Skip to content

Commit

Permalink
Fix operator compositions
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 committed Feb 26, 2024
1 parent 9b62dae commit b8b599c
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 36 deletions.
8 changes: 4 additions & 4 deletions src/mrpro/operators/_LinearOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ class LinearOperatorComposition(LinearOperator, OperatorComposition[torch.Tensor

def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Adjoint operator composition."""
return self._operator2.adjoint(self._operator1.adjoint(x))
return self._operator2.adjoint(*self._operator1.adjoint(x))


class LinearOperatorSum(LinearOperator, OperatorSum[torch.Tensor, tuple[torch.Tensor,]]):
"""Operator addition."""

def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Adjoint operator addition."""
return self._operator1.adjoint(x) + self._operator2.adjoint(x)
return (self._operator1.adjoint(x)[0] + self._operator2.adjoint(x)[0],)


class LinearOperatorElementwiseProduct(LinearOperator, OperatorElementwiseProduct[torch.Tensor, tuple[torch.Tensor,]]):
Expand All @@ -86,8 +86,8 @@ class LinearOperatorElementwiseProduct(LinearOperator, OperatorElementwiseProduc
def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Adjoint Operator elementwise multiplication with scalar/tensor."""
if self._tensor.is_complex():
return self._operator.adjoint(x) * self._tensor.conj()
return self._operator.adjoint(x) * self._tensor
return (self._operator.adjoint(x)[0] * self._tensor.conj(),)
return (self._operator.adjoint(x)[0] * self._tensor,)


class AdjointLinearOperator(LinearOperator):
Expand Down
8 changes: 5 additions & 3 deletions src/mrpro/operators/_Operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Generic
from typing import TypeVar
from typing import TypeVarTuple
from typing import cast

import torch

Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(self, operator1: Operator[*Tin2, Tout], operator2: Operator[*Tin, t

def forward(self, *args: *Tin) -> Tout:
"""Operator composition."""
return self._operator1(self._operator2(*args))
return self._operator1(*self._operator2(*args))


class OperatorSum(Operator[*Tin, Tout]):
Expand All @@ -74,7 +75,7 @@ def __init__(self, operator1: Operator[*Tin, Tout], operator2: Operator[*Tin, To

def forward(self, *args: *Tin) -> Tout:
"""Operator addition."""
return self._operator1(*args) + self._operator2(*args)
return cast(Tout, tuple(a + b for a, b in zip(self._operator1(*args), self._operator2(*args), strict=True)))


class OperatorElementwiseProduct(Operator[*Tin, Tout]):
Expand All @@ -87,4 +88,5 @@ def __init__(self, operator: Operator[*Tin, Tout], tensor: torch.Tensor):

def forward(self, *args: *Tin) -> Tout:
"""Operator elementwise multiplication."""
return self._tensor * self._operator(*args)
out = self._operator(*args)
return cast(Tout, tuple(a * self._tensor for a in out))
89 changes: 60 additions & 29 deletions tests/operators/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mrpro.operators import Operator


class DummyOperator(Operator):
class DummyOperator(Operator[torch.Tensor, torch.Tensor]):
"""Dummy operator for testing, raises input to the power of value."""

def __init__(self, value: torch.Tensor):
Expand All @@ -16,7 +16,7 @@ def __init__(self, value: torch.Tensor):

def forward(self, x: torch.Tensor):
"""Dummy operator."""
return x**self._value
return (x**self._value,)


class DummyLinearOperator(LinearOperator):
Expand All @@ -28,24 +28,22 @@ def __init__(self, value: torch.Tensor):

def forward(self, x: torch.Tensor):
"""Dummy linear operator."""
return x * 2
return (x * self._value,)

def adjoint(self, x: torch.Tensor):
"""Dummy adjoint linear operator."""
if x.is_complex():
x = x.conj()
if self._value.is_complex():
return x * self._value.conj()
return x * self._value
return (x * self._value.conj(),)
return (x * self._value,)


def test_composition_operator():
a = DummyOperator(torch.tensor(2.0))
b = DummyOperator(torch.tensor(3.0))
c = a @ b
x = torch.arange(10)
y1 = c(x)
y2 = a(b(x))
(y1,) = c(x)
(y2,) = a(*b(x))

torch.testing.assert_close(y1, y2)
assert isinstance(c, Operator)
Expand All @@ -57,8 +55,8 @@ def test_composition_linearoperator():
b = DummyLinearOperator(torch.tensor(3.0))
c = a @ b
x = torch.arange(10)
y1 = c(x)
y2 = a(b(x))
(y1,) = c(x)
(y2,) = a(*b(x))

torch.testing.assert_close(y1, y2)
assert isinstance(c, Operator), 'LinearOperator @ LinearOperator should be an Operator'
Expand All @@ -70,8 +68,8 @@ def test_composition_linearoperator_operator():
b = DummyOperator(torch.tensor(3.0))
c = a @ b
x = torch.arange(10)
y1 = c(x)
y2 = a(b(x))
(y1,) = c(x)
(y2,) = a(*b(x))

torch.testing.assert_close(y1, y2)
assert isinstance(c, Operator), 'LinearOperator @ Operator should be an Operator'
Expand All @@ -83,8 +81,8 @@ def test_sum_operator():
b = DummyOperator(torch.tensor(3.0))
c = a + b
x = torch.arange(10)
y1 = c(x)
y2 = a(x) + b(x)
(y1,) = c(x)
y2 = a(x)[0] + b(x)[0]

torch.testing.assert_close(y1, y2)
assert isinstance(c, Operator), 'Operator + Operator should be an Operator'
Expand All @@ -96,8 +94,8 @@ def test_sum_linearoperator():
b = DummyLinearOperator(torch.tensor(3.0))
c = a + b
x = torch.arange(10)
y1 = c(x)
y2 = a(x) + b(x)
(y1,) = c(x)
y2 = a(x)[0] + b(x)[0]

torch.testing.assert_close(y1, y2)
assert isinstance(c, Operator), 'LinearOperator + LinearOperator should be an Operator'
Expand All @@ -109,8 +107,8 @@ def test_sum_linearoperator_operator():
b = DummyOperator(torch.tensor(3.0))
c = a + b
x = torch.arange(10)
y1 = c(x)
y2 = a(x) + b(x)
(y1,) = c(x)
y2 = a(x)[0] + b(x)[0]

torch.testing.assert_close(y1, y2)
assert isinstance(c, Operator), 'LinearOperator + Operator should be an Operator'
Expand All @@ -122,8 +120,8 @@ def test_sum_operator_linearoperator():
b = DummyLinearOperator(torch.tensor(2.0))
c = a + b
x = torch.arange(10)
y1 = c(x)
y2 = a(x) + b(x)
(y1,) = c(x)
y2 = a(x)[0] + b(x)[0]

torch.testing.assert_close(y1, y2)
assert isinstance(c, Operator), 'Operator + LinearOperator should be an Operator'
Expand All @@ -135,8 +133,8 @@ def test_elementwise_product_operator():
b = torch.tensor(3.0)
c = a * b
x = torch.arange(10)
y1 = c(x)
y2 = a(x) * b
(y1,) = c(x)
y2 = a(x)[0] * b

torch.testing.assert_close(y1, y2)
assert isinstance(c, Operator), 'Operator * scalar should be an Operator'
Expand All @@ -148,8 +146,8 @@ def test_elementwise_rproduct_operator():
b = torch.tensor(3.0)
c = b * a
x = torch.arange(10)
y1 = c(x)
y2 = a(x) * b
(y1,) = c(x)
y2 = a(x)[0] * b

torch.testing.assert_close(y1, y2)
assert isinstance(c, Operator), 'Operator * scalar should be an Operator'
Expand All @@ -161,8 +159,8 @@ def test_elementwise_product_linearoperator():
b = torch.tensor(3.0)
c = a * b
x = torch.arange(10)
y1 = c(x)
y2 = a(x) * b
(y1,) = c(x)
y2 = a(x)[0] * b

torch.testing.assert_close(y1, y2)
assert isinstance(c, Operator), 'LinearOperator * scalar should be an Operator'
Expand All @@ -174,9 +172,42 @@ def test_elementwise_rproduct_linearoperator():
b = torch.tensor(3.0)
c = b * a
x = torch.arange(10)
y1 = c(x)
y2 = a(x) * b
(y1,) = c(x)
y2 = a(x)[0] * b

torch.testing.assert_close(y1, y2)
assert isinstance(c, Operator), 'LinearOperator * scalar should be an Operator'
assert isinstance(c, LinearOperator), 'LinearOperator * scalar should be a LinearOperator'


def test_adjoint_composition_operators():
a = DummyLinearOperator(torch.tensor(2.0 + 1j))
b = DummyLinearOperator(torch.tensor(3.0 + 2j))
u = torch.tensor(4 + 5j)
v = torch.tensor(7 + 8j)
A = a @ b
(Au,) = A(u)
(AHv,) = A.H(v)
torch.testing.assert_close(Au * v.conj(), u * AHv.conj())


def test_adjoint_product():
a = DummyLinearOperator(torch.tensor(2.0 + 1j))
b = torch.tensor(3.0 + 2j)
u = torch.tensor(4 + 5j)
v = torch.tensor(7 + 8j)
A = a * b
(Au,) = A(u)
(AHv,) = A.H(v)
torch.testing.assert_close(Au * v.conj(), u * AHv.conj())


def test_adjoint_sum():
a = DummyLinearOperator(torch.tensor(2.0 + 1j))
b = DummyLinearOperator(torch.tensor(3.0 + 2j))
u = torch.tensor(4 + 5j)
v = torch.tensor(7 + 8j)
A = a + b
(Au,) = A(u)
(AHv,) = A.H(v)
torch.testing.assert_close(Au * v.conj(), u * AHv.conj())

0 comments on commit b8b599c

Please sign in to comment.