diff --git a/arch/arm64/arm64_assembler.cpp b/arch/arm64/arm64_assembler.cpp index c3fcd2a..6ddae72 100644 --- a/arch/arm64/arm64_assembler.cpp +++ b/arch/arm64/arm64_assembler.cpp @@ -35,6 +35,7 @@ uint32_t NOP = 0xd503201f; uint32_t CRASH = 0xd63f03e0; uint32_t MRS_X0_NZCV = 0xd53b4200; +uint32_t MRS_X1_NZCV = 0xd53b4201; uint32_t MSR_NZCV_X0 = 0xd51b4200; // strip pac tag @@ -271,15 +272,12 @@ void Arm64Assembler::InstrumentGlobalIndirect(ModuleInfo *module, please use 'local'"); } -// converts an indirect jump/call into a MOV instruction -// which moves the target of the indirect call into the X0 register -// and writes this instruction into the code buffer -// + // returns the register number which is used by the original code // to perform the branch. -uint8_t Arm64Assembler::MovIndirectTarget(ModuleInfo *module, Instruction &inst) { +uint8_t Arm64Assembler::GetIndirectTarget(Instruction &inst, uint8_t *is_pac) { Register target_address_reg = Register::X0; - bool strip_pac = false; + uint8_t strip_pac = 0; switch (inst.instr.opcode) { case arm64::Opcode::kBraa: @@ -291,7 +289,7 @@ uint8_t Arm64Assembler::MovIndirectTarget(ModuleInfo *module, Instruction &inst) case arm64::Opcode::kBlrab: case arm64::Opcode::kBlraaz: case arm64::Opcode::kBlrabz: - strip_pac = true; + strip_pac = 1; // fall through case arm64::Opcode::kBr: case arm64::Opcode::kBlr: @@ -301,7 +299,7 @@ uint8_t Arm64Assembler::MovIndirectTarget(ModuleInfo *module, Instruction &inst) case arm64::Opcode::kRetaa: case arm64::Opcode::kRetab: - strip_pac = true; + strip_pac = 1; // fall through case arm64::Opcode::kRet: target_address_reg = Register::LR; @@ -311,12 +309,19 @@ uint8_t Arm64Assembler::MovIndirectTarget(ModuleInfo *module, Instruction &inst) FATAL("not implemented yet"); } - uint32_t mov_instr = mov(Register::X0, target_address_reg); + if(is_pac) *is_pac = strip_pac; + return static_cast(target_address_reg); +} + +// converts an indirect jump/call into a MOV instruction +// which moves the target of the indirect call into the X0 register +// and writes this instruction into the code buffer +void Arm64Assembler::MovIndirectTarget(ModuleInfo *module, uint8_t target_address_reg, uint8_t is_pac) { + uint32_t mov_instr = mov(Register::X0, static_cast(target_address_reg)); tinyinst_.WriteCode(module, &mov_instr, sizeof(mov_instr)); - if (strip_pac) { + if (is_pac) { tinyinst_.WriteCode(module, &XPACI_X0, sizeof(XPACI_X0)); } - return static_cast(target_address_reg); } // translates indirect jump or call @@ -334,17 +339,25 @@ void Arm64Assembler::InstrumentLocalIndirect(ModuleInfo *module, // it in the if clause. OffsetStack(module, -32); + uint8_t is_pac = 0; + uint8_t branch_register_number = GetIndirectTarget(inst, &is_pac); + // stack layout // x0 // x1 // alu flags WriteRegStack(module, Register::X1, 16); WriteRegStack(module, Register::X0, 8); - tinyinst_.WriteCode(module, &MRS_X0_NZCV, sizeof(MRS_X0_NZCV)); - WriteRegStack(module, Register::X0, 0); + if(branch_register_number != 0) { + tinyinst_.WriteCode(module, &MRS_X0_NZCV, sizeof(MRS_X0_NZCV)); + WriteRegStack(module, Register::X0, 0); + } else { + tinyinst_.WriteCode(module, &MRS_X1_NZCV, sizeof(MRS_X1_NZCV)); + WriteRegStack(module, Register::X1, 0); + } // Emit instructions that load the target address to the X0 register. - uint8_t branch_register_number = MovIndirectTarget(module, inst); + MovIndirectTarget(module, branch_register_number, is_pac); // InstrumentLocalIndirect iterates through a linked list until the it // finds the code that was generated for the target address. Jumps are diff --git a/arch/arm64/arm64_assembler.h b/arch/arm64/arm64_assembler.h index 4b45457..5f55586 100644 --- a/arch/arm64/arm64_assembler.h +++ b/arch/arm64/arm64_assembler.h @@ -66,7 +66,8 @@ class Arm64Assembler : public Assembler { bool is_signed, uint64_t value); private: - uint8_t MovIndirectTarget(ModuleInfo *module, Instruction &inst); + uint8_t GetIndirectTarget(Instruction &inst, uint8_t *is_pac); + void MovIndirectTarget(ModuleInfo *module, uint8_t target_address_reg, uint8_t is_pac); void ReadStack(ModuleInfo *module, int32_t offset); void WriteStack(ModuleInfo *module, int32_t offset);