Skip to content

Commit

Permalink
[rtl] refactor store unit.
Browse files Browse the repository at this point in the history
  • Loading branch information
qinjun-li committed Jan 16, 2024
1 parent 5a6a2aa commit 2be1c22
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 78 deletions.
4 changes: 4 additions & 0 deletions t1/src/V.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ case class VParameter(
/** for TileLink `mask` element. */
val maskWidth: Int = memoryDataWidth / 8

// todo
val vrfReadLatency = 1

// each element: Each lane will be connected to the other two lanes,
// and the values are their respective delays.
val crossLaneConnectCycles: Seq[Seq[Int]] = Seq.tabulate(laneNumber)(_ => Seq(1, 1))
Expand Down Expand Up @@ -159,6 +162,7 @@ case class VParameter(
lsuMSHRSize,
lsuVRFWriteQueueSize,
lsuTransposeSize,
vrfReadLatency,
tlParam
)
def vrfParam: VRFParam = VRFParam(vLen, laneNumber, datapathWidth, chainingSize, portFactor)
Expand Down
9 changes: 5 additions & 4 deletions t1/src/lsu/LSU.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ case class LSUParam(
lsuMSHRSize: Int,
lsuVRFWriteQueueSize: Int,
lsuTransposeSize: Int,
vrfReadLatency: Int,
tlParam: TLBundleParameter) {

/** see [[VParameter.maskGroupWidth]]. */
Expand All @@ -49,7 +50,7 @@ case class LSUParam(
val bankPosition: Int = log2Ceil(lsuTransposeSize)

def mshrParam: MSHRParam =
MSHRParam(chainingSize, datapathWidth, vLen, laneNumber, paWidth, lsuTransposeSize, memoryBankSize, tlParam)
MSHRParam(chainingSize, datapathWidth, vLen, laneNumber, paWidth, lsuTransposeSize, memoryBankSize, vrfReadLatency, tlParam)

/** see [[VRFParam.regNumBits]] */
val regNumBits: Int = log2Ceil(32)
Expand Down Expand Up @@ -312,9 +313,9 @@ class LSU(param: LSUParam) extends Module {
val storeEndLessThanLoadEnd: Bool = storeUnit.status.endAddress <= loadUnit.status.endAddress

val addressOverlap: Bool = ((storeStartLargerThanLoadStart && storeStartLessThanLoadEnd) ||
(storeEndLargerThanLoadStart && storeEndLessThanLoadEnd)) && !(storeUnit.status.idle || loadUnit.status.idle)
val stallLoad: Bool = !unitOrder && addressOverlap
val stallStore: Bool = unitOrder && addressOverlap
(storeEndLargerThanLoadStart && storeEndLessThanLoadEnd))
val stallLoad: Bool = !unitOrder && addressOverlap && !storeUnit.status.idle
val stallStore: Bool = unitOrder && addressOverlap && !loadUnit.status.idle

loadUnit.addressConflict := stallLoad
storeUnit.addressConflict := stallStore
Expand Down
1 change: 1 addition & 0 deletions t1/src/lsu/SimpleAccessUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ case class MSHRParam(
paWidth: Int,
lsuTransposeSize: Int,
memoryBankSize: Int,
vrfReadLatency: Int,
outerTLParam: TLBundleParameter) {

/** see [[LaneParameter.lmulMax]] */
Expand Down
175 changes: 101 additions & 74 deletions t1/src/lsu/StoreUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package org.chipsalliance.t1.rtl.lsu
import chisel3._
import chisel3.util._
import chisel3.probe._
import org.chipsalliance.t1.rtl.{VRFReadRequest, cutUInt, multiShifter}
import org.chipsalliance.t1.rtl.{EmptyBundle, VRFReadRequest, cutUInt, multiShifter}
import tilelink.TLChannelA

class cacheLineEnqueueBundle(param: MSHRParam) extends Bundle {
Expand Down Expand Up @@ -37,92 +37,107 @@ class StoreUnit(param: MSHRParam) extends StrideBase(param) with LSUPublic {
val vrfReadyToStore: Bool = IO(Input(Bool()))

// stage 0, 处理 vl, mask ...
val changeReadGroup: Bool = Wire(Bool())
val dataGroupByteSize: Int = param.datapathWidth * param.laneNumber / 8
val dataByteSize: UInt = (csrInterface.vl << lsuRequest.bits.instructionInformation.eew).asUInt
val lastDataGroupForInstruction: UInt = (dataByteSize >> log2Ceil(dataGroupByteSize)).asUInt -
!dataByteSize(log2Ceil(dataGroupByteSize) - 1, 0).orR
val lastDataGroupReg: UInt = RegEnable(lastDataGroupForInstruction, 0.U, lsuRequest.valid)
val nextDataGroup: UInt = Mux(lsuRequest.valid, -1.S(dataGroup.getWidth.W).asUInt, dataGroup + 1.U)
val isLastRead: Bool = nextDataGroup === lastDataGroupReg
val isLastRead: Bool = dataGroup === lastDataGroupReg
val lastGroupAndNeedAlign: Bool = initOffset.orR && isLastRead
val stage0Idle: Bool = RegEnable(
Mux(lsuRequest.valid, invalidInstruction, isLastRead),
true.B,
changeReadGroup || lsuRequest.valid
)
val readStageEnqueueReady: Bool = Wire(Bool())
changeReadGroup := readStageEnqueueReady && !stage0Idle

when(changeReadGroup || lsuRequest.valid) {
maskCounterInGroup := Mux(isLastDataGroup || lsuRequest.valid, 0.U, nextMaskCount)
when(isLastDataGroup && !isLastMaskGroup) {
maskSelect.valid := true.B
}
when((isLastDataGroup && !isLastMaskGroup) || lsuRequest.valid) {
maskGroupCounter := Mux(lsuRequest.valid, 0.U, nextMaskGroup)
}
dataGroup := nextDataGroup
}

// stage1, 读vrf
val readStageValid: Bool = RegInit(false.B)
// todo: need hazardCheck?
val hazardCheck: Bool = RegEnable(vrfReadyToStore && !lsuRequest.valid, false.B, lsuRequest.valid || vrfReadyToStore)
val readData: Vec[UInt] = RegInit(VecInit(Seq.fill(param.laneNumber)(0.U(param.datapathWidth.W))))
val readMask: UInt = RegInit(0.U(param.maskGroupWidth.W))
val tailLeft1: Bool = RegInit(false.B)
// read stage dequeue ready need all source valid, Or add a queue to coordinate
val vrfReadQueueVec: Seq[Queue[UInt]] =
Seq.tabulate(param.laneNumber)(_ => Module(new Queue(UInt(param.datapathWidth.W), 2, flow = true, pipe = true)))

// 从vrf里面读数据
Seq.tabulate(param.laneNumber) { laneIndex =>
val readStageValid: Bool = Seq.tabulate(param.laneNumber) { laneIndex =>
val readPort: DecoupledIO[VRFReadRequest] = vrfReadDataPorts(laneIndex)
readPort.valid := accessState(laneIndex) && readStageValid && hazardCheck
val segPtr: UInt = RegInit(0.U(3.W))
val readCount: UInt = RegInit(0.U(dataGroupBits.W))
val stageValid = RegInit(false.B)
// queue for read latency
val queue: Queue[UInt] = Module(new Queue(UInt(param.datapathWidth.W), param.vrfReadLatency, flow = true))

val lastReadPtr: Bool = segPtr === 0.U

val nextReadCount: UInt = Mux(lsuRequest.valid, 0.U(dataGroup.getWidth.W), readCount + 1.U)
val lastReadGroup: Bool = readCount === lastDataGroupReg

// update stageValid
when((lsuRequest.valid && !invalidInstruction) || (lastReadGroup && lastReadPtr && readPort.fire)) {
stageValid := lsuRequest.valid
}

// update segPtr
when(lsuRequest.valid || readPort.fire) {
segPtr := Mux(
lsuRequest.valid,
lsuRequest.bits.instructionInformation.nf,
Mux(
lastReadPtr,
lsuRequestReg.instructionInformation.nf,
segPtr - 1.U
)
)
}

// update readCount
when(lsuRequest.valid || (readPort.fire && lastReadPtr)) {
readCount := nextReadCount
}

// vrf read request
readPort.valid := stageValid && vrfReadQueueVec(laneIndex).io.enq.ready
readPort.bits.vs :=
lsuRequestReg.instructionInformation.vs3 +
accessPtr * segmentInstructionIndexInterval +
(dataGroup >> readPort.bits.offset.getWidth).asUInt
segPtr * segmentInstructionIndexInterval +
(readCount >> readPort.bits.offset.getWidth).asUInt
readPort.bits.readSource := 2.U
readPort.bits.offset := dataGroup
readPort.bits.offset := readCount
readPort.bits.instructionIndex := lsuRequestReg.instructionIndex
when(readPort.fire) {
accessState(laneIndex) := false.B
}
when(RegNext(readPort.fire, false.B)) {
readData(laneIndex) := vrfReadResults(laneIndex)
}
}

// 需要等待 sram 的结果返回
val readResponseCheck: Bool = RegNext(accessStateCheck, true.B)
val lastPtr: Bool = accessPtr === 0.U
val readStateCheck: Bool = accessStateCheck && readResponseCheck
val readStateValid: Bool = lastPtr && readStateCheck
val accessBufferDequeueReady: Bool = Wire(Bool())
val accessBufferDequeueFire: Bool = readStateValid && accessBufferDequeueReady && readStageValid
readStageEnqueueReady := !readStageValid || accessBufferDequeueFire
when(changeReadGroup ^ accessBufferDequeueFire) {
readStageValid := changeReadGroup
}
// pipe read fire
val readResultFire = Pipe(readPort.fire, 0.U.asTypeOf(new EmptyBundle), param.vrfReadLatency).valid

when(changeReadGroup) {
readMask := maskForGroupWire
tailLeft1 := lastGroupAndNeedAlign
}
// latency queue enq
queue.io.enq.valid := readResultFire
queue.io.enq.bits := vrfReadResults(laneIndex)
assert(!queue.io.enq.valid || queue.io.enq.ready)

vrfReadQueueVec(laneIndex).io.enq <> queue.io.deq
stageValid
}.reduce(_ || _)

when(changeReadGroup || (readStateCheck && !lastPtr)) {
// stage buffer stage: data before regroup
val bufferFull: Bool = RegInit(false.B)
val accessBufferDequeueReady: Bool = Wire(Bool())
val accessBufferEnqueueReady: Bool = !bufferFull || accessBufferDequeueReady
val accessBufferEnqueueValid: Bool = vrfReadQueueVec.map(_.io.deq.valid).reduce(_ && _)
val readQueueClear: Bool = !vrfReadQueueVec.map(_.io.deq.valid).reduce(_ || _)
val accessBufferEnqueueFire: Bool = accessBufferEnqueueValid && accessBufferEnqueueReady
val lastPtr: Bool = accessPtr === 0.U
val lastPtrEnq: Bool = lastPtr && accessBufferEnqueueFire
val accessBufferDequeueValid: Bool = bufferFull || lastPtrEnq
val accessBufferDequeueFire: Bool = accessBufferDequeueValid && accessBufferDequeueReady
vrfReadQueueVec.foreach(_.io.deq.ready := accessBufferEnqueueFire)
val accessDataUpdate: Vec[UInt] =
VecInit(VecInit(vrfReadQueueVec.map(_.io.deq.bits)).asUInt +: accessData.init)

when(lastPtrEnq ^ accessBufferDequeueFire) {
bufferFull := lastPtrEnq
}
when(accessBufferDequeueFire || accessBufferEnqueueFire) {
accessPtr := Mux(
changeReadGroup,
accessBufferDequeueFire || lastPtr,
lsuRequestReg.instructionInformation.nf,
accessPtr - 1.U
)
// 在更新ptr的时候把数据推进 [[accessData]] 里面
accessData := VecInit(readData.asUInt +: accessData.init)

// 更新access state
accessState := Mux(changeReadGroup, initSendState, initStateReg)
}

// changeReadGroup 可能会换 mask 所以需要存起来
when(changeReadGroup) {
initStateReg := initSendState
accessData := accessDataUpdate
}

// stage2, 用一个buffer来存转成cache line 的数据
Expand All @@ -138,16 +153,17 @@ class StoreUnit(param: MSHRParam) extends StrideBase(param) with LSUPublic {
val maskTemp: UInt = RegInit(0.U(param.lsuTransposeSize.W))
val tailValid: Bool = RegInit(false.B)
val isLastCacheLineInBuffer: Bool = cacheLineIndexInBuffer === lsuRequestReg.instructionInformation.nf
accessBufferDequeueReady := !bufferValid
val bufferStageEnqueueData: Vec[UInt] = VecInit(readData.asUInt +: accessData.init)
val bufferWillClear: Bool = alignedDequeueFire && isLastCacheLineInBuffer
accessBufferDequeueReady := !bufferValid || (alignedDequeue.ready && isLastCacheLineInBuffer)
val bufferStageEnqueueData: Vec[UInt] = Mux(bufferFull, accessData, accessDataUpdate)
// 处理mask, 对于 segment type 来说 一个mask 管 nf 个element
val fillBySeg: UInt = Mux1H(UIntToOH(lsuRequestReg.instructionInformation.nf), Seq.tabulate(8) { segSize =>
FillInterleaved(segSize + 1, readMask)
FillInterleaved(segSize + 1, maskForGroupWire)
})
// 把数据regroup, 然后放去 [[dataBuffer]]
when(accessBufferDequeueFire) {
maskForBufferData := cutUInt(fillBySeg, param.lsuTransposeSize)
tailLeft2 := tailLeft1
tailLeft2 := lastGroupAndNeedAlign
// todo: 只是因为参数恰好是一个方形的, 需要写一个反的
dataBuffer := Mux1H(dataEEWOH, Seq.tabulate(3) { sewSize =>
// 每个数据块 2 ** sew byte
Expand Down Expand Up @@ -179,16 +195,27 @@ class StoreUnit(param: MSHRParam) extends StrideBase(param) with LSUPublic {
}).asUInt.suggestName(s"regroupLoadData_${sewSize}_$segSize")
})
}).asTypeOf(dataBuffer)
}
when(alignedDequeueFire) {
}.elsewhen(alignedDequeueFire) {
dataBuffer := VecInit(dataBuffer.tail :+ 0.U.asTypeOf(dataBuffer.head))
}

// update mask
when(lsuRequest.valid || accessBufferDequeueFire) {
maskCounterInGroup := Mux(isLastDataGroup || lsuRequest.valid, 0.U, nextMaskCount)
when(isLastDataGroup && !isLastMaskGroup) {
maskSelect.valid := true.B
}
when((isLastDataGroup && !isLastMaskGroup) || lsuRequest.valid) {
maskGroupCounter := Mux(lsuRequest.valid, 0.U, nextMaskGroup)
}
dataGroup := nextDataGroup
}

when(accessBufferDequeueFire || alignedDequeueFire) {
cacheLineIndexInBuffer := Mux(accessBufferDequeueFire, 0.U, cacheLineIndexInBuffer + 1.U)
}

when(accessBufferDequeueFire || (alignedDequeueFire && isLastCacheLineInBuffer)) {
when(accessBufferDequeueFire ^ bufferWillClear) {
bufferValid := accessBufferDequeueFire
}

Expand Down Expand Up @@ -217,7 +244,7 @@ class StoreUnit(param: MSHRParam) extends StrideBase(param) with LSUPublic {
val currentAddress: Vec[UInt] = Wire(Vec(param.memoryBankSize, UInt(param.tlParam.a.addressWidth.W)))
val sendStageReady: Vec[Bool] = Wire(Vec(param.memoryBankSize, Bool()))
// tl 发送单元
val readyVec: Vec[Bool] = VecInit(Seq.tabulate(param.memoryBankSize) { portIndex =>
val readyVec = Seq.tabulate(param.memoryBankSize) { portIndex =>
val dataToSend: ValidIO[cacheLineEnqueueBundle] = RegInit(0.U.asTypeOf(Valid(new cacheLineEnqueueBundle(param))))
val port: DecoupledIO[TLChannelA] = tlPortA(portIndex)
val portFire: Bool = port.fire
Expand Down Expand Up @@ -266,12 +293,12 @@ class StoreUnit(param: MSHRParam) extends StrideBase(param) with LSUPublic {
port.bits.corrupt := false.B
sendStageReady(portIndex) := enqueueReady
!dataToSend.valid
})
}

val sendStageClear: Bool = readyVec.asUInt.andR
val sendStageClear: Bool = readyVec.reduce(_ && _)
alignedDequeue.ready := (sendStageReady.asUInt & selectOH).orR

status.idle := sendStageClear && !bufferValid && !readStageValid && stage0Idle
status.idle := sendStageClear && !bufferValid && !readStageValid && readQueueClear && !bufferFull
val idleNext: Bool = RegNext(status.idle, true.B)
status.last := (!idleNext && status.idle) || invalidInstructionNext
status.changeMaskGroup := maskSelect.valid && !lsuRequest.valid
Expand Down

0 comments on commit 2be1c22

Please sign in to comment.