diff --git a/sqlparser/CHANGELOG.md b/sqlparser/CHANGELOG.md index 07313de8..98702e07 100644 --- a/sqlparser/CHANGELOG.md +++ b/sqlparser/CHANGELOG.md @@ -1,6 +1,8 @@ ## unreleased - Added a argument type and argument to the visitor classes +- Experimental new type inference algorithm +- Support `CAST` expressions. ## 0.5.0 - Optionally support the `json1` module @@ -19,7 +21,6 @@ __0.3.0+1__: Accept `\r` characters as whitespace - ## 0.2.0 - Parse `CREATE TABLE` statements - Extract schema information from parsed create table statements with `SchemaFromCreateTable`. diff --git a/sqlparser/lib/src/analysis/context.dart b/sqlparser/lib/src/analysis/context.dart index 49f499a6..901e58d0 100644 --- a/sqlparser/lib/src/analysis/context.dart +++ b/sqlparser/lib/src/analysis/context.dart @@ -17,6 +17,9 @@ class AnalysisContext { /// outside. final AnalyzeStatementOptions stmtOptions; + /// Utilities to read types. + final SchemaFromCreateTable schemaSupport; + /// A resolver that can be used to obtain the type of a [Typeable]. This /// mostly applies to [Expression]s, [Reference]s, [Variable]s and /// [ResultSet.resolvedColumns] of a select statement. @@ -24,7 +27,7 @@ class AnalysisContext { /// Constructs a new analysis context from the AST and the source sql. AnalysisContext(this.root, this.sql, EngineOptions options, - {AnalyzeStatementOptions stmtOptions}) + {AnalyzeStatementOptions stmtOptions, this.schemaSupport}) : stmtOptions = stmtOptions ?? const AnalyzeStatementOptions() { types = TypeResolver(this, options); } diff --git a/sqlparser/lib/src/analysis/types/resolver.dart b/sqlparser/lib/src/analysis/types/resolver.dart index 65a1ac48..85cb82aa 100644 --- a/sqlparser/lib/src/analysis/types/resolver.dart +++ b/sqlparser/lib/src/analysis/types/resolver.dart @@ -130,6 +130,9 @@ class TypeResolver { } } else if (expr is CaseExpression) { return resolveExpression(expr.whens.first.then); + } else if (expr is CastExpression) { + final type = context.schemaSupport.resolveColumnType(expr.typeName); + return ResolveResult(type); } else if (expr is SubQuery) { final columns = expr.select.resultSet.resolvedColumns; if (columns.length != 1) { diff --git a/sqlparser/lib/src/analysis/types2/resolving_visitor.dart b/sqlparser/lib/src/analysis/types2/resolving_visitor.dart index 1415920f..d15685e9 100644 --- a/sqlparser/lib/src/analysis/types2/resolving_visitor.dart +++ b/sqlparser/lib/src/analysis/types2/resolving_visitor.dart @@ -156,6 +156,14 @@ class TypeResolver extends RecursiveVisitor { visitChildren(e, const NoTypeExpectation()); } + @override + void visitCastExpression(CastExpression e, TypeExpectation arg) { + final type = session.context.schemaSupport.resolveColumnType(e.typeName); + session.checkAndResolve(e, type, arg); + session.addRelationship(NullableIfSomeOtherIs(e, [e.operand])); + visit(e.operand, const NoTypeExpectation()); + } + void _handleWhereClause(HasWhereClause stmt) { // assume that a where statement is a boolean expression. Sqlite internally // casts (https://www.sqlite.org/lang_expr.html#booleanexpr), so be lax diff --git a/sqlparser/lib/src/ast/ast.dart b/sqlparser/lib/src/ast/ast.dart index b77b7bc5..30a962f2 100644 --- a/sqlparser/lib/src/ast/ast.dart +++ b/sqlparser/lib/src/ast/ast.dart @@ -14,6 +14,7 @@ part 'common/renamable.dart'; part 'common/tuple.dart'; part 'expressions/aggregate.dart'; part 'expressions/case.dart'; +part 'expressions/cast.dart'; part 'expressions/expressions.dart'; part 'expressions/function.dart'; part 'expressions/literals.dart'; diff --git a/sqlparser/lib/src/ast/expressions/cast.dart b/sqlparser/lib/src/ast/expressions/cast.dart new file mode 100644 index 00000000..08a3e593 --- /dev/null +++ b/sqlparser/lib/src/ast/expressions/cast.dart @@ -0,0 +1,22 @@ +part of '../ast.dart'; + +/// A `CAST( AS )` expression. +class CastExpression extends Expression { + final Expression operand; + final String typeName; + + CastExpression(this.operand, this.typeName); + + @override + R accept(AstVisitor visitor, A arg) { + return visitor.visitCastExpression(this, arg); + } + + @override + Iterable get childNodes => [operand]; + + @override + bool contentEquals(CastExpression other) { + return other.typeName == typeName; + } +} diff --git a/sqlparser/lib/src/ast/visitor.dart b/sqlparser/lib/src/ast/visitor.dart index a53fe055..389877bb 100644 --- a/sqlparser/lib/src/ast/visitor.dart +++ b/sqlparser/lib/src/ast/visitor.dart @@ -27,6 +27,7 @@ abstract class AstVisitor { R visitTableConstraint(TableConstraint e, A arg); R visitForeignKeyClause(ForeignKeyClause e, A arg); + R visitCastExpression(CastExpression e, A arg); R visitBinaryExpression(BinaryExpression e, A arg); R visitStringComparison(StringComparisonExpression e, A arg); R visitUnaryExpression(UnaryExpression e, A arg); @@ -231,6 +232,11 @@ class RecursiveVisitor implements AstVisitor { // Expressions + @override + R visitCastExpression(CastExpression e, A arg) { + return visitExpression(e, arg); + } + @override R visitBinaryExpression(BinaryExpression e, A arg) { return visitExpression(e, arg); diff --git a/sqlparser/lib/src/engine/sql_engine.dart b/sqlparser/lib/src/engine/sql_engine.dart index a7c0aba2..3c6b2609 100644 --- a/sqlparser/lib/src/engine/sql_engine.dart +++ b/sqlparser/lib/src/engine/sql_engine.dart @@ -150,8 +150,7 @@ class SqlEngine { {AnalyzeStatementOptions stmtOptions}) { final node = result.rootNode; - final context = - AnalysisContext(node, result.sql, options, stmtOptions: stmtOptions); + final context = _createContext(node, result.sql, stmtOptions); _analyzeContext(context); return context; @@ -169,12 +168,17 @@ class SqlEngine { /// this statement only. AnalysisContext analyzeNode(AstNode node, String file, {AnalyzeStatementOptions stmtOptions}) { - final context = - AnalysisContext(node, file, options, stmtOptions: stmtOptions); + final context = _createContext(node, file, stmtOptions); _analyzeContext(context); return context; } + AnalysisContext _createContext( + AstNode node, String sql, AnalyzeStatementOptions stmtOptions) { + return AnalysisContext(node, sql, options, + stmtOptions: stmtOptions, schemaSupport: schemaReader); + } + void _analyzeContext(AnalysisContext context) { final node = context.root; _attachRootScope(node); diff --git a/sqlparser/lib/src/reader/parser/expressions.dart b/sqlparser/lib/src/reader/parser/expressions.dart index 628592ef..815c33ed 100644 --- a/sqlparser/lib/src/reader/parser/expressions.dart +++ b/sqlparser/lib/src/reader/parser/expressions.dart @@ -284,7 +284,7 @@ mixin ExpressionParser on ParserBase { final first = _consumeIdentifier( 'This error message should never be displayed. Please report.'); - // could be table.column, function(...) or just column + // could be table.column, function(...), cast(...) or just column if (_matchOne(TokenType.dot)) { final second = _consumeIdentifier('Expected a column name here', lenient: true); @@ -292,17 +292,28 @@ mixin ExpressionParser on ParserBase { tableName: first.identifier, columnName: second.identifier) ..setSpan(first, second); } else if (_matchOne(TokenType.leftParen)) { - final parameters = _functionParameters(); - final rightParen = _consume(TokenType.rightParen, - 'Expected closing bracket after argument list'); + if (first.identifier.toLowerCase() == 'cast') { + final operand = expression(); + _consume(TokenType.as, 'Expected AS operator here'); + final type = _typeName(); + final typeName = type.lexeme; + _consume(TokenType.rightParen, 'Expected closing bracket here'); - if (_peek.type == TokenType.filter || _peek.type == TokenType.over) { - return _aggregate(first, parameters); + return CastExpression(operand, typeName)..setSpan(first, _previous); + } else { + // regular function invocation + final parameters = _functionParameters(); + final rightParen = _consume(TokenType.rightParen, + 'Expected closing bracket after argument list'); + + if (_peek.type == TokenType.filter || _peek.type == TokenType.over) { + return _aggregate(first, parameters); + } + + return FunctionExpression( + name: first.identifier, parameters: parameters) + ..setSpan(first, rightParen); } - - return FunctionExpression( - name: first.identifier, parameters: parameters) - ..setSpan(first, rightParen); } else { return Reference(columnName: first.identifier)..setSpan(first, first); } diff --git a/sqlparser/lib/src/reader/parser/parser.dart b/sqlparser/lib/src/reader/parser/parser.dart index 75a21c07..91ebffc2 100644 --- a/sqlparser/lib/src/reader/parser/parser.dart +++ b/sqlparser/lib/src/reader/parser/parser.dart @@ -313,8 +313,7 @@ class Parser extends ParserBase _error('Expected a type name here'); } - final typeName = - typeNameTokens.first.span.expand(typeNameTokens.last.span).text; + final typeName = typeNameTokens.lexeme; parameters.add(VariableTypeHint(variable, typeName) ..as = as ..setSpan(first, _previous)); @@ -364,3 +363,7 @@ class Parser extends ParserBase while (!_isAtEnd && _advance().type != TokenType.semicolon) {} } } + +extension on List { + String get lexeme => first.span.expand(last.span).text; +} diff --git a/sqlparser/lib/src/reader/parser/schema.dart b/sqlparser/lib/src/reader/parser/schema.dart index 6ecadcd9..7274aecf 100644 --- a/sqlparser/lib/src/reader/parser/schema.dart +++ b/sqlparser/lib/src/reader/parser/schema.dart @@ -158,8 +158,7 @@ mixin SchemaParser on ParserBase { String typeName; if (typeTokens != null) { - final typeSpan = typeTokens.first.span.expand(typeTokens.last.span); - typeName = typeSpan.text; + typeName = typeTokens.lexeme; } final constraints = []; diff --git a/sqlparser/test/analysis/types/resolver_test.dart b/sqlparser/test/analysis/types/resolver_test.dart index 378ed15a..4f3481a8 100644 --- a/sqlparser/test/analysis/types/resolver_test.dart +++ b/sqlparser/test/analysis/types/resolver_test.dart @@ -30,6 +30,8 @@ Map _types = { 'SELECT row_number() OVER (RANGE ? PRECEDING)': const ResolveResult(ResolvedType(type: BasicType.int)), 'SELECT ?;': const ResolveResult.unknown(), + 'SELECT CAST(3 AS TEXT) = ?': + const ResolveResult(ResolvedType(type: BasicType.text)), }; void main() { diff --git a/sqlparser/test/analysis/types2/resolver_test.dart b/sqlparser/test/analysis/types2/resolver_test.dart index 6444f8fc..8b80c67f 100644 --- a/sqlparser/test/analysis/types2/resolver_test.dart +++ b/sqlparser/test/analysis/types2/resolver_test.dart @@ -61,4 +61,15 @@ void main() { expect(_resolveFirstVariable("SELECT '' || :foo"), const ResolvedType(type: BasicType.text)); }); + + group('case expressions', () { + test('resolve to type argument', () { + expect(_resolveResultColumn('SELECT CAST(3+4 AS TEXT)'), + const ResolvedType(type: BasicType.text)); + }); + + test('allow anything as their operand', () { + expect(_resolveFirstVariable('SELECT CAST(? AS TEXT)'), null); + }); + }); } diff --git a/sqlparser/test/parser/expression_test.dart b/sqlparser/test/parser/expression_test.dart index 34b21906..7fe2fdd8 100644 --- a/sqlparser/test/parser/expression_test.dart +++ b/sqlparser/test/parser/expression_test.dart @@ -136,6 +136,14 @@ final Map _testCases = { ], ), ), + 'CAST(3 + 4 AS TEXT)': CastExpression( + BinaryExpression( + NumericLiteral(3.0, token(TokenType.numberLiteral)), + token(TokenType.plus), + NumericLiteral(4.0, token(TokenType.numberLiteral)), + ), + 'TEXT', + ), }; void main() {