Skip to content

Commit

Permalink
Merge branch 'master' into mrwyattii/pydantic-2-support
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored May 22, 2024
2 parents fcee6a7 + 995ba11 commit d80508d
Show file tree
Hide file tree
Showing 17 changed files with 84 additions and 34 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/cpu-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,5 @@ jobs:
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
# LOCAL_SIZE=2 enforce CPU to report 2 devices, this helps run the test on github default runner
LOCAL_SIZE=2 COLUMNS=240 TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' unit/
LOCAL_SIZE=2 COLUMNS=240 TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference_ops' -m 'inference' unit/
LOCAL_SIZE=2 COLUMNS=240 HF_HOME=~/tmp/hf_home/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' unit/
LOCAL_SIZE=2 COLUMNS=240 HF_HOME=~/tmp/hf_home/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference_ops' -m 'inference' unit/
4 changes: 2 additions & 2 deletions .github/workflows/cpu-torch-latest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ jobs:
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.3"
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.3"
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.3"
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.3"
2 changes: 1 addition & 1 deletion .github/workflows/setup-venv/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ runs:
- id: set-env-vars
run: |
echo TEST_DATA_DIR=/blob/ >> $GITHUB_ENV
echo TRANSFORMERS_CACHE=/blob/transformers_cache/ >> $GITHUB_ENV
echo HF_HOME=/blob/hf_home/ >> $GITHUB_ENV
echo TORCH_EXTENSIONS_DIR=./torch-extensions/ >> $GITHUB_ENV
echo TORCH_CACHE=/blob/torch_cache/ >> $GITHUB_ENV
echo HF_DATASETS_CACHE=/blob/datasets_cache/ >> $GITHUB_ENV
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def init_distributed(dist_backend=None,
auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI
distributed_port: Optional (int). torch distributed backend port
verbose: Optional (bool). verbose logging
timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
timeout: Optional (timedelta). Timeout for operations executed against the process group. The default value of 30 minutes can be overridden by the environment variable `DEEPSPEED_TIMEOUT`.
init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling)
rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization)
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def top1gating(logits: Tensor,
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
# Make sure the capacity value does not exceed the number of tokens.
capacity = min(new_capacity, torch.tensor(mask1.size(0)))
capacity = min(new_capacity, torch.tensor(mask1.size(0)).to(new_capacity.device))

# Compute l_aux
me = torch.mean(gates, dim=0)
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/ops/transformer/inference/ds_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi):
if isinstance(qkv_out, list) or isinstance(qkv_out, tuple):
qkv_out = qkv_out[0]

no_masking = input_mask is None
no_masking = input_mask is None or input_mask is False

if no_masking:
input_mask = torch.empty(1)
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None
class wrapper(mod.__class__):

def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
self.__dict__ = module.__dict__.copy()
self.__dict__ = {k: module.__dict__[k] for k in module.__dict__ if not k in self.__class__.__dict__}

assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."

Expand Down
4 changes: 4 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
from .data_pipeline.config import get_data_efficiency_enabled, get_data_efficiency_config, get_curriculum_enabled_legacy, get_curriculum_params_legacy
from .data_pipeline.constants import *

from ..utils.config import get_timers_config

TENSOR_CORE_ALIGN_SIZE = 8

ADAGRAD_OPTIMIZER = 'adagrad'
Expand Down Expand Up @@ -911,6 +913,8 @@ def _initialize_params(self, param_dict):

self.compile_config = get_compile_config(param_dict)

self.timers_config = get_timers_config(param_dict)

def _batch_assertion(self):

train_batch = self.train_batch_size
Expand Down
9 changes: 4 additions & 5 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,10 @@ def __init__(self,
# Configure wall clock timers
self.timers = SynchronizedWallClockTimer()
# Throughput timer
self.tput_timer = ThroughputTimer(
batch_size=self.train_batch_size(),
steps_per_output=self.steps_per_print(),
monitor_memory=False,
)
self.tput_timer = ThroughputTimer(self._config.timers_config,
batch_size=self.train_batch_size(),
steps_per_output=self.steps_per_print(),
monitor_memory=False)

log_dist(f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}", ranks=[0])

Expand Down
3 changes: 2 additions & 1 deletion deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):

self._force_grad_boundary = False

self.batch_timer = ThroughputTimer(batch_size=self.train_batch_size(),
self.batch_timer = ThroughputTimer(self._config.timers_config,
batch_size=self.train_batch_size(),
logging_fn=self.tput_log,
monitor_memory=False,
steps_per_output=self.steps_per_print())
Expand Down
10 changes: 6 additions & 4 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, param: Parameter) -> None:
self.__param = param

def wait(self) -> None:
if not get_accelerator().is_synchronized_device():
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().synchronize()
self.__param.ds_status = ZeroParamStatus.AVAILABLE

Expand All @@ -82,7 +82,7 @@ def wait(self) -> None:
if self.__complete:
return

if not get_accelerator().is_synchronized_device():
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().synchronize()
for param in self.__params:
assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
Expand Down Expand Up @@ -1737,7 +1737,8 @@ def _allgather_param(self, param, async_op=False, hierarchy=0):
f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ',
force=False)

get_accelerator().synchronize()
if not get_accelerator().resolves_data_dependency():
get_accelerator().synchronize()

print_rank_0(
f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}"
Expand Down Expand Up @@ -1870,7 +1871,8 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0, quantize=False):
param.data = gathered_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape).data

# guarantee the communication to be completed
get_accelerator().synchronize()
if not get_accelerator().resolves_data_dependency():
get_accelerator().synchronize()

return None

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,7 +1575,7 @@ def set_none_gradients_to_zero(self, i, partition_id):
for param_id in self.is_grad_computed[i][partition_id]:
param = self.param_dict[param_id]
if param.grad is None:
param.grad = torch.zero_like(param)
param.grad = torch.zeros_like(param)

######################Reduction Related Methods##############################

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,7 +1474,7 @@ def set_none_gradients_to_zero(self, i, partition_id):
for param_id in self.is_grad_computed[i][partition_id]:
param = self.param_dict[param_id]
if param.grad is None:
param.grad = torch.zero_like(param)
param.grad = torch.zeros_like(param)

######################Reduction Related Methods##############################
def allreduce_bucket(self, bucket, rank=None, log=None, divide=True, process_group=None):
Expand Down
46 changes: 46 additions & 0 deletions deepspeed/utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from deepspeed.runtime.config_utils import DeepSpeedConfigModel

#########################################
# Timers
#########################################
# Timers. By default, timers are enabled.
# Users can configure in ds_config.json as below example:
TIMERS_FORMAT = '''
Timers should be enabled as:
"timers": {
"throughput": {
"enabled": true,
"synchronized": true
}
}
'''

TIMERS = "timers"
TIMERS_THROUGHPUT = "throughput"


def get_timers_config(param_dict):
if param_dict and TIMERS in param_dict and TIMERS_THROUGHPUT in param_dict[TIMERS]:
timers_config_dict = param_dict[TIMERS][TIMERS_THROUGHPUT]
else:
timers_config_dict = {}
return DeepSpeedThroughputTimerConfig(**timers_config_dict)


class DeepSpeedThroughputTimerConfig(DeepSpeedConfigModel):
""" Configure throughput timers """

enabled: bool = True
""" Turn on/off throughput timers """

synchronized: bool = True
""" Whether to synchronize a device when measuring the time.
Synchronizing a device is required to produce the most accurate timer measurements.
However, this comes at the expense of performance degradation. The CPU timer provides
sufficient accuracy in many cases.
"""
20 changes: 9 additions & 11 deletions deepspeed/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,9 @@ def get_mean(self, names, normalizer=1.0, reset=True):

class ThroughputTimer:

def __init__(
self,
batch_size,
start_step=2,
steps_per_output=50,
monitor_memory=False,
logging_fn=None,
):
def __init__(self, config, batch_size, start_step=2, steps_per_output=50, monitor_memory=False, logging_fn=None):
from deepspeed.utils import logger
self.config = config
self.start_time = 0
self.end_time = 0
self.started = False
Expand Down Expand Up @@ -234,22 +228,26 @@ def _init_timer(self):
self.initialized = True

def start(self):
if not self.config.enabled:
return
self._init_timer()
self.started = True
if self.global_step_count >= self.start_step:
get_accelerator().synchronize()
if self.config.synchronized:
get_accelerator().synchronize()
self.start_time = time.time()

def stop(self, global_step=False, report_speed=True):
if not self.started:
if not self.config.enabled or not self.started:
return
self.started = False
self.micro_step_count += 1
if global_step:
self.global_step_count += 1

if self.start_time > 0:
get_accelerator().synchronize()
if self.config.synchronized:
get_accelerator().synchronize()
self.end_time = time.time()
duration = self.end_time - self.start_time
self.total_elapsed_time += duration
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/inference/test_checkpoint_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def write_checkpoints_json(model_name, class_tmpdir):
cached_repo_dir = snapshot_download(
model_name,
local_files_only=is_offline_mode(),
cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
cache_dir=os.getenv("HF_HOME", None),
ignore_patterns=["*.safetensors", "*.msgpack", "*.h5"],
)
file_list = [str(entry) for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]") if entry.is_file()]
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class ModelInfo:
def _hf_model_list() -> List[ModelInfo]:
""" Caches HF model list to avoid repeated API calls """

cache_dir = os.getenv("TRANSFORMERS_CACHE", "~/.cache/huggingface")
cache_dir = os.getenv("HF_HOME", "~/.cache/huggingface")
cache_file_path = os.path.join(cache_dir, "DS_model_cache.pkl")
cache_expiration_seconds = 60 * 60 * 24 # 1 day

Expand Down

0 comments on commit d80508d

Please sign in to comment.