Skip to content

Commit

Permalink
Paged attn (#1036)
Browse files Browse the repository at this point in the history
* nice code
* device type adjustment

Signed-off-by: Liu, Kaixuan <[email protected]>
  • Loading branch information
kaixuanliu authored Nov 27, 2024
1 parent 51030e5 commit 587837e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,11 +602,11 @@ def __init__(self, module, config) -> None:
self.q_slice = self.q_proj.out_features
self.k_slice = self.q_slice + self.k_proj.out_features
self.v_slice = self.k_slice + self.v_proj.out_features
if self.module_device == "cpu":
if self.module_device.type == "cpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = LinearAdd(module.o_proj)

elif self.module_device == "xpu":
elif self.module_device.type == "xpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = XPULinearAdd(module.o_proj)

Expand Down

0 comments on commit 587837e

Please sign in to comment.