diff --git a/.circleci/build-toolchains.sh b/.circleci/build-toolchains.sh index 39caa772..1a4e4b13 100755 --- a/.circleci/build-toolchains.sh +++ b/.circleci/build-toolchains.sh @@ -28,5 +28,5 @@ if [ ! -d "$HOME/$1-install" ]; then cd $HOME # init all submodules including the tools (doesn't use CI_MAKE_PROC due to mem. constraints) - CHIPYARD_DIR="$LOCAL_CHIPYARD_DIR" NPROC=$CI_MAKE_PROC $LOCAL_CHIPYARD_DIR/scripts/build-toolchains.sh esp-tools + CHIPYARD_DIR="$LOCAL_CHIPYARD_DIR" NPROC=$CI_MAKE_NPROC $LOCAL_CHIPYARD_DIR/scripts/build-toolchains.sh esp-tools fi diff --git a/.circleci/defaults.sh b/.circleci/defaults.sh index 6100774a..2d200104 100755 --- a/.circleci/defaults.sh +++ b/.circleci/defaults.sh @@ -14,7 +14,7 @@ ############# # make parallelism -CI_MAKE_NPROC=8 +CI_MAKE_NPROC=4 LOCAL_MAKE_NPROC=$CI_MAKE_NPROC # verilator version diff --git a/CHIPYARD.hash b/CHIPYARD.hash index 8f7c41f5..70a1842b 100644 --- a/CHIPYARD.hash +++ b/CHIPYARD.hash @@ -1 +1 @@ -939e3a9f94d5bfef9671f49c37cd3acd5fc26128 +1e2f778a6705033d67ccbcc932e66083e4646f15 diff --git a/README.md b/README.md index 1853c913..a70a2146 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Gemmini is implemented as a RoCC accelerator with non-standard RISC-V custom ins At the heart of the accelerator lies a systolic array which performs matrix multiplications. By default, the matrix multiplication support both _output-stationary_ and _weight-stationary_ dataflows, which programmers can pick between at runtime. However, the dataflow can also be hardened at elaboration time. -The systolic array's inputs and outputs are stored in an explicity managed scratchpad, made up of banked SRAMs. A DMA engine facilitates the tranfer of data between main memory and the scratchpad. +The systolic array's inputs and outputs are stored in an explicity managed scratchpad, made up of banked SRAMs. A DMA engine facilitates the transfer of data between main memory and the scratchpad. Because weight-stationary dataflows require an accumulator outside the systolic array, we add a final SRAM bank, equipped with adder units, which can be conceptually considered an extension of the scratchpad memory space. The systolic array can store results to any address in the accumulator, and can also read new inputs from any address in the accumulator. The DMA engine can also tranfer data directly between the accumulator and main memory, which is often necessary to load in biases. @@ -75,7 +75,7 @@ The ``software`` directory of the generator includes the aforementioned library The Gemmini generator generates a C header file based on the generator parameters. This header files gets compiled together with the matrix multiplication library to tune library performance. The generated header file can be found under ``software/gemmini-rocc-tests/include/gemmini_params.h`` Gemmini can also be used to run ONNX-specified neural-networks through a port of Microsoft's ONNX-Runtime framework. The port is included as the [onnxruntime-riscv](https://github.com/pranav-prakash/onnxruntime-riscv) repository submoduled in the `software` directory. -To start using ONNX-Runtime, run `git submodule update --init --recursive software/onnxruntime-riscv`, and read the documentation at [here](https://github.com/pranav-prakash/onnxruntime-riscv/blob/systolic/systolic_runner/docs). +To start using ONNX-Runtime, run `git submodule update --init --recursive software/onnxruntime-riscv`, and read the documentation [here](https://github.com/pranav-prakash/onnxruntime-riscv/blob/systolic/systolic_runner/docs). ## Build and Run Gemmini Tests @@ -317,3 +317,15 @@ This section describes an additional set of RoCC instructions that configure and ### `COMPUTE_CISC` runs a complete hardware tiling sequence with the configured A, B, C, D, M, N, K, RPT_BIAS values **Format:** `compute_cisc` - `funct` = 17 + +# Citing Gemmini +If Gemmini helps you in your academic research, you are encouraged to cite our paper. Here is an example bibtex: +``` +@article{genc2019gemmini, + title={Gemmini: An Agile Systolic Array Generator Enabling Systematic Evaluations of Deep-Learning Architectures}, + author={Genc, Hasan and Haj-Ali, Ameer and Iyer, Vighnesh and Amid, Alon and Mao, Howard and Wright, John and Schmidt, Colin and Zhao, Jerry and Ou, Albert and Banister, Max and Shao, Yakun Sophia and Nikolic, Borivoje and Stoica, Ion and Asanovic, Krste}, + journal={arXiv preprint arXiv:1911.09925}, + year={2019} +} +``` + diff --git a/SPIKE.hash b/SPIKE.hash index df14dfbe..3137a068 100644 --- a/SPIKE.hash +++ b/SPIKE.hash @@ -1 +1 @@ -3db7a449d97bf40a101ef541089054e6af59d7df \ No newline at end of file +bc3222e351cdd645b6fd2605fd9611e3bc0d9cae diff --git a/software/gemmini-rocc-tests b/software/gemmini-rocc-tests index bf934460..bd9dbe0b 160000 --- a/software/gemmini-rocc-tests +++ b/software/gemmini-rocc-tests @@ -1 +1 @@ -Subproject commit bf934460b1addc96ebbf50ab4574a4dc46959703 +Subproject commit bd9dbe0b0dcde33b5445711ed27c6840167c10bf diff --git a/src/main/scala/gemmini/AccumulatorMem.scala b/src/main/scala/gemmini/AccumulatorMem.scala index e218bd51..1939c624 100644 --- a/src/main/scala/gemmini/AccumulatorMem.scala +++ b/src/main/scala/gemmini/AccumulatorMem.scala @@ -17,19 +17,21 @@ class AccumulatorReadReq[T <: Data](n: Int, shift_width: Int, scale_t: T) extend override def cloneType: this.type = new AccumulatorReadReq(n, shift_width, scale_t.cloneType).asInstanceOf[this.type] } -class AccumulatorReadResp[T <: Data: Arithmetic](rdataType: Vec[Vec[T]], fullDataType: Vec[Vec[T]]) extends Bundle { - val data = rdataType.cloneType - val full_data = fullDataType.cloneType +class AccumulatorReadResp[T <: Data: Arithmetic, U <: Data](fullDataType: Vec[Vec[T]], scale_t: U, shift_width: Int) extends Bundle { + val data = fullDataType.cloneType val fromDMA = Bool() - - override def cloneType: this.type = new AccumulatorReadResp(rdataType.cloneType, fullDataType.cloneType).asInstanceOf[this.type] + val scale = scale_t.cloneType + val relu6_shift = UInt(shift_width.W) + val act = UInt(2.W) + val acc_bank_id = UInt(2.W) // TODO don't hardcode + override def cloneType: this.type = new AccumulatorReadResp(fullDataType.cloneType, scale_t, shift_width).asInstanceOf[this.type] } -class AccumulatorReadIO[T <: Data: Arithmetic, U <: Data](n: Int, shift_width: Int, rdataType: Vec[Vec[T]], fullDataType: Vec[Vec[T]], scale_t: U) extends Bundle { - val req = Decoupled(new AccumulatorReadReq(n, shift_width, scale_t)) - val resp = Flipped(Decoupled(new AccumulatorReadResp(rdataType.cloneType, fullDataType.cloneType))) +class AccumulatorReadIO[T <: Data: Arithmetic, U <: Data](n: Int, shift_width: Int, fullDataType: Vec[Vec[T]], scale_t: U) extends Bundle { + val req = Decoupled(new AccumulatorReadReq[U](n, shift_width, scale_t)) + val resp = Flipped(Decoupled(new AccumulatorReadResp[T, U](fullDataType, scale_t, shift_width))) - override def cloneType: this.type = new AccumulatorReadIO(n, shift_width, rdataType.cloneType, fullDataType.cloneType, scale_t.cloneType).asInstanceOf[this.type] + override def cloneType: this.type = new AccumulatorReadIO(n, shift_width, fullDataType.cloneType, scale_t.cloneType).asInstanceOf[this.type] } class AccumulatorWriteReq[T <: Data: Arithmetic](n: Int, t: Vec[Vec[T]]) extends Bundle { @@ -42,16 +44,19 @@ class AccumulatorWriteReq[T <: Data: Arithmetic](n: Int, t: Vec[Vec[T]]) extends override def cloneType: this.type = new AccumulatorWriteReq(n, t).asInstanceOf[this.type] } -class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]], rdata: Vec[Vec[T]], scale_t: U) extends Bundle { - val read = Flipped(new AccumulatorReadIO(n, log2Ceil(t.head.head.getWidth), rdata, t, scale_t)) +class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]], scale_t: U) extends Bundle { + val read = Flipped(new AccumulatorReadIO(n, log2Ceil(t.head.head.getWidth), t, scale_t)) // val write = Flipped(new AccumulatorWriteIO(n, t)) val write = Flipped(Decoupled(new AccumulatorWriteReq(n, t))) - override def cloneType: this.type = new AccumulatorMemIO(n, t, rdata, scale_t).asInstanceOf[this.type] + override def cloneType: this.type = new AccumulatorMemIO(n, t, scale_t).asInstanceOf[this.type] } -class AccumulatorMem[T <: Data, U <: Data](n: Int, t: Vec[Vec[T]], rdataType: Vec[Vec[T]], mem_pipeline: Int, scale_args: ScaleArguments[T, U], read_small_data: Boolean, read_full_data: Boolean) - (implicit ev: Arithmetic[T]) extends Module { +class AccumulatorMem[T <: Data, U <: Data]( + n: Int, t: Vec[Vec[T]], scale_args: ScaleArguments[T, U], + acc_singleported: Boolean, num_acc_sub_banks: Int +) + (implicit ev: Arithmetic[T]) extends Module { // TODO Do writes in this module work with matrices of size 2? If we try to read from an address right after writing // to it, then we might not get the written data. We might need some kind of cooldown counter after addresses in the // accumulator have been written to for configurations with such small matrices @@ -64,9 +69,8 @@ class AccumulatorMem[T <: Data, U <: Data](n: Int, t: Vec[Vec[T]], rdataType: Ve import ev._ // TODO unify this with TwoPortSyncMemIO - val io = IO(new AccumulatorMemIO(n, t, rdataType, scale_args.multiplicand_t)) + val io = IO(new AccumulatorMemIO(n, t, scale_args.multiplicand_t)) - val mem = TwoPortSyncMem(n, t, t.getWidth / 8) // TODO We assume byte-alignment here. Use aligned_to instead // For any write operation, we spend 2 cycles reading the existing address out, buffering it in a register, and then // accumulating on top of it (if necessary) @@ -75,55 +79,146 @@ class AccumulatorMem[T <: Data, U <: Data](n: Int, t: Vec[Vec[T]], rdataType: Ve val acc_buf = ShiftRegister(io.write.bits.acc, 2) val mask_buf = ShiftRegister(io.write.bits.mask, 2) val w_buf_valid = ShiftRegister(io.write.fire(), 2) - - val w_sum = VecInit((RegNext(mem.io.rdata) zip wdata_buf).map { case (rv, wv) => + val acc_rdata = Wire(t) + acc_rdata := DontCare + val read_rdata = Wire(t) + read_rdata := DontCare + val block_read_req = WireInit(false.B) + val w_sum = VecInit((RegNext(acc_rdata) zip wdata_buf).map { case (rv, wv) => VecInit((rv zip wv).map(t => t._1 + t._2)) }) - mem.io.waddr := waddr_buf - mem.io.wen := w_buf_valid - mem.io.wdata := Mux(acc_buf, w_sum, wdata_buf) - mem.io.mask := mask_buf - - mem.io.raddr := Mux(io.write.fire() && io.write.bits.acc, io.write.bits.addr, io.read.req.bits.addr) - mem.io.ren := io.read.req.fire() || (io.write.fire() && io.write.bits.acc) - - class PipelinedRdataAndActT extends Bundle { - val data = mem.io.rdata.cloneType - val full_data = mem.io.rdata.cloneType - val scale = io.read.req.bits.scale.cloneType - val relu6_shift = io.read.req.bits.relu6_shift.cloneType - val act = io.read.req.bits.act.cloneType - val fromDMA = io.read.req.bits.fromDMA.cloneType + if (!acc_singleported) { + val mem = TwoPortSyncMem(n, t, t.getWidth / 8) // TODO We assume byte-alignment here. Use aligned_to instead + mem.io.waddr := waddr_buf + mem.io.wen := w_buf_valid + mem.io.wdata := Mux(acc_buf, w_sum, wdata_buf) + mem.io.mask := mask_buf + acc_rdata := mem.io.rdata + read_rdata := mem.io.rdata + mem.io.raddr := Mux(io.write.fire() && io.write.bits.acc, io.write.bits.addr, io.read.req.bits.addr) + mem.io.ren := io.read.req.fire() || (io.write.fire() && io.write.bits.acc) + } else { + val mask_len = t.getWidth / 8 + val mask_elem = UInt((t.getWidth / mask_len).W) + val reads = Wire(Vec(2, Decoupled(UInt()))) + reads(0).valid := io.write.valid && io.write.bits.acc + reads(0).bits := io.write.bits.addr + reads(0).ready := true.B + reads(1).valid := io.read.req.valid + reads(1).bits := io.read.req.bits.addr + reads(1).ready := true.B + block_read_req := !reads(1).ready + for (i <- 0 until num_acc_sub_banks) { + def isThisBank(addr: UInt) = addr(log2Ceil(num_acc_sub_banks)-1,0) === i.U + def getBankIdx(addr: UInt) = addr >> log2Ceil(num_acc_sub_banks) + val mem = SyncReadMem(n / num_acc_sub_banks, Vec(mask_len, mask_elem)) + + val ren = WireInit(false.B) + val raddr = WireInit(getBankIdx(reads(0).bits)) + val nEntries = 3 + // Writes coming 2 cycles after read leads to bad bank behavior + // Add another buffer here + class W_Q_Entry[T <: Data](mask_len: Int, mask_elem: T) extends Bundle { + val valid = Bool() + val data = Vec(mask_len, mask_elem) + val mask = Vec(mask_len, Bool()) + val addr = UInt(log2Ceil(n/num_acc_sub_banks).W) + override def cloneType: this.type = new W_Q_Entry(mask_len, mask_elem).asInstanceOf[this.type] + } + val w_q = Reg(Vec(nEntries, new W_Q_Entry(mask_len, mask_elem))) + for (e <- w_q) { + when (e.valid) { + assert(!( + io.write.valid && io.write.bits.acc && + isThisBank(io.write.bits.addr) && getBankIdx(io.write.bits.addr) === e.addr && + ((io.write.bits.mask.asUInt & e.mask.asUInt) =/= 0.U) + )) + when (io.read.req.valid && isThisBank(io.read.req.bits.addr) && getBankIdx(io.read.req.bits.addr) === e.addr) { + reads(1).ready := false.B + } + } + } + val w_q_head = RegInit(1.U(nEntries.W)) + val w_q_tail = RegInit(1.U(nEntries.W)) + when (reset.asBool) { + w_q.foreach(_.valid := false.B) + } + val wen = WireInit(false.B) + val wdata = Mux1H(w_q_head.asBools, w_q.map(_.data)) + val wmask = Mux1H(w_q_head.asBools, w_q.map(_.mask)) + val waddr = Mux1H(w_q_head.asBools, w_q.map(_.addr)) + when (wen) { + w_q_head := w_q_head << 1 | w_q_head(nEntries-1) + for (i <- 0 until nEntries) { + when (w_q_head(i)) { + w_q(i).valid := false.B + } + } + } + + when (w_buf_valid && isThisBank(waddr_buf)) { + assert(!((w_q_tail.asBools zip w_q.map(_.valid)).map({ case (h,v) => h && v }).reduce(_||_))) + w_q_tail := w_q_tail << 1 | w_q_tail(nEntries-1) + for (i <- 0 until nEntries) { + when (w_q_tail(i)) { + w_q(i).valid := true.B + w_q(i).data := Mux(acc_buf, w_sum, wdata_buf).asTypeOf(Vec(mask_len, mask_elem)) + w_q(i).mask := mask_buf + w_q(i).addr := getBankIdx(waddr_buf) + } + } + + } + val bank_rdata = mem.read(raddr, ren && !wen).asTypeOf(t) + when (RegNext(ren && reads(0).valid && isThisBank(reads(0).bits))) { + acc_rdata := bank_rdata + } .elsewhen (RegNext(ren)) { + read_rdata := bank_rdata + } + when (wen) { + mem.write(waddr, wdata, wmask) + } + // Three requestors, 1 slot + // Priority is incoming reads for RMW > writes from RMW > incoming reads + when (reads(0).valid && isThisBank(reads(0).bits)) { + ren := true.B + when (isThisBank(reads(1).bits)) { + reads(1).ready := false.B + } + } .elsewhen ((w_q_head.asBools zip w_q.map(_.valid)).map({ case (h,v) => h && v }).reduce(_||_)) { + wen := true.B + when (isThisBank(reads(1).bits)) { + reads(1).ready := false.B + } + } .otherwise { + ren := isThisBank(reads(1).bits) + raddr := getBankIdx(reads(1).bits) + } + } } - val q = Module(new Queue(new PipelinedRdataAndActT, 1, true, true)) - q.io.enq.bits.data := mem.io.rdata - q.io.enq.bits.full_data := mem.io.rdata + val q = Module(new Queue(new AccumulatorReadResp(t, scale_args.multiplicand_t, log2Ceil(t.head.head.getWidth)), 1, true, true)) + q.io.enq.bits.data := read_rdata q.io.enq.bits.scale := RegNext(io.read.req.bits.scale) q.io.enq.bits.relu6_shift := RegNext(io.read.req.bits.relu6_shift) q.io.enq.bits.act := RegNext(io.read.req.bits.act) q.io.enq.bits.fromDMA := RegNext(io.read.req.bits.fromDMA) + q.io.enq.bits.acc_bank_id := DontCare q.io.enq.valid := RegNext(io.read.req.fire()) - val p = Pipeline(q.io.deq, mem_pipeline, Seq.fill(mem_pipeline)((x: PipelinedRdataAndActT) => x) :+ { - x: PipelinedRdataAndActT => - val activated_rdata = VecInit(x.data.map(v => VecInit(v.map { e => - // val e_scaled = e >> x.shift - val e_scaled = scale_args.scale_func(e, x.scale) - val e_clipped = e_scaled.clippedToWidthOf(rdataType.head.head) - val e_act = MuxCase(e_clipped, Seq( - (x.act === Activation.RELU) -> e_clipped.relu, - (x.act === Activation.RELU6) -> e_clipped.relu6(x.relu6_shift))) - e_act - }))) + val p = q.io.deq - val result = WireInit(x) - result.data := activated_rdata + io.read.resp.bits.data := p.bits.data + io.read.resp.bits.fromDMA := p.bits.fromDMA + io.read.resp.bits.relu6_shift := p.bits.relu6_shift + io.read.resp.bits.act := p.bits.act + io.read.resp.bits.scale := p.bits.scale + io.read.resp.bits.acc_bank_id := DontCare // This is set in Scratchpad + io.read.resp.valid := p.valid + p.ready := io.read.resp.ready - result - }) val q_will_be_empty = (q.io.count +& q.io.enq.fire()) - q.io.deq.fire() === 0.U io.read.req.ready := q_will_be_empty && ( @@ -131,27 +226,15 @@ class AccumulatorMem[T <: Data, U <: Data](n: Int, t: Vec[Vec[T]], rdataType: Ve !(io.write.fire() && io.write.bits.acc) && // Make sure we aren't reading something that is still being written !(RegNext(io.write.fire()) && RegNext(io.write.bits.addr) === io.read.req.bits.addr) && - !(w_buf_valid && waddr_buf === io.read.req.bits.addr) - ) - io.read.resp.bits.data := p.bits.data - io.read.resp.bits.full_data := p.bits.full_data - io.read.resp.bits.fromDMA := p.bits.fromDMA - io.read.resp.valid := p.valid - p.ready := io.read.resp.ready + !(w_buf_valid && waddr_buf === io.read.req.bits.addr) && + !block_read_req + ) - if (read_small_data) - io.read.resp.bits.data := p.bits.data - else - io.read.resp.bits.data := 0.U.asTypeOf(p.bits.data) // TODO make this DontCare instead - if (read_full_data) - io.read.resp.bits.full_data := p.bits.full_data - else - io.read.resp.bits.full_data := 0.U.asTypeOf(q.io.enq.bits.full_data) // TODO make this DontCare instead // io.write.current_waddr.valid := mem.io.wen // io.write.current_waddr.bits := mem.io.waddr - io.write.ready := !io.write.bits.acc || (!(io.write.bits.addr === mem.io.waddr && mem.io.wen) && + io.write.ready := !io.write.bits.acc || (!(io.write.bits.addr === waddr_buf && w_buf_valid) && !(io.write.bits.addr === RegNext(io.write.bits.addr) && RegNext(io.write.fire()))) // assert(!(io.read.req.valid && io.write.en && io.write.acc), "reading and accumulating simultaneously is not supported") diff --git a/src/main/scala/gemmini/AccumulatorScale.scala b/src/main/scala/gemmini/AccumulatorScale.scala new file mode 100644 index 00000000..304126fe --- /dev/null +++ b/src/main/scala/gemmini/AccumulatorScale.scala @@ -0,0 +1,216 @@ +package gemmini + +import chisel3._ +import chisel3.util._ + +import Util._ + +class AccumulatorReadRespWithFullData[T <: Data: Arithmetic, U <: Data](fullDataType: Vec[Vec[T]], scale_t: U, shift_width: Int) extends Bundle { + val resp = new AccumulatorReadResp(fullDataType, scale_t, shift_width) + val full_data = fullDataType.cloneType + override def cloneType: this.type = new AccumulatorReadRespWithFullData(fullDataType.cloneType, scale_t, shift_width).asInstanceOf[this.type] +} + + +class AccumulatorScaleResp[T <: Data: Arithmetic](fullDataType: Vec[Vec[T]], rDataType: Vec[Vec[T]]) extends Bundle { + val full_data = fullDataType.cloneType + val data = rDataType.cloneType + val acc_bank_id = UInt(2.W) + val fromDMA = Bool() + override def cloneType: this.type = new AccumulatorScaleResp(fullDataType, rDataType).asInstanceOf[this.type] +} + +class AccumulatorScaleIO[T <: Data: Arithmetic, U <: Data]( + fullDataType: Vec[Vec[T]], scale_t: U, shift_width: Int, + rDataType: Vec[Vec[T]] +) extends Bundle { + val in = Flipped(Decoupled(new AccumulatorReadResp[T,U](fullDataType, scale_t, shift_width))) + val out = Decoupled(new AccumulatorScaleResp[T](fullDataType, rDataType)) + override def cloneType: this.type = new AccumulatorScaleIO(fullDataType, scale_t, + shift_width, rDataType).asInstanceOf[this.type] +} + +class AccScaleDataWithIndex[T <: Data: Arithmetic, U <: Data](t: T, u: U, scale_args: ScaleArguments[T, U]) extends Bundle { + val shift_width = log2Ceil(t.getWidth) + + val scale = u.cloneType + val act = UInt(2.W) + val relu6_shift = UInt(shift_width.W) + val data = t.cloneType + val full_data = t.cloneType + val id = UInt(2.W) // TODO hardcoded + val index = UInt() + override def cloneType: this.type = new AccScaleDataWithIndex(t, u, scale_args: ScaleArguments[T, U]).asInstanceOf[this.type] +} + +class AccScalePipe[T <: Data : Arithmetic, U <: Data](t: T, rDataType: Vec[Vec[T]], scale_args: ScaleArguments[T, U])(implicit ev: Arithmetic[T]) extends Module { + val u = scale_args.multiplicand_t + val io = IO(new Bundle { + val in = Input(Valid(new AccScaleDataWithIndex(t, u, scale_args)(ev))) + val out = Output(Valid(new AccScaleDataWithIndex(t, u, scale_args)(ev))) + }) + import ev._ + val latency = scale_args.latency + val out = WireInit(io.in) + + val e_scaled = scale_args.scale_func(io.in.bits.data, io.in.bits.scale) + val e_clipped = e_scaled.clippedToWidthOf(rDataType.head.head) + val e_act = MuxCase(e_clipped, Seq( + (io.in.bits.act === Activation.RELU) -> e_clipped.relu, + (io.in.bits.act === Activation.RELU6) -> e_clipped.relu6(io.in.bits.relu6_shift))) + + out.bits.data := e_act + io.out := Pipe(out, latency) +} + + +class AccumulatorScale[T <: Data: Arithmetic, U <: Data]( + fullDataType: Vec[Vec[T]], rDataType: Vec[Vec[T]], + scale_t: U, shift_width: Int, + read_small_data: Boolean, read_full_data: Boolean, + scale_args: ScaleArguments[T, U])(implicit ev: Arithmetic[T]) extends Module { + + import ev._ + val io = IO(new AccumulatorScaleIO[T,U]( + fullDataType, scale_t, shift_width, rDataType + )(ev)) + val t = io.in.bits.data(0)(0).cloneType + val out = Wire(Decoupled(new AccumulatorScaleResp[T]( + fullDataType, rDataType)(ev))) + + val num_scale_units = scale_args.num_scale_units + val acc_scale_latency = scale_args.latency + + if (num_scale_units == -1) { + val in = Wire(Decoupled(new AccumulatorReadRespWithFullData(fullDataType, scale_t, shift_width)(ev))) + in.valid := io.in.valid + io.in.ready := in.ready + in.bits.resp := io.in.bits + in.bits.full_data := io.in.bits.data + + val pipe_out = Pipeline(in, acc_scale_latency, Seq.fill(acc_scale_latency)((x: AccumulatorReadRespWithFullData[T,U]) => x) :+ { + x: AccumulatorReadRespWithFullData[T,U] => + val activated_rdata = VecInit(x.resp.data.map(v => VecInit(v.map { e => + // val e_scaled = e >> x.shiftls + val e_scaled = scale_args.scale_func(e, x.resp.scale) + val e_clipped = e_scaled.clippedToWidthOf(rDataType.head.head) + val e_act = MuxCase(e_clipped, Seq( + (x.resp.act === Activation.RELU) -> e_clipped.relu, + (x.resp.act === Activation.RELU6) -> e_clipped.relu6(x.resp.relu6_shift))) + + e_act + }))) + val result = WireInit(x) + result.resp.data := activated_rdata + result + }) + out.valid := pipe_out.valid + pipe_out.ready := out.ready + out.bits.full_data := pipe_out.bits.full_data + out.bits.data := pipe_out.bits.resp.data + out.bits.fromDMA := pipe_out.bits.resp.fromDMA + out.bits.acc_bank_id := pipe_out.bits.resp.acc_bank_id + } else { + val width = io.in.bits.data.size * io.in.bits.data(0).size + val nEntries = 3 + val regs = Reg(Vec(nEntries, Valid(new AccumulatorReadResp[T,U]( + fullDataType, scale_t, shift_width)(ev)))) + val out_regs = Reg(Vec(nEntries, new AccumulatorScaleResp[T]( + fullDataType, rDataType)(ev))) + + val fired_masks = Reg(Vec(nEntries, Vec(width, Bool()))) + val completed_masks = Reg(Vec(nEntries, Vec(width, Bool()))) + val head_oh = RegInit(1.U(nEntries.W)) + val tail_oh = RegInit(1.U(nEntries.W)) + out.valid := Mux1H(head_oh.asBools, (regs zip completed_masks).map({case (r, c) => r.valid && c.reduce(_&&_)})) + out.bits := Mux1H(head_oh.asBools, out_regs) + when (out.fire()) { + for (i <- 0 until nEntries) { + when (head_oh(i)) { + regs(i).valid := false.B + } + } + head_oh := (head_oh << 1) | head_oh(nEntries-1) + } + + io.in.ready := !Mux1H(tail_oh.asBools, regs.map(_.valid)) || (tail_oh === head_oh && out.fire()) + when (io.in.fire()) { + for (i <- 0 until nEntries) { + when (tail_oh(i)) { + regs(i).valid := true.B + regs(i).bits := io.in.bits + out_regs(i).fromDMA := io.in.bits.fromDMA + out_regs(i).acc_bank_id := io.in.bits.acc_bank_id + fired_masks(i).foreach(_ := false.B) + completed_masks(i).foreach(_ := false.B) + } + } + tail_oh := (tail_oh << 1) | tail_oh(nEntries-1) + } + + val inputs = Seq.fill(width*nEntries) { Wire(Decoupled(new AccScaleDataWithIndex(t, scale_t, scale_args)(ev))) } + + for (i <- 0 until nEntries) { + for (w <- 0 until width) { + val input = inputs(i*width+w) + input.valid := regs(i).valid && !fired_masks(i)(w) + input.bits.data := regs(i).bits.data(w / io.in.bits.data(0).size)(w % io.in.bits.data(0).size) + input.bits.full_data := regs(i).bits.data(w / io.in.bits.data(0).size)(w % io.in.bits.data(0).size) + input.bits.scale := regs(i).bits.scale + input.bits.act := regs(i).bits.act + input.bits.relu6_shift := regs(i).bits.relu6_shift + input.bits.id := i.U + input.bits.index := w.U + when (input.fire()) { + fired_masks(i)(w) := true.B + } + } + } + for (i <- 0 until num_scale_units) { + val arbIn = inputs.zipWithIndex.filter({ case (_, w) => w % num_scale_units == i }).map(_._1) + val arb = Module(new RRArbiter(new AccScaleDataWithIndex(t, scale_t, scale_args)(ev), arbIn.length)) + arb.io.in <> arbIn + arb.io.out.ready := true.B + val arbOut = Reg(Valid(new AccScaleDataWithIndex(t, scale_t, scale_args)(ev))) + arbOut.valid := arb.io.out.valid + arbOut.bits := arb.io.out.bits + when (reset.asBool) { + arbOut.valid := false.B + } + val pipe = Module(new AccScalePipe(t, rDataType, scale_args)(ev, ev)) + pipe.io.in := arbOut + val pipe_out = pipe.io.out + + for (j <- 0 until nEntries) { + for (w <- 0 until width) { + if ((j*width+w) % num_scale_units == i) { + val id0 = w % io.in.bits.data(0).size + val id1 = w / io.in.bits.data(0).size + when (pipe_out.fire() && pipe_out.bits.id === j.U && pipe_out.bits.index === w.U) { + out_regs(j).data (id1)(id0) := pipe_out.bits.data + out_regs(j).full_data(id1)(id0) := pipe_out.bits.full_data + completed_masks(j)(w) := true.B + } + } + } + } + } + when (reset.asBool) { + regs.foreach(_.valid := false.B) + } + } + + io.out <> out + + if (read_small_data) + io.out.bits.data := out.bits.data + else + io.out.bits.data := DontCare + + if (read_full_data) + io.out.bits.full_data := out.bits.full_data + else + io.out.bits.full_data := DontCare + +} + diff --git a/src/main/scala/gemmini/Arithmetic.scala b/src/main/scala/gemmini/Arithmetic.scala index 0fcac90e..9170b834 100644 --- a/src/main/scala/gemmini/Arithmetic.scala +++ b/src/main/scala/gemmini/Arithmetic.scala @@ -17,6 +17,7 @@ abstract class ArithmeticOps[T <: Data](self: T) { def +(t: T): T def >>(u: UInt): T // This is a rounding shift! Rounds away from 0 def >(t: T): Bool + def identity: T def withWidthOf(t: T): T def clippedToWidthOf(t: T): T // Like "withWidthOf", except that it saturates def relu: T @@ -62,6 +63,7 @@ object Arithmetic { } override def zero: UInt = 0.U + override def identity: UInt = 1.U } } @@ -111,6 +113,7 @@ object Arithmetic { } override def zero: SInt = 0.S + override def identity: SInt = 1.S } } @@ -271,7 +274,25 @@ object Arithmetic { */ } - override def >(t: Float): Bool = true.B // TODO + override def >(t: Float): Bool = { + // Recode all operands + val t_rec = recFNFromFN(t.expWidth, t.sigWidth, t.bits) + val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits) + + // Resize t to self's width + val t_resizer = Module(new RecFNToRecFN(t.expWidth, t.sigWidth, self.expWidth, self.sigWidth)) + t_resizer.io.in := t_rec + t_resizer.io.roundingMode := consts.round_near_even + t_resizer.io.detectTininess := consts.tininess_afterRounding + val t_rec_resized = t_resizer.io.out + + val comparator = Module(new CompareRecFN(self.expWidth, self.sigWidth)) + comparator.io.a := self_rec + comparator.io.b := t_rec_resized + comparator.io.signaling := false.B + + comparator.io.gt + } override def withWidthOf(t: Float): Float = { val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits) @@ -375,6 +396,7 @@ object Arithmetic { } override def zero: Float = 0.U.asTypeOf(self) + override def identity: Float = Cat(0.U(2.W), ~(0.U((self.expWidth-1).W)), 0.U((self.sigWidth-1).W)).asTypeOf(self) } } } diff --git a/src/main/scala/gemmini/BeatMerger.scala b/src/main/scala/gemmini/BeatMerger.scala index cac08aac..c3922f15 100644 --- a/src/main/scala/gemmini/BeatMerger.scala +++ b/src/main/scala/gemmini/BeatMerger.scala @@ -31,7 +31,6 @@ class BeatMerger[U <: Data](beatBits: Int, maxShift: Int, spadWidth: Int, accWid val io = IO(new Bundle { val req = Flipped(Decoupled(new XactTrackerEntry(maxShift, spadWidth, accWidth, spadRows, accRows, maxReqBytes, mvin_scale_t_bits, nCmds))) val in = Flipped(Decoupled(UInt(beatBits.W))) - // val in = Flipped(Decoupled(new BeatPackerIn(beatBits))) val out = Decoupled(new BeatMergerOut(spadWidth, accWidth, spadRows, accRows, alignedTo)) }) @@ -72,7 +71,7 @@ class BeatMerger[U <: Data](beatBits: Int, maxShift: Int, spadWidth: Int, accWid i.U >= spad_row_offset && i.U < spad_row_offset +& (req.bits.bytes_to_read - bytesSent) }) - io.out.bits.addr := req.bits.addr + meshRows.U * { + io.out.bits.addr := req.bits.addr + req.bits.block_stride * { val total_bytes_sent = req.bits.spad_row_offset + bytesSent Mux(req.bits.has_acc_bitwidth, // We only add "if" statements here to satisfy the Verilator linter. The code would be cleaner without the diff --git a/src/main/scala/gemmini/Configs.scala b/src/main/scala/gemmini/Configs.scala index 5e6b0336..c6006497 100644 --- a/src/main/scala/gemmini/Configs.scala +++ b/src/main/scala/gemmini/Configs.scala @@ -35,58 +35,42 @@ class WithMultiRoCC extends Config((site, here, up) => { // ----------------------- object GemminiConfigs { - // import Arithmetic.FloatArithmetic._ - val defaultConfig = GemminiArrayConfig[SInt, Float, Float]( - // val defaultConfig = GemminiArrayConfig[Float, Float]( + opcodes = OpcodeSet.custom3, + tileRows = 1, tileColumns = 1, - // meshRows = 4, - // meshColumns = 4, meshRows = 16, meshColumns = 16, + ld_queue_length = 8, st_queue_length = 2, ex_queue_length = 8, + rob_entries = 16, + hasIm2col = false, //declare im2col block + sp_banks = 4, + sp_singleported = true, acc_banks = 2, + acc_singleported = false, + num_acc_sub_banks = -1, sp_capacity = CapacityInKilobytes(256), shifter_banks = 1, // TODO add separate parameters for left and up shifter banks dataflow = Dataflow.BOTH, acc_capacity = CapacityInKilobytes(64), - mem_pipeline = 1, - hasIm2col = true, //declare im2col block + mem_pipeline = 4, dma_maxbytes = 64, // TODO get this from cacheblockbytes dma_buswidth = 128, // TODO get this from SystemBusKey aligned_to = 1, + tlb_size = 4, + use_tlb_register_filter = true, + max_in_flight_reqs = 16, + use_dedicated_tl_port = false, inputType = SInt(8.W), outputType = SInt(20.W), accType = SInt(32.W), - // inputType = Float(8, 24), - // outputType = Float(8, 24), - // accType = Float(8, 24), - - // mvin_scale_args = Some(MvinScaleArguments((t: SInt, u: SInt) => t * u, 0, SInt(8.W))), - // mvin_scale_acc_args = Some(MvinScaleArguments((t: SInt, u: SInt) => t * u, 0, SInt(8.W))), - // mvin_scale_args = None, - -// mvin_scale_args = Some(ScaleArguments( -// (t: SInt, s: SInt) => { -// // The equation we use can be found here: https://riscv.github.io/documents/riscv-v-spec/#_vector_fixed_point_rounding_mode_register_vxrm -// -// // TODO Do we need to explicitly handle the cases where "u" is a small number (like 0)? What is the default behavior here? -// val u = s.asUInt() -// val point_five = Mux(u === 0.U, 0.U, t(u - 1.U)) -// val zeros = Mux(u <= 1.U, 0.U, t.asUInt() & ((1.U << (u - 1.U)).asUInt() - 1.U)) =/= 0.U -// val ones_digit = t(u) -// -// val r = (point_five & (zeros | ones_digit)).asBool() -// -// Mux(s >= 0.S, ((t >> u).asSInt() + Mux(r, 1.S, 0.S)).asSInt(), (t << (0.S-s).asUInt()).asSInt()) -// }, -// 0, SInt(8.W), "0")), mvin_scale_args = Some(ScaleArguments( (t: SInt, f: Float) => { @@ -122,13 +106,11 @@ object GemminiConfigs { Mux(overflow, sat, rec_fn_to_in.io.out.asTypeOf(t)) }, - 0, Float(8, 24), + 4, Float(8, 24), 4, identity = "1.0", c_str = "({float y = ROUND_NEAR_EVEN((x) * (scale)); y > INT8_MAX ? INT8_MAX : (y < INT8_MIN ? INT8_MIN : (elem_t)y);})" )), - mvin_scale_acc_args = None, - mvin_scale_shared = false, acc_scale_args = ScaleArguments( @@ -165,20 +147,34 @@ object GemminiConfigs { Mux(overflow, sat, rec_fn_to_in.io.out.asTypeOf(t)) }, - 0, Float(8, 24), + 1, Float(8, 24), -1, // TODO pipelining should be 5 identity = "1.0", c_str = "({float y = ROUND_NEAR_EVEN((x) * (scale)); y > INT8_MAX ? INT8_MAX : (y < INT8_MIN ? INT8_MIN : (acc_t)y);})" ), acc_read_full_width = true, acc_read_small_width = true, - use_dedicated_tl_port = false, + pe_latency = 0, - tlb_size = 4, - use_tlb_register_filter = true, - max_in_flight_reqs = 16, + ex_read_from_spad = true, + ex_read_from_acc = true, + ex_write_to_spad = true, + ex_write_to_acc = true + ) + + val chipConfig = defaultConfig.copy(sp_capacity=CapacityInKilobytes(64), acc_capacity=CapacityInKilobytes(32), dataflow=Dataflow.WS, + acc_scale_args=defaultConfig.acc_scale_args.copy(latency=4), + acc_singleported=true, + num_acc_sub_banks=2, + ex_read_from_acc=false, + ex_write_to_spad=false + ) + val largeChipConfig = chipConfig.copy(sp_capacity=CapacityInKilobytes(128), acc_capacity=CapacityInKilobytes(64), + meshRows=32, meshColumns=32 ) + + val highPerfConfig = defaultConfig.copy(dataflow=Dataflow.WS, acc_read_full_width = false, ex_read_from_acc = false, ex_write_to_spad = false, max_in_flight_reqs = 64) } /** @@ -186,18 +182,19 @@ object GemminiConfigs { Also sets the system bus width to 128 bits (instead of the deafult 64 bits) to allow for the default 16x16 8-bit systolic array to be attached. */ -class DefaultGemminiConfig extends Config((site, here, up) => { +class DefaultGemminiConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( + gemminiConfig: GemminiArrayConfig[T,U,V] = GemminiConfigs.defaultConfig +) extends Config((site, here, up) => { case BuildRoCC => up(BuildRoCC) ++ Seq( - (p: Parameters) => { - implicit val q = p - val gemmini = LazyModule(new Gemmini(OpcodeSet.custom3, GemminiConfigs.defaultConfig)) - gemmini + (p: Parameters) => { + implicit val q = p + val gemmini = LazyModule(new Gemmini(gemminiConfig)) + gemmini } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) - /** * Mixin which configures a smaller host processor for the systolic array. This mixin **replaces** the default host rocket (assuming a single core config). @@ -231,7 +228,7 @@ class GemminiHostMiniCore extends Config((site, here, up) => { (up(RocketTilesKey, site).length - 1 -> Seq((p: Parameters) => { implicit val q = p - val gemmini = LazyModule(new Gemmini(OpcodeSet.custom3, GemminiConfigs.defaultConfig)) + val gemmini = LazyModule(new Gemmini(GemminiConfigs.defaultConfig)) gemmini })) }) @@ -270,7 +267,7 @@ class WithGemminiHostMiniCore extends Config((site, here, up) => { (up(RocketTilesKey, site).length -> Seq((p: Parameters) => { implicit val q = p - val gemmini = LazyModule(new Gemmini(OpcodeSet.custom3, GemminiConfigs.defaultConfig)) + val gemmini = LazyModule(new Gemmini(GemminiConfigs.defaultConfig)) gemmini })) }) @@ -316,5 +313,3 @@ class GemminiAcceleratorDeviceConfig extends Config( new WithoutTLMonitors ++ new freechips.rocketchip.system.DefaultConfig ) - - diff --git a/src/main/scala/gemmini/ConfigsFP.scala b/src/main/scala/gemmini/ConfigsFP.scala new file mode 100644 index 00000000..946915ca --- /dev/null +++ b/src/main/scala/gemmini/ConfigsFP.scala @@ -0,0 +1,167 @@ +package gemmini + +import chisel3._ +import freechips.rocketchip.config.{Config, Parameters} +import freechips.rocketchip.diplomacy.{LazyModule, ValName} +import freechips.rocketchip.subsystem._ +import freechips.rocketchip.tile.{BuildRoCC, OpcodeSet} + +// ----------------------------- +// Floating Point Config Mixins +// ----------------------------- + + +object GemminiFPConfigs { + import Arithmetic.FloatArithmetic._ + val defaultFPConfig = GemminiArrayConfig[Float, Float, Float]( + opcodes = OpcodeSet.custom3, + tileRows = 1, + tileColumns = 1, + meshRows = 4, + meshColumns = 4, + + ld_queue_length = 8, + st_queue_length = 2, + ex_queue_length = 8, + + rob_entries = 16, + + hasIm2col = false, + + sp_banks = 4, + sp_singleported = true, + acc_banks = 1, + acc_singleported = false, + num_acc_sub_banks = -1, + sp_capacity = CapacityInKilobytes(256), + shifter_banks = 1, // TODO add separate parameters for left and up shifter banks + dataflow = Dataflow.BOTH, + acc_capacity = CapacityInKilobytes(64), + mem_pipeline = 1, + + dma_maxbytes = 64, // TODO get this from cacheblockbytes + dma_buswidth = 128, // TODO get this from SystemBusKey + aligned_to = 1, + tlb_size = 4, + use_tlb_register_filter = true, + max_in_flight_reqs = 16, + use_dedicated_tl_port = false, + + inputType = Float(8, 24), + outputType = Float(8, 24), + accType = Float(8, 24), + + mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + mvin_scale_shared = false, + + acc_scale_args = ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", + c_str = "((x) * (scale))" + ), + acc_read_full_width = true, + acc_read_small_width = true, + + pe_latency = 1, + + ex_read_from_spad = true, + ex_read_from_acc = true, + ex_write_to_spad = true, + ex_write_to_acc = true, + ) + + //FP32 Single Precision Configuration + val FP32DefaultConfig = defaultFPConfig.copy(inputType = Float(8, 24), outputType = Float(8, 24), accType = Float(8, 24), + pe_latency = 2, + mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + ) + + //FP16 Half Precision Configuration + val FP16DefaultConfig = defaultFPConfig.copy(inputType = Float(5, 11), outputType = Float(5, 11), accType = Float(8, 24), + pe_latency = 2, + mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))")), + mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))")), + ) + + //Bfloat16 Brain-half Precision Configuration + val BF16DefaultConfig = defaultFPConfig.copy(inputType = Float(8, 8), outputType = Float(8, 8), accType = Float(8, 24), + pe_latency = 2, + mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + ) + + //Bfloat16 Brain-half Precision Configuration 8x8 array + val BF16Default8Config = defaultFPConfig.copy(inputType = Float(8, 8), outputType = Float(8, 8), accType = Float(8, 24), + meshRows = 8, meshColumns = 8, + pe_latency = 2, + mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + ) + +} + + +//===========FP32 Default Config========= +class GemminiFP32DefaultConfig extends Config((site, here, up) => { + case BuildRoCC => Seq( + (p: Parameters) => { + implicit val q = p + implicit val v = implicitly[ValName] + LazyModule(new Gemmini(GemminiFPConfigs.FP32DefaultConfig)) + } + ) + case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) +}) + + +//===========FP16 Default Config========= +class GemminiFP16DefaultConfig extends Config((site, here, up) => { + case BuildRoCC => Seq( + (p: Parameters) => { + implicit val q = p + implicit val v = implicitly[ValName] + LazyModule(new Gemmini(GemminiFPConfigs.FP16DefaultConfig)) + } + ) + case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) +}) + +//===========BFLOAT16 Default Config========= +class GemminiBF16DefaultConfig extends Config((site, here, up) => { + case BuildRoCC => Seq( + (p: Parameters) => { + implicit val q = p + implicit val v = implicitly[ValName] + LazyModule(new Gemmini(GemminiFPConfigs.BF16DefaultConfig)) + } + ) + case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) +}) + +class GemminiBF16DefaultHighPerfConfig extends Config((site, here, up) => { + case BuildRoCC => Seq( + (p: Parameters) => { + implicit val q = p + implicit val v = implicitly[ValName] + val gemmini = LazyModule(new Gemmini(GemminiFPConfigs.BF16DefaultConfig.copy( + ex_read_from_acc = false, + ex_write_to_spad = false, + ))) + gemmini + } + ) + case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) +}) + +//===========BFLOAT16 Default Config 8x8========= +class GemminiBF16Default8Config extends Config((site, here, up) => { + case BuildRoCC => Seq( + (p: Parameters) => { + implicit val q = p + implicit val v = implicitly[ValName] + LazyModule(new Gemmini(GemminiFPConfigs.BF16Default8Config)) + } + ) + case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) +}) + diff --git a/src/main/scala/gemmini/Controller.scala b/src/main/scala/gemmini/Controller.scala index d64cb3ce..d9a0f93a 100644 --- a/src/main/scala/gemmini/Controller.scala +++ b/src/main/scala/gemmini/Controller.scala @@ -9,7 +9,7 @@ import chisel3.util._ import freechips.rocketchip.config._ import freechips.rocketchip.diplomacy._ import freechips.rocketchip.tile._ -import freechips.rocketchip.tilelink.{TLIdentityNode} +import freechips.rocketchip.tilelink.TLIdentityNode import GemminiISA._ import Util._ @@ -20,84 +20,10 @@ class GemminiCmd(rob_entries: Int)(implicit p: Parameters) extends Bundle { override def cloneType: this.type = new GemminiCmd(rob_entries).asInstanceOf[this.type] } - -class LocalAddr(sp_banks: Int, sp_bank_entries: Int, acc_banks: Int, acc_bank_entries: Int) extends Bundle { - private val localAddrBits = 32 // TODO magic number - - private val spAddrBits = log2Ceil(sp_banks * sp_bank_entries) - private val accAddrBits = log2Ceil(acc_banks * acc_bank_entries) - private val maxAddrBits = spAddrBits max accAddrBits - - private val spBankBits = log2Up(sp_banks) - private val spBankRowBits = log2Up(sp_bank_entries) - - private val accBankBits = log2Up(acc_banks) - private val accBankRowBits = log2Up(acc_bank_entries) - - val is_acc_addr = Bool() - val accumulate = Bool() - val read_full_acc_row = Bool() - val garbage = UInt(((localAddrBits - maxAddrBits - 4) max 0).W) - val garbage_bit = if (localAddrBits - maxAddrBits >= 4) UInt(1.W) else UInt(0.W) - val data = UInt(maxAddrBits.W) - - def sp_bank(dummy: Int = 0) = if (spAddrBits == spBankRowBits) 0.U else data(spAddrBits - 1, spBankRowBits) - def sp_row(dummy: Int = 0) = data(spBankRowBits - 1, 0) - def acc_bank(dummy: Int = 0) = if (accAddrBits == accBankRowBits) 0.U else data(accAddrBits - 1, accBankRowBits) - def acc_row(dummy: Int = 0) = data(accBankRowBits - 1, 0) - - def full_sp_addr(dummy: Int = 0) = data(spAddrBits - 1, 0) - def full_acc_addr(dummy: Int = 0) = data(accAddrBits - 1, 0) - - def is_same_address(other: LocalAddr): Bool = is_acc_addr === other.is_acc_addr && data === other.data - def is_same_address(other: UInt): Bool = is_same_address(other.asTypeOf(this)) - def is_garbage(dummy: Int = 0) = is_acc_addr && accumulate && read_full_acc_row && data.andR() && - (if (garbage_bit.getWidth > 0) garbage_bit.asBool() else true.B) - - def +(other: UInt) = { - require(isPow2(sp_bank_entries)) // TODO remove this requirement - require(isPow2(acc_bank_entries)) // TODO remove this requirement - - val result = WireInit(this) - result.data := data + other - result - } - - def <=(other: LocalAddr) = - is_acc_addr === other.is_acc_addr && - Mux(is_acc_addr, full_acc_addr() <= other.full_acc_addr(), full_sp_addr() <= other.full_sp_addr()) - - def >(other: LocalAddr) = - is_acc_addr === other.is_acc_addr && - Mux(is_acc_addr, full_acc_addr() > other.full_acc_addr(), full_sp_addr() > other.full_sp_addr()) - - def add_with_overflow(other: UInt): Tuple2[LocalAddr, Bool] = { - require(isPow2(sp_bank_entries)) // TODO remove this requirement - require(isPow2(acc_bank_entries)) // TODO remove this requirement - - val sum = data +& other - - val result = WireInit(this) - result.data := sum(data.getWidth-1, 0) - - (result, sum(data.getWidth)) - } - - def make_this_garbage(dummy: Int = 0): Unit = { - is_acc_addr := true.B - accumulate := true.B - read_full_acc_row := true.B - garbage_bit := 1.U - data := ~(0.U(maxAddrBits.W)) - } - - override def cloneType: LocalAddr.this.type = new LocalAddr(sp_banks, sp_bank_entries, acc_banks, acc_bank_entries).asInstanceOf[this.type] -} - -class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](opcodes: OpcodeSet, val config: GemminiArrayConfig[T, U, V]) +class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiArrayConfig[T, U, V]) (implicit p: Parameters) extends LazyRoCC ( - opcodes = OpcodeSet.custom3, + opcodes = config.opcodes, nPTWPorts = 1) { Files.write(Paths.get(config.headerFilePath), config.generateHeader().getBytes(StandardCharsets.UTF_8)) @@ -177,7 +103,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] */ // Incoming commands and ROB - val rob = Module(new ROB(new RoCCCommand, rob_entries, local_addr_t, meshRows*tileRows, meshColumns*tileColumns)) + val rob = Module(new ROB(outer.config, new RoCCCommand)) val raw_cmd_q = Module(new Queue(new RoCCCommand, 2)) val fence_stall = io.cmd.bits.inst.funct === FENCE_CMD && io.busy @@ -204,11 +130,23 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] val raw_cmd = raw_cmd_q.io.deq + // TODO replace 4,12,2 with parameters based on ROB size + val (conv_cmd, loop_conv_unroller_busy) = LoopConv(raw_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization, + meshRows*tileRows, coreMaxAddrBits, rob_entries, 4, 12, 2, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries, + inputType.getWidth, accType.getWidth, dma_maxbytes) + // val (compressed_cmd, compressor_busy) = InstCompressor(unrolled_cmd) // compressed_cmd.ready := false.B - val (unrolled_cmd, loop_unroller_busy) = LoopMatmul(raw_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization, - meshRows*tileRows, coreMaxAddrBits, rob_entries, 4, 12, 2, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries, + + // val (unrolled_cmd, loop_matmul_unroller_busy) = LoopMatmul(unrolled_cmd_after_conv, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization, + val max_lds = rob_entries * 1 / 4 + val max_exs = rob_entries * 3 / 4 + val max_sts = rob_entries * 1 / 8 + val (loop_cmd, loop_matmul_unroller_busy) = LoopMatmul(conv_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization, + meshRows*tileRows, coreMaxAddrBits, rob_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries, inputType.getWidth, accType.getWidth, dma_maxbytes) + + val unrolled_cmd = Queue(loop_cmd) unrolled_cmd.ready := false.B // val cmd_decompressor = Module(new InstDecompressor(rob_entries)) @@ -307,7 +245,8 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] spad.module.io.dma.write <> store_controller.io.dma ex_controller.io.srams.read <> spad.module.io.srams.read ex_controller.io.srams.write <> spad.module.io.srams.write - ex_controller.io.acc.read <> spad.module.io.acc.read + spad.module.io.acc.read_req <> ex_controller.io.acc.read_req + ex_controller.io.acc.read_resp <> spad.module.io.acc.read_resp ex_controller.io.acc.write <> spad.module.io.acc.write // Im2Col unit @@ -382,12 +321,12 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] rob_completed_arb.io.out.ready := true.B // Wire up global RoCC signals - io.busy := (raw_cmd.valid || loop_unroller_busy || rob.io.busy || spad.module.io.busy) && !tlb.io.exp.interrupt + io.busy := (raw_cmd.valid || loop_conv_unroller_busy || loop_matmul_unroller_busy || rob.io.busy || spad.module.io.busy || unrolled_cmd.valid || loop_cmd.valid || conv_cmd.valid) && !tlb.io.exp.interrupt io.interrupt := false.B rob.io.solitary_preload := ex_controller.io.solitary_preload - assert(!io.interrupt, "Interrupt handlers have not been written yet") + // assert(!io.interrupt, "Interrupt handlers have not been written yet") // Cycle counters val ld_cycles = RegInit(0.U(34.W)) diff --git a/src/main/scala/gemmini/DMA.scala b/src/main/scala/gemmini/DMA.scala index d79295b3..7af1751d 100644 --- a/src/main/scala/gemmini/DMA.scala +++ b/src/main/scala/gemmini/DMA.scala @@ -7,7 +7,7 @@ import chisel3.experimental.DataMirror import freechips.rocketchip.config.Parameters import freechips.rocketchip.diplomacy.{IdRange, LazyModule, LazyModuleImp} import freechips.rocketchip.tile.{CoreBundle, HasCoreParameters} -import freechips.rocketchip.tilelink.{TLBundleA} +import freechips.rocketchip.tilelink.TLBundleA import testchipip.TLHelper import freechips.rocketchip.rocket.MStatus import freechips.rocketchip.rocket.constants.MemoryOpConstants @@ -24,6 +24,7 @@ class StreamReadRequest[U <: Data](spad_rows: Int, acc_rows: Int, mvin_scale_t_b val status = new MStatus val len = UInt(16.W) // TODO magic number val repeats = UInt(16.W) // TODO magic number + val block_stride = UInt(16.W) // TODO magic number val cmd_id = UInt(8.W) // TODO magic number override def cloneType: StreamReadRequest.this.type = new StreamReadRequest(spad_rows, acc_rows, mvin_scale_t_bits).asInstanceOf[this.type] @@ -38,7 +39,7 @@ class StreamReadResponse[U <: Data](spadWidth: Int, accWidth: Int, spad_rows: In val accumulate = Bool() val has_acc_bitwidth = Bool() val scale = UInt(mvin_scale_t_bits.W) - val rows = UInt(16.W) // TODO magic number + val repeats = UInt(16.W) // TODO magic number val last = Bool() val bytes_read = UInt(8.W) // TODO magic number val cmd_id = UInt(8.W) // TODO magic number @@ -93,7 +94,7 @@ class StreamReader[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T io.resp.bits.accumulate := beatPacker.io.out.bits.accumulate io.resp.bits.has_acc_bitwidth := beatPacker.io.out.bits.has_acc_bitwidth io.resp.bits.scale := RegEnable(xactTracker.io.peek.entry.scale, beatPacker.io.req.fire()) - io.resp.bits.rows := RegEnable(xactTracker.io.peek.entry.rows, beatPacker.io.req.fire()) + io.resp.bits.repeats := RegEnable(xactTracker.io.peek.entry.repeats, beatPacker.io.req.fire()) io.resp.bits.cmd_id := RegEnable(xactTracker.io.peek.entry.cmd_id, beatPacker.io.req.fire()) io.resp.bits.bytes_read := RegEnable(xactTracker.io.peek.entry.bytes_to_read, beatPacker.io.req.fire()) io.resp.bits.last := beatPacker.io.out.bits.last @@ -217,10 +218,9 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf io.tlb.req.bits.tlb_req.vaddr := tlb_q.io.deq.bits.vaddr io.tlb.req.bits.tlb_req.passthrough := false.B io.tlb.req.bits.tlb_req.size := 0.U // send_size - io.tlb.req.bits.tlb_req.cmd := M_XWR + io.tlb.req.bits.tlb_req.cmd := M_XRD io.tlb.req.bits.status := tlb_q.io.deq.bits.status - val translate_q = Module(new Queue(new TLBundleAWithInfo, 1, pipe=true)) translate_q.io.enq <> tlb_q.io.deq translate_q.io.deq.ready := true.B @@ -229,29 +229,30 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf retry_a.bits := translate_q.io.deq.bits assert(retry_a.ready) - tl.a.valid := translate_q.io.deq.valid && !io.tlb.resp.miss - tl.a.bits := translate_q.io.deq.bits.tl_a + tl.a.valid := translate_q.io.deq.valid && !io.tlb.resp.miss + tl.a.bits := translate_q.io.deq.bits.tl_a tl.a.bits.address := io.tlb.resp.paddr - io.reserve.valid := state === s_req_new_block && untranslated_a.ready // TODO decouple "reserve.valid" from "tl.a.ready" io.reserve.entry.shift := read_shift io.reserve.entry.is_acc := req.is_acc io.reserve.entry.accumulate := req.accumulate io.reserve.entry.has_acc_bitwidth := req.has_acc_bitwidth io.reserve.entry.scale := req.scale - io.reserve.entry.rows := req.repeats + io.reserve.entry.repeats := req.repeats + io.reserve.entry.block_stride := req.block_stride io.reserve.entry.lg_len_req := DontCare // TODO just remove this from the IO completely io.reserve.entry.bytes_to_read := read_bytes_read io.reserve.entry.cmd_id := req.cmd_id - io.reserve.entry.addr := req.spaddr + meshRows.U * + io.reserve.entry.addr := req.spaddr + req.block_stride * Mux(req.has_acc_bitwidth, // We only add "if" statements here to satisfy the Verilator linter. The code would be cleaner without the // "if" condition and the "else" clause if (bytesRequested.getWidth >= log2Up(accWidthBytes+1)) bytesRequested / accWidthBytes.U else 0.U, if (bytesRequested.getWidth >= log2Up(spadWidthBytes+1)) bytesRequested / spadWidthBytes.U else 0.U) io.reserve.entry.spad_row_offset := Mux(req.has_acc_bitwidth, bytesRequested % accWidthBytes.U, bytesRequested % spadWidthBytes.U) + when (untranslated_a.fire()) { val next_vaddr = req.vaddr + read_bytes_read // send_size val new_page = next_vaddr(pgIdxBits-1, 0) === 0.U @@ -286,10 +287,11 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf } } -class StreamWriteRequest(val dataWidth: Int)(implicit p: Parameters) extends CoreBundle { +class StreamWriteRequest(val dataWidth: Int, val maxBytes: Int)(implicit p: Parameters) extends CoreBundle { val vaddr = UInt(coreMaxAddrBits.W) val data = UInt(dataWidth.W) - val len = UInt(16.W) // The number of bytes to write // TODO magic number + val len = UInt(log2Up((dataWidth/8 max maxBytes)+1).W) // The number of bytes to write + val block = UInt(8.W) // TODO magic number val status = new MStatus // Pooling variables @@ -311,11 +313,13 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: val beatBytes = beatBits / 8 val lgBeatBytes = log2Ceil(beatBytes) val maxBeatsPerReq = maxBytes / beatBytes + val inputTypeRowBytes = block_cols * inputType.getWidth / 8 + val maxBlocks = maxBytes / inputTypeRowBytes require(beatBytes > 0) val io = IO(new Bundle { - val req = Flipped(Decoupled(new StreamWriteRequest(dataWidth))) + val req = Flipped(Decoupled(new StreamWriteRequest(dataWidth, maxBytes))) val tlb = new FrontendTLBIO val busy = Output(Bool()) val flush = Input(Bool()) @@ -324,9 +328,14 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: val (s_idle :: s_writing_new_block :: s_writing_beats :: Nil) = Enum(3) val state = RegInit(s_idle) - val req = Reg(new StreamWriteRequest(dataWidth)) + val req = Reg(new StreamWriteRequest(dataWidth, maxBytes)) + + // TODO use the same register to hold data_blocks and data_single_block, so that this Mux here is not necessary + val data_blocks = Reg(Vec(maxBlocks, UInt((inputTypeRowBytes * 8).W))) + val data_single_block = Reg(UInt(dataWidth.W)) // For data that's just one-block-wide + val data = Mux(req.block === 0.U, data_single_block, data_blocks.asUInt()) - val bytesSent = Reg(UInt(log2Ceil(dataBytes).W)) // TODO this only needs to count up to (dataBytes/aligned_to), right? + val bytesSent = Reg(UInt(log2Ceil((dataBytes max maxBytes)+1).W)) // TODO this only needs to count up to (dataBytes/aligned_to), right? val bytesLeft = req.len - bytesSent val xactBusy = RegInit(0.U(nXacts.W)) @@ -346,13 +355,15 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: // Select the size and mask of the TileLink request class Packet extends Bundle { - val size = UInt(log2Ceil(maxBytes).W) - val lg_size = UInt(log2Ceil(log2Ceil(maxBytes)).W) + val size = UInt(log2Ceil(maxBytes+1).W) + val lg_size = UInt(log2Ceil(log2Ceil(maxBytes+1)+1).W) val mask = Vec(maxBeatsPerReq, Vec(beatBytes, Bool())) val vaddr = UInt(vaddrBits.W) val is_full = Bool() - def bytes_written(dummy: Int = 0) = PopCount(mask.flatten) + val bytes_written = UInt(log2Up(dataBytes+1).W) + val bytes_written_per_beat = Vec(maxBeatsPerReq, UInt(log2Up(beatBytes+1).W)) + def total_beats(dummy: Int = 0) = Mux(size < beatBytes.U, 1.U, size / beatBytes.U) } @@ -364,16 +375,12 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: val write_packets = write_sizes.map { s => val lg_s = log2Ceil(s) val vaddr_aligned_to_size = if (s == 1) vaddr else Cat(vaddr(vaddrBits-1, lg_s), 0.U(lg_s.W)) + val vaddr_offset = if (s > 1) vaddr(lg_s - 1, 0) else 0.U - val mask = (0 until maxBytes).map { i => - if (s > 1) { - val vaddr_offset = vaddr(lg_s - 1, 0) + val mask = (0 until maxBytes).map { i => i.U >= vaddr_offset && i.U < vaddr_offset +& bytesLeft && (i < s).B } - i.U >= vaddr_offset && - i.U < vaddr_offset +& bytesLeft - } else { - true.B - } && (i < s).B + val bytes_written = { + Mux(vaddr_offset +& bytesLeft > s.U, s.U - vaddr_offset, bytesLeft) } val packet = Wire(new Packet()) @@ -383,10 +390,29 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: packet.vaddr := vaddr_aligned_to_size packet.is_full := mask.take(s).reduce(_ && _) + packet.bytes_written := bytes_written + packet.bytes_written_per_beat.zipWithIndex.foreach { case (b, i) => + val start_of_beat = i * beatBytes + val end_of_beat = (i+1) * beatBytes + + val left_shift = Mux(vaddr_offset >= start_of_beat.U && vaddr_offset < end_of_beat.U, + vaddr_offset - start_of_beat.U, + 0.U) + + val right_shift = Mux(vaddr_offset +& bytesLeft >= start_of_beat.U && vaddr_offset +& bytesLeft < end_of_beat.U, + end_of_beat.U - (vaddr_offset +& bytesLeft), + 0.U) + + val too_early = vaddr_offset >= end_of_beat.U + val too_late = vaddr_offset +& bytesLeft <= start_of_beat.U + + b := Mux(too_early || too_late, 0.U, beatBytes.U - (left_shift +& right_shift)) + } + packet } val best_write_packet = write_packets.reduce { (acc, p) => - Mux(p.bytes_written() > acc.bytes_written(), p, acc) + Mux(p.bytes_written > acc.bytes_written, p, acc) } val write_packet = RegEnableThru(best_write_packet, state === s_writing_new_block) @@ -402,21 +428,21 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: val write_mask = write_packet.mask(beatsSent) val write_shift = PriorityEncoder(write_mask) - val bytes_written_this_beat = PopCount(write_mask) + val bytes_written_this_beat = write_packet.bytes_written_per_beat(beatsSent) // Firing off TileLink write requests val putFull = edge.Put( fromSource = RegEnableThru(xactId, state === s_writing_new_block), toAddress = 0.U, lgSize = lg_write_size, - data = (req.data >> (bytesSent * 8.U)).asUInt() + data = (data >> (bytesSent * 8.U)).asUInt() )._2 val putPartial = edge.Put( fromSource = RegEnableThru(xactId, state === s_writing_new_block), toAddress = 0.U, lgSize = lg_write_size, - data = ((req.data >> (bytesSent * 8.U)) << (write_shift * 8.U)).asUInt(), + data = ((data >> (bytesSent * 8.U)) << (write_shift * 8.U)).asUInt(), mask = write_mask.asUInt() )._2 @@ -427,7 +453,7 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: } val untranslated_a = Wire(Decoupled(new TLBundleAWithInfo)) - xactBusy_fire := untranslated_a.fire() + xactBusy_fire := untranslated_a.fire() && state === s_writing_new_block untranslated_a.valid := (state === s_writing_new_block || state === s_writing_beats) && !xactBusy.andR() untranslated_a.bits.tl_a := Mux(write_full, putFull, putPartial) untranslated_a.bits.vaddr := write_vaddr @@ -435,14 +461,18 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: // 0 goes to retries, 1 goes to state machine val retry_a = Wire(Decoupled(new TLBundleAWithInfo)) - val tlb_arb = Module(new Arbiter(new TLBundleAWithInfo, 2)) + val shadow_retry_a = Module(new Queue(new TLBundleAWithInfo, 1)) + shadow_retry_a.io.enq.valid := false.B + shadow_retry_a.io.enq.bits := DontCare + val tlb_arb = Module(new Arbiter(new TLBundleAWithInfo, 3)) tlb_arb.io.in(0) <> retry_a - tlb_arb.io.in(1) <> untranslated_a + tlb_arb.io.in(1) <> shadow_retry_a.io.deq + tlb_arb.io.in(2) <> untranslated_a val tlb_q = Module(new Queue(new TLBundleAWithInfo, 1, pipe=true)) tlb_q.io.enq <> tlb_arb.io.out - io.tlb.req.valid := tlb_q.io.deq.valid + io.tlb.req.valid := tlb_q.io.deq.fire() io.tlb.req.bits.tlb_req.vaddr := tlb_q.io.deq.bits.vaddr io.tlb.req.bits.tlb_req.passthrough := false.B io.tlb.req.bits.tlb_req.size := 0.U // send_size @@ -451,15 +481,20 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: val translate_q = Module(new Queue(new TLBundleAWithInfo, 1, pipe=true)) translate_q.io.enq <> tlb_q.io.deq - translate_q.io.deq.ready := true.B + when (retry_a.valid) { + translate_q.io.enq.valid := false.B + shadow_retry_a.io.enq.valid := tlb_q.io.deq.valid + shadow_retry_a.io.enq.bits := tlb_q.io.deq.bits + } + translate_q.io.deq.ready := tl.a.ready || io.tlb.resp.miss - retry_a.valid := translate_q.io.deq.valid && (io.tlb.resp.miss || !tl.a.ready) + retry_a.valid := translate_q.io.deq.valid && io.tlb.resp.miss retry_a.bits := translate_q.io.deq.bits - assert(retry_a.ready) + assert(!(retry_a.valid && !retry_a.ready)) - tl.a.valid := translate_q.io.deq.valid && !io.tlb.resp.miss - tl.a.bits := translate_q.io.deq.bits.tl_a - tl.a.bits.address := io.tlb.resp.paddr + tl.a.valid := translate_q.io.deq.valid && !io.tlb.resp.miss + tl.a.bits := translate_q.io.deq.bits.tl_a + tl.a.bits.address := RegEnableThru(io.tlb.resp.paddr, RegNext(io.tlb.req.fire())) tl.d.ready := xactBusy.orR() @@ -467,7 +502,7 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: when (state === s_writing_new_block) { beatsLeft := write_beats - 1.U - val next_vaddr = req.vaddr + bytes_written_this_beat + val next_vaddr = req.vaddr + write_packet.bytes_written req.vaddr := next_vaddr bytesSent := bytesSent + bytes_written_this_beat @@ -485,9 +520,9 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: beatsLeft := beatsLeft - 1.U bytesSent := bytesSent + bytes_written_this_beat - when (beatsLeft === 0.U) { - val new_page = req.vaddr(pgIdxBits-1, 0) === 0.U + assert(beatsLeft > 0.U) + when (beatsLeft === 1.U) { when (bytes_written_this_beat >= bytesLeft) { // We're done with this request at this point state_machine_ready_for_req := true.B @@ -504,17 +539,23 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: val pooled = { val cols = dataWidth / inputType.getWidth val v1 = io.req.bits.data.asTypeOf(Vec(cols, inputType)) - val v2 = req.data.asTypeOf(Vec(cols, inputType)) + val v2 = data_single_block.asTypeOf(Vec(cols, inputType)) val m = v1.zip(v2) VecInit(m.zipWithIndex.map{case ((x, y), i) => if (i < block_cols) maxOf(x, y) else y}).asUInt() } req := io.req.bits - req.data := Mux(io.req.bits.pool_en, pooled, io.req.bits.data) + req.len := io.req.bits.block * inputTypeRowBytes.U + io.req.bits.len + + data_single_block := Mux(io.req.bits.pool_en, pooled, io.req.bits.data) + data_blocks(io.req.bits.block) := io.req.bits.data bytesSent := 0.U state := Mux(io.req.bits.store_en, s_writing_new_block, s_idle) + + assert(io.req.bits.len <= (block_cols * inputType.getWidth / 8).U || io.req.bits.block === 0.U, "DMA can't write multiple blocks to main memory when writing full accumulator output") + assert(!io.req.bits.pool_en || io.req.bits.block === 0.U, "Can't pool with block-mvout") } } } diff --git a/src/main/scala/gemmini/DMAReadCommandTracker.scala b/src/main/scala/gemmini/DMACommandTracker.scala similarity index 86% rename from src/main/scala/gemmini/DMAReadCommandTracker.scala rename to src/main/scala/gemmini/DMACommandTracker.scala index 386bf52e..2632f753 100644 --- a/src/main/scala/gemmini/DMAReadCommandTracker.scala +++ b/src/main/scala/gemmini/DMACommandTracker.scala @@ -6,7 +6,7 @@ import chisel3.util._ // This module is meant to go inside the Load controller, where it can track which commands are currently // in flight and which are completed -class DMAReadCommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: => T) extends Module { +class DMACommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: => T) extends Module { def cmd_id_t = UInt((log2Ceil(nCmds) max 1).W) val io = IO(new Bundle { @@ -24,12 +24,6 @@ class DMAReadCommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: override def cloneType: this.type = new BitsT(tag_t.cloneType, cmd_id_t.cloneType).asInstanceOf[this.type] } - /*val bits = new Bundle { - val tag = Input(tag_t) - val bytes_to_read = Input(UInt(log2Up(maxBytes+1).W)) - val cmd_id = Output(cmd_id_t) - }*/ - val bits = new BitsT(tag_t.cloneType, cmd_id_t.cloneType) def fire(dummy: Int = 0) = valid && ready @@ -43,11 +37,6 @@ class DMAReadCommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: override def cloneType: this.type = new RequestReturnedT(cmd_id_t.cloneType).asInstanceOf[this.type] } - /*val request_returned = Flipped(Valid(new Bundle { - val bytes_read = UInt(log2Up(maxBytes+1).W) - val cmd_id = cmd_id_t - }))*/ - val request_returned = Flipped(Valid(new RequestReturnedT(cmd_id_t.cloneType))) class CmdCompletedT(cmd_id_t: UInt, tag_t: T) extends Bundle { @@ -57,11 +46,6 @@ class DMAReadCommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: override def cloneType: this.type = new CmdCompletedT(cmd_id_t.cloneType, tag_t.cloneType).asInstanceOf[this.type] } - /*val cmd_completed = Decoupled(new Bundle { - val cmd_id = cmd_id_t - val tag = tag_t - })*/ - val cmd_completed = Decoupled(new CmdCompletedT(cmd_id_t.cloneType, tag_t.cloneType)) val busy = Output(Bool()) diff --git a/src/main/scala/gemmini/DMAWriteCommandTracker.scala b/src/main/scala/gemmini/DMAWriteCommandTracker.scala deleted file mode 100644 index 30765a44..00000000 --- a/src/main/scala/gemmini/DMAWriteCommandTracker.scala +++ /dev/null @@ -1,9 +0,0 @@ -package gemmini - -import chisel3._ -import chisel3.util._ - -object DMAWriteCommandTracker { - def apply[T <: Data](nCmds: Int, nRows: Int, tag_t: => T) = Module(new DMAReadCommandTracker(nCmds = nCmds, - maxBytes = nRows, tag_t = tag_t)) -} diff --git a/src/main/scala/gemmini/DSEConfigs.scala b/src/main/scala/gemmini/DSEConfigs.scala index 51bdc192..da3b1795 100644 --- a/src/main/scala/gemmini/DSEConfigs.scala +++ b/src/main/scala/gemmini/DSEConfigs.scala @@ -13,6 +13,7 @@ import freechips.rocketchip.tile.{BuildRoCC, OpcodeSet} object DSEBaseConfig { val baseConfig = GemminiArrayConfig[SInt, Bool, UInt]( + opcodes = OpcodeSet.custom3, tileRows = 1, tileColumns = 1, meshRows = 16, @@ -23,7 +24,10 @@ object DSEBaseConfig { rob_entries = 8, sp_banks = 4, // TODO support one-bank designs acc_banks = 1, + acc_singleported = false, + num_acc_sub_banks = -1, sp_capacity = CapacityInKilobytes(64), + sp_singleported = false, shifter_banks = 1, // TODO add separate parameters for left and up shifter banks dataflow = Dataflow.OS, acc_capacity = CapacityInKilobytes(16), @@ -50,12 +54,17 @@ object DSEBaseConfig { val r = (point_five & (zeros | ones_digit)).asBool() (t >> u).asSInt() + Mux(r, 1.S, 0.S) - }, 0, UInt(8.W)), + }, 0, UInt(8.W), -1), acc_read_full_width = true, acc_read_small_width = true, use_dedicated_tl_port = false, pe_latency = 0, + ex_read_from_spad = true, + ex_read_from_acc = true, + ex_write_to_spad = true, + ex_write_to_acc = true, + tlb_size = 4, use_tlb_register_filter = true, max_in_flight_reqs = 16, @@ -90,7 +99,7 @@ class GemminiParamsDSE1 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.baseConfig)) + LazyModule(new Gemmini(DSEConfigs.baseConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -102,7 +111,7 @@ class GemminiParamsDSE2 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.wsOnlyConfig)) + LazyModule(new Gemmini(DSEConfigs.wsOnlyConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -114,7 +123,7 @@ class GemminiParamsDSE3 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.bothDataflowsConfig)) + LazyModule(new Gemmini(DSEConfigs.bothDataflowsConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -126,7 +135,7 @@ class GemminiParamsDSE4 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.highBitwidthConfig)) + LazyModule(new Gemmini(DSEConfigs.highBitwidthConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -138,7 +147,7 @@ class GemminiParamsDSE5 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.largerDimConfig)) + LazyModule(new Gemmini(DSEConfigs.largerDimConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -150,7 +159,7 @@ class GemminiParamsDSE6 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.fullyCombinationalConfig)) + LazyModule(new Gemmini(DSEConfigs.fullyCombinationalConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -162,7 +171,7 @@ class GemminiParamsDSE7 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.moreMemoryConfig)) + LazyModule(new Gemmini(DSEConfigs.moreMemoryConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -174,7 +183,7 @@ class GemminiParamsDSE8 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.moreBanksConfig)) + LazyModule(new Gemmini(DSEConfigs.moreBanksConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -186,7 +195,7 @@ class GemminiParamsDSE10 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.narrowerBusConfig)) + LazyModule(new Gemmini(DSEConfigs.narrowerBusConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 8) @@ -198,7 +207,7 @@ class GemminiParamsPnR16 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.pnr16Config)) + LazyModule(new Gemmini(DSEConfigs.pnr16Config)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -210,7 +219,7 @@ class GemminiParamsPnR32 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.pnr32Config)) + LazyModule(new Gemmini(DSEConfigs.pnr32Config)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -222,7 +231,7 @@ class GemminiParamsDSE11 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.baseConfig)) + LazyModule(new Gemmini(DSEConfigs.baseConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) diff --git a/src/main/scala/gemmini/ExecuteController.scala b/src/main/scala/gemmini/ExecuteController.scala index c3601d50..300dab16 100644 --- a/src/main/scala/gemmini/ExecuteController.scala +++ b/src/main/scala/gemmini/ExecuteController.scala @@ -27,7 +27,15 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In } val acc = new Bundle { - val read = Vec(acc_banks, new AccumulatorReadIO(acc_bank_entries, log2Up(accType.getWidth), Vec(meshColumns, Vec(tileColumns, inputType)), Vec(meshColumns, Vec(tileColumns, accType)), acc_scale_args.multiplicand_t)) + val read_req = Vec(acc_banks, Decoupled(new AccumulatorReadReq( + acc_bank_entries, log2Up(accType.getWidth), acc_scale_args.multiplicand_t + ))) + + val read_resp = Flipped(Vec(acc_banks, Decoupled(new AccumulatorScaleResp( + Vec(meshColumns, Vec(tileColumns, inputType)), + Vec(meshColumns, Vec(tileColumns, accType)) + )))) + // val write = Vec(acc_banks, new AccumulatorWriteIO(acc_bank_entries, Vec(meshColumns, Vec(tileColumns, accType)))) val write = Vec(acc_banks, Decoupled(new AccumulatorWriteReq(acc_bank_entries, Vec(meshColumns, Vec(tileColumns, accType))))) } @@ -131,7 +139,6 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In //val row_turn_counter = RegInit(row_turn) im2col_en := Mux(weight_stride === 0.U, false.B, true.B) - // SRAM addresses of matmul operands val a_address_rs1 = rs1s(a_address_place).asTypeOf(local_addr_t) val b_address_rs2 = rs2s(b_address_place).asTypeOf(local_addr_t) @@ -211,6 +218,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In !is_garbage && (mul_raw_haz || pre_raw_haz) }.reduce(_ || _) + val raw_hazards_are_impossible = !ex_read_from_acc && !ex_write_to_spad // Special case where RAW hazards are impossible + val matmul_in_progress = mesh.io.tags_in_progress.map(_.rob_id.valid).reduce(_ || _) io.busy := cmd.valid(0) || matmul_in_progress @@ -242,9 +251,9 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val dataBBankAcc = b_address.acc_bank() val dataDBankAcc = d_address.acc_bank() - val a_read_from_acc = a_address_rs1.is_acc_addr - val b_read_from_acc = b_address_rs2.is_acc_addr - val d_read_from_acc = d_address_rs1.is_acc_addr + val a_read_from_acc = ex_read_from_acc.B && a_address_rs1.is_acc_addr + val b_read_from_acc = ex_read_from_acc.B && b_address_rs2.is_acc_addr + val d_read_from_acc = ex_read_from_acc.B && d_address_rs1.is_acc_addr val start_inputting_a = WireInit(false.B) val start_inputting_b = WireInit(false.B) @@ -322,9 +331,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In !must_wait_for.reduce(_ || _) } - val a_fire = a_valid && a_ready - dontTouch(a_fire) val b_fire = b_valid && b_ready val d_fire = d_valid && d_ready @@ -353,7 +360,6 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In d_fire_started := true.B } - when(performing_mul_pre && !cntl_ready && !mul_pre_counter_lock){ mul_pre_counter_count := d_fire_counter //store 2 }.elsewhen(!performing_mul_pre){ @@ -371,9 +377,9 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In // The last line in this (long) Boolean is just to make sure that we don't think we're done as soon as we begin firing // TODO change when square requirement lifted - val about_to_fire_all_rows = ((a_fire_counter === (block_size-1).U && a_valid) || a_fire_counter === 0.U) && - ((b_fire_counter === (block_size-1).U && b_valid) || b_fire_counter === 0.U) && - ((d_fire_counter === (block_size-1).U && d_valid) || d_fire_counter === 0.U) && + val about_to_fire_all_rows = ((a_fire_counter === (block_size-1).U && a_fire) || a_fire_counter === 0.U) && + ((b_fire_counter === (block_size-1).U && b_fire) || b_fire_counter === 0.U) && + ((d_fire_counter === (block_size-1).U && d_fire) || d_fire_counter === 0.U) && (a_fire_counter =/= 0.U || b_fire_counter =/= 0.U || d_fire_counter =/= 0.U) && cntl_ready @@ -403,19 +409,26 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In } } - io.srams.read(i).req.valid := read_a || read_b || read_d - io.srams.read(i).req.bits.fromDMA := false.B - io.srams.read(i).req.bits.addr := MuxCase(a_address_rs1.sp_row() + a_fire_counter, - Seq(read_b -> (b_address_rs2.sp_row() + b_fire_counter), - read_d -> (d_address_rs1.sp_row() + block_size.U - 1.U - d_fire_counter_mulpre))) - - when(im2col_en === false.B){ - io.srams.read(i).req.bits.addr := MuxCase(a_address.sp_row(), - Seq(read_b -> b_address.sp_row(), - read_d -> d_address.sp_row())) + if (ex_read_from_spad) { + io.srams.read(i).req.valid := (read_a || read_b || read_d) && cntl_ready + io.srams.read(i).req.bits.fromDMA := false.B + io.srams.read(i).req.bits.addr := MuxCase(a_address_rs1.sp_row() + a_fire_counter, + Seq(read_b -> (b_address_rs2.sp_row() + b_fire_counter), + read_d -> (d_address_rs1.sp_row() + block_size.U - 1.U - d_fire_counter_mulpre))) + + // TODO this just overrides the previous line. Should we erase the previous line? + when(im2col_en === false.B) { + io.srams.read(i).req.bits.addr := MuxCase(a_address.sp_row(), + Seq(read_b -> b_address.sp_row(), + read_d -> d_address.sp_row())) + } + } else { + io.srams.read(i).req.valid := false.B + io.srams.read(i).req.bits.fromDMA := false.B + io.srams.read(i).req.bits.addr := DontCare } - io.srams.read(i).resp.ready := true.B + io.srams.read(i).resp.ready := false.B } // Accumulator read @@ -425,45 +438,39 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val read_d_from_acc = d_valid && d_read_from_acc && dataDBankAcc === i.U && start_inputting_d && !preload_zeros && d_row_is_not_all_zeros //&& !im2col_wire Seq((read_a_from_acc, a_ready), (read_b_from_acc, b_ready), (read_d_from_acc, d_ready)).foreach { case (rd, r) => - when(rd && !io.acc.read(i).req.ready) { + when(rd && !io.acc.read_req(i).ready) { r := false.B } } - /* - io.acc.read(i).req.valid := read_a_from_acc || read_b_from_acc || read_d_from_acc - io.acc.read(i).req.bits.scale := acc_scale - io.acc.read(i).req.bits.full := false.B - io.acc.read(i).req.bits.relu6_shift := relu6_shift - io.acc.read(i).req.bits.act := activation - io.acc.read(i).req.bits.fromDMA := false.B - io.acc.read(i).req.bits.addr := MuxCase(a_address_rs1.acc_row() + a_fire_counter, - Seq(read_b_from_acc -> (b_address_rs2.acc_row() + b_fire_counter), - read_d_from_acc -> (d_address_rs1.acc_row() + block_size.U - 1.U - d_fire_counter))) - - when(im2col_en === false.B){ - io.acc.read(i).req.bits.addr := MuxCase(a_address.acc_row(), - Seq(read_b_from_acc -> b_address.acc_row(), - read_d_from_acc -> d_address.acc_row())) - } - */ - - // TODO Remove the ability to read into Mesh from AccumulatorMem completely - io.acc.read(i).req.valid := false.B - io.acc.read(i).req.bits.scale := acc_scale - io.acc.read(i).req.bits.full := false.B - io.acc.read(i).req.bits.relu6_shift := relu6_shift - io.acc.read(i).req.bits.act := activation - io.acc.read(i).req.bits.fromDMA := false.B - io.acc.read(i).req.bits.addr := DontCare - - when(im2col_en === false.B){ - io.acc.read(i).req.bits.addr := MuxCase(a_address.acc_row(), - Seq(read_b_from_acc -> b_address.acc_row(), - read_d_from_acc -> d_address.acc_row())) + if (ex_read_from_acc) { + io.acc.read_req(i).valid := read_a_from_acc || read_b_from_acc || read_d_from_acc + io.acc.read_req(i).bits.scale := acc_scale + io.acc.read_req(i).bits.full := false.B + io.acc.read_req(i).bits.relu6_shift := relu6_shift + io.acc.read_req(i).bits.act := activation + io.acc.read_req(i).bits.fromDMA := false.B + io.acc.read_req(i).bits.addr := MuxCase(a_address_rs1.acc_row() + a_fire_counter, + Seq(read_b_from_acc -> (b_address_rs2.acc_row() + b_fire_counter), + read_d_from_acc -> (d_address_rs1.acc_row() + block_size.U - 1.U - d_fire_counter))) + + // TODO this just overrides the previous line. Should we erase the previous line? + when(im2col_en === false.B){ + io.acc.read_req(i).bits.addr := MuxCase(a_address.acc_row(), + Seq(read_b_from_acc -> b_address.acc_row(), + read_d_from_acc -> d_address.acc_row())) + } + } else { + io.acc.read_req(i).valid := false.B + io.acc.read_req(i).bits.scale := acc_scale + io.acc.read_req(i).bits.full := false.B + io.acc.read_req(i).bits.relu6_shift := relu6_shift + io.acc.read_req(i).bits.act := activation + io.acc.read_req(i).bits.fromDMA := false.B + io.acc.read_req(i).bits.addr := DontCare } - io.acc.read(i).resp.ready := true.B + io.acc.read_resp(i).ready := false.B } // Im2Col reads @@ -495,7 +502,6 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In io.im2col.resp.ready := mesh.io.a.ready } - // FSM logic switch (control_state) { is(waiting_for_cmd) { @@ -514,7 +520,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In in_shift := rs2s(0)(31, 0) // TODO magic number acc_scale := rs1s(0)(xLen - 1, 32).asTypeOf(acc_scale_args.multiplicand_t) // TODO magic number relu6_shift := rs2s(0)(xLen - 1, 32) // TODO magic number - a_addr_stride := rs1s(0)(31, 16) // TODO magic number + a_addr_stride := rs1s(0)(31, 16) // TODO magic number // TODO this needs to be kept in sync with ROB.scala a_transpose := rs1s(0)(8) bd_transpose := rs1s(0)(9) @@ -541,7 +547,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In } // Preload - .elsewhen(DoPreloads(0) && cmd.valid(1) && !raw_hazard_pre) { + .elsewhen(DoPreloads(0) && cmd.valid(1) && (raw_hazards_are_impossible.B || !raw_hazard_pre)) { perform_single_preload := true.B performing_single_preload := true.B @@ -556,7 +562,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In } // Overlap compute and preload - .elsewhen(DoComputes(0) && cmd.valid(1) && DoPreloads(1) && cmd.valid(2) && !raw_hazard_mulpre) { + .elsewhen(DoComputes(0) && cmd.valid(1) && DoPreloads(1) && (raw_hazards_are_impossible.B || (cmd.valid(2) && !raw_hazard_mulpre))) { perform_mul_pre := true.B performing_mul_pre := true.B @@ -751,11 +757,11 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In mesh_cntl_signals_q.io.enq.bits.im2colling := im2col_wire && im2col_en //im2col_wire val readData = VecInit(io.srams.read.map(_.resp.bits.data)) - val accReadData = readData // VecInit(io.acc.read.map(_.resp.bits.data.asUInt())) // TODO remove ability to read from AccumulatorMem + val accReadData = if (ex_read_from_acc) VecInit(io.acc.read_resp.map(_.bits.data.asUInt())) else readData val im2ColData = io.im2col.resp.bits.a_im2col.asUInt() - val readValid = VecInit(io.srams.read.map(bank => bank.resp.valid && !bank.resp.bits.fromDMA)) - val accReadValid = false.B // VecInit(io.acc.read.map(bank => bank.resp.valid && !bank.resp.bits.fromDMA)) // TODO remove ability to read from AccumulatorMem + val readValid = VecInit(io.srams.read.map(bank => ex_read_from_spad.B && bank.resp.valid && !bank.resp.bits.fromDMA)) + val accReadValid = VecInit(io.acc.read_resp.map(bank => ex_read_from_acc.B && bank.valid && !bank.bits.fromDMA)) val im2ColValid = io.im2col.resp.valid mesh_cntl_signals_q.io.deq.ready := (!cntl.a_fire || mesh.io.a.fire() || !mesh.io.a.ready) && @@ -786,6 +792,37 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val dataB = VecInit(dataB_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.b_unpadded_cols, d, inputType.zero)}) val dataD = VecInit(dataD_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.d_unpadded_cols, d, inputType.zero)}) + // Pop responses off the scratchpad io ports + when (mesh_cntl_signals_q.io.deq.fire()) { + when (cntl.a_fire && mesh.io.a.fire() && !cntl.a_garbage && cntl.a_unpadded_cols > 0.U && !cntl.im2colling) { + when (cntl.a_read_from_acc) { + io.acc.read_resp(cntl.a_bank_acc).ready := !io.acc.read_resp(cntl.a_bank_acc).bits.fromDMA + }.otherwise { + io.srams.read(cntl.a_bank).resp.ready := !io.srams.read(cntl.a_bank).resp.bits.fromDMA + } + } + + when (cntl.b_fire && mesh.io.b.fire() && !cntl.b_garbage && !cntl.accumulate_zeros && cntl.b_unpadded_cols > 0.U) { + when (cntl.b_read_from_acc) { + io.acc.read_resp(cntl.b_bank_acc).ready := !io.acc.read_resp(cntl.b_bank_acc).bits.fromDMA + }.otherwise { + io.srams.read(cntl.b_bank).resp.ready := !io.srams.read(cntl.b_bank).resp.bits.fromDMA + } + } + + when (cntl.d_fire && mesh.io.d.fire() && !cntl.d_garbage && !cntl.preload_zeros && cntl.d_unpadded_cols > 0.U) { + when (cntl.d_read_from_acc) { + io.acc.read_resp(cntl.d_bank_acc).ready := !io.acc.read_resp(cntl.d_bank_acc).bits.fromDMA + }.otherwise { + io.srams.read(cntl.d_bank).resp.ready := !io.srams.read(cntl.d_bank).resp.bits.fromDMA + } + } + } + + for (acc_r <- io.acc.read_resp) { + acc_r.ready := true.B + } + when (cntl_valid) { // Default inputs mesh.io.a.valid := cntl.a_fire && dataA_valid @@ -804,14 +841,11 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In } when (cntl_valid && cntl.perform_single_preload) { - // mesh.io.a.bits := Mux(cntl.dataflow === Dataflow.WS.id.U, 0.U, dataA.asUInt).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) mesh.io.a.bits := Mux(a_should_be_fed_into_transposer, dataA.asUInt, 0.U).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) - // mesh.io.b.bits := 0.U.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType))) mesh.io.b.bits := Mux(b_should_be_fed_into_transposer, dataB.asUInt, 0.U).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) } when (cntl_valid && cntl.perform_single_mul) { - // mesh.io.a.bits := Mux(cntl.dataflow === Dataflow.OS.id.U, 0.U, dataA.asUInt).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) mesh.io.a.bits := Mux(a_should_be_fed_into_transposer, 0.U, dataA.asUInt).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) mesh.io.b.bits := Mux(b_should_be_fed_into_transposer, 0.U, dataB.asUInt).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) mesh.io.tag_in.bits.addr.make_this_garbage() @@ -847,20 +881,34 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In e_act }))) - io.srams.write(i).en := start_array_outputting && w_bank === i.U && !write_to_acc && !is_garbage_addr && write_this_row - io.srams.write(i).addr := w_row - io.srams.write(i).data := activated_wdata.asUInt() - // io.srams.write(i).mask := VecInit(Seq.fill(io.srams.write(0).mask.length)(true.B)) - io.srams.write(i).mask := w_mask.flatMap(b => Seq.fill(inputType.getWidth / (aligned_to * 8))(b)) + if (ex_write_to_spad) { + io.srams.write(i).en := start_array_outputting && w_bank === i.U && !write_to_acc && !is_garbage_addr && write_this_row + io.srams.write(i).addr := w_row + io.srams.write(i).data := activated_wdata.asUInt() + io.srams.write(i).mask := w_mask.flatMap(b => Seq.fill(inputType.getWidth / (aligned_to * 8))(b)) + } else { + io.srams.write(i).en := false.B + io.srams.write(i).addr := DontCare + io.srams.write(i).data := DontCare + io.srams.write(i).mask := DontCare + } } // Write to accumulator for (i <- 0 until acc_banks) { - io.acc.write(i).valid := start_array_outputting && w_bank === i.U && write_to_acc && !is_garbage_addr && write_this_row - io.acc.write(i).bits.addr := w_row - io.acc.write(i).bits.data := VecInit(mesh.io.out.bits.map(v => VecInit(v.map(e => e.withWidthOf(accType))))) - io.acc.write(i).bits.acc := w_address.accumulate - io.acc.write(i).bits.mask := w_mask.flatMap(b => Seq.fill(accType.getWidth / (aligned_to * 8))(b)) + if (ex_write_to_acc) { + io.acc.write(i).valid := start_array_outputting && w_bank === i.U && write_to_acc && !is_garbage_addr && write_this_row + io.acc.write(i).bits.addr := w_row + io.acc.write(i).bits.data := VecInit(mesh.io.out.bits.map(v => VecInit(v.map(e => e.withWidthOf(accType))))) + io.acc.write(i).bits.acc := w_address.accumulate + io.acc.write(i).bits.mask := w_mask.flatMap(b => Seq.fill(accType.getWidth / (aligned_to * 8))(b)) + } else { + io.acc.write(i).valid := false.B + io.acc.write(i).bits.addr := DontCare + io.acc.write(i).bits.data := DontCare + io.acc.write(i).bits.acc := DontCare + io.acc.write(i).bits.mask := DontCare + } assert(!(io.acc.write(i).valid && !io.acc.write(i).ready), "Execute controller write to AccumulatorMem was skipped") } diff --git a/src/main/scala/gemmini/FrontendTLB.scala b/src/main/scala/gemmini/FrontendTLB.scala index ae8711a0..819d9f57 100644 --- a/src/main/scala/gemmini/FrontendTLB.scala +++ b/src/main/scala/gemmini/FrontendTLB.scala @@ -102,7 +102,6 @@ class FrontendTLB(nClients: Int, entries: Int, maxSize: Int) val l0_tlb_hit = last_translated_valid && ((client.req.bits.tlb_req.vaddr >> pgIdxBits) === (last_translated_vpn >> pgIdxBits)) val l0_tlb_paddr = Cat(last_translated_ppn >> pgIdxBits, client.req.bits.tlb_req.vaddr(pgIdxBits-1,0)) - when (req.fire() && !tlb.io.resp.miss) { last_translated_valid := true.B last_translated_vpn := req.bits.tlb_req.vaddr diff --git a/src/main/scala/gemmini/GemminiConfigs.scala b/src/main/scala/gemmini/GemminiConfigs.scala index fb9026d4..b8df083b 100644 --- a/src/main/scala/gemmini/GemminiConfigs.scala +++ b/src/main/scala/gemmini/GemminiConfigs.scala @@ -4,15 +4,18 @@ package gemmini import scala.math.{pow,sqrt} import chisel3._ import chisel3.util._ +import freechips.rocketchip.tile._ sealed abstract trait GemminiMemCapacity case class CapacityInKilobytes(kilobytes: Int) extends GemminiMemCapacity case class CapacityInMatrices(matrices: Int) extends GemminiMemCapacity case class ScaleArguments[T <: Data, U <: Data](scale_func: (T, U) => T, latency: Int, multiplicand_t: U, + num_scale_units: Int, identity: String="0", c_str: String="ROUNDING_RIGHT_SHIFT(x, scale)") case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( + opcodes: OpcodeSet, tileRows: Int, tileColumns: Int, meshRows: Int, @@ -22,8 +25,11 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( ex_queue_length: Int, rob_entries: Int, sp_banks: Int, // TODO support one-bank designs + sp_singleported: Boolean, sp_capacity: GemminiMemCapacity, acc_banks: Int, + acc_singleported: Boolean, + num_acc_sub_banks: Int, acc_capacity: GemminiMemCapacity, shifter_banks: Int, dataflow: Dataflow.Value, @@ -50,6 +56,11 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( use_tlb_register_filter: Boolean, max_in_flight_reqs: Int, + ex_read_from_spad: Boolean, + ex_read_from_acc: Boolean, + ex_write_to_spad: Boolean, + ex_write_to_acc: Boolean, + headerFileName: String = "gemmini_params.h" ) { val sp_width = meshColumns * tileColumns * inputType.getWidth @@ -61,16 +72,17 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( case CapacityInKilobytes(kb) => kb * 1024 * 8 / (acc_banks * meshColumns * tileColumns * accType.getWidth) case CapacityInMatrices(ms) => ms * meshRows * tileRows / acc_banks } + require (!acc_singleported || (num_acc_sub_banks <= 4 && isPow2(num_acc_sub_banks))) val local_addr_t = new LocalAddr(sp_banks, sp_bank_entries, acc_banks, acc_bank_entries) val mvin_scale_t = mvin_scale_args match { - case Some(ScaleArguments(_, _, t, _, _)) => t + case Some(ScaleArguments(_, _, t, _, _, _)) => t case None => Bool() // TODO replace this with UInt(0.W) } val mvin_scale_acc_t = mvin_scale_acc_args match { - case Some(ScaleArguments(_, _, t, _, _)) => t + case Some(ScaleArguments(_, _, t, _, _, _)) => t case None => Bool() // TODO replace this with UInt(0.W) } @@ -81,13 +93,14 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( val acc_scale_t_bits = acc_scale_t.getWidth - // val max_in_flight_reqs = 16 // TODO calculate this somehow - - val mvin_len_bits = log2Up(((dma_maxbytes / (inputType.getWidth / 8)) max (meshColumns * tileColumns)) + 1) - val mvin_rows_bits = 16 // log2Up(meshRows * tileRows + 1) - val mvout_len_bits = log2Up(meshColumns * tileColumns + 1) + val mvin_cols_bits = log2Up(((dma_maxbytes / (inputType.getWidth / 8)) max (meshColumns * tileColumns)) + 1) + val mvin_rows_bits = log2Up(meshRows * tileRows + 1) + val mvout_cols_bits = log2Up(((dma_maxbytes / (inputType.getWidth / 8)) max (meshColumns * tileColumns)) + 1) val mvout_rows_bits = log2Up(meshRows * tileRows + 1) + val load_states = 3 + val block_stride_bits = 16 + //========================================================================== // sanity check mesh size //========================================================================== @@ -110,13 +123,15 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( // cisc-gemmini hardware-specific compile-time global constants //========================================================================== + val cisc_dim = (meshRows * tileRows) / 2 + val ITYPE_BITS = inputType.getWidth - val ITYPE_BYTES = (inputType.getWidth+7) / 8 + val ITYPE_BYTES = (inputType.getWidth+cisc_dim-1) / cisc_dim val LOG2_ITYPE_BYTES = if(ITYPE_BYTES <= 1) 0 else log2Up(ITYPE_BYTES) val OTYPE_BITS = accType.getWidth val LOG2_OTYPE_BITS = log2Up(OTYPE_BITS) - val OTYPE_BYTES = (accType.getWidth+7) / 8 + val OTYPE_BYTES = (accType.getWidth+cisc_dim-1) / cisc_dim val LOG2_OTYPE_BYTES = if(OTYPE_BYTES <= 1) 0 else log2Up(OTYPE_BYTES) val SP_BANKS = sp_banks @@ -133,12 +148,12 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( val LOG2_MNK_BYTES = log2Up(MNK_BYTES) val MNK_BYTES_PER_TILE_ROW = MNK_BYTES * DIM val LOG2_MNK_BYTES_PER_TILE_ROW = log2Up(MNK_BYTES_PER_TILE_ROW) - val TILE_IDX = MNK_BYTES / (DIM / 8) + val TILE_IDX = MNK_BYTES / (DIM / cisc_dim) val LOG2_TILE_IDX = log2Up(TILE_IDX) //-------------------------------------------------------------------------- - val I_TILE_BYTE_WIDTH = DIM * ((inputType.getWidth+7) / 8) - val O_TILE_BYTE_WIDTH = DIM * ((accType.getWidth+7) / 8) + val I_TILE_BYTE_WIDTH = DIM * ((inputType.getWidth+cisc_dim-1) / cisc_dim) + val O_TILE_BYTE_WIDTH = DIM * ((accType.getWidth+cisc_dim-1) / cisc_dim) val I_TILE_BYTE_WIDTH_LOG2 = log2Up(I_TILE_BYTE_WIDTH) val O_TILE_BYTE_WIDTH_LOG2 = log2Up(O_TILE_BYTE_WIDTH) require(pow(2,I_TILE_BYTE_WIDTH_LOG2) == I_TILE_BYTE_WIDTH, @@ -187,7 +202,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( (dt.expWidth, dt.sigWidth) match { case (8, 24) => (scala.Float.MinValue.toString, scala.Float.MaxValue.toString) case (11, 53) => (scala.Double.MinValue.toString, scala.Double.MaxValue.toString) - case _ => throw new IllegalArgumentException(s"Only single- and double-precision IEEE754 floating point types are currently supported") + case _ => (((Range(-1,-(dt.sigWidth),-1).map(-Math.pow(2, _)).foldLeft(-1.0)(_ + _)) * Math.pow(2, Math.pow(2, dt.expWidth - 1) - 1)).toString, ((Range(-1,-(dt.sigWidth),-1).map(Math.pow(2, _)).foldLeft(1.0)(_ + _)) * Math.pow(2, Math.pow(2, dt.expWidth - 1) - 1)).toString) } case _ => throw new IllegalArgumentException(s"Data type $dataType is unknown") } @@ -201,7 +216,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( (dt.expWidth, dt.sigWidth) match { case (8, 24) => "float" case (11, 53) => "double" - case _ => throw new IllegalArgumentException(s"Only single- and double-precision IEEE754 floating point types are currently supported") + case _ => s"uint" + (Math.pow(2, Math.ceil(Math.log(dt.expWidth + dt.sigWidth)/Math.log(2.0)))).toInt.toString + s"_t" } case _ => throw new IllegalArgumentException(s"Data type $dataType is unknown") } @@ -221,7 +236,6 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( // assert(Set(8, 16, 32, 64).contains(outputType.getWidth)) assert(Set(8, 16, 32, 64).contains(accType.getWidth)) - assert(acc_scale_args.latency == 0, "Accumulator's scale latency must be 0 cycles") val header = new StringBuilder() header ++= s"#ifndef $guard\n" @@ -230,6 +244,13 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( header ++= s"#include \n" header ++= s"#include \n\n" + val opcodeid = Seq( + OpcodeSet.custom0, OpcodeSet.custom1, OpcodeSet.custom2, OpcodeSet.custom3 + ).indexWhere(o => o.opcodes(0).litValue == opcodes.opcodes(0).litValue) + println(opcodeid, opcodes.opcodes) + require (opcodeid != -1 && opcodes.opcodes.size == 1) + header ++= s"#define XCUSTOM_ACC $opcodeid\n" + header ++= s"#define DIM ${tileColumns*meshColumns}\n" header ++= s"#define ADDR_LEN 32\n" header ++= s"#define BANK_NUM $sp_banks\n" @@ -254,8 +275,15 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( // Datatype of the systolic array val limits = limitsOfDataType(inputType) header ++= s"typedef ${c_type(inputType)} elem_t;\n" - header ++= s"static const elem_t elem_t_max = ${limits._2};\n" - header ++= s"static const elem_t elem_t_min = ${limits._1};\n" + if (inputType.isInstanceOf[Float] && !((inputType.asInstanceOf[Float].expWidth, inputType.asInstanceOf[Float].sigWidth) == (8, 24) || (inputType.asInstanceOf[Float].expWidth, inputType.asInstanceOf[Float].sigWidth) == (11, 53))) + { + header ++= "#define ELEM_T_IS_LOWPREC_FLOAT\n" + header ++= s"static const float elem_t_max = ${limits._2};\n" + header ++= s"static const float elem_t_min = ${limits._1};\n" + } else { + header ++= s"static const elem_t elem_t_max = ${limits._2};\n" + header ++= s"static const elem_t elem_t_min = ${limits._1};\n" + } header ++= s"typedef ${c_type(accType)} acc_t;\n" header ++= s"typedef ${full_c_type(inputType)} full_t;\n\n" @@ -296,7 +324,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( header ++= s"#define row_align_acc(blocks) __attribute__((aligned(blocks*DIM*sizeof(acc_t))))\n\n" val mvin_scale_identity = mvin_scale_args match { - case Some(ScaleArguments(_, _, _, identity, _)) => identity + case Some(ScaleArguments(_, _, _, _, identity, _)) => identity case None => "0" } header ++= s"#define MVIN_SCALE_IDENTITY $mvin_scale_identity\n\n" @@ -333,6 +361,13 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( |""".stripMargin header ++= "\n" + header ++= """// Rounding right shift equation: https://riscv.github.io/documents/riscv-v-spec/#_vector_fixed_point_rounding_mode_register_vxrm +#define ROUNDING_RIGHT_SHIFT_BITS(x, shift) \ +((shift) > 0 ? (((x) >> (shift)) + \ + (((shift) == 0 ? 0 : (((x) >> ((shift)-1)) & 1)) & \ + ((((shift) <= 1 ? 0 : ((x) & ((1 << ((shift)-1)) - 1))) != 0) | (((x) >> (shift)) & 1)))) : ((x) << (-(shift))))""" + header ++= "\n\n" + header ++= """#define ACC_SCALE(x, scale) \ """ header ++= s" ${acc_scale_args.c_str}" @@ -364,7 +399,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( header ++= s"#define ACC_READ_FULL_WIDTH\n" header ++= s"\n" - header ++= s"#endif // $guard" + header ++= s"#endif // $guard\n" header.toString() } diff --git a/src/main/scala/gemmini/GemminiISA.scala b/src/main/scala/gemmini/GemminiISA.scala index b799f8ce..2bf48dad 100644 --- a/src/main/scala/gemmini/GemminiISA.scala +++ b/src/main/scala/gemmini/GemminiISA.scala @@ -22,6 +22,15 @@ object GemminiISA { val LOAD3_CMD = 14.U + // TODO add orows and ocols to this as well + val LOOP_CONV_WS = 15.U // no_bias, no_pool + val LOOP_CONV_WS_CONFIG_1 = 16.U // batch_size, in_dim, in_channels, out_channels | out_dim, pool_out_dim, stride, padding + val LOOP_CONV_WS_CONFIG_2 = 17.U // kernel_dim, pool_size, pool_stride, pool_padding | batches, porows, pocols, pochs + val LOOP_CONV_WS_CONFIG_3 = 18.U // krows, kcols, kchs, lpad | rpad, upad, dpad, plpad + val LOOP_CONV_WS_CONFIG_4 = 19.U // prad, pupad, pdpad, orows | ocols + val LOOP_CONV_WS_CONFIG_5 = 20.U // *weights | *output + val LOOP_CONV_WS_CONFIG_6 = 21.U // *bias, *input + val FENCE_CMD = 127.U // rs1[2:0] values diff --git a/src/main/scala/gemmini/Im2Col.scala b/src/main/scala/gemmini/Im2Col.scala index 52039d4d..f264ad32 100644 --- a/src/main/scala/gemmini/Im2Col.scala +++ b/src/main/scala/gemmini/Im2Col.scala @@ -415,7 +415,7 @@ class Im2Col[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V when(i.U < channel){ im2col_data(i) := sram_req_output(i) }.otherwise{ - im2col_data(i) := 0.S //when channel < 16, pad with 0 + im2col_data(i) := 0.U.asTypeOf(inputType) //when channel < 16, pad with 0 } } } @@ -446,5 +446,6 @@ class Im2Col[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V io.resp.valid := false.B io.req.ready := true.B io.sram_reads.foreach(_.req.valid := false.B) + io.sram_reads.foreach(_.resp.ready := false.B) } } diff --git a/src/main/scala/gemmini/LoadController.scala b/src/main/scala/gemmini/LoadController.scala index cf5f0c57..d9221f5b 100644 --- a/src/main/scala/gemmini/LoadController.scala +++ b/src/main/scala/gemmini/LoadController.scala @@ -6,6 +6,7 @@ import GemminiISA._ import Util._ import freechips.rocketchip.config.Parameters +// TODO we need to check for WAW errors here // TODO deal with errors when reading scratchpad responses class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V], coreMaxAddrBits: Int, local_addr_t: LocalAddr) (implicit p: Parameters) extends Module { @@ -24,9 +25,10 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig val waiting_for_command :: waiting_for_dma_req_ready :: sending_rows :: Nil = Enum(3) val control_state = RegInit(waiting_for_command) - val strides = Reg(Vec(3, UInt(coreMaxAddrBits.W))) - val scales = Reg(Vec(3, UInt(mvin_scale_t_bits.W))) - val shrinks = Reg(Vec(3, Bool())) // Shrink inputs to accumulator + val strides = Reg(Vec(load_states, UInt(coreMaxAddrBits.W))) + val scales = Reg(Vec(load_states, UInt(mvin_scale_t_bits.W))) + val shrinks = Reg(Vec(load_states, Bool())) // Shrink inputs to accumulator + val block_strides = Reg(Vec(load_states, UInt(block_stride_bits.W))) // Spad stride during block move-ins val block_rows = meshRows * tileRows val block_cols = meshColumns * tileColumns val row_counter = RegInit(0.U(log2Ceil(block_rows).W)) @@ -34,22 +36,26 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig val cmd = Queue(io.cmd, ld_queue_length) val vaddr = cmd.bits.cmd.rs1 val localaddr = cmd.bits.cmd.rs2.asTypeOf(local_addr_t) - val cols = cmd.bits.cmd.rs2(32 + mvin_len_bits - 1, 32) // TODO magic numbers + val cols = cmd.bits.cmd.rs2(32 + mvin_cols_bits - 1, 32) // TODO magic numbers val rows = cmd.bits.cmd.rs2(48 + mvin_rows_bits - 1, 48) // TODO magic numbers val config_stride = cmd.bits.cmd.rs2 val config_scale = cmd.bits.cmd.rs1(32 + mvin_scale_t_bits - 1, 32) // TODO magic numbers - val config_shrink = cmd.bits.cmd.rs1(2) + val config_shrink = cmd.bits.cmd.rs1(2) // TODO magic numbers + val config_block_stride = cmd.bits.cmd.rs1(31, 16) // TODO magic numbers val mstatus = cmd.bits.cmd.status val load_state_id = MuxCase(0.U, Seq((cmd.bits.cmd.inst.funct === LOAD2_CMD) -> 1.U, (cmd.bits.cmd.inst.funct === LOAD3_CMD) -> 2.U)) - val config_state_id = cmd.bits.cmd.rs1(4,3) + val config_state_id = cmd.bits.cmd.rs1(4,3) // TODO magic numbers val state_id = Mux(cmd.bits.cmd.inst.funct === CONFIG_CMD, config_state_id, load_state_id) val stride = strides(state_id) val scale = scales(state_id) val shrink = shrinks(state_id) + val block_stride = block_strides(state_id) + + val all_zeros = vaddr === 0.U val localaddr_plus_row_counter = localaddr + row_counter @@ -71,7 +77,7 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig (block_cols * config.accType.getWidth / 8) val maxBytesInMatRequest = block_rows * maxBytesInRowRequest - val cmd_tracker = Module(new DMAReadCommandTracker(nCmds, maxBytesInMatRequest, deps_t)) + val cmd_tracker = Module(new DMACommandTracker(nCmds, maxBytesInMatRequest, deps_t)) io.busy := cmd.valid || cmd_tracker.io.busy @@ -81,10 +87,12 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig (control_state === sending_rows && row_counter =/= 0.U) io.dma.req.bits.vaddr := vaddr + row_counter * stride io.dma.req.bits.laddr := localaddr_plus_row_counter - io.dma.req.bits.len := cols - io.dma.req.bits.repeats := Mux(stride === 0.U, rows - 1.U, 0.U) + io.dma.req.bits.cols := cols + io.dma.req.bits.repeats := Mux(stride === 0.U && !all_zeros, rows - 1.U, 0.U) + io.dma.req.bits.block_stride := block_stride io.dma.req.bits.scale := scale io.dma.req.bits.has_acc_bitwidth := localaddr_plus_row_counter.is_acc_addr && !shrink + io.dma.req.bits.all_zeros := all_zeros io.dma.req.bits.status := mstatus // Command tracker IO @@ -109,6 +117,8 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig // Row counter when (io.dma.req.fire()) { row_counter := wrappingAdd(row_counter, 1.U, actual_rows_read) + + assert(block_stride >= rows) } // Control logic @@ -120,6 +130,7 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig stride := config_stride scale := config_scale shrink := config_shrink + block_stride := config_block_stride cmd.ready := true.B } diff --git a/src/main/scala/gemmini/LocalAddr.scala b/src/main/scala/gemmini/LocalAddr.scala new file mode 100644 index 00000000..6520b7f9 --- /dev/null +++ b/src/main/scala/gemmini/LocalAddr.scala @@ -0,0 +1,83 @@ +package gemmini + +import chisel3._ +import chisel3.util._ + +class LocalAddr(sp_banks: Int, sp_bank_entries: Int, acc_banks: Int, acc_bank_entries: Int) extends Bundle { + private val localAddrBits = 32 // TODO magic number + + private val spAddrBits = log2Ceil(sp_banks * sp_bank_entries) + private val accAddrBits = log2Ceil(acc_banks * acc_bank_entries) + private val maxAddrBits = spAddrBits max accAddrBits + + private val spBankBits = log2Up(sp_banks) + private val spBankRowBits = log2Up(sp_bank_entries) + + private val accBankBits = log2Up(acc_banks) + private val accBankRowBits = log2Up(acc_bank_entries) + + val is_acc_addr = Bool() + val accumulate = Bool() + val read_full_acc_row = Bool() + val garbage = UInt(((localAddrBits - maxAddrBits - 4) max 0).W) + val garbage_bit = if (localAddrBits - maxAddrBits >= 4) UInt(1.W) else UInt(0.W) + val data = UInt(maxAddrBits.W) + + def sp_bank(dummy: Int = 0) = if (spAddrBits == spBankRowBits) 0.U else data(spAddrBits - 1, spBankRowBits) + def sp_row(dummy: Int = 0) = data(spBankRowBits - 1, 0) + def acc_bank(dummy: Int = 0) = if (accAddrBits == accBankRowBits) 0.U else data(accAddrBits - 1, accBankRowBits) + def acc_row(dummy: Int = 0) = data(accBankRowBits - 1, 0) + + def full_sp_addr(dummy: Int = 0) = data(spAddrBits - 1, 0) + def full_acc_addr(dummy: Int = 0) = data(accAddrBits - 1, 0) + + def is_same_address(other: LocalAddr): Bool = is_acc_addr === other.is_acc_addr && data === other.data + def is_same_address(other: UInt): Bool = is_same_address(other.asTypeOf(this)) + def is_garbage(dummy: Int = 0) = is_acc_addr && accumulate && read_full_acc_row && data.andR() && + (if (garbage_bit.getWidth > 0) garbage_bit.asBool() else true.B) + + def +(other: UInt) = { + require(isPow2(sp_bank_entries)) // TODO remove this requirement + require(isPow2(acc_bank_entries)) // TODO remove this requirement + + val result = WireInit(this) + result.data := data + other + result + } + + def <=(other: LocalAddr) = + is_acc_addr === other.is_acc_addr && + Mux(is_acc_addr, full_acc_addr() <= other.full_acc_addr(), full_sp_addr() <= other.full_sp_addr()) + + def <(other: LocalAddr) = + is_acc_addr === other.is_acc_addr && + Mux(is_acc_addr, full_acc_addr() < other.full_acc_addr(), full_sp_addr() < other.full_sp_addr()) + + def >(other: LocalAddr) = + is_acc_addr === other.is_acc_addr && + Mux(is_acc_addr, full_acc_addr() > other.full_acc_addr(), full_sp_addr() > other.full_sp_addr()) + + def add_with_overflow(other: UInt): Tuple2[LocalAddr, Bool] = { + require(isPow2(sp_bank_entries)) // TODO remove this requirement + require(isPow2(acc_bank_entries)) // TODO remove this requirement + + val sum = data +& other + + val overflow = Mux(is_acc_addr, sum(accAddrBits), sum(spAddrBits)) + + val result = WireInit(this) + result.data := sum(maxAddrBits - 1, 0) + + (result, overflow) + } + + def make_this_garbage(dummy: Int = 0): Unit = { + is_acc_addr := true.B + accumulate := true.B + read_full_acc_row := true.B + garbage_bit := 1.U + data := ~(0.U(maxAddrBits.W)) + } + + override def cloneType: LocalAddr.this.type = new LocalAddr(sp_banks, sp_bank_entries, acc_banks, acc_bank_entries).asInstanceOf[this.type] +} diff --git a/src/main/scala/gemmini/LoopConv.scala b/src/main/scala/gemmini/LoopConv.scala new file mode 100644 index 00000000..83f34fcf --- /dev/null +++ b/src/main/scala/gemmini/LoopConv.scala @@ -0,0 +1,1025 @@ +package gemmini + +import chisel3._ +import chisel3.util._ +import chisel3.experimental._ +import freechips.rocketchip.tile.RoCCCommand +import freechips.rocketchip.config.Parameters +import GemminiISA._ +import Util._ + +class LoopConvOuterBounds(val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int) extends Bundle { + val batch_size = UInt(large_iterator_bitwidth.W) + val in_dim = UInt(small_iterator_bitwidth.W) + val in_channels = UInt(large_iterator_bitwidth.W) + val out_channels = UInt(large_iterator_bitwidth.W) + val out_dim = UInt(small_iterator_bitwidth.W) + val pool_out_dim = UInt(small_iterator_bitwidth.W) + val stride = UInt(tiny_iterator_bitwidth.W) + val padding = UInt(tiny_iterator_bitwidth.W) + val kernel_dim = UInt(tiny_iterator_bitwidth.W) + val pool_size = UInt(tiny_iterator_bitwidth.W) + val pool_stride = UInt(tiny_iterator_bitwidth.W) + val pool_padding = UInt(tiny_iterator_bitwidth.W) +} + +class LoopConvInnerBounds(val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int) extends Bundle { + val batches = UInt(large_iterator_bitwidth.W) + val porows = UInt(small_iterator_bitwidth.W) + val pocols = UInt(small_iterator_bitwidth.W) + val pochs = UInt(large_iterator_bitwidth.W) + val krows = UInt(tiny_iterator_bitwidth.W) + val kcols = UInt(tiny_iterator_bitwidth.W) + val kchs = UInt(large_iterator_bitwidth.W) + val lpad = UInt(tiny_iterator_bitwidth.W) + val rpad = UInt(tiny_iterator_bitwidth.W) + val upad = UInt(tiny_iterator_bitwidth.W) + val dpad = UInt(tiny_iterator_bitwidth.W) + val plpad = UInt(tiny_iterator_bitwidth.W) + val prad = UInt(tiny_iterator_bitwidth.W) + val pupad = UInt(tiny_iterator_bitwidth.W) + val pdpad = UInt(tiny_iterator_bitwidth.W) + val orows = UInt(small_iterator_bitwidth.W) + val ocols = UInt(small_iterator_bitwidth.W) +} + +class LoopConvDerivedParams(val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int) extends Bundle { + val ochs = UInt(large_iterator_bitwidth.W) + + val irows = UInt(small_iterator_bitwidth.W) + val icols = UInt(small_iterator_bitwidth.W) + val irows_unpadded = UInt(small_iterator_bitwidth.W) + val icols_unpadded = UInt(small_iterator_bitwidth.W) + val ichs = UInt(large_iterator_bitwidth.W) + + val out_channels_per_bank = UInt(small_iterator_bitwidth.W) // TODO this won't work for systolic arrays above 256 in size + + val bias_spad_stride = UInt(large_iterator_bitwidth.W) + val input_spad_stride = UInt(large_iterator_bitwidth.W) + val weight_spad_stride = UInt(large_iterator_bitwidth.W) + + val ex_overwrite = Bool() +} + +class LoopConvLdBiasReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int, val max_acc_addr: Int, val concurrent_loops: Int) extends Bundle { + val outer_bounds = new LoopConvOuterBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val inner_bounds = new LoopConvInnerBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val derived_params = new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val addr_start = UInt(log2Up(max_acc_addr).W) + val dram_addr = UInt(coreMaxAddrBits.W) + val no_bias = Bool() + val loop_id = UInt(log2Up(concurrent_loops).W) +} + +class LoopConvLdBias(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: Int, small_iterator_bitwidth: Int, tiny_iterator_bitwidth: Int, max_acc_addr: Int, acc_w: Int, + max_block_len_acc: Int, concurrent_loops: Int)(implicit p: Parameters) extends Module { + val MVIN_SCALE_IDENTITY = 0x3f800000.U // TODO get this from configs somehow + + val io = IO(new Bundle { + val req = Flipped(Decoupled(new LoopConvLdBiasReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth: Int, max_acc_addr, concurrent_loops))) + val cmd = Decoupled(Output(new RoCCCommand)) + + val idle = Output(Bool()) + val rob_overloaded = Input(Bool()) + val wait_for_prev_loop = Input(Bool()) + + val loop_id = Output(UInt(log2Up(concurrent_loops).W)) + }) + + object State extends ChiselEnum { + val idle, config, ld = Value + } + import State._ + val state = RegInit(idle) + + val req = Reg(new LoopConvLdBiasReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth: Int, max_acc_addr, concurrent_loops)) + import req.inner_bounds._ + import req.derived_params._ + + val acc_addr_start = (BigInt(1) << 31).U | req.addr_start + + // Derived parameters + val max_ochs_per_mvin = Mux(ochs < (max_block_len_acc * block_size).U, ochs, (max_block_len_acc * block_size).U) + + val skip = req.no_bias || (req.dram_addr === 0.U) + + // Iterators + val b = Reg(UInt(large_iterator_bitwidth.W)) + val orow = Reg(UInt(small_iterator_bitwidth.W)) + val ocol = Reg(UInt(small_iterator_bitwidth.W)) + val och = Reg(UInt(large_iterator_bitwidth.W)) + + // Addresses + val dram_addr = req.dram_addr +& och * (acc_w/8).U + val spad_addr = acc_addr_start +& (och / block_size.U) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol + + // Sizes + val I = Mux(ocols - ocol > block_size.U, block_size.U, ocols - ocol) + val J = Mux(ochs - och > max_ochs_per_mvin, max_ochs_per_mvin, ochs - och) + + // Commands + val config_cmd = Wire(new RoCCCommand) + config_cmd := DontCare + config_cmd.inst.funct := CONFIG_CMD + config_cmd.rs1 := (MVIN_SCALE_IDENTITY << 32.U) | (req.derived_params.bias_spad_stride << 16.U) | (2.U << 3) | 1.U + config_cmd.rs2 := 0.U + + val mvin_cmd = Wire(new RoCCCommand) + mvin_cmd := DontCare + mvin_cmd.inst.funct := LOAD3_CMD + mvin_cmd.rs1 := dram_addr + mvin_cmd.rs2 := (I << 48.U) | (J << 32.U) | spad_addr + + // Inputs and outputs + io.req.ready := state === idle + io.idle := state === idle + io.loop_id := req.loop_id + + io.cmd.valid := state =/= idle && !io.rob_overloaded && !io.wait_for_prev_loop && !skip + io.cmd.bits := Mux(state === config, config_cmd, mvin_cmd) + + // Sending outputs + when (skip) { + state := idle + }.elsewhen(io.cmd.fire()) { + when (state === config) { + state := ld + }.otherwise { + val next_och = floorAdd(och, max_ochs_per_mvin, ochs) + val next_ocol = floorAdd(ocol, block_size.U, ocols, next_och === 0.U) + val next_orow = floorAdd(orow, 1.U, orows, next_ocol === 0.U && next_och === 0.U) + val next_b = floorAdd(b, 1.U, batches, next_orow === 0.U && next_ocol === 0.U && next_och === 0.U) + + och := next_och + ocol := next_ocol + orow := next_orow + b := next_b + + state := Mux(next_b === 0.U && next_orow === 0.U && next_ocol === 0.U && next_och === 0.U, + idle, ld) + } + } + + // Accepting requests + when (io.req.fire()) { + req := io.req.bits + state := config + b := 0.U + orow := 0.U + ocol := 0.U + och := 0.U + } +} + +class LoopConvLdInputReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int, val max_acc_addr: Int, val concurrent_loops: Int) extends Bundle { + val outer_bounds = new LoopConvOuterBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val inner_bounds = new LoopConvInnerBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val derived_params = new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val addr_start = UInt(log2Up(max_acc_addr).W) + val dram_addr = UInt(coreMaxAddrBits.W) + val loop_id = UInt(log2Up(concurrent_loops).W) +} + +class LoopConvLdInput(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: Int, small_iterator_bitwidth: Int, tiny_iterator_bitwidth: Int, max_addr: Int, input_w: Int, + max_block_len: Int, concurrent_loops: Int)(implicit p: Parameters) extends Module { + val MVIN_SCALE_IDENTITY = 0x3f800000.U // TODO get this from configs somehow + + val io = IO(new Bundle { + val req = Flipped(Decoupled(new LoopConvLdInputReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, concurrent_loops))) + val cmd = Decoupled(Output(new RoCCCommand)) + + val idle = Output(Bool()) + val rob_overloaded = Input(Bool()) + val wait_for_prev_loop = Input(Bool()) + + val loop_id = Output(UInt(log2Up(concurrent_loops).W)) + }) + + object State extends ChiselEnum { + val idle, config, ld = Value + } + import State._ + val state = RegInit(idle) + + val req = Reg(new LoopConvLdInputReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, concurrent_loops)) + import req.outer_bounds._ + import req.inner_bounds._ + import req.derived_params._ + + // Derived parameters + val max_ichs_per_mvin = Mux(ichs < (max_block_len * block_size).U, ichs, (max_block_len * block_size).U).zext() + + // Iterators + val b = Reg(SInt(large_iterator_bitwidth.W)) + val irow = Reg(SInt(small_iterator_bitwidth.W)) + val icol = Reg(SInt(small_iterator_bitwidth.W)) + val ich = Reg(SInt(large_iterator_bitwidth.W)) + + // Calculated params + val irow_padded = irow +& upad.zext() + val icol_padded = icol +& lpad.zext() + val is_zeros = irow < 0.S || irow >= irows_unpadded.zext() || icol < 0.S || icol >= icols_unpadded.zext() + + // Addresses + val dram_addr = Mux(is_zeros, 0.U, + req.dram_addr +& (((b * in_dim * in_dim +& irow*in_dim +& icol) * in_channels +& ich) * (input_w/8).U).asUInt()) + val spad_addr = req.addr_start.zext() +& (ich / block_size.S) * batches * irows * icols +& b * irows * icols +& irow_padded * icols +& icol_padded + + // Sizes + val I = MuxCase( + Mux(icols_unpadded.zext() -& icol > block_size.S, block_size.S, icols_unpadded.zext() -& icol), + Seq( + (icol < 0.S) -> Mux((0.S-&icol) > block_size.S, block_size.S, 0.S-&icol), + (icol >= icols_unpadded.zext()) -> Mux(icols_unpadded.zext() +& rpad.zext() -& icol > block_size.S, block_size.S, icols_unpadded.zext() +& rpad.zext() -& icol) + ) + ) + val K = Mux(ichs.zext() -& ich > max_ichs_per_mvin, max_ichs_per_mvin, ichs.zext() -& ich) + + // Commands + val config_cmd = Wire(new RoCCCommand) + config_cmd := DontCare + config_cmd.inst.funct := CONFIG_CMD + config_cmd.rs1 := (MVIN_SCALE_IDENTITY << 32.U) | (req.derived_params.input_spad_stride << 16.U) | (0.U << 3) | 1.U + config_cmd.rs2 := in_channels * (input_w/8).U + + val mvin_cmd = Wire(new RoCCCommand) + mvin_cmd := DontCare + mvin_cmd.inst.funct := LOAD_CMD + mvin_cmd.rs1 := dram_addr + mvin_cmd.rs2 := (I << 48.U).asUInt() | (K << 32.U).asUInt() | spad_addr.asUInt() + + // Inputs and outputs + io.req.ready := state === idle + io.idle := state === idle + io.loop_id := req.loop_id + + io.cmd.valid := state =/= idle && !io.wait_for_prev_loop && !io.rob_overloaded + io.cmd.bits := Mux(state === config, config_cmd, mvin_cmd) + + // Sending outputs + when(io.cmd.fire()) { + when (state === config) { + state := ld + }.otherwise { + val next_ich = sFloorAdd(ich, max_ichs_per_mvin.asUInt(), ichs.zext(), 0.S) + val next_icol = sFloorAdd(icol, I.asUInt(), (icols_unpadded +& rpad).zext(), 0.S-&lpad.zext(), + next_ich === 0.S) + val next_irow = sFloorAdd(irow, 1.U, (irows_unpadded +& dpad).zext(), 0.S-&upad.zext(), + next_icol === 0.S-&lpad.zext() && next_ich === 0.S) + val next_b = sFloorAdd(b, 1.U, batches.zext(), 0.S, + next_irow === 0.S-&upad.zext() && next_icol === 0.S-&lpad.zext() && next_ich === 0.S) + + ich := next_ich + icol := next_icol + irow := next_irow + b := next_b + + state := Mux(next_b === 0.S && next_irow === 0.S-&upad.zext() && next_icol === 0.S-&lpad.zext() && next_ich === 0.S, + idle, ld) + } + } + + // Accepting requests + when (io.req.fire()) { + req := io.req.bits + state := config + b := 0.S + irow := 0.S -& io.req.bits.inner_bounds.upad.zext() + icol := 0.S -& io.req.bits.inner_bounds.lpad.zext() + ich := 0.S + } +} + +class LoopConvLdWeightReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int, val max_addr: Int, val concurrent_loops: Int) extends Bundle { + val outer_bounds = new LoopConvOuterBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val inner_bounds = new LoopConvInnerBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val derived_params = new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val addr_end = UInt(log2Up(max_addr).W) + val dram_addr = UInt(coreMaxAddrBits.W) + val loop_id = UInt(log2Up(concurrent_loops).W) +} + +class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: Int, small_iterator_bitwidth: Int, tiny_iterator_bitwidth: Int, max_addr: Int, input_w: Int, + max_block_len: Int, concurrent_loops: Int)(implicit p: Parameters) extends Module { + val MVIN_SCALE_IDENTITY = 0x3f800000.U // TODO get this from configs somehow + + val io = IO(new Bundle { + val req = Flipped(Decoupled(new LoopConvLdWeightReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, concurrent_loops))) + val cmd = Decoupled(Output(new RoCCCommand)) + + val idle = Output(Bool()) + val rob_overloaded = Input(Bool()) + val wait_for_prev_loop = Input(Bool()) + + val loop_id = Output(UInt(log2Up(concurrent_loops).W)) + }) + + object State extends ChiselEnum { + val idle, config, ld = Value + } + import State._ + val state = RegInit(idle) + + val req = Reg(new LoopConvLdWeightReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, concurrent_loops)) + import req.outer_bounds._ + import req.inner_bounds._ + import req.derived_params._ + + // Derived parameters + val max_ochs_per_mvin = Mux(ochs < (max_block_len * block_size).U, ochs, (max_block_len * block_size).U) + val B_rows = out_channels_per_bank * kcols * krows * kchs + val addr_start = req.addr_end - B_rows + + // Iterators + val och = Reg(UInt(large_iterator_bitwidth.W)) + val krow = Reg(UInt(tiny_iterator_bitwidth.W)) + val kcol = Reg(UInt(tiny_iterator_bitwidth.W)) + val kch = Reg(UInt(large_iterator_bitwidth.W)) + + // Addresses + val dram_addr = req.dram_addr +& ((krow*kernel_dim*in_channels +& kcol*in_channels +& kch) * out_channels +& och) * (input_w/8).U + val spad_addr = addr_start + (och / block_size.U) * krows * kcols * kchs + krow * kcols * kchs + kcol * kchs + kch + + // Sizes + val J = Mux(ochs - och > max_ochs_per_mvin, max_ochs_per_mvin, ochs - och) + val K = Mux(kchs - kch > block_size.U, block_size.U, kchs - kch) + + // Commands + val config_cmd = Wire(new RoCCCommand) + config_cmd := DontCare + config_cmd.inst.funct := CONFIG_CMD + config_cmd.rs1 := (MVIN_SCALE_IDENTITY << 32.U) | (req.derived_params.weight_spad_stride << 16.U) | (1.U << 3) | 1.U + config_cmd.rs2 := out_channels * (input_w/8).U + + val mvin_cmd = Wire(new RoCCCommand) + mvin_cmd := DontCare + mvin_cmd.inst.funct := LOAD2_CMD + mvin_cmd.rs1 := dram_addr + mvin_cmd.rs2 := (K << 48.U) | (J << 32.U) | spad_addr + + // Inputs and outputs + io.req.ready := state === idle + io.idle := state === idle + io.loop_id := req.loop_id + + io.cmd.valid := state =/= idle && !io.wait_for_prev_loop && !io.rob_overloaded + io.cmd.bits := Mux(state === config, config_cmd, mvin_cmd) + + // Sending outputs + when(io.cmd.fire()) { + when (state === config) { + state := ld + }.otherwise { + val next_kch = floorAdd(kch, block_size.U, kchs) + val next_kcol = floorAdd(kcol, 1.U, kcols, next_kch === 0.U) + val next_krow = floorAdd(krow, 1.U, krows, next_kcol === 0.U && next_kch === 0.U) + val next_och = floorAdd(och, max_ochs_per_mvin, ochs, next_krow === 0.U && next_kcol === 0.U && next_kch === 0.U) + + kch := next_kch + kcol := next_kcol + krow := next_krow + och := next_och + + state := Mux(next_och === 0.U && next_krow === 0.U && next_kcol === 0.U && next_kch === 0.U, + idle, ld) + } + } + + // Accepting requests + when (io.req.fire()) { + req := io.req.bits + state := config + kch := 0.U + kcol := 0.U + krow := 0.U + och := 0.U + } +} + +class LoopConvExecuteReq(val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int, val max_addr: Int, val max_acc_addr: Int, val concurrent_loops: Int) extends Bundle { + val outer_bounds = new LoopConvOuterBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val inner_bounds = new LoopConvInnerBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val derived_params = new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val a_addr_start = UInt(log2Up(max_addr).W) + val b_addr_end = UInt(log2Up(max_addr).W) + val c_addr_start = UInt(log2Up(max_acc_addr).W) + val loop_id = UInt(log2Up(concurrent_loops).W) +} + +class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_iterator_bitwidth: Int, tiny_iterator_bitwidth: Int, max_addr: Int, + max_acc_addr: Int, concurrent_loops: Int)(implicit p: Parameters) extends Module { + val GARBAGE_ADDR = (~0.U(32.W)).asUInt() + + val io = IO(new Bundle { + val req = Flipped(Decoupled(new LoopConvExecuteReq(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, max_acc_addr, concurrent_loops))) + val cmd = Decoupled(Output(new RoCCCommand)) + + val lda_completed = Input(Bool()) + val ldb_completed = Input(Bool()) + val ldd_completed = Input(Bool()) + + val idle = Output(Bool()) + val rob_overloaded = Input(Bool()) + + val loop_id = Output(UInt(log2Up(concurrent_loops).W)) + }) + + object State extends ChiselEnum { + val idle, pre, comp = Value + } + import State._ + val state = RegInit(idle) + + val req = Reg(new LoopConvExecuteReq(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, + max_addr, max_acc_addr, concurrent_loops)) + import req.outer_bounds._ + import req.inner_bounds._ + import req.derived_params._ + + // Derived parameters + val B_rows = out_channels_per_bank * kcols * krows * kchs + + val a_addr_start = req.a_addr_start + val b_addr_start = req.b_addr_end - B_rows + val d_addr_start = (BigInt(1) << 31).U | req.c_addr_start + val c_addr_start = (BigInt(3) << 30).U | req.c_addr_start + + // Iterators + val b = Reg(UInt(large_iterator_bitwidth.W)) + val orow = Reg(UInt(small_iterator_bitwidth.W)) + val ocol = Reg(UInt(small_iterator_bitwidth.W)) + val och = Reg(UInt(large_iterator_bitwidth.W)) + val krow = Reg(UInt(tiny_iterator_bitwidth.W)) + val kcol = Reg(UInt(tiny_iterator_bitwidth.W)) + val kch = Reg(UInt(large_iterator_bitwidth.W)) + + val irow = orow * stride +& krow + val icol = ocol * stride +& kcol + + val I = Mux(ocols - ocol > block_size.U, block_size.U, ocols - ocol) + val J = Mux(ochs - och > block_size.U, block_size.U, ochs - och) + val K = Mux(kchs - kch > block_size.U, block_size.U, kchs - kch) + + // Addresses + val a_addr = a_addr_start +& (kch / block_size.U) * batches * irows * icols +& b * irows * icols +& irow * icols +& icol + val b_addr = b_addr_start +& (och / block_size.U) * krows * kcols * kchs +& krow * kcols * kchs +& kcol * kchs +& kch + val c_addr = Mux(ex_overwrite && krow === 0.U && kcol === 0.U && kch === 0.U, d_addr_start, c_addr_start) +& + (och / block_size.U) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol + + // Commands + val pre_cmd = Wire(new RoCCCommand) + pre_cmd := DontCare + pre_cmd.inst.funct := PRELOAD_CMD + pre_cmd.rs1 := (K << 48) | (J << 32) | b_addr + pre_cmd.rs2 := (I << 48) | (J << 32) | c_addr + + val comp_cmd = Wire(new RoCCCommand()) + comp_cmd := DontCare + comp_cmd.inst.funct := COMPUTE_AND_FLIP_CMD + comp_cmd.rs1 := (I << 48) | (K << 32) | a_addr + comp_cmd.rs2 := (I << 48) | (J << 32) | GARBAGE_ADDR + + // Inputs and outputs + io.req.ready := state === idle + io.idle := state === idle + + val ld_ahead = io.lda_completed && io.ldb_completed && io.ldd_completed + + io.cmd.valid := state =/= idle && !io.rob_overloaded && ld_ahead + io.cmd.bits := Mux(state === pre, pre_cmd, comp_cmd) + + io.loop_id := req.loop_id + + // Sending outputs + when (io.cmd.fire()) { + when (state === pre) { + state := comp + }.otherwise { + val next_kch = floorAdd(kch, block_size.U, kchs) + val next_kcol = floorAdd(kcol, 1.U, kcols, next_kch === 0.U) + val next_krow = floorAdd(krow, 1.U, krows, next_kcol === 0.U && next_kch === 0.U) + val next_och = floorAdd(och, block_size.U, ochs, + next_krow === 0.U && next_kcol === 0.U && next_kch === 0.U) + val next_ocol = floorAdd(ocol, block_size.U, ocols, + next_och === 0.U && next_krow === 0.U && next_kcol === 0.U && next_kch === 0.U) + val next_orow = floorAdd(orow, 1.U, orows, + next_ocol === 0.U && next_och === 0.U && next_krow === 0.U && next_kcol === 0.U && next_kch === 0.U) + val next_b = floorAdd(b, 1.U, batches, next_orow === 0.U && + next_ocol === 0.U && next_och === 0.U && next_krow === 0.U && next_kcol === 0.U && next_kch === 0.U) + + kch := next_kch + kcol := next_kcol + krow := next_krow + och := next_och + ocol := next_ocol + orow := next_orow + b := next_b + + state := Mux(next_b === 0.U && next_orow === 0.U && next_ocol === 0.U && + next_och === 0.U && next_krow === 0.U && next_kcol === 0.U && next_kch === 0.U, + idle, pre) + } + } + + // Accepting requests + when (io.req.fire()) { + req := io.req.bits + state := pre + + b := 0.U + orow := 0.U + ocol := 0.U + och := 0.U + krow := 0.U + kcol := 0.U + kch := 0.U + } +} + +class LoopConvStReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int, val max_acc_addr: Int, val concurrent_loops: Int) extends Bundle { + val outer_bounds = new LoopConvOuterBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val inner_bounds = new LoopConvInnerBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val derived_params = new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val addr_start = UInt(log2Up(max_acc_addr).W) + val dram_addr = UInt(coreMaxAddrBits.W) + val no_pool = Bool() + val loop_id = UInt(log2Up(concurrent_loops).W) +} + +class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: Int, small_iterator_bitwidth: Int, tiny_iterator_bitwidth: Int, max_acc_addr: Int, input_w: Int, concurrent_loops: Int)(implicit p: Parameters) extends Module { + val MVIN_SCALE_IDENTITY = 0x3f800000.U // TODO get this from configs somehow + + val io = IO(new Bundle { + val req = Flipped(Decoupled(new LoopConvStReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth: Int, max_acc_addr, concurrent_loops))) + val cmd = Decoupled(Output(new RoCCCommand)) + + val ex_completed = Input(Bool()) + + val idle = Output(Bool()) + val rob_overloaded = Input(Bool()) + + val loop_id = Output(UInt(log2Up(concurrent_loops).W)) + }) + + object State extends ChiselEnum { + val idle, st = Value + } + import State._ + val state = RegInit(idle) + + val req = Reg(new LoopConvStReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth: Int, max_acc_addr, concurrent_loops)) + import req.outer_bounds._ + import req.inner_bounds._ + import req.derived_params._ + + val acc_addr_start = (BigInt(1) << 31).U | req.addr_start + + // Derived parameters + val skip = !(req.no_pool && (req.dram_addr =/= 0.U)) + + // Iterators + val b = Reg(UInt(large_iterator_bitwidth.W)) + val orow = Reg(UInt(small_iterator_bitwidth.W)) + val ocol = Reg(UInt(small_iterator_bitwidth.W)) + val och = Reg(UInt(large_iterator_bitwidth.W)) + + // Addresses + val dram_addr = req.dram_addr + ((b*out_dim*out_dim + orow*out_dim + ocol) * out_channels + och) * (input_w/8).U + val spad_addr = acc_addr_start +& (och / block_size.U) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol + + // Sizes + val I = Mux(ocols - ocol > block_size.U, block_size.U, ocols - ocol) + val J = Mux(ochs - och > block_size.U, block_size.U, ochs - och) + + // Commands + val mvout_cmd = Wire(new RoCCCommand) + mvout_cmd := DontCare + mvout_cmd.inst.funct := STORE_CMD + mvout_cmd.rs1 := dram_addr + mvout_cmd.rs2 := (I << 48.U) | (J << 32.U) | spad_addr + + // Inputs and outputs + io.req.ready := state === idle + io.idle := state === idle + io.loop_id := req.loop_id + + io.cmd.valid := state =/= idle && !io.rob_overloaded && !skip && io.ex_completed + io.cmd.bits := mvout_cmd + + // Sending outputs + when (skip) { + state := idle + }.elsewhen(io.cmd.fire()) { + val next_och = floorAdd(och, block_size.U, ochs) + val next_ocol = floorAdd(ocol, block_size.U, ocols, next_och === 0.U) + val next_orow = floorAdd(orow, 1.U, orows, next_ocol === 0.U && next_och === 0.U) + val next_b = floorAdd(b, 1.U, batches, next_orow === 0.U && next_ocol === 0.U && next_och === 0.U) + + och := next_och + ocol := next_ocol + orow := next_orow + b := next_b + + state := Mux(next_b === 0.U && next_orow === 0.U && next_ocol === 0.U && next_och === 0.U, + idle, st) + } + + // Accepting requests + when (io.req.fire()) { + req := io.req.bits + state := st + + b := 0.U + orow := 0.U + ocol := 0.U + och := 0.U + } +} + +class LoopConvState(val block_size: Int, val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int, val coreMaxAddrBits: Int, val max_addr: Int, val max_acc_addr: Int) extends Bundle { + val outer_bounds = new LoopConvOuterBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val inner_bounds = new LoopConvInnerBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + + val bias_dram_addr = UInt(coreMaxAddrBits.W) + val weights_dram_addr = UInt(coreMaxAddrBits.W) + val input_dram_addr = UInt(coreMaxAddrBits.W) + val output_dram_addr = UInt(coreMaxAddrBits.W) + + val no_bias = Bool() + val no_pool = Bool() + + val configured = Bool() + + val running = Bool() + + val ld_bias_started = Bool() + val ld_input_started = Bool() + val ld_weights_started = Bool() + val ex_started = Bool() + val st_started = Bool() + + val ld_bias_completed = Bool() + val ld_input_completed = Bool() + val ld_weights_completed = Bool() + val ex_completed = Bool() + val st_completed = Bool() + + def all_completed(dummy: Int=0): Bool = ld_bias_completed && ld_input_completed && ld_weights_completed && ex_completed && st_completed + + val a_addr_start = UInt(log2Up(max_addr).W) + val b_addr_end = UInt(log2Up(max_addr).W) + + def derived_params(dummy: Int=0): LoopConvDerivedParams = { + import outer_bounds.stride + import inner_bounds.{batches, pochs, orows, ocols, krows, kcols, upad, dpad, lpad, rpad, kchs} + + val result = Wire(new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth)) + + result.ochs := pochs + + result.irows := orows * stride +& krows - 1.U + result.icols := ocols * stride +& kcols - 1.U + result.irows_unpadded := result.irows - upad - dpad + result.icols_unpadded := result.icols - lpad - rpad + result.ichs := kchs + + result.out_channels_per_bank := result.ochs / block_size.U +& (result.ochs % block_size.U =/= 0.U) + + result.bias_spad_stride := batches * orows * ocols + result.input_spad_stride := batches * result.irows * result.icols + result.weight_spad_stride := krows * kcols * kchs + + result.ex_overwrite := bias_dram_addr =/= 0.U && no_bias + + result + } + + def reset(): Unit = { + configured := false.B + + running := false.B + + ld_bias_started := false.B + ld_input_started := false.B + ld_weights_started := false.B + ex_started := false.B + st_started := false.B + + ld_bias_completed := false.B + ld_input_completed := false.B + ld_weights_completed := false.B + ex_completed := false.B + st_completed := false.B + } +} + +class LoopConv (block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: Int, max_exs: Int, max_sts: Int, + max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int) + (implicit p: Parameters) extends Module { + val large_iterator_bitwidth = 16 + val small_iterator_bitwidth = 8 + val tiny_iterator_bitwidth = 4 + + val max_block_len = (dma_max_bytes / (block_size * (input_w / 8))) max 1 + val max_block_len_acc = (dma_max_bytes / (block_size * (acc_w / 8))) max 1 + + val io = IO(new Bundle { + val in = Flipped(Decoupled(new RoCCCommand)) + val out = Decoupled(new RoCCCommand) + val ld_utilization = Input(UInt(log2Up(rob_size+1).W)) + val st_utilization = Input(UInt(log2Up(rob_size+1).W)) + val ex_utilization = Input(UInt(log2Up(rob_size+1).W)) + val busy = Output(Bool()) + }) + + // Create states + val concurrent_loops = 2 + val loops = Reg(Vec(concurrent_loops, new LoopConvState(block_size, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, coreMaxAddrBits, max_addr, max_acc_addr))) + val head_loop_id = RegInit(0.U(log2Up(concurrent_loops).W)) + val tail_loop_id = (~head_loop_id).asUInt() // This is the loop that we always try to configure if available + val head_loop = loops(head_loop_id) + val tail_loop = loops(tail_loop_id) + + val loop_configured = loops.map(_.configured).reduce(_ || _) + + val loop_being_configured_id = Mux(head_loop.configured, tail_loop_id, head_loop_id) + val loop_being_configured = loops(loop_being_configured_id) + + // Create inner modules + val ld_bias = Module(new LoopConvLdBias(block_size, coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_acc_addr, acc_w, max_block_len_acc, concurrent_loops)) + val ld_input = Module(new LoopConvLdInput(block_size, coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, input_w, max_block_len, concurrent_loops)) + val ld_weights = Module(new LoopConvLdWeight(block_size, coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, input_w, max_block_len, concurrent_loops)) + val ex = Module(new LoopConvExecute(block_size, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, max_acc_addr, concurrent_loops)) + val st = Module(new LoopConvSt(block_size, coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_acc_addr, input_w, concurrent_loops)) + + // Create command queue + val cmd = Queue(io.in) + + io.busy := cmd.valid || loop_configured + + // Create arbiter + val arb = Module(new Arbiter(new RoCCCommand, 5)) + arb.io.in(0) <> st.io.cmd + arb.io.in(1) <> ex.io.cmd + arb.io.in(2) <> ld_bias.io.cmd + arb.io.in(3) <> ld_weights.io.cmd + arb.io.in(4) <> ld_input.io.cmd + val unrolled_cmd = arb.io.out + + // Wire up unrolled command output + val is_loop_run_cmd = cmd.bits.inst.funct === LOOP_CONV_WS + val is_loop_config_cmd = cmd.bits.inst.funct >= LOOP_CONV_WS_CONFIG_1 && cmd.bits.inst.funct <= LOOP_CONV_WS_CONFIG_6 + val is_loop_cmd = is_loop_run_cmd || is_loop_config_cmd + + io.out.bits := Mux(loop_configured, unrolled_cmd.bits, cmd.bits) + io.out.bits.status := cmd.bits.status // TODO This is not guaranteed to be the correct fix! We must fix this + io.out.valid := Mux(loop_configured, unrolled_cmd.valid, cmd.valid && !is_loop_config_cmd && !is_loop_run_cmd) + + cmd.ready := Mux(is_loop_cmd, !loop_being_configured.configured, !loop_configured && io.out.ready) + arb.io.out.ready := io.out.ready + + // Wire up waiting-for-loads signals + val ex_is_waiting_for_loads = loops(ex.io.loop_id).ex_started && !loops(ex.io.loop_id).ex_completed && + !(loops(ex.io.loop_id).ld_input_completed && loops(ex.io.loop_id).ld_weights_completed && + loops(ex.io.loop_id).ld_bias_completed) + + ld_bias.io.wait_for_prev_loop := ex_is_waiting_for_loads && ld_bias.io.loop_id =/= ex.io.loop_id + ld_weights.io.wait_for_prev_loop := ex_is_waiting_for_loads && ld_weights.io.loop_id =/= ex.io.loop_id + ld_input.io.wait_for_prev_loop := ex_is_waiting_for_loads && ld_input.io.loop_id =/= ex.io.loop_id + + // Wire up overloaded signals + ld_bias.io.rob_overloaded := io.ld_utilization >= max_lds.U + ld_input.io.rob_overloaded := io.ld_utilization >= max_lds.U + ld_weights.io.rob_overloaded := io.ld_utilization >= max_lds.U + ex.io.rob_overloaded := io.ex_utilization >= max_exs.U + st.io.rob_overloaded := io.st_utilization >= max_sts.U + + // Wire up iterator inputs + ex.io.lda_completed := (ld_input.io.loop_id =/= ex.io.loop_id) || ld_input.io.idle + ex.io.ldb_completed := (ld_weights.io.loop_id =/= ex.io.loop_id) || ld_weights.io.idle + ex.io.ldd_completed := (ld_bias.io.loop_id =/= ex.io.loop_id) || ld_bias.io.idle + st.io.ex_completed := (ex.io.loop_id =/= st.io.loop_id) || ex.io.idle + + // Create config registers + when(cmd.valid && is_loop_cmd && !loop_being_configured.configured) { + + switch (cmd.bits.inst.funct) { + is (LOOP_CONV_WS_CONFIG_1) { + loop_being_configured.outer_bounds.out_channels := cmd.bits.rs1(63, 48) + loop_being_configured.outer_bounds.in_channels := cmd.bits.rs1(47, 32) + loop_being_configured.outer_bounds.in_dim := cmd.bits.rs1(31, 16) + loop_being_configured.outer_bounds.batch_size := cmd.bits.rs1(15, 0) + + loop_being_configured.outer_bounds.padding := cmd.bits.rs2(63, 48) + loop_being_configured.outer_bounds.stride := cmd.bits.rs2(47, 32) + loop_being_configured.outer_bounds.pool_out_dim := cmd.bits.rs2(31, 16) + loop_being_configured.outer_bounds.out_dim := cmd.bits.rs2(15, 0) + } + + is (LOOP_CONV_WS_CONFIG_2) { + loop_being_configured.outer_bounds.kernel_dim := cmd.bits.rs1(63, 48) + loop_being_configured.outer_bounds.pool_size := cmd.bits.rs1(47, 32) + loop_being_configured.outer_bounds.pool_stride := cmd.bits.rs1(31, 16) + loop_being_configured.outer_bounds.pool_padding := cmd.bits.rs1(15, 0) + + loop_being_configured.inner_bounds.batches := cmd.bits.rs2(63, 48) + loop_being_configured.inner_bounds.porows := cmd.bits.rs2(47, 32) + loop_being_configured.inner_bounds.pocols := cmd.bits.rs2(31, 16) + loop_being_configured.inner_bounds.pochs := cmd.bits.rs2(15, 0) + } + + is (LOOP_CONV_WS_CONFIG_3) { + loop_being_configured.inner_bounds.krows := cmd.bits.rs1(63, 48) + loop_being_configured.inner_bounds.kcols := cmd.bits.rs1(47, 32) + loop_being_configured.inner_bounds.kchs := cmd.bits.rs1(31, 16) + loop_being_configured.inner_bounds.lpad := cmd.bits.rs1(15, 0) + + loop_being_configured.inner_bounds.rpad := cmd.bits.rs2(63, 48) + loop_being_configured.inner_bounds.upad := cmd.bits.rs2(47, 32) + loop_being_configured.inner_bounds.dpad := cmd.bits.rs2(31, 16) + loop_being_configured.inner_bounds.plpad := cmd.bits.rs2(15, 0) + } + + is (LOOP_CONV_WS_CONFIG_4) { + loop_being_configured.inner_bounds.orows := cmd.bits.rs1(63, 48) + loop_being_configured.inner_bounds.prad := cmd.bits.rs1(47, 32) + loop_being_configured.inner_bounds.pupad := cmd.bits.rs1(31, 16) + loop_being_configured.inner_bounds.pdpad := cmd.bits.rs1(15, 0) + + loop_being_configured.inner_bounds.ocols := cmd.bits.rs2(15, 0) + } + + is (LOOP_CONV_WS_CONFIG_5) { + loop_being_configured.weights_dram_addr := cmd.bits.rs1 + + loop_being_configured.output_dram_addr := cmd.bits.rs2 + } + + is (LOOP_CONV_WS_CONFIG_6) { + loop_being_configured.bias_dram_addr := cmd.bits.rs1 + + loop_being_configured.input_dram_addr := cmd.bits.rs2 + } + + is (LOOP_CONV_WS) { + loop_being_configured.no_bias := cmd.bits.rs1(0) + + loop_being_configured.no_pool := cmd.bits.rs2(0) + + loop_being_configured.configured := true.B + } + } + } + + // Wire up request signals + val ld_bias_addr_start = RegInit(0.U(log2Up(max_acc_addr).W)) + val ex_c_addr_start = RegInit(0.U(log2Up(max_acc_addr).W)) + val st_addr_start = RegInit(0.U(log2Up(max_acc_addr).W)) + + val loop_requesting_ld_bias_id = Mux(head_loop.ld_bias_started, tail_loop_id, head_loop_id) + val loop_requesting_ld_bias = loops(loop_requesting_ld_bias_id) + ld_bias.io.req.bits.outer_bounds := loop_requesting_ld_bias.outer_bounds + ld_bias.io.req.bits.inner_bounds := loop_requesting_ld_bias.inner_bounds + ld_bias.io.req.bits.derived_params := loop_requesting_ld_bias.derived_params() + ld_bias.io.req.bits.addr_start := ld_bias_addr_start + ld_bias.io.req.bits.dram_addr := loop_requesting_ld_bias.bias_dram_addr + ld_bias.io.req.bits.no_bias := loop_requesting_ld_bias.no_bias + ld_bias.io.req.bits.loop_id := loop_requesting_ld_bias_id + + ld_bias.io.req.valid := !loop_requesting_ld_bias.ld_bias_started && loop_requesting_ld_bias.configured + + when (ld_bias.io.req.fire()) { + loop_requesting_ld_bias.running := true.B + loop_requesting_ld_bias.ld_bias_started := true.B + + // when (loop_requesting_ld_bias.bias_dram_addr =/= 0.U) { + when (loop_requesting_ld_bias.output_dram_addr =/= 0.U) { + ld_bias_addr_start := floorAdd(ld_bias_addr_start, (max_acc_addr / concurrent_loops).U, max_acc_addr.U) + } + } + + val loop_requesting_ld_input_id = Mux(head_loop.ld_input_started, tail_loop_id, head_loop_id) + val loop_requesting_ld_input = loops(loop_requesting_ld_input_id) + ld_input.io.req.bits.outer_bounds := loop_requesting_ld_input.outer_bounds + ld_input.io.req.bits.inner_bounds := loop_requesting_ld_input.inner_bounds + ld_input.io.req.bits.derived_params := loop_requesting_ld_input.derived_params() + ld_input.io.req.bits.addr_start := loop_requesting_ld_input.a_addr_start + ld_input.io.req.bits.dram_addr := loop_requesting_ld_input.input_dram_addr + ld_input.io.req.bits.loop_id := loop_requesting_ld_input_id + + ld_input.io.req.valid := !loop_requesting_ld_input.ld_input_started && loop_requesting_ld_input.configured + + when (ld_input.io.req.fire()) { + loop_requesting_ld_input.running := true.B + loop_requesting_ld_input.ld_input_started := true.B + } + + val loop_requesting_ld_weights_id = Mux(head_loop.ld_weights_started, tail_loop_id, head_loop_id) + val loop_requesting_ld_weights = loops(loop_requesting_ld_weights_id) + ld_weights.io.req.bits.outer_bounds := loop_requesting_ld_weights.outer_bounds + ld_weights.io.req.bits.inner_bounds := loop_requesting_ld_weights.inner_bounds + ld_weights.io.req.bits.derived_params := loop_requesting_ld_weights.derived_params() + ld_weights.io.req.bits.addr_end := loop_requesting_ld_weights.b_addr_end + ld_weights.io.req.bits.dram_addr := loop_requesting_ld_weights.weights_dram_addr + ld_weights.io.req.bits.loop_id := loop_requesting_ld_weights_id + + ld_weights.io.req.valid := !loop_requesting_ld_weights.ld_weights_started && loop_requesting_ld_weights.configured + + when (ld_weights.io.req.fire()) { + loop_requesting_ld_weights.running := true.B + loop_requesting_ld_weights.ld_weights_started := true.B + } + + val loop_requesting_ex_id = Mux(head_loop.ex_started, tail_loop_id, head_loop_id) + val loop_requesting_ex = loops(loop_requesting_ex_id) + ex.io.req.bits.outer_bounds := loop_requesting_ex.outer_bounds + ex.io.req.bits.inner_bounds := loop_requesting_ex.inner_bounds + ex.io.req.bits.derived_params := loop_requesting_ex.derived_params() + ex.io.req.bits.a_addr_start := loop_requesting_ex.a_addr_start + ex.io.req.bits.b_addr_end := loop_requesting_ex.b_addr_end + ex.io.req.bits.c_addr_start := ex_c_addr_start + ex.io.req.bits.loop_id := loop_requesting_ex_id + + ex.io.req.valid := !loop_requesting_ex.ex_started && loop_requesting_ex.ld_bias_started && + loop_requesting_ex.ld_input_started && loop_requesting_ex.ld_weights_started && loop_requesting_ex.configured + + when (ex.io.req.fire()) { + loop_requesting_ex.running := true.B + loop_requesting_ex.ex_started := true.B + + when (loop_requesting_ex.output_dram_addr =/= 0.U) { + ex_c_addr_start := floorAdd(ex_c_addr_start, (max_acc_addr / concurrent_loops).U, max_acc_addr.U) + } + } + + val loop_requesting_st_id = Mux(head_loop.st_started, tail_loop_id, head_loop_id) + val loop_requesting_st = loops(loop_requesting_st_id) + st.io.req.bits.outer_bounds := loop_requesting_st.outer_bounds + st.io.req.bits.inner_bounds := loop_requesting_st.inner_bounds + st.io.req.bits.derived_params := loop_requesting_st.derived_params() + st.io.req.bits.addr_start := st_addr_start + st.io.req.bits.dram_addr := loop_requesting_st.output_dram_addr + st.io.req.bits.no_pool := loop_requesting_st.no_pool + st.io.req.bits.loop_id := loop_requesting_st_id + + st.io.req.valid := !loop_requesting_st.st_started && loop_requesting_st.ex_started && loop_requesting_st.configured + + when (st.io.req.fire()) { + loop_requesting_st.running := true.B + loop_requesting_st.st_started := true.B + + when (loop_requesting_st.output_dram_addr =/= 0.U) { + st_addr_start := floorAdd(st_addr_start, (max_acc_addr / concurrent_loops).U, max_acc_addr.U) + } + } + + // Handle completed signals + when (ld_bias.io.idle && loops(ld_bias.io.loop_id).running && loops(ld_bias.io.loop_id).ld_bias_started) { + loops(ld_bias.io.loop_id).ld_bias_completed := true.B + } + + when (ld_input.io.idle && loops(ld_input.io.loop_id).running && loops(ld_input.io.loop_id).ld_input_started) { + loops(ld_input.io.loop_id).ld_input_completed := true.B + } + + when (ld_weights.io.idle && loops(ld_weights.io.loop_id).running && loops(ld_weights.io.loop_id).ld_weights_started) { + loops(ld_weights.io.loop_id).ld_weights_completed := true.B + } + + when (ex.io.idle && loops(ex.io.loop_id).running && loops(ex.io.loop_id).ex_started) { + loops(ex.io.loop_id).ex_completed := true.B + } + + when (st.io.idle && loops(st.io.loop_id).running && loops(st.io.loop_id).st_started) { + loops(st.io.loop_id).st_completed := true.B + } + + when (head_loop.running && head_loop.all_completed()) { + head_loop.reset() + head_loop_id := ~head_loop_id + } + + // Resets + when (reset.toBool()) { + loops.zipWithIndex.foreach { case (l, i) => + l.reset() + l.a_addr_start := (i * (max_addr / concurrent_loops)).U + l.b_addr_end := ((i+1) * (max_addr / concurrent_loops) - block_size).U + } + } +} + +object LoopConv { + def apply(in: DecoupledIO[RoCCCommand], ld_utilization: UInt, st_utilization: UInt, ex_utilization: UInt, + block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: Int, max_exs: Int, max_sts: Int, + max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int) + (implicit p: Parameters): Tuple2[DecoupledIO[RoCCCommand], Bool] = { + val mod = Module(new LoopConv(block_size, coreMaxAddrBits, rob_size, max_lds, max_exs, max_sts, + max_addr, max_acc_addr, input_w, acc_w, dma_max_bytes)) + mod.io.in <> in + mod.io.ld_utilization := ld_utilization + mod.io.st_utilization := st_utilization + mod.io.ex_utilization := ex_utilization + (mod.io.out, mod.io.busy) + } +} diff --git a/src/main/scala/gemmini/LoopMatmul.scala b/src/main/scala/gemmini/LoopMatmul.scala index 181202b3..74db8914 100644 --- a/src/main/scala/gemmini/LoopMatmul.scala +++ b/src/main/scala/gemmini/LoopMatmul.scala @@ -84,8 +84,11 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In when (io.cmd.fire()) { // The order here is k, j, i - val next_i = floorAdd(i, 1.U, req.max_i) - val next_k = floorAdd(k, max_blocks, req.max_k, next_i === 0.U) + val i_blocks = Mux(req.transpose, max_blocks, 1.U) + val k_blocks = Mux(req.transpose, 1.U, max_blocks) + + val next_i = floorAdd(i, i_blocks, req.max_i) + val next_k = floorAdd(k, k_blocks, req.max_k, next_i === 0.U) i := next_i k := next_k @@ -182,11 +185,14 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In when (io.cmd.fire()) { // The order here is k, j, i - val next_j = floorAdd(j, max_blocks, req.max_j) - val next_k = floorAdd(k, 1.U, req.max_k, next_j === 0.U) + val j_blocks = Mux(req.transpose, 1.U, max_blocks) + val k_blocks = Mux(req.transpose, max_blocks, 1.U) + + val next_j = floorAdd(j, j_blocks, req.max_j) + val next_k = floorAdd(k, k_blocks, req.max_k, next_j === 0.U) - k := next_k j := next_j + k := next_k when (next_j === 0.U && next_k === 0.U) { state := idle @@ -229,7 +235,7 @@ class LoopMatmulLdD(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In }) object State extends ChiselEnum { - val idle, st = Value + val idle, ld = Value } import State._ val state = RegInit(idle) @@ -270,8 +276,8 @@ class LoopMatmulLdD(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In state := idle }.elsewhen (io.cmd.fire()) { // The order here is k, j, i - val next_i = floorAdd(i, max_blocks, req.max_i) - val next_j = floorAdd(j, 1.U, req.max_j, next_i === 0.U) + val next_i = floorAdd(i, 1.U, req.max_i) + val next_j = floorAdd(j, max_blocks, req.max_j, next_i === 0.U) i := next_i j := next_j @@ -283,7 +289,7 @@ class LoopMatmulLdD(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In when (io.req.fire()) { req := io.req.bits - state := st + state := ld j := 0.U i := 0.U } @@ -308,7 +314,6 @@ class LoopMatmulExecuteReq(val block_size: Int, val coreMaxAddrBits: Int, val it class LoopMatmulExecute(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_addr: Int, max_acc_addr: Int, concurrent_loops: Int) (implicit p: Parameters) extends Module { - val MAX_BLOCK_LEN = 4 // TODO get this from configs val GARBAGE_ADDR = (~0.U(32.W)).asUInt() val io = IO(new Bundle { @@ -443,7 +448,7 @@ class LoopMatmulStCReq(val block_size: Int, val coreMaxAddrBits: Int, val iterat val loop_id = UInt(log2Up(concurrent_loops).W) } -class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, concurrent_loops: Int) +class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, max_block_len: Int, concurrent_loops: Int) (implicit p: Parameters) extends Module { val io = IO(new Bundle { val req = Flipped(Decoupled(new LoopMatmulStCReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, concurrent_loops))) @@ -471,6 +476,8 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In val req = Reg(new LoopMatmulStCReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, concurrent_loops)) + val max_blocks = Mux(req.full_c, 1.U, Mux(req.max_j <= max_block_len.U, req.max_j, max_block_len.U)) + val j = Reg(UInt(iterator_bitwidth.W)) val i = Reg(UInt(iterator_bitwidth.W)) @@ -479,7 +486,8 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In val dram_addr = Mux(req.full_c, req.dram_addr + (i * req.dram_stride + j) * block_size.U * (acc_w/8).U, req.dram_addr + (i * req.dram_stride + j) * block_size.U * (input_w/8).U) val sp_addr = acc_addr_start + (i * req.max_j + j) * block_size.U - val cols = block_size.U - Mux(j + 1.U >= req.max_j, req.pad_j, 0.U) + val blocks = Mux(j + max_blocks <= req.max_j, max_blocks, req.max_j-j) + val cols = (blocks * block_size.U) - Mux(j + blocks >= req.max_j, req.pad_j, 0.U) val rows = block_size.U - Mux(i === req.max_i-1.U, req.pad_i, 0.U) val mvout_cmd = Wire(new RoCCCommand) @@ -494,7 +502,11 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In io.idle := state === idle // The order here is k, j, i - val ex_ahead = io.ex_completed || (io.ex_k === req.max_k - 1.U && (io.ex_j > j || (io.ex_j === j && io.ex_i > i))) + // val ex_ahead = io.ex_completed || (io.ex_k === req.max_k - 1.U && (io.ex_j > j || (io.ex_j === j && io.ex_i > i))) + val ex_ahead = io.ex_completed || + (io.ex_k === req.max_k - 1.U && + (io.ex_j >= j + blocks || + ((io.ex_j === j + blocks - 1.U) && io.ex_i > i))) io.cmd.valid := state =/= idle && !io.rob_overloaded && ex_ahead && req.dram_addr =/= 0.U io.cmd.bits := mvout_cmd @@ -506,7 +518,7 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In }.elsewhen (io.cmd.fire()) { // The order here is k, j, i val next_i = floorAdd(i, 1.U, req.max_i) - val next_j = floorAdd(j, 1.U, req.max_j, next_i === 0.U) + val next_j = floorAdd(j, max_blocks, req.max_j, next_i === 0.U) i := next_i j := next_j @@ -595,15 +607,15 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int) (implicit p: Parameters) extends Module { val iterator_bitwidth = 16 - val max_block_len = (dma_max_bytes / (block_size * input_w * 8)) max 1 - val max_block_len_acc = (dma_max_bytes / (block_size * acc_w * 8)) max 1 + val max_block_len = (dma_max_bytes / (block_size * input_w / 8)) max 1 + val max_block_len_acc = (dma_max_bytes / (block_size * acc_w / 8)) max 1 val io = IO(new Bundle { val in = Flipped(Decoupled(new RoCCCommand)) val out = Decoupled(new RoCCCommand) - val ld_utilization = Input(UInt(log2Up(rob_size).W)) - val st_utilization = Input(UInt(log2Up(rob_size).W)) - val ex_utilization = Input(UInt(log2Up(rob_size).W)) + val ld_utilization = Input(UInt(log2Up(rob_size+1).W)) + val st_utilization = Input(UInt(log2Up(rob_size+1).W)) + val ex_utilization = Input(UInt(log2Up(rob_size+1).W)) val busy = Output(Bool()) }) @@ -625,7 +637,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: val ldB = Module(new LoopMatmulLdB(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, input_w, max_block_len, concurrent_loops)) val ldD = Module(new LoopMatmulLdD(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, input_w, acc_w, max_block_len, max_block_len_acc, concurrent_loops)) val ex = Module(new LoopMatmulExecute(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, max_acc_addr, concurrent_loops)) - val stC = Module(new LoopMatmulStC(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, input_w, acc_w, concurrent_loops)) + val stC = Module(new LoopMatmulStC(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, input_w, acc_w, max_block_len, concurrent_loops)) // Create command queue val cmd = Queue(io.in) @@ -654,7 +666,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: val is_loop_cmd = is_loop_run_cmd || is_loop_config_cmd io.out.bits := Mux(loop_configured, unrolled_cmd.bits, cmd.bits) - io.out.bits.status := cmd.bits.status + io.out.bits.status := cmd.bits.status // TODO This is not guaranteed to be the correct fix! We must fix this io.out.valid := Mux(loop_configured, unrolled_cmd.valid, cmd.valid && !is_loop_config_cmd && !is_loop_run_cmd) cmd.ready := Mux(is_loop_cmd, !loop_being_configured.configured, !loop_configured && io.out.ready) @@ -681,6 +693,9 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: stC.io.ex_j := ex.io.j stC.io.ex_i := ex.io.i + val loops_configured = RegInit(0.U(16.W)) + dontTouch(loops_configured) + // Create config registers when(cmd.valid && is_loop_cmd && !loop_being_configured.configured) { @@ -723,6 +738,8 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: loop_being_configured.b_transpose := cmd.bits.rs2(1) loop_being_configured.configured := true.B + + loops_configured := loops_configured + 1.U } } } diff --git a/src/main/scala/gemmini/Mesh.scala b/src/main/scala/gemmini/Mesh.scala index 5f50c992..22ece6f3 100644 --- a/src/main/scala/gemmini/Mesh.scala +++ b/src/main/scala/gemmini/Mesh.scala @@ -57,6 +57,7 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, } } // Chain control signals (pipeline across each column) + assert(!(mesh.map(_.map(_.io.bad_dataflow).reduce(_||_)).reduce(_||_))) for (c <- 0 until meshColumns) { meshT(c).foldLeft((io.in_control(c), io.in_valid(c))) { case ((in_ctrl, valid), tile) => @@ -68,6 +69,7 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, (tile.io.out_control, tile.io.out_valid) } } + // Chain in_valid (pipeline across each column) for (c <- 0 until meshColumns) { meshT(c).foldLeft(io.in_valid(c)) { diff --git a/src/main/scala/gemmini/PE.scala b/src/main/scala/gemmini/PE.scala index b912ad34..7c17cc39 100644 --- a/src/main/scala/gemmini/PE.scala +++ b/src/main/scala/gemmini/PE.scala @@ -34,6 +34,8 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, val in_valid = Input(Bool()) val out_valid = Output(Bool()) + + val bad_dataflow = Output(Bool()) }) val cType = if (df == Dataflow.WS) inputType else accType @@ -66,6 +68,7 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, val COMPUTE = 0.U(1.W) val PROPAGATE = 1.U(1.W) + io.bad_dataflow := false.B when ((df == Dataflow.OS).B || ((df == Dataflow.BOTH).B && dataflow === OUTPUT_STATIONARY)) { when(prop === PROPAGATE) { io.out_c := (c1 >> shift_offset).clippedToWidthOf(outputType) @@ -89,7 +92,8 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, c2 := d } }.otherwise { - assert(false.B, "unknown dataflow") + io.bad_dataflow := true.B + //assert(false.B, "unknown dataflow") io.out_c := DontCare io.out_b := DontCare } diff --git a/src/main/scala/gemmini/ROB.scala b/src/main/scala/gemmini/ROB.scala index ccc6dbd2..c002da48 100644 --- a/src/main/scala/gemmini/ROB.scala +++ b/src/main/scala/gemmini/ROB.scala @@ -3,66 +3,79 @@ package gemmini import chisel3._ import chisel3.util._ - import freechips.rocketchip.tile.RoCCCommand - import GemminiISA._ import Util._ -//import midas.targetutils.FpgaDebug // TODO unify this class with GemminiCmdWithDeps -class ROBIssue[T <: Data](cmd_t: T, nEntries: Int) extends Bundle { +class ROBIssue[T <: Data](cmd_t: T, rob_entries: Int) extends Bundle { val valid = Output(Bool()) val ready = Input(Bool()) val cmd = Output(cmd_t.cloneType) - val rob_id = Output(UInt(log2Up(nEntries).W)) + val rob_id = Output(UInt(log2Up(rob_entries).W)) def fire(dummy: Int=0) = valid && ready - override def cloneType: this.type = new ROBIssue(cmd_t, nEntries).asInstanceOf[this.type] + override def cloneType: this.type = new ROBIssue(cmd_t, rob_entries).asInstanceOf[this.type] } // TODO we don't need to store the full command in here. We should be able to release the command directly into the relevant controller and only store the associated metadata in the ROB. This would reduce the size considerably -class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows: Int, block_cols: Int) extends Module { +class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V], cmd_t: RoCCCommand) extends Module { + import config._ + + val block_rows = tileRows * meshRows + val block_cols = tileColumns * meshColumns + val io = IO(new Bundle { val alloc = Flipped(Decoupled(cmd_t.cloneType)) - val completed = Flipped(Valid(UInt(log2Up(nEntries).W))) + val completed = Flipped(Valid(UInt(log2Up(rob_entries).W))) val issue = new Bundle { - val ld = new ROBIssue(cmd_t, nEntries) - val st = new ROBIssue(cmd_t, nEntries) - val ex = new ROBIssue(cmd_t, nEntries) + val ld = new ROBIssue(cmd_t, rob_entries) + val st = new ROBIssue(cmd_t, rob_entries) + val ex = new ROBIssue(cmd_t, rob_entries) } - val ld_utilization = Output(UInt(log2Up(nEntries).W)) - val st_utilization = Output(UInt(log2Up(nEntries).W)) - val ex_utilization = Output(UInt(log2Up(nEntries).W)) + val ld_utilization = Output(UInt(log2Up(rob_entries+1).W)) + val st_utilization = Output(UInt(log2Up(rob_entries+1).W)) + val ex_utilization = Output(UInt(log2Up(rob_entries+1).W)) val busy = Output(Bool()) val solitary_preload = Input(Bool()) // TODO very hacky. from ExecuteController, to prevent infinite fence stalls. remove later }) + // TODO make this a ChiselEnum val ldq :: stq :: exq :: Nil = Enum(3) val q_t = ldq.cloneType + class OpT extends Bundle { + val start = local_addr_t.cloneType + val end = local_addr_t.cloneType + val wraps_around = Bool() + + def overlaps(other: OpT): Bool = { + ((other.start <= start && (start < other.end || other.wraps_around)) || + (start <= other.start && (other.start < end || wraps_around))) && + !(start.is_garbage() || other.start.is_garbage()) // TODO the "is_garbage" check might not really be necessary + } + } + + val instructions_allocated = RegInit(0.U(32.W)) + when (io.alloc.fire()) { + instructions_allocated := instructions_allocated + 1.U + } + dontTouch(instructions_allocated) + class Entry extends Bundle { val q = q_t.cloneType val is_config = Bool() - val op1 = UDValid(local_addr_t.cloneType) - val op2 = UDValid(local_addr_t.cloneType) - // val op3 = UDValid(local_addr_t.cloneType) - - val dst = UDValid(new Bundle { - val start = local_addr_t.cloneType - val len = UInt(8.W) // TODO magic number - - def end(dummy: Int = 0): LocalAddr = start + len * block_rows.U - def wraps_around(dummy: Int = 0): Bool = start.add_with_overflow(len * block_rows.U)._2 - }) + val op1 = UDValid(new OpT) + val op2 = UDValid(new OpT) + val dst = UDValid(new OpT) val issued = Bool() @@ -70,16 +83,20 @@ class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows val cmd = cmd_t.cloneType - val deps = Vec(nEntries, Bool()) + val deps = Vec(rob_entries, Bool()) def ready(dummy: Int = 0): Bool = !deps.reduce(_ || _) + + // Debugging signals + val allocated_at = UInt(instructions_allocated.getWidth.W) } - val entries = Reg(Vec(nEntries, UDValid(new Entry))) + val entries = Reg(Vec(rob_entries, UDValid(new Entry))) val empty = !entries.map(_.valid).reduce(_ || _) val full = entries.map(_.valid).reduce(_ && _) - // io.busy := !empty + // TODO we could also check for a solitary preload by recording the last instruction that was allocated, rather than + // reading all entries to check for preloads, which is an O(n) operation in terms of area cost val utilization = PopCount(entries.map(_.valid)) val solitary_preload = utilization === 1.U && entries.map(e => e.valid && e.bits.cmd.inst.funct === PRELOAD_CMD).reduce(_ || _) io.busy := !empty && !(solitary_preload && io.solitary_preload) @@ -87,13 +104,40 @@ class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows // Read in commands to the buffer io.alloc.ready := !full - val last_allocated = Reg(UInt(log2Up(nEntries).W)) + val last_allocated = Reg(UInt(log2Up(rob_entries).W)) + val a_stride = Reg(UInt(16.W)) // TODO magic numbers // TODO we also need to check the transpose to see how many rows we're reading + val ld_block_strides = Reg(Vec(load_states, UInt(block_stride_bits.W))) + val st_block_stride = block_rows.U val new_entry = Wire(new Entry) new_entry := DontCare - val new_entry_id = MuxCase((nEntries-1).U, entries.zipWithIndex.map { case (e, i) => !e.valid -> i.U }) + val new_entry_id = MuxCase((rob_entries-1).U, entries.zipWithIndex.map { case (e, i) => !e.valid -> i.U }) val alloc_fire = io.alloc.fire() + val raws_probe = WireInit(0.U(rob_entries.W)) + val waws_probe = WireInit(0.U(rob_entries.W)) + val wars_probe = WireInit(0.U(rob_entries.W)) + val older_in_same_q_probe = WireInit(0.U(rob_entries.W)) + val is_st_and_must_wait_for_prior_ex_config_probe = WireInit(0.U(rob_entries.W)) + val is_ex_config_and_must_wait_for_prior_st_probe = WireInit(0.U(rob_entries.W)) + + val wars_op1_probe = WireInit(0.U(rob_entries.W)) + val wars_op2_probe = WireInit(0.U(rob_entries.W)) + + val raws_op1_probe = WireInit(0.U(rob_entries.W)) + val raws_op2_probe = WireInit(0.U(rob_entries.W)) + + dontTouch(raws_probe) + dontTouch(waws_probe) + dontTouch(wars_probe) + dontTouch(wars_op1_probe) + dontTouch(wars_op2_probe) + dontTouch(raws_op1_probe) + dontTouch(raws_op2_probe) + dontTouch(older_in_same_q_probe) + dontTouch(is_st_and_must_wait_for_prior_ex_config_probe) + dontTouch(is_ex_config_and_must_wait_for_prior_st_probe) + when (io.alloc.fire()) { val spAddrBits = 32 val cmd = io.alloc.bits @@ -107,22 +151,61 @@ class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows new_entry.is_config := funct === CONFIG_CMD new_entry.op1.valid := funct === PRELOAD_CMD || funct_is_compute - new_entry.op1.bits := cmd.rs1.asTypeOf(local_addr_t) + new_entry.op1.bits.start := cmd.rs1.asTypeOf(local_addr_t) + when (funct === PRELOAD_CMD) { + val preload_rows = cmd.rs1(48 + log2Up(block_rows + 1) - 1, 48) + new_entry.op1.bits.end := new_entry.op1.bits.start + preload_rows + new_entry.op1.bits.wraps_around := new_entry.op1.bits.start.add_with_overflow(preload_rows)._2 + }.otherwise { + val compute_rows = cmd.rs1(48 + log2Up(block_rows + 1) - 1, 48) * a_stride + new_entry.op1.bits.end := new_entry.op1.bits.start + compute_rows + new_entry.op1.bits.wraps_around := new_entry.op1.bits.start.add_with_overflow(compute_rows)._2 + } new_entry.op2.valid := funct_is_compute || funct === STORE_CMD - new_entry.op2.bits := cmd.rs2.asTypeOf(local_addr_t) - - // new_entry.op3.valid := funct_is_compute - // new_entry.op3.bits := cmd.rs1(63, 32).asTypeOf(local_addr_t) + new_entry.op2.bits.start := cmd.rs2.asTypeOf(local_addr_t) + when (funct_is_compute) { + val compute_rows = cmd.rs2(48 + log2Up(block_rows + 1) - 1, 48) + new_entry.op2.bits.end := new_entry.op2.bits.start + compute_rows + new_entry.op2.bits.wraps_around := new_entry.op2.bits.start.add_with_overflow(compute_rows)._2 + }.otherwise { + val block_stride = st_block_stride + + val mvout_cols = cmd.rs2(32 + mvout_cols_bits - 1, 32) + val mvout_rows = cmd.rs2(48 + mvout_rows_bits - 1, 48) + + val mvout_mats = mvout_cols / block_cols.U + (mvout_cols % block_cols.U =/= 0.U) + val total_mvout_rows = ((mvout_mats - 1.U) * block_stride) + mvout_rows + + new_entry.op2.bits.end := new_entry.op2.bits.start + total_mvout_rows + new_entry.op2.bits.wraps_around := new_entry.op2.bits.start.add_with_overflow(total_mvout_rows)._2 + } - val mvin_mvout_len = cmd.rs2(48, spAddrBits) new_entry.dst.valid := funct === PRELOAD_CMD || funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD new_entry.dst.bits.start := cmd.rs2(31, 0).asTypeOf(local_addr_t) - new_entry.dst.bits.len := Mux(funct === PRELOAD_CMD, 1.U, mvin_mvout_len / block_cols.U + (mvin_mvout_len % block_cols.U =/= 0.U)) + when (funct === PRELOAD_CMD) { + val preload_rows = cmd.rs2(48 + log2Up(block_rows + 1) - 1, 48) + new_entry.dst.bits.end := new_entry.dst.bits.start + preload_rows + new_entry.dst.bits.wraps_around := new_entry.dst.bits.start.add_with_overflow(preload_rows)._2 + }.otherwise { + val id = MuxCase(0.U, Seq((new_entry.cmd.inst.funct === LOAD2_CMD) -> 1.U, + (new_entry.cmd.inst.funct === LOAD3_CMD) -> 2.U)) + val block_stride = ld_block_strides(id) + + val mvin_cols = cmd.rs2(spAddrBits + mvin_cols_bits - 1, spAddrBits) + val mvin_rows = cmd.rs2(spAddrBits + mvin_cols_bits + mvin_rows_bits - 1, spAddrBits + mvin_cols_bits) + + val mvin_mats = mvin_cols / block_cols.U + (mvin_cols % block_cols.U =/= 0.U) + val total_mvin_rows = ((mvin_mats - 1.U) * block_stride) + mvin_rows + + new_entry.dst.bits.end := new_entry.dst.bits.start + total_mvin_rows + new_entry.dst.bits.wraps_around := new_entry.dst.bits.start.add_with_overflow(total_mvin_rows)._2 + } val is_load = funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD || (funct === CONFIG_CMD && config_cmd_type === CONFIG_LOAD) val is_store = funct === STORE_CMD || (funct === CONFIG_CMD && config_cmd_type === CONFIG_STORE) val is_ex = funct === PRELOAD_CMD || funct_is_compute || (funct === CONFIG_CMD && (config_cmd_type === CONFIG_EX || config_cmd_type === CONFIG_IM2COL)) + val is_im2col = funct === CONFIG_CMD && config_cmd_type === CONFIG_IM2COL // im2col commands are a subset of ex commands, so they still go in the ex queue new_entry.q := Mux1H(Seq( is_load -> ldq, @@ -130,30 +213,53 @@ class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows is_ex -> exq )) + assert(is_load || is_store || is_ex) + + // TODO we should checck whether op1 and op2 are valid here val raws = entries.map { e => // We search for all entries which write to an address which we read from e.valid && e.bits.dst.valid && e.bits.q =/= new_entry.q && ( - (new_entry.op1.valid && e.bits.dst.bits.start <= new_entry.op1.bits && (e.bits.dst.bits.end() > new_entry.op1.bits || e.bits.dst.bits.wraps_around())) || - (new_entry.op2.valid && e.bits.dst.bits.start <= new_entry.op2.bits && (e.bits.dst.bits.end() > new_entry.op2.bits || e.bits.dst.bits.wraps_around()))) /* || - (new_entry.op3.valid && e.bits.dst.bits.start <= new_entry.op3.bits && e.bits.dst.bits.end() > new_entry.op3.bits)) */ + (new_entry.op1.valid && new_entry.op1.bits.overlaps(e.bits.dst.bits)) || + (new_entry.op2.valid && new_entry.op2.bits.overlaps(e.bits.dst.bits))) + } + + val raws_op1 = entries.map { e => + // We search for all entries which write to an address which we read from + e.valid && e.bits.dst.valid && e.bits.q =/= new_entry.q && ( + (new_entry.op1.valid && new_entry.op1.bits.overlaps(e.bits.dst.bits))) } + val raws_op2 = entries.map { e => + // We search for all entries which write to an address which we read from + e.valid && e.bits.dst.valid && e.bits.q =/= new_entry.q && ( + (new_entry.op2.valid && new_entry.op2.bits.overlaps(e.bits.dst.bits))) + } + + // TODO we should checck whether op1 and op2 are valid here val wars = entries.map { e => // We search for all entries which read from an address that we write to e.valid && new_entry.dst.valid && e.bits.q =/= new_entry.q && ( - (e.bits.op1.valid && new_entry.dst.bits.start <= e.bits.op1.bits && (new_entry.dst.bits.end() > e.bits.op1.bits || new_entry.dst.bits.wraps_around())) || - (e.bits.op2.valid && new_entry.dst.bits.start <= e.bits.op2.bits && (new_entry.dst.bits.end() > e.bits.op2.bits || new_entry.dst.bits.wraps_around()))) /* || - (e.bits.op3.valid && new_entry.dst.bits.start <= e.bits.op3.bits && new_entry.dst.bits.end() > e.bits.op3.bits)) */ + (e.bits.op1.valid && e.bits.op1.bits.overlaps(new_entry.dst.bits)) || + (e.bits.op2.valid && e.bits.op2.bits.overlaps(new_entry.dst.bits))) } - val waws = entries.map { e => - def is_accumulative(laddr: LocalAddr): Bool = laddr.is_acc_addr && laddr.accumulate + val wars_op1 = entries.map { e => + // We search for all entries which read from an address that we write to + e.valid && new_entry.dst.valid && e.bits.q =/= new_entry.q && ( + e.bits.op1.bits.overlaps(new_entry.dst.bits)) + } + val wars_op2 = entries.map { e => + // We search for all entries which read from an address that we write to + e.valid && new_entry.dst.valid && e.bits.q =/= new_entry.q && ( + e.bits.op2.bits.overlaps(new_entry.dst.bits)) + } + + // TODO we should checck whether op1 and op2 are valid here + val waws = entries.map { e => // We search for all entries which write to an address that we write to e.valid && new_entry.dst.valid && e.bits.dst.valid && e.bits.q =/= new_entry.q && - !(is_accumulative(new_entry.dst.bits.start) && is_accumulative(e.bits.dst.bits.start)) && - ((new_entry.dst.bits.start <= e.bits.dst.bits.start && (new_entry.dst.bits.end() > e.bits.dst.bits.start || new_entry.dst.bits.wraps_around())) || - (e.bits.dst.bits.start <= new_entry.dst.bits.start && (e.bits.dst.bits.end() > new_entry.dst.bits.start || e.bits.dst.bits.wraps_around()))) + (new_entry.dst.bits.overlaps(e.bits.dst.bits) || e.bits.dst.bits.overlaps(new_entry.dst.bits)) } val older_in_same_q = entries.map { e => @@ -171,17 +277,38 @@ class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows new_entry.deps := (Cat(raws) | Cat(wars) | Cat(waws) | Cat(older_in_same_q) | Cat(is_st_and_must_wait_for_prior_ex_config) | Cat(is_ex_config_and_must_wait_for_prior_st)).asBools().reverse + raws_probe := Cat(raws.reverse) + waws_probe := Cat(waws.reverse) + wars_probe := Cat(wars.reverse) + wars_op1_probe := Cat(wars_op1.reverse) + wars_op2_probe := Cat(wars_op2.reverse) + raws_op1_probe := Cat(raws_op1.reverse) + raws_op2_probe := Cat(raws_op2.reverse) + older_in_same_q_probe := Cat(older_in_same_q.reverse) + is_st_and_must_wait_for_prior_ex_config_probe := Cat(is_st_and_must_wait_for_prior_ex_config.reverse) + is_ex_config_and_must_wait_for_prior_st_probe := Cat(is_ex_config_and_must_wait_for_prior_st.reverse) + + new_entry.allocated_at := instructions_allocated + new_entry.complete_on_issue := new_entry.is_config && new_entry.q =/= exq entries(new_entry_id).valid := true.B entries(new_entry_id).bits := new_entry last_allocated := new_entry_id + + when (new_entry.is_config && new_entry.q === exq && !is_im2col) { + a_stride := new_entry.cmd.rs1(31, 16) // TODO magic numbers // TODO this needs to be kept in sync with ExecuteController.scala + }.elsewhen(new_entry.is_config && new_entry.q === ldq) { + val id = new_entry.cmd.rs1(4,3) // TODO magic numbers + val block_stride = new_entry.cmd.rs1(31, 16) // TODO magic numbers + ld_block_strides(id) := block_stride + } } // Issue commands which are ready to be issued Seq((ldq, io.issue.ld), (stq, io.issue.st), (exq, io.issue.ex)).foreach { case (q, io) => - val issue_id = MuxCase((nEntries-1).U, entries.zipWithIndex.map { case (e, i) => + val issue_id = MuxCase((rob_entries-1).U, entries.zipWithIndex.map { case (e, i) => (e.valid && e.bits.ready() && !e.bits.issued && e.bits.q === q) -> i.U }) @@ -227,7 +354,14 @@ class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows io.st_utilization := utilization_st_q io.ex_utilization := utilization_ex_q - val packed_deps = VecInit(entries.map(e => Cat(e.bits.deps))) + val valids = VecInit(entries.map(_.valid)) + val functs = VecInit(entries.map(_.bits.cmd.inst.funct)) + val issueds = VecInit(entries.map(_.bits.issued)) + val packed_deps = VecInit(entries.map(e => Cat(e.bits.deps.reverse))) + + dontTouch(valids) + dontTouch(functs) + dontTouch(issueds) dontTouch(packed_deps) val pop_count_packed_deps = VecInit(entries.map(e => Mux(e.valid, PopCount(e.bits.deps), 0.U))) @@ -238,13 +372,16 @@ class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows val cycles_since_issue = RegInit(0.U(16.W)) - when (io.issue.ld.fire() || io.issue.st.fire() || io.issue.ex.fire() || !io.busy) { + when (io.issue.ld.fire() || io.issue.st.fire() || io.issue.ex.fire() || !io.busy || io.completed.fire()) { cycles_since_issue := 0.U }.elsewhen(io.busy) { cycles_since_issue := cycles_since_issue + 1.U } assert(cycles_since_issue < 10000.U, "pipeline stall") + for (e <- entries) { + dontTouch(e.bits.allocated_at) + } val cntr = Counter(10000000) when (cntr.inc()) { diff --git a/src/main/scala/gemmini/Scratchpad.scala b/src/main/scala/gemmini/Scratchpad.scala index 39596cc3..f4812cc6 100644 --- a/src/main/scala/gemmini/Scratchpad.scala +++ b/src/main/scala/gemmini/Scratchpad.scala @@ -15,15 +15,13 @@ class ScratchpadMemReadRequest[U <: Data](local_addr_t: LocalAddr, scale_t_bits: val vaddr = UInt(coreMaxAddrBits.W) val laddr = local_addr_t.cloneType - val len = UInt(16.W) // TODO don't use a magic number for the width here + val cols = UInt(16.W) // TODO don't use a magic number for the width here val repeats = UInt(16.W) // TODO don't use a magic number for the width here - val scale = UInt(scale_t_bits.W) - val has_acc_bitwidth = Bool() - + val all_zeros = Bool() + val block_stride = UInt(16.W) // TODO magic numbers val cmd_id = UInt(8.W) // TODO don't use a magic number here - val status = new MStatus override def cloneType: this.type = new ScratchpadMemReadRequest(local_addr_t, scale_t_bits).asInstanceOf[this.type] @@ -35,9 +33,9 @@ class ScratchpadMemWriteRequest(local_addr_t: LocalAddr) val laddr = local_addr_t.cloneType val len = UInt(16.W) // TODO don't use a magic number for the width here + val block = UInt(8.W) // TODO don't use a magic number for the width here val cmd_id = UInt(8.W) // TODO don't use a magic number here - val status = new MStatus // Pooling variables @@ -95,7 +93,7 @@ class ScratchpadWriteIO(val n: Int, val w: Int, val mask_len: Int) extends Bundl val data = Output(UInt(w.W)) } -class ScratchpadBank(n: Int, w: Int, mem_pipeline: Int, aligned_to: Int) extends Module { +class ScratchpadBank(n: Int, w: Int, mem_pipeline: Int, aligned_to: Int, single_ported: Boolean) extends Module { // This is essentially a pipelined SRAM with the ability to stall pipeline stages require(w % aligned_to == 0 || w < aligned_to) @@ -107,9 +105,11 @@ class ScratchpadBank(n: Int, w: Int, mem_pipeline: Int, aligned_to: Int) extends val write = Flipped(new ScratchpadWriteIO(n, w, mask_len)) }) - // val mem = SyncReadMem(n, UInt(w.W)) val mem = SyncReadMem(n, Vec(mask_len, mask_elem)) + // When the scratchpad is single-ported, the writes take precedence + val singleport_busy_with_write = single_ported.B && io.write.en + when (io.write.en) { if (aligned_to >= w) mem.write(io.write.addr, io.write.data.asTypeOf(Vec(mask_len, mask_elem))) @@ -119,7 +119,13 @@ class ScratchpadBank(n: Int, w: Int, mem_pipeline: Int, aligned_to: Int) extends val raddr = io.read.req.bits.addr val ren = io.read.req.fire() - val rdata = mem.read(raddr, ren).asUInt() + val rdata = if (single_ported) { + assert(!(ren && io.write.en)) + mem.read(raddr, ren && !io.write.en).asUInt() + } else { + mem.read(raddr, ren).asUInt() + } + val fromDMA = io.read.req.bits.fromDMA // Make a queue which buffers the result of an SRAM read if it can't immediately be consumed @@ -129,7 +135,7 @@ class ScratchpadBank(n: Int, w: Int, mem_pipeline: Int, aligned_to: Int) extends q.io.enq.bits.fromDMA := RegNext(fromDMA) val q_will_be_empty = (q.io.count +& q.io.enq.fire()) - q.io.deq.fire() === 0.U - io.read.req.ready := q_will_be_empty + io.read.req.ready := q_will_be_empty && !singleport_busy_with_write // Build the rest of the resp pipeline val rdata_p = Pipeline(q.io.deq, mem_pipeline) @@ -183,9 +189,16 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // Accumulator ports val acc = new Bundle { - val read = Flipped(Vec(acc_banks, new AccumulatorReadIO(acc_bank_entries, log2Up(accType.getWidth), Vec(meshColumns, Vec(tileColumns, inputType)), Vec(meshColumns, Vec(tileColumns, accType)), acc_scale_args.multiplicand_t))) - // val write = Flipped(Vec(acc_banks, new AccumulatorWriteReq(acc_bank_entries, Vec(meshColumns, Vec(tileColumns, accType))))) - val write = Flipped(Vec(acc_banks, Decoupled(new AccumulatorWriteReq(acc_bank_entries, Vec(meshColumns, Vec(tileColumns, accType)))))) + val read_req = Flipped(Vec(acc_banks, Decoupled(new AccumulatorReadReq( + acc_bank_entries, log2Up(accType.getWidth), acc_scale_args.multiplicand_t + )))) + val read_resp = Vec(acc_banks, Decoupled(new AccumulatorScaleResp( + Vec(meshColumns, Vec(tileColumns, inputType)), + Vec(meshColumns, Vec(tileColumns, accType)) + ))) + val write = Flipped(Vec(acc_banks, Decoupled(new AccumulatorWriteReq( + acc_bank_entries, Vec(meshColumns, Vec(tileColumns, accType)) + )))) } // TLB ports @@ -197,17 +210,34 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, }) val write_dispatch_q = Queue(io.dma.write.req) - write_dispatch_q.ready := false.B - + // Write scale queue is necessary to maintain in-order requests to accumulator scale unit + // Writes from main SPAD just flow directly between scale_q and issue_q, while writes + // From acc are ordered + val write_scale_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t), mem_pipeline)) val write_issue_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t), mem_pipeline+1, pipe=true)) val read_issue_q = Module(new Queue(new ScratchpadMemReadRequest(local_addr_t, mvin_scale_t_bits), mem_pipeline+1, pipe=true)) // TODO can't this just be a normal queue? + write_scale_q.io.enq.valid := false.B + write_scale_q.io.enq.bits := write_dispatch_q.bits + write_scale_q.io.deq.ready := false.B + write_issue_q.io.enq.valid := false.B - write_issue_q.io.enq.bits := write_dispatch_q.bits + write_issue_q.io.enq.bits := write_scale_q.io.deq.bits + + + // Garbage can immediately fire between dispatch_q and scale_q + when (write_dispatch_q.bits.laddr.is_garbage()) { + write_scale_q.io.enq <> write_dispatch_q + } + // Non-acc or garbage can immediately fire between scale_q and issue_q + when (write_scale_q.io.deq.bits.laddr.is_garbage() || !write_scale_q.io.deq.bits.laddr.is_acc_addr) { + write_issue_q.io.enq <> write_scale_q.io.deq + } + val writeData = Wire(Valid(UInt((spad_w max acc_w).W))) - writeData.valid := false.B + writeData.valid := write_issue_q.io.deq.bits.laddr.is_garbage() writeData.bits := DontCare val fullAccWriteData = Wire(UInt(acc_w.W)) fullAccWriteData := DontCare @@ -215,8 +245,8 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, write_issue_q.io.deq.bits.laddr.is_acc_addr && write_issue_q.io.deq.bits.laddr.read_full_acc_row val writeData_is_all_zeros = write_issue_q.io.deq.bits.laddr.is_garbage() - writer.module.io.req.valid := write_issue_q.io.deq.valid && (writeData.valid || writeData_is_all_zeros) - write_issue_q.io.deq.ready := writer.module.io.req.ready && (writeData.valid || writeData_is_all_zeros) + writer.module.io.req.valid := write_issue_q.io.deq.valid && writeData.valid + write_issue_q.io.deq.ready := writer.module.io.req.ready && writeData.valid writer.module.io.req.bits.vaddr := write_issue_q.io.deq.bits.vaddr writer.module.io.req.bits.len := Mux(writeData_is_full_width, write_issue_q.io.deq.bits.len * (accType.getWidth / 8).U, @@ -225,43 +255,68 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, writeData_is_all_zeros -> 0.U, writeData_is_full_width -> fullAccWriteData )) + writer.module.io.req.bits.block := write_issue_q.io.deq.bits.block writer.module.io.req.bits.status := write_issue_q.io.deq.bits.status writer.module.io.req.bits.pool_en := write_issue_q.io.deq.bits.pool_en writer.module.io.req.bits.store_en := write_issue_q.io.deq.bits.store_en - // FpgaDebug(write_issue_q.io.deq.bits.laddr.data) - // FpgaDebug(write_issue_q.io.deq.bits.laddr.accumulate) - // FpgaDebug(write_issue_q.io.deq.bits.laddr.is_acc_addr) - io.dma.write.resp.valid := false.B io.dma.write.resp.bits.cmd_id := write_dispatch_q.bits.cmd_id + when (write_dispatch_q.bits.laddr.is_garbage() && write_dispatch_q.fire()) { + io.dma.write.resp.valid := true.B + } read_issue_q.io.enq <> io.dma.read.req + val zero_writer = Module(new ZeroWriter(config, new ScratchpadMemReadRequest(local_addr_t, mvin_scale_t_bits))) + + when (io.dma.read.req.bits.all_zeros) { + read_issue_q.io.enq.valid := false.B + io.dma.read.req.ready := zero_writer.io.req.ready + } + + zero_writer.io.req.valid := io.dma.read.req.valid && io.dma.read.req.bits.all_zeros + zero_writer.io.req.bits.laddr := io.dma.read.req.bits.laddr + zero_writer.io.req.bits.cols := io.dma.read.req.bits.cols + zero_writer.io.req.bits.block_stride := io.dma.read.req.bits.block_stride + zero_writer.io.req.bits.tag := io.dma.read.req.bits + + zero_writer.io.resp.ready := false.B + reader.module.io.req.valid := read_issue_q.io.deq.valid read_issue_q.io.deq.ready := reader.module.io.req.ready reader.module.io.req.bits.vaddr := read_issue_q.io.deq.bits.vaddr reader.module.io.req.bits.spaddr := Mux(read_issue_q.io.deq.bits.laddr.is_acc_addr, read_issue_q.io.deq.bits.laddr.full_acc_addr(), read_issue_q.io.deq.bits.laddr.full_sp_addr()) - reader.module.io.req.bits.len := read_issue_q.io.deq.bits.len + reader.module.io.req.bits.len := read_issue_q.io.deq.bits.cols reader.module.io.req.bits.repeats := read_issue_q.io.deq.bits.repeats reader.module.io.req.bits.scale := read_issue_q.io.deq.bits.scale reader.module.io.req.bits.is_acc := read_issue_q.io.deq.bits.laddr.is_acc_addr reader.module.io.req.bits.accumulate := read_issue_q.io.deq.bits.laddr.accumulate reader.module.io.req.bits.has_acc_bitwidth := read_issue_q.io.deq.bits.has_acc_bitwidth + reader.module.io.req.bits.block_stride := read_issue_q.io.deq.bits.block_stride reader.module.io.req.bits.status := read_issue_q.io.deq.bits.status reader.module.io.req.bits.cmd_id := read_issue_q.io.deq.bits.cmd_id - val (mvin_scale_in, mvin_scale_out) = VectorScalarMultiplier(config.mvin_scale_args, config.inputType, config.meshColumns * config.tileColumns, chiselTypeOf(reader.module.io.resp.bits), is_acc = false) - val (mvin_scale_acc_in, mvin_scale_acc_out) = if (mvin_scale_shared) (mvin_scale_in, mvin_scale_out) else - VectorScalarMultiplier(config.mvin_scale_acc_args, config.accType, config.meshColumns * config.tileColumns, chiselTypeOf(reader.module.io.resp.bits), is_acc = true) + val (mvin_scale_in, mvin_scale_out) = VectorScalarMultiplier( + config.mvin_scale_args, + config.inputType, config.meshColumns * config.tileColumns, chiselTypeOf(reader.module.io.resp.bits), + is_acc = false + ) + val (mvin_scale_acc_in, mvin_scale_acc_out) = if (mvin_scale_shared) (mvin_scale_in, mvin_scale_out) else ( + VectorScalarMultiplier( + config.mvin_scale_acc_args, + config.accType, config.meshColumns * config.tileColumns, chiselTypeOf(reader.module.io.resp.bits), + is_acc = true + ) + ) mvin_scale_in.valid := reader.module.io.resp.valid && (mvin_scale_shared.B || !reader.module.io.resp.bits.is_acc || (reader.module.io.resp.bits.is_acc && !reader.module.io.resp.bits.has_acc_bitwidth)) mvin_scale_in.bits.in := reader.module.io.resp.bits.data.asTypeOf(chiselTypeOf(mvin_scale_in.bits.in)) mvin_scale_in.bits.scale := reader.module.io.resp.bits.scale.asTypeOf(mvin_scale_t) - mvin_scale_in.bits.repeats := reader.module.io.resp.bits.rows + mvin_scale_in.bits.repeats := reader.module.io.resp.bits.repeats mvin_scale_in.bits.last := reader.module.io.resp.bits.last mvin_scale_in.bits.tag := reader.module.io.resp.bits @@ -272,7 +327,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, (reader.module.io.resp.bits.is_acc && reader.module.io.resp.bits.has_acc_bitwidth) mvin_scale_acc_in.bits.in := reader.module.io.resp.bits.data.asTypeOf(chiselTypeOf(mvin_scale_acc_in.bits.in)) mvin_scale_acc_in.bits.scale := reader.module.io.resp.bits.scale.asTypeOf(mvin_scale_acc_t) - mvin_scale_acc_in.bits.repeats := reader.module.io.resp.bits.rows + mvin_scale_acc_in.bits.repeats := reader.module.io.resp.bits.repeats mvin_scale_acc_in.bits.last := reader.module.io.resp.bits.last mvin_scale_acc_in.bits.tag := reader.module.io.resp.bits @@ -284,9 +339,22 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val mvin_scale_finished = mvin_scale_out.fire() && mvin_scale_out.bits.last val mvin_scale_acc_finished = mvin_scale_acc_out.fire() && mvin_scale_acc_out.bits.last - io.dma.read.resp.valid := mvin_scale_finished || mvin_scale_acc_finished - io.dma.read.resp.bits.cmd_id := Mux(mvin_scale_finished, mvin_scale_out.bits.tag.cmd_id, mvin_scale_acc_out.bits.tag.cmd_id) - io.dma.read.resp.bits.bytesRead := Mux(mvin_scale_finished, mvin_scale_out.bits.tag.bytes_read, mvin_scale_acc_out.bits.tag.bytes_read) + val zero_writer_finished = zero_writer.io.resp.fire() && zero_writer.io.resp.bits.last + + val zero_writer_bytes_read = Mux(zero_writer.io.resp.bits.laddr.is_acc_addr, + zero_writer.io.resp.bits.tag.cols * (accType.getWidth / 8).U, + zero_writer.io.resp.bits.tag.cols * (inputType.getWidth / 8).U) + + // For DMA read responses, mvin_scale gets first priority, then mvin_scale_acc, and then zero_writer + io.dma.read.resp.valid := mvin_scale_finished || mvin_scale_acc_finished || zero_writer_finished + + io.dma.read.resp.bits.cmd_id := MuxCase(zero_writer.io.resp.bits.tag.cmd_id, Seq( + mvin_scale_finished -> mvin_scale_out.bits.tag.cmd_id, + mvin_scale_acc_finished -> mvin_scale_acc_out.bits.tag.cmd_id)) + + io.dma.read.resp.bits.bytesRead := MuxCase(zero_writer_bytes_read, Seq( + mvin_scale_finished -> mvin_scale_out.bits.tag.bytes_read, + mvin_scale_acc_finished -> mvin_scale_acc_out.bits.tag.bytes_read)) io.tlb(0) <> writer.module.io.tlb io.tlb(1) <> reader.module.io.tlb @@ -294,10 +362,10 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, writer.module.io.flush := io.flush reader.module.io.flush := io.flush - io.busy := writer.module.io.busy || reader.module.io.busy || write_issue_q.io.deq.valid + io.busy := writer.module.io.busy || reader.module.io.busy || write_issue_q.io.deq.valid || write_scale_q.io.deq.valid || write_dispatch_q.valid { - val banks = Seq.fill(sp_banks) { Module(new ScratchpadBank(sp_bank_entries, spad_w, mem_pipeline, aligned_to)) } + val banks = Seq.fill(sp_banks) { Module(new ScratchpadBank(sp_bank_entries, spad_w, mem_pipeline, aligned_to, config.sp_singleported)) } val bank_ios = VecInit(banks.map(_.io)) // Getting the output of the bank that's about to be issued to the writer @@ -314,10 +382,12 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val exread = ex_read_req.valid // TODO we tie the write dispatch queue's, and write issue queue's, ready and valid signals together here - val dmawrite = write_dispatch_q.valid && write_issue_q.io.enq.ready && + val dmawrite = write_dispatch_q.valid && write_scale_q.io.enq.ready && + !write_dispatch_q.bits.laddr.is_garbage() && + !(bio.write.en && config.sp_singleported.B) && !write_dispatch_q.bits.laddr.is_acc_addr && write_dispatch_q.bits.laddr.sp_bank() === i.U - bio.read.req.valid := exread || (dmawrite && !write_dispatch_q.bits.laddr.is_garbage()) + bio.read.req.valid := exread || dmawrite ex_read_req.ready := bio.read.req.ready // The ExecuteController gets priority when reading from SRAMs @@ -328,9 +398,9 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, bio.read.req.bits.addr := write_dispatch_q.bits.laddr.sp_row() bio.read.req.bits.fromDMA := true.B - when (bio.read.req.fire() || write_dispatch_q.bits.laddr.is_garbage()) { + when (bio.read.req.fire()) { write_dispatch_q.ready := true.B - write_issue_q.io.enq.valid := true.B + write_scale_q.io.enq.valid := true.B io.dma.write.resp.valid := true.B } @@ -357,7 +427,13 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val dmaread = mvin_scale_out.valid && !mvin_scale_out.bits.tag.is_acc && laddr.sp_bank() === i.U - bio.write.en := exwrite || dmaread + // We need to make sure that we don't try to return a dma read resp from both zero_writer and either mvin_scale + // or mvin_acc_scale at the same time. The scalers always get priority in those cases + val zerowrite = zero_writer.io.resp.valid && !zero_writer.io.resp.bits.laddr.is_acc_addr && + zero_writer.io.resp.bits.laddr.sp_bank() === i.U && + !((mvin_scale_out.valid && mvin_scale_out.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last)) + + bio.write.en := exwrite || dmaread || zerowrite when (exwrite) { bio.write.addr := io.srams.write(i).addr @@ -369,6 +445,17 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, bio.write.mask := mvin_scale_out.bits.tag.mask take ((spad_w / (aligned_to * 8)) max 1) mvin_scale_out.ready := true.B // TODO we combinationally couple valid and ready signals + }.elsewhen (zerowrite) { + bio.write.addr := zero_writer.io.resp.bits.laddr.sp_row() + bio.write.data := 0.U + bio.write.mask := { + val n = inputType.getWidth / 8 + val mask = zero_writer.io.resp.bits.mask + val expanded = VecInit(mask.flatMap(e => Seq.fill(n)(e))) + expanded + } + + zero_writer.io.resp.ready := true.B // TODO we combinationally couple valid and ready signals }.otherwise { bio.write.addr := DontCare bio.write.data := DontCare @@ -377,32 +464,64 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, } } + val acc_row_t = Vec(meshColumns, Vec(tileColumns, accType)) + val spad_row_t = Vec(meshColumns, Vec(tileColumns, inputType)) + + val acc_scale_unit = Module(new AccumulatorScale( + acc_row_t, + spad_row_t, + acc_scale_args.multiplicand_t, + log2Up(accType.getWidth), + acc_read_small_width, + acc_read_full_width, + acc_scale_args + )) + + acc_scale_unit.io.in.valid := false.B + acc_scale_unit.io.in.bits := DontCare + val dma_resp_ready = ( + writer.module.io.req.ready && + write_issue_q.io.deq.bits.laddr.is_acc_addr && + !write_issue_q.io.deq.bits.laddr.is_garbage() + ) + acc_scale_unit.io.out.ready := false.B + when (acc_scale_unit.io.out.bits.fromDMA && dma_resp_ready) { + acc_scale_unit.io.out.ready := true.B + writeData.valid := acc_scale_unit.io.out.valid + writeData.bits := acc_scale_unit.io.out.bits.data.asUInt + fullAccWriteData := acc_scale_unit.io.out.bits.full_data.asUInt + } + for (i <- 0 until acc_banks) { + io.acc.read_resp(i).valid := false.B + io.acc.read_resp(i).bits := acc_scale_unit.io.out.bits + when (!acc_scale_unit.io.out.bits.fromDMA && acc_scale_unit.io.out.bits.acc_bank_id === i.U) { + acc_scale_unit.io.out.ready := io.acc.read_resp(i).ready + io.acc.read_resp(i).valid := acc_scale_unit.io.out.valid + } + } + { - val acc_row_t = Vec(meshColumns, Vec(tileColumns, accType)) - val spad_row_t = Vec(meshColumns, Vec(tileColumns, inputType)) - val banks = Seq.fill(acc_banks) { Module(new AccumulatorMem(acc_bank_entries, acc_row_t, spad_row_t, mem_pipeline, acc_scale_args, acc_read_small_width, acc_read_full_width)) } + val banks = Seq.fill(acc_banks) { Module(new AccumulatorMem( + acc_bank_entries, acc_row_t, acc_scale_args, + acc_singleported, num_acc_sub_banks + )) } val bank_ios = VecInit(banks.map(_.io)) // Getting the output of the bank that's about to be issued to the writer val bank_issued_io = bank_ios(write_issue_q.io.deq.bits.laddr.acc_bank()) - when (write_issue_q.io.deq.bits.laddr.is_acc_addr) { - writeData.valid := bank_issued_io.read.resp.valid && bank_issued_io.read.resp.bits.fromDMA - writeData.bits := bank_issued_io.read.resp.bits.data.asUInt() - fullAccWriteData := bank_issued_io.read.resp.bits.full_data.asUInt() - } - // Reading from the Accumulator banks bank_ios.zipWithIndex.foreach { case (bio, i) => - val ex_read_req = io.acc.read(i).req + val ex_read_req = io.acc.read_req(i) val exread = ex_read_req.valid // TODO we tie the write dispatch queue's, and write issue queue's, ready and valid signals together here - val dmawrite = write_dispatch_q.valid && write_issue_q.io.enq.ready && + val dmawrite = write_dispatch_q.valid && write_scale_q.io.enq.ready && + !write_dispatch_q.bits.laddr.is_garbage() && write_dispatch_q.bits.laddr.is_acc_addr && write_dispatch_q.bits.laddr.acc_bank() === i.U - bio.read.req.valid := exread || (dmawrite && !write_dispatch_q.bits.laddr.is_garbage()) + bio.read.req.valid := exread || dmawrite bio.read.req.bits.scale := ex_read_req.bits.scale bio.read.req.bits.relu6_shift := ex_read_req.bits.relu6_shift bio.read.req.bits.act := ex_read_req.bits.act @@ -418,28 +537,41 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, bio.read.req.bits.full := write_dispatch_q.bits.laddr.read_full_acc_row bio.read.req.bits.fromDMA := true.B - when (bio.read.req.fire() || write_dispatch_q.bits.laddr.is_garbage()) { + when (bio.read.req.fire()) { write_dispatch_q.ready := true.B - write_issue_q.io.enq.valid := true.B + write_scale_q.io.enq.valid := true.B io.dma.write.resp.valid := true.B } }.otherwise { bio.read.req.bits := DontCare } + bio.read.resp.ready := false.B + + + when (write_scale_q.io.deq.valid && + acc_scale_unit.io.in.ready && + bio.read.resp.valid && + write_issue_q.io.enq.ready && + write_scale_q.io.deq.bits.laddr.is_acc_addr && + !write_scale_q.io.deq.bits.laddr.is_garbage() && + write_scale_q.io.deq.bits.laddr.acc_bank() === i.U) { + write_scale_q.io.deq.ready := true.B + acc_scale_unit.io.in.valid := true.B + bio.read.resp.ready := true.B + write_issue_q.io.enq.valid := true.B + + acc_scale_unit.io.in.bits := bio.read.resp.bits + acc_scale_unit.io.in.bits.acc_bank_id := i.U + } - val ex_read_resp = io.acc.read(i).resp - val dma_resp_ready = writer.module.io.req.ready && - write_issue_q.io.deq.bits.laddr.is_acc_addr && write_issue_q.io.deq.bits.laddr.acc_bank() === i.U && // I believe we don't need to check that write_issue_q is valid here, because if the accumulator bank's resp is valid, then that means that the write_issue_q's deq should also be valid - !write_issue_q.io.deq.bits.laddr.is_garbage() - - bio.read.resp.ready := Mux(bio.read.resp.bits.fromDMA, dma_resp_ready, ex_read_resp.ready) - ex_read_resp.valid := bio.read.resp.valid // TODO should we AND this with fromDMA? - ex_read_resp.bits := bio.read.resp.bits } // Writing to the accumulator banks bank_ios.zipWithIndex.foreach { case (bio, i) => + // Order of precedence during writes is ExecuteController, and then mvin_scale, and then mvin_scale_acc, and + // then zero_writer + val exwrite = io.acc.write(i).valid io.acc.write(i).ready := true.B assert(!(exwrite && !bio.write.ready), "Execute controller write to AccumulatorMem was skipped") @@ -447,48 +579,85 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val from_mvin_scale = mvin_scale_out.valid && mvin_scale_out.bits.tag.is_acc val from_mvin_scale_acc = mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.tag.is_acc - val mvin_scale_acc_laddr = mvin_scale_acc_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_acc_out.bits.row val mvin_scale_laddr = mvin_scale_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_out.bits.row + val mvin_scale_acc_laddr = mvin_scale_acc_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_acc_out.bits.row - val dmaread_bank = Mux(from_mvin_scale_acc, mvin_scale_acc_laddr.acc_bank(), - mvin_scale_laddr.acc_bank()) - val dmaread_row = Mux(from_mvin_scale_acc, mvin_scale_acc_laddr.acc_row(), mvin_scale_laddr.acc_row()) + val dmaread_bank = Mux(from_mvin_scale, mvin_scale_laddr.acc_bank(), + mvin_scale_acc_laddr.acc_bank()) + val dmaread_row = Mux(from_mvin_scale, mvin_scale_laddr.acc_row(), mvin_scale_acc_laddr.acc_row()) // We need to make sure that we don't try to return a dma read resp from both mvin_scale and mvin_scale_acc // at the same time. mvin_scale always gets priority in this cases - val mvin_scale_out_last = mvin_scale_out.valid && mvin_scale_out.bits.last + val spad_last = mvin_scale_out.valid && mvin_scale_out.bits.last && !mvin_scale_out.bits.tag.is_acc val dmaread = (from_mvin_scale || from_mvin_scale_acc) && - dmaread_bank === i.U && - (mvin_scale_same.B || from_mvin_scale || !mvin_scale_out_last) + dmaread_bank === i.U /* && + (mvin_scale_same.B || from_mvin_scale || !spad_dmaread_last) */ + + // We need to make sure that we don't try to return a dma read resp from both zero_writer and either mvin_scale + // or mvin_acc_scale at the same time. The scalers always get priority in those cases + val zerowrite = zero_writer.io.resp.valid && zero_writer.io.resp.bits.laddr.is_acc_addr && + zero_writer.io.resp.bits.laddr.acc_bank() === i.U && + !((mvin_scale_out.valid && mvin_scale_out.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last)) + val consecutive_write_block = RegInit(false.B) + if (acc_singleported) { + val consecutive_write_sub_bank = RegInit(0.U((1 max log2Ceil(num_acc_sub_banks)).W)) + when (bio.write.fire() && bio.write.bits.acc && + (bio.write.bits.addr(log2Ceil(num_acc_sub_banks)-1,0) === consecutive_write_sub_bank)) { + consecutive_write_block := true.B + } .elsewhen (bio.write.fire() && bio.write.bits.acc) { + consecutive_write_block := false.B + consecutive_write_sub_bank := bio.write.bits.addr(log2Ceil(num_acc_sub_banks)-1,0) + } .otherwise { + consecutive_write_block := false.B + } + } + bio.write.valid := false.B - bio.write.valid := exwrite || dmaread - bio.write.bits.acc := Mux(exwrite, io.acc.write(i).bits.acc, - Mux(from_mvin_scale_acc, mvin_scale_acc_out.bits.tag.accumulate, mvin_scale_out.bits.tag.accumulate)) - bio.write.bits.addr := Mux(exwrite, io.acc.write(i).bits.addr, dmaread_row) + bio.write.bits.acc := MuxCase(zero_writer.io.resp.bits.laddr.accumulate, + Seq(exwrite -> io.acc.write(i).bits.acc, + from_mvin_scale -> mvin_scale_out.bits.tag.accumulate, + from_mvin_scale_acc -> mvin_scale_acc_out.bits.tag.accumulate)) + + bio.write.bits.addr := MuxCase(zero_writer.io.resp.bits.laddr.acc_row(), + Seq(exwrite -> io.acc.write(i).bits.addr, + (from_mvin_scale || from_mvin_scale_acc) -> dmaread_row)) when (exwrite) { + bio.write.valid := true.B bio.write.bits.data := io.acc.write(i).bits.data bio.write.bits.mask := io.acc.write(i).bits.mask - }.elsewhen (dmaread && bio.write.fire()) { - bio.write.bits.data := Mux(from_mvin_scale_acc, - mvin_scale_acc_out.bits.out.asTypeOf(acc_row_t), - VecInit(mvin_scale_out.bits.out.map(e => e.withWidthOf(accType))).asTypeOf(acc_row_t)) + }.elsewhen (dmaread && !spad_last && !consecutive_write_block) { + bio.write.valid := true.B + bio.write.bits.data := Mux(from_mvin_scale, + VecInit(mvin_scale_out.bits.out.map(e => e.withWidthOf(accType))).asTypeOf(acc_row_t), + mvin_scale_acc_out.bits.out.asTypeOf(acc_row_t)) bio.write.bits.mask := - Mux(from_mvin_scale_acc, - mvin_scale_acc_out.bits.tag.mask, + Mux(from_mvin_scale, { val n = accType.getWidth / inputType.getWidth val mask = mvin_scale_out.bits.tag.mask take ((spad_w / (aligned_to * 8)) max 1) val expanded = VecInit(mask.flatMap(e => Seq.fill(n)(e))) expanded - }) + }, + mvin_scale_acc_out.bits.tag.mask) - when (from_mvin_scale_acc) { - mvin_scale_acc_out.ready := true.B + when(from_mvin_scale) { + mvin_scale_out.ready := bio.write.ready }.otherwise { - mvin_scale_out.ready := true.B + mvin_scale_acc_out.ready := bio.write.ready + } + }.elsewhen (zerowrite && !spad_last && !consecutive_write_block) { + bio.write.valid := true.B + bio.write.bits.data := 0.U.asTypeOf(acc_row_t) + bio.write.bits.mask := { + val n = accType.getWidth / 8 + val mask = zero_writer.io.resp.bits.mask + val expanded = VecInit(mask.flatMap(e => Seq.fill(n)(e))) + expanded } + + zero_writer.io.resp.ready := bio.write.ready }.otherwise { bio.write.bits.data := DontCare bio.write.bits.mask := DontCare diff --git a/src/main/scala/gemmini/StoreController.scala b/src/main/scala/gemmini/StoreController.scala index 07399a16..98584bca 100644 --- a/src/main/scala/gemmini/StoreController.scala +++ b/src/main/scala/gemmini/StoreController.scala @@ -35,8 +35,13 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val stride = Reg(UInt(coreMaxAddrBits.W)) val block_rows = meshRows * tileRows + val block_stride = block_rows.U + val block_cols = meshColumns * tileColumns + val max_blocks = (dma_maxbytes / (block_cols * inputType.getWidth / 8)) max 1 + //val row_counter = RegInit(0.U(log2Ceil(block_rows).W)) - val row_counter = RegInit(0.U(12.W)) + val row_counter = RegInit(0.U(12.W)) // TODO magic number + val block_counter = RegInit(0.U(8.W)) // TODO magic number // Pooling variables val pool_stride = Reg(UInt(2.W)) // When this is 0, pooling is disabled // TODO magic number @@ -69,8 +74,9 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val cmd = Queue(io.cmd, st_queue_length) val vaddr = cmd.bits.cmd.rs1 val localaddr = cmd.bits.cmd.rs2.asTypeOf(local_addr_t) - val cols = cmd.bits.cmd.rs2(32 + mvout_len_bits - 1, 32) // TODO magic numbers + val cols = cmd.bits.cmd.rs2(32 + mvout_cols_bits - 1, 32) // TODO magic numbers val rows = cmd.bits.cmd.rs2(48 + mvout_rows_bits - 1, 48) // TODO magic numbers + val blocks = (cols / block_cols.U) + (cols % block_cols.U =/= 0.U) val config_stride = cmd.bits.cmd.rs2 val config_pool_stride = cmd.bits.cmd.rs1(5, 4) // TODO magic numbers val config_pool_size = cmd.bits.cmd.rs1(7, 6) // TODO magic numbers @@ -84,7 +90,8 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val mstatus = cmd.bits.cmd.status - val localaddr_plus_row_counter = localaddr + row_counter + val current_vaddr = vaddr + row_counter * stride + val current_localaddr = localaddr + (block_counter * block_stride + row_counter) val pool_row_addr = localaddr + (orow * pool_ocols +& ocol) when (orow_is_negative || ocol_is_negative || orow >= pool_orows || ocol >= pool_ocols) { @@ -106,31 +113,32 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val rob_id = UInt(log2Up(rob_entries).W) } - val cmd_tracker_max_rows = (block_rows max + val cmd_tracker_max_rows = ((block_rows * max_blocks) max (((1 << pool_orows.getWidth)-1) * ((1 << pool_ocols.getWidth)-1) + 2*((1 << pool_lpad.getWidth)-1) + 2*((1 << pool_upad.getWidth)-1))) min ((config.sp_banks * config.sp_bank_entries) max (config.acc_banks * config.acc_bank_entries)) - val cmd_tracker = Module(new DMAReadCommandTracker(nCmds, cmd_tracker_max_rows, deps_t)) + val cmd_tracker = Module(new DMACommandTracker(nCmds, cmd_tracker_max_rows, deps_t)) // DMA IO wiring io.dma.req.valid := (control_state === waiting_for_command && cmd.valid && DoStore && cmd_tracker.io.alloc.ready) || control_state === waiting_for_dma_req_ready || - (control_state === sending_rows && row_counter =/= 0.U) || // TODO Do we really have to check whether the counters should be 0 here? + (control_state === sending_rows && (block_counter =/= 0.U || row_counter =/= 0.U)) || (control_state === pooling && (wcol_counter =/= 0.U || wrow_counter =/= 0.U || pocol_counter =/= 0.U || porow_counter =/= 0.U)) - io.dma.req.bits.vaddr := Mux(pooling_is_enabled || mvout_1d_enabled, pool_vaddr, vaddr + row_counter * stride) - io.dma.req.bits.laddr := Mux(pooling_is_enabled, pool_row_addr, localaddr_plus_row_counter) //Todo: laddr for 1D? + io.dma.req.bits.vaddr := Mux(pooling_is_enabled || mvout_1d_enabled, pool_vaddr, current_vaddr) + io.dma.req.bits.laddr := Mux(pooling_is_enabled, pool_row_addr, current_localaddr) //Todo: laddr for 1D? - io.dma.req.bits.len := cols + io.dma.req.bits.len := Mux(block_counter === blocks - 1.U, ((cols - 1.U) % block_cols.U) + 1.U, block_cols.U) + io.dma.req.bits.block := block_counter io.dma.req.bits.status := mstatus io.dma.req.bits.pool_en := pooling_is_enabled && (wrow_counter =/= 0.U || wcol_counter =/= 0.U) - io.dma.req.bits.store_en := !pooling_is_enabled || - (wrow_counter === pool_size - 1.U && wcol_counter === pool_size - 1.U) + io.dma.req.bits.store_en := Mux(pooling_is_enabled, wrow_counter === pool_size - 1.U && wcol_counter === pool_size - 1.U, + block_counter === blocks - 1.U) // Command tracker IO cmd_tracker.io.alloc.valid := control_state === waiting_for_command && cmd.valid && DoStore - cmd_tracker.io.alloc.bits.bytes_to_read := Mux(!pooling_is_enabled, Mux(mvout_1d_enabled, mvout_1d_rows, rows), pool_total_rows) // TODO do we have to add upad and lpad to this? + cmd_tracker.io.alloc.bits.bytes_to_read := Mux(!pooling_is_enabled, Mux(mvout_1d_enabled, mvout_1d_rows, rows*blocks), pool_total_rows) // TODO do we have to add upad and lpad to this? cmd_tracker.io.alloc.bits.tag.rob_id := cmd.bits.rob_id.bits cmd_tracker.io.request_returned.valid := io.dma.resp.fire() // TODO use a bundle connect @@ -155,13 +163,18 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm pocol_counter := wrappingAdd(pocol_counter, 1.U, pool_ocols) porow_counter := wrappingAdd(porow_counter, 1.U, pool_orows, pocol_counter === pool_ocols - 1.U) } - row_counter := Mux(mvout_1d_enabled, wrappingAdd(row_counter, 1.U, mvout_1d_rows), wrappingAdd(row_counter, 1.U, rows)) + + block_counter := wrappingAdd(block_counter, 1.U, blocks) + row_counter := Mux(mvout_1d_enabled, wrappingAdd(row_counter, 1.U, mvout_1d_rows), wrappingAdd(row_counter, 1.U, rows, block_counter === blocks - 1.U)) }.otherwise { wcol_counter := wrappingAdd(wcol_counter, 1.U, pool_size) wrow_counter := wrappingAdd(wrow_counter, 1.U, pool_size, wcol_counter === pool_size - 1.U) pocol_counter := wrappingAdd(pocol_counter, 1.U, pool_pocols, wrow_counter === pool_size - 1.U && wcol_counter === pool_size - 1.U) porow_counter := wrappingAdd(porow_counter, 1.U, pool_porows, pocol_counter === pool_pocols - 1.U && wrow_counter === pool_size - 1.U && wcol_counter === pool_size - 1.U) } + + assert(!(io.dma.req.bits.laddr.read_full_acc_row && blocks > 1.U), "Block-mvouts are not permitted when moving out full accumulator data") + assert(!((pooling_is_enabled || mvout_1d_enabled) && blocks > 1.U), "Block-mvouts are not permitted when pooling") } // Control logic @@ -201,11 +214,13 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm } is (sending_rows) { - // TODO Is it really possible for row_counter to be 0 here? - val last_row = row_counter === 0.U || (Mux(mvout_1d_enabled, row_counter === mvout_1d_rows - 1.U, row_counter === rows - 1.U) && io.dma.req.fire()) + val last_block = block_counter === blocks - 1.U && io.dma.req.fire() + val last_row = Mux(mvout_1d_enabled, row_counter === mvout_1d_rows - 1.U, row_counter === rows - 1.U) && io.dma.req.fire() //normal mvout: row, 1D mvout: orows*ocols - when (last_row) { + val only_one_dma_req = block_counter === 0.U && row_counter === 0.U // This is a special case when only one DMA request is made + + when ((last_block && last_row) || only_one_dma_req) { control_state := waiting_for_command cmd.ready := true.B } diff --git a/src/main/scala/gemmini/Tile.scala b/src/main/scala/gemmini/Tile.scala index 69b606b8..1a2bfe74 100644 --- a/src/main/scala/gemmini/Tile.scala +++ b/src/main/scala/gemmini/Tile.scala @@ -25,6 +25,8 @@ class Tile[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, df: val in_valid = Input(Vec(columns, Bool())) val out_valid = Output(Vec(columns, Bool())) + + val bad_dataflow = Output(Bool()) }) val tile = Seq.fill(rows, columns)(Module(new PE(inputType, outputType, accType, df, pe_latency))) @@ -83,6 +85,7 @@ class Tile[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, df: io.out_control(c) := tile(rows-1)(c).io.out_control io.out_valid(c) := tile(rows-1)(c).io.out_valid } + io.bad_dataflow := tile.map(_.map(_.io.bad_dataflow).reduce(_||_)).reduce(_||_) // Drive the Tile's right IO for (r <- 0 until rows) { diff --git a/src/main/scala/gemmini/Util.scala b/src/main/scala/gemmini/Util.scala index 593d2070..dd837c7d 100644 --- a/src/main/scala/gemmini/Util.scala +++ b/src/main/scala/gemmini/Util.scala @@ -44,6 +44,15 @@ object Util { )) } + def sFloorAdd(s: SInt, n: UInt, max_plus_one: SInt, min: SInt, en: Bool = true.B): SInt = { + val max = max_plus_one - 1.S + + MuxCase(s + n.zext(), Seq( + (!en) -> s, + ((s +& n.zext()) > max) -> min + )) + } + def wrappingSub(u: UInt, n: UInt, max_plus_one: Int): UInt = { val max = max_plus_one - 1 assert(n <= max.U, "cannot wrapSub when n is larger than max") diff --git a/src/main/scala/gemmini/VectorScalarMultiplier.scala b/src/main/scala/gemmini/VectorScalarMultiplier.scala index 4e86f61a..d1cefcb3 100644 --- a/src/main/scala/gemmini/VectorScalarMultiplier.scala +++ b/src/main/scala/gemmini/VectorScalarMultiplier.scala @@ -24,15 +24,33 @@ class VectorScalarMultiplierResp[T <: Data, Tag <: Data](block_cols: Int, t: T, override def cloneType: VectorScalarMultiplierResp.this.type = new VectorScalarMultiplierResp(block_cols, t, tag_t).asInstanceOf[this.type] } -// Currently, this class only supports multiplications of scratchpad inputs, rather than accumulator inputs -// class VectorScalarMultiplier[T <: Data, U <: Data, Tag <: Data](config: GemminiArrayConfig[T, U], tag_t: Tag) extends Module { - // import config._ - // val block_cols = meshColumns * tileColumns -class VectorScalarMultiplier[T <: Data, U <: Data, Tag <: Data](mvin_scale_args: Option[ScaleArguments[T, U]], block_cols: Int, t: T, tag_t: Tag) extends Module { - - val u = mvin_scale_args match { - case Some(ScaleArguments(_, _, multiplicand_t, _, _)) => multiplicand_t - case None => Bool() // TODO make this a 0-width UInt +class DataWithIndex[T <: Data, U <: Data](t: T, u: U) extends Bundle { + val data = t.cloneType + val scale = u.cloneType + val id = UInt(2.W) // TODO hardcoded + val index = UInt() + override def cloneType: DataWithIndex.this.type = new DataWithIndex(t, u).asInstanceOf[this.type] +} + +class ScalePipe[T <: Data, U <: Data](t: T, mvin_scale_args: ScaleArguments[T, U]) extends Module { + val u = mvin_scale_args.multiplicand_t + val io = IO(new Bundle { + val in = Input(Valid(new DataWithIndex(t, u))) + val out = Output(Valid(new DataWithIndex(t, u))) + }) + val latency = mvin_scale_args.latency + val out = WireInit(io.in) + out.bits.data := mvin_scale_args.scale_func(io.in.bits.data, io.in.bits.scale.asTypeOf(u)) + io.out := Pipe(out, latency) +} + +class VectorScalarMultiplier[T <: Data, U <: Data, Tag <: Data]( + mvin_scale_args: Option[ScaleArguments[T, U]], block_cols: Int, t: T, tag_t: Tag +) extends Module { + + val (u, num_scale_units, always_identity) = mvin_scale_args match { + case Some(ScaleArguments(_, _, multiplicand_t, num_scale_units, _, _)) => (multiplicand_t, num_scale_units, false) + case None => (Bool(), -1, true) // TODO make this a 0-width UInt } val io = IO(new Bundle { @@ -40,48 +58,155 @@ class VectorScalarMultiplier[T <: Data, U <: Data, Tag <: Data](mvin_scale_args: val resp = Decoupled(new VectorScalarMultiplierResp(block_cols, t, tag_t)) }) - val req = Reg(UDValid(chiselTypeOf(io.req.bits))) - - io.req.ready := !req.valid || (req.bits.repeats === 0.U && io.resp.fire()) - io.resp.valid := req.valid - io.resp.bits.tag := req.bits.tag - io.resp.bits.last := req.bits.repeats === 0.U && req.bits.last - io.resp.bits.row := req.bits.repeats - io.resp.bits.out := (mvin_scale_args match { - case Some(ScaleArguments(mvin_scale_func, _, multiplicand_t, _, _)) => - req.bits.in.map(x => mvin_scale_func(x, req.bits.scale.asTypeOf(multiplicand_t))) + val width = block_cols + val latency = mvin_scale_args match { + case Some(ScaleArguments(_, latency, _, _, _, _)) => latency + case None => 0 + } - case None => req.bits.in - }) + val in = Reg(Valid(new VectorScalarMultiplierReq(block_cols, t, u, tag_t))) + val in_fire = WireInit(false.B) + io.req.ready := !in.valid || (in.bits.repeats === 0.U && in_fire) when (io.req.fire()) { - req.push(io.req.bits) - }.elsewhen(io.resp.fire()) { - when (req.bits.repeats === 0.U) { - req.pop() - }.otherwise { - req.bits.repeats := req.bits.repeats - 1.U + in.valid := io.req.valid + in.bits := io.req.bits + } .elsewhen (in_fire) { + when (in.bits.repeats === 0.U) { + in.valid := false.B } + in.bits.repeats := in.bits.repeats - 1.U + } + when (reset.asBool) { + in.valid := false.B } - when (reset.toBool()) { - req.pop() + + if (num_scale_units == -1) { + val pipe = Module(new Pipeline( + new VectorScalarMultiplierResp(block_cols, t, tag_t), + latency + )()) + io.resp <> pipe.io.out + in_fire := pipe.io.in.fire() + + pipe.io.in.valid := in.valid + pipe.io.in.bits.tag := in.bits.tag + pipe.io.in.bits.last := in.bits.repeats === 0.U && in.bits.last + pipe.io.in.bits.row := in.bits.repeats + pipe.io.in.bits.out := (mvin_scale_args match { + case Some(ScaleArguments(mvin_scale_func, _, multiplicand_t, _, _, _)) => + in.bits.in.map(x => mvin_scale_func(x, in.bits.scale.asTypeOf(multiplicand_t))) + case None => in.bits.in + }) + } else { + val nEntries = 3 + val regs = Reg(Vec(nEntries, Valid(new VectorScalarMultiplierReq(block_cols, t, u, tag_t)))) + val out_regs = Reg(Vec(nEntries, new VectorScalarMultiplierResp(block_cols, t, tag_t))) + + val fired_masks = Reg(Vec(nEntries, Vec(width, Bool()))) + val completed_masks = Reg(Vec(nEntries, Vec(width, Bool()))) + val head_oh = RegInit(1.U(nEntries.W)) + val tail_oh = RegInit(1.U(nEntries.W)) + + io.resp.valid := Mux1H(head_oh.asBools, (regs zip completed_masks).map({case (r,c) => r.valid && c.reduce(_&&_)})) + io.resp.bits := Mux1H(head_oh.asBools, out_regs) + when (io.resp.fire()) { + for (i <- 0 until nEntries) { + when (head_oh(i)) { + regs(i).valid := false.B + } + } + head_oh := (head_oh << 1) | head_oh(nEntries-1) + } + in_fire := (in.valid && + (!Mux1H(tail_oh.asBools, regs.map(_.valid)) || (tail_oh === head_oh && io.resp.fire())) + ) + when (in_fire) { + for (i <- 0 until nEntries) { + when (tail_oh(i)) { + regs(i).valid := true.B + regs(i).bits := in.bits + out_regs(i).tag := in.bits.tag + out_regs(i).last := in.bits.repeats === 0.U && in.bits.last + out_regs(i).row := in.bits.repeats + out_regs(i).out := in.bits.in + val identity = (u match { + case u: UInt => Arithmetic.UIntArithmetic.cast(u).identity + case s: SInt => Arithmetic.SIntArithmetic.cast(s).identity + case f: Float => Arithmetic.FloatArithmetic.cast(f).identity + case b: Bool => 1.U(1.W) + }) + fired_masks(i).foreach(_ := in.bits.scale.asUInt === identity.asUInt || always_identity.B) + completed_masks(i).foreach(_ := in.bits.scale.asUInt === identity.asUInt || always_identity.B) + } + } + tail_oh := (tail_oh << 1) | tail_oh(nEntries-1) + } + + + + val inputs = Seq.fill(width*nEntries) { Wire(Decoupled(new DataWithIndex(t, u))) } + for (i <- 0 until nEntries) { + for (w <- 0 until width) { + val input = inputs(i*width+w) + input.valid := regs(i).valid && !fired_masks(i)(w) + input.bits.data := regs(i).bits.in(w) + input.bits.scale := regs(i).bits.scale.asTypeOf(u) + input.bits.id := i.U + input.bits.index := w.U + when (input.fire()) { + fired_masks(i)(w) := true.B + } + } + } + for (i <- 0 until num_scale_units) { + val arbIn = inputs.zipWithIndex.filter({ case (_, w) => w % num_scale_units == i }).map(_._1) + val arb = Module(new RRArbiter(new DataWithIndex(t, u), arbIn.length)) + arb.io.in <> arbIn + arb.io.out.ready := true.B + val arbOut = Reg(Valid(new DataWithIndex(t, u))) + arbOut.valid := arb.io.out.valid + arbOut.bits := arb.io.out.bits + when (reset.asBool) { + arbOut.valid := false.B + } + + + val pipe = Module(new ScalePipe(t, mvin_scale_args.get)) + pipe.io.in := arbOut + val pipe_out = pipe.io.out + for (j <- 0 until nEntries) { + for (w <- 0 until width) { + if ((j*width+w) % num_scale_units == i) { + when (pipe_out.fire() && pipe_out.bits.id === j.U && pipe_out.bits.index === w.U) { + out_regs(j).out(w) := pipe_out.bits.data + completed_masks(j)(w) := true.B + } + } + } + } + } + when (reset.asBool) { + regs.foreach(_.valid := false.B) + } + + } + + } object VectorScalarMultiplier { // Returns the input and output IO of the module (together with the pipeline) - def apply[T <: Data, U <: Data, Tag <: Data](scale_args: Option[ScaleArguments[T, U]], t: T, cols: Int, tag_t: Tag, is_acc: Boolean, is_mvin: Boolean=true) = { + def apply[T <: Data, U <: Data, Tag <: Data]( + scale_args: Option[ScaleArguments[T, U]], + t: T, cols: Int, tag_t: Tag, + is_acc: Boolean, + is_mvin: Boolean=true + ) = { assert(!is_acc || is_mvin) - val vsm = Module(new VectorScalarMultiplier(scale_args, cols, t, tag_t)) - - val in = vsm.io.req - val out = scale_args match { - case Some(ScaleArguments(_, latency, _, _, _)) => Pipeline(vsm.io.resp, latency) - case None => vsm.io.resp - } - - (in, out) + (vsm.io.req, vsm.io.resp) } } diff --git a/src/main/scala/gemmini/XactTracker.scala b/src/main/scala/gemmini/XactTracker.scala index af020ed9..9eee539a 100644 --- a/src/main/scala/gemmini/XactTracker.scala +++ b/src/main/scala/gemmini/XactTracker.scala @@ -13,7 +13,8 @@ class XactTrackerEntry[U <: Data](maxShift: Int, spadWidth: Int, accWidth: Int, val accumulate = Bool() val has_acc_bitwidth = Bool() val scale = UInt(mvin_scale_t_bits.W) - val rows = UInt(16.W) // TODO magic number + val repeats = UInt(16.W) // TODO magic number + val block_stride = UInt(16.W) // TODO magic number val spad_row_offset = UInt(log2Up(spadWidth max accWidth).W) val lg_len_req = UInt(log2Up(log2Up(maxReqBytes+1)+1).W) val bytes_to_read = UInt(log2Up(maxReqBytes+1).W) diff --git a/src/main/scala/gemmini/ZeroWriter.scala b/src/main/scala/gemmini/ZeroWriter.scala new file mode 100644 index 00000000..c2e97b36 --- /dev/null +++ b/src/main/scala/gemmini/ZeroWriter.scala @@ -0,0 +1,70 @@ +package gemmini + +import chisel3._ +import chisel3.util._ + +import Util._ + +class ZeroWriterReq[Tag <: Data](laddr_t: LocalAddr, max_cols: Int, tag_t: Tag) extends Bundle { + val laddr = laddr_t + val cols = UInt(log2Up(max_cols+1).W) + val block_stride = UInt(16.W) // TODO magic number + val tag = tag_t + + override def cloneType: ZeroWriterReq.this.type = new ZeroWriterReq(laddr_t.cloneType, max_cols, tag_t.cloneType).asInstanceOf[this.type] +} + +class ZeroWriterResp[Tag <: Data](laddr_t: LocalAddr, block_cols: Int, tag_t: Tag) extends Bundle { + val laddr = laddr_t.cloneType + val mask = Vec(block_cols, Bool()) + val last = Bool() + val tag = tag_t + + override def cloneType: ZeroWriterResp.this.type = new ZeroWriterResp(laddr_t, block_cols, tag_t.cloneType).asInstanceOf[this.type] +} + +class ZeroWriter[T <: Data, U <: Data, V <: Data, Tag <: Data](config: GemminiArrayConfig[T, U, V], tag_t: Tag) + extends Module { + import config._ + + val block_cols = meshColumns * tileColumns + val max_cols = (dma_maxbytes / (inputType.getWidth / 8)) max block_cols + + val io = IO(new Bundle { + val req = Flipped(Decoupled(new ZeroWriterReq(local_addr_t, max_cols, tag_t))) + val resp = Decoupled(new ZeroWriterResp(local_addr_t, block_cols, tag_t)) + }) + + val req = Reg(UDValid(new ZeroWriterReq(local_addr_t, max_cols, tag_t))) + + val col_counter = Reg(UInt(log2Up(max_cols).W)) + + io.req.ready := !req.valid + + io.resp.valid := req.valid + io.resp.bits.laddr := req.bits.laddr + req.bits.block_stride * (col_counter / block_cols.U) + io.resp.bits.mask.zipWithIndex.foreach { case (m, i) => m := col_counter + i.U < req.bits.cols } + io.resp.bits.last := col_counter +& block_cols.U >= req.bits.cols + io.resp.bits.tag := req.bits.tag + + when (io.resp.fire()) { + val next_col_counter = floorAdd(col_counter, block_cols.U, req.bits.cols) + + col_counter := next_col_counter + + when (next_col_counter === 0.U) { + req.pop() + io.req.ready := true.B + } + } + + when (io.req.fire()) { + req.push(io.req.bits) + + col_counter := 0.U + } + + when (reset.toBool()) { + req.pop() + } +}