diff --git a/CMakeLists.txt b/CMakeLists.txt index 16ebf4b8b..277fe0659 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,6 +58,11 @@ if(USE_CUDA AND CMAKE_CUDA_COMPILER) "${MODELS_ROOT}/*.cu" "${MODELS_ROOT}/*.cuh" ) + file(GLOB test_cuda_srcs CONFIGURE_DEPENDS + "${TESTS_ROOT}/*.cu" + "${TESTS_ROOT}/*.cuh" + ) + list(APPEND test_srcs ${test_cuda_srcs}) list(APPEND generator_srcs ${generator_cuda_srcs}) add_compile_definitions(USE_CUDA=1) include_directories("${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}") diff --git a/README.md b/README.md index 2cf827ba5..314afbaa3 100644 --- a/README.md +++ b/README.md @@ -6,44 +6,10 @@ This library provides the generative AI loop for ONNX models, including inferenc Users can call a high level `generate()` method, or run each iteration of the model in a loop. -* Search techniques like greedy/beam search to generate token sequences -* Built in scoring tools like repetition penalties +* Support greedy/beam search and TopP, TopK sampling to generate token sequences +* Built in logits processing like repetition penalties * Easy custom scoring -## Sample code for phi-2 in Python - -Install onnxruntime-genai. - -(Temporary) Build and install from source according to the instructions below. - - -```python -import onnxruntime_genai as og - -model=og.Model(f'models/microsoft/phi-2', device_type) - -tokenizer = model.create_tokenizer() - -prompt = '''def print_prime(n): - """ - Print all primes between 1 and n - """''' - -tokens = tokenizer.encode(prompt) - -params=og.SearchParams(model) -params.max_length = 200 -params.input_ids = tokens - -output_tokens=model.generate(params) - -text = tokenizer.decode(output_tokens) - -print("Output:") -print(text) -``` - - ## Features * Supported model architectures: @@ -126,6 +92,39 @@ huggingface-cli login --token python export.py -m microsoft/phi-2 -p int4 -e cpu -o phi2-int4-cpu.onnx ``` +## Sample code for phi-2 in Python + +Install onnxruntime-genai. + +(Temporary) Build and install from source according to the instructions below. + + +```python +import onnxruntime_genai as og + +model=og.Model(f'models/microsoft/phi-2', device_type) + +tokenizer = model.create_tokenizer() + +prompt = '''def print_prime(n): + """ + Print all primes between 1 and n + """''' + +tokens = tokenizer.encode(prompt) + +params=og.SearchParams(model) +params.max_length = 200 +params.input_ids = tokens + +output_tokens=model.generate(params) + +text = tokenizer.decode(output_tokens) + +print("Output:") +print(text) +``` + ## Contributing diff --git a/src/beam_search_topk.cu b/src/beam_search_topk.cu index 1dae92cc4..222561ce8 100644 --- a/src/beam_search_topk.cu +++ b/src/beam_search_topk.cu @@ -6,51 +6,7 @@ namespace Generators { namespace cuda { -template -struct TopK { - int32_t key[max_k]; - T value[max_k]; - - __device__ __forceinline__ void Insert(T elem, int elem_id) { - T v = value[max_k - 1]; - if (v < elem || - (key[max_k - 1] == -1) || - ((elem == value[max_k - 1]) && (elem_id < key[max_k - 1]))) { - value[max_k - 1] = elem; - key[max_k - 1] = elem_id; - } - - for (int k = max_k - 2; k >= 0; --k) { - if (value[k + 1] > value[k] || - key[k] == -1 || - ((value[k + 1] == value[k]) && (key[k + 1] < key[k]))) { - T u2 = value[k]; - int p2 = key[k]; - value[k] = value[k + 1]; - key[k] = key[k + 1]; - value[k + 1] = u2; - key[k + 1] = p2; - } - } - } - - __device__ __forceinline__ void Init() { - for (int i = 0; i < max_k; i++) { - key[i] = -1; - value[i] = -std::numeric_limits::infinity(); - } - } -}; - -template -__device__ __forceinline__ TopK reduce_topk_op(const TopK& a, const TopK& b) { - TopK res = a; - for (int i = 0; i < max_k; ++i) - res.Insert(b.value[i], b.key[i]); - return res; -} - -// kernel to compute the top k on last axis for tensor with shape: [batch, beam_size, parts_of_vocab, vacab_part_size] +// kernel to compute the top k on last axis for tensor with shape: [batch, beam_size, parts_of_vocab, vocab_part_size] // Its grid is [batch * beam_size, parts_of_vocab] template __launch_bounds__(thread_block_size) __global__ void BeamSearchOnlineTopKStage1Kernel( @@ -319,18 +275,18 @@ void BeamSearchTopK( tmp_indices_2nd_stage, \ tmp_values_1st_stage, \ tmp_indices_1st_stage, \ - stream); + stream) if (k <= 4) { - TopKLauncher(4) + TopKLauncher(4); } else if (k <= 8) { - TopKLauncher(8) + TopKLauncher(8); } else if (k <= 16) { - TopKLauncher(16) + TopKLauncher(16); } else if (k <= 32) { - TopKLauncher(32) + TopKLauncher(32); } else { - TopKLauncher(64) + TopKLauncher(64); } LaunchBatchTopKKernel(tmp_values_2nd_stage, diff --git a/src/beam_search_topk.h b/src/beam_search_topk.h index 4bf6210ce..4cd3f4b62 100644 --- a/src/beam_search_topk.h +++ b/src/beam_search_topk.h @@ -20,5 +20,49 @@ void BeamSearchTopK( int32_t* output_indices, cudaStream_t stream); +template +struct TopK { + int32_t key[max_k]; + T value[max_k]; + + __device__ __forceinline__ void Insert(T elem, int elem_id) { + T v = value[max_k - 1]; + if (v < elem || + (key[max_k - 1] == -1) || + ((elem == value[max_k - 1]) && (elem_id < key[max_k - 1]))) { + value[max_k - 1] = elem; + key[max_k - 1] = elem_id; + } + + for (int k = max_k - 2; k >= 0; --k) { + if (value[k + 1] > value[k] || + key[k] == -1 || + ((value[k + 1] == value[k]) && (key[k + 1] < key[k]))) { + T u2 = value[k]; + int p2 = key[k]; + value[k] = value[k + 1]; + key[k] = key[k + 1]; + value[k + 1] = u2; + key[k + 1] = p2; + } + } + } + + __device__ __forceinline__ void Init() { + for (int i = 0; i < max_k; i++) { + key[i] = -1; + value[i] = -std::numeric_limits::infinity(); + } + } +}; + +template +__device__ __forceinline__ TopK reduce_topk_op(const TopK& a, const TopK& b) { + TopK res = a; + for (int i = 0; i < max_k; ++i) + res.Insert(b.value[i], b.key[i]); + return res; +} + } // namespace cuda } // namespace Generators diff --git a/src/cuda_sampling.cu b/src/cuda_sampling.cu new file mode 100644 index 000000000..15636b928 --- /dev/null +++ b/src/cuda_sampling.cu @@ -0,0 +1,596 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include "beam_search_topk.h" +#include "cuda_sampling.cuh" +#include "smartptrs.h" +#include +#include +#include +#include + +namespace Generators { +namespace cuda { + +constexpr int kMaxThreads = 1024; +constexpr int kGPUWarpSize = 32; + +SamplingData::SamplingData(int batch_size, int vocab_size, cudaStream_t stream) { + indices_sorted = CudaMallocArray(vocab_size * batch_size); + scores_sorted = CudaMallocArray(vocab_size * batch_size); + scores_softmaxed = CudaMallocArray(vocab_size * batch_size); + prefix_sums = CudaMallocArray(vocab_size * batch_size); + thresholds = CudaMallocArray(batch_size); + indices_in = CudaMallocArray(vocab_size * batch_size); + offsets = CudaMallocArray(batch_size + 1); + temp_storage_bytes = 0; + cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, (float*)nullptr, (float*)nullptr, + (int*)nullptr, (int*)nullptr, vocab_size*batch_size, batch_size, (int*)nullptr, (int*)nullptr, 0, sizeof(float) * 8, stream); + temp_buffer = CudaMallocArray(temp_storage_bytes / sizeof(float)); +} + +// Softmax Kernels and Launchers + +template +struct MaxFloat { + __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { + return ::max(max, (AccumT)v); + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +struct SumExpFloat { + __device__ __forceinline__ SumExpFloat(AccumT v) + : max_k(v) {} + + __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { + return sum + exp((AccumT)v - max_k); + } + + const AccumT max_k; +}; + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +// aligned vector generates vectorized load/store on CUDA +template +struct alignas(sizeof(T) * vec_size) aligned_vector { + T val[vec_size]; +}; + +template