Skip to content

Commit

Permalink
Add ONNX export support for ViTPose
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Nov 18, 2024
1 parent 35eebfe commit c0df045
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
3 changes: 3 additions & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ class OnnxConfig(ExportConfig, ABC):
"image-to-image": OrderedDict(
{"reconstruction": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}}
),
"keypoint-detection": OrderedDict(
{"heatmaps": {0: "batch_size", 1: "num_keypoints", 2: "height", 3: "width"}}
),
"mask-generation": OrderedDict({"logits": {0: "batch_size"}}),
"masked-im": OrderedDict(
{"reconstruction" if check_if_transformers_greater("4.29.0") else "logits": {0: "batch_size"}}
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,10 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
return common_outputs


class VitPoseOnnxConfig(ViTOnnxConfig):
pass


class CvTOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 13
ATOL_FOR_VALIDATION = 1e-2
Expand Down
2 changes: 2 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ class TasksManager:
"image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"),
"image-to-image": "AutoModelForImageToImage",
"image-to-text": "AutoModelForVision2Seq",
"keypoint-detection": "VitPoseForPoseEstimation", # TODO support AutoModelForXXX
"mask-generation": "AutoModel",
"masked-im": "AutoModelForMaskedImageModeling",
"multiple-choice": "AutoModelForMultipleChoice",
Expand Down Expand Up @@ -1104,6 +1105,7 @@ class TasksManager:
"vit": supported_tasks_mapping(
"feature-extraction", "image-classification", "masked-im", onnx="ViTOnnxConfig"
),
"vitpose": supported_tasks_mapping("feature-extraction", "keypoint-detection", onnx="VitPoseOnnxConfig"),
"vits": supported_tasks_mapping(
"text-to-audio",
onnx="VitsOnnxConfig",
Expand Down

0 comments on commit c0df045

Please sign in to comment.