From 91b0c7824636b5df73665ed6efcc07fda5fbd704 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 9 Nov 2023 15:47:44 +0800 Subject: [PATCH] bugfix and optimization --- .../python/training/ort_triton/_common.py | 4 +-- .../training/ort_triton/_sorted_graph.py | 4 ++- .../training/ort_triton/_sympy_utils.py | 4 +-- .../training/ort_triton/triton_op_executor.py | 25 +++++++++---------- 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/orttraining/orttraining/python/training/ort_triton/_common.py b/orttraining/orttraining/python/training/ort_triton/_common.py index 88f0502f4fb0a..b7e55bc733ede 100644 --- a/orttraining/orttraining/python/training/ort_triton/_common.py +++ b/orttraining/orttraining/python/training/ort_triton/_common.py @@ -211,7 +211,7 @@ def __init__(self, x_numel: sympy.Expr, r_numel: sympy.Expr, contiguous: bool): if x_numel.is_number else int( x_numel.subs( - {symbol: sympy.Integer(extract_shape_from_symbol(symbol)) for symbol in x_numel.free_symbols} + {symbol: sympy.Integer(extract_shape_from_symbol(symbol.name)) for symbol in x_numel.free_symbols} ) ) ) @@ -220,7 +220,7 @@ def __init__(self, x_numel: sympy.Expr, r_numel: sympy.Expr, contiguous: bool): if r_numel.is_number else int( r_numel.subs( - {symbol: sympy.Integer(extract_shape_from_symbol(symbol)) for symbol in r_numel.free_symbols} + {symbol: sympy.Integer(extract_shape_from_symbol(symbol.name)) for symbol in r_numel.free_symbols} ) ) ) diff --git a/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py b/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py index 2645cc3b30337..23b591b48590c 100644 --- a/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py +++ b/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py @@ -41,7 +41,9 @@ def __init__(self, model: ModelProto, input_shapes: List[List[sympy.Expr]]): self._elementwise_graph_outputs: Set[str] = set() for node in self._graph.node: if is_elementwise_node(node): - self._elementwise_graph_outputs.update(node.output) + self._elementwise_graph_outputs.update( + [output for output in node.output if output in self._graph.output] + ) # Topological sort the nodes in the graph. self._sorted_nodes: List[NodeProto] = topological_sort( diff --git a/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py b/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py index 3e43d1f32a656..a4a384c021fe8 100644 --- a/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py +++ b/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py @@ -9,8 +9,8 @@ import sympy -def extract_shape_from_symbol(symbol: sympy.Symbol) -> int: - match = re.match(r"i(\d+)_dim(\d+)_(\d+)", symbol.name) +def extract_shape_from_symbol(symbol: str) -> int: + match = re.match(r"i(\d+)_dim(\d+)_(\d+)", symbol) assert match return int(match.group(3)) diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py index 18ab5d846d0c9..8a642a1f7f26c 100644 --- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py @@ -8,10 +8,9 @@ import os import sys from types import ModuleType -from typing import List, Tuple +from typing import List, Tuple, Union import onnx -import sympy from torch._C import _from_dlpack from torch.utils.dlpack import to_dlpack @@ -43,37 +42,37 @@ class _ShapeCache: clear = staticmethod(cache.clear) @classmethod - def get_shape(cls, onnx_key: int, shapes: List[List[sympy.Expr]]) -> List[List[sympy.Expr]]: + def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[int, str]]]: if onnx_key not in cls.cache: cls.cache[onnx_key] = shapes else: changed = False for i, shape in enumerate(shapes): for j, dim in enumerate(shape): - if dim != cls.cache[onnx_key][i][j] and cls.cache[onnx_key][i][j].is_number: - max_dim = max(int(dim), int(cls.cache[onnx_key][i][j])) - shape[j] = sympy.Symbol(f"i{i}_dim{j}_{next_power_of_2(max_dim)}") + if dim != cls.cache[onnx_key][i][j] and isinstance(cls.cache[onnx_key][i][j], int): + max_dim = max(dim, cls.cache[onnx_key][i][j]) + shape[j] = f"i{i}_dim{j}_{next_power_of_2(max_dim)}" changed = True - elif cls.cache[onnx_key][i][j].is_symbol: + elif isinstance(cls.cache[onnx_key][i][j], str): pre = extract_shape_from_symbol(cls.cache[onnx_key][i][j]) - if pre >= int(dim): + if pre >= dim: shape[j] = cls.cache[onnx_key][i][j] else: - shape[j] = sympy.Symbol(f"i{i}_dim{j}_{next_power_of_2(int(dim))}") + shape[j] = f"i{i}_dim{j}_{next_power_of_2(dim)}" changed = True if changed: cls.cache[onnx_key] = shapes return cls.cache[onnx_key] -def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[sympy.Expr]]) -> int: +def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> int: # pylint: disable=unused-argument return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") % (10**8) -def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[sympy.Expr]]) -> Tuple[str, ModuleType]: +def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: model = onnx.load_model_from_string(onnx_str) - sorted_graph = SortedGraph(model, shapes) + sorted_graph = SortedGraph(model, [parse_shape(shape) for shape in shapes]) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) sorted_graph.save_onnx(f"triton_debug/{onnx_key}") @@ -125,7 +124,7 @@ def call_triton_by_onnx(onnx_key: int, onnx_str: bytes, *tensors): assert all(tensor is not None for tensor in tensors) torch_tensors = [_from_dlpack(tensor) for tensor in tensors] - concrete_shapes = [parse_shape(list(tensor.size())) for tensor in torch_tensors] + concrete_shapes = [list(tensor.size()) for tensor in torch_tensors] shapes = _ShapeCache.get_shape(onnx_key, concrete_shapes) func_name, mod = ModuleCache.load(_gen_key, _gen_module, onnx_key, onnx_str, shapes) func = getattr(mod, func_name)