From 6645ca9577839c61721e8f31f347f64f9a4fc31c Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Tue, 9 Jul 2024 02:21:26 -0400 Subject: [PATCH] Minimize size for dot product --- kronfluence/module/conv2d.py | 10 +--------- kronfluence/module/linear.py | 12 ++---------- kronfluence/module/tracker/pairwise_score.py | 2 +- kronfluence/module/tracker/precondition.py | 2 +- kronfluence/score/dot_product.py | 16 ---------------- 5 files changed, 5 insertions(+), 37 deletions(-) diff --git a/kronfluence/module/conv2d.py b/kronfluence/module/conv2d.py index 1becff0..38bd382 100644 --- a/kronfluence/module/conv2d.py +++ b/kronfluence/module/conv2d.py @@ -211,12 +211,4 @@ def compute_self_measurement_score( input_activation = self._flatten_input_activation(input_activation=input_activation) input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") - if self.einsum_expression is None: - self.einsum_expression = contract_expression( - "bio,bci,bco->b", - preconditioned_gradient.shape, - output_gradient.shape, - input_activation.shape, - optimize=DynamicProgramming(search_outer=True, minimize="flops"), - ) - return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) + return contract("bio,bci,bco->b", preconditioned_gradient, output_gradient, input_activation) diff --git a/kronfluence/module/linear.py b/kronfluence/module/linear.py index 0fbf40a..a6fcd47 100644 --- a/kronfluence/module/linear.py +++ b/kronfluence/module/linear.py @@ -103,7 +103,7 @@ def compute_pairwise_score( minimize = "size" else: expr = "qio,b...i,b...o->qb" - minimize = "flops" + minimize = "size" self.einsum_expression = contract_expression( expr, preconditioned_gradient.shape, @@ -117,12 +117,4 @@ def compute_self_measurement_score( self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor ) -> torch.Tensor: input_activation = self._flatten_input_activation(input_activation=input_activation) - if self.einsum_expression is None: - self.einsum_expression = contract_expression( - "bio,b...i,b...o->b", - preconditioned_gradient.shape, - output_gradient.shape, - input_activation.shape, - optimize=DynamicProgramming(search_outer=True, minimize="flops"), - ) - return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) + return contract("bio,b...i,b...o->b", preconditioned_gradient, output_gradient, input_activation) diff --git a/kronfluence/module/tracker/pairwise_score.py b/kronfluence/module/tracker/pairwise_score.py index a04239f..2bde117 100644 --- a/kronfluence/module/tracker/pairwise_score.py +++ b/kronfluence/module/tracker/pairwise_score.py @@ -88,7 +88,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None: input_activation=cached_activation.to(device=output_gradient.device), output_gradient=output_gradient, ) - del cached_activation, output_gradient + del self.cached_activations, cached_activation, output_gradient self.clear_all_cache() else: per_sample_gradient = self.module.compute_per_sample_gradient( diff --git a/kronfluence/module/tracker/precondition.py b/kronfluence/module/tracker/precondition.py index b93c397..1d60dbc 100644 --- a/kronfluence/module/tracker/precondition.py +++ b/kronfluence/module/tracker/precondition.py @@ -236,7 +236,7 @@ def accumulate_iterations(self) -> None: self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = torch.cat( (accumulated_gradient, gradient), dim=0 ).contiguous() - del gradient, self.module.storage[PRECONDITIONED_GRADIENT_NAME] + del self.module.storage[PRECONDITIONED_GRADIENT_NAME], gradient self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None @torch.no_grad() diff --git a/kronfluence/score/dot_product.py b/kronfluence/score/dot_product.py index 6b324fd..cd96f3e 100644 --- a/kronfluence/score/dot_product.py +++ b/kronfluence/score/dot_product.py @@ -73,10 +73,6 @@ def compute_dot_products_with_loader( total_steps = 0 enable_amp = score_args.amp_dtype is not None - print("Start") - print(torch.cuda.memory_allocated()) - print(torch.cuda.memory_reserved()) - with tqdm( total=len(train_loader), desc="Computing pairwise scores (training gradient)", @@ -86,10 +82,6 @@ def compute_dot_products_with_loader( for batch in train_loader: batch = send_to_device(tensor=batch, device=state.device) - print("Begin") - print(torch.cuda.memory_allocated()) - print(torch.cuda.memory_reserved()) - with no_sync(model=model, state=state): model.zero_grad(set_to_none=True) with autocast(device_type=state.device.type, enabled=enable_amp, dtype=score_args.amp_dtype): @@ -103,10 +95,6 @@ def compute_dot_products_with_loader( if factor_args.has_shared_parameters: finalize_iteration(model=model, tracked_module_names=tracked_module_names) - print("Middle") - print(torch.cuda.memory_allocated()) - print(torch.cuda.memory_reserved()) - with torch.no_grad(): if score_args.compute_per_module_scores: for module in cached_module_lst: @@ -137,10 +125,6 @@ def compute_dot_products_with_loader( total_steps += 1 pbar.update(1) - print("End") - print(torch.cuda.memory_allocated()) - print(torch.cuda.memory_reserved()) - model.zero_grad(set_to_none=True) finalize_all_iterations(model=model, tracked_module_names=tracked_module_names) set_mode(