Skip to content

Commit

Permalink
Inference Checkpoints (#4620)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
Co-authored-by: Ammar Ahmad Awan <[email protected]>
Co-authored-by: Masahiro Tanaka <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Reza Yazdani <[email protected]>
Co-authored-by: Reza Yazdani <[email protected]>
  • Loading branch information
8 people authored Nov 10, 2023
1 parent da652d0 commit 19b2587
Show file tree
Hide file tree
Showing 98 changed files with 924 additions and 4,597 deletions.
1 change: 1 addition & 0 deletions .github/workflows/nv-accelerate-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-lightning-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-megatron.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-pre-compile-ops.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-torch-latest-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-torch-latest-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-transformers-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# DeepSpeed Team
from .v2 import RaggedInferenceEngineConfig, DeepSpeedTPConfig
from .v2.engine_v2 import InferenceEngineV2
from .v2 import build_hf_engine
from .v2 import build_hf_engine, buid_engine_from_ds_checkpoint
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# DeepSpeed Team
from .config_v2 import RaggedInferenceEngineConfig, DeepSpeedTPConfig
from .engine_v2 import InferenceEngineV2
from .engine_factory import build_hf_engine
from .engine_factory import build_hf_engine, buid_engine_from_ds_checkpoint
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def on_device(method) -> torch.Tensor:
def wrapped(self, *args, **kwargs):
tensor = method(self, *args, **kwargs)
if isinstance(tensor, torch.Tensor):
return tensor.to(get_accelerator().current_device()).contiguous()
return tensor.to(get_accelerator().current_device())
return tensor

return wrapped
2 changes: 2 additions & 0 deletions deepspeed/inference/v2/checkpoint/huggingface_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
param = checkpoint_sd[param_name]
yield param_name, param

del checkpoint_sd


if __name__ == "__main__":
# To test, add your auth_token here and run `python huggingface_engine.py`
Expand Down
107 changes: 81 additions & 26 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,99 @@

# DeepSpeed Team

import json
import logging
from typing import Any
import os
import pickle

from .engine_v2 import InferenceEngineV2
from .config_v2 import RaggedInferenceEngineConfig
from .checkpoint import HuggingFaceCheckpointEngine
from .logging import inference_logger
from .model_implementations import (
OPTPolicy,
Llama2Policy,
MistralPolicy,
)
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata


def buid_engine_from_ds_checkpoint(path: str,
engine_config: RaggedInferenceEngineConfig,
debug_level: int = logging.INFO) -> InferenceEngineV2:
"""
Creates an engine from a checkpoint saved by ``InferenceEngineV2``.
Arguments:
path: Path to the checkpoint. This does not need to point to any files in particular,
just the directory containing the checkpoint.
engine_config: Engine configuration. See ``RaggedInferenceEngineConfig`` for details.
debug_level: Logging level to use. Unless you are actively seeing issues, the recommended
value is ``logging.INFO``.
Returns:
Fully initialized inference engine ready to serve queries.
"""

inference_logger(level=debug_level)
# Load metadata, for grabbing the policy name we'll have all ranks just check for
# rank 0.
metadata_filename = make_metadata_filename(path, 0, engine_config.tensor_parallel.tp_size)
metadata = json.load(open(metadata_filename, "r"))
metadata = ModelMetadata.parse_raw(metadata)

# Get the policy
try:
policy_cls: InferenceV2Policy = POLICIES[metadata.policy]
except KeyError:
raise ValueError(f"Unknown policy {metadata.policy} for model {path}")

# Load the model config
model_config = pickle.load(open(os.path.join(path, "ds_model_config.pkl"), "rb"))
policy = policy_cls(model_config, inf_checkpoint_path=path)

return InferenceEngineV2(policy, engine_config)


def build_hf_engine(path: str,
engine_config: RaggedInferenceEngineConfig,
debug_level: int = logging.INFO,
random_weights_config: Any = None,
fill_random: bool = False) -> InferenceEngineV2:
debug_level: int = logging.INFO) -> InferenceEngineV2:
"""
Build an InferenceV2 engine for HuggingFace models.
Build an InferenceV2 engine for HuggingFace models. This can accept both a HuggingFace
model name or a path to an Inference-V2 checkpoint.
Arguments:
path: Path to the checkpoint. This does not need to point to any files in particular,
just the directory containing the checkpoint.
engine_config: Engine configuration. See ``RaggedInferenceEngineConfig`` for details.
debug_level: Logging level to use. Unless you are actively seeing issues, the recommended
value is ``logging.INFO``.
Returns:
Fully initialized inference engine ready to serve queries.
"""
# Set up logging
inference_logger(level=debug_level)

# get HF checkpoint engine
checkpoint_engine = HuggingFaceCheckpointEngine(path)

# get model config from HF AutoConfig
model_config = checkpoint_engine.model_config

# get the policy
# TODO: generalize this to other models
if model_config.model_type == "opt":
from .model_implementations.opt.policy import OPTPolicy
policy = OPTPolicy(checkpoint_engine, model_config)
elif model_config.model_type == "llama":
from .model_implementations.llama_v2.llama_v2_policy import Llama2Policy
policy = Llama2Policy(checkpoint_engine, model_config)
elif model_config.model_type == "mistral":
from .model_implementations.mistral.policy import MistralPolicy
policy = MistralPolicy(checkpoint_engine, model_config)
if os.path.exists(os.path.join(path, "ds_model_config.pkl")):
return buid_engine_from_ds_checkpoint(path, engine_config, debug_level=debug_level)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")
# Set up logging
inference_logger(level=debug_level)
# get HF checkpoint engine
checkpoint_engine = HuggingFaceCheckpointEngine(path)

return InferenceEngineV2(policy, engine_config)
# get model config from HF AutoConfig
model_config = checkpoint_engine.model_config

# get the policy
# TODO: generalize this to other models
if model_config.model_type == "opt":
policy = OPTPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "llama":
policy = Llama2Policy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "mistral":
policy = MistralPolicy(model_config, checkpoint_engine=checkpoint_engine)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")

return InferenceEngineV2(policy, engine_config)
32 changes: 31 additions & 1 deletion deepspeed/inference/v2/engine_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# DeepSpeed Team

import os
import json
import pickle
from typing import Iterable, Tuple

import torch
Expand All @@ -17,6 +19,8 @@
from .logging import inference_logger
from .ragged import DSStateManager, RaggedBatchWrapper, PlaceholderSequenceDescriptor
from .scheduling_utils import SchedulingError, SchedulingResult
from .model_implementations.flat_model_helpers import make_param_filename, make_metadata_filename
from .model_implementations.inference_model_base import DSInferenceModelBase

from .config_v2 import RaggedInferenceEngineConfig

Expand All @@ -30,7 +34,7 @@ class InferenceEngineV2:
Configuration of the inference engine.
"""

#_model: DSInferenceModelBase
_model: DSInferenceModelBase
"""
Inference model supporting ragged inference.
"""
Expand All @@ -47,6 +51,13 @@ def free_blocks(self) -> int:
"""
return self._state_manager.free_blocks

@property
def model(self) -> DSInferenceModelBase:
"""
The model implementation.
"""
return self._model

def __init__(self, policy: InferenceV2Policy, engine_config: RaggedInferenceEngineConfig) -> None:
"""
Create the Inference V2 engine.
Expand Down Expand Up @@ -215,3 +226,22 @@ def flush(self, uid: int) -> None:
uid (int): The UID of the sequence to flush.
"""
self._state_manager.flush_sequence(uid)

def serialize(self, save_path: str) -> None:
"""
Serialize the model to a file.
Arguments:
path (str): Path to the file to serialize to.
"""
param_file_name = make_param_filename(save_path, self._model.tp_rank, self._model.tp_size)
metadata_file_name = make_metadata_filename(save_path, self._model.tp_rank, self._model.tp_size)

# Save the flattened parameters

torch.save(self._model.flattened_params, param_file_name)

json.dump(self._model.flattened_param_metadata.json(), open(metadata_file_name, "w"))

if self._model.tp_rank == 0:
pickle.dump(self._model._config, open(os.path.join(save_path, "ds_model_config.pkl"), "wb"))
89 changes: 89 additions & 0 deletions deepspeed/inference/v2/inference_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Dict

import torch

CORE_PARAM = "_ds_core_param_key"

STR_TO_DTYPE = {
"torch.float32": torch.float32,
"torch.float64": torch.float64,
"torch.float16": torch.float16,
"torch.bfloat16": torch.bfloat16,
"torch.int64": torch.int64,
"torch.int32": torch.int32,
"torch.int16": torch.int16,
"torch.int8": torch.int8,
"torch.uint8": torch.uint8,
"torch.bool": torch.bool,
}


class InferenceParameter(torch.Tensor):
"""
An extension of the torch.Tensor class to support our inference focused features. One important
thing to note here is that an InferenceParam can be used a torch.Tensor, but outputs of
torch.Tensor operations will not be InferenceParams.
"""

@staticmethod
def __new__(cls, tensor, *args, **kwargs):
new_tensor = super().__new__(cls, tensor, *args, **kwargs)
if hasattr(tensor, "_aux_attrs"):
setattr(new_tensor, "_aux_attrs", tensor.aux_attrs)
return new_tensor

def to(self, *args, **kwargs):
new_tensor = super().to(*args, **kwargs)
if hasattr(self, "_aux_attrs"):
setattr(new_tensor, "_aux_attrs", self.aux_attrs)
try:
_ = torch.device(args[0])
for name, attr in new_tensor.aux_attrs.items():
new_attr = attr.to(*args, **kwargs)
setattr(new_tensor, name, new_attr)
new_tensor.aux_attrs[name] = new_attr
except:
pass

return new_tensor

@classmethod
def initialize(cls, core_param: torch.Tensor, **kwargs) -> 'InferenceParameter':
"""
Create the inference parameter.
"""
param = InferenceParameter(core_param)
setattr(param, "_aux_attrs", kwargs)

for attr_name, attr in kwargs.items():
if hasattr(param, attr_name):
raise ValueError(f"Attribute {attr_name} already exists on param.")

if not isinstance(attr, torch.Tensor):
raise ValueError(f"Attribute {attr_name} must be a tensor.")

setattr(param, attr_name, attr)

return param

@classmethod
def initialize_raw(self, **kwargs) -> 'InferenceParameter':
"""
All kwargs must be torch.Tensors and must include the core parameter.
"""
if CORE_PARAM not in kwargs:
raise ValueError(f"Must provide core parameter, with key {CORE_PARAM}.")

return InferenceParameter.initialize(kwargs[CORE_PARAM], **kwargs)

@property
def aux_attrs(self) -> Dict[str, torch.Tensor]:
"""
Dictionary of auxiliary attributes.
"""
return self._aux_attrs
5 changes: 5 additions & 0 deletions deepspeed/inference/v2/model_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@
from .inference_transformer_base import DSTransformerModelBase, DSMoETransformerModelBase
from .inference_policy_base import InferenceV2Policy, ContainerMap
from .sharding import *

# Model Implementations
from .llama_v2 import *
from .opt import *
from .mistral import *
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch

from ...model_implementations.parameter_base import ParameterBase
from ...allocator import on_device
"""
Embedding containers.
"""
Expand All @@ -23,7 +22,5 @@ class EmbeddingParameter(ParameterBase):
Vocabulary parameter of shape [vocab_size, model_dim].
"""

@on_device
def finalize(self) -> torch.Tensor:
return self.params
#return self.inference_model.transform_embed_param(self.params)
return self.inference_model.transform_embedding_param(self.params)
Loading

0 comments on commit 19b2587

Please sign in to comment.