Skip to content

Commit

Permalink
resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Aug 28, 2024
2 parents 36d6e10 + 4864991 commit af95a37
Show file tree
Hide file tree
Showing 33 changed files with 446 additions and 251 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/nv-a6000.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ jobs:
- name: Install deepspeed
run: |
python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja
python -m pip install pydantic==1.10.11
# Update packages included in the container that do not support pydantic 2+ to versions that do
python -m pip install thinc spacy confection --upgrade
python -m pip install .[dev,1bit,autotuning,inf]
ds_report
- name: Python environment
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ repos:
# Do not check files that are automatically generated
'--skip=docs/Gemfile.lock,tests/unit/gpt2-merges.txt,tests/unit/gpt2-vocab.json',
'--ignore-regex=\\n', # Do not count the 'n' in an escaped newline as part of a word
'--ignore-words-list=youn,unsupport,noe', # Word used in error messages that need rewording
'--ignore-words-list=youn,unsupport,noe,cann', # Word used in error messages that need rewording
--check-filenames,
--check-hidden
]
Expand Down
2 changes: 1 addition & 1 deletion csrc/aio/py_lib/deepspeed_aio_op_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct io_op_desc_t {
const std::string _filename;
const long long int _file_num_bytes;
const int _num_threads;
const int _num_bytes_per_thread;
const long long int _num_bytes_per_thread;
torch::Tensor _contiguous_buffer;
const bool _validate;

Expand Down
14 changes: 3 additions & 11 deletions deepspeed/comm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,12 @@

# DeepSpeed Team

from .constants import *
from ..pydantic_v1 import BaseModel

from deepspeed.runtime.config_utils import DeepSpeedConfigModel

class CommsConfig(BaseModel):

class Config:
validate_all = True
validate_assignment = True
use_enum_values = True
extra = 'forbid'
from .constants import *


class CommsLoggerConfig(CommsConfig):
class CommsLoggerConfig(DeepSpeedConfigModel):
enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT
Expand Down
119 changes: 63 additions & 56 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,25 @@

import torch
import deepspeed
from deepspeed.pydantic_v1 import Field, validator
from pydantic import Field, field_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from typing import Dict, Union
from typing import Dict, Union, Optional
from enum import Enum


class DtypeEnum(Enum):
# The torch dtype must always be the first value (so we return torch.dtype)
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat"
int8 = torch.int8, "torch.int8", "int8"

# Copied from https://stackoverflow.com/a/43210118
# Allows us to use multiple values for each Enum index and returns first
# listed value when Enum is called
def __new__(cls, *values):
obj = object.__new__(cls)
# first value is canonical value
obj._value_ = values[0]
for other_value in values[1:]:
cls._value2member_map_[other_value] = obj
obj._all_values = values
return obj

def __repr__(self):
return "<%s.%s: %s>" % (
self.__class__.__name__,
self._name_,
", ".join([repr(v) for v in self._all_values]),
)
fp16 = (torch.float16, "torch.float16", "fp16", "float16", "half")
fp32 = (torch.float32, "torch.float32", "fp32", "float32", "float")
bf16 = (torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat")
int8 = (torch.int8, "torch.int8", "int8")

@classmethod
def from_str(cls, value: str):
for dtype in cls:
if value in dtype.value:
return dtype
raise ValueError(f"'{value}' is not a valid DtypeEnum")


class MoETypeEnum(str, Enum):
Expand Down Expand Up @@ -91,24 +78,24 @@ class QuantTypeEnum(str, Enum):


class BaseQuantConfig(DeepSpeedConfigModel):
enabled = True
num_bits = 8
enabled: bool = True
num_bits: int = 8
q_type: QuantTypeEnum = QuantTypeEnum.sym
q_groups: int = 1


class WeightQuantConfig(BaseQuantConfig):
enabled = True
enabled: bool = True
quantized_initialization: Dict = {}
post_init_quant: Dict = {}


class ActivationQuantConfig(BaseQuantConfig):
enabled = True
enabled: bool = True


class QKVQuantConfig(DeepSpeedConfigModel):
enabled = True
enabled: bool = True


class QuantizationConfig(DeepSpeedConfigModel):
Expand All @@ -120,9 +107,9 @@ class QuantizationConfig(DeepSpeedConfigModel):

# todo: brainstorm on how to do ckpt loading for DS inference
class InferenceCheckpointConfig(DeepSpeedConfigModel):
checkpoint_dir: str = None
save_mp_checkpoint_path: str = None
base_dir: str = None
checkpoint_dir: Optional[str] = None
save_mp_checkpoint_path: Optional[str] = None
base_dir: Optional[str] = None


class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
Expand All @@ -136,7 +123,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
`(attention_output projection, transformer output projection)`
"""

dtype: DtypeEnum = torch.float16
dtype: torch.dtype = torch.float16
"""
Desired model data type, will convert model to this type.
Supported target types: `torch.half`, `torch.int8`, `torch.float`
Expand Down Expand Up @@ -198,7 +185,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
"""

#todo: refactor the following 3 into the new checkpoint_config
checkpoint: Union[str, Dict] = None
checkpoint: Optional[Union[str, Dict]] = None
"""
Path to deepspeed compatible checkpoint or path to JSON with load policy.
"""
Expand All @@ -214,7 +201,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
specifying whether the inference-module is created with empty or real Tensor
"""

save_mp_checkpoint_path: str = None
save_mp_checkpoint_path: Optional[str] = None
"""
The path for which we want to save the loaded model with a checkpoint. This
feature is used for adjusting the parallelism degree to help alleviate the
Expand Down Expand Up @@ -243,19 +230,21 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):

replace_method: str = Field(
"auto",
deprecated=True,
deprecated_msg="This parameter is no longer needed, please remove from your call to DeepSpeed-inference")
json_schema_extra={
"deprecated": True,
"deprecated_msg": "This parameter is no longer needed, please remove from your call to DeepSpeed-inference"
})

injection_policy: Dict = Field(None, alias="injection_dict")
injection_policy: Optional[Dict] = Field(None, alias="injection_dict")
"""
Dictionary mapping a client nn.Module to its corresponding injection
policy. e.g., `{BertLayer : deepspeed.inference.HFBertLayerPolicy}`
"""

injection_policy_tuple: tuple = None
injection_policy_tuple: Optional[tuple] = None
""" TODO: Add docs """

config: Dict = Field(None, alias="args") # todo: really no need for this field if we can refactor
config: Optional[Dict] = Field(None, alias="args") # todo: really no need for this field if we can refactor

max_out_tokens: int = Field(1024, alias="max_tokens")
"""
Expand All @@ -274,31 +263,49 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):

transposed_mode: bool = Field(False, alias="transposed_mode")

mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size")
mp_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.tp_size"})
"""
Desired model parallel size, default is 1 meaning no model parallelism.
Deprecated, please use the ``tensor_parallel` config to control model
parallelism.
"""
mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu")
ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size")
ep_group: object = Field(None, alias="expert_group", deprecated=True, new_param="moe.ep_group")
ep_mp_group: object = Field(None, alias="expert_mp_group", deprecated=True, new_param="moe.ep_mp_group")
moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts")
moe_type: MoETypeEnum = Field(MoETypeEnum.standard, deprecated=True, new_param="moe.type")

@validator("moe")
mpu: object = Field(None, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.mpu"})
ep_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "moe.ep_size"})
ep_group: object = Field(None,
alias="expert_group",
json_schema_extra={
"deprecated": True,
"new_param": "moe.ep_group"
})
ep_mp_group: object = Field(None,
alias="expert_mp_group",
json_schema_extra={
"deprecated": True,
"new_param": "moe.ep_mp_group"
})
moe_experts: list = Field([1], json_schema_extra={"deprecated": True, "new_param": "moe.moe_experts"})
moe_type: MoETypeEnum = Field(MoETypeEnum.standard,
json_schema_extra={
"deprecated": True,
"new_param": "moe.type"
})

@field_validator("dtype", mode="before")
def validate_dtype(cls, field_value, values):
if isinstance(field_value, str):
return DtypeEnum.from_str(field_value).value[0]
if isinstance(field_value, torch.dtype):
return field_value
raise TypeError(f"Invalid type for dtype: {type(field_value)}")

@field_validator("moe")
def moe_backward_compat(cls, field_value, values):
if isinstance(field_value, bool):
return DeepSpeedMoEConfig(moe=field_value)
return field_value

@validator("use_triton")
@field_validator("use_triton")
def has_triton(cls, field_value, values):
if field_value and not deepspeed.HAS_TRITON:
raise ValueError('Triton needs to be installed to use deepspeed with triton kernels')
return field_value

class Config:
# Get the str representation of the datatype for serialization
json_encoders = {torch.dtype: lambda x: str(x)}
3 changes: 2 additions & 1 deletion deepspeed/inference/v2/config_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

# DeepSpeed Team

from pydantic import Field
from typing import Optional
from deepspeed.pydantic_v1 import Field

from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from .ragged import DSStateManagerConfig

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@ class TensorMetadata(DeepSpeedConfigModel):
"""
A class to represent a tensor specification.
"""
dtype: Optional[str]
shape: Optional[Tuple[int, ...]]
strides: Optional[Tuple[int, ...]]
dtype: Optional[str] = None
shape: Optional[Tuple[int, ...]] = None
strides: Optional[Tuple[int, ...]] = None
offset: int


class ParameterMetadata(DeepSpeedConfigModel):
"""
A class to represent a parameter specification.
"""
core_param: TensorMetadata = None
core_param: Optional[TensorMetadata] = None
aux_params: Dict[str, TensorMetadata] = {}


Expand Down
14 changes: 6 additions & 8 deletions deepspeed/inference/v2/ragged/manager_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from typing import Tuple

from deepspeed.pydantic_v1 import PositiveInt, validator
from pydantic import PositiveInt, model_validator

from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from ..inference_utils import DtypeEnum
Expand Down Expand Up @@ -173,11 +173,9 @@ class DSStateManagerConfig(DeepSpeedConfigModel):
Enable tracking for offloading KV-cache to host memory. Currently unsupported.
"""

@validator("max_ragged_sequence_count")
def max_ragged_sequence_count_validator(cls, v: int, values: dict):
@model_validator(mode="after")
def max_ragged_sequence_count_validator(self):
# If the attributes below failed their validation they won't appear in the values dict.
if "max_tracked_sequences" in values and v > values["max_tracked_sequences"]:
raise ValueError("max_ragged_sequence_count must be less than max_tracked_sequences")
if "max_ragged_batch_size" in values and v > values["max_ragged_batch_size"]:
raise ValueError("max_ragged_sequence_count must be less than max_ragged_batch_size")
return v
assert self.max_ragged_sequence_count <= self.max_tracked_sequences, "max_ragged_sequence_count must be less than max_tracked_sequences"
assert self.max_ragged_sequence_count <= self.max_ragged_batch_size, "max_ragged_sequence_count must be less than max_ragged_batch_size"
return self
16 changes: 8 additions & 8 deletions deepspeed/monitor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Optional

from deepspeed.pydantic_v1 import root_validator
from pydantic import model_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel


Expand Down Expand Up @@ -36,10 +36,10 @@ class WandbConfig(DeepSpeedConfigModel):
enabled: bool = False
""" Whether logging to WandB is enabled. Requires `wandb` package is installed. """

group: str = None
group: Optional[str] = None
""" Name for the WandB group. This can be used to group together runs. """

team: str = None
team: Optional[str] = None
""" Name for the WandB team. """

project: str = "deepspeed"
Expand Down Expand Up @@ -137,8 +137,8 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
csv_monitor: CSVConfig = {}
""" Local CSV output of monitoring data. """

@root_validator
def check_enabled(cls, values):
values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get(
"csv_monitor").enabled or values.get("comet").enabled
return values
@model_validator(mode="after")
def check_enabled(self):
enabled = self.tensorboard.enabled or self.wandb.enabled or self.csv_monitor.enabled or self.comet.enabled
self.__dict__["enabled"] = enabled
return self
16 changes: 0 additions & 16 deletions deepspeed/pydantic_v1.py

This file was deleted.

Loading

0 comments on commit af95a37

Please sign in to comment.