Skip to content

Commit

Permalink
rebase with main
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Apr 12, 2024
2 parents 9e89942 + e295090 commit 0fbcaac
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 86 deletions.
2 changes: 1 addition & 1 deletion VERSION_INFO
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.0-dev
0.2.0-dev
9 changes: 2 additions & 7 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ Logits::Logits(const Model& model, State& state)
state_{state},
shape_{static_cast<int64_t>(state_.params_->batch_size) * state_.params_->search.num_beams, state_.params_->sequence_length, state_.params_->vocab_size},
type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} {
if (model_.device_type_ == DeviceType::CPU && type_ != Ort::TypeToTensorType<float>::type)
throw std::runtime_error("Model logits_type can only be float32 on CPU");

auto logits_tensor = OrtValue::CreateTensor(*model.allocator_device_, shape_, type_);
if (type_ == Ort::TypeToTensorType<float>::type)
value32_ = std::move(logits_tensor);
Expand All @@ -32,11 +29,9 @@ Logits::Logits(const Model& model, State& state)
RoamingArray<float> Logits::Get() {
size_t element_count = shape_[0] * shape_[1] * shape_[2];

#if USE_CUDA
// Convert from float16 to float32 if necessary
if (model_.device_type_ == DeviceType::CUDA && type_ == Ort::TypeToTensorType<Ort::Float16_t>::type)
ConvertFp16ToFp32(*model_.allocator_device_, model_.cuda_stream_, *value16_, value32_);
#endif
if (type_ == Ort::TypeToTensorType<Ort::Float16_t>::type)
ConvertFp16ToFp32(*model_.allocator_device_, *value16_, value32_, model_.device_type_, model_.cuda_stream_);

// First iteration? Then copy the logits over to a {batch_beams, 1, vocab_size} tensor
// We'll reuse this tensor for all future iterations
Expand Down
20 changes: 16 additions & 4 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,7 @@ std::shared_ptr<GeneratorParams> CreateGeneratorParams() {
return std::make_shared<GeneratorParams>();
}

#if USE_CUDA
void ConvertFp16ToFp32(OrtAllocator& allocator, cudaStream_t stream, OrtValue& in, std::unique_ptr<OrtValue>& p_out) {
void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<OrtValue>& p_out, DeviceType device_type, cudaStream_t stream) {
auto shape_info = in.GetTensorTypeAndShapeInfo();
auto shape = shape_info->GetShape();
assert(shape_info->GetElementType() == Ort::TypeToTensorType<Ort::Float16_t>::type);
Expand All @@ -386,10 +385,23 @@ void ConvertFp16ToFp32(OrtAllocator& allocator, cudaStream_t stream, OrtValue& i
auto* fp16 = in.GetTensorData<uint16_t>();
auto* fp32 = p_out->GetTensorMutableData<float>();

cuda::LaunchFp16ToFp32(fp16, fp32, count, stream);
}
switch (device_type) {
case DeviceType::CPU:
for (int i = 0; i < count; i++)
fp32[i] = Float16ToFloat32(fp16[i]);
break;

#ifdef USE_CUDA
case DeviceType::CUDA:
cuda::LaunchFp16ToFp32(fp16, fp32, count, stream);
break;
#endif

default:
throw std::runtime_error("ConvertFp16ToFp32 - Unsupported device type");
}
}

size_t GetOrtTypeSize(ONNXTensorElementDataType type) {
switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
Expand Down
2 changes: 1 addition & 1 deletion src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace Generators {

struct Tokenizer;

void ConvertFp16ToFp32(OrtAllocator& allocator, cudaStream_t stream, OrtValue& in, std::unique_ptr<OrtValue>& p_out);
void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<OrtValue>& p_out, DeviceType device_type, cudaStream_t stream);

struct State {
State(const GeneratorParams& params);
Expand Down
57 changes: 41 additions & 16 deletions src/python/py/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
This folder contains the model builder for quickly creating optimized and quantized ONNX models within a few minutes that run with ONNX Runtime GenAI.

# Contents
- [Current Support](#current-support)
- [Usage](#usage)
- [Full Usage](#full-usage)
- [Original PyTorch Model from Hugging Face](#original-pytorch-model-from-hugging-face)
- [Original PyTorch Model from Disk](#original-pytorch-model-from-disk)
- [Customized or Finetuned PyTorch Model](#customized-or-finetuned-pytorch-model)
- [GGUF Model](#gguf-model)
- [Extra Options](#extra-options)
- [Config Only](#config-only)
- [Unit Testing Models](#unit-testing-models)
- [Option 1: Use the model builder tool directly](#option-1-use-the-model-builder-tool-directly)
- [Option 2: Edit the config.json file](#option-2-edit-the-configjson-file-on-disk-and-then-run-the-model-builder-tool)
- [Current Support](#current-support)
- [Usage](#usage)
- [Full Usage](#full-usage)
- [Original PyTorch Model from Hugging Face](#original-pytorch-model-from-hugging-face)
- [Original PyTorch Model from Disk](#original-pytorch-model-from-disk)
- [Customized or Finetuned PyTorch Model](#customized-or-finetuned-pytorch-model)
- [GGUF Model](#gguf-model)
- [Extra Options](#extra-options)
- [Config Only](#config-only)
- [Exclude Embedding Layer](#exclude-embedding-layer)
- [Exclude Language Modeling Head](#exclude-language-modeling-head)
- [Unit Testing Models](#unit-testing-models)
- [Option 1: Use the model builder directly](#option-1-use-the-model-builder-directly)
- [Option 2: Edit the config.json file](#option-2-edit-the-configjson-file-on-disk-and-then-run-the-model-builder)
- [Design](#design)

## Current Support
The tool currently supports the following model architectures.
Expand Down Expand Up @@ -89,7 +92,7 @@ python3 builder.py -m model_name -o path_to_output_folder -p precision -e execut
```
To see all available options through `--extra_options`, please use the `help` commands in the `Full Usage` section above.

### Config Only
#### Config Only
This scenario is for when you already have your optimized and/or quantized ONNX model and you need to create the config files to run with ONNX Runtime GenAI.
```
# From wheel:
Expand All @@ -101,6 +104,28 @@ python3 builder.py -m model_name -o path_to_output_folder -p precision -e execut

Afterwards, please open the `genai_config.json` file in the output folder and modify the fields as needed for your model. You should store your ONNX model in the output folder as well.

#### Exclude Embedding Layer
This scenario is for when you want to exclude the embedding layer from your ONNX model.

```
# From wheel:
python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options exclude_embeds=true
# From source:
python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options exclude_embeds=true
```

#### Exclude Language Modeling Head
This scenario is for when you want to exclude the language modeling head from your ONNX model.

```
# From wheel:
python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options exclude_lm_head=true
# From source:
python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options exclude_lm_head=true
```

### Unit Testing Models
This scenario is where your PyTorch model is already downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk). If it is not already downloaded locally, here is an example of how you can download it.

Expand All @@ -117,7 +142,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
tokenizer.save_pretrained(cache_dir)
```

#### Option 1: Use the model builder tool directly
#### Option 1: Use the model builder directly
This option is the simplest but it will download another copy of the PyTorch model onto disk to accommodate the change in the number of hidden layers.
```
# From wheel:
Expand All @@ -127,11 +152,11 @@ python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_fold
python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider --extra_options num_hidden_layers=4
```

#### Option 2: Edit the config.json file on disk and then run the model builder tool
#### Option 2: Edit the config.json file on disk and then run the model builder

1. Navigate to where the PyTorch model and its associated files are saved on disk.
2. Modify `num_hidden_layers` in `config.json` to your desired target (e.g. 4 layers).
3. Run the below command for the model builder tool.
3. Run the below command for the model builder.

```
# From wheel:
Expand Down
Loading

0 comments on commit 0fbcaac

Please sign in to comment.