Skip to content

Commit

Permalink
fix: OpenCL issues in linux
Browse files Browse the repository at this point in the history
  • Loading branch information
uttarayan21 committed Nov 11, 2024
1 parent db90692 commit e00e314
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 27 deletions.
8 changes: 7 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,14 @@
cmake
llvmPackages.libclang.lib
clang
pkg-config
];
buildInputs = with pkgs;
[]
++ (lib.optionals pkgs.stdenv.isLinux [
ocl-icd
opencl-headers
])
++ (lib.optionals pkgs.stdenv.isDarwin [
darwin.apple_sdk.frameworks.OpenCL
]
Expand Down Expand Up @@ -206,6 +211,7 @@
// {
packages = with pkgs;
[
clang
nushell
git
git-lfs
Expand Down
4 changes: 1 addition & 3 deletions mnn-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dunce = "1.0.4"
fs_extra = "1.3.0"
ignore = "0.4.23"
itertools = "0.13.0"
pkg-config = "0.3.31"
tap = "1.0.1"
walkdir = "2.5.0"

Expand All @@ -32,6 +33,3 @@ arm82 = []
bf16 = []

default = ["mnn-threadpool", "sparse-compute", "opencl"]

[dependencies]
libc = "0.2.155"
16 changes: 10 additions & 6 deletions mnn-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,13 @@ fn main() -> Result<()> {
println!("cargo:rustc-link-lib=framework=OpenCL");
#[cfg(feature = "opengl")]
println!("cargo:rustc-link-lib=framework=OpenGL");
} else {
// #[cfg(feature = "opencl")]
// println!("cargo:rustc-link-lib=static=opencl");
} else if *TARGET_OS == "linux" {
#[cfg(feature = "opencl")]
{
if !pkg_config::probe_library("OpenCL").is_ok() {
println!("cargo:rustc-link-lib=static=OpenCL");
};
}
}
if is_emscripten() {
// println!("cargo:rustc-link-lib=static=stdc++");
Expand Down Expand Up @@ -609,9 +613,9 @@ pub fn build_cpp_build(vendor: impl AsRef<Path>) -> Result<()> {
arm(&mut build, cpu_files_dir.join("arm"))?;
}

if TARGET_FEATURES.contains(&("sse".into())) && is_x86() {
x86_64(&mut build, cpu_files_dir.join("x86_64"))?;
}
// if TARGET_FEATURES.contains(&("sse".into())) && is_x86() {
// x86_64(&mut build, cpu_files_dir.join("x86_64"))?;
// }

build.files(cpu_files);
}
Expand Down
2 changes: 1 addition & 1 deletion src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ impl BackendConfig {
/// # Safety
/// This just binds to the underlying unsafe api and should be used only if you know what you
/// are doing
pub unsafe fn set_shared_context(&mut self, shared_context: *mut libc::c_void) {
pub unsafe fn set_shared_context(&mut self, shared_context: *mut core::ffi::c_void) {
unsafe {
mnn_sys::mnnbc_set_shared_context(self.inner, shared_context);
}
Expand Down
30 changes: 15 additions & 15 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl Default for TensorCallback {
}

impl TensorCallback {
pub fn from_ptr(f: *mut libc::c_void) -> Self {
pub fn from_ptr(f: *mut core::ffi::c_void) -> Self {
debug_assert!(!f.is_null());
unsafe {
Self {
Expand All @@ -30,8 +30,8 @@ impl TensorCallback {
}
}

pub fn into_ptr(self) -> *mut libc::c_void {
Arc::into_raw(self.inner) as *mut libc::c_void
pub fn into_ptr(self) -> *mut core::ffi::c_void {
Arc::into_raw(self.inner) as *mut core::ffi::c_void
}

pub fn identity() -> impl Fn(&[RawTensor], OperatorInfo) -> bool {
Expand Down Expand Up @@ -463,7 +463,7 @@ impl Interpreter {
end: impl Fn(&[RawTensor], OperatorInfo) -> bool + 'static,
sync: bool,
) -> Result<()> {
let sync = sync as libc::c_int;
let sync = sync as core::ffi::c_int;
let before = TensorCallback::from(before).into_ptr();
let end = TensorCallback::from(end).into_ptr();
let ret = unsafe {
Expand Down Expand Up @@ -555,7 +555,7 @@ impl Interpreter {
self.inner,
session.inner,
mnn_sys::cpp::MNN_Interpreter_SessionInfoCode_FLOPS as _,
flop_ptr.cast::<libc::c_void>(),
flop_ptr.cast::<core::ffi::c_void>(),
)
};
ensure!(
Expand Down Expand Up @@ -742,15 +742,15 @@ impl<'t, 'tl> Iterator for TensorListIter<'t, 'tl> {

// #[no_mangle]
// extern "C" fn rust_closure_callback_runner(
// f: *mut libc::c_void,
// f: *mut core::ffi::c_void,
// tensors: *const *mut mnn_sys::Tensor,
// tensor_count: usize,
// name: *const libc::c_char,
// ) -> libc::c_int {
// name: *const core::ffi::c_char,
// ) -> core::ffi::c_int {
// let tensors = unsafe { std::slice::from_raw_parts(tensors.cast(), tensor_count) };
// let name = unsafe { std::ffi::CStr::from_ptr(name) };
// let f: TensorCallback = unsafe { Box::from_raw(f.cast::<TensorCallback>()) };
// let ret = f(tensors, name) as libc::c_int;
// let ret = f(tensors, name) as core::ffi::c_int;
// core::mem::forget(f);
// ret
// }
Expand All @@ -766,32 +766,32 @@ impl<'t, 'tl> Iterator for TensorListIter<'t, 'tl> {
// let tensors = [std::ptr::null_mut()];
// let name = std::ffi::CString::new("Test").unwrap();
// let ret = rust_closure_callback_runner(f, tensors.as_ptr(), tensors.len(), name.as_ptr())
// as libc::c_int;
// as core::ffi::c_int;
// assert_eq!(ret, 0);
// }

#[no_mangle]
extern "C" fn rust_closure_callback_runner_op(
f: *mut libc::c_void,
f: *mut core::ffi::c_void,
tensors: *const *mut mnn_sys::Tensor,
tensor_count: usize,
op: *mut libc::c_void,
) -> libc::c_int {
op: *mut core::ffi::c_void,
) -> core::ffi::c_int {
let tensors = unsafe { std::slice::from_raw_parts(tensors.cast(), tensor_count) };
let f: TensorCallback = TensorCallback::from_ptr(f);
let op = OperatorInfo {
inner: op.cast(),
__marker: PhantomData,
};
let ret = f(tensors, op) as libc::c_int;
let ret = f(tensors, op) as core::ffi::c_int;

core::mem::forget(f);
ret
}

#[repr(transparent)]
pub struct OperatorInfo<'op> {
pub(crate) inner: *mut libc::c_void,
pub(crate) inner: *mut core::ffi::c_void,
pub(crate) __marker: PhantomData<&'op ()>,
}

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,5 @@ pub mod prelude {
pub(crate) use crate::profile::profile;
pub use core::marker::PhantomData;
pub use error_stack::{Report, ResultExt};
pub use libc::*;
pub use core::ffi::*;
}

0 comments on commit e00e314

Please sign in to comment.