Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OptimizedLinear updates #5791

Merged
merged 20 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepspeed/linear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

from .optimized_linear import OptimizedLinear
from .config import LoRAConfig, QuantizationConfig
from .context_manager import Init, init_lora
12 changes: 11 additions & 1 deletion deepspeed/linear/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

# DeepSpeed Team

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import List


@dataclass
Expand All @@ -17,10 +18,19 @@ class LoRAConfig:
base_weight_sharding (int): The degree to which the base weights are sharded,
should typically be set to the data-parallel world size to maximize the memory
reduction benefits. Defaults to 1, which means this feature is disabled.
offload (bool): offload frozen parameters to cpu when not in use
offload_ratio (float): ratio of parameters to offload to cpu when not in use
delay_lora_init (bool): initialize lora parameters at time of model init or allow manual init later
target_mods (str): target module names to apply LoRA to, defaults to llama-3.1 arch
"""
lora_r: int = 64
lora_alpha: float = 16.
base_weight_sharding: int = 1
offload: bool = False
offload_ratio: float = 0.0
delay_lora_init: bool = False
target_mods: List[str] = field(
default_factory=lambda: ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'])


@dataclass
Expand Down
90 changes: 90 additions & 0 deletions deepspeed/linear/context_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .optimized_linear import LoRAOptimizedLinear, OptimizedLinear

import torch

try:
import transformers
except ImportError:
transformers = None


def init_lora(model):
model.requires_grad_(False)
for m in model.modules():
if isinstance(m, LoRAOptimizedLinear):
m.init_lora()


class Init(object):
"""
Init context wrapper similar in style to zero.Init. Allows for injecting OptimizedLinear during model
construction which will shard base weights and reduce overall memory usage during model init. Primarily
useful when initializing a model via transformers.AutoModelForCausalLM.

Example usage:
lora_config = deepspeed.linear.LoRAConfig(..)
quant_config = deepspeed.linear.QuantizationConfig(..)
with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config):
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-405B")

"""

def __init__(self, lora_config=None, quant_config=None):
self._orig_nn_linear = torch.nn.Linear
self._orig_causallm_pretrained = None
if transformers != None:
self._orig_causallm_pretrained = transformers.AutoModelForCausalLM.from_pretrained
self._orig_causallm_config = transformers.AutoModelForCausalLM.from_config
self.lora_config = lora_config
self.quant_config = quant_config
self._post_init_complete = False

def __enter__(self):

class OptLinearWrapper:
_orig_nn_linear = self._orig_nn_linear
_lora_config = self.lora_config
_quant_config = self.quant_config

def __new__(self, *args, **kwargs):
self._lora_config.delay_lora_init = True
kwargs['lora_config'] = self._lora_config
kwargs['quantization_config'] = self._quant_config
kwargs['linear_cls'] = self._orig_nn_linear
return OptimizedLinear(*args, **kwargs)

def _model_init(model):
if self.lora_config != None:
init_lora(model)
self._post_init_complete = True
return model

# ensures non-lora params are frozen and lora weights are initialized
def from_pretrained(*args, **kwargs):
model = self._orig_causallm_pretrained(*args, **kwargs)
return _model_init(model)

def from_config(*args, **kwargs):
model = self._orig_causallm_config(*args, **kwargs)
return _model_init(model)

torch.nn.Linear = OptLinearWrapper
if transformers != None:
transformers.AutoModelForCausalLM.from_pretrained = from_pretrained
transformers.AutoModelForCausalLM.from_config = from_config

def __exit__(self, *args, **kwargs):
torch.nn.Linear = self._orig_nn_linear
if not self._post_init_complete:
print('WARNING: For some reason LoRA modules are not initialized, this is usually done automatically '
'if using transformers via (AutoModelForCausalLM from_pretrained/from_config). '
'You must call `init_lora` on each module in order to use DeepSpeed LoRA, otherwise '
'you will error out during runtime.')
else:
transformers.AutoModelForCausalLM.from_pretrained = self._orig_causallm_pretrained
transformers.AutoModelForCausalLM.from_config = self._orig_causallm_config
144 changes: 108 additions & 36 deletions deepspeed/linear/optimized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,17 @@ def __new__(self,
bias: bool = False,
lora_config: LoRAConfig = None,
quantization_config: QuantizationConfig = None,
dtype=torch.bfloat16):
device=None,
dtype=torch.bfloat16,
linear_cls=nn.Linear):

if quantization_config is not None and not is_dataclass(quantization_config):
raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}")
if lora_config is not None and not is_dataclass(lora_config):
raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}")
if lora_config is None and quantization_config is None:
# Everything disabled, fall back to normal nn.Linear
self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype)
self = linear_cls(input_dim, output_dim, bias=bias, dtype=dtype, device=device)

elif lora_config:
# lora enabled, quantization may or may not be
Expand All @@ -57,7 +59,9 @@ def __new__(self,
bias=bias,
lora_config=lora_config,
quantization_config=quantization_config,
dtype=dtype)
dtype=dtype,
device=device,
linear_cls=linear_cls)

elif quantization_config:
# only quantization enabled, no lora
Expand All @@ -78,72 +82,140 @@ def __init__(self,
lora_config: LoRAConfig = None,
quantization_config: QuantizationConfig = None,
device=None,
dtype=torch.bfloat16):
dtype=torch.bfloat16,
linear_cls=nn.Linear):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.bias = bias
self.lora_config = lora_config
self.quantization_config = quantization_config
device = get_accelerator().current_device_name() if device is None else device
self.device = get_accelerator().current_device_name() if device is None else device
self.linear_cls = linear_cls
self.dtype = dtype
assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config"

assert not self.bias, "bias=True is not supported by LoRAOptimizedLinear"
self.zero_shards = self.lora_config.base_weight_sharding
self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards)
w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype))
torch.nn.init.xavier_uniform_(w)
if self.zero_shards > 1:
assert self.zero_shards == dist.get_world_size(
), "base weight sharding is only supported across world size"
w = torch.nn.Parameter(torch.empty(self.output_dim * self.sharded_weight_size, dtype=dtype),
requires_grad=False)
else:
w = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=dtype), requires_grad=False)
torch.nn.init.xavier_uniform_(w.reshape(self.sharded_weight_size, self.output_dim))

if self.quantization_config is not None:
assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization"
self.base_weight = QuantizedParameter(w, quantization_config=quantization_config)
self.weight = QuantizedParameter(w, quantization_config=quantization_config)
else:
self.base_weight = w
self.weight = w

self.disabled = False
self._initialized = False
if not self.lora_config.delay_lora_init:
self.init_lora()

def disable(self):
self.disabled = True
self.weight = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=self.dtype),
requires_grad=False)

def init_lora(self):
if self.disabled:
return

if self.quantization_config is not None:
# ensure quant-param wasn't stripped, in some cases transformers will do this during model init
if not isinstance(self.weight, QuantizedParameter):
self.weight = QuantizedParameter(self.weight, quantization_config=self.quantization_config)

self._initialized = True
self.weight.requires_grad = False

self.base_weight.requires_grad = False
# Mark base weight to prevent broadcast and ensure proper offload behavior
self.weight.ds_optim_param = True

self.lora_scaling_factor = self.lora_config.lora_alpha / self.lora_config.lora_r

# Use RS lora for now.
self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r)
# Keeping lora weights in bf16 precision for ease of training.
self.lora_weight_1 = nn.Linear(self.input_dim,
self.lora_config.lora_r,
bias=self.bias,
device=device,
dtype=dtype)
self.lora_weight_2 = nn.Linear(self.lora_config.lora_r,
self.output_dim,
bias=self.bias,
device=device,
dtype=dtype)
self.lora_weight_1 = self.linear_cls(self.input_dim,
self.lora_config.lora_r,
bias=self.bias,
device=self.device,
dtype=self.dtype)
self.lora_weight_2 = self.linear_cls(self.lora_config.lora_r,
self.output_dim,
bias=self.bias,
device=self.device,
dtype=self.dtype)

# initialize "A" with kaiming uniform and "B" with zeros following this
# https://github.com/huggingface/peft/blob/62122b5add8d6892f70c82eaef2147a6ba33b90b/src/peft/tuners/lora/layer.py#L155
nn.init.kaiming_uniform_(self.lora_weight_1.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_weight_2.weight)
self.lora_weight_1.weight.requires_grad = True
self.lora_weight_2.weight.requires_grad = True

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
if not any([target in prefix for target in self.lora_config.target_mods]):
# module does not match any target_mods, we must revert to normal nn.Linear via disable
self.disable()
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs)

if self.zero_shards > 1:
if not dist.is_initialized():
raise RuntimeError(
"attempting to use optimized linear base weight sharding but torch-distributed is not initialized, please init first."
)
rank = dist.get_rank()
shape_local = self.output_dim * self.sharded_weight_size
base_weight_name = f"{prefix}weight"
incoming_param = state_dict[base_weight_name]
state_dict[base_weight_name] = incoming_param.flatten().narrow(0, rank * shape_local, shape_local)

return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)

def full_weight(self):
# This assumes weights are evenly sharded across gpus. which might not be correct.
# in that case, we should flatten before all_gather.
local_weight = self.base_weight.dequantized() if isinstance(self.base_weight,
QuantizedParameter) else self.base_weight
tensor_list = [
torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype)
for _ in range(self.zero_shards)
]
dist.all_gather(tensor_list, local_weight)
weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1))
return weight
base_weight = self.weight
if getattr(base_weight, 'ds_offload', False):
# move to gpu so we can dequant and all-gather
assert base_weight.device == torch.device('cpu'), \
f"expected base weight on cpu but found {base_weight.device}"
base_weight.offload(revert=True)
local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight
base_weight.offload()
else:
local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight

tensor_out = torch.empty(self.output_dim * self.input_dim,
dtype=local_weight.dtype,
device=local_weight.device)
dist.all_gather_into_tensor(tensor_out, local_weight)
return tensor_out.reshape(self.output_dim, self.input_dim)

def linear_without_F_linear(self, input, weight):
output = torch.mm(input.reshape(-1, input.shape[-1]), weight)
output = output.view(*input.shape[:-1], weight.shape[1])
return output

def forward(self, input_tensor):
if self.disabled:
return F.linear(input_tensor, self.weight)
assert self._initialized, "init_lora was never called, please initialize before proceeding"

# Gather the sharded base weight
if self.zero_shards > 1:
with torch.no_grad():
base_weight = self.full_weight()
elif self.quantization_config:
base_weight = self.base_weight.dequantized()
base_weight = self.weight.dequantized()
else:
base_weight = self.base_weight
base_weight = self.weight

base_weight_output = F.linear(input_tensor, base_weight)
lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor))
Expand Down
12 changes: 11 additions & 1 deletion deepspeed/linear/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ def dequantized(self) -> torch.Tensor:
q_mantisa_bits=self.quantization_config.mantissa_bits)
return self.data

def offload(self, revert=False):
if getattr(self, 'ds_offload', False):
if revert:
self.data = self.to(get_accelerator().current_device_name())
else:
self.data = self.to('cpu')

def __getstate__(self):
state = self.__dict__
state["data"] = self.data
Expand Down Expand Up @@ -104,14 +111,17 @@ def __copy__(self):
return new_instance

def cuda(self, device=None, non_blocking=False):
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
device = "cuda" if device is None else device
self.quantizer.to(device, non_blocking=non_blocking)
return self.to(device, non_blocking=non_blocking)

def to(self, *args, **kwargs):
"""
Move the parameter to the given device. Then, if the device is a cuda device,
quantize it.
"""
tensor = super().to(*args, **kwargs)
self.quantizer.to(*args, **kwargs)
self._ensure_quantized(tensor)
return tensor

Expand Down
7 changes: 7 additions & 0 deletions deepspeed/ops/fp_quantizer/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ def quantize(self,

return out

def to(self, *args, **kwargs):
# Intermediate tensors may need to be moved to different devices
if hasattr(self, 'input_q'):
self.input_q = self.input_q.to(*args, **kwargs)
if hasattr(self, 'scale'):
self.scale = self.scale.to(*args, **kwargs)

def get_scales(self):
return fp_quant_module.get_scales(self.scale, self.num_groups)

Expand Down
Loading
Loading