Skip to content

Commit

Permalink
Fix concat bug in rotary embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Dec 22, 2024
1 parent a745039 commit 17c06c3
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def rewrite(self, op, x, inv_freq, position_ids, interleaved, num_heads, **_):
pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1)
angles = np.matmul(pos_id_range, inv_freq_values)
cos_value = np.cos(angles)
cos_value = np.concatenate([cos_value, cos_value], axis=-1)
# cos_value = np.concatenate([cos_value, cos_value], axis=-1)
sin_value = np.sin(angles)
sin_value = np.concatenate([sin_value, sin_value], axis=-1)
# sin_value = np.concatenate([sin_value, sin_value], axis=-1)
cos_2d = op.Constant(value=ir.tensor(cos_value))
# cos = op.Gather(cos_2d, position_ids, axis=0)
sin_2d = op.Constant(value=ir.tensor(sin_value))
Expand Down

0 comments on commit 17c06c3

Please sign in to comment.