Parse FROM clauses for update statements

This commit is contained in:
Simon Binder 2021-03-13 12:02:35 +01:00
parent 8499ef4e10
commit 3000cb2e44
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
5 changed files with 99 additions and 13 deletions

View File

@ -22,16 +22,18 @@ class UpdateStatement extends CrudStatement
@override @override
TableReference table; TableReference table;
final List<SetComponent> set; final List<SetComponent> set;
Queryable? from;
@override @override
Expression? where; Expression? where;
UpdateStatement( UpdateStatement({
{WithClause? withClause, WithClause? withClause,
this.or, this.or,
required this.table, required this.table,
required this.set, required this.set,
this.where}) this.from,
: super._(withClause); this.where,
}) : super._(withClause);
@override @override
R accept<A, R>(AstVisitor<A, R> visitor, A arg) { R accept<A, R>(AstVisitor<A, R> visitor, A arg) {
@ -43,6 +45,7 @@ class UpdateStatement extends CrudStatement
withClause = transformer.transformNullableChild(withClause, this, arg); withClause = transformer.transformNullableChild(withClause, this, arg);
table = transformer.transformChild(table, this, arg); table = transformer.transformChild(table, this, arg);
transformer.transformChildren(set, this, arg); transformer.transformChildren(set, this, arg);
from = transformer.transformNullableChild(from, this, arg);
where = transformer.transformChild(where!, this, arg); where = transformer.transformChild(where!, this, arg);
} }
@ -51,6 +54,7 @@ class UpdateStatement extends CrudStatement
if (withClause != null) withClause!, if (withClause != null) withClause!,
table, table,
...set, ...set,
if (from != null) from!,
if (where != null) where!, if (where != null) where!,
]; ];

View File

@ -1447,6 +1447,7 @@ class Parser {
_consume(TokenType.set, 'Expected SET after the table name'); _consume(TokenType.set, 'Expected SET after the table name');
final set = _setComponents(); final set = _setComponents();
final from = _from();
final where = _where(); final where = _where();
return UpdateStatement( return UpdateStatement(
@ -1454,6 +1455,7 @@ class Parser {
or: failureMode, or: failureMode,
table: table, table: table,
set: set, set: set,
from: from,
where: where, where: where,
)..setSpan(withClause?.first ?? updateToken, _previous); )..setSpan(withClause?.first ?? updateToken, _previous);
} }

View File

@ -101,6 +101,13 @@ class NodeSqlBuilder extends AstVisitor<void, void> {
} }
} }
void _from(Queryable? from) {
if (from != null) {
_keyword(TokenType.from);
visit(from, null);
}
}
@override @override
void visitAggregateExpression(AggregateExpression e, void arg) { void visitAggregateExpression(AggregateExpression e, void arg) {
_symbol(e.name); _symbol(e.name);
@ -484,8 +491,7 @@ class NodeSqlBuilder extends AstVisitor<void, void> {
visitNullable(e.withClause, arg); visitNullable(e.withClause, arg);
_keyword(TokenType.delete); _keyword(TokenType.delete);
_keyword(TokenType.from); _from(e.from);
visit(e.from!, null);
_where(e.where); _where(e.where);
} }
@ -958,10 +964,7 @@ class NodeSqlBuilder extends AstVisitor<void, void> {
_join(e.columns, ','); _join(e.columns, ',');
if (e.from != null) { _from(e.from);
_keyword(TokenType.from);
visit(e.from!, arg);
}
_where(e.where); _where(e.where);
visitNullable(e.groupBy, arg); visitNullable(e.groupBy, arg);
if (e.windowDeclarations.isNotEmpty) { if (e.windowDeclarations.isNotEmpty) {
@ -1156,6 +1159,7 @@ class NodeSqlBuilder extends AstVisitor<void, void> {
visit(e.table, arg); visit(e.table, arg);
_keyword(TokenType.set); _keyword(TokenType.set);
_join(e.set, ','); _join(e.set, ',');
_from(e.from);
_where(e.where); _where(e.where);
} }

View File

@ -26,4 +26,57 @@ void main() {
group('update statements', () { group('update statements', () {
testAll(testCases); testAll(testCases);
}); });
test('parses updates with FROM clause', () {
testStatement(
'''
UPDATE inventory
SET quantity = quantity - daily.amt
FROM (SELECT sum(quantity) AS amt,
itemId FROM sales GROUP BY 2) AS daily
WHERE inventory.itemId = daily.itemId;
''',
UpdateStatement(
table: TableReference('inventory'),
set: [
SetComponent(
column: Reference(columnName: 'quantity'),
expression: BinaryExpression(
Reference(columnName: 'quantity'),
token(TokenType.minus),
Reference(entityName: 'daily', columnName: 'amt'),
),
),
],
from: SelectStatementAsSource(
statement: SelectStatement(
columns: [
ExpressionResultColumn(
expression: FunctionExpression(
name: 'sum',
parameters: ExprFunctionParameters(
parameters: [Reference(columnName: 'quantity')],
),
),
as: 'amt',
),
ExpressionResultColumn(
expression: Reference(columnName: 'itemId'),
),
],
from: TableReference('sales'),
groupBy: GroupBy(
by: [NumericLiteral(2, token(TokenType.numberLiteral))],
),
),
as: 'daily',
),
where: BinaryExpression(
Reference(entityName: 'inventory', columnName: 'itemId'),
token(TokenType.equal),
Reference(entityName: 'daily', columnName: 'itemId'),
),
),
);
});
} }

View File

@ -239,6 +239,29 @@ CREATE UNIQUE INDEX my_idx ON t1 (c1, c2, c3) WHERE c1 < c3;
'ON CONFLICT DO UPDATE SET a = b, c = d WHERE d < a;'); 'ON CONFLICT DO UPDATE SET a = b, c = d WHERE d < a;');
}); });
}); });
group('update', () {
test('simple', () {
testFormat('UPDATE foo SET bar = baz WHERE 1;');
});
const modes = [
'OR ABORT',
'OR FAIL',
'OR IGNORE',
'OR REPLACE',
' OR ROLLBACK',
];
for (var i = 0; i < modes.length; i++) {
test('failure mode #$i', () {
testFormat('UPDATE ${modes[i]} foo SET bar = baz');
});
}
test('from', () {
testFormat('UPDATE foo SET bar = baz FROM t1 CROSS JOIN t2');
});
});
}); });
group('expressions', () { group('expressions', () {