Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola committed Apr 22, 2024
1 parent 261b0b7 commit d33f5ef
Show file tree
Hide file tree
Showing 36 changed files with 29 additions and 31 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ if(USE_DML)
add_compile_definitions(NOMINMAX)

file(GLOB dml_srcs CONFIGURE_DEPENDS
"${MODELS_ROOT}/dml/*.h"
"${MODELS_ROOT}/dml/*.cpp"
"${PROJECT_SOURCE_DIR}/src/dml/*.h"
"${PROJECT_SOURCE_DIR}/src/dml/*.cpp"
)

list(APPEND generator_srcs ${dml_srcs})
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <wil/result.h>
#include "dml_command_recorder.h"
#include "dml_command_queue.h"
#include "../onnxruntime_api.h"
#include "../models/onnxruntime_api.h"

DmlCommandRecorder::DmlCommandRecorder(
ID3D12Device* d3d_device,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <d3dx12.h>
#include "dml_command_recorder.h"
#include "dml_gpu_event.h"
#include "../onnxruntime_api.h"
#include "../models/onnxruntime_api.h"

// Asynchronously performs GPU work, and automatically manages command list recording and submission to queues.
// Work submitted to the DmlExecutionContext is typically recorded onto a command list and may not immediately begin
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <memory>
#include <wrl/client.h>
#include <wrl/implements.h>
#include "onnxruntime_api.h"
#include "../models/onnxruntime_api.h"

// Allows objects to be added to a D3D12 object via SetPrivateDataInterface and extend its lifetime beyond the life of the model. For
// example, we can put the DML allocator on the D3D12 device (which is a unique singleton for each adapter) and be sure that the allocator won't be
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions src/models/input_ids.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ struct InputIDs {
std::unique_ptr<OrtValue> value_;

// Used for decoding runs with cuda graphs.
StaticBuffer* sb_input_ids_ = nullptr;
StaticBuffer* sb_input_ids_{};

#if USE_DML
std::unique_ptr<OrtValue> value_int32_;
StaticBuffer* sb_input_ids_int32_ = nullptr;
StaticBuffer* sb_input_ids_int32_{};
DmlReusedCommandListState input_ids_cast_command_list_state_{};
#endif
};
Expand Down
18 changes: 8 additions & 10 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,20 @@ RoamingArray<float> Logits::Get() {
} break;
#endif

#if USE_CUDA
case DeviceType::CPU:
case DeviceType::CUDA: {
auto logits = std::span<float>{value32_->GetTensorMutableData<float>(), element_count};
auto logits_next = gpu_span<float>{value_next->GetTensorMutableData<float>(), element_count};
auto target = logits_next.subspan(vocab_index, vocab_size);
std::span<const float> source = logits.subspan(vocab_index * seq_length + token_index * vocab_size, vocab_size);
CudaCheck() == cudaMemcpyAsync(target.data(), source.data(), source.size_bytes(), cudaMemcpyDeviceToDevice, state_.params_->cuda_stream);

} break;
if (model_.device_type_ == DeviceType::CUDA)
#if USE_CUDA
CudaCheck() == cudaMemcpyAsync(target.data(), source.data(), source.size_bytes(), cudaMemcpyDeviceToDevice, state_.params_->cuda_stream);
#else
throw std::runtime_error("Unexpected CUDA device usage");
#endif
case DeviceType::CPU: {
auto logits = std::span<float>{value32_->GetTensorMutableData<float>(), element_count};
auto logits_next = cpu_span<float>{value_next->GetTensorMutableData<float>(), element_count};
auto target = logits_next.subspan(vocab_index, vocab_size);
std::span<const float> source = logits.subspan(vocab_index * seq_length + token_index * vocab_size, vocab_size);
copy(source, target);
else
copy(source, target);
} break;
}

Expand Down
4 changes: 2 additions & 2 deletions src/models/logits.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ struct Logits {
std::unique_ptr<OrtValue> value16_; // When model output is fp16

// Used for decoding runs with cuda graphs.
StaticBuffer* sb_logits32_ = nullptr;
StaticBuffer* sb_logits16_ = nullptr;
StaticBuffer* sb_logits32_{};
StaticBuffer* sb_logits16_{};

#if USE_DML
DmlReusedCommandListState logits_cast_command_list_state_{};
Expand Down
2 changes: 1 addition & 1 deletion src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#ifdef USE_DML
#include <wil/wrl.h>
#include "dml_provider_factory.h"
#include "dml/dml_smart_container.h"
#include "../dml/dml_smart_container.h"

EXTERN_C IMAGE_DOS_HEADER __ImageBase;

Expand Down
10 changes: 5 additions & 5 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

#ifdef USE_DML
#include "dml_provider_factory.h"
#include "dml/dml_helpers.h"
#include "dml/dml_execution_context.h"
#include "dml/dml_pooled_upload_heap.h"
#include "dml/dml_readback_heap.h"
#include "../dml/dml_helpers.h"
#include "../dml/dml_execution_context.h"
#include "../dml/dml_pooled_upload_heap.h"
#include "../dml/dml_readback_heap.h"
#endif

namespace Generators {
Expand Down Expand Up @@ -155,7 +155,7 @@ struct Model : std::enable_shared_from_this<Model> {
private:
#if USE_DML
mutable DmlObjects dml_objects_;
const OrtDmlApi* p_dml_api_ = nullptr;
const OrtDmlApi* p_dml_api_{};
std::unique_ptr<DmlPooledUploadHeap> dml_pooled_upload_heap_;
std::unique_ptr<DmlExecutionContext> dml_execution_context_;
std::unique_ptr<DmlReadbackHeap> dml_readback_heap_;
Expand Down
2 changes: 1 addition & 1 deletion src/models/position_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "kernels.h"

#if USE_DML
#include "dml/dml_update_mask_kernel.h"
#include "../dml/dml_update_mask_kernel.h"
#endif

namespace Generators {
Expand Down
10 changes: 5 additions & 5 deletions src/models/position_inputs.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include "static_buffer.h"

#if USE_DML
#include "dml/dml_update_mask_kernel.h"
#include "dml/dml_increment_values_kernel.h"
#include "../dml/dml_update_mask_kernel.h"
#include "../dml/dml_increment_values_kernel.h"
#endif

namespace Generators {
Expand Down Expand Up @@ -51,15 +51,15 @@ struct PositionInputs {
std::vector<int32_t> initial_sequence_lengths_;

// Used for decoding runs with cuda graphs.
StaticBuffer* sb_position_ids_ = nullptr;
StaticBuffer* sb_attention_mask_ = nullptr;
StaticBuffer* sb_position_ids_{};
StaticBuffer* sb_attention_mask_{};

bool is_first_posid_update_{true};
bool is_first_mask_update_{true};

#ifdef USE_DML
std::optional<DmlUpdateMaskKernel> dml_update_mask_kernel_;
StaticBuffer* sb_attention_mask_next_ = nullptr;
StaticBuffer* sb_attention_mask_next_{};
std::optional<DmlIncrementValuesKernel> dml_update_position_ids_kernel_;
bool is_second_mask_update_{false};
#endif
Expand Down

0 comments on commit d33f5ef

Please sign in to comment.