Support CASE expressions in the new type resolver

This commit is contained in:
Simon Binder 2020-01-10 21:31:22 +01:00
parent 6434f7a7d5
commit 9c38ed1ea5
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
5 changed files with 130 additions and 5 deletions

View File

@ -52,6 +52,15 @@ class RoughTypeExpectation extends TypeExpectation {
}
return false;
}
ResolvedType defaultType() {
switch (_type) {
case _RoughType.numeric:
return const ResolvedType(type: BasicType.real);
}
throw AssertionError();
}
}
enum _RoughType {

View File

@ -9,6 +9,8 @@ class TypeGraph {
final List<TypeRelationship> _relationships = [];
final Map<Typeable, List<TypeRelationship>> _edges = {};
final List<DefaultType> _defaultTypes = [];
final List<CopyEncapsulating> _manyToOne = [];
TypeGraph();
@ -45,11 +47,30 @@ class TypeGraph {
while (queue.isNotEmpty) {
_propagateTypeInfo(queue, queue.removeLast());
}
// propagate many-to-one changes
for (final edge in _manyToOne) {
if (!knowsType(edge.target)) {
final fromTypes = edge.from.map((t) => this[t]).where((e) => e != null);
final encapsulated = _encapsulate(fromTypes);
if (encapsulated != null) {
this[edge.target] = encapsulated;
}
}
}
// apply default types
for (final applyDefault in _defaultTypes) {
if (!knowsType(applyDefault.target)) {
this[applyDefault.target] = applyDefault.defaultType;
}
}
}
void _propagateTypeInfo(List<Typeable> resolved, Typeable t) {
if (!_edges.containsKey(t)) return;
// propagate one-to-one and one-to-many changes
for (final edge in _edges[t]) {
if (edge is CopyTypeFrom) {
_copyType(resolved, edge.other, edge.target);
@ -68,6 +89,12 @@ class TypeGraph {
}
}
ResolvedType /*?*/ _encapsulate(Iterable<ResolvedType> targets) {
return targets.fold<ResolvedType>(null, (previous, element) {
return previous?.union(element) ?? element;
});
}
void _indexRelationships() {
_edges.clear();
@ -87,12 +114,12 @@ class TypeGraph {
} else if (relation is CopyTypeFrom) {
put(relation.other, relation);
} else if (relation is CopyEncapsulating) {
putAll(relation.from, relation);
_manyToOne.add(relation);
} else if (relation is HaveSameType) {
put(relation.first, relation);
put(relation.second, relation);
} else if (relation is DefaultType) {
put(relation.target, relation);
_defaultTypes.add(relation);
} else if (relation is CopyAndCast) {
put(relation.other, relation);
} else {
@ -117,3 +144,16 @@ class _ResolvedVariables {
return _referenceForIndex[normalized.resolvedIndex] ??= normalized;
}
}
extension on ResolvedType {
ResolvedType union(ResolvedType other) {
if (other == this) return this;
if (other.type == type) {
return withNullable(nullable || other.nullable);
}
// fallback. todo: Support more cases
return const ResolvedType(type: BasicType.text, nullable: true);
}
}

View File

@ -2,6 +2,7 @@ part of 'types.dart';
const _expectInt =
ExactTypeExpectation.laxly(ResolvedType(type: BasicType.int));
const _expectNum = RoughTypeExpectation.numeric();
const _expectString =
ExactTypeExpectation.laxly(ResolvedType(type: BasicType.text));
@ -141,10 +142,13 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
void visitVariable(Variable e, TypeExpectation arg) {
final resolved = session.context.stmtOptions.specifiedTypeOf(e) ??
_inferFromContext(arg);
if (resolved != null) {
session.markTypeResolved(e, resolved);
session.checkAndResolve(e, resolved, arg);
} else if (arg is RoughTypeExpectation) {
session.addRelationship(DefaultType(e, arg.defaultType()));
}
// todo support when arg is RoughTypeExpectation
visitChildren(e, arg);
}
@ -178,6 +182,16 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
}
}
@override
void visitBetweenExpression(BetweenExpression e, TypeExpectation arg) {
visitChildren(e, _expectNum);
session
..addRelationship(NullableIfSomeOtherIs(e, e.childNodes))
..addRelationship(HaveSameType(e.lower, e.upper))
..addRelationship(HaveSameType(e.check, e.lower));
}
@override
void visitBinaryExpression(BinaryExpression e, TypeExpectation arg) {
switch (e.operator.type) {
@ -241,6 +255,45 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
visitChildren(e, const NoTypeExpectation());
}
@override
void visitIsNullExpression(IsNullExpression e, TypeExpectation arg) {
session.checkAndResolve(e, const ResolvedType.bool(), arg);
session.hintNullability(e, false);
visitChildren(e, const NoTypeExpectation());
}
@override
void visitCaseExpression(CaseExpression e, TypeExpectation arg) {
session.addRelationship(CopyEncapsulating(e, [
for (final when in e.whens) when.then,
if (e.elseExpr != null) e.elseExpr,
]));
if (e.base != null) {
session.addRelationship(
CopyEncapsulating(e.base, [for (final when in e.whens) when.when]),
);
}
visitNullable(e.base, const NoTypeExpectation());
visitExcept(e, e.base, arg);
}
@override
void visitWhen(WhenComponent e, TypeExpectation arg) {
final parent = e.parent;
if (parent is CaseExpression && parent.base != null) {
// case expressions with base -> condition is compared to base
session.addRelationship(CopyTypeFrom(e.when, parent.base));
visit(e.when, const NoTypeExpectation());
} else {
// case expression without base -> the conditions are booleans
visit(e.when, const ExactTypeExpectation(ResolvedType.bool()));
}
visit(e.then, arg);
}
@override
void visitCastExpression(CastExpression e, TypeExpectation arg) {
final type = session.context.schemaSupport.resolveColumnType(e.typeName);

View File

@ -143,7 +143,7 @@ class BetweenExpression extends Expression {
}
@override
Iterable<AstNode> get childNodes => [check, lower, upper];
List<Expression> get childNodes => [check, lower, upper];
@override
bool contentEquals(BetweenExpression other) => other.not == not;

View File

@ -137,4 +137,27 @@ void main() {
"SELECT * FROM demo WHERE content LIKE 'foo' ESCAPE ?");
expect(escapedType, const ResolvedType(type: BasicType.text));
});
group('case expressions', () {
test('infers base clause from when', () {
final type = _resolveFirstVariable("SELECT CASE ? WHEN 1 THEN 'two' END");
expect(type, const ResolvedType(type: BasicType.int));
});
test('infers when condition from base', () {
final type = _resolveFirstVariable("SELECT CASE 1 WHEN ? THEN 'two' END");
expect(type, const ResolvedType(type: BasicType.int));
});
test('infers when conditions as boolean when no base is set', () {
final type = _resolveFirstVariable("SELECT CASE WHEN ? THEN 'two' END;");
expect(type, const ResolvedType.bool());
});
test('infers type of whole when expression', () {
final type = _resolveResultColumn("SELECT CASE WHEN false THEN 'one' "
"WHEN true THEN 'two' ELSE 'three' END;");
expect(type, const ResolvedType(type: BasicType.text));
});
});
}