From a68ec2a73eec0abdd48430307bd2ee79ab2137f8 Mon Sep 17 00:00:00 2001 From: Simon Binder Date: Mon, 22 Apr 2024 22:17:21 +0200 Subject: [PATCH] Make aggregate expression builder public --- drift/CHANGELOG.md | 7 ++ drift/lib/extensions/json1.dart | 18 ++++- .../query_builder/expressions/aggregate.dart | 73 +++++++++++++------ .../extensions/json1_integration_test.dart | 10 ++- drift/test/extensions/json1_test.dart | 22 ++++++ drift/test/test_utils/matchers.dart | 2 +- drift/test/test_utils/test_utils.mocks.dart | 2 +- 7 files changed, 102 insertions(+), 32 deletions(-) diff --git a/drift/CHANGELOG.md b/drift/CHANGELOG.md index 59b00974..f4a87ebb 100644 --- a/drift/CHANGELOG.md +++ b/drift/CHANGELOG.md @@ -1,3 +1,10 @@ +## 2.18.0-dev + +- Add `AggregateFunctionExpression` to write custom [aggregate function](https://www.sqlite.org/lang_aggfunc.html) + invocations in the Dart query builder. +- The `json_group_array` and `jsonb_group_array` functions now contain an `orderBy` + and `filter` parameter. + ## 2.17.0 - Adds `companion` entry to `DataClassName` to override the name of the diff --git a/drift/lib/extensions/json1.dart b/drift/lib/extensions/json1.dart index 1e600dc8..cd2fc8f9 100644 --- a/drift/lib/extensions/json1.dart +++ b/drift/lib/extensions/json1.dart @@ -144,8 +144,13 @@ extension JsonExtensions on Expression { /// all emails in that folder. /// This string could be turned back into a list with /// `(json.decode(row.read(subjects)!) as List).cast()`. -Expression jsonGroupArray(Expression value) { - return FunctionCallExpression('json_group_array', [value]); +Expression jsonGroupArray( + Expression value, { + OrderBy? orderBy, + Expression? filter, +}) { + return AggregateFunctionExpression('json_group_array', [value], + orderBy: orderBy, filter: filter); } /// Returns a binary representation of a JSON array containing the result of @@ -153,8 +158,13 @@ Expression jsonGroupArray(Expression value) { /// /// See [jsonGroupArray], the variant of this function returning a textual /// description, for more details and an example. -Expression jsonbGroupArray(Expression value) { - return FunctionCallExpression('jsonb_group_array', [value]); +Expression jsonbGroupArray( + Expression value, { + OrderBy? orderBy, + Expression? filter, +}) { + return AggregateFunctionExpression('jsonb_group_array', [value], + orderBy: orderBy, filter: filter); } List _groupObjectArgs(Map, Expression> values) { diff --git a/drift/lib/src/runtime/query_builder/expressions/aggregate.dart b/drift/lib/src/runtime/query_builder/expressions/aggregate.dart index b2f99b52..451ff018 100644 --- a/drift/lib/src/runtime/query_builder/expressions/aggregate.dart +++ b/drift/lib/src/runtime/query_builder/expressions/aggregate.dart @@ -12,7 +12,7 @@ part of '../query_builder.dart'; /// This is equivalent to the `COUNT(*) FILTER (WHERE filter)` sql function. The /// filter will be omitted if null. Expression countAll({Expression? filter}) { - return _AggregateExpression('COUNT', const [_StarFunctionParameter()], + return AggregateFunctionExpression('COUNT', const [_StarFunctionParameter()], filter: filter); } @@ -26,7 +26,7 @@ extension BaseAggregate
on Expression
{ /// counted twice. /// {@macro drift_aggregate_filter} Expression count({bool distinct = false, Expression? filter}) { - return _AggregateExpression('COUNT', [this], + return AggregateFunctionExpression('COUNT', [this], filter: filter, distinct: distinct); } @@ -35,14 +35,14 @@ extension BaseAggregate
on Expression
{ /// If there are no non-null values in the group, returns null. /// {@macro drift_aggregate_filter} Expression
max({Expression? filter}) => - _AggregateExpression('MAX', [this], filter: filter); + AggregateFunctionExpression('MAX', [this], filter: filter); /// Return the minimum of all non-null values in this group. /// /// If there are no non-null values in the group, returns null. /// {@macro drift_aggregate_filter} Expression
min({Expression? filter}) => - _AggregateExpression('MIN', [this], filter: filter); + AggregateFunctionExpression('MIN', [this], filter: filter); /// Returns the concatenation of all non-null values in the current group, /// joined by the [separator]. @@ -71,7 +71,7 @@ extension BaseAggregate
on Expression
{ 'Cannot use groupConcat with distinct: true and a custom separator'); } - return _AggregateExpression( + return AggregateFunctionExpression( 'GROUP_CONCAT', [ this, @@ -89,21 +89,21 @@ extension ArithmeticAggregates
on Expression
{ /// /// {@macro drift_aggregate_filter} Expression avg({Expression? filter}) => - _AggregateExpression('AVG', [this], filter: filter); + AggregateFunctionExpression('AVG', [this], filter: filter); /// Return the maximum of all non-null values in this group. /// /// If there are no non-null values in the group, returns null. /// {@macro drift_aggregate_filter} Expression
max({Expression? filter}) => - _AggregateExpression('MAX', [this], filter: filter); + AggregateFunctionExpression('MAX', [this], filter: filter); /// Return the minimum of all non-null values in this group. /// /// If there are no non-null values in the group, returns null. /// {@macro drift_aggregate_filter} Expression
min({Expression? filter}) => - _AggregateExpression('MIN', [this], filter: filter); + AggregateFunctionExpression('MIN', [this], filter: filter); /// Calculate the sum of all non-null values in the group. /// @@ -115,7 +115,7 @@ extension ArithmeticAggregates
on Expression
{ /// value and doesn't throw an overflow exception. /// {@macro drift_aggregate_filter} Expression
sum({Expression? filter}) => - _AggregateExpression('SUM', [this], filter: filter); + AggregateFunctionExpression('SUM', [this], filter: filter); /// Calculate the sum of all non-null values in the group. /// @@ -123,7 +123,7 @@ extension ArithmeticAggregates
on Expression
{ /// uses floating-point values internally. /// {@macro drift_aggregate_filter} Expression total({Expression? filter}) => - _AggregateExpression('TOTAL', [this], filter: filter); + AggregateFunctionExpression('TOTAL', [this], filter: filter); } /// Provides aggregate functions that are available for BigInt expressions. @@ -197,16 +197,41 @@ extension DateTimeAggregate on Expression { } } -class _AggregateExpression extends Expression { +/// An expression invoking an [aggregate function](https://www.sqlite.org/lang_aggfunc.html). +/// +/// Aggregate functions, like `count()` or `sum()` collapse the entire data set +/// (or a partition of it, if `GROUP BY` is used) into a single value. +/// +/// Drift exposes direct bindings to most aggregate functions (e.g. via +/// [BaseAggregate.count]). This class is useful when writing custom aggregate +/// function invocations. +final class AggregateFunctionExpression + extends Expression { + /// The name of the aggregate function to invoke. final String functionName; - final bool distinct; - final List parameter; + /// Whether only distinct rows should be passed to the function. + final bool distinct; + + /// The arguments to pass to the function. + final List arguments; + + /// The order in which rows of the current group should be passed to the + /// aggregate function. + final OrderBy? orderBy; + + /// An optional filter clause only passing rows matching this condition into + /// the function. final Where? filter; - _AggregateExpression(this.functionName, this.parameter, - {Expression? filter, this.distinct = false}) - : filter = filter != null ? Where(filter) : null; + /// Creates an aggregate function expression from the syntactic components. + AggregateFunctionExpression( + this.functionName, + this.arguments, { + Expression? filter, + this.distinct = false, + this.orderBy, + }) : filter = filter != null ? Where(filter) : null; @override final Precedence precedence = Precedence.primary; @@ -220,7 +245,11 @@ class _AggregateExpression extends Expression { if (distinct) { context.buffer.write('DISTINCT '); } - _writeCommaSeparated(context, parameter); + _writeCommaSeparated(context, arguments); + if (orderBy case final orderBy?) { + context.writeWhitespace(); + orderBy.writeInto(context); + } context.buffer.write(')'); if (filter != null) { @@ -233,20 +262,20 @@ class _AggregateExpression extends Expression { @override int get hashCode { return Object.hash(functionName, distinct, - const ListEquality().hash(parameter), filter); + const ListEquality().hash(arguments), orderBy, filter); } @override bool operator ==(Object other) { - if (!identical(this, other) && other.runtimeType != runtimeType) { + if (!identical(this, other) && other is! AggregateFunctionExpression) { return false; } - // ignore: test_types_in_equals - final typedOther = other as _AggregateExpression; + final typedOther = other as AggregateFunctionExpression; return typedOther.functionName == functionName && typedOther.distinct == distinct && - const ListEquality().equals(typedOther.parameter, parameter) && + const ListEquality().equals(typedOther.arguments, arguments) && + typedOther.orderBy == orderBy && typedOther.filter == filter; } } diff --git a/drift/test/extensions/json1_integration_test.dart b/drift/test/extensions/json1_integration_test.dart index 2a268218..f62a1f22 100644 --- a/drift/test/extensions/json1_integration_test.dart +++ b/drift/test/extensions/json1_integration_test.dart @@ -101,15 +101,17 @@ void main() { db.todosTable, db.todosTable.category.equalsExp(db.categories.id)) ]); - final stringArray = jsonGroupArray(db.todosTable.id); - final binaryArray = jsonbGroupArray(db.todosTable.id).json(); + final stringArray = jsonGroupArray(db.todosTable.id, + orderBy: OrderBy([OrderingTerm.desc(db.todosTable.id)])); + final binaryArray = jsonbGroupArray(db.todosTable.id, + orderBy: OrderBy([OrderingTerm.asc(db.todosTable.id)])).json(); query ..groupBy([db.categories.id]) ..addColumns([stringArray, binaryArray]); final row = await query.getSingle(); - expect(json.decode(row.read(stringArray)!), unorderedEquals([1, 3])); - expect(json.decode(row.read(binaryArray)!), unorderedEquals([1, 3])); + expect(json.decode(row.read(stringArray)!), [3, 1]); + expect(json.decode(row.read(binaryArray)!), [1, 3]); }); test('json_group_object', () async { diff --git a/drift/test/extensions/json1_test.dart b/drift/test/extensions/json1_test.dart index f6eca7d0..bee8f927 100644 --- a/drift/test/extensions/json1_test.dart +++ b/drift/test/extensions/json1_test.dart @@ -42,6 +42,17 @@ void main() { test('aggregates', () { expect(jsonGroupArray(column), generates('json_group_array(col)')); + expect( + jsonGroupArray( + column, + orderBy: OrderBy([OrderingTerm.desc(column)]), + filter: column.length.isBiggerOrEqualValue(10), + ), + generates( + 'json_group_array(col ORDER BY col DESC) FILTER (WHERE LENGTH(col) >= ?)', + [10], + ), + ); expect( jsonGroupObject({ Variable('foo'): column, @@ -84,6 +95,17 @@ void main() { test('aggregates', () { expect(jsonbGroupArray(column), generates('jsonb_group_array(col)')); + expect( + jsonbGroupArray( + column, + orderBy: OrderBy([OrderingTerm.desc(column)]), + filter: column.length.isBiggerOrEqualValue(10), + ), + generates( + 'jsonb_group_array(col ORDER BY col DESC) FILTER (WHERE LENGTH(col) >= ?)', + [10], + ), + ); expect( jsonbGroupObject({ Variable('foo'): column, diff --git a/drift/test/test_utils/matchers.dart b/drift/test/test_utils/matchers.dart index 98497716..7b72826e 100644 --- a/drift/test/test_utils/matchers.dart +++ b/drift/test/test_utils/matchers.dart @@ -107,7 +107,7 @@ class _GeneratesSqlMatcher extends Matcher { matches = false; } - final argsMatchState = {}; + final argsMatchState = {}; if (_matchVariables != null && !_matchVariables.matches(ctx.boundVariables, argsMatchState)) { matchState['vars'] = ctx.boundVariables; diff --git a/drift/test/test_utils/test_utils.mocks.dart b/drift/test/test_utils/test_utils.mocks.dart index 03d207d3..0d052c05 100644 --- a/drift/test/test_utils/test_utils.mocks.dart +++ b/drift/test/test_utils/test_utils.mocks.dart @@ -1,4 +1,4 @@ -// Mocks generated by Mockito 5.4.3 from annotations +// Mocks generated by Mockito 5.4.4 from annotations // in drift/test/test_utils/test_utils.dart. // Do not manually edit this file.