Skip to content

Commit

Permalink
Support eos_token_id being an array (#284)
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanUnderhill authored Apr 23, 2024
1 parent cfaa57c commit 6435332
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 67 deletions.
32 changes: 32 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "generators.h"
#include "json.h"
#include <fstream>
Expand Down Expand Up @@ -212,6 +214,29 @@ struct Decoder_Element : JSON::Element {
Outputs_Element outputs_{v_.outputs};
};

struct Eos_Array_Element : JSON::Element {
explicit Eos_Array_Element(Config::Model& v) : v_{v} {}

void OnNumber(std::string_view name, double value) override {
v_.eos_token_ids.push_back(static_cast<int>(value));
}

void OnComplete(bool empty) {
if (v_.eos_token_ids.empty())
return; // Empty array, nothign to do

// Copy the first eos_token_id into the eos_token_id value, it will be our primary eos token
v_.eos_token_id = v_.eos_token_ids.front();

// If the array is just one value, clear the array and just act like a single value was set
if (v_.eos_token_ids.size() == 1)
v_.eos_token_ids.clear();
}

private:
Config::Model& v_;
};

struct Model_Element : JSON::Element {
explicit Model_Element(Config::Model& v) : v_{v} {}

Expand Down Expand Up @@ -241,6 +266,12 @@ struct Model_Element : JSON::Element {
throw JSON::unknown_value_error{};
}

Element& OnArray(std::string_view name) override {
if (name == "eos_token_id")
return eos_token_ids_;
throw JSON::unknown_value_error{};
}

Element& OnObject(std::string_view name) override {
if (name == "encoder_decoder_init") {
return encoder_decoder_init_;
Expand All @@ -255,6 +286,7 @@ struct Model_Element : JSON::Element {
Config::Model& v_;
EncoderDecoderInit_Element encoder_decoder_init_{v_.encoder_decoder_init};
Decoder_Element decoder_{v_.decoder};
Eos_Array_Element eos_token_ids_{v_};
};

struct Search_Element : JSON::Element {
Expand Down
13 changes: 8 additions & 5 deletions src/config.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once

namespace Generators {
Expand Down Expand Up @@ -29,11 +31,12 @@ struct Config {
struct Model {
std::string type;

int pad_token_id{}; // The id of the padding token.
int eos_token_id{}; // The id of the end-of-stream token.
int bos_token_id{}; // The id of the beginning-of-stream token.
int sep_token_id{}; // The id of the separation token.
int decoder_start_token_id{}; // If an encoder-decoder model starts decoding with a different token than bos, the id of that token.
int pad_token_id{}; // The id of the padding token.
int eos_token_id{}; // The id of the end-of-stream token.
std::vector<int> eos_token_ids; // If eos_token_id is passed as an array, this is where the values go (eos_token_id gets set to the first entry in the array)
int bos_token_id{}; // The id of the beginning-of-stream token.
int sep_token_id{}; // The id of the separation token.
int decoder_start_token_id{}; // If an encoder-decoder model starts decoding with a different token than bos, the id of that token.
int vocab_size{};
int context_length{};

Expand Down
128 changes: 76 additions & 52 deletions src/models/kernels.cu
Original file line number Diff line number Diff line change
@@ -1,87 +1,111 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <stdint.h>
#include <limits>

namespace Generators {
namespace cuda {

template <typename T> __global__ void UpdatePositionIds(T *positions, int batch_beam_size) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < batch_beam_size)
positions[i]++;
template <typename T>
__global__ void UpdatePositionIds(T* positions, int batch_beam_size) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < batch_beam_size)
positions[i]++;
}

template <typename T> void Launch_UpdatePositionIds(T *positions, int batch_beam_size, cudaStream_t stream) {
UpdatePositionIds<T><<<(batch_beam_size + 255) / 256, 256, 0, stream>>>(positions, batch_beam_size);
template <typename T>
void Launch_UpdatePositionIds(T* positions, int batch_beam_size, cudaStream_t stream) {
UpdatePositionIds<T><<<(batch_beam_size + 255) / 256, 256, 0, stream>>>(positions, batch_beam_size);
}

template void Launch_UpdatePositionIds(int32_t *positions, int batch_beam_size, cudaStream_t stream);
template void Launch_UpdatePositionIds(int64_t *positions, int batch_beam_size, cudaStream_t stream);
template void Launch_UpdatePositionIds(int32_t* positions, int batch_beam_size, cudaStream_t stream);
template void Launch_UpdatePositionIds(int64_t* positions, int batch_beam_size, cudaStream_t stream);

template <typename T>
__global__ void CopyAndUpdateAttentionMask(T *mask_data, const T *old_mask_data, int batch_beam_size,
__global__ void CopyAndUpdateAttentionMask(T* mask_data, const T* old_mask_data, int batch_beam_size,
int current_length, int max_length) {
int global_index = blockIdx.x * blockDim.x + threadIdx.x;
int i = global_index / current_length;
int j = global_index % current_length;
if (i < batch_beam_size) {
if (j < current_length - 1) {
mask_data[i * max_length + j] = old_mask_data[i * (current_length - 1) + j];
} else {
mask_data[i * max_length + j] = 1;
}
int global_index = blockIdx.x * blockDim.x + threadIdx.x;
int i = global_index / current_length;
int j = global_index % current_length;
if (i < batch_beam_size) {
if (j < current_length - 1) {
mask_data[i * max_length + j] = old_mask_data[i * (current_length - 1) + j];
} else {
mask_data[i * max_length + j] = 1;
}
}
}

template <typename T>
__global__ void UpdateAttentionMask(T *mask_data, int batch_beam_size, int current_length, int max_length) {
int i = blockIdx.x;
if (i < batch_beam_size) {
mask_data[i * max_length + current_length] = 1;
}
__global__ void UpdateAttentionMask(T* mask_data, int batch_beam_size, int current_length, int max_length) {
int i = blockIdx.x;
if (i < batch_beam_size) {
mask_data[i * max_length + current_length] = 1;
}
}

template <typename T>
void Launch_UpdateAttentionMask(T *mask_data, const T *old_mask_data, int batch_beam_size, int current_length,
void Launch_UpdateAttentionMask(T* mask_data, const T* old_mask_data, int batch_beam_size, int current_length,
int max_length, bool update_only, cudaStream_t stream) {
if (update_only) {
UpdateAttentionMask<T>
<<<batch_beam_size, 1, 0, stream>>>(mask_data, batch_beam_size, current_length, max_length);
} else {
CopyAndUpdateAttentionMask<T><<<(batch_beam_size * max_length + 255) / 256, 256, 0, stream>>>(
mask_data, old_mask_data, batch_beam_size, current_length, max_length);
}
if (update_only) {
UpdateAttentionMask<T>
<<<batch_beam_size, 1, 0, stream>>>(mask_data, batch_beam_size, current_length, max_length);
} else {
CopyAndUpdateAttentionMask<T><<<(batch_beam_size * max_length + 255) / 256, 256, 0, stream>>>(
mask_data, old_mask_data, batch_beam_size, current_length, max_length);
}
}

template void Launch_UpdateAttentionMask(int32_t *mask_data, const int32_t *old_mask_data, int batch_beam_size,
template void Launch_UpdateAttentionMask(int32_t* mask_data, const int32_t* old_mask_data, int batch_beam_size,
int current_length, int max_length, bool update_only, cudaStream_t stream);
template void Launch_UpdateAttentionMask(int64_t *mask_data, const int64_t *old_mask_data, int batch_beam_size,
template void Launch_UpdateAttentionMask(int64_t* mask_data, const int64_t* old_mask_data, int batch_beam_size,
int current_length, int max_length, bool update_only, cudaStream_t stream);

__global__ void ConvertFp16ToFp32(const half *src, float *dst, int count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < count)
dst[idx] = __half2float(src[idx]);
__global__ void HandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= batch_beam_size)
return;

float* logits = batch_logits + index * vocab_size;
float max = std::numeric_limits<float>::lowest();
for (int i = 0; i < eos_token_ids_count; i++) {
max = std::max(max, logits[eos_token_ids[i]]);
logits[eos_token_ids[i]] = std::numeric_limits<float>::lowest(); // Set all EOS token options to never happen (the first will get the max of all)
}

logits[eos_token_ids[0]] = max; // Set the score of the primary EOS token to the highest of any of the EOS tokens
}

void LaunchFp16ToFp32(const uint16_t *fp16, float *fp32, int count, cudaStream_t stream) {
int block_size = 256;
int num_blocks = (count + block_size - 1) / block_size;
ConvertFp16ToFp32<<<num_blocks, block_size, 0, stream>>>(reinterpret_cast<const half *>(fp16), fp32, count);
void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream) {
HandleEOSArray<<<(batch_beam_size + 255) / 256, 256, 0, stream>>>(batch_logits, batch_beam_size, vocab_size, eos_token_ids, eos_token_ids_count);
}

__global__ void ConvertInt32ToInt64(const int32_t *src, int64_t *dst, int count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < count) {
dst[idx] = src[idx];
}
__global__ void ConvertFp16ToFp32(const half* src, float* dst, int count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < count)
dst[idx] = __half2float(src[idx]);
}

void LaunchFp16ToFp32(const uint16_t* fp16, float* fp32, int count, cudaStream_t stream) {
int block_size = 256;
int num_blocks = (count + block_size - 1) / block_size;
ConvertFp16ToFp32<<<num_blocks, block_size, 0, stream>>>(reinterpret_cast<const half*>(fp16), fp32, count);
}

__global__ void ConvertInt32ToInt64(const int32_t* src, int64_t* dst, int count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < count) {
dst[idx] = src[idx];
}
}

void LaunchInt32ToInt64(const int32_t *src, int64_t *dst, int count, cudaStream_t stream) {
int block_size = 256;
int num_blocks = (count + block_size - 1) / block_size;
ConvertInt32ToInt64<<<num_blocks, block_size, 0, stream>>>(src, dst, count);
void LaunchInt32ToInt64(const int32_t* src, int64_t* dst, int count, cudaStream_t stream) {
int block_size = 256;
int num_blocks = (count + block_size - 1) / block_size;
ConvertInt32ToInt64<<<num_blocks, block_size, 0, stream>>>(src, dst, count);
}

} // namespace cuda
} // namespace Generators
} // namespace cuda
} // namespace Generators
4 changes: 4 additions & 0 deletions src/models/kernels.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace Generators {

Expand All @@ -9,6 +11,8 @@ template <typename T>
void Launch_UpdateAttentionMask(T* mask_data, const T* old_mask_data, int batch_beam_size, int current_length,
int max_length, bool update_only, cudaStream_t stream);

void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream);

void LaunchFp16ToFp32(const uint16_t* fp16, float* fp32, int count, cudaStream_t stream);
void LaunchInt32ToInt64(const int32_t* src, int64_t* dst, int count, cudaStream_t stream);
} // namespace cuda
Expand Down
61 changes: 51 additions & 10 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "../generators.h"
#include "model.h"
#include "logits.h"
#if USE_CUDA
#include "kernels.h"
#endif

namespace Generators {

Expand All @@ -23,6 +28,14 @@ Logits::Logits(const Model& model, State& state)
sb_logits16_ = state_.GetCapturedGraphInfo()->sb_logits16_.get();
}
}

#if USE_CUDA
if (model_.device_type_ == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) {
auto& cpu_ids = model_.config_->model.eos_token_ids;
cuda_eos_token_ids_ptr_ = CudaMallocArray<int32_t>(cpu_ids.size(), &cuda_eos_token_ids_);
cudaMemcpyAsync(cuda_eos_token_ids_.data(), cpu_ids.data(), cpu_ids.size() * sizeof(int32_t), ::cudaMemcpyHostToDevice, model_.cuda_stream_);
}
#endif
}

RoamingArray<float> Logits::Get() {
Expand Down Expand Up @@ -76,7 +89,7 @@ RoamingArray<float> Logits::Get() {

for (int beam_index = 0; beam_index < num_beams; beam_index++) {
switch (model_.device_type_) {
#ifdef USE_DML
#if USE_DML
case DeviceType::DML: {
ComPtr<ID3D12Resource> source_resource;
Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, value32_->GetTensorMutableRawData(), &source_resource));
Expand Down Expand Up @@ -126,33 +139,61 @@ RoamingArray<float> Logits::Get() {
if (type_ == Ort::TypeToTensorType<Ort::Float16_t>::type)
value16_ = !sb_logits16_ ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_)
: sb_logits16_->CreateTensorOnStaticBuffer(shape_, type_);

state_.outputs_[output_index_] = type_ == Ort::TypeToTensorType<float>::type ? value32_.get() : value16_.get();
element_count = shape_[0] * shape_[2]; // shape_[1] is now 1, so the element count must be updated
}

assert(shape_[1] == 1);

#if USE_CUDA
if (model_.device_type_ == DeviceType::CUDA)
return gpu_span<float>{value32_->GetTensorMutableData<float>(), element_count};
if (model_.device_type_ == DeviceType::CUDA) {
auto batched_logits_gpu = gpu_span<float>{value32_->GetTensorMutableData<float>(), element_count};
if (cuda_eos_token_ids_ptr_)
cuda::LaunchHandleEOSArray(batched_logits_gpu.data(), static_cast<int>(shape_[0]) /* batch_beam_size*/, static_cast<int>(shape_[2]) /* vocab_size */, cuda_eos_token_ids_.data(), static_cast<int>(cuda_eos_token_ids_.size()), model_.cuda_stream_);
return batched_logits_gpu;
}
#elif USE_DML
auto cpu_tensor = value32_cpu_->GetTensorMutableData<float>();
if (model_.device_type_ == DeviceType::DML) {
// DML doesn't support on-device scoring yet, so we transfer the data to the CPU
ComPtr<ID3D12Resource> gpu_resource;
Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, value32_->GetTensorMutableRawData(), &gpu_resource));

size_t new_element_count = shape_[0] * shape_[1] * shape_[2];
auto cpu_tensor = value32_cpu_->GetTensorMutableData<float>();

model_.GetDmlReadbackHeap()->ReadbackFromGpu(
std::span(reinterpret_cast<uint8_t*>(cpu_tensor), new_element_count * sizeof(float)),
std::span(reinterpret_cast<uint8_t*>(cpu_tensor), element_count * sizeof(float)),
gpu_resource.Get(),
0,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS);

return cpu_span<float>{cpu_tensor, new_element_count};
auto batched_logits_cpu = cpu_span<float>{cpu_tensor, element_count};
HandleEOSArray(batched_logits_cpu);
return batched_logits_cpu;
}
#endif

return cpu_span<float>{value32_->GetTensorMutableData<float>(), element_count};
auto batched_logits_cpu = cpu_span<float>{value32_->GetTensorMutableData<float>(), element_count};
HandleEOSArray(batched_logits_cpu);
return batched_logits_cpu;
}

void Logits::HandleEOSArray(cpu_span<float> batched_logits) {
if (model_.config_->model.eos_token_ids.empty())
return;

const size_t vocab_size = shape_[2];
size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process

for (int index = 0; index < shape_[0]; index++) {
auto logits = batched_logits.subspan(vocab_index, vocab_size);
float max = std::numeric_limits<float>::lowest();
for (auto id : model_.config_->model.eos_token_ids) {
max = std::max(max, logits[id]);
logits[id] = std::numeric_limits<float>::lowest(); // Set all EOS token options to never happen (the first will get the max of all)
}

logits[model_.config_->model.eos_token_id] = max; // Set the score of the primary EOS token to the highest of any of the EOS tokens
vocab_index += vocab_size;
}
}

void Logits::Add() {
Expand Down
9 changes: 9 additions & 0 deletions src/models/logits.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once

#include "static_buffer.h"
Expand All @@ -11,6 +13,8 @@ struct Logits {
RoamingArray<float> Get();

private:
void HandleEOSArray(cpu_span<float> logits);

const Model& model_;
State& state_;
size_t output_index_{~0U};
Expand All @@ -24,6 +28,11 @@ struct Logits {
StaticBuffer* sb_logits32_{};
StaticBuffer* sb_logits16_{};

#if USE_CUDA
cuda_unique_ptr<int32_t> cuda_eos_token_ids_ptr_; // eos_token_ids from params, but in cuda accessible memory
gpu_span<int32_t> cuda_eos_token_ids_;
#endif

#if USE_DML
DmlReusedCommandListState logits_cast_command_list_state_{};
std::unique_ptr<OrtValue> value32_cpu_;
Expand Down

0 comments on commit 6435332

Please sign in to comment.