From a2efa476bf050a5f1e2afcbdacabbec2cd8d155e Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Tue, 9 Jul 2024 03:12:54 -0400 Subject: [PATCH] Remove einsum --- kronfluence/module/conv2d.py | 14 ++------------ kronfluence/module/linear.py | 23 +++++++---------------- 2 files changed, 9 insertions(+), 28 deletions(-) diff --git a/kronfluence/module/conv2d.py b/kronfluence/module/conv2d.py index 7f15a0f..a6d7d48 100644 --- a/kronfluence/module/conv2d.py +++ b/kronfluence/module/conv2d.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from einconv.utils import get_conv_paddings from einops import rearrange, reduce -from opt_einsum import DynamicProgramming, contract, contract_expression +from opt_einsum import DynamicProgramming, contract_expression from torch import nn from torch.nn.modules.utils import _pair @@ -196,20 +196,10 @@ def compute_pairwise_score( return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation).contiguous() return torch.einsum("qio,bti,bto->qb", preconditioned_gradient, output_gradient, input_activation) - # if self.einsum_expression is None: - # self.einsum_expression = contract_expression( - # "qio,bti,bto->qb", - # preconditioned_gradient.shape, - # output_gradient.shape, - # input_activation.shape, - # optimize=DynamicProgramming(search_outer=True, minimize="flops"), - # ) - # return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) - def compute_self_measurement_score( self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor ) -> torch.Tensor: input_activation = self._flatten_input_activation(input_activation=input_activation) input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") - return contract("bio,bci,bco->b", preconditioned_gradient, output_gradient, input_activation).contiguous() + return torch.einsum("bio,bci,bco->b", preconditioned_gradient, output_gradient, input_activation) diff --git a/kronfluence/module/linear.py b/kronfluence/module/linear.py index 171873f..f3dfa84 100644 --- a/kronfluence/module/linear.py +++ b/kronfluence/module/linear.py @@ -2,7 +2,7 @@ import torch from einops import rearrange -from opt_einsum import DynamicProgramming, contract, contract_expression +from opt_einsum import DynamicProgramming, contract_expression from torch import nn from kronfluence.module.tracked_module import TrackedModule @@ -79,8 +79,8 @@ def compute_per_sample_gradient( def compute_pairwise_score( self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor ) -> torch.Tensor: - input_activation = self._flatten_input_activation(input_activation=input_activation) if isinstance(preconditioned_gradient, list): + input_activation = self._flatten_input_activation(input_activation=input_activation) left_mat, right_mat = preconditioned_gradient if self.einsum_expression is None: if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3: @@ -96,24 +96,15 @@ def compute_pairwise_score( optimize=DynamicProgramming(search_outer=True, minimize="size"), ) return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation).contiguous() - if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3: + input_activation = self._flatten_input_activation(input_activation=input_activation) return torch.einsum("qio,bti,bto->qbt", preconditioned_gradient, output_gradient, input_activation) - return torch.einsum("qio,b...i,b...o->qb", preconditioned_gradient, output_gradient, input_activation) - # else: - # expr = "qio,b...i,b...o->qb" - # minimize = "flops" - # self.einsum_expression = contract_expression( - # expr, - # preconditioned_gradient.shape, - # output_gradient.shape, - # input_activation.shape, - # optimize=DynamicProgramming(search_outer=True, minimize=minimize), - # ) - # return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) + gradient = self.compute_per_sample_gradient(input_activation=input_activation, output_gradient=output_gradient) + # return torch.einsum("qio,b...i,b...o->qb", preconditioned_gradient, output_gradient, input_activation) + return torch.matmul(preconditioned_gradient.view(preconditioned_gradient.size(0), -1), gradient.view(gradient.size(0), -1).T) def compute_self_measurement_score( self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor ) -> torch.Tensor: input_activation = self._flatten_input_activation(input_activation=input_activation) - return contract("bio,b...i,b...o->b", preconditioned_gradient, output_gradient, input_activation).contiguous() + return torch.einsum("bio,b...i,b...o->b", preconditioned_gradient, output_gradient, input_activation)