Skip to content

Commit

Permalink
Revert "RF PT matmul, use conv1d if possible"
Browse files Browse the repository at this point in the history
This reverts commit 7956f0b.

In my test benchmark (demo-rf-pt-benchmark.py),
this was actually slightly slower.
  • Loading branch information
albertz committed Oct 10, 2023
1 parent 7956f0b commit 336b71e
Showing 1 changed file with 22 additions and 58 deletions.
80 changes: 22 additions & 58 deletions returnn/torch/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,10 @@ def matmul(a: _TT, b: _TT, *, reduce: Union[Dim, Sequence[Dim]], use_mask: bool

common_axes_total_dimension = prod(common_axes_shape)

a_unique_axes_shape = tuple(a_shape[i] for i in a_unique_axes)
b_unique_axes_shape = tuple(b_shape[i] for i in b_unique_axes)

a_unique_axes_total_dimension = prod(a_unique_axes_shape)
b_unique_axes_total_dimension = prod(b_unique_axes_shape)

reduce_axes_shape = tuple(a_shape[i] for i in a_reduce_axes)
Expand All @@ -955,72 +958,33 @@ def matmul(a: _TT, b: _TT, *, reduce: Union[Dim, Sequence[Dim]], use_mask: bool

reduce_axes_total_dimension = prod(reduce_axes_shape)

# Check if conv1d makes sense, which is optimized for (Unique...,Reduce...,Unique...) * (Reduce...,Unique...) [^T]
if (
common_axes_total_dimension == 1
and len(a_unique_axes) >= 2
and len(a_reduce_axes) == 1
and min(a_unique_axes) < a_reduce_axes[0] < max(a_unique_axes)
and len(b_unique_axes) == 1
and len(b_reduce_axes) == 1
):
a_unique_axes_low = [i for i in a_unique_axes if i < a_reduce_axes[0]]
a_unique_axes_high = [i for i in a_unique_axes if i > a_reduce_axes[0]]

a_unique_axes_low_shape = tuple(a_shape[i] for i in a_unique_axes_low)
a_unique_axes_high_shape = tuple(a_shape[i] for i in a_unique_axes_high)
a_unique_axes_low_total_dimension = prod(a_unique_axes_low_shape)
a_unique_axes_high_total_dimension = prod(a_unique_axes_high_shape)

a_raw = torch.permute(a_raw, a_common_axes + a_unique_axes_low + a_reduce_axes + a_unique_axes_high)
b_raw = torch.permute(b_raw, b_common_axes + b_unique_axes + b_reduce_axes)

# The expectation is that view should always be possible here.
a_raw = a_raw.view(
a_unique_axes_low_total_dimension, reduce_axes_total_dimension, a_unique_axes_high_total_dimension
)
b_raw = b_raw.view(b_unique_axes_total_dimension, reduce_axes_total_dimension, 1)

raw_result = torch.nn.functional.conv1d(a_raw, b_raw)
raw_result = raw_result.view(
common_axes_shape + a_unique_axes_low_shape + b_unique_axes_shape + a_unique_axes_high_shape
)

a_unique_dims_low = [a_dims[i] for i in a_unique_axes_low]
a_unique_dims_high = [a_dims[i] for i in a_unique_axes_high]
b_unique_dims = [b_dims[i] for i in b_unique_axes]
result_dims = common_dims + a_unique_dims_low + b_unique_dims + a_unique_dims_high

else:
a_unique_axes_shape = tuple(a_shape[i] for i in a_unique_axes)
a_unique_axes_total_dimension = prod(a_unique_axes_shape)
a_raw = torch.permute(a_raw, a_common_axes + a_unique_axes + a_reduce_axes)
b_raw = torch.permute(b_raw, b_common_axes + b_reduce_axes + b_unique_axes)

a_raw = torch.permute(a_raw, a_common_axes + a_unique_axes + a_reduce_axes)
b_raw = torch.permute(b_raw, b_common_axes + b_reduce_axes + b_unique_axes)
if common_axes_total_dimension == 1: # standard matrix multiplication
a_raw = torch.reshape(a_raw, (a_unique_axes_total_dimension, reduce_axes_total_dimension))
b_raw = torch.reshape(b_raw, (reduce_axes_total_dimension, b_unique_axes_total_dimension))

if common_axes_total_dimension == 1: # standard matrix multiplication
a_raw = torch.reshape(a_raw, (a_unique_axes_total_dimension, reduce_axes_total_dimension))
b_raw = torch.reshape(b_raw, (reduce_axes_total_dimension, b_unique_axes_total_dimension))
raw_result = torch.mm(a_raw, b_raw)

raw_result = torch.mm(a_raw, b_raw)

else: # batched matrix multiplication
a_raw = torch.reshape(
a_raw, (common_axes_total_dimension, a_unique_axes_total_dimension, reduce_axes_total_dimension)
)
b_raw = torch.reshape(
b_raw, (common_axes_total_dimension, reduce_axes_total_dimension, b_unique_axes_total_dimension)
)
else: # batched matrix multiplication
a_raw = torch.reshape(
a_raw, (common_axes_total_dimension, a_unique_axes_total_dimension, reduce_axes_total_dimension)
)
b_raw = torch.reshape(
b_raw, (common_axes_total_dimension, reduce_axes_total_dimension, b_unique_axes_total_dimension)
)

raw_result = torch.bmm(a_raw, b_raw)
raw_result = torch.bmm(a_raw, b_raw)

raw_result = raw_result.view(common_axes_shape + a_unique_axes_shape + b_unique_axes_shape)
raw_result = torch.reshape(raw_result, common_axes_shape + a_unique_axes_shape + b_unique_axes_shape)

a_unique_dims = [a_dims[i] for i in a_unique_axes]
b_unique_dims = [b_dims[i] for i in b_unique_axes]
result_dims = common_dims + a_unique_dims + b_unique_dims
a_unique_dims = [a_dims[i] for i in a_unique_axes]
b_unique_dims = [b_dims[i] for i in b_unique_axes]
result_dims = common_dims + a_unique_dims + b_unique_dims

result_tensor = Tensor(name="dot", dims=result_dims, raw_tensor=raw_result, dtype=a.dtype)

return result_tensor

@staticmethod
Expand Down

0 comments on commit 336b71e

Please sign in to comment.