Skip to content

Commit

Permalink
combine op_dialect_version_map_, import_handler_map_ into onnx_ops_map_
Browse files Browse the repository at this point in the history
Signed-off-by: Soren Lassen <[email protected]>
  • Loading branch information
sorenlassen committed May 18, 2023
1 parent 571b72b commit afa19b0
Showing 1 changed file with 72 additions and 64 deletions.
136 changes: 72 additions & 64 deletions src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,25 @@ class FrontendGenImpl {
ModuleOp module_;
OpBuilder builder_;

// onnxop: list of versions for dialect
std::unordered_map<std::string, std::vector<int>> op_dialect_version_map_;
using ImportHandlerType = void (FrontendGenImpl::*)(const onnx::NodeProto &);

struct VersionedHandler {
int version;
ImportHandlerType handler;
};

using ONNXOpVersions = SmallVector<VersionedHandler, 1>;

// Maps NodeProto::op_type() to sorted vector of (version, handler) pairs.
// TODO: Key by (domain, op_type) pair so we don't rely on names being unique
// across all domains.
std::unordered_map<std::string, ONNXOpVersions> onnx_ops_map_;

// mapping between string name and symbol
ValueSymbolMapping frontend_symbols_;

ModelInputShaper modelInputShaper_;

using ImportHandlerType = void (FrontendGenImpl::*)(const onnx::NodeProto &);

std::unordered_map<std::string, ImportHandlerType> import_handler_map_;

// The total number of elements in all initializers. This value is a rough
// counter of the number of parameters in a model.
int64_t num_of_parameters_ = 0;
Expand Down Expand Up @@ -682,45 +689,6 @@ class FrontendGenImpl {
node.op_type(), version, node.domain());
}

std::string GetImportVersionOfNode(const onnx::NodeProto &node) {
int64_t version = GetDomainVersion(node.domain());
if (version == 0)
return "";

LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing ONNX"
<< node.op_type() << " (" << node.name() << ")"
<< ", Opset: " << version << "\n");

auto opset_list_it = op_dialect_version_map_.find(node.op_type());

// Custom ops may not be present in op_dialect_version_map_. If no version
// info is found, treat as unversioned (no renaming).
if (opset_list_it == op_dialect_version_map_.end())
return "";

auto opset_list = opset_list_it->second;

// A new opset is added to onnx-mlir when it becomes imcompactible.
// But the lowest opset in op_dialect_version_map_ is an exception.
// It is the current opset when onnx-mlir project is started.
// All opset lower than the last opset should use the last opset(version)
if (node.domain().compare("ai.onnx.ml") != 0 &&
version < opset_list.back() && version < MINIMUM_SUPPORTED_OPSET)
llvm::outs() << "Warning: ONNX " << node.op_type()
<< " in your model is using Opset " << version
<< ", which is quite old. Please consider regenerating your "
"model with a newer Opset.\n";

for (int i = opset_list.size() - 1; i > 0; i--) {
if (version < opset_list[i - 1]) {
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": - use Opset "
<< opset_list[i] << "\n");
return "V" + std::to_string(opset_list[i]);
}
}
return "";
}

func::FuncOp CreateFuncOp(
std::string namePrefix, TypeRange operandTypes, TypeRange resultTypes) {
auto funcType = builder_.getFunctionType(operandTypes, resultTypes);
Expand Down Expand Up @@ -912,16 +880,58 @@ class FrontendGenImpl {
}
}

bool TryImportONNXNode(const onnx::NodeProto &node) {
int64_t version = GetDomainVersion(node.domain());
if (version == 0) {
// Unknown domain.
return false;
}

LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing ONNX"
<< node.op_type() << " (" << node.name() << ")"
<< ", Opset: " << version << "\n");

auto versions_it = onnx_ops_map_.find(node.op_type());
if (versions_it == onnx_ops_map_.end()) {
// Unknown op_type.
llvm::outs() << "Warning: ONNX " << node.op_type() << " from domain '"
<< node.domain() << ","
<< " in your model is unsupported.\n";
return false;
}

const ONNXOpVersions &opVersions = versions_it->second;

// A new opset is added to onnx-mlir when it becomes imcompatible.
// But the lowest opset in op_dialect_version_map_ is an exception.
// It is the current opset when onnx-mlir project is started.
// All opset lower than the last opset should use the last opset(version)
if (node.domain().compare("ai.onnx.ml") != 0 &&
version < opVersions.back().version &&
version < MINIMUM_SUPPORTED_OPSET)
llvm::outs() << "Warning: ONNX " << node.op_type()
<< " in your model is using Opset " << version
<< ", which is quite old. Please consider regenerating your "
"model with a newer Opset.\n";

ImportHandlerType handler = opVersions.front().handler;
for (int i = opVersions.size() - 1; i > 0; --i) {
if (version < opVersions[i - 1].version) {
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": - use Opset "
<< opVersions[i].version << "\n");
handler = opVersions[i].handler;
}
}
(this->*handler)(node);
return true;
}

void ImportNode(const onnx::NodeProto &node) {
std::string versionStr = GetImportVersionOfNode(node);

// look up handler for the opName. If not found, create a node
// for a custom op, and issue a warning.
std::string versionedName = node.op_type() + versionStr;
auto handler = import_handler_map_.find(versionedName);
if (handler != import_handler_map_.end()) {
(this->*(handler->second))(node);
} else {
bool imported = TryImportONNXNode(node);
if (!imported) {
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing Custom op "
<< node.op_type() << " (" << node.name() << ")"
<< ", domain: '" << node.domain() << "'\n");
ImportCustomNode(node);
}
}
Expand All @@ -932,18 +942,16 @@ class FrontendGenImpl {
if constexpr (std::is_base_of_v<ONNXOperationTrait<T>, T>) {
StringRef name = T::getONNXName();
int version = T::getONNXSinceVersion();
op_dialect_version_map_[name.str()].push_back(version);

StringRef versionedName = T::getOperationName();
bool hadOnnxPrefix = versionedName.consume_front("onnx.");
assert(hadOnnxPrefix);
import_handler_map_[versionedName.str()] =
&FrontendGenImpl::buildOperation<T>;
ImportHandlerType handler = &FrontendGenImpl::buildOperation<T>;
ONNXOpVersions &opVersions = onnx_ops_map_[name.str()];
// Insert in descending version order:
auto it = opVersions.begin();
while (it != opVersions.end() && it->version > version) {
++it; // Skip past larger versions.
}
opVersions.insert(it, {version, handler});
}
});
for (auto &[name, versions] : op_dialect_version_map_) {
std::sort(versions.begin(), versions.end(), std::greater<int>());
}
}

/*!
Expand Down

0 comments on commit afa19b0

Please sign in to comment.