Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vrf read #543

Merged
merged 2 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,000 changes: 500 additions & 500 deletions .github/cases/blastoise/default.json

Large diffs are not rendered by default.

872 changes: 436 additions & 436 deletions .github/cases/machamp/default.json

Large diffs are not rendered by default.

872 changes: 436 additions & 436 deletions .github/cases/sandslash/default.json

Large diffs are not rendered by default.

5 changes: 0 additions & 5 deletions t1/src/Bundles.scala
Original file line number Diff line number Diff line change
Expand Up @@ -699,8 +699,3 @@ final class EmptyBundle extends Bundle
class VRFReadPipe(size: BigInt) extends Bundle {
val address: UInt = UInt(log2Ceil(size).W)
}

class DataPipeInReadStage(dataWidth: Int, arbitrate: Boolean) extends Bundle {
val data: UInt = UInt(dataWidth.W)
val choose: Option[Bool] = Option.when(arbitrate)(Bool())
}
18 changes: 18 additions & 0 deletions t1/src/Lane.scala
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,13 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
)
)

// 3 * slot + 2 cross read
val readCheckRequestVec: Vec[VRFReadRequest] = Wire(Vec(parameter.chainingSize * 3 + 2,
new VRFReadRequest(parameter.vrfParam.regNumBits, parameter.vrfOffsetBits, parameter.instructionIndexBits)
))

val readCheckResult: Vec[Bool] = Wire(Vec(parameter.chainingSize * 3 + 2, Bool()))

/** signal used for prohibiting slots to access VRF.
* a slot will become inactive when:
* 1. cross lane read/write is not finished
Expand Down Expand Up @@ -638,6 +645,14 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
stage1.readFromScalar := record.laneRequest.readFromScalar
vrfReadRequest(index).zip(stage1.vrfReadRequest).foreach{ case (sink, source) => sink <> source }
vrfReadResult(index).zip(stage1.vrfReadResult).foreach{ case (source, sink) => sink := source }
// 3: read vs1 vs2 vd
// 2: cross read lsb & msb
val checkSize = if (isLastSlot) 5 else 3
Seq.tabulate(checkSize){ portIndex =>
// parameter.chainingSize - index: slot 0 need 5 port, so reverse connection
readCheckRequestVec((parameter.chainingSize - index - 1) * 3 + portIndex) := stage1.vrfCheckRequest(portIndex)
stage1.checkResult(portIndex) := readCheckResult((parameter.chainingSize - index - 1) * 3 + portIndex)
}
// connect cross read bus
if(isLastSlot) {
val tokenSize = parameter.crossLaneVRFWriteEscapeQueueSize
Expand Down Expand Up @@ -869,6 +884,9 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
}
val checkResult = vrf.writeAllow.asUInt

vrf.readCheck.zip(readCheckRequestVec).foreach{case (sink, source) => sink := source}
readCheckResult.zip(vrf.readCheckResult).foreach{case (sink, source) => sink := source}

// Arbiter
val writeSelect: UInt = ffo(checkResult & VecInit(allVrfWrite.map(_.valid)).asUInt)
allVrfWrite.zipWithIndex.foreach{ case (p, i) => p.ready := writeSelect(i) && queueBeforeMaskWrite.io.enq.ready }
Expand Down
196 changes: 132 additions & 64 deletions t1/src/laneStage/LaneStage1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class LaneStage1Dequeue(parameter: LaneParameter, isLastSlot: Boolean) extends B
* */
@instantiable
class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
val readRequestType: VRFReadRequest =
new VRFReadRequest(parameter.vrfParam.regNumBits, parameter.vrfOffsetBits, parameter.instructionIndexBits)
@public
val enqueue = IO(Flipped(Decoupled(new LaneStage1Enqueue(parameter, isLastSlot))))
@public
Expand All @@ -42,14 +44,14 @@ class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
@public
val state: LaneState = IO(Input(new LaneState(parameter)))
@public
val vrfReadRequest: Vec[DecoupledIO[VRFReadRequest]] = IO(
Vec(
3,
Decoupled(
new VRFReadRequest(parameter.vrfParam.regNumBits, parameter.vrfOffsetBits, parameter.instructionIndexBits)
)
)
)
val vrfReadRequest: Vec[DecoupledIO[VRFReadRequest]] = IO(Vec(3, Decoupled(readRequestType)))

val readCheckSize: Int = if(isLastSlot) 5 else 3
@public
val vrfCheckRequest: Vec[VRFReadRequest] = IO(Vec(readCheckSize, Output(readRequestType)))

@public
val checkResult: Vec[Bool] = IO(Vec(readCheckSize, Input(Bool())))

/** VRF read result for each slot,
* 3 is for [[source1]] [[source2]] [[source3]]
Expand All @@ -72,46 +74,74 @@ class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
val groupCounter: UInt = enqueue.bits.groupCounter

// todo: param
val readRequestQueueSize: Int = 4
val readRequestQueueSizeBeforeCheck: Int = 4
val readRequestQueueSizeAfterCheck: Int = 4
val dataQueueSize: Int = 4
val vrfReadEntryType = new VRFReadQueueEntry(parameter.vrfParam.regNumBits, parameter.vrfOffsetBits)

// read request queue for vs1 vs2 vd
val readRequestQueueVs1: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSize))
val readRequestQueueVs2: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSize))
val readRequestQueueVd: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSize))
val queueAfterCheck1: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSizeAfterCheck))
val queueAfterCheck2: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSizeAfterCheck))
val queueAfterCheckVd: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSizeAfterCheck))

// read request queue for vs1 vs2 vd
val queueBeforeCheck1: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSizeBeforeCheck))
val queueBeforeCheck2: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSizeBeforeCheck))
val queueBeforeCheckVd: Queue[VRFReadQueueEntry] = Module(new Queue(vrfReadEntryType, readRequestQueueSizeBeforeCheck))

// read request queue for cross read lsb & msb
val queueAfterCheckLSB: Option[Queue[VRFReadQueueEntry]] =
Option.when(isLastSlot)(Module(new Queue(vrfReadEntryType, readRequestQueueSizeAfterCheck)))
val queueAfterCheckMSB: Option[Queue[VRFReadQueueEntry]] =
Option.when(isLastSlot)(Module(new Queue(vrfReadEntryType, readRequestQueueSizeAfterCheck)))

// read request queue for cross read lsb & msb
val readRequestQueueLSB: Option[Queue[VRFReadQueueEntry]] =
Option.when(isLastSlot)(Module(new Queue(vrfReadEntryType, readRequestQueueSize)))
val readRequestQueueMSB: Option[Queue[VRFReadQueueEntry]] =
Option.when(isLastSlot)(Module(new Queue(vrfReadEntryType, readRequestQueueSize)))
val queueBeforeCheckLSB: Option[Queue[VRFReadQueueEntry]] =
Option.when(isLastSlot)(Module(new Queue(vrfReadEntryType, readRequestQueueSizeBeforeCheck)))
val queueBeforeCheckMSB: Option[Queue[VRFReadQueueEntry]] =
Option.when(isLastSlot)(Module(new Queue(vrfReadEntryType, readRequestQueueSizeBeforeCheck)))

// pipe from enqueue
val pipeQueue: Queue[LaneStage1Enqueue] =
Module(new Queue(chiselTypeOf(enqueue.bits), readRequestQueueSize + dataQueueSize + 2))
Module(
new Queue(chiselTypeOf(enqueue.bits),
readRequestQueueSizeBeforeCheck + readRequestQueueSizeAfterCheck + dataQueueSize + 2
))
pipeQueue.io.enq.bits := enqueue.bits
pipeQueue.io.enq.valid := enqueue.fire
pipeQueue.io.deq.ready := dequeue.fire

val readQueueVec: Seq[Queue[VRFReadQueueEntry]] =
Seq(readRequestQueueVs1, readRequestQueueVs2, readRequestQueueVd) ++
readRequestQueueLSB ++ readRequestQueueMSB
val allReadQueueReady: Bool = readQueueVec.map(_.io.enq.ready).reduce(_ && _)
val allReadQueueEmpty: Bool = readQueueVec.map(!_.io.deq.valid).reduce(_ && _)
readQueueVec.foreach(q => q.io.enq.bits.instructionIndex := state.instructionIndex)
val beforeCheckQueueVec: Seq[Queue[VRFReadQueueEntry]] =
Seq(queueBeforeCheck1, queueBeforeCheck2, queueBeforeCheckVd) ++
queueBeforeCheckLSB ++ queueBeforeCheckMSB
val afterCheckQueueVec: Seq[Queue[VRFReadQueueEntry]] =
Seq(queueAfterCheck1, queueAfterCheck2, queueAfterCheckVd) ++
queueAfterCheckLSB ++ queueAfterCheckMSB
val allReadQueueReady: Bool = beforeCheckQueueVec.map(_.io.enq.ready).reduce(_ && _)
beforeCheckQueueVec.foreach{ q =>
q.io.enq.bits.instructionIndex := state.instructionIndex
q.io.enq.bits.groupIndex := enqueue.bits.groupCounter
}

enqueue.ready := allReadQueueReady && pipeQueue.io.enq.ready

// chaining check
beforeCheckQueueVec.zip(afterCheckQueueVec).zipWithIndex.foreach { case ((before, after), i) =>
vrfCheckRequest(i) := before.io.deq.bits
before.io.deq.ready := after.io.enq.ready && checkResult(i)
after.io.enq.valid := before.io.deq.valid && checkResult(i)
after.io.enq.bits := before.io.deq.bits
}
// request enqueue
readRequestQueueVs1.io.enq.valid := enqueue.fire && state.decodeResult(Decoder.vtype) && !state.skipRead
readRequestQueueVs2.io.enq.valid := enqueue.fire && !state.skipRead
readRequestQueueVd.io.enq.valid := enqueue.fire && !state.decodeResult(Decoder.sReadVD)
(readRequestQueueLSB ++ readRequestQueueMSB).foreach { q =>
queueBeforeCheck1.io.enq.valid := enqueue.fire && state.decodeResult(Decoder.vtype) && !state.skipRead
queueBeforeCheck2.io.enq.valid := enqueue.fire && !state.skipRead
queueBeforeCheckVd.io.enq.valid := enqueue.fire && !state.decodeResult(Decoder.sReadVD)
(queueBeforeCheckLSB ++ queueBeforeCheckMSB).foreach { q =>
q.io.enq.valid := enqueue.valid && allReadQueueReady && state.decodeResult(Decoder.crossRead)
}

// calculate vs
readRequestQueueVs1.io.enq.bits.vs := Mux(
queueBeforeCheck1.io.enq.bits.vs := Mux(
// encodings with vm=0 are reserved for mask type logic
state.decodeResult(Decoder.maskLogic) && !state.decodeResult(Decoder.logic),
// read v0 for (15. Vector Mask Instructions)
Expand All @@ -121,25 +151,25 @@ class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
parameter.vrfOffsetBits
)
)
readRequestQueueVs1.io.enq.bits.readSource := Mux(
queueBeforeCheck1.io.enq.bits.readSource := Mux(
state.decodeResult(Decoder.maskLogic) && !state.decodeResult(Decoder.logic),
3.U,
0.U
)
readRequestQueueVs2.io.enq.bits.vs := state.vs2 +
queueBeforeCheck2.io.enq.bits.vs := state.vs2 +
groupCounter(parameter.groupNumberBits - 1, parameter.vrfOffsetBits)
readRequestQueueVs2.io.enq.bits.readSource := 1.U
readRequestQueueVd.io.enq.bits.vs := state.vd +
queueBeforeCheck2.io.enq.bits.readSource := 1.U
queueBeforeCheckVd.io.enq.bits.vs := state.vd +
groupCounter(parameter.groupNumberBits - 1, parameter.vrfOffsetBits)
readRequestQueueVd.io.enq.bits.readSource := 2.U
queueBeforeCheckVd.io.enq.bits.readSource := 2.U

// calculate offset
readRequestQueueVs1.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 1, 0)
readRequestQueueVs2.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 1, 0)
readRequestQueueVd.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 1, 0)
queueBeforeCheck1.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 1, 0)
queueBeforeCheck2.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 1, 0)
queueBeforeCheckVd.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 1, 0)

// cross read enqueue
readRequestQueueLSB.foreach { q =>
queueBeforeCheckLSB.foreach { q =>
q.io.enq.bits.vs := Mux(
state.decodeResult(Decoder.vwmacc),
// cross read vd for vwmacc, since it need dual [[dataPathWidth]], use vs2 port to read LSB part of it.
Expand All @@ -151,7 +181,7 @@ class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
q.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 2, 0) ## false.B
}

readRequestQueueMSB.foreach { q =>
queueBeforeCheckMSB.foreach { q =>
q.io.enq.bits.vs := Mux(
state.decodeResult(Decoder.vwmacc),
// cross read vd for vwmacc
Expand All @@ -163,23 +193,12 @@ class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
q.io.enq.bits.offset := groupCounter(parameter.vrfOffsetBits - 2, 0) ## true.B
}

// todo: for debug
readQueueVec.foreach {q => q.io.enq.bits.groupIndex := enqueue.bits.groupCounter}

// read pipe
val readPipe0: Instance[VrfReadPipe] = Instantiate(new VrfReadPipe(parameter, arbitrate = false))
val readPipe1: Instance[VrfReadPipe] = Instantiate(new VrfReadPipe(parameter, arbitrate = isLastSlot))
val readPipe2: Instance[VrfReadPipe] = Instantiate(new VrfReadPipe(parameter, arbitrate = isLastSlot))
val pipeVec: Seq[Instance[VrfReadPipe]] = Seq(readPipe0, readPipe1, readPipe2)

readPipe0.enqueue <> readRequestQueueVs1.io.deq
readPipe1.enqueue <> readRequestQueueVs2.io.deq
readPipe2.enqueue <> readRequestQueueVd.io.deq

// contender for cross read
readPipe1.contender.zip(readRequestQueueLSB).foreach { case (port, queue) => port <> queue.io.deq }
readPipe2.contender.zip(readRequestQueueMSB).foreach { case (port, queue) => port <> queue.io.deq }

// read port connect
vrfReadRequest.zip(pipeVec).foreach { case (port, pipe) => port <> pipe.vrfReadRequest }
vrfReadResult.zip(pipeVec).foreach { case (result, pipe) => pipe.vrfReadResult := result }
Expand All @@ -192,27 +211,71 @@ class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
val dataQueueLSB = Option.when(isLastSlot)(Module(new Queue(UInt(parameter.datapathWidth.W), dataQueueSize)))
val dataQueueMSB = Option.when(isLastSlot)(Module(new Queue(UInt(parameter.datapathWidth.W), dataQueueSize)))

val dataQueueNotFull2: Bool = {
val counterReg = RegInit(0.U(log2Ceil(dataQueueSize + 1).W))
val doEnq = queueAfterCheck2.io.deq.fire
val doDeq = dataQueueVs2.io.deq.fire
val countChange = Mux(doEnq, 1.U, -1.S(log2Ceil(dataQueueSize + 1).W).asUInt)
when(doEnq ^ doDeq) {
counterReg := counterReg + countChange
}
!counterReg(log2Ceil(dataQueueSize))
}

val dataQueueNotFullVd: Bool = {
val counterReg = RegInit(0.U(log2Ceil(dataQueueSize + 1).W))
val doEnq = queueAfterCheckVd.io.deq.fire
val doDeq = dataQueueVd.io.deq.fire
val countChange = Mux(doEnq, 1.U, -1.S(log2Ceil(dataQueueSize + 1).W).asUInt)
when(doEnq ^ doDeq) {
counterReg := counterReg + countChange
}
!counterReg(log2Ceil(dataQueueSize))
}

readPipe0.enqueue <> queueAfterCheck1.io.deq
blockingHandshake(readPipe1.enqueue, queueAfterCheck2.io.deq, dataQueueNotFull2)
blockingHandshake(readPipe2.enqueue, queueAfterCheckVd.io.deq, dataQueueNotFullVd)

// contender for cross read
readPipe1.contender.zip(queueAfterCheckLSB).foreach { case (port, queue) =>
val dataQueueNotFullLSB: Bool = {
val counterReg = RegInit(0.U(log2Ceil(dataQueueSize + 1).W))
val doEnq = queue.io.deq.fire
val doDeq = dataQueueLSB.get.io.deq.fire
val countChange = Mux(doEnq, 1.U, -1.S(log2Ceil(dataQueueSize + 1).W).asUInt)
when(doEnq ^ doDeq) {
counterReg := counterReg + countChange
}
!counterReg(log2Ceil(dataQueueSize))
}
blockingHandshake(port, queue.io.deq, dataQueueNotFullLSB)
}
readPipe2.contender.zip(queueAfterCheckMSB).foreach { case (port, queue) =>
val dataQueueNotFullMSB: Bool = {
val counterReg = RegInit(0.U(log2Ceil(dataQueueSize + 1).W))
val doEnq = queue.io.deq.fire
val doDeq = dataQueueMSB.get.io.deq.fire
val countChange = Mux(doEnq, 1.U, -1.S(log2Ceil(dataQueueSize + 1).W).asUInt)
when(doEnq ^ doDeq) {
counterReg := counterReg + countChange
}
!counterReg(log2Ceil(dataQueueSize))
}
blockingHandshake(port, queue.io.deq, dataQueueNotFullMSB)
}

// data: pipe <-> queue
if (isLastSlot) {
// pipe1 <-> dataQueueVs2
dataQueueVs2.io.enq.valid := readPipe1.dequeue.valid && readPipe1.dequeueChoose.get
dataQueueVs2.io.enq.bits := readPipe1.dequeue.bits
dataQueueVs2.io.enq <> readPipe1.dequeue
// pipe1 <> dataQueueLSB
dataQueueLSB.get.io.enq.valid := readPipe1.dequeue.valid && !readPipe1.dequeueChoose.get
dataQueueLSB.get.io.enq.bits := readPipe1.dequeue.bits
// ready select
readPipe1.dequeue.ready :=
Mux(readPipe1.dequeueChoose.get, dataQueueVs2.io.enq.ready, dataQueueLSB.get.io.enq.ready)
dataQueueLSB.zip(readPipe1.contenderDequeue).foreach { case (sink, source) => sink.io.enq <> source }

// pipe2 <-> dataQueueVd
dataQueueVd.io.enq.valid := readPipe2.dequeue.valid && readPipe2.dequeueChoose.get
dataQueueVd.io.enq.bits := readPipe2.dequeue.bits
dataQueueVd.io.enq <> readPipe2.dequeue
// pipe2 <-> dataQueueMSB
dataQueueMSB.get.io.enq.valid := readPipe2.dequeue.valid && !readPipe2.dequeueChoose.get
dataQueueMSB.get.io.enq.bits := readPipe2.dequeue.bits
// ready select
readPipe2.dequeue.ready :=
Mux(readPipe2.dequeueChoose.get, dataQueueVd.io.enq.ready, dataQueueMSB.get.io.enq.ready)
dataQueueMSB.zip(readPipe2.contenderDequeue).foreach { case (sink, source) => sink.io.enq <> source }
} else {
dataQueueVs2.io.enq <> readPipe1.dequeue
dataQueueVd.io.enq <> readPipe2.dequeue
Expand All @@ -225,7 +288,12 @@ class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
val crossReadUnitOp: Option[Instance[CrossReadUnit]] = Option.when(isLastSlot)(Instantiate(new CrossReadUnit(parameter)))
if (isLastSlot) {
val dataGroupQueue: Queue[UInt] =
Module(new Queue(UInt(parameter.groupNumberBits.W), readRequestQueueSize + dataQueueSize + 2))
Module(
new Queue(
UInt(parameter.groupNumberBits.W),
readRequestQueueSizeBeforeCheck + readRequestQueueSizeBeforeCheck + dataQueueSize + 2
)
)
val crossReadUnit = crossReadUnitOp.get
crossReadUnit.dataInputLSB <> dataQueueLSB.get.io.deq
crossReadUnit.dataInputMSB <> dataQueueMSB.get.io.deq
Expand Down
Loading
Loading