Make aggregate expression builder public

This commit is contained in:
Simon Binder 2024-04-22 22:17:21 +02:00
parent 428de9354a
commit a68ec2a73e
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
7 changed files with 102 additions and 32 deletions

View File

@ -1,3 +1,10 @@
## 2.18.0-dev
- Add `AggregateFunctionExpression` to write custom [aggregate function](https://www.sqlite.org/lang_aggfunc.html)
invocations in the Dart query builder.
- The `json_group_array` and `jsonb_group_array` functions now contain an `orderBy`
and `filter` parameter.
## 2.17.0
- Adds `companion` entry to `DataClassName` to override the name of the

View File

@ -144,8 +144,13 @@ extension JsonExtensions on Expression<String> {
/// all emails in that folder.
/// This string could be turned back into a list with
/// `(json.decode(row.read(subjects)!) as List).cast<String>()`.
Expression<String> jsonGroupArray(Expression value) {
return FunctionCallExpression('json_group_array', [value]);
Expression<String> jsonGroupArray(
Expression value, {
OrderBy? orderBy,
Expression<bool>? filter,
}) {
return AggregateFunctionExpression('json_group_array', [value],
orderBy: orderBy, filter: filter);
}
/// Returns a binary representation of a JSON array containing the result of
@ -153,8 +158,13 @@ Expression<String> jsonGroupArray(Expression value) {
///
/// See [jsonGroupArray], the variant of this function returning a textual
/// description, for more details and an example.
Expression<Uint8List> jsonbGroupArray(Expression value) {
return FunctionCallExpression('jsonb_group_array', [value]);
Expression<Uint8List> jsonbGroupArray(
Expression value, {
OrderBy? orderBy,
Expression<bool>? filter,
}) {
return AggregateFunctionExpression('jsonb_group_array', [value],
orderBy: orderBy, filter: filter);
}
List<Expression> _groupObjectArgs(Map<Expression<String>, Expression> values) {

View File

@ -12,7 +12,7 @@ part of '../query_builder.dart';
/// This is equivalent to the `COUNT(*) FILTER (WHERE filter)` sql function. The
/// filter will be omitted if null.
Expression<int> countAll({Expression<bool>? filter}) {
return _AggregateExpression('COUNT', const [_StarFunctionParameter()],
return AggregateFunctionExpression('COUNT', const [_StarFunctionParameter()],
filter: filter);
}
@ -26,7 +26,7 @@ extension BaseAggregate<DT extends Object> on Expression<DT> {
/// counted twice.
/// {@macro drift_aggregate_filter}
Expression<int> count({bool distinct = false, Expression<bool>? filter}) {
return _AggregateExpression('COUNT', [this],
return AggregateFunctionExpression('COUNT', [this],
filter: filter, distinct: distinct);
}
@ -35,14 +35,14 @@ extension BaseAggregate<DT extends Object> on Expression<DT> {
/// If there are no non-null values in the group, returns null.
/// {@macro drift_aggregate_filter}
Expression<DT> max({Expression<bool>? filter}) =>
_AggregateExpression('MAX', [this], filter: filter);
AggregateFunctionExpression('MAX', [this], filter: filter);
/// Return the minimum of all non-null values in this group.
///
/// If there are no non-null values in the group, returns null.
/// {@macro drift_aggregate_filter}
Expression<DT> min({Expression<bool>? filter}) =>
_AggregateExpression('MIN', [this], filter: filter);
AggregateFunctionExpression('MIN', [this], filter: filter);
/// Returns the concatenation of all non-null values in the current group,
/// joined by the [separator].
@ -71,7 +71,7 @@ extension BaseAggregate<DT extends Object> on Expression<DT> {
'Cannot use groupConcat with distinct: true and a custom separator');
}
return _AggregateExpression(
return AggregateFunctionExpression(
'GROUP_CONCAT',
[
this,
@ -89,21 +89,21 @@ extension ArithmeticAggregates<DT extends num> on Expression<DT> {
///
/// {@macro drift_aggregate_filter}
Expression<double> avg({Expression<bool>? filter}) =>
_AggregateExpression('AVG', [this], filter: filter);
AggregateFunctionExpression('AVG', [this], filter: filter);
/// Return the maximum of all non-null values in this group.
///
/// If there are no non-null values in the group, returns null.
/// {@macro drift_aggregate_filter}
Expression<DT> max({Expression<bool>? filter}) =>
_AggregateExpression('MAX', [this], filter: filter);
AggregateFunctionExpression('MAX', [this], filter: filter);
/// Return the minimum of all non-null values in this group.
///
/// If there are no non-null values in the group, returns null.
/// {@macro drift_aggregate_filter}
Expression<DT> min({Expression<bool>? filter}) =>
_AggregateExpression('MIN', [this], filter: filter);
AggregateFunctionExpression('MIN', [this], filter: filter);
/// Calculate the sum of all non-null values in the group.
///
@ -115,7 +115,7 @@ extension ArithmeticAggregates<DT extends num> on Expression<DT> {
/// value and doesn't throw an overflow exception.
/// {@macro drift_aggregate_filter}
Expression<DT> sum({Expression<bool>? filter}) =>
_AggregateExpression('SUM', [this], filter: filter);
AggregateFunctionExpression('SUM', [this], filter: filter);
/// Calculate the sum of all non-null values in the group.
///
@ -123,7 +123,7 @@ extension ArithmeticAggregates<DT extends num> on Expression<DT> {
/// uses floating-point values internally.
/// {@macro drift_aggregate_filter}
Expression<double> total({Expression<bool>? filter}) =>
_AggregateExpression('TOTAL', [this], filter: filter);
AggregateFunctionExpression('TOTAL', [this], filter: filter);
}
/// Provides aggregate functions that are available for BigInt expressions.
@ -197,16 +197,41 @@ extension DateTimeAggregate on Expression<DateTime> {
}
}
class _AggregateExpression<D extends Object> extends Expression<D> {
/// An expression invoking an [aggregate function](https://www.sqlite.org/lang_aggfunc.html).
///
/// Aggregate functions, like `count()` or `sum()` collapse the entire data set
/// (or a partition of it, if `GROUP BY` is used) into a single value.
///
/// Drift exposes direct bindings to most aggregate functions (e.g. via
/// [BaseAggregate.count]). This class is useful when writing custom aggregate
/// function invocations.
final class AggregateFunctionExpression<D extends Object>
extends Expression<D> {
/// The name of the aggregate function to invoke.
final String functionName;
final bool distinct;
final List<FunctionParameter> parameter;
/// Whether only distinct rows should be passed to the function.
final bool distinct;
/// The arguments to pass to the function.
final List<FunctionParameter> arguments;
/// The order in which rows of the current group should be passed to the
/// aggregate function.
final OrderBy? orderBy;
/// An optional filter clause only passing rows matching this condition into
/// the function.
final Where? filter;
_AggregateExpression(this.functionName, this.parameter,
{Expression<bool>? filter, this.distinct = false})
: filter = filter != null ? Where(filter) : null;
/// Creates an aggregate function expression from the syntactic components.
AggregateFunctionExpression(
this.functionName,
this.arguments, {
Expression<bool>? filter,
this.distinct = false,
this.orderBy,
}) : filter = filter != null ? Where(filter) : null;
@override
final Precedence precedence = Precedence.primary;
@ -220,7 +245,11 @@ class _AggregateExpression<D extends Object> extends Expression<D> {
if (distinct) {
context.buffer.write('DISTINCT ');
}
_writeCommaSeparated(context, parameter);
_writeCommaSeparated(context, arguments);
if (orderBy case final orderBy?) {
context.writeWhitespace();
orderBy.writeInto(context);
}
context.buffer.write(')');
if (filter != null) {
@ -233,20 +262,20 @@ class _AggregateExpression<D extends Object> extends Expression<D> {
@override
int get hashCode {
return Object.hash(functionName, distinct,
const ListEquality<Object?>().hash(parameter), filter);
const ListEquality<Object?>().hash(arguments), orderBy, filter);
}
@override
bool operator ==(Object other) {
if (!identical(this, other) && other.runtimeType != runtimeType) {
if (!identical(this, other) && other is! AggregateFunctionExpression<D>) {
return false;
}
// ignore: test_types_in_equals
final typedOther = other as _AggregateExpression;
final typedOther = other as AggregateFunctionExpression<D>;
return typedOther.functionName == functionName &&
typedOther.distinct == distinct &&
const ListEquality<Object?>().equals(typedOther.parameter, parameter) &&
const ListEquality<Object?>().equals(typedOther.arguments, arguments) &&
typedOther.orderBy == orderBy &&
typedOther.filter == filter;
}
}

View File

@ -101,15 +101,17 @@ void main() {
db.todosTable, db.todosTable.category.equalsExp(db.categories.id))
]);
final stringArray = jsonGroupArray(db.todosTable.id);
final binaryArray = jsonbGroupArray(db.todosTable.id).json();
final stringArray = jsonGroupArray(db.todosTable.id,
orderBy: OrderBy([OrderingTerm.desc(db.todosTable.id)]));
final binaryArray = jsonbGroupArray(db.todosTable.id,
orderBy: OrderBy([OrderingTerm.asc(db.todosTable.id)])).json();
query
..groupBy([db.categories.id])
..addColumns([stringArray, binaryArray]);
final row = await query.getSingle();
expect(json.decode(row.read(stringArray)!), unorderedEquals([1, 3]));
expect(json.decode(row.read(binaryArray)!), unorderedEquals([1, 3]));
expect(json.decode(row.read(stringArray)!), [3, 1]);
expect(json.decode(row.read(binaryArray)!), [1, 3]);
});
test('json_group_object', () async {

View File

@ -42,6 +42,17 @@ void main() {
test('aggregates', () {
expect(jsonGroupArray(column), generates('json_group_array(col)'));
expect(
jsonGroupArray(
column,
orderBy: OrderBy([OrderingTerm.desc(column)]),
filter: column.length.isBiggerOrEqualValue(10),
),
generates(
'json_group_array(col ORDER BY col DESC) FILTER (WHERE LENGTH(col) >= ?)',
[10],
),
);
expect(
jsonGroupObject({
Variable('foo'): column,
@ -84,6 +95,17 @@ void main() {
test('aggregates', () {
expect(jsonbGroupArray(column), generates('jsonb_group_array(col)'));
expect(
jsonbGroupArray(
column,
orderBy: OrderBy([OrderingTerm.desc(column)]),
filter: column.length.isBiggerOrEqualValue(10),
),
generates(
'jsonb_group_array(col ORDER BY col DESC) FILTER (WHERE LENGTH(col) >= ?)',
[10],
),
);
expect(
jsonbGroupObject({
Variable('foo'): column,

View File

@ -107,7 +107,7 @@ class _GeneratesSqlMatcher extends Matcher {
matches = false;
}
final argsMatchState = <String, Object?>{};
final argsMatchState = <Object?, Object?>{};
if (_matchVariables != null &&
!_matchVariables.matches(ctx.boundVariables, argsMatchState)) {
matchState['vars'] = ctx.boundVariables;

View File

@ -1,4 +1,4 @@
// Mocks generated by Mockito 5.4.3 from annotations
// Mocks generated by Mockito 5.4.4 from annotations
// in drift/test/test_utils/test_utils.dart.
// Do not manually edit this file.