Skip to content

Commit

Permalink
Fix atomic
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Jan 26, 2024
1 parent c6ec60f commit df5b2e2
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 16 deletions.
59 changes: 59 additions & 0 deletions onnxruntime/core/providers/cuda/atomic/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@
namespace onnxruntime {
namespace cuda {

__device__ __forceinline__ void atomic_add(int8_t* address, int8_t val) {
int* address_as_int = reinterpret_cast<int*>(address);
int old = *address_as_int, assumed;
do {
assumed = old;
old = atomicCAS(address_as_int, assumed,
static_cast<int>(val) + assumed);
} while (assumed != old);
}

__device__ __forceinline__ void atomic_add(float *address, float value) {
atomicAdd(address, value);
}
Expand Down Expand Up @@ -122,5 +132,54 @@ __device__ __forceinline__ void AtomicAdd<half>(half* start_addr, size_t index,
#endif
}

__device__ __forceinline__ void atomic_mul(half* address, half val) {
unsigned short int* address_as_short = reinterpret_cast<unsigned short int*>(address);
unsigned short int old = *address_as_short, assumed;
do {
assumed = old;
old = atomicCAS(address_as_short, assumed,
__half_as_short(val * __short_as_half(assumed)));
} while (assumed != old);
}

__device__ __forceinline__ void atomic_mul(float* address, float val) {
int* address_as_int = reinterpret_cast<int*>(address);
int old = *address_as_int, assumed;
do {
assumed = old;
old = atomicCAS(address_as_int, assumed,
__float_as_int(val * __int_as_float(assumed)));
} while (assumed != old);
}

__device__ __forceinline__ void atomic_mul(double* address, double val) {
unsigned long long int* address_as_long_long = reinterpret_cast<unsigned long long int*>(address);
unsigned long long int old = *address_as_long_long, assumed;
do {
assumed = old;
old = atomicCAS(address_as_long_long, assumed,
__double_as_longlong(val * __longlong_as_double(assumed)));
} while (assumed != old);
}

// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh.
__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t val) {
size_t offset = (size_t)address & 3; \
uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \
uint32_t old = *address_as_ui; \
uint32_t shift = offset * 8; \
uint32_t old_byte; \
uint32_t newval; \
uint32_t assumed; \
\
do { \
assumed = old; \
old_byte = (old >> shift) & 0xff; \
newval = static_cast<uint8_t>(val * static_cast<int8_t>(old_byte)); \
newval = (old & ~(0x000000ff << shift)) | (newval << shift); \
old = atomicCAS(address_as_ui, assumed, newval); \
} while (assumed != old); \
}

} // namespace cuda
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ struct FuncAssignment {

template <class T>
struct FuncAdd {
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] += value; }
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { atomic_add(start_addr + index, value); }
};

template <class T>
struct FuncMul {
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] *= value; }
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { atomic_mul(start_addr + index, value); }
};

template <class T>
Expand Down
106 changes: 92 additions & 14 deletions onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,18 @@ void RunTest(const std::vector<int64_t>& input_dims, const std::vector<int64_t>&
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
}

for (int opset : {11, 18}) {
onnxruntime::test::OpTester test1("ScatterElements", opset);
if (has_axis) test1.AddAttribute<int64_t>("axis", axis);
test1.AddInput<T>("data", input_dims, input_data);
test1.AddInput<TIndex>("indices", indices_dims, indices_data);
test1.AddInput<T>("updates", indices_dims, updates_data);
test1.AddOutput<T>("y", input_dims, output_data);
if (std::is_same<T, int8_t>::value) {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
} else if (std::is_same<T, MLFloat16>::value) {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
} else {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
}
onnxruntime::test::OpTester test1("ScatterElements", 11);
if (has_axis) test1.AddAttribute<int64_t>("axis", axis);
test1.AddInput<T>("data", input_dims, input_data);
test1.AddInput<TIndex>("indices", indices_dims, indices_data);
test1.AddInput<T>("updates", indices_dims, updates_data);
test1.AddOutput<T>("y", input_dims, output_data);
if (std::is_same<T, int8_t>::value) {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
} else if (std::is_same<T, MLFloat16>::value) {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
} else {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
}
}

Expand Down Expand Up @@ -304,5 +302,85 @@ TEST(Scatter, BoolInputWithAxis) {
scatter_bool_with_axis_tests("ScatterElements", 11);
}

TEST(ScatterElements, AddReduction) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<std::string>("reduction", "add");

test.AddInput<float>("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
test.AddInput<int64_t>("indices", {4, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
test.AddInput<float>("updates", {4, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f});
test.AddOutput<float>("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)});

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider});
}

TEST(ScatterElements, AddReductionAxis1) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 1);
test.AddAttribute<std::string>("reduction", "add");

// update's slice shape is {2, 1}
test.AddInput<float>("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f});
test.AddInput<int64_t>("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1});
test.AddInput<float>("updates", {2, 4}, {2.f, 5.f, 3.f, 6.f, 7.f, 9.f, 8.f, 10.f});
test.AddOutput<float>("y", {2, 3}, {9.f, 4.f + (2.f + 5.f + 3.f + 6.f), 1.f, 7.f, 3.f + (7.f + 9.f + 8.f + 10.f), 6.f});

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider});
}

TEST(ScatterElements, MulReduction) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<std::string>("reduction", "mul");

test.AddInput<float>("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
test.AddInput<float>("updates", {2, 3}, {7.f, 3.f, 6.f, 7.f, 3.f, 6.f});
test.AddOutput<float>("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f * 7.f * 7.f, -3.f * 3.f * 3.f, -6.f * 6.f * 6.f});

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider});
}

TEST(ScatterElements, MulReductionAxis1) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 1);
test.AddAttribute<std::string>("reduction", "mul");

// update's slice shape is {2, 1}
test.AddInput<float>("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f});
test.AddInput<int64_t>("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1});
test.AddInput<float>("updates", {2, 4}, {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f});
test.AddOutput<float>("y", {2, 3}, {9.f, 4.f * (2.f * 3.f * 4.f * 5.f), 1.f, 7.f, 3.f * (6.f * 7.f * 8.f * 9.f), 6.f});

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider});
}

TEST(ScatterElements, MaxReduction) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<std::string>("reduction", "max");

test.AddInput<float>("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
test.AddInput<float>("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f});
test.AddOutput<float>("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f});

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider});
}

TEST(ScatterElements, MinReduction) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<std::string>("reduction", "min");

test.AddInput<float>("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f});
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
test.AddInput<float>("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f});
test.AddOutput<float>("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f});

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider});
}

} // namespace test
} // namespace onnxruntime

0 comments on commit df5b2e2

Please sign in to comment.