Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter activation fusions on MCDM #18371

Merged
merged 2 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@

namespace Dml
{
GraphTransformer::GraphTransformer(
const std::string& name,
const onnxruntime::IExecutionProvider* provider
Copy link
Contributor

@fdwr fdwr Nov 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

   [](http://example.com/codeflow?start=0&length=2)

tabs #Closed

) : onnxruntime::GraphTransformer(name),
m_providerImpl(static_cast<const ExecutionProvider*>(provider)->GetImpl())
{
}

onnxruntime::common::Status GraphTransformer::ApplyImpl(
onnxruntime::Graph& graph,
bool& modified,
Expand All @@ -27,7 +35,7 @@ namespace Dml
// Perform fusion
{
bool transformModifiedGraph = false;
PerformOperatorFusion(&graph, &transformModifiedGraph);
PerformOperatorFusion(&graph, m_providerImpl->IsMcdmDevice(), &transformModifiedGraph);
modified |= transformModifiedGraph;

if (modified)
Expand All @@ -50,7 +58,7 @@ namespace Dml
return ss.str();
}

void GraphTransformer::PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const
void GraphTransformer::PerformOperatorFusion(onnxruntime::Graph* graph, bool isMcdmDevice, bool* modified) const
{
struct NodeToAdd
{
Expand Down Expand Up @@ -112,7 +120,8 @@ namespace Dml
gsl::narrow_cast<uint32_t>(node.InputDefs().size()),
outputNode.OpType(),
outputNode.Domain(),
outputNode.Op()->SinceVersion());
outputNode.Op()->SinceVersion(),
isMcdmDevice);

if (!fusedOpProperties)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,25 @@

namespace Dml
{
class ExecutionProviderImpl;

// Applies transforms to a Lotus graph. The graph transformer is responsible for setting the execution provider
// on the graph nodes which DML supports.
class GraphTransformer : public onnxruntime::GraphTransformer
{
public:
GraphTransformer(
const std::string& name
) : onnxruntime::GraphTransformer(name)
{
}
const std::string& name,
const onnxruntime::IExecutionProvider* provider
Copy link
Contributor

@fdwr fdwr Nov 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  	 [](http://example.com/codeflow?start=0&length=3)

tabs
2 other lines too #Closed

);

private:
onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const onnxruntime::logging::Logger& logger) const final;

private:
void PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const;
void PerformOperatorFusion(onnxruntime::Graph* graph, bool isMcdmDevice, bool* modified) const;

const ExecutionProviderImpl* m_providerImpl = nullptr;
};

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ namespace Dml
std::string_view domain;
int sinceVersion;
std::vector<std::string_view> activationFilter;
bool enableOnMcdm;
Copy link
Contributor

@fdwr fdwr Nov 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enableOnMcdm

Huh, that's weird that ORT needs such a low-level detail, and that it's not just handled suitably in DirectML.dll. I'm not asking for changes from you, because I know this is already written by Jeff, but it's weird. #Resolved

Copy link
Contributor

@fdwr fdwr Nov 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just talked with Jeff - it's a stopgap. Consider resolved.

std::vector<std::string_view> extraMcdmActivationFilter;
std::optional<uint32_t> inputCountFilter;
};

Expand All @@ -142,10 +144,10 @@ namespace Dml

static const OperatorInfo c_fusableOps[] =
{
OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Conv },
OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Conv },
OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ConvTranspose },
OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_ConvTranspose },
OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Conv, {}, true, {"Relu", "LeakyRelu"} },
OperatorInfo{ "Conv", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Conv, {}, true, {"Relu", "LeakyRelu"} },
OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ConvTranspose, {}, true, {"Relu", "LeakyRelu"} },
OperatorInfo{ "ConvTranspose", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_ConvTranspose, {}, true, {"Relu", "LeakyRelu"} },
OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_BatchNormalization },
OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_BatchNormalization },
OperatorInfo{ "BatchNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_BatchNormalization },
Expand All @@ -163,11 +165,11 @@ namespace Dml
OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MatMul },

// The filter for activation functions maps to what DML's fused op internally fuses at the shader level.
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"} },
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Add, {"Relu", "LeakyRelu"} },
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_Add, {"Relu", "LeakyRelu"} },
OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet8::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 },
OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 },
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true },
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true },
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet14::sc_sinceVer_Add, {"Relu", "LeakyRelu"}, true },
OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet8::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, true, {} , 2 },
OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, true, {} , 2 },
};

// Not all activations can be fused - only simple elementwise activations (i.e. activation functions which
Expand Down Expand Up @@ -205,7 +207,8 @@ namespace Dml
int candidateOpInputCount,
std::string_view activationOpType,
std::string_view activationOpDomain,
int activationOpSinceVersion)
int activationOpSinceVersion,
bool isMcdmDevice)
{
auto opIt = std::find(
std::begin(c_fusableOps),
Expand Down Expand Up @@ -233,6 +236,20 @@ namespace Dml
return std::nullopt;
}

if (isMcdmDevice)
{
if (!opIt->enableOnMcdm)
{
return std::nullopt;
}

if (!opIt->extraMcdmActivationFilter.empty() &&
std::find(opIt->extraMcdmActivationFilter.begin(), opIt->extraMcdmActivationFilter.end(), activationOpType) == opIt->extraMcdmActivationFilter.end())
{
return std::nullopt;
}
}

if (opIt->inputCountFilter && *opIt->inputCountFilter != static_cast<uint32_t>(candidateOpInputCount))
{
return std::nullopt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ namespace Dml
int candidateOpInputCount,
std::string_view activationOpType,
std::string_view activationOpDomain,
int activationOpSinceVersion);
int activationOpSinceVersion,
bool isMcdmDevice);

// Returns true if the given activation operator type supports being fused with a fusable operator, false
// otherwise.
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1631,7 +1631,8 @@ common::Status InferenceSession::Initialize() {
// This transformer applies DML-specific fusions that go beyond what ORT offers by default
bool dml_operator_fusion_enabled = session_options_.graph_optimization_level >= TransformerLevel::Level2;
if (dml_operator_fusion_enabled) {
std::unique_ptr<onnxruntime::GraphTransformer> dmlOperatorFusionTransformer = std::make_unique<Dml::GraphTransformer>("DmlOperatorFusionTransformer");
std::unique_ptr<onnxruntime::GraphTransformer> dmlOperatorFusionTransformer = std::make_unique<Dml::GraphTransformer>("DmlOperatorFusionTransformer",
execution_providers_.Get(kDmlExecutionProvider));
if (dmlOperatorFusionTransformer == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "DmlOperatorFusionTransformer is nullptr");
}
Expand Down
Loading