Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept left offsets in the rotary embeddings layer #1372

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ namespace ctranslate2 {
const dim_t num_initial_positions = 2048,
const float base = 10000);

void apply(StorageView& x, const dim_t offset = 0);
void apply(StorageView& x, const dim_t step = 0, const StorageView* offsets = nullptr);

private:
void initialize(const dim_t num_positions,
Expand Down
8 changes: 6 additions & 2 deletions include/ctranslate2/ops/rotary.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@ namespace ctranslate2 {
void operator()(const StorageView& input,
const StorageView& sin,
const StorageView& cos,
StorageView& output) const;
StorageView& output,
const StorageView* offsets = nullptr,
const dim_t step = 0) const;

private:
const dim_t _ndims;
const bool _interleave;

template <Device D, typename T>
void compute(const StorageView& input,
void compute(const dim_t step,
const StorageView* offsets,
const StorageView& input,
const StorageView& sin,
const StorageView& cos,
StorageView& output) const;
Expand Down
14 changes: 3 additions & 11 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -602,28 +602,20 @@ namespace ctranslate2 {
{
}

void RotaryEmbeddings::apply(StorageView& x, const dim_t offset) {
void RotaryEmbeddings::apply(StorageView& x, const dim_t step, const StorageView* offsets) {
const Device device = x.device();
const DataType dtype = x.dtype();
const dim_t max_time = x.dim(-2);
const dim_t dim = _dim == 0 ? x.dim(-1) : _dim;

if (!_sin || offset + max_time > _sin.dim(0)) {
if (!_sin || step + max_time > _sin.dim(0)) {
const dim_t cur_num_positions = _sin ? _sin.dim(0) : 0;
const dim_t new_num_positions = cur_num_positions + _num_initial_positions;
initialize(new_num_positions, dim, device, dtype);
}

StorageView sin(dtype, device);
StorageView cos(dtype, device);
TYPE_DISPATCH(dtype,
{
sin.view(_sin.index<T>({offset, 0}), {max_time, dim});
cos.view(_cos.index<T>({offset, 0}), {max_time, dim});
});

StorageView y(dtype, device);
_rotary_op(x, sin, cos, y);
_rotary_op(x, _sin, _cos, y, offsets, step);
x = std::move(y);
}

Expand Down
17 changes: 15 additions & 2 deletions src/ops/rotary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,24 @@ namespace ctranslate2 {
void Rotary::operator()(const StorageView& input,
const StorageView& sin,
const StorageView& cos,
StorageView& output) const {
StorageView& output,
const StorageView* offsets,
const dim_t step) const {
PROFILE("Rotary");

if (offsets) {
const dim_t batch_size = input.size() / (input.dim(-1) * input.dim(-2));
if (offsets->size() != batch_size)
throw std::invalid_argument("Offsets has size "
+ std::to_string(offsets->size())
+ " which is different than the current batch size "
+ std::to_string(batch_size));
}

output.resize_as(input);

DEVICE_AND_FLOAT_DISPATCH("Rotary", input.device(), input.dtype(),
(compute<D, T>(input, sin, cos, output)));
(compute<D, T>(step, offsets, input, sin, cos, output)));
}

}
Expand Down
25 changes: 19 additions & 6 deletions src/ops/rotary_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ namespace ctranslate2 {
const T* sin,
const T* cos,
T* output,
const int32_t* offsets,
const dim_t step,
const dim_t batch_size,
const dim_t max_time,
const dim_t ndims,
Expand All @@ -18,9 +20,15 @@ namespace ctranslate2 {

cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
for (dim_t b = begin; b < end; ++b) {
const dim_t offset = offsets ? offsets[b] : 0;

for (dim_t t = 0; t < max_time; ++t) {
const T* s = sin + t * ndims;
const T* c = cos + t * ndims;
const dim_t signal_time = t - offset + step;
if (signal_time < 0)
continue;

const T* s = sin + signal_time * ndims;
const T* c = cos + signal_time * ndims;

const T* x = input + b * (max_time * depth) + t * depth;
T* y = output + b * (max_time * depth) + t * depth;
Expand All @@ -40,7 +48,9 @@ namespace ctranslate2 {
}

template <Device D, typename T>
void Rotary::compute(const StorageView& input,
void Rotary::compute(const dim_t step,
const StorageView* offsets,
const StorageView& input,
const StorageView& sin,
const StorageView& cos,
StorageView& output) const {
Expand All @@ -52,17 +62,20 @@ namespace ctranslate2 {
const auto* x = input.data<T>();
const auto* s = sin.data<T>();
const auto* c = cos.data<T>();
const auto* o = offsets ? offsets->data<int32_t>() : nullptr;
auto* y = output.data<T>();

if (_interleave)
rotary_kernel<T, true>(x, s, c, y, batch_size, max_time, ndims, depth);
rotary_kernel<T, true>(x, s, c, y, o, step, batch_size, max_time, ndims, depth);
else
rotary_kernel<T, false>(x, s, c, y, batch_size, max_time, ndims, depth);
rotary_kernel<T, false>(x, s, c, y, o, step, batch_size, max_time, ndims, depth);
}

#define DECLARE_IMPL(T) \
template void \
Rotary::compute<Device::CPU, T>(const StorageView&, \
Rotary::compute<Device::CPU, T>(const dim_t, \
const StorageView*, \
const StorageView&, \
const StorageView&, \
const StorageView&, \
StorageView&) const;
Expand Down
26 changes: 20 additions & 6 deletions src/ops/rotary_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,26 @@ namespace ctranslate2 {
const T* sin,
const T* cos,
T* y,
const int32_t* offsets,
const cuda::index_t step,
const cuda::index_t max_time,
const cuda::index_t ndims,
const cuda::index_t depth) {
const auto batch = blockIdx.x / max_time;
const auto time = blockIdx.x % max_time;
const auto middle = ndims / 2;

const int32_t offset = offsets ? offsets[batch] : 0;
const int32_t signal_time = time - offset + step;

if (signal_time < 0)
return;

x += blockIdx.x * depth;
y += blockIdx.x * depth;

sin += time * ndims;
cos += time * ndims;
sin += signal_time * ndims;
cos += signal_time * ndims;

using C = typename ComputeType<T>::type;

Expand All @@ -54,7 +63,9 @@ namespace ctranslate2 {
}

template <Device D, typename T>
void Rotary::compute(const StorageView& input,
void Rotary::compute(const dim_t step,
const StorageView* offsets,
const StorageView& input,
const StorageView& sin,
const StorageView& cos,
StorageView& output) const {
Expand All @@ -68,21 +79,24 @@ namespace ctranslate2 {
const auto* x = cuda::device_cast(input.data<T>());
const auto* s = cuda::device_cast(sin.data<T>());
const auto* c = cuda::device_cast(cos.data<T>());
const auto* o = offsets ? offsets->data<int32_t>() : nullptr;
auto* y = cuda::device_cast(output.data<T>());

using DeviceT = cuda::device_type<T>;

if (_interleave)
rotary_kernel<DeviceT, true><<<blocks, threads, 0, cuda::get_cuda_stream()>>>(
x, s, c, y, max_time, ndims, depth);
x, s, c, y, o, step, max_time, ndims, depth);
else
rotary_kernel<DeviceT, false><<<blocks, threads, 0, cuda::get_cuda_stream()>>>(
x, s, c, y, max_time, ndims, depth);
x, s, c, y, o, step, max_time, ndims, depth);
}

#define DECLARE_IMPL(T) \
template void \
Rotary::compute<Device::CUDA, T>(const StorageView&, \
Rotary::compute<Device::CUDA, T>(const dim_t, \
const StorageView*, \
const StorageView&, \
const StorageView&, \
const StorageView&, \
StorageView&) const;
Expand Down
55 changes: 55 additions & 0 deletions tests/layers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,61 @@ TEST_P(LayerDeviceFPTest, RotaryEmbedding) {
}
}

TEST_P(LayerDeviceFPTest, RotaryEmbeddingOffset) {
const Device device = GetParam().device;
const DataType dtype = GetParam().dtype;
const float error = GetParam().error;

// The input and expected output were generated from PyTorch using the rotary embeddings layer from
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
//
// q = torch.rand([2, 4, 1, 6])
// print(q.numpy().flatten().tolist())
// position_ids = torch.tensor([[2], [4]])
// llama = LlamaRotaryEmbedding(6)
// cos, sin = llama(q, seq_len=12)
// q, _ = apply_rotary_pos_emb(q, q, cos, sin, position_ids)
// print(q.numpy().flatten().tolist())

const StorageView input({2, 4, 1, 6}, std::vector<float>{
0.23646563291549683, 0.9993839263916016, 0.4034807085990906, 0.5447465777397156,
0.9373598098754883, 0.3172609210014343, 0.19522875547409058, 0.707885205745697,
0.0094565749168396, 0.9327296018600464, 0.4594022035598755, 0.5009559392929077,
0.0743250846862793, 0.5236821174621582, 0.18698054552078247, 0.3285903334617615,
0.6952935457229614, 0.46870940923690796, 0.578666090965271, 0.11945730447769165,
0.16381490230560303, 0.38767993450164795, 0.15953445434570312, 0.5320672392845154,
0.10134690999984741, 0.26156187057495117, 0.9635066986083984, 0.7839735746383667,
0.2869170308113098, 0.5146785378456116, 0.2806260585784912, 0.6367897987365723,
0.9142636656761169, 0.7779543995857239, 0.5855610370635986, 0.23491668701171875,
0.6287166476249695, 0.400934636592865, 0.8011993169784546, 0.4153047204017639,
0.7990701198577881, 0.01711505651473999, 0.19538897275924683, 0.21076786518096924,
0.9088703989982605, 0.8127486109733582, 0.9860213994979858, 0.9132919907569885
}, device);

const StorageView expected({2, 4, 1, 6}, std::vector<float>{
-0.5937410593032837, 0.9081889986991882, 0.40210992097854614, -0.011676982045173645,
1.0259650945663452, 0.31899651885032654, -0.9293724298477173, 0.6622512936592102,
0.00729794055223465, -0.21063147485256195, 0.5230439901351929, 0.5009920597076416,
-0.3297165036201477, 0.4569746255874634, 0.18495920300483704, -0.06915822625160217,
0.7408443093299866, 0.4695107340812683, -0.5933264493942261, 0.10415434092283249,
0.16152077913284302, 0.3648477792739868, 0.16992105543613434, 0.5327681303024292,
0.5270683765411377, 0.20410215854644775, 0.9590356349945068, -0.589138925075531,
0.33027005195617676, 0.5229625701904297, 0.4053283929824829, 0.5177521109580994,
0.9122052788734436, -0.7208834290504456, 0.6930481195449829, 0.24278675019741058,
-0.09665298461914062, 0.24653685092926025, 0.8010221123695374, -0.7472755908966064,
0.8593493103981018, 0.02401886135339737, 0.4873754382133484, 0.025127321481704712,
0.900966227054596, -0.6791187524795532, 1.0079830884933472, 0.9210903644561768
}, device);

const StorageView offsets({2, 4}, std::vector<int32_t>{3, 3, 3, 3, 1, 1, 1, 1}, device);
const dim_t step = 5;

layers::RotaryEmbeddings rotary_embeddings(0, false);
StorageView x = input.to(dtype);
rotary_embeddings.apply(x, step, &offsets);
expect_storage_eq(x.to_float32(), expected, error);
}

TEST(LayerTest, Padder) {
const StorageView lengths({3}, std::vector<int32_t>{2, 3, 1});
const Padder padder(lengths, /*max_time=*/4);
Expand Down