Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
Browse files Browse the repository at this point in the history
…nto add_split_param
  • Loading branch information
DesmonDay committed Oct 14, 2024
2 parents 9fdaae2 + d46bc06 commit 9a210db
Show file tree
Hide file tree
Showing 29 changed files with 881 additions and 173 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩
### pip 安装

```shell
pip install --upgrade paddlenlp==3.0.0b1
pip install --upgrade paddlenlp==3.0.0b2
```

或者可通过以下命令安装最新 develop 分支代码:
Expand Down
2 changes: 1 addition & 1 deletion README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Detailed list 👉 [Supported Model List](https://github.com/PaddlePaddle/Paddle
### Pip Installation

```shell
pip install --upgrade paddlenlp==3.0.0b1
pip install --upgrade paddlenlp==3.0.0b2
```

or you can install the latest develop branch code with the following command:
Expand Down
67 changes: 30 additions & 37 deletions csrc/gpu/sample_kernels/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ namespace sampling {

using namespace cub;

#define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \
if (compute_capacity.first >= 8) { \
constexpr uint32_t BLOCK_THREADS = 1024; \
__VA_ARGS__ \
} else { \
constexpr uint32_t BLOCK_THREADS = 512; \
__VA_ARGS__ \
}

constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS;
constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS;

Expand Down Expand Up @@ -277,17 +286,12 @@ template <uint32_t BLOCK_THREADS,
__global__ void TopPSamplingFromProbKernel(DType* probs,
DType* uniform_samples,
IdType* output,
bool* success,
IdType* row_indices,
float* top_p_arr,
float* top_p_val,
uint32_t d,
uint32_t max_top_p_rounds) {
const uint32_t batch_size = gridDim.x;
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
float top_p = (top_p_arr == nullptr) ? top_p_val[bx] : top_p_arr[bx];

const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
float top_p = top_p_val[bx];

extern __shared__ __align__(alignof(SamplingTempStorage<DType,
BLOCK_THREADS,
Expand All @@ -313,7 +317,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs,
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + row_idx * d +
probs_vec.load(probs + bx * d +
(i * BLOCK_THREADS + tx) * VEC_SIZE);
}

Expand All @@ -330,58 +334,51 @@ __global__ void TopPSamplingFromProbKernel(DType* probs,
}
__syncthreads();
sampled_id = temp_storage.data.sampled_id;
pivot = max(pivot, probs[row_idx * d + sampled_id]);
pivot = max(pivot, probs[bx * d + sampled_id]);

DType aggregate_gt_pivot = DType(0);
Pair<DType> aggregate_gt_pivot{DType(0), 0};
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + row_idx * d +
(i * BLOCK_THREADS + tx) * VEC_SIZE);
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

DType probs_gt_pivot[VEC_SIZE];
Pair<DType> probs_gt_pivot[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0);
probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0),
(probs_vec[j] > pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
}

aggregate_gt_pivot +=
BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot);
aggregate_gt_pivot += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_pair)
.Sum<VEC_SIZE>(probs_gt_pivot);
if (tx == 0) {
temp_storage.data.block_aggregate.value = aggregate_gt_pivot;
temp_storage.data.block_aggregate.pair = aggregate_gt_pivot;
}
__syncthreads();
}
q = temp_storage.data.block_aggregate.value;
if (float(q) < top_p) {
q = temp_storage.data.block_aggregate.pair.value;
if (float(q) > 0 && float(q) < top_p) {
// top_p is not 0
break;
} else {
// top_p is 0, use top_k, k=1
if (temp_storage.data.block_aggregate.pair.count < 1) {
break;
}
}
}
__syncthreads();
if (tx == 0) {
output[bx] = sampled_id;
if (float(q) >= top_p) {
// failed to sample within MAX_TOP_P_ROUNDS
if (success != nullptr) {
success[bx] = false;
}
} else {
if (success != nullptr) {
success[bx] = true;
}
}
}
}


template <typename T, typename IdType>
cudaError_t TopPSamplingFromProb(T* probs,
T* uniform_samples,
IdType* output,
bool* success,
T* top_p_arr,
uint32_t batch_size,
const T* top_p_val,
uint32_t d,
Expand All @@ -395,13 +392,9 @@ cudaError_t TopPSamplingFromProb(T* probs,
sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
IdType* row_indices_placeholder = nullptr;
void* args[] = {&probs,
&uniform_samples,
&output,
&success,
&row_indices_placeholder,
&top_p_arr,
&top_p_val,
&d,
&max_top_p_rounds};
Expand All @@ -425,4 +418,4 @@ cudaError_t TopPSamplingFromProb(T* probs,
return cudaSuccess;
}

} // namespace sampling
} // namespace sampling
55 changes: 27 additions & 28 deletions csrc/gpu/sample_kernels/top_p_sampling_reject.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,48 +16,46 @@
#include "sample_kernels/sampling.cuh"

std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor& probs,
const paddle::Tensor& top_p) {
const paddle::Tensor& top_p,
int seed) {
std::vector<int64_t> probs_shape = probs.shape();
unsigned int batch_size = probs_shape[0];
unsigned int vocab_size = probs_shape[1];

// default is 32
unsigned int max_top_p_rounds = 32;
std::vector<int64_t> uniform_samples_shape = {batch_size, max_top_p_rounds};
paddle::Tensor uniform_samples = paddle::experimental::uniform(
uniform_samples_shape, paddle::DataType::FLOAT32, 0, 1, 0, probs.place());
paddle::Tensor uniform_samples =
paddle::experimental::uniform(uniform_samples_shape,
paddle::DataType::FLOAT32,
0,
1,
seed,
probs.place());

// todo: add parameter for deterministic, now default is true
bool deterministic = true;
paddle::Tensor probs_input;

probs_input = paddle::experimental::cast(probs, paddle::DataType::FLOAT32);
auto cu_stream = probs.stream();

auto samples =
paddle::full({batch_size}, 0, paddle::DataType::INT32, probs.place());
auto success =
paddle::full({batch_size}, 0, paddle::DataType::BOOL, probs.place());
paddle::empty({batch_size, 1}, paddle::DataType::INT64, probs.place());

cudaError_t status;

cudaError_t status =
sampling::TopPSamplingFromProb<float, int>(probs_input.data<float>(),
uniform_samples.data<float>(),
samples.data<int>(),
success.data<bool>(),
nullptr,
batch_size,
top_p.data<float>(),
vocab_size,
max_top_p_rounds,
deterministic,
cu_stream);
status = sampling::TopPSamplingFromProb<float, int64_t>(
const_cast<float*>(probs.data<float>()),
uniform_samples.data<float>(),
samples.data<int64_t>(),
batch_size,
top_p.data<float>(),
vocab_size,
max_top_p_rounds,
true,
cu_stream);

PD_CHECK(status == cudaSuccess,
"SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));

paddle::Tensor samples_output;
samples_output = paddle::experimental::cast(samples, paddle::DataType::INT64);
return {samples_output};
return {samples};
}

std::vector<std::vector<int64_t>> TopPSamplingRejectInferShape(
Expand All @@ -69,12 +67,13 @@ std::vector<std::vector<int64_t>> TopPSamplingRejectInferShape(

std::vector<paddle::DataType> TopPSamplingRejectInferDtype(
const paddle::DataType& probs_dtype, const paddle::DataType& top_p_shape) {
return {probs_dtype};
return {paddle::DataType::INT64};
}

PD_BUILD_OP(top_p_sampling_reject)
.Inputs({"probs", "top_p"})
.Outputs({"samples"})
.Attrs({"seed: int"})
.SetKernelFn(PD_KERNEL(TopPSamplingReject))
.SetInferShapeFn(PD_INFER_SHAPE(TopPSamplingRejectInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingRejectInferDtype));
.SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingRejectInferDtype));
63 changes: 63 additions & 0 deletions csrc/gpu/test/python/test_top_p_sampling_reject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import paddle
from paddlenlp_ops import top_p_sampling_reject

paddle.seed(2023)

batch_size = 3
vocab_size = 40080
max_rounds = 32

class SetPreidsTokenPenaltyMultiScores(unittest.TestCase):
def test_top_p_sampling_reject_case1(self):
# top_p为1, 不同seed
pre_norm_prob_np = np.random.rand(batch_size, vocab_size).astype(np.float32)

paddle_pre_norm_prob = paddle.to_tensor(pre_norm_prob_np)
paddle_norm_prob = paddle_pre_norm_prob / paddle_pre_norm_prob.sum(axis=-1, keepdim=True)
top_p_paddle = paddle.full((batch_size,), 1)
samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 0)
print(samples)
samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 1024)
print(samples)
samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 2033)
print(samples)

def test_top_p_sampling_reject_case2(self):
# top_p为0
pre_norm_prob_np = np.random.rand(batch_size, vocab_size).astype(np.float32)

paddle_pre_norm_prob = paddle.to_tensor(pre_norm_prob_np)
paddle_norm_prob = paddle_pre_norm_prob / paddle_pre_norm_prob.sum(axis=-1, keepdim=True)
top_p_paddle = paddle.full((batch_size,), 0)
samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 0)
print(samples)

def test_top_p_sampling_reject_case3(self):
# 不同batch的top_p值不同
pre_norm_prob_np = np.random.rand(batch_size, vocab_size).astype(np.float32)

paddle_pre_norm_prob = paddle.to_tensor(pre_norm_prob_np)
paddle_norm_prob = paddle_pre_norm_prob / paddle_pre_norm_prob.sum(axis=-1, keepdim=True)
top_p_paddle = paddle.uniform(shape=[batch_size,1], min=0, max=1)
samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 0)
print(samples)

if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 9a210db

Please sign in to comment.