Merge branch 'add-update-column-name-list' into develop

This commit is contained in:
Simon Binder 2023-11-18 00:32:33 +01:00
commit fcce984af4
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
18 changed files with 379 additions and 41 deletions

View File

@ -388,9 +388,8 @@ class LintingVisitor extends RecursiveVisitor<void, void> {
visitChildren(e, arg);
}
@override
void visitSetComponent(SetComponent e, void arg) {
final target = e.column.resolvedColumn;
void _checkForGeneratedColumn(Reference column) {
final target = column.resolvedColumn;
if (target is TableColumn && target.isGenerated) {
context.reportError(
@ -398,7 +397,43 @@ class LintingVisitor extends RecursiveVisitor<void, void> {
type: AnalysisErrorType.writeToGeneratedColumn,
message: 'This column is generated, and generated columns cannot be '
'updated explicitly.',
relevantNode: e.column,
relevantNode: column,
),
);
}
}
@override
void visitSingleColumnSetComponent(SingleColumnSetComponent e, void arg) {
_checkForGeneratedColumn(e.column);
visitChildren(e, arg);
}
@override
void visitMultiColumnSetComponent(MultiColumnSetComponent e, void arg) {
for (final column in e.columns) {
_checkForGeneratedColumn(column);
}
if (e.rowValue is Tuple &&
e.columns.length != (e.rowValue as Tuple).expressions.length) {
context.reportError(
AnalysisError(
type: AnalysisErrorType.cteColumnCountMismatch,
message:
'Length of column-name-list must match length of row values.',
relevantNode: e.rowValue,
),
);
} else if (e.rowValue is SubQuery &&
e.columns.length !=
(e.rowValue as SubQuery).select.resolvedColumns?.length) {
context.reportError(
AnalysisError(
type: AnalysisErrorType.cteColumnCountMismatch,
message:
'Length of column-name-list must match length of columns returned by SubQuery.',
relevantNode: e.rowValue,
),
);
}
@ -494,6 +529,8 @@ class LintingVisitor extends RecursiveVisitor<void, void> {
// In expressions are tricky. The rhs can always be a row value, but the
// lhs can only be a row value if the rhs is a subquery
isAllowed = e == parent.inside || parent.inside is SubQuery;
} else if (parent is SetComponent) {
isAllowed = true;
}
if (!isAllowed) {

View File

@ -113,7 +113,9 @@ class ReferenceResolver
if (table != null) {
// Resolve the set components against the primary table
for (final set in e.set) {
_resolveReferenceInTable(set.column, table);
for (final column in set.columns) {
_resolveReferenceInTable(column, table);
}
}
}

View File

@ -131,12 +131,33 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
}
@override
void visitSetComponent(SetComponent e, TypeExpectation arg) {
void visitSingleColumnSetComponent(
SingleColumnSetComponent e, TypeExpectation arg) {
visit(e.column, const NoTypeExpectation());
_lazyCopy(e.expression, e.column);
visit(e.expression, const NoTypeExpectation());
}
@override
void visitMultiColumnSetComponent(
MultiColumnSetComponent e, TypeExpectation arg) {
visitList(e.columns, const NoTypeExpectation());
final targets = e.resolvedTargetColumns ?? const [];
for (final column in targets) {
_handleColumn(column, e);
}
final expectations = targets.map((r) {
if (r != null && session.graph.knowsType(r)) {
return ExactTypeExpectation(session.typeOf(r)!);
}
return const NoTypeExpectation();
}).toList();
visit(e.rowValue, SelectTypeExpectation(expectations));
}
@override
void visitGroupBy(GroupBy e, TypeExpectation arg) {
visitList(e.by, const NoTypeExpectation());

View File

@ -78,15 +78,22 @@ class UpdateStatement extends CrudStatement
}
}
class SetComponent extends AstNode {
abstract class SetComponent extends AstNode {
List<Reference> get columns;
}
class SingleColumnSetComponent extends SetComponent {
Reference column;
Expression expression;
SetComponent({required this.column, required this.expression});
@override
List<Reference> get columns => [column];
SingleColumnSetComponent({required this.column, required this.expression});
@override
R accept<A, R>(AstVisitor<A, R> visitor, A arg) {
return visitor.visitSetComponent(this, arg);
return visitor.visitSingleColumnSetComponent(this, arg);
}
@override
@ -98,3 +105,31 @@ class SetComponent extends AstNode {
@override
Iterable<AstNode> get childNodes => [column, expression];
}
class MultiColumnSetComponent extends SetComponent {
@override
List<Reference> columns;
// Will be either Tuple or SubQuery
Expression rowValue;
List<Column?>? get resolvedTargetColumns {
if (columns.isEmpty) return null;
return columns.map((c) => c.resolvedColumn).toList();
}
MultiColumnSetComponent({required this.columns, required this.rowValue});
@override
R accept<A, R>(AstVisitor<A, R> visitor, A arg) {
return visitor.visitMultiColumnSetComponent(this, arg);
}
@override
void transformChildren<A>(Transformer<A> transformer, A arg) {
columns = transformer.transformChildren(columns, this, arg);
rowValue = transformer.transformChild(rowValue, this, arg);
}
@override
Iterable<AstNode> get childNodes => [...columns, rowValue];
}

View File

@ -43,7 +43,8 @@ abstract class AstVisitor<A, R> {
R visitDoNothing(DoNothing e, A arg);
R visitDoUpdate(DoUpdate e, A arg);
R visitSetComponent(SetComponent e, A arg);
R visitSingleColumnSetComponent(SingleColumnSetComponent e, A arg);
R visitMultiColumnSetComponent(MultiColumnSetComponent e, A arg);
R visitValuesSource(ValuesSource e, A arg);
R visitSelectInsertSource(SelectInsertSource e, A arg);
@ -315,7 +316,16 @@ class RecursiveVisitor<A, R> implements AstVisitor<A, R?> {
}
@override
R? visitSetComponent(SetComponent e, A arg) {
R? visitSingleColumnSetComponent(SingleColumnSetComponent e, A arg) {
return defaultSetComponent(e, arg);
}
@override
R? visitMultiColumnSetComponent(MultiColumnSetComponent e, A arg) {
return defaultSetComponent(e, arg);
}
R? defaultSetComponent(SetComponent e, A arg) {
return defaultNode(e, arg);
}

View File

@ -1030,8 +1030,10 @@ class Parser {
}
/// Parses a [Tuple]. If [orSubQuery] is set (defaults to false), a [SubQuery]
/// (in brackets) will be accepted as well.
Expression _consumeTuple({bool orSubQuery = false}) {
/// (in brackets) will be accepted as well. If parsing a [Tuple], [usedAsRowValue] is
/// passed into the [Tuple] constructor.
Expression _consumeTuple(
{bool orSubQuery = false, bool usedAsRowValue = false}) {
final firstToken =
_consume(TokenType.leftParen, 'Expected opening parenthesis for tuple');
final expressions = <Expression>[];
@ -1049,7 +1051,8 @@ class Parser {
_consume(
TokenType.rightParen, 'Expected right parenthesis to close tuple');
return Tuple(expressions: expressions)..setSpan(firstToken, _previous);
return Tuple(expressions: expressions, usedAsRowValue: usedAsRowValue)
..setSpan(firstToken, _previous);
} else {
_consume(TokenType.rightParen,
'Expected right parenthesis to finish subquery');
@ -1686,19 +1689,47 @@ class Parser {
)..setSpan(withClause?.first ?? updateToken, _previous);
}
SingleColumnSetComponent _singleColumnSetComponent() {
final columnName = _consumeIdentifier('Expected a column name to set');
final reference = Reference(columnName: columnName.identifier)
..setSpan(columnName, columnName);
_consume(TokenType.equal, 'Expected = after the column name');
final expr = expression();
return SingleColumnSetComponent(column: reference, expression: expr)
..setSpan(columnName, _previous);
}
MultiColumnSetComponent _multiColumnSetComponent() {
final first = _consume(
TokenType.leftParen, 'Expected opening parenthesis before column list');
final targetColumns = <Reference>[];
do {
final columnName = _consumeIdentifier('Expected a column');
targetColumns.add(Reference(columnName: columnName.identifier)
..setSpan(columnName, columnName));
} while (_matchOne(TokenType.comma));
_consume(
TokenType.rightParen, 'Expected closing parenthesis after column list');
_consume(TokenType.equal, 'Expected = after the column name list');
final tupleOrSubQuery =
_consumeTuple(orSubQuery: true, usedAsRowValue: true);
return MultiColumnSetComponent(
columns: targetColumns, rowValue: tupleOrSubQuery)
..setSpan(first, _previous);
}
List<SetComponent> _setComponents() {
final set = <SetComponent>[];
do {
final columnName =
_consume(TokenType.identifier, 'Expected a column name to set')
as IdentifierToken;
final reference = Reference(columnName: columnName.identifier)
..setSpan(columnName, columnName);
_consume(TokenType.equal, 'Expected = after the column name');
final expr = expression();
set.add(SetComponent(column: reference, expression: expr)
..setSpan(columnName, _previous));
if (_check(TokenType.leftParen)) {
set.add(_multiColumnSetComponent());
} else {
set.add(_singleColumnSetComponent());
}
} while (_matchOne(TokenType.comma));
return set;

View File

@ -608,8 +608,14 @@ class EqualityEnforcingVisitor implements AstVisitor<void, void> {
}
@override
void visitSetComponent(SetComponent e, void arg) {
_currentAs<SetComponent>(e);
void visitSingleColumnSetComponent(SingleColumnSetComponent e, void arg) {
_currentAs<SingleColumnSetComponent>(e);
_checkChildren(e);
}
@override
void visitMultiColumnSetComponent(MultiColumnSetComponent e, void arg) {
_currentAs<MultiColumnSetComponent>(e);
_checkChildren(e);
}

View File

@ -1062,12 +1062,21 @@ class NodeSqlBuilder extends AstVisitor<void, void> {
}
@override
void visitSetComponent(SetComponent e, void arg) {
void visitSingleColumnSetComponent(SingleColumnSetComponent e, void arg) {
visit(e.column, arg);
symbol('=', spaceBefore: true, spaceAfter: true);
visit(e.expression, arg);
}
@override
void visitMultiColumnSetComponent(MultiColumnSetComponent e, void arg) {
symbol('(', spaceBefore: true);
_join(e.columns, ',');
symbol(')');
symbol('=', spaceBefore: true, spaceAfter: true);
visit(e.rowValue, arg);
}
@override
void visitStarFunctionParameter(StarFunctionParameter e, void arg) {
symbol('*', spaceAfter: true);

View File

@ -89,12 +89,25 @@ void main() {
expect(result.errors, isEmpty);
});
test('update statements with column-name-list', () {
final result = engine.analyze(
'WITH x AS (SELECT * FROM demo) UPDATE demo '
'SET (id, content) = (x.id, x.content) FROM x WHERE demo.id = x.id;');
expect(result.errors, isEmpty);
});
test('insert statements', () {
final result = engine.analyze(
'WITH x AS (SELECT * FROM demo) INSERT INTO demo SELECT * FROM x;');
expect(result.errors, isEmpty);
});
test('insert statements with upsert using column-name-list', () {
final result = engine.analyze(
'WITH x AS (SELECT * FROM demo) INSERT INTO demo SELECT * FROM x ON CONFLICT(content) DO UPDATE SET (id, content) = (0, \'hello\');');
expect(result.errors, isEmpty);
});
test('delete statements', () {
final result = engine.analyze(
'WITH x AS (SELECT * FROM demo) DELETE FROM demo WHERE id IN (SELECT id FROM x);');

View File

@ -0,0 +1,72 @@
import 'package:sqlparser/sqlparser.dart';
import 'package:test/test.dart';
import '../data.dart';
import 'utils.dart';
void main() {
late SqlEngine engine;
setUp(() {
engine = SqlEngine();
engine.registerTableFromSql('''
CREATE TABLE foo (
a INTEGER NOT NULL,
b INTEGER NOT NULL
);
''');
});
group("using column-name-list and tuple in update", () {
test('reports error if they have different sizes', () {
engine
.analyze("UPDATE foo SET (a, b) = (1);")
.expectError('(1)', type: AnalysisErrorType.cteColumnCountMismatch);
engine
.analyze("UPDATE foo SET (a) = (1,2);")
.expectError('(1,2)', type: AnalysisErrorType.cteColumnCountMismatch);
});
test('reports no error if they have same sizes', () {
engine.analyze("UPDATE foo SET (a, b) = (1,2);").expectNoError();
});
});
group("using column-name-list and values in update", () {
test('reports error if they have different sizes', () {
engine.analyze("UPDATE foo SET (a, b) = (VALUES(1));").expectError(
"(VALUES(1))",
type: AnalysisErrorType.cteColumnCountMismatch);
engine.analyze("UPDATE foo SET (a) = (VALUES(1,2));").expectError(
"(VALUES(1,2))",
type: AnalysisErrorType.cteColumnCountMismatch);
});
test('reports no error if they have same sizes', () {
engine.analyze("UPDATE foo SET (a, b) = (VALUES(1,2));").expectNoError();
});
});
group("using column-name-list and subquery in update", () {
test('reports error if they have different sizes', () {
engine.analyze("UPDATE foo SET (a, b) = (SELECT 1);").expectError(
'(SELECT 1)',
type: AnalysisErrorType.cteColumnCountMismatch);
engine.analyze("UPDATE foo SET (a) = (SELECT 1,2);").expectError(
'(SELECT 1,2)',
type: AnalysisErrorType.cteColumnCountMismatch);
engine
.analyze(
"UPDATE foo SET (a, b) = (SELECT b FROM foo as f WHERE f.a=a);")
.expectError('(SELECT b FROM foo as f WHERE f.a=a)',
type: AnalysisErrorType.cteColumnCountMismatch);
});
test('reports no error if they have same sizes', () {
engine.analyze("UPDATE foo SET (a, b) = (SELECT 1,2);").expectNoError();
engine
.analyze(
"UPDATE foo SET (a, b) = (SELECT b, a FROM foo as f WHERE f.a=a);")
.expectNoError();
});
});
}

View File

@ -23,6 +23,13 @@ void main() {
.expectError('g', type: AnalysisErrorType.writeToGeneratedColumn);
});
test('reports error when updating generated column with column-name-list',
() {
engine
.analyze("UPDATE a SET (ok, g) = ('new', 'old');")
.expectError('g', type: AnalysisErrorType.writeToGeneratedColumn);
});
test('reports error when inserting generated column', () {
engine
.analyze('INSERT INTO a (ok, g) VALUES (?, ?)')

View File

@ -14,13 +14,18 @@ void main() {
return TypeResolver(TypeInferenceSession(context))..run(context.root);
}
ResolvedType? resolveFirstVariable(String sql,
Iterable<ResolvedType?> resolveVariableTypes(String sql,
{AnalyzeStatementOptions? options}) {
final resolver = obtainResolver(sql, options: options);
final session = resolver.session;
final variable =
session.context.root.allDescendants.whereType<Variable>().first;
return session.typeOf(variable);
return session.context.root.allDescendants
.whereType<Variable>()
.map((variable) => session.typeOf(variable));
}
ResolvedType? resolveFirstVariable(String sql,
{AnalyzeStatementOptions? options}) {
return resolveVariableTypes(sql, options: options).first;
}
ResolvedType? resolveResultColumn(String sql) {
@ -292,6 +297,32 @@ WITH RECURSIVE
expect(type, const ResolvedType(type: BasicType.int));
});
test('handles multi column set components in updates', () {
final variableTypes =
resolveVariableTypes('UPDATE demo SET (id, content) = (?, ?)');
expect(variableTypes.first, const ResolvedType(type: BasicType.int));
expect(
variableTypes.elementAt(1), const ResolvedType(type: BasicType.text));
});
test('handles multi column set components in updates with select subquery',
() {
final variableTypes =
resolveVariableTypes('UPDATE demo SET (id, content) = (SELECT ?,?)');
expect(variableTypes.first, const ResolvedType(type: BasicType.int));
expect(
variableTypes.elementAt(1), const ResolvedType(type: BasicType.text));
});
test('handles multi column set components in updates with values subquery',
() {
final variableTypes =
resolveVariableTypes('UPDATE demo SET (id, content) = (VALUES(?,?))');
expect(variableTypes.first, const ResolvedType(type: BasicType.int));
expect(
variableTypes.elementAt(1), const ResolvedType(type: BasicType.text));
});
test('infers offsets in frame specs', () {
final type = resolveFirstVariable('SELECT SUM(id) OVER (ROWS ? PRECEDING)');
expect(type, const ResolvedType(type: BasicType.int));

View File

@ -5,7 +5,7 @@ import 'utils.dart';
final _block = Block([
UpdateStatement(table: TableReference('tbl'), set: [
SetComponent(
SingleColumnSetComponent(
column: Reference(columnName: 'foo'),
expression: Reference(columnName: 'bar'),
),

View File

@ -151,7 +151,7 @@ END;
UpdateStatement(
table: TableReference('foo'),
set: [
SetComponent(
SingleColumnSetComponent(
column: Reference(columnName: 'bar'),
expression: Reference(columnName: 'baz'),
),

View File

@ -133,7 +133,7 @@ void main() {
UpsertClauseEntry(
action: DoUpdate(
[
SetComponent(
SingleColumnSetComponent(
column: Reference(columnName: 'foo'),
expression: NumericLiteral(2),
),
@ -158,7 +158,7 @@ void main() {
UpsertClauseEntry(
action: DoUpdate(
[
SetComponent(
SingleColumnSetComponent(
column: Reference(columnName: 'foo'),
expression: NumericLiteral(2),
),
@ -189,7 +189,7 @@ void main() {
onColumns: [IndexedColumn(Reference(columnName: 'bar'))],
action: DoUpdate(
[
SetComponent(
SingleColumnSetComponent(
column: Reference(columnName: 'x'),
expression: NumericLiteral(2),
),

View File

@ -18,7 +18,7 @@ void main() {
UpdateStatement(
table: TableReference('tbl'),
set: [
SetComponent(
SingleColumnSetComponent(
column: Reference(columnName: 'a'),
expression: Reference(columnName: 'b'),
),

View File

@ -8,11 +8,11 @@ final Map<String, AstNode> testCases = {
or: FailureMode.rollback,
table: TableReference('tbl'),
set: [
SetComponent(
SingleColumnSetComponent(
column: Reference(columnName: 'a'),
expression: NullLiteral(),
),
SetComponent(
SingleColumnSetComponent(
column: Reference(columnName: 'b'),
expression: Reference(columnName: 'c'),
)
@ -26,6 +26,57 @@ void main() {
testAll(testCases);
});
test('parses updates with column-name-list and subquery', () {
testStatement(
'''
UPDATE foo
SET (a, b) = (SELECT b, a FROM bar AS b WHERE b.id=foo.id);
''',
UpdateStatement(table: TableReference('foo'), set: [
MultiColumnSetComponent(
columns: [Reference(columnName: 'a'), Reference(columnName: 'b')],
rowValue: SubQuery(
select: SelectStatement(
columns: [
ExpressionResultColumn(
expression: Reference(columnName: 'b'),
),
ExpressionResultColumn(
expression: Reference(columnName: 'a'),
),
],
from: TableReference('bar', as: 'b'),
where: BinaryExpression(
Reference(entityName: 'b', columnName: 'id'),
token(TokenType.equal),
Reference(entityName: 'foo', columnName: 'id'),
),
),
),
)
]),
);
});
test('parses updates with column-name-list and scalar rowValues', () {
testStatement(
'''
UPDATE foo
SET (a, b) = (b, 3+4);
''',
UpdateStatement(table: TableReference('foo'), set: [
MultiColumnSetComponent(
columns: [Reference(columnName: 'a'), Reference(columnName: 'b')],
rowValue: Tuple(expressions: [
Reference(columnName: "b"),
BinaryExpression(
NumericLiteral(3), token(TokenType.plus), NumericLiteral(4)),
], usedAsRowValue: true),
)
]),
);
});
test('parses updates with FROM clause', () {
testStatement(
'''
@ -38,7 +89,7 @@ void main() {
UpdateStatement(
table: TableReference('inventory'),
set: [
SetComponent(
SingleColumnSetComponent(
column: Reference(columnName: 'quantity'),
expression: BinaryExpression(
Reference(columnName: 'quantity'),
@ -85,7 +136,7 @@ void main() {
UpdateStatement(
table: TableReference('tbl'),
set: [
SetComponent(
SingleColumnSetComponent(
column: Reference(columnName: 'foo'),
expression: Reference(columnName: 'bar'),
),

View File

@ -407,6 +407,11 @@ CREATE UNIQUE INDEX my_idx ON t1 (c1, c2, c3) WHERE c1 < c3;
'ON CONFLICT DO UPDATE SET a = b, c = d WHERE d < a;');
});
test('upsert - update with column-name-list', () {
testFormat('INSERT INTO foo VALUES (1, 2, 3) '
'ON CONFLICT DO UPDATE SET (a, c) = (b, d) WHERE d < a;');
});
test('upsert - multiple clauses', () {
testFormat('INSERT INTO foo VALUES (1, 2, 3) '
'ON CONFLICT DO NOTHING '
@ -419,6 +424,14 @@ CREATE UNIQUE INDEX my_idx ON t1 (c1, c2, c3) WHERE c1 < c3;
testFormat('UPDATE foo SET bar = baz WHERE 1;');
});
test('with column-name-list', () {
testFormat('UPDATE foo SET (bar, baz) = (baz, bar) WHERE 1;');
});
test('with column-name-list and subquery', () {
testFormat('UPDATE foo SET (bar, baz) = (SELECT 1,2) WHERE 1;');
});
test('with returning', () {
testFormat('UPDATE foo SET bar = baz RETURNING *');
});