diff --git a/cannon/mipsevm/tests/evm_common_test.go b/cannon/mipsevm/tests/evm_common_test.go index ad93014450dd..402fe89bd321 100644 --- a/cannon/mipsevm/tests/evm_common_test.go +++ b/cannon/mipsevm/tests/evm_common_test.go @@ -201,14 +201,20 @@ func TestEVMSingleStep_Operators(t *testing.T) { goVm := v.VMFactory(nil, os.Stdout, os.Stderr, testutil.CreateLogger(), testutil.WithRandomization(int64(i)), testutil.WithPC(0), testutil.WithNextPC(4)) state := goVm.GetState() var insn uint32 + var baseReg uint32 = 17 + var rtReg uint32 + var rdReg uint32 if tt.isImm { - insn = tt.opcode<<26 | uint32(17)<<21 | uint32(8)<<16 | uint32(tt.imm) - state.GetRegistersRef()[8] = tt.rt - state.GetRegistersRef()[17] = tt.rs + rtReg = 8 + insn = tt.opcode<<26 | baseReg<<21 | rtReg<<16 | uint32(tt.imm) + state.GetRegistersRef()[rtReg] = tt.rt + state.GetRegistersRef()[baseReg] = tt.rs } else { - insn = uint32(17)<<21 | uint32(18)<<16 | uint32(8)<<11 | tt.funct - state.GetRegistersRef()[17] = tt.rs - state.GetRegistersRef()[18] = tt.rt + rtReg = 18 + rdReg = 8 + insn = baseReg<<21 | rtReg<<16 | rdReg<<11 | tt.funct + state.GetRegistersRef()[baseReg] = tt.rs + state.GetRegistersRef()[rtReg] = tt.rt } state.GetMemory().SetMemory(0, insn) step := state.GetStep() @@ -218,8 +224,85 @@ func TestEVMSingleStep_Operators(t *testing.T) { expected.Step += 1 expected.PC = 4 expected.NextPC = 8 - expected.Registers[8] = tt.expectRes + if tt.isImm { + expected.Registers[rtReg] = tt.expectRes + } else { + expected.Registers[rdReg] = tt.expectRes + } + + stepWitness, err := goVm.Step(true) + require.NoError(t, err) + + // Check expectations + expected.Validate(t, state) + testutil.ValidateEVM(t, stepWitness, step, goVm, v.StateHashFn, v.Contracts, tracer) + }) + } + } +} + +func TestEVMSingleStep_LoadStore(t *testing.T) { + var tracer *tracing.Hooks + + versions := GetMipsVersionTestCases(t) + cases := []struct { + name string + rs uint32 + rt uint32 + isUnAligned bool + opcode uint32 + memVal uint32 + expectMemVal uint32 + expectRes uint32 + }{ + {name: "lb", opcode: uint32(0x20), memVal: uint32(0x12_00_00_00), expectRes: uint32(0x12)}, // lb $t0, 4($t1) + {name: "lh", opcode: uint32(0x21), memVal: uint32(0x12_23_00_00), expectRes: uint32(0x12_23)}, // lh $t0, 4($t1) + {name: "lw", opcode: uint32(0x23), memVal: uint32(0x12_23_45_67), expectRes: uint32(0x12_23_45_67)}, // lw $t0, 4($t1) + {name: "lbu", opcode: uint32(0x24), memVal: uint32(0x12_23_00_00), expectRes: uint32(0x12)}, // lbu $t0, 4($t1) + {name: "lhu", opcode: uint32(0x25), memVal: uint32(0x12_23_00_00), expectRes: uint32(0x12_23)}, // lhu $t0, 4($t1) + {name: "lwl", opcode: uint32(0x22), rt: uint32(0xaa_bb_cc_dd), memVal: uint32(0x12_34_56_78), expectRes: uint32(0x12_34_56_78)}, // lwl $t0, 4($t1) + {name: "lwl unaligned address", opcode: uint32(0x22), rt: uint32(0xaa_bb_cc_dd), isUnAligned: true, memVal: uint32(0x12_34_56_78), expectRes: uint32(0x34_56_78_dd)}, // lwl $t0, 5($t1) + {name: "lwr", opcode: uint32(0x26), rt: uint32(0xaa_bb_cc_dd), memVal: uint32(0x12_34_56_78), expectRes: uint32(0xaa_bb_cc_12)}, // lwr $t0, 4($t1) + {name: "lwr unaligned address", opcode: uint32(0x26), rt: uint32(0xaa_bb_cc_dd), isUnAligned: true, memVal: uint32(0x12_34_56_78), expectRes: uint32(0xaa_bb_12_34)}, // lwr $t0, 5($t1) + {name: "sb", opcode: uint32(0x28), rt: uint32(0xaa_bb_cc_dd), expectMemVal: uint32(0xdd_00_00_00)}, // sb $t0, 4($t1) + {name: "sh", opcode: uint32(0x29), rt: uint32(0xaa_bb_cc_dd), expectMemVal: uint32(0xcc_dd_00_00)}, // sh $t0, 4($t1) + {name: "swl", opcode: uint32(0x2a), rt: uint32(0xaa_bb_cc_dd), expectMemVal: uint32(0xaa_bb_cc_dd)}, // swl $t0, 4($t1) + {name: "sw", opcode: uint32(0x2b), rt: uint32(0xaa_bb_cc_dd), expectMemVal: uint32(0xaa_bb_cc_dd)}, // sw $t0, 4($t1) + {name: "swr unaligned address", opcode: uint32(0x2e), rt: uint32(0xaa_bb_cc_dd), isUnAligned: true, expectMemVal: uint32(0xcc_dd_00_00)}, // swr $t0, 5($t1) + } + + var t1 uint32 = 0x100 + var baseReg uint32 = 9 + var rtReg uint32 = 8 + for _, v := range versions { + for i, tt := range cases { + testName := fmt.Sprintf("%v (%v)", tt.name, v.Name) + t.Run(testName, func(t *testing.T) { + goVm := v.VMFactory(nil, os.Stdout, os.Stderr, testutil.CreateLogger(), testutil.WithRandomization(int64(i)), testutil.WithPC(0), testutil.WithNextPC(4)) + state := goVm.GetState() + var insn uint32 + imm := uint32(0x4) + if tt.isUnAligned { + imm = uint32(0x5) + } + + insn = tt.opcode<<26 | baseReg<<21 | rtReg<<16 | imm + state.GetRegistersRef()[rtReg] = tt.rt + state.GetRegistersRef()[baseReg] = t1 + + state.GetMemory().SetMemory(0, insn) + state.GetMemory().SetMemory(t1+4, tt.memVal) + step := state.GetStep() + // Setup expectations + expected := testutil.NewExpectedState(state) + expected.ExpectStep() + + if tt.expectMemVal != 0 { + expected.ExpectMemoryWrite(t1+4, tt.expectMemVal) + } else { + expected.Registers[rtReg] = tt.expectRes + } stepWitness, err := goVm.Step(true) require.NoError(t, err)