diff --git a/migrations/account_storage.go b/migrations/account_storage.go index 7c5578f723..7e53d7deba 100644 --- a/migrations/account_storage.go +++ b/migrations/account_storage.go @@ -41,35 +41,35 @@ func NewAccountStorage(storage *runtime.Storage, address common.Address) Account func (i *AccountStorage) ForEachValue( inter *interpreter.Interpreter, domains []common.PathDomain, - valueConverter func(interpreter.Value) interpreter.Value, + valueConverter func(interpreter.Value) (newValue interpreter.Value, updated bool), reporter Reporter, ) { for _, domain := range domains { storageMap := i.storage.GetStorageMap(i.address, domain.Identifier(), false) - if storageMap == nil { + if storageMap == nil || storageMap.Count() == 0 { continue } iterator := storageMap.Iterator(inter) - count := storageMap.Count() - if count > 0 { - for key, value := iterator.Next(); key != nil; key, value = iterator.Next() { - newValue := valueConverter(value) + for key, value := iterator.Next(); key != nil; key, value = iterator.Next() { + newValue, updated := valueConverter(value) + if newValue == nil && !updated { + continue + } - // If the converter returns a new value, then replace the existing value with the new one. - if newValue != nil { - // TODO: unfortunately, the iterator only returns an atree.Value, not a StorageMapKey - identifier := string(key.(interpreter.StringAtreeValue)) - storageMap.SetValue( - inter, - interpreter.StringStorageMapKey(identifier), - newValue, - ) + identifier := string(key.(interpreter.StringAtreeValue)) - reporter.Report(i.address, domain, identifier, newValue) - } + if newValue != nil { + // If the converter returns a new value, then replace the existing value with the new one. + storageMap.SetValue( + inter, + interpreter.StringStorageMapKey(identifier), + newValue, + ) } + + reporter.Report(i.address, domain, identifier) } } } diff --git a/migrations/account_type/migration.go b/migrations/account_type/migration.go index c6e077d3ca..52b52399bd 100644 --- a/migrations/account_type/migration.go +++ b/migrations/account_type/migration.go @@ -58,6 +58,11 @@ func (m *AccountTypeMigration) Migrate( reporter, ) } + + err := m.storage.Commit(m.interpreter, false) + if err != nil { + panic(err) + } } // migrateTypeValuesInAccount migrates `AuthAccount` and `PublicAccount` types in a given account @@ -77,21 +82,133 @@ func (m *AccountTypeMigration) migrateTypeValuesInAccount( ) } -func (m *AccountTypeMigration) migrateValue(value interpreter.Value) interpreter.Value { - typeValue, ok := value.(interpreter.TypeValue) - if !ok { - // TODO: support migration for type-values nested inside other values. - return nil - } +var locationRange = interpreter.EmptyLocationRange + +func (m *AccountTypeMigration) migrateValue(value interpreter.Value) (newValue interpreter.Value, updated bool) { + switch value := value.(type) { + case interpreter.TypeValue: + convertedType := m.maybeConvertAccountType(value.Type) + if convertedType == nil { + return + } - innerType := typeValue.Type + return interpreter.NewTypeValue(nil, convertedType), true - convertedType := m.maybeConvertAccountType(innerType) - if convertedType == nil { - return nil - } + case *interpreter.SomeValue: + innerValue := value.InnerValue(m.interpreter, locationRange) + newInnerValue, _ := m.migrateValue(innerValue) + if newInnerValue != nil { + return interpreter.NewSomeValueNonCopying(m.interpreter, newInnerValue), true + } + + return + + case *interpreter.ArrayValue: + var index int + + // Migrate array elements + + value.Iterate(m.interpreter, func(element interpreter.Value) (resume bool) { + newElement, elementUpdated := m.migrateValue(element) + if newElement != nil { + value.Set( + m.interpreter, + locationRange, + index, + newElement, + ) + } + + index++ + + updated = updated || elementUpdated + + return true + }) + + // The array itself doesn't need to be replaced. + return + + case *interpreter.CompositeValue: + value.ForEachField(nil, func(fieldName string, fieldValue interpreter.Value) (resume bool) { + newFieldValue, fieldUpdated := m.migrateValue(fieldValue) + if newFieldValue != nil { + value.SetMember( + m.interpreter, + locationRange, + fieldName, + newFieldValue, + ) + } + + updated = updated || fieldUpdated - return interpreter.NewTypeValue(nil, convertedType) + // continue iteration + return true + }) + + // The composite itself does not have to be replaced + return + + case *interpreter.DictionaryValue: + dictionary := value + + type migratedKeyValue struct { + oldKey interpreter.Value + newKey interpreter.Value + newValue interpreter.Value + } + + var keyValues []migratedKeyValue + + dictionary.Iterate(m.interpreter, func(key, value interpreter.Value) (resume bool) { + newKey, keyUpdated := m.migrateValue(key) + newValue, valueUpdated := m.migrateValue(value) + + if newKey != nil || newValue != nil { + keyValues = append( + keyValues, + migratedKeyValue{ + oldKey: key, + newKey: newKey, + newValue: newValue, + }, + ) + } + + updated = updated || keyUpdated || valueUpdated + + return true + }) + + for _, keyValue := range keyValues { + var key, value interpreter.Value + + // We only reach here is either the key or value has been migrated. + + if keyValue.newKey != nil { + // Key was migrated. + // Remove the old value at the old key. + // This old value will be inserted again with the new key, unless the value is also migrated. + value = dictionary.RemoveKey(m.interpreter, locationRange, keyValue.oldKey) + key = keyValue.newKey + } else { + key = keyValue.oldKey + } + + // Value was migrated + if keyValue.newValue != nil { + value = keyValue.newValue + } + + dictionary.SetKey(m.interpreter, locationRange, key, value) + } + + // The dictionary itself does not have to be replaced + return + default: + return + } } func (m *AccountTypeMigration) maybeConvertAccountType(staticType interpreter.StaticType) interpreter.StaticType { diff --git a/migrations/account_type/migration_test.go b/migrations/account_type/migration_test.go index d5687255e3..6cea4f0ae1 100644 --- a/migrations/account_type/migration_test.go +++ b/migrations/account_type/migration_test.go @@ -22,9 +22,10 @@ import ( "fmt" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/onflow/atree" + "github.com/onflow/cadence/migrations" "github.com/onflow/cadence/runtime" "github.com/onflow/cadence/runtime/common" @@ -33,9 +34,36 @@ import ( "github.com/onflow/cadence/runtime/tests/utils" ) -type testCase struct { - storedType interpreter.StaticType - expectedType interpreter.StaticType +var _ migrations.Reporter = &testReporter{} + +type testReporter struct { + migratedPaths map[common.Address]map[common.PathDomain]map[string]struct{} +} + +func newTestReporter() *testReporter { + return &testReporter{ + migratedPaths: map[common.Address]map[common.PathDomain]map[string]struct{}{}, + } +} + +func (t *testReporter) Report( + address common.Address, + domain common.PathDomain, + identifier string, +) { + migratedPathsInAddress, ok := t.migratedPaths[address] + if !ok { + migratedPathsInAddress = make(map[common.PathDomain]map[string]struct{}) + t.migratedPaths[address] = migratedPathsInAddress + } + + migratedPathsInDomain, ok := migratedPathsInAddress[domain] + if !ok { + migratedPathsInDomain = make(map[string]struct{}) + migratedPathsInAddress[domain] = migratedPathsInDomain + } + + migratedPathsInDomain[identifier] = struct{}{} } func TestMigration(t *testing.T) { @@ -48,6 +76,11 @@ func TestMigration(t *testing.T) { const authAccountType = interpreter.PrimitiveStaticTypeAuthAccount const stringType = interpreter.PrimitiveStaticTypeString + type testCase struct { + storedType interpreter.StaticType + expectedType interpreter.StaticType + } + testCases := map[string]testCase{ "public_account": { storedType: publicAccountType, @@ -285,7 +318,6 @@ func TestMigration(t *testing.T) { // Migrate rt := runtime_utils.NewTestInterpreterRuntime() - runtimeInterface := &runtime_utils.TestRuntimeInterface{ Storage: ledger, } @@ -296,7 +328,6 @@ func TestMigration(t *testing.T) { Interface: runtimeInterface, }, ) - require.NoError(t, err) reporter := newTestReporter() @@ -310,15 +341,11 @@ func TestMigration(t *testing.T) { reporter, ) - migratedPathsInDomain := reporter.migratedPaths[account][pathDomain] - - for path, _ := range migratedPathsInDomain { - require.Contains(t, testCases, path) - } + // Check reported migrated paths + migratedPathsInDomain := reporter.migratedPaths[account][pathDomain] for path, test := range testCases { - t.Run(path, func(t *testing.T) { - + t.Run(fmt.Sprintf("reported_%s", path), func(t *testing.T) { test := test path := path @@ -328,16 +355,34 @@ func TestMigration(t *testing.T) { require.NotContains(t, migratedPathsInDomain, path) } else { require.Contains(t, migratedPathsInDomain, path) + } + }) + } + + // Assert the migrated values. + // Traverse through the storage and see if the values are updated now. - actualValue := migratedPathsInDomain[path] - actualTypeValue := actualValue.(interpreter.TypeValue) + storageMap := storage.GetStorageMap(account, pathDomain.Identifier(), false) + require.NotNil(t, storageMap) + require.Greater(t, storageMap.Count(), uint64(0)) - assert.True( - t, - test.expectedType.Equal(actualTypeValue.Type), - fmt.Sprintf("expected `%s`, found `%s`", test.expectedType, actualTypeValue.Type), - ) + iterator := storageMap.Iterator(inter) + + for key, value := iterator.Next(); key != nil; key, value = iterator.Next() { + identifier := string(key.(interpreter.StringAtreeValue)) + + t.Run(identifier, func(t *testing.T) { + testCase, ok := testCases[identifier] + require.True(t, ok) + + var storageValue interpreter.Value + if testCase.expectedType != nil { + storageValue = interpreter.NewTypeValue(nil, testCase.expectedType) + } else { + storageValue = interpreter.NewTypeValue(nil, testCase.storedType) } + + utils.AssertValuesEqual(t, inter, storageValue, value) }) } } @@ -357,39 +402,172 @@ func storeTypeValue( ) } -var _ migrations.Reporter = &testReporter{} +func TestNestedTypeValueMigration(t *testing.T) { + t.Parallel() -type testReporter struct { - migratedPaths map[common.Address]map[common.PathDomain]map[string]interpreter.Value -} + account := common.Address{0x42} + pathDomain := common.PathDomainPublic -func newTestReporter() *testReporter { - return &testReporter{ - migratedPaths: map[common.Address]map[common.PathDomain]map[string]interpreter.Value{}, + type testCase struct { + storedValue interpreter.Value + expectedValue interpreter.Value } -} -func (t *testReporter) Report( - address common.Address, - domain common.PathDomain, - identifier string, - value interpreter.Value, -) { - migratedPathsInAddress, ok := t.migratedPaths[address] - if !ok { - migratedPathsInAddress = make(map[common.PathDomain]map[string]interpreter.Value) - t.migratedPaths[address] = migratedPathsInAddress + storedAccountTypeValue := interpreter.NewTypeValue(nil, interpreter.PrimitiveStaticTypePublicAccount) + expectedAccountTypeValue := interpreter.NewTypeValue(nil, unauthorizedAccountReferenceType) + stringTypeValue := interpreter.NewTypeValue(nil, interpreter.PrimitiveStaticTypeString) + + ledger := runtime_utils.NewTestLedger(nil, nil) + storage := runtime.NewStorage(ledger, nil) + + inter, err := interpreter.NewInterpreter( + nil, + utils.TestLocation, + &interpreter.Config{ + Storage: storage, + AtreeValueValidationEnabled: false, + AtreeStorageValidationEnabled: false, + }, + ) + require.NoError(t, err) + + testCases := map[string]testCase{ + "account_some_value": { + storedValue: interpreter.NewUnmeteredSomeValueNonCopying(storedAccountTypeValue), + expectedValue: interpreter.NewUnmeteredSomeValueNonCopying(expectedAccountTypeValue), + }, + "int8_some_value": { + storedValue: interpreter.NewUnmeteredSomeValueNonCopying(stringTypeValue), + }, + "account_array": { + storedValue: interpreter.NewArrayValue( + inter, + locationRange, + interpreter.NewVariableSizedStaticType(nil, interpreter.PrimitiveStaticTypeAnyStruct), + common.ZeroAddress, + stringTypeValue, + storedAccountTypeValue, + stringTypeValue, + stringTypeValue, + storedAccountTypeValue, + ), + expectedValue: interpreter.NewArrayValue( + inter, + locationRange, + interpreter.NewVariableSizedStaticType(nil, interpreter.PrimitiveStaticTypeAnyStruct), + common.ZeroAddress, + stringTypeValue, + expectedAccountTypeValue, + stringTypeValue, + stringTypeValue, + expectedAccountTypeValue, + ), + }, + "non_account_array": { + storedValue: interpreter.NewArrayValue( + inter, + locationRange, + interpreter.NewVariableSizedStaticType(nil, interpreter.PrimitiveStaticTypeAnyStruct), + common.ZeroAddress, + stringTypeValue, + stringTypeValue, + stringTypeValue, + ), + }, + "dictionary_with_account_type_value": { + storedValue: interpreter.NewDictionaryValue( + inter, + locationRange, + interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeInt8, + interpreter.PrimitiveStaticTypeAnyStruct, + ), + interpreter.NewUnmeteredInt8Value(4), + interpreter.NewUnmeteredSomeValueNonCopying(storedAccountTypeValue), + ), + expectedValue: interpreter.NewDictionaryValue( + inter, + locationRange, + interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeInt8, + interpreter.PrimitiveStaticTypeAnyStruct, + ), + interpreter.NewUnmeteredInt8Value(4), + interpreter.NewUnmeteredSomeValueNonCopying(expectedAccountTypeValue), + ), + }, } - migratedPathsInDomain, ok := migratedPathsInAddress[domain] - if !ok { - migratedPathsInDomain = make(map[string]interpreter.Value) - migratedPathsInAddress[domain] = migratedPathsInDomain + // Store values + + for name, testCase := range testCases { + transferredValue := testCase.storedValue.Transfer( + inter, + locationRange, + atree.Address(account), + true, + nil, + nil, + ) + + inter.WriteStored( + account, + pathDomain.Identifier(), + interpreter.StringStorageMapKey(name), + transferredValue, + ) } - migratedPathsInDomain[identifier] = value -} + err = storage.Commit(inter, true) + require.NoError(t, err) -func (t *testReporter) ReportError(err error) { - panic("implement me") + // Migrate + + rt := runtime_utils.NewTestInterpreterRuntime() + runtimeInterface := &runtime_utils.TestRuntimeInterface{ + Storage: ledger, + } + + migration, err := NewAccountTypeMigration( + rt, + runtime.Context{ + Interface: runtimeInterface, + }, + ) + require.NoError(t, err) + + migration.Migrate( + &migrations.AddressSliceIterator{ + Addresses: []common.Address{ + account, + }, + }, + nil, + ) + + // Assert: Traverse through the storage and see if the values are updated now. + + storageMap := storage.GetStorageMap(account, pathDomain.Identifier(), false) + require.NotNil(t, storageMap) + require.Greater(t, storageMap.Count(), uint64(0)) + + iterator := storageMap.Iterator(inter) + + for key, value := iterator.Next(); key != nil; key, value = iterator.Next() { + identifier := string(key.(interpreter.StringAtreeValue)) + + t.Run(identifier, func(t *testing.T) { + testCase, ok := testCases[identifier] + require.True(t, ok) + + expectedStoredValue := testCase.expectedValue + if expectedStoredValue == nil { + expectedStoredValue = testCase.storedValue + } + + utils.AssertValuesEqual(t, inter, expectedStoredValue, value) + }) + } } diff --git a/migrations/migration_reporter.go b/migrations/migration_reporter.go index a11bf2f087..a03d72957b 100644 --- a/migrations/migration_reporter.go +++ b/migrations/migration_reporter.go @@ -20,10 +20,8 @@ package migrations import ( "github.com/onflow/cadence/runtime/common" - "github.com/onflow/cadence/runtime/interpreter" ) type Reporter interface { - Report(address common.Address, domain common.PathDomain, key string, value interpreter.Value) - ReportError(err error) + Report(address common.Address, domain common.PathDomain, key string) }