Skip to content

Commit

Permalink
Unit test implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
cmikeh2 committed Nov 10, 2023
1 parent 91385d2 commit 89566b4
Show file tree
Hide file tree
Showing 13 changed files with 146 additions and 33 deletions.
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def on_device(method) -> torch.Tensor:
def wrapped(self, *args, **kwargs):
tensor = method(self, *args, **kwargs)
if isinstance(tensor, torch.Tensor):
return tensor.to(get_accelerator().current_device()).contiguous()
return tensor.to(get_accelerator().current_device())
return tensor

return wrapped
2 changes: 0 additions & 2 deletions deepspeed/inference/v2/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,3 @@
from .base_engine import CheckpointEngineBase
from .in_memory_engine import InMemoryModelEngine
from .huggingface_engine import HuggingFaceCheckpointEngine

from .megatron_engine import MegatronCheckpointEngine
6 changes: 3 additions & 3 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

def buid_engine_from_ds_checkpoint(path:str, engine_config: RaggedInferenceEngineConfig,
debug_level: int = logging.INFO) -> InferenceEngineV2:

inference_logger(level=debug_level)
# Load metadata, for grabbing the policy name we'll have all ranks just check for
# rank 0.
Expand All @@ -39,7 +39,7 @@ def buid_engine_from_ds_checkpoint(path:str, engine_config: RaggedInferenceEngin
# Load the model config
model_config = pickle.load(open(os.path.join(path, "ds_model_config.pkl"), "rb"))
policy = policy_cls(model_config, inf_checkpoint_path=path)

return InferenceEngineV2(policy, engine_config)

def build_hf_engine(path: str,
Expand All @@ -50,7 +50,7 @@ def build_hf_engine(path: str,
"""

if os.path.exists(os.path.join(path, "ds_model_config.pkl")):
return buid_engine_from_ds_checkpoint(path, engine_config)
return buid_engine_from_ds_checkpoint(path, engine_config, debug_level=debug_level)
else:
# Set up logging
inference_logger(level=debug_level)
Expand Down
1 change: 0 additions & 1 deletion deepspeed/inference/v2/model_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,3 @@
from .llama_v2 import *
from .opt import *
from .mistral import *
from .tlg import *
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,5 @@ def finalize(self) -> torch.Tensor:
head_size = self.inference_model.head_size
n_q_heads = self.inference_model.n_heads_q
n_kv_heads = self.inference_model.n_heads_kv
if self.params.shape[0] != (2 * n_kv_heads + n_q_heads) * head_size:
world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()
n_q_heads = n_q_heads // world_size
n_kv_heads = n_kv_heads // world_size
transposed_param = transform_gqa_megatron(self.params, head_size, n_q_heads, n_kv_heads)
return self.inference_model.transform_qkv_param(transposed_param)
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def gated_mlp(self) -> bool:
"""

def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig,
base_mp_group: MPType, please_dont_shard: bool=False) -> None:
base_mp_group: MPType) -> None:
"""
Base implementation for initialization. By default, this will initialize
the traditional components of a transformer model:
Expand All @@ -201,7 +201,7 @@ def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInf
engine_config (RaggedInferenceEngineConfig): Engine configuration.
base_mp_group (MPType): Base communication group for Tensor-parallel inference.
"""
super().__init__(config, engine_config, base_mp_group, please_dont_shard)
super().__init__(config, engine_config, base_mp_group)

self.make_norm_layer()
self.make_qkv_layer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def shard_param(param: Optional[torch.Tensor],
bias_dims (int): The number of dimensions that are considered bias dimensions. This is used to support
sharding of MoE and non-MoE biases on the same codepath.
"""
return param
assert shard_rank < num_shards, "Shard rank must be less than num_shards"

# Easier to hide this inside of the sharding logic than to add checks in every model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def forward(self,
ragged_batch (RaggedBatchWrapper): The input ids and associated ragged batch metadata.
word_embeddings (torch.Tensor): The word embedding table
"""

output = empty_from(self._output, (ragged_batch.tensor_toks, self._config.embedding_dim))
self._ragged_embed(output,
ragged_batch,
Expand Down
10 changes: 5 additions & 5 deletions op_builder/inference_cutlass_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def sources(self):
return sources

def extra_ldflags(self):
#import dskernels
lib_path = '/data/users/reyazda/ds-alpha/DeepSpeed-Kernels/ft_gemm/gemm_variants/lib' #dskernels.library_path()
#prefix = self.get_prefix()
#lib_path = os.path.join(prefix, lib_path)
#lib_path = self.deepspeed_src_path(lib_path)
import dskernels
lib_path = dskernels.library_path()
prefix = self.get_prefix()
lib_path = os.path.join(prefix, lib_path)
lib_path = self.deepspeed_src_path(lib_path)

args = [f'-L{lib_path}', '-ldeepspeedft']
if self.jit_load:
Expand Down
10 changes: 5 additions & 5 deletions op_builder/ragged_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def sources(self):
return sources

def extra_ldflags(self):
#import dskernels
lib_path = '/data/users/reyazda/ds-alpha/DeepSpeed-Kernels/inf_flash_attn/blocked_flash/lib' #dskernels.library_path()
import dskernels
lib_path = dskernels.library_path()

#prefix = self.get_prefix()
#lib_path = os.path.join(prefix, lib_path)
#lib_path = self.deepspeed_src_path(lib_path)
prefix = self.get_prefix()
lib_path = os.path.join(prefix, lib_path)
lib_path = self.deepspeed_src_path(lib_path)

args = [f'-L{lib_path}', '-lblockedflash']
if self.jit_load:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import List

import pytest
import torch

from deepspeed.accelerator import get_accelerator
from deepspeed.inference.v2.model_implementations.flat_model_helpers import (
flatten_inference_model,
restore_inference_model,
)
from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer
from .utils import SimpleParam, DummyInferenceModel


class TransformerLayerContainer(LayerContainer):
"""
Stub layer container
"""
PARAM_MAPPING = {
"param_1": "param_1.param",
"param_2": "param_2.param",
}

param_1: SimpleParam

param_2: SimpleParam


class NonTransformerContainer(LayerContainer):
"""
Stub layer container
"""
PARAM_MAPPING = {
"param_1": "param_1.param",
"param_2": "param_2.param",
"param_3": "param_3.param",
}

param_1: SimpleParam

param_2: SimpleParam

param_3: SimpleParam


@pytest.mark.inference_v2
def test_contiguify_roundtrip():
"""
Validate that contiguify round trips and reconstructions are correct.
"""
model = DummyInferenceModel()

n_layers = 2
transformer_params = []
transformer_containers = []

# Create parameters and populate them into the containers
for i in range(n_layers):
transformer_containers.append(TransformerLayerContainer(model))
layer_params = []
for j in range(2):
layer_params.append(torch.rand(16, 16))
transformer_containers[i].set_dependency(f"param_{j+1}", layer_params[j])

layer_params = [p.to(get_accelerator().current_device()) for p in layer_params]

transformer_params.append(layer_params)
assert transformer_containers[i].is_populated == True

non_transformer_params = []
non_transformer_container = NonTransformerContainer(model)

for i in range(3):
non_transformer_params.append(torch.rand(16, 16).permute(1, 0))
non_transformer_container.set_dependency(f"param_{i+1}", non_transformer_params[i])

non_transformer_params = [p.to(get_accelerator().current_device()) for p in non_transformer_params]

def validate_containers(t_containers: List[LayerContainer],
n_t_containers: LayerContainer,
t_params: List[List[torch.Tensor]],
n_t_params: List[torch.Tensor]):
"""
Validate params match what is on the containers.
"""
for i in range(n_layers):
l_c = t_containers[i]

assert l_c.is_initialized == True

assert torch.equal(l_c.param_1, t_params[i][0])
assert torch.equal(l_c.param_2, t_params[i][1])

assert n_t_containers.is_initialized == True
assert torch.equal(n_t_containers.param_1, n_t_params[0])
assert torch.equal(n_t_containers.param_2, n_t_params[1])
assert torch.equal(n_t_containers.param_3, n_t_params[2])
assert not n_t_containers.param_1.is_contiguous()
assert not n_t_containers.param_2.is_contiguous()
assert not n_t_containers.param_3.is_contiguous()

buffer, metadata = flatten_inference_model(transformer_containers, non_transformer_container, "NoOpPolicy")

# Validate containers before contiguify
validate_containers(transformer_containers, non_transformer_container, transformer_params,
non_transformer_params)

# Validate restore pass
transformer_containers_r = []
for i in range(n_layers):
transformer_containers_r.append(TransformerLayerContainer(model))

non_transformer_container_r = NonTransformerContainer(model)

restore_inference_model(buffer, metadata, transformer_containers_r, non_transformer_container_r)

validate_containers(transformer_containers_r, non_transformer_container_r, transformer_params,
non_transformer_params)


Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import pytest
import torch

from deepspeed.inference.v2.inference_parameter import InferenceParameter
from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer

from .utils import validate_device, SimpleParam, DummyInferenceModel
from .utils import SimpleParam, DummyInferenceModel


class ParentLayer(LayerContainer):
Expand Down Expand Up @@ -42,9 +43,6 @@ def test_layer_inheritance():

multi_param_layer.param_2.param = torch.full((16, 16), 2.0)

assert multi_param_layer.is_initialized is True
assert isinstance(multi_param_layer.param_1, torch.Tensor)
assert isinstance(multi_param_layer.param_2, torch.Tensor)

validate_device(multi_param_layer.param_1)
validate_device(multi_param_layer.param_2)
assert multi_param_layer.is_populated is True
assert isinstance(multi_param_layer.param_1, InferenceParameter)
assert isinstance(multi_param_layer.param_2, InferenceParameter)
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class DummyInferenceModel:
def num_dependencies(self) -> int:
return 2

@on_device
def transform(self, param: torch.Tensor) -> torch.Tensor:
return InferenceParameter.initialize(param)

Expand Down

0 comments on commit 89566b4

Please sign in to comment.