Improve nullability detection in new type inference

This commit is contained in:
Simon Binder 2020-03-24 19:33:57 +01:00
parent cdd57f340d
commit 2b9a85714f
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
6 changed files with 77 additions and 33 deletions

View File

@ -26,11 +26,17 @@ class ResolvedType {
const ResolvedType( const ResolvedType(
{this.type, this.hint, this.nullable = false, this.isArray = false}); {this.type, this.hint, this.nullable = false, this.isArray = false});
const ResolvedType.bool({bool nullable}) const ResolvedType.bool({bool nullable = false})
: this(type: BasicType.int, hint: const IsBoolean(), nullable: nullable); : this(type: BasicType.int, hint: const IsBoolean(), nullable: nullable);
ResolvedType get withoutNullabilityInfo {
return nullable == null
? this
: ResolvedType(type: type, hint: hint, isArray: isArray);
}
ResolvedType withNullable(bool nullable) { ResolvedType withNullable(bool nullable) {
return copyWith(nullable: nullable); return nullable == this.nullable ? this : copyWith(nullable: nullable);
} }
ResolvedType toArray(bool array) { ResolvedType toArray(bool array) {

View File

@ -56,9 +56,10 @@ class HaveSameType extends TypeRelation {
class DefaultType extends TypeRelation implements DirectedRelation { class DefaultType extends TypeRelation implements DirectedRelation {
@override @override
final Typeable target; final Typeable target;
final ResolvedType defaultType; final ResolvedType /*?*/ defaultType;
final bool /*?*/ isNullable;
DefaultType(this.target, this.defaultType); DefaultType(this.target, {this.defaultType, this.isNullable});
} }
enum CastMode { numeric, boolean } enum CastMode { numeric, boolean }

View File

@ -12,13 +12,17 @@ class TypeGraph {
final Set<MultiSourceRelation> _multiSources = {}; final Set<MultiSourceRelation> _multiSources = {};
final List<DefaultType> _defaultTypes = []; final List<DefaultType> _defaultTypes = [];
TypeGraph();
ResolvedType operator [](Typeable t) { ResolvedType operator [](Typeable t) {
final normalized = variables.normalize(t); final normalized = variables.normalize(t);
if (_knownTypes.containsKey(normalized)) { if (_knownTypes.containsKey(normalized)) {
return _knownTypes[normalized]; final type = _knownTypes[normalized];
final nullability = _knownNullability[normalized];
if (nullability != null) {
return type.withNullable(nullability);
}
return type;
} }
return null; return null;
@ -28,7 +32,7 @@ class TypeGraph {
final normalized = variables.normalize(t); final normalized = variables.normalize(t);
_knownTypes[normalized] = type; _knownTypes[normalized] = type;
if (type.nullable != null) { if (type.nullable != null && !_knownNullability.containsKey(normalized)) {
// nullability is known // nullability is known
_knownNullability[normalized] = type.nullable; _knownNullability[normalized] = type.nullable;
} }
@ -40,6 +44,10 @@ class TypeGraph {
_relations.add(relation); _relations.add(relation);
} }
void markNullability(Typeable t, bool isNullable) {
_knownNullability[variables.normalize(t)] = isNullable;
}
void performResolve() { void performResolve() {
_indexRelations(); _indexRelations();
@ -55,8 +63,16 @@ class TypeGraph {
// apply default types // apply default types
for (final applyDefault in _defaultTypes) { for (final applyDefault in _defaultTypes) {
if (!knowsType(applyDefault.target)) { final target = applyDefault.target;
this[applyDefault.target] = applyDefault.defaultType;
final type = applyDefault.defaultType;
if (type != null && !knowsType(target)) {
this[target] = applyDefault.defaultType;
}
final nullability = applyDefault.isNullable;
if (nullability != null && _knownNullability.containsKey(target)) {
markNullability(target, nullability);
} }
} }
} }
@ -87,13 +103,22 @@ class TypeGraph {
} }
void _propagateManyToOne(MultiSourceRelation edge, List<Typeable> resolved) { void _propagateManyToOne(MultiSourceRelation edge, List<Typeable> resolved) {
if (!knowsType(edge.target)) { if (edge is CopyEncapsulating) {
final fromTypes = edge.from.map((t) => this[t]).where((e) => e != null); if (!knowsType(edge.target)) {
final encapsulated = _encapsulate(fromTypes); final fromTypes = edge.from.map((t) => this[t]).where((e) => e != null);
if (encapsulated != null) { final encapsulated = _encapsulate(fromTypes);
this[edge.target] = encapsulated; if (encapsulated != null) {
resolved.add(edge.target); this[edge.target] = encapsulated;
resolved.add(edge.target);
}
} }
} else if (edge is NullableIfSomeOtherIs &&
!_knownNullability.containsKey(edge.target)) {
final nullable = edge.from
.map((e) => _knownNullability[e])
.any((nullable) => nullable == true);
_knownNullability[edge.target] = nullable;
} }
} }
@ -108,8 +133,27 @@ class TypeGraph {
} }
ResolvedType /*?*/ _encapsulate(Iterable<ResolvedType> targets) { ResolvedType /*?*/ _encapsulate(Iterable<ResolvedType> targets) {
return targets.fold<ResolvedType>(null, (previous, element) { return targets.map((e) => e.withoutNullabilityInfo).fold<ResolvedType>(null,
return previous?.union(element) ?? element; (previous, element) {
if (previous == null) return element;
final previousType = previous.type;
final elementType = element.type;
if (previousType == elementType) return previous;
if (previousType == BasicType.nullType) return element;
bool isIntOrNumeric(BasicType type) {
return type == BasicType.int || type == BasicType.real;
}
// encapsulate two different numeric types to real
if (isIntOrNumeric(previousType) && isIntOrNumeric(elementType)) {
return const ResolvedType(type: BasicType.real);
}
// fallback to text if everything else fails
return const ResolvedType(type: BasicType.text);
}); });
} }
@ -181,19 +225,6 @@ class _ResolvedVariables {
} }
extension on ResolvedType { extension on ResolvedType {
ResolvedType union(ResolvedType other) {
if (other == this) return this;
if (other.type == type) {
final thisNullable = nullable ?? true;
final otherNullable = other.nullable ?? true;
return withNullable(thisNullable || otherNullable);
}
// fallback. todo: Support more cases
return const ResolvedType(type: BasicType.text, nullable: true);
}
ResolvedType cast(CastMode mode) { ResolvedType cast(CastMode mode) {
switch (mode) { switch (mode) {
case CastMode.numeric: case CastMode.numeric:

View File

@ -189,7 +189,7 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
if (resolved != null) { if (resolved != null) {
session._checkAndResolve(e, resolved, arg); session._checkAndResolve(e, resolved, arg);
} else if (arg is RoughTypeExpectation) { } else if (arg is RoughTypeExpectation) {
session._addRelation(DefaultType(e, arg.defaultType())); session._addRelation(DefaultType(e, defaultType: arg.defaultType()));
} }
visitChildren(e, arg); visitChildren(e, arg);
@ -434,7 +434,7 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
throw AssertionError(); // required so that this switch compiles throw AssertionError(); // required so that this switch compiles
case 'sum': case 'sum':
session._addRelation(CopyAndCast(e, params.first, CastMode.numeric)); session._addRelation(CopyAndCast(e, params.first, CastMode.numeric));
session._addRelation(DefaultType(e, _realType)); session._addRelation(DefaultType(e, defaultType: _realType));
nullableIfChildIs(); nullableIfChildIs();
return null; return null;
case 'lower': case 'lower':
@ -494,6 +494,9 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
case 'coalesce': case 'coalesce':
case 'ifnull': case 'ifnull':
session._addRelation(CopyEncapsulating(e, params)); session._addRelation(CopyEncapsulating(e, params));
for (final param in params) {
session._addRelation(DefaultType(param, isNullable: true));
}
return null; return null;
case 'nullif': case 'nullif':
session._hintNullability(e, true); session._hintNullability(e, true);

View File

@ -46,6 +46,7 @@ class TypeInferenceSession {
/// This is not currently implemented. /// This is not currently implemented.
void _hintNullability(Typeable t, bool nullable) { void _hintNullability(Typeable t, bool nullable) {
assert(nullable != null); assert(nullable != null);
graph.markNullability(t, nullable);
} }
/// Asks the underlying [TypeGraph] to propagate known types via known /// Asks the underlying [TypeGraph] to propagate known types via known

View File

@ -30,6 +30,8 @@ const Map<String, ResolvedType> _types = {
'SELECT (3 * 4) = ?': ResolvedType(type: BasicType.int), 'SELECT (3 * 4) = ?': ResolvedType(type: BasicType.int),
'SELECT (3 / 4) = ?': ResolvedType(type: BasicType.int), 'SELECT (3 / 4) = ?': ResolvedType(type: BasicType.int),
'SELECT CURRENT_TIMESTAMP = ?': ResolvedType(type: BasicType.text), 'SELECT CURRENT_TIMESTAMP = ?': ResolvedType(type: BasicType.text),
"SELECT COALESCE(NULL, 'foo') = ?": ResolvedType(type: BasicType.text),
'SELECT NULLIF(3, 4) = ?': ResolvedType(type: BasicType.int, nullable: true),
'INSERT INTO demo DEFAULT VALUES ON CONFLICT (id) WHERE ? DO NOTHING': 'INSERT INTO demo DEFAULT VALUES ON CONFLICT (id) WHERE ? DO NOTHING':
ResolvedType.bool(), ResolvedType.bool(),
'INSERT INTO demo DEFAULT VALUES ON CONFLICT DO UPDATE SET id = id WHERE ?': 'INSERT INTO demo DEFAULT VALUES ON CONFLICT DO UPDATE SET id = id WHERE ?':