Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DDPWrapper #2479

Open
wants to merge 13 commits into
base: juliagmt/test
Choose a base branch
from
Open

Conversation

juliagmt-google
Copy link
Collaborator

Benchmark improvements: #2468

anijain2305 and others added 4 commits September 26, 2024 00:50
Summary:
This reverts commit 7743149b2be4a9eba7e0997ccdc6abe552bec266.

Reverts
* pytorch/pytorch#135503
* pytorch/pytorch#135502
* pytorch/pytorch#135422

This passes this test. Earlier, the getitem would stay like a getitem in the Fx graph. But now the fake tensor propagations fails saying that .item is called. It seems that torch function is not getting triggered while fake tensor propagation.

```
import torch
from torch.nn.attention.flex_attention import BlockMask, _mask_mod_signature, _score_mod_signature, flex_attention
from torch._inductor.lowering import make_pointwise, register_lowering
from torch._inductor.virtualized import ops
from torch.nn.attention.flex_attention import create_block_mask

torch.set_default_device('cuda')

flex_attention = torch.compile(flex_attention, dynamic=False)

prefix_lengths = torch.arange(8)
def prefix_lm(b, h, q, kv):
    return prefix_lengths[b] >= kv

mask = create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True)
```

X-link: pytorch/pytorch#136590
Approved by: https://github.com/Chillee

Reviewed By: atalman

Differential Revision: D63431470

Pulled By: anijain2305

fbshipit-source-id: 60915b30336121b845af71f423582c22a6c65c3f
Summary: Add new metric `--metric nsys` to collect nsys trace.

Reviewed By: htyu

Differential Revision: D63274918

fbshipit-source-id: 0536310df6290ea5f5a02d85cc0ad6d342d45dbd
Summary:
pytorch#2458

Pull Request resolved: pytorch#2459

Reviewed By: xuzhao9

Differential Revision: D63476542

Pulled By: kit1980

fbshipit-source-id: 01e9db9cb03d34e82a773897417df2ccda410634
Summary: Pull Request resolved: pytorch#2473

Reviewed By: xuzhao9

Differential Revision: D63543625

Pulled By: bertmaher

fbshipit-source-id: 1693e15875544bda0f5f6c69daa5597fffd80509
Copy link
Collaborator Author

@juliagmt-google juliagmt-google left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test

Summary: Pull Request resolved: pytorch#2475

Reviewed By: htyu

Differential Revision: D63653081

Pulled By: xuzhao9

fbshipit-source-id: 8d840986779b6124cbccc2425c24e2b892d55ce4
Summary: We had the imports wrong for the internal port.

Reviewed By: xuzhao9, adamomainz

Differential Revision: D63643617

fbshipit-source-id: 04a49d419fede71d2681dedbfb55112a67cb4d55
Summary:
We have an old triton internally that doesn't have the cublasLt
bindings

Reviewed By: adamomainz

Differential Revision: D63643619

fbshipit-source-id: 39aece74b52f7747fe2100d7bb905bad49ba1fa0
Summary:
X-link: facebookresearch/FBGEMM#301

X-link: pytorch/FBGEMM#3202

Printing warnings to stdout mucks up the output of various tools/benchmarks

Reviewed By: xuzhao9, htyu

Differential Revision: D63643615

fbshipit-source-id: 1f34508a7fd36f5aa421e11bddd5ce77fc13038a
Summary: FBGEMM has changed how it declares its Cutlass-based blockwise gemm.

Reviewed By: htyu, sijiac, adamomainz

Differential Revision: D63643618

fbshipit-source-id: e46e3bbd2e07be0653f7c7fa6bd080b6c8db171e
Summary:
We have a big list of interesting shapes for blockwise/rowwise scaled
gemm.  A lot of these are variants of llama.  We might want to use them for
gemm and fp8_gemm (unscaled) as well, but for now just do blockwise/rowwise

Reviewed By: xuzhao9, adamomainz

Differential Revision: D63643616

fbshipit-source-id: 328961fe8c91e66428fcd1e5b72c89813f58a5a3
Summary:
We were only benchmarking `row-major x row-major` gemms (also called
`TT` or `transpose-transpose`, because FORTRAN), which is actually not the
common case; `nn.Linear` will use column-major layouts for weights, which means
`TN` is actually much more common.

Reviewed By: adamomainz

Differential Revision: D63714661

fbshipit-source-id: 735c25c59ddeb6596afd9b19f463af92036a830b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants