diff --git a/sqlparser/lib/src/analysis/types/data.dart b/sqlparser/lib/src/analysis/types/data.dart index f72b0dc7..bc9ad4bf 100644 --- a/sqlparser/lib/src/analysis/types/data.dart +++ b/sqlparser/lib/src/analysis/types/data.dart @@ -18,12 +18,22 @@ class ResolvedType { final TypeHint hint; final bool nullable; - const ResolvedType({this.type, this.hint, this.nullable = false}); + /// Whether this type is an array. + final bool isArray; + + const ResolvedType( + {this.type, this.hint, this.nullable = false, this.isArray = false}); const ResolvedType.bool() : this(type: BasicType.int, hint: const IsBoolean()); ResolvedType withNullable(bool nullable) { - return ResolvedType(type: type, hint: hint, nullable: nullable); + return ResolvedType( + type: type, hint: hint, nullable: nullable, isArray: isArray); + } + + ResolvedType toArray(bool array) { + return ResolvedType( + type: type, hint: hint, nullable: nullable, isArray: array); } @override @@ -32,7 +42,8 @@ class ResolvedType { other is ResolvedType && other.type == type && other.hint == hint && - other.nullable == nullable; + other.nullable == nullable && + other.isArray == isArray; } @override @@ -42,7 +53,7 @@ class ResolvedType { @override String toString() { - return 'ResolvedType($type, hint: $hint, nullable: $nullable)'; + return 'ResolvedType($type, hint: $hint, nullable: $nullable, array: $isArray)'; } } diff --git a/sqlparser/lib/src/analysis/types/resolver.dart b/sqlparser/lib/src/analysis/types/resolver.dart index 9e5088cb..2f5b7f61 100644 --- a/sqlparser/lib/src/analysis/types/resolver.dart +++ b/sqlparser/lib/src/analysis/types/resolver.dart @@ -10,7 +10,7 @@ const _comparisonOperators = [ TokenType.less, TokenType.lessEqual, TokenType.more, - TokenType.moreEqual + TokenType.moreEqual, ]; class TypeResolver { @@ -79,6 +79,7 @@ class TypeResolver { } else if (expr is FunctionExpression) { return resolveFunctionCall(expr); } else if (expr is IsExpression || + expr is InExpression || expr is StringComparisonExpression || expr is BetweenExpression || expr is ExistsExpression) { @@ -240,11 +241,15 @@ class TypeResolver { final parent = e.parent; if (parent is Expression) { final result = _argumentType(parent, e); - if (result.needsContext) { - return inferType(parent); - } else { - return result; + // while more context is needed, look at the parent + final inferredType = result.needsContext ? inferType(parent) : result; + + // If this appears in a tuple, e.g. test IN (?). The "(?)" will be an + // array. Of course, the individual entry is not, so reset that state. + if (parent is TupleExpression) { + return inferredType.mapResult((r) => r.toArray(false)); } + return inferredType; } else if (parent is Limit) { return const ResolveResult(ResolvedType(type: BasicType.int)); } else if (parent is SetComponent) { @@ -257,12 +262,20 @@ class TypeResolver { ResolveResult _argumentType(Expression parent, Expression argument) { if (parent is IsExpression || + parent is InExpression || parent is BinaryExpression || parent is BetweenExpression || parent is CaseExpression) { final relevant = parent.childNodes .lastWhere((node) => node is Expression && node != argument); - return resolveExpression(relevant as Expression); + final resolved = resolveExpression(relevant as Expression); + + // if we have "a x IN argument" expression, the argument will be an array + if (parent is InExpression && argument == parent.inside) { + return resolved.mapResult((r) => r.toArray(true)); + } + + return resolved; } else if (parent is StringComparisonExpression) { if (argument == parent.escape) { return const ResolveResult(ResolvedType(type: BasicType.text)); @@ -271,7 +284,9 @@ class TypeResolver { .firstWhere((node) => node is Expression && node != argument); return resolveExpression(otherNode as Expression); } - } else if (parent is Parentheses || parent is UnaryExpression) { + } else if (parent is Parentheses || + parent is TupleExpression || + parent is UnaryExpression) { return const ResolveResult.needsContext(); } else if (parent is FunctionExpression) { return resolveFunctionCall(parent); @@ -326,12 +341,9 @@ class ResolveResult { bool get nullable => type?.nullable ?? true; - /// Copies the result with the [nullable] information, if there is one. If - /// there isn't, the failure state will be copied into the new - /// [ResolveResult]. - ResolveResult withNullable(bool nullable) { + ResolveResult mapResult(ResolvedType Function(ResolvedType) map) { if (type != null) { - return ResolveResult(type.withNullable(nullable)); + return ResolveResult(map(type)); } else if (needsContext != null) { return const ResolveResult.needsContext(); } else { @@ -339,6 +351,13 @@ class ResolveResult { } } + /// Copies the result with the [nullable] information, if there is one. If + /// there isn't, the failure state will be copied into the new + /// [ResolveResult]. + ResolveResult withNullable(bool nullable) { + return mapResult((r) => r.withNullable(nullable)); + } + @override bool operator ==(other) { return identical(this, other) || diff --git a/sqlparser/lib/src/ast/ast.dart b/sqlparser/lib/src/ast/ast.dart index 0b7a7049..e08d895a 100644 --- a/sqlparser/lib/src/ast/ast.dart +++ b/sqlparser/lib/src/ast/ast.dart @@ -150,6 +150,7 @@ abstract class AstVisitor { T visitCaseExpression(CaseExpression e); T visitWhen(WhenComponent e); T visitTuple(TupleExpression e); + T visitInExpression(InExpression e); T visitNumberedVariable(NumberedVariable e); T visitNamedVariable(ColonNamedVariable e); @@ -184,6 +185,9 @@ class RecursiveVisitor extends AstVisitor { @override T visitTuple(TupleExpression e) => visitChildren(e); + @override + T visitInExpression(InExpression e) => visitChildren(e); + @override T visitSubQuery(SubQuery e) => visitChildren(e); diff --git a/sqlparser/lib/src/ast/expressions/simple.dart b/sqlparser/lib/src/ast/expressions/simple.dart index 124d6205..52b6bc37 100644 --- a/sqlparser/lib/src/ast/expressions/simple.dart +++ b/sqlparser/lib/src/ast/expressions/simple.dart @@ -61,6 +61,7 @@ class StringComparisonExpression extends Expression { bool contentEquals(StringComparisonExpression other) => other.not == not; } +/// `(NOT)? $left IS $right` class IsExpression extends Expression { final bool negated; final Expression left; @@ -82,6 +83,7 @@ class IsExpression extends Expression { } } +/// `$check BETWEEN $lower AND $upper` class BetweenExpression extends Expression { final bool not; final Expression check; @@ -100,6 +102,27 @@ class BetweenExpression extends Expression { bool contentEquals(BetweenExpression other) => other.not == not; } +/// `$left$ IN $inside` +class InExpression extends Expression { + final bool not; + final Expression left; + final Expression inside; + + InExpression({this.not = false, @required this.left, @required this.inside}); + + @override + T accept(AstVisitor visitor) => visitor.visitInExpression(this); + + @override + Iterable get childNodes => [left, inside]; + + @override + bool contentEquals(InExpression other) => other.not == not; +} + +// todo we might be able to remove a hack in the parser at _in() if we make +// parentheses a subclass of tuples + class Parentheses extends Expression { final Token openingLeft; final Expression expression; @@ -117,4 +140,8 @@ class Parentheses extends Expression { @override bool contentEquals(Parentheses other) => true; + + TupleExpression get asTuple { + return TupleExpression(expressions: [expression]); + } } diff --git a/sqlparser/lib/src/reader/parser/parser.dart b/sqlparser/lib/src/reader/parser/parser.dart index 2a98ec90..3609b1f7 100644 --- a/sqlparser/lib/src/reader/parser/parser.dart +++ b/sqlparser/lib/src/reader/parser/parser.dart @@ -507,7 +507,27 @@ class Parser { } Expression _or() => _parseSimpleBinary(const [TokenType.or], _and); - Expression _and() => _parseSimpleBinary(const [TokenType.and], _equals); + Expression _and() => _parseSimpleBinary(const [TokenType.and], _in); + + Expression _in() { + final left = _equals(); + + if (_checkWithNot(TokenType.$in)) { + final not = _matchOne(TokenType.not); + _matchOne(TokenType.$in); + + var inside = _equals(); + if (inside is Parentheses) { + // if we have something like x IN (3), then (3) is a tuple and not a + // parenthesis. We can only know this from the context unfortunately + inside = (inside as Parentheses).asTuple; + } + + return InExpression(left: left, inside: inside, not: not); + } + + return left; + } /// Parses expressions with the "equals" precedence. This contains /// comparisons, "IS (NOT) IN" expressions, between expressions and "like" @@ -521,7 +541,6 @@ class Parser { TokenType.exclamationEqual, TokenType.lessMore, TokenType.$is, - TokenType.$in, ]; final stringOps = const [ TokenType.like, diff --git a/sqlparser/test/analysis/type_resolver_test.dart b/sqlparser/test/analysis/type_resolver_test.dart index 6f447577..78be19de 100644 --- a/sqlparser/test/analysis/type_resolver_test.dart +++ b/sqlparser/test/analysis/type_resolver_test.dart @@ -20,6 +20,10 @@ Map _types = { const ResolveResult(ResolvedType(type: BasicType.text)), "SELECT * FROM demo WHERE content LIKE '%e' ESCAPE ?": const ResolveResult(ResolvedType(type: BasicType.text)), + 'SELECT * FROM demo WHERE content IN ?': + const ResolveResult(ResolvedType(type: BasicType.text, isArray: true)), + 'SELECT * FROM demo WHERE content IN (?)': + const ResolveResult(ResolvedType(type: BasicType.text, isArray: false)), }; void main() {