Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Copy gates in LSTM, GRU ops before applying activations #193

Merged
merged 2 commits into from
May 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions src/ops/rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ use rten_tensor::{Tensor, TensorView};
use crate::check_dims;
use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB};
use crate::ops::{
add_in_place, mul_in_place, sigmoid_in_place, tanh, tanh_in_place, InputList, IntoOpResult,
OpError, Operator, Output,
add_in_place, mul_in_place, sigmoid, tanh, InputList, IntoOpResult, OpError, Operator, Output,
};
use crate::tensor_pool::{AutoReturn, TensorPool};

Expand Down Expand Up @@ -265,26 +264,31 @@ pub fn gru(
hidden_scratch_reset_update_gates.as_dyn(),
);

// nb. This is slower than it should be because it falls back to
// the slow path for non-contiguous tensors.
sigmoid_in_place(update_reset_gates.as_dyn_mut());
// Copy gates before applying activation because `sigmoid_in_place`
// and `tanh_in_place` are slow with non-contiguous tensors, and
// `update_reset_gates` will be non-contiguous if the batch size is
// > 1. See https://github.com/robertknight/rten/issues/192.
//
// Note `gate_range` can be still used because the update and reset
// gates are in the same positions in the `update_reset_gates` slice
// as `gates`.
let update_reset_gates = sigmoid(pool, update_reset_gates.as_dyn()).auto_return(pool);
let update_gate = update_reset_gates.slice::<2, _>((.., gate_range(UPDATE_GATE)));
let reset_gate = update_reset_gates.slice::<2, _>((.., gate_range(RESET_GATE)));

// Combine inputs for hidden gate and apply activation.
let reset_gate = gates.slice::<2, _>((.., gate_range(RESET_GATE)));
let mut hidden_gate_recurrent =
hidden_scratch.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE)));
mul_in_place(hidden_gate_recurrent.as_dyn_mut(), reset_gate.as_dyn());

let mut hidden_gate = gates.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE)));
add_in_place(hidden_gate.as_dyn_mut(), hidden_gate_recurrent.as_dyn());

// Copy the hidden gate because `tanh_in_place` is slow with
// non-contiguous tensors.
// See note above about `sigmoid_in_place`.
let hidden_gate = tanh(pool, hidden_gate.as_dyn()).auto_return(pool);

// Compute next hidden state
let mut hidden_item = hidden.slice_mut::<2, _>([dir]);
let update_gate = gates.slice::<2, _>((.., gate_range(UPDATE_GATE)));

for (hidden, update, hidden_gate) in zip3(
hidden_item.iter_mut(),
Expand Down Expand Up @@ -489,19 +493,19 @@ pub fn lstm(
add_in_place(gates.view_mut(), hidden_bias.as_dyn());
}

let mut iof_gates = gates.slice_mut::<2, _>((
// Copy gates to work around `tanh_in_place` and `sigmoid_in_place`
// being slow for non-contiguous inputs. See notes in GRU op.
let iof_gates = gates.slice::<2, _>((
..,
gate_range(INPUT_GATE).start..gate_range(FORGET_GATE).end,
));
sigmoid_in_place(iof_gates.as_dyn_mut());

let mut cell_gate = gates.slice_mut::<2, _>((.., gate_range(CELL_GATE)));
tanh_in_place(cell_gate.as_dyn_mut());
let iof_gates = sigmoid(pool, iof_gates.as_dyn()).auto_return(pool);
let input_gate = iof_gates.slice::<2, _>((.., gate_range(INPUT_GATE)));
let out_gate = iof_gates.slice::<2, _>((.., gate_range(OUTPUT_GATE)));
let forget_gate = iof_gates.slice::<2, _>((.., gate_range(FORGET_GATE)));

let input_gate = gates.slice::<2, _>((.., gate_range(INPUT_GATE)));
let out_gate = gates.slice::<2, _>((.., gate_range(OUTPUT_GATE)));
let forget_gate = gates.slice::<2, _>((.., gate_range(FORGET_GATE)));
let cell_gate = gates.slice::<2, _>((.., gate_range(CELL_GATE)));
let cell_gate = tanh(pool, cell_gate.as_dyn()).auto_return(pool);

// Update cell and hidden state
let mut cell_item = cell.slice_mut::<2, _>([dir]);
Expand Down
Loading