Skip to content

Commit

Permalink
Make LoopMatmul run-ahead to avoid RAW issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryz123 committed Mar 8, 2021
1 parent 2090812 commit defd120
Showing 1 changed file with 57 additions and 29 deletions.
86 changes: 57 additions & 29 deletions src/main/scala/gemmini/LoopMatmul.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
val io = IO(new Bundle {
val req = Flipped(Decoupled(new LoopMatmulLdAReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, concurrent_loops)))
val cmd = Decoupled(Output(new RoCCCommand))
val prefetch = Output(Valid(new RoCCCommand))
val i = Output(UInt(iterator_bitwidth.W))
val k = Output(UInt(iterator_bitwidth.W))
val idle = Output(Bool())
val rob_overloaded = Input(Bool())
val loop_id = Output(UInt(log2Up(concurrent_loops).W))
})

Expand Down Expand Up @@ -71,17 +73,31 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
mvin_cmd.rs1 := dram_addr
mvin_cmd.rs2 := (rows << 48).asUInt() | (cols << 32).asUInt() | sp_addr

io.req.ready := state === idle
io.i := i
io.k := k
io.idle := state === idle

io.cmd.valid := state =/= idle
io.cmd.bits := mvin_cmd

class CmdQEntry extends Bundle {
val cmd = new RoCCCommand
val i = UInt()
val k = UInt()
}
val cmd_q = Module(new Queue(new CmdQEntry, 8))
cmd_q.io.enq.valid := state =/= idle && !io.rob_overloaded
cmd_q.io.enq.bits.cmd := mvin_cmd
cmd_q.io.enq.bits.i := i
cmd_q.io.enq.bits.k := k

io.cmd.valid := cmd_q.io.deq.valid
cmd_q.io.deq.ready := io.cmd.ready
io.cmd.bits := cmd_q.io.deq.bits.cmd

io.prefetch.valid := cmd_q.io.enq.fire()
io.prefetch.bits := cmd_q.io.enq.bits.cmd

io.req.ready := state === idle && !cmd_q.io.deq.valid
io.i := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.i, 0.U)
io.k := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.k, 0.U)
io.loop_id := req.loop_id
io.idle := state === idle && !cmd_q.io.deq.valid

when (io.cmd.fire()) {
when (cmd_q.io.enq.fire()) {
// The order here is k, j, i
val i_blocks = Mux(req.transpose, max_blocks, 1.U)
val k_blocks = Mux(req.transpose, 1.U, max_blocks)
Expand Down Expand Up @@ -125,11 +141,13 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
val io = IO(new Bundle {
val req = Flipped(Decoupled(new LoopMatmulLdBReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, concurrent_loops)))
val cmd = Decoupled(Output(new RoCCCommand))
val prefetch = Output(Valid(new RoCCCommand))

val k = Output(UInt(iterator_bitwidth.W))
val j = Output(UInt(iterator_bitwidth.W))

val idle = Output(Bool())
val rob_overloaded = Input(Bool())

val loop_id = Output(UInt(log2Up(concurrent_loops).W))
})
Expand Down Expand Up @@ -171,17 +189,31 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
mvin_cmd.rs1 := dram_addr
mvin_cmd.rs2 := (rows << 48).asUInt() | (cols << 32).asUInt() | sp_addr

io.req.ready := state === idle
io.k := k
io.j := j
io.idle := state === idle

io.cmd.valid := state =/= idle
io.cmd.bits := mvin_cmd

class CmdQEntry extends Bundle {
val cmd = new RoCCCommand
val k = UInt()
val j = UInt()
}
val cmd_q = Module(new Queue(new CmdQEntry, 8))
cmd_q.io.enq.valid := state =/= idle && !io.rob_overloaded
cmd_q.io.enq.bits.cmd := mvin_cmd
cmd_q.io.enq.bits.k := k
cmd_q.io.enq.bits.j := j

io.cmd.valid := cmd_q.io.deq.valid
cmd_q.io.deq.ready := io.cmd.ready
io.cmd.bits := cmd_q.io.deq.bits.cmd

io.prefetch.valid := cmd_q.io.enq.fire()
io.prefetch.bits := cmd_q.io.enq.bits.cmd

io.req.ready := state === idle && !cmd_q.io.deq.valid
io.k := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.k, 0.U)
io.j := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.j, 0.U)
io.loop_id := req.loop_id
io.idle := state === idle && !cmd_q.io.deq.valid

when (io.cmd.fire()) {
when (cmd_q.io.enq.fire()) {
// The order here is k, j, i
val j_blocks = Mux(req.transpose, 1.U, max_blocks)
val k_blocks = Mux(req.transpose, max_blocks, 1.U)
Expand Down Expand Up @@ -637,21 +669,17 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds:

// Create ld arbiters
val ldab_arb = Module(new WeightedArbiter(new RoCCCommand(), weightA=3))
val lda_q = Module(new Queue(new RoCCCommand, 8))
lda_q.io.enq <> ldA.io.cmd
val ldb_q = Module(new Queue(new RoCCCommand, 8))
ldb_q.io.enq <> ldB.io.cmd
ldab_arb.io.inA <> lda_q.io.deq
ldab_arb.io.inB <> ldb_q.io.deq
ldab_arb.io.inA <> ldA.io.cmd
ldab_arb.io.inB <> ldB.io.cmd
val ab_loads_on_same_loop = ldA.io.loop_id === ldB.io.loop_id
ldab_arb.io.forceA := !ab_loads_on_same_loop && ldA.io.loop_id === head_loop_id
ldab_arb.io.forceB := !ab_loads_on_same_loop && ldB.io.loop_id === head_loop_id

val prefetch_arb = Module(new Arbiter(new RoCCCommand, 2))
prefetch_arb.io.in(0).valid := ldA.io.cmd.fire()
prefetch_arb.io.in(0).bits := ldA.io.cmd.bits
prefetch_arb.io.in(1).valid := ldB.io.cmd.fire()
prefetch_arb.io.in(1).bits := ldB.io.cmd.bits
prefetch_arb.io.in(0).valid := ldA.io.prefetch.fire()
prefetch_arb.io.in(0).bits := ldA.io.prefetch.bits
prefetch_arb.io.in(1).valid := ldB.io.prefetch.fire()
prefetch_arb.io.in(1).bits := ldB.io.prefetch.bits
val prefetch_q_size = 4
val prefetch_q = Module(new Queue(new RoCCCommand, prefetch_q_size, pipe=true))
io.prefetch <> prefetch_q.io.deq
Expand All @@ -661,7 +689,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds:
prefetch_q.io.deq.ready := true.B
}

io.busy := cmd.valid || loop_configured || lda_q.io.deq.valid || ldb_q.io.deq.valid
io.busy := cmd.valid || loop_configured


// Create global arbiter
Expand Down

0 comments on commit defd120

Please sign in to comment.