diff --git a/src/neurostatslib/glm.py b/src/neurostatslib/glm.py index 134adae3..8ee3e24c 100644 --- a/src/neurostatslib/glm.py +++ b/src/neurostatslib/glm.py @@ -565,7 +565,7 @@ def scan_fn( 0 ] - # Extract the corresponding slice of the feedforward input for the current time step + # Extract the slice of the feedforward input for the current time step input_slice = jax.lax.dynamic_slice( feed_forward_contrib, (chunk, 0),