mirror of https://github.com/AMT-Cheif/drift.git
Improve type inference for parentheses
This commit is contained in:
parent
c8a155a44b
commit
8c62365f26
|
@ -1,3 +1,4 @@
|
|||
@Tags(['analyzer'])
|
||||
import 'dart:convert';
|
||||
|
||||
import 'package:moor_generator/src/analyzer/runner/results.dart';
|
||||
|
|
|
@ -9,6 +9,7 @@ class TypeGraph {
|
|||
final List<TypeRelation> _relations = [];
|
||||
|
||||
final Map<Typeable, List<TypeRelation>> _edges = {};
|
||||
final Set<MultiSourceRelation> _multiSources = {};
|
||||
final List<DefaultType> _defaultTypes = [];
|
||||
|
||||
TypeGraph();
|
||||
|
@ -40,13 +41,18 @@ class TypeGraph {
|
|||
}
|
||||
|
||||
void performResolve() {
|
||||
_indexRelationships();
|
||||
_indexRelations();
|
||||
|
||||
final queue = List.of(_knownTypes.keys);
|
||||
while (queue.isNotEmpty) {
|
||||
_propagateTypeInfo(queue, queue.removeLast());
|
||||
}
|
||||
|
||||
// propagate many-to-one sources where we don't know each source type
|
||||
for (final remaining in _multiSources) {
|
||||
_propagateManyToOne(remaining, queue);
|
||||
}
|
||||
|
||||
// apply default types
|
||||
for (final applyDefault in _defaultTypes) {
|
||||
if (!knowsType(applyDefault.target)) {
|
||||
|
@ -73,14 +79,14 @@ class TypeGraph {
|
|||
} else if (edge is MultiSourceRelation) {
|
||||
// handle many-to-one changes, if all targets have been resolved
|
||||
if (edge.from.every(knowsType)) {
|
||||
_propagateManyToOne(edge, resolved, t);
|
||||
_multiSources.remove(edge);
|
||||
_propagateManyToOne(edge, resolved);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void _propagateManyToOne(
|
||||
MultiSourceRelation edge, List<Typeable> resolved, Typeable t) {
|
||||
void _propagateManyToOne(MultiSourceRelation edge, List<Typeable> resolved) {
|
||||
if (!knowsType(edge.target)) {
|
||||
final fromTypes = edge.from.map((t) => this[t]).where((e) => e != null);
|
||||
final encapsulated = _encapsulate(fromTypes);
|
||||
|
@ -107,26 +113,27 @@ class TypeGraph {
|
|||
});
|
||||
}
|
||||
|
||||
void _indexRelationships() {
|
||||
void _indexRelations() {
|
||||
_edges.clear();
|
||||
|
||||
void put(Typeable t, TypeRelation r) {
|
||||
_edges.putIfAbsent(t, () => []).add(r);
|
||||
}
|
||||
|
||||
void putAll(Iterable<Typeable> t, TypeRelation r) {
|
||||
for (final element in t) {
|
||||
void putAll(MultiSourceRelation r) {
|
||||
_multiSources.add(r);
|
||||
for (final element in r.from) {
|
||||
put(element, r);
|
||||
}
|
||||
}
|
||||
|
||||
for (final relation in _relations) {
|
||||
if (relation is NullableIfSomeOtherIs) {
|
||||
putAll(relation.from, relation);
|
||||
putAll(relation);
|
||||
} else if (relation is CopyTypeFrom) {
|
||||
put(relation.other, relation);
|
||||
} else if (relation is CopyEncapsulating) {
|
||||
putAll(relation.from, relation);
|
||||
putAll(relation);
|
||||
} else if (relation is HaveSameType) {
|
||||
put(relation.first, relation);
|
||||
put(relation.second, relation);
|
||||
|
|
|
@ -268,7 +268,10 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
|
|||
break;
|
||||
case TokenType.plus:
|
||||
case TokenType.minus:
|
||||
case TokenType.star:
|
||||
case TokenType.slash:
|
||||
session._addRelation(CopyEncapsulating(e, [e.left, e.right]));
|
||||
visitChildren(e, const RoughTypeExpectation.numeric());
|
||||
break;
|
||||
// all of those only really make sense for integers
|
||||
case TokenType.shiftLeft:
|
||||
|
@ -378,6 +381,12 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
|
|||
visitNullable(e.escape, _expectString);
|
||||
}
|
||||
|
||||
@override
|
||||
void visitParentheses(Parentheses e, TypeExpectation arg) {
|
||||
session._addRelation(CopyTypeFrom(e, e.expression));
|
||||
visit(e.expression, arg);
|
||||
}
|
||||
|
||||
@override
|
||||
void visitReference(Reference e, TypeExpectation arg) {
|
||||
final resolved = e.resolvedColumn;
|
||||
|
|
|
@ -189,7 +189,7 @@ class Parentheses extends Expression {
|
|||
|
||||
@override
|
||||
R accept<A, R>(AstVisitor<A, R> visitor, A arg) {
|
||||
return expression.accept(visitor, arg);
|
||||
return visitor.visitParentheses(this, arg);
|
||||
}
|
||||
|
||||
@override
|
||||
|
|
|
@ -51,6 +51,7 @@ abstract class AstVisitor<A, R> {
|
|||
R visitCaseExpression(CaseExpression e, A arg);
|
||||
R visitWhen(WhenComponent e, A arg);
|
||||
R visitTuple(Tuple e, A arg);
|
||||
R visitParentheses(Parentheses e, A arg);
|
||||
R visitInExpression(InExpression e, A arg);
|
||||
|
||||
R visitAggregateExpression(AggregateExpression e, A arg);
|
||||
|
@ -378,6 +379,11 @@ class RecursiveVisitor<A, R> implements AstVisitor<A, R> {
|
|||
return visitExpression(e, arg);
|
||||
}
|
||||
|
||||
@override
|
||||
R visitParentheses(Parentheses e, A arg) {
|
||||
return e.expression.accept(this, arg);
|
||||
}
|
||||
|
||||
@override
|
||||
R visitInExpression(InExpression e, A arg) {
|
||||
return visitExpression(e, arg);
|
||||
|
|
|
@ -32,6 +32,8 @@ Map<String, ResolveResult> _types = {
|
|||
'SELECT ?;': const ResolveResult.unknown(),
|
||||
'SELECT CAST(3 AS TEXT) = ?':
|
||||
const ResolveResult(ResolvedType(type: BasicType.text)),
|
||||
'SELECT (3 * 4) = ?': const ResolveResult(ResolvedType(type: BasicType.int)),
|
||||
'SELECT (3 / 4) = ?': const ResolveResult(ResolvedType(type: BasicType.int)),
|
||||
};
|
||||
|
||||
void main() {
|
||||
|
|
|
@ -27,6 +27,8 @@ const Map<String, ResolvedType> _types = {
|
|||
ResolvedType(type: BasicType.int),
|
||||
'SELECT ?;': null,
|
||||
'SELECT CAST(3 AS TEXT) = ?': ResolvedType(type: BasicType.text),
|
||||
'SELECT (3 * 4) = ?': ResolvedType(type: BasicType.int),
|
||||
'SELECT (3 / 4) = ?': ResolvedType(type: BasicType.int),
|
||||
};
|
||||
|
||||
SqlEngine _spawnEngine() {
|
||||
|
@ -50,4 +52,24 @@ void main() {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
test('resolves all expressions in CTE', () {
|
||||
final engine = _spawnEngine();
|
||||
final content = engine.analyze('''
|
||||
WITH RECURSIVE
|
||||
cnt(x) AS (
|
||||
SELECT 1
|
||||
UNION ALL
|
||||
SELECT x+1 FROM cnt
|
||||
LIMIT 1000000
|
||||
)
|
||||
SELECT x FROM cnt;
|
||||
''');
|
||||
|
||||
final expressions = content.root.allDescendants.whereType<Expression>();
|
||||
expect(
|
||||
expressions.map((e) => content.typeOf(e).type),
|
||||
everyElement(isNotNull),
|
||||
);
|
||||
});
|
||||
}
|
||||
|
|
|
@ -67,6 +67,11 @@ void main() {
|
|||
const ResolvedType(type: BasicType.text));
|
||||
});
|
||||
|
||||
test('resolves arithmetic expressions', () {
|
||||
expect(_resolveFirstVariable('SELECT ((3 + 4) * 5) = ?'),
|
||||
const ResolvedType(type: BasicType.int));
|
||||
});
|
||||
|
||||
group('cast expressions', () {
|
||||
test('resolve to type argument', () {
|
||||
expect(_resolveResultColumn('SELECT CAST(3+4 AS TEXT)'),
|
||||
|
|
Loading…
Reference in New Issue