From 0c171c3b813dfc2673f9fce978413fe216f9c161 Mon Sep 17 00:00:00 2001 From: Simon Binder Date: Mon, 3 Feb 2020 21:43:18 +0100 Subject: [PATCH] sqlparser: Support upsert clauses (#367) --- .../analysis/types2/resolving_visitor.dart | 6 +- sqlparser/lib/src/ast/ast.dart | 6 + sqlparser/lib/src/ast/clauses/upsert.dart | 62 +++++++++++ .../lib/src/ast/statements/create_index.dart | 2 +- sqlparser/lib/src/ast/statements/delete.dart | 2 +- sqlparser/lib/src/ast/statements/insert.dart | 7 +- sqlparser/lib/src/ast/statements/select.dart | 3 +- .../lib/src/ast/statements/statement.dart | 4 +- sqlparser/lib/src/ast/statements/update.dart | 2 +- sqlparser/lib/src/ast/visitor.dart | 23 ++++ sqlparser/lib/src/reader/parser/crud.dart | 72 ++++++++++-- sqlparser/lib/src/reader/parser/parser.dart | 2 + sqlparser/lib/src/reader/parser/schema.dart | 21 ++-- sqlparser/lib/src/reader/tokenizer/token.dart | 4 + sqlparser/test/parser/insert_test.dart | 104 ++++++++++++++++++ 15 files changed, 292 insertions(+), 28 deletions(-) create mode 100644 sqlparser/lib/src/ast/clauses/upsert.dart diff --git a/sqlparser/lib/src/analysis/types2/resolving_visitor.dart b/sqlparser/lib/src/analysis/types2/resolving_visitor.dart index 2f5f48bc..54990a51 100644 --- a/sqlparser/lib/src/analysis/types2/resolving_visitor.dart +++ b/sqlparser/lib/src/analysis/types2/resolving_visitor.dart @@ -89,8 +89,8 @@ class TypeResolver extends RecursiveVisitor { @override void visitCrudStatement(CrudStatement stmt, TypeExpectation arg) { - if (stmt is HasWhereClause) { - final typedStmt = stmt as HasWhereClause; + if (stmt is StatementWithWhere) { + final typedStmt = stmt as StatementWithWhere; _handleWhereClause(typedStmt); visitExcept(stmt, typedStmt.where, arg); } else { @@ -561,7 +561,7 @@ class TypeResolver extends RecursiveVisitor { } } - void _handleWhereClause(HasWhereClause stmt) { + void _handleWhereClause(StatementWithWhere stmt) { if (stmt.where != null) { // assume that a where statement is a boolean expression. Sqlite // internally casts (https://www.sqlite.org/lang_expr.html#booleanexpr), diff --git a/sqlparser/lib/src/ast/ast.dart b/sqlparser/lib/src/ast/ast.dart index c6bc9c36..f6732730 100644 --- a/sqlparser/lib/src/ast/ast.dart +++ b/sqlparser/lib/src/ast/ast.dart @@ -8,6 +8,7 @@ import 'package:sqlparser/src/utils/meta.dart'; part 'clauses/limit.dart'; part 'clauses/ordering.dart'; +part 'clauses/upsert.dart'; part 'clauses/with.dart'; part 'common/queryables.dart'; part 'common/renamable.dart'; @@ -156,3 +157,8 @@ abstract class AstNode with HasMetaMixin implements SyntacticEntity { return super.toString(); } } + +/// Common interface for every node that has a `where` clause. +abstract class HasWhereClause implements AstNode { + Expression get where; +} diff --git a/sqlparser/lib/src/ast/clauses/upsert.dart b/sqlparser/lib/src/ast/clauses/upsert.dart new file mode 100644 index 00000000..47fbe824 --- /dev/null +++ b/sqlparser/lib/src/ast/clauses/upsert.dart @@ -0,0 +1,62 @@ +part of '../ast.dart'; + +class UpsertClause extends AstNode implements HasWhereClause { + final List /*?*/ onColumns; + @override + final Expression where; + + final UpsertAction action; + + UpsertClause({this.onColumns, this.where, @required this.action}); + + @override + R accept(AstVisitor visitor, A arg) { + return visitor.visitUpsertClause(this, arg); + } + + @override + Iterable get childNodes { + return [ + if (onColumns != null) ...onColumns, + if (where != null) where, + action, + ]; + } + + @override + bool contentEquals(UpsertClause other) => true; +} + +abstract class UpsertAction extends AstNode {} + +class DoNothing extends UpsertAction { + @override + R accept(AstVisitor visitor, A arg) { + return visitor.visitDoNothing(this, arg); + } + + @override + Iterable get childNodes => const []; + + @override + bool contentEquals(DoNothing other) => true; +} + +class DoUpdate extends UpsertAction implements HasWhereClause { + final List set; + @override + final Expression where; + + DoUpdate(this.set, {this.where}); + + @override + R accept(AstVisitor visitor, A arg) { + return visitor.visitDoUpdate(this, arg); + } + + @override + Iterable get childNodes => [...set, if (where != null) where]; + + @override + bool contentEquals(DoUpdate other) => true; +} diff --git a/sqlparser/lib/src/ast/statements/create_index.dart b/sqlparser/lib/src/ast/statements/create_index.dart index f535bccc..08a78619 100644 --- a/sqlparser/lib/src/ast/statements/create_index.dart +++ b/sqlparser/lib/src/ast/statements/create_index.dart @@ -1,7 +1,7 @@ part of '../ast.dart'; class CreateIndexStatement extends Statement - implements CreatingStatement, HasWhereClause { + implements CreatingStatement, StatementWithWhere { final String indexName; final bool unique; final bool ifNotExists; diff --git a/sqlparser/lib/src/ast/statements/delete.dart b/sqlparser/lib/src/ast/statements/delete.dart index 9135932d..cb80750a 100644 --- a/sqlparser/lib/src/ast/statements/delete.dart +++ b/sqlparser/lib/src/ast/statements/delete.dart @@ -1,6 +1,6 @@ part of '../ast.dart'; -class DeleteStatement extends CrudStatement implements HasWhereClause { +class DeleteStatement extends CrudStatement implements StatementWithWhere { final TableReference from; @override final Expression where; diff --git a/sqlparser/lib/src/ast/statements/insert.dart b/sqlparser/lib/src/ast/statements/insert.dart index 0b205bdf..04861d83 100644 --- a/sqlparser/lib/src/ast/statements/insert.dart +++ b/sqlparser/lib/src/ast/statements/insert.dart @@ -15,6 +15,7 @@ class InsertStatement extends CrudStatement { final TableReference table; final List targetColumns; final InsertSource source; + final UpsertClause upsert; List get resolvedTargetColumns { if (targetColumns.isNotEmpty) { @@ -25,14 +26,13 @@ class InsertStatement extends CrudStatement { } } - // todo parse upsert clauses - InsertStatement( {WithClause withClause, this.mode = InsertMode.insert, @required this.table, @required this.targetColumns, - @required this.source}) + @required this.source, + this.upsert}) : super._(withClause); @override @@ -46,6 +46,7 @@ class InsertStatement extends CrudStatement { yield table; yield* targetColumns; yield* source.childNodes; + if (upsert != null) yield upsert; } @override diff --git a/sqlparser/lib/src/ast/statements/select.dart b/sqlparser/lib/src/ast/statements/select.dart index a45840e6..0ffe47c0 100644 --- a/sqlparser/lib/src/ast/statements/select.dart +++ b/sqlparser/lib/src/ast/statements/select.dart @@ -9,7 +9,8 @@ abstract class BaseSelectStatement extends CrudStatement with ResultSet { BaseSelectStatement._(WithClause withClause) : super._(withClause); } -class SelectStatement extends BaseSelectStatement implements HasWhereClause { +class SelectStatement extends BaseSelectStatement + implements StatementWithWhere { final bool distinct; final List columns; final List from; diff --git a/sqlparser/lib/src/ast/statements/statement.dart b/sqlparser/lib/src/ast/statements/statement.dart index f6b2706b..382d696c 100644 --- a/sqlparser/lib/src/ast/statements/statement.dart +++ b/sqlparser/lib/src/ast/statements/statement.dart @@ -14,9 +14,7 @@ abstract class CrudStatement extends Statement { /// Interface for statements that have a primary where clause (select, update, /// delete). -abstract class HasWhereClause extends Statement { - Expression get where; -} +abstract class StatementWithWhere extends Statement implements HasWhereClause {} /// Marker interface for statements that change the table structure. abstract class SchemaStatement extends Statement implements PartOfMoorFile {} diff --git a/sqlparser/lib/src/ast/statements/update.dart b/sqlparser/lib/src/ast/statements/update.dart index 04f138f5..3f721537 100644 --- a/sqlparser/lib/src/ast/statements/update.dart +++ b/sqlparser/lib/src/ast/statements/update.dart @@ -16,7 +16,7 @@ const Map _tokensToMode = { TokenType.ignore: FailureMode.ignore, }; -class UpdateStatement extends CrudStatement implements HasWhereClause { +class UpdateStatement extends CrudStatement implements StatementWithWhere { final FailureMode or; final TableReference table; final List set; diff --git a/sqlparser/lib/src/ast/visitor.dart b/sqlparser/lib/src/ast/visitor.dart index d7944111..a0e41007 100644 --- a/sqlparser/lib/src/ast/visitor.dart +++ b/sqlparser/lib/src/ast/visitor.dart @@ -14,6 +14,7 @@ abstract class AstVisitor { R visitCreateIndexStatement(CreateIndexStatement e, A arg); R visitWithClause(WithClause e, A arg); + R visitUpsertClause(UpsertClause e, A arg); R visitCommonTableExpression(CommonTableExpression e, A arg); R visitOrderBy(OrderBy e, A arg); R visitOrderingTerm(OrderingTerm e, A arg); @@ -22,6 +23,9 @@ abstract class AstVisitor { R visitJoin(Join e, A arg); R visitGroupBy(GroupBy e, A arg); + R visitDoNothing(DoNothing e, A arg); + R visitDoUpdate(DoUpdate e, A arg); + R visitSetComponent(SetComponent e, A arg); R visitColumnDefinition(ColumnDefinition e, A arg); @@ -158,6 +162,25 @@ class RecursiveVisitor implements AstVisitor { return visitChildren(e, arg); } + @override + R visitUpsertClause(UpsertClause e, A arg) { + return visitChildren(e, arg); + } + + @override + R visitDoNothing(DoNothing e, A arg) { + return defaultUpsertAction(e, arg); + } + + @override + R visitDoUpdate(DoUpdate e, A arg) { + return defaultUpsertAction(e, arg); + } + + R defaultUpsertAction(UpsertAction e, A arg) { + return visitChildren(e, arg); + } + @override R visitCommonTableExpression(CommonTableExpression e, A arg) { return visitChildren(e, arg); diff --git a/sqlparser/lib/src/reader/parser/crud.dart b/sqlparser/lib/src/reader/parser/crud.dart index 46fe437c..503a6686 100644 --- a/sqlparser/lib/src/reader/parser/crud.dart +++ b/sqlparser/lib/src/reader/parser/crud.dart @@ -562,6 +562,19 @@ mixin CrudParser on ParserBase { final table = _tableReference(); _consume(TokenType.set, 'Expected SET after the table name'); + final set = _setComponents(); + + final where = _where(); + return UpdateStatement( + withClause: withClause, + or: failureMode, + table: table, + set: set, + where: where, + )..setSpan(withClause?.first ?? updateToken, _previous); + } + + List _setComponents() { final set = []; do { final columnName = @@ -576,14 +589,7 @@ mixin CrudParser on ParserBase { ..setSpan(columnName, _previous)); } while (_matchOne(TokenType.comma)); - final where = _where(); - return UpdateStatement( - withClause: withClause, - or: failureMode, - table: table, - set: set, - where: where, - )..setSpan(withClause?.first ?? updateToken, _previous); + return set; } InsertStatement _insertStmt([WithClause withClause]) { @@ -632,6 +638,7 @@ mixin CrudParser on ParserBase { 'Expected clpsing parenthesis after column list'); } final source = _insertSource(); + final upsert = _upsertClauseOrNull(); return InsertStatement( withClause: withClause, @@ -639,6 +646,7 @@ mixin CrudParser on ParserBase { table: table, targetColumns: targetColumns, source: source, + upsert: upsert, )..setSpan(withClause?.first ?? firstToken, _previous); } @@ -659,6 +667,54 @@ mixin CrudParser on ParserBase { } } + UpsertClause _upsertClauseOrNull() { + if (!_matchOne(TokenType.on)) return null; + + final first = _previous; + _consume(TokenType.conflict, 'Expected CONFLICT keyword for upsert clause'); + + List indexedColumns; + Expression where; + if (_matchOne(TokenType.leftParen)) { + indexedColumns = _indexedColumns(); + + _consume(TokenType.rightParen, 'Expected closing paren here'); + if (_matchOne(TokenType.where)) { + where = expression(); + } + } + + _consume(TokenType.$do, + 'Expected DO, followed by the action (NOTHING or UPDATE SET)'); + + UpsertAction action; + if (_matchOne(TokenType.nothing)) { + action = DoNothing()..setSpan(_previous, _previous); + } else if (_check(TokenType.update)) { + action = _doUpdate(); + } + + return UpsertClause( + onColumns: indexedColumns, + where: where, + action: action, + )..setSpan(first, _previous); + } + + DoUpdate _doUpdate() { + _consume(TokenType.update, 'Expected UPDATE SET keyword here'); + final first = _previous; + _consume(TokenType.set, 'Expected UPDATE SET keyword here'); + + final set = _setComponents(); + Expression where; + if (_matchOne(TokenType.where)) { + where = expression(); + } + + return DoUpdate(set, where: where)..setSpan(first, _previous); + } + @override WindowDefinition _windowDefinition() { _consume(TokenType.leftParen, 'Expected opening parenthesis'); diff --git a/sqlparser/lib/src/reader/parser/parser.dart b/sqlparser/lib/src/reader/parser/parser.dart index 7bb7a46b..c9e14c53 100644 --- a/sqlparser/lib/src/reader/parser/parser.dart +++ b/sqlparser/lib/src/reader/parser/parser.dart @@ -204,6 +204,8 @@ abstract class ParserBase { /// Parses function parameters, without the surrounding parentheses. FunctionParameters _functionParameters(); + List _indexedColumns(); + /// Skips all tokens until it finds one with [type]. If [skipTarget] is true, /// that token will be skipped as well. /// diff --git a/sqlparser/lib/src/reader/parser/schema.dart b/sqlparser/lib/src/reader/parser/schema.dart index 5626ca83..0f543876 100644 --- a/sqlparser/lib/src/reader/parser/schema.dart +++ b/sqlparser/lib/src/reader/parser/schema.dart @@ -267,13 +267,7 @@ mixin SchemaParser on ParserBase { _consume(TokenType.leftParen, 'Expected indexed columns in parentheses'); - final indexes = []; - do { - final expr = expression(); - final mode = _orderingModeOrNull(); - - indexes.add(IndexedColumn(expr, mode)..setSpan(expr.first, _previous)); - } while (_matchOne(TokenType.comma)); + final indexes = _indexedColumns(); _consume(TokenType.rightParen, 'Expected closing bracket'); @@ -294,6 +288,19 @@ mixin SchemaParser on ParserBase { ..setSpan(create, _previous); } + @override + List _indexedColumns() { + final indexes = []; + do { + final expr = expression(); + final mode = _orderingModeOrNull(); + + indexes.add(IndexedColumn(expr, mode)..setSpan(expr.first, _previous)); + } while (_matchOne(TokenType.comma)); + + return indexes; + } + /// Parses `IF NOT EXISTS` | epsilon bool _ifNotExists() { if (_matchOne(TokenType.$if)) { diff --git a/sqlparser/lib/src/reader/tokenizer/token.dart b/sqlparser/lib/src/reader/tokenizer/token.dart index f0da59f3..b74065b0 100644 --- a/sqlparser/lib/src/reader/tokenizer/token.dart +++ b/sqlparser/lib/src/reader/tokenizer/token.dart @@ -6,6 +6,7 @@ enum TokenType { rightParen, comma, dot, + $do, doublePipe, star, slash, @@ -27,6 +28,7 @@ enum TokenType { $is, $in, not, + nothing, like, glob, match, @@ -171,6 +173,7 @@ const Map keywords = { 'INTO': TokenType.into, 'COLLATE': TokenType.collate, 'DISTINCT': TokenType.distinct, + 'DO': TokenType.$do, 'UPDATE': TokenType.update, 'ALL': TokenType.all, 'AND': TokenType.and, @@ -206,6 +209,7 @@ const Map keywords = { 'REGEXP': TokenType.regexp, 'ESCAPE': TokenType.escape, 'NOT': TokenType.not, + 'NOTHING': TokenType.nothing, 'TRUE': TokenType.$true, 'FALSE': TokenType.$false, 'NULL': TokenType.$null, diff --git a/sqlparser/test/parser/insert_test.dart b/sqlparser/test/parser/insert_test.dart index 2e5bc6c5..49a77c7a 100644 --- a/sqlparser/test/parser/insert_test.dart +++ b/sqlparser/test/parser/insert_test.dart @@ -54,4 +54,108 @@ void main() { ), ); }); + + group('parses upsert clauses', () { + const prefix = 'INSERT INTO tbl DEFAULT VALUES ON CONFLICT'; + test('without listing indexed columns', () { + testStatement( + '$prefix DO NOTHING', + InsertStatement( + table: TableReference('tbl'), + targetColumns: const [], + source: const DefaultValues(), + upsert: UpsertClause(action: DoNothing()), + ), + ); + }); + + test('listing indexed columns without where clause', () { + testStatement( + '$prefix (foo, bar DESC) DO NOTHING', + InsertStatement( + table: TableReference('tbl'), + targetColumns: const [], + source: const DefaultValues(), + upsert: UpsertClause( + onColumns: [ + IndexedColumn(Reference(columnName: 'foo')), + IndexedColumn( + Reference(columnName: 'bar'), + OrderingMode.descending, + ), + ], + action: DoNothing(), + ), + ), + ); + }); + + test('listing indexed columns and where clause', () { + testStatement( + '$prefix (foo, bar) WHERE 2 = foo DO NOTHING', + InsertStatement( + table: TableReference('tbl'), + targetColumns: const [], + source: const DefaultValues(), + upsert: UpsertClause( + onColumns: [ + IndexedColumn(Reference(columnName: 'foo')), + IndexedColumn(Reference(columnName: 'bar')), + ], + where: BinaryExpression( + NumericLiteral(2, token(TokenType.numberLiteral)), + token(TokenType.equal), + Reference(columnName: 'foo'), + ), + action: DoNothing(), + ), + ), + ); + }); + + test('having an update action without where', () { + testStatement( + '$prefix DO UPDATE SET foo = 2', + InsertStatement( + table: TableReference('tbl'), + targetColumns: const [], + source: const DefaultValues(), + upsert: UpsertClause( + action: DoUpdate( + [ + SetComponent( + column: Reference(columnName: 'foo'), + expression: NumericLiteral(2, token(TokenType.numberLiteral)), + ), + ], + ), + ), + ), + ); + }); + + test('having an update action with where', () { + testStatement( + '$prefix DO UPDATE SET foo = 2 WHERE ?', + InsertStatement( + table: TableReference('tbl'), + targetColumns: const [], + source: const DefaultValues(), + upsert: UpsertClause( + action: DoUpdate( + [ + SetComponent( + column: Reference(columnName: 'foo'), + expression: NumericLiteral(2, token(TokenType.numberLiteral)), + ), + ], + where: NumberedVariable( + QuestionMarkVariableToken(fakeSpan('?'), null), + ), + ), + ), + ), + ); + }); + }); }