Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

discover relative path from exe location #1344

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 70 additions & 65 deletions cli/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use reqwest::Url;
use scan_fmt::scan_fmt;
use std::io::Cursor;
use std::io::Read;
use std::io::Seek;
use std::path::PathBuf;
use std::str::FromStr;
use tract_core::ops::konst::Const;
Expand Down Expand Up @@ -37,21 +36,21 @@ use super::info_usage;
use std::convert::*;

#[derive(Debug)]
enum ModelLocation {
enum Location {
Fs(PathBuf),
Http(Url),
}

impl ModelLocation {
impl Location {
fn path(&self) -> Cow<std::path::Path> {
match self {
ModelLocation::Fs(p) => p.into(),
ModelLocation::Http(u) => std::path::Path::new(u.path()).into(),
Location::Fs(p) => p.into(),
Location::Http(u) => std::path::Path::new(u.path()).into(),
}
}

fn is_dir(&self) -> bool {
if let &ModelLocation::Fs(p) = &self {
if let &Location::Fs(p) = &self {
p.is_dir()
} else {
false
Expand All @@ -60,10 +59,41 @@ impl ModelLocation {

fn read(&self) -> TractResult<Box<dyn Read>> {
match self {
ModelLocation::Fs(p) => Ok(Box::new(std::fs::File::open(p)?)),
ModelLocation::Http(u) => Ok(Box::new(reqwest::blocking::get(u.clone())?)),
Location::Fs(p) => Ok(Box::new(std::fs::File::open(p)?)),
Location::Http(u) => Ok(Box::new(reqwest::blocking::get(u.clone())?)),
}
}

fn bytes(&self) -> TractResult<Vec<u8>> {
let mut vec = vec![];
self.read()?.read_to_end(&mut vec)?;
Ok(vec)
}

fn find(s: impl AsRef<str>) -> TractResult<Self> {
let s = s.as_ref();
let path = std::path::PathBuf::from(s);
if s.starts_with("http://") || s.starts_with("https://") {
return Ok(Location::Http(s.parse()?));
} else if path.exists() {
return Ok(Location::Fs(path));
} else if path.is_relative()
&& cfg!(any(
target_os = "ios",
target_os = "watchos",
target_os = "tvos",
target_os = "android"
))
{
if let Ok(pwd) = std::env::current_exe() {
let absolute = pwd.parent().unwrap().join(&path);
if absolute.exists() {
return Ok(Location::Fs(absolute));
}
}
}
bail!("File not found {}", s)
}
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -128,31 +158,17 @@ type TfExt = tract_tensorflow::model::TfModelExtensions;
type TfExt = ();

impl Parameters {
fn disco_model(matches: &clap::ArgMatches) -> TractResult<(ModelLocation, bool)> {
fn disco_model(matches: &clap::ArgMatches) -> TractResult<(Location, bool)> {
let model = matches.value_of("model").context("Model argument required")?;
let path = std::path::PathBuf::from(model);
let (location, onnx_tc) = if model.starts_with("http://") || model.starts_with("https://") {
(ModelLocation::Http(model.parse()?), false)
} else if !path.exists() {
bail!("model not found: {:?}", path)
} else if std::fs::metadata(&path)?.is_file()
&& path.file_name().unwrap().to_string_lossy() == "graph.nnef"
{
(ModelLocation::Fs(path.parent().unwrap().to_owned()), false)
} else if std::fs::metadata(&path)?.is_dir() && path.join("graph.nnef").exists() {
(ModelLocation::Fs(path), false)
} else if std::fs::metadata(&path)?.is_dir() && path.join("model.onnx").exists() {
(ModelLocation::Fs(path.join("model.onnx")), true)
} else {
(ModelLocation::Fs(path), false)
};
let location = Location::find(model)?;
let onnx_tc = location.is_dir() && location.path().join("model.onnx").exists();
Ok((location, onnx_tc))
}

fn load_model(
matches: &clap::ArgMatches,
probe: Option<&Probe>,
location: &ModelLocation,
location: &Location,
tensors_values: &TensorsValues,
symbol_table: &SymbolTable,
) -> TractResult<(SomeGraphDef, Box<dyn Model>, Option<TfExt>)> {
Expand All @@ -178,7 +194,7 @@ impl Parameters {
"nnef" => {
let nnef = super::nnef(matches);
let mut proto_model = if location.is_dir() {
if let ModelLocation::Fs(dir) = location {
if let Location::Fs(dir) = location {
nnef.proto_model_for_path(dir)?
} else {
unreachable!();
Expand Down Expand Up @@ -446,37 +462,26 @@ impl Parameters {
get_values: bool,
get_facts: bool,
) -> TractResult<Vec<TensorValues>> {
fn do_it(reader: impl Read + Seek) -> TractResult<Vec<(String, usize, Tensor)>> {
let mut npz = ndarray_npy::NpzReader::new(reader)?;
npz.names()?
.iter()
.map(|n| {
if let Ok((turn, name)) =
scan_fmt::scan_fmt!(n, "turn_{d}/{}.npy", usize, String)
{
Ok((name, turn, tensor::for_npz(&mut npz, n)?))
} else {
let name = n.trim_end_matches(".npy").to_string();
Ok((name, 0, tensor::for_npz(&mut npz, n)?))
}
})
.collect()
}
let triples = if input.starts_with("http://") || input.starts_with("https://") {
let mut buf = vec![];
reqwest::blocking::get(input)?.error_for_status()?.read_to_end(&mut buf)?;
do_it(Cursor::new(buf)).with_context(|| format!("reading file from {input:?}"))
} else {
let fd =
std::fs::File::open(input).with_context(|| format!("reading file {input:?}"))?;
do_it(fd)
}?;
let loc = Location::find(input)?;
let mut npz = ndarray_npy::NpzReader::new(Cursor::new(loc.bytes()?))?;
let triples = npz
.names()?
.iter()
.map(|n| {
if let Ok((turn, name)) = scan_fmt::scan_fmt!(n, "turn_{d}/{}.npy", usize, String) {
Ok((name, turn, tensor::for_npz(&mut npz, n)?))
} else {
let name = n.trim_end_matches(".npy").to_string();
Ok((name, 0, tensor::for_npz(&mut npz, n)?))
}
})
.collect::<TractResult<Vec<_>>>()?;
Ok(Self::tensor_values_from_iter(triples.into_iter(), get_values, get_facts))
}

fn parse_tensors(
matches: &clap::ArgMatches,
location: &ModelLocation,
location: &Location,
onnx_tc: bool,
symbol_table: &SymbolTable,
) -> TractResult<TensorsValues> {
Expand Down Expand Up @@ -660,8 +665,8 @@ impl Parameters {
stage!("analyse", inference_model -> inference_model,
|mut m:InferenceModel| -> TractResult<_> {
m.analyse(!matches.is_present("analyse-fail-fast")).map_err(|e|
ModelBuildingError(Box::new(m.clone()), e.into())
)?;
ModelBuildingError(Box::new(m.clone()), e.into())
)?;
if let Some(fail) = m.missing_type_shape()?.first() {
bail!(ModelBuildingError(Box::new(m.clone()), format!("{} has incomplete typing", m.node(fail.node)).into()))
}
Expand Down Expand Up @@ -749,16 +754,16 @@ impl Parameters {
let node = m.node_mut(node);
if let Some(op) = node.op_as_mut::<Const>() {
if op.0.datum_type() == DatumType::TDim { {
// get inner value to Arc<Tensor>
let mut constant = op.0.as_ref().clone();
// Generally a shape or hyperparam
constant
.as_slice_mut::<TDim>()?
.iter_mut()
.for_each(|x| *x = x.eval(&values));

op.0 = constant.into_arc_tensor();
}
// get inner value to Arc<Tensor>
let mut constant = op.0.as_ref().clone();
// Generally a shape or hyperparam
constant
.as_slice_mut::<TDim>()?
.iter_mut()
.for_each(|x| *x = x.eval(&values));

op.0 = constant.into_arc_tensor();
}
}
}
}
Expand Down
Loading