Skip to content

Commit

Permalink
Address a DML regression caused by the continuous decoding changes (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Jan 13, 2025
1 parent 5e5c544 commit 49eb184
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 66 deletions.
3 changes: 2 additions & 1 deletion src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ void Generator::AppendTokens(cpu_span<const int32_t> input_ids) {
if (search_->GetSequenceLength() != 0 &&
std::none_of(devices_supporting_continuous_decoding.begin(), devices_supporting_continuous_decoding.end(),
[this](DeviceType device_type) { return device_type == state_->params_->device_type; }))
throw std::runtime_error("Continuous decoding is not supported on the selected device type: " + to_string(state_->params_->device_type));
throw std::runtime_error("Continuous decoding is not supported on the selected device type (" + to_string(state_->params_->device_type) +
"). Please recreate the generator instance to avoid using continuous decoding.");

if (last_action_ == Action::generated) {
ComputeLogits(search_->GetNextTokens());
Expand Down
75 changes: 38 additions & 37 deletions src/models/debugging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,47 +88,48 @@ void DumpTensor(const Model& model, std::ostream& stream, OrtValue* value, bool
stream << SGR::Fg_Green << " Location: " << SGR::Reset;

const auto& memory_info = value->GetTensorMemoryInfo();
switch (memory_info.GetDeviceType()) {
case OrtMemoryInfoDeviceType_CPU:
stream << "CPU\r\n";
DumpValues(stream, type_info->GetElementType(), value->GetTensorRawData(), element_count);
break;
case OrtMemoryInfoDeviceType_GPU: {
stream << "GPU\r\n";
auto device_type = memory_info.GetDeviceType();
if (device_type == OrtMemoryInfoDeviceType_CPU) {
stream << "CPU\r\n";
DumpValues(stream, type_info->GetElementType(), value->GetTensorRawData(), element_count);
} else if (device_type == OrtMemoryInfoDeviceType_GPU) {
stream << "GPU\r\n";
#if USE_CUDA
auto type = type_info->GetElementType();
size_t element_size = SizeOf(type);
auto cpu_copy = std::make_unique<uint8_t[]>(element_size * element_count);
CudaCheck() == cudaMemcpy(cpu_copy.get(), value->GetTensorRawData(), element_size * element_count, cudaMemcpyDeviceToHost);
DumpValues(stream, type, cpu_copy.get(), element_count);
#elif USE_DML
auto type = type_info->GetElementType();
size_t element_size = SizeOf(type);
auto cpu_copy = std::make_unique<uint8_t[]>(element_size * element_count);

if (value->GetTensorMutableRawData()) {
ComPtr<ID3D12Resource> gpu_resource;
Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(
model.allocator_device_,
value->GetTensorMutableRawData(),
&gpu_resource));

model.GetDmlReadbackHeap()->ReadbackFromGpu(
std::span(cpu_copy.get(), element_size * element_count),
gpu_resource.Get(),
0,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
}

DumpValues(stream, type, cpu_copy.get(), element_count);
auto type = type_info->GetElementType();
size_t element_size = SizeOf(type);
auto cpu_copy = std::make_unique<uint8_t[]>(element_size * element_count);
CudaCheck() == cudaMemcpy(cpu_copy.get(), value->GetTensorRawData(), element_size * element_count, cudaMemcpyDeviceToHost);
DumpValues(stream, type, cpu_copy.get(), element_count);
#else
stream << "Unexpected, using GPU memory but not compiled with CUDA or DML?";
throw std::runtime_error("Unexpected error. Trying to access GPU memory but the project is not compiled with CUDA.");
#endif
break;
} else if (static_cast<int>(device_type) == 4) {
stream << "DML\r\n";
#if USE_DML
auto type = type_info->GetElementType();
size_t element_size = SizeOf(type);
auto cpu_copy = std::make_unique<uint8_t[]>(element_size * element_count);

if (value->GetTensorMutableRawData()) {
ComPtr<ID3D12Resource> gpu_resource;
Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(
model.allocator_device_,
value->GetTensorMutableRawData(),
&gpu_resource));

model.GetDmlReadbackHeap()->ReadbackFromGpu(
std::span(cpu_copy.get(), element_size * element_count),
gpu_resource.Get(),
0,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
}
default:
stream << "Unhandled device type";
break;

DumpValues(stream, type, cpu_copy.get(), element_count);
#else
throw std::runtime_error("Unexpected error. Trying to access DML memory but the project is not compiled with DML.");
#endif
} else {
stream << "Unhandled device type: " << static_cast<int>(device_type) << "\r\n";
}
}

Expand Down
59 changes: 47 additions & 12 deletions src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,9 @@ namespace Generators {
DefaultInputIDs::DefaultInputIDs(State& state)
: state_{state} {
name_ = model_.config_->model.decoder.inputs.input_ids.c_str();
shape_ = {state_.params_->BatchBeamSize(), 0};
shape_ = {state_.params_->search.batch_size, 0};
type_ = model_.session_info_->GetInputDataType(name_);

if (state_.GetCapturedGraphInfo()) {
sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get();

#if USE_DML
if (model_.device_type_ == DeviceType::DML) {
sb_input_ids_int32_ = state_.GetCapturedGraphInfo()->sb_input_ids_int32_.get();
}
#endif
}

if (model_.session_info_->HasInput(model_.config_->model.decoder.inputs.current_sequence_length) &&
model_.session_info_->HasInput(model_.config_->model.decoder.inputs.past_sequence_length)) {
if (state_.params_->BatchBeamSize() != 1) {
Expand All @@ -36,7 +26,7 @@ DefaultInputIDs::DefaultInputIDs(State& state)
current_sequence_length_ = OrtValue::CreateTensor(model_.allocator_cpu_, current_sequence_length_shape, model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.current_sequence_length));
*current_sequence_length_->GetTensorMutableData<int32_t>() = 0;

past_sequence_length_ = OrtValue::CreateTensor(*model_.allocator_device_, past_sequence_length_shape, model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.past_sequence_length));
past_sequence_length_ = OrtValue::CreateTensor(model_.allocator_cpu_, past_sequence_length_shape, model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.past_sequence_length));
*past_sequence_length_->GetTensorMutableData<int32_t>() = -1;
}
}
Expand All @@ -56,6 +46,51 @@ void DefaultInputIDs::Add() {
}

void DefaultInputIDs::Update(DeviceSpan<int32_t>& new_tokens) {
// There are three scopes involved when the Update function is called:
// 1. A new Generator state has been just created. This is a prompt stage, and value_ is a nullptr.
// i.e. this is the very first time ever that Update is being called for this Generator.
// 2. We move to the token generation stage. value_ has already been previously created in the prompt stage.
// Update is called on every new token generated.
// 3. We move from the token generation stage back to the prompt stage (e.g. in continous decoding). value_ is already created.

// For instances where the value_ is not created, we need handle graph capture correctly.
// For subsequent prompt stages, the limiting factor is that the subsequent prompts can not
// be larger than the first prompt (when graph capture is enabled).
if (!value_) {
shape_[1] = static_cast<int64_t>(new_tokens.size()) / shape_[0];

// If 64-bit, convert from 32-bit to 64-bit
auto input_ids = new_tokens.CopyDeviceToCpu();
if (type_ == Ort::TypeToTensorType<int64_t>) {
value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_);
auto* p_data = value_->GetTensorMutableData<int64_t>();
for (auto v : input_ids) {
*p_data++ = v;
}
} else {
if (type_ != Ort::TypeToTensorType<int32_t>)
throw std::runtime_error("InputIDs must be int64 or int32");
value_ = OrtValue::CreateTensor<int32_t>(model_.allocator_cpu_.GetInfo(), input_ids, shape_);
}

value_ = model_.ExpandInputs(value_, state_.params_->search.num_beams);
shape_[0] *= state_.params_->search.num_beams;

if (state_.GetCapturedGraphInfo()) {
sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get();

#if USE_DML
if (model_.device_type_ == DeviceType::DML) {
sb_input_ids_int32_ = state_.GetCapturedGraphInfo()->sb_input_ids_int32_.get();
}
#endif
}

is_prompt_ = false;
state_.inputs_[input_index_] = value_.get();
return;
}

const auto get_unpadded_sequence_length = [](std::span<const int32_t> input_ids,
int32_t pad_token_id) {
int32_t seq_length = 0;
Expand Down
21 changes: 12 additions & 9 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,6 @@ Logits::Logits(State& state)
type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} {
output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);

if (state_.GetCapturedGraphInfo()) {
if (type_ == Ort::TypeToTensorType<float>) {
sb_logits32_ = state_.GetCapturedGraphInfo()->sb_logits32_.get();
}
if (type_ == Ort::TypeToTensorType<Ort::Float16_t>) {
sb_logits16_ = state_.GetCapturedGraphInfo()->sb_logits16_.get();
}
}

#if USE_CUDA
if (model_.device_type_ == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) {
auto& cpu_ids = model_.config_->model.eos_token_ids;
Expand Down Expand Up @@ -215,6 +206,18 @@ void Logits::Update(const DeviceSpan<int32_t>& next_tokens, size_t new_kv_length
StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType<Ort::Float16_t> ? sb_logits16_ : sb_logits32_;
output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_)
: sb_logits->CreateTensorOnStaticBuffer(shape_, type_);

if (state_.GetCapturedGraphInfo()) {
if (!sb_logits16_ && !sb_logits32_) {
if (type_ == Ort::TypeToTensorType<float>) {
sb_logits32_ = state_.GetCapturedGraphInfo()->sb_logits32_.get();
}
if (type_ == Ort::TypeToTensorType<Ort::Float16_t>) {
sb_logits16_ = state_.GetCapturedGraphInfo()->sb_logits16_.get();
}
}
}

state_.outputs_[output_index_] = output_raw_.get();
}

Expand Down
19 changes: 12 additions & 7 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,18 @@ OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator
// Copy data to ortvalue_clone
auto element_size = Generators::SizeOf(type_info->GetElementType());
auto data_size = type_info->GetElementCount() * element_size;
if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::CUDA) {
const auto device_type = ortvalue_output->GetTensorMemoryInfo().GetDeviceType();
if (device_type == OrtMemoryInfoDeviceType_CPU) {
std::copy(static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()),
static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()) + data_size,
static_cast<uint8_t*>(ortvalue_clone->GetTensorMutableRawData()));
} else if (device_type == OrtMemoryInfoDeviceType_GPU) {
#if USE_CUDA
cudaMemcpy(ortvalue_clone->GetTensorMutableRawData(), ortvalue_output->GetTensorMutableRawData(), data_size, cudaMemcpyDeviceToHost);
#else
throw std::runtime_error("Unexpected error. Trying to access GPU memory but the project is not compiled with CUDA.");
#endif
} else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::DML) {
} else if (static_cast<int>(device_type) == 4) {
#if USE_DML
ComPtr<ID3D12Resource> gpu_resource;
Ort::ThrowOnError(generator.model_->GetOrtDmlApi()->GetD3D12ResourceFromAllocation(
Expand All @@ -354,13 +361,11 @@ OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator
gpu_resource.Get(),
0,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
#else
throw std::runtime_error("Unexpected error. Trying to access DML memory but the project is not compiled with DML.");
#endif
} else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_CPU) {
std::copy(static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()),
static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()) + data_size,
static_cast<uint8_t*>(ortvalue_clone->GetTensorMutableRawData()));
} else {
throw std::runtime_error("Unsupported Device type: " + std::to_string(ortvalue_output->GetTensorMemoryInfo().GetDeviceType()));
throw std::runtime_error("Unsupported device type: " + static_cast<int>(device_type));
}

auto tensor = std::make_shared<Generators::Tensor>(std::move(ortvalue_clone));
Expand Down

0 comments on commit 49eb184

Please sign in to comment.