Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

post training quantization #247

Open
shanek16 opened this issue Oct 25, 2023 · 2 comments
Open

post training quantization #247

shanek16 opened this issue Oct 25, 2023 · 2 comments

Comments

@shanek16
Copy link

Thank you for your work.

I am trying to quantize the MiDaS DPT_Large model into INT 8 quantization.

I have searched through github and googled, and asked bing if there is any one liner code to quantize the model into INT8 given calibration image folder.

However there seems no such way, and no examples or trials that other people had done to quantize MiDaS.
I have tried quantization through torch.quantization module, but got the following error:

Using cache found in /root/.cache/torch/hub/intel-isl_MiDaS_master
Traceback (most recent call last):
  File "quantize.py", line 58, in <module>
    output = prepared_model(data)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1533, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/torch/hub/intel-isl_MiDaS_master/midas/dpt_depth.py", line 166, in forward
    return super().forward(x).squeeze(dim=1)
  File "/root/.cache/torch/hub/intel-isl_MiDaS_master/midas/dpt_depth.py", line 114, in forward
    layers = self.forward_transformer(self.pretrained, x)
  File "/root/.cache/torch/hub/intel-isl_MiDaS_master/midas/backbones/vit.py", line 13, in forward_vit
    return forward_adapted_unflatten(pretrained, x, "forward_flex")
  File "/root/.cache/torch/hub/intel-isl_MiDaS_master/midas/backbones/utils.py", line 88, in forward_adapted_unflatten
    layer_1 = pretrained.activations["1"]
KeyError: '1'

I used code:

import os
import copy
from PIL import Image
import torch
import torch.quantization
from torchvision import datasets, transforms
from sklearn.metrics import mean_squared_error
from torch.utils.data import random_split, Dataset
from torch.quantization import get_default_qconfig, QConfig
from torch.quantization import default_observer, MovingAverageMinMaxObserver


class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.img_names = os.listdir(img_dir)
        self.transform = transform

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_names[idx])
        image = Image.open(img_path)
        if self.transform:
            image = self.transform(image)
        return image.unsqueeze(0)

# Load your model
model = torch.hub.load("intel-isl/MiDaS", "DPT_Large")
model.eval()  # Set the model to evaluation mode
original_model = copy.deepcopy(model)

# Define the transformations to be applied to the images (e.g., resize, normalization)
transform = transforms.Compose([
    transforms.Resize((384, 672)),
    transforms.ToTensor(),
])

# Load the calibration data
calibration_data = CustomImageDataset(img_dir='/workspace/data/safety/calib', transform=transform)
calibration_data = torch.utils.data.Subset(calibration_data, torch.randperm(len(calibration_data)))
calibration_loader, validation_loader = random_split(calibration_data, [len(calibration_data)//2, len(calibration_data)//2])
# Prepare the model for quantization
# Create a custom qconfig
custom_qconfig = QConfig(
    activation=default_observer,  # or replace with your choice of observer
    weight=default_observer  # or replace with your choice of observer
)

# Apply the custom qconfig to ConvTranspose layers
model.qconfig = custom_qconfig
prepared_model = torch.quantization.prepare(model, inplace=False)

# Calibrate the model
for data in calibration_loader:
    with torch.no_grad():
        output = prepared_model(data)

# Convert to quantized model
quantized_model = torch.quantization.convert(prepared_model, inplace=False)
torch.save(quantized_model, 'quantized_midas.pth')

quantized_model.eval()
original_model.eval()
mse = 0
num_samples = 0
try:
    for data, _ in validation_loader:
        with torch.no_grad():
            original_output = original_model(data)
            quantized_output = quantized_model(data)
            mse += mean_squared_error(original_output.cpu().numpy(), quantized_output.cpu().numpy())
            num_samples += 1
except Exception as e:
    print(f"Failed during validation: {e}")
    exit()

# Print the mean squared error over the validation set
print(f'Mean Squared Error on Validation Set: {mse / num_samples}')

I used Docker image: https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-23-04.html
for environment setting.

I am having a hard time trying to quantize MiDaS manually by myself while I am still not familiar to pytorch functions.
It would be a great help to know if anyone could share their story succeeding in quantizing MiDaS model.
Any comments or advice is also welcome.
Thanks.

@charliebudd
Copy link

I don't think this behaviour is to do with quantisation. The trace of the error starts during your forward call to the quantised model...

# Calibrate the model
for data in calibration_loader:
    with torch.no_grad():
        output = prepared_model(data)

However, I am seeing the same error so there is still a problem here. I notice you are not applying the MiDaS transform in your code, This does not fix the issue for me but you will want to do this. Probably give the tutorial another read...

@charliebudd
Copy link

I have tracked down the issue (for me at least). I was making a copy of the model using the python built-in copy.deepcopy method. The copy throws the error in the forward call, but the original works fine. I wouldn't be too surprised if the torch quantisation module does something similar, so this may well be the cause of @shanek16's issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants