From 64353329586ae9f15d3b894497985eb4b2ed065e Mon Sep 17 00:00:00 2001 From: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Date: Mon, 22 Apr 2024 22:46:04 -0700 Subject: [PATCH] Support eos_token_id being an array (#284) --- src/config.cpp | 32 +++++++++++ src/config.h | 13 +++-- src/models/kernels.cu | 128 +++++++++++++++++++++++++----------------- src/models/kernels.h | 4 ++ src/models/logits.cpp | 61 ++++++++++++++++---- src/models/logits.h | 9 +++ 6 files changed, 180 insertions(+), 67 deletions(-) diff --git a/src/config.cpp b/src/config.cpp index 283bac3cb..39341f5b5 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #include "generators.h" #include "json.h" #include @@ -212,6 +214,29 @@ struct Decoder_Element : JSON::Element { Outputs_Element outputs_{v_.outputs}; }; +struct Eos_Array_Element : JSON::Element { + explicit Eos_Array_Element(Config::Model& v) : v_{v} {} + + void OnNumber(std::string_view name, double value) override { + v_.eos_token_ids.push_back(static_cast(value)); + } + + void OnComplete(bool empty) { + if (v_.eos_token_ids.empty()) + return; // Empty array, nothign to do + + // Copy the first eos_token_id into the eos_token_id value, it will be our primary eos token + v_.eos_token_id = v_.eos_token_ids.front(); + + // If the array is just one value, clear the array and just act like a single value was set + if (v_.eos_token_ids.size() == 1) + v_.eos_token_ids.clear(); + } + + private: + Config::Model& v_; +}; + struct Model_Element : JSON::Element { explicit Model_Element(Config::Model& v) : v_{v} {} @@ -241,6 +266,12 @@ struct Model_Element : JSON::Element { throw JSON::unknown_value_error{}; } + Element& OnArray(std::string_view name) override { + if (name == "eos_token_id") + return eos_token_ids_; + throw JSON::unknown_value_error{}; + } + Element& OnObject(std::string_view name) override { if (name == "encoder_decoder_init") { return encoder_decoder_init_; @@ -255,6 +286,7 @@ struct Model_Element : JSON::Element { Config::Model& v_; EncoderDecoderInit_Element encoder_decoder_init_{v_.encoder_decoder_init}; Decoder_Element decoder_{v_.decoder}; + Eos_Array_Element eos_token_ids_{v_}; }; struct Search_Element : JSON::Element { diff --git a/src/config.h b/src/config.h index a126419a4..b94e05ca0 100644 --- a/src/config.h +++ b/src/config.h @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #pragma once namespace Generators { @@ -29,11 +31,12 @@ struct Config { struct Model { std::string type; - int pad_token_id{}; // The id of the padding token. - int eos_token_id{}; // The id of the end-of-stream token. - int bos_token_id{}; // The id of the beginning-of-stream token. - int sep_token_id{}; // The id of the separation token. - int decoder_start_token_id{}; // If an encoder-decoder model starts decoding with a different token than bos, the id of that token. + int pad_token_id{}; // The id of the padding token. + int eos_token_id{}; // The id of the end-of-stream token. + std::vector eos_token_ids; // If eos_token_id is passed as an array, this is where the values go (eos_token_id gets set to the first entry in the array) + int bos_token_id{}; // The id of the beginning-of-stream token. + int sep_token_id{}; // The id of the separation token. + int decoder_start_token_id{}; // If an encoder-decoder model starts decoding with a different token than bos, the id of that token. int vocab_size{}; int context_length{}; diff --git a/src/models/kernels.cu b/src/models/kernels.cu index 8bb551849..f90be8347 100644 --- a/src/models/kernels.cu +++ b/src/models/kernels.cu @@ -1,87 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #include #include #include +#include namespace Generators { namespace cuda { -template __global__ void UpdatePositionIds(T *positions, int batch_beam_size) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < batch_beam_size) - positions[i]++; +template +__global__ void UpdatePositionIds(T* positions, int batch_beam_size) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < batch_beam_size) + positions[i]++; } -template void Launch_UpdatePositionIds(T *positions, int batch_beam_size, cudaStream_t stream) { - UpdatePositionIds<<<(batch_beam_size + 255) / 256, 256, 0, stream>>>(positions, batch_beam_size); +template +void Launch_UpdatePositionIds(T* positions, int batch_beam_size, cudaStream_t stream) { + UpdatePositionIds<<<(batch_beam_size + 255) / 256, 256, 0, stream>>>(positions, batch_beam_size); } -template void Launch_UpdatePositionIds(int32_t *positions, int batch_beam_size, cudaStream_t stream); -template void Launch_UpdatePositionIds(int64_t *positions, int batch_beam_size, cudaStream_t stream); +template void Launch_UpdatePositionIds(int32_t* positions, int batch_beam_size, cudaStream_t stream); +template void Launch_UpdatePositionIds(int64_t* positions, int batch_beam_size, cudaStream_t stream); template -__global__ void CopyAndUpdateAttentionMask(T *mask_data, const T *old_mask_data, int batch_beam_size, +__global__ void CopyAndUpdateAttentionMask(T* mask_data, const T* old_mask_data, int batch_beam_size, int current_length, int max_length) { - int global_index = blockIdx.x * blockDim.x + threadIdx.x; - int i = global_index / current_length; - int j = global_index % current_length; - if (i < batch_beam_size) { - if (j < current_length - 1) { - mask_data[i * max_length + j] = old_mask_data[i * (current_length - 1) + j]; - } else { - mask_data[i * max_length + j] = 1; - } + int global_index = blockIdx.x * blockDim.x + threadIdx.x; + int i = global_index / current_length; + int j = global_index % current_length; + if (i < batch_beam_size) { + if (j < current_length - 1) { + mask_data[i * max_length + j] = old_mask_data[i * (current_length - 1) + j]; + } else { + mask_data[i * max_length + j] = 1; } + } } template -__global__ void UpdateAttentionMask(T *mask_data, int batch_beam_size, int current_length, int max_length) { - int i = blockIdx.x; - if (i < batch_beam_size) { - mask_data[i * max_length + current_length] = 1; - } +__global__ void UpdateAttentionMask(T* mask_data, int batch_beam_size, int current_length, int max_length) { + int i = blockIdx.x; + if (i < batch_beam_size) { + mask_data[i * max_length + current_length] = 1; + } } template -void Launch_UpdateAttentionMask(T *mask_data, const T *old_mask_data, int batch_beam_size, int current_length, +void Launch_UpdateAttentionMask(T* mask_data, const T* old_mask_data, int batch_beam_size, int current_length, int max_length, bool update_only, cudaStream_t stream) { - if (update_only) { - UpdateAttentionMask - <<>>(mask_data, batch_beam_size, current_length, max_length); - } else { - CopyAndUpdateAttentionMask<<<(batch_beam_size * max_length + 255) / 256, 256, 0, stream>>>( - mask_data, old_mask_data, batch_beam_size, current_length, max_length); - } + if (update_only) { + UpdateAttentionMask + <<>>(mask_data, batch_beam_size, current_length, max_length); + } else { + CopyAndUpdateAttentionMask<<<(batch_beam_size * max_length + 255) / 256, 256, 0, stream>>>( + mask_data, old_mask_data, batch_beam_size, current_length, max_length); + } } -template void Launch_UpdateAttentionMask(int32_t *mask_data, const int32_t *old_mask_data, int batch_beam_size, +template void Launch_UpdateAttentionMask(int32_t* mask_data, const int32_t* old_mask_data, int batch_beam_size, int current_length, int max_length, bool update_only, cudaStream_t stream); -template void Launch_UpdateAttentionMask(int64_t *mask_data, const int64_t *old_mask_data, int batch_beam_size, +template void Launch_UpdateAttentionMask(int64_t* mask_data, const int64_t* old_mask_data, int batch_beam_size, int current_length, int max_length, bool update_only, cudaStream_t stream); -__global__ void ConvertFp16ToFp32(const half *src, float *dst, int count) { - int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < count) - dst[idx] = __half2float(src[idx]); +__global__ void HandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= batch_beam_size) + return; + + float* logits = batch_logits + index * vocab_size; + float max = std::numeric_limits::lowest(); + for (int i = 0; i < eos_token_ids_count; i++) { + max = std::max(max, logits[eos_token_ids[i]]); + logits[eos_token_ids[i]] = std::numeric_limits::lowest(); // Set all EOS token options to never happen (the first will get the max of all) + } + + logits[eos_token_ids[0]] = max; // Set the score of the primary EOS token to the highest of any of the EOS tokens } -void LaunchFp16ToFp32(const uint16_t *fp16, float *fp32, int count, cudaStream_t stream) { - int block_size = 256; - int num_blocks = (count + block_size - 1) / block_size; - ConvertFp16ToFp32<<>>(reinterpret_cast(fp16), fp32, count); +void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream) { + HandleEOSArray<<<(batch_beam_size + 255) / 256, 256, 0, stream>>>(batch_logits, batch_beam_size, vocab_size, eos_token_ids, eos_token_ids_count); } -__global__ void ConvertInt32ToInt64(const int32_t *src, int64_t *dst, int count) { - int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < count) { - dst[idx] = src[idx]; - } +__global__ void ConvertFp16ToFp32(const half* src, float* dst, int count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < count) + dst[idx] = __half2float(src[idx]); +} + +void LaunchFp16ToFp32(const uint16_t* fp16, float* fp32, int count, cudaStream_t stream) { + int block_size = 256; + int num_blocks = (count + block_size - 1) / block_size; + ConvertFp16ToFp32<<>>(reinterpret_cast(fp16), fp32, count); +} + +__global__ void ConvertInt32ToInt64(const int32_t* src, int64_t* dst, int count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < count) { + dst[idx] = src[idx]; + } } -void LaunchInt32ToInt64(const int32_t *src, int64_t *dst, int count, cudaStream_t stream) { - int block_size = 256; - int num_blocks = (count + block_size - 1) / block_size; - ConvertInt32ToInt64<<>>(src, dst, count); +void LaunchInt32ToInt64(const int32_t* src, int64_t* dst, int count, cudaStream_t stream) { + int block_size = 256; + int num_blocks = (count + block_size - 1) / block_size; + ConvertInt32ToInt64<<>>(src, dst, count); } -} // namespace cuda -} // namespace Generators +} // namespace cuda +} // namespace Generators diff --git a/src/models/kernels.h b/src/models/kernels.h index 03160cfb2..b778d95a8 100644 --- a/src/models/kernels.h +++ b/src/models/kernels.h @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #pragma once namespace Generators { @@ -9,6 +11,8 @@ template void Launch_UpdateAttentionMask(T* mask_data, const T* old_mask_data, int batch_beam_size, int current_length, int max_length, bool update_only, cudaStream_t stream); +void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream); + void LaunchFp16ToFp32(const uint16_t* fp16, float* fp32, int count, cudaStream_t stream); void LaunchInt32ToInt64(const int32_t* src, int64_t* dst, int count, cudaStream_t stream); } // namespace cuda diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 57ce2f7f2..4c34f4b76 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -1,6 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #include "../generators.h" #include "model.h" #include "logits.h" +#if USE_CUDA +#include "kernels.h" +#endif namespace Generators { @@ -23,6 +28,14 @@ Logits::Logits(const Model& model, State& state) sb_logits16_ = state_.GetCapturedGraphInfo()->sb_logits16_.get(); } } + +#if USE_CUDA + if (model_.device_type_ == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) { + auto& cpu_ids = model_.config_->model.eos_token_ids; + cuda_eos_token_ids_ptr_ = CudaMallocArray(cpu_ids.size(), &cuda_eos_token_ids_); + cudaMemcpyAsync(cuda_eos_token_ids_.data(), cpu_ids.data(), cpu_ids.size() * sizeof(int32_t), ::cudaMemcpyHostToDevice, model_.cuda_stream_); + } +#endif } RoamingArray Logits::Get() { @@ -76,7 +89,7 @@ RoamingArray Logits::Get() { for (int beam_index = 0; beam_index < num_beams; beam_index++) { switch (model_.device_type_) { -#ifdef USE_DML +#if USE_DML case DeviceType::DML: { ComPtr source_resource; Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, value32_->GetTensorMutableRawData(), &source_resource)); @@ -126,33 +139,61 @@ RoamingArray Logits::Get() { if (type_ == Ort::TypeToTensorType::type) value16_ = !sb_logits16_ ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) : sb_logits16_->CreateTensorOnStaticBuffer(shape_, type_); - state_.outputs_[output_index_] = type_ == Ort::TypeToTensorType::type ? value32_.get() : value16_.get(); + element_count = shape_[0] * shape_[2]; // shape_[1] is now 1, so the element count must be updated } + assert(shape_[1] == 1); + #if USE_CUDA - if (model_.device_type_ == DeviceType::CUDA) - return gpu_span{value32_->GetTensorMutableData(), element_count}; + if (model_.device_type_ == DeviceType::CUDA) { + auto batched_logits_gpu = gpu_span{value32_->GetTensorMutableData(), element_count}; + if (cuda_eos_token_ids_ptr_) + cuda::LaunchHandleEOSArray(batched_logits_gpu.data(), static_cast(shape_[0]) /* batch_beam_size*/, static_cast(shape_[2]) /* vocab_size */, cuda_eos_token_ids_.data(), static_cast(cuda_eos_token_ids_.size()), model_.cuda_stream_); + return batched_logits_gpu; + } #elif USE_DML - auto cpu_tensor = value32_cpu_->GetTensorMutableData(); if (model_.device_type_ == DeviceType::DML) { // DML doesn't support on-device scoring yet, so we transfer the data to the CPU ComPtr gpu_resource; Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, value32_->GetTensorMutableRawData(), &gpu_resource)); - - size_t new_element_count = shape_[0] * shape_[1] * shape_[2]; + auto cpu_tensor = value32_cpu_->GetTensorMutableData(); model_.GetDmlReadbackHeap()->ReadbackFromGpu( - std::span(reinterpret_cast(cpu_tensor), new_element_count * sizeof(float)), + std::span(reinterpret_cast(cpu_tensor), element_count * sizeof(float)), gpu_resource.Get(), 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS); - return cpu_span{cpu_tensor, new_element_count}; + auto batched_logits_cpu = cpu_span{cpu_tensor, element_count}; + HandleEOSArray(batched_logits_cpu); + return batched_logits_cpu; } #endif - return cpu_span{value32_->GetTensorMutableData(), element_count}; + auto batched_logits_cpu = cpu_span{value32_->GetTensorMutableData(), element_count}; + HandleEOSArray(batched_logits_cpu); + return batched_logits_cpu; +} + +void Logits::HandleEOSArray(cpu_span batched_logits) { + if (model_.config_->model.eos_token_ids.empty()) + return; + + const size_t vocab_size = shape_[2]; + size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process + + for (int index = 0; index < shape_[0]; index++) { + auto logits = batched_logits.subspan(vocab_index, vocab_size); + float max = std::numeric_limits::lowest(); + for (auto id : model_.config_->model.eos_token_ids) { + max = std::max(max, logits[id]); + logits[id] = std::numeric_limits::lowest(); // Set all EOS token options to never happen (the first will get the max of all) + } + + logits[model_.config_->model.eos_token_id] = max; // Set the score of the primary EOS token to the highest of any of the EOS tokens + vocab_index += vocab_size; + } } void Logits::Add() { diff --git a/src/models/logits.h b/src/models/logits.h index 481e9e44a..4d22139ed 100644 --- a/src/models/logits.h +++ b/src/models/logits.h @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #pragma once #include "static_buffer.h" @@ -11,6 +13,8 @@ struct Logits { RoamingArray Get(); private: + void HandleEOSArray(cpu_span logits); + const Model& model_; State& state_; size_t output_index_{~0U}; @@ -24,6 +28,11 @@ struct Logits { StaticBuffer* sb_logits32_{}; StaticBuffer* sb_logits16_{}; +#if USE_CUDA + cuda_unique_ptr cuda_eos_token_ids_ptr_; // eos_token_ids from params, but in cuda accessible memory + gpu_span cuda_eos_token_ids_; +#endif + #if USE_DML DmlReusedCommandListState logits_cast_command_list_state_{}; std::unique_ptr value32_cpu_;