diff --git a/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td b/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td index 839961e3e70a..edd79267e575 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td @@ -160,6 +160,128 @@ def InstanceOp : HardwareDeclOp<"instance", [ let hasVerifier = true; } +def StrictInstanceOp : HardwareDeclOp<"strictinstance", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + ]> { + let summary = "Instantiate an instance of a module"; + let description = [{ + This represents an instance of a module. The results are the modules inputs + and outputs. The inputs have LHSType wrapped types, the outputs do not. + + Examples: + ```mlir + %0 = firrtl.strictinstance foo @Foo(in io: !firrtl.uint) + firrtl.strictconnect %0, %c0_ui + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$moduleName, StrAttr:$name, NameKindAttr:$nameKind, + DenseBoolArrayAttr:$portDirections, StrArrayAttr:$portNames, + AnnotationArrayAttr:$annotations, + PortAnnotationsAttr:$portAnnotations, + LayerArrayAttr:$layers, + UnitAttr:$lowerToBind, + OptionalAttr:$inner_sym); + + let results = (outs Variadic:$results); + + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins "::mlir::TypeRange":$resultTypes, + "::mlir::StringRef":$moduleName, + "::mlir::StringRef":$name, + "::circt::firrtl::NameKindEnum":$nameKind, + "::mlir::ArrayRef":$portDirections, + "::mlir::ArrayRef":$portNames, + CArg<"ArrayRef", "{}">:$annotations, + CArg<"ArrayRef", "{}">:$portAnnotations, + CArg<"::mlir::ArrayRef", "{}">:$layers, + CArg<"bool","false">:$lowerToBind, + CArg<"StringAttr", "StringAttr()">:$innerSym)>, + OpBuilder<(ins "::mlir::TypeRange":$resultTypes, + "::mlir::StringRef":$moduleName, + "::mlir::StringRef":$name, + "::circt::firrtl::NameKindEnum":$nameKind, + "::mlir::ArrayRef":$portDirections, + "::mlir::ArrayRef":$portNames, + "ArrayRef":$annotations, + "ArrayRef":$portAnnotations, + "::mlir::ArrayRef":$layers, + "bool":$lowerToBind, + "hw::InnerSymAttr":$innerSym)>, + + /// Constructor when you have the target module in hand. + OpBuilder<(ins "FModuleLike":$module, + "mlir::StringRef":$name, + CArg<"NameKindEnum", "NameKindEnum::DroppableName">:$nameKind, + CArg<"ArrayRef", "{}">:$annotations, + CArg<"ArrayRef", "{}">:$portAnnotations, + CArg<"bool","false">:$lowerToBind, + CArg<"hw::InnerSymAttr", "hw::InnerSymAttr()">:$innerSym)>, + + /// Constructor when you have a port info list in hand. + OpBuilder<(ins "ArrayRef":$ports, + "::mlir::StringRef":$moduleName, + "mlir::StringRef":$name, + CArg<"NameKindEnum", "NameKindEnum::DroppableName">:$nameKind, + CArg<"ArrayRef", "{}">:$annotations, + CArg<"ArrayRef", "{}">:$layers, + CArg<"bool","false">:$lowerToBind, + CArg<"hw::InnerSymAttr", "hw::InnerSymAttr()">:$innerSym)> + ]; + + let extraClassDeclaration = [{ + /// Return the port direction for the specified result number. + Direction getPortDirection(size_t resultNo) { + return direction::get(getPortDirections()[resultNo]); + } + + /// Return the port name for the specified result number. + StringAttr getPortName(size_t resultNo) { + return cast(getPortNames()[resultNo]); + } + StringRef getPortNameStr(size_t resultNo) { + return getPortName(resultNo).getValue(); + } + + /// Hooks for port annotations. + ArrayAttr getPortAnnotation(unsigned portIdx); + void setAllPortAnnotations(ArrayRef annotations); + + /// Builds a new `StrictInstanceOp` with the ports listed in `portIndices` erased, + /// and updates any users of the remaining ports to point at the new + /// instance. + StrictInstanceOp erasePorts(OpBuilder &builder, const llvm::BitVector &portIndices); + + /// Clone the instance op and add ports. This is usually used in + /// conjuction with adding ports to the referenced module. This will emit + /// the new StrictInstanceOp to the same location. + StrictInstanceOp cloneAndInsertPorts(ArrayRef> ports); + + //===------------------------------------------------------------------===// + // Instance graph methods + //===------------------------------------------------------------------===// + + // Quick lookup of the referenced module using the instance graph. + template + T getReferencedModule(::circt::igraph::InstanceGraph &instanceGraph) { + auto moduleNameAttr = getModuleNameAttr().getAttr(); + auto *node = instanceGraph.lookup(moduleNameAttr); + if (!node) + return nullptr; + Operation *moduleOp = node->getModule(); + return dyn_cast_or_null(moduleOp); + } + }]; + + let hasVerifier = true; +} + def InstanceChoiceOp : HardwareDeclOp<"instance_choice", [ DeclareOpInterfaceMethods ]> { @@ -579,6 +701,41 @@ def RegResetOp : HardwareDeclOp<"regreset", [Forceable, CombDataflow]> { let hasVerifier = 1; } +def StrictRegOp : HardwareDeclOp<"strictreg", [Forceable, SameVariadicOperandSize]> { + let summary = "Define a new register with an optional reset"; + let description = [{ + Declare a new register: + ``` + %name, %name_write = firrtl.regreset %clockVal (reset %resetSignal, %resetValue)? : t1 (, rt)? + ``` + }]; + + let arguments = ( + ins ClockType:$clockVal, + Optional:$resetSignal, + Optional:$resetValue, + StrAttr:$name, NameKindAttr:$nameKind, + AnnotationArrayAttr:$annotations, + OptionalAttr:$inner_sym, + UnitAttr:$forceable); + let results = (outs AnyRegisterType:$read, LHSType:$write, Optional:$ref); + + let assemblyFormat = [{ + (`sym` $inner_sym^)? `` custom($nameKind) + $clockVal (`,` `reset` `(` $resetSignal^ `:` type($resetSignal) `,` $resetValue `:` type($resetValue) `)` )? (`,` `forceable` $forceable^)? `` custom(attr-dict) + `:` type($clockVal) `,` custom(type($write), type($read)) (`,` qualified(type($ref))^)? + + }]; + +// let hasCanonicalizer = true; +// let hasVerifier = 1; + + let extraClassDeclaration = [{ + Value getResult(); + }]; +} + + def WireOp : HardwareDeclOp<"wire", [ Forceable, DeclareOpInterfaceMethods @@ -670,6 +827,41 @@ def WireOp : HardwareDeclOp<"wire", [ }]; } +def StrictWireOp : HardwareDeclOp<"strictwire", [ + Forceable, + LHSTypeConstraint<"read", "write">, + OptRefTypeConstraint<"read", "ref">, + DeclareOpInterfaceMethods +]> { + let summary = "Define a new wire"; + let description = [{ + Declare a new wire: + ``` + %name,%name.wp,%name.ref = firrtl.strictwire forceable : t1 + ``` + }]; + + let arguments = (ins StrAttr:$name, NameKindAttr:$nameKind, + AnnotationArrayAttr:$annotations, + OptionalAttr:$inner_sym, + UnitAttr:$forceable); // ReferenceKinds + let results = (outs PassiveType:$read, + LHSType:$write, + Optional:$ref); + +// let hasCanonicalizer = true; + + let assemblyFormat = [{ + (`sym` $inner_sym^)? `` custom($nameKind) + (`forceable` $forceable^)? `` custom(attr-dict) `:` + custom(type($write), type($read)) (`,` qualified(type($ref))^)? + }]; + + let extraClassDeclaration = [{ + Value getResult(); + }]; +} + //===----------------------------------------------------------------------===// // Property Ops //===----------------------------------------------------------------------===// diff --git a/include/circt/Dialect/FIRRTL/FIRRTLExpressions.td b/include/circt/Dialect/FIRRTL/FIRRTLExpressions.td index f435e86ef7ce..68ac55d344e4 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLExpressions.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLExpressions.td @@ -271,7 +271,7 @@ def InvalidValueOp : FIRRTLOp<"invalidvalue", let assemblyFormat = "attr-dict `:` qualified(type($result))"; } -class BaseSubfieldOp : FIRRTLExprOp { +class BaseSubfieldOp : FIRRTLExprOp { let summary = "Extract a subfield of another value"; let description = [{ The subfield expression refers to a subelement of an expression with a @@ -281,14 +281,15 @@ class BaseSubfieldOp : FIRRTLExprOp { ``` }]; - let arguments = (ins btype:$input, I32Attr:$fieldIndex); + let arguments = (ins itype:$input, I32Attr:$fieldIndex); let results = (outs rtype:$result); let hasVerifier = 1; let hasCustomAssemblyFormat = 1; let builders = [ OpBuilder<(ins "Value":$input, "StringRef":$fieldName), [{ - auto bundleType = firrtl::type_cast<}] # btype # [{>(input.getType()); + auto ttype = isa(input.getType()) ? cast(input.getType()).getType() : input.getType(); + auto bundleType = firrtl::type_cast<}] # btype # [{>(ttype); auto fieldIndex = bundleType.getElementIndex(fieldName); assert(fieldIndex.has_value() && "subfield operation to unknown field"); return build($_builder, $_state, input, *fieldIndex); @@ -296,29 +297,39 @@ class BaseSubfieldOp : FIRRTLExprOp { ]; let firrtlExtraClassDeclaration = [{ - using InputType = }] # btype # [{; + using InputType = }] # itype # [{; + using InputBundleType = }] # btype # [{; + + FIRRTLType stripLHS(FIRRTLType type) { + if (auto t = dyn_cast(type)) + return t.getType(); + return type; + } /// Return true if the specified field is flipped. bool isFieldFlipped(); /// Return a `FieldRef` to the accessed field. FieldRef getAccessedField() { - return FieldRef(getInput(), firrtl::type_cast(getInput().getType()) + auto innerType = stripLHS(getInput().getType()); + return FieldRef(getInput(), firrtl::type_cast(innerType) .getFieldID(getFieldIndex())); } /// Return the name of the accessed field. StringRef getFieldName() { - return firrtl::type_cast(getInput().getType()).getElementName(getFieldIndex()); + auto innerType = stripLHS(getInput().getType()); + return firrtl::type_cast(innerType).getElementName(getFieldIndex()); } }]; } -def SubfieldOp : BaseSubfieldOp<"subfield", BundleType, FIRRTLBaseType> { +def SubfieldOp : BaseSubfieldOp<"subfield", BundleType, BundleType, FIRRTLBaseType> { let hasFolder = 1; let hasCanonicalizer = 1; } -def OpenSubfieldOp : BaseSubfieldOp<"opensubfield", OpenBundleType, FIRRTLType>; +def OpenSubfieldOp : BaseSubfieldOp<"opensubfield", OpenBundleType, OpenBundleType, FIRRTLType>; +def LHSSubfieldOp : BaseSubfieldOp<"lhssubfield", BundleType, LHSType, LHSType>; def SubindexOp : FIRRTLExprOp<"subindex"> { let summary = "Extract an element of a vector value"; @@ -377,6 +388,35 @@ def OpenSubindexOp : FIRRTLExprOp<"opensubindex"> { }]; } +def LHSSubindexOp : FIRRTLExprOp<"lhssubindex"> { + let summary = "Reference an element of a vector value"; + let description = [{ + The subindex expression statically refers, by index, to a subelement + of an expression with a vector type. The index must be a non-negative + integer and cannot be equal to or exceed the length of the vector it + indexes. + ``` + %result = firrtl.lhssubindex %input[index] : lhs + ``` + }]; + + let arguments = (ins LHSType:$input, I32Attr:$index); + let results = (outs LHSType:$result); + + let assemblyFormat = + "$input `[` $index `]` attr-dict `:` qualified(type($input))"; + + let firrtlExtraClassDeclaration = [{ + /// Return a `FieldRef` to the accessed field. + FieldRef getAccessedField() { + auto btype = cast(getInput().getType().getType()); + return FieldRef(getInput(), btype.getFieldID(getIndex())); + } + + using InputType = FVectorType; + }]; +} + def SubaccessOp : FIRRTLExprOp<"subaccess"> { let summary = "Extract a dynamic element of a vector value"; let description = [{ diff --git a/include/circt/Dialect/FIRRTL/FIRRTLStatements.td b/include/circt/Dialect/FIRRTL/FIRRTLStatements.td index 649801ad0328..ae0d33344b02 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLStatements.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLStatements.td @@ -77,6 +77,27 @@ def MatchingConnectOp : FIRRTLOp<"matchingconnect", [FConnectLike, custom(type($dest), type($src))}]; } +def SameAnonLHSTypeOperands: PredOpTrait< + "operands must be structurally equivalent with LHS wrapper", + CPred<"areAnonymousLHSandTypeEquivalent(getDest().getType(), getSrc().getType())">>; +def StrictConnectOp : FIRRTLOp<"strictconnect", [FConnectLike, + SameAnonLHSTypeOperands]> { + let summary = "Connect two signals"; + let description = [{ + Connect two values with strict constraints: + ``` + firrtl.strictconnect %dest, %src : t1 + firrtl.strictconnect %dest, %src : t1, !firrtl.alias + ``` + }]; + + let arguments = (ins LHSType:$dest, + SizedPassiveType:$src); + let results = (outs); + + let assemblyFormat = [{$dest `,` $src attr-dict `:` custom(type($dest), type($src))}]; +} + def RefDefineOp : FIRRTLOp<"ref.define", [SameTypeOperands, FConnectLike]> { let summary = "FIRRTL Define References"; let description = [{ diff --git a/include/circt/Dialect/FIRRTL/FIRRTLTypes.h b/include/circt/Dialect/FIRRTL/FIRRTLTypes.h index 71360161163a..a86db842b447 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLTypes.h +++ b/include/circt/Dialect/FIRRTL/FIRRTLTypes.h @@ -248,6 +248,7 @@ bool isTypeLarger(FIRRTLBaseType dstType, FIRRTLBaseType srcType); /// comparison. bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs); bool areAnonymousTypesEquivalent(mlir::Type lhs, mlir::Type rhs); +bool areAnonymousLHSandTypeEquivalent(LHSType lhs, FIRRTLBaseType rhs); mlir::Type getPassiveType(mlir::Type anyBaseFIRRTLType); diff --git a/include/circt/Dialect/FIRRTL/FIRRTLTypes.td b/include/circt/Dialect/FIRRTL/FIRRTLTypes.td index 3051928e458c..c4cc0d6135f2 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLTypes.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLTypes.td @@ -108,6 +108,16 @@ def RWProbe : FIRRTLDialectType< def LHSType : FIRRTLDialectType($_self)">, "writable type", "::circt::firrtl::LHSType">; +class LHSTypeConstraint + : TypesMatchWith<"Result must be lhs of input type", + input, result, + "type_cast($_self).getType()">; + +class OptRefTypeConstraint + : OptionalTypesMatchWith<"Result must be ref of input type", + input, result, + "type_cast($_self).getType()">; + def ConnectableType : AnyTypeOf<[FIRRTLBaseType, ForeignType]>; def MatchingConnectableType : AnyTypeOf<[SizedPassiveType, ForeignType]>; diff --git a/include/circt/Dialect/FIRRTL/FIRRTLVisitors.h b/include/circt/Dialect/FIRRTL/FIRRTLVisitors.h index 03a2466d2b52..5372b0c1c955 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLVisitors.h +++ b/include/circt/Dialect/FIRRTL/FIRRTLVisitors.h @@ -270,6 +270,7 @@ class StmtVisitor { HANDLE(AttachOp); HANDLE(ConnectOp); HANDLE(MatchingConnectOp); + HANDLE(StrictConnectOp); HANDLE(RefDefineOp); HANDLE(ForceOp); HANDLE(PrintFOp); @@ -304,10 +305,10 @@ class DeclVisitor { auto *thisCast = static_cast(this); return TypeSwitch(op) .template Case( - [&](auto opNode) -> ResultType { - return thisCast->visitDecl(opNode, args...); - }) + RegOp, RegResetOp, StrictRegOp, WireOp, StrictWireOp, + VerbatimWireOp>([&](auto opNode) -> ResultType { + return thisCast->visitDecl(opNode, args...); + }) .Default([&](auto expr) -> ResultType { return thisCast->visitInvalidDecl(op, args...); }); @@ -337,7 +338,9 @@ class DeclVisitor { HANDLE(NodeOp); HANDLE(RegOp); HANDLE(RegResetOp); + HANDLE(StrictRegOp); HANDLE(WireOp); + HANDLE(StrictWireOp); HANDLE(VerbatimWireOp); #undef HANDLE }; diff --git a/lib/Dialect/FIRRTL/FIRRTLInstanceImplementation.cpp b/lib/Dialect/FIRRTL/FIRRTLInstanceImplementation.cpp index 7403c72b81b0..e6f9aa9d6416 100644 --- a/lib/Dialect/FIRRTL/FIRRTLInstanceImplementation.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLInstanceImplementation.cpp @@ -99,6 +99,9 @@ instance_like_impl::verifyReferencedModule(Operation *instanceOp, for (size_t i = 0; i != numResults; i++) { auto resultType = instanceOp->getResult(i).getType(); auto expectedType = referencedModule.getPortType(i); + if (direction::get(portDirections[i]) == Direction::In && + isa(resultType)) + resultType = cast(resultType).getType(); if (resultType != expectedType) { return emitNote(instanceOp->emitOpError() << "result type for " << portNames[i] << " must be " diff --git a/lib/Dialect/FIRRTL/FIRRTLOps.cpp b/lib/Dialect/FIRRTL/FIRRTLOps.cpp index aa5adadc383b..1c8fca8ffc0d 100644 --- a/lib/Dialect/FIRRTL/FIRRTLOps.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLOps.cpp @@ -2465,6 +2465,333 @@ std::optional InstanceOp::getTargetResultIndex() { return std::nullopt; } +//===----------------------------------------------------------------------===// +// StrictInstanceOp +//===----------------------------------------------------------------------===// + +void StrictInstanceOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTypes, + StringRef moduleName, StringRef name, NameKindEnum nameKind, + ArrayRef portDirections, ArrayRef portNames, + ArrayRef annotations, ArrayRef portAnnotations, + ArrayRef layers, bool lowerToBind, hw::InnerSymAttr innerSym) { + SmallVector actualTypes; + for (auto [t, d] : llvm::zip(resultTypes, portDirections)) + if (d == Direction::In && isa(t)) + actualTypes.push_back( + LHSType::get(t.getContext(), cast(t))); + else + actualTypes.push_back(t); + result.addTypes(actualTypes); + result.addAttribute("moduleName", + SymbolRefAttr::get(builder.getContext(), moduleName)); + result.addAttribute("name", builder.getStringAttr(name)); + result.addAttribute( + "portDirections", + direction::packAttribute(builder.getContext(), portDirections)); + result.addAttribute("portNames", builder.getArrayAttr(portNames)); + result.addAttribute("annotations", builder.getArrayAttr(annotations)); + result.addAttribute("layers", builder.getArrayAttr(layers)); + if (lowerToBind) + result.addAttribute("lowerToBind", builder.getUnitAttr()); + if (innerSym) + result.addAttribute("inner_sym", innerSym); + result.addAttribute("nameKind", + NameKindEnumAttr::get(builder.getContext(), nameKind)); + + if (portAnnotations.empty()) { + SmallVector portAnnotationsVec(resultTypes.size(), + builder.getArrayAttr({})); + result.addAttribute("portAnnotations", + builder.getArrayAttr(portAnnotationsVec)); + } else { + assert(portAnnotations.size() == resultTypes.size()); + result.addAttribute("portAnnotations", + builder.getArrayAttr(portAnnotations)); + } +} + +void StrictInstanceOp::build(OpBuilder &builder, OperationState &odsState, + ArrayRef ports, StringRef moduleName, + StringRef name, NameKindEnum nameKind, + ArrayRef annotations, + ArrayRef layers, bool lowerToBind, + hw::InnerSymAttr innerSym) { + // Gather the result types. + SmallVector newResultTypes; + SmallVector newPortDirections; + SmallVector newPortNames; + SmallVector newPortAnnotations; + for (auto &p : ports) { + newResultTypes.push_back(p.type); + newPortDirections.push_back(p.direction); + newPortNames.push_back(p.name); + newPortAnnotations.push_back(p.annotations.getArrayAttr()); + } + + return build(builder, odsState, newResultTypes, moduleName, name, nameKind, + newPortDirections, newPortNames, annotations, newPortAnnotations, + layers, lowerToBind, innerSym); +} + +LogicalResult StrictInstanceOp::verify() { + // The instance may only be instantiated under its required layers. + auto ambientLayers = getAmbientLayersAt(getOperation()); + SmallVector missingLayers; + for (auto layer : getLayersAttr().getAsRange()) + if (!isLayerCompatibleWith(layer, ambientLayers)) + missingLayers.push_back(layer); + + if (missingLayers.empty()) + return success(); + + auto diag = + emitOpError("ambient layers are insufficient to instantiate module"); + auto ¬e = diag.attachNote(); + note << "missing layer requirements: "; + interleaveComma(missingLayers, note); + return failure(); +} + +/// Builds a new `StrictInstanceOp` with the ports listed in `portIndices` +/// erased, and updates any users of the remaining ports to point at the new +/// instance. +StrictInstanceOp +StrictInstanceOp::erasePorts(OpBuilder &builder, + const llvm::BitVector &portIndices) { + assert(portIndices.size() >= getNumResults() && + "portIndices is not at least as large as getNumResults()"); + + if (portIndices.none()) + return *this; + + SmallVector newResultTypes = removeElementsAtIndices( + SmallVector(result_type_begin(), result_type_end()), portIndices); + SmallVector newPortDirections = removeElementsAtIndices( + direction::unpackAttribute(getPortDirectionsAttr()), portIndices); + SmallVector newPortNames = + removeElementsAtIndices(getPortNames().getValue(), portIndices); + SmallVector newPortAnnotations = + removeElementsAtIndices(getPortAnnotations().getValue(), portIndices); + + auto newOp = builder.create( + getLoc(), newResultTypes, getModuleName(), getName(), getNameKind(), + newPortDirections, newPortNames, getAnnotations().getValue(), + newPortAnnotations, getLayers(), getLowerToBind(), getInnerSymAttr()); + + for (unsigned oldIdx = 0, newIdx = 0, numOldPorts = getNumResults(); + oldIdx != numOldPorts; ++oldIdx) { + if (portIndices.test(oldIdx)) { + assert(getResult(oldIdx).use_empty() && "removed instance port has uses"); + continue; + } + getResult(oldIdx).replaceAllUsesWith(newOp.getResult(newIdx)); + ++newIdx; + } + + // Compy over "output_file" information so that this is not lost when ports + // are erased. + // + // TODO: Other attributes may need to be copied over. + if (auto outputFile = (*this)->getAttr("output_file")) + newOp->setAttr("output_file", outputFile); + + return newOp; +} + +ArrayAttr StrictInstanceOp::getPortAnnotation(unsigned portIdx) { + assert(portIdx < getNumResults() && + "index should be smaller than result number"); + return cast(getPortAnnotations()[portIdx]); +} + +void StrictInstanceOp::setAllPortAnnotations(ArrayRef annotations) { + assert(annotations.size() == getNumResults() && + "number of annotations is not equal to result number"); + (*this)->setAttr("portAnnotations", + ArrayAttr::get(getContext(), annotations)); +} + +StrictInstanceOp StrictInstanceOp::cloneAndInsertPorts( + ArrayRef> ports) { + auto portSize = ports.size(); + auto newPortCount = getNumResults() + portSize; + SmallVector newPortDirections; + newPortDirections.reserve(newPortCount); + SmallVector newPortNames; + newPortNames.reserve(newPortCount); + SmallVector newPortTypes; + newPortTypes.reserve(newPortCount); + SmallVector newPortAnnos; + newPortAnnos.reserve(newPortCount); + + unsigned oldIndex = 0; + unsigned newIndex = 0; + while (oldIndex + newIndex < newPortCount) { + // Check if we should insert a port here. + if (newIndex < portSize && ports[newIndex].first == oldIndex) { + auto &newPort = ports[newIndex].second; + newPortDirections.push_back(newPort.direction); + newPortNames.push_back(newPort.name); + newPortTypes.push_back(newPort.type); + newPortAnnos.push_back(newPort.annotations.getArrayAttr()); + ++newIndex; + } else { + // Copy the next old port. + newPortDirections.push_back(getPortDirection(oldIndex)); + newPortNames.push_back(getPortName(oldIndex)); + newPortTypes.push_back(getType(oldIndex)); + newPortAnnos.push_back(getPortAnnotation(oldIndex)); + ++oldIndex; + } + } + + // Create a new instance op with the reset inserted. + return OpBuilder(*this).create( + getLoc(), newPortTypes, getModuleName(), getName(), getNameKind(), + newPortDirections, newPortNames, getAnnotations().getValue(), + newPortAnnos, getLayers(), getLowerToBind(), getInnerSymAttr()); +} + +LogicalResult +StrictInstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + return instance_like_impl::verifyReferencedModule(*this, symbolTable, + getModuleNameAttr()); +} + +StringRef StrictInstanceOp::getInstanceName() { return getName(); } + +StringAttr StrictInstanceOp::getInstanceNameAttr() { return getNameAttr(); } + +void StrictInstanceOp::print(OpAsmPrinter &p) { + // Print the instance name. + p << " "; + p.printKeywordOrString(getName()); + if (auto attr = getInnerSymAttr()) { + p << " sym "; + p.printSymbolName(attr.getSymName()); + } + if (getNameKindAttr().getValue() != NameKindEnum::DroppableName) + p << ' ' << stringifyNameKindEnum(getNameKindAttr().getValue()); + + // Print the attr-dict. + SmallVector omittedAttrs = { + "moduleName", "name", "portDirections", + "portNames", "portTypes", "portAnnotations", + "inner_sym", "nameKind"}; + if (getAnnotations().empty()) + omittedAttrs.push_back("annotations"); + if (getLayers().empty()) + omittedAttrs.push_back("layers"); + p.printOptionalAttrDict((*this)->getAttrs(), omittedAttrs); + + // Print the module name. + p << " "; + p.printSymbolName(getModuleName()); + + // Collect all the result types as TypeAttrs for printing. + SmallVector portTypes; + portTypes.reserve(getNumResults()); + for (auto p : getResultTypes()) + if (auto t = dyn_cast(p)) + portTypes.push_back(TypeAttr::get(t.getType())); + else + portTypes.push_back(TypeAttr::get(p)); + + printModulePorts(p, /*block=*/nullptr, getPortDirectionsAttr(), + getPortNames().getValue(), portTypes, + getPortAnnotations().getValue(), {}, {}); +} + +ParseResult StrictInstanceOp::parse(OpAsmParser &parser, + OperationState &result) { + auto *context = parser.getContext(); + auto &resultAttrs = result.attributes; + + std::string name; + hw::InnerSymAttr innerSymAttr; + FlatSymbolRefAttr moduleName; + SmallVector entryArgs; + SmallVector portDirections; + SmallVector portNames; + SmallVector portTypes; + SmallVector portAnnotations; + SmallVector portSyms; + SmallVector portLocs; + NameKindEnumAttr nameKind; + + if (parser.parseKeywordOrString(&name)) + return failure(); + if (succeeded(parser.parseOptionalKeyword("sym"))) { + if (parser.parseCustomAttributeWithFallback( + innerSymAttr, ::mlir::Type{}, + hw::InnerSymbolTable::getInnerSymbolAttrName(), + result.attributes)) { + return ::mlir::failure(); + } + } + if (parseNameKind(parser, nameKind) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseAttribute(moduleName, "moduleName", resultAttrs) || + parseModulePorts(parser, /*hasSSAIdentifiers=*/false, + /*supportsSymbols=*/false, entryArgs, portDirections, + portNames, portTypes, portAnnotations, portSyms, + portLocs)) + return failure(); + + // Add the attributes. We let attributes defined in the attr-dict override + // attributes parsed out of the module signature. + if (!resultAttrs.get("moduleName")) + result.addAttribute("moduleName", moduleName); + if (!resultAttrs.get("name")) + result.addAttribute("name", StringAttr::get(context, name)); + result.addAttribute("nameKind", nameKind); + if (!resultAttrs.get("portDirections")) + result.addAttribute("portDirections", + direction::packAttribute(context, portDirections)); + if (!resultAttrs.get("portNames")) + result.addAttribute("portNames", ArrayAttr::get(context, portNames)); + if (!resultAttrs.get("portAnnotations")) + result.addAttribute("portAnnotations", + ArrayAttr::get(context, portAnnotations)); + + // Annotations, layers, and LowerToBind are omitted in the printed format + // if they are empty, empty, and false (respectively). + if (!resultAttrs.get("annotations")) + resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({})); + if (!resultAttrs.get("layers")) + resultAttrs.append("layers", parser.getBuilder().getArrayAttr({})); + + // Add result types. + for (auto tup : llvm::zip_equal(portTypes, portDirections)) { + auto pType = cast(std::get<0>(tup)).getValue(); + if (std::get<1>(tup) == Direction::In && isa(pType)) + std::get<0>(tup) = TypeAttr::get( + LHSType::get(pType.getContext(), cast(pType))); + } + result.types.reserve(portTypes.size()); + llvm::transform( + portTypes, std::back_inserter(result.types), + [](Attribute typeAttr) { return cast(typeAttr).getValue(); }); + + return success(); +} + +void StrictInstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { + StringRef base = getName(); + if (base.empty()) + base = "inst"; + + for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) { + setNameFn(getResult(i), (base + "_" + getPortNameStr(i)).str()); + } +} + +std::optional StrictInstanceOp::getTargetResultIndex() { + // Inner symbols on instance operations target the op not any result. + return std::nullopt; +} + // ----------------------------------------------------------------------------- // InstanceChoiceOp // ----------------------------------------------------------------------------- @@ -3323,8 +3650,15 @@ void RegOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { return forceableAsmResultNames(*this, getName(), setNameFn); } +void StrictRegOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { + setNameFn(getWrite(), (getName() + "_write").str()); + return forceableAsmResultNames(*this, getName(), setNameFn); +} + std::optional RegOp::getTargetResultIndex() { return 0; } +std::optional StrictRegOp::getTargetResultIndex() { return 0; } + LogicalResult RegResetOp::verify() { auto reset = getResetValue(); @@ -3349,8 +3683,15 @@ void WireOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { return forceableAsmResultNames(*this, getName(), setNameFn); } +void StrictWireOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { + setNameFn(getWrite(), (getName() + "_write").str()); + return forceableAsmResultNames(*this, getName(), setNameFn); +} + std::optional WireOp::getTargetResultIndex() { return 0; } +std::optional StrictWireOp::getTargetResultIndex() { return 0; } + LogicalResult WireOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto refType = type_dyn_cast(getType(0)); if (!refType) @@ -3361,6 +3702,21 @@ LogicalResult WireOp::verifySymbolUses(SymbolTableCollection &symbolTable) { symbolTable, Twine("'") + getOperationName() + "' op is"); } +LogicalResult +StrictWireOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto refType = type_dyn_cast(getType(0)); + if (!refType) + return success(); + + return verifyProbeType( + refType, getLoc(), getOperation()->getParentOfType(), + symbolTable, Twine("'") + getOperationName() + "' op is"); +} + +Value StrictWireOp::getResult() { return getRead(); } + +Value StrictRegOp::getResult() { return getRead(); } + void ObjectOp::build(OpBuilder &builder, OperationState &state, ClassLike klass, StringRef name) { build(builder, state, klass.getInstanceType(), @@ -4428,11 +4784,22 @@ ParseResult parseSubfieldLikeOp(OpAsmParser &parser, OperationState &result) { if (parser.resolveOperand(input, inputType, result.operands)) return failure(); - auto bundleType = type_dyn_cast(inputType); + auto actualInputType = type_dyn_cast(inputType); + if (!actualInputType) + return parser.emitError(parser.getNameLoc(), + "input must be bundle or lhs of bundle type, got ") + << inputType; + + auto bundleType = isa(actualInputType) + ? firrtl::type_dyn_cast( + cast(actualInputType).getType()) + : firrtl::type_dyn_cast( + actualInputType); if (!bundleType) return parser.emitError(parser.getNameLoc(), - "input must be bundle type, got ") + "input must be effectively bundle type, got ") << inputType; + auto fieldIndex = bundleType.getElementIndex(fieldName); if (!fieldIndex) return parser.emitError(parser.getNameLoc(), @@ -4501,6 +4868,9 @@ ParseResult SubfieldOp::parse(OpAsmParser &parser, OperationState &result) { ParseResult OpenSubfieldOp::parse(OpAsmParser &parser, OperationState &result) { return parseSubfieldLikeOp(parser, result); } +ParseResult LHSSubfieldOp::parse(OpAsmParser &parser, OperationState &result) { + return parseSubfieldLikeOp(parser, result); +} template static void printSubfieldLikeOp(OpTy op, ::mlir::OpAsmPrinter &printer) { @@ -4518,6 +4888,9 @@ void SubfieldOp::print(::mlir::OpAsmPrinter &printer) { void OpenSubfieldOp::print(::mlir::OpAsmPrinter &printer) { return printSubfieldLikeOp(*this, printer); } +void LHSSubfieldOp::print(::mlir::OpAsmPrinter &printer) { + return printSubfieldLikeOp(*this, printer); +} void SubtagOp::print(::mlir::OpAsmPrinter &printer) { printer << ' ' << getInput() << '['; @@ -4529,20 +4902,23 @@ void SubtagOp::print(::mlir::OpAsmPrinter &printer) { printer << " : " << getInput().getType(); } -template -static LogicalResult verifySubfieldLike(OpTy op) { +template +static LogicalResult verifySubfieldLike(OpTy op, ITy ty) { if (op.getFieldIndex() >= - firrtl::type_cast(op.getInput().getType()) - .getNumElements()) + firrtl::type_cast(ty).getNumElements()) return op.emitOpError("subfield element index is greater than the number " "of fields in the bundle type"); return success(); } LogicalResult SubfieldOp::verify() { - return verifySubfieldLike(*this); + return verifySubfieldLike(*this, getInput().getType()); } LogicalResult OpenSubfieldOp::verify() { - return verifySubfieldLike(*this); + return verifySubfieldLike(*this, getInput().getType()); +} +LogicalResult LHSSubfieldOp::verify() { + return verifySubfieldLike(*this, + stripLHS(getInput().getType())); } LogicalResult SubtagOp::verify() { @@ -4636,10 +5012,29 @@ FIRRTLType OpenSubfieldOp::inferReturnType(ValueRange operands, return inType.getElementTypePreservingConst(fieldIndex); } +FIRRTLType LHSSubfieldOp::inferReturnType(ValueRange operands, + ArrayRef attrs, + std::optional loc) { + auto aType = cast(operands[0].getType()).getType(); + auto inType = type_cast(aType); + auto fieldIndex = + getAttr(attrs, "fieldIndex").getValue().getZExtValue(); + + if (fieldIndex >= inType.getNumElements()) + return emitInferRetTypeError(loc, + "subfield element index is greater than the " + "number of fields in the bundle type"); + + // OpenSubfieldOp verifier checks that the field index is valid with number of + // subelements. + return LHSType::get(inType.getElementTypePreservingConst(fieldIndex)); +} + bool SubfieldOp::isFieldFlipped() { BundleType bundle = getInput().getType(); return bundle.getElement(getFieldIndex()).isFlip; } + bool OpenSubfieldOp::isFieldFlipped() { auto bundle = getInput().getType(); return bundle.getElement(getFieldIndex()).isFlip; @@ -4679,6 +5074,23 @@ FIRRTLType OpenSubindexOp::inferReturnType(ValueRange operands, return emitInferRetTypeError(loc, "subindex requires vector operand"); } +FIRRTLType LHSSubindexOp::inferReturnType(ValueRange operands, + ArrayRef attrs, + std::optional loc) { + auto inType = cast(operands[0].getType()).getType(); + auto fieldIdx = + getAttr(attrs, "index").getValue().getZExtValue(); + + if (auto vectorType = type_dyn_cast(inType)) { + if (fieldIdx < vectorType.getNumElements()) + return LHSType::get(vectorType.getElementTypePreservingConst()); + return emitInferRetTypeError(loc, "out of range index '", fieldIdx, + "' in vector type ", inType); + } + + return emitInferRetTypeError(loc, "subindex requires vector operand"); +} + FIRRTLType SubtagOp::inferReturnType(ValueRange operands, ArrayRef attrs, std::optional loc) { @@ -5956,6 +6368,10 @@ void OpenSubfieldOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { genericAsmResultNames(*this, setNameFn); } +void LHSSubfieldOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { + genericAsmResultNames(*this, setNameFn); +} + void SubtagOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { genericAsmResultNames(*this, setNameFn); } @@ -5968,6 +6384,10 @@ void OpenSubindexOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { genericAsmResultNames(*this, setNameFn); } +void LHSSubindexOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { + genericAsmResultNames(*this, setNameFn); +} + void TagExtractOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { genericAsmResultNames(*this, setNameFn); } @@ -6276,6 +6696,40 @@ LayerBlockOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +//===----------------------------------------------------------------------===// +// Printer/Parser Helpers. +//===----------------------------------------------------------------------===// + +/// Elide the lhs wrapper and the lhs if the inside is the rhs. +static void printOptionalLHSOpTypes(OpAsmPrinter &p, Operation *op, Type lhs, + Type rhs) { + // If operand types are the same, print a rhs type. + auto lhsCast = dyn_cast(lhs); + if (!lhsCast || lhsCast.getType() != rhs) + p << lhs << ", " << rhs; + else + p << rhs; +} + +static ParseResult parseOptionalLHSOpTypes(OpAsmParser &parser, Type &lhs, + Type &rhs) { + if (parser.parseType(rhs)) + return failure(); + + // Parse an optional rhs type. + if (parser.parseOptionalComma()) { + auto cRhs = dyn_cast(rhs); + if (!cRhs) + return failure(); + lhs = LHSType::get(parser.getContext(), cRhs); + } else { + lhs = rhs; + if (parser.parseType(rhs)) + return failure(); + } + return success(); +} + //===----------------------------------------------------------------------===// // TblGen Generated Logic. //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/FIRRTL/FIRRTLTypes.cpp b/lib/Dialect/FIRRTL/FIRRTLTypes.cpp index 0705de3a7152..92ea2572d3d7 100644 --- a/lib/Dialect/FIRRTL/FIRRTLTypes.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLTypes.cpp @@ -1184,6 +1184,11 @@ bool firrtl::areAnonymousTypesEquivalent(FIRRTLBaseType lhs, return lhs.getAnonymousType() == rhs.getAnonymousType(); } +bool firrtl::areAnonymousLHSandTypeEquivalent(LHSType lhs, FIRRTLBaseType rhs) { + auto baseLHS = lhs.getType(); + return baseLHS.getAnonymousType() == rhs.getAnonymousType(); +} + bool firrtl::areAnonymousTypesEquivalent(mlir::Type lhs, mlir::Type rhs) { if (auto destBaseType = type_dyn_cast(lhs)) if (auto srcBaseType = type_dyn_cast(rhs)) diff --git a/lib/Dialect/FIRRTL/Transforms/LowerCHIRRTL.cpp b/lib/Dialect/FIRRTL/Transforms/LowerCHIRRTL.cpp index dd2245bdf1cc..adb5d5fb4626 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerCHIRRTL.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerCHIRRTL.cpp @@ -54,7 +54,7 @@ struct LowerCHIRRTLPass void visitExpr(SubfieldOp op); void visitExpr(SubindexOp op); void visitStmt(ConnectOp op); - void visitStmt(MatchingConnectOp op); + void visitStmt(MatchingConnectOp connect); void visitUnhandledOp(Operation *op); // Chain the CHIRRTL visitor to the FIRRTL visitor. diff --git a/test/Dialect/FIRRTL/errors.mlir b/test/Dialect/FIRRTL/errors.mlir index 2f1cadd853b5..48931533397b 100644 --- a/test/Dialect/FIRRTL/errors.mlir +++ b/test/Dialect/FIRRTL/errors.mlir @@ -751,13 +751,23 @@ firrtl.circuit "SubfieldOpInputTypeMismatch" { firrtl.circuit "SubfieldOpNonBundleInputType" { firrtl.module @SubfieldOpFieldError() { %w = firrtl.wire : !firrtl.uint<1> - // expected-error @+1 {{'firrtl.subfield' input must be bundle type, got '!firrtl.uint<1>'}} + // expected-error @+1 {{'firrtl.subfield' input must be bundle or lhs of bundle type, got '!firrtl.uint<1>'}} %w_a = firrtl.subfield %w[a] : !firrtl.uint<1> } } // ----- +firrtl.circuit "SubfieldOpNonBundleInputType" { + firrtl.module @SubfieldOpFieldError() { + %w, %w_write = firrtl.strictwire : !firrtl.uint<1> + // expected-error @+1 {{'firrtl.subfield' input must be bundle or lhs of bundle type, got '!firrtl.lhs>'}} + %w_a = firrtl.subfield %w_write[a] : !firrtl.lhs> + } +} + +// ----- + firrtl.circuit "BitCast1" { firrtl.module @BitCast1() { %a = firrtl.wire : !firrtl.bundle, ready: uint<1>, data: uint> diff --git a/test/Dialect/FIRRTL/test.mlir b/test/Dialect/FIRRTL/test.mlir index 66a233105027..634ae97691c6 100644 --- a/test/Dialect/FIRRTL/test.mlir +++ b/test/Dialect/FIRRTL/test.mlir @@ -366,4 +366,27 @@ firrtl.module @TypeAlias(in %in: !firrtl.alias>, firrtl.matchingconnect %out, %in: !firrtl.alias>, !firrtl.alias> } +// CHECK-LABEL: FlowFix +firrtl.module @FlowFix(in %in : !firrtl.uint<8>, out %out : !firrtl.uint<8>) { + %mod_in, %mod_out = firrtl.strictinstance mod @MyModule(in in : !firrtl.uint<8>, out out : !firrtl.uint<8>) + firrtl.strictconnect %mod_in, %in : !firrtl.uint<8> + firrtl.matchingconnect %out, %mod_out : !firrtl.uint<8> +} + +// CHECK-LABEL: FlowFix2 +firrtl.module @FlowFix2(in %clk : !firrtl.clock) { + %wireb, %wireb_write = firrtl.strictwire : !firrtl.bundle> + %regb, %regb_write = firrtl.strictreg %clk : !firrtl.clock, !firrtl.bundle> + + %wirev, %wirev_write = firrtl.strictwire : !firrtl.vector,2> + %regv, %regv_write = firrtl.strictreg %clk : !firrtl.clock, !firrtl.vector,2> + + %b_a = firrtl.subfield %wireb[a] : !firrtl.bundle> + %v_a = firrtl.subindex %wirev[0] : !firrtl.vector,2> + %wb_a = firrtl.lhssubfield %wireb_write[a] : !firrtl.lhs>> + %wv_a = firrtl.lhssubindex %wirev_write[0] : !firrtl.lhs,2>> + firrtl.strictconnect %wb_a, %v_a : !firrtl.uint<3> + firrtl.strictconnect %wv_a, %b_a : !firrtl.uint<3> +} + }