Skip to content

Commit

Permalink
fix add decomposition for complex numbers (pytorch#129044)
Browse files Browse the repository at this point in the history
Fixes pytorch#125745

Bug source: When addition requires broadcasting, adding complex numbers is not implemented correctly in `torch/_inductor/decomposition.py` because `x.view(x.real.dtype)` would multiply the last dimension by 2, and then broadcasting wouldn't work.

Fix: re-shape the complex tensors after view and before broadcasting.

Pull Request resolved: pytorch#129044
Approved by: https://github.com/zou3519, https://github.com/lezcano
  • Loading branch information
yushangdi authored and pytorchmergebot committed Jun 25, 2024
1 parent 6508f0f commit bbdeff7
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
20 changes: 20 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,26 @@ def fn(a, b):
3,
)

def test_add_complex5(self):
def fn(a, b, alpha):
return torch.add(a, b, alpha=alpha)

x = torch.tensor([[1 + 1j, -1 + 1j], [-2 + 2j, 3 - 3j]])
y = torch.tensor([[1 + 1j, -1 + 1j], [-2 + 2j, 3 - 3j]])

self.common(fn, (x, y, 2))

def test_add_complex6(self):
# Fix https://github.com/pytorch/pytorch/issues/125745.
# Add complex tensors with broadcasting.
def fn(a, b, alpha):
return torch.add(a, b, alpha=alpha)

x = torch.tensor([[1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j]])
y = torch.tensor([[1 + 1j]])

self.common(fn, (x, y, 2))

def test_concat_add_inplace(self):
def fn(x, y, z):
return torch.cat([x, y], dim=1).add_(z)
Expand Down
26 changes: 25 additions & 1 deletion torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def angle(x):

@register_decomposition([aten.add])
def add(x, y, *, alpha=None):
# Require both x and y to be complex tensors.
x_is_complex_tensor = torch.is_tensor(x) and x.is_complex()
y_is_complex_tensor = torch.is_tensor(y) and y.is_complex()
if not x_is_complex_tensor or not y_is_complex_tensor:
Expand All @@ -342,7 +343,30 @@ def add(x, y, *, alpha=None):
if alpha is not None:
z = alpha * y
complex_type = torch.promote_types(x.dtype, y.dtype)
return (x.view(x.real.dtype) + z.view(y.real.dtype)).view(complex_type)

# For complex typed `x`, `x.view(x.real.dtype)` doubles the last dimension and can cause problem
# when broadcasting the add.
def reshape_tensor_complex(tensor):
"""Reshape tensor from [*initial_dims, last_dim] to *initial_dims, last_dim/2, 2]"""
# Get the current shape of the tensor
*initial_dims, last_dim = tensor.shape

# Check if the last dimension is even. We should never reach here since `x.view(x.real.dtype)`
# doubles the last dimension for complex numbers.
if last_dim % 2 != 0:
raise AssertionError(
"The size of the last dimension must be even to reshape it to [..., last_dim/2, 2]"
)

# Reshape the tensor
new_shape = (*initial_dims, last_dim // 2, 2)
reshaped_tensor = tensor.view(new_shape)
return reshaped_tensor

x_reshaped = reshape_tensor_complex(x.view(x.real.dtype))
z_reshaped = reshape_tensor_complex(z.view(y.real.dtype))
result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type)
return result


@register_decomposition([aten.conj_physical])
Expand Down

0 comments on commit bbdeff7

Please sign in to comment.