Skip to content

Commit

Permalink
plug extra in
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 5, 2023
1 parent 02eaa35 commit 2265ebe
Show file tree
Hide file tree
Showing 13 changed files with 64 additions and 17 deletions.
8 changes: 8 additions & 0 deletions api/ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ pub unsafe extern "C" fn tract_nnef_enable_tract_core(nnef: *mut TractNnef) -> T
})
}

#[no_mangle]
pub unsafe extern "C" fn tract_nnef_enable_tract_extra(nnef: *mut TractNnef) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(nnef);
(*nnef).0.enable_tract_extra()
})
}

#[no_mangle]
pub unsafe extern "C" fn tract_nnef_enable_onnx(nnef: *mut TractNnef) -> TRACT_RESULT {
wrap(|| unsafe {
Expand Down
7 changes: 7 additions & 0 deletions api/generate-tract-h.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/sh

set -ex

cbindgen ffi > tract,h
cp tract.h c
cp tract.h proxy/sys
4 changes: 4 additions & 0 deletions api/proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ impl NnefInterface for Nnef {
check!(sys::tract_nnef_enable_tract_core(self.0))
}

fn enable_tract_extra(&mut self) -> Result<()> {
check!(sys::tract_nnef_enable_tract_extra(self.0))
}

fn enable_onnx(&mut self) -> Result<()> {
check!(sys::tract_nnef_enable_onnx(self.0))
}
Expand Down
4 changes: 3 additions & 1 deletion api/proxy/sys/tract.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ const char *tract_version(void);
void tract_free_cstring(char *ptr);

/**
* Creates an instance of an NNEF framework and parser that can be used to load models.
* Creates an instance of an NNEF framework and parser that can be used to load and dump NNEF models.
*
* The returned object should be destroyed with `tract_nnef_destroy` once the model
* has been loaded.
Expand All @@ -92,6 +92,8 @@ enum TRACT_RESULT tract_nnef_create(struct TractNnef **nnef);

enum TRACT_RESULT tract_nnef_enable_tract_core(struct TractNnef *nnef);

enum TRACT_RESULT tract_nnef_enable_tract_extra(struct TractNnef *nnef);

enum TRACT_RESULT tract_nnef_enable_onnx(struct TractNnef *nnef);

enum TRACT_RESULT tract_nnef_enable_pulse(struct TractNnef *nnef);
Expand Down
2 changes: 1 addition & 1 deletion api/py/tests/mobilenet_onnx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_state():
assert numpy.argmax(confidences) == 652

def test_nnef_register():
tract.nnef().with_tract_core().with_onnx().with_pulse()
tract.nnef().with_tract_core().with_onnx().with_pulse().with_tract_extra()

def test_nnef():
model = (
Expand Down
8 changes: 8 additions & 0 deletions api/py/tract/nnef.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def with_tract_core(self) -> "Nnef":
check(lib.tract_nnef_enable_tract_core(self.ptr))
return self

def with_tract_extra(self) -> "Nnef":
"""
Enable tract-extra extensions to NNEF.
"""
self._valid()
check(lib.tract_nnef_enable_tract_extra(self.ptr))
return self

def with_onnx(self) -> "Nnef":
"""
Enable tract-opl extensions to NNEF to covers (more or) ONNX operator set
Expand Down
1 change: 1 addition & 0 deletions api/rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ tract-api = { path = ".." , version = "=0.20.19-pre" }
tract-nnef = { path = "../../nnef/" , version = "=0.20.19-pre" }
tract-onnx-opl = { path = "../../onnx-opl/" , version = "=0.20.19-pre" }
tract-onnx = { path = "../../onnx/" , version = "=0.20.19-pre" }
tract-extra = { path = "../../extra/" , version = "=0.20.19-pre" }
tract-pulse = { path = "../../pulse/" , version = "=0.20.19-pre" }
tract-libcli = { path = "../../libcli" , version = "=0.20.19-pre" }
serde_json.workspace = true
Expand Down
6 changes: 6 additions & 0 deletions api/rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;

use anyhow::{Context, Result};
use ndarray::{Data, Dimension, RawData};
use tract_extra::WithTractExtra;
use tract_libcli::annotations::Annotations;
use tract_libcli::profile::BenchLimits;
use tract_nnef::internal::parse_tdim;
Expand Down Expand Up @@ -46,6 +47,11 @@ impl NnefInterface for Nnef {
Ok(())
}

fn enable_tract_extra(&mut self) -> Result<()> {
self.0.enable_tract_extra();
Ok(())
}

fn enable_onnx(&mut self) -> Result<()> {
self.0.enable_onnx();
Ok(())
Expand Down
9 changes: 9 additions & 0 deletions api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ pub trait NnefInterface: Sized {
/// Allow the framework to use tract_core extensions instead of a stricter NNEF definition.
fn enable_tract_core(&mut self) -> Result<()>;

/// Allow the framework to use tract_extra extensions.
fn enable_tract_extra(&mut self) -> Result<()>;

/// Allow the framework to use tract_onnx extensions to support operators in ONNX that are
/// absent from NNEF.
fn enable_onnx(&mut self) -> Result<()>;
Expand All @@ -38,6 +41,12 @@ pub trait NnefInterface: Sized {
Ok(self)
}

/// Convenience function, similar with enable_tract_core but allowing method chaining.
fn with_tract_extra(mut self) -> Result<Self> {
self.enable_tract_extra()?;
Ok(self)
}

/// Convenience function, similar with enable_onnx but allowing method chaining.
fn with_onnx(mut self) -> Result<Self> {
self.enable_onnx()?;
Expand Down
4 changes: 3 additions & 1 deletion api/tract.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ const char *tract_version(void);
void tract_free_cstring(char *ptr);

/**
* Creates an instance of an NNEF framework and parser that can be used to load models.
* Creates an instance of an NNEF framework and parser that can be used to load and dump NNEF models.
*
* The returned object should be destroyed with `tract_nnef_destroy` once the model
* has been loaded.
Expand All @@ -92,6 +92,8 @@ enum TRACT_RESULT tract_nnef_create(struct TractNnef **nnef);

enum TRACT_RESULT tract_nnef_enable_tract_core(struct TractNnef *nnef);

enum TRACT_RESULT tract_nnef_enable_tract_extra(struct TractNnef *nnef);

enum TRACT_RESULT tract_nnef_enable_onnx(struct TractNnef *nnef);

enum TRACT_RESULT tract_nnef_enable_pulse(struct TractNnef *nnef);
Expand Down
4 changes: 2 additions & 2 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,8 @@ fn nnef(matches: &clap::ArgMatches) -> tract_nnef::internal::Nnef {
if matches.is_present("nnef-tract-extra") {
#[cfg(feature = "extra")]
{
use tract_extra::WithExtra;
fw = fw.with_extra();
use tract_extra::WithTractExtra;
fw = fw.with_tract_extra();
}
#[cfg(not(feature = "extra"))]
{
Expand Down
4 changes: 2 additions & 2 deletions extra/src/exp_unit_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl EvalOp for ExpUnitNorm {
_session: &mut SessionState,
_node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
Ok(Some(Box::new(ExpUnitNormState::default())))
Ok(Some(Box::<ExpUnitNormState>::default()))
}

fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
Expand All @@ -93,7 +93,7 @@ impl ExpUnitNormState {
*s = x.max(op.epsilon) * (1f32 - op.alpha) + *s * op.alpha;
});
}
time_slice.zip_mut_with(&state, |x, s| *x = *x / s.sqrt());
time_slice.zip_mut_with(&state, |x, s| *x /= s.sqrt());
self.index += 1;
}
Ok(tvec!(input.into_tvalue()))
Expand Down
20 changes: 10 additions & 10 deletions extra/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,29 @@ use tract_nnef::internal::*;

mod exp_unit_norm;

pub trait WithExtra {
fn enable_extra(&mut self);
fn with_extra(self) -> Self;
pub trait WithTractExtra {
fn enable_tract_extra(&mut self);
fn with_tract_extra(self) -> Self;
}

impl WithExtra for tract_nnef::framework::Nnef {
fn enable_extra(&mut self) {
impl WithTractExtra for tract_nnef::framework::Nnef {
fn enable_tract_extra(&mut self) {
self.enable_tract_core();
self.registries.push(tract_nnef_registry());
self.registries.push(tract_extra_registry());
}

fn with_extra(mut self) -> Self {
self.enable_extra();
fn with_tract_extra(mut self) -> Self {
self.enable_tract_extra();
self
}
}

pub fn tract_nnef_registry() -> Registry {
pub fn tract_extra_registry() -> Registry {
let mut reg = Registry::new("tract_extra");
exp_unit_norm::register(&mut reg);
reg
}

pub fn register_pulsifiers() {
let _ = tract_nnef_registry();
let _ = tract_extra_registry();
}

0 comments on commit 2265ebe

Please sign in to comment.