Skip to content

Commit

Permalink
ported the spad to spad code to a new branch
Browse files Browse the repository at this point in the history
  • Loading branch information
vikramjain236 committed Sep 23, 2024
1 parent 25809f7 commit bac4323
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 6 deletions.
18 changes: 18 additions & 0 deletions src/main/scala/gemmini/Configs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ object GemminiConfigs {
ex_read_from_acc = true,
ex_write_to_spad = true,
ex_write_to_acc = true,

use_tl_spad_mem = true,
tl_spad_mem_base = 0x1000000,
)

val dummyConfig = GemminiArrayConfig[DummySInt, Float, Float](
Expand Down Expand Up @@ -285,6 +288,21 @@ class LeanGemminiPrintfConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
)
})

class LeanGemminiPGASConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
gemminiConfig: GemminiArrayConfig[T,U,V] = GemminiConfigs.leanConfig,
tl_spad_mem_base: BigInt = 0
) extends Config((site, here, up) => {
case BuildRoCC => up(BuildRoCC) ++ Seq(
(p: Parameters) => {
implicit val q = p
val gemmini = LazyModule(new Gemmini(gemminiConfig.copy(
tl_spad_mem_base = tl_spad_mem_base
)))
gemmini
}
)
})

class DummyDefaultGemminiConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
gemminiConfig: GemminiArrayConfig[T,U,V] = GemminiConfigs.dummyConfig
) extends Config((site, here, up) => {
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/gemmini/ConfigsFP.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ object GemminiFPConfigs {
has_nonlinear_activations = true,

num_counter = 8,
use_tl_spad_mem = true, // Use the globally addressable local spad feature
tl_spad_mem_base = 0x1000000, // Global address for the local spad of gemmini
)

//FP32 Single Precision Configuration
Expand Down
8 changes: 7 additions & 1 deletion src/main/scala/gemmini/Controller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import org.chipsalliance.cde.config._
import freechips.rocketchip.diplomacy._
import freechips.rocketchip.tile._
import freechips.rocketchip.util.ClockGate
import freechips.rocketchip.tilelink.TLIdentityNode
import freechips.rocketchip.tilelink._
import GemminiISA._
import Util._

Expand All @@ -34,6 +34,12 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA

val xLen = p(TileKey).core.xLen
val spad = LazyModule(new Scratchpad(config))
val xbar_mgr_node = TLXbar()

if (config.use_tl_spad_mem) {
spad.spad_rw_mgrs :*= TLBuffer() :*= xbar_mgr_node
xbar_mgr_node := TLBuffer() := stlNode
}

override lazy val module = new GemminiModule(this)
override val tlNode = if (config.use_dedicated_tl_port) spad.id_node else TLIdentityNode()
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/gemmini/ExecuteController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
if (ex_read_from_spad) {
io.srams.read(i).req.valid := (read_a || read_b || read_d) && cntl_ready
io.srams.read(i).req.bits.fromDMA := false.B
io.srams.read(i).req.bits.fromTL := false.B
io.srams.read(i).req.bits.addr := MuxCase(a_address_rs1.sp_row() + a_fire_counter,
Seq(read_b -> (b_address_rs2.sp_row() + b_fire_counter),
read_d -> (d_address_rs1.sp_row() + block_size.U - 1.U - d_fire_counter_mulpre)))
Expand All @@ -447,6 +448,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
} else {
io.srams.read(i).req.valid := false.B
io.srams.read(i).req.bits.fromDMA := false.B
io.srams.read(i).req.bits.fromTL := false.B
io.srams.read(i).req.bits.addr := DontCare
}

Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/gemmini/GemminiConfigs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
use_firesim_simulation_counters: Boolean = false,

use_shared_ext_mem: Boolean = false,
use_tl_spad_mem: Boolean = false,
tl_spad_mem_base: BigInt = 0,
clock_gate: Boolean = false,

headerFileName: String = "gemmini_params.h"
Expand Down
96 changes: 91 additions & 5 deletions src/main/scala/gemmini/Scratchpad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package gemmini
import chisel3._
import chisel3.util._
import org.chipsalliance.cde.config.Parameters
import freechips.rocketchip.diplomacy.{LazyModule, LazyModuleImp}
import freechips.rocketchip.diplomacy._
import freechips.rocketchip.rocket._
import freechips.rocketchip.tile._
import freechips.rocketchip.tilelink._
Expand Down Expand Up @@ -75,11 +75,13 @@ class ScratchpadWriteMemIO(local_addr_t: LocalAddr, acc_t_bits: Int, scale_t_bit
class ScratchpadReadReq(val n: Int) extends Bundle {
val addr = UInt(log2Ceil(n).W)
val fromDMA = Bool()
val fromTL = Bool()
}

class ScratchpadReadResp(val w: Int) extends Bundle {
val data = UInt(w.W)
val fromDMA = Bool()
val fromTL = Bool()
}

class ScratchpadReadIO(val n: Int, val w: Int) extends Bundle {
Expand Down Expand Up @@ -155,12 +157,14 @@ class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean, us
}

val fromDMA = io.read.req.bits.fromDMA
val fromTL = io.read.req.bits.fromTL

// Make a queue which buffers the result of an SRAM read if it can't immediately be consumed
val q = Module(new Queue(new ScratchpadReadResp(w), 1, true, true))
q.io.enq.valid := RegNext(ren)
q.io.enq.bits.data := rdata
q.io.enq.bits.fromDMA := RegNext(fromDMA)
q.io.enq.bits.fromTL := RegNext(fromTL)

val q_will_be_empty = (q.io.count +& q.io.enq.fire) - q.io.deq.fire === 0.U
io.read.req.ready := q_will_be_empty && !singleport_busy_with_write
Expand Down Expand Up @@ -193,6 +197,27 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
if (acc_read_full_width) acc_w else spad_w, aligned_to, inputType, block_cols, use_tlb_register_filter,
use_firesim_simulation_counters))

val use_tl_spad_mem = config.use_tl_spad_mem
val spad_base = config.tl_spad_mem_base
val spad_data_len = config.sp_width / 8
val max_data_len = spad_data_len // max acc_data_len

val mem_depth = config.sp_bank_entries * spad_data_len / max_data_len
val mem_width = max_data_len

val spad_rw_mgrs = if (use_tl_spad_mem) TLManagerNode(Seq.tabulate(config.sp_banks) { i =>
TLSlavePortParameters.v1(Seq(TLSlaveParameters.v2(
name = Some(s"spad_rw_mgr_$i"),
address = Seq(AddressSet(spad_base + i * mem_width * mem_depth, mem_width * mem_depth - 1)),
supports = TLMasterToSlaveTransferSizes(
get = TransferSizes(1, 64),
putFull = TransferSizes(1, 64),
putPartial = TransferSizes(1, 64)),
fifoId = Some(0)
)),
beatBytes = mem_width)
}) else TLIdentityNode()

// TODO make a cross-bar vs two separate ports a config option
// id_node :=* reader.node
// id_node :=* writer.node
Expand Down Expand Up @@ -443,6 +468,29 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,

io.busy := writer.module.io.busy || reader.module.io.busy || write_issue_q.io.deq.valid || write_norm_q.io.deq.valid || write_scale_q.io.deq.valid || write_dispatch_q.valid

val (tls, edges) = spad_rw_mgrs.in.unzip

val aHasData = Wire(Vec(config.sp_banks, Bool()))
val a_read_ready = Wire(Vec(config.sp_banks, Bool()))
val a_write_ready = Wire(Vec(config.sp_banks, Bool()))
for (i<-0 until config.sp_banks) {
aHasData(i) := edges(i).hasData(tls(i).a.bits)
tls(i).a.ready := Mux(aHasData(i), a_write_ready(i), a_read_ready(i))
}

def getDResponseFromID(sourceId: UInt, data: UInt) = {
val d = Wire(new TLBundleD(edges(0).bundle))
d.opcode := TLMessages.AccessAckData
d.param := 0.U
d.size := log2Ceil(mem_width).U
d.source := sourceId
d.sink := 0.U
d.denied := false.B
d.data := data
d.corrupt := false.B
d
}

val spad_mems = {
val banks = Seq.fill(sp_banks) { Module(new ScratchpadBank(
sp_bank_entries, spad_w,
Expand All @@ -465,23 +513,31 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
!(bio.write.en && config.sp_singleported.B) &&
!write_dispatch_q.bits.laddr.is_acc_addr && write_dispatch_q.bits.laddr.sp_bank() === i.U

bio.read.req.valid := exread || dmawrite
val tlread = tls(i).a.fire && !aHasData(i)

bio.read.req.valid := exread || dmawrite || tlread
ex_read_req.ready := bio.read.req.ready

// The ExecuteController gets priority when reading from SRAMs
when (exread) {
bio.read.req.bits.addr := ex_read_req.bits.addr
bio.read.req.bits.fromDMA := false.B
bio.read.req.bits.fromTL := false.B
}.elsewhen (dmawrite) {
bio.read.req.bits.addr := write_dispatch_q.bits.laddr.sp_row()
bio.read.req.bits.fromDMA := true.B
bio.read.req.bits.fromTL := false.B

when (bio.read.req.fire) {
write_dispatch_q.ready := true.B
write_norm_q.io.enq.valid := true.B

io.dma.write.resp.valid := true.B
}
}.elsewhen(tlread) {
bio.read.req.bits.addr := tls(i).a.bits.address >> (log2Up(mem_width).U)
bio.read.req.bits.fromDMA := false.B
bio.read.req.bits.fromTL := true.B
}.otherwise {
bio.read.req.bits := DontCare
}
Expand All @@ -492,14 +548,34 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
val ex_read_resp = Wire(Decoupled(new ScratchpadReadResp(spad_w)))
ex_read_resp.valid := bio.read.resp.valid && !bio.read.resp.bits.fromDMA
ex_read_resp.bits := bio.read.resp.bits
val tl_read_resp = Wire(Decoupled(new ScratchpadReadResp(spad_w)))
tl_read_resp.valid := bio.read.resp.valid && bio.read.resp.bits.fromTL && !bio.read.resp.bits.fromDMA
tl_read_resp.bits := bio.read.resp.bits
val src_id = Wire(Decoupled(UInt(tls(i).a.bits.source.getWidth.W)))
src_id.valid := tls(i).a.valid
src_id.bits := tls(i).a.bits.source

val dma_read_pipe = Pipeline(dma_read_resp, spad_read_delay)
val ex_read_pipe = Pipeline(ex_read_resp, spad_read_delay)
val tl_read_pipe = Pipeline(tl_read_resp, spad_read_delay)
val src_id_pipe = Pipeline(src_id, spad_read_delay+1)

bio.read.resp.ready := Mux(bio.read.resp.bits.fromTL, tl_read_resp.ready, (Mux(bio.read.resp.bits.fromDMA, dma_read_resp.ready, ex_read_resp.ready)))

a_read_ready(i) := tl_read_resp.ready && tls(i).d.ready

tl_read_pipe.ready := tls(i).d.ready
src_id_pipe.ready := tls(i).d.ready

bio.read.resp.ready := Mux(bio.read.resp.bits.fromDMA, dma_read_resp.ready, ex_read_resp.ready)
when (tl_read_pipe.fire) {
tls(i).d.valid := true.B
tls(i).d.bits := getDResponseFromID(src_id_pipe.bits, tl_read_pipe.bits.data.asUInt)
}.otherwise {
tls(i).d.valid := false.B
}

dma_read_pipe.ready := writer.module.io.req.ready &&
!write_issue_q.io.deq.bits.laddr.is_acc_addr && write_issue_q.io.deq.bits.laddr.sp_bank() === i.U && // I believe we don't need to check that write_issue_q is valid here, because if the SRAM's resp is valid, then that means that the write_issue_q's deq should also be valid
((!write_issue_q.io.deq.bits.laddr.is_acc_addr && write_issue_q.io.deq.bits.laddr.sp_bank() === i.U) && write_issue_q.io.deq.valid) &&
!write_issue_q.io.deq.bits.laddr.is_garbage()
when (dma_read_pipe.fire) {
writeData.valid := true.B
Expand Down Expand Up @@ -529,7 +605,10 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
// !((mvin_scale_out.valid && mvin_scale_out.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last))
!((mvin_scale_pixel_repeater.io.resp.valid && mvin_scale_pixel_repeater.io.resp.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last))

bio.write.en := exwrite || dmaread || zerowrite
val tlwrite = tls(i).a.fire && aHasData(i)
a_write_ready(i) := !exwrite && !dmaread && !zerowrite && tls(i).d.ready

bio.write.en := exwrite || dmaread || zerowrite || tlwrite

when (exwrite) {
bio.write.addr := io.srams.write(i).addr
Expand All @@ -547,6 +626,13 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
bio.write.mask := zero_writer_pixel_repeater.io.resp.bits.mask

zero_writer_pixel_repeater.io.resp.ready := true.B // TODO we combinationally couple valid and ready signals
}.elsewhen (tlwrite) {
bio.write.addr := tls(i).a.bits.address >> (log2Up(mem_width).U)
bio.write.data := tls(i).a.bits.data
bio.write.mask := tls(i).a.bits.mask.asTypeOf(bio.write.mask)

tls(i).d.valid := true.B
tls(i).d.bits := edges(i).AccessAck(tls(i).a.bits)
}.otherwise {
bio.write.addr := DontCare
bio.write.data := DontCare
Expand Down

0 comments on commit bac4323

Please sign in to comment.