Skip to content

Commit

Permalink
use pyfastkron for kroneckerproduct (t/r)matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
abhijangda committed Dec 8, 2024
1 parent 6dad1cb commit 06f97e2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
26 changes: 10 additions & 16 deletions linear_operator/operators/kronecker_product_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from jaxtyping import Float
from torch import Tensor

from pyfastkron import fastkrontorch as fktorch

from linear_operator import settings
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.dense_linear_operator import to_linear_operator
Expand Down Expand Up @@ -267,14 +269,13 @@ def _matmul(
self: Float[LinearOperator, "*batch M N"],
rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]],
) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]:
is_vec = rhs.ndimension() == 1
if is_vec:
rhs = rhs.unsqueeze(-1)

res = _matmul(self.linear_ops, self.shape, rhs.contiguous())

if is_vec:
res = res.squeeze(-1)
res = fktorch.gekmm([op.to_dense() for op in self.linear_ops], rhs.contiguous())
return res

def rmatmul(self: Float[LinearOperator, "... M N"],
rhs: Union[Float[Tensor, "... P M"], Float[Tensor, "... M"], Float[LinearOperator, "... P M"]],
) -> Union[Float[Tensor, "... P N"], Float[Tensor, "N"], Float[LinearOperator, "... P N"]]:
res = fktorch.gemkm(rhs.contiguous(), [op.to_dense() for op in self.linear_ops])
return res

@cached(name="root_decomposition")
Expand Down Expand Up @@ -357,14 +358,7 @@ def _t_matmul(
self: Float[LinearOperator, "*batch M N"],
rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]],
) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]:
is_vec = rhs.ndimension() == 1
if is_vec:
rhs = rhs.unsqueeze(-1)

res = _t_matmul(self.linear_ops, self.shape, rhs.contiguous())

if is_vec:
res = res.squeeze(-1)
res = fktorch.gekmm([op.to_dense().mT for op in self.linear_ops], rhs.contiguous())
return res

def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"scipy",
"jaxtyping==0.2.19",
"mpmath>=0.19,<=1.3", # avoid incompatibiltiy with torch+sympy with mpmath 1.4
"pyfastkron"
]


Expand Down

0 comments on commit 06f97e2

Please sign in to comment.