Skip to content

Commit

Permalink
[rtl] fix float reduce.
Browse files Browse the repository at this point in the history
  • Loading branch information
qinjun-li committed Nov 11, 2024
1 parent 8bf475f commit f66a623
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 19 deletions.
10 changes: 6 additions & 4 deletions t1/src/Bundles.scala
Original file line number Diff line number Diff line change
Expand Up @@ -730,11 +730,13 @@ class MaskUnitInstReq(parameter: T1Parameter) extends Bundle {

class MaskUnitExeReq(parameter: LaneParameter) extends Bundle {
// source1, read vs
val source1: UInt = UInt(parameter.datapathWidth.W)
val source1: UInt = UInt(parameter.datapathWidth.W)
// source2, read offset
val source2: UInt = UInt(parameter.datapathWidth.W)
val index: UInt = UInt(parameter.instructionIndexBits.W)
val ffo: Bool = Bool()
val source2: UInt = UInt(parameter.datapathWidth.W)
val index: UInt = UInt(parameter.instructionIndexBits.W)
val ffo: Bool = Bool()
// Is there a valid element?
val fpReduceValid: Option[Bool] = Option.when(parameter.fpuEnable)(Bool())
}

class MaskUnitExeResponse(parameter: LaneParameter) extends Bundle {
Expand Down
2 changes: 2 additions & 0 deletions t1/src/laneStage/MaskExchangeUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class MaskExchangeUnit(parameter: LaneParameter) extends Module {
maskReq.bits.index := enqueue.bits.instructionIndex
maskReq.bits.ffo := enqueue.bits.ffoSuccess

maskReq.bits.fpReduceValid.zip(enqueue.bits.fpReduceValid).foreach { case (sink, source) => sink := source }

maskRequestToLSU := enqueue.bits.loadStore

// type change MaskUnitExeResponse -> LaneStage3Enqueue
Expand Down
36 changes: 21 additions & 15 deletions t1/src/mask/MaskReduce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@ import chisel3.experimental.hierarchy.{Instance, Instantiate}
import chisel3.util._

class ReduceInput(parameter: T1Parameter) extends Bundle {
val maskType: Bool = Bool()
val eew: UInt = UInt(2.W)
val uop: UInt = UInt(3.W)
val readVS1: UInt = UInt(parameter.datapathWidth.W)
val source2: UInt = UInt((parameter.laneNumber * parameter.datapathWidth).W)
val sourceValid: UInt = UInt(parameter.laneNumber.W)
val lastGroup: Bool = Bool()
val vxrm: UInt = UInt(3.W)
val aluUop: UInt = UInt(4.W)
val sign: Bool = Bool()
val maskType: Bool = Bool()
val eew: UInt = UInt(2.W)
val uop: UInt = UInt(3.W)
val readVS1: UInt = UInt(parameter.datapathWidth.W)
val source2: UInt = UInt((parameter.laneNumber * parameter.datapathWidth).W)
val sourceValid: UInt = UInt(parameter.laneNumber.W)
val lastGroup: Bool = Bool()
val vxrm: UInt = UInt(3.W)
val aluUop: UInt = UInt(4.W)
val sign: Bool = Bool()
// for fpu
val fpSourceValid: Option[UInt] = Option.when(parameter.fpuEnable)(UInt(parameter.laneNumber.W))
}

class ReduceOutput(parameter: T1Parameter) extends Bundle {
Expand Down Expand Up @@ -66,7 +68,7 @@ class MaskReduce(parameter: T1Parameter) extends Module {
val skipFlotReduce: Bool = WireDefault(false.B)

val eew1HReg: UInt = UIntToOH(reqReg.eew)(2, 0)
val floatType: Bool = reqReg.uop(2)
val floatType: Bool = reqReg.uop(2) || reqReg.uop(1, 0).andR
val NotAdd: Bool = reqReg.uop(1)
val widen: Bool = reqReg.uop === "b001".U || reqReg.uop(2, 1) === "b11".U
val needFold: Bool = eew1HReg(0) || (eew1HReg(1) && !widen)
Expand Down Expand Up @@ -117,7 +119,9 @@ class MaskReduce(parameter: T1Parameter) extends Module {
}
}

val updateInitMask: UInt = FillInterleaved(8, writeMask)
val enqWriteMask: UInt = Fill(2, in.bits.eew(1)) ## in.bits.eew.orR ## true.B
val updateInitMask: UInt = FillInterleaved(8, enqWriteMask)
val updateMask: UInt = FillInterleaved(8, writeMask)
when(newInstruction) {
// todo: update reduceInit when first in.fire
reduceInit := in.bits.readVS1 & updateInitMask
Expand All @@ -133,20 +137,22 @@ class MaskReduce(parameter: T1Parameter) extends Module {

// result update
when(updateResult) {
reduceInit := reduceResult & updateInitMask
reduceInit := reduceResult & updateMask
}

when(stateLast) {
lastFoldCount := false.B
}

val selectLaneResult: UInt = Mux1H(
val selectLaneResult: UInt = Mux1H(
UIntToOH(crossFoldCount),
cutUInt(reqReg.source2, parameter.datapathWidth)
)
val sourceValidCalculate: UInt =
reqReg.fpSourceValid.map(fv => Mux(floatType, fv & reqReg.sourceValid, fv)).getOrElse(reqReg.sourceValid)
sourceValid := Mux1H(
UIntToOH(crossFoldCount),
reqReg.sourceValid.asBools
sourceValidCalculate.asBools
)
val reduceDataVec = cutUInt(reduceInit, 8)
// reduceFoldCount = false => abcd -> xxab | xxcd -> mask 0011
Expand Down
5 changes: 5 additions & 0 deletions t1/src/mask/MaskUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,11 @@ class MaskUnit(parameter: T1Parameter) extends Module {
reduceUnit.in.bits.sign := !instReg.decodeResult(Decoder.unsigned1)
reduceUnit.newInstruction := !readVS1Reg.sendToExecution && reduceUnit.in.fire
reduceUnit.validInst := instReg.vl.orR

reduceUnit.in.bits.fpSourceValid.foreach { sink =>
sink := VecInit(exeReqReg.map(_.bits.fpReduceValid.get)).asUInt
}

when(reduceUnit.in.fire) {
readVS1Reg.sendToExecution := true.B
}
Expand Down

0 comments on commit f66a623

Please sign in to comment.