Improve nullability detection in new type inference

This commit is contained in:
Simon Binder 2020-03-24 19:33:57 +01:00
parent cdd57f340d
commit 2b9a85714f
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
6 changed files with 77 additions and 33 deletions

View File

@ -26,11 +26,17 @@ class ResolvedType {
const ResolvedType(
{this.type, this.hint, this.nullable = false, this.isArray = false});
const ResolvedType.bool({bool nullable})
const ResolvedType.bool({bool nullable = false})
: this(type: BasicType.int, hint: const IsBoolean(), nullable: nullable);
ResolvedType get withoutNullabilityInfo {
return nullable == null
? this
: ResolvedType(type: type, hint: hint, isArray: isArray);
}
ResolvedType withNullable(bool nullable) {
return copyWith(nullable: nullable);
return nullable == this.nullable ? this : copyWith(nullable: nullable);
}
ResolvedType toArray(bool array) {

View File

@ -56,9 +56,10 @@ class HaveSameType extends TypeRelation {
class DefaultType extends TypeRelation implements DirectedRelation {
@override
final Typeable target;
final ResolvedType defaultType;
final ResolvedType /*?*/ defaultType;
final bool /*?*/ isNullable;
DefaultType(this.target, this.defaultType);
DefaultType(this.target, {this.defaultType, this.isNullable});
}
enum CastMode { numeric, boolean }

View File

@ -12,13 +12,17 @@ class TypeGraph {
final Set<MultiSourceRelation> _multiSources = {};
final List<DefaultType> _defaultTypes = [];
TypeGraph();
ResolvedType operator [](Typeable t) {
final normalized = variables.normalize(t);
if (_knownTypes.containsKey(normalized)) {
return _knownTypes[normalized];
final type = _knownTypes[normalized];
final nullability = _knownNullability[normalized];
if (nullability != null) {
return type.withNullable(nullability);
}
return type;
}
return null;
@ -28,7 +32,7 @@ class TypeGraph {
final normalized = variables.normalize(t);
_knownTypes[normalized] = type;
if (type.nullable != null) {
if (type.nullable != null && !_knownNullability.containsKey(normalized)) {
// nullability is known
_knownNullability[normalized] = type.nullable;
}
@ -40,6 +44,10 @@ class TypeGraph {
_relations.add(relation);
}
void markNullability(Typeable t, bool isNullable) {
_knownNullability[variables.normalize(t)] = isNullable;
}
void performResolve() {
_indexRelations();
@ -55,8 +63,16 @@ class TypeGraph {
// apply default types
for (final applyDefault in _defaultTypes) {
if (!knowsType(applyDefault.target)) {
this[applyDefault.target] = applyDefault.defaultType;
final target = applyDefault.target;
final type = applyDefault.defaultType;
if (type != null && !knowsType(target)) {
this[target] = applyDefault.defaultType;
}
final nullability = applyDefault.isNullable;
if (nullability != null && _knownNullability.containsKey(target)) {
markNullability(target, nullability);
}
}
}
@ -87,13 +103,22 @@ class TypeGraph {
}
void _propagateManyToOne(MultiSourceRelation edge, List<Typeable> resolved) {
if (!knowsType(edge.target)) {
final fromTypes = edge.from.map((t) => this[t]).where((e) => e != null);
final encapsulated = _encapsulate(fromTypes);
if (encapsulated != null) {
this[edge.target] = encapsulated;
resolved.add(edge.target);
if (edge is CopyEncapsulating) {
if (!knowsType(edge.target)) {
final fromTypes = edge.from.map((t) => this[t]).where((e) => e != null);
final encapsulated = _encapsulate(fromTypes);
if (encapsulated != null) {
this[edge.target] = encapsulated;
resolved.add(edge.target);
}
}
} else if (edge is NullableIfSomeOtherIs &&
!_knownNullability.containsKey(edge.target)) {
final nullable = edge.from
.map((e) => _knownNullability[e])
.any((nullable) => nullable == true);
_knownNullability[edge.target] = nullable;
}
}
@ -108,8 +133,27 @@ class TypeGraph {
}
ResolvedType /*?*/ _encapsulate(Iterable<ResolvedType> targets) {
return targets.fold<ResolvedType>(null, (previous, element) {
return previous?.union(element) ?? element;
return targets.map((e) => e.withoutNullabilityInfo).fold<ResolvedType>(null,
(previous, element) {
if (previous == null) return element;
final previousType = previous.type;
final elementType = element.type;
if (previousType == elementType) return previous;
if (previousType == BasicType.nullType) return element;
bool isIntOrNumeric(BasicType type) {
return type == BasicType.int || type == BasicType.real;
}
// encapsulate two different numeric types to real
if (isIntOrNumeric(previousType) && isIntOrNumeric(elementType)) {
return const ResolvedType(type: BasicType.real);
}
// fallback to text if everything else fails
return const ResolvedType(type: BasicType.text);
});
}
@ -181,19 +225,6 @@ class _ResolvedVariables {
}
extension on ResolvedType {
ResolvedType union(ResolvedType other) {
if (other == this) return this;
if (other.type == type) {
final thisNullable = nullable ?? true;
final otherNullable = other.nullable ?? true;
return withNullable(thisNullable || otherNullable);
}
// fallback. todo: Support more cases
return const ResolvedType(type: BasicType.text, nullable: true);
}
ResolvedType cast(CastMode mode) {
switch (mode) {
case CastMode.numeric:

View File

@ -189,7 +189,7 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
if (resolved != null) {
session._checkAndResolve(e, resolved, arg);
} else if (arg is RoughTypeExpectation) {
session._addRelation(DefaultType(e, arg.defaultType()));
session._addRelation(DefaultType(e, defaultType: arg.defaultType()));
}
visitChildren(e, arg);
@ -434,7 +434,7 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
throw AssertionError(); // required so that this switch compiles
case 'sum':
session._addRelation(CopyAndCast(e, params.first, CastMode.numeric));
session._addRelation(DefaultType(e, _realType));
session._addRelation(DefaultType(e, defaultType: _realType));
nullableIfChildIs();
return null;
case 'lower':
@ -494,6 +494,9 @@ class TypeResolver extends RecursiveVisitor<TypeExpectation, void> {
case 'coalesce':
case 'ifnull':
session._addRelation(CopyEncapsulating(e, params));
for (final param in params) {
session._addRelation(DefaultType(param, isNullable: true));
}
return null;
case 'nullif':
session._hintNullability(e, true);

View File

@ -46,6 +46,7 @@ class TypeInferenceSession {
/// This is not currently implemented.
void _hintNullability(Typeable t, bool nullable) {
assert(nullable != null);
graph.markNullability(t, nullable);
}
/// Asks the underlying [TypeGraph] to propagate known types via known

View File

@ -30,6 +30,8 @@ const Map<String, ResolvedType> _types = {
'SELECT (3 * 4) = ?': ResolvedType(type: BasicType.int),
'SELECT (3 / 4) = ?': ResolvedType(type: BasicType.int),
'SELECT CURRENT_TIMESTAMP = ?': ResolvedType(type: BasicType.text),
"SELECT COALESCE(NULL, 'foo') = ?": ResolvedType(type: BasicType.text),
'SELECT NULLIF(3, 4) = ?': ResolvedType(type: BasicType.int, nullable: true),
'INSERT INTO demo DEFAULT VALUES ON CONFLICT (id) WHERE ? DO NOTHING':
ResolvedType.bool(),
'INSERT INTO demo DEFAULT VALUES ON CONFLICT DO UPDATE SET id = id WHERE ?':