Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Support AWQ models with GEMM module (#27624) #27859

Open
wants to merge 1 commit into
base: releases/2024/5
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions src/bindings/python/src/openvino/frontend/pytorch/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,3 @@ def unpatch_model(model):
log.warning("Exception raised during GPTQ model unpatching. "
"Depending on the exact issue it may lead to broken "
"original model.\n%s", error)


def detect_gptq_model_raw(model):
return (model and getattr(model, 'config', None) and
getattr(model.config, 'quantization_config', None) and
model.config.quantization_config.quant_method == 'gptq')


def detect_gptq_model(model):
return (detect_gptq_model_raw(model) or
getattr(model, 'model', None) and
detect_gptq_model_raw(model.model))
73 changes: 73 additions & 0 deletions src/bindings/python/src/openvino/frontend/pytorch/quantized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Optional
import torch
from openvino.frontend.pytorch import ModuleExtension, gptq
from openvino.frontend.pytorch.patch_model import patch_model, unpatch_model


def detect_quantized_model(model: torch.nn.Module) -> Optional[str]:
"""Detects the quantization method used in a given PyTorch model.

Args:
model (torch.nn.Module): The PyTorch model to check for quantization.

Returns:
str: The quantization method if available, otherwise None.
"""
if (model and getattr(model, "config", None)
and getattr(model.config, "quantization_config", None)):
return model.config.quantization_config.quant_method
if getattr(model, "model", None):
return detect_quantized_model(model.model)
return None


def patch_quantized(model: torch.nn.Module) -> None:
"""Patches a model based on its quantization type ("awq" or "gptq").

Args:
model (torch.nn.Module): The model to patch.

Raises:
RuntimeError: If the quantization type is unknown.
"""
quant_type = detect_quantized_model(model)
if quant_type == "awq":
extensions = {}
try:
from awq.modules.linear import WQLinear_GEMM
extensions[WQLinear_GEMM] = ModuleExtension(
WQLinear_GEMM, "ov_ext::awq_gemm",
convert=lambda module, target_op, *args, **kwargs: target_op(
args[0], module.qweight, module.qzeros, module.scales,
torch.tensor(module.group_size),
torch.tensor(module.w_bit), module.bias),
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.out_features], 0.5,
dtype=torch.float32)) # type: ignore
except ImportError:
pass
patch_model(model, extensions,
"_openvino_quantized_patch_orig_forward") # type: ignore
elif quant_type == "gptq":
model._openvino_gptq_patched = True
gptq.patch_model(model) # type: ignore
else:
raise RuntimeError(f"Unknown quantization type: {quant_type}.")


def unpatch_quantized(model: torch.nn.Module) -> None:
"""Reverts the patching applied to a quantized PyTorch model.

Args:
model (torch.nn.Module): The model to unpatch.
"""
if getattr(model, "_openvino_gptq_patched", False):
gptq.unpatch_model(model) # type: ignore
del model._openvino_gptq_patched
else:
unpatch_model(model,
"_openvino_quantized_patch_orig_forward") # type: ignore
26 changes: 12 additions & 14 deletions src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
graph_has_ops,
)
from openvino.runtime import opset11 as ops
from openvino.frontend.pytorch import gptq, patch_model
from openvino.frontend.pytorch import quantized, patch_model
from openvino.frontend.pytorch.module_extension import ModuleExtension

import inspect
Expand Down Expand Up @@ -139,27 +139,25 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False)
patch_model.patch_model(
pt_module, self.module_extensions, orig_forward_name)

gptq_patched = False
if gptq.detect_gptq_model(pt_module):
patched = False
if quantized.detect_quantized_model(pt_module) is not None:
try:
gptq.patch_model(pt_module)
gptq_patched = True
quantized.patch_quantized(pt_module)
patched = True
except Exception as error:
log.warning(
"Failed patching of AutoGPTQ model. Error message:\n%s"
"\nTracing of the model will likely be unsuccessful or incorrect",
error)
gptq.unpatch_model(pt_module)
gptq_patched = False
"Failed patching of AutoGPTQ model. Error message:\n"
"Tracing of the model will likely be unsuccessful or incorrect",
exc_info=error)
quantized.unpatch_quantized(pt_module)
patched = False

try:
scripted = torch.jit.trace(
pt_module, **input_parameters, strict=False)
finally:
if gptq_patched:
gptq.unpatch_model(pt_module)
if self.module_extensions:
patch_model.unpatch_model(pt_module, orig_forward_name)
if patched:
quantized.unpatch_quantized(pt_module)

have_to_freeze_ops = ["prim::Uninitialized",
"prim::unchecked_cast", "aten::append"]
Expand Down
88 changes: 84 additions & 4 deletions src/frontends/pytorch/src/op/linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_linear(const NodeContext& context) {
// schema: aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
num_inputs_check(context, 2, 3);
Expand All @@ -20,17 +26,91 @@ OutputVector translate_linear(const NodeContext& context) {
if (weight.get_element_type() == element::f16 || weight.get_element_type() == element::bf16) {
// In case of patched linear it can have mixed fp16/bf16 and fp32 input type.
// In other cases these conversion is not required.
weight = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(weight, x));
weight = context.mark_node(std::make_shared<v1::ConvertLike>(weight, x));
}
auto matmul = context.mark_node(std::make_shared<ov::op::v0::MatMul>(x, weight, false, true));
auto matmul = context.mark_node(std::make_shared<v0::MatMul>(x, weight, false, true));
if (!context.input_is_none(2)) {
auto bias = context.get_input(2);

if (bias.get_element_type() == element::f16 || bias.get_element_type() == element::bf16) {
// Same reason as for weight.
bias = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(bias, x));
bias = context.mark_node(std::make_shared<v1::ConvertLike>(bias, x));
}
matmul = context.mark_node(std::make_shared<v1::Add>(matmul, bias));
}
return {matmul};
};

namespace {
uint32_t rearrange_awq_bits(uint32_t num) {
uint32_t result = 0;
uint32_t mask = 0xF;

// Rearrange each 4-bit part in accordance with the AWQ i32->u4 unpacking schema
result |= (num & (mask << 0)) << 0;
result |= (num & (mask << 16)) >> 12;
result |= (num & (mask << 4)) << 4;
result |= (num & (mask << 20)) >> 8;
result |= (num & (mask << 8)) << 8;
result |= (num & (mask << 24)) >> 4;
result |= (num & (mask << 12)) << 12;
result |= (num & (mask << 28)) >> 0;

return result;
}

Output<Node> rearrange_constant(const Output<Node>& c, uint32_t groups) {
auto constant = std::dynamic_pointer_cast<v0::Constant>(c.get_node_shared_ptr());
FRONT_END_OP_CONVERSION_CHECK(constant, "weight must be Constant.");
auto src = constant->get_data_ptr<uint32_t>();
auto initial_shape = constant->get_shape();
FRONT_END_OP_CONVERSION_CHECK(initial_shape.size() == 2, "Only 2D constants are supported.");
auto new_shape = Shape{initial_shape[0] / groups, groups, initial_shape[1] * 8};
auto new_qweight = std::make_shared<v0::Constant>(element::u4, new_shape);
auto dst = const_cast<uint32_t*>(reinterpret_cast<const uint32_t*>(new_qweight->get_data_ptr()));
for (size_t i = 0; i < shape_size(constant->get_shape()); i++) {
dst[i] = rearrange_awq_bits(src[i]);
}
return new_qweight;
}
} // namespace

OutputVector translate_linear_awq(const NodeContext& context) {
num_inputs_check(context, 4, 7);
auto x = context.get_input(0);
auto qweight = context.get_input(1);
auto qzeros = context.get_input(2);
auto scales = context.get_input(3);
auto groups = context.const_input<int64_t>(4);
auto bits = context.const_input<int64_t>(5);

FRONT_END_OP_CONVERSION_CHECK(bits == 4, "Only 4 bit AWQ is supported.");

auto new_qweight = rearrange_constant(qweight, static_cast<uint32_t>(groups));
auto new_qzeros = rearrange_constant(qzeros, 1);
new_qweight = context.mark_node(std::make_shared<v0::Convert>(new_qweight, scales.get_element_type()));
new_qzeros = context.mark_node(std::make_shared<v0::Convert>(new_qzeros, scales.get_element_type()));

auto w_s = context.mark_node(std::make_shared<v1::Subtract>(new_qweight, new_qzeros));
FRONT_END_OP_CONVERSION_CHECK(scales.get_partial_shape().is_static(), "Scales must be constant.");
auto scales_shape = scales.get_shape();
auto new_scales_shape =
v0::Constant::create(element::i32, {3}, std::vector<uint64_t>{scales_shape[0], 1, scales_shape[1]});
scales = context.mark_node(std::make_shared<v1::Reshape>(scales, new_scales_shape, false));
auto weight = context.mark_node(std::make_shared<v1::Multiply>(w_s, scales));
auto out_shape =
v0::Constant::create(element::i32, {2}, std::vector<int32_t>{static_cast<int32_t>(qweight.get_shape()[0]), -1});
weight = context.mark_node(std::make_shared<v1::Reshape>(weight, out_shape, false));
weight = context.mark_node(std::make_shared<v1::ConvertLike>(weight, x));

auto matmul = context.mark_node(std::make_shared<v0::MatMul>(x, weight, false, false));
if (!context.input_is_none(6)) {
auto bias = context.get_input(6);

if (bias.get_element_type() == element::f16 || bias.get_element_type() == element::bf16) {
bias = context.mark_node(std::make_shared<v1::ConvertLike>(bias, x));
}
matmul = context.mark_node(std::make_shared<ov::op::v1::Add>(matmul, bias));
matmul = context.mark_node(std::make_shared<v1::Add>(matmul, bias));
}
return {matmul};
};
Expand Down
7 changes: 5 additions & 2 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ OP_CONVERTER(translate_clamp);
OP_CONVERTER(translate_col2im);
OP_CONVERTER(translate_constant);
OP_CONVERTER(translate_conv_transposend);
OP_CONVERTER(translate_conv1d_ext);
OP_CONVERTER(translate_convnd);
OP_CONVERTER(translate_convolution);
OP_CONVERTER(translate_convolution_mode);
Expand All @@ -77,7 +76,6 @@ OP_CONVERTER(translate_dot);
OP_CONVERTER(translate_elu);
OP_CONVERTER(translate_embedding);
OP_CONVERTER(translate_embedding_bag);
OP_CONVERTER(translate_embedding_ext);
OP_CONVERTER(translate_empty);
OP_CONVERTER(translate_empty_like);
OP_CONVERTER(translate_erf);
Expand Down Expand Up @@ -324,6 +322,10 @@ OP_CONVERTER(translate_unbind_int_fx);
OP_CONVERTER(translate_unique2);
OP_CONVERTER(translate_zeros_fx);
OP_CONVERTER(translate_zeros_like_fx);
// Extensions
OP_CONVERTER(translate_conv1d_ext);
OP_CONVERTER(translate_embedding_ext);
OP_CONVERTER(translate_linear_awq);

} // namespace op

Expand Down Expand Up @@ -697,6 +699,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::zero", op::translate_zeros_like},
{"aten::zeros", op::translate_zeros},
{"aten::zeros_like", op::translate_zeros_like},
{"ov_ext::awq_gemm", op::translate_linear_awq},
{"ov_ext::embedding", op::translate_embedding_ext},
{"ov_ext::conv1d", op::translate_conv1d_ext},
{"ov_ext::linear", op::translate_linear},
Expand Down
6 changes: 5 additions & 1 deletion src/frontends/pytorch/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ using namespace ov::op;

void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs) {
auto num_inputs = context.get_input_size();
FRONT_END_OP_CONVERSION_CHECK(num_inputs >= min_inputs, "Got less inputs than expected");
FRONT_END_OP_CONVERSION_CHECK(num_inputs >= min_inputs,
"Got less inputs ",
num_inputs,
" than expected ",
min_inputs);
for (auto i = max_inputs; i < num_inputs; i++) {
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected.");
}
Expand Down
7 changes: 0 additions & 7 deletions tests/model_hub_tests/pytorch/detectron2_precommit
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
COCO-Detection/faster_rcnn_R_50_C4_1x,none
COCO-Detection/faster_rcnn_R_50_DC5_3x,none
COCO-Detection/faster_rcnn_R_50_FPN_1x,none
COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x,none
COCO-Detection/retinanet_R_50_FPN_1x,none
COCO-Detection/rpn_R_50_C4_1x,none
COCO-Detection/rpn_R_50_FPN_1x,none
COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x,none
COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x,none
COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x,none
COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x,none
COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x,none
COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x,none
Expand All @@ -19,8 +14,6 @@ LVISv0.5-InstanceSegmentation/mask_rcnn_R_50_FPN_1x,none
LVISv0.5-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x,none
Misc/cascade_mask_rcnn_R_50_FPN_3x,none
Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv,none
Misc/mask_rcnn_R_50_FPN_3x_dconv_c3-c5,none
Misc/mask_rcnn_R_50_FPN_3x_gn,none
Misc/mask_rcnn_R_50_FPN_3x_syncbn,none
Misc/scratch_mask_rcnn_R_50_FPN_9x_syncbn,none
PascalVOC-Detection/faster_rcnn_R_50_C4,none
Loading
Loading