Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
cmikeh2 committed Nov 10, 2023
1 parent 89566b4 commit 295fa61
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 12 deletions.
7 changes: 5 additions & 2 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata

def buid_engine_from_ds_checkpoint(path:str, engine_config: RaggedInferenceEngineConfig,
debug_level: int = logging.INFO) -> InferenceEngineV2:

def buid_engine_from_ds_checkpoint(path: str,
engine_config: RaggedInferenceEngineConfig,
debug_level: int = logging.INFO) -> InferenceEngineV2:

inference_logger(level=debug_level)
# Load metadata, for grabbing the policy name we'll have all ranks just check for
Expand All @@ -42,6 +44,7 @@ def buid_engine_from_ds_checkpoint(path:str, engine_config: RaggedInferenceEngin

return InferenceEngineV2(policy, engine_config)


def build_hf_engine(path: str,
engine_config: RaggedInferenceEngineConfig,
debug_level: int = logging.INFO) -> InferenceEngineV2:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def _initialization_checker(self, check_device: bool = True) -> bool:
if tensor is None:
continue
elif not isinstance(tensor, InferenceParameter):
raise ValueError(
"Layer should be finalized, but {} ({}) is neither InferenceParameter or None".format(name, type(tensor)))
raise ValueError("Layer should be finalized, but {} ({}) is neither InferenceParameter or None".format(
name, type(tensor)))
elif check_device and tensor.device != torch.device(get_accelerator().current_device()):
raise RuntimeError("Layer should be finalized, but {} is not on device {}".format(
name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,8 @@ def test_contiguify_roundtrip():

non_transformer_params = [p.to(get_accelerator().current_device()) for p in non_transformer_params]

def validate_containers(t_containers: List[LayerContainer],
n_t_containers: LayerContainer,
t_params: List[List[torch.Tensor]],
n_t_params: List[torch.Tensor]):
def validate_containers(t_containers: List[LayerContainer], n_t_containers: LayerContainer,
t_params: List[List[torch.Tensor]], n_t_params: List[torch.Tensor]):
"""
Validate params match what is on the containers.
"""
Expand All @@ -107,8 +105,7 @@ def validate_containers(t_containers: List[LayerContainer],
buffer, metadata = flatten_inference_model(transformer_containers, non_transformer_container, "NoOpPolicy")

# Validate containers before contiguify
validate_containers(transformer_containers, non_transformer_container, transformer_params,
non_transformer_params)
validate_containers(transformer_containers, non_transformer_container, transformer_params, non_transformer_params)

# Validate restore pass
transformer_containers_r = []
Expand All @@ -121,5 +118,3 @@ def validate_containers(t_containers: List[LayerContainer],

validate_containers(transformer_containers_r, non_transformer_container_r, transformer_params,
non_transformer_params)


Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def finalize(self) -> torch.Tensor:
param = torch.cat([self.dependency_1, self.dependency_2])
return InferenceParameter.initialize(param)


class ListDependencyContainer(ParameterBase):

dependencies: ParamList("list_items") # noqa: F821
Expand Down

0 comments on commit 295fa61

Please sign in to comment.