From defd120823429b985743c136edeff36bb451e504 Mon Sep 17 00:00:00 2001 From: Jerry Zhao Date: Mon, 8 Mar 2021 14:01:59 -0800 Subject: [PATCH] Make LoopMatmul run-ahead to avoid RAW issues --- src/main/scala/gemmini/LoopMatmul.scala | 86 ++++++++++++++++--------- 1 file changed, 57 insertions(+), 29 deletions(-) diff --git a/src/main/scala/gemmini/LoopMatmul.scala b/src/main/scala/gemmini/LoopMatmul.scala index 5fd2e56c..79c435e0 100644 --- a/src/main/scala/gemmini/LoopMatmul.scala +++ b/src/main/scala/gemmini/LoopMatmul.scala @@ -28,9 +28,11 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In val io = IO(new Bundle { val req = Flipped(Decoupled(new LoopMatmulLdAReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, concurrent_loops))) val cmd = Decoupled(Output(new RoCCCommand)) + val prefetch = Output(Valid(new RoCCCommand)) val i = Output(UInt(iterator_bitwidth.W)) val k = Output(UInt(iterator_bitwidth.W)) val idle = Output(Bool()) + val rob_overloaded = Input(Bool()) val loop_id = Output(UInt(log2Up(concurrent_loops).W)) }) @@ -71,17 +73,31 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In mvin_cmd.rs1 := dram_addr mvin_cmd.rs2 := (rows << 48).asUInt() | (cols << 32).asUInt() | sp_addr - io.req.ready := state === idle - io.i := i - io.k := k - io.idle := state === idle - - io.cmd.valid := state =/= idle - io.cmd.bits := mvin_cmd - + class CmdQEntry extends Bundle { + val cmd = new RoCCCommand + val i = UInt() + val k = UInt() + } + val cmd_q = Module(new Queue(new CmdQEntry, 8)) + cmd_q.io.enq.valid := state =/= idle && !io.rob_overloaded + cmd_q.io.enq.bits.cmd := mvin_cmd + cmd_q.io.enq.bits.i := i + cmd_q.io.enq.bits.k := k + + io.cmd.valid := cmd_q.io.deq.valid + cmd_q.io.deq.ready := io.cmd.ready + io.cmd.bits := cmd_q.io.deq.bits.cmd + + io.prefetch.valid := cmd_q.io.enq.fire() + io.prefetch.bits := cmd_q.io.enq.bits.cmd + + io.req.ready := state === idle && !cmd_q.io.deq.valid + io.i := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.i, 0.U) + io.k := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.k, 0.U) io.loop_id := req.loop_id + io.idle := state === idle && !cmd_q.io.deq.valid - when (io.cmd.fire()) { + when (cmd_q.io.enq.fire()) { // The order here is k, j, i val i_blocks = Mux(req.transpose, max_blocks, 1.U) val k_blocks = Mux(req.transpose, 1.U, max_blocks) @@ -125,11 +141,13 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In val io = IO(new Bundle { val req = Flipped(Decoupled(new LoopMatmulLdBReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, concurrent_loops))) val cmd = Decoupled(Output(new RoCCCommand)) + val prefetch = Output(Valid(new RoCCCommand)) val k = Output(UInt(iterator_bitwidth.W)) val j = Output(UInt(iterator_bitwidth.W)) val idle = Output(Bool()) + val rob_overloaded = Input(Bool()) val loop_id = Output(UInt(log2Up(concurrent_loops).W)) }) @@ -171,17 +189,31 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In mvin_cmd.rs1 := dram_addr mvin_cmd.rs2 := (rows << 48).asUInt() | (cols << 32).asUInt() | sp_addr - io.req.ready := state === idle - io.k := k - io.j := j - io.idle := state === idle - - io.cmd.valid := state =/= idle - io.cmd.bits := mvin_cmd - + class CmdQEntry extends Bundle { + val cmd = new RoCCCommand + val k = UInt() + val j = UInt() + } + val cmd_q = Module(new Queue(new CmdQEntry, 8)) + cmd_q.io.enq.valid := state =/= idle && !io.rob_overloaded + cmd_q.io.enq.bits.cmd := mvin_cmd + cmd_q.io.enq.bits.k := k + cmd_q.io.enq.bits.j := j + + io.cmd.valid := cmd_q.io.deq.valid + cmd_q.io.deq.ready := io.cmd.ready + io.cmd.bits := cmd_q.io.deq.bits.cmd + + io.prefetch.valid := cmd_q.io.enq.fire() + io.prefetch.bits := cmd_q.io.enq.bits.cmd + + io.req.ready := state === idle && !cmd_q.io.deq.valid + io.k := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.k, 0.U) + io.j := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.j, 0.U) io.loop_id := req.loop_id + io.idle := state === idle && !cmd_q.io.deq.valid - when (io.cmd.fire()) { + when (cmd_q.io.enq.fire()) { // The order here is k, j, i val j_blocks = Mux(req.transpose, 1.U, max_blocks) val k_blocks = Mux(req.transpose, max_blocks, 1.U) @@ -637,21 +669,17 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: // Create ld arbiters val ldab_arb = Module(new WeightedArbiter(new RoCCCommand(), weightA=3)) - val lda_q = Module(new Queue(new RoCCCommand, 8)) - lda_q.io.enq <> ldA.io.cmd - val ldb_q = Module(new Queue(new RoCCCommand, 8)) - ldb_q.io.enq <> ldB.io.cmd - ldab_arb.io.inA <> lda_q.io.deq - ldab_arb.io.inB <> ldb_q.io.deq + ldab_arb.io.inA <> ldA.io.cmd + ldab_arb.io.inB <> ldB.io.cmd val ab_loads_on_same_loop = ldA.io.loop_id === ldB.io.loop_id ldab_arb.io.forceA := !ab_loads_on_same_loop && ldA.io.loop_id === head_loop_id ldab_arb.io.forceB := !ab_loads_on_same_loop && ldB.io.loop_id === head_loop_id val prefetch_arb = Module(new Arbiter(new RoCCCommand, 2)) - prefetch_arb.io.in(0).valid := ldA.io.cmd.fire() - prefetch_arb.io.in(0).bits := ldA.io.cmd.bits - prefetch_arb.io.in(1).valid := ldB.io.cmd.fire() - prefetch_arb.io.in(1).bits := ldB.io.cmd.bits + prefetch_arb.io.in(0).valid := ldA.io.prefetch.fire() + prefetch_arb.io.in(0).bits := ldA.io.prefetch.bits + prefetch_arb.io.in(1).valid := ldB.io.prefetch.fire() + prefetch_arb.io.in(1).bits := ldB.io.prefetch.bits val prefetch_q_size = 4 val prefetch_q = Module(new Queue(new RoCCCommand, prefetch_q_size, pipe=true)) io.prefetch <> prefetch_q.io.deq @@ -661,7 +689,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: prefetch_q.io.deq.ready := true.B } - io.busy := cmd.valid || loop_configured || lda_q.io.deq.valid || ldb_q.io.deq.valid + io.busy := cmd.valid || loop_configured // Create global arbiter