From e5c6a7f742750ff904d16f06e74b24034ae8a2df Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Mon, 20 May 2024 08:16:50 +0100 Subject: [PATCH 1/2] Copy update and reset gates in GRU op before applying sigmoid activation This is a workaround until https://github.com/robertknight/rten/issues/192 is solved more generally. --- src/ops/rnn.rs | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/ops/rnn.rs b/src/ops/rnn.rs index 0a6ad13c..dae3f24b 100644 --- a/src/ops/rnn.rs +++ b/src/ops/rnn.rs @@ -7,8 +7,8 @@ use rten_tensor::{Tensor, TensorView}; use crate::check_dims; use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB}; use crate::ops::{ - add_in_place, mul_in_place, sigmoid_in_place, tanh, tanh_in_place, InputList, IntoOpResult, - OpError, Operator, Output, + add_in_place, mul_in_place, sigmoid, sigmoid_in_place, tanh, tanh_in_place, InputList, + IntoOpResult, OpError, Operator, Output, }; use crate::tensor_pool::{AutoReturn, TensorPool}; @@ -265,12 +265,19 @@ pub fn gru( hidden_scratch_reset_update_gates.as_dyn(), ); - // nb. This is slower than it should be because it falls back to - // the slow path for non-contiguous tensors. - sigmoid_in_place(update_reset_gates.as_dyn_mut()); + // Copy gates before applying activation because `sigmoid_in_place` + // and `tanh_in_place` are slow with non-contiguous tensors, and + // `update_reset_gates` will be non-contiguous if the batch size is + // > 1. See https://github.com/robertknight/rten/issues/192. + // + // Note `gate_range` can be still used because the update and reset + // gates are in the same positions in the `update_reset_gates` slice + // as `gates`. + let update_reset_gates = sigmoid(pool, update_reset_gates.as_dyn()).auto_return(pool); + let update_gate = update_reset_gates.slice::<2, _>((.., gate_range(UPDATE_GATE))); + let reset_gate = update_reset_gates.slice::<2, _>((.., gate_range(RESET_GATE))); // Combine inputs for hidden gate and apply activation. - let reset_gate = gates.slice::<2, _>((.., gate_range(RESET_GATE))); let mut hidden_gate_recurrent = hidden_scratch.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE))); mul_in_place(hidden_gate_recurrent.as_dyn_mut(), reset_gate.as_dyn()); @@ -278,13 +285,11 @@ pub fn gru( let mut hidden_gate = gates.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE))); add_in_place(hidden_gate.as_dyn_mut(), hidden_gate_recurrent.as_dyn()); - // Copy the hidden gate because `tanh_in_place` is slow with - // non-contiguous tensors. + // See note above about `sigmoid_in_place`. let hidden_gate = tanh(pool, hidden_gate.as_dyn()).auto_return(pool); // Compute next hidden state let mut hidden_item = hidden.slice_mut::<2, _>([dir]); - let update_gate = gates.slice::<2, _>((.., gate_range(UPDATE_GATE))); for (hidden, update, hidden_gate) in zip3( hidden_item.iter_mut(), From bf848f7fb59c6a42960103eb555dfa113fbc8cb6 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Mon, 20 May 2024 08:34:31 +0100 Subject: [PATCH 2/2] Copy gates in LSTM op before applying activation This change was previously applied in the GRU operator to work around `sigmoid_in_place` and `tanh_in_place` being slow for non-contiguous inputs, which will be the case if the batch size is > 1. --- src/ops/rnn.rs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/ops/rnn.rs b/src/ops/rnn.rs index dae3f24b..4f9b4f7e 100644 --- a/src/ops/rnn.rs +++ b/src/ops/rnn.rs @@ -7,8 +7,7 @@ use rten_tensor::{Tensor, TensorView}; use crate::check_dims; use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB}; use crate::ops::{ - add_in_place, mul_in_place, sigmoid, sigmoid_in_place, tanh, tanh_in_place, InputList, - IntoOpResult, OpError, Operator, Output, + add_in_place, mul_in_place, sigmoid, tanh, InputList, IntoOpResult, OpError, Operator, Output, }; use crate::tensor_pool::{AutoReturn, TensorPool}; @@ -494,19 +493,19 @@ pub fn lstm( add_in_place(gates.view_mut(), hidden_bias.as_dyn()); } - let mut iof_gates = gates.slice_mut::<2, _>(( + // Copy gates to work around `tanh_in_place` and `sigmoid_in_place` + // being slow for non-contiguous inputs. See notes in GRU op. + let iof_gates = gates.slice::<2, _>(( .., gate_range(INPUT_GATE).start..gate_range(FORGET_GATE).end, )); - sigmoid_in_place(iof_gates.as_dyn_mut()); + let iof_gates = sigmoid(pool, iof_gates.as_dyn()).auto_return(pool); + let input_gate = iof_gates.slice::<2, _>((.., gate_range(INPUT_GATE))); + let out_gate = iof_gates.slice::<2, _>((.., gate_range(OUTPUT_GATE))); + let forget_gate = iof_gates.slice::<2, _>((.., gate_range(FORGET_GATE))); - let mut cell_gate = gates.slice_mut::<2, _>((.., gate_range(CELL_GATE))); - tanh_in_place(cell_gate.as_dyn_mut()); - - let input_gate = gates.slice::<2, _>((.., gate_range(INPUT_GATE))); - let out_gate = gates.slice::<2, _>((.., gate_range(OUTPUT_GATE))); - let forget_gate = gates.slice::<2, _>((.., gate_range(FORGET_GATE))); let cell_gate = gates.slice::<2, _>((.., gate_range(CELL_GATE))); + let cell_gate = tanh(pool, cell_gate.as_dyn()).auto_return(pool); // Update cell and hidden state let mut cell_item = cell.slice_mut::<2, _>([dir]);