From aaae9ec70d70adbfee16409044e44ce9c045aa2c Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Thu, 27 Jul 2023 13:59:54 +0200 Subject: [PATCH] Accept left offsets in the rotary embeddings layer --- include/ctranslate2/layers/attention.h | 2 +- include/ctranslate2/ops/rotary.h | 8 +++- src/layers/attention.cc | 14 ++----- src/ops/rotary.cc | 17 +++++++- src/ops/rotary_cpu.cc | 25 +++++++++--- src/ops/rotary_gpu.cu | 26 +++++++++--- tests/layers_test.cc | 55 ++++++++++++++++++++++++++ 7 files changed, 119 insertions(+), 28 deletions(-) diff --git a/include/ctranslate2/layers/attention.h b/include/ctranslate2/layers/attention.h index e76e7f1b9..31b13d4b5 100644 --- a/include/ctranslate2/layers/attention.h +++ b/include/ctranslate2/layers/attention.h @@ -78,7 +78,7 @@ namespace ctranslate2 { const dim_t num_initial_positions = 2048, const float base = 10000); - void apply(StorageView& x, const dim_t offset = 0); + void apply(StorageView& x, const dim_t step = 0, const StorageView* offsets = nullptr); private: void initialize(const dim_t num_positions, diff --git a/include/ctranslate2/ops/rotary.h b/include/ctranslate2/ops/rotary.h index c0a4cf091..ebd24feca 100644 --- a/include/ctranslate2/ops/rotary.h +++ b/include/ctranslate2/ops/rotary.h @@ -12,14 +12,18 @@ namespace ctranslate2 { void operator()(const StorageView& input, const StorageView& sin, const StorageView& cos, - StorageView& output) const; + StorageView& output, + const StorageView* offsets = nullptr, + const dim_t step = 0) const; private: const dim_t _ndims; const bool _interleave; template - void compute(const StorageView& input, + void compute(const dim_t step, + const StorageView* offsets, + const StorageView& input, const StorageView& sin, const StorageView& cos, StorageView& output) const; diff --git a/src/layers/attention.cc b/src/layers/attention.cc index 4b057b535..01beb3e20 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -602,28 +602,20 @@ namespace ctranslate2 { { } - void RotaryEmbeddings::apply(StorageView& x, const dim_t offset) { + void RotaryEmbeddings::apply(StorageView& x, const dim_t step, const StorageView* offsets) { const Device device = x.device(); const DataType dtype = x.dtype(); const dim_t max_time = x.dim(-2); const dim_t dim = _dim == 0 ? x.dim(-1) : _dim; - if (!_sin || offset + max_time > _sin.dim(0)) { + if (!_sin || step + max_time > _sin.dim(0)) { const dim_t cur_num_positions = _sin ? _sin.dim(0) : 0; const dim_t new_num_positions = cur_num_positions + _num_initial_positions; initialize(new_num_positions, dim, device, dtype); } - StorageView sin(dtype, device); - StorageView cos(dtype, device); - TYPE_DISPATCH(dtype, - { - sin.view(_sin.index({offset, 0}), {max_time, dim}); - cos.view(_cos.index({offset, 0}), {max_time, dim}); - }); - StorageView y(dtype, device); - _rotary_op(x, sin, cos, y); + _rotary_op(x, _sin, _cos, y, offsets, step); x = std::move(y); } diff --git a/src/ops/rotary.cc b/src/ops/rotary.cc index 0058db784..d47c3bb53 100644 --- a/src/ops/rotary.cc +++ b/src/ops/rotary.cc @@ -14,11 +14,24 @@ namespace ctranslate2 { void Rotary::operator()(const StorageView& input, const StorageView& sin, const StorageView& cos, - StorageView& output) const { + StorageView& output, + const StorageView* offsets, + const dim_t step) const { + PROFILE("Rotary"); + + if (offsets) { + const dim_t batch_size = input.size() / (input.dim(-1) * input.dim(-2)); + if (offsets->size() != batch_size) + throw std::invalid_argument("Offsets has size " + + std::to_string(offsets->size()) + + " which is different than the current batch size " + + std::to_string(batch_size)); + } + output.resize_as(input); DEVICE_AND_FLOAT_DISPATCH("Rotary", input.device(), input.dtype(), - (compute(input, sin, cos, output))); + (compute(step, offsets, input, sin, cos, output))); } } diff --git a/src/ops/rotary_cpu.cc b/src/ops/rotary_cpu.cc index bdf35a5ae..1c827297b 100644 --- a/src/ops/rotary_cpu.cc +++ b/src/ops/rotary_cpu.cc @@ -10,6 +10,8 @@ namespace ctranslate2 { const T* sin, const T* cos, T* output, + const int32_t* offsets, + const dim_t step, const dim_t batch_size, const dim_t max_time, const dim_t ndims, @@ -18,9 +20,15 @@ namespace ctranslate2 { cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) { for (dim_t b = begin; b < end; ++b) { + const dim_t offset = offsets ? offsets[b] : 0; + for (dim_t t = 0; t < max_time; ++t) { - const T* s = sin + t * ndims; - const T* c = cos + t * ndims; + const dim_t signal_time = t - offset + step; + if (signal_time < 0) + continue; + + const T* s = sin + signal_time * ndims; + const T* c = cos + signal_time * ndims; const T* x = input + b * (max_time * depth) + t * depth; T* y = output + b * (max_time * depth) + t * depth; @@ -40,7 +48,9 @@ namespace ctranslate2 { } template - void Rotary::compute(const StorageView& input, + void Rotary::compute(const dim_t step, + const StorageView* offsets, + const StorageView& input, const StorageView& sin, const StorageView& cos, StorageView& output) const { @@ -52,17 +62,20 @@ namespace ctranslate2 { const auto* x = input.data(); const auto* s = sin.data(); const auto* c = cos.data(); + const auto* o = offsets ? offsets->data() : nullptr; auto* y = output.data(); if (_interleave) - rotary_kernel(x, s, c, y, batch_size, max_time, ndims, depth); + rotary_kernel(x, s, c, y, o, step, batch_size, max_time, ndims, depth); else - rotary_kernel(x, s, c, y, batch_size, max_time, ndims, depth); + rotary_kernel(x, s, c, y, o, step, batch_size, max_time, ndims, depth); } #define DECLARE_IMPL(T) \ template void \ - Rotary::compute(const StorageView&, \ + Rotary::compute(const dim_t, \ + const StorageView*, \ + const StorageView&, \ const StorageView&, \ const StorageView&, \ StorageView&) const; diff --git a/src/ops/rotary_gpu.cu b/src/ops/rotary_gpu.cu index 511608ce0..2411e90df 100644 --- a/src/ops/rotary_gpu.cu +++ b/src/ops/rotary_gpu.cu @@ -29,17 +29,26 @@ namespace ctranslate2 { const T* sin, const T* cos, T* y, + const int32_t* offsets, + const cuda::index_t step, const cuda::index_t max_time, const cuda::index_t ndims, const cuda::index_t depth) { + const auto batch = blockIdx.x / max_time; const auto time = blockIdx.x % max_time; const auto middle = ndims / 2; + const int32_t offset = offsets ? offsets[batch] : 0; + const int32_t signal_time = time - offset + step; + + if (signal_time < 0) + return; + x += blockIdx.x * depth; y += blockIdx.x * depth; - sin += time * ndims; - cos += time * ndims; + sin += signal_time * ndims; + cos += signal_time * ndims; using C = typename ComputeType::type; @@ -54,7 +63,9 @@ namespace ctranslate2 { } template - void Rotary::compute(const StorageView& input, + void Rotary::compute(const dim_t step, + const StorageView* offsets, + const StorageView& input, const StorageView& sin, const StorageView& cos, StorageView& output) const { @@ -68,21 +79,24 @@ namespace ctranslate2 { const auto* x = cuda::device_cast(input.data()); const auto* s = cuda::device_cast(sin.data()); const auto* c = cuda::device_cast(cos.data()); + const auto* o = offsets ? offsets->data() : nullptr; auto* y = cuda::device_cast(output.data()); using DeviceT = cuda::device_type; if (_interleave) rotary_kernel<<>>( - x, s, c, y, max_time, ndims, depth); + x, s, c, y, o, step, max_time, ndims, depth); else rotary_kernel<<>>( - x, s, c, y, max_time, ndims, depth); + x, s, c, y, o, step, max_time, ndims, depth); } #define DECLARE_IMPL(T) \ template void \ - Rotary::compute(const StorageView&, \ + Rotary::compute(const dim_t, \ + const StorageView*, \ + const StorageView&, \ const StorageView&, \ const StorageView&, \ StorageView&) const; diff --git a/tests/layers_test.cc b/tests/layers_test.cc index f1359bc05..f2489b46e 100644 --- a/tests/layers_test.cc +++ b/tests/layers_test.cc @@ -180,6 +180,61 @@ TEST_P(LayerDeviceFPTest, RotaryEmbedding) { } } +TEST_P(LayerDeviceFPTest, RotaryEmbeddingOffset) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + + // The input and expected output were generated from PyTorch using the rotary embeddings layer from + // https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + // + // q = torch.rand([2, 4, 1, 6]) + // print(q.numpy().flatten().tolist()) + // position_ids = torch.tensor([[2], [4]]) + // llama = LlamaRotaryEmbedding(6) + // cos, sin = llama(q, seq_len=12) + // q, _ = apply_rotary_pos_emb(q, q, cos, sin, position_ids) + // print(q.numpy().flatten().tolist()) + + const StorageView input({2, 4, 1, 6}, std::vector{ + 0.23646563291549683, 0.9993839263916016, 0.4034807085990906, 0.5447465777397156, + 0.9373598098754883, 0.3172609210014343, 0.19522875547409058, 0.707885205745697, + 0.0094565749168396, 0.9327296018600464, 0.4594022035598755, 0.5009559392929077, + 0.0743250846862793, 0.5236821174621582, 0.18698054552078247, 0.3285903334617615, + 0.6952935457229614, 0.46870940923690796, 0.578666090965271, 0.11945730447769165, + 0.16381490230560303, 0.38767993450164795, 0.15953445434570312, 0.5320672392845154, + 0.10134690999984741, 0.26156187057495117, 0.9635066986083984, 0.7839735746383667, + 0.2869170308113098, 0.5146785378456116, 0.2806260585784912, 0.6367897987365723, + 0.9142636656761169, 0.7779543995857239, 0.5855610370635986, 0.23491668701171875, + 0.6287166476249695, 0.400934636592865, 0.8011993169784546, 0.4153047204017639, + 0.7990701198577881, 0.01711505651473999, 0.19538897275924683, 0.21076786518096924, + 0.9088703989982605, 0.8127486109733582, 0.9860213994979858, 0.9132919907569885 + }, device); + + const StorageView expected({2, 4, 1, 6}, std::vector{ + -0.5937410593032837, 0.9081889986991882, 0.40210992097854614, -0.011676982045173645, + 1.0259650945663452, 0.31899651885032654, -0.9293724298477173, 0.6622512936592102, + 0.00729794055223465, -0.21063147485256195, 0.5230439901351929, 0.5009920597076416, + -0.3297165036201477, 0.4569746255874634, 0.18495920300483704, -0.06915822625160217, + 0.7408443093299866, 0.4695107340812683, -0.5933264493942261, 0.10415434092283249, + 0.16152077913284302, 0.3648477792739868, 0.16992105543613434, 0.5327681303024292, + 0.5270683765411377, 0.20410215854644775, 0.9590356349945068, -0.589138925075531, + 0.33027005195617676, 0.5229625701904297, 0.4053283929824829, 0.5177521109580994, + 0.9122052788734436, -0.7208834290504456, 0.6930481195449829, 0.24278675019741058, + -0.09665298461914062, 0.24653685092926025, 0.8010221123695374, -0.7472755908966064, + 0.8593493103981018, 0.02401886135339737, 0.4873754382133484, 0.025127321481704712, + 0.900966227054596, -0.6791187524795532, 1.0079830884933472, 0.9210903644561768 + }, device); + + const StorageView offsets({2, 4}, std::vector{3, 3, 3, 3, 1, 1, 1, 1}, device); + const dim_t step = 5; + + layers::RotaryEmbeddings rotary_embeddings(0, false); + StorageView x = input.to(dtype); + rotary_embeddings.apply(x, step, &offsets); + expect_storage_eq(x.to_float32(), expected, error); +} + TEST(LayerTest, Padder) { const StorageView lengths({3}, std::vector{2, 3, 1}); const Padder padder(lengths, /*max_time=*/4);