diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index e0f65ed272d38..9c7214f467af1 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -37,7 +37,7 @@ from ._lowering import lower from ._sorted_graph import SortedGraph from ._sympy_utils import parse_shape, sympy_dot -from ._utils import may_add_brackets +from ._utils import is_number, may_add_brackets class TritonCodegen(NodeVisitor): @@ -318,7 +318,7 @@ def ComputeNode( # noqa: N802 if op_type == "Cast": from_dtype = node.inputs[0].dtype.type to_dtype = node.outputs[0].dtype.type - if from_dtype == to_dtype: + if from_dtype == to_dtype or is_number(kwargs["i0"]): op_type = "Identity" elif to_dtype == np.bool_: op_type = "CastBool" diff --git a/orttraining/orttraining/python/training/ort_triton/_utils.py b/orttraining/orttraining/python/training/ort_triton/_utils.py index c80e28f6f73df..95e6703be8783 100644 --- a/orttraining/orttraining/python/training/ort_triton/_utils.py +++ b/orttraining/orttraining/python/training/ort_triton/_utils.py @@ -150,3 +150,11 @@ def next_power_of_2(n: int) -> int: n |= n >> 16 n += 1 return n + + +def is_number(name: str) -> bool: + try: + float(name) + return True + except ValueError: + return name.startswith("float(") and name.endswith(")") diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py index 0c381d70ca4c1..922f5c696500d 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py @@ -12,6 +12,7 @@ import pytest import torch from onnx import TensorProto, helper +from packaging.version import Version from torch._C import _from_dlpack from torch.utils.dlpack import to_dlpack @@ -842,6 +843,32 @@ def _gen_inputs(dtype): _run_module_test(NeuralNetSliceScel, dtype, _gen_inputs, 2) +@pytest.mark.skipif( + Version(torch.__version__) < Version("2.1"), reason="PyTorch has scaled_dot_product_attention since 2.1." +) +def test_scaled_dot_product_attention_module(): + class NeuralNetScaledDotProductAttention(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16) + self.linear2 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16) + self.linear3 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16) + + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention( + self.linear1(q), self.linear2(k), self.linear3(v) + ).to(torch.float16) + + def _gen_inputs(dtype): + return [ + (torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE), + (torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE), + (torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE), + ] + + _run_module_test(NeuralNetScaledDotProductAttention, torch.float16, _gen_inputs, 3) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize("input_shapes", [([128, 64], [64, 64]), ([16, 64, 128], [16, 128, 64])]) def test_matmul_tunable_op(dtype, input_shapes):