diff --git a/linear_operator/operators/kronecker_product_linear_operator.py b/linear_operator/operators/kronecker_product_linear_operator.py index 1cece8b0..db9b1935 100644 --- a/linear_operator/operators/kronecker_product_linear_operator.py +++ b/linear_operator/operators/kronecker_product_linear_operator.py @@ -8,6 +8,8 @@ from jaxtyping import Float from torch import Tensor +from pyfastkron import fastkrontorch as fktorch + from linear_operator import settings from linear_operator.operators._linear_operator import IndexType, LinearOperator from linear_operator.operators.dense_linear_operator import to_linear_operator @@ -267,14 +269,13 @@ def _matmul( self: Float[LinearOperator, "*batch M N"], rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: - is_vec = rhs.ndimension() == 1 - if is_vec: - rhs = rhs.unsqueeze(-1) - - res = _matmul(self.linear_ops, self.shape, rhs.contiguous()) - - if is_vec: - res = res.squeeze(-1) + res = fktorch.gekmm([op.to_dense() for op in self.linear_ops], rhs.contiguous()) + return res + + def rmatmul(self: Float[LinearOperator, "... M N"], + rhs: Union[Float[Tensor, "... P M"], Float[Tensor, "... M"], Float[LinearOperator, "... P M"]], + ) -> Union[Float[Tensor, "... P N"], Float[Tensor, "N"], Float[LinearOperator, "... P N"]]: + res = fktorch.gemkm(rhs.contiguous(), [op.to_dense() for op in self.linear_ops]) return res @cached(name="root_decomposition") @@ -357,14 +358,7 @@ def _t_matmul( self: Float[LinearOperator, "*batch M N"], rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: - is_vec = rhs.ndimension() == 1 - if is_vec: - rhs = rhs.unsqueeze(-1) - - res = _t_matmul(self.linear_ops, self.shape, rhs.contiguous()) - - if is_vec: - res = res.squeeze(-1) + res = fktorch.gekmm([op.to_dense().mT for op in self.linear_ops], rhs.contiguous()) return res def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: diff --git a/setup.py b/setup.py index f3313a47..e30fafd4 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ "scipy", "jaxtyping==0.2.19", "mpmath>=0.19,<=1.3", # avoid incompatibiltiy with torch+sympy with mpmath 1.4 + "pyfastkron" ]