Skip to content

Commit

Permalink
Remove einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 9, 2024
1 parent 77c0e3a commit a2efa47
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 28 deletions.
14 changes: 2 additions & 12 deletions kronfluence/module/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn.functional as F
from einconv.utils import get_conv_paddings
from einops import rearrange, reduce
from opt_einsum import DynamicProgramming, contract, contract_expression
from opt_einsum import DynamicProgramming, contract_expression
from torch import nn
from torch.nn.modules.utils import _pair

Expand Down Expand Up @@ -196,20 +196,10 @@ def compute_pairwise_score(
return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation).contiguous()
return torch.einsum("qio,bti,bto->qb", preconditioned_gradient, output_gradient, input_activation)

# if self.einsum_expression is None:
# self.einsum_expression = contract_expression(
# "qio,bti,bto->qb",
# preconditioned_gradient.shape,
# output_gradient.shape,
# input_activation.shape,
# optimize=DynamicProgramming(search_outer=True, minimize="flops"),
# )
# return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation)

def compute_self_measurement_score(
self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor
) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1))
output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o")
return contract("bio,bci,bco->b", preconditioned_gradient, output_gradient, input_activation).contiguous()
return torch.einsum("bio,bci,bco->b", preconditioned_gradient, output_gradient, input_activation)
23 changes: 7 additions & 16 deletions kronfluence/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from einops import rearrange
from opt_einsum import DynamicProgramming, contract, contract_expression
from opt_einsum import DynamicProgramming, contract_expression
from torch import nn

from kronfluence.module.tracked_module import TrackedModule
Expand Down Expand Up @@ -79,8 +79,8 @@ def compute_per_sample_gradient(
def compute_pairwise_score(
self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor
) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
if isinstance(preconditioned_gradient, list):
input_activation = self._flatten_input_activation(input_activation=input_activation)
left_mat, right_mat = preconditioned_gradient
if self.einsum_expression is None:
if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3:
Expand All @@ -96,24 +96,15 @@ def compute_pairwise_score(
optimize=DynamicProgramming(search_outer=True, minimize="size"),
)
return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation).contiguous()

if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3:
input_activation = self._flatten_input_activation(input_activation=input_activation)
return torch.einsum("qio,bti,bto->qbt", preconditioned_gradient, output_gradient, input_activation)
return torch.einsum("qio,b...i,b...o->qb", preconditioned_gradient, output_gradient, input_activation)
# else:
# expr = "qio,b...i,b...o->qb"
# minimize = "flops"
# self.einsum_expression = contract_expression(
# expr,
# preconditioned_gradient.shape,
# output_gradient.shape,
# input_activation.shape,
# optimize=DynamicProgramming(search_outer=True, minimize=minimize),
# )
# return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation)
gradient = self.compute_per_sample_gradient(input_activation=input_activation, output_gradient=output_gradient)
# return torch.einsum("qio,b...i,b...o->qb", preconditioned_gradient, output_gradient, input_activation)
return torch.matmul(preconditioned_gradient.view(preconditioned_gradient.size(0), -1), gradient.view(gradient.size(0), -1).T)

def compute_self_measurement_score(
self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor
) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
return contract("bio,b...i,b...o->b", preconditioned_gradient, output_gradient, input_activation).contiguous()
return torch.einsum("bio,b...i,b...o->b", preconditioned_gradient, output_gradient, input_activation)

0 comments on commit a2efa47

Please sign in to comment.