From 1ea41e04b07d541f715b38a85beec5a34cc453c2 Mon Sep 17 00:00:00 2001 From: qinjun-li Date: Sat, 24 Aug 2024 17:44:57 +0800 Subject: [PATCH] [rocketv] read frs1 from fpu for vector instruction. --- rocketv/src/Decoder.scala | 46 ++++++++++++++++++++++++------------ rocketv/src/RocketCore.scala | 13 ++++++++-- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/rocketv/src/Decoder.scala b/rocketv/src/Decoder.scala index dd4abaca0d..cef5fa8829 100644 --- a/rocketv/src/Decoder.scala +++ b/rocketv/src/Decoder.scala @@ -120,7 +120,7 @@ case class DecoderParameter( ) ++ (if (useFPU) Seq(fp, rfs1, rfs2, rfs3, wfd, dp) else None) ++ (if (useMulDiv) if (pipelinedMul) Seq(mul, div) else Seq(div) else None) ++ - (if (useVector) Seq(vector, vectorLSU, vectorCSR) else None) + (if (useVector) Seq(vector, vectorLSU, vectorCSR, vectorReadFRs1) else None) private val Y = BitPat.Y() private val N = BitPat.N() @@ -720,6 +720,12 @@ case class DecoderParameter( override def genTable(op: RocketDecodePattern): BitPat = if (op.isVectorCSR) Y else N } + object vectorReadFRs1 extends BoolDecodeField[RocketDecodePattern] { + override def name: String = "vectorReadFRs1" + + override def genTable(op: RocketDecodePattern): BitPat = if (op.vectorReadFRegFile) Y else N + } + // fpu decode object fldst extends BoolDecodeField[RocketDecodePattern] { override def name: String = "ldst" @@ -744,8 +750,9 @@ case class DecoderParameter( object fren1 extends BoolDecodeField[RocketDecodePattern] { override def name: String = "ren1" - override def genTable(op: RocketDecodePattern): BitPat = op.instruction.name match { - case i if Seq("fmv.x.h", "fclass.h", "fcvt.w.h", "fcvt.wu.h", "fcvt.l.h", "fcvt.lu.h", "fcvt.s.h", "fcvt.h.s", "feq.h", "flt.h", "fle.h", "fsgnj.h", "fsgnjn.h", "fsgnjx.h", "fmin.h", "fmax.h", "fadd.h", "fsub.h", "fmul.h", "fmadd.h", "fmsub.h", "fnmadd.h", "fnmsub.h", "fdiv.h", "fsqrt.h", "fmv.x.w", "fclass.s", "fcvt.w.s", "fcvt.wu.s", "fcvt.l.s", "fcvt.lu.s", "feq.s", "flt.s", "fle.s", "fsgnj.s", "fsgnjn.s", "fsgnjx.s", "fmin.s", "fmax.s", "fadd.s", "fsub.s", "fmul.s", "fmadd.s", "fmsub.s", "fnmadd.s", "fnmsub.s", "fdiv.s", "fsqrt.s", "fmv.x.d", "fclass.d", "fcvt.w.d", "fcvt.wu.d", "fcvt.l.d", "fcvt.lu.d", "fcvt.s.d", "fcvt.d.s", "feq.d", "flt.d", "fle.d", "fsgnj.d", "fsgnjn.d", "fsgnjx.d", "fmin.d", "fmax.d", "fadd.d", "fsub.d", "fmul.d", "fmadd.d", "fmsub.d", "fnmadd.d", "fnmsub.d", "fdiv.d", "fsqrt.d", "fcvt.h.d", "fcvt.d.h").contains(i) => y + override def genTable(op: RocketDecodePattern): BitPat = (op.instruction.name, op) match { + case (i, _) if Seq("fmv.x.h", "fclass.h", "fcvt.w.h", "fcvt.wu.h", "fcvt.l.h", "fcvt.lu.h", "fcvt.s.h", "fcvt.h.s", "feq.h", "flt.h", "fle.h", "fsgnj.h", "fsgnjn.h", "fsgnjx.h", "fmin.h", "fmax.h", "fadd.h", "fsub.h", "fmul.h", "fmadd.h", "fmsub.h", "fnmadd.h", "fnmsub.h", "fdiv.h", "fsqrt.h", "fmv.x.w", "fclass.s", "fcvt.w.s", "fcvt.wu.s", "fcvt.l.s", "fcvt.lu.s", "feq.s", "flt.s", "fle.s", "fsgnj.s", "fsgnjn.s", "fsgnjx.s", "fmin.s", "fmax.s", "fadd.s", "fsub.s", "fmul.s", "fmadd.s", "fmsub.s", "fnmadd.s", "fnmsub.s", "fdiv.s", "fsqrt.s", "fmv.x.d", "fclass.d", "fcvt.w.d", "fcvt.wu.d", "fcvt.l.d", "fcvt.lu.d", "fcvt.s.d", "fcvt.d.s", "feq.d", "flt.d", "fle.d", "fsgnj.d", "fsgnjn.d", "fsgnjx.d", "fmin.d", "fmax.d", "fadd.d", "fsub.d", "fmul.d", "fmadd.d", "fmsub.d", "fnmadd.d", "fnmsub.d", "fdiv.d", "fsqrt.d", "fcvt.h.d", "fcvt.d.h").contains(i) => y + case (_, p) if p.vectorReadFRegFile => y case _ => n } } @@ -800,11 +807,12 @@ case class DecoderParameter( object ftypeTagIn extends UOPDecodeField[RocketDecodePattern] { override def name: String = "typeTagIn" - override def genTable(op: RocketDecodePattern): BitPat = op.instruction.name match { - case i if Seq("fsh", "fmv.x.h", "fsw", "fmv.x.w", "fsd", "fmv.x.d").contains(i) => UOPFType.I - case i if Seq("fmv.h.x", "fcvt.h.w", "fcvt.h.wu", "fcvt.h.l", "fcvt.h.lu", "fclass.h", "fcvt.w.h", "fcvt.wu.h", "fcvt.l.h", "fcvt.lu.h", "fcvt.s.h", "feq.h", "flt.h", "fle.h", "fsgnj.h", "fsgnjn.h", "fsgnjx.h", "fmin.h", "fmax.h", "fadd.h", "fsub.h", "fmul.h", "fmadd.h", "fmsub.h", "fnmadd.h", "fnmsub.h", "fdiv.h", "fsqrt.h", "fcvt.d.h").contains(i) => UOPFType.H - case i if Seq("fcvt.h.s", "fmv.w.x", "fcvt.s.w", "fcvt.s.wu", "fcvt.s.l", "fcvt.s.lu", "fclass.s", "fcvt.w.s", "fcvt.wu.s", "fcvt.l.s", "fcvt.lu.s", "feq.s", "flt.s", "fle.s", "fsgnj.s", "fsgnjn.s", "fsgnjx.s", "fmin.s", "fmax.s", "fadd.s", "fsub.s", "fmul.s", "fmadd.s", "fmsub.s", "fnmadd.s", "fnmsub.s", "fdiv.s", "fsqrt.s", "fcvt.d.s").contains(i) => UOPFType.S - case i if Seq("fmv.d.x", "fcvt.d.w", "fcvt.d.wu", "fcvt.d.l", "fcvt.d.lu", "fclass.d", "fcvt.w.d", "fcvt.wu.d", "fcvt.l.d", "fcvt.lu.d", "fcvt.s.d", "feq.d", "flt.d", "fle.d", "fsgnj.d", "fsgnjn.d", "fsgnjx.d", "fmin.d", "fmax.d", "fadd.d", "fsub.d", "fmul.d", "fmadd.d", "fmsub.d", "fnmadd.d", "fnmsub.d", "fdiv.d", "fsqrt.d", "fcvt.h.d").contains(i) => UOPFType.D + override def genTable(op: RocketDecodePattern): BitPat = (op.instruction.name, op) match { + case (i, _) if Seq("fsh", "fmv.x.h", "fsw", "fmv.x.w", "fsd", "fmv.x.d").contains(i) => UOPFType.I + case (i, _) if Seq("fmv.h.x", "fcvt.h.w", "fcvt.h.wu", "fcvt.h.l", "fcvt.h.lu", "fclass.h", "fcvt.w.h", "fcvt.wu.h", "fcvt.l.h", "fcvt.lu.h", "fcvt.s.h", "feq.h", "flt.h", "fle.h", "fsgnj.h", "fsgnjn.h", "fsgnjx.h", "fmin.h", "fmax.h", "fadd.h", "fsub.h", "fmul.h", "fmadd.h", "fmsub.h", "fnmadd.h", "fnmsub.h", "fdiv.h", "fsqrt.h", "fcvt.d.h").contains(i) => UOPFType.H + case (i, _) if Seq("fcvt.h.s", "fmv.w.x", "fcvt.s.w", "fcvt.s.wu", "fcvt.s.l", "fcvt.s.lu", "fclass.s", "fcvt.w.s", "fcvt.wu.s", "fcvt.l.s", "fcvt.lu.s", "feq.s", "flt.s", "fle.s", "fsgnj.s", "fsgnjn.s", "fsgnjx.s", "fmin.s", "fmax.s", "fadd.s", "fsub.s", "fmul.s", "fmadd.s", "fmsub.s", "fnmadd.s", "fnmsub.s", "fdiv.s", "fsqrt.s", "fcvt.d.s").contains(i) => UOPFType.S + case (i, _) if Seq("fmv.d.x", "fcvt.d.w", "fcvt.d.wu", "fcvt.d.l", "fcvt.d.lu", "fclass.d", "fcvt.w.d", "fcvt.wu.d", "fcvt.l.d", "fcvt.lu.d", "fcvt.s.d", "feq.d", "flt.d", "fle.d", "fsgnj.d", "fsgnjn.d", "fsgnjx.d", "fmin.d", "fmax.d", "fadd.d", "fsub.d", "fmul.d", "fmadd.d", "fmsub.d", "fnmadd.d", "fnmsub.d", "fdiv.d", "fsqrt.d", "fcvt.h.d").contains(i) => UOPFType.D + case (_, op) if op.vectorReadFRegFile => UOPFType.I case _ => UOPFType.X2 } @@ -814,11 +822,12 @@ case class DecoderParameter( object ftypeTagOut extends UOPDecodeField[RocketDecodePattern] { override def name: String = "typeTagOut" - override def genTable(op: RocketDecodePattern): BitPat = op.instruction.name match { - case i if Seq("fmv.h.x", "fmv.w.x", "fmv.d.x").contains(i) => UOPFType.I - case i if Seq("fsh", "fcvt.h.w", "fcvt.h.wu", "fcvt.h.l", "fcvt.h.lu", "fmv.x.h", "fclass.h", "fcvt.h.s", "feq.h", "flt.h", "fle.h", "fsgnj.h", "fsgnjn.h", "fsgnjx.h", "fmin.h", "fmax.h", "fadd.h", "fsub.h", "fmul.h", "fmadd.h", "fmsub.h", "fnmadd.h", "fnmsub.h", "fdiv.h", "fsqrt.h", "fcvt.h.d").contains(i) => UOPFType.H - case i if Seq("fcvt.s.h", "fsw", "fcvt.s.w", "fcvt.s.wu", "fcvt.s.l", "fcvt.s.lu", "fmv.x.w", "fclass.s", "feq.s", "flt.s", "fle.s", "fsgnj.s", "fsgnjn.s", "fsgnjx.s", "fmin.s", "fmax.s", "fadd.s", "fsub.s", "fmul.s", "fmadd.s", "fmsub.s", "fnmadd.s", "fnmsub.s", "fdiv.s", "fsqrt.s", "fcvt.s.d").contains(i) => UOPFType.S - case i if Seq("fsd", "fcvt.d.w", "fcvt.d.wu", "fcvt.d.l", "fcvt.d.lu", "fmv.x.d", "fclass.d", "fcvt.d.s", "feq.d", "flt.d", "fle.d", "fsgnj.d", "fsgnjn.d", "fsgnjx.d", "fmin.d", "fmax.d", "fadd.d", "fsub.d", "fmul.d", "fmadd.d", "fmsub.d", "fnmadd.d", "fnmsub.d", "fdiv.d", "fsqrt.d", "fcvt.d.h").contains(i) => UOPFType.D + override def genTable(op: RocketDecodePattern): BitPat = (op.instruction.name, op) match { + case (i, _) if Seq("fmv.h.x", "fmv.w.x", "fmv.d.x").contains(i) => UOPFType.I + case (i, _) if Seq("fsh", "fcvt.h.w", "fcvt.h.wu", "fcvt.h.l", "fcvt.h.lu", "fmv.x.h", "fclass.h", "fcvt.h.s", "feq.h", "flt.h", "fle.h", "fsgnj.h", "fsgnjn.h", "fsgnjx.h", "fmin.h", "fmax.h", "fadd.h", "fsub.h", "fmul.h", "fmadd.h", "fmsub.h", "fnmadd.h", "fnmsub.h", "fdiv.h", "fsqrt.h", "fcvt.h.d").contains(i) => UOPFType.H + case (i, _) if Seq("fcvt.s.h", "fsw", "fcvt.s.w", "fcvt.s.wu", "fcvt.s.l", "fcvt.s.lu", "fmv.x.w", "fclass.s", "feq.s", "flt.s", "fle.s", "fsgnj.s", "fsgnjn.s", "fsgnjx.s", "fmin.s", "fmax.s", "fadd.s", "fsub.s", "fmul.s", "fmadd.s", "fmsub.s", "fnmadd.s", "fnmsub.s", "fdiv.s", "fsqrt.s", "fcvt.s.d").contains(i) => UOPFType.S + case (i, _) if Seq("fsd", "fcvt.d.w", "fcvt.d.wu", "fcvt.d.l", "fcvt.d.lu", "fmv.x.d", "fclass.d", "fcvt.d.s", "feq.d", "flt.d", "fle.d", "fsgnj.d", "fsgnjn.d", "fsgnjx.d", "fmin.d", "fmax.d", "fadd.d", "fsub.d", "fmul.d", "fmadd.d", "fmsub.d", "fnmadd.d", "fnmsub.d", "fdiv.d", "fsqrt.d", "fcvt.d.h").contains(i) => UOPFType.D + case (_, op) if op.vectorReadFRegFile => UOPFType.S case _ => UOPFType.X2 } @@ -837,8 +846,9 @@ case class DecoderParameter( object ftoint extends BoolDecodeField[RocketDecodePattern] { override def name: String = "toint" - override def genTable(op: RocketDecodePattern): BitPat = op.instruction.name match { - case i if Seq("fsh", "fmv.x.h", "fclass.h", "fcvt.w.h", "fcvt.wu.h", "fcvt.l.h", "fcvt.lu.h", "feq.h", "flt.h", "fle.h", "fsw", "fmv.x.w", "fclass.s", "fcvt.w.s", "fcvt.wu.s", "fcvt.l.s", "fcvt.lu.s", "feq.s", "flt.s", "fle.s", "fsd", "fmv.x.d", "fclass.d", "fcvt.w.d", "fcvt.wu.d", "fcvt.l.d", "fcvt.lu.d", "feq.d", "flt.d", "fle.d").contains(i) => y + override def genTable(op: RocketDecodePattern): BitPat = (op.instruction.name, op) match { + case (i, _) if Seq("fsh", "fmv.x.h", "fclass.h", "fcvt.w.h", "fcvt.wu.h", "fcvt.l.h", "fcvt.lu.h", "feq.h", "flt.h", "fle.h", "fsw", "fmv.x.w", "fclass.s", "fcvt.w.s", "fcvt.wu.s", "fcvt.l.s", "fcvt.lu.s", "feq.s", "flt.s", "fle.s", "fsd", "fmv.x.d", "fclass.d", "fcvt.w.d", "fcvt.wu.d", "fcvt.l.d", "fcvt.lu.d", "feq.d", "flt.d", "fle.d").contains(i) => y + case (_, op) if op.vectorReadFRegFile => y case _ => n } } @@ -922,6 +932,12 @@ case class RocketDecodePattern(instruction: Instruction) extends DecodePattern { case s"v${to}xei${sz}.v" if (to == "lo" || to == "lu" || to == "so" || to == "su") => true case _ => false } + def vectorReadFRegFile: Boolean = instruction.name match { + case i if Seq("vfadd.vf", "vfdiv.vf", "vfmacc.vf", "vfmadd.vf", "vfmax.vf", "vfmerge.vfm", "vfmin.vf", "vfmsac.vf", "vfmsub.vf", "vfmul.vf", "vfnmacc.vf", "vfnmadd.vf", "vfnmsac.vf", "vfnmsub.vf", "vfrdiv.vf", "vfredusum.vs", "vfrsub.vf", "vfsgnj.vf", "vfsgnjn.vf", "vfsgnjx.vf", "vfsub.vf", "vmfeq.vf", "vmfge.vf", "vmfgt.vf", "vmflt.vf", "vmfne.vf").contains(i) => true + case i if Seq("vfmv.s.f", "vfmv.v.f").contains(i) => true + + case _ => false + } // todo: unsure. def vectorReadRs1: Boolean = isVectorLSU || (instruction.name match { // vx type diff --git a/rocketv/src/RocketCore.scala b/rocketv/src/RocketCore.scala index 29fe04016c..93e707391c 100644 --- a/rocketv/src/RocketCore.scala +++ b/rocketv/src/RocketCore.scala @@ -974,7 +974,12 @@ class Rocket(val parameter: RocketParameter) Mux( !memRegException && memRegDecodeOutput(parameter.decoderParameter.fp) && memRegDecodeOutput(parameter.decoderParameter.wxd), fpu.toint_data, - memIntWdata + Mux( + !memRegException && Option.when(usingVector)(memRegDecodeOutput(parameter.decoderParameter.vectorReadFRs1)).getOrElse(false.B), + fpu.store_data, + memIntWdata + ) + ) ) .getOrElse(memIntWdata) @@ -1388,7 +1393,11 @@ class Rocket(val parameter: RocketParameter) io.fpu.foreach { fpu => fpuDecoder.get.io.instruction := idInstruction fpu.dec := fpuDecoder.get.io.output - fpu.valid := !ctrlKilled && idDecodeOutput(parameter.decoderParameter.fp) + fpu.valid := !ctrlKilled && ( + idDecodeOutput(parameter.decoderParameter.fp) || + // vector read frs1 + (fpu.dec.ren1 && idDecodeOutput(parameter.decoderParameter.vector)) + ) fpu.killx := ctrlKillx fpu.killm := killmCommon fpu.inst := idInstruction