sqlparser: Support CAST expressions

This commit is contained in:
Simon Binder 2019-12-30 20:46:54 +01:00
parent 4484890609
commit c54a62120d
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
14 changed files with 102 additions and 20 deletions

View File

@ -1,6 +1,8 @@
## unreleased
- Added a argument type and argument to the visitor classes
- Experimental new type inference algorithm
- Support `CAST` expressions.
## 0.5.0
- Optionally support the `json1` module
@ -19,7 +21,6 @@
__0.3.0+1__: Accept `\r` characters as whitespace
## 0.2.0
- Parse `CREATE TABLE` statements
- Extract schema information from parsed create table statements with `SchemaFromCreateTable`.

View File

@ -17,6 +17,9 @@ class AnalysisContext {
/// outside.
final AnalyzeStatementOptions stmtOptions;
/// Utilities to read types.
final SchemaFromCreateTable schemaSupport;
/// A resolver that can be used to obtain the type of a [Typeable]. This
/// mostly applies to [Expression]s, [Reference]s, [Variable]s and
/// [ResultSet.resolvedColumns] of a select statement.
@ -24,7 +27,7 @@ class AnalysisContext {
/// Constructs a new analysis context from the AST and the source sql.
AnalysisContext(this.root, this.sql, EngineOptions options,
{AnalyzeStatementOptions stmtOptions})
{AnalyzeStatementOptions stmtOptions, this.schemaSupport})
: stmtOptions = stmtOptions ?? const AnalyzeStatementOptions() {
types = TypeResolver(this, options);
}

View File

@ -130,6 +130,9 @@ class TypeResolver {
}
} else if (expr is CaseExpression) {
return resolveExpression(expr.whens.first.then);
} else if (expr is CastExpression) {
final type = context.schemaSupport.resolveColumnType(expr.typeName);
return ResolveResult(type);
} else if (expr is SubQuery) {
final columns = expr.select.resultSet.resolvedColumns;
if (columns.length != 1) {

View File

@ -156,6 +156,14 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
visitChildren(e, const NoTypeExpectation());
}
@override
void visitCastExpression(CastExpression e, TypeExpectation arg) {
final type = session.context.schemaSupport.resolveColumnType(e.typeName);
session.checkAndResolve(e, type, arg);
session.addRelationship(NullableIfSomeOtherIs(e, [e.operand]));
visit(e.operand, const NoTypeExpectation());
}
void _handleWhereClause(HasWhereClause stmt) {
// assume that a where statement is a boolean expression. Sqlite internally
// casts (https://www.sqlite.org/lang_expr.html#booleanexpr), so be lax

View File

@ -14,6 +14,7 @@ part 'common/renamable.dart';
part 'common/tuple.dart';
part 'expressions/aggregate.dart';
part 'expressions/case.dart';
part 'expressions/cast.dart';
part 'expressions/expressions.dart';
part 'expressions/function.dart';
part 'expressions/literals.dart';

View File

@ -0,0 +1,22 @@
part of '../ast.dart';
/// A `CAST(<expr> AS <type>)` expression.
class CastExpression extends Expression {
final Expression operand;
final String typeName;
CastExpression(this.operand, this.typeName);
@override
R accept<A, R>(AstVisitor<A, R> visitor, A arg) {
return visitor.visitCastExpression(this, arg);
}
@override
Iterable<AstNode> get childNodes => [operand];
@override
bool contentEquals(CastExpression other) {
return other.typeName == typeName;
}
}

View File

@ -27,6 +27,7 @@ abstract class AstVisitor<A, R> {
R visitTableConstraint(TableConstraint e, A arg);
R visitForeignKeyClause(ForeignKeyClause e, A arg);
R visitCastExpression(CastExpression e, A arg);
R visitBinaryExpression(BinaryExpression e, A arg);
R visitStringComparison(StringComparisonExpression e, A arg);
R visitUnaryExpression(UnaryExpression e, A arg);
@ -231,6 +232,11 @@ class RecursiveVisitor<A, R> implements AstVisitor<A, R> {
// Expressions
@override
R visitCastExpression(CastExpression e, A arg) {
return visitExpression(e, arg);
}
@override
R visitBinaryExpression(BinaryExpression e, A arg) {
return visitExpression(e, arg);

View File

@ -150,8 +150,7 @@ class SqlEngine {
{AnalyzeStatementOptions stmtOptions}) {
final node = result.rootNode;
final context =
AnalysisContext(node, result.sql, options, stmtOptions: stmtOptions);
final context = _createContext(node, result.sql, stmtOptions);
_analyzeContext(context);
return context;
@ -169,12 +168,17 @@ class SqlEngine {
/// this statement only.
AnalysisContext analyzeNode(AstNode node, String file,
{AnalyzeStatementOptions stmtOptions}) {
final context =
AnalysisContext(node, file, options, stmtOptions: stmtOptions);
final context = _createContext(node, file, stmtOptions);
_analyzeContext(context);
return context;
}
AnalysisContext _createContext(
AstNode node, String sql, AnalyzeStatementOptions stmtOptions) {
return AnalysisContext(node, sql, options,
stmtOptions: stmtOptions, schemaSupport: schemaReader);
}
void _analyzeContext(AnalysisContext context) {
final node = context.root;
_attachRootScope(node);

View File

@ -284,7 +284,7 @@ mixin ExpressionParser on ParserBase {
final first = _consumeIdentifier(
'This error message should never be displayed. Please report.');
// could be table.column, function(...) or just column
// could be table.column, function(...), cast(...) or just column
if (_matchOne(TokenType.dot)) {
final second =
_consumeIdentifier('Expected a column name here', lenient: true);
@ -292,17 +292,28 @@ mixin ExpressionParser on ParserBase {
tableName: first.identifier, columnName: second.identifier)
..setSpan(first, second);
} else if (_matchOne(TokenType.leftParen)) {
final parameters = _functionParameters();
final rightParen = _consume(TokenType.rightParen,
'Expected closing bracket after argument list');
if (first.identifier.toLowerCase() == 'cast') {
final operand = expression();
_consume(TokenType.as, 'Expected AS operator here');
final type = _typeName();
final typeName = type.lexeme;
_consume(TokenType.rightParen, 'Expected closing bracket here');
if (_peek.type == TokenType.filter || _peek.type == TokenType.over) {
return _aggregate(first, parameters);
return CastExpression(operand, typeName)..setSpan(first, _previous);
} else {
// regular function invocation
final parameters = _functionParameters();
final rightParen = _consume(TokenType.rightParen,
'Expected closing bracket after argument list');
if (_peek.type == TokenType.filter || _peek.type == TokenType.over) {
return _aggregate(first, parameters);
}
return FunctionExpression(
name: first.identifier, parameters: parameters)
..setSpan(first, rightParen);
}
return FunctionExpression(
name: first.identifier, parameters: parameters)
..setSpan(first, rightParen);
} else {
return Reference(columnName: first.identifier)..setSpan(first, first);
}

View File

@ -313,8 +313,7 @@ class Parser extends ParserBase
_error('Expected a type name here');
}
final typeName =
typeNameTokens.first.span.expand(typeNameTokens.last.span).text;
final typeName = typeNameTokens.lexeme;
parameters.add(VariableTypeHint(variable, typeName)
..as = as
..setSpan(first, _previous));
@ -364,3 +363,7 @@ class Parser extends ParserBase
while (!_isAtEnd && _advance().type != TokenType.semicolon) {}
}
}
extension on List<Token> {
String get lexeme => first.span.expand(last.span).text;
}

View File

@ -158,8 +158,7 @@ mixin SchemaParser on ParserBase {
String typeName;
if (typeTokens != null) {
final typeSpan = typeTokens.first.span.expand(typeTokens.last.span);
typeName = typeSpan.text;
typeName = typeTokens.lexeme;
}
final constraints = <ColumnConstraint>[];

View File

@ -30,6 +30,8 @@ Map<String, ResolveResult> _types = {
'SELECT row_number() OVER (RANGE ? PRECEDING)':
const ResolveResult(ResolvedType(type: BasicType.int)),
'SELECT ?;': const ResolveResult.unknown(),
'SELECT CAST(3 AS TEXT) = ?':
const ResolveResult(ResolvedType(type: BasicType.text)),
};
void main() {

View File

@ -61,4 +61,15 @@ void main() {
expect(_resolveFirstVariable("SELECT '' || :foo"),
const ResolvedType(type: BasicType.text));
});
group('case expressions', () {
test('resolve to type argument', () {
expect(_resolveResultColumn('SELECT CAST(3+4 AS TEXT)'),
const ResolvedType(type: BasicType.text));
});
test('allow anything as their operand', () {
expect(_resolveFirstVariable('SELECT CAST(? AS TEXT)'), null);
});
});
}

View File

@ -136,6 +136,14 @@ final Map<String, Expression> _testCases = {
],
),
),
'CAST(3 + 4 AS TEXT)': CastExpression(
BinaryExpression(
NumericLiteral(3.0, token(TokenType.numberLiteral)),
token(TokenType.plus),
NumericLiteral(4.0, token(TokenType.numberLiteral)),
),
'TEXT',
),
};
void main() {