Skip to content

Commit

Permalink
bugfix and optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang committed Nov 9, 2023
1 parent 7e0f97d commit 91b0c78
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 18 deletions.
4 changes: 2 additions & 2 deletions orttraining/orttraining/python/training/ort_triton/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(self, x_numel: sympy.Expr, r_numel: sympy.Expr, contiguous: bool):
if x_numel.is_number
else int(
x_numel.subs(
{symbol: sympy.Integer(extract_shape_from_symbol(symbol)) for symbol in x_numel.free_symbols}
{symbol: sympy.Integer(extract_shape_from_symbol(symbol.name)) for symbol in x_numel.free_symbols}
)
)
)
Expand All @@ -220,7 +220,7 @@ def __init__(self, x_numel: sympy.Expr, r_numel: sympy.Expr, contiguous: bool):
if r_numel.is_number
else int(
r_numel.subs(
{symbol: sympy.Integer(extract_shape_from_symbol(symbol)) for symbol in r_numel.free_symbols}
{symbol: sympy.Integer(extract_shape_from_symbol(symbol.name)) for symbol in r_numel.free_symbols}
)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def __init__(self, model: ModelProto, input_shapes: List[List[sympy.Expr]]):
self._elementwise_graph_outputs: Set[str] = set()
for node in self._graph.node:
if is_elementwise_node(node):
self._elementwise_graph_outputs.update(node.output)
self._elementwise_graph_outputs.update(
[output for output in node.output if output in self._graph.output]
)

# Topological sort the nodes in the graph.
self._sorted_nodes: List[NodeProto] = topological_sort(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import sympy


def extract_shape_from_symbol(symbol: sympy.Symbol) -> int:
match = re.match(r"i(\d+)_dim(\d+)_(\d+)", symbol.name)
def extract_shape_from_symbol(symbol: str) -> int:
match = re.match(r"i(\d+)_dim(\d+)_(\d+)", symbol)
assert match
return int(match.group(3))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import os
import sys
from types import ModuleType
from typing import List, Tuple
from typing import List, Tuple, Union

import onnx
import sympy
from torch._C import _from_dlpack
from torch.utils.dlpack import to_dlpack

Expand Down Expand Up @@ -43,37 +42,37 @@ class _ShapeCache:
clear = staticmethod(cache.clear)

@classmethod
def get_shape(cls, onnx_key: int, shapes: List[List[sympy.Expr]]) -> List[List[sympy.Expr]]:
def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[int, str]]]:
if onnx_key not in cls.cache:
cls.cache[onnx_key] = shapes
else:
changed = False
for i, shape in enumerate(shapes):
for j, dim in enumerate(shape):
if dim != cls.cache[onnx_key][i][j] and cls.cache[onnx_key][i][j].is_number:
max_dim = max(int(dim), int(cls.cache[onnx_key][i][j]))
shape[j] = sympy.Symbol(f"i{i}_dim{j}_{next_power_of_2(max_dim)}")
if dim != cls.cache[onnx_key][i][j] and isinstance(cls.cache[onnx_key][i][j], int):
max_dim = max(dim, cls.cache[onnx_key][i][j])
shape[j] = f"i{i}_dim{j}_{next_power_of_2(max_dim)}"
changed = True
elif cls.cache[onnx_key][i][j].is_symbol:
elif isinstance(cls.cache[onnx_key][i][j], str):
pre = extract_shape_from_symbol(cls.cache[onnx_key][i][j])
if pre >= int(dim):
if pre >= dim:
shape[j] = cls.cache[onnx_key][i][j]
else:
shape[j] = sympy.Symbol(f"i{i}_dim{j}_{next_power_of_2(int(dim))}")
shape[j] = f"i{i}_dim{j}_{next_power_of_2(dim)}"
changed = True
if changed:
cls.cache[onnx_key] = shapes
return cls.cache[onnx_key]


def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[sympy.Expr]]) -> int:
def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> int:
# pylint: disable=unused-argument
return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") % (10**8)


def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[sympy.Expr]]) -> Tuple[str, ModuleType]:
def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]:
model = onnx.load_model_from_string(onnx_str)
sorted_graph = SortedGraph(model, shapes)
sorted_graph = SortedGraph(model, [parse_shape(shape) for shape in shapes])
if _DEBUG_MODE:
os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True)
sorted_graph.save_onnx(f"triton_debug/{onnx_key}")
Expand Down Expand Up @@ -125,7 +124,7 @@ def call_triton_by_onnx(onnx_key: int, onnx_str: bytes, *tensors):

assert all(tensor is not None for tensor in tensors)
torch_tensors = [_from_dlpack(tensor) for tensor in tensors]
concrete_shapes = [parse_shape(list(tensor.size())) for tensor in torch_tensors]
concrete_shapes = [list(tensor.size()) for tensor in torch_tensors]
shapes = _ShapeCache.get_shape(onnx_key, concrete_shapes)
func_name, mod = ModuleCache.load(_gen_key, _gen_module, onnx_key, onnx_str, shapes)
func = getattr(mod, func_name)
Expand Down

0 comments on commit 91b0c78

Please sign in to comment.