Skip to content

Commit

Permalink
[rtl] Handle chaining checks across register groups
Browse files Browse the repository at this point in the history
  • Loading branch information
qinjun-li committed Nov 20, 2024
1 parent cf83bb5 commit 1b59b4f
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 44 deletions.
9 changes: 2 additions & 7 deletions t1/src/Lane.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
lastWriteOH
)

val selectMask: UInt = Mux(
val selectMask: UInt = Mux(
segmentLS,
segmentMask,
Mux(
Expand All @@ -1151,13 +1151,8 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
lastWriteOH
)
)
// 8 register
val paddingSize: Int = elementSizeForOneRegister * 8
val shifterMask: UInt = (((selectMask ## Fill(paddingSize, true.B))
<< laneRequest.bits.vd(2, 0) ## 0.U(log2Ceil(elementSizeForOneRegister).W))
>> paddingSize).asUInt

vrf.instructionWriteReport.bits.elementMask := shifterMask
vrf.instructionWriteReport.bits.elementMask := selectMask

// clear record by instructionFinished
vrf.instructionLastReport := instructionFinished
Expand Down
23 changes: 20 additions & 3 deletions t1/src/vrf/ChainingCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,26 @@ class ChainingCheck(val parameter: VRFParam) extends Module {
val sameInst: Bool = read.instructionIndex === record.bits.instIndex

// 3: 8 register
val readOH: UInt = UIntToOH((read.vs ## read.offset)(parameter.vrfOffsetBits + 3 - 1, 0))
val hitElement: Bool = (readOH & record.bits.elementMask) === 0.U
val readOH: UInt = UIntToOH((read.vs ## read.offset)(parameter.vrfOffsetBits + 3 - 1, 0))

val raw: Bool = record.bits.vd.valid && (read.vs(4, 3) === record.bits.vd.bits(4, 3)) && hitElement
// todo: def
val elementSizeForOneRegister: Int = parameter.vLen / parameter.datapathWidth / parameter.laneNumber
val paddingSize: Int = elementSizeForOneRegister * 8

// elementMask records the relative position of the relative instruction.
// Let's calculate the absolute position.
val maskShifter: UInt = (((Fill(paddingSize, true.B) ## record.bits.elementMask ## Fill(paddingSize, true.B))
<< record.bits.vd.bits(2, 0) ## 0.U(log2Ceil(elementSizeForOneRegister).W))
>> paddingSize).asUInt(2 * paddingSize - 1, 0)
// mask for vd's group
val maskForVD: UInt = cutUIntBySize(maskShifter, 2)(0)
// Due to the existence of segment load, writes may cross register groups
// So we need the mask of the previous set of registers
val maskForVD1: UInt = cutUIntBySize(maskShifter, 2)(1)

val hitVd: Bool = (readOH & maskForVD) === 0.U && read.vs(4, 3) === record.bits.vd.bits(4, 3)
val hitVd1: Bool = (readOH & maskForVD1) === 0.U && read.vs(4, 3) === (record.bits.vd.bits(4, 3) + 1.U)

val raw: Bool = record.bits.vd.valid && (hitVd || hitVd1)
checkResult := !(!older && raw && !sameInst && recordValid)
}
8 changes: 4 additions & 4 deletions t1/src/vrf/VRF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -518,13 +518,13 @@ class VRF(val parameter: VRFParam) extends Module with SerializableModule[VRFPar
)
vrfAllocateIssue := freeRecord.orR && olderCheck

val writePort: Seq[ValidIO[VRFWriteRequest]] = Seq(writePipe)
val writeOH = writePort.map(p => UIntToOH((p.bits.vd ## p.bits.offset)(parameter.vrfOffsetBits + 3 - 1, 0)))
val writePort: Seq[ValidIO[VRFWriteRequest]] = Seq(writePipe)
val loadUnitReadPorts: Seq[DecoupledIO[VRFReadRequest]] = Seq(readRequests.last)
val loadReadOH: Seq[UInt] =
loadUnitReadPorts.map(p => UIntToOH((p.bits.vs ## p.bits.offset)(parameter.vrfOffsetBits + 3 - 1, 0)))
Seq(chainingRecord, chainingRecordCopy).foreach { recordVec =>
recordVec.zipWithIndex.foreach { case (record, i) =>
// read write one hot base on base address
val writeOH = writePort.map(p => UIntToOH((p.bits.vd - record.bits.vd.bits)(2, 0) ## p.bits.offset))
val loadReadOH = loadUnitReadPorts.map(p => UIntToOH((p.bits.vs - record.bits.vs2)(2, 0) ## p.bits.offset))
val dataInLsuQueue = ohCheck(loadDataInLSUWriteQueue, record.bits.instIndex, parameter.chainingSize)
// elementMask update by write
val writeUpdateValidVec: Seq[Bool] =
Expand Down
66 changes: 36 additions & 30 deletions t1/src/vrf/WriteCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,41 @@ class WriteCheck(val parameter: VRFParam) extends Module {
val sameInst: Bool = check.instructionIndex === record.bits.instIndex
val checkOH: UInt = UIntToOH((check.vd ## check.offset)(parameter.vrfOffsetBits + 3 - 1, 0))

// this element in record not execute
val notHitMaskVd: Bool = (checkOH & record.bits.elementMask) === 0.U
val waw: Bool = record.bits.vd.valid && check.vd(4, 3) === record.bits.vd.bits(4, 3) && notHitMaskVd
// inst eg: vadd v0, v1, v1 (lmul = 1)
// We only recorded vd-related masks.
// 0 base: 11111111111111xx eg vs = 0 off=2
// As above, using vd as the perspective,
// we will access the lowest two elements of the register group where vd is located.
// But from the perspective of vs1:
// 1 base: 111111111111xx11 eg vs = 1 off=2
// Apparently. Our mask has shifted
// 0 base => 1 base << (1 * off)
// we need vd%8 base => vs1%8 base => vd base mask << (vs1 - vd) * off
// => vd base mask >> 8 * off << (8 + vs1 - vd) * off
// => vd base mask << (8 + vs1 - vd) * off >> 8 * off
val vs1Mask: UInt = (((-1.S(parameter.elementSize.W)).asUInt ## record.bits.elementMask) <<
((8.U + record.bits.vs1.bits(2, 0) - record.bits.vd.bits(2, 0)) << parameter.vrfOffsetBits).asUInt).asUInt(
2 * 8 * parameter.singleGroupSize - 1,
8 * parameter.singleGroupSize
)
val notHitVs1: Bool = (checkOH & vs1Mask) === 0.U
val war1: Bool = record.bits.vs1.valid && check.vd(4, 3) === record.bits.vs1.bits(4, 3) && notHitVs1
val maskForVs2: UInt = record.bits.elementMask & Fill(parameter.elementSize, !record.bits.onlyRead)
val vs2Mask: UInt = (((-1.S(parameter.elementSize.W)).asUInt ## maskForVs2) <<
((8.U + record.bits.vs2(2, 0) - record.bits.vd.bits(2, 0)) << parameter.vrfOffsetBits).asUInt).asUInt(
2 * 8 * parameter.singleGroupSize - 1,
8 * parameter.singleGroupSize
)
val notHitVs2: Bool = (checkOH & vs2Mask) === 0.U
val war2: Bool = check.vd(4, 3) === record.bits.vs2(4, 3) && notHitVs2
val elementSizeForOneRegister: Int = parameter.vLen / parameter.datapathWidth / parameter.laneNumber
val paddingSize: Int = elementSizeForOneRegister * 8

// elementMask records the relative position of the relative instruction.
// Let's calculate the absolute position.
val maskShifter: UInt = (((Fill(paddingSize, true.B) ## record.bits.elementMask ## Fill(paddingSize, true.B))
<< record.bits.vd.bits(2, 0) ## 0.U(log2Ceil(elementSizeForOneRegister).W))
>> paddingSize).asUInt(2 * paddingSize - 1, 0)
// mask for vd's group
val maskForVD: UInt = cutUIntBySize(maskShifter, 2)(0)
// Due to the existence of segment load, writes may cross register groups
// So we need the mask of the previous set of registers
val maskForVD1: UInt = cutUIntBySize(maskShifter, 2)(1)

val hitVd: Bool = (checkOH & maskForVD) === 0.U && check.vd(4, 3) === record.bits.vd.bits(4, 3)
val hitVd1: Bool = (checkOH & maskForVD1) === 0.U && check.vd(4, 3) === (record.bits.vd.bits(4, 3) + 1.U)
val waw: Bool = record.bits.vd.valid && (hitVd || hitVd1)

// calculate the absolute position for vs1
val vs1Mask: UInt = (((record.bits.elementMask ## Fill(paddingSize, true.B))
<< record.bits.vs1.bits(2, 0) ## 0.U(log2Ceil(elementSizeForOneRegister).W))
>> paddingSize).asUInt
val notHitVs1: Bool = (checkOH & vs1Mask) === 0.U
val war1: Bool = record.bits.vs1.valid && check.vd(4, 3) === record.bits.vs1.bits(4, 3) && notHitVs1

// calculate the absolute position for vs2
val maskShifterForVs2: UInt = (((Fill(paddingSize, true.B) ## record.bits.elementMask ## Fill(paddingSize, true.B))
<< record.bits.vs2(2, 0) ## 0.U(log2Ceil(elementSizeForOneRegister).W))
>> paddingSize).asUInt(2 * paddingSize - 1, 0)

val maskForVs2: UInt = cutUIntBySize(maskShifterForVs2, 2)(0) & Fill(parameter.elementSize, !record.bits.onlyRead)
val maskForVs21: UInt = cutUIntBySize(maskShifterForVs2, 2)(1)
val hitVs2: Bool = (checkOH & maskForVs2) === 0.U && check.vd(4, 3) === record.bits.vs2(4, 3)
val hitVs21: Bool = (checkOH & maskForVs21) === 0.U && check.vd(4, 3) === (record.bits.vs2(4, 3) + 1.U)
val war2: Bool = hitVs2 || hitVs21

checkResult := !((!older && (waw || war1 || war2)) && !sameInst && record.valid)
}

0 comments on commit 1b59b4f

Please sign in to comment.