Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanUnderhill committed Jan 16, 2025
1 parent bdbb09c commit 66321dd
Show file tree
Hide file tree
Showing 9 changed files with 17 additions and 19 deletions.
3 changes: 1 addition & 2 deletions src/dml/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
std::string CurrentModulePath();

namespace Generators {
namespace Dml { // If this was in a shared library it wouldn't need to be in its own namespace
namespace Dml { // If this was in a shared library it wouldn't need to be in its own namespace

Ort::Allocator* ort_allocator_{};
const char* label_dml = "dml";
Expand Down Expand Up @@ -95,7 +95,6 @@ struct GpuMemory final : DeviceBuffer {
};

struct DmlInterfaceImpl : DeviceInterface {

DmlInterfaceImpl(LUID* p_device_luid) {
Ort::ThrowOnError(Ort::api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&dml_api_)));
if (!dml_api_) {
Expand Down
2 changes: 1 addition & 1 deletion src/dml/interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ void SetDmlProvider(OrtSessionOptions& options);

DeviceInterface* GetDmlInterface();

}
} // namespace Generators
5 changes: 3 additions & 2 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ DeviceSpan<T> WrapTensor(DeviceInterface& device, OrtValue& value) {

DeviceSpan<uint8_t> ByteWrapTensor(DeviceInterface& device, OrtValue& value);

template<typename T>
template <typename T>
struct OrtTensor {
OrtTensor(std::unique_ptr<OrtValue> ort_value, DeviceInterface& device)
: ort_value_{std::move(ort_value)}, device_span_{WrapTensor<T>(device, *ort_value_)} {}
: ort_value_{std::move(ort_value)}, device_span_{WrapTensor<T>(device, *ort_value_)} {}

operator OrtValue*() { return ort_value_.get(); }

Expand Down Expand Up @@ -151,6 +151,7 @@ struct OrtGlobals {

std::unique_ptr<OrtEnv> env_;
std::unique_ptr<Ort::Allocator> allocator_device_[static_cast<int>(DeviceType::MAX)];

private:
OrtGlobals(const OrtGlobals&) = delete;
void operator=(const OrtGlobals&) = delete;
Expand Down
2 changes: 1 addition & 1 deletion src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void DefaultInputIDs::Update(DeviceSpan<int32_t> new_tokens) {
}

// Update input_ids with next tokens
auto data_span = WrapTensor<int32_t>(*model_.p_device_, *value_);
auto data_span = WrapTensor<int32_t>(*model_.p_device_, *value_);

// For beam search
if (is_prompt_ && state_.params_->search.num_beams > 1) {
Expand Down
2 changes: 1 addition & 1 deletion src/models/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ void DefaultKeyValueCache::PickPastState(DeviceSpan<int32_t> beam_indices_device
auto past = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
past.CopyFrom(present);
}

pasts_[index] = std::move(past_value);
}

Expand Down
2 changes: 0 additions & 2 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ Logits::Logits(State& state)
type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} {
output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);


if (model_.device_type_ == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) {
auto& cpu_ids = model_.config_->model.eos_token_ids;
cuda_eos_token_ids_ = state_.params_->p_device->Allocate<int32_t>(cpu_ids.size());
Expand Down Expand Up @@ -52,7 +51,6 @@ DeviceSpan<float> Logits::Get() {
// Find the first non pad token from the end
size_t token_index = input_sequence_lengths[batch_index] - 1;
for (int beam_index = 0; beam_index < num_beams; beam_index++) {

auto target = logits_last_tokens.subspan(vocab_index * element_size, vocab_size * element_size);
auto source = logits_raw.subspan((vocab_index * seq_length + token_index * vocab_size) * element_size, vocab_size * element_size);
target.CopyFrom(source);
Expand Down
2 changes: 1 addition & 1 deletion src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ void Cast(OrtValue& input, std::unique_ptr<OrtValue>& output, DeviceInterface& d
if (!output)
output = OrtValue::CreateTensor(device.GetAllocator(), shape, output_type);

if(!device.Cast(input, *output))
if (!device.Cast(input, *output))
GetDeviceInterface(DeviceType::CPU)->Cast(input, *output);
}

Expand Down
16 changes: 8 additions & 8 deletions src/models/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, DeviceSpan<int32_t> seq
}

if (inputs.alignment_heads != nullptr) {
#if 0 // USE_CUDA
#if 0 // USE_CUDA
auto alignment_heads_type_and_shape_info = inputs.alignment_heads->ort_tensor_->GetTensorTypeAndShapeInfo();
auto alignment_heads_type = alignment_heads_type_and_shape_info->GetElementType(); // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
auto alignment_heads_shape = alignment_heads_type_and_shape_info->GetShape();
Expand Down Expand Up @@ -97,7 +97,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, DeviceSpan<int32_t> seq
}
}

#if 0 // USE_CUDA
#if 0 // USE_CUDA
template <typename T>
void TransposeKCacheForDMMHA(T* dest_data,
T* temp_buffer,
Expand Down Expand Up @@ -143,7 +143,7 @@ DeviceSpan<float> Whisper_State::Run(int current_length, DeviceSpan<int32_t>& ne

const auto copy_data_size_all = src_shape_info->GetElementCount() * SizeOf(src_shape_info->GetElementType());

#if 0 // USE_CUDA
#if 0 // USE_CUDA
const auto src_dims = src_shape_info->GetShape();
const auto src_element_type = src_shape_info->GetElementType();
const auto src_element_size = SizeOf(src_element_type);
Expand Down Expand Up @@ -183,7 +183,7 @@ DeviceSpan<float> Whisper_State::Run(int current_length, DeviceSpan<int32_t>& ne
auto dest_data = presents_[i]->GetTensorMutableRawData();

switch (model_.device_type_) {
#if 0 // USE_CUDA
#if 0 // USE_CUDA
case DeviceType::CUDA:
if (self_attn_kv_cache_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
// CUDA EP + FP16 precision == `DecoderMaskedMultiHeadAttention` op is used
Expand Down Expand Up @@ -224,7 +224,7 @@ DeviceSpan<float> Whisper_State::Run(int current_length, DeviceSpan<int32_t>& ne
}
}

#if 0 // USE_CUDA
#if 0 // USE_CUDA
if (self_attn_kv_cache_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 && model_.device_type_ == DeviceType::CUDA) {
// Transpose cross attention K caches for `DecoderMaskedMultiHeadAttention`

Expand Down Expand Up @@ -327,7 +327,7 @@ void Whisper_State::UpdateInputsOutputs(DeviceSpan<int32_t>& next_tokens, Device
}

if (cache_indirection_) {
#if 0 // USE_CUDA
#if 0 // USE_CUDA
auto beam_indices_gpu = gpu_span<int32_t>{beam_indices.Span()};
if (beam_indices_gpu.empty()) {
auto beam_indices_cpu = beam_indices.CpuSpan();
Expand Down Expand Up @@ -355,7 +355,7 @@ void Whisper_State::UpdateInputsOutputs(DeviceSpan<int32_t>& next_tokens, Device
}

if (output_cross_qk_.size() && alignment_heads_) {
#if 0 // USE_CUDA
#if 0 // USE_CUDA
// Collect a GPU array of float* pointers from the vector of OrtValues to pass to the kernel
auto output_cross_qk_ptrs = cross_qk_ptrs_gpu_.CpuSpan();
assert(output_cross_qk_ptrs.size() == output_cross_qk_.size());
Expand Down Expand Up @@ -386,7 +386,7 @@ void Whisper_State::Initialize(DeviceSpan<int32_t>& next_tokens, int total_lengt

void Whisper_State::Finalize() {
if (output_cross_qk_.size() && alignment_heads_) {
#if 0 // USE_CUDA
#if 0 // USE_CUDA
int decoded_length = *(past_sequence_length_->GetTensorMutableData<int32_t>()) + 1;
auto output_cross_qk_dims = output_cross_qk_[0]->GetTensorTypeAndShapeInfo()->GetShape();

Expand Down
2 changes: 1 addition & 1 deletion src/smartptrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ struct DeviceInterface {
virtual void LaunchHandleEOSArray(float* /*batch_logits*/, int /*batch_beam_size*/, int /*vocab_size*/, const int32_t* /*eos_token_ids*/, int /*eos_token_ids_count*/) { assert(false); }
virtual void UpdateCacheIndirectionKernelLauncher(int32_t* /*tgt_indir_cache*/, const int32_t* /*src_indir_cache*/, const int32_t* /*beam_ids*/, int /*batch_size*/, int /*beam_width*/, int /*input_seq_length*/, int /*max_seq_length*/, int /*current_length*/) { assert(false); }
virtual void ReorderPastStatesKernelLauncher(void* /*out_buffer*/, const void* /*in_buffer*/, int /*batch_size*/, int /*num_heads*/, int /*max_length*/, int /*head_size*/, int /*chunk_size*/) { assert(false); }
virtual void LaunchCopyCrossQKSingleDecodeStep(float* /*cross_qk_buffer_data*/, float** /*qk_layer_pointers*/, int /*token_index*/, int /*batch_beam_size*/, int /*num_layers*/, int /*num_heads*/, int /*num_alignment_heads*/, const int* /*alignment_heads*/, int /*frames*/, int /*max_length*/) { assert(false); }
virtual void LaunchCopyCrossQKSingleDecodeStep(float* /*cross_qk_buffer_data*/, float** /*qk_layer_pointers*/, int /*token_index*/, int /*batch_beam_size*/, int /*num_layers*/, int /*num_heads*/, int /*num_alignment_heads*/, const int* /*alignment_heads*/, int /*frames*/, int /*max_length*/) { assert(false); }
virtual void LaunchFinalizeCrossQK(int /*iteration_number*/, int /*context_decoding_len*/, int /*batch_size*/, int /*num_beams*/, int /*max_length*/, int /*num_alignment_heads*/, int /*frames_of_k*/, const float* /*cross_qk_buffer_data*/, float* /*cross_qk_output*/, int /*num_return_sequences*/, const int* /*cache_indir_data*/) { assert(false); }

virtual void* GetCudaStream() {
Expand Down

0 comments on commit 66321dd

Please sign in to comment.