Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce communication to optimize the performance for mixtral tp #4

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ repos:
name: check-torchdist
entry: ./scripts/check-torchdist.py
language: python
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py|deepspeed/elasticity/elastic_agent.py|deepspeed/launcher/launch.py|tests/unit/comm/test_dist.py)
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py|deepspeed/elasticity/elastic_agent.py|deepspeed/launcher/launch.py|tests/unit/comm/test_dist.py|deepspeed/runtime/zero/utils.py|deepspeed/tools/pg_sim/ut/base.py|deepspeed/tools/pg_sim/pg.py)
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm

- repo: local
Expand Down
17 changes: 17 additions & 0 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ def __init__(self):
def is_synchronized_device(self):
...

@abc.abstractmethod
def use_host_timers(self):
...

@abc.abstractmethod
def resolves_data_dependency(self):
...

@abc.abstractmethod
def handles_memory_backpressure(self):
...

# Device APIs
@abc.abstractmethod
def device_name(self, device_index):
Expand Down Expand Up @@ -255,6 +267,11 @@ def create_op_builder(self, class_name):
def get_op_builder(self, class_name):
...

# creates and returns an optimizer, specified by optimizer_name, when the accelerator has its own implementation
@abc.abstractmethod
def get_optimizer(self, optimizer_name, cpu_optimization, model_parameters, **optimizer_parameters):
...

@abc.abstractmethod
def build_extension(self):
...
12 changes: 12 additions & 0 deletions accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ def __init__(self):
def is_synchronized_device(self):
return True

def use_host_timers(self):
return self.is_synchronized_device()

def resolves_data_dependency(self):
return self.is_synchronized_device()

def handles_memory_backpressure(self):
return self.is_synchronized_device()

# Device APIs
def device_name(self, device_index=None):
return 'cpu'
Expand Down Expand Up @@ -280,3 +289,6 @@ def get_op_builder(self, class_name):
def build_extension(self):
from torch.utils.cpp_extension import BuildExtension
return BuildExtension

def get_optimizer(self, optimizer_name, cpu_optimization, model_parameters, **optimizer_parameters):
return None
18 changes: 13 additions & 5 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ def _init_pynvml(self):
def is_synchronized_device(self):
return False

def use_host_timers(self):
return self.is_synchronized_device()

def resolves_data_dependency(self):
return self.is_synchronized_device()

def handles_memory_backpressure(self):
return self.is_synchronized_device()

# Device APIs
def device_name(self, device_index=None):
if device_index == None:
Expand Down Expand Up @@ -183,11 +192,7 @@ def is_bf16_supported(self):
return torch.cuda.is_bf16_supported()

def is_fp16_supported(self):
major, _ = torch.cuda.get_device_capability()
if major >= 7:
return True
else:
return False
return True

def supported_dtypes(self):
return [torch.float, torch.half, torch.bfloat16]
Expand Down Expand Up @@ -322,3 +327,6 @@ def get_op_builder(self, class_name):
def build_extension(self):
from torch.utils.cpp_extension import BuildExtension
return BuildExtension

def get_optimizer(self, optimizer_name, cpu_optimization, model_parameters, **optimizer_parameters):
return None
Loading