Skip to content

Commit

Permalink
fix gates size retrieval logic in _rnn_flops (#3921)
Browse files Browse the repository at this point in the history
Co-authored-by: Pinstripe Potoroo <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
3 people authored Jul 24, 2023
1 parent 23a11a3 commit f4d18fa
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions deepspeed/profiling/flops_profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,11 +980,11 @@ def _reload_tensor_methods():


def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):
input_size, hidden_size = w_ih.shape
gates_size = w_ih.shape[0]
# matrix matrix mult ih state and internal state
flops += 2 * input_size * hidden_size - hidden_size
flops += 2 * w_ih.shape[0] * w_ih.shape[1] - gates_size
# matrix matrix mult hh state and internal state
flops += 2 * hidden_size * hidden_size - hidden_size
flops += 2 * w_hh.shape[0] * w_hh.shape[1] - gates_size
if isinstance(rnn_module, (nn.RNN, nn.RNNCell)):
# add both operations
flops += rnn_module.hidden_size
Expand Down

0 comments on commit f4d18fa

Please sign in to comment.