mirror of https://github.com/AMT-Cheif/drift.git
Improve type inference for parentheses
This commit is contained in:
parent
c8a155a44b
commit
8c62365f26
|
@ -1,3 +1,4 @@
|
||||||
|
@Tags(['analyzer'])
|
||||||
import 'dart:convert';
|
import 'dart:convert';
|
||||||
|
|
||||||
import 'package:moor_generator/src/analyzer/runner/results.dart';
|
import 'package:moor_generator/src/analyzer/runner/results.dart';
|
||||||
|
|
|
@ -9,6 +9,7 @@ class TypeGraph {
|
||||||
final List<TypeRelation> _relations = [];
|
final List<TypeRelation> _relations = [];
|
||||||
|
|
||||||
final Map<Typeable, List<TypeRelation>> _edges = {};
|
final Map<Typeable, List<TypeRelation>> _edges = {};
|
||||||
|
final Set<MultiSourceRelation> _multiSources = {};
|
||||||
final List<DefaultType> _defaultTypes = [];
|
final List<DefaultType> _defaultTypes = [];
|
||||||
|
|
||||||
TypeGraph();
|
TypeGraph();
|
||||||
|
@ -40,13 +41,18 @@ class TypeGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
void performResolve() {
|
void performResolve() {
|
||||||
_indexRelationships();
|
_indexRelations();
|
||||||
|
|
||||||
final queue = List.of(_knownTypes.keys);
|
final queue = List.of(_knownTypes.keys);
|
||||||
while (queue.isNotEmpty) {
|
while (queue.isNotEmpty) {
|
||||||
_propagateTypeInfo(queue, queue.removeLast());
|
_propagateTypeInfo(queue, queue.removeLast());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// propagate many-to-one sources where we don't know each source type
|
||||||
|
for (final remaining in _multiSources) {
|
||||||
|
_propagateManyToOne(remaining, queue);
|
||||||
|
}
|
||||||
|
|
||||||
// apply default types
|
// apply default types
|
||||||
for (final applyDefault in _defaultTypes) {
|
for (final applyDefault in _defaultTypes) {
|
||||||
if (!knowsType(applyDefault.target)) {
|
if (!knowsType(applyDefault.target)) {
|
||||||
|
@ -73,14 +79,14 @@ class TypeGraph {
|
||||||
} else if (edge is MultiSourceRelation) {
|
} else if (edge is MultiSourceRelation) {
|
||||||
// handle many-to-one changes, if all targets have been resolved
|
// handle many-to-one changes, if all targets have been resolved
|
||||||
if (edge.from.every(knowsType)) {
|
if (edge.from.every(knowsType)) {
|
||||||
_propagateManyToOne(edge, resolved, t);
|
_multiSources.remove(edge);
|
||||||
|
_propagateManyToOne(edge, resolved);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void _propagateManyToOne(
|
void _propagateManyToOne(MultiSourceRelation edge, List<Typeable> resolved) {
|
||||||
MultiSourceRelation edge, List<Typeable> resolved, Typeable t) {
|
|
||||||
if (!knowsType(edge.target)) {
|
if (!knowsType(edge.target)) {
|
||||||
final fromTypes = edge.from.map((t) => this[t]).where((e) => e != null);
|
final fromTypes = edge.from.map((t) => this[t]).where((e) => e != null);
|
||||||
final encapsulated = _encapsulate(fromTypes);
|
final encapsulated = _encapsulate(fromTypes);
|
||||||
|
@ -107,26 +113,27 @@ class TypeGraph {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void _indexRelationships() {
|
void _indexRelations() {
|
||||||
_edges.clear();
|
_edges.clear();
|
||||||
|
|
||||||
void put(Typeable t, TypeRelation r) {
|
void put(Typeable t, TypeRelation r) {
|
||||||
_edges.putIfAbsent(t, () => []).add(r);
|
_edges.putIfAbsent(t, () => []).add(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
void putAll(Iterable<Typeable> t, TypeRelation r) {
|
void putAll(MultiSourceRelation r) {
|
||||||
for (final element in t) {
|
_multiSources.add(r);
|
||||||
|
for (final element in r.from) {
|
||||||
put(element, r);
|
put(element, r);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (final relation in _relations) {
|
for (final relation in _relations) {
|
||||||
if (relation is NullableIfSomeOtherIs) {
|
if (relation is NullableIfSomeOtherIs) {
|
||||||
putAll(relation.from, relation);
|
putAll(relation);
|
||||||
} else if (relation is CopyTypeFrom) {
|
} else if (relation is CopyTypeFrom) {
|
||||||
put(relation.other, relation);
|
put(relation.other, relation);
|
||||||
} else if (relation is CopyEncapsulating) {
|
} else if (relation is CopyEncapsulating) {
|
||||||
putAll(relation.from, relation);
|
putAll(relation);
|
||||||
} else if (relation is HaveSameType) {
|
} else if (relation is HaveSameType) {
|
||||||
put(relation.first, relation);
|
put(relation.first, relation);
|
||||||
put(relation.second, relation);
|
put(relation.second, relation);
|
||||||
|
|
|
@ -268,7 +268,10 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
|
||||||
break;
|
break;
|
||||||
case TokenType.plus:
|
case TokenType.plus:
|
||||||
case TokenType.minus:
|
case TokenType.minus:
|
||||||
|
case TokenType.star:
|
||||||
|
case TokenType.slash:
|
||||||
session._addRelation(CopyEncapsulating(e, [e.left, e.right]));
|
session._addRelation(CopyEncapsulating(e, [e.left, e.right]));
|
||||||
|
visitChildren(e, const RoughTypeExpectation.numeric());
|
||||||
break;
|
break;
|
||||||
// all of those only really make sense for integers
|
// all of those only really make sense for integers
|
||||||
case TokenType.shiftLeft:
|
case TokenType.shiftLeft:
|
||||||
|
@ -378,6 +381,12 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
|
||||||
visitNullable(e.escape, _expectString);
|
visitNullable(e.escape, _expectString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@override
|
||||||
|
void visitParentheses(Parentheses e, TypeExpectation arg) {
|
||||||
|
session._addRelation(CopyTypeFrom(e, e.expression));
|
||||||
|
visit(e.expression, arg);
|
||||||
|
}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
void visitReference(Reference e, TypeExpectation arg) {
|
void visitReference(Reference e, TypeExpectation arg) {
|
||||||
final resolved = e.resolvedColumn;
|
final resolved = e.resolvedColumn;
|
||||||
|
|
|
@ -189,7 +189,7 @@ class Parentheses extends Expression {
|
||||||
|
|
||||||
@override
|
@override
|
||||||
R accept<A, R>(AstVisitor<A, R> visitor, A arg) {
|
R accept<A, R>(AstVisitor<A, R> visitor, A arg) {
|
||||||
return expression.accept(visitor, arg);
|
return visitor.visitParentheses(this, arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|
|
@ -51,6 +51,7 @@ abstract class AstVisitor<A, R> {
|
||||||
R visitCaseExpression(CaseExpression e, A arg);
|
R visitCaseExpression(CaseExpression e, A arg);
|
||||||
R visitWhen(WhenComponent e, A arg);
|
R visitWhen(WhenComponent e, A arg);
|
||||||
R visitTuple(Tuple e, A arg);
|
R visitTuple(Tuple e, A arg);
|
||||||
|
R visitParentheses(Parentheses e, A arg);
|
||||||
R visitInExpression(InExpression e, A arg);
|
R visitInExpression(InExpression e, A arg);
|
||||||
|
|
||||||
R visitAggregateExpression(AggregateExpression e, A arg);
|
R visitAggregateExpression(AggregateExpression e, A arg);
|
||||||
|
@ -378,6 +379,11 @@ class RecursiveVisitor<A, R> implements AstVisitor<A, R> {
|
||||||
return visitExpression(e, arg);
|
return visitExpression(e, arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@override
|
||||||
|
R visitParentheses(Parentheses e, A arg) {
|
||||||
|
return e.expression.accept(this, arg);
|
||||||
|
}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
R visitInExpression(InExpression e, A arg) {
|
R visitInExpression(InExpression e, A arg) {
|
||||||
return visitExpression(e, arg);
|
return visitExpression(e, arg);
|
||||||
|
|
|
@ -32,6 +32,8 @@ Map<String, ResolveResult> _types = {
|
||||||
'SELECT ?;': const ResolveResult.unknown(),
|
'SELECT ?;': const ResolveResult.unknown(),
|
||||||
'SELECT CAST(3 AS TEXT) = ?':
|
'SELECT CAST(3 AS TEXT) = ?':
|
||||||
const ResolveResult(ResolvedType(type: BasicType.text)),
|
const ResolveResult(ResolvedType(type: BasicType.text)),
|
||||||
|
'SELECT (3 * 4) = ?': const ResolveResult(ResolvedType(type: BasicType.int)),
|
||||||
|
'SELECT (3 / 4) = ?': const ResolveResult(ResolvedType(type: BasicType.int)),
|
||||||
};
|
};
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
|
|
@ -27,6 +27,8 @@ const Map<String, ResolvedType> _types = {
|
||||||
ResolvedType(type: BasicType.int),
|
ResolvedType(type: BasicType.int),
|
||||||
'SELECT ?;': null,
|
'SELECT ?;': null,
|
||||||
'SELECT CAST(3 AS TEXT) = ?': ResolvedType(type: BasicType.text),
|
'SELECT CAST(3 AS TEXT) = ?': ResolvedType(type: BasicType.text),
|
||||||
|
'SELECT (3 * 4) = ?': ResolvedType(type: BasicType.int),
|
||||||
|
'SELECT (3 / 4) = ?': ResolvedType(type: BasicType.int),
|
||||||
};
|
};
|
||||||
|
|
||||||
SqlEngine _spawnEngine() {
|
SqlEngine _spawnEngine() {
|
||||||
|
@ -50,4 +52,24 @@ void main() {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('resolves all expressions in CTE', () {
|
||||||
|
final engine = _spawnEngine();
|
||||||
|
final content = engine.analyze('''
|
||||||
|
WITH RECURSIVE
|
||||||
|
cnt(x) AS (
|
||||||
|
SELECT 1
|
||||||
|
UNION ALL
|
||||||
|
SELECT x+1 FROM cnt
|
||||||
|
LIMIT 1000000
|
||||||
|
)
|
||||||
|
SELECT x FROM cnt;
|
||||||
|
''');
|
||||||
|
|
||||||
|
final expressions = content.root.allDescendants.whereType<Expression>();
|
||||||
|
expect(
|
||||||
|
expressions.map((e) => content.typeOf(e).type),
|
||||||
|
everyElement(isNotNull),
|
||||||
|
);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,6 +67,11 @@ void main() {
|
||||||
const ResolvedType(type: BasicType.text));
|
const ResolvedType(type: BasicType.text));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('resolves arithmetic expressions', () {
|
||||||
|
expect(_resolveFirstVariable('SELECT ((3 + 4) * 5) = ?'),
|
||||||
|
const ResolvedType(type: BasicType.int));
|
||||||
|
});
|
||||||
|
|
||||||
group('cast expressions', () {
|
group('cast expressions', () {
|
||||||
test('resolve to type argument', () {
|
test('resolve to type argument', () {
|
||||||
expect(_resolveResultColumn('SELECT CAST(3+4 AS TEXT)'),
|
expect(_resolveResultColumn('SELECT CAST(3+4 AS TEXT)'),
|
||||||
|
|
Loading…
Reference in New Issue