Skip to content

Commit

Permalink
Add embeddings output to model builder (#1127)
Browse files Browse the repository at this point in the history
### Description

This PR adds support for outputting the last hidden state in addition to
the logits in ONNX models. Users can run their models with ONNX Runtime
GenAI and use the generator's `GetOutput` API to obtain the hidden
states.

C/C++:
```c
std::unique_ptr<OgaTensor> embeddings = generator->GetOutput("hidden_states");
```

C#:
```csharp
using var embeddings = generator.GetOutput("hidden_states");
```

Java:
```java
Tensor embeddings = generator.getOutput("hidden_states");
```

Python:
```python
embeddings = generator.get_output("hidden_states")
```

### Motivation and Context

In SLMs and LLMs, the last hidden state represents a model's embeddings
for a particular input before the language modeling head is applied.
Generating embeddings for a model is a popular task. These embeddings
can be used for many scenarios such as text classification, sequence
labeling, information retrieval using [retrieval-augmented generation
(RAG)](https://en.wikipedia.org/wiki/Retrieval-augmented_generation),
and more.

This PR helps the following issues:
- microsoft/onnxruntime#20969
- #442
- #474
- #713
  • Loading branch information
kunal-vaishnavi authored Dec 11, 2024
1 parent 76f949d commit c61aaa6
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 104 deletions.
61 changes: 36 additions & 25 deletions src/python/py/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ This folder contains the model builder for quickly creating optimized and quanti
- [GGUF Model](#gguf-model)
- [Extra Options](#extra-options)
- [Config Only](#config-only)
- [Hugging Face Authentication](#hugging-face-authentication)
- [Exclude Embedding Layer](#exclude-embedding-layer)
- [Exclude Language Modeling Head](#exclude-language-modeling-head)
- [Enable Cuda Graph](#enable-cuda-graph)
- [Include Last Hidden States Output](#include-last-hidden-states-output)
- [Enable CUDA Graph](#enable-cuda-graph)
- [Use 8 Bits Quantization in QMoE](#use-8-bits-quantization-in-qmoe)
- [Hugging Face Authentication](#hugging-face-authentication)
- [Use QDQ Pattern for Quantization](#use-qdq-pattern-for-quantization)
- [LoRA Models](#lora-models)
- [Unit Testing Models](#unit-testing-models)
Expand All @@ -30,12 +31,13 @@ This folder contains the model builder for quickly creating optimized and quanti

The tool currently supports the following model architectures.

- ChatGLM
- Gemma
- LLaMA
- Mistral
- Nemotron
- Phi
- Qwen
- Nemotron

It is intended for supporting the latest, popular state-of-the-art models.

Expand Down Expand Up @@ -141,6 +143,18 @@ python3 builder.py -m model_name -o path_to_output_folder -p precision -e execut

Afterwards, please open the `genai_config.json` file in the output folder and modify the fields as needed for your model. You should store your ONNX model in the output folder as well.

#### Hugging Face Authentication

This scenario is for when you need to disable the Hugging Face authentication or use a different authentication token than the one stored in [huggingface-cli login](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#huggingface-cli-login).

```
# From wheel:
python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options hf_token=false
# From source:
python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options hf_token=false
```

#### Exclude Embedding Layer

This scenario is for when you want to exclude the embedding layer from your ONNX model.
Expand All @@ -165,65 +179,60 @@ python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o p
python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options exclude_lm_head=true
```

#### Enable Cuda Graph
#### Include Last Hidden States Output

This scenario is for when you want to enable cuda graph for your ONNX model.
This scenario is for when you want to include the last hidden states as an output to your ONNX model.

```
# From wheel:
python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options enable_cuda_graph=1
python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options include_hidden_states=true
# From source:
python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options enable_cuda_graph=1
python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options include_hidden_states=true
```

#### Use 8 Bits Quantization in QMoE
Note that this is the same as outputting embeddings since the last hidden states are also known as the embeddings.

This scenario is for when you want to use 8-bit quantization for MoE layers. Default is using 4-bit quantization.
#### Enable CUDA Graph

This scenario is for when you want to enable CUDA graph for your ONNX model.

```
# From wheel:
python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_8bits_moe=1
python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options enable_cuda_graph=true
# From source:
python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_8bits_moe=1
python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options enable_cuda_graph=true
```

#### Hugging Face Authentication

This scenario is for when you need to disable the Hugging Face authentication or use a different authentication token than the one stored in [huggingface-cli login](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#huggingface-cli-login).

Possible values :
#### Use 8 Bits Quantization in QMoE

- hf_token=False
- hf_token=<user_token>
This scenario is for when you want to use 8-bit quantization for MoE layers. Default is using 4-bit quantization.

```
# From wheel:
python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options hf_token=False
python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_8bits_moe=true
# From source:
python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options hf_token=False
python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_8bits_moe=true
```

#### Use QDQ Pattern for Quantization

This scenario is for when you want to use the QDQ pattern (DequantizeLinear + MatMul) instead of the MatMulNBits operator when quantizing the model to 4 bits.
This scenario is for when you want to use the QDQ pattern when quantizing the model to 4 bits.

```
# From wheel:
python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_qdq=1
python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_qdq=true
# From source:
python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_qdq=1
python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_qdq=true
```

#### LoRA Models

This scenario is where you have a finetuned model with LoRA adapters and your model can be loaded in the Hugging Face style via [PEFT](https://github.com/huggingface/peft).

- path_to_local_folder_on_disk = location where base_model's weights are present

```
# From wheel:
python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p fp16 -e execution_provider -c cache_dir_to_store_temp_files --extra_options adapter_path=path_to_adapter_files
Expand All @@ -232,6 +241,8 @@ python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o p
python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p fp16 -e execution_provider -c cache_dir_to_store_temp_files --extra_options adapter_path=path_to_adapter_files
```

Base weights should be located in `path_to_local_folder_on_disk` and adapter weights should be located in `path_to_adapter_files`.

### Unit Testing Models

This scenario is where your PyTorch model is already downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk). If it is not already downloaded locally, here is an example of how you can download it.
Expand Down
Loading

0 comments on commit c61aaa6

Please sign in to comment.