Skip to content

Commit

Permalink
[inductor] support vec for atomic add (pytorch#131314)
Browse files Browse the repository at this point in the history
Depends on pytorch#130827 to have correct `index_expr` dtype

Support vec for atomic add by scalar implementation.
TestPlan:
```
python test/inductor/test_cpu_repro.py -k test_scatter_using_atomic_add_vec
```
Generated code for `test_scatter_using_atomic_add_vec`
```
cpp_fused_scatter_0 = async_compile.cpp_pybinding(['const float*', 'const int64_t*', 'const float*', 'float*'], '''
#include "/tmp/torchinductor_root/nn/cnnpkaxivwaa5rzng6qsyc4ao42vschogi3yk33ukwv3emlvxeqq.h"
extern "C"  void kernel(const float* in_ptr0,
                       const int64_t* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0)
{
    {
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(16L); x0+=static_cast<long>(16L))
        {
            auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x0), 16);
            tmp0.store(out_ptr0 + static_cast<long>(x0));
        }
        #pragma omp simd simdlen(8)
        for(long x0=static_cast<long>(16L); x0<static_cast<long>(25L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = in_ptr0[static_cast<long>(x0)];
            out_ptr0[static_cast<long>(x0)] = tmp0;
        }
    }
    {
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(16L); x0+=static_cast<long>(16L))
        {
            auto tmp0 = at::vec::VectorizedN<int64_t,2>::loadu(in_ptr1 + static_cast<long>(x0), 16);
            auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x0), 16);
            auto tmp1 = 25L;
            auto tmp2 = c10::convert<int64_t>(tmp1);
            auto tmp3 = at::vec::VectorizedN<int64_t,2>(tmp2);
            auto tmp4 = tmp0 + tmp3;
            auto tmp5 = static_cast<int64_t>(0);
            auto tmp6 = at::vec::VectorizedN<int64_t,2>(tmp5);
            auto tmp7 = at::vec::VecMask<int64_t,2>(tmp0 < tmp6);
            auto tmp8 = decltype(tmp4)::blendv(tmp0, tmp4, tmp7.template cast<int64_t,2>());
            auto tmp9 =
            [&]
            {
                __at_align__ std::array<int64_t, 16> tmpbuf;
                tmp8.store(tmpbuf.data());
                return tmpbuf;
            }
            ()
            ;
            auto tmp10 =
            [&]
            {
                __at_align__ std::array<int64_t, 16> tmpbuf;
                #pragma GCC unroll 16
                for (long x0_inner = 0; x0_inner < 16; x0_inner++)
                {
                    tmpbuf[x0_inner] = static_cast<long>(tmp9[x0_inner]);
                }
                return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);
            }
            ()
            ;
            TORCH_CHECK((at::vec::VecMask<int64_t,2>((at::vec::VectorizedN<int64_t,2>(0) <= tmp10) & (tmp10 < at::vec::VectorizedN<int64_t,2>(25L)))).all_masked(), "index out of bounds: 0 <= tmp10 < 25L");
            atomic_add_vec(out_ptr0, tmp8, tmp12);
        }
        #pragma omp simd simdlen(8)
        for(long x0=static_cast<long>(16L); x0<static_cast<long>(20L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = in_ptr1[static_cast<long>(x0)];
            auto tmp9 = in_ptr2[static_cast<long>(x0)];
            auto tmp1 = 25L;
            auto tmp2 = c10::convert<int64_t>(tmp1);
            auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
            auto tmp4 = tmp0 < 0;
            auto tmp5 = tmp4 ? tmp3 : tmp0;
            auto tmp6 = tmp5;
            auto tmp7 = c10::convert<int64_t>(tmp6);
            TORCH_CHECK((0 <= tmp7) & (tmp7 < 25L), "index out of bounds: 0 <= tmp7 < 25L");
            atomic_add(&out_ptr0[static_cast<long>(tmp5)], static_cast<float>(tmp9));
        }
    }
}
''')
```

Pull Request resolved: pytorch#131314
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel
  • Loading branch information
zhuhaozhe authored and pytorchmergebot committed Aug 26, 2024
1 parent bf5c7bf commit 1ff226d
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 10 deletions.
26 changes: 26 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1972,6 +1972,32 @@ def _internal_check(
with config.patch({"cpp.dynamic_threads": True}), set_num_threads(1):
_internal_check(fn, inps, "aten.scatter_reduce_")

@patch("torch.cuda.is_available", lambda: False)
@requires_vectorization
@torch._inductor.config.patch({"cpp.fallback_scatter_reduce_sum": False})
def test_scatter_using_atomic_add_vec(self):
def fn(a, dim, index, b):
return aten.scatter(a, dim, index, b, reduce="add")

inps = (
torch.zeros(1, 1, 25),
2,
torch.tensor([[[3, 5, 7, 9] * 5]]),
torch.ones(1, 1, 25),
)
torch._dynamo.reset()
metrics.reset()
self.common(fn, inps)
assert metrics.generated_cpp_vec_kernel_count == 2

with set_num_threads(1), config.patch(
{"fx_graph_cache": False, "fx_graph_remote_cache": False}
):
torch._dynamo.reset()
metrics.reset()
self.common(fn, inps)
assert metrics.generated_cpp_vec_kernel_count == 2

@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
@requires_vectorization
@patch("torch.cuda.is_available", lambda: False)
Expand Down
42 changes: 32 additions & 10 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2157,6 +2157,7 @@ def _load_or_store_non_contiguous(
dtype: torch.dtype,
buffer: Optional[IndentedBuffer] = None,
store_value: Optional[Union[str, CppCSEVariable]] = None,
accu_store: bool = False,
) -> Optional[CppCSEVariable]:
"""
Load or store a vector in a non-contiguous way. The vector is initialized from an array that is
Expand All @@ -2171,10 +2172,12 @@ def _load_or_store_non_contiguous(
:param dtype: data type of `var` or `index` if `var` is None.
:param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`.
:param store_value: the value to store. If None, we load the vector.
:param accu_store: whether accumulate the store_value to store_ptr. If True, a store_value should be provided
:return: a CppCSEVariable that represents the loaded vector or None if it is a store.
"""
assert not store_value or var is not None, "store var must be provided"

if accu_store:
assert store_value
if buffer is None:
buffer = self.loads

Expand Down Expand Up @@ -2261,7 +2264,8 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
code.writeline(f"if ({load_mask})")
stack.enter_context(code.indent())
if store_value:
code.writeline(f"{rhs} = tmpbuf[{itervar_inner}];")
conjunction = "+=" if accu_store else "="
code.writeline(f"{rhs} {conjunction} tmpbuf[{itervar_inner}];")
else:
code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};")
if not store_value:
Expand Down Expand Up @@ -2304,6 +2308,7 @@ def _get_store_line(
var: str,
index: sympy.Expr,
dtype: torch.dtype,
accu_store: bool = False,
):
"""
Get a store line buffer that stores `value` into `var` at `index` of `dtype`. It handles
Expand All @@ -2328,21 +2333,42 @@ def _get_store_line(
code.writeline(f"{value}.store({var_expr}, {self.num_elems});")
else:
self._load_or_store_non_contiguous(
var, index, dtype, buffer=code, store_value=value
var, index, dtype, buffer=code, store_value=value, accu_store=accu_store
)
return code

def store(self, name, index, value, mode=None):
assert "buf" in name
assert mode is None
assert isinstance(value, CppCSEVariable), value
if not value.is_vec:
# this happens when we store a scalar into a vectorized buffer like "fill"
value = self.broadcast(value)
var = self.args.output(name)
index = self.rename_indexing(index)
code = self._get_store_line(value, var, index, V.graph.get_dtype(name))
self.stores.splice(code.map(lambda x: DeferredLine(name, x)))
dtype = V.graph.get_dtype(name)
if mode is None:
code = self._get_store_line(value, var, index, dtype)
self.stores.splice(code.map(lambda x: DeferredLine(name, x)))
elif mode == "atomic_add":
if not config.cpp.dynamic_threads and self.num_threads == 1:
code = self._get_store_line(
f"{value}",
var,
index,
dtype,
accu_store=True,
)
self.stores.splice(code.map(lambda x: DeferredLine(name, x)))
else:
n_src = self._get_num_vectors(dtype)
n_idx = self._get_num_vectors(torch.int64)
cdtype = DTYPE_TO_CPP[dtype]
index = ops.index_expr(index, torch.int64).value
assert index.is_vec
line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});"
self.stores.writeline(DeferredLine(name, line))
else:
raise NotImplementedError(f"store mode={mode}")

def reduction(self, dtype, src_dtype, reduction_type, value):
assert reduction_type in VECTORIZABLE_RTYPES
Expand Down Expand Up @@ -3081,10 +3107,6 @@ def store(self, name, index, value, mode=None):
assert "buf" in name
index = self.rename_indexing(index)

if mode:
self.disable_vec(f"store mode: {mode}")
return self.simd_vec

return self.simd_vec

def reduction(self, dtype, src_dtype, reduction_type, value):
Expand Down
15 changes: 15 additions & 0 deletions torch/_inductor/codegen/cpp_prefix.h
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,21 @@ atomic_add(volatile T *addr, T offset) {
atomic_addr->fetch_add(offset, std::memory_order_relaxed);
}

#if INDUCTOR_USE_VECTOR_TYPES()
template <typename T, int NI, int NV>
void atomic_add_vec(T *addr, at::vec::VectorizedN<int64_t, NI> index, at::vec::VectorizedN<T, NV> offset) {
constexpr int len = at::vec::VectorizedN<int64_t, NI>::size();
static_assert(len <= at::vec::VectorizedN<T, NV>::size());
__at_align__ std::array<T, len> tmpbuf;
__at_align__ std::array<int64_t, len> tmpidx;
offset.store(tmpbuf.data());
index.store(tmpidx.data());
for (int i = 0; i < len; i++){
atomic_add(addr + tmpidx[i], tmpbuf[i]);
}
}
#endif

void mm_get_thread_blocking(
int num_threads,
int64_t M,
Expand Down

0 comments on commit 1ff226d

Please sign in to comment.