Parse non-window aggregate expressions

This commit is contained in:
Simon Binder 2022-03-08 10:58:40 +01:00
parent 56f048b42f
commit c9e22bf8d2
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
10 changed files with 122 additions and 46 deletions

View File

@ -304,7 +304,7 @@ class LintingVisitor extends RecursiveVisitor<void, void> {
));
} else if (column is ExpressionResultColumn) {
// While we're at it, window expressions aren't allowed either
if (column.expression is AggregateExpression) {
if (column.expression is WindowFunctionInvocation) {
context.reportError(
AnalysisError(
type: AnalysisErrorType.illegalUseOfReturning,

View File

@ -6,16 +6,6 @@ class ReferenceResolver extends RecursiveVisitor<void, void> {
ReferenceResolver(this.context);
@override
void visitAggregateExpression(AggregateExpression e, void arg) {
if (e.windowName != null && e.resolved == null) {
final resolved = e.scope.resolve<NamedWindowDeclaration>(e.windowName!);
e.resolved = resolved;
}
visitChildren(e, arg);
}
@override
void visitInsertStatement(InsertStatement e, void arg) {
final table = e.table.resultSet;
@ -115,6 +105,16 @@ class ReferenceResolver extends RecursiveVisitor<void, void> {
visitChildren(e, arg);
}
@override
void visitWindowFunctionInvocation(WindowFunctionInvocation e, void arg) {
if (e.windowName != null && e.resolved == null) {
final resolved = e.scope.resolve<NamedWindowDeclaration>(e.windowName!);
e.resolved = resolved;
}
visitChildren(e, arg);
}
void _reportUnknownColumnError(Reference e, {Iterable<Column>? columns}) {
columns ??= e.scope.availableColumns;
final columnNames = e.scope.availableColumns

View File

@ -1,6 +1,6 @@
part of '../ast.dart';
class AggregateExpression extends Expression
class AggregateFunctionInvocation extends Expression
implements ExpressionInvocation, ReferenceOwner {
final IdentifierToken function;
@ -13,6 +13,31 @@ class AggregateExpression extends Expression
@override
Referencable? resolved;
AggregateFunctionInvocation({
required this.function,
required this.parameters,
this.filter,
});
@override
R accept<A, R>(AstVisitor<A, R> visitor, A arg) {
return visitor.visitAggregateFunctionInvocation(this, arg);
}
@override
void transformChildren<A>(Transformer<A> transformer, A arg) {
parameters = transformer.transformChild(parameters, this, arg);
filter = transformer.transformNullableChild(filter, this, arg);
}
@override
Iterable<AstNode> get childNodes {
return [parameters, if (filter != null) filter!];
}
}
class WindowFunctionInvocation extends AggregateFunctionInvocation {
WindowDefinition? get over {
if (windowDefinition != null) return windowDefinition;
return (resolved as NamedWindowDeclaration?)?.definition;
@ -32,18 +57,19 @@ class AggregateExpression extends Expression
/// [over] in either case.
final String? windowName;
AggregateExpression(
{required this.function,
required this.parameters,
this.filter,
WindowFunctionInvocation(
{required IdentifierToken function,
required FunctionParameters parameters,
Expression? filter,
this.windowDefinition,
this.windowName})
// either window definition or name must be null
: assert((windowDefinition == null) != (windowName == null));
// one of window definition or name must be null
: assert((windowDefinition == null) || (windowName == null)),
super(function: function, parameters: parameters, filter: filter);
@override
R accept<A, R>(AstVisitor<A, R> visitor, A arg) {
return visitor.visitAggregateExpression(this, arg);
return visitor.visitWindowFunctionInvocation(this, arg);
}
@override
@ -66,7 +92,7 @@ class AggregateExpression extends Expression
/// A window declaration that appears in a `SELECT` statement like
/// `WINDOW <name> AS <window-defn>`. It can be referenced from an
/// [AggregateExpression] if it uses the same name.
/// [AggregateFunctionInvocation] if it uses the same name.
class NamedWindowDeclaration with Referencable {
// todo: Should be an ast node
final String name;

View File

@ -81,7 +81,8 @@ abstract class AstVisitor<A, R> {
R visitInExpression(InExpression e, A arg);
R visitRaiseExpression(RaiseExpression e, A arg);
R visitAggregateExpression(AggregateExpression e, A arg);
R visitAggregateFunctionInvocation(AggregateFunctionInvocation e, A arg);
R visitWindowFunctionInvocation(WindowFunctionInvocation e, A arg);
R visitWindowDefinition(WindowDefinition e, A arg);
R visitFrameSpec(FrameSpec e, A arg);
R visitIndexedColumn(IndexedColumn e, A arg);
@ -488,7 +489,12 @@ class RecursiveVisitor<A, R> implements AstVisitor<A, R?> {
}
@override
R? visitAggregateExpression(AggregateExpression e, A arg) {
R? visitAggregateFunctionInvocation(AggregateFunctionInvocation e, A arg) {
return visitExpressionInvocation(e, arg);
}
@override
R? visitWindowFunctionInvocation(WindowFunctionInvocation e, A arg) {
return visitExpressionInvocation(e, arg);
}

View File

@ -913,7 +913,7 @@ class Parser {
..setSpan(first, _previous);
}
AggregateExpression _aggregate(
AggregateFunctionInvocation _aggregate(
IdentifierToken name, FunctionParameters params) {
Expression? filter;
@ -926,24 +926,30 @@ class Parser {
_consume(TokenType.rightParen, 'Expecteded closing parenthes');
}
_consume(TokenType.over, 'Expected OVER to begin window clause');
if (_matchOne(TokenType.over)) {
String? windowName;
WindowDefinition? window;
String? windowName;
WindowDefinition? window;
if (_matchOne(TokenType.identifier)) {
windowName = (_previous as IdentifierToken).identifier;
} else {
window = _windowDefinition();
}
if (_matchOne(TokenType.identifier)) {
windowName = (_previous as IdentifierToken).identifier;
return WindowFunctionInvocation(
function: name,
parameters: params,
filter: filter,
windowDefinition: window,
windowName: windowName,
)..setSpan(name, _previous);
} else {
window = _windowDefinition();
return AggregateFunctionInvocation(
function: name,
parameters: params,
filter: filter,
)..setSpan(name, _previous);
}
return AggregateExpression(
function: name,
parameters: params,
filter: filter,
windowDefinition: window,
windowName: windowName,
)..setSpan(name, _previous);
}
/// Parses a [Tuple]. If [orSubQuery] is set (defaults to false), a [SubQuery]

View File

@ -44,9 +44,10 @@ class EqualityEnforcingVisitor implements AstVisitor<void, void> {
: _considerChildren = considerChildren;
@override
void visitAggregateExpression(AggregateExpression e, void arg) {
final current = _currentAs<AggregateExpression>(e);
_assert(current.name == e.name && current.windowName == e.windowName, e);
void visitAggregateFunctionInvocation(
AggregateFunctionInvocation e, void arg) {
final current = _currentAs<AggregateFunctionInvocation>(e);
_assert(current.name == e.name, e);
_checkChildren(e);
}
@ -716,6 +717,13 @@ class EqualityEnforcingVisitor implements AstVisitor<void, void> {
_checkChildren(e);
}
@override
void visitWindowFunctionInvocation(WindowFunctionInvocation e, void arg) {
final current = _currentAs<WindowFunctionInvocation>(e);
_assert(current.name == e.name && current.windowName == e.windowName, e);
_checkChildren(e);
}
@override
void visitWindowDefinition(WindowDefinition e, void arg) {
final current = _currentAs<WindowDefinition>(e);

View File

@ -25,7 +25,8 @@ class NodeSqlBuilder extends AstVisitor<void, void> {
}
@override
void visitAggregateExpression(AggregateExpression e, void arg) {
void visitAggregateFunctionInvocation(
AggregateFunctionInvocation e, void arg) {
symbol(e.name);
symbol('(');
@ -39,6 +40,11 @@ class NodeSqlBuilder extends AstVisitor<void, void> {
visit(e.filter!, arg);
symbol(')', spaceAfter: true);
}
}
@override
void visitWindowFunctionInvocation(WindowFunctionInvocation e, void arg) {
visitAggregateFunctionInvocation(e, arg);
if (e.windowDefinition != null) {
_keyword(TokenType.over);

View File

@ -184,7 +184,7 @@ SELECT row_number() OVER wnd FROM demo
final column = (context.root as SelectStatement).resolvedColumns!.single
as ExpressionColumn;
final over = (column.expression as AggregateExpression).over!;
final over = (column.expression as WindowFunctionInvocation).over!;
enforceEqual(
over,

View File

@ -7,7 +7,7 @@ import 'package:test/test.dart';
import 'utils.dart';
final Map<String, Expression> _testCases = {
'row_number() OVER (ORDER BY y)': AggregateExpression(
'row_number() OVER (ORDER BY y)': WindowFunctionInvocation(
function: identifier('row_number'),
parameters: ExprFunctionParameters(),
windowDefinition: WindowDefinition(
@ -20,7 +20,7 @@ final Map<String, Expression> _testCases = {
'row_number(*) FILTER (WHERE 1) OVER '
'(base_name PARTITION BY a, b '
'GROUPS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING EXCLUDE TIES)':
AggregateExpression(
WindowFunctionInvocation(
function: identifier('row_number'),
parameters: StarFunctionParameter(),
filter: NumericLiteral(1, token(TokenType.numberLiteral)),
@ -41,7 +41,7 @@ final Map<String, Expression> _testCases = {
),
),
'row_number() OVER (RANGE CURRENT ROW EXCLUDE NO OTHERS)':
AggregateExpression(
WindowFunctionInvocation(
function: identifier('row_number'),
parameters: ExprFunctionParameters(),
windowDefinition: WindowDefinition(
@ -53,6 +53,18 @@ final Map<String, Expression> _testCases = {
),
),
),
'COUNT(is_skipped) FILTER (WHERE is_skipped = true)':
AggregateFunctionInvocation(
function: identifier('COUNT'),
parameters: ExprFunctionParameters(
parameters: [Reference(columnName: 'is_skipped')],
),
filter: BinaryExpression(
Reference(columnName: 'is_skipped'),
token(TokenType.equal),
BooleanLiteral.withTrue(token(TokenType.$true)),
),
),
};
void main() {

View File

@ -189,7 +189,7 @@ CREATE UNIQUE INDEX my_idx ON t1 (c1, c2, c3) WHERE c1 < c3;
test('with materialized CTEs', () {
testFormat('''
WITH
WITH
foo (id) AS NOT MATERIALIZED (SELECT 1),
bar (id) AS MATERIALIZED (SELECT 2)
SELECT * FROM foo UNION ALL SELECT * FROM bar;
@ -238,6 +238,18 @@ CREATE UNIQUE INDEX my_idx ON t1 (c1, c2, c3) WHERE c1 < c3;
''');
});
test('aggregate', () {
testFormat('''
SELECT
subs_id, subs_name,
COUNT(is_skipped) FILTER (WHERE is_skipped = true) skipped,
COUNT(is_touched) FILTER (WHERE is_touched = true) touched,
COUNT(is_passed) FILTER (WHERE is_passed = true) passed
FROM stats
GROUP BY subs_id;
''');
});
test('joins', () {
testFormat('''
SELECT * FROM