diff --git a/references/classification/README.md b/references/classification/README.md index 66ae871aede..203dae5dbc4 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -289,10 +289,10 @@ For all post training quantized models, the settings are: 2. num_workers: 16 3. batch_size: 32 4. eval_batch_size: 128 -5. backend: 'fbgemm' +5. qbackend: 'fbgemm' ``` -python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='$MODEL' +python train_quantization.py --device='cpu' --post-training-quantize --qbackend='fbgemm' --model='$MODEL' ``` Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `resnext101_32x8d`, `shufflenet_v2_x0_5` and `shufflenet_v2_x1_0`. @@ -301,12 +301,12 @@ Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `re Here are commands that we use to quantize the `shufflenet_v2_x1_5` and `shufflenet_v2_x2_0` models. ``` # For shufflenet_v2_x1_5 -python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' \ +python train_quantization.py --device='cpu' --post-training-quantize --qbackend='fbgemm' \ --model=shufflenet_v2_x1_5 --weights="ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1" \ --train-crop-size 176 --val-resize-size 232 --data-path /datasets01_ontap/imagenet_full_size/061417/ # For shufflenet_v2_x2_0 -python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' \ +python train_quantization.py --device='cpu' --post-training-quantize --qbackend='fbgemm' \ --model=shufflenet_v2_x2_0 --weights="ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1" \ --train-crop-size 176 --val-resize-size 232 --data-path /datasets01_ontap/imagenet_full_size/061417/ ``` @@ -317,7 +317,7 @@ For Mobilenet-v2, the model was trained with quantization aware training, the se 1. num_workers: 16 2. batch_size: 32 3. eval_batch_size: 128 -4. backend: 'qnnpack' +4. qbackend: 'qnnpack' 5. learning-rate: 0.0001 6. num_epochs: 90 7. num_observer_update_epochs:4 @@ -339,7 +339,7 @@ For Mobilenet-v3 Large, the model was trained with quantization aware training, 1. num_workers: 16 2. batch_size: 32 3. eval_batch_size: 128 -4. backend: 'qnnpack' +4. qbackend: 'qnnpack' 5. learning-rate: 0.001 6. num_epochs: 90 7. num_observer_update_epochs:4 @@ -359,7 +359,7 @@ For post training quant, device is set to CPU. For training, the device is set t ### Command to evaluate quantized models using the pre-trained weights: ``` -python train_quantization.py --device='cpu' --test-only --backend='' --model='' +python train_quantization.py --device='cpu' --test-only --qbackend='' --model='' ``` For inception_v3 you need to pass the following extra parameters: diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index ed36e13a028..ca1937bdbe4 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -23,9 +23,9 @@ def main(args): raise RuntimeError("Post training quantization example should not be performed on distributed mode") # Set backend engine to ensure that quantized model runs on the correct kernels - if args.backend not in torch.backends.quantized.supported_engines: - raise RuntimeError("Quantized backend not supported: " + str(args.backend)) - torch.backends.quantized.engine = args.backend + if args.qbackend not in torch.backends.quantized.supported_engines: + raise RuntimeError("Quantized backend not supported: " + str(args.qbackend)) + torch.backends.quantized.engine = args.qbackend device = torch.device(args.device) torch.backends.cudnn.benchmark = True @@ -55,7 +55,7 @@ def main(args): if not (args.test_only or args.post_training_quantize): model.fuse_model(is_qat=True) - model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend) + model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.qbackend) torch.ao.quantization.prepare_qat(model, inplace=True) if args.distributed and args.sync_bn: @@ -89,7 +89,7 @@ def main(args): ) model.eval() model.fuse_model(is_qat=False) - model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend) + model.qconfig = torch.ao.quantization.get_default_qconfig(args.qbackend) torch.ao.quantization.prepare(model, inplace=True) # Calibrate first print("Calibrating") @@ -161,7 +161,7 @@ def get_args_parser(add_help=True): parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path") parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name") - parser.add_argument("--backend", default="qnnpack", type=str, help="fbgemm or qnnpack") + parser.add_argument("--qbackend", default="qnnpack", type=str, help="Quantized backend: fbgemm or qnnpack") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument( @@ -257,9 +257,17 @@ def get_args_parser(add_help=True): parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") + parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") + return parser if __name__ == "__main__": args = get_args_parser().parse_args() + if args.backend in ("fbgemm", "qnnpack"): + raise ValueError( + "The --backend parameter has been re-purposed to specify the backend of the transforms (PIL or Tensor) " + "instead of the quantized backend. Please use the --qbackend parameter to specify the quantized backend." + ) main(args) diff --git a/setup.py b/setup.py index ce67413f410..7818a598244 100644 --- a/setup.py +++ b/setup.py @@ -129,8 +129,10 @@ def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "torchvision", "csrc") - main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + glob.glob( - os.path.join(extensions_dir, "ops", "*.cpp") + main_file = ( + glob.glob(os.path.join(extensions_dir, "*.cpp")) + + glob.glob(os.path.join(extensions_dir, "ops", "*.cpp")) + + glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp")) ) source_cpu = ( glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp")) @@ -184,8 +186,6 @@ def get_extensions(): else: source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu")) - source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp")) - sources = main_file + source_cpu extension = CppExtension diff --git a/test/datasets_utils.py b/test/datasets_utils.py index bd9f7ea3a0f..43b4103646a 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -27,7 +27,11 @@ import torchvision.io from common_utils import disable_console_output, get_tmp_dir from torch.utils._pytree import tree_any +from torch.utils.data import DataLoader +from torchvision import tv_tensors +from torchvision.datasets import wrap_dataset_for_transforms_v2 from torchvision.transforms.functional import get_dimensions +from torchvision.transforms.v2.functional import get_size __all__ = [ @@ -568,9 +572,6 @@ def test_transforms(self, config): @test_all_configs def test_transforms_v2_wrapper(self, config): - from torchvision import tv_tensors - from torchvision.datasets import wrap_dataset_for_transforms_v2 - try: with self.create_dataset(config) as (dataset, info): for target_keys in [None, "all"]: @@ -709,26 +710,29 @@ def _no_collate(batch): return batch -def check_transforms_v2_wrapper_spawn(dataset): - # On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new - # subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what - # we are enforcing here. - if platform.system() != "Darwin": - pytest.skip("Multiprocessing spawning is only checked on macOS.") +def check_transforms_v2_wrapper_spawn(dataset, expected_size): + # This check ensures that the wrapped datasets can be used with multiprocessing_context="spawn" in the DataLoader. + # We also check that transforms are applied correctly as a non-regression test for + # https://github.com/pytorch/vision/issues/8066 + # Implicitly, this also checks that the wrapped datasets are pickleable. - from torch.utils.data import DataLoader - from torchvision import tv_tensors - from torchvision.datasets import wrap_dataset_for_transforms_v2 + # To save CI/test time, we only check on Windows where "spawn" is the default + if platform.system() != "Windows": + pytest.skip("Multiprocessing spawning is only checked on macOS.") wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate) - for wrapped_sample in dataloader: - assert tree_any( - lambda item: isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)), wrapped_sample + def resize_was_applied(item): + # Checking the size of the output ensures that the Resize transform was correctly applied + return isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)) and get_size(item) == list( + expected_size ) + for wrapped_sample in dataloader: + assert tree_any(resize_was_applied, wrapped_sample) + def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor: r"""Create a random uint8 tensor. diff --git a/test/test_datasets.py b/test/test_datasets.py index 1270201d53e..832aefe5e09 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from common_utils import combinations_grid from torchvision import datasets +from torchvision.transforms import v2 class STL10TestCase(datasets_utils.ImageDatasetTestCase): @@ -184,8 +185,9 @@ def test_combined_targets(self): f"{actual} is not {expected}", def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset(target_type="category") as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(target_type="category", transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): @@ -263,8 +265,9 @@ def inject_fake_data(self, tmpdir, config): return split_to_num_examples[config["split"]] def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): @@ -391,9 +394,10 @@ def test_feature_types_target_polygon(self): (polygon_target, info["expected_polygon_target"]) def test_transforms_v2_wrapper_spawn(self): + expected_size = (123, 321) for target_type in ["instance", "semantic", ["instance", "semantic"]]: - with self.create_dataset(target_type=target_type) as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): @@ -427,8 +431,9 @@ def inject_fake_data(self, tmpdir, config): return num_examples def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase): @@ -625,9 +630,10 @@ def test_images_names_split(self): assert merged_imgs_names == all_imgs_names def test_transforms_v2_wrapper_spawn(self): + expected_size = (123, 321) for target_type in ["identity", "bbox", ["identity", "bbox"]]: - with self.create_dataset(target_type=target_type) as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): @@ -717,8 +723,9 @@ def add_bndbox(obj, bndbox=None): return data def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class VOCDetectionTestCase(VOCSegmentationTestCase): @@ -741,8 +748,9 @@ def test_annotations(self): assert object == info["annotation"] def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): @@ -815,8 +823,9 @@ def _create_json(self, root, name, content): return file def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class CocoCaptionsTestCase(CocoDetectionTestCase): @@ -1005,9 +1014,11 @@ def inject_fake_data(self, tmpdir, config): ) return num_videos_per_class * len(classes) + @pytest.mark.xfail(reason="FIXME") def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset(output_format="TCHW") as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(output_format="TCHW", transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): @@ -1237,8 +1248,9 @@ def _file_stem(self, idx): return f"2008_{idx:06d}" def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset(mode="segmentation") as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(mode="segmentation", transforms=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class FakeDataTestCase(datasets_utils.ImageDatasetTestCase): @@ -1690,8 +1702,9 @@ def inject_fake_data(self, tmpdir, config): return split_to_num_examples[config["train"]] def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class SvhnTestCase(datasets_utils.ImageDatasetTestCase): @@ -2568,8 +2581,9 @@ def _meta_to_split_and_classification_ann(self, meta, idx): return (image_id, class_id, species, breed_id) def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase): diff --git a/test/test_ops.py b/test/test_ops.py index d1cfce5919c..f4d7c2840ba 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -131,6 +131,8 @@ def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, determinist tol = 5e-3 else: tol = 4e-3 + elif x_dtype == torch.bfloat16: + tol = 5e-3 pool_size = 5 # n_channels % (pool_size ** 2) == 0 required for PS operations. @@ -504,6 +506,21 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype): rois_dtype=rois_dtype, ) + @pytest.mark.parametrize("aligned", (True, False)) + @pytest.mark.parametrize("deterministic", (True, False)) + @pytest.mark.parametrize("x_dtype", (torch.float, torch.bfloat16)) + @pytest.mark.parametrize("rois_dtype", (torch.float, torch.bfloat16)) + def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype): + with torch.cpu.amp.autocast(): + self.test_forward( + torch.device("cpu"), + contiguous=False, + deterministic=deterministic, + aligned=aligned, + x_dtype=x_dtype, + rois_dtype=rois_dtype, + ) + @pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) @@ -808,6 +825,15 @@ def test_autocast(self, iou, dtype): with torch.cuda.amp.autocast(): self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda") + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + @pytest.mark.parametrize("dtype", (torch.float, torch.bfloat16)) + def test_autocast_cpu(self, iou, dtype): + boxes, scores = self._create_tensors_with_iou(1000, iou) + with torch.cpu.amp.autocast(): + keep_ref_float = ops.nms(boxes.to(dtype).float(), scores.to(dtype).float(), iou) + keep_dtype = ops.nms(boxes.to(dtype), scores.to(dtype), iou) + torch.testing.assert_close(keep_ref_float, keep_dtype) + @pytest.mark.parametrize( "device", ( diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index 7baece2ae2c..58512753ef7 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -33,7 +33,7 @@ def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, samp ), ) num_rois = rois.size(0) - _, channels, height, width = input.size() + channels = input.size(1) return input.new_empty((num_rois, channels, pooled_height, pooled_width)) @@ -51,6 +51,51 @@ def meta_roi_align_backward( return grad.new_empty((batch_size, channels, height, width)) +@register_meta("ps_roi_align") +def meta_ps_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio): + torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]") + torch._check( + input.dtype == rois.dtype, + lambda: ( + "Expected tensor for input to have the same type as tensor for rois; " + f"but type {input.dtype} does not equal {rois.dtype}" + ), + ) + channels = input.size(1) + torch._check( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width", + ) + + num_rois = rois.size(0) + out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width) + return input.new_empty(out_size), torch.empty(out_size, dtype=torch.int32, device="meta") + + +@register_meta("_ps_roi_align_backward") +def meta_ps_roi_align_backward( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width, +): + torch._check( + grad.dtype == rois.dtype, + lambda: ( + "Expected tensor for grad to have the same type as tensor for rois; " + f"but type {grad.dtype} does not equal {rois.dtype}" + ), + ) + return grad.new_empty((batch_size, channels, height, width)) + + @torch._custom_ops.impl_abstract("torchvision::nms") def meta_nms(dets, scores, iou_threshold): torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") diff --git a/torchvision/csrc/ops/autocast/nms_kernel.cpp b/torchvision/csrc/ops/autocast/nms_kernel.cpp index 96c9ad041de..2acd0f5d0dc 100644 --- a/torchvision/csrc/ops/autocast/nms_kernel.cpp +++ b/torchvision/csrc/ops/autocast/nms_kernel.cpp @@ -9,21 +9,33 @@ namespace ops { namespace { +template at::Tensor nms_autocast( const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key); + return nms( - at::autocast::cached_cast(at::kFloat, dets), - at::autocast::cached_cast(at::kFloat, scores), + at::autocast::cached_cast(at::kFloat, dets, device_type), + at::autocast::cached_cast(at::kFloat, scores, device_type), iou_threshold); } } // namespace TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_autocast)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::nms"), + TORCH_FN( + (nms_autocast))); +} + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::nms"), + TORCH_FN( + (nms_autocast))); } } // namespace ops diff --git a/torchvision/csrc/ops/autocast/roi_align_kernel.cpp b/torchvision/csrc/ops/autocast/roi_align_kernel.cpp index 78cb2309bbe..919393a5ef0 100644 --- a/torchvision/csrc/ops/autocast/roi_align_kernel.cpp +++ b/torchvision/csrc/ops/autocast/roi_align_kernel.cpp @@ -9,6 +9,7 @@ namespace ops { namespace { +template at::Tensor roi_align_autocast( const at::Tensor& input, const at::Tensor& rois, @@ -17,10 +18,10 @@ at::Tensor roi_align_autocast( int64_t pooled_width, int64_t sampling_ratio, bool aligned) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key); return roi_align( - at::autocast::cached_cast(at::kFloat, input), - at::autocast::cached_cast(at::kFloat, rois), + at::autocast::cached_cast(at::kFloat, input, device_type), + at::autocast::cached_cast(at::kFloat, rois, device_type), spatial_scale, pooled_height, pooled_width, @@ -34,7 +35,17 @@ at::Tensor roi_align_autocast( TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl( TORCH_SELECTIVE_NAME("torchvision::roi_align"), - TORCH_FN(roi_align_autocast)); + TORCH_FN((roi_align_autocast< + c10::DispatchKey::Autocast, + c10::DeviceType::CUDA>))); +} + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_align"), + TORCH_FN((roi_align_autocast< + c10::DispatchKey::AutocastCPU, + c10::DeviceType::CPU>))); } } // namespace ops diff --git a/torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp b/torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp index 47e51ce9ca2..7205e9b15db 100644 --- a/torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp +++ b/torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp @@ -16,16 +16,16 @@ class PSROIAlignFunction const torch::autograd::Variable& input, const torch::autograd::Variable& rois, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, + c10::SymInt pooled_height, + c10::SymInt pooled_width, int64_t sampling_ratio) { ctx->saved_data["spatial_scale"] = spatial_scale; ctx->saved_data["pooled_height"] = pooled_height; ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["sampling_ratio"] = sampling_ratio; - ctx->saved_data["input_shape"] = input.sizes(); + ctx->saved_data["input_shape"] = input.sym_sizes(); at::AutoDispatchBelowADInplaceOrView g; - auto result = ps_roi_align( + auto result = ps_roi_align_symint( input, rois, spatial_scale, @@ -48,19 +48,19 @@ class PSROIAlignFunction auto saved = ctx->get_saved_variables(); auto rois = saved[0]; auto channel_mapping = saved[1]; - auto input_shape = ctx->saved_data["input_shape"].toIntList(); - auto grad_in = detail::_ps_roi_align_backward( + auto input_shape = ctx->saved_data["input_shape"].toList(); + auto grad_in = detail::_ps_roi_align_backward_symint( grad_output[0], rois, channel_mapping, ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toInt(), - ctx->saved_data["pooled_width"].toInt(), + ctx->saved_data["pooled_height"].toSymInt(), + ctx->saved_data["pooled_width"].toSymInt(), ctx->saved_data["sampling_ratio"].toInt(), - input_shape[0], - input_shape[1], - input_shape[2], - input_shape[3]); + input_shape[0].get().toSymInt(), + input_shape[1].get().toSymInt(), + input_shape[2].get().toSymInt(), + input_shape[3].get().toSymInt()); return { grad_in, @@ -82,15 +82,15 @@ class PSROIAlignBackwardFunction const torch::autograd::Variable& rois, const torch::autograd::Variable& channel_mapping, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, + c10::SymInt pooled_height, + c10::SymInt pooled_width, int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { at::AutoDispatchBelowADInplaceOrView g; - auto grad_in = detail::_ps_roi_align_backward( + auto grad_in = detail::_ps_roi_align_backward_symint( grad, rois, channel_mapping, @@ -117,8 +117,8 @@ std::tuple ps_roi_align_autograd( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, + c10::SymInt pooled_height, + c10::SymInt pooled_width, int64_t sampling_ratio) { auto result = PSROIAlignFunction::apply( input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); @@ -131,13 +131,13 @@ at::Tensor ps_roi_align_backward_autograd( const at::Tensor& rois, const at::Tensor& channel_mapping, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, + c10::SymInt pooled_height, + c10::SymInt pooled_width, int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { return PSROIAlignBackwardFunction::apply( grad, rois, diff --git a/torchvision/csrc/ops/ps_roi_align.cpp b/torchvision/csrc/ops/ps_roi_align.cpp index 6d091b3c695..de458c0d62d 100644 --- a/torchvision/csrc/ops/ps_roi_align.cpp +++ b/torchvision/csrc/ops/ps_roi_align.cpp @@ -22,6 +22,21 @@ std::tuple ps_roi_align( input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); } +std::tuple ps_roi_align_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_align", "") + .typed(); + return op.call( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +} + namespace detail { at::Tensor _ps_roi_align_backward( @@ -54,13 +69,43 @@ at::Tensor _ps_roi_align_backward( width); } +at::Tensor _ps_roi_align_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") + .typed(); + return op.call( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width); +} + } // namespace detail TORCH_LIBRARY_FRAGMENT(torchvision, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)")); + "torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio) -> (Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor")); + "torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); } } // namespace ops diff --git a/torchvision/csrc/ops/ps_roi_align.h b/torchvision/csrc/ops/ps_roi_align.h index c5ed865982c..75650586bc6 100644 --- a/torchvision/csrc/ops/ps_roi_align.h +++ b/torchvision/csrc/ops/ps_roi_align.h @@ -14,6 +14,14 @@ VISION_API std::tuple ps_roi_align( int64_t pooled_width, int64_t sampling_ratio); +VISION_API std::tuple ps_roi_align_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio); + namespace detail { at::Tensor _ps_roi_align_backward( @@ -29,6 +37,19 @@ at::Tensor _ps_roi_align_backward( int64_t height, int64_t width); +at::Tensor _ps_roi_align_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width); + } // namespace detail } // namespace ops diff --git a/torchvision/tv_tensors/_dataset_wrapper.py b/torchvision/tv_tensors/_dataset_wrapper.py index ef9260ebde9..04c3bf7133d 100644 --- a/torchvision/tv_tensors/_dataset_wrapper.py +++ b/torchvision/tv_tensors/_dataset_wrapper.py @@ -6,6 +6,7 @@ import contextlib from collections import defaultdict +from copy import copy import torch @@ -198,8 +199,19 @@ def __getitem__(self, idx): def __len__(self): return len(self._dataset) + # TODO: maybe we should use __getstate__ and __setstate__ instead of __reduce__, as recommended in the docs. def __reduce__(self): - return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys) + # __reduce__ gets called when we try to pickle the dataset. + # In a DataLoader with spawn context, this gets called `num_workers` times from the main process. + + # We have to reset the [target_]transform[s] attributes of the dataset + # to their original values, because we previously set them to None in __init__(). + dataset = copy(self._dataset) + dataset.transform = self.transform + dataset.transforms = self.transforms + dataset.target_transform = self.target_transform + + return wrap_dataset_for_transforms_v2, (dataset, self._target_keys) def raise_not_supported(description):