Improve resolving recursive CTEs

This commit is contained in:
Simon Binder 2023-05-30 13:51:47 +02:00
parent 0177175878
commit 9154f60dfd
No known key found for this signature in database
GPG Key ID: 7891917E4147B8C0
5 changed files with 186 additions and 119 deletions

View File

@ -69,6 +69,7 @@ enum AnalysisErrorType {
starColumnWithoutTable,
compoundColumnCountMismatch,
cteColumnCountMismatch,
circularReference,
valuesSelectCountMismatch,
viewColumnNamesMismatch,
rowValueMisuse,

View File

@ -11,17 +11,55 @@ class ColumnResolver extends RecursiveVisitor<ColumnResolverContext, void> {
@override
void visitSelectStatement(SelectStatement e, ColumnResolverContext arg) {
// visit children first so that common table expressions are resolved
visitChildren(e, arg);
e.withClause?.accept(this, arg);
_resolveSelect(e, arg);
// We've handled the from clause in _resolveSelect, but we still need to
// visit other children to handle things like subquery expressions.
for (final child in e.childNodes) {
if (child != e.withClause && child != e.from) {
visit(child, arg);
}
}
}
@override
void visitCreateIndexStatement(
CreateIndexStatement e, ColumnResolverContext arg) {
_resolveTableReference(e.on, arg);
visitExcept(e, e.on, arg);
}
@override
void visitCreateTriggerStatement(
CreateTriggerStatement e, ColumnResolverContext arg) {
final table = _resolveTableReference(e.onTable, arg);
if (table == null) {
// further analysis is not really possible without knowing the table
super.visitCreateTriggerStatement(e, arg);
return;
}
final scope = e.statementScope;
// Add columns of the target table for when and update of clauses
scope.expansionOfStarColumn = table.resolvedColumns;
if (e.target.introducesNew) {
scope.addAlias(e, table, 'new');
}
if (e.target.introducesOld) {
scope.addAlias(e, table, 'old');
}
visitChildren(e, arg);
}
@override
void visitCompoundSelectStatement(
CompoundSelectStatement e, ColumnResolverContext arg) {
// first, visit all children so that the compound parts have their columns
// resolved
visitChildren(e, arg);
e.base.accept(this, arg);
visitList(e.additional, arg);
_resolveCompoundSelect(e);
}
@ -29,29 +67,75 @@ class ColumnResolver extends RecursiveVisitor<ColumnResolverContext, void> {
@override
void visitValuesSelectStatement(
ValuesSelectStatement e, ColumnResolverContext arg) {
// visit children to resolve CTEs
visitChildren(e, arg);
e.withClause?.accept(this, arg);
_resolveValuesSelect(e);
// Still visit expressions because they could have subqueries that we need
// to handle.
visitList(e.values, arg);
}
@override
void visitCommonTableExpression(
CommonTableExpression e, ColumnResolverContext arg) {
visitChildren(
e,
const ColumnResolverContext(referencesUseNameOfReferencedColumn: false),
// If we have a compound select statement as a CTE, resolve the initial
// query first because the whole CTE will have those columns in the end.
// This allows subsequent parts of the compound select to refer to the CTE.
final query = e.as;
final contextForFirstChild = ColumnResolverContext(
referencesUseNameOfReferencedColumn: false,
inDefinitionOfCte: [
...arg.inDefinitionOfCte,
e.cteTableName.toLowerCase(),
],
);
final resolved = e.as.resolvedColumns;
final names = e.columnNames;
if (names != null && resolved != null && names.length != resolved.length) {
context.reportError(AnalysisError(
type: AnalysisErrorType.cteColumnCountMismatch,
message: 'This CTE declares ${names.length} columns, but its select '
'statement actually returns ${resolved.length}.',
relevantNode: e,
));
void applyColumns(BaseSelectStatement source) {
final resolved = source.resolvedColumns!;
final names = e.columnNames;
if (names == null) {
e.resolvedColumns = resolved;
} else {
if (names.length != resolved.length) {
context.reportError(AnalysisError(
type: AnalysisErrorType.cteColumnCountMismatch,
message:
'This CTE declares ${names.length} columns, but its select '
'statement actually returns ${resolved.length}.',
relevantNode: e.tableNameToken ?? e,
));
}
final cteColumns = names
.map((name) => CommonTableExpressionColumn(name)..containingSet = e)
.toList();
for (var i = 0; i < cteColumns.length; i++) {
if (i < resolved.length) {
final selectColumn = resolved[i];
cteColumns[i].innerColumn = selectColumn;
}
}
e.resolvedColumns = cteColumns;
}
}
if (query is CompoundSelectStatement) {
// The first nested select statement determines the columns of this CTE.
query.base.accept(this, contextForFirstChild);
applyColumns(query.base);
// Subsequent queries can refer to the CTE though.
final contextForOtherChildren = ColumnResolverContext(
referencesUseNameOfReferencedColumn: false,
inDefinitionOfCte: arg.inDefinitionOfCte,
);
visitList(query.additional, contextForOtherChildren);
_resolveCompoundSelect(query);
} else {
visitChildren(e, contextForFirstChild);
applyColumns(query);
}
}
@ -70,10 +154,9 @@ class ColumnResolver extends RecursiveVisitor<ColumnResolverContext, void> {
}
@override
void visitTableReference(TableReference e, void arg) {
if (e.resolved == null) {
_resolveTableReference(e);
}
void visitForeignKeyClause(ForeignKeyClause e, ColumnResolverContext arg) {
_resolveTableReference(e.foreignTable, arg);
visitExcept(e, e.foreignTable, arg);
}
@override
@ -100,8 +183,9 @@ class ColumnResolver extends RecursiveVisitor<ColumnResolverContext, void> {
_resolveReturningClause(e, e.table.resultSet, arg);
}
ResultSet? _addIfResolved(AstNode node, TableReference ref) {
final table = _resolveTableReference(ref);
ResultSet? _addIfResolved(
AstNode node, TableReference ref, ColumnResolverContext arg) {
final table = _resolveTableReference(ref, arg);
if (table != null) {
node.statementScope.expansionOfStarColumn = table.resolvedColumns;
}
@ -114,7 +198,7 @@ class ColumnResolver extends RecursiveVisitor<ColumnResolverContext, void> {
// Resolve CTEs first
e.withClause?.accept(this, arg);
final into = _addIfResolved(e, e.table);
final into = _addIfResolved(e, e.table, arg);
for (final child in e.childNodes) {
if (child != e.withClause) visit(child, arg);
}
@ -126,7 +210,7 @@ class ColumnResolver extends RecursiveVisitor<ColumnResolverContext, void> {
// Resolve CTEs first
e.withClause?.accept(this, arg);
final from = _addIfResolved(e, e.from);
final from = _addIfResolved(e, e.from, arg);
for (final child in e.childNodes) {
if (child != e.withClause) visit(child, arg);
}
@ -168,31 +252,10 @@ class ColumnResolver extends RecursiveVisitor<ColumnResolverContext, void> {
stmt.returnedResultSet = CustomResultSet(columns);
}
@override
void visitCreateTriggerStatement(
CreateTriggerStatement e, ColumnResolverContext arg) {
final table = _resolveTableReference(e.onTable);
if (table == null) {
// further analysis is not really possible without knowing the table
super.visitCreateTriggerStatement(e, arg);
return;
}
final scope = e.statementScope;
// Add columns of the target table for when and update of clauses
scope.expansionOfStarColumn = table.resolvedColumns;
if (e.target.introducesNew) {
scope.addAlias(e, table, 'new');
}
if (e.target.introducesOld) {
scope.addAlias(e, table, 'old');
}
visitChildren(e, arg);
}
/// Visits a [queryable] appearing in a `FROM` clause under the state [state].
///
/// This also adds columns contributed to the resolved source to
/// [availableColumns], which is later used to expand `*` parameters.
void _handle(Queryable queryable, List<Column> availableColumns,
ColumnResolverContext state) {
void addColumns(Iterable<Column> columns) {
@ -211,39 +274,32 @@ class ColumnResolver extends RecursiveVisitor<ColumnResolverContext, void> {
queryable.when(
isTable: (table) {
final resolved = _resolveTableReference(table);
final resolved = _resolveTableReference(table, state);
if (resolved != null) {
// an error will be logged when resolved is null, so the != null check
// is fine and avoids crashes
addColumns(table.resultSet!.resolvedColumns!);
}
},
isSelect: (select) {
// Inside subqueries, references don't take the name of the referenced
// column.
final childState =
ColumnResolverContext(referencesUseNameOfReferencedColumn: false);
// the inner select statement doesn't have access to columns defined in
// the outer statements, which is why we use _resolveSelect instead of
// passing availableColumns down to a recursive call of _handle
final childState = ColumnResolverContext(
referencesUseNameOfReferencedColumn: false,
inDefinitionOfCte: state.inDefinitionOfCte,
);
final stmt = select.statement;
if (stmt is CompoundSelectStatement) {
_resolveCompoundSelect(stmt);
} else if (stmt is SelectStatement) {
_resolveSelect(stmt, childState);
} else if (stmt is ValuesSelectStatement) {
_resolveValuesSelect(stmt);
} else {
throw AssertionError('Unknown type of select statement: $stmt');
}
visit(stmt, childState);
addColumns(stmt.resolvedColumns!);
},
isJoin: (join) {
_handle(join.primary, availableColumns, state);
for (final query in join.joins.map((j) => j.query)) {
_handle(query, availableColumns, state);
isJoin: (joinClause) {
_handle(joinClause.primary, availableColumns, state);
for (final join in joinClause.joins) {
_handle(join.query, availableColumns, state);
final constraint = join.constraint;
if (constraint is OnConstraint) {
visit(constraint.expression, state);
}
}
},
isTableFunction: (function) {
@ -458,7 +514,18 @@ class ColumnResolver extends RecursiveVisitor<ColumnResolverContext, void> {
return span;
}
ResultSet? _resolveTableReference(TableReference r) {
ResultSet? _resolveTableReference(
TableReference r, ColumnResolverContext state) {
// Check for circular references
if (state.inDefinitionOfCte.contains(r.tableName.toLowerCase())) {
context.reportError(AnalysisError(
type: AnalysisErrorType.circularReference,
relevantNode: r,
message: 'Circular reference to its own CTE',
));
return null;
}
final scope = r.scope;
// Try resolving to a top-level table in the schema and to a result set that
@ -528,6 +595,13 @@ class ColumnResolverContext {
/// column in subqueries or CTEs.
final bool referencesUseNameOfReferencedColumn;
const ColumnResolverContext(
{this.referencesUseNameOfReferencedColumn = true});
/// The common table expressions that are currently being defined.
///
/// This is used to detect forbidden circular references.
final List<String> inDefinitionOfCte;
const ColumnResolverContext({
this.referencesUseNameOfReferencedColumn = true,
this.inDefinitionOfCte = const [],
});
}

View File

@ -49,7 +49,8 @@ class CommonTableExpression extends AstNode with ResultSet {
Token? asToken;
IdentifierToken? tableNameToken;
List<CommonTableExpressionColumn>? _cachedColumns;
@override
List<Column>? resolvedColumns;
CommonTableExpression({
required this.cteTableName,
@ -71,33 +72,6 @@ class CommonTableExpression extends AstNode with ResultSet {
@override
Iterable<AstNode> get childNodes => [as];
@override
List<Column>? get resolvedColumns {
final columnsOfSelect = as.resolvedColumns;
// we don't override column names, so just return the columns declared by
// the select statement
if (columnNames == null) return columnsOfSelect;
final cached = _cachedColumns ??= columnNames!
.map((name) => CommonTableExpressionColumn(name)..containingSet = this)
.toList();
if (columnsOfSelect != null) {
// bind the CommonTableExpressionColumn to the real underlying column
// returned by the select statement
for (var i = 0; i < cached.length; i++) {
if (i < columnsOfSelect.length) {
final selectColumn = columnsOfSelect[i];
cached[i].innerColumn = selectColumn;
}
}
}
return _cachedColumns;
}
@override
bool get visibleToChildren => true;
}

View File

@ -270,22 +270,18 @@ class SqlEngine {
final node = context.root;
node.scope = context.rootScope;
try {
AstPreparingVisitor(context: context).start(node);
AstPreparingVisitor(context: context).start(node);
node
..accept(ColumnResolver(context), const ColumnResolverContext())
..accept(ReferenceResolver(context), const ReferenceResolvingContext());
node
..accept(ColumnResolver(context), const ColumnResolverContext())
..accept(ReferenceResolver(context), const ReferenceResolvingContext());
final session = TypeInferenceSession(context, options);
final resolver = TypeResolver(session);
resolver.run(node);
context.types2 = session.results!;
final session = TypeInferenceSession(context, options);
final resolver = TypeResolver(session);
resolver.run(node);
context.types2 = session.results!;
node.acceptWithoutArg(LintingVisitor(options, context));
} catch (_) {
rethrow;
}
node.acceptWithoutArg(LintingVisitor(options, context));
}
}

View File

@ -100,6 +100,21 @@ END;
});
});
test('resolves index', () {
final context = engine.analyze('CREATE INDEX foo ON demo (content)');
context.expectNoError();
final tableReference =
context.root.allDescendants.whereType<TableReference>().first;
final columnReference = context.root.allDescendants
.whereType<IndexedColumn>()
.first
.expression as Reference;
expect(tableReference.resolved, demoTable);
expect(columnReference.resolvedColumn, isA<AvailableColumn>());
});
test("DO UPDATE action in upsert can refer to 'exluded'", () {
final context = engine.analyze('''
INSERT INTO demo VALUES (?, ?)
@ -270,4 +285,11 @@ INSERT INTO demo VALUES (?, ?)
.root as SelectStatement;
expect(cte.resolvedColumns?.map((e) => e.name), ['RoWiD']);
});
test('reports error for circular reference', () {
final query = engine.analyze('WITH x AS (SELECT * FROM x) SELECT 1;');
expect(query.errors, [
analysisErrorWith(lexeme: 'x', type: AnalysisErrorType.circularReference),
]);
});
}