Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ORTModule] Remove Mod from Hash to Avoid Conflict for Triton Code-gen #19256

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions orttraining/orttraining/python/training/ort_triton/kernel/_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch

from .._cache import ModuleCache, PyCodeCache
from .._utils import next_power_of_2
from .._utils import gen_unique_name, next_power_of_2

_DEBUG_MODE = "ORTMODULE_TRITON_DEBUG" in os.environ and int(os.getenv("ORTMODULE_TRITON_DEBUG")) == 1

Expand Down Expand Up @@ -305,18 +305,18 @@ def _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name):


def _gen_mm_key(dtype: torch.dtype, m: int, n: int, k: int, trans_a: bool, trans_b: bool, alpha: float) -> int:
return hash(f"mm|{dtype}|{m}|{n}|{k}|{trans_a}|{trans_b}|{alpha}") % (10**8)
return hash(f"mm|{dtype}|{m}|{n}|{k}|{trans_a}|{trans_b}|{alpha}")


def _gen_mm_module(
dtype: torch.dtype, m: int, n: int, k: int, trans_a: bool, trans_b: bool, alpha: float
) -> Tuple[str, ModuleType]:
func_name = f"mm_{_gen_mm_key(dtype, m, n, k, trans_a, trans_b, alpha)}"
func_name = gen_unique_name("mm")
kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name)
src_code = _MM_TEMPLATE.format(**kwargs)
if _DEBUG_MODE:
os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True)
with open(f"triton_debug/{func_name}.py", "w") as f:
with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f:
f.write(src_code)
return func_name, PyCodeCache().load(src_code)

Expand All @@ -333,7 +333,7 @@ def _gen_gemm_key(
alpha: float,
beta: float,
) -> int:
return hash(f"gemm|{dtype}|{m}|{n}|{k}|{stride_cm}|{stride_cn}|{trans_a}|{trans_b}|{alpha}|{beta}") % (10**8)
return hash(f"gemm|{dtype}|{m}|{n}|{k}|{stride_cm}|{stride_cn}|{trans_a}|{trans_b}|{alpha}|{beta}")


def _gen_gemm_module(
Expand All @@ -348,29 +348,29 @@ def _gen_gemm_module(
alpha: float,
beta: float,
) -> Tuple[str, ModuleType]:
func_name = f"gemm_{_gen_gemm_key(dtype, m, n, k, stride_cm, stride_cn, trans_a, trans_b, alpha, beta)}"
func_name = gen_unique_name("gemm")
kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name)
kwargs["stride_cm"] = stride_cm
kwargs["stride_cn"] = stride_cn
kwargs["beta"] = beta
src_code = _GEMM_TEMPLATE.format(**kwargs)
if _DEBUG_MODE:
os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True)
with open(f"triton_debug/{func_name}.py", "w") as f:
with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f:
f.write(src_code)
return func_name, PyCodeCache().load(src_code)


def _gen_bmm_key(
dtype: torch.dtype, m: int, n: int, k: int, batch_a: int, batch_b: int, trans_a: bool, trans_b: bool, alpha: float
) -> int:
return hash(f"bmm|{dtype}|{m}|{n}|{k}|{batch_a}|{batch_b}|{trans_a}|{trans_b}|{alpha}") % (10**8)
return hash(f"bmm|{dtype}|{m}|{n}|{k}|{batch_a}|{batch_b}|{trans_a}|{trans_b}|{alpha}")


def _gen_bmm_module(
dtype: torch.dtype, m: int, n: int, k: int, batch_a: int, batch_b: int, trans_a: bool, trans_b: bool, alpha: float
) -> Tuple[str, ModuleType]:
func_name = f"bmm_{_gen_bmm_key(dtype, m, n, k, batch_a, batch_b, trans_a, trans_b, alpha)}"
func_name = gen_unique_name("bmm")
kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name)
batch = batch_a if batch_a >= batch_b else batch_b
kwargs["stride_aq"] = m * k if batch_a == batch else 0
Expand All @@ -379,7 +379,7 @@ def _gen_bmm_module(
src_code = _BMM_TEMPLATE.format(**kwargs)
if _DEBUG_MODE:
os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True)
with open(f"triton_debug/{func_name}.py", "w") as f:
with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f:
f.write(src_code)
return func_name, PyCodeCache().load(src_code)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[in

def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> int:
# pylint: disable=unused-argument
return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") % (10**8)
return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}")


def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]:
Expand Down
Loading