Skip to content

Commit

Permalink
Fix column references in versioned schema code (#3259)
Browse files Browse the repository at this point in the history
When a column is defined like this:

```dart
  DateTimeColumn get creationTime => dateTime()
    .check(creationTime.isBiggerThan(Constant(DateTime(2020))))
    .withDefault(Constant(DateTime(2024, 1, 1)))();
```

Our generated code relies on the fact that `creationTime` as a getter is in scope for the copied `check` code (it is because we're generating columns in table classes). With versioned schemas however, we have an optimization that tries to not re-generate column code if it hasn't changed between different schema versions. With this generation mode, columns are no longer in scope for check constraints.

This fix relies on detecting columns in Dart code (so we see that the `creationTime` reference in `check` references a column) and then rewriting these expressions with a `CustomExpression` when generating code:

```dart
i1.GeneratedColumn<DateTime> _column_18(String aliasedName) =>
    i1.GeneratedColumn<DateTime>('birthday', aliasedName, true,
        check: () =>
            i2.ComparableExpr((i0.VersionedTable.col<DateTime>('birthday')))
                .isBiggerThan(i2.Constant(DateTime(1900))),
        type: i1.DriftSqlType.dateTime);
```

Closes #3219
  • Loading branch information
simolus3 authored Oct 3, 2024
1 parent 7665813 commit ab93107
Show file tree
Hide file tree
Showing 30 changed files with 1,372 additions and 53 deletions.
9 changes: 9 additions & 0 deletions drift/lib/internal/versioned_schema.dart
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ class VersionedTable extends Table with TableInfo<Table, QueryRow> {
VersionedTable createAlias(String alias) {
return VersionedTable.aliased(source: this, alias: alias);
}

/// Generates an expression referencing a column in the same table with the
/// given [name].
///
/// Intended for generated code.
static Expression<T> col<T extends Object>(String name) {
return CustomExpression(SqlDialect.sqlite.escape(name),
precedence: Precedence.primary);
}
}

/// The version of [VersionedTable] for virtual tables.
Expand Down
25 changes: 22 additions & 3 deletions drift_dev/lib/src/analysis/resolver/dart/column.dart
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,25 @@ const String _errorMessage = 'This getter does not create a valid column that '
class ColumnParser {
final DartTableResolver _resolver;

ColumnParser(this._resolver);
/// A map of elements to their name for elements defining columns.
///
/// This is used to recognize column references in arbitrary Dart code, e.g.
/// in this definition:
///
/// ```
/// DateTimeColumn get creationTime => dateTime()
/// .check(creationTime.isBiggerThan(Constant(DateTime(2020))))();
/// ```
///
/// Here, the check constraint references the column itself. In some code
/// generation modes where we generate code for individual columns (instead
/// of for entire table structures, this mainly includes step-by-step
/// migrations), there might not be a `creationTime` in scope for the check
/// constraint. So, we annotate these references in [AnnotatedDartCode] and
/// use that information when generating code to transform the code.
final Map<Element, String> _columnsInSameTable;

ColumnParser(this._resolver, this._columnsInSameTable);

Future<PendingColumnInformation?> parse(
ColumnDeclaration columnDeclaration, Element element) async {
Expand Down Expand Up @@ -343,8 +361,9 @@ class ColumnParser {
break;
case _methodCheck:
final expr = remainingExpr.argumentList.arguments.first;
foundConstraints
.add(DartCheckExpression(AnnotatedDartCode.ast(expr)));

foundConstraints.add(DartCheckExpression(AnnotatedDartCode.build(
(b) => b.addAstNode(expr, taggedElements: _columnsInSameTable))));
}

// We're not at a starting method yet, so we need to go deeper!
Expand Down
14 changes: 8 additions & 6 deletions drift_dev/lib/src/analysis/resolver/dart/table.dart
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ class DartTableResolver extends LocalElementResolver<DiscoveredDartTable> {
element.lookUpInheritedConcreteGetter(name, element.library);
// ignore: deprecated_member_use
return getter!.variable;
});
}).toList();
final all = {for (final entry in fields) entry.getter ?? entry: entry.name};

final results = <PendingColumnInformation>[];
for (final field in fields) {
final ColumnDeclaration node;
Expand All @@ -317,14 +319,14 @@ class DartTableResolver extends LocalElementResolver<DiscoveredDartTable> {
.loadElementDeclaration(field.declaration)
as VariableDeclaration,
null);
column = await _parseColumn(node, field.declaration);
column = await _parseColumn(node, field.declaration, all);
} else {
node = ColumnDeclaration(
null,
await resolver.driver.backend.loadElementDeclaration(field.getter!)
as MethodDeclaration);

column = await _parseColumn(node, field.getter!);
column = await _parseColumn(node, field.getter!, all);
}

if (column != null) {
Expand All @@ -335,9 +337,9 @@ class DartTableResolver extends LocalElementResolver<DiscoveredDartTable> {
return results.whereType();
}

Future<PendingColumnInformation?> _parseColumn(
ColumnDeclaration declaration, Element element) async {
return ColumnParser(this).parse(declaration, element);
Future<PendingColumnInformation?> _parseColumn(ColumnDeclaration declaration,
Element element, Map<Element, String> allColumns) async {
return ColumnParser(this, allColumns).parse(declaration, element);
}

Future<List<String>> _readCustomConstraints(Set<DriftElement> references,
Expand Down
4 changes: 2 additions & 2 deletions drift_dev/lib/src/analysis/resolver/shared/data_class.dart
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ CustomParentClass? parseCustomParentClass(
if (genericType.isDartCoreObject || genericType is DynamicType) {
code = AnnotatedDartCode([
DartTopLevelSymbol.topLevelElement(extendingType.element),
'<',
const DartLexeme('<'),
DartTopLevelSymbol(
dartTypeName ?? dataClassNameForClassName(element.name),
null,
),
'>',
const DartLexeme('>'),
]);
} else {
resolver.reportError(
Expand Down
96 changes: 80 additions & 16 deletions drift_dev/lib/src/analysis/results/dart.dart
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ class AnnotatedDartCode {
static final Uri dartCore = Uri.parse('dart:core');
static final Uri drift = Uri.parse('package:drift/drift.dart');

final List<dynamic /* String|DartTopLevelSymbol */ > elements;
final List<DartCodeElement> elements;

AnnotatedDartCode(this.elements)
: assert(elements.every((e) => e is String || e is DartTopLevelSymbol));
AnnotatedDartCode(this.elements);

AnnotatedDartCode.text(String e) : elements = [e];
AnnotatedDartCode.text(String e) : elements = [DartLexeme(e)];

factory AnnotatedDartCode.ast(AstNode node) {
return AnnotatedDartCode.build(((builder) => builder.addAstNode(node)));
Expand All @@ -55,8 +54,7 @@ class AnnotatedDartCode {
final serializedElements = json['elements'] as List;

return AnnotatedDartCode([
for (final part in serializedElements)
if (part is Map) DartTopLevelSymbol.fromJson(part) else part as String
for (final part in serializedElements) DartCodeElement.fromJson(part)
]);
}

Expand All @@ -66,10 +64,7 @@ class AnnotatedDartCode {

Map<String, Object?> toJson() {
return {
'elements': [
for (final element in elements)
if (element is DartTopLevelSymbol) element.toJson() else element
],
'elements': [for (final element in elements) element.toJson()],
};
}

Expand All @@ -90,12 +85,12 @@ class AnnotatedDartCode {
}

class AnnotatedDartCodeBuilder {
final List<dynamic> _elements = [];
final List<DartCodeElement> _elements = [];
final StringBuffer _pendingText = StringBuffer();

void _addPendingText() {
if (_pendingText.isNotEmpty) {
_elements.add(_pendingText.toString());
_elements.add(DartLexeme(_pendingText.toString()));
_pendingText.clear();
}
}
Expand All @@ -122,12 +117,21 @@ class AnnotatedDartCodeBuilder {
_elements.add(DartTopLevelSymbol.topLevelElement(element));
}

void addTagged(String lexeme, String tag) {
_addPendingText();
_elements.add(TaggedDartLexeme(lexeme, tag));
}

void addDartType(DartType type) {
type.accept(_AddFromDartType(this));
}

void addAstNode(AstNode node, {Set<AstNode> exclude = const {}}) {
final visitor = _AddFromAst(this, exclude);
void addAstNode(
AstNode node, {
Set<AstNode> exclude = const {},
Map<Element, String> taggedElements = const {},
}) {
final visitor = _AddFromAst(this, exclude, taggedElements);
node.accept(visitor);
}

Expand Down Expand Up @@ -224,8 +228,64 @@ class AnnotatedDartCodeBuilder {
}
}

sealed class DartCodeElement {
Object? toJson();

factory DartCodeElement.fromJson(Object? json) {
return switch (json) {
String s => DartLexeme(s),
{'import_uri': _} => DartTopLevelSymbol.fromJson(json),
{'tag': _} => TaggedDartLexeme.fromJson(json),
_ => throw ArgumentError.value(json, 'json', 'Unknown code element'),
};
}
}

final class DartLexeme implements DartCodeElement {
final String lexeme;

const DartLexeme(this.lexeme);

@override
Object? toJson() {
return lexeme;
}

@override
String toString() {
return lexeme;
}
}

/// A variant of [DartLexeme] with a custom associated [tag].
///
/// For a motivation, see `ColumnParser._columnsInSameTable` - essentially, some
/// drift tools need to resolve column references in Dart code to rewrite them
/// depending on the generation mode.
@JsonSerializable()
final class TaggedDartLexeme implements DartCodeElement {
final String lexeme;
final String tag;

TaggedDartLexeme(this.lexeme, this.tag);

factory TaggedDartLexeme.fromJson(Map json) =>
_$TaggedDartLexemeFromJson(json);

@override
Map<String, Object?> toJson() => _$TaggedDartLexemeToJson(this);

@override
String toString() {
return lexeme;
}
}

/// A variant of [DartLexeme] that is used for top-level elements to also store
/// the import URI. This allows drift's code generator, when encountering such
/// element, to automatically add the relevant import to generated Dart files.
@JsonSerializable()
class DartTopLevelSymbol {
final class DartTopLevelSymbol implements DartCodeElement {
static final _driftUri = Uri.parse('package:drift/drift.dart');

static final list = DartTopLevelSymbol('List', AnnotatedDartCode.dartCore);
Expand Down Expand Up @@ -259,6 +319,7 @@ class DartTopLevelSymbol {
factory DartTopLevelSymbol.fromJson(Map json) =>
_$DartTopLevelSymbolFromJson(json);

@override
Map<String, Object?> toJson() => _$DartTopLevelSymbolToJson(this);
}

Expand Down Expand Up @@ -453,8 +514,9 @@ class _AddFromDartType extends UnifyingTypeVisitor<void> {
class _AddFromAst extends GeneralizingAstVisitor<void> {
final AnnotatedDartCodeBuilder _builder;
final Set<AstNode> _excluding;
final Map<Element, String> _taggedElements;

_AddFromAst(this._builder, this._excluding);
_AddFromAst(this._builder, this._excluding, this._taggedElements);

void _addTopLevelReference(Element? element, Token name2) {
if (element == null || (element.isSynthetic && element.library == null)) {
Expand Down Expand Up @@ -575,6 +637,8 @@ class _AddFromAst extends GeneralizingAstVisitor<void> {

if (isTopLevel) {
_builder.addTopLevelElement(target!);
} else if (_taggedElements[target] case final tag?) {
_builder.addTagged(node.token.lexeme, tag);
} else {
_builder.addText(node.name);
}
Expand Down
11 changes: 11 additions & 0 deletions drift_dev/lib/src/generated/analysis/results/dart.g.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions drift_dev/lib/src/services/schema/schema_files.dart
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,9 @@ class SchemaReader {
nullable: nullable,
nameInSql: name,
nameInDart: getterName ?? ReCase(name).camelCase,
defaultArgument:
defaultDart != null ? AnnotatedDartCode([defaultDart]) : null,
defaultArgument: defaultDart != null
? AnnotatedDartCode([DartLexeme(defaultDart)])
: null,
declaration: _declaration,
customConstraints: customConstraints,
constraints: dslFeatures,
Expand Down
23 changes: 22 additions & 1 deletion drift_dev/lib/src/writer/schema_version_writer.dart
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,28 @@ class SchemaVersionWriter {
/// called in different places. This method looks up or creates a method for
/// the given [column], returning it if doesn't exist.
String _referenceColumn(DriftColumn column) {
final text = libraryScope.leaf();
final text = libraryScope.leaf(writeTaggedDartCode: (tag, buffer) {
final dartName = tag.tag;
final referencedColumn = column.owner.columns
.singleWhereOrNull((e) => e.nameInDart == dartName);

if (referencedColumn != null) {
// This references a column in the same table. Since we're not
// generating columns in a table structure where they would be in scope
// for Dart, we have to replace this with a custom expression evaluating
// to the column.
final sqlType = libraryScope.innerColumnType(referencedColumn.sqlType);
final result = libraryScope.dartCode(AnnotatedDartCode.build((b) => b
..addText('(')
..addSymbol('VersionedTable', _schemaLibrary)
..addText('.col<')
..addCode(sqlType)
..addText('>(${asDartLiteral(referencedColumn.nameInSql)}))')));
buffer.write(result);
} else {
buffer.write(tag.lexeme);
}
});
final (type, code) = TableOrViewWriter.instantiateColumn(column, text);

return _columnCodeToFactory.putIfAbsent(code, () {
Expand Down
2 changes: 1 addition & 1 deletion drift_dev/lib/src/writer/tables/data_class_writer.dart
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DataClassWriter {
final nullable = converter.canBeSkippedForNulls && column.nullable;
final code = AnnotatedDartCode([
...AnnotatedDartCode.type(converter.jsonType!).elements,
if (nullable) '?',
if (nullable) const DartLexeme('?'),
]);

return _emitter.dartCode(code);
Expand Down
5 changes: 4 additions & 1 deletion drift_dev/lib/src/writer/utils/column_constraints.dart
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ Map<SqlDialect, String> defaultConstraints(DriftColumn column) {
result.write(defaults);
}
if (feature.dialectSpecific[dialect] case final specific?) {
result.write(' $specific');
if (result.isNotEmpty) {
result.write(' ');
}
result.write(specific);
}
return result.toString();
}
Expand Down
Loading

0 comments on commit ab93107

Please sign in to comment.