Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing graph breaks in transforms #8056

Open
NicolasHug opened this issue Oct 20, 2023 · 10 comments
Open

Removing graph breaks in transforms #8056

NicolasHug opened this issue Oct 20, 2023 · 10 comments
Labels

Comments

@NicolasHug
Copy link
Member

NicolasHug commented Oct 20, 2023

This issue tracks progress on graph breaks removal for the v2 transforms.
Restricting to pure tensors input (images) for now, we can figure out the TVTensors and arbitrary structures later.

Kernels

The low-levels kernels are almost all fine. Only 4 kernels are problematic.

import torch
from torchvision.transforms import v2
import torchvision.transforms.v2.functional as F

img = torch.rand(3, 256, 256)

# These kernels don't have graph breaks
# -------------------------------------
# torch.compile(F.get_dimensions_image, fullgraph=True)(img)
# torch.compile(F.get_num_channels_image, fullgraph=True)(img)
# torch.compile(F.get_size_image, fullgraph=True)(img)
# torch.compile(F.erase_image, fullgraph=True)(img, 0, 0, 10, 10, v=torch.tensor(0.5))
# torch.compile(F.adjust_brightness_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_contrast_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_gamma_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_hue_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_saturation_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_sharpness_image, fullgraph=True)(img, .5)
# torch.compile(F.autocontrast_image, fullgraph=True)(img)
# torch.compile(F.invert_image, fullgraph=True)(img)
# torch.compile(F.permute_channels_image, fullgraph=True)(img, [2, 1, 0])
# torch.compile(F.posterize_image, fullgraph=True)(img, bits=3)
# torch.compile(F.rgb_to_grayscale_image, fullgraph=True)(img)
# torch.compile(F.solarize_image, fullgraph=True)(img, .4)
# torch.compile(F.affine_image, fullgraph=True)(img, angle=20, translate=[1, 4], scale=1.3, shear=[0, 0])
# torch.compile(F.center_crop_image, fullgraph=True)(img, output_size=(223, 223))
# torch.compile(F.crop_image, fullgraph=True)(img, 0, 10, 10, 10)
# torch.compile(F.elastic_image, fullgraph=True)(img, displacement=torch.randn(1, *img.shape[-2:], 2))
# torch.compile(F.five_crop_image, fullgraph=True)(img, size=(223, 224))
# torch.compile(F.horizontal_flip_image, fullgraph=True)(img)
# torch.compile(F.pad_image, fullgraph=True)(img, [2, 2, 2, 2])
# torch.compile(F.rotate_image, fullgraph=True)(img, angle=30)
# torch.compile(F.ten_crop_image, fullgraph=True)(img, size=(223, 224))
# torch.compile(F.vertical_flip_image, fullgraph=True)(img)
# torch.compile(F.gaussian_blur_image, fullgraph=True)(img, kernel_size=3)
# torch.compile(F.normalize_image, fullgraph=True)(img, mean=0, std=1)
# torch.compile(to_dtype_image, fullgraph=True)(img, dtype=torch.uint8, scale=True)


# These ones have breaks

# torch.compile(F.perspective_image, fullgraph=False)(img, None, None, coefficients=torch.rand(8))
# torch.compile(F.resize_image, fullgraph=False)(img, size=(223, 223))
# torch.compile(F.resized_crop_image, fullgraph=False)(img, 0, 12, 10, 34, (223, 223))

# This one doesn't even compile
# torch.compile(F.equalize_image, fullgraph=False)(img) 

Weird thing: resize_image and resized_crop_image both break on

if (interpolation == InterpolationMode.BILINEAR and "AVX2" in torch.backends.cpu.get_cpu_capability()) or (
, but when calling them both consecutively, one of them starts breaking on
if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
as well. I have no idea why.

Functionals

As @pmeier noted offline the functionals break on

registry = _KERNEL_REGISTRY.get(functional)

which, technically, can probably be avoided since the dict entry should be constant across one execution (we still need to make sure it won't affect custom kernels that users register, or whether it changes something if we eventually want to allow users to override our default kernels)

TODO: figure out whether the call to log_api_usage_once() introduces a break.

Transforms

The transforms also break where the functionals break.
On top of that the random transforms seem to break on the call to if rand() < self.p although I don't see those breaks when using TORCH_LOGS="graph_breaks", I only see them when using _dynamo.explain(). And _dynamo.explain() in turn doesn't show the graph breaks that happens on the _KERNEL_REGISTRY. 🤷‍♂️

TODO: figure out which one we should trust, and also assess the rest of the transforms more systematically with a script similar to the one above.

CC @pmeier @vfdev-5

@pmeier
Copy link
Collaborator

pmeier commented Oct 24, 2023

I've run a few quick benchmarks whether or not it is useful to compile kernels in the first place. I've used a simple classification pipeline (random_resized_crop, horizontal_flip, to_dtype, normalize) and pure tensor input:

[------------------  -----------------]
                  |  eager  |  compiled
1 threads: ----------------------------
      kernel      |   279   |    225   
      functional  |   280   |    328   

Times are in microseconds (us).

The slowdown in the functionals stems from the graph break mentioned of _get_kernels that is the heart of our dispatch mechanism and thus present in every functional. If we hardcode the kernel, e.g.

    # kernel = _get_kernel(horizontal_flip, type(inpt))
    kernel = horizontal_flip_image

we get the following results

[------------------  -----------------]
                  |  eager  |  compiled
1 threads: ----------------------------
      kernel      |   270   |    228   
      functional  |   270   |    225   

Times are in microseconds (us).

Meaning, if we can somehow resolve the graph break, compiling the functionals will net us the same speedup as compiling the kernels directly. Note that this for now only applies to pure tensors and thus image only pipelines.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 2, 2023

I'll be working on this item:

This one doesn't even compile
torch.compile(F.equalize_image, fullgraph=False)(img)

=> PR on pytorch: pytorch/pytorch#112753

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Nov 6, 2023
Description:
- Fixed cat uint8 lowering

Otherwise, it gives the following issue on the repro code:
```python
def func(x):
    batch_shape = x.shape[:1]
    out = torch.cat([x.new_zeros(1).expand(batch_shape + (1,)), x], dim=-1)
    return out

cfunc = torch.compile(func)

x = torch.randint(0, 256, size=(3, 255), dtype=torch.uint8)
out = cfunc(x)
```
Error message:
```
  File "/pytorch/torch/_inductor/lowering.py", line 1037, in <genexpr>
    if all(len(input.layout.size) == 4 for input in inputs):
  File "/pytorch/torch/_inductor/ir.py", line 5795, in __getattr__
    fn = getattr(self.data, name)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'ExpandView' object has no attribute 'layout'
  target: aten.cat.default
  args[0]: [TensorBox(
    ExpandView(data=StorageBox(
      ComputedBuffer(name='buf0', layout=FlexibleLayout('cpu', torch.uint8, size=[1], stride=[1]), data=Pointwise(
        'cpu',
        torch.uint8,
        def inner_fn(index):
            _ = index
            tmp0 = ops.constant(0, torch.uint8)
            return tmp0
        ,
        ranges=[1],
        origin_node=full,
        origins={full}
      ))
    ), size=[3, 1])
  ), TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.uint8, size=[3, 255], stride=[255, 1]))
  ))]
  args[1]: 1

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
```

Context: compiling is not working for torchvision's `F.equalize` op: pytorch/vision#8056

Pull Request resolved: #112753
Approved by: https://github.com/peterbell10
xuhancn pushed a commit to xuhancn/pytorch that referenced this issue Nov 7, 2023
Description:
- Fixed cat uint8 lowering

Otherwise, it gives the following issue on the repro code:
```python
def func(x):
    batch_shape = x.shape[:1]
    out = torch.cat([x.new_zeros(1).expand(batch_shape + (1,)), x], dim=-1)
    return out

cfunc = torch.compile(func)

x = torch.randint(0, 256, size=(3, 255), dtype=torch.uint8)
out = cfunc(x)
```
Error message:
```
  File "/pytorch/torch/_inductor/lowering.py", line 1037, in <genexpr>
    if all(len(input.layout.size) == 4 for input in inputs):
  File "/pytorch/torch/_inductor/ir.py", line 5795, in __getattr__
    fn = getattr(self.data, name)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'ExpandView' object has no attribute 'layout'
  target: aten.cat.default
  args[0]: [TensorBox(
    ExpandView(data=StorageBox(
      ComputedBuffer(name='buf0', layout=FlexibleLayout('cpu', torch.uint8, size=[1], stride=[1]), data=Pointwise(
        'cpu',
        torch.uint8,
        def inner_fn(index):
            _ = index
            tmp0 = ops.constant(0, torch.uint8)
            return tmp0
        ,
        ranges=[1],
        origin_node=full,
        origins={full}
      ))
    ), size=[3, 1])
  ), TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.uint8, size=[3, 255], stride=[255, 1]))
  ))]
  args[1]: 1

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
```

Context: compiling is not working for torchvision's `F.equalize` op: pytorch/vision#8056

Pull Request resolved: pytorch#112753
Approved by: https://github.com/peterbell10
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 8, 2023

EDIT: Wrong conclusion:

Additional torch compile failures for boxes and seg masks:

...

@pmeier
Copy link
Collaborator

pmeier commented Nov 8, 2023

torch.compile doesn't yet handle tensor subclasses. From this error message

Argument displacement shape should be (1, 1, 4, 2), but given torch.Size([1, 17, 11, 2])

you can see that likely a tensor image made its way into a bounding box kernel.

What exactly are you testing there? That bounding box / mask inputs work properly on a compiled functional?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 8, 2023

Well, I was running tests from #8092 and it is partially my fault as I was running dispatched functions on tensors instead of subclasses... Now, the problem is with recursive error due tv_tensors.wrap which we can temporarily decorate to skip from compilation

@pmeier
Copy link
Collaborator

pmeier commented Nov 9, 2023

There are two sources of graph breaks in the way we currently dispatch:

  1. We use the dispatcher and the input type directly as dictionary keys:

    # {functional: {input_type: type_specific_kernel}}
    _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}

    This is currently not supported by dynamo. However, there is Support tensors as Dict keys pytorch#111196 that opens up dictionary keys to other types than primitives as well. If that is merged, we should be able to send a small fix to allow our use case as well.

  2. Inlining functions that use types, which is what happens when dynamo hits _get_kernel the first time, is not properly supported. I have use sourceless builder for builtin getattr pytorch#113340 to address this.

Apart from that, nothing needs to change on our side. Dynamo is fine with all the other things we worried about, i.e. global dicts, MRO traversal, ... 🎉

I've reran my benchmark with fixes for the points above and this is what I got out:

[------------------  -----------------]
                  |  eager  |  compiled
1 threads: ----------------------------
      kernel      |   265   |    230   
      functional  |   270   |    240   

Times are in microseconds (us).

I've re-run it a couple of times and the 10µs gap between compiled kernels and functionals is reproducable. Meaning the compiled functionals don't fully get to the same level as the kernels, but they still outperform their eager counterpart.

@pmeier
Copy link
Collaborator

pmeier commented Nov 9, 2023

One thing that I noticed while playing around with the benchmarks is that dynamo does not give us a strict improvement for individual ops.

random_resized_crop

[------------------  -----------------]
                  |  eager  |  compiled
1 threads: ----------------------------
      kernel      |   178   |    206   
      functional  |   178   |    207   

Times are in microseconds (us).

horizontal_flip

[------------------  -----------------]
                  |  eager  |  compiled
1 threads: ----------------------------
      kernel      |    22   |    36.4  
      functional  |    24   |    41.7  

Times are in microseconds (us).

to_dtype

[------------------  -----------------]
                  |  eager  |  compiled
1 threads: ----------------------------
      kernel      |   65.2  |    54.6  
      functional  |   67.0  |    59.3  

to_type and normalize

[------------------  -----------------]
                  |  eager  |  compiled
1 threads: ----------------------------
      kernel      |   170   |    61.4  
      functional  |   180   |    67.5  
  • resizing and horizontal_flip is slower in the compiled version that in eager
  • to_dtype is marginally faster
  • normalize (with prefixed to_dtype since normalize requires floating point input) is massively faster. IIUC, the high values in eager come from the fact that we are inputting an image with CHW memory layout and that hurts normalize. In the full pipeline this is mitigated by having the resize before that produces artificial HWC layout. The compiled version seems to have this natively.

@lezcano
Copy link
Contributor

lezcano commented Nov 9, 2023

Note that what's going to be great for torchvision is that I expect pretty much any combination of transformation to be fused into one kernel. There is where the main speed-ups will be coming from.

To this end, it'd be useful to try to benchmark through a list of transformation applied one after the other. As I told victor, I expect these wins to heavily overweight the slight regression in resize and flips.

On a different note, I'd expect the flip issue to be fixable.

@NicolasHug
Copy link
Member Author

Thanks a lot for this great investigation Philip.

@lezcano I tend to have a different intuition from yours: if resize is much faster than compiled(resize), then perhaps the speed-up gained with not compiling resize will outweight the speed-up coming from fusing resize with the op coming just before and the one coming just after (keeping the rest of the transforms compiled / fused as well). But we'll see with benchmarks. Regardless, we probably don't need to worry too much about benchmarks for now, the main goal of this issue is to remove graph breaks as a first step.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 13, 2023

Few other findings on failing tests when kernels are compiled with variable input shape: https://gist.github.com/vfdev-5/5b2733b5641d08c6889a17eda6267aba (logs contain 32k lines totally, so browser may stuck for few seconds on loading...)

Skylion007 pushed a commit to Skylion007/pytorch that referenced this issue Nov 14, 2023
Description:
- Fixed cat uint8 lowering

Otherwise, it gives the following issue on the repro code:
```python
def func(x):
    batch_shape = x.shape[:1]
    out = torch.cat([x.new_zeros(1).expand(batch_shape + (1,)), x], dim=-1)
    return out

cfunc = torch.compile(func)

x = torch.randint(0, 256, size=(3, 255), dtype=torch.uint8)
out = cfunc(x)
```
Error message:
```
  File "/pytorch/torch/_inductor/lowering.py", line 1037, in <genexpr>
    if all(len(input.layout.size) == 4 for input in inputs):
  File "/pytorch/torch/_inductor/ir.py", line 5795, in __getattr__
    fn = getattr(self.data, name)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'ExpandView' object has no attribute 'layout'
  target: aten.cat.default
  args[0]: [TensorBox(
    ExpandView(data=StorageBox(
      ComputedBuffer(name='buf0', layout=FlexibleLayout('cpu', torch.uint8, size=[1], stride=[1]), data=Pointwise(
        'cpu',
        torch.uint8,
        def inner_fn(index):
            _ = index
            tmp0 = ops.constant(0, torch.uint8)
            return tmp0
        ,
        ranges=[1],
        origin_node=full,
        origins={full}
      ))
    ), size=[3, 1])
  ), TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.uint8, size=[3, 255], stride=[255, 1]))
  ))]
  args[1]: 1

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
```

Context: compiling is not working for torchvision's `F.equalize` op: pytorch/vision#8056

Pull Request resolved: pytorch#112753
Approved by: https://github.com/peterbell10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants