Skip to content

Commit

Permalink
Implement torchvision operator nms | feat(torchlib) (#1253)
Browse files Browse the repository at this point in the history
- Create the scaffold and tests to support torchvision ops.
- Implement `torchvision::nms`
  • Loading branch information
justinchuby authored Jan 16, 2024
1 parent 3b2291c commit 5b9c318
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 4 deletions.
16 changes: 13 additions & 3 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ONNX = "onnx==1.14.1"
ONNX_RUNTIME = "onnxruntime==1.16.1"
PYTORCH = "torch==2.1.0"
TORCHVISON = "torchvision==0.16"
ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = (
"flatbuffers",
"coloredlogs",
Expand All @@ -52,6 +53,7 @@ def test(session):
session.install(
*COMMON_TEST_DEPENDENCIES,
PYTORCH,
TORCHVISON,
ONNX,
ONNX_RUNTIME,
)
Expand All @@ -78,7 +80,7 @@ def test_torch_nightly(session):
@nox.session(tags=["test-onnx-weekly"])
def test_onnx_weekly(session):
"""Test with ONNX weekly (preview) build."""
session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH)
session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH, TORCHVISON)
session.install("-r", "requirements/ci/requirements-onnx-weekly.txt")
session.install(".", "--no-deps")
session.run("pip", "list")
Expand All @@ -89,7 +91,11 @@ def test_onnx_weekly(session):
def test_ort_nightly(session):
"""Test with ONNX Runtime nightly builds."""
session.install(
*COMMON_TEST_DEPENDENCIES, PYTORCH, ONNX, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES
*COMMON_TEST_DEPENDENCIES,
PYTORCH,
TORCHVISON,
ONNX,
*ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
)
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
session.install(".", "--no-deps")
Expand All @@ -101,7 +107,11 @@ def test_ort_nightly(session):
def test_experimental_torchlib_tracing(session):
"""Test TorchLib with the experimental TORCHLIB_EXPERIMENTAL_PREFER_TRACING flag on."""
session.install(
*COMMON_TEST_DEPENDENCIES, PYTORCH, ONNX, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES
*COMMON_TEST_DEPENDENCIES,
PYTORCH,
TORCHVISON,
ONNX,
*ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
)
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
session.install(".", "--no-deps")
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"prims",
"sparse",
"special",
"vision",
]

from . import core, fft, linalg, nested, nn, prims, sparse, special
from . import core, fft, linalg, nested, nn, prims, sparse, special, vision
24 changes: 24 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
"""torchvision operators."""
from __future__ import annotations

from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import FLOAT, INT64

_INT64_MAX = 0x7FFFFFFFFFFFFFFF


@torch_op("torchvision::nms")
def torchvision_nms(boxes: FLOAT, scores: FLOAT, iou_threshold: float) -> INT64:
# boxes: [num_batches, spatial_dimension, 4]
boxes = op.Unsqueeze(boxes, [0])
# scores: [num_batches, num_classes, spatial_dimension]
scores = op.Unsqueeze(scores, [0, 1])
# nms_out: [num_selected_indices, 3] where each column is [batch_index, class_index, box_index]
nms_out = op.NonMaxSuppression(boxes, scores, _INT64_MAX, iou_threshold)
return op.Reshape(op.Slice(nms_out, axes=[1], starts=[2], ends=[3]), [-1])
29 changes: 29 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, List

import torch
import torchvision
from torch import testing as torch_testing
from torch.testing._internal import (
common_device_type,
Expand Down Expand Up @@ -997,6 +998,27 @@ def sample_inputs__native_batch_norm_legit_no_stats(
)


def sample_inputs_non_max_suppression(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs
boxes = torch.tensor(
[
[0.0, 0.0, 10.0, 10.0],
[10.0, 10.0, 20.0, 20.0],
[32.0, 32.0, 40.0, 52.0],
],
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
scores = torch.tensor(
[0.8, 0.4, 0.6], device=device, dtype=dtype, requires_grad=requires_grad
)

for iou_threshold in (0.3, 0.5, 0.7, 0.9):
yield opinfo_core.SampleInput(boxes, args=(scores, iou_threshold))


def sample_inputs_normal_tensor_float(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del requires_grad
Expand Down Expand Up @@ -1961,4 +1983,11 @@ def __init__(self):
supports_autograd=False,
supports_out=False,
),
opinfo_core.OpInfo(
"torchvision.ops.nms",
op=torchvision.ops.nms,
dtypes=common_dtype.floating_types(),
sample_inputs_func=sample_inputs_non_max_suppression,
supports_out=False,
),
]
2 changes: 2 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from onnxscript.function_libs.torch_lib.ops import linalg as linalg_ops
from onnxscript.function_libs.torch_lib.ops import nn as nn_ops
from onnxscript.function_libs.torch_lib.ops import special as special_ops
from onnxscript.function_libs.torch_lib.ops import vision as vision_ops
from onnxscript.tests.function_libs.torch_lib import extra_opinfo, ops_test_common

# Create a copy of the op_db to modify
Expand Down Expand Up @@ -2300,6 +2301,7 @@ def _where_input_wrangler(
reason="this Aten overload only support when correction attribute exists",
),
TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True),
TorchLibOpInfo("torchvision.ops.nms", vision_ops.torchvision_nms),
)

ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
Expand Down
1 change: 1 addition & 0 deletions requirements/ci/requirements-pytorch-nightly.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
--index-url=https://download.pytorch.org/whl/nightly/cpu
--pre
torch
torchvision

0 comments on commit 5b9c318

Please sign in to comment.