Skip to content

Commit

Permalink
64bit indexing fused adam (#5187)
Browse files Browse the repository at this point in the history
## The Issue

Applying `FusedAdam` on large tensors will cause an error `CUDA error:
an illegal memory access was encountered`.

#3429

NVIDIA/apex#1654

## PR Content

Following the solution in the apex repository
(NVIDIA/apex#1765), changing indexing type to
`int64` if necessary.

---------

Co-authored-by: Michael Wyatt <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored Apr 22, 2024
1 parent 3f875d9 commit 9b6ef9e
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 26 deletions.
77 changes: 56 additions & 21 deletions csrc/adam/multi_tensor_adam.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ typedef enum : int {

using MATH_T = float;

template <typename T>
template <typename T, typename index_t>
struct AdamFunctor {
__device__ __forceinline__ void operator()(int chunk_size,
volatile int* noop_gmem,
Expand All @@ -48,13 +48,13 @@ struct AdamFunctor {
// if(*noop_gmem == 1)
// return;

int tensor_loc = tl.block_to_tensor[blockIdx.x];
index_t tensor_loc = tl.block_to_tensor[blockIdx.x];

// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;

int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];

T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
Expand All @@ -71,7 +71,8 @@ struct AdamFunctor {
n -= chunk_idx * chunk_size;

// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
for (index_t i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
Expand Down Expand Up @@ -146,23 +147,57 @@ void multi_tensor_adam_cuda(int chunk_size,
bias_correction2 = 1 - std::pow(beta2, step);
}

size_t max_size = 0;
bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
if (it2->numel() > max_size) {
max_size = it2->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
}
}
}
if (requires_64bit_indexing) { break; }
}

// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
if (requires_64bit_indexing) {
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE,
(int64_t)chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, int64_t>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
} else {
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, int32_t>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
}

AT_CUDA_CHECK(cudaGetLastError());
}
10 changes: 5 additions & 5 deletions csrc/adam/multi_tensor_apply.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct TensorListMetadata {
};

template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int chunk_size,
__global__ void multi_tensor_apply_kernel(int64_t chunk_size,
volatile int* noop_flag,
T tl,
U callable,
Expand All @@ -46,8 +46,8 @@ __global__ void multi_tensor_apply_kernel(int chunk_size,
}

template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(int block_size,
int chunk_size,
void multi_tensor_apply(int64_t block_size,
int64_t chunk_size,
const at::Tensor& noop_flag,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
Expand Down Expand Up @@ -91,9 +91,9 @@ void multi_tensor_apply(int block_size,
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;

int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;

for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) {
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
Expand Down

0 comments on commit 9b6ef9e

Please sign in to comment.