Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPUAdam fp16 and bf16 support #5409

Merged
merged 37 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
65faa95
[SW-0] allow running tests on simulator
BacharL Apr 3, 2024
f18eef3
fused adam for hpu
BacharL May 5, 2024
c389971
CPUAdam fp16 and bf16 support
BacharL Apr 15, 2024
04a95e5
add missing functions
BacharL May 5, 2024
4613804
remvoe set_dtype function, move dtype argument to constructor
BacharL May 5, 2024
4f13b2e
fix dead code
BacharL May 5, 2024
fd9901d
cleanup half_precision
BacharL May 5, 2024
9b3151d
remove HALF_DTYPE compiler define
BacharL May 6, 2024
f940312
apply changes to cpulion and cpuadagrad
BacharL May 6, 2024
beb0d6c
fix compile errors
BacharL May 6, 2024
f7ee43c
fix compile errors
BacharL May 6, 2024
d8efc50
pre commit
BacharL May 6, 2024
a5a0531
codepsell fixes
BacharL May 6, 2024
42cbf95
fix typo
BacharL May 6, 2024
aaeffe6
Merge branch 'master' into hab_cpu_adam
tjruwase May 7, 2024
d9d2188
cpu adam templated param,state and device types
BacharL May 8, 2024
4d150cf
fix cuda build
BacharL May 8, 2024
930ca33
pre commit
BacharL May 8, 2024
468a314
apply changes to adagrad and lion
BacharL May 8, 2024
2987e5f
pre commit
BacharL May 8, 2024
3c86804
fix typos
BacharL May 8, 2024
cebcd21
Merge branch 'master' into hab_cpu_adam
tjruwase May 8, 2024
a690a0d
Revert "fused adam for hpu"
BacharL May 8, 2024
f518a70
cleanup device params and cuda specific code from cpuadam
BacharL May 9, 2024
79477f3
cleanup cuda code from cpulion and cpuadagrad
BacharL May 9, 2024
023d1f8
fix builders
BacharL May 9, 2024
005c7cc
pre commit
BacharL May 9, 2024
9b11559
fix builders
BacharL May 9, 2024
48ac129
fix cpulion build
BacharL May 9, 2024
abc5b2b
remove unused function declerations
BacharL May 9, 2024
6bd525a
Merge branch 'master' into hab_cpu_adam
tjruwase May 9, 2024
8aa6a20
fix bf16 to numpy conversion in tests
BacharL May 12, 2024
cbef0bf
Merge branch 'master' into hab_cpu_adam
tjruwase May 13, 2024
23625cf
skip test_torch_adamw_equal for invalid parameter combinations
BacharL May 13, 2024
e8e1b25
remove fp16_param_groups argument
BacharL May 19, 2024
0f5ffc3
pre commit
BacharL May 19, 2024
62975f2
Merge branch 'master' into hab_cpu_adam
tjruwase May 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 83 additions & 124 deletions csrc/adagrad/cpu_adagrad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,55 +5,38 @@

#include "cpu_adagrad.h"
#include <torch/extension.h>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <type_traits>
#include <unordered_map>
#if defined(__ENABLE_CUDA__)
#include <cuda_runtime_api.h>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#include "custom_cuda_layers.h"
#endif

using namespace std::string_literals;
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;

// C++ interface

void Adagrad_Optimizer::Step_1(float* _params,
float* grads,
float* _exp_avg_sq,
size_t _param_size,
ds_half_precision_t* dev_params,
bool half_precision)
template <typename ds_params_percision_t, typename ds_state_precision_t>
void Adagrad_Optimizer::Step_1(ds_params_percision_t* _params,
ds_params_percision_t* grads,
ds_state_precision_t* _exp_avg_sq,
size_t _param_size)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
Step_AVX<1>(
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
Step_AVX<1>(&rounded_size, _params, grads, _exp_avg_sq, _param_size);
#endif
if (_param_size > rounded_size) {
float step_size = -1 * _alpha;
ds_half_precision_t* grads_cast_h;
ds_half_precision_t* params_cast_h;
if (half_precision) {
grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
}
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
float param = half_precision ? (float)params_cast_h[k] : _params[k];
float grad = (float)grads[k];
float param = (float)_params[k];
float momentum = grads[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0) { grad = param * _weight_decay + grad; }
Expand All @@ -64,58 +47,30 @@ void Adagrad_Optimizer::Step_1(float* _params,
grad += _eps;
grad = momentum / grad;
param = grad * step_size + param;
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
#endif
if (half_precision)
params_cast_h[k] = (ds_half_precision_t)param;
else
_params[k] = param;
_params[k] = param;
// STORE UPDATE TERM TO GRAD'S MEMORY
grads[k] = grad * step_size;
_exp_avg_sq[k] = variance;
}
#if defined(__ENABLE_CUDA__)
if (dev_params) {
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
}
}
}

void Adagrad_Optimizer::Step_4(float* _params,
float* grads,
float* _exp_avg_sq,
size_t _param_size,
ds_half_precision_t* dev_params,
bool half_precision)
template <typename ds_params_percision_t, typename ds_state_precision_t>
void Adagrad_Optimizer::Step_4(ds_params_percision_t* _params,
ds_params_percision_t* grads,
ds_state_precision_t* _exp_avg_sq,
size_t _param_size)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
Step_AVX<4>(
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
Step_AVX<4>(&rounded_size, _params, grads, _exp_avg_sq, _param_size);
#endif
if (_param_size > rounded_size)
Step_1((_params + rounded_size),
(grads + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
half_precision);
(_param_size - rounded_size));
}

int create_adagrad_optimizer(int optimizer_id,
Expand Down Expand Up @@ -149,25 +104,77 @@ int create_adagrad_optimizer(int optimizer_id,
return 0;
}

void Adagrad_Optimizer::Step_8(float* _params,
float* grads,
float* _exp_avg_sq,
size_t _param_size,
ds_half_precision_t* dev_params,
bool half_precision)
template <typename ds_params_percision_t, typename ds_state_precision_t>
void Adagrad_Optimizer::Step_8(ds_params_percision_t* _params,
ds_params_percision_t* grads,
ds_state_precision_t* _exp_avg_sq,
size_t _param_size)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
Step_AVX<8>(
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
Step_AVX<8>(&rounded_size, _params, grads, _exp_avg_sq, _param_size);
#endif
if (_param_size > rounded_size)
Step_4((_params + rounded_size),
(grads + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
half_precision);
(_param_size - rounded_size));
}

template <typename ds_params_percision_t, typename ds_state_precision_t>
void step_invoker(std::shared_ptr<Adagrad_Optimizer> opt,
void* _params,
void* grads,
void* _exp_avg_sq,
size_t _param_size)
{
opt->Step_8((ds_params_percision_t*)(_params),
(ds_params_percision_t*)(grads),
(ds_state_precision_t*)(_exp_avg_sq),
_param_size);
}

std::map<std::tuple<c10::ScalarType, c10::ScalarType>,
std::function<void(std::shared_ptr<Adagrad_Optimizer>, void*, void*, void*, size_t)>>
invokers;

// Fill map with template functions for each type
template <class ds_params_percision_t, class ds_state_precision_t>
void create_invoker()
{
invokers[std::tuple(c10::CppTypeToScalarType<ds_params_percision_t>(),
c10::CppTypeToScalarType<ds_state_precision_t>())] =
step_invoker<ds_params_percision_t, ds_state_precision_t>;
}
struct InvokerInitializer {
InvokerInitializer()
{
create_invoker<c10::Half, float>();
create_invoker<c10::Half, c10::Half>();
create_invoker<c10::BFloat16, float>();
create_invoker<c10::BFloat16, c10::BFloat16>();
create_invoker<float, float>();
}
} _invoker_initializer;

void invoke(std::shared_ptr<Adagrad_Optimizer> opt,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg_sq,
size_t param_size)
{
c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype());
c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg_sq.options().dtype());

auto it = invokers.find(std::tuple(params_type, state_type));
if (it == invokers.end()) {
throw std::runtime_error("Adagrad optimizer with param type "s +
c10::toString(params_type) + " and state type "s +
c10::toString(state_type) +
" is not supported on current hardware"s);
}

it->second(opt, params.data_ptr(), grads.data_ptr(), exp_avg_sq.data_ptr(), param_size);
}

int ds_adagrad_step(int optimizer_id,
Expand All @@ -183,58 +190,13 @@ int ds_adagrad_step(int optimizer_id,
auto grads_c = grads.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();

float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();

std::shared_ptr<Adagrad_Optimizer> opt =
std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step);
opt->update_state(lr, epsilon, weight_decay);
opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel());

#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
opt->SynchronizeStreams();
#endif
return 0;
}
invoke(opt, params_c, grads_c, exp_avg_sq_c, params_c.numel());

int ds_adagrad_step_plus_copy(int optimizer_id,
size_t step,
float lr,
float epsilon,
float weight_decay,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
{
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
auto grads_c = grads.contiguous();

float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();

std::shared_ptr<Adagrad_Optimizer> opt =
std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step);
opt->update_state(lr, epsilon, weight_decay);
opt->Step_8(params_ptr,
grads_ptr,
exp_avg_sq_ptr,
params_c.numel(),
gpu_params_ptr,
(params.options().dtype() == at::kHalf));

opt->SynchronizeStreams();
#else
assert(false);
#endif
return 0;
}

Expand All @@ -248,9 +210,6 @@ int destroy_adagrad_optimizer(int optimizer_id)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)");
m.def("adagrad_update_copy",
&ds_adagrad_step_plus_copy,
"DeepSpeed CPU Adagrad update and param copy (C++)");
m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)");
m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)");
}
3 changes: 0 additions & 3 deletions csrc/adam/cpu_adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
m.def("adam_update_copy",
&ds_adam_step_plus_copy,
"DeepSpeed CPU Adam update and param copy (C++)");
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
}
Loading
Loading