mirror of https://github.com/AMT-Cheif/drift.git
Parse non-window aggregate expressions
This commit is contained in:
parent
56f048b42f
commit
c9e22bf8d2
|
@ -304,7 +304,7 @@ class LintingVisitor extends RecursiveVisitor<void, void> {
|
|||
));
|
||||
} else if (column is ExpressionResultColumn) {
|
||||
// While we're at it, window expressions aren't allowed either
|
||||
if (column.expression is AggregateExpression) {
|
||||
if (column.expression is WindowFunctionInvocation) {
|
||||
context.reportError(
|
||||
AnalysisError(
|
||||
type: AnalysisErrorType.illegalUseOfReturning,
|
||||
|
|
|
@ -6,16 +6,6 @@ class ReferenceResolver extends RecursiveVisitor<void, void> {
|
|||
|
||||
ReferenceResolver(this.context);
|
||||
|
||||
@override
|
||||
void visitAggregateExpression(AggregateExpression e, void arg) {
|
||||
if (e.windowName != null && e.resolved == null) {
|
||||
final resolved = e.scope.resolve<NamedWindowDeclaration>(e.windowName!);
|
||||
e.resolved = resolved;
|
||||
}
|
||||
|
||||
visitChildren(e, arg);
|
||||
}
|
||||
|
||||
@override
|
||||
void visitInsertStatement(InsertStatement e, void arg) {
|
||||
final table = e.table.resultSet;
|
||||
|
@ -115,6 +105,16 @@ class ReferenceResolver extends RecursiveVisitor<void, void> {
|
|||
visitChildren(e, arg);
|
||||
}
|
||||
|
||||
@override
|
||||
void visitWindowFunctionInvocation(WindowFunctionInvocation e, void arg) {
|
||||
if (e.windowName != null && e.resolved == null) {
|
||||
final resolved = e.scope.resolve<NamedWindowDeclaration>(e.windowName!);
|
||||
e.resolved = resolved;
|
||||
}
|
||||
|
||||
visitChildren(e, arg);
|
||||
}
|
||||
|
||||
void _reportUnknownColumnError(Reference e, {Iterable<Column>? columns}) {
|
||||
columns ??= e.scope.availableColumns;
|
||||
final columnNames = e.scope.availableColumns
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
part of '../ast.dart';
|
||||
|
||||
class AggregateExpression extends Expression
|
||||
class AggregateFunctionInvocation extends Expression
|
||||
implements ExpressionInvocation, ReferenceOwner {
|
||||
final IdentifierToken function;
|
||||
|
||||
|
@ -13,6 +13,31 @@ class AggregateExpression extends Expression
|
|||
|
||||
@override
|
||||
Referencable? resolved;
|
||||
|
||||
AggregateFunctionInvocation({
|
||||
required this.function,
|
||||
required this.parameters,
|
||||
this.filter,
|
||||
});
|
||||
|
||||
@override
|
||||
R accept<A, R>(AstVisitor<A, R> visitor, A arg) {
|
||||
return visitor.visitAggregateFunctionInvocation(this, arg);
|
||||
}
|
||||
|
||||
@override
|
||||
void transformChildren<A>(Transformer<A> transformer, A arg) {
|
||||
parameters = transformer.transformChild(parameters, this, arg);
|
||||
filter = transformer.transformNullableChild(filter, this, arg);
|
||||
}
|
||||
|
||||
@override
|
||||
Iterable<AstNode> get childNodes {
|
||||
return [parameters, if (filter != null) filter!];
|
||||
}
|
||||
}
|
||||
|
||||
class WindowFunctionInvocation extends AggregateFunctionInvocation {
|
||||
WindowDefinition? get over {
|
||||
if (windowDefinition != null) return windowDefinition;
|
||||
return (resolved as NamedWindowDeclaration?)?.definition;
|
||||
|
@ -32,18 +57,19 @@ class AggregateExpression extends Expression
|
|||
/// [over] in either case.
|
||||
final String? windowName;
|
||||
|
||||
AggregateExpression(
|
||||
{required this.function,
|
||||
required this.parameters,
|
||||
this.filter,
|
||||
WindowFunctionInvocation(
|
||||
{required IdentifierToken function,
|
||||
required FunctionParameters parameters,
|
||||
Expression? filter,
|
||||
this.windowDefinition,
|
||||
this.windowName})
|
||||
// either window definition or name must be null
|
||||
: assert((windowDefinition == null) != (windowName == null));
|
||||
// one of window definition or name must be null
|
||||
: assert((windowDefinition == null) || (windowName == null)),
|
||||
super(function: function, parameters: parameters, filter: filter);
|
||||
|
||||
@override
|
||||
R accept<A, R>(AstVisitor<A, R> visitor, A arg) {
|
||||
return visitor.visitAggregateExpression(this, arg);
|
||||
return visitor.visitWindowFunctionInvocation(this, arg);
|
||||
}
|
||||
|
||||
@override
|
||||
|
@ -66,7 +92,7 @@ class AggregateExpression extends Expression
|
|||
|
||||
/// A window declaration that appears in a `SELECT` statement like
|
||||
/// `WINDOW <name> AS <window-defn>`. It can be referenced from an
|
||||
/// [AggregateExpression] if it uses the same name.
|
||||
/// [AggregateFunctionInvocation] if it uses the same name.
|
||||
class NamedWindowDeclaration with Referencable {
|
||||
// todo: Should be an ast node
|
||||
final String name;
|
||||
|
|
|
@ -81,7 +81,8 @@ abstract class AstVisitor<A, R> {
|
|||
R visitInExpression(InExpression e, A arg);
|
||||
R visitRaiseExpression(RaiseExpression e, A arg);
|
||||
|
||||
R visitAggregateExpression(AggregateExpression e, A arg);
|
||||
R visitAggregateFunctionInvocation(AggregateFunctionInvocation e, A arg);
|
||||
R visitWindowFunctionInvocation(WindowFunctionInvocation e, A arg);
|
||||
R visitWindowDefinition(WindowDefinition e, A arg);
|
||||
R visitFrameSpec(FrameSpec e, A arg);
|
||||
R visitIndexedColumn(IndexedColumn e, A arg);
|
||||
|
@ -488,7 +489,12 @@ class RecursiveVisitor<A, R> implements AstVisitor<A, R?> {
|
|||
}
|
||||
|
||||
@override
|
||||
R? visitAggregateExpression(AggregateExpression e, A arg) {
|
||||
R? visitAggregateFunctionInvocation(AggregateFunctionInvocation e, A arg) {
|
||||
return visitExpressionInvocation(e, arg);
|
||||
}
|
||||
|
||||
@override
|
||||
R? visitWindowFunctionInvocation(WindowFunctionInvocation e, A arg) {
|
||||
return visitExpressionInvocation(e, arg);
|
||||
}
|
||||
|
||||
|
|
|
@ -913,7 +913,7 @@ class Parser {
|
|||
..setSpan(first, _previous);
|
||||
}
|
||||
|
||||
AggregateExpression _aggregate(
|
||||
AggregateFunctionInvocation _aggregate(
|
||||
IdentifierToken name, FunctionParameters params) {
|
||||
Expression? filter;
|
||||
|
||||
|
@ -926,24 +926,30 @@ class Parser {
|
|||
_consume(TokenType.rightParen, 'Expecteded closing parenthes');
|
||||
}
|
||||
|
||||
_consume(TokenType.over, 'Expected OVER to begin window clause');
|
||||
if (_matchOne(TokenType.over)) {
|
||||
String? windowName;
|
||||
WindowDefinition? window;
|
||||
|
||||
String? windowName;
|
||||
WindowDefinition? window;
|
||||
if (_matchOne(TokenType.identifier)) {
|
||||
windowName = (_previous as IdentifierToken).identifier;
|
||||
} else {
|
||||
window = _windowDefinition();
|
||||
}
|
||||
|
||||
if (_matchOne(TokenType.identifier)) {
|
||||
windowName = (_previous as IdentifierToken).identifier;
|
||||
return WindowFunctionInvocation(
|
||||
function: name,
|
||||
parameters: params,
|
||||
filter: filter,
|
||||
windowDefinition: window,
|
||||
windowName: windowName,
|
||||
)..setSpan(name, _previous);
|
||||
} else {
|
||||
window = _windowDefinition();
|
||||
return AggregateFunctionInvocation(
|
||||
function: name,
|
||||
parameters: params,
|
||||
filter: filter,
|
||||
)..setSpan(name, _previous);
|
||||
}
|
||||
|
||||
return AggregateExpression(
|
||||
function: name,
|
||||
parameters: params,
|
||||
filter: filter,
|
||||
windowDefinition: window,
|
||||
windowName: windowName,
|
||||
)..setSpan(name, _previous);
|
||||
}
|
||||
|
||||
/// Parses a [Tuple]. If [orSubQuery] is set (defaults to false), a [SubQuery]
|
||||
|
|
|
@ -44,9 +44,10 @@ class EqualityEnforcingVisitor implements AstVisitor<void, void> {
|
|||
: _considerChildren = considerChildren;
|
||||
|
||||
@override
|
||||
void visitAggregateExpression(AggregateExpression e, void arg) {
|
||||
final current = _currentAs<AggregateExpression>(e);
|
||||
_assert(current.name == e.name && current.windowName == e.windowName, e);
|
||||
void visitAggregateFunctionInvocation(
|
||||
AggregateFunctionInvocation e, void arg) {
|
||||
final current = _currentAs<AggregateFunctionInvocation>(e);
|
||||
_assert(current.name == e.name, e);
|
||||
_checkChildren(e);
|
||||
}
|
||||
|
||||
|
@ -716,6 +717,13 @@ class EqualityEnforcingVisitor implements AstVisitor<void, void> {
|
|||
_checkChildren(e);
|
||||
}
|
||||
|
||||
@override
|
||||
void visitWindowFunctionInvocation(WindowFunctionInvocation e, void arg) {
|
||||
final current = _currentAs<WindowFunctionInvocation>(e);
|
||||
_assert(current.name == e.name && current.windowName == e.windowName, e);
|
||||
_checkChildren(e);
|
||||
}
|
||||
|
||||
@override
|
||||
void visitWindowDefinition(WindowDefinition e, void arg) {
|
||||
final current = _currentAs<WindowDefinition>(e);
|
||||
|
|
|
@ -25,7 +25,8 @@ class NodeSqlBuilder extends AstVisitor<void, void> {
|
|||
}
|
||||
|
||||
@override
|
||||
void visitAggregateExpression(AggregateExpression e, void arg) {
|
||||
void visitAggregateFunctionInvocation(
|
||||
AggregateFunctionInvocation e, void arg) {
|
||||
symbol(e.name);
|
||||
|
||||
symbol('(');
|
||||
|
@ -39,6 +40,11 @@ class NodeSqlBuilder extends AstVisitor<void, void> {
|
|||
visit(e.filter!, arg);
|
||||
symbol(')', spaceAfter: true);
|
||||
}
|
||||
}
|
||||
|
||||
@override
|
||||
void visitWindowFunctionInvocation(WindowFunctionInvocation e, void arg) {
|
||||
visitAggregateFunctionInvocation(e, arg);
|
||||
|
||||
if (e.windowDefinition != null) {
|
||||
_keyword(TokenType.over);
|
||||
|
|
|
@ -184,7 +184,7 @@ SELECT row_number() OVER wnd FROM demo
|
|||
final column = (context.root as SelectStatement).resolvedColumns!.single
|
||||
as ExpressionColumn;
|
||||
|
||||
final over = (column.expression as AggregateExpression).over!;
|
||||
final over = (column.expression as WindowFunctionInvocation).over!;
|
||||
|
||||
enforceEqual(
|
||||
over,
|
||||
|
|
|
@ -7,7 +7,7 @@ import 'package:test/test.dart';
|
|||
import 'utils.dart';
|
||||
|
||||
final Map<String, Expression> _testCases = {
|
||||
'row_number() OVER (ORDER BY y)': AggregateExpression(
|
||||
'row_number() OVER (ORDER BY y)': WindowFunctionInvocation(
|
||||
function: identifier('row_number'),
|
||||
parameters: ExprFunctionParameters(),
|
||||
windowDefinition: WindowDefinition(
|
||||
|
@ -20,7 +20,7 @@ final Map<String, Expression> _testCases = {
|
|||
'row_number(*) FILTER (WHERE 1) OVER '
|
||||
'(base_name PARTITION BY a, b '
|
||||
'GROUPS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING EXCLUDE TIES)':
|
||||
AggregateExpression(
|
||||
WindowFunctionInvocation(
|
||||
function: identifier('row_number'),
|
||||
parameters: StarFunctionParameter(),
|
||||
filter: NumericLiteral(1, token(TokenType.numberLiteral)),
|
||||
|
@ -41,7 +41,7 @@ final Map<String, Expression> _testCases = {
|
|||
),
|
||||
),
|
||||
'row_number() OVER (RANGE CURRENT ROW EXCLUDE NO OTHERS)':
|
||||
AggregateExpression(
|
||||
WindowFunctionInvocation(
|
||||
function: identifier('row_number'),
|
||||
parameters: ExprFunctionParameters(),
|
||||
windowDefinition: WindowDefinition(
|
||||
|
@ -53,6 +53,18 @@ final Map<String, Expression> _testCases = {
|
|||
),
|
||||
),
|
||||
),
|
||||
'COUNT(is_skipped) FILTER (WHERE is_skipped = true)':
|
||||
AggregateFunctionInvocation(
|
||||
function: identifier('COUNT'),
|
||||
parameters: ExprFunctionParameters(
|
||||
parameters: [Reference(columnName: 'is_skipped')],
|
||||
),
|
||||
filter: BinaryExpression(
|
||||
Reference(columnName: 'is_skipped'),
|
||||
token(TokenType.equal),
|
||||
BooleanLiteral.withTrue(token(TokenType.$true)),
|
||||
),
|
||||
),
|
||||
};
|
||||
|
||||
void main() {
|
||||
|
|
|
@ -189,7 +189,7 @@ CREATE UNIQUE INDEX my_idx ON t1 (c1, c2, c3) WHERE c1 < c3;
|
|||
|
||||
test('with materialized CTEs', () {
|
||||
testFormat('''
|
||||
WITH
|
||||
WITH
|
||||
foo (id) AS NOT MATERIALIZED (SELECT 1),
|
||||
bar (id) AS MATERIALIZED (SELECT 2)
|
||||
SELECT * FROM foo UNION ALL SELECT * FROM bar;
|
||||
|
@ -238,6 +238,18 @@ CREATE UNIQUE INDEX my_idx ON t1 (c1, c2, c3) WHERE c1 < c3;
|
|||
''');
|
||||
});
|
||||
|
||||
test('aggregate', () {
|
||||
testFormat('''
|
||||
SELECT
|
||||
subs_id, subs_name,
|
||||
COUNT(is_skipped) FILTER (WHERE is_skipped = true) skipped,
|
||||
COUNT(is_touched) FILTER (WHERE is_touched = true) touched,
|
||||
COUNT(is_passed) FILTER (WHERE is_passed = true) passed
|
||||
FROM stats
|
||||
GROUP BY subs_id;
|
||||
''');
|
||||
});
|
||||
|
||||
test('joins', () {
|
||||
testFormat('''
|
||||
SELECT * FROM
|
||||
|
|
Loading…
Reference in New Issue