mirror of https://github.com/AMT-Cheif/drift.git
Support CASE expressions in the new type resolver
This commit is contained in:
parent
6434f7a7d5
commit
9c38ed1ea5
|
@ -52,6 +52,15 @@ class RoughTypeExpectation extends TypeExpectation {
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ResolvedType defaultType() {
|
||||||
|
switch (_type) {
|
||||||
|
case _RoughType.numeric:
|
||||||
|
return const ResolvedType(type: BasicType.real);
|
||||||
|
}
|
||||||
|
|
||||||
|
throw AssertionError();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
enum _RoughType {
|
enum _RoughType {
|
||||||
|
|
|
@ -9,6 +9,8 @@ class TypeGraph {
|
||||||
final List<TypeRelationship> _relationships = [];
|
final List<TypeRelationship> _relationships = [];
|
||||||
|
|
||||||
final Map<Typeable, List<TypeRelationship>> _edges = {};
|
final Map<Typeable, List<TypeRelationship>> _edges = {};
|
||||||
|
final List<DefaultType> _defaultTypes = [];
|
||||||
|
final List<CopyEncapsulating> _manyToOne = [];
|
||||||
|
|
||||||
TypeGraph();
|
TypeGraph();
|
||||||
|
|
||||||
|
@ -45,11 +47,30 @@ class TypeGraph {
|
||||||
while (queue.isNotEmpty) {
|
while (queue.isNotEmpty) {
|
||||||
_propagateTypeInfo(queue, queue.removeLast());
|
_propagateTypeInfo(queue, queue.removeLast());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// propagate many-to-one changes
|
||||||
|
for (final edge in _manyToOne) {
|
||||||
|
if (!knowsType(edge.target)) {
|
||||||
|
final fromTypes = edge.from.map((t) => this[t]).where((e) => e != null);
|
||||||
|
final encapsulated = _encapsulate(fromTypes);
|
||||||
|
if (encapsulated != null) {
|
||||||
|
this[edge.target] = encapsulated;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply default types
|
||||||
|
for (final applyDefault in _defaultTypes) {
|
||||||
|
if (!knowsType(applyDefault.target)) {
|
||||||
|
this[applyDefault.target] = applyDefault.defaultType;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void _propagateTypeInfo(List<Typeable> resolved, Typeable t) {
|
void _propagateTypeInfo(List<Typeable> resolved, Typeable t) {
|
||||||
if (!_edges.containsKey(t)) return;
|
if (!_edges.containsKey(t)) return;
|
||||||
|
|
||||||
|
// propagate one-to-one and one-to-many changes
|
||||||
for (final edge in _edges[t]) {
|
for (final edge in _edges[t]) {
|
||||||
if (edge is CopyTypeFrom) {
|
if (edge is CopyTypeFrom) {
|
||||||
_copyType(resolved, edge.other, edge.target);
|
_copyType(resolved, edge.other, edge.target);
|
||||||
|
@ -68,6 +89,12 @@ class TypeGraph {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ResolvedType /*?*/ _encapsulate(Iterable<ResolvedType> targets) {
|
||||||
|
return targets.fold<ResolvedType>(null, (previous, element) {
|
||||||
|
return previous?.union(element) ?? element;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void _indexRelationships() {
|
void _indexRelationships() {
|
||||||
_edges.clear();
|
_edges.clear();
|
||||||
|
|
||||||
|
@ -87,12 +114,12 @@ class TypeGraph {
|
||||||
} else if (relation is CopyTypeFrom) {
|
} else if (relation is CopyTypeFrom) {
|
||||||
put(relation.other, relation);
|
put(relation.other, relation);
|
||||||
} else if (relation is CopyEncapsulating) {
|
} else if (relation is CopyEncapsulating) {
|
||||||
putAll(relation.from, relation);
|
_manyToOne.add(relation);
|
||||||
} else if (relation is HaveSameType) {
|
} else if (relation is HaveSameType) {
|
||||||
put(relation.first, relation);
|
put(relation.first, relation);
|
||||||
put(relation.second, relation);
|
put(relation.second, relation);
|
||||||
} else if (relation is DefaultType) {
|
} else if (relation is DefaultType) {
|
||||||
put(relation.target, relation);
|
_defaultTypes.add(relation);
|
||||||
} else if (relation is CopyAndCast) {
|
} else if (relation is CopyAndCast) {
|
||||||
put(relation.other, relation);
|
put(relation.other, relation);
|
||||||
} else {
|
} else {
|
||||||
|
@ -117,3 +144,16 @@ class _ResolvedVariables {
|
||||||
return _referenceForIndex[normalized.resolvedIndex] ??= normalized;
|
return _referenceForIndex[normalized.resolvedIndex] ??= normalized;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extension on ResolvedType {
|
||||||
|
ResolvedType union(ResolvedType other) {
|
||||||
|
if (other == this) return this;
|
||||||
|
|
||||||
|
if (other.type == type) {
|
||||||
|
return withNullable(nullable || other.nullable);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fallback. todo: Support more cases
|
||||||
|
return const ResolvedType(type: BasicType.text, nullable: true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ part of 'types.dart';
|
||||||
|
|
||||||
const _expectInt =
|
const _expectInt =
|
||||||
ExactTypeExpectation.laxly(ResolvedType(type: BasicType.int));
|
ExactTypeExpectation.laxly(ResolvedType(type: BasicType.int));
|
||||||
|
const _expectNum = RoughTypeExpectation.numeric();
|
||||||
const _expectString =
|
const _expectString =
|
||||||
ExactTypeExpectation.laxly(ResolvedType(type: BasicType.text));
|
ExactTypeExpectation.laxly(ResolvedType(type: BasicType.text));
|
||||||
|
|
||||||
|
@ -141,10 +142,13 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
|
||||||
void visitVariable(Variable e, TypeExpectation arg) {
|
void visitVariable(Variable e, TypeExpectation arg) {
|
||||||
final resolved = session.context.stmtOptions.specifiedTypeOf(e) ??
|
final resolved = session.context.stmtOptions.specifiedTypeOf(e) ??
|
||||||
_inferFromContext(arg);
|
_inferFromContext(arg);
|
||||||
|
|
||||||
if (resolved != null) {
|
if (resolved != null) {
|
||||||
session.markTypeResolved(e, resolved);
|
session.checkAndResolve(e, resolved, arg);
|
||||||
|
} else if (arg is RoughTypeExpectation) {
|
||||||
|
session.addRelationship(DefaultType(e, arg.defaultType()));
|
||||||
}
|
}
|
||||||
// todo support when arg is RoughTypeExpectation
|
|
||||||
visitChildren(e, arg);
|
visitChildren(e, arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,6 +182,16 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@override
|
||||||
|
void visitBetweenExpression(BetweenExpression e, TypeExpectation arg) {
|
||||||
|
visitChildren(e, _expectNum);
|
||||||
|
|
||||||
|
session
|
||||||
|
..addRelationship(NullableIfSomeOtherIs(e, e.childNodes))
|
||||||
|
..addRelationship(HaveSameType(e.lower, e.upper))
|
||||||
|
..addRelationship(HaveSameType(e.check, e.lower));
|
||||||
|
}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
void visitBinaryExpression(BinaryExpression e, TypeExpectation arg) {
|
void visitBinaryExpression(BinaryExpression e, TypeExpectation arg) {
|
||||||
switch (e.operator.type) {
|
switch (e.operator.type) {
|
||||||
|
@ -241,6 +255,45 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
|
||||||
visitChildren(e, const NoTypeExpectation());
|
visitChildren(e, const NoTypeExpectation());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@override
|
||||||
|
void visitIsNullExpression(IsNullExpression e, TypeExpectation arg) {
|
||||||
|
session.checkAndResolve(e, const ResolvedType.bool(), arg);
|
||||||
|
session.hintNullability(e, false);
|
||||||
|
visitChildren(e, const NoTypeExpectation());
|
||||||
|
}
|
||||||
|
|
||||||
|
@override
|
||||||
|
void visitCaseExpression(CaseExpression e, TypeExpectation arg) {
|
||||||
|
session.addRelationship(CopyEncapsulating(e, [
|
||||||
|
for (final when in e.whens) when.then,
|
||||||
|
if (e.elseExpr != null) e.elseExpr,
|
||||||
|
]));
|
||||||
|
|
||||||
|
if (e.base != null) {
|
||||||
|
session.addRelationship(
|
||||||
|
CopyEncapsulating(e.base, [for (final when in e.whens) when.when]),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
visitNullable(e.base, const NoTypeExpectation());
|
||||||
|
visitExcept(e, e.base, arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
@override
|
||||||
|
void visitWhen(WhenComponent e, TypeExpectation arg) {
|
||||||
|
final parent = e.parent;
|
||||||
|
if (parent is CaseExpression && parent.base != null) {
|
||||||
|
// case expressions with base -> condition is compared to base
|
||||||
|
session.addRelationship(CopyTypeFrom(e.when, parent.base));
|
||||||
|
visit(e.when, const NoTypeExpectation());
|
||||||
|
} else {
|
||||||
|
// case expression without base -> the conditions are booleans
|
||||||
|
visit(e.when, const ExactTypeExpectation(ResolvedType.bool()));
|
||||||
|
}
|
||||||
|
|
||||||
|
visit(e.then, arg);
|
||||||
|
}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
void visitCastExpression(CastExpression e, TypeExpectation arg) {
|
void visitCastExpression(CastExpression e, TypeExpectation arg) {
|
||||||
final type = session.context.schemaSupport.resolveColumnType(e.typeName);
|
final type = session.context.schemaSupport.resolveColumnType(e.typeName);
|
||||||
|
|
|
@ -143,7 +143,7 @@ class BetweenExpression extends Expression {
|
||||||
}
|
}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
Iterable<AstNode> get childNodes => [check, lower, upper];
|
List<Expression> get childNodes => [check, lower, upper];
|
||||||
|
|
||||||
@override
|
@override
|
||||||
bool contentEquals(BetweenExpression other) => other.not == not;
|
bool contentEquals(BetweenExpression other) => other.not == not;
|
||||||
|
|
|
@ -137,4 +137,27 @@ void main() {
|
||||||
"SELECT * FROM demo WHERE content LIKE 'foo' ESCAPE ?");
|
"SELECT * FROM demo WHERE content LIKE 'foo' ESCAPE ?");
|
||||||
expect(escapedType, const ResolvedType(type: BasicType.text));
|
expect(escapedType, const ResolvedType(type: BasicType.text));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
group('case expressions', () {
|
||||||
|
test('infers base clause from when', () {
|
||||||
|
final type = _resolveFirstVariable("SELECT CASE ? WHEN 1 THEN 'two' END");
|
||||||
|
expect(type, const ResolvedType(type: BasicType.int));
|
||||||
|
});
|
||||||
|
|
||||||
|
test('infers when condition from base', () {
|
||||||
|
final type = _resolveFirstVariable("SELECT CASE 1 WHEN ? THEN 'two' END");
|
||||||
|
expect(type, const ResolvedType(type: BasicType.int));
|
||||||
|
});
|
||||||
|
|
||||||
|
test('infers when conditions as boolean when no base is set', () {
|
||||||
|
final type = _resolveFirstVariable("SELECT CASE WHEN ? THEN 'two' END;");
|
||||||
|
expect(type, const ResolvedType.bool());
|
||||||
|
});
|
||||||
|
|
||||||
|
test('infers type of whole when expression', () {
|
||||||
|
final type = _resolveResultColumn("SELECT CASE WHEN false THEN 'one' "
|
||||||
|
"WHEN true THEN 'two' ELSE 'three' END;");
|
||||||
|
expect(type, const ResolvedType(type: BasicType.text));
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue