Skip to content

Commit

Permalink
merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
aciddelgado committed Feb 16, 2024
2 parents 5164bf0 + c0cfc7a commit a0a8b5c
Show file tree
Hide file tree
Showing 19 changed files with 1,298 additions and 149 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ if(USE_CUDA AND CMAKE_CUDA_COMPILER)
"${MODELS_ROOT}/*.cu"
"${MODELS_ROOT}/*.cuh"
)
file(GLOB test_cuda_srcs CONFIGURE_DEPENDS
"${TESTS_ROOT}/*.cu"
"${TESTS_ROOT}/*.cuh"
)
list(APPEND test_srcs ${test_cuda_srcs})
list(APPEND generator_srcs ${generator_cuda_srcs})
add_compile_definitions(USE_CUDA=1)
include_directories("${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}")
Expand Down
71 changes: 35 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,10 @@ This library provides the generative AI loop for ONNX models, including inferenc

Users can call a high level `generate()` method, or run each iteration of the model in a loop.

* Search techniques like greedy/beam search to generate token sequences
* Built in scoring tools like repetition penalties
* Support greedy/beam search and TopP, TopK sampling to generate token sequences
* Built in logits processing like repetition penalties
* Easy custom scoring

## Sample code for phi-2 in Python

Install onnxruntime-genai.

(Temporary) Build and install from source according to the instructions below.


```python
import onnxruntime_genai as og

model=og.Model(f'models/microsoft/phi-2', device_type)

tokenizer = model.create_tokenizer()

prompt = '''def print_prime(n):
"""
Print all primes between 1 and n
"""'''

tokens = tokenizer.encode(prompt)

params=og.SearchParams(model)
params.max_length = 200
params.input_ids = tokens

output_tokens=model.generate(params)

text = tokenizer.decode(output_tokens)

print("Output:")
print(text)
```


## Features

* Supported model architectures:
Expand Down Expand Up @@ -126,6 +92,39 @@ huggingface-cli login --token <your HuggingFace token>
python export.py -m microsoft/phi-2 -p int4 -e cpu -o phi2-int4-cpu.onnx
```

## Sample code for phi-2 in Python

Install onnxruntime-genai.

(Temporary) Build and install from source according to the instructions below.


```python
import onnxruntime_genai as og

model=og.Model(f'models/microsoft/phi-2', device_type)

tokenizer = model.create_tokenizer()

prompt = '''def print_prime(n):
"""
Print all primes between 1 and n
"""'''

tokens = tokenizer.encode(prompt)

params=og.SearchParams(model)
params.max_length = 200
params.input_ids = tokens

output_tokens=model.generate(params)

text = tokenizer.decode(output_tokens)

print("Output:")
print(text)
```


## Contributing

Expand Down
58 changes: 7 additions & 51 deletions src/beam_search_topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,51 +6,7 @@
namespace Generators {
namespace cuda {

template <typename T, int max_k>
struct TopK {
int32_t key[max_k];
T value[max_k];

__device__ __forceinline__ void Insert(T elem, int elem_id) {
T v = value[max_k - 1];
if (v < elem ||
(key[max_k - 1] == -1) ||
((elem == value[max_k - 1]) && (elem_id < key[max_k - 1]))) {
value[max_k - 1] = elem;
key[max_k - 1] = elem_id;
}

for (int k = max_k - 2; k >= 0; --k) {
if (value[k + 1] > value[k] ||
key[k] == -1 ||
((value[k + 1] == value[k]) && (key[k + 1] < key[k]))) {
T u2 = value[k];
int p2 = key[k];
value[k] = value[k + 1];
key[k] = key[k + 1];
value[k + 1] = u2;
key[k + 1] = p2;
}
}
}

__device__ __forceinline__ void Init() {
for (int i = 0; i < max_k; i++) {
key[i] = -1;
value[i] = -std::numeric_limits<T>::infinity();
}
}
};

template <typename T, int max_k>
__device__ __forceinline__ TopK<T, max_k> reduce_topk_op(const TopK<T, max_k>& a, const TopK<T, max_k>& b) {
TopK<T, max_k> res = a;
for (int i = 0; i < max_k; ++i)
res.Insert(b.value[i], b.key[i]);
return res;
}

// kernel to compute the top k on last axis for tensor with shape: [batch, beam_size, parts_of_vocab, vacab_part_size]
// kernel to compute the top k on last axis for tensor with shape: [batch, beam_size, parts_of_vocab, vocab_part_size]
// Its grid is [batch * beam_size, parts_of_vocab]
template <typename T, int max_k, int thread_block_size>
__launch_bounds__(thread_block_size) __global__ void BeamSearchOnlineTopKStage1Kernel(
Expand Down Expand Up @@ -319,18 +275,18 @@ void BeamSearchTopK(
tmp_indices_2nd_stage, \
tmp_values_1st_stage, \
tmp_indices_1st_stage, \
stream);
stream)

if (k <= 4) {
TopKLauncher(4)
TopKLauncher(4);
} else if (k <= 8) {
TopKLauncher(8)
TopKLauncher(8);
} else if (k <= 16) {
TopKLauncher(16)
TopKLauncher(16);
} else if (k <= 32) {
TopKLauncher(32)
TopKLauncher(32);
} else {
TopKLauncher(64)
TopKLauncher(64);
}

LaunchBatchTopKKernel(tmp_values_2nd_stage,
Expand Down
44 changes: 44 additions & 0 deletions src/beam_search_topk.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,49 @@ void BeamSearchTopK(
int32_t* output_indices,
cudaStream_t stream);

template <typename T, int max_k>
struct TopK {
int32_t key[max_k];
T value[max_k];

__device__ __forceinline__ void Insert(T elem, int elem_id) {
T v = value[max_k - 1];
if (v < elem ||
(key[max_k - 1] == -1) ||
((elem == value[max_k - 1]) && (elem_id < key[max_k - 1]))) {
value[max_k - 1] = elem;
key[max_k - 1] = elem_id;
}

for (int k = max_k - 2; k >= 0; --k) {
if (value[k + 1] > value[k] ||
key[k] == -1 ||
((value[k + 1] == value[k]) && (key[k + 1] < key[k]))) {
T u2 = value[k];
int p2 = key[k];
value[k] = value[k + 1];
key[k] = key[k + 1];
value[k + 1] = u2;
key[k + 1] = p2;
}
}
}

__device__ __forceinline__ void Init() {
for (int i = 0; i < max_k; i++) {
key[i] = -1;
value[i] = -std::numeric_limits<T>::infinity();
}
}
};

template <typename T, int max_k>
__device__ __forceinline__ TopK<T, max_k> reduce_topk_op(const TopK<T, max_k>& a, const TopK<T, max_k>& b) {
TopK<T, max_k> res = a;
for (int i = 0; i < max_k; ++i)
res.Insert(b.value[i], b.key[i]);
return res;
}

} // namespace cuda
} // namespace Generators
Loading

0 comments on commit a0a8b5c

Please sign in to comment.