diff --git a/sqlparser/lib/src/ast/ast.dart b/sqlparser/lib/src/ast/ast.dart index 2f775b4e..0e853b9b 100644 --- a/sqlparser/lib/src/ast/ast.dart +++ b/sqlparser/lib/src/ast/ast.dart @@ -7,6 +7,7 @@ import 'package:sqlparser/src/utils/meta.dart'; part 'clauses/limit.dart'; part 'clauses/ordering.dart'; +part 'clauses/with.dart'; part 'common/queryables.dart'; part 'common/renamable.dart'; @@ -149,6 +150,8 @@ abstract class AstVisitor { T visitUpdateStatement(UpdateStatement e); T visitCreateTableStatement(CreateTableStatement e); + T visitWithClause(WithClause e); + T visitCommonTableExpression(CommonTableExpression e); T visitOrderBy(OrderBy e); T visitOrderingTerm(OrderingTerm e); T visitLimit(Limit e); @@ -307,6 +310,12 @@ class RecursiveVisitor extends AstVisitor { @override T visitFrameSpec(FrameSpec e) => visitChildren(e); + @override + T visitWithClause(WithClause e) => visitChildren(e); + + @override + T visitCommonTableExpression(CommonTableExpression e) => visitChildren(e); + @override T visitMoorFile(MoorFile e) => visitChildren(e); diff --git a/sqlparser/lib/src/ast/clauses/with.dart b/sqlparser/lib/src/ast/clauses/with.dart new file mode 100644 index 00000000..1933baca --- /dev/null +++ b/sqlparser/lib/src/ast/clauses/with.dart @@ -0,0 +1,50 @@ +part of '../ast.dart'; + +class WithClause extends AstNode { + Token withToken; + + final bool recursive; + Token recursiveToken; + + final List ctes; + + WithClause({@required this.recursive, @required this.ctes}); + + @override + T accept(AstVisitor visitor) => visitor.visitWithClause(this); + + @override + Iterable get childNodes => ctes; + + @override + bool contentEquals(WithClause other) => other.recursive == recursive; +} + +class CommonTableExpression extends AstNode { + final String cteTableName; + + /// If this common table expression has explicit column names, e.g. with + /// `cnt(x) AS (...)`, contains the column names (`['x']`, in that case). + /// Otherwise null. + final List columnNames; + final BaseSelectStatement as; + + Token asToken; + IdentifierToken tableNameToken; + + CommonTableExpression( + {@required this.cteTableName, this.columnNames, @required this.as}); + + @override + T accept(AstVisitor visitor) { + return visitor.visitCommonTableExpression(this); + } + + @override + Iterable get childNodes => [as]; + + @override + bool contentEquals(CommonTableExpression other) { + return other.cteTableName == cteTableName; + } +} diff --git a/sqlparser/lib/src/ast/common/queryables.dart b/sqlparser/lib/src/ast/common/queryables.dart index c24ce031..887ad5b7 100644 --- a/sqlparser/lib/src/ast/common/queryables.dart +++ b/sqlparser/lib/src/ast/common/queryables.dart @@ -38,7 +38,7 @@ class TableReference extends TableOrSubquery @override final String as; - TableReference(this.tableName, this.as); + TableReference(this.tableName, [this.as]); @override Iterable get childNodes => const []; diff --git a/sqlparser/lib/src/ast/statements/delete.dart b/sqlparser/lib/src/ast/statements/delete.dart index 450d48f8..4f1c3e05 100644 --- a/sqlparser/lib/src/ast/statements/delete.dart +++ b/sqlparser/lib/src/ast/statements/delete.dart @@ -1,19 +1,22 @@ part of '../ast.dart'; -class DeleteStatement extends Statement - with CrudStatement - implements HasWhereClause { +class DeleteStatement extends CrudStatement implements HasWhereClause { final TableReference from; @override final Expression where; - DeleteStatement({@required this.from, this.where}); + DeleteStatement({WithClause withClause, @required this.from, this.where}) + : super._(withClause); @override T accept(AstVisitor visitor) => visitor.visitDeleteStatement(this); @override - Iterable get childNodes => [from, if (where != null) where]; + Iterable get childNodes => [ + if (withClause != null) withClause, + from, + if (where != null) where, + ]; @override bool contentEquals(DeleteStatement other) => true; diff --git a/sqlparser/lib/src/ast/statements/insert.dart b/sqlparser/lib/src/ast/statements/insert.dart index e194bff8..397eb2b2 100644 --- a/sqlparser/lib/src/ast/statements/insert.dart +++ b/sqlparser/lib/src/ast/statements/insert.dart @@ -10,7 +10,7 @@ enum InsertMode { insertOrIgnore } -class InsertStatement extends Statement with CrudStatement { +class InsertStatement extends CrudStatement { final InsertMode mode; final TableReference table; final List targetColumns; @@ -28,16 +28,19 @@ class InsertStatement extends Statement with CrudStatement { // todo parse upsert clauses InsertStatement( - {this.mode = InsertMode.insert, + {WithClause withClause, + this.mode = InsertMode.insert, @required this.table, @required this.targetColumns, - @required this.source}); + @required this.source}) + : super._(withClause); @override T accept(AstVisitor visitor) => visitor.visitInsertStatement(this); @override Iterable get childNodes sync* { + if (withClause != null) yield withClause; yield table; yield* targetColumns; yield* source.childNodes; diff --git a/sqlparser/lib/src/ast/statements/select.dart b/sqlparser/lib/src/ast/statements/select.dart index 424955a9..6df707cc 100644 --- a/sqlparser/lib/src/ast/statements/select.dart +++ b/sqlparser/lib/src/ast/statements/select.dart @@ -1,11 +1,12 @@ part of '../ast.dart'; -abstract class BaseSelectStatement extends Statement - with CrudStatement, ResultSet { +abstract class BaseSelectStatement extends CrudStatement with ResultSet { /// The resolved list of columns returned by this select statements. Not /// available from the parse tree, will be set later by the analyzer. @override List resolvedColumns; + + BaseSelectStatement._(WithClause withClause) : super._(withClause); } class SelectStatement extends BaseSelectStatement implements HasWhereClause { @@ -22,14 +23,16 @@ class SelectStatement extends BaseSelectStatement implements HasWhereClause { final LimitBase limit; SelectStatement( - {this.distinct = false, + {WithClause withClause, + this.distinct = false, this.columns, this.from, this.where, this.groupBy, this.windowDeclarations = const [], this.orderBy, - this.limit}); + this.limit}) + : super._(withClause); @override T accept(AstVisitor visitor) { @@ -39,6 +42,7 @@ class SelectStatement extends BaseSelectStatement implements HasWhereClause { @override Iterable get childNodes { return [ + if (withClause != null) withClause, ...columns, if (from != null) ...from, if (where != null) where, @@ -64,13 +68,14 @@ class CompoundSelectStatement extends BaseSelectStatement { // part of the last compound select statement in [additional] CompoundSelectStatement({ + WithClause withClause, @required this.base, this.additional = const [], - }); + }) : super._(withClause); @override Iterable get childNodes { - return [base, ...additional]; + return [if (withClause != null) withClause, base, ...additional]; } @override diff --git a/sqlparser/lib/src/ast/statements/statement.dart b/sqlparser/lib/src/ast/statements/statement.dart index f46d0750..ef498e57 100644 --- a/sqlparser/lib/src/ast/statements/statement.dart +++ b/sqlparser/lib/src/ast/statements/statement.dart @@ -4,8 +4,13 @@ abstract class Statement extends AstNode { Token semicolon; } -/// Marker mixin for statements that read from an existing table structure. -mixin CrudStatement on Statement {} +/// A statement that reads from an existing table structure and has an optional +/// `WITH` clause. +abstract class CrudStatement extends Statement { + WithClause withClause; + + CrudStatement._(this.withClause); +} /// Interface for statements that have a primary where clause (select, update, /// delete). diff --git a/sqlparser/lib/src/ast/statements/update.dart b/sqlparser/lib/src/ast/statements/update.dart index f7c8877f..6c9dc048 100644 --- a/sqlparser/lib/src/ast/statements/update.dart +++ b/sqlparser/lib/src/ast/statements/update.dart @@ -16,9 +16,7 @@ const Map _tokensToMode = { TokenType.ignore: FailureMode.ignore, }; -class UpdateStatement extends Statement - with CrudStatement - implements HasWhereClause { +class UpdateStatement extends CrudStatement implements HasWhereClause { final FailureMode or; final TableReference table; final List set; @@ -26,13 +24,23 @@ class UpdateStatement extends Statement final Expression where; UpdateStatement( - {this.or, @required this.table, @required this.set, this.where}); + {WithClause withClause, + this.or, + @required this.table, + @required this.set, + this.where}) + : super._(withClause); @override T accept(AstVisitor visitor) => visitor.visitUpdateStatement(this); @override - Iterable get childNodes => [table, ...set, if (where != null) where]; + Iterable get childNodes => [ + if (withClause != null) withClause, + table, + ...set, + if (where != null) where, + ]; @override bool contentEquals(UpdateStatement other) { diff --git a/sqlparser/lib/src/reader/parser/crud.dart b/sqlparser/lib/src/reader/parser/crud.dart index 7fcc7a37..b8e5d01e 100644 --- a/sqlparser/lib/src/reader/parser/crud.dart +++ b/sqlparser/lib/src/reader/parser/crud.dart @@ -1,12 +1,78 @@ part of 'parser.dart'; mixin CrudParser on ParserBase { + CrudStatement _crud() { + final withClause = _withClause(); + + if (_check(TokenType.select)) { + return select(withClause: withClause); + } else if (_check(TokenType.delete)) { + return _deleteStmt(withClause); + } else if (_check(TokenType.update)) { + return _update(withClause); + } else if (_check(TokenType.insert)) { + return _insertStmt(withClause); + } + return null; + } + + WithClause _withClause() { + if (!_matchOne(TokenType.$with)) return null; + final withToken = _previous; + + final recursive = _matchOne(TokenType.recursive); + final recursiveToken = recursive ? _previous : null; + + final ctes = []; + do { + final name = _consumeIdentifier('Expected name for common table'); + List columnNames; + + // can optionally declare the column names in (foo, bar, baz) syntax + if (_matchOne(TokenType.leftParen)) { + columnNames = []; + do { + final identifier = _consumeIdentifier('Expected column name'); + columnNames.add(identifier.identifier); + } while (_matchOne(TokenType.comma)); + + _consume(TokenType.rightParen, + 'Expected closing bracket after column names'); + } + + final asToken = _consume(TokenType.as, 'Expected AS'); + + const msg = 'Expected select statement in brackets'; + _consume(TokenType.leftParen, msg); + final selectStmt = select() ?? _error(msg); + _consume(TokenType.rightParen, msg); + + ctes.add(CommonTableExpression( + cteTableName: name.identifier, + columnNames: columnNames, + as: selectStmt, + ) + ..setSpan(name, _previous) + ..asToken = asToken + ..tableNameToken = name); + } while (_matchOne(TokenType.comma)); + + return WithClause( + recursive: recursive, + ctes: ctes, + ) + ..setSpan(withToken, _previous) + ..recursiveToken = recursiveToken + ..withToken = withToken; + } + @override - BaseSelectStatement select({bool noCompound}) { + BaseSelectStatement select({bool noCompound, WithClause withClause}) { if (noCompound == true) { - return _selectNoCompound(); + return _selectNoCompound(withClause); } else { - final first = _selectNoCompound(); + final firstTokenOfBase = _peek; + final first = _selectNoCompound(withClause); final parts = []; while (true) { @@ -19,18 +85,24 @@ mixin CrudParser on ParserBase { } if (parts.isEmpty) { - // no compound parts, just return the simple select statement + // no compound parts, just return the simple select statement. return first; } else { + // remove with clause from base select, it belongs to the compound + // select. + first.withClause = null; + first.first = firstTokenOfBase; + return CompoundSelectStatement( + withClause: withClause, base: first, additional: parts, - )..setSpan(first.first, _previous); + )..setSpan(withClause?.first ?? first.first, _previous); } } } - SelectStatement _selectNoCompound() { + SelectStatement _selectNoCompound([WithClause withClause]) { if (!_match(const [TokenType.select])) return null; final selectToken = _previous; @@ -54,7 +126,9 @@ mixin CrudParser on ParserBase { final orderBy = _orderBy(); final limit = _limit(); + final first = withClause?.first ?? selectToken; return SelectStatement( + withClause: withClause, distinct: distinct, columns: resultColumns, from: from, @@ -63,7 +137,7 @@ mixin CrudParser on ParserBase { windowDeclarations: windowDecls, orderBy: orderBy, limit: limit, - )..setSpan(selectToken, _previous); + )..setSpan(first, _previous); } CompoundSelectPart _compoundSelectPart() { @@ -413,7 +487,7 @@ mixin CrudParser on ParserBase { } } - DeleteStatement _deleteStmt() { + DeleteStatement _deleteStmt([WithClause withClause]) { if (!_matchOne(TokenType.delete)) return null; final deleteToken = _previous; @@ -429,11 +503,14 @@ mixin CrudParser on ParserBase { where = expression(); } - return DeleteStatement(from: table, where: where) - ..setSpan(deleteToken, _previous); + return DeleteStatement( + withClause: withClause, + from: table, + where: where, + )..setSpan(withClause?.first ?? deleteToken, _previous); } - UpdateStatement _update() { + UpdateStatement _update([WithClause withClause]) { if (!_matchOne(TokenType.update)) return null; final updateToken = _previous; @@ -461,11 +538,15 @@ mixin CrudParser on ParserBase { final where = _where(); return UpdateStatement( - or: failureMode, table: table, set: set, where: where) - ..setSpan(updateToken, _previous); + withClause: withClause, + or: failureMode, + table: table, + set: set, + where: where, + )..setSpan(withClause?.first ?? updateToken, _previous); } - InsertStatement _insertStmt() { + InsertStatement _insertStmt([WithClause withClause]) { if (!_match(const [TokenType.insert, TokenType.replace])) return null; final firstToken = _previous; @@ -513,11 +594,12 @@ mixin CrudParser on ParserBase { final source = _insertSource(); return InsertStatement( + withClause: withClause, mode: insertMode, table: table, targetColumns: targetColumns, source: source, - )..setSpan(firstToken, _previous); + )..setSpan(withClause?.first ?? firstToken, _previous); } InsertSource _insertSource() { diff --git a/sqlparser/lib/src/reader/parser/parser.dart b/sqlparser/lib/src/reader/parser/parser.dart index ae25dc29..c083eb9b 100644 --- a/sqlparser/lib/src/reader/parser/parser.dart +++ b/sqlparser/lib/src/reader/parser/parser.dart @@ -136,7 +136,7 @@ abstract class ParserBase { } @alwaysThrows - void _error(String message) { + Null _error(String message) { final error = ParsingError(_peek, message); errors.add(error); throw error; @@ -212,17 +212,6 @@ class Parser extends ParserBase return stmt..setSpan(first, _previous); } - CrudStatement _crud() { - // writing select() ?? _deleteStmt() and so on doesn't cast to CrudStatement - // for some reason. - CrudStatement stmt = select(); - stmt ??= _deleteStmt(); - stmt ??= _update(); - stmt ??= _insertStmt(); - - return stmt; - } - MoorFile moorFile() { final first = _peek; final foundComponents = []; diff --git a/sqlparser/lib/src/reader/tokenizer/token.dart b/sqlparser/lib/src/reader/tokenizer/token.dart index 753a6342..4e9da7a7 100644 --- a/sqlparser/lib/src/reader/tokenizer/token.dart +++ b/sqlparser/lib/src/reader/tokenizer/token.dart @@ -121,6 +121,7 @@ enum TokenType { create, table, $if, + $with, without, rowid, constraint, @@ -138,6 +139,7 @@ enum TokenType { restrict, no, action, + recursive, semicolon, comment, @@ -210,6 +212,7 @@ const Map keywords = { 'CREATE': TokenType.create, 'TABLE': TokenType.table, 'IF': TokenType.$if, + 'WITH': TokenType.$with, 'WITHOUT': TokenType.without, 'ROWID': TokenType.rowid, 'CONSTRAINT': TokenType.constraint, @@ -230,6 +233,7 @@ const Map keywords = { 'OVER': TokenType.over, 'PARTITION': TokenType.partition, 'RANGE': TokenType.range, + 'RECURSIVE': TokenType.recursive, 'ROWS': TokenType.rows, 'GROUPS': TokenType.groups, 'UNBOUNDED': TokenType.unbounded, diff --git a/sqlparser/test/parser/select/common_table_expression_test.dart b/sqlparser/test/parser/select/common_table_expression_test.dart new file mode 100644 index 00000000..0d4cb660 --- /dev/null +++ b/sqlparser/test/parser/select/common_table_expression_test.dart @@ -0,0 +1,71 @@ +import 'package:sqlparser/sqlparser.dart'; +import 'package:test/test.dart'; + +import '../utils.dart'; + +void main() { + test('parses WITH clauses', () { + testStatement( + ''' + WITH RECURSIVE + cnt(x) AS ( + SELECT 1 + UNION ALL + SELECT x+1 FROM cnt + LIMIT 1000000 + ) + SELECT x FROM cnt; + ''', + SelectStatement( + withClause: WithClause( + recursive: true, + ctes: [ + CommonTableExpression( + cteTableName: 'cnt', + columnNames: ['x'], + as: CompoundSelectStatement( + base: SelectStatement( + columns: [ + ExpressionResultColumn( + expression: NumericLiteral( + 1, + token(TokenType.numberLiteral), + ), + ), + ], + ), + additional: [ + CompoundSelectPart( + mode: CompoundSelectMode.unionAll, + select: SelectStatement( + columns: [ + ExpressionResultColumn( + expression: BinaryExpression( + Reference(columnName: 'x'), + token(TokenType.plus), + NumericLiteral(1, token(TokenType.numberLiteral)), + ), + ), + ], + from: [TableReference('cnt')], + limit: Limit( + count: NumericLiteral( + 1000000, + token(TokenType.numberLiteral), + ), + ), + ), + ), + ], + ), + ), + ], + ), + columns: [ + ExpressionResultColumn(expression: Reference(columnName: 'x')), + ], + from: [TableReference('cnt')], + ), + ); + }); +}