Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformers support generation, trainer, tutorial, etc. #748

Open
wants to merge 176 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
176 commits
Select commit Hold shift + click to select a range
3aaf623
add transformers generate
zhanghuiyao Oct 25, 2024
c3d0d0c
fix tensor_scatter_elements input data dtype
zhanghuiyao Oct 25, 2024
8a86468
1
zhanghuiyao Oct 25, 2024
f6262b5
support flash attention
zhanghuiyao Oct 25, 2024
6ea4b64
debug
zhanghuiyao Oct 25, 2024
dbb8129
fix fa mask
zhanghuiyao Oct 25, 2024
1edd1b3
1
zhanghuiyao Oct 25, 2024
121fe6f
add loading checkpoint
zhanghuiyao Oct 26, 2024
7783d5a
debug
zhanghuiyao Oct 26, 2024
8817181
debug
zhanghuiyao Oct 26, 2024
98ded28
1
zhanghuiyao Oct 26, 2024
cccf2bd
fix PreTrainedModel loaded_keys
zhanghuiyao Oct 26, 2024
f3bdbed
1
zhanghuiyao Oct 26, 2024
45c1bc6
add trainer support
zhanghuiyao Oct 28, 2024
8ddac5b
debug
zhanghuiyao Oct 28, 2024
32d6305
fix torch demo
zhanghuiyao Oct 28, 2024
91ae1d3
debug
zhanghuiyao Oct 28, 2024
8003840
debug
zhanghuiyao Oct 28, 2024
013ce28
debug
zhanghuiyao Oct 28, 2024
12cca4a
debug
zhanghuiyao Oct 28, 2024
a63fedb
debug
zhanghuiyao Oct 28, 2024
56a6886
debug
zhanghuiyao Oct 28, 2024
602676b
debug
zhanghuiyao Oct 28, 2024
0896b00
debug
zhanghuiyao Oct 28, 2024
65ec177
fix generate length
zhanghuiyao Oct 29, 2024
5a5d759
debug
zhanghuiyao Oct 29, 2024
217de5d
debug
zhanghuiyao Oct 29, 2024
e1b2d1d
debug
zhanghuiyao Oct 29, 2024
a54bd30
debug
zhanghuiyao Oct 29, 2024
99bccb8
debug
zhanghuiyao Oct 29, 2024
98438b7
debug
zhanghuiyao Oct 29, 2024
41dac1f
debug
zhanghuiyao Oct 29, 2024
717a59a
debug
zhanghuiyao Oct 29, 2024
d13d775
debug
zhanghuiyao Oct 29, 2024
20af81a
debug
zhanghuiyao Oct 29, 2024
ca79f6e
debug
zhanghuiyao Oct 30, 2024
36fe748
debug
zhanghuiyao Oct 30, 2024
9855824
debug
zhanghuiyao Oct 30, 2024
8e93e07
debug
zhanghuiyao Oct 30, 2024
03cb2a9
debug
zhanghuiyao Oct 30, 2024
80f7883
debug
zhanghuiyao Oct 30, 2024
d745e3f
debug
zhanghuiyao Oct 30, 2024
a301d6d
fix generate nan bug with graph+kbk on mindspore 2.3.1
zhanghuiyao Oct 30, 2024
e13be43
add some comment
zhanghuiyao Oct 30, 2024
75324db
add finetune
zhanghuiyao Oct 31, 2024
4be0c0b
fix data 1
zhanghuiyao Oct 31, 2024
714c7ba
fix data 1
zhanghuiyao Oct 31, 2024
763fcfc
debug
zhanghuiyao Oct 31, 2024
73215c1
fix args
zhanghuiyao Oct 31, 2024
8e93c73
fix args
zhanghuiyao Oct 31, 2024
0eec5b8
fix args
zhanghuiyao Oct 31, 2024
d9dea1f
fix args
zhanghuiyao Oct 31, 2024
95db3ca
fix args
zhanghuiyao Oct 31, 2024
e59228b
fix args
zhanghuiyao Oct 31, 2024
75275c8
fix llama model
zhanghuiyao Oct 31, 2024
bbe305f
fix evaluate.load(accuracy)
zhanghuiyao Oct 31, 2024
c0acbdd
fix args dataloader workers
zhanghuiyao Oct 31, 2024
9e4bdfd
support to create optimizer and lr_scheduler
zhanghuiyao Nov 4, 2024
5cf22c4
debug
zhanghuiyao Nov 4, 2024
027ae1e
debug
zhanghuiyao Nov 4, 2024
9b559d6
debug
zhanghuiyao Nov 4, 2024
0cf9256
debug
zhanghuiyao Nov 4, 2024
0bf4785
debug
zhanghuiyao Nov 4, 2024
5dfe239
debug
zhanghuiyao Nov 4, 2024
f69247b
debug
zhanghuiyao Nov 4, 2024
5a51a42
debug
zhanghuiyao Nov 4, 2024
f5d7628
debug
zhanghuiyao Nov 4, 2024
770b1b4
debug
zhanghuiyao Nov 4, 2024
4fd3342
debug
zhanghuiyao Nov 4, 2024
ad0bff1
debug
zhanghuiyao Nov 4, 2024
7f2dd1d
debug
zhanghuiyao Nov 4, 2024
4afa0a1
debug
zhanghuiyao Nov 4, 2024
8b9aab3
debug
zhanghuiyao Nov 4, 2024
d96aee3
debug
zhanghuiyao Nov 4, 2024
5a40fb6
debug
zhanghuiyao Nov 4, 2024
1d062cd
debug
zhanghuiyao Nov 4, 2024
0e9fa4b
support training_step
zhanghuiyao Nov 4, 2024
02cddd7
fix llama model training use cache
zhanghuiyao Nov 4, 2024
1f6b8be
fix train model wrapper
zhanghuiyao Nov 4, 2024
7689a09
support dict inputs
zhanghuiyao Nov 5, 2024
b76c406
delete numpy.float
zhanghuiyao Nov 5, 2024
71328c0
debug
zhanghuiyao Nov 5, 2024
00ee7ea
modify default
zhanghuiyao Nov 5, 2024
1b244e4
debug
zhanghuiyao Nov 5, 2024
3ec92fa
debug
zhanghuiyao Nov 5, 2024
3c91d16
debug
zhanghuiyao Nov 5, 2024
99e3e8e
debug
zhanghuiyao Nov 5, 2024
1717552
debug
zhanghuiyao Nov 5, 2024
0dd491e
debug
zhanghuiyao Nov 5, 2024
bf209f8
debug
zhanghuiyao Nov 5, 2024
bffbfc6
debug
zhanghuiyao Nov 5, 2024
849b5cd
debug
zhanghuiyao Nov 5, 2024
35da137
fix train wrap
zhanghuiyao Nov 5, 2024
c35c67a
fix prepare_inputs dtype
zhanghuiyao Nov 5, 2024
c09b699
support save_checkpoint
zhanghuiyao Nov 5, 2024
475dad6
fix save_checkpoint process
zhanghuiyao Nov 5, 2024
0fd6527
debug
zhanghuiyao Nov 5, 2024
b8f7001
debug
zhanghuiyao Nov 5, 2024
0b3635e
debug
zhanghuiyao Nov 5, 2024
748967a
fix recompute
zhanghuiyao Nov 5, 2024
83996d8
debug
zhanghuiyao Nov 5, 2024
b960c32
debug
zhanghuiyao Nov 5, 2024
7117179
debug
zhanghuiyao Nov 5, 2024
be323dc
debug
zhanghuiyao Nov 5, 2024
7571ce6
debug
zhanghuiyao Nov 5, 2024
1408e02
debug
zhanghuiyao Nov 5, 2024
b97ac4b
support momentum optimizer
zhanghuiyao Nov 5, 2024
71780f0
debug
zhanghuiyao Nov 5, 2024
0df43aa
debug
zhanghuiyao Nov 5, 2024
fe9cf78
add finetune_in_native_mindspore.py
zhanghuiyao Nov 5, 2024
4b0bd16
fix finetune_in_native_mindspore.py
zhanghuiyao Nov 5, 2024
e0cc533
fix finetune_in_native_mindspore.py second
zhanghuiyao Nov 5, 2024
4e9444d
debug
zhanghuiyao Nov 5, 2024
3a006f4
debug
zhanghuiyao Nov 5, 2024
147e5a8
debug
zhanghuiyao Nov 5, 2024
0a562e2
debug
zhanghuiyao Nov 5, 2024
378fca3
debug
zhanghuiyao Nov 5, 2024
be83fb4
Add bert training
zhanghuiyao Nov 6, 2024
88e6db9
fix bert training 1
zhanghuiyao Nov 6, 2024
1c9ca99
fix bert training 2
zhanghuiyao Nov 6, 2024
484e721
fix bert training 2
zhanghuiyao Nov 6, 2024
742bb5d
debug
zhanghuiyao Nov 6, 2024
5d157a3
debug
zhanghuiyao Nov 6, 2024
5548196
debug
zhanghuiyao Nov 6, 2024
15c7ac1
fix bert training 3
zhanghuiyao Nov 6, 2024
19cd28b
add modeling_llama init_weight
zhanghuiyao Nov 6, 2024
b1f8bbf
decommented loading checkpoint
zhanghuiyao Nov 6, 2024
e774a78
1. add llama distribute train; 2. fix some comment
zhanghuiyao Nov 6, 2024
fee1e68
add distribute scripts
zhanghuiyao Nov 6, 2024
2f15523
fix zero_stage args
zhanghuiyao Nov 6, 2024
7d041b9
fix llama train scripts args
zhanghuiyao Nov 6, 2024
5840bbd
update run scripts
zhanghuiyao Nov 6, 2024
a2536ac
debug
zhanghuiyao Nov 6, 2024
1797800
fix training_args jit_mode
zhanghuiyao Nov 6, 2024
1938063
fix LlamaForSequenceClassification
zhanghuiyao Nov 6, 2024
f469a49
fix training args mode
zhanghuiyao Nov 6, 2024
58cd5e9
debug
zhanghuiyao Nov 6, 2024
5bb7ffa
debug
zhanghuiyao Nov 6, 2024
54957c5
delete llama model output_attention
zhanghuiyao Nov 6, 2024
df68174
debug
zhanghuiyao Nov 6, 2024
21fc9fd
debug
zhanghuiyao Nov 6, 2024
fe5abd6
debug
zhanghuiyao Nov 6, 2024
076ce5e
debug
zhanghuiyao Nov 6, 2024
b8ae647
debug
zhanghuiyao Nov 6, 2024
099faa8
debug
zhanghuiyao Nov 6, 2024
a6ebd1c
debug
zhanghuiyao Nov 6, 2024
3dbd218
debug
zhanghuiyao Nov 6, 2024
9d7f5ca
debug
zhanghuiyao Nov 6, 2024
0951eee
fix llama return_dict bug
zhanghuiyao Nov 6, 2024
3ae74fa
fix adamw_zero in single cards
zhanghuiyao Nov 6, 2024
dc8388f
debug
zhanghuiyao Nov 6, 2024
251dca8
decomment loading checkpoint
zhanghuiyao Nov 6, 2024
667c9cb
enable bf16
zhanghuiyao Nov 6, 2024
8dd0ccb
debug
zhanghuiyao Nov 7, 2024
98b4658
fix trainer amp fp16
zhanghuiyao Nov 7, 2024
20c7033
support distribute llama_ft_in_native_mindspore
zhanghuiyao Nov 7, 2024
13967cf
set bs to 1
zhanghuiyao Nov 7, 2024
b35012b
set mindspore_dtype with args.fp16/bf16
zhanghuiyao Nov 7, 2024
5704771
fix adamw_zero bf16
zhanghuiyao Nov 7, 2024
2d24180
add lazy_inline for llama
zhanghuiyao Nov 7, 2024
e59bd56
set bs to 8
zhanghuiyao Nov 7, 2024
8020af6
decomment loading checkpoint
zhanghuiyao Nov 7, 2024
e1329f5
add native train script
zhanghuiyao Nov 7, 2024
b7318ca
fix args
zhanghuiyao Nov 7, 2024
6021035
update readme
zhanghuiyao Nov 7, 2024
77b6d56
update readme
zhanghuiyao Nov 7, 2024
229bf43
update docs
zhanghuiyao Nov 8, 2024
fa3dce0
delete hf_configs
zhanghuiyao Nov 8, 2024
ddaaedf
Merge branch 'master' into _transformers_pr
zhanghuiyao Nov 8, 2024
ad76449
update readme
zhanghuiyao Nov 8, 2024
5265311
update readme
zhanghuiyao Nov 8, 2024
06bf39e
update readme
zhanghuiyao Nov 8, 2024
5589753
fix clip grad norm on zero
zhanghuiyao Nov 14, 2024
28f0293
fix pre-commit format
zhanghuiyao Nov 14, 2024
8048e81
delete comment
zhanghuiyao Nov 15, 2024
328217e
modify amp
zhanghuiyao Nov 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/transformers/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,11 @@
- local: index
title: 🤗 Transformers
title: Get started
- sections:
- local: tutorials/finetune
title: Fine-tune a pretrained model
- local: tutorials/finetune_distribute
title: Distributed training and mixed precision
- local: tutorials/generation
title: Generation with LLMs
title: Tutorials
243 changes: 243 additions & 0 deletions docs/transformers/tutorials/finetune.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Fine-tune a pretrained model

There are significant benefits to using a pretrained model. It reduces computation costs, your carbon footprint, and allows you to use state-of-the-art models without having to train one from scratch. 🤗 Transformers provides access to thousands of pretrained models for a wide range of tasks. When you use a pretrained model, you train it on a dataset specific to your task. This is known as fine-tuning, an incredibly powerful training technique. In this tutorial, you will fine-tune a pretrained model with a deep learning framework of your choice:

- Fine-tune a pretrained model with 🤗 Transformers Trainer.
- Fine-tune a pretrained model in native MindSpore.

## Prepare a dataset

Before you can fine-tune a pretrained model, download a dataset and prepare it for training. The previous tutorial showed you how to process data for training, and now you get an opportunity to put those skills to the test!

Begin by loading the Yelp Reviews dataset:

```pycon
>>> from datasets import load_dataset

>>> dataset = load_dataset("yelp_review_full")
>>> dataset["train"][100]
{'label': 0,
'text': 'My expectations for McDonalds are t rarely high. But for one to still fail so spectacularly...that takes something special!\\nThe cashier took my friends\'s order, then promptly ignored me. I had to force myself in front of a cashier who opened his register to wait on the person BEHIND me. I waited over five minutes for a gigantic order that included precisely one kid\'s meal. After watching two people who ordered after me be handed their food, I asked where mine was. The manager started yelling at the cashiers for \\"serving off their orders\\" when they didn\'t have their food. But neither cashier was anywhere near those controls, and the manager was the one serving food to customers and clearing the boards.\\nThe manager was rude when giving me my order. She didn\'t make sure that I had everything ON MY RECEIPT, and never even had the decency to apologize that I felt I was getting poor service.\\nI\'ve eaten at various McDonalds restaurants for over 30 years. I\'ve worked at more than one location. I expect bad days, bad moods, and the occasional mistake. But I have yet to have a decent experience at this store. It will remain a place I avoid unless someone in my party needs to avoid illness from low blood sugar. Perhaps I should go back to the racially biased service of Steak n Shake instead!'}
```

As you now know, you need a tokenizer to process the text and include a padding and truncation strategy to handle any variable sequence lengths. To process your dataset in one step, use 🤗 Datasets map method to apply a preprocessing function over the entire dataset:

```pycon
>>> from transformers import AutoTokenizer

>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")


>>> def tokenize_function(examples):
... return tokenizer(examples["text"], padding="max_length", truncation=True)


>>> tokenized_datasets = dataset.map(tokenize_function, batched=True)
```

If you like, you can create a smaller subset of the full dataset to fine-tune on to reduce the time it takes:

```pycon
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
```

## Train

At this point, you should follow the section corresponding to the framework you want to use. You can use the links in the right sidebar to jump to the one you want - and if you want to hide all of the content for a given framework, just use the button at the top-right of that framework’s block!

### Train with MindSpore Trainer

<details open>

!!! Note

Taking bert as an example, you can find the complete code in `examples/transformers/bert/finetune_with_mindspore_trainer.py`

🤗 Transformers provides a Trainer class optimized for training 🤗 Transformers models, making it easier to start training without manually writing your own training loop. The Trainer API supports a wide range of training options and features such as logging, gradient accumulation, and mixed precision.

Start by loading your model and specify the number of expected labels. From the Yelp Review dataset card, you know there are five labels:

```pycon
>>> from mindone.transformers.models.bert import BertForSequenceClassification

>>> model = BertForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
```

!!! Note

You will see a warning about some of the pretrained weights not being used and some weights being randomly initialized. Don’t worry, this is completely normal! The pretrained head of the BERT model is discarded, and replaced with a randomly initialized classification head. You will fine-tune this new model head on your sequence classification task, transferring the knowledge of the pretrained model to it.

#### Training hyperparameters

Next, create a TrainingArguments class which contains all the hyperparameters you can tune as well as flags for activating different training options. For this tutorial you can start with the default training hyperparameters, but feel free to experiment with these to find your optimal settings.

Specify where to save the checkpoints from your training:

```pycon
>>> from mindone.transformers.training_args import TrainingArguments

>>> training_args = TrainingArguments(output_dir="test_trainer")
```

(optional but recommended) Init environment:

```pycon
>>> import mindspore as ms
>>> from mindone.transformers.mindspore_adapter import MindSporeArguments, init_environment

>>> env_args = MindSporeArguments(mode=ms.GRAPH_MODE, device_target="Ascend")
>>> init_environment(env_args)
```

#### Trainer

Create a Trainer object with your model, training arguments, training and test datasets, and evaluation function:

```pycon
>>> trainer = Trainer(
... model=model,
... args=training_args,
... train_dataset=small_train_dataset,
... eval_dataset=small_eval_dataset,
... compute_metrics=compute_metrics,
... )
```

Then fine-tune your model by calling train():

```pycon
>>> trainer.train()
```

</details>

### Train in native MindSpore

<details open>

!!! Note

Taking bert as an example, you can find the complete code in `examples/transformers/bert/finetune_in_native_mindspore.py`

Trainer takes care of the training loop and allows you to fine-tune a model in a single line of code. For users who prefer to write their own training loop, you can also fine-tune a 🤗 Transformers model in native MindSpore.

At this point, you may need to restart your notebook to free memory.

Next, manually postprocess `tokenized_dataset` to prepare it for training.

1. Remove the text column because the model does not accept raw text as an input:

```pycon
>>> tokenized_datasets = tokenized_datasets.remove_columns(["text"])
```

2. Rename the `label` column to `labels` because the model expects the argument to be named `labels`:

```pycon
>>> tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
```

#### DataLoader

Create a MindSpore DataLoader for your training datasets so you can iterate over batches of data:

```pycon
>>> import mindspore as ms
>>> from mindone.transformers.mindspore_adapter import HF2MSDataset

>>> def ms_data_collator(features, batch_info):
... batch = {}
... for k, v in features[0]:
... batch[k] = np.stack([f[k] for f in features]) if isinstance(v, np.ndarray) else np.array([f[k] for f in features])
... return batch

>>> batch_size, num_epochs = 1, 3
>>> train_dataloader = ms.dataset.GeneratorDataset(HF2MSDataset(small_train_dataset), column_names="item")
>>> train_dataloader = train_dataloader.batch(batch_size=batch_size, per_batch_map=ms_data_collator)
>>> train_dataloader = train_dataloader.repeat(1)
>>> train_dataloader = train_dataloader.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
```

Load your model with the number of expected labels:

```pycon
>>> from mindone.transformers.models.bert import BertForSequenceClassification

>>> model = BertForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
```

#### Optimizer

Create an optimizer to fine-tune the model. Let’s use the AdamWeightDecay optimizer from MindSpore:

```pycon
>>> from mindspore import nn

>>> optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=5e-6)
```

#### Train Network

Create an MindSpore train network

```pycon
>>> from mindone.transformers.mindspore_adapter import TrainOneStepWrapper

>>> class ReturnLoss(nn.Cell):
... def __init__(self, model):
... super(ReturnLoss, self).__init__(auto_prefix=False)
... self.model = model
...
... def construct(self, *args, **kwargs):
... outputs = self.model(*args, **kwargs)
... loss = outputs[0]
... return loss

>>> train_model = TrainOneStepWrapper(ReturnLoss(model), optimizer)
```

Great, now you are ready to train! 🥳

#### Training loop

To keep track of your training progress, use the tqdm library to add a progress bar over the number of training steps:

```pycon
>>> from tqdm.auto import tqdm

>>> num_training_steps = len(small_train_dataset) * num_epochs // batch_size
>>> progress_bar = tqdm(range(num_training_steps))

>>> train_model.train()
>>> for step, batch in enumerate(train_dataloader):
... batch = batch["item"]
...
... tuple_inputs = (
... ms.Tensor(batch["input_ids"], ms.int32),
... ms.Tensor(batch["attention_mask"], ms.bool_),
... None,
... None,
... None,
... None,
... ms.tensor(batch["labels"], ms.int32)
... )
...
... loss, _, overflow = train_model(*tuple_inputs)
...
... progress_bar.update(1)
```

</details>
37 changes: 37 additions & 0 deletions docs/transformers/tutorials/finetune_distribute.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Distributed training with mixed precision and ZeRO parallelism

The Trainer supports distributed training and mixed precision, which means you can also use it in a script. To enable both of these features:

See `examples/transformers/llama/finetune_with_mindspore_trainer.py` for more detail.

- Add the `is_distribute` argument to enable distribute training.
- Add the `fp16` or `bf16` argument to enable mixed precision.
- Add the `zero_stage` argument to enable optimizer parallelism with `ZeRO` algorithm.
- Set the number of global/local NPUs to use with the `worker_num`/`local_worker_num` argument.

```shell
msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir=outputs/parallel_logs \
python finetune_with_mindspore_trainer.py \
--model_path $local_path/meta-llama/Meta-Llama-3-8B \
--dataset_path $local_path/yelp_review_full \
--output_dir ./outputs \
--bf16 \
--zero_stage 2 \
--is_distribute True
```

Another example implemented through native MindSpore, see `examples/transformers/llama/finetune_in_native_mindspore.py` for more detail.

<details onclose>

```shell
msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir=outputs/parallel_logs \
python finetune_in_native_mindspore.py \
--model_path meta-llama/Meta-Llama-3-8B \
--dataset_path Yelp/yelp_review_full \
--bf16 \
--zero_stage 2 \
--is_distribute True
```

</details>
102 changes: 102 additions & 0 deletions docs/transformers/tutorials/generation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Generation with LLMs

LLMs, or Large Language Models, are the key component behind text generation. In a nutshell, they consist of large pretrained transformer models trained to predict the next word (or, more precisely, token) given some input text. Since they predict one token at a time, you need to do something more elaborate to generate new sentences other than just calling the model — you need to do autoregressive generation.

Autoregressive generation is the inference-time procedure of iteratively calling a model with its own generated outputs, given a few initial inputs. In 🤗 Transformers, this is handled by the generate() method, which is available to all models with generative capabilities.

This tutorial will show you how to:

- Generate text with an LLM

Before you begin, make sure you have all the necessary libraries installed:

```shell
pip install transformers==4.42.4
```

## Generate text

!!! Note

Taking llama as an example, you can find the complete code in `examples/transformers/llama/generate.py`
And you can compare the results of script `examples/transformers/llama/generate_pt.py` with PyTorch.

A language model trained for causal language modeling takes a sequence of text tokens as input and returns the probability distribution for the next token.

A critical aspect of autoregressive generation with LLMs is how to select the next token from this probability distribution. Anything goes in this step as long as you end up with a token for the next iteration. This means it can be as simple as selecting the most likely token from the probability distribution or as complex as applying a dozen transformations before sampling from the resulting distribution.

The process depicted above is repeated iteratively until some stopping condition is reached. Ideally, the stopping condition is dictated by the model, which should learn when to output an end-of-sequence (EOS) token. If this is not the case, generation stops when some predefined maximum length is reached.

Properly setting up the token selection step and the stopping condition is essential to make your model behave as you’d expect on your task. That is why we have a GenerationConfig file associated with each model, which contains a good default generative parameterization and is loaded alongside your model.

Let’s talk code!

!!! Note

If you’re interested in basic LLM usage, our high-level Pipeline interface is a great starting point. However, LLMs often require advanced features like quantization and fine control of the token selection step, which is best done through generate(). Autoregressive generation with LLMs is also resource-intensive and should be executed on a Ascend NPU for adequate throughput.

First, you need to load the model.

```pycon
>>> from mindone.transformers.models.llama import LlamaForCausalLM

>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
```

There are other ways to initialize a model, but this is a good baseline to begin with an LLM.

Next, you need to preprocess your text input with a tokenizer.

```pycon
>>> from transformers import AutoTokenizer

>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
>>> input_ids = ms.Tensor(tokenizer(["A list of colors: red, blue"]).input_ids, ms.int32)
```

The model_inputs variable holds the tokenized text input, as well as the attention mask. While generate() does its best effort to infer the attention mask when it is not passed, we recommend passing it whenever possible for optimal results.

After tokenizing the inputs, you can call the generate() method to returns the generated tokens. The generated tokens then should be converted to text before printing.

```pycon
>>> generated_ids = model.generate(
... input_ids=input_ids,
... max_new_tokens=30,
... use_cache=True,
... do_sample=False,
... )

>>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
```

Finally, you don’t need to do it one sequence at a time! You can batch your inputs, which will greatly improve the throughput at a small latency and memory cost. All you need to do is to make sure you pad your inputs properly (more on that below).

```pycon
>>> tokenizer.pad_token = tokenizer.eos_token # Most LLMs don't have a pad token by default
>>> input_ids = ms.Tensor(tokenizer(
... ["A list of colors: red, blue", "Portugal is"], padding=True
... ).input_ids, ms.int32)

>>> generated_ids = model.generate(
... input_ids=input_ids,
... max_new_tokens=30,
... use_cache=True,
... do_sample=False,
... )

>>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
```

And that’s it! In a few lines of code, you can harness the power of an LLM.
Loading