Skip to content

Commit

Permalink
2023-10-28 nightly release (209b2b3)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 28, 2023
1 parent 36aab13 commit 8c42a73
Show file tree
Hide file tree
Showing 13 changed files with 293 additions and 95 deletions.
14 changes: 7 additions & 7 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ For all post training quantized models, the settings are:
2. num_workers: 16
3. batch_size: 32
4. eval_batch_size: 128
5. backend: 'fbgemm'
5. qbackend: 'fbgemm'

```
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='$MODEL'
python train_quantization.py --device='cpu' --post-training-quantize --qbackend='fbgemm' --model='$MODEL'
```
Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `resnext101_32x8d`, `shufflenet_v2_x0_5` and `shufflenet_v2_x1_0`.

Expand All @@ -301,12 +301,12 @@ Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `re
Here are commands that we use to quantize the `shufflenet_v2_x1_5` and `shufflenet_v2_x2_0` models.
```
# For shufflenet_v2_x1_5
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' \
python train_quantization.py --device='cpu' --post-training-quantize --qbackend='fbgemm' \
--model=shufflenet_v2_x1_5 --weights="ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1" \
--train-crop-size 176 --val-resize-size 232 --data-path /datasets01_ontap/imagenet_full_size/061417/
# For shufflenet_v2_x2_0
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' \
python train_quantization.py --device='cpu' --post-training-quantize --qbackend='fbgemm' \
--model=shufflenet_v2_x2_0 --weights="ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1" \
--train-crop-size 176 --val-resize-size 232 --data-path /datasets01_ontap/imagenet_full_size/061417/
```
Expand All @@ -317,7 +317,7 @@ For Mobilenet-v2, the model was trained with quantization aware training, the se
1. num_workers: 16
2. batch_size: 32
3. eval_batch_size: 128
4. backend: 'qnnpack'
4. qbackend: 'qnnpack'
5. learning-rate: 0.0001
6. num_epochs: 90
7. num_observer_update_epochs:4
Expand All @@ -339,7 +339,7 @@ For Mobilenet-v3 Large, the model was trained with quantization aware training,
1. num_workers: 16
2. batch_size: 32
3. eval_batch_size: 128
4. backend: 'qnnpack'
4. qbackend: 'qnnpack'
5. learning-rate: 0.001
6. num_epochs: 90
7. num_observer_update_epochs:4
Expand All @@ -359,7 +359,7 @@ For post training quant, device is set to CPU. For training, the device is set t
### Command to evaluate quantized models using the pre-trained weights:

```
python train_quantization.py --device='cpu' --test-only --backend='<backend>' --model='<model_name>'
python train_quantization.py --device='cpu' --test-only --qbackend='<qbackend>' --model='<model_name>'
```

For inception_v3 you need to pass the following extra parameters:
Expand Down
20 changes: 14 additions & 6 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def main(args):
raise RuntimeError("Post training quantization example should not be performed on distributed mode")

# Set backend engine to ensure that quantized model runs on the correct kernels
if args.backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
torch.backends.quantized.engine = args.backend
if args.qbackend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported: " + str(args.qbackend))
torch.backends.quantized.engine = args.qbackend

device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
Expand Down Expand Up @@ -55,7 +55,7 @@ def main(args):

if not (args.test_only or args.post_training_quantize):
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.qbackend)
torch.ao.quantization.prepare_qat(model, inplace=True)

if args.distributed and args.sync_bn:
Expand Down Expand Up @@ -89,7 +89,7 @@ def main(args):
)
model.eval()
model.fuse_model(is_qat=False)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.qbackend)
torch.ao.quantization.prepare(model, inplace=True)
# Calibrate first
print("Calibrating")
Expand Down Expand Up @@ -161,7 +161,7 @@ def get_args_parser(add_help=True):

parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
parser.add_argument("--backend", default="qnnpack", type=str, help="fbgemm or qnnpack")
parser.add_argument("--qbackend", default="qnnpack", type=str, help="Quantized backend: fbgemm or qnnpack")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")

parser.add_argument(
Expand Down Expand Up @@ -257,9 +257,17 @@ def get_args_parser(add_help=True):
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")

return parser


if __name__ == "__main__":
args = get_args_parser().parse_args()
if args.backend in ("fbgemm", "qnnpack"):
raise ValueError(
"The --backend parameter has been re-purposed to specify the backend of the transforms (PIL or Tensor) "
"instead of the quantized backend. Please use the --qbackend parameter to specify the quantized backend."
)
main(args)
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "torchvision", "csrc")

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + glob.glob(
os.path.join(extensions_dir, "ops", "*.cpp")
main_file = (
glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
)
source_cpu = (
glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp"))
Expand Down Expand Up @@ -184,8 +186,6 @@ def get_extensions():
else:
source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu"))

source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))

sources = main_file + source_cpu
extension = CppExtension

Expand Down
34 changes: 19 additions & 15 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
import torchvision.io
from common_utils import disable_console_output, get_tmp_dir
from torch.utils._pytree import tree_any
from torch.utils.data import DataLoader
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2
from torchvision.transforms.functional import get_dimensions
from torchvision.transforms.v2.functional import get_size


__all__ = [
Expand Down Expand Up @@ -568,9 +572,6 @@ def test_transforms(self, config):

@test_all_configs
def test_transforms_v2_wrapper(self, config):
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2

try:
with self.create_dataset(config) as (dataset, info):
for target_keys in [None, "all"]:
Expand Down Expand Up @@ -709,26 +710,29 @@ def _no_collate(batch):
return batch


def check_transforms_v2_wrapper_spawn(dataset):
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
# we are enforcing here.
if platform.system() != "Darwin":
pytest.skip("Multiprocessing spawning is only checked on macOS.")
def check_transforms_v2_wrapper_spawn(dataset, expected_size):
# This check ensures that the wrapped datasets can be used with multiprocessing_context="spawn" in the DataLoader.
# We also check that transforms are applied correctly as a non-regression test for
# https://github.com/pytorch/vision/issues/8066
# Implicitly, this also checks that the wrapped datasets are pickleable.

from torch.utils.data import DataLoader
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2
# To save CI/test time, we only check on Windows where "spawn" is the default
if platform.system() != "Windows":
pytest.skip("Multiprocessing spawning is only checked on macOS.")

wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)

dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)

for wrapped_sample in dataloader:
assert tree_any(
lambda item: isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)), wrapped_sample
def resize_was_applied(item):
# Checking the size of the output ensures that the Resize transform was correctly applied
return isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)) and get_size(item) == list(
expected_size
)

for wrapped_sample in dataloader:
assert tree_any(resize_was_applied, wrapped_sample)


def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor.
Expand Down
62 changes: 38 additions & 24 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.nn.functional as F
from common_utils import combinations_grid
from torchvision import datasets
from torchvision.transforms import v2


class STL10TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -184,8 +185,9 @@ def test_combined_targets(self):
f"{actual} is not {expected}",

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(target_type="category") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(target_type="category", transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -263,8 +265,9 @@ def inject_fake_data(self, tmpdir, config):
return split_to_num_examples[config["split"]]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -391,9 +394,10 @@ def test_feature_types_target_polygon(self):
(polygon_target, info["expected_polygon_target"])

def test_transforms_v2_wrapper_spawn(self):
expected_size = (123, 321)
for target_type in ["instance", "semantic", ["instance", "semantic"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -427,8 +431,9 @@ def inject_fake_data(self, tmpdir, config):
return num_examples

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -625,9 +630,10 @@ def test_images_names_split(self):
assert merged_imgs_names == all_imgs_names

def test_transforms_v2_wrapper_spawn(self):
expected_size = (123, 321)
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -717,8 +723,9 @@ def add_bndbox(obj, bndbox=None):
return data

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class VOCDetectionTestCase(VOCSegmentationTestCase):
Expand All @@ -741,8 +748,9 @@ def test_annotations(self):
assert object == info["annotation"]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -815,8 +823,9 @@ def _create_json(self, root, name, content):
return file

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class CocoCaptionsTestCase(CocoDetectionTestCase):
Expand Down Expand Up @@ -1005,9 +1014,11 @@ def inject_fake_data(self, tmpdir, config):
)
return num_videos_per_class * len(classes)

@pytest.mark.xfail(reason="FIXME")
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(output_format="TCHW") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(output_format="TCHW", transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
Expand Down Expand Up @@ -1237,8 +1248,9 @@ def _file_stem(self, idx):
return f"2008_{idx:06d}"

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(mode="segmentation") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(mode="segmentation", transforms=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class FakeDataTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -1690,8 +1702,9 @@ def inject_fake_data(self, tmpdir, config):
return split_to_num_examples[config["train"]]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -2568,8 +2581,9 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
return (image_id, class_id, species, breed_id)

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
expected_size = (123, 321)
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)


class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down
Loading

0 comments on commit 8c42a73

Please sign in to comment.