From 3be320d0c517e3f0429c708cfed18727d5a995f3 Mon Sep 17 00:00:00 2001 From: Simon Binder Date: Wed, 8 Sep 2021 22:53:57 +0200 Subject: [PATCH] Parse BEGIN and COMMIT statements --- sqlparser/lib/src/ast/ast.dart | 1 + .../lib/src/ast/statements/transaction.dart | 40 +++ sqlparser/lib/src/ast/visitor.dart | 14 +- sqlparser/lib/src/reader/parser.dart | 111 +++++- sqlparser/lib/src/utils/ast_equality.dart | 189 +++++----- sqlparser/lib/utils/node_to_text.dart | 338 ++++++++++-------- sqlparser/test/parser/errors_test.dart | 10 + sqlparser/test/parser/misc_test.dart | 35 ++ sqlparser/test/parser/utils.dart | 10 + sqlparser/test/utils/node_to_text_test.dart | 11 + 10 files changed, 504 insertions(+), 255 deletions(-) create mode 100644 sqlparser/lib/src/ast/statements/transaction.dart create mode 100644 sqlparser/test/parser/errors_test.dart create mode 100644 sqlparser/test/parser/misc_test.dart diff --git a/sqlparser/lib/src/ast/ast.dart b/sqlparser/lib/src/ast/ast.dart index f7e12ba3..bff3330c 100644 --- a/sqlparser/lib/src/ast/ast.dart +++ b/sqlparser/lib/src/ast/ast.dart @@ -30,6 +30,7 @@ export 'statements/insert.dart'; export 'statements/invalid.dart'; export 'statements/select.dart'; export 'statements/statement.dart'; +export 'statements/transaction.dart'; export 'statements/update.dart'; export 'visitor.dart'; diff --git a/sqlparser/lib/src/ast/statements/transaction.dart b/sqlparser/lib/src/ast/statements/transaction.dart new file mode 100644 index 00000000..f32f95e4 --- /dev/null +++ b/sqlparser/lib/src/ast/statements/transaction.dart @@ -0,0 +1,40 @@ +import '../../reader/tokenizer/token.dart'; +import '../node.dart'; +import '../visitor.dart'; +import 'statement.dart'; + +enum TransactionMode { none, deferred, immediate, exclusive } + +class BeginTransactionStatement extends Statement { + Token? begin, modeToken, transaction; + + final TransactionMode mode; + + BeginTransactionStatement([this.mode = TransactionMode.none]); + + @override + R accept(AstVisitor visitor, A arg) { + return visitor.visitBeginTransaction(this, arg); + } + + @override + Iterable get childNodes => const Iterable.empty(); + + @override + void transformChildren(Transformer transformer, A arg) {} +} + +class CommitStatement extends Statement { + Token? commitOrEnd, transaction; + + @override + R accept(AstVisitor visitor, A arg) { + return visitor.visitCommitStatement(this, arg); + } + + @override + Iterable get childNodes => const Iterable.empty(); + + @override + void transformChildren(Transformer transformer, A arg) {} +} diff --git a/sqlparser/lib/src/ast/visitor.dart b/sqlparser/lib/src/ast/visitor.dart index 00849476..51434ba2 100644 --- a/sqlparser/lib/src/ast/visitor.dart +++ b/sqlparser/lib/src/ast/visitor.dart @@ -1,7 +1,5 @@ import 'ast.dart'; -import 'expressions/raise.dart'; - abstract class AstVisitor { R visitSelectStatement(SelectStatement e, A arg); R visitCompoundSelectStatement(CompoundSelectStatement e, A arg); @@ -92,6 +90,8 @@ abstract class AstVisitor { R visitNamedVariable(ColonNamedVariable e, A arg); R visitBlock(Block block, A arg); + R visitBeginTransaction(BeginTransactionStatement e, A arg); + R visitCommitStatement(CommitStatement e, A arg); R visitMoorFile(MoorFile e, A arg); R visitMoorImportStatement(ImportStatement e, A arg); @@ -393,6 +393,16 @@ class RecursiveVisitor implements AstVisitor { return defaultNode(e, arg); } + @override + R? visitBeginTransaction(BeginTransactionStatement e, A arg) { + return visitStatement(e, arg); + } + + @override + R? visitCommitStatement(CommitStatement e, A arg) { + return visitStatement(e, arg); + } + // Moor-specific additions @override R? visitMoorFile(MoorFile e, A arg) { diff --git a/sqlparser/lib/src/reader/parser.dart b/sqlparser/lib/src/reader/parser.dart index 4cd958d1..18d88377 100644 --- a/sqlparser/lib/src/reader/parser.dart +++ b/sqlparser/lib/src/reader/parser.dart @@ -110,6 +110,11 @@ class Parser { return _peek.type == type; } + bool _checkAny(Iterable type) { + if (_isAtEnd) return false; + return type.contains(_peek.type); + } + /// Returns whether the next token is an [TokenType.identifier] or a /// [KeywordToken]. If this method returns true, calling [_consumeIdentifier] /// with same [lenient] parameter will now throw. @@ -177,18 +182,45 @@ class Parser { InvalidStatement(); } - Statement statement() { - final first = _peek; - Statement? stmt = _crud(); - stmt ??= _create(); + Statement _statementWithoutSemicolon() { + if (_checkAny(const [ + TokenType.$with, + TokenType.select, + TokenType.$values, + TokenType.delete, + TokenType.update, + TokenType.insert, + TokenType.replace, + ])) { + return _crud()!; + } + + if (_check(TokenType.create)) { + return _create()!; + } + + if (_check(TokenType.begin)) { + return _beginStatement(); + } + if (_checkAny(const [TokenType.commit, TokenType.end])) { + return _commit(); + } if (enableMoorExtensions) { - stmt ??= _import() ?? _declaredStatement(); + if (_check(TokenType.import)) { + return _import()!; + } + if (_check(TokenType.identifier) || _peek is KeywordToken) { + return _declaredStatement()!; + } } - if (stmt == null) { - _error('Expected a sql statement to start here'); - } + _error('Expected a sql statement to start here'); + } + + Statement statement() { + final first = _peek; + final stmt = _statementWithoutSemicolon(); if (_matchOne(TokenType.semicolon)) { stmt.semicolon = _previous; @@ -909,6 +941,61 @@ class Parser { } } + Statement _beginStatement() { + final begin = _consume(TokenType.begin); + Token? modeToken; + var mode = TransactionMode.none; + + if (_match( + const [TokenType.deferred, TokenType.immediate, TokenType.exclusive])) { + modeToken = _previous; + + switch (modeToken.type) { + case TokenType.deferred: + mode = TransactionMode.deferred; + break; + case TokenType.immediate: + mode = TransactionMode.immediate; + break; + case TokenType.exclusive: + mode = TransactionMode.exclusive; + break; + default: + throw AssertionError('unreachable'); + } + } + + Token? transaction; + if (_matchOne(TokenType.transaction)) { + transaction = _previous; + } + + return BeginTransactionStatement(mode) + ..setSpan(begin, _previous) + ..begin = begin + ..modeToken = modeToken + ..transaction = transaction; + } + + CommitStatement _commit() { + Token commitOrEnd; + if (_match(const [TokenType.commit, TokenType.end])) { + commitOrEnd = _previous; + } else { + _error('Expected COMMIT or END here'); + } + + Token? transaction; + if (_matchOne(TokenType.transaction)) { + transaction = _previous; + } + + return CommitStatement() + ..setSpan(commitOrEnd, _previous) + ..commitOrEnd = commitOrEnd + ..transaction = transaction; + } + CrudStatement? _crud() { final withClause = _withClause(); @@ -921,6 +1008,14 @@ class Parser { } else if (_check(TokenType.insert) || _check(TokenType.replace)) { return _insertStmt(withClause); } + + // A WITH clause without a following select, insert, delete or update + // is invalid! + if (withClause != null) { + _error('Expected a SELECT, INSERT, UPDATE or DELETE statement to ' + 'follow this WITH clause.'); + } + return null; } diff --git a/sqlparser/lib/src/utils/ast_equality.dart b/sqlparser/lib/src/utils/ast_equality.dart index 3b04d07c..05f6eda9 100644 --- a/sqlparser/lib/src/utils/ast_equality.dart +++ b/sqlparser/lib/src/utils/ast_equality.dart @@ -1,6 +1,25 @@ import 'package:collection/collection.dart'; import 'package:sqlparser/src/ast/ast.dart'; +/// Checks whether [a] and [b] are equal. If they aren't, throws an exception. +void enforceEqual(AstNode a, AstNode b) { + EqualityEnforcingVisitor(a).visit(b, null); +} + +void enforceEqualIterable(Iterable a, Iterable b) { + final childrenA = a.iterator; + final childrenB = b.iterator; + + // always move both iterators + while (childrenA.moveNext() & childrenB.moveNext()) { + enforceEqual(childrenA.current, childrenB.current); + } + + if (childrenA.moveNext() || childrenB.moveNext()) { + throw ArgumentError("$a and $b don't have an equal amount of children"); + } +} + /// Visitor enforcing the equality of two ast nodes. class EqualityEnforcingVisitor implements AstVisitor { // The current ast node. Visitor methods will compare the node they receive to @@ -19,58 +38,6 @@ class EqualityEnforcingVisitor implements AstVisitor { EqualityEnforcingVisitor(this._current, {bool considerChildren = true}) : _considerChildren = considerChildren; - void _check(AstNode? childOfCurrent, AstNode? childOfOther) { - if (identical(childOfCurrent, childOfOther)) return; - - if ((childOfCurrent == null) != (childOfOther == null)) { - throw NotEqualException('$childOfCurrent and $childOfOther'); - } - - // Both non nullable here - final savedCurrent = _current; - _current = childOfCurrent!; - visit(childOfOther!, null); - _current = savedCurrent; - } - - void _checkChildren(AstNode other) { - if (!_considerChildren) return; - - final currentChildren = _current.childNodes.iterator; - final otherChildren = other.childNodes.iterator; - - while (currentChildren.moveNext()) { - if (otherChildren.moveNext()) { - _check(currentChildren.current, otherChildren.current); - } else { - // Current has more elements than other - throw NotEqualException( - "$_current and $other don't have an equal amount of children"); - } - } - - if (otherChildren.moveNext()) { - // Other has more elements than current - throw NotEqualException( - "$_current and $other don't have an equal amount of children"); - } - } - - Never _notEqual(AstNode other) { - throw NotEqualException('$_current and $other'); - } - - T _currentAs(T context) { - final current = _current; - if (current is T) return current; - - _notEqual(context); - } - - void _assert(bool contentEqual, AstNode context) { - if (!contentEqual) _notEqual(context); - } - @override void visitAggregateExpression(AggregateExpression e, void arg) { final current = _currentAs(e); @@ -78,6 +45,13 @@ class EqualityEnforcingVisitor implements AstVisitor { _checkChildren(e); } + @override + void visitBeginTransaction(BeginTransactionStatement e, void arg) { + final current = _currentAs(e); + _assert(current.mode == e.mode, e); + _checkChildren(e); + } + @override void visitBetweenExpression(BetweenExpression e, void arg) { final current = _currentAs(e); @@ -165,6 +139,12 @@ class EqualityEnforcingVisitor implements AstVisitor { _checkChildren(e); } + @override + void visitCommitStatement(CommitStatement e, void arg) { + _currentAs(e); + _checkChildren(e); + } + @override void visitCommonTableExpression(CommonTableExpression e, void arg) { final current = _currentAs(e); @@ -348,24 +328,6 @@ class EqualityEnforcingVisitor implements AstVisitor { _checkChildren(e); } - @override - void visitInExpression(InExpression e, void arg) { - final current = _currentAs(e); - _assert(current.not == e.not, e); - _checkChildren(e); - } - - @override - void visitRaiseExpression(RaiseExpression e, void arg) { - final current = _currentAs(e); - _assert( - current.raiseKind == e.raiseKind && - current.errorMessage == e.errorMessage, - e, - ); - _checkChildren(e); - } - @override void visitIndexedColumn(IndexedColumn e, void arg) { final current = _currentAs(e); @@ -373,6 +335,13 @@ class EqualityEnforcingVisitor implements AstVisitor { _checkChildren(e); } + @override + void visitInExpression(InExpression e, void arg) { + final current = _currentAs(e); + _assert(current.not == e.not, e); + _checkChildren(e); + } + @override void visitInsertStatement(InsertStatement e, void arg) { final current = _currentAs(e); @@ -544,6 +513,17 @@ class EqualityEnforcingVisitor implements AstVisitor { _checkChildren(e); } + @override + void visitRaiseExpression(RaiseExpression e, void arg) { + final current = _currentAs(e); + _assert( + current.raiseKind == e.raiseKind && + current.errorMessage == e.errorMessage, + e, + ); + _checkChildren(e); + } + @override void visitReference(Reference e, void arg) { final current = _currentAs(e); @@ -721,11 +701,58 @@ class EqualityEnforcingVisitor implements AstVisitor { _assert(current.recursive == e.recursive, e); _checkChildren(e); } -} -/// Checks whether [a] and [b] are equal. If they aren't, throws an exception. -void enforceEqual(AstNode a, AstNode b) { - EqualityEnforcingVisitor(a).visit(b, null); + void _assert(bool contentEqual, AstNode context) { + if (!contentEqual) _notEqual(context); + } + + void _check(AstNode? childOfCurrent, AstNode? childOfOther) { + if (identical(childOfCurrent, childOfOther)) return; + + if ((childOfCurrent == null) != (childOfOther == null)) { + throw NotEqualException('$childOfCurrent and $childOfOther'); + } + + // Both non nullable here + final savedCurrent = _current; + _current = childOfCurrent!; + visit(childOfOther!, null); + _current = savedCurrent; + } + + void _checkChildren(AstNode other) { + if (!_considerChildren) return; + + final currentChildren = _current.childNodes.iterator; + final otherChildren = other.childNodes.iterator; + + while (currentChildren.moveNext()) { + if (otherChildren.moveNext()) { + _check(currentChildren.current, otherChildren.current); + } else { + // Current has more elements than other + throw NotEqualException( + "$_current and $other don't have an equal amount of children"); + } + } + + if (otherChildren.moveNext()) { + // Other has more elements than current + throw NotEqualException( + "$_current and $other don't have an equal amount of children"); + } + } + + T _currentAs(T context) { + final current = _current; + if (current is T) return current; + + _notEqual(context); + } + + Never _notEqual(AstNode other) { + throw NotEqualException('$_current and $other'); + } } /// Thrown by the [EqualityEnforcingVisitor] when two nodes were determined to @@ -740,17 +767,3 @@ class NotEqualException implements Exception { return 'Not equal: $message'; } } - -void enforceEqualIterable(Iterable a, Iterable b) { - final childrenA = a.iterator; - final childrenB = b.iterator; - - // always move both iterators - while (childrenA.moveNext() & childrenB.moveNext()) { - enforceEqual(childrenA.current, childrenB.current); - } - - if (childrenA.moveNext() || childrenB.moveNext()) { - throw ArgumentError("$a and $b don't have an equal amount of children"); - } -} diff --git a/sqlparser/lib/utils/node_to_text.dart b/sqlparser/lib/utils/node_to_text.dart index c8d90b1e..b4bff023 100644 --- a/sqlparser/lib/utils/node_to_text.dart +++ b/sqlparser/lib/utils/node_to_text.dart @@ -7,25 +7,6 @@ import 'package:charcode/charcode.dart'; import 'package:sqlparser/sqlparser.dart'; import 'package:sqlparser/src/reader/tokenizer/token.dart'; -/// Defines the [toSql] extension method that turns ast nodes into a compatible -/// textual representation. -/// -/// Parsing the output of [toSql] will result in an equal AST. -extension NodeToText on AstNode { - /// Obtains a textual representation for AST nodes. - /// - /// Parsing the output of [toSql] will result in an equal AST. Since only the - /// AST is used, the output will not contain comments. It's possible for the - /// output to have more than just whitespace changes if there are multiple - /// ways to represent an equivalent node (e.g. the no-op `FOR EACH ROW` on - /// triggers). - String toSql() { - final builder = NodeSqlBuilder(); - builder.visit(this, null); - return builder.buffer.toString(); - } -} - class NodeSqlBuilder extends AstVisitor { final StringSink buffer; @@ -34,42 +15,6 @@ class NodeSqlBuilder extends AstVisitor { NodeSqlBuilder([StringSink? buffer]) : buffer = buffer ?? StringBuffer(); - void _join(Iterable nodes, String separatingSymbol) { - var isFirst = true; - - for (final node in nodes) { - if (!isFirst) { - _symbol(separatingSymbol, spaceAfter: true); - } - - visit(node, null); - isFirst = false; - } - } - - void _identifier(String identifier, - {bool spaceBefore = true, bool spaceAfter = true}) { - if (isKeywordLexeme(identifier) || identifier.contains(' ')) { - identifier = '"$identifier"'; - } - - _symbol(identifier, spaceBefore: spaceBefore, spaceAfter: spaceAfter); - } - - void _ifNotExists(bool ifNotExists) { - if (ifNotExists) { - _keyword(TokenType.$if); - _keyword(TokenType.not); - _keyword(TokenType.exists); - } - } - - void _keyword(TokenType type) { - _symbol(reverseKeywords[type]!, spaceAfter: true, spaceBefore: true); - } - - void _space() => buffer.writeCharCode($space); - /// Writes a space character if [needsSpace] is set. /// /// This also resets [needsSpace] to `false`. @@ -80,35 +25,6 @@ class NodeSqlBuilder extends AstVisitor { } } - void _stringLiteral(String content) { - final escapedChars = content.replaceAll("'", "''"); - _symbol("'$escapedChars'", spaceBefore: true, spaceAfter: true); - } - - void _symbol(String lexeme, - {bool spaceBefore = false, bool spaceAfter = false}) { - if (needsSpace && spaceBefore) { - _space(); - } - - buffer.write(lexeme); - needsSpace = spaceAfter; - } - - void _where(Expression? where) { - if (where != null) { - _keyword(TokenType.where); - visit(where, null); - } - } - - void _from(Queryable? from) { - if (from != null) { - _keyword(TokenType.from); - visit(from, null); - } - } - @override void visitAggregateExpression(AggregateExpression e, void arg) { _symbol(e.name); @@ -134,6 +50,25 @@ class NodeSqlBuilder extends AstVisitor { } } + @override + void visitBeginTransaction(BeginTransactionStatement e, void arg) { + _keyword(TokenType.begin); + + switch (e.mode) { + case TransactionMode.none: + break; + case TransactionMode.deferred: + _keyword(TokenType.deferred); + break; + case TransactionMode.immediate: + _keyword(TokenType.immediate); + break; + case TransactionMode.exclusive: + _keyword(TokenType.exclusive); + break; + } + } + @override void visitBetweenExpression(BetweenExpression e, void arg) { visit(e.check, arg); @@ -229,21 +164,6 @@ class NodeSqlBuilder extends AstVisitor { _identifier(e.collation); } - void _conflictClause(ConflictClause? clause) { - if (clause != null) { - _keyword(TokenType.on); - _keyword(TokenType.conflict); - - _keyword(const { - ConflictClause.rollback: TokenType.rollback, - ConflictClause.abort: TokenType.abort, - ConflictClause.fail: TokenType.fail, - ConflictClause.ignore: TokenType.ignore, - ConflictClause.replace: TokenType.replace, - }[clause]!); - } - } - @override void visitColumnConstraint(ColumnConstraint e, void arg) { if (e.name != null) { @@ -305,6 +225,11 @@ class NodeSqlBuilder extends AstVisitor { visitList(e.constraints, arg); } + @override + void visitCommitStatement(CommitStatement e, void arg) { + _keyword(TokenType.commit); + } + @override void visitCommonTableExpression(CommonTableExpression e, void arg) { _identifier(e.cteTableName); @@ -549,6 +474,15 @@ class NodeSqlBuilder extends AstVisitor { _join(e.parameters, ','); } + @override + void visitExpressionResultColumn(ExpressionResultColumn e, void arg) { + visit(e.expression, arg); + if (e.as != null) { + _keyword(TokenType.as); + _identifier(e.as!); + } + } + @override void visitForeignKeyClause(ForeignKeyClause e, void arg) { _keyword(TokenType.references); @@ -673,6 +607,12 @@ class NodeSqlBuilder extends AstVisitor { } } + @override + void visitIndexedColumn(IndexedColumn e, void arg) { + visit(e.expression, arg); + _orderingMode(e.ordering); + } + @override void visitInExpression(InExpression e, void arg) { visit(e.left, arg); @@ -685,30 +625,6 @@ class NodeSqlBuilder extends AstVisitor { visit(e.inside, arg); } - @override - void visitRaiseExpression(RaiseExpression e, void arg) { - _keyword(TokenType.raise); - _symbol('(', spaceBefore: true); - _keyword(const { - RaiseKind.ignore: TokenType.ignore, - RaiseKind.rollback: TokenType.rollback, - RaiseKind.abort: TokenType.abort, - RaiseKind.fail: TokenType.fail, - }[e.raiseKind]!); - - if (e.errorMessage != null) { - _symbol(',', spaceAfter: true); - _stringLiteral(e.errorMessage!); - } - _symbol(')', spaceAfter: true); - } - - @override - void visitIndexedColumn(IndexedColumn e, void arg) { - visit(e.expression, arg); - _orderingMode(e.ordering); - } - @override void visitInsertStatement(InsertStatement e, void arg) { visitNullable(e.withClause, arg); @@ -875,6 +791,12 @@ class NodeSqlBuilder extends AstVisitor { _symbol(';', spaceAfter: true); } + @override + void visitMoorNestedStarResultColumn(NestedStarResultColumn e, void arg) { + _identifier(e.tableName); + _symbol('.**', spaceAfter: true); + } + @override void visitMoorStatementParameter(StatementParameter e, void arg) { if (e is VariableTypeHint) { @@ -937,15 +859,6 @@ class NodeSqlBuilder extends AstVisitor { _join(e.terms, ','); } - void _orderingMode(OrderingMode? mode) { - if (mode != null) { - _keyword(const { - OrderingMode.ascending: TokenType.asc, - OrderingMode.descending: TokenType.desc, - }[mode]!); - } - } - @override void visitOrderingTerm(OrderingTerm e, void arg) { visit(e.expression, arg); @@ -968,6 +881,24 @@ class NodeSqlBuilder extends AstVisitor { _symbol(')'); } + @override + void visitRaiseExpression(RaiseExpression e, void arg) { + _keyword(TokenType.raise); + _symbol('(', spaceBefore: true); + _keyword(const { + RaiseKind.ignore: TokenType.ignore, + RaiseKind.rollback: TokenType.rollback, + RaiseKind.abort: TokenType.abort, + RaiseKind.fail: TokenType.fail, + }[e.raiseKind]!); + + if (e.errorMessage != null) { + _symbol(',', spaceAfter: true); + _stringLiteral(e.errorMessage!); + } + _symbol(')', spaceAfter: true); + } + @override void visitReference(Reference e, void arg) { var didWriteSpaceBefore = false; @@ -988,31 +919,6 @@ class NodeSqlBuilder extends AstVisitor { spaceAfter: true, spaceBefore: !didWriteSpaceBefore); } - @override - void visitStarResultColumn(StarResultColumn e, void arg) { - if (e.tableName != null) { - _identifier(e.tableName!); - _symbol('.'); - } - - _symbol('*', spaceAfter: true, spaceBefore: e.tableName == null); - } - - @override - void visitMoorNestedStarResultColumn(NestedStarResultColumn e, void arg) { - _identifier(e.tableName); - _symbol('.**', spaceAfter: true); - } - - @override - void visitExpressionResultColumn(ExpressionResultColumn e, void arg) { - visit(e.expression, arg); - if (e.as != null) { - _keyword(TokenType.as); - _identifier(e.as!); - } - } - @override void visitReturning(Returning e, void arg) { _keyword(TokenType.returning); @@ -1081,6 +987,16 @@ class NodeSqlBuilder extends AstVisitor { _symbol('*', spaceAfter: true); } + @override + void visitStarResultColumn(StarResultColumn e, void arg) { + if (e.tableName != null) { + _identifier(e.tableName!); + _symbol('.'); + } + + _symbol('*', spaceAfter: true, spaceBefore: e.tableName == null); + } + @override void visitStringComparison(StringComparisonExpression e, void arg) { visit(e.left, arg); @@ -1316,4 +1232,112 @@ class NodeSqlBuilder extends AstVisitor { _join(e.ctes, ','); } + + void _conflictClause(ConflictClause? clause) { + if (clause != null) { + _keyword(TokenType.on); + _keyword(TokenType.conflict); + + _keyword(const { + ConflictClause.rollback: TokenType.rollback, + ConflictClause.abort: TokenType.abort, + ConflictClause.fail: TokenType.fail, + ConflictClause.ignore: TokenType.ignore, + ConflictClause.replace: TokenType.replace, + }[clause]!); + } + } + + void _from(Queryable? from) { + if (from != null) { + _keyword(TokenType.from); + visit(from, null); + } + } + + void _identifier(String identifier, + {bool spaceBefore = true, bool spaceAfter = true}) { + if (isKeywordLexeme(identifier) || identifier.contains(' ')) { + identifier = '"$identifier"'; + } + + _symbol(identifier, spaceBefore: spaceBefore, spaceAfter: spaceAfter); + } + + void _ifNotExists(bool ifNotExists) { + if (ifNotExists) { + _keyword(TokenType.$if); + _keyword(TokenType.not); + _keyword(TokenType.exists); + } + } + + void _join(Iterable nodes, String separatingSymbol) { + var isFirst = true; + + for (final node in nodes) { + if (!isFirst) { + _symbol(separatingSymbol, spaceAfter: true); + } + + visit(node, null); + isFirst = false; + } + } + + void _keyword(TokenType type) { + _symbol(reverseKeywords[type]!, spaceAfter: true, spaceBefore: true); + } + + void _orderingMode(OrderingMode? mode) { + if (mode != null) { + _keyword(const { + OrderingMode.ascending: TokenType.asc, + OrderingMode.descending: TokenType.desc, + }[mode]!); + } + } + + void _space() => buffer.writeCharCode($space); + + void _stringLiteral(String content) { + final escapedChars = content.replaceAll("'", "''"); + _symbol("'$escapedChars'", spaceBefore: true, spaceAfter: true); + } + + void _symbol(String lexeme, + {bool spaceBefore = false, bool spaceAfter = false}) { + if (needsSpace && spaceBefore) { + _space(); + } + + buffer.write(lexeme); + needsSpace = spaceAfter; + } + + void _where(Expression? where) { + if (where != null) { + _keyword(TokenType.where); + visit(where, null); + } + } +} + +/// Defines the [toSql] extension method that turns ast nodes into a compatible +/// textual representation. +/// +/// Parsing the output of [toSql] will result in an equal AST. +extension NodeToText on AstNode { + /// Obtains a textual representation for AST nodes. + /// + /// Parsing the output of [toSql] will result in an equal AST. Since only the + /// AST is used, the output will not contain comments. It's possible for the + /// output to have more than just whitespace changes if there are multiple + /// ways to represent an equivalent node (e.g. the no-op `FOR EACH ROW` on + /// triggers). + String toSql() { + final builder = NodeSqlBuilder(); + builder.visit(this, null); + return builder.buffer.toString(); + } } diff --git a/sqlparser/test/parser/errors_test.dart b/sqlparser/test/parser/errors_test.dart new file mode 100644 index 00000000..df5b3983 --- /dev/null +++ b/sqlparser/test/parser/errors_test.dart @@ -0,0 +1,10 @@ +import 'package:test/test.dart'; + +import 'utils.dart'; + +void main() { + test('WITH without following statement', () { + enforceError('WITH foo AS (SELECT * FROM bar)', + contains('to follow this WITH clause')); + }); +} diff --git a/sqlparser/test/parser/misc_test.dart b/sqlparser/test/parser/misc_test.dart new file mode 100644 index 00000000..b2289302 --- /dev/null +++ b/sqlparser/test/parser/misc_test.dart @@ -0,0 +1,35 @@ +import 'package:sqlparser/sqlparser.dart'; +import 'package:test/test.dart'; + +import 'utils.dart'; + +void main() { + group('BEGIN', () { + test('without mode', () { + testStatement('BEGIN;', BeginTransactionStatement()); + testStatement('BEGIN TRANSACTION;', BeginTransactionStatement()); + }); + + test('deferred', () { + testStatement('BEGIN DEFERRED;', + BeginTransactionStatement(TransactionMode.deferred)); + }); + + test('immediate', () { + testStatement('BEGIN IMMEDIATE;', + BeginTransactionStatement(TransactionMode.immediate)); + }); + + test('exclusive', () { + testStatement('BEGIN EXCLUSIVE TRANSACTION;', + BeginTransactionStatement(TransactionMode.exclusive)); + }); + }); + + test('COMMIT', () { + testStatement('COMMIT', CommitStatement()); + testStatement('END', CommitStatement()); + testStatement('COMMIT TRANSACTION', CommitStatement()); + testStatement('END TRANSACTION', CommitStatement()); + }); +} diff --git a/sqlparser/test/parser/utils.dart b/sqlparser/test/parser/utils.dart index fbd4204a..cd6c0da9 100644 --- a/sqlparser/test/parser/utils.dart +++ b/sqlparser/test/parser/utils.dart @@ -66,3 +66,13 @@ void enforceHasSpan(AstNode node) { throw ArgumentError('Node $problematic did not have a span'); } } + +void enforceError(String sql, Matcher textMatcher) { + final parsed = SqlEngine().parse(sql); + + expect( + parsed.errors, + contains( + isA().having((e) => e.message, 'message', textMatcher)), + ); +} diff --git a/sqlparser/test/utils/node_to_text_test.dart b/sqlparser/test/utils/node_to_text_test.dart index 825aec14..022399e4 100644 --- a/sqlparser/test/utils/node_to_text_test.dart +++ b/sqlparser/test/utils/node_to_text_test.dart @@ -167,6 +167,17 @@ CREATE UNIQUE INDEX my_idx ON t1 (c1, c2, c3) WHERE c1 < c3; }); }); + group('misc', () { + test('transactions', () { + testFormat('BEGIN DEFERRED TRANSACTION;'); + testFormat('BEGIN IMMEDIATE'); + testFormat('BEGIN EXCLUSIVE'); + + testFormat('COMMIT'); + testFormat('END TRANSACTION'); + }); + }); + group('query statements', () { group('select', () { test('with common table expressions', () {