Parse and analyze tuples, rework type resolution for `IN`

This commit is contained in:
Simon Binder 2019-07-02 14:38:28 +02:00
parent 285113717f
commit 3024157ec9
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
6 changed files with 102 additions and 18 deletions

View File

@ -18,12 +18,22 @@ class ResolvedType {
final TypeHint hint;
final bool nullable;
const ResolvedType({this.type, this.hint, this.nullable = false});
/// Whether this type is an array.
final bool isArray;
const ResolvedType(
{this.type, this.hint, this.nullable = false, this.isArray = false});
const ResolvedType.bool()
: this(type: BasicType.int, hint: const IsBoolean());
ResolvedType withNullable(bool nullable) {
return ResolvedType(type: type, hint: hint, nullable: nullable);
return ResolvedType(
type: type, hint: hint, nullable: nullable, isArray: isArray);
}
ResolvedType toArray(bool array) {
return ResolvedType(
type: type, hint: hint, nullable: nullable, isArray: array);
}
@override
@ -32,7 +42,8 @@ class ResolvedType {
other is ResolvedType &&
other.type == type &&
other.hint == hint &&
other.nullable == nullable;
other.nullable == nullable &&
other.isArray == isArray;
}
@override
@ -42,7 +53,7 @@ class ResolvedType {
@override
String toString() {
return 'ResolvedType($type, hint: $hint, nullable: $nullable)';
return 'ResolvedType($type, hint: $hint, nullable: $nullable, array: $isArray)';
}
}

View File

@ -10,7 +10,7 @@ const _comparisonOperators = [
TokenType.less,
TokenType.lessEqual,
TokenType.more,
TokenType.moreEqual
TokenType.moreEqual,
];
class TypeResolver {
@ -79,6 +79,7 @@ class TypeResolver {
} else if (expr is FunctionExpression) {
return resolveFunctionCall(expr);
} else if (expr is IsExpression ||
expr is InExpression ||
expr is StringComparisonExpression ||
expr is BetweenExpression ||
expr is ExistsExpression) {
@ -240,11 +241,15 @@ class TypeResolver {
final parent = e.parent;
if (parent is Expression) {
final result = _argumentType(parent, e);
if (result.needsContext) {
return inferType(parent);
} else {
return result;
// while more context is needed, look at the parent
final inferredType = result.needsContext ? inferType(parent) : result;
// If this appears in a tuple, e.g. test IN (?). The "(?)" will be an
// array. Of course, the individual entry is not, so reset that state.
if (parent is TupleExpression) {
return inferredType.mapResult((r) => r.toArray(false));
}
return inferredType;
} else if (parent is Limit) {
return const ResolveResult(ResolvedType(type: BasicType.int));
} else if (parent is SetComponent) {
@ -257,12 +262,20 @@ class TypeResolver {
ResolveResult _argumentType(Expression parent, Expression argument) {
if (parent is IsExpression ||
parent is InExpression ||
parent is BinaryExpression ||
parent is BetweenExpression ||
parent is CaseExpression) {
final relevant = parent.childNodes
.lastWhere((node) => node is Expression && node != argument);
return resolveExpression(relevant as Expression);
final resolved = resolveExpression(relevant as Expression);
// if we have "a x IN argument" expression, the argument will be an array
if (parent is InExpression && argument == parent.inside) {
return resolved.mapResult((r) => r.toArray(true));
}
return resolved;
} else if (parent is StringComparisonExpression) {
if (argument == parent.escape) {
return const ResolveResult(ResolvedType(type: BasicType.text));
@ -271,7 +284,9 @@ class TypeResolver {
.firstWhere((node) => node is Expression && node != argument);
return resolveExpression(otherNode as Expression);
}
} else if (parent is Parentheses || parent is UnaryExpression) {
} else if (parent is Parentheses ||
parent is TupleExpression ||
parent is UnaryExpression) {
return const ResolveResult.needsContext();
} else if (parent is FunctionExpression) {
return resolveFunctionCall(parent);
@ -326,12 +341,9 @@ class ResolveResult {
bool get nullable => type?.nullable ?? true;
/// Copies the result with the [nullable] information, if there is one. If
/// there isn't, the failure state will be copied into the new
/// [ResolveResult].
ResolveResult withNullable(bool nullable) {
ResolveResult mapResult(ResolvedType Function(ResolvedType) map) {
if (type != null) {
return ResolveResult(type.withNullable(nullable));
return ResolveResult(map(type));
} else if (needsContext != null) {
return const ResolveResult.needsContext();
} else {
@ -339,6 +351,13 @@ class ResolveResult {
}
}
/// Copies the result with the [nullable] information, if there is one. If
/// there isn't, the failure state will be copied into the new
/// [ResolveResult].
ResolveResult withNullable(bool nullable) {
return mapResult((r) => r.withNullable(nullable));
}
@override
bool operator ==(other) {
return identical(this, other) ||

View File

@ -150,6 +150,7 @@ abstract class AstVisitor<T> {
T visitCaseExpression(CaseExpression e);
T visitWhen(WhenComponent e);
T visitTuple(TupleExpression e);
T visitInExpression(InExpression e);
T visitNumberedVariable(NumberedVariable e);
T visitNamedVariable(ColonNamedVariable e);
@ -184,6 +185,9 @@ class RecursiveVisitor<T> extends AstVisitor<T> {
@override
T visitTuple(TupleExpression e) => visitChildren(e);
@override
T visitInExpression(InExpression e) => visitChildren(e);
@override
T visitSubQuery(SubQuery e) => visitChildren(e);

View File

@ -61,6 +61,7 @@ class StringComparisonExpression extends Expression {
bool contentEquals(StringComparisonExpression other) => other.not == not;
}
/// `(NOT)? $left IS $right`
class IsExpression extends Expression {
final bool negated;
final Expression left;
@ -82,6 +83,7 @@ class IsExpression extends Expression {
}
}
/// `$check BETWEEN $lower AND $upper`
class BetweenExpression extends Expression {
final bool not;
final Expression check;
@ -100,6 +102,27 @@ class BetweenExpression extends Expression {
bool contentEquals(BetweenExpression other) => other.not == not;
}
/// `$left$ IN $inside`
class InExpression extends Expression {
final bool not;
final Expression left;
final Expression inside;
InExpression({this.not = false, @required this.left, @required this.inside});
@override
T accept<T>(AstVisitor<T> visitor) => visitor.visitInExpression(this);
@override
Iterable<AstNode> get childNodes => [left, inside];
@override
bool contentEquals(InExpression other) => other.not == not;
}
// todo we might be able to remove a hack in the parser at _in() if we make
// parentheses a subclass of tuples
class Parentheses extends Expression {
final Token openingLeft;
final Expression expression;
@ -117,4 +140,8 @@ class Parentheses extends Expression {
@override
bool contentEquals(Parentheses other) => true;
TupleExpression get asTuple {
return TupleExpression(expressions: [expression]);
}
}

View File

@ -507,7 +507,27 @@ class Parser {
}
Expression _or() => _parseSimpleBinary(const [TokenType.or], _and);
Expression _and() => _parseSimpleBinary(const [TokenType.and], _equals);
Expression _and() => _parseSimpleBinary(const [TokenType.and], _in);
Expression _in() {
final left = _equals();
if (_checkWithNot(TokenType.$in)) {
final not = _matchOne(TokenType.not);
_matchOne(TokenType.$in);
var inside = _equals();
if (inside is Parentheses) {
// if we have something like x IN (3), then (3) is a tuple and not a
// parenthesis. We can only know this from the context unfortunately
inside = (inside as Parentheses).asTuple;
}
return InExpression(left: left, inside: inside, not: not);
}
return left;
}
/// Parses expressions with the "equals" precedence. This contains
/// comparisons, "IS (NOT) IN" expressions, between expressions and "like"
@ -521,7 +541,6 @@ class Parser {
TokenType.exclamationEqual,
TokenType.lessMore,
TokenType.$is,
TokenType.$in,
];
final stringOps = const [
TokenType.like,

View File

@ -20,6 +20,10 @@ Map<String, ResolveResult> _types = {
const ResolveResult(ResolvedType(type: BasicType.text)),
"SELECT * FROM demo WHERE content LIKE '%e' ESCAPE ?":
const ResolveResult(ResolvedType(type: BasicType.text)),
'SELECT * FROM demo WHERE content IN ?':
const ResolveResult(ResolvedType(type: BasicType.text, isArray: true)),
'SELECT * FROM demo WHERE content IN (?)':
const ResolveResult(ResolvedType(type: BasicType.text, isArray: false)),
};
void main() {