Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] EVA: ensure deterministic behavior of SVD on multi gpu setups #2225

Merged
merged 8 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/eva_finetuning/eva_finetuning_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@

# Wrap model in DDP
model = model.to(local_rank)
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)

# setup peft config
eva_config = EvaConfig(rho=rho)
Expand All @@ -96,7 +96,8 @@
)

# EVA initialization
eva_state_dict = get_eva_state_dict(model, dataloader, peft_config)
# It is important to set `gather_distributed_inputs=True` here if you use a distributed data sampler.
eva_state_dict = get_eva_state_dict(model, dataloader, peft_config, gather_distributed_inputs=True)
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
eva_state_dict = {".".join(["base_model.model"] + k.split(".")[1:]): v for k, v in eva_state_dict.items()}

# cleanup ddp
Expand Down
67 changes: 48 additions & 19 deletions src/peft/tuners/lora/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,14 @@ class _Hook:
A base class for hooks that prepares layer inputs for EVA.
"""

def __init__(self, name: str, prepare_layer_inputs_fn: Optional[callable] = None):
def __init__(
self,
name: str,
prepare_layer_inputs_fn: Optional[callable] = None,
gather_distributed_inputs: bool = False,
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
):
self.name = name
self.gather_distributed_inputs = gather_distributed_inputs
if prepare_layer_inputs_fn is None:
self._prepare_layer_inputs_fn = self._prepare_layer_inputs_fn_default
else:
Expand All @@ -71,9 +77,8 @@ def _prepare_layer_inputs_fn_default(layer_input, model_input, layer_name) -> to
def prepare_layer_inputs(self, layer_input):
return self._prepare_layer_inputs_fn(layer_input, self.model_input, self.name)

@staticmethod
def gather_layer_inputs(layer_input):
if dist.is_initialized():
def gather_layer_inputs(self, layer_input):
if dist.is_initialized() and self.gather_distributed_inputs:
world_size = dist.get_world_size()

# First gather sizes from all processes more efficiently
Expand Down Expand Up @@ -116,12 +121,11 @@ class SVDHook(_Hook):

def __init__(
self,
name: str,
n_components: int,
sim_thresh: Union[float, torch.Tensor],
prepare_layer_inputs_fn: Optional[callable] = None,
**base_class_kwargs,
):
super().__init__(name, prepare_layer_inputs_fn)
super().__init__(**base_class_kwargs)
self.n_components = n_components
self.sim_thresh = sim_thresh
if isinstance(sim_thresh, torch.Tensor) and len(sim_thresh.shape) > 0:
Expand All @@ -131,7 +135,12 @@ def __init__(
raise ValueError(
"if sim_thresh is a tensor with more than 0 dimensions it must have shape (n_components,) or (1,)"
)
self.svd = IncrementalPCA(n_components=n_components, copy=True, lowrank=True)
self.svd = IncrementalPCA(
n_components=n_components,
copy=True,
lowrank=True,
lowrank_seed=42,
)
self.model_input = None
self.converged = torch.zeros((n_components,), dtype=torch.bool)

Expand Down Expand Up @@ -174,12 +183,8 @@ class HashHook(_Hook):
prepare_layer_inputs_fn (Optional[callable]): Function to prepare layer inputs for hashing.
"""

def __init__(
self,
name: str,
prepare_layer_inputs_fn: Optional[callable] = None,
):
super().__init__(name, prepare_layer_inputs_fn)
def __init__(self, **base_class_kwargs):
super().__init__(**base_class_kwargs)
self.hashed_inputs = []

@staticmethod
Expand Down Expand Up @@ -289,11 +294,9 @@ def _get_eva_state_dict(
forward_fn: Optional[callable],
prepare_model_inputs_fn: Optional[callable],
prepare_layer_inputs_fn: Union[callable, Dict[str, callable], None],
gather_distributed_inputs: bool,
show_progress_bar: bool,
) -> dict:
# Set seeds for reproducibility at the start of EVA computation
torch.manual_seed(0)

# Computes the rank distribution for each layer based on the explained variance ratio.
# when rank_pattern flag is False, all values in max_components are the same
def _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components):
Expand All @@ -314,6 +317,14 @@ def _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget,
if len(dataloader) == 0:
raise ValueError("dataloader is empty")

# check if dist is initialized
if dist.is_initialized() and gather_distributed_inputs:
warnings.warn(
"torch.distributed is initialized and `gather_distributed_inputs` is True, "
"therefore EVA initialization will gather tensors from all ranks. "
"Ensure the model does not receive the same inputs on different ranks."
)

# for unusually high rho values, define an upper limit
rho_threshold = 1000
rho = peft_config.eva_config.rho
Expand Down Expand Up @@ -345,7 +356,7 @@ def _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget,
fn = prepare_layer_inputs_fn.pop(name, None)
else:
fn = prepare_layer_inputs_fn
hook = HashHook(name, fn)
hook = HashHook(name=name, prepare_layer_inputs_fn=fn, gather_distributed_inputs=gather_distributed_inputs)
hook.model_input = model_inputs_for_hooks
handle = module.register_forward_hook(hook)
hooks[name] = (hook, handle)
Expand Down Expand Up @@ -377,7 +388,13 @@ def _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget,
handle.remove()
if name in equal_inputs_map:
continue
hook = SVDHook(name, max_components[name], peft_config.eva_config.tau, hook._prepare_layer_inputs_fn)
hook = SVDHook(
n_components=max_components[name],
sim_thresh=peft_config.eva_config.tau,
name=name,
prepare_layer_inputs_fn=hook._prepare_layer_inputs_fn,
gather_distributed_inputs=gather_distributed_inputs,
)
module = model.get_submodule(name)
handle = module.register_forward_hook(hook)
hooks[name] = (hook, handle) # adding the old handle here so we dont get errors in the first forward pass
Expand Down Expand Up @@ -546,6 +563,7 @@ def get_eva_state_dict(
prepare_model_inputs_fn: Optional[callable] = prepare_model_inputs_fn_language_modeling,
prepare_layer_inputs_fn: Union[callable, Dict[str, callable], None] = prepare_layer_inputs_fn_language_modeling,
adapter_name: str = "default",
gather_distributed_inputs: bool = True,
show_progress_bar: bool = True,
) -> dict:
"""
Expand Down Expand Up @@ -579,6 +597,10 @@ def get_eva_state_dict(
case model_inputs is the mask used to determine which indices should be used for SVD (created by
`prepare_model_inputs_fn_language_modeling`).
adapter_name (str): The name of the adapter to compute the SVD for.
gather_distributed_inputs (bool):
Whether to gather the layer inputs from all ranks. Default is True meaning in a distributed setting the
layer inputs will be gathered from all ranks for the SVD computation. For non-distributed settings this
argument is ignored. Set to False if you are using a non-distributed dataloader in a distributed setting.
show_progress_bar (bool): Whether to show a progress bar. Default is True.

Returns:
Expand Down Expand Up @@ -624,6 +646,7 @@ def target_module_check_fn_default(name, module, peft_config):
forward_fn=forward_fn,
prepare_model_inputs_fn=prepare_model_inputs_fn,
prepare_layer_inputs_fn=prepare_layer_inputs_fn,
gather_distributed_inputs=gather_distributed_inputs,
show_progress_bar=show_progress_bar,
)
return eva_state_dict
Expand All @@ -638,6 +661,7 @@ def initialize_lora_eva_weights(
prepare_model_inputs_fn: Optional[callable] = prepare_model_inputs_fn_language_modeling,
prepare_layer_inputs_fn: Union[callable, Dict[str, callable], None] = prepare_layer_inputs_fn_language_modeling,
adapter_name: str = "default",
gather_distributed_inputs: bool = True,
show_progress_bar: bool = True,
):
"""
Expand Down Expand Up @@ -672,6 +696,10 @@ def initialize_lora_eva_weights(
case model_inputs is the mask used to determine which indices should be used for SVD (created by
`prepare_model_inputs_fn_language_modeling`).
adapter_name (str): The name of the adapter to initialize the weights for.
gather_distributed_inputs (bool):
Whether to gather the layer inputs from all ranks. Default is True meaning in a distributed setting the
layer inputs will be gathered from all ranks for the SVD computation. For non-distributed settings this
argument is ignored. Set to False if you are using a non-distributed dataloader in a distributed setting.
show_progress_bar (bool): Whether to show a progress bar. Default is True.

Returns:
Expand Down Expand Up @@ -700,6 +728,7 @@ def initialize_lora_eva_weights(
prepare_model_inputs_fn=prepare_model_inputs_fn,
prepare_layer_inputs_fn=prepare_layer_inputs_fn,
adapter_name=adapter_name,
gather_distributed_inputs=gather_distributed_inputs,
show_progress_bar=show_progress_bar,
)

Expand Down
75 changes: 45 additions & 30 deletions src/peft/utils/incremental_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -41,6 +40,7 @@ class IncrementalPCA:
n_components * 2.
lowrank_niter (int, optional): Number of subspace iterations to conduct for torch.svd_lowrank.
Defaults to 4.
lowrank_seed (int, optional): Seed for making results of torch.svd_lowrank reproducible.
"""

def __init__(
Expand All @@ -52,28 +52,40 @@ def __init__(
lowrank: bool = False,
lowrank_q: Optional[int] = None,
lowrank_niter: int = 4,
lowrank_seed: Optional[int] = None,
):
self.n_components_ = n_components
self.n_components = n_components
self.copy = copy
self.batch_size = batch_size
self.svd_driver = svd_driver
self.lowrank = lowrank
self.lowrank_q = lowrank_q
self.lowrank_niter = lowrank_niter
self.lowrank_seed = lowrank_seed

self.n_features_ = None

if lowrank:
if lowrank_q is None:
if n_components is None:
raise ValueError("n_components must be specified when using lowrank mode with lowrank_q=None.")
lowrank_q = n_components * 2
if lowrank_q < n_components:
raise ValueError("lowrank_q must be greater than or equal to n_components.")
if self.lowrank:
self._validate_lowrank_params()

def svd_fn(X):
U, S, V = torch.svd_lowrank(X, q=lowrank_q, niter=lowrank_niter)
return U, S, V.mH # V is returned as a conjugate transpose
def _validate_lowrank_params(self):
if self.lowrank_q is None:
if self.n_components is None:
raise ValueError("n_components must be specified when using lowrank mode with lowrank_q=None.")
self.lowrank_q = self.n_components * 2
elif self.lowrank_q < self.n_components:
raise ValueError("lowrank_q must be greater than or equal to n_components.")

self._svd_fn = svd_fn
def _svd_fn_full(self, X):
return torch.linalg.svd(X, full_matrices=False, driver=self.svd_driver)

else:
self._svd_fn = partial(torch.linalg.svd, full_matrices=False, driver=svd_driver)
def _svd_fn_lowrank(self, X):
seed_enabled = self.lowrank_seed is not None
with torch.random.fork_rng(enabled=seed_enabled):
if seed_enabled:
torch.manual_seed(self.lowrank_seed)
U, S, V = torch.svd_lowrank(X, q=self.lowrank_q, niter=self.lowrank_niter)
return U, S, V.mH

def _validate_data(self, X) -> torch.Tensor:
"""
Expand All @@ -93,16 +105,16 @@ def _validate_data(self, X) -> torch.Tensor:
X = X.clone()

n_samples, n_features = X.shape
if self.n_components_ is None:
if self.n_components is None:
pass
elif self.n_components_ > n_features:
elif self.n_components > n_features:
raise ValueError(
f"n_components={self.n_components_} invalid for n_features={n_features}, "
f"n_components={self.n_components} invalid for n_features={n_features}, "
"need more rows than columns for IncrementalPCA processing."
)
elif self.n_components_ > n_samples:
elif self.n_components > n_samples:
raise ValueError(
f"n_components={self.n_components_} must be less or equal to the batch number of samples {n_samples}"
f"n_components={self.n_components} must be less or equal to the batch number of samples {n_samples}"
)

if X.dtype not in valid_dtypes:
Expand Down Expand Up @@ -210,7 +222,7 @@ def fit(self, X, check_input=True):
if self.batch_size is None:
self.batch_size = 5 * n_features

for batch in self.gen_batches(n_samples, self.batch_size, min_batch_size=self.n_components_ or 0):
for batch in self.gen_batches(n_samples, self.batch_size, min_batch_size=self.n_components or 0):
self.partial_fit(X[batch], check_input=False)

return self
Expand Down Expand Up @@ -238,8 +250,8 @@ def partial_fit(self, X, check_input=True):
self.var_ = None # Will be initialized properly in _incremental_mean_and_var based on data dimensions
self.n_samples_seen_ = torch.tensor([0], device=X.device)
self.n_features_ = n_features
if not self.n_components_:
self.n_components_ = min(n_samples, n_features)
if not self.n_components:
self.n_components = min(n_samples, n_features)

if n_features != self.n_features_:
raise ValueError(
Expand All @@ -265,20 +277,23 @@ def partial_fit(self, X, check_input=True):
)
)

U, S, Vt = self._svd_fn(X)
if self.lowrank:
U, S, Vt = self._svd_fn_lowrank(X)
else:
U, S, Vt = self._svd_fn_full(X)
U, Vt = self._svd_flip(U, Vt, u_based_decision=False)
explained_variance = S**2 / (n_total_samples - 1)
explained_variance_ratio = S**2 / torch.sum(col_var * n_total_samples)

self.n_samples_seen_ = n_total_samples
self.components_ = Vt[: self.n_components_]
self.singular_values_ = S[: self.n_components_]
self.components_ = Vt[: self.n_components]
self.singular_values_ = S[: self.n_components]
self.mean_ = col_mean
self.var_ = col_var
self.explained_variance_ = explained_variance[: self.n_components_]
self.explained_variance_ratio_ = explained_variance_ratio[: self.n_components_]
if self.n_components_ not in (n_samples, n_features):
self.noise_variance_ = explained_variance[self.n_components_ :].mean()
self.explained_variance_ = explained_variance[: self.n_components]
self.explained_variance_ratio_ = explained_variance_ratio[: self.n_components]
if self.n_components not in (n_samples, n_features):
self.noise_variance_ = explained_variance[self.n_components :].mean()
else:
self.noise_variance_ = torch.tensor(0.0, device=X.device)
return self
Expand Down
2 changes: 1 addition & 1 deletion tests/test_incremental_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_n_components_none():
# First partial_fit call, ipca.n_components_ is inferred from
# min(X.shape)
ipca.partial_fit(X)
assert ipca.n_components_ == min(X.shape)
assert ipca.n_components == min(X.shape)


def test_incremental_pca_num_features_change():
Expand Down
Loading