Skip to content

Commit

Permalink
adjust test and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
prathikr committed Jul 2, 2024
1 parent dd1849a commit 8219ec9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,7 @@ def upsample_bicubic2d_gradient():
return _upsample_gradient("upsample_bicubic2d_backward", 2)


# based on the following kernel implementation from PyTorch:
# https://github.com/pytorch/pytorch/blob/52341c28e817ee6bc36b529823f8248ba395d5bb/aten/src/ATen/native/transformers/cuda/attention_backward.cu#L748
# dispatch logic:
# based on the following internal PyTorch kernel for efficient attention:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784
@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "")
def scaled_dot_product_attention_gradient():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -971,9 +971,7 @@ def softmax(g, input, dim, dtype=None):
return softmax


# based on the following kernel implementation from PyTorch:
# https://github.com/pytorch/pytorch/blob/00f675bb4c2ec02bb5ffecfc75571026e220701c/aten/src/ATen/native/transformers/cuda/attention.cu#L788
# dispatch logic:
# based on the following internal PyTorch kernel for efficient attention:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778
@register_symbolic("scaled_dot_product_attention")
def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6946,7 +6946,7 @@ def gen_inputs(device, dtype):

device = "cuda"
pt_model = _NeuralNetAttention().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn"))

def run_step(model, inputs):
prediction = model(*inputs)
Expand All @@ -6962,3 +6962,23 @@ def run_step(model, inputs):

_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)

execution_mgr = ort_model._torch_module._execution_manager._training_manager
from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name

path = os.path.join(
execution_mgr._debug_options.save_onnx_models.path,
_get_onnx_file_name(
execution_mgr._debug_options.save_onnx_models.name_prefix, "execution_model", execution_mgr._export_mode
),
)

onnx_model = onnx.load(path)
onnx_nodes = onnx_model.graph.node

mem_eff_attn_nodes = 0
for node in onnx_nodes:
if ("ATen" in node.name) and ("scaled_dot_product_attention" in node.attributes.operator):
mem_eff_attn_nodes += 1

assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found"

0 comments on commit 8219ec9

Please sign in to comment.