Add multi-dialect generation option

This commit is contained in:
Simon Binder 2023-07-26 23:59:06 +02:00
parent 94c4c1a8e0
commit 8d3f490604
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
28 changed files with 436 additions and 101 deletions

View File

@ -110,6 +110,36 @@ in 3.34, so an error would be reported.
Currently, the generator can't provide compatibility checks for versions below 3.34, which is the
minimum version needed in options.
### Multi-dialect code generation
Thanks to community contributions, drift has in-progress support for Postgres and MariaDB.
You can change the `dialect` option to `postgres` or `mariadb` to generate code for those
database management systems.
In some cases, your generated code might have to support more than one DBMS. For instance,
you might want to share database code between your backend and a Flutter app. Or maybe
you're writing a server that should be able to talk to both MariaDB and Postgres, depending
on what the operator prefers.
Drift can generate code for multiple dialects - in that case, the right SQL will be chosen
at runtime when it makes a difference.
To enable this feature, remove the `dialect` option in the `sql` block and replace it with
a list of `dialects`:
```yaml
targets:
$default:
builders:
drift_dev:
options:
sql:
dialect:
- sqlite
- postgres
options:
version: "3.34"
```
### Available extensions
__Note__: This enables extensions in the analyzer for custom queries only. For instance, when the `json1` extension is

View File

@ -534,7 +534,7 @@ class $TodoCategoryItemCountView
@override
String get entityName => 'todo_category_item_count';
@override
String? get createViewStmt => null;
Map<SqlDialect, String>? get createViewStatements => null;
@override
$TodoCategoryItemCountView get asDslTable => this;
@override
@ -639,7 +639,7 @@ class $TodoItemWithCategoryNameViewView extends ViewInfo<
@override
String get entityName => 'customViewName';
@override
String? get createViewStmt => null;
Map<SqlDialect, String>? get createViewStatements => null;
@override
$TodoItemWithCategoryNameViewView get asDslTable => this;
@override

View File

@ -157,6 +157,10 @@ class VersionedView implements ViewInfo<HasResultSet, QueryRow>, HasResultSet {
@override
final String createViewStmt;
@override
Map<SqlDialect, String>? get createViewStatements =>
{SqlDialect.sqlite: createViewStmt};
@override
final List<GeneratedColumn> $columns;

View File

@ -11,6 +11,8 @@ class CustomExpression<D extends Object> extends Expression<D> {
/// The SQL of this expression
final String content;
final Map<SqlDialect, String>? _dialectSpecificContent;
/// Additional tables that this expression is watching.
///
/// When this expression is used in a stream query, the stream will update
@ -24,11 +26,25 @@ class CustomExpression<D extends Object> extends Expression<D> {
/// Constructs a custom expression by providing the raw sql [content].
const CustomExpression(this.content,
{this.watchedTables = const [], this.precedence = Precedence.unknown});
{this.watchedTables = const [], this.precedence = Precedence.unknown})
: _dialectSpecificContent = null;
/// Constructs a custom expression providing the raw SQL in [content] depending
/// on the SQL dialect when this expression is built.
const CustomExpression.dialectSpecific(Map<SqlDialect, String> content,
{this.watchedTables = const [], this.precedence = Precedence.unknown})
: _dialectSpecificContent = content,
content = '';
@override
void writeInto(GenerationContext context) {
context.buffer.write(content);
final dialectSpecific = _dialectSpecificContent;
if (dialectSpecific != null) {
} else {
context.buffer.write(content);
}
context.watchedTables.addAll(watchedTables);
}

View File

@ -1,6 +1,11 @@
@internal
library;
import 'package:drift/drift.dart';
/// Utilities for writing the definition of a result set into a query.
import 'package:meta/meta.dart';
/// Internal utilities for building queries that aren't exported.
extension WriteDefinition on GenerationContext {
/// Writes the result set to this context, suitable to implement `FROM`
/// clauses and joins.
@ -16,4 +21,21 @@ extension WriteDefinition on GenerationContext {
watchedTables.add(resultSet);
}
}
/// Returns a suitable SQL string in [sql] based on the current dialect.
String pickForDialect(Map<SqlDialect, String> sql) {
assert(
sql.containsKey(dialect),
'Tried running SQL optimized for the following dialects: ${sql.keys.join}. '
'However, the database is running $dialect. Has that dialect been added '
'to the `dialects` drift builder option?',
);
final found = sql[dialect];
if (found != null) {
return found;
}
return sql.values.first; // Fallback
}
}

View File

@ -87,7 +87,7 @@ class Migrator {
} else if (entity is Index) {
await createIndex(entity);
} else if (entity is OnCreateQuery) {
await _issueCustomQuery(entity.sql, const []);
await _issueQueryByDialect(entity.sqlByDialect);
} else if (entity is ViewInfo) {
await createView(entity);
} else {
@ -363,19 +363,19 @@ class Migrator {
/// Executes the `CREATE TRIGGER` statement that created the [trigger].
Future<void> createTrigger(Trigger trigger) {
return _issueCustomQuery(trigger.createTriggerStmt, const []);
return _issueQueryByDialect(trigger.createStatementsByDialect);
}
/// Executes a `CREATE INDEX` statement to create the [index].
Future<void> createIndex(Index index) {
return _issueCustomQuery(index.createIndexStmt, const []);
return _issueQueryByDialect(index.createStatementsByDialect);
}
/// Executes a `CREATE VIEW` statement to create the [view].
Future<void> createView(ViewInfo view) async {
final stmt = view.createViewStmt;
if (stmt != null) {
await _issueCustomQuery(stmt, const []);
final stmts = view.createViewStatements;
if (stmts != null) {
await _issueQueryByDialect(stmts);
} else if (view.query != null) {
final context = GenerationContext.fromDb(_db, supportsVariables: false);
final columnNames = view.$columns.map((e) => e.escapedName).join(', ');
@ -478,6 +478,11 @@ class Migrator {
return _issueCustomQuery(sql, args);
}
Future<void> _issueQueryByDialect(Map<SqlDialect, String> sql) {
final context = _createContext();
return _issueCustomQuery(context.pickForDialect(sql), const []);
}
Future<void> _issueCustomQuery(String sql, [List<dynamic>? args]) {
return _db.customStatement(sql, args);
}

View File

@ -17,15 +17,26 @@ abstract class DatabaseSchemaEntity {
/// [sqlite-docs]: https://sqlite.org/lang_createtrigger.html
/// [sql-tut]: https://www.sqlitetutorial.net/sqlite-trigger/
class Trigger extends DatabaseSchemaEntity {
/// The `CREATE TRIGGER` sql statement that can be used to create this
/// trigger.
final String createTriggerStmt;
@override
final String entityName;
/// The `CREATE TRIGGER` sql statement that can be used to create this
/// trigger.
@Deprecated('Use createStatementsByDialect instead')
String get createTriggerStmt => createStatementsByDialect.values.first;
/// The `CREATE TRIGGER` SQL statements used to create this trigger, accessible
/// for each dialect enabled when generating code.
final Map<SqlDialect, String> createStatementsByDialect;
/// Creates a trigger representation by the [createTriggerStmt] and its
/// [entityName]. Mainly used by generated code.
Trigger(this.createTriggerStmt, this.entityName);
Trigger(String createTriggerStmt, String entityName)
: this.byDialect(entityName, {SqlDialect.sqlite: createTriggerStmt});
/// Creates the trigger model from its [entityName] in the schema and all
/// [createStatementsByDialect] for the supported dialects.
Trigger.byDialect(this.entityName, this.createStatementsByDialect);
}
/// A sqlite index on columns or expressions.
@ -40,11 +51,21 @@ class Index extends DatabaseSchemaEntity {
final String entityName;
/// The `CREATE INDEX` sql statement that can be used to create this index.
final String createIndexStmt;
@Deprecated('Use createStatementsByDialect instead')
String get createIndexStmt => createStatementsByDialect.values.first;
/// The `CREATE INDEX` SQL statements used to create this index, accessible
/// for each dialect enabled when generating code.
final Map<SqlDialect, String> createStatementsByDialect;
/// Creates an index model by the [createIndexStmt] and its [entityName].
/// Mainly used by generated code.
Index(this.entityName, this.createIndexStmt);
Index(this.entityName, String createIndexStmt)
: createStatementsByDialect = {SqlDialect.sqlite: createIndexStmt};
/// Creates an index model by its [entityName] used in the schema and the
/// `CREATE INDEX` statements for each supported dialect.
Index.byDialect(this.entityName, this.createStatementsByDialect);
}
/// An internal schema entity to run an sql statement when the database is
@ -61,10 +82,19 @@ class Index extends DatabaseSchemaEntity {
/// drift file.
class OnCreateQuery extends DatabaseSchemaEntity {
/// The sql statement that should be run in the default `onCreate` clause.
final String sql;
@Deprecated('Use sqlByDialect instead')
String get sql => sqlByDialect.values.first;
/// The SQL statement to run, indexed by the dialect used in the database.
final Map<SqlDialect, String> sqlByDialect;
/// Create a query that will be run in the default `onCreate` migration.
OnCreateQuery(this.sql);
OnCreateQuery(String sql) : this.byDialect({SqlDialect.sqlite: sql});
/// Creates the entity of a query to run in the default `onCreate` migration.
///
/// The migrator will lookup a suitable query from the [sqlByDialect] map.
OnCreateQuery.byDialect(this.sqlByDialect);
@override
String get entityName => r'$internal$';

View File

@ -17,7 +17,14 @@ abstract class ViewInfo<Self extends HasResultSet, Row>
/// The `CREATE VIEW` sql statement that can be used to create this view.
///
/// This will be null if the view was defined in Dart.
String? get createViewStmt;
@Deprecated('Use createViewStatements instead')
String? get createViewStmt => createViewStatements?.values.first;
/// The `CREATE VIEW` sql statement that can be used to create this view,
/// depending on the dialect used by the current database.
///
/// This will be null if the view was defined in Dart.
Map<SqlDialect, String>? get createViewStatements;
/// Predefined query from `View.as()`
///

View File

@ -259,6 +259,32 @@ void main() {
));
});
});
group('dialect-specific', () {
Map<SqlDialect, String> statements(String base) {
return {
for (final dialect in SqlDialect.values) dialect: '$base $dialect',
};
}
for (final dialect in [SqlDialect.sqlite, SqlDialect.postgres]) {
test('with dialect $dialect', () async {
final executor = MockExecutor();
when(executor.dialect).thenReturn(dialect);
final db = TodoDb(executor);
final migrator = db.createMigrator();
await migrator.create(Trigger.byDialect('a', statements('trigger')));
await migrator.create(Index.byDialect('a', statements('index')));
await migrator.create(OnCreateQuery.byDialect(statements('@')));
verify(executor.runCustom('trigger $dialect', []));
verify(executor.runCustom('index $dialect', []));
verify(executor.runCustom('@ $dialect', []));
});
}
});
}
final class _FakeSchemaVersion extends VersionedSchema {

View File

@ -1600,8 +1600,10 @@ class MyView extends ViewInfo<MyView, MyViewData> implements HasResultSet {
@override
String get entityName => 'my_view';
@override
String get createViewStmt =>
'CREATE VIEW my_view AS SELECT * FROM config WHERE sync_state = 2';
Map<SqlDialect, String> get createViewStatements => {
SqlDialect.sqlite:
'CREATE VIEW my_view AS SELECT * FROM config WHERE sync_state = 2',
};
@override
MyView get asDslTable => this;
@override

View File

@ -641,16 +641,13 @@ class $UsersTable extends Users with TableInfo<$UsersTable, User> {
static const VerificationMeta _isAwesomeMeta =
const VerificationMeta('isAwesome');
@override
late final GeneratedColumn<bool> isAwesome =
GeneratedColumn<bool>('is_awesome', aliasedName, false,
type: DriftSqlType.bool,
requiredDuringInsert: false,
defaultConstraints: GeneratedColumn.constraintsDependsOnDialect({
SqlDialect.sqlite: 'CHECK ("is_awesome" IN (0, 1))',
SqlDialect.mysql: '',
SqlDialect.postgres: '',
}),
defaultValue: const Constant(true));
late final GeneratedColumn<bool> isAwesome = GeneratedColumn<bool>(
'is_awesome', aliasedName, false,
type: DriftSqlType.bool,
requiredDuringInsert: false,
defaultConstraints:
GeneratedColumn.constraintIsAlways('CHECK ("is_awesome" IN (0, 1))'),
defaultValue: const Constant(true));
static const VerificationMeta _profilePictureMeta =
const VerificationMeta('profilePicture');
@override
@ -1549,7 +1546,7 @@ class $CategoryTodoCountViewView
@override
String get entityName => 'category_todo_count_view';
@override
String? get createViewStmt => null;
Map<SqlDialect, String>? get createViewStatements => null;
@override
$CategoryTodoCountViewView get asDslTable => this;
@override
@ -1660,7 +1657,7 @@ class $TodoWithCategoryViewView
@override
String get entityName => 'todo_with_category_view';
@override
String? get createViewStmt => null;
Map<SqlDialect, String>? get createViewStatements => null;
@override
$TodoWithCategoryViewView get asDslTable => this;
@override

View File

@ -124,7 +124,7 @@ class DriftOptions {
this.modules = const [],
this.sqliteAnalysisOptions,
this.storeDateTimeValuesAsText = false,
this.dialect = const DialectOptions(SqlDialect.sqlite, null),
this.dialect = const DialectOptions(null, [SqlDialect.sqlite], null),
this.caseFromDartToSql = CaseFromDartToSql.snake,
this.writeToColumnsMixins = false,
this.fatalWarnings = false,
@ -189,7 +189,18 @@ class DriftOptions {
/// Whether the [module] has been enabled in this configuration.
bool hasModule(SqlModule module) => effectiveModules.contains(module);
SqlDialect get effectiveDialect => dialect?.dialect ?? SqlDialect.sqlite;
List<SqlDialect> get supportedDialects {
final dialects = dialect?.dialects;
final singleDialect = dialect?.dialect;
if (dialects != null) {
return dialects;
} else if (singleDialect != null) {
return [singleDialect];
} else {
return const [SqlDialect.sqlite];
}
}
/// The assumed sqlite version used when analyzing queries.
SqliteVersion get sqliteVersion {
@ -201,10 +212,11 @@ class DriftOptions {
@JsonSerializable()
class DialectOptions {
final SqlDialect dialect;
final SqlDialect? dialect;
final List<SqlDialect>? dialects;
final SqliteAnalysisOptions? options;
const DialectOptions(this.dialect, this.options);
const DialectOptions(this.dialect, this.dialects, this.options);
factory DialectOptions.fromJson(Map json) => _$DialectOptionsFromJson(json);

View File

@ -1,3 +1,4 @@
import 'package:drift/drift.dart' show SqlDialect;
import 'package:recase/recase.dart';
import 'package:sqlparser/sqlparser.dart';
import 'package:sqlparser/sqlparser.dart' as sql;
@ -110,7 +111,8 @@ class DriftViewResolver extends DriftElementResolver<DiscoveredDriftView> {
query: stmt.query,
// Remove drift-specific syntax
driftTableName: null,
).toSqlWithoutDriftSpecificSyntax(resolver.driver.options);
).toSqlWithoutDriftSpecificSyntax(
resolver.driver.options, SqlDialect.sqlite);
return DriftView(
discovered.ownId,

View File

@ -179,11 +179,16 @@ DialectOptions _$DialectOptionsFromJson(Map json) => $checkedCreate(
($checkedConvert) {
$checkKeys(
json,
allowedKeys: const ['dialect', 'options'],
allowedKeys: const ['dialect', 'dialects', 'options'],
);
final val = DialectOptions(
$checkedConvert(
'dialect', (v) => $enumDecode(_$SqlDialectEnumMap, v)),
'dialect', (v) => $enumDecodeNullable(_$SqlDialectEnumMap, v)),
$checkedConvert(
'dialects',
(v) => (v as List<dynamic>?)
?.map((e) => $enumDecode(_$SqlDialectEnumMap, e))
.toList()),
$checkedConvert(
'options',
(v) =>
@ -195,7 +200,9 @@ DialectOptions _$DialectOptionsFromJson(Map json) => $checkedCreate(
Map<String, dynamic> _$DialectOptionsToJson(DialectOptions instance) =>
<String, dynamic>{
'dialect': _$SqlDialectEnumMap[instance.dialect]!,
'dialect': _$SqlDialectEnumMap[instance.dialect],
'dialects':
instance.dialects?.map((e) => _$SqlDialectEnumMap[e]!).toList(),
'options': instance.options?.toJson(),
};

View File

@ -220,25 +220,37 @@ class DatabaseWriter {
}
static String createTrigger(Scope scope, DriftTrigger entity) {
final sql = scope.sqlCode(entity.parsedStatement!);
final (sql, dialectSpecific) = scope.sqlByDialect(entity.parsedStatement!);
final trigger = scope.drift('Trigger');
return '$trigger(${asDartLiteral(sql)}, ${asDartLiteral(entity.schemaName)})';
if (dialectSpecific) {
return '$trigger.byDialect(${asDartLiteral(entity.schemaName)}, $sql)';
} else {
return '$trigger($sql, ${asDartLiteral(entity.schemaName)})';
}
}
static String createIndex(Scope scope, DriftIndex entity) {
final sql = scope.sqlCode(entity.parsedStatement!);
final (sql, dialectSpecific) = scope.sqlByDialect(entity.parsedStatement!);
final index = scope.drift('Index');
return '$index(${asDartLiteral(entity.schemaName)}, ${asDartLiteral(sql)})';
if (dialectSpecific) {
return '$index.byDialect(${asDartLiteral(entity.schemaName)}, $sql)';
} else {
return '$index(${asDartLiteral(entity.schemaName)}, $sql)';
}
}
static String createOnCreate(
Scope scope, DefinedSqlQuery query, SqlQuery resolved) {
final sql = scope.sqlCode(resolved.root!);
final (sql, dialectSpecific) = scope.sqlByDialect(resolved.root!);
final onCreate = scope.drift('OnCreateQuery');
return '$onCreate(${asDartLiteral(sql)})';
if (dialectSpecific) {
return '$onCreate.byDialect($sql)';
} else {
return '$onCreate($sql)';
}
}
}

View File

@ -1,3 +1,4 @@
import 'package:drift/drift.dart';
import 'package:recase/recase.dart';
import 'package:sqlparser/sqlparser.dart' hide ResultColumn;
@ -495,7 +496,45 @@ class QueryWriter {
/// been expanded. For instance, 'SELECT * FROM t WHERE x IN ?' will be turned
/// into 'SELECT * FROM t WHERE x IN ($expandedVar1)'.
String _queryCode(SqlQuery query) {
return SqlWriter(scope.options, query: query).write();
final dialectForCode = <String, List<SqlDialect>>{};
for (final dialect in scope.options.supportedDialects) {
final code =
SqlWriter(scope.options, dialect: dialect, query: query).write();
dialectForCode.putIfAbsent(code, () => []).add(dialect);
}
if (dialectForCode.length == 1) {
// All supported dialects use the same SQL syntax, so we can just use that
return dialectForCode.keys.single;
} else {
// Create a switch expression matching over the dialect of the database
// we're connected to.
final buffer = StringBuffer('switch (executor.dialect) {');
final dialectEnum = scope.drift('SqlDialect');
var index = 0;
for (final MapEntry(key: code, value: dialects)
in dialectForCode.entries) {
index++;
buffer
.write(dialects.map((e) => '$dialectEnum.${e.name}').join(' || '));
if (index == dialectForCode.length) {
// In the last branch, match all dialects as a fallback
buffer.write(' || _ ');
}
buffer
..write(' => ')
..write(code)
..write(', ');
}
buffer.writeln('}');
return buffer.toString();
}
}
void _writeReadsFrom(SqlSelectQuery select) {
@ -818,9 +857,14 @@ String? _defaultForDartPlaceholder(
if (kind is ExpressionDartPlaceholderType && kind.defaultValue != null) {
// Wrap the default expression in parentheses to avoid issues with
// the surrounding precedence in SQL.
final sql = SqlWriter(scope.options)
.writeNodeIntoStringLiteral(Parentheses(kind.defaultValue!));
return 'const ${scope.drift('CustomExpression')}($sql)';
final (sql, dialectSpecific) =
scope.sqlByDialect(Parentheses(kind.defaultValue!));
if (dialectSpecific) {
return 'const ${scope.drift('CustomExpression')}.dialectSpecific($sql)';
} else {
return 'const ${scope.drift('CustomExpression')}($sql)';
}
} else if (kind is SimpleDartPlaceholderType &&
kind.kind == SimpleDartPlaceholderKind.orderBy) {
return 'const ${scope.drift('OrderBy')}.nothing()';

View File

@ -26,8 +26,9 @@ String placeholderContextName(FoundDartPlaceholder placeholder) {
}
extension ToSqlText on AstNode {
String toSqlWithoutDriftSpecificSyntax(DriftOptions options) {
final writer = SqlWriter(options, escapeForDart: false);
String toSqlWithoutDriftSpecificSyntax(
DriftOptions options, SqlDialect dialect) {
final writer = SqlWriter(options, dialect: dialect, escapeForDart: false);
return writer.writeSql(this);
}
}
@ -36,17 +37,19 @@ class SqlWriter extends NodeSqlBuilder {
final StringBuffer _out;
final SqlQuery? query;
final DriftOptions options;
final SqlDialect dialect;
final Map<NestedStarResultColumn, NestedResultTable> _starColumnToResolved;
bool get _isPostgres => options.effectiveDialect == SqlDialect.postgres;
bool get _isPostgres => dialect == SqlDialect.postgres;
SqlWriter._(this.query, this.options, this._starColumnToResolved,
StringBuffer out, bool escapeForDart)
SqlWriter._(this.query, this.options, this.dialect,
this._starColumnToResolved, StringBuffer out, bool escapeForDart)
: _out = out,
super(escapeForDart ? _DartEscapingSink(out) : out);
factory SqlWriter(
DriftOptions options, {
required SqlDialect dialect,
SqlQuery? query,
bool escapeForDart = true,
StringBuffer? buffer,
@ -61,7 +64,7 @@ class SqlWriter extends NodeSqlBuilder {
if (nestedResult is NestedResultTable) nestedResult.from: nestedResult
};
}
return SqlWriter._(query, options, doubleStarColumnToResolvedTable,
return SqlWriter._(query, options, dialect, doubleStarColumnToResolvedTable,
buffer ?? StringBuffer(), escapeForDart);
}
@ -84,7 +87,7 @@ class SqlWriter extends NodeSqlBuilder {
@override
bool isKeyword(String lexeme) {
switch (options.effectiveDialect) {
switch (dialect) {
case SqlDialect.postgres:
return isKeywordLexeme(lexeme) || isPostgresKeywordLexeme(lexeme);
default:

View File

@ -162,7 +162,7 @@ abstract class TableOrViewWriter {
}
/// Returns the Dart type and the Dart expression creating a `GeneratedColumn`
/// instance in drift for the givne [column].
/// instance in drift for the given [column].
static (String, String) instantiateColumn(
DriftColumn column,
TextEmitter emitter, {
@ -173,6 +173,10 @@ abstract class TableOrViewWriter {
final expressionBuffer = StringBuffer();
final constraints = defaultConstraints(column);
// Remove dialect-specific constraints for dialects we don't care about.
constraints.removeWhere(
(key, _) => !emitter.writer.options.supportedDialects.contains(key));
for (final constraint in column.constraints) {
if (constraint is LimitingTextLength) {
final buffer =

View File

@ -75,18 +75,29 @@ class ViewWriter extends TableOrViewWriter {
..write('@override\n String get entityName=>'
' ${asDartLiteral(view.schemaName)};\n');
emitter
..writeln('@override')
..write('Map<${emitter.drift('SqlDialect')}, String>')
..write(source is! SqlViewSource ? '?' : '')
..write('get createViewStatements => ');
if (source is SqlViewSource) {
final astNode = source.parsedStatement;
emitter.write('@override\nString get createViewStmt =>');
if (astNode != null) {
emitter.writeSqlAsDartLiteral(astNode);
emitter.writeSqlByDialectMap(astNode);
} else {
emitter.write(asDartLiteral(source.sqlCreateViewStmt));
final firstDialect = scope.options.supportedDialects.first;
emitter
..write('{')
..writeDriftRef('SqlDialect')
..write('.${firstDialect.name}: ')
..write(asDartLiteral(source.sqlCreateViewStmt))
..write('}');
}
buffer.writeln(';');
} else {
buffer.write('@override\n String? get createViewStmt => null;\n');
buffer.writeln('null;');
}
writeAsDslTable();

View File

@ -1,3 +1,4 @@
import 'package:drift/drift.dart';
import 'package:recase/recase.dart';
import 'package:sqlparser/sqlparser.dart' as sql;
import 'package:path/path.dart' show url;
@ -228,8 +229,50 @@ abstract class _NodeOrWriter {
return buffer.toString();
}
String sqlCode(sql.AstNode node) {
return SqlWriter(writer.options, escapeForDart: false).writeSql(node);
String sqlCode(sql.AstNode node, SqlDialect dialect) {
return SqlWriter(writer.options, dialect: dialect, escapeForDart: false)
.writeSql(node);
}
/// Builds a Dart expression writing the [node] into a Dart string.
///
/// If the code for [node] depends on the dialect, the code returned evaluates
/// to a `Map<SqlDialect, String>`. Otherwise, the code is a direct string
/// literal.
///
/// The boolean component in the record describes whether the code will be
/// dialect specific.
(String, bool) sqlByDialect(sql.AstNode node) {
final dialects = writer.options.supportedDialects;
if (dialects.length == 1) {
return (
SqlWriter(writer.options, dialect: dialects.single)
.writeNodeIntoStringLiteral(node),
false
);
}
final buffer = StringBuffer();
_writeSqlByDialectMap(node, buffer);
return (buffer.toString(), true);
}
void _writeSqlByDialectMap(sql.AstNode node, StringBuffer buffer) {
buffer.write('{');
for (final dialect in writer.options.supportedDialects) {
buffer
..write(drift('SqlDialect'))
..write(".${dialect.name}: '");
SqlWriter(writer.options, dialect: dialect, buffer: buffer)
.writeSql(node);
buffer.writeln("',");
}
buffer.write('}');
}
}
@ -302,16 +345,18 @@ class TextEmitter extends _Node {
void writeDart(AnnotatedDartCode code) => write(dartCode(code));
void writeSql(sql.AstNode node, {bool escapeForDartString = true}) {
SqlWriter(writer.options,
escapeForDart: escapeForDartString, buffer: buffer)
.writeSql(node);
void writeSql(sql.AstNode node,
{required SqlDialect dialect, bool escapeForDartString = true}) {
SqlWriter(
writer.options,
dialect: dialect,
escapeForDart: escapeForDartString,
buffer: buffer,
).writeSql(node);
}
void writeSqlAsDartLiteral(sql.AstNode node) {
buffer.write("'");
writeSql(node);
buffer.write("'");
void writeSqlByDialectMap(sql.AstNode node) {
_writeSqlByDialectMap(node, buffer);
}
}

View File

@ -76,7 +76,8 @@ SELECT rowid, highlight(example_table_search, 0, '[match]', '[match]') name,
{'a|lib/a.drift': 'CREATE VIRTUAL TABLE demo USING spellfix1;'},
options: DriftOptions.defaults(
dialect: DialectOptions(
SqlDialect.sqlite,
null,
[SqlDialect.sqlite],
SqliteAnalysisOptions(
modules: [SqlModule.spellfix1],
),

View File

@ -307,7 +307,7 @@ TypeConverter<Object, int> myConverter() => throw UnimplementedError();
'a|lib/a.drift.dart': decodedMatches(
allOf(
contains(
''''CREATE VIEW my_view AS SELECT CAST(1 AS INT) AS c1, CAST(\\'bar\\' AS TEXT) AS c2, 1 AS c3, NULLIF(1, 2) AS c4';'''),
''''CREATE VIEW my_view AS SELECT CAST(1 AS INT) AS c1, CAST(\\'bar\\' AS TEXT) AS c2, 1 AS c3, NULLIF(1, 2) AS c4','''),
contains(r'$converterc1 ='),
contains(r'$converterc2 ='),
contains(r'$converterc3 ='),

View File

@ -1,4 +1,5 @@
import 'package:build_test/build_test.dart';
import 'package:drift/drift.dart';
import 'package:drift_dev/src/analysis/options.dart';
import 'package:drift_dev/src/writer/import_manager.dart';
import 'package:drift_dev/src/writer/queries/query_writer.dart';
@ -11,14 +12,16 @@ import '../../utils.dart';
void main() {
Future<String> generateForQueryInDriftFile(String driftFile,
{DriftOptions options = const DriftOptions.defaults()}) async {
{DriftOptions options = const DriftOptions.defaults(
generateNamedParameters: true,
)}) async {
final state =
TestBackend.inTest({'a|lib/main.drift': driftFile}, options: options);
final file = await state.analyze('package:a/main.drift');
state.expectNoErrors();
final writer = Writer(
const DriftOptions.defaults(generateNamedParameters: true),
options,
generationOptions: GenerationOptions(
imports: ImportManagerForPartFiles(),
),
@ -528,4 +531,26 @@ class ADrift extends i1.ModularAccessor {
}'''))
}, outputs.dartOutputs, outputs);
});
test('creates dialect-specific query code', () async {
final result = await generateForQueryInDriftFile(
r'''
query (:foo AS TEXT): SELECT :foo;
''',
options: const DriftOptions.defaults(
dialect: DialectOptions(
null, [SqlDialect.sqlite, SqlDialect.postgres], null),
),
);
expect(
result,
contains(
'switch (executor.dialect) {'
"SqlDialect.sqlite => 'SELECT ?1 AS _c0', "
"SqlDialect.postgres || _ => 'SELECT \\\$1 AS _c0', "
'}',
),
);
});
}

View File

@ -6,14 +6,18 @@ import 'package:sqlparser/sqlparser.dart';
import 'package:test/test.dart';
void main() {
void check(String sql, String expectedDart,
{DriftOptions options = const DriftOptions.defaults()}) {
void check(
String sql,
String expectedDart, {
DriftOptions options = const DriftOptions.defaults(),
SqlDialect dialect = SqlDialect.sqlite,
}) {
final engine = SqlEngine();
final context = engine.analyze(sql);
final query = SqlSelectQuery('name', context, context.root, [], [],
InferredResultSet(null, []), null, null);
final result = SqlWriter(options, query: query).write();
final result = SqlWriter(options, dialect: dialect, query: query).write();
expect(result, expectedDart);
}
@ -33,7 +37,6 @@ void main() {
test('escapes postgres keywords', () {
check('SELECT * FROM user', "'SELECT * FROM user'");
check('SELECT * FROM user', "'SELECT * FROM \"user\"'",
options: DriftOptions.defaults(
dialect: DialectOptions(SqlDialect.postgres, null)));
dialect: SqlDialect.postgres);
});
}

View File

@ -602,8 +602,10 @@ class PopularUsers extends i0.ViewInfo<i1.PopularUsers, i1.PopularUser>
@override
String get entityName => 'popular_users';
@override
String get createViewStmt =>
'CREATE VIEW popular_users AS SELECT * FROM users ORDER BY (SELECT count(*) FROM follows WHERE followed = users.id)';
Map<i0.SqlDialect, String> get createViewStatements => {
i0.SqlDialect.sqlite:
'CREATE VIEW popular_users AS SELECT * FROM users ORDER BY (SELECT count(*) FROM follows WHERE followed = users.id)',
};
@override
PopularUsers get asDslTable => this;
@override

View File

@ -9,9 +9,7 @@ targets:
raw_result_set_data: false
named_parameters: false
sql:
# As sqlite3 is compatible with the postgres dialect (but not vice-versa), we're
# using this dialect so that we can run the tests on postgres as well.
dialect: postgres
dialects: [sqlite, postgres]
options:
version: "3.37"
modules:

View File

@ -336,7 +336,6 @@ class $FriendshipsTable extends Friendships
requiredDuringInsert: false,
defaultConstraints: GeneratedColumn.constraintsDependsOnDialect({
SqlDialect.sqlite: 'CHECK ("really_good_friends" IN (0, 1))',
SqlDialect.mysql: '',
SqlDialect.postgres: '',
}),
defaultValue: const Constant(false));
@ -554,7 +553,13 @@ abstract class _$Database extends GeneratedDatabase {
late final $FriendshipsTable friendships = $FriendshipsTable(this);
Selectable<User> mostPopularUsers(int amount) {
return customSelect(
'SELECT * FROM users AS u ORDER BY (SELECT COUNT(*) FROM friendships WHERE first_user = u.id OR second_user = u.id) DESC LIMIT \$1',
switch (executor.dialect) {
SqlDialect.sqlite =>
'SELECT * FROM users AS u ORDER BY (SELECT COUNT(*) FROM friendships WHERE first_user = u.id OR second_user = u.id) DESC LIMIT ?1',
SqlDialect.postgres ||
_ =>
'SELECT * FROM users AS u ORDER BY (SELECT COUNT(*) FROM friendships WHERE first_user = u.id OR second_user = u.id) DESC LIMIT \$1',
},
variables: [
Variable<int>(amount)
],
@ -566,7 +571,13 @@ abstract class _$Database extends GeneratedDatabase {
Selectable<int> amountOfGoodFriends(int user) {
return customSelect(
'SELECT COUNT(*) AS _c0 FROM friendships AS f WHERE f.really_good_friends = TRUE AND(f.first_user = \$1 OR f.second_user = \$1)',
switch (executor.dialect) {
SqlDialect.sqlite =>
'SELECT COUNT(*) AS _c0 FROM friendships AS f WHERE f.really_good_friends = TRUE AND(f.first_user = ?1 OR f.second_user = ?1)',
SqlDialect.postgres ||
_ =>
'SELECT COUNT(*) AS _c0 FROM friendships AS f WHERE f.really_good_friends = TRUE AND(f.first_user = \$1 OR f.second_user = \$1)',
},
variables: [
Variable<int>(user)
],
@ -577,19 +588,23 @@ abstract class _$Database extends GeneratedDatabase {
Selectable<FriendshipsOfResult> friendshipsOf(int user) {
return customSelect(
'SELECT f.really_good_friends,"user"."id" AS "nested_0.id", "user"."name" AS "nested_0.name", "user"."birth_date" AS "nested_0.birth_date", "user"."profile_picture" AS "nested_0.profile_picture", "user"."preferences" AS "nested_0.preferences" FROM friendships AS f INNER JOIN users AS "user" ON "user".id IN (f.first_user, f.second_user) AND "user".id != \$1 WHERE(f.first_user = \$1 OR f.second_user = \$1)',
switch (executor.dialect) {
SqlDialect.sqlite =>
'SELECT f.really_good_friends,"user"."id" AS "nested_0.id", "user"."name" AS "nested_0.name", "user"."birth_date" AS "nested_0.birth_date", "user"."profile_picture" AS "nested_0.profile_picture", "user"."preferences" AS "nested_0.preferences" FROM friendships AS f INNER JOIN users AS user ON user.id IN (f.first_user, f.second_user) AND user.id != ?1 WHERE(f.first_user = ?1 OR f.second_user = ?1)',
SqlDialect.postgres ||
_ =>
'SELECT f.really_good_friends,"user"."id" AS "nested_0.id", "user"."name" AS "nested_0.name", "user"."birth_date" AS "nested_0.birth_date", "user"."profile_picture" AS "nested_0.profile_picture", "user"."preferences" AS "nested_0.preferences" FROM friendships AS f INNER JOIN users AS "user" ON "user".id IN (f.first_user, f.second_user) AND "user".id != \$1 WHERE(f.first_user = \$1 OR f.second_user = \$1)',
},
variables: [
Variable<int>(user)
],
readsFrom: {
friendships,
users,
}).asyncMap((QueryRow row) async {
return FriendshipsOfResult(
reallyGoodFriends: row.read<bool>('really_good_friends'),
user: await users.mapFromRow(row, tablePrefix: 'nested_0'),
);
});
}).asyncMap((QueryRow row) async => FriendshipsOfResult(
reallyGoodFriends: row.read<bool>('really_good_friends'),
user: await users.mapFromRow(row, tablePrefix: 'nested_0'),
));
}
Selectable<int> userCount() {
@ -601,7 +616,13 @@ abstract class _$Database extends GeneratedDatabase {
}
Selectable<Preferences?> settingsFor(int user) {
return customSelect('SELECT preferences FROM users WHERE id = \$1',
return customSelect(
switch (executor.dialect) {
SqlDialect.sqlite => 'SELECT preferences FROM users WHERE id = ?1',
SqlDialect.postgres ||
_ =>
'SELECT preferences FROM users WHERE id = \$1',
},
variables: [
Variable<int>(user)
],
@ -626,7 +647,13 @@ abstract class _$Database extends GeneratedDatabase {
Future<List<Friendship>> returning(int var1, int var2, bool var3) {
return customWriteReturning(
'INSERT INTO friendships VALUES (\$1, \$2, \$3) RETURNING *',
switch (executor.dialect) {
SqlDialect.sqlite =>
'INSERT INTO friendships VALUES (?1, ?2, ?3) RETURNING *',
SqlDialect.postgres ||
_ =>
'INSERT INTO friendships VALUES (\$1, \$2, \$3) RETURNING *',
},
variables: [
Variable<int>(var1),
Variable<int>(var2),

View File

@ -4,7 +4,7 @@ version: 1.0.0
# homepage: https://www.example.com
environment:
sdk: '>=2.17.0 <3.0.0'
sdk: '>=3.0.0 <4.0.0'
dependencies:
drift: ^2.0.0-0