diff --git a/sqlparser/lib/src/analysis/context.dart b/sqlparser/lib/src/analysis/context.dart index a6fbaf19..6ff7dcb1 100644 --- a/sqlparser/lib/src/analysis/context.dart +++ b/sqlparser/lib/src/analysis/context.dart @@ -12,15 +12,5 @@ class AnalysisContext { errors.add(error); } - ResolveResult typeOf(Typeable t) { - if (t is Column) { - return types.resolveColumn(t); - } else if (t is Variable) { - return types.inferType(t); - } else if (t is Expression) { - return types.resolveExpression(t); - } - - throw StateError('Unknown typeable $t'); - } + ResolveResult typeOf(Typeable t) => types.resolveOrInfer(t); } diff --git a/sqlparser/lib/src/analysis/types/resolver.dart b/sqlparser/lib/src/analysis/types/resolver.dart index 873de405..eccefdfc 100644 --- a/sqlparser/lib/src/analysis/types/resolver.dart +++ b/sqlparser/lib/src/analysis/types/resolver.dart @@ -29,6 +29,28 @@ class TypeResolver { return calculated; } + ResolveResult resolveOrInfer(Typeable t) { + if (t is Column) { + return resolveColumn(t); + } else if (t is Variable) { + return inferType(t); + } else if (t is Expression) { + return resolveExpression(t); + } + + throw StateError('Unknown typeable $t'); + } + + ResolveResult justResolve(Typeable t) { + if (t is Column) { + return resolveColumn(t); + } else if (t is Expression) { + return resolveExpression(t); + } + + throw StateError('Unknown typeable $t'); + } + ResolveResult resolveColumn(Column column) { return _cache((column) { if (column is TableColumn) { @@ -97,8 +119,109 @@ class TypeResolver { } ResolveResult resolveFunctionCall(FunctionExpression call) { - // todo - return const ResolveResult.unknown(); + return _cache((FunctionExpression call) { + List parameters; + final sqlParameters = call.parameters; + if (sqlParameters is ExprFunctionParameters) { + parameters = sqlParameters.parameters; + } else if (sqlParameters is StarFunctionParameter) { + parameters = call.scope.availableColumns; + } + + final firstNullable = justResolve(parameters.first).nullable; + final anyNullable = parameters.map(justResolve).any((r) => r.nullable); + + switch (call.name.toLowerCase()) { + case 'round': + // if there is only one param, returns an int. otherwise real + if (parameters.length == 1) { + return ResolveResult( + ResolvedType(type: BasicType.int, nullable: firstNullable)); + } else { + return ResolveResult( + ResolvedType(type: BasicType.real, nullable: anyNullable)); + } + break; + case 'sum': + final firstType = justResolve(parameters.first); + if (firstType.type?.type == BasicType.int) { + return firstType; + } else { + return ResolveResult(ResolvedType( + type: BasicType.real, nullable: firstType.nullable)); + } + break; // can't happen, though + case 'lower': + case 'ltrim': + case 'printf': + case 'replace': + case 'rtrim': + case 'substr': + case 'trim': + case 'upper': + case 'group_concat': + return ResolveResult( + ResolvedType(type: BasicType.text, nullable: firstNullable)); + case 'date': + case 'time': + case 'datetime': + case 'julianday': + case 'strftime': + case 'char': + case 'hex': + case 'quote': + case 'soundex': + case 'sqlite_compileoption_set': + case 'sqlite_version': + case 'typeof': + return const ResolveResult(ResolvedType(type: BasicType.text)); + case 'changes': + case 'last_insert_rowid': + case 'random': + case 'sqlite_compileoption_used': + case 'total_changes': + case 'count': + return const ResolveResult(ResolvedType(type: BasicType.int)); + case 'instr': + case 'length': + case 'unicode': + return ResolveResult( + ResolvedType(type: BasicType.int, nullable: anyNullable)); + case 'randomblob': + case 'zeroblob': + return const ResolveResult(ResolvedType(type: BasicType.blob)); + case 'total': + case 'avg': + return const ResolveResult(ResolvedType(type: BasicType.real)); + case 'abs': + case 'likelihood': + case 'likely': + case 'unlikely': + return justResolve(parameters.first); + case 'coalesce': + case 'ifnull': + return ResolveResult(_encapsulate(parameters, + [BasicType.int, BasicType.real, BasicType.text, BasicType.blob])); + case 'nullif': + return justResolve(parameters.first).withNullable(true); + case 'max': + return ResolveResult(_encapsulate(parameters, [ + BasicType.int, + BasicType.real, + BasicType.text, + BasicType.blob + ])).withNullable(true); + case 'min': + return ResolveResult(_encapsulate(parameters, [ + BasicType.blob, + BasicType.text, + BasicType.int, + BasicType.real + ])).withNullable(true); + } + + throw StateError('Unknown function: ${call.name}'); + }, call); } ResolveResult inferType(Expression e) { @@ -126,6 +249,8 @@ class TypeResolver { return resolveExpression(relevant as Expression); } else if (parent is Parentheses || parent is UnaryExpression) { return const ResolveResult.needsContext(); + } else if (parent is FunctionExpression) { + return resolveFunctionCall(parent); } throw StateError('Cannot infer argument type: $parent'); @@ -134,9 +259,9 @@ class TypeResolver { /// Returns the type of an expression in [expressions] that has the highest /// order in [types]. ResolvedType _encapsulate( - Iterable expressions, List types) { + Iterable expressions, List types) { final argTypes = expressions - .map((expr) => resolveExpression(expr).type) + .map((expr) => justResolve(expr).type) .where((t) => t != null); final type = types.lastWhere((t) => argTypes.any((arg) => arg.type == t)); final notNull = argTypes.any((t) => !t.nullable); @@ -162,6 +287,18 @@ class ResolveResult { needsContext = false, unknown = true; + bool get nullable => type?.nullable ?? true; + + ResolveResult withNullable(bool nullable) { + if (type != null) { + return ResolveResult(type.withNullable(nullable)); + } else if (needsContext != null) { + return const ResolveResult.needsContext(); + } else { + return const ResolveResult.unknown(); + } + } + @override bool operator ==(other) { return identical(this, other) || diff --git a/sqlparser/lib/src/ast/statements/select.dart b/sqlparser/lib/src/ast/statements/select.dart index 5854076f..d6e89bb4 100644 --- a/sqlparser/lib/src/ast/statements/select.dart +++ b/sqlparser/lib/src/ast/statements/select.dart @@ -36,6 +36,7 @@ class SelectStatement extends Statement with ResultSet { if (where != null) where, ...columns, if (from != null) ...from, + if (groupBy != null) groupBy, if (limit != null) limit, if (orderBy != null) orderBy, ]; diff --git a/sqlparser/test/analysis/type_resolver_test.dart b/sqlparser/test/analysis/type_resolver_test.dart index eb0f2fb5..9916ba40 100644 --- a/sqlparser/test/analysis/type_resolver_test.dart +++ b/sqlparser/test/analysis/type_resolver_test.dart @@ -10,6 +10,8 @@ Map _types = { const ResolveResult(ResolvedType(type: BasicType.text)), 'SELECT * FROM demo LIMIT ?': const ResolveResult(ResolvedType(type: BasicType.int)), + 'SELECT 1 FROM demo GROUP BY id HAVING COUNT(*) = ?': + const ResolveResult(ResolvedType(type: BasicType.int)), }; void main() {