Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert whisper input_features from fp32 to fp16 if needed #200

Closed
wants to merge 12 commits into from
43 changes: 20 additions & 23 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,28 @@ OrtEnv& GetOrtEnv() {
return *GetOrtGlobals()->env_;
}

// C++17 compatible version of bit_cast for the code below
template <typename TTo, typename TFrom>
TTo bit_cast(TFrom x) {
return *reinterpret_cast<TTo*>(&x);
}

// IEEE-754 16-bit floating-point format (without infinity): 1-5-10, exp-15, +-131008.0, +-6.1035156E-5, +-5.9604645E-8, 3.311 digits
// IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction
float Float16ToFloat32(uint16_t v) {
// Extract sign, exponent, and fraction from numpy.float16
int const sign = (v & 0x8000) >> 15;
int const exponent = (v & 0x7C00) >> 10;
int const fraction = v & 0x03FF;

// Handle special cases
if (exponent == 0) {
if (fraction == 0) {
// Zero
return sign != 0 ? -0.0f : 0.0f;
} // Subnormal number
return std::ldexp((sign != 0 ? -1.0f : 1.0f) * static_cast<float>(fraction) / 1024.0f, -14);
}
if (exponent == 31) {
if (fraction == 0) {
// Infinity
return sign != 0 ? -std::numeric_limits<float>::infinity() : std::numeric_limits<float>::infinity();
} // NaN
return std::numeric_limits<float>::quiet_NaN();
}
float Float16ToFloat32(const uint16_t x) {
const uint32_t e = (x & 0x7C00) >> 10; // exponent
const uint32_t m = (x & 0x03FF) << 13; // mantissa

const uint32_t v = bit_cast<uint32_t>((float)m) >> 23; // log2 bit hack to count leading zeros in denormalized format
return bit_cast<float>((x & 0x8000) << 16 | (e != 0) * ((e + 112) << 23 | m) | ((e == 0) & (m != 0)) * ((v - 37) << 23 | ((m << (150 - v)) & 0x007FE000))); // sign : normalized : denormalized
}

uint16_t Float32ToFloat16(float v) {
const uint32_t b = bit_cast<uint32_t>(v) + 0x00001000; // round-to-nearest-even: add last bit after truncated mantissa

// Normalized number
return std::ldexp((sign != 0 ? -1.0f : 1.0f) * (1.0f + static_cast<float>(fraction) / 1024.0f), exponent - 15);
const uint32_t e = (b & 0x7F800000) >> 23; // exponent
const uint32_t m = b & 0x007FFFFF; // mantissa; in line below: 0x007FF000 = 0x00800000-0x00001000 = decimal indicator flag - initial rounding
return static_cast<uint16_t>((b & 0x80000000) >> 16 | (e > 112) * ((((e - 112) << 10) & 0x7C00) | m >> 13) | ((e < 113) & (e > 101)) * ((((0x007FF000 + m) >> (125 - e)) + 1) >> 1) | (e > 143) * 0x7FFF); // sign : normalized : denormalized : saturate
}

GeneratorParams::GeneratorParams(const Model& model)
Expand Down
2 changes: 2 additions & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorPa
std::vector<std::vector<int32_t>> Generate(const Model& model, const GeneratorParams& params); // Uses CreateGenerator and a simple loop to return the entire sequence

float Float16ToFloat32(uint16_t v); // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction
uint16_t Float32ToFloat16(float v); // Opposite direction of above

void top_k_indices(std::span<int32_t> top_k, std::span<const float> inputs);

} // namespace Generators
2 changes: 1 addition & 1 deletion src/logging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void SGRExample(std::ostream& stream) {
}

bool RunExample = (SGRExample(std::cerr), false);
#endif SGR_EXAMPLE
#endif

std::ostream& Log(std::string_view label, std::string_view string) {
assert(g_log.enabled);
Expand Down
12 changes: 12 additions & 0 deletions src/models/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,18 @@ void LaunchFp16ToFp32(const uint16_t* fp16, float* fp32, int count, cudaStream_t
ConvertFp16ToFp32<<<num_blocks, block_size, 0, stream>>>(reinterpret_cast<const half*>(fp16), fp32, count);
}

__global__ void ConvertFp32ToFp16(const float* src, half* dst, int count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < count)
dst[idx] = __float2half(src[idx]);
}

void LaunchFp32ToFp16(const float* fp32, uint16_t* fp16, int count, cudaStream_t stream) {
int block_size = 256;
int num_blocks = (count + block_size - 1) / block_size;
ConvertFp32ToFp16<<<num_blocks, block_size, 0, stream>>>(fp32, reinterpret_cast<half*>(fp16), count);
}

__global__ void ConvertInt32ToInt64(const int32_t* src, int64_t* dst, int count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < count) {
Expand Down
1 change: 1 addition & 0 deletions src/models/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ void Launch_UpdateAttentionMask(T* mask_data, const T* old_mask_data, int batch_
void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream);

void LaunchFp16ToFp32(const uint16_t* fp16, float* fp32, int count, cudaStream_t stream);
void LaunchFp32ToFp16(const float* fp32, uint16_t* fp16, int count, cudaStream_t stream);
void LaunchInt32ToInt64(const int32_t* src, int64_t* dst, int count, cudaStream_t stream);
} // namespace cuda

Expand Down
38 changes: 38 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,44 @@ void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<Or
}
}

void ConvertFp32ToFp16(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<OrtValue>& p_out, DeviceType device_type, cudaStream_t stream) {
auto shape_info = in.GetTensorTypeAndShapeInfo();
auto shape = shape_info->GetShape();
assert(shape_info->GetElementType() == Ort::TypeToTensorType<float>::type);

bool allocate_p_out = p_out == nullptr;
if (p_out) {
auto out_shape_info = p_out->GetTensorTypeAndShapeInfo();
auto out_shape = out_shape_info->GetShape();
allocate_p_out = shape != out_shape;
}

if (allocate_p_out)
p_out = OrtValue::CreateTensor<Ort::Float16_t>(allocator, shape);

int count = static_cast<int>(shape_info->GetElementCount());
auto* fp32 = in.GetTensorData<float>();
auto* fp16 = p_out->GetTensorMutableData<uint16_t>();

switch (device_type) {
case DeviceType::DML:
// DML doesn't currently support on-device scoring, so we fall back to the CPU
case DeviceType::CPU:
for (int i = 0; i < count; i++)
fp16[i] = Float32ToFloat16(fp32[i]);
break;

#if USE_CUDA
case DeviceType::CUDA:
cuda::LaunchFp32ToFp16(fp32, fp16, count, stream);
break;
#endif

default:
throw std::runtime_error("ConvertFp32ToFp16 - Unsupported device type");
}
}

size_t GetOrtTypeSize(ONNXTensorElementDataType type) {
switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
Expand Down
1 change: 1 addition & 0 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace Generators {
struct Tokenizer;

void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<OrtValue>& p_out, DeviceType device_type, cudaStream_t stream);
void ConvertFp32ToFp16(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<OrtValue>& p_out, DeviceType device_type, cudaStream_t stream);

struct State {
State(const GeneratorParams& params);
Expand Down
12 changes: 11 additions & 1 deletion src/models/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ namespace Generators {

Whisper_Model::Whisper_Model(std::unique_ptr<Config> config, OrtEnv& ort_env)
: Model{std::move(config)} {
session_decoder_ = OrtSession::Create(ort_env, (config_->config_path / config_->model.decoder.filename).c_str(), session_options_.get());
session_encoder_ = OrtSession::Create(ort_env, (config_->config_path / config_->model.encoder_decoder_init.filename).c_str(), session_options_.get());
session_decoder_ = OrtSession::Create(ort_env, (config_->config_path / config_->model.decoder.filename).c_str(), session_options_.get());

InitDeviceAllocator(*session_decoder_);
session_encoder_info_ = std::make_unique<SessionInfo>(*session_encoder_);
}

std::unique_ptr<State> Whisper_Model::CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) const {
Expand All @@ -20,6 +21,15 @@ Whisper_State::Whisper_State(const Whisper_Model& model, RoamingArray<int32_t> s
model_{model} {
auto& inputs = const_cast<GeneratorParams::Whisper&>(std::get<GeneratorParams::Whisper>(params.inputs));

#if USE_CUDA
// Convert input_features from float32 to float16 if necessary
if (model_.device_type_ == DeviceType::CUDA && model_.session_encoder_info_->GetInputDataType("encoder_input_ids") == Ort::TypeToTensorType<Ort::Float16_t>::type) {
std::unique_ptr<OrtValue> input_features_32;
ConvertFp32ToFp16(*model_.allocator_device_, *inputs.input_features, input_features_32, model_.device_type_, model_.cuda_stream_);
inputs.input_features = std::move(input_features_32);
}
#endif

auto encoder_input_ids = model_.ExpandInputs(inputs.input_features, params_->search.num_beams);
encoder_hidden_states_ = OrtValue::CreateTensor<float>(*model_.allocator_device_, std::array<int64_t, 3>{decoder_input_ids_.GetShape()[0], 1500, 384});

Expand Down
4 changes: 3 additions & 1 deletion src/models/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ struct Whisper_Model : Model {

std::unique_ptr<State> CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) const override;

std::unique_ptr<OrtSession> session_decoder_; // decoder.onnx
std::unique_ptr<OrtSession> session_encoder_; // encoder_decoder_init.onnx
std::unique_ptr<OrtSession> session_decoder_; // decoder.onnx

std::unique_ptr<SessionInfo> session_encoder_info_;
};

struct Whisper_State : State {
Expand Down
7 changes: 7 additions & 0 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ struct PyGenerator {
generator_->ComputeLogits();
}

pybind11::array_t<float> GetLogits() {
py_logits_.Assign(generator_->search_->GetLogits());
return ToPython(py_logits_.GetCPU());
}

void GenerateNextToken() {
generator_->GenerateNextToken();
}
Expand All @@ -146,6 +151,7 @@ struct PyGenerator {
PyRoamingArray<int32_t> py_indices_;
PyRoamingArray<int32_t> py_sequence_;
PyRoamingArray<int32_t> py_sequencelengths_;
PyRoamingArray<float> py_logits_;
};

void SetLogOptions(const pybind11::kwargs& dict) {
Expand Down Expand Up @@ -235,6 +241,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
.def(pybind11::init<Model&, PyGeneratorParams&>())
.def("is_done", &PyGenerator::IsDone)
.def("compute_logits", &PyGenerator::ComputeLogits)
.def("get_logits", &PyGenerator::GetLogits)
.def("generate_next_token", &PyGenerator::GenerateNextToken)
.def("get_next_tokens", &PyGenerator::GetNextTokens)
.def("get_sequence", &PyGenerator::GetSequence);
Expand Down
4 changes: 3 additions & 1 deletion src/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct Search {
virtual RoamingArray<int32_t> GetSequenceLengths() = 0;
virtual int GetSequenceLength() const = 0;
virtual RoamingArray<int32_t> GetSequence(int index) = 0;
virtual RoamingArray<float> GetLogits() = 0;

virtual void SetLogits(RoamingArray<float> logits) = 0;
virtual bool IsDone() const = 0;
Expand All @@ -39,6 +40,7 @@ struct Search_Cpu : Search {
int GetSequenceLength() const override;
RoamingArray<int32_t> GetSequenceLengths() override { return sequence_lengths_; }
RoamingArray<int32_t> GetSequence(int index) override { return sequences_.GetSequence(index); }
RoamingArray<float> GetLogits() override { return next_token_scores_; }

bool IsDone() const override { return done_; }
void SetLogits(RoamingArray<float> logits) override;
Expand All @@ -54,7 +56,7 @@ struct Search_Cpu : Search {

cpu_span<int32_t> next_tokens_; // shape (beam_size*batch_size)

std::span<float> next_token_scores_; // shape (beam_size*batch_size, vocab_size)
cpu_span<float> next_token_scores_; // shape (beam_size*batch_size, vocab_size)

Sequences sequences_;
bool done_{};
Expand Down
1 change: 1 addition & 0 deletions src/search_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ struct Search_Cuda : Search {
int GetSequenceLength() const override;
RoamingArray<int32_t> GetSequenceLengths() override { return sequence_lengths_; }
RoamingArray<int32_t> GetSequence(int index) override { return sequences_.GetSequence(index); }
RoamingArray<float> GetLogits() override { return next_token_scores_; }

bool IsDone() const {
cudaStreamSynchronize(params_->cuda_stream);
Expand Down
Loading