Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
Browse files Browse the repository at this point in the history
…baijumeswani/nominal-checkpoint
  • Loading branch information
baijumeswani committed Jan 24, 2024
2 parents d9e9dda + d7aebf9 commit 3f80b6a
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 14 deletions.
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

past_present_share_buffer_ = info.GetAttrOrDefault<int64_t>("past_present_share_buffer", 0LL) != 0LL;
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;

using HipT = typename ToHipType<T>::MappedType;
using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp<HipT>;
Expand Down Expand Up @@ -121,8 +122,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
query, key, value, bias,
key_padding_mask, relative_position_bias,
past_key, past_value, past_seq_len,
&attn,
num_heads_, mask_filter_value_, scale_,
&attn, num_heads_,
mask_filter_value_, scale_, false, /*is_unidirectional_*/
past_present_share_buffer_, false, device_prop.maxThreadsPerBlock));

if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) {
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/rocm/bert/multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class MultiHeadAttention final : public RocmKernel {
float mask_filter_value_;
float scale_;
bool past_present_share_buffer_{false};
bool is_unidirectional_{false};

// type-erased GemmSoftmaxGemmPermuteTunableOp<HipT>, the reason for this is:
// 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp<HipT> is defined.
Expand Down
12 changes: 8 additions & 4 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi

session_options = self._sess_options if self._sess_options else C.get_default_session_options()

self._register_ep_custom_ops(session_options, providers, provider_options)
self._register_ep_custom_ops(session_options, providers, provider_options, available_providers)

if self._model_path:
sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
Expand Down Expand Up @@ -510,11 +510,15 @@ def _reset_session(self, providers, provider_options):
self._sess_options = self._sess_options_initial
self._create_inference_session(providers, provider_options)

def _register_ep_custom_ops(self, session_options, providers, provider_options):
def _register_ep_custom_ops(self, session_options, providers, provider_options, available_providers):
for i in range(len(providers)):
if providers[i] == "TensorrtExecutionProvider":
if providers[i] in available_providers and providers[i] == "TensorrtExecutionProvider":
C.register_tensorrt_plugins_as_custom_ops(session_options, provider_options[i])
elif isinstance(providers[i], tuple) and providers[i][0] == "TensorrtExecutionProvider":
elif (
isinstance(providers[i], tuple)
and providers[i][0] in available_providers
and providers[i][0] == "TensorrtExecutionProvider"
):
C.register_tensorrt_plugins_as_custom_ops(session_options, providers[i][1])


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ stages:

- template: nuget/templates/test_win.yml
parameters:
AgentPool : 'onnxruntime-Win2022-GPU-T4'
AgentPool : 'onnxruntime-Win2022-GPU-A10'
NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu'
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
Expand All @@ -1034,7 +1034,7 @@ stages:

- template: nuget/templates/test_win.yml
parameters:
AgentPool : 'onnxruntime-Win2022-GPU-T4'
AgentPool : 'onnxruntime-Win2022-GPU-A10'
NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Windows'
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
Expand All @@ -1046,7 +1046,7 @@ stages:

- template: nuget/templates/test_linux.yml
parameters:
AgentPool : Onnxruntime-Linux-GPU
AgentPool : Onnxruntime-Linux-GPU-A10
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu'
Expand All @@ -1055,7 +1055,7 @@ stages:

- template: nuget/templates/test_linux.yml
parameters:
AgentPool : Onnxruntime-Linux-GPU
AgentPool : Onnxruntime-Linux-GPU-A10
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
MoreSuffix: '_Linux'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ stages:
# Testing
- template: nuget/templates/test_win.yml
parameters:
AgentPool : 'onnxruntime-Win2022-GPU-T4'
AgentPool : 'onnxruntime-Win2022-GPU-A10'
NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu'
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
Expand All @@ -162,7 +162,7 @@ stages:

- template: nuget/templates/test_win.yml
parameters:
AgentPool : 'onnxruntime-Win2022-GPU-T4'
AgentPool : 'onnxruntime-Win2022-GPU-A10'
NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Windows'
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
Expand All @@ -174,7 +174,7 @@ stages:

- template: nuget/templates/test_linux.yml
parameters:
AgentPool : Onnxruntime-Linux-GPU
AgentPool : Onnxruntime-Linux-GPU-A10
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu'
Expand All @@ -184,7 +184,7 @@ stages:

- template: nuget/templates/test_linux.yml
parameters:
AgentPool : Onnxruntime-Linux-GPU
AgentPool : Onnxruntime-Linux-GPU-A10
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
MoreSuffix: '_Linux'
Expand Down

0 comments on commit 3f80b6a

Please sign in to comment.