Skip to content

Commit

Permalink
[ORTModule] Support User Config for Triton Codegen, Bugfix for Reduce…
Browse files Browse the repository at this point in the history
…-to-scalar (#18448)

User can provide Triton codegen config JSON through env variable. Also
fix some bugs related to reduction to scalar case.
  • Loading branch information
centwang authored Nov 15, 2023
1 parent b0699d9 commit ed89ca5
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 8 deletions.
14 changes: 9 additions & 5 deletions orttraining/orttraining/python/training/ort_triton/_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def codegen(self, node: IRNode, context: CodegenContext, code_buffer: CodeBuffer

def _get_elementwise_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> Tuple[str, str]:
if offset_calc.is_x_reduced(arg_name):
return "", ""
# Scalar.
return "tl.full([1], 0, tl.int32)", ""
if offset_calc.is_same_x_shape(arg_name):
return "xindex", "xmask" if offset_calc.requires_x_mask else ""
strides = offset_calc.get_input_strides(arg_name)
Expand Down Expand Up @@ -88,13 +89,16 @@ def _get_reduce_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str)
if offset_calc.requires_r_mask:
mask_strs.append("rmask")

# If both is_x_reduced and is_r_reduced are True, it's scalar.
if len(offset_strs) == 0:
offset_strs.append("tl.full([1, 1], 0, tl.int32)")
return " + ".join(offset_strs), " & ".join(mask_strs)

def _get_offset_mask(self, node: OffsetCalculator, arg_name: str) -> Tuple[str, str]:
def _get_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> Tuple[str, str]:
return (
self._get_reduce_offset_mask(node, arg_name)
if node.is_reduction
else self._get_elementwise_offset_mask(node, arg_name)
self._get_reduce_offset_mask(offset_calc, arg_name)
if offset_calc.is_reduction
else self._get_elementwise_offset_mask(offset_calc, arg_name)
)

def IONode(self, node: IONode, context: CodegenContext, code_buffer: CodeBuffer, indent: int): # noqa: N802
Expand Down
6 changes: 4 additions & 2 deletions orttraining/orttraining/python/training/ort_triton/_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def register_tensor_arg(self, tensor_arg: TensorArg):
input_shape = tensor_arg.shape
if tensor_arg.name in self.reduced_args:
assert self.is_reduction
reduced_rank = len(input_shape) - len(self.reduce_axes)
reduced_rank = len(self.target_shape) - len(self.reduce_axes)
if len(input_shape) < reduced_rank:
input_shape = [sympy.Integer(1)] * (reduced_rank - len(input_shape)) + input_shape
input_shape = (
Expand All @@ -143,7 +143,9 @@ def register_tensor_arg(self, tensor_arg: TensorArg):
input_shape = [sympy.Integer(1)] * (len(self.target_shape) - len(input_shape)) + input_shape
running_stride = sympy.Integer(1)
for i in range(len(self.target_shape) - 1, -1, -1):
if self.target_shape[i] == input_shape[i]:
if self.target_shape[i] == input_shape[i] and not (
tensor_arg.name in self.reduced_args and i in self.reduce_axes
):
strides.insert(0, running_stride)
running_stride = running_stride * input_shape[i]
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,20 @@ def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str
def get_config() -> str:
"""
Get the supported ops and other configs in JSON format to control the Triton fusion on backend side.
All supported ops are from _op_config.py. The Triton fusion will try to fuse subgraphs with connected supported ops.
All supported ops are from user config specified by env ORTMODULE_TRITON_CONFIG_FILE or from _op_config.py.
The Triton fusion will try to fuse subgraphs with connected supported ops.
The initializer value can be "none", "scalar", and "all".
"none": no initializer will be added to subgraphs.
"scalar": only related scalar initializers will be added to subgraphs.
"all": all related initializers will be added to subgraphs.
The min_nodes is used to control the minimum number of non-no-op nodes in a subgraph.
"""

config_file = os.getenv("ORTMODULE_TRITON_CONFIG_FILE", "")
if config_file and os.path.exists(config_file):
with open(config_file, encoding="UTF-8") as f:
return f.read()

config = {"ops": get_supported_ops(), "initializer": "scalar", "min_nodes": 2}
return json.dumps(config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,8 @@ def _gen_inputs(dtype):
([123, 4, 5, 6], [2], False),
([16, 8, 16, 8], [1, 3], True),
([16, 8, 16, 8], [0, 2], False),
([16, 8, 16, 8], [0, 1, 2, 3], True),
([16, 1, 16, 8], [0, 1, 2, 3], False),
],
)
def test_reduce_op(op_type, onnx_dtype, input_shape_and_reduce_info):
Expand Down Expand Up @@ -871,3 +873,34 @@ def _gen_inputs(dtype):
return [torch.rand(m_n_k[0], m_n_k[2], dtype=dtype, device=DEVICE, requires_grad=True)]

_run_tunable_op_test(NeuralNetGemm, dtype, _gen_inputs, "GemmTunableOp", 2)


def test_user_config():
n, d, h, w = 8, 768, 12, 64
dtype = torch.float32

class NeuralNetElementwise(torch.nn.Module):
def forward(self, input1, input2, input3, input4):
return input1 + input2 - input3 * input4

def _gen_inputs(dtype):
return [
torch.rand(n, d, h, w, dtype=dtype, device=DEVICE, requires_grad=True),
torch.rand(w, dtype=dtype, device=DEVICE, requires_grad=True),
torch.rand(d, 1, 1, dtype=dtype, device=DEVICE, requires_grad=True),
torch.rand(n, 1, h, w, dtype=dtype, device=DEVICE, requires_grad=True),
]

user_config = (
'{"ops": {"Add": {"versions": [13, 14]}, "Mul": {"versions": [13, 14]}}, '
'"initializer": "scalar", "min_nodes": 2}'
)
with open("user_config.json", "w", encoding="UTF-8") as f:
f.write(user_config)
os.environ["ORTMODULE_TRITON_CONFIG_FILE"] = "./user_config.json"

# Mul is not supported, the graph is splited to 2 subgraphs with single Op, which will not be fused to TritonOp.
_run_module_test(NeuralNetElementwise, dtype, _gen_inputs, 0)

del os.environ["ORTMODULE_TRITON_CONFIG_FILE"]
os.remove(os.path.join(os.getcwd(), "user_config.json"))

0 comments on commit ed89ca5

Please sign in to comment.