From 21ff422f19a003b20a08b56dc6bb1721865809b1 Mon Sep 17 00:00:00 2001 From: Simon Binder Date: Sun, 12 Feb 2023 12:44:47 +0100 Subject: [PATCH] Support `@create` queries in modular mode (#2313) --- drift_dev/lib/src/analysis/driver/state.dart | 7 +++-- .../lib/src/analysis/resolver/discover.dart | 2 +- .../lib/src/analysis/results/element.dart | 10 +++---- drift_dev/lib/src/analysis/results/query.dart | 9 ++++++ .../lib/src/backends/build/drift_builder.dart | 21 ++++++++++++++ drift_dev/lib/src/writer/database_writer.dart | 26 ++++++++++------- drift_dev/lib/src/writer/modules.dart | 7 ++++- .../build/build_integration_test.dart | 28 +++++++++++++++++++ examples/modular/lib/database.drift.dart | 15 +++++----- examples/modular/lib/src/user_queries.drift | 1 + .../modular/lib/src/user_queries.drift.dart | 3 ++ .../integration_test/drift_native_test.dart | 1 - 12 files changed, 102 insertions(+), 28 deletions(-) diff --git a/drift_dev/lib/src/analysis/driver/state.dart b/drift_dev/lib/src/analysis/driver/state.dart index 398e4933..a44f322f 100644 --- a/drift_dev/lib/src/analysis/driver/state.dart +++ b/drift_dev/lib/src/analysis/driver/state.dart @@ -63,11 +63,12 @@ class FileState { } bool get _definesQuery { - return analyzedElements.any((e) => e is DefinedSqlQuery) || + return analyzedElements + .any((e) => e is DefinedSqlQuery && e.mode == QueryMode.regular) || // Also check discovery, we might not have analyzed all elements in this // file if it's just an import. - discovery?.locallyDefinedElements - .any((e) => e is DiscoveredDriftStatement) == + discovery?.locallyDefinedElements.any((e) => + e is DiscoveredDriftStatement && e.sqlNode.isRegularQuery) == true; } diff --git a/drift_dev/lib/src/analysis/resolver/discover.dart b/drift_dev/lib/src/analysis/resolver/discover.dart index d7b85013..3819dc6d 100644 --- a/drift_dev/lib/src/analysis/resolver/discover.dart +++ b/drift_dev/lib/src/analysis/resolver/discover.dart @@ -131,7 +131,7 @@ class DiscoverStep { if (declaredName is SimpleName) { name = declaredName.name; } else { - name = 'special:${specialQueryNameCount++}'; + name = '\$drift_${specialQueryNameCount++}'; } pendingElements.add(DiscoveredDriftStatement(_id(name), node)); diff --git a/drift_dev/lib/src/analysis/results/element.dart b/drift_dev/lib/src/analysis/results/element.dart index 51008d90..0608fb84 100644 --- a/drift_dev/lib/src/analysis/results/element.dart +++ b/drift_dev/lib/src/analysis/results/element.dart @@ -87,6 +87,11 @@ abstract class DriftElement { /// - tables included in the `@DriftDatabase` annotation. Iterable get references => const Iterable.empty(); + /// The getter in a generated database accessor referring to this model. + /// + /// Returns null for entities that shouldn't have a getter. + String? get dbGetterName => null; + /// If this element was extracted from a defined Dart class, returns the name /// of that class. AnnotatedDartCode? get definingDartClass { @@ -105,11 +110,6 @@ abstract class DriftSchemaElement extends DriftElement { /// The exact, unaliased name of this element in the database's schema. String get schemaName => id.name; - /// The getter in a generated database accessor referring to this model. - /// - /// Returns null for entities that shouldn't have a getter. - String? get dbGetterName; - static String dbFieldName(String baseName) { return ReCase(baseName).camelCase; } diff --git a/drift_dev/lib/src/analysis/results/query.dart b/drift_dev/lib/src/analysis/results/query.dart index 88f138b0..b0355e1a 100644 --- a/drift_dev/lib/src/analysis/results/query.dart +++ b/drift_dev/lib/src/analysis/results/query.dart @@ -58,6 +58,15 @@ class DefinedSqlQuery extends DriftElement implements DriftQueryDeclaration { @override String get name => id.name; + @override + String? get dbGetterName { + if (mode != QueryMode.regular) { + return DriftSchemaElement.dbFieldName(id.name); + } else { + return null; + } + } + /// All in-line Dart source code literals embedded into the query. final List dartTokens; diff --git a/drift_dev/lib/src/backends/build/drift_builder.dart b/drift_dev/lib/src/backends/build/drift_builder.dart index a4600c7c..f7c97f11 100644 --- a/drift_dev/lib/src/backends/build/drift_builder.dart +++ b/drift_dev/lib/src/backends/build/drift_builder.dart @@ -280,6 +280,27 @@ class _DriftBuildRun { final input = AccessorGenerationInput(result, resolved, const {}, driver); AccessorWriter(input, writer.child()).write(); + } else if (result is DefinedSqlQuery) { + switch (result.mode) { + case QueryMode.regular: + // Ignore, this query will be made available in a generated accessor + // class. + break; + case QueryMode.atCreate: + final resolved = + entrypointState.fileAnalysis?.resolvedQueries[result.id]; + + if (resolved != null) { + writer.leaf() + ..writeDriftRef('OnCreateQuery') + ..write(' get ${result.dbGetterName} => ') + ..write(DatabaseWriter.createOnCreate( + writer.child(), result, resolved)) + ..writeln(';'); + } + + break; + } } } diff --git a/drift_dev/lib/src/writer/database_writer.dart b/drift_dev/lib/src/writer/database_writer.dart index 5a492746..96eba75a 100644 --- a/drift_dev/lib/src/writer/database_writer.dart +++ b/drift_dev/lib/src/writer/database_writer.dart @@ -70,9 +70,9 @@ class DatabaseWriter { firstLeaf.write('$className.connect($conn c): super.connect(c); \n'); } - final entityGetters = {}; + final entityGetters = {}; - for (final entity in elements.whereType()) { + for (final entity in elements.whereType()) { final getterName = entity.dbGetterName; if (getterName != null) { @@ -80,7 +80,7 @@ class DatabaseWriter { // created in the database instance. However, triggers and indices are // generated as a top-level field which is simply imported. if (scope.generationOptions.isModular && - (entity is DriftTrigger || entity is DriftIndex)) { + (entity is! DriftElementWithResultSet)) { final import = dbScope.generatedElement(entity, getterName); entityGetters[entity] = dbScope.dartCode(import); @@ -163,16 +163,14 @@ class DatabaseWriter { ..write('@override\nIterable<$tableInfoType> get allTables => ') ..write('allSchemaEntities.whereType<$tableInfoType>();\n') ..write('@override\nList<$schemaEntity> get allSchemaEntities ') - ..write('=> ['); - - schemaScope + ..write('=> [') ..write(elements .map((e) { - if (e is DefinedSqlQuery && e.mode == QueryMode.atCreate) { + if (e is DefinedSqlQuery && + e.mode == QueryMode.atCreate && + !scope.generationOptions.isModular) { final resolved = input.importedQueries[e]!; - final sql = schemaScope.sqlCode(resolved.root!); - - return 'OnCreateQuery(${asDartLiteral(sql)})'; + return createOnCreate(dbScope, e, resolved); } return entityGetters[e]; @@ -234,6 +232,14 @@ class DatabaseWriter { return '$index(${asDartLiteral(entity.schemaName)}, ${asDartLiteral(sql)})'; } + + static String createOnCreate( + Scope scope, DefinedSqlQuery query, SqlQuery resolved) { + final sql = scope.sqlCode(resolved.root!); + final onCreate = scope.drift('OnCreateQuery'); + + return '$onCreate(${asDartLiteral(sql)})'; + } } class GenerationInput { diff --git a/drift_dev/lib/src/writer/modules.dart b/drift_dev/lib/src/writer/modules.dart index c15b5b39..1392a689 100644 --- a/drift_dev/lib/src/writer/modules.dart +++ b/drift_dev/lib/src/writer/modules.dart @@ -48,9 +48,14 @@ class ModularAccessorWriter { queries = queries.map((k, v) => MapEntry(k, mappedQueries[v] ?? v)); for (final query in queries.entries) { - final queryElement = file.analysis[query.key]?.result; + final queryElement = file.analysis[query.key]?.result as DefinedSqlQuery?; if (queryElement != null) { referencedElements.addAll(queryElement.references); + + if (queryElement.mode != QueryMode.regular) { + // Not a query for which a public API should exist + continue; + } } final value = query.value; diff --git a/drift_dev/test/backends/build/build_integration_test.dart b/drift_dev/test/backends/build/build_integration_test.dart index a5c17fdb..b81d8ca1 100644 --- a/drift_dev/test/backends/build/build_integration_test.dart +++ b/drift_dev/test/backends/build/build_integration_test.dart @@ -319,4 +319,32 @@ TypeConverter myConverter() => throw UnimplementedError(); result, ); }); + + test('supports @create queries in modular generation', () async { + final result = await emulateDriftBuild( + inputs: { + 'a|lib/a.drift': ''' +CREATE TABLE foo (bar INTEGER PRIMARY KEY); + +@create: INSERT INTO foo VALUES (1); +''', + 'a|lib/db.dart': r''' +import 'package:drift/drift.dart'; + +import 'db.drift.dart'; + +@DriftDatabase(include: {'a.drift'}) +class Database extends $Database {} +''', + }, + modularBuild: true, + logger: loggerThat(neverEmits(anything)), + ); + + checkOutputs({ + 'a|lib/a.drift.dart': + decodedMatches(contains(r'OnCreateQuery get $drift0 => ')), + 'a|lib/db.drift.dart': decodedMatches(contains(r'.$drift0];')) + }, result.dartOutputs, result); + }); } diff --git a/examples/modular/lib/database.drift.dart b/examples/modular/lib/database.drift.dart index 754de1e4..60db061f 100644 --- a/examples/modular/lib/database.drift.dart +++ b/examples/modular/lib/database.drift.dart @@ -3,9 +3,9 @@ import 'package:drift/drift.dart' as i0; import 'package:modular/src/users.drift.dart' as i1; import 'package:modular/src/posts.drift.dart' as i2; import 'package:modular/src/search.drift.dart' as i3; -import 'package:modular/accessor.dart' as i4; -import 'package:modular/database.dart' as i5; -import 'package:modular/src/user_queries.drift.dart' as i6; +import 'package:modular/src/user_queries.drift.dart' as i4; +import 'package:modular/accessor.dart' as i5; +import 'package:modular/database.dart' as i6; import 'package:drift/internal/modular.dart' as i7; abstract class $Database extends i0.GeneratedDatabase { @@ -16,9 +16,9 @@ abstract class $Database extends i0.GeneratedDatabase { late final i2.Likes likes = i2.Likes(this); late final i1.Follows follows = i1.Follows(this); late final i1.PopularUsers popularUsers = i1.PopularUsers(this); - late final i4.MyAccessor myAccessor = i4.MyAccessor(this as i5.Database); - i6.UserQueriesDrift get userQueriesDrift => i7.ReadDatabaseContainer(this) - .accessor(i6.UserQueriesDrift.new); + late final i5.MyAccessor myAccessor = i5.MyAccessor(this as i6.Database); + i4.UserQueriesDrift get userQueriesDrift => i7.ReadDatabaseContainer(this) + .accessor(i4.UserQueriesDrift.new); i3.SearchDrift get searchDrift => i7.ReadDatabaseContainer(this) .accessor(i3.SearchDrift.new); @override @@ -34,7 +34,8 @@ abstract class $Database extends i0.GeneratedDatabase { i3.postsDelete, likes, follows, - popularUsers + popularUsers, + i4.$drift0 ]; @override i0.StreamQueryUpdateRules get streamUpdateRules => diff --git a/examples/modular/lib/src/user_queries.drift b/examples/modular/lib/src/user_queries.drift index ddb76414..fb8aaeb4 100644 --- a/examples/modular/lib/src/user_queries.drift +++ b/examples/modular/lib/src/user_queries.drift @@ -3,3 +3,4 @@ import 'users.drift'; findUsers($predicate = TRUE): SELECT * FROM users WHERE $predicate; findPopularUsers: SELECT * FROM popular_users; follow: INSERT INTO follows VALUES (?, ?); +@create: UPDATE users SET id = id + 1; diff --git a/examples/modular/lib/src/user_queries.drift.dart b/examples/modular/lib/src/user_queries.drift.dart index 8fe27dbd..76d20022 100644 --- a/examples/modular/lib/src/user_queries.drift.dart +++ b/examples/modular/lib/src/user_queries.drift.dart @@ -3,6 +3,9 @@ import 'package:drift/drift.dart' as i0; import 'package:drift/internal/modular.dart' as i1; import 'package:modular/src/users.drift.dart' as i2; +i0.OnCreateQuery get $drift0 => + i0.OnCreateQuery('UPDATE users SET id = id + 1'); + class UserQueriesDrift extends i1.ModularAccessor { UserQueriesDrift(i0.GeneratedDatabase db) : super(db); i0.Selectable findUsers({FindUsers$predicate? predicate}) { diff --git a/extras/integration_tests/ffi_on_flutter/integration_test/drift_native_test.dart b/extras/integration_tests/ffi_on_flutter/integration_test/drift_native_test.dart index bb193020..e994401f 100644 --- a/extras/integration_tests/ffi_on_flutter/integration_test/drift_native_test.dart +++ b/extras/integration_tests/ffi_on_flutter/integration_test/drift_native_test.dart @@ -1,5 +1,4 @@ import 'dart:io'; -import 'dart:ui'; import 'package:drift/isolate.dart'; import 'package:drift/native.dart';