diff --git a/sqlparser/lib/src/analysis/steps/linting_visitor.dart b/sqlparser/lib/src/analysis/steps/linting_visitor.dart index 517d9f20..0b299711 100644 --- a/sqlparser/lib/src/analysis/steps/linting_visitor.dart +++ b/sqlparser/lib/src/analysis/steps/linting_visitor.dart @@ -304,7 +304,7 @@ class LintingVisitor extends RecursiveVisitor { )); } else if (column is ExpressionResultColumn) { // While we're at it, window expressions aren't allowed either - if (column.expression is AggregateExpression) { + if (column.expression is WindowFunctionInvocation) { context.reportError( AnalysisError( type: AnalysisErrorType.illegalUseOfReturning, diff --git a/sqlparser/lib/src/analysis/steps/reference_resolver.dart b/sqlparser/lib/src/analysis/steps/reference_resolver.dart index 51bff015..bf8233bc 100644 --- a/sqlparser/lib/src/analysis/steps/reference_resolver.dart +++ b/sqlparser/lib/src/analysis/steps/reference_resolver.dart @@ -6,16 +6,6 @@ class ReferenceResolver extends RecursiveVisitor { ReferenceResolver(this.context); - @override - void visitAggregateExpression(AggregateExpression e, void arg) { - if (e.windowName != null && e.resolved == null) { - final resolved = e.scope.resolve(e.windowName!); - e.resolved = resolved; - } - - visitChildren(e, arg); - } - @override void visitInsertStatement(InsertStatement e, void arg) { final table = e.table.resultSet; @@ -115,6 +105,16 @@ class ReferenceResolver extends RecursiveVisitor { visitChildren(e, arg); } + @override + void visitWindowFunctionInvocation(WindowFunctionInvocation e, void arg) { + if (e.windowName != null && e.resolved == null) { + final resolved = e.scope.resolve(e.windowName!); + e.resolved = resolved; + } + + visitChildren(e, arg); + } + void _reportUnknownColumnError(Reference e, {Iterable? columns}) { columns ??= e.scope.availableColumns; final columnNames = e.scope.availableColumns diff --git a/sqlparser/lib/src/ast/expressions/aggregate.dart b/sqlparser/lib/src/ast/expressions/aggregate.dart index 5d17023c..b8ccb348 100644 --- a/sqlparser/lib/src/ast/expressions/aggregate.dart +++ b/sqlparser/lib/src/ast/expressions/aggregate.dart @@ -1,6 +1,6 @@ part of '../ast.dart'; -class AggregateExpression extends Expression +class AggregateFunctionInvocation extends Expression implements ExpressionInvocation, ReferenceOwner { final IdentifierToken function; @@ -13,6 +13,31 @@ class AggregateExpression extends Expression @override Referencable? resolved; + + AggregateFunctionInvocation({ + required this.function, + required this.parameters, + this.filter, + }); + + @override + R accept(AstVisitor visitor, A arg) { + return visitor.visitAggregateFunctionInvocation(this, arg); + } + + @override + void transformChildren(Transformer transformer, A arg) { + parameters = transformer.transformChild(parameters, this, arg); + filter = transformer.transformNullableChild(filter, this, arg); + } + + @override + Iterable get childNodes { + return [parameters, if (filter != null) filter!]; + } +} + +class WindowFunctionInvocation extends AggregateFunctionInvocation { WindowDefinition? get over { if (windowDefinition != null) return windowDefinition; return (resolved as NamedWindowDeclaration?)?.definition; @@ -32,18 +57,19 @@ class AggregateExpression extends Expression /// [over] in either case. final String? windowName; - AggregateExpression( - {required this.function, - required this.parameters, - this.filter, + WindowFunctionInvocation( + {required IdentifierToken function, + required FunctionParameters parameters, + Expression? filter, this.windowDefinition, this.windowName}) - // either window definition or name must be null - : assert((windowDefinition == null) != (windowName == null)); + // one of window definition or name must be null + : assert((windowDefinition == null) || (windowName == null)), + super(function: function, parameters: parameters, filter: filter); @override R accept(AstVisitor visitor, A arg) { - return visitor.visitAggregateExpression(this, arg); + return visitor.visitWindowFunctionInvocation(this, arg); } @override @@ -66,7 +92,7 @@ class AggregateExpression extends Expression /// A window declaration that appears in a `SELECT` statement like /// `WINDOW AS `. It can be referenced from an -/// [AggregateExpression] if it uses the same name. +/// [AggregateFunctionInvocation] if it uses the same name. class NamedWindowDeclaration with Referencable { // todo: Should be an ast node final String name; diff --git a/sqlparser/lib/src/ast/visitor.dart b/sqlparser/lib/src/ast/visitor.dart index 4f7cfc83..226eb88a 100644 --- a/sqlparser/lib/src/ast/visitor.dart +++ b/sqlparser/lib/src/ast/visitor.dart @@ -81,7 +81,8 @@ abstract class AstVisitor { R visitInExpression(InExpression e, A arg); R visitRaiseExpression(RaiseExpression e, A arg); - R visitAggregateExpression(AggregateExpression e, A arg); + R visitAggregateFunctionInvocation(AggregateFunctionInvocation e, A arg); + R visitWindowFunctionInvocation(WindowFunctionInvocation e, A arg); R visitWindowDefinition(WindowDefinition e, A arg); R visitFrameSpec(FrameSpec e, A arg); R visitIndexedColumn(IndexedColumn e, A arg); @@ -488,7 +489,12 @@ class RecursiveVisitor implements AstVisitor { } @override - R? visitAggregateExpression(AggregateExpression e, A arg) { + R? visitAggregateFunctionInvocation(AggregateFunctionInvocation e, A arg) { + return visitExpressionInvocation(e, arg); + } + + @override + R? visitWindowFunctionInvocation(WindowFunctionInvocation e, A arg) { return visitExpressionInvocation(e, arg); } diff --git a/sqlparser/lib/src/reader/parser.dart b/sqlparser/lib/src/reader/parser.dart index f6024405..1d31e214 100644 --- a/sqlparser/lib/src/reader/parser.dart +++ b/sqlparser/lib/src/reader/parser.dart @@ -913,7 +913,7 @@ class Parser { ..setSpan(first, _previous); } - AggregateExpression _aggregate( + AggregateFunctionInvocation _aggregate( IdentifierToken name, FunctionParameters params) { Expression? filter; @@ -926,24 +926,30 @@ class Parser { _consume(TokenType.rightParen, 'Expecteded closing parenthes'); } - _consume(TokenType.over, 'Expected OVER to begin window clause'); + if (_matchOne(TokenType.over)) { + String? windowName; + WindowDefinition? window; - String? windowName; - WindowDefinition? window; + if (_matchOne(TokenType.identifier)) { + windowName = (_previous as IdentifierToken).identifier; + } else { + window = _windowDefinition(); + } - if (_matchOne(TokenType.identifier)) { - windowName = (_previous as IdentifierToken).identifier; + return WindowFunctionInvocation( + function: name, + parameters: params, + filter: filter, + windowDefinition: window, + windowName: windowName, + )..setSpan(name, _previous); } else { - window = _windowDefinition(); + return AggregateFunctionInvocation( + function: name, + parameters: params, + filter: filter, + )..setSpan(name, _previous); } - - return AggregateExpression( - function: name, - parameters: params, - filter: filter, - windowDefinition: window, - windowName: windowName, - )..setSpan(name, _previous); } /// Parses a [Tuple]. If [orSubQuery] is set (defaults to false), a [SubQuery] diff --git a/sqlparser/lib/src/utils/ast_equality.dart b/sqlparser/lib/src/utils/ast_equality.dart index cbb3e329..577470b0 100644 --- a/sqlparser/lib/src/utils/ast_equality.dart +++ b/sqlparser/lib/src/utils/ast_equality.dart @@ -44,9 +44,10 @@ class EqualityEnforcingVisitor implements AstVisitor { : _considerChildren = considerChildren; @override - void visitAggregateExpression(AggregateExpression e, void arg) { - final current = _currentAs(e); - _assert(current.name == e.name && current.windowName == e.windowName, e); + void visitAggregateFunctionInvocation( + AggregateFunctionInvocation e, void arg) { + final current = _currentAs(e); + _assert(current.name == e.name, e); _checkChildren(e); } @@ -716,6 +717,13 @@ class EqualityEnforcingVisitor implements AstVisitor { _checkChildren(e); } + @override + void visitWindowFunctionInvocation(WindowFunctionInvocation e, void arg) { + final current = _currentAs(e); + _assert(current.name == e.name && current.windowName == e.windowName, e); + _checkChildren(e); + } + @override void visitWindowDefinition(WindowDefinition e, void arg) { final current = _currentAs(e); diff --git a/sqlparser/lib/utils/node_to_text.dart b/sqlparser/lib/utils/node_to_text.dart index 57169833..d7d69b5a 100644 --- a/sqlparser/lib/utils/node_to_text.dart +++ b/sqlparser/lib/utils/node_to_text.dart @@ -25,7 +25,8 @@ class NodeSqlBuilder extends AstVisitor { } @override - void visitAggregateExpression(AggregateExpression e, void arg) { + void visitAggregateFunctionInvocation( + AggregateFunctionInvocation e, void arg) { symbol(e.name); symbol('('); @@ -39,6 +40,11 @@ class NodeSqlBuilder extends AstVisitor { visit(e.filter!, arg); symbol(')', spaceAfter: true); } + } + + @override + void visitWindowFunctionInvocation(WindowFunctionInvocation e, void arg) { + visitAggregateFunctionInvocation(e, arg); if (e.windowDefinition != null) { _keyword(TokenType.over); diff --git a/sqlparser/test/analysis/reference_resolver_test.dart b/sqlparser/test/analysis/reference_resolver_test.dart index 2cb71e02..95e6562a 100644 --- a/sqlparser/test/analysis/reference_resolver_test.dart +++ b/sqlparser/test/analysis/reference_resolver_test.dart @@ -184,7 +184,7 @@ SELECT row_number() OVER wnd FROM demo final column = (context.root as SelectStatement).resolvedColumns!.single as ExpressionColumn; - final over = (column.expression as AggregateExpression).over!; + final over = (column.expression as WindowFunctionInvocation).over!; enforceEqual( over, diff --git a/sqlparser/test/parser/partition_test.dart b/sqlparser/test/parser/partition_test.dart index fc5c754f..64905265 100644 --- a/sqlparser/test/parser/partition_test.dart +++ b/sqlparser/test/parser/partition_test.dart @@ -7,7 +7,7 @@ import 'package:test/test.dart'; import 'utils.dart'; final Map _testCases = { - 'row_number() OVER (ORDER BY y)': AggregateExpression( + 'row_number() OVER (ORDER BY y)': WindowFunctionInvocation( function: identifier('row_number'), parameters: ExprFunctionParameters(), windowDefinition: WindowDefinition( @@ -20,7 +20,7 @@ final Map _testCases = { 'row_number(*) FILTER (WHERE 1) OVER ' '(base_name PARTITION BY a, b ' 'GROUPS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING EXCLUDE TIES)': - AggregateExpression( + WindowFunctionInvocation( function: identifier('row_number'), parameters: StarFunctionParameter(), filter: NumericLiteral(1, token(TokenType.numberLiteral)), @@ -41,7 +41,7 @@ final Map _testCases = { ), ), 'row_number() OVER (RANGE CURRENT ROW EXCLUDE NO OTHERS)': - AggregateExpression( + WindowFunctionInvocation( function: identifier('row_number'), parameters: ExprFunctionParameters(), windowDefinition: WindowDefinition( @@ -53,6 +53,18 @@ final Map _testCases = { ), ), ), + 'COUNT(is_skipped) FILTER (WHERE is_skipped = true)': + AggregateFunctionInvocation( + function: identifier('COUNT'), + parameters: ExprFunctionParameters( + parameters: [Reference(columnName: 'is_skipped')], + ), + filter: BinaryExpression( + Reference(columnName: 'is_skipped'), + token(TokenType.equal), + BooleanLiteral.withTrue(token(TokenType.$true)), + ), + ), }; void main() { diff --git a/sqlparser/test/utils/node_to_text_test.dart b/sqlparser/test/utils/node_to_text_test.dart index 61339007..cce7d726 100644 --- a/sqlparser/test/utils/node_to_text_test.dart +++ b/sqlparser/test/utils/node_to_text_test.dart @@ -189,7 +189,7 @@ CREATE UNIQUE INDEX my_idx ON t1 (c1, c2, c3) WHERE c1 < c3; test('with materialized CTEs', () { testFormat(''' - WITH + WITH foo (id) AS NOT MATERIALIZED (SELECT 1), bar (id) AS MATERIALIZED (SELECT 2) SELECT * FROM foo UNION ALL SELECT * FROM bar; @@ -238,6 +238,18 @@ CREATE UNIQUE INDEX my_idx ON t1 (c1, c2, c3) WHERE c1 < c3; '''); }); + test('aggregate', () { + testFormat(''' + SELECT + subs_id, subs_name, + COUNT(is_skipped) FILTER (WHERE is_skipped = true) skipped, + COUNT(is_touched) FILTER (WHERE is_touched = true) touched, + COUNT(is_passed) FILTER (WHERE is_passed = true) passed + FROM stats + GROUP BY subs_id; + '''); + }); + test('joins', () { testFormat(''' SELECT * FROM