Skip to content

Commit

Permalink
Copy update and reset gates in GRU op before applying sigmoid activation
Browse files Browse the repository at this point in the history
This is a workaround until #192 is
solved more generally.
  • Loading branch information
robertknight committed May 20, 2024
1 parent 2851fed commit e5c6a7f
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions src/ops/rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use rten_tensor::{Tensor, TensorView};
use crate::check_dims;
use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB};
use crate::ops::{
add_in_place, mul_in_place, sigmoid_in_place, tanh, tanh_in_place, InputList, IntoOpResult,
OpError, Operator, Output,
add_in_place, mul_in_place, sigmoid, sigmoid_in_place, tanh, tanh_in_place, InputList,
IntoOpResult, OpError, Operator, Output,
};
use crate::tensor_pool::{AutoReturn, TensorPool};

Expand Down Expand Up @@ -265,26 +265,31 @@ pub fn gru(
hidden_scratch_reset_update_gates.as_dyn(),
);

// nb. This is slower than it should be because it falls back to
// the slow path for non-contiguous tensors.
sigmoid_in_place(update_reset_gates.as_dyn_mut());
// Copy gates before applying activation because `sigmoid_in_place`
// and `tanh_in_place` are slow with non-contiguous tensors, and
// `update_reset_gates` will be non-contiguous if the batch size is
// > 1. See https://github.com/robertknight/rten/issues/192.
//
// Note `gate_range` can be still used because the update and reset
// gates are in the same positions in the `update_reset_gates` slice
// as `gates`.
let update_reset_gates = sigmoid(pool, update_reset_gates.as_dyn()).auto_return(pool);
let update_gate = update_reset_gates.slice::<2, _>((.., gate_range(UPDATE_GATE)));
let reset_gate = update_reset_gates.slice::<2, _>((.., gate_range(RESET_GATE)));

// Combine inputs for hidden gate and apply activation.
let reset_gate = gates.slice::<2, _>((.., gate_range(RESET_GATE)));
let mut hidden_gate_recurrent =
hidden_scratch.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE)));
mul_in_place(hidden_gate_recurrent.as_dyn_mut(), reset_gate.as_dyn());

let mut hidden_gate = gates.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE)));
add_in_place(hidden_gate.as_dyn_mut(), hidden_gate_recurrent.as_dyn());

// Copy the hidden gate because `tanh_in_place` is slow with
// non-contiguous tensors.
// See note above about `sigmoid_in_place`.
let hidden_gate = tanh(pool, hidden_gate.as_dyn()).auto_return(pool);

// Compute next hidden state
let mut hidden_item = hidden.slice_mut::<2, _>([dir]);
let update_gate = gates.slice::<2, _>((.., gate_range(UPDATE_GATE)));

for (hidden, update, hidden_gate) in zip3(
hidden_item.iter_mut(),
Expand Down

0 comments on commit e5c6a7f

Please sign in to comment.