Skip to content

Commit

Permalink
Merge branch 'master' into mrwyattii/pydantic-2-support
Browse files Browse the repository at this point in the history
  • Loading branch information
adk9 authored Jul 16, 2024
2 parents b3804ad + 78c6c44 commit 41fc635
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 30 deletions.
15 changes: 8 additions & 7 deletions accelerator/hpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

import functools
import os
import pkgutil
import importlib
Expand Down Expand Up @@ -196,31 +197,31 @@ def replay_graph(self, graph):
# Tensor operations
@property
def BFloat16Tensor(self):
return self.hpu.BFloat16Tensor
return functools.partial(torch.tensor, dtype=torch.bfloat16, device='hpu')

@property
def ByteTensor(self):
return self.hpu.ByteTensor
return functools.partial(torch.tensor, dtype=torch.uint8, device='hpu')

@property
def DoubleTensor(self):
return self.hpu.DoubleTensor
return functools.partial(torch.tensor, dtype=torch.double, device='hpu')

@property
def FloatTensor(self):
return self.hpu.FloatTensor
return functools.partial(torch.tensor, dtype=torch.float, device='hpu')

@property
def HalfTensor(self):
return self.hpu.HalfTensor
return functools.partial(torch.tensor, dtype=torch.half, device='hpu')

@property
def IntTensor(self):
return self.hpu.IntTensor
return functools.partial(torch.tensor, dtype=torch.int, device='hpu')

@property
def LongTensor(self):
return self.hpu.LongTensor
return functools.partial(torch.tensor, dtype=torch.long, device='hpu')

def pin_memory(self, tensor, align_bytes=1):
return tensor.pin_memory(self.device())
Expand Down
4 changes: 4 additions & 0 deletions blogs/deepspeed-fastgen/chinese/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ DeepSpeed-FastGen 是 [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII
* [LLaMA](https://huggingface.co/models?other=llama)[LLaMA-2](https://huggingface.co/models?other=llama-2)
* [Mistral](https://huggingface.co/models?other=mistral)
* [OPT](https://huggingface.co/models?other=opt)
* [Falcon](https://huggingface.co/models?other=falcon)
* [Mixtral](https://huggingface.co/models?other=mixtral)
* [Phi-2](https://huggingface.co/models?other=phi-msft)
* [Qwen](https://huggingface.co/models?other=qwen)

所有当前模型都利用了后端的 [HuggingFace](https://github.com/huggingface) API 来提供模型权重和模型对应的分词器。

Expand Down
99 changes: 82 additions & 17 deletions csrc/cpu/comm/shm_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,13 @@ void initialize(int size, int rank)
if (all_ranks_local_p) { shm_initialize(size, rank, addr_string, port_string); }
}

int get_rank(int group = 0) { return world_rank; }

int get_world_size(int group = 0) { return world_size; }
void inference_all_reduce_(torch::Tensor& data, int op);

// Success - return 0
// Fail (cannot hornor the request and need to fall back) - return -1
int inference_all_reduce(torch::Tensor& data, py::object op)
void inference_all_reduce_(torch::Tensor& data, int op)
{
if (!all_ranks_local_p) return -1;
assert(op == 0);
#ifdef DO_PROFILE
static double total_time = 0.0;
static double total_time_sq = 0.0;
Expand All @@ -67,11 +65,6 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
auto start = std::chrono::system_clock::now();
#endif

static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));

assert(py::int_(op.attr("value")) == ReduceOpSum);

auto numel = data.numel();

int data_size = 0;
Expand All @@ -84,7 +77,7 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
default: data_type_fallback = true;
}

if (data_type_fallback) return -1;
if (data_type_fallback) return;

all_reduce_outer_loop(data, numel, data_size);

Expand All @@ -109,13 +102,85 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
}
}
#endif
return 0;
return;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("initialize", &initialize, "shm initialize"); }

TORCH_LIBRARY(deepspeed, m)
{
m.def("inference_all_reduce(Tensor self) -> Tensor");
m.def("inference_all_reduce_(Tensor(a!) self) -> Tensor(a!)");
}

torch::Tensor inference_all_reduce_meta(const torch::Tensor& self_)
{
torch::Tensor result_ = torch::empty_like(self_);
return result_;
}

torch::Tensor& inference_all_reduce__meta(torch::Tensor& self_) { return self_; }

torch::Tensor& inference_all_reduce__cpu(torch::Tensor& self_)
{
TORCH_INTERNAL_ASSERT(self_.device().type() == torch::DeviceType::CPU);
torch::Tensor self_tensor = self_.contiguous();
inference_all_reduce_(self_tensor, 0);
return self_;
}

torch::Tensor inference_all_reduce_cpu(const torch::Tensor& self_)
{
torch::Tensor result = self_.clone();
inference_all_reduce__cpu(result);
return result;
}

#include <ATen/FunctionalTensorWrapper.h>
// The boilerplate functionalization logic, that teaches functionalization
// how to map x_() calls into x() calls.
// Long term, we'd like to not require users to write this logic.
// HOWEVER, if you have a custom op that is mutable,
// You will still need to write an out-of-place version of that op!
at::Tensor& inference_all_reduce__functionalization_glue(at::Tensor& x)
{
// We expect all tensor inputs to our op to be "functional tensors"
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(x));
// First, sync and unwrap and functional tensors
at::functionalization::impl::sync(x);
auto x_ = at::functionalization::impl::from_functional_tensor(x);
// Grab the dispatcher entry corresponding to the out-of-place op, "x"
static auto op_handle = c10::Dispatcher::singleton()
// specify namespace::op_name, op_overload_name
.findSchemaOrThrow("deepspeed::inference_all_reduce", "")
// Specify the C++ schema of the out-of-place op.
.typed<at::Tensor(const at::Tensor&)>();
// Next, redispatch to the out-of-place op, x() (user called x_, we call x)
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = op_handle.call(x_);
}
// Finally, tell functionalization about this mutation.
at::functionalization::impl::replace_(x, tmp_output);
at::functionalization::impl::commit_update(x);
at::functionalization::impl::sync(x);
return x;
}

TORCH_LIBRARY_IMPL(deepspeed, CPU, m)
{
m.impl("inference_all_reduce", inference_all_reduce_cpu);
m.impl("inference_all_reduce_", inference_all_reduce__cpu);
}

TORCH_LIBRARY_IMPL(deepspeed, Meta, m)
{
m.impl("inference_all_reduce", inference_all_reduce_meta);
m.impl("inference_all_reduce_", inference_all_reduce__meta);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
TORCH_LIBRARY_IMPL(deepspeed, Functionalize, m)
{
m.def("initialize", &initialize, "shm initialize");
m.def("get_rank", &get_rank, "get rank");
m.def("get_world_size", &get_world_size, "get world size");
m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation");
m.impl("inference_all_reduce_", inference_all_reduce__functionalization_glue);
}
5 changes: 3 additions & 2 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,12 @@ def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, asyn
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)

@compiler.disable
def inference_all_reduce(self, tensor, op, group=None):
if self.shm_comm_op == None or self.shm_comm_op.inference_all_reduce(tensor, op) == -1:
if not hasattr(torch.ops, 'deepspeed') or not hasattr(torch.ops.deepspeed, 'inference_all_reduce_'):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False)
else:
return torch.ops.deepspeed.inference_all_reduce_(tensor)

@compiler.disable
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
Expand Down
20 changes: 18 additions & 2 deletions deepspeed/launcher/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def parse_args(args=None):
help="(optional) IP address of node 0, will be "
"inferred via 'hostname -I' if not specified.")

parser.add_argument("--node_rank",
default=-1,
type=int,
help="ID of each node in the range [0:N). "
"Only required when --no_ssh is set.")

parser.add_argument("--launcher",
default=PDSH_LAUNCHER,
type=str,
Expand Down Expand Up @@ -145,6 +151,10 @@ def parse_args(args=None):
help="Do not pass local_rank as an argument when calling "
"the user's training script.")

parser.add_argument("--no_ssh",
action="store_true",
help="Launch training independently on each node without ssh setup.")

parser.add_argument("--no_ssh_check",
action="store_true",
help="Do not perform ssh check in multi-node launcher model")
Expand Down Expand Up @@ -427,7 +437,7 @@ def main(args=None):
env = os.environ.copy()

# validate that passwordless-ssh is workly properly with this hostfile
if multi_node_exec and not args.no_ssh_check:
if multi_node_exec and not args.no_ssh_check and not args.no_ssh:
first_host = list(active_resources.keys())[0]
try:
ssh_check_cmd = "ssh -o PasswordAuthentication=no "
Expand Down Expand Up @@ -483,16 +493,22 @@ def main(args=None):
if args.elastic_training:
assert not args.no_local_rank, "--no_local_rank argument is not supported in Elastic training"

if args.no_ssh:
assert (0 <= args.node_rank <
len(active_resources)), "Launching training without ssh, but --node_rank is not set correctly."

# encode world info as base64 to make it easier to pass via command line
world_info_base64 = encode_world_info(active_resources)

multi_node_exec = args.force_multi or len(active_resources) > 1
multi_node_exec = (args.force_multi or len(active_resources) > 1) and not args.no_ssh

if not multi_node_exec:
deepspeed_launch = [
sys.executable, "-u", "-m", "deepspeed.launcher.launch", f"--world_info={world_info_base64}",
f"--master_addr={args.master_addr}", f"--master_port={args.master_port}"
]
if args.no_ssh:
deepspeed_launch.append(f"--node_rank={args.node_rank}")
if args.no_python:
deepspeed_launch.append("--no_python")
if args.module:
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def is_load_module(module):
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
"Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding"
"Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

Expand Down
5 changes: 5 additions & 0 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,11 @@ def param_groups(self):
"""Forward the wrapped optimizer's parameters."""
return self.optimizer.param_groups

@property
def state(self):
"""Forward the wrapped optimizer's states."""
return self.optimizer.state

def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx):
assert self.immediate_grad_update
self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=True)
Expand Down
3 changes: 3 additions & 0 deletions deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from deepspeed.ops.op_builder import AsyncIOBuilder
from deepspeed import comm as dist
import torch

from deepspeed.runtime.swap_tensor.constants import *
from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, print_object
Expand Down Expand Up @@ -154,6 +155,8 @@ def swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors):

def _complete_swap_out(self, swap_out_type):
self.swap_ops[swap_out_type].wait()
for buffer in self.swap_ops[swap_out_type].state_buffers:
buffer = torch.Tensor()
self.swap_buffer_manager.free(self.swap_ops[swap_out_type].allocated_buffers)
self.swap_ops[swap_out_type] = None

Expand Down
1 change: 0 additions & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ def __init__(
self.params_in_ipg_bucket = []

self.params_already_reduced = {}
self.is_gradient_accumulation_boundary = True
self._release_ipg_buffers()
self.previous_reduced_grads = None

Expand Down

0 comments on commit 41fc635

Please sign in to comment.