Skip to content

Commit

Permalink
add llama2+peft support for flash attn, refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang committed Oct 26, 2023
1 parent 2d5c4cf commit 5431e8e
Show file tree
Hide file tree
Showing 14 changed files with 308 additions and 164 deletions.
10 changes: 5 additions & 5 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,8 @@ if (onnxruntime_ENABLE_TRAINING)
file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/*"
)
file(GLOB onnxruntime_python_ortmodule_transformers_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/transformers/*"
file(GLOB onnxruntime_python_ortmodule_graph_optimizers_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/graph_optimizers/*"
)
file(GLOB onnxruntime_python_ort_triton_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ort_triton/*.py"
Expand Down Expand Up @@ -744,7 +744,7 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/transformers
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/kernel
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils
Expand Down Expand Up @@ -799,8 +799,8 @@ if (onnxruntime_ENABLE_TRAINING)
${onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ortmodule_transformers_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/transformers/
${onnxruntime_python_ortmodule_graph_optimizers_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ort_triton_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os

from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out # noqa: F401
from ._slice_scel import slice_scel, slice_scel_backward, transform_slice_scel # noqa: F401
from ._slice_scel import optimize_graph_for_slice_scel, slice_scel, slice_scel_backward # noqa: F401

_all_kernels = [
"triton_gemm",
Expand All @@ -17,14 +17,14 @@
"slice_scel_backward",
]

_all_transformers = [
"transform_slice_scel",
_all_optimizers = [
"optimize_graph_for_slice_scel",
]

if "ORTMODULE_USE_FLASH_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1:
from ._flash_attn import flash_attn_backward, flash_attn_forward, transform_flash_attention # noqa: F401
from ._flash_attn import flash_attn_backward, flash_attn_forward, optimize_graph_for_flash_attention # noqa: F401

_all_kernels.extend(["flash_attn_forward", "flash_attn_backward"])
_all_transformers.append("transform_flash_attention")
_all_optimizers.append("optimize_graph_for_flash_attention")

__all__ = _all_kernels + _all_transformers # noqa: PLE0605
__all__ = _all_kernels + _all_optimizers # noqa: PLE0605
241 changes: 190 additions & 51 deletions orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,8 @@
import triton.language as tl
from onnx import GraphProto, NodeProto, TensorProto, helper

from onnxruntime.training.ortmodule import register_graph_transformer
from onnxruntime.training.ortmodule.transformers.utils import (
GraphMatcher,
check_attribute_value,
make_constant_node,
update_graph,
)
from onnxruntime.training.ortmodule import register_graph_optimizer
from onnxruntime.training.ortmodule.graph_optimizers.utils import GraphMatcher, check_attribute_value, update_graph


# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
Expand Down Expand Up @@ -117,9 +112,6 @@ def _fwd_kernel(
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
# off_b = tl.program_id(1)
# off_h = tl.program_id(2)
# off_hb = off_b * nheads + off_h
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
Expand Down Expand Up @@ -503,17 +495,6 @@ def _bwd_kernel_one_col_block(
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0,
)
# if EVEN_M:
# if EVEN_HEADDIM:
# do = tl.load(do_ptrs)
# else:
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
# else:
# if EVEN_HEADDIM:
# do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
# else:
# do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
# & (offs_d[None, :] < headdim), other=0.0)
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
# compute dp = dot(v, do)
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
Expand Down Expand Up @@ -988,32 +969,6 @@ def flash_attn_backward(do, q, k, v, o, lse, bias=None, **kwargs):
return dq, dk, dv


# Without causal mask, without Dropout. For example, BERT model in HuggingFace.
_PATTERN_0: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
("MatMul", False, []), # 0
("Transpose", True, [(0, 0, 0)]), # 1
("Transpose", True, [(0, 0, 1)]), # 2
("Div", False, [(0, 0, 0)]), # 3
("Add", False, [(3, 0, 0)]), # 4
("Softmax", False, [(4, 0, 0)]), # 5
("MatMul", False, [(5, 0, 0)]), # 6
("Transpose", True, [(6, 0, 1)]), # 7
("Transpose", False, [(6, 0, 0)]), # 8
("FusedMatMul", False, [(7, 0, 1)]), # 9
("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10
("Identity", False, [(10, 0, 0)]), # 11
("Div", False, [(11, 0, 0)]), # 12
("Identity", False, [(12, 0, 0)]), # 13
("FusedMatMul", False, [(2, 0, 1), (13, 0, 0)]), # 14
("FusedMatMul", False, [(1, 0, 0), (13, 0, 1)]), # 15
("FusedMatMul", False, [(5, 0, 0)]), # 16
("Transpose", True, [(16, 0, 1)]), # 17
("Transpose", False, [(14, 0, 0)]), # 18
("Transpose", False, [(15, 0, 0)]), # 19
("Transpose", False, [(16, 0, 0)]), # 20
]


def _make_flash_attention_nodes(
idx: int,
q: str,
Expand Down Expand Up @@ -1053,7 +1008,33 @@ def _make_flash_attention_nodes(
return [fwd_node, bwd_node], [logsumexp]


def _apply_transform_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
# Without causal mask, without Dropout. For example, BERT model in HuggingFace.
_PATTERN_0: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
("MatMul", False, []), # 0
("Transpose", True, [(0, 0, 0)]), # 1
("Transpose", True, [(0, 0, 1)]), # 2
("Div", False, [(0, 0, 0)]), # 3
("Add", False, [(3, 0, 0)]), # 4
("Softmax", False, [(4, 0, 0)]), # 5
("MatMul", False, [(5, 0, 0)]), # 6
("Transpose", True, [(6, 0, 1)]), # 7
("Transpose", False, [(6, 0, 0)]), # 8
("FusedMatMul", False, [(7, 0, 1)]), # 9
("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10
("Identity", False, [(10, 0, 0)]), # 11
("Div", False, [(11, 0, 0)]), # 12
("Identity", False, [(12, 0, 0)]), # 13
("FusedMatMul", False, [(2, 0, 1), (13, 0, 0)]), # 14
("FusedMatMul", False, [(1, 0, 0), (13, 0, 1)]), # 15
("FusedMatMul", False, [(5, 0, 0)]), # 16
("Transpose", True, [(16, 0, 1)]), # 17
("Transpose", False, [(14, 0, 0)]), # 18
("Transpose", False, [(15, 0, 0)]), # 19
("Transpose", False, [(16, 0, 0)]), # 20
]


def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
# Check forward only as the backward is expected to be consistent if it's built correctly.
scale_value = matcher.get_constant_value(nodes[3].input[1])
if not (
Expand Down Expand Up @@ -1081,14 +1062,172 @@ def _apply_transform_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[
return nodes, nodes_to_add, new_value_infos


# llama2+peft, k doesn't require grad.
_PATTERN_1: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
("MatMul", False, []), # 0
("Transpose", True, [(0, 0, 1)]), # 1
("Div", False, [(0, 0, 0)]), # 2
("Add", False, [(2, 0, 0)]), # 3
("Softmax", False, [(3, 0, 0)]), # 4
("MatMul", False, [(4, 0, 0)]), # 5
("Transpose", True, [(5, 0, 1)]), # 6
("Identity", False, [(6, 0, 0)]), # 7
("YieldOp", False, [(7, 0, -1)]), # 8
("Transpose", False, [(5, 0, 0)]), # 9
("FusedMatMul", False, [(6, 0, 1)]), # 10
("SoftmaxGrad_13", False, [(10, 0, 0), (4, 0, 1)]), # 11
("Identity", False, [(11, 0, 0)]), # 12
("Div", False, [(12, 0, 0)]), # 13
("Identity", False, [(13, 0, 0)]), # 14
("FusedMatMul", False, [(1, 0, 1), (14, 0, 0)]), # 15
("FusedMatMul", False, [(4, 0, 0)]), # 16
("Transpose", True, [(16, 0, 1)]), # 17
("Sum", False, [(16, 0, 0)]), # 18
("Transpose", False, [(18, 0, 0)]), # 19
]


def _optimize_for_pattern_1(matcher: GraphProto, idx: int, nodes: List[NodeProto]):
# Check forward only as the backward is expected to be consistent if it's built correctly.
scale_value = matcher.get_constant_value(nodes[2].input[1])
if not (
check_attribute_value(nodes[1], "perm", [0, 1, 3, 2])
and scale_value is not None
and check_attribute_value(nodes[6], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[9], "perm", [0, 2, 1, 3])
and matcher.get_consumer_count(nodes[14].output[0]) == 1
):
return [], [], []

dtype, _ = matcher.get_type_and_shape(nodes[0].input[0])
assert dtype is not None
trans_q_tensor = helper.make_tensor_value_info("trans_q_" + str(idx), dtype, None)
trans_q_grad_tensor = helper.make_tensor_value_info("trans_q_grad_" + str(idx), dtype, None)
trans_k_tensor = helper.make_tensor_value_info("trans_k_" + str(idx), dtype, None)
trans_q = helper.make_node(
"Transpose", [nodes[0].input[0]], [trans_q_tensor.name], "Trans_Q_" + str(idx), perm=[0, 2, 1, 3]
)
trans_q_grad = helper.make_node(
"Transpose", [trans_q_grad_tensor.name], [nodes[15].output[0]], "Trans_Q_Grad_" + str(idx), perm=[0, 2, 1, 3]
)
trans_k = helper.make_node(
"Transpose", [nodes[1].input[0]], [trans_k_tensor.name], "Trans_K_" + str(idx), perm=[0, 2, 1, 3]
)
nodes[19].input[0] = nodes[18].input[1]
v_grad = nodes[19].output[0]
nodes[19].output[0] = nodes[18].output[0]
nodes[18].input[1] = nodes[18].output[0]
nodes[18].output[0] = v_grad
nodes_to_add, new_value_infos = _make_flash_attention_nodes(
idx,
trans_q_tensor.name,
trans_k_tensor.name,
nodes[6].input[0],
nodes[9].output[0],
nodes[17].input[0],
trans_q_grad_tensor.name,
"",
nodes[16].output[0],
nodes[3].input[1],
1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value),
)
nodes_to_remove = nodes[:6] + nodes[9:18]
nodes_to_add.extend([trans_q, trans_q_grad, trans_k])
new_value_infos.extend([trans_q_tensor, trans_q_grad_tensor, trans_k_tensor])
return nodes_to_remove, nodes_to_add, new_value_infos


# llama2+peft, k requires grad.
_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
("MatMul", False, []), # 0
("Transpose", True, [(0, 0, 1)]), # 1
("Div", False, [(0, 0, 0)]), # 2
("Add", False, [(2, 0, 0)]), # 3
("Softmax", False, [(3, 0, 0)]), # 4
("MatMul", False, [(4, 0, 0)]), # 5
("Transpose", True, [(5, 0, 1)]), # 6
("Identity", False, [(6, 0, 0)]), # 7
("YieldOp", False, [(7, 0, -1)]), # 8
("Transpose", False, [(5, 0, 0)]), # 9
("FusedMatMul", False, [(6, 0, 1)]), # 10
("SoftmaxGrad_13", False, [(10, 0, 0), (4, 0, 1)]), # 11
("Identity", False, [(11, 0, 0)]), # 12
("Div", False, [(12, 0, 0)]), # 13
("Identity", False, [(13, 0, 0)]), # 14
("FusedMatMul", False, [(1, 0, 1), (14, 0, 0)]), # 15
("FusedMatMul", False, [(14, 0, 1)]), # 16
("Transpose", False, [(16, 0, 0)]), # 17
("FusedMatMul", False, [(4, 0, 0)]), # 18
("Transpose", True, [(18, 0, 1)]), # 19
("Sum", False, [(18, 0, 0)]), # 20
("Transpose", False, [(20, 0, 0)]), # 21
]


def _aptimize_for_pattern_2(matcher: GraphProto, idx: int, nodes: List[NodeProto]):
# Check forward only as the backward is expected to be consistent if it's built correctly.
scale_value = matcher.get_constant_value(nodes[2].input[1])
if not (
check_attribute_value(nodes[1], "perm", [0, 1, 3, 2])
and scale_value is not None
and check_attribute_value(nodes[6], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[9], "perm", [0, 2, 1, 3])
and matcher.get_consumer_count(nodes[14].output[0]) == 2
):
return [], [], []

dtype, _ = matcher.get_type_and_shape(nodes[0].input[0])
assert dtype is not None
trans_q_tensor = helper.make_tensor_value_info("trans_q_" + str(idx), dtype, None)
trans_q_grad_tensor = helper.make_tensor_value_info("trans_q_grad_" + str(idx), dtype, None)
trans_k_tensor = helper.make_tensor_value_info("trans_k_" + str(idx), dtype, None)
trans_k_grad_tensor = helper.make_tensor_value_info("trans_k_grad_" + str(idx), dtype, None)
trans_q = helper.make_node(
"Transpose", [nodes[0].input[0]], [trans_q_tensor.name], "Trans_Q_" + str(idx), perm=[0, 2, 1, 3]
)
trans_q_grad = helper.make_node(
"Transpose", [trans_q_grad_tensor.name], [nodes[15].output[0]], "Trans_Q_Grad_" + str(idx), perm=[0, 2, 1, 3]
)
trans_k = helper.make_node(
"Transpose", [nodes[1].input[0]], [trans_k_tensor.name], "Trans_K_" + str(idx), perm=[0, 2, 1, 3]
)
trans_k_grad = helper.make_node(
"Transpose", [trans_k_grad_tensor.name], [nodes[17].output[0]], "Trans_K_Grad_" + str(idx), perm=[0, 2, 1, 3]
)
nodes[21].input[0] = nodes[20].input[1]
v_grad = nodes[21].output[0]
nodes[21].output[0] = nodes[20].output[0]
nodes[20].input[1] = nodes[20].output[0]
nodes[20].output[0] = v_grad
nodes_to_add, new_value_infos = _make_flash_attention_nodes(
idx,
trans_q_tensor.name,
trans_k_tensor.name,
nodes[6].input[0],
nodes[9].output[0],
nodes[19].input[0],
trans_q_grad_tensor.name,
trans_k_grad_tensor.name,
nodes[18].output[0],
nodes[3].input[1],
1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value),
)
nodes_to_remove = nodes[:6] + nodes[9:20]
nodes_to_add.extend([trans_q, trans_q_grad, trans_k, trans_k_grad])
new_value_infos.extend([trans_q_tensor, trans_q_grad_tensor, trans_k_tensor, trans_k_grad_tensor])
return nodes_to_remove, nodes_to_add, new_value_infos


# TODO: add pattern to support attention with causal mask, such as GPT2 in HuggingFace.
_PATTERNS = [
(_PATTERN_0, _apply_transform_for_pattern_0),
(_PATTERN_0, _optimize_for_pattern_0),
(_PATTERN_1, _optimize_for_pattern_1),
(_PATTERN_2, _aptimize_for_pattern_2),
]


@register_graph_transformer(devices="cuda")
def transform_flash_attention(graph: GraphProto):
@register_graph_optimizer(devices="cuda")
def optimize_graph_for_flash_attention(graph: GraphProto):
nodes_to_remove = []
nodes_to_add = []
new_value_infos = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import triton.language as tl
from onnx import TensorProto, helper

from onnxruntime.training.ortmodule import register_graph_transformer
from onnxruntime.training.ortmodule import register_graph_optimizer

from .._utils import get_attribute, to_numpy_array

Expand Down Expand Up @@ -246,8 +246,8 @@ def _get_shape_related_nodes(graph, start_arg, sub_graph_nodes):
args.append(output)


@register_graph_transformer(devices="cuda")
def transform_slice_scel(graph):
@register_graph_optimizer(devices="cuda")
def optimize_graph_for_slice_scel(graph):
remove_nodes = []
triton_nodes = []
value_infos = []
Expand Down
4 changes: 2 additions & 2 deletions orttraining/orttraining/python/training/ortmodule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def _are_deterministic_algorithms_enabled():
return ORTMODULE_IS_DETERMINISTIC


from .graph_transformer_registry import register_graph_transformer # noqa: E402, F401
from .graph_optimizer_registry import register_graph_optimizer # noqa: E402, F401
from .graph_optimizers import * # noqa: E402, F403
from .options import DebugOptions, LogLevel # noqa: E402, F401

# ORTModule must be loaded only after all validation passes
from .ortmodule import ORTModule # noqa: E402, F401
from .transformers import * # noqa: F403
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ._logger import ORTModuleInitPhase, SuppressLogs, TrackTime
from ._runtime_inspector import Phase
from ._utils import save_tuning_results, set_tuning_results
from .graph_transformer_registry import GraphTransformerRegistry
from .graph_optimizer_registry import GraphOptimizerRegistry
from .options import DebugOptions, _SkipCheck


Expand Down Expand Up @@ -369,7 +369,7 @@ def _build_graph(self, graph_transformer_config):
device_type = self._device.type
if device_type == "cuda" and self.is_rocm_pytorch:
device_type = "rocm"
GraphTransformerRegistry.transform_all(
GraphOptimizerRegistry.optimize_all(
type(self._flattened_module._original_module).__name__, device_type, self._onnx_models.optimized_model.graph
)

Expand Down
Loading

0 comments on commit 5431e8e

Please sign in to comment.