Skip to content

Commit

Permalink
get rid of im2col counter
Browse files Browse the repository at this point in the history
  • Loading branch information
SeahK committed Feb 8, 2024
1 parent 25710ae commit 8a9ed28
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 53 deletions.
1 change: 1 addition & 0 deletions src/main/scala/gemmini/ConfigsFP.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ object GemminiFPConfigs {
ex_write_to_spad=false,
has_training_convs = false,
hardcode_d_to_garbage_addr = true,
has_loop_conv = false,
acc_read_full_width = false,
//has_loop_conv = false,
max_in_flight_mem_reqs = 16,
Expand Down
33 changes: 19 additions & 14 deletions src/main/scala/gemmini/Controller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
val max_exs = reservation_station_entries_ex
val max_sts = reservation_station_entries_st

/*
val (conv_cmd, loop_conv_unroller_busy) = withClock (gated_clock) { LoopConv(raw_cmd, reservation_station.io.conv_ld_completed, reservation_station.io.conv_st_completed, reservation_station.io.conv_ex_completed,
meshRows*tileRows, coreMaxAddrBits, reservation_station_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries,
inputType.getWidth, accType.getWidth, dma_maxbytes,
Expand All @@ -151,8 +152,20 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
new PreloadRs(mvout_rows_bits, mvout_cols_bits, local_addr_t),
new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t),
has_training_convs, has_max_pool, has_first_layer_optimizations, has_dw_convs) }
*/

val (loop_cmd, loop_matmul_unroller_busy) = withClock (gated_clock) { LoopMatmul(conv_cmd, reservation_station.io.matmul_ld_completed, reservation_station.io.matmul_st_completed, reservation_station.io.matmul_ex_completed,
val (conv_cmd, loop_conv_unroller_busy) = if (has_loop_conv) withClock (gated_clock) { LoopConv(raw_cmd, reservation_station.io.conv_ld_completed, reservation_station.io.conv_st_completed, reservation_station.io.conv_ex_completed,
meshRows*tileRows, coreMaxAddrBits, reservation_station_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries,
inputType.getWidth, accType.getWidth, dma_maxbytes,
new ConfigMvinRs1(mvin_scale_t_bits, block_stride_bits, pixel_repeats_bits), new MvinRs2(mvin_rows_bits, mvin_cols_bits, local_addr_t),
new ConfigMvoutRs2(acc_scale_t_bits, 32), new MvoutRs2(mvout_rows_bits, mvout_cols_bits, local_addr_t),
new ConfigExRs1(acc_scale_t_bits), new PreloadRs(mvin_rows_bits, mvin_cols_bits, local_addr_t),
new PreloadRs(mvout_rows_bits, mvout_cols_bits, local_addr_t),
new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t),
has_training_convs, has_max_pool, has_first_layer_optimizations, has_dw_convs) }
else (raw_cmd, false.B)

val (loop_cmd, loop_matmul_unroller_busy) = withClock (gated_clock) { LoopMatmul(if (has_loop_conv) conv_cmd else raw_cmd, reservation_station.io.matmul_ld_completed, reservation_station.io.matmul_st_completed, reservation_station.io.matmul_ex_completed,
meshRows*tileRows, coreMaxAddrBits, reservation_station_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries,
inputType.getWidth, accType.getWidth, dma_maxbytes, new MvinRs2(mvin_rows_bits, mvin_cols_bits, local_addr_t),
new PreloadRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new PreloadRs(mvout_rows_bits, mvout_cols_bits, local_addr_t),
Expand Down Expand Up @@ -255,32 +268,24 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
ex_controller.io.acc.write <> spad.module.io.acc.write

// Im2Col unit
/*
val im2col = withClock (gated_clock) { Module(new Im2Col(outer.config)) }
// Wire up Im2col
counters.io.event_io.collect(im2col.io.counter)
// im2col.io.sram_reads <> spad.module.io.srams.read
im2col.io.req <> ex_controller.io.im2col.req
ex_controller.io.im2col.resp <> im2col.io.resp
*/

// Wire arbiter for ExecuteController and Im2Col scratchpad reads
(ex_controller.io.srams.read, im2col.io.sram_reads, spad.module.io.srams.read).zipped.foreach { case (ex_read, im2col_read, spad_read) =>
val req_arb = Module(new Arbiter(new ScratchpadReadReq(n=sp_bank_entries), 2))

(ex_controller.io.srams.read, spad.module.io.srams.read).zipped.foreach { case (ex_read, spad_read) =>
val req_arb = Module(new Arbiter(new ScratchpadReadReq(n=sp_bank_entries), 1))
req_arb.io.in(0) <> ex_read.req
req_arb.io.in(1) <> im2col_read.req

spad_read.req <> req_arb.io.out

// TODO if necessary, change how the responses are handled when fromIm2Col is added to spad read interface

ex_read.resp.valid := spad_read.resp.valid
im2col_read.resp.valid := spad_read.resp.valid

ex_read.resp.bits := spad_read.resp.bits
im2col_read.resp.bits := spad_read.resp.bits

spad_read.resp.ready := ex_read.resp.ready || im2col_read.resp.ready
spad_read.resp.ready := ex_read.resp.ready
}

// Wire up controllers to ROB
Expand Down
16 changes: 8 additions & 8 deletions src/main/scala/gemmini/CounterFile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,17 @@ object CounterEvent {
val B_GARBAGE_CYCLES = 36
val D_GARBAGE_CYCLES = 37

val IM2COL_MEM_CYCLES = 38
val IM2COL_ACTIVE_CYCLES = 39
val IM2COL_TRANSPOSER_WAIT_CYCLE = 40
//val IM2COL_MEM_CYCLES = 38
//val IM2COL_ACTIVE_CYCLES = 39
//val IM2COL_TRANSPOSER_WAIT_CYCLE = 40

val RESERVATION_STATION_FULL_CYCLES = 41
val RESERVATION_STATION_ACTIVE_CYCLES = 42
val RESERVATION_STATION_FULL_CYCLES = 38
val RESERVATION_STATION_ACTIVE_CYCLES = 39

val LOOP_MATMUL_ACTIVE_CYCLES = 43
val TRANSPOSE_PRELOAD_UNROLLER_ACTIVE_CYCLES = 44
val LOOP_MATMUL_ACTIVE_CYCLES = 40
val TRANSPOSE_PRELOAD_UNROLLER_ACTIVE_CYCLES = 41

val n = 45
val n = 42
}

object CounterExternal {
Expand Down
64 changes: 39 additions & 25 deletions src/main/scala/gemmini/ExecuteController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val io = IO(new Bundle {
val cmd = Flipped(Decoupled(new GemminiCmd(reservation_station_entries)))

/*
val im2col = new Bundle {
val req = Decoupled(new Im2ColReadReq(config))
val resp = Flipped(Decoupled(new Im2ColReadResp(config)))
}
*/

val srams = new Bundle {
val read = Vec(sp_banks, new ScratchpadReadIO(sp_bank_entries, sp_width))
val write = Vec(sp_banks, new ScratchpadWriteIO(sp_bank_entries, sp_width, (sp_width / (aligned_to * 8)) max 1))
Expand Down Expand Up @@ -111,7 +114,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
icol := ((ocol - 1.U) * weight_stride + krow)//.asSInt
irow := ((orow - 1.U) * weight_stride + krow)//.asSInt

val im2col_turn = WireInit(0.U(9.W))
//val im2col_turn = WireInit(0.U(9.W))

val in_shift = Reg(UInt(log2Up(accType.getWidth).W))
val acc_scale = Reg(acc_scale_t)
Expand All @@ -133,7 +136,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
"Too many inputs are being fed into the single transposer we have")

//fix by input
val im2col_en = config.hasIm2Col.B && weight_stride =/= 0.U
val im2col_en = false.B //config.hasIm2Col.B && weight_stride =/= 0.U

// SRAM addresses of matmul operands
val a_address_rs1 = rs1s(a_address_place).asTypeOf(local_addr_t)
Expand Down Expand Up @@ -311,7 +314,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val b_row_is_not_all_zeros = b_fire_counter < b_rows
val d_row_is_not_all_zeros = block_size.U - 1.U - d_fire_counter < d_rows //Todo: d_fire_counter_mulpre?

val im2col_wire = io.im2col.req.ready
val im2col_wire = false.B //io.im2col.req.ready

def same_bank(addr1: LocalAddr, addr2: LocalAddr, is_garbage1: Bool, is_garbage2: Bool, start_inputting1: Bool, start_inputting2: Bool, can_be_im2colled: Boolean): Bool = {
val addr1_read_from_acc = addr1.is_acc_addr
Expand Down Expand Up @@ -394,11 +397,16 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
mul_pre_counter_lock := true.B
}

/*
when(!io.im2col.resp.bits.im2col_delay && performing_mul_pre){
mul_pre_counter_sub := Mux(mul_pre_counter_sub > 0.U, mul_pre_counter_sub - 1.U, 0.U)
}.elsewhen(io.im2col.resp.bits.im2col_delay){
mul_pre_counter_sub := 2.U
}.otherwise{mul_pre_counter_sub := 0.U}
*/
when(performing_mul_pre){
mul_pre_counter_sub := Mux(mul_pre_counter_sub > 0.U, mul_pre_counter_sub - 1.U, 0.U)
}.otherwise{mul_pre_counter_sub := 0.U}

// The last line in this (long) Boolean is just to make sure that we don't think we're done as soon as we begin firing
// TODO change when square requirement lifted
Expand All @@ -415,9 +423,13 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
}

val d_fire_counter_mulpre = WireInit(b_fire_counter)
/*
when(performing_mul_pre && !io.im2col.resp.bits.im2col_delay&&im2col_en){
d_fire_counter_mulpre := d_fire_counter - mul_pre_counter_sub
}.otherwise{d_fire_counter_mulpre := d_fire_counter}
*/
d_fire_counter_mulpre := d_fire_counter


// Scratchpad reads
for (i <- 0 until sp_banks) {
Expand Down Expand Up @@ -505,27 +517,28 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
{
val read_a = a_valid && start_inputting_a && !multiply_garbage && im2col_wire&&im2col_en //or just im2col_wire

when (read_a && !io.im2col.req.ready) {
when (read_a && !im2col_wire) {
a_ready := false.B
}

io.im2col.req.valid := read_a
io.im2col.req.bits.addr := a_address_rs1
io.im2col.req.bits.icol := icol
io.im2col.req.bits.irow := irow
io.im2col.req.bits.ocol := ocol
io.im2col.req.bits.stride := weight_stride
io.im2col.req.bits.krow := krow
io.im2col.req.bits.kdim2 := kdim2
io.im2col.req.bits.row_turn := row_turn
io.im2col.req.bits.row_left := row_left
io.im2col.req.bits.channel := channel
io.im2col.req.bits.im2col_cmd := im2col_en
io.im2col.req.bits.start_inputting := start_inputting_a
io.im2col.req.bits.weight_double_bank := weight_double_bank
io.im2col.req.bits.weight_triple_bank := weight_triple_bank

io.im2col.resp.ready := mesh.io.a.ready
/*
io.im2col.req.valid := read_a
io.im2col.req.bits.addr := a_address_rs1
io.im2col.req.bits.icol := icol
io.im2col.req.bits.irow := irow
io.im2col.req.bits.ocol := ocol
io.im2col.req.bits.stride := weight_stride
io.im2col.req.bits.krow := krow
io.im2col.req.bits.kdim2 := kdim2
io.im2col.req.bits.row_turn := row_turn
io.im2col.req.bits.row_left := row_left
io.im2col.req.bits.channel := channel
io.im2col.req.bits.im2col_cmd := im2col_en
io.im2col.req.bits.start_inputting := start_inputting_a
io.im2col.req.bits.weight_double_bank := weight_double_bank
io.im2col.req.bits.weight_triple_bank := weight_triple_bank
io.im2col.resp.ready := mesh.io.a.ready
*/
}

// FSM logic
Expand Down Expand Up @@ -802,11 +815,11 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In

val readData = VecInit(io.srams.read.map(_.resp.bits.data))
val accReadData = if (ex_read_from_acc) VecInit(io.acc.read_resp.map(_.bits.data.asUInt)) else readData
val im2ColData = io.im2col.resp.bits.a_im2col.asUInt
//val im2ColData = io.im2col.resp.bits.a_im2col.asUInt

val readValid = VecInit(io.srams.read.map(bank => ex_read_from_spad.B && bank.resp.valid && !bank.resp.bits.fromDMA))
val accReadValid = VecInit(io.acc.read_resp.map(bank => ex_read_from_acc.B && bank.valid && !bank.bits.fromDMA))
val im2ColValid = io.im2col.resp.valid
val im2ColValid = false.B //io.im2col.resp.valid

mesh_cntl_signals_q.io.deq.ready := (!cntl.a_fire || mesh.io.a.fire || !mesh.io.a.ready) &&
(!cntl.b_fire || mesh.io.b.fire || !mesh.io.b.ready) &&
Expand All @@ -829,7 +842,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
//val neg_shift_sub = block_size.U - cntl.c_rows
preload_zero_counter := wrappingAdd(preload_zero_counter, 1.U, block_size.U, dataA_valid && dataD_valid && cntl.preload_zeros && (cntl.perform_single_preload || cntl.perform_mul_pre))

val dataA_unpadded = Mux(cntl.im2colling, im2ColData, Mux(cntl.a_read_from_acc, accReadData(cntl.a_bank_acc), readData(cntl.a_bank)))
//val dataA_unpadded = Mux(cntl.im2colling, im2ColData, Mux(cntl.a_read_from_acc, accReadData(cntl.a_bank_acc), readData(cntl.a_bank)))
val dataA_unpadded = Mux(cntl.a_read_from_acc, accReadData(cntl.a_bank_acc), readData(cntl.a_bank))
val dataB_unpadded = MuxCase(readData(cntl.b_bank), Seq(cntl.accumulate_zeros -> 0.U, cntl.b_read_from_acc -> accReadData(cntl.b_bank_acc)))
val dataD_unpadded = MuxCase(readData(cntl.d_bank), Seq(cntl.preload_zeros -> 0.U, cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))

Expand Down
1 change: 1 addition & 0 deletions src/main/scala/gemmini/GemminiConfigs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
has_dw_convs: Boolean = true,
has_normalizations: Boolean = false,
has_first_layer_optimizations: Boolean = true,
has_loop_conv: Boolean = true,

use_firesim_simulation_counters: Boolean = false,

Expand Down
6 changes: 0 additions & 6 deletions src/main/scala/gemmini/Im2Col.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class Im2Col[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V

val sram_reads = Vec(sp_banks, new ScratchpadReadIO(sp_bank_entries, sp_width)) // from Scratchpad

val counter = new CounterEventIO()
})
val req = Reg(new Im2ColReadReq(config))

Expand Down Expand Up @@ -449,9 +448,4 @@ class Im2Col[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V
io.sram_reads.foreach(_.resp.ready := false.B)
}

// Performance counter
CounterEventIO.init(io.counter)
io.counter.connectEventSignal(CounterEvent.IM2COL_ACTIVE_CYCLES, im2col_state === preparing_im2col)
io.counter.connectEventSignal(CounterEvent.IM2COL_MEM_CYCLES, im2col_state === doing_im2col)
io.counter.connectEventSignal(CounterEvent.IM2COL_TRANSPOSER_WAIT_CYCLE, im2col_state === waiting_for_im2col && sram_read_req)
}

0 comments on commit 8a9ed28

Please sign in to comment.