diff --git a/migrations/capcons/capabilitymigration.go b/migrations/capcons/capabilitymigration.go index fce3b419ff..048e8eb6e5 100644 --- a/migrations/capcons/capabilitymigration.go +++ b/migrations/capcons/capabilitymigration.go @@ -68,7 +68,12 @@ func (m *CapabilityValueMigration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, _ *interpreter.Interpreter, -) (interpreter.Value, error) { + _ migrations.ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { + reporter := m.Reporter switch value := value.(type) { diff --git a/migrations/capcons/linkmigration.go b/migrations/capcons/linkmigration.go index b54b4504a4..922ef82c5a 100644 --- a/migrations/capcons/linkmigration.go +++ b/migrations/capcons/linkmigration.go @@ -70,7 +70,11 @@ func (m *LinkValueMigration) Migrate( storageMapKey interpreter.StorageMapKey, value interpreter.Value, inter *interpreter.Interpreter, -) (interpreter.Value, error) { + _ migrations.ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { pathValue, ok := storageKeyToPathValue(storageKey, storageMapKey) if !ok { diff --git a/migrations/entitlements/migration.go b/migrations/entitlements/migration.go index c6639da7e4..b58136d05f 100644 --- a/migrations/entitlements/migration.go +++ b/migrations/entitlements/migration.go @@ -371,6 +371,7 @@ func (m EntitlementsMigration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, _ *interpreter.Interpreter, + _ migrations.ValueMigrationPosition, ) ( interpreter.Value, error, diff --git a/migrations/entitlements/migration_test.go b/migrations/entitlements/migration_test.go index 6d6e1f04e6..88cc6fdce1 100644 --- a/migrations/entitlements/migration_test.go +++ b/migrations/entitlements/migration_test.go @@ -30,6 +30,7 @@ import ( "github.com/onflow/cadence" "github.com/onflow/cadence/migrations" "github.com/onflow/cadence/migrations/statictypes" + "github.com/onflow/cadence/migrations/type_keys" "github.com/onflow/cadence/runtime" "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" @@ -690,6 +691,7 @@ func (m testEntitlementsMigration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, _ *interpreter.Interpreter, + _ migrations.ValueMigrationPosition, ) ( interpreter.Value, error, @@ -731,6 +733,7 @@ func convertEntireTestValue( }, reporter, true, + migrations.ValueMigrationPositionOther, ) err = migration.Commit() @@ -2863,6 +2866,7 @@ func TestConvertMigratedAccountTypes(t *testing.T) { nil, value, inter, + migrations.ValueMigrationPositionOther, ) require.NoError(t, err) require.NotNil(t, newValue) @@ -3658,6 +3662,7 @@ func TestUseAfterMigrationFailure(t *testing.T) { migration.NewValueMigrationsPathMigrator( reporter, NewEntitlementsMigration(inter), + type_keys.NewTypeKeyMigration(), ), ) @@ -3673,7 +3678,21 @@ func TestUseAfterMigrationFailure(t *testing.T) { assert.ErrorContains(t, reporter.errors[0], importErrorMessage) - require.Empty(t, reporter.migrated) + assert.Equal(t, + map[struct { + interpreter.StorageKey + interpreter.StorageMapKey + }]struct{}{ + { + StorageKey: interpreter.StorageKey{ + Address: testAddress, + Key: common.PathDomainStorage.Identifier(), + }, + StorageMapKey: interpreter.StringStorageMapKey("dict"), + }: {}, + }, + reporter.migrated, + ) })() // Load @@ -3718,15 +3737,9 @@ func TestUseAfterMigrationFailure(t *testing.T) { assert.Equal(t, 1, dictValue.Count()) - // Key did not get migrated, so is inaccessible using the "new" type value + // Key did not get migrated, but got still re-stored in new format, + // so it can be loaded and used after the migration failure _, ok := dictValue.Get(inter, locationRange, typeValue) - require.False(t, ok) - - // But the key is still accessible using the "old" type value - legacyKey := migrations.LegacyKey(typeValue) - - value, ok := dictValue.Get(inter, locationRange, legacyKey) require.True(t, ok) - require.Equal(t, newTestValue(), value) })() } diff --git a/migrations/migration.go b/migrations/migration.go index efc21e1a6f..d7877f4dc2 100644 --- a/migrations/migration.go +++ b/migrations/migration.go @@ -22,6 +22,8 @@ import ( "fmt" "runtime/debug" + "github.com/onflow/atree" + "github.com/onflow/cadence/runtime" "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/errors" @@ -37,11 +39,19 @@ type ValueMigration interface { storageMapKey interpreter.StorageMapKey, value interpreter.Value, interpreter *interpreter.Interpreter, + position ValueMigrationPosition, ) (newValue interpreter.Value, err error) CanSkip(valueType interpreter.StaticType) bool Domains() map[string]struct{} } +type ValueMigrationPosition uint8 + +const ( + ValueMigrationPositionOther ValueMigrationPosition = iota + ValueMigrationPositionDictionaryKey +) + type DomainMigration interface { Name() string Migrate( @@ -162,6 +172,7 @@ func (m *StorageMigration) NewValueMigrationsPathMigrator( valueMigrations, reporter, true, + ValueMigrationPositionOther, ) }, ) @@ -176,6 +187,7 @@ func (m *StorageMigration) MigrateNestedValue( valueMigrations []ValueMigration, reporter Reporter, allowMutation bool, + position ValueMigrationPosition, ) (migratedValue interpreter.Value) { defer func() { @@ -240,6 +252,7 @@ func (m *StorageMigration) MigrateNestedValue( valueMigrations, reporter, allowMutation, + ValueMigrationPositionOther, ) if newInnerValue != nil { migratedValue = interpreter.NewSomeValueNonCopying(inter, newInnerValue) @@ -264,6 +277,7 @@ func (m *StorageMigration) MigrateNestedValue( valueMigrations, reporter, allowMutation, + ValueMigrationPositionOther, ) if newElement == nil { @@ -325,6 +339,7 @@ func (m *StorageMigration) MigrateNestedValue( valueMigrations, reporter, allowMutation, + ValueMigrationPositionOther, ) if newValue == nil { @@ -363,7 +378,7 @@ func (m *StorageMigration) MigrateNestedValue( // The mutating iterator is only able to read new keys, // as it recalculates the stored values' hashes. - keys := m.migrateDictionaryKeys( + m.migrateDictionaryKeys( storageKey, storageMapKey, dictionary, @@ -379,7 +394,6 @@ func (m *StorageMigration) MigrateNestedValue( valueMigrations, reporter, allowMutation, - keys, ) case *interpreter.PublishedValue: @@ -391,6 +405,7 @@ func (m *StorageMigration) MigrateNestedValue( valueMigrations, reporter, allowMutation, + ValueMigrationPositionOther, ) if newInnerValue != nil { newInnerCapability := newInnerValue.(interpreter.CapabilityValue) @@ -415,6 +430,7 @@ func (m *StorageMigration) MigrateNestedValue( storageKey, storageMapKey, value, + position, ) if err != nil { @@ -469,11 +485,6 @@ func (m *StorageMigration) MigrateNestedValue( } -type migratedDictionaryKey struct { - key interpreter.Value - migrated bool -} - func (m *StorageMigration) migrateDictionaryKeys( storageKey interpreter.StorageKey, storageMapKey interpreter.StorageMapKey, @@ -481,7 +492,7 @@ func (m *StorageMigration) migrateDictionaryKeys( valueMigrations []ValueMigration, reporter Reporter, allowMutation bool, -) (migratedKeys []migratedDictionaryKey) { +) { inter := m.interpreter var existingKeys []interpreter.Value @@ -507,14 +518,10 @@ func (m *StorageMigration) migrateDictionaryKeys( reporter, // NOTE: Mutation of keys is not allowed. false, + ValueMigrationPositionDictionaryKey, ) if newKey == nil { - migratedKeys = append(migratedKeys, migratedDictionaryKey{ - key: existingKey, - migrated: false, - }) - continue } @@ -531,14 +538,25 @@ func (m *StorageMigration) migrateDictionaryKeys( // We only reach here because key needs to be migrated. - // Remove the old key-value pair + // Remove the old key-value pair. - existingKey = LegacyKey(existingKey) - existingKeyStorable, existingValueStorable := dictionary.RemoveWithoutTransfer( - inter, - emptyLocationRange, - existingKey, - ) + var existingKeyStorable, existingValueStorable atree.Storable + + legacyKey := LegacyKey(existingKey) + if legacyKey != nil { + existingKeyStorable, existingValueStorable = dictionary.RemoveWithoutTransfer( + inter, + emptyLocationRange, + legacyKey, + ) + } + if existingKeyStorable == nil { + existingKeyStorable, existingValueStorable = dictionary.RemoveWithoutTransfer( + inter, + emptyLocationRange, + existingKey, + ) + } if existingKeyStorable == nil { panic(errors.NewUnexpectedError( "failed to remove old value for migrated key: %s", @@ -589,6 +607,7 @@ func (m *StorageMigration) migrateDictionaryKeys( valueMigrations, reporter, allowMutation, + ValueMigrationPositionOther, ) var valueToSet interpreter.Value @@ -655,15 +674,8 @@ func (m *StorageMigration) migrateDictionaryKeys( newKey, existingValue, ) - - migratedKeys = append(migratedKeys, migratedDictionaryKey{ - key: newKey, - migrated: true, - }) } } - - return } func (m *StorageMigration) migrateDictionaryValues( @@ -673,21 +685,37 @@ func (m *StorageMigration) migrateDictionaryValues( valueMigrations []ValueMigration, reporter Reporter, allowMutation bool, - migratedDictionaryKeys []migratedDictionaryKey, ) { + inter := m.interpreter - for _, migratedDictionaryKey := range migratedDictionaryKeys { + type keyValuePair struct { + key, value interpreter.Value + } - existingKey := migratedDictionaryKey.key - if !migratedDictionaryKey.migrated { - existingKey = LegacyKey(existingKey) - } + var existingKeysAndValues []keyValuePair - existingValue, ok := dictionary.Get(inter, emptyLocationRange, existingKey) - if !ok { - panic(errors.NewUnexpectedError("failed to get existing value for key: %s", existingKey)) - } + dictionary.Iterate( + inter, + func(key, value interpreter.Value) (resume bool) { + + existingKeysAndValues = append( + existingKeysAndValues, + keyValuePair{ + key: key, + value: value, + }, + ) + + // Continue iteration + return true + }, + emptyLocationRange, + ) + + for _, existingKeyAndValue := range existingKeysAndValues { + existingKey := existingKeyAndValue.key + existingValue := existingKeyAndValue.value newValue := m.MigrateNestedValue( storageKey, @@ -696,6 +724,7 @@ func (m *StorageMigration) migrateDictionaryValues( valueMigrations, reporter, allowMutation, + ValueMigrationPositionOther, ) if newValue == nil { @@ -771,6 +800,7 @@ func (m *StorageMigration) migrate( storageKey interpreter.StorageKey, storageMapKey interpreter.StorageMapKey, value interpreter.Value, + position ValueMigrationPosition, ) ( converted interpreter.Value, err error, @@ -809,6 +839,7 @@ func (m *StorageMigration) migrate( storageMapKey, value, m.interpreter, + position, ) } @@ -832,7 +863,7 @@ func LegacyKey(key interpreter.Value) interpreter.Value { } } - return key + return nil } func legacyType(staticType interpreter.StaticType) interpreter.StaticType { diff --git a/migrations/migration_test.go b/migrations/migration_test.go index eb4c246ba2..cb6b4b2b51 100644 --- a/migrations/migration_test.go +++ b/migrations/migration_test.go @@ -104,7 +104,11 @@ func (testStringMigration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, _ *interpreter.Interpreter, -) (interpreter.Value, error) { + _ ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { if value, ok := value.(*interpreter.StringValue); ok { return interpreter.NewUnmeteredStringValue(fmt.Sprintf("updated_%s", value.Str)), nil } @@ -137,7 +141,11 @@ func (m testInt8Migration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, _ *interpreter.Interpreter, -) (interpreter.Value, error) { + _ ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { int8Value, ok := value.(interpreter.Int8Value) if !ok { return nil, nil @@ -173,7 +181,11 @@ func (testCapMigration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, _ *interpreter.Interpreter, -) (interpreter.Value, error) { + _ ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { if value, ok := value.(*interpreter.IDCapabilityValue); ok { return interpreter.NewCapabilityValue( nil, @@ -209,7 +221,11 @@ func (testCapConMigration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, _ *interpreter.Interpreter, -) (interpreter.Value, error) { + _ ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { switch value := value.(type) { case *interpreter.StorageCapabilityControllerValue: @@ -984,6 +1000,7 @@ func (m testCompositeValueMigration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, inter *interpreter.Interpreter, + _ ValueMigrationPosition, ) ( interpreter.Value, error, @@ -1178,7 +1195,11 @@ func (testContainerMigration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, inter *interpreter.Interpreter, -) (interpreter.Value, error) { + _ ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { switch value := value.(type) { case *interpreter.DictionaryValue: @@ -1640,7 +1661,11 @@ func (m testPanicMigration) Migrate( _ interpreter.StorageMapKey, _ interpreter.Value, _ *interpreter.Interpreter, -) (interpreter.Value, error) { + _ ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { // NOTE: out-of-bounds access, panic _ = []int{}[0] @@ -1759,7 +1784,11 @@ func (m *testSkipMigration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, _ *interpreter.Interpreter, -) (interpreter.Value, error) { + _ ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { m.migrationCalls = append(m.migrationCalls, value) @@ -2114,7 +2143,11 @@ func (testPublishedValueMigration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, _ *interpreter.Interpreter, -) (interpreter.Value, error) { + _ ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { if pathCap, ok := value.(*interpreter.PathCapabilityValue); ok { //nolint:staticcheck return pathCap, nil @@ -2219,7 +2252,11 @@ func (m testDomainsMigration) Migrate( _ interpreter.StorageMapKey, _ interpreter.Value, _ *interpreter.Interpreter, -) (interpreter.Value, error) { + _ ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { if m.domains != nil { _, ok := m.domains[storageKey.Key] @@ -2355,7 +2392,11 @@ func (m testDictionaryKeyConflictMigration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, _ *interpreter.Interpreter, -) (interpreter.Value, error) { + _ ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { typeValue, ok := value.(interpreter.TypeValue) if ok { return typeValue, nil @@ -2776,7 +2817,11 @@ func (testEnumMigration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, inter *interpreter.Interpreter, -) (interpreter.Value, error) { + _ ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { if composite, ok := value.(*interpreter.CompositeValue); ok && composite.Kind == common.CompositeKindEnum { rawValue := composite.GetField(inter, emptyLocationRange, sema.EnumRawValueFieldName) raw := rawValue.(interpreter.UInt8Value) diff --git a/migrations/statictypes/statictype_migration.go b/migrations/statictypes/statictype_migration.go index 84de646a58..3522c154a3 100644 --- a/migrations/statictypes/statictype_migration.go +++ b/migrations/statictypes/statictype_migration.go @@ -62,9 +62,10 @@ func (m *StaticTypeMigration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, _ *interpreter.Interpreter, + _ migrations.ValueMigrationPosition, ) ( - newValue interpreter.Value, - err error, + interpreter.Value, + error, ) { switch value := value.(type) { @@ -72,18 +73,18 @@ func (m *StaticTypeMigration) Migrate( // Type is optional. nil represents "unknown"/"invalid" type ty := value.Type if ty == nil { - return + return nil, nil } convertedType := m.maybeConvertStaticType(ty, nil) if convertedType == nil { - return + return nil, nil } return interpreter.NewTypeValue(nil, convertedType), nil case *interpreter.IDCapabilityValue: convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil) if convertedBorrowType == nil { - return + return nil, nil } return interpreter.NewUnmeteredCapabilityValue(value.ID, value.Address, convertedBorrowType), nil @@ -91,11 +92,11 @@ func (m *StaticTypeMigration) Migrate( // Type is optional borrowType := value.BorrowType if borrowType == nil { - return + return nil, nil } convertedBorrowType := m.maybeConvertStaticType(borrowType, nil) if convertedBorrowType == nil { - return + return nil, nil } return &interpreter.PathCapabilityValue{ //nolint:staticcheck BorrowType: convertedBorrowType, @@ -106,7 +107,7 @@ func (m *StaticTypeMigration) Migrate( case interpreter.PathLinkValue: //nolint:staticcheck convertedBorrowType := m.maybeConvertStaticType(value.Type, nil) if convertedBorrowType == nil { - return + return nil, nil } return interpreter.PathLinkValue{ //nolint:staticcheck Type: convertedBorrowType, @@ -116,7 +117,7 @@ func (m *StaticTypeMigration) Migrate( case *interpreter.AccountCapabilityControllerValue: convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil) if convertedBorrowType == nil { - return + return nil, nil } borrowType := convertedBorrowType.(*interpreter.ReferenceStaticType) return interpreter.NewUnmeteredAccountCapabilityControllerValue(borrowType, value.CapabilityID), nil @@ -124,7 +125,7 @@ func (m *StaticTypeMigration) Migrate( case *interpreter.StorageCapabilityControllerValue: convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil) if convertedBorrowType == nil { - return + return nil, nil } borrowType := convertedBorrowType.(*interpreter.ReferenceStaticType) return interpreter.NewUnmeteredStorageCapabilityControllerValue( @@ -136,7 +137,7 @@ func (m *StaticTypeMigration) Migrate( case *interpreter.ArrayValue: convertedElementType := m.maybeConvertStaticType(value.Type, nil) if convertedElementType == nil { - return + return nil, nil } value.SetType( @@ -146,7 +147,7 @@ func (m *StaticTypeMigration) Migrate( case *interpreter.DictionaryValue: convertedElementType := m.maybeConvertStaticType(value.Type, nil) if convertedElementType == nil { - return + return nil, nil } value.SetType( @@ -154,7 +155,7 @@ func (m *StaticTypeMigration) Migrate( ) } - return + return nil, nil } func (m *StaticTypeMigration) maybeConvertStaticType( @@ -586,6 +587,9 @@ func CanSkipStaticTypeMigration(valueType interpreter.StaticType) bool { case interpreter.PrimitiveStaticType: switch valueType { + case interpreter.PrimitiveStaticTypeMetaType: + return false + case interpreter.PrimitiveStaticTypeBool, interpreter.PrimitiveStaticTypeVoid, interpreter.PrimitiveStaticTypeAddress, diff --git a/migrations/string_normalization/migration.go b/migrations/string_normalization/migration.go index 512ca00a2f..96a0b22789 100644 --- a/migrations/string_normalization/migration.go +++ b/migrations/string_normalization/migration.go @@ -43,6 +43,7 @@ func (StringNormalizingMigration) Migrate( _ interpreter.StorageMapKey, value interpreter.Value, _ *interpreter.Interpreter, + _ migrations.ValueMigrationPosition, ) ( interpreter.Value, error, diff --git a/migrations/type_keys/migration.go b/migrations/type_keys/migration.go new file mode 100644 index 0000000000..01bc3bfaef --- /dev/null +++ b/migrations/type_keys/migration.go @@ -0,0 +1,118 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package type_keys + +import ( + "github.com/onflow/cadence/migrations" + "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/sema" +) + +type TypeKeyMigration struct{} + +var _ migrations.ValueMigration = TypeKeyMigration{} + +func NewTypeKeyMigration() TypeKeyMigration { + return TypeKeyMigration{} +} + +func (TypeKeyMigration) Name() string { + return "TypeKeyMigration" +} + +func (TypeKeyMigration) Migrate( + _ interpreter.StorageKey, + _ interpreter.StorageMapKey, + value interpreter.Value, + _ *interpreter.Interpreter, + position migrations.ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { + // Re-store Type values used as dictionary keys, + // to ensure that even when such values failed to get migrated + // by the static types and entitlements migration, + // they are still stored using their new hash. + + if position == migrations.ValueMigrationPositionDictionaryKey { + if typeValue, ok := value.(interpreter.TypeValue); ok { + return typeValue, nil + } + } + + return nil, nil +} + +func (TypeKeyMigration) Domains() map[string]struct{} { + return nil +} + +func (m TypeKeyMigration) CanSkip(valueType interpreter.StaticType) bool { + return CanSkipTypeKeyMigration(valueType) +} + +func CanSkipTypeKeyMigration(valueType interpreter.StaticType) bool { + + switch valueType := valueType.(type) { + case *interpreter.DictionaryStaticType: + return CanSkipTypeKeyMigration(valueType.KeyType) && + CanSkipTypeKeyMigration(valueType.ValueType) + + case interpreter.ArrayStaticType: + return CanSkipTypeKeyMigration(valueType.ElementType()) + + case *interpreter.OptionalStaticType: + return CanSkipTypeKeyMigration(valueType.Type) + + case *interpreter.CapabilityStaticType: + // Typed capability, can skip + return true + + case interpreter.PrimitiveStaticType: + + switch valueType { + case interpreter.PrimitiveStaticTypeMetaType: + return false + + case interpreter.PrimitiveStaticTypeBool, + interpreter.PrimitiveStaticTypeVoid, + interpreter.PrimitiveStaticTypeAddress, + interpreter.PrimitiveStaticTypeBlock, + interpreter.PrimitiveStaticTypeString, + interpreter.PrimitiveStaticTypeCharacter, + // Untyped capability, can skip + interpreter.PrimitiveStaticTypeCapability: + + return true + } + + if !valueType.IsDeprecated() { //nolint:staticcheck + semaType := valueType.SemaType() + + if sema.IsSubType(semaType, sema.NumberType) || + sema.IsSubType(semaType, sema.PathType) { + + return true + } + } + } + + return false +} diff --git a/migrations/type_keys/migration_test.go b/migrations/type_keys/migration_test.go new file mode 100644 index 0000000000..bc2a352c7b --- /dev/null +++ b/migrations/type_keys/migration_test.go @@ -0,0 +1,269 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package type_keys + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/onflow/atree" + + "github.com/onflow/cadence/migrations" + "github.com/onflow/cadence/runtime" + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/interpreter" + . "github.com/onflow/cadence/runtime/tests/runtime_utils" + "github.com/onflow/cadence/runtime/tests/utils" +) + +type testReporter struct { + migrated map[struct { + interpreter.StorageKey + interpreter.StorageMapKey + }][]string + errors []error +} + +var _ migrations.Reporter = &testReporter{} + +func newTestReporter() *testReporter { + return &testReporter{ + migrated: map[struct { + interpreter.StorageKey + interpreter.StorageMapKey + }][]string{}, + } +} + +func (t *testReporter) Migrated( + storageKey interpreter.StorageKey, + storageMapKey interpreter.StorageMapKey, + migration string, +) { + key := struct { + interpreter.StorageKey + interpreter.StorageMapKey + }{ + StorageKey: storageKey, + StorageMapKey: storageMapKey, + } + + t.migrated[key] = append( + t.migrated[key], + migration, + ) +} + +func (t *testReporter) Error(err error) { + t.errors = append(t.errors, err) +} + +func (t *testReporter) DictionaryKeyConflict(addressPath interpreter.AddressPath) { + // For testing purposes, record the conflict as an error + t.errors = append(t.errors, fmt.Errorf("dictionary key conflict: %s", addressPath)) +} + +func TestTypeKeyMigration(t *testing.T) { + t.Parallel() + + account := common.Address{0x42} + pathDomain := common.PathDomainPublic + locationRange := interpreter.EmptyLocationRange + + type testCase struct { + name string + storedValue func(inter *interpreter.Interpreter) interpreter.Value + expectedValue func(inter *interpreter.Interpreter) interpreter.Value + } + + test := func(t *testing.T, testCase testCase) { + + t.Run(testCase.name, func(t *testing.T) { + + t.Parallel() + + ledger := NewTestLedger(nil, nil) + + storageMapKey := interpreter.StringStorageMapKey("test") + + newStorageAndInterpreter := func(t *testing.T) (*runtime.Storage, *interpreter.Interpreter) { + storage := runtime.NewStorage(ledger, nil) + inter, err := interpreter.NewInterpreter( + nil, + utils.TestLocation, + &interpreter.Config{ + Storage: storage, + // NOTE: disabled, because encoded and decoded values are expected to not match + AtreeValueValidationEnabled: false, + AtreeStorageValidationEnabled: true, + }, + ) + require.NoError(t, err) + + return storage, inter + } + + // Store value + (func() { + + storage, inter := newStorageAndInterpreter(t) + + transferredValue := testCase.storedValue(inter).Transfer( + inter, + locationRange, + atree.Address(account), + false, + nil, + nil, + ) + + inter.WriteStored( + account, + pathDomain.Identifier(), + storageMapKey, + transferredValue, + ) + + err := storage.Commit(inter, true) + require.NoError(t, err) + })() + + // Migrate + (func() { + + storage, inter := newStorageAndInterpreter(t) + + migration, err := migrations.NewStorageMigration(inter, storage, "test", account) + require.NoError(t, err) + + reporter := newTestReporter() + + migration.Migrate( + migration.NewValueMigrationsPathMigrator( + reporter, + NewTypeKeyMigration(), + ), + ) + + err = migration.Commit() + require.NoError(t, err) + + require.Empty(t, reporter.errors) + + })() + + // Load + (func() { + + storage, inter := newStorageAndInterpreter(t) + + err := storage.CheckHealth() + require.NoError(t, err) + + storageMap := storage.GetStorageMap(account, pathDomain.Identifier(), false) + require.NotNil(t, storageMap) + require.Equal(t, uint64(1), storageMap.Count()) + + actualValue := storageMap.ReadValue(nil, storageMapKey) + + expectedValue := testCase.expectedValue(inter) + + utils.AssertValuesEqual(t, inter, expectedValue, actualValue) + })() + }) + } + + testCases := []testCase{ + { + name: "optional reference", + storedValue: func(inter *interpreter.Interpreter) interpreter.Value { + + dictValue := interpreter.NewDictionaryValue( + inter, + locationRange, + interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeMetaType, + interpreter.PrimitiveStaticTypeInt, + ), + ) + + dictValue.Insert( + inter, + locationRange, + // NOTE: storing with legacy key + migrations.LegacyKey( + interpreter.NewTypeValue( + nil, + interpreter.NewOptionalStaticType( + nil, + interpreter.NewReferenceStaticType( + nil, + interpreter.UnauthorizedAccess, + interpreter.PrimitiveStaticTypeInt, + ), + ), + ), + ), + interpreter.NewUnmeteredIntValueFromInt64(42), + ) + + return dictValue + }, + expectedValue: func(inter *interpreter.Interpreter) interpreter.Value { + dictValue := interpreter.NewDictionaryValue( + inter, + locationRange, + interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeMetaType, + interpreter.PrimitiveStaticTypeInt, + ), + ) + + dictValue.Insert( + inter, + locationRange, + // NOTE: expecting to load with new key + interpreter.NewTypeValue( + nil, + interpreter.NewOptionalStaticType( + nil, + interpreter.NewReferenceStaticType( + nil, + interpreter.UnauthorizedAccess, + interpreter.PrimitiveStaticTypeInt, + ), + ), + ), + interpreter.NewUnmeteredIntValueFromInt64(42), + ) + + return dictValue + }, + }, + } + + for _, testCase := range testCases { + test(t, testCase) + } + +}