Skip to content

Commit

Permalink
CPUAdam fp16 and bf16 support
Browse files Browse the repository at this point in the history
Change-Id: I7846288693bdbb70884689dab8f9934109570f32
  • Loading branch information
BacharL committed Apr 17, 2024
1 parent 54c0687 commit a9d5b2c
Show file tree
Hide file tree
Showing 34 changed files with 639 additions and 110 deletions.
38 changes: 31 additions & 7 deletions csrc/adagrad/cpu_adagrad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;

// C++ interface

template <typename T>
void Adagrad_Optimizer::Step_1(float* _params,
float* grads,
float* _exp_avg_sq,
Expand All @@ -30,7 +31,7 @@ void Adagrad_Optimizer::Step_1(float* _params,
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
Step_AVX<1>(
Step_AVX<1, T>(
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
#endif
if (_param_size > rounded_size) {
Expand Down Expand Up @@ -97,6 +98,7 @@ void Adagrad_Optimizer::Step_1(float* _params,
}
}

template <typename T>
void Adagrad_Optimizer::Step_4(float* _params,
float* grads,
float* _exp_avg_sq,
Expand All @@ -106,11 +108,11 @@ void Adagrad_Optimizer::Step_4(float* _params,
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
Step_AVX<4>(
Step_AVX<4, T>(
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
#endif
if (_param_size > rounded_size)
Step_1((_params + rounded_size),
Step_1<T>((_params + rounded_size),
(grads + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
Expand Down Expand Up @@ -149,6 +151,7 @@ int create_adagrad_optimizer(int optimizer_id,
return 0;
}

template <typename T>
void Adagrad_Optimizer::Step_8(float* _params,
float* grads,
float* _exp_avg_sq,
Expand All @@ -158,11 +161,11 @@ void Adagrad_Optimizer::Step_8(float* _params,
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
Step_AVX<8>(
Step_AVX<8, T>(
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
#endif
if (_param_size > rounded_size)
Step_4((_params + rounded_size),
Step_4<T>((_params + rounded_size),
(grads + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
Expand Down Expand Up @@ -191,7 +194,12 @@ int ds_adagrad_step(int optimizer_id,
std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step);
opt->update_state(lr, epsilon, weight_decay);
opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel());
if (params.options().dtype() == at::kHalf)
opt->Step_8<c10:Half>(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel(), nullptr, true);
else if (params.options().dtype() == at::kBFloat16)
opt->Step_8<c10:BFloat>(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel(), nullptr, true);
else
opt->Step_8<float>(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel(), nullptr, false);

#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
opt->SynchronizeStreams();
Expand Down Expand Up @@ -224,7 +232,23 @@ int ds_adagrad_step_plus_copy(int optimizer_id,
std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step);
opt->update_state(lr, epsilon, weight_decay);
opt->Step_8(params_ptr,

if (params.options().dtype() == at::kHalf)
opt->Step_8<c10::Half>(params_ptr,
grads_ptr,
exp_avg_sq_ptr,
params_c.numel(),
gpu_params_ptr,
true);
else if (params.options().dtype() == at::kBFloat16)
opt->Step_8<c10::BFloat>(params_ptr,
grads_ptr,
exp_avg_sq_ptr,
params_c.numel(),
gpu_params_ptr,
true);
else
opt->Step_8<float>(params_ptr,
grads_ptr,
exp_avg_sq_ptr,
params_c.numel(),
Expand Down
60 changes: 41 additions & 19 deletions csrc/adam/cpu_adam_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;

// C++ interface

template <typename T>
void Adam_Optimizer::Step_1(float* _params,
float* grads,
float* _exp_avg,
Expand All @@ -33,7 +34,7 @@ void Adam_Optimizer::Step_1(float* _params,
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
Step_AVX<1>(&rounded_size,
Step_AVX<1, T>(&rounded_size,
_params,
grads,
_exp_avg,
Expand Down Expand Up @@ -116,6 +117,7 @@ void Adam_Optimizer::Step_1(float* _params,
}
}

template <typename T>
void Adam_Optimizer::Step_4(float* _params,
float* grads,
float* _exp_avg,
Expand All @@ -126,7 +128,7 @@ void Adam_Optimizer::Step_4(float* _params,
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
Step_AVX<4>(&rounded_size,
Step_AVX<4, T>(&rounded_size,
_params,
grads,
_exp_avg,
Expand All @@ -136,7 +138,7 @@ void Adam_Optimizer::Step_4(float* _params,
half_precision);
#endif
if (_param_size > rounded_size)
Step_1((_params + rounded_size),
Step_1<T>((_params + rounded_size),
(grads + rounded_size),
(_exp_avg + rounded_size),
(_exp_avg_sq + rounded_size),
Expand Down Expand Up @@ -185,6 +187,7 @@ int create_adam_optimizer(int optimizer_id,
return 0;
}

template <typename T>
void Adam_Optimizer::Step_8(float* _params,
float* grads,
float* _exp_avg,
Expand All @@ -195,7 +198,7 @@ void Adam_Optimizer::Step_8(float* _params,
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
Step_AVX<8>(&rounded_size,
Step_AVX<8, T>(&rounded_size,
_params,
grads,
_exp_avg,
Expand All @@ -205,7 +208,7 @@ void Adam_Optimizer::Step_8(float* _params,
half_precision);
#endif
if (_param_size > rounded_size)
Step_4((_params + rounded_size),
Step_4<T>((_params + rounded_size),
(grads + rounded_size),
(_exp_avg + rounded_size),
(_exp_avg_sq + rounded_size),
Expand Down Expand Up @@ -244,13 +247,15 @@ int ds_adam_step(int optimizer_id,
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);

opt->Step_8(params_ptr,
grads_ptr,
exp_avg_ptr,
exp_avg_sq_ptr,
params_c.numel(),
nullptr,
(params.options().dtype() == at::kHalf));
if (params.options().dtype() == at::kHalf)
opt->Step_8<c10::Half>(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.numel(), nullptr, true);
else if (params.options().dtype() == at::kBFloat16)
opt->Step_8<c10::BFloat16>(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.numel(), nullptr, true);
else
opt->Step_8<float>(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.numel(), nullptr, false);

#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
opt->SynchronizeStreams();
Expand Down Expand Up @@ -289,13 +294,30 @@ int ds_adam_step_plus_copy(int optimizer_id,
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr,
grads_ptr,
exp_avg_ptr,
exp_avg_sq_ptr,
params_c.numel(),
device_params_ptr,
(params.options().dtype() == at::kHalf));
if (params.options().dtype() == at::kHalf)
opt->Step_8<c10::Half>(params_ptr,
grads_ptr,
exp_avg_ptr,
exp_avg_sq_ptr,
params_c.numel(),
device_params_ptr,
true);
else if (params.options().dtype() == at::kBFloat16)
opt->Step_8<c10::BFloat16>(params_ptr,
grads_ptr,
exp_avg_ptr,
exp_avg_sq_ptr,
params_c.numel(),
device_params_ptr,
true);
else
opt->Step_8<float>(params_ptr,
grads_ptr,
exp_avg_ptr,
exp_avg_sq_ptr,
params_c.numel(),
device_params_ptr,
false);

opt->SynchronizeStreams();
#else
Expand Down
40 changes: 40 additions & 0 deletions csrc/common/custom_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,43 @@ void launch_param_update_half(const float* input, __half* output, int size, cuda

param_update_kernel_half<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
}

#ifdef BF16_AVAILABLE
__global__ void param_update_kernel(const float* input, __nv_bfloat16* output, int size)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;

if (id < size) { output[id] = (__nv_bfloat16)input[id]; }
}

void launch_param_update(const float* input, __nv_bfloat16* output, int size, cudaStream_t stream)
{
int threads = 1024;

dim3 grid_dim((size - 1) / threads + 1);
dim3 block_dim(threads);

param_update_kernel<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
}

__global__ void param_update_kernel_half(const float* input, __nv_bfloat16* output, int size)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
__nv_bfloat162* output_cast = reinterpret_cast<__nv_bfloat162*>(output);
if (id < size) {
float input_f = input[id];
__nv_bfloat162* input_h = reinterpret_cast<__nv_bfloat162*>(&input_f);
output_cast[id] = *input_h;
}
}

void launch_param_update_half(const float* input, __nv_bfloat16* output, int size, cudaStream_t stream)
{
int threads = 1024;
size /= 2;
dim3 grid_dim((size - 1) / threads + 1);
dim3 block_dim(threads);

param_update_kernel_half<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
}
#endif
35 changes: 24 additions & 11 deletions csrc/includes/cpu_adagrad.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,29 @@
#include <cassert>
#include "simd.h"

#ifndef HALF_DTYPE
#error Must provide compiler option -DHALF_DTYPE=<half data type>
#endif

#if defined(__ENABLE_CUDA__)
#include <cuda_fp16.h>
#ifdef BF16_AVAILABLE
#include <cuda_bf16.h>
#endif
#include <cuda_runtime_api.h>
#include "cuda.h"
#include "custom_cuda_layers.h"
typedef __half ds_half_precision_t;
#elif defined(__ENABLE_CANN__)
#include "acl/acl.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"
typedef c10::Half ds_half_precision_t;
#else
typedef unsigned short ds_half_precision_t;
#endif

typedef HALF_DTYPE ds_half_precision_t;

#define STEP(SPAN) \
template <typename T> \
void Step_##SPAN(float* _params, \
float* grads, \
float* _exp_avg_sq, \
Expand Down Expand Up @@ -64,7 +72,7 @@ class Adagrad_Optimizer {
#endif
}
#if defined(__AVX512__) or defined(__AVX256__)
template <int span>
template <int span, typename T>
void Step_AVX(size_t* rounded_size,
float* _params,
float* grads,
Expand Down Expand Up @@ -121,7 +129,7 @@ class Adagrad_Optimizer {
};

#if defined(__AVX512__) or defined(__AVX256__)
template <int span>
template <int span, typename T>
void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
float* _params,
float* grads,
Expand All @@ -130,6 +138,11 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
ds_half_precision_t* dev_params,
bool half_precision)
{
#if !defined(__AVX512__)
if (std::is_same_v<T, c10::BFloat16>) {
return;
}
#endif
size_t new_rounded_size = 0;
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
Expand All @@ -153,16 +166,16 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
AVX_Data grad_4[span];
simd_load<span>(grad_4, grads + i, half_precision);
simd_load<span, T>(grad_4, grads + i);

AVX_Data momentum_4[span];
simd_load<span>(momentum_4, grads + i, false);
simd_load<span, float>(momentum_4, grads + i);

AVX_Data variance_4[span];
simd_load<span>(variance_4, _exp_avg_sq + i, false);
simd_load<span, float>(variance_4, _exp_avg_sq + i);

AVX_Data param_4[span];
simd_load<span>(param_4, _params + i, half_precision);
simd_load<span, T>(param_4, _params + i);

if (_weight_decay > 0) { simd_fma<span>(grad_4, param_4, weight_decay4, grad_4); }

Expand All @@ -172,13 +185,13 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
simd_div<span>(grad_4, momentum_4, grad_4);
simd_fma<span>(param_4, grad_4, step_size_4, param_4);

simd_store<span>(_params + i, param_4, half_precision);
simd_store<span, T>(_params + i, param_4);
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) {
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
simd_store<span, T>(_doubled_buffer[_buf_index] + (i - t), param_4);
}
#endif
simd_store<span>(_exp_avg_sq + i, variance_4, false);
simd_store<span, float>(_exp_avg_sq + i, variance_4);
}
#if defined(__ENABLE_CUDA__)
if (dev_params) {
Expand Down
Loading

0 comments on commit a9d5b2c

Please sign in to comment.