diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py index 3cbb317519d6..ed1ecf717c21 100644 --- a/deepspeed/inference/v2/engine_factory.py +++ b/deepspeed/inference/v2/engine_factory.py @@ -20,8 +20,10 @@ from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata -def buid_engine_from_ds_checkpoint(path:str, engine_config: RaggedInferenceEngineConfig, - debug_level: int = logging.INFO) -> InferenceEngineV2: + +def buid_engine_from_ds_checkpoint(path: str, + engine_config: RaggedInferenceEngineConfig, + debug_level: int = logging.INFO) -> InferenceEngineV2: inference_logger(level=debug_level) # Load metadata, for grabbing the policy name we'll have all ranks just check for @@ -42,6 +44,7 @@ def buid_engine_from_ds_checkpoint(path:str, engine_config: RaggedInferenceEngin return InferenceEngineV2(policy, engine_config) + def build_hf_engine(path: str, engine_config: RaggedInferenceEngineConfig, debug_level: int = logging.INFO) -> InferenceEngineV2: diff --git a/deepspeed/inference/v2/model_implementations/layer_container_base.py b/deepspeed/inference/v2/model_implementations/layer_container_base.py index f0a1618c6605..98e3e0bb31ed 100644 --- a/deepspeed/inference/v2/model_implementations/layer_container_base.py +++ b/deepspeed/inference/v2/model_implementations/layer_container_base.py @@ -214,8 +214,8 @@ def _initialization_checker(self, check_device: bool = True) -> bool: if tensor is None: continue elif not isinstance(tensor, InferenceParameter): - raise ValueError( - "Layer should be finalized, but {} ({}) is neither InferenceParameter or None".format(name, type(tensor))) + raise ValueError("Layer should be finalized, but {} ({}) is neither InferenceParameter or None".format( + name, type(tensor))) elif check_device and tensor.device != torch.device(get_accelerator().current_device()): raise RuntimeError("Layer should be finalized, but {} is not on device {}".format( name, diff --git a/tests/unit/inference/v2/model_implementations/parameters/test_contiguify.py b/tests/unit/inference/v2/model_implementations/parameters/test_contiguify.py index 8199185cbe48..52ff0e134dfc 100644 --- a/tests/unit/inference/v2/model_implementations/parameters/test_contiguify.py +++ b/tests/unit/inference/v2/model_implementations/parameters/test_contiguify.py @@ -81,10 +81,8 @@ def test_contiguify_roundtrip(): non_transformer_params = [p.to(get_accelerator().current_device()) for p in non_transformer_params] - def validate_containers(t_containers: List[LayerContainer], - n_t_containers: LayerContainer, - t_params: List[List[torch.Tensor]], - n_t_params: List[torch.Tensor]): + def validate_containers(t_containers: List[LayerContainer], n_t_containers: LayerContainer, + t_params: List[List[torch.Tensor]], n_t_params: List[torch.Tensor]): """ Validate params match what is on the containers. """ @@ -107,8 +105,7 @@ def validate_containers(t_containers: List[LayerContainer], buffer, metadata = flatten_inference_model(transformer_containers, non_transformer_container, "NoOpPolicy") # Validate containers before contiguify - validate_containers(transformer_containers, non_transformer_container, transformer_params, - non_transformer_params) + validate_containers(transformer_containers, non_transformer_container, transformer_params, non_transformer_params) # Validate restore pass transformer_containers_r = [] @@ -121,5 +118,3 @@ def validate_containers(t_containers: List[LayerContainer], validate_containers(transformer_containers_r, non_transformer_container_r, transformer_params, non_transformer_params) - - diff --git a/tests/unit/inference/v2/model_implementations/parameters/test_mapping.py b/tests/unit/inference/v2/model_implementations/parameters/test_mapping.py index 0feaaff72a7e..52313cb6f202 100644 --- a/tests/unit/inference/v2/model_implementations/parameters/test_mapping.py +++ b/tests/unit/inference/v2/model_implementations/parameters/test_mapping.py @@ -23,6 +23,7 @@ def finalize(self) -> torch.Tensor: param = torch.cat([self.dependency_1, self.dependency_2]) return InferenceParameter.initialize(param) + class ListDependencyContainer(ParameterBase): dependencies: ParamList("list_items") # noqa: F821