Skip to content

Commit

Permalink
nice code (#1035)
Browse files Browse the repository at this point in the history
Signed-off-by: Liu, Kaixuan <[email protected]>
  • Loading branch information
kaixuanliu authored Nov 27, 2024
1 parent bcce6b0 commit 51030e5
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 14 deletions.
2 changes: 1 addition & 1 deletion optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def __init__(self, module, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = next(module.parameters()).device.type
self.module_device = next(module.parameters()).device
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device
Expand Down
20 changes: 12 additions & 8 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.


import copy
import inspect
import logging
import warnings
Expand Down Expand Up @@ -331,8 +330,7 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
return self.model.prepare_inputs_for_generation(*args, **kwargs)

def generate(self, *args, **kwargs):
new_kwargs = copy.deepcopy(kwargs)
if is_ipex_version("<", "2.4.0") and self._add_patch and new_kwargs.get("assistant_model", None):
if is_ipex_version("<", "2.4.0") and self._add_patch and kwargs.get("assistant_model", None):
raise ValueError(
f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)
Expand All @@ -343,25 +341,31 @@ def generate(self, *args, **kwargs):
if is_transformers_version(">=", "4.45.0"):
if "ipex_paged" not in transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS:
transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append("ipex_paged")
if new_kwargs.get("generation_config", None):
new_kwargs["generation_config"].cache_implementation = "ipex_paged"
if kwargs.get("generation_config", None):
# Change cache implementation temporarily
orig_cache_implementation = kwargs["generation_config"].cache_implementation
kwargs["generation_config"].cache_implementation = "ipex_paged"

if self._add_patch and new_kwargs.get("assistant_model", None):
if self._add_patch and kwargs.get("assistant_model", None):
transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values
elif self._add_patch:
transformers.generation.candidate_generator._crop_past_key_values = _ipex_crop_past_key_values

try:
result = super().generate(*args, **new_kwargs)
result = super().generate(*args, **kwargs)
except Exception as e:
transformers.generation.utils._crop_past_key_values = _crop_past_key_values
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values
raise e

if self._add_patch and new_kwargs.get("assistant_model", None):
if self._add_patch and kwargs.get("assistant_model", None):
transformers.generation.utils._crop_past_key_values = _crop_past_key_values
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values

# change back cache_implementation
if self._add_patch and kwargs.get("generation_config", None):
kwargs["generation_config"].cache_implementation = orig_cache_implementation

return result


Expand Down
7 changes: 2 additions & 5 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@


SEED = 42
torch.use_deterministic_algorithms(True)


class Timer(object):
Expand Down Expand Up @@ -104,7 +105,7 @@ def test_compare_to_transformers(self, model_arch):
# Compare tensor outputs
for output_name in {"logits", "last_hidden_state"}:
if output_name in transformers_outputs:
self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-4))
self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-3))
self.assertTrue(torch.allclose(outputs[output_name], loaded_model_outputs[output_name]))
self.assertTrue(torch.allclose(outputs[output_name], init_model_outputs[output_name]))

Expand Down Expand Up @@ -205,10 +206,7 @@ class IPEXModelForCausalLMTest(unittest.TestCase):
"gpt_neo",
"gpt_neox",
"mistral",
# "llama",
"llama2",
# "phi",
# "distilgpt2",
"mpt",
"opt",
)
Expand Down Expand Up @@ -431,7 +429,6 @@ class IPEXModelForImageClassificationIntegrationTest(unittest.TestCase):
IPEX_MODEL_CLASS = IPEXModelForImageClassification
SUPPORTED_ARCHITECTURES = (
"beit",
# "levit",
"mobilenet_v1",
"mobilenet_v2",
"mobilevit",
Expand Down
3 changes: 3 additions & 0 deletions tests/ipex/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from optimum.intel.pipelines import pipeline as ipex_pipeline


torch.use_deterministic_algorithms(True)


class PipelinesIntegrationTest(unittest.TestCase):
COMMON_SUPPORTED_ARCHITECTURES = (
"albert",
Expand Down

0 comments on commit 51030e5

Please sign in to comment.