diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index 0bf402b750115..cac9b6fc4a2b6 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -45,12 +45,9 @@ class TritonCodegen(NodeVisitor): Specialized codegen for Triton backend. """ - def __init__(self): - super().__init__() - def codegen(self, node: IRNode, context: CodegenContext, code_buffer: CodeBuffer, indent: int): func = getattr(self, node.__class__.__name__) - assert func is not None, "unimplemented node: %s" % node.__class__.__name__ + assert func is not None, f"unimplemented node: {node.__class__.__name__}" func(node, context, code_buffer, indent) def _get_elementwise_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> Tuple[str, str]: @@ -125,18 +122,29 @@ def IONode(self, node: IONode, context: CodegenContext, code_buffer: CodeBuffer, def _gen_kernel_signature(self, node: KernelNode, context: CodegenContext, code_buffer: CodeBuffer, indent: int): is_reduction = node.offset_calc.is_reduction space_indent = " " * indent - autotune_configs_str = "" - for config in node.offset_calc.autotune_configs.configs: - if is_reduction: - autotune_configs_str += ( - f'{space_indent} triton.Config({{"XBLOCK": {config[0]}, "RBLOCK": {config[1]}}}, ' - f"num_warps={config[2]}),\n" - ) - else: - autotune_configs_str += ( - f'{space_indent} triton.Config({{"XBLOCK": {config[0]}}}, num_warps={config[2]}),\n' - ) - keys_str = '"xnumel", "rnumel"' if is_reduction else '"xnumel"' + + if len(node.offset_calc.autotune_configs.configs) > 1: + autotune_configs_str = "" + for config in node.offset_calc.autotune_configs.configs: + if is_reduction: + autotune_configs_str += ( + f'{space_indent} triton.Config({{"XBLOCK": {config[0]}, "RBLOCK": {config[1]}}}, ' + f"num_warps={config[2]}),\n" + ) + else: + autotune_configs_str += ( + f'{space_indent} triton.Config({{"XBLOCK": {config[0]}}}, num_warps={config[2]}),\n' + ) + keys_str = '"xnumel", "rnumel"' if is_reduction else '"xnumel"' + code_buffer += ( + f"{space_indent}@triton.autotune(\n" + f"{space_indent} configs=[\n" + f"{autotune_configs_str}" + f"{space_indent} ],\n" + f"{space_indent} key=[{keys_str}],\n" + f"{space_indent})\n" + ) + input_args = [context.get_variable_name(input.name) for input in node.inputs] input_args_str = ", ".join(input_args) if input_args_str: @@ -158,12 +166,6 @@ def _gen_kernel_signature(self, node: KernelNode, context: CodegenContext, code_ ) code_buffer += ( - f"{space_indent}@triton.autotune(\n" - f"{space_indent} configs=[\n" - f"{autotune_configs_str}" - f"{space_indent} ],\n" - f"{space_indent} key=[{keys_str}],\n" - f"{space_indent})\n" f"{space_indent}@triton.jit\n" f"{space_indent}def {node.name}({input_args_str}{output_args_str}{other_input_args}{blocks_str}):\n" ) @@ -175,8 +177,10 @@ def ElementwiseKernelNode( # noqa: N802 offset_calc = node.offset_calc indent += 4 space_indent = " " * indent + x_numel_str = str(offset_calc.x_numel) + if x_numel_str.isnumeric(): + code_buffer += f"{space_indent}xnumel = {x_numel_str}\n" code_buffer += ( - f"{space_indent}xnumel = {offset_calc.x_numel}\n" f"{space_indent}xoffset = tl.program_id(0) * XBLOCK\n" f"{space_indent}xindex = xoffset + tl.arange(0, XBLOCK)\n" ) @@ -207,9 +211,13 @@ def ReduceKernelNode( # noqa: N802 offset_calc = node.offset_calc indent += 4 space_indent = " " * indent + x_numel_str = str(offset_calc.x_numel) + if x_numel_str.isnumeric(): + code_buffer += f"{space_indent}xnumel = {x_numel_str}\n" + r_numel_str = str(offset_calc.r_numel) + if r_numel_str.isnumeric(): + code_buffer += f"{space_indent}rnumel = {r_numel_str}\n" code_buffer += ( - f"{space_indent}xnumel = {offset_calc.x_numel}\n" - f"{space_indent}rnumel = {offset_calc.r_numel}\n" f"{space_indent}xoffset = tl.program_id(0) * XBLOCK\n" f"{space_indent}xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n" f"{space_indent}rbase = tl.arange(0, RBLOCK)[None, :]\n" @@ -444,6 +452,13 @@ def ModuleNode(self, node: ModuleNode, context: CodegenContext, code_buffer: Cod indent += 4 space_indent = " " * indent + seen_symbolic_shape = set() + for input in node.inputs: + for idx, dim in enumerate(input.shape): + if dim.is_symbol and dim not in seen_symbolic_shape: + code_buffer += f"{space_indent}{dim} = {context.get_variable_name(input.name)}.size()[{idx}]\n" + seen_symbolic_shape.add(dim) + if node.has_dropout: code_buffer += ( f'{space_indent}seed_cuda = torch.randint(2**31, size=(), dtype=torch.int64, device="cuda")\n\n' @@ -470,18 +485,31 @@ def ModuleNode(self, node: ModuleNode, context: CodegenContext, code_buffer: Cod if kernel_node.has_dropout: kernel_args_str += ", seed_cuda" + # Support symbolic shape if any. + symbolic_shape_args_str = ", ".join(kernel_node.symbolic_shape_variables) + if symbolic_shape_args_str: + kernel_args_str += f", {symbolic_shape_args_str}" + + block_str = "" + if len(kernel_node.offset_calc.autotune_configs.configs) == 1: + config = kernel_node.offset_calc.autotune_configs.configs[0] + if kernel_node.offset_calc.is_reduction: + block_str = f", XBLOCK={config[0]}, RBLOCK={config[1]}, num_warps={config[2]}" + else: + block_str = f", XBLOCK={config[0]}, num_warps={config[2]}" + if isinstance(kernel_node, ReduceKernelNode): code_buffer += ( f"{space_indent}x_numel = {kernel_node.offset_calc.x_numel}\n" f"{space_indent}r_numel = {kernel_node.offset_calc.r_numel}\n" f'{space_indent}grid = lambda meta: (triton.cdiv(x_numel, meta["XBLOCK"]),)\n' - f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, x_numel, r_numel)\n" + f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, x_numel, r_numel{block_str})\n" ) else: code_buffer += ( f"{space_indent}n_elements = {kernel_node.offset_calc.x_numel}\n" f'{space_indent}grid = lambda meta: (triton.cdiv(n_elements, meta["XBLOCK"]),)\n' - f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, n_elements)\n" + f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, n_elements{block_str})\n" ) for name in node.cross_kernel_args_to_delete[idx]: diff --git a/orttraining/orttraining/python/training/ort_triton/_common.py b/orttraining/orttraining/python/training/ort_triton/_common.py index 82ac82cfa2919..b7e55bc733ede 100644 --- a/orttraining/orttraining/python/training/ort_triton/_common.py +++ b/orttraining/orttraining/python/training/ort_triton/_common.py @@ -9,9 +9,11 @@ import sympy from onnx import GraphProto, NodeProto, TensorProto -from ._sympy_utils import parse_shape +from ._sympy_utils import extract_shape_from_symbol from ._utils import get_attribute, get_reduce_info, next_power_of_2 +_SPECIAL_FLOATS: List[str] = ["inf", "-inf"] + class CodegenContext: """ @@ -28,7 +30,8 @@ def get_variable_name(self, name: str) -> str: # For some operators such as data load/store, we need an internal variable name inside the kernel function. def get_internal_variable_name(self, name: str) -> str: var_name = self._var_map[name] - return self._var_map[var_name] if var_name in self._var_map else var_name + var_name = self._var_map[var_name] if var_name in self._var_map else var_name + return f'float("{var_name}")' if var_name in _SPECIAL_FLOATS else var_name class CodeBuffer: @@ -49,14 +52,38 @@ def codegen(self, node: Any, context: CodegenContext, code_buffer: CodeBuffer, i pass +class SymbolicDSU: + """ + A 'disjoint set union' to merge symbolics so that we use less variables in the generated code. + When handling shape inference for elementwise Ops, if two symbols are not equal and they are not 1, we merge them. + """ + + def __init__(self): + self._dsu: Dict[sympy.Expr, sympy.Expr] = {} + + def find(self, symbolic: sympy.Expr) -> sympy.Expr: + if symbolic not in self._dsu: + self._dsu[symbolic] = symbolic + return symbolic + if symbolic == self._dsu[symbolic]: + return symbolic + self._dsu[symbolic] = self.find(self._dsu[symbolic]) + return self._dsu[symbolic] + + def union(self, symbolic: sympy.Expr, other_symbolic: sympy.Expr): + root = self.find(symbolic) + other_root = self.find(other_symbolic) + self._dsu[other_root] = root + + class TensorInfo: """ Represent a input/output tensor of a node. """ - def __init__(self, dtype: TensorProto.DataType, shape: List[Any]): + def __init__(self, dtype: TensorProto.DataType, shape: List[sympy.Expr]): self._dtype: TensorProto.DataType = dtype - self._shape: List[sympy.Expr] = parse_shape(shape) + self._shape: List[sympy.Expr] = shape @property def dtype(self) -> TensorProto.DataType: @@ -66,27 +93,42 @@ def dtype(self) -> TensorProto.DataType: def shape(self) -> List[sympy.Expr]: return self._shape + def update_shape(self, symbolics: SymbolicDSU): + self._shape = [symbolics.find(dim) if dim.is_symbol else dim for dim in self._shape] + -def _infer_elementwise_shape(input_infos: List[TensorInfo]) -> List[sympy.Expr]: +def _infer_elementwise_shape(input_infos: List[TensorInfo], symbolics: SymbolicDSU) -> List[sympy.Expr]: max_len = max([len(input_info.shape) for input_info in input_infos]) output_shape: List[sympy.Expr] = [sympy.Integer(1)] * max_len for input_info in input_infos: offset = max_len - len(input_info.shape) - for i in range(len(input_info.shape)): - if not input_info.shape[i].is_number or input_info.shape[i] != 1: - output_shape[i + offset] = input_info.shape[i] + for idx, dim in enumerate(input_info.shape): + if not dim.is_number or dim != 1: + if not output_shape[idx + offset].is_number or output_shape[idx + offset] != 1: + symbolics.union(output_shape[idx + offset], dim) + else: + output_shape[idx + offset] = dim return output_shape -def _infer_elementwise(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]: - return [TensorInfo(input_infos[0].dtype, _infer_elementwise_shape(input_infos))] +def _infer_elementwise( + node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> List[TensorInfo]: + # pylint: disable=unused-argument + return [TensorInfo(input_infos[0].dtype, _infer_elementwise_shape(input_infos, symbolics))] -def _infer_where(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]: - return [TensorInfo(input_infos[1].dtype, _infer_elementwise_shape(input_infos))] +def _infer_where( + node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> List[TensorInfo]: + # pylint: disable=unused-argument + return [TensorInfo(input_infos[1].dtype, _infer_elementwise_shape(input_infos, symbolics))] -def _infer_reduction(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]: +def _infer_reduction( + node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> List[TensorInfo]: + # pylint: disable=unused-argument input_rank = len(input_infos[0].shape) keep_dims, axes = get_reduce_info(node, graph, input_rank) axes = [axis + input_rank if axis < 0 else axis for axis in axes] @@ -98,17 +140,26 @@ def _infer_reduction(node: NodeProto, input_infos: List[TensorInfo], graph: Grap return [TensorInfo(input_infos[0].dtype, shape)] -def _infer_unary(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]: +def _infer_unary( + node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> List[TensorInfo]: + # pylint: disable=unused-argument return [input_infos[0]] -def _infer_cast(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]: +def _infer_cast( + node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> List[TensorInfo]: + # pylint: disable=unused-argument dtype = get_attribute(node, "to", TensorProto.UNDEFINED) assert dtype != TensorProto.UNDEFINED return [TensorInfo(dtype, input_infos[0].shape)] -def _infer_dropout(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]: +def _infer_dropout( + node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> List[TensorInfo]: + # pylint: disable=unused-argument return [input_infos[0], TensorInfo(TensorProto.BOOL, input_infos[0].shape)] @@ -138,10 +189,12 @@ class TypeAndShapeInfer: } @classmethod - def infer(cls, node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]: + def infer( + cls, node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU + ) -> List[TensorInfo]: if node.op_type not in cls._INFER_FUNC_MAP: raise NotImplementedError(f"Unsupported op type: {node.op_type}") - return cls._INFER_FUNC_MAP[node.op_type](node, input_infos, graph) + return cls._INFER_FUNC_MAP[node.op_type](node, input_infos, graph, symbolics) class AutotuneConfigs: @@ -152,9 +205,30 @@ class AutotuneConfigs: If it's reduction kernel on last contiguous dimensions, the contiguous flag is True. """ - def __init__(self, x_numel: int, r_numel: int, contiguous: bool): - self.configs: List[Tuple[int, int, int]] = self._gen_autotune_configs(x_numel, r_numel, contiguous) - self.requires_for_loop: bool = any(config[1] < r_numel for config in self.configs) + def __init__(self, x_numel: sympy.Expr, r_numel: sympy.Expr, contiguous: bool): + x_numel_int = ( + int(x_numel) + if x_numel.is_number + else int( + x_numel.subs( + {symbol: sympy.Integer(extract_shape_from_symbol(symbol.name)) for symbol in x_numel.free_symbols} + ) + ) + ) + r_numel_int = ( + int(r_numel) + if r_numel.is_number + else int( + r_numel.subs( + {symbol: sympy.Integer(extract_shape_from_symbol(symbol.name)) for symbol in r_numel.free_symbols} + ) + ) + ) + self.configs: List[Tuple[int, int, int]] = self._gen_autotune_configs(x_numel_int, r_numel_int, contiguous) + # If there is symbolic shape, we will not tune the kernel. + if not x_numel.is_number or not r_numel.is_number: + self.configs = self.configs[-1:] + self.requires_for_loop: bool = any(config[1] < r_numel_int for config in self.configs) def _num_warps(self, x: int, r: int) -> int: return min(max(x * r // 256, 2), 8) diff --git a/orttraining/orttraining/python/training/ort_triton/_decompose.py b/orttraining/orttraining/python/training/ort_triton/_decompose.py index e18bb16bb80db..ffd20b09b42ea 100644 --- a/orttraining/orttraining/python/training/ort_triton/_decompose.py +++ b/orttraining/orttraining/python/training/ort_triton/_decompose.py @@ -58,7 +58,7 @@ def _get_dtype_and_shape(self, arg_name: str, **kwargs): arg_info = node_arg_infos[arg_name] return arg_info.dtype, arg_info.shape - def _decompose_elementwise_precision(self, node: NodeProto, graph: GraphProto, **kwargs): + def _decompose_elementwise_precision(self, node: NodeProto, **kwargs): x = node.input[0] dtype, _ = self._get_dtype_and_shape(x, **kwargs) if not _is_half_dtype(dtype): @@ -79,15 +79,19 @@ def _decompose_elementwise_precision(self, node: NodeProto, graph: GraphProto, * return [*cast_nodes, op_node, cast_node1] def Exp(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 - return self._decompose_elementwise_precision(node, graph, **kwargs) + # pylint: disable=unused-argument + return self._decompose_elementwise_precision(node, **kwargs) def Pow(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 - return self._decompose_elementwise_precision(node, graph, **kwargs) + # pylint: disable=unused-argument + return self._decompose_elementwise_precision(node, **kwargs) def Sqrt(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 - return self._decompose_elementwise_precision(node, graph, **kwargs) + # pylint: disable=unused-argument + return self._decompose_elementwise_precision(node, **kwargs) def LayerNormalization(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 + # pylint: disable=unused-argument node_name = node.name x = node.input[0] w = node.input[1] @@ -153,6 +157,7 @@ def LayerNormalization(self, node: NodeProto, graph: GraphProto, **kwargs): # n ] def LayerNormalizationGrad(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 + # pylint: disable=unused-argument node_name = node.name dy = node.input[0] x = node.input[1] @@ -241,6 +246,7 @@ def LayerNormalizationGrad(self, node: NodeProto, graph: GraphProto, **kwargs): return decomposed_nodes def Softmax(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 + # pylint: disable=unused-argument node_name = node.name x = node.input[0] y = node.output[0] @@ -259,6 +265,7 @@ def Softmax(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 return [max_node, sub_node, exp_node, sum_node, div_node] def SoftmaxGrad_13(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 + # pylint: disable=unused-argument node_name = node.name dy = node.input[0] y = node.input[1] diff --git a/orttraining/orttraining/python/training/ort_triton/_ir.py b/orttraining/orttraining/python/training/ort_triton/_ir.py index f7d3b31eac5b6..628ea822ff55b 100644 --- a/orttraining/orttraining/python/training/ort_triton/_ir.py +++ b/orttraining/orttraining/python/training/ort_triton/_ir.py @@ -88,13 +88,15 @@ def __init__(self, target_shape: List[sympy.Expr], reduce_axes: List[int]): self.r_strides.insert(0, self.r_strides[0] * self.r_dims[i + 1]) self.r_compute_dims: Set[int] = set() self.input_strides: Dict[str, List[sympy.Expr]] = dict() - # Support concrete shape only for now. - assert self.x_numel.is_integer and self.r_numel.is_integer self.autotune_configs: AutotuneConfigs = AutotuneConfigs( - int(self.x_numel), int(self.r_numel), not self.is_reduction or self.reduce_axes[-1] == self.rank - 1 + self.x_numel, self.r_numel, not self.is_reduction or self.reduce_axes[-1] == self.rank - 1 + ) + self.requires_x_mask: bool = not self.x_numel.is_number or any( + int(self.x_numel) % config[0] != 0 for config in self.autotune_configs.configs + ) + self.requires_r_mask: bool = not self.r_numel.is_number or any( + int(self.r_numel) % config[1] != 0 for config in self.autotune_configs.configs ) - self.requires_x_mask: bool = any(int(self.x_numel) % config[0] != 0 for config in self.autotune_configs.configs) - self.requires_r_mask: bool = any(int(self.r_numel) % config[1] != 0 for config in self.autotune_configs.configs) self.reduced_args: Set[str] = set() def get_input_strides(self, name: str) -> List[sympy.Expr]: @@ -300,17 +302,20 @@ def gen_variable_names(self): self.var_map[name] = "t_" + name for name in self.internal_args: self.var_map[name] = gen_variable_name(name, "t", existing_names) - for constant_name in self.constants: - self.var_map[constant_name] = gen_variable_name(constant_name, "c", existing_names) - if self.constants[constant_name].data is not None: - value = self.constants[constant_name].data + for name, tensor_arg in self.constants.items(): + self.var_map[name] = gen_variable_name(name, "c", existing_names) + if tensor_arg.data is not None: + value = tensor_arg.data if value is not None: assert value.size == 1, f"unsupported constant array {value}" - variable_name = self.var_map[constant_name] + variable_name = self.var_map[name] assert variable_name not in self.var_map self.var_map[variable_name] = str(np.array(value.item(), value.dtype)) - - self.symbolic_shape_variables = [str(dim) for dim in self.target_shape if dim.is_symbol] + seen = set() + for dim in self.target_shape: + if dim.is_symbol and dim not in seen: + seen.add(dim) + self.symbolic_shape_variables.append(str(dim)) class ElementwiseKernelNode(KernelNode): diff --git a/orttraining/orttraining/python/training/ort_triton/_lowering.py b/orttraining/orttraining/python/training/ort_triton/_lowering.py index 5de60e69437a0..5c848d2cecc58 100644 --- a/orttraining/orttraining/python/training/ort_triton/_lowering.py +++ b/orttraining/orttraining/python/training/ort_triton/_lowering.py @@ -51,10 +51,8 @@ def __init__(self, node: NodeProto, reduce_axes: List[int], keep_dims: int, node # r_numel is meant to hint how many elements in a row of tensor will be processed by each kernel. # r is a abbreviation of reduction, so, it's only used for reduction nodes. r_numel: sympy.Expr = sympy.prod(r_dims) if len(r_dims) > 0 else sympy.Integer(1) - # Support concrete shape only for now. - assert x_numel.is_integer and r_numel.is_integer self.autotune_configs: AutotuneConfigs = AutotuneConfigs( - int(x_numel), int(r_numel), len(self.reduce_axes) == 0 or self.reduce_axes[-1] == rank - 1 + x_numel, r_numel, len(self.reduce_axes) == 0 or self.reduce_axes[-1] == rank - 1 ) self.reduced_args: Set[str] = set() if keep_dims != 1: @@ -69,10 +67,8 @@ def _compatible_shape(self, shape: List[sympy.Expr], split_if_different: bool) - if len(shape) > len(self.target_shape): return False shape = [sympy.Integer(1)] * (len(self.target_shape) - len(shape)) + shape - for axis in range(len(shape)): - if shape[axis] != self.target_shape[axis] and ( - not shape[axis].is_number or shape[axis] != sympy.Integer(1) - ): + for axis, dim in enumerate(shape): + if dim != self.target_shape[axis] and (not dim.is_number or dim != sympy.Integer(1)): return False return True @@ -129,7 +125,7 @@ def has_reduced_elementwise_nodes(self) -> bool: return not is_reduction_node(self.nodes_groups[0]) and len(self.reduced_args) > 0 def dependent_nodes(self, keep_reduce_node: bool): - node_map = dict() + node_map = {} reduce_nodes = [] if not keep_reduce_node and self.has_reduced_elementwise_nodes(): for item in self.nodes_groups: @@ -151,8 +147,8 @@ def flatten(self, sorted_nodes: List[NodeProto]) -> Tuple[List[NodeProto], List[ layers = [] group_layer = [self] while len(group_layer) > 0: - node_map = dict() - reduce_node_map = dict() + node_map = {} + reduce_node_map = {} next_layer = [] for group in group_layer: sub_node_map, reduce_nodes = group.dependent_nodes(False) @@ -201,7 +197,7 @@ def __init__(self): self.cross_kernel_inputs: List[str] = [] self.constants: List[str] = [] self.module_outputs: List[str] = [] - self.cross_kernel_outputs: [str] = [] + self.cross_kernel_outputs: List[str] = [] self.internal_args: List[str] = [] @@ -284,7 +280,7 @@ def _process_node(self, node: NodeProto, precessors: Dict[str, List[NodeProto]], return dependent_nodes def _group_nodes(self): - producers = dict() + producers = {} precessors = defaultdict(list) processed = set() groups = [] @@ -321,13 +317,16 @@ def _group_nodes(self): group_dependencies[k].add(j) flag = set() - for i in range(len(groups)): - if i not in flag: - for j in range(i + 1, len(groups)): - if j not in flag and j not in group_dependencies[i] and groups[i].try_merge(groups[j]): - flag.add(j) - self._groups.append(groups[i]) - flag.add(i) + for i, group_i in enumerate(groups): + if i in flag: + continue + for j, group_j in enumerate(groups): + if j <= i: + continue + if j not in flag and j not in group_dependencies[i] and group_i.try_merge(group_j): + flag.add(j) + self._groups.append(group_i) + flag.add(i) def _get_node_io(self, node: NodeProto) -> Tuple[List[TensorArg], List[TensorArg]]: input_args = [] @@ -395,7 +394,7 @@ def _analyze_kernel_io_list(self): def _insert_load_and_store(self, kernel_node: KernelNode): input_names = [input.name for input in kernel_node.inputs] - output_name_map = dict() + output_name_map = {} for output in kernel_node.outputs: output_name_map[output.name] = 0 for node in kernel_node.sub_nodes: @@ -499,7 +498,7 @@ def _lower(self): warnings.warn("Use triton's random for Dropout, ignore the random seed from ORT.", UserWarning) self._analyze_kernel_io_list() - cross_kernel_arg_map = dict() + cross_kernel_arg_map = {} for idx, kernel_io in enumerate(self._kernel_io_list): for output in itertools.chain(kernel_io.cross_kernel_outputs, kernel_io.module_outputs): cross_kernel_arg_map[output] = idx diff --git a/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py b/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py index 69df567500a89..32e54d0868013 100644 --- a/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py +++ b/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py @@ -5,15 +5,17 @@ import copy import itertools -from typing import Any, Dict, List, Set +from typing import Dict, List, Set import numpy as np import onnx -from onnx import GraphProto, ModelProto, NodeProto, helper +import sympy +from onnx import GraphProto, ModelProto, NodeProto, TensorProto, helper -from ._common import TensorInfo, TypeAndShapeInfer +from ._common import SymbolicDSU, TensorInfo, TypeAndShapeInfer from ._decompose import DecomposeDispatch from ._op_config import is_elementwise_node +from ._sympy_utils import parse_shape from ._utils import get_attribute, to_numpy_array, topological_sort @@ -29,17 +31,20 @@ class SortedGraph: input_shapes: the shapes of the model inputs. Can be numeric values or symbolic values. """ - def __init__(self, model: ModelProto, input_shapes: List[List[Any]]): + def __init__(self, model: ModelProto, input_shapes: List[List[sympy.Expr]]): self._model: ModelProto = model self._graph: GraphProto = model.graph - self._input_shapes: List[List[Any]] = input_shapes + self._input_shapes: List[List[sympy.Expr]] = input_shapes # For elementwise graph outputs, when we group nodes to different kernels, if the target shape is different # from other nodes' target shape, even it can be broadcasted, we still need to create a new kernel for it. self._elementwise_graph_outputs: Set[str] = set() + graph_output_names = [output.name for output in self._graph.output] for node in self._graph.node: if is_elementwise_node(node): - self._elementwise_graph_outputs.update(node.output) + self._elementwise_graph_outputs.update( + [output for output in node.output if output in graph_output_names] + ) # Topological sort the nodes in the graph. self._sorted_nodes: List[NodeProto] = topological_sort( @@ -53,7 +58,7 @@ def __init__(self, model: ModelProto, input_shapes: List[List[Any]]): for initializer in self._graph.initializer: self._node_arg_infos[initializer.name] = TensorInfo( initializer.data_type, - list(to_numpy_array(initializer).shape), + parse_shape(list(to_numpy_array(initializer).shape)), ) # Decompose complex operators. @@ -66,7 +71,7 @@ def __init__(self, model: ModelProto, input_shapes: List[List[Any]]): initializers = {} for initializer in self._graph.initializer: initializers[initializer.name] = initializer - self._sorted_initializers: List[TensorInfo] = [] + self._sorted_initializers: List[TensorProto] = [] for node in self._sorted_nodes: for input in node.input: if input in initializers: @@ -157,6 +162,7 @@ def elementwise_graph_outputs(self) -> Set[str]: def _decompose(self): dispatch = DecomposeDispatch() + symbolics: SymbolicDSU = SymbolicDSU() pos = 0 # If a node is complex, decompose it and insert the decomposed nodes at the same position. # All complex Ops are defined in DecomposeDispatch. @@ -175,16 +181,18 @@ def _decompose(self): value_attr = get_attribute(node, "value") self._node_arg_infos[node.output[0]] = TensorInfo( value_attr.data_type, - list(to_numpy_array(value_attr).shape), + parse_shape(list(to_numpy_array(value_attr).shape)), ) else: input_infos = [] for input in node.input: input_infos.append(self._node_arg_infos[input]) - output_infos = TypeAndShapeInfer.infer(node, input_infos, self._graph) + output_infos = TypeAndShapeInfer.infer(node, input_infos, self._graph, symbolics) for idx, output in enumerate(node.output): self._node_arg_infos[output] = output_infos[idx] pos += 1 + for tensor_info in self._node_arg_infos.values(): + tensor_info.update_shape(symbolics) # Save the ONNX graphs for debug purpose. The original ONNX graph is the subgraph from backend. # The processed ONNX graph is the subgraph after decompose, it also contains the concrete shapes for each arg. @@ -197,13 +205,20 @@ def save_onnx(self, file_path_prefix): for node in itertools.chain(processed_model.graph.input, processed_model.graph.output): node.type.tensor_type.shape.Clear() for dim in self.node_arg_infos[node.name].shape: - node.type.tensor_type.shape.dim.add().dim_value = int(dim) + if dim.is_number: + node.type.tensor_type.shape.dim.add().dim_value = int(dim) + else: + node.type.tensor_type.shape.dim.add().dim_param = str(dim) value_infos = [] for node in itertools.chain(self.const_nodes, self.sorted_nodes): for output in node.output: tensor_info = self.node_arg_infos[output] value_infos.append( - helper.make_tensor_value_info(output, tensor_info.dtype, [int(dim) for dim in tensor_info.shape]) + helper.make_tensor_value_info( + output, + tensor_info.dtype, + [int(dim) if dim.is_number else str(dim) for dim in tensor_info.shape], + ) ) processed_model.graph.ClearField("value_info") processed_model.graph.value_info.extend(value_infos) diff --git a/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py b/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py index e3629b5effa38..a4a384c021fe8 100644 --- a/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py +++ b/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py @@ -9,6 +9,12 @@ import sympy +def extract_shape_from_symbol(symbol: str) -> int: + match = re.match(r"i(\d+)_dim(\d+)_(\d+)", symbol) + assert match + return int(match.group(3)) + + def sympy_dot(seq1: List[sympy.Expr], seq2: List[sympy.Expr]) -> sympy.Expr: assert len(seq1) == len(seq2) return sympy.expand(sum(a * b for a, b in zip(seq1, seq2))) diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py b/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py index c1b99e4859dbd..dc9e0c18eac15 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py @@ -6,7 +6,7 @@ import os from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out # noqa: F401 -from ._slice_scel import optimize_graph_for_slice_scel, slice_scel, slice_scel_backward # noqa: F401 +from ._slice_scel import slice_scel, slice_scel_backward # noqa: F401 _all_kernels = [ "triton_gemm", @@ -17,14 +17,9 @@ "slice_scel_backward", ] -_all_optimizers = [ - "optimize_graph_for_slice_scel", -] - if "ORTMODULE_USE_FLASH_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1: - from ._flash_attn import flash_attn_backward, flash_attn_forward, optimize_graph_for_flash_attention # noqa: F401 + from ._flash_attn import flash_attn_backward, flash_attn_forward # noqa: F401 _all_kernels.extend(["flash_attn_forward", "flash_attn_backward"]) - _all_optimizers.append("optimize_graph_for_flash_attention") -__all__ = _all_kernels + _all_optimizers # noqa: PLE0605 +__all__ = _all_kernels # noqa: PLE0605 diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py index b970c730d0441..8a642a1f7f26c 100644 --- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py @@ -8,7 +8,7 @@ import os import sys from types import ModuleType -from typing import List, Tuple +from typing import List, Tuple, Union import onnx from torch._C import _from_dlpack @@ -18,8 +18,8 @@ from ._codegen import codegen from ._op_config import get_supported_ops from ._sorted_graph import SortedGraph -from ._sympy_utils import parse_shape -from ._utils import gen_unique_name +from ._sympy_utils import extract_shape_from_symbol, parse_shape +from ._utils import gen_unique_name, next_power_of_2 _DEBUG_MODE = "ORTMODULE_TRITON_DEBUG" in os.environ and int(os.getenv("ORTMODULE_TRITON_DEBUG")) == 1 @@ -31,11 +31,46 @@ def _gen_module_internal(sorted_graph: SortedGraph) -> Tuple[str, str, ModuleTyp return func_name, src_code, PyCodeCache().load(src_code) -def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[int]]) -> int: +class _ShapeCache: + """ + Cache the shapes of the inputs. The inputs are the concrete shapes of inputs from each step for a given ONNX model. + For those dimensions that the concrete shape is not changed, we use the same concrete shape. + For those dimensions that the concrete shape is changed between different steps, we use a symbolic shape. + """ + + cache = dict() # noqa: RUF012 + clear = staticmethod(cache.clear) + + @classmethod + def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[int, str]]]: + if onnx_key not in cls.cache: + cls.cache[onnx_key] = shapes + else: + changed = False + for i, shape in enumerate(shapes): + for j, dim in enumerate(shape): + if dim != cls.cache[onnx_key][i][j] and isinstance(cls.cache[onnx_key][i][j], int): + max_dim = max(dim, cls.cache[onnx_key][i][j]) + shape[j] = f"i{i}_dim{j}_{next_power_of_2(max_dim)}" + changed = True + elif isinstance(cls.cache[onnx_key][i][j], str): + pre = extract_shape_from_symbol(cls.cache[onnx_key][i][j]) + if pre >= dim: + shape[j] = cls.cache[onnx_key][i][j] + else: + shape[j] = f"i{i}_dim{j}_{next_power_of_2(dim)}" + changed = True + if changed: + cls.cache[onnx_key] = shapes + return cls.cache[onnx_key] + + +def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> int: + # pylint: disable=unused-argument return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") % (10**8) -def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[int]]) -> Tuple[str, ModuleType]: +def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: model = onnx.load_model_from_string(onnx_str) sorted_graph = SortedGraph(model, [parse_shape(shape) for shape in shapes]) if _DEBUG_MODE: @@ -44,7 +79,7 @@ def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[int]]) -> Tupl func_name, src_code, mod = _gen_module_internal(sorted_graph) if _DEBUG_MODE: py_file_path = f"triton_debug/{func_name}_{onnx_key}.py" - with open(py_file_path, "w") as f: + with open(py_file_path, "w", encoding="UTF-8") as f: f.write(src_code) return func_name, mod @@ -90,7 +125,8 @@ def call_triton_by_onnx(onnx_key: int, onnx_str: bytes, *tensors): assert all(tensor is not None for tensor in tensors) torch_tensors = [_from_dlpack(tensor) for tensor in tensors] concrete_shapes = [list(tensor.size()) for tensor in torch_tensors] - func_name, mod = ModuleCache.load(_gen_key, _gen_module, onnx_key, onnx_str, concrete_shapes) + shapes = _ShapeCache.get_shape(onnx_key, concrete_shapes) + func_name, mod = ModuleCache.load(_gen_key, _gen_module, onnx_key, onnx_str, shapes) func = getattr(mod, func_name) output = func(*torch_tensors) if isinstance(output, tuple): diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py index d205e8f2377dd..e66582bda9ed1 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py @@ -5,6 +5,7 @@ import json import os import random +import uuid import _test_helpers import onnx @@ -200,7 +201,8 @@ def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwar _, op_type = op_type.split("::") pt_outputs = TorchFuncExecutor.run(op_type, *pt_inputs, **kwargs) model_str = create_model_func(op_type, onnx_dtype, **kwargs).SerializeToString() - ort_outputs = call_triton_by_onnx(hash(model_str), model_str, *[to_dlpack(tensor) for tensor in ort_inputs]) + unique_id = uuid.uuid1().int >> 64 + ort_outputs = call_triton_by_onnx(unique_id, model_str, *[to_dlpack(tensor) for tensor in ort_inputs]) if isinstance(pt_outputs, tuple): assert isinstance(ort_outputs, tuple) assert len(pt_outputs) == len(ort_outputs) @@ -229,9 +231,9 @@ def _run_module_test(module_cls, dtype, gen_inputs_func, triton_op_count, **kwar ort_output = _run_step(ort_model, *ort_inputs) _test_helpers.assert_values_are_close(pt_output, ort_output, rtol=rtol, atol=atol) _test_helpers.assert_gradients_match_and_reset_gradient(pt_model, ort_model, rtol=rtol, atol=atol) - for i in range(len(pt_inputs)): - if pt_inputs[i].requires_grad: - _test_helpers.assert_values_are_close(pt_inputs[i].grad, ort_inputs[i].grad, rtol=rtol, atol=atol) + for idx, pt_input in enumerate(pt_inputs): + if pt_input.requires_grad: + _test_helpers.assert_values_are_close(pt_input.grad, ort_inputs[idx].grad, rtol=rtol, atol=atol) assert os.path.exists(os.path.join(os.getcwd(), "triton_model_torch_exported_training.onnx")) assert os.path.exists(os.path.join(os.getcwd(), "triton_model_optimized_training.onnx")) @@ -250,12 +252,12 @@ def _run_module_test(module_cls, dtype, gen_inputs_func, triton_op_count, **kwar def _run_tunable_op_test(module_cls, dtype, gen_inputs_func, tunable_op, impl_count, **kwargs): + os.environ["ORTMODULE_ENABLE_TUNING"] = "1" + os.environ["ORTMODULE_TUNING_RESULTS_PATH"] = "./" pt_model = module_cls().to(DEVICE).to(dtype) ort_model = ORTModule(copy.deepcopy(pt_model)) rtol = kwargs.get("rtol", 1e-03 if dtype == torch.float16 else 1e-04) atol = kwargs.get("atol", 1e-03 if dtype == torch.float16 else 1e-05) - os.environ["ORTMODULE_ENABLE_TUNING"] = "1" - os.environ["ORTMODULE_TUNING_RESULTS_PATH"] = "./" for _ in range(5): pt_inputs = gen_inputs_func(dtype) ort_inputs = copy.deepcopy(pt_inputs) @@ -265,7 +267,7 @@ def _run_tunable_op_test(module_cls, dtype, gen_inputs_func, tunable_op, impl_co _test_helpers.assert_gradients_match_and_reset_gradient(pt_model, ort_model, rtol=rtol, atol=atol) tunable_results_file = os.path.join(os.getcwd(), "tuning_results_training.json") assert os.path.exists(tunable_results_file) - with open(tunable_results_file) as f: + with open(tunable_results_file, encoding="UTF-8") as f: tunable_results = json.load(f) assert tunable_op in str(tunable_results) del os.environ["ORTMODULE_ENABLE_TUNING"] @@ -275,7 +277,7 @@ def _run_tunable_op_test(module_cls, dtype, gen_inputs_func, tunable_op, impl_co if tunable_op in k: for param, impl in v.items(): v[param] = (impl + 1 + i) % impl_count - with open(tunable_results_file, "w") as f: + with open(tunable_results_file, "w", encoding="UTF-8") as f: json.dump(new_tunable_results, f) ort_model = ORTModule(copy.deepcopy(pt_model)) for _ in range(5): @@ -781,6 +783,43 @@ def _gen_inputs(dtype): _run_module_test(NeuralNetLayerNorm, dtype, _gen_inputs, 2) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_dynamic_shapes_elementwise_module(dtype): + class NeuralNetSymbolicShapesElementwise(torch.nn.Module): + def forward(self, x, y, u, v): + return x * y - (u + v) + + def _gen_inputs(dtype): + dim1 = 64 * random.randint(2, 4) + dim2 = 64 * random.randint(2, 4) + return [ + torch.rand(16, dim1, dim2, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(16, 1, dim2, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(dim1, 1, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(16, dim1, dim2, dtype=dtype, device=DEVICE, requires_grad=True), + ] + + _run_module_test(NeuralNetSymbolicShapesElementwise, dtype, _gen_inputs, 1) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_dynamic_shapes_reduction_module(dtype): + class NeuralNetSymbolicShapesReduction(torch.nn.Module): + def forward(self, x, y, z): + return torch.softmax(x * y + z, dim=-1) + + def _gen_inputs(dtype): + dim1 = 64 * random.randint(2, 4) + dim2 = 64 * random.randint(2, 4) + return [ + torch.rand(16, dim1, dim2, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(16, 1, dim2, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(dim1, 1, dtype=dtype, device=DEVICE, requires_grad=True), + ] + + _run_module_test(NeuralNetSymbolicShapesReduction, dtype, _gen_inputs, 2) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize("has_sum", [True, False]) def test_slice_scel_module(dtype, has_sum):