-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add onnx export script for segment anything v2 (#22119)
### Description Add ONNX export script for segment anything v2 (SAM2). ### Limitations * Does not support video. Only support image right now. * The decoder does not support batch inference. ### Credits The demo that is based on [SAM2 notebook](https://github.com/facebookresearch/segment-anything-2/blob/main/notebooks/image_predictor_example.ipynb), and modified to run with ORT. The export of decoder is inspired by https://github.com/vietanhdev/samexporter. ### Demo Example output of demo: ![sam2_demo](https://github.com/user-attachments/assets/9a9fa360-8c20-482e-9935-a7aba9cf15de) ### Motivation and Context For support optimization of SAM2 image segmentation.
- Loading branch information
Showing
13 changed files
with
1,796 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
65 changes: 65 additions & 0 deletions
65
onnxruntime/python/tools/transformers/models/sam2/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# SAM2 ONNX Model Export | ||
|
||
## Setup Environment | ||
It is recommend to setup a machine with python 3.10, 3.11 or 3.12. Then install [PyTorch 2.4.1](https://pytorch.org/) and [Onnx Runtime 1.19.2]. | ||
|
||
### CPU Only | ||
To install the CPU-only version of PyTorch and Onnx Runtime for exporting and running ONNX models, use the following commands: | ||
``` | ||
python3 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu | ||
python3 -m pip install onnxruntime onnx opencv-python matplotlib | ||
``` | ||
|
||
### GPU | ||
If your machine has an NVIDIA GPU, you can install the CUDA version of PyTorch and Onnx Runtime for exporting and running ONNX models: | ||
|
||
``` | ||
python3 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 | ||
python3 -m pip install onnxruntime-gpu onnx opencv-python matplotlib | ||
``` | ||
|
||
onnxruntime-gpu requires CUDA 12.x, cuDNN 9.x, and other dependencies (such as MSVC Runtime on Windows). For more information, see the [installation guide](https://onnxruntime.ai/docs/install/#python-installs). | ||
|
||
## Download Checkpoints | ||
|
||
Clone the SAM2 git repository and download the checkpoints: | ||
```bash | ||
git clone https://github.com/facebookresearch/segment-anything-2.git | ||
cd segment-anything-2 | ||
python3 -m pip install -e . | ||
cd checkpoints | ||
sh ./download_ckpts.sh | ||
``` | ||
|
||
On Windows, you can replace `sh ./download_ckpts.sh` with the following commands: | ||
```bash | ||
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt > sam2_hiera_tiny.pt | ||
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt > sam2_hiera_small.pt | ||
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt > sam2_hiera_base_plus.pt | ||
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt > sam2_hiera_large.pt | ||
``` | ||
|
||
## Export ONNX | ||
To export ONNX models, run the convert_to_onnx.py script and specify the segment-anything-2 directory created by the above git clone command: | ||
```bash | ||
python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 | ||
``` | ||
|
||
The exported ONNX models will be found in the sam2_onnx_models sub-directory. You can change the output directory using the `--output_dir` option. | ||
|
||
If you want the model outputs multiple masks, append the `--multimask_output` option. | ||
|
||
To see all parameters, run the following command: | ||
```bash | ||
python3 convert_to_onnx.py -h | ||
``` | ||
|
||
## Run Demo | ||
The exported ONNX models can run on a CPU. The demo will output sam2_demo.png. | ||
```bash | ||
curl https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg > truck.jpg | ||
python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --demo | ||
``` | ||
|
||
## Limitations | ||
- The exported image_decoder model does not support batch mode for now. |
12 changes: 12 additions & 0 deletions
12
onnxruntime/python/tools/transformers/models/sam2/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
import os.path | ||
import sys | ||
|
||
sys.path.append(os.path.dirname(__file__)) | ||
|
||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")) | ||
if transformers_dir not in sys.path: | ||
sys.path.append(transformers_dir) |
195 changes: 195 additions & 0 deletions
195
onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (R) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
import argparse | ||
import os | ||
import pathlib | ||
import sys | ||
|
||
import torch | ||
from image_decoder import export_decoder_onnx, test_decoder_onnx | ||
from image_encoder import export_image_encoder_onnx, test_image_encoder_onnx | ||
from mask_decoder import export_mask_decoder_onnx, test_mask_decoder_onnx | ||
from prompt_encoder import export_prompt_encoder_onnx, test_prompt_encoder_onnx | ||
from sam2_demo import run_demo, show_all_images | ||
from sam2_utils import build_sam2_model, get_decoder_onnx_path, get_image_encoder_onnx_path, setup_logger | ||
|
||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser(description="Export SAM2 models to ONNX") | ||
|
||
parser.add_argument( | ||
"--model_type", | ||
required=False, | ||
type=str, | ||
choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"], | ||
default="sam2_hiera_large", | ||
help="The model type to export", | ||
) | ||
|
||
parser.add_argument( | ||
"--components", | ||
required=False, | ||
nargs="+", | ||
choices=["image_encoder", "mask_decoder", "prompt_encoder", "image_decoder"], | ||
default=["image_encoder", "image_decoder"], | ||
help="Type of ONNX models to export. " | ||
"Note that image_decoder is a combination of prompt_encoder and mask_decoder", | ||
) | ||
|
||
parser.add_argument( | ||
"--output_dir", | ||
type=str, | ||
help="The output directory for the ONNX models", | ||
default="sam2_onnx_models", | ||
) | ||
|
||
parser.add_argument( | ||
"--dynamic_batch_axes", | ||
required=False, | ||
default=False, | ||
action="store_true", | ||
help="Export image_encoder with dynamic batch axes", | ||
) | ||
|
||
parser.add_argument( | ||
"--multimask_output", | ||
required=False, | ||
default=False, | ||
action="store_true", | ||
help="Export mask_decoder or image_decoder with multimask_output", | ||
) | ||
|
||
parser.add_argument( | ||
"--disable_dynamic_multimask_via_stability", | ||
required=False, | ||
action="store_true", | ||
help="Disable mask_decoder dynamic_multimask_via_stability, and output first mask only." | ||
"This option will be ignored when multimask_output is True", | ||
) | ||
|
||
parser.add_argument( | ||
"--sam2_dir", | ||
required=False, | ||
type=str, | ||
default="./segment-anything-2", | ||
help="The directory of segment-anything-2 git repository", | ||
) | ||
|
||
parser.add_argument( | ||
"--overwrite", | ||
required=False, | ||
default=False, | ||
action="store_true", | ||
help="Overwrite onnx model file if exists.", | ||
) | ||
|
||
parser.add_argument( | ||
"--demo", | ||
required=False, | ||
default=False, | ||
action="store_true", | ||
help="Run demo with the exported ONNX models.", | ||
) | ||
|
||
parser.add_argument( | ||
"--verbose", | ||
required=False, | ||
default=False, | ||
action="store_true", | ||
help="Print verbose information", | ||
) | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = parse_arguments() | ||
|
||
checkpoints_dir = os.path.join(args.sam2_dir, "checkpoints") | ||
sam2_config_dir = os.path.join(args.sam2_dir, "sam2_configs") | ||
if not os.path.exists(args.sam2_dir): | ||
raise FileNotFoundError(f"{args.sam2_dir} does not exist. Please specify --sam2_dir correctly.") | ||
|
||
if not os.path.exists(checkpoints_dir): | ||
raise FileNotFoundError(f"{checkpoints_dir} does not exist. Please specify --sam2_dir correctly.") | ||
|
||
if not os.path.exists(sam2_config_dir): | ||
raise FileNotFoundError(f"{sam2_config_dir} does not exist. Please specify --sam2_dir correctly.") | ||
|
||
if not os.path.exists(os.path.join(checkpoints_dir, f"{args.model_type}.pt")): | ||
raise FileNotFoundError( | ||
f"{checkpoints_dir}/{args.model_type}.pt does not exist. Please download checkpoints under the directory." | ||
) | ||
|
||
if args.sam2_dir not in sys.path: | ||
sys.path.append(args.sam2_dir) | ||
|
||
pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) | ||
|
||
sam2_model = build_sam2_model(checkpoints_dir, args.model_type, device="cpu") | ||
|
||
for component in args.components: | ||
if component == "image_encoder": | ||
onnx_model_path = get_image_encoder_onnx_path(args.output_dir, args.model_type) | ||
if args.overwrite or not os.path.exists(onnx_model_path): | ||
export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose) | ||
test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=False) | ||
|
||
elif component == "mask_decoder": | ||
onnx_model_path = os.path.join(args.output_dir, f"{args.model_type}_mask_decoder.onnx") | ||
if args.overwrite or not os.path.exists(onnx_model_path): | ||
export_mask_decoder_onnx( | ||
sam2_model, | ||
onnx_model_path, | ||
args.multimask_output, | ||
not args.disable_dynamic_multimask_via_stability, | ||
args.verbose, | ||
) | ||
test_mask_decoder_onnx( | ||
sam2_model, | ||
onnx_model_path, | ||
args.multimask_output, | ||
not args.disable_dynamic_multimask_via_stability, | ||
) | ||
elif component == "prompt_encoder": | ||
onnx_model_path = os.path.join(args.output_dir, f"{args.model_type}_prompt_encoder.onnx") | ||
if args.overwrite or not os.path.exists(onnx_model_path): | ||
export_prompt_encoder_onnx(sam2_model, onnx_model_path) | ||
test_prompt_encoder_onnx(sam2_model, onnx_model_path) | ||
elif component == "image_decoder": | ||
onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, args.multimask_output) | ||
if args.overwrite or not os.path.exists(onnx_model_path): | ||
export_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output) | ||
test_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output) | ||
|
||
if args.demo: | ||
# Export required ONNX models for demo if not already exported. | ||
onnx_model_path = get_image_encoder_onnx_path(args.output_dir, args.model_type) | ||
if not os.path.exists(onnx_model_path): | ||
export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose) | ||
|
||
onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, True) | ||
if not os.path.exists(onnx_model_path): | ||
export_decoder_onnx(sam2_model, onnx_model_path, True) | ||
|
||
onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, False) | ||
if not os.path.exists(onnx_model_path): | ||
export_decoder_onnx(sam2_model, onnx_model_path, False) | ||
|
||
ort_image_files = run_demo(checkpoints_dir, args.model_type, engine="ort", onnx_directory=args.output_dir) | ||
print("demo output files for ONNX Runtime:", ort_image_files) | ||
|
||
# Get results from torch engine to compare. | ||
torch_image_files = run_demo(checkpoints_dir, args.model_type, engine="torch", onnx_directory=args.output_dir) | ||
print("demo output files for PyTorch:", torch_image_files) | ||
|
||
show_all_images(ort_image_files, torch_image_files) | ||
|
||
|
||
if __name__ == "__main__": | ||
setup_logger(verbose=False) | ||
with torch.no_grad(): | ||
main() |
Oops, something went wrong.