Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eun #1181

Merged
merged 6 commits into from
Sep 5, 2023
Merged

Eun #1181

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ members = [
"onnx",
"libcli",
"cli",
"extra",

"tflite",

Expand Down
8 changes: 8 additions & 0 deletions api/ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ pub unsafe extern "C" fn tract_nnef_enable_tract_core(nnef: *mut TractNnef) -> T
})
}

#[no_mangle]
pub unsafe extern "C" fn tract_nnef_enable_tract_extra(nnef: *mut TractNnef) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(nnef);
(*nnef).0.enable_tract_extra()
})
}

#[no_mangle]
pub unsafe extern "C" fn tract_nnef_enable_onnx(nnef: *mut TractNnef) -> TRACT_RESULT {
wrap(|| unsafe {
Expand Down
7 changes: 7 additions & 0 deletions api/generate-tract-h.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/sh

set -ex

cbindgen ffi > tract,h
cp tract.h c
cp tract.h proxy/sys
4 changes: 4 additions & 0 deletions api/proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ impl NnefInterface for Nnef {
check!(sys::tract_nnef_enable_tract_core(self.0))
}

fn enable_tract_extra(&mut self) -> Result<()> {
check!(sys::tract_nnef_enable_tract_extra(self.0))
}

fn enable_onnx(&mut self) -> Result<()> {
check!(sys::tract_nnef_enable_onnx(self.0))
}
Expand Down
4 changes: 3 additions & 1 deletion api/proxy/sys/tract.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ const char *tract_version(void);
void tract_free_cstring(char *ptr);

/**
* Creates an instance of an NNEF framework and parser that can be used to load models.
* Creates an instance of an NNEF framework and parser that can be used to load and dump NNEF models.
*
* The returned object should be destroyed with `tract_nnef_destroy` once the model
* has been loaded.
Expand All @@ -92,6 +92,8 @@ enum TRACT_RESULT tract_nnef_create(struct TractNnef **nnef);

enum TRACT_RESULT tract_nnef_enable_tract_core(struct TractNnef *nnef);

enum TRACT_RESULT tract_nnef_enable_tract_extra(struct TractNnef *nnef);

enum TRACT_RESULT tract_nnef_enable_onnx(struct TractNnef *nnef);

enum TRACT_RESULT tract_nnef_enable_pulse(struct TractNnef *nnef);
Expand Down
2 changes: 1 addition & 1 deletion api/py/tests/mobilenet_onnx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_state():
assert numpy.argmax(confidences) == 652

def test_nnef_register():
tract.nnef().with_tract_core().with_onnx().with_pulse()
tract.nnef().with_tract_core().with_onnx().with_pulse().with_tract_extra()

def test_nnef():
model = (
Expand Down
8 changes: 8 additions & 0 deletions api/py/tract/nnef.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def with_tract_core(self) -> "Nnef":
check(lib.tract_nnef_enable_tract_core(self.ptr))
return self

def with_tract_extra(self) -> "Nnef":
"""
Enable tract-extra extensions to NNEF.
"""
self._valid()
check(lib.tract_nnef_enable_tract_extra(self.ptr))
return self

def with_onnx(self) -> "Nnef":
"""
Enable tract-opl extensions to NNEF to covers (more or) ONNX operator set
Expand Down
1 change: 1 addition & 0 deletions api/rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ tract-api = { path = ".." , version = "=0.20.19-pre" }
tract-nnef = { path = "../../nnef/" , version = "=0.20.19-pre" }
tract-onnx-opl = { path = "../../onnx-opl/" , version = "=0.20.19-pre" }
tract-onnx = { path = "../../onnx/" , version = "=0.20.19-pre" }
tract-extra = { path = "../../extra/" , version = "=0.20.19-pre" }
tract-pulse = { path = "../../pulse/" , version = "=0.20.19-pre" }
tract-libcli = { path = "../../libcli" , version = "=0.20.19-pre" }
serde_json.workspace = true
Expand Down
6 changes: 6 additions & 0 deletions api/rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;

use anyhow::{Context, Result};
use ndarray::{Data, Dimension, RawData};
use tract_extra::WithTractExtra;
use tract_libcli::annotations::Annotations;
use tract_libcli::profile::BenchLimits;
use tract_nnef::internal::parse_tdim;
Expand Down Expand Up @@ -46,6 +47,11 @@ impl NnefInterface for Nnef {
Ok(())
}

fn enable_tract_extra(&mut self) -> Result<()> {
self.0.enable_tract_extra();
Ok(())
}

fn enable_onnx(&mut self) -> Result<()> {
self.0.enable_onnx();
Ok(())
Expand Down
9 changes: 9 additions & 0 deletions api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ pub trait NnefInterface: Sized {
/// Allow the framework to use tract_core extensions instead of a stricter NNEF definition.
fn enable_tract_core(&mut self) -> Result<()>;

/// Allow the framework to use tract_extra extensions.
fn enable_tract_extra(&mut self) -> Result<()>;

/// Allow the framework to use tract_onnx extensions to support operators in ONNX that are
/// absent from NNEF.
fn enable_onnx(&mut self) -> Result<()>;
Expand All @@ -38,6 +41,12 @@ pub trait NnefInterface: Sized {
Ok(self)
}

/// Convenience function, similar with enable_tract_core but allowing method chaining.
fn with_tract_extra(mut self) -> Result<Self> {
self.enable_tract_extra()?;
Ok(self)
}

/// Convenience function, similar with enable_onnx but allowing method chaining.
fn with_onnx(mut self) -> Result<Self> {
self.enable_onnx()?;
Expand Down
4 changes: 3 additions & 1 deletion api/tract.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ const char *tract_version(void);
void tract_free_cstring(char *ptr);

/**
* Creates an instance of an NNEF framework and parser that can be used to load models.
* Creates an instance of an NNEF framework and parser that can be used to load and dump NNEF models.
*
* The returned object should be destroyed with `tract_nnef_destroy` once the model
* has been loaded.
Expand All @@ -92,6 +92,8 @@ enum TRACT_RESULT tract_nnef_create(struct TractNnef **nnef);

enum TRACT_RESULT tract_nnef_enable_tract_core(struct TractNnef *nnef);

enum TRACT_RESULT tract_nnef_enable_tract_extra(struct TractNnef *nnef);

enum TRACT_RESULT tract_nnef_enable_onnx(struct TractNnef *nnef);

enum TRACT_RESULT tract_nnef_enable_pulse(struct TractNnef *nnef);
Expand Down
4 changes: 3 additions & 1 deletion cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,17 @@ tract-core = { version = "=0.20.19-pre", path = "../core" }
tract-hir = { version = "=0.20.19-pre", path = "../hir" }
tract-nnef = { version = "=0.20.19-pre", path = "../nnef" }
tract-libcli = { version = "=0.20.19-pre", path = "../libcli" }
tract-extra = { optional = true, version = "=0.20.19-pre", path = "../extra" }
tract-pulse-opl = { optional = true, version = "=0.20.19-pre", path = "../pulse-opl" }
tract-pulse = { optional = true, version = "=0.20.19-pre", path = "../pulse" }
tract-onnx = { optional = true, version = "=0.20.19-pre", path = "../onnx" }
tract-tensorflow = { optional = true, version = "=0.20.19-pre", path = "../tensorflow" }
tract-tflite = { optional = true, version = "=0.20.19-pre", path = "../tflite" }

[features]
default = ["onnx", "tf", "pulse", "pulse-opl", "tflite"]
default = ["onnx", "tf", "pulse", "pulse-opl", "tflite", "extra"]
onnx = [ "tract-onnx", "tract-libcli/hir", "tract-libcli/onnx" ]
extra = [ "tract-extra" ]
pulse-opl = [ "tract-pulse-opl" ]
pulse = [ "tract-pulse", "tract-pulse-opl" ]
tf = [ "tract-tensorflow", "tract-libcli/hir" ]
Expand Down
12 changes: 12 additions & 0 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ fn main() -> tract_core::anyhow::Result<()> {
.arg(arg!(--"nnef-tract-core" "Allow usage of tract-core extension in NNEF dump and load"))
.arg(arg!(--"nnef-tract-onnx" "Allow usage of tract-onnx extension in NNEF dump and load"))
.arg(arg!(--"nnef-tract-pulse" "Allow usage of tract-pulse extension in NNEF dump and load"))
.arg(arg!(--"nnef-tract-extra" "Allow usage of tract-extra extension in NNEF dump and load"))
.arg(arg!(--"nnef-extended-identifier" "Allow usage of the i\"...\" syntax to escape identifier names"))

.arg(arg!(-O --optimize "Optimize before running"))
Expand Down Expand Up @@ -609,6 +610,17 @@ fn nnef(matches: &clap::ArgMatches) -> tract_nnef::internal::Nnef {
panic!("tract is build without pulse-opl support")
}
}
if matches.is_present("nnef-tract-extra") {
#[cfg(feature = "extra")]
{
use tract_extra::WithTractExtra;
fw = fw.with_tract_extra();
}
#[cfg(not(feature = "extra"))]
{
panic!("tract is build without tract-extra support")
}
}
if matches.is_present("nnef-tract-core") {
fw = fw.with_tract_core();
}
Expand Down
13 changes: 12 additions & 1 deletion core/src/model/fact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ impl ShapeFact {
self.dims.remove(axis);
if let Some(concrete) = &mut self.concrete {
concrete.remove(axis);
}
} else {
self.compute_concrete();
};
Ok(())
}

Expand All @@ -124,6 +126,14 @@ impl ShapeFact {
let void: &[usize] = &[];
Self::from(void)
}

pub fn consistent(&self) -> TractResult<()> {
ensure!(
self.concrete
== self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
);
Ok(())
}
}

impl std::ops::Deref for ShapeFact {
Expand Down Expand Up @@ -241,6 +251,7 @@ impl TypedFact {
}

pub fn consistent(&self) -> TractResult<()> {
self.shape.consistent()?;
if let Some(k) = &self.konst {
if !self.matches(k.as_ref(), None)? {
bail!("fact says {}, constant is {:?}", self.format_dt_shape_nocheck(), k);
Expand Down
26 changes: 26 additions & 0 deletions extra/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[package]
name = "tract-extra"
version = "0.20.19-pre"
license = "MIT OR Apache-2.0"
authors = ["Mathieu Poumeyrol <[email protected]>"]
description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference"
repository = "https://github.com/snipsco/tract"
keywords = [ "TensorFlow", "NeuralNetworks" ]
categories = [ "science" ]
autobenches = false
edition = "2021"
rust-version = "1.65"

[badges]
maintenance = { status = "actively-developed" }

[dependencies]
tract-nnef = { version = "=0.20.19-pre", path = "../nnef" }
tract-pulse = { version = "=0.20.19-pre", path = "../pulse" }

[dev-dependencies]
criterion.workspace = true
env_logger.workspace = true
lazy_static.workspace = true
proptest.workspace = true
approx.workspace = true
Loading
Loading