diff --git a/docs/lib/snippets/modular/custom_types/drift_table.drift b/docs/lib/snippets/modular/custom_types/drift_table.drift new file mode 100644 index 00000000..92317018 --- /dev/null +++ b/docs/lib/snippets/modular/custom_types/drift_table.drift @@ -0,0 +1,7 @@ +import 'type.dart'; + +CREATE TABLE periodic_reminders ( + id INTEGER NOT NULL PRIMARY KEY, + frequency `const DurationType()` NOT NULL, + reminder TEXT NOT NULL +); diff --git a/docs/lib/snippets/modular/custom_types/drift_table.drift.dart b/docs/lib/snippets/modular/custom_types/drift_table.drift.dart new file mode 100644 index 00000000..db590906 --- /dev/null +++ b/docs/lib/snippets/modular/custom_types/drift_table.drift.dart @@ -0,0 +1,223 @@ +// ignore_for_file: type=lint +import 'package:drift/drift.dart' as i0; +import 'package:drift_docs/snippets/modular/custom_types/drift_table.drift.dart' + as i1; +import 'package:drift_docs/snippets/modular/custom_types/type.dart' as i2; + +class PeriodicReminders extends i0.Table + with i0.TableInfo { + @override + final i0.GeneratedDatabase attachedDatabase; + final String? _alias; + PeriodicReminders(this.attachedDatabase, [this._alias]); + static const i0.VerificationMeta _idMeta = const i0.VerificationMeta('id'); + late final i0.GeneratedColumn id = i0.GeneratedColumn( + 'id', aliasedName, false, + type: i0.DriftSqlType.int, + requiredDuringInsert: false, + $customConstraints: 'NOT NULL PRIMARY KEY'); + static const i0.VerificationMeta _frequencyMeta = + const i0.VerificationMeta('frequency'); + late final i0.GeneratedColumn frequency = + i0.GeneratedColumn('frequency', aliasedName, false, + type: const i2.DurationType(), + requiredDuringInsert: true, + $customConstraints: 'NOT NULL'); + static const i0.VerificationMeta _reminderMeta = + const i0.VerificationMeta('reminder'); + late final i0.GeneratedColumn reminder = i0.GeneratedColumn( + 'reminder', aliasedName, false, + type: i0.DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL'); + @override + List get $columns => [id, frequency, reminder]; + @override + String get aliasedName => _alias ?? actualTableName; + @override + String get actualTableName => $name; + static const String $name = 'periodic_reminders'; + @override + i0.VerificationContext validateIntegrity( + i0.Insertable instance, + {bool isInserting = false}) { + final context = i0.VerificationContext(); + final data = instance.toColumns(true); + if (data.containsKey('id')) { + context.handle(_idMeta, id.isAcceptableOrUnknown(data['id']!, _idMeta)); + } + if (data.containsKey('frequency')) { + context.handle(_frequencyMeta, + frequency.isAcceptableOrUnknown(data['frequency']!, _frequencyMeta)); + } else if (isInserting) { + context.missing(_frequencyMeta); + } + if (data.containsKey('reminder')) { + context.handle(_reminderMeta, + reminder.isAcceptableOrUnknown(data['reminder']!, _reminderMeta)); + } else if (isInserting) { + context.missing(_reminderMeta); + } + return context; + } + + @override + Set get $primaryKey => {id}; + @override + i1.PeriodicReminder map(Map data, {String? tablePrefix}) { + final effectivePrefix = tablePrefix != null ? '$tablePrefix.' : ''; + return i1.PeriodicReminder( + id: attachedDatabase.typeMapping + .read(i0.DriftSqlType.int, data['${effectivePrefix}id'])!, + frequency: attachedDatabase.typeMapping + .read(const i2.DurationType(), data['${effectivePrefix}frequency'])!, + reminder: attachedDatabase.typeMapping + .read(i0.DriftSqlType.string, data['${effectivePrefix}reminder'])!, + ); + } + + @override + PeriodicReminders createAlias(String alias) { + return PeriodicReminders(attachedDatabase, alias); + } + + @override + bool get dontWriteConstraints => true; +} + +class PeriodicReminder extends i0.DataClass + implements i0.Insertable { + final int id; + final Duration frequency; + final String reminder; + const PeriodicReminder( + {required this.id, required this.frequency, required this.reminder}); + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + map['id'] = i0.Variable(id); + map['frequency'] = i0.Variable(frequency); + map['reminder'] = i0.Variable(reminder); + return map; + } + + i1.PeriodicRemindersCompanion toCompanion(bool nullToAbsent) { + return i1.PeriodicRemindersCompanion( + id: i0.Value(id), + frequency: i0.Value(frequency), + reminder: i0.Value(reminder), + ); + } + + factory PeriodicReminder.fromJson(Map json, + {i0.ValueSerializer? serializer}) { + serializer ??= i0.driftRuntimeOptions.defaultSerializer; + return PeriodicReminder( + id: serializer.fromJson(json['id']), + frequency: serializer.fromJson(json['frequency']), + reminder: serializer.fromJson(json['reminder']), + ); + } + @override + Map toJson({i0.ValueSerializer? serializer}) { + serializer ??= i0.driftRuntimeOptions.defaultSerializer; + return { + 'id': serializer.toJson(id), + 'frequency': serializer.toJson(frequency), + 'reminder': serializer.toJson(reminder), + }; + } + + i1.PeriodicReminder copyWith( + {int? id, Duration? frequency, String? reminder}) => + i1.PeriodicReminder( + id: id ?? this.id, + frequency: frequency ?? this.frequency, + reminder: reminder ?? this.reminder, + ); + @override + String toString() { + return (StringBuffer('PeriodicReminder(') + ..write('id: $id, ') + ..write('frequency: $frequency, ') + ..write('reminder: $reminder') + ..write(')')) + .toString(); + } + + @override + int get hashCode => Object.hash(id, frequency, reminder); + @override + bool operator ==(Object other) => + identical(this, other) || + (other is i1.PeriodicReminder && + other.id == this.id && + other.frequency == this.frequency && + other.reminder == this.reminder); +} + +class PeriodicRemindersCompanion + extends i0.UpdateCompanion { + final i0.Value id; + final i0.Value frequency; + final i0.Value reminder; + const PeriodicRemindersCompanion({ + this.id = const i0.Value.absent(), + this.frequency = const i0.Value.absent(), + this.reminder = const i0.Value.absent(), + }); + PeriodicRemindersCompanion.insert({ + this.id = const i0.Value.absent(), + required Duration frequency, + required String reminder, + }) : frequency = i0.Value(frequency), + reminder = i0.Value(reminder); + static i0.Insertable custom({ + i0.Expression? id, + i0.Expression? frequency, + i0.Expression? reminder, + }) { + return i0.RawValuesInsertable({ + if (id != null) 'id': id, + if (frequency != null) 'frequency': frequency, + if (reminder != null) 'reminder': reminder, + }); + } + + i1.PeriodicRemindersCompanion copyWith( + {i0.Value? id, + i0.Value? frequency, + i0.Value? reminder}) { + return i1.PeriodicRemindersCompanion( + id: id ?? this.id, + frequency: frequency ?? this.frequency, + reminder: reminder ?? this.reminder, + ); + } + + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + if (id.present) { + map['id'] = i0.Variable(id.value); + } + if (frequency.present) { + map['frequency'] = + i0.Variable(frequency.value, const i2.DurationType()); + } + if (reminder.present) { + map['reminder'] = i0.Variable(reminder.value); + } + return map; + } + + @override + String toString() { + return (StringBuffer('PeriodicRemindersCompanion(') + ..write('id: $id, ') + ..write('frequency: $frequency, ') + ..write('reminder: $reminder') + ..write(')')) + .toString(); + } +} diff --git a/docs/lib/snippets/modular/custom_types/table.dart b/docs/lib/snippets/modular/custom_types/table.dart new file mode 100644 index 00000000..80790571 --- /dev/null +++ b/docs/lib/snippets/modular/custom_types/table.dart @@ -0,0 +1,9 @@ +import 'package:drift/drift.dart'; +import 'type.dart'; + +class PeriodicReminders extends Table { + IntColumn get id => integer().autoIncrement()(); + Column get frequency => customType(const DurationType()) + .clientDefault(() => Duration(minutes: 15))(); + TextColumn get reminder => text()(); +} diff --git a/docs/lib/snippets/modular/custom_types/table.drift.dart b/docs/lib/snippets/modular/custom_types/table.drift.dart new file mode 100644 index 00000000..2e3239e1 --- /dev/null +++ b/docs/lib/snippets/modular/custom_types/table.drift.dart @@ -0,0 +1,221 @@ +// ignore_for_file: type=lint +import 'package:drift/drift.dart' as i0; +import 'package:drift_docs/snippets/modular/custom_types/table.drift.dart' + as i1; +import 'package:drift_docs/snippets/modular/custom_types/type.dart' as i2; +import 'package:drift_docs/snippets/modular/custom_types/table.dart' as i3; + +class $PeriodicRemindersTable extends i3.PeriodicReminders + with i0.TableInfo<$PeriodicRemindersTable, i1.PeriodicReminder> { + @override + final i0.GeneratedDatabase attachedDatabase; + final String? _alias; + $PeriodicRemindersTable(this.attachedDatabase, [this._alias]); + static const i0.VerificationMeta _idMeta = const i0.VerificationMeta('id'); + @override + late final i0.GeneratedColumn id = i0.GeneratedColumn( + 'id', aliasedName, false, + hasAutoIncrement: true, + type: i0.DriftSqlType.int, + requiredDuringInsert: false, + defaultConstraints: + i0.GeneratedColumn.constraintIsAlways('PRIMARY KEY AUTOINCREMENT')); + static const i0.VerificationMeta _frequencyMeta = + const i0.VerificationMeta('frequency'); + @override + late final i0.GeneratedColumn frequency = + i0.GeneratedColumn('frequency', aliasedName, false, + type: const i2.DurationType(), + requiredDuringInsert: false, + clientDefault: () => Duration(minutes: 15)); + static const i0.VerificationMeta _reminderMeta = + const i0.VerificationMeta('reminder'); + @override + late final i0.GeneratedColumn reminder = i0.GeneratedColumn( + 'reminder', aliasedName, false, + type: i0.DriftSqlType.string, requiredDuringInsert: true); + @override + List get $columns => [id, frequency, reminder]; + @override + String get aliasedName => _alias ?? actualTableName; + @override + String get actualTableName => $name; + static const String $name = 'periodic_reminders'; + @override + i0.VerificationContext validateIntegrity( + i0.Insertable instance, + {bool isInserting = false}) { + final context = i0.VerificationContext(); + final data = instance.toColumns(true); + if (data.containsKey('id')) { + context.handle(_idMeta, id.isAcceptableOrUnknown(data['id']!, _idMeta)); + } + if (data.containsKey('frequency')) { + context.handle(_frequencyMeta, + frequency.isAcceptableOrUnknown(data['frequency']!, _frequencyMeta)); + } + if (data.containsKey('reminder')) { + context.handle(_reminderMeta, + reminder.isAcceptableOrUnknown(data['reminder']!, _reminderMeta)); + } else if (isInserting) { + context.missing(_reminderMeta); + } + return context; + } + + @override + Set get $primaryKey => {id}; + @override + i1.PeriodicReminder map(Map data, {String? tablePrefix}) { + final effectivePrefix = tablePrefix != null ? '$tablePrefix.' : ''; + return i1.PeriodicReminder( + id: attachedDatabase.typeMapping + .read(i0.DriftSqlType.int, data['${effectivePrefix}id'])!, + frequency: attachedDatabase.typeMapping + .read(const i2.DurationType(), data['${effectivePrefix}frequency'])!, + reminder: attachedDatabase.typeMapping + .read(i0.DriftSqlType.string, data['${effectivePrefix}reminder'])!, + ); + } + + @override + $PeriodicRemindersTable createAlias(String alias) { + return $PeriodicRemindersTable(attachedDatabase, alias); + } +} + +class PeriodicReminder extends i0.DataClass + implements i0.Insertable { + final int id; + final Duration frequency; + final String reminder; + const PeriodicReminder( + {required this.id, required this.frequency, required this.reminder}); + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + map['id'] = i0.Variable(id); + map['frequency'] = i0.Variable(frequency); + map['reminder'] = i0.Variable(reminder); + return map; + } + + i1.PeriodicRemindersCompanion toCompanion(bool nullToAbsent) { + return i1.PeriodicRemindersCompanion( + id: i0.Value(id), + frequency: i0.Value(frequency), + reminder: i0.Value(reminder), + ); + } + + factory PeriodicReminder.fromJson(Map json, + {i0.ValueSerializer? serializer}) { + serializer ??= i0.driftRuntimeOptions.defaultSerializer; + return PeriodicReminder( + id: serializer.fromJson(json['id']), + frequency: serializer.fromJson(json['frequency']), + reminder: serializer.fromJson(json['reminder']), + ); + } + @override + Map toJson({i0.ValueSerializer? serializer}) { + serializer ??= i0.driftRuntimeOptions.defaultSerializer; + return { + 'id': serializer.toJson(id), + 'frequency': serializer.toJson(frequency), + 'reminder': serializer.toJson(reminder), + }; + } + + i1.PeriodicReminder copyWith( + {int? id, Duration? frequency, String? reminder}) => + i1.PeriodicReminder( + id: id ?? this.id, + frequency: frequency ?? this.frequency, + reminder: reminder ?? this.reminder, + ); + @override + String toString() { + return (StringBuffer('PeriodicReminder(') + ..write('id: $id, ') + ..write('frequency: $frequency, ') + ..write('reminder: $reminder') + ..write(')')) + .toString(); + } + + @override + int get hashCode => Object.hash(id, frequency, reminder); + @override + bool operator ==(Object other) => + identical(this, other) || + (other is i1.PeriodicReminder && + other.id == this.id && + other.frequency == this.frequency && + other.reminder == this.reminder); +} + +class PeriodicRemindersCompanion + extends i0.UpdateCompanion { + final i0.Value id; + final i0.Value frequency; + final i0.Value reminder; + const PeriodicRemindersCompanion({ + this.id = const i0.Value.absent(), + this.frequency = const i0.Value.absent(), + this.reminder = const i0.Value.absent(), + }); + PeriodicRemindersCompanion.insert({ + this.id = const i0.Value.absent(), + this.frequency = const i0.Value.absent(), + required String reminder, + }) : reminder = i0.Value(reminder); + static i0.Insertable custom({ + i0.Expression? id, + i0.Expression? frequency, + i0.Expression? reminder, + }) { + return i0.RawValuesInsertable({ + if (id != null) 'id': id, + if (frequency != null) 'frequency': frequency, + if (reminder != null) 'reminder': reminder, + }); + } + + i1.PeriodicRemindersCompanion copyWith( + {i0.Value? id, + i0.Value? frequency, + i0.Value? reminder}) { + return i1.PeriodicRemindersCompanion( + id: id ?? this.id, + frequency: frequency ?? this.frequency, + reminder: reminder ?? this.reminder, + ); + } + + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + if (id.present) { + map['id'] = i0.Variable(id.value); + } + if (frequency.present) { + map['frequency'] = + i0.Variable(frequency.value, const i2.DurationType()); + } + if (reminder.present) { + map['reminder'] = i0.Variable(reminder.value); + } + return map; + } + + @override + String toString() { + return (StringBuffer('PeriodicRemindersCompanion(') + ..write('id: $id, ') + ..write('frequency: $frequency, ') + ..write('reminder: $reminder') + ..write(')')) + .toString(); + } +} diff --git a/docs/lib/snippets/modular/custom_types/type.dart b/docs/lib/snippets/modular/custom_types/type.dart new file mode 100644 index 00000000..4acc263a --- /dev/null +++ b/docs/lib/snippets/modular/custom_types/type.dart @@ -0,0 +1,19 @@ +import 'package:drift/drift.dart'; + +class DurationType implements CustomSqlType { + const DurationType(); + + @override + String mapToSqlLiteral(Duration dartValue) { + return "interval '${dartValue.inMicroseconds} microseconds'"; + } + + @override + Object mapToSqlParameter(Duration dartValue) => dartValue; + + @override + Duration read(Object fromSql) => fromSql as Duration; + + @override + String sqlTypeName(GenerationContext context) => 'interval'; +} diff --git a/docs/lib/snippets/modular/many_to_many/json.drift.dart b/docs/lib/snippets/modular/many_to_many/json.drift.dart index e5542372..57ddbccf 100644 --- a/docs/lib/snippets/modular/many_to_many/json.drift.dart +++ b/docs/lib/snippets/modular/many_to_many/json.drift.dart @@ -187,6 +187,7 @@ class ShoppingCartsCompanion extends i0.UpdateCompanion { } if (entries.present) { final converter = i2.$ShoppingCartsTable.$converterentries; + map['entries'] = i0.Variable(converter.toSql(entries.value)); } return map; diff --git a/docs/pages/docs/SQL API/types.md b/docs/pages/docs/SQL API/types.md new file mode 100644 index 00000000..c868c797 --- /dev/null +++ b/docs/pages/docs/SQL API/types.md @@ -0,0 +1,86 @@ +--- +data: + title: "Custom SQL types" + weight: 10 + description: Use custom SQL types in Drift files and Dart code. + +template: layouts/docs/single +--- + +Drift's core library is written with sqlite3 as a primary target. This is +reflected in the [SQL types][types] drift supports out of the box - these +types supported by sqlite3 with a few additions that are handled in Dart. + +Other databases for which drift has limited support commonly support more types. +For instance, postgres has a dedicated type for durations, JSON values, UUIDs +and more. With a sqlite3 database, you'd use a [type converter][type converters] +to store these values with the types supported by sqlite3. +While type converters can also work here, they tell drift to use a regular text +column under the hood. When a database has builtin support for UUIDs for instance, +this could lead to less efficient statements or issues with other applications +talking to same database. +For this reason, drift allows the use of "custom types" - types that are not defined +in the core `drift` package and don't work with all databases. + +{% block "blocks/alert" title="When to use custom types - summary" %} +Custom types are a good tool when extending drift support to new database engines +with their own types not already covered by drift. + +Unless you're extending drift to work with a new database package (which is awesome, +please reach out!), you probably don't need to implement custom types yourself. +Packages like `drift_postgres` already define relevant custom types for you. +{% endblock %} + +## Defining a type + +As an example, let's assume we have a database with native support for `Duration` +values via the `interval` type. We're using a database driver that also has native +support for `Duration` values, meaning that they can be passed to the database in +prepared statements and also be read from rows without manual conversions. + +In that case, a custom type class to implement `Duration` support for drift would be +added: + +{% include "blocks/snippet" snippets = ('package:drift_docs/snippets/modular/custom_types/type.dart.excerpt.json' | readString | json_decode) %} + +This type defines the following things: + +- When `Duration` values are mapped to SQL literals (for instance, because they're used in `Constant`s), + we represent them as `interval '123754 microseconds'` in SQL. +- When a `Duration` value is mapped to a parameter, we just use the value directly (since we + assume it is supported by the underlying database driver here). +- Similarly, we expect that the database driver correctly returns durations as instances of + `Duration`, so the other way around in `read` also just casts the value. +- The name to use in `CREATE TABLE` statements and casts is `interval`. + +## Using custom types + +### In Dart + +To define a custom type on a Dart table, use the `customType` column builder method with the type: + +{% include "blocks/snippet" snippets = ('package:drift_docs/snippets/modular/custom_types/table.dart.excerpt.json' | readString | json_decode) %} + +As the example shows, other column constraints like `clientDefault` can still be added to custom +columns. You can even combine custom columns and type converters if needed. + +This is enough to get most queries to work, but in some advanced scenarios you may have to provide +more information to use custom types. +For instance, when manually constructing a `Variable` or a `Constant` with a custom type, the custom +type must be added as a second parameter to the constructor. This is because, unlike for builtin types, +drift doesn't have a central register describing how to deal with custom type values. + +### In SQL + +In SQL, Drift's [inline Dart]({{ 'drift_files.md#dart-interop' | pageUrl }}) syntax may be used to define +the custom type: + +{% include "blocks/snippet" snippets = ('package:drift_docs/snippets/modular/custom_types/drift_table.drift.excerpt.json' | readString | json_decode) %} + +Please note that support for custom types in drift files is currently limited. +For instance, custom types are not currently supported in `CAST` expressions. +If you are interested in advanced analysis support for custom types, please reach out by +opening an issue or a discussion describing your use-cases, thanks! + +[types]: {{ '../Dart API/tables.md#supported-column-types' | pageUrl }} +[type converters]: {{ '../type_converters.md' | pageUrl }} diff --git a/drift/CHANGELOG.md b/drift/CHANGELOG.md index e7fafe98..fb569da8 100644 --- a/drift/CHANGELOG.md +++ b/drift/CHANGELOG.md @@ -1,6 +1,8 @@ ## 2.13.0-dev - Add APIs to setup Wasm databases with custom drift workers. +- Add support for [custom types](https://drift.simonbinder.eu/docs/sql-api/types/), + which are useful when extending drift to support other database engines. - Add `Expression.and` and `Expression.or` to create disjunctions and conjunctions of sub-predicates. - Step-by-step migrations now save the intermediate schema version after each step. diff --git a/drift/example/main.g.dart b/drift/example/main.g.dart index adb70e1b..966f06b6 100644 --- a/drift/example/main.g.dart +++ b/drift/example/main.g.dart @@ -691,17 +691,10 @@ abstract class _$Database extends GeneratedDatabase { $TodoCategoryItemCountView(this); late final $TodoItemWithCategoryNameViewView customViewName = $TodoItemWithCategoryNameViewView(this); - late final Index itemTitle = - Index('item_title', 'CREATE INDEX item_title ON todo_items (title)'); @override Iterable> get allTables => allSchemaEntities.whereType>(); @override - List get allSchemaEntities => [ - todoCategories, - todoItems, - todoCategoryItemCount, - customViewName, - itemTitle - ]; + List get allSchemaEntities => + [todoCategories, todoItems, todoCategoryItemCount, customViewName]; } diff --git a/drift/lib/drift.dart b/drift/lib/drift.dart index 40f8f423..51886b96 100644 --- a/drift/lib/drift.dart +++ b/drift/lib/drift.dart @@ -18,7 +18,7 @@ export 'src/runtime/executor/executor.dart'; export 'src/runtime/query_builder/query_builder.dart' hide CaseWhenExpressionWithBase, BaseCaseWhenExpression; export 'src/runtime/types/converters.dart'; -export 'src/runtime/types/mapping.dart'; +export 'src/runtime/types/mapping.dart' hide BaseSqlType; export 'src/utils/lazy_database.dart'; /// A [ListEquality] instance used by generated drift code for the `==` and diff --git a/drift/lib/src/drift_dev_helper.dart b/drift/lib/src/drift_dev_helper.dart index 3243c1b5..af7ff171 100644 --- a/drift/lib/src/drift_dev_helper.dart +++ b/drift/lib/src/drift_dev_helper.dart @@ -2,7 +2,7 @@ export 'dart:typed_data' show Uint8List; export 'runtime/types/converters.dart' show TypeConverter, JsonTypeConverter2; -export 'runtime/types/mapping.dart' show DriftAny; +export 'runtime/types/mapping.dart' show DriftAny, CustomSqlType; export 'runtime/query_builder/query_builder.dart' show TableInfo; export 'dsl/dsl.dart' diff --git a/drift/lib/src/dsl/table.dart b/drift/lib/src/dsl/table.dart index 072f12f6..1ca952c7 100644 --- a/drift/lib/src/dsl/table.dart +++ b/drift/lib/src/dsl/table.dart @@ -196,6 +196,17 @@ abstract class Table extends HasResultSet { /// ``` @protected ColumnBuilder real() => _isGenerated(); + + /// Defines a column with a custom [type] when used as a getter. + /// + /// For more information on custom types and when they can be useful, see + /// https://drift.simonbinder.eu/docs/sql-api/types/. + /// + /// For most users, [TypeConverter]s are a more appropriate tool to store + /// custom values in the database. + @protected + ColumnBuilder customType(CustomSqlType type) => + _isGenerated(); } /// Subclasses represent a view in a database generated by drift. diff --git a/drift/lib/src/runtime/query_builder/expressions/datetimes.dart b/drift/lib/src/runtime/query_builder/expressions/datetimes.dart index 99f0c292..ae4c8794 100644 --- a/drift/lib/src/runtime/query_builder/expressions/datetimes.dart +++ b/drift/lib/src/runtime/query_builder/expressions/datetimes.dart @@ -33,6 +33,7 @@ const Expression currentDateAndTime = _DependingOnDateTimeExpression( 'strftime', [Constant('%s'), _currentTimestampLiteral], ), + DriftSqlType.dateTime, ), ); @@ -341,7 +342,7 @@ class _DependingOnDateTimeExpression extends Expression { /// For another explanation of modifiers, see the [sqlite3 docs]. /// /// [sqlite3 docs]: https://sqlite.org/lang_datefunc.html#modifiers -class DateTimeModifier extends Constant { +final class DateTimeModifier extends Constant { const DateTimeModifier._(super.value); /// Adds or subtracts [days] calendar days from the date time value. diff --git a/drift/lib/src/runtime/query_builder/expressions/expression.dart b/drift/lib/src/runtime/query_builder/expressions/expression.dart index 26a4c814..64aaa096 100644 --- a/drift/lib/src/runtime/query_builder/expressions/expression.dart +++ b/drift/lib/src/runtime/query_builder/expressions/expression.dart @@ -90,8 +90,13 @@ abstract class Expression implements FunctionParameter { /// Note that this does not do a meaningful conversion for drift-only types /// like `bool` or `DateTime`. Both would simply generate a `CAST AS INT` /// expression. - Expression cast() { - return _CastInSqlExpression(this); + /// + /// The optional [type] parameter can be used to specify the SQL type to cast + /// to. This is mainly useful for [CustomSqlType]s. For types supported by + /// drift, [DriftSqlType.forType] will be used as a default. + Expression cast([BaseSqlType? type]) { + return _CastInSqlExpression( + this, type ?? DriftSqlType.forType()); } /// Generates an `IS` expression in SQL, comparing this expression with the @@ -269,8 +274,11 @@ abstract class Expression implements FunctionParameter { inner.writeAroundPrecedence(ctx, precedence); } - /// The supported [DriftSqlType] backing this expression. - DriftSqlType get driftSqlType => DriftSqlType.forType(); + /// The [BaseSqlType] backing this expression. + /// + /// This is a recognized [DriftSqlType] for all expressions for which a custom + /// type has not explicitly been set. + BaseSqlType get driftSqlType => DriftSqlType.forType(); /// Chains all [predicates] together into a single expression that will /// evaluate to `true` iff any of the [predicates] evaluates to `true`. @@ -503,16 +511,20 @@ class _DartCastExpression class _CastInSqlExpression extends Expression { final Expression inner; + final BaseSqlType targetType; @override Precedence get precedence => Precedence.primary; - const _CastInSqlExpression(this.inner); + @override + BaseSqlType get driftSqlType => targetType; + + const _CastInSqlExpression(this.inner, this.targetType); @override void writeInto(GenerationContext context) { - final type = DriftSqlType.forType(); - if (type == DriftSqlType.any) { + // ignore: unrelated_type_equality_checks + if (targetType == DriftSqlType.any) { inner.writeInto(context); // No need to cast } @@ -523,7 +535,7 @@ class _CastInSqlExpression // ones used in a create table statement. // ignore: unnecessary_cast - typeName = switch (type as DriftSqlType) { + typeName = switch (targetType) { DriftSqlType.int || DriftSqlType.bigInt || DriftSqlType.bool => @@ -533,9 +545,10 @@ class _CastInSqlExpression DriftSqlType.blob => 'BINARY', DriftSqlType.dateTime => 'DATETIME', DriftSqlType.any => '', + CustomSqlType() => targetType.sqlTypeName(context), }; } else { - typeName = type.sqlTypeName(context); + typeName = targetType.sqlTypeName(context); } context.buffer.write('CAST('); diff --git a/drift/lib/src/runtime/query_builder/expressions/variables.dart b/drift/lib/src/runtime/query_builder/expressions/variables.dart index 21a2d9b2..76bff618 100644 --- a/drift/lib/src/runtime/query_builder/expressions/variables.dart +++ b/drift/lib/src/runtime/query_builder/expressions/variables.dart @@ -5,9 +5,10 @@ part of '../query_builder.dart'; /// An expression that represents the value of a dart object encoded to sql /// using prepared statements. -class Variable extends Expression { +final class Variable extends Expression { /// The Dart value that will be sent to the database final T? value; + final CustomSqlType? _customType; // note that we keep the identity hash/equals here because each variable would // get its own index in sqlite and is thus different. @@ -18,8 +19,14 @@ class Variable extends Expression { @override int get hashCode => value.hashCode; + @override + BaseSqlType get driftSqlType => _customType ?? super.driftSqlType; + /// Constructs a new variable from the [value]. - const Variable(this.value); + /// + /// For variables of [CustomSqlType]s, the `type` can also be provided as a + /// parameter to control how the value is mapped to SQL. + const Variable(this.value, [this._customType]); /// Creates a variable that holds the specified boolean. static Variable withBool(bool value) { @@ -60,7 +67,12 @@ class Variable extends Expression { /// database engine. For instance, a [DateTime] will me mapped to its unix /// timestamp. dynamic mapToSimpleValue(GenerationContext context) { - return context.typeMapping.mapToSqlVariable(value); + final type = _customType; + if (value != null && type != null) { + return type.mapToSqlParameter(value!); + } else { + return context.typeMapping.mapToSqlVariable(value); + } } @override @@ -110,22 +122,32 @@ class Variable extends Expression { /// An expression that represents the value of a dart object encoded to sql /// by writing them into the sql statements. For most cases, consider using /// [Variable] instead. -class Constant extends Expression { +final class Constant extends Expression { + /// The value that will be converted to an sql literal. + final T? value; + + final CustomSqlType? _customType; + /// Constructs a new constant (sql literal) holding the [value]. - const Constant(this.value); + const Constant(this.value, [this._customType]); @override Precedence get precedence => Precedence.primary; - /// The value that will be converted to an sql literal. - final T? value; + @override + BaseSqlType get driftSqlType => _customType ?? super.driftSqlType; @override bool get isLiteral => true; @override void writeInto(GenerationContext context) { - context.buffer.write(context.typeMapping.mapToSqlLiteral(value)); + final type = _customType; + if (value != null && type != null) { + context.buffer.write(type.mapToSqlLiteral(value!)); + } else { + context.buffer.write(context.typeMapping.mapToSqlLiteral(value)); + } } @override diff --git a/drift/lib/src/runtime/query_builder/schema/column_impl.dart b/drift/lib/src/runtime/query_builder/schema/column_impl.dart index e9c15bc8..852ec721 100644 --- a/drift/lib/src/runtime/query_builder/schema/column_impl.dart +++ b/drift/lib/src/runtime/query_builder/schema/column_impl.dart @@ -54,7 +54,7 @@ class GeneratedColumn extends Column { final VerificationResult Function(T?, VerificationMeta)? additionalChecks; /// The sql type to use for this column. - final DriftSqlType type; + final BaseSqlType type; /// If this column is generated (that is, it is a SQL expression of other) /// columns, contains information about how to generate this column. @@ -70,6 +70,9 @@ class GeneratedColumn extends Column { @override String get name => $name; + @override + BaseSqlType get driftSqlType => type; + /// Used by generated code. GeneratedColumn( this.$name, @@ -294,7 +297,7 @@ class GeneratedColumnWithTypeConverter String tableName, bool nullable, S? Function()? clientDefault, - DriftSqlType type, + BaseSqlType type, void Function(GenerationContext)? defaultConstraints, String? customConstraints, Expression? defaultValue, diff --git a/drift/lib/src/runtime/query_builder/statements/select/custom_select.dart b/drift/lib/src/runtime/query_builder/statements/select/custom_select.dart index 86cd8afa..a992b315 100644 --- a/drift/lib/src/runtime/query_builder/statements/select/custom_select.dart +++ b/drift/lib/src/runtime/query_builder/statements/select/custom_select.dart @@ -81,7 +81,17 @@ class QueryRow { /// support non-nullable types. T read(String key) { final type = DriftSqlType.forNullableType(); - return _db.typeMapping.read(type, data[key]) as T; + return readNullableWithType(type, key) as T; + } + + /// Interprets the column named [key] under the known drift type [type]. + /// + /// Like [read], except that the [type] is fixed and not inferred from the + /// type parameter [T]. Also, this method does not support nullable values - + /// use [readNullableWithType] if needed. + @optionalTypeArgs + T readWithType(BaseSqlType type, String key) { + return _db.typeMapping.read(type, data[key])!; } /// Reads a nullable value from this row. @@ -90,7 +100,7 @@ class QueryRow { /// drift (e.g. booleans, strings, numbers, dates, `Uint8List`s). T? readNullable(String key) { final type = DriftSqlType.forType(); - return _db.typeMapping.read(type, data[key]); + return readNullableWithType(type, key); } /// Interprets the column named [key] under the known drift type [type]. @@ -98,7 +108,7 @@ class QueryRow { /// Like [readNullable], except that the [type] is fixed and not inferred from /// the type parameter [T]. @optionalTypeArgs - T? readNullableWithType(DriftSqlType type, String key) { + T? readNullableWithType(BaseSqlType type, String key) { return _db.typeMapping.read(type, data[key]); } diff --git a/drift/lib/src/runtime/types/mapping.dart b/drift/lib/src/runtime/types/mapping.dart index 50fe7613..f3c77a0b 100644 --- a/drift/lib/src/runtime/types/mapping.dart +++ b/drift/lib/src/runtime/types/mapping.dart @@ -10,8 +10,7 @@ import '../query_builder/query_builder.dart'; /// Database-specific helper methods mapping Dart values from and to SQL /// variables or literals. -@sealed -class SqlTypes { +final class SqlTypes { // Stolen from DateTime._parseFormat static final RegExp _timeZoneInDateTime = RegExp(r' ?([-+])(\d\d)(?::?(\d\d))?$'); @@ -127,76 +126,77 @@ class SqlTypes { 'Must be null, bool, String, int, DateTime, Uint8List or double'); } - /// Maps a raw [sqlValue] to Dart given its sql [type]. - T? read(DriftSqlType type, Object? sqlValue) { + DateTime _readDateTime(Object sqlValue) { + if (storeDateTimesAsText) { + final rawValue = read(DriftSqlType.string, sqlValue)!; + DateTime result; + + // We store date times like this: + // + // - if it's in UTC, we call [DateTime.toIso8601String], so there's a + // trailing `Z`. We can just use [DateTime.parse] and get an utc + // datetime back. + // - for local date times, we append the time zone offset, e.g. + // `+02:00`. [DateTime.parse] respects this UTC offset and returns + // the correct date, but it returns it in UTC. Since we only use + // this format for local times, we need to transform it back to + // local. + // + // Additionally, complex date time expressions are wrapped in a + // `datetime` sqlite call, which doesn't append a `Z` or a time zone + // offset. As sqlite3 always uses UTC for these computations + // internally, we'll return a UTC datetime as well. + if (_timeZoneInDateTime.hasMatch(rawValue)) { + // Case 2: Explicit time zone offset given, we do this for local + // dates. + result = DateTime.parse(rawValue).toLocal(); + } else if (rawValue.endsWith('Z')) { + // Case 1: Date time in UTC, [DateTime.parse] will do the right + // thing. + result = DateTime.parse(rawValue); + } else { + // Result from complex date tmie transformation. Interpret as UTC, + // which is what sqlite3 does by default. + result = DateTime.parse('${rawValue}Z'); + } + + return result; + } else { + final unixSeconds = read(DriftSqlType.int, sqlValue)!; + return DateTime.fromMillisecondsSinceEpoch(unixSeconds * 1000); + } + } + + /// Maps a raw [sqlValue] to Dart given its sql [type] (typically a + /// [DriftSqlType]). + T? read(BaseSqlType type, Object? sqlValue) { if (sqlValue == null) return null; - // ignore: unnecessary_cast - switch (type as DriftSqlType) { - case DriftSqlType.bool: - return (sqlValue != 0 && sqlValue != false) as T; - case DriftSqlType.string: - return sqlValue.toString() as T; - case DriftSqlType.bigInt: - if (sqlValue is BigInt) return sqlValue as T?; - if (sqlValue is int) return BigInt.from(sqlValue) as T; - return BigInt.parse(sqlValue.toString()) as T; - case DriftSqlType.int: - if (sqlValue is int) return sqlValue as T; - if (sqlValue is BigInt) return sqlValue.toInt() as T; - return int.parse(sqlValue.toString()) as T; - case DriftSqlType.dateTime: - if (storeDateTimesAsText) { - final rawValue = read(DriftSqlType.string, sqlValue)!; - DateTime result; - - // We store date times like this: - // - // - if it's in UTC, we call [DateTime.toIso8601String], so there's a - // trailing `Z`. We can just use [DateTime.parse] and get an utc - // datetime back. - // - for local date times, we append the time zone offset, e.g. - // `+02:00`. [DateTime.parse] respects this UTC offset and returns - // the correct date, but it returns it in UTC. Since we only use - // this format for local times, we need to transform it back to - // local. - // - // Additionally, complex date time expressions are wrapped in a - // `datetime` sqlite call, which doesn't append a `Z` or a time zone - // offset. As sqlite3 always uses UTC for these computations - // internally, we'll return a UTC datetime as well. - if (_timeZoneInDateTime.hasMatch(rawValue)) { - // Case 2: Explicit time zone offset given, we do this for local - // dates. - result = DateTime.parse(rawValue).toLocal(); - } else if (rawValue.endsWith('Z')) { - // Case 1: Date time in UTC, [DateTime.parse] will do the right - // thing. - result = DateTime.parse(rawValue); - } else { - // Result from complex date tmie transformation. Interpret as UTC, - // which is what sqlite3 does by default. - result = DateTime.parse('${rawValue}Z'); - } - - return result as T; - } else { - final unixSeconds = read(DriftSqlType.int, sqlValue)!; - return DateTime.fromMillisecondsSinceEpoch(unixSeconds * 1000) as T; - } - case DriftSqlType.blob: - if (sqlValue is String) { - final list = sqlValue.codeUnits; - return Uint8List.fromList(list) as T; - } - return sqlValue as T; - case DriftSqlType.double: - if (sqlValue case final BigInt bi) return bi.toDouble() as T; - - return (sqlValue as num?)?.toDouble() as T; - case DriftSqlType.any: - return DriftAny(sqlValue) as T; - } + return switch (type) { + DriftSqlType.bool => (sqlValue != 0 && sqlValue != false), + DriftSqlType.string => sqlValue.toString(), + DriftSqlType.bigInt => switch (sqlValue) { + BigInt() => sqlValue, + int() => BigInt.from(sqlValue), + _ => BigInt.parse(sqlValue.toString()), + }, + DriftSqlType.int => switch (sqlValue) { + int() => sqlValue, + BigInt() => sqlValue.toInt(), + _ => int.parse(sqlValue.toString()), + }, + DriftSqlType.dateTime => _readDateTime(sqlValue), + DriftSqlType.blob => switch (sqlValue) { + String() => Uint8List.fromList(sqlValue.codeUnits), + _ => sqlValue, + }, + DriftSqlType.double => switch (sqlValue) { + BigInt() => sqlValue.toDouble(), + _ => (sqlValue as num).toDouble(), + }, + DriftSqlType.any => DriftAny(sqlValue), + CustomSqlType() => type.read(sqlValue), + } as T; } } @@ -212,8 +212,7 @@ class SqlTypes { /// column with an `ANY` type. /// /// [STRICT tables]: https://www.sqlite.org/stricttables.html -@sealed -class DriftAny { +final class DriftAny { /// The direct, unmodified SQL value being wrapped by this [DriftAny] /// instance. /// @@ -249,7 +248,7 @@ class DriftAny { /// [DatabaseConnectionUser.typeMapping]. /// /// [as text]: https://drift.simonbinder.eu/docs/getting-started/advanced_dart_tables/#datetime-options - T readAs(DriftSqlType type, SqlTypes types) { + T readAs(BaseSqlType type, SqlTypes types) { return types.read(type, rawSqlValue)!; } @@ -268,19 +267,16 @@ class DriftAny { } } -/// In [DriftSqlType.forNullableType], we need to do an `is` check over -/// `DriftSqlType` with a potentially nullable `T`. Since `DriftSqlType` is -/// defined with a non-nullable `T`, this is illegal. -/// The non-nullable upper bound in [DriftSqlType] is generally useful, for -/// instance because it works well with [SqlTypes.read] which can then have a -/// sound nullable return type. -/// -/// As a hack, we define this base class that doesn't have this restriction and -/// use this one for type checks. -abstract class _InternalDriftSqlType {} +/// The superclass for SQL types, whether built-in to drift ([DriftSqlType]) or +/// provided by the user through [CustomSqlType]s. +@internal +sealed class BaseSqlType { + /// Returns a suitable representation of this type in SQL. + String sqlTypeName(GenerationContext context); +} /// An enumation of type mappings that are builtin to drift and `drift_dev`. -enum DriftSqlType implements _InternalDriftSqlType { +enum DriftSqlType implements BaseSqlType { /// A boolean type, represented as `0` or `1` (int) in SQL. bool(), @@ -316,7 +312,7 @@ enum DriftSqlType implements _InternalDriftSqlType { /// [STRICT tables]: https://www.sqlite.org/stricttables.html any(); - /// Returns a suitable representation of this type in SQL. + @override String sqlTypeName(GenerationContext context) { final dialect = context.dialect; @@ -390,7 +386,7 @@ enum DriftSqlType implements _InternalDriftSqlType { // typecheck where that doesn't work (which can be the case for complex // type like `forNullableType>`). final type = _dartToDrift[Dart] ?? - values.whereType<_InternalDriftSqlType>().singleOrNull; + values.whereType>().singleOrNull; if (type == null) { throw ArgumentError('Could not find a matching SQL type for $Dart'); @@ -399,3 +395,31 @@ enum DriftSqlType implements _InternalDriftSqlType { return type as DriftSqlType; } } + +/// Interface for a custom SQL type. +/// +/// Being designed with sqlite3 as its primary database engine, drift lacks +/// builtin support for the rich type system found in more complex database +/// systems like postgres. By providing the [CustomSqlType] interface, drift can +/// be extended to support any database type by customizing the way these types +/// are mapped from and to the database. +/// +/// To create a custom type, implement this interface. You can now create values +/// of this type by passing it to [Constant] or [Variable], [Expression.cast] +/// and other methods operating on types. +/// Custom types can also be applied to table columns, see https://drift.simonbinder.eu/docs/sql-api/types/ +/// for details. +abstract interface class CustomSqlType + implements BaseSqlType { + /// Interprets the underlying [fromSql] value from the database driver into + /// the Dart representation [T] of this type. + T read(Object fromSql); + + /// Maps the [dartValue] to a value understood by the underlying database + /// driver. + Object mapToSqlParameter(T dartValue); + + /// Maps the [dartValue] to a SQL snippet that can be embedded as a literal + /// into SQL queries generated by drift. + String mapToSqlLiteral(T dartValue); +} diff --git a/drift/test/database/statements/schema_test.dart b/drift/test/database/statements/schema_test.dart index da7397fe..bdb0606a 100644 --- a/drift/test/database/statements/schema_test.dart +++ b/drift/test/database/statements/schema_test.dart @@ -109,6 +109,14 @@ void main() { [])); }); + test('creates tables with custom types', () async { + await db.createMigrator().createTable(db.withCustomType); + + verify(mockExecutor.runCustom( + 'CREATE TABLE IF NOT EXISTS "with_custom_type" ("id" uuid NOT NULL);', + [])); + }); + test('creates views through create()', () async { await db.createMigrator().create(db.categoryTodoCountView); diff --git a/drift/test/database/types/custom_type_test.dart b/drift/test/database/types/custom_type_test.dart new file mode 100644 index 00000000..173e13b8 --- /dev/null +++ b/drift/test/database/types/custom_type_test.dart @@ -0,0 +1,64 @@ +import 'package:drift/drift.dart'; +import 'package:mockito/mockito.dart'; +import 'package:test/test.dart'; +import 'package:uuid/uuid.dart'; + +import '../../generated/todos.dart'; +import '../../test_utils/test_utils.dart'; + +void main() { + final uuid = Uuid().v4obj(); + + group('in expression', () { + test('variable', () { + final c = Variable(uuid, const UuidType()); + + expect(c.driftSqlType, isA()); + expect(c, generates('?', [uuid])); + }); + + test('constant', () { + final c = Constant(uuid, const UuidType()); + + expect(c.driftSqlType, isA()); + expect(c, generates("'$uuid'")); + }); + + test('cast', () { + final cast = Variable('foo').cast(const UuidType()); + + expect(cast.driftSqlType, isA()); + expect(cast, generates('CAST(? AS uuid)', ['foo'])); + }); + }); + + test('for inserts', () async { + final executor = MockExecutor(); + final database = TodoDb(executor); + addTearDown(database.close); + + final uuid = Uuid().v4obj(); + await database + .into(database.withCustomType) + .insert(WithCustomTypeCompanion.insert(id: uuid)); + + verify(executor + .runInsert('INSERT INTO "with_custom_type" ("id") VALUES (?)', [uuid])); + }); + + test('for selects', () async { + final executor = MockExecutor(); + final database = TodoDb(executor); + addTearDown(database.close); + + final uuid = Uuid().v4obj(); + when(executor.runSelect(any, any)).thenAnswer((_) { + return Future.value([ + {'id': uuid} + ]); + }); + + final row = await database.withCustomType.all().getSingle(); + expect(row.id, uuid); + }); +} diff --git a/drift/test/generated/custom_tables.g.dart b/drift/test/generated/custom_tables.g.dart index 7f02afa8..a6018c24 100644 --- a/drift/test/generated/custom_tables.g.dart +++ b/drift/test/generated/custom_tables.g.dart @@ -797,10 +797,12 @@ class ConfigCompanion extends UpdateCompanion { } if (syncState.present) { final converter = ConfigTable.$convertersyncStaten; + map['sync_state'] = Variable(converter.toSql(syncState.value)); } if (syncStateImplicit.present) { final converter = ConfigTable.$convertersyncStateImplicitn; + map['sync_state_implicit'] = Variable(converter.toSql(syncStateImplicit.value)); } diff --git a/drift/test/generated/todos.dart b/drift/test/generated/todos.dart index 301c586d..43509b3e 100644 --- a/drift/test/generated/todos.dart +++ b/drift/test/generated/todos.dart @@ -180,6 +180,32 @@ abstract class TodoWithCategoryView extends View { .join([innerJoin(categories, categories.id.equalsExp(todos.category))]); } +class WithCustomType extends Table { + Column get id => customType(const UuidType())(); +} + +class UuidType implements CustomSqlType { + const UuidType(); + + @override + String mapToSqlLiteral(UuidValue dartValue) { + return "'$dartValue'"; + } + + @override + Object mapToSqlParameter(UuidValue dartValue) { + return dartValue; + } + + @override + UuidValue read(Object fromSql) { + return fromSql as UuidValue; + } + + @override + String sqlTypeName(GenerationContext context) => 'uuid'; +} + @DriftDatabase( tables: [ TodosTable, @@ -188,6 +214,7 @@ abstract class TodoWithCategoryView extends View { SharedTodos, TableWithoutPK, PureDefaults, + WithCustomType, ], views: [ CategoryTodoCountView, diff --git a/drift/test/generated/todos.g.dart b/drift/test/generated/todos.g.dart index 6362fc83..50933861 100644 --- a/drift/test/generated/todos.g.dart +++ b/drift/test/generated/todos.g.dart @@ -247,6 +247,7 @@ class CategoriesCompanion extends UpdateCompanion { } if (priority.present) { final converter = $CategoriesTable.$converterpriority; + map['priority'] = Variable(converter.toSql(priority.value)); } return map; @@ -598,6 +599,7 @@ class TodosTableCompanion extends UpdateCompanion { } if (status.present) { final converter = $TodosTableTable.$converterstatusn; + map['status'] = Variable(converter.toSql(status.value)); } return map; @@ -1269,6 +1271,7 @@ class TableWithoutPKCompanion extends UpdateCompanion { } if (custom.present) { final converter = $TableWithoutPKTable.$convertercustom; + map['custom'] = Variable(converter.toSql(custom.value)); } if (rowid.present) { @@ -1455,6 +1458,7 @@ class PureDefaultsCompanion extends UpdateCompanion { final map = {}; if (txt.present) { final converter = $PureDefaultsTable.$convertertxtn; + map['insert'] = Variable(converter.toSql(txt.value)); } if (rowid.present) { @@ -1473,6 +1477,160 @@ class PureDefaultsCompanion extends UpdateCompanion { } } +class $WithCustomTypeTable extends WithCustomType + with TableInfo<$WithCustomTypeTable, WithCustomTypeData> { + @override + final GeneratedDatabase attachedDatabase; + final String? _alias; + $WithCustomTypeTable(this.attachedDatabase, [this._alias]); + static const VerificationMeta _idMeta = const VerificationMeta('id'); + @override + late final GeneratedColumn id = GeneratedColumn( + 'id', aliasedName, false, + type: const UuidType(), requiredDuringInsert: true); + @override + List get $columns => [id]; + @override + String get aliasedName => _alias ?? actualTableName; + @override + String get actualTableName => $name; + static const String $name = 'with_custom_type'; + @override + VerificationContext validateIntegrity(Insertable instance, + {bool isInserting = false}) { + final context = VerificationContext(); + final data = instance.toColumns(true); + if (data.containsKey('id')) { + context.handle(_idMeta, id.isAcceptableOrUnknown(data['id']!, _idMeta)); + } else if (isInserting) { + context.missing(_idMeta); + } + return context; + } + + @override + Set get $primaryKey => const {}; + @override + WithCustomTypeData map(Map data, {String? tablePrefix}) { + final effectivePrefix = tablePrefix != null ? '$tablePrefix.' : ''; + return WithCustomTypeData( + id: attachedDatabase.typeMapping + .read(const UuidType(), data['${effectivePrefix}id'])!, + ); + } + + @override + $WithCustomTypeTable createAlias(String alias) { + return $WithCustomTypeTable(attachedDatabase, alias); + } +} + +class WithCustomTypeData extends DataClass + implements Insertable { + final UuidValue id; + const WithCustomTypeData({required this.id}); + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + map['id'] = Variable(id); + return map; + } + + WithCustomTypeCompanion toCompanion(bool nullToAbsent) { + return WithCustomTypeCompanion( + id: Value(id), + ); + } + + factory WithCustomTypeData.fromJson(Map json, + {ValueSerializer? serializer}) { + serializer ??= driftRuntimeOptions.defaultSerializer; + return WithCustomTypeData( + id: serializer.fromJson(json['id']), + ); + } + factory WithCustomTypeData.fromJsonString(String encodedJson, + {ValueSerializer? serializer}) => + WithCustomTypeData.fromJson( + DataClass.parseJson(encodedJson) as Map, + serializer: serializer); + @override + Map toJson({ValueSerializer? serializer}) { + serializer ??= driftRuntimeOptions.defaultSerializer; + return { + 'id': serializer.toJson(id), + }; + } + + WithCustomTypeData copyWith({UuidValue? id}) => WithCustomTypeData( + id: id ?? this.id, + ); + @override + String toString() { + return (StringBuffer('WithCustomTypeData(') + ..write('id: $id') + ..write(')')) + .toString(); + } + + @override + int get hashCode => id.hashCode; + @override + bool operator ==(Object other) => + identical(this, other) || + (other is WithCustomTypeData && other.id == this.id); +} + +class WithCustomTypeCompanion extends UpdateCompanion { + final Value id; + final Value rowid; + const WithCustomTypeCompanion({ + this.id = const Value.absent(), + this.rowid = const Value.absent(), + }); + WithCustomTypeCompanion.insert({ + required UuidValue id, + this.rowid = const Value.absent(), + }) : id = Value(id); + static Insertable custom({ + Expression? id, + Expression? rowid, + }) { + return RawValuesInsertable({ + if (id != null) 'id': id, + if (rowid != null) 'rowid': rowid, + }); + } + + WithCustomTypeCompanion copyWith({Value? id, Value? rowid}) { + return WithCustomTypeCompanion( + id: id ?? this.id, + rowid: rowid ?? this.rowid, + ); + } + + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + if (id.present) { + map['id'] = Variable(id.value, const UuidType()); + } + if (rowid.present) { + map['rowid'] = Variable(rowid.value); + } + return map; + } + + @override + String toString() { + return (StringBuffer('WithCustomTypeCompanion(') + ..write('id: $id, ') + ..write('rowid: $rowid') + ..write(')')) + .toString(); + } +} + class CategoryTodoCountViewData extends DataClass { final int? categoryId; final String? description; @@ -1703,6 +1861,7 @@ abstract class _$TodoDb extends GeneratedDatabase { late final $SharedTodosTable sharedTodos = $SharedTodosTable(this); late final $TableWithoutPKTable tableWithoutPK = $TableWithoutPKTable(this); late final $PureDefaultsTable pureDefaults = $PureDefaultsTable(this); + late final $WithCustomTypeTable withCustomType = $WithCustomTypeTable(this); late final $CategoryTodoCountViewView categoryTodoCountView = $CategoryTodoCountViewView(this); late final $TodoWithCategoryViewView todoWithCategoryView = @@ -1787,6 +1946,7 @@ abstract class _$TodoDb extends GeneratedDatabase { sharedTodos, tableWithoutPK, pureDefaults, + withCustomType, categoryTodoCountView, todoWithCategoryView ]; diff --git a/drift_dev/lib/src/analysis/resolver/dart/column.dart b/drift_dev/lib/src/analysis/resolver/dart/column.dart index 95c0a762..ef9b9e4a 100644 --- a/drift_dev/lib/src/analysis/resolver/dart/column.dart +++ b/drift_dev/lib/src/analysis/resolver/dart/column.dart @@ -22,6 +22,7 @@ const String _startBool = 'boolean'; const String _startDateTime = 'dateTime'; const String _startBlob = 'blob'; const String _startReal = 'real'; +const String _startCustom = 'customType'; const Set _starters = { _startInt, @@ -33,6 +34,7 @@ const Set _starters = { _startDateTime, _startBlob, _startReal, + _startCustom, }; const String _methodNamed = 'named'; @@ -337,19 +339,40 @@ class ColumnParser { remainingExpr = inner; } - _resolver.resolver.driver.options.caseFromDartToSql; final sqlName = foundExplicitName ?? _resolver.resolver.driver.options.caseFromDartToSql .apply(getter.name.lexeme); - final sqlType = _startMethodToColumnType(foundStartMethod); + ColumnType columnType; + final helper = await _resolver.resolver.driver.loadKnownTypes(); + if (foundStartMethod == _startCustom) { + final expression = remainingExpr.argumentList.arguments.single; + + final custom = readCustomType( + element.library!, + expression, + helper, + (message) => _resolver.reportError( + DriftAnalysisError.inDartAst(element, mappedAs!, message), + ), + ); + columnType = custom != null + ? ColumnType.custom(custom) + // Fallback if we fail to read the custom type - we'll also emit an + // error int that case. + : ColumnType.drift(DriftSqlType.any); + } else { + columnType = + ColumnType.drift(_startMethodToBuiltinColumnType(foundStartMethod)); + } + AppliedTypeConverter? converter; if (mappedAs != null) { converter = readTypeConverter( element.library!, mappedAs, - sqlType, + columnType, nullable, (message) => _resolver.reportError( DriftAnalysisError.inDartAst(element, mappedAs!, message)), @@ -437,7 +460,7 @@ class ColumnParser { return PendingColumnInformation( DriftColumn( - sqlType: sqlType, + sqlType: columnType, nullable: nullable, nameInSql: sqlName, nameInDart: element.name!, @@ -454,7 +477,7 @@ class ColumnParser { ); } - DriftSqlType _startMethodToColumnType(String name) { + DriftSqlType _startMethodToBuiltinColumnType(String name) { return const { _startBool: DriftSqlType.bool, _startString: DriftSqlType.string, diff --git a/drift_dev/lib/src/analysis/resolver/dart/helper.dart b/drift_dev/lib/src/analysis/resolver/dart/helper.dart index 058ce08c..fb9338f5 100644 --- a/drift_dev/lib/src/analysis/resolver/dart/helper.dart +++ b/drift_dev/lib/src/analysis/resolver/dart/helper.dart @@ -27,6 +27,7 @@ class KnownDriftTypes { final InterfaceType tableInfoType; final InterfaceType driftDatabase; final InterfaceType driftAccessor; + final InterfaceElement customSqlType; final InterfaceElement typeConverter; final InterfaceElement jsonTypeConverter; final InterfaceType driftAny; @@ -39,6 +40,7 @@ class KnownDriftTypes { this.tableIndexType, this.viewType, this.tableInfoType, + this.customSqlType, this.typeConverter, this.jsonTypeConverter, this.driftDatabase, @@ -62,6 +64,7 @@ class KnownDriftTypes { (exportNamespace.get('TableIndex') as InterfaceElement).thisType, (exportNamespace.get('View') as InterfaceElement).thisType, (exportNamespace.get('TableInfo') as InterfaceElement).thisType, + exportNamespace.get('CustomSqlType') as InterfaceElement, exportNamespace.get('TypeConverter') as InterfaceElement, exportNamespace.get('JsonTypeConverter2') as InterfaceElement, dbElement.defaultInstantiation, @@ -81,6 +84,10 @@ class KnownDriftTypes { return type.asInstanceOf(typeConverter); } + InterfaceType? asCustomType(DartType type) { + return type.asInstanceOf(customSqlType); + } + /// Converts the given Dart [type] into an instantiation of the /// `JsonTypeConverter` class from drift. /// diff --git a/drift_dev/lib/src/analysis/resolver/dart/view.dart b/drift_dev/lib/src/analysis/resolver/dart/view.dart index 41146767..bea648d6 100644 --- a/drift_dev/lib/src/analysis/resolver/dart/view.dart +++ b/drift_dev/lib/src/analysis/resolver/dart/view.dart @@ -275,7 +275,7 @@ class DartViewResolver extends LocalElementResolver { columns.add(DriftColumn( declaration: DriftDeclaration.dartElement(getter), - sqlType: sqlType, + sqlType: ColumnType.drift(sqlType), nameInDart: getter.name, nameInSql: ReCase(getter.name).snakeCase, nullable: true, diff --git a/drift_dev/lib/src/analysis/resolver/drift/element_resolver.dart b/drift_dev/lib/src/analysis/resolver/drift/element_resolver.dart index 103d2089..010356b3 100644 --- a/drift_dev/lib/src/analysis/resolver/drift/element_resolver.dart +++ b/drift_dev/lib/src/analysis/resolver/drift/element_resolver.dart @@ -3,7 +3,6 @@ import 'package:analyzer/dart/element/nullability_suffix.dart'; import 'package:collection/collection.dart'; import 'package:analyzer/dart/element/element.dart'; import 'package:analyzer/dart/element/type.dart'; -import 'package:drift/drift.dart'; import 'package:sqlparser/sqlparser.dart'; import 'package:sqlparser/utils/find_referenced_tables.dart'; @@ -22,8 +21,33 @@ abstract class DriftElementResolver DriftElementResolver( super.file, super.discovered, super.resolver, super.state); + Future resolveCustomColumnType( + InlineDartToken type) async { + dart.Expression expression; + try { + expression = await resolver.driver.backend.resolveExpression( + file.ownUri, + type.dartCode, + file.discovery!.importDependencies + .map((e) => e.uri.toString()) + .where((e) => e.endsWith('.dart')), + ); + } on CannotReadExpressionException catch (e) { + reportError(DriftAnalysisError.inDriftFile(type, e.msg)); + return null; + } + + final knownTypes = await resolver.driver.loadKnownTypes(); + return readCustomType( + knownTypes.helperLibrary, + expression, + knownTypes, + (msg) => reportError(DriftAnalysisError.inDriftFile(type, msg)), + ); + } + Future typeConverterFromMappedBy( - DriftSqlType sqlType, bool nullable, MappedBy mapper) async { + ColumnType sqlType, bool nullable, MappedBy mapper) async { final code = mapper.mapper.dartCode; dart.Expression expression; diff --git a/drift_dev/lib/src/analysis/resolver/drift/sqlparser/mapping.dart b/drift_dev/lib/src/analysis/resolver/drift/sqlparser/mapping.dart index 3e87a008..50a1d651 100644 --- a/drift_dev/lib/src/analysis/resolver/drift/sqlparser/mapping.dart +++ b/drift_dev/lib/src/analysis/resolver/drift/sqlparser/mapping.dart @@ -73,14 +73,17 @@ class TypeMapping { } ResolvedType _columnType(DriftColumn column) { - final type = - _driftTypeToParser(column.sqlType).withNullable(column.nullable); + var type = _driftTypeToParser(column.sqlType.builtin) + .withNullable(column.nullable); - if (column.typeConverter case final AppliedTypeConverter c) { - return type.addHint(TypeConverterHint(c)); - } else { - return type; + if (column.sqlType.isCustom) { + type = type.addHint(CustomTypeHint(column.sqlType.custom!)); } + if (column.typeConverter case AppliedTypeConverter c) { + type = type.addHint(TypeConverterHint(c)); + } + + return type; } ResolvedType _driftTypeToParser(DriftSqlType type) { @@ -103,11 +106,7 @@ class TypeMapping { }; } - DriftSqlType sqlTypeToDrift(ResolvedType? type) { - if (type == null) { - return DriftSqlType.string; - } - + DriftSqlType _toDefaultType(ResolvedType type) { switch (type.type) { case null: case BasicType.nullType: @@ -137,6 +136,19 @@ class TypeMapping { return DriftSqlType.any; } } + + ColumnType sqlTypeToDrift(ResolvedType? type) { + if (type == null) { + return const ColumnType.drift(DriftSqlType.string); + } + + final customHint = type.hint(); + if (customHint != null) { + return ColumnType.custom(customHint.type); + } + + return ColumnType.drift(_toDefaultType(type)); + } } /// Creates a [TypeFromText] implementation that will look up type converters @@ -176,6 +188,12 @@ class TypeConverterHint extends TypeHint { TypeConverterHint(this.converter); } +class CustomTypeHint extends TypeHint { + final CustomColumnType type; + + CustomTypeHint(this.type); +} + class _SimpleColumn extends Column implements ColumnWithType { @override final String name; diff --git a/drift_dev/lib/src/analysis/resolver/drift/table.dart b/drift_dev/lib/src/analysis/resolver/drift/table.dart index 8c41816e..e6b7c16a 100644 --- a/drift_dev/lib/src/analysis/resolver/drift/table.dart +++ b/drift_dev/lib/src/analysis/resolver/drift/table.dart @@ -43,31 +43,43 @@ class DriftTableResolver extends DriftElementResolver { for (final column in table.resultColumns) { String? overriddenDartName; - final type = resolver.driver.typeMapping.sqlTypeToDrift(column.type); + var type = resolver.driver.typeMapping.sqlTypeToDrift(column.type); final nullable = column.type.nullable != false; final constraints = []; AppliedTypeConverter? converter; AnnotatedDartCode? defaultArgument; String? overriddenJsonName; - final typeName = column.definition?.typeName; + final definition = column.definition; + if (definition != null) { + final typeName = definition.typeName; - final enumIndexMatch = typeName != null - ? FoundReferencesInSql.enumRegex.firstMatch(typeName) - : null; - if (enumIndexMatch != null) { - final dartTypeName = enumIndexMatch.group(2)!; - final dartType = await findDartTypeOrReportError( - dartTypeName, column.definition?.typeNames?.toSingleEntity ?? stmt); + final enumIndexMatch = typeName != null + ? FoundReferencesInSql.enumRegex.firstMatch(typeName) + : null; - if (dartType != null) { - converter = readEnumConverter( - (msg) => reportError( - DriftAnalysisError.inDriftFile(column.definition ?? stmt, msg)), - dartType, - type == DriftSqlType.int ? EnumType.intEnum : EnumType.textEnum, - await resolver.driver.loadKnownTypes(), - ); + if (definition.typeNames case [InlineDartToken token]) { + // An inline Dart token used as a type name indicates a custom type. + final custom = await resolveCustomColumnType(token); + if (custom != null) { + type = ColumnType.custom(custom); + } + } else if (enumIndexMatch != null) { + final dartTypeName = enumIndexMatch.group(2)!; + final dartType = await findDartTypeOrReportError(dartTypeName, + column.definition?.typeNames?.toSingleEntity ?? stmt); + + if (dartType != null) { + converter = readEnumConverter( + (msg) => reportError(DriftAnalysisError.inDriftFile( + column.definition ?? stmt, msg)), + dartType, + type.builtin == DriftSqlType.int + ? EnumType.intEnum + : EnumType.textEnum, + await resolver.driver.loadKnownTypes(), + ); + } } } diff --git a/drift_dev/lib/src/analysis/resolver/queries/query_analyzer.dart b/drift_dev/lib/src/analysis/resolver/queries/query_analyzer.dart index 5c11a3dc..2937bd14 100644 --- a/drift_dev/lib/src/analysis/resolver/queries/query_analyzer.dart +++ b/drift_dev/lib/src/analysis/resolver/queries/query_analyzer.dart @@ -1,5 +1,4 @@ import 'package:analyzer/dart/ast/ast.dart' as dart; -import 'package:drift/drift.dart' show DriftSqlType; import 'package:drift/drift.dart' as drift; import 'package:recase/recase.dart'; import 'package:sqlparser/sqlparser.dart' hide ResultColumn; @@ -715,7 +714,7 @@ class QueryAnalyzer { final type = placeholder.when( isExpression: (e) { final foundType = context.typeOf(e); - DriftSqlType? columnType; + ColumnType? columnType; if (foundType.type != null) { columnType = driver.typeMapping.sqlTypeToDrift(foundType.type); } diff --git a/drift_dev/lib/src/analysis/resolver/shared/dart_types.dart b/drift_dev/lib/src/analysis/resolver/shared/dart_types.dart index 659a6237..2d29f572 100644 --- a/drift_dev/lib/src/analysis/resolver/shared/dart_types.dart +++ b/drift_dev/lib/src/analysis/resolver/shared/dart_types.dart @@ -267,10 +267,30 @@ enum EnumType { textEnum, } +CustomColumnType? readCustomType( + LibraryElement library, + Expression dartExpression, + KnownDriftTypes helper, + void Function(String) reportError, +) { + final staticType = dartExpression.staticType; + final asCustomType = + staticType != null ? helper.asCustomType(staticType) : null; + + if (asCustomType == null) { + reportError('Not a custom type'); + return null; + } + + final dartType = asCustomType.typeArguments[0]; + + return CustomColumnType(AnnotatedDartCode.ast(dartExpression), dartType); +} + AppliedTypeConverter? readTypeConverter( LibraryElement library, Expression dartExpression, - DriftSqlType columnType, + ColumnType columnType, bool columnIsNullable, void Function(String) reportError, KnownDriftTypes helper, @@ -369,9 +389,9 @@ AppliedTypeConverter readEnumConverter( jsonType: columnEnumType == EnumType.intEnum ? typeProvider.intType : typeProvider.stringType, - sqlType: columnEnumType == EnumType.intEnum + sqlType: ColumnType.drift(columnEnumType == EnumType.intEnum ? DriftSqlType.int - : DriftSqlType.string, + : DriftSqlType.string), dartTypeIsNullable: false, sqlTypeIsNullable: false, isDriftEnumTypeConverter: true, @@ -415,7 +435,7 @@ void _checkParameterType( } bool checkType( - DriftSqlType columnType, + ColumnType columnType, bool columnIsNullable, AppliedTypeConverter? typeConverter, DartType typeToCheck, @@ -467,8 +487,12 @@ DartType regularColumnType( } extension on TypeProvider { - DartType typeFor(DriftSqlType type, KnownDriftTypes knownTypes) { - switch (type) { + DartType typeFor(ColumnType type, KnownDriftTypes knownTypes) { + if (type.custom case CustomColumnType custom) { + return custom.dartType; + } + + switch (type.builtin) { case DriftSqlType.int: return intType; case DriftSqlType.bigInt: diff --git a/drift_dev/lib/src/analysis/results/column.dart b/drift_dev/lib/src/analysis/results/column.dart index 4ec3461f..d34c3ad8 100644 --- a/drift_dev/lib/src/analysis/results/column.dart +++ b/drift_dev/lib/src/analysis/results/column.dart @@ -1,5 +1,4 @@ import 'package:analyzer/dart/element/type.dart'; -import 'package:drift/drift.dart' show DriftSqlType; import 'package:json_annotation/json_annotation.dart'; import 'package:sqlparser/sqlparser.dart' show GeneratedAs, ReferenceAction; import 'package:sqlparser/utils/node_to_text.dart'; @@ -15,7 +14,7 @@ part '../../generated/analysis/results/column.g.dart'; class DriftColumn implements HasType { @override - final DriftSqlType sqlType; + final ColumnType sqlType; @override bool get isArray => false; @@ -127,6 +126,15 @@ class DriftColumn implements HasType { } } +class CustomColumnType { + /// The Dart expression creating an instance of the `CustomSqlType` responsible + /// for the column. + final AnnotatedDartCode expression; + final DartType dartType; + + CustomColumnType(this.expression, this.dartType); +} + class AppliedTypeConverter { /// The Dart expression creating an instance of the applied type converter. final AnnotatedDartCode expression; @@ -136,7 +144,7 @@ class AppliedTypeConverter { /// The JSON type representation of this column, if this type converter /// applies to the JSON serialization as well. final DartType? jsonType; - final DriftSqlType sqlType; + final ColumnType sqlType; late DriftColumn? owningColumn; diff --git a/drift_dev/lib/src/analysis/results/dart.dart b/drift_dev/lib/src/analysis/results/dart.dart index 3acc2103..561189c6 100644 --- a/drift_dev/lib/src/analysis/results/dart.dart +++ b/drift_dev/lib/src/analysis/results/dart.dart @@ -128,13 +128,18 @@ class AnnotatedDartCodeBuilder { void addDriftType(HasType hasType) { void addNonListType() { final converter = hasType.typeConverter; + final customType = hasType.sqlType.custom; + if (converter != null) { final nullable = converter.canBeSkippedForNulls && hasType.nullable; addDartType(converter.dartType); if (nullable) addText('?'); + } else if (customType != null) { + addDartType(customType.dartType); + if (hasType.nullable) addText('?'); } else { - addTopLevel(dartTypeNames[hasType.sqlType]!); + addTopLevel(dartTypeNames[hasType.sqlType.builtin]!); if (hasType.nullable) addText('?'); } } diff --git a/drift_dev/lib/src/analysis/results/query.dart b/drift_dev/lib/src/analysis/results/query.dart index c027c16b..014c18ba 100644 --- a/drift_dev/lib/src/analysis/results/query.dart +++ b/drift_dev/lib/src/analysis/results/query.dart @@ -722,7 +722,7 @@ final class ScalarResultColumn extends ResultColumn implements HasType, ArgumentForQueryRowType { final String name; @override - final DriftSqlType sqlType; + final ColumnType sqlType; @override final bool nullable; @@ -746,17 +746,22 @@ final class ScalarResultColumn extends ResultColumn return dartNameForSqlColumn(name, existingNames: existingNames); } + int get _columnTypeCompatibilityHash { + return Object.hash(sqlType.builtin, sqlType.custom?.dartType); + } + @override int get compatibilityHashCode { - return Object.hash( - ScalarResultColumn, name, sqlType, nullable, typeConverter); + return Object.hash(ScalarResultColumn, name, _columnTypeCompatibilityHash, + nullable, typeConverter); } @override bool isCompatibleTo(ResultColumn other) { return other is ScalarResultColumn && other.name == name && - other.sqlType == sqlType && + other.sqlType.builtin == sqlType.builtin && + other.sqlType.custom?.dartType == sqlType.custom?.dartType && other.nullable == nullable && other.typeConverter == typeConverter; } @@ -924,7 +929,7 @@ class FoundVariable extends FoundElement implements HasType { /// The (inferred) type for this variable. @override - final DriftSqlType sqlType; + final ColumnType sqlType; /// The type converter to apply before writing this value. @override @@ -1014,7 +1019,7 @@ class SimpleDartPlaceholderType extends DartPlaceholderType { class ExpressionDartPlaceholderType extends DartPlaceholderType { /// The sql type of this expression. - final DriftSqlType? columnType; + final ColumnType? columnType; final Expression? defaultValue; ExpressionDartPlaceholderType(this.columnType, this.defaultValue); diff --git a/drift_dev/lib/src/analysis/results/table.dart b/drift_dev/lib/src/analysis/results/table.dart index f0c11604..c9385e09 100644 --- a/drift_dev/lib/src/analysis/results/table.dart +++ b/drift_dev/lib/src/analysis/results/table.dart @@ -7,6 +7,7 @@ import 'element.dart'; import 'column.dart'; import 'result_sets.dart'; +import 'types.dart'; class DriftTable extends DriftElementWithResultSet { @override @@ -82,7 +83,7 @@ class DriftTable extends DriftElementWithResultSet { this.attachedIndices = const [], }) { _rowIdColumn = DriftColumn( - sqlType: DriftSqlType.int, + sqlType: ColumnType.drift(DriftSqlType.int), nullable: false, nameInSql: 'rowid', nameInDart: 'rowid', @@ -126,8 +127,9 @@ class DriftTable extends DriftElementWithResultSet { final primaryKey = fullPrimaryKey; if (primaryKey.length == 1) { final column = primaryKey.single; - if (column.sqlType == DriftSqlType.int || - column.sqlType == DriftSqlType.bigInt) { + final builtinType = column.sqlType.builtin; + if (builtinType == DriftSqlType.int || + builtinType == DriftSqlType.bigInt) { // So this column is an alias for the rowid return column; } diff --git a/drift_dev/lib/src/analysis/results/types.dart b/drift_dev/lib/src/analysis/results/types.dart index 41eff1d6..54dd42c8 100644 --- a/drift_dev/lib/src/analysis/results/types.dart +++ b/drift_dev/lib/src/analysis/results/types.dart @@ -17,16 +17,47 @@ abstract class HasType { bool get isArray; /// The associated sql type. - DriftSqlType get sqlType; + ColumnType get sqlType; /// The applied type converter, or null if no type converter has been applied /// to this column. AppliedTypeConverter? get typeConverter; } +/// The underlying SQL type of a column analyzed by drift. +/// +/// We distinguish between types directly supported by drift, and types that +/// are supplied by another library. Custom types can hold different Dart types, +/// but are a feature distinct from type converters: They indicate that a type +/// is directly supported by the underlying database driver, whereas a type +/// converter is a mapping done in drift. +/// +/// In addition to the SQL type, we also track whether a column is nullable, +/// appears where an array is expected or has a type converter applied to it. +/// [HasType] is the interface for sql-typed elements and is implemented by +/// columns. +class ColumnType { + /// The builtin drift type used by this column. + /// + /// Even though it's unused there, custom types also have this field set - + /// to [DriftSqlType.any] because drift doesn't reinterpret these values at + /// all. + final DriftSqlType builtin; + + /// Details about the custom type, if one is present. + final CustomColumnType? custom; + + bool get isCustom => custom != null; + + const ColumnType.drift(this.builtin) : custom = null; + + ColumnType.custom(CustomColumnType this.custom) : builtin = DriftSqlType.any; +} + extension OperationOnTypes on HasType { - bool get isUint8ListInDart => - sqlType == DriftSqlType.blob && typeConverter == null; + bool get isUint8ListInDart { + return sqlType.builtin == DriftSqlType.blob && typeConverter == null; + } /// Whether this type is nullable in Dart bool get nullableInDart { @@ -52,15 +83,3 @@ Map dartTypeNames = Map.unmodifiable({ DriftSqlType.double: DartTopLevelSymbol('double', Uri.parse('dart:core')), DriftSqlType.any: DartTopLevelSymbol('DriftAny', AnnotatedDartCode.drift), }); - -/// Maps from a column type to code that can be used to create a variable of the -/// respective type. -const Map createVariable = { - DriftSqlType.bool: 'Variable.withBool', - DriftSqlType.string: 'Variable.withString', - DriftSqlType.int: 'Variable.withInt', - DriftSqlType.bigInt: 'Variable.withBigInt', - DriftSqlType.dateTime: 'Variable.withDateTime', - DriftSqlType.blob: 'Variable.withBlob', - DriftSqlType.double: 'Variable.withReal', -}; diff --git a/drift_dev/lib/src/analysis/serializer.dart b/drift_dev/lib/src/analysis/serializer.dart index 46fd2745..14db24a9 100644 --- a/drift_dev/lib/src/analysis/serializer.dart +++ b/drift_dev/lib/src/analysis/serializer.dart @@ -194,9 +194,23 @@ class ElementSerializer { }; } + Map _serializeColumnType(ColumnType type) { + final custom = type.custom; + + return { + if (custom != null) + 'custom': { + 'dart': _serializeType(custom.dartType), + 'expression': custom.expression.toJson(), + } + else + 'builtin': type.builtin.name, + }; + } + Map _serializeColumn(DriftColumn column) { return { - 'sqlType': column.sqlType.name, + 'sqlType': _serializeColumnType(column.sqlType), 'nullable': column.nullable, 'nameInSql': column.nameInSql, 'nameInDart': column.nameInDart, @@ -306,7 +320,7 @@ class ElementSerializer { 'expression': converter.expression.toJson(), 'dart_type': _serializeType(converter.dartType), 'json_type': _serializeType(converter.jsonType), - 'sql_type': converter.sqlType.name, + 'sql_type': _serializeColumnType(converter.sqlType), 'dart_type_is_nullable': converter.dartTypeIsNullable, 'sql_type_is_nullable': converter.sqlTypeIsNullable, 'is_drift_enum_converter': converter.isDriftEnumTypeConverter, @@ -709,11 +723,24 @@ class ElementDeserializer { } } + Future _readColumnType(Map json, Uri definition) async { + if (json.containsKey('custom')) { + return ColumnType.custom(CustomColumnType( + AnnotatedDartCode.fromJson(json['expression'] as Map), + await _readDartType(definition, json['dart'] as int), + )); + } else { + return ColumnType.drift( + DriftSqlType.values.byName(json['builtin'] as String)); + } + } + Future _readColumn(Map json, DriftElementId ownTable) async { final rawConverter = json['typeConverter'] as Map?; return DriftColumn( - sqlType: DriftSqlType.values.byName(json['sqlType'] as String), + sqlType: + await _readColumnType(json['sqlType'] as Map, ownTable.libraryUri), nullable: json['nullable'] as bool, nameInSql: json['nameInSql'] as String, nameInDart: json['nameInDart'] as String, @@ -752,7 +779,7 @@ class ElementDeserializer { jsonType: json['json_type'] != null ? await _readDartType(definition, json['json_type'] as int) : null, - sqlType: DriftSqlType.values.byName(json['sql_type'] as String), + sqlType: await _readColumnType(json['sql_type'] as Map, definition), dartTypeIsNullable: json['dart_type_is_nullable'] as bool, sqlTypeIsNullable: json['sql_type_is_nullable'] as bool, isDriftEnumTypeConverter: json['is_drift_enum_converter'] as bool, diff --git a/drift_dev/lib/src/services/schema/schema_files.dart b/drift_dev/lib/src/services/schema/schema_files.dart index c31b2e5a..3127b834 100644 --- a/drift_dev/lib/src/services/schema/schema_files.dart +++ b/drift_dev/lib/src/services/schema/schema_files.dart @@ -158,7 +158,7 @@ class SchemaWriter { return { 'name': column.nameInSql, 'getter_name': column.nameInDart, - 'moor_type': column.sqlType.toSerializedString(), + 'moor_type': column.sqlType.builtin.toSerializedString(), 'nullable': column.nullable, 'customConstraints': column.customConstraints, if (constraints[SqlDialect.sqlite]!.isNotEmpty && @@ -467,7 +467,7 @@ class SchemaReader { // Note: Not including client default code because that usually depends on // imports from the database. return DriftColumn( - sqlType: columnType, + sqlType: ColumnType.drift(columnType), nullable: nullable, nameInSql: name, nameInDart: getterName ?? ReCase(name).camelCase, diff --git a/drift_dev/lib/src/writer/function_stubs_writer.dart b/drift_dev/lib/src/writer/function_stubs_writer.dart index 86572bec..30705685 100644 --- a/drift_dev/lib/src/writer/function_stubs_writer.dart +++ b/drift_dev/lib/src/writer/function_stubs_writer.dart @@ -59,7 +59,7 @@ class FunctionStubsWriter { String _nameFor(String sqlName) => ReCase(sqlName).camelCase; void _writeTypeFor(ResolvedType type) { - final driftType = _driver.typeMapping.sqlTypeToDrift(type); + final driftType = _driver.typeMapping.sqlTypeToDrift(type).builtin; _emitter.writeDart(AnnotatedDartCode([dartTypeNames[driftType]!])); if (type.nullable == true) { diff --git a/drift_dev/lib/src/writer/queries/query_writer.dart b/drift_dev/lib/src/writer/queries/query_writer.dart index e14a1d61..b150576a 100644 --- a/drift_dev/lib/src/writer/queries/query_writer.dart +++ b/drift_dev/lib/src/writer/queries/query_writer.dart @@ -202,10 +202,19 @@ class QueryWriter { } final dartLiteral = asDartLiteral(name); - final method = isNullable ? 'readNullable' : 'read'; + final rawDartType = - _emitter.dartCode(AnnotatedDartCode([dartTypeNames[column.sqlType]!])); - var code = 'row.$method<$rawDartType>($dartLiteral)'; + _emitter.dartCode(_emitter.innerColumnType(column.sqlType)); + String code; + + if (column.sqlType.isCustom) { + final method = isNullable ? 'readNullableWithType' : 'readWithType'; + final typeImpl = _emitter.dartCode(column.sqlType.custom!.expression); + code = 'row.$method<$rawDartType>($dartLiteral, $typeImpl)'; + } else { + final method = isNullable ? 'readNullable' : 'read'; + code = 'row.$method<$rawDartType>($dartLiteral)'; + } final converter = column.typeConverter; if (converter != null) { @@ -838,8 +847,8 @@ class _ExpandedVariableWriter { // write all the variables sequentially. String constructVar(String dartExpr) { // No longer an array here, we apply a for loop if necessary - final type = - _emitter.dartCode(_emitter.innerColumnType(element, nullable: false)); + final type = _emitter + .dartCode(_emitter.innerColumnType(element.sqlType, nullable: false)); final varType = _emitter.drift('Variable'); final buffer = StringBuffer('$varType<$type>('); diff --git a/drift_dev/lib/src/writer/queries/utils.dart b/drift_dev/lib/src/writer/queries/utils.dart index 022345c1..d903bc2f 100644 --- a/drift_dev/lib/src/writer/queries/utils.dart +++ b/drift_dev/lib/src/writer/queries/utils.dart @@ -26,7 +26,7 @@ extension FoundElementType on FoundElement { builder ..addSymbol('Expression', AnnotatedDartCode.drift) ..addText('<') - ..addTopLevel(dartTypeNames[kind.columnType]!) + ..addCode(scope.innerColumnType(kind.columnType!)) ..addText('>'); } else if (kind is InsertableDartPlaceholderType) { final table = kind.table; diff --git a/drift_dev/lib/src/writer/schema_version_writer.dart b/drift_dev/lib/src/writer/schema_version_writer.dart index c6856a6b..4eed4c63 100644 --- a/drift_dev/lib/src/writer/schema_version_writer.dart +++ b/drift_dev/lib/src/writer/schema_version_writer.dart @@ -47,7 +47,7 @@ final class _TableShape { DriftElementWithResultSet e) { return { for (final column in e.columns) - column.nameInDart: (column.nameInSql, column.sqlType), + column.nameInDart: (column.nameInSql, column.sqlType.builtin), }; } } diff --git a/drift_dev/lib/src/writer/tables/data_class_writer.dart b/drift_dev/lib/src/writer/tables/data_class_writer.dart index a0857448..3f51b147 100644 --- a/drift_dev/lib/src/writer/tables/data_class_writer.dart +++ b/drift_dev/lib/src/writer/tables/data_class_writer.dart @@ -339,7 +339,13 @@ class RowMappingWriter { final columnName = column.nameInSql; final rawData = "data['\${effectivePrefix}$columnName']"; - final sqlType = writer.drift(column.sqlType.toString()); + String sqlType; + if (column.sqlType.custom case CustomColumnType custom) { + sqlType = writer.dartCode(custom.expression); + } else { + sqlType = writer.drift(column.sqlType.builtin.toString()); + } + var loadType = '$databaseGetter.typeMapping.read($sqlType, $rawData)'; if (!column.nullable) { diff --git a/drift_dev/lib/src/writer/tables/table_writer.dart b/drift_dev/lib/src/writer/tables/table_writer.dart index 40466b15..73d9f761 100644 --- a/drift_dev/lib/src/writer/tables/table_writer.dart +++ b/drift_dev/lib/src/writer/tables/table_writer.dart @@ -210,7 +210,13 @@ abstract class TableOrViewWriter { } } - additionalParams['type'] = emitter.drift(column.sqlType.toString()); + if (column.sqlType.isCustom) { + additionalParams['type'] = + emitter.dartCode(column.sqlType.custom!.expression); + } else { + additionalParams['type'] = + emitter.drift(column.sqlType.builtin.toString()); + } if (isRequiredForInsert != null) { additionalParams['requiredDuringInsert'] = isRequiredForInsert.toString(); @@ -256,7 +262,7 @@ abstract class TableOrViewWriter { emitter.dartCode(column.clientDefaultCode!); } - final innerType = emitter.innerColumnType(column); + final innerType = emitter.innerColumnType(column.sqlType); var type = '${emitter.drift('GeneratedColumn')}<${emitter.dartCode(innerType)}>'; expressionBuffer diff --git a/drift_dev/lib/src/writer/tables/update_companion_writer.dart b/drift_dev/lib/src/writer/tables/update_companion_writer.dart index 2752afa5..d5abc262 100644 --- a/drift_dev/lib/src/writer/tables/update_companion_writer.dart +++ b/drift_dev/lib/src/writer/tables/update_companion_writer.dart @@ -135,7 +135,8 @@ class UpdateCompanionWriter { final expression = _emitter.drift('Expression'); for (final column in columns) { - final typeName = _emitter.dartCode(_emitter.innerColumnType(column)); + final typeName = + _emitter.dartCode(_emitter.innerColumnType(column.sqlType)); _buffer.write('$expression<$typeName>? ${column.nameInDart}, \n'); } @@ -190,32 +191,35 @@ class UpdateCompanionWriter { for (final column in columns) { final getterName = thisIfNeeded(column.nameInDart, locals); - _buffer.write('if ($getterName.present) {'); + _buffer.writeln('if ($getterName.present) {'); final typeName = _emitter.dartCode(_emitter.variableTypeCode(column, nullable: false)); final mapSetter = 'map[${asDartLiteral(column.nameInSql)}] = ' '${_emitter.drift('Variable')}<$typeName>'; + var value = '$getterName.value'; final converter = column.typeConverter; if (converter != null) { // apply type converter before writing the variable final fieldName = _emitter.dartCode( _emitter.readConverter(converter, forNullable: column.nullable)); - _buffer - ..write('final converter = $fieldName;\n') - ..write(mapSetter) - ..write('(converter.toSql($getterName.value)') - ..write(');'); - } else { - // no type converter. Write variable directly - _buffer - ..write(mapSetter) - ..write('(') - ..write('$getterName.value') - ..write(');'); + _buffer.writeln('final converter = $fieldName;\n'); + value = 'converter.toSql($value)'; } - _buffer.write('}'); + _buffer + ..write(mapSetter) + ..write('($value'); + + if (column.sqlType.isCustom) { + // Also specify the custom type since it can't be inferred from the + // value passed to the variable. + _buffer + ..write(', ') + ..write(_emitter.dartCode(column.sqlType.custom!.expression)); + } + + _buffer.writeln(');}'); } _buffer.write('return map; \n}\n'); diff --git a/drift_dev/lib/src/writer/utils/column_constraints.dart b/drift_dev/lib/src/writer/utils/column_constraints.dart index d7eebb76..de81b33c 100644 --- a/drift_dev/lib/src/writer/utils/column_constraints.dart +++ b/drift_dev/lib/src/writer/utils/column_constraints.dart @@ -72,7 +72,7 @@ Map defaultConstraints(DriftColumn column) { } } - if (column.sqlType == DriftSqlType.bool) { + if (column.sqlType.builtin == DriftSqlType.bool) { final name = column.nameInSql; dialectSpecificConstraints[SqlDialect.sqlite]! .add('CHECK (${SqlDialect.sqlite.escape(name)} IN (0, 1))'); diff --git a/drift_dev/lib/src/writer/writer.dart b/drift_dev/lib/src/writer/writer.dart index bd4e1fca..e64b8953 100644 --- a/drift_dev/lib/src/writer/writer.dart +++ b/drift_dev/lib/src/writer/writer.dart @@ -134,7 +134,16 @@ abstract class _NodeOrWriter { {bool makeNullable = false}) { // Write something like `TypeConverter` return AnnotatedDartCode.build((b) { - var sqlDartType = dartTypeNames[converter.sqlType]!; + AnnotatedDartCode sqlDartType; + + if (converter.sqlType.isCustom) { + sqlDartType = + AnnotatedDartCode.type(converter.sqlType.custom!.dartType); + } else { + sqlDartType = + AnnotatedDartCode([dartTypeNames[converter.sqlType.builtin]!]); + } + final className = converter.alsoAppliesToJsonConversion ? 'JsonTypeConverter2' : 'TypeConverter'; @@ -145,7 +154,7 @@ abstract class _NodeOrWriter { ..addDartType(converter.dartType) ..questionMarkIfNullable(makeNullable) ..addText(',') - ..addTopLevel(sqlDartType) + ..addCode(sqlDartType) ..questionMarkIfNullable(makeNullable || converter.sqlTypeIsNullable); if (converter.alsoAppliesToJsonConversion) { @@ -169,7 +178,8 @@ abstract class _NodeOrWriter { /// This is the same as [dartType] but without custom types. AnnotatedDartCode variableTypeCode(HasType type, {bool? nullable}) { if (type.isArray) { - final inner = innerColumnType(type, nullable: nullable ?? type.nullable); + final inner = + innerColumnType(type.sqlType, nullable: nullable ?? type.nullable); return AnnotatedDartCode([ DartTopLevelSymbol.list, '<', @@ -177,7 +187,7 @@ abstract class _NodeOrWriter { '>', ]); } else { - return innerColumnType(type, nullable: nullable ?? type.nullable); + return innerColumnType(type.sqlType, nullable: nullable ?? type.nullable); } } @@ -185,11 +195,20 @@ abstract class _NodeOrWriter { /// [nullable] parameter. /// /// This type does not respect type converters or arrays. - AnnotatedDartCode innerColumnType(HasType type, {bool nullable = false}) { - return AnnotatedDartCode([ - dartTypeNames[type.sqlType], - if (nullable) '?', - ]); + AnnotatedDartCode innerColumnType(ColumnType type, {bool nullable = false}) { + return AnnotatedDartCode.build((b) { + final custom = type.custom; + + if (custom != null) { + b.addDartType(custom.dartType); + } else { + b.addTopLevel(dartTypeNames[type.builtin]!); + } + + if (nullable) { + b.addText('?'); + } + }); } String refUri(Uri definition, String element) { diff --git a/drift_dev/test/analysis/generic_test.dart b/drift_dev/test/analysis/generic_test.dart index d7841023..8b19e74e 100644 --- a/drift_dev/test/analysis/generic_test.dart +++ b/drift_dev/test/analysis/generic_test.dart @@ -172,7 +172,8 @@ class ProgrammingLanguages extends Table { final tablesFile = await backend.analyze('package:a/tables.drift'); final librariesQuery = tablesFile.fileAnalysis!.resolvedQueries.values .singleWhere((e) => e.name == 'findLibraries') as SqlSelectQuery; - expect(librariesQuery.variables.single.sqlType, DriftSqlType.string); + expect( + librariesQuery.variables.single.sqlType.builtin, DriftSqlType.string); expect(librariesQuery.declaredInDriftFile, isTrue); }); diff --git a/drift_dev/test/analysis/resolver/dart/column_test.dart b/drift_dev/test/analysis/resolver/dart/column_test.dart index 3f5e8b5b..2e45c5ce 100644 --- a/drift_dev/test/analysis/resolver/dart/column_test.dart +++ b/drift_dev/test/analysis/resolver/dart/column_test.dart @@ -1,3 +1,4 @@ +import 'package:drift/drift.dart'; import 'package:drift_dev/src/analysis/options.dart'; import 'package:drift_dev/src/analysis/results/results.dart'; import 'package:test/test.dart'; @@ -233,6 +234,30 @@ class Database {} expect(column.nameInSql, 'TEXTCOLUMN'); }); + test('recognizes custom column types', () async { + final state = TestBackend.inTest({ + 'a|lib/main.dart': ''' +import 'package:drift/drift.dart'; + +class StringArrayType implements CustomSqlType> {} + +class TestTable extends Table { + Column> get list => customType(StringArrayType())(); +} +''', + }); + + final file = await state.analyze('package:a/main.dart'); + state.expectNoErrors(); + + final table = file.analyzedElements.whereType().single; + final column = table.columns.single; + + expect(column.sqlType.builtin, DriftSqlType.any); + expect(column.sqlType.custom?.dartType.toString(), 'List'); + expect(column.sqlType.custom?.expression.toString(), 'StringArrayType()'); + }); + group('customConstraint analysis', () { test('reports errors', () async { final state = TestBackend.inTest({ diff --git a/drift_dev/test/analysis/resolver/dart/view_test.dart b/drift_dev/test/analysis/resolver/dart/view_test.dart index e818b613..60d586f0 100644 --- a/drift_dev/test/analysis/resolver/dart/view_test.dart +++ b/drift_dev/test/analysis/resolver/dart/view_test.dart @@ -95,7 +95,7 @@ abstract class TodoItemWithCategoryNameView extends View { todoCategoryItemCount.columns[1], isA() .having((e) => e.nameInDart, 'nameInDart', 'itemCount') - .having((e) => e.sqlType, 'sqlType', DriftSqlType.int) + .having((e) => e.sqlType.builtin, 'sqlType', DriftSqlType.int) .having((e) => e.nullable, 'nullable', isTrue)); expect(todoItemWithCategoryName.columns, hasLength(2)); @@ -108,7 +108,7 @@ abstract class TodoItemWithCategoryNameView extends View { todoItemWithCategoryName.columns[1], isA() .having((e) => e.nameInDart, 'nameInDart', 'title') - .having((e) => e.sqlType, 'sqlType', DriftSqlType.string) + .having((e) => e.sqlType.builtin, 'sqlType', DriftSqlType.string) .having((e) => e.nullable, 'nullable', isTrue)); }); } diff --git a/drift_dev/test/analysis/resolver/drift/create_view_test.dart b/drift_dev/test/analysis/resolver/drift/create_view_test.dart index 0ba371e4..9fe69b93 100644 --- a/drift_dev/test/analysis/resolver/drift/create_view_test.dart +++ b/drift_dev/test/analysis/resolver/drift/create_view_test.dart @@ -23,8 +23,8 @@ void main() { final view = file.analyzedElements.single as DriftView; expect(view.columns, [ - isA() - .having((e) => e.sqlType, 'sqlType', drift.DriftSqlType.string) + isA().having( + (e) => e.sqlType.builtin, 'sqlType', drift.DriftSqlType.string) ]); expect(view.references, @@ -56,8 +56,8 @@ void main() { expect(parentView.columns, hasLength(2)); expect(childView.columns, [ - isA() - .having((e) => e.sqlType, 'sqlType', drift.DriftSqlType.string) + isA().having( + (e) => e.sqlType.builtin, 'sqlType', drift.DriftSqlType.string) ]); expect(parentView.references.map((e) => e.id.name), ['t']); @@ -290,7 +290,7 @@ enum MyEnum { final view = state.analyzedElements.single as DriftView; final column = view.columns.single; - expect(column.sqlType, expectedType); + expect(column.sqlType.builtin, expectedType); expect( view.source, isA().having( diff --git a/drift_dev/test/analysis/resolver/drift/cte_test.dart b/drift_dev/test/analysis/resolver/drift/cte_test.dart index b7d4ec79..f9583be1 100644 --- a/drift_dev/test/analysis/resolver/drift/cte_test.dart +++ b/drift_dev/test/analysis/resolver/drift/cte_test.dart @@ -33,7 +33,8 @@ SELECT final resultSet = query.resultSet; expect(resultSet.singleColumn, isTrue); expect(resultSet.needsOwnClass, isFalse); - expect(resultSet.scalarColumns.map((c) => c.sqlType), [DriftSqlType.int]); + expect(resultSet.scalarColumns.map((c) => c.sqlType.builtin), + [DriftSqlType.int]); }); test('recognizes CTE clause', () async { @@ -64,7 +65,8 @@ WITH RECURSIVE expect(resultSet.singleColumn, isTrue); expect(resultSet.needsOwnClass, isFalse); expect(resultSet.columns.map(resultSet.dartNameFor), ['x']); - expect(resultSet.scalarColumns.map((c) => c.sqlType), [DriftSqlType.int]); + expect(resultSet.scalarColumns.map((c) => c.sqlType.builtin), + [DriftSqlType.int]); }); test('finds the underlying table when aliased through CTE', () async { diff --git a/drift_dev/test/analysis/resolver/drift/ffi_extension_test.dart b/drift_dev/test/analysis/resolver/drift/ffi_extension_test.dart index c8581f8d..b3b5cbed 100644 --- a/drift_dev/test/analysis/resolver/drift/ffi_extension_test.dart +++ b/drift_dev/test/analysis/resolver/drift/ffi_extension_test.dart @@ -108,7 +108,7 @@ wrongArgs: SELECT sin(oid, foo) FROM numbers; expect( queryInA.resultSet.scalarColumns.single, const TypeMatcher() - .having((e) => e.sqlType, 'type', DriftSqlType.double), + .having((e) => e.sqlType.builtin, 'type', DriftSqlType.double), ); final fileB = await state.analyze('package:foo/b.drift'); diff --git a/drift_dev/test/analysis/resolver/drift/table_test.dart b/drift_dev/test/analysis/resolver/drift/table_test.dart index 38c8a2c2..cf81d1ae 100644 --- a/drift_dev/test/analysis/resolver/drift/table_test.dart +++ b/drift_dev/test/analysis/resolver/drift/table_test.dart @@ -33,12 +33,12 @@ CREATE TABLE b ( final b = results[1].result! as DriftTable; final bBar = b.columns[0]; - expect(aFoo.sqlType, DriftSqlType.int); + expect(aFoo.sqlType.builtin, DriftSqlType.int); expect(aFoo.nullable, isFalse); expect(aFoo.constraints, [isA()]); expect(aFoo.customConstraints, 'PRIMARY KEY'); - expect(aBar.sqlType, DriftSqlType.int); + expect(aBar.sqlType.builtin, DriftSqlType.int); expect(aBar.nullable, isTrue); expect(aBar.constraints, [ isA() @@ -48,7 +48,7 @@ CREATE TABLE b ( ]); expect(aBar.customConstraints, 'REFERENCES b(bar)'); - expect(bBar.sqlType, DriftSqlType.int); + expect(bBar.sqlType.builtin, DriftSqlType.int); expect(bBar.nullable, isFalse); expect(bBar.constraints, isEmpty); expect(bBar.customConstraints, 'NOT NULL'); @@ -111,7 +111,7 @@ CREATE TABLE b ( final indexColumn = table.columns.singleWhere((c) => c.nameInSql == 'fruitIndex'); - expect(indexColumn.sqlType, DriftSqlType.int); + expect(indexColumn.sqlType.builtin, DriftSqlType.int); expect( indexColumn.typeConverter, isA() @@ -126,7 +126,7 @@ CREATE TABLE b ( final withGenericIndexColumn = table.columns .singleWhere((c) => c.nameInSql == 'fruitWithGenericIndex'); - expect(withGenericIndexColumn.sqlType, DriftSqlType.int); + expect(withGenericIndexColumn.sqlType.builtin, DriftSqlType.int); expect( withGenericIndexColumn.typeConverter, isA() @@ -142,7 +142,7 @@ CREATE TABLE b ( final nameColumn = table.columns.singleWhere((c) => c.nameInSql == 'fruitName'); - expect(nameColumn.sqlType, DriftSqlType.string); + expect(nameColumn.sqlType.builtin, DriftSqlType.string); expect( nameColumn.typeConverter, isA() @@ -263,4 +263,31 @@ CREATE TABLE IF NOT EXISTS currencies ( 'documentationComment', '/// The name of this currency'), ); }); + + test('can use custom types', () async { + final state = TestBackend.inTest({ + 'a|lib/a.drift': ''' +import 'b.dart'; + +CREATE TABLE foo ( + bar `MyType()` NOT NULL +); +''', + 'a|lib/b.dart': ''' +import 'package:drift/drift.dart'; + +class MyType implements CustomSqlType {} + ''', + }); + + final file = await state.analyze('package:a/a.drift'); + state.expectNoErrors(); + + final table = file.analyzedElements.single as DriftTable; + final column = table.columns.single; + + expect(column.sqlType.isCustom, isTrue); + expect(column.sqlType.custom?.dartType.toString(), 'String'); + expect(column.sqlType.custom?.expression.toString(), 'MyType()'); + }); } diff --git a/drift_dev/test/analysis/resolver/queries/query_analyzer_test.dart b/drift_dev/test/analysis/resolver/queries/query_analyzer_test.dart index 25d58892..6bab2e91 100644 --- a/drift_dev/test/analysis/resolver/queries/query_analyzer_test.dart +++ b/drift_dev/test/analysis/resolver/queries/query_analyzer_test.dart @@ -24,7 +24,7 @@ bar(?1 AS TEXT, :foo AS BOOLEAN): SELECT ?, :foo; final resultSet = (query as SqlSelectQuery).resultSet; expect(resultSet.matchingTable, isNull); expect(resultSet.scalarColumns.map((c) => c.name), ['?', ':foo']); - expect(resultSet.scalarColumns.map((c) => c.sqlType), + expect(resultSet.scalarColumns.map((c) => c.sqlType.builtin), [DriftSqlType.string, DriftSqlType.bool]); }); @@ -181,19 +181,22 @@ q3: SELECT datetime('now'); expect(queries, hasLength(3)); final q1 = queries[0]; - expect(q1.resultSet!.scalarColumns.single.sqlType, DriftSqlType.dateTime); + expect(q1.resultSet!.scalarColumns.single.sqlType.builtin, + DriftSqlType.dateTime); final q2 = queries[1]; final q3 = queries[2]; if (dateTimeAsText) { - expect(q2.resultSet!.scalarColumns.single.sqlType, DriftSqlType.int); - expect( - q3.resultSet!.scalarColumns.single.sqlType, DriftSqlType.dateTime); + expect(q2.resultSet!.scalarColumns.single.sqlType.builtin, + DriftSqlType.int); + expect(q3.resultSet!.scalarColumns.single.sqlType.builtin, + DriftSqlType.dateTime); } else { - expect( - q2.resultSet!.scalarColumns.single.sqlType, DriftSqlType.dateTime); - expect(q3.resultSet!.scalarColumns.single.sqlType, DriftSqlType.string); + expect(q2.resultSet!.scalarColumns.single.sqlType.builtin, + DriftSqlType.dateTime); + expect(q3.resultSet!.scalarColumns.single.sqlType.builtin, + DriftSqlType.string); } }); } @@ -306,11 +309,12 @@ LEFT JOIN tableB1 AS tableB2 -- nullable final query = result.fileAnalysis!.resolvedQueries.values.single; expect(query.resultSet!.columns, [ isA() - .having((e) => e.sqlType, 'sqlType', DriftSqlType.bool) + .having((e) => e.sqlType.builtin, 'sqlType', DriftSqlType.bool) ]); final args = query.variables; - expect(args.map((e) => e.sqlType), [DriftSqlType.int, DriftSqlType.string]); + expect(args.map((e) => e.sqlType.builtin), + [DriftSqlType.int, DriftSqlType.string]); }); test('can cast to DATETIME and BOOLEAN', () async { @@ -325,9 +329,10 @@ a: SELECT CAST(1 AS BOOLEAN) AS a, CAST(2 AS DATETIME) as b; final resultSet = query.resultSet!; expect(resultSet.columns, [ - scalarColumn('a').having((e) => e.sqlType, 'sqlType', DriftSqlType.bool), + scalarColumn('a') + .having((e) => e.sqlType.builtin, 'sqlType', DriftSqlType.bool), scalarColumn('b') - .having((e) => e.sqlType, 'sqlType', DriftSqlType.dateTime), + .having((e) => e.sqlType.builtin, 'sqlType', DriftSqlType.dateTime), ]); }); diff --git a/drift_dev/test/analysis/test_utils.dart b/drift_dev/test/analysis/test_utils.dart index 79dc1557..df38ca6c 100644 --- a/drift_dev/test/analysis/test_utils.dart +++ b/drift_dev/test/analysis/test_utils.dart @@ -279,7 +279,8 @@ class _HasInferredColumnTypes extends CustomMatcher { final resultSet = actual.resultSet; return { - for (final column in resultSet.scalarColumns) column.name: column.sqlType + for (final column in resultSet.scalarColumns) + column.name: column.sqlType.builtin }; } } diff --git a/drift_dev/test/writer/schema_version_writer_test.dart b/drift_dev/test/writer/schema_version_writer_test.dart index 76f6f71d..519f80ac 100644 --- a/drift_dev/test/writer/schema_version_writer_test.dart +++ b/drift_dev/test/writer/schema_version_writer_test.dart @@ -15,7 +15,7 @@ void main() { DriftDeclaration(fakeUri, -1, ''), columns: [ DriftColumn( - sqlType: DriftSqlType.int, + sqlType: ColumnType.drift(DriftSqlType.int), nullable: false, nameInSql: 'foo', nameInDart: 'foo', diff --git a/extras/drift_postgres/example/main.dart b/extras/drift_postgres/example/main.dart new file mode 100644 index 00000000..45e7d8a1 --- /dev/null +++ b/extras/drift_postgres/example/main.dart @@ -0,0 +1,37 @@ +import 'package:drift/drift.dart'; +import 'package:drift_postgres/postgres.dart'; +import 'package:postgres/postgres_v3_experimental.dart'; +import 'package:uuid/uuid.dart'; + +part 'main.g.dart'; + +class Users extends Table { + UuidColumn get id => customType(PgTypes.uuid).withDefault(genRandomUuid())(); + TextColumn get name => text()(); +} + +@DriftDatabase(tables: [Users]) +class DriftPostgresDatabase extends _$DriftPostgresDatabase { + DriftPostgresDatabase(super.e); + + @override + int get schemaVersion => 1; +} + +void main() async { + final database = DriftPostgresDatabase(PgDatabase( + endpoint: PgEndpoint( + host: 'localhost', + database: 'postgres', + username: 'postgres', + password: 'postgres', + ), + logStatements: true, + )); + + final user = await database.users.insertReturning( + UsersCompanion.insert(name: 'Simon', id: Value(Uuid().v4obj()))); + print(user); + + await database.close(); +} diff --git a/extras/drift_postgres/example/main.g.dart b/extras/drift_postgres/example/main.g.dart new file mode 100644 index 00000000..3b79ed01 --- /dev/null +++ b/extras/drift_postgres/example/main.g.dart @@ -0,0 +1,192 @@ +// GENERATED CODE - DO NOT MODIFY BY HAND + +part of 'main.dart'; + +// ignore_for_file: type=lint +class $UsersTable extends Users with TableInfo<$UsersTable, User> { + @override + final GeneratedDatabase attachedDatabase; + final String? _alias; + $UsersTable(this.attachedDatabase, [this._alias]); + static const VerificationMeta _idMeta = const VerificationMeta('id'); + @override + late final GeneratedColumn id = GeneratedColumn( + 'id', aliasedName, false, + type: PgTypes.uuid, + requiredDuringInsert: false, + defaultValue: genRandomUuid()); + static const VerificationMeta _nameMeta = const VerificationMeta('name'); + @override + late final GeneratedColumn name = GeneratedColumn( + 'name', aliasedName, false, + type: DriftSqlType.string, requiredDuringInsert: true); + @override + List get $columns => [id, name]; + @override + String get aliasedName => _alias ?? actualTableName; + @override + String get actualTableName => $name; + static const String $name = 'users'; + @override + VerificationContext validateIntegrity(Insertable instance, + {bool isInserting = false}) { + final context = VerificationContext(); + final data = instance.toColumns(true); + if (data.containsKey('id')) { + context.handle(_idMeta, id.isAcceptableOrUnknown(data['id']!, _idMeta)); + } + if (data.containsKey('name')) { + context.handle( + _nameMeta, name.isAcceptableOrUnknown(data['name']!, _nameMeta)); + } else if (isInserting) { + context.missing(_nameMeta); + } + return context; + } + + @override + Set get $primaryKey => const {}; + @override + User map(Map data, {String? tablePrefix}) { + final effectivePrefix = tablePrefix != null ? '$tablePrefix.' : ''; + return User( + id: attachedDatabase.typeMapping + .read(PgTypes.uuid, data['${effectivePrefix}id'])!, + name: attachedDatabase.typeMapping + .read(DriftSqlType.string, data['${effectivePrefix}name'])!, + ); + } + + @override + $UsersTable createAlias(String alias) { + return $UsersTable(attachedDatabase, alias); + } +} + +class User extends DataClass implements Insertable { + final UuidValue id; + final String name; + const User({required this.id, required this.name}); + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + map['id'] = Variable(id); + map['name'] = Variable(name); + return map; + } + + UsersCompanion toCompanion(bool nullToAbsent) { + return UsersCompanion( + id: Value(id), + name: Value(name), + ); + } + + factory User.fromJson(Map json, + {ValueSerializer? serializer}) { + serializer ??= driftRuntimeOptions.defaultSerializer; + return User( + id: serializer.fromJson(json['id']), + name: serializer.fromJson(json['name']), + ); + } + @override + Map toJson({ValueSerializer? serializer}) { + serializer ??= driftRuntimeOptions.defaultSerializer; + return { + 'id': serializer.toJson(id), + 'name': serializer.toJson(name), + }; + } + + User copyWith({UuidValue? id, String? name}) => User( + id: id ?? this.id, + name: name ?? this.name, + ); + @override + String toString() { + return (StringBuffer('User(') + ..write('id: $id, ') + ..write('name: $name') + ..write(')')) + .toString(); + } + + @override + int get hashCode => Object.hash(id, name); + @override + bool operator ==(Object other) => + identical(this, other) || + (other is User && other.id == this.id && other.name == this.name); +} + +class UsersCompanion extends UpdateCompanion { + final Value id; + final Value name; + final Value rowid; + const UsersCompanion({ + this.id = const Value.absent(), + this.name = const Value.absent(), + this.rowid = const Value.absent(), + }); + UsersCompanion.insert({ + this.id = const Value.absent(), + required String name, + this.rowid = const Value.absent(), + }) : name = Value(name); + static Insertable custom({ + Expression? id, + Expression? name, + Expression? rowid, + }) { + return RawValuesInsertable({ + if (id != null) 'id': id, + if (name != null) 'name': name, + if (rowid != null) 'rowid': rowid, + }); + } + + UsersCompanion copyWith( + {Value? id, Value? name, Value? rowid}) { + return UsersCompanion( + id: id ?? this.id, + name: name ?? this.name, + rowid: rowid ?? this.rowid, + ); + } + + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + if (id.present) { + map['id'] = Variable(id.value, PgTypes.uuid); + } + if (name.present) { + map['name'] = Variable(name.value); + } + if (rowid.present) { + map['rowid'] = Variable(rowid.value); + } + return map; + } + + @override + String toString() { + return (StringBuffer('UsersCompanion(') + ..write('id: $id, ') + ..write('name: $name, ') + ..write('rowid: $rowid') + ..write(')')) + .toString(); + } +} + +abstract class _$DriftPostgresDatabase extends GeneratedDatabase { + _$DriftPostgresDatabase(QueryExecutor e) : super(e); + late final $UsersTable users = $UsersTable(this); + @override + Iterable> get allTables => + allSchemaEntities.whereType>(); + @override + List get allSchemaEntities => [users]; +} diff --git a/extras/drift_postgres/example/test.dart b/extras/drift_postgres/example/test.dart deleted file mode 100644 index 87c2fa30..00000000 --- a/extras/drift_postgres/example/test.dart +++ /dev/null @@ -1,32 +0,0 @@ -import 'package:drift/backends.dart'; -import 'package:drift/src/runtime/query_builder/query_builder.dart'; -import 'package:drift_postgres/postgres.dart'; -import 'package:postgres/postgres_v3_experimental.dart'; - -void main() async { - final postgres = PgDatabase( - endpoint: PgEndpoint( - host: 'localhost', - database: 'postgres', - username: 'postgres', - password: 'postgres', - ), - logStatements: true, - ); - - await postgres.ensureOpen(_NullUser()); - - final rows = await postgres.runSelect(r'SELECT $1', [true]); - final row = rows.single; - print(row); - print(row.values.map((e) => e.runtimeType).toList()); -} - -class _NullUser extends QueryExecutorUser { - @override - Future beforeOpen( - QueryExecutor executor, OpeningDetails details) async {} - - @override - int get schemaVersion => 1; -} diff --git a/extras/drift_postgres/lib/postgres.dart b/extras/drift_postgres/lib/postgres.dart index ba57a666..0c63bf65 100644 --- a/extras/drift_postgres/lib/postgres.dart +++ b/extras/drift_postgres/lib/postgres.dart @@ -2,6 +2,33 @@ @experimental library drift.postgres; +import 'package:drift/drift.dart'; import 'package:meta/meta.dart'; +import 'package:postgres/postgres_v3_experimental.dart'; +import 'package:uuid/uuid.dart'; + +import 'src/types.dart'; export 'src/pg_database.dart'; + +typedef UuidColumn = Column; +typedef IntervalColumn = Column; +typedef JsonColumn = Column; +typedef PointColumn = Column; + +final class PgTypes { + PgTypes._(); + + static const CustomSqlType uuid = UuidType(); + static const CustomSqlType interval = IntervalType(); + static const CustomSqlType json = + PostgresType(type: PgDataType.json, name: 'json'); + static const CustomSqlType jsonb = + PostgresType(type: PgDataType.json, name: 'jsonb'); + static const CustomSqlType point = PointType(); +} + +/// Calls the `gen_random_uuid` function in postgres. +Expression genRandomUuid() { + return FunctionCallExpression('gen_random_uuid', []); +} diff --git a/extras/drift_postgres/lib/src/pg_database.dart b/extras/drift_postgres/lib/src/pg_database.dart index 0bab0232..257ef748 100644 --- a/extras/drift_postgres/lib/src/pg_database.dart +++ b/extras/drift_postgres/lib/src/pg_database.dart @@ -152,25 +152,16 @@ class _BoundArguments { } for (final value in args) { - if (value == null) { - add(PgTypedParameter(PgDataType.text, null)); - } else if (value is int) { - add(PgTypedParameter(PgDataType.bigInteger, value)); - } else if (value is BigInt) { - // Drift only uses BigInts to represent 64-bit values on the web, so we - // can use toInt() here. - add(PgTypedParameter(PgDataType.bigInteger, value)); - } else if (value is bool) { - add(PgTypedParameter(PgDataType.boolean, value)); - } else if (value is double) { - add(PgTypedParameter(PgDataType.double, value)); - } else if (value is String) { - add(PgTypedParameter(PgDataType.text, value)); - } else if (value is List) { - add(PgTypedParameter(PgDataType.byteArray, value)); - } else { - throw ArgumentError.value(value, 'value', 'Unsupported type'); - } + add(switch (value) { + PgTypedParameter() => value, + null => PgTypedParameter(PgDataType.text, null), + int() || BigInt() => PgTypedParameter(PgDataType.bigInteger, value), + String() => PgTypedParameter(PgDataType.text, value), + bool() => PgTypedParameter(PgDataType.boolean, value), + double() => PgTypedParameter(PgDataType.double, value), + List() => PgTypedParameter(PgDataType.byteArray, value), + _ => throw ArgumentError.value(value, 'value', 'Unsupported type'), + }); } return _BoundArguments(types, parameters); diff --git a/extras/drift_postgres/lib/src/types.dart b/extras/drift_postgres/lib/src/types.dart new file mode 100644 index 00000000..ab810e97 --- /dev/null +++ b/extras/drift_postgres/lib/src/types.dart @@ -0,0 +1,67 @@ +import 'package:drift/drift.dart'; +import 'package:postgres/postgres_v3_experimental.dart'; +// ignore: implementation_imports +import 'package:postgres/src/text_codec.dart'; +import 'package:uuid/uuid.dart'; + +class PostgresType implements CustomSqlType { + static final _encoder = PostgresTextEncoder(); + + final PgDataType type; + final String name; + + const PostgresType({required this.type, required this.name}); + + @override + String mapToSqlLiteral(T dartValue) { + return '${_encoder.convert(dartValue)}::$name'; + } + + @override + Object mapToSqlParameter(T dartValue) => PgTypedParameter(type, dartValue); + + @override + T read(Object fromSql) => fromSql as T; + + @override + String sqlTypeName(GenerationContext context) => name; +} + +class UuidType extends PostgresType { + const UuidType() : super(type: PgDataType.uuid, name: 'uuid'); + + @override + String mapToSqlLiteral(UuidValue dartValue) { + // UUIDs can't contain escape characters, so we don't check these values. + return "'${dartValue.uuid}'"; + } + + @override + Object mapToSqlParameter(UuidValue dartValue) { + return PgTypedParameter(PgDataType.uuid, dartValue.uuid); + } + + @override + UuidValue read(Object fromSql) { + return UuidValue(fromSql as String); + } +} + +// override because the text encoder doesn't properly encode PgPoint values +class PointType extends PostgresType { + const PointType() : super(type: PgDataType.point, name: 'point'); + + @override + String mapToSqlLiteral(PgPoint dartValue) { + return "'(${dartValue.latitude}, ${dartValue.longitude})'::point"; + } +} + +class IntervalType extends PostgresType { + const IntervalType() : super(type: PgDataType.interval, name: 'interval'); + + @override + String mapToSqlLiteral(Duration dartValue) { + return "'${dartValue.inMicroseconds} microseconds'::interval"; + } +} diff --git a/extras/drift_postgres/pubspec.yaml b/extras/drift_postgres/pubspec.yaml index 1bf14d5d..1368d6f7 100644 --- a/extras/drift_postgres/pubspec.yaml +++ b/extras/drift_postgres/pubspec.yaml @@ -3,19 +3,22 @@ description: Postgres support for drift version: 1.0.0 environment: - sdk: '>=2.12.0-0 <4.0.0' + sdk: '>=3.0.0 <4.0.0' dependencies: collection: ^1.16.0 drift: ^2.0.0 postgres: meta: ^1.8.0 + uuid: ^4.1.0 dev_dependencies: lints: ^2.0.0 test: ^1.18.0 + drift_dev: drift_testcases: path: ../integration_tests/drift_testcases + build_runner: ^2.4.6 dependency_overrides: drift: diff --git a/extras/drift_postgres/test/types_test.dart b/extras/drift_postgres/test/types_test.dart new file mode 100644 index 00000000..ea3e9493 --- /dev/null +++ b/extras/drift_postgres/test/types_test.dart @@ -0,0 +1,53 @@ +import 'package:drift/drift.dart'; +import 'package:drift_postgres/postgres.dart'; +import 'package:postgres/postgres_v3_experimental.dart'; +import 'package:test/test.dart'; +import 'package:uuid/uuid.dart'; + +import '../example/main.dart'; + +void main() { + final database = DriftPostgresDatabase(PgDatabase( + endpoint: PgEndpoint( + host: 'localhost', + database: 'postgres', + username: 'postgres', + password: 'postgres', + ), + )); + + setUpAll(() async { + await database.users.insertOne(UsersCompanion.insert(name: 'test user')); + }); + + tearDownAll(() async { + await database.users.deleteAll(); + await database.close(); + }); + + group('custom types pass through', () { + void testWith(CustomSqlType type, T value) { + test('with variable', () async { + final variable = Variable(value, type); + final query = database.selectOnly(database.users) + ..addColumns([variable]); + final row = await query.getSingle(); + expect(row.read(variable), value); + }); + + test('with constant', () async { + final constant = Constant(value, type); + final query = database.selectOnly(database.users) + ..addColumns([constant]); + final row = await query.getSingle(); + expect(row.read(constant), value); + }); + } + + group('uuid', () => testWith(PgTypes.uuid, Uuid().v4obj())); + group('interval', () => testWith(PgTypes.interval, Duration(seconds: 15))); + group('json', () => testWith(PgTypes.json, {'foo': 'bar'})); + group('jsonb', () => testWith(PgTypes.jsonb, {'foo': 'bar'})); + group('point', () => testWith(PgTypes.point, PgPoint(90, -90))); + }); +} diff --git a/sqlparser/lib/src/reader/parser.dart b/sqlparser/lib/src/reader/parser.dart index ccc32383..ed7bc439 100644 --- a/sqlparser/lib/src/reader/parser.dart +++ b/sqlparser/lib/src/reader/parser.dart @@ -2496,6 +2496,10 @@ class Parser { } List? _typeName() { + if (enableDriftExtensions && _matchOne(TokenType.inlineDart)) { + return [_previous]; + } + // sqlite doesn't really define what a type name is and has very loose rules // at turning them into a type affinity. We support this pattern: // typename = identifier [ "(" { identifier | comma | number_literal } ")" ] diff --git a/sqlparser/test/parser/create_table_test.dart b/sqlparser/test/parser/create_table_test.dart index 74b01c5b..f094096d 100644 --- a/sqlparser/test/parser/create_table_test.dart +++ b/sqlparser/test/parser/create_table_test.dart @@ -217,7 +217,7 @@ void main() { ); }); - test('parses CREATE TABLE WITH in drift more', () { + test('parses CREATE TABLE WITH in drift mode', () { testStatement( 'CREATE TABLE a (b INTEGER) WITH MyExistingClass', CreateTableStatement( @@ -237,6 +237,23 @@ void main() { ); }); + test('parses custom types in drift mode', () { + testStatement( + 'CREATE TABLE a (b `PgTypes.uuid` NOT NULL)', + CreateTableStatement( + tableName: 'a', + columns: [ + ColumnDefinition( + columnName: 'b', + typeName: '`PgTypes.uuid`', + constraints: [NotNull(null)], + ), + ], + ), + driftMode: true, + ); + }); + test('parses CREATE VIRTUAL TABLE statement', () { testStatement( 'CREATE VIRTUAL TABLE IF NOT EXISTS foo USING bar(a, b(), c) AS drift',