Skip to content

Commit

Permalink
[rtl] Correct the mask on the boundary line.
Browse files Browse the repository at this point in the history
  • Loading branch information
qinjun-li committed Jan 25, 2024
1 parent 3e7ed35 commit ab10f08
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 54 deletions.
5 changes: 3 additions & 2 deletions t1/src/Bundles.scala
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ class InstructionControlRecord(param: LaneParameter) extends Bundle {
val lastGroupForInstruction: UInt = UInt(param.groupNumberBits.W)

/** this is the last lane for mask type instruction */
val isLastLaneForMaskLogic: Bool = Bool()
val isLastLaneForInstruction: Bool = Bool()

/** the find first one instruction is finished by other lanes,
* for example, sbf(set before first)
Expand Down Expand Up @@ -625,7 +625,8 @@ class LaneExecuteStage(parameter: LaneParameter)(isLastSlot: Boolean) extends Bu
class ExecutionUnitRecord(parameter: LaneParameter)(isLastSlot: Boolean) extends Bundle {
val crossReadVS2: Bool = Bool()
val bordersForMaskLogic: Bool = Bool()
val mask: UInt = UInt(4.W)
val maskForMaskInput: UInt = UInt(4.W)
val maskForFilter: UInt = UInt(4.W)
// false -> lsb of cross read group
val executeIndex: Bool = Bool()
val source: Vec[UInt] = Vec(3, UInt(parameter.datapathWidth.W))
Expand Down
23 changes: 18 additions & 5 deletions t1/src/Lane.scala
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
laneState.laneIndex := laneIndex
laneState.decodeResult := record.laneRequest.decodeResult
laneState.lastGroupForInstruction := record.lastGroupForInstruction
laneState.isLastLaneForInstruction := record.isLastLaneForInstruction
laneState.instructionFinished := record.instructionFinished
laneState.csr := record.csr
laneState.maskType := record.laneRequest.mask
Expand Down Expand Up @@ -541,7 +542,8 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
stage1.enqueue.valid := stage0.dequeue.valid
stage0.dequeue.ready := stage1.enqueue.ready
stage1.enqueue.bits.groupCounter := stage0.dequeue.bits.groupCounter
stage1.enqueue.bits.mask := stage0.dequeue.bits.mask
stage1.enqueue.bits.maskForMaskInput := stage0.dequeue.bits.maskForMaskInput
stage1.enqueue.bits.boundaryMaskCorrection := stage0.dequeue.bits.boundaryMaskCorrection
stage1.enqueue.bits.sSendResponse.zip(stage0.dequeue.bits.sSendResponse).foreach { case (sink, source) =>
sink := source
}
Expand Down Expand Up @@ -609,8 +611,9 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
executionUnit.state := laneState
executionUnit.enqueue.bits.src := stage1.dequeue.bits.src
executionUnit.enqueue.bits.bordersForMaskLogic :=
(stage1.dequeue.bits.groupCounter === record.lastGroupForInstruction && record.isLastLaneForMaskLogic)
(stage1.dequeue.bits.groupCounter === record.lastGroupForInstruction && record.isLastLaneForInstruction)
executionUnit.enqueue.bits.mask := stage1.dequeue.bits.mask
executionUnit.enqueue.bits.maskForFilter := stage1.dequeue.bits.maskForFilter
executionUnit.enqueue.bits.groupCounter := stage1.dequeue.bits.groupCounter
executionUnit.enqueue.bits.sSendResponse.zip(stage1.dequeue.bits.sSendResponse).foreach { case (sink, source) =>
sink := source
Expand Down Expand Up @@ -906,20 +909,30 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
csrInterface.vl(parameter.datapathWidthBits + parameter.laneNumberBits - 1, parameter.datapathWidthBits)
val vlHead: UInt = csrInterface.vl(parameter.vlMaxBits - 1, parameter.datapathWidthBits + parameter.laneNumberBits)
val lastGroupMask = scanRightOr(UIntToOH(vlTail)) >> 1
val dataPathMisaligned = vlTail.orR
val dataPathMisaligned: Bool = vlTail.orR
val maskeDataGroup = (vlHead ## vlBody) - !dataPathMisaligned
val lastLaneIndexForMaskLogic: UInt = maskeDataGroup(parameter.laneNumberBits - 1, 0)
val isLastLaneForMaskLogic: Bool = lastLaneIndexForMaskLogic === laneIndex
val lastGroupCountForMaskLogic: UInt = (maskeDataGroup >> parameter.laneNumberBits).asUInt
val misalignedForOther: Bool = Mux1H(
requestVSew1H(1, 0),
Seq(
csrInterface.vl(1, 0).orR,
csrInterface.vl(0),
)
)

entranceControl.lastGroupForInstruction := Mux(
laneRequest.bits.decodeResult(Decoder.maskLogic),
lastGroupCountForMaskLogic,
lastGroupForLane
)

entranceControl.isLastLaneForMaskLogic :=
isLastLaneForMaskLogic && dataPathMisaligned && laneRequest.bits.decodeResult(Decoder.maskLogic)
entranceControl.isLastLaneForInstruction := Mux(
laneRequest.bits.decodeResult(Decoder.maskLogic),
isLastLaneForMaskLogic && dataPathMisaligned,
isEndLane && misalignedForOther
)

// slot needs to be moved, try to shifter and stall pipe
slotShiftValid := VecInit(Seq.range(0, parameter.chainingSize).map { slotIndex =>
Expand Down
43 changes: 22 additions & 21 deletions t1/src/laneStage/LaneExecutionBridge.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class LaneExecuteRequest(parameter: LaneParameter, isLastSlot: Boolean) extends
val crossReadSource: Option[UInt] = Option.when(isLastSlot)(UInt((parameter.datapathWidth * 2).W))
val bordersForMaskLogic: Bool = Bool()
val mask: UInt = UInt((parameter.datapathWidth / 8).W)
val maskForFilter: UInt = UInt((parameter.datapathWidth / 8).W)
val groupCounter: UInt = UInt(parameter.groupNumberBits.W)
val sSendResponse: Option[Bool] = Option.when(isLastSlot)(Bool())
}
Expand All @@ -27,7 +28,7 @@ class LaneExecuteResponse(parameter: LaneParameter, isLastSlot: Boolean) extends

class ExecutionBridgeRecordQueue(parameter: LaneParameter, isLastSlot: Boolean) extends Bundle {
val bordersForMaskLogic: Bool = Bool()
val mask: UInt = UInt((parameter.datapathWidth / 8).W)
val maskForFilter: UInt = UInt((parameter.datapathWidth / 8).W)
val groupCounter: UInt = UInt(parameter.groupNumberBits.W)
val sSendResponse: Option[Bool] = Option.when(isLastSlot)(Bool())
val executeIndex: Bool = Bool()
Expand Down Expand Up @@ -66,7 +67,7 @@ class LaneExecutionBridge(parameter: LaneParameter, isLastSlot: Boolean, slotInd
RegEnable(state.newInstruction.get, false.B, state.newInstruction.get || firstRequestFire.get)
}
firstRequestFire.foreach(d =>
d := enqueue.fire && firstRequest.get && (state.maskNotMaskedElement || enqueue.bits.mask(0))
d := enqueue.fire && firstRequest.get && enqueue.bits.maskForFilter(0)
)

// Type widenReduce instructions occupy double the data registers because they need to retain the carry bit.
Expand All @@ -89,11 +90,11 @@ class LaneExecutionBridge(parameter: LaneParameter, isLastSlot: Boolean, slotInd
executionRecordValid := enqueue.fire
}
if (isLastSlot) {
val firstGroupNotExecute = decodeResult(Decoder.crossWrite) && !state.maskNotMaskedElement && !Mux(
val firstGroupNotExecute = decodeResult(Decoder.crossWrite) && !Mux(
state.vSew1H(0),
// sew = 8, 2 mask bit / group
enqueue.bits.mask(1, 0).orR,
enqueue.bits.mask(0)
enqueue.bits.maskForFilter(1, 0).orR,
enqueue.bits.maskForFilter(0)
)
// update execute index
when(enqueue.fire || vfuRequest.fire) {
Expand All @@ -106,8 +107,9 @@ class LaneExecutionBridge(parameter: LaneParameter, isLastSlot: Boolean, slotInd

when(enqueue.fire) {
executionRecord.crossReadVS2 := decodeResult(Decoder.crossRead) && !decodeResult(Decoder.vwmacc)
executionRecord.bordersForMaskLogic := enqueue.bits.bordersForMaskLogic
executionRecord.mask := enqueue.bits.mask
executionRecord.bordersForMaskLogic := enqueue.bits.bordersForMaskLogic && state.decodeResult(Decoder.maskLogic)
executionRecord.maskForMaskInput := enqueue.bits.mask
executionRecord.maskForFilter := enqueue.bits.maskForFilter
executionRecord.source := enqueue.bits.src
executionRecord.crossReadSource.foreach(_ := enqueue.bits.crossReadSource.get)
executionRecord.sSendResponse.foreach(_ := enqueue.bits.sSendResponse.get)
Expand Down Expand Up @@ -217,19 +219,18 @@ class LaneExecutionBridge(parameter: LaneParameter, isLastSlot: Boolean, slotInd
* use [[lastGroupMask]] to mask the result otherwise use [[fullMask]]. */
val maskCorrect: Bits = Mux(executionRecord.bordersForMaskLogic, lastGroupMask, fullMask)

val maskExtend = Mux(state.vSew1H(1), FillInterleaved(2, executionRecord.mask(1, 0)), executionRecord.mask)
val maskExtend = Mux(state.vSew1H(1), FillInterleaved(2, executionRecord.maskForMaskInput(1, 0)), executionRecord.maskForMaskInput)
vfuRequest.bits.src := VecInit(Seq(finalSource1, finalSource2, finalSource3, maskCorrect))
vfuRequest.bits.opcode := decodeResult(Decoder.uop)
vfuRequest.bits.mask := Mux(
decodeResult(Decoder.adder),
Mux(decodeResult(Decoder.maskSource), executionRecord.mask, 0.U(4.W)),
Mux(decodeResult(Decoder.maskSource), executionRecord.maskForMaskInput, 0.U(4.W)),
maskExtend | Fill(4, !state.maskType)
)
val executeMask = executionRecord.mask | FillInterleaved(4, state.maskNotMaskedElement)
vfuRequest.bits.executeMask := Mux(
executionRecord.executeIndex,
0.U(2.W) ## executeMask(3, 2),
executeMask
0.U(2.W) ## executionRecord.maskForFilter(3, 2),
executionRecord.maskForFilter
)
vfuRequest.bits.sign0 := !decodeResult(Decoder.unsigned0)
vfuRequest.bits.sign := !decodeResult(Decoder.unsigned1)
Expand Down Expand Up @@ -283,7 +284,7 @@ class LaneExecutionBridge(parameter: LaneParameter, isLastSlot: Boolean, slotInd
recordQueueReadyForNoExecute := notExecute && recordQueue.io.enq.ready
recordQueue.io.enq.valid := executionRecordValid && (vfuRequest.ready || notExecute)
recordQueue.io.enq.bits.bordersForMaskLogic := executionRecord.bordersForMaskLogic
recordQueue.io.enq.bits.mask := executionRecord.mask
recordQueue.io.enq.bits.maskForFilter := executionRecord.maskForFilter
recordQueue.io.enq.bits.groupCounter := executionRecord.groupCounter
recordQueue.io.enq.bits.executeIndex := executionRecord.executeIndex
recordQueue.io.enq.bits.source2 := executionRecord.source(1)
Expand Down Expand Up @@ -344,29 +345,29 @@ class LaneExecutionBridge(parameter: LaneParameter, isLastSlot: Boolean, slotInd
val normalReduceMask = Mux1H(
state.vSew1H,
Seq(
recordQueue.io.deq.bits.mask,
FillInterleaved(2, recordQueue.io.deq.bits.mask(1, 0)),
FillInterleaved(4, recordQueue.io.deq.bits.mask(0))
recordQueue.io.deq.bits.maskForFilter,
FillInterleaved(2, recordQueue.io.deq.bits.maskForFilter(1, 0)),
FillInterleaved(4, recordQueue.io.deq.bits.maskForFilter(0))
)
)
val widenReduceMask = Mux1H(
state.vSew1H(1, 0),
Seq(
FillInterleaved(2, Mux(
executionRecord.executeIndex,
recordQueue.io.deq.bits.mask(3, 2),
recordQueue.io.deq.bits.mask(1, 0)
recordQueue.io.deq.bits.maskForFilter(3, 2),
recordQueue.io.deq.bits.maskForFilter(1, 0)
)),
FillInterleaved(4, Mux(
executionRecord.executeIndex,
recordQueue.io.deq.bits.mask(1),
recordQueue.io.deq.bits.mask(0)
recordQueue.io.deq.bits.maskForFilter(1),
recordQueue.io.deq.bits.maskForFilter(0)
))
)
)
// masked element don't update 'reduceResult'
val reduceUpdateByteMask: UInt =
Mux(widenReduce, widenReduceMask, normalReduceMask) | FillInterleaved(4, state.maskNotMaskedElement)
Mux(widenReduce, widenReduceMask, normalReduceMask)
val foldUpdateMask = Wire(UInt(4.W))
updateReduceResult.get := {
val dataVec = cutUInt(dataDequeue, 8)
Expand Down
1 change: 1 addition & 0 deletions t1/src/laneStage/LaneStage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class LaneState(parameter: LaneParameter) extends Bundle {
val decodeResult: DecodeBundle = Decoder.bundle(parameter.fpuEnable)
/** which group is the last group for instruction. */
val lastGroupForInstruction: UInt = UInt(parameter.groupNumberBits.W)
val isLastLaneForInstruction: Bool = Bool()
val instructionFinished: Bool = Bool()
val csr: CSRInterface = new CSRInterface(parameter.vlMaxBits)
// vm = 0
Expand Down
55 changes: 32 additions & 23 deletions t1/src/laneStage/LaneStage0.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class LaneStage0StateUpdate(parameter: LaneParameter) extends Bundle {
}

class LaneStage0Dequeue(parameter: LaneParameter, isLastSlot: Boolean) extends Bundle {
val mask: UInt = UInt((parameter.datapathWidth/8).W)
val maskForMaskInput: UInt = UInt((parameter.datapathWidth/8).W)
val boundaryMaskCorrection: UInt = UInt((parameter.datapathWidth/8).W)
val sSendResponse: Option[Bool] = Option.when(isLastSlot)(Bool())
val groupCounter: UInt = UInt(parameter.groupNumberBits.W)
}
Expand All @@ -43,9 +44,9 @@ class LaneStage0(parameter: LaneParameter, isLastSlot: Boolean) extends
val stageWire: LaneStage0Dequeue = Wire(new LaneStage0Dequeue(parameter, isLastSlot))
// 这一组如果全被masked了也不压进流水
val notMaskedAllElement: Bool = Mux1H(state.vSew1H, Seq(
stageWire.mask.orR,
stageWire.mask(1, 0).orR,
stageWire.mask(0),
stageWire.maskForMaskInput.orR,
stageWire.maskForMaskInput(1, 0).orR,
stageWire.maskForMaskInput(0),
)) || state.maskNotMaskedElement ||
state.decodeResult(Decoder.maskDestination) || state.decodeResult(Decoder.red) ||
state.decodeResult(Decoder.readOnly) || state.loadStore || state.decodeResult(Decoder.gather) ||
Expand Down Expand Up @@ -84,23 +85,6 @@ class LaneStage0(parameter: LaneParameter, isLastSlot: Boolean) extends
/** The mask group will be updated */
val maskGroupWillUpdate: Bool = state.decodeResult(Decoder.maskLogic) || updateLaneState.maskExhausted

/** The index of next execute element in whole instruction */
val elementIndexForInstruction = enqueue.bits.maskGroupCount ## Mux1H(
state.vSew1H,
Seq(
enqueue.bits.maskIndex(parameter.datapathWidthBits - 1, 2) ## state.laneIndex ## enqueue.bits.maskIndex(1, 0),
enqueue.bits.maskIndex(parameter.datapathWidthBits - 1, 1) ## state.laneIndex ## enqueue.bits.maskIndex(0),
enqueue.bits.maskIndex ## state.laneIndex
)
)

/** The next element is out of execution range */
updateLaneState.outOfExecutionRange := Mux(
state.decodeResult(Decoder.maskLogic),
(enqueue.bits.maskGroupCount > state.lastGroupForInstruction),
elementIndexForInstruction >= state.csr.vl
) || state.instructionFinished

/** Encoding of different element lengths: 1, 8, 16, 32 */
val elementLengthOH = Mux(state.decodeResult(Decoder.maskLogic), 1.U, state.vSew1H(2, 0) ## false.B)

Expand All @@ -115,9 +99,34 @@ class LaneStage0(parameter: LaneParameter, isLastSlot: Boolean) extends
)
)

val isTheLastGroup = dataGroupIndex === state.lastGroupForInstruction
/** The next element is out of execution range */
updateLaneState.outOfExecutionRange := dataGroupIndex > state.lastGroupForInstruction || state.instructionFinished

stageWire.mask := (state.mask.bits >> enqueue.bits.maskIndex).asUInt(3, 0)
val isTheLastGroup: Bool = dataGroupIndex === state.lastGroupForInstruction

// Correct the mask on the boundary line
val vlNeedCorrect: Bool = Mux1H(
state.vSew1H(1, 0),
Seq(
state.csr.vl(1, 0).orR,
state.csr.vl(0)
)
)
val correctMask: UInt = Mux1H(
state.vSew1H(1, 0),
Seq(
(scanRightOr(UIntToOH(state.csr.vl(1, 0))) >> 1).asUInt,
1.U(4.W)
)
)
val needCorrect: Bool =
isTheLastGroup &&
state.isLastLaneForInstruction &&
vlNeedCorrect
val maskCorrect: UInt = Mux(needCorrect, correctMask, 15.U(4.W))

stageWire.maskForMaskInput := (state.mask.bits >> enqueue.bits.maskIndex).asUInt(3, 0)
stageWire.boundaryMaskCorrection := maskCorrect

/** The index of next element in this mask group.(0-31) */
updateLaneState.maskIndex := Mux(
Expand Down
9 changes: 6 additions & 3 deletions t1/src/laneStage/LaneStage1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import org.chipsalliance.t1.rtl.lane.{CrossReadUnit, LaneState, VrfReadPipe}

class LaneStage1Enqueue(parameter: LaneParameter, isLastSlot: Boolean) extends Bundle {
val groupCounter: UInt = UInt(parameter.groupNumberBits.W)
val mask: UInt = UInt((parameter.datapathWidth / 8).W)
val maskForMaskInput: UInt = UInt((parameter.datapathWidth / 8).W)
val boundaryMaskCorrection: UInt = UInt((parameter.datapathWidth / 8).W)
val sSendResponse: Option[Bool] = Option.when(isLastSlot)(Bool())
}

Expand Down Expand Up @@ -243,13 +244,15 @@ class LaneStage1(parameter: LaneParameter, isLastSlot: Boolean) extends Module {
)

val source1Select: UInt = Mux(state.decodeResult(Decoder.vtype), dataQueueVs1.bits, scalarDataRepeat)
dequeue.bits.mask := pipeQueue.io.deq.bits.mask
dequeue.bits.mask := pipeQueue.io.deq.bits.maskForMaskInput
dequeue.bits.groupCounter := pipeQueue.io.deq.bits.groupCounter
dequeue.bits.src := VecInit(Seq(source1Select, dataQueueVs2.io.deq.bits, dataQueueVd.io.deq.bits))
dequeue.bits.crossReadSource.foreach(_ := crossReadResultQueue.get.io.deq.bits)
dequeue.bits.sSendResponse.foreach(_ := pipeQueue.io.deq.bits.sSendResponse.get)

dequeue.bits.maskForFilter := FillInterleaved(4, state.maskNotMaskedElement) | pipeQueue.io.deq.bits.mask
dequeue.bits.maskForFilter :=
(FillInterleaved(4, state.maskNotMaskedElement) | pipeQueue.io.deq.bits.maskForMaskInput) &
pipeQueue.io.deq.bits.boundaryMaskCorrection
// All required data is ready
val dataQueueValidVec: Seq[Bool] =
Seq(
Expand Down

0 comments on commit ab10f08

Please sign in to comment.