Support cast to enum types

This commit is contained in:
Simon Binder 2023-01-06 16:53:04 +01:00
parent bc325dd31c
commit 1d7e656e30
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
19 changed files with 357 additions and 41 deletions

View File

@ -3,6 +3,7 @@
- Support `MAPPED BY` for individual columns in queries or in views defined with SQL.
- Consistently interpret `CAST (x AS DATETIME)` and `CAST(x AS TEXT)` in drift files.
- Support a `CAST` to an enum type in drift types.
## 2.4.1

View File

@ -11,9 +11,11 @@ import '../../backend.dart';
import '../../driver/error.dart';
import '../../driver/state.dart';
import '../../results/results.dart';
import '../dart/helper.dart';
import '../resolver.dart';
import '../shared/dart_types.dart';
import 'sqlparser/drift_lints.dart';
import 'sqlparser/mapping.dart';
abstract class DriftElementResolver<T extends DiscoveredElement>
extends LocalElementResolver<T> {
@ -164,7 +166,7 @@ abstract class DriftElementResolver<T extends DiscoveredElement>
return references.firstWhereOrNull((e) => e.id.sameName(name));
}
Future<List<DriftElement>> resolveSqlReferences(AstNode stmt) async {
Future<List<DriftElement>> resolveTableReferences(AstNode stmt) async {
final references =
resolver.driver.newSqlEngine().findReferencedSchemaTables(stmt);
final found = <DriftElement>[];
@ -186,7 +188,91 @@ abstract class DriftElementResolver<T extends DiscoveredElement>
return found;
}
/// Finds all referenced tables, Dart expressions and Dart types referenced
/// in [stmt].
Future<FoundReferencesInSql> resolveSqlReferences(AstNode stmt) async {
final driftElements = await resolveTableReferences(stmt);
final identifier = _IdentifyDartElements();
stmt.accept(identifier, null);
return FoundReferencesInSql(
referencedElements: driftElements,
dartExpressions: identifier.dartExpressions,
dartTypes: identifier.dartTypes,
);
}
/// Creates a type resolver capable of resolving `ENUM` and `ENUMNAME` types.
///
/// Because actual type resolving work is synchronous, types are pre-resolved
/// and must be known beforehand. Types can be found by [resolveSqlReferences].
Future<TypeFromText> createTypeResolver(
FoundReferencesInSql references,
KnownDriftTypes helper,
) async {
final typeLiteralToResolved = <String, DartType>{};
for (final entry in references.dartTypes.entries) {
final type = await findDartTypeOrReportError(entry.value, entry.key);
if (type != null) {
typeLiteralToResolved[entry.value] = type;
}
}
return enumColumnFromText(typeLiteralToResolved, helper);
}
void reportLint(AnalysisError parserError) {
reportError(DriftAnalysisError.fromSqlError(parserError));
}
}
class FoundReferencesInSql {
/// All referenced tables in the statement.
final List<DriftElement> referencedElements;
/// All inline Dart tokens used in a `MAPPED BY`.
final List<String> dartExpressions;
/// All Dart types that were referenced in an `ENUM` or `ENUMNAME` cast
/// expression in SQL.
final Map<SyntacticEntity, String> dartTypes;
const FoundReferencesInSql({
this.referencedElements = const [],
this.dartExpressions = const [],
this.dartTypes = const {},
});
static final RegExp enumRegex =
RegExp(r'^enum(name)?\((\w+)\)$', caseSensitive: false);
}
class _IdentifyDartElements extends RecursiveVisitor<void, void> {
final List<String> dartExpressions = [];
final Map<SyntacticEntity, String> dartTypes = {};
@override
void visitCastExpression(CastExpression e, void arg) {
final match = FoundReferencesInSql.enumRegex.firstMatch(e.typeName);
if (match != null) {
// Found `ENUMNAME(x)`, where `x` is a Dart type that we might want to
// resolve later.
dartTypes[e] = match.group(2)!;
}
super.visitCastExpression(e, arg);
}
@override
void visitColumnConstraint(ColumnConstraint e, void arg) {
if (e is MappedBy) {
dartExpressions.add(e.mapper.dartCode);
} else {
super.visitColumnConstraint(e, arg);
}
}
}

View File

@ -11,7 +11,7 @@ class DriftIndexResolver extends DriftElementResolver<DiscoveredDriftIndex> {
@override
Future<DriftIndex> resolve() async {
final stmt = discovered.sqlNode;
final references = await resolveSqlReferences(stmt);
final references = await resolveTableReferences(stmt);
final engine = newEngineWithTables(references);
final source = (file.discovery as DiscoveredDriftFile).originalSource;

View File

@ -36,15 +36,25 @@ class DriftQueryResolver
}
}
final resolvedDartTypes = <String, DartType>{};
for (final entry in references.dartTypes.entries) {
final dartType = await findDartTypeOrReportError(entry.value, entry.key);
if (dartType != null) {
resolvedDartTypes[entry.value] = dartType;
}
}
return DefinedSqlQuery(
discovered.ownId,
DriftDeclaration.driftFile(stmt, file.ownUri),
references: references,
references: references.referencedElements,
sql: source.substring(stmt.firstPosition, stmt.lastPosition),
sqlOffset: stmt.firstPosition,
mode: isCreate ? QueryMode.atCreate : QueryMode.regular,
resultClassName: resultClassName,
existingDartType: existingType,
dartTypes: resolvedDartTypes,
dartTokens: references.dartExpressions,
);
}
}

View File

@ -1,8 +1,12 @@
import 'package:analyzer/dart/element/type.dart';
import 'package:drift/drift.dart' show DriftSqlType;
import 'package:sqlparser/sqlparser.dart';
import '../../../driver/driver.dart';
import '../../../results/results.dart';
import '../../dart/helper.dart';
import '../../shared/dart_types.dart';
import '../element_resolver.dart';
/// Converts tables and types between `drift_dev` internal reprensentation and
/// the one used by the `sqlparser` package.
@ -140,6 +144,35 @@ class TypeMapping {
}
}
/// Creates a [TypeFromText] implementation that will look up type converters
/// for `ENUM` and `ENUMNAME` column.
TypeFromText enumColumnFromText(
Map<String, DartType> knownTypes, KnownDriftTypes helper) {
return (String typeName) {
final match = FoundReferencesInSql.enumRegex.firstMatch(typeName);
if (match != null) {
final isStoredAsName = match.group(1) != null;
final type = knownTypes[match.group(2)];
if (type != null) {
return ResolvedType(
type: isStoredAsName ? BasicType.text : BasicType.int,
hint: TypeConverterHint(
readEnumConverter(
(_) {},
type,
isStoredAsName ? EnumType.textEnum : EnumType.intEnum,
helper,
)..owningColumn = null,
),
);
}
}
return null;
};
}
class TypeConverterHint extends TypeHint {
final AppliedTypeConverter converter;

View File

@ -17,9 +17,6 @@ import 'element_resolver.dart';
import 'sqlparser/drift_lints.dart';
class DriftTableResolver extends DriftElementResolver<DiscoveredDriftTable> {
static final RegExp _enumRegex =
RegExp(r'^enum(name)?\((\w+)\)$', caseSensitive: false);
DriftTableResolver(super.file, super.discovered, super.resolver, super.state);
@override
@ -55,8 +52,9 @@ class DriftTableResolver extends DriftElementResolver<DiscoveredDriftTable> {
final typeName = column.definition?.typeName;
final enumIndexMatch =
typeName != null ? _enumRegex.firstMatch(typeName) : null;
final enumIndexMatch = typeName != null
? FoundReferencesInSql.enumRegex.firstMatch(typeName)
: null;
if (enumIndexMatch != null) {
final dartTypeName = enumIndexMatch.group(2)!;
final dartType = await findDartTypeOrReportError(

View File

@ -16,7 +16,7 @@ class DriftTriggerResolver
@override
Future<DriftTrigger> resolve() async {
final stmt = discovered.sqlNode;
final references = await resolveSqlReferences(stmt);
final references = await resolveTableReferences(stmt);
final engine = newEngineWithTables(references);
final source = (file.discovery as DiscoveredDriftFile).originalSource;

View File

@ -17,11 +17,21 @@ class DriftViewResolver extends DriftElementResolver<DiscoveredDriftView> {
@override
Future<DriftView> resolve() async {
final stmt = discovered.sqlNode;
final references = await resolveSqlReferences(stmt);
final allReferences = await resolveSqlReferences(stmt);
final references = allReferences.referencedElements;
final engine = newEngineWithTables(references);
final source = (file.discovery as DiscoveredDriftFile).originalSource;
final context = engine.analyzeNode(stmt, source);
final context = engine.analyzeNode(
stmt,
source,
stmtOptions: AnalyzeStatementOptions(
resolveTypeFromText: await createTypeResolver(
allReferences,
await resolver.driver.loadKnownTypes(),
),
),
);
reportLints(context, references);
final parserView = engine.schemaReader.readView(context, stmt);
@ -51,6 +61,7 @@ class DriftViewResolver extends DriftElementResolver<DiscoveredDriftView> {
if (type != null && type.hint is TypeConverterHint) {
converter ??= (type.hint as TypeConverterHint).converter;
ownsConverter = converter.owningColumn == null;
}
final driftColumn = DriftColumn(

View File

@ -1,3 +1,4 @@
import 'package:drift_dev/src/analysis/resolver/drift/sqlparser/mapping.dart';
import 'package:sqlparser/sqlparser.dart';
import '../../utils/entity_reference_sorter.dart';
@ -6,6 +7,7 @@ import '../driver/error.dart';
import '../driver/state.dart';
import '../results/file_results.dart';
import '../results/results.dart';
import 'dart/helper.dart';
import 'queries/query_analyzer.dart';
import 'queries/required_variables.dart';
@ -90,7 +92,8 @@ class FileAnalyzer {
AstPreparingVisitor.resolveIndexOfVariables(
stmt.allDescendants.whereType<Variable>().toList());
final options = _createOptionsAndVars(engine, stmt);
final options =
_createOptionsAndVars(engine, stmt, element, knownTypes);
final analysisResult = engine.analyzeNode(stmt.statement, source,
stmtOptions: options.options);
@ -127,7 +130,11 @@ class FileAnalyzer {
}
_OptionsAndRequiredVariables _createOptionsAndVars(
SqlEngine engine, DeclaredStatement stmt) {
SqlEngine engine,
DeclaredStatement stmt,
DefinedSqlQuery query,
KnownDriftTypes helper,
) {
final reader = engine.schemaReader;
final indexedHints = <int, ResolvedType>{};
final namedHints = <String, ResolvedType>{};
@ -168,6 +175,7 @@ class FileAnalyzer {
indexedVariableTypes: indexedHints,
namedVariableTypes: namedHints,
defaultValuesForPlaceholder: defaultValues,
resolveTypeFromText: enumColumnFromText(query.dartTypes, helper),
),
RequiredVariables(requiredIndex, requiredName),
);

View File

@ -53,7 +53,7 @@ class QueryAnalyzer {
final RequiredVariables requiredVariables;
final Map<String, DriftElement> referencesByName;
final Map<InlineDartToken, dart.Expression> _resolvedExpressions = {};
final Map<String, dart.Expression> _resolvedExpressions = {};
/// Found tables and views found need to be shared between the query and
/// all subqueries to not muss any updates when watching query.
@ -95,7 +95,7 @@ class QueryAnalyzer {
/// messages.
Future<SqlQuery> analyze(DriftQueryDeclaration declaration,
{DriftTableName? sourceForCustomName}) async {
await _resolveDartTokens();
await _resolveDartTokens(declaration);
final nestedAnalyzer = NestedQueryAnalyzer();
NestedQueriesContainer? nestedScope;
@ -143,25 +143,26 @@ class QueryAnalyzer {
return query;
}
Future<void> _resolveDartTokens() async {
for (final mappedBy in context.root.allDescendants.whereType<MappedBy>()) {
try {
final expression = await driver.backend.resolveExpression(
fromFile.ownUri,
mappedBy.mapper.dartCode,
fromFile.discovery?.importDependencies
.map((e) => e.toString())
.where((e) => e.endsWith('.dart')) ??
const Iterable.empty(),
);
Future<void> _resolveDartTokens(DriftQueryDeclaration declaration) async {
if (declaration is DefinedSqlQuery) {
for (final expression in declaration.dartTokens) {
try {
final resolved = await driver.backend.resolveExpression(
fromFile.ownUri,
expression,
fromFile.discovery?.importDependencies
.map((e) => e.toString())
.where((e) => e.endsWith('.dart')) ??
const Iterable.empty(),
);
_resolvedExpressions[mappedBy.mapper] = expression;
} on CannotReadExpressionException catch (e) {
lints.add(AnalysisError(
type: AnalysisErrorType.other,
message: 'Could not read expression: ${e.msg}',
relevantNode: mappedBy.mapper,
));
_resolvedExpressions[expression] = resolved;
} on CannotReadExpressionException catch (e) {
lints.add(AnalysisError(
type: AnalysisErrorType.other,
message: 'Could not read expression: ${e.msg}',
));
}
}
}
}
@ -291,7 +292,7 @@ class QueryAnalyzer {
if (type?.hint is TypeConverterHint) {
converter = (type!.hint as TypeConverterHint).converter;
} else if (mappedBy != null) {
final dartExpression = _resolvedExpressions[mappedBy.mapper];
final dartExpression = _resolvedExpressions[mappedBy.mapper.dartCode];
if (dartExpression != null) {
converter = readTypeConverter(
knownTypes.helperLibrary,

View File

@ -58,6 +58,13 @@ class DefinedSqlQuery extends DriftElement implements DriftQueryDeclaration {
@override
String get name => id.name;
/// All in-line Dart source code literals embedded into the query.
final List<String> dartTokens;
/// All Dart type names embedded into the query, for instance in a
/// `CAST(x AS ENUMNAME(MyDartType))` expression.
final Map<String, DartType> dartTypes;
DefinedSqlQuery(
super.id,
super.declaration, {
@ -66,6 +73,8 @@ class DefinedSqlQuery extends DriftElement implements DriftQueryDeclaration {
required this.sqlOffset,
this.resultClassName,
this.existingDartType,
this.dartTokens = const [],
this.dartTypes = const {},
this.mode = QueryMode.regular,
});
}

View File

@ -70,8 +70,13 @@ class ElementSerializer {
'sql': element.sql,
'offset': element.sqlOffset,
'result_class': element.resultClassName,
'eixsting_type': _serializeType(element.existingDartType),
'existing_type': _serializeType(element.existingDartType),
'mode': element.mode.name,
'dart_tokens': element.dartTokens,
'dart_types': {
for (final entry in element.dartTypes.entries)
entry.key: _serializeType(entry.value)
},
};
} else if (element is DriftTrigger) {
additionalInformation = {
@ -515,7 +520,13 @@ class ElementDeserializer {
createStmt: json['sql'] as String,
);
case 'query':
final rawExistingType = json['eixsting_type'];
final rawExistingType = json['existing_type'];
final types = <String, DartType>{};
for (final entry in (json['dart_types'] as Map).entries) {
types[entry.key as String] =
await _readDartType(id.libraryUri, entry.value as int);
}
return DefinedSqlQuery(
id,
@ -528,6 +539,8 @@ class ElementDeserializer {
? await _readDartType(id.libraryUri, rawExistingType as int)
: null,
mode: QueryMode.values.byName(json['mode'] as String),
dartTokens: (json['dart_tokens'] as List).cast(),
dartTypes: types,
);
case 'trigger':
DriftTable? on;

View File

@ -4,6 +4,7 @@ import 'package:drift/drift.dart' show SqlDialect;
import 'package:sqlparser/sqlparser.dart';
import 'package:sqlparser/utils/node_to_text.dart';
import '../../analysis/resolver/drift/element_resolver.dart';
import '../../analysis/results/results.dart';
import '../../analysis/options.dart';
import '../../utils/string_escaper.dart';
@ -117,6 +118,13 @@ class SqlWriter extends NodeSqlBuilder {
overriddenTypeName = options.storeDateTimeValuesAsText ? 'TEXT' : 'INT';
} else if (hint is IsBoolean) {
overriddenTypeName = 'INT';
} else {
final enumMatch = FoundReferencesInSql.enumRegex.firstMatch(e.typeName);
if (enumMatch != null) {
final isStoredAsText = enumMatch.group(1) != null;
overriddenTypeName = isStoredAsText ? 'TEXT' : 'INT';
}
}
if (overriddenTypeName != null) {

View File

@ -222,6 +222,57 @@ TypeConverter<Object, int> createConverter() => throw UnimplementedError();
);
});
test('supports enum columns', () async {
final backend = TestBackend.inTest({
'a|lib/a.drift': '''
import 'enums.dart';
CREATE VIEW foo AS SELECT
1 AS c1,
CAST(1 AS ENUM(MyEnum)) AS c2,
CAST('foo' AS ENUMNAME(MyEnum)) AS c3;
''',
'a|lib/enums.dart': '''
enum MyEnum {
foo, bar
}
''',
});
final state = await backend.analyze('package:a/a.drift');
backend.expectNoErrors();
final view = state.analyzedElements.single as DriftView;
final c1 = view.columns[0];
final c2 = view.columns[1];
final c3 = view.columns[2];
expect(c1.typeConverter, isNull);
expect(
c2.typeConverter,
isA<AppliedTypeConverter>()
.having((e) => e.isDriftEnumTypeConverter, 'isDriftEnumTypeConverter',
isTrue)
.having((e) => e.owningColumn, 'owningColumn', c2),
);
expect(
c3.typeConverter,
isA<AppliedTypeConverter>()
.having((e) => e.isDriftEnumTypeConverter, 'isDriftEnumTypeConverter',
isTrue)
.having((e) => e.owningColumn, 'owningColumn', c3),
);
expect(
view.source,
isA<SqlViewSource>().having(
(e) => e.sqlCreateViewStmt,
'sqlCreateViewStmt',
"CREATE VIEW foo AS SELECT 1 AS c1, CAST(1 AS INT) AS c2, CAST('foo' AS TEXT) AS c3;",
),
);
});
group('desugars cast', () {
Future<void> expectView(
String definition,

View File

@ -311,4 +311,40 @@ a: SELECT CAST(1 AS BOOLEAN) AS a, CAST(2 AS DATETIME) as b;
.having((e) => e.sqlType, 'sqlType', DriftSqlType.dateTime),
]);
});
test('can cast to enum type', () async {
final backend = TestBackend.inTest({
'a|lib/a.drift': '''
import 'enum.dart';
a: SELECT
1 AS c1,
CAST(1 AS ENUM(MyEnum)) AS c2,
CAST('foo' AS ENUMNAME(MyEnum)) AS c3;
''',
'a|lib/enum.dart': '''
enum MyEnum {
foo, bar
}
''',
});
final state = await backend.analyze('package:a/a.drift');
backend.expectNoErrors();
final query = state.fileAnalysis!.resolvedQueries.values.single;
final resultSet = query.resultSet!;
final isEnumConverter = isA<AppliedTypeConverter>().having(
(e) => e.isDriftEnumTypeConverter, 'isDriftEnumTypeConverter', isTrue);
expect(resultSet.columns, [
scalarColumn('c1')
.having((e) => e.typeConverter, 'typeConverter', isNull),
scalarColumn('c2')
.having((e) => e.typeConverter, 'typeConverter', isEnumConverter),
scalarColumn('c3')
.having((e) => e.typeConverter, 'typeConverter', isEnumConverter),
]);
});
}

View File

@ -1,5 +1,8 @@
part of 'analysis.dart';
/// Signature of a function that resolves the type of a SQL type literal.
typedef TypeFromText = ResolvedType? Function(String);
/// Options to analyze a sql statement. This can be used if the type of a
/// variable is known from the outside.
class AnalyzeStatementOptions {
@ -10,10 +13,13 @@ class AnalyzeStatementOptions {
/// expression, if set.
final Map<String, Expression> defaultValuesForPlaceholder;
final TypeFromText? resolveTypeFromText;
const AnalyzeStatementOptions({
this.indexedVariableTypes = const {},
this.namedVariableTypes = const {},
this.defaultValuesForPlaceholder = const {},
this.resolveTypeFromText,
});
/// Looks up the defined type for that variable.

View File

@ -10,9 +10,12 @@ class SchemaFromCreateTable {
/// enabled) should be reported as a text column instead of an int column.
final bool driftUseTextForDateTime;
final AnalyzeStatementOptions? statementOptions;
const SchemaFromCreateTable({
this.driftExtensions = false,
this.driftUseTextForDateTime = false,
this.statementOptions,
});
/// Reads a [Table] schema from the [stmt] inducing a table (either a
@ -146,6 +149,12 @@ class SchemaFromCreateTable {
return const ResolvedType(type: BasicType.blob);
}
// S if a custom resolver is installed and yields a type for this column:
final custom = statementOptions?.resolveTypeFromText?.call(typeName);
if (custom != null) {
return custom;
}
final upper = typeName.toUpperCase();
if (upper.contains('INT')) {
return const ResolvedType(type: BasicType.int);

View File

@ -34,12 +34,25 @@ class SqlEngine {
/// The returned reader can be used to read the table structure from a
/// [TableInducingStatement] by using [SchemaFromCreateTable.read].
SchemaFromCreateTable get schemaReader {
return _schemaReader ??= _createSchemaReader(null);
}
SchemaFromCreateTable _createSchemaReader(
AnalyzeStatementOptions? stmtOptions) {
final driftOptions = options.driftOptions;
return _schemaReader ??= SchemaFromCreateTable(
driftExtensions: driftOptions != null,
driftUseTextForDateTime: driftOptions?.storeDateTimesAsText == true,
);
if (stmtOptions != null) {
return SchemaFromCreateTable(
driftExtensions: driftOptions != null,
driftUseTextForDateTime: driftOptions?.storeDateTimesAsText == true,
statementOptions: stmtOptions,
);
} else {
return _schemaReader ??= SchemaFromCreateTable(
driftExtensions: driftOptions != null,
driftUseTextForDateTime: driftOptions?.storeDateTimesAsText == true,
);
}
}
/// Registers the [table], which means that it can later be used in sql
@ -223,8 +236,10 @@ class SqlEngine {
AnalysisContext _createContext(
AstNode node, String sql, AnalyzeStatementOptions? stmtOptions) {
final schemaSupport = _createSchemaReader(stmtOptions);
return AnalysisContext(node, sql, _constructRootScope(), options,
stmtOptions: stmtOptions, schemaSupport: schemaReader);
stmtOptions: stmtOptions, schemaSupport: schemaSupport);
}
void _analyzeContext(AnalysisContext context) {

View File

@ -116,4 +116,25 @@ WITH RECURSIVE
everyElement(isA<ResolveResult>()
.having((e) => e.type?.nullable, 'type.nullable', isTrue)));
});
test('can extract custom type', () {
final engine = SqlEngine();
final content = engine.analyze(
'SELECT CAST(1 AS MyCustomType)',
stmtOptions: AnalyzeStatementOptions(
resolveTypeFromText: expectAsync1(
(typeName) {
expect(typeName, 'MyCustomType');
return ResolvedType.bool();
},
),
),
);
final select = content.root as SelectStatement;
final column = select.resolvedColumns!.single;
expect(content.typeOf(column).type, ResolvedType.bool());
});
}