Skip to content

Commit

Permalink
bug: allow unsafe_broadcast in VStack and Fredholm1
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Nov 26, 2024
1 parent e5d7b52 commit fedb902
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
5 changes: 3 additions & 2 deletions pylops_mpi/basicoperators/VStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ def __init__(self, ops: Sequence[LinearOperator],

def _matvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition is not Partition.BROADCAST:
raise ValueError(f"x should have partition={Partition.BROADCAST}, {x.partition} != {Partition.BROADCAST}")
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
f"Got {x.partition} instead...")
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n,
engine=x.engine, dtype=self.dtype)
y1 = []
Expand Down
10 changes: 6 additions & 4 deletions pylops_mpi/signalprocessing/Fredholm1.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def __init__(

def _matvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition is not Partition.BROADCAST:
raise ValueError(f"x should have partition={Partition.BROADCAST}, {x.partition} != {Partition.BROADCAST}")
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
f"Got {x.partition} instead...")
y = DistributedArray(global_shape=self.shape[0], partition=Partition.BROADCAST,
engine=x.engine, dtype=self.dtype)
x = x.local_array.reshape(self.dims).squeeze()
Expand All @@ -129,8 +130,9 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:

def _rmatvec(self, x: NDArray) -> NDArray:
ncp = get_module(x.engine)
if x.partition is not Partition.BROADCAST:
raise ValueError(f"x should have partition={Partition.BROADCAST}, {x.partition} != {Partition.BROADCAST}")
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
f"Got {x.partition} instead...")
y = DistributedArray(global_shape=self.shape[1], partition=Partition.BROADCAST,
engine=x.engine, dtype=self.dtype)
x = x.local_array.reshape(self.dimsd).squeeze()
Expand Down

0 comments on commit fedb902

Please sign in to comment.