From cfb1a06fc3dae1ee710932a49dd51cec394d0a6d Mon Sep 17 00:00:00 2001 From: jamjamjon Date: Sun, 29 Dec 2024 17:30:18 +0800 Subject: [PATCH] ###### --- .github/workflows/rust-ci.yml | 218 +--- .gitignore | 3 +- Cargo.toml | 83 +- README.md | 10 +- benches/yolo.rs | 94 -- examples/blip/README.md | 16 +- examples/blip/main.rs | 67 +- examples/clip/README.md | 12 +- .../clip/images/{peoples.jpg => drink.jpg} | Bin examples/clip/main.rs | 55 +- examples/dataloader.rs | 42 + examples/dataloader/main.rs | 66 - examples/db/README.md | 9 - examples/db/main.rs | 27 +- examples/depth-anything/main.rs | 19 +- examples/depth-pro/README.md | 10 + examples/depth-pro/main.rs | 43 +- examples/dfine/README.md | 6 + examples/dfine/main.rs | 27 + examples/dinov2/main.rs | 48 +- examples/doclayout-yolo/README.md | 10 + examples/doclayout-yolo/main.rs | 41 + examples/fastsam.rs | 46 + examples/florence2/README.md | 31 + examples/florence2/main.rs | 328 ++--- examples/grounding-dino/README.md | 2 +- examples/grounding-dino/main.rs | 94 +- examples/hub.rs | 25 + examples/image-classification/README.md | 13 + examples/image-classification/main.rs | 59 + examples/modnet/main.rs | 15 +- examples/picodet-layout/README.md | 10 + examples/picodet-layout/main.rs | 30 + examples/rtdetr/README.md | 7 + examples/rtdetr/main.rs | 33 + examples/rtmo/main.rs | 23 +- examples/sam/README.md | 8 +- examples/sam/main.rs | 131 +- examples/sapiens/README.md | 2 +- examples/sapiens/main.rs | 30 +- examples/slanet/README.md | 10 + examples/slanet/main.rs | 46 + examples/svtr/README.md | 32 +- .../svtr/images/{3.png => license-ch-2.png} | Bin .../svtr/images/{2.png => license-ch.png} | Bin examples/svtr/images/{4.png => sign-ch-2.png} | Bin examples/svtr/images/{6.png => sign-ch.png} | Bin .../svtr/images/{8.png => text-110022345.png} | Bin examples/svtr/images/{1.png => text-ch.png} | Bin examples/svtr/images/{9.png => text-en-2.png} | Bin .../svtr/images/{7.png => text-en-dark.png} | Bin examples/svtr/images/{5.png => text-en.png} | Bin .../images/text-hello-rust-handwritten.png | Bin 0 -> 3777 bytes examples/svtr/main.rs | 41 +- examples/trocr/README.md | 13 + examples/trocr/main.rs | 91 ++ examples/viewer.rs | 38 + examples/yolo-sam/README.md | 2 +- examples/yolo-sam/main.rs | 54 +- examples/yolo/README.md | 76 +- examples/yolo/main.rs | 358 ++--- examples/yolop/main.rs | 13 +- examples/yolov8-rtdetr.rs | 46 + rust-toolchain.toml | 2 - scripts/CelebAMask-HQ-To-YOLO-Labels.py | 63 - scripts/convert2f16.py | 8 - src/core/device.rs | 14 - src/core/metric.rs | 6 - src/core/options.rs | 163 --- src/core/ort_engine.rs | 666 ---------- src/core/tokenizer_stream.rs | 87 -- src/core/ts.rs | 49 - src/core/vision.rs | 51 - src/lib.rs | 253 +--- src/{core => misc}/annotator.rs | 25 +- src/{utils => misc}/color.rs | 0 src/{utils => misc}/colormap256.rs | 0 src/{core => misc}/dataloader.rs | 96 +- src/misc/device.rs | 63 + src/{core => misc}/dir.rs | 6 +- src/misc/dtype.rs | 114 ++ src/{core => misc}/dynconf.rs | 0 src/misc/engine.rs | 743 +++++++++++ src/{core => misc}/hub.rs | 339 +++-- src/misc/iiix.rs | 15 + src/{core => misc}/logits_sampler.rs | 1 - src/{core => misc}/media.rs | 18 +- src/{core => misc}/min_opt_max.rs | 3 + src/{core => misc}/mod.rs | 37 +- src/{core => misc}/onnx.rs | 2 + src/{core => misc}/ops.rs | 204 ++- src/misc/ts.rs | 392 ++++++ src/{utils/mod.rs => misc/utils.rs} | 42 +- src/{core => misc}/viewer.rs | 14 +- src/models/basemodel.rs | 148 +++ src/models/beit/config.rs | 26 + src/models/beit/mod.rs | 1 + src/models/blip.rs | 155 --- src/models/blip/config.rs | 34 + src/models/blip/impl.rs | 130 ++ src/models/blip/mod.rs | 4 + src/models/clip.rs | 107 -- src/models/clip/config.rs | 71 + src/models/clip/impl.rs | 149 +++ src/models/clip/mod.rs | 4 + src/models/convnext/config.rs | 66 + src/models/convnext/mod.rs | 1 + src/models/db/config.rs | 29 + src/models/{db.rs => db/impl.rs} | 107 +- src/models/db/mod.rs | 4 + src/models/deit/config.rs | 30 + src/models/deit/mod.rs | 1 + src/models/depth_anything.rs | 90 -- src/models/depth_anything/config.rs | 40 + src/models/depth_anything/impl.rs | 98 ++ src/models/depth_anything/mod.rs | 4 + src/models/depth_pro.rs | 86 -- src/models/depth_pro/config.rs | 27 + src/models/depth_pro/impl.rs | 99 ++ src/models/depth_pro/mod.rs | 4 + src/models/dinov2.rs | 161 --- src/models/dinov2/config.rs | 28 + src/models/dinov2/impl.rs | 68 + src/models/dinov2/mod.rs | 4 + src/models/fastvit/config.rs | 74 ++ src/models/fastvit/mod.rs | 1 + src/models/florence2.rs | 459 ------- src/models/florence2/config.rs | 59 + src/models/florence2/impl.rs | 417 ++++++ src/models/florence2/mod.rs | 6 + src/{utils => models/florence2}/quantizer.rs | 16 +- src/models/grounding_dino.rs | 245 ---- src/models/grounding_dino/config.rs | 22 + src/models/grounding_dino/impl.rs | 223 ++++ src/models/grounding_dino/mod.rs | 4 + src/models/image_classifier.rs | 125 ++ src/models/kind.rs | 18 + src/models/labels.rs | 1155 +++++++++++++++++ src/models/mobileone/config.rs | 50 + src/models/mobileone/mod.rs | 1 + src/models/mod.rs | 65 +- src/models/modnet.rs | 84 -- src/models/modnet/config.rs | 17 + src/models/modnet/impl.rs | 90 ++ src/models/modnet/mod.rs | 4 + src/models/options.rs | 432 ++++++ src/models/picodet/config.rs | 61 + src/models/picodet/impl.rs | 111 ++ src/models/picodet/mod.rs | 4 + src/models/processor.rs | 320 +++++ src/models/rtdetr/README.md | 3 + src/models/rtdetr/config.rs | 80 ++ src/models/rtdetr/impl.rs | 128 ++ src/models/rtdetr/mod.rs | 4 + src/models/rtmo/config.rs | 28 + src/models/{rtmo.rs => rtmo/impl.rs} | 115 +- src/models/rtmo/mod.rs | 4 + src/models/sam/config.rs | 100 ++ src/models/{sam.rs => sam/impl.rs} | 151 ++- src/models/sam/mod.rs | 4 + src/models/sapiens/config.rs | 47 + src/models/{sapiens.rs => sapiens/impl.rs} | 111 +- src/models/sapiens/mod.rs | 4 + src/models/scale.rs | 83 ++ src/models/slanet/config.rs | 22 + src/models/slanet/impl.rs | 109 ++ src/models/slanet/mod.rs | 4 + src/models/svtr.rs | 101 -- src/models/svtr/config.rs | 43 + src/models/svtr/impl.rs | 109 ++ src/models/svtr/mod.rs | 4 + src/{core => models}/task.rs | 72 +- src/models/trocr/config.rs | 92 ++ src/models/trocr/impl.rs | 292 +++++ src/models/trocr/mod.rs | 4 + src/models/version.rs | 43 + src/models/yolo/config.rs | 199 +++ src/models/{yolo.rs => yolo/impl.rs} | 486 +++---- src/models/yolo/mod.rs | 6 + src/models/{yolo_.rs => yolo/preds.rs} | 127 +- src/models/yolop/config.rs | 22 + src/models/{yolop.rs => yolop/impl.rs} | 122 +- src/models/yolop/mod.rs | 4 + src/utils/names.rs | 154 --- src/{ys => xy}/bbox.rs | 4 +- src/{ys => xy}/keypoint.rs | 12 + src/{ys => xy}/mask.rs | 0 src/{ys => xy}/mbr.rs | 0 src/{ys => xy}/mod.rs | 11 +- src/{ys => xy}/polygon.rs | 0 src/{ys => xy}/prob.rs | 0 src/xy/text.rs | 17 + src/{core => xy}/x.rs | 116 +- src/{core => xy}/xs.rs | 13 + src/{ys => xy}/y.rs | 58 +- src/xy/ys.rs | 19 + src/ys/embedding.rs | 49 - 197 files changed, 9879 insertions(+), 5225 deletions(-) delete mode 100644 benches/yolo.rs rename examples/clip/images/{peoples.jpg => drink.jpg} (100%) create mode 100644 examples/dataloader.rs delete mode 100644 examples/dataloader/main.rs create mode 100644 examples/depth-pro/README.md create mode 100644 examples/dfine/README.md create mode 100644 examples/dfine/main.rs create mode 100644 examples/doclayout-yolo/README.md create mode 100644 examples/doclayout-yolo/main.rs create mode 100644 examples/fastsam.rs create mode 100644 examples/florence2/README.md create mode 100644 examples/hub.rs create mode 100644 examples/image-classification/README.md create mode 100644 examples/image-classification/main.rs create mode 100644 examples/picodet-layout/README.md create mode 100644 examples/picodet-layout/main.rs create mode 100644 examples/rtdetr/README.md create mode 100644 examples/rtdetr/main.rs create mode 100644 examples/slanet/README.md create mode 100644 examples/slanet/main.rs rename examples/svtr/images/{3.png => license-ch-2.png} (100%) rename examples/svtr/images/{2.png => license-ch.png} (100%) rename examples/svtr/images/{4.png => sign-ch-2.png} (100%) rename examples/svtr/images/{6.png => sign-ch.png} (100%) rename examples/svtr/images/{8.png => text-110022345.png} (100%) rename examples/svtr/images/{1.png => text-ch.png} (100%) rename examples/svtr/images/{9.png => text-en-2.png} (100%) rename examples/svtr/images/{7.png => text-en-dark.png} (100%) rename examples/svtr/images/{5.png => text-en.png} (100%) create mode 100644 examples/svtr/images/text-hello-rust-handwritten.png create mode 100644 examples/trocr/README.md create mode 100644 examples/trocr/main.rs create mode 100644 examples/viewer.rs create mode 100644 examples/yolov8-rtdetr.rs delete mode 100644 rust-toolchain.toml delete mode 100644 scripts/CelebAMask-HQ-To-YOLO-Labels.py delete mode 100644 scripts/convert2f16.py delete mode 100644 src/core/device.rs delete mode 100644 src/core/metric.rs delete mode 100644 src/core/options.rs delete mode 100644 src/core/ort_engine.rs delete mode 100644 src/core/tokenizer_stream.rs delete mode 100644 src/core/ts.rs delete mode 100644 src/core/vision.rs rename src/{core => misc}/annotator.rs (97%) rename src/{utils => misc}/color.rs (100%) rename src/{utils => misc}/colormap256.rs (100%) rename src/{core => misc}/dataloader.rs (83%) create mode 100644 src/misc/device.rs rename src/{core => misc}/dir.rs (97%) create mode 100644 src/misc/dtype.rs rename src/{core => misc}/dynconf.rs (100%) create mode 100644 src/misc/engine.rs rename src/{core => misc}/hub.rs (60%) create mode 100644 src/misc/iiix.rs rename src/{core => misc}/logits_sampler.rs (99%) rename src/{core => misc}/media.rs (100%) rename src/{core => misc}/min_opt_max.rs (99%) rename src/{core => misc}/mod.rs (57%) rename src/{core => misc}/onnx.rs (99%) rename src/{core => misc}/ops.rs (65%) create mode 100644 src/misc/ts.rs rename src/{utils/mod.rs => misc/utils.rs} (79%) rename src/{core => misc}/viewer.rs (92%) create mode 100644 src/models/basemodel.rs create mode 100644 src/models/beit/config.rs create mode 100644 src/models/beit/mod.rs delete mode 100644 src/models/blip.rs create mode 100644 src/models/blip/config.rs create mode 100644 src/models/blip/impl.rs create mode 100644 src/models/blip/mod.rs delete mode 100644 src/models/clip.rs create mode 100644 src/models/clip/config.rs create mode 100644 src/models/clip/impl.rs create mode 100644 src/models/clip/mod.rs create mode 100644 src/models/convnext/config.rs create mode 100644 src/models/convnext/mod.rs create mode 100644 src/models/db/config.rs rename src/models/{db.rs => db/impl.rs} (64%) create mode 100644 src/models/db/mod.rs create mode 100644 src/models/deit/config.rs create mode 100644 src/models/deit/mod.rs delete mode 100644 src/models/depth_anything.rs create mode 100644 src/models/depth_anything/config.rs create mode 100644 src/models/depth_anything/impl.rs create mode 100644 src/models/depth_anything/mod.rs delete mode 100644 src/models/depth_pro.rs create mode 100644 src/models/depth_pro/config.rs create mode 100644 src/models/depth_pro/impl.rs create mode 100644 src/models/depth_pro/mod.rs delete mode 100644 src/models/dinov2.rs create mode 100644 src/models/dinov2/config.rs create mode 100644 src/models/dinov2/impl.rs create mode 100644 src/models/dinov2/mod.rs create mode 100644 src/models/fastvit/config.rs create mode 100644 src/models/fastvit/mod.rs delete mode 100644 src/models/florence2.rs create mode 100644 src/models/florence2/config.rs create mode 100644 src/models/florence2/impl.rs create mode 100644 src/models/florence2/mod.rs rename src/{utils => models/florence2}/quantizer.rs (76%) delete mode 100644 src/models/grounding_dino.rs create mode 100644 src/models/grounding_dino/config.rs create mode 100644 src/models/grounding_dino/impl.rs create mode 100644 src/models/grounding_dino/mod.rs create mode 100644 src/models/image_classifier.rs create mode 100644 src/models/kind.rs create mode 100644 src/models/labels.rs create mode 100644 src/models/mobileone/config.rs create mode 100644 src/models/mobileone/mod.rs delete mode 100644 src/models/modnet.rs create mode 100644 src/models/modnet/config.rs create mode 100644 src/models/modnet/impl.rs create mode 100644 src/models/modnet/mod.rs create mode 100644 src/models/options.rs create mode 100644 src/models/picodet/config.rs create mode 100644 src/models/picodet/impl.rs create mode 100644 src/models/picodet/mod.rs create mode 100644 src/models/processor.rs create mode 100644 src/models/rtdetr/README.md create mode 100644 src/models/rtdetr/config.rs create mode 100644 src/models/rtdetr/impl.rs create mode 100644 src/models/rtdetr/mod.rs create mode 100644 src/models/rtmo/config.rs rename src/models/{rtmo.rs => rtmo/impl.rs} (55%) create mode 100644 src/models/rtmo/mod.rs create mode 100644 src/models/sam/config.rs rename src/models/{sam.rs => sam/impl.rs} (73%) create mode 100644 src/models/sam/mod.rs create mode 100644 src/models/sapiens/config.rs rename src/models/{sapiens.rs => sapiens/impl.rs} (65%) create mode 100644 src/models/sapiens/mod.rs create mode 100644 src/models/scale.rs create mode 100644 src/models/slanet/config.rs create mode 100644 src/models/slanet/impl.rs create mode 100644 src/models/slanet/mod.rs delete mode 100644 src/models/svtr.rs create mode 100644 src/models/svtr/config.rs create mode 100644 src/models/svtr/impl.rs create mode 100644 src/models/svtr/mod.rs rename src/{core => models}/task.rs (75%) create mode 100644 src/models/trocr/config.rs create mode 100644 src/models/trocr/impl.rs create mode 100644 src/models/trocr/mod.rs create mode 100644 src/models/version.rs create mode 100644 src/models/yolo/config.rs rename src/models/{yolo.rs => yolo/impl.rs} (58%) create mode 100644 src/models/yolo/mod.rs rename src/models/{yolo_.rs => yolo/preds.rs} (73%) create mode 100644 src/models/yolop/config.rs rename src/models/{yolop.rs => yolop/impl.rs} (69%) create mode 100644 src/models/yolop/mod.rs delete mode 100644 src/utils/names.rs rename src/{ys => xy}/bbox.rs (100%) rename src/{ys => xy}/keypoint.rs (96%) rename src/{ys => xy}/mask.rs (100%) rename src/{ys => xy}/mbr.rs (100%) rename src/{ys => xy}/mod.rs (72%) rename src/{ys => xy}/polygon.rs (100%) rename src/{ys => xy}/prob.rs (100%) create mode 100644 src/xy/text.rs rename src/{core => xy}/x.rs (54%) rename src/{core => xy}/xs.rs (88%) rename src/{ys => xy}/y.rs (71%) create mode 100644 src/xy/ys.rs delete mode 100644 src/ys/embedding.rs diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index 28bce6f..c35fadf 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -11,193 +11,73 @@ env: jobs: - build-on-linux: - name: build / linux / ffmpeg ${{ matrix.ffmpeg_version }} - runs-on: ubuntu-latest - container: jrottenberg/ffmpeg:${{ matrix.ffmpeg_version }}-ubuntu - + check: + name: Check + runs-on: ${{ matrix.os }} strategy: matrix: - ffmpeg_version: ["4.3", "4.4", "5.0", "5.1", "6.0", "6.1", "7.0"] - fail-fast: false - + os: [ubuntu-latest, macOS-latest, windows-latest] + rust: [stable] steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - apt update - apt install -y --no-install-recommends clang curl pkg-config - - - name: Setup Rust - uses: dtolnay/rust-toolchain@v1 + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 with: - toolchain: stable - - - name: Build - run: cargo build - - build-on-macos: - name: build / macos / ffmpeg latest - runs-on: macos-latest - - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - brew install ffmpeg pkg-config - - - name: Setup Rust - uses: dtolnay/rust-toolchain@v1 + profile: minimal + toolchain: ${{ matrix.rust }} + override: true + - uses: actions-rs/cargo@v1 with: - toolchain: stable - - - name: Build - run: cargo build - - - build-on-windows: - name: build / windows / ffmpeg latest - runs-on: windows-latest - - env: - FFMPEG_DOWNLOAD_URL: https://www.gyan.dev/ffmpeg/builds/ffmpeg-release-full-shared.7z - - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - $VCINSTALLDIR = $(& "${env:ProgramFiles(x86)}\Microsoft Visual Studio\Installer\vswhere.exe" -latest -property installationPath) - Add-Content $env:GITHUB_ENV "LIBCLANG_PATH=${VCINSTALLDIR}\VC\Tools\LLVM\x64\bin`n" - Invoke-WebRequest "${env:FFMPEG_DOWNLOAD_URL}" -OutFile ffmpeg-release-full-shared.7z - 7z x ffmpeg-release-full-shared.7z - mkdir ffmpeg - mv ffmpeg-*/* ffmpeg/ - Add-Content $env:GITHUB_ENV "FFMPEG_DIR=${pwd}\ffmpeg`n" - Add-Content $env:GITHUB_PATH "${pwd}\ffmpeg\bin`n" - - - name: Setup Rust - uses: dtolnay/rust-toolchain@v1 - with: - toolchain: stable - - - name: Build - run: cargo build - - - test-on-linux: - name: test / linux / ffmpeg ${{ matrix.ffmpeg_version }} - runs-on: ubuntu-latest - container: jrottenberg/ffmpeg:${{ matrix.ffmpeg_version }}-ubuntu + command: check + args: --all + test: + name: Test + runs-on: ${{ matrix.os }} strategy: matrix: - ffmpeg_version: ["4.3", "4.4", "5.0", "5.1", "6.0", "6.1", "7.0"] - fail-fast: false - + os: [ubuntu-latest, macOS-latest, windows-latest] + rust: [stable] steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - apt update - apt install -y --no-install-recommends clang curl pkg-config - - - name: Setup Rust - uses: dtolnay/rust-toolchain@v1 + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 with: - toolchain: stable - - - name: Run Tests with All Features - run: cargo test --all-features - - - name: Run Tests in Release Mode - run: cargo test --release - - test-on-macos: - name: test / macos / ffmpeg latest - runs-on: macos-latest - - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - brew install ffmpeg pkg-config - - - name: Setup Rust - uses: dtolnay/rust-toolchain@v1 + profile: minimal + toolchain: ${{ matrix.rust }} + override: true + - uses: actions-rs/cargo@v1 with: - toolchain: stable - - - name: Run Tests with All Features - run: cargo test --all-features - - - name: Run Tests in Release Mode - run: cargo test --release - - test-on-windows: - name: test / windows / ffmpeg latest - runs-on: windows-latest - - env: - FFMPEG_DOWNLOAD_URL: https://www.gyan.dev/ffmpeg/builds/ffmpeg-release-full-shared.7z + command: test + args: --all + fmt: + name: Rustfmt + runs-on: ubuntu-latest steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - $VCINSTALLDIR = $(& "${env:ProgramFiles(x86)}\Microsoft Visual Studio\Installer\vswhere.exe" -latest -property installationPath) - Add-Content $env:GITHUB_ENV "LIBCLANG_PATH=${VCINSTALLDIR}\VC\Tools\LLVM\x64\bin`n" - Invoke-WebRequest "${env:FFMPEG_DOWNLOAD_URL}" -OutFile ffmpeg-release-full-shared.7z - 7z x ffmpeg-release-full-shared.7z - mkdir ffmpeg - mv ffmpeg-*/* ffmpeg/ - Add-Content $env:GITHUB_ENV "FFMPEG_DIR=${pwd}\ffmpeg`n" - Add-Content $env:GITHUB_PATH "${pwd}\ffmpeg\bin`n" - - - name: Setup Rust - uses: dtolnay/rust-toolchain@v1 + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 with: + profile: minimal toolchain: stable - - - name: Run Tests with All Features - run: cargo test --all-features - - - name: Run Tests in Release Mode - run: cargo test --release - + override: true + - run: rustup component add rustfmt + - uses: actions-rs/cargo@v1 + with: + command: fmt + args: --all -- --check - lints: + clippy: + name: Clippy runs-on: ubuntu-latest - container: jrottenberg/ffmpeg:6-ubuntu - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - apt update - apt install -y --no-install-recommends clang curl pkg-config - - - name: Setup Rust - uses: dtolnay/rust-toolchain@v1 + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 with: + profile: minimal toolchain: stable - components: rustfmt, clippy - - - name: Rustfmt - run: cargo fmt --all -- --check + override: true + - run: rustup component add clippy + - uses: actions-rs/cargo@v1 + with: + command: clippy + args: --all --all-targets -- -D warnings - - name: Clippy - run: cargo clippy --all --all-targets --all-features -- -D warnings diff --git a/.gitignore b/.gitignore index b99985e..e1a526e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ debug/ target/ +**/*.DS_Store + # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html Cargo.lock @@ -13,7 +15,6 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb - .debug .vscode runs/ diff --git a/Cargo.toml b/Cargo.toml index f5714e9..15aa2e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,66 +1,75 @@ [package] name = "usls" -version = "0.0.20" +version = "0.0.21" edition = "2021" description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models." repository = "https://github.com/jamjamjon/usls" authors = ["Jamjamjon "] license = "MIT" readme = "README.md" -exclude = ["assets/*", "examples/*", "scripts/*", "runs/*"] +exclude = ["assets/*", "examples/*", "runs/*", "benches/*"] + [dependencies] -clap = { version = "4.2.4", features = ["derive"] } +aksr = { version = "0.0.2" } +image = { version = "0.25.2" } +imageproc = { version = "0.24" } ndarray = { version = "0.16.1", features = ["rayon"] } -ort = { version = "2.0.0-rc.5", default-features = false} +rayon = { version = "1.10.0" } anyhow = { version = "1.0.75" } regex = { version = "1.5.4" } rand = { version = "0.8.5" } chrono = { version = "0.4.30" } -half = { version = "2.3.1" } -dirs = { version = "5.0.1" } -ureq = { version = "2.9.1", default-features = true, features = [ - "socks-proxy", -] } tokenizers = { version = "0.15.2" } -rayon = "1.10.0" +log = { version = "0.4.22" } +env_logger = { version = "0.11.5" } indicatif = "0.17.8" -image = "0.25.2" -imageproc = { version = "0.24" } +serde_json = "1.0" +serde = { version = "1.0", features = ["derive"] } +ort = { version = "2.0.0-rc.9", default-features = false} +prost = "0.12.4" ab_glyph = "0.2.23" +dirs = { version = "5.0.1" } +tempfile = "3.12.0" geo = "0.28.0" -prost = "0.12.4" +half = { version = "2.3.1" } +ureq = { version = "2.9.1", default-features = true, features = [ + "socks-proxy", +] } fast_image_resize = { version = "4.2.1", features = ["image"]} -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -tempfile = "3.12.0" -video-rs = { version = "0.9.0", features = ["ndarray"] } natord = "1.0.9" -tracing = "0.1.40" -tracing-subscriber = "0.3.18" -minifb = "0.27.0" -aksr = "0.0.1" +video-rs = { version = "0.10.0", features = ["ndarray"], optional = true } +minifb = { version = "0.27.0", optional = true } +argh = "0.1.13" + + +[dev-dependencies] +tracing-subscriber = { version = "0.3.18" } +tracing = { version = "0.1.40", features = ["log"] } + + + +[[example]] +name = "viewer" +required-features = ["ffmpeg"] + [features] default = [ - "ort/load-dynamic", - "ort/copy-dylibs", - "ort/half", - "ort/ndarray", - "ort/cuda", - "ort/tensorrt", - "ort/coreml", - "ort/operator-libraries" + "ort/ndarray", + "ort/copy-dylibs", + "ort/load-dynamic", + "ort/half", ] auto = ["ort/download-binaries"] +ffmpeg = ["dep:video-rs", "dep:minifb"] +cuda = [ "ort/cuda" ] +trt = [ "ort/tensorrt" ] +mps = [ "ort/coreml" ] -[dev-dependencies] -criterion = "0.5.1" - -[[bench]] -name = "yolo" -harness = false -[lib] -bench = false +[profile.release] +# lto = true +strip = true +panic = "abort" diff --git a/README.md b/README.md index 150e77b..5fa1ec4 100644 --- a/README.md +++ b/README.md @@ -124,7 +124,7 @@ cargo run -r --example yolo # blip, clip, yolop, svtr, db, ... ``` - #### Follow the pipeline - - Build model with the provided `models` and `Options` + - Build model with the provided `models` and `ModelConfig` - Load images, video and stream with `DataLoader` - Do inference - Retrieve inference results from `Vec` @@ -136,11 +136,11 @@ cargo run -r --example yolo # blip, clip, yolop, svtr, db, ... example code ```rust - use usls::{models::YOLO, Annotator, DataLoader, Nms, Options, Vision, YOLOTask, YOLOVersion}; + use usls::{models::YOLO, Annotator, DataLoader, Nms, ModelConfig, Vision, YOLOTask, YOLOVersion}; fn main() -> anyhow::Result<()> { - // Build model with Options - let options = Options::new() + // Build model with ModelConfig + let options = ModelConfig::new() .with_trt(0) .with_model("yolo/v8-m-dyn.onnx")? .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR @@ -181,7 +181,7 @@ cargo run -r --example yolo # blip, clip, yolop, svtr, db, ... viewer.imshow(&images_plotted)?; // check out window and key event - if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) { + if !viewer.is_open() || viewer.is_key_pressed(crate::Key::Escape) { break; } diff --git a/benches/yolo.rs b/benches/yolo.rs deleted file mode 100644 index 4fbecc4..0000000 --- a/benches/yolo.rs +++ /dev/null @@ -1,94 +0,0 @@ -use anyhow::Result; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; - -use usls::{models::YOLO, DataLoader, Options, Vision, YOLOTask, YOLOVersion}; - -enum Stage { - Pre, - Run, - Post, - Pipeline, -} - -fn yolo_stage_bench( - model: &mut YOLO, - x: &[image::DynamicImage], - stage: Stage, - n: u64, -) -> std::time::Duration { - let mut t_pre = std::time::Duration::new(0, 0); - let mut t_run = std::time::Duration::new(0, 0); - let mut t_post = std::time::Duration::new(0, 0); - let mut t_pipeline = std::time::Duration::new(0, 0); - for _ in 0..n { - let t0 = std::time::Instant::now(); - let xs = model.preprocess(x).unwrap(); - t_pre += t0.elapsed(); - - let t = std::time::Instant::now(); - let xs = model.inference(xs).unwrap(); - t_run += t.elapsed(); - - let t = std::time::Instant::now(); - let _ys = black_box(model.postprocess(xs, x).unwrap()); - t_post += t.elapsed(); - t_pipeline += t0.elapsed(); - } - match stage { - Stage::Pre => t_pre, - Stage::Run => t_run, - Stage::Post => t_post, - Stage::Pipeline => t_pipeline, - } -} - -pub fn benchmark_cuda(c: &mut Criterion, h: isize, w: isize) -> Result<()> { - let mut group = c.benchmark_group(format!("YOLO ({}-{})", w, h)); - group - .significance_level(0.05) - .sample_size(80) - .measurement_time(std::time::Duration::new(20, 0)); - - let options = Options::default() - .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR - .with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb - .with_model("yolo/v8-m-dyn.onnx")? - .with_cuda(0) - // .with_cpu() - .with_num_dry_run(0) - .with_ixx(0, 2, (320, h, 1280).into()) - .with_ixx(0, 3, (320, w, 1280).into()) - .with_confs(&[0.2, 0.15]); - let mut model = YOLO::new(options)?; - - let xs = [DataLoader::try_read("./assets/bus.jpg")?]; - - group.bench_function("pre-process", |b| { - b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Pre, n)) - }); - - group.bench_function("run", |b| { - b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Run, n)) - }); - - group.bench_function("post-process", |b| { - b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Post, n)) - }); - - group.bench_function("pipeline", |b| { - b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Pipeline, n)) - }); - - group.finish(); - Ok(()) -} - -pub fn criterion_benchmark(c: &mut Criterion) { - // benchmark_cuda(c, 416, 416).unwrap(); - benchmark_cuda(c, 640, 640).unwrap(); - benchmark_cuda(c, 448, 768).unwrap(); - // benchmark_cuda(c, 800, 800).unwrap(); -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/examples/blip/README.md b/examples/blip/README.md index e0dfe3e..7b2161d 100644 --- a/examples/blip/README.md +++ b/examples/blip/README.md @@ -3,20 +3,10 @@ This demo shows how to use [BLIP](https://arxiv.org/abs/2201.12086) to do condit ## Quick Start ```shell -cargo run -r --example blip +cargo run -r -F cuda --example blip -- --device cuda:0 --source images/dog.jpg --source ./assets/bus.jpg --source images/green-car.jpg ``` -## Results - ```shell -[Unconditional]: a group of people walking around a bus -[Conditional]: three man walking in front of a bus -Some(["three man walking in front of a bus"]) +Unconditional: Ys([Y { Texts: [Text("a dog running through a field of grass")] }, Y { Texts: [Text("a group of people walking around a bus")] }, Y { Texts: [Text("a green volkswagen beetle parked in front of a yellow building")] }]) +Conditional: Ys([Y { Texts: [Text("this image depicting a dog running in a field")] }, Y { Texts: [Text("this image depict a bus in barcelona")] }, Y { Texts: [Text("this image depict a blue volkswagen beetle parked in a street in havana, cuba")] }]) ``` - -## TODO - -* [ ] Multi-batch inference for image caption -* [ ] VQA -* [ ] Retrival -* [ ] TensorRT support for textual model diff --git a/examples/blip/main.rs b/examples/blip/main.rs index da7fc89..0c7273e 100644 --- a/examples/blip/main.rs +++ b/examples/blip/main.rs @@ -1,28 +1,39 @@ -use usls::{models::Blip, DataLoader, Options}; - -fn main() -> Result<(), Box> { - // visual - let options_visual = Options::default() - .with_model("blip/visual-base.onnx")? - // .with_ixx(0, 2, 384.into()) - // .with_ixx(0, 3, 384.into()) - .with_profile(false); - - // textual - let options_textual = Options::default() - .with_model("blip/textual-base.onnx")? - .with_tokenizer("blip/tokenizer.json")? - .with_profile(false); - - // build model - let mut model = Blip::new(options_visual, options_textual)?; - - // image caption (this demo use batch_size=1) - let xs = [DataLoader::try_read("images/bus.jpg")?]; - let image_embeddings = model.encode_images(&xs)?; - let _y = model.caption(&image_embeddings, None, true)?; // unconditional - let y = model.caption(&image_embeddings, Some("three man"), true)?; // conditional - println!("{:?}", y[0].texts()); - - Ok(()) -} +use usls::{models::Blip, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// BLIP Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// source image + #[argh(option, default = "vec![String::from(\"./assets/bus.jpg\")]")] + source: Vec, +} + +fn main() -> anyhow::Result<()> { + let args: Args = argh::from_env(); + + // build model + let options_visual = Options::blip_v1_base_caption_visual() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let options_textual = Options::blip_v1_base_caption_textual() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = Blip::new(options_visual, options_textual)?; + + // image caption + let xs = DataLoader::try_read_batch(&args.source)?; + + // unconditional caption + let ys = model.forward(&xs, None)?; + println!("Unconditional: {:?}", ys); + + // conditional caption + let ys = model.forward(&xs, Some("this image depict"))?; + println!("Conditional: {:?}", ys); + + Ok(()) +} diff --git a/examples/clip/README.md b/examples/clip/README.md index d85a682..09ff510 100644 --- a/examples/clip/README.md +++ b/examples/clip/README.md @@ -3,18 +3,14 @@ This demo showcases how to use [CLIP](https://github.com/openai/CLIP) to compute ## Quick Start ```shell -cargo run -r --example clip +cargo run -r -F cuda --example clip -- --device cuda:0 ``` ## Results ```shell -(90.11472%) ./examples/clip/images/carrot.jpg => 几个胡萝卜 -[0.04573484, 0.0048218793, 0.0011618224, 0.90114725, 0.0036694852, 0.031348046, 0.0121166315] +(99.9675%) ./examples/clip/images/carrot.jpg => Some carrots +(99.93718%) ./examples/clip/images/doll.jpg => There is a doll with red hair and a clock on a table +(100.0%) ./examples/clip/images/drink.jpg => Some people holding wine glasses in a restaurant -(94.07785%) ./examples/clip/images/peoples.jpg => Some people holding wine glasses in a restaurant -[0.050406333, 0.0011632168, 0.0019338318, 0.0013227565, 0.003916758, 0.00047858112, 0.9407785] - -(86.59852%) ./examples/clip/images/doll.jpg => There is a doll with red hair and a clock on a table -[0.07032883, 0.00053773675, 0.0006372929, 0.06066096, 0.0007378078, 0.8659852, 0.0011121632] ``` \ No newline at end of file diff --git a/examples/clip/images/peoples.jpg b/examples/clip/images/drink.jpg similarity index 100% rename from examples/clip/images/peoples.jpg rename to examples/clip/images/drink.jpg diff --git a/examples/clip/main.rs b/examples/clip/main.rs index 0fd03ce..9acd02f 100644 --- a/examples/clip/main.rs +++ b/examples/clip/main.rs @@ -1,43 +1,49 @@ -use usls::{models::Clip, DataLoader, Options}; +use anyhow::Result; +use usls::{models::Clip, DataLoader, Ops, Options}; -fn main() -> Result<(), Box> { - // visual - let options_visual = Options::default().with_model("clip/visual-base-dyn.onnx")?; - - // textual - let options_textual = Options::default() - .with_model("clip/textual-base-dyn.onnx")? - .with_tokenizer("clip/tokenizer.json")?; +#[derive(argh::FromArgs)] +/// CLIP Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} +fn main() -> Result<()> { + let args: Args = argh::from_env(); // build model + let options_visual = Options::jina_clip_v1_visual() + // clip_vit_b32_visual() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let options_textual = Options::jina_clip_v1_textual() + // clip_vit_b32_textual() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; let mut model = Clip::new(options_visual, options_textual)?; // texts let texts = vec![ - "A photo of a dinosaur ".to_string(), - "A photo of a cat".to_string(), - "A photo of a dog".to_string(), - "几个胡萝卜".to_string(), - "There are some playing cards on a striped table cloth".to_string(), - "There is a doll with red hair and a clock on a table".to_string(), - "Some people holding wine glasses in a restaurant".to_string(), + "A photo of a dinosaur", + "A photo of a cat", + "A photo of a dog", + "Some carrots", + "There are some playing cards on a striped table cloth", + "There is a doll with red hair and a clock on a table", + "Some people holding wine glasses in a restaurant", ]; let feats_text = model.encode_texts(&texts)?; // [n, ndim] - // load image + // load images let dl = DataLoader::new("./examples/clip/images")?.build()?; - // loop + // run for (images, paths) in dl { - let feats_image = model.encode_images(&images).unwrap(); + let feats_image = model.encode_images(&images)?; // use image to query texts - let matrix = match feats_image.embedding() { - Some(x) => x.dot2(feats_text.embedding().unwrap())?, - None => continue, - }; + let matrix = Ops::dot2(&feats_image, &feats_text)?; - // summary for i in 0..paths.len() { let probs = &matrix[i]; let (id, &score) = probs @@ -52,7 +58,6 @@ fn main() -> Result<(), Box> { paths[i].display(), &texts[id] ); - println!("{:?}\n", probs); } } diff --git a/examples/dataloader.rs b/examples/dataloader.rs new file mode 100644 index 0000000..799e56b --- /dev/null +++ b/examples/dataloader.rs @@ -0,0 +1,42 @@ +use usls::DataLoader; + +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt::init(); + + // 1. iterator + let dl = DataLoader::try_from( + // "images/bus.jpg", // remote image + // "../images", // image folder + // "../demo.mp4", // local video + // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // remote video + // "rtsp://admin:xyz@192.168.2.217:554/h265/ch1/", // rtsp h264 stream + "./assets/bus.jpg", // local image + )? + .with_batch(1) + .with_progress_bar(true) + .build()?; + + for (_xs, _paths) in dl { + println!("Paths: {:?}", _paths); + } + + // 2. read one image + let image = DataLoader::try_read("./assets/bus.jpg")?; + println!( + "Read one image. Height: {}, Width: {}", + image.height(), + image.width() + ); + + // 3. read several images + let images = DataLoader::try_read_batch(&[ + "./assets/bus.jpg", + "./assets/bus.jpg", + "./assets/bus.jpg", + "./assets/bus.jpg", + "./assets/bus.jpg", + ])?; + println!("Read {} images.", images.len()); + + Ok(()) +} diff --git a/examples/dataloader/main.rs b/examples/dataloader/main.rs deleted file mode 100644 index 9ab7430..0000000 --- a/examples/dataloader/main.rs +++ /dev/null @@ -1,66 +0,0 @@ -use usls::{ - models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOTask, YOLOVersion, -}; - -fn main() -> anyhow::Result<()> { - tracing_subscriber::fmt() - .with_max_level(tracing::Level::ERROR) - .init(); - - let options = Options::new() - .with_device(Device::Cuda(0)) - .with_model("yolo/v8-m-det.onnx")? - .with_yolo_version(YOLOVersion::V8) - .with_yolo_task(YOLOTask::Detect) - .with_batch(2) - .with_ixx(0, 2, (416, 640, 800).into()) - .with_ixx(0, 3, (416, 640, 800).into()) - .with_confs(&[0.2]); - let mut model = YOLO::new(options)?; - - // build annotator - let annotator = Annotator::new() - .with_bboxes_thickness(4) - .with_saveout("YOLO-DataLoader"); - - // build dataloader - let dl = DataLoader::new( - // "images/bus.jpg", // remote image - // "../images", // image folder - // "../demo.mp4", // local video - // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // remote video - // "rtsp://admin:xyz@192.168.2.217:554/h265/ch1/", // rtsp h264 stream - // "./assets/bus.jpg", // local image - "../7.mp4", - )? - .with_batch(1) - .build()?; - - let mut viewer = Viewer::new().with_delay(10).with_scale(1.).resizable(true); - - // iteration - for (xs, _) in dl { - // inference & annotate - let ys = model.run(&xs)?; - let images_plotted = annotator.plot(&xs, &ys, false)?; - - // show image - viewer.imshow(&images_plotted)?; - - // check out window and key event - if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) { - break; - } - - // write video - viewer.write_batch(&images_plotted)?; - } - - // finish video write - viewer.finish_write()?; - - // images -> video - // DataLoader::is2v("runs/YOLO-DataLoader", &["runs", "is2v"], 24)?; - - Ok(()) -} diff --git a/examples/db/README.md b/examples/db/README.md index 6da1cfc..9e19375 100644 --- a/examples/db/README.md +++ b/examples/db/README.md @@ -4,15 +4,6 @@ cargo run -r --example db ``` -### Speed test - -| Model | Image size | TensorRT
f16
batch=1
(ms) | TensorRT
f32
batch=1
(ms) | CUDA
f32
batch=1
(ms) | -| --------------- | ---------- | ---------------------------------------- | ---------------------------------------- | ------------------------------------ | -| ppocr-v3-db-dyn | 640x640 | 1.8585 | 2.5739 | 4.3314 | -| ppocr-v4-db-dyn | 640x640 | 2.0507 | 2.8264 | 6.6064 | - -***Test on RTX3060*** - ## Results ![](https://github.com/jamjamjon/assets/releases/download/db/demo-paper.png) diff --git a/examples/db/main.rs b/examples/db/main.rs index b133216..fa3d8fe 100644 --- a/examples/db/main.rs +++ b/examples/db/main.rs @@ -1,27 +1,20 @@ +use anyhow::Result; use usls::{models::DB, Annotator, DataLoader, Options}; -fn main() -> Result<(), Box> { - // build model - let options = Options::default() - .with_ixx(0, 0, (1, 4, 8).into()) - .with_ixx(0, 2, (608, 960, 1280).into()) - .with_ixx(0, 3, (608, 960, 1280).into()) - // .with_trt(0) - .with_confs(&[0.4]) - .with_min_width(5.0) - .with_min_height(12.0) - .with_model("db/ppocr-v4-db-dyn.onnx")?; +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + // build model + let options = Options::ppocr_det_v4_ch().commit()?; let mut model = DB::new(options)?; // load image - let x = [ - DataLoader::try_read("images/db.png")?, - DataLoader::try_read("images/street.jpg")?, - ]; + let x = DataLoader::try_read_batch(&["images/db.png", "images/street.jpg"])?; // run - let y = model.run(&x)?; + let y = model.forward(&x)?; // annotate let annotator = Annotator::default() @@ -29,7 +22,7 @@ fn main() -> Result<(), Box> { .with_polygons_alpha(60) .with_contours_color([255, 105, 180, 255]) .without_mbrs(true) - .with_saveout("DB"); + .with_saveout(model.spec()); annotator.annotate(&x, &y); Ok(()) diff --git a/examples/depth-anything/main.rs b/examples/depth-anything/main.rs index d339ff3..7e1ba7d 100644 --- a/examples/depth-anything/main.rs +++ b/examples/depth-anything/main.rs @@ -1,24 +1,25 @@ +use anyhow::Result; use usls::{models::DepthAnything, Annotator, DataLoader, Options}; -fn main() -> Result<(), Box> { - // options - let options = Options::default() - // .with_model("depth-anything/v1-s-dyn.onnx")? - .with_model("depth-anything/v2-s.onnx")? - .with_ixx(0, 2, (384, 512, 1024).into()) - .with_ixx(0, 3, (384, 512, 1024).into()); +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + // build model + let options = Options::depth_anything_v2_small().commit()?; let mut model = DepthAnything::new(options)?; // load let x = [DataLoader::try_read("images/street.jpg")?]; // run - let y = model.run(&x)?; + let y = model.forward(&x)?; // annotate let annotator = Annotator::default() .with_colormap("Turbo") - .with_saveout("Depth-Anything"); + .with_saveout(model.spec()); annotator.annotate(&x, &y); Ok(()) diff --git a/examples/depth-pro/README.md b/examples/depth-pro/README.md new file mode 100644 index 0000000..52c1418 --- /dev/null +++ b/examples/depth-pro/README.md @@ -0,0 +1,10 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example depth-pro -- --device cuda +``` + + +## Results + +![](https://github.com/jamjamjon/assets/releases/download/depth-pro/demo-depth-pro.png) diff --git a/examples/depth-pro/main.rs b/examples/depth-pro/main.rs index eb72a9a..7a0c246 100644 --- a/examples/depth-pro/main.rs +++ b/examples/depth-pro/main.rs @@ -1,25 +1,46 @@ +use anyhow::Result; use usls::{models::DepthPro, Annotator, DataLoader, Options}; -fn main() -> Result<(), Box> { - // options - let options = Options::default() - .with_model("depth-pro/q4f16.onnx")? // bnb4, f16 - .with_ixx(0, 0, 1.into()) // batch. Note: now only support batch_size = 1 - .with_ixx(0, 1, 3.into()) // channel - .with_ixx(0, 2, 1536.into()) // height - .with_ixx(0, 3, 1536.into()); // width +#[derive(argh::FromArgs)] +/// BLIP Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// dtype + #[argh(option, default = "String::from(\"q4f16\")")] + dtype: String, + + /// source image + #[argh(option, default = "String::from(\"images/street.jpg\")")] + source: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + let args: Args = argh::from_env(); + + // model + let options = Options::depth_pro() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; let mut model = DepthPro::new(options)?; // load - let x = [DataLoader::try_read("images/street.jpg")?]; + let x = [DataLoader::try_read(&args.source)?]; // run - let y = model.run(&x)?; + let y = model.forward(&x)?; // annotate let annotator = Annotator::default() .with_colormap("Turbo") - .with_saveout("Depth-Pro"); + .with_saveout(model.spec()); annotator.annotate(&x, &y); Ok(()) diff --git a/examples/dfine/README.md b/examples/dfine/README.md new file mode 100644 index 0000000..7105118 --- /dev/null +++ b/examples/dfine/README.md @@ -0,0 +1,6 @@ +## Quick Start + +```shell +cargo run -r --example dfine +``` + diff --git a/examples/dfine/main.rs b/examples/dfine/main.rs new file mode 100644 index 0000000..3f6813c --- /dev/null +++ b/examples/dfine/main.rs @@ -0,0 +1,27 @@ +use anyhow::Result; +use usls::{models::RTDETR, Annotator, DataLoader, Options}; + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + // options + let options = Options::dfine_n_coco().commit()?; + let mut model = RTDETR::new(options)?; + + // load + let x = [DataLoader::try_read("./assets/bus.jpg")?]; + + // run + let y = model.forward(&x)?; + println!("{:?}", y); + + // annotate + let annotator = Annotator::default() + .with_bboxes_thickness(3) + .with_saveout(model.spec()); + annotator.annotate(&x, &y); + + Ok(()) +} diff --git a/examples/dinov2/main.rs b/examples/dinov2/main.rs index 4cc7732..749adb5 100644 --- a/examples/dinov2/main.rs +++ b/examples/dinov2/main.rs @@ -1,40 +1,20 @@ -use usls::{models::Dinov2, DataLoader, Options}; +use anyhow::Result; +use usls::{models::DINOv2, DataLoader, Options}; -fn main() -> Result<(), Box> { - // build model - let options = Options::default() - .with_model("dinov2/s-dyn.onnx")? - .with_ixx(0, 2, 224.into()) - .with_ixx(0, 3, 224.into()); - let mut model = Dinov2::new(options)?; - let x = [DataLoader::try_read("images/bus.jpg")?]; - let y = model.run(&x)?; - println!("{y:?}"); +fn main() -> Result<()> { + // images + let xs = [ + DataLoader::try_read("./assets/bus.jpg")?, + DataLoader::try_read("./assets/bus.jpg")?, + ]; - // TODO: - // query from vector - // let ys = model.query_from_vec( - // "./assets/bus.jpg", - // &[ - // "./examples/dinov2/images/bus.jpg", - // "./examples/dinov2/images/1.jpg", - // "./examples/dinov2/images/2.jpg", - // ], - // Metric::L2, - // )?; + // model + let options = Options::dinov2_small().with_batch_size(xs.len()).commit()?; + let mut model = DINOv2::new(options)?; - // or query from folder - // let ys = model.query_from_folder("./assets/bus.jpg", "./examples/dinov2/images", Metric::IP)?; - - // results - // for (i, y) in ys.iter().enumerate() { - // println!( - // "Top-{:<3}{:.7} {}", - // i + 1, - // y.1, - // y.2.canonicalize()?.display() - // ); - // } + // encode images + let y = model.encode_images(&xs)?; + println!("Feat shape: {:?}", y.shape()); Ok(()) } diff --git a/examples/doclayout-yolo/README.md b/examples/doclayout-yolo/README.md new file mode 100644 index 0000000..b9b233f --- /dev/null +++ b/examples/doclayout-yolo/README.md @@ -0,0 +1,10 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example doclayout-yolo -- --device cuda +``` + + +## Results + +![](https://github.com/jamjamjon/assets/releases/download/yolo/demo-doclayout-yolo.png) diff --git a/examples/doclayout-yolo/main.rs b/examples/doclayout-yolo/main.rs new file mode 100644 index 0000000..9727428 --- /dev/null +++ b/examples/doclayout-yolo/main.rs @@ -0,0 +1,41 @@ +use anyhow::Result; +use usls::{models::YOLO, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + let args: Args = argh::from_env(); + + // build model + let config = Options::doclayout_yolo_docstructbench() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = YOLO::new(config)?; + + // load images + let xs = [DataLoader::try_read("images/academic.jpg")?]; + + // run + let ys = model.forward(&xs)?; + // println!("{:?}", ys); + + // annotate + let annotator = Annotator::default() + .with_bboxes_thickness(3) + .with_saveout("doclayout-yolo"); + annotator.annotate(&xs, &ys); + + model.summary(); + + Ok(()) +} diff --git a/examples/fastsam.rs b/examples/fastsam.rs new file mode 100644 index 0000000..1b3bde8 --- /dev/null +++ b/examples/fastsam.rs @@ -0,0 +1,46 @@ +use anyhow::Result; +use usls::{models::YOLO, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"fp16\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + let args: Args = argh::from_env(); + + // build model + let config = Options::fastsam_s() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = YOLO::new(config)?; + + // load images + let xs = DataLoader::try_read_batch(&["./assets/bus.jpg"])?; + + // run + let ys = model.forward(&xs)?; + + // annotate + let annotator = Annotator::default() + .without_masks(true) + .with_bboxes_thickness(3) + .with_saveout("fastsam"); + annotator.annotate(&xs, &ys); + + model.summary(); + + Ok(()) +} diff --git a/examples/florence2/README.md b/examples/florence2/README.md new file mode 100644 index 0000000..3764ea8 --- /dev/null +++ b/examples/florence2/README.md @@ -0,0 +1,31 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example florence2 -- --device cuda --scale base --dtype fp16 +``` + + +```Shell +Task: Caption(0) +Ys([Y { Texts: [Text("A green car parked in front of a yellow building.")] }, Y { Texts: [Text("A group of people walking down a street next to a bus.")] }]) + +Task: Caption(1) +Ys([Y { Texts: [Text("The image shows a green car parked in front of a yellow building with two brown doors. The car is on the road, and the building has a wall and a tree in the background.")] }, Y { Texts: [Text("The image shows a group of people walking down a street next to a bus, with a building in the background. The bus is likely part of the World Electric Emission Bus, which is a new bus that will be launched in Madrid. The people are walking on the road, and there are trees and a sign board to the left of the bus.")] }]) + +Task: Caption(2) +Ys([Y { Texts: [Text("The image shows a vintage Volkswagen Beetle car parked on a cobblestone street in front of a yellow building with two wooden doors. The car is a light blue color with silver rims and appears to be in good condition. The building has a sloping roof and is painted in a bright yellow color. The sky is blue and there are trees in the background. The overall mood of the image is peaceful and serene.")] }, Y { Texts: [Text("The image shows a blue and white bus with the logo of the Brazilian football club, Cero Emisiones, on the side. The bus is parked on a street with a building in the background. There are several people walking on the sidewalk in front of the bus, some of them are carrying bags and one person is holding a camera. The sky is blue and there are trees and a traffic light visible in the top right corner of the image. The image appears to be taken during the day.")] }]) +``` + + +# Tasks + +| Task | Demo | +| -----| ------| +|Caption-To-Phrase-Grounding | | +| Ocr-With-Region | | +| Dense-Region-Caption | | +| Object-Detection | | +| Region-Proposal | | +| Referring-Expression-Segmentation | | + + diff --git a/examples/florence2/main.rs b/examples/florence2/main.rs index 07cc7d1..87ac9db 100644 --- a/examples/florence2/main.rs +++ b/examples/florence2/main.rs @@ -1,157 +1,171 @@ -use usls::{models::Florence2, Annotator, DataLoader, Options, Task}; - -fn main() -> Result<(), Box> { - let batch_size = 3; - - // vision encoder - let options_vision_encoder = Options::default() - .with_model("florence2/base-vision-encoder-f16.onnx")? - .with_ixx(0, 2, (512, 768, 800).into()) - .with_ixx(0, 3, 768.into()) - .with_ixx(0, 0, (1, batch_size as _, 8).into()); - - // text embed - let options_text_embed = Options::default() - .with_model("florence2/base-embed-tokens-f16.onnx")? - .with_tokenizer("florence2/tokenizer.json")? - .with_batch(batch_size); - - // transformer encoder - let options_encoder = Options::default() - .with_model("florence2/base-encoder-f16.onnx")? - .with_batch(batch_size); - - // transformer decoder - let options_decoder = Options::default() - .with_model("florence2/base-decoder-f16.onnx")? - .with_batch(batch_size); - - // transformer decoder merged - let options_decoder_merged = Options::default() - .with_model("florence2/base-decoder-merged-f16.onnx")? - .with_batch(batch_size); - - // build model - let mut model = Florence2::new( - options_vision_encoder, - options_text_embed, - options_encoder, - options_decoder, - options_decoder_merged, - )?; - - // load images - let xs = [ - // DataLoader::try_read("florence2/car.jpg")?, // for testing region-related tasks - DataLoader::try_read("florence2/car.jpg")?, - // DataLoader::try_read("images/db.png")?, - DataLoader::try_read("assets/bus.jpg")?, - ]; - - // region-related tasks - let quantizer = usls::Quantizer::default(); - // let coords = [449., 270., 556., 372.]; // wheel - let coords = [31., 156., 581., 373.]; // car - let (width_car, height_car) = (xs[0].width(), xs[0].height()); - let quantized_coords = quantizer.quantize(&coords, (width_car as _, height_car as _)); - - // run with tasks - let ys = model.run_with_tasks( - &xs, - &[ - // w/ inputs - Task::Caption(0), - Task::Caption(1), - Task::Caption(2), - Task::Ocr, - Task::OcrWithRegion, - Task::RegionProposal, - Task::ObjectDetection, - Task::DenseRegionCaption, - // w/o inputs - Task::OpenSetDetection("a vehicle".into()), - Task::CaptionToPhraseGrounding( - "A vehicle with two wheels parked in front of a building.".into(), - ), - Task::ReferringExpressionSegmentation("a vehicle".into()), - Task::RegionToSegmentation( - quantized_coords[0], - quantized_coords[1], - quantized_coords[2], - quantized_coords[3], - ), - Task::RegionToCategory( - quantized_coords[0], - quantized_coords[1], - quantized_coords[2], - quantized_coords[3], - ), - Task::RegionToDescription( - quantized_coords[0], - quantized_coords[1], - quantized_coords[2], - quantized_coords[3], - ), - ], - )?; - - // annotator - let annotator = Annotator::new() - .without_bboxes_conf(true) - .with_bboxes_thickness(3) - .with_saveout_subs(&["Florence2"]); - for (task, ys_) in ys.iter() { - match task { - Task::Caption(_) - | Task::Ocr - | Task::RegionToCategory(..) - | Task::RegionToDescription(..) => { - println!("Task: {:?}\n{:?}\n", task, ys_) - } - Task::DenseRegionCaption => { - let annotator = annotator.clone().with_saveout("Dense-Region-Caption"); - annotator.annotate(&xs, ys_); - } - Task::RegionProposal => { - let annotator = annotator - .clone() - .without_bboxes_name(false) - .with_saveout("Region-Proposal"); - - annotator.annotate(&xs, ys_); - } - Task::ObjectDetection => { - let annotator = annotator.clone().with_saveout("Object-Detection"); - annotator.annotate(&xs, ys_); - } - Task::OpenSetDetection(_) => { - let annotator = annotator.clone().with_saveout("Open-Set-Detection"); - annotator.annotate(&xs, ys_); - } - Task::CaptionToPhraseGrounding(_) => { - let annotator = annotator - .clone() - .with_saveout("Caption-To-Phrase-Grounding"); - annotator.annotate(&xs, ys_); - } - Task::ReferringExpressionSegmentation(_) => { - let annotator = annotator - .clone() - .with_saveout("Referring-Expression-Segmentation"); - annotator.annotate(&xs, ys_); - } - Task::RegionToSegmentation(..) => { - let annotator = annotator.clone().with_saveout("Region-To-Segmentation"); - annotator.annotate(&xs, ys_); - } - Task::OcrWithRegion => { - let annotator = annotator.clone().with_saveout("Ocr-With-Region"); - annotator.annotate(&xs, ys_); - } - - _ => (), - } - } - - Ok(()) -} +use anyhow::Result; +use usls::{models::Florence2, Annotator, DataLoader, Options, Scale, Task}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// scale + #[argh(option, default = "String::from(\"base\")")] + scale: String, +} + +fn main() -> Result<()> { + let args: Args = argh::from_env(); + + // load images + let xs = [ + DataLoader::try_read("images/green-car.jpg")?, + DataLoader::try_read("assets/bus.jpg")?, + ]; + + // build model + let ( + options_vision_encoder, + options_text_embed, + options_encoder, + options_decoder, + options_decoder_merged, + ) = match args.scale.as_str().try_into()? { + Scale::B => ( + Options::florence2_visual_encoder_base(), + Options::florence2_textual_embed_base(), + Options::florence2_texual_encoder_base(), + Options::florence2_texual_decoder_base(), + Options::florence2_texual_decoder_merged_base(), + ), + Scale::L => todo!(), + _ => anyhow::bail!("Unsupported Florence2 scale."), + }; + + let mut model = Florence2::new( + options_vision_encoder + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + options_text_embed + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + options_encoder + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + options_decoder + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + options_decoder_merged + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + )?; + + // tasks + let tasks = [ + // w inputs + Task::Caption(0), + Task::Caption(1), + Task::Caption(2), + Task::Ocr, + // Task::OcrWithRegion, + Task::RegionProposal, + Task::ObjectDetection, + Task::DenseRegionCaption, + // w/o inputs + Task::OpenSetDetection("a vehicle"), + Task::CaptionToPhraseGrounding("A vehicle with two wheels parked in front of a building."), + Task::ReferringExpressionSegmentation("a vehicle"), + Task::RegionToSegmentation( + // 31, 156, 581, 373, // car + 449, 270, 556, 372, // wheel + ), + Task::RegionToCategory( + // 31, 156, 581, 373, + 449, 270, 556, 372, + ), + Task::RegionToDescription( + // 31, 156, 581, 373, + 449, 270, 556, 372, + ), + ]; + + // annotator + let annotator = Annotator::new() + .without_bboxes_conf(true) + .with_bboxes_thickness(3) + .with_saveout_subs(&["Florence2"]); + + // inference + for task in tasks.iter() { + let ys = model.forward(&xs, task)?; + + // annotate + match task { + Task::Caption(_) + | Task::Ocr + | Task::RegionToCategory(..) + | Task::RegionToDescription(..) => { + println!("Task: {:?}\n{:?}\n", task, &ys) + } + Task::DenseRegionCaption => { + let annotator = annotator.clone().with_saveout("Dense-Region-Caption"); + annotator.annotate(&xs, &ys); + } + Task::RegionProposal => { + let annotator = annotator + .clone() + .without_bboxes_name(false) + .with_saveout("Region-Proposal"); + + annotator.annotate(&xs, &ys); + } + Task::ObjectDetection => { + let annotator = annotator.clone().with_saveout("Object-Detection"); + annotator.annotate(&xs, &ys); + } + Task::OpenSetDetection(_) => { + let annotator = annotator.clone().with_saveout("Open-Set-Detection"); + annotator.annotate(&xs, &ys); + } + Task::CaptionToPhraseGrounding(_) => { + let annotator = annotator + .clone() + .with_saveout("Caption-To-Phrase-Grounding"); + annotator.annotate(&xs, &ys); + } + Task::ReferringExpressionSegmentation(_) => { + let annotator = annotator + .clone() + .with_saveout("Referring-Expression-Segmentation"); + annotator.annotate(&xs, &ys); + } + Task::RegionToSegmentation(..) => { + let annotator = annotator.clone().with_saveout("Region-To-Segmentation"); + annotator.annotate(&xs, &ys); + } + Task::OcrWithRegion => { + let annotator = annotator.clone().with_saveout("Ocr-With-Region"); + annotator.annotate(&xs, &ys); + } + + _ => (), + } + } + + model.summary(); + + Ok(()) +} diff --git a/examples/grounding-dino/README.md b/examples/grounding-dino/README.md index a94cb0b..f97321f 100644 --- a/examples/grounding-dino/README.md +++ b/examples/grounding-dino/README.md @@ -1,7 +1,7 @@ ## Quick Start ```shell -cargo run -r --example grounding-dino +cargo run -r -F cuda --example grounding-dino -- --device cuda --dtype fp16 ``` diff --git a/examples/grounding-dino/main.rs b/examples/grounding-dino/main.rs index 2ceb61c..c837d53 100644 --- a/examples/grounding-dino/main.rs +++ b/examples/grounding-dino/main.rs @@ -1,41 +1,67 @@ +use anyhow::Result; use usls::{models::GroundingDINO, Annotator, DataLoader, Options}; -fn main() -> Result<(), Box> { - let opts = Options::default() - .with_ixx(0, 0, (1, 1, 4).into()) - .with_ixx(0, 2, (640, 800, 1200).into()) - .with_ixx(0, 3, (640, 1200, 1200).into()) - // .with_i10((1, 1, 4).into()) - // .with_i11((256, 256, 512).into()) - // .with_i20((1, 1, 4).into()) - // .with_i21((256, 256, 512).into()) - // .with_i30((1, 1, 4).into()) - // .with_i31((256, 256, 512).into()) - // .with_i40((1, 1, 4).into()) - // .with_i41((256, 256, 512).into()) - // .with_i50((1, 1, 4).into()) - // .with_i51((256, 256, 512).into()) - // .with_i52((256, 256, 512).into()) - .with_model("grounding-dino/swint-ogc-dyn-u8.onnx")? // TODO: current onnx model does not support bs > 1 - // .with_model("grounding-dino/swint-ogc-dyn-f32.onnx")? - .with_tokenizer("grounding-dino/tokenizer.json")? - .with_confs(&[0.2]) - .with_profile(false); - let mut model = GroundingDINO::new(opts)?; - - // Load images and set class names - let x = [DataLoader::try_read("images/bus.jpg")?]; - let texts = [ - "person", "hand", "shoes", "bus", "dog", "cat", "sign", "tie", "monitor", "window", - "glasses", "tree", "head", - ]; - - // Run and annotate - let y = model.run(&x, &texts)?; +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// source image + #[argh(option, default = "vec![String::from(\"./assets/bus.jpg\")]")] + source: Vec, + + /// open class names + #[argh( + option, + default = "vec![ + String::from(\"person\"), + String::from(\"hand\"), + String::from(\"shoes\"), + String::from(\"bus\"), + String::from(\"dog\"), + String::from(\"cat\"), + String::from(\"sign\"), + String::from(\"tie\"), + String::from(\"monitor\"), + String::from(\"glasses\"), + String::from(\"tree\"), + String::from(\"head\"), + ]" + )] + labels: Vec, +} + +fn main() -> Result<()> { + let args: Args = argh::from_env(); + + let options = Options::grounding_dino_tiny() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::>()) + .commit()?; + + let mut model = GroundingDINO::new(options)?; + + // load images + let xs = DataLoader::try_read_batch(&args.source)?; + + // run + let ys = model.forward(&xs)?; + + // annotate let annotator = Annotator::default() .with_bboxes_thickness(4) - .with_saveout("GroundingDINO"); - annotator.annotate(&x, &y); + .with_saveout(model.spec()); + annotator.annotate(&xs, &ys); + + // summary + model.summary(); Ok(()) } diff --git a/examples/hub.rs b/examples/hub.rs new file mode 100644 index 0000000..dd7452c --- /dev/null +++ b/examples/hub.rs @@ -0,0 +1,25 @@ +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + // build + let mut hub = usls::Hub::new()?; + println!("{:#?}", hub); + + // download + let image_downloaded = hub.try_fetch("images/bus.jpg")?; + println!("Fetch one image. path: {:?}", image_downloaded); + + // download again + let image_downloaded = hub.try_fetch("images/bus.jpg")?; + println!("Fetch one image. path: {:?}", image_downloaded); + + // tags and files + for tag in hub.tags().iter() { + let files = hub.files(tag); + println!("{} => {:?}", tag, files); + } + + Ok(()) +} diff --git a/examples/image-classification/README.md b/examples/image-classification/README.md new file mode 100644 index 0000000..6d337fa --- /dev/null +++ b/examples/image-classification/README.md @@ -0,0 +1,13 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example image-classification -- --device cuda --dtype fp16 +``` + + +```shell +0: Y { Probs: { Top5: [(263, 0.6109131, Some("Pembroke, Pembroke Welsh corgi")), (264, 0.2062352, Some("Cardigan, Cardigan Welsh corgi")), (231, 0.028572788, Some("collie")), (273, 0.015174894, Some("dingo, warrigal, warragal, Canis dingo")), (248, 0.014367299, Some("Eskimo dog, husky"))] } } +1: Y { Probs: { Top5: [(284, 0.9907692, Some("siamese cat, Siamese")), (285, 0.0015794479, Some("Egyptian cat")), (174, 0.0015189401, Some("Norwegian elkhound, elkhound")), (225, 0.00031838714, Some("malinois")), (17, 0.00027021166, Some("jay"))] } } +2: Y { Probs: { Top5: [(387, 0.94238573, Some("lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens")), (368, 0.0029994072, Some("gibbon, Hylobates lar")), (277, 0.0016564301, Some("red fox, Vulpes vulpes")), (356, 0.0015081967, Some("weasel")), (295, 0.001427932, Some("American black bear, black bear, Ursus americanus, Euarctos americanus"))] } } + +``` diff --git a/examples/image-classification/main.rs b/examples/image-classification/main.rs new file mode 100644 index 0000000..46f9992 --- /dev/null +++ b/examples/image-classification/main.rs @@ -0,0 +1,59 @@ +use usls::{models::ImageClassifier, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// source image + #[argh( + option, + default = "vec![ + String::from(\"images/dog.jpg\"), + String::from(\"images/siamese.png\"), + String::from(\"images/ailurus-fulgens.jpg\"), + ]" + )] + source: Vec, +} + +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + let args: Args = argh::from_env(); + + // build model + let options = Options::mobileone_s0() + // convnext_v2_atto() + // fastvit_sa24_distill() + // deit_tiny_distill() + // beit_base() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = ImageClassifier::try_from(options)?; + + // load images + let xs = DataLoader::try_read_batch(&args.source)?; + + // run + let ys = model.forward(&xs)?; + + // results + for (i, y) in ys.iter().enumerate() { + println!("{}: {:?}", i, y); + } + + // annotate + let annotator = Annotator::default().with_saveout(model.spec()); + annotator.annotate(&xs, &ys); + + Ok(()) +} diff --git a/examples/modnet/main.rs b/examples/modnet/main.rs index 660ded5..4c1a7b1 100644 --- a/examples/modnet/main.rs +++ b/examples/modnet/main.rs @@ -1,22 +1,19 @@ use usls::{models::MODNet, Annotator, DataLoader, Options}; -fn main() -> Result<(), Box> { +fn main() -> anyhow::Result<()> { // build model - let options = Options::default() - .with_model("modnet/dyn-f32.onnx")? - .with_ixx(0, 2, (416, 512, 800).into()) - .with_ixx(0, 3, (416, 512, 800).into()); + let options = Options::modnet_photographic().commit()?; let mut model = MODNet::new(options)?; // load image - let x = [DataLoader::try_read("images/liuyifei.png")?]; + let xs = [DataLoader::try_read("images/liuyifei.png")?]; // run - let y = model.run(&x)?; + let ys = model.forward(&xs)?; // annotate - let annotator = Annotator::default().with_saveout("MODNet"); - annotator.annotate(&x, &y); + let annotator = Annotator::default().with_saveout(model.spec()); + annotator.annotate(&xs, &ys); Ok(()) } diff --git a/examples/picodet-layout/README.md b/examples/picodet-layout/README.md new file mode 100644 index 0000000..8e29d70 --- /dev/null +++ b/examples/picodet-layout/README.md @@ -0,0 +1,10 @@ +## Quick Start + +```shell +cargo run -r --example picodet-layout +``` + + +## Results + +![](https://github.com/jamjamjon/assets/releases/download/picodet/demo-layout-1x.png) diff --git a/examples/picodet-layout/main.rs b/examples/picodet-layout/main.rs new file mode 100644 index 0000000..a288696 --- /dev/null +++ b/examples/picodet-layout/main.rs @@ -0,0 +1,30 @@ +use anyhow::Result; +use usls::{models::PicoDet, Annotator, DataLoader, Options}; + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + // options + let options = Options::picodet_layout_1x() + // picodet_l_layout_3cls() + // picodet_l_layout_17cls() + .commit()?; + let mut model = PicoDet::new(options)?; + + // load + let xs = [DataLoader::try_read("images/academic.jpg")?]; + + // annotator + let annotator = Annotator::default() + .with_bboxes_thickness(3) + .with_saveout(model.spec()); + + // run + let ys = model.forward(&xs)?; + println!("{:?}", ys); + annotator.annotate(&xs, &ys); + + Ok(()) +} diff --git a/examples/rtdetr/README.md b/examples/rtdetr/README.md new file mode 100644 index 0000000..b131882 --- /dev/null +++ b/examples/rtdetr/README.md @@ -0,0 +1,7 @@ +## Quick Start + +```shell +cargo run -r --example rtdetr +``` + + diff --git a/examples/rtdetr/main.rs b/examples/rtdetr/main.rs new file mode 100644 index 0000000..0201cc2 --- /dev/null +++ b/examples/rtdetr/main.rs @@ -0,0 +1,33 @@ +use anyhow::Result; +use usls::{models::RTDETR, Annotator, DataLoader, Options}; + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + // options + let options = Options::rtdetr_v2_s_coco() + // rtdetr_v1_r18vd_coco() + // rtdetr_v2_ms_coco() + // rtdetr_v2_m_coco() + // rtdetr_v2_l_coco() + // rtdetr_v2_x_coco() + .commit()?; + let mut model = RTDETR::new(options)?; + + // load + let x = [DataLoader::try_read("./assets/bus.jpg")?]; + + // run + let y = model.forward(&x)?; + println!("{:?}", y); + + // annotate + let annotator = Annotator::default() + .with_bboxes_thickness(3) + .with_saveout(model.spec()); + annotator.annotate(&x, &y); + + Ok(()) +} diff --git a/examples/rtmo/main.rs b/examples/rtmo/main.rs index aae1706..cce0b37 100644 --- a/examples/rtmo/main.rs +++ b/examples/rtmo/main.rs @@ -1,25 +1,24 @@ -use usls::{models::RTMO, Annotator, DataLoader, Options, COCO_SKELETONS_16}; +use anyhow::Result; +use usls::{ + models::{COCO_SKELETONS_16, RTMO}, + Annotator, DataLoader, Options, +}; -fn main() -> Result<(), Box> { +fn main() -> Result<()> { // build model - let options = Options::default() - .with_model("rtmo/s-dyn.onnx")? - .with_nk(17) - .with_confs(&[0.3]) - .with_kconfs(&[0.5]); - let mut model = RTMO::new(options)?; + let mut model = RTMO::new(Options::rtmo_s().commit()?)?; // load image - let x = [DataLoader::try_read("images/bus.jpg")?]; + let xs = [DataLoader::try_read("images/bus.jpg")?]; // run - let y = model.run(&x)?; + let ys = model.forward(&xs)?; // annotate let annotator = Annotator::default() - .with_saveout("RTMO") + .with_saveout(model.spec()) .with_skeletons(&COCO_SKELETONS_16); - annotator.annotate(&x, &y); + annotator.annotate(&xs, &ys); Ok(()) } diff --git a/examples/sam/README.md b/examples/sam/README.md index 92af792..6b85c99 100644 --- a/examples/sam/README.md +++ b/examples/sam/README.md @@ -3,16 +3,16 @@ ```Shell # SAM -cargo run -r --example sam +cargo run -r -F cuda --example sam -- --device cuda --kind sam # MobileSAM -cargo run -r --example sam -- --kind mobile-sam +cargo run -r -F cuda --example sam -- --device cuda --kind mobile-sam # EdgeSAM -cargo run -r --example sam -- --kind edge-sam +cargo run -r -F cuda --example sam -- --device cuda --kind edge-sam # SAM-HQ -cargo run -r --example sam -- --kind sam-hq +cargo run -r -F cuda --example sam -- --device cuda --kind sam-hq ``` diff --git a/examples/sam/main.rs b/examples/sam/main.rs index 72eca95..f4bd351 100644 --- a/examples/sam/main.rs +++ b/examples/sam/main.rs @@ -1,97 +1,72 @@ -use clap::Parser; - +use anyhow::Result; use usls::{ models::{SamKind, SamPrompt, SAM}, - Annotator, DataLoader, Options, + Annotator, DataLoader, Options, Scale, }; -#[derive(Parser)] -#[command(author, version, about, long_about = None)] -pub struct Args { - #[arg(long, value_enum, default_value_t = SamKind::Sam)] - pub kind: SamKind, +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, - #[arg(long, default_value_t = 0)] - pub device_id: usize, + /// scale + #[argh(option, default = "String::from(\"t\")")] + scale: String, - #[arg(long)] - pub use_low_res_mask: bool, + /// SAM kind + #[argh(option, default = "String::from(\"sam\")")] + kind: String, } -fn main() -> Result<(), Box> { - let args = Args::parse(); - - // Options - let (options_encoder, options_decoder, saveout) = match args.kind { - SamKind::Sam => { - let options_encoder = Options::default() - // .with_model("sam/sam-vit-b-encoder.onnx")?; - .with_model("sam/sam-vit-b-encoder-u8.onnx")?; +fn main() -> Result<()> { + // tracing_subscriber::fmt() + // .with_max_level(tracing::Level::INFO) + // .init(); - let options_decoder = Options::default() - .with_sam_kind(SamKind::Sam) - // .with_model("sam/sam-vit-b-decoder.onnx")?; - // .with_model("sam/sam-vit-b-decoder-singlemask.onnx")?; - .with_model("sam/sam-vit-b-decoder-u8.onnx")?; - (options_encoder, options_decoder, "SAM") - } - SamKind::Sam2 => { - let options_encoder = Options::default() - // .with_model("sam/sam2-hiera-tiny-encoder.onnx")?; - // .with_model("sam/sam2-hiera-small-encoder.onnx")?; - .with_model("sam/sam2-hiera-base-plus-encoder.onnx")?; - let options_decoder = Options::default() - .with_sam_kind(SamKind::Sam2) - // .with_model("sam/sam2-hiera-tiny-decoder.onnx")?; - // .with_model("sam/sam2-hiera-small-decoder.onnx")?; - .with_model("sam/sam2-hiera-base-plus-decoder.onnx")?; - (options_encoder, options_decoder, "SAM2") - } - SamKind::MobileSam => { - let options_encoder = - Options::default().with_model("sam/mobile-sam-vit-t-encoder.onnx")?; - - let options_decoder = Options::default() - .with_sam_kind(SamKind::MobileSam) - .with_model("sam/mobile-sam-vit-t-decoder.onnx")?; - (options_encoder, options_decoder, "Mobile-SAM") - } - SamKind::SamHq => { - let options_encoder = Options::default().with_model("sam/sam-hq-vit-t-encoder.onnx")?; + let args: Args = argh::from_env(); + // Build model + let (options_encoder, options_decoder) = match args.kind.as_str().try_into()? { + SamKind::Sam => ( + Options::sam_v1_base_encoder(), + Options::sam_v1_base_decoder(), + ), + SamKind::Sam2 => match args.scale.as_str().try_into()? { + Scale::T => (Options::sam2_tiny_encoder(), Options::sam2_tiny_decoder()), + Scale::S => (Options::sam2_small_encoder(), Options::sam2_small_decoder()), + Scale::B => ( + Options::sam2_base_plus_encoder(), + Options::sam2_base_plus_decoder(), + ), + _ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t.", args.scale), + }, - let options_decoder = Options::default() - .with_sam_kind(SamKind::SamHq) - .with_model("sam/sam-hq-vit-t-decoder.onnx")?; - (options_encoder, options_decoder, "SAM-HQ") - } - SamKind::EdgeSam => { - let options_encoder = Options::default().with_model("sam/edge-sam-3x-encoder.onnx")?; - let options_decoder = Options::default() - .with_sam_kind(SamKind::EdgeSam) - .with_model("sam/edge-sam-3x-decoder.onnx")?; - (options_encoder, options_decoder, "Edge-SAM") - } + SamKind::MobileSam => ( + Options::mobile_sam_tiny_encoder(), + Options::mobile_sam_tiny_decoder(), + ), + SamKind::SamHq => ( + Options::sam_hq_tiny_encoder(), + Options::sam_hq_tiny_decoder(), + ), + SamKind::EdgeSam => ( + Options::edge_sam_3x_encoder(), + Options::edge_sam_3x_decoder(), + ), }; - let options_encoder = options_encoder - .with_cuda(args.device_id) - .with_ixx(0, 2, (800, 1024, 1024).into()) - .with_ixx(0, 3, (800, 1024, 1024).into()); - let options_decoder = options_decoder - .with_cuda(args.device_id) - .with_low_res_mask(args.use_low_res_mask) - .with_find_contours(true); - // Build model + let options_encoder = options_encoder + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let options_decoder = options_decoder.commit()?; let mut model = SAM::new(options_encoder, options_decoder)?; // Load image - let xs = [ - DataLoader::try_read("images/truck.jpg")?, - // DataLoader::try_read("images/dog.jpg")?, - ]; + let xs = [DataLoader::try_read("images/truck.jpg")?]; // Build annotator - let annotator = Annotator::default().with_saveout(saveout); + let annotator = Annotator::default().with_saveout(model.spec()); // Prompt let prompts = vec![ @@ -102,7 +77,7 @@ fn main() -> Result<(), Box> { ]; // Run & Annotate - let ys = model.run(&xs, &prompts)?; + let ys = model.forward(&xs, &prompts)?; annotator.annotate(&xs, &ys); Ok(()) diff --git a/examples/sapiens/README.md b/examples/sapiens/README.md index 6bf5cfe..3112e69 100644 --- a/examples/sapiens/README.md +++ b/examples/sapiens/README.md @@ -1,7 +1,7 @@ ## Quick Start ```shell -cargo run -r --example sapiens +cargo run -r -F cuda --example sapiens -- --device cuda ``` diff --git a/examples/sapiens/main.rs b/examples/sapiens/main.rs index 111d90f..ab0bda7 100644 --- a/examples/sapiens/main.rs +++ b/examples/sapiens/main.rs @@ -1,27 +1,33 @@ -use usls::{ - models::{Sapiens, SapiensTask}, - Annotator, DataLoader, Options, BODY_PARTS_28, -}; +use anyhow::Result; +use usls::{models::Sapiens, Annotator, DataLoader, Options}; -fn main() -> Result<(), Box> { +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + let args: Args = argh::from_env(); // build - let options = Options::default() - .with_model("sapiens/seg-0.3b-dyn.onnx")? - .with_sapiens_task(SapiensTask::Seg) - .with_names(&BODY_PARTS_28); + let options = Options::sapiens_seg_0_3b() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; let mut model = Sapiens::new(options)?; // load let x = [DataLoader::try_read("images/paul-george.jpg")?]; // run - let y = model.run(&x)?; + let y = model.forward(&x)?; // annotate let annotator = Annotator::default() .without_masks(true) - .with_polygons_name(false) - .with_saveout("Sapiens"); + .with_polygons_name(true) + .with_saveout(model.spec()); annotator.annotate(&x, &y); Ok(()) diff --git a/examples/slanet/README.md b/examples/slanet/README.md new file mode 100644 index 0000000..dac09f0 --- /dev/null +++ b/examples/slanet/README.md @@ -0,0 +1,10 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example slanet -- --device cuda +``` + + +## Results + +![](https://github.com/jamjamjon/assets/releases/download/slanet/demo.png) diff --git a/examples/slanet/main.rs b/examples/slanet/main.rs new file mode 100644 index 0000000..e028269 --- /dev/null +++ b/examples/slanet/main.rs @@ -0,0 +1,46 @@ +use anyhow::Result; +use usls::{models::SLANet, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// source + #[argh(option, default = "String::from(\"images/table.png\")")] + source: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + let args: Args = argh::from_env(); + + // build model + let options = Options::slanet_lcnet_v2_mobile_ch() + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = SLANet::new(options)?; + + // load + let xs = DataLoader::try_read_batch(&[args.source])?; + + // run + let ys = model.forward(&xs)?; + println!("{:?}", ys); + + // annotate + let annotator = Annotator::default() + .with_keypoints_radius(2) + .with_skeletons(&[(0, 1), (1, 2), (2, 3), (3, 0)]) + .with_saveout(model.spec()); + annotator.annotate(&xs, &ys); + + // summary + model.summary(); + + Ok(()) +} diff --git a/examples/svtr/README.md b/examples/svtr/README.md index cc192bc..82c10c5 100644 --- a/examples/svtr/README.md +++ b/examples/svtr/README.md @@ -1,29 +1,21 @@ ## Quick Start ```shell -cargo run -r --example svtr +cargo run -r -F cuda --example svtr -- --device cuda ``` -### Speed test - -| Model | Width | TensorRT
f16
batch=1
(ms) | TensorRT
f32
batch=1
(ms) | CUDA
f32
batch=1
(ms) | -| --------------------------- | :---: | :--------------------------------------: | :--------------------------------------: | :----------------------------------: | -| ppocr-v4-server-svtr-ch-dyn | 1500 | 4.2116 | 13.0013 | 20.8673 | -| ppocr-v4-svtr-ch-dyn | 1500 | 2.0435 | 3.1959 | 10.1750 | -| ppocr-v3-svtr-ch-dyn | 1500 | 1.8596 | 2.9401 | 6.8210 | - -***Test on RTX3060*** - ## Results ```shell -["./examples/svtr/images/5.png"]: Some(["are closely jointed. Some examples are illustrated in Fig.7."]) -["./examples/svtr/images/6.png"]: Some(["小菊儿胡同71号"]) -["./examples/svtr/images/4.png"]: Some(["我在南锣鼓捣猫呢"]) -["./examples/svtr/images/1.png"]: Some(["你有这么高速运转的机械进入中国,记住我给出的原理"]) -["./examples/svtr/images/2.png"]: Some(["冀B6G000"]) -["./examples/svtr/images/9.png"]: Some(["from the background, but also separate text instances which"]) -["./examples/svtr/images/8.png"]: Some(["110022345"]) -["./examples/svtr/images/3.png"]: Some(["粤A·68688"]) -["./examples/svtr/images/7.png"]: Some(["Please lower your volume"]) +["./examples/svtr/images/license-ch-2.png"]: Ys([Y { Texts: [Text("粤A·68688")] }]) +["./examples/svtr/images/license-ch.png"]: Ys([Y { Texts: [Text("冀B6G000")] }]) +["./examples/svtr/images/sign-ch-2.png"]: Ys([Y { Texts: [Text("我在南锣鼓捣猫呢")] }]) +["./examples/svtr/images/sign-ch.png"]: Ys([Y { Texts: [Text("小菊儿胡同71号")] }]) +["./examples/svtr/images/text-110022345.png"]: Ys([Y { Texts: [Text("110022345")] }]) +["./examples/svtr/images/text-ch.png"]: Ys([Y { Texts: [Text("你有这么高速运转的机械进入中国,记住我给出的原理")] }]) +["./examples/svtr/images/text-en-2.png"]: Ys([Y { Texts: [Text("from the background, but also separate text instances which")] }]) +["./examples/svtr/images/text-en-dark.png"]: Ys([Y { Texts: [Text("Please lower your volume")] }]) +["./examples/svtr/images/text-en.png"]: Ys([Y { Texts: [Text("are closely jointed. Some examples are illustrated in Fig.7.")] }]) +["./examples/svtr/images/text-hello-rust-handwritten.png"]: Ys([Y { Texts: [Text("HeloRuSt")] }]) + ``` \ No newline at end of file diff --git a/examples/svtr/images/3.png b/examples/svtr/images/license-ch-2.png similarity index 100% rename from examples/svtr/images/3.png rename to examples/svtr/images/license-ch-2.png diff --git a/examples/svtr/images/2.png b/examples/svtr/images/license-ch.png similarity index 100% rename from examples/svtr/images/2.png rename to examples/svtr/images/license-ch.png diff --git a/examples/svtr/images/4.png b/examples/svtr/images/sign-ch-2.png similarity index 100% rename from examples/svtr/images/4.png rename to examples/svtr/images/sign-ch-2.png diff --git a/examples/svtr/images/6.png b/examples/svtr/images/sign-ch.png similarity index 100% rename from examples/svtr/images/6.png rename to examples/svtr/images/sign-ch.png diff --git a/examples/svtr/images/8.png b/examples/svtr/images/text-110022345.png similarity index 100% rename from examples/svtr/images/8.png rename to examples/svtr/images/text-110022345.png diff --git a/examples/svtr/images/1.png b/examples/svtr/images/text-ch.png similarity index 100% rename from examples/svtr/images/1.png rename to examples/svtr/images/text-ch.png diff --git a/examples/svtr/images/9.png b/examples/svtr/images/text-en-2.png similarity index 100% rename from examples/svtr/images/9.png rename to examples/svtr/images/text-en-2.png diff --git a/examples/svtr/images/7.png b/examples/svtr/images/text-en-dark.png similarity index 100% rename from examples/svtr/images/7.png rename to examples/svtr/images/text-en-dark.png diff --git a/examples/svtr/images/5.png b/examples/svtr/images/text-en.png similarity index 100% rename from examples/svtr/images/5.png rename to examples/svtr/images/text-en.png diff --git a/examples/svtr/images/text-hello-rust-handwritten.png b/examples/svtr/images/text-hello-rust-handwritten.png new file mode 100644 index 0000000000000000000000000000000000000000..750c634277c02cd5dc29db2e5805340114eb51d6 GIT binary patch literal 3777 zcmV;y4nFaTP)Oh00001b5ch_0Itp) z=>Px#1ZP1_K>z@;j|==^1poj532;bRa{vGi!vFvd!vV){sAK>D4pvD-K~#8N?VSmX z6h#z3i$w<NFfPIYpzOcz^~3A|lA;fhdBgQ2`Hv93BBgK?DgBAV@%tpq#E~bXgbC za46z|fCdHSP*hgr3|=UB0Byg&RlCD5J=3#2!*2J?OFA`O(_K5=UH$*B|9{n%vTRhE zG-=WVj0|0-Ns}f`phye1G-=WVlC*F;No?G>QGZGt$!If8@c8%Nf7Ks<{2|+sBS&;? z@JS1|lf=FE-mBhy_g%f07?ROuDv+0#r`~$&ExlLBufP7P7B60`KK}S)5d=Y@Teohi zcJ11-U4HrHx;FTv0}3aJ%*;&Dpzpl%j%w1RiLUiVGTKZ9rcIlsh7B90PCM;1Ri;cC zb@1Rp75*BOE?ruL!l+TB)R||VscZcS!n$+k&hq)py?b{l4Xut>k&arNAc$b!zI{?Af@HLr z3jF!!pZb$hD^{#fufF=Kdhx{<)l*MBrCPLTp)R=K0@bHaAGKh?fl$!gVktO(lFgEIoSksH@12>B#4ucb=4P*}Z$WTDfv%f!0eDWy_Y; zpAttR+B|UJfO_)DCsmy~b!5Bq&O6oq{rmN|!D7RP4N|sm-(E^5IIdQ$T6)hJ6evSm zf)U}MC_en~L%mltXum^;4&wW(TD7W-T;|W8ufL7Wx^?SRqehL?si&T*Yf=N>3AaOs z4(U(%Rz@t06;h!>1u0`e#O#qr9#NNEa*4X_w%gQ|S6(U3`WZ83sK5XITOU<)7$PSc zuWi2m`s?bdtFBTrXUg!%lF=UPagoQ@U!;s-(L+K zI#eI&4DSo0mMgBfLf51a_Rz*x?Ay1`s#2wjY=;jY*0s(&|NQfEY;JC@u5o7Gym?lQ z8Z|6iJ6gBje!FFNBXdrFa!xZbIk3P~z$A|gNwl$z@ypxDFr%#_=(9WPi zgRH!~JgHy1cCFNLKN9(yiyPbGw`?n3u5HblHM;gh_|CT4q`yUD(xgcR_s>=cx;CFF zQ>IA0Ekv!~e)~<=x?|_gotCX!HJGk-X6)Fpa_sx>zprZ|gJ9!0I^v2ID|Br>bf9U|rpfnbpMADvTZ8%=_n6;& z@ZdpR;~lSU+8W)eR;`+}5vlV(@x&AIU6(FhbdCSv*_!*db?a7LdptbftFOLV6k{y< z_U$WewP?{o*A!xe)2LCSbd7ts=bSrWzyMuS$bbL+S8xrn-hl}dCdhG4*BWGeH!S_} z#~-CERjQOLAA;Qxq^w!9rj(H&k~(nUKy~f4*Q!mMHk}ZO&B(}54?g&y>esKIw8xO3 zC{(OiQI#)WUhgFaFKvz=KVE#7_6sULws3m%=pm(T#Yq_iX3w6jKPh#}DW~X9g)oe> zt#)NAG4Pwyo_ntTlsJmoCg~7MmoAlLRT5PrL3FKh<;qD3f)oe%Ol=EZ?IYu7IE=b9v011TnO*IjqH3QB}BX3Q8VUwGjKDZ??dC(bzI3{kw% zAecbZHx==naQpJhFQufz&^g)1=}q(qEk~{riI)#Q{IFPKIC(kN$;!itPBQ28>C>ZH zKsb&=wP-xTja)l$60k+VYn#iKEfdFdyLRmyUt;RHcZvMn#RW2<>e}Pkv15mn_3G7= z(utpc{#pI<%P*>K-MYHwL=gB`nJD?dp@*lw0N)o*UgK|Y+Kj{4Y~8ze*DV|m{yThV zQ>RXqeS8_)w{I7}j4kx6Aw!0UGq+BiIt6W+U<1dx#qMBoE)M%d;kC`)y?e{~v3!yE zjqoc4foGn1rrwO+Lnn-?(vOUE`iPbLPmo_&c32 zel*i>4g7xkt1WD;(W6HfJmXHCI?1*7=+VQnEmQF)+QQRn)Toj4e~w`zK*6~|!B>_W z8isj*z;KErO#DUojutLlD4G->16t^dFTN1%2To@B*s){9hxXia&&jZ~diCmZ4k;jZ z=Ta}d^ishv2LA_|p5O4=Yp=x|vV>xWKWCkFmXwLWYn#Rbj+G95KXH#A*@^d#e}yXKl}q>KmiT%)BqVJYGo!O9E;w+j80v=rWn zH{N(d>X}jnVFd>=SPdQl8(SJ z;VBM&WcBLRaxGXhuw=;+v0AXoH*el7b@5zLnAd=64D3^Wx;+rEoADv-?Za98{2j>P% zdm^QghE*nG9EW+bDT%_36c8Y&;6!e2u1v*BRxvq#EkX;4ANH^L3g zx?#fx@e{ZeNKB%NPpQEIulId0$^;fAHcP>Kn_ZlPm1h0|HZwpz7n((z+yMqYFr&n;WFl;qBvZn{Z- z;|?GG^tXjkdg{I2N21|8 zYqnLZR_WTWO4!*rze>(}er=#Xg3@gzLDfmItwzGTG(k}~<6 zr&q}Eiuvt{6mDMIBq5rFPScLrNbHQZT87`a?6S)wshboDH;}}NAnDhypQB>5tzAj_ zMKN#>zxwJc{Y^fk4-FqayddEm|q<1!1y#agvdd zA$f5IG^AIO(Pj`aOAaxGF4bVfg0O;cs}+I9KRzO7j^thzVI`F%(xRAtCy>IeBtdeg zA$)PfqLPd@g9f>1EIR3wxEo~1Ya2d%xP0!?rHjnWU~R+u?z^wxA0K69Wy!o)R?m6- z@yF%Zpy3;zn_pqN7DFS-?Ck7>BvdA&&7d)2#0Z&n5bqM-L4u{Zrc9Y4f*$t_LsVuf rFw+LtO|&6v5J(HRG-=WV5T*VHkp(^3^j3~}00000NkvXXu0mjf8AfZD literal 0 HcmV?d00001 diff --git a/examples/svtr/main.rs b/examples/svtr/main.rs index 43562c1..a46f0f8 100644 --- a/examples/svtr/main.rs +++ b/examples/svtr/main.rs @@ -1,24 +1,43 @@ +use anyhow::Result; use usls::{models::SVTR, DataLoader, Options}; -fn main() -> Result<(), Box> { +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + let args: Args = argh::from_env(); + // build model - let options = Options::default() - .with_ixx(0, 0, (1, 2, 8).into()) - .with_ixx(0, 2, (320, 960, 1600).into()) - .with_ixx(0, 3, (320, 960, 1600).into()) - .with_confs(&[0.2]) - .with_vocab("svtr/ppocr_rec_vocab.txt")? - .with_model("svtr/ppocr-v4-svtr-ch-dyn.onnx")?; + let options = Options::ppocr_rec_v4_ch() + // svtr_v2_teacher_ch() + // .with_batch_size(2) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; let mut model = SVTR::new(options)?; // load images - let dl = DataLoader::new("./examples/svtr/images")?.build()?; + let dl = DataLoader::new("./examples/svtr/images")? + .with_batch(model.batch() as _) + .with_progress_bar(false) + .build()?; // run for (xs, paths) in dl { - let ys = model.run(&xs)?; - println!("{paths:?}: {:?}", ys[0].texts()) + let ys = model.forward(&xs)?; + println!("{paths:?}: {:?}", ys) } + //summary + model.summary(); + Ok(()) } diff --git a/examples/trocr/README.md b/examples/trocr/README.md new file mode 100644 index 0000000..dba262c --- /dev/null +++ b/examples/trocr/README.md @@ -0,0 +1,13 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example trocr -- --device cuda --dtype fp16 --scale s --kind printed + +cargo run -r -F cuda --example trocr -- --device cuda --dtype fp16 --scale s --kind hand-written + +``` + + +```shell +Ys([Y { Texts: [Text("PLEASE LOWER YOUR VOLUME")] }, Y { Texts: [Text("HELLO RUST")] }]) +``` \ No newline at end of file diff --git a/examples/trocr/main.rs b/examples/trocr/main.rs new file mode 100644 index 0000000..d3392b0 --- /dev/null +++ b/examples/trocr/main.rs @@ -0,0 +1,91 @@ +use usls::{ + models::{TrOCR, TrOCRKind}, + DataLoader, Options, Scale, +}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// scale + #[argh(option, default = "String::from(\"s\")")] + scale: String, + + /// kind + #[argh(option, default = "String::from(\"printed\")")] + kind: String, +} + +fn main() -> anyhow::Result<()> { + let args: Args = argh::from_env(); + + // load images + let xs = DataLoader::try_read_batch(&[ + "images/text-en-dark.png", + "images/text-hello-rust-handwritten.png", + ])?; + + // build model + let (options_encoder, options_decoder, options_decoder_merged) = + match args.scale.as_str().try_into()? { + Scale::S => match args.kind.as_str().try_into()? { + TrOCRKind::Printed => ( + Options::trocr_encoder_small_printed(), + Options::trocr_decoder_small_printed(), + Options::trocr_decoder_merged_small_printed(), + ), + TrOCRKind::HandWritten => ( + Options::trocr_encoder_small_handwritten(), + Options::trocr_decoder_small_handwritten(), + Options::trocr_decoder_merged_small_handwritten(), + ), + }, + Scale::B => match args.kind.as_str().try_into()? { + TrOCRKind::Printed => ( + Options::trocr_encoder_base_printed(), + Options::trocr_decoder_base_printed(), + Options::trocr_decoder_merged_base_printed(), + ), + TrOCRKind::HandWritten => ( + Options::trocr_encoder_base_handwritten(), + Options::trocr_decoder_base_handwritten(), + Options::trocr_decoder_merged_base_handwritten(), + ), + }, + x => anyhow::bail!("Unsupported TrOCR scale: {:?}", x), + }; + + let mut model = TrOCR::new( + options_encoder + .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + options_decoder + .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + options_decoder_merged + .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_batch_size(xs.len()) + .commit()?, + )?; + + // inference + let ys = model.forward(&xs)?; + println!("{:?}", ys); + + // summary + model.summary(); + + Ok(()) +} diff --git a/examples/viewer.rs b/examples/viewer.rs new file mode 100644 index 0000000..6863be6 --- /dev/null +++ b/examples/viewer.rs @@ -0,0 +1,38 @@ +use usls::{DataLoader, Key, Viewer}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// source + #[argh( + option, + default = "String::from(\"http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4\")" + )] + source: String, +} + +fn main() -> anyhow::Result<()> { + let args: Args = argh::from_env(); + let dl = DataLoader::new(&args.source)?.with_batch(1).build()?; + + let mut viewer = Viewer::new().with_delay(5).with_scale(1.).resizable(true); + + // run & annotate + for (xs, _paths) in dl { + // show image + viewer.imshow(&xs)?; + + // check out window and key event + if !viewer.is_open() || viewer.is_key_pressed(Key::Escape) { + break; + } + + // write video + viewer.write_batch(&xs)? + } + + // finish video write + viewer.finish_write()?; + + Ok(()) +} diff --git a/examples/yolo-sam/README.md b/examples/yolo-sam/README.md index 1dfab0c..84dfb0f 100644 --- a/examples/yolo-sam/README.md +++ b/examples/yolo-sam/README.md @@ -1,7 +1,7 @@ ## Quick Start ```shell -cargo run -r --example yolo-sam +cargo run -r -F cuda --example yolo-sam -- --device cuda ``` ## Results diff --git a/examples/yolo-sam/main.rs b/examples/yolo-sam/main.rs index 3b51ace..5628ba6 100644 --- a/examples/yolo-sam/main.rs +++ b/examples/yolo-sam/main.rs @@ -1,31 +1,41 @@ +use anyhow::Result; use usls::{ - models::{SamKind, SamPrompt, YOLOTask, YOLOVersion, SAM, YOLO}, - Annotator, DataLoader, Options, Vision, + models::{SamPrompt, SAM, YOLO}, + Annotator, DataLoader, Options, Scale, }; -fn main() -> Result<(), Box> { +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + // tracing_subscriber::fmt() + // .with_max_level(tracing::Level::INFO) + // .init(); + + let args: Args = argh::from_env(); + // build SAM - let options_encoder = Options::default().with_model("sam/mobile-sam-vit-t-encoder.onnx")?; - let options_decoder = Options::default() - .with_find_contours(true) - .with_sam_kind(SamKind::Sam) - .with_model("sam/mobile-sam-vit-t-decoder.onnx")?; + let (options_encoder, options_decoder) = ( + Options::mobile_sam_tiny_encoder().commit()?, + Options::mobile_sam_tiny_decoder().commit()?, + ); let mut sam = SAM::new(options_encoder, options_decoder)?; - // build YOLOv8-Det - let options_yolo = Options::default() - .with_yolo_version(YOLOVersion::V8) - .with_yolo_task(YOLOTask::Detect) - .with_model("yolo/v8-m-dyn.onnx")? - .with_cuda(0) - .with_ixx(0, 2, (416, 640, 800).into()) - .with_ixx(0, 3, (416, 640, 800).into()) - .with_find_contours(false) - .with_confs(&[0.45]); + // build YOLOv8 + let options_yolo = Options::yolo_detect() + .with_model_scale(Scale::N) + .with_model_version(8.0.into()) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; let mut yolo = YOLO::new(options_yolo)?; // load one image - let xs = [DataLoader::try_read("images/dog.jpg")?]; + let xs = DataLoader::try_read_batch(&["images/dog.jpg"])?; // build annotator let annotator = Annotator::default() @@ -36,11 +46,11 @@ fn main() -> Result<(), Box> { .with_saveout("YOLO-SAM"); // run & annotate - let ys_det = yolo.run(&xs)?; - for y_det in ys_det { + let ys_det = yolo.forward(&xs)?; + for y_det in ys_det.iter() { if let Some(bboxes) = y_det.bboxes() { for bbox in bboxes { - let ys_sam = sam.run( + let ys_sam = sam.forward( &xs, &[SamPrompt::default().with_bbox( bbox.xmin(), diff --git a/examples/yolo/README.md b/examples/yolo/README.md index d443f43..15be924 100644 --- a/examples/yolo/README.md +++ b/examples/yolo/README.md @@ -28,76 +28,38 @@ cargo run -r --example yolo -- --task detect --ver v8 --nc 6 --model xxx.onnx # YOLOv8 # Classify -cargo run -r --example yolo -- --task classify --ver v5 --scale s --width 224 --height 224 --nc 1000 # YOLOv5 -cargo run -r --example yolo -- --task classify --ver v8 --scale n --width 224 --height 224 --nc 1000 # YOLOv8 -cargo run -r --example yolo -- --task classify --ver v11 --scale n --width 224 --height 224 --nc 1000 # YOLOv11 + +cargo run -r --example yolo -- --task classify --ver 5 --scale s --image-width 224 --image-height 224 --num-classes 1000 --use-imagenet-1k-classes # YOLOv5 +cargo run -r --example yolo -- --task classify --ver 8 --scale n --image-width 224 --image-height 224 # YOLOv8 +cargo run -r --example yolo -- --task classify --ver 11 --scale n --image-width 224 --image-height 224 # YOLOv11 # Detect -cargo run -r --example yolo -- --task detect --ver v5 --scale n # YOLOv5 -cargo run -r --example yolo -- --task detect --ver v6 --scale n # YOLOv6 -cargo run -r --example yolo -- --task detect --ver v7 --scale t # YOLOv7 -cargo run -r --example yolo -- --task detect --ver v8 --scale n # YOLOv8 -cargo run -r --example yolo -- --task detect --ver v9 --scale t # YOLOv9 -cargo run -r --example yolo -- --task detect --ver v10 --scale n # YOLOv10 -cargo run -r --example yolo -- --task detect --ver v11 --scale n # YOLOv11 -cargo run -r --example yolo -- --task detect --ver rtdetr --scale l # RTDETR -cargo run -r --example yolo -- --task detect --ver v8 --model yolo/v8-s-world-v2-shoes.onnx # YOLOv8-world +cargo run -r --example yolo -- --task detect --ver 5 --scale n --use-coco-80-classes # YOLOv5 +cargo run -r --example yolo -- --task detect --ver 6 --scale n --use-coco-80-classes # YOLOv6 +cargo run -r --example yolo -- --task detect --ver 7 --scale t --use-coco-80-classes # YOLOv7 +cargo run -r --example yolo -- --task detect --ver 8 --scale n --use-coco-80-classes # YOLOv8 +cargo run -r --example yolo -- --task detect --ver 9 --scale t --use-coco-80-classes # YOLOv9 +cargo run -r --example yolo -- --task detect --ver 10 --scale n --use-coco-80-classes # YOLOv10 +cargo run -r --example yolo -- --task detect --ver 11 --scale n --use-coco-80-classes # YOLOv11 +cargo run -r --example yolo -- --task detect --ver 8 --model v8-s-world-v2-shoes.onnx # YOLOv8-world # Pose -cargo run -r --example yolo -- --task pose --ver v8 --scale n # YOLOv8-Pose -cargo run -r --example yolo -- --task pose --ver v11 --scale n # YOLOv11-Pose +cargo run -r --example yolo -- --task pose --ver 8 --scale n # YOLOv8-Pose +cargo run -r --example yolo -- --task pose --ver 11 --scale n # YOLOv11-Pose # Segment -cargo run -r --example yolo -- --task segment --ver v5 --scale n # YOLOv5-Segment -cargo run -r --example yolo -- --task segment --ver v8 --scale n # YOLOv8-Segment -cargo run -r --example yolo -- --task segment --ver v11 --scale n # YOLOv8-Segment -cargo run -r --example yolo -- --task segment --ver v8 --model yolo/FastSAM-s-dyn-f16.onnx # FastSAM +cargo run -r --example yolo -- --task segment --ver 5 --scale n # YOLOv5-Segment +cargo run -r --example yolo -- --task segment --ver 8 --scale n # YOLOv8-Segment +cargo run -r --example yolo -- --task segment --ver 11 --scale n # YOLOv8-Segment # Obb -cargo run -r --example yolo -- --ver v8 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv8-Obb -cargo run -r --example yolo -- --ver v11 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv11-Obb +cargo run -r --example yolo -- --ver 8 --task obb --scale n --image-width 1024 --image-height 1024 --source images/dota.png # YOLOv8-Obb +cargo run -r --example yolo -- --ver 11 --task obb --scale n --image-width 1024 --image-height 1024 --source images/dota.png # YOLOv11-Obb ``` **`cargo run -r --example yolo -- --help` for more options** -## YOLOs configs with `Options` - -
-Use official YOLO Models - -```Rust -let options = Options::default() - .with_yolo_version(YOLOVersion::V5) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR - .with_yolo_task(YOLOTask::Classify) // YOLOTask: Classify, Detect, Pose, Segment, Obb - .with_model("xxxx.onnx")?; - -``` -
- -
-Cutomized your own YOLO model - -```Rust -// This config is for YOLOv8-Segment -use usls::{AnchorsPosition, BoxType, ClssType, YOLOPreds}; - -let options = Options::default() - .with_yolo_preds( - YOLOPreds { - bbox: Some(BoxType::Cxcywh), - clss: ClssType::Clss, - coefs: Some(true), - anchors: Some(AnchorsPosition::After), - ..Default::default() - } - ) - // .with_nc(80) - // .with_names(&COCO_CLASS_NAMES_80) - .with_model("xxxx.onnx")?; -``` -
- ## Other YOLOv8 Solution Models | Model | Weights | Datasets| diff --git a/examples/yolo/main.rs b/examples/yolo/main.rs index 2df0cc4..056b8a7 100644 --- a/examples/yolo/main.rs +++ b/examples/yolo/main.rs @@ -1,171 +1,211 @@ use anyhow::Result; -use clap::Parser; - use usls::{ - models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOScale, YOLOTask, - YOLOVersion, COCO_SKELETONS_16, + models::{COCO_CLASS_NAMES_80, COCO_SKELETONS_16, IMAGENET_NAMES_1K, YOLO}, + Annotator, DataLoader, Options, }; -#[derive(Parser, Clone)] -#[command(author, version, about, long_about = None)] -pub struct Args { - /// Path to the model - #[arg(long)] - pub model: Option, +#[derive(argh::FromArgs, Debug)] +/// Example +struct Args { + /// model file + #[argh(option)] + model: Option, + + /// source + #[argh(option, default = "String::from(\"./assets/bus.jpg\")")] + source: String, + + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// task + #[argh(option, default = "String::from(\"det\")")] + task: String, - /// Input source path - #[arg(long, default_value_t = String::from("./assets/bus.jpg"))] - pub source: String, + /// version + #[argh(option, default = "8.0")] + ver: f32, - /// YOLO Task - #[arg(long, value_enum, default_value_t = YOLOTask::Detect)] - pub task: YOLOTask, + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, - /// YOLO Version - #[arg(long, value_enum, default_value_t = YOLOVersion::V8)] - pub ver: YOLOVersion, + /// scale + #[argh(option, default = "String::from(\"n\")")] + scale: String, - /// YOLO Scale - #[arg(long, value_enum, default_value_t = YOLOScale::N)] - pub scale: YOLOScale, + /// trt_fp16 + #[argh(option, default = "true")] + trt_fp16: bool, - /// Batch size - #[arg(long, default_value_t = 1)] - pub batch_size: usize, + /// find_contours + #[argh(option, default = "true")] + find_contours: bool, - /// Minimum input width - #[arg(long, default_value_t = 224)] - pub width_min: isize, + /// batch_size + #[argh(option, default = "1")] + batch_size: usize, - /// Input width - #[arg(long, default_value_t = 640)] - pub width: isize, + /// min_batch_size + #[argh(option, default = "1")] + min_batch_size: usize, - /// Maximum input width - #[arg(long, default_value_t = 1024)] - pub width_max: isize, + /// max_batch_size + #[argh(option, default = "4")] + max_batch_size: usize, - /// Minimum input height - #[arg(long, default_value_t = 224)] - pub height_min: isize, + /// min_image_width + #[argh(option, default = "224")] + min_image_width: isize, - /// Input height - #[arg(long, default_value_t = 640)] - pub height: isize, + /// image_width + #[argh(option, default = "640")] + image_width: isize, - /// Maximum input height - #[arg(long, default_value_t = 1024)] - pub height_max: isize, + /// max_image_width + #[argh(option, default = "1280")] + max_image_width: isize, - /// Number of classes - #[arg(long, default_value_t = 80)] - pub nc: usize, + /// min_image_height + #[argh(option, default = "224")] + min_image_height: isize, - /// Class confidence - #[arg(long)] - pub confs: Vec, + /// image_height + #[argh(option, default = "640")] + image_height: isize, - /// Enable TensorRT support - #[arg(long)] - pub trt: bool, + /// max_image_height + #[argh(option, default = "1280")] + max_image_height: isize, - /// Enable CUDA support - #[arg(long)] - pub cuda: bool, + /// num_classes + #[argh(option)] + num_classes: Option, - /// Enable CoreML support - #[arg(long)] - pub coreml: bool, + /// num_keypoints + #[argh(option)] + num_keypoints: Option, - /// Use TensorRT half precision - #[arg(long)] - pub half: bool, + /// use_coco_80_classes + #[argh(switch)] + use_coco_80_classes: bool, - /// Device ID to use - #[arg(long, default_value_t = 0)] - pub device_id: usize, + /// use_imagenet_1k_classes + #[argh(switch)] + use_imagenet_1k_classes: bool, - /// Enable performance profiling - #[arg(long)] - pub profile: bool, + /// confs + #[argh(option)] + confs: Vec, - /// Disable contour drawing - #[arg(long)] - pub no_contours: bool, + /// keypoint_confs + #[argh(option)] + keypoint_confs: Vec, - /// Show result - #[arg(long)] - pub view: bool, + /// exclude_classes + #[argh(option)] + exclude_classes: Vec, + + /// retain_classes + #[argh(option)] + retain_classes: Vec, - /// Do not save output - #[arg(long)] - pub nosave: bool, + /// class_names + #[argh(option)] + class_names: Vec, + + /// keypoint_names + #[argh(option)] + keypoint_names: Vec, } fn main() -> Result<()> { - let args = Args::parse(); - - // model path - let path = match &args.model { - None => format!( - "yolo/{}-{}-{}.onnx", - args.ver.name(), - args.scale.name(), - args.task.name() - ), - Some(x) => x.to_string(), - }; - - // saveout - let saveout = match &args.model { - None => format!( - "{}-{}-{}", - args.ver.name(), - args.scale.name(), - args.task.name() - ), - Some(x) => { - let p = std::path::PathBuf::from(&x); - p.file_stem().unwrap().to_str().unwrap().to_string() - } - }; - - // device - let device = if args.cuda { - Device::Cuda(args.device_id) - } else if args.trt { - Device::Trt(args.device_id) - } else if args.coreml { - Device::CoreML(args.device_id) - } else { - Device::Cpu(args.device_id) - }; - - // build options - let options = Options::new() - .with_model(&path)? - .with_yolo_version(args.ver) - .with_yolo_task(args.task) - .with_device(device) - .with_trt_fp16(args.half) - .with_ixx(0, 0, (1, args.batch_size as _, 4).into()) - .with_ixx(0, 2, (args.height_min, args.height, args.height_max).into()) - .with_ixx(0, 3, (args.width_min, args.width, args.width_max).into()) - .with_confs(if args.confs.is_empty() { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::ERROR) + .init(); + let args: Args = argh::from_env(); + + let mut options = Options::yolo() + .with_model_file(&args.model.unwrap_or_default()) + .with_model_task(args.task.as_str().try_into()?) + .with_model_version(args.ver.into()) + .with_model_scale(args.scale.as_str().try_into()?) + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_trt_fp16(args.trt_fp16) + .with_model_ixx( + 0, + 0, + (args.min_batch_size, args.batch_size, args.max_batch_size).into(), + ) + .with_model_ixx( + 0, + 2, + ( + args.min_image_height, + args.image_height, + args.max_image_height, + ) + .into(), + ) + .with_model_ixx( + 0, + 3, + (args.min_image_width, args.image_width, args.max_image_width).into(), + ) + .with_class_confs(if args.confs.is_empty() { &[0.2, 0.15] } else { &args.confs }) - .with_nc(args.nc) - // .with_names(&COCO_CLASS_NAMES_80) - // .with_names2(&COCO_KEYPOINTS_17) - .with_find_contours(!args.no_contours) // find contours or not - // .exclude_classes(&[0]) - // .retain_classes(&[0, 5]) - .with_profile(args.profile); + .with_keypoint_confs(if args.keypoint_confs.is_empty() { + &[0.5] + } else { + &args.keypoint_confs + }) + .with_find_contours(args.find_contours) + .retain_classes(&args.retain_classes) + .exclude_classes(&args.exclude_classes); + + if args.use_coco_80_classes { + options = options.with_class_names(&COCO_CLASS_NAMES_80); + } + + if args.use_imagenet_1k_classes { + options = options.with_class_names(&IMAGENET_NAMES_1K); + } + + if let Some(nc) = args.num_classes { + options = options.with_nc(nc); + } + + if let Some(nk) = args.num_keypoints { + options = options.with_nk(nk); + } + + if !args.class_names.is_empty() { + options = options.with_class_names( + &args + .class_names + .iter() + .map(|x| x.as_str()) + .collect::>(), + ); + } + + if !args.keypoint_names.is_empty() { + options = options.with_keypoint_names( + &args + .keypoint_names + .iter() + .map(|x| x.as_str()) + .collect::>(), + ); + } // build model - let mut model = YOLO::new(options)?; + let mut model = YOLO::try_from(options.commit()?)?; // build dataloader let dl = DataLoader::new(&args.source)? @@ -175,56 +215,28 @@ fn main() -> Result<()> { // build annotator let annotator = Annotator::default() .with_skeletons(&COCO_SKELETONS_16) - .without_masks(true) // No masks plotting when doing segment task. + .without_masks(true) .with_bboxes_thickness(3) - .with_keypoints_name(false) // Enable keypoints names - .with_saveout_subs(&["YOLO"]) - .with_saveout(&saveout); - - // build viewer - let mut viewer = if args.view { - Some(Viewer::new().with_delay(5).with_scale(1.).resizable(true)) - } else { - None - }; + .with_saveout(model.spec()); // run & annotate for (xs, _paths) in dl { - // let ys = model.run(&xs)?; // way one - let ys = model.forward(&xs, args.profile)?; // way two - let images_plotted = annotator.plot(&xs, &ys, !args.nosave)?; - - // show image - match &mut viewer { - Some(viewer) => viewer.imshow(&images_plotted)?, - None => continue, - } - - // check out window and key event - match &mut viewer { - Some(viewer) => { - if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) { - break; + let ys = model.forward(&xs)?; + // extract bboxes + for y in ys.iter() { + if let Some(bboxes) = y.bboxes() { + println!("[Bboxes]: Found {} objects", bboxes.len()); + for (i, bbox) in bboxes.iter().enumerate() { + println!("{}: {:?}", i, bbox) } } - None => continue, } - // write video - if !args.nosave { - match &mut viewer { - Some(viewer) => viewer.write_batch(&images_plotted)?, - None => continue, - } - } + // plot + annotator.annotate(&xs, &ys); } - // finish video write - if !args.nosave { - if let Some(viewer) = &mut viewer { - viewer.finish_write()?; - } - } + model.summary(); Ok(()) } diff --git a/examples/yolop/main.rs b/examples/yolop/main.rs index 2e338cc..820ca39 100644 --- a/examples/yolop/main.rs +++ b/examples/yolop/main.rs @@ -1,22 +1,21 @@ +use anyhow::Result; use usls::{models::YOLOPv2, Annotator, DataLoader, Options}; -fn main() -> Result<(), Box> { +fn main() -> Result<()> { // build model - let options = Options::default() - .with_model("yolop/v2-dyn-480x800.onnx")? - .with_confs(&[0.3]); + let options = Options::yolop_v2_480x800().commit()?; let mut model = YOLOPv2::new(options)?; // load image - let x = [DataLoader::try_read("images/car.jpg")?]; + let x = [DataLoader::try_read("images/car-view.jpg")?]; // run - let y = model.run(&x)?; + let y = model.forward(&x)?; // annotate let annotator = Annotator::default() .with_polygons_name(true) - .with_saveout("YOLOPv2"); + .with_saveout(model.spec()); annotator.annotate(&x, &y); Ok(()) diff --git a/examples/yolov8-rtdetr.rs b/examples/yolov8-rtdetr.rs new file mode 100644 index 0000000..9b87e71 --- /dev/null +++ b/examples/yolov8-rtdetr.rs @@ -0,0 +1,46 @@ +use anyhow::Result; +use usls::{models::YOLO, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + let args: Args = argh::from_env(); + + // build model + let config = Options::yolo_v8_rtdetr_l() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = YOLO::new(config)?; + + // load images + let xs = DataLoader::try_read_batch(&["./assets/bus.jpg"])?; + + // run + let ys = model.forward(&xs)?; + println!("{:?}", ys); + + // annotate + let annotator = Annotator::default() + .with_bboxes_thickness(3) + .with_saveout(model.spec()); + annotator.annotate(&xs, &ys); + + model.summary(); + + Ok(()) +} diff --git a/rust-toolchain.toml b/rust-toolchain.toml deleted file mode 100644 index c6e4d7d..0000000 --- a/rust-toolchain.toml +++ /dev/null @@ -1,2 +0,0 @@ -[toolchain] -channel = "1.79" diff --git a/scripts/CelebAMask-HQ-To-YOLO-Labels.py b/scripts/CelebAMask-HQ-To-YOLO-Labels.py deleted file mode 100644 index 95babb6..0000000 --- a/scripts/CelebAMask-HQ-To-YOLO-Labels.py +++ /dev/null @@ -1,63 +0,0 @@ -import cv2 -import numpy as np -from pathlib import Path -from tqdm import tqdm - - -mapping = { - 'background': 0, - 'skin': 1, - 'nose': 2, - 'eye_g': 3, - 'l_eye': 4, - 'r_eye': 5, - 'l_brow': 6, - 'r_brow': 7, - 'l_ear': 8, - 'r_ear': 9, - 'mouth': 10, - 'u_lip': 11, - 'l_lip': 12, - 'hair': 13, - 'hat': 14, - 'ear_r': 15, - 'neck_l': 16, - 'neck': 17, - 'cloth': 18 -} - - - -def main(): - saveout_dir = Path("labels") - if not saveout_dir.exists(): - saveout_dir.mkdir() - else: - import shutil - shutil.rmtree(saveout_dir) - saveout_dir.mkdir() - - - image_list = [x for x in Path("CelebAMask-HQ-mask-anno/").rglob("*.png")] - for image_path in tqdm(image_list, total=len(image_list)): - image_gray = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE) - stem = image_path.stem - name, cls_ = stem.split("_", 1) - segments = cv2.findContours(image_gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] - - saveout = saveout_dir / f"{int(name)}.txt" - with open(saveout, 'a+') as f: - for segment in segments: - line = f"{mapping[cls_]}" - segment = segment / 512 - for seg in segment: - xn, yn = seg[0] - line += f" {xn} {yn}" - f.write(line + "\n") - - - - -if __name__ == "__main__": - main() - diff --git a/scripts/convert2f16.py b/scripts/convert2f16.py deleted file mode 100644 index 6b9eec3..0000000 --- a/scripts/convert2f16.py +++ /dev/null @@ -1,8 +0,0 @@ -import onnx -from pathlib import Path -from onnxconverter_common import float16 - -model_f32 = "onnx_model.onnx" -model_f16 = float16.convert_float_to_float16(onnx.load(model_f32)) -saveout = Path(model_f32).with_name(Path(model_f32).stem + "-f16.onnx") -onnx.save(model_f16, saveout) diff --git a/src/core/device.rs b/src/core/device.rs deleted file mode 100644 index 583df16..0000000 --- a/src/core/device.rs +++ /dev/null @@ -1,14 +0,0 @@ -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub enum Device { - Auto(usize), - Cpu(usize), - Cuda(usize), - Trt(usize), - CoreML(usize), - // Cann(usize), - // Acl(usize), - // Rocm(usize), - // Rknpu(usize), - // Openvino(usize), - // Onednn(usize), -} diff --git a/src/core/metric.rs b/src/core/metric.rs deleted file mode 100644 index af0a5ed..0000000 --- a/src/core/metric.rs +++ /dev/null @@ -1,6 +0,0 @@ -#[derive(Debug)] -pub enum Metric { - IP, - Cos, - L2, -} diff --git a/src/core/options.rs b/src/core/options.rs deleted file mode 100644 index 291e60a..0000000 --- a/src/core/options.rs +++ /dev/null @@ -1,163 +0,0 @@ -//! Options for build models. - -use aksr::Builder; -use anyhow::Result; - -use crate::{ - models::{SamKind, SapiensTask, YOLOPreds, YOLOTask, YOLOVersion}, - Device, Hub, Iiix, MinOptMax, Task, -}; - -/// Options for building models -#[derive(Builder, Debug, Clone)] -pub struct Options { - pub onnx_path: String, - pub task: Task, - pub device: Device, - pub batch_size: usize, - pub iiixs: Vec, - pub profile: bool, - pub num_dry_run: usize, - - // trt related - pub trt_engine_cache_enable: bool, - pub trt_int8: bool, - pub trt_fp16: bool, - - // options for Vision and Language models - pub nc: Option, - pub nk: Option, - pub nm: Option, - pub confs: Vec, - pub confs2: Vec, - pub confs3: Vec, - pub kconfs: Vec, - pub iou: Option, - #[args(setter = false)] - pub tokenizer: Option, - #[args(setter = false)] - pub vocab: Option, - pub context_length: Option, - pub names: Option>, // names - pub names2: Option>, // names2 - pub names3: Option>, // names3 - pub min_width: Option, - pub min_height: Option, - pub unclip_ratio: f32, // DB - pub yolo_task: Option, - pub yolo_version: Option, - pub yolo_preds: Option, - pub find_contours: bool, - pub sam_kind: Option, - pub low_res_mask: Option, - pub sapiens_task: Option, - pub classes_excluded: Vec, - pub classes_retained: Vec, -} - -impl Default for Options { - fn default() -> Self { - Self { - onnx_path: String::new(), - device: Device::Cuda(0), - profile: false, - batch_size: 1, - iiixs: vec![], - num_dry_run: 3, - - trt_engine_cache_enable: true, - trt_int8: false, - trt_fp16: false, - nc: None, - nk: None, - nm: None, - confs: vec![0.3f32], - confs2: vec![0.3f32], - confs3: vec![0.3f32], - kconfs: vec![0.5f32], - iou: None, - tokenizer: None, - vocab: None, - context_length: None, - names: None, - names2: None, - names3: None, - min_width: None, - min_height: None, - unclip_ratio: 1.5, - yolo_task: None, - yolo_version: None, - yolo_preds: None, - find_contours: false, - sam_kind: None, - low_res_mask: None, - sapiens_task: None, - task: Task::Untitled, - classes_excluded: vec![], - classes_retained: vec![], - } - } -} - -impl Options { - pub fn new() -> Self { - Default::default() - } - - pub fn with_model(mut self, onnx_path: &str) -> Result { - self.onnx_path = Hub::new()?.fetch(onnx_path)?.commit()?; - Ok(self) - } - - pub fn with_batch(mut self, n: usize) -> Self { - self.batch_size = n; - self - } - - pub fn with_cuda(mut self, id: usize) -> Self { - self.device = Device::Cuda(id); - self - } - - pub fn with_trt(mut self, id: usize) -> Self { - self.device = Device::Trt(id); - self - } - - pub fn with_cpu(mut self) -> Self { - self.device = Device::Cpu(0); - self - } - - pub fn with_coreml(mut self, id: usize) -> Self { - self.device = Device::CoreML(id); - self - } - - pub fn with_vocab(mut self, vocab: &str) -> Result { - self.vocab = Some(Hub::new()?.fetch(vocab)?.commit()?); - Ok(self) - } - - pub fn with_tokenizer(mut self, tokenizer: &str) -> Result { - self.tokenizer = Some(Hub::new()?.fetch(tokenizer)?.commit()?); - Ok(self) - } - - pub fn with_ixx(mut self, i: usize, ii: usize, x: MinOptMax) -> Self { - self.iiixs.push(Iiix::from((i, ii, x))); - self - } - - pub fn exclude_classes(mut self, xs: &[isize]) -> Self { - self.classes_retained.clear(); - self.classes_excluded.extend_from_slice(xs); - self - } - - pub fn retain_classes(mut self, xs: &[isize]) -> Self { - self.classes_excluded.clear(); - self.classes_retained.extend_from_slice(xs); - self - } -} diff --git a/src/core/ort_engine.rs b/src/core/ort_engine.rs deleted file mode 100644 index b7a6f09..0000000 --- a/src/core/ort_engine.rs +++ /dev/null @@ -1,666 +0,0 @@ -use anyhow::Result; -use half::f16; -use ndarray::{Array, IxDyn}; -use ort::{ - ExecutionProvider, Session, SessionBuilder, TensorElementType, TensorRTExecutionProvider, -}; -use prost::Message; -use std::collections::HashSet; - -use crate::{ - build_progress_bar, human_bytes, onnx, Device, Dir, MinOptMax, Ops, Options, Ts, Xs, - CHECK_MARK, CROSS_MARK, X, -}; - -/// A struct for input composed of the i-th input, the ii-th dimension, and the value. -#[derive(Clone, Debug, Default)] -pub struct Iiix { - pub i: usize, - pub ii: usize, - pub x: MinOptMax, -} - -impl From<(usize, usize, MinOptMax)> for Iiix { - fn from((i, ii, x): (usize, usize, MinOptMax)) -> Self { - Self { i, ii, x } - } -} - -/// A struct for tensor attrs composed of the names, the dtypes, and the dimensions. -#[derive(Debug)] -pub struct OrtTensorAttr { - pub names: Vec, - pub dtypes: Vec, - pub dimss: Vec>, -} - -/// ONNXRuntime Backend -#[derive(Debug)] -pub struct OrtEngine { - name: String, - session: Session, - device: Device, - inputs_minoptmax: Vec>, - inputs_attrs: OrtTensorAttr, - outputs_attrs: OrtTensorAttr, - profile: bool, - num_dry_run: usize, - model_proto: onnx::ModelProto, - params: usize, - wbmems: usize, - ts: Ts, -} - -impl OrtEngine { - pub fn new(config: &Options) -> Result { - let span = tracing::span!(tracing::Level::INFO, "OrtEngine-new"); - let _guard = span.enter(); - - // onnx graph - let model_proto = Self::load_onnx(&config.onnx_path)?; - let graph = match &model_proto.graph { - Some(graph) => graph, - None => anyhow::bail!("No graph found in this proto. Failed to parse ONNX model."), - }; - - // model params & mems - let byte_alignment = 16; // 16 for simd; 8 for most - let mut params: usize = 0; - let mut wbmems: usize = 0; - let mut initializer_names: HashSet<&str> = HashSet::new(); - for tensor_proto in graph.initializer.iter() { - initializer_names.insert(&tensor_proto.name); - let param = tensor_proto.dims.iter().product::() as usize; - params += param; - - // mems - let param = Ops::make_divisible(param, byte_alignment); - let n = Self::nbytes_from_onnx_dtype_id(tensor_proto.data_type as usize); - let wbmem = param * n; - wbmems += wbmem; - } - - // inputs & outputs - let inputs_attrs = Self::io_from_onnx_value_info(&initializer_names, &graph.input)?; - let outputs_attrs = Self::io_from_onnx_value_info(&initializer_names, &graph.output)?; - let inputs_minoptmax = - Self::build_inputs_minoptmax(&inputs_attrs, &config.iiixs, config.batch_size)?; - - // build - ort::init().commit()?; - let builder = Session::builder()?; - let mut device = config.device.to_owned(); - match device { - Device::Trt(device_id) => { - Self::build_trt( - &inputs_attrs.names, - &inputs_minoptmax, - &builder, - device_id, - config.trt_int8, - config.trt_fp16, - config.trt_engine_cache_enable, - )?; - } - Device::Cuda(device_id) => { - Self::build_cuda(&builder, device_id).unwrap_or_else(|err| { - tracing::warn!("{err}, Using cpu"); - device = Device::Cpu(0); - }) - } - Device::CoreML(_) => Self::build_coreml(&builder).unwrap_or_else(|err| { - tracing::warn!("{err}, Using cpu"); - device = Device::Cpu(0); - }), - Device::Cpu(_) => { - Self::build_cpu(&builder)?; - } - _ => todo!(), - } - - let session = builder - .with_optimization_level(ort::GraphOptimizationLevel::Level3)? - .commit_from_file(&config.onnx_path)?; - - // summary - tracing::info!( - "{CHECK_MARK} Backend: ONNXRuntime | Opset: {} | Device: {:?} | Params: {}", - model_proto.opset_import[0].version, - device, - human_bytes(params as f64), - ); - - Ok(Self { - name: config.onnx_path.to_owned(), - session, - device, - inputs_minoptmax, - inputs_attrs, - outputs_attrs, - profile: config.profile, - num_dry_run: config.num_dry_run, - model_proto, - params, - wbmems, - ts: Ts::default(), - }) - } - - fn build_trt( - names: &[String], - inputs_minoptmax: &[Vec], - builder: &SessionBuilder, - device_id: usize, - int8_enable: bool, - fp16_enable: bool, - engine_cache_enable: bool, - ) -> Result<()> { - let span = tracing::span!(tracing::Level::INFO, "OrtEngine-build_trt"); - let _guard = span.enter(); - - // auto generate shapes - let mut spec_min = String::new(); - let mut spec_opt = String::new(); - let mut spec_max = String::new(); - for (i, name) in names.iter().enumerate() { - if i != 0 { - spec_min.push(','); - spec_opt.push(','); - spec_max.push(','); - } - let mut s_min = format!("{}:", name); - let mut s_opt = format!("{}:", name); - let mut s_max = format!("{}:", name); - for d in inputs_minoptmax[i].iter() { - let min_ = &format!("{}x", d.min()); - let opt_ = &format!("{}x", d.opt()); - let max_ = &format!("{}x", d.max()); - s_min += min_; - s_opt += opt_; - s_max += max_; - } - s_min.pop(); - s_opt.pop(); - s_max.pop(); - spec_min += &s_min; - spec_opt += &s_opt; - spec_max += &s_max; - } - let p = Dir::Cache.path_with_subs(&["trt-cache"])?; - let trt = TensorRTExecutionProvider::default() - .with_device_id(device_id as i32) - .with_int8(int8_enable) - .with_fp16(fp16_enable) - .with_engine_cache(engine_cache_enable) - .with_engine_cache_path(p.to_str().unwrap()) - .with_timing_cache(false) - .with_profile_min_shapes(spec_min) - .with_profile_opt_shapes(spec_opt) - .with_profile_max_shapes(spec_max); - if trt.is_available()? && trt.register(builder).is_ok() { - tracing::info!("🐢 Initial model serialization with TensorRT may require a wait...\n"); - Ok(()) - } else { - anyhow::bail!("{CROSS_MARK} TensorRT initialization failed") - } - } - - fn build_cuda(builder: &SessionBuilder, device_id: usize) -> Result<()> { - let ep = ort::CUDAExecutionProvider::default().with_device_id(device_id as i32); - if ep.is_available()? && ep.register(builder).is_ok() { - Ok(()) - } else { - anyhow::bail!("{CROSS_MARK} CUDA initialization failed") - } - } - - fn build_coreml(builder: &SessionBuilder) -> Result<()> { - let ep = ort::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only(); - if ep.is_available()? && ep.register(builder).is_ok() { - Ok(()) - } else { - anyhow::bail!("{CROSS_MARK} CoreML initialization failed") - } - } - - fn build_cpu(builder: &SessionBuilder) -> Result<()> { - let ep = ort::CPUExecutionProvider::default(); - if ep.is_available()? && ep.register(builder).is_ok() { - Ok(()) - } else { - anyhow::bail!("{CROSS_MARK} CPU initialization failed") - } - } - - pub fn dry_run(&mut self) -> Result<()> { - if self.num_dry_run > 0 { - // pb - let name = std::path::Path::new(&self.name); - let pb = build_progress_bar( - self.num_dry_run as u64, - "DryRun", - Some( - name.file_name() - .and_then(|x| x.to_str()) - .unwrap_or_default(), - ), - crate::PROGRESS_BAR_STYLE_CYAN_2, - )?; - - // dummy inputs - let mut xs = Vec::new(); - for i in self.inputs_minoptmax.iter() { - let mut x: Vec = Vec::new(); - for i_ in i.iter() { - x.push(i_.opt()); - } - let x: Array = Array::ones(x).into_dyn(); - xs.push(X::from(x)); - } - let xs = Xs::from(xs); - - // run - for _ in 0..self.num_dry_run { - pb.inc(1); - self.run(xs.clone())?; - } - self.ts.clear(); - - // update - let name = std::path::Path::new(&self.name); - pb.set_message(format!( - "{} on {:?}", - name.file_name() - .and_then(|x| x.to_str()) - .unwrap_or_default(), - self.device, - )); - pb.set_style(indicatif::ProgressStyle::with_template( - crate::PROGRESS_BAR_STYLE_FINISH, - )?); - pb.finish(); - } - Ok(()) - } - - pub fn run(&mut self, xs: Xs) -> Result { - let span = tracing::span!(tracing::Level::INFO, "OrtEngine-run"); - let _guard = span.enter(); - - // inputs dtype alignment - let mut xs_ = Vec::new(); - let t_pre = std::time::Instant::now(); - for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) { - let x_ = match &idtype { - TensorElementType::Float32 => ort::Value::from_array(x.view())?.into_dyn(), - TensorElementType::Float16 => { - ort::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn() - } - TensorElementType::Int32 => { - ort::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn() - } - TensorElementType::Int64 => { - ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn() - } - TensorElementType::Uint8 => { - ort::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn() - } - TensorElementType::Int8 => { - ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn() - } - TensorElementType::Bool => { - ort::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn() - } - _ => todo!(), - }; - xs_.push(Into::>::into(x_)); - } - let t_pre = t_pre.elapsed(); - self.ts.add_or_push(0, t_pre); - - // inference - let t_run = std::time::Instant::now(); - let outputs = self.session.run(&xs_[..])?; - - let t_run = t_run.elapsed(); - self.ts.add_or_push(1, t_run); - - // oputput - let mut ys = Xs::new(); - let t_post = std::time::Instant::now(); - for (dtype, name) in self - .outputs_attrs - .dtypes - .iter() - .zip(self.outputs_attrs.names.iter()) - { - let y = &outputs[name.as_str()]; - - let y_ = match &dtype { - TensorElementType::Float32 => match y.try_extract_tensor::() { - Err(err) => { - tracing::error!("Error: {:?}. Output name: {:?}", err, name); - Array::zeros(0).into_dyn() - } - Ok(x) => x.view().into_owned(), - }, - TensorElementType::Float16 => match y.try_extract_tensor::() { - Err(err) => { - tracing::error!("Error: {:?}. Output name: {:?}", err, name); - Array::zeros(0).into_dyn() - } - Ok(x) => x.view().mapv(f16::to_f32).into_owned(), - }, - TensorElementType::Int64 => match y.try_extract_tensor::() { - Err(err) => { - tracing::error!("Error: {:?}. Output name: {:?}", err, name); - Array::zeros(0).into_dyn() - } - Ok(x) => x.view().to_owned().mapv(|x| x as f32).into_owned(), - }, - _ => todo!(), - }; - - ys.push_kv(name.as_str(), X::from(y_))?; - } - let t_post = t_post.elapsed(); - self.ts.add_or_push(2, t_post); - - if self.profile { - let len = 10usize; - let n = 4usize; - tracing::info!( - "[Profile] {:>len$.n$?} ({:>len$.n$?} avg) [alignment: {:>len$.n$?} ({:>len$.n$?} avg) | inference: {:>len$.n$?} ({:>len$.n$?} avg) | to_f32: {:>len$.n$?} ({:>len$.n$?} avg)]", - t_pre + t_run + t_post, - self.ts.avg(), - t_pre, - self.ts.avgi(0), - t_run, - self.ts.avgi(1), - t_post, - self.ts.avgi(2), - ); - } - Ok(ys) - } - - fn build_inputs_minoptmax( - inputs_attrs: &OrtTensorAttr, - iiixs: &[Iiix], - batch_size: usize, - ) -> Result>> { - let span = tracing::span!(tracing::Level::INFO, "OrtEngine-build_inputs_minoptmax"); - let _guard = span.enter(); - - // init - let mut ys: Vec> = inputs_attrs - .dimss - .iter() - .map(|dims| dims.iter().map(|&x| MinOptMax::from(x)).collect()) - .collect(); - - // update from customized - for iiix in iiixs.iter() { - if let Some(x) = inputs_attrs - .dimss - .get(iiix.i) - .and_then(|dims| dims.get(iiix.ii)) - { - // dynamic - if *x == 0 { - ys[iiix.i][iiix.ii] = iiix.x.clone(); - } - } else { - anyhow::bail!( - "Cannot retrieve the {}-th dimension of the {}-th input.", - iiix.ii, - iiix.i, - ); - } - } - - // deal with the dynamic axis - ys.iter_mut().enumerate().for_each(|(i, xs)| { - xs.iter_mut().enumerate().for_each(|(ii, x)| { - if x.is_dyn() { - let n = if ii == 0 { batch_size } else { 1 }; - let y = MinOptMax::from(n); - tracing::warn!( - "Using dynamic shapes in inputs without specifying it: the {}-th input, the {}-th dimension. \ - Using {:?} by default. You should make it clear when using TensorRT.", - i + 1, ii + 1, y - ); - *x = y; - } - }); - }); - - Ok(ys) - } - - #[allow(dead_code)] - fn nbytes_from_onnx_dtype_id(x: usize) -> usize { - match x { - 7 | 11 | 13 => 8, // i64, f64, u64 - 1 | 6 | 12 => 4, // f32, i32, u32 - 10 | 16 | 5 | 4 => 2, // f16, bf16, i16, u16 - 2 | 3 | 9 => 1, // u8, i8, bool - 8 => 4, // string(1~4) - _ => todo!(), - } - } - - #[allow(dead_code)] - fn nbytes_from_onnx_dtype(x: &ort::TensorElementType) -> usize { - match x { - ort::TensorElementType::Float64 - | ort::TensorElementType::Uint64 - | ort::TensorElementType::Int64 => 8, // i64, f64, u64 - ort::TensorElementType::Float32 - | ort::TensorElementType::Uint32 - | ort::TensorElementType::Int32 - | ort::TensorElementType::String => 4, // f32, i32, u32, string(1~4) - ort::TensorElementType::Float16 - | ort::TensorElementType::Bfloat16 - | ort::TensorElementType::Int16 - | ort::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16 - ort::TensorElementType::Uint8 - | ort::TensorElementType::Int8 - | ort::TensorElementType::Bool => 1, // u8, i8, bool - } - } - - #[allow(dead_code)] - fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option { - match value { - 0 => None, - 1 => Some(ort::TensorElementType::Float32), - 2 => Some(ort::TensorElementType::Uint8), - 3 => Some(ort::TensorElementType::Int8), - 4 => Some(ort::TensorElementType::Uint16), - 5 => Some(ort::TensorElementType::Int16), - 6 => Some(ort::TensorElementType::Int32), - 7 => Some(ort::TensorElementType::Int64), - 8 => Some(ort::TensorElementType::String), - 9 => Some(ort::TensorElementType::Bool), - 10 => Some(ort::TensorElementType::Float16), - 11 => Some(ort::TensorElementType::Float64), - 12 => Some(ort::TensorElementType::Uint32), - 13 => Some(ort::TensorElementType::Uint64), - 14 => None, // COMPLEX64 - 15 => None, // COMPLEX128 - 16 => Some(ort::TensorElementType::Bfloat16), - _ => None, - } - } - - fn io_from_onnx_value_info( - initializer_names: &HashSet<&str>, - value_info: &[onnx::ValueInfoProto], - ) -> Result { - let mut dimss: Vec> = Vec::new(); - let mut dtypes: Vec = Vec::new(); - let mut names: Vec = Vec::new(); - for v in value_info.iter() { - if initializer_names.contains(v.name.as_str()) { - continue; - } - names.push(v.name.to_string()); - let dtype = match &v.r#type { - Some(dtype) => dtype, - None => continue, - }; - let dtype = match &dtype.value { - Some(dtype) => dtype, - None => continue, - }; - let tensor = match dtype { - onnx::type_proto::Value::TensorType(tensor) => tensor, - _ => continue, - }; - let tensor_type = tensor.elem_type; - let tensor_type = match Self::ort_dtype_from_onnx_dtype_id(tensor_type) { - Some(dtype) => dtype, - None => continue, - }; - dtypes.push(tensor_type); - - let shapes = match &tensor.shape { - Some(shapes) => shapes, - None => continue, - }; - let mut shape_: Vec = Vec::new(); - for shape in shapes.dim.iter() { - match &shape.value { - None => continue, - Some(value) => match value { - onnx::tensor_shape_proto::dimension::Value::DimValue(x) => { - shape_.push(*x as _); - } - onnx::tensor_shape_proto::dimension::Value::DimParam(_) => { - shape_.push(0); - } - }, - } - } - dimss.push(shape_); - } - Ok(OrtTensorAttr { - dimss, - dtypes, - names, - }) - } - - pub fn load_onnx>(p: P) -> Result { - let f = std::fs::read(p)?; - Ok(onnx::ModelProto::decode(f.as_slice())?) - } - - pub fn oshapes(&self) -> &Vec> { - &self.outputs_attrs.dimss - } - - pub fn odimss(&self) -> &Vec> { - &self.outputs_attrs.dimss - } - - pub fn onames(&self) -> &Vec { - &self.outputs_attrs.names - } - - pub fn odtypes(&self) -> &Vec { - &self.outputs_attrs.dtypes - } - - pub fn ishapes(&self) -> &Vec> { - &self.inputs_attrs.dimss - } - - pub fn idimss(&self) -> &Vec> { - &self.inputs_attrs.dimss - } - - pub fn inames(&self) -> &Vec { - &self.inputs_attrs.names - } - - pub fn idtypes(&self) -> &Vec { - &self.inputs_attrs.dtypes - } - - pub fn device(&self) -> &Device { - &self.device - } - - pub fn inputs_minoptmax(&self) -> &Vec> { - &self.inputs_minoptmax - } - - pub fn batch(&self) -> &MinOptMax { - &self.inputs_minoptmax[0][0] - } - - pub fn try_height(&self) -> Option<&MinOptMax> { - self.inputs_minoptmax.first().and_then(|x| x.get(2)) - } - - pub fn try_width(&self) -> Option<&MinOptMax> { - self.inputs_minoptmax.first().and_then(|x| x.get(3)) - } - - pub fn height(&self) -> &MinOptMax { - &self.inputs_minoptmax[0][2] - } - - pub fn width(&self) -> &MinOptMax { - &self.inputs_minoptmax[0][3] - } - - pub fn is_batch_dyn(&self) -> bool { - self.ishapes()[0][0] == 0 - } - - pub fn try_fetch(&self, key: &str) -> Option { - match self.session.metadata() { - Err(_) => None, - Ok(metadata) => metadata.custom(key).unwrap_or_default(), - } - } - - pub fn session(&self) -> &Session { - &self.session - } - - pub fn ir_version(&self) -> usize { - self.model_proto.ir_version as usize - } - - pub fn opset_version(&self) -> usize { - self.model_proto.opset_import[0].version as usize - } - - pub fn producer_name(&self) -> String { - self.model_proto.producer_name.to_string() - } - - pub fn producer_version(&self) -> String { - self.model_proto.producer_version.to_string() - } - - pub fn model_version(&self) -> usize { - self.model_proto.model_version as usize - } - - pub fn parameters(&self) -> usize { - self.params - } - - pub fn memory_weights(&self) -> usize { - self.wbmems - } - - pub fn ts(&self) -> &Ts { - &self.ts - } -} diff --git a/src/core/tokenizer_stream.rs b/src/core/tokenizer_stream.rs deleted file mode 100644 index 495d69a..0000000 --- a/src/core/tokenizer_stream.rs +++ /dev/null @@ -1,87 +0,0 @@ -// TODO: refactor -use anyhow::Result; - -/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a -/// streaming way rather than having to wait for the full decoding. -#[derive(Debug)] -pub struct TokenizerStream { - tokenizer: tokenizers::Tokenizer, - tokens: Vec, - prev_index: usize, - current_index: usize, -} - -impl TokenizerStream { - pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { - Self { - tokenizer, - tokens: Vec::new(), - prev_index: 0, - current_index: 0, - } - } - - pub fn into_inner(self) -> tokenizers::Tokenizer { - self.tokenizer - } - - fn decode(&self, tokens: &[u32]) -> Result { - match self.tokenizer.decode(tokens, true) { - Ok(str) => Ok(str), - Err(err) => anyhow::bail!("cannot decode: {err}"), - } - } - - pub fn next_token(&mut self, token: u32) -> Result> { - let prev_text = if self.tokens.is_empty() { - String::new() - } else { - let tokens = &self.tokens[self.prev_index..self.current_index]; - self.decode(tokens)? - }; - self.tokens.push(token); - let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() { - let text = text.split_at(prev_text.len()); - self.prev_index = self.current_index; - self.current_index = self.tokens.len(); - Ok(Some(text.1.to_string())) - } else { - Ok(None) - } - } - - pub fn decode_rest(&self) -> Result> { - let prev_text = if self.tokens.is_empty() { - String::new() - } else { - let tokens = &self.tokens[self.prev_index..self.current_index]; - self.decode(tokens)? - }; - let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() { - let text = text.split_at(prev_text.len()); - Ok(Some(text.1.to_string())) - } else { - Ok(None) - } - } - - pub fn decode_all(&self) -> Result { - self.decode(&self.tokens) - } - - pub fn get_token(&self, token_s: &str) -> Option { - self.tokenizer.get_vocab(true).get(token_s).copied() - } - - pub fn tokenizer(&self) -> &tokenizers::Tokenizer { - &self.tokenizer - } - - pub fn clear(&mut self) { - self.tokens.clear(); - self.prev_index = 0; - self.current_index = 0; - } -} diff --git a/src/core/ts.rs b/src/core/ts.rs deleted file mode 100644 index dc65ae1..0000000 --- a/src/core/ts.rs +++ /dev/null @@ -1,49 +0,0 @@ -use std::time::Duration; - -#[derive(Debug, Default)] -pub struct Ts { - n: usize, - ts: Vec, -} - -impl Ts { - pub fn total(&self) -> Duration { - self.ts.iter().sum::() - } - - pub fn n(&self) -> usize { - self.n / self.ts.len() - } - - pub fn avg(&self) -> Duration { - self.total() / self.n() as u32 - } - - pub fn avgi(&self, i: usize) -> Duration { - if i >= self.ts.len() { - panic!("Index out of bound"); - } - self.ts[i] / self.n() as u32 - } - - pub fn ts(&self) -> &Vec { - &self.ts - } - - pub fn add_or_push(&mut self, i: usize, x: Duration) { - match self.ts.get_mut(i) { - Some(elem) => *elem += x, - None => { - if i >= self.ts.len() { - self.ts.push(x) - } - } - } - self.n += 1; - } - - pub fn clear(&mut self) { - self.n = Default::default(); - self.ts = Default::default(); - } -} diff --git a/src/core/vision.rs b/src/core/vision.rs deleted file mode 100644 index f78bc18..0000000 --- a/src/core/vision.rs +++ /dev/null @@ -1,51 +0,0 @@ -use crate::{Options, Xs, Y}; - -pub trait Vision: Sized { - type Input; // DynamicImage - - /// Creates a new instance of the model with the given options. - fn new(options: Options) -> anyhow::Result; - - /// Preprocesses the input data. - fn preprocess(&self, xs: &[Self::Input]) -> anyhow::Result; - - /// Executes the model on the preprocessed data. - fn inference(&mut self, xs: Xs) -> anyhow::Result; - - /// Postprocesses the model's output. - fn postprocess(&self, xs: Xs, xs0: &[Self::Input]) -> anyhow::Result>; - - /// Executes the full pipeline. - fn run(&mut self, xs: &[Self::Input]) -> anyhow::Result> { - let ys = self.preprocess(xs)?; - let ys = self.inference(ys)?; - let ys = self.postprocess(ys, xs)?; - Ok(ys) - } - - /// Executes the full pipeline. - fn forward(&mut self, xs: &[Self::Input], profile: bool) -> anyhow::Result> { - let span = tracing::span!(tracing::Level::INFO, "Vision-forward"); - let _guard = span.enter(); - - let t_pre = std::time::Instant::now(); - let ys = self.preprocess(xs)?; - let t_pre = t_pre.elapsed(); - - let t_exe = std::time::Instant::now(); - let ys = self.inference(ys)?; - let t_exe = t_exe.elapsed(); - - let t_post = std::time::Instant::now(); - let ys = self.postprocess(ys, xs)?; - let t_post = t_post.elapsed(); - - if profile { - tracing::info!( - "> Preprocess: {t_pre:?} | Execution: {t_exe:?} | Postprocess: {t_post:?}" - ); - } - - Ok(ys) - } -} diff --git a/src/lib.rs b/src/lib.rs index ce9d586..6ffb03f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,254 +5,11 @@ //! - **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569) //! - **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242) //! -//! # Examples -//! -//! Refer to [All Runnable Demos](https://github.com/jamjamjon/usls/tree/main/examples) -//! -//! # Quick Start -//! -//! The following demo shows how to build a `YOLO` with [`Options`], load `image(s)`, `video` and `stream` with [`DataLoader`], and annotate the model's inference results with [`Annotator`]. -//! -//! ```ignore -//! use usls::{models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion}; -//! -//! fn main() -> anyhow::Result<()> { -//! // Build model with Options -//! let options = Options::new() -//! .with_trt(0) -//! .with_model("yolo/v8-m-dyn.onnx")? -//! .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR -//! .with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb -//! .with_i00((1, 1, 4).into()) -//! .with_i02((0, 640, 640).into()) -//! .with_i03((0, 640, 640).into()) -//! .with_confs(&[0.2]); -//! let mut model = YOLO::new(options)?; -//! -//! // Build DataLoader to load image(s), video, stream -//! let dl = DataLoader::new( -//! "./assets/bus.jpg", // local image -//! // "images/bus.jpg", // remote image -//! // "../set-negs", // local images (from folder) -//! // "../hall.mp4", // local video -//! // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // remote video -//! // "rtsp://admin:kkasd1234@192.168.2.217:554/h264/ch1/", // stream -//! )? -//! .with_batch(3) // iterate with batch_size = 3 -//! .build()?; -//! -//! // Build annotator -//! let annotator = Annotator::new().with_saveout("YOLO-Demo"); -//! -//! // Run and Annotate images -//! for (xs, _) in dl { -//! let ys = model.forward(&xs, false)?; -//! annotator.annotate(&xs, &ys); -//! } -//! -//! Ok(()) -//! } -//! ``` -//! - -//! # What's More -//! -//! This guide covers the process of using provided models for inference, including how to build a model, load data, annotate results, and retrieve the outputs. Click the sections below to expand for detailed instructions. -//! -//!
-//! Build the Model -//! -//! To build a model, you can use the provided [models] with [Options]: -//! -//! ```ignore -//! use usls::{models::YOLO, Annotator, DataLoader, Options, Vision}; -//! -//! let options = Options::default() -//! .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR -//! .with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb -//! .with_model("xxxx.onnx")?; -//! let mut model = YOLO::new(options)?; -//! ``` -//! -//! **And there're many options provided by [Options]** -//! -//! - **Choose Execution Provider:** -//! Select `CUDA` (default), `TensorRT`, or `CoreML`: -//! -//! ```ignore -//! let options = Options::default() -//! .with_cuda(0) -//! // .with_trt(0) -//! // .with_coreml(0) -//! // .with_cpu(); -//! ``` -//! -//! - **Dynamic Input Shapes:** -//! Specify dynamic shapes with [MinOptMax]: -//! -//! ```ignore -//! let options = Options::default() -//! .with_i00((1, 2, 4).into()) // batch(min=1, opt=2, max=4) -//! .with_i02((416, 640, 800).into()) // height(min=416, opt=640, max=800) -//! .with_i03((416, 640, 800).into()); // width(min=416, opt=640, max=800) -//! ``` -//! -//! - **Set Confidence Thresholds:** -//! Adjust thresholds for each category: -//! -//! ```ignore -//! let options = Options::default() -//! .with_confs(&[0.4, 0.15]); // class_0: 0.4, others: 0.15 -//! ``` -//! -//! - **Set Class Names:** -//! Provide class names if needed: -//! -//! ```ignore -//! let options = Options::default() -//! .with_names(&COCO_CLASS_NAMES_80); -//! ``` -//! -//! **More options are detailed in the [Options] documentation.** -//! -//! -//!
-//! -//!
-//! Load Images, Video and Stream -//! -//! - **Load a Single Image** -//! Use [DataLoader::try_read] to load an image from a local file or remote source: -//! -//! ```ignore -//! let x = DataLoader::try_read("./assets/bus.jpg")?; // from local -//! let x = DataLoader::try_read("images/bus.jpg")?; // from remote -//! ``` -//! -//! Alternatively, use [image::ImageReader] directly: -//! -//! ```ignore -//! let x = image::ImageReader::open("myimage.png")?.decode()?; -//! ``` -//! -//! - **Load Multiple Images, Videos, or Streams** -//! Create a [DataLoader] instance for batch processing: -//! -//! ```ignore -//! let dl = DataLoader::new( -//! "./assets/bus.jpg", // local image -//! // "images/bus.jpg", // remote image -//! // "../set-negs", // local images (from folder) -//! // "../hall.mp4", // local video -//! // "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // remote video -//! // "rtsp://admin:kkasd1234@192.168.2.217:554/h264/ch1/", // stream -//! )? -//! .with_batch(3) // iterate with batch_size = 3 -//! .build()?; -//! -//! // Iterate through the data -//! for (xs, _) in dl {} -//! ``` -//! -//! - **Convert Images to Video** -//! Use [DataLoader::is2v] to create a video from a sequence of images: -//! -//! ```ignore -//! let fps = 24; -//! let image_folder = "runs/YOLO-DataLoader"; -//! let saveout = ["runs", "is2v"]; -//! DataLoader::is2v(image_folder, &saveout, fps)?; -//! ``` -//! -//!
-//! -//!
-//! Annotate Inference Results -//! -//! - **Create an Annotator Instance** -//! -//! ```ignore -//! let annotator = Annotator::default(); -//! ``` -//! -//! - **Set Saveout Name:** -//! -//! ```ignore -//! let annotator = Annotator::default() -//! .with_saveout("YOLOs"); -//! ``` -//! -//! - **Set Bounding Box Line Width:** -//! -//! ```ignore -//! let annotator = Annotator::default() -//! .with_bboxes_thickness(4); -//! ``` -//! -//! - **Disable Mask Plotting** -//! -//! ```ignore -//! let annotator = Annotator::default() -//! .without_masks(true); -//! ``` -//! -//! - **Perform Inference and nnotate the results** -//! -//! ```ignore -//! for (xs, _paths) in dl { -//! let ys = model.run(&xs)?; -//! annotator.annotate(&xs, &ys); -//! } -//! ``` -//! -//! More options are detailed in the [Annotator] documentation. -//! -//!
-//! -//!
-//! Retrieve Model's Inference Results -//! -//! Retrieve the inference outputs, which are saved in a [`Vec`]: -//! -//! - **Get Detection Bounding Boxes** -//! -//! ```ignore -//! let ys = model.run(&xs)?; -//! for y in ys { -//! // bboxes -//! if let Some(bboxes) = y.bboxes() { -//! for bbox in bboxes { -//! println!( -//! "Bbox: {}, {}, {}, {}, {}, {}", -//! bbox.xmin(), -//! bbox.ymin(), -//! bbox.xmax(), -//! bbox.ymax(), -//! bbox.confidence(), -//! bbox.id(), -//! ); -//! } -//! } -//! } -//! ``` -//! -//!
-//! -//!
-//! Custom Model Implementation -//! -//! You can also implement your own model using [OrtEngine] and [Options]. [OrtEngine] supports ONNX model loading, metadata parsing, dry_run, inference, and other functions, with execution providers such as CUDA, TensorRT, CoreML, etc. -//! -//! For more details, refer to the [Demo: Depth-Anything](https://github.com/jamjamjon/usls/blob/main/src/models/depth_anything.rs). -//! -//!
-mod core; +mod misc; pub mod models; -mod utils; -mod ys; +mod xy; -pub use core::*; -pub use models::*; -pub use utils::*; -pub use ys::*; +pub use misc::*; +pub use models::{Kind, Options, Processor, ResizeMode, Scale, Task, Version}; +pub use xy::*; diff --git a/src/core/annotator.rs b/src/misc/annotator.rs similarity index 97% rename from src/core/annotator.rs rename to src/misc/annotator.rs index b9f51ee..ae5ddba 100644 --- a/src/core/annotator.rs +++ b/src/misc/annotator.rs @@ -1,11 +1,12 @@ -use crate::{ - string_now, Bbox, Color, ColorMap256, Dir, Hub, Keypoint, Mask, Mbr, Polygon, Prob, CHECK_MARK, - CROSS_MARK, Y, -}; use ab_glyph::{FontArc, PxScale}; use anyhow::Result; use image::{DynamicImage, GenericImage, Rgba, RgbaImage}; use imageproc::map::map_colors; +use log::{error, info}; + +use crate::{ + string_now, Bbox, Color, ColorMap256, Dir, Hub, Keypoint, Mask, Mbr, Polygon, Prob, Y, +}; /// Annotator for struct `Y` #[derive(Clone)] @@ -342,7 +343,7 @@ impl Annotator { } // mkdir even no filename specified - Dir::Currnet.raw_path_with_subs(&subs) + Dir::Current.raw_path_with_subs(&subs) } /// Annotate images, save, and no return @@ -352,9 +353,6 @@ impl Annotator { /// Plot images and return plotted images pub fn plot(&self, imgs: &[DynamicImage], ys: &[Y], save: bool) -> Result> { - let span = tracing::span!(tracing::Level::INFO, "Annotator-plot"); - let _guard = span.enter(); - let mut vs: Vec = Vec::new(); // annotate @@ -405,9 +403,9 @@ impl Annotator { if save { let saveout = self.saveout()?.join(format!("{}.png", string_now("-"))); match img_rgba.save(&saveout) { - Err(err) => tracing::error!("{} Saving failed: {:?}", CROSS_MARK, err), + Err(err) => error!("Saving failed: {:?}", err), Ok(_) => { - tracing::info!("{} Annotated image saved to: {:?}", CHECK_MARK, saveout); + info!("Annotated image saved to: {:?}", saveout); } } } @@ -415,6 +413,7 @@ impl Annotator { // RgbaImage -> DynamicImage vs.push(image::DynamicImage::from(img_rgba)); } + Ok(vs) } @@ -761,11 +760,11 @@ impl Annotator { /// Load custom font fn load_font(path: Option<&str>) -> Result { let path_font = match path { - None => Hub::new()?.fetch("fonts/Arial.ttf")?.commit()?, + None => Hub::new()?.try_fetch("fonts/Arial.ttf")?, Some(p) => p.into(), }; - let buffer = std::fs::read(path_font)?; - Ok(FontArc::try_from_vec(buffer.to_owned())?) + let buf = std::fs::read(path_font)?; + Ok(FontArc::try_from_vec(buf.to_owned())?) } /// Color palette diff --git a/src/utils/color.rs b/src/misc/color.rs similarity index 100% rename from src/utils/color.rs rename to src/misc/color.rs diff --git a/src/utils/colormap256.rs b/src/misc/colormap256.rs similarity index 100% rename from src/utils/colormap256.rs rename to src/misc/colormap256.rs diff --git a/src/core/dataloader.rs b/src/misc/dataloader.rs similarity index 83% rename from src/core/dataloader.rs rename to src/misc/dataloader.rs index d96d439..bff5b80 100644 --- a/src/core/dataloader.rs +++ b/src/misc/dataloader.rs @@ -1,18 +1,18 @@ use anyhow::{anyhow, Result}; use image::DynamicImage; -use indicatif::{ProgressBar, ProgressStyle}; +use indicatif::ProgressBar; +use log::{info, warn}; use std::collections::VecDeque; use std::path::{Path, PathBuf}; use std::sync::mpsc; +#[cfg(feature = "ffmpeg")] use video_rs::{ encode::{Encoder, Settings}, time::Time, Decoder, Url, }; -use crate::{ - build_progress_bar, string_now, Dir, Hub, Location, MediaType, CHECK_MARK, CROSS_MARK, -}; +use crate::{build_progress_bar, Hub, Location, MediaType}; type TempReturnType = (Vec, Vec); @@ -37,9 +37,7 @@ impl Iterator for DataLoaderIterator { None => { progress_bar.set_prefix("Iterated"); progress_bar.set_style( - match indicatif::ProgressStyle::with_template( - crate::PROGRESS_BAR_STYLE_FINISH_2, - ) { + match indicatif::ProgressStyle::with_template(crate::PROGRESS_BAR_STYLE_FINISH_2) { Ok(x) => x, Err(err) => panic!("Failed to set style for progressbar in `DataLoaderIterator`: {}", err), }, @@ -99,6 +97,7 @@ pub struct DataLoader { receiver: mpsc::Receiver, /// Video decoder for handling video or stream data. + #[cfg(feature = "ffmpeg")] decoder: Option, /// Number of images or frames; `u64::MAX` is used for live streams (indicating no limit). @@ -108,10 +107,18 @@ pub struct DataLoader { with_pb: bool, } +impl TryFrom<&str> for DataLoader { + type Error = anyhow::Error; + + fn try_from(str: &str) -> Result { + Self::new(str) + } +} + impl DataLoader { pub fn new(source: &str) -> Result { - let span = tracing::span!(tracing::Level::INFO, "DataLoader-new"); - let _guard = span.enter(); + // TODO: multi-types + // Vec<&str> // Number of frames or stream let mut nf = 0; @@ -153,6 +160,21 @@ impl DataLoader { } // video decoder + #[cfg(not(feature = "ffmpeg"))] + { + match &media_type { + MediaType::Video(Location::Local) + | MediaType::Video(Location::Remote) + | MediaType::Stream => { + anyhow::bail!( + "Video processing requires the features: `ffmpeg`. \ + \nConsider enabling them by passing, e.g., `--features ffmpeg`" + ); + } + _ => {} + }; + } + #[cfg(feature = "ffmpeg")] let decoder = match &media_type { MediaType::Video(Location::Local) => Some(Decoder::new(source_path)?), MediaType::Video(Location::Remote) | MediaType::Stream => { @@ -163,6 +185,7 @@ impl DataLoader { }; // video & stream frames + #[cfg(feature = "ffmpeg")] if let Some(decoder) = &decoder { nf = match decoder.frames() { Err(_) => u64::MAX, @@ -172,7 +195,7 @@ impl DataLoader { } // summary - tracing::info!("{} Found {:?} x{}", CHECK_MARK, media_type, nf); + info!("Found {:?} x{}", media_type, nf); Ok(DataLoader { paths, @@ -180,6 +203,7 @@ impl DataLoader { bound: 50, receiver: mpsc::sync_channel(1).1, batch_size: 1, + #[cfg(feature = "ffmpeg")] decoder, nf, with_pb: true, @@ -196,6 +220,11 @@ impl DataLoader { self } + pub fn with_batch_size(mut self, x: usize) -> Self { + self.batch_size = x; + self + } + pub fn with_progress_bar(mut self, x: bool) -> Self { self.with_pb = x; self @@ -207,11 +236,19 @@ impl DataLoader { let batch_size = self.batch_size; let data = self.paths.take().unwrap_or_default(); let media_type = self.media_type.clone(); + #[cfg(feature = "ffmpeg")] let decoder = self.decoder.take(); // Spawn the producer thread std::thread::spawn(move || { - DataLoader::producer_thread(sender, data, batch_size, media_type, decoder); + DataLoader::producer_thread( + sender, + data, + batch_size, + media_type, + #[cfg(feature = "ffmpeg")] + decoder, + ); }); Ok(self) @@ -222,10 +259,8 @@ impl DataLoader { mut data: VecDeque, batch_size: usize, media_type: MediaType, - mut decoder: Option, + #[cfg(feature = "ffmpeg")] mut decoder: Option, ) { - let span = tracing::span!(tracing::Level::INFO, "DataLoader-producer-thread"); - let _guard = span.enter(); let mut yis: Vec = Vec::with_capacity(batch_size); let mut yps: Vec = Vec::with_capacity(batch_size); @@ -234,7 +269,7 @@ impl DataLoader { while let Some(path) = data.pop_front() { match Self::try_read(&path) { Err(err) => { - tracing::warn!("{} {:?} | {:?}", CROSS_MARK, path, err); + warn!("{:?} | {:?}", path, err); continue; } Ok(img) => { @@ -251,6 +286,7 @@ impl DataLoader { } } } + #[cfg(feature = "ffmpeg")] MediaType::Video(_) | MediaType::Stream => { if let Some(decoder) = decoder.as_mut() { let (w, h) = decoder.size(); @@ -285,12 +321,12 @@ impl DataLoader { } } } - _ => todo!(), + _ => unimplemented!(), } // Deal with remaining data if !yis.is_empty() && sender.send((yis, yps)).is_err() { - tracing::info!("Receiver dropped, stopping production"); + info!("Receiver dropped, stopping production"); } } @@ -325,13 +361,30 @@ impl DataLoader { // try to fetch from hub or local cache if !path.exists() { - let p = Hub::new()?.fetch(path.to_str().unwrap())?.commit()?; + let p = Hub::new()?.try_fetch(path.to_str().unwrap())?; path = PathBuf::from(&p); } let img = Self::read_into_rgb8(path)?; Ok(DynamicImage::from(img)) } + pub fn try_read_batch + std::fmt::Debug>( + paths: &[P], + ) -> Result> { + let images = paths + .iter() + .filter_map(|path| match Self::try_read(path) { + Ok(img) => Some(img), + Err(err) => { + warn!("Failed to read from: {:?}. Error: {:?}", path, err); + None + } + }) + .collect(); + + Ok(images) + } + fn read_into_rgb8>(path: P) -> Result { let path = path.as_ref(); let img = image::ImageReader::open(path) @@ -363,6 +416,7 @@ impl DataLoader { } /// Convert images into a video + #[cfg(feature = "ffmpeg")] pub fn is2v>(source: P, subs: &[&str], fps: usize) -> Result<()> { let paths = Self::load_from_folder(source.as_ref())?; if paths.is_empty() { @@ -370,9 +424,9 @@ impl DataLoader { } let mut encoder = None; let mut position = Time::zero(); - let saveout = Dir::Currnet + let saveout = crate::Dir::Current .raw_path_with_subs(subs)? - .join(format!("{}.mp4", string_now("-"))); + .join(format!("{}.mp4", crate::string_now("-"))); let pb = build_progress_bar( paths.len() as u64, "Converting", @@ -412,7 +466,7 @@ impl DataLoader { // update pb.set_prefix("Converted"); pb.set_message(saveout.to_str().unwrap_or_default().to_string()); - pb.set_style(ProgressStyle::with_template( + pb.set_style(indicatif::ProgressStyle::with_template( crate::PROGRESS_BAR_STYLE_FINISH_4, )?); pb.finish(); diff --git a/src/misc/device.rs b/src/misc/device.rs new file mode 100644 index 0000000..e1029e1 --- /dev/null +++ b/src/misc/device.rs @@ -0,0 +1,63 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum Device { + Auto(usize), + Cpu(usize), + Cuda(usize), + TensorRT(usize), + CoreML(usize), + // Cann(usize), + // Acl(usize), + // Rocm(usize), + // Rknpu(usize), + // Openvino(usize), + // Onednn(usize), +} + +impl std::fmt::Display for Device { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let x = match self { + Self::Auto(i) => format!("auto:{}", i), + Self::Cpu(i) => format!("cpu:{}", i), + Self::Cuda(i) => format!("cuda:{}", i), + Self::TensorRT(i) => format!("tensorrt:{}", i), + Self::CoreML(i) => format!("mps:{}", i), + }; + write!(f, "{}", x) + } +} + +impl TryFrom<&str> for Device { + type Error = anyhow::Error; + + fn try_from(s: &str) -> Result { + // device and its id + let d_id: Vec<&str> = s.trim().split(':').collect(); + let (d, id) = match d_id.len() { + 1 => (d_id[0], 0), + 2 => (d_id[0], d_id[1].parse::().unwrap_or(0)), + _ => anyhow::bail!( + "Fail to parse device string: {s}. Expect: `device:device_id` or `device`. e.g. `cuda:0` or `cuda`" + ), + }; + // TODO: device-id checking + match d.to_lowercase().as_str() { + "cpu" => Ok(Self::Cpu(id)), + "cuda" => Ok(Self::Cuda(id)), + "trt" | "tensorrt" => Ok(Self::TensorRT(id)), + "coreml" | "mps" => Ok(Self::CoreML(id)), + _ => anyhow::bail!("Unsupported device str: {s:?}."), + } + } +} + +impl Device { + pub fn id(&self) -> usize { + match self { + Device::Auto(i) => *i, + Device::Cpu(i) => *i, + Device::Cuda(i) => *i, + Device::TensorRT(i) => *i, + Device::CoreML(i) => *i, + } + } +} diff --git a/src/core/dir.rs b/src/misc/dir.rs similarity index 97% rename from src/core/dir.rs rename to src/misc/dir.rs index 7c36c56..9b260ff 100644 --- a/src/core/dir.rs +++ b/src/misc/dir.rs @@ -4,7 +4,7 @@ pub enum Dir { Home, Cache, Config, - Currnet, + Current, Document, Data, Download, @@ -15,7 +15,7 @@ pub enum Dir { impl Dir { pub fn saveout(subs: &[&str]) -> anyhow::Result { - Self::Currnet.raw_path_with_subs(subs) + Self::Current.raw_path_with_subs(subs) } /// Retrieves the base path for the specified directory type, optionally appending the `usls` subdirectory. @@ -30,7 +30,7 @@ impl Dir { Dir::Home => dirs::home_dir(), Dir::Cache => dirs::cache_dir(), Dir::Config => dirs::config_dir(), - Dir::Currnet => std::env::current_dir().ok(), + Dir::Current => std::env::current_dir().ok(), _ => None, }; diff --git a/src/misc/dtype.rs b/src/misc/dtype.rs new file mode 100644 index 0000000..81f0d50 --- /dev/null +++ b/src/misc/dtype.rs @@ -0,0 +1,114 @@ +use ort::tensor::TensorElementType; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum DType { + Auto, + Int8, + Int16, + Int32, + Int64, + Uint8, + Uint16, + Uint32, + Uint64, + Fp16, + Fp32, + Fp64, + Bf16, + Bool, + String, + Bnb4, + Q4, + Q4f16, +} + +impl TryFrom<&str> for DType { + type Error = anyhow::Error; + + fn try_from(s: &str) -> Result { + match s.to_lowercase().as_str() { + "auto" | "dyn" => Ok(Self::Auto), + "u8" | "uint8" => Ok(Self::Uint8), + "u16" | "uint16" => Ok(Self::Uint16), + "u32" | "uint32" => Ok(Self::Uint32), + "u64" | "uint64" => Ok(Self::Uint64), + "i8" | "int8" => Ok(Self::Int8), + "i16" | "int=16" => Ok(Self::Int16), + "i32" | "int32" => Ok(Self::Int32), + "i64" | "int64" => Ok(Self::Int64), + "f16" | "fp16" => Ok(Self::Fp16), + "f32" | "fp32" => Ok(Self::Fp32), + "f64" | "fp64" => Ok(Self::Fp64), + "b16" | "bf16" => Ok(Self::Bf16), + "q4f16" => Ok(Self::Q4f16), + "q4" => Ok(Self::Q4), + "bnb4" => Ok(Self::Bnb4), + x => anyhow::bail!("Unsupported Model DType: {}", x), + } + } +} + +impl std::fmt::Display for DType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let x = match self { + Self::Auto => "auto", + Self::Int8 => "int8", + Self::Int16 => "int16", + Self::Int32 => "int32", + Self::Int64 => "int64", + Self::Uint8 => "uint8", + Self::Uint16 => "uint16", + Self::Uint32 => "uint32", + Self::Uint64 => "uint64", + Self::Fp16 => "fp16", + Self::Fp32 => "fp32", + Self::Fp64 => "fp64", + Self::Bf16 => "bf16", + Self::String => "string", + Self::Bool => "bool", + Self::Bnb4 => "bnb4", + Self::Q4 => "q4", + Self::Q4f16 => "q4f16", + }; + write!(f, "{}", x) + } +} + +impl DType { + pub fn to_ort(&self) -> TensorElementType { + match self { + Self::Int8 => TensorElementType::Int8, + Self::Int16 => TensorElementType::Int16, + Self::Int32 => TensorElementType::Int32, + Self::Int64 => TensorElementType::Int64, + Self::Uint8 => TensorElementType::Uint8, + Self::Uint16 => TensorElementType::Uint16, + Self::Uint32 => TensorElementType::Uint32, + Self::Uint64 => TensorElementType::Uint64, + Self::Fp16 => TensorElementType::Float16, + Self::Fp32 => TensorElementType::Float32, + Self::Fp64 => TensorElementType::Float64, + Self::Bf16 => TensorElementType::Bfloat16, + _ => todo!(), + } + } + + pub fn from_ort(dtype: &TensorElementType) -> Self { + match dtype { + TensorElementType::Int8 => Self::Int8, + TensorElementType::Int16 => Self::Int16, + TensorElementType::Int32 => Self::Int32, + TensorElementType::Int64 => Self::Int64, + TensorElementType::Uint8 => Self::Uint8, + TensorElementType::Uint16 => Self::Uint16, + TensorElementType::Uint32 => Self::Uint32, + TensorElementType::Uint64 => Self::Uint64, + TensorElementType::Float16 => Self::Fp16, + TensorElementType::Float32 => Self::Fp32, + TensorElementType::Float64 => Self::Fp64, + TensorElementType::Bfloat16 => Self::Bf16, + TensorElementType::String => Self::String, + TensorElementType::Bool => Self::Bool, + } + } +} diff --git a/src/core/dynconf.rs b/src/misc/dynconf.rs similarity index 100% rename from src/core/dynconf.rs rename to src/misc/dynconf.rs diff --git a/src/misc/engine.rs b/src/misc/engine.rs new file mode 100644 index 0000000..88ccf10 --- /dev/null +++ b/src/misc/engine.rs @@ -0,0 +1,743 @@ +use aksr::Builder; +use anyhow::Result; +use half::{bf16, f16}; +use log::{error, info, warn}; +use ndarray::{Array, IxDyn}; +#[allow(unused_imports)] +use ort::{ + execution_providers::ExecutionProvider, + session::{ + builder::{GraphOptimizationLevel, SessionBuilder}, + Session, SessionInputValue, + }, + tensor::TensorElementType, + value::{DynValue, Value}, +}; +use prost::Message; +use std::collections::HashSet; + +use crate::{ + build_progress_bar, elapsed, human_bytes, onnx, DType, Device, Iiix, MinOptMax, Ops, Ts, Xs, X, +}; + +/// A struct for tensor attrs composed of the names, the dtypes, and the dimensions. +#[derive(Builder, Debug, Clone)] +pub struct OrtTensorAttr { + pub names: Vec, + pub dtypes: Vec, + pub dimss: Vec>, +} + +#[derive(Debug)] +pub struct OnnxIo { + pub inputs: OrtTensorAttr, + pub outputs: OrtTensorAttr, + pub session: Session, + pub proto: onnx::ModelProto, +} + +#[derive(Debug, Builder)] +pub struct Engine { + pub file: String, + pub spec: String, + pub device: Device, + pub trt_fp16: bool, + #[args(inc = true)] + pub iiixs: Vec, + #[args(alias = "parameters")] + pub params: Option, + #[args(alias = "memory")] + pub wbmems: Option, + pub inputs_minoptmax: Vec>, + pub onnx: Option, + pub ts: Ts, + pub num_dry_run: usize, +} + +impl Default for Engine { + fn default() -> Self { + Self { + file: Default::default(), + device: Device::Cpu(0), + trt_fp16: false, + spec: Default::default(), + iiixs: Default::default(), + num_dry_run: 3, + params: None, + wbmems: None, + inputs_minoptmax: vec![], + onnx: None, + ts: Ts::default(), + } + } +} + +impl Engine { + pub fn build(mut self) -> Result { + let name = format!("[{}] ort_initialization", self.spec); + elapsed!(&name, self.ts, { + let proto = Self::load_onnx(self.file())?; + let graph = match &proto.graph { + Some(graph) => graph, + None => { + anyhow::bail!( + "No graph found in this proto. Invalid ONNX model: {}", + self.file() + ) + } + }; + + // params & mems + let byte_alignment = 16; + let mut params: usize = 0; + let mut wbmems: usize = 0; + let mut initializer_names: HashSet<&str> = HashSet::new(); + if !graph.initializer.is_empty() { + // from initializer + for tensor_proto in graph.initializer.iter() { + initializer_names.insert(&tensor_proto.name); + let param = tensor_proto.dims.iter().product::() as usize; + params += param; + let param = Ops::make_divisible(param, byte_alignment); + let n = Self::nbytes_from_onnx_dtype_id(tensor_proto.data_type as usize); + let wbmem = param * n; + wbmems += wbmem; + } + } else { + // from node, workaround + for node in &graph.node { + for attr in &node.attribute { + if let Some(tensor) = &attr.t { + let param = tensor.dims.iter().product::() as usize; + params += param; + let param = Ops::make_divisible(param, byte_alignment); + let n = Self::nbytes_from_onnx_dtype_id(tensor.data_type as usize); + let wbmem = param * n; + wbmems += wbmem; + } + } + } + } + self.params = Some(params); + self.wbmems = Some(wbmems); + + // inputs & outputs + let inputs = Self::io_from_onnx_value_info(&initializer_names, &graph.input)?; + let outputs = Self::io_from_onnx_value_info(&initializer_names, &graph.output)?; + self.inputs_minoptmax = Self::build_ort_inputs(&inputs, self.iiixs())?; + + // session + ort::init().commit()?; + let session = self.build_session(&inputs)?; + + // onnxio + self.onnx = Some(OnnxIo { + inputs, + outputs, + proto, + session, + }); + }); + self.dry_run()?; + self.info(); + + Ok(self) + } + + pub fn dry_run(&mut self) -> Result<()> { + if self.num_dry_run > 0 { + // pb + let pb = build_progress_bar( + self.num_dry_run as u64, + "DryRun", + Some(self.spec()), + crate::PROGRESS_BAR_STYLE_CYAN_2, + )?; + + // dummy + let mut xs = Vec::new(); + for i in self.inputs_minoptmax().iter() { + let mut x: Vec = Vec::new(); + for i_ in i.iter() { + x.push(i_.opt()); + } + let x: Array = Array::ones(x).into_dyn(); + xs.push(X::from(x)); + } + let xs = Xs::from(xs); + + // run + for i in 0..self.num_dry_run { + pb.inc(1); + let name = format!("[{}] ort_dry_run_{}", self.spec, i); + elapsed!(&name, self.ts, { + self.run(xs.clone())?; + }); + } + + // update + pb.set_message(format!( + "{}(Params: {}) on {:?}", + self.spec, + match self.params { + Some(bytes) if bytes != 0 => { + human_bytes(bytes as f64, true) + } + _ => "Unknown".to_string(), + }, + self.device, + )); + pb.set_style(indicatif::ProgressStyle::with_template( + crate::PROGRESS_BAR_STYLE_FINISH, + )?); + pb.finish(); + } + Ok(()) + } + + pub fn run(&mut self, xs: Xs) -> Result { + let mut ys = xs.derive(); + if let Some(onnx) = &self.onnx { + // alignment + let xs_ = elapsed!(&format!("[{}] ort_preprocessing", self.spec), self.ts, { + let mut xs_ = Vec::new(); + for (dtype, x) in onnx.inputs.dtypes.iter().zip(xs.into_iter()) { + xs_.push(Into::>::into(Self::preprocess( + x, dtype, + )?)); + } + xs_ + }); + + // run + let outputs = elapsed!( + &format!("[{}] ort_inference", self.spec), + self.ts, + onnx.session.run(&xs_[..])? + ); + + // extract + elapsed!(&format!("[{}] ort_postprocessing", self.spec), self.ts, { + for (dtype, name) in onnx.outputs.dtypes.iter().zip(onnx.outputs.names.iter()) { + let y = Self::postprocess(&outputs[name.as_str()], dtype)?; + ys.push_kv(name.as_str(), X::from(y))?; + } + }); + Ok(ys) + } else { + anyhow::bail!("Failed to run with ONNXRuntime. No model info found."); + } + } + + fn preprocess(x: &X, dtype: &TensorElementType) -> Result { + let x = match dtype { + TensorElementType::Float32 => Value::from_array(x.view())?.into_dyn(), + TensorElementType::Float16 => { + Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn() + } + TensorElementType::Float64 => Value::from_array(x.view())?.into_dyn(), + TensorElementType::Bfloat16 => { + Value::from_array(x.mapv(bf16::from_f32).view())?.into_dyn() + } + TensorElementType::Int8 => Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn(), + TensorElementType::Int16 => { + Value::from_array(x.mapv(|x_| x_ as i16).view())?.into_dyn() + } + TensorElementType::Int32 => { + Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn() + } + TensorElementType::Int64 => { + Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn() + } + TensorElementType::Uint8 => Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn(), + TensorElementType::Uint16 => { + Value::from_array(x.mapv(|x_| x_ as u16).view())?.into_dyn() + } + TensorElementType::Uint32 => { + Value::from_array(x.mapv(|x_| x_ as u32).view())?.into_dyn() + } + TensorElementType::Uint64 => { + Value::from_array(x.mapv(|x_| x_ as u64).view())?.into_dyn() + } + TensorElementType::Bool => Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn(), + _ => unimplemented!(), + }; + + Ok(x) + } + + fn postprocess(x: &DynValue, dtype: &TensorElementType) -> Result> { + fn _extract_and_convert(x: &DynValue, map_fn: impl Fn(T) -> f32) -> Array + where + T: Clone + 'static + ort::tensor::PrimitiveTensorElementType, + { + match x.try_extract_tensor::() { + Err(err) => { + error!("Failed to extract from ort outputs: {:?}", err); + Array::zeros(0).into_dyn() + } + Ok(x) => x.view().mapv(map_fn).into_owned(), + } + } + let x = match dtype { + TensorElementType::Float32 => _extract_and_convert::(x, |x| x), + TensorElementType::Float16 => _extract_and_convert::(x, f16::to_f32), + TensorElementType::Bfloat16 => _extract_and_convert::(x, bf16::to_f32), + TensorElementType::Float64 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Int64 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Int32 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Int16 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Int8 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Uint64 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Uint32 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Uint16 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Uint8 => _extract_and_convert::(x, |x| x as f32), + TensorElementType::Bool => _extract_and_convert::(x, |x| x as u8 as f32), + _ => return Err(anyhow::anyhow!("Unsupported ort tensor type: {:?}", dtype)), + }; + + Ok(x) + } + + #[allow(unused_variables)] + fn build_session(&mut self, inputs: &OrtTensorAttr) -> Result { + #[allow(unused_mut)] + let mut builder = Session::builder()?; + let compile_help = "Please compile ONNXRuntime with #EP"; + let feature_help = "#EP EP requires the features: `#FEATURE`. \ + \nConsider enabling them by passing, e.g., `--features #FEATURE`"; + + match self.device { + Device::TensorRT(id) => { + #[cfg(not(feature = "trt"))] + { + anyhow::bail!(feature_help + .replace("#EP", "TensorRT") + .replace("#FEATURE", "trt")); + } + + #[cfg(feature = "trt")] + { + // generate shapes + let mut spec_min = String::new(); + let mut spec_opt = String::new(); + let mut spec_max = String::new(); + for (i, name) in inputs.names.iter().enumerate() { + if i != 0 { + spec_min.push(','); + spec_opt.push(','); + spec_max.push(','); + } + let mut s_min = format!("{}:", name); + let mut s_opt = format!("{}:", name); + let mut s_max = format!("{}:", name); + for d in self.inputs_minoptmax[i].iter() { + let min_ = &format!("{}x", d.min()); + let opt_ = &format!("{}x", d.opt()); + let max_ = &format!("{}x", d.max()); + s_min += min_; + s_opt += opt_; + s_max += max_; + } + s_min.pop(); + s_opt.pop(); + s_max.pop(); + spec_min += &s_min; + spec_opt += &s_opt; + spec_max += &s_max; + } + + let p = crate::Dir::Cache.path_with_subs(&["trt-cache"])?; + let ep = ort::execution_providers::TensorRTExecutionProvider::default() + .with_device_id(id as i32) + .with_fp16(self.trt_fp16) + .with_engine_cache(true) + .with_engine_cache_path(p.to_str().unwrap()) + .with_timing_cache(false) + .with_profile_min_shapes(spec_min) + .with_profile_opt_shapes(spec_opt) + .with_profile_max_shapes(spec_max); + + match ep.is_available() { + Ok(true) => { + info!( + "Initial model serialization with TensorRT may require a wait..." + ); + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register TensorRT: {}", err) + })?; + } + _ => { + anyhow::bail!(compile_help.replace("#EP", "TensorRT")) + } + } + } + } + Device::Cuda(id) => { + #[cfg(not(feature = "cuda"))] + { + anyhow::bail!(feature_help + .replace("#EP", "CUDA") + .replace("#FEATURE", "cuda")); + } + + #[cfg(feature = "cuda")] + { + let ep = ort::execution_providers::CUDAExecutionProvider::default() + .with_device_id(id as i32); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register CUDA: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "CUDA")), + } + } + } + Device::CoreML(id) => { + #[cfg(not(feature = "mps"))] + { + anyhow::bail!(feature_help + .replace("#EP", "CoreML") + .replace("#FEATURE", "mps")); + } + #[cfg(feature = "mps")] + { + let ep = ort::execution_providers::CoreMLExecutionProvider::default(); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register CoreML: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "CoreML")), + } + } + } + _ => { + let ep = ort::execution_providers::CPUExecutionProvider::default(); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder) + .map_err(|err| anyhow::anyhow!("Failed to register Cpu: {}", err))?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "Cpu")), + } + } + } + + // session + let session = builder + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(std::thread::available_parallelism()?.get())? + .commit_from_file(self.file())?; + + Ok(session) + } + + fn build_ort_inputs(xs: &OrtTensorAttr, iiixs: &[Iiix]) -> Result>> { + // init + let mut ys: Vec> = xs + .dimss + .iter() + .map(|dims| dims.iter().map(|&x| MinOptMax::from(x)).collect()) + .collect(); + + // update from customized + for iiix in iiixs.iter() { + if let Some(x) = xs.dimss.get(iiix.i).and_then(|dims| dims.get(iiix.ii)) { + // dynamic + if *x == 0 { + ys[iiix.i][iiix.ii] = iiix.x.clone(); + } + } else { + anyhow::bail!( + "Cannot retrieve the {}-th dimension of the {}-th input.", + iiix.ii, + iiix.i, + ); + } + } + + // set batch size <- i00 + let batch_size: MinOptMax = if ys[0][0].is_dyn() { + 1.into() + } else { + ys[0][0].clone() + }; + + // deal with the dynamic axis + ys.iter_mut().enumerate().for_each(|(i, xs)| { + xs.iter_mut().enumerate().for_each(|(ii, x)| { + if x.is_dyn() { + let z = if ii == 0 { + batch_size.clone() + } else { + let z = MinOptMax::from(1); + warn!( + "Using dynamic shapes in inputs without specifying it: the {}-th input, the {}-th dimension. \ + Using {:?} by default. You should make it clear when using TensorRT.", + i + 1, ii + 1, z + ); + z + }; + *x = z; + } + }); + }); + + Ok(ys) + } + + #[allow(dead_code)] + fn nbytes_from_onnx_dtype_id(x: usize) -> usize { + match x { + 7 | 11 | 13 => 8, // i64, f64, u64 + 1 | 6 | 12 => 4, // f32, i32, u32 + 10 | 16 | 5 | 4 => 2, // f16, bf16, i16, u16 + 2 | 3 | 9 => 1, // u8, i8, bool + 8 => 4, // string(1~4) + _ => 1, // TODO: others + } + } + + #[allow(dead_code)] + fn nbytes_from_onnx_dtype(x: &TensorElementType) -> usize { + match x { + TensorElementType::Float64 | TensorElementType::Uint64 | TensorElementType::Int64 => 8, // i64, f64, u64 + TensorElementType::Float32 + | TensorElementType::Uint32 + | TensorElementType::Int32 + | TensorElementType::String => 4, // f32, i32, u32, string(1~4) + TensorElementType::Float16 + | TensorElementType::Bfloat16 + | TensorElementType::Int16 + | TensorElementType::Uint16 => 2, // f16, bf16, i16, u16 + TensorElementType::Uint8 | TensorElementType::Int8 | TensorElementType::Bool => 1, // u8, i8, bool + } + } + + #[allow(dead_code)] + fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option { + match value { + 0 => None, + 1 => Some(TensorElementType::Float32), + 2 => Some(TensorElementType::Uint8), + 3 => Some(TensorElementType::Int8), + 4 => Some(TensorElementType::Uint16), + 5 => Some(TensorElementType::Int16), + 6 => Some(TensorElementType::Int32), + 7 => Some(TensorElementType::Int64), + 8 => Some(TensorElementType::String), + 9 => Some(TensorElementType::Bool), + 10 => Some(TensorElementType::Float16), + 11 => Some(TensorElementType::Float64), + 12 => Some(TensorElementType::Uint32), + 13 => Some(TensorElementType::Uint64), + 14 => None, // COMPLEX64 + 15 => None, // COMPLEX128 + 16 => Some(TensorElementType::Bfloat16), + _ => None, + } + } + + fn io_from_onnx_value_info( + initializer_names: &HashSet<&str>, + value_info: &[onnx::ValueInfoProto], + ) -> Result { + let mut dimss: Vec> = Vec::new(); + let mut dtypes: Vec = Vec::new(); + let mut names: Vec = Vec::new(); + for v in value_info.iter() { + if initializer_names.contains(v.name.as_str()) { + continue; + } + names.push(v.name.to_string()); + let dtype = match &v.r#type { + Some(dtype) => dtype, + None => continue, + }; + let dtype = match &dtype.value { + Some(dtype) => dtype, + None => continue, + }; + let tensor = match dtype { + onnx::type_proto::Value::TensorType(tensor) => tensor, + _ => continue, + }; + let tensor_type = tensor.elem_type; + let tensor_type = match Self::ort_dtype_from_onnx_dtype_id(tensor_type) { + Some(dtype) => dtype, + None => continue, + }; + dtypes.push(tensor_type); + + let shapes = match &tensor.shape { + Some(shapes) => shapes, + None => continue, + }; + let mut shape_: Vec = Vec::new(); + for shape in shapes.dim.iter() { + match &shape.value { + None => continue, + Some(value) => match value { + onnx::tensor_shape_proto::dimension::Value::DimValue(x) => { + shape_.push(*x as _); + } + onnx::tensor_shape_proto::dimension::Value::DimParam(_) => { + shape_.push(0); + } + }, + } + } + dimss.push(shape_); + } + Ok(OrtTensorAttr { + dimss, + dtypes, + names, + }) + } + + pub fn load_onnx>(p: P) -> Result { + let f = std::fs::read(p)?; + Ok(onnx::ModelProto::decode(f.as_slice())?) + } + + pub fn batch(&self) -> &MinOptMax { + &self.inputs_minoptmax[0][0] + } + + pub fn is_batch_dyn(&self) -> bool { + self.batch().is_dyn() + } + + pub fn try_height(&self) -> Option<&MinOptMax> { + self.inputs_minoptmax.first().and_then(|x| x.get(2)) + } + + pub fn height(&self) -> &MinOptMax { + // unsafe + &self.inputs_minoptmax[0][2] + } + + pub fn is_height_dyn(&self) -> bool { + self.height().is_dyn() + } + + pub fn try_width(&self) -> Option<&MinOptMax> { + self.inputs_minoptmax.first().and_then(|x| x.get(3)) + } + + pub fn width(&self) -> &MinOptMax { + // unsafe + &self.inputs_minoptmax[0][3] + } + + pub fn is_width_dyn(&self) -> bool { + self.width().is_dyn() + } + + pub fn try_fetch(&self, key: &str) -> Option { + match self.onnx.as_ref().unwrap().session.metadata() { + Err(_) => None, + Ok(metadata) => metadata.custom(key).unwrap_or_default(), + } + } + + pub fn ir_version(&self) -> Option { + self.onnx.as_ref().map(|x| x.proto.ir_version as usize) + } + + pub fn opset_version(&self) -> Option { + self.onnx + .as_ref() + .map(|x| x.proto.opset_import[0].version as usize) + } + + pub fn producer_name(&self) -> Option { + self.onnx.as_ref().map(|x| x.proto.producer_name.clone()) + } + + pub fn producer_version(&self) -> Option { + self.onnx.as_ref().map(|x| x.proto.producer_version.clone()) + } + + pub fn model_version(&self) -> Option { + self.onnx.as_ref().map(|x| x.proto.model_version as usize) + } + + pub fn ishapes(&self) -> Option<&[Vec]> { + self.onnx.as_ref().map(|x| x.inputs.dimss()) + } + + pub fn idimss(&self) -> Option<&[Vec]> { + self.onnx.as_ref().map(|x| x.inputs.dimss()) + } + + pub fn inames(&self) -> Option<&[String]> { + self.onnx.as_ref().map(|x| x.inputs.names()) + } + + pub fn idtypes(&self) -> Option> { + self.onnx.as_ref().and_then(|x| { + x.inputs + .dtypes() + .iter() + .map(DType::from_ort) + .collect::>() + .into() + }) + } + + pub fn oshapes(&self) -> Option<&[Vec]> { + self.onnx.as_ref().map(|x| x.outputs.dimss()) + } + + pub fn odimss(&self) -> Option<&[Vec]> { + self.onnx.as_ref().map(|x| x.outputs.dimss()) + } + + pub fn onames(&self) -> Option<&[String]> { + self.onnx.as_ref().map(|x| x.outputs.names()) + } + + pub fn odtypes(&self) -> Option> { + self.onnx.as_ref().and_then(|x| { + x.outputs + .dtypes() + .iter() + .map(DType::from_ort) + .collect::>() + .into() + }) + } + + pub fn profile(&self) { + self.ts.summary(); + } + + pub fn info(&self) { + let info = format!( + "Minimum Supported Ort Version: 1.{}.x, Opset Version: {}, Device: {}, Parameters: {}, Memory: {}", + ort::MINOR_VERSION, + self.opset_version().map_or("Unknown".to_string(), |x| x.to_string()), + self.device, + match self.params { + Some(bytes) if bytes != 0 => { + human_bytes(bytes as f64, true) + } + _ => "Unknown".to_string(), + }, + match self.wbmems { + Some(bytes) if bytes != 0 => { + human_bytes(bytes as f64, true) + } + _ => "Unknown".to_string(), + }, + ); + + info!("{}", info); + } +} diff --git a/src/core/hub.rs b/src/misc/hub.rs similarity index 60% rename from src/core/hub.rs rename to src/misc/hub.rs index 5a13e92..ec4697b 100644 --- a/src/core/hub.rs +++ b/src/misc/hub.rs @@ -1,10 +1,11 @@ use anyhow::{Context, Result}; use indicatif::{ProgressBar, ProgressStyle}; +use log::debug; use serde::{Deserialize, Serialize}; use std::io::{Read, Write}; use std::path::{Path, PathBuf}; -use crate::Dir; +use crate::{Dir, PREFIX_LENGTH}; /// Represents a downloadable asset in a release #[derive(Clone, Debug, Serialize, Deserialize)] @@ -23,36 +24,20 @@ pub struct Release { /// Manages interactions with a GitHub repository's releases pub struct Hub { - /// github api - _gh_api_release: String, - /// GitHub repository owner owner: String, /// GitHub repository name repo: String, - /// Optional list of releases fetched from GitHub - releases: Option>, + /// Directory to store the downloaded file + to: Dir, /// Path to cache file cache: PathBuf, - /// Optional release tag to be used - tag: Option, - - /// Filename for the asset, used in cache management - file_name: Option, - file_size: Option, - - /// Full URL constructed for downloading the asset - url: Option, - - /// Local path where the asset will be stored - path: PathBuf, - - /// Directory to store the downloaded file - to: Dir, + /// Optional list of releases fetched from GitHub + releases: Vec, /// Download timeout in seconds timeout: u64, @@ -70,136 +55,122 @@ impl std::fmt::Debug for Hub { .field("owner", &self.owner) .field("repo", &self.repo) .field("cache", &self.cache) - .field("path", &self.path) - .field("releases", &self.releases.as_ref().map(|x| x.len())) + // .field("releases", &self.releases.as_ref().map(|x| x.len())) .field("ttl", &self.ttl) .field("max_attempts", &self.max_attempts) .finish() } } -impl Default for Hub { - fn default() -> Self { - let owner = "jamjamjon".to_string(); - let repo = "assets".to_string(); - let _gh_api_release = format!("https://api.github.com/repos/{}/{}/releases", owner, repo); - - Self { - owner, - repo, - _gh_api_release, - url: None, - path: PathBuf::new(), - to: Dir::Cache, - tag: None, - file_name: None, - file_size: None, - releases: None, - cache: PathBuf::new(), - timeout: 3000, - max_attempts: 3, - ttl: std::time::Duration::from_secs(10 * 60), - } - } -} - impl Hub { pub fn new() -> Result { + // Build the Hub instance + let owner = "jamjamjon".to_string(); + let repo = "assets".to_string(); let mut to = Dir::Cache; let cache = to .path() .or_else(|_| { to = Dir::Home; to.path() - })? - .join("cache_releases"); + }) + .or_else(|_| { + to = Dir::Config; + to.path() + }) + .or_else(|_| { + to = Dir::Current; + to.path() + }) + .expect( + "Unable to get cache directory, home directory, config directory, and current directory. Possible reason:\ + \n1. Unsupported OS\ + \n2. Directory does not exist\ + \n3. Insufficient permissions to access" + ) + .join(".gh_releases.cache"); + + let ttl = std::time::Duration::from_secs(10 * 60); + + // releases + let is_file_expired = Self::is_file_expired(&cache, ttl)?; + let body = if is_file_expired { + let gh_api_release = + format!("https://api.github.com/repos/{}/{}/releases", owner, repo); + Self::fetch_and_cache_releases(&gh_api_release, &cache)? + } else { + std::fs::read_to_string(&cache)? + }; + let releases = serde_json::from_str(&body)?; Ok(Self { + owner, + repo, to, cache, - ..Default::default() + releases, + ttl, + timeout: 3000, + max_attempts: 3, }) } - pub fn with_owner(mut self, owner: &str) -> Self { - self.owner = owner.to_string(); - self - } - - pub fn with_repo(mut self, repo: &str) -> Self { - self.repo = repo.to_string(); - self - } - - pub fn with_ttl(mut self, x: u64) -> Self { - self.ttl = std::time::Duration::from_secs(x); - self - } - - pub fn with_timeout(mut self, x: u64) -> Self { - self.timeout = x; - self - } - - pub fn with_max_attempts(mut self, x: u32) -> Self { - self.max_attempts = x; - self - } + pub fn try_fetch(&mut self, s: &str) -> Result { + // mutables + let mut url: Option = None; + let mut tag: Option = None; + let mut file_size: Option = None; + let mut file_name: Option = None; - pub fn fetch(mut self, s: &str) -> Result { - // try to fetch from hub or local cache let p = PathBuf::from(s); - match p.exists() { - true => self.path = p, + let path = match p.exists() { + true => p, false => { + // check empty + if self.releases.is_empty() { + anyhow::bail!("No releases found in this repo."); + } + // check remote match s.split_once('/') { - Some((tag, file_name)) => { + Some((tag_, file_name_)) => { // Extract tag and file from input string - self.tag = Some(tag.to_string()); - self.file_name = Some(file_name.to_string()); - - // Check if releases are already loaded in memory - if self.releases.is_none() { - self.releases = Some(self.connect_remote()?); - } - - if let Some(releases) = &self.releases { - // Validate the tag - let tags: Vec<&str> = - releases.iter().map(|x| x.tag_name.as_str()).collect(); - if !tags.contains(&tag) { - anyhow::bail!( - "Hub tag '{}' not found in releases. Available tags: {:?}", - tag, + tag = Some(tag_.to_string()); + file_name = Some(file_name_.to_string()); + + // Validate the tag + let tags = self.tags(); + if !tags.contains(&tag_) { + anyhow::bail!( + "Try to fetch from GitHub releases. However, tag: `{}` is not found. Available tags: {:#?}", + tag_, tags ); - } + } - // Validate the file - if let Some(release) = releases.iter().find(|r| r.tag_name == tag) { - let files: Vec<&str> = - release.assets.iter().map(|x| x.name.as_str()).collect(); - if !files.contains(&file_name) { - anyhow::bail!( - "Hub file '{}' not found in tag '{}'. Available files: {:?}", - file_name, - tag, + // Validate the file + if let Some(release) = self.releases.iter().find(|r| r.tag_name == tag_) { + let files: Vec<&str> = + release.assets.iter().map(|x| x.name.as_str()).collect(); + if !files.contains(&file_name_) { + anyhow::bail!( + "Try to fetch from GitHub releases. However, file: `{}` is not found in tag: `{}`. Available files: {:#?}", + file_name_, + tag_, files ); - } else { - for f_ in release.assets.iter() { - if f_.name.as_str() == file_name { - self.url = Some(f_.browser_download_url.clone()); - self.file_size = Some(f_.size); - - break; - } + } else { + for f_ in release.assets.iter() { + if f_.name.as_str() == file_name_ { + url = Some(f_.browser_download_url.clone()); + file_size = Some(f_.size); + + break; } } } - self.path = self.to.path_with_subs(&[tag])?.join(file_name); } + self.to.path_with_subs(&[tag_])?.join(file_name_) } _ => anyhow::bail!( "Download failed due to invalid format. Expected: /, got: {}", @@ -207,9 +178,28 @@ impl Hub { ), } } + }; + + // Commit the downloaded file, downloading if necessary + if let Some(url) = &url { + // Download if the file does not exist or if the size of file does not match + if !path.is_file() + || path.is_file() && Some(std::fs::metadata(&path)?.len()) != file_size + { + let name = format!("{}/{}", tag.as_ref().unwrap(), file_name.as_ref().unwrap()); + Self::download( + url.as_str(), + &path, + Some(&name), + Some(self.timeout), + Some(self.max_attempts), + )?; + } } - Ok(self) + path.to_str() + .map(|s| s.to_string()) + .with_context(|| format!("Failed to convert PathBuf: {:?} to String", path)) } /// Fetch releases from GitHub and cache them @@ -243,6 +233,8 @@ impl Hub { let mut temp_file = tempfile::NamedTempFile::new_in(parent_dir) .context("Failed to create temporary cache file")?; + // Encode? + // Write data to temporary file temp_file .write_all(body.as_bytes()) @@ -256,89 +248,41 @@ impl Hub { Ok(body) } - pub fn tags(&mut self) -> Option> { - if self.releases.is_none() { - self.releases = self.connect_remote().ok(); - } - - self.releases - .as_ref() - .map(|releases| releases.iter().map(|x| x.tag_name.as_str()).collect()) + pub fn tags(&self) -> Vec<&str> { + self.releases.iter().map(|x| x.tag_name.as_str()).collect() } - pub fn files(&mut self, tag: &str) -> Option> { - if self.releases.is_none() { - self.releases = self.connect_remote().ok(); - } - - self.releases.as_ref().map(|releases| { - releases - .iter() - .find(|r| r.tag_name == tag) - .map(|a| a.assets.iter().map(|x| x.name.as_str()).collect()) - })? + pub fn files(&self, tag: &str) -> Vec<&str> { + self.releases + .iter() + .find(|r| r.tag_name == tag) + .map(|a| a.assets.iter().map(|x| x.name.as_str()).collect()) + .unwrap_or_default() } - pub fn connect_remote(&mut self) -> Result> { - let span = tracing::span!(tracing::Level::INFO, "Hub-connect_remote"); - let _guard = span.enter(); - - let should_download = if !self.cache.exists() { - tracing::info!("No cache found, fetching data from GitHub"); + pub fn is_file_expired>(file: P, ttl: std::time::Duration) -> Result { + let file = file.as_ref(); + let y = if !file.exists() { + debug!("No cache found, fetching data from GitHub"); true } else { - match std::fs::metadata(&self.cache)?.modified() { + match std::fs::metadata(file)?.modified() { Err(_) => { - tracing::info!("Cannot get file modified time, fetching new data from GitHub"); + debug!("Cannot get file modified time, fetching new data from GitHub"); true } Ok(modified_time) => { - if std::time::SystemTime::now().duration_since(modified_time)? < self.ttl { - tracing::info!("Using cached data"); + if std::time::SystemTime::now().duration_since(modified_time)? < ttl { + debug!("Using cached data"); false } else { - tracing::info!("Cache expired, fetching new data from GitHub"); + debug!("Cache expired, fetching new data from GitHub"); true } } } }; - - let body = if should_download { - Self::fetch_and_cache_releases(&self._gh_api_release, &self.cache)? - } else { - std::fs::read_to_string(&self.cache)? - }; - let releases: Vec = serde_json::from_str(&body)?; - Ok(releases) - } - - /// Commit the downloaded file, downloading if necessary - pub fn commit(&self) -> Result { - if let Some(url) = &self.url { - // Download if the file does not exist or if the size of file does not match - if !self.path.is_file() - || self.path.is_file() - && Some(std::fs::metadata(&self.path)?.len()) != self.file_size - { - let name = format!( - "{}/{}", - self.tag.as_ref().unwrap(), - self.file_name.as_ref().unwrap() - ); - Self::download( - url.as_str(), - &self.path, - Some(&name), - Some(self.timeout), - Some(self.max_attempts), - )?; - } - } - self.path - .to_str() - .map(|s| s.to_string()) - .with_context(|| format!("Failed to convert PathBuf: {:?} to String", self.path)) + Ok(y) } /// Download a file from a github release to a specified path with a progress bar @@ -349,8 +293,6 @@ impl Hub { timeout: Option, max_attempts: Option, ) -> Result<()> { - // TODO: other url, not just github release page - let max_attempts = max_attempts.unwrap_or(2); let timeout_duration = std::time::Duration::from_secs(timeout.unwrap_or(2000)); let agent = ureq::AgentBuilder::new().try_proxy_from_env(true).build(); @@ -379,9 +321,9 @@ impl Hub { .progress_chars("██ "), ); pb.set_prefix(if i_try == 0 { - "Fetching" + format!("{:>PREFIX_LENGTH$}", "Fetching") } else { - "Re-Fetching" + format!("{:>PREFIX_LENGTH$}", "Re-Fetching") }); pb.set_message(prompt.unwrap_or_default().to_string()); @@ -423,4 +365,29 @@ impl Hub { Ok(()) } + + pub fn with_owner(mut self, owner: &str) -> Self { + self.owner = owner.to_string(); + self + } + + pub fn with_repo(mut self, repo: &str) -> Self { + self.repo = repo.to_string(); + self + } + + pub fn with_ttl(mut self, x: u64) -> Self { + self.ttl = std::time::Duration::from_secs(x); + self + } + + pub fn with_timeout(mut self, x: u64) -> Self { + self.timeout = x; + self + } + + pub fn with_max_attempts(mut self, x: u32) -> Self { + self.max_attempts = x; + self + } } diff --git a/src/misc/iiix.rs b/src/misc/iiix.rs new file mode 100644 index 0000000..6db1626 --- /dev/null +++ b/src/misc/iiix.rs @@ -0,0 +1,15 @@ +use crate::MinOptMax; + +/// A struct for input composed of the i-th input, the ii-th dimension, and the value. +#[derive(Clone, Debug, Default)] +pub struct Iiix { + pub i: usize, + pub ii: usize, + pub x: MinOptMax, +} + +impl From<(usize, usize, MinOptMax)> for Iiix { + fn from((i, ii, x): (usize, usize, MinOptMax)) -> Self { + Self { i, ii, x } + } +} diff --git a/src/core/logits_sampler.rs b/src/misc/logits_sampler.rs similarity index 99% rename from src/core/logits_sampler.rs rename to src/misc/logits_sampler.rs index 5867fd7..1be03e0 100644 --- a/src/core/logits_sampler.rs +++ b/src/misc/logits_sampler.rs @@ -1,7 +1,6 @@ use anyhow::Result; use rand::distributions::{Distribution, WeightedIndex}; -/// Logits Sampler #[derive(Debug)] pub struct LogitsSampler { temperature: f32, diff --git a/src/core/media.rs b/src/misc/media.rs similarity index 100% rename from src/core/media.rs rename to src/misc/media.rs index 23cee6a..ee76c69 100644 --- a/src/core/media.rs +++ b/src/misc/media.rs @@ -1,14 +1,5 @@ use crate::{AUDIO_EXTENSIONS, IMAGE_EXTENSIONS, STREAM_PROTOCOLS, VIDEO_EXTENSIONS}; -#[derive(Debug, Clone)] -pub enum MediaType { - Image(Location), - Video(Location), - Audio(Location), - Stream, - Unknown, -} - #[derive(Debug, Clone)] pub enum Location { Local, @@ -21,6 +12,15 @@ pub enum StreamType { Live, } +#[derive(Debug, Clone)] +pub enum MediaType { + Image(Location), + Video(Location), + Audio(Location), + Stream, + Unknown, +} + impl MediaType { pub fn from_path>(path: P) -> Self { let extension = path diff --git a/src/core/min_opt_max.rs b/src/misc/min_opt_max.rs similarity index 99% rename from src/core/min_opt_max.rs rename to src/misc/min_opt_max.rs index e5412cd..e4f47ca 100644 --- a/src/core/min_opt_max.rs +++ b/src/misc/min_opt_max.rs @@ -83,6 +83,7 @@ impl MinOptMax { } } +// TODO: min = 1????? impl From for MinOptMax { fn from(opt: i32) -> Self { let opt = opt.max(0) as usize; @@ -92,6 +93,7 @@ impl From for MinOptMax { } } +// TODO: min = 1????? impl From for MinOptMax { fn from(opt: i64) -> Self { let opt = opt.max(0) as usize; @@ -127,6 +129,7 @@ impl From for MinOptMax { } } +// TODO: min = 1????? impl From for MinOptMax { fn from(opt: isize) -> Self { let opt = opt.max(0) as usize; diff --git a/src/core/mod.rs b/src/misc/mod.rs similarity index 57% rename from src/core/mod.rs rename to src/misc/mod.rs index 0b0c2f1..685dae5 100644 --- a/src/core/mod.rs +++ b/src/misc/mod.rs @@ -1,45 +1,44 @@ mod annotator; +mod color; +mod colormap256; mod dataloader; mod device; mod dir; +mod dtype; mod dynconf; +mod engine; mod hub; +mod iiix; mod logits_sampler; mod media; -mod metric; mod min_opt_max; -pub mod onnx; -pub mod ops; -mod options; -mod ort_engine; -mod task; -mod tokenizer_stream; +pub(crate) mod onnx; +mod ops; mod ts; +mod utils; +#[cfg(feature = "ffmpeg")] mod viewer; -mod vision; -mod x; -mod xs; pub use annotator::Annotator; +pub use color::Color; +pub use colormap256::*; pub use dataloader::DataLoader; pub use device::Device; pub use dir::Dir; +pub use dtype::DType; pub use dynconf::DynConf; +pub use engine::*; pub use hub::Hub; +pub use iiix::Iiix; pub use logits_sampler::LogitsSampler; pub use media::*; -pub use metric::Metric; pub use min_opt_max::MinOptMax; -pub use ops::Ops; -pub use options::Options; -pub use ort_engine::*; -pub use task::Task; -pub use tokenizer_stream::TokenizerStream; +pub use ops::*; pub use ts::Ts; +pub use utils::*; +#[cfg(feature = "ffmpeg")] pub use viewer::Viewer; -pub use vision::Vision; -pub use x::X; -pub use xs::Xs; // re-export +#[cfg(feature = "ffmpeg")] pub use minifb::Key; diff --git a/src/core/onnx.rs b/src/misc/onnx.rs similarity index 99% rename from src/core/onnx.rs rename to src/misc/onnx.rs index d88dc84..33bdfc0 100644 --- a/src/core/onnx.rs +++ b/src/misc/onnx.rs @@ -866,6 +866,7 @@ pub mod type_proto { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] + #[allow(clippy::enum_variant_names)] pub enum Value { /// The type of a tensor. #[prost(message, tag = "1")] @@ -945,6 +946,7 @@ pub struct FunctionProto { /// that is not defined by the default value but an explicit enum number. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +#[allow(clippy::enum_variant_names)] pub enum Version { /// proto3 requires the first enum value to be zero. /// We add this just to appease the compiler. diff --git a/src/core/ops.rs b/src/misc/ops.rs similarity index 65% rename from src/core/ops.rs rename to src/misc/ops.rs index 7b60fb3..2b9211d 100644 --- a/src/core/ops.rs +++ b/src/misc/ops.rs @@ -1,4 +1,4 @@ -//! Some processing functions to image and ndarray. +//! Some processing functions. use anyhow::Result; use fast_image_resize::{ @@ -7,11 +7,11 @@ use fast_image_resize::{ FilterType, ResizeAlg, ResizeOptions, Resizer, }; use image::{DynamicImage, GenericImageView}; -use ndarray::{concatenate, s, Array, Array3, Axis, IntoDimension, IxDyn}; +use ndarray::{concatenate, s, Array, Array3, Axis, IntoDimension, Ix2, IxDyn}; use rayon::prelude::*; pub enum Ops<'a> { - Resize(&'a [DynamicImage], u32, u32, &'a str), + FitExact(&'a [DynamicImage], u32, u32, &'a str), Letterbox(&'a [DynamicImage], u32, u32, &'a str, u8, &'a str, bool), Normalize(f32, f32), Standardize(&'a [f32], &'a [f32], usize), @@ -80,11 +80,20 @@ impl Ops<'_> { dim: usize, ) -> Result> { if mean.len() != std.len() { - anyhow::bail!("`standardize`: `mean` and `std` lengths are not equal. Mean length: {}, Std length: {}.", mean.len(), std.len()); + anyhow::bail!( + "`standardize`: `mean` and `std` lengths are not equal. Mean length: {}, Std length: {}.", + mean.len(), + std.len() + ); } let shape = x.shape(); if dim >= shape.len() || shape[dim] != mean.len() { - anyhow::bail!("`standardize`: Dimension mismatch. `dim` is {} but shape length is {} or `mean` length is {}.", dim, shape.len(), mean.len()); + anyhow::bail!( + "`standardize`: Dimension mismatch. `dim` is {} but shape length is {} or `mean` length is {}.", + dim, + shape.len(), + mean.len() + ); } let mut shape = vec![1; shape.len()]; shape[dim] = mean.len(); @@ -122,6 +131,23 @@ impl Ops<'_> { Ok(concatenate(Axis(d), &[x.view(), y.view()])?) } + pub fn concat(xs: &[Array], d: usize) -> Result> { + let xs = xs.iter().map(|x| x.view()).collect::>(); + Ok(concatenate(Axis(d), &xs)?) + } + + pub fn dot2(x: &Array, other: &Array) -> Result>> { + // (m, ndim) * (n, ndim).t => (m, n) + let query = x.to_owned().into_dimensionality::()?; + let gallery = other.to_owned().into_dimensionality::()?; + let matrix = query.dot(&gallery.t()); + let exps = matrix.mapv(|x| x.exp()); + let stds = exps.sum_axis(Axis(1)); + let matrix = exps / stds.insert_axis(Axis(1)); + let matrix: Vec> = matrix.axis_iter(Axis(0)).map(|row| row.to_vec()).collect(); + Ok(matrix) + } + pub fn insert_axis(x: Array, d: usize) -> Result> { if x.shape().len() < d { anyhow::bail!( @@ -167,11 +193,6 @@ impl Ops<'_> { mask.resize_exact(w1 as u32, h1 as u32, image::imageops::FilterType::Triangle) } - // pub fn argmax(xs: Array, d: usize, keep_dims: bool) -> Result> { - // let mask = Array::zeros(xs.raw_dim()); - // todo!(); - // } - pub fn interpolate_3d( xs: Array, tw: f32, @@ -238,14 +259,36 @@ impl Ops<'_> { }; resizer.resize(&src, &mut dst, &options)?; - // u8*2 -> f32 - let mask_f32: Vec = dst - .into_vec() - .chunks_exact(4) - .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) - .collect(); + // u8 -> f32 + Self::u8_slice_to_f32(&dst.into_vec()) + } + + pub fn u8_slice_to_f32(data: &[u8]) -> Result> { + let size_in_bytes = 4; + let elem_count = data.len() / size_in_bytes; + if (data.as_ptr() as usize) % size_in_bytes == 0 { + let data: &[f32] = + unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, elem_count) }; - Ok(mask_f32) + Ok(data.to_vec()) + } else { + let mut c: Vec = Vec::with_capacity(elem_count); + unsafe { + std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len()); + c.set_len(elem_count) + } + + Ok(c) + } + } + + pub fn f32_slice_to_u8(mut vs: Vec) -> Vec { + let size_in_bytes = 4; + let length = vs.len() * size_in_bytes; + let capacity = vs.capacity() * size_in_bytes; + let ptr = vs.as_mut_ptr() as *mut u8; + std::mem::forget(vs); + unsafe { Vec::from_raw_parts(ptr, length, capacity) } } pub fn resize_luma8_u8( @@ -285,6 +328,26 @@ impl Ops<'_> { )) } + pub fn resize_rgb( + x: &DynamicImage, + th: u32, + tw: u32, + resizer: &mut Resizer, + options: &ResizeOptions, + ) -> Result> { + let buffer = if x.dimensions() == (tw, th) { + x.to_rgb8().into_raw() + } else { + let mut dst = Image::new(tw, th, PixelType::U8x3); + resizer.resize(x, &mut dst, options)?; + dst.into_vec() + }; + let y = Array::from_shape_vec((th as usize, tw as usize, 3), buffer)? + .mapv(|x| x as f32) + .into_dyn(); + Ok(y) + } + pub fn resize( xs: &[DynamicImage], th: u32, @@ -294,20 +357,65 @@ impl Ops<'_> { let mut ys = Array::ones((xs.len(), th as usize, tw as usize, 3)).into_dyn(); let (mut resizer, options) = Self::build_resizer_filter(filter)?; for (idx, x) in xs.iter().enumerate() { - let buffer = if x.dimensions() == (tw, th) { - x.to_rgb8().into_raw() - } else { - let mut dst = Image::new(tw, th, PixelType::U8x3); - resizer.resize(x, &mut dst, &options)?; - dst.into_vec() - }; - let y_ = - Array::from_shape_vec((th as usize, tw as usize, 3), buffer)?.mapv(|x| x as f32); - ys.slice_mut(s![idx, .., .., ..]).assign(&y_); + let y = Self::resize_rgb(x, th, tw, &mut resizer, &options)?; + ys.slice_mut(s![idx, .., .., ..]).assign(&y); } Ok(ys) } + #[allow(clippy::too_many_arguments)] + pub fn letterbox_rgb( + x: &DynamicImage, + th: u32, + tw: u32, + bg: u8, + resize_by: &str, + center: bool, + resizer: &mut Resizer, + options: &ResizeOptions, + ) -> Result> { + let (w0, h0) = x.dimensions(); + let buffer = if w0 == tw && h0 == th { + x.to_rgb8().into_raw() + } else { + let (w, h) = match resize_by { + "auto" => { + let r = (tw as f32 / w0 as f32).min(th as f32 / h0 as f32); + ( + (w0 as f32 * r).round() as u32, + (h0 as f32 * r).round() as u32, + ) + } + "height" => (th * w0 / h0, th), + "width" => (tw, tw * h0 / w0), + _ => anyhow::bail!("ModelConfig for `letterbox`: width, height, auto"), + }; + + let mut dst = Image::from_vec_u8( + tw, + th, + vec![bg; 3 * th as usize * tw as usize], + PixelType::U8x3, + )?; + let (l, t) = if center { + if w == tw { + (0, (th - h) / 2) + } else { + ((tw - w) / 2, 0) + } + } else { + (0, 0) + }; + let mut dst_cropped = CroppedImageMut::new(&mut dst, l, t, w, h)?; + resizer.resize(x, &mut dst_cropped, options)?; + dst.into_vec() + }; + let y = Array::from_shape_vec((th as usize, tw as usize, 3), buffer)? + .mapv(|x| x as f32) + .into_dyn(); + Ok(y) + } + pub fn letterbox( xs: &[DynamicImage], th: u32, @@ -319,47 +427,9 @@ impl Ops<'_> { ) -> Result> { let mut ys = Array::ones((xs.len(), th as usize, tw as usize, 3)).into_dyn(); let (mut resizer, options) = Self::build_resizer_filter(filter)?; - for (idx, x) in xs.iter().enumerate() { - let (w0, h0) = x.dimensions(); - let buffer = if w0 == tw && h0 == th { - x.to_rgb8().into_raw() - } else { - let (w, h) = match resize_by { - "auto" => { - let r = (tw as f32 / w0 as f32).min(th as f32 / h0 as f32); - ( - (w0 as f32 * r).round() as u32, - (h0 as f32 * r).round() as u32, - ) - } - "height" => (th * w0 / h0, th), - "width" => (tw, tw * h0 / w0), - _ => anyhow::bail!("Options for `letterbox`: width, height, auto"), - }; - - let mut dst = Image::from_vec_u8( - tw, - th, - vec![bg; 3 * th as usize * tw as usize], - PixelType::U8x3, - )?; - let (l, t) = if center { - if w == tw { - (0, (th - h) / 2) - } else { - ((tw - w) / 2, 0) - } - } else { - (0, 0) - }; - let mut dst_cropped = CroppedImageMut::new(&mut dst, l, t, w, h)?; - resizer.resize(x, &mut dst_cropped, &options)?; - dst.into_vec() - }; - let y_ = - Array::from_shape_vec((th as usize, tw as usize, 3), buffer)?.mapv(|x| x as f32); - ys.slice_mut(s![idx, .., .., ..]).assign(&y_); + let y = Self::letterbox_rgb(x, th, tw, bg, resize_by, center, &mut resizer, &options)?; + ys.slice_mut(s![idx, .., .., ..]).assign(&y); } Ok(ys) } diff --git a/src/misc/ts.rs b/src/misc/ts.rs new file mode 100644 index 0000000..fcdcce4 --- /dev/null +++ b/src/misc/ts.rs @@ -0,0 +1,392 @@ +use std::collections::HashMap; +use std::time::Duration; + +#[macro_export] +macro_rules! elapsed { + ($code:expr) => {{ + let t = std::time::Instant::now(); + let ret = $code; + let duration = t.elapsed(); + (duration, ret) + }}; + ($label:expr, $ts:expr, $code:expr) => {{ + let t = std::time::Instant::now(); + let ret = $code; + let duration = t.elapsed(); + $ts.push($label, duration); + ret + }}; +} + +#[derive(aksr::Builder, Debug, Default, Clone, PartialEq)] +pub struct Ts { + // { k1: [d1,d1,d1,..], k2: [d2,d2,d2,..], k3: [d3,d3,d3,..], ..} + map: HashMap>, + names: Vec, +} + +impl std::ops::Index<&str> for Ts { + type Output = Vec; + + fn index(&self, index: &str) -> &Self::Output { + self.map.get(index).expect("Index was not found in `Ts`") + } +} + +impl std::ops::Index for Ts { + type Output = Vec; + + fn index(&self, index: usize) -> &Self::Output { + self.names + .get(index) + .and_then(|key| self.map.get(key)) + .expect("Index was not found in `Ts`") + } +} + +impl Ts { + pub fn summary(&self) { + let decimal_places = 4; + let place_holder = '-'; + let width_count = 10; + let width_time = 15; + let width_task = self + .names + .iter() + .map(|s| s.len()) + .max() + .map(|x| x + 8) + .unwrap_or(60); + + let sep = "-".repeat(width_task + 66); + + // cols + println!( + "\n\n{: Self { + let mut names = Vec::new(); + let mut map: HashMap> = HashMap::new(); + for x in xs.iter() { + names.extend_from_slice(x.names()); + map.extend(x.map().to_owned()); + } + + Self { names, map } + } + + pub fn push(&mut self, k: &str, v: Duration) { + if !self.names.contains(&k.to_string()) { + self.names.push(k.to_string()); + } + self.map + .entry(k.to_string()) + .and_modify(|x| x.push(v)) + .or_insert(vec![v]); + } + + pub fn numit(&self) -> anyhow::Result { + // num of iterations + if self.names.is_empty() { + anyhow::bail!("Empty Ts"); + } + + let len = self[0].len(); + for v in self.map.values() { + if v.len() != len { + anyhow::bail!( + "Invalid Ts: The number of elements in each values entry is inconsistent" + ); + } + } + + Ok(len) + } + + pub fn is_valid(&self) -> bool { + let mut iter = self.map.values(); + if let Some(first) = iter.next() { + let len = first.len(); + iter.all(|v| v.len() == len) + } else { + true + } + } + + pub fn sum_by_index(&self, i: usize) -> Duration { + self[i].iter().sum::() + } + + pub fn sum_by_key(&self, i: &str) -> Duration { + self[i].iter().sum::() + } + + pub fn avg_by_index(&self, i: usize) -> anyhow::Result { + let len = self[i].len(); + if len == 0 { + anyhow::bail!("Cannot compute average for an empty duration vector.") + } else { + Ok(self.sum_by_index(i) / len as u32) + } + } + + pub fn avg_by_key(&self, i: &str) -> anyhow::Result { + let len = self[i].len(); + if len == 0 { + anyhow::bail!("Cannot compute average for an empty duration vector.") + } else { + Ok(self.sum_by_key(i) / len as u32) + } + } + + pub fn sum_column(&self, i: usize) -> Duration { + self.map + .values() + .filter_map(|vec| vec.get(i)) + .copied() + .sum() + } + + pub fn sum(&self) -> Duration { + self.map.values().flat_map(|vec| vec.iter()).copied().sum() + } + + pub fn avg(&self) -> anyhow::Result { + self.names.iter().map(|x| self.avg_by_key(x)).sum() + } + + pub fn skip(mut self, n: usize) -> Self { + self.map.iter_mut().for_each(|(_, vec)| { + *vec = vec.iter().skip(n).copied().collect(); + }); + self + } + + pub fn clear(&mut self) { + self.names.clear(); + self.map.clear(); + } + + pub fn is_empty(&self) -> bool { + self.names.is_empty() && self.map.is_empty() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn test_push_and_indexing() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + + assert_eq!(ts["task1"], vec![Duration::new(1, 0), Duration::new(2, 0)]); + assert_eq!(ts["task2"], vec![Duration::new(3, 0)]); + } + + #[test] + fn test_numit() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + ts.push("task2", Duration::new(4, 0)); + + assert_eq!(ts.numit().unwrap(), 2); + } + + #[test] + fn test_is_valid() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + + assert!(!ts.is_valid()); + + ts.push("task2", Duration::new(4, 0)); + ts.push("task3", Duration::new(5, 0)); + + assert!(!ts.is_valid()); + + ts.push("task3", Duration::new(6, 0)); + assert!(ts.is_valid()); + } + + #[test] + fn test_sum_by_index() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + ts.push("task2", Duration::new(3, 0)); + ts.push("task2", Duration::new(3, 0)); + + assert_eq!(ts.sum_by_index(0), Duration::new(3, 0)); // 1 + 2 + assert_eq!(ts.sum_by_index(1), Duration::new(9, 0)); // 1 + 2 + } + + #[test] + fn test_sum_by_key() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + ts.push("task2", Duration::new(3, 0)); + ts.push("task2", Duration::new(3, 0)); + + assert_eq!(ts.sum_by_key("task1"), Duration::new(3, 0)); // 1 + 2 + assert_eq!(ts.sum_by_key("task2"), Duration::new(9, 0)); // 1 + 2 + } + + #[test] + fn test_avg_by_index() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(2, 0)); + ts.push("task2", Duration::new(2, 0)); + ts.push("task3", Duration::new(2, 0)); + + assert_eq!(ts.avg_by_index(0).unwrap(), Duration::new(1, 500_000_000)); + assert_eq!(ts.avg_by_index(1).unwrap(), Duration::new(2, 0)); + assert_eq!(ts.avg_by_index(2).unwrap(), Duration::new(2, 0)); + } + + #[test] + fn test_avg_by_key() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + + let avg = ts.avg_by_key("task1").unwrap(); + assert_eq!(avg, Duration::new(1, 500_000_000)); + } + + #[test] + fn test_sum_column() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + + assert_eq!(ts.sum_column(0), Duration::new(4, 0)); // 1 + 3 + } + + #[test] + fn test_sum() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + + assert_eq!(ts.sum(), Duration::new(6, 0)); + } + + #[test] + fn test_avg() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + ts.push("task2", Duration::new(4, 0)); + + assert_eq!(ts.avg().unwrap(), Duration::new(5, 0)); + } + + #[test] + fn test_skip() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task1", Duration::new(2, 0)); + ts.push("task2", Duration::new(3, 0)); + ts.push("task2", Duration::new(4, 0)); + ts.push("task2", Duration::new(4, 0)); + + let ts_skipped = ts.skip(1); + + assert_eq!(ts_skipped["task1"], vec![Duration::new(2, 0)]); + assert_eq!( + ts_skipped["task2"], + vec![Duration::new(4, 0), Duration::new(4, 0)] + ); + + let ts_skipped = ts_skipped.skip(1); + + assert!(ts_skipped["task1"].is_empty()); + assert_eq!(ts_skipped["task2"], vec![Duration::new(4, 0)]); + } + + #[test] + fn test_clear() { + let mut ts = Ts::default(); + + ts.push("task1", Duration::new(1, 0)); + ts.push("task2", Duration::new(2, 0)); + + ts.clear(); + assert!(ts.names.is_empty()); + assert!(ts.map.is_empty()); + } +} diff --git a/src/utils/mod.rs b/src/misc/utils.rs similarity index 79% rename from src/utils/mod.rs rename to src/misc/utils.rs index ec69d44..c243618 100644 --- a/src/utils/mod.rs +++ b/src/misc/utils.rs @@ -3,20 +3,7 @@ use indicatif::{ProgressBar, ProgressStyle}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; -pub mod color; -mod colormap256; -pub mod names; -mod quantizer; - -pub use color::Color; -pub use colormap256::*; -pub use names::*; -pub use quantizer::Quantizer; - pub(crate) const PREFIX_LENGTH: usize = 12; -pub(crate) const CHECK_MARK: &str = "✅"; -pub(crate) const CROSS_MARK: &str = "❌"; -pub(crate) const SAFE_CROSS_MARK: &str = "❎"; pub(crate) const NETWORK_PREFIXES: &[&str] = &[ "http://", "https://", "ftp://", "ftps://", "sftp://", "rtsp://", "mms://", "mmsh://", "rtmp://", "rtmps://", "file://", @@ -47,18 +34,37 @@ pub(crate) const PROGRESS_BAR_STYLE_FINISH_3: &str = "{prefix:>12.green.bold} {msg} ({binary_total_bytes}) in {elapsed}"; pub(crate) const PROGRESS_BAR_STYLE_FINISH_4: &str = "{prefix:>12.green.bold} {msg} in {elapsed}"; -pub fn human_bytes(size: f64) -> String { - let units = ["B", "KB", "MB", "GB", "TB", "PB", "EB"]; +pub(crate) fn try_fetch_stem>(p: P) -> anyhow::Result { + let p = p.as_ref(); + let stem = p + .file_stem() + .ok_or(anyhow::anyhow!( + "Failed to get the `file_stem` of `model_file`: {:?}", + p + ))? + .to_str() + .ok_or(anyhow::anyhow!("Failed to convert from `&OsStr` to `&str`"))?; + + Ok(stem.to_string()) +} + +pub fn human_bytes(size: f64, use_binary: bool) -> String { + let units = if use_binary { + ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB"] + } else { + ["B", "KB", "MB", "GB", "TB", "PB", "EB"] + }; + let mut size = size; let mut unit_index = 0; - let k = 1024.; + let k = if use_binary { 1024. } else { 1000. }; while size >= k && unit_index < units.len() - 1 { size /= k; unit_index += 1; } - format!("{:.1} {}", size, units[unit_index]) + format!("{:.2} {}", size, units[unit_index]) } pub(crate) fn string_random(n: usize) -> String { @@ -78,7 +84,7 @@ pub(crate) fn string_now(delimiter: &str) -> String { t_now.format(&fmt).to_string() } -pub fn build_progress_bar( +pub(crate) fn build_progress_bar( n: u64, prefix: &str, msg: Option<&str>, diff --git a/src/core/viewer.rs b/src/misc/viewer.rs similarity index 92% rename from src/core/viewer.rs rename to src/misc/viewer.rs index 982fc8a..cb37c77 100644 --- a/src/core/viewer.rs +++ b/src/misc/viewer.rs @@ -1,13 +1,12 @@ use anyhow::Result; use image::DynamicImage; +use log::info; use minifb::{Window, WindowOptions}; use video_rs::{ encode::{Encoder, Settings}, time::Time, }; -use crate::{string_now, Dir, Key}; - pub struct Viewer<'a> { name: &'a str, window: Option, @@ -107,8 +106,9 @@ impl Viewer<'_> { let (w, h) = frame.dimensions(); if self.writer.is_none() { let settings = Settings::preset_h264_yuv420p(w as _, h as _, false); - let saveout = Dir::saveout(&["runs"])?.join(format!("{}.mp4", string_now("-"))); - tracing::info!("Video will be save to: {:?}", saveout); + let saveout = + crate::Dir::saveout(&["runs"])?.join(format!("{}.mp4", crate::string_now("-"))); + info!("Video will be save to: {:?}", saveout); self.writer = Some(Encoder::new(saveout, settings)?); } @@ -138,7 +138,7 @@ impl Viewer<'_> { match &mut self.writer { Some(writer) => writer.finish()?, None => { - tracing::info!("Found no video writer. No need to release."); + info!("Found no video writer. No need to release."); } } Ok(()) @@ -152,7 +152,7 @@ impl Viewer<'_> { } } - pub fn is_key_pressed(&self, key: Key) -> bool { + pub fn is_key_pressed(&self, key: crate::Key) -> bool { if let Some(window) = &self.window { window.is_key_down(key) } else { @@ -161,7 +161,7 @@ impl Viewer<'_> { } pub fn is_esc_pressed(&self) -> bool { - self.is_key_pressed(Key::Escape) + self.is_key_pressed(crate::Key::Escape) } pub fn resizable(mut self, x: bool) -> Self { diff --git a/src/models/basemodel.rs b/src/models/basemodel.rs new file mode 100644 index 0000000..52b73fc --- /dev/null +++ b/src/models/basemodel.rs @@ -0,0 +1,148 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; + +use crate::{ + elapsed, DType, Device, Engine, Kind, Options, Processor, Scale, Task, Ts, Version, Xs, X, +}; + +#[derive(Debug, Builder)] +pub struct BaseModelVisual { + engine: Engine, + height: usize, + width: usize, + batch: usize, + processor: Processor, + ts: Ts, + spec: String, + name: &'static str, + device: Device, + dtype: DType, + task: Option, + scale: Option, + kind: Option, + version: Option, +} + +impl BaseModelVisual { + pub fn summary(&self) { + self.ts.summary(); + } + + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let err_msg = "You need to specify the image height and image width for visual model."; + let (batch, height, width, ts, spec) = ( + engine.batch().opt(), + engine.try_height().expect(err_msg).opt(), + engine.try_width().expect(err_msg).opt(), + engine.ts.clone(), + engine.spec().to_owned(), + ); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + let device = options.model_device; + let task = options.model_task; + let scale = options.model_scale; + let dtype = options.model_dtype; + let kind = options.model_kind; + let name = options.model_name; + let version = options.model_version; + + Ok(Self { + engine, + height, + width, + batch, + processor, + ts, + spec, + dtype, + task, + scale, + kind, + device, + version, + name, + }) + } + + pub fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let x = self.processor.process_images(xs)?; + self.batch = xs.len(); // update + + Ok(x.into()) + } + + pub fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn encode(&mut self, xs: &[DynamicImage]) -> Result { + let xs = elapsed!("visual-preprocess", self.ts, { self.preprocess(xs)? }); + let xs = elapsed!("visual-inference", self.ts, { self.inference(xs)? }); + + Ok(xs[0].to_owned()) + } +} + +#[derive(Debug, Builder)] +pub struct BaseModelTextual { + engine: Engine, + batch: usize, + processor: Processor, + ts: Ts, + spec: String, + name: &'static str, + device: Device, + dtype: DType, + task: Option, + scale: Option, + kind: Option, + version: Option, +} + +impl BaseModelTextual { + pub fn summary(&self) { + self.ts.summary(); + } + + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, ts, spec) = ( + engine.batch().opt(), + engine.ts.clone(), + engine.spec().to_owned(), + ); + let processor = options.to_processor()?; + let device = options.model_device; + let task = options.model_task; + let scale = options.model_scale; + let dtype = options.model_dtype; + let kind = options.model_kind; + let name = options.model_name; + let version = options.model_version; + + Ok(Self { + engine, + batch, + processor, + ts, + spec, + dtype, + task, + scale, + kind, + device, + version, + name, + }) + } + + pub fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } +} diff --git a/src/models/beit/config.rs b/src/models/beit/config.rs new file mode 100644 index 0000000..97ba9db --- /dev/null +++ b/src/models/beit/config.rs @@ -0,0 +1,26 @@ +use crate::models::IMAGENET_NAMES_1K; + +/// Model configuration for `BEiT` +impl crate::Options { + pub fn beit() -> Self { + Self::default() + .with_model_name("beit") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_image_mean(&[0.5, 0.5, 0.5]) + .with_image_std(&[0.5, 0.5, 0.5]) + .with_normalize(true) + .with_apply_softmax(true) + .with_class_names(&IMAGENET_NAMES_1K) + } + + pub fn beit_base() -> Self { + Self::beit().with_model_file("b.onnx") + } + + pub fn beit_large() -> Self { + Self::beit().with_model_file("l.onnx") + } +} diff --git a/src/models/beit/mod.rs b/src/models/beit/mod.rs new file mode 100644 index 0000000..1bf79df --- /dev/null +++ b/src/models/beit/mod.rs @@ -0,0 +1 @@ +mod config; diff --git a/src/models/blip.rs b/src/models/blip.rs deleted file mode 100644 index bb94c2f..0000000 --- a/src/models/blip.rs +++ /dev/null @@ -1,155 +0,0 @@ -use anyhow::Result; -use image::DynamicImage; -use ndarray::s; -use std::io::Write; -use tokenizers::Tokenizer; - -use crate::{ - Embedding, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, TokenizerStream, Xs, X, Y, -}; - -#[derive(Debug)] -pub struct Blip { - pub textual: OrtEngine, - pub visual: OrtEngine, - pub height: MinOptMax, - pub width: MinOptMax, - pub batch_visual: MinOptMax, - pub batch_textual: MinOptMax, - tokenizer: TokenizerStream, -} - -impl Blip { - pub fn new(options_visual: Options, options_textual: Options) -> Result { - let mut visual = OrtEngine::new(&options_visual)?; - let mut textual = OrtEngine::new(&options_textual)?; - let (batch_visual, batch_textual, height, width) = ( - visual.batch().to_owned(), - textual.batch().to_owned(), - visual.height().to_owned(), - visual.width().to_owned(), - ); - - let tokenizer = options_textual - .tokenizer - .ok_or(anyhow::anyhow!("No tokenizer file found"))?; - let tokenizer = match Tokenizer::from_file(tokenizer) { - Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err), - Ok(x) => x, - }; - - let tokenizer = TokenizerStream::new(tokenizer); - visual.dry_run()?; - textual.dry_run()?; - Ok(Self { - textual, - visual, - batch_visual, - batch_textual, - height, - width, - tokenizer, - }) - } - - pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { - let xs_ = X::apply(&[ - Ops::Resize( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Bilinear", - ), - Ops::Normalize(0., 255.), - Ops::Standardize( - &[0.48145466, 0.4578275, 0.40821073], - &[0.26862954, 0.2613026, 0.2757771], - 3, - ), - Ops::Nhwc2nchw, - ])?; - let ys = self.visual.run(Xs::from(xs_))?; - Ok(Y::default().with_embedding(Embedding::from(ys[0].to_owned()))) - } - - pub fn caption(&mut self, xs: &Y, prompt: Option<&str>, show: bool) -> Result> { - let mut ys: Vec = Vec::new(); - let image_embeds = match xs.embedding() { - Some(x) => X::from(x.data().to_owned()), - None => anyhow::bail!("No image embeddings found."), - }; - let image_embeds_attn_mask = X::ones(&[self.batch_visual(), image_embeds.dims()[1]]); - - let mut y_text = String::new(); - - // conditional - let mut input_ids = match prompt { - None => { - if show { - print!("[Unconditional]: "); - } - vec![0.0f32] - } - Some(prompt) => { - let encodings = match self.tokenizer.tokenizer().encode(prompt, false) { - Err(err) => anyhow::bail!("{}", err), - Ok(x) => x, - }; - let ids: Vec = encodings.get_ids().iter().map(|x| *x as f32).collect(); - if show { - print!("[Conditional]: {} ", prompt); - } - y_text.push_str(&format!("{} ", prompt)); - ids - } - }; - - let mut logits_sampler = LogitsSampler::new(); - loop { - let input_ids_nd = X::from(input_ids.to_owned()) - .insert_axis(0)? - .repeat(0, self.batch_textual())?; - let input_ids_attn_mask = X::ones(input_ids_nd.dims()); - - let y = self.textual.run(Xs::from(vec![ - input_ids_nd, - input_ids_attn_mask, - image_embeds.clone(), - image_embeds_attn_mask.clone(), - ]))?; // N, length, vocab_size - let y = y[0].slice(s!(0, -1.., ..)); - let logits = y.slice(s!(0, ..)).to_vec(); - let token_id = logits_sampler.decode(&logits)?; - input_ids.push(token_id as f32); - - // SEP - if token_id == 102 { - break; - } - - // streaming generation - if let Some(t) = self.tokenizer.next_token(token_id as u32)? { - y_text.push_str(&t); - if show { - print!("{t}"); - // std::thread::sleep(std::time::Duration::from_millis(5)); - } - std::io::stdout().flush()?; - } - } - if show { - println!(); - } - self.tokenizer.clear(); - ys.push(Y::default().with_texts(&[&y_text])); - Ok(ys) - } - - pub fn batch_visual(&self) -> usize { - self.batch_visual.opt() - } - - pub fn batch_textual(&self) -> usize { - self.batch_textual.opt() - } -} diff --git a/src/models/blip/config.rs b/src/models/blip/config.rs new file mode 100644 index 0000000..2248a9e --- /dev/null +++ b/src/models/blip/config.rs @@ -0,0 +1,34 @@ +/// Model configuration for `BLIP` +impl crate::Options { + pub fn blip() -> Self { + Self::default().with_model_name("blip").with_batch_size(1) + } + + #[allow(clippy::excessive_precision)] + pub fn blip_visual() -> Self { + Self::blip() + .with_model_kind(crate::Kind::Vision) + .with_model_ixx(0, 2, 384.into()) + .with_model_ixx(0, 3, 384.into()) + .with_image_mean(&[0.48145466, 0.4578275, 0.40821073]) + .with_image_std(&[0.26862954, 0.26130258, 0.27577711]) + .with_resize_filter("Bilinear") + .with_normalize(true) + } + + pub fn blip_textual() -> Self { + Self::blip().with_model_kind(crate::Kind::Language) + } + + pub fn blip_v1_base_caption_visual() -> Self { + Self::blip_visual() + .with_model_version(1.0.into()) + .with_model_file("v1-base-caption-visual.onnx") + } + + pub fn blip_v1_base_caption_textual() -> Self { + Self::blip_textual() + .with_model_version(1.0.into()) + .with_model_file("v1-base-caption-textual.onnx") + } +} diff --git a/src/models/blip/impl.rs b/src/models/blip/impl.rs new file mode 100644 index 0000000..b77a1a5 --- /dev/null +++ b/src/models/blip/impl.rs @@ -0,0 +1,130 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Axis}; + +use crate::{ + elapsed, + models::{BaseModelTextual, BaseModelVisual}, + LogitsSampler, Options, Ts, Xs, Ys, X, Y, +}; + +#[derive(Debug, Builder)] +pub struct Blip { + visual: BaseModelVisual, + textual: BaseModelTextual, + ts: Ts, + max_length: usize, + eos_token_id: u32, +} + +impl Blip { + pub fn new(options_visual: Options, options_textual: Options) -> Result { + let visual = BaseModelVisual::new(options_visual)?; + let textual = BaseModelTextual::new(options_textual)?; + let ts = Ts::merge(&[visual.engine().ts(), textual.engine().ts()]); + let max_length = 512; + let eos_token_id = 102; + + Ok(Self { + textual, + visual, + ts, + max_length, + eos_token_id, + }) + } + + pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { + self.visual.encode(xs) + } + + pub fn encode_texts(&mut self, text: Option<&str>) -> Result>> { + let input_ids = self + .textual + .processor() + .encode_text_ids(text.unwrap_or_default(), false)?; + Ok(vec![input_ids.clone(); self.batch()]) + } + + pub fn forward(&mut self, images: &[DynamicImage], text: Option<&str>) -> Result { + let image_embeds = elapsed!("encode_images", self.ts, { self.encode_images(images)? }); + let ys = elapsed!("generate", self.ts, { self.generate(&image_embeds, text)? }); + + Ok(ys) + } + + pub fn generate(&mut self, image_embeds: &X, text: Option<&str>) -> Result { + // encode texts + let mut token_ids = self.encode_texts(text)?; + + // generate + let mut logits_sampler = LogitsSampler::new(); + let mut finished = vec![false; self.batch()]; + for _ in 0..self.max_length { + let input_ids_nd = token_ids + .iter() + .map(|tokens| X::from(tokens.clone()).insert_axis(0)) + .collect::, _>>()?; + + let input_ids_nd = X::concat(&input_ids_nd, 0)?; + let input_ids_attn_mask = X::ones(input_ids_nd.dims()); + + // decode + let outputs = self.textual.inference(Xs::from(vec![ + input_ids_nd, + input_ids_attn_mask, + image_embeds.clone(), + X::ones(&[self.visual().batch(), image_embeds.dims()[1]]), // image_embeds_attn_mask + ]))?; + + // decode each token for each batch + for (i, logit) in outputs[0].axis_iter(Axis(0)).enumerate() { + if !finished[i] { + let token_id = logits_sampler.decode( + &logit + .slice(s![-1, ..]) + .into_owned() + .into_raw_vec_and_offset() + .0, + )?; + if token_id == self.eos_token_id { + finished[i] = true; + } + token_ids[i].push(token_id as f32); + } else { + token_ids[i].push(self.eos_token_id as f32); + } + } + + if finished.iter().all(|&x| x) { + break; + } + } + + // batch decode + let texts = self.textual.processor().decode_tokens_batch( + &token_ids + .into_iter() + .map(|v| v.into_iter().map(|x| x as u32).collect::>()) + .collect::>>(), + true, + )?; + + let ys = texts + .into_iter() + .map(|x| Y::default().with_texts(&[x.into()])) + .collect::>() + .into(); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + pub fn batch(&self) -> usize { + self.visual.batch() as _ + } +} diff --git a/src/models/blip/mod.rs b/src/models/blip/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/blip/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/clip.rs b/src/models/clip.rs deleted file mode 100644 index 550e03e..0000000 --- a/src/models/clip.rs +++ /dev/null @@ -1,107 +0,0 @@ -use anyhow::Result; -use image::DynamicImage; -use ndarray::Array2; -use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer}; - -use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; - -#[derive(Debug)] -pub struct Clip { - pub textual: OrtEngine, - pub visual: OrtEngine, - pub height: MinOptMax, - pub width: MinOptMax, - pub batch_visual: MinOptMax, - pub batch_textual: MinOptMax, - tokenizer: Tokenizer, - context_length: usize, -} - -impl Clip { - pub fn new(options_visual: Options, options_textual: Options) -> Result { - let context_length = 77; - let mut visual = OrtEngine::new(&options_visual)?; - let mut textual = OrtEngine::new(&options_textual)?; - let (batch_visual, batch_textual, height, width) = ( - visual.inputs_minoptmax()[0][0].to_owned(), - textual.inputs_minoptmax()[0][0].to_owned(), - visual.inputs_minoptmax()[0][2].to_owned(), - visual.inputs_minoptmax()[0][3].to_owned(), - ); - - let tokenizer = options_textual - .tokenizer - .ok_or(anyhow::anyhow!("No tokenizer file found"))?; - - let mut tokenizer = match Tokenizer::from_file(tokenizer) { - Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err), - Ok(x) => x, - }; - - tokenizer.with_padding(Some(PaddingParams { - strategy: PaddingStrategy::Fixed(context_length), - direction: PaddingDirection::Right, - pad_to_multiple_of: None, - pad_id: 0, - pad_type_id: 0, - pad_token: "[PAD]".to_string(), - })); - - visual.dry_run()?; - textual.dry_run()?; - - Ok(Self { - textual, - visual, - batch_visual, - batch_textual, - height, - width, - tokenizer, - context_length, - }) - } - - pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { - let xs_ = X::apply(&[ - Ops::Resize( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Bilinear", - ), - Ops::Normalize(0., 255.), - Ops::Standardize( - &[0.48145466, 0.4578275, 0.40821073], - &[0.26862954, 0.2613026, 0.2757771], - 3, - ), - Ops::Nhwc2nchw, - ])?; - let ys = self.visual.run(Xs::from(xs_))?; - Ok(Y::default().with_embedding(Embedding::from(ys[0].to_owned()))) - } - - pub fn encode_texts(&mut self, texts: &[String]) -> Result { - let encodings = match self.tokenizer.encode_batch(texts.to_owned(), false) { - Err(err) => anyhow::bail!("{:?}", err), - Ok(x) => x, - }; - let xs: Vec = encodings - .iter() - .flat_map(|i| i.get_ids().iter().map(|&b| b as f32)) - .collect(); - let xs = Array2::from_shape_vec((texts.len(), self.context_length), xs)?.into_dyn(); - let xs = X::from(xs); - let ys = self.textual.run(Xs::from(xs))?; - Ok(Y::default().with_embedding(Embedding::from(ys[0].to_owned()))) - } - - pub fn batch_visual(&self) -> usize { - self.batch_visual.opt() - } - - pub fn batch_textual(&self) -> usize { - self.batch_textual.opt() - } -} diff --git a/src/models/clip/config.rs b/src/models/clip/config.rs new file mode 100644 index 0000000..0454261 --- /dev/null +++ b/src/models/clip/config.rs @@ -0,0 +1,71 @@ +use crate::Kind; + +/// Model configuration for `CLIP` +impl crate::Options { + pub fn clip() -> Self { + Self::default() + .with_model_name("clip") + .with_model_ixx(0, 0, 1.into()) + } + + pub fn clip_visual() -> Self { + Self::clip() + .with_model_kind(Kind::Vision) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_image_mean(&[0.48145466, 0.4578275, 0.40821073]) + .with_image_std(&[0.26862954, 0.2613026, 0.2757771]) + } + + pub fn clip_textual() -> Self { + Self::clip() + .with_model_kind(Kind::Language) + .with_model_max_length(77) + } + + pub fn clip_vit_b16_visual() -> Self { + Self::clip_visual().with_model_file("vit-b16-visual.onnx") + } + + pub fn clip_vit_b16_textual() -> Self { + Self::clip_textual().with_model_file("vit-b16-textual.onnx") + } + + pub fn clip_vit_b32_visual() -> Self { + Self::clip_visual().with_model_file("vit-b32-visual.onnx") + } + + pub fn clip_vit_b32_textual() -> Self { + Self::clip_textual().with_model_file("vit-b32-textual.onnx") + } + + pub fn clip_vit_l14_visual() -> Self { + Self::clip_visual().with_model_file("vit-l14-visual.onnx") + } + + pub fn clip_vit_l14_textual() -> Self { + Self::clip_textual().with_model_file("vit-l14-textual.onnx") + } + + pub fn jina_clip_v1() -> Self { + Self::default() + .with_model_name("jina-clip-v1") + .with_model_ixx(0, 0, 1.into()) + } + + pub fn jina_clip_v1_visual() -> Self { + Self::jina_clip_v1() + .with_model_kind(Kind::Vision) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_image_mean(&[0.48145466, 0.4578275, 0.40821073]) + .with_image_std(&[0.26862954, 0.2613026, 0.2757771]) + .with_model_file("visual.onnx") + } + + pub fn jina_clip_v1_textual() -> Self { + Self::jina_clip_v1() + .with_model_kind(Kind::Language) + .with_model_file("textual.onnx") + } +} diff --git a/src/models/clip/impl.rs b/src/models/clip/impl.rs new file mode 100644 index 0000000..24f6abb --- /dev/null +++ b/src/models/clip/impl.rs @@ -0,0 +1,149 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::Array2; + +use crate::{elapsed, Engine, Options, Processor, Ts, Xs, X}; + +#[derive(Debug, Builder)] +pub struct ClipVisual { + engine: Engine, + height: usize, + width: usize, + batch: usize, + processor: Processor, + ts: Ts, +} + +impl ClipVisual { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&224.into()).opt(), + engine.try_width().unwrap_or(&224.into()).opt(), + engine.ts.clone(), + ); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + Ok(Self { + engine, + height, + width, + batch, + processor, + ts, + }) + } + + pub fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let x = self.processor.process_images(xs)?; + + Ok(x.into()) + } + + pub fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { + let xs = elapsed!("visual-preprocess", self.ts, { self.preprocess(xs)? }); + let xs = elapsed!("visual-inference", self.ts, { self.inference(xs)? }); + let x = elapsed!("visual-postprocess", self.ts, { xs[0].to_owned() }); + + Ok(x) + } +} + +#[derive(Debug, Builder)] +pub struct ClipTextual { + engine: Engine, + batch: usize, + processor: Processor, + ts: Ts, +} + +impl ClipTextual { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, ts) = (engine.batch().opt(), engine.ts.clone()); + let processor = options.to_processor()?; + + Ok(Self { + engine, + batch, + processor, + ts, + }) + } + + pub fn preprocess(&self, xs: &[&str]) -> Result { + let encodings: Vec = self + .processor + .encode_texts_ids(xs, false)? // skip_special_tokens + .into_iter() + .flatten() + .collect(); + + let x: X = Array2::from_shape_vec((xs.len(), encodings.len() / xs.len()), encodings)? + .into_dyn() + .into(); + + Ok(x.into()) + } + + pub fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn encode_texts(&mut self, xs: &[&str]) -> Result { + let xs = elapsed!("textual-preprocess", self.ts, { self.preprocess(xs)? }); + let xs = elapsed!("textual-inference", self.ts, { self.inference(xs)? }); + let x = elapsed!("textual-postprocess", self.ts, { xs[0].to_owned() }); + + Ok(x) + } +} + +#[derive(Debug, Builder)] +pub struct Clip { + textual: ClipTextual, + visual: ClipVisual, + ts: Ts, +} + +impl Clip { + pub fn new(options_visual: Options, options_textual: Options) -> Result { + let visual = ClipVisual::new(options_visual)?; + let textual = ClipTextual::new(options_textual)?; + // let ts = Ts::merge(&[visual.engine().ts(), textual.engine().ts()]); + let ts = Ts::default(); + + Ok(Self { + textual, + visual, + ts, + }) + } + + pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { + let x = elapsed!("encode_images", self.ts, { self.visual.encode_images(xs)? }); + Ok(x) + } + + pub fn encode_texts(&mut self, xs: &[&str]) -> Result { + let x = elapsed!("encode_texts", self.ts, { self.textual.encode_texts(xs)? }); + Ok(x) + } + + pub fn summary(&mut self) { + // self.ts.clear(); + // self.ts = Ts::merge(&[&self.ts, self.visual.ts(), self.textual.ts()]); + self.ts.summary(); + self.visual.ts().summary(); + self.textual.ts().summary(); + } +} diff --git a/src/models/clip/mod.rs b/src/models/clip/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/clip/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/convnext/config.rs b/src/models/convnext/config.rs new file mode 100644 index 0000000..784a4b2 --- /dev/null +++ b/src/models/convnext/config.rs @@ -0,0 +1,66 @@ +use crate::models::IMAGENET_NAMES_1K; + +/// Model configuration for `ConvNeXt` +impl crate::Options { + pub fn convnext() -> Self { + Self::default() + .with_model_name("convnext") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_normalize(true) + .with_apply_softmax(true) + .with_class_names(&IMAGENET_NAMES_1K) + } + + pub fn convnext_v1_tiny() -> Self { + Self::convnext().with_model_file("v1-t.onnx") + } + + pub fn convnext_v1_small() -> Self { + Self::convnext().with_model_file("v1-s.onnx") + } + + pub fn convnext_v1_base() -> Self { + Self::convnext().with_model_file("v1-b.onnx") + } + + pub fn convnext_v1_large() -> Self { + Self::convnext().with_model_file("v1-l.onnx") + } + + pub fn convnext_v2_atto() -> Self { + Self::convnext().with_model_file("v2-a.onnx") + } + + pub fn convnext_v2_femto() -> Self { + Self::convnext().with_model_file("v2-f.onnx") + } + + pub fn convnext_v2_pico() -> Self { + Self::convnext().with_model_file("v2-p.onnx") + } + + pub fn convnext_v2_nano() -> Self { + Self::convnext().with_model_file("v2-n.onnx") + } + + pub fn convnext_v2_tiny() -> Self { + Self::convnext().with_model_file("v2-t.onnx") + } + + pub fn convnext_v2_small() -> Self { + Self::convnext().with_model_file("v2-s.onnx") + } + + pub fn convnext_v2_base() -> Self { + Self::convnext().with_model_file("v2-b.onnx") + } + + pub fn convnext_v2_large() -> Self { + Self::convnext().with_model_file("v2-l.onnx") + } +} diff --git a/src/models/convnext/mod.rs b/src/models/convnext/mod.rs new file mode 100644 index 0000000..1bf79df --- /dev/null +++ b/src/models/convnext/mod.rs @@ -0,0 +1 @@ +mod config; diff --git a/src/models/db/config.rs b/src/models/db/config.rs new file mode 100644 index 0000000..4843e54 --- /dev/null +++ b/src/models/db/config.rs @@ -0,0 +1,29 @@ +/// Model configuration for `DB` +impl crate::Options { + pub fn db() -> Self { + Self::default() + .with_model_name("db") + .with_model_ixx(0, 0, (1, 1, 8).into()) + .with_model_ixx(0, 2, (608, 960, 1600).into()) + .with_model_ixx(0, 3, (608, 960, 1600).into()) + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_normalize(true) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_class_confs(&[0.4]) + .with_min_width(5.0) + .with_min_height(12.0) + } + + pub fn ppocr_det_v3_ch() -> Self { + Self::db().with_model_file("ppocr-v3-ch.onnx") + } + + pub fn ppocr_det_v4_ch() -> Self { + Self::db().with_model_file("ppocr-v4-ch.onnx") + } + + pub fn ppocr_det_v4_server_ch() -> Self { + Self::db().with_model_file("ppocr-v4-server-ch.onnx") + } +} diff --git a/src/models/db.rs b/src/models/db/impl.rs similarity index 64% rename from src/models/db.rs rename to src/models/db/impl.rs index aefa620..de1e793 100644 --- a/src/models/db.rs +++ b/src/models/db/impl.rs @@ -1,36 +1,45 @@ +use aksr::Builder; use anyhow::Result; use image::DynamicImage; use ndarray::Axis; -use crate::{DynConf, Mbr, MinOptMax, Ops, Options, OrtEngine, Polygon, Xs, X, Y}; +use crate::{elapsed, DynConf, Engine, Mbr, Ops, Options, Polygon, Processor, Ts, Xs, Ys, Y}; -#[derive(Debug)] +#[derive(Debug, Builder)] pub struct DB { - engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, + engine: Engine, + height: usize, + width: usize, + batch: usize, confs: DynConf, unclip_ratio: f32, binary_thresh: f32, min_width: f32, min_height: f32, + spec: String, + ts: Ts, + processor: Processor, } impl DB { pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().to_owned(), - engine.height().to_owned(), - engine.width().to_owned(), + let engine = options.to_engine()?; + let (batch, height, width, ts, spec) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&960.into()).opt(), + engine.try_width().unwrap_or(&960.into()).opt(), + engine.ts.clone(), + engine.spec().to_owned(), ); - let confs = DynConf::new(&options.confs, 1); - let unclip_ratio = options.unclip_ratio; - let binary_thresh = 0.2; - let min_width = options.min_width.unwrap_or(0.); - let min_height = options.min_height.unwrap_or(0.); - engine.dry_run()?; + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let confs = DynConf::new(options.class_confs(), 1); + let binary_thresh = options.binary_thresh().unwrap_or(0.2); + let unclip_ratio = options.unclip_ratio().unwrap_or(1.5); + let min_width = options.min_width().unwrap_or(12.0); + let min_height = options.min_height().unwrap_or(5.0); Ok(Self { engine, @@ -42,29 +51,33 @@ impl DB { min_height, unclip_ratio, binary_thresh, + processor, + spec, + ts, }) } - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Letterbox( - xs, - self.height() as u32, - self.width() as u32, - "Bilinear", - 114, - "auto", - false, - ), - Ops::Normalize(0., 255.), - Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3), - Ops::Nhwc2nchw, - ])?; - let ys = self.engine.run(Xs::from(xs_))?; - self.postprocess(ys, xs) + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) } - pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { + pub fn summary(&mut self) { + self.ts.summary(); + } + + pub fn postprocess(&mut self, xs: Xs) -> Result { let mut ys = Vec::new(); for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() { let mut y_bbox = Vec::new(); @@ -72,13 +85,10 @@ impl DB { let mut y_mbrs: Vec = Vec::new(); // input image - let image_width = xs0[idx].width() as f32; - let image_height = xs0[idx].height() as f32; + let (image_height, image_width) = self.processor.image0s_size[idx]; // reshape - let h = luma.dim()[1]; - let w = luma.dim()[2]; - let (ratio, _, _) = Ops::scale_wh(image_width, image_height, w as f32, h as f32); + let ratio = self.processor.scale_factors_hw[idx][0]; let v = luma .into_owned() .into_raw_vec_and_offset() @@ -95,8 +105,8 @@ impl DB { let luma = Ops::resize_luma8_u8( &v, - self.width() as _, - self.height() as _, + self.width as _, + self.height as _, image_width as _, image_height as _, true, @@ -158,18 +168,7 @@ impl DB { .with_mbrs(&y_mbrs), ); } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - pub fn height(&self) -> isize { - self.height.opt() as _ + Ok(ys.into()) } } diff --git a/src/models/db/mod.rs b/src/models/db/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/db/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/deit/config.rs b/src/models/deit/config.rs new file mode 100644 index 0000000..999319b --- /dev/null +++ b/src/models/deit/config.rs @@ -0,0 +1,30 @@ +use crate::models::IMAGENET_NAMES_1K; + +/// Model configuration for `DeiT` +impl crate::Options { + pub fn deit() -> Self { + Self::default() + .with_model_name("deit") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_normalize(true) + .with_apply_softmax(true) + .with_class_names(&IMAGENET_NAMES_1K) + } + + pub fn deit_tiny_distill() -> Self { + Self::deit().with_model_file("t-distill.onnx") + } + + pub fn deit_small_distill() -> Self { + Self::deit().with_model_file("s-distill.onnx") + } + + pub fn deit_base_distill() -> Self { + Self::deit().with_model_file("b-distill.onnx") + } +} diff --git a/src/models/deit/mod.rs b/src/models/deit/mod.rs new file mode 100644 index 0000000..1bf79df --- /dev/null +++ b/src/models/deit/mod.rs @@ -0,0 +1 @@ +mod config; diff --git a/src/models/depth_anything.rs b/src/models/depth_anything.rs deleted file mode 100644 index 4573dfb..0000000 --- a/src/models/depth_anything.rs +++ /dev/null @@ -1,90 +0,0 @@ -use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; -use anyhow::Result; -use image::DynamicImage; -use ndarray::Axis; - -#[derive(Debug)] -pub struct DepthAnything { - engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, -} - -impl DepthAnything { - pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().to_owned(), - engine.height().to_owned(), - engine.width().to_owned(), - ); - engine.dry_run()?; - - Ok(Self { - engine, - height, - width, - batch, - }) - } - - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Resize( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Lanczos3", - ), - Ops::Normalize(0., 255.), - Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3), - Ops::Nhwc2nchw, - ])?; - let ys = self.engine.run(Xs::from(xs_))?; - self.postprocess(ys, xs) - } - - pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { - let mut ys: Vec = Vec::new(); - for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() { - let (w1, h1) = (xs0[idx].width(), xs0[idx].height()); - let v = luma.into_owned().into_raw_vec_and_offset().0; - let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap(); - let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap(); - let v = v - .iter() - .map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8) - .collect::>(); - - let luma = Ops::resize_luma8_u8( - &v, - self.width() as _, - self.height() as _, - w1 as _, - h1 as _, - false, - "Bilinear", - )?; - let luma: image::ImageBuffer, Vec<_>> = - match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) { - None => continue, - Some(x) => x, - }; - ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)])); - } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - - pub fn height(&self) -> isize { - self.height.opt() as _ - } -} diff --git a/src/models/depth_anything/config.rs b/src/models/depth_anything/config.rs new file mode 100644 index 0000000..6133876 --- /dev/null +++ b/src/models/depth_anything/config.rs @@ -0,0 +1,40 @@ +/// Model configuration for `DepthAnything` +impl crate::Options { + pub fn depth_anything() -> Self { + Self::default() + .with_model_name("depth-anything") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, (384, 518, 1024).into()) + .with_model_ixx(0, 3, (384, 518, 1024).into()) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_resize_filter("Lanczos3") + .with_normalize(true) + } + + pub fn depth_anything_s() -> Self { + Self::depth_anything().with_model_scale(crate::Scale::S) + } + + pub fn depth_anything_v1() -> Self { + Self::depth_anything().with_model_version(1.0.into()) + } + + pub fn depth_anything_v2() -> Self { + Self::depth_anything().with_model_version(2.0.into()) + } + + pub fn depth_anything_v1_small() -> Self { + Self::depth_anything_v1() + .with_model_scale(crate::Scale::S) + .with_model_file("v1-s.onnx") + } + + pub fn depth_anything_v2_small() -> Self { + Self::depth_anything_v2() + .with_model_scale(crate::Scale::S) + .with_model_file("v2-s.onnx") + } + // TODO +} diff --git a/src/models/depth_anything/impl.rs b/src/models/depth_anything/impl.rs new file mode 100644 index 0000000..7a9b61c --- /dev/null +++ b/src/models/depth_anything/impl.rs @@ -0,0 +1,98 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; + +use crate::{elapsed, Engine, Mask, Ops, Options, Processor, Ts, Xs, Ys, Y}; + +#[derive(Debug, Builder)] +pub struct DepthAnything { + engine: Engine, + height: usize, + width: usize, + batch: usize, + spec: String, + ts: Ts, + processor: Processor, +} + +impl DepthAnything { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&518.into()).opt(), + engine.try_width().unwrap_or(&518.into()).opt(), + engine.ts().clone(), + ); + + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + Ok(Self { + engine, + height, + width, + batch, + spec, + ts, + processor, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + fn postprocess(&mut self, xs: Xs) -> Result { + let mut ys: Vec = Vec::new(); + for (idx, luma) in xs[0].axis_iter(ndarray::Axis(0)).enumerate() { + // image size + let (h1, w1) = self.processor.image0s_size[idx]; + let v = luma.into_owned().into_raw_vec_and_offset().0; + let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap(); + let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap(); + let v = v + .iter() + .map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8) + .collect::>(); + + let luma = Ops::resize_luma8_u8( + &v, + self.width() as _, + self.height() as _, + w1 as _, + h1 as _, + false, + "Bilinear", + )?; + let luma: image::ImageBuffer, Vec<_>> = + match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) { + None => continue, + Some(x) => x, + }; + ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)])); + } + + Ok(ys.into()) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } +} diff --git a/src/models/depth_anything/mod.rs b/src/models/depth_anything/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/depth_anything/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/depth_pro.rs b/src/models/depth_pro.rs deleted file mode 100644 index 26938f7..0000000 --- a/src/models/depth_pro.rs +++ /dev/null @@ -1,86 +0,0 @@ -use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; -use anyhow::Result; -use image::DynamicImage; -use ndarray::Axis; - -#[derive(Debug)] -pub struct DepthPro { - engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, -} - -impl DepthPro { - pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().clone(), - engine.height().clone(), - engine.width().clone(), - ); - engine.dry_run()?; - - Ok(Self { - engine, - height, - width, - batch, - }) - } - - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Resize( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Bilinear", - ), - Ops::Normalize(0., 255.), - Ops::Standardize(&[0.5, 0.5, 0.5], &[0.5, 0.5, 0.5], 3), - Ops::Nhwc2nchw, - ])?; - let ys = self.engine.run(Xs::from(xs_))?; - - self.postprocess(ys, xs) - } - - pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { - let (predicted_depth, _focallength_px) = (&xs["predicted_depth"], &xs["focallength_px"]); - let predicted_depth = predicted_depth.mapv(|x| 1. / x); - - let mut ys: Vec = Vec::new(); - for (idx, luma) in predicted_depth.axis_iter(Axis(0)).enumerate() { - let (w1, h1) = (xs0[idx].width(), xs0[idx].height()); - let v = luma.into_owned().into_raw_vec_and_offset().0; - let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap(); - let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap(); - let v = v - .iter() - .map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8) - .collect::>(); - - let luma = Ops::resize_luma8_u8( - &v, - self.width.opt() as _, - self.height.opt() as _, - w1 as _, - h1 as _, - false, - "Bilinear", - )?; - let luma: image::ImageBuffer, Vec<_>> = - match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) { - None => continue, - Some(x) => x, - }; - ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)])); - } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } -} diff --git a/src/models/depth_pro/config.rs b/src/models/depth_pro/config.rs new file mode 100644 index 0000000..451682e --- /dev/null +++ b/src/models/depth_pro/config.rs @@ -0,0 +1,27 @@ +/// Model configuration for `DepthPro` +impl crate::Options { + pub fn depth_pro() -> Self { + Self::default() + .with_model_name("depth-pro") + .with_model_ixx(0, 0, 1.into()) // batch. Note: now only support batch_size = 1 + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 1536.into()) + .with_model_ixx(0, 3, 1536.into()) + .with_image_mean(&[0.5, 0.5, 0.5]) + .with_image_std(&[0.5, 0.5, 0.5]) + .with_resize_mode(crate::ResizeMode::FitExact) + .with_normalize(true) + } + + // pub fn depth_pro_q4f16() -> Self { + // Self::depth_pro().with_model_file("q4f16.onnx") + // } + + // pub fn depth_pro_fp16() -> Self { + // Self::depth_pro().with_model_file("fp16.onnx") + // } + + // pub fn depth_pro_bnb4() -> Self { + // Self::depth_pro().with_model_file("bnb4.onnx") + // } +} diff --git a/src/models/depth_pro/impl.rs b/src/models/depth_pro/impl.rs new file mode 100644 index 0000000..49518d3 --- /dev/null +++ b/src/models/depth_pro/impl.rs @@ -0,0 +1,99 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::Axis; + +use crate::{elapsed, Engine, Mask, Ops, Options, Processor, Ts, Xs, Ys, Y}; + +#[derive(Builder, Debug)] +pub struct DepthPro { + engine: Engine, + height: usize, + width: usize, + batch: usize, + ts: Ts, + spec: String, + processor: Processor, +} + +impl DepthPro { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&512.into()).opt(), + engine.try_width().unwrap_or(&512.into()).opt(), + engine.ts().clone(), + ); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + Ok(Self { + engine, + height, + width, + batch, + ts, + spec, + processor, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + fn postprocess(&mut self, xs: Xs) -> Result { + let (predicted_depth, _focallength_px) = (&xs["predicted_depth"], &xs["focallength_px"]); + let predicted_depth = predicted_depth.mapv(|x| 1. / x); + + let mut ys: Vec = Vec::new(); + for (idx, luma) in predicted_depth.axis_iter(Axis(0)).enumerate() { + let (h1, w1) = self.processor.image0s_size[idx]; + let v = luma.into_owned().into_raw_vec_and_offset().0; + let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap(); + let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap(); + let v = v + .iter() + .map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8) + .collect::>(); + + let luma = Ops::resize_luma8_u8( + &v, + self.width as _, + self.height as _, + w1 as _, + h1 as _, + false, + "Bilinear", + )?; + let luma: image::ImageBuffer, Vec<_>> = + match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) { + None => continue, + Some(x) => x, + }; + ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)])); + } + + Ok(ys.into()) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } +} diff --git a/src/models/depth_pro/mod.rs b/src/models/depth_pro/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/depth_pro/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/dinov2.rs b/src/models/dinov2.rs deleted file mode 100644 index 46b676d..0000000 --- a/src/models/dinov2.rs +++ /dev/null @@ -1,161 +0,0 @@ -use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; -use anyhow::Result; -use image::DynamicImage; -// use std::path::PathBuf; -// use usearch::ffi::{IndexOptions, MetricKind, ScalarKind}; - -#[derive(Debug)] -pub enum Model { - S, - B, -} - -#[derive(Debug)] -pub struct Dinov2 { - engine: OrtEngine, - pub height: MinOptMax, - pub width: MinOptMax, - pub batch: MinOptMax, - pub hidden_size: usize, -} - -impl Dinov2 { - pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.inputs_minoptmax()[0][0].to_owned(), - engine.inputs_minoptmax()[0][2].to_owned(), - engine.inputs_minoptmax()[0][3].to_owned(), - ); - let which = match options.onnx_path { - s if s.contains('b') => Model::B, - s if s.contains('s') => Model::S, - _ => todo!(), - }; - let hidden_size = match which { - Model::S => 384, - Model::B => 768, - }; - engine.dry_run()?; - - Ok(Self { - engine, - height, - width, - batch, - hidden_size, - }) - } - - pub fn run(&mut self, xs: &[DynamicImage]) -> Result { - let xs_ = X::apply(&[ - Ops::Resize( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Lanczos3", - ), - Ops::Normalize(0., 255.), - Ops::Standardize( - &[0.48145466, 0.4578275, 0.40821073], - &[0.26862954, 0.2613026, 0.2757771], - 3, - ), - Ops::Nhwc2nchw, - ])?; - let ys = self.engine.run(Xs::from(xs_))?; - Ok(Y::default().with_embedding(Embedding::from(ys[0].to_owned()))) - } - - // pub fn build_index(&self, metric: Metric) -> Result { - // let metric = match metric { - // Metric::IP => MetricKind::IP, - // Metric::L2 => MetricKind::L2sq, - // Metric::Cos => MetricKind::Cos, - // }; - // let options = IndexOptions { - // metric, - // dimensions: self.hidden_size, - // quantization: ScalarKind::F16, - // ..Default::default() - // }; - // Ok(usearch::new_index(&options)?) - // } - - // pub fn query_from_folder( - // &mut self, - // qurey: &str, - // gallery: &str, - // metric: Metric, - // ) -> Result> { - // // load query - // let query = DataLoader::try_read(qurey)?; - // let query = self.run(&[query])?; - - // // build index & gallery - // let index = self.build_index(metric)?; - // let dl = DataLoader::default() - // .with_batch(self.batch.opt as usize) - // .load(gallery)?; - // let paths = dl.paths().to_owned(); - // index.reserve(paths.len())?; - - // // load feats - // for (idx, (x, _path)) in dl.enumerate() { - // let y = self.run(&x)?; - // index.add(idx as u64, &y.into_raw_vec())?; - // } - - // // output - // let matches = index.search(&query.into_raw_vec(), index.size())?; - // let mut results: Vec<(usize, f32, PathBuf)> = Vec::new(); - // matches - // .keys - // .into_iter() - // .zip(matches.distances) - // .for_each(|(k, score)| { - // results.push((k as usize, score, paths[k as usize].to_owned())); - // }); - - // Ok(results) - // } - - // pub fn query_from_vec( - // &mut self, - // qurey: &str, - // gallery: &[&str], - // metric: Metric, - // ) -> Result> { - // // load query - // let query = DataLoader::try_read(qurey)?; - // let query = self.run(&[query])?; - - // // build index & gallery - // let index = self.build_index(metric)?; - // index.reserve(gallery.len())?; - // let mut dl = DataLoader::default().with_batch(self.batch.opt as usize); - // gallery.iter().for_each(|x| { - // dl.load(x).unwrap(); - // }); - - // // load feats - // let paths = dl.paths().to_owned(); - // for (idx, (x, _path)) in dl.enumerate() { - // let y = self.run(&x)?; - // index.add(idx as u64, &y.into_raw_vec())?; - // } - - // // output - // let matches = index.search(&query.into_raw_vec(), index.size())?; - // let mut results: Vec<(usize, f32, PathBuf)> = Vec::new(); - // matches - // .keys - // .into_iter() - // .zip(matches.distances) - // .for_each(|(k, score)| { - // results.push((k as usize, score, paths[k as usize].to_owned())); - // }); - - // Ok(results) - // } -} diff --git a/src/models/dinov2/config.rs b/src/models/dinov2/config.rs new file mode 100644 index 0000000..abf7696 --- /dev/null +++ b/src/models/dinov2/config.rs @@ -0,0 +1,28 @@ +/// Model configuration for `DINOv2` +impl crate::Options { + pub fn dinov2() -> Self { + Self::default() + .with_model_name("dinov2") + .with_model_ixx(0, 0, (1, 1, 8).into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_resize_mode(crate::ResizeMode::FitExact) + .with_resize_filter("Lanczos3") + .with_normalize(true) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_image_mean(&[0.485, 0.456, 0.406]) + } + + pub fn dinov2_small() -> Self { + Self::dinov2() + .with_model_scale(crate::Scale::S) + .with_model_file("s.onnx") + } + + pub fn dinov2_base() -> Self { + Self::dinov2() + .with_model_scale(crate::Scale::B) + .with_model_file("b.onnx") + } +} diff --git a/src/models/dinov2/impl.rs b/src/models/dinov2/impl.rs new file mode 100644 index 0000000..de0897e --- /dev/null +++ b/src/models/dinov2/impl.rs @@ -0,0 +1,68 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; + +use crate::{elapsed, Engine, Options, Processor, Scale, Ts, Xs, X}; + +#[derive(Builder, Debug)] +pub struct DINOv2 { + engine: Engine, + height: usize, + width: usize, + batch: usize, + dim: usize, + ts: Ts, + processor: Processor, +} + +impl DINOv2 { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&384.into()).opt(), + engine.try_width().unwrap_or(&384.into()).opt(), + engine.ts.clone(), + ); + let dim = match options.model_scale() { + Some(Scale::S) => 384, + Some(Scale::B) => 768, + Some(Scale::L) => 1024, + Some(Scale::G) => 1536, + Some(x) => anyhow::bail!("Unsupported scale: {:?}", x), + None => anyhow::bail!("No model scale specified"), + }; + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + Ok(Self { + engine, + height, + width, + batch, + dim, + ts, + processor, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let x = self.processor.process_images(xs)?; + + Ok(x.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { + let xs = elapsed!("visual-preprocess", self.ts, { self.preprocess(xs)? }); + let xs = elapsed!("visual-inference", self.ts, { self.inference(xs)? }); + let x = elapsed!("visual-postprocess", self.ts, { xs[0].to_owned() }); + + Ok(x) + } +} diff --git a/src/models/dinov2/mod.rs b/src/models/dinov2/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/dinov2/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/fastvit/config.rs b/src/models/fastvit/config.rs new file mode 100644 index 0000000..c0eca50 --- /dev/null +++ b/src/models/fastvit/config.rs @@ -0,0 +1,74 @@ +use crate::models::IMAGENET_NAMES_1K; + +/// Model configuration for `FastViT` +impl crate::Options { + pub fn fastvit() -> Self { + Self::default() + .with_model_name("fastvit") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_normalize(true) + .with_apply_softmax(true) + .with_class_names(&IMAGENET_NAMES_1K) + } + + pub fn fastvit_t8() -> Self { + Self::fastvit().with_model_file("t8.onnx") + } + + pub fn fastvit_t8_distill() -> Self { + Self::fastvit().with_model_file("t8-distill.onnx") + } + + pub fn fastvit_t12() -> Self { + Self::fastvit().with_model_file("t12.onnx") + } + + pub fn fastvit_t12_distill() -> Self { + Self::fastvit().with_model_file("t12-distill.onnx") + } + + pub fn fastvit_s12() -> Self { + Self::fastvit().with_model_file("s12.onnx") + } + + pub fn fastvit_s12_distill() -> Self { + Self::fastvit().with_model_file("s12-distill.onnx") + } + + pub fn fastvit_sa12() -> Self { + Self::fastvit().with_model_file("sa12.onnx") + } + + pub fn fastvit_sa12_distill() -> Self { + Self::fastvit().with_model_file("sa12-distill.onnx") + } + + pub fn fastvit_sa24() -> Self { + Self::fastvit().with_model_file("sa24.onnx") + } + + pub fn fastvit_sa24_distill() -> Self { + Self::fastvit().with_model_file("sa24-distill.onnx") + } + + pub fn fastvit_sa36() -> Self { + Self::fastvit().with_model_file("sa36.onnx") + } + + pub fn fastvit_sa36_distill() -> Self { + Self::fastvit().with_model_file("sa36-distill.onnx") + } + + pub fn fastvit_ma36() -> Self { + Self::fastvit().with_model_file("ma36.onnx") + } + + pub fn fastvit_ma36_distill() -> Self { + Self::fastvit().with_model_file("ma36-distill.onnx") + } +} diff --git a/src/models/fastvit/mod.rs b/src/models/fastvit/mod.rs new file mode 100644 index 0000000..1bf79df --- /dev/null +++ b/src/models/fastvit/mod.rs @@ -0,0 +1 @@ +mod config; diff --git a/src/models/florence2.rs b/src/models/florence2.rs deleted file mode 100644 index bcfb9b7..0000000 --- a/src/models/florence2.rs +++ /dev/null @@ -1,459 +0,0 @@ -use anyhow::Result; -use image::DynamicImage; -use ndarray::{s, Axis}; -use rayon::prelude::*; -use std::collections::BTreeMap; -use tokenizers::Tokenizer; - -use crate::{ - build_progress_bar, Bbox, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, Polygon, - Quantizer, Task, Xs, X, Y, -}; - -#[derive(Debug)] -pub struct Florence2 { - pub vision_encoder: OrtEngine, - pub text_embed: OrtEngine, - pub encoder: OrtEngine, - pub decoder: OrtEngine, - pub decoder_merged: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, - tokenizer: Tokenizer, - max_length: usize, - quantizer: Quantizer, -} - -impl Florence2 { - pub fn new( - options_vision_encoder: Options, - options_text_embed: Options, - options_encoder: Options, - options_decoder: Options, - options_decoder_merged: Options, - ) -> Result { - let mut vision_encoder = OrtEngine::new(&options_vision_encoder)?; - let mut text_embed = OrtEngine::new(&options_text_embed)?; - let mut encoder = OrtEngine::new(&options_encoder)?; - let mut decoder = OrtEngine::new(&options_decoder)?; - let mut decoder_merged = OrtEngine::new(&options_decoder_merged)?; - let (batch, height, width) = ( - vision_encoder.batch().to_owned(), - vision_encoder.height().to_owned(), - vision_encoder.width().to_owned(), - ); - let tokenizer = options_text_embed - .tokenizer - .ok_or(anyhow::anyhow!("No tokenizer file found"))?; - let tokenizer = match Tokenizer::from_file(tokenizer) { - Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err), - Ok(x) => x, - }; - - let quantizer = Quantizer::default(); - let max_length = 1024; - - // dry run - vision_encoder.dry_run()?; - text_embed.dry_run()?; - encoder.dry_run()?; - decoder.dry_run()?; - decoder_merged.dry_run()?; - - Ok(Self { - vision_encoder, - text_embed, - encoder, - decoder, - decoder_merged, - height, - width, - batch, - tokenizer, - max_length, - quantizer, - }) - } - - pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result { - let xs_ = X::apply(&[ - Ops::Resize( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Bilinear", - ), - Ops::Normalize(0., 255.), - Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3), - Ops::Nhwc2nchw, - ])?; - let ys = self.vision_encoder.run(Xs::from(xs_))?[0].to_owned(); - Ok(ys) - } - - pub fn run_with_tasks( - &mut self, - xs: &[DynamicImage], - tasks: &[Task], - ) -> Result>> { - let mut ys: BTreeMap> = BTreeMap::new(); - - // encode images - let image_embeddings = self.encode_images(xs)?; - - // note: the length of xs is not always equal to batch size - self.batch.update_opt(xs.len() as _); - - // build pb - let pb = build_progress_bar( - tasks.len() as u64, - "Working On", - None, - crate::PROGRESS_BAR_STYLE_CYAN_2, - )?; - - // tasks - for task in tasks.iter() { - pb.inc(1); - pb.set_message(format!("{:?}", task)); - - // construct prompt and encode - let input_ids = self - .encode_prompt(task)? - .insert_axis(0)? - .repeat(0, self.batch())?; - let text_embeddings = self.text_embed.run(Xs::from(input_ids))?[0].clone(); - - // run - let texts = self.run_batch(&image_embeddings, &text_embeddings)?; - - // tasks iteration - let ys_task = (0..self.batch()) - .into_par_iter() - .map(|batch| { - // image size - let image_width = xs[batch].width() as usize; - let image_height = xs[batch].height() as usize; - - // texts cleanup - let text = texts[batch] - .as_str() - .replace("", "") - .replace("", "") - .replace("", ""); - - // postprocess - let mut y = Y::default(); - if let Task::Caption(_) | Task::Ocr = task { - y = y.with_texts(&[&text]); - } else { - let elems = Self::loc_parse(&text)?; - match task { - Task::RegionToCategory(..) | Task::RegionToDescription(..) => { - let text = elems[0][0].clone(); - y = y.with_texts(&[&text]); - } - Task::ObjectDetection - | Task::OpenSetDetection(_) - | Task::DenseRegionCaption - | Task::CaptionToPhraseGrounding(_) => { - let y_bboxes: Vec = elems - .par_iter() - .enumerate() - .flat_map(|(i, elem)| { - Self::process_bboxes( - &elem[1..], - &self.quantizer, - image_width, - image_height, - Some((&elem[0], i)), - ) - }) - .collect(); - y = y.with_bboxes(&y_bboxes); - } - Task::RegionProposal => { - let y_bboxes: Vec = Self::process_bboxes( - &elems[0], - &self.quantizer, - image_width, - image_height, - None, - ); - y = y.with_bboxes(&y_bboxes); - } - Task::ReferringExpressionSegmentation(_) - | Task::RegionToSegmentation(..) => { - let points = Self::process_polygons( - &elems[0], - &self.quantizer, - image_width, - image_height, - ); - y = y.with_polygons(&[Polygon::default() - .with_points(&points) - .with_id(0)]); - } - Task::OcrWithRegion => { - let y_polygons: Vec = elems - .par_iter() - .enumerate() - .map(|(i, elem)| { - let points = Self::process_polygons( - &elem[1..], - &self.quantizer, - image_width, - image_height, - ); - Polygon::default() - .with_name(&elem[0]) - .with_points(&points) - .with_id(i as _) - }) - .collect(); - y = y.with_polygons(&y_polygons); - } - _ => anyhow::bail!("Unsupported Florence2 task."), - }; - } - Ok(y) - }) - .collect::>>()?; - - ys.insert(task.clone(), ys_task); - } - - // update pb - pb.set_prefix("Completed"); - pb.set_message("Florence2 tasks"); - pb.set_style(indicatif::ProgressStyle::with_template( - crate::PROGRESS_BAR_STYLE_FINISH_2, - )?); - pb.finish(); - - Ok(ys) - } - - fn run_batch(&mut self, image_embeddings: &X, text_embeddings: &X) -> Result> { - // concate image_embeddings and prompt embeddings - let inputs_embeds = image_embeddings.clone().concatenate(text_embeddings, 1)?; - let attention_mask = X::ones(&[self.batch(), inputs_embeds.dims()[1]]); - - // encoder - let last_hidden_state = self.encoder.run(Xs::from(vec![ - attention_mask.clone(), - inputs_embeds.clone(), - ]))?[0] - .clone(); - - // decoder - let inputs_embeds = inputs_embeds.slice(s![.., -1.., ..]); - let inputs_embeds = X::from(inputs_embeds.to_owned().into_dyn()); - let mut decoder_outputs = self.decoder.run(Xs::from(vec![ - attention_mask.clone(), - last_hidden_state.clone(), - inputs_embeds, - ]))?; - - let encoder_k0 = decoder_outputs[3].clone(); - let encoder_v0 = decoder_outputs[4].clone(); - let encoder_k1 = decoder_outputs[7].clone(); - let encoder_v1 = decoder_outputs[8].clone(); - let encoder_k2 = decoder_outputs[11].clone(); - let encoder_v2 = decoder_outputs[12].clone(); - let encoder_k3 = decoder_outputs[15].clone(); - let encoder_v3 = decoder_outputs[16].clone(); - let encoder_k4 = decoder_outputs[19].clone(); - let encoder_v4 = decoder_outputs[20].clone(); - let encoder_k5 = decoder_outputs[23].clone(); - let encoder_v5 = decoder_outputs[24].clone(); - - let mut generated_tokens: Vec> = vec![vec![]; self.batch()]; - let mut finished = vec![false; self.batch()]; - - // save last batch tokens - let mut last_tokens: Vec = vec![0.; self.batch()]; - let mut logits_sampler = LogitsSampler::new(); - - // generate - for _ in 0..self.max_length { - let logits = &decoder_outputs["logits"]; - let decoder_k0 = &decoder_outputs[1]; - let decoder_v0 = &decoder_outputs[2]; - let decoder_k1 = &decoder_outputs[5]; - let decoder_v1 = &decoder_outputs[6]; - let decoder_k2 = &decoder_outputs[9]; - let decoder_v2 = &decoder_outputs[10]; - let decoder_k3 = &decoder_outputs[13]; - let decoder_v3 = &decoder_outputs[14]; - let decoder_k4 = &decoder_outputs[17]; - let decoder_v4 = &decoder_outputs[18]; - let decoder_k5 = &decoder_outputs[21]; - let decoder_v5 = &decoder_outputs[22]; - - // decode each token for each batch - for (i, logit) in logits.axis_iter(Axis(0)).enumerate() { - if !finished[i] { - let token_id = logits_sampler.decode( - &logit - .slice(s![-1, ..]) - .into_owned() - .into_raw_vec_and_offset() - .0, - )?; // - generated_tokens[i].push(token_id); - - // update last_tokens - last_tokens[i] = token_id as f32; - - if token_id == 2 { - finished[i] = true; - } - } - } - - // all finished? - if finished.iter().all(|&x| x) { - break; - } - - // next input text embedding - let next_tokens = X::from(last_tokens.clone()).insert_axis(1)?; - - // decode - let inputs_embeds = &self.text_embed.run(Xs::from(next_tokens))?[0].clone(); - let use_cache = X::ones(&[1]); - decoder_outputs = self.decoder_merged.run(Xs::from(vec![ - attention_mask.clone(), - last_hidden_state.clone(), - inputs_embeds.clone(), - decoder_k0.clone(), - decoder_v0.clone(), - encoder_k0.clone(), - encoder_v0.clone(), - decoder_k1.clone(), - decoder_v1.clone(), - encoder_k1.clone(), - encoder_v1.clone(), - decoder_k2.clone(), - decoder_v2.clone(), - encoder_k2.clone(), - encoder_v2.clone(), - decoder_k3.clone(), - decoder_v3.clone(), - encoder_k3.clone(), - encoder_v3.clone(), - decoder_k4.clone(), - decoder_v4.clone(), - encoder_k4.clone(), - encoder_v4.clone(), - decoder_k5.clone(), - decoder_v5.clone(), - encoder_k5.clone(), - encoder_v5.clone(), - use_cache, - ]))?; - } - - // batch decode - let texts = match self.tokenizer.decode_batch( - &generated_tokens - .iter() - .map(|tokens| tokens.as_slice()) - .collect::>(), - false, - ) { - Err(err) => anyhow::bail!("{:?}", err), - Ok(xs) => xs, - }; - - Ok(texts) - } - - pub fn encode_prompt(&self, task: &Task) -> Result { - let prompt = task.prompt_for_florence2()?; - let encodings = match self.tokenizer.encode(prompt, true) { - Err(err) => anyhow::bail!("{}", err), - Ok(x) => x, - }; - let ids: Vec = encodings.get_ids().iter().map(|x| *x as f32).collect(); - - Ok(X::from(ids)) - } - - fn process_polygons( - elems: &[String], - quantizer: &Quantizer, - image_width: usize, - image_height: usize, - ) -> Vec> { - elems - .par_chunks(2) - .map(|chunk| { - let coord: Vec<_> = chunk.iter().map(|s| s.parse::().unwrap()).collect(); - quantizer.dequantize(&coord, (image_width, image_height)) - }) - .collect() - } - - fn process_bboxes( - elems: &[String], - quantizer: &Quantizer, - image_width: usize, - image_height: usize, - class_name: Option<(&str, usize)>, - ) -> Vec { - elems - .par_chunks(4) - .enumerate() - .map(|(i, chunk)| { - let bbox: Vec<_> = chunk.iter().map(|s| s.parse::().unwrap()).collect(); - let dequantized_bbox = quantizer.dequantize(&bbox, (image_width, image_height)); - - let mut bbox = Bbox::default().with_xyxy( - dequantized_bbox[0].max(0.0f32).min(image_width as f32), - dequantized_bbox[1].max(0.0f32).min(image_height as f32), - dequantized_bbox[2], - dequantized_bbox[3], - ); - if let Some((class_name, i)) = class_name { - bbox = bbox.with_name(class_name).with_id(i as _); - } else { - bbox = bbox.with_id(i as _); - } - - bbox - }) - .collect() - } - - fn loc_parse(hay: &str) -> Result>> { - let pattern = r"(?i)(\d+)>)|(?[^<]+)"; - let re = regex::Regex::new(pattern)?; - let mut ys: Vec> = Vec::new(); - let mut y = Vec::new(); - - for cap in re.captures_iter(hay) { - if let Some(loc) = cap.name("coord") { - y.push(loc.as_str().to_string()); - } else if let Some(text) = cap.name("name") { - if !text.as_str().is_empty() { - if !y.is_empty() { - ys.push(y); - y = Vec::new(); - } - y.push(text.as_str().to_string()); - } - } - } - if !y.is_empty() { - ys.push(y); - } - Ok(ys) - } - - pub fn batch(&self) -> usize { - self.batch.opt() - } -} diff --git a/src/models/florence2/config.rs b/src/models/florence2/config.rs new file mode 100644 index 0000000..8ef74ac --- /dev/null +++ b/src/models/florence2/config.rs @@ -0,0 +1,59 @@ +/// Model configuration for `Florence2` +impl crate::Options { + pub fn florence2() -> Self { + Self::default() + .with_model_name("florence2") + .with_batch_size(1) + } + + pub fn florence2_visual() -> Self { + Self::florence2() + .with_model_kind(crate::Kind::Vision) + .with_model_ixx(0, 2, 768.into()) + .with_model_ixx(0, 3, 768.into()) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_resize_filter("Bilinear") + .with_normalize(true) + } + + pub fn florence2_textual() -> Self { + Self::florence2().with_model_kind(crate::Kind::Language) + } + + pub fn florence2_visual_base() -> Self { + Self::florence2_visual().with_model_scale(crate::Scale::B) + } + + pub fn florence2_textual_base() -> Self { + Self::florence2_textual().with_model_scale(crate::Scale::B) + } + + pub fn florence2_visual_large() -> Self { + Self::florence2_visual().with_model_scale(crate::Scale::L) + } + + pub fn florence2_textual_large() -> Self { + Self::florence2_textual().with_model_scale(crate::Scale::L) + } + + pub fn florence2_visual_encoder_base() -> Self { + Self::florence2_visual_base().with_model_file("base-vision-encoder.onnx") + } + + pub fn florence2_textual_embed_base() -> Self { + Self::florence2_textual_base().with_model_file("base-embed-tokens.onnx") + } + + pub fn florence2_texual_encoder_base() -> Self { + Self::florence2_textual_base().with_model_file("base-encoder.onnx") + } + + pub fn florence2_texual_decoder_base() -> Self { + Self::florence2_textual_base().with_model_file("base-decoder.onnx") + } + + pub fn florence2_texual_decoder_merged_base() -> Self { + Self::florence2_textual_base().with_model_file("base-decoder-merged.onnx") + } +} diff --git a/src/models/florence2/impl.rs b/src/models/florence2/impl.rs new file mode 100644 index 0000000..52b0a16 --- /dev/null +++ b/src/models/florence2/impl.rs @@ -0,0 +1,417 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Axis}; +use rayon::prelude::*; + +use crate::{ + elapsed, + models::{BaseModelTextual, BaseModelVisual, Quantizer}, + Bbox, LogitsSampler, Options, Polygon, Scale, Task, Ts, Xs, Ys, X, Y, +}; + +#[derive(Debug, Builder)] +pub struct Florence2 { + pub vision_encoder: BaseModelVisual, + pub text_embed: BaseModelTextual, + pub encoder: BaseModelTextual, + pub decoder: BaseModelTextual, + pub decoder_merged: BaseModelTextual, + ts: Ts, + quantizer: Quantizer, + max_length: usize, + eos_token_id: u32, + decoder_start_token_id: u32, + n_kvs: usize, +} + +impl Florence2 { + pub fn new( + options_vision_encoder: Options, + options_text_embed: Options, + options_encoder: Options, + options_decoder: Options, + options_decoder_merged: Options, + ) -> Result { + let vision_encoder = BaseModelVisual::new(options_vision_encoder)?; + let text_embed = BaseModelTextual::new(options_text_embed)?; + let encoder = BaseModelTextual::new(options_encoder)?; + let decoder = BaseModelTextual::new(options_decoder)?; + let decoder_merged = BaseModelTextual::new(options_decoder_merged)?; + let quantizer = Quantizer::default(); + let ts = Ts::merge(&[ + vision_encoder.engine().ts(), + text_embed.engine().ts(), + encoder.engine().ts(), + decoder.engine().ts(), + decoder_merged.engine().ts(), + ]); + let max_length = 1024; + let eos_token_id = 2; + let decoder_start_token_id = 2; + let n_kvs = match decoder.scale() { + Some(Scale::B) => 6, + Some(Scale::L) => 12, + _ => unimplemented!(), + }; + + Ok(Self { + vision_encoder, + text_embed, + encoder, + decoder, + decoder_merged, + max_length, + quantizer, + ts, + eos_token_id, + decoder_start_token_id, + n_kvs, + }) + } + + fn process_task(task: &Task, image_height: usize, image_width: usize) -> Task { + // region-related tasks + match task { + Task::RegionToSegmentation(x0, y0, x1, y1) => { + let xyxy = Quantizer::default() + .quantize(&[*x0, *y0, *x1, *y1], (image_width, image_height)); + Task::RegionToSegmentation(xyxy[0], xyxy[1], xyxy[2], xyxy[3]) + } + Task::RegionToCategory(x0, y0, x1, y1) => { + let xyxy = Quantizer::default() + .quantize(&[*x0, *y0, *x1, *y1], (image_width, image_height)); + Task::RegionToCategory(xyxy[0], xyxy[1], xyxy[2], xyxy[3]) + } + Task::RegionToDescription(x0, y0, x1, y1) => { + let xyxy = Quantizer::default() + .quantize(&[*x0, *y0, *x1, *y1], (image_width, image_height)); + Task::RegionToDescription(xyxy[0], xyxy[1], xyxy[2], xyxy[3]) + } + _ => *task, + } + } + + fn encode_text(&mut self, task: &Task, images: &[DynamicImage]) -> Result { + let xs = images + .par_iter() + .map(|im| { + let text = Self::process_task(task, im.height() as _, im.width() as _) + .prompt_for_florence2()?; + let ids = self.text_embed.processor().encode_text_ids(&text, true)?; + X::from(ids).insert_axis(0) + }) + .collect::, _>>()?; + let x = X::concat(&xs, 0)?; + let xs = self.text_embed.inference(x.into())?; + let x = xs[0].to_owned(); + + Ok(x) + } + + pub fn forward(&mut self, xs_visual: &[DynamicImage], x_textual: &Task) -> Result { + let visual_embeddings = elapsed!("visual-encode", self.ts, { + self.vision_encoder.encode(xs_visual)? + }); + + let textual_embedding = elapsed!("textual-encode", self.ts, { + self.encode_text(x_textual, xs_visual)? + }); + + let generated = elapsed!("generate-then-decode", self.ts, { + self.generate_then_decode(&visual_embeddings, &textual_embedding)? + }); + + let ys = elapsed!("postprocess", self.ts, { + self.postprocess(&generated, xs_visual, x_textual)? + }); + + Ok(ys) + } + + // decode or postprocess, batch images and one text + fn generate_then_decode( + &mut self, + visual_embeddings: &X, + textual_embedding: &X, + ) -> Result> { + // concate image embeddings and prompt embeddings + let inputs_embeds = visual_embeddings + .clone() + .concatenate(textual_embedding, 1)?; + let attention_mask = X::ones(&[self.batch(), inputs_embeds.dims()[1]]); + + // encoder + let last_hidden_state = self.encoder.inference(Xs::from(vec![ + attention_mask.clone(), + inputs_embeds.clone(), + ]))?[0] + .clone(); + + // decoder + let inputs_embeds = inputs_embeds.slice(s![.., -1.., ..]); + let inputs_embeds = X::from(inputs_embeds.to_owned().into_dyn()); + let mut decoder_outputs = self.decoder.inference(Xs::from(vec![ + attention_mask.clone(), + last_hidden_state.clone(), + inputs_embeds, + ]))?; + + // encoder kvs + let encoder_kvs: Vec<_> = (3..4 * self.n_kvs) + .step_by(4) + .flat_map(|i| [i, i + 1]) + .map(|i| decoder_outputs[i].clone()) + .collect(); + + // token ids + let mut token_ids: Vec> = vec![vec![]; self.batch()]; + let mut finished = vec![false; self.batch()]; + let mut last_tokens: Vec = vec![0.; self.batch()]; + let mut logits_sampler = LogitsSampler::new(); + + // generate + for _ in 0..self.max_length { + let logits = &decoder_outputs[0]; + let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2) + .step_by(4) + .flat_map(|i| [i, i + 1]) + .map(|i| decoder_outputs[i].clone()) + .collect(); + + // decode each token for each batch + for (i, logit) in logits.axis_iter(Axis(0)).enumerate() { + if !finished[i] { + let token_id = logits_sampler.decode( + &logit + .slice(s![-1, ..]) + .into_owned() + .into_raw_vec_and_offset() + .0, + )?; + if token_id == self.eos_token_id { + finished[i] = true; + } else { + token_ids[i].push(token_id); + } + // update + last_tokens[i] = token_id as f32; + } + } + + // all finished? + if finished.iter().all(|&x| x) { + break; + } + + // decode + let next_tokens = X::from(last_tokens.clone()).insert_axis(1)?; + let inputs_embeds = &self.text_embed.inference(Xs::from(next_tokens))?[0].clone(); + let use_cache = X::ones(&[1]); + let mut xs = vec![ + attention_mask.clone(), + last_hidden_state.clone(), + inputs_embeds.clone(), + ]; + for i in 0..self.n_kvs { + xs.push(decoder_kvs[i * 2].clone()); + xs.push(decoder_kvs[i * 2 + 1].clone()); + xs.push(encoder_kvs[i * 2].clone()); + xs.push(encoder_kvs[i * 2 + 1].clone()); + } + xs.push(use_cache); + decoder_outputs = self.decoder_merged.inference(xs.into())?; + } + + // batch decode + let texts = self + .text_embed + .processor() + .decode_tokens_batch(&token_ids, false)?; + + Ok(texts) + } + + fn postprocess( + &mut self, + generated_text: &[String], + xs_visual: &[DynamicImage], + x_textual: &Task, + ) -> Result { + let mut ys = Vec::new(); + let ys_task = (0..self.batch()) + .into_par_iter() + .map(|batch| { + // image size + let image_width = xs_visual[batch].width() as usize; + let image_height = xs_visual[batch].height() as usize; + + // texts cleanup + let text = generated_text[batch] + .as_str() + .replace("", "") + .replace("", "") + .replace("", ""); + + // postprocess + let mut y = Y::default(); + if let Task::Caption(_) | Task::Ocr = x_textual { + y = y.with_texts(&[text.into()]); + } else { + let elems = Self::loc_parse(&text)?; + match x_textual { + Task::RegionToCategory(..) | Task::RegionToDescription(..) => { + let text = elems[0][0].clone(); + y = y.with_texts(&[text.into()]); + } + Task::ObjectDetection + | Task::OpenSetDetection(_) + | Task::DenseRegionCaption + | Task::CaptionToPhraseGrounding(_) => { + let y_bboxes: Vec = elems + .par_iter() + .enumerate() + .flat_map(|(i, elem)| { + Self::process_bboxes( + &elem[1..], + &self.quantizer, + image_width, + image_height, + Some((&elem[0], i)), + ) + }) + .collect(); + y = y.with_bboxes(&y_bboxes); + } + Task::RegionProposal => { + let y_bboxes: Vec = Self::process_bboxes( + &elems[0], + &self.quantizer, + image_width, + image_height, + None, + ); + y = y.with_bboxes(&y_bboxes); + } + Task::ReferringExpressionSegmentation(_) + | Task::RegionToSegmentation(..) => { + let points = Self::process_polygons( + &elems[0], + &self.quantizer, + image_width, + image_height, + ); + y = y.with_polygons(&[Polygon::default() + .with_points(&points) + .with_id(0)]); + } + Task::OcrWithRegion => { + let y_polygons: Vec = elems + .par_iter() + .enumerate() + .map(|(i, elem)| { + let points = Self::process_polygons( + &elem[1..], + &self.quantizer, + image_width, + image_height, + ); + Polygon::default() + .with_name(&elem[0]) + .with_points(&points) + .with_id(i as _) + }) + .collect(); + y = y.with_polygons(&y_polygons); + } + _ => anyhow::bail!("Unsupported Florence2 task."), + }; + } + Ok(y) + }) + .collect::>>()?; + + ys.extend_from_slice(&ys_task); + + Ok(ys.into()) + } + + fn process_polygons( + elems: &[String], + quantizer: &Quantizer, + image_width: usize, + image_height: usize, + ) -> Vec> { + elems + .par_chunks(2) + .map(|chunk| { + let coord: Vec<_> = chunk.iter().map(|s| s.parse::().unwrap()).collect(); + quantizer.dequantize(&coord, (image_width, image_height)) + }) + .collect() + } + + fn process_bboxes( + elems: &[String], + quantizer: &Quantizer, + image_width: usize, + image_height: usize, + class_name: Option<(&str, usize)>, + ) -> Vec { + elems + .par_chunks(4) + .enumerate() + .map(|(i, chunk)| { + let bbox: Vec<_> = chunk.iter().map(|s| s.parse::().unwrap()).collect(); + let dequantized_bbox = quantizer.dequantize(&bbox, (image_width, image_height)); + + let mut bbox = Bbox::default().with_xyxy( + dequantized_bbox[0].max(0.0f32).min(image_width as f32), + dequantized_bbox[1].max(0.0f32).min(image_height as f32), + dequantized_bbox[2], + dequantized_bbox[3], + ); + if let Some((class_name, i)) = class_name { + bbox = bbox.with_name(class_name).with_id(i as _); + } else { + bbox = bbox.with_id(i as _); + } + + bbox + }) + .collect() + } + + fn loc_parse(hay: &str) -> Result>> { + let pattern = r"(?i)(\d+)>)|(?[^<]+)"; + let re = regex::Regex::new(pattern)?; + let mut ys: Vec> = Vec::new(); + let mut y = Vec::new(); + + for cap in re.captures_iter(hay) { + if let Some(loc) = cap.name("coord") { + y.push(loc.as_str().to_string()); + } else if let Some(text) = cap.name("name") { + if !text.as_str().is_empty() { + if !y.is_empty() { + ys.push(y); + y = Vec::new(); + } + y.push(text.as_str().to_string()); + } + } + } + if !y.is_empty() { + ys.push(y); + } + Ok(ys) + } + + pub fn batch(&self) -> usize { + self.vision_encoder.batch() as _ + } + + pub fn summary(&mut self) { + self.ts.summary(); + } +} diff --git a/src/models/florence2/mod.rs b/src/models/florence2/mod.rs new file mode 100644 index 0000000..5405ab1 --- /dev/null +++ b/src/models/florence2/mod.rs @@ -0,0 +1,6 @@ +mod config; +mod r#impl; +mod quantizer; + +pub use quantizer::Quantizer; +pub use r#impl::*; diff --git a/src/utils/quantizer.rs b/src/models/florence2/quantizer.rs similarity index 76% rename from src/utils/quantizer.rs rename to src/models/florence2/quantizer.rs index 1a3a6ac..615247d 100644 --- a/src/utils/quantizer.rs +++ b/src/models/florence2/quantizer.rs @@ -22,7 +22,7 @@ impl Quantizer { ((val as f64 + 0.5) * bin_size) as f32 } - fn quantize_internal(&self, input: &[f32], size: (usize, usize)) -> Vec { + fn quantize_internal(&self, input: &[usize], size: (usize, usize)) -> Vec { let (bins_w, bins_h) = self.bins; let (size_w, size_h) = size; @@ -31,14 +31,14 @@ impl Quantizer { match input.len() { 4 => vec![ - self.quantize_value(input[0], size_per_bin_w, bins_w), - self.quantize_value(input[1], size_per_bin_h, bins_h), - self.quantize_value(input[2], size_per_bin_w, bins_w), - self.quantize_value(input[3], size_per_bin_h, bins_h), + self.quantize_value(input[0] as f32, size_per_bin_w, bins_w), + self.quantize_value(input[1] as f32, size_per_bin_h, bins_h), + self.quantize_value(input[2] as f32, size_per_bin_w, bins_w), + self.quantize_value(input[3] as f32, size_per_bin_h, bins_h), ], 2 => vec![ - self.quantize_value(input[0], size_per_bin_w, bins_w), - self.quantize_value(input[1], size_per_bin_h, bins_h), + self.quantize_value(input[0] as f32, size_per_bin_w, bins_w), + self.quantize_value(input[1] as f32, size_per_bin_h, bins_h), ], _ => panic!( "Error: Unsupported input length: {} in Quantizer.", @@ -72,7 +72,7 @@ impl Quantizer { } } - pub fn quantize(&self, input: &[f32], size: (usize, usize)) -> Vec { + pub fn quantize(&self, input: &[usize], size: (usize, usize)) -> Vec { self.quantize_internal(input, size) } diff --git a/src/models/grounding_dino.rs b/src/models/grounding_dino.rs deleted file mode 100644 index 8700c91..0000000 --- a/src/models/grounding_dino.rs +++ /dev/null @@ -1,245 +0,0 @@ -use crate::{Bbox, DynConf, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; -use anyhow::Result; -use image::DynamicImage; -use ndarray::{s, Array, Axis}; -use rayon::prelude::*; -use tokenizers::{Encoding, Tokenizer}; - -#[derive(Debug)] -pub struct GroundingDINO { - pub engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, - tokenizer: Tokenizer, - pub context_length: usize, - confs_visual: DynConf, - confs_textual: DynConf, -} - -impl GroundingDINO { - pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.inputs_minoptmax()[0][0].to_owned(), - engine.inputs_minoptmax()[0][2].to_owned(), - engine.inputs_minoptmax()[0][3].to_owned(), - ); - let context_length = options.context_length.unwrap_or(256); - // let special_tokens = ["[CLS]", "[SEP]", ".", "?"]; - let tokenizer = options - .tokenizer - .ok_or(anyhow::anyhow!("No tokenizer file found"))?; - let tokenizer = match Tokenizer::from_file(tokenizer) { - Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err), - Ok(x) => x, - }; - let confs_visual = DynConf::new(&options.confs, 1); - let confs_textual = DynConf::new(&options.confs, 1); - - engine.dry_run()?; - - Ok(Self { - engine, - batch, - height, - width, - tokenizer, - context_length, - confs_visual, - confs_textual, - }) - } - - pub fn run(&mut self, xs: &[DynamicImage], texts: &[&str]) -> Result> { - // image embeddings - let image_embeddings = X::apply(&[ - Ops::Letterbox( - xs, - self.height() as u32, - self.width() as u32, - "CatmullRom", - 114, - "auto", - false, - ), - Ops::Normalize(0., 255.), - Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3), - Ops::Nhwc2nchw, - ])?; - - // encoding - let text = Self::parse_texts(texts); - let encoding = match self.tokenizer.encode(text, true) { - Err(err) => anyhow::bail!("{}", err), - Ok(x) => x, - }; - let tokens = encoding.get_tokens(); - - // input_ids - let input_ids = X::from( - encoding - .get_ids() - .iter() - .map(|&x| x as f32) - .collect::>(), - ) - .insert_axis(0)? - .repeat(0, self.batch() as usize)?; - - // token_type_ids - let token_type_ids = X::zeros(&[self.batch() as usize, tokens.len()]); - - // attention_mask - let attention_mask = X::ones(&[self.batch() as usize, tokens.len()]); - - // position_ids - let position_ids = X::from( - encoding - .get_tokens() - .iter() - .map(|x| if x == "." { 1. } else { 0. }) - .collect::>(), - ) - .insert_axis(0)? - .repeat(0, self.batch() as usize)?; - - // text_self_attention_masks - let text_self_attention_masks = Self::gen_text_self_attention_masks(&encoding)? - .insert_axis(0)? - .repeat(0, self.batch() as usize)?; - - // run - let ys = self.engine.run(Xs::from(vec![ - image_embeddings, - input_ids, - attention_mask, - position_ids, - token_type_ids, - text_self_attention_masks, - ]))?; - - // post-process - self.postprocess(ys, xs, tokens) - } - - fn postprocess(&self, xs: Xs, xs0: &[DynamicImage], tokens: &[String]) -> Result> { - let ys: Vec = xs["logits"] - .axis_iter(Axis(0)) - .into_par_iter() - .enumerate() - .filter_map(|(idx, logits)| { - let image_width = xs0[idx].width() as f32; - let image_height = xs0[idx].height() as f32; - let ratio = - (self.width() as f32 / image_width).min(self.height() as f32 / image_height); - - let y_bboxes: Vec = logits - .axis_iter(Axis(0)) - .into_par_iter() - .enumerate() - .filter_map(|(i, clss)| { - let (class_id, &conf) = clss - .mapv(|x| 1. / ((-x).exp() + 1.)) - .iter() - .enumerate() - .max_by(|a, b| a.1.total_cmp(b.1))?; - - if conf < self.conf_visual() { - return None; - } - - let bbox = xs["boxes"].slice(s![idx, i, ..]).mapv(|x| x / ratio); - let cx = bbox[0] * self.width() as f32; - let cy = bbox[1] * self.height() as f32; - let w = bbox[2] * self.width() as f32; - let h = bbox[3] * self.height() as f32; - let x = cx - w / 2.; - let y = cy - h / 2.; - let x = x.max(0.0).min(image_width); - let y = y.max(0.0).min(image_height); - - Some( - Bbox::default() - .with_xywh(x, y, w, h) - .with_id(class_id as _) - .with_name(&tokens[class_id]) - .with_confidence(conf), - ) - }) - .collect(); - - if !y_bboxes.is_empty() { - Some(Y::default().with_bboxes(&y_bboxes)) - } else { - None - } - }) - .collect(); - Ok(ys) - } - - fn parse_texts(texts: &[&str]) -> String { - let mut y = String::new(); - for text in texts.iter() { - if !text.is_empty() { - y.push_str(&format!("{} . ", text)); - } - } - y - } - - fn gen_text_self_attention_masks(encoding: &Encoding) -> Result { - let mut vs = encoding - .get_tokens() - .iter() - .map(|x| if x == "." { 1. } else { 0. }) - .collect::>(); - - let n = vs.len(); - vs[0] = 1.; - vs[n - 1] = 1.; - let mut ys = Array::zeros((n, n)).into_dyn(); - let mut i_last = -1; - for (i, &v) in vs.iter().enumerate() { - if v == 0. { - if i_last == -1 { - i_last = i as isize; - } else { - i_last = -1; - } - } else if v == 1. { - if i_last == -1 { - ys.slice_mut(s![i, i]).fill(1.); - } else { - ys.slice_mut(s![i_last as _..i + 1, i_last as _..i + 1]) - .fill(1.); - } - i_last = -1; - } else { - continue; - } - } - Ok(X::from(ys)) - } - - pub fn conf_visual(&self) -> f32 { - self.confs_visual[0] - } - - pub fn conf_textual(&self) -> f32 { - self.confs_textual[0] - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - - pub fn height(&self) -> isize { - self.height.opt() as _ - } -} diff --git a/src/models/grounding_dino/config.rs b/src/models/grounding_dino/config.rs new file mode 100644 index 0000000..4c54ee0 --- /dev/null +++ b/src/models/grounding_dino/config.rs @@ -0,0 +1,22 @@ +/// Model configuration for `GroundingDino` +impl crate::Options { + pub fn grounding_dino() -> Self { + Self::default() + .with_model_name("grounding-dino") + .with_model_kind(crate::Kind::VisionLanguage) + .with_model_ixx(0, 0, 1.into()) // TODO: current onnx model does not support bs > 1 + .with_model_ixx(0, 2, 800.into()) // TODO: matters + .with_model_ixx(0, 3, 1200.into()) // TODO: matters + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_resize_filter("CatmullRom") + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_normalize(true) + .with_class_confs(&[0.4]) + .with_text_confs(&[0.3]) + } + + pub fn grounding_dino_tiny() -> Self { + Self::grounding_dino().with_model_file("swint-ogc.onnx") + } +} diff --git a/src/models/grounding_dino/impl.rs b/src/models/grounding_dino/impl.rs new file mode 100644 index 0000000..46e8fc8 --- /dev/null +++ b/src/models/grounding_dino/impl.rs @@ -0,0 +1,223 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Array, Axis}; +use rayon::prelude::*; + +use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y}; + +#[derive(Builder, Debug)] +pub struct GroundingDINO { + pub engine: Engine, + height: usize, + width: usize, + batch: usize, + confs_visual: DynConf, + confs_textual: DynConf, + class_names: Vec, + tokens: Vec, + token_ids: Vec, + ts: Ts, + processor: Processor, + spec: String, +} + +impl GroundingDINO { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&800.into()).opt(), + engine.try_width().unwrap_or(&1200.into()).opt(), + engine.ts().clone(), + ); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let confs_visual = DynConf::new(options.class_confs(), 1); + let confs_textual = DynConf::new(options.text_confs(), 1); + + let class_names = Self::parse_texts( + &options + .text_names + .expect("No class names specified!") + .iter() + .map(|x| x.as_str()) + .collect::>(), + ); + let token_ids = processor.encode_text_ids(&class_names, true)?; + let tokens = processor.encode_text_tokens(&class_names, true)?; + let class_names = tokens.clone(); + + Ok(Self { + engine, + batch, + height, + width, + confs_visual, + confs_textual, + class_names, + token_ids, + tokens, + ts, + processor, + spec, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + // encode images + let image_embeddings = self.processor.process_images(xs)?; + + // encode texts + let tokens_f32 = self + .tokens + .iter() + .map(|x| if x == "." { 1. } else { 0. }) + .collect::>(); + + // input_ids + let input_ids = X::from(self.token_ids.clone()) + .insert_axis(0)? + .repeat(0, self.batch)?; + + // token_type_ids + let token_type_ids = X::zeros(&[self.batch, tokens_f32.len()]); + + // attention_mask + let attention_mask = X::ones(&[self.batch, tokens_f32.len()]); + + // text_self_attention_masks + let text_self_attention_masks = Self::gen_text_self_attention_masks(&tokens_f32)? + .insert_axis(0)? + .repeat(0, self.batch)?; + + // position_ids + let position_ids = X::from(tokens_f32).insert_axis(0)?.repeat(0, self.batch)?; + + // inputs + let xs = Xs::from(vec![ + image_embeddings, + input_ids, + attention_mask, + position_ids, + token_type_ids, + text_self_attention_masks, + ]); + + Ok(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + fn postprocess(&self, xs: Xs) -> Result { + let ys: Vec = xs["logits"] + .axis_iter(Axis(0)) + .into_par_iter() + .enumerate() + .filter_map(|(idx, logits)| { + let (image_height, image_width) = self.processor.image0s_size[idx]; + let ratio = self.processor.scale_factors_hw[idx][0]; + + let y_bboxes: Vec = logits + .axis_iter(Axis(0)) + .into_par_iter() + .enumerate() + .filter_map(|(i, clss)| { + let (class_id, &conf) = clss + .mapv(|x| 1. / ((-x).exp() + 1.)) + .iter() + .enumerate() + .max_by(|a, b| a.1.total_cmp(b.1))?; + + if conf < self.confs_visual[0] { + return None; + } + + let bbox = xs["boxes"].slice(s![idx, i, ..]).mapv(|x| x / ratio); + let cx = bbox[0] * self.width as f32; + let cy = bbox[1] * self.height as f32; + let w = bbox[2] * self.width as f32; + let h = bbox[3] * self.height as f32; + let x = cx - w / 2.; + let y = cy - h / 2.; + let x = x.max(0.0).min(image_width as _); + let y = y.max(0.0).min(image_height as _); + + Some( + Bbox::default() + .with_xywh(x, y, w, h) + .with_id(class_id as _) + .with_name(&self.class_names[class_id]) + .with_confidence(conf), + ) + }) + .collect(); + + if !y_bboxes.is_empty() { + Some(Y::default().with_bboxes(&y_bboxes)) + } else { + None + } + }) + .collect(); + + Ok(ys.into()) + } + + fn parse_texts(texts: &[&str]) -> String { + let mut y = String::new(); + for text in texts.iter() { + if !text.is_empty() { + y.push_str(&format!("{} . ", text)); + } + } + y + } + + fn gen_text_self_attention_masks(tokens: &[f32]) -> Result { + let mut vs = tokens.to_vec(); + let n = vs.len(); + vs[0] = 1.; + vs[n - 1] = 1.; + let mut ys = Array::zeros((n, n)).into_dyn(); + let mut i_last = -1; + for (i, &v) in vs.iter().enumerate() { + if v == 0. { + if i_last == -1 { + i_last = i as isize; + } else { + i_last = -1; + } + } else if v == 1. { + if i_last == -1 { + ys.slice_mut(s![i, i]).fill(1.); + } else { + ys.slice_mut(s![i_last as _..i + 1, i_last as _..i + 1]) + .fill(1.); + } + i_last = -1; + } else { + continue; + } + } + Ok(X::from(ys)) + } +} diff --git a/src/models/grounding_dino/mod.rs b/src/models/grounding_dino/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/grounding_dino/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/image_classifier.rs b/src/models/image_classifier.rs new file mode 100644 index 0000000..25ccfaa --- /dev/null +++ b/src/models/image_classifier.rs @@ -0,0 +1,125 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::Axis; +use rayon::prelude::*; + +use crate::{elapsed, DynConf, Engine, Options, Prob, Processor, Ts, Xs, Ys, Y}; + +#[derive(Debug, Builder)] +pub struct ImageClassifier { + engine: Engine, + height: usize, + width: usize, + batch: usize, + apply_softmax: bool, + ts: Ts, + processor: Processor, + confs: DynConf, + nc: usize, + names: Vec, + spec: String, +} + +impl TryFrom for ImageClassifier { + type Error = anyhow::Error; + + fn try_from(options: Options) -> Result { + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&224.into()).opt(), + engine.try_width().unwrap_or(&224.into()).opt(), + engine.ts().clone(), + ); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let (nc, names) = match (options.nc(), options.class_names()) { + (Some(nc), Some(names)) => { + if nc != names.len() { + anyhow::bail!( + "The length of the input class names: {} is inconsistent with the number of classes: {}.", + names.len(), + nc + ); + } + (nc, names.to_vec()) + } + (Some(nc), None) => ( + nc, + (0..nc).map(|x| format!("# {}", x)).collect::>(), + ), + (None, Some(names)) => (names.len(), names.to_vec()), + (None, None) => { + anyhow::bail!("Neither class names nor class numbers were specified."); + } + }; + let confs = DynConf::new(options.class_confs(), nc); + let apply_softmax = options.apply_softmax.unwrap_or_default(); + + Ok(Self { + engine, + height, + width, + batch, + nc, + ts, + spec, + processor, + confs, + names, + apply_softmax, + }) + } +} + +impl ImageClassifier { + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let x = self.processor.process_images(xs)?; + + Ok(x.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + fn postprocess(&self, xs: Xs) -> Result { + let ys: Ys = xs[0] + .axis_iter(Axis(0)) + .into_par_iter() + .filter_map(|logits| { + let logits = if self.apply_softmax { + let exps = logits.mapv(|x| x.exp()); + let stds = exps.sum_axis(Axis(0)); + exps / stds + } else { + logits.into_owned() + }; + let probs = Prob::default() + .with_probs(&logits.into_raw_vec_and_offset().0) + .with_names(&self.names.iter().map(|x| x.as_str()).collect::>()); + + Some(Y::default().with_probs(probs)) + }) + .collect::>() + .into(); + + Ok(ys) + } +} diff --git a/src/models/kind.rs b/src/models/kind.rs new file mode 100644 index 0000000..4519427 --- /dev/null +++ b/src/models/kind.rs @@ -0,0 +1,18 @@ +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +pub enum Kind { + // Do we really need this? + Vision, + Language, + VisionLanguage, +} + +impl std::fmt::Display for Kind { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let x = match self { + Self::Vision => "visual", + Self::Language => "textual", + Self::VisionLanguage => "vl", + }; + write!(f, "{}", x) + } +} diff --git a/src/models/labels.rs b/src/models/labels.rs new file mode 100644 index 0000000..0415615 --- /dev/null +++ b/src/models/labels.rs @@ -0,0 +1,1155 @@ +pub const COCO_SKELETONS_16: [(usize, usize); 16] = [ + (0, 1), + (0, 2), + (1, 3), + (2, 4), + (5, 6), + (5, 11), + (6, 12), + (11, 12), + (5, 7), + (6, 8), + (7, 9), + (8, 10), + (11, 13), + (12, 14), + (13, 15), + (14, 16), +]; + +pub const COCO_KEYPOINTS_NAMES_17: [&str; 17] = [ + "nose", + "left_eye", + "right_eye", + "left_ear", + "right_ear", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "left_hip", + "right_hip", + "left_knee", + "right_knee", + "left_ankle", + "right_ankle", +]; + +pub const COCO_CLASS_NAMES_80: [&str; 80] = [ + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "dining table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", +]; + +pub const BODY_PARTS_NAMES_28: [&str; 28] = [ + "Background", + "Apparel", + "Face Neck", + "Hair", + "Left Foot", + "Left Hand", + "Left Lower Arm", + "Left Lower Leg", + "Left Shoe", + "Left Sock", + "Left Upper Arm", + "Left Upper Leg", + "Lower Clothing", + "Right Foot", + "Right Hand", + "Right Lower Arm", + "Right Lower Leg", + "Right Shoe", + "Right Sock", + "Right Upper Arm", + "Right Upper Leg", + "Torso", + "Upper Clothing", + "Lower Lip", + "Upper Lip", + "Lower Teeth", + "Upper Teeth", + "Tongue", +]; + +pub const IMAGENET_NAMES_1K: [&str; 1000] = [ + "tench, Tinca tinca", + "goldfish, Carassius auratus", + "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", + "tiger shark, Galeocerdo cuvieri", + "hammerhead, hammerhead shark", + "electric ray, crampfish, numbfish, torpedo", + "stingray", + "cock", + "hen", + "ostrich, Struthio camelus", + "brambling, Fringilla montifringilla", + "goldfinch, Carduelis carduelis", + "house finch, linnet, Carpodacus mexicanus", + "junco, snowbird", + "indigo bunting, indigo finch, indigo bird, Passerina cyanea", + "robin, American robin, Turdus migratorius", + "bulbul", + "jay", + "magpie", + "chickadee", + "water ouzel, dipper", + "kite", + "bald eagle, American eagle, Haliaeetus leucocephalus", + "vulture", + "great grey owl, great gray owl, Strix nebulosa", + "European fire salamander, Salamandra salamandra", + "common newt, Triturus vulgaris", + "eft", + "spotted salamander, Ambystoma maculatum", + "axolotl, mud puppy, Ambystoma mexicanum", + "bullfrog, Rana catesbeiana", + "tree frog, tree-frog", + "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", + "loggerhead, loggerhead turtle, Caretta caretta", + "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", + "mud turtle", + "terrapin", + "box turtle, box tortoise", + "banded gecko", + "common iguana, iguana, Iguana iguana", + "American chameleon, anole, Anolis carolinensis", + "whiptail, whiptail lizard", + "agama", + "frilled lizard, Chlamydosaurus kingi", + "alligator lizard", + "Gila monster, Heloderma suspectum", + "green lizard, Lacerta viridis", + "African chameleon, Chamaeleo chamaeleon", + "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", + "African crocodile, Nile crocodile, Crocodylus niloticus", + "American alligator, Alligator mississipiensis", + "triceratops", + "thunder snake, worm snake, Carphophis amoenus", + "ringneck snake, ring-necked snake, ring snake", + "hognose snake, puff adder, sand viper", + "green snake, grass snake", + "king snake, kingsnake", + "garter snake, grass snake", + "water snake", + "vine snake", + "night snake, Hypsiglena torquata", + "boa constrictor, Constrictor constrictor", + "rock python, rock snake, Python sebae", + "Indian cobra, Naja naja", + "green mamba", + "sea snake", + "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", + "diamondback, diamondback rattlesnake, Crotalus adamanteus", + "sidewinder, horned rattlesnake, Crotalus cerastes", + "trilobite", + "harvestman, daddy longlegs, Phalangium opilio", + "scorpion", + "black and gold garden spider, Argiope aurantia", + "barn spider, Araneus cavaticus", + "garden spider, Aranea diademata", + "black widow, Latrodectus mactans", + "tarantula", + "wolf spider, hunting spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse, partridge, Bonasa umbellus", + "prairie chicken, prairie grouse, prairie fowl", + "peacock", + "quail", + "partridge", + "African grey, African gray, Psittacus erithacus", + "macaw", + "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "drake", + "red-breasted merganser, Mergus serrator", + "goose", + "black swan, Cygnus atratus", + "tusker", + "echidna, spiny anteater, anteater", + "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", + "wallaby, brush kangaroo", + "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", + "wombat", + "jellyfish", + "sea anemone, anemone", + "brain coral", + "flatworm, platyhelminth", + "nematode, nematode worm, roundworm", + "conch", + "snail", + "slug", + "sea slug, nudibranch", + "chiton, coat-of-mail shell, sea cradle, polyplacophore", + "chambered nautilus, pearly nautilus, nautilus", + "Dungeness crab, Cancer magister", + "rock crab, Cancer irroratus", + "fiddler crab", + "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", + "American lobster, Northern lobster, Maine lobster, Homarus americanus", + "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", + "crayfish, crawfish, crawdad, crawdaddy", + "hermit crab", + "isopod", + "white stork, Ciconia ciconia", + "black stork, Ciconia nigra", + "spoonbill", + "flamingo", + "little blue heron, Egretta caerulea", + "American egret, great white heron, Egretta albus", + "bittern", + "crane", + "limpkin, Aramus pictus", + "European gallinule, Porphyrio porphyrio", + "American coot, marsh hen, mud hen, water hen, Fulica americana", + "bustard", + "ruddy turnstone, Arenaria interpres", + "red-backed sandpiper, dunlin, Erolia alpina", + "redshank, Tringa totanus", + "dowitcher", + "oystercatcher, oyster catcher", + "pelican", + "king penguin, Aptenodytes patagonica", + "albatross, mollymawk", + "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", + "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", + "dugong, Dugong dugon", + "sea lion", + "Chihuahua", + "Japanese spaniel", + "Maltese dog, Maltese terrier, Maltese", + "Pekinese, Pekingese, Peke", + "shih-Tzu", + "Blenheim spaniel", + "papillon", + "toy terrier", + "Rhodesian ridgeback", + "Afghan hound, Afghan", + "basset, basset hound", + "beagle", + "bloodhound, sleuthhound", + "bluetick", + "black-and-tan coonhound", + "Walker hound, Walker foxhound", + "English foxhound", + "redbone", + "borzoi, Russian wolfhound", + "Irish wolfhound", + "Italian greyhound", + "whippet", + "Ibizan hound, Ibizan Podenco", + "Norwegian elkhound, elkhound", + "otterhound, otter hound", + "saluki, gazelle hound", + "scottish deerhound, deerhound", + "Weimaraner", + "staffordshire bullterrier, Staffordshire bull terrier", + "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", + "Bedlington terrier", + "Border terrier", + "Kerry blue terrier", + "Irish terrier", + "Norfolk terrier", + "Norwich terrier", + "Yorkshire terrier", + "wire-haired fox terrier", + "Lakeland terrier", + "sealyham terrier, Sealyham", + "Airedale, Airedale terrier", + "cairn, cairn terrier", + "Australian terrier", + "Dandie Dinmont, Dandie Dinmont terrier", + "Boston bull, Boston terrier", + "miniature schnauzer", + "giant schnauzer", + "standard schnauzer", + "scotch terrier, Scottish terrier, Scottie", + "Tibetan terrier, chrysanthemum dog", + "silky terrier, Sydney silky", + "soft-coated wheaten terrier", + "West Highland white terrier", + "Lhasa, Lhasa apso", + "flat-coated retriever", + "curly-coated retriever", + "golden retriever", + "Labrador retriever", + "Chesapeake Bay retriever", + "German short-haired pointer", + "vizsla, Hungarian pointer", + "English setter", + "Irish setter, red setter", + "Gordon setter", + "Brittany spaniel", + "clumber, clumber spaniel", + "English springer, English springer spaniel", + "Welsh springer spaniel", + "cocker spaniel, English cocker spaniel, cocker", + "sussex spaniel", + "Irish water spaniel", + "kuvasz", + "schipperke", + "groenendael", + "malinois", + "briard", + "kelpie", + "komondor", + "Old English sheepdog, bobtail", + "shetland sheepdog, Shetland sheep dog, Shetland", + "collie", + "Border collie", + "Bouvier des Flandres, Bouviers des Flandres", + "Rottweiler", + "German shepherd, German shepherd dog, German police dog, alsatian", + "Doberman, Doberman pinscher", + "miniature pinscher", + "Greater Swiss Mountain dog", + "Bernese mountain dog", + "Appenzeller", + "EntleBucher", + "boxer", + "bull mastiff", + "Tibetan mastiff", + "French bulldog", + "Great Dane", + "saint Bernard, St Bernard", + "Eskimo dog, husky", + "malamute, malemute, Alaskan malamute", + "siberian husky", + "dalmatian, coach dog, carriage dog", + "affenpinscher, monkey pinscher, monkey dog", + "basenji", + "pug, pug-dog", + "Leonberg", + "Newfoundland, Newfoundland dog", + "Great Pyrenees", + "samoyed, Samoyede", + "Pomeranian", + "chow, chow chow", + "keeshond", + "Brabancon griffon", + "Pembroke, Pembroke Welsh corgi", + "Cardigan, Cardigan Welsh corgi", + "toy poodle", + "miniature poodle", + "standard poodle", + "Mexican hairless", + "timber wolf, grey wolf, gray wolf, Canis lupus", + "white wolf, Arctic wolf, Canis lupus tundrarum", + "red wolf, maned wolf, Canis rufus, Canis niger", + "coyote, prairie wolf, brush wolf, Canis latrans", + "dingo, warrigal, warragal, Canis dingo", + "dhole, Cuon alpinus", + "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", + "hyena, hyaena", + "red fox, Vulpes vulpes", + "kit fox, Vulpes macrotis", + "Arctic fox, white fox, Alopex lagopus", + "grey fox, gray fox, Urocyon cinereoargenteus", + "tabby, tabby cat", + "tiger cat", + "Persian cat", + "siamese cat, Siamese", + "Egyptian cat", + "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", + "lynx, catamount", + "leopard, Panthera pardus", + "snow leopard, ounce, Panthera uncia", + "jaguar, panther, Panthera onca, Felis onca", + "lion, king of beasts, Panthera leo", + "tiger, Panthera tigris", + "cheetah, chetah, Acinonyx jubatus", + "brown bear, bruin, Ursus arctos", + "American black bear, black bear, Ursus americanus, Euarctos americanus", + "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", + "sloth bear, Melursus ursinus, Ursus ursinus", + "mongoose", + "meerkat, mierkat", + "tiger beetle", + "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", + "ground beetle, carabid beetle", + "long-horned beetle, longicorn, longicorn beetle", + "leaf beetle, chrysomelid", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant, emmet, pismire", + "grasshopper, hopper", + "cricket", + "walking stick, walkingstick, stick insect", + "cockroach, roach", + "mantis, mantid", + "cicada, cicala", + "leafhopper", + "lacewing, lacewing fly", + "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + "damselfly", + "admiral", + "ringlet, ringlet butterfly", + "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", + "cabbage butterfly", + "sulphur butterfly, sulfur butterfly", + "lycaenid, lycaenid butterfly", + "starfish, sea star", + "sea urchin", + "sea cucumber, holothurian", + "wood rabbit, cottontail, cottontail rabbit", + "hare", + "Angora, Angora rabbit", + "hamster", + "porcupine, hedgehog", + "fox squirrel, eastern fox squirrel, Sciurus niger", + "marmot", + "beaver", + "guinea pig, Cavia cobaya", + "sorrel", + "zebra", + "hog, pig, grunter, squealer, Sus scrofa", + "wild boar, boar, Sus scrofa", + "warthog", + "hippopotamus, hippo, river horse, Hippopotamus amphibius", + "ox", + "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", + "bison", + "ram, tup", + "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", + "ibex, Capra ibex", + "hartebeest", + "impala, Aepyceros melampus", + "gazelle", + "Arabian camel, dromedary, Camelus dromedarius", + "llama", + "weasel", + "mink", + "polecat, fitch, foulmart, foumart, Mustela putorius", + "black-footed ferret, ferret, Mustela nigripes", + "otter", + "skunk, polecat, wood pussy", + "badger", + "armadillo", + "three-toed sloth, ai, Bradypus tridactylus", + "orangutan, orang, orangutang, Pongo pygmaeus", + "gorilla, Gorilla gorilla", + "chimpanzee, chimp, Pan troglodytes", + "gibbon, Hylobates lar", + "siamang, Hylobates syndactylus, Symphalangus syndactylus", + "guenon, guenon monkey", + "patas, hussar monkey, Erythrocebus patas", + "baboon", + "macaque", + "langur", + "colobus, colobus monkey", + "proboscis monkey, Nasalis larvatus", + "marmoset", + "capuchin, ringtail, Cebus capucinus", + "howler monkey, howler", + "titi, titi monkey", + "spider monkey, Ateles geoffroyi", + "squirrel monkey, Saimiri sciureus", + "Madagascar cat, ring-tailed lemur, Lemur catta", + "indri, indris, Indri indri, Indri brevicaudatus", + "Indian elephant, Elephas maximus", + "African elephant, Loxodonta africana", + "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", + "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", + "barracouta, snoek", + "eel", + "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", + "rock beauty, Holocanthus tricolor", + "anemone fish", + "sturgeon", + "gar, garfish, garpike, billfish, Lepisosteus osseus", + "lionfish", + "puffer, pufferfish, blowfish, globefish", + "abacus", + "abaya", + "academic gown, academic robe, judge's robe", + "accordion, piano accordion, squeeze box", + "acoustic guitar", + "aircraft carrier, carrier, flattop, attack aircraft carrier", + "airliner", + "airship, dirigible", + "altar", + "ambulance", + "amphibian, amphibious vehicle", + "analog clock", + "apiary, bee house", + "apron", + "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", + "assault rifle, assault gun", + "backpack, back pack, knapsack, packsack, rucksack, haversack", + "bakery, bakeshop, bakehouse", + "balance beam, beam", + "balloon", + "ballpoint, ballpoint pen, ballpen, Biro", + "Band Aid", + "banjo", + "bannister, banister, balustrade, balusters, handrail", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel, cask", + "barrow, garden cart, lawn cart, wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "bathing cap, swimming cap", + "bath towel", + "bathtub, bathing tub, bath, tub", + "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", + "beacon, lighthouse, beacon light, pharos", + "beaker", + "bearskin, busby, shako", + "beer bottle", + "beer glass", + "bell cote, bell cot", + "bib", + "bicycle-built-for-two, tandem bicycle, tandem", + "bikini, two-piece", + "binder, ring-binder", + "binoculars, field glasses, opera glasses", + "birdhouse", + "boathouse", + "bobsled, bobsleigh, bob", + "bolo tie, bolo, bola tie, bola", + "bonnet, poke bonnet", + "bookcase", + "bookshop, bookstore, bookstall", + "bottlecap", + "bow", + "bow tie, bow-tie, bowtie", + "brass, memorial tablet, plaque", + "brassiere, bra, bandeau", + "breakwater, groin, groyne, mole, bulwark, seawall, jetty", + "breastplate, aegis, egis", + "broom", + "bucket, pail", + "buckle", + "bulletproof vest", + "bullet train, bullet", + "butcher shop, meat market", + "cab, hack, taxi, taxicab", + "caldron, cauldron", + "candle, taper, wax light", + "cannon", + "canoe", + "can opener, tin opener", + "cardigan", + "car mirror", + "carousel, carrousel, merry-go-round, roundabout, whirligig", + "carpenter's kit, tool kit", + "carton", + "car wheel", + "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello, violoncello", + "cellular telephone, cellular phone, cellphone, cell, mobile phone", + "chain", + "chainlink fence", + "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", + "chain saw, chainsaw", + "chest", + "chiffonier, commode", + "chime, bell, gong", + "china cabinet, china closet", + "Christmas stocking", + "church, church building", + "cinema, movie theater, movie theatre, movie house, picture palace", + "cleaver, meat cleaver, chopper", + "cliff dwelling", + "cloak", + "clog, geta, patten, sabot", + "cocktail shaker", + "coffee mug", + "coffeepot", + "coil, spiral, volute, whorl, helix", + "combination lock", + "computer keyboard, keypad", + "confectionery, confectionary, candy store", + "container ship, containership, container vessel", + "convertible", + "corkscrew, bottle screw", + "cornet, horn, trumpet, trump", + "cowboy boot", + "cowboy hat, ten-gallon hat", + "cradle", + "crane", + "crash helmet", + "crate", + "crib, cot", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam, dike, dyke", + "desk", + "desktop computer", + "dial telephone, dial phone", + "diaper, nappy, napkin", + "digital clock", + "digital watch", + "dining table, board", + "dishrag, dishcloth", + "dishwasher, dish washer, dishwashing machine", + "disk brake, disc brake", + "dock, dockage, docking facility", + "dogsled, dog sled, dog sleigh", + "dome", + "doormat, welcome mat", + "drilling platform, offshore rig", + "drum, membranophone, tympan", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan, blower", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso maker", + "face powder", + "feather boa, boa", + "file, file cabinet, filing cabinet", + "fireboat", + "fire engine, fire truck", + "fire screen, fireguard", + "flagpole, flagstaff", + "flute, transverse flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster", + "freight car", + "French horn, horn", + "frying pan, frypan, skillet", + "fur coat", + "garbage truck, dustcart", + "gasmask, respirator, gas helmet", + "gas pump, gasoline pump, petrol pump, island dispenser", + "goblet", + "go-kart", + "golf ball", + "golfcart, golf cart", + "gondola", + "gong, tam-tam", + "gown", + "grand piano, grand", + "greenhouse, nursery, glasshouse", + "grille, radiator grille", + "grocery store, grocery, food market, market", + "guillotine", + "hair slide", + "hair spray", + "half track", + "hammer", + "hamper", + "hand blower, blow dryer, blow drier, hair dryer, hair drier", + "hand-held computer, hand-held microcomputer", + "handkerchief, hankie, hanky, hankey", + "hard disc, hard disk, fixed disk", + "harmonica, mouth organ, harp, mouth harp", + "harp", + "harvester, reaper", + "hatchet", + "holster", + "home theater, home theatre", + "honeycomb", + "hook, claw", + "hoopskirt, crinoline", + "horizontal bar, high bar", + "horse cart, horse-cart", + "hourglass", + "iPod", + "iron, smoothing iron", + "jack-o'-lantern", + "jean, blue jean, denim", + "jeep, landrover", + "jersey, T-shirt, tee shirt", + "jigsaw puzzle", + "jinrikisha, ricksha, rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat, laboratory coat", + "ladle", + "lampshade, lamp shade", + "laptop, laptop computer", + "lawn mower, mower", + "lens cap, lens cover", + "letter opener, paper knife, paperknife", + "library", + "lifeboat", + "lighter, light, igniter, ignitor", + "limousine, limo", + "liner, ocean liner", + "lipstick, lip rouge", + "Loafer", + "lotion", + "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", + "loupe, jeweler's loupe", + "lumbermill, sawmill", + "magnetic compass", + "mailbag, postbag", + "mailbox, letter box", + "maillot", + "maillot, tank suit", + "manhole cover", + "maraca", + "marimba, xylophone", + "mask", + "matchstick", + "maypole", + "maze, labyrinth", + "measuring cup", + "medicine chest, medicine cabinet", + "megalith, megalithic structure", + "microphone, mike", + "microwave, microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt, mini", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home, manufactured home", + "Model T", + "modem", + "monastery", + "monitor", + "moped", + "mortar", + "mortarboard", + "mosque", + "mosquito net", + "motor scooter, scooter", + "mountain bike, all-terrain bike, off-roader", + "mountain tent", + "mouse, computer mouse", + "mousetrap", + "moving van", + "muzzle", + "nail", + "neck brace", + "necklace", + "nipple", + "notebook, notebook computer", + "obelisk", + "oboe, hautboy, hautbois", + "ocarina, sweet potato", + "odometer, hodometer, mileometer, milometer", + "oil filter", + "organ, pipe organ", + "oscilloscope, scope, cathode-ray oscilloscope, CRO", + "overskirt", + "oxcart", + "oxygen mask", + "packet", + "paddle, boat paddle", + "paddlewheel, paddle wheel", + "padlock", + "paintbrush", + "pajama, pyjama, pj's, jammies", + "palace", + "panpipe, pandean pipe, syrinx", + "paper towel", + "parachute, chute", + "parallel bars, bars", + "park bench", + "parking meter", + "passenger car, coach, carriage", + "patio, terrace", + "pay-phone, pay-station", + "pedestal, plinth, footstall", + "pencil box, pencil case", + "pencil sharpener", + "perfume, essence", + "Petri dish", + "photocopier", + "pick, plectrum, plectron", + "pickelhaube", + "picket fence, paling", + "pickup, pickup truck", + "pier", + "piggy bank, penny bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate, pirate ship", + "pitcher, ewer", + "plane, carpenter's plane, woodworking plane", + "planetarium", + "plastic bag", + "plate rack", + "plow, plough", + "plunger, plumber's helper", + "Polaroid camera, Polaroid Land camera", + "pole", + "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", + "poncho", + "pool table, billiard table, snooker table", + "pop bottle, soda bottle", + "pot, flowerpot", + "potter's wheel", + "power drill", + "prayer rug, prayer mat", + "printer", + "prison, prison house", + "projectile, missile", + "projector", + "puck, hockey puck", + "punching bag, punch bag, punching ball, punchball", + "purse", + "quill, quill pen", + "quilt, comforter, comfort, puff", + "racer, race car, racing car", + "racket, racquet", + "radiator", + "radio, wireless", + "radio telescope, radio reflector", + "rain barrel", + "recreational vehicle, RV, R.V.", + "reel", + "reflex camera", + "refrigerator, icebox", + "remote control, remote", + "restaurant, eating house, eating place, eatery", + "revolver, six-gun, six-shooter", + "rifle", + "rocking chair, rocker", + "rotisserie", + "rubber eraser, rubber, pencil eraser", + "rugby ball", + "rule, ruler", + "running shoe", + "safe", + "safety pin", + "saltshaker, salt shaker", + "sandal", + "sarong", + "sax, saxophone", + "scabbard", + "scale, weighing machine", + "school bus", + "schooner", + "scoreboard", + "screen, CRT screen", + "screw", + "screwdriver", + "seat belt, seatbelt", + "sewing machine", + "shield, buckler", + "shoe shop, shoe-shop, shoe store", + "shoji", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "ski mask", + "sleeping bag", + "slide rule, slipstick", + "sliding door", + "slot, one-armed bandit", + "snorkel", + "snowmobile", + "snowplow, snowplough", + "soap dispenser", + "soccer ball", + "sock", + "solar dish, solar collector, solar furnace", + "sombrero", + "soup bowl", + "space bar", + "space heater", + "space shuttle", + "spatula", + "speedboat", + "spider web, spider's web", + "spindle", + "sports car, sport car", + "spotlight, spot", + "stage", + "steam locomotive", + "steel arch bridge", + "steel drum", + "stethoscope", + "stole", + "stone wall", + "stopwatch, stop watch", + "stove", + "strainer", + "streetcar, tram, tramcar, trolley, trolley car", + "stretcher", + "studio couch, day bed", + "stupa, tope", + "submarine, pigboat, sub, U-boat", + "suit, suit of clothes", + "sundial", + "sunglass", + "sunglasses, dark glasses, shades", + "sunscreen, sunblock, sun blocker", + "suspension bridge", + "swab, swob, mop", + "sweatshirt", + "swimming trunks, bathing trunks", + "swing", + "switch, electric switch, electrical switch", + "syringe", + "table lamp", + "tank, army tank, armored combat vehicle, armoured combat vehicle", + "tape player", + "teapot", + "teddy, teddy bear", + "television, television system", + "tennis ball", + "thatch, thatched roof", + "theater curtain, theatre curtain", + "thimble", + "thresher, thrasher, threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop, tobacconist shop, tobacconist", + "toilet seat", + "torch", + "totem pole", + "tow truck, tow car, wrecker", + "toyshop", + "tractor", + "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", + "tray", + "trench coat", + "tricycle, trike, velocipede", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus, trolley coach, trackless trolley", + "trombone", + "tub, vat", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle, monocycle", + "upright, upright piano", + "vacuum, vacuum cleaner", + "vase", + "vault", + "velvet", + "vending machine", + "vestment", + "viaduct", + "violin, fiddle", + "volleyball", + "waffle iron", + "wall clock", + "wallet, billfold, notecase, pocketbook", + "wardrobe, closet, press", + "warplane, military plane", + "washbasin, handbasin, washbowl, lavabo, wash-hand basin", + "washer, automatic washer, washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "wing", + "wok", + "wooden spoon", + "wool, woolen, woollen", + "worm fence, snake fence, snake-rail fence, Virginia fence", + "wreck", + "yawl", + "yurt", + "web site, website, internet site, site", + "comic book", + "crossword puzzle, crossword", + "street sign", + "traffic light, traffic signal, stoplight", + "book jacket, dust cover, dust jacket, dust wrapper", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot, hotpot", + "trifle", + "ice cream, icecream", + "ice lolly, lolly, lollipop, popsicle", + "French loaf", + "bagel, beigel", + "pretzel", + "cheeseburger", + "hotdog, hot dog, red hot", + "mashed potato", + "head cabbage", + "broccoli", + "cauliflower", + "zucchini, courgette", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber, cuke", + "artichoke, globe artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple, ananas", + "banana", + "jackfruit, jak, jack", + "custard apple", + "pomegranate", + "hay", + "carbonara", + "chocolate sauce, chocolate syrup", + "dough", + "meat loaf, meatloaf", + "pizza, pizza pie", + "potpie", + "burrito", + "red wine", + "espresso", + "cup", + "eggnog", + "alp", + "bubble", + "cliff, drop, drop-off", + "coral reef", + "geyser", + "lakeside, lakeshore", + "promontory, headland, head, foreland", + "sandbar, sand bar", + "seashore, coast, seacoast, sea-coast", + "valley, vale", + "volcano", + "ballplayer, baseball player", + "groom, bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + "corn", + "acorn", + "hip, rose hip, rosehip", + "buckeye, horse chestnut, conker", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn, carrion fungus", + "earthstar", + "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", + "bolete", + "ear, spike, capitulum", + "toilet tissue, toilet paper, bathroom tissue" +]; diff --git a/src/models/mobileone/config.rs b/src/models/mobileone/config.rs new file mode 100644 index 0000000..ad0203f --- /dev/null +++ b/src/models/mobileone/config.rs @@ -0,0 +1,50 @@ +use crate::models::IMAGENET_NAMES_1K; + +/// Model configuration for `MobileOne` +impl crate::Options { + pub fn mobileone() -> Self { + Self::default() + .with_model_name("mobileone") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_apply_softmax(true) + .with_normalize(true) + .with_class_names(&IMAGENET_NAMES_1K) + } + + pub fn mobileone_s0() -> Self { + Self::mobileone().with_model_file("s0.onnx") + } + + pub fn mobileone_s1() -> Self { + Self::mobileone().with_model_file("s1.onnx") + } + + pub fn mobileone_s2() -> Self { + Self::mobileone().with_model_file("s2.onnx") + } + + pub fn mobileone_s3() -> Self { + Self::mobileone().with_model_file("s3.onnx") + } + + pub fn mobileone_s4_224x224() -> Self { + Self::mobileone().with_model_file("s4-224x224.onnx") + } + + pub fn mobileone_s4_256x256() -> Self { + Self::mobileone().with_model_file("s4-256x256.onnx") + } + + pub fn mobileone_s4_384x384() -> Self { + Self::mobileone().with_model_file("s4-384x384.onnx") + } + + pub fn mobileone_s4_512x512() -> Self { + Self::mobileone().with_model_file("s4-512x512.onnx") + } +} diff --git a/src/models/mobileone/mod.rs b/src/models/mobileone/mod.rs new file mode 100644 index 0000000..1bf79df --- /dev/null +++ b/src/models/mobileone/mod.rs @@ -0,0 +1 @@ +mod config; diff --git a/src/models/mod.rs b/src/models/mod.rs index 6df4c8c..26a2da5 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,35 +1,62 @@ -//! Models provided: [`Blip`], [`Clip`], [`YOLO`], [`DepthAnything`], ... - +mod basemodel; +mod beit; mod blip; mod clip; +mod convnext; mod db; +mod deit; mod depth_anything; mod depth_pro; mod dinov2; +mod fastvit; mod florence2; mod grounding_dino; +mod image_classifier; +mod kind; +mod labels; +mod mobileone; mod modnet; +mod options; +mod picodet; +mod processor; +mod rtdetr; mod rtmo; mod sam; mod sapiens; +mod scale; +mod slanet; mod svtr; +mod task; +mod trocr; +mod version; mod yolo; -mod yolo_; mod yolop; -pub use blip::Blip; -pub use clip::Clip; -pub use db::DB; -pub use depth_anything::DepthAnything; -pub use depth_pro::DepthPro; -pub use dinov2::Dinov2; -pub use florence2::Florence2; -pub use grounding_dino::GroundingDINO; -pub use modnet::MODNet; -pub use rtmo::RTMO; -pub use sam::{SamKind, SamPrompt, SAM}; -pub use sapiens::{Sapiens, SapiensTask}; -pub use svtr::SVTR; -pub use yolo::YOLO; -pub use yolo_::*; -pub use yolop::YOLOPv2; +pub use basemodel::*; +pub use blip::*; +pub use clip::*; +pub use db::*; +pub use depth_anything::*; +pub use depth_pro::*; +pub use dinov2::*; +pub use florence2::*; +pub use grounding_dino::*; +pub use image_classifier::*; +pub use kind::Kind; +pub use labels::*; +pub use modnet::*; +pub use options::*; +pub use picodet::*; +pub use processor::*; +pub use rtdetr::*; +pub use rtmo::*; +pub use sam::*; +pub use sapiens::*; +pub use scale::Scale; +pub use slanet::*; +pub use svtr::*; +pub use task::Task; +pub use trocr::*; +pub use version::Version; +pub use yolo::*; +pub use yolop::*; diff --git a/src/models/modnet.rs b/src/models/modnet.rs deleted file mode 100644 index 4f87cbd..0000000 --- a/src/models/modnet.rs +++ /dev/null @@ -1,84 +0,0 @@ -use anyhow::Result; -use image::DynamicImage; -use ndarray::Axis; - -use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; - -#[derive(Debug)] -pub struct MODNet { - engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, -} - -impl MODNet { - pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().to_owned(), - engine.height().to_owned(), - engine.width().to_owned(), - ); - engine.dry_run()?; - - Ok(Self { - engine, - height, - width, - batch, - }) - } - - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Resize( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Lanczos3", - ), - Ops::Normalize(0., 255.), - Ops::Nhwc2nchw, - ])?; - - let ys = self.engine.run(Xs::from(xs_))?; - self.postprocess(ys, xs) - } - - pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { - let mut ys: Vec = Vec::new(); - for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() { - let (w1, h1) = (xs0[idx].width(), xs0[idx].height()); - let luma = luma.mapv(|x| (x * 255.0) as u8); - let luma = Ops::resize_luma8_u8( - &luma.into_raw_vec_and_offset().0, - self.width() as _, - self.height() as _, - w1 as _, - h1 as _, - false, - "Bilinear", - )?; - let luma: image::ImageBuffer, Vec<_>> = - match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) { - None => continue, - Some(x) => x, - }; - ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)])); - } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - - pub fn height(&self) -> isize { - self.height.opt() as _ - } -} diff --git a/src/models/modnet/config.rs b/src/models/modnet/config.rs new file mode 100644 index 0000000..05174d2 --- /dev/null +++ b/src/models/modnet/config.rs @@ -0,0 +1,17 @@ +/// Model configuration for `MODNet` +impl crate::Options { + pub fn modnet() -> Self { + Self::default() + .with_model_name("modnet") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 2, (416, 512, 800).into()) + .with_model_ixx(0, 3, (416, 512, 800).into()) + .with_image_mean(&[0.5, 0.5, 0.5]) + .with_image_std(&[0.5, 0.5, 0.5]) + .with_normalize(true) + } + + pub fn modnet_photographic() -> Self { + Self::modnet().with_model_file("photographic-portrait-matting.onnx") + } +} diff --git a/src/models/modnet/impl.rs b/src/models/modnet/impl.rs new file mode 100644 index 0000000..f21d446 --- /dev/null +++ b/src/models/modnet/impl.rs @@ -0,0 +1,90 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::Axis; + +use crate::{elapsed, Engine, Mask, Ops, Options, Processor, Ts, Xs, Ys, Y}; + +#[derive(Builder, Debug)] +pub struct MODNet { + engine: Engine, + height: usize, + width: usize, + batch: usize, + ts: Ts, + spec: String, + processor: Processor, +} + +impl MODNet { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&512.into()).opt(), + engine.try_width().unwrap_or(&512.into()).opt(), + engine.ts().clone(), + ); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + Ok(Self { + engine, + height, + width, + batch, + ts, + spec, + processor, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn postprocess(&mut self, xs: Xs) -> Result { + let mut ys: Vec = Vec::new(); + for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() { + let (h1, w1) = self.processor.image0s_size[idx]; + + let luma = luma.mapv(|x| (x * 255.0) as u8); + let luma = Ops::resize_luma8_u8( + &luma.into_raw_vec_and_offset().0, + self.width as _, + self.height as _, + w1 as _, + h1 as _, + false, + "Bilinear", + )?; + let luma: image::ImageBuffer, Vec<_>> = + match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) { + None => continue, + Some(x) => x, + }; + ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)])); + } + + Ok(ys.into()) + } +} diff --git a/src/models/modnet/mod.rs b/src/models/modnet/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/modnet/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/options.rs b/src/models/options.rs new file mode 100644 index 0000000..7b80a4a --- /dev/null +++ b/src/models/options.rs @@ -0,0 +1,432 @@ +//! Options for everthing + +use aksr::Builder; +use anyhow::Result; +use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; + +use crate::{ + models::{SamKind, YOLOPredsFormat}, + DType, Device, Engine, Hub, Iiix, Kind, MinOptMax, Processor, ResizeMode, Scale, Task, Version, +}; + +/// Options for building models and inference +#[derive(Builder, Debug, Clone)] +pub struct Options { + // Model configs + pub model_file: String, + pub model_name: &'static str, + pub model_device: Device, + pub model_dtype: DType, + pub model_version: Option, + pub model_task: Option, + pub model_scale: Option, + pub model_kind: Option, + pub model_iiixs: Vec, + pub model_spec: String, + pub model_num_dry_run: usize, + pub trt_fp16: bool, + pub profile: bool, + + // Processor configs + #[args(setter = false)] + pub image_width: u32, + #[args(setter = false)] + pub image_height: u32, + pub resize_mode: ResizeMode, + pub resize_filter: &'static str, + pub padding_value: u8, + pub letterbox_center: bool, + pub normalize: bool, + pub image_std: Vec, + pub image_mean: Vec, + pub nchw: bool, + pub unsigned: bool, + + // Names + pub class_names: Option>, + pub class_names_2: Option>, + pub class_names_3: Option>, + pub keypoint_names: Option>, + pub keypoint_names_2: Option>, + pub keypoint_names_3: Option>, + pub text_names: Option>, + pub text_names_2: Option>, + pub text_names_3: Option>, + pub category_names: Option>, + pub category_names_2: Option>, + pub category_names_3: Option>, + + // Confs + pub class_confs: Vec, + pub class_confs_2: Vec, + pub class_confs_3: Vec, + pub keypoint_confs: Vec, + pub keypoint_confs_2: Vec, + pub keypoint_confs_3: Vec, + pub text_confs: Vec, + pub text_confs_2: Vec, + pub text_confs_3: Vec, + + // For classification + pub apply_softmax: Option, + + // For detection + #[args(alias = "nc")] + pub num_classes: Option, + #[args(alias = "nk")] + pub num_keypoints: Option, + #[args(alias = "nm")] + pub num_masks: Option, + pub iou: Option, + pub iou_2: Option, + pub iou_3: Option, + pub apply_nms: Option, + pub find_contours: bool, + pub yolo_preds_format: Option, + pub classes_excluded: Vec, + pub classes_retained: Vec, + pub min_width: Option, + pub min_height: Option, + + // Language models related + pub model_max_length: Option, + pub tokenizer_file: Option, + pub config_file: Option, + pub special_tokens_map_file: Option, + pub tokenizer_config_file: Option, + pub generation_config_file: Option, + pub vocab_file: Option, // vocab.json file + pub vocab_txt: Option, // vacab.txt file, not kv pairs + + // For DB + pub unclip_ratio: Option, + pub binary_thresh: Option, + + // For SAM + pub sam_kind: Option, + pub low_res_mask: Option, +} + +impl Default for Options { + fn default() -> Self { + Self { + model_file: Default::default(), + model_name: Default::default(), + model_version: Default::default(), + model_task: Default::default(), + model_scale: Default::default(), + model_kind: Default::default(), + model_device: Device::Cpu(0), + model_dtype: DType::Auto, + model_spec: Default::default(), + model_iiixs: Default::default(), + model_num_dry_run: 3, + trt_fp16: true, + profile: false, + normalize: true, + image_mean: vec![], + image_std: vec![], + image_height: 640, + image_width: 640, + padding_value: 114, + resize_mode: ResizeMode::FitExact, + resize_filter: "Bilinear", + letterbox_center: false, + nchw: true, + unsigned: false, + class_names: None, + class_names_2: None, + class_names_3: None, + category_names: None, + category_names_2: None, + category_names_3: None, + keypoint_names: None, + keypoint_names_2: None, + keypoint_names_3: None, + text_names: None, + text_names_2: None, + text_names_3: None, + class_confs: vec![0.3f32], + class_confs_2: vec![0.3f32], + class_confs_3: vec![0.3f32], + keypoint_confs: vec![0.3f32], + keypoint_confs_2: vec![0.5f32], + keypoint_confs_3: vec![0.5f32], + text_confs: vec![0.4f32], + text_confs_2: vec![0.4f32], + text_confs_3: vec![0.4f32], + apply_softmax: Some(false), + num_classes: None, + num_keypoints: None, + num_masks: None, + iou: None, + iou_2: None, + iou_3: None, + find_contours: false, + yolo_preds_format: None, + classes_excluded: vec![], + classes_retained: vec![], + apply_nms: None, + model_max_length: None, + tokenizer_file: None, + config_file: None, + special_tokens_map_file: None, + tokenizer_config_file: None, + generation_config_file: None, + vocab_file: None, + vocab_txt: None, + min_width: None, + min_height: None, + unclip_ratio: Some(1.5), + binary_thresh: Some(0.2), + sam_kind: None, + low_res_mask: None, + } + } +} + +impl Options { + pub fn new() -> Self { + Default::default() + } + + pub fn to_engine(&self) -> Result { + Engine { + file: self.model_file.clone(), + spec: self.model_spec.clone(), + device: self.model_device, + trt_fp16: self.trt_fp16, + iiixs: self.model_iiixs.clone(), + num_dry_run: self.model_num_dry_run, + ..Default::default() + } + .build() + } + + pub fn to_processor(&self) -> Result { + // try to build tokenizer + let tokenizer = match self.model_kind { + Some(Kind::Language) | Some(Kind::VisionLanguage) => Some(self.try_build_tokenizer()?), + _ => None, + }; + + // try to build vocab from `vocab.txt` + let vocab: Vec = match &self.vocab_txt { + Some(x) => { + let file = if !std::path::PathBuf::from(&x).exists() { + Hub::new()?.try_fetch(&format!("{}/{}", self.model_name, x))? + } else { + x.to_string() + }; + std::fs::read_to_string(file)? + .lines() + .map(|line| line.to_string()) + .collect() + } + None => vec![], + }; + + Ok(Processor { + image_width: self.image_width, + image_height: self.image_height, + resize_mode: self.resize_mode.clone(), + resize_filter: self.resize_filter, + padding_value: self.padding_value, + do_normalize: self.normalize, + image_mean: self.image_mean.clone(), + image_std: self.image_std.clone(), + nchw: self.nchw, + unsigned: self.unsigned, + tokenizer, + vocab, + ..Default::default() + }) + } + + pub fn commit(mut self) -> Result { + // Identify the local model or fetch the remote model + + if std::path::PathBuf::from(&self.model_file).exists() { + // Local + self.model_spec = crate::try_fetch_stem(&self.model_file)?; + } else { + // Remote + if self.model_file.is_empty() && self.model_name.is_empty() { + anyhow::bail!("Neither `model_name` nor `model_file` were specified. Faild to fetch model from remote.") + } + + // special yolo case + if self.model_file.is_empty() && self.model_name == "yolo" { + // [version]-[scale]-[task] + let mut y = String::new(); + if let Some(x) = self.model_version() { + y.push_str(&x.to_string()); + } + if let Some(x) = self.model_scale() { + y.push_str(&format!("-{}", x)); + } + if let Some(x) = self.model_task() { + y.push_str(&format!("-{}", x.yolo_str())); + } + y.push_str(".onnx"); + self.model_file = y; + } + + // append dtype to model file + match self.model_dtype { + d @ (DType::Auto | DType::Fp32) => { + if self.model_file.is_empty() { + self.model_file = format!("{}.onnx", d); + } + } + dtype => { + if self.model_file.is_empty() { + self.model_file = format!("{}.onnx", dtype); + } else { + let pos = self.model_file.len() - 5; // .onnx + let suffix = self.model_file.split_off(pos); + self.model_file = format!("{}-{}{}", self.model_file, dtype, suffix); + } + } + } + + // Load + let stem = crate::try_fetch_stem(&self.model_file)?; + self.model_spec = format!("{}/{}", self.model_name, stem); + self.model_file = + Hub::new()?.try_fetch(&format!("{}/{}", self.model_name, self.model_file))?; + } + + Ok(self) + } + + pub fn with_batch_size(mut self, x: usize) -> Self { + self.model_iiixs.push(Iiix::from((0, 0, x.into()))); + self + } + + pub fn with_image_height(mut self, x: u32) -> Self { + self.image_height = x; + self.model_iiixs.push(Iiix::from((0, 2, x.into()))); + self + } + + pub fn with_image_width(mut self, x: u32) -> Self { + self.image_width = x; + self.model_iiixs.push(Iiix::from((0, 3, x.into()))); + self + } + + pub fn with_model_ixx(mut self, i: usize, ii: usize, x: MinOptMax) -> Self { + self.model_iiixs.push(Iiix::from((i, ii, x))); + self + } + + pub fn exclude_classes(mut self, xs: &[usize]) -> Self { + // TODO: remove??? + self.classes_retained.clear(); + self.classes_excluded.extend_from_slice(xs); + self + } + + pub fn retain_classes(mut self, xs: &[usize]) -> Self { + self.classes_excluded.clear(); + self.classes_retained.extend_from_slice(xs); + self + } + + pub fn try_build_tokenizer(&self) -> Result { + let mut hub = Hub::new()?; + // config file + // TODO: save configs? + let pad_id = match hub.try_fetch( + self.tokenizer_config_file + .as_ref() + .unwrap_or(&format!("{}/config.json", self.model_name)), + ) { + Ok(x) => { + let config: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(x)?)?; + config["pad_token_id"].as_u64().unwrap_or(0) as u32 + } + Err(_err) => 0u32, + }; + + // tokenizer_config file + let mut max_length = None; + let mut pad_token = String::from("[PAD]"); + match hub.try_fetch( + self.tokenizer_config_file + .as_ref() + .unwrap_or(&format!("{}/tokenizer_config.json", self.model_name)), + ) { + Err(_) => {} + Ok(x) => { + let tokenizer_config: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(x)?)?; + max_length = tokenizer_config["model_max_length"].as_u64(); + pad_token = tokenizer_config["pad_token"] + .as_str() + .unwrap_or("[PAD]") + .to_string(); + } + } + + // tokenizer file + let mut tokenizer: tokenizers::Tokenizer = tokenizers::Tokenizer::from_file( + hub.try_fetch( + self.tokenizer_file + .as_ref() + .unwrap_or(&format!("{}/tokenizer.json", self.model_name)), + )?, + ) + .map_err(|_| anyhow::anyhow!("No `tokenizer.json` found"))?; + + // TODO: padding + // if `max_length` specified: use `Fixed` strategy + // else: use `BatchLongest` strategy + // TODO: if sequence_length is dynamic, `BatchLongest` is fine + let tokenizer = match self.model_max_length { + Some(n) => { + let n = match max_length { + None => n, + Some(x) => x.min(n), + }; + tokenizer + .with_padding(Some(PaddingParams { + strategy: PaddingStrategy::Fixed(n as _), + pad_token, + pad_id, + ..Default::default() + })) + .clone() + } + None => match max_length { + Some(n) => tokenizer + .with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + pad_token, + pad_id, + ..Default::default() + })) + .with_truncation(Some(TruncationParams { + max_length: n as _, + ..Default::default() + })) + .map_err(|err| anyhow::anyhow!("Failed to truncate: {}", err))? + .clone(), + None => tokenizer + .with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + pad_token, + pad_id, + ..Default::default() + })) + .clone(), + }, + }; + + // TODO: generation_config.json & special_tokens_map file + + Ok(tokenizer.into()) + } +} diff --git a/src/models/picodet/config.rs b/src/models/picodet/config.rs new file mode 100644 index 0000000..a509e39 --- /dev/null +++ b/src/models/picodet/config.rs @@ -0,0 +1,61 @@ +use crate::{models::COCO_CLASS_NAMES_80, ResizeMode}; + +/// Model configuration for `PicoDet` +impl crate::Options { + pub fn picodet() -> Self { + Self::default() + .with_model_name("picodet") + .with_batch_size(1) // TODO: ONNX model's batch size seems always = 1 + .with_model_ixx(0, 2, 640.into()) + .with_model_ixx(0, 3, 640.into()) + .with_model_ixx(1, 0, (1, 1, 8).into()) + .with_model_ixx(1, 1, 2.into()) + .with_resize_mode(ResizeMode::FitAdaptive) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_normalize(true) + .with_class_confs(&[0.5]) + } + + pub fn picodet_l_coco() -> Self { + Self::picodet() + .with_model_file("l-coco.onnx") + .with_class_names(&COCO_CLASS_NAMES_80) + } + + pub fn picodet_layout_1x() -> Self { + Self::picodet() + .with_model_file("layout-1x.onnx") + .with_class_names(&["Text", "Title", "List", "Table", "Figure"]) + } + + pub fn picodet_l_layout_3cls() -> Self { + Self::picodet() + .with_model_file("l-layout-3cls.onnx") + .with_class_names(&["image", "table", "seal"]) + } + + pub fn picodet_l_layout_17cls() -> Self { + Self::picodet() + .with_model_file("l-layout-17cls.onnx") + .with_class_names(&[ + "paragraph_title", + "image", + "text", + "number", + "abstract", + "content", + "figure_title", + "formula", + "table", + "table_title", + "reference", + "doc_title", + "footnote", + "header", + "algorithm", + "footer", + "seal", + ]) + } +} diff --git a/src/models/picodet/impl.rs b/src/models/picodet/impl.rs new file mode 100644 index 0000000..5bada86 --- /dev/null +++ b/src/models/picodet/impl.rs @@ -0,0 +1,111 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::Axis; +use rayon::prelude::*; + +use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y}; + +#[derive(Debug, Builder)] +pub struct PicoDet { + engine: Engine, + height: usize, + width: usize, + batch: usize, + spec: String, + names: Vec, + confs: DynConf, + ts: Ts, + processor: Processor, +} + +impl PicoDet { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&640.into()).opt(), + engine.try_width().unwrap_or(&640.into()).opt(), + engine.ts.clone(), + ); + let spec = engine.spec().to_owned(); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let names = options + .class_names() + .expect("No class names are specified.") + .to_vec(); + let confs = DynConf::new(options.class_confs(), names.len()); + + Ok(Self { + engine, + height, + width, + batch, + spec, + names, + confs, + ts, + processor, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let x1 = self.processor.process_images(xs)?; + let x2: X = self.processor.scale_factors_hw.clone().try_into()?; + + Ok(Xs::from(vec![x1, x2])) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn postprocess(&mut self, xs: Xs) -> Result { + // ONNX models exported by paddle2onnx + // TODO: ONNX model's batch size seems always = 1 + // xs[0] : n, 6 + // xs[1] : n + let y_bboxes: Vec = xs[0] + .axis_iter(Axis(0)) + .into_par_iter() + .enumerate() + .filter_map(|(_i, pred)| { + let (class_id, confidence) = (pred[0] as usize, pred[1]); + if confidence < self.confs[class_id] { + return None; + } + let (x1, y1, x2, y2) = (pred[2], pred[3], pred[4], pred[5]); + + Some( + Bbox::default() + .with_xyxy(x1.max(0.0f32), y1.max(0.0f32), x2, y2) + .with_confidence(confidence) + .with_id(class_id as isize) + .with_name(&self.names[class_id]), + ) + }) + .collect(); + + let mut y = Y::default(); + if !y_bboxes.is_empty() { + y = y.with_bboxes(&y_bboxes); + } + + Ok(vec![y].into()) + } +} diff --git a/src/models/picodet/mod.rs b/src/models/picodet/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/picodet/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/processor.rs b/src/models/processor.rs new file mode 100644 index 0000000..629b05f --- /dev/null +++ b/src/models/processor.rs @@ -0,0 +1,320 @@ +use anyhow::Result; +use fast_image_resize::{ + images::{CroppedImageMut, Image}, + pixels::PixelType, + FilterType, ResizeAlg, ResizeOptions, Resizer, +}; +use image::{DynamicImage, GenericImageView}; +use ndarray::{s, Array}; +use tokenizers::{Encoding, Tokenizer}; + +use crate::X; + +#[derive(Debug, Clone)] +pub enum ResizeMode { + FitExact, // StretchToFit + FitWidth, + FitHeight, + FitAdaptive, + Letterbox, +} + +#[derive(aksr::Builder, Debug, Clone)] +pub struct Processor { + pub image_width: u32, // target image width + pub image_height: u32, // target image height + pub image0s_size: Vec<(u32, u32)>, // original image height and width + pub scale_factors_hw: Vec>, + pub resize_mode: ResizeMode, + pub resize_filter: &'static str, + pub padding_value: u8, + pub do_normalize: bool, + pub image_mean: Vec, + pub image_std: Vec, + pub nchw: bool, + pub tokenizer: Option, + pub vocab: Vec, + pub unsigned: bool, +} + +impl Default for Processor { + fn default() -> Self { + Self { + image0s_size: vec![], + image_width: 0, + image_height: 0, + scale_factors_hw: vec![], + resize_mode: ResizeMode::FitAdaptive, + resize_filter: "Bilinear", + padding_value: 114, + do_normalize: true, + image_mean: vec![], + image_std: vec![], + nchw: true, + tokenizer: Default::default(), + vocab: vec![], + unsigned: false, + } + } +} + +impl Processor { + pub fn reset_image0_status(&mut self) { + self.scale_factors_hw.clear(); + self.image0s_size.clear(); + } + + pub fn process_images(&mut self, xs: &[DynamicImage]) -> Result { + // reset + self.reset_image0_status(); + + let mut x = self.resize_batch(xs)?; + if self.do_normalize { + x = x.normalize(0., 255.)?; + } + if !self.image_std.is_empty() && !self.image_mean.is_empty() { + x = x.standardize(&self.image_mean, &self.image_std, 3)?; + } + if self.nchw { + x = x.nhwc2nchw()?; + } + + // Cope with padding problem + if self.unsigned { + x = x.unsigned(); + } + Ok(x) + } + + pub fn encode_text(&self, x: &str, skip_special_tokens: bool) -> Result { + self.tokenizer + .as_ref() + .expect("No tokenizer specified in `Processor`") + .encode(x, skip_special_tokens) + .map_err(|err| anyhow::anyhow!("Tokenizer encode error: {}", err)) + } + + pub fn encode_texts(&self, xs: &[&str], skip_special_tokens: bool) -> Result> { + self.tokenizer + .as_ref() + .expect("No tokenizer specified in `Processor`") + .encode_batch(xs.to_vec(), skip_special_tokens) + .map_err(|err| anyhow::anyhow!("Tokenizer encode_batch error: {}", err)) + } + + pub fn encode_text_ids(&self, x: &str, skip_special_tokens: bool) -> Result> { + let ids: Vec = if x.is_empty() { + vec![0.0f32] + } else { + self.encode_text(x, skip_special_tokens)? + .get_ids() + .iter() + .map(|x| *x as f32) + .collect() + }; + + Ok(ids) + } + + pub fn encode_texts_ids( + &self, + xs: &[&str], + skip_special_tokens: bool, + ) -> Result>> { + let ids: Vec> = if xs.is_empty() { + vec![vec![0.0f32]] + } else { + self.encode_texts(xs, skip_special_tokens)? + .into_iter() + .map(|encoding| encoding.get_ids().iter().map(|x| *x as f32).collect()) + .collect() + }; + + Ok(ids) + } + + pub fn encode_text_tokens(&self, x: &str, skip_special_tokens: bool) -> Result> { + Ok(self + .encode_text(x, skip_special_tokens)? + .get_tokens() + .to_vec()) + } + + pub fn encode_texts_tokens( + &self, + xs: &[&str], + skip_special_tokens: bool, + ) -> Result>> { + Ok(self + .encode_texts(xs, skip_special_tokens)? + .into_iter() + .map(|encoding| encoding.get_tokens().to_vec()) + .collect()) + } + + pub fn decode_tokens(&self, ids: &[u32], skip_special_tokens: bool) -> Result { + self.tokenizer + .as_ref() + .expect("No tokenizer specified in `Processor`") + .decode(ids, skip_special_tokens) + .map_err(|err| anyhow::anyhow!("Tokenizer decode error: {}", err)) + } + + pub fn decode_tokens_batch2( + &self, + ids: &[&[u32]], + skip_special_tokens: bool, + ) -> Result> { + self.tokenizer + .as_ref() + .expect("No tokenizer specified in `Processor`") + .decode_batch(ids, skip_special_tokens) + .map_err(|err| anyhow::anyhow!("Tokenizer decode_batch error: {}", err)) + } + + pub fn decode_tokens_batch( + &self, + ids: &[Vec], + skip_special_tokens: bool, + ) -> Result> { + self.tokenizer + .as_ref() + .expect("No tokenizer specified in `Processor`") + .decode_batch( + &ids.iter().map(|x| x.as_slice()).collect::>(), + skip_special_tokens, + ) + .map_err(|err| anyhow::anyhow!("Tokenizer decode_batch error: {}", err)) + } + + pub fn build_resizer_filter(ty: &str) -> Result<(Resizer, ResizeOptions)> { + let ty = match ty.to_lowercase().as_str() { + "box" => FilterType::Box, + "bilinear" => FilterType::Bilinear, + "hamming" => FilterType::Hamming, + "catmullrom" => FilterType::CatmullRom, + "mitchell" => FilterType::Mitchell, + "gaussian" => FilterType::Gaussian, + "lanczos3" => FilterType::Lanczos3, + x => anyhow::bail!("Unsupported resizer's filter type: {}", x), + }; + Ok(( + Resizer::new(), + ResizeOptions::new().resize_alg(ResizeAlg::Convolution(ty)), + )) + } + + pub fn resize(&mut self, x: &DynamicImage) -> Result { + if self.image_width + self.image_height == 0 { + anyhow::bail!( + "Invalid target height: {} or width: {}.", + self.image_height, + self.image_width + ); + } + + let buffer = match x.dimensions() { + (w, h) if (w, h) == (self.image_height, self.image_width) => { + self.image0s_size.push((h, w)); + self.scale_factors_hw.push(vec![1., 1.]); + x.to_rgb8().into_raw() + } + (w0, h0) => { + self.image0s_size.push((h0, w0)); + let (mut resizer, options) = Self::build_resizer_filter(self.resize_filter)?; + + if let ResizeMode::FitExact = self.resize_mode { + let mut dst = Image::new(self.image_width, self.image_height, PixelType::U8x3); + resizer.resize(x, &mut dst, &options)?; + self.scale_factors_hw.push(vec![ + (self.image_height as f32 / h0 as f32), + (self.image_width as f32 / w0 as f32), + ]); + + dst.into_vec() + } else { + let (w, h) = match self.resize_mode { + ResizeMode::Letterbox | ResizeMode::FitAdaptive => { + let r = (self.image_width as f32 / w0 as f32) + .min(self.image_height as f32 / h0 as f32); + self.scale_factors_hw.push(vec![r, r]); + + ( + (w0 as f32 * r).round() as u32, + (h0 as f32 * r).round() as u32, + ) + } + ResizeMode::FitHeight => { + let r = self.image_height as f32 / h0 as f32; + self.scale_factors_hw.push(vec![1.0, r]); + ((r * w0 as f32).round() as u32, self.image_height) + } + ResizeMode::FitWidth => { + // scale factor + let r = self.image_width as f32 / w0 as f32; + self.scale_factors_hw.push(vec![r, 1.0]); + (self.image_width, (r * h0 as f32).round() as u32) + } + + _ => unreachable!(), + }; + + let mut dst = Image::from_vec_u8( + self.image_width, + self.image_height, + vec![ + self.padding_value; + 3 * self.image_height as usize * self.image_width as usize + ], + PixelType::U8x3, + )?; + let (l, t) = if let ResizeMode::Letterbox = self.resize_mode { + if w == self.image_width { + (0, (self.image_height - h) / 2) + } else { + ((self.image_width - w) / 2, 0) + } + } else { + (0, 0) + }; + + let mut dst_cropped = CroppedImageMut::new(&mut dst, l, t, w, h)?; + resizer.resize(x, &mut dst_cropped, &options)?; + dst.into_vec() + } + } + }; + + let y = Array::from_shape_vec( + (self.image_height as usize, self.image_width as usize, 3), + buffer, + )? + .mapv(|x| x as f32) + .into_dyn(); + + Ok(y.into()) + } + + pub fn resize_batch(&mut self, xs: &[DynamicImage]) -> Result { + // TODO: par resize + if xs.is_empty() { + anyhow::bail!("Found no input images.") + } + + let mut ys = Array::ones(( + xs.len(), + self.image_height as usize, + self.image_width as usize, + 3, + )) + .into_dyn(); + + xs.iter().enumerate().try_for_each(|(idx, x)| { + let y = self.resize(x)?; + ys.slice_mut(s![idx, .., .., ..]).assign(&y); + anyhow::Ok(()) + })?; + + Ok(ys.into()) + } +} diff --git a/src/models/rtdetr/README.md b/src/models/rtdetr/README.md new file mode 100644 index 0000000..2295290 --- /dev/null +++ b/src/models/rtdetr/README.md @@ -0,0 +1,3 @@ +# RT-DETR + +**Models exported from [RT-DETR](https://github.com/lyuwenyu/RT-DETR)** \ No newline at end of file diff --git a/src/models/rtdetr/config.rs b/src/models/rtdetr/config.rs new file mode 100644 index 0000000..fa96137 --- /dev/null +++ b/src/models/rtdetr/config.rs @@ -0,0 +1,80 @@ +use crate::models::COCO_CLASS_NAMES_80; + +/// Model configuration for `RT-DETR` +impl crate::Options { + pub fn rtdetr() -> Self { + Self::default() + .with_model_name("rtdetr") + .with_batch_size(1) + .with_model_ixx(0, 2, 640.into()) + .with_model_ixx(0, 3, 640.into()) + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_normalize(true) + .with_class_confs(&[0.5]) + .with_class_names(&COCO_CLASS_NAMES_80) + } + + pub fn rtdetr_v1_r18vd_coco() -> Self { + Self::rtdetr().with_model_file("v1-r18vd-coco.onnx") + } + + pub fn rtdetr_v2_s_coco() -> Self { + Self::rtdetr().with_model_file("v2-s-coco.onnx") + } + + pub fn rtdetr_v2_ms_coco() -> Self { + Self::rtdetr().with_model_file("v2-ms-coco.onnx") + } + + pub fn rtdetr_v2_m_coco() -> Self { + Self::rtdetr().with_model_file("v2-m-coco.onnx") + } + + pub fn rtdetr_v2_l_coco() -> Self { + Self::rtdetr().with_model_file("v2-l-coco.onnx") + } + + pub fn rtdetr_v2_x_coco() -> Self { + Self::rtdetr().with_model_file("v2-x-coco.onnx") + } + + pub fn dfine() -> Self { + Self::rtdetr().with_model_name("dfine") + } + + pub fn dfine_n_coco() -> Self { + Self::dfine().with_model_file("n-coco.onnx") + } + + pub fn dfine_s_coco() -> Self { + Self::dfine().with_model_file("s-coco.onnx") + } + + pub fn dfine_m_coco() -> Self { + Self::dfine().with_model_file("m-coco.onnx") + } + + pub fn dfine_l_coco() -> Self { + Self::dfine().with_model_file("l-coco.onnx") + } + + pub fn dfine_x_coco() -> Self { + Self::dfine().with_model_file("x-coco.onnx") + } + + pub fn dfine_s_coco_obj365() -> Self { + Self::dfine().with_model_file("s-obj2coco.onnx") + } + + pub fn dfine_m_coco_obj365() -> Self { + Self::dfine().with_model_file("m-obj2coco.onnx") + } + + pub fn dfine_l_coco_obj365() -> Self { + Self::dfine().with_model_file("l-obj2coco.onnx") + } + + pub fn dfine_x_coco_obj365() -> Self { + Self::dfine().with_model_file("x-obj2coco.onnx") + } +} diff --git a/src/models/rtdetr/impl.rs b/src/models/rtdetr/impl.rs new file mode 100644 index 0000000..70e5262 --- /dev/null +++ b/src/models/rtdetr/impl.rs @@ -0,0 +1,128 @@ +use crate::{elapsed, Bbox, DynConf, Engine, Processor, Ts, Xs, Ys, X, Y}; +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Axis}; +use rayon::prelude::*; + +use crate::Options; + +#[derive(Debug, Builder)] +pub struct RTDETR { + engine: Engine, + height: usize, + width: usize, + batch: usize, + names: Vec, + confs: DynConf, + ts: Ts, + processor: Processor, + spec: String, +} + +impl RTDETR { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&640.into()).opt(), + engine.try_width().unwrap_or(&640.into()).opt(), + engine.ts.clone(), + ); + let spec = engine.spec().to_owned(); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let names = options + .class_names() + .expect("No class names specified.") + .to_vec(); + let confs = DynConf::new(options.class_confs(), names.len()); + + Ok(Self { + engine, + height, + width, + batch, + spec, + names, + confs, + ts, + processor, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let x1 = self.processor.process_images(xs)?; + let x2 = X::from(vec![self.height as f32, self.width as f32]) + .insert_axis(0)? + .repeat(0, self.batch)?; + + let xs = Xs::from(vec![x1, x2]); + + Ok(xs) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn postprocess(&mut self, xs: Xs) -> Result { + let ys: Vec = xs[0] + .axis_iter(Axis(0)) + .into_par_iter() + .zip(xs[1].axis_iter(Axis(0)).into_par_iter()) + .zip(xs[2].axis_iter(Axis(0)).into_par_iter()) + .enumerate() + .filter_map(|(idx, ((labels, boxes), scores))| { + let ratio = self.processor.scale_factors_hw[idx][0]; + + let mut y_bboxes = Vec::new(); + for (i, &score) in scores.iter().enumerate() { + let class_id = labels[i] as usize; + if score < self.confs[class_id] { + continue; + } + + let xyxy = boxes.slice(s![i, ..]); + let (x1, y1, x2, y2) = ( + xyxy[0] / ratio, + xyxy[1] / ratio, + xyxy[2] / ratio, + xyxy[3] / ratio, + ); + + y_bboxes.push( + Bbox::default() + .with_xyxy(x1.max(0.0f32), y1.max(0.0f32), x2, y2) + .with_confidence(score) + .with_id(class_id as isize) + .with_name(&self.names[class_id]), + ); + } + + let mut y = Y::default(); + if !y_bboxes.is_empty() { + y = y.with_bboxes(&y_bboxes); + } + + Some(y) + }) + .collect(); + + Ok(ys.into()) + } +} diff --git a/src/models/rtdetr/mod.rs b/src/models/rtdetr/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/rtdetr/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/rtmo/config.rs b/src/models/rtmo/config.rs new file mode 100644 index 0000000..d223269 --- /dev/null +++ b/src/models/rtmo/config.rs @@ -0,0 +1,28 @@ +/// Model configuration for `RTMO` +impl crate::Options { + pub fn rtmo() -> Self { + Self::default() + .with_model_name("rtmo") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 2, 640.into()) + .with_model_ixx(0, 3, 640.into()) + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_resize_filter("CatmullRom") + .with_normalize(false) + .with_nk(17) + .with_class_confs(&[0.3]) + .with_keypoint_confs(&[0.5]) + } + + pub fn rtmo_s() -> Self { + Self::rtmo().with_model_file("s.onnx") + } + + pub fn rtmo_m() -> Self { + Self::rtmo().with_model_file("m.onnx") + } + + pub fn rtmo_l() -> Self { + Self::rtmo().with_model_file("l.onnx") + } +} diff --git a/src/models/rtmo.rs b/src/models/rtmo/impl.rs similarity index 55% rename from src/models/rtmo.rs rename to src/models/rtmo/impl.rs index 1ae4b4d..f23f448 100644 --- a/src/models/rtmo.rs +++ b/src/models/rtmo/impl.rs @@ -1,75 +1,87 @@ +use aksr::Builder; use anyhow::Result; use image::DynamicImage; use ndarray::Axis; -use crate::{Bbox, DynConf, Keypoint, MinOptMax, Options, OrtEngine, Xs, X, Y}; +use crate::{elapsed, Bbox, DynConf, Engine, Keypoint, Options, Processor, Ts, Xs, Ys, Y}; -#[derive(Debug)] +#[derive(Builder, Debug)] pub struct RTMO { - engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, + engine: Engine, + height: usize, + width: usize, + batch: usize, + ts: Ts, + spec: String, + processor: Processor, confs: DynConf, kconfs: DynConf, } impl RTMO { pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().to_owned(), - engine.height().to_owned(), - engine.width().to_owned(), + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&512.into()).opt(), + engine.try_width().unwrap_or(&512.into()).opt(), + engine.ts().clone(), ); - let nc = 1; - let nk = options.nk.unwrap_or(17); - let confs = DynConf::new(&options.confs, nc); - let kconfs = DynConf::new(&options.kconfs, nk); - engine.dry_run()?; + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + let nk = options.nk().unwrap_or(17); + let confs = DynConf::new(options.class_confs(), 1); + let kconfs = DynConf::new(options.keypoint_confs(), nk); Ok(Self { engine, - confs, - kconfs, height, width, batch, + ts, + spec, + processor, + confs, + kconfs, }) } - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::letterbox( - xs, - self.height() as u32, - self.width() as u32, - "CatmullRom", - 114, - "auto", - false, - )? - .nhwc2nchw()?; - let ys = self.engine.run(Xs::from(xs_))?; - self.postprocess(ys, xs) + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) } - pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn postprocess(&mut self, xs: Xs) -> Result { let mut ys: Vec = Vec::new(); - let (preds_bboxes, preds_kpts) = if xs[0].ndim() == 3 { - (&xs[0], &xs[1]) - } else { - (&xs[1], &xs[0]) - }; + // let (preds_bboxes, preds_kpts) = (&xs["dets"], &xs["keypoints"]); + let (preds_bboxes, preds_kpts) = (&xs[0], &xs[1]); for (idx, (batch_bboxes, batch_kpts)) in preds_bboxes .axis_iter(Axis(0)) .zip(preds_kpts.axis_iter(Axis(0))) .enumerate() { - let width_original = xs0[idx].width() as f32; - let height_original = xs0[idx].height() as f32; - let ratio = - (self.width() as f32 / width_original).min(self.height() as f32 / height_original); + let (height_original, width_original) = self.processor.image0s_size[idx]; + let ratio = self.processor.scale_factors_hw[idx][0]; let mut y_bboxes = Vec::new(); let mut y_kpts: Vec> = Vec::new(); @@ -90,8 +102,8 @@ impl RTMO { y_bboxes.push( Bbox::default() .with_xyxy( - x1.max(0.0f32).min(width_original), - y1.max(0.0f32).min(height_original), + x1.max(0.0f32).min(width_original as _), + y1.max(0.0f32).min(height_original as _), x2, y2, ) @@ -114,8 +126,8 @@ impl RTMO { .with_id(i as isize) .with_confidence(c) .with_xy( - x.max(0.0f32).min(width_original), - y.max(0.0f32).min(height_original), + x.max(0.0f32).min(width_original as _), + y.max(0.0f32).min(height_original as _), ), ); } @@ -124,18 +136,7 @@ impl RTMO { } ys.push(Y::default().with_bboxes(&y_bboxes).with_keypoints(&y_kpts)); } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - pub fn height(&self) -> isize { - self.height.opt() as _ + Ok(ys.into()) } } diff --git a/src/models/rtmo/mod.rs b/src/models/rtmo/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/rtmo/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/sam/config.rs b/src/models/sam/config.rs new file mode 100644 index 0000000..0e9ce58 --- /dev/null +++ b/src/models/sam/config.rs @@ -0,0 +1,100 @@ +use crate::{models::SamKind, Options}; + +/// Model configuration for `Segment Anything Model` +impl Options { + pub fn sam() -> Self { + Self::default() + .with_model_name("sam") + .with_model_ixx(0, 0, 1.into()) + } + + pub fn sam_encoder() -> Self { + Self::sam() + .with_model_ixx(0, 2, 1024.into()) + .with_model_ixx(0, 3, 1024.into()) + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_resize_filter("Bilinear") + .with_image_mean(&[123.5, 116.5, 103.5]) + .with_image_std(&[58.5, 57.0, 57.5]) + .with_normalize(false) + .with_sam_kind(SamKind::Sam) + .with_low_res_mask(false) + .with_find_contours(true) + } + + pub fn sam_decoder() -> Self { + Self::sam() + } + + pub fn sam_v1_base_encoder() -> Self { + Self::sam_encoder().with_model_file("sam-vit-b-encoder.onnx") + } + + pub fn sam_v1_base_decoder() -> Self { + Self::sam_decoder().with_model_file("sam-vit-b-decoder.onnx") + } + + pub fn sam_v1_base_singlemask_decoder() -> Self { + Self::sam_decoder().with_model_file("sam-vit-b-decoder-singlemask.onnx") + } + + pub fn sam2_tiny_encoder() -> Self { + Self::sam_encoder() + .with_model_file("sam2-hiera-tiny-encoder.onnx") + .with_sam_kind(SamKind::Sam2) + } + + pub fn sam2_tiny_decoder() -> Self { + Self::sam_decoder().with_model_file("sam2-hiera-tiny-decoder.onnx") + } + + pub fn sam2_small_encoder() -> Self { + Self::sam_encoder() + .with_model_file("sam2-hiera-small-encoder.onnx") + .with_sam_kind(SamKind::Sam2) + } + + pub fn sam2_small_decoder() -> Self { + Self::sam_decoder().with_model_file("sam2-hiera-small-decoder.onnx") + } + + pub fn sam2_base_plus_encoder() -> Self { + Self::sam_encoder() + .with_model_file("sam2-hiera-base-plus-encoder.onnx") + .with_sam_kind(SamKind::Sam2) + } + + pub fn sam2_base_plus_decoder() -> Self { + Self::sam_decoder().with_model_file("sam2-hiera-base-plus-decoder.onnx") + } + + pub fn mobile_sam_tiny_encoder() -> Self { + Self::sam_encoder() + .with_model_file("mobile-sam-vit-t-encoder.onnx") + .with_sam_kind(SamKind::MobileSam) + } + + pub fn mobile_sam_tiny_decoder() -> Self { + Self::sam_decoder().with_model_file("mobile-sam-vit-t-decoder.onnx") + } + + pub fn sam_hq_tiny_encoder() -> Self { + Self::sam_encoder() + .with_model_file("sam-hq-vit-t-encoder.onnx") + .with_sam_kind(SamKind::SamHq) + } + + pub fn sam_hq_tiny_decoder() -> Self { + Self::sam_decoder().with_model_file("sam-hq-vit-t-decoder.onnx") + } + + pub fn edge_sam_3x_encoder() -> Self { + Self::sam_encoder() + .with_model_file("edge-sam-3x-encoder.onnx") + .with_sam_kind(SamKind::EdgeSam) + } + + pub fn edge_sam_3x_decoder() -> Self { + Self::sam_decoder().with_model_file("edge-sam-3x-decoder.onnx") + } +} diff --git a/src/models/sam.rs b/src/models/sam/impl.rs similarity index 73% rename from src/models/sam.rs rename to src/models/sam/impl.rs index 2c094c2..41eb7ba 100644 --- a/src/models/sam.rs +++ b/src/models/sam/impl.rs @@ -1,11 +1,12 @@ +use aksr::Builder; use anyhow::Result; use image::DynamicImage; use ndarray::{s, Array, Axis}; use rand::prelude::*; -use crate::{DynConf, Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, Xs, X, Y}; +use crate::{elapsed, DynConf, Engine, Mask, Ops, Options, Polygon, Processor, Ts, Xs, Ys, X, Y}; -#[derive(Debug, Clone, clap::ValueEnum)] +#[derive(Debug, Clone)] pub enum SamKind { Sam, Sam2, @@ -14,6 +15,21 @@ pub enum SamKind { EdgeSam, } +impl TryFrom<&str> for SamKind { + type Error = anyhow::Error; + + fn try_from(s: &str) -> Result { + match s.to_lowercase().as_str() { + "sam" => Ok(Self::Sam), + "sam2" => Ok(Self::Sam2), + "mobilesam" | "mobile-sam" => Ok(Self::MobileSam), + "samhq" | "sam-hq" => Ok(Self::SamHq), + "edgesam" | "edge-sam" => Ok(Self::EdgeSam), + x => anyhow::bail!("Unsupported SamKind: {}", x), + } + } +} + #[derive(Debug, Default, Clone)] pub struct SamPrompt { points: Vec, @@ -62,87 +78,81 @@ impl SamPrompt { } } -#[derive(Debug)] +#[derive(Builder, Debug)] pub struct SAM { - encoder: OrtEngine, - decoder: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, - pub conf: DynConf, + encoder: Engine, + decoder: Engine, + height: usize, + width: usize, + batch: usize, + processor: Processor, + conf: DynConf, find_contours: bool, kind: SamKind, use_low_res_mask: bool, + ts: Ts, + spec: String, } impl SAM { pub fn new(options_encoder: Options, options_decoder: Options) -> Result { - let mut encoder = OrtEngine::new(&options_encoder)?; - let mut decoder = OrtEngine::new(&options_decoder)?; + let encoder = options_encoder.to_engine()?; + let decoder = options_decoder.to_engine()?; let (batch, height, width) = ( - encoder.inputs_minoptmax()[0][0].to_owned(), - encoder.inputs_minoptmax()[0][2].to_owned(), - encoder.inputs_minoptmax()[0][3].to_owned(), + encoder.batch().opt(), + encoder.try_height().unwrap_or(&1024.into()).opt(), + encoder.try_width().unwrap_or(&1024.into()).opt(), ); - let conf = DynConf::new(&options_decoder.confs, 1); + let ts = Ts::merge(&[encoder.ts(), decoder.ts()]); + let spec = encoder.spec().to_owned(); + + let processor = options_encoder + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); - let kind = match options_decoder.sam_kind { + let conf = DynConf::new(options_encoder.class_confs(), 1); + let find_contours = options_encoder.find_contours; + let kind = match options_encoder.sam_kind { Some(x) => x, None => anyhow::bail!("Error: no clear `SamKind` specified."), }; - let find_contours = options_decoder.find_contours; let use_low_res_mask = match kind { SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => { - options_decoder.low_res_mask.unwrap_or(false) + options_encoder.low_res_mask.unwrap_or(false) } SamKind::EdgeSam | SamKind::Sam2 => true, }; - encoder.dry_run()?; - decoder.dry_run()?; - Ok(Self { encoder, decoder, + conf, batch, height, width, - conf, + ts, + processor, kind, find_contours, use_low_res_mask, + spec, }) } - pub fn run(&mut self, xs: &[DynamicImage], prompts: &[SamPrompt]) -> Result> { - let ys = self.encode(xs)?; - self.decode(&ys, xs, prompts) + pub fn forward(&mut self, xs: &[DynamicImage], prompts: &[SamPrompt]) -> Result { + let ys = elapsed!("encode", self.ts, { self.encode(xs)? }); + let ys = elapsed!("decode", self.ts, { self.decode(&ys, prompts)? }); + + Ok(ys) } pub fn encode(&mut self, xs: &[DynamicImage]) -> Result { - let xs_ = X::apply(&[ - Ops::Letterbox( - xs, - self.height() as u32, - self.width() as u32, - "Bilinear", - 0, - "auto", - false, - ), - Ops::Standardize(&[123.675, 116.28, 103.53], &[58.395, 57.12, 57.375], 3), - Ops::Nhwc2nchw, - ])?; - + let xs_ = self.processor.process_images(xs)?; self.encoder.run(Xs::from(xs_)) } - pub fn decode( - &mut self, - xs: &Xs, - xs0: &[DynamicImage], - prompts: &[SamPrompt], - ) -> Result> { + pub fn decode(&mut self, xs: &Xs, prompts: &[SamPrompt]) -> Result { let (image_embeddings, high_res_features_0, high_res_features_1) = match self.kind { SamKind::Sam2 => (&xs[0], Some(&xs[1]), Some(&xs[2])), _ => (&xs[0], None, None), @@ -150,44 +160,43 @@ impl SAM { let mut ys: Vec = Vec::new(); for (idx, image_embedding) in image_embeddings.axis_iter(Axis(0)).enumerate() { - let image_width = xs0[idx].width() as f32; - let image_height = xs0[idx].height() as f32; - let ratio = - (self.width() as f32 / image_width).min(self.height() as f32 / image_height); + let (image_height, image_width) = self.processor.image0s_size[idx]; + let ratio = self.processor.scale_factors_hw[idx][0]; + let args = match self.kind { SamKind::Sam | SamKind::MobileSam => { vec![ X::from(image_embedding.into_dyn().into_owned()) .insert_axis(0)? - .repeat(0, self.batch() as usize)?, // image_embedding + .repeat(0, self.batch)?, // image_embedding prompts[idx].point_coords(ratio)?, // point_coords prompts[idx].point_labels()?, // point_labels X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input, - X::zeros(&[1]), // has_mask_input - X::from(vec![image_height, image_width]), // orig_im_size + X::zeros(&[1]), // has_mask_input + X::from(vec![image_height as _, image_width as _]), // orig_im_size ] } SamKind::SamHq => { vec![ X::from(image_embedding.into_dyn().into_owned()) .insert_axis(0)? - .repeat(0, self.batch() as usize)?, // image_embedding + .repeat(0, self.batch)?, // image_embedding X::from(xs[1].slice(s![idx, .., .., ..]).into_dyn().into_owned()) .insert_axis(0)? .insert_axis(0)? - .repeat(0, self.batch() as usize)?, // intern_embedding + .repeat(0, self.batch)?, // intern_embedding prompts[idx].point_coords(ratio)?, // point_coords prompts[idx].point_labels()?, // point_labels X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input - X::zeros(&[1]), // has_mask_input - X::from(vec![image_height, image_width]), // orig_im_size + X::zeros(&[1]), // has_mask_input + X::from(vec![image_height as _, image_width as _]), // orig_im_size ] } SamKind::EdgeSam => { vec![ X::from(image_embedding.into_dyn().into_owned()) .insert_axis(0)? - .repeat(0, self.batch() as usize)?, + .repeat(0, self.batch)?, prompts[idx].point_coords(ratio)?, prompts[idx].point_labels()?, ] @@ -196,7 +205,7 @@ impl SAM { vec![ X::from(image_embedding.into_dyn().into_owned()) .insert_axis(0)? - .repeat(0, self.batch() as usize)?, + .repeat(0, self.batch)?, X::from( high_res_features_0 .unwrap() @@ -205,7 +214,7 @@ impl SAM { .into_owned(), ) .insert_axis(0)? - .repeat(0, self.batch() as usize)?, + .repeat(0, self.batch)?, X::from( high_res_features_1 .unwrap() @@ -214,12 +223,12 @@ impl SAM { .into_owned(), ) .insert_axis(0)? - .repeat(0, self.batch() as usize)?, + .repeat(0, self.batch)?, prompts[idx].point_coords(ratio)?, prompts[idx].point_labels()?, X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input - X::zeros(&[1]), // has_mask_input - X::from(vec![image_height, image_width]), // orig_im_size + X::zeros(&[1]), // has_mask_input + X::from(vec![image_height as _, image_width as _]), // orig_im_size ] } }; @@ -310,26 +319,14 @@ impl SAM { ys.push(y); } - Ok(ys) + Ok(ys.into()) } pub fn width_low_res(&self) -> usize { - self.width() as usize / 4 + self.width / 4 } pub fn height_low_res(&self) -> usize { - self.height() as usize / 4 - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - - pub fn height(&self) -> isize { - self.height.opt() as _ + self.height / 4 } } diff --git a/src/models/sam/mod.rs b/src/models/sam/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/sam/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/sapiens/config.rs b/src/models/sapiens/config.rs new file mode 100644 index 0000000..87dd6c5 --- /dev/null +++ b/src/models/sapiens/config.rs @@ -0,0 +1,47 @@ +use crate::models::BODY_PARTS_NAMES_28; + +/// Model configuration for `Sapiens` +impl crate::Options { + pub fn sapiens() -> Self { + Self::default() + .with_model_name("sapiens") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 2, 1024.into()) + .with_model_ixx(0, 3, 768.into()) + .with_resize_mode(crate::ResizeMode::FitExact) + .with_resize_filter("Bilinear") + .with_image_mean(&[123.5, 116.5, 103.5]) + .with_image_std(&[58.5, 57.0, 57.5]) + .with_normalize(false) + } + + pub fn sapiens_body_part_segmentation() -> Self { + Self::sapiens() + .with_model_task(crate::Task::InstanceSegmentation) + .with_class_names(&BODY_PARTS_NAMES_28) + } + + pub fn sapiens_seg_0_3b() -> Self { + Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b.onnx") + } + + // pub fn sapiens_seg_0_3b_uint8() -> Self { + // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-uint8.onnx") + // } + + // pub fn sapiens_seg_0_3b_fp16() -> Self { + // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-fp16.onnx") + // } + + // pub fn sapiens_seg_0_3b_bnb4() -> Self { + // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-bnb4.onnx") + // } + + // pub fn sapiens_seg_0_3b_q4f16() -> Self { + // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-q4f16.onnx") + // } + + // pub fn sapiens_seg_0_6b_fp16() -> Self { + // Self::sapiens_body_part_segmentation().with_model_file("seg-0.6b-fp16.onnx") + // } +} diff --git a/src/models/sapiens.rs b/src/models/sapiens/impl.rs similarity index 65% rename from src/models/sapiens.rs rename to src/models/sapiens/impl.rs index c43a6b9..927ca4b 100644 --- a/src/models/sapiens.rs +++ b/src/models/sapiens/impl.rs @@ -1,73 +1,84 @@ +use aksr::Builder; use anyhow::Result; use image::DynamicImage; use ndarray::{s, Array2, Axis}; -use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, Xs, X, Y}; +use crate::{elapsed, Engine, Mask, Ops, Options, Polygon, Processor, Task, Ts, Xs, Ys, Y}; -#[derive(Debug, Clone, clap::ValueEnum)] -pub enum SapiensTask { - Seg, - Depth, - Normal, - Pose, -} - -#[derive(Debug)] +#[derive(Builder, Debug)] pub struct Sapiens { - engine_seg: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, - task: SapiensTask, + engine: Engine, + height: usize, + width: usize, + batch: usize, + task: Task, names_body: Option>, + ts: Ts, + processor: Processor, + spec: String, } impl Sapiens { - pub fn new(options_seg: Options) -> Result { - let mut engine_seg = OrtEngine::new(&options_seg)?; - let (batch, height, width) = ( - engine_seg.batch().to_owned(), - engine_seg.height().to_owned(), - engine_seg.width().to_owned(), + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&1024.into()).opt(), + engine.try_width().unwrap_or(&768.into()).opt(), + engine.ts().clone(), ); - let task = options_seg - .sapiens_task - .expect("Error: No sapiens task specified."); - let names_body = options_seg.names; - engine_seg.dry_run()?; + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let task = options.model_task.expect("No sapiens task specified."); + let names_body = options.class_names; Ok(Self { - engine_seg, + engine, height, width, batch, task, names_body, + ts, + processor, + spec, }) } - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Resize(xs, self.height() as u32, self.width() as u32, "Bilinear"), - Ops::Standardize(&[123.5, 116.5, 103.5], &[58.5, 57.0, 57.5], 3), - Ops::Nhwc2nchw, - ])?; - - match self.task { - SapiensTask::Seg => { - let ys = self.engine_seg.run(Xs::from(xs_))?; - self.postprocess_seg(ys, xs) + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { + if let Task::InstanceSegmentation = self.task { + self.postprocess_seg(ys)? + } else { + unimplemented!() } - _ => todo!(), - } + }); + + Ok(ys) } - pub fn postprocess_seg(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { + pub fn summary(&mut self) { + self.ts.summary(); + } + + pub fn postprocess_seg(&self, xs: Xs) -> Result { let mut ys: Vec = Vec::new(); for (idx, b) in xs[0].axis_iter(Axis(0)).enumerate() { - let (w1, h1) = (xs0[idx].width(), xs0[idx].height()); - // rescale + let (h1, w1) = self.processor.image0s_size[idx]; let masks = Ops::interpolate_3d(b.to_owned(), w1 as _, h1 as _, "Bilinear")?; // generate mask @@ -131,7 +142,6 @@ impl Sapiens { Some(p) => p, None => continue, }; - y_polygons.push(polygon); let mut mask = Mask::default().with_mask(luma).with_id(*i as _); @@ -142,18 +152,7 @@ impl Sapiens { } ys.push(Y::default().with_masks(&y_masks).with_polygons(&y_polygons)); } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - pub fn height(&self) -> isize { - self.height.opt() as _ + Ok(ys.into()) } } diff --git a/src/models/sapiens/mod.rs b/src/models/sapiens/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/sapiens/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/scale.rs b/src/models/scale.rs new file mode 100644 index 0000000..4dc5ab4 --- /dev/null +++ b/src/models/scale.rs @@ -0,0 +1,83 @@ +#[derive(Debug, Copy, Clone)] +pub enum Scale { + N, + T, + B, + S, + M, + L, + C, + E, + X, + G, + P, + A, + F, +} + +impl std::fmt::Display for Scale { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let x = match self { + Self::N => "n", + Self::T => "t", + Self::S => "s", + Self::B => "b", + Self::M => "m", + Self::L => "l", + Self::C => "c", + Self::E => "e", + Self::X => "x", + Self::G => "g", + Self::P => "p", + Self::A => "a", + Self::F => "f", + }; + write!(f, "{}", x) + } +} + +impl TryFrom for Scale { + type Error = anyhow::Error; + + fn try_from(s: char) -> Result { + match s { + 'n' => Ok(Self::N), + 't' => Ok(Self::T), + 'b' => Ok(Self::B), + 's' => Ok(Self::S), + 'm' => Ok(Self::M), + 'l' => Ok(Self::L), + 'c' => Ok(Self::C), + 'e' => Ok(Self::E), + 'x' => Ok(Self::X), + 'g' => Ok(Self::G), + 'p' => Ok(Self::P), + 'a' => Ok(Self::A), + 'f' => Ok(Self::F), + x => anyhow::bail!("Unsupported model scale: {:?}", x), + } + } +} + +impl TryFrom<&str> for Scale { + type Error = anyhow::Error; + + fn try_from(s: &str) -> Result { + match s.to_lowercase().as_str() { + "n" | "nano" => Ok(Self::N), + "t" | "tiny" => Ok(Self::T), + "b" | "base" => Ok(Self::B), + "s" | "small" => Ok(Self::S), + "m" | "medium" => Ok(Self::M), + "l" | "large" => Ok(Self::L), + "c" => Ok(Self::C), + "e" => Ok(Self::E), + "x" | "extra-large" => Ok(Self::X), + "g" | "giant" => Ok(Self::G), + "p" | "pico" => Ok(Self::P), + "a" | "atto" => Ok(Self::A), + "f" | "femto" => Ok(Self::F), + x => anyhow::bail!("Unsupported model scale: {:?}", x), + } + } +} diff --git a/src/models/slanet/config.rs b/src/models/slanet/config.rs new file mode 100644 index 0000000..f29b311 --- /dev/null +++ b/src/models/slanet/config.rs @@ -0,0 +1,22 @@ +/// Model configuration for `SLANet` +impl crate::Options { + pub fn slanet() -> Self { + Self::default() + .with_model_name("slanet") + .with_model_ixx(0, 0, (1, 1, 8).into()) + .with_model_ixx(0, 2, (320, 488, 488).into()) + .with_model_ixx(0, 3, (320, 488, 488).into()) + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + .with_normalize(true) + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_padding_value(0) + .with_unsigned(true) + } + + pub fn slanet_lcnet_v2_mobile_ch() -> Self { + Self::slanet() + .with_model_file("v2-mobile-ch.onnx") + .with_vocab_txt("vocab-sla-v2.txt") + } +} diff --git a/src/models/slanet/impl.rs b/src/models/slanet/impl.rs new file mode 100644 index 0000000..cfbd50c --- /dev/null +++ b/src/models/slanet/impl.rs @@ -0,0 +1,109 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Axis}; + +use crate::{elapsed, models::BaseModelVisual, Keypoint, Options, Text, Ts, Xs, Ys, Y}; + +#[derive(Builder, Debug)] +pub struct SLANet { + base: BaseModelVisual, + td_tokens: Vec<&'static str>, + eos: usize, + sos: usize, + ts: Ts, + spec: String, +} + +impl SLANet { + pub fn summary(&mut self) { + self.ts.summary(); + } + + pub fn new(options: Options) -> Result { + let base = BaseModelVisual::new(options)?; + let spec = base.engine().spec().to_owned(); + let sos = 0; + let eos = base.processor().vocab().len() - 1; + let td_tokens = vec!["", ""]; + let ts = base.ts().clone(); + + Ok(Self { + base, + td_tokens, + eos, + sos, + ts, + spec, + }) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.base.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.base.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + fn postprocess(&self, xs: Xs) -> Result { + let mut ys: Vec = Vec::new(); + for (bid, (bboxes, structures)) in xs[0] + .axis_iter(Axis(0)) + .zip(xs[1].axis_iter(Axis(0))) + .enumerate() + { + let mut y_texts: Vec = vec!["".into(), "".into(), "".into()]; + let mut y_kpts: Vec> = Vec::new(); + let (image_height, image_width) = self.base.processor().image0s_size[bid]; + for (i, structure) in structures.axis_iter(Axis(0)).enumerate() { + let (token_id, &_confidence) = match structure + .into_iter() + .enumerate() + .max_by(|a, b| a.1.total_cmp(b.1)) + { + None => continue, + Some((id, conf)) => (id, conf), + }; + if token_id == self.eos { + break; + } + if token_id == self.sos { + continue; + } + + // token + let token = self.base.processor().vocab()[token_id].as_str(); + + // keypoint + if self.td_tokens.contains(&token) { + let slice_bboxes = bboxes.slice(s![i, ..]); + let x14 = slice_bboxes + .slice(s![0..;2]) + .mapv(|x| x * image_width as f32); + let y14 = slice_bboxes + .slice(s![1..;2]) + .mapv(|x| x * image_height as f32); + y_kpts.push( + (0..=3) + .map(|i| (x14[i], y14[i], i as isize).into()) + .collect(), + ); + } + + y_texts.push(token.into()); + } + + // clean up text + if y_texts.len() == 3 { + y_texts.clear(); + } else { + y_texts.extend_from_slice(&["
".into(), "".into(), "".into()]); + } + + ys.push(Y::default().with_keypoints(&y_kpts).with_texts(&y_texts)); + } + + Ok(ys.into()) + } +} diff --git a/src/models/slanet/mod.rs b/src/models/slanet/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/slanet/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/svtr.rs b/src/models/svtr.rs deleted file mode 100644 index e2e9df7..0000000 --- a/src/models/svtr.rs +++ /dev/null @@ -1,101 +0,0 @@ -use anyhow::Result; -use image::DynamicImage; -use ndarray::Axis; - -use crate::{DynConf, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; - -#[derive(Debug)] -pub struct SVTR { - engine: OrtEngine, - pub height: MinOptMax, - pub width: MinOptMax, - pub batch: MinOptMax, - confs: DynConf, - vocab: Vec, -} - -impl SVTR { - pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().to_owned(), - engine.height().to_owned(), - engine.width().to_owned(), - ); - let confs = DynConf::new(&options.confs, 1); - let mut vocab: Vec<_> = - std::fs::read_to_string(options.vocab.expect("No vocabulary found"))? - .lines() - .map(|line| line.to_string()) - .collect(); - vocab.push(" ".to_string()); - vocab.insert(0, "Blank".to_string()); - engine.dry_run()?; - - Ok(Self { - engine, - height, - width, - batch, - vocab, - confs, - }) - } - - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Letterbox( - xs, - self.height.opt() as u32, - self.width.opt() as u32, - "Bilinear", - 0, - "auto", - false, - ), - Ops::Normalize(0., 255.), - Ops::Nhwc2nchw, - ])?; - - let ys = self.engine.run(Xs::from(xs_))?; - self.postprocess(ys) - } - - pub fn postprocess(&self, xs: Xs) -> Result> { - let mut ys: Vec = Vec::new(); - for batch in xs[0].axis_iter(Axis(0)) { - let preds = batch - .axis_iter(Axis(0)) - .filter_map(|x| { - x.into_iter() - .enumerate() - .max_by(|(_, x), (_, y)| x.total_cmp(y)) - }) - .collect::>(); - - let text = preds - .iter() - .enumerate() - .fold(Vec::new(), |mut text_ids, (idx, (text_id, &confidence))| { - if *text_id == 0 || confidence < self.confs[0] { - return text_ids; - } - - if idx == 0 || idx == self.vocab.len() - 1 { - return text_ids; - } - - if *text_id != preds[idx - 1].0 { - text_ids.push(*text_id); - } - text_ids - }) - .into_iter() - .map(|idx| self.vocab[idx].to_owned()) - .collect::(); - - ys.push(Y::default().with_texts(&[&text])) - } - Ok(ys) - } -} diff --git a/src/models/svtr/config.rs b/src/models/svtr/config.rs new file mode 100644 index 0000000..93fc38e --- /dev/null +++ b/src/models/svtr/config.rs @@ -0,0 +1,43 @@ +/// Model configuration for `SVTR` +impl crate::Options { + pub fn svtr() -> Self { + Self::default() + .with_model_name("svtr") + .with_model_ixx(0, 0, (1, 1, 8).into()) + .with_model_ixx(0, 2, 48.into()) + .with_model_ixx(0, 3, (320, 960, 1600).into()) + .with_resize_mode(crate::ResizeMode::FitHeight) + .with_padding_value(0) + .with_normalize(true) + .with_class_confs(&[0.2]) + .with_vocab_txt("vocab-v1-ppocr-rec-ch.txt") + } + + pub fn ppocr_rec_v3_ch() -> Self { + Self::svtr().with_model_file("ppocr-v3-ch.onnx") + } + + pub fn ppocr_rec_v4_ch() -> Self { + Self::svtr().with_model_file("ppocr-v4-ch.onnx") + } + + pub fn ppocr_rec_v4_server_ch() -> Self { + Self::svtr().with_model_file("ppocr-v4-server-ch.onnx") + } + + pub fn svtr_v2_server_ch() -> Self { + Self::svtr().with_model_file("v2-server-ch.onnx") + } + + pub fn repsvtr_ch() -> Self { + Self::svtr().with_model_file("repsvtr-ch.onnx") + } + + pub fn svtr_v2_teacher_ch() -> Self { + Self::svtr().with_model_file("v2-distill-teacher-ch.onnx") + } + + pub fn svtr_v2_student_ch() -> Self { + Self::svtr().with_model_file("v2-distill-student-ch.onnx") + } +} diff --git a/src/models/svtr/impl.rs b/src/models/svtr/impl.rs new file mode 100644 index 0000000..f14f5f8 --- /dev/null +++ b/src/models/svtr/impl.rs @@ -0,0 +1,109 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::Axis; + +use crate::{elapsed, DynConf, Engine, Options, Processor, Ts, Xs, Ys, Y}; + +#[derive(Builder, Debug)] +pub struct SVTR { + engine: Engine, + height: usize, + width: usize, + batch: usize, + confs: DynConf, + spec: String, + ts: Ts, + processor: Processor, +} + +impl SVTR { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&960.into()).opt(), + engine.try_width().unwrap_or(&960.into()).opt(), + engine.ts.clone(), + ); + let spec = options.model_spec().to_string(); + let confs = DynConf::new(options.class_confs(), 1); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + if processor.vocab().is_empty() { + anyhow::bail!("No vocab file found") + } + + Ok(Self { + engine, + height, + width, + batch, + confs, + processor, + spec, + ts, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + pub fn postprocess(&self, xs: Xs) -> Result { + let mut ys: Vec = Vec::new(); + for batch in xs[0].axis_iter(Axis(0)) { + let preds = batch + .axis_iter(Axis(0)) + .filter_map(|x| { + x.into_iter() + .enumerate() + .max_by(|(_, x), (_, y)| x.total_cmp(y)) + }) + .collect::>(); + + let text = preds + .iter() + .enumerate() + .fold(Vec::new(), |mut text_ids, (idx, (text_id, &confidence))| { + if *text_id == 0 || confidence < self.confs[0] { + return text_ids; + } + + if idx == 0 || idx == self.processor.vocab().len() - 1 { + return text_ids; + } + + if *text_id != preds[idx - 1].0 { + text_ids.push(*text_id); + } + text_ids + }) + .into_iter() + .map(|idx| self.processor.vocab()[idx].to_owned()) + .collect::(); + + ys.push(Y::default().with_texts(&[text.into()])) + } + + Ok(ys.into()) + } +} diff --git a/src/models/svtr/mod.rs b/src/models/svtr/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/svtr/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/core/task.rs b/src/models/task.rs similarity index 75% rename from src/core/task.rs rename to src/models/task.rs index 8090625..80e5c33 100644 --- a/src/core/task.rs +++ b/src/models/task.rs @@ -1,7 +1,5 @@ -#[derive(Debug, Clone, Ord, Eq, PartialOrd, PartialEq)] +#[derive(Debug, Copy, Clone, Ord, Eq, PartialOrd, PartialEq)] pub enum Task { - Untitled, - /// Image classification task. /// Input: image /// Output: a label representing the class of the image @@ -27,13 +25,14 @@ pub enum Task { /// Input: image /// Output: bounding boxes (bboxes), class labels, and optional scores for the detected objects ObjectDetection, + OrientedObjectDetection, + Obb, /// Open set detection task, detecting and classifying objects in an image, with the ability to handle unseen or unknown objects. /// Input: image /// Output: bounding boxes, class labels (including an "unknown" category for unfamiliar objects), and detection scores /// Open set detection task, with String query - OpenSetDetection(String), - + OpenSetDetection(&'static str), /// Task for generating brief descriptions of dense regions in the image. /// Input: image /// Output: bounding boxes (bboxes), brief phrase labels, and optional scores for detected regions @@ -44,12 +43,16 @@ pub enum Task { /// Input: image /// Output: coordinates of detected keypoints KeypointsDetection, + Pose, /// Semantic segmentation task, segmenting the image into different semantic regions. /// Input: image /// Output: per-pixel class labels indicating object or background SemanticSegmentation, + ImageFeatureExtraction, + TextFeatureExtraction, + /// Instance segmentation task, detecting and segmenting individual object instances. /// Input: image /// Output: pixel masks for each object instance @@ -94,12 +97,12 @@ pub enum Task { /// Input: image and text /// Output: image region and the corresponding phrase /// caption to phrase grounding - CaptionToPhraseGrounding(String), + CaptionToPhraseGrounding(&'static str), /// Referring expression segmentation task, segmenting objects in the image based on a text description. /// Input: image and referring expression /// Output: a segmentation mask for the object referred to by the text - ReferringExpressionSegmentation(String), + ReferringExpressionSegmentation(&'static str), /// Region-to-segmentation task, similar to combining object detection with segmentation (e.g., YOLO + SAM). /// Input: image and region proposals @@ -122,7 +125,7 @@ pub enum Task { /// Visual question answering (VQA) task, answering questions related to an image. /// Input: image and question text /// Output: the answer to the question - Vqa(String), + Vqa(&'static str), /// Optical character recognition (OCR) task, recognizing text in an image. /// Input: image @@ -135,10 +138,59 @@ pub enum Task { OcrWithRegion, } +impl std::fmt::Display for Task { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let x = match self { + Self::ImageClassification => "image-classification", + Self::ObjectDetection => "object-detection", + Self::Pose => "pose", + Self::KeypointsDetection => "pose-detection", + Self::InstanceSegmentation => "instance-segmentation", + Self::Obb => "obb", + Self::OrientedObjectDetection => "oriented-object-detection", + Self::DepthEstimation => "depth", + Self::Caption(0) => "caption", + Self::Caption(1) => "detailed-caption", + Self::Caption(2) => "more-detailed-caption", + Self::ImageTagging => "image-tagging", + Self::Ocr => "ocr", + Self::OcrWithRegion => "ocr-with-region", + Self::Vqa(_) => "vqa", + _ => todo!(), + }; + write!(f, "{}", x) + } +} + +impl TryFrom<&str> for Task { + type Error = anyhow::Error; + + fn try_from(s: &str) -> Result { + match s.to_lowercase().as_str() { + "cls" | "classify" | "classification" => Ok(Self::ImageClassification), + "det" | "od" | "detect" => Ok(Self::ObjectDetection), + "kpt" | "pose" => Ok(Self::KeypointsDetection), + "seg" | "segment" => Ok(Self::InstanceSegmentation), + "obb" => Ok(Self::OrientedObjectDetection), + _ => todo!(), // x => anyhow::bail!("Unsupported model task: {}", x), + } + } +} + impl Task { + pub fn yolo_str(&self) -> &'static str { + match self { + Self::ImageClassification => "cls", + Self::ObjectDetection => "det", + Self::Pose | Self::KeypointsDetection => "pose", + Self::InstanceSegmentation => "seg", + Self::Obb | Self::OrientedObjectDetection => "obb", + x => unimplemented!("Unsupported YOLO Task: {}", x), + } + } + pub fn prompt_for_florence2(&self) -> anyhow::Result { let prompt = match self { - Self::Untitled => anyhow::bail!("No task specified."), Self::Caption(0) => "What does the image describe?".to_string(), Self::Caption(1) => "Describe in detail what is shown in the image.".to_string(), Self::Caption(2) => "Describe with a paragraph what is shown in the image.".to_string(), @@ -178,7 +230,7 @@ impl Task { x0, y0, x1, y1 ) } - _ => anyhow::bail!("Unsupported task."), + x => anyhow::bail!("Unsupported Florence2 task: {:?}", x), }; Ok(prompt) diff --git a/src/models/trocr/config.rs b/src/models/trocr/config.rs new file mode 100644 index 0000000..8343434 --- /dev/null +++ b/src/models/trocr/config.rs @@ -0,0 +1,92 @@ +use crate::Scale; + +/// Model configuration for `TrOCR` +impl crate::Options { + pub fn trocr() -> Self { + Self::default().with_model_name("trocr").with_batch_size(1) + } + + pub fn trocr_visual() -> Self { + Self::trocr() + .with_model_kind(crate::Kind::Vision) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 384.into()) + .with_model_ixx(0, 3, 384.into()) + .with_image_mean(&[0.5, 0.5, 0.5]) + .with_image_std(&[0.5, 0.5, 0.5]) + .with_resize_filter("Bilinear") + .with_normalize(true) + } + + pub fn trocr_textual() -> Self { + Self::trocr().with_model_kind(crate::Kind::Language) + } + + pub fn trocr_visual_small() -> Self { + Self::trocr_visual().with_model_scale(Scale::S) + } + + pub fn trocr_textual_small() -> Self { + Self::trocr_textual() + .with_model_scale(Scale::S) + .with_tokenizer_file("trocr/tokenizer-small.json") + } + + pub fn trocr_visual_base() -> Self { + Self::trocr_visual().with_model_scale(Scale::B) + } + + pub fn trocr_textual_base() -> Self { + Self::trocr_textual() + .with_model_scale(Scale::B) + .with_tokenizer_file("trocr/tokenizer-base.json") + } + + pub fn trocr_encoder_small_printed() -> Self { + Self::trocr_visual_small().with_model_file("s-encoder-printed.onnx") + } + + pub fn trocr_decoder_small_printed() -> Self { + Self::trocr_textual_small().with_model_file("s-decoder-printed.onnx") + } + + pub fn trocr_decoder_merged_small_printed() -> Self { + Self::trocr_textual_small().with_model_file("s-decoder-merged-printed.onnx") + } + + pub fn trocr_encoder_small_handwritten() -> Self { + Self::trocr_visual_small().with_model_file("s-encoder-handwritten.onnx") + } + + pub fn trocr_decoder_small_handwritten() -> Self { + Self::trocr_textual_small().with_model_file("s-decoder-handwritten.onnx") + } + + pub fn trocr_decoder_merged_small_handwritten() -> Self { + Self::trocr_textual_small().with_model_file("s-decoder-merged-handwritten.onnx") + } + + pub fn trocr_encoder_base_printed() -> Self { + Self::trocr_visual_base().with_model_file("b-encoder-printed.onnx") + } + + pub fn trocr_decoder_base_printed() -> Self { + Self::trocr_textual_base().with_model_file("b-decoder-printed.onnx") + } + + pub fn trocr_decoder_merged_base_printed() -> Self { + Self::trocr_textual_base().with_model_file("b-decoder-merged-printed.onnx") + } + + pub fn trocr_encoder_base_handwritten() -> Self { + Self::trocr_visual_base().with_model_file("b-encoder-handwritten.onnx") + } + + pub fn trocr_decoder_base_handwritten() -> Self { + Self::trocr_textual_base().with_model_file("b-decoder-handwritten.onnx") + } + + pub fn trocr_decoder_merged_base_handwritten() -> Self { + Self::trocr_textual_base().with_model_file("b-decoder-merged-handwritten.onnx") + } +} diff --git a/src/models/trocr/impl.rs b/src/models/trocr/impl.rs new file mode 100644 index 0000000..3e39b59 --- /dev/null +++ b/src/models/trocr/impl.rs @@ -0,0 +1,292 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Axis}; +use rayon::prelude::*; + +use crate::{ + elapsed, + models::{BaseModelTextual, BaseModelVisual}, + LogitsSampler, Options, Scale, Ts, Xs, Ys, X, Y, +}; + +#[derive(Debug, Copy, Clone)] +pub enum TrOCRKind { + Printed, + HandWritten, +} + +impl TryFrom<&str> for TrOCRKind { + type Error = anyhow::Error; + + fn try_from(s: &str) -> Result { + match s.to_lowercase().as_str() { + "printed" => Ok(Self::Printed), + "handwritten" | "hand-written" => Ok(Self::HandWritten), + x => anyhow::bail!("Unsupported TrOCRKind: {}", x), + } + } +} + +#[derive(Debug, Builder)] +pub struct TrOCR { + encoder: BaseModelVisual, + decoder: BaseModelTextual, + decoder_merged: BaseModelTextual, + max_length: u32, + eos_token_id: u32, + decoder_start_token_id: u32, + ts: Ts, + n_kvs: usize, +} + +impl TrOCR { + pub fn summary(&self) { + self.ts.summary(); + } + + pub fn new( + options_encoder: Options, + options_decoder: Options, + options_decoder_merged: Options, + ) -> Result { + let encoder = BaseModelVisual::new(options_encoder)?; + let decoder = BaseModelTextual::new(options_decoder)?; + let decoder_merged = BaseModelTextual::new(options_decoder_merged)?; + let ts = Ts::merge(&[ + encoder.engine().ts(), + decoder.engine().ts(), + decoder_merged.engine().ts(), + ]); + + // "bos_token": "", "eos_token": "", "sep_token": "", + // "model_max_length": 1000000000000000019884624838656, + // let bos_token = ""; + // let eos_token = ""; + // let sep_token = ""; + // let bos_token_id = 0; + // let pad_token_id = 1; + let max_length = 1024; // TODO + let eos_token_id = 2; + let decoder_start_token_id = 2; + let n_kvs = match decoder.scale() { + Some(Scale::S) => 6, + Some(Scale::B) => 12, + _ => unimplemented!(), + }; + + Ok(Self { + encoder, + decoder, + decoder_merged, + max_length, + ts, + eos_token_id, + decoder_start_token_id, + n_kvs, + }) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let encoder_hidden_states = elapsed!("encode", self.ts, { self.encoder.encode(xs)? }); + let generated = elapsed!("generate", self.ts, { + self.generate(&encoder_hidden_states)? + }); + let ys = elapsed!("decode", self.ts, { self.decode(generated)? }); + + Ok(ys) + } + + fn generate(&mut self, encoder_hidden_states: &X) -> Result>> { + // input_ids + let input_ids = X::from(vec![self.decoder_start_token_id as f32]) + .insert_axis(0)? + .repeat(0, self.encoder.batch())?; + + // decoder + let mut decoder_outputs = self.decoder.inference(Xs::from(vec![ + input_ids.clone(), + encoder_hidden_states.clone(), + ]))?; + + // encoder kvs + let encoder_kvs: Vec<_> = (3..4 * self.n_kvs) + .step_by(4) + .flat_map(|i| [i, i + 1]) + .map(|i| decoder_outputs[i].clone()) + .collect(); + + // token ids + let mut token_ids: Vec> = vec![vec![]; self.encoder.batch()]; + let mut finished = vec![false; self.encoder.batch()]; + let mut last_tokens: Vec = vec![0.; self.encoder.batch()]; + let mut logits_sampler = LogitsSampler::new(); + + // generate + for _ in 0..self.max_length { + let logits = &decoder_outputs[0]; + let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2) + .step_by(4) + .flat_map(|i| [i, i + 1]) + .map(|i| decoder_outputs[i].clone()) + .collect(); + + // decode each token for each batch + for (i, logit) in logits.axis_iter(Axis(0)).enumerate() { + if !finished[i] { + let token_id = logits_sampler.decode( + &logit + .slice(s![-1, ..]) + .into_owned() + .into_raw_vec_and_offset() + .0, + )?; + + if token_id == self.eos_token_id { + finished[i] = true; + } else { + token_ids[i].push(token_id); + } + + // update + last_tokens[i] = token_id as f32; + } + } + + // all finished? + if finished.iter().all(|&x| x) { + break; + } + + // build inputs + let input_ids = X::from(last_tokens.clone()).insert_axis(1)?; + let mut xs = vec![input_ids, encoder_hidden_states.clone()]; + for i in 0..self.n_kvs { + xs.push(decoder_kvs[i * 2].clone()); + xs.push(decoder_kvs[i * 2 + 1].clone()); + xs.push(encoder_kvs[i * 2].clone()); + xs.push(encoder_kvs[i * 2 + 1].clone()); + } + xs.push(X::ones(&[1])); // use_cache + + // generate + decoder_outputs = self.decoder_merged.inference(xs.into())?; + } + + Ok(token_ids) + } + + pub fn decode(&self, token_ids: Vec>) -> Result { + // decode + let texts = self + .decoder_merged + .processor() + .decode_tokens_batch(&token_ids, false)?; + + // to texts + let texts = texts + .into_par_iter() + .map(|x| Y::default().with_texts(&[x.into()])) + .collect::>() + .into(); + + Ok(texts) + } +} + +// #[derive(Debug, Builder)] +// pub struct TrOCREncoder { +// // TODO: `BaseVisualEncoder`, `BaseVisualModel` struct? +// engine: Engine, +// height: usize, +// width: usize, +// batch: usize, +// processor: Processor, +// } + +// impl TrOCREncoder { +// pub fn new(options: Options) -> Result { +// let engine = options.to_engine()?; +// let (batch, height, width) = ( +// engine.batch().opt(), +// engine.try_height().unwrap_or(&384.into()).opt(), +// engine.try_width().unwrap_or(&384.into()).opt(), +// ); +// let processor = options +// .to_processor()? +// .with_image_width(width as _) +// .with_image_height(height as _); + +// Ok(Self { +// engine, +// height, +// width, +// batch, +// processor, +// }) +// } + +// pub fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { +// self.batch = xs.len(); // TODO +// let x = self.processor.process_images(xs)?; + +// Ok(x.into()) +// } + +// pub fn inference(&mut self, xs: Xs) -> Result { +// self.engine.run(xs) +// } + +// fn encode(&mut self, xs: &[DynamicImage]) -> Result { +// // encode a batch of images into one embedding, that's `X` +// let xs = self.preprocess(xs)?; +// let xs = self.inference(xs)?; +// let x = xs[0].to_owned(); + +// Ok(x) +// } +// } + +// #[derive(Debug, Builder)] +// pub struct TrOCRDecoder { +// engine: Engine, +// batch: usize, +// } + +// impl TrOCRDecoder { +// pub fn new(options: Options) -> Result { +// let engine = options.to_engine()?; +// let batch = engine.batch().opt(); + +// Ok(Self { engine, batch }) +// } + +// pub fn inference(&mut self, xs: Xs) -> Result { +// self.engine.run(xs) +// } +// } + +// #[derive(Debug, Builder)] +// pub struct TrOCRDecoderMerged { +// engine: Engine, +// batch: usize, +// processor: Processor, +// } + +// impl TrOCRDecoderMerged { +// pub fn new(options: Options) -> Result { +// let engine = options.to_engine()?; +// let batch = engine.batch().opt(); +// let processor = options.to_processor()?; + +// Ok(Self { +// engine, +// batch, +// processor, +// }) +// } + +// pub fn inference(&mut self, xs: Xs) -> Result { +// self.engine.run(xs) +// } +// } diff --git a/src/models/trocr/mod.rs b/src/models/trocr/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/trocr/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/version.rs b/src/models/version.rs new file mode 100644 index 0000000..022f39e --- /dev/null +++ b/src/models/version.rs @@ -0,0 +1,43 @@ +#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash, Default)] +pub struct Version(pub u8, pub u8); + +impl std::fmt::Display for Version { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let x = if self.1 == 0 { + format!("v{}", self.0) + } else { + format!("v{}.{}", self.0, self.1) + }; + write!(f, "{}", x) + } +} + +impl From<(u8, u8)> for Version { + fn from((x, y): (u8, u8)) -> Self { + Self(x, y) + } +} + +impl From for Version { + fn from(x: f32) -> Self { + let x = format!("{:?}", x); + let x: Vec = x + .as_str() + .split('.') + .map(|x| x.parse::().unwrap_or(0)) + .collect(); + Self(x[0], x[1]) + } +} + +impl From for Version { + fn from(x: u8) -> Self { + Self(x, 0) + } +} + +impl Version { + pub fn new(x: u8, y: u8) -> Self { + Self(x, y) + } +} diff --git a/src/models/yolo/config.rs b/src/models/yolo/config.rs new file mode 100644 index 0000000..9a0c31d --- /dev/null +++ b/src/models/yolo/config.rs @@ -0,0 +1,199 @@ +use crate::{ + models::{YOLOPredsFormat, COCO_KEYPOINTS_NAMES_17}, + Options, ResizeMode, Scale, Task, +}; + +impl Options { + pub fn yolo() -> Self { + Self::default() + .with_model_name("yolo") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) + .with_model_ixx(0, 2, 640.into()) + .with_model_ixx(0, 3, 640.into()) + .with_resize_mode(ResizeMode::FitAdaptive) + .with_resize_filter("CatmullRom") + .with_find_contours(true) + } + + pub fn doclayout_yolo_docstructbench() -> Self { + Self::yolo_v10() + .with_model_file("doclayout-docstructbench.onnx") // TODO: batch_size > 1 + .with_model_ixx(0, 2, (640, 1024, 1024).into()) + .with_model_ixx(0, 3, (640, 1024, 1024).into()) + .with_class_confs(&[0.4]) + .with_class_names(&[ + "title", + "plain text", + "abandon", + "figure", + "figure_caption", + "table", + "table_caption", + "table_footnote", + "isolate_formula", + "formula_caption", + ]) + } + + pub fn yolo_classify() -> Self { + Self::yolo() + .with_model_task(Task::ImageClassification) + .with_model_ixx(0, 2, 224.into()) + .with_model_ixx(0, 3, 224.into()) + .with_resize_mode(ResizeMode::FitExact) + .with_resize_filter("Bilinear") + } + + pub fn yolo_detect() -> Self { + Self::yolo().with_model_task(Task::ObjectDetection) + } + + pub fn yolo_pose() -> Self { + Self::yolo() + .with_model_task(Task::KeypointsDetection) + .with_keypoint_names(&COCO_KEYPOINTS_NAMES_17) + } + + pub fn yolo_segment() -> Self { + Self::yolo().with_model_task(Task::InstanceSegmentation) + } + + pub fn yolo_obb() -> Self { + Self::yolo().with_model_task(Task::OrientedObjectDetection) + } + + pub fn fastsam_s() -> Self { + Self::yolo_segment() + .with_model_scale(Scale::S) + .with_model_version(8.0.into()) + .with_model_file("FastSAM-s.onnx") + } + + pub fn yolo_v8_rtdetr() -> Self { + Self::yolo() + .with_model_version(7.0.into()) + .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n()) + } + + pub fn yolo_v8_rtdetr_l() -> Self { + Self::yolo_v8_rtdetr() + .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n()) + .with_model_scale(Scale::L) + .with_model_file("rtdetr-l-det.onnx") + } + + pub fn yolo_v8_rtdetr_x() -> Self { + Self::yolo_v8_rtdetr() + .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n()) + .with_model_scale(Scale::X) + } + + pub fn yolo_n() -> Self { + Self::yolo().with_model_scale(Scale::N) + } + + pub fn yolo_s() -> Self { + Self::yolo().with_model_scale(Scale::S) + } + + pub fn yolo_m() -> Self { + Self::yolo().with_model_scale(Scale::M) + } + + pub fn yolo_l() -> Self { + Self::yolo().with_model_scale(Scale::L) + } + + pub fn yolo_x() -> Self { + Self::yolo().with_model_scale(Scale::X) + } + + pub fn yolo_v5() -> Self { + Self::yolo().with_model_version(5.0.into()) + } + + pub fn yolo_v6() -> Self { + Self::yolo().with_model_version(6.0.into()) + } + + pub fn yolo_v7() -> Self { + Self::yolo().with_model_version(7.0.into()) + } + + pub fn yolo_v8() -> Self { + Self::yolo().with_model_version(8.0.into()) + } + + pub fn yolo_v9() -> Self { + Self::yolo().with_model_version(9.0.into()) + } + + pub fn yolo_v10() -> Self { + Self::yolo().with_model_version(10.0.into()) + } + + pub fn yolo_v11() -> Self { + Self::yolo().with_model_version(11.0.into()) + } + + pub fn yolo_v8_n() -> Self { + Self::yolo() + .with_model_version(8.0.into()) + .with_model_scale(Scale::N) + } + + pub fn yolo_v8_s() -> Self { + Self::yolo() + .with_model_version(8.0.into()) + .with_model_scale(Scale::S) + } + + pub fn yolo_v8_m() -> Self { + Self::yolo() + .with_model_version(8.0.into()) + .with_model_scale(Scale::M) + } + + pub fn yolo_v8_l() -> Self { + Self::yolo() + .with_model_version(8.0.into()) + .with_model_scale(Scale::L) + } + + pub fn yolo_v8_x() -> Self { + Self::yolo() + .with_model_version(8.0.into()) + .with_model_scale(Scale::X) + } + + pub fn yolo_v11_n() -> Self { + Self::yolo() + .with_model_version(11.0.into()) + .with_model_scale(Scale::N) + } + + pub fn yolo_v11_s() -> Self { + Self::yolo() + .with_model_version(11.0.into()) + .with_model_scale(Scale::S) + } + + pub fn yolo_v11_m() -> Self { + Self::yolo() + .with_model_version(11.0.into()) + .with_model_scale(Scale::M) + } + + pub fn yolo_v11_l() -> Self { + Self::yolo() + .with_model_version(11.0.into()) + .with_model_scale(Scale::L) + } + + pub fn yolo_v11_x() -> Self { + Self::yolo() + .with_model_version(11.0.into()) + .with_model_scale(Scale::X) + } +} diff --git a/src/models/yolo.rs b/src/models/yolo/impl.rs similarity index 58% rename from src/models/yolo.rs rename to src/models/yolo/impl.rs index 45c23c3..396b602 100644 --- a/src/models/yolo.rs +++ b/src/models/yolo/impl.rs @@ -1,224 +1,296 @@ +use aksr::Builder; use anyhow::Result; use image::DynamicImage; +use log::{error, info}; use ndarray::{s, Array, Axis}; use rayon::prelude::*; use regex::Regex; use crate::{ - Bbox, BoxType, DynConf, Keypoint, Mask, Mbr, MinOptMax, Ops, Options, OrtEngine, Polygon, Prob, - Vision, Xs, YOLOPreds, YOLOTask, YOLOVersion, X, Y, + elapsed, + models::{BoxType, YOLOPredsFormat}, + Bbox, DynConf, Engine, Keypoint, Mask, Mbr, Ops, Options, Polygon, Prob, Processor, Task, Ts, + Version, Xs, Ys, Y, }; -#[derive(Debug)] +#[derive(Debug, Builder)] pub struct YOLO { - engine: OrtEngine, + engine: Engine, + height: usize, + width: usize, + batch: usize, + layout: YOLOPredsFormat, + task: Task, + version: Option, + names: Vec, + names_kpt: Vec, nc: usize, nk: usize, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, confs: DynConf, kconfs: DynConf, iou: f32, - names: Vec, - names_kpt: Vec, - task: YOLOTask, - layout: YOLOPreds, find_contours: bool, - version: Option, - classes_excluded: Vec, - classes_retained: Vec, + processor: Processor, + ts: Ts, + spec: String, + classes_excluded: Vec, + classes_retained: Vec, } -impl Vision for YOLO { - type Input = DynamicImage; +impl TryFrom for YOLO { + type Error = anyhow::Error; - fn new(options: Options) -> Result { - let span = tracing::span!(tracing::Level::INFO, "YOLO-new"); - let _guard = span.enter(); + fn try_from(options: Options) -> Result { + Self::new(options) + } +} - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().to_owned(), - engine.height().to_owned(), - engine.width().to_owned(), +impl YOLO { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, height, width, ts, spec) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&640.into()).opt(), + engine.try_width().unwrap_or(&640.into()).opt(), + engine.ts.clone(), + engine.spec().to_owned(), ); - - // YOLO Task - let task = options - .yolo_task - .or(engine.try_fetch("task").and_then(|x| match x.as_str() { - "classify" => Some(YOLOTask::Classify), - "detect" => Some(YOLOTask::Detect), - "pose" => Some(YOLOTask::Pose), - "segment" => Some(YOLOTask::Segment), - "obb" => Some(YOLOTask::Obb), - s => { - tracing::error!("YOLO Task: {s:?} is unsupported"); - None - } - })); - - // YOLO Outputs Format - let (version, layout) = match options.yolo_version { - Some(ver) => match &task { - None => anyhow::bail!("No clear YOLO Task specified for Version: {ver:?}."), - Some(task) => match task { - YOLOTask::Classify => match ver { - YOLOVersion::V5 => (Some(ver), YOLOPreds::n_clss().apply_softmax(true)), - YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_clss()), - x => anyhow::bail!("YOLOTask::Classify is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") - } - YOLOTask::Detect => match ver { - YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss()), - YOLOVersion::V8 | YOLOVersion::V9 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_a()), - YOLOVersion::V10 => (Some(ver), YOLOPreds::n_a_xyxy_confcls().apply_nms(false)), - YOLOVersion::RTDETR => (Some(ver), YOLOPreds::n_a_cxcywh_clss_n().apply_nms(false)), - } - YOLOTask::Pose => match ver { - YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_xycs_a()), - x => anyhow::bail!("YOLOTask::Pose is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let task: Option = match options.model_task { + Some(task) => Some(task), + None => match engine.try_fetch("task") { + Some(x) => match x.as_str() { + "classify" => Some(Task::ImageClassification), + "detect" => Some(Task::ObjectDetection), + "pose" => Some(Task::KeypointsDetection), + "segment" => Some(Task::InstanceSegmentation), + "obb" => Some(Task::OrientedObjectDetection), + x => { + error!("Unsupported YOLO Task: {}", x); + None } - YOLOTask::Segment => match ver { - YOLOVersion::V5 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss_coefs()), - YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()), - x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") - } - YOLOTask::Obb => match ver { - YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()), - x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") + }, + None => None, + }, + }; + + // Task & layout + let version = options.model_version; + let (layout, task) = match &options.yolo_preds_format { + // customized + Some(layout) => { + // check task + let task_parsed = layout.task(); + let task = match task { + Some(task) => { + if task_parsed != task { + anyhow::bail!( + "Task specified: {:?} is inconsistent with parsed from yolo_preds_format: {:?}", + task, + task_parsed + ); + } + task_parsed } - } - } - None => match options.yolo_preds { - None => anyhow::bail!("No clear YOLO version or YOLO Format specified."), - Some(fmt) => (None, fmt) + None => task_parsed, + }; + + (layout.clone(), task) } - }; - let task = task.unwrap_or(layout.task()); + // version + task + None => match (task, version) { + (Some(task), Some(version)) => { + let layout = match (task, version) { + (Task::ImageClassification, Version(5, 0)) => { + YOLOPredsFormat::n_clss().apply_softmax(true) + } + (Task::ImageClassification, Version(8, 0) | Version(11, 0)) => { + YOLOPredsFormat::n_clss() + } + (Task::ObjectDetection, Version(5, 0) | Version(6, 0) | Version(7, 0)) => { + YOLOPredsFormat::n_a_cxcywh_confclss() + } + (Task::ObjectDetection, Version(8, 0) | Version(9, 0) | Version(11, 0)) => { + YOLOPredsFormat::n_cxcywh_clss_a() + } + (Task::ObjectDetection, Version(10, 0)) => { + YOLOPredsFormat::n_a_xyxy_confcls().apply_nms(false) + } + (Task::KeypointsDetection, Version(8, 0) | Version(11, 0)) => { + YOLOPredsFormat::n_cxcywh_clss_xycs_a() + } + (Task::InstanceSegmentation, Version(5, 0)) => { + YOLOPredsFormat::n_a_cxcywh_confclss_coefs() + } + (Task::InstanceSegmentation, Version(8, 0) | Version(11, 0)) => { + YOLOPredsFormat::n_cxcywh_clss_coefs_a() + } + (Task::OrientedObjectDetection, Version(8, 0) | Version(11, 0)) => { + YOLOPredsFormat::n_cxcywh_clss_r_a() + } + (task, version) => { + anyhow::bail!("Task: {:?} is unsupported for Version: {:?}. Try using `.with_yolo_preds()` for customization.", task, version) + } + }; + + (layout, task) + } + (None, Some(version)) => { + let layout = match version { + // single task, no need to specified task + Version(6, 0) | Version(7, 0) => YOLOPredsFormat::n_a_cxcywh_confclss(), + Version(9, 0) => YOLOPredsFormat::n_cxcywh_clss_a(), + Version(10, 0) => YOLOPredsFormat::n_a_xyxy_confcls().apply_nms(false), + _ => { + anyhow::bail!( + "No clear YOLO Task specified for Version: {:?}.", + version + ) + } + }; + + (layout, Task::ObjectDetection) + } + (Some(task), None) => { + anyhow::bail!("No clear YOLO Version specified for Task: {:?}.", task) + } + (None, None) => { + anyhow::bail!("No clear YOLO Task and Version specified.") + } + }, + }; - // Class names: user-defined.or(parsed) - let names_parsed = Self::fetch_names(&engine); - let names = match names_parsed { - Some(names_parsed) => match options.names { + // Class names + let names: Option> = match Self::fetch_names_from_onnx(&engine) { + Some(names_parsed) => match &options.class_names { Some(names) => { if names.len() == names_parsed.len() { - Some(names) + // prioritize user-defined + Some(names.clone()) } else { + // Fail to override anyhow::bail!( "The lengths of parsed class names: {} and user-defined class names: {} do not match.", names_parsed.len(), names.len(), - ); + ) } } None => Some(names_parsed), }, - None => options.names, + None => options.class_names.clone(), }; - // nc: names.len().or(options.nc) - let nc = match &names { - Some(names) => names.len(), - None => match options.nc { - Some(nc) => nc, - None => anyhow::bail!( - "Unable to obtain the number of classes. Please specify them explicitly using `options.with_nc(usize)` or `options.with_names()`." - ), + // Class names & Number of class + let (nc, names) = match (options.nc(), names) { + (_, Some(names)) => (names.len(), names.to_vec()), + (Some(nc), None) => (nc, Self::n2s(nc)), + (None, None) => { + anyhow::bail!( + "Neither class names nor the number of classes were specified. \ + \nConsider specify them with `Options::default().with_nc()` or `Options::default().with_class_names()`" + ); } }; - // Class names - let names = match names { - None => Self::n2s(nc), - Some(names) => names, - }; - - // Keypoint names & nk - let (nk, names_kpt) = match Self::fetch_kpts(&engine) { - None => (0, vec![]), - Some(nk) => match options.names2 { - Some(names) => { - if names.len() == nk { - (nk, names) - } else { + // Keypoint names & Number of keypoints + let (nk, names_kpt) = if let Task::KeypointsDetection = task { + let nk = Self::fetch_nk_from_onnx(&engine).or(options.nk()); + match (&options.keypoint_names, nk) { + (Some(names), Some(nk)) => { + if names.len() != nk { anyhow::bail!( - "The lengths of user-defined keypoint names: {} and nk: {} do not match.", + "The lengths of user-defined keypoint names: {} and nk parsed: {} do not match.", names.len(), nk, ); } + (nk, names.clone()) } - None => (nk, Self::n2s(nk)), - }, + (Some(names), None) => (names.len(), names.clone()), + (None, Some(nk)) => (nk, Self::n2s(nk)), + (None, None) => anyhow::bail!( + "Neither keypoint names nor the number of keypoints were specified when doing `KeypointsDetection` task. \ + \nConsider specify them with `Options::default().with_nk()` or `Options::default().with_keypoint_names()`" + ), + } + } else { + (0, vec![]) }; - // Confs & Iou - let confs = DynConf::new(&options.confs, nc); - let kconfs = DynConf::new(&options.kconfs, nk); - let iou = options.iou.unwrap_or(0.45); - - // Classes excluded and retained - let classes_excluded = options.classes_excluded; - let classes_retained = options.classes_retained; - - // Summary - tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version); - - // dry run - engine.dry_run()?; + // Attributes + let confs = DynConf::new(options.class_confs(), nc); + let kconfs = DynConf::new(options.keypoint_confs(), nk); + let iou = options.iou().unwrap_or(0.45); + let classes_excluded = options.classes_excluded().to_vec(); + let classes_retained = options.classes_retained().to_vec(); + let find_contours = options.find_contours(); + let mut info = format!( + "YOLO Version: {}, Task: {:?}, Category Count: {}, Keypoint Count: {}", + version.map_or("Unknown".into(), |x| x.to_string()), + task, + nc, + nk, + ); + if !classes_excluded.is_empty() { + info = format!("{}, classes_excluded: {:?}", info, classes_excluded); + } + if !classes_retained.is_empty() { + info = format!("{}, classes_retained: {:?}", info, classes_retained); + } + info!("{}", info); Ok(Self { engine, - confs, - kconfs, - iou, - nc, - nk, height, width, batch, task, + version, + spec, + layout, names, names_kpt, - layout, - version, - find_contours: options.find_contours, + confs, + kconfs, + iou, + nc, + nk, + find_contours, classes_excluded, classes_retained, + processor, + ts, }) } - fn preprocess(&self, xs: &[Self::Input]) -> Result { - let xs_ = match self.task { - YOLOTask::Classify => { - X::resize(xs, self.height() as u32, self.width() as u32, "Bilinear")? - .normalize(0., 255.)? - .nhwc2nchw()? - } - _ => X::apply(&[ - Ops::Letterbox( - xs, - self.height() as u32, - self.width() as u32, - "CatmullRom", - 114, - "auto", - false, - ), - Ops::Normalize(0., 255.), - Ops::Nhwc2nchw, - ])?, - }; - Ok(Xs::from(xs_)) + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let x = self.processor.process_images(xs)?; + + Ok(x.into()) } fn inference(&mut self, xs: Xs) -> Result { self.engine.run(xs) } - fn postprocess(&self, xs: Xs, xs0: &[Self::Input]) -> Result> { + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn postprocess(&self, xs: Xs) -> Result { let protos = if xs.len() == 2 { Some(&xs[1]) } else { None }; let ys: Vec = xs[0] .axis_iter(Axis(0)) @@ -227,7 +299,7 @@ impl Vision for YOLO { .filter_map(|(idx, preds)| { let mut y = Y::default(); - // parse preditions + // Parse predictions let ( slice_bboxes, slice_id, @@ -238,8 +310,8 @@ impl Vision for YOLO { slice_radians, ) = self.layout.parse_preds(preds, self.nc); - // Classifcation - if let YOLOTask::Classify = self.task { + // ImageClassifcation + if let Task::ImageClassification = self.task { let x = if self.layout.apply_softmax { let exps = slice_clss.mapv(|x| x.exp()); let stds = exps.sum_axis(Axis(0)); @@ -247,17 +319,16 @@ impl Vision for YOLO { } else { slice_clss.into_owned() }; - let mut probs = Prob::default().with_probs(&x.into_raw_vec_and_offset().0); - probs = probs + let probs = Prob::default() + .with_probs(&x.into_raw_vec_and_offset().0) .with_names(&self.names.iter().map(|x| x.as_str()).collect::>()); return Some(y.with_probs(probs)); } - let image_width = xs0[idx].width() as f32; - let image_height = xs0[idx].height() as f32; - let ratio = - (self.width() as f32 / image_width).min(self.height() as f32 / image_height); + // Original image size + let (image_height, image_width) = self.processor.image0s_size[idx]; + let ratio = self.processor.scale_factors_hw[idx][0]; // Other tasks let (y_bboxes, y_mbrs) = slice_bboxes? @@ -284,19 +355,21 @@ impl Vision for YOLO { } }; - // filtering by class id + // filter out class id if !self.classes_excluded.is_empty() - && self.classes_excluded.contains(&(class_id as isize)) + && self.classes_excluded.contains(&class_id) { return None; } + + // filter by class id if !self.classes_retained.is_empty() - && !self.classes_retained.contains(&(class_id as isize)) + && !self.classes_retained.contains(&class_id) { return None; } - // filtering by conf + // filter by conf if confidence < self.confs[class_id] { return None; } @@ -354,8 +427,7 @@ impl Vision for YOLO { (h, w, radians + std::f32::consts::PI / 2.) }; let radians = radians % std::f32::consts::PI; - - let mut mbr = Mbr::from_cxcywhr( + let mbr = Mbr::from_cxcywhr( cx as f64, cy as f64, w as f64, @@ -363,18 +435,18 @@ impl Vision for YOLO { radians as f64, ) .with_confidence(confidence) - .with_id(class_id as isize); - mbr = mbr.with_name(&self.names[class_id]); + .with_id(class_id as isize) + .with_name(&self.names[class_id]); (None, Some(mbr)) } None => { - let mut bbox = Bbox::default() + let bbox = Bbox::default() .with_xywh(x, y, w, h) .with_confidence(confidence) .with_id(class_id as isize) - .with_id_born(i as isize); - bbox = bbox.with_name(&self.names[class_id]); + .with_id_born(i as isize) + .with_name(&self.names[class_id]); (Some(bbox), None) } @@ -404,7 +476,7 @@ impl Vision for YOLO { } } - // Pose + // KeypointsDetection if let Some(pred_kpts) = slice_kpts { let kpt_step = self.layout.kpt_step().unwrap_or(3); if let Some(bboxes) = y.bboxes() { @@ -421,16 +493,14 @@ impl Vision for YOLO { if kconf < self.kconfs[i] { Keypoint::default() } else { - let mut kpt = Keypoint::default() + Keypoint::default() .with_id(i as isize) .with_confidence(kconf) .with_xy( - kx.max(0.0f32).min(image_width), - ky.max(0.0f32).min(image_height), - ); - - kpt = kpt.with_name(&self.names_kpt[i]); - kpt + kx.max(0.0f32).min(image_width as f32), + ky.max(0.0f32).min(image_height as f32), + ) + .with_name(&self.names_kpt[i]) } }) .collect::>(); @@ -441,7 +511,7 @@ impl Vision for YOLO { } } - // Segment + // InstanceSegmentation if let Some(coefs) = slice_coefs { if let Some(bboxes) = y.bboxes() { let (y_polygons, y_masks) = bboxes @@ -533,54 +603,26 @@ impl Vision for YOLO { }) .collect(); - Ok(ys) - } -} - -impl YOLO { - pub fn batch(&self) -> usize { - self.batch.opt() - } - - pub fn width(&self) -> usize { - self.width.opt() - } - - pub fn height(&self) -> usize { - self.height.opt() + Ok(ys.into()) } - pub fn version(&self) -> Option<&YOLOVersion> { - self.version.as_ref() - } - - pub fn task(&self) -> &YOLOTask { - &self.task - } - - pub fn layout(&self) -> &YOLOPreds { - &self.layout - } - - fn fetch_names(engine: &OrtEngine) -> Option> { + fn fetch_names_from_onnx(engine: &Engine) -> Option> { // fetch class names from onnx metadata // String format: `{0: 'person', 1: 'bicycle', 2: 'sports ball', ..., 27: "yellow_lady's_slipper"}` - engine.try_fetch("names").map(|names| { - let re = Regex::new(r#"(['"])([-()\w '"]+)(['"])"#).unwrap(); - let mut names_ = vec![]; - for (_, [_, name, _]) in re.captures_iter(&names).map(|x| x.extract()) { - names_.push(name.to_string()); - } - names_ - }) + Regex::new(r#"(['"])([-()\w '"]+)(['"])"#) + .ok()? + .captures_iter(&engine.try_fetch("names")?) + .filter_map(|caps| caps.get(2).map(|m| m.as_str().to_string())) + .collect::>() + .into() } - fn fetch_kpts(engine: &OrtEngine) -> Option { - engine.try_fetch("kpt_shape").map(|s| { - let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap(); - let caps = re.captures(&s).unwrap(); - caps.get(1).unwrap().as_str().parse::().unwrap() - }) + fn fetch_nk_from_onnx(engine: &Engine) -> Option { + Regex::new(r"(\d+), \d+") + .ok()? + .captures(&engine.try_fetch("kpt_shape")?) + .and_then(|caps| caps.get(1)) + .and_then(|m| m.as_str().parse::().ok()) } fn n2s(n: usize) -> Vec { diff --git a/src/models/yolo/mod.rs b/src/models/yolo/mod.rs new file mode 100644 index 0000000..a7ca057 --- /dev/null +++ b/src/models/yolo/mod.rs @@ -0,0 +1,6 @@ +mod config; +mod r#impl; +mod preds; + +pub use preds::*; +pub use r#impl::*; diff --git a/src/models/yolo_.rs b/src/models/yolo/preds.rs similarity index 73% rename from src/models/yolo_.rs rename to src/models/yolo/preds.rs index 994dd3a..e85e762 100644 --- a/src/models/yolo_.rs +++ b/src/models/yolo/preds.rs @@ -1,107 +1,13 @@ use ndarray::{ArrayBase, ArrayView, Axis, Dim, IxDyn, IxDynImpl, ViewRepr}; -#[derive(Debug, Clone, clap::ValueEnum)] -pub enum YOLOTask { - Classify, - Detect, - Pose, - Segment, - Obb, -} - -impl YOLOTask { - pub fn name(&self) -> String { - match self { - Self::Classify => "cls".to_string(), - Self::Detect => "det".to_string(), - Self::Pose => "pose".to_string(), - Self::Segment => "seg".to_string(), - Self::Obb => "obb".to_string(), - } - } - - pub fn name_detailed(&self) -> String { - match self { - Self::Classify => "image-classification".to_string(), - Self::Detect => "object-detection".to_string(), - Self::Pose => "pose-estimation".to_string(), - Self::Segment => "instance-segment".to_string(), - Self::Obb => "oriented-object-detection".to_string(), - } - } -} - -#[derive(Debug, Copy, Clone, clap::ValueEnum)] -pub enum YOLOVersion { - V5, - V6, - V7, - V8, - V9, - V10, - V11, - RTDETR, -} - -impl YOLOVersion { - pub fn name(&self) -> String { - match self { - Self::V5 => "v5".to_string(), - Self::V6 => "v6".to_string(), - Self::V7 => "v7".to_string(), - Self::V8 => "v8".to_string(), - Self::V9 => "v9".to_string(), - Self::V10 => "v10".to_string(), - Self::V11 => "v11".to_string(), - Self::RTDETR => "rtdetr".to_string(), - } - } -} - -#[derive(Debug, Copy, Clone, clap::ValueEnum)] -pub enum YOLOScale { - N, - T, - B, - S, - M, - L, - C, - E, - X, -} - -impl YOLOScale { - pub fn name(&self) -> String { - match self { - Self::N => "n".to_string(), - Self::T => "t".to_string(), - Self::S => "s".to_string(), - Self::B => "b".to_string(), - Self::M => "m".to_string(), - Self::L => "l".to_string(), - Self::C => "c".to_string(), - Self::E => "e".to_string(), - Self::X => "x".to_string(), - } - } -} +use crate::Task; #[derive(Debug, Clone, PartialEq)] pub enum BoxType { - /// 1 Cxcywh, - - /// 2 Cxcybr Cxcyxy, - - /// 3 Tlbr Xyxy, - - /// 4 Tlwh Xywh, - - /// 5 Tlcxcy XyCxcy, } @@ -127,7 +33,7 @@ pub enum AnchorsPosition { } #[derive(Debug, Clone, PartialEq)] -pub struct YOLOPreds { +pub struct YOLOPredsFormat { pub clss: ClssType, pub bbox: Option, pub kpts: Option, @@ -137,9 +43,11 @@ pub struct YOLOPreds { pub is_bbox_normalized: bool, pub apply_nms: bool, pub apply_softmax: bool, + // ------------------------------------------------ + // pub is_concatenated: bool, // TODO: how to tell which parts? } -impl Default for YOLOPreds { +impl Default for YOLOPredsFormat { fn default() -> Self { Self { clss: ClssType::Clss, @@ -151,11 +59,12 @@ impl Default for YOLOPreds { is_bbox_normalized: false, apply_nms: true, apply_softmax: false, + // is_concatenated: true, } } } -impl YOLOPreds { +impl YOLOPredsFormat { pub fn apply_nms(mut self, x: bool) -> Self { self.apply_nms = x; self @@ -259,16 +168,16 @@ impl YOLOPreds { } } - pub fn task(&self) -> YOLOTask { + pub fn task(&self) -> Task { match self.obb { - Some(_) => YOLOTask::Obb, + Some(_) => Task::OrientedObjectDetection, None => match self.coefs { - Some(_) => YOLOTask::Segment, + Some(_) => Task::InstanceSegmentation, None => match self.kpts { - Some(_) => YOLOTask::Pose, + Some(_) => Task::KeypointsDetection, None => match self.bbox { - Some(_) => YOLOTask::Detect, - None => YOLOTask::Classify, + Some(_) => Task::ObjectDetection, + None => Task::ImageClassification, }, }, }, @@ -327,7 +236,7 @@ impl YOLOPreds { Option>, ) { match self.task() { - YOLOTask::Classify => (None, None, x, None, None, None, None), + Task::ImageClassification => (None, None, x, None, None, None, None), _ => { let x = if self.is_anchors_first() { x @@ -335,7 +244,7 @@ impl YOLOPreds { x.reversed_axes() }; - // get each tasks slices + // each tasks slices let (slice_bboxes, xs) = x.split_at(Axis(1), 4); let (slice_id, slice_clss, slice_confs, xs) = match self.clss { ClssType::ConfClss => { @@ -364,9 +273,9 @@ impl YOLOPreds { } }; let (slice_kpts, slice_coefs, slice_radians) = match self.task() { - YOLOTask::Pose => (Some(xs), None, None), - YOLOTask::Segment => (None, Some(xs), None), - YOLOTask::Obb => (None, None, Some(xs)), + Task::Pose | Task::KeypointsDetection => (Some(xs), None, None), + Task::InstanceSegmentation => (None, Some(xs), None), + Task::Obb | Task::OrientedObjectDetection => (None, None, Some(xs)), _ => (None, None, None), }; diff --git a/src/models/yolop/config.rs b/src/models/yolop/config.rs new file mode 100644 index 0000000..6e1564e --- /dev/null +++ b/src/models/yolop/config.rs @@ -0,0 +1,22 @@ +/// Model configuration for `YOLOP` +impl crate::Options { + pub fn yolop() -> Self { + Self::default() + .with_model_name("yolop") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 2, 640.into()) + .with_model_ixx(0, 3, 640.into()) + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_resize_filter("Bilinear") + .with_normalize(true) + .with_class_confs(&[0.3]) + } + + pub fn yolop_v2_480x800() -> Self { + Self::yolop().with_model_file("v2-480x800.onnx") + } + + pub fn yolop_v2_736x1280() -> Self { + Self::yolop().with_model_file("v2-736x1280.onnx") + } +} diff --git a/src/models/yolop.rs b/src/models/yolop/impl.rs similarity index 69% rename from src/models/yolop.rs rename to src/models/yolop/impl.rs index 05adbba..0e734ae 100644 --- a/src/models/yolop.rs +++ b/src/models/yolop/impl.rs @@ -1,61 +1,75 @@ +use aksr::Builder; use anyhow::Result; use image::DynamicImage; use ndarray::{s, Array, Axis, IxDyn}; -use crate::{Bbox, DynConf, MinOptMax, Ops, Options, OrtEngine, Polygon, Xs, X, Y}; +use crate::{elapsed, Bbox, DynConf, Engine, Ops, Options, Polygon, Processor, Ts, Xs, Ys, Y}; -#[derive(Debug)] +#[derive(Builder, Debug)] pub struct YOLOPv2 { - engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, + engine: Engine, + height: usize, + width: usize, + batch: usize, + ts: Ts, + spec: String, + processor: Processor, confs: DynConf, iou: f32, } impl YOLOPv2 { pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.batch().to_owned(), - engine.height().to_owned(), - engine.width().to_owned(), + let engine = options.to_engine()?; + let spec = engine.spec().to_string(); + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&512.into()).opt(), + engine.try_width().unwrap_or(&512.into()).opt(), + engine.ts().clone(), ); - let confs = DynConf::new(&options.kconfs, 80); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + + let confs = DynConf::new(options.class_confs(), 80); let iou = options.iou.unwrap_or(0.45f32); - engine.dry_run()?; Ok(Self { engine, - confs, height, width, batch, + confs, iou, + ts, + processor, + spec, }) } - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Letterbox( - xs, - self.height() as u32, - self.width() as u32, - "Bilinear", - 114, - "auto", - false, - ), - Ops::Normalize(0., 255.), - Ops::Nhwc2nchw, - ])?; - let ys = self.engine.run(Xs::from(xs_))?; - self.postprocess(ys, xs) + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + Ok(self.processor.process_images(xs)?.into()) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) } - pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { - // pub fn postprocess(&self, xs: Vec, xs0: &[DynamicImage]) -> Result> { + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } + + fn postprocess(&mut self, xs: Xs) -> Result { let mut ys: Vec = Vec::new(); let (xs_da, xs_ll, xs_det) = (&xs[0], &xs[1], &xs[2]); for (idx, ((x_det, x_ll), x_da)) in xs_det @@ -64,14 +78,8 @@ impl YOLOPv2 { .zip(xs_da.axis_iter(Axis(0))) .enumerate() { - let image_width = xs0[idx].width() as f32; - let image_height = xs0[idx].height() as f32; - let (ratio, _, _) = Ops::scale_wh( - image_width, - image_height, - self.width() as f32, - self.height() as f32, - ); + let (image_height, image_width) = self.processor.image0s_size[idx]; + let ratio = self.processor.scale_factors_hw[idx][0]; // Vehicle let mut y_bboxes = Vec::new(); @@ -94,8 +102,8 @@ impl YOLOPv2 { let h = bbox[3] / ratio; let x = cx - w / 2.; let y = cy - h / 2.; - let x = x.max(0.0).min(image_width); - let y = y.max(0.0).min(image_height); + let x = x.max(0.0).min(image_width as _); + let y = y.max(0.0).min(image_height as _); y_bboxes.push( Bbox::default() .with_xywh(x, y, w, h) @@ -112,10 +120,10 @@ impl YOLOPv2 { let contours = match self.get_contours_from_mask( x_da.into_dyn(), 0.0, - self.width() as _, - self.height() as _, - image_width, - image_height, + self.width as _, + self.height as _, + image_width as _, + image_height as _, ) { Err(_) => continue, Ok(x) => x, @@ -138,10 +146,10 @@ impl YOLOPv2 { let contours = match self.get_contours_from_mask( x_ll.to_owned(), 0.5, - self.width() as _, - self.height() as _, - image_width, - image_height, + self.width as _, + self.height as _, + image_width as _, + image_height as _, ) { Err(_) => continue, Ok(x) => x, @@ -168,19 +176,7 @@ impl YOLOPv2 { .apply_nms(self.iou), ); } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt() as _ - } - - pub fn width(&self) -> isize { - self.width.opt() as _ - } - - pub fn height(&self) -> isize { - self.height.opt() as _ + Ok(ys.into()) } fn get_contours_from_mask( diff --git a/src/models/yolop/mod.rs b/src/models/yolop/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/yolop/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/utils/names.rs b/src/utils/names.rs deleted file mode 100644 index ea6b648..0000000 --- a/src/utils/names.rs +++ /dev/null @@ -1,154 +0,0 @@ -//! Some constants releated with COCO dataset: [`COCO_SKELETONS_16`], [`COCO_KEYPOINTS_17`], [`COCO_CLASS_NAMES_80`] - -pub const COCO_SKELETONS_16: [(usize, usize); 16] = [ - (0, 1), - (0, 2), - (1, 3), - (2, 4), - (5, 6), - (5, 11), - (6, 12), - (11, 12), - (5, 7), - (6, 8), - (7, 9), - (8, 10), - (11, 13), - (12, 14), - (13, 15), - (14, 16), -]; - -pub const COCO_KEYPOINTS_17: [&str; 17] = [ - "nose", - "left_eye", - "right_eye", - "left_ear", - "right_ear", - "left_shoulder", - "right_shoulder", - "left_elbow", - "right_elbow", - "left_wrist", - "right_wrist", - "left_hip", - "right_hip", - "left_knee", - "right_knee", - "left_ankle", - "right_ankle", -]; - -pub const COCO_CLASS_NAMES_80: [&str; 80] = [ - "person", - "bicycle", - "car", - "motorcycle", - "airplane", - "bus", - "train", - "truck", - "boat", - "traffic light", - "fire hydrant", - "stop sign", - "parking meter", - "bench", - "bird", - "cat", - "dog", - "horse", - "sheep", - "cow", - "elephant", - "bear", - "zebra", - "giraffe", - "backpack", - "umbrella", - "handbag", - "tie", - "suitcase", - "frisbee", - "skis", - "snowboard", - "sports ball", - "kite", - "baseball bat", - "baseball glove", - "skateboard", - "surfboard", - "tennis racket", - "bottle", - "wine glass", - "cup", - "fork", - "knife", - "spoon", - "bowl", - "banana", - "apple", - "sandwich", - "orange", - "broccoli", - "carrot", - "hot dog", - "pizza", - "donut", - "cake", - "chair", - "couch", - "potted plant", - "bed", - "dining table", - "toilet", - "tv", - "laptop", - "mouse", - "remote", - "keyboard", - "cell phone", - "microwave", - "oven", - "toaster", - "sink", - "refrigerator", - "book", - "clock", - "vase", - "scissors", - "teddy bear", - "hair drier", - "toothbrush", -]; - -pub const BODY_PARTS_28: [&str; 28] = [ - "Background", - "Apparel", - "Face Neck", - "Hair", - "Left Foot", - "Left Hand", - "Left Lower Arm", - "Left Lower Leg", - "Left Shoe", - "Left Sock", - "Left Upper Arm", - "Left Upper Leg", - "Lower Clothing", - "Right Foot", - "Right Hand", - "Right Lower Arm", - "Right Lower Leg", - "Right Shoe", - "Right Sock", - "Right Upper Arm", - "Right Upper Leg", - "Torso", - "Upper Clothing", - "Lower Lip", - "Upper Lip", - "Lower Teeth", - "Upper Teeth", - "Tongue", -]; diff --git a/src/ys/bbox.rs b/src/xy/bbox.rs similarity index 100% rename from src/ys/bbox.rs rename to src/xy/bbox.rs index a4294b1..66c30f9 100644 --- a/src/ys/bbox.rs +++ b/src/xy/bbox.rs @@ -13,9 +13,9 @@ pub struct Bbox { w: f32, h: f32, id: isize, + id_born: isize, confidence: f32, name: Option, - id_born: isize, } impl Nms for Bbox { @@ -38,9 +38,9 @@ impl Default for Bbox { w: 0., h: 0., id: -1, + id_born: -1, confidence: 0., name: None, - id_born: -1, } } } diff --git a/src/ys/keypoint.rs b/src/xy/keypoint.rs similarity index 96% rename from src/ys/keypoint.rs rename to src/xy/keypoint.rs index 5d5b5e8..e75d00a 100644 --- a/src/ys/keypoint.rs +++ b/src/xy/keypoint.rs @@ -150,6 +150,18 @@ impl From<[f32; 2]> for Keypoint { } } +impl From<(f32, f32, isize)> for Keypoint { + fn from((x, y, id): (f32, f32, isize)) -> Self { + Self { + x, + y, + id, + confidence: 1., + ..Default::default() + } + } +} + impl From<(f32, f32, isize, f32)> for Keypoint { fn from((x, y, id, confidence): (f32, f32, isize, f32)) -> Self { Self { diff --git a/src/ys/mask.rs b/src/xy/mask.rs similarity index 100% rename from src/ys/mask.rs rename to src/xy/mask.rs diff --git a/src/ys/mbr.rs b/src/xy/mbr.rs similarity index 100% rename from src/ys/mbr.rs rename to src/xy/mbr.rs diff --git a/src/ys/mod.rs b/src/xy/mod.rs similarity index 72% rename from src/ys/mod.rs rename to src/xy/mod.rs index f07d38a..626e151 100644 --- a/src/ys/mod.rs +++ b/src/xy/mod.rs @@ -1,20 +1,27 @@ mod bbox; -mod embedding; +// mod embedding; mod keypoint; mod mask; mod mbr; mod polygon; mod prob; +mod text; +mod x; +mod xs; mod y; +mod ys; pub use bbox::Bbox; -pub use embedding::Embedding; pub use keypoint::Keypoint; pub use mask::Mask; pub use mbr::Mbr; pub use polygon::Polygon; pub use prob::Prob; +pub use text::Text; +pub use x::X; +pub use xs::Xs; pub use y::Y; +pub use ys::Ys; pub trait Nms { fn iou(&self, other: &Self) -> f32; diff --git a/src/ys/polygon.rs b/src/xy/polygon.rs similarity index 100% rename from src/ys/polygon.rs rename to src/xy/polygon.rs diff --git a/src/ys/prob.rs b/src/xy/prob.rs similarity index 100% rename from src/ys/prob.rs rename to src/xy/prob.rs diff --git a/src/xy/text.rs b/src/xy/text.rs new file mode 100644 index 0000000..0c67b5f --- /dev/null +++ b/src/xy/text.rs @@ -0,0 +1,17 @@ +/// Wrapper over [`String`] +#[derive(aksr::Builder, Debug, Clone, Default, PartialEq)] +pub struct Text(String); + +impl std::ops::Deref for Text { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl> From for Text { + fn from(x: T) -> Self { + Self(x.as_ref().to_string()) + } +} diff --git a/src/core/x.rs b/src/xy/x.rs similarity index 54% rename from src/core/x.rs rename to src/xy/x.rs index 3a245e2..6b4b482 100644 --- a/src/core/x.rs +++ b/src/xy/x.rs @@ -2,10 +2,10 @@ use anyhow::Result; use image::DynamicImage; use ndarray::{Array, Dim, IntoDimension, IxDyn, IxDynImpl}; -use crate::Ops; +use crate::{Ops, ResizeMode}; -/// Model input, wrapper over [`Array`] -#[derive(Debug, Clone, Default)] +/// Wrapper over [`Array`] +#[derive(Debug, Clone, Default, PartialEq)] pub struct X(pub Array); impl From> for X { @@ -20,6 +20,42 @@ impl From> for X { } } +impl TryFrom> for X { + type Error = anyhow::Error; + + fn try_from(values: Vec<(u32, u32)>) -> Result { + if values.is_empty() { + Ok(Self::default()) + } else { + let mut flattened: Vec = Vec::new(); + for &(a, b) in values.iter() { + flattened.push(a); + flattened.push(b); + } + let shape = (values.len(), 2); + let x = Array::from_shape_vec(shape, flattened)? + .map(|x| *x as f32) + .into_dyn(); + Ok(Self(x)) + } + } +} + +impl TryFrom>> for X { + type Error = anyhow::Error; + + fn try_from(xs: Vec>) -> Result { + if xs.is_empty() { + Ok(Self::default()) + } else { + let shape = (xs.len(), xs[0].len()); + let flattened: Vec = xs.iter().flatten().cloned().collect(); + let x = Array::from_shape_vec(shape, flattened)?.into_dyn(); + Ok(Self(x)) + } + } +} + impl std::ops::Deref for X { type Target = Array; @@ -29,6 +65,8 @@ impl std::ops::Deref for X { } impl X { + // TODO: Add some slice and index method + pub fn zeros(shape: &[usize]) -> Self { Self::from(Array::zeros(Dim(IxDynImpl::from(shape.to_vec())))) } @@ -37,11 +75,27 @@ impl X { Self::from(Array::ones(Dim(IxDynImpl::from(shape.to_vec())))) } + pub fn zeros_like(x: Self) -> Self { + Self::from(Array::zeros(x.raw_dim())) + } + + pub fn ones_like(x: Self) -> Self { + Self::from(Array::ones(x.raw_dim())) + } + + pub fn full(shape: &[usize], x: f32) -> Self { + Self::from(Array::from_elem(shape, x)) + } + + pub fn from_shape_vec(shape: &[usize], xs: Vec) -> Result { + Ok(Self::from(Array::from_shape_vec(shape, xs)?)) + } + pub fn apply(ops: &[Ops]) -> Result { let mut y = Self::default(); for op in ops { y = match op { - Ops::Resize(xs, h, w, filter) => Self::resize(xs, *h, *w, filter)?, + Ops::FitExact(xs, h, w, filter) => Self::resize(xs, *h, *w, filter)?, Ops::Letterbox(xs, h, w, filter, bg, resize_by, center) => { Self::letterbox(xs, *h, *w, filter, *bg, resize_by, *center)? } @@ -103,6 +157,12 @@ impl X { Ok(self) } + pub fn concat(xs: &[Self], d: usize) -> Result { + let xs = xs.iter().cloned().map(|x| x.0).collect::>(); + let x = Ops::concat(&xs, d)?; + Ok(x.into()) + } + pub fn dims(&self) -> &[usize] { self.0.shape() } @@ -126,6 +186,11 @@ impl X { Ok(self) } + pub fn unsigned(mut self) -> Self { + self.0 = self.0.mapv(|x| if x < 0.0 { 0.0 } else { x }); + self + } + pub fn resize(xs: &[DynamicImage], height: u32, width: u32, filter: &str) -> Result { Ok(Self::from(Ops::resize(xs, height, width, filter)?)) } @@ -143,4 +208,47 @@ impl X { xs, height, width, filter, bg, resize_by, center, )?)) } + + #[allow(clippy::too_many_arguments)] + pub fn preprocess( + xs: &[image::DynamicImage], + image_width: u32, + image_height: u32, + resize_mode: &ResizeMode, + resizer_filter: &str, + padding_value: u8, + letterbox_center: bool, + normalize: bool, + image_std: &[f32], + image_mean: &[f32], + nchw: bool, + ) -> Result { + let mut x = match resize_mode { + ResizeMode::FitExact => X::resize(xs, image_height, image_width, resizer_filter)?, + ResizeMode::Letterbox => X::letterbox( + xs, + image_height, + image_width, + resizer_filter, + padding_value, + "auto", + letterbox_center, + )?, + _ => unimplemented!(), + }; + + if normalize { + x = x.normalize(0., 255.)?; + } + + if !image_std.is_empty() && !image_mean.is_empty() { + x = x.standardize(image_mean, image_std, 3)?; + } + + if nchw { + x = x.nhwc2nchw()?; + } + + Ok(x) + } } diff --git a/src/core/xs.rs b/src/xy/xs.rs similarity index 88% rename from src/core/xs.rs rename to src/xy/xs.rs index b3cef8f..d54c347 100644 --- a/src/core/xs.rs +++ b/src/xy/xs.rs @@ -1,5 +1,6 @@ use aksr::Builder; use anyhow::Result; +use image::DynamicImage; use std::collections::HashMap; use std::ops::{Deref, Index}; @@ -9,6 +10,10 @@ use crate::{string_random, X}; pub struct Xs { map: HashMap, names: Vec, + + // TODO: move to Processor + pub images: Vec>, + pub texts: Vec>, } impl From for Xs { @@ -36,6 +41,14 @@ impl Xs { } } + pub fn derive(&self) -> Self { + Self { + map: Default::default(), + names: Default::default(), + ..self.clone() + } + } + pub fn push(&mut self, value: X) { loop { let key = string_random(5); diff --git a/src/ys/y.rs b/src/xy/y.rs similarity index 71% rename from src/ys/y.rs rename to src/xy/y.rs index 140a43a..cd263d1 100644 --- a/src/ys/y.rs +++ b/src/xy/y.rs @@ -1,6 +1,6 @@ use aksr::Builder; -use crate::{Bbox, Embedding, Keypoint, Mask, Mbr, Nms, Polygon, Prob}; +use crate::{Bbox, Keypoint, Mask, Mbr, Nms, Polygon, Prob, Text, X}; /// Container for inference results for each image. /// @@ -10,60 +10,60 @@ use crate::{Bbox, Embedding, Keypoint, Mask, Mbr, Nms, Polygon, Prob}; /// /// # Fields /// +/// * `texts` - Optionally contains a vector of texts. +/// * `embedding` - Optionally contains the embedding representation. /// * `probs` - Optionally contains the probability scores for the detected objects. /// * `bboxes` - Optionally contains a vector of bounding boxes. /// * `keypoints` - Optionally contains a nested vector of keypoints. /// * `mbrs` - Optionally contains a vector of minimum bounding rectangles. /// * `polygons` - Optionally contains a vector of polygons. -/// * `texts` - Optionally contains a vector of text annotations. /// * `masks` - Optionally contains a vector of masks. -/// * `embedding` - Optionally contains the embedding representation. #[derive(Builder, Clone, PartialEq, Default)] pub struct Y { + texts: Option>, + embedding: Option, probs: Option, bboxes: Option>, keypoints: Option>>, mbrs: Option>, polygons: Option>, - texts: Option>, masks: Option>, - embedding: Option, } impl std::fmt::Debug for Y { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut f = f.debug_struct("Y"); - if let Some(x) = &self.texts { - if !x.is_empty() { - f.field("Texts", &x); + if let Some(xs) = &self.texts { + if !xs.is_empty() { + f.field("Texts", &xs); } } - if let Some(x) = &self.probs { - f.field("Probabilities", &x); + if let Some(xs) = &self.probs { + f.field("Probs", &xs); } - if let Some(x) = &self.bboxes { - if !x.is_empty() { - f.field("BoundingBoxes", &x); + if let Some(xs) = &self.bboxes { + if !xs.is_empty() { + f.field("BBoxes", &xs); } } - if let Some(x) = &self.mbrs { - if !x.is_empty() { - f.field("MinimumBoundingRectangles", &x); + if let Some(xs) = &self.mbrs { + if !xs.is_empty() { + f.field("OBBs", &xs); } } - if let Some(x) = &self.keypoints { - if !x.is_empty() { - f.field("Keypoints", &x); + if let Some(xs) = &self.keypoints { + if !xs.is_empty() { + f.field("Kpts", &xs); } } - if let Some(x) = &self.polygons { - if !x.is_empty() { - f.field("Polygons", &x); + if let Some(xs) = &self.polygons { + if !xs.is_empty() { + f.field("Polys", &xs); } } - if let Some(x) = &self.masks { - if !x.is_empty() { - f.field("Masks", &x); + if let Some(xs) = &self.masks { + if !xs.is_empty() { + f.field("Masks", &xs); } } if let Some(x) = &self.embedding { @@ -74,6 +74,14 @@ impl std::fmt::Debug for Y { } impl Y { + pub fn hbbs(&self) -> Option<&[Bbox]> { + self.bboxes.as_deref() + } + + pub fn obbs(&self) -> Option<&[Mbr]> { + self.mbrs.as_deref() + } + pub fn apply_nms(mut self, iou_threshold: f32) -> Self { match &mut self.bboxes { None => match &mut self.mbrs { diff --git a/src/xy/ys.rs b/src/xy/ys.rs new file mode 100644 index 0000000..b4303b1 --- /dev/null +++ b/src/xy/ys.rs @@ -0,0 +1,19 @@ +use crate::Y; + +/// Wrapper over `Vec` +#[derive(aksr::Builder, Default, Debug)] +pub struct Ys(pub Vec); + +impl From> for Ys { + fn from(xs: Vec) -> Self { + Self(xs) + } +} + +impl std::ops::Deref for Ys { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/ys/embedding.rs b/src/ys/embedding.rs deleted file mode 100644 index c6b554b..0000000 --- a/src/ys/embedding.rs +++ /dev/null @@ -1,49 +0,0 @@ -use aksr::Builder; -use anyhow::Result; -use ndarray::{Array, Axis, Ix2, IxDyn}; - -use crate::X; - -/// Embedding for image or text. -#[derive(Builder, Clone, PartialEq, Default)] -pub struct Embedding(#[args(alias = "embedding")] Array); - -impl std::fmt::Debug for Embedding { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("").field("Shape", &self.0.shape()).finish() - } -} - -impl From for Embedding { - fn from(x: X) -> Self { - Self(x.0) - } -} - -impl Embedding { - pub fn new(x: Array) -> Self { - Self(x) - } - - pub fn data(&self) -> &Array { - &self.0 - } - - pub fn norm(mut self) -> Self { - let std_ = self.0.mapv(|x| x * x).sum_axis(Axis(0)).mapv(f32::sqrt); - self.0 = self.0 / std_; - self - } - - pub fn dot2(&self, other: &Embedding) -> Result>> { - // (m, ndim) * (n, ndim).t => (m, n) - let query = self.0.to_owned().into_dimensionality::()?; - let gallery = other.0.to_owned().into_dimensionality::()?; - let matrix = query.dot(&gallery.t()); - let exps = matrix.mapv(|x| x.exp()); - let stds = exps.sum_axis(Axis(1)); - let matrix = exps / stds.insert_axis(Axis(1)); - let matrix: Vec> = matrix.axis_iter(Axis(0)).map(|row| row.to_vec()).collect(); - Ok(matrix) - } -}