Improve type inference for parentheses

This commit is contained in:
Simon Binder 2020-01-19 13:07:47 +01:00
parent c8a155a44b
commit 8c62365f26
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
8 changed files with 62 additions and 10 deletions

View File

@ -1,3 +1,4 @@
@Tags(['analyzer'])
import 'dart:convert'; import 'dart:convert';
import 'package:moor_generator/src/analyzer/runner/results.dart'; import 'package:moor_generator/src/analyzer/runner/results.dart';

View File

@ -9,6 +9,7 @@ class TypeGraph {
final List<TypeRelation> _relations = []; final List<TypeRelation> _relations = [];
final Map<Typeable, List<TypeRelation>> _edges = {}; final Map<Typeable, List<TypeRelation>> _edges = {};
final Set<MultiSourceRelation> _multiSources = {};
final List<DefaultType> _defaultTypes = []; final List<DefaultType> _defaultTypes = [];
TypeGraph(); TypeGraph();
@ -40,13 +41,18 @@ class TypeGraph {
} }
void performResolve() { void performResolve() {
_indexRelationships(); _indexRelations();
final queue = List.of(_knownTypes.keys); final queue = List.of(_knownTypes.keys);
while (queue.isNotEmpty) { while (queue.isNotEmpty) {
_propagateTypeInfo(queue, queue.removeLast()); _propagateTypeInfo(queue, queue.removeLast());
} }
// propagate many-to-one sources where we don't know each source type
for (final remaining in _multiSources) {
_propagateManyToOne(remaining, queue);
}
// apply default types // apply default types
for (final applyDefault in _defaultTypes) { for (final applyDefault in _defaultTypes) {
if (!knowsType(applyDefault.target)) { if (!knowsType(applyDefault.target)) {
@ -73,14 +79,14 @@ class TypeGraph {
} else if (edge is MultiSourceRelation) { } else if (edge is MultiSourceRelation) {
// handle many-to-one changes, if all targets have been resolved // handle many-to-one changes, if all targets have been resolved
if (edge.from.every(knowsType)) { if (edge.from.every(knowsType)) {
_propagateManyToOne(edge, resolved, t); _multiSources.remove(edge);
_propagateManyToOne(edge, resolved);
} }
} }
} }
} }
void _propagateManyToOne( void _propagateManyToOne(MultiSourceRelation edge, List<Typeable> resolved) {
MultiSourceRelation edge, List<Typeable> resolved, Typeable t) {
if (!knowsType(edge.target)) { if (!knowsType(edge.target)) {
final fromTypes = edge.from.map((t) => this[t]).where((e) => e != null); final fromTypes = edge.from.map((t) => this[t]).where((e) => e != null);
final encapsulated = _encapsulate(fromTypes); final encapsulated = _encapsulate(fromTypes);
@ -107,26 +113,27 @@ class TypeGraph {
}); });
} }
void _indexRelationships() { void _indexRelations() {
_edges.clear(); _edges.clear();
void put(Typeable t, TypeRelation r) { void put(Typeable t, TypeRelation r) {
_edges.putIfAbsent(t, () => []).add(r); _edges.putIfAbsent(t, () => []).add(r);
} }
void putAll(Iterable<Typeable> t, TypeRelation r) { void putAll(MultiSourceRelation r) {
for (final element in t) { _multiSources.add(r);
for (final element in r.from) {
put(element, r); put(element, r);
} }
} }
for (final relation in _relations) { for (final relation in _relations) {
if (relation is NullableIfSomeOtherIs) { if (relation is NullableIfSomeOtherIs) {
putAll(relation.from, relation); putAll(relation);
} else if (relation is CopyTypeFrom) { } else if (relation is CopyTypeFrom) {
put(relation.other, relation); put(relation.other, relation);
} else if (relation is CopyEncapsulating) { } else if (relation is CopyEncapsulating) {
putAll(relation.from, relation); putAll(relation);
} else if (relation is HaveSameType) { } else if (relation is HaveSameType) {
put(relation.first, relation); put(relation.first, relation);
put(relation.second, relation); put(relation.second, relation);

View File

@ -268,7 +268,10 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
break; break;
case TokenType.plus: case TokenType.plus:
case TokenType.minus: case TokenType.minus:
case TokenType.star:
case TokenType.slash:
session._addRelation(CopyEncapsulating(e, [e.left, e.right])); session._addRelation(CopyEncapsulating(e, [e.left, e.right]));
visitChildren(e, const RoughTypeExpectation.numeric());
break; break;
// all of those only really make sense for integers // all of those only really make sense for integers
case TokenType.shiftLeft: case TokenType.shiftLeft:
@ -378,6 +381,12 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
visitNullable(e.escape, _expectString); visitNullable(e.escape, _expectString);
} }
@override
void visitParentheses(Parentheses e, TypeExpectation arg) {
session._addRelation(CopyTypeFrom(e, e.expression));
visit(e.expression, arg);
}
@override @override
void visitReference(Reference e, TypeExpectation arg) { void visitReference(Reference e, TypeExpectation arg) {
final resolved = e.resolvedColumn; final resolved = e.resolvedColumn;

View File

@ -189,7 +189,7 @@ class Parentheses extends Expression {
@override @override
R accept<A, R>(AstVisitor<A, R> visitor, A arg) { R accept<A, R>(AstVisitor<A, R> visitor, A arg) {
return expression.accept(visitor, arg); return visitor.visitParentheses(this, arg);
} }
@override @override

View File

@ -51,6 +51,7 @@ abstract class AstVisitor<A, R> {
R visitCaseExpression(CaseExpression e, A arg); R visitCaseExpression(CaseExpression e, A arg);
R visitWhen(WhenComponent e, A arg); R visitWhen(WhenComponent e, A arg);
R visitTuple(Tuple e, A arg); R visitTuple(Tuple e, A arg);
R visitParentheses(Parentheses e, A arg);
R visitInExpression(InExpression e, A arg); R visitInExpression(InExpression e, A arg);
R visitAggregateExpression(AggregateExpression e, A arg); R visitAggregateExpression(AggregateExpression e, A arg);
@ -378,6 +379,11 @@ class RecursiveVisitor<A, R> implements AstVisitor<A, R> {
return visitExpression(e, arg); return visitExpression(e, arg);
} }
@override
R visitParentheses(Parentheses e, A arg) {
return e.expression.accept(this, arg);
}
@override @override
R visitInExpression(InExpression e, A arg) { R visitInExpression(InExpression e, A arg) {
return visitExpression(e, arg); return visitExpression(e, arg);

View File

@ -32,6 +32,8 @@ Map<String, ResolveResult> _types = {
'SELECT ?;': const ResolveResult.unknown(), 'SELECT ?;': const ResolveResult.unknown(),
'SELECT CAST(3 AS TEXT) = ?': 'SELECT CAST(3 AS TEXT) = ?':
const ResolveResult(ResolvedType(type: BasicType.text)), const ResolveResult(ResolvedType(type: BasicType.text)),
'SELECT (3 * 4) = ?': const ResolveResult(ResolvedType(type: BasicType.int)),
'SELECT (3 / 4) = ?': const ResolveResult(ResolvedType(type: BasicType.int)),
}; };
void main() { void main() {

View File

@ -27,6 +27,8 @@ const Map<String, ResolvedType> _types = {
ResolvedType(type: BasicType.int), ResolvedType(type: BasicType.int),
'SELECT ?;': null, 'SELECT ?;': null,
'SELECT CAST(3 AS TEXT) = ?': ResolvedType(type: BasicType.text), 'SELECT CAST(3 AS TEXT) = ?': ResolvedType(type: BasicType.text),
'SELECT (3 * 4) = ?': ResolvedType(type: BasicType.int),
'SELECT (3 / 4) = ?': ResolvedType(type: BasicType.int),
}; };
SqlEngine _spawnEngine() { SqlEngine _spawnEngine() {
@ -50,4 +52,24 @@ void main() {
}); });
}); });
}); });
test('resolves all expressions in CTE', () {
final engine = _spawnEngine();
final content = engine.analyze('''
WITH RECURSIVE
cnt(x) AS (
SELECT 1
UNION ALL
SELECT x+1 FROM cnt
LIMIT 1000000
)
SELECT x FROM cnt;
''');
final expressions = content.root.allDescendants.whereType<Expression>();
expect(
expressions.map((e) => content.typeOf(e).type),
everyElement(isNotNull),
);
});
} }

View File

@ -67,6 +67,11 @@ void main() {
const ResolvedType(type: BasicType.text)); const ResolvedType(type: BasicType.text));
}); });
test('resolves arithmetic expressions', () {
expect(_resolveFirstVariable('SELECT ((3 + 4) * 5) = ?'),
const ResolvedType(type: BasicType.int));
});
group('cast expressions', () { group('cast expressions', () {
test('resolve to type argument', () { test('resolve to type argument', () {
expect(_resolveResultColumn('SELECT CAST(3+4 AS TEXT)'), expect(_resolveResultColumn('SELECT CAST(3+4 AS TEXT)'),