diff --git a/runtime/contract_test.go b/runtime/contract_test.go index 2bdf7a1d8d..f188075382 100644 --- a/runtime/contract_test.go +++ b/runtime/contract_test.go @@ -26,8 +26,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/onflow/cadence/runtime/errors" "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/sema" "github.com/onflow/cadence/runtime/stdlib" + "github.com/onflow/cadence/runtime/tests/checker" . "github.com/onflow/cadence/runtime/tests/utils" "github.com/onflow/cadence" @@ -1014,3 +1017,296 @@ func TestRuntimeContractInterfaceConditionEventEmission(t *testing.T) { require.Equal(t, concreteEvent.Fields[0], cadence.String("")) require.Equal(t, concreteEvent.Fields[1], cadence.NewInt(2)) } + +func TestRuntimeContractTryUpdate(t *testing.T) { + t.Parallel() + + newTestRuntimeInterface := func(onUpdate func()) *testRuntimeInterface { + var actualEvents []cadence.Event + storage := newTestLedger(nil, nil) + accountCodes := map[Location][]byte{} + + return &testRuntimeInterface{ + storage: storage, + log: func(message string) {}, + emitEvent: func(event cadence.Event) error { + actualEvents = append(actualEvents, event) + return nil + }, + resolveLocation: singleIdentifierLocationResolver(t), + getSigningAccounts: func() ([]Address, error) { + return []Address{[8]byte{0, 0, 0, 0, 0, 0, 0, 1}}, nil + }, + updateAccountContractCode: func(location common.AddressLocation, code []byte) error { + onUpdate() + accountCodes[location] = code + return nil + }, + getAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + } + } + + t.Run("tryUpdate simple", func(t *testing.T) { + + t.Parallel() + + rt := newTestInterpreterRuntime() + + deployTx := DeploymentTransaction("Foo", []byte(`access(all) contract Foo {}`)) + + updateTx := []byte(` + transaction { + prepare(signer: auth(UpdateContract) &Account) { + let code = "access(all) contract Foo { access(all) fun sayHello(): String {return \"hello\"} }".utf8 + + let deploymentResult = signer.contracts.tryUpdate( + name: "Foo", + code: code, + ) + + let deployedContract = deploymentResult.deployedContract! + assert(deployedContract.name == "Foo") + assert(deployedContract.address == 0x1) + assert(deployedContract.code == code) + } + } + `) + + invokeTx := []byte(` + import Foo from 0x1 + + transaction { + prepare(signer: &Account) { + assert(Foo.sayHello() == "hello") + } + } + `) + + runtimeInterface := newTestRuntimeInterface(func() {}) + nextTransactionLocation := newTransactionLocationGenerator() + + // Deploy 'Foo' + err := rt.ExecuteTransaction( + Script{ + Source: deployTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + // Update 'Foo' + err = rt.ExecuteTransaction( + Script{ + Source: updateTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + // Test the updated 'Foo' + err = rt.ExecuteTransaction( + Script{ + Source: invokeTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + }) + + t.Run("tryUpdate non existing", func(t *testing.T) { + + t.Parallel() + + rt := newTestInterpreterRuntime() + + updateTx := []byte(` + transaction { + prepare(signer: auth(UpdateContract) &Account) { + let deploymentResult = signer.contracts.tryUpdate( + name: "Foo", + code: "access(all) contract Foo { access(all) fun sayHello(): String {return \"hello\"} }".utf8, + ) + + assert(deploymentResult.deployedContract == nil) + } + } + `) + + invokeTx := []byte(` + import Foo from 0x1 + + transaction { + prepare(signer: &Account) { + assert(Foo.sayHello() == "hello") + } + } + `) + + runtimeInterface := newTestRuntimeInterface(func() {}) + nextTransactionLocation := newTransactionLocationGenerator() + + // Update non-existing 'Foo'. Should not panic. + err := rt.ExecuteTransaction( + Script{ + Source: updateTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + // Test the updated 'Foo'. + // Foo must not be available. + + err = rt.ExecuteTransaction( + Script{ + Source: invokeTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + RequireError(t, err) + + errs := checker.RequireCheckerErrors(t, err, 1) + var notExportedError *sema.NotExportedError + require.ErrorAs(t, errs[0], ¬ExportedError) + }) + + t.Run("tryUpdate with checking error", func(t *testing.T) { + + t.Parallel() + + rt := newTestInterpreterRuntime() + + deployTx := DeploymentTransaction("Foo", []byte(`access(all) contract Foo {}`)) + + updateTx := []byte(` + transaction { + prepare(signer: auth(UpdateContract) &Account) { + let deploymentResult = signer.contracts.tryUpdate( + name: "Foo", + + // Has a semantic error! + code: "access(all) contract Foo { access(all) fun sayHello(): Int { return \"hello\" } }".utf8, + ) + + assert(deploymentResult.deployedContract == nil) + } + } + `) + + runtimeInterface := newTestRuntimeInterface(func() {}) + + nextTransactionLocation := newTransactionLocationGenerator() + + // Deploy 'Foo' + err := rt.ExecuteTransaction( + Script{ + Source: deployTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + // Update 'Foo'. + // User errors (parsing, checking and interpreting) should be handled gracefully. + + err = rt.ExecuteTransaction( + Script{ + Source: updateTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + require.NoError(t, err) + }) + + t.Run("tryUpdate panic with internal error", func(t *testing.T) { + + t.Parallel() + + rt := newTestInterpreterRuntime() + + deployTx := DeploymentTransaction("Foo", []byte(`access(all) contract Foo {}`)) + + updateTx := []byte(` + transaction { + prepare(signer: auth(UpdateContract) &Account) { + let deploymentResult = signer.contracts.tryUpdate( + name: "Foo", + code: "access(all) contract Foo { access(all) fun sayHello(): String {return \"hello\"} }".utf8, + ) + + assert(deploymentResult.deployedContract == nil) + } + } + `) + + shouldPanic := false + didPanic := false + + runtimeInterface := newTestRuntimeInterface(func() { + if shouldPanic { + didPanic = true + panic("panic during update") + } + }) + + nextTransactionLocation := newTransactionLocationGenerator() + + // Deploy 'Foo' + err := rt.ExecuteTransaction( + Script{ + Source: deployTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + assert.False(t, didPanic) + + // Update 'Foo'. + // Internal errors should NOT be handled gracefully. + + shouldPanic = true + err = rt.ExecuteTransaction( + Script{ + Source: updateTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + RequireError(t, err) + var unexpectedError errors.UnexpectedError + require.ErrorAs(t, err, &unexpectedError) + + assert.True(t, didPanic) + }) +} diff --git a/runtime/convertValues_test.go b/runtime/convertValues_test.go index 5be6c0a4a8..8f6c898911 100644 --- a/runtime/convertValues_test.go +++ b/runtime/convertValues_test.go @@ -5242,3 +5242,153 @@ func TestRuntimeDestroyedResourceReferenceExport(t *testing.T) { require.Error(t, err) require.ErrorAs(t, err, &interpreter.DestroyedResourceError{}) } + +func TestRuntimeDeploymentResultValueImportExport(t *testing.T) { + + t.Parallel() + + t.Run("import", func(t *testing.T) { + + t.Parallel() + + script := ` + access(all) fun main(v: DeploymentResult) {} + ` + + rt := newTestInterpreterRuntime() + runtimeInterface := &testRuntimeInterface{} + + _, err := rt.ExecuteScript( + Script{ + Source: []byte(script), + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + + RequireError(t, err) + + var notImportableError *ScriptParameterTypeNotImportableError + require.ErrorAs(t, err, ¬ImportableError) + }) + + t.Run("export", func(t *testing.T) { + + t.Parallel() + + script := ` + access(all) fun main(): DeploymentResult? { + return nil + } + ` + + rt := newTestInterpreterRuntime() + runtimeInterface := &testRuntimeInterface{} + + _, err := rt.ExecuteScript( + Script{ + Source: []byte(script), + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + + RequireError(t, err) + + var invalidReturnTypeError *InvalidScriptReturnTypeError + require.ErrorAs(t, err, &invalidReturnTypeError) + }) +} + +func TestRuntimeDeploymentResultTypeImportExport(t *testing.T) { + + t.Parallel() + + t.Run("import", func(t *testing.T) { + + t.Parallel() + + script := ` + access(all) fun main(v: Type) { + assert(v == Type()) + } + ` + + rt := newTestInterpreterRuntime() + + typeValue := cadence.NewTypeValue(&cadence.StructType{ + QualifiedIdentifier: "DeploymentResult", + Fields: []cadence.Field{ + { + Type: cadence.NewOptionalType(cadence.DeployedContractType), + Identifier: "deployedContract", + }, + }, + }) + + encodedArg, err := json.Encode(typeValue) + require.NoError(t, err) + + runtimeInterface := &testRuntimeInterface{} + + runtimeInterface.decodeArgument = func(b []byte, t cadence.Type) (value cadence.Value, err error) { + return json.Decode(runtimeInterface, b) + } + + _, err = rt.ExecuteScript( + Script{ + Source: []byte(script), + Arguments: [][]byte{encodedArg}, + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + + require.NoError(t, err) + }) + + t.Run("export", func(t *testing.T) { + + t.Parallel() + + script := ` + access(all) fun main(): Type { + return Type() + } + ` + + rt := newTestInterpreterRuntime() + runtimeInterface := &testRuntimeInterface{} + + result, err := rt.ExecuteScript( + Script{ + Source: []byte(script), + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + + require.NoError(t, err) + + assert.Equal(t, + cadence.NewTypeValue(&cadence.StructType{ + QualifiedIdentifier: "DeploymentResult", + Fields: []cadence.Field{ + { + Type: cadence.NewOptionalType(cadence.DeployedContractType), + Identifier: "deployedContract", + }, + }, + }), + result, + ) + }) +} diff --git a/runtime/interpreter/value_account_contracts.go b/runtime/interpreter/value_account_contracts.go index 655b86e902..76211ad5ea 100644 --- a/runtime/interpreter/value_account_contracts.go +++ b/runtime/interpreter/value_account_contracts.go @@ -38,6 +38,7 @@ func NewAccountContractsValue( address AddressValue, addFunction FunctionValue, updateFunction FunctionValue, + tryUpdateFunction FunctionValue, getFunction FunctionValue, borrowFunction FunctionValue, removeFunction FunctionValue, @@ -45,11 +46,12 @@ func NewAccountContractsValue( ) Value { fields := map[string]Value{ - sema.Account_ContractsTypeAddFunctionName: addFunction, - sema.Account_ContractsTypeGetFunctionName: getFunction, - sema.Account_ContractsTypeBorrowFunctionName: borrowFunction, - sema.Account_ContractsTypeRemoveFunctionName: removeFunction, - sema.Account_ContractsTypeUpdateFunctionName: updateFunction, + sema.Account_ContractsTypeAddFunctionName: addFunction, + sema.Account_ContractsTypeGetFunctionName: getFunction, + sema.Account_ContractsTypeBorrowFunctionName: borrowFunction, + sema.Account_ContractsTypeRemoveFunctionName: removeFunction, + sema.Account_ContractsTypeUpdateFunctionName: updateFunction, + sema.Account_ContractsTypeTryUpdateFunctionName: tryUpdateFunction, } computeField := func( diff --git a/runtime/interpreter/value_deployment_result.go b/runtime/interpreter/value_deployment_result.go new file mode 100644 index 0000000000..c39f2a25ed --- /dev/null +++ b/runtime/interpreter/value_deployment_result.go @@ -0,0 +1,49 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package interpreter + +import ( + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/sema" +) + +// DeploymentResult + +var deploymentResultTypeID = sema.DeploymentResultType.ID() +var deploymentResultStaticType = ConvertSemaToStaticType(nil, sema.DeploymentResultType) // unmetered +var deploymentResultFieldNames []string = nil + +func NewDeploymentResultValue( + gauge common.MemoryGauge, + deployedContract OptionalValue, +) Value { + + return NewSimpleCompositeValue( + gauge, + deploymentResultTypeID, + deploymentResultStaticType, + deploymentResultFieldNames, + map[string]Value{ + sema.DeploymentResultTypeDeployedContractFieldName: deployedContract, + }, + nil, + nil, + nil, + ) +} diff --git a/runtime/sema/account.cdc b/runtime/sema/account.cdc index 6d765ee1fb..788214460c 100644 --- a/runtime/sema/account.cdc +++ b/runtime/sema/account.cdc @@ -205,6 +205,26 @@ struct Account { access(Contracts | UpdateContract) fun update(name: String, code: [UInt8]): DeployedContract + /// Updates the code for the contract/contract interface in the account, + /// and handle any deployment errors gracefully. + /// + /// The `code` parameter is the UTF-8 encoded representation of the source code. + /// The code must contain exactly one contract or contract interface, + /// which must have the same name as the `name` parameter. + /// + /// Does **not** run the initializer of the contract/contract interface again. + /// The contract instance in the world state stays as is. + /// + /// Fails if no contract/contract interface with the given name exists in the account, + /// if the given code does not declare exactly one contract or contract interface, + /// or if the given name does not match the name of the contract/contract interface declaration in the code. + /// + /// Returns the deployment result. + /// Result would contain the deployed contract for the updated contract, if the update was successfull. + /// Otherwise, the deployed contract would be nil. + access(Contracts | UpdateContract) + fun tryUpdate(name: String, code: [UInt8]): DeploymentResult + /// Returns the deployed contract for the contract/contract interface with the given name in the account, if any. /// /// Returns nil if no contract/contract interface with the given name exists in the account. diff --git a/runtime/sema/account.gen.go b/runtime/sema/account.gen.go index 55d6521263..df3ae4e4cd 100644 --- a/runtime/sema/account.gen.go +++ b/runtime/sema/account.gen.go @@ -644,6 +644,46 @@ or if the given name does not match the name of the contract/contract interface Returns the deployed contract for the updated contract. ` +const Account_ContractsTypeTryUpdateFunctionName = "tryUpdate" + +var Account_ContractsTypeTryUpdateFunctionType = &FunctionType{ + Parameters: []Parameter{ + { + Identifier: "name", + TypeAnnotation: NewTypeAnnotation(StringType), + }, + { + Identifier: "code", + TypeAnnotation: NewTypeAnnotation(&VariableSizedType{ + Type: UInt8Type, + }), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation( + DeploymentResultType, + ), +} + +const Account_ContractsTypeTryUpdateFunctionDocString = ` +Updates the code for the contract/contract interface in the account, +and handle any deployment errors gracefully. + +The ` + "`code`" + ` parameter is the UTF-8 encoded representation of the source code. +The code must contain exactly one contract or contract interface, +which must have the same name as the ` + "`name`" + ` parameter. + +Does **not** run the initializer of the contract/contract interface again. +The contract instance in the world state stays as is. + +Fails if no contract/contract interface with the given name exists in the account, +if the given code does not declare exactly one contract or contract interface, +or if the given name does not match the name of the contract/contract interface declaration in the code. + +Returns the deployment result. +Result would contain the deployed contract for the updated contract, if the update was successfull. +Otherwise, the deployed contract would be nil. +` + const Account_ContractsTypeGetFunctionName = "get" var Account_ContractsTypeGetFunctionType = &FunctionType{ @@ -771,6 +811,16 @@ func init() { Account_ContractsTypeUpdateFunctionType, Account_ContractsTypeUpdateFunctionDocString, ), + NewUnmeteredFunctionMember( + Account_ContractsType, + newEntitlementAccess( + []Type{ContractsType, UpdateContractType}, + Disjunction, + ), + Account_ContractsTypeTryUpdateFunctionName, + Account_ContractsTypeTryUpdateFunctionType, + Account_ContractsTypeTryUpdateFunctionDocString, + ), NewUnmeteredFunctionMember( Account_ContractsType, PrimitiveAccess(ast.AccessAll), diff --git a/runtime/sema/deployment_result.cdc b/runtime/sema/deployment_result.cdc new file mode 100644 index 0000000000..a927f5e28b --- /dev/null +++ b/runtime/sema/deployment_result.cdc @@ -0,0 +1,11 @@ +#compositeType +access(all) +struct DeploymentResult { + + /// The deployed contract. + /// + /// If the the deployment was unsuccessfull, this will be nil. + /// + access(all) + let deployedContract: DeployedContract? +} diff --git a/runtime/sema/deployment_result.gen.go b/runtime/sema/deployment_result.gen.go new file mode 100644 index 0000000000..0829352a90 --- /dev/null +++ b/runtime/sema/deployment_result.gen.go @@ -0,0 +1,66 @@ +// Code generated from deployment_result.cdc. DO NOT EDIT. +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sema + +import ( + "github.com/onflow/cadence/runtime/ast" + "github.com/onflow/cadence/runtime/common" +) + +const DeploymentResultTypeDeployedContractFieldName = "deployedContract" + +var DeploymentResultTypeDeployedContractFieldType = &OptionalType{ + Type: DeployedContractType, +} + +const DeploymentResultTypeDeployedContractFieldDocString = ` +The deployed contract. + +If the the deployment was unsuccessfull, this will be nil. +` + +const DeploymentResultTypeName = "DeploymentResult" + +var DeploymentResultType = func() *CompositeType { + var t = &CompositeType{ + Identifier: DeploymentResultTypeName, + Kind: common.CompositeKindStructure, + ImportableBuiltin: false, + HasComputedMembers: true, + } + + return t +}() + +func init() { + var members = []*Member{ + NewUnmeteredFieldMember( + DeploymentResultType, + PrimitiveAccess(ast.AccessAll), + ast.VariableKindConstant, + DeploymentResultTypeDeployedContractFieldName, + DeploymentResultTypeDeployedContractFieldType, + DeploymentResultTypeDeployedContractFieldDocString, + ), + } + + DeploymentResultType.Members = MembersAsMap(members) + DeploymentResultType.Fields = MembersFieldNames(members) +} diff --git a/runtime/sema/deployment_result.go b/runtime/sema/deployment_result.go new file mode 100644 index 0000000000..d360f76a71 --- /dev/null +++ b/runtime/sema/deployment_result.go @@ -0,0 +1,21 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sema + +//go:generate go run ./gen deployment_result.cdc deployment_result.gen.go diff --git a/runtime/sema/gen/main.go b/runtime/sema/gen/main.go index 9d7ed032fe..3fd35f3e5b 100644 --- a/runtime/sema/gen/main.go +++ b/runtime/sema/gen/main.go @@ -164,8 +164,9 @@ type typeDecl struct { } type generator struct { - typeStack []*typeDecl - decls []dst.Decl + typeStack []*typeDecl + decls []dst.Decl + leadingPragma map[string]struct{} } var _ ast.DeclarationVisitor[struct{}] = &generator{} @@ -365,24 +366,32 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ g.typeStack = g.typeStack[:lastIndex] }() - // We can generate a SimpleType declaration, - // if this is a top-level type, - // and this declaration has no nested type declarations. - // Otherwise, we have to generate a CompositeType - - canGenerateSimpleType := len(g.typeStack) == 1 - if canGenerateSimpleType { - switch compositeKind { - case common.CompositeKindStructure, - common.CompositeKindResource: - break - default: - canGenerateSimpleType = false + var generateSimpleType bool + + // Check if the declaration is explicitly marked to be generated as a composite type. + if _, ok := g.leadingPragma["compositeType"]; ok { + generateSimpleType = false + } else { + // If not, decide what to generate depending on the type. + + // We can generate a SimpleType declaration, + // if this is a top-level type, + // and this declaration has no nested type declarations. + // Otherwise, we have to generate a CompositeType + generateSimpleType = len(g.typeStack) == 1 + if generateSimpleType { + switch compositeKind { + case common.CompositeKindStructure, + common.CompositeKindResource: + break + default: + generateSimpleType = false + } } } for _, memberDeclaration := range decl.Members.Declarations() { - ast.AcceptDeclaration[struct{}](memberDeclaration, g) + generateDeclaration(g, memberDeclaration) // Visiting unsupported declarations panics, // so only supported member declarations are added @@ -391,14 +400,14 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ memberDeclaration, ) - if canGenerateSimpleType { + if generateSimpleType { switch memberDeclaration.(type) { case *ast.FieldDeclaration, *ast.FunctionDeclaration: break default: - canGenerateSimpleType = false + generateSimpleType = false } } } @@ -406,7 +415,7 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ for _, conformance := range decl.Conformances { switch conformance.Identifier.Identifier { case "Storable": - if !canGenerateSimpleType { + if !generateSimpleType { panic(fmt.Errorf( "composite types cannot be explicitly marked as storable: %s", g.currentTypeID(), @@ -415,7 +424,7 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ typeDecl.storable = true case "Equatable": - if !canGenerateSimpleType { + if !generateSimpleType { panic(fmt.Errorf( "composite types cannot be explicitly marked as equatable: %s", g.currentTypeID(), @@ -424,7 +433,7 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ typeDecl.equatable = true case "Comparable": - if !canGenerateSimpleType { + if !generateSimpleType { panic(fmt.Errorf( "composite types cannot be explicitly marked as comparable: %s", g.currentTypeID(), @@ -433,7 +442,7 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ typeDecl.comparable = true case "Exportable": - if !canGenerateSimpleType { + if !generateSimpleType { panic(fmt.Errorf( "composite types cannot be explicitly marked as exportable: %s", g.currentTypeID(), @@ -445,7 +454,7 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ typeDecl.importable = true case "ContainFields": - if !canGenerateSimpleType { + if !generateSimpleType { panic(fmt.Errorf( "composite types cannot be explicitly marked as having fields: %s", g.currentTypeID(), @@ -456,7 +465,7 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ } var typeVarDecl dst.Expr - if canGenerateSimpleType { + if generateSimpleType { typeVarDecl = simpleTypeLiteral(typeDecl) } else { typeVarDecl = compositeTypeExpr(typeDecl) @@ -479,7 +488,7 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ if len(memberDeclarations) > 0 { - if canGenerateSimpleType { + if generateSimpleType { // func init() { // t.Members = func(t *SimpleType) map[string]MemberResolver { @@ -1075,8 +1084,20 @@ func (*generator) VisitEnumCaseDeclaration(_ *ast.EnumCaseDeclaration) struct{} panic("enum case declarations are not supported") } -func (*generator) VisitPragmaDeclaration(_ *ast.PragmaDeclaration) struct{} { - panic("pragma declarations are not supported") +func (g *generator) VisitPragmaDeclaration(pragma *ast.PragmaDeclaration) (_ struct{}) { + // Treat pragmas as part of the declaration to follow. + + identifierExpr, ok := pragma.Expression.(*ast.IdentifierExpression) + if !ok { + panic("only identifier pragmas are supported") + } + + if g.leadingPragma == nil { + g.leadingPragma = map[string]struct{}{} + } + g.leadingPragma[identifierExpr.Identifier.Identifier] = struct{}{} + + return } func (*generator) VisitImportDeclaration(_ *ast.ImportDeclaration) struct{} { @@ -1954,7 +1975,7 @@ func gen(inPath string, outFile *os.File, packagePath string) { var gen generator for _, declaration := range program.Declarations() { - _ = ast.AcceptDeclaration[struct{}](declaration, &gen) + generateDeclaration(&gen, declaration) } gen.generateTypeInit(program) @@ -1962,6 +1983,19 @@ func gen(inPath string, outFile *os.File, packagePath string) { writeGoFile(inPath, outFile, gen.decls, packagePath) } +func generateDeclaration(gen *generator, declaration ast.Declaration) { + // Treat leading pragmas as part of this declaration. + // Reset them after finishing the current decl. This is to handle nested declarations. + if declaration.DeclarationKind() != common.DeclarationKindPragma { + prevLeadingPragma := gen.leadingPragma + defer func() { + gen.leadingPragma = prevLeadingPragma + }() + } + + _ = ast.AcceptDeclaration[struct{}](declaration, gen) +} + func writeGoFile(inPath string, outFile *os.File, decls []dst.Decl, packagePath string) { err := parsedHeaderTemplate.Execute(outFile, inPath) if err != nil { diff --git a/runtime/sema/gen/testdata/composite-type-pragma.cdc b/runtime/sema/gen/testdata/composite-type-pragma.cdc new file mode 100644 index 0000000000..29d2a8dd36 --- /dev/null +++ b/runtime/sema/gen/testdata/composite-type-pragma.cdc @@ -0,0 +1,2 @@ +#compositeType +access(all) struct Test {} diff --git a/runtime/sema/gen/testdata/composite-type-pragma.golden.go b/runtime/sema/gen/testdata/composite-type-pragma.golden.go new file mode 100644 index 0000000000..f975ff112f --- /dev/null +++ b/runtime/sema/gen/testdata/composite-type-pragma.golden.go @@ -0,0 +1,35 @@ +// Code generated from testdata/composite-type-pragma.cdc. DO NOT EDIT. +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sema + +import "github.com/onflow/cadence/runtime/common" + +const TestTypeName = "Test" + +var TestType = func() *CompositeType { + var t = &CompositeType{ + Identifier: TestTypeName, + Kind: common.CompositeKindStructure, + ImportableBuiltin: false, + HasComputedMembers: true, + } + + return t +}() diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 5d37b84cf4..4292b0e8dc 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -3725,6 +3725,7 @@ func init() { HashAlgorithmType, StorageCapabilityControllerType, AccountCapabilityControllerType, + DeploymentResultType, }, ) @@ -8016,6 +8017,7 @@ func init() { HashAlgorithmType, SignatureAlgorithmType, AccountType, + DeploymentResultType, } for len(compositeTypes) > 0 { diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index d80dbbc3a7..57a0071593 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -22,6 +22,7 @@ import ( "fmt" "golang.org/x/crypto/sha3" + "golang.org/x/xerrors" "github.com/onflow/atree" @@ -327,6 +328,12 @@ func newAccountContractsValue( addressValue, true, ), + newAccountContractsTryUpdateFunction( + sema.Account_ContractsTypeUpdateFunctionType, + gauge, + handler, + addressValue, + ), newAccountContractsGetFunction( sema.Account_ContractsTypeGetFunctionType, gauge, @@ -1369,250 +1376,306 @@ func newAccountContractsChangeFunction( gauge, functionType, func(invocation interpreter.Invocation) interpreter.Value { + return changeAccountContracts(invocation, handler, addressValue, isUpdate) + }, + ) +} - locationRange := invocation.LocationRange - - const requiredArgumentCount = 2 +func changeAccountContracts( + invocation interpreter.Invocation, + handler AccountContractAdditionHandler, + addressValue interpreter.AddressValue, + isUpdate bool, +) interpreter.Value { - nameValue, ok := invocation.Arguments[0].(*interpreter.StringValue) - if !ok { - panic(errors.NewUnreachableError()) - } + locationRange := invocation.LocationRange - newCodeValue, ok := invocation.Arguments[1].(*interpreter.ArrayValue) - if !ok { - panic(errors.NewUnreachableError()) - } + const requiredArgumentCount = 2 - constructorArguments := invocation.Arguments[requiredArgumentCount:] - constructorArgumentTypes := invocation.ArgumentTypes[requiredArgumentCount:] + nameValue, ok := invocation.Arguments[0].(*interpreter.StringValue) + if !ok { + panic(errors.NewUnreachableError()) + } - code, err := interpreter.ByteArrayValueToByteSlice(invocation.Interpreter, newCodeValue, locationRange) - if err != nil { - panic(errors.NewDefaultUserError("add requires the second argument to be an array")) - } + newCodeValue, ok := invocation.Arguments[1].(*interpreter.ArrayValue) + if !ok { + panic(errors.NewUnreachableError()) + } - // Get the existing code + constructorArguments := invocation.Arguments[requiredArgumentCount:] + constructorArgumentTypes := invocation.ArgumentTypes[requiredArgumentCount:] - contractName := nameValue.Str + code, err := interpreter.ByteArrayValueToByteSlice(invocation.Interpreter, newCodeValue, locationRange) + if err != nil { + panic(errors.NewDefaultUserError("add requires the second argument to be an array")) + } - if contractName == "" { - panic(errors.NewDefaultUserError( - "contract name argument cannot be empty." + - "it must match the name of the deployed contract declaration or contract interface declaration", - )) - } + // Get the existing code - address := addressValue.ToAddress() - location := common.NewAddressLocation(invocation.Interpreter, address, contractName) + contractName := nameValue.Str - existingCode, err := handler.GetAccountContractCode(location) - if err != nil { - panic(err) - } + if contractName == "" { + panic(errors.NewDefaultUserError( + "contract name argument cannot be empty." + + "it must match the name of the deployed contract declaration or contract interface declaration", + )) + } - if isUpdate { - // We are updating an existing contract. - // Ensure that there's a contract/contract-interface with the given name exists already + address := addressValue.ToAddress() + location := common.NewAddressLocation(invocation.Interpreter, address, contractName) - if len(existingCode) == 0 { - panic(errors.NewDefaultUserError( - "cannot update non-existing contract with name %q in account %s", - contractName, - address.ShortHexWithPrefix(), - )) - } + existingCode, err := handler.GetAccountContractCode(location) + if err != nil { + panic(err) + } - } else { - // We are adding a new contract. - // Ensure that no contract/contract interface with the given name exists already - - if len(existingCode) > 0 { - panic(errors.NewDefaultUserError( - "cannot overwrite existing contract with name %q in account %s", - contractName, - address.ShortHexWithPrefix(), - )) - } - } + if isUpdate { + // We are updating an existing contract. + // Ensure that there's a contract/contract-interface with the given name exists already - // Check the code - handleContractUpdateError := func(err error) { - if err == nil { - return - } + if len(existingCode) == 0 { + panic(errors.NewDefaultUserError( + "cannot update non-existing contract with name %q in account %s", + contractName, + address.ShortHexWithPrefix(), + )) + } - // Update the code for the error pretty printing - // NOTE: only do this when an error occurs + } else { + // We are adding a new contract. + // Ensure that no contract/contract interface with the given name exists already + + if len(existingCode) > 0 { + panic(errors.NewDefaultUserError( + "cannot overwrite existing contract with name %q in account %s", + contractName, + address.ShortHexWithPrefix(), + )) + } + } - handler.TemporarilyRecordCode(location, code) + // Check the code + handleContractUpdateError := func(err error) { + if err == nil { + return + } - panic(&InvalidContractDeploymentError{ - Err: err, - LocationRange: locationRange, - }) - } + // Update the code for the error pretty printing + // NOTE: only do this when an error occurs - // NOTE: do NOT use the program obtained from the host environment, as the current program. - // Always re-parse and re-check the new program. + handler.TemporarilyRecordCode(location, code) - // NOTE: *DO NOT* store the program – the new or updated program - // should not be effective during the execution + panic(&InvalidContractDeploymentError{ + Err: err, + LocationRange: locationRange, + }) + } - const getAndSetProgram = false + // NOTE: do NOT use the program obtained from the host environment, as the current program. + // Always re-parse and re-check the new program. - program, err := handler.ParseAndCheckProgram( - code, - location, - getAndSetProgram, - ) - handleContractUpdateError(err) + // NOTE: *DO NOT* store the program – the new or updated program + // should not be effective during the execution - // The code may declare exactly one contract or one contract interface. + const getAndSetProgram = false - var contractTypes []*sema.CompositeType - var contractInterfaceTypes []*sema.InterfaceType + program, err := handler.ParseAndCheckProgram( + code, + location, + getAndSetProgram, + ) + handleContractUpdateError(err) - program.Elaboration.ForEachGlobalType(func(_ string, variable *sema.Variable) { - switch ty := variable.Type.(type) { - case *sema.CompositeType: - if ty.Kind == common.CompositeKindContract { - contractTypes = append(contractTypes, ty) - } + // The code may declare exactly one contract or one contract interface. - case *sema.InterfaceType: - if ty.CompositeKind == common.CompositeKindContract { - contractInterfaceTypes = append(contractInterfaceTypes, ty) - } - } - }) + var contractTypes []*sema.CompositeType + var contractInterfaceTypes []*sema.InterfaceType - var deployedType sema.Type - var contractType *sema.CompositeType - var contractInterfaceType *sema.InterfaceType - var declaredName string - var declarationKind common.DeclarationKind + program.Elaboration.ForEachGlobalType(func(_ string, variable *sema.Variable) { + switch ty := variable.Type.(type) { + case *sema.CompositeType: + if ty.Kind == common.CompositeKindContract { + contractTypes = append(contractTypes, ty) + } - switch { - case len(contractTypes) == 1 && len(contractInterfaceTypes) == 0: - contractType = contractTypes[0] - declaredName = contractType.Identifier - deployedType = contractType - declarationKind = common.DeclarationKindContract - case len(contractInterfaceTypes) == 1 && len(contractTypes) == 0: - contractInterfaceType = contractInterfaceTypes[0] - declaredName = contractInterfaceType.Identifier - deployedType = contractInterfaceType - declarationKind = common.DeclarationKindContractInterface + case *sema.InterfaceType: + if ty.CompositeKind == common.CompositeKindContract { + contractInterfaceTypes = append(contractInterfaceTypes, ty) } + } + }) - if deployedType == nil { - // Update the code for the error pretty printing - // NOTE: only do this when an error occurs + var deployedType sema.Type + var contractType *sema.CompositeType + var contractInterfaceType *sema.InterfaceType + var declaredName string + var declarationKind common.DeclarationKind - handler.TemporarilyRecordCode(location, code) + switch { + case len(contractTypes) == 1 && len(contractInterfaceTypes) == 0: + contractType = contractTypes[0] + declaredName = contractType.Identifier + deployedType = contractType + declarationKind = common.DeclarationKindContract + case len(contractInterfaceTypes) == 1 && len(contractTypes) == 0: + contractInterfaceType = contractInterfaceTypes[0] + declaredName = contractInterfaceType.Identifier + deployedType = contractInterfaceType + declarationKind = common.DeclarationKindContractInterface + } - panic(errors.NewDefaultUserError( - "invalid %s: the code must declare exactly one contract or contract interface", - declarationKind.Name(), - )) - } + if deployedType == nil { + // Update the code for the error pretty printing + // NOTE: only do this when an error occurs - // The declared contract or contract interface must have the name - // passed to the constructor as the first argument + handler.TemporarilyRecordCode(location, code) - if declaredName != contractName { - // Update the code for the error pretty printing - // NOTE: only do this when an error occurs + panic(errors.NewDefaultUserError( + "invalid %s: the code must declare exactly one contract or contract interface", + declarationKind.Name(), + )) + } - handler.TemporarilyRecordCode(location, code) + // The declared contract or contract interface must have the name + // passed to the constructor as the first argument - panic(errors.NewDefaultUserError( - "invalid %s: the name argument must match the name of the declaration: got %q, expected %q", - declarationKind.Name(), - contractName, - declaredName, - )) - } + if declaredName != contractName { + // Update the code for the error pretty printing + // NOTE: only do this when an error occurs - // Validate the contract update + handler.TemporarilyRecordCode(location, code) - if isUpdate { - oldCode, err := handler.GetAccountContractCode(location) - handleContractUpdateError(err) + panic(errors.NewDefaultUserError( + "invalid %s: the name argument must match the name of the declaration: got %q, expected %q", + declarationKind.Name(), + contractName, + declaredName, + )) + } - oldProgram, err := parser.ParseProgram( - gauge, - oldCode, - parser.Config{ - IgnoreLeadingIdentifierEnabled: true, - }, - ) + // Validate the contract update - if !ignoreUpdatedProgramParserError(err) { - handleContractUpdateError(err) - } + if isUpdate { + oldCode, err := handler.GetAccountContractCode(location) + handleContractUpdateError(err) - validator := NewContractUpdateValidator( - location, - contractName, - oldProgram, - program.Program, - ) - err = validator.Validate() - handleContractUpdateError(err) - } + oldProgram, err := parser.ParseProgram( + invocation.Interpreter.SharedState.Config.MemoryGauge, + oldCode, + parser.Config{ + IgnoreLeadingIdentifierEnabled: true, + }, + ) - inter := invocation.Interpreter + if !ignoreUpdatedProgramParserError(err) { + handleContractUpdateError(err) + } - err = updateAccountContractCode( - handler, - location, - program, - code, - contractType, - constructorArguments, - constructorArgumentTypes, - updateAccountContractCodeOptions{ - createContract: !isUpdate, - }, - ) - if err != nil { - // Update the code for the error pretty printing - // NOTE: only do this when an error occurs + validator := NewContractUpdateValidator( + location, + contractName, + oldProgram, + program.Program, + ) + err = validator.Validate() + handleContractUpdateError(err) + } - handler.TemporarilyRecordCode(location, code) + inter := invocation.Interpreter - panic(err) - } + err = updateAccountContractCode( + handler, + location, + program, + code, + contractType, + constructorArguments, + constructorArgumentTypes, + updateAccountContractCodeOptions{ + createContract: !isUpdate, + }, + ) + if err != nil { + // Update the code for the error pretty printing + // NOTE: only do this when an error occurs - var eventType *sema.CompositeType + handler.TemporarilyRecordCode(location, code) - if isUpdate { - eventType = AccountContractUpdatedEventType - } else { - eventType = AccountContractAddedEventType - } + panic(err) + } - codeHashValue := CodeToHashValue(inter, code) + var eventType *sema.CompositeType - handler.EmitEvent( - inter, - eventType, - []interpreter.Value{ - addressValue, - codeHashValue, - nameValue, - }, - locationRange, - ) + if isUpdate { + eventType = AccountContractUpdatedEventType + } else { + eventType = AccountContractAddedEventType + } - return interpreter.NewDeployedContractValue( - inter, - addressValue, - nameValue, - newCodeValue, - ) + codeHashValue := CodeToHashValue(inter, code) + + handler.EmitEvent( + inter, + eventType, + []interpreter.Value{ + addressValue, + codeHashValue, + nameValue, + }, + locationRange, + ) + + return interpreter.NewDeployedContractValue( + inter, + addressValue, + nameValue, + newCodeValue, + ) +} + +func newAccountContractsTryUpdateFunction( + functionType *sema.FunctionType, + gauge common.MemoryGauge, + handler AccountContractAdditionHandler, + addressValue interpreter.AddressValue, +) *interpreter.HostFunctionValue { + return interpreter.NewHostFunctionValue( + gauge, + functionType, + func(invocation interpreter.Invocation) (deploymentResult interpreter.Value) { + var deployedContract interpreter.Value + + defer func() { + if r := recover(); r != nil { + rootError := r + for { + switch err := r.(type) { + case errors.UserError, errors.ExternalError: + // Error is ignored for now. + // Simply return with a `nil` deployed-contract + case xerrors.Wrapper: + r = err.Unwrap() + continue + default: + panic(rootError) + } + + break + } + } + + var optionalDeployedContract interpreter.OptionalValue + if deployedContract == nil { + optionalDeployedContract = interpreter.NilOptionalValue + } else { + optionalDeployedContract = interpreter.NewSomeValueNonCopying(invocation.Interpreter, deployedContract) + } + + deploymentResult = interpreter.NewDeploymentResultValue(gauge, optionalDeployedContract) + }() + + deployedContract = changeAccountContracts(invocation, handler, addressValue, true) + return }, ) } diff --git a/runtime/tests/checker/account_test.go b/runtime/tests/checker/account_test.go index d9d6677ccf..0876e90508 100644 --- a/runtime/tests/checker/account_test.go +++ b/runtime/tests/checker/account_test.go @@ -1097,6 +1097,48 @@ func TestCheckAccountContractsUpdate(t *testing.T) { `) require.NoError(t, err) }) + + t.Run("try update, unauthorized", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test(contracts: &Account.Contracts): DeploymentResult { + return contracts.tryUpdate(name: "foo", code: "012".decodeHex()) + } + `) + + errors := RequireCheckerErrors(t, err, 1) + + var invalidAccessErr *sema.InvalidAccessError + require.ErrorAs(t, errors[0], &invalidAccessErr) + assert.Equal(t, "tryUpdate", invalidAccessErr.Name) + }) + + t.Run("try update, authorized", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test(contracts: auth(Contracts) &Account.Contracts): DeploymentResult { + return contracts.tryUpdate(name: "foo", code: "012".decodeHex()) + } + `) + require.NoError(t, err) + }) + + t.Run("deployment result fields", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test(contracts: auth(Contracts) &Account.Contracts) { + let deploymentResult: DeploymentResult = contracts.tryUpdate(name: "foo", code: "012".decodeHex()) + let deployedContract: DeployedContract = deploymentResult.deployedContract! + let name: String = deployedContract.name + let address: Address = deployedContract.address + let code: [UInt8] = deployedContract.code + } + `) + require.NoError(t, err) + }) } func TestCheckAccountContractsRemove(t *testing.T) { diff --git a/runtime/tests/checker/type_inference_test.go b/runtime/tests/checker/type_inference_test.go index 8d0ef43450..5c6f69c615 100644 --- a/runtime/tests/checker/type_inference_test.go +++ b/runtime/tests/checker/type_inference_test.go @@ -1258,3 +1258,31 @@ func TestCheckCompositeSupertypeInference(t *testing.T) { assert.Equal(t, expectedType.ID(), intersectionType.ID()) }) } + +func TestCheckDeploymentResultInference(t *testing.T) { + + t.Parallel() + + code := ` + let x: DeploymentResult = getDeploymentResult() + let y: DeploymentResult = getDeploymentResult() + + // Function is just to get a 'DeploymentResult' return type. + fun getDeploymentResult(): DeploymentResult { + let v: DeploymentResult? = nil + return v! + } + + let z = [x, y] + ` + + checker, err := ParseAndCheck(t, code) + require.NoError(t, err) + + zType := RequireGlobalValue(t, checker.Elaboration, "z") + + require.IsType(t, &sema.VariableSizedType{}, zType) + variableSizedType := zType.(*sema.VariableSizedType) + + assert.Equal(t, sema.DeploymentResultType, variableSizedType.Type) +}