From f66a6233e0ecfb2eb2c40200683b66b9d9c595b9 Mon Sep 17 00:00:00 2001 From: qinjun-li Date: Tue, 12 Nov 2024 02:30:02 +0800 Subject: [PATCH] [rtl] fix float reduce. --- t1/src/Bundles.scala | 10 ++++--- t1/src/laneStage/MaskExchangeUnit.scala | 2 ++ t1/src/mask/MaskReduce.scala | 36 ++++++++++++++----------- t1/src/mask/MaskUnit.scala | 5 ++++ 4 files changed, 34 insertions(+), 19 deletions(-) diff --git a/t1/src/Bundles.scala b/t1/src/Bundles.scala index 26926f6d5..22efc4a8f 100644 --- a/t1/src/Bundles.scala +++ b/t1/src/Bundles.scala @@ -730,11 +730,13 @@ class MaskUnitInstReq(parameter: T1Parameter) extends Bundle { class MaskUnitExeReq(parameter: LaneParameter) extends Bundle { // source1, read vs - val source1: UInt = UInt(parameter.datapathWidth.W) + val source1: UInt = UInt(parameter.datapathWidth.W) // source2, read offset - val source2: UInt = UInt(parameter.datapathWidth.W) - val index: UInt = UInt(parameter.instructionIndexBits.W) - val ffo: Bool = Bool() + val source2: UInt = UInt(parameter.datapathWidth.W) + val index: UInt = UInt(parameter.instructionIndexBits.W) + val ffo: Bool = Bool() + // Is there a valid element? + val fpReduceValid: Option[Bool] = Option.when(parameter.fpuEnable)(Bool()) } class MaskUnitExeResponse(parameter: LaneParameter) extends Bundle { diff --git a/t1/src/laneStage/MaskExchangeUnit.scala b/t1/src/laneStage/MaskExchangeUnit.scala index c9a0770ac..f340e8eb1 100644 --- a/t1/src/laneStage/MaskExchangeUnit.scala +++ b/t1/src/laneStage/MaskExchangeUnit.scala @@ -55,6 +55,8 @@ class MaskExchangeUnit(parameter: LaneParameter) extends Module { maskReq.bits.index := enqueue.bits.instructionIndex maskReq.bits.ffo := enqueue.bits.ffoSuccess + maskReq.bits.fpReduceValid.zip(enqueue.bits.fpReduceValid).foreach { case (sink, source) => sink := source } + maskRequestToLSU := enqueue.bits.loadStore // type change MaskUnitExeResponse -> LaneStage3Enqueue diff --git a/t1/src/mask/MaskReduce.scala b/t1/src/mask/MaskReduce.scala index 4816910b4..c85388838 100644 --- a/t1/src/mask/MaskReduce.scala +++ b/t1/src/mask/MaskReduce.scala @@ -8,16 +8,18 @@ import chisel3.experimental.hierarchy.{Instance, Instantiate} import chisel3.util._ class ReduceInput(parameter: T1Parameter) extends Bundle { - val maskType: Bool = Bool() - val eew: UInt = UInt(2.W) - val uop: UInt = UInt(3.W) - val readVS1: UInt = UInt(parameter.datapathWidth.W) - val source2: UInt = UInt((parameter.laneNumber * parameter.datapathWidth).W) - val sourceValid: UInt = UInt(parameter.laneNumber.W) - val lastGroup: Bool = Bool() - val vxrm: UInt = UInt(3.W) - val aluUop: UInt = UInt(4.W) - val sign: Bool = Bool() + val maskType: Bool = Bool() + val eew: UInt = UInt(2.W) + val uop: UInt = UInt(3.W) + val readVS1: UInt = UInt(parameter.datapathWidth.W) + val source2: UInt = UInt((parameter.laneNumber * parameter.datapathWidth).W) + val sourceValid: UInt = UInt(parameter.laneNumber.W) + val lastGroup: Bool = Bool() + val vxrm: UInt = UInt(3.W) + val aluUop: UInt = UInt(4.W) + val sign: Bool = Bool() + // for fpu + val fpSourceValid: Option[UInt] = Option.when(parameter.fpuEnable)(UInt(parameter.laneNumber.W)) } class ReduceOutput(parameter: T1Parameter) extends Bundle { @@ -66,7 +68,7 @@ class MaskReduce(parameter: T1Parameter) extends Module { val skipFlotReduce: Bool = WireDefault(false.B) val eew1HReg: UInt = UIntToOH(reqReg.eew)(2, 0) - val floatType: Bool = reqReg.uop(2) + val floatType: Bool = reqReg.uop(2) || reqReg.uop(1, 0).andR val NotAdd: Bool = reqReg.uop(1) val widen: Bool = reqReg.uop === "b001".U || reqReg.uop(2, 1) === "b11".U val needFold: Bool = eew1HReg(0) || (eew1HReg(1) && !widen) @@ -117,7 +119,9 @@ class MaskReduce(parameter: T1Parameter) extends Module { } } - val updateInitMask: UInt = FillInterleaved(8, writeMask) + val enqWriteMask: UInt = Fill(2, in.bits.eew(1)) ## in.bits.eew.orR ## true.B + val updateInitMask: UInt = FillInterleaved(8, enqWriteMask) + val updateMask: UInt = FillInterleaved(8, writeMask) when(newInstruction) { // todo: update reduceInit when first in.fire reduceInit := in.bits.readVS1 & updateInitMask @@ -133,20 +137,22 @@ class MaskReduce(parameter: T1Parameter) extends Module { // result update when(updateResult) { - reduceInit := reduceResult & updateInitMask + reduceInit := reduceResult & updateMask } when(stateLast) { lastFoldCount := false.B } - val selectLaneResult: UInt = Mux1H( + val selectLaneResult: UInt = Mux1H( UIntToOH(crossFoldCount), cutUInt(reqReg.source2, parameter.datapathWidth) ) + val sourceValidCalculate: UInt = + reqReg.fpSourceValid.map(fv => Mux(floatType, fv & reqReg.sourceValid, fv)).getOrElse(reqReg.sourceValid) sourceValid := Mux1H( UIntToOH(crossFoldCount), - reqReg.sourceValid.asBools + sourceValidCalculate.asBools ) val reduceDataVec = cutUInt(reduceInit, 8) // reduceFoldCount = false => abcd -> xxab | xxcd -> mask 0011 diff --git a/t1/src/mask/MaskUnit.scala b/t1/src/mask/MaskUnit.scala index 14438ae89..ad2c71b18 100644 --- a/t1/src/mask/MaskUnit.scala +++ b/t1/src/mask/MaskUnit.scala @@ -871,6 +871,11 @@ class MaskUnit(parameter: T1Parameter) extends Module { reduceUnit.in.bits.sign := !instReg.decodeResult(Decoder.unsigned1) reduceUnit.newInstruction := !readVS1Reg.sendToExecution && reduceUnit.in.fire reduceUnit.validInst := instReg.vl.orR + + reduceUnit.in.bits.fpSourceValid.foreach { sink => + sink := VecInit(exeReqReg.map(_.bits.fpReduceValid.get)).asUInt + } + when(reduceUnit.in.fire) { readVS1Reg.sendToExecution := true.B }