diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 8cd94194ff..5ca7bc0fb4 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -156,6 +156,9 @@ class OnnxConfig(ExportConfig, ABC): "image-to-image": OrderedDict( {"reconstruction": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}} ), + "keypoint-detection": OrderedDict( + {"heatmaps": {0: "batch_size", 1: "num_keypoints", 2: "height", 3: "width"}} + ), "mask-generation": OrderedDict({"logits": {0: "batch_size"}}), "masked-im": OrderedDict( {"reconstruction" if check_if_transformers_greater("4.29.0") else "logits": {0: "batch_size"}} diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index cc752779d3..64b9ae52d2 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -774,6 +774,10 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return common_outputs +class VitPoseOnnxConfig(ViTOnnxConfig): + pass + + class CvTOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 13 ATOL_FOR_VALIDATION = 1e-2 diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index fdc8bfcb53..ab98dddcb5 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -212,6 +212,7 @@ class TasksManager: "image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"), "image-to-image": "AutoModelForImageToImage", "image-to-text": "AutoModelForVision2Seq", + "keypoint-detection": "VitPoseForPoseEstimation", # TODO support AutoModelForXXX "mask-generation": "AutoModel", "masked-im": "AutoModelForMaskedImageModeling", "multiple-choice": "AutoModelForMultipleChoice", @@ -1104,6 +1105,7 @@ class TasksManager: "vit": supported_tasks_mapping( "feature-extraction", "image-classification", "masked-im", onnx="ViTOnnxConfig" ), + "vitpose": supported_tasks_mapping("feature-extraction", "keypoint-detection", onnx="VitPoseOnnxConfig"), "vits": supported_tasks_mapping( "text-to-audio", onnx="VitsOnnxConfig",