From 06a07e8f4722c1216f7d535251af29e50b6721ba Mon Sep 17 00:00:00 2001 From: qinjun-li Date: Sun, 8 Dec 2024 19:01:27 +0800 Subject: [PATCH] [rtl] Pipe result in float adder. --- t1/src/FloatModule.scala | 6 +++--- t1/src/mask/MaskReduce.scala | 27 +++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/t1/src/FloatModule.scala b/t1/src/FloatModule.scala index da0651e1d..ea4abad8a 100644 --- a/t1/src/FloatModule.scala +++ b/t1/src/FloatModule.scala @@ -10,7 +10,7 @@ import chisel3.util.experimental.decode.TruthTable import hardfloat._ @instantiable -class FloatAdder(expWidth: Int, sigWidth: Int) extends Module { +class FloatAdder(expWidth: Int, sigWidth: Int, latency: Int) extends Module { @public val io = IO(new Bundle { val a = Input(UInt((expWidth + sigWidth).W)) @@ -26,8 +26,8 @@ class FloatAdder(expWidth: Int, sigWidth: Int) extends Module { addRecFN.io.roundingMode := io.roundingMode addRecFN.io.detectTininess := false.B - io.out := fNFromRecFN(8, 24, addRecFN.io.out) - io.exceptionFlags := addRecFN.io.exceptionFlags + io.out := Pipe(true.B, fNFromRecFN(8, 24, addRecFN.io.out), latency).bits + io.exceptionFlags := Pipe(true.B, addRecFN.io.exceptionFlags, latency).bits } /** float compare module diff --git a/t1/src/mask/MaskReduce.scala b/t1/src/mask/MaskReduce.scala index 288a190f5..4e26d4c67 100644 --- a/t1/src/mask/MaskReduce.scala +++ b/t1/src/mask/MaskReduce.scala @@ -35,7 +35,8 @@ class MaskReduce(parameter: T1Parameter) extends Module { val validInst: Bool = IO(Input(Bool())) val pop: Bool = IO(Input(Bool())) - val maskSize: Int = parameter.laneNumber * parameter.datapathWidth / 8 + val floatAdderLatency: Int = 1 + val maskSize: Int = parameter.laneNumber * parameter.datapathWidth / 8 // todo: uop decode val order: Bool = in.bits.uop === "b101".U @@ -49,7 +50,7 @@ class MaskReduce(parameter: T1Parameter) extends Module { val logicUnit: Instance[LaneLogic] = Instantiate(new LaneLogic(parameter.datapathWidth)) // option unit for flot reduce val floatAdder: Option[Instance[FloatAdder]] = - Option.when(parameter.fpuEnable)(Instantiate(new FloatAdder(8, 24))) + Option.when(parameter.fpuEnable)(Instantiate(new FloatAdder(8, 24, floatAdderLatency))) val flotCompare: Option[Instance[FloatCompare]] = Option.when(parameter.fpuEnable)(Instantiate(new FloatCompare(8, 24))) @@ -73,6 +74,7 @@ class MaskReduce(parameter: T1Parameter) extends Module { val floatType: Bool = reqReg.uop(2) || reqReg.uop(1, 0).andR val NotAdd: Bool = reqReg.uop(1) val widen: Bool = reqReg.uop === "b001".U || reqReg.uop(2, 1) === "b11".U + val floatAdd: Bool = floatType && !NotAdd // eew1HReg(0) || (eew1HReg(1) && !widen) val needFold: Bool = false.B val writeEEW: UInt = Mux(pop, 2.U, reqReg.eew + widen) @@ -82,16 +84,21 @@ class MaskReduce(parameter: T1Parameter) extends Module { // crossFold: reduce between lane // lastFold: reduce in data path // orderRed: order reduce - val idle :: crossFold :: lastFold :: orderRed :: Nil = Enum(4) + val idle :: crossFold :: lastFold :: orderRed :: waitRes :: Nil = Enum(5) val state: UInt = RegInit(idle) val stateIdle: Bool = state === idle val stateCross: Bool = state === crossFold val stateLast: Bool = state === lastFold val stateOrder: Bool = state === orderRed + val stateWait: Bool = state === waitRes + // wait float + val waitCount: UInt = RegInit(0.U(log2Ceil(floatAdderLatency.max(2)).W)) + when(stateWait) { waitCount := waitCount + 1.U } + val resFire: Bool = stateWait && waitCount === (floatAdderLatency - 1).U updateResult := - stateLast || ((stateCross || stateOrder) && sourceValid) + stateLast || ((stateCross || stateOrder) && sourceValid && !floatAdd) || resFire // state update in.ready := stateIdle @@ -102,9 +109,21 @@ class MaskReduce(parameter: T1Parameter) extends Module { } when(stateCross) { + when(floatAdd) { + state := waitRes + waitCount := 0.U + }.elsewhen(groupLastReduce) { + state := Mux(reqReg.lastGroup && needFold, lastFold, idle) + outValid := reqReg.lastGroup && !needFold + } + } + + when(stateWait && resFire) { when(groupLastReduce) { state := Mux(reqReg.lastGroup && needFold, lastFold, idle) outValid := reqReg.lastGroup && !needFold + }.otherwise { + state := crossFold } }