Skip to content

Commit

Permalink
Add a new fusion instruction and remove unnecessary ones
Browse files Browse the repository at this point in the history
After constant propagation and constant folding, auipc + addi and lui +
addi become 2 lui instuctions. Moreover, we discover a sequence of lui
instrution existing in optimized basic block, so we add a new fusion
function to handle multiple lui instructions.
  • Loading branch information
qwe661234 committed Sep 24, 2023
1 parent 7c71c3f commit 4dbfe8d
Showing 1 changed file with 155 additions and 127 deletions.
282 changes: 155 additions & 127 deletions src/emulate.c
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,7 @@ static uint32_t last_pc = 0;
_(fuse3) \
_(fuse4) \
_(fuse5) \
_(fuse6) \
_(fuse7)
_(fuse6)

enum {
rv_insn_fuse0 = N_RV_INSNS,
Expand All @@ -396,24 +395,26 @@ enum {
#undef _
};

/* AUIPC + ADDI */
/* multiple lui */
static bool do_fuse1(riscv_t *rv, const rv_insn_t *ir)
{
rv->csr_cycle += 2;
rv->X[ir->rd] = rv->PC + ir->imm;
rv->X[ir->rs1] = rv->X[ir->rd] + ir->imm2;
rv->PC += 2 * ir->insn_len;
rv->csr_cycle += ir->imm2;
for (int i = 0; i < ir->imm2; i++) {
const rv_insn_t *cur_ir = ir + i;
rv->X[cur_ir->rd] = cur_ir->imm;
}
rv->PC += ir->imm2 * ir->insn_len;
if (unlikely(RVOP_NO_NEXT(ir)))
return true;
const rv_insn_t *next = ir + 2;
const rv_insn_t *next = ir + ir->imm2;
MUST_TAIL return next->impl(rv, next);
}

/* AUIPC + ADD */
/* LUI + ADD */
static bool do_fuse2(riscv_t *rv, const rv_insn_t *ir)
{
rv->csr_cycle += 2;
rv->X[ir->rd] = rv->PC + ir->imm;
rv->X[ir->rd] = ir->imm;
rv->X[ir->rs2] = rv->X[ir->rd] + rv->X[ir->rs1];
rv->PC += 2 * ir->insn_len;
if (unlikely(RVOP_NO_NEXT(ir)))
Expand Down Expand Up @@ -468,21 +469,8 @@ static bool do_fuse4(riscv_t *rv, const rv_insn_t *ir)
MUST_TAIL return next->impl(rv, next);
}

/* LUI + ADDI */
static bool do_fuse5(riscv_t *rv, const rv_insn_t *ir)
{
rv->csr_cycle += 2;
rv->X[ir->rd] = ir->imm;
rv->X[ir->rs1] = ir->imm + ir->imm2;
rv->PC += 2 * ir->insn_len;
if (unlikely(RVOP_NO_NEXT(ir)))
return true;
const rv_insn_t *next = ir + 2;
MUST_TAIL return next->impl(rv, next);
}

/* memset */
static bool do_fuse6(riscv_t *rv, const rv_insn_t *ir)
static bool do_fuse5(riscv_t *rv, const rv_insn_t *ir)
{
rv->csr_cycle += 2;
memory_t *m = ((state_t *) rv->userdata)->mem;
Expand All @@ -496,7 +484,7 @@ static bool do_fuse6(riscv_t *rv, const rv_insn_t *ir)
}

/* memcpy */
static bool do_fuse7(riscv_t *rv, const rv_insn_t *ir)
static bool do_fuse6(riscv_t *rv, const rv_insn_t *ir)
{
rv->csr_cycle += 2;
memory_t *m = ((state_t *) rv->userdata)->mem;
Expand Down Expand Up @@ -572,6 +560,7 @@ static void block_translate(riscv_t *rv, block_t *block)
break;
}
ir->impl = dispatch_table[ir->opcode];
ir->pc = block->pc_end;
/* compute the end of pc */
block->pc_end += ir->insn_len;
block->n_insn++;
Expand Down Expand Up @@ -834,102 +823,102 @@ static bool detect_memcpy(riscv_t *rv, int lib)
return true;
}

static bool libc_substitute(riscv_t *rv, block_t *block)
{
rv_insn_t *ir = block->ir, *next_ir = NULL;
switch (ir->opcode) {
case rv_insn_addi:
/* Compare the target block with the first basic block of
* memset/memcpy, if two block is match, we would extract the
* instruction sequence starting from the pc_start of the basic
* block and then compare it with the pre-recorded memset/memcpy
* instruction sequence.
*/
if (ir->imm == 15 && ir->rd == rv_reg_t1 && ir->rs1 == rv_reg_zero) {
next_ir = ir + 1;
if (next_ir->opcode == rv_insn_addi && next_ir->rd == rv_reg_a4 &&
next_ir->rs1 == rv_reg_a0 && next_ir->rs2 == rv_reg_zero) {
next_ir = next_ir + 1;
if (next_ir->opcode == rv_insn_bgeu && next_ir->imm == 60 &&
next_ir->rs1 == rv_reg_t1 && next_ir->rs2 == rv_reg_a2) {
if (detect_memset(rv, 1)) {
ir->opcode = rv_insn_fuse5;
ir->impl = dispatch_table[ir->opcode];
ir->tailcall = true;
return true;
};
}
}
} else if (ir->imm == 0 && ir->rd == rv_reg_t1 &&
ir->rs1 == rv_reg_a0) {
next_ir = ir + 1;
if (next_ir->opcode == rv_insn_beq && next_ir->rs1 == rv_reg_a2 &&
next_ir->rs2 == rv_reg_zero) {
if (next_ir->imm == 20 && detect_memset(rv, 2)) {
ir->opcode = rv_insn_fuse5;
ir->impl = dispatch_table[ir->opcode];
ir->tailcall = true;
return true;
} else if (next_ir->imm == 28 && detect_memcpy(rv, 2)) {
ir->opcode = rv_insn_fuse6;
ir->impl = dispatch_table[ir->opcode];
ir->tailcall = true;
return true;
};
}
}
break;
case rv_insn_xor:
/* Compare the target block with the first basic block of memcpy, if
* two block is match, we would extract the instruction sequence
* starting from the pc_start of the basic block and then compare
* it with the pre-recorded memcpy instruction sequence.
*/
if (ir->rd == rv_reg_a5 && ir->rs1 == rv_reg_a0 &&
ir->rs2 == rv_reg_a1) {
next_ir = ir + 1;
if (next_ir->opcode == rv_insn_andi && next_ir->imm == 3 &&
next_ir->rd == rv_reg_a5 && next_ir->rs1 == rv_reg_a5) {
next_ir = next_ir + 1;
if (next_ir->opcode == rv_insn_add &&
next_ir->rd == rv_reg_a7 && next_ir->rs1 == rv_reg_a0 &&
next_ir->rs2 == rv_reg_a2) {
next_ir = next_ir + 1;
if (next_ir->opcode == rv_insn_bne && next_ir->imm == 104 &&
next_ir->rs1 == rv_reg_a5 &&
next_ir->rs2 == rv_reg_zero) {
if (detect_memcpy(rv, 1)) {
ir->opcode = rv_insn_fuse6;
ir->impl = dispatch_table[ir->opcode];
ir->tailcall = true;
return true;
};
}
}
}
}
break;
/* TODO: inject other frequently-used standard library */
}
return false;
}

/* Check if instructions in a block match a specific pattern. If they do,
* rewrite them as fused instructions.
*
* Strategies are being devised to increase the number of instructions that
* match the pattern, including possible instruction reordering.
*/
static void match_pattern(riscv_t *rv, block_t *block)
static void match_pattern(block_t *block)
{
for (uint32_t i = 0; i < block->n_insn - 1; i++) {
rv_insn_t *ir = block->ir + i, *next_ir = NULL;
int32_t count = 0, sign = 1;
switch (ir->opcode) {
case rv_insn_addi:
/* Compare the target block with the first basic block of
* memset/memcpy, if two block is match, we would extract the
* instruction sequence starting from the pc_start of the basic
* block and then compare it with the pre-recorded memset/memcpy
* instruction sequence.
*/
if (ir->imm == 15 && ir->rd == rv_reg_t1 &&
ir->rs1 == rv_reg_zero) {
next_ir = ir + 1;
if (next_ir->opcode == rv_insn_addi &&
next_ir->rd == rv_reg_a4 && next_ir->rs1 == rv_reg_a0 &&
next_ir->rs2 == rv_reg_zero) {
next_ir = next_ir + 1;
if (next_ir->opcode == rv_insn_bgeu && next_ir->imm == 60 &&
next_ir->rs1 == rv_reg_t1 &&
next_ir->rs2 == rv_reg_a2) {
if (detect_memset(rv, 1)) {
ir->opcode = rv_insn_fuse6;
ir->impl = dispatch_table[ir->opcode];
ir->tailcall = true;
};
}
}
} else if (ir->imm == 0 && ir->rd == rv_reg_t1 &&
ir->rs1 == rv_reg_a0) {
next_ir = ir + 1;
if (next_ir->opcode == rv_insn_beq &&
next_ir->rs1 == rv_reg_a2 && next_ir->rs2 == rv_reg_zero) {
if (next_ir->imm == 20 && detect_memset(rv, 2)) {
ir->opcode = rv_insn_fuse6;
ir->impl = dispatch_table[ir->opcode];
ir->tailcall = true;
} else if (next_ir->imm == 28 && detect_memcpy(rv, 2)) {
ir->opcode = rv_insn_fuse7;
ir->impl = dispatch_table[ir->opcode];
ir->tailcall = true;
};
}
}
break;
case rv_insn_xor:
/* Compare the target block with the first basic block of memcpy, if
* two block is match, we would extract the instruction sequence
* starting from the pc_start of the basic block and then compare
* it with the pre-recorded memcpy instruction sequence.
*/
if (ir->rd == rv_reg_a5 && ir->rs1 == rv_reg_a0 &&
ir->rs2 == rv_reg_a1) {
next_ir = ir + 1;
if (next_ir->opcode == rv_insn_andi && next_ir->imm == 3 &&
next_ir->rd == rv_reg_a5 && next_ir->rs1 == rv_reg_a5) {
next_ir = next_ir + 1;
if (next_ir->opcode == rv_insn_add &&
next_ir->rd == rv_reg_a7 && next_ir->rs1 == rv_reg_a0 &&
next_ir->rs2 == rv_reg_a2) {
next_ir = next_ir + 1;
if (next_ir->opcode == rv_insn_bne &&
next_ir->imm == 104 && next_ir->rs1 == rv_reg_a5 &&
next_ir->rs2 == rv_reg_zero) {
if (detect_memcpy(rv, 1)) {
ir->opcode = rv_insn_fuse7;
ir->impl = dispatch_table[ir->opcode];
ir->tailcall = true;
};
}
}
}
}
break;
case rv_insn_auipc:
case rv_insn_lui:
next_ir = ir + 1;
if (next_ir->opcode == rv_insn_addi && ir->rd == next_ir->rs1) {
/* The destination register of the AUIPC instruction is the
* same as the source register 1 of the next instruction ADDI.
*/
ir->opcode = rv_insn_fuse1;
ir->rs1 = next_ir->rd;
ir->imm2 = next_ir->imm;
ir->impl = dispatch_table[ir->opcode];
ir->tailcall = next_ir->tailcall;
} else if (next_ir->opcode == rv_insn_add &&
ir->rd == next_ir->rs2) {
/* The destination register of the AUIPC instruction is the
if (next_ir->opcode == rv_insn_add && ir->rd == next_ir->rs2) {
/* The destination register of the LUI instruction is the
* same as the source register 2 of the next instruction ADD.
*/
ir->opcode = rv_insn_fuse2;
Expand All @@ -938,13 +927,28 @@ static void match_pattern(riscv_t *rv, block_t *block)
ir->impl = dispatch_table[ir->opcode];
} else if (next_ir->opcode == rv_insn_add &&
ir->rd == next_ir->rs1) {
/* The destination register of the AUIPC instruction is the
/* The destination register of the LUI instruction is the
* same as the source register 1 of the next instruction ADD.
*/
ir->opcode = rv_insn_fuse2;
ir->rs2 = next_ir->rd;
ir->rs1 = next_ir->rs2;
ir->impl = dispatch_table[ir->opcode];
} else {
count = 1;
next_ir = ir + 1;
while (1) {
if (next_ir->opcode != rv_insn_lui)
break;
next_ir->opcode = rv_insn_nop;
count++;
next_ir += 1;
}
if (count > 1) {
ir->imm2 = count;
ir->opcode = rv_insn_fuse1;
ir->impl = dispatch_table[ir->opcode];
}
}
break;

Expand All @@ -957,25 +961,46 @@ static void match_pattern(riscv_t *rv, block_t *block)
case rv_insn_lw:
COMBINE_MEM_OPS(1);
break;
case rv_insn_lui:
next_ir = ir + 1;
if (next_ir->opcode == rv_insn_addi && ir->rd == next_ir->rs1) {
/* The destination register of the LUI instruction is the
* same as the source register 1 of the next instruction ADDI.
*/
ir->opcode = rv_insn_fuse5;
ir->rs1 = next_ir->rd;
ir->imm2 = next_ir->imm;
ir->impl = dispatch_table[ir->opcode];
ir->tailcall = next_ir->tailcall;
}
break;
/* TODO: mixture of SW and LW */
/* TODO: reorder insturction to match pattern */
}
}
}


typedef struct const_opt_info {
bool is_constant[32];
uint32_t const_val[32];
} const_opt_info_t;

#define CONSTOPT(inst, code) \
static void const_opt_##inst(UNUSED rv_insn_t *ir, \
UNUSED const_opt_info_t *const_opt_info) \
{ \
code; \
}

#include "const_opt.c"
/* clang-format off */
static const void *const_opt_table[] = {
/* RV32 instructions */
#define _(inst, can_branch, reg_mask) [rv_insn_##inst] = const_opt_##inst,
RV_INSN_LIST
#undef _
};
/* clang-format on */
typedef void (*opt_func)(rv_insn_t *, const_opt_info_t *);
static void constant_opt(block_t *block)
{
const_opt_info_t const_opt_info;
memset(&const_opt_info, 0, sizeof(const_opt_info));
const_opt_info.is_constant[0] = true;
for (uint32_t i = 0; i < block->n_insn; i++) {
rv_insn_t *ir = block->ir + i;
((opt_func) const_opt_table[ir->opcode])(ir, &const_opt_info);
}
}

static block_t *prev = NULL;
static block_t *block_find_or_translate(riscv_t *rv)
{
Expand All @@ -994,12 +1019,15 @@ static block_t *block_find_or_translate(riscv_t *rv)

/* translate the basic block */
block_translate(rv, next);

if (!libc_substitute(rv, next)) {
constant_opt(next);
#if RV32_HAS(GDBSTUB)
if (likely(!rv->debug_mode))
if (likely(!rv->debug_mode))
#endif
/* macro operation fusion */
match_pattern(rv, next);

/* macro operation fusion */
match_pattern(next);
}
/* insert the block into block map */
block_insert(&rv->block_map, next);

Expand Down

0 comments on commit 4dbfe8d

Please sign in to comment.