diff --git a/software/gemmini-rocc-tests b/software/gemmini-rocc-tests index 1a1a1c6b..89e38f1c 160000 --- a/software/gemmini-rocc-tests +++ b/software/gemmini-rocc-tests @@ -1 +1 @@ -Subproject commit 1a1a1c6bd60df6d7cae3d87aac96c8f406cae084 +Subproject commit 89e38f1cc3df721c633914d28e0b8de658672b5a diff --git a/src/main/scala/gemmini/AccumulatorMem.scala b/src/main/scala/gemmini/AccumulatorMem.scala index c664bd0f..8289f1a5 100644 --- a/src/main/scala/gemmini/AccumulatorMem.scala +++ b/src/main/scala/gemmini/AccumulatorMem.scala @@ -92,7 +92,7 @@ class AccPipeShared[T <: Data : Arithmetic](latency: Int, t: Vec[Vec[T]], banks: class AccumulatorMem[T <: Data, U <: Data]( n: Int, t: Vec[Vec[T]], scale_func: (T, U) => T, scale_t: U, acc_singleported: Boolean, acc_sub_banks: Int, - use_shared_ext_mem: Boolean, + use_shared_ext_mem: Boolean, use_tl_ext_ram: Boolean, acc_latency: Int, acc_type: T, is_dummy: Boolean ) (implicit ev: Arithmetic[T]) extends Module { @@ -134,8 +134,46 @@ class AccumulatorMem[T <: Data, U <: Data]( val mask_len = t.getWidth / 8 val mask_elem = UInt((t.getWidth / mask_len).W) + // val ext_mem_write_q_enq = if (use_shared_ext_mem && use_tl_ext_ram) { + // require(acc_sub_banks == 1) + // Some(io.ext_mem.get.map { ext_mem => + // val write_q = Module(new Queue(new Bundle { + // val write_addr = UInt() + // val write_data = UInt() + // val write_mask = UInt() + // }, 8, pipe = true, flow = true)) + + // write_q.io.enq.valid := false.B + // write_q.io.enq.bits := DontCare + + // ext_mem.write_valid := write_q.io.deq.valid + // ext_mem.write_addr := write_q.io.deq.bits.write_addr + // ext_mem.write_data := write_q.io.deq.bits.write_data + // ext_mem.write_mask := write_q.io.deq.bits.write_mask + // write_q.io.deq.ready := ext_mem.write_ready + // write_q.io.enq + // }) + // } else None + io.ext_mem.get.foreach(_.write_req.valid := false.B) + io.ext_mem.get.foreach(_.write_req.bits.addr := 0.U(io.write.bits.addr.getWidth.W)) + io.ext_mem.get.foreach(_.write_req.bits.mask := 0.U(io.write.bits.mask.getWidth.W)) + io.ext_mem.get.foreach(_.write_req.bits.data := 0.U(io.write.bits.data.getWidth.W)) + io.ext_mem.get.foreach(_.read_req.bits := 0.U((mask_len * mask_elem.getWidth).W)) + io.ext_mem.get.foreach(_.read_req.valid := false.B) + io.ext_mem.get.foreach(_.read_resp.ready := false.B) // no reading from external accmem if (!acc_singleported && !is_dummy) { - require(!use_shared_ext_mem) + // if (use_shared_ext_mem && use_tl_ext_ram) { + // // duplicate write to external memory + // val enq = ext_mem_write_q_enq.get(0) + // enq.valid := oldest_pipelined_write.valid + // enq.bits.write_addr := oldest_pipelined_write.bits.addr + // enq.bits.write_data := Mux(oldest_pipelined_write.bits.acc, adder_sum.asUInt, oldest_pipelined_write.bits.data.asUInt) + // enq.bits.write_mask := oldest_pipelined_write.bits.mask.asUInt + // // TODO (richard): add buffer here and potentially propagate backpressure to systolic array + // assert(enq.ready || !enq.valid, "accumulator external memory write dropped") + // } else if (use_shared_ext_mem) { + // require(false, "cannot use two-port external acc mem bank") + // } val mem = TwoPortSyncMem(n, t, mask_len) // TODO We assume byte-alignment here. Use aligned_to instead mem.io.waddr := oldest_pipelined_write.bits.addr mem.io.wen := oldest_pipelined_write.valid @@ -163,27 +201,39 @@ class AccumulatorMem[T <: Data, U <: Data]( for (i <- 0 until acc_sub_banks) { def isThisBank(addr: UInt) = addr(log2Ceil(acc_sub_banks)-1,0) === i.U def getBankIdx(addr: UInt) = addr >> log2Ceil(acc_sub_banks) - val (read, write) = if (use_shared_ext_mem) { + val (read, write) = if (use_shared_ext_mem && !use_tl_ext_ram) { def read(addr: UInt, ren: Bool): Data = { - io.ext_mem.get(i).read_en := ren - io.ext_mem.get(i).read_addr := addr - io.ext_mem.get(i).read_data + io.ext_mem.get(i).read_req.valid := ren + io.ext_mem.get(i).read_req.bits := addr + io.ext_mem.get(i).read_resp.bits } - io.ext_mem.get(i).write_en := false.B - io.ext_mem.get(i).write_addr := DontCare - io.ext_mem.get(i).write_data := DontCare - io.ext_mem.get(i).write_mask := DontCare + io.ext_mem.get(i).write_req.bits := DontCare def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]) = { - io.ext_mem.get(i).write_en := true.B - io.ext_mem.get(i).write_addr := addr - io.ext_mem.get(i).write_data := wdata.asUInt - io.ext_mem.get(i).write_mask := wmask.asUInt + io.ext_mem.get(i).write_req.valid := true.B + io.ext_mem.get(i).write_req.bits.addr := addr + io.ext_mem.get(i).write_req.bits.data := wdata.asUInt + io.ext_mem.get(i).write_req.bits.mask := wmask.asUInt } (read _, write _) } else { val mem = SyncReadMem(n / acc_sub_banks, Vec(mask_len, mask_elem)) + io.ext_mem.get(i).read_req.bits := 0.U((mask_len * mask_elem.getWidth).W) + io.ext_mem.get(i).read_req.valid := false.B + def read(addr: UInt, ren: Bool): Data = mem.read(addr, ren) - def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]) = mem.write(addr, wdata, wmask) + def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]) = if (use_tl_ext_ram) { + mem.write(addr, wdata, wmask) + // duplicate write signal to external memory + // val enq = ext_mem_write_q_enq.get(i) + // enq.valid := true.B + // enq.bits.write_mask := wmask.asUInt + // enq.bits.write_addr := addr + // enq.bits.write_data := wdata.asUInt + // // TODO (richard): propagate backpressure to systolic array, add fence ability + // assert(enq.ready, "accumulator external memory write dropped") + } else { + mem.write(addr, wdata, wmask) + } (read _, write _) } diff --git a/src/main/scala/gemmini/Controller.scala b/src/main/scala/gemmini/Controller.scala index 0fdda55f..1b74bc18 100644 --- a/src/main/scala/gemmini/Controller.scala +++ b/src/main/scala/gemmini/Controller.scala @@ -3,14 +3,13 @@ package gemmini import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} - import chisel3._ import chisel3.util._ import org.chipsalliance.cde.config._ import freechips.rocketchip.diplomacy._ import freechips.rocketchip.tile._ import freechips.rocketchip.util.ClockGate -import freechips.rocketchip.tilelink.TLIdentityNode +import freechips.rocketchip.tilelink.{TLBundle, TLClientNode, TLEdgeOut, TLFragmenter, TLIdentityNode, TLManagerNode, TLMasterParameters, TLMasterPortParameters, TLMasterToSlaveTransferSizes, TLRAM, TLSlaveParameters, TLSlavePortParameters, TLWidthWidget, TLXbar} import GemminiISA._ import Util._ @@ -35,11 +34,115 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA val xLen = p(XLen) val spad = LazyModule(new Scratchpad(config)) + val create_tl_mem = config.use_shared_ext_mem && config.use_tl_ext_mem + + val num_ids = 32 // TODO (richard): move to config + val spad_base = 0 // 0x60000000L + + val unified_mem_read_node = TLIdentityNode() + val spad_read_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.sp_banks) {i => + TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"spad_read_node_$i", sourceId = IdRange(0, num_ids)))) + }) else TLIdentityNode() + // val acc_read_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.acc_banks) { i => + // TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"acc_read_node_$i", sourceId = IdRange(0, numIDs)))) + // }) else TLIdentityNode() + + val unified_mem_write_node = TLIdentityNode() + val spad_write_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.sp_banks) { i => + TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"spad_write_node_$i", sourceId = IdRange(0, num_ids)))) + }) else TLIdentityNode() + + // val spad_dma_write_node = TLClientNode(Seq( + // TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"spad_dma_write_node", sourceId = IdRange(0, num_ids)))))) + // val acc_write_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.acc_banks) { i => + // TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"acc_write_node_$i", sourceId = IdRange(0, numIDs)))) + // }) else TLIdentityNode() + + val spad_data_len = config.sp_width / 8 + val acc_data_len = config.sp_width / config.inputType.getWidth * config.accType.getWidth / 8 + val max_data_len = spad_data_len // max acc_data_len + + val spad_tl_ram : Seq[Seq[TLManagerNode]] = if (config.use_shared_ext_mem && config.use_tl_ext_mem) { + unified_mem_read_node :=* TLWidthWidget(spad_data_len) :=* spad_read_nodes + // unified_mem_read_node :=* TLWidthWidget(acc_data_len) :=* acc_read_nodes + unified_mem_write_node :=* TLWidthWidget(spad_data_len) :=* spad_write_nodes + // unified_mem_write_node :=* TLWidthWidget(acc_data_len) :=* acc_write_nodes + + val stride_by_word = false // TODO (richard): move to config + + require(isPow2(config.sp_banks)) + val banks : Seq[Seq[TLManagerNode]] = + if (stride_by_word) { + assert(false, "TODO under construction") + assert((config.sp_capacity match { case CapacityInKilobytes(kb) => kb * 1024}) == + config.sp_bank_entries * spad_data_len / max_data_len * config.sp_banks * max_data_len) + (0 until config.sp_banks).map { bank => + LazyModule(new TLRAM( + address = AddressSet(max_data_len * bank, + ((config.sp_bank_entries * spad_data_len / max_data_len - 1) * config.sp_banks + bank) + * max_data_len + (max_data_len - 1)), + beatBytes = max_data_len + )) + }.map(x => Seq(x.node)) + } else { + (0 until config.sp_banks).map { bank => + val mem_depth = config.sp_bank_entries * spad_data_len / max_data_len + val mem_width = max_data_len + + Seq(TLManagerNode(Seq(TLSlavePortParameters.v1( + managers = Seq(TLSlaveParameters.v2( + name = Some(f"sp_bank${bank}_read_mgr"), + address = Seq(AddressSet(spad_base + (mem_depth * mem_width * bank), + mem_depth * mem_width - 1)), + supports = TLMasterToSlaveTransferSizes( + get = TransferSizes(1, mem_width)), + fifoId = Some(0) + )), + beatBytes = mem_width + ))), + TLManagerNode(Seq(TLSlavePortParameters.v1( + managers = Seq(TLSlaveParameters.v2( + name = Some(f"sp_bank${bank}_write_mgr"), + address = Seq(AddressSet(spad_base + (mem_depth * mem_width * bank), + mem_depth * mem_width - 1)), + supports = TLMasterToSlaveTransferSizes( + putFull = TransferSizes(1, mem_width), + putPartial = TransferSizes(1, mem_width)), + fifoId = Some(0) + )), + beatBytes = mem_width + )))) + } + } + + require(!config.sp_singleported) + if (config.sp_singleported) { + val xbar = TLXbar() + xbar :=* unified_mem_read_node + xbar :=* unified_mem_write_node + banks.foreach(_.head := xbar) + } else { + val r_xbar = TLXbar() + val w_xbar = TLXbar() + r_xbar :=* unified_mem_read_node + w_xbar :=* unified_mem_write_node + banks.foreach { mem => + require(mem.length == 2) + mem.head := r_xbar + mem.last := TLFragmenter(spad_data_len, spad.maxBytes) := w_xbar + } + } + + banks + } else Seq() + override lazy val module = new GemminiModule(this) override val tlNode = if (config.use_dedicated_tl_port) spad.id_node else TLIdentityNode() override val atlNode = if (config.use_dedicated_tl_port) TLIdentityNode() else spad.id_node val node = if (config.use_dedicated_tl_port) tlNode else atlNode + + unified_mem_write_node := spad.spad_writer.node } class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] @@ -50,8 +153,121 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] import outer.config._ import outer.spad - val ext_mem_io = if (use_shared_ext_mem) Some(IO(new ExtSpadMemIO(sp_banks, acc_banks, acc_sub_banks))) else None - ext_mem_io.foreach(_ <> outer.spad.module.io.ext_mem.get) + val ext_mem_io = if (use_shared_ext_mem && !use_tl_ext_mem) + Some(IO(new ExtSpadMemIO(sp_banks, acc_banks, acc_sub_banks))) else None + + // we need these 2 separate signals because ext_mem_io is not writable in this module + val ext_mem_spad = outer.spad.module.io.ext_mem.get.spad + val ext_mem_acc = outer.spad.module.io.ext_mem.get.acc + + // connecting to unified TL interface + val source_counters = Seq.fill(4)(Counter(outer.num_ids)) + + if (outer.create_tl_mem) { + def connect(ext_mem: ExtMemIO, req_size: Int, r_node: TLBundle, r_edge: TLEdgeOut, r_source: Counter, + w_node: TLBundle, w_edge: TLEdgeOut, w_source: Counter): Unit = { + r_node.a.valid := ext_mem.read_req.valid + r_node.a.bits := r_edge.Get(r_source.value, + (ext_mem.read_req.bits << req_size.U).asUInt | outer.spad_base.U, + req_size.U)._2 + ext_mem.read_req.ready := r_node.a.ready + + val w_shifted_addr = (ext_mem.write_req.bits.addr << req_size.U).asUInt + val w_mask = (ext_mem.write_req.bits.mask << (w_shifted_addr & (w_edge.manager.beatBytes - 1).U)).asUInt + + w_node.a.valid := ext_mem.write_req.valid + w_node.a.bits := w_edge.Put(w_source.value, + w_shifted_addr | outer.spad_base.U, + req_size.U, ext_mem.write_req.bits.data, w_mask)._2 + ext_mem.write_req.ready := w_node.a.ready + + ext_mem.read_resp.valid := r_node.d.valid + ext_mem.read_resp.bits := r_node.d.bits.data + r_node.d.ready := ext_mem.read_resp.ready + + w_node.d.ready := true.B // writes are not acknowledged in gemmini + + when(ext_mem.read_req.fire) { r_source.inc() } + when(ext_mem.write_req.fire) { w_source.inc() } + } + (outer.spad_read_nodes.out zip outer.spad_write_nodes.out) + .zipWithIndex.foreach{ case (((r_node, r_edge), (w_node, w_edge)), i) => + connect(ext_mem_spad(i), log2Up(outer.spad_data_len), + r_node, r_edge, source_counters(0), w_node, w_edge, source_counters(1)) + } + + outer.spad_tl_ram.foreach { case Seq(r, w) => + val mem_depth = outer.config.sp_bank_entries * outer.spad_data_len / outer.max_data_len + val mem_width = outer.max_data_len + + val mem = TwoPortSyncMem( + n = mem_depth, + t = UInt((mem_width * 8).W), + mask_len = mem_width // byte level mask + ) + + val (r_node, r_edge) = r.in.head + val (w_node, w_edge) = w.in.head + + // READ + mem.io.ren := r_node.a.fire + mem.io.raddr := r_node.a.bits.address ^ outer.spad_base.U + + val data_pipe_in = Wire(DecoupledIO(mem.io.rdata.cloneType)) + data_pipe_in.valid := RegNext(mem.io.ren) + data_pipe_in.bits := mem.io.rdata + + val metadata_pipe_in = Wire(DecoupledIO(new Bundle { + val source = r_node.a.bits.source.cloneType + val size = r_node.a.bits.size.cloneType + })) + metadata_pipe_in.valid := mem.io.ren + metadata_pipe_in.bits.source := r_node.a.bits.source + metadata_pipe_in.bits.size := r_node.a.bits.size + + val data_pipe_inst = Module(new Pipeline(data_pipe_in.bits.cloneType, 1)()) + data_pipe_inst.io.in <> data_pipe_in + val data_pipe = data_pipe_inst.io.out + val metadata_pipe = Pipeline(metadata_pipe_in, 2) + assert(data_pipe_in.ready || !data_pipe_in.valid) + assert(metadata_pipe_in.ready || !data_pipe_in.ready) + assert(data_pipe.valid === metadata_pipe.valid) + + r_node.d.bits := r_edge.AccessAck( + metadata_pipe.bits.source, + metadata_pipe.bits.size, + data_pipe.bits) + r_node.d.valid := data_pipe.valid + // take new requests only we have the buffer slot open in case downstream becomes unready + r_node.a.ready := r_node.d.ready && !data_pipe_inst.io.busy + data_pipe.ready := r_node.d.ready + metadata_pipe.ready := r_node.d.ready + + // WRITE + mem.io.wen := w_node.a.fire + mem.io.waddr := w_node.a.bits.address ^ outer.spad_base.U + mem.io.wdata := w_node.a.bits.data + mem.io.mask := w_node.a.bits.mask.asBools + w_node.a.ready := w_node.d.ready// && (mem.io.waddr =/= mem.io.raddr) + w_node.d.valid := w_node.a.valid + w_node.d.bits := w_edge.AccessAck(w_node.a.bits) + } + + ext_mem_acc.foreach(_.foreach(x => { + x.read_resp.bits := 0.U(1.W) + x.read_resp.valid := false.B + x.read_req.ready := false.B + x.write_req.ready := false.B + })) + // (outer.acc_read_nodes.out zip outer.acc_write_nodes.out) + // .zipWithIndex.foreach { case (((r_node, r_edge), (w_node, w_edge)), i) => + // // TODO (richard): one subbank only for now + // connect(ext_mem_acc(i)(0), log2Up(outer.acc_data_len), + // r_node, r_edge, source_counters(2), w_node, w_edge, source_counters(3)) + // } + } else if (use_shared_ext_mem) { + ext_mem_io.foreach(_ <> outer.spad.module.io.ext_mem.get) + } val tagWidth = 32 @@ -66,7 +282,8 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] // TLB implicit val edge = outer.spad.id_node.edges.out.head - val tlb = Module(new FrontendTLB(2, tlb_size, dma_maxbytes, use_tlb_register_filter, use_firesim_simulation_counters, use_shared_tlb)) + // TODO(richard): bypass TLB + val tlb = Module(new FrontendTLB(3, tlb_size, dma_maxbytes, use_tlb_register_filter, use_firesim_simulation_counters, use_shared_tlb)) (tlb.io.clients zip outer.spad.module.io.tlb).foreach(t => t._1 <> t._2) tlb.io.exp.foreach(_.flush_skip := false.B) diff --git a/src/main/scala/gemmini/CustomConfigs.scala b/src/main/scala/gemmini/CustomConfigs.scala index 011d7ce1..b7f93ab2 100644 --- a/src/main/scala/gemmini/CustomConfigs.scala +++ b/src/main/scala/gemmini/CustomConfigs.scala @@ -49,8 +49,18 @@ object GemminiCustomConfigs { acc_capacity = CapacityInKilobytes(128), ) + val unifiedMemConfig = defaultConfig.copy( + has_training_convs = false, + has_max_pool = false, + use_tl_ext_mem = true, + sp_singleported = false, + spad_read_delay = 8, + use_shared_ext_mem = true, + acc_sub_banks = 1 + ) + // Specify which of your custom configs you want to build here - val customConfig = baselineInferenceConfig + val customConfig = unifiedMemConfig } diff --git a/src/main/scala/gemmini/DMA.scala b/src/main/scala/gemmini/DMA.scala index dac1a369..162b3226 100644 --- a/src/main/scala/gemmini/DMA.scala +++ b/src/main/scala/gemmini/DMA.scala @@ -338,6 +338,7 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf class StreamWriteRequest(val dataWidth: Int, val maxBytes: Int)(implicit p: Parameters) extends CoreBundle { val vaddr = UInt(coreMaxAddrBits.W) + val physical = Bool() val data = UInt(dataWidth.W) val len = UInt(log2Up((dataWidth/8 max maxBytes)+1).W) // The number of bytes to write val block = UInt(8.W) // TODO magic number @@ -354,12 +355,15 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: (implicit p: Parameters) extends LazyModule { val node = TLClientNode(Seq(TLMasterPortParameters.v1(Seq(TLClientParameters( name = "stream-writer", sourceId = IdRange(0, nXacts)))))) +// val spad_node = TLClientNode(Seq(TLMasterPortParameters.v1(Seq(TLClientParameters( +// name = "spad-writer", sourceId = IdRange(0, nXacts)))))) require(isPow2(aligned_to)) lazy val module = new Impl class Impl extends LazyModuleImp(this) with HasCoreParameters with MemoryOpConstants { val (tl, edge) = node.out(0) +// val (tl_spad, edge_spad) = spad_node.out(0) val dataBytes = dataWidth / 8 val beatBytes = beatBits / 8 val lgBeatBytes = log2Ceil(beatBytes) @@ -387,7 +391,7 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: val data_single_block = Reg(UInt(dataWidth.W)) // For data that's just one-block-wide val data = Mux(req.block === 0.U, data_single_block, data_blocks.asUInt) - val bytesSent = Reg(UInt(log2Ceil((dataBytes max maxBytes)+1).W)) // TODO this only needs to count up to (dataBytes/aligned_to), right? + val bytesSent = Reg(UInt((log2Ceil((dataBytes max maxBytes)+1) + 1).W)) // TODO this only needs to count up to (dataBytes/aligned_to), right? val bytesLeft = req.len - bytesSent val xactBusy = RegInit(0.U(nXacts.W)) @@ -502,6 +506,7 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: val tl_a = DataMirror.internal.chiselTypeClone[TLBundleA](tl.a.bits) val vaddr = Output(UInt(vaddrBits.W)) val status = Output(new MStatus) + val passthrough = Output(Bool()) } val untranslated_a = Wire(Decoupled(new TLBundleAWithInfo)) @@ -510,6 +515,7 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: untranslated_a.bits.tl_a := Mux(write_full, putFull, putPartial) untranslated_a.bits.vaddr := write_vaddr untranslated_a.bits.status := req.status + untranslated_a.bits.passthrough := req.physical // 0 goes to retries, 1 goes to state machine val retry_a = Wire(Decoupled(new TLBundleAWithInfo)) @@ -527,7 +533,7 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: io.tlb.req.valid := tlb_q.io.deq.fire io.tlb.req.bits := DontCare io.tlb.req.bits.tlb_req.vaddr := tlb_q.io.deq.bits.vaddr - io.tlb.req.bits.tlb_req.passthrough := false.B + io.tlb.req.bits.tlb_req.passthrough := tlb_q.io.deq.bits.passthrough io.tlb.req.bits.tlb_req.size := 0.U // send_size io.tlb.req.bits.tlb_req.cmd := M_XWR io.tlb.req.bits.status := tlb_q.io.deq.bits.status diff --git a/src/main/scala/gemmini/GemminiConfigs.scala b/src/main/scala/gemmini/GemminiConfigs.scala index 98254299..6873c683 100644 --- a/src/main/scala/gemmini/GemminiConfigs.scala +++ b/src/main/scala/gemmini/GemminiConfigs.scala @@ -92,6 +92,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( use_firesim_simulation_counters: Boolean = false, use_shared_ext_mem: Boolean = false, + use_tl_ext_mem: Boolean = false, clock_gate: Boolean = false, headerFileName: String = "gemmini_params.h" diff --git a/src/main/scala/gemmini/GemminiISA.scala b/src/main/scala/gemmini/GemminiISA.scala index 7bca089b..31efd346 100644 --- a/src/main/scala/gemmini/GemminiISA.scala +++ b/src/main/scala/gemmini/GemminiISA.scala @@ -34,6 +34,8 @@ object GemminiISA { val CLKGATE_EN = 22.U + val STORE_SPAD_CMD = 23.U + // rs1[2:0] values val CONFIG_EX = 0.U val CONFIG_LOAD = 1.U @@ -73,6 +75,16 @@ object GemminiISA { val local_addr = local_addr_t.cloneType } + val MVOUT_SPAD_RS1_ADDR_WIDTH = 32 + val MVOUT_SPAD_RS1_STRIDE_WIDTH = 32 + + class MvoutSpadRs1(stride_bits: Int, local_addr_t: LocalAddr) extends Bundle { + val _spacer1 = UInt((MVOUT_SPAD_RS1_STRIDE_WIDTH - stride_bits).W) + val stride = UInt(stride_bits.W) + val _spacer0 = UInt((MVOUT_SPAD_RS1_ADDR_WIDTH - local_addr_t.getWidth).W) + val local_addr = local_addr_t.cloneType + } + val MVOUT_RS2_ADDR_WIDTH = 32 val MVOUT_RS2_COLS_WIDTH = 16 val MVOUT_RS2_ROWS_WIDTH = 16 diff --git a/src/main/scala/gemmini/ReservationStation.scala b/src/main/scala/gemmini/ReservationStation.scala index 47dd5ef1..1ff8eda7 100644 --- a/src/main/scala/gemmini/ReservationStation.scala +++ b/src/main/scala/gemmini/ReservationStation.scala @@ -228,7 +228,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G op1.bits.wraps_around := op1.bits.start.add_with_overflow(compute_rows)._2 } - op2.valid := funct_is_compute || funct === STORE_CMD + op2.valid := funct_is_compute || funct === STORE_CMD || funct === STORE_SPAD_CMD op2.bits.start := cmd.rs2.asTypeOf(local_addr_t) when (funct_is_compute) { val compute_rows = cmd.rs2(48 + log2Up(block_rows + 1) - 1, 48) @@ -258,12 +258,25 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G op2.bits.wraps_around := pooling_is_enabled || op2.bits.start.add_with_overflow(total_mvout_rows)._2 } - dst.valid := funct === PRELOAD_CMD || funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD + dst.valid := funct === PRELOAD_CMD || funct === STORE_SPAD_CMD || + funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD dst.bits.start := cmd.rs2(31, 0).asTypeOf(local_addr_t) when (funct === PRELOAD_CMD) { val preload_rows = cmd.rs2(48 + log2Up(block_rows + 1) - 1, 48) * c_stride dst.bits.end := dst.bits.start + preload_rows dst.bits.wraps_around := dst.bits.start.add_with_overflow(preload_rows)._2 + }.elsewhen(funct === STORE_SPAD_CMD) { + // TODO: make it so that spad move has its own load config states + val mv_cols = cmd.rs2(32 + mvout_cols_bits - 1, 32) + val mv_rows = cmd.rs2(48 + mvout_rows_bits - 1, 48) + val mvout_dst_bytes = mv_rows * (cmd.rs1(63, 32) max mv_cols) + + dst.bits.start := cmd.rs1(31, 0).asTypeOf(local_addr_t) + dst.bits.end := dst.bits.start + mvout_dst_bytes + dst.bits.wraps_around := dst.bits.start.add_with_overflow(mvout_dst_bytes.asUInt)._2 + + assert(!pooling_is_enabled, "cannot pool while moving between internal memories") + assert(!dst.bits.start.is_acc_addr, "cannot move to accumulator memory") }.otherwise { val id = MuxCase(0.U, Seq((new_entry.cmd.cmd.inst.funct === LOAD2_CMD) -> 1.U, (new_entry.cmd.cmd.inst.funct === LOAD3_CMD) -> 2.U)) @@ -294,7 +307,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G val is_load = funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD || (funct === CONFIG_CMD && config_cmd_type === CONFIG_LOAD) val is_ex = funct === PRELOAD_CMD || funct_is_compute || (funct === CONFIG_CMD && config_cmd_type === CONFIG_EX) - val is_store = funct === STORE_CMD || (funct === CONFIG_CMD && (config_cmd_type === CONFIG_STORE || config_cmd_type === CONFIG_NORM)) + val is_store = funct === STORE_CMD || funct === STORE_SPAD_CMD || (funct === CONFIG_CMD && (config_cmd_type === CONFIG_STORE || config_cmd_type === CONFIG_NORM)) val is_norm = funct === CONFIG_CMD && config_cmd_type === CONFIG_NORM // normalization commands are a subset of store commands, so they still go in the store queue new_entry.q := Mux1H(Seq( @@ -308,34 +321,57 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G val not_config = !new_entry.is_config when (is_load) { - // war (after ex/st) | waw (after ex) + // war/waw (after ex) | war/waw (after st) new_entry.deps_ld := VecInit(entries_ld.map { e => e.valid && !e.bits.issued }) // same q new_entry.deps_ex := VecInit(entries_ex.map { e => e.valid && !new_entry.is_config && ( (new_entry.opa.bits.overlaps(e.bits.opa.bits) && e.bits.opa.valid) || // waw if preload, war if compute (new_entry.opa.bits.overlaps(e.bits.opb.bits) && e.bits.opb.valid))}) // war - new_entry.deps_st := VecInit(entries_st.map { e => e.valid && e.bits.opa.valid && not_config && - new_entry.opa.bits.overlaps(e.bits.opa.bits)}) // war + new_entry.deps_st := VecInit(entries_st.map { e => e.valid && not_config && ( + (new_entry.opa.bits.overlaps(e.bits.opa.bits) && e.bits.opa.valid) || // waw if st_spad, war otherwise + (new_entry.opa.bits.overlaps(e.bits.opb.bits) && e.bits.opb.valid))}) // war if st_spad }.elsewhen (is_ex) { - // raw (after ld) | war (after st) | waw (after ld) + // raw/waw (after ld) | war/waw/raw (after st) new_entry.deps_ld := VecInit(entries_ld.map { e => e.valid && e.bits.opa.valid && not_config && ( new_entry.opa.bits.overlaps(e.bits.opa.bits) || // waw if preload, raw if compute new_entry.opb.bits.overlaps(e.bits.opa.bits))}) // raw new_entry.deps_ex := VecInit(entries_ex.map { e => e.valid && !e.bits.issued }) // same q - new_entry.deps_st := VecInit(entries_st.map { e => e.valid && e.bits.opa.valid && not_config && new_entry.opa_is_dst && - new_entry.opa.bits.overlaps(e.bits.opa.bits)}) // war + new_entry.deps_st := VecInit(entries_st.map { e => e.valid && e.bits.opa.valid && not_config && + Mux(e.bits.opa_is_dst, + // if st writes, raw/waw for ex a/b <- st a + new_entry.opa.bits.overlaps(e.bits.opa.bits) || new_entry.opb.bits.overlaps(e.bits.opa.bits) || + // additionally if ex writes, war for ex a <- st b + (new_entry.opa_is_dst && new_entry.opa.bits.overlaps(e.bits.opb.bits)) + , + // if st only reads, only check ex writes, war for ex a <- st a + new_entry.opa_is_dst && new_entry.opa.bits.overlaps(e.bits.opa.bits) + ) + }) }.otherwise { - // raw (after ld/ex) - new_entry.deps_ld := VecInit(entries_ld.map { e => e.valid && e.bits.opa.valid && not_config && - new_entry.opa.bits.overlaps(e.bits.opa.bits)}) // raw - - new_entry.deps_ex := VecInit(entries_ex.map { e => e.valid && e.bits.opa.valid && not_config && - e.bits.opa_is_dst && new_entry.opa.bits.overlaps(e.bits.opa.bits)}) // raw only if ex is preload + assert((!new_entry.opa_is_dst) || new_entry.opb.valid) + // raw (after ld/ex), waw/war if destination is spad + new_entry.deps_ld := VecInit(entries_ld.map { e => e.valid && e.bits.opa.valid && not_config && ( + new_entry.opa.bits.overlaps(e.bits.opa.bits) || // waw/raw + new_entry.opb.valid && new_entry.opb.bits.overlaps(e.bits.opa.bits))}) // raw + + new_entry.deps_ex := VecInit(entries_ex.map { e => e.valid && not_config && + Mux(new_entry.opa_is_dst, + // if st writes, war/waw for st a <- ex a/b + (new_entry.opa.bits.overlaps(e.bits.opa.bits) && e.bits.opa.valid) || + (new_entry.opa.bits.overlaps(e.bits.opb.bits) && e.bits.opb.valid) || + // additionally if ex writes, raw for st b <- ex a + (e.bits.opa.valid && e.bits.opa_is_dst && new_entry.opb.bits.overlaps(e.bits.opa.bits)) + , + // if st only reads, only check ex writes, raw for st a <- ex a + e.bits.opa.valid && e.bits.opa_is_dst && new_entry.opa.bits.overlaps(e.bits.opa.bits) + ) + }) - new_entry.deps_st := VecInit(entries_st.map { e => e.valid && !e.bits.issued }) // same q + // new_entry.deps_st := VecInit(entries_st.map { e => e.valid && !e.bits.issued }) // same q + new_entry.deps_st := VecInit(entries_st.map { e => e.valid }) // same q } new_entry.allocated_at := instructions_allocated @@ -349,7 +385,8 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G .foreach { case (q, entries_type, new_allocs_type, entries_count) => when (new_entry.q === q) { val is_full = PopCount(Seq(dst.valid, op1.valid, op2.valid)) > 1.U - when (q =/= exq) { assert(!is_full) } + when (q === ldq) { assert(!is_full) } + when (q === stq && new_entry.cmd.cmd.inst.funct =/= STORE_SPAD_CMD) { assert(!is_full) } // looking for the first invalid entry val alloc_id = MuxCase((entries_count - 1).U, entries_type.zipWithIndex.map { case (e, i) => !e.valid -> i.U }) @@ -425,7 +462,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G entries_type_.zipWithIndex.foreach { case (e, i) => val deps_type = if (q == ldq) e.bits.deps_ld else if (q == exq) e.bits.deps_ex else e.bits.deps_st - when (q === q_) { + when ((q === q_) && (q_ =/= stq)) { deps_type(issue_id) := false.B }.otherwise { when (issue_entry.bits.complete_on_issue) { diff --git a/src/main/scala/gemmini/Scratchpad.scala b/src/main/scala/gemmini/Scratchpad.scala index cdd63062..b67da1f9 100644 --- a/src/main/scala/gemmini/Scratchpad.scala +++ b/src/main/scala/gemmini/Scratchpad.scala @@ -1,14 +1,13 @@ package gemmini -import chisel3._ +import chisel3.{Bool, _} import chisel3.util._ import org.chipsalliance.cde.config.Parameters import freechips.rocketchip.diplomacy.{LazyModule, LazyModuleImp} import freechips.rocketchip.rocket._ import freechips.rocketchip.tile._ import freechips.rocketchip.tilelink._ - import Util._ class ScratchpadMemReadRequest[U <: Data](local_addr_t: LocalAddr, scale_t_bits: Int)(implicit p: Parameters) extends CoreBundle { @@ -32,6 +31,8 @@ class ScratchpadMemWriteRequest(local_addr_t: LocalAddr, acc_t_bits: Int, scale_ val vaddr = UInt(coreMaxAddrBits.W) val laddr = local_addr_t.cloneType + val dest = UInt(1.W) + val acc_act = UInt(Activation.bitwidth.W) // TODO don't use a magic number for the width here val acc_scale = UInt(scale_t_bits.W) val acc_igelu_qb = UInt(acc_t_bits.W) @@ -107,63 +108,77 @@ class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean, us val ext_mem = if (use_shared_ext_mem) Some(new ExtMemIO) else None }) - val (read, write) = if (is_dummy) { - def read(addr: UInt, ren: Bool): Data = 0.U - def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]): Unit = { } - (read _, write _) - } else if (use_shared_ext_mem) { - def read(addr: UInt, ren: Bool): Data = { - io.ext_mem.get.read_en := ren - io.ext_mem.get.read_addr := addr - io.ext_mem.get.read_data - } - io.ext_mem.get.write_en := false.B - io.ext_mem.get.write_addr := DontCare - io.ext_mem.get.write_data := DontCare - io.ext_mem.get.write_mask := DontCare - def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]) = { - io.ext_mem.get.write_en := true.B - io.ext_mem.get.write_addr := addr - io.ext_mem.get.write_data := wdata.asUInt - io.ext_mem.get.write_mask := wmask.asUInt - } - (read _, write _) - } else { - val mem = SyncReadMem(n, Vec(mask_len, mask_elem)) - def read(addr: UInt, ren: Bool): Data = mem.read(addr, ren) - def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]) = mem.write(addr, wdata, wmask) - (read _, write _) - } + val ren = io.read.req.fire + val fromDMA = io.read.req.bits.fromDMA + // Make a queue which buffers the result of an SRAM read if it can't immediately be consumed + val q = Module(new Queue(new ScratchpadReadResp(w), 1, true, true)) + val q_will_be_empty = (q.io.count +& q.io.enq.fire) - q.io.deq.fire === 0.U // When the scratchpad is single-ported, the writes take precedence val singleport_busy_with_write = single_ported.B && io.write.en - when (io.write.en) { - if (aligned_to >= w) - write(io.write.addr, io.write.data.asTypeOf(Vec(mask_len, mask_elem)), VecInit((~(0.U(mask_len.W))).asBools)) - else - write(io.write.addr, io.write.data.asTypeOf(Vec(mask_len, mask_elem)), io.write.mask) - } - - val raddr = io.read.req.bits.addr - val ren = io.read.req.fire - val rdata = if (single_ported) { - assert(!(ren && io.write.en)) - read(raddr, ren && !io.write.en).asUInt - } else { - read(raddr, ren).asUInt - } + if (is_dummy) { + q.io.enq.valid := RegNext(ren) + q.io.enq.bits.data := 0.U + q.io.enq.bits.fromDMA := RegNext(fromDMA) + io.read.req.ready := q_will_be_empty && !singleport_busy_with_write + } else if (use_shared_ext_mem) { // use ready-valid interface + val ext_mem = io.ext_mem.get + + /* READ */ + ext_mem.read_req.valid := io.read.req.valid + ext_mem.read_req.bits := io.read.req.bits.addr + io.read.req.ready := ext_mem.read_req.ready + + // TODO (richard): the number of entries here should be configurable + val dma_q = Module(new Queue(Bool(), 4, true, true)) + dma_q.io.enq.valid := ren + dma_q.io.enq.bits := fromDMA + dma_q.io.deq.ready := q.io.enq.fire + assert(dma_q.io.deq.fire || (!q.io.enq.fire), "fromDMA should be dequeued only when read resp comes back") + assert(dma_q.io.enq.ready || (!ren), "DMA queue does not have enough entries") // TODO (richard): do backpressure + + q.io.enq.valid := ext_mem.read_resp.valid + q.io.enq.bits.data := ext_mem.read_resp.bits + q.io.enq.bits.fromDMA := dma_q.io.deq.bits + ext_mem.read_resp.ready := q.io.enq.ready + + /* WRITE */ + val wq = Module(new Queue(ext_mem.write_req.bits.cloneType, 12, pipe=true, flow=true)) + ext_mem.write_req <> wq.io.deq + + wq.io.enq.valid := io.write.en + wq.io.enq.bits.addr := io.write.addr + wq.io.enq.bits.data := io.write.data + if (aligned_to >= w) { + wq.io.enq.bits.mask := VecInit((~(0.U(mask_len.W))).asBools).asUInt + } else { + wq.io.enq.bits.mask := io.write.mask.asUInt + } + assert(wq.io.enq.ready || (!io.write.en), "TODO (richard): fix this if triggered") + } else { // use valid only interface + val mem = SyncReadMem(n, Vec(mask_len, mask_elem)) - val fromDMA = io.read.req.bits.fromDMA + val raddr = io.read.req.bits.addr + val rdata = if (single_ported) { + assert(!(ren && io.write.en)) + mem.read(raddr, ren && !io.write.en).asUInt + } else { + mem.read(raddr, ren).asUInt + } + q.io.enq.valid := RegNext(ren) + q.io.enq.bits.data := rdata + q.io.enq.bits.fromDMA := RegNext(fromDMA) - // Make a queue which buffers the result of an SRAM read if it can't immediately be consumed - val q = Module(new Queue(new ScratchpadReadResp(w), 1, true, true)) - q.io.enq.valid := RegNext(ren) - q.io.enq.bits.data := rdata - q.io.enq.bits.fromDMA := RegNext(fromDMA) + io.read.req.ready := q_will_be_empty && !singleport_busy_with_write - val q_will_be_empty = (q.io.count +& q.io.enq.fire) - q.io.deq.fire === 0.U - io.read.req.ready := q_will_be_empty && !singleport_busy_with_write + when(io.write.en) { + if (aligned_to >= w) + mem.write(io.write.addr, io.write.data.asTypeOf(Vec(mask_len, mask_elem)), VecInit((~(0.U(mask_len.W))).asBools)) + else + mem.write(io.write.addr, io.write.data.asTypeOf(Vec(mask_len, mask_elem)), io.write.mask) + } + } io.read.resp <> q.io.deq } @@ -192,6 +207,9 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val writer = LazyModule(new StreamWriter(max_in_flight_mem_reqs, dataBits, maxBytes, if (acc_read_full_width) acc_w else spad_w, aligned_to, inputType, block_cols, use_tlb_register_filter, use_firesim_simulation_counters)) + val spad_writer = LazyModule(new StreamWriter(max_in_flight_mem_reqs, dataBits, maxBytes, + if (acc_read_full_width) acc_w else spad_w, aligned_to, inputType, block_cols, use_tlb_register_filter, + use_firesim_simulation_counters)) // TODO make a cross-bar vs two separate ports a config option // id_node :=* reader.node @@ -237,7 +255,8 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, } // TLB ports - val tlb = Vec(2, new FrontendTLBIO) + // TODO(richard): bypass TLB + val tlb = Vec(3, new FrontendTLBIO) // Misc. ports val busy = Output(Bool()) @@ -291,9 +310,10 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, write_issue_q.io.deq.bits.laddr.is_acc_addr && write_issue_q.io.deq.bits.laddr.read_full_acc_row val writeData_is_all_zeros = write_issue_q.io.deq.bits.laddr.is_garbage() - writer.module.io.req.valid := write_issue_q.io.deq.valid && writeData.valid - write_issue_q.io.deq.ready := writer.module.io.req.ready && writeData.valid + writer.module.io.req.valid := write_issue_q.io.deq.valid && writeData.valid && !write_issue_q.io.deq.bits.dest.asBool + // write_issue_q.io.deq.ready := writer.module.io.req.ready && writeData.valid writer.module.io.req.bits.vaddr := write_issue_q.io.deq.bits.vaddr + writer.module.io.req.bits.physical := write_issue_q.io.deq.bits.dest writer.module.io.req.bits.len := Mux(writeData_is_full_width, write_issue_q.io.deq.bits.len * (accType.getWidth / 8).U, write_issue_q.io.deq.bits.len * (inputType.getWidth / 8).U) @@ -306,6 +326,22 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, writer.module.io.req.bits.pool_en := write_issue_q.io.deq.bits.pool_en writer.module.io.req.bits.store_en := write_issue_q.io.deq.bits.store_en + spad_writer.module.io.req.valid := write_issue_q.io.deq.valid && writeData.valid && write_issue_q.io.deq.bits.dest.asBool + write_issue_q.io.deq.ready := writer.module.io.req.ready && spad_writer.module.io.req.ready && writeData.valid + spad_writer.module.io.req.bits.vaddr := write_issue_q.io.deq.bits.vaddr << 4.U // TODO(richard): do not hardcode + spad_writer.module.io.req.bits.physical := write_issue_q.io.deq.bits.dest + spad_writer.module.io.req.bits.len := Mux(writeData_is_full_width, + write_issue_q.io.deq.bits.len * (accType.getWidth / 8).U, + write_issue_q.io.deq.bits.len * (inputType.getWidth / 8).U) + spad_writer.module.io.req.bits.data := MuxCase(writeData.bits, Seq( + writeData_is_all_zeros -> 0.U, + writeData_is_full_width -> fullAccWriteData + )) + spad_writer.module.io.req.bits.block := write_issue_q.io.deq.bits.block + spad_writer.module.io.req.bits.status := write_issue_q.io.deq.bits.status + spad_writer.module.io.req.bits.pool_en := write_issue_q.io.deq.bits.pool_en + spad_writer.module.io.req.bits.store_en := write_issue_q.io.deq.bits.store_en + io.dma.write.resp.valid := false.B io.dma.write.resp.bits.cmd_id := write_dispatch_q.bits.cmd_id when (write_dispatch_q.bits.laddr.is_garbage() && write_dispatch_q.fire) { @@ -437,11 +473,14 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, io.tlb(0) <> writer.module.io.tlb io.tlb(1) <> reader.module.io.tlb + io.tlb(2) <> spad_writer.module.io.tlb + spad_writer.module.io.flush := io.flush writer.module.io.flush := io.flush reader.module.io.flush := io.flush - io.busy := writer.module.io.busy || reader.module.io.busy || write_issue_q.io.deq.valid || write_norm_q.io.deq.valid || write_scale_q.io.deq.valid || write_dispatch_q.valid + io.busy := writer.module.io.busy || spad_writer.module.io.busy || reader.module.io.busy || + write_issue_q.io.deq.valid || write_norm_q.io.deq.valid || write_scale_q.io.deq.valid || write_dispatch_q.valid val spad_mems = { val banks = Seq.fill(sp_banks) { Module(new ScratchpadBank( @@ -493,20 +532,24 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, ex_read_resp.valid := bio.read.resp.valid && !bio.read.resp.bits.fromDMA ex_read_resp.bits := bio.read.resp.bits - val dma_read_pipe = Pipeline(dma_read_resp, spad_read_delay) - val ex_read_pipe = Pipeline(ex_read_resp, spad_read_delay) + val dma_read_pipe = Module(new Queue(dma_read_resp.bits.cloneType, spad_read_delay, flow = false, pipe = true)) + val ex_read_pipe = Module(new Queue(ex_read_resp.bits.cloneType, spad_read_delay, flow = false, pipe = true)) + + dma_read_pipe.io.enq <> dma_read_resp + ex_read_pipe.io.enq <> ex_read_resp + // TODO (richard): perhaps backpressure can be applied here when writes are blocked bio.read.resp.ready := Mux(bio.read.resp.bits.fromDMA, dma_read_resp.ready, ex_read_resp.ready) - dma_read_pipe.ready := writer.module.io.req.ready && + dma_read_pipe.io.deq.ready := writer.module.io.req.ready && spad_writer.module.io.req.ready && !write_issue_q.io.deq.bits.laddr.is_acc_addr && write_issue_q.io.deq.bits.laddr.sp_bank() === i.U && // I believe we don't need to check that write_issue_q is valid here, because if the SRAM's resp is valid, then that means that the write_issue_q's deq should also be valid !write_issue_q.io.deq.bits.laddr.is_garbage() - when (dma_read_pipe.fire) { + when (dma_read_pipe.io.deq.fire) { writeData.valid := true.B - writeData.bits := dma_read_pipe.bits.data + writeData.bits := dma_read_pipe.io.deq.bits.data } - io.srams.read(i).resp <> ex_read_pipe + io.srams.read(i).resp <> ex_read_pipe.io.deq } // Writing to the SRAM banks @@ -597,14 +640,14 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, acc_scale_unit.io.in.valid := acc_norm_unit_out.valid && acc_waiting_to_be_scaled acc_scale_unit.io.in.bits := acc_norm_unit_out.bits - when (acc_scale_unit.io.in.fire()) { + when (acc_scale_unit.io.in.fire) { write_issue_q.io.enq <> write_scale_q.io.deq } acc_scale_unit.io.out.ready := false.B val dma_resp_ready = - writer.module.io.req.ready && + (writer.module.io.req.ready && spad_writer.module.io.req.ready) && write_issue_q.io.deq.bits.laddr.is_acc_addr && !write_issue_q.io.deq.bits.laddr.is_garbage() @@ -631,7 +674,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val banks = Seq.fill(acc_banks) { Module(new AccumulatorMem( acc_bank_entries, acc_row_t, acc_scale_func, acc_scale_t.asInstanceOf[V], acc_singleported, acc_sub_banks, - use_shared_ext_mem, + use_shared_ext_mem, use_tl_ext_mem, acc_latency, accType, is_dummy )) } val bank_ios = VecInit(banks.map(_.io)) @@ -828,5 +871,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, io.counter := DontCare io.counter.collect(reader.module.io.counter) io.counter.collect(writer.module.io.counter) + spad_writer.module.io.counter := DontCare +// io.counter.collect(spad_writer.module.io.counter) } } diff --git a/src/main/scala/gemmini/SharedExtMem.scala b/src/main/scala/gemmini/SharedExtMem.scala index f3acdd2e..26f28a16 100644 --- a/src/main/scala/gemmini/SharedExtMem.scala +++ b/src/main/scala/gemmini/SharedExtMem.scala @@ -7,14 +7,13 @@ import Util._ class ExtMemIO extends Bundle { - val read_en = Output(Bool()) - val read_addr = Output(UInt()) - val read_data = Input(UInt()) - - val write_en = Output(Bool()) - val write_addr = Output(UInt()) - val write_data = Output(UInt()) - val write_mask = Output(UInt()) + val read_req = DecoupledIO(UInt()) + val read_resp = Flipped(DecoupledIO(UInt())) + val write_req = DecoupledIO(new Bundle { + val addr = Output(UInt()) + val data = Output(UInt()) + val mask = Output(UInt()) + }) } class ExtSpadMemIO(sp_banks: Int, acc_banks: Int, acc_sub_banks: Int) extends Bundle { @@ -28,18 +27,18 @@ class SharedSyncReadMem(nSharers: Int, depth: Int, mask_len: Int, data_len: Int) val in = Vec(nSharers, Flipped(new ExtMemIO())) }) val mem = SyncReadMem(depth, Vec(mask_len, UInt(data_len.W))) - val wens = io.in.map(_.write_en) + val wens = io.in.map(_.write_req.valid) val wen = wens.reduce(_||_) - val waddr = Mux1H(wens, io.in.map(_.write_addr)) - val wmask = Mux1H(wens, io.in.map(_.write_mask)) - val wdata = Mux1H(wens, io.in.map(_.write_data)) + val waddr = Mux1H(wens, io.in.map(_.write_req.bits.addr)) + val wmask = Mux1H(wens, io.in.map(_.write_req.bits.mask)) + val wdata = Mux1H(wens, io.in.map(_.write_req.bits.data)) assert(PopCount(wens) <= 1.U) - val rens = io.in.map(_.read_en) + val rens = io.in.map(_.read_req.valid) assert(PopCount(rens) <= 1.U) val ren = rens.reduce(_||_) - val raddr = Mux1H(rens, io.in.map(_.read_addr)) + val raddr = Mux1H(rens, io.in.map(_.read_req.bits)) val rdata = mem.read(raddr, ren && !wen) - io.in.foreach(_.read_data := rdata.asUInt) + io.in.foreach(_.read_resp.bits := rdata.asUInt) when (wen) { mem.write(waddr, wdata.asTypeOf(Vec(mask_len, UInt(data_len.W))), wmask.asTypeOf(Vec(mask_len, Bool()))) } @@ -68,12 +67,14 @@ class SharedExtMem( acc_mem.io.in(0) <> io.in(0).acc(i)(s) // The FP gemmini expects a taller, skinnier accumulator mem acc_mem.io.in(1) <> io.in(1).acc(i)(s) - acc_mem.io.in(1).read_addr := io.in(1).acc(i)(s).read_addr >> 1 - io.in(1).acc(i)(s).read_data := acc_mem.io.in(1).read_data.asTypeOf(Vec(2, UInt((acc_data_len * acc_mask_len / 2).W)))(RegNext(io.in(1).acc(i)(s).read_addr(0))) + acc_mem.io.in(1).read_req.bits := io.in(1).acc(i)(s).read_req.bits >> 1 + io.in(1).acc(i)(s).read_resp.bits := acc_mem.io.in(1).read_resp.bits.asTypeOf(Vec(2, UInt((acc_data_len * acc_mask_len / 2).W)))(RegNext(io.in(1).acc(i)(s).read_req.bits(0))) - acc_mem.io.in(1).write_addr := io.in(1).acc(i)(s).write_addr >> 1 - acc_mem.io.in(1).write_data := Cat(io.in(1).acc(i)(s).write_data, io.in(1).acc(i)(s).write_data) - acc_mem.io.in(1).write_mask := Mux(io.in(1).acc(i)(s).write_addr(0), io.in(1).acc(i)(s).write_mask << (acc_mask_len / 2), io.in(1).acc(i)(s).write_mask) + acc_mem.io.in(1).write_req.bits.addr := io.in(1).acc(i)(s).write_req.bits.addr >> 1 + acc_mem.io.in(1).write_req.bits.data := Cat(io.in(1).acc(i)(s).write_req.bits.data, + io.in(1).acc(i)(s).write_req.bits.data) + acc_mem.io.in(1).write_req.bits.mask := Mux(io.in(1).acc(i)(s).write_req.bits.addr(0), + io.in(1).acc(i)(s).write_req.bits.mask << (acc_mask_len / 2), io.in(1).acc(i)(s).write_req.bits.mask) } } } diff --git a/src/main/scala/gemmini/StoreController.scala b/src/main/scala/gemmini/StoreController.scala index 72cd761b..4be9f920 100644 --- a/src/main/scala/gemmini/StoreController.scala +++ b/src/main/scala/gemmini/StoreController.scala @@ -84,6 +84,10 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm // Commands val cmd = Queue(io.cmd, st_queue_length) val vaddr = cmd.bits.cmd.rs1 + val mvout_spad_rs1 = cmd.bits.cmd.rs1.asTypeOf(new MvoutSpadRs1(32, local_addr_t)) + val dst_spad_addr = mvout_spad_rs1.local_addr + val dst_spad_stride = mvout_spad_rs1.stride + val dst_is_spad = cmd.bits.cmd.inst.funct === STORE_SPAD_CMD val mvout_rs2 = cmd.bits.cmd.rs2.asTypeOf(new MvoutRs2(mvout_rows_bits, mvout_cols_bits, local_addr_t)) val localaddr = mvout_rs2.local_addr val cols = mvout_rs2.num_cols @@ -122,6 +126,7 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val current_vaddr = vaddr + row_counter * stride val current_localaddr = WireInit(localaddr + (block_counter * block_stride + row_counter)) + val current_dst_spad_addr = dst_spad_addr.asUInt + row_counter * dst_spad_stride val pool_row_addr = localaddr + (orow * pool_ocols +& ocol) when (orow_is_negative || ocol_is_negative || orow >= pool_orows || ocol >= pool_ocols) { @@ -157,7 +162,10 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm (control_state === sending_rows && (block_counter =/= 0.U || row_counter =/= 0.U)) || (control_state === pooling && (wcol_counter =/= 0.U || wrow_counter =/= 0.U || pocol_counter =/= 0.U || porow_counter =/= 0.U)) - io.dma.req.bits.vaddr := Mux(pooling_is_enabled || mvout_1d_enabled, pool_vaddr, current_vaddr) + io.dma.req.bits.vaddr := Mux(dst_is_spad, + current_dst_spad_addr, + Mux(pooling_is_enabled || mvout_1d_enabled, pool_vaddr, current_vaddr)) + io.dma.req.bits.dest := dst_is_spad io.dma.req.bits.laddr := Mux(pooling_is_enabled, pool_row_addr, current_localaddr) //Todo: laddr for 1D? io.dma.req.bits.laddr.norm_cmd := Mux(block_counter === blocks - 1.U, current_localaddr.norm_cmd, NormCmd.non_reset_version(current_localaddr.norm_cmd))