Skip to content

Commit

Permalink
[rtl] Check in advance whether vrf can be read.
Browse files Browse the repository at this point in the history
  • Loading branch information
qinjun-li committed Apr 24, 2024
1 parent 149bc2f commit b7091a1
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 48 deletions.
18 changes: 18 additions & 0 deletions t1/src/Lane.scala
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,13 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
)
)

// 3 * slot + 2 cross read
val readCheckRequestVec: Vec[VRFReadRequest] = Wire(Vec(parameter.chainingSize * 3 + 2,
new VRFReadRequest(parameter.vrfParam.regNumBits, parameter.vrfOffsetBits, parameter.instructionIndexBits)
))

val readCheckResult: Vec[Bool] = Wire(Vec(parameter.chainingSize * 3 + 2, Bool()))

/** signal used for prohibiting slots to access VRF.
* a slot will become inactive when:
* 1. cross lane read/write is not finished
Expand Down Expand Up @@ -623,6 +630,14 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
stage1.readFromScalar := record.laneRequest.readFromScalar
vrfReadRequest(index).zip(stage1.vrfReadRequest).foreach{ case (sink, source) => sink <> source }
vrfReadResult(index).zip(stage1.vrfReadResult).foreach{ case (source, sink) => sink := source }
// 3: read vs1 vs2 vd
// 2: cross read lsb & msb
val checkSize = if (isLastSlot) 5 else 3
Seq.tabulate(checkSize){ portIndex =>
// parameter.chainingSize - index: slot 0 need 5 port, so reverse connection
readCheckRequestVec((parameter.chainingSize - index - 1) * 3 + portIndex) := stage1.vrfCheckRequest(portIndex)
stage1.checkResult(portIndex) := readCheckResult((parameter.chainingSize - index - 1) * 3 + portIndex)
}
// connect cross read bus
if(isLastSlot) {
val tokenSize = parameter.crossLaneVRFWriteEscapeQueueSize
Expand Down Expand Up @@ -855,6 +870,9 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
}
val checkResult = vrf.writeAllow.asUInt

vrf.readCheck.zip(readCheckRequestVec).foreach{case (sink, source) => sink := source}
readCheckResult.zip(vrf.readCheckResult).foreach{case (sink, source) => sink := source}

// Arbiter
val writeSelect: UInt = ffo(checkResult & VecInit(allVrfWrite.map(_.valid)).asUInt)
allVrfWrite.zipWithIndex.foreach{ case (p, i) => p.ready := writeSelect(i) && queueBeforeMaskWrite.io.enq.ready }
Expand Down
121 changes: 74 additions & 47 deletions t1/src/laneStage/LaneStage1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class LaneStage1Dequeue(parameter: LaneParameter, isLastSlot: Boolean) extends B
* */
@instantiable
class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
val readRequestType: VRFReadRequest =
new VRFReadRequest(parameter.vrfParam.regNumBits, parameter.vrfOffsetBits, parameter.instructionIndexBits)
@public
val enqueue = IO(Flipped(Decoupled(new LaneStage1Enqueue(parameter, isLastSlot))))
@public
Expand All @@ -42,14 +44,14 @@ class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
@public
val state: LaneState = IO(Input(new LaneState(parameter)))
@public
val vrfReadRequest: Vec[DecoupledIO[VRFReadRequest]] = IO(
Vec(
3,
Decoupled(
new VRFReadRequest(parameter.vrfParam.regNumBits, parameter.vrfOffsetBits, parameter.instructionIndexBits)
)
)
)
val vrfReadRequest: Vec[DecoupledIO[VRFReadRequest]] = IO(Vec(3, Decoupled(readRequestType)))

val readCheckSize: Int = if(isLastSlot) 5 else 3
@public
val vrfCheckRequest: Vec[VRFReadRequest] = IO(Vec(readCheckSize, Output(readRequestType)))

@public
val checkResult: Vec[Bool] = IO(Vec(readCheckSize, Input(Bool())))

/** VRF read result for each slot,
* 3 is for [[source1]] [[source2]] [[source3]]
Expand All @@ -72,46 +74,74 @@ class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
val groupCounter: UInt = enqueue.bits.groupCounter

// todo: param
val readRequestQueueSize: Int = 4
val readRequestQueueSizeBeforeCheck: Int = 4
val readRequestQueueSizeAfterCheck: Int = 4
val dataQueueSize: Int = 4
val vrfReadEntryType = new VRFReadQueueEntry(parameter.vrfParam.regNumBits, parameter.vrfOffsetBits)

// read request queue for vs1 vs2 vd
val readRequestQueueVs1: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSize))
val readRequestQueueVs2: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSize))
val readRequestQueueVd: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSize))
val queueAfterCheck1: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSizeAfterCheck))
val queueAfterCheck2: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSizeAfterCheck))
val queueAfterCheckVd: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSizeAfterCheck))

// read request queue for vs1 vs2 vd
val queueBeforeCheck1: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSizeBeforeCheck))
val queueBeforeCheck2: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSizeBeforeCheck))
val queueBeforeCheckVd: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSizeBeforeCheck))

// read request queue for cross read lsb & msb
val queueAfterCheckLSB: Option[Queue[VRFReadQueueEntry]] =
Option.when(isLastSlot)(Module(new Queue(vrfReadEntryType, readRequestQueueSizeAfterCheck)))
val queueAfterCheckMSB: Option[Queue[VRFReadQueueEntry]] =
Option.when(isLastSlot)(Module(new Queue(vrfReadEntryType, readRequestQueueSizeAfterCheck)))

// read request queue for cross read lsb & msb
val readRequestQueueLSB: Option[Queue[VRFReadQueueEntry]] =
Option.when(isLastSlot)(Module(new Queue(vrfReadEntryType, readRequestQueueSize)))
val readRequestQueueMSB: Option[Queue[VRFReadQueueEntry]] =
Option.when(isLastSlot)(Module(new Queue(vrfReadEntryType, readRequestQueueSize)))
val queueBeforeCheckLSB: Option[Queue[VRFReadQueueEntry]] =
Option.when(isLastSlot)(Module(new Queue(vrfReadEntryType, readRequestQueueSizeBeforeCheck)))
val queueBeforeCheckMSB: Option[Queue[VRFReadQueueEntry]] =
Option.when(isLastSlot)(Module(new Queue(vrfReadEntryType, readRequestQueueSizeBeforeCheck)))

// pipe from enqueue
val pipeQueue: Queue[LaneStage1Enqueue] =
Module(new Queue(chiselTypeOf(enqueue.bits), readRequestQueueSize + dataQueueSize + 2))
Module(
new Queue(chiselTypeOf(enqueue.bits),
readRequestQueueSizeBeforeCheck + readRequestQueueSizeAfterCheck + dataQueueSize + 2
))
pipeQueue.io.enq.bits := enqueue.bits
pipeQueue.io.enq.valid := enqueue.fire
pipeQueue.io.deq.ready := dequeue.fire

val readQueueVec: Seq[Queue[VRFReadQueueEntry]] =
Seq(readRequestQueueVs1, readRequestQueueVs2, readRequestQueueVd) ++
readRequestQueueLSB ++ readRequestQueueMSB
val allReadQueueReady: Bool = readQueueVec.map(_.io.enq.ready).reduce(_ && _)
val allReadQueueEmpty: Bool = readQueueVec.map(!_.io.deq.valid).reduce(_ && _)
readQueueVec.foreach(q => q.io.enq.bits.instructionIndex := state.instructionIndex)
val beforeCheckQueueVec: Seq[Queue[VRFReadQueueEntry]] =
Seq(queueBeforeCheck1, queueBeforeCheck2, queueBeforeCheckVd) ++
queueBeforeCheckLSB ++ queueBeforeCheckMSB
val afterCheckQueueVec: Seq[Queue[VRFReadQueueEntry]] =
Seq(queueAfterCheck1, queueAfterCheck2, queueAfterCheckVd) ++
queueAfterCheckLSB ++ queueAfterCheckMSB
val allReadQueueReady: Bool = beforeCheckQueueVec.map(_.io.enq.ready).reduce(_ && _)
beforeCheckQueueVec.foreach{ q =>
q.io.enq.bits.instructionIndex := state.instructionIndex
q.io.enq.bits.groupIndex := enqueue.bits.groupCounter
}

enqueue.ready := allReadQueueReady && pipeQueue.io.enq.ready

// chaining check
beforeCheckQueueVec.zip(afterCheckQueueVec).zipWithIndex.foreach { case ((before, after), i) =>
vrfCheckRequest(i) := before.io.deq.bits
before.io.deq.ready := after.io.enq.ready && checkResult(i)
after.io.enq.valid := before.io.deq.valid && checkResult(i)
after.io.enq.bits := before.io.deq.bits
}
// request enqueue
readRequestQueueVs1.io.enq.valid := enqueue.fire && state.decodeResult(Decoder.vtype) && !state.skipRead
readRequestQueueVs2.io.enq.valid := enqueue.fire && !state.skipRead
readRequestQueueVd.io.enq.valid := enqueue.fire && !state.decodeResult(Decoder.sReadVD)
(readRequestQueueLSB ++ readRequestQueueMSB).foreach { q =>
queueBeforeCheck1.io.enq.valid := enqueue.fire && state.decodeResult(Decoder.vtype) && !state.skipRead
queueBeforeCheck2.io.enq.valid := enqueue.fire && !state.skipRead
queueBeforeCheckVd.io.enq.valid := enqueue.fire && !state.decodeResult(Decoder.sReadVD)
(queueBeforeCheckLSB ++ queueBeforeCheckMSB).foreach { q =>
q.io.enq.valid := enqueue.valid && allReadQueueReady && state.decodeResult(Decoder.crossRead)
}

// calculate vs
readRequestQueueVs1.io.enq.bits.vs := Mux(
queueBeforeCheck1.io.enq.bits.vs := Mux(
// encodings with vm=0 are reserved for mask type logic
state.decodeResult(Decoder.maskLogic) && !state.decodeResult(Decoder.logic),
// read v0 for (15. Vector Mask Instructions)
Expand All @@ -121,25 +151,25 @@ class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
parameter.vrfOffsetBits
)
)
readRequestQueueVs1.io.enq.bits.readSource := Mux(
queueBeforeCheck1.io.enq.bits.readSource := Mux(
state.decodeResult(Decoder.maskLogic) && !state.decodeResult(Decoder.logic),
3.U,
0.U
)
readRequestQueueVs2.io.enq.bits.vs := state.vs2 +
queueBeforeCheck2.io.enq.bits.vs := state.vs2 +
groupCounter(parameter.groupNumberBits - 1, parameter.vrfOffsetBits)
readRequestQueueVs2.io.enq.bits.readSource := 1.U
readRequestQueueVd.io.enq.bits.vs := state.vd +
queueBeforeCheck2.io.enq.bits.readSource := 1.U
queueBeforeCheckVd.io.enq.bits.vs := state.vd +
groupCounter(parameter.groupNumberBits - 1, parameter.vrfOffsetBits)
readRequestQueueVd.io.enq.bits.readSource := 2.U
queueBeforeCheckVd.io.enq.bits.readSource := 2.U

// calculate offset
readRequestQueueVs1.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 1, 0)
readRequestQueueVs2.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 1, 0)
readRequestQueueVd.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 1, 0)
queueBeforeCheck1.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 1, 0)
queueBeforeCheck2.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 1, 0)
queueBeforeCheckVd.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 1, 0)

// cross read enqueue
readRequestQueueLSB.foreach { q =>
queueBeforeCheckLSB.foreach { q =>
q.io.enq.bits.vs := Mux(
state.decodeResult(Decoder.vwmacc),
// cross read vd for vwmacc, since it need dual [[dataPathWidth]], use vs2 port to read LSB part of it.
Expand All @@ -151,7 +181,7 @@ class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
q.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 2, 0) ## false.B
}

readRequestQueueMSB.foreach { q =>
queueBeforeCheckMSB.foreach { q =>
q.io.enq.bits.vs := Mux(
state.decodeResult(Decoder.vwmacc),
// cross read vd for vwmacc
Expand All @@ -163,22 +193,19 @@ class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
q.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 2, 0) ## true.B
}

// todo: for debug
readQueueVec.foreach {q => q.io.enq.bits.groupIndex := enqueue.bits.groupCounter}

// read pipe
val readPipe0: Instance[VrfReadPipe] = Instantiate(new VrfReadPipe(parameter, arbitrate = false))
val readPipe1: Instance[VrfReadPipe] = Instantiate(new VrfReadPipe(parameter, arbitrate = isLastSlot))
val readPipe2: Instance[VrfReadPipe] = Instantiate(new VrfReadPipe(parameter, arbitrate = isLastSlot))
val pipeVec: Seq[Instance[VrfReadPipe]] = Seq(readPipe0, readPipe1, readPipe2)

readPipe0.enqueue <> readRequestQueueVs1.io.deq
readPipe1.enqueue <> readRequestQueueVs2.io.deq
readPipe2.enqueue <> readRequestQueueVd.io.deq
readPipe0.enqueue <> queueAfterCheck1.io.deq
readPipe1.enqueue <> queueAfterCheck2.io.deq
readPipe2.enqueue <> queueAfterCheckVd.io.deq

// contender for cross read
readPipe1.contender.zip(readRequestQueueLSB).foreach { case (port, queue) => port <> queue.io.deq }
readPipe2.contender.zip(readRequestQueueMSB).foreach { case (port, queue) => port <> queue.io.deq }
readPipe1.contender.zip(queueAfterCheckLSB).foreach { case (port, queue) => port <> queue.io.deq }
readPipe2.contender.zip(queueAfterCheckMSB).foreach { case (port, queue) => port <> queue.io.deq }

// read port connect
vrfReadRequest.zip(pipeVec).foreach { case (port, pipe) => port <> pipe.vrfReadRequest }
Expand Down Expand Up @@ -225,7 +252,7 @@ class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
val crossReadUnitOp: Option[Instance[CrossReadUnit]] = Option.when(isLastSlot)(Instantiate(new CrossReadUnit(parameter)))
if (isLastSlot) {
val dataGroupQueue: Queue[UInt] =
Module(new Queue(UInt(parameter.groupNumberBits.W), readRequestQueueSize + dataQueueSize + 2))
Module(new Queue(UInt(parameter.groupNumberBits.W), readRequestQueueSizeBeforeCheck + dataQueueSize + 2))
val crossReadUnit = crossReadUnitOp.get
crossReadUnit.dataInputLSB <> dataQueueLSB.get.io.deq
crossReadUnit.dataInputMSB <> dataQueueMSB.get.io.deq
Expand Down
29 changes: 28 additions & 1 deletion t1/src/vrf/VRF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ class VRF(val parameter: VRFParam) extends Module with SerializableModule[VRFPar
)
)

// 3 * slot + 2 cross read
@public
val readCheck: Vec[VRFReadRequest] = IO(Vec(parameter.chainingSize * 3 + 2, Input(
new VRFReadRequest(parameter.regNumBits, parameter.vrfOffsetBits, parameter.instructionIndexBits)
)))

@public
val readCheckResult: Vec[Bool] = IO(Vec(parameter.chainingSize * 3 + 2, Output(Bool())))

/** VRF read results. */
@public
val readResults: Vec[UInt] = IO(Output(Vec(parameter.vrfReadPort, UInt(parameter.datapathWidth.W))))
Expand Down Expand Up @@ -251,6 +260,24 @@ class VRF(val parameter: VRFParam) extends Module with SerializableModule[VRFPar
when(write.fire) { writePipe.bits := write.bits }
val writeBankPipe: UInt = RegNext(writeBank)

// lane chaining check
readCheck.zip(readCheckResult).foreach { case (req, res) =>
val recordSelect = chainingRecord
// 先找到自的record
val readRecord =
Mux1H(recordSelect.map(_.bits.instIndex === req.instructionIndex), recordSelect.map(_.bits))
res :=
recordSelect.zip(recordValidVec).zipWithIndex.map {
case ((r, f), recordIndex) =>
val checkModule = Instantiate(new ChainingCheck(parameter))
checkModule.read := req
checkModule.readRecord := readRecord
checkModule.record := r
checkModule.recordValid := f
checkModule.checkResult
}.reduce(_ && _)
}

val checkSize: Int = readRequests.size
val (firstOccupied, secondOccupied) = readRequests.zipWithIndex.foldLeft(
(0.U(parameter.rfBankNum.W), 0.U(parameter.rfBankNum.W))
Expand All @@ -276,7 +303,7 @@ class VRF(val parameter: VRFParam) extends Module with SerializableModule[VRFPar
checkModule.recordValid := f
checkModule.checkResult
}.reduce(_ && _) && portConflictCheck
val validCorrect: Bool = if (i == 0) v.valid else v.valid && checkResult
val validCorrect: Bool = if (i == (readRequests.size - 1)) v.valid && checkResult else v.valid
// select bank
val bank = if (parameter.rfBankNum == 1) true.B else UIntToOH(v.bits.offset(log2Ceil(parameter.rfBankNum) - 1, 0))
val pipeBank = Pipe(true.B, bank, parameter.vrfReadLatency).bits
Expand Down

0 comments on commit b7091a1

Please sign in to comment.