You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We want to be able to compile and export trained inference models for usage outside of sleap-nn.
Background
The logic for inference is broken down into:
Data loading: I/O (VideoReader, LabelsReader)
Data preprocessing: moving to GPU, normalization, batching, etc.
Model forward pass
Postprocessing: peak finding, PAF grouping, etc.
Right now, some of these ops are a bit mixed across Predictor classes and underlying torch.nn.Modules.
In order to best support workflows where we compile/export the final model for inference-only workloads, we need to include steps 2-4 in the inference model itself (as done in core SLEAP).
The reason for this is both for:
Performance: tensor vectorized ops like normalization are much faster on the GPU, and we won't incur overhead of transferring float32 data from CPU. Additionally, inference engines like torch.compile and TensorRT can yield dramatic performance improvements when the system supports it.
Portability: being able to run those ops with an exported artifact without having to ship instructions for pre/post-processing, including implementation-dependent details like we might have in sleap-nn. This will be useful for building web demos, realtime inference and more.
Ultralytics is a gold-standard example of this, where they support a huge number of export formats:
Some of these formats can implement more complex ops than others, which would fit our needs.
Our goal will be implement support for:
Required: TensorRT, ONNX
Nice to have: torch.compile, CoreML, TF SavedModel/GraphDef/Lite/JS, OpenVINO
Likely, we'll need to adapt to the nuances of each inference runtime framework (TensorRT is notoriously picky), which will impose a particular modularization of the inference steps above. Examples of potential pitfalls:
Not supporting variable length shapes (meaning we need to implement padding logic)
Not supporting cropping in the middle of the pipeline (e.g., top-down)
In cases where the framework does support everything, it may be that we need to do it in a particular way for the conversion to work (e.g., sometimes resizing ops support nearest neighbor but not bilinear interpolation mode).
Overview
We want to be able to compile and export trained inference models for usage outside of
sleap-nn
.Background
The logic for inference is broken down into:
VideoReader
,LabelsReader
)Right now, some of these ops are a bit mixed across
Predictor
classes and underlyingtorch.nn.Module
s.In order to best support workflows where we compile/export the final model for inference-only workloads, we need to include steps 2-4 in the inference model itself (as done in core SLEAP).
The reason for this is both for:
float32
data from CPU. Additionally, inference engines liketorch.compile
and TensorRT can yield dramatic performance improvements when the system supports it.sleap-nn
. This will be useful for building web demos, realtime inference and more.Ultralytics is a gold-standard example of this, where they support a huge number of export formats:
(ref)
Some of these formats can implement more complex ops than others, which would fit our needs.
Our goal will be implement support for:
torch.compile
, CoreML, TF SavedModel/GraphDef/Lite/JS, OpenVINOLikely, we'll need to adapt to the nuances of each inference runtime framework (TensorRT is notoriously picky), which will impose a particular modularization of the inference steps above. Examples of potential pitfalls:
In cases where the framework does support everything, it may be that we need to do it in a particular way for the conversion to work (e.g., sometimes resizing ops support nearest neighbor but not bilinear interpolation mode).
Examples
SavedModel
export lets you do something like this to use a trained model without installing any special dependencies:PRs
TODO
The text was updated successfully, but these errors were encountered: