Skip to content

Commit

Permalink
Filter activation fusions on MCDM
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffbloo committed Oct 12, 2023
1 parent a9290e9 commit cfab945
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@

namespace Dml
{
GraphTransformer::GraphTransformer(
const std::string& name,
const onnxruntime::IExecutionProvider* provider
) : onnxruntime::GraphTransformer(name),
m_providerImpl(static_cast<const ExecutionProvider*>(provider)->GetImpl())
{
}

onnxruntime::common::Status GraphTransformer::ApplyImpl(
onnxruntime::Graph& graph,
bool& modified,
Expand All @@ -27,7 +35,7 @@ namespace Dml
// Perform fusion
{
bool transformModifiedGraph = false;
PerformOperatorFusion(&graph, &transformModifiedGraph);
PerformOperatorFusion(&graph, m_providerImpl->IsMcdmDevice(), &transformModifiedGraph);
modified |= transformModifiedGraph;

if (modified)
Expand All @@ -50,7 +58,7 @@ namespace Dml
return ss.str();
}

void GraphTransformer::PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const
void GraphTransformer::PerformOperatorFusion(onnxruntime::Graph* graph, bool isMcdmDevice, bool* modified) const
{
struct NodeToAdd
{
Expand Down Expand Up @@ -112,7 +120,8 @@ namespace Dml
gsl::narrow_cast<uint32_t>(node.InputDefs().size()),
outputNode.OpType(),
outputNode.Domain(),
outputNode.Op()->SinceVersion());
outputNode.Op()->SinceVersion(),
isMcdmDevice);

if (!fusedOpProperties)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,25 @@

namespace Dml
{
class ExecutionProviderImpl;

// Applies transforms to a Lotus graph. The graph transformer is responsible for setting the execution provider
// on the graph nodes which DML supports.
class GraphTransformer : public onnxruntime::GraphTransformer
{
public:
GraphTransformer(
const std::string& name
) : onnxruntime::GraphTransformer(name)
{
}
const std::string& name,
const onnxruntime::IExecutionProvider* provider
);

private:
onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const onnxruntime::logging::Logger& logger) const final;

private:
void PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const;
void PerformOperatorFusion(onnxruntime::Graph* graph, bool isMcdmDevice, bool* modified) const;

const ExecutionProviderImpl* m_providerImpl = nullptr;
};

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ namespace Dml
std::string_view domain;
int sinceVersion;
std::vector<std::string_view> activationFilter;
bool enableOnMcdm;
std::vector<std::string_view> extraMcdmActivationFilter;
std::optional<uint32_t> inputCountFilter;
};

Expand All @@ -142,10 +144,10 @@ namespace Dml

static const OperatorInfo c_fusableOps[] =
{
OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Conv },
OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Conv },
OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ConvTranspose },
OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_ConvTranspose },
OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Conv, {}, true, {"Relu", "LeakyRelu"} },
OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Conv, {}, true, {"Relu", "LeakyRelu"} },
OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ConvTranspose, {}, true, {"Relu", "LeakyRelu"} },
OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_ConvTranspose, {}, true, {"Relu", "LeakyRelu"} },
OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_BatchNormalization },
OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_BatchNormalization },
OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_BatchNormalization },
Expand All @@ -163,11 +165,11 @@ namespace Dml
OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MatMul },

// The filter for activation functions maps to what DML's fused op internally fuses at the shader level.
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"} },
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Add, {"Relu", "LeakyRelu"} },
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_Add, {"Relu", "LeakyRelu"} },
OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet8::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 },
OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 },
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true },
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true },
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true },
OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet8::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, true, {} , 2 },
OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, true, {} , 2 },
};

// Not all activations can be fused - only simple elementwise activations (i.e. activation functions which
Expand Down Expand Up @@ -205,7 +207,8 @@ namespace Dml
int candidateOpInputCount,
std::string_view activationOpType,
std::string_view activationOpDomain,
int activationOpSinceVersion)
int activationOpSinceVersion,
bool isMcdmDevice)
{
auto opIt = std::find(
std::begin(c_fusableOps),
Expand Down Expand Up @@ -233,6 +236,20 @@ namespace Dml
return std::nullopt;
}

if (isMcdmDevice)
{
if (!opIt->enableOnMcdm)
{
return std::nullopt;
}

if (!opIt->extraMcdmActivationFilter.empty() &&
std::find(opIt->extraMcdmActivationFilter.begin(), opIt->extraMcdmActivationFilter.end(), activationOpType) == opIt->extraMcdmActivationFilter.end())
{
return std::nullopt;
}
}

if (opIt->inputCountFilter && *opIt->inputCountFilter != static_cast<uint32_t>(candidateOpInputCount))
{
return std::nullopt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ namespace Dml
int candidateOpInputCount,
std::string_view activationOpType,
std::string_view activationOpDomain,
int activationOpSinceVersion);
int activationOpSinceVersion,
bool isMcdmDevice);

// Returns true if the given activation operator type supports being fused with a fusable operator, false
// otherwise.
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1524,7 +1524,8 @@ common::Status InferenceSession::Initialize() {
// This transformer applies DML-specific fusions that go beyond what ORT offers by default
bool dml_operator_fusion_enabled = session_options_.graph_optimization_level >= TransformerLevel::Level2;
if (dml_operator_fusion_enabled) {
std::unique_ptr<onnxruntime::GraphTransformer> dmlOperatorFusionTransformer = std::make_unique<Dml::GraphTransformer>("DmlOperatorFusionTransformer");
std::unique_ptr<onnxruntime::GraphTransformer> dmlOperatorFusionTransformer = std::make_unique<Dml::GraphTransformer>("DmlOperatorFusionTransformer",
execution_providers_.Get(kDmlExecutionProvider));
if (dmlOperatorFusionTransformer == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "DmlOperatorFusionTransformer is nullptr");
}
Expand Down

0 comments on commit cfab945

Please sign in to comment.