Skip to content

Commit

Permalink
CPUAdam fp16 and bf16 support (microsoft#5409)
Browse files Browse the repository at this point in the history
Hi.
Please review the following changes
I added support for BF16 to cpu adam. BF16, FP16 and float are supported
at compilation time. the correct template is called at runtime according
to input params dtype.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
2 people authored and sfc-gh-reyazda committed Jun 10, 2024
1 parent f53895f commit bb146c3
Show file tree
Hide file tree
Showing 27 changed files with 530 additions and 1,021 deletions.
207 changes: 83 additions & 124 deletions csrc/adagrad/cpu_adagrad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,55 +5,38 @@

#include "cpu_adagrad.h"
#include <torch/extension.h>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <type_traits>
#include <unordered_map>
#if defined(__ENABLE_CUDA__)
#include <cuda_runtime_api.h>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#include "custom_cuda_layers.h"
#endif

using namespace std::string_literals;
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;

// C++ interface

void Adagrad_Optimizer::Step_1(float* _params,
float* grads,
float* _exp_avg_sq,
size_t _param_size,
ds_half_precision_t* dev_params,
bool half_precision)
template <typename ds_params_percision_t, typename ds_state_precision_t>
void Adagrad_Optimizer::Step_1(ds_params_percision_t* _params,
ds_params_percision_t* grads,
ds_state_precision_t* _exp_avg_sq,
size_t _param_size)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
Step_AVX<1>(
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
Step_AVX<1>(&rounded_size, _params, grads, _exp_avg_sq, _param_size);
#endif
if (_param_size > rounded_size) {
float step_size = -1 * _alpha;
ds_half_precision_t* grads_cast_h;
ds_half_precision_t* params_cast_h;
if (half_precision) {
grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
}
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
float param = half_precision ? (float)params_cast_h[k] : _params[k];
float grad = (float)grads[k];
float param = (float)_params[k];
float momentum = grads[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0) { grad = param * _weight_decay + grad; }
Expand All @@ -64,58 +47,30 @@ void Adagrad_Optimizer::Step_1(float* _params,
grad += _eps;
grad = momentum / grad;
param = grad * step_size + param;
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
#endif
if (half_precision)
params_cast_h[k] = (ds_half_precision_t)param;
else
_params[k] = param;
_params[k] = param;
// STORE UPDATE TERM TO GRAD'S MEMORY
grads[k] = grad * step_size;
_exp_avg_sq[k] = variance;
}
#if defined(__ENABLE_CUDA__)
if (dev_params) {
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
}
}
}

void Adagrad_Optimizer::Step_4(float* _params,
float* grads,
float* _exp_avg_sq,
size_t _param_size,
ds_half_precision_t* dev_params,
bool half_precision)
template <typename ds_params_percision_t, typename ds_state_precision_t>
void Adagrad_Optimizer::Step_4(ds_params_percision_t* _params,
ds_params_percision_t* grads,
ds_state_precision_t* _exp_avg_sq,
size_t _param_size)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
Step_AVX<4>(
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
Step_AVX<4>(&rounded_size, _params, grads, _exp_avg_sq, _param_size);
#endif
if (_param_size > rounded_size)
Step_1((_params + rounded_size),
(grads + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
half_precision);
(_param_size - rounded_size));
}

int create_adagrad_optimizer(int optimizer_id,
Expand Down Expand Up @@ -149,25 +104,77 @@ int create_adagrad_optimizer(int optimizer_id,
return 0;
}

void Adagrad_Optimizer::Step_8(float* _params,
float* grads,
float* _exp_avg_sq,
size_t _param_size,
ds_half_precision_t* dev_params,
bool half_precision)
template <typename ds_params_percision_t, typename ds_state_precision_t>
void Adagrad_Optimizer::Step_8(ds_params_percision_t* _params,
ds_params_percision_t* grads,
ds_state_precision_t* _exp_avg_sq,
size_t _param_size)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
Step_AVX<8>(
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
Step_AVX<8>(&rounded_size, _params, grads, _exp_avg_sq, _param_size);
#endif
if (_param_size > rounded_size)
Step_4((_params + rounded_size),
(grads + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
half_precision);
(_param_size - rounded_size));
}

template <typename ds_params_percision_t, typename ds_state_precision_t>
void step_invoker(std::shared_ptr<Adagrad_Optimizer> opt,
void* _params,
void* grads,
void* _exp_avg_sq,
size_t _param_size)
{
opt->Step_8((ds_params_percision_t*)(_params),
(ds_params_percision_t*)(grads),
(ds_state_precision_t*)(_exp_avg_sq),
_param_size);
}

std::map<std::tuple<c10::ScalarType, c10::ScalarType>,
std::function<void(std::shared_ptr<Adagrad_Optimizer>, void*, void*, void*, size_t)>>
invokers;

// Fill map with template functions for each type
template <class ds_params_percision_t, class ds_state_precision_t>
void create_invoker()
{
invokers[std::tuple(c10::CppTypeToScalarType<ds_params_percision_t>(),
c10::CppTypeToScalarType<ds_state_precision_t>())] =
step_invoker<ds_params_percision_t, ds_state_precision_t>;
}
struct InvokerInitializer {
InvokerInitializer()
{
create_invoker<c10::Half, float>();
create_invoker<c10::Half, c10::Half>();
create_invoker<c10::BFloat16, float>();
create_invoker<c10::BFloat16, c10::BFloat16>();
create_invoker<float, float>();
}
} _invoker_initializer;

void invoke(std::shared_ptr<Adagrad_Optimizer> opt,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg_sq,
size_t param_size)
{
c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype());
c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg_sq.options().dtype());

auto it = invokers.find(std::tuple(params_type, state_type));
if (it == invokers.end()) {
throw std::runtime_error("Adagrad optimizer with param type "s +
c10::toString(params_type) + " and state type "s +
c10::toString(state_type) +
" is not supported on current hardware"s);
}

it->second(opt, params.data_ptr(), grads.data_ptr(), exp_avg_sq.data_ptr(), param_size);
}

int ds_adagrad_step(int optimizer_id,
Expand All @@ -183,58 +190,13 @@ int ds_adagrad_step(int optimizer_id,
auto grads_c = grads.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();

float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();

std::shared_ptr<Adagrad_Optimizer> opt =
std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step);
opt->update_state(lr, epsilon, weight_decay);
opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel());

#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
opt->SynchronizeStreams();
#endif
return 0;
}
invoke(opt, params_c, grads_c, exp_avg_sq_c, params_c.numel());

int ds_adagrad_step_plus_copy(int optimizer_id,
size_t step,
float lr,
float epsilon,
float weight_decay,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
{
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
auto grads_c = grads.contiguous();

float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();

std::shared_ptr<Adagrad_Optimizer> opt =
std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step);
opt->update_state(lr, epsilon, weight_decay);
opt->Step_8(params_ptr,
grads_ptr,
exp_avg_sq_ptr,
params_c.numel(),
gpu_params_ptr,
(params.options().dtype() == at::kHalf));

opt->SynchronizeStreams();
#else
assert(false);
#endif
return 0;
}

Expand All @@ -248,9 +210,6 @@ int destroy_adagrad_optimizer(int optimizer_id)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)");
m.def("adagrad_update_copy",
&ds_adagrad_step_plus_copy,
"DeepSpeed CPU Adagrad update and param copy (C++)");
m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)");
m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)");
}
3 changes: 0 additions & 3 deletions csrc/adam/cpu_adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
m.def("adam_update_copy",
&ds_adam_step_plus_copy,
"DeepSpeed CPU Adam update and param copy (C++)");
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
}
Loading

0 comments on commit bb146c3

Please sign in to comment.