From f3eb23a793799f7433b2d6f4b3300273cdfb4630 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 2 Dec 2024 11:55:45 +0100 Subject: [PATCH] [PT FE] Support AWQ models with GEMM module (#27624) ### Details: - *Support converting models quantized with AWQ algorithm from hf* ### Tickets: - *CVS-136653* --------- Signed-off-by: Maxim Vafin --- .../src/openvino/frontend/pytorch/gptq.py | 12 -- .../openvino/frontend/pytorch/quantized.py | 73 ++++++++++++ .../openvino/frontend/pytorch/ts_decoder.py | 26 ++--- src/frontends/pytorch/src/op/linear.cpp | 88 ++++++++++++++- src/frontends/pytorch/src/op_table.cpp | 7 +- src/frontends/pytorch/src/utils.cpp | 6 +- .../pytorch/detectron2_precommit | 7 -- tests/model_hub_tests/pytorch/test_llm.py | 105 +++++++++++++----- tests/requirements_pytorch | 1 + 9 files changed, 259 insertions(+), 66 deletions(-) create mode 100644 src/bindings/python/src/openvino/frontend/pytorch/quantized.py diff --git a/src/bindings/python/src/openvino/frontend/pytorch/gptq.py b/src/bindings/python/src/openvino/frontend/pytorch/gptq.py index a1c6aecc45d421..60a48c275d6681 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/gptq.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/gptq.py @@ -177,15 +177,3 @@ def unpatch_model(model): log.warning("Exception raised during GPTQ model unpatching. " "Depending on the exact issue it may lead to broken " "original model.\n%s", error) - - -def detect_gptq_model_raw(model): - return (model and getattr(model, 'config', None) and - getattr(model.config, 'quantization_config', None) and - model.config.quantization_config.quant_method == 'gptq') - - -def detect_gptq_model(model): - return (detect_gptq_model_raw(model) or - getattr(model, 'model', None) and - detect_gptq_model_raw(model.model)) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/quantized.py b/src/bindings/python/src/openvino/frontend/pytorch/quantized.py new file mode 100644 index 00000000000000..310e95cb9985d7 --- /dev/null +++ b/src/bindings/python/src/openvino/frontend/pytorch/quantized.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional +import torch +from openvino.frontend.pytorch import ModuleExtension, gptq +from openvino.frontend.pytorch.patch_model import patch_model, unpatch_model + + +def detect_quantized_model(model: torch.nn.Module) -> Optional[str]: + """Detects the quantization method used in a given PyTorch model. + + Args: + model (torch.nn.Module): The PyTorch model to check for quantization. + + Returns: + str: The quantization method if available, otherwise None. + """ + if (model and getattr(model, "config", None) + and getattr(model.config, "quantization_config", None)): + return model.config.quantization_config.quant_method + if getattr(model, "model", None): + return detect_quantized_model(model.model) + return None + + +def patch_quantized(model: torch.nn.Module) -> None: + """Patches a model based on its quantization type ("awq" or "gptq"). + + Args: + model (torch.nn.Module): The model to patch. + + Raises: + RuntimeError: If the quantization type is unknown. + """ + quant_type = detect_quantized_model(model) + if quant_type == "awq": + extensions = {} + try: + from awq.modules.linear import WQLinear_GEMM + extensions[WQLinear_GEMM] = ModuleExtension( + WQLinear_GEMM, "ov_ext::awq_gemm", + convert=lambda module, target_op, *args, **kwargs: target_op( + args[0], module.qweight, module.qzeros, module.scales, + torch.tensor(module.group_size), + torch.tensor(module.w_bit), module.bias), + evaluate=lambda module, *args, **kwargs: torch.full( + list(args[0].shape[:-1]) + [module.out_features], 0.5, + dtype=torch.float32)) # type: ignore + except ImportError: + pass + patch_model(model, extensions, + "_openvino_quantized_patch_orig_forward") # type: ignore + elif quant_type == "gptq": + model._openvino_gptq_patched = True + gptq.patch_model(model) # type: ignore + else: + raise RuntimeError(f"Unknown quantization type: {quant_type}.") + + +def unpatch_quantized(model: torch.nn.Module) -> None: + """Reverts the patching applied to a quantized PyTorch model. + + Args: + model (torch.nn.Module): The model to unpatch. + """ + if getattr(model, "_openvino_gptq_patched", False): + gptq.unpatch_model(model) # type: ignore + del model._openvino_gptq_patched + else: + unpatch_model(model, + "_openvino_quantized_patch_orig_forward") # type: ignore diff --git a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py index af8eafda8e9be7..62efca9a4fee0b 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py @@ -16,7 +16,7 @@ graph_has_ops, ) from openvino.runtime import opset11 as ops -from openvino.frontend.pytorch import gptq, patch_model +from openvino.frontend.pytorch import quantized, patch_model from openvino.frontend.pytorch.module_extension import ModuleExtension import inspect @@ -139,27 +139,25 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False) patch_model.patch_model( pt_module, self.module_extensions, orig_forward_name) - gptq_patched = False - if gptq.detect_gptq_model(pt_module): + patched = False + if quantized.detect_quantized_model(pt_module) is not None: try: - gptq.patch_model(pt_module) - gptq_patched = True + quantized.patch_quantized(pt_module) + patched = True except Exception as error: log.warning( - "Failed patching of AutoGPTQ model. Error message:\n%s" - "\nTracing of the model will likely be unsuccessful or incorrect", - error) - gptq.unpatch_model(pt_module) - gptq_patched = False + "Failed patching of AutoGPTQ model. Error message:\n" + "Tracing of the model will likely be unsuccessful or incorrect", + exc_info=error) + quantized.unpatch_quantized(pt_module) + patched = False try: scripted = torch.jit.trace( pt_module, **input_parameters, strict=False) finally: - if gptq_patched: - gptq.unpatch_model(pt_module) - if self.module_extensions: - patch_model.unpatch_model(pt_module, orig_forward_name) + if patched: + quantized.unpatch_quantized(pt_module) have_to_freeze_ops = ["prim::Uninitialized", "prim::unchecked_cast", "aten::append"] diff --git a/src/frontends/pytorch/src/op/linear.cpp b/src/frontends/pytorch/src/op/linear.cpp index 2d01dee84c151b..4a5ad4a6b0e73b 100644 --- a/src/frontends/pytorch/src/op/linear.cpp +++ b/src/frontends/pytorch/src/op/linear.cpp @@ -5,6 +5,10 @@ #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/add.hpp" #include "openvino/op/matmul.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/subtract.hpp" #include "utils.hpp" namespace ov { @@ -12,6 +16,8 @@ namespace frontend { namespace pytorch { namespace op { +using namespace ov::op; + OutputVector translate_linear(const NodeContext& context) { // schema: aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor num_inputs_check(context, 2, 3); @@ -20,17 +26,91 @@ OutputVector translate_linear(const NodeContext& context) { if (weight.get_element_type() == element::f16 || weight.get_element_type() == element::bf16) { // In case of patched linear it can have mixed fp16/bf16 and fp32 input type. // In other cases these conversion is not required. - weight = context.mark_node(std::make_shared(weight, x)); + weight = context.mark_node(std::make_shared(weight, x)); } - auto matmul = context.mark_node(std::make_shared(x, weight, false, true)); + auto matmul = context.mark_node(std::make_shared(x, weight, false, true)); if (!context.input_is_none(2)) { auto bias = context.get_input(2); if (bias.get_element_type() == element::f16 || bias.get_element_type() == element::bf16) { // Same reason as for weight. - bias = context.mark_node(std::make_shared(bias, x)); + bias = context.mark_node(std::make_shared(bias, x)); + } + matmul = context.mark_node(std::make_shared(matmul, bias)); + } + return {matmul}; +}; + +namespace { +uint32_t rearrange_awq_bits(uint32_t num) { + uint32_t result = 0; + uint32_t mask = 0xF; + + // Rearrange each 4-bit part in accordance with the AWQ i32->u4 unpacking schema + result |= (num & (mask << 0)) << 0; + result |= (num & (mask << 16)) >> 12; + result |= (num & (mask << 4)) << 4; + result |= (num & (mask << 20)) >> 8; + result |= (num & (mask << 8)) << 8; + result |= (num & (mask << 24)) >> 4; + result |= (num & (mask << 12)) << 12; + result |= (num & (mask << 28)) >> 0; + + return result; +} + +Output rearrange_constant(const Output& c, uint32_t groups) { + auto constant = std::dynamic_pointer_cast(c.get_node_shared_ptr()); + FRONT_END_OP_CONVERSION_CHECK(constant, "weight must be Constant."); + auto src = constant->get_data_ptr(); + auto initial_shape = constant->get_shape(); + FRONT_END_OP_CONVERSION_CHECK(initial_shape.size() == 2, "Only 2D constants are supported."); + auto new_shape = Shape{initial_shape[0] / groups, groups, initial_shape[1] * 8}; + auto new_qweight = std::make_shared(element::u4, new_shape); + auto dst = const_cast(reinterpret_cast(new_qweight->get_data_ptr())); + for (size_t i = 0; i < shape_size(constant->get_shape()); i++) { + dst[i] = rearrange_awq_bits(src[i]); + } + return new_qweight; +} +} // namespace + +OutputVector translate_linear_awq(const NodeContext& context) { + num_inputs_check(context, 4, 7); + auto x = context.get_input(0); + auto qweight = context.get_input(1); + auto qzeros = context.get_input(2); + auto scales = context.get_input(3); + auto groups = context.const_input(4); + auto bits = context.const_input(5); + + FRONT_END_OP_CONVERSION_CHECK(bits == 4, "Only 4 bit AWQ is supported."); + + auto new_qweight = rearrange_constant(qweight, static_cast(groups)); + auto new_qzeros = rearrange_constant(qzeros, 1); + new_qweight = context.mark_node(std::make_shared(new_qweight, scales.get_element_type())); + new_qzeros = context.mark_node(std::make_shared(new_qzeros, scales.get_element_type())); + + auto w_s = context.mark_node(std::make_shared(new_qweight, new_qzeros)); + FRONT_END_OP_CONVERSION_CHECK(scales.get_partial_shape().is_static(), "Scales must be constant."); + auto scales_shape = scales.get_shape(); + auto new_scales_shape = + v0::Constant::create(element::i32, {3}, std::vector{scales_shape[0], 1, scales_shape[1]}); + scales = context.mark_node(std::make_shared(scales, new_scales_shape, false)); + auto weight = context.mark_node(std::make_shared(w_s, scales)); + auto out_shape = + v0::Constant::create(element::i32, {2}, std::vector{static_cast(qweight.get_shape()[0]), -1}); + weight = context.mark_node(std::make_shared(weight, out_shape, false)); + weight = context.mark_node(std::make_shared(weight, x)); + + auto matmul = context.mark_node(std::make_shared(x, weight, false, false)); + if (!context.input_is_none(6)) { + auto bias = context.get_input(6); + + if (bias.get_element_type() == element::f16 || bias.get_element_type() == element::bf16) { + bias = context.mark_node(std::make_shared(bias, x)); } - matmul = context.mark_node(std::make_shared(matmul, bias)); + matmul = context.mark_node(std::make_shared(matmul, bias)); } return {matmul}; }; diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 607f0bd32db80d..00ef0e0a9d24f6 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -61,7 +61,6 @@ OP_CONVERTER(translate_clamp); OP_CONVERTER(translate_col2im); OP_CONVERTER(translate_constant); OP_CONVERTER(translate_conv_transposend); -OP_CONVERTER(translate_conv1d_ext); OP_CONVERTER(translate_convnd); OP_CONVERTER(translate_convolution); OP_CONVERTER(translate_convolution_mode); @@ -77,7 +76,6 @@ OP_CONVERTER(translate_dot); OP_CONVERTER(translate_elu); OP_CONVERTER(translate_embedding); OP_CONVERTER(translate_embedding_bag); -OP_CONVERTER(translate_embedding_ext); OP_CONVERTER(translate_empty); OP_CONVERTER(translate_empty_like); OP_CONVERTER(translate_erf); @@ -324,6 +322,10 @@ OP_CONVERTER(translate_unbind_int_fx); OP_CONVERTER(translate_unique2); OP_CONVERTER(translate_zeros_fx); OP_CONVERTER(translate_zeros_like_fx); +// Extensions +OP_CONVERTER(translate_conv1d_ext); +OP_CONVERTER(translate_embedding_ext); +OP_CONVERTER(translate_linear_awq); } // namespace op @@ -697,6 +699,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::zero", op::translate_zeros_like}, {"aten::zeros", op::translate_zeros}, {"aten::zeros_like", op::translate_zeros_like}, + {"ov_ext::awq_gemm", op::translate_linear_awq}, {"ov_ext::embedding", op::translate_embedding_ext}, {"ov_ext::conv1d", op::translate_conv1d_ext}, {"ov_ext::linear", op::translate_linear}, diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 752b9accb71d01..5cc7ec21f30911 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -42,7 +42,11 @@ using namespace ov::op; void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs) { auto num_inputs = context.get_input_size(); - FRONT_END_OP_CONVERSION_CHECK(num_inputs >= min_inputs, "Got less inputs than expected"); + FRONT_END_OP_CONVERSION_CHECK(num_inputs >= min_inputs, + "Got less inputs ", + num_inputs, + " than expected ", + min_inputs); for (auto i = max_inputs; i < num_inputs; i++) { FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected."); } diff --git a/tests/model_hub_tests/pytorch/detectron2_precommit b/tests/model_hub_tests/pytorch/detectron2_precommit index 155e4d2a359779..f98e44ad21871f 100644 --- a/tests/model_hub_tests/pytorch/detectron2_precommit +++ b/tests/model_hub_tests/pytorch/detectron2_precommit @@ -1,13 +1,8 @@ -COCO-Detection/faster_rcnn_R_50_C4_1x,none -COCO-Detection/faster_rcnn_R_50_DC5_3x,none COCO-Detection/faster_rcnn_R_50_FPN_1x,none COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x,none COCO-Detection/retinanet_R_50_FPN_1x,none COCO-Detection/rpn_R_50_C4_1x,none -COCO-Detection/rpn_R_50_FPN_1x,none COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x,none -COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x,none -COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x,none COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x,none COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x,none COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x,none @@ -19,8 +14,6 @@ LVISv0.5-InstanceSegmentation/mask_rcnn_R_50_FPN_1x,none LVISv0.5-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x,none Misc/cascade_mask_rcnn_R_50_FPN_3x,none Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv,none -Misc/mask_rcnn_R_50_FPN_3x_dconv_c3-c5,none -Misc/mask_rcnn_R_50_FPN_3x_gn,none Misc/mask_rcnn_R_50_FPN_3x_syncbn,none Misc/scratch_mask_rcnn_R_50_FPN_9x_syncbn,none PascalVOC-Detection/faster_rcnn_R_50_C4,none diff --git a/tests/model_hub_tests/pytorch/test_llm.py b/tests/model_hub_tests/pytorch/test_llm.py index e444f93db9d7ec..ba48634a070e39 100644 --- a/tests/model_hub_tests/pytorch/test_llm.py +++ b/tests/model_hub_tests/pytorch/test_llm.py @@ -15,10 +15,10 @@ from torch_utils import TestTorchConvertModel -def is_gptq_model(config): +def is_quantized_model(config): config_dict = config.to_dict() if not isinstance(config, dict) else config quantization_config = config_dict.get("quantization_config", None) - return quantization_config and quantization_config["quant_method"] == "gptq" + return quantization_config and quantization_config["quant_method"] in ["gptq", "awq"] def patch_gptq(): @@ -26,35 +26,83 @@ def patch_gptq(): orig_cuda_is_bf16_supported = torch.cuda.is_bf16_supported orig_cuda_get_device_capability = torch.cuda.get_device_capability orig_post_init_model = None + orig_gemm_forward = None torch.set_default_dtype(torch.float32) torch.cuda.is_available = lambda: True torch.cuda.is_bf16_supported = lambda: False torch.cuda.get_device_capability = lambda n: (9, 1) - from optimum.gptq import GPTQQuantizer + try: + from optimum.gptq import GPTQQuantizer - orig_post_init_model = GPTQQuantizer.post_init_model + orig_post_init_model = GPTQQuantizer.post_init_model - def post_init_model(self, model): - from auto_gptq import exllama_set_max_input_length + def post_init_model(self, model): + from auto_gptq import exllama_set_max_input_length - class StoreAttr(object): - pass + class StoreAttr(object): + pass - model.quantize_config = StoreAttr() - model.quantize_config.desc_act = self.desc_act - if self.desc_act and not self.disable_exllama and self.max_input_length is not None: - model = exllama_set_max_input_length(model, self.max_input_length) - return model + model.quantize_config = StoreAttr() + model.quantize_config.desc_act = self.desc_act + if self.desc_act and not self.disable_exllama and self.max_input_length is not None: + model = exllama_set_max_input_length(model, self.max_input_length) + return model + + GPTQQuantizer.post_init_model = post_init_model + except ImportError: + pass + + try: + # patch GEMM module to work without CUDA GPU + from awq.modules.linear.gemm import WQLinearMMFunction + from awq.utils.packing_utils import dequantize_gemm + + def new_forward( + ctx, + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0, + ): + ctx.out_features = out_features + + out_shape = x.shape[:-1] + (out_features,) + x = x.to(torch.float16) - GPTQQuantizer.post_init_model = post_init_model - return (orig_cuda_is_available, orig_cuda_is_bf16_supported, orig_cuda_get_device_capability), orig_post_init_model + out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) + out = torch.matmul(x, out) + out = out + bias if bias is not None else out + out = out.reshape(out_shape) -def unpatch_gptq(orig_cuda_check, orig_post_init_model): - from optimum.gptq import GPTQQuantizer + if len(out.shape) == 2: + out = out.unsqueeze(0) + return out + + orig_gemm_forward = WQLinearMMFunction.forward + WQLinearMMFunction.forward = new_forward + except ImportError: + pass + return (orig_cuda_is_available, orig_cuda_is_bf16_supported, orig_cuda_get_device_capability), orig_post_init_model, orig_gemm_forward + + +def unpatch_gptq(orig_cuda_check, orig_post_init_model, orig_gemm_forward): torch.cuda.is_available, torch.cuda.is_bf16_supported, torch.cuda.get_device_capability = orig_cuda_check - GPTQQuantizer.post_init_model = orig_post_init_model + try: + from optimum.gptq import GPTQQuantizer + GPTQQuantizer.post_init_model = orig_post_init_model + except ImportError: + pass + try: + from awq.modules.linear.gemm import WQLinearMMFunction + WQLinearMMFunction.forward = orig_gemm_forward + except ImportError: + pass def to_numpy(t): @@ -88,7 +136,7 @@ def flattenize_outputs(outputs): class TestLLMModel(TestTorchConvertModel): def setup_class(self): self.infer_timeout = 1800 - self.cuda_available, self.gptq_postinit = None, None + self.cuda_available, self.gptq_postinit, self.orig_gemm_forward = None, None, None @retry(3, exceptions=(OSError,), delay=1) def load_model(self, name, type): @@ -99,11 +147,12 @@ def load_model(self, name, type): except Exception: config = {} model_kwargs = {"torchscript": True, "trust_remote_code": True} - is_gptq = is_gptq_model(config) + is_quant = is_quantized_model(config) is_gpt2 = name == "openai-community/gpt2" - if is_gptq: - self.cuda_available, self.gptq_postinit = patch_gptq() + if is_quant: + self.cuda_available, self.gptq_postinit, self.orig_gemm_forward = patch_gptq() + model_kwargs["torch_dtype"] = "auto" model_kwargs["torch_dtype"] = torch.float32 self.ov_config = {"DYNAMIC_QUANTIZATION_GROUP_SIZE": "0"} elif is_gpt2: @@ -113,7 +162,7 @@ def load_model(self, name, type): t = AutoTokenizer.from_pretrained(name, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained(name, **model_kwargs) - if is_gptq: + if is_quant: model = self.model else: assert self.model.config.torch_dtype in [ @@ -175,8 +224,8 @@ def convert_model_impl(self, model_obj): def teardown_method(self): # restore after gptq patching if self.cuda_available is not None: - unpatch_gptq(self.cuda_available, self.gptq_postinit) - self.cuda_available, self.gptq_postinit = None, None + unpatch_gptq(self.cuda_available, self.gptq_postinit, self.orig_gemm_forward) + self.cuda_available, self.gptq_postinit, self.orig_gemm_forward = None, None, None super().teardown_method() @staticmethod @@ -191,7 +240,8 @@ def get_pkv(model, tokenizer): @pytest.mark.parametrize("type,name", [ ("opt_gptq", "katuni4ka/opt-125m-gptq"), ("llama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"), - ("gpt2", "openai-community/gpt2") + ("gpt2", "openai-community/gpt2"), + ("llama_awq", "casperhansen/tinyllama-1b-awq") ]) @pytest.mark.precommit @pytest.mark.nightly @@ -210,6 +260,7 @@ def test_convert_model_precommit(self, name, type, ie_device): ("bloom_gptq", "sbolouki/bloom-1b7-gptq"), ("cohere_gptq", "shuyuej/aya-23-8B-GPTQ"), ("mbart_gptq", "Shivam098/opt-translation"), + ("llama_awq", "TheBloke/open-llama-3b-v2-wizard-evol-instuct-v2-196k-AWQ") ]) @pytest.mark.nightly def test_convert_model_nightly(self, name, type, ie_device): @@ -236,6 +287,8 @@ def test_convert_model_nightly(self, name, type, ie_device): marks=pytest.mark.xfail(reason="GPTQ QUANT_TYPE=cuda is not supported")), pytest.param("llama3_gptq", "TechxGenus/Meta-Llama-3-8B-GPTQ", marks=pytest.mark.xfail(reason="GPTQ QUANT_TYPE=cuda is not supported")), + ("qwen2_awq", "Qwen/Qwen2.5-Coder-32B-Instruct-AWQ"), + ("mixstral_awq", "TheBloke/SauerkrautLM-Mixtral-8x7B-AWQ"), ]) def test_convert_model_very_large(self, name, type, ie_device): self.run(model_name=name, model_link=type, ie_device=ie_device) diff --git a/tests/requirements_pytorch b/tests/requirements_pytorch index 56446beba12600..be304155e2afc0 100644 --- a/tests/requirements_pytorch +++ b/tests/requirements_pytorch @@ -19,6 +19,7 @@ pytest-html==4.1.1 pytest-xdist[psutil]==3.6.1 defusedxml==0.7.1 +autoawq==0.2.7; platform_system == "Linux" and platform_machine == "x86_64" auto-gptq==0.7.1; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.12" av==13.0.0 basicsr==1.4.2; python_version < "3.12"