diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py index 440f4a111..538070feb 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py @@ -60,9 +60,9 @@ def rewrite(self, op, x, inv_freq, position_ids, interleaved, num_heads, **_): pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1) angles = np.matmul(pos_id_range, inv_freq_values) cos_value = np.cos(angles) - cos_value = np.concatenate([cos_value, cos_value], axis=-1) + # cos_value = np.concatenate([cos_value, cos_value], axis=-1) sin_value = np.sin(angles) - sin_value = np.concatenate([sin_value, sin_value], axis=-1) + # sin_value = np.concatenate([sin_value, sin_value], axis=-1) cos_2d = op.Constant(value=ir.tensor(cos_value)) # cos = op.Gather(cos_2d, position_ids, axis=0) sin_2d = op.Constant(value=ir.tensor(sin_value))