Add multi-dialect generation option

This commit is contained in:
Simon Binder 2023-07-26 23:59:06 +02:00
parent 94c4c1a8e0
commit 8d3f490604
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
28 changed files with 436 additions and 101 deletions

View File

@ -110,6 +110,36 @@ in 3.34, so an error would be reported.
Currently, the generator can't provide compatibility checks for versions below 3.34, which is the Currently, the generator can't provide compatibility checks for versions below 3.34, which is the
minimum version needed in options. minimum version needed in options.
### Multi-dialect code generation
Thanks to community contributions, drift has in-progress support for Postgres and MariaDB.
You can change the `dialect` option to `postgres` or `mariadb` to generate code for those
database management systems.
In some cases, your generated code might have to support more than one DBMS. For instance,
you might want to share database code between your backend and a Flutter app. Or maybe
you're writing a server that should be able to talk to both MariaDB and Postgres, depending
on what the operator prefers.
Drift can generate code for multiple dialects - in that case, the right SQL will be chosen
at runtime when it makes a difference.
To enable this feature, remove the `dialect` option in the `sql` block and replace it with
a list of `dialects`:
```yaml
targets:
$default:
builders:
drift_dev:
options:
sql:
dialect:
- sqlite
- postgres
options:
version: "3.34"
```
### Available extensions ### Available extensions
__Note__: This enables extensions in the analyzer for custom queries only. For instance, when the `json1` extension is __Note__: This enables extensions in the analyzer for custom queries only. For instance, when the `json1` extension is

View File

@ -534,7 +534,7 @@ class $TodoCategoryItemCountView
@override @override
String get entityName => 'todo_category_item_count'; String get entityName => 'todo_category_item_count';
@override @override
String? get createViewStmt => null; Map<SqlDialect, String>? get createViewStatements => null;
@override @override
$TodoCategoryItemCountView get asDslTable => this; $TodoCategoryItemCountView get asDslTable => this;
@override @override
@ -639,7 +639,7 @@ class $TodoItemWithCategoryNameViewView extends ViewInfo<
@override @override
String get entityName => 'customViewName'; String get entityName => 'customViewName';
@override @override
String? get createViewStmt => null; Map<SqlDialect, String>? get createViewStatements => null;
@override @override
$TodoItemWithCategoryNameViewView get asDslTable => this; $TodoItemWithCategoryNameViewView get asDslTable => this;
@override @override

View File

@ -157,6 +157,10 @@ class VersionedView implements ViewInfo<HasResultSet, QueryRow>, HasResultSet {
@override @override
final String createViewStmt; final String createViewStmt;
@override
Map<SqlDialect, String>? get createViewStatements =>
{SqlDialect.sqlite: createViewStmt};
@override @override
final List<GeneratedColumn> $columns; final List<GeneratedColumn> $columns;

View File

@ -11,6 +11,8 @@ class CustomExpression<D extends Object> extends Expression<D> {
/// The SQL of this expression /// The SQL of this expression
final String content; final String content;
final Map<SqlDialect, String>? _dialectSpecificContent;
/// Additional tables that this expression is watching. /// Additional tables that this expression is watching.
/// ///
/// When this expression is used in a stream query, the stream will update /// When this expression is used in a stream query, the stream will update
@ -24,11 +26,25 @@ class CustomExpression<D extends Object> extends Expression<D> {
/// Constructs a custom expression by providing the raw sql [content]. /// Constructs a custom expression by providing the raw sql [content].
const CustomExpression(this.content, const CustomExpression(this.content,
{this.watchedTables = const [], this.precedence = Precedence.unknown}); {this.watchedTables = const [], this.precedence = Precedence.unknown})
: _dialectSpecificContent = null;
/// Constructs a custom expression providing the raw SQL in [content] depending
/// on the SQL dialect when this expression is built.
const CustomExpression.dialectSpecific(Map<SqlDialect, String> content,
{this.watchedTables = const [], this.precedence = Precedence.unknown})
: _dialectSpecificContent = content,
content = '';
@override @override
void writeInto(GenerationContext context) { void writeInto(GenerationContext context) {
final dialectSpecific = _dialectSpecificContent;
if (dialectSpecific != null) {
} else {
context.buffer.write(content); context.buffer.write(content);
}
context.watchedTables.addAll(watchedTables); context.watchedTables.addAll(watchedTables);
} }

View File

@ -1,6 +1,11 @@
@internal
library;
import 'package:drift/drift.dart'; import 'package:drift/drift.dart';
/// Utilities for writing the definition of a result set into a query. import 'package:meta/meta.dart';
/// Internal utilities for building queries that aren't exported.
extension WriteDefinition on GenerationContext { extension WriteDefinition on GenerationContext {
/// Writes the result set to this context, suitable to implement `FROM` /// Writes the result set to this context, suitable to implement `FROM`
/// clauses and joins. /// clauses and joins.
@ -16,4 +21,21 @@ extension WriteDefinition on GenerationContext {
watchedTables.add(resultSet); watchedTables.add(resultSet);
} }
} }
/// Returns a suitable SQL string in [sql] based on the current dialect.
String pickForDialect(Map<SqlDialect, String> sql) {
assert(
sql.containsKey(dialect),
'Tried running SQL optimized for the following dialects: ${sql.keys.join}. '
'However, the database is running $dialect. Has that dialect been added '
'to the `dialects` drift builder option?',
);
final found = sql[dialect];
if (found != null) {
return found;
}
return sql.values.first; // Fallback
}
} }

View File

@ -87,7 +87,7 @@ class Migrator {
} else if (entity is Index) { } else if (entity is Index) {
await createIndex(entity); await createIndex(entity);
} else if (entity is OnCreateQuery) { } else if (entity is OnCreateQuery) {
await _issueCustomQuery(entity.sql, const []); await _issueQueryByDialect(entity.sqlByDialect);
} else if (entity is ViewInfo) { } else if (entity is ViewInfo) {
await createView(entity); await createView(entity);
} else { } else {
@ -363,19 +363,19 @@ class Migrator {
/// Executes the `CREATE TRIGGER` statement that created the [trigger]. /// Executes the `CREATE TRIGGER` statement that created the [trigger].
Future<void> createTrigger(Trigger trigger) { Future<void> createTrigger(Trigger trigger) {
return _issueCustomQuery(trigger.createTriggerStmt, const []); return _issueQueryByDialect(trigger.createStatementsByDialect);
} }
/// Executes a `CREATE INDEX` statement to create the [index]. /// Executes a `CREATE INDEX` statement to create the [index].
Future<void> createIndex(Index index) { Future<void> createIndex(Index index) {
return _issueCustomQuery(index.createIndexStmt, const []); return _issueQueryByDialect(index.createStatementsByDialect);
} }
/// Executes a `CREATE VIEW` statement to create the [view]. /// Executes a `CREATE VIEW` statement to create the [view].
Future<void> createView(ViewInfo view) async { Future<void> createView(ViewInfo view) async {
final stmt = view.createViewStmt; final stmts = view.createViewStatements;
if (stmt != null) { if (stmts != null) {
await _issueCustomQuery(stmt, const []); await _issueQueryByDialect(stmts);
} else if (view.query != null) { } else if (view.query != null) {
final context = GenerationContext.fromDb(_db, supportsVariables: false); final context = GenerationContext.fromDb(_db, supportsVariables: false);
final columnNames = view.$columns.map((e) => e.escapedName).join(', '); final columnNames = view.$columns.map((e) => e.escapedName).join(', ');
@ -478,6 +478,11 @@ class Migrator {
return _issueCustomQuery(sql, args); return _issueCustomQuery(sql, args);
} }
Future<void> _issueQueryByDialect(Map<SqlDialect, String> sql) {
final context = _createContext();
return _issueCustomQuery(context.pickForDialect(sql), const []);
}
Future<void> _issueCustomQuery(String sql, [List<dynamic>? args]) { Future<void> _issueCustomQuery(String sql, [List<dynamic>? args]) {
return _db.customStatement(sql, args); return _db.customStatement(sql, args);
} }

View File

@ -17,15 +17,26 @@ abstract class DatabaseSchemaEntity {
/// [sqlite-docs]: https://sqlite.org/lang_createtrigger.html /// [sqlite-docs]: https://sqlite.org/lang_createtrigger.html
/// [sql-tut]: https://www.sqlitetutorial.net/sqlite-trigger/ /// [sql-tut]: https://www.sqlitetutorial.net/sqlite-trigger/
class Trigger extends DatabaseSchemaEntity { class Trigger extends DatabaseSchemaEntity {
/// The `CREATE TRIGGER` sql statement that can be used to create this
/// trigger.
final String createTriggerStmt;
@override @override
final String entityName; final String entityName;
/// The `CREATE TRIGGER` sql statement that can be used to create this
/// trigger.
@Deprecated('Use createStatementsByDialect instead')
String get createTriggerStmt => createStatementsByDialect.values.first;
/// The `CREATE TRIGGER` SQL statements used to create this trigger, accessible
/// for each dialect enabled when generating code.
final Map<SqlDialect, String> createStatementsByDialect;
/// Creates a trigger representation by the [createTriggerStmt] and its /// Creates a trigger representation by the [createTriggerStmt] and its
/// [entityName]. Mainly used by generated code. /// [entityName]. Mainly used by generated code.
Trigger(this.createTriggerStmt, this.entityName); Trigger(String createTriggerStmt, String entityName)
: this.byDialect(entityName, {SqlDialect.sqlite: createTriggerStmt});
/// Creates the trigger model from its [entityName] in the schema and all
/// [createStatementsByDialect] for the supported dialects.
Trigger.byDialect(this.entityName, this.createStatementsByDialect);
} }
/// A sqlite index on columns or expressions. /// A sqlite index on columns or expressions.
@ -40,11 +51,21 @@ class Index extends DatabaseSchemaEntity {
final String entityName; final String entityName;
/// The `CREATE INDEX` sql statement that can be used to create this index. /// The `CREATE INDEX` sql statement that can be used to create this index.
final String createIndexStmt; @Deprecated('Use createStatementsByDialect instead')
String get createIndexStmt => createStatementsByDialect.values.first;
/// The `CREATE INDEX` SQL statements used to create this index, accessible
/// for each dialect enabled when generating code.
final Map<SqlDialect, String> createStatementsByDialect;
/// Creates an index model by the [createIndexStmt] and its [entityName]. /// Creates an index model by the [createIndexStmt] and its [entityName].
/// Mainly used by generated code. /// Mainly used by generated code.
Index(this.entityName, this.createIndexStmt); Index(this.entityName, String createIndexStmt)
: createStatementsByDialect = {SqlDialect.sqlite: createIndexStmt};
/// Creates an index model by its [entityName] used in the schema and the
/// `CREATE INDEX` statements for each supported dialect.
Index.byDialect(this.entityName, this.createStatementsByDialect);
} }
/// An internal schema entity to run an sql statement when the database is /// An internal schema entity to run an sql statement when the database is
@ -61,10 +82,19 @@ class Index extends DatabaseSchemaEntity {
/// drift file. /// drift file.
class OnCreateQuery extends DatabaseSchemaEntity { class OnCreateQuery extends DatabaseSchemaEntity {
/// The sql statement that should be run in the default `onCreate` clause. /// The sql statement that should be run in the default `onCreate` clause.
final String sql; @Deprecated('Use sqlByDialect instead')
String get sql => sqlByDialect.values.first;
/// The SQL statement to run, indexed by the dialect used in the database.
final Map<SqlDialect, String> sqlByDialect;
/// Create a query that will be run in the default `onCreate` migration. /// Create a query that will be run in the default `onCreate` migration.
OnCreateQuery(this.sql); OnCreateQuery(String sql) : this.byDialect({SqlDialect.sqlite: sql});
/// Creates the entity of a query to run in the default `onCreate` migration.
///
/// The migrator will lookup a suitable query from the [sqlByDialect] map.
OnCreateQuery.byDialect(this.sqlByDialect);
@override @override
String get entityName => r'$internal$'; String get entityName => r'$internal$';

View File

@ -17,7 +17,14 @@ abstract class ViewInfo<Self extends HasResultSet, Row>
/// The `CREATE VIEW` sql statement that can be used to create this view. /// The `CREATE VIEW` sql statement that can be used to create this view.
/// ///
/// This will be null if the view was defined in Dart. /// This will be null if the view was defined in Dart.
String? get createViewStmt; @Deprecated('Use createViewStatements instead')
String? get createViewStmt => createViewStatements?.values.first;
/// The `CREATE VIEW` sql statement that can be used to create this view,
/// depending on the dialect used by the current database.
///
/// This will be null if the view was defined in Dart.
Map<SqlDialect, String>? get createViewStatements;
/// Predefined query from `View.as()` /// Predefined query from `View.as()`
/// ///

View File

@ -259,6 +259,32 @@ void main() {
)); ));
}); });
}); });
group('dialect-specific', () {
Map<SqlDialect, String> statements(String base) {
return {
for (final dialect in SqlDialect.values) dialect: '$base $dialect',
};
}
for (final dialect in [SqlDialect.sqlite, SqlDialect.postgres]) {
test('with dialect $dialect', () async {
final executor = MockExecutor();
when(executor.dialect).thenReturn(dialect);
final db = TodoDb(executor);
final migrator = db.createMigrator();
await migrator.create(Trigger.byDialect('a', statements('trigger')));
await migrator.create(Index.byDialect('a', statements('index')));
await migrator.create(OnCreateQuery.byDialect(statements('@')));
verify(executor.runCustom('trigger $dialect', []));
verify(executor.runCustom('index $dialect', []));
verify(executor.runCustom('@ $dialect', []));
});
}
});
} }
final class _FakeSchemaVersion extends VersionedSchema { final class _FakeSchemaVersion extends VersionedSchema {

View File

@ -1600,8 +1600,10 @@ class MyView extends ViewInfo<MyView, MyViewData> implements HasResultSet {
@override @override
String get entityName => 'my_view'; String get entityName => 'my_view';
@override @override
String get createViewStmt => Map<SqlDialect, String> get createViewStatements => {
'CREATE VIEW my_view AS SELECT * FROM config WHERE sync_state = 2'; SqlDialect.sqlite:
'CREATE VIEW my_view AS SELECT * FROM config WHERE sync_state = 2',
};
@override @override
MyView get asDslTable => this; MyView get asDslTable => this;
@override @override

View File

@ -641,15 +641,12 @@ class $UsersTable extends Users with TableInfo<$UsersTable, User> {
static const VerificationMeta _isAwesomeMeta = static const VerificationMeta _isAwesomeMeta =
const VerificationMeta('isAwesome'); const VerificationMeta('isAwesome');
@override @override
late final GeneratedColumn<bool> isAwesome = late final GeneratedColumn<bool> isAwesome = GeneratedColumn<bool>(
GeneratedColumn<bool>('is_awesome', aliasedName, false, 'is_awesome', aliasedName, false,
type: DriftSqlType.bool, type: DriftSqlType.bool,
requiredDuringInsert: false, requiredDuringInsert: false,
defaultConstraints: GeneratedColumn.constraintsDependsOnDialect({ defaultConstraints:
SqlDialect.sqlite: 'CHECK ("is_awesome" IN (0, 1))', GeneratedColumn.constraintIsAlways('CHECK ("is_awesome" IN (0, 1))'),
SqlDialect.mysql: '',
SqlDialect.postgres: '',
}),
defaultValue: const Constant(true)); defaultValue: const Constant(true));
static const VerificationMeta _profilePictureMeta = static const VerificationMeta _profilePictureMeta =
const VerificationMeta('profilePicture'); const VerificationMeta('profilePicture');
@ -1549,7 +1546,7 @@ class $CategoryTodoCountViewView
@override @override
String get entityName => 'category_todo_count_view'; String get entityName => 'category_todo_count_view';
@override @override
String? get createViewStmt => null; Map<SqlDialect, String>? get createViewStatements => null;
@override @override
$CategoryTodoCountViewView get asDslTable => this; $CategoryTodoCountViewView get asDslTable => this;
@override @override
@ -1660,7 +1657,7 @@ class $TodoWithCategoryViewView
@override @override
String get entityName => 'todo_with_category_view'; String get entityName => 'todo_with_category_view';
@override @override
String? get createViewStmt => null; Map<SqlDialect, String>? get createViewStatements => null;
@override @override
$TodoWithCategoryViewView get asDslTable => this; $TodoWithCategoryViewView get asDslTable => this;
@override @override

View File

@ -124,7 +124,7 @@ class DriftOptions {
this.modules = const [], this.modules = const [],
this.sqliteAnalysisOptions, this.sqliteAnalysisOptions,
this.storeDateTimeValuesAsText = false, this.storeDateTimeValuesAsText = false,
this.dialect = const DialectOptions(SqlDialect.sqlite, null), this.dialect = const DialectOptions(null, [SqlDialect.sqlite], null),
this.caseFromDartToSql = CaseFromDartToSql.snake, this.caseFromDartToSql = CaseFromDartToSql.snake,
this.writeToColumnsMixins = false, this.writeToColumnsMixins = false,
this.fatalWarnings = false, this.fatalWarnings = false,
@ -189,7 +189,18 @@ class DriftOptions {
/// Whether the [module] has been enabled in this configuration. /// Whether the [module] has been enabled in this configuration.
bool hasModule(SqlModule module) => effectiveModules.contains(module); bool hasModule(SqlModule module) => effectiveModules.contains(module);
SqlDialect get effectiveDialect => dialect?.dialect ?? SqlDialect.sqlite; List<SqlDialect> get supportedDialects {
final dialects = dialect?.dialects;
final singleDialect = dialect?.dialect;
if (dialects != null) {
return dialects;
} else if (singleDialect != null) {
return [singleDialect];
} else {
return const [SqlDialect.sqlite];
}
}
/// The assumed sqlite version used when analyzing queries. /// The assumed sqlite version used when analyzing queries.
SqliteVersion get sqliteVersion { SqliteVersion get sqliteVersion {
@ -201,10 +212,11 @@ class DriftOptions {
@JsonSerializable() @JsonSerializable()
class DialectOptions { class DialectOptions {
final SqlDialect dialect; final SqlDialect? dialect;
final List<SqlDialect>? dialects;
final SqliteAnalysisOptions? options; final SqliteAnalysisOptions? options;
const DialectOptions(this.dialect, this.options); const DialectOptions(this.dialect, this.dialects, this.options);
factory DialectOptions.fromJson(Map json) => _$DialectOptionsFromJson(json); factory DialectOptions.fromJson(Map json) => _$DialectOptionsFromJson(json);

View File

@ -1,3 +1,4 @@
import 'package:drift/drift.dart' show SqlDialect;
import 'package:recase/recase.dart'; import 'package:recase/recase.dart';
import 'package:sqlparser/sqlparser.dart'; import 'package:sqlparser/sqlparser.dart';
import 'package:sqlparser/sqlparser.dart' as sql; import 'package:sqlparser/sqlparser.dart' as sql;
@ -110,7 +111,8 @@ class DriftViewResolver extends DriftElementResolver<DiscoveredDriftView> {
query: stmt.query, query: stmt.query,
// Remove drift-specific syntax // Remove drift-specific syntax
driftTableName: null, driftTableName: null,
).toSqlWithoutDriftSpecificSyntax(resolver.driver.options); ).toSqlWithoutDriftSpecificSyntax(
resolver.driver.options, SqlDialect.sqlite);
return DriftView( return DriftView(
discovered.ownId, discovered.ownId,

View File

@ -179,11 +179,16 @@ DialectOptions _$DialectOptionsFromJson(Map json) => $checkedCreate(
($checkedConvert) { ($checkedConvert) {
$checkKeys( $checkKeys(
json, json,
allowedKeys: const ['dialect', 'options'], allowedKeys: const ['dialect', 'dialects', 'options'],
); );
final val = DialectOptions( final val = DialectOptions(
$checkedConvert( $checkedConvert(
'dialect', (v) => $enumDecode(_$SqlDialectEnumMap, v)), 'dialect', (v) => $enumDecodeNullable(_$SqlDialectEnumMap, v)),
$checkedConvert(
'dialects',
(v) => (v as List<dynamic>?)
?.map((e) => $enumDecode(_$SqlDialectEnumMap, e))
.toList()),
$checkedConvert( $checkedConvert(
'options', 'options',
(v) => (v) =>
@ -195,7 +200,9 @@ DialectOptions _$DialectOptionsFromJson(Map json) => $checkedCreate(
Map<String, dynamic> _$DialectOptionsToJson(DialectOptions instance) => Map<String, dynamic> _$DialectOptionsToJson(DialectOptions instance) =>
<String, dynamic>{ <String, dynamic>{
'dialect': _$SqlDialectEnumMap[instance.dialect]!, 'dialect': _$SqlDialectEnumMap[instance.dialect],
'dialects':
instance.dialects?.map((e) => _$SqlDialectEnumMap[e]!).toList(),
'options': instance.options?.toJson(), 'options': instance.options?.toJson(),
}; };

View File

@ -220,25 +220,37 @@ class DatabaseWriter {
} }
static String createTrigger(Scope scope, DriftTrigger entity) { static String createTrigger(Scope scope, DriftTrigger entity) {
final sql = scope.sqlCode(entity.parsedStatement!); final (sql, dialectSpecific) = scope.sqlByDialect(entity.parsedStatement!);
final trigger = scope.drift('Trigger'); final trigger = scope.drift('Trigger');
return '$trigger(${asDartLiteral(sql)}, ${asDartLiteral(entity.schemaName)})'; if (dialectSpecific) {
return '$trigger.byDialect(${asDartLiteral(entity.schemaName)}, $sql)';
} else {
return '$trigger($sql, ${asDartLiteral(entity.schemaName)})';
}
} }
static String createIndex(Scope scope, DriftIndex entity) { static String createIndex(Scope scope, DriftIndex entity) {
final sql = scope.sqlCode(entity.parsedStatement!); final (sql, dialectSpecific) = scope.sqlByDialect(entity.parsedStatement!);
final index = scope.drift('Index'); final index = scope.drift('Index');
return '$index(${asDartLiteral(entity.schemaName)}, ${asDartLiteral(sql)})'; if (dialectSpecific) {
return '$index.byDialect(${asDartLiteral(entity.schemaName)}, $sql)';
} else {
return '$index(${asDartLiteral(entity.schemaName)}, $sql)';
}
} }
static String createOnCreate( static String createOnCreate(
Scope scope, DefinedSqlQuery query, SqlQuery resolved) { Scope scope, DefinedSqlQuery query, SqlQuery resolved) {
final sql = scope.sqlCode(resolved.root!); final (sql, dialectSpecific) = scope.sqlByDialect(resolved.root!);
final onCreate = scope.drift('OnCreateQuery'); final onCreate = scope.drift('OnCreateQuery');
return '$onCreate(${asDartLiteral(sql)})'; if (dialectSpecific) {
return '$onCreate.byDialect($sql)';
} else {
return '$onCreate($sql)';
}
} }
} }

View File

@ -1,3 +1,4 @@
import 'package:drift/drift.dart';
import 'package:recase/recase.dart'; import 'package:recase/recase.dart';
import 'package:sqlparser/sqlparser.dart' hide ResultColumn; import 'package:sqlparser/sqlparser.dart' hide ResultColumn;
@ -495,7 +496,45 @@ class QueryWriter {
/// been expanded. For instance, 'SELECT * FROM t WHERE x IN ?' will be turned /// been expanded. For instance, 'SELECT * FROM t WHERE x IN ?' will be turned
/// into 'SELECT * FROM t WHERE x IN ($expandedVar1)'. /// into 'SELECT * FROM t WHERE x IN ($expandedVar1)'.
String _queryCode(SqlQuery query) { String _queryCode(SqlQuery query) {
return SqlWriter(scope.options, query: query).write(); final dialectForCode = <String, List<SqlDialect>>{};
for (final dialect in scope.options.supportedDialects) {
final code =
SqlWriter(scope.options, dialect: dialect, query: query).write();
dialectForCode.putIfAbsent(code, () => []).add(dialect);
}
if (dialectForCode.length == 1) {
// All supported dialects use the same SQL syntax, so we can just use that
return dialectForCode.keys.single;
} else {
// Create a switch expression matching over the dialect of the database
// we're connected to.
final buffer = StringBuffer('switch (executor.dialect) {');
final dialectEnum = scope.drift('SqlDialect');
var index = 0;
for (final MapEntry(key: code, value: dialects)
in dialectForCode.entries) {
index++;
buffer
.write(dialects.map((e) => '$dialectEnum.${e.name}').join(' || '));
if (index == dialectForCode.length) {
// In the last branch, match all dialects as a fallback
buffer.write(' || _ ');
}
buffer
..write(' => ')
..write(code)
..write(', ');
}
buffer.writeln('}');
return buffer.toString();
}
} }
void _writeReadsFrom(SqlSelectQuery select) { void _writeReadsFrom(SqlSelectQuery select) {
@ -818,9 +857,14 @@ String? _defaultForDartPlaceholder(
if (kind is ExpressionDartPlaceholderType && kind.defaultValue != null) { if (kind is ExpressionDartPlaceholderType && kind.defaultValue != null) {
// Wrap the default expression in parentheses to avoid issues with // Wrap the default expression in parentheses to avoid issues with
// the surrounding precedence in SQL. // the surrounding precedence in SQL.
final sql = SqlWriter(scope.options) final (sql, dialectSpecific) =
.writeNodeIntoStringLiteral(Parentheses(kind.defaultValue!)); scope.sqlByDialect(Parentheses(kind.defaultValue!));
if (dialectSpecific) {
return 'const ${scope.drift('CustomExpression')}.dialectSpecific($sql)';
} else {
return 'const ${scope.drift('CustomExpression')}($sql)'; return 'const ${scope.drift('CustomExpression')}($sql)';
}
} else if (kind is SimpleDartPlaceholderType && } else if (kind is SimpleDartPlaceholderType &&
kind.kind == SimpleDartPlaceholderKind.orderBy) { kind.kind == SimpleDartPlaceholderKind.orderBy) {
return 'const ${scope.drift('OrderBy')}.nothing()'; return 'const ${scope.drift('OrderBy')}.nothing()';

View File

@ -26,8 +26,9 @@ String placeholderContextName(FoundDartPlaceholder placeholder) {
} }
extension ToSqlText on AstNode { extension ToSqlText on AstNode {
String toSqlWithoutDriftSpecificSyntax(DriftOptions options) { String toSqlWithoutDriftSpecificSyntax(
final writer = SqlWriter(options, escapeForDart: false); DriftOptions options, SqlDialect dialect) {
final writer = SqlWriter(options, dialect: dialect, escapeForDart: false);
return writer.writeSql(this); return writer.writeSql(this);
} }
} }
@ -36,17 +37,19 @@ class SqlWriter extends NodeSqlBuilder {
final StringBuffer _out; final StringBuffer _out;
final SqlQuery? query; final SqlQuery? query;
final DriftOptions options; final DriftOptions options;
final SqlDialect dialect;
final Map<NestedStarResultColumn, NestedResultTable> _starColumnToResolved; final Map<NestedStarResultColumn, NestedResultTable> _starColumnToResolved;
bool get _isPostgres => options.effectiveDialect == SqlDialect.postgres; bool get _isPostgres => dialect == SqlDialect.postgres;
SqlWriter._(this.query, this.options, this._starColumnToResolved, SqlWriter._(this.query, this.options, this.dialect,
StringBuffer out, bool escapeForDart) this._starColumnToResolved, StringBuffer out, bool escapeForDart)
: _out = out, : _out = out,
super(escapeForDart ? _DartEscapingSink(out) : out); super(escapeForDart ? _DartEscapingSink(out) : out);
factory SqlWriter( factory SqlWriter(
DriftOptions options, { DriftOptions options, {
required SqlDialect dialect,
SqlQuery? query, SqlQuery? query,
bool escapeForDart = true, bool escapeForDart = true,
StringBuffer? buffer, StringBuffer? buffer,
@ -61,7 +64,7 @@ class SqlWriter extends NodeSqlBuilder {
if (nestedResult is NestedResultTable) nestedResult.from: nestedResult if (nestedResult is NestedResultTable) nestedResult.from: nestedResult
}; };
} }
return SqlWriter._(query, options, doubleStarColumnToResolvedTable, return SqlWriter._(query, options, dialect, doubleStarColumnToResolvedTable,
buffer ?? StringBuffer(), escapeForDart); buffer ?? StringBuffer(), escapeForDart);
} }
@ -84,7 +87,7 @@ class SqlWriter extends NodeSqlBuilder {
@override @override
bool isKeyword(String lexeme) { bool isKeyword(String lexeme) {
switch (options.effectiveDialect) { switch (dialect) {
case SqlDialect.postgres: case SqlDialect.postgres:
return isKeywordLexeme(lexeme) || isPostgresKeywordLexeme(lexeme); return isKeywordLexeme(lexeme) || isPostgresKeywordLexeme(lexeme);
default: default:

View File

@ -162,7 +162,7 @@ abstract class TableOrViewWriter {
} }
/// Returns the Dart type and the Dart expression creating a `GeneratedColumn` /// Returns the Dart type and the Dart expression creating a `GeneratedColumn`
/// instance in drift for the givne [column]. /// instance in drift for the given [column].
static (String, String) instantiateColumn( static (String, String) instantiateColumn(
DriftColumn column, DriftColumn column,
TextEmitter emitter, { TextEmitter emitter, {
@ -173,6 +173,10 @@ abstract class TableOrViewWriter {
final expressionBuffer = StringBuffer(); final expressionBuffer = StringBuffer();
final constraints = defaultConstraints(column); final constraints = defaultConstraints(column);
// Remove dialect-specific constraints for dialects we don't care about.
constraints.removeWhere(
(key, _) => !emitter.writer.options.supportedDialects.contains(key));
for (final constraint in column.constraints) { for (final constraint in column.constraints) {
if (constraint is LimitingTextLength) { if (constraint is LimitingTextLength) {
final buffer = final buffer =

View File

@ -75,18 +75,29 @@ class ViewWriter extends TableOrViewWriter {
..write('@override\n String get entityName=>' ..write('@override\n String get entityName=>'
' ${asDartLiteral(view.schemaName)};\n'); ' ${asDartLiteral(view.schemaName)};\n');
emitter
..writeln('@override')
..write('Map<${emitter.drift('SqlDialect')}, String>')
..write(source is! SqlViewSource ? '?' : '')
..write('get createViewStatements => ');
if (source is SqlViewSource) { if (source is SqlViewSource) {
final astNode = source.parsedStatement; final astNode = source.parsedStatement;
emitter.write('@override\nString get createViewStmt =>');
if (astNode != null) { if (astNode != null) {
emitter.writeSqlAsDartLiteral(astNode); emitter.writeSqlByDialectMap(astNode);
} else { } else {
emitter.write(asDartLiteral(source.sqlCreateViewStmt)); final firstDialect = scope.options.supportedDialects.first;
emitter
..write('{')
..writeDriftRef('SqlDialect')
..write('.${firstDialect.name}: ')
..write(asDartLiteral(source.sqlCreateViewStmt))
..write('}');
} }
buffer.writeln(';'); buffer.writeln(';');
} else { } else {
buffer.write('@override\n String? get createViewStmt => null;\n'); buffer.writeln('null;');
} }
writeAsDslTable(); writeAsDslTable();

View File

@ -1,3 +1,4 @@
import 'package:drift/drift.dart';
import 'package:recase/recase.dart'; import 'package:recase/recase.dart';
import 'package:sqlparser/sqlparser.dart' as sql; import 'package:sqlparser/sqlparser.dart' as sql;
import 'package:path/path.dart' show url; import 'package:path/path.dart' show url;
@ -228,8 +229,50 @@ abstract class _NodeOrWriter {
return buffer.toString(); return buffer.toString();
} }
String sqlCode(sql.AstNode node) { String sqlCode(sql.AstNode node, SqlDialect dialect) {
return SqlWriter(writer.options, escapeForDart: false).writeSql(node); return SqlWriter(writer.options, dialect: dialect, escapeForDart: false)
.writeSql(node);
}
/// Builds a Dart expression writing the [node] into a Dart string.
///
/// If the code for [node] depends on the dialect, the code returned evaluates
/// to a `Map<SqlDialect, String>`. Otherwise, the code is a direct string
/// literal.
///
/// The boolean component in the record describes whether the code will be
/// dialect specific.
(String, bool) sqlByDialect(sql.AstNode node) {
final dialects = writer.options.supportedDialects;
if (dialects.length == 1) {
return (
SqlWriter(writer.options, dialect: dialects.single)
.writeNodeIntoStringLiteral(node),
false
);
}
final buffer = StringBuffer();
_writeSqlByDialectMap(node, buffer);
return (buffer.toString(), true);
}
void _writeSqlByDialectMap(sql.AstNode node, StringBuffer buffer) {
buffer.write('{');
for (final dialect in writer.options.supportedDialects) {
buffer
..write(drift('SqlDialect'))
..write(".${dialect.name}: '");
SqlWriter(writer.options, dialect: dialect, buffer: buffer)
.writeSql(node);
buffer.writeln("',");
}
buffer.write('}');
} }
} }
@ -302,16 +345,18 @@ class TextEmitter extends _Node {
void writeDart(AnnotatedDartCode code) => write(dartCode(code)); void writeDart(AnnotatedDartCode code) => write(dartCode(code));
void writeSql(sql.AstNode node, {bool escapeForDartString = true}) { void writeSql(sql.AstNode node,
SqlWriter(writer.options, {required SqlDialect dialect, bool escapeForDartString = true}) {
escapeForDart: escapeForDartString, buffer: buffer) SqlWriter(
.writeSql(node); writer.options,
dialect: dialect,
escapeForDart: escapeForDartString,
buffer: buffer,
).writeSql(node);
} }
void writeSqlAsDartLiteral(sql.AstNode node) { void writeSqlByDialectMap(sql.AstNode node) {
buffer.write("'"); _writeSqlByDialectMap(node, buffer);
writeSql(node);
buffer.write("'");
} }
} }

View File

@ -76,7 +76,8 @@ SELECT rowid, highlight(example_table_search, 0, '[match]', '[match]') name,
{'a|lib/a.drift': 'CREATE VIRTUAL TABLE demo USING spellfix1;'}, {'a|lib/a.drift': 'CREATE VIRTUAL TABLE demo USING spellfix1;'},
options: DriftOptions.defaults( options: DriftOptions.defaults(
dialect: DialectOptions( dialect: DialectOptions(
SqlDialect.sqlite, null,
[SqlDialect.sqlite],
SqliteAnalysisOptions( SqliteAnalysisOptions(
modules: [SqlModule.spellfix1], modules: [SqlModule.spellfix1],
), ),

View File

@ -307,7 +307,7 @@ TypeConverter<Object, int> myConverter() => throw UnimplementedError();
'a|lib/a.drift.dart': decodedMatches( 'a|lib/a.drift.dart': decodedMatches(
allOf( allOf(
contains( contains(
''''CREATE VIEW my_view AS SELECT CAST(1 AS INT) AS c1, CAST(\\'bar\\' AS TEXT) AS c2, 1 AS c3, NULLIF(1, 2) AS c4';'''), ''''CREATE VIEW my_view AS SELECT CAST(1 AS INT) AS c1, CAST(\\'bar\\' AS TEXT) AS c2, 1 AS c3, NULLIF(1, 2) AS c4','''),
contains(r'$converterc1 ='), contains(r'$converterc1 ='),
contains(r'$converterc2 ='), contains(r'$converterc2 ='),
contains(r'$converterc3 ='), contains(r'$converterc3 ='),

View File

@ -1,4 +1,5 @@
import 'package:build_test/build_test.dart'; import 'package:build_test/build_test.dart';
import 'package:drift/drift.dart';
import 'package:drift_dev/src/analysis/options.dart'; import 'package:drift_dev/src/analysis/options.dart';
import 'package:drift_dev/src/writer/import_manager.dart'; import 'package:drift_dev/src/writer/import_manager.dart';
import 'package:drift_dev/src/writer/queries/query_writer.dart'; import 'package:drift_dev/src/writer/queries/query_writer.dart';
@ -11,14 +12,16 @@ import '../../utils.dart';
void main() { void main() {
Future<String> generateForQueryInDriftFile(String driftFile, Future<String> generateForQueryInDriftFile(String driftFile,
{DriftOptions options = const DriftOptions.defaults()}) async { {DriftOptions options = const DriftOptions.defaults(
generateNamedParameters: true,
)}) async {
final state = final state =
TestBackend.inTest({'a|lib/main.drift': driftFile}, options: options); TestBackend.inTest({'a|lib/main.drift': driftFile}, options: options);
final file = await state.analyze('package:a/main.drift'); final file = await state.analyze('package:a/main.drift');
state.expectNoErrors(); state.expectNoErrors();
final writer = Writer( final writer = Writer(
const DriftOptions.defaults(generateNamedParameters: true), options,
generationOptions: GenerationOptions( generationOptions: GenerationOptions(
imports: ImportManagerForPartFiles(), imports: ImportManagerForPartFiles(),
), ),
@ -528,4 +531,26 @@ class ADrift extends i1.ModularAccessor {
}''')) }'''))
}, outputs.dartOutputs, outputs); }, outputs.dartOutputs, outputs);
}); });
test('creates dialect-specific query code', () async {
final result = await generateForQueryInDriftFile(
r'''
query (:foo AS TEXT): SELECT :foo;
''',
options: const DriftOptions.defaults(
dialect: DialectOptions(
null, [SqlDialect.sqlite, SqlDialect.postgres], null),
),
);
expect(
result,
contains(
'switch (executor.dialect) {'
"SqlDialect.sqlite => 'SELECT ?1 AS _c0', "
"SqlDialect.postgres || _ => 'SELECT \\\$1 AS _c0', "
'}',
),
);
});
} }

View File

@ -6,14 +6,18 @@ import 'package:sqlparser/sqlparser.dart';
import 'package:test/test.dart'; import 'package:test/test.dart';
void main() { void main() {
void check(String sql, String expectedDart, void check(
{DriftOptions options = const DriftOptions.defaults()}) { String sql,
String expectedDart, {
DriftOptions options = const DriftOptions.defaults(),
SqlDialect dialect = SqlDialect.sqlite,
}) {
final engine = SqlEngine(); final engine = SqlEngine();
final context = engine.analyze(sql); final context = engine.analyze(sql);
final query = SqlSelectQuery('name', context, context.root, [], [], final query = SqlSelectQuery('name', context, context.root, [], [],
InferredResultSet(null, []), null, null); InferredResultSet(null, []), null, null);
final result = SqlWriter(options, query: query).write(); final result = SqlWriter(options, dialect: dialect, query: query).write();
expect(result, expectedDart); expect(result, expectedDart);
} }
@ -33,7 +37,6 @@ void main() {
test('escapes postgres keywords', () { test('escapes postgres keywords', () {
check('SELECT * FROM user', "'SELECT * FROM user'"); check('SELECT * FROM user', "'SELECT * FROM user'");
check('SELECT * FROM user', "'SELECT * FROM \"user\"'", check('SELECT * FROM user', "'SELECT * FROM \"user\"'",
options: DriftOptions.defaults( dialect: SqlDialect.postgres);
dialect: DialectOptions(SqlDialect.postgres, null)));
}); });
} }

View File

@ -602,8 +602,10 @@ class PopularUsers extends i0.ViewInfo<i1.PopularUsers, i1.PopularUser>
@override @override
String get entityName => 'popular_users'; String get entityName => 'popular_users';
@override @override
String get createViewStmt => Map<i0.SqlDialect, String> get createViewStatements => {
'CREATE VIEW popular_users AS SELECT * FROM users ORDER BY (SELECT count(*) FROM follows WHERE followed = users.id)'; i0.SqlDialect.sqlite:
'CREATE VIEW popular_users AS SELECT * FROM users ORDER BY (SELECT count(*) FROM follows WHERE followed = users.id)',
};
@override @override
PopularUsers get asDslTable => this; PopularUsers get asDslTable => this;
@override @override

View File

@ -9,9 +9,7 @@ targets:
raw_result_set_data: false raw_result_set_data: false
named_parameters: false named_parameters: false
sql: sql:
# As sqlite3 is compatible with the postgres dialect (but not vice-versa), we're dialects: [sqlite, postgres]
# using this dialect so that we can run the tests on postgres as well.
dialect: postgres
options: options:
version: "3.37" version: "3.37"
modules: modules:

View File

@ -336,7 +336,6 @@ class $FriendshipsTable extends Friendships
requiredDuringInsert: false, requiredDuringInsert: false,
defaultConstraints: GeneratedColumn.constraintsDependsOnDialect({ defaultConstraints: GeneratedColumn.constraintsDependsOnDialect({
SqlDialect.sqlite: 'CHECK ("really_good_friends" IN (0, 1))', SqlDialect.sqlite: 'CHECK ("really_good_friends" IN (0, 1))',
SqlDialect.mysql: '',
SqlDialect.postgres: '', SqlDialect.postgres: '',
}), }),
defaultValue: const Constant(false)); defaultValue: const Constant(false));
@ -554,7 +553,13 @@ abstract class _$Database extends GeneratedDatabase {
late final $FriendshipsTable friendships = $FriendshipsTable(this); late final $FriendshipsTable friendships = $FriendshipsTable(this);
Selectable<User> mostPopularUsers(int amount) { Selectable<User> mostPopularUsers(int amount) {
return customSelect( return customSelect(
switch (executor.dialect) {
SqlDialect.sqlite =>
'SELECT * FROM users AS u ORDER BY (SELECT COUNT(*) FROM friendships WHERE first_user = u.id OR second_user = u.id) DESC LIMIT ?1',
SqlDialect.postgres ||
_ =>
'SELECT * FROM users AS u ORDER BY (SELECT COUNT(*) FROM friendships WHERE first_user = u.id OR second_user = u.id) DESC LIMIT \$1', 'SELECT * FROM users AS u ORDER BY (SELECT COUNT(*) FROM friendships WHERE first_user = u.id OR second_user = u.id) DESC LIMIT \$1',
},
variables: [ variables: [
Variable<int>(amount) Variable<int>(amount)
], ],
@ -566,7 +571,13 @@ abstract class _$Database extends GeneratedDatabase {
Selectable<int> amountOfGoodFriends(int user) { Selectable<int> amountOfGoodFriends(int user) {
return customSelect( return customSelect(
switch (executor.dialect) {
SqlDialect.sqlite =>
'SELECT COUNT(*) AS _c0 FROM friendships AS f WHERE f.really_good_friends = TRUE AND(f.first_user = ?1 OR f.second_user = ?1)',
SqlDialect.postgres ||
_ =>
'SELECT COUNT(*) AS _c0 FROM friendships AS f WHERE f.really_good_friends = TRUE AND(f.first_user = \$1 OR f.second_user = \$1)', 'SELECT COUNT(*) AS _c0 FROM friendships AS f WHERE f.really_good_friends = TRUE AND(f.first_user = \$1 OR f.second_user = \$1)',
},
variables: [ variables: [
Variable<int>(user) Variable<int>(user)
], ],
@ -577,19 +588,23 @@ abstract class _$Database extends GeneratedDatabase {
Selectable<FriendshipsOfResult> friendshipsOf(int user) { Selectable<FriendshipsOfResult> friendshipsOf(int user) {
return customSelect( return customSelect(
switch (executor.dialect) {
SqlDialect.sqlite =>
'SELECT f.really_good_friends,"user"."id" AS "nested_0.id", "user"."name" AS "nested_0.name", "user"."birth_date" AS "nested_0.birth_date", "user"."profile_picture" AS "nested_0.profile_picture", "user"."preferences" AS "nested_0.preferences" FROM friendships AS f INNER JOIN users AS user ON user.id IN (f.first_user, f.second_user) AND user.id != ?1 WHERE(f.first_user = ?1 OR f.second_user = ?1)',
SqlDialect.postgres ||
_ =>
'SELECT f.really_good_friends,"user"."id" AS "nested_0.id", "user"."name" AS "nested_0.name", "user"."birth_date" AS "nested_0.birth_date", "user"."profile_picture" AS "nested_0.profile_picture", "user"."preferences" AS "nested_0.preferences" FROM friendships AS f INNER JOIN users AS "user" ON "user".id IN (f.first_user, f.second_user) AND "user".id != \$1 WHERE(f.first_user = \$1 OR f.second_user = \$1)', 'SELECT f.really_good_friends,"user"."id" AS "nested_0.id", "user"."name" AS "nested_0.name", "user"."birth_date" AS "nested_0.birth_date", "user"."profile_picture" AS "nested_0.profile_picture", "user"."preferences" AS "nested_0.preferences" FROM friendships AS f INNER JOIN users AS "user" ON "user".id IN (f.first_user, f.second_user) AND "user".id != \$1 WHERE(f.first_user = \$1 OR f.second_user = \$1)',
},
variables: [ variables: [
Variable<int>(user) Variable<int>(user)
], ],
readsFrom: { readsFrom: {
friendships, friendships,
users, users,
}).asyncMap((QueryRow row) async { }).asyncMap((QueryRow row) async => FriendshipsOfResult(
return FriendshipsOfResult(
reallyGoodFriends: row.read<bool>('really_good_friends'), reallyGoodFriends: row.read<bool>('really_good_friends'),
user: await users.mapFromRow(row, tablePrefix: 'nested_0'), user: await users.mapFromRow(row, tablePrefix: 'nested_0'),
); ));
});
} }
Selectable<int> userCount() { Selectable<int> userCount() {
@ -601,7 +616,13 @@ abstract class _$Database extends GeneratedDatabase {
} }
Selectable<Preferences?> settingsFor(int user) { Selectable<Preferences?> settingsFor(int user) {
return customSelect('SELECT preferences FROM users WHERE id = \$1', return customSelect(
switch (executor.dialect) {
SqlDialect.sqlite => 'SELECT preferences FROM users WHERE id = ?1',
SqlDialect.postgres ||
_ =>
'SELECT preferences FROM users WHERE id = \$1',
},
variables: [ variables: [
Variable<int>(user) Variable<int>(user)
], ],
@ -626,7 +647,13 @@ abstract class _$Database extends GeneratedDatabase {
Future<List<Friendship>> returning(int var1, int var2, bool var3) { Future<List<Friendship>> returning(int var1, int var2, bool var3) {
return customWriteReturning( return customWriteReturning(
switch (executor.dialect) {
SqlDialect.sqlite =>
'INSERT INTO friendships VALUES (?1, ?2, ?3) RETURNING *',
SqlDialect.postgres ||
_ =>
'INSERT INTO friendships VALUES (\$1, \$2, \$3) RETURNING *', 'INSERT INTO friendships VALUES (\$1, \$2, \$3) RETURNING *',
},
variables: [ variables: [
Variable<int>(var1), Variable<int>(var1),
Variable<int>(var2), Variable<int>(var2),

View File

@ -4,7 +4,7 @@ version: 1.0.0
# homepage: https://www.example.com # homepage: https://www.example.com
environment: environment:
sdk: '>=2.17.0 <3.0.0' sdk: '>=3.0.0 <4.0.0'
dependencies: dependencies:
drift: ^2.0.0-0 drift: ^2.0.0-0