Skip to content

Commit

Permalink
Merge branch 'microsoft:master' into abhilash1910_cpu_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
abhilash1910 authored Jan 4, 2024
2 parents b98b4df + 8342725 commit 81e2fb1
Show file tree
Hide file tree
Showing 97 changed files with 2,145 additions and 569 deletions.
13 changes: 13 additions & 0 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,19 @@ def communication_backend_name(self):
def is_triton_supported(self):
...

# Graph operations
@abc.abstractmethod
def create_graph(self):
...

@abc.abstractmethod
def capture_to_graph(self, graph, pool=None, stream=None):
...

@abc.abstractmethod
def replay_graph(self, graph):
...

# Tensor operations
@property
@abc.abstractmethod
Expand Down
12 changes: 11 additions & 1 deletion accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,18 @@ def is_fp16_supported(self):
def supported_dtypes(self):
return [torch.float, torch.bfloat16]

# Tensor operations
# Graph operations
def create_graph(self):
return None

def capture_to_graph(self, graph, pool=None, stream=None):
from deepspeed.runtime.utils import noop_context
return noop_context()

def replay_graph(self, graph):
return

# Tensor operations
@property
def BFloat16Tensor(self):
return torch.BFloat16Tensor
Expand Down
11 changes: 11 additions & 0 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,17 @@ def is_triton_supported(self):
else:
return False

# Graph operations
def create_graph(self):
return torch.cuda.CUDAGraph()

def capture_to_graph(self, graph, pool=None, stream=None):
return torch.cuda.graph(graph, pool, stream)

def replay_graph(self, graph):
graph.replay()
return

# Tensor operations

@property
Expand Down
11 changes: 11 additions & 0 deletions accelerator/mps_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ def communication_backend_name(self):
def is_triton_supported(self):
return False

# Graph operations
def create_graph(self):
return None

def capture_to_graph(self, graph, pool=None, stream=None):
from deepspeed.runtime.utils import noop_context
return noop_context()

def replay_graph(self, graph):
return

# Tensor operations
@property
def BFloat16Tensor(self):
Expand Down
11 changes: 11 additions & 0 deletions accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,17 @@ def communication_backend_name(self):
def is_triton_supported(self):
return False

# Graph operations
def create_graph(self):
return None

def capture_to_graph(self, graph, pool=None, stream=None):
from deepspeed.runtime.utils import noop_context
return noop_context()

def replay_graph(self, graph):
return

# Tensor operations

@property
Expand Down
2 changes: 2 additions & 0 deletions blogs/deepspeed-fastgen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ We currently support the following model architectures in this alpha release of
* [Mistral](https://huggingface.co/models?other=mistral)
* [OPT](https://huggingface.co/models?other=opt)
* [Falcon](https://huggingface.co/models?other=falcon)
* [Mixtral](https://huggingface.co/models?other=mixtral)
* [Phi-2](https://huggingface.co/models?other=phi-msft)

All current models leverage [HuggingFace](https://github.com/huggingface) APIs in our backend to provide both the model weights and the model's corresponding tokenizer.

Expand Down
3 changes: 2 additions & 1 deletion csrc/includes/cpu_adagrad.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,15 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
if (half_precision) memcpy_size /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
}
*rounded_size = new_rounded_size;
Expand Down
6 changes: 3 additions & 3 deletions csrc/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream((_streams[_buf_index].stream());
}
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
Expand Down Expand Up @@ -274,14 +273,15 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
if (half_precision) memcpy_size /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
}
*rounded_size = new_rounded_size;
Expand Down
3 changes: 2 additions & 1 deletion csrc/includes/cpu_lion.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,15 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size,
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
if (half_precision) memcpy_size /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
}
*rounded_size = new_rounded_size;
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def get_all_ranks_from_group(group=None):
while True:
group_ranks.append(cdb.get_global_rank(group, rank))
rank += 1
except RuntimeError:
except (RuntimeError, ValueError):
pass
return group_ranks

Expand Down
6 changes: 3 additions & 3 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,11 @@ def _create_cuda_graph(self, *inputs, **kwargs):
get_accelerator().current_stream().wait_stream(cuda_stream)

# create cuda_graph and assign static_inputs and static_outputs
self._cuda_graphs = torch.cuda.CUDAGraph()
self._cuda_graphs = get_accelerator().create_graph()
self.static_inputs = inputs
self.static_kwargs = kwargs

with torch.cuda.graph(self._cuda_graphs):
with get_accelerator().capture_to_graph(self._cuda_graphs):
self.static_output = self.module(*self.static_inputs, **self.static_kwargs)

self.cuda_graph_created = True
Expand All @@ -547,7 +547,7 @@ def _graph_replay(self, *inputs, **kwargs):
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
self._cuda_graphs.replay()
get_accelerator().replay_graph(self._cuda_graphs)
return self.static_output

def model_times(self):
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/checkpoint/huggingface_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def model_has_safetensors(model_name_or_path: str) -> bool:
# We need to download the checkpoint files from HF
if model_has_safetensors(self.model_name_or_path):
# Prioritize downloading safetensors if they are available
allow_patterns = ["*.safetensors", "*.json", "*.pt"]
allow_patterns = ["*.safetensors", "*.json"]
else:
# Fallback to bin files when safetensors are not present
allow_patterns = ["*.bin", "*.json", "*.pt"]
Expand Down
10 changes: 10 additions & 0 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
OPTPolicy,
Llama2Policy,
MistralPolicy,
MixtralPolicy,
FalconPolicy,
PhiPolicy,
)
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata
Expand Down Expand Up @@ -105,8 +107,16 @@ def build_hf_engine(path: str,
assert version.parse(transformers.__version__) >= version.parse("4.34.0"), \
f"Mistral requires transformers >= 4.34.0, you have version {transformers.__version__}"
policy = MistralPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "mixtral":
# Ensure we're using the correct version of transformers for mistral
import transformers
assert version.parse(transformers.__version__) >= version.parse("4.36.1"), \
f"Mistral requires transformers >= 4.36.1, you have version {transformers.__version__}"
policy = MixtralPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "falcon":
policy = FalconPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "phi-msft":
policy = PhiPolicy(model_config, checkpoint_engine=checkpoint_engine)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/kernels/ragged_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
from .logits_gather import *
from .moe_gather import *
from .moe_scatter import *
from .top_1_gating import *
from .top_k_gating import *
15 changes: 15 additions & 0 deletions deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#define TOP_K_SWITCH(N_TOP_K, ...) \
[&] { \
if (1 == N_TOP_K) { \
constexpr int CONST_TOP_K = 1; \
__VA_ARGS__(); \
} else if (2 == N_TOP_K) { \
constexpr int CONST_TOP_K = 2; \
__VA_ARGS__(); \
} \
}()
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
(C_TYPE*)k.data_ptr(), \
(C_TYPE*)v.data_ptr(), \
(C_TYPE*)inv_freq_ptr, \
rotary_dim, \
theta_base, \
batch_wrapper, \
qkv_stride, \
kv_cache_stride, \
Expand Down Expand Up @@ -51,6 +53,9 @@ void kv_trained_rotary_embeddings(torch::Tensor& kv_cache,
TORCH_CHECK(n_tokens == k.size(0));
TORCH_CHECK(n_tokens == v.size(0));

const float theta_base = 0.f;
const int32_t rotary_dim = inv_freq.size(0) * 2;

// Dimensions
const int32_t block_size = kv_cache.size(1);
const int32_t n_kv_heads = kv_cache.size(3);
Expand Down Expand Up @@ -91,6 +96,8 @@ void kv_rotary_embeddings(torch::Tensor& kv_cache,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
const int32_t rotary_dim,
const float theta_base,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
torch::Tensor& tokens_to_seq,
Expand Down
Loading

0 comments on commit 81e2fb1

Please sign in to comment.