Skip to content

Commit

Permalink
vendor the onnxruntime source to enable cargo package verification
Browse files Browse the repository at this point in the history
Signed-off-by: David Justice <[email protected]>
  • Loading branch information
devigned committed Nov 7, 2023
1 parent 8aeac1a commit 1d87dc6
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 38 deletions.
13 changes: 13 additions & 0 deletions rust/justfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@


vendor:
mkdir -p ./onnxruntime-sys/vendor/onnxruntime-src
cp -rf ../onnxruntime ./onnxruntime-sys/vendor/onnxruntime-src
cp -rf ../cmake ./onnxruntime-sys/vendor/onnxruntime-src
rm -rf ./onnxruntime-sys/vendor/onnxruntime-src/cmake/external/onnx
cp -rf ../include ./onnxruntime-sys/vendor/onnxruntime-src
mkdir -p ./onnxruntime-sys/vendor/onnxruntime-src/tools
cp -rf ../tools/ci_build ./onnxruntime-sys/vendor/onnxruntime-src/tools
cp -rf ../samples ./onnxruntime-sys/vendor/onnxruntime-src
cp -f ../requirements.txt.in ./onnxruntime-sys/vendor/onnxruntime-src
cp -f ../VERSION_NUMBER ./onnxruntime-sys/vendor/onnxruntime-src
1 change: 1 addition & 0 deletions rust/onnxruntime-sys/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
vendor
5 changes: 2 additions & 3 deletions rust/onnxruntime-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@ authors = ["Nicolas Bigaouette <[email protected]>"]
edition = "2018"
name = "onnxruntime-sys"
version = "0.0.14"

links = "onnxruntime"

description = "Unsafe wrapper around Microsoft's ONNX Runtime"
documentation = "https://docs.rs/onnxruntime-sys"
homepage = "https://github.com/microsoft/onnxruntime"
license = "MIT OR Apache-2.0"
readme = "../README.md"
repository = "https://github.com/microsoft/onnxruntime"

categories = ["science"]
keywords = ["neuralnetworks", "onnx", "bindings"]
include = ["src", "example", "vendor", "build.rs"]

[dependencies]
libloading = "0.7"

[build-dependencies]
bindgen = "0.63"
cmake = "0.1"
anyhow = "1.0"

# Used on unix
flate2 = "1.0"
Expand Down
74 changes: 39 additions & 35 deletions rust/onnxruntime-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ use std::{
str::FromStr,
};

// use cmake::build;

use anyhow::{anyhow, Context, Result};

/// ONNX Runtime version
///
/// WARNING: If version is changed, bindings for all platforms will have to be re-generated.
/// To do so, run this:
/// cargo build --package onnxruntime-sys --features generate-bindings
const ORT_VERSION: &str = include_str!("../../VERSION_NUMBER");
const ORT_VERSION: &str = include_str!("./vendor/onnxruntime-src/VERSION_NUMBER");

/// Base Url from which to download pre-built releases/
const ORT_RELEASE_BASE_URL: &str = "https://github.com/microsoft/onnxruntime/releases/download";
Expand All @@ -34,8 +38,8 @@ const ORT_RUST_ENV_GPU: &str = "ORT_RUST_USE_CUDA";
/// Subdirectory (of the 'target' directory) into which to extract the prebuilt library.
const ORT_PREBUILT_EXTRACT_DIR: &str = "onnxruntime";

fn main() {
let libort_install_dir = prepare_libort_dir();
fn main() -> Result<()> {
let libort_install_dir = prepare_libort_dir().context("preparing libort directory")?;

let include_dir = libort_install_dir.join("include");
let lib_dir = libort_install_dir.join("lib");
Expand All @@ -55,6 +59,7 @@ fn main() {
);

generate_bindings(&include_dir);
Ok(())
}

fn generate_bindings(include_dir: &Path) {
Expand All @@ -70,9 +75,7 @@ fn generate_bindings(include_dir: &Path) {
),
];

let path = include_dir
.join("onnxruntime")
.join("onnxruntime_c_api.h");
let path = include_dir.join("onnxruntime").join("onnxruntime_c_api.h");

// The bindgen::Builder is the main entry point
// to bindgen, and lets you build up options for
Expand Down Expand Up @@ -104,7 +107,7 @@ fn generate_bindings(include_dir: &Path) {

let generated_file = PathBuf::from(env::var("OUT_DIR").unwrap()).join("bindings.rs");
bindings
.write_to_file(&generated_file)
.write_to_file(generated_file)
.expect("Couldn't write bindings!");
}

Expand Down Expand Up @@ -142,15 +145,15 @@ fn extract_archive(filename: &Path, output: &Path) {
}

fn extract_tgz(filename: &Path, output: &Path) {
let file = fs::File::open(&filename).unwrap();
let file = fs::File::open(filename).unwrap();
let buf = io::BufReader::new(file);
let tar = flate2::read::GzDecoder::new(buf);
let mut archive = tar::Archive::new(tar);
archive.unpack(output).unwrap();
}

fn extract_zip(filename: &Path, outpath: &Path) {
let file = fs::File::open(&filename).unwrap();
let file = fs::File::open(filename).unwrap();
let buf = io::BufReader::new(file);
let mut archive = zip::ZipArchive::new(buf).unwrap();
for i in 0..archive.len() {
Expand All @@ -166,7 +169,7 @@ fn extract_zip(filename: &Path, outpath: &Path) {
);
if let Some(p) = outpath.parent() {
if !p.exists() {
fs::create_dir_all(&p).unwrap();
fs::create_dir_all(p).unwrap();
}
}
let mut outfile = fs::File::create(&outpath).unwrap();
Expand All @@ -188,15 +191,15 @@ enum Architecture {
}

impl FromStr for Architecture {
type Err = String;
type Err = anyhow::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"x86" => Ok(Architecture::X86),
"x86_64" => Ok(Architecture::X86_64),
"arm" => Ok(Architecture::Arm),
"aarch64" => Ok(Architecture::Arm64),
_ => Err(format!("Unsupported architecture: {}", s)),
_ => Err(anyhow!("Unsupported architecture: {s}")),
}
}
}
Expand Down Expand Up @@ -231,14 +234,14 @@ impl Os {
}

impl FromStr for Os {
type Err = String;
type Err = anyhow::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"windows" => Ok(Os::Windows),
"macos" => Ok(Os::MacOs),
"linux" => Ok(Os::Linux),
_ => Err(format!("Unsupported os: {}", s)),
_ => Err(anyhow!("Unsupported os: {s}")),
}
}
}
Expand All @@ -260,9 +263,9 @@ enum Accelerator {
}

impl FromStr for Accelerator {
type Err = String;
type Err = anyhow::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"1" | "yes" | "true" | "on" => Ok(Accelerator::Cuda),
_ => Ok(Accelerator::Cpu),
Expand Down Expand Up @@ -391,36 +394,37 @@ fn prepare_libort_dir_prebuilt() -> PathBuf {
extract_dir.join(prebuilt_archive.file_stem().unwrap())
}

fn prepare_libort_dir() -> PathBuf {
fn prepare_libort_dir() -> Result<PathBuf> {
let strategy = env::var(ORT_RUST_ENV_STRATEGY);
println!(
"strategy: {:?}",
strategy.as_ref().map_or_else(|_| "unknown", String::as_str)
);
match strategy.as_ref().map(String::as_str) {
Ok("download") => prepare_libort_dir_prebuilt(),
Ok("system") => PathBuf::from(match env::var(ORT_RUST_ENV_SYSTEM_LIB_LOCATION) {
Ok(p) => p,
Err(e) => {
panic!(
"Could not get value of environment variable {:?}: {:?}",
ORT_RUST_ENV_SYSTEM_LIB_LOCATION, e
);
}
}),
Ok("download") => Ok(prepare_libort_dir_prebuilt()),
Ok("system") => {
let location = env::var(ORT_RUST_ENV_SYSTEM_LIB_LOCATION).context(format!(
"Could not get value of environment variable {:?}",
ORT_RUST_ENV_SYSTEM_LIB_LOCATION
))?;
Ok(PathBuf::from(location))
}
Ok("compile") | Err(_) => prepare_libort_dir_compiled(),
_ => panic!("Unknown value for {:?}", ORT_RUST_ENV_STRATEGY),
_ => Err(anyhow!("Unknown value for {:?}", ORT_RUST_ENV_STRATEGY)),
}
}

fn prepare_libort_dir_compiled() -> PathBuf {
let mut config = cmake::Config::new("../../cmake");
fn prepare_libort_dir_compiled() -> Result<PathBuf> {
let manifest_dir_string = env::var("CARGO_MANIFEST_DIR").unwrap();
let mut config = cmake::Config::new(format!(
"{manifest_dir_string}/vendor/onnxruntime-src/cmake"
));

config.define("onnxruntime_BUILD_SHARED_LIB", "ON");

if env::var(ORT_RUST_ENV_GPU).unwrap_or_default().parse() == Ok(Accelerator::Cuda) {
if let Ok(Accelerator::Cuda) = env::var(ORT_RUST_ENV_GPU).unwrap_or_default().parse() {
config.define("onnxruntime_USE_CUDA", "ON");
}
};

config.build()
Ok(config.build())
}

0 comments on commit 1d87dc6

Please sign in to comment.