diff --git a/sqlparser/lib/src/analysis/types/data.dart b/sqlparser/lib/src/analysis/types/data.dart index a100958d..d5b2fc8a 100644 --- a/sqlparser/lib/src/analysis/types/data.dart +++ b/sqlparser/lib/src/analysis/types/data.dart @@ -26,11 +26,17 @@ class ResolvedType { const ResolvedType( {this.type, this.hint, this.nullable = false, this.isArray = false}); - const ResolvedType.bool({bool nullable}) + const ResolvedType.bool({bool nullable = false}) : this(type: BasicType.int, hint: const IsBoolean(), nullable: nullable); + ResolvedType get withoutNullabilityInfo { + return nullable == null + ? this + : ResolvedType(type: type, hint: hint, isArray: isArray); + } + ResolvedType withNullable(bool nullable) { - return copyWith(nullable: nullable); + return nullable == this.nullable ? this : copyWith(nullable: nullable); } ResolvedType toArray(bool array) { diff --git a/sqlparser/lib/src/analysis/types2/graph/relationships.dart b/sqlparser/lib/src/analysis/types2/graph/relationships.dart index 8b017f2c..971d65a9 100644 --- a/sqlparser/lib/src/analysis/types2/graph/relationships.dart +++ b/sqlparser/lib/src/analysis/types2/graph/relationships.dart @@ -56,9 +56,10 @@ class HaveSameType extends TypeRelation { class DefaultType extends TypeRelation implements DirectedRelation { @override final Typeable target; - final ResolvedType defaultType; + final ResolvedType /*?*/ defaultType; + final bool /*?*/ isNullable; - DefaultType(this.target, this.defaultType); + DefaultType(this.target, {this.defaultType, this.isNullable}); } enum CastMode { numeric, boolean } diff --git a/sqlparser/lib/src/analysis/types2/graph/type_graph.dart b/sqlparser/lib/src/analysis/types2/graph/type_graph.dart index e335ae4b..29f4c24a 100644 --- a/sqlparser/lib/src/analysis/types2/graph/type_graph.dart +++ b/sqlparser/lib/src/analysis/types2/graph/type_graph.dart @@ -12,13 +12,17 @@ class TypeGraph { final Set _multiSources = {}; final List _defaultTypes = []; - TypeGraph(); - ResolvedType operator [](Typeable t) { final normalized = variables.normalize(t); if (_knownTypes.containsKey(normalized)) { - return _knownTypes[normalized]; + final type = _knownTypes[normalized]; + final nullability = _knownNullability[normalized]; + + if (nullability != null) { + return type.withNullable(nullability); + } + return type; } return null; @@ -28,7 +32,7 @@ class TypeGraph { final normalized = variables.normalize(t); _knownTypes[normalized] = type; - if (type.nullable != null) { + if (type.nullable != null && !_knownNullability.containsKey(normalized)) { // nullability is known _knownNullability[normalized] = type.nullable; } @@ -40,6 +44,10 @@ class TypeGraph { _relations.add(relation); } + void markNullability(Typeable t, bool isNullable) { + _knownNullability[variables.normalize(t)] = isNullable; + } + void performResolve() { _indexRelations(); @@ -55,8 +63,16 @@ class TypeGraph { // apply default types for (final applyDefault in _defaultTypes) { - if (!knowsType(applyDefault.target)) { - this[applyDefault.target] = applyDefault.defaultType; + final target = applyDefault.target; + + final type = applyDefault.defaultType; + if (type != null && !knowsType(target)) { + this[target] = applyDefault.defaultType; + } + + final nullability = applyDefault.isNullable; + if (nullability != null && _knownNullability.containsKey(target)) { + markNullability(target, nullability); } } } @@ -87,13 +103,22 @@ class TypeGraph { } void _propagateManyToOne(MultiSourceRelation edge, List resolved) { - if (!knowsType(edge.target)) { - final fromTypes = edge.from.map((t) => this[t]).where((e) => e != null); - final encapsulated = _encapsulate(fromTypes); - if (encapsulated != null) { - this[edge.target] = encapsulated; - resolved.add(edge.target); + if (edge is CopyEncapsulating) { + if (!knowsType(edge.target)) { + final fromTypes = edge.from.map((t) => this[t]).where((e) => e != null); + final encapsulated = _encapsulate(fromTypes); + if (encapsulated != null) { + this[edge.target] = encapsulated; + resolved.add(edge.target); + } } + } else if (edge is NullableIfSomeOtherIs && + !_knownNullability.containsKey(edge.target)) { + final nullable = edge.from + .map((e) => _knownNullability[e]) + .any((nullable) => nullable == true); + + _knownNullability[edge.target] = nullable; } } @@ -108,8 +133,27 @@ class TypeGraph { } ResolvedType /*?*/ _encapsulate(Iterable targets) { - return targets.fold(null, (previous, element) { - return previous?.union(element) ?? element; + return targets.map((e) => e.withoutNullabilityInfo).fold(null, + (previous, element) { + if (previous == null) return element; + + final previousType = previous.type; + final elementType = element.type; + + if (previousType == elementType) return previous; + if (previousType == BasicType.nullType) return element; + + bool isIntOrNumeric(BasicType type) { + return type == BasicType.int || type == BasicType.real; + } + + // encapsulate two different numeric types to real + if (isIntOrNumeric(previousType) && isIntOrNumeric(elementType)) { + return const ResolvedType(type: BasicType.real); + } + + // fallback to text if everything else fails + return const ResolvedType(type: BasicType.text); }); } @@ -181,19 +225,6 @@ class _ResolvedVariables { } extension on ResolvedType { - ResolvedType union(ResolvedType other) { - if (other == this) return this; - - if (other.type == type) { - final thisNullable = nullable ?? true; - final otherNullable = other.nullable ?? true; - return withNullable(thisNullable || otherNullable); - } - - // fallback. todo: Support more cases - return const ResolvedType(type: BasicType.text, nullable: true); - } - ResolvedType cast(CastMode mode) { switch (mode) { case CastMode.numeric: diff --git a/sqlparser/lib/src/analysis/types2/resolving_visitor.dart b/sqlparser/lib/src/analysis/types2/resolving_visitor.dart index 0bf74a6d..9cad3902 100644 --- a/sqlparser/lib/src/analysis/types2/resolving_visitor.dart +++ b/sqlparser/lib/src/analysis/types2/resolving_visitor.dart @@ -189,7 +189,7 @@ class TypeResolver extends RecursiveVisitor { if (resolved != null) { session._checkAndResolve(e, resolved, arg); } else if (arg is RoughTypeExpectation) { - session._addRelation(DefaultType(e, arg.defaultType())); + session._addRelation(DefaultType(e, defaultType: arg.defaultType())); } visitChildren(e, arg); @@ -434,7 +434,7 @@ class TypeResolver extends RecursiveVisitor { throw AssertionError(); // required so that this switch compiles case 'sum': session._addRelation(CopyAndCast(e, params.first, CastMode.numeric)); - session._addRelation(DefaultType(e, _realType)); + session._addRelation(DefaultType(e, defaultType: _realType)); nullableIfChildIs(); return null; case 'lower': @@ -494,6 +494,9 @@ class TypeResolver extends RecursiveVisitor { case 'coalesce': case 'ifnull': session._addRelation(CopyEncapsulating(e, params)); + for (final param in params) { + session._addRelation(DefaultType(param, isNullable: true)); + } return null; case 'nullif': session._hintNullability(e, true); diff --git a/sqlparser/lib/src/analysis/types2/types.dart b/sqlparser/lib/src/analysis/types2/types.dart index 0b6d631c..29121523 100644 --- a/sqlparser/lib/src/analysis/types2/types.dart +++ b/sqlparser/lib/src/analysis/types2/types.dart @@ -46,6 +46,7 @@ class TypeInferenceSession { /// This is not currently implemented. void _hintNullability(Typeable t, bool nullable) { assert(nullable != null); + graph.markNullability(t, nullable); } /// Asks the underlying [TypeGraph] to propagate known types via known diff --git a/sqlparser/test/analysis/types2/misc_cases_test.dart b/sqlparser/test/analysis/types2/misc_cases_test.dart index 8898d776..d7044d50 100644 --- a/sqlparser/test/analysis/types2/misc_cases_test.dart +++ b/sqlparser/test/analysis/types2/misc_cases_test.dart @@ -30,6 +30,8 @@ const Map _types = { 'SELECT (3 * 4) = ?': ResolvedType(type: BasicType.int), 'SELECT (3 / 4) = ?': ResolvedType(type: BasicType.int), 'SELECT CURRENT_TIMESTAMP = ?': ResolvedType(type: BasicType.text), + "SELECT COALESCE(NULL, 'foo') = ?": ResolvedType(type: BasicType.text), + 'SELECT NULLIF(3, 4) = ?': ResolvedType(type: BasicType.int, nullable: true), 'INSERT INTO demo DEFAULT VALUES ON CONFLICT (id) WHERE ? DO NOTHING': ResolvedType.bool(), 'INSERT INTO demo DEFAULT VALUES ON CONFLICT DO UPDATE SET id = id WHERE ?':