Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Jul 19, 2024
1 parent 9dee02a commit 1937a0d
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions tests/function_libs/torch_lib/quantization_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Test quantized model export."""

from __future__ import annotations

import unittest

import onnx
import torch
import unittest
from torch.ao.quantization import quantize_pt2e
import torch._export
from torch.ao.quantization import quantize_pt2e
from torch.ao.quantization.quantizer import xnnpack_quantizer


class QuantizedModelExportTest(unittest.TestCase):
def test_simple_quantized_model(self):
class TestModel(torch.nn.Module):
Expand All @@ -19,15 +23,16 @@ def __init__(self):
def forward(self, x):
return self.linear(x)


example_inputs = (torch.randn(1, 5),)
model = TestModel().eval()

# Step 1. program capture
pt2e_torch_model = torch._export.capture_pre_autograd_graph(model, example_inputs)

# Step 2. quantization
quantizer = xnnpack_quantizer.XNNPACKQuantizer().set_global(xnnpack_quantizer.get_symmetric_quantization_config())
quantizer = xnnpack_quantizer.XNNPACKQuantizer().set_global(
xnnpack_quantizer.get_symmetric_quantization_config()
)
pt2e_torch_model = quantize_pt2e.prepare_pt2e(pt2e_torch_model, quantizer)

# Run the prepared model with sample input data to ensure that internal observers are populated with correct values
Expand Down

0 comments on commit 1937a0d

Please sign in to comment.