Skip to content

Commit

Permalink
Remove blank einsum calls with >3 operands
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 10, 2024
1 parent c6efd69 commit ab16f3f
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
24 changes: 22 additions & 2 deletions kronfluence/module/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,32 @@ def compute_pairwise_score(
)[0]
self.einsum_path = [item for pair in path for item in pair]
return _VF.einsum(expr, (left_mat, right_mat, output_gradient, input_activation), path=self.einsum_path) # pylint: disable=no-member
return torch.einsum("qio,bti,bto->qb", preconditioned_gradient, output_gradient, input_activation)
expr = "qio,bti,bto->qb"
if self.einsum_path is None:
path = contract_path(
expr,
preconditioned_gradient,
output_gradient,
input_activation,
optimize=DynamicProgramming(search_outer=True, minimize="flops"),
)[0]
self.einsum_path = [item for pair in path for item in pair]
return _VF.einsum(expr, (preconditioned_gradient, output_gradient, input_activation), path=self.einsum_path) # pylint: disable=no-member

def compute_self_measurement_score(
self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor
) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1))
output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o")
return torch.einsum("bio,bci,bco->b", preconditioned_gradient, output_gradient, input_activation)
expr = "bio,bci,bco->b"
if self.einsum_path is None:
path = contract_path(
expr,
preconditioned_gradient,
output_gradient,
input_activation,
optimize=DynamicProgramming(search_outer=True, minimize="flops"),
)[0]
self.einsum_path = [item for pair in path for item in pair]
return _VF.einsum(expr, (preconditioned_gradient, output_gradient, input_activation), path=self.einsum_path) # pylint: disable=no-member
36 changes: 33 additions & 3 deletions kronfluence/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,41 @@ def compute_pairwise_score(
self.einsum_path = [item for pair in path for item in pair]
return _VF.einsum(expr, (left_mat, right_mat, output_gradient, input_activation), path=self.einsum_path) # pylint: disable=no-member
if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3:
return torch.einsum("qio,bti,bto->qbt", preconditioned_gradient, output_gradient, input_activation)
return torch.einsum("qio,b...i,b...o->qb", preconditioned_gradient, output_gradient, input_activation)
expr = "qio,bti,bto->qbt"
if self.einsum_path is None:
path = contract_path(
expr,
preconditioned_gradient,
output_gradient,
input_activation,
optimize=DynamicProgramming(search_outer=True, minimize="flops"),
)[0]
self.einsum_path = [item for pair in path for item in pair]
return _VF.einsum(expr, (preconditioned_gradient, output_gradient, input_activation), path=self.einsum_path) # pylint: disable=no-member
expr = "qio,b...i,b...o->qb"
if self.einsum_path is None:
path = contract_path(
expr,
preconditioned_gradient,
output_gradient,
input_activation,
optimize=DynamicProgramming(search_outer=True, minimize="flops"),
)[0]
self.einsum_path = [item for pair in path for item in pair]
return _VF.einsum(expr, (preconditioned_gradient, output_gradient, input_activation), path=self.einsum_path) # pylint: disable=no-member

def compute_self_measurement_score(
self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor
) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
return torch.einsum("bio,b...i,b...o->b", preconditioned_gradient, output_gradient, input_activation)
expr = "bio,b...i,b...o->b"
if self.einsum_path is None:
path = contract_path(
expr,
preconditioned_gradient,
output_gradient,
input_activation,
optimize=DynamicProgramming(search_outer=True, minimize="flops"),
)[0]
self.einsum_path = [item for pair in path for item in pair]
return _VF.einsum(expr, (preconditioned_gradient, output_gradient, input_activation), path=self.einsum_path) # pylint: disable=no-member

0 comments on commit ab16f3f

Please sign in to comment.