diff --git a/src/main/antlr4/FIRRTL.g4 b/src/main/antlr4/FIRRTL.g4 index d40c65601a..ca4b3fe99d 100644 --- a/src/main/antlr4/FIRRTL.g4 +++ b/src/main/antlr4/FIRRTL.g4 @@ -181,6 +181,7 @@ subref : '.' fieldId subref? | '.' DoubleLit subref? // TODO Workaround for #470 | '[' (intLit | exp) ']' subref? + | '[' intLit ':' intLit ']' subref? ; id diff --git a/src/main/proto/firrtl.proto b/src/main/proto/firrtl.proto index 6ce1c10822..6aeb91d59e 100644 --- a/src/main/proto/firrtl.proto +++ b/src/main/proto/firrtl.proto @@ -436,6 +436,15 @@ message Firrtl { Expression index = 2; } + message WSliceNode { + // Required. + Expression expr = 1; + // Required. + IntegerLiteral hi = 2; + // Required. + IntegerLiteral lo = 3; + } + message PrimOp { enum Op { @@ -504,6 +513,7 @@ message Firrtl { SubIndex sub_index = 8; SubAccess sub_access = 9; PrimOp prim_op = 10; + WSliceNode w_slice_node = 12; } } } diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index ec68d4ebf1..b0efe5ce28 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -616,17 +616,19 @@ object Utils extends LazyLogging { case ex => ExpKind } def flow(e: Expression): Flow = e match { - case ex: WRef => ex.flow - case ex: WSubField => ex.flow - case ex: WSubIndex => ex.flow - case ex: WSubAccess => ex.flow + case ex: WRef => ex.flow + case ex: WSubField => ex.flow + case ex: WSubIndex => ex.flow + case ex: WSubAccess => ex.flow + // bits can be used as sink (in a sub-word assignment) or as source + case DoPrim(PrimOps.Bits, Seq(inner), _, _) => flow(inner) case ex: DoPrim => SourceFlow case ex: UIntLiteral => SourceFlow case ex: SIntLiteral => SourceFlow case ex: Mux => SourceFlow case ex: ValidIf => SourceFlow - case WInvalid => SourceFlow - case ex => throwInternalError(s"flow: shouldn't be here - $e") + case _: WIntInvalid => SourceFlow + case ex => throwInternalError(s"flow: shouldn't be here - $e") } def get_flow(s: Statement): Flow = s match { case sx: DefWire => DuplexFlow @@ -892,7 +894,7 @@ object Utils extends LazyLogging { def and(e1: Expression, e2: Expression): Expression = { assert(e1.tpe == e2.tpe) (e1, e2) match { - case (a: UIntLiteral, b: UIntLiteral) => UIntLiteral(a.value | b.value, a.width) + case (a: UIntLiteral, b: UIntLiteral) => UIntLiteral(a.value & b.value, a.width) case (True(), b) => b case (a, True()) => a case (False(), _) => False() @@ -902,6 +904,16 @@ object Utils extends LazyLogging { } } + /** Applies the firrtl Cat primop. */ + def cat(e1: Expression, e2: Expression): Expression = + DoPrim(PrimOps.Cat, Seq(e1, e2), Nil, UIntType(IntWidth(bitWidth(e1.tpe) + bitWidth(e2.tpe)))) + + /** Applies the firrtl bits primop. */ + def bits(e1: Expression, hi: BigInt, lo: BigInt): Expression = { + require(lo >= 0 && hi >= lo) + DoPrim(PrimOps.Bits, Seq(e1), Seq(hi, lo), UIntType(IntWidth(hi - lo + 1))) + } + /** Applies the firrtl Eq primop. */ def eq(e1: Expression, e2: Expression): Expression = DoPrim(PrimOps.Eq, Seq(e1, e2), Nil, BoolType) diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index f3b6837e56..42f7c75b22 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -4,6 +4,7 @@ package firrtl import org.antlr.v4.runtime.ParserRuleContext import org.antlr.v4.runtime.tree.{AbstractParseTreeVisitor, ParseTreeVisitor, TerminalNode} + import scala.collection.JavaConverters._ import scala.collection.mutable import scala.annotation.tailrec @@ -13,6 +14,7 @@ import FIRRTLParser._ import Parser.{AppendInfo, GenInfo, IgnoreInfo, InfoMode, UseInfo} import firrtl.ir._ import Utils.throwInternalError +import firrtl.passes.WSliceNode class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] with ParseTreeVisitor[FirrtlNode] { // Strip file path @@ -470,12 +472,16 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w } } case "[" => - if (ctx.intLit != null) { - val lit = string2Int(ctx.intLit.getText) - SubIndex(inner, lit, UnknownType) + if (ctx.getChildCount == 5 || ctx.getChildCount == 6) { + WSliceNode(inner, string2Int(ctx.intLit(0).getText), string2Int(ctx.intLit(1).getText)) } else { - val idx = visitExp(ctx.exp) - SubAccess(inner, idx, UnknownType) + assert(ctx.getChildCount == 3 || (ctx.getChildCount == 4 && ctx.subref() != null)) + if (ctx.intLit(0) != null) { + val index = string2Int(ctx.intLit(0).getText) + WSliceNode(inner, index, index) + } else { // firrtl expressions are only allowed for Vec sub-accesses + SubAccess(inner, visitExp(ctx.exp), UnknownType) + } } } if (ctx.subref != null) { diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 6198c29dbf..923fa0bc74 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -74,6 +74,18 @@ object WSubAccess { ) } +/** A ground type void that carries an explicit bit-width. Used in the ExpandWhens pass. */ +case class WIntVoid(width: BigInt) extends Expression with UseSerializer { + override def tpe = UIntType(IntWidth(width)) + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = () +} + +@deprecated("Use WIntVoid instead.", "FIRRTL 1.6") case object WVoid extends Expression with UseSerializer { def tpe = UnknownType def mapExpr(f: Expression => Expression): Expression = this @@ -83,6 +95,19 @@ case object WVoid extends Expression with UseSerializer { def foreachType(f: Type => Unit): Unit = () def foreachWidth(f: Width => Unit): Unit = () } + +/** A ground type invalid that carries an explicit bit-width. Used in the ExpandWhens pass. */ +case class WIntInvalid(width: BigInt) extends Expression with UseSerializer { + override def tpe = UIntType(IntWidth(width)) + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = () +} + +@deprecated("Use WIntInvalid instead.", "FIRRTL 1.6") case object WInvalid extends Expression with UseSerializer { def tpe = UnknownType def mapExpr(f: Expression => Expression): Expression = this @@ -233,8 +258,8 @@ class WrappedExpression(val e1: Expression) { case (e1x: WSubField, e2x: WSubField) => (e1x.name.equals(e2x.name)) && weq(e1x.expr, e2x.expr) case (e1x: WSubIndex, e2x: WSubIndex) => (e1x.value == e2x.value) && weq(e1x.expr, e2x.expr) case (e1x: WSubAccess, e2x: WSubAccess) => weq(e1x.index, e2x.index) && weq(e1x.expr, e2x.expr) - case (WVoid, WVoid) => true - case (WInvalid, WInvalid) => true + case (_: WIntVoid, _: WIntVoid) => true + case (_: WIntInvalid, _: WIntInvalid) => true case (e1x: DoPrim, e2x: DoPrim) => e1x.op == e2x.op && ((e1x.consts.zip(e2x.consts)).forall { case (x, y) => x == y }) && diff --git a/src/main/scala/firrtl/analyses/ConnectionGraph.scala b/src/main/scala/firrtl/analyses/ConnectionGraph.scala index 85cbe4df65..d12b3ab069 100644 --- a/src/main/scala/firrtl/analyses/ConnectionGraph.scala +++ b/src/main/scala/firrtl/analyses/ConnectionGraph.scala @@ -7,7 +7,7 @@ import firrtl.annotations.{TargetToken, _} import firrtl.graph.{CyclicException, DiGraph, MutableDiGraph} import firrtl.ir._ import firrtl.passes.MemPortUtils -import firrtl.{InstanceKind, PortKind, SinkFlow, SourceFlow, Utils, WInvalid} +import firrtl.{InstanceKind, PortKind, SinkFlow, SourceFlow, Utils, WIntInvalid} import scala.collection.mutable @@ -338,16 +338,16 @@ object ConnectionGraph { * @return */ def asTarget(m: ModuleTarget, tagger: TokenTagger)(e: FirrtlNode): ReferenceTarget = e match { - case l: Literal => m.ref(tagger.getRef(l.value.toString)) - case r: Reference => m.ref(r.name) - case s: SubIndex => asTarget(m, tagger)(s.expr).index(s.value) - case s: SubField => asTarget(m, tagger)(s.expr).field(s.name) - case d: DoPrim => m.ref(tagger.getRef(d.op.serialize)) - case _: Mux => m.ref(tagger.getRef("mux")) - case _: ValidIf => m.ref(tagger.getRef("validif")) - case WInvalid => m.ref(tagger.getRef("invalid")) - case _: Print => m.ref(tagger.getRef("print")) - case _: Stop => m.ref(tagger.getRef("print")) + case l: Literal => m.ref(tagger.getRef(l.value.toString)) + case r: Reference => m.ref(r.name) + case s: SubIndex => asTarget(m, tagger)(s.expr).index(s.value) + case s: SubField => asTarget(m, tagger)(s.expr).field(s.name) + case d: DoPrim => m.ref(tagger.getRef(d.op.serialize)) + case _: Mux => m.ref(tagger.getRef("mux")) + case _: ValidIf => m.ref(tagger.getRef("validif")) + case _: WIntInvalid => m.ref(tagger.getRef("invalid")) + case _: Print => m.ref(tagger.getRef("print")) + case _: Stop => m.ref(tagger.getRef("print")) case other => sys.error(s"Unsupported: $other") } @@ -508,7 +508,7 @@ object ConnectionGraph { buildExpression(m, tagger, sinkTarget)(c.expr) case i: IsInvalid => - val sourceTarget = asTarget(m, tagger)(WInvalid) + val sourceTarget = asTarget(m, tagger)(WIntInvalid(firrtl.bitWidth(i.expr.tpe))) addLabeledVertex(sourceTarget, stmt) mdg.addVertex(sourceTarget) val sinkTarget = asTarget(m, tagger)(i.expr) diff --git a/src/main/scala/firrtl/analyses/IRLookup.scala b/src/main/scala/firrtl/analyses/IRLookup.scala index e403c1496c..7fff4e2d97 100644 --- a/src/main/scala/firrtl/analyses/IRLookup.scala +++ b/src/main/scala/firrtl/analyses/IRLookup.scala @@ -19,7 +19,7 @@ import firrtl.{ SourceFlow, UnknownFlow, Utils, - WInvalid, + WIntInvalid, WireKind } @@ -124,10 +124,10 @@ class IRLookup private[analyses] ( case other => sys.error(s"Cannot call expr with: $t, given declaration $other") } - case _: IsInvalid => + case i: IsInvalid => exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())( (pathless, SourceFlow) - ) = WInvalid + ) = WIntInvalid(firrtl.bitWidth(i.expr.tpe)) } } } diff --git a/src/main/scala/firrtl/ir/Serializer.scala b/src/main/scala/firrtl/ir/Serializer.scala index f5457dea4c..b244ca4b62 100644 --- a/src/main/scala/firrtl/ir/Serializer.scala +++ b/src/main/scala/firrtl/ir/Serializer.scala @@ -86,10 +86,12 @@ object Serializer { b ++= "Fixed"; s(width); sPoint(point) b ++= "(\"h"; b ++= value.toString(16); b ++= "\")" // WIR - case firrtl.WVoid => b ++= "VOID" - case firrtl.WInvalid => b ++= "INVALID" - case firrtl.EmptyExpression => b ++= "EMPTY" - case other => b ++= other.serialize // Handle user-defined nodes + case firrtl.WVoid => b ++= "VOID" + case firrtl.WInvalid => b ++= "INVALID" + case firrtl.WIntVoid(width) => b ++= "VOID<"; b ++= width.toString(); b += '>' + case firrtl.WIntInvalid(width) => b ++= "INVALID<"; b ++= width.toString(); b += '>' + case firrtl.EmptyExpression => b ++= "EMPTY" + case other => b ++= other.serialize // Handle user-defined nodes } private def s(node: Statement)(implicit b: StringBuilder, indent: Int): Unit = node match { diff --git a/src/main/scala/firrtl/ir/StructuralHash.scala b/src/main/scala/firrtl/ir/StructuralHash.scala index 26e7d210e8..4c8d791a81 100644 --- a/src/main/scala/firrtl/ir/StructuralHash.scala +++ b/src/main/scala/firrtl/ir/StructuralHash.scala @@ -229,8 +229,10 @@ class StructuralHash private (h: Hasher, renameModule: String => String) { case firrtl.WInvalid => id(11) case firrtl.EmptyExpression => id(12) // VRandom is used in the Emitter - case firrtl.VRandom(width) => id(13); hash(width) - // ids 14 ... 19 are reserved for future Expression nodes + case firrtl.VRandom(width) => id(13); hash(width) + case firrtl.WIntVoid(width) => id(14); hash(width) + case firrtl.WIntInvalid(width) => id(15); hash(width) + // ids 16 ... 19 are reserved for future Expression nodes } private def hash(node: Statement): Unit = node match { diff --git a/src/main/scala/firrtl/passes/CheckFlows.scala b/src/main/scala/firrtl/passes/CheckFlows.scala index f78a115a04..6d0d0b42ff 100644 --- a/src/main/scala/firrtl/passes/CheckFlows.scala +++ b/src/main/scala/firrtl/passes/CheckFlows.scala @@ -39,15 +39,21 @@ object CheckFlows extends Pass { def run(c: Circuit): Circuit = { val errors = new Errors() - def get_flow(e: Expression, flows: FlowMap): Flow = e match { + def get_flow(e: Expression, flows: FlowMap, desired: Flow): Flow = e match { case (e: WRef) => flows(e.name) - case (e: WSubIndex) => get_flow(e.expr, flows) - case (e: WSubAccess) => get_flow(e.expr, flows) + case (e: WSubIndex) => get_flow(e.expr, flows, desired) + case (e: WSubAccess) => get_flow(e.expr, flows, desired) case (e: WSubField) => e.expr.tpe match { case t: BundleType => val f = (t.fields.find(_.name == e.name)).get - times(get_flow(e.expr, flows), f.flip) + times(get_flow(e.expr, flows, desired), f.flip) + } + case DoPrim(PrimOps.Bits, Seq(expr), _, _) => // bits can be used as sink (sub-word assignments) and source flow + // when bits are used on the rhs, then they are treated like any other primop + if (desired == SourceFlow) { SourceFlow } + else { + get_flow(expr, flows, desired) } case _ => SourceFlow } @@ -62,7 +68,7 @@ object CheckFlows extends Pass { } def check_flow(info: Info, mname: String, flows: FlowMap, desired: Flow)(e: Expression): Unit = { - val flow = get_flow(e, flows) + val flow = get_flow(e, flows, desired) (flow, desired) match { case (SourceFlow, SinkFlow) => errors.append(new WrongFlow(info, mname, e.serialize, desired, flow)) diff --git a/src/main/scala/firrtl/passes/CheckHighForm.scala b/src/main/scala/firrtl/passes/CheckHighForm.scala index 8a88f8273a..15dec7d91a 100644 --- a/src/main/scala/firrtl/passes/CheckHighForm.scala +++ b/src/main/scala/firrtl/passes/CheckHighForm.scala @@ -207,6 +207,8 @@ trait CheckHighFormLike { this: Pass => } def checkValidLoc(info: Info, mname: String, e: Expression): Unit = e match { + case DoPrim(PrimOps.Bits, Seq(expr), _, _) => // bit slices are allowed as locs for subword assignments + checkValidLoc(info, mname, expr) case _: UIntLiteral | _: SIntLiteral | _: DoPrim => errors.append(new InvalidLOCException(info, mname)) case _ => // Do Nothing @@ -232,7 +234,7 @@ trait CheckHighFormLike { this: Pass => def validSubexp(info: Info, mname: String)(e: Expression): Unit = { e match { case _: Reference | _: SubField | _: SubIndex | _: SubAccess => // No error - case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error + case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf | _: WSliceNode => // No error case _ => errors.append(new InvalidAccessException(info, mname)) } } @@ -247,7 +249,8 @@ trait CheckHighFormLike { this: Pass => errors.append(new NegUIntException(info, mname)) case ex: DoPrim => checkHighFormPrimop(info, mname, ex) case _: Reference | _: WRef | _: UIntLiteral | _: Mux | _: ValidIf => - case ex: SubAccess => validSubexp(info, mname)(ex.expr) + case ex: SubAccess => validSubexp(info, mname)(ex.expr) + case ex: WSliceNode => validSubexp(info, mname)(ex.expr) case ex => ex.foreach(validSubexp(info, mname)) } e.foreach(checkHighFormW(info, mname + "/" + e.serialize)) diff --git a/src/main/scala/firrtl/passes/CheckInitialization.scala b/src/main/scala/firrtl/passes/CheckInitialization.scala index 896e21fc7a..ef008ffec5 100644 --- a/src/main/scala/firrtl/passes/CheckInitialization.scala +++ b/src/main/scala/firrtl/passes/CheckInitialization.scala @@ -11,7 +11,7 @@ import annotation.tailrec /** Reports errors for any references that are not fully initialized * - * @note This pass looks for [[firrtl.WVoid]]s left behind by [[ExpandWhens]] + * @note This pass looks for [[firrtl.WIntVoid]]s left behind by [[ExpandWhens]] * @note Assumes single connection (ie. no last connect semantics) */ object CheckInitialization extends Pass { @@ -48,7 +48,7 @@ object CheckInitialization extends Pass { var void = false val voidDeps = collection.mutable.ArrayBuffer[Expression]() def hasVoid(e: Expression): Unit = e match { - case WVoid => + case _: WIntVoid => void = true case (_: WRef | _: WSubField) => if (voidExprs.contains(e)) { diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index 8fb4e5fbd6..21ded953e5 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -66,20 +66,120 @@ object ExpandWhens extends Pass { */ type Defaults = Seq[mutable.Map[WrappedExpression, Expression]] + private def getVoid(tpe: ir.Type): WIntVoid = WIntVoid(bitWidth(tpe)) + private def doConnect( + netlist: Netlist, + defaults: Defaults, + lhs: ir.Expression, + rhs: ir.Expression, + info: Info + ): Unit = + lhs match { + case ref: ir.RefLikeExpression => + netlist(ref) = InfoExpr(info, rhs) + case otherLhs => + // normalize nested bits expressions before check + simplifyBits(otherLhs) match { + case ir.DoPrim(PrimOps.Bits, Seq(ref: ir.RefLikeExpression), Seq(hi, lo), _) => + val refWidth = bitWidth(ref.tpe) + val highest = refWidth - 1 + assert(hi < refWidth && lo >= 0) + // if we are assigning the whole range, things are simple + if (hi == highest && lo == 0) { + netlist(ref) = InfoExpr(info, rhs) + } else { + // perform a read modify write + val (prevInfo, prev) = unwrap( + netlist.getOrElse(ref, getDefault(ref, defaults).getOrElse(WIntVoid(refWidth))) + ) + val msb = if (hi == highest) { simplifyBits(rhs) } + else { + Utils.cat(simplifyBits(prev, highest, hi + 1), simplifyBits(rhs)) + } + val full = if (lo == 0) { msb } + else { + Utils.cat(msb, simplifyBits(prev, lo - 1, 0)) + } + netlist(ref) = InfoExpr(ir.MultiInfo(Seq(prevInfo, info)), full) + } + case other => + throw new PassException(s"Invalid expression at the left hand side of an assignment: ${other.serialize}") + } + } + private def simplifyBits(e: ir.Expression, hi: BigInt, lo: BigInt): ir.Expression = + simplifyBits(Utils.bits(e, hi, lo)) + + /** performs special simplifications which are needed in the case of sub-word assignments to avoid false positives + * in the connection check + */ + private def simplifyBits(e: ir.Expression): ir.Expression = e match { + case ir.DoPrim(PrimOps.Bits, Seq(expr), Seq(hi, lo), _) => + expr match { + // combine bits of bits + case ir.DoPrim(PrimOps.Bits, Seq(innerExpr), Seq(_, innerLo), _) => + simplifyBits(innerExpr, hi + innerLo, lo + innerLo) + // push bits into mux + case ir.Mux(cond, tval, fval, tpe) => + val tru = simplifyBits(tval, hi, lo) + val fals = simplifyBits(fval, hi, lo) + assert(tru.tpe == fals.tpe) + ir.Mux(cond, tru, fals, tru.tpe) + // push bits into concat + case ir.DoPrim(PrimOps.Cat, Seq(msb, lsb), _, _) => + val lsbWidth = bitWidth(lsb.tpe) + if (lsbWidth > hi) { + simplifyBits(lsb, hi, lo) + } else if (lo >= lsbWidth) { + simplifyBits(msb, hi - lsbWidth, lo - lsbWidth) + } else { + Utils.cat( + simplifyBits(msb, hi - lsbWidth, 0), + simplifyBits(lsb, lsbWidth - 1, lo) + ) + } + // reduce size of void + case _: WIntVoid => WIntVoid(hi - lo + 1) + case _: WIntInvalid => simplifyBits(WIntInvalid(hi - lo + 1)) + case _ => e // nothing to simplify + } + // we replace any invalids with zero because the ValidIf construct does not work well for subword assignments + case WIntInvalid(width) => Utils.getGroundZero(UIntType(IntWidth(width))) + case _ => e // nothing to simplify + } + + /** combines assignments from two branches */ + private def combineBranches( + pred: Expression, + trueValue: Expression, + falseValue: Expression, + tinfo: Info, + finfo: Info + ): (Expression, Info, Info) = { + (trueValue, falseValue) match { + case (i: WIntInvalid, _: WIntInvalid) => (i, NoInfo, NoInfo) + case (_: WIntInvalid, fv) => (ValidIf(NOT(pred), fv, fv.tpe), finfo, NoInfo) + case (tv, _: WIntInvalid) => (ValidIf(pred, tv, tv.tpe), tinfo, NoInfo) + case (tv, fv) => (Mux(pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo) + } + } + /** Expands a module's when statements */ private def onModule(m: Module): Module = { val namespace = Namespace(m) val simlist = new Simlist // Memoizes if an expression contains any WVoids inserted in this pass - val memoizedVoid = new mutable.HashSet[WrappedExpression] += WVoid + val memoizedVoid = new mutable.HashSet[WrappedExpression] // Does an expression contain WVoid inserted in this pass? def containsVoid(e: Expression): Boolean = e match { - case WVoid => true + case _: WIntVoid => true + case ValidIf(_, _: WIntVoid, _) => true case ValidIf(_, value, _) => memoizedVoid(value) - case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv) - case _ => false + case Mux(_, _: WIntVoid, _, _) => true + case Mux(_, _, _: WIntVoid, _) => true + case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv) + case _ => false } // Memoizes the node that holds a particular expression, if any @@ -101,13 +201,13 @@ object ExpandWhens extends Pass { // Return self, unchanged case stmt @ (_: DefNode | EmptyStmt) => stmt case w: DefWire => - netlist ++= (getSinkRefs(w.name, w.tpe, DuplexFlow).map(ref => we(ref) -> WVoid)) + netlist ++= (getSinkRefs(w.name, w.tpe, DuplexFlow).map(ref => we(ref) -> getVoid(ref.tpe))) w case w: DefMemory => - netlist ++= (getSinkRefs(w.name, MemPortUtils.memType(w), SourceFlow).map(ref => we(ref) -> WVoid)) + netlist ++= (getSinkRefs(w.name, MemPortUtils.memType(w), SourceFlow).map(ref => we(ref) -> getVoid(ref.tpe))) w case w: WDefInstance => - netlist ++= (getSinkRefs(w.name, w.tpe, SourceFlow).map(ref => we(ref) -> WVoid)) + netlist ++= (getSinkRefs(w.name, w.tpe, SourceFlow).map(ref => we(ref) -> getVoid(ref.tpe))) w case r: DefRegister => // Update netlist with self reference for each sink reference @@ -115,10 +215,10 @@ object ExpandWhens extends Pass { r // For value assignments, update netlist/attaches and return EmptyStmt case c: Connect => - netlist(c.loc) = InfoExpr(c.info, c.expr) + doConnect(netlist, defaults, c.loc, c.expr, c.info) EmptyStmt case c: IsInvalid => - netlist(c.expr) = WInvalid + doConnect(netlist, defaults, c.expr, WIntInvalid(bitWidth(c.expr.tpe)), c.info) EmptyStmt case a: Attach => attaches += a @@ -161,12 +261,7 @@ object ExpandWhens extends Pass { case Some(defaultValue) => val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue)) val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue)) - (trueValue, falseValue) match { - case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo) - case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo) - case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo) - case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo) - } + combineBranches(sx.pred, trueValue, falseValue, tinfo, finfo) case None => // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo) @@ -205,7 +300,7 @@ object ExpandWhens extends Pass { // Add ports to netlist netlist ++= (m.ports.flatMap { case Port(_, name, dir, tpe) => - getSinkRefs(name, tpe, to_flow(dir)).map(ref => we(ref) -> WVoid) + getSinkRefs(name, tpe, to_flow(dir)).map(ref => we(ref) -> getVoid(ref.tpe)) }) // Do traversal and construct mutable datastructures val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body) @@ -242,8 +337,8 @@ object ExpandWhens extends Pass { def handleInvalid(k: WrappedExpression, info: Info): Statement = if (attached.contains(k)) EmptyStmt else IsInvalid(info, k.e1) netlist.map { - case (k, WInvalid) => handleInvalid(k, NoInfo) - case (k, InfoExpr(info, WInvalid)) => handleInvalid(k, info) + case (k, _: WIntInvalid) => handleInvalid(k, NoInfo) + case (k, InfoExpr(info, _: WIntInvalid)) => handleInvalid(k, info) case (k, v) => val (info, expr) = unwrap(v) Connect(info, k.e1, expr) @@ -286,10 +381,8 @@ object ExpandWhens extends Pass { } } - private def AND(e1: Expression, e2: Expression) = - DoPrim(And, Seq(e1, e2), Nil, BoolType) - private def NOT(e: Expression) = - DoPrim(Eq, Seq(e, zero), Nil, BoolType) + private def AND(e1: Expression, e2: Expression) = Utils.and(e1, e2) + private def NOT(e: Expression) = Utils.not(e) } class ExpandWhensAndCheck extends Transform with DependencyAPIMigration { diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 8ab78fee6b..5bdae55802 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -9,7 +9,7 @@ import firrtl.Mappers._ import firrtl.options.Dependency object InferTypes extends Pass { - + import CInferTypes.resolveSlice override def prerequisites = Dependency(ResolveKinds) +: firrtl.stage.Forms.MinimalHighForm override def invalidates(a: Transform) = false @@ -54,6 +54,7 @@ object InferTypes extends Pass { case e: Mux => e.copy(tpe = mux_type_and_widths(e.tval, e.fval)) case e: ValidIf => e.copy(tpe = e.value.tpe) case e @ (_: UIntLiteral | _: SIntLiteral) => e + case e: WSliceNode => resolveSlice(e) } def infer_types_s(types: TypeLookup)(s: Statement): Statement = s match { @@ -97,6 +98,25 @@ object InferTypes extends Pass { } } +/** Internal node used by the parser for the `[...]` and `[... : ...]` syntax. + * CHIRRTL level type inference then converts this into either a `bits(..., ... , ...)`, + * or [[ir.SubIndex]] depending on the inferred type of the inner expression. + */ +case class WSliceNode(expr: ir.Expression, hi: Int, lo: Int) extends ir.Expression { + override def tpe = ir.UnknownType + override def mapExpr(f: ir.Expression => ir.Expression) = copy(expr = f(expr)) + override def mapType(f: ir.Type => ir.Type) = { f(ir.UnknownType); this } + override def mapWidth(f: ir.Width => ir.Width) = this + override def foreachExpr(f: ir.Expression => Unit): Unit = f(expr) + override def foreachType(f: ir.Type => Unit): Unit = f(ir.UnknownType) + override def foreachWidth(f: ir.Width => Unit): Unit = () + override def serialize = { + val suffix = if (hi == lo) { s"[$hi]" } + else { s"[$hi:$lo]" } + expr.serialize + suffix + } +} + object CInferTypes extends Pass { override def prerequisites = firrtl.stage.Forms.ChirrtlForm @@ -107,6 +127,17 @@ object CInferTypes extends Pass { private type TypeLookup = collection.mutable.HashMap[String, Type] + /** Turns an internal slice node into either a bits primop or a sub index depending on types. + * This is necessary because the parser cannot disambiguate the three AST nodes since they share the same syntax. + */ + def resolveSlice(e: WSliceNode): Expression = e.expr.tpe match { + case _: VectorType if e.hi == e.lo => + SubIndex(e.expr, e.hi, sub_type(e.expr.tpe)) + case _ => + val op = DoPrim(PrimOps.Bits, List(e.expr), List(e.hi, e.lo), UnknownType) + PrimOps.set_primop_type(op) + } + def run(c: Circuit): Circuit = { val mtypes = (c.modules.map(m => m.name -> module_type(m))).toMap @@ -120,6 +151,7 @@ object CInferTypes extends Pass { case (e: Mux) => e.copy(tpe = mux_type(e.tval, e.fval)) case (e: ValidIf) => e.copy(tpe = e.value.tpe) case e @ (_: UIntLiteral | _: SIntLiteral) => e + case e: WSliceNode => resolveSlice(e) } def infer_types_s(types: TypeLookup)(s: Statement): Statement = s match { diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala index 7a1a57fbef..6769b8b1af 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -25,7 +25,7 @@ object AnalysisUtils { case DefNode(_, name, value) => connects(name) = value case IsInvalid(_, value) => - connects(value.serialize) = WInvalid + connects(value.serialize) = WIntInvalid(bitWidth(value.tpe)) case _ => // do nothing } s.map(getConnects(connects)) diff --git a/src/main/scala/firrtl/proto/FromProto.scala b/src/main/scala/firrtl/proto/FromProto.scala index 91b3f872d0..270ec6d567 100644 --- a/src/main/scala/firrtl/proto/FromProto.scala +++ b/src/main/scala/firrtl/proto/FromProto.scala @@ -122,6 +122,9 @@ object FromProto { def convert(validif: Firrtl.Expression.ValidIf): ir.ValidIf = ir.ValidIf(convert(validif.getCondition), convert(validif.getValue), ir.UnknownType) + def convert(slice: Firrtl.Expression.WSliceNode): firrtl.passes.WSliceNode = + firrtl.passes.WSliceNode(convert(slice.getExpr), convert(slice.getHi).toInt, convert(slice.getLo).toInt) + def convert(expr: Firrtl.Expression): ir.Expression = { import Firrtl.Expression._ expr.getExpressionCase.getNumber match { @@ -135,6 +138,7 @@ object FromProto { case PRIM_OP_FIELD_NUMBER => convert(expr.getPrimOp) case MUX_FIELD_NUMBER => convert(expr.getMux) case VALID_IF_FIELD_NUMBER => convert(expr.getValidIf) + case W_SLICE_NODE_FIELD_NUMBER => convert(expr.getWSliceNode) } } diff --git a/src/main/scala/firrtl/proto/ToProto.scala b/src/main/scala/firrtl/proto/ToProto.scala index f5ade0e3e6..38242cf0f7 100644 --- a/src/main/scala/firrtl/proto/ToProto.scala +++ b/src/main/scala/firrtl/proto/ToProto.scala @@ -165,6 +165,13 @@ object ToProto { .setExpression(convert(e)) .setIndex(convertToIntegerLiteral(value)) eb.setSubIndex(sb) + case firrtl.passes.WSliceNode(expr, hi, lo) => + val sb = Firrtl.Expression.WSliceNode + .newBuilder() + .setExpr(convert(expr)) + .setHi(convertToIntegerLiteral(hi)) + .setLo(convertToIntegerLiteral(lo)) + eb.setWSliceNode(sb) case ir.SubAccess(e, index, _, _) => val sb = Firrtl.Expression.SubAccess .newBuilder() diff --git a/src/test/scala/firrtl/testutils/LeanTransformSpec.scala b/src/test/scala/firrtl/testutils/LeanTransformSpec.scala index d3510326a4..2d1cad8de7 100644 --- a/src/test/scala/firrtl/testutils/LeanTransformSpec.scala +++ b/src/test/scala/firrtl/testutils/LeanTransformSpec.scala @@ -35,6 +35,10 @@ class LeanTransformSpec(protected val transforms: Seq[TransformDependency]) actual should be(expected) finalState } + protected def removeSkip(c: ir.Circuit): ir.Circuit = { + def onStmt(s: ir.Statement): ir.Statement = s.mapStmt(onStmt) + c.mapModule(m => m.mapStmt(onStmt)) + } } private object LeanTransformSpec { diff --git a/src/test/scala/firrtlTests/ExpandWhensSpec.scala b/src/test/scala/firrtlTests/ExpandWhensSpec.scala index 52c87ffbc6..5299f9fe33 100644 --- a/src/test/scala/firrtlTests/ExpandWhensSpec.scala +++ b/src/test/scala/firrtlTests/ExpandWhensSpec.scala @@ -33,9 +33,9 @@ class ExpandWhensSpec extends FirrtlFlatSpec { val lines = c.serialize.split("\n").map(normalized) if (expected) { - c.serialize.contains(check) should be(true) + assert(c.serialize.contains(check)) } else { - lines.foreach(_.contains(check) should be(false)) + lines.foreach(l => assert(!l.contains(check))) } } "Expand Whens" should "not emit INVALID" in { @@ -146,7 +146,7 @@ class ExpandWhensSpec extends FirrtlFlatSpec { | else : | skip""".stripMargin val check = - "assert(clock, eq(in, UInt<1>(\"h1\")), and(and(UInt<1>(\"h1\"), p), UInt<1>(\"h1\")), \"assert0\") : test_assert" + "assert(clock, eq(in, UInt<1>(\"h1\")), p, \"assert0\") : test_assert" executeTest(input, check, true) } it should "handle stops" in { @@ -160,7 +160,7 @@ class ExpandWhensSpec extends FirrtlFlatSpec { | stop(clock, UInt(1), 1) : test_stop | else : | skip""".stripMargin - val check = """stop(clock, and(and(UInt<1>("h1"), p), UInt<1>("h1")), 1) : test_stop""" + val check = """stop(clock, p, 1) : test_stop""" executeTest(input, check, true) } } diff --git a/src/test/scala/firrtlTests/SubWordAssignmentTests.scala b/src/test/scala/firrtlTests/SubWordAssignmentTests.scala new file mode 100644 index 0000000000..1623e1a520 --- /dev/null +++ b/src/test/scala/firrtlTests/SubWordAssignmentTests.scala @@ -0,0 +1,342 @@ +package firrtlTests + +import firrtl.passes.CheckInitialization.RefNotInitializedException +import firrtl.passes.CheckWidths.BitsWidthException +import firrtl.stage.Forms +import firrtl.testutils.LeanTransformSpec + +class SubWordAssignmentTests extends LeanTransformSpec(Forms.LowFormOptimized) { + behavior.of("SubWordAssignment") + + private def check(input: String, expected: String) = { + val r = compile(input) + assert(removeSkip(r.circuit).serialize == removeSkip(parse(expected)).serialize) + } + + it should "support assigning individual output bits" in { + val src = + """circuit m: + | module m: + | output x : UInt<2> + | x[0] <= UInt(1) + | x[1] <= UInt(0) + |""".stripMargin + val expected = + """circuit m : + | module m : + | output x : UInt<2> + | x <= UInt<2>("h1") + |""".stripMargin + check(src, expected) + } + + it should "support assigning output bit ranges" in { + val src = + """circuit m: + | module m: + | output x : UInt<3> + | x[0] <= UInt(1) + | x[2:1] <= UInt(2) + |""".stripMargin + val expected = + """circuit m : + | module m : + | output x : UInt<3> + | x <= UInt<3>("h5") + |""".stripMargin + check(src, expected) + } + + it should "throw an error on uninitialized bits" in { + val src = + """circuit m: + | module m: + | output x : UInt<3> + | x[2:1] <= UInt(2) + |""".stripMargin + val e = intercept[RefNotInitializedException] { compile(src) } + assert(e.getMessage.contains("Reference x is not fully initialized")) + } + + it should "allow marking individual bits as DontCare" in { + val src = + """circuit m: + | module m: + | output x : UInt<3> + | x[0] is invalid + | x[2:1] <= UInt(2) + |""".stripMargin + val expected = + """circuit m : + | module m : + | output x : UInt<3> + | x <= UInt<3>("h4") + |""".stripMargin + check(src, expected) + } + + it should "allow marking individual bits as DontCare with conditionals" in { + val src = + """circuit m: + | module m: + | input c : UInt<1> + | output x : UInt<2> + | x[0] <= UInt(1) + | when c: + | x[1] <= UInt(1) + | else: + | x[1] is invalid + |""".stripMargin + // TODO: currently we actually do not perform this optimization since we replace invalid with 0 for sub-word assignments + val expected = + """circuit m : + | module m : + | input c : UInt<1> + | output x : UInt<2> + | x <= mux(c, UInt<2>("h3"), UInt<2>("h1")) + |""".stripMargin + check(src, expected) + } + + it should "take advantage of DontCare bits in different branches" in { + val src = + """circuit m: + | module m: + | input c : UInt<1> + | output x : UInt<2> + | when c: + | x[0] is invalid + | x[1] <= UInt(1) + | else: + | x[0] <= UInt(1) + | x[1] is invalid + |""".stripMargin + // TODO: currently we actually do not perform this optimization since we replace invalid with 0 for sub-word assignments + val expected = + """circuit m : + | module m : + | input c : UInt<1> + | output x : UInt<2> + | x <= mux(c, UInt<2>("h2"), UInt<2>("h1")) + |""".stripMargin + check(src, expected) + } + + it should "support assignments in conditionals" in { + val src = + """circuit m: + | module m: + | input c : UInt<1> + | output x : UInt<2> + | x[0] <= UInt(1) + | when c: + | x[1] <= UInt(1) + | else: + | x[1] <= UInt(0) + |""".stripMargin + val expected = + """circuit m : + | module m : + | input c : UInt<1> + | output x : UInt<2> + | x <= mux(c, UInt<2>("h3"), UInt<2>("h1")) + |""".stripMargin + check(src, expected) + } + + it should "support assignments for bundles" in { + val src = + """circuit m: + | module m: + | input c : UInt<1> + | output x : { a : UInt<2>, b: UInt<2> } + | x.a[0] <= UInt(1) + | when c: + | x.a[1] <= UInt(1) + | else: + | x.a[1] <= UInt(0) + | x.b <= UInt(3) + |""".stripMargin + val expected = + """circuit m : + | module m : + | input c : UInt<1> + | output x_a : UInt<2> + | output x_b : UInt<2> + | x_a <= mux(c, UInt<2>("h3"), UInt<2>("h1")) + | x_b <= UInt<2>("h3") + |""".stripMargin + check(src, expected) + } + + it should "support partial assignments with invalidation for bundles" in { + val src = + """circuit m: + | module m: + | input c : UInt<1> + | output x : { a : UInt<2>, b: UInt<2> } + | x is invalid + | when c: + | x.a[1] <= UInt(1) + |""".stripMargin + val expected = + """circuit m : + | module m : + | input c : UInt<1> + | output x_a : UInt<2> + | output x_b : UInt<2> + | x_a <= UInt<2>(2) + | x_b <= UInt<2>(0) + |""".stripMargin + check(src, expected) + } + + it should "work with wires" in { + val src = + """circuit m: + | module m: + | input c : UInt<1> + | output y : UInt<2> + | wire x : UInt<2> + | y <= x + | x[0] <= UInt(1) + | when c: + | x[1] <= UInt(1) + | else: + | x[1] <= UInt(0) + |""".stripMargin + val expected = + """circuit m : + | module m : + | input c : UInt<1> + | output y : UInt<2> + | y <= mux(c, UInt<2>("h3"), UInt<2>("h1")) + |""".stripMargin + check(src, expected) + } + + it should "work with registers" in { + val src = + """circuit m: + | module m: + | input clock : Clock + | input c : UInt<1> + | output y : UInt<2> + | reg x : UInt<2>, clock + | y <= x + | x[0] <= UInt(1) + | when c: + | x[1] <= UInt(1) + | else: + | x[1] <= UInt(0) + |""".stripMargin + val expected = + """circuit m : + | module m : + | input clock : Clock + | input c : UInt<1> + | output y : UInt<2> + | reg x : UInt<2>, clock with : + | reset => (UInt<1>("h0"), x) + | y <= x + | x <= mux(c, UInt<2>("h3"), UInt<2>("h1")) + |""".stripMargin + check(src, expected) + } + + it should "allow for nested bit indices" in { + val src = + """circuit m: + | module m: + | output x : UInt<2> + | x[1:0][0][0] <= UInt(1) + | x[1][0][0][0] <= UInt(0) + |""".stripMargin + val expected = + """circuit m : + | module m : + | output x : UInt<2> + | x <= UInt<2>("h1") + |""".stripMargin + check(src, expected) + } + + it should "error on out of bounds hi index" in { + val src = + """circuit m: + | module m: + | output x : UInt<2> + | x[2:0] <= UInt(1) + |""".stripMargin + val e = intercept[BitsWidthException] { + compile(src) + } + assert(e.getMessage.contains("High bit 2 in bits operator is larger than input width 2")) + } + + it should "error on out of bounds lo index" in { + val src = + """circuit m: + | module m: + | output x : UInt<2> + | x[1:-1] <= UInt(1) + |""".stripMargin + val e = intercept[firrtl.passes.CheckHighForm.NegArgException] { compile(src) } + assert(e.getMessage.contains("Primop bits argument -1 < 0")) + } + + it should "error on hi < lo indices" in { + val src = + """circuit m: + | module m: + | output x : UInt<2> + | x[0:1] <= UInt(1) + |""".stripMargin + val e = intercept[firrtl.passes.CheckHighForm.LsbLargerThanMsbException] { compile(src) } + assert(e.getMessage.contains("Primop bits lsb 1 > 0")) + } + + it should "error when assigning an input" in { + val src = + """circuit m: + | module m: + | input x : UInt<2> + | x[0] <= UInt(1) + |""".stripMargin + val e = intercept[firrtl.passes.CheckFlows.WrongFlow] { compile(src) } + assert(e.getMessage.contains("is used as a SinkFlow but can only be used as a SourceFlow")) + } + + it should "error when assigning an input inside a bundle" in { + val src = + """circuit m: + | module m: + | output io : { flip x : UInt<2> } + | io.x[0] <= UInt(1) + |""".stripMargin + val e = intercept[firrtl.passes.CheckFlows.WrongFlow] { compile(src) } + assert(e.getMessage.contains("is used as a SinkFlow but can only be used as a SourceFlow")) + } + + // TODO: should we support these kinds of word-level, but not bit-level loops? + it should "allow signals to be connected trough bit slices" ignore { + val src = + """circuit m: + | module m: + | input y : UInt<1> + | output x : UInt<1> + | wire tmp : UInt<4> + | x <= not(tmp[3]) + | tmp[3:1] <= not(tmp[2:0]) + | tmp[0] <= y + |""".stripMargin + val expected = + """circuit m : + | module m : + | input y : UInt<1> + | output x : UInt<1> + | x <= y + |""".stripMargin + check(src, expected) + } +}