Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Fix and improve quantization support #1737

Merged
merged 8 commits into from
Jul 22, 2024

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Jul 18, 2024

Fix bugs in the implementation where the dtype of the zero point is not correctly set. Tested with exporter.

@justinchuby
Copy link
Collaborator Author

Tested with:

import torch


class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(5, 10)

    def forward(self, x):
        return self.linear(x)


example_inputs = (torch.randn(1, 5),)
m = M().eval()

# Step 1. program capture

from torch._export import capture_pre_autograd_graph

pt2e_torch_model = capture_pre_autograd_graph(m, example_inputs)

# Step 2. quantization
from torch.ao.quantization.quantize_pt2e import (
    prepare_pt2e,
    convert_pt2e,
)

from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)

quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
pt2e_torch_model = prepare_pt2e(pt2e_torch_model, quantizer)

# Run the prepared model with sample input data to ensure that internal observers are populated with correct values
pt2e_torch_model(*example_inputs)

# Convert the prepared model to a quantized model
pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)


program = torch.export.export(pt2e_torch_model, example_inputs)
# we get a model with aten ops
print(program)

# Convert to ONNX
import torch_onnx

torch_onnx.patch_torch(error_report=True)

onnx_program = torch.onnx.export(program, example_inputs, "quantized.textproto")

@justinchuby
Copy link
Collaborator Author

ir_version: 9
producer_name: "torch"
producer_version: "2.3.1+cu121"
graph {
  node {
    output: "val_0"
    name: "node_Constant_0"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 3
        raw_data: "E"
      }
      type: TENSOR
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/quantize_per_tensor: quantized_decomposed.quantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.quantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%quantize_per_tensor : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.quantize_per_tensor.default](args = (%arg2_1, 0.004384273663163185, 69, -128, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'quantize_per_tensor\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"<eval_with_key>.11\", line 7, in forward\n    quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg0_1, 0.004384273663163185, 69, -128, 127, torch.int8);  arg0_1 = None\n"
    }
  }
  node {
    output: "val_1"
    name: "node_Constant_1"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 1
        raw_data: "\364\251\217;"
      }
      type: TENSOR
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/quantize_per_tensor: quantized_decomposed.quantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.quantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%quantize_per_tensor : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.quantize_per_tensor.default](args = (%arg2_1, 0.004384273663163185, 69, -128, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'quantize_per_tensor\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"<eval_with_key>.11\", line 7, in forward\n    quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg0_1, 0.004384273663163185, 69, -128, 127, torch.int8);  arg0_1 = None\n"
    }
  }
  node {
    input: "arg2_1"
    input: "val_1"
    input: "val_0"
    output: "val_quantize_per_tensor"
    name: "node_QuantizeLinear_2"
    op_type: "QuantizeLinear"
    attribute {
      name: "axis"
      i: 1
      type: INT
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/quantize_per_tensor: quantized_decomposed.quantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.quantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%quantize_per_tensor : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.quantize_per_tensor.default](args = (%arg2_1, 0.004384273663163185, 69, -128, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'quantize_per_tensor\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"<eval_with_key>.11\", line 7, in forward\n    quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg0_1, 0.004384273663163185, 69, -128, 127, torch.int8);  arg0_1 = None\n"
    }
  }
  node {
    output: "val_2"
    name: "node_Constant_3"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 3
        raw_data: "E"
      }
      type: TENSOR
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/dequantize_per_tensor: quantized_decomposed.dequantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.dequantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%dequantize_per_tensor : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.dequantize_per_tensor.default](args = (%quantize_per_tensor, 0.004384273663163185, 69, -128, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'dequantize_per_tensor\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"/home/justinchu/dev/torch-onnx/quant.py\", line 10, in forward\n    return self.linear(x)\n"
    }
  }
  node {
    input: "val_quantize_per_tensor"
    input: "val_1"
    input: "val_2"
    output: "val_dequantize_per_tensor"
    name: "node_DequantizeLinear_4"
    op_type: "DequantizeLinear"
    attribute {
      name: "axis"
      i: 1
      type: INT
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/dequantize_per_tensor: quantized_decomposed.dequantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.dequantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%dequantize_per_tensor : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.dequantize_per_tensor.default](args = (%quantize_per_tensor, 0.004384273663163185, 69, -128, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'dequantize_per_tensor\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"/home/justinchu/dev/torch-onnx/quant.py\", line 10, in forward\n    return self.linear(x)\n"
    }
  }
  node {
    output: "val_3"
    name: "node_Constant_5"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 3
        raw_data: "\000"
      }
      type: TENSOR
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/quantize_per_tensor_1: quantized_decomposed.quantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.quantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%quantize_per_tensor_1 : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.quantize_per_tensor.default](args = (%arg0_1, 0.003469259710982442, 0, -127, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'quantize_per_tensor_1\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"/home/justinchu/dev/torch-onnx/quant.py\", line 10, in forward\n    return self.linear(x)\n"
    }
  }
  node {
    output: "val_4"
    name: "node_Constant_6"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 1
        raw_data: "\205\\c;"
      }
      type: TENSOR
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/quantize_per_tensor_1: quantized_decomposed.quantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.quantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%quantize_per_tensor_1 : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.quantize_per_tensor.default](args = (%arg0_1, 0.003469259710982442, 0, -127, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'quantize_per_tensor_1\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"/home/justinchu/dev/torch-onnx/quant.py\", line 10, in forward\n    return self.linear(x)\n"
    }
  }
  node {
    input: "linear_weight"
    input: "val_4"
    input: "val_3"
    output: "val_quantize_per_tensor_1"
    name: "node_QuantizeLinear_7"
    op_type: "QuantizeLinear"
    attribute {
      name: "axis"
      i: 1
      type: INT
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/quantize_per_tensor_1: quantized_decomposed.quantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.quantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%quantize_per_tensor_1 : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.quantize_per_tensor.default](args = (%arg0_1, 0.003469259710982442, 0, -127, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'quantize_per_tensor_1\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"/home/justinchu/dev/torch-onnx/quant.py\", line 10, in forward\n    return self.linear(x)\n"
    }
  }
  node {
    output: "val_5"
    name: "node_Constant_8"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 3
        raw_data: "\000"
      }
      type: TENSOR
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/dequantize_per_tensor_1: quantized_decomposed.dequantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.dequantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%dequantize_per_tensor_1 : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.dequantize_per_tensor.default](args = (%quantize_per_tensor_1, 0.003469259710982442, 0, -127, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'dequantize_per_tensor_1\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"/home/justinchu/dev/torch-onnx/quant.py\", line 10, in forward\n    return self.linear(x)\n"
    }
  }
  node {
    input: "val_quantize_per_tensor_1"
    input: "val_4"
    input: "val_5"
    output: "val_dequantize_per_tensor_1"
    name: "node_DequantizeLinear_9"
    op_type: "DequantizeLinear"
    attribute {
      name: "axis"
      i: 1
      type: INT
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/dequantize_per_tensor_1: quantized_decomposed.dequantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.dequantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%dequantize_per_tensor_1 : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.dequantize_per_tensor.default](args = (%quantize_per_tensor_1, 0.003469259710982442, 0, -127, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'dequantize_per_tensor_1\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"/home/justinchu/dev/torch-onnx/quant.py\", line 10, in forward\n    return self.linear(x)\n"
    }
  }
  node {
    input: "val_dequantize_per_tensor_1"
    output: "val_t"
    name: "node_Transpose_10"
    op_type: "Transpose"
    attribute {
      name: "perm"
      ints: 1
      ints: 0
      type: INTS
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/linear: torch.nn.modules.linear.Linear/t: aten.t.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'torch.nn.modules.linear.Linear\', \'aten.t.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%dequantize_per_tensor_1,), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'linear\', \'t\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"/home/justinchu/dev/torch-onnx/quant.py\", line 10, in forward\n    return self.linear(x)\n"
    }
  }
  node {
    input: "linear_bias"
    input: "val_dequantize_per_tensor"
    input: "val_t"
    output: "val_addmm"
    name: "node_aten_addmm_11"
    op_type: "aten_addmm"
    attribute {
      name: "beta"
      f: 1.0
      type: FLOAT
    }
    attribute {
      name: "alpha"
      f: 1.0
      type: FLOAT
    }
    domain: "pkg.onnxscript.torch_lib"
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/linear: torch.nn.modules.linear.Linear/addmm: aten.addmm.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'torch.nn.modules.linear.Linear\', \'aten.addmm.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %dequantize_per_tensor, %t), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'linear\', \'addmm\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"/home/justinchu/dev/torch-onnx/quant.py\", line 10, in forward\n    return self.linear(x)\n"
    }
  }
  node {
    output: "val_6"
    name: "node_Constant_12"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 3
        raw_data: "\314"
      }
      type: TENSOR
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/quantize_per_tensor_2: quantized_decomposed.quantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.quantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%quantize_per_tensor_2 : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.quantize_per_tensor.default](args = (%addmm, 0.003849889151751995, -52, -128, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'quantize_per_tensor_2\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"/home/justinchu/dev/torch-onnx/quant.py\", line 10, in forward\n    return self.linear(x)\n"
    }
  }
  node {
    output: "val_7"
    name: "node_Constant_13"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 1
        raw_data: "lN|;"
      }
      type: TENSOR
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/quantize_per_tensor_2: quantized_decomposed.quantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.quantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%quantize_per_tensor_2 : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.quantize_per_tensor.default](args = (%addmm, 0.003849889151751995, -52, -128, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'quantize_per_tensor_2\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"/home/justinchu/dev/torch-onnx/quant.py\", line 10, in forward\n    return self.linear(x)\n"
    }
  }
  node {
    input: "val_addmm"
    input: "val_7"
    input: "val_6"
    output: "val_quantize_per_tensor_2"
    name: "node_QuantizeLinear_14"
    op_type: "QuantizeLinear"
    attribute {
      name: "axis"
      i: 1
      type: INT
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/quantize_per_tensor_2: quantized_decomposed.quantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.quantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%quantize_per_tensor_2 : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.quantize_per_tensor.default](args = (%addmm, 0.003849889151751995, -52, -128, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'quantize_per_tensor_2\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"/home/justinchu/dev/torch-onnx/quant.py\", line 10, in forward\n    return self.linear(x)\n"
    }
  }
  node {
    output: "val_8"
    name: "node_Constant_15"
    op_type: "Constant"
    attribute {
      name: "value"
      t {
        data_type: 3
        raw_data: "\314"
      }
      type: TENSOR
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/dequantize_per_tensor_2: quantized_decomposed.dequantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.dequantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%dequantize_per_tensor_2 : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.dequantize_per_tensor.default](args = (%quantize_per_tensor_2, 0.003849889151751995, -52, -128, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'dequantize_per_tensor_2\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"<eval_with_key>.11\", line 15, in forward\n    dequantize_per_tensor_default_2 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_2, 0.003849889151751995, -52, -128, 127, torch.int8);  quantize_per_tensor_default_2 = None\n"
    }
  }
  node {
    input: "val_quantize_per_tensor_2"
    input: "val_7"
    input: "val_8"
    output: "val_dequantize_per_tensor_2"
    name: "node_DequantizeLinear_16"
    op_type: "DequantizeLinear"
    attribute {
      name: "axis"
      i: 1
      type: INT
    }
    metadata_props {
      key: "namespace"
      value: ": torch.fx.graph_module.GraphModule/dequantize_per_tensor_2: quantized_decomposed.dequantize_per_tensor.default"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'torch.fx.graph_module.GraphModule\', \'quantized_decomposed.dequantize_per_tensor.default\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%dequantize_per_tensor_2 : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.dequantize_per_tensor.default](args = (%quantize_per_tensor_2, 0.003849889151751995, -52, -128, 127, torch.int8), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'dequantize_per_tensor_2\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "  File \"<eval_with_key>.11\", line 15, in forward\n    dequantize_per_tensor_default_2 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_2, 0.003849889151751995, -52, -128, 127, torch.int8);  quantize_per_tensor_default_2 = None\n"
    }
  }
  name: "main_graph"
  initializer {
    dims: 10
    dims: 5
    data_type: 1
    name: "linear_weight"
    raw_data: "*\372\310\275\245\232\206=\314\225\341\276\371\214}\276\374%7\276\262\354\313<\205\241\273\276Q1\266>L\272v>N\rf>\200\263e>\321W\336>\017\016M\276\205\204\014\274\361@\253>\351\372\246\276\201\r\271\275\3015~\276\300\010N\276;\205n\276\002U\216=\033\253\340>l\016\031>\035\264\030>\311\032\226>\345w]\276*\346\262\275\300\356\025\276\034\201\271\276~~e>\010\245\310>\204\206\336\275\202,\\>%\3754\276\271\360\310>\242@\272>O.\226\275D\340\271\276s~\313\276~*\231\276\002*\276\276\360\032\222>\"\262G\276\227\326\217=GVC>\225\r\234>v\370,=\314\223\206\276\225\212\223>\337\002Q\275"
  }
  initializer {
    dims: 10
    data_type: 1
    name: "linear_bias"
    raw_data: "V\336\021\276TC4>\306y\250>h\225\277>\330\351\326>\'5n\274\275\036\270>\341bf\276\366\230\202\2768\222\303>"
  }
  input {
    name: "arg2_1"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 5
          }
        }
      }
    }
    metadata_props {
      key: "pkg.torch.export.graph_signature.InputSpec.kind"
      value: "USER_INPUT"
    }
    metadata_props {
      key: "pkg.torch.export.graph_signature.InputSpec.persistent"
      value: "None"
    }
  }
  input {
    name: "linear_weight"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 10
          }
          dim {
            dim_value: 5
          }
        }
      }
    }
    metadata_props {
      key: "pkg.torch.export.graph_signature.InputSpec.kind"
      value: "PARAMETER"
    }
    metadata_props {
      key: "pkg.torch.export.graph_signature.InputSpec.persistent"
      value: "None"
    }
    metadata_props {
      key: "pkg.torch.onnx.original_node_name"
      value: "arg0_1"
    }
  }
  input {
    name: "linear_bias"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 10
          }
        }
      }
    }
    metadata_props {
      key: "pkg.torch.export.graph_signature.InputSpec.kind"
      value: "PARAMETER"
    }
    metadata_props {
      key: "pkg.torch.export.graph_signature.InputSpec.persistent"
      value: "None"
    }
    metadata_props {
      key: "pkg.torch.onnx.original_node_name"
      value: "arg1_1"
    }
  }
  output {
    name: "val_dequantize_per_tensor_2"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 10
          }
        }
      }
    }
    metadata_props {
      key: "pkg.torch.export.graph_signature.OutputSpec.kind"
      value: "USER_OUTPUT"
    }
  }
  value_info {
    name: "val_quantize_per_tensor"
    type {
      tensor_type {
        elem_type: 3
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 5
          }
        }
      }
    }
  }
  value_info {
    name: "val_dequantize_per_tensor"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 5
          }
        }
      }
    }
  }
  value_info {
    name: "val_quantize_per_tensor_1"
    type {
      tensor_type {
        elem_type: 3
        shape {
          dim {
            dim_value: 10
          }
          dim {
            dim_value: 5
          }
        }
      }
    }
  }
  value_info {
    name: "val_dequantize_per_tensor_1"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 10
          }
          dim {
            dim_value: 5
          }
        }
      }
    }
  }
  value_info {
    name: "val_t"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 5
          }
          dim {
            dim_value: 10
          }
        }
      }
    }
  }
  value_info {
    name: "val_addmm"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 10
          }
        }
      }
    }
  }
  value_info {
    name: "val_quantize_per_tensor_2"
    type {
      tensor_type {
        elem_type: 3
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 10
          }
        }
      }
    }
  }
}
opset_import {
  domain: ""
  version: 18
}
opset_import {
  domain: "pkg.onnxscript.torch_lib"
  version: 1
}
opset_import {
  domain: "pkg.onnxscript.torch_lib.common"
  version: 1
}
functions {
  name: "aten_addmm"
  input: "self"
  input: "mat1"
  input: "mat2"
  output: "return_val"
  node {
    input: "mat1"
    input: "mat2"
    input: "self"
    output: "return_val"
    name: "n0"
    op_type: "Gemm"
    attribute {
      name: "alpha"
      type: FLOAT
      ref_attr_name: "alpha"
    }
    attribute {
      name: "beta"
      type: FLOAT
      ref_attr_name: "beta"
    }
  }
  doc_string: "addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"
  opset_import {
    domain: ""
    version: 18
  }
  domain: "pkg.onnxscript.torch_lib"
  attribute_proto {
    name: "beta"
    f: 1.0
    type: FLOAT
  }
  attribute_proto {
    name: "alpha"
    f: 1.0
    type: FLOAT
  }
}

@justinchuby justinchuby added the topic: torch_lib Related to the torch/aten function lib in development label Jul 18, 2024
Copy link

codecov bot commented Jul 18, 2024

Codecov Report

Attention: Patch coverage is 50.00000% with 8 lines in your changes missing coverage. Please review.

Project coverage is 74.73%. Comparing base (2401de4) to head (ce37828).

Files Patch % Lines
...unction_libs/torch_lib/ops/quantized_decomposed.py 16.66% 5 Missing ⚠️
.../torch_lib/graph_building/_graph_building_torch.py 33.33% 1 Missing and 1 partial ⚠️
onnxscript/function_libs/torch_lib/ops/common.py 85.71% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1737      +/-   ##
==========================================
- Coverage   74.89%   74.73%   -0.17%     
==========================================
  Files         245      245              
  Lines       26358    26369      +11     
  Branches     4793     4795       +2     
==========================================
- Hits        19742    19706      -36     
- Misses       5695     5743      +48     
+ Partials      921      920       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a test?

@@ -56,3 +62,10 @@
result = op.Cast(a, to=dtype)

return result


def constant(array, dtype: int | onnx.TensorProto.DataType | ir.DataType) -> TensorType:

Check failure

Code scanning / lintrunner

PYLINT/E0601 Error

Using variable 'onnx' before assignment (used-before-assignment)
See used-before-assignment. To disable, use # pylint: disable=used-before-assignment
Copy link

github-actions bot commented Jul 18, 2024

Test Results

     24 files  ± 0      24 suites  ±0   1h 51m 44s ⏱️ + 3m 53s
 13 931 tests + 1  10 611 ✅ +2    3 312 💤 ±0   8 ❌  - 1 
276 206 runs  +10  68 959 ✅ +3  207 222 💤 ±0  25 ❌ +7 

For more details on these failures, see this check.

Results for commit 64e6416. ± Comparison against base commit 2401de4.

♻️ This comment has been updated with latest results.

@justinchuby justinchuby merged commit 842f38d into main Jul 22, 2024
29 of 41 checks passed
@justinchuby justinchuby deleted the justinchu/quant-improved branch July 22, 2024 19:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: torch_lib Related to the torch/aten function lib in development
Projects
Development

Successfully merging this pull request may close these issues.

2 participants