Skip to content

Commit

Permalink
initial files
Browse files Browse the repository at this point in the history
  • Loading branch information
lelzeiny committed Mar 19, 2024
1 parent 8c8b38b commit 2f62f3e
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/main/scala/gemmini/ExecuteController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In

// Instantiate the actual mesh
val mesh = Module(new MeshWithDelays(inputType, spatialArrayOutputType, accType, mesh_tag, dataflow, tree_reduction, tile_latency, mesh_output_delay,
tileRows, tileColumns, meshRows, meshColumns, shifter_banks, shifter_banks))
tileRows, tileColumns, meshRows, meshColumns, shifter_banks, shifter_banks, gemv_support=false))

mesh.io.a.valid := false.B
mesh.io.b.valid := false.B
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/gemmini/GemminiConfigs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
accType: T,

dataflow: Dataflow.Value = Dataflow.BOTH,
gemv_support: Boolean = false,

tileRows: Int = 1,
tileColumns: Int = 1,
Expand Down
7 changes: 4 additions & 3 deletions src/main/scala/gemmini/MeshWithDelays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data]
(inputType: T, val outputType: T, accType: T,
tagType: U, df: Dataflow.Value, tree_reduction: Boolean, tile_latency: Int, output_delay: Int,
tileRows: Int, tileColumns: Int, meshRows: Int, meshColumns: Int,
leftBanks: Int, upBanks: Int, outBanks: Int = 1, n_simultaneous_matmuls: Int = -1)
leftBanks: Int, upBanks: Int, outBanks: Int = 1, n_simultaneous_matmuls: Int = -1, gemv_support: Boolean = false)
extends Module {

val A_TYPE = Vec(meshRows, Vec(tileRows, inputType))
Expand Down Expand Up @@ -164,15 +164,16 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data]
val transposer_out = VecInit(transposer.io.outCol.bits.grouped(tileRows).map(t => VecInit(t)).toSeq)

// Wire up mesh's IO to this module's IO
val mesh = Module(new Mesh(inputType, outputType, accType, df, tree_reduction, tile_latency, max_simultaneous_matmuls, output_delay, tileRows, tileColumns, meshRows, meshColumns))
val mesh = Module(new MeshWithVectors(inputType, outputType, accType, df, tree_reduction, tile_latency, max_simultaneous_matmuls, output_delay, tileRows, tileColumns, meshRows, meshColumns))

// TODO wire only to *_buf here, instead of io.*.bits
val a_shifter_in = WireInit(Mux(a_is_from_transposer, transposer_out.asTypeOf(A_TYPE), a_buf))
val b_shifter_in = WireInit(Mux(b_is_from_transposer, transposer_out.asTypeOf(B_TYPE), b_buf))
val d_shifter_in = WireInit(Mux(d_is_from_transposer,
VecInit(transposer_out.flatten.reverse.grouped(tileRows).map(VecInit(_)).toSeq).asTypeOf(D_TYPE), d_buf))

mesh.io.in_a := shifted(a_shifter_in, leftBanks)
val in_a_row = shifted(a_shifter_in, leftBanks)
mesh.io.in_a := VecInit.fill(meshColumns)(in_a_row)
mesh.io.in_b := shifted(b_shifter_in, upBanks)
mesh.io.in_d := shifted(d_shifter_in, upBanks)

Expand Down
144 changes: 144 additions & 0 deletions src/main/scala/gemmini/MeshWithVectors.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@

package gemmini

import chisel3._
import chisel3.util._
import chisel3.experimental._

/**
* A Grid is a 2D array of Tile modules with registers in between each tile and
* registers from the bottom row and rightmost column of tiles to the Grid outputs.
* @param width
* @param tileRows
* @param tileColumns
* @param meshRows
* @param meshColumns
*/
class MeshForVectors[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T,
df: Dataflow.Value, tree_reduction: Boolean, tile_latency: Int,
max_simultaneous_matmuls: Int, output_delay: Int,
val tileRows: Int, val tileColumns: Int,
val meshRows: Int, val meshColumns: Int) extends Module {
val io = IO(new Bundle {
val in_a = Input(Vec(meshColumns, Vec(meshRows, Vec(tileRows, inputType))))
val in_b = Input(Vec(meshColumns, Vec(tileColumns, inputType)))
val in_d = Input(Vec(meshColumns, Vec(tileColumns, inputType)))
val in_control = Input(Vec(meshColumns, Vec(tileColumns, new PEControl(accType))))
val in_id = Input(Vec(meshColumns, Vec(tileColumns, UInt(log2Up(max_simultaneous_matmuls).W)))) // The unique id of this particular matmul
val in_last = Input(Vec(meshColumns, Vec(tileColumns, Bool())))
val out_b = Output(Vec(meshColumns, Vec(tileColumns, outputType)))
val out_c = Output(Vec(meshColumns, Vec(tileColumns, outputType)))
val in_valid = Input(Vec(meshColumns, Vec(tileColumns, Bool())))
val out_valid = Output(Vec(meshColumns, Vec(tileColumns, Bool())))
val out_control = Output(Vec(meshColumns, Vec(tileColumns, new PEControl(accType))))
val out_id = Output(Vec(meshColumns, Vec(tileColumns, UInt(log2Up(max_simultaneous_matmuls).W))))
val out_last = Output(Vec(meshColumns, Vec(tileColumns, Bool())))
})

// mesh(r)(c) => Tile at row r, column c
val mesh: Seq[Seq[Tile[T]]] = Seq.fill(meshRows, meshColumns)(Module(new Tile(inputType, outputType, accType, df, tree_reduction, max_simultaneous_matmuls, tileRows, tileColumns)))
val meshT = mesh.transpose

def pipe[T <: Data](valid: Bool, t: T, latency: Int): T = {
// The default "Pipe" function apparently resets the valid signals to false.B. We would like to avoid using global
// signals in the Mesh, so over here, we make it clear that the reset signal will never be asserted
chisel3.withReset(false.B) { Pipe(valid, t, latency).bits }
}


// Chain tile_a_out -> tile_a_in (pipeline a across each row)
// TODO clock-gate A signals with in_garbage
for (r <- 0 until meshRows) {
mesh(r).foldLeft(io.in_a(0)(r)) {
case (in_a, tile) =>
tile.io.in_a := ShiftRegister(in_a, tile_latency+1)
tile.io.out_a
}
}
// val pipelined_a = ?
// for (r <- 0 until meshRows) {
// mesh(r).foldLeft(io.in_a(0)(r)) {
// case (in_a, tile) =>
// pipelined_a(?) := ShiftRegister(in_a, tile_latency+1)
// }
}

// for (c <- 0 until meshColumns) {
// for (r <- 0 until meshRows) {
// tile.io.in_a(r)(c) := WireInit(Mux(true.B, ShiftRegister(tile.io.out_a(r-1)(c), tile_latency+1), io.in_a(r)(c)))
// tile.io.out_a(r)(c) := tile.io.in_a(r)(c)
// }
// }

// Chain tile_out_b -> tile_b_in (pipeline b across each column)
for (c <- 0 until meshColumns) {
meshT(c).foldLeft((io.in_b(c), io.in_valid(c))) {
case ((in_b, valid), tile) =>
tile.io.in_b := pipe(valid.head, in_b, tile_latency+1)
(tile.io.out_b, tile.io.out_valid)
}
}

// Chain tile_out -> tile_propag (pipeline output across each column)
for (c <- 0 until meshColumns) {
meshT(c).foldLeft((io.in_d(c), io.in_valid(c))) {
case ((in_propag, valid), tile) =>
tile.io.in_d := pipe(valid.head, in_propag, tile_latency+1)
(tile.io.out_c, tile.io.out_valid)
}
}

// Chain control signals (pipeline across each column)
assert(!(mesh.map(_.map(_.io.bad_dataflow).reduce(_||_)).reduce(_||_)))
for (c <- 0 until meshColumns) {
meshT(c).foldLeft((io.in_control(c), io.in_valid(c))) {
case ((in_ctrl, valid), tile) =>
(tile.io.in_control, in_ctrl, valid).zipped.foreach { case (tile_ctrl, ctrl, v) =>
tile_ctrl.shift := pipe(v, ctrl.shift, tile_latency+1)
tile_ctrl.dataflow := pipe(v, ctrl.dataflow, tile_latency+1)
tile_ctrl.propagate := pipe(v, ctrl.propagate, tile_latency+1)
}
(tile.io.out_control, tile.io.out_valid)
}
}

// Chain in_valid (pipeline across each column)
for (c <- 0 until meshColumns) {
meshT(c).foldLeft(io.in_valid(c)) {
case (in_v, tile) =>
tile.io.in_valid := ShiftRegister(in_v, tile_latency+1)
tile.io.out_valid
}
}

// Chain in_id (pipeline across each column)
for (c <- 0 until meshColumns) {
meshT(c).foldLeft(io.in_id(c)) {
case (in_id, tile) =>
tile.io.in_id := ShiftRegister(in_id, tile_latency+1)
tile.io.out_id
}
}

// Chain in_last (pipeline across each column)
for (c <- 0 until meshColumns) {
meshT(c).foldLeft(io.in_last(c)) {
case (in_last, tile) =>
tile.io.in_last := ShiftRegister(in_last, tile_latency+1)
tile.io.out_last
}
}

// Capture out_vec and out_control_vec (connect IO to bottom row of mesh)
// (The only reason we have so many zips is because Scala doesn't provide a zipped function for Tuple4)
for (((((((b, c), v), ctrl), id), last), tile) <- io.out_b zip io.out_c zip io.out_valid zip io.out_control zip io.out_id zip io.out_last zip mesh.last) {
// TODO we pipelined this to make physical design easier. Consider removing these if possible
// TODO shouldn't we clock-gate these signals with "garbage" as well?
b := ShiftRegister(tile.io.out_b, output_delay)
c := ShiftRegister(tile.io.out_c, output_delay)
v := ShiftRegister(tile.io.out_valid, output_delay)
ctrl := ShiftRegister(tile.io.out_control, output_delay)
id := ShiftRegister(tile.io.out_id, output_delay)
last := ShiftRegister(tile.io.out_last, output_delay)
}
}

0 comments on commit 2f62f3e

Please sign in to comment.