Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

determine available modules #6

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .github/workflows/nv-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ jobs:
git rev-parse --short HEAD
pip install .

- name: Install datasets
run: |
pip install datasets

- name: Install deepspeed
run: |
pip install .[dev,1bit,autotuning,inf]
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ repos:
name: check-torchdist
entry: ./scripts/check-torchdist.py
language: python
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py|deepspeed/elasticity/elastic_agent.py|deepspeed/launcher/launch.py|tests/unit/comm/test_dist.py)
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py|deepspeed/elasticity/elastic_agent.py|deepspeed/launcher/launch.py|tests/unit/comm/test_dist.py|deepspeed/runtime/zero/utils.py|deepspeed/tools/pg_sim/ut/base.py|deepspeed/tools/pg_sim/pg.py)
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm

- repo: local
Expand Down
4 changes: 4 additions & 0 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ def create_op_builder(self, class_name):
def get_op_builder(self, class_name):
...

@abc.abstractmethod
def get_compile_backend(self):
...

@abc.abstractmethod
def build_extension(self):
...
Expand Down
3 changes: 3 additions & 0 deletions accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,6 @@ def build_extension(self):

def export_envs(self):
return []

def get_compile_backend(self):
return "inductor"
3 changes: 3 additions & 0 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,6 @@ def build_extension(self):

def export_envs(self):
return ['NCCL']

def get_compile_backend(self):
return "inductor"
11 changes: 11 additions & 0 deletions accelerator/hpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,17 @@ def get_op_builder(self, class_name):
else:
return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None

def get_compile_backend(self):
return "hpu_backend"

#shall be removed once moving to torch.compile
def wrap_in_hpu_graph(self, module):
if self.hpu.is_lazy():
module = self.hpu.wrap_in_hpu_graph(module)
else:
print("Warning: hpu graphs in eager mode is not supported, ignoring")
return module

def build_extension(self):
from torch.utils.cpp_extension import BuildExtension
return BuildExtension
Expand Down
3 changes: 3 additions & 0 deletions accelerator/mps_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,6 @@ def build_extension(self):

def export_envs(self):
return []

def get_compile_backend(self):
return "inductor"
3 changes: 3 additions & 0 deletions accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,6 @@ def build_extension(self):

def export_envs(self):
return ['ASCEND', 'HCCL', 'LD_LIBRARY', 'PATH']

def get_compile_backend(self):
return "inductor"
1 change: 1 addition & 0 deletions build.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
+hpu.synapse.v1.16.1
6 changes: 5 additions & 1 deletion csrc/adam/cpu_adam_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,17 @@ int ds_adam_step(int optimizer_id,
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);

bool bit16_precision = false;
if ((params.options().dtype() == at::kHalf) || (params.options().dtype() == at::kBFloat16))
bit16_precision = true;

opt->Step_8(params_ptr,
grads_ptr,
exp_avg_ptr,
exp_avg_sq_ptr,
params_c.numel(),
nullptr,
(params.options().dtype() == at::kHalf));
bit16_precision);

#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
opt->SynchronizeStreams();
Expand Down
4 changes: 4 additions & 0 deletions csrc/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ typedef __half ds_half_precision_t;
#include "acl/acl.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"
typedef c10::Half ds_half_precision_t;
#elif defined(__BFLOAT16__)
#include <torch/torch.h>
typedef at::BFloat16 ds_half_precision_t
#else
#include <cmath>
typedef unsigned short ds_half_precision_t;
Expand Down Expand Up @@ -259,6 +262,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
simd_store<span>(_exp_avg + i, momentum_4, false);
simd_store<span>(_exp_avg_sq + i, variance_4, false);
}
// Params are updated only in case of float16, which is currently not supported on HPU
#if defined(__ENABLE_CUDA__)
if (dev_params) {
if (half_precision)
Expand Down
110 changes: 109 additions & 1 deletion csrc/includes/simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,23 @@
#include <cpuid.h>
#include <x86intrin.h>
#endif
#include <cstdint>
#include <cstring>
#include <type_traits>

template <typename T>
inline T readAs(const void* src)
{
T res;
std::memcpy(&res, src, sizeof(T));
return res;
}

template <typename T>
inline void writeAs(void* dst, const T& val)
{
std::memcpy(dst, &val, sizeof(T));
}

#define TILE (128 * 1024 * 1024)
#if defined(__AVX512__) or defined(__AVX256__)
Expand All @@ -29,12 +46,58 @@
#define SIMD_OR(x, y) _mm512_or_ps(x, y)
#define SIMD_XOR(x, y) _mm512_xor_ps(x, y)
#define SIMD_WIDTH 16
#if defined(ENABLE_BFLOAT16)
static __m512 load_16_bf16_as_f32(const void* data)
{
__m256i a = readAs<__m256i>(data); // use memcpy to avoid aliasing
__m512i b = _mm512_cvtepu16_epi32(a); // convert 8 u16 to 8 u32
__m512i c = _mm512_slli_epi32(b, 16); // logical shift left of all u32 by
// 16 bits (representing bf16->f32)
return readAs<__m512>(&c); // use memcpy to avoid aliasing
}

static void store_16_f32_as_bf16_nearest(__m512 v, void* data)
{
__m512i u32 = readAs<__m512i>(&v);

// flow assuming non-nan:

// uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
__m512i b = _mm512_srli_epi32(u32, 16);
__m512i lsb_mask = _mm512_set1_epi32(0x00000001);
__m512i c = _mm512_and_si512(b, lsb_mask);
__m512i bias_constant = _mm512_set1_epi32(0x00007fff);
__m512i rounding_bias = _mm512_add_epi32(c, bias_constant);

// uint16_t res = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
__m512i d = _mm512_add_epi32(u32, rounding_bias);
__m512i e = _mm512_srli_epi32(d, 16);
__m256i non_nan_res = _mm512_cvtusepi32_epi16(e);

// handle nan (exp is all 1s and mantissa != 0)
// if ((x & 0x7fffffffU) > 0x7f800000U)
__m512i mask_out_sign = _mm512_set1_epi32(0x7fffffff);
__m512i non_sign_bits = _mm512_and_si512(u32, mask_out_sign);
__m512i nan_threshold = _mm512_set1_epi32(0x7f800000);
__mmask16 nan_mask = _mm512_cmp_epi32_mask(non_sign_bits, nan_threshold, _MM_CMPINT_GT);

// mix in results with nans as needed
__m256i nans = _mm256_set1_epi16(0x7fc0);
__m256i res = _mm256_mask_mov_epi16(non_nan_res, nan_mask, nans);

writeAs(data, res);
}

#define SIMD_LOAD2(x, h) ((h) ? load_16_bf16_as_f32(x) : _mm512_loadu_ps(x))

#define SIMD_STORE2(x, d, h) ((h) ? store_16_f32_as_bf16_nearest(d, x) : _mm512_storeu_ps(x, d))
#else // ENABLE_BFLOAT16
#define SIMD_LOAD2(x, h) \
((h) ? _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) : _mm512_loadu_ps(x))
#define SIMD_STORE2(x, d, h) \
((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
: _mm512_storeu_ps(x, d))
#endif // ENABLE_BFLOAT16

#define INTV __m256i
#elif defined(__AVX256__)
Expand All @@ -52,12 +115,57 @@
#define SIMD_XOR(x, y) _mm256_xor_ps(x, y)
#define SIMD_WIDTH 8

#if defined(ENABLE_BFLOAT16)
__m256 load_8_bf16_as_f32(const float* data)
{
__m128i a = readAs<__m128i>(data); // use memcpy to avoid aliasing
__m256i b = _mm256_cvtepu16_epi32(a); // convert 8 u16 to 8 u32
__m256i c = _mm256_slli_epi32(b, 16); // logical shift left of all u32 by
// 16 bits (representing bf16->f32)
return readAs<__m256>(&c); // use memcpy to avoid aliasing
}

void store_8_f32_as_bf16_nearest(__m256 v, float* data)
{
__m256i u32 = readAs<__m256i>(&v);

// flow assuming non-nan:

// uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
__m256i b = _mm256_srli_epi32(u32, 16);
__m256i lsb_mask = _mm256_set1_epi32(0x00000001);
__m256i c = _mm256_and_si256(b, lsb_mask);
__m256i bias_constant = _mm256_set1_epi32(0x00007fff);
__m256i rounding_bias = _mm256_add_epi32(c, bias_constant);

// uint16_t res = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
__m256i d = _mm256_add_epi32(u32, rounding_bias);
__m256i e = _mm256_srli_epi32(d, 16);
__m128i non_nan_res = _mm256_cvtusepi32_epi16(e);

// handle nan (exp is all 1s and mantissa != 0)
// if ((x & 0x7fffffffU) > 0x7f800000U)
__m256i mask_out_sign = _mm256_set1_epi32(0x7fffffff);
__m256i non_sign_bits = _mm256_and_si256(u32, mask_out_sign);
__m256i nan_threshold = _mm256_set1_epi32(0x7f800000);
__mmask8 nan_mask = _mm256_cmp_epi32_mask(non_sign_bits, nan_threshold, _MM_CMPINT_GT);

// mix in results with nans as needed
__m128i nans = _mm_set1_epi16(0x7fc0);
__m128i res = _mm_mask_mov_epi16(non_nan_res, nan_mask, nans);

writeAs(data, res);
}
#define SIMD_LOAD2(x, h) ((h) ? load_8_bf16_as_f32(x) : _mm256_loadu_ps(x))

#define SIMD_STORE2(x, d, h) ((h) ? store_8_f32_as_bf16_nearest(d, x) : _mm256_storeu_ps(x, d))
#else // ENABLE_BFLOAT16
#define SIMD_LOAD2(x, h) \
((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)(x))) : _mm256_loadu_ps(x))
#define SIMD_STORE2(x, d, h) \
((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
: _mm256_storeu_ps(x, d))

#endif // ENABLE_BFLOAT16
#define INTV __m128i
#endif

Expand Down
95 changes: 91 additions & 4 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,15 +446,15 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
unsigned layer_id,
unsigned num_layers,
at::Tensor& alibi,
float rope_theta)
float rope_theta,
bool is_prompt,
std::optional<at::Tensor> token_idx)
{
unsigned bsz = query_key_value.size(0);
unsigned seq_len = query_key_value.size(1);
int k = query_key_value.size(2) / (heads + 2 * (num_kv > 0 ? num_kv : heads));
unsigned hidden_dim = heads * k;

bool is_prompt = (seq_len > 1);

if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len);
unsigned soft_len = InferenceContext::Instance().current_tokens();

Expand Down Expand Up @@ -847,6 +847,87 @@ std::vector<at::Tensor> ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu
return {norm_output, res_output};
}

template <typename T>
at::Tensor ds_transform4d_0213(at::Tensor& input, int seq_length)
{
auto input_cont = input.contiguous();
unsigned batch_size = input.size(0);
unsigned num_heads = input.size(1);
unsigned seq_length_head_dim = input.size(2);
unsigned head_dim = seq_length_head_dim / seq_length;
unsigned hidden_dim = num_heads * head_dim;

auto options = at::TensorOptions()
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();

launch_transform4d_0213<T>(workspace,
(T*)input.data_ptr(),
batch_size,
num_heads,
seq_length,
hidden_dim,
InferenceContext::Instance().GetCurrentStream(),
1);
auto output = at::from_blob(workspace, {batch_size, seq_length, num_heads, head_dim}, options);
return output;
}

template <typename T>
std::vector<at::Tensor> ds_bias_add_transform_0213(at::Tensor& input,
at::Tensor& bias,
int num_heads,
int trans_count)
{
TORCH_CHECK(
trans_count == 1 or trans_count == 3, "trans_count ", trans_count, " is not supported");
auto input_cont = input.contiguous();

unsigned batch_size = input.size(0);
unsigned seq_length = input.size(1);
unsigned value_size = input.size(2);
unsigned hidden_dim = input.size(2) / trans_count;
unsigned head_dim = hidden_dim / num_heads;

auto options = at::TensorOptions()
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
auto final_output = workspace;
int num_kv = -1;
int repo_theta = -1;
size_t offset = (batch_size * seq_length * hidden_dim);
launch_bias_add_transform_0213<T>(final_output,
final_output + offset,
final_output + 2 * offset,
(T*)input.data_ptr(),
(T*)bias.data_ptr(),
batch_size,
seq_length,
0, // seq_offset
input.size(1), // all_tokens .. unused?
hidden_dim,
num_heads,
num_kv,
-1, // rotary_dim
false, // rotate_half
false, // rotate_every_two
InferenceContext::Instance().GetCurrentStream(),
trans_count, // trans_count
input.size(1), // max_out_tokens
repo_theta);
return {at::from_blob(final_output, {batch_size, num_heads, seq_length, head_dim}, options),
at::from_blob(
final_output + offset, {batch_size, num_heads, seq_length, head_dim}, options),
at::from_blob(
final_output + 2 * offset, {batch_size, num_heads, seq_length, head_dim}, options)};
}

template <typename T>
void quantized_gemm(void* output,
T* input,
Expand Down Expand Up @@ -2010,7 +2091,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)"); \
m.def("dequantize_" #_name, \
&ds_dequantize<_dtype>, \
"DeepSpeed dequantize with " #_name " (CUDA)")
"DeepSpeed dequantize with " #_name " (CUDA)"); \
m.def("transform4d_0213_" #_name, \
&ds_transform4d_0213<_dtype>, \
"DeepSpeed transform4d 0213 with " #_name " (CUDA)"); \
m.def("bias_add_transform_0213_" #_name, \
&ds_bias_add_transform_0213<_dtype>, \
"DeepSpeed bias and transform 0213 with " #_name " (CUDA)")

DEF_OPS(fp32, float);
DEF_OPS(fp16, __half);
Expand Down
3 changes: 1 addition & 2 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
else:
HAS_TRITON = False

from .utils import log_dist, OnDevice, logger
from . import ops
from . import module_inject

Expand All @@ -38,11 +39,9 @@
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject import replace_transformer_layer, revert_transformer_layer

from .utils import log_dist, OnDevice, logger
from .comm.comm import init_distributed

from .runtime import zero
from .runtime import DeepSpeedOptimizer, ZeROOptimizer
from .runtime.compiler import is_compile_supported

from .pipe import PipelineModule
Expand Down
Loading