Skip to content

Commit

Permalink
Fix CPU devices in DML build - rc3 (#309)
Browse files Browse the repository at this point in the history
Co-authored-by: Patrice Vignola <[email protected]>
  • Loading branch information
yufenglee and PatriceVignola authored Apr 23, 2024
1 parent dd98289 commit db5f57e
Show file tree
Hide file tree
Showing 11 changed files with 47 additions and 28 deletions.
2 changes: 1 addition & 1 deletion VERSION_INFO
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.0rc2
0.2.0rc3
8 changes: 6 additions & 2 deletions src/models/captured_graph_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model,
new_captured_graph->sb_input_ids_ = std::make_unique<StaticBuffer>(allocator_device_, max_beam_batch_size);

#if USE_DML
new_captured_graph->sb_input_ids_int32_ = std::make_unique<StaticBuffer>(allocator_device_, max_beam_batch_size);
if (model.device_type_ == DeviceType::DML) {
new_captured_graph->sb_input_ids_int32_ = std::make_unique<StaticBuffer>(allocator_device_, max_beam_batch_size);
}
#endif

// Create the static buffers for the cache
Expand All @@ -76,7 +78,9 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model,

#if USE_DML
// DML currently needs an additional static buffer for the mask
new_captured_graph->sb_attention_mask_next_ = std::make_unique<StaticBuffer>(allocator_device_, max_beam_batch_size);
if (model.device_type_ == DeviceType::DML) {
new_captured_graph->sb_attention_mask_next_ = std::make_unique<StaticBuffer>(allocator_device_, max_beam_batch_size);
}
#endif
}

Expand Down
12 changes: 9 additions & 3 deletions src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ InputIDs::InputIDs(const Model& model, State& state)
sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get();

#if USE_DML
sb_input_ids_int32_ = state_.GetCapturedGraphInfo()->sb_input_ids_int32_.get();
if (model_.device_type_ == DeviceType::DML) {
sb_input_ids_int32_ = state_.GetCapturedGraphInfo()->sb_input_ids_int32_.get();
}
#endif
}
}
Expand All @@ -52,13 +54,17 @@ void InputIDs::Update(RoamingArray<int32_t> next_tokens_unk) {
value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);

#if USE_DML
value_int32_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
if (model_.device_type_ == DeviceType::DML) {
value_int32_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
}
#endif
} else {
value_ = sb_input_ids_->CreateTensorOnStaticBuffer(shape_, type_);

#if USE_DML
value_int32_ = sb_input_ids_int32_->CreateTensorOnStaticBuffer(shape_, Ort::TypeToTensorType<int32_t>::type);
if (model_.device_type_ == DeviceType::DML) {
value_int32_ = sb_input_ids_int32_->CreateTensorOnStaticBuffer(shape_, Ort::TypeToTensorType<int32_t>::type);
}
#endif
}

Expand Down
25 changes: 14 additions & 11 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,18 @@ RoamingArray<float> Logits::Get() {
// Convert from float16 to float32 if necessary
if (type_ == Ort::TypeToTensorType<Ort::Float16_t>::type) {
#if USE_DML
DmlHelpers::DmlCastInputToOutput(
model_.GetDmlExecutionContext(),
*model_.allocator_device_,
*value16_,
value32_,
model_.GetDmlDevice(),
model_.GetOrtDmlApi(),
logits_cast_command_list_state_);
#else
ConvertFp16ToFp32(*model_.allocator_device_, *value16_, value32_, model_.device_type_, model_.cuda_stream_);
if (model_.device_type_ == DeviceType::DML) {
DmlHelpers::DmlCastInputToOutput(
model_.GetDmlExecutionContext(),
*model_.allocator_device_,
*value16_,
value32_,
model_.GetDmlDevice(),
model_.GetOrtDmlApi(),
logits_cast_command_list_state_);
} else
#endif
ConvertFp16ToFp32(*model_.allocator_device_, *value16_, value32_, model_.device_type_, model_.cuda_stream_);
}

// First iteration? Then copy the logits over to a {batch_beams, 1, vocab_size} tensor
Expand All @@ -73,7 +74,9 @@ RoamingArray<float> Logits::Get() {

#if USE_DML
// DML doesn't support on-device scoring yet, so we need to download some data to the CPU
value32_cpu_ = OrtValue::CreateTensor<float>(model_.allocator_cpu_, shape_);
if (model_.device_type_ == DeviceType::DML) {
value32_cpu_ = OrtValue::CreateTensor<float>(model_.allocator_cpu_, shape_);
}
#endif

size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process
Expand Down
6 changes: 3 additions & 3 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "decoder_only.h"
#include "whisper.h"
#include "kernels.h"
#ifdef USE_DML
#if USE_DML
#include <wil/wrl.h>
#include "dml_provider_factory.h"
#include "../dml/dml_smart_container.h"
Expand Down Expand Up @@ -347,7 +347,7 @@ void Model::CreateSessionOptions() {

Ort::ThrowOnError(Ort::api->UpdateROCMProviderOptions(&ort_provider_options, keys.data(), values.data(), keys.size()));
ort_options.AppendExecutionProvider_ROCM(ort_provider_options);
#ifdef USE_DML
#if USE_DML
} else if (provider_options.name == "dml") {
dml_objects_ = DmlHelpers::CreateDmlObjects();

Expand Down Expand Up @@ -442,7 +442,7 @@ void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<Or
fp32[i] = Float16ToFloat32(fp16[i]);
break;

#ifdef USE_CUDA
#if USE_CUDA
case DeviceType::CUDA:
cuda::LaunchFp16ToFp32(fp16, fp32, count, stream);
break;
Expand Down
2 changes: 1 addition & 1 deletion src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "captured_graph_pool.h"

#ifdef USE_DML
#if USE_DML
#include "dml_provider_factory.h"
#include "../dml/dml_helpers.h"
#include "../dml/dml_execution_context.h"
Expand Down
10 changes: 8 additions & 2 deletions src/models/position_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ PositionInputs::PositionInputs(const Model& model, State& state, RoamingArray<in
sb_attention_mask_ = state_.GetCapturedGraphInfo()->sb_attention_mask_.get();

#if USE_DML
sb_attention_mask_next_ = state_.GetCapturedGraphInfo()->sb_attention_mask_next_.get();
if (model_.device_type_ == DeviceType::DML) {
sb_attention_mask_next_ = state_.GetCapturedGraphInfo()->sb_attention_mask_next_.get();
}
#endif
}
}
Expand Down Expand Up @@ -301,7 +303,11 @@ void PositionInputs::UpdateAttentionMask(int current_length) {
throw std::runtime_error("PositionIDs::Update - Unsupported device type");
}

#ifndef USE_DML
#if USE_DML
if (model_.device_type_ != DeviceType::DML) {
attention_mask_ = std::move(attention_mask_next_);
}
#else
attention_mask_ = std::move(attention_mask_next_);
#endif

Expand Down
2 changes: 1 addition & 1 deletion src/models/position_inputs.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct PositionInputs {
bool is_first_posid_update_{true};
bool is_first_mask_update_{true};

#ifdef USE_DML
#if USE_DML
std::optional<DmlUpdateMaskKernel> dml_update_mask_kernel_;
StaticBuffer* sb_attention_mask_next_{};
std::optional<DmlIncrementValuesKernel> dml_update_position_ids_kernel_;
Expand Down
4 changes: 2 additions & 2 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,15 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
m.def("set_log_options", &SetLogOptions);

m.def("is_cuda_available", []() {
#ifdef USE_CUDA
#if USE_CUDA
return true;
#else
return false;
#endif
});

m.def("is_dml_available", []() {
#ifdef USE_DML
#if USE_DML
return true;
#else
return false;
Expand Down
2 changes: 1 addition & 1 deletion test/c_api_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ TEST(CAPITests, EndToEndPhiBatch) {
}

// DML doesn't support GPT attention
#ifndef USE_DML
#if !USE_DML
TEST(CAPITests, GreedySearchGptFp32CAPI) {
std::vector<int64_t> input_ids_shape{2, 4};
std::vector<int32_t> input_ids{0, 0, 0, 52, 0, 0, 195, 731};
Expand Down
2 changes: 1 addition & 1 deletion test/model_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ static const std::pair<const char*, const char*> c_tiny_gpt2_model_paths[] = {
};

// DML doesn't support GPT attention
#ifndef USE_DML
#if !USE_DML
TEST(ModelTests, GreedySearchGptFp32) {
std::vector<int64_t> input_ids_shape{2, 4};
std::vector<int32_t> input_ids{0, 0, 0, 52, 0, 0, 195, 731};
Expand Down

0 comments on commit db5f57e

Please sign in to comment.