Skip to content

Commit

Permalink
Fix fp16 whisper conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanUnderhill committed May 1, 2024
1 parent 13dede8 commit 512a71d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/models/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Whisper_Model::Whisper_Model(std::unique_ptr<Config> config, OrtEnv& ort_env)
session_encoder_ = OrtSession::Create(ort_env, (config_->config_path / config_->model.encoder_decoder_init.filename).c_str(), session_options_.get());

InitDeviceAllocator(*session_decoder_);
session_decoder_info_ = std::make_unique<SessionInfo>(*session_decoder_);
}

std::unique_ptr<State> Whisper_Model::CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) const {
Expand All @@ -22,7 +23,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, RoamingArray<int32_t> s

#if USE_CUDA
// Convert input_features from float32 to float16 if necessary
if (model_.device_type_ == DeviceType::CUDA && model_.session_info_->GetInputDataType("encoder_input_ids") == Ort::TypeToTensorType<Ort::Float16_t>::type) {
if (model_.device_type_ == DeviceType::CUDA && model_.session_decoder_info_->GetInputDataType("encoder_input_ids") == Ort::TypeToTensorType<Ort::Float16_t>::type) {
std::unique_ptr<OrtValue> input_features_32;
ConvertFp32ToFp16(*model_.allocator_device_, *inputs.input_features, input_features_32, model_.device_type_, model_.cuda_stream_);
inputs.input_features = std::move(input_features_32);
Expand Down
1 change: 1 addition & 0 deletions src/models/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ struct Whisper_Model : Model {

std::unique_ptr<OrtSession> session_decoder_; // decoder.onnx
std::unique_ptr<OrtSession> session_encoder_; // encoder_decoder_init.onnx
std::unique_ptr<SessionInfo> session_decoder_info_;
};

struct Whisper_State : State {
Expand Down

0 comments on commit 512a71d

Please sign in to comment.