Skip to content

Commit

Permalink
[ORTModule] Remove Unused Arguments from Generated Triton Code (#18636)
Browse files Browse the repository at this point in the history
This PR:
- Remove unused arguments from generated triton code,
- Remove unnecessary mask for symbolic shape case from generated triton
code.
- Add doc for usage of ORTMODULE_TRITON_CONFIG_FILE.
  • Loading branch information
centwang authored Nov 30, 2023
1 parent 5c67a00 commit e1d1033
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 14 deletions.
24 changes: 24 additions & 0 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,30 @@ Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training
export ORTMODULE_USE_TRITON=1
```

#### ORTMODULE_TRITON_CONFIG_FILE

- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: Triton codegen currently supported some Ops such as some elementwise Ops and some reduction Ops. If Triton optimization is enabled, all these supported Ops will be optimized by default if possible. User can provide a customized JSON config file to control which Ops to optimize and how to optimize them. Below is a sample of config JSON. For each Op, Opset version list and domain is needed. Currently "conditions" field can be used to control axis/axes attribute or input, by specify the real value, or "single" means it contains only one dimension, or "constant" means it must be constant tensor. Save the JSON as a file somewhere and assign its path to below env variable to enable the customized config.

```json
{
"ops": {
"Add": {"versions": [13, 14]},
"Sub": {"versions": [13, 14]},
"Identity": {"versions": [13], "is_no_op": True},
"ReduceSum": {"versions": [13], "conditions": {"axes": "[-1]"}},
"Softmax": {"versions": [13]},
"SoftmaxGrad_13": {"domain": "com.microsoft", "versions": [1]}
},
"initializer": "scalar",
"min_nodes": 2
}
```

```bash
export ORTMODULE_TRITON_CONFIG_FILE=triton_config.json
```

#### ORTMODULE_ENABLE_TUNING

- **Feature Area**: *ORTMODULE/TritonOp*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _gen_kernel_signature(self, node: KernelNode, context: CodegenContext, code_

other_input_args = "seed_cuda, " if node.has_dropout else ""
# Support symbolic shape if any.
symbolic_shape_args_str = ", ".join(node.symbolic_shape_variables)
symbolic_shape_args_str = ", ".join(sorted(node.offset_calc.symbolic_shape_variables))
if symbolic_shape_args_str:
other_input_args += f"{symbolic_shape_args_str}, "

Expand Down Expand Up @@ -490,7 +490,7 @@ def ModuleNode(self, node: ModuleNode, context: CodegenContext, code_buffer: Cod
kernel_args_str += ", seed_cuda"

# Support symbolic shape if any.
symbolic_shape_args_str = ", ".join(kernel_node.symbolic_shape_variables)
symbolic_shape_args_str = ", ".join(sorted(kernel_node.offset_calc.symbolic_shape_variables))
if symbolic_shape_args_str:
kernel_args_str += f", {symbolic_shape_args_str}"

Expand Down
39 changes: 27 additions & 12 deletions orttraining/orttraining/python/training/ort_triton/_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,16 @@ def __init__(self, target_shape: List[sympy.Expr], reduce_axes: List[int]):
self.autotune_configs: AutotuneConfigs = AutotuneConfigs(
self.x_numel, self.r_numel, not self.is_reduction or self.reduce_axes[-1] == self.rank - 1
)
self.requires_x_mask: bool = not self.x_numel.is_number or any(
int(self.x_numel) % config[0] != 0 for config in self.autotune_configs.configs
simplified_x_numel = self.x_numel.subs({symbol: sympy.Integer(1) for symbol in self.x_numel.free_symbols})
self.requires_x_mask: bool = any(
simplified_x_numel % sympy.Integer(config[0]) != 0 for config in self.autotune_configs.configs
)
self.requires_r_mask: bool = not self.r_numel.is_number or any(
int(self.r_numel) % config[1] != 0 for config in self.autotune_configs.configs
simplified_r_numel = self.r_numel.subs({symbol: sympy.Integer(1) for symbol in self.r_numel.free_symbols})
self.requires_r_mask: bool = any(
simplified_r_numel % sympy.Integer(config[1]) != 0 for config in self.autotune_configs.configs
)
self.reduced_args: Set[str] = set()
self.symbolic_shape_variables: Set[str] = set()

def get_input_strides(self, name: str) -> List[sympy.Expr]:
assert name in self.input_strides
Expand Down Expand Up @@ -151,14 +154,32 @@ def register_tensor_arg(self, tensor_arg: TensorArg):
else:
strides.insert(0, sympy.Integer(0))
self.input_strides[tensor_arg.name] = strides
x_input_strides = self.get_x_input_strides(tensor_arg.name)
if not self.is_same_x_shape(tensor_arg.name):
for idx, dim in enumerate(self.get_x_input_strides(tensor_arg.name)):
for idx, dim in enumerate(x_input_strides):
if dim != sympy.Integer(0):
self.x_compute_dims.add(idx)
if idx != self.x_rank - 1:
self.symbolic_shape_variables.update(
[symbol.name for symbol in self.x_strides[idx].free_symbols]
)
if idx != 0:
self.symbolic_shape_variables.update([symbol.name for symbol in self.x_dims[idx].free_symbols])
elif len(x_input_strides) > 0 and x_input_strides[-1] != sympy.Integer(1):
self.symbolic_shape_variables.update([symbol.name for symbol in x_input_strides[-1].free_symbols])
r_input_strides = self.get_r_input_strides(tensor_arg.name)
if not self.is_same_r_shape(tensor_arg.name):
for idx, dim in enumerate(self.get_r_input_strides(tensor_arg.name)):
for idx, dim in enumerate(r_input_strides):
if dim != sympy.Integer(0):
self.r_compute_dims.add(idx)
if idx != self.r_rank - 1:
self.symbolic_shape_variables.update(
[symbol.name for symbol in self.r_strides[idx].free_symbols]
)
if idx != 0:
self.symbolic_shape_variables.update([symbol.name for symbol in self.r_dims[idx].free_symbols])
elif len(r_input_strides) > 0 and r_input_strides[-1] != sympy.Integer(1):
self.symbolic_shape_variables.update([symbol.name for symbol in r_input_strides[-1].free_symbols])

def is_x_reduced(self, name: str) -> bool:
strides = self.get_input_strides(name)
Expand Down Expand Up @@ -288,7 +309,6 @@ def __init__(self, inputs: List[TensorArg], outputs: List[TensorArg], target_sha
self.target_shape: List[sympy.Expr] = target_shape
self.sub_nodes: List[IRNode] = []
self.var_map: Dict[str, str] = dict()
self.symbolic_shape_variables: List[str] = []
self.has_dropout: bool = False
self.offset_calc: OffsetCalculator = OffsetCalculator(target_shape, reduce_axes)

Expand All @@ -313,11 +333,6 @@ def gen_variable_names(self):
variable_name = self.var_map[name]
assert variable_name not in self.var_map
self.var_map[variable_name] = str(np.array(value.item(), value.dtype))
seen = set()
for dim in self.target_shape:
if dim.is_symbol and dim not in seen:
seen.add(dim)
self.symbolic_shape_variables.append(str(dim))


class ElementwiseKernelNode(KernelNode):
Expand Down

0 comments on commit e1d1033

Please sign in to comment.