Skip to content

Commit

Permalink
[inductor][cpp][gemm] move bias add to epilogue (pytorch#130675)
Browse files Browse the repository at this point in the history
Speedup bias-add compute by moving it to the epilogue. Performance numbers measured on "Intel (R) Xeon (R) CPU Max 9480", single core, bf16.
Before
AUTOTUNE linear_unary(512x768, 3072x768, 3072)
  cpp_packed_gemm_0 1.9200 ms 100.0%
  _linear_pointwise 1.9345 ms 99.3%

After
AUTOTUNE linear_unary(512x768, 3072x768, 3072)
  cpp_packed_gemm_0 1.8321 ms 100.0%
  _linear_pointwise 1.9246 ms 95.2%

Pull Request resolved: pytorch#130675
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
  • Loading branch information
jgong5 authored and pytorchmergebot committed Jul 19, 2024
1 parent 5a6a806 commit 39493aa
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 47 deletions.
102 changes: 73 additions & 29 deletions torch/_inductor/codegen/cpp_gemm_template.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mypy: allow-untyped-defs
from typing import Any, Callable, cast, List, Optional, Union
import contextlib
from typing import Any, Callable, cast, List, Optional, Set, Union
from unittest.mock import patch

import torch
import torch.utils
Expand All @@ -9,7 +11,7 @@
from ..kernel.mm_common import mm_args
from ..select_algorithm import DataProcessorTemplateWrapper
from ..utils import cache_on_self, has_free_symbols, parallel_num_threads
from ..virtualized import V
from ..virtualized import ops, V
from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType
from .cpp_template import CppTemplate

Expand All @@ -28,7 +30,7 @@
{%- endif %}
extern "C"
{{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=buffer_aliases)}}
{{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}}
{
{{kernel.maybe_codegen_profile()}}
constexpr int64_t num_threads = {{num_threads}};
Expand Down Expand Up @@ -107,30 +109,17 @@
{%- else %}
{%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
{%- endif %}
{%- if inp is not none and beta != 0 and x_scale is none %}
// For int8, bias should add after convert Y to FP32
for (int64_t m = 0; m < m_size; ++m) {
#pragma omp simd
for (int64_t n = 0; n < n_size; ++n) {
{{kernel.index(acc, ["m", "n"])}} = {{beta}} * {{kernel.index(inp, ["m + m_start", "n + n_start"])}};
}
}
{%- endif %}
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
int64_t k_start = kc * K0;
int64_t k_end = std::min((kc + Kc_blocks) * K0, K);
{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %}
{%- set tile_W_3d = kernel.slice_nd(W, [("nc", "nc + 1"), ("k_start", "k_end"), ()]) %}
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
{%- if inp is not none and beta != 0 and x_scale is none %}
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True)|indent(20, false) }}
{%- else %}
if (kc == k_block_start) {
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=False)|indent(24, false) }}
} else {
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True)|indent(24, false) }}
}
{%- endif %}
}
{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
{{ kernel.store_output(
Expand Down Expand Up @@ -447,7 +436,7 @@ def postprocessor(output):
template.maybe_append_choice(choices)
return template

def render( # type: ignore[override]
def render( # type: ignore[override,return]
self,
kernel: CppTemplateKernel,
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
Expand Down Expand Up @@ -486,17 +475,68 @@ def render( # type: ignore[override]

epilogues: List[ir.IRNode] = []
reindexers: List[Optional[Callable[[List[Any]], List[Any]]]] = []
epilogue_creators: List[Callable[[ir.Buffer], ir.Pointwise]] = []
fake_buffers: List[ir.Buffer] = []
Y_aliases: Set[str] = set()
# TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template,
# but we'd better move it here to align with fp.
if inp is not None and self.beta != 0 and not int8_gemm:

def bias_epilogue(input_buffer: ir.Buffer):
dtype = self.layout.dtype
bias_loader = inp.make_loader()
input_loader = input_buffer.make_loader()

def bias_add_inner(index):
bias = bias_loader(index)
input = input_loader(index)
if self.beta != 1:
result = ops.constant(self.beta, torch.float) * bias + input
else:
result = bias + input
return result

return ir.Pointwise(
device=input_buffer.get_device(),
dtype=dtype,
inner_fn=bias_add_inner,
ranges=input_buffer.get_size(),
)

epilogue_creators.append(bias_epilogue)

if self.epilogue_creator is not None:
gemm_output_name = "GemmOut"
epilogue_creators.append(self.epilogue_creator)

# NOTE [How CPP GEMM template epilogues are organized]
# gemm_output_buffer
# --> zero or more in-template epilogues (created by `epilogue_creators`) -->
# template_buffer
# --> zero or more out-of-template epilogues (`epilogue_nodes`) -->
# Y
if epilogue_creators:
gemm_output_name = "buf_GemmOut"
gemm_output_buffer = ir.Buffer(gemm_output_name, template_buffer.layout)
epilogues.append(
ir.ComputedBuffer(
name=template_buffer.get_name(),
layout=template_buffer.layout,
data=self.epilogue_creator(gemm_output_buffer),
current_input_buffer = gemm_output_buffer
for i, creator in enumerate(epilogue_creators):
if i == len(epilogue_creators) - 1:
buffer_name = template_buffer.get_name()
else:
buffer_name = f"buf_GemmOut_epilogue_{i}"
epilogues.append(
ir.ComputedBuffer(
name=buffer_name,
layout=template_buffer.layout,
data=creator(current_input_buffer),
)
)
)
reindexers.append(None)
fake_buffers.append(current_input_buffer)
Y_aliases.add(current_input_buffer.get_name())
reindexers.append(None)
if i < len(epilogue_creators) - 1:
current_input_buffer = ir.Buffer(
buffer_name, template_buffer.layout
)

Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
use_local_acc = self.layout.dtype != torch.float or int8_gemm
Expand All @@ -505,6 +545,7 @@ def render( # type: ignore[override]
epilogues.extend(epilogue_nodes)
assert Y.get_numel() == epilogues[-1].get_numel()
Y = cast(ir.Buffer, epilogues[-1])
Y_aliases.add(template_buffer.get_name())
if (
Y.get_size() == template_buffer.get_size()
and Y.get_stride() == template_buffer.get_stride()
Expand Down Expand Up @@ -555,9 +596,7 @@ def render( # type: ignore[override]
inp=inp,
Y=Y,
GemmOut=gemm_output_buffer,
buffer_aliases=[(gemm_output_buffer, Y)]
if gemm_output_buffer is not Y
else None,
aliases={alias: Y.get_name() for alias in Y_aliases},
beta=self.beta,
alpha=self.alpha,
num_threads=self.num_threads,
Expand All @@ -576,4 +615,9 @@ def render( # type: ignore[override]
w_zp=w_zp,
acc_buf_dtype=torch.int32 if int8_gemm else torch.float,
)
return self._template_from_string(GEMM_TEMPLATE).render(**options)
with contextlib.ExitStack() as stack:
for buf in fake_buffers:
stack.enter_context(
patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
)
return self._template_from_string(GEMM_TEMPLATE).render(**options)
29 changes: 11 additions & 18 deletions torch/_inductor/codegen/cpp_template_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,19 @@ def def_kernel(
self,
inputs: Dict[str, ir.Buffer],
outputs: Dict[str, ir.Buffer],
aliases: Optional[List[Tuple[ir.Buffer, ir.Buffer]]] = None,
aliases: Optional[Dict[str, str]] = None,
) -> str:
for name, inp in inputs.items():
if inp is not None:
self.args.input_buffers[inp.get_name()] = name
for name, out in outputs.items():
self.args.output_buffers[out.get_name()] = name
if aliases is not None:
for alias, orig in aliases:
orig_name = orig.get_name()
alias_name = alias.get_name()
if orig_name in self.args.input_buffers:
self.args.input_buffers[alias_name] = self.args.input_buffers[
orig_name
]
if orig_name in self.args.output_buffers:
self.args.output_buffers[alias_name] = self.args.output_buffers[
orig_name
]
for alias, orig in aliases.items():
if orig in self.args.input_buffers:
self.args.input_buffers[alias] = self.args.input_buffers[orig]
if orig in self.args.output_buffers:
self.args.output_buffers[alias] = self.args.output_buffers[orig]

unique_sizevars = {
s
Expand All @@ -92,12 +86,11 @@ def def_kernel(
def hook():
# remove all aliases before generate function definition
if aliases is not None:
for alias, _ in aliases:
alias_name = alias.get_name()
if alias_name in self.args.input_buffers:
self.args.input_buffers[alias_name] = "REMOVED"
if alias_name in self.args.output_buffers:
self.args.output_buffers[alias_name] = "REMOVED"
for alias in aliases:
if alias in self.args.input_buffers:
self.args.input_buffers[alias] = "REMOVED"
if alias in self.args.output_buffers:
self.args.output_buffers[alias] = "REMOVED"
cpp_argdefs, _, _ = self.args.cpp_argdefs()
return f"void {self.kernel_name}({', '.join(cpp_argdefs)})"

Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2651,6 +2651,10 @@ def remove_kernel_local_buffers(self) -> None:
fused_node_names = V.kernel.store_buffer_names
names_to_remove = []
for out_buf in V.kernel.store_buffer_names:
if out_buf not in self.name_to_node:
# Aux buffers created during kernel codegen
names_to_remove.append(out_buf)
continue
users = self.name_to_node[out_buf].users
assert users is not None
users = {user.get_name() for user in users if not user.is_weak}
Expand Down

0 comments on commit 39493aa

Please sign in to comment.