Support cast to enum types

This commit is contained in:
Simon Binder 2023-01-06 16:53:04 +01:00
parent bc325dd31c
commit 1d7e656e30
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
19 changed files with 357 additions and 41 deletions

View File

@ -3,6 +3,7 @@
- Support `MAPPED BY` for individual columns in queries or in views defined with SQL. - Support `MAPPED BY` for individual columns in queries or in views defined with SQL.
- Consistently interpret `CAST (x AS DATETIME)` and `CAST(x AS TEXT)` in drift files. - Consistently interpret `CAST (x AS DATETIME)` and `CAST(x AS TEXT)` in drift files.
- Support a `CAST` to an enum type in drift types.
## 2.4.1 ## 2.4.1

View File

@ -11,9 +11,11 @@ import '../../backend.dart';
import '../../driver/error.dart'; import '../../driver/error.dart';
import '../../driver/state.dart'; import '../../driver/state.dart';
import '../../results/results.dart'; import '../../results/results.dart';
import '../dart/helper.dart';
import '../resolver.dart'; import '../resolver.dart';
import '../shared/dart_types.dart'; import '../shared/dart_types.dart';
import 'sqlparser/drift_lints.dart'; import 'sqlparser/drift_lints.dart';
import 'sqlparser/mapping.dart';
abstract class DriftElementResolver<T extends DiscoveredElement> abstract class DriftElementResolver<T extends DiscoveredElement>
extends LocalElementResolver<T> { extends LocalElementResolver<T> {
@ -164,7 +166,7 @@ abstract class DriftElementResolver<T extends DiscoveredElement>
return references.firstWhereOrNull((e) => e.id.sameName(name)); return references.firstWhereOrNull((e) => e.id.sameName(name));
} }
Future<List<DriftElement>> resolveSqlReferences(AstNode stmt) async { Future<List<DriftElement>> resolveTableReferences(AstNode stmt) async {
final references = final references =
resolver.driver.newSqlEngine().findReferencedSchemaTables(stmt); resolver.driver.newSqlEngine().findReferencedSchemaTables(stmt);
final found = <DriftElement>[]; final found = <DriftElement>[];
@ -186,7 +188,91 @@ abstract class DriftElementResolver<T extends DiscoveredElement>
return found; return found;
} }
/// Finds all referenced tables, Dart expressions and Dart types referenced
/// in [stmt].
Future<FoundReferencesInSql> resolveSqlReferences(AstNode stmt) async {
final driftElements = await resolveTableReferences(stmt);
final identifier = _IdentifyDartElements();
stmt.accept(identifier, null);
return FoundReferencesInSql(
referencedElements: driftElements,
dartExpressions: identifier.dartExpressions,
dartTypes: identifier.dartTypes,
);
}
/// Creates a type resolver capable of resolving `ENUM` and `ENUMNAME` types.
///
/// Because actual type resolving work is synchronous, types are pre-resolved
/// and must be known beforehand. Types can be found by [resolveSqlReferences].
Future<TypeFromText> createTypeResolver(
FoundReferencesInSql references,
KnownDriftTypes helper,
) async {
final typeLiteralToResolved = <String, DartType>{};
for (final entry in references.dartTypes.entries) {
final type = await findDartTypeOrReportError(entry.value, entry.key);
if (type != null) {
typeLiteralToResolved[entry.value] = type;
}
}
return enumColumnFromText(typeLiteralToResolved, helper);
}
void reportLint(AnalysisError parserError) { void reportLint(AnalysisError parserError) {
reportError(DriftAnalysisError.fromSqlError(parserError)); reportError(DriftAnalysisError.fromSqlError(parserError));
} }
} }
class FoundReferencesInSql {
/// All referenced tables in the statement.
final List<DriftElement> referencedElements;
/// All inline Dart tokens used in a `MAPPED BY`.
final List<String> dartExpressions;
/// All Dart types that were referenced in an `ENUM` or `ENUMNAME` cast
/// expression in SQL.
final Map<SyntacticEntity, String> dartTypes;
const FoundReferencesInSql({
this.referencedElements = const [],
this.dartExpressions = const [],
this.dartTypes = const {},
});
static final RegExp enumRegex =
RegExp(r'^enum(name)?\((\w+)\)$', caseSensitive: false);
}
class _IdentifyDartElements extends RecursiveVisitor<void, void> {
final List<String> dartExpressions = [];
final Map<SyntacticEntity, String> dartTypes = {};
@override
void visitCastExpression(CastExpression e, void arg) {
final match = FoundReferencesInSql.enumRegex.firstMatch(e.typeName);
if (match != null) {
// Found `ENUMNAME(x)`, where `x` is a Dart type that we might want to
// resolve later.
dartTypes[e] = match.group(2)!;
}
super.visitCastExpression(e, arg);
}
@override
void visitColumnConstraint(ColumnConstraint e, void arg) {
if (e is MappedBy) {
dartExpressions.add(e.mapper.dartCode);
} else {
super.visitColumnConstraint(e, arg);
}
}
}

View File

@ -11,7 +11,7 @@ class DriftIndexResolver extends DriftElementResolver<DiscoveredDriftIndex> {
@override @override
Future<DriftIndex> resolve() async { Future<DriftIndex> resolve() async {
final stmt = discovered.sqlNode; final stmt = discovered.sqlNode;
final references = await resolveSqlReferences(stmt); final references = await resolveTableReferences(stmt);
final engine = newEngineWithTables(references); final engine = newEngineWithTables(references);
final source = (file.discovery as DiscoveredDriftFile).originalSource; final source = (file.discovery as DiscoveredDriftFile).originalSource;

View File

@ -36,15 +36,25 @@ class DriftQueryResolver
} }
} }
final resolvedDartTypes = <String, DartType>{};
for (final entry in references.dartTypes.entries) {
final dartType = await findDartTypeOrReportError(entry.value, entry.key);
if (dartType != null) {
resolvedDartTypes[entry.value] = dartType;
}
}
return DefinedSqlQuery( return DefinedSqlQuery(
discovered.ownId, discovered.ownId,
DriftDeclaration.driftFile(stmt, file.ownUri), DriftDeclaration.driftFile(stmt, file.ownUri),
references: references, references: references.referencedElements,
sql: source.substring(stmt.firstPosition, stmt.lastPosition), sql: source.substring(stmt.firstPosition, stmt.lastPosition),
sqlOffset: stmt.firstPosition, sqlOffset: stmt.firstPosition,
mode: isCreate ? QueryMode.atCreate : QueryMode.regular, mode: isCreate ? QueryMode.atCreate : QueryMode.regular,
resultClassName: resultClassName, resultClassName: resultClassName,
existingDartType: existingType, existingDartType: existingType,
dartTypes: resolvedDartTypes,
dartTokens: references.dartExpressions,
); );
} }
} }

View File

@ -1,8 +1,12 @@
import 'package:analyzer/dart/element/type.dart';
import 'package:drift/drift.dart' show DriftSqlType; import 'package:drift/drift.dart' show DriftSqlType;
import 'package:sqlparser/sqlparser.dart'; import 'package:sqlparser/sqlparser.dart';
import '../../../driver/driver.dart'; import '../../../driver/driver.dart';
import '../../../results/results.dart'; import '../../../results/results.dart';
import '../../dart/helper.dart';
import '../../shared/dart_types.dart';
import '../element_resolver.dart';
/// Converts tables and types between `drift_dev` internal reprensentation and /// Converts tables and types between `drift_dev` internal reprensentation and
/// the one used by the `sqlparser` package. /// the one used by the `sqlparser` package.
@ -140,6 +144,35 @@ class TypeMapping {
} }
} }
/// Creates a [TypeFromText] implementation that will look up type converters
/// for `ENUM` and `ENUMNAME` column.
TypeFromText enumColumnFromText(
Map<String, DartType> knownTypes, KnownDriftTypes helper) {
return (String typeName) {
final match = FoundReferencesInSql.enumRegex.firstMatch(typeName);
if (match != null) {
final isStoredAsName = match.group(1) != null;
final type = knownTypes[match.group(2)];
if (type != null) {
return ResolvedType(
type: isStoredAsName ? BasicType.text : BasicType.int,
hint: TypeConverterHint(
readEnumConverter(
(_) {},
type,
isStoredAsName ? EnumType.textEnum : EnumType.intEnum,
helper,
)..owningColumn = null,
),
);
}
}
return null;
};
}
class TypeConverterHint extends TypeHint { class TypeConverterHint extends TypeHint {
final AppliedTypeConverter converter; final AppliedTypeConverter converter;

View File

@ -17,9 +17,6 @@ import 'element_resolver.dart';
import 'sqlparser/drift_lints.dart'; import 'sqlparser/drift_lints.dart';
class DriftTableResolver extends DriftElementResolver<DiscoveredDriftTable> { class DriftTableResolver extends DriftElementResolver<DiscoveredDriftTable> {
static final RegExp _enumRegex =
RegExp(r'^enum(name)?\((\w+)\)$', caseSensitive: false);
DriftTableResolver(super.file, super.discovered, super.resolver, super.state); DriftTableResolver(super.file, super.discovered, super.resolver, super.state);
@override @override
@ -55,8 +52,9 @@ class DriftTableResolver extends DriftElementResolver<DiscoveredDriftTable> {
final typeName = column.definition?.typeName; final typeName = column.definition?.typeName;
final enumIndexMatch = final enumIndexMatch = typeName != null
typeName != null ? _enumRegex.firstMatch(typeName) : null; ? FoundReferencesInSql.enumRegex.firstMatch(typeName)
: null;
if (enumIndexMatch != null) { if (enumIndexMatch != null) {
final dartTypeName = enumIndexMatch.group(2)!; final dartTypeName = enumIndexMatch.group(2)!;
final dartType = await findDartTypeOrReportError( final dartType = await findDartTypeOrReportError(

View File

@ -16,7 +16,7 @@ class DriftTriggerResolver
@override @override
Future<DriftTrigger> resolve() async { Future<DriftTrigger> resolve() async {
final stmt = discovered.sqlNode; final stmt = discovered.sqlNode;
final references = await resolveSqlReferences(stmt); final references = await resolveTableReferences(stmt);
final engine = newEngineWithTables(references); final engine = newEngineWithTables(references);
final source = (file.discovery as DiscoveredDriftFile).originalSource; final source = (file.discovery as DiscoveredDriftFile).originalSource;

View File

@ -17,11 +17,21 @@ class DriftViewResolver extends DriftElementResolver<DiscoveredDriftView> {
@override @override
Future<DriftView> resolve() async { Future<DriftView> resolve() async {
final stmt = discovered.sqlNode; final stmt = discovered.sqlNode;
final references = await resolveSqlReferences(stmt); final allReferences = await resolveSqlReferences(stmt);
final references = allReferences.referencedElements;
final engine = newEngineWithTables(references); final engine = newEngineWithTables(references);
final source = (file.discovery as DiscoveredDriftFile).originalSource; final source = (file.discovery as DiscoveredDriftFile).originalSource;
final context = engine.analyzeNode(stmt, source); final context = engine.analyzeNode(
stmt,
source,
stmtOptions: AnalyzeStatementOptions(
resolveTypeFromText: await createTypeResolver(
allReferences,
await resolver.driver.loadKnownTypes(),
),
),
);
reportLints(context, references); reportLints(context, references);
final parserView = engine.schemaReader.readView(context, stmt); final parserView = engine.schemaReader.readView(context, stmt);
@ -51,6 +61,7 @@ class DriftViewResolver extends DriftElementResolver<DiscoveredDriftView> {
if (type != null && type.hint is TypeConverterHint) { if (type != null && type.hint is TypeConverterHint) {
converter ??= (type.hint as TypeConverterHint).converter; converter ??= (type.hint as TypeConverterHint).converter;
ownsConverter = converter.owningColumn == null;
} }
final driftColumn = DriftColumn( final driftColumn = DriftColumn(

View File

@ -1,3 +1,4 @@
import 'package:drift_dev/src/analysis/resolver/drift/sqlparser/mapping.dart';
import 'package:sqlparser/sqlparser.dart'; import 'package:sqlparser/sqlparser.dart';
import '../../utils/entity_reference_sorter.dart'; import '../../utils/entity_reference_sorter.dart';
@ -6,6 +7,7 @@ import '../driver/error.dart';
import '../driver/state.dart'; import '../driver/state.dart';
import '../results/file_results.dart'; import '../results/file_results.dart';
import '../results/results.dart'; import '../results/results.dart';
import 'dart/helper.dart';
import 'queries/query_analyzer.dart'; import 'queries/query_analyzer.dart';
import 'queries/required_variables.dart'; import 'queries/required_variables.dart';
@ -90,7 +92,8 @@ class FileAnalyzer {
AstPreparingVisitor.resolveIndexOfVariables( AstPreparingVisitor.resolveIndexOfVariables(
stmt.allDescendants.whereType<Variable>().toList()); stmt.allDescendants.whereType<Variable>().toList());
final options = _createOptionsAndVars(engine, stmt); final options =
_createOptionsAndVars(engine, stmt, element, knownTypes);
final analysisResult = engine.analyzeNode(stmt.statement, source, final analysisResult = engine.analyzeNode(stmt.statement, source,
stmtOptions: options.options); stmtOptions: options.options);
@ -127,7 +130,11 @@ class FileAnalyzer {
} }
_OptionsAndRequiredVariables _createOptionsAndVars( _OptionsAndRequiredVariables _createOptionsAndVars(
SqlEngine engine, DeclaredStatement stmt) { SqlEngine engine,
DeclaredStatement stmt,
DefinedSqlQuery query,
KnownDriftTypes helper,
) {
final reader = engine.schemaReader; final reader = engine.schemaReader;
final indexedHints = <int, ResolvedType>{}; final indexedHints = <int, ResolvedType>{};
final namedHints = <String, ResolvedType>{}; final namedHints = <String, ResolvedType>{};
@ -168,6 +175,7 @@ class FileAnalyzer {
indexedVariableTypes: indexedHints, indexedVariableTypes: indexedHints,
namedVariableTypes: namedHints, namedVariableTypes: namedHints,
defaultValuesForPlaceholder: defaultValues, defaultValuesForPlaceholder: defaultValues,
resolveTypeFromText: enumColumnFromText(query.dartTypes, helper),
), ),
RequiredVariables(requiredIndex, requiredName), RequiredVariables(requiredIndex, requiredName),
); );

View File

@ -53,7 +53,7 @@ class QueryAnalyzer {
final RequiredVariables requiredVariables; final RequiredVariables requiredVariables;
final Map<String, DriftElement> referencesByName; final Map<String, DriftElement> referencesByName;
final Map<InlineDartToken, dart.Expression> _resolvedExpressions = {}; final Map<String, dart.Expression> _resolvedExpressions = {};
/// Found tables and views found need to be shared between the query and /// Found tables and views found need to be shared between the query and
/// all subqueries to not muss any updates when watching query. /// all subqueries to not muss any updates when watching query.
@ -95,7 +95,7 @@ class QueryAnalyzer {
/// messages. /// messages.
Future<SqlQuery> analyze(DriftQueryDeclaration declaration, Future<SqlQuery> analyze(DriftQueryDeclaration declaration,
{DriftTableName? sourceForCustomName}) async { {DriftTableName? sourceForCustomName}) async {
await _resolveDartTokens(); await _resolveDartTokens(declaration);
final nestedAnalyzer = NestedQueryAnalyzer(); final nestedAnalyzer = NestedQueryAnalyzer();
NestedQueriesContainer? nestedScope; NestedQueriesContainer? nestedScope;
@ -143,25 +143,26 @@ class QueryAnalyzer {
return query; return query;
} }
Future<void> _resolveDartTokens() async { Future<void> _resolveDartTokens(DriftQueryDeclaration declaration) async {
for (final mappedBy in context.root.allDescendants.whereType<MappedBy>()) { if (declaration is DefinedSqlQuery) {
try { for (final expression in declaration.dartTokens) {
final expression = await driver.backend.resolveExpression( try {
fromFile.ownUri, final resolved = await driver.backend.resolveExpression(
mappedBy.mapper.dartCode, fromFile.ownUri,
fromFile.discovery?.importDependencies expression,
.map((e) => e.toString()) fromFile.discovery?.importDependencies
.where((e) => e.endsWith('.dart')) ?? .map((e) => e.toString())
const Iterable.empty(), .where((e) => e.endsWith('.dart')) ??
); const Iterable.empty(),
);
_resolvedExpressions[mappedBy.mapper] = expression; _resolvedExpressions[expression] = resolved;
} on CannotReadExpressionException catch (e) { } on CannotReadExpressionException catch (e) {
lints.add(AnalysisError( lints.add(AnalysisError(
type: AnalysisErrorType.other, type: AnalysisErrorType.other,
message: 'Could not read expression: ${e.msg}', message: 'Could not read expression: ${e.msg}',
relevantNode: mappedBy.mapper, ));
)); }
} }
} }
} }
@ -291,7 +292,7 @@ class QueryAnalyzer {
if (type?.hint is TypeConverterHint) { if (type?.hint is TypeConverterHint) {
converter = (type!.hint as TypeConverterHint).converter; converter = (type!.hint as TypeConverterHint).converter;
} else if (mappedBy != null) { } else if (mappedBy != null) {
final dartExpression = _resolvedExpressions[mappedBy.mapper]; final dartExpression = _resolvedExpressions[mappedBy.mapper.dartCode];
if (dartExpression != null) { if (dartExpression != null) {
converter = readTypeConverter( converter = readTypeConverter(
knownTypes.helperLibrary, knownTypes.helperLibrary,

View File

@ -58,6 +58,13 @@ class DefinedSqlQuery extends DriftElement implements DriftQueryDeclaration {
@override @override
String get name => id.name; String get name => id.name;
/// All in-line Dart source code literals embedded into the query.
final List<String> dartTokens;
/// All Dart type names embedded into the query, for instance in a
/// `CAST(x AS ENUMNAME(MyDartType))` expression.
final Map<String, DartType> dartTypes;
DefinedSqlQuery( DefinedSqlQuery(
super.id, super.id,
super.declaration, { super.declaration, {
@ -66,6 +73,8 @@ class DefinedSqlQuery extends DriftElement implements DriftQueryDeclaration {
required this.sqlOffset, required this.sqlOffset,
this.resultClassName, this.resultClassName,
this.existingDartType, this.existingDartType,
this.dartTokens = const [],
this.dartTypes = const {},
this.mode = QueryMode.regular, this.mode = QueryMode.regular,
}); });
} }

View File

@ -70,8 +70,13 @@ class ElementSerializer {
'sql': element.sql, 'sql': element.sql,
'offset': element.sqlOffset, 'offset': element.sqlOffset,
'result_class': element.resultClassName, 'result_class': element.resultClassName,
'eixsting_type': _serializeType(element.existingDartType), 'existing_type': _serializeType(element.existingDartType),
'mode': element.mode.name, 'mode': element.mode.name,
'dart_tokens': element.dartTokens,
'dart_types': {
for (final entry in element.dartTypes.entries)
entry.key: _serializeType(entry.value)
},
}; };
} else if (element is DriftTrigger) { } else if (element is DriftTrigger) {
additionalInformation = { additionalInformation = {
@ -515,7 +520,13 @@ class ElementDeserializer {
createStmt: json['sql'] as String, createStmt: json['sql'] as String,
); );
case 'query': case 'query':
final rawExistingType = json['eixsting_type']; final rawExistingType = json['existing_type'];
final types = <String, DartType>{};
for (final entry in (json['dart_types'] as Map).entries) {
types[entry.key as String] =
await _readDartType(id.libraryUri, entry.value as int);
}
return DefinedSqlQuery( return DefinedSqlQuery(
id, id,
@ -528,6 +539,8 @@ class ElementDeserializer {
? await _readDartType(id.libraryUri, rawExistingType as int) ? await _readDartType(id.libraryUri, rawExistingType as int)
: null, : null,
mode: QueryMode.values.byName(json['mode'] as String), mode: QueryMode.values.byName(json['mode'] as String),
dartTokens: (json['dart_tokens'] as List).cast(),
dartTypes: types,
); );
case 'trigger': case 'trigger':
DriftTable? on; DriftTable? on;

View File

@ -4,6 +4,7 @@ import 'package:drift/drift.dart' show SqlDialect;
import 'package:sqlparser/sqlparser.dart'; import 'package:sqlparser/sqlparser.dart';
import 'package:sqlparser/utils/node_to_text.dart'; import 'package:sqlparser/utils/node_to_text.dart';
import '../../analysis/resolver/drift/element_resolver.dart';
import '../../analysis/results/results.dart'; import '../../analysis/results/results.dart';
import '../../analysis/options.dart'; import '../../analysis/options.dart';
import '../../utils/string_escaper.dart'; import '../../utils/string_escaper.dart';
@ -117,6 +118,13 @@ class SqlWriter extends NodeSqlBuilder {
overriddenTypeName = options.storeDateTimeValuesAsText ? 'TEXT' : 'INT'; overriddenTypeName = options.storeDateTimeValuesAsText ? 'TEXT' : 'INT';
} else if (hint is IsBoolean) { } else if (hint is IsBoolean) {
overriddenTypeName = 'INT'; overriddenTypeName = 'INT';
} else {
final enumMatch = FoundReferencesInSql.enumRegex.firstMatch(e.typeName);
if (enumMatch != null) {
final isStoredAsText = enumMatch.group(1) != null;
overriddenTypeName = isStoredAsText ? 'TEXT' : 'INT';
}
} }
if (overriddenTypeName != null) { if (overriddenTypeName != null) {

View File

@ -222,6 +222,57 @@ TypeConverter<Object, int> createConverter() => throw UnimplementedError();
); );
}); });
test('supports enum columns', () async {
final backend = TestBackend.inTest({
'a|lib/a.drift': '''
import 'enums.dart';
CREATE VIEW foo AS SELECT
1 AS c1,
CAST(1 AS ENUM(MyEnum)) AS c2,
CAST('foo' AS ENUMNAME(MyEnum)) AS c3;
''',
'a|lib/enums.dart': '''
enum MyEnum {
foo, bar
}
''',
});
final state = await backend.analyze('package:a/a.drift');
backend.expectNoErrors();
final view = state.analyzedElements.single as DriftView;
final c1 = view.columns[0];
final c2 = view.columns[1];
final c3 = view.columns[2];
expect(c1.typeConverter, isNull);
expect(
c2.typeConverter,
isA<AppliedTypeConverter>()
.having((e) => e.isDriftEnumTypeConverter, 'isDriftEnumTypeConverter',
isTrue)
.having((e) => e.owningColumn, 'owningColumn', c2),
);
expect(
c3.typeConverter,
isA<AppliedTypeConverter>()
.having((e) => e.isDriftEnumTypeConverter, 'isDriftEnumTypeConverter',
isTrue)
.having((e) => e.owningColumn, 'owningColumn', c3),
);
expect(
view.source,
isA<SqlViewSource>().having(
(e) => e.sqlCreateViewStmt,
'sqlCreateViewStmt',
"CREATE VIEW foo AS SELECT 1 AS c1, CAST(1 AS INT) AS c2, CAST('foo' AS TEXT) AS c3;",
),
);
});
group('desugars cast', () { group('desugars cast', () {
Future<void> expectView( Future<void> expectView(
String definition, String definition,

View File

@ -311,4 +311,40 @@ a: SELECT CAST(1 AS BOOLEAN) AS a, CAST(2 AS DATETIME) as b;
.having((e) => e.sqlType, 'sqlType', DriftSqlType.dateTime), .having((e) => e.sqlType, 'sqlType', DriftSqlType.dateTime),
]); ]);
}); });
test('can cast to enum type', () async {
final backend = TestBackend.inTest({
'a|lib/a.drift': '''
import 'enum.dart';
a: SELECT
1 AS c1,
CAST(1 AS ENUM(MyEnum)) AS c2,
CAST('foo' AS ENUMNAME(MyEnum)) AS c3;
''',
'a|lib/enum.dart': '''
enum MyEnum {
foo, bar
}
''',
});
final state = await backend.analyze('package:a/a.drift');
backend.expectNoErrors();
final query = state.fileAnalysis!.resolvedQueries.values.single;
final resultSet = query.resultSet!;
final isEnumConverter = isA<AppliedTypeConverter>().having(
(e) => e.isDriftEnumTypeConverter, 'isDriftEnumTypeConverter', isTrue);
expect(resultSet.columns, [
scalarColumn('c1')
.having((e) => e.typeConverter, 'typeConverter', isNull),
scalarColumn('c2')
.having((e) => e.typeConverter, 'typeConverter', isEnumConverter),
scalarColumn('c3')
.having((e) => e.typeConverter, 'typeConverter', isEnumConverter),
]);
});
} }

View File

@ -1,5 +1,8 @@
part of 'analysis.dart'; part of 'analysis.dart';
/// Signature of a function that resolves the type of a SQL type literal.
typedef TypeFromText = ResolvedType? Function(String);
/// Options to analyze a sql statement. This can be used if the type of a /// Options to analyze a sql statement. This can be used if the type of a
/// variable is known from the outside. /// variable is known from the outside.
class AnalyzeStatementOptions { class AnalyzeStatementOptions {
@ -10,10 +13,13 @@ class AnalyzeStatementOptions {
/// expression, if set. /// expression, if set.
final Map<String, Expression> defaultValuesForPlaceholder; final Map<String, Expression> defaultValuesForPlaceholder;
final TypeFromText? resolveTypeFromText;
const AnalyzeStatementOptions({ const AnalyzeStatementOptions({
this.indexedVariableTypes = const {}, this.indexedVariableTypes = const {},
this.namedVariableTypes = const {}, this.namedVariableTypes = const {},
this.defaultValuesForPlaceholder = const {}, this.defaultValuesForPlaceholder = const {},
this.resolveTypeFromText,
}); });
/// Looks up the defined type for that variable. /// Looks up the defined type for that variable.

View File

@ -10,9 +10,12 @@ class SchemaFromCreateTable {
/// enabled) should be reported as a text column instead of an int column. /// enabled) should be reported as a text column instead of an int column.
final bool driftUseTextForDateTime; final bool driftUseTextForDateTime;
final AnalyzeStatementOptions? statementOptions;
const SchemaFromCreateTable({ const SchemaFromCreateTable({
this.driftExtensions = false, this.driftExtensions = false,
this.driftUseTextForDateTime = false, this.driftUseTextForDateTime = false,
this.statementOptions,
}); });
/// Reads a [Table] schema from the [stmt] inducing a table (either a /// Reads a [Table] schema from the [stmt] inducing a table (either a
@ -146,6 +149,12 @@ class SchemaFromCreateTable {
return const ResolvedType(type: BasicType.blob); return const ResolvedType(type: BasicType.blob);
} }
// S if a custom resolver is installed and yields a type for this column:
final custom = statementOptions?.resolveTypeFromText?.call(typeName);
if (custom != null) {
return custom;
}
final upper = typeName.toUpperCase(); final upper = typeName.toUpperCase();
if (upper.contains('INT')) { if (upper.contains('INT')) {
return const ResolvedType(type: BasicType.int); return const ResolvedType(type: BasicType.int);

View File

@ -34,12 +34,25 @@ class SqlEngine {
/// The returned reader can be used to read the table structure from a /// The returned reader can be used to read the table structure from a
/// [TableInducingStatement] by using [SchemaFromCreateTable.read]. /// [TableInducingStatement] by using [SchemaFromCreateTable.read].
SchemaFromCreateTable get schemaReader { SchemaFromCreateTable get schemaReader {
return _schemaReader ??= _createSchemaReader(null);
}
SchemaFromCreateTable _createSchemaReader(
AnalyzeStatementOptions? stmtOptions) {
final driftOptions = options.driftOptions; final driftOptions = options.driftOptions;
return _schemaReader ??= SchemaFromCreateTable( if (stmtOptions != null) {
driftExtensions: driftOptions != null, return SchemaFromCreateTable(
driftUseTextForDateTime: driftOptions?.storeDateTimesAsText == true, driftExtensions: driftOptions != null,
); driftUseTextForDateTime: driftOptions?.storeDateTimesAsText == true,
statementOptions: stmtOptions,
);
} else {
return _schemaReader ??= SchemaFromCreateTable(
driftExtensions: driftOptions != null,
driftUseTextForDateTime: driftOptions?.storeDateTimesAsText == true,
);
}
} }
/// Registers the [table], which means that it can later be used in sql /// Registers the [table], which means that it can later be used in sql
@ -223,8 +236,10 @@ class SqlEngine {
AnalysisContext _createContext( AnalysisContext _createContext(
AstNode node, String sql, AnalyzeStatementOptions? stmtOptions) { AstNode node, String sql, AnalyzeStatementOptions? stmtOptions) {
final schemaSupport = _createSchemaReader(stmtOptions);
return AnalysisContext(node, sql, _constructRootScope(), options, return AnalysisContext(node, sql, _constructRootScope(), options,
stmtOptions: stmtOptions, schemaSupport: schemaReader); stmtOptions: stmtOptions, schemaSupport: schemaSupport);
} }
void _analyzeContext(AnalysisContext context) { void _analyzeContext(AnalysisContext context) {

View File

@ -116,4 +116,25 @@ WITH RECURSIVE
everyElement(isA<ResolveResult>() everyElement(isA<ResolveResult>()
.having((e) => e.type?.nullable, 'type.nullable', isTrue))); .having((e) => e.type?.nullable, 'type.nullable', isTrue)));
}); });
test('can extract custom type', () {
final engine = SqlEngine();
final content = engine.analyze(
'SELECT CAST(1 AS MyCustomType)',
stmtOptions: AnalyzeStatementOptions(
resolveTypeFromText: expectAsync1(
(typeName) {
expect(typeName, 'MyCustomType');
return ResolvedType.bool();
},
),
),
);
final select = content.root as SelectStatement;
final column = select.resolvedColumns!.single;
expect(content.typeOf(column).type, ResolvedType.bool());
});
} }