Skip to content

Commit

Permalink
remove unused enum and fix a cast
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Jan 27, 2024
1 parent 55cc23a commit ede11d8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 15 deletions.
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cuda/atomic/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ __device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* addr
// 00000000 | 00000000 | 00000000 | ..... byte 2 .....
old_byte = (old >> shift) & AtomicCasType<ValueType>::mask;
// Use + for atomic addition, * for atomic multiplication, / for atomic division.
newval = static_cast<uint32_t>(func(val, static_cast<ValueType>(old_byte)));
newval = reinterpret_cast<uint32_t&>(func(val, reinterpret_cast<ValueType&>(old_byte)));
// Journey of a 32-bit value (cont'd):
//
// old
Expand Down Expand Up @@ -307,16 +307,16 @@ __device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType
// Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS.
// 4
// 8
auto observed_as_cas_type = *reinterpret_cast<CasType*>(&observed);
auto new_value_as_cas_type = *reinterpret_cast<CasType*>(&new_value);
auto observed_as_cas_type = reinterpret_cast<CasType&>(&observed);
auto new_value_as_cas_type = reinterpret_cast<CasType&>(&new_value);

// Call atomicCAS as if the 2-byte type variables are all unsigned short int.
// 4 unsigned int (or int)
// 8 unsigned long long int
auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type);

// Cast the freshly observed value in memory back to the TwoByteType.
observed = *reinterpret_cast<ValueType*>(&cas_observed_as_cas_type);
observed = reinterpret_cast<ValueType&>(&cas_observed_as_cas_type);

// Two cases:
// 1. compare-and-swap success
Expand Down
11 changes: 0 additions & 11 deletions onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,6 @@ namespace onnxruntime {
namespace cuda {

struct GatherScatterElementsArgs {
enum class Operation {
NONE,
ADD,
MUL,
MAX,
MIN
};

int64_t rank;
int64_t axis;
int64_t input_size;
Expand All @@ -27,9 +19,6 @@ struct GatherScatterElementsArgs {
TArray<fast_divmod> indices_fdms;
TArray<int64_t> indices_strides;
int64_t indices_size;
// operation used to combine values associated the same
// memory location in the output tensor.
Operation operation;
};

template <typename T, typename TIndex>
Expand Down

0 comments on commit ede11d8

Please sign in to comment.