From e2a302cd6e0a8538184e5362af9da1b4a72ee52e Mon Sep 17 00:00:00 2001 From: Ardit Marku Date: Mon, 6 Nov 2023 16:17:14 +0200 Subject: [PATCH] Add native function declarations for all relevant Test contract functions --- runtime/sema/type.go | 14 +-- runtime/stdlib/contracts/test.cdc | 83 +++++++++++++ runtime/stdlib/test_contract.go | 23 ++-- runtime/stdlib/test_test.go | 191 ++++++++++++++++++++++++++---- 4 files changed, 268 insertions(+), 43 deletions(-) diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 251769e704..9450d9405c 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -4738,11 +4738,7 @@ func (t *CompositeType) GetMembers() map[string]MemberResolver { } func (t *CompositeType) initializeMemberResolvers() { - t.memberResolversOnce.Do(t.initializerMemberResolversFunc()) -} - -func (t *CompositeType) initializerMemberResolversFunc() func() { - return func() { + t.memberResolversOnce.Do(func() { memberResolvers := MembersMapAsResolvers(t.Members) // Check conformances. @@ -4777,13 +4773,7 @@ func (t *CompositeType) initializerMemberResolversFunc() func() { } t.memberResolvers = withBuiltinMembers(t, memberResolvers) - } -} - -func (t *CompositeType) ResolveMembers() { - if t.Members.Len() != len(t.GetMembers()) { - t.initializerMemberResolversFunc()() - } + }) } func (t *CompositeType) FieldPosition(name string, declaration ast.CompositeLikeDeclaration) ast.Position { diff --git a/runtime/stdlib/contracts/test.cdc b/runtime/stdlib/contracts/test.cdc index 70a484d54f..155015f948 100644 --- a/runtime/stdlib/contracts/test.cdc +++ b/runtime/stdlib/contracts/test.cdc @@ -473,4 +473,87 @@ contract Test { assert(found, message: "the error message did not contain the given sub-string") } + + /// Creates a matcher with a test function. + /// The test function is of type 'fun(T): Bool', + /// where 'T' is bound to 'AnyStruct'. + /// + access(all) + native fun newMatcher(_ test: fun(T): Bool): Test.Matcher {} + + /// Wraps a function call in a closure, and expects it to fail with + /// an error message that contains the given error message portion. + /// + access(all) + native fun expectFailure( + _ functionWrapper: fun(): Void, + errorMessageSubstring: String + ) {} + + /// Expect function tests a value against a matcher + /// and fails the test if it's not a match. + /// + access(all) + native fun expect(_ value: T, _ matcher: Test.Matcher) {} + + /// Returns a matcher that succeeds if the tested + /// value is equal to the given value. + /// + access(all) + native fun equal(_ value: T): Test.Matcher {} + + /// Fails the test-case if the given values are not equal, and + /// reports a message which explains how the two values differ. + /// + access(all) + native fun assertEqual(_ expected: AnyStruct, _ actual: AnyStruct) {} + + /// Returns a matcher that succeeds if the tested value is + /// an array or dictionary and the tested value contains + /// no elements. + /// + access(all) + native fun beEmpty(): Test.Matcher {} + + /// Returns a matcher that succeeds if the tested value is + /// an array or dictionary and has the given number of elements. + /// + access(all) + native fun haveElementCount(_ count: Int): Test.Matcher {} + + /// Returns a matcher that succeeds if the tested value is + /// an array that contains a value that is equal to the given + /// value, or the tested value is a dictionary that contains + /// an entry where the key is equal to the given value. + /// + access(all) + native fun contain(_ element: AnyStruct): Test.Matcher {} + + /// Returns a matcher that succeeds if the tested value + /// is a number and greater than the given number. + /// + access(all) + native fun beGreaterThan(_ value: Number): Test.Matcher {} + + /// Returns a matcher that succeeds if the tested value + /// is a number and less than the given number. + /// + access(all) + native fun beLessThan(_ value: Number): Test.Matcher {} + + /// Read a local file, and return the content as a string. + /// + access(all) + native fun readFile(_ path: String): String {} + + /// Fails the test-case if the given condition is false, + /// and reports a message which explains how the condition is false. + /// + access(all) + native fun assert(_ condition: Bool, message: String = ""): Void {} + + /// Fails the test-case with a message. + /// + access(all) + native fun fail(message: String = ""): Void {} } diff --git a/runtime/stdlib/test_contract.go b/runtime/stdlib/test_contract.go index 862f234ac1..3e3eb7e517 100644 --- a/runtime/stdlib/test_contract.go +++ b/runtime/stdlib/test_contract.go @@ -64,13 +64,12 @@ var testTypeAssertFunctionType = &sema.FunctionType{ TypeAnnotation: sema.BoolTypeAnnotation, }, { - Identifier: "message", - TypeAnnotation: sema.StringTypeAnnotation, + Identifier: "message", + TypeAnnotation: sema.StringTypeAnnotation, + DefaultArgument: sema.StringType, }, }, ReturnTypeAnnotation: sema.VoidTypeAnnotation, - // `message` parameter is optional - Arity: &sema.Arity{Min: 1, Max: 2}, } var testTypeAssertFunction = interpreter.NewUnmeteredHostFunctionValue( @@ -181,13 +180,12 @@ var testTypeFailFunctionType = &sema.FunctionType{ Purity: sema.FunctionPurityView, Parameters: []sema.Parameter{ { - Identifier: "message", - TypeAnnotation: sema.StringTypeAnnotation, + Identifier: "message", + TypeAnnotation: sema.StringTypeAnnotation, + DefaultArgument: sema.StringType, }, }, ReturnTypeAnnotation: sema.VoidTypeAnnotation, - // `message` parameter is optional - Arity: &sema.Arity{Min: 0, Max: 1}, } var testTypeFailFunction = interpreter.NewUnmeteredHostFunctionValue( @@ -915,7 +913,10 @@ func newTestContractType() *TestContractType { program, err := parser.ParseProgram( nil, contracts.TestContract, - parser.Config{}, + parser.Config{ + NativeModifierEnabled: true, + TypeParametersEnabled: true, + }, ) if err != nil { panic(err) @@ -933,7 +934,8 @@ func newTestContractType() *TestContractType { BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { return activation }, - AccessCheckMode: sema.AccessCheckModeStrict, + AccessCheckMode: sema.AccessCheckModeStrict, + AllowNativeDeclarations: true, }, ) if err != nil { @@ -1160,7 +1162,6 @@ func newTestContractType() *TestContractType { ty.expectFailureFunction = newTestTypeExpectFailureFunction( expectFailureFunctionType, ) - compositeType.ResolveMembers() return ty } diff --git a/runtime/stdlib/test_test.go b/runtime/stdlib/test_test.go index a2c00688a8..31316da3fb 100644 --- a/runtime/stdlib/test_test.go +++ b/runtime/stdlib/test_test.go @@ -57,7 +57,10 @@ func newTestContractInterpreterWithTestFramework( program, err := parser.ParseProgram( nil, []byte(code), - parser.Config{}, + parser.Config{ + NativeModifierEnabled: true, + TypeParametersEnabled: true, + }, ) require.NoError(t, err) @@ -73,7 +76,8 @@ func newTestContractInterpreterWithTestFramework( BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { return baseValueActivation }, - AccessCheckMode: sema.AccessCheckModeStrict, + AccessCheckMode: sema.AccessCheckModeStrict, + AllowNativeDeclarations: true, ImportHandler: func( checker *sema.Checker, importedLocation common.Location, @@ -714,6 +718,80 @@ func TestTestEqualMatcher(t *testing.T) { }) } +func TestAssertFunction(t *testing.T) { + t.Parallel() + + const script = ` + import Test + + access(all) + fun testAssertWithNoArgs() { + Test.assert(true) + } + + access(all) + fun testAssertWithNoArgsFail() { + Test.assert(false) + } + + access(all) + fun testAssertWithMessage() { + Test.assert(true, message: "some reason") + } + + access(all) + fun testAssertWithMessageFail() { + Test.assert(false, message: "some reason") + } + ` + + inter, err := newTestContractInterpreter(t, script) + require.NoError(t, err) + + _, err = inter.Invoke("testAssertWithNoArgs") + require.NoError(t, err) + + _, err = inter.Invoke("testAssertWithNoArgsFail") + require.Error(t, err) + assert.ErrorContains(t, err, "assertion failed") + + _, err = inter.Invoke("testAssertWithMessage") + require.NoError(t, err) + + _, err = inter.Invoke("testAssertWithMessageFail") + require.Error(t, err) + require.ErrorContains(t, err, "assertion failed: some reason") +} + +func TestFailFunction(t *testing.T) { + t.Parallel() + + const script = ` + import Test + + access(all) + fun testFailWithoutMessage() { + Test.fail() + } + + access(all) + fun testFailWithMessage() { + Test.fail(message: "some error") + } + ` + + inter, err := newTestContractInterpreter(t, script) + require.NoError(t, err) + + _, err = inter.Invoke("testFailWithoutMessage") + require.Error(t, err) + require.ErrorContains(t, err, "assertion failed") + + _, err = inter.Invoke("testFailWithMessage") + require.Error(t, err) + require.ErrorContains(t, err, "assertion failed: some error") +} + func TestAssertEqual(t *testing.T) { t.Parallel() @@ -2209,7 +2287,7 @@ func TestBlockchain(t *testing.T) { Test.expect(events, Test.beEmpty()) } - ` + ` eventsInvoked := false @@ -2258,7 +2336,7 @@ func TestBlockchain(t *testing.T) { Test.expect(events, Test.beEmpty()) } - ` + ` eventsInvoked := false @@ -2303,7 +2381,7 @@ func TestBlockchain(t *testing.T) { fun test() { Test.reset(to: 5) } - ` + ` resetInvoked := false @@ -2337,7 +2415,7 @@ func TestBlockchain(t *testing.T) { fun test() { Test.reset(to: 5.5) } - ` + ` resetInvoked := false @@ -2370,7 +2448,7 @@ func TestBlockchain(t *testing.T) { let timeDelta = Fix64(35 * 24 * 60 * 60) Test.moveTime(by: timeDelta) } - ` + ` moveTimeInvoked := false @@ -2407,7 +2485,7 @@ func TestBlockchain(t *testing.T) { let timeDelta = Fix64(35 * 24 * 60 * 60) * -1.0 Test.moveTime(by: timeDelta) } - ` + ` moveTimeInvoked := false @@ -2441,7 +2519,7 @@ func TestBlockchain(t *testing.T) { fun testMoveTime() { Test.moveTime(by: 3000) } - ` + ` moveTimeInvoked := false @@ -2471,7 +2549,7 @@ func TestBlockchain(t *testing.T) { fun test() { Test.createSnapshot(name: "adminCreated") } - ` + ` createSnapshotInvoked := false @@ -2507,7 +2585,7 @@ func TestBlockchain(t *testing.T) { fun test() { Test.createSnapshot(name: "adminCreated") } - ` + ` createSnapshotInvoked := false @@ -2544,7 +2622,7 @@ func TestBlockchain(t *testing.T) { Test.createSnapshot(name: "adminCreated") Test.loadSnapshot(name: "adminCreated") } - ` + ` loadSnapshotInvoked := false @@ -2586,7 +2664,7 @@ func TestBlockchain(t *testing.T) { Test.createSnapshot(name: "adminCreated") Test.loadSnapshot(name: "contractDeployed") } - ` + ` loadSnapshotInvoked := false @@ -2633,7 +2711,7 @@ func TestBlockchain(t *testing.T) { Test.expect(err, Test.beNil()) } - ` + ` deployContractInvoked := false @@ -2687,7 +2765,7 @@ func TestBlockchain(t *testing.T) { err!.message ) } - ` + ` deployContractInvoked := false @@ -2726,9 +2804,9 @@ func TestBlockchain(t *testing.T) { access(all) fun test() { let account = Test.getAccount(0x0000000000000009) - Test.assertEqual(0x0000000000000009 as Address, account.address) + Test.assertEqual(Address(0x0000000000000009), account.address) } - ` + ` getAccountInvoked := false @@ -2774,7 +2852,7 @@ func TestBlockchain(t *testing.T) { fun test() { let account = Test.getAccount(0x0000000000000009) } - ` + ` getAccountInvoked := false @@ -2804,6 +2882,79 @@ func TestBlockchain(t *testing.T) { assert.True(t, getAccountInvoked) }) + t.Run("readFile", func(t *testing.T) { + t.Parallel() + + const script = ` + import Test + + access(all) + fun test() { + let content = Test.readFile("some_file.cdc") + Test.assertEqual("Hey there!", content) + } + ` + + readFileInvoked := false + + testFramework := &mockedTestFramework{ + emulatorBackend: func() stdlib.Blockchain { + return &mockedBlockchain{} + }, + readFile: func(path string) (string, error) { + readFileInvoked = true + assert.Equal(t, "some_file.cdc", path) + + return "Hey there!", nil + }, + } + + inter, err := newTestContractInterpreterWithTestFramework(t, script, testFramework) + require.NoError(t, err) + + _, err = inter.Invoke("test") + require.NoError(t, err) + + assert.True(t, readFileInvoked) + }) + + t.Run("readFile with failure", func(t *testing.T) { + t.Parallel() + + const script = ` + import Test + + access(all) + fun test() { + let content = Test.readFile("some_file.cdc") + Test.assertEqual("Hey there!", content) + } + ` + + readFileInvoked := false + + testFramework := &mockedTestFramework{ + emulatorBackend: func() stdlib.Blockchain { + return &mockedBlockchain{} + }, + readFile: func(path string) (string, error) { + readFileInvoked = true + assert.Equal(t, "some_file.cdc", path) + + return "", fmt.Errorf("could not read file: %s", path) + }, + } + + inter, err := newTestContractInterpreterWithTestFramework(t, script, testFramework) + require.NoError(t, err) + + _, err = inter.Invoke("test") + require.Error(t, err) + assert.ErrorContains(t, err, "could not read file: some_file.cdc") + + assert.True(t, readFileInvoked) + }) + // TODO: Add more tests for the remaining functions. } @@ -2820,9 +2971,9 @@ func TestBlockchainAccount(t *testing.T) { access(all) fun test() { let account = Test.createAccount() - assert(account.address == 0x0100000000000000) + Test.assertEqual(Address(0x0100000000000000), account.address) } - ` + ` testFramework := &mockedTestFramework{ emulatorBackend: func() stdlib.Blockchain {