Skip to content

Commit

Permalink
move loop ordering after fusion (#126254)
Browse files Browse the repository at this point in the history
Summary:
Restart the work from PR pytorch/pytorch#100331 in this new PR since it's hard to rebase. It would be expected that some code is copy/pasted from the previous PR and main idea is the same.

Previously we see relatively large compilation time increase due to too many loop orders being considered. This PR tries to continue the work by doing pruning and only considering loop orders that we know for sure are relevant (i.e. do it on demand).

Some manually created cases that loop ordering matters are added as unit tests. The PR can make sure inductor does not miss fusion opportunities for them.

This PR should solve the not-able to fusion problem in pytorch/pytorch#130015

Right now there is still significant increase of compilation time. I'll disable the feature by default. Later on after the compilation time issue is resolved, I'll enable it  by default.

X-link: pytorch/pytorch#126254
Approved by: https://github.com/jansel

Reviewed By: ZainRizvi

Differential Revision: D62008970

Pulled By: shunting314

fbshipit-source-id: ce4c7c7003b93a2faccd2c65d78eeee0300b6bff
  • Loading branch information
shunting314 authored and facebook-github-bot committed Aug 30, 2024
1 parent 1f4461f commit 49bcc38
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1728,6 +1728,26 @@ def to_tensor(t):

# Check error from fp64 version
if fp64_ref.dtype == torch.float64:
# Fix a corner case that res and fp64_ref does not contains NaN and match (with loose tolerance)
# while the ref contains NaN. In this case, RMSE should not match any ways.
# But res is 'BETTER' than ref so we count it pass.
#
# This happens for Super_SloMo when loop ordering after fusion is enabled:
# https://gist.github.com/shunting314/11f235c70f7db0d52718d26f4a701cab
loose_tol = 1e-2 * 4
if (
not fp64_ref.isnan().any()
and not res.isnan().any()
and ref.isnan().any()
and torch.allclose(
fp64_ref.to(dtype=res.dtype),
res,
atol=loose_tol,
rtol=loose_tol,
equal_nan=equal_nan,
)
):
return True
ref_error = rmse(fp64_ref, ref).item()
# ref unable to produce this with stable numerics in this precision, ignore
if math.isnan(ref_error):
Expand Down

0 comments on commit 49bcc38

Please sign in to comment.