Skip to content

Commit

Permalink
DeepSpeed content for 1.19.0
Browse files Browse the repository at this point in the history
Signed-off-by: SW publisher <[email protected]>
  • Loading branch information
SW publisher authored and Jenkins committed Dec 19, 2024
1 parent d254d75 commit 9c2d043
Show file tree
Hide file tree
Showing 191 changed files with 12,655 additions and 1,104 deletions.
89 changes: 0 additions & 89 deletions .pre-commit-config.yaml

This file was deleted.

50 changes: 2 additions & 48 deletions CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,6 @@
# Learn more about CODEOWNERS syntax here:
# https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners

* [email protected] [email protected] [email protected] [email protected] [email protected] [email protected] [email protected]

# top-level repo folders
/.github/ @mrwyattii @loadams
/azure/ @mrwyattii @awan-10
/benchmarks/ @awan-10 @mrwyattii
/bin/ @mrwyattii
/csrc/ @awan-10 @mrwyattii @cmikeh2 @arashb
/deepspeed/ @mrwyattii
/docker/ @mrwyattii @awan-10
/docs/ @mrwyattii
/examples/ @awan-10 @mrwyattii
/op_builder/ @mrwyattii @cmikeh2
/release/ @loadams @mrwyattii
/requirements/ @loadams @mrwyattii
/scripts/ @mrwyattii @awan-10
/tests/ @mrwyattii @tjruwase @loadams

# deepspeed
/deepspeed/autotuning/ @mrwyattii
/deepspeed/checkpoint/ @tjruwase
/deepspeed/comm/ @awan-10
/deepspeed/compression/ @minjiaz @xiaoxiawu-microsoft @conglongli
/deepspeed/elasticity/ @mrwyattii @awan-10
/deepspeed/launcher/ @mrwyattii @awan-10
/deepspeed/module_inject/ @mrwyattii @awan-10 @cmikeh2 @arashb
/deepspeed/moe/ @awan-10
/deepspeed/monitor/ @awan-10 @mrwyattii
/deepspeed/nebula/ @tjruwase @mrwyattii
/deepspeed/ops/ @mrwyattii @awan-10 @cmikeh2 @arashb
/deepspeed/pipe/ @ShadenSmith @duli2012
/deepspeed/profiling/ @ShijieZZZZ
/deepspeed/utils/ @mrwyattii @tjruwase @awan-10

# inference
/deepspeed/inference/ @mrwyattii @awan-10 @cmikeh2 @arashb
/deepspeed/model_implementations/ @mrwyattii @awan-10 @cmikeh2 @arashb

# training
/deepspeed/runtime/ @mrwyattii @tjruwase
/deepspeed/runtime/activation_checkpointing/ @mrwyattii @tjruwase
/deepspeed/runtime/checkpoint_engine/ @tjruwase @mrwyattii
/deepspeed/runtime/comm/ @awan-10
/deepspeed/runtime/compression/ @awan-10 @conglongli
/deepspeed/runtime/data_pipeline/ @conglongli
/deepspeed/runtime/fp16/ @mrwyattii @tjruwase
/deepspeed/runtime/fp16/onebit/ @conglongli @awan-10
/deepspeed/runtime/pipe/ @ShadenSmith @duli2012
/deepspeed/runtime/swap_tensor/ @tjruwase @mrwyattii
/deepspeed/runtime/zero/ @tjruwase @mrwyattii
CODEOWNERS [email protected] [email protected] [email protected]
45 changes: 34 additions & 11 deletions accelerator/hpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

import functools
import os
import pkgutil
import importlib
Expand All @@ -17,6 +18,7 @@ def __init__(self):
self._name = 'hpu'
self._communication_backend_name = 'hccl'
self._compile_backend = "hpu_backend"
self.apply_hpu_workarounds()
try:
import habana_frameworks.torch.hpu as hpu
hpu.setDeterministic(True)
Expand All @@ -27,6 +29,15 @@ def __init__(self):

self.fp16_supported = None

def apply_hpu_workarounds(self):

def update_wa_env_var(key, value):
if key not in os.environ.keys():
os.environ[key] = value

update_wa_env_var("PT_HPU_LAZY_ACC_PAR_MODE", "0")
update_wa_env_var("PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES", "0")

# Device APIs
def is_synchronized_device(self):
return False
Expand All @@ -41,9 +52,8 @@ def handles_memory_backpressure(self):
return True

def device_name(self, device_index=None):
if device_index is None:
return 'hpu'
return 'hpu:{}'.format(device_index)
# ignoring device_index.
return 'hpu'

def device(self, device_index=None):
return torch.device(self.device_name(device_index))
Expand Down Expand Up @@ -196,31 +206,31 @@ def replay_graph(self, graph):
# Tensor operations
@property
def BFloat16Tensor(self):
return self.hpu.BFloat16Tensor
return functools.partial(torch.tensor, dtype=torch.bfloat16, device='hpu')

@property
def ByteTensor(self):
return self.hpu.ByteTensor
return functools.partial(torch.tensor, dtype=torch.uint8, device='hpu')

@property
def DoubleTensor(self):
return self.hpu.DoubleTensor
return functools.partial(torch.tensor, dtype=torch.double, device='hpu')

@property
def FloatTensor(self):
return self.hpu.FloatTensor
return functools.partial(torch.tensor, dtype=torch.float, device='hpu')

@property
def HalfTensor(self):
return self.hpu.HalfTensor
return functools.partial(torch.tensor, dtype=torch.half, device='hpu')

@property
def IntTensor(self):
return self.hpu.IntTensor
return functools.partial(torch.tensor, dtype=torch.int, device='hpu')

@property
def LongTensor(self):
return self.hpu.LongTensor
return functools.partial(torch.tensor, dtype=torch.long, device='hpu')

def pin_memory(self, tensor, align_bytes=1):
return tensor.pin_memory(self.device())
Expand Down Expand Up @@ -289,6 +299,14 @@ def get_op_builder(self, class_name):
else:
return self.class_dict['NotImplementedBuilder'] if 'NotImplementedBuilder' in self.class_dict else None

#shall be removed once moving to torch.compile
def wrap_in_hpu_graph(self, module):
if self.hpu.is_lazy():
module = self.hpu.wrap_in_hpu_graph(module)
else:
print("Warning: hpu graphs in eager mode is not supported, ignoring")
return module

def build_extension(self):
from torch.utils.cpp_extension import BuildExtension
return BuildExtension
Expand All @@ -297,7 +315,12 @@ def export_envs(self):
return []

def visible_devices_envs(self):
return ['HABANA_VISIBLE_MODULES']
# TODO SW-195658: remove WA to not return HABANA_VISIBLE_MODULES once SW-195657 is resolved
# Current way deepspeed set this env var is not applicable with all HPU instances
# User has to follow instructions in:
# https://docs.habana.ai/en/latest/PyTorch/Reference/PT_Multiple_Tenants_on_HPU/Multiple_Workloads_Single_Docker.html
# keeping CUDA_VISIBLE_DEVICES
return ['CUDA_VISIBLE_DEVICES'] #['HABANA_VISIBLE_MODULES']

def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
Expand Down
1 change: 1 addition & 0 deletions build.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
+hpu.synapse.v1.19.0
22 changes: 15 additions & 7 deletions csrc/fp_quantizer/fp_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,20 @@
stochastic_rounding); \
}

at::Tensor quantize(torch::Tensor& val,
at::Tensor quantize(torch::Tensor& out,
torch::Tensor& val,
int group_size,
int stochastic_rounding,
int q_bits,
int q_mantisa_bits)
{
int total_elems = at::numel(val);
auto options = at::TensorOptions()
.dtype(torch::kInt8)
.layout(val.layout())
.device(val.device())
.requires_grad(false);
float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges
(q_bits == 12 ? 510.0 : // fp12 range
(q_bits == 6 ? 28.0 : // fp6 range
6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4
// in case accuracy is not matching!
int num_groups = total_elems / group_size;
auto out = torch::empty({num_groups, group_size * q_bits / 8 + 4}, options);

DISPATCH_QUANTIZE(kHalf, __half, 23, 8);
#ifdef BF16_AVAILABLE
Expand Down Expand Up @@ -108,9 +103,22 @@ void selective_dequantize(torch::Tensor& val,
#endif
}

at::Tensor get_scales(torch::Tensor& out, int num_groups)
{
auto options = at::TensorOptions()
.dtype(torch::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto scales =
torch::from_blob(out.data_ptr(), {num_groups, 1}, {out.stride(0) / 4, 1}, options);
return scales;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("quantize", &quantize, "quantize function");
m.def("dequantize", &dequantize, "dequantize function");
m.def("get_scales", &get_scales, "get scales function");
m.def("selective_dequantize", &selective_dequantize, "selective dequantize function");
}
2 changes: 2 additions & 0 deletions csrc/fp_quantizer/fp_quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
#include <cuda_fp16.h>
#include <curand_kernel.h>

#ifdef BF16_AVAILABLE
#include <cuda_bf16.h>
#endif
#include <cuda_runtime_api.h>

using ROp = reduce::ROpType;
Expand Down
2 changes: 2 additions & 0 deletions csrc/fp_quantizer/includes/fp_quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

#include <cuda_fp16.h>

#ifdef BF16_AVAILABLE
#include <cuda_bf16.h>
#endif
#include <cuda_runtime_api.h>
#include <stdio.h>

Expand Down
9 changes: 6 additions & 3 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,14 +452,17 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
unsigned layer_id,
unsigned num_layers,
at::Tensor& alibi,
float rope_theta)
float rope_theta,
bool is_prompt,
std::optional<at::Tensor> token_idx,
std::optional<at::Tensor> position_ids)
{
unsigned bsz = query_key_value.size(0);
unsigned seq_len = query_key_value.size(1);
int k = query_key_value.size(2) / (heads + 2 * (num_kv > 0 ? num_kv : heads));
unsigned hidden_dim = heads * k;

bool is_prompt = (seq_len > 1);
is_prompt = (seq_len > 1);

if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len);
unsigned soft_len = InferenceContext::Instance().current_tokens();
Expand Down Expand Up @@ -2028,7 +2031,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)"); \
m.def("dequantize_" #_name, \
&ds_dequantize<_dtype>, \
"DeepSpeed dequantize with " #_name " (CUDA)")
"DeepSpeed dequantize with " #_name " (CUDA)");

DEF_OPS(fp32, float);
DEF_OPS(fp16, __half);
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/zero_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _strip_tensor_paddings(self, sd):
if group_paddings[key] == 0:
continue
for state_name, state_value in group_state.items():
if state_name != "step" and torch.is_tensor(state_value):
if state_name != "step" and torch.is_tensor(state_value) and state_value.dim():
raw_length = state_value.numel() - group_paddings[key]
group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone()
else:
Expand Down
Loading

0 comments on commit 9c2d043

Please sign in to comment.