Skip to content

Commit

Permalink
Merge pull request #193 from robertknight/gru-sigmoid-copy
Browse files Browse the repository at this point in the history
Copy gates in LSTM, GRU ops before applying activations
  • Loading branch information
robertknight authored May 20, 2024
2 parents 2851fed + bf848f7 commit 7f24d0f
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions src/ops/rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ use rten_tensor::{Tensor, TensorView};
use crate::check_dims;
use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB};
use crate::ops::{
add_in_place, mul_in_place, sigmoid_in_place, tanh, tanh_in_place, InputList, IntoOpResult,
OpError, Operator, Output,
add_in_place, mul_in_place, sigmoid, tanh, InputList, IntoOpResult, OpError, Operator, Output,
};
use crate::tensor_pool::{AutoReturn, TensorPool};

Expand Down Expand Up @@ -265,26 +264,31 @@ pub fn gru(
hidden_scratch_reset_update_gates.as_dyn(),
);

// nb. This is slower than it should be because it falls back to
// the slow path for non-contiguous tensors.
sigmoid_in_place(update_reset_gates.as_dyn_mut());
// Copy gates before applying activation because `sigmoid_in_place`
// and `tanh_in_place` are slow with non-contiguous tensors, and
// `update_reset_gates` will be non-contiguous if the batch size is
// > 1. See https://github.com/robertknight/rten/issues/192.
//
// Note `gate_range` can be still used because the update and reset
// gates are in the same positions in the `update_reset_gates` slice
// as `gates`.
let update_reset_gates = sigmoid(pool, update_reset_gates.as_dyn()).auto_return(pool);
let update_gate = update_reset_gates.slice::<2, _>((.., gate_range(UPDATE_GATE)));
let reset_gate = update_reset_gates.slice::<2, _>((.., gate_range(RESET_GATE)));

// Combine inputs for hidden gate and apply activation.
let reset_gate = gates.slice::<2, _>((.., gate_range(RESET_GATE)));
let mut hidden_gate_recurrent =
hidden_scratch.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE)));
mul_in_place(hidden_gate_recurrent.as_dyn_mut(), reset_gate.as_dyn());

let mut hidden_gate = gates.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE)));
add_in_place(hidden_gate.as_dyn_mut(), hidden_gate_recurrent.as_dyn());

// Copy the hidden gate because `tanh_in_place` is slow with
// non-contiguous tensors.
// See note above about `sigmoid_in_place`.
let hidden_gate = tanh(pool, hidden_gate.as_dyn()).auto_return(pool);

// Compute next hidden state
let mut hidden_item = hidden.slice_mut::<2, _>([dir]);
let update_gate = gates.slice::<2, _>((.., gate_range(UPDATE_GATE)));

for (hidden, update, hidden_gate) in zip3(
hidden_item.iter_mut(),
Expand Down Expand Up @@ -489,19 +493,19 @@ pub fn lstm(
add_in_place(gates.view_mut(), hidden_bias.as_dyn());
}

let mut iof_gates = gates.slice_mut::<2, _>((
// Copy gates to work around `tanh_in_place` and `sigmoid_in_place`
// being slow for non-contiguous inputs. See notes in GRU op.
let iof_gates = gates.slice::<2, _>((
..,
gate_range(INPUT_GATE).start..gate_range(FORGET_GATE).end,
));
sigmoid_in_place(iof_gates.as_dyn_mut());

let mut cell_gate = gates.slice_mut::<2, _>((.., gate_range(CELL_GATE)));
tanh_in_place(cell_gate.as_dyn_mut());
let iof_gates = sigmoid(pool, iof_gates.as_dyn()).auto_return(pool);
let input_gate = iof_gates.slice::<2, _>((.., gate_range(INPUT_GATE)));
let out_gate = iof_gates.slice::<2, _>((.., gate_range(OUTPUT_GATE)));
let forget_gate = iof_gates.slice::<2, _>((.., gate_range(FORGET_GATE)));

let input_gate = gates.slice::<2, _>((.., gate_range(INPUT_GATE)));
let out_gate = gates.slice::<2, _>((.., gate_range(OUTPUT_GATE)));
let forget_gate = gates.slice::<2, _>((.., gate_range(FORGET_GATE)));
let cell_gate = gates.slice::<2, _>((.., gate_range(CELL_GATE)));
let cell_gate = tanh(pool, cell_gate.as_dyn()).auto_return(pool);

// Update cell and hidden state
let mut cell_item = cell.slice_mut::<2, _>([dir]);
Expand Down

0 comments on commit 7f24d0f

Please sign in to comment.