From a179da75e74b27c08f2d0c8ba2d27c83cdb8db3f Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Wed, 20 Mar 2024 11:18:59 -0700 Subject: [PATCH] Improve typeID resolving --- runtime/stdlib/account.go | 2 +- ..._to_v1_contract_upgrade_validation_test.go | 98 ++++++++++++++++++- ..._v0.42_to_v1_contract_upgrade_validator.go | 39 ++++++-- 3 files changed, 124 insertions(+), 15 deletions(-) diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index 1868725195..aec228ceb8 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -1625,7 +1625,7 @@ func changeAccountContracts( contractName, handler, oldProgram, - program.Program, + program, inter.AllElaborations(), ) } else { diff --git a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go index 9db4d47172..4a4290fd9d 100644 --- a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go +++ b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go @@ -27,6 +27,7 @@ import ( "github.com/onflow/cadence/runtime" "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/interpreter" "github.com/onflow/cadence/runtime/old_parser" "github.com/onflow/cadence/runtime/parser" "github.com/onflow/cadence/runtime/sema" @@ -55,12 +56,14 @@ func testContractUpdate(t *testing.T, oldCode string, newCode string) error { err = checker.Check() require.NoError(t, err) + program := interpreter.ProgramFromChecker(checker) + upgradeValidator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( utils.TestLocation, "Test", &runtime_utils.TestRuntimeInterface{}, oldProgram, - newProgram, + program, map[common.Location]*sema.Elaboration{ utils.TestLocation: checker.Elaboration, }) @@ -104,7 +107,7 @@ func parseAndCheckPrograms( newImports map[common.Location]string, ) ( oldProgram *ast.Program, - newProgram *ast.Program, + newProgram *interpreter.Program, elaborations map[common.Location]*sema.Elaboration, ) { @@ -112,7 +115,7 @@ func parseAndCheckPrograms( oldProgram, err = old_parser.ParseProgram(nil, []byte(oldCode), old_parser.Config{}) require.NoError(t, err) - newProgram, err = parser.ParseProgram(nil, []byte(newCode), parser.Config{}) + program, err := parser.ParseProgram(nil, []byte(newCode), parser.Config{}) require.NoError(t, err) elaborations = map[common.Location]*sema.Elaboration{} @@ -139,7 +142,7 @@ func parseAndCheckPrograms( } checker, err := sema.NewChecker( - newProgram, + program, location, nil, &sema.Config{ @@ -174,7 +177,7 @@ func parseAndCheckPrograms( err = checker.Check() require.NoError(t, err) - elaborations[location] = checker.Elaboration + newProgram = interpreter.ProgramFromChecker(checker) return } @@ -1977,4 +1980,89 @@ func TestInterfaceConformanceChange(t *testing.T) { err := upgradeValidator.Validate() require.NoError(t, err) }) + + t.Run("with custom rules and changed import", func(t *testing.T) { + t.Parallel() + + const oldCode = ` + import MetadataViews from 0x02 + + pub contract Test { + pub resource R: MetadataViews.Resolver {} + } + ` + + const newImport = ` + access(all) contract ViewResolver { + access(all) resource interface Resolver {} + } + ` + + const newCode = ` + import ViewResolver from 0x02 + + access(all) contract Test { + access(all) resource R: ViewResolver.Resolver {} + } + ` + + viewResolverLocation := common.AddressLocation{ + Name: "ViewResolver", + Address: common.MustBytesToAddress([]byte{0x2}), + } + + metadatViewsLocation := common.AddressLocation{ + Name: "MetadataViews", + Address: common.MustBytesToAddress([]byte{0x2}), + } + + imports := map[common.Location]string{ + viewResolverLocation: newImport, + } + + const contractName = "Test" + location := common.AddressLocation{ + Name: contractName, + Address: common.MustBytesToAddress([]byte{0x1}), + } + + oldProgram, newProgram, elaborations := parseAndCheckPrograms(t, location, oldCode, newCode, imports) + + metadataViewsResolverTypeID := common.NewTypeIDFromQualifiedName( + nil, + metadatViewsLocation, + "MetadataViews.Resolver", + ) + + viewResolverResolverTypeID := common.NewTypeIDFromQualifiedName( + nil, + viewResolverLocation, + "ViewResolver.Resolver", + ) + + upgradeValidator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( + location, + contractName, + &runtime_utils.TestRuntimeInterface{ + OnGetAccountContractNames: func(address runtime.Address) ([]string, error) { + return []string{"TestImport"}, nil + }, + }, + oldProgram, + newProgram, + elaborations, + ).WithUserDefinedTypeChangeChecker( + func(oldTypeID common.TypeID, newTypeID common.TypeID) (checked, valid bool) { + switch oldTypeID { + case metadataViewsResolverTypeID: + return true, newTypeID == viewResolverResolverTypeID + } + + return false, false + }, + ) + + err := upgradeValidator.Validate() + require.NoError(t, err) + }) } diff --git a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go index 61de58cb33..5ce8350a9a 100644 --- a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go +++ b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go @@ -25,6 +25,7 @@ import ( "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/common/orderedmap" "github.com/onflow/cadence/runtime/errors" + "github.com/onflow/cadence/runtime/interpreter" "github.com/onflow/cadence/runtime/sema" ) @@ -46,11 +47,20 @@ func NewCadenceV042ToV1ContractUpdateValidator( contractName string, provider AccountContractNamesProvider, oldProgram *ast.Program, - newProgram *ast.Program, + newProgram *interpreter.Program, newElaborations map[common.Location]*sema.Elaboration, ) *CadenceV042ToV1ContractUpdateValidator { - underlyingValidator := NewContractUpdateValidator(location, contractName, provider, oldProgram, newProgram) + underlyingValidator := NewContractUpdateValidator( + location, + contractName, + provider, + oldProgram, + newProgram.Program, + ) + + // Also add the elaboration of the current program. + newElaborations[location] = newProgram.Elaboration return &CadenceV042ToV1ContractUpdateValidator{ underlyingUpdateValidator: underlyingValidator, @@ -157,17 +167,24 @@ func (validator *CadenceV042ToV1ContractUpdateValidator) idAndLocationOfQualifie // and in 1 and 2 we don't need to do anything typIdentifier := typ.Identifier.Identifier rootIdentifier := validator.TypeComparator.RootDeclIdentifier.Identifier - location := validator.underlyingUpdateValidator.location - foundLocations := validator.TypeComparator.foundIdentifierImportLocations + newImportLocations := validator.TypeComparator.foundIdentifierImportLocations + oldImportLocations := validator.TypeComparator.expectedIdentifierImportLocations - if typIdentifier != rootIdentifier && foundLocations[typIdentifier] == nil { - qualifiedString = fmt.Sprintf("%s.%s", rootIdentifier, qualifiedString) - return common.NewTypeIDFromQualifiedName(nil, location, qualifiedString), location + // Here we only need to find the qualified type ID. + // So check in both old imports as well as in new imports. + location, wasImported := newImportLocations[typIdentifier] + if !wasImported { + location, wasImported = oldImportLocations[typIdentifier] + } + + if !wasImported { + location = validator.underlyingUpdateValidator.location } - if loc := foundLocations[typIdentifier]; loc != nil { - location = loc + if typIdentifier != rootIdentifier && !wasImported { + qualifiedString = fmt.Sprintf("%s.%s", rootIdentifier, qualifiedString) + return common.NewTypeIDFromQualifiedName(nil, location, qualifiedString), location } return common.NewTypeIDFromQualifiedName(nil, location, qualifiedString), location @@ -425,6 +442,10 @@ func (validator *CadenceV042ToV1ContractUpdateValidator) checkUserDefinedTypeCus newType ast.Type, ) (checked, valid bool) { + if validator.checkUserDefinedType == nil { + return false, false + } + oldTypeID, err := validator.typeIDFromType(oldType) if err != nil { return false, false