diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py b/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py index ed92923589d48..a3681a13699a0 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py @@ -11,7 +11,7 @@ import torch from .._cache import ModuleCache, PyCodeCache -from .._utils import next_power_of_2 +from .._utils import gen_unique_name, next_power_of_2 _DEBUG_MODE = "ORTMODULE_TRITON_DEBUG" in os.environ and int(os.getenv("ORTMODULE_TRITON_DEBUG")) == 1 @@ -305,18 +305,18 @@ def _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name): def _gen_mm_key(dtype: torch.dtype, m: int, n: int, k: int, trans_a: bool, trans_b: bool, alpha: float) -> int: - return hash(f"mm|{dtype}|{m}|{n}|{k}|{trans_a}|{trans_b}|{alpha}") % (10**8) + return hash(f"mm|{dtype}|{m}|{n}|{k}|{trans_a}|{trans_b}|{alpha}") def _gen_mm_module( dtype: torch.dtype, m: int, n: int, k: int, trans_a: bool, trans_b: bool, alpha: float ) -> Tuple[str, ModuleType]: - func_name = f"mm_{_gen_mm_key(dtype, m, n, k, trans_a, trans_b, alpha)}" + func_name = gen_unique_name("mm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) src_code = _MM_TEMPLATE.format(**kwargs) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) - with open(f"triton_debug/{func_name}.py", "w") as f: + with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: f.write(src_code) return func_name, PyCodeCache().load(src_code) @@ -333,7 +333,7 @@ def _gen_gemm_key( alpha: float, beta: float, ) -> int: - return hash(f"gemm|{dtype}|{m}|{n}|{k}|{stride_cm}|{stride_cn}|{trans_a}|{trans_b}|{alpha}|{beta}") % (10**8) + return hash(f"gemm|{dtype}|{m}|{n}|{k}|{stride_cm}|{stride_cn}|{trans_a}|{trans_b}|{alpha}|{beta}") def _gen_gemm_module( @@ -348,7 +348,7 @@ def _gen_gemm_module( alpha: float, beta: float, ) -> Tuple[str, ModuleType]: - func_name = f"gemm_{_gen_gemm_key(dtype, m, n, k, stride_cm, stride_cn, trans_a, trans_b, alpha, beta)}" + func_name = gen_unique_name("gemm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) kwargs["stride_cm"] = stride_cm kwargs["stride_cn"] = stride_cn @@ -356,7 +356,7 @@ def _gen_gemm_module( src_code = _GEMM_TEMPLATE.format(**kwargs) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) - with open(f"triton_debug/{func_name}.py", "w") as f: + with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: f.write(src_code) return func_name, PyCodeCache().load(src_code) @@ -364,13 +364,13 @@ def _gen_gemm_module( def _gen_bmm_key( dtype: torch.dtype, m: int, n: int, k: int, batch_a: int, batch_b: int, trans_a: bool, trans_b: bool, alpha: float ) -> int: - return hash(f"bmm|{dtype}|{m}|{n}|{k}|{batch_a}|{batch_b}|{trans_a}|{trans_b}|{alpha}") % (10**8) + return hash(f"bmm|{dtype}|{m}|{n}|{k}|{batch_a}|{batch_b}|{trans_a}|{trans_b}|{alpha}") def _gen_bmm_module( dtype: torch.dtype, m: int, n: int, k: int, batch_a: int, batch_b: int, trans_a: bool, trans_b: bool, alpha: float ) -> Tuple[str, ModuleType]: - func_name = f"bmm_{_gen_bmm_key(dtype, m, n, k, batch_a, batch_b, trans_a, trans_b, alpha)}" + func_name = gen_unique_name("bmm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) batch = batch_a if batch_a >= batch_b else batch_b kwargs["stride_aq"] = m * k if batch_a == batch else 0 @@ -379,7 +379,7 @@ def _gen_bmm_module( src_code = _BMM_TEMPLATE.format(**kwargs) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) - with open(f"triton_debug/{func_name}.py", "w") as f: + with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: f.write(src_code) return func_name, PyCodeCache().load(src_code) diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py index 1fe61750e651e..f16abc71251ed 100644 --- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py @@ -67,7 +67,7 @@ def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[in def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> int: # pylint: disable=unused-argument - return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") % (10**8) + return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: