diff --git a/api/ffi/src/lib.rs b/api/ffi/src/lib.rs index de0f2b7560..9c57c0ed81 100644 --- a/api/ffi/src/lib.rs +++ b/api/ffi/src/lib.rs @@ -124,6 +124,14 @@ pub unsafe extern "C" fn tract_nnef_enable_tract_core(nnef: *mut TractNnef) -> T }) } +#[no_mangle] +pub unsafe extern "C" fn tract_nnef_enable_tract_extra(nnef: *mut TractNnef) -> TRACT_RESULT { + wrap(|| unsafe { + check_not_null!(nnef); + (*nnef).0.enable_tract_extra() + }) +} + #[no_mangle] pub unsafe extern "C" fn tract_nnef_enable_onnx(nnef: *mut TractNnef) -> TRACT_RESULT { wrap(|| unsafe { diff --git a/api/generate-tract-h.sh b/api/generate-tract-h.sh new file mode 100755 index 0000000000..3ce87536c5 --- /dev/null +++ b/api/generate-tract-h.sh @@ -0,0 +1,7 @@ +#!/bin/sh + +set -ex + +cbindgen ffi > tract,h +cp tract.h c +cp tract.h proxy/sys diff --git a/api/proxy/src/lib.rs b/api/proxy/src/lib.rs index c2da4459be..06e6f3b96e 100644 --- a/api/proxy/src/lib.rs +++ b/api/proxy/src/lib.rs @@ -69,6 +69,10 @@ impl NnefInterface for Nnef { check!(sys::tract_nnef_enable_tract_core(self.0)) } + fn enable_tract_extra(&mut self) -> Result<()> { + check!(sys::tract_nnef_enable_tract_extra(self.0)) + } + fn enable_onnx(&mut self) -> Result<()> { check!(sys::tract_nnef_enable_onnx(self.0)) } diff --git a/api/proxy/sys/tract.h b/api/proxy/sys/tract.h index b3c9254ff6..1471d65425 100644 --- a/api/proxy/sys/tract.h +++ b/api/proxy/sys/tract.h @@ -83,7 +83,7 @@ const char *tract_version(void); void tract_free_cstring(char *ptr); /** - * Creates an instance of an NNEF framework and parser that can be used to load models. + * Creates an instance of an NNEF framework and parser that can be used to load and dump NNEF models. * * The returned object should be destroyed with `tract_nnef_destroy` once the model * has been loaded. @@ -92,6 +92,8 @@ enum TRACT_RESULT tract_nnef_create(struct TractNnef **nnef); enum TRACT_RESULT tract_nnef_enable_tract_core(struct TractNnef *nnef); +enum TRACT_RESULT tract_nnef_enable_tract_extra(struct TractNnef *nnef); + enum TRACT_RESULT tract_nnef_enable_onnx(struct TractNnef *nnef); enum TRACT_RESULT tract_nnef_enable_pulse(struct TractNnef *nnef); diff --git a/api/py/tests/mobilenet_onnx_test.py b/api/py/tests/mobilenet_onnx_test.py index 4bd6d7b7b2..b3d5785bd5 100644 --- a/api/py/tests/mobilenet_onnx_test.py +++ b/api/py/tests/mobilenet_onnx_test.py @@ -48,7 +48,7 @@ def test_state(): assert numpy.argmax(confidences) == 652 def test_nnef_register(): - tract.nnef().with_tract_core().with_onnx().with_pulse() + tract.nnef().with_tract_core().with_onnx().with_pulse().with_tract_extra() def test_nnef(): model = ( diff --git a/api/py/tract/nnef.py b/api/py/tract/nnef.py index 074895c3a7..e9d7634b78 100644 --- a/api/py/tract/nnef.py +++ b/api/py/tract/nnef.py @@ -56,6 +56,14 @@ def with_tract_core(self) -> "Nnef": check(lib.tract_nnef_enable_tract_core(self.ptr)) return self + def with_tract_extra(self) -> "Nnef": + """ + Enable tract-extra extensions to NNEF. + """ + self._valid() + check(lib.tract_nnef_enable_tract_extra(self.ptr)) + return self + def with_onnx(self) -> "Nnef": """ Enable tract-opl extensions to NNEF to covers (more or) ONNX operator set diff --git a/api/rs/Cargo.toml b/api/rs/Cargo.toml index 4cb1dd03d5..105921040b 100644 --- a/api/rs/Cargo.toml +++ b/api/rs/Cargo.toml @@ -22,6 +22,7 @@ tract-api = { path = ".." , version = "=0.20.19-pre" } tract-nnef = { path = "../../nnef/" , version = "=0.20.19-pre" } tract-onnx-opl = { path = "../../onnx-opl/" , version = "=0.20.19-pre" } tract-onnx = { path = "../../onnx/" , version = "=0.20.19-pre" } +tract-extra = { path = "../../extra/" , version = "=0.20.19-pre" } tract-pulse = { path = "../../pulse/" , version = "=0.20.19-pre" } tract-libcli = { path = "../../libcli" , version = "=0.20.19-pre" } serde_json.workspace = true diff --git a/api/rs/src/lib.rs b/api/rs/src/lib.rs index ff68712c9e..87ea00c116 100644 --- a/api/rs/src/lib.rs +++ b/api/rs/src/lib.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use anyhow::{Context, Result}; use ndarray::{Data, Dimension, RawData}; +use tract_extra::WithTractExtra; use tract_libcli::annotations::Annotations; use tract_libcli::profile::BenchLimits; use tract_nnef::internal::parse_tdim; @@ -46,6 +47,11 @@ impl NnefInterface for Nnef { Ok(()) } + fn enable_tract_extra(&mut self) -> Result<()> { + self.0.enable_tract_extra(); + Ok(()) + } + fn enable_onnx(&mut self) -> Result<()> { self.0.enable_onnx(); Ok(()) diff --git a/api/src/lib.rs b/api/src/lib.rs index 5e71771411..55dd5ec56b 100644 --- a/api/src/lib.rs +++ b/api/src/lib.rs @@ -19,6 +19,9 @@ pub trait NnefInterface: Sized { /// Allow the framework to use tract_core extensions instead of a stricter NNEF definition. fn enable_tract_core(&mut self) -> Result<()>; + /// Allow the framework to use tract_extra extensions. + fn enable_tract_extra(&mut self) -> Result<()>; + /// Allow the framework to use tract_onnx extensions to support operators in ONNX that are /// absent from NNEF. fn enable_onnx(&mut self) -> Result<()>; @@ -38,6 +41,12 @@ pub trait NnefInterface: Sized { Ok(self) } + /// Convenience function, similar with enable_tract_core but allowing method chaining. + fn with_tract_extra(mut self) -> Result { + self.enable_tract_extra()?; + Ok(self) + } + /// Convenience function, similar with enable_onnx but allowing method chaining. fn with_onnx(mut self) -> Result { self.enable_onnx()?; diff --git a/api/tract.h b/api/tract.h index b3c9254ff6..1471d65425 100644 --- a/api/tract.h +++ b/api/tract.h @@ -83,7 +83,7 @@ const char *tract_version(void); void tract_free_cstring(char *ptr); /** - * Creates an instance of an NNEF framework and parser that can be used to load models. + * Creates an instance of an NNEF framework and parser that can be used to load and dump NNEF models. * * The returned object should be destroyed with `tract_nnef_destroy` once the model * has been loaded. @@ -92,6 +92,8 @@ enum TRACT_RESULT tract_nnef_create(struct TractNnef **nnef); enum TRACT_RESULT tract_nnef_enable_tract_core(struct TractNnef *nnef); +enum TRACT_RESULT tract_nnef_enable_tract_extra(struct TractNnef *nnef); + enum TRACT_RESULT tract_nnef_enable_onnx(struct TractNnef *nnef); enum TRACT_RESULT tract_nnef_enable_pulse(struct TractNnef *nnef); diff --git a/cli/src/main.rs b/cli/src/main.rs index 5a983055a7..507e57d6d7 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -613,8 +613,8 @@ fn nnef(matches: &clap::ArgMatches) -> tract_nnef::internal::Nnef { if matches.is_present("nnef-tract-extra") { #[cfg(feature = "extra")] { - use tract_extra::WithExtra; - fw = fw.with_extra(); + use tract_extra::WithTractExtra; + fw = fw.with_tract_extra(); } #[cfg(not(feature = "extra"))] { diff --git a/extra/src/exp_unit_norm.rs b/extra/src/exp_unit_norm.rs index 3a0945b8e2..981eb75261 100644 --- a/extra/src/exp_unit_norm.rs +++ b/extra/src/exp_unit_norm.rs @@ -70,7 +70,7 @@ impl EvalOp for ExpUnitNorm { _session: &mut SessionState, _node_id: usize, ) -> TractResult>> { - Ok(Some(Box::new(ExpUnitNormState::default()))) + Ok(Some(Box::::default())) } fn eval(&self, inputs: TVec) -> TractResult> { @@ -93,7 +93,7 @@ impl ExpUnitNormState { *s = x.max(op.epsilon) * (1f32 - op.alpha) + *s * op.alpha; }); } - time_slice.zip_mut_with(&state, |x, s| *x = *x / s.sqrt()); + time_slice.zip_mut_with(&state, |x, s| *x /= s.sqrt()); self.index += 1; } Ok(tvec!(input.into_tvalue())) diff --git a/extra/src/lib.rs b/extra/src/lib.rs index 9cad7ef951..dbe46874b3 100644 --- a/extra/src/lib.rs +++ b/extra/src/lib.rs @@ -2,29 +2,29 @@ use tract_nnef::internal::*; mod exp_unit_norm; -pub trait WithExtra { - fn enable_extra(&mut self); - fn with_extra(self) -> Self; +pub trait WithTractExtra { + fn enable_tract_extra(&mut self); + fn with_tract_extra(self) -> Self; } -impl WithExtra for tract_nnef::framework::Nnef { - fn enable_extra(&mut self) { +impl WithTractExtra for tract_nnef::framework::Nnef { + fn enable_tract_extra(&mut self) { self.enable_tract_core(); - self.registries.push(tract_nnef_registry()); + self.registries.push(tract_extra_registry()); } - fn with_extra(mut self) -> Self { - self.enable_extra(); + fn with_tract_extra(mut self) -> Self { + self.enable_tract_extra(); self } } -pub fn tract_nnef_registry() -> Registry { +pub fn tract_extra_registry() -> Registry { let mut reg = Registry::new("tract_extra"); exp_unit_norm::register(&mut reg); reg } pub fn register_pulsifiers() { - let _ = tract_nnef_registry(); + let _ = tract_extra_registry(); }