Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add axiswise granularity to Float8Tensor #919

Merged
merged 9 commits into from
Oct 7, 2024
Merged

add axiswise granularity to Float8Tensor #919

merged 9 commits into from
Oct 7, 2024

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Sep 23, 2024

Summary:

This is a copy-paste of pytorch-labs/float8_experimental#352
which never landed.

Test Plan:


Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Sep 23, 2024

Copy link

pytorch-bot bot commented Sep 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/919

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1f01df9 with merge base 5dd0132 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 23, 2024
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@vkuzo vkuzo requested a review from drisspg October 1, 2024 01:30
test/float8/test_base.py Outdated Show resolved Hide resolved
torchao/float8/float8_tensor.py Outdated Show resolved Hide resolved
vkuzo added 3 commits October 2, 2024 08:29
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Copy link
Contributor

@lw lw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I skimmed through it, LGTM!

I have one suggestion to have a more streamlined and extendable API, see below

Comment on lines +148 to +149
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=axiswise_dim,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion for another API: instead of an enum + extra params on a case-by-case basis, we could reuse the same idea that @drisspg used in the _scaled_mm operator: deduce the kind of scaling based on the size/shape of the desired scale tensor!

Concretely, we could add a single scale_shape=... parameter, which for row-wise would be [-1, 1], indicating that:

  • all columns (second dim) should be grouped and reduced into a single scaling factor (because the second element has a value of 1)
  • but that for the rows (first dim) there should be as many scaling factors as there are rows (because the first element has a value of -1, which gets replaced with the dim of the input tensor).

The scale shape is right-aligned to the shape of the tensor (thus following PyTorch's standard broadcast semantics), and then left-padded with 1 (again, standard semantics). This means that tensor-wise scaling is achieved with a scale_size=[].

Using this convention will later allow to express block-wise scaling (e.g., 128x128), group-wise scaling (1x128) and maybe even column-wise scaling if that ever becomes a thing!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One wrinkle to work through would be that Float8Tensor can be of any rank, but operand inputs to torch._scaled_mm are required to be of rank 2, to match torch.mm|torch.addmm.

I'm definitely open to making this more flexible in the future. We've been careful to keep Float8Tensor and these utility functions out of the public API, to give us the freedom to make these kinds of changes as other scaling types become more important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, if someone puts up a PR for ^, sgtm!

[ghstack-poisoned]
@@ -191,15 +188,6 @@ def __init__(self, *args, **kwargs):
# would be initialized in every iteration.
self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward

# See the comments in config.py for more details of this option.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically not related to this PR, but making the test logs non-spammy for now and we can add this back in a better way later

@vkuzo vkuzo merged commit 52d27a1 into main Oct 7, 2024
43 checks passed
jainapurva pushed a commit that referenced this pull request Oct 15, 2024
Summary:

This is a copy-paste of pytorch-labs/float8_experimental#352
which never landed.

Test Plan:


Reviewers:

Subscribers:

Tasks:

Tags:
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* Use ao's int4 quantizer

* Point AO to commit hash of Jerry's fix

* When device is cuda, only run for dtype==bfloat16

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Typo

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Use tensor subclass for int4 weight only quant

* Fix bug

* Fix

* Use both quantizer and subclass API

* Bug

* unwrap tensor subclass for aoti

* Add import

* Eval fix

* Evaluate AOTI

---------

Co-authored-by: Mengwei Liu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants