Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformers support generation, trainer, tutorial, etc. #748

Open
wants to merge 176 commits into
base: master
Choose a base branch
from

Conversation

zhanghuiyao
Copy link
Collaborator

@zhanghuiyao zhanghuiyao commented Nov 8, 2024

What does this PR do?

Adds # (feature)

  1. generation
  2. trainer
  3. some tutorials, readme, docs
  4. llama3 8b infer/gen/finetune

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you make sure to update the documentation with your changes? E.g. record bug fixes or new features in What's New. Here are the
    documentation guidelines
  • Did you build and run the code without any errors?
  • Did you report the running environment (NPU type/MS version) and performance in the doc? (better record it for data loading, model inference, or training tasks)
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@xxx

python finetune_in_native_mindspore.py \
--model_path meta-llama/Meta-Llama-3-8B \
--dataset_path Yelp/yelp_review_full \
\
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

nn.AvgPool1d,
nn.AvgPool2d,
nn.AvgPool3d,
nn.CrossEntropyLoss,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this hard-coded fp32 layers may not fit for all models

return out


class FlashAttention2(nn.Cell):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may illustrate it's just a wrapper, not the real FA2.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i heard before that it is really fa2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about using mint adamw?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will switch the mint uniformly later

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary for pynative?

Copy link
Collaborator

@SamitHuang SamitHuang Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can list the supported features compared to torch? For example, beam_search for generation, 8-bit quantization for memory reduction, are quite commonly used, but seems to be missing in this PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently only the most basic sample method for generation is supported to provide for MLLMs to use, and do not support any quantification.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the full interface will be provided in the subsequent version 4.46.2, including beam_search for generation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better claim which ms version and mode (graph/pynative) are mainly tested. If both graph and pynative mode are supported, do both of them guarantee good accuracy? or just both runnable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the he newly added features were validated on ms2.3.1, but some existing interfaces may not be supported, this will require complete validation before providing a confirmed support version to the public.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for infer/generate, it can have good accuracy, and for training, it is currently only in a runnable state.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants