mirror of https://github.com/AMT-Cheif/drift.git
Support CASE expressions in the new type resolver
This commit is contained in:
parent
6434f7a7d5
commit
9c38ed1ea5
|
@ -52,6 +52,15 @@ class RoughTypeExpectation extends TypeExpectation {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
ResolvedType defaultType() {
|
||||
switch (_type) {
|
||||
case _RoughType.numeric:
|
||||
return const ResolvedType(type: BasicType.real);
|
||||
}
|
||||
|
||||
throw AssertionError();
|
||||
}
|
||||
}
|
||||
|
||||
enum _RoughType {
|
||||
|
|
|
@ -9,6 +9,8 @@ class TypeGraph {
|
|||
final List<TypeRelationship> _relationships = [];
|
||||
|
||||
final Map<Typeable, List<TypeRelationship>> _edges = {};
|
||||
final List<DefaultType> _defaultTypes = [];
|
||||
final List<CopyEncapsulating> _manyToOne = [];
|
||||
|
||||
TypeGraph();
|
||||
|
||||
|
@ -45,11 +47,30 @@ class TypeGraph {
|
|||
while (queue.isNotEmpty) {
|
||||
_propagateTypeInfo(queue, queue.removeLast());
|
||||
}
|
||||
|
||||
// propagate many-to-one changes
|
||||
for (final edge in _manyToOne) {
|
||||
if (!knowsType(edge.target)) {
|
||||
final fromTypes = edge.from.map((t) => this[t]).where((e) => e != null);
|
||||
final encapsulated = _encapsulate(fromTypes);
|
||||
if (encapsulated != null) {
|
||||
this[edge.target] = encapsulated;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// apply default types
|
||||
for (final applyDefault in _defaultTypes) {
|
||||
if (!knowsType(applyDefault.target)) {
|
||||
this[applyDefault.target] = applyDefault.defaultType;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void _propagateTypeInfo(List<Typeable> resolved, Typeable t) {
|
||||
if (!_edges.containsKey(t)) return;
|
||||
|
||||
// propagate one-to-one and one-to-many changes
|
||||
for (final edge in _edges[t]) {
|
||||
if (edge is CopyTypeFrom) {
|
||||
_copyType(resolved, edge.other, edge.target);
|
||||
|
@ -68,6 +89,12 @@ class TypeGraph {
|
|||
}
|
||||
}
|
||||
|
||||
ResolvedType /*?*/ _encapsulate(Iterable<ResolvedType> targets) {
|
||||
return targets.fold<ResolvedType>(null, (previous, element) {
|
||||
return previous?.union(element) ?? element;
|
||||
});
|
||||
}
|
||||
|
||||
void _indexRelationships() {
|
||||
_edges.clear();
|
||||
|
||||
|
@ -87,12 +114,12 @@ class TypeGraph {
|
|||
} else if (relation is CopyTypeFrom) {
|
||||
put(relation.other, relation);
|
||||
} else if (relation is CopyEncapsulating) {
|
||||
putAll(relation.from, relation);
|
||||
_manyToOne.add(relation);
|
||||
} else if (relation is HaveSameType) {
|
||||
put(relation.first, relation);
|
||||
put(relation.second, relation);
|
||||
} else if (relation is DefaultType) {
|
||||
put(relation.target, relation);
|
||||
_defaultTypes.add(relation);
|
||||
} else if (relation is CopyAndCast) {
|
||||
put(relation.other, relation);
|
||||
} else {
|
||||
|
@ -117,3 +144,16 @@ class _ResolvedVariables {
|
|||
return _referenceForIndex[normalized.resolvedIndex] ??= normalized;
|
||||
}
|
||||
}
|
||||
|
||||
extension on ResolvedType {
|
||||
ResolvedType union(ResolvedType other) {
|
||||
if (other == this) return this;
|
||||
|
||||
if (other.type == type) {
|
||||
return withNullable(nullable || other.nullable);
|
||||
}
|
||||
|
||||
// fallback. todo: Support more cases
|
||||
return const ResolvedType(type: BasicType.text, nullable: true);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ part of 'types.dart';
|
|||
|
||||
const _expectInt =
|
||||
ExactTypeExpectation.laxly(ResolvedType(type: BasicType.int));
|
||||
const _expectNum = RoughTypeExpectation.numeric();
|
||||
const _expectString =
|
||||
ExactTypeExpectation.laxly(ResolvedType(type: BasicType.text));
|
||||
|
||||
|
@ -141,10 +142,13 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
|
|||
void visitVariable(Variable e, TypeExpectation arg) {
|
||||
final resolved = session.context.stmtOptions.specifiedTypeOf(e) ??
|
||||
_inferFromContext(arg);
|
||||
|
||||
if (resolved != null) {
|
||||
session.markTypeResolved(e, resolved);
|
||||
session.checkAndResolve(e, resolved, arg);
|
||||
} else if (arg is RoughTypeExpectation) {
|
||||
session.addRelationship(DefaultType(e, arg.defaultType()));
|
||||
}
|
||||
// todo support when arg is RoughTypeExpectation
|
||||
|
||||
visitChildren(e, arg);
|
||||
}
|
||||
|
||||
|
@ -178,6 +182,16 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
|
|||
}
|
||||
}
|
||||
|
||||
@override
|
||||
void visitBetweenExpression(BetweenExpression e, TypeExpectation arg) {
|
||||
visitChildren(e, _expectNum);
|
||||
|
||||
session
|
||||
..addRelationship(NullableIfSomeOtherIs(e, e.childNodes))
|
||||
..addRelationship(HaveSameType(e.lower, e.upper))
|
||||
..addRelationship(HaveSameType(e.check, e.lower));
|
||||
}
|
||||
|
||||
@override
|
||||
void visitBinaryExpression(BinaryExpression e, TypeExpectation arg) {
|
||||
switch (e.operator.type) {
|
||||
|
@ -241,6 +255,45 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
|
|||
visitChildren(e, const NoTypeExpectation());
|
||||
}
|
||||
|
||||
@override
|
||||
void visitIsNullExpression(IsNullExpression e, TypeExpectation arg) {
|
||||
session.checkAndResolve(e, const ResolvedType.bool(), arg);
|
||||
session.hintNullability(e, false);
|
||||
visitChildren(e, const NoTypeExpectation());
|
||||
}
|
||||
|
||||
@override
|
||||
void visitCaseExpression(CaseExpression e, TypeExpectation arg) {
|
||||
session.addRelationship(CopyEncapsulating(e, [
|
||||
for (final when in e.whens) when.then,
|
||||
if (e.elseExpr != null) e.elseExpr,
|
||||
]));
|
||||
|
||||
if (e.base != null) {
|
||||
session.addRelationship(
|
||||
CopyEncapsulating(e.base, [for (final when in e.whens) when.when]),
|
||||
);
|
||||
}
|
||||
|
||||
visitNullable(e.base, const NoTypeExpectation());
|
||||
visitExcept(e, e.base, arg);
|
||||
}
|
||||
|
||||
@override
|
||||
void visitWhen(WhenComponent e, TypeExpectation arg) {
|
||||
final parent = e.parent;
|
||||
if (parent is CaseExpression && parent.base != null) {
|
||||
// case expressions with base -> condition is compared to base
|
||||
session.addRelationship(CopyTypeFrom(e.when, parent.base));
|
||||
visit(e.when, const NoTypeExpectation());
|
||||
} else {
|
||||
// case expression without base -> the conditions are booleans
|
||||
visit(e.when, const ExactTypeExpectation(ResolvedType.bool()));
|
||||
}
|
||||
|
||||
visit(e.then, arg);
|
||||
}
|
||||
|
||||
@override
|
||||
void visitCastExpression(CastExpression e, TypeExpectation arg) {
|
||||
final type = session.context.schemaSupport.resolveColumnType(e.typeName);
|
||||
|
|
|
@ -143,7 +143,7 @@ class BetweenExpression extends Expression {
|
|||
}
|
||||
|
||||
@override
|
||||
Iterable<AstNode> get childNodes => [check, lower, upper];
|
||||
List<Expression> get childNodes => [check, lower, upper];
|
||||
|
||||
@override
|
||||
bool contentEquals(BetweenExpression other) => other.not == not;
|
||||
|
|
|
@ -137,4 +137,27 @@ void main() {
|
|||
"SELECT * FROM demo WHERE content LIKE 'foo' ESCAPE ?");
|
||||
expect(escapedType, const ResolvedType(type: BasicType.text));
|
||||
});
|
||||
|
||||
group('case expressions', () {
|
||||
test('infers base clause from when', () {
|
||||
final type = _resolveFirstVariable("SELECT CASE ? WHEN 1 THEN 'two' END");
|
||||
expect(type, const ResolvedType(type: BasicType.int));
|
||||
});
|
||||
|
||||
test('infers when condition from base', () {
|
||||
final type = _resolveFirstVariable("SELECT CASE 1 WHEN ? THEN 'two' END");
|
||||
expect(type, const ResolvedType(type: BasicType.int));
|
||||
});
|
||||
|
||||
test('infers when conditions as boolean when no base is set', () {
|
||||
final type = _resolveFirstVariable("SELECT CASE WHEN ? THEN 'two' END;");
|
||||
expect(type, const ResolvedType.bool());
|
||||
});
|
||||
|
||||
test('infers type of whole when expression', () {
|
||||
final type = _resolveResultColumn("SELECT CASE WHEN false THEN 'one' "
|
||||
"WHEN true THEN 'two' ELSE 'three' END;");
|
||||
expect(type, const ResolvedType(type: BasicType.text));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue