Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inductor] support vec for atomic add (pytorch#131314)
Depends on pytorch#130827 to have correct `index_expr` dtype Support vec for atomic add by scalar implementation. TestPlan: ``` python test/inductor/test_cpu_repro.py -k test_scatter_using_atomic_add_vec ``` Generated code for `test_scatter_using_atomic_add_vec` ``` cpp_fused_scatter_0 = async_compile.cpp_pybinding(['const float*', 'const int64_t*', 'const float*', 'float*'], ''' #include "/tmp/torchinductor_root/nn/cnnpkaxivwaa5rzng6qsyc4ao42vschogi3yk33ukwv3emlvxeqq.h" extern "C" void kernel(const float* in_ptr0, const int64_t* in_ptr1, const float* in_ptr2, float* out_ptr0) { { for(long x0=static_cast<long>(0L); x0<static_cast<long>(16L); x0+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x0), 16); tmp0.store(out_ptr0 + static_cast<long>(x0)); } #pragma omp simd simdlen(8) for(long x0=static_cast<long>(16L); x0<static_cast<long>(25L); x0+=static_cast<long>(1L)) { auto tmp0 = in_ptr0[static_cast<long>(x0)]; out_ptr0[static_cast<long>(x0)] = tmp0; } } { for(long x0=static_cast<long>(0L); x0<static_cast<long>(16L); x0+=static_cast<long>(16L)) { auto tmp0 = at::vec::VectorizedN<int64_t,2>::loadu(in_ptr1 + static_cast<long>(x0), 16); auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x0), 16); auto tmp1 = 25L; auto tmp2 = c10::convert<int64_t>(tmp1); auto tmp3 = at::vec::VectorizedN<int64_t,2>(tmp2); auto tmp4 = tmp0 + tmp3; auto tmp5 = static_cast<int64_t>(0); auto tmp6 = at::vec::VectorizedN<int64_t,2>(tmp5); auto tmp7 = at::vec::VecMask<int64_t,2>(tmp0 < tmp6); auto tmp8 = decltype(tmp4)::blendv(tmp0, tmp4, tmp7.template cast<int64_t,2>()); auto tmp9 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp8.store(tmpbuf.data()); return tmpbuf; } () ; auto tmp10 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; #pragma GCC unroll 16 for (long x0_inner = 0; x0_inner < 16; x0_inner++) { tmpbuf[x0_inner] = static_cast<long>(tmp9[x0_inner]); } return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16); } () ; TORCH_CHECK((at::vec::VecMask<int64_t,2>((at::vec::VectorizedN<int64_t,2>(0) <= tmp10) & (tmp10 < at::vec::VectorizedN<int64_t,2>(25L)))).all_masked(), "index out of bounds: 0 <= tmp10 < 25L"); atomic_add_vec(out_ptr0, tmp8, tmp12); } #pragma omp simd simdlen(8) for(long x0=static_cast<long>(16L); x0<static_cast<long>(20L); x0+=static_cast<long>(1L)) { auto tmp0 = in_ptr1[static_cast<long>(x0)]; auto tmp9 = in_ptr2[static_cast<long>(x0)]; auto tmp1 = 25L; auto tmp2 = c10::convert<int64_t>(tmp1); auto tmp3 = decltype(tmp0)(tmp0 + tmp2); auto tmp4 = tmp0 < 0; auto tmp5 = tmp4 ? tmp3 : tmp0; auto tmp6 = tmp5; auto tmp7 = c10::convert<int64_t>(tmp6); TORCH_CHECK((0 <= tmp7) & (tmp7 < 25L), "index out of bounds: 0 <= tmp7 < 25L"); atomic_add(&out_ptr0[static_cast<long>(tmp5)], static_cast<float>(tmp9)); } } } ''') ``` Pull Request resolved: pytorch#131314 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel
- Loading branch information