Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

implement sub-word assignments #2545

Open
wants to merge 1 commit into
base: master-deprecated
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/main/antlr4/FIRRTL.g4
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ subref
: '.' fieldId subref?
| '.' DoubleLit subref? // TODO Workaround for #470
| '[' (intLit | exp) ']' subref?
| '[' intLit ':' intLit ']' subref?
;

id
Expand Down
10 changes: 10 additions & 0 deletions src/main/proto/firrtl.proto
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,15 @@ message Firrtl {
Expression index = 2;
}

message WSliceNode {
// Required.
Expression expr = 1;
// Required.
IntegerLiteral hi = 2;
// Required.
IntegerLiteral lo = 3;
}

message PrimOp {

enum Op {
Expand Down Expand Up @@ -504,6 +513,7 @@ message Firrtl {
SubIndex sub_index = 8;
SubAccess sub_access = 9;
PrimOp prim_op = 10;
WSliceNode w_slice_node = 12;
}
}
}
26 changes: 19 additions & 7 deletions src/main/scala/firrtl/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -616,17 +616,19 @@ object Utils extends LazyLogging {
case ex => ExpKind
}
def flow(e: Expression): Flow = e match {
case ex: WRef => ex.flow
case ex: WSubField => ex.flow
case ex: WSubIndex => ex.flow
case ex: WSubAccess => ex.flow
case ex: WRef => ex.flow
case ex: WSubField => ex.flow
case ex: WSubIndex => ex.flow
case ex: WSubAccess => ex.flow
// bits can be used as sink (in a sub-word assignment) or as source
case DoPrim(PrimOps.Bits, Seq(inner), _, _) => flow(inner)
case ex: DoPrim => SourceFlow
case ex: UIntLiteral => SourceFlow
case ex: SIntLiteral => SourceFlow
case ex: Mux => SourceFlow
case ex: ValidIf => SourceFlow
case WInvalid => SourceFlow
case ex => throwInternalError(s"flow: shouldn't be here - $e")
case _: WIntInvalid => SourceFlow
case ex => throwInternalError(s"flow: shouldn't be here - $e")
}
def get_flow(s: Statement): Flow = s match {
case sx: DefWire => DuplexFlow
Expand Down Expand Up @@ -892,7 +894,7 @@ object Utils extends LazyLogging {
def and(e1: Expression, e2: Expression): Expression = {
assert(e1.tpe == e2.tpe)
(e1, e2) match {
case (a: UIntLiteral, b: UIntLiteral) => UIntLiteral(a.value | b.value, a.width)
case (a: UIntLiteral, b: UIntLiteral) => UIntLiteral(a.value & b.value, a.width)
case (True(), b) => b
case (a, True()) => a
case (False(), _) => False()
Expand All @@ -902,6 +904,16 @@ object Utils extends LazyLogging {
}
}

/** Applies the firrtl Cat primop. */
def cat(e1: Expression, e2: Expression): Expression =
DoPrim(PrimOps.Cat, Seq(e1, e2), Nil, UIntType(IntWidth(bitWidth(e1.tpe) + bitWidth(e2.tpe))))

/** Applies the firrtl bits primop. */
def bits(e1: Expression, hi: BigInt, lo: BigInt): Expression = {
require(lo >= 0 && hi >= lo)
DoPrim(PrimOps.Bits, Seq(e1), Seq(hi, lo), UIntType(IntWidth(hi - lo + 1)))
}

/** Applies the firrtl Eq primop. */
def eq(e1: Expression, e2: Expression): Expression = DoPrim(PrimOps.Eq, Seq(e1, e2), Nil, BoolType)

Expand Down
16 changes: 11 additions & 5 deletions src/main/scala/firrtl/Visitor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package firrtl

import org.antlr.v4.runtime.ParserRuleContext
import org.antlr.v4.runtime.tree.{AbstractParseTreeVisitor, ParseTreeVisitor, TerminalNode}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.annotation.tailrec
Expand All @@ -13,6 +14,7 @@ import FIRRTLParser._
import Parser.{AppendInfo, GenInfo, IgnoreInfo, InfoMode, UseInfo}
import firrtl.ir._
import Utils.throwInternalError
import firrtl.passes.WSliceNode

class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] with ParseTreeVisitor[FirrtlNode] {
// Strip file path
Expand Down Expand Up @@ -470,12 +472,16 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
}
}
case "[" =>
if (ctx.intLit != null) {
val lit = string2Int(ctx.intLit.getText)
SubIndex(inner, lit, UnknownType)
if (ctx.getChildCount == 5 || ctx.getChildCount == 6) {
WSliceNode(inner, string2Int(ctx.intLit(0).getText), string2Int(ctx.intLit(1).getText))
} else {
val idx = visitExp(ctx.exp)
SubAccess(inner, idx, UnknownType)
assert(ctx.getChildCount == 3 || (ctx.getChildCount == 4 && ctx.subref() != null))
if (ctx.intLit(0) != null) {
val index = string2Int(ctx.intLit(0).getText)
WSliceNode(inner, index, index)
} else { // firrtl expressions are only allowed for Vec sub-accesses
SubAccess(inner, visitExp(ctx.exp), UnknownType)
}
}
}
if (ctx.subref != null) {
Expand Down
29 changes: 27 additions & 2 deletions src/main/scala/firrtl/WIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ object WSubAccess {
)
}

/** A ground type void that carries an explicit bit-width. Used in the ExpandWhens pass. */
case class WIntVoid(width: BigInt) extends Expression with UseSerializer {
override def tpe = UIntType(IntWidth(width))
def mapExpr(f: Expression => Expression): Expression = this
def mapType(f: Type => Type): Expression = this
def mapWidth(f: Width => Width): Expression = this
def foreachExpr(f: Expression => Unit): Unit = ()
def foreachType(f: Type => Unit): Unit = f(tpe)
def foreachWidth(f: Width => Unit): Unit = ()
}

@deprecated("Use WIntVoid instead.", "FIRRTL 1.6")
case object WVoid extends Expression with UseSerializer {
def tpe = UnknownType
def mapExpr(f: Expression => Expression): Expression = this
Expand All @@ -83,6 +95,19 @@ case object WVoid extends Expression with UseSerializer {
def foreachType(f: Type => Unit): Unit = ()
def foreachWidth(f: Width => Unit): Unit = ()
}

/** A ground type invalid that carries an explicit bit-width. Used in the ExpandWhens pass. */
case class WIntInvalid(width: BigInt) extends Expression with UseSerializer {
override def tpe = UIntType(IntWidth(width))
def mapExpr(f: Expression => Expression): Expression = this
def mapType(f: Type => Type): Expression = this
def mapWidth(f: Width => Width): Expression = this
def foreachExpr(f: Expression => Unit): Unit = ()
def foreachType(f: Type => Unit): Unit = f(tpe)
def foreachWidth(f: Width => Unit): Unit = ()
}

@deprecated("Use WIntInvalid instead.", "FIRRTL 1.6")
case object WInvalid extends Expression with UseSerializer {
def tpe = UnknownType
def mapExpr(f: Expression => Expression): Expression = this
Expand Down Expand Up @@ -233,8 +258,8 @@ class WrappedExpression(val e1: Expression) {
case (e1x: WSubField, e2x: WSubField) => (e1x.name.equals(e2x.name)) && weq(e1x.expr, e2x.expr)
case (e1x: WSubIndex, e2x: WSubIndex) => (e1x.value == e2x.value) && weq(e1x.expr, e2x.expr)
case (e1x: WSubAccess, e2x: WSubAccess) => weq(e1x.index, e2x.index) && weq(e1x.expr, e2x.expr)
case (WVoid, WVoid) => true
case (WInvalid, WInvalid) => true
case (_: WIntVoid, _: WIntVoid) => true
case (_: WIntInvalid, _: WIntInvalid) => true
case (e1x: DoPrim, e2x: DoPrim) =>
e1x.op == e2x.op &&
((e1x.consts.zip(e2x.consts)).forall { case (x, y) => x == y }) &&
Expand Down
24 changes: 12 additions & 12 deletions src/main/scala/firrtl/analyses/ConnectionGraph.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import firrtl.annotations.{TargetToken, _}
import firrtl.graph.{CyclicException, DiGraph, MutableDiGraph}
import firrtl.ir._
import firrtl.passes.MemPortUtils
import firrtl.{InstanceKind, PortKind, SinkFlow, SourceFlow, Utils, WInvalid}
import firrtl.{InstanceKind, PortKind, SinkFlow, SourceFlow, Utils, WIntInvalid}

import scala.collection.mutable

Expand Down Expand Up @@ -338,16 +338,16 @@ object ConnectionGraph {
* @return
*/
def asTarget(m: ModuleTarget, tagger: TokenTagger)(e: FirrtlNode): ReferenceTarget = e match {
case l: Literal => m.ref(tagger.getRef(l.value.toString))
case r: Reference => m.ref(r.name)
case s: SubIndex => asTarget(m, tagger)(s.expr).index(s.value)
case s: SubField => asTarget(m, tagger)(s.expr).field(s.name)
case d: DoPrim => m.ref(tagger.getRef(d.op.serialize))
case _: Mux => m.ref(tagger.getRef("mux"))
case _: ValidIf => m.ref(tagger.getRef("validif"))
case WInvalid => m.ref(tagger.getRef("invalid"))
case _: Print => m.ref(tagger.getRef("print"))
case _: Stop => m.ref(tagger.getRef("print"))
case l: Literal => m.ref(tagger.getRef(l.value.toString))
case r: Reference => m.ref(r.name)
case s: SubIndex => asTarget(m, tagger)(s.expr).index(s.value)
case s: SubField => asTarget(m, tagger)(s.expr).field(s.name)
case d: DoPrim => m.ref(tagger.getRef(d.op.serialize))
case _: Mux => m.ref(tagger.getRef("mux"))
case _: ValidIf => m.ref(tagger.getRef("validif"))
case _: WIntInvalid => m.ref(tagger.getRef("invalid"))
case _: Print => m.ref(tagger.getRef("print"))
case _: Stop => m.ref(tagger.getRef("print"))
case other => sys.error(s"Unsupported: $other")
}

Expand Down Expand Up @@ -508,7 +508,7 @@ object ConnectionGraph {
buildExpression(m, tagger, sinkTarget)(c.expr)

case i: IsInvalid =>
val sourceTarget = asTarget(m, tagger)(WInvalid)
val sourceTarget = asTarget(m, tagger)(WIntInvalid(firrtl.bitWidth(i.expr.tpe)))
addLabeledVertex(sourceTarget, stmt)
mdg.addVertex(sourceTarget)
val sinkTarget = asTarget(m, tagger)(i.expr)
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/firrtl/analyses/IRLookup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import firrtl.{
SourceFlow,
UnknownFlow,
Utils,
WInvalid,
WIntInvalid,
WireKind
}

Expand Down Expand Up @@ -124,10 +124,10 @@ class IRLookup private[analyses] (
case other =>
sys.error(s"Cannot call expr with: $t, given declaration $other")
}
case _: IsInvalid =>
case i: IsInvalid =>
exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())(
(pathless, SourceFlow)
) = WInvalid
) = WIntInvalid(firrtl.bitWidth(i.expr.tpe))
}
}
}
Expand Down
10 changes: 6 additions & 4 deletions src/main/scala/firrtl/ir/Serializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ object Serializer {
b ++= "Fixed"; s(width); sPoint(point)
b ++= "(\"h"; b ++= value.toString(16); b ++= "\")"
// WIR
case firrtl.WVoid => b ++= "VOID"
case firrtl.WInvalid => b ++= "INVALID"
case firrtl.EmptyExpression => b ++= "EMPTY"
case other => b ++= other.serialize // Handle user-defined nodes
case firrtl.WVoid => b ++= "VOID"
case firrtl.WInvalid => b ++= "INVALID"
case firrtl.WIntVoid(width) => b ++= "VOID<"; b ++= width.toString(); b += '>'
case firrtl.WIntInvalid(width) => b ++= "INVALID<"; b ++= width.toString(); b += '>'
case firrtl.EmptyExpression => b ++= "EMPTY"
case other => b ++= other.serialize // Handle user-defined nodes
}

private def s(node: Statement)(implicit b: StringBuilder, indent: Int): Unit = node match {
Expand Down
6 changes: 4 additions & 2 deletions src/main/scala/firrtl/ir/StructuralHash.scala
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,10 @@ class StructuralHash private (h: Hasher, renameModule: String => String) {
case firrtl.WInvalid => id(11)
case firrtl.EmptyExpression => id(12)
// VRandom is used in the Emitter
case firrtl.VRandom(width) => id(13); hash(width)
// ids 14 ... 19 are reserved for future Expression nodes
case firrtl.VRandom(width) => id(13); hash(width)
case firrtl.WIntVoid(width) => id(14); hash(width)
case firrtl.WIntInvalid(width) => id(15); hash(width)
// ids 16 ... 19 are reserved for future Expression nodes
}

private def hash(node: Statement): Unit = node match {
Expand Down
16 changes: 11 additions & 5 deletions src/main/scala/firrtl/passes/CheckFlows.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,21 @@ object CheckFlows extends Pass {
def run(c: Circuit): Circuit = {
val errors = new Errors()

def get_flow(e: Expression, flows: FlowMap): Flow = e match {
def get_flow(e: Expression, flows: FlowMap, desired: Flow): Flow = e match {
case (e: WRef) => flows(e.name)
case (e: WSubIndex) => get_flow(e.expr, flows)
case (e: WSubAccess) => get_flow(e.expr, flows)
case (e: WSubIndex) => get_flow(e.expr, flows, desired)
case (e: WSubAccess) => get_flow(e.expr, flows, desired)
case (e: WSubField) =>
e.expr.tpe match {
case t: BundleType =>
val f = (t.fields.find(_.name == e.name)).get
times(get_flow(e.expr, flows), f.flip)
times(get_flow(e.expr, flows, desired), f.flip)
}
case DoPrim(PrimOps.Bits, Seq(expr), _, _) => // bits can be used as sink (sub-word assignments) and source flow
// when bits are used on the rhs, then they are treated like any other primop
if (desired == SourceFlow) { SourceFlow }
else {
get_flow(expr, flows, desired)
}
case _ => SourceFlow
}
Expand All @@ -62,7 +68,7 @@ object CheckFlows extends Pass {
}

def check_flow(info: Info, mname: String, flows: FlowMap, desired: Flow)(e: Expression): Unit = {
val flow = get_flow(e, flows)
val flow = get_flow(e, flows, desired)
(flow, desired) match {
case (SourceFlow, SinkFlow) =>
errors.append(new WrongFlow(info, mname, e.serialize, desired, flow))
Expand Down
7 changes: 5 additions & 2 deletions src/main/scala/firrtl/passes/CheckHighForm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ trait CheckHighFormLike { this: Pass =>
}

def checkValidLoc(info: Info, mname: String, e: Expression): Unit = e match {
case DoPrim(PrimOps.Bits, Seq(expr), _, _) => // bit slices are allowed as locs for subword assignments
checkValidLoc(info, mname, expr)
case _: UIntLiteral | _: SIntLiteral | _: DoPrim =>
errors.append(new InvalidLOCException(info, mname))
case _ => // Do Nothing
Expand All @@ -232,7 +234,7 @@ trait CheckHighFormLike { this: Pass =>
def validSubexp(info: Info, mname: String)(e: Expression): Unit = {
e match {
case _: Reference | _: SubField | _: SubIndex | _: SubAccess => // No error
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf | _: WSliceNode => // No error
case _ => errors.append(new InvalidAccessException(info, mname))
}
}
Expand All @@ -247,7 +249,8 @@ trait CheckHighFormLike { this: Pass =>
errors.append(new NegUIntException(info, mname))
case ex: DoPrim => checkHighFormPrimop(info, mname, ex)
case _: Reference | _: WRef | _: UIntLiteral | _: Mux | _: ValidIf =>
case ex: SubAccess => validSubexp(info, mname)(ex.expr)
case ex: SubAccess => validSubexp(info, mname)(ex.expr)
case ex: WSliceNode => validSubexp(info, mname)(ex.expr)
case ex => ex.foreach(validSubexp(info, mname))
}
e.foreach(checkHighFormW(info, mname + "/" + e.serialize))
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/firrtl/passes/CheckInitialization.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import annotation.tailrec

/** Reports errors for any references that are not fully initialized
*
* @note This pass looks for [[firrtl.WVoid]]s left behind by [[ExpandWhens]]
* @note This pass looks for [[firrtl.WIntVoid]]s left behind by [[ExpandWhens]]
* @note Assumes single connection (ie. no last connect semantics)
*/
object CheckInitialization extends Pass {
Expand Down Expand Up @@ -48,7 +48,7 @@ object CheckInitialization extends Pass {
var void = false
val voidDeps = collection.mutable.ArrayBuffer[Expression]()
def hasVoid(e: Expression): Unit = e match {
case WVoid =>
case _: WIntVoid =>
void = true
case (_: WRef | _: WSubField) =>
if (voidExprs.contains(e)) {
Expand Down
Loading