Skip to content

Commit

Permalink
fix: Port pykeio#218 (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip authored Jun 25, 2024
1 parent a2d6ae2 commit b6c41c6
Showing 1 changed file with 142 additions and 110 deletions.
252 changes: 142 additions & 110 deletions src/operator/bound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::{
kernel::{Kernel, KernelAttributes, KernelContext},
DummyOperator, Operator
};
use crate::error::IntoStatus;
use crate::{error::IntoStatus, extern_system_fn};

#[repr(C)]
#[derive(Clone)]
Expand Down Expand Up @@ -62,115 +62,147 @@ impl<O: Operator> BoundOperator<O> {
&*op.cast()
}

pub(crate) unsafe extern "C" fn CreateKernelV2(
_: *const ort_sys::OrtCustomOp,
_: *const ort_sys::OrtApi,
info: *const ort_sys::OrtKernelInfo,
kernel_ptr: *mut *mut ort_sys::c_void
) -> *mut ort_sys::OrtStatus {
let kernel = match O::create_kernel(&KernelAttributes::new(info)) {
Ok(kernel) => kernel,
e => return e.into_status()
};
*kernel_ptr = (Box::leak(Box::new(kernel)) as *mut O::Kernel).cast();
Ok(()).into_status()
}

pub(crate) unsafe extern "C" fn ComputeKernelV2(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus {
let context = KernelContext::new(context);
O::Kernel::compute(unsafe { &mut *kernel_ptr.cast::<O::Kernel>() }, &context).into_status()
}

pub(crate) unsafe extern "C" fn KernelDestroy(op_kernel: *mut ort_sys::c_void) {
drop(Box::from_raw(op_kernel.cast::<O::Kernel>()));
}

pub(crate) unsafe extern "C" fn GetName(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
let safe = Self::safe(op);
safe.name.as_ptr()
}
pub(crate) unsafe extern "C" fn GetExecutionProviderType(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
let safe = Self::safe(op);
safe.execution_provider_type.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null)
}

pub(crate) unsafe extern "C" fn GetStartVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::min_version()
}
pub(crate) unsafe extern "C" fn GetEndVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::max_version()
}

pub(crate) unsafe extern "C" fn GetInputMemoryType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtMemType {
O::inputs()[index as usize].memory_type.into()
}
pub(crate) unsafe extern "C" fn GetInputCharacteristic(
_: *const ort_sys::OrtCustomOp,
index: ort_sys::size_t
) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
O::inputs()[index as usize].characteristic.into()
}
pub(crate) unsafe extern "C" fn GetOutputCharacteristic(
_: *const ort_sys::OrtCustomOp,
index: ort_sys::size_t
) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
O::outputs()[index as usize].characteristic.into()
}
pub(crate) unsafe extern "C" fn GetInputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t {
O::inputs().len() as _
}
pub(crate) unsafe extern "C" fn GetOutputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t {
O::outputs().len() as _
}
pub(crate) unsafe extern "C" fn GetInputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType {
O::inputs()[index as usize]
.r#type
.map(|c| c.into())
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
}
pub(crate) unsafe extern "C" fn GetOutputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType {
O::outputs()[index as usize]
.r#type
.map(|c| c.into())
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
}
pub(crate) unsafe extern "C" fn GetVariadicInputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::inputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_min_arity)
.unwrap_or(1)
.try_into()
.expect("input minimum arity overflows i32")
}
pub(crate) unsafe extern "C" fn GetVariadicInputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::inputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_homogeneity)
.unwrap_or(false)
.into()
}
pub(crate) unsafe extern "C" fn GetVariadicOutputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::outputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_min_arity)
.unwrap_or(1)
.try_into()
.expect("output minimum arity overflows i32")
}
pub(crate) unsafe extern "C" fn GetVariadicOutputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::outputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_homogeneity)
.unwrap_or(false)
.into()
}

pub(crate) unsafe extern "C" fn InferOutputShapeFn(_: *const ort_sys::OrtCustomOp, arg1: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus {
O::get_infer_shape_function().expect("missing infer shape function")(arg1).into_status()
extern_system_fn! {
pub(crate) unsafe fn CreateKernelV2(
_: *const ort_sys::OrtCustomOp,
_: *const ort_sys::OrtApi,
info: *const ort_sys::OrtKernelInfo,
kernel_ptr: *mut *mut ort_sys::c_void
) -> *mut ort_sys::OrtStatus {
let kernel = match O::create_kernel(&KernelAttributes::new(info)) {
Ok(kernel) => kernel,
e => return e.into_status()
};
*kernel_ptr = (Box::leak(Box::new(kernel)) as *mut O::Kernel).cast();
Ok(()).into_status()
}
}

extern_system_fn! {
pub(crate) unsafe fn ComputeKernelV2(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus {
let context = KernelContext::new(context);
O::Kernel::compute(unsafe { &mut *kernel_ptr.cast::<O::Kernel>() }, &context).into_status()
}
}

extern_system_fn! {
pub(crate) unsafe fn KernelDestroy(op_kernel: *mut ort_sys::c_void) {
drop(Box::from_raw(op_kernel.cast::<O::Kernel>()));
}
}

extern_system_fn! {
pub(crate) unsafe fn GetName(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
let safe = Self::safe(op);
safe.name.as_ptr()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetExecutionProviderType(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
let safe = Self::safe(op);
safe.execution_provider_type.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null)
}
}

extern_system_fn! {
pub(crate) unsafe fn GetStartVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::min_version()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetEndVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::max_version()
}
}

extern_system_fn! {
pub(crate) unsafe fn GetInputMemoryType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtMemType {
O::inputs()[index as usize].memory_type.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetInputCharacteristic(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
O::inputs()[index as usize].characteristic.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetOutputCharacteristic(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
O::outputs()[index as usize].characteristic.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetInputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t {
O::inputs().len() as _
}
}
extern_system_fn! {
pub(crate) unsafe fn GetOutputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t {
O::outputs().len() as _
}
}
extern_system_fn! {
pub(crate) unsafe fn GetInputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType {
O::inputs()[index as usize]
.r#type
.map(|c| c.into())
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
}
}
extern_system_fn! {
pub(crate) unsafe fn GetOutputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType {
O::outputs()[index as usize]
.r#type
.map(|c| c.into())
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
}
}
extern_system_fn! {
pub(crate) unsafe fn GetVariadicInputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::inputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_min_arity)
.unwrap_or(1)
.try_into()
.expect("input minimum arity overflows i32")
}
}
extern_system_fn! {
pub(crate) unsafe fn GetVariadicInputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::inputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_homogeneity)
.unwrap_or(false)
.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetVariadicOutputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::outputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_min_arity)
.unwrap_or(1)
.try_into()
.expect("output minimum arity overflows i32")
}
}
extern_system_fn! {
pub(crate) unsafe fn GetVariadicOutputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::outputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_homogeneity)
.unwrap_or(false)
.into()
}
}

extern_system_fn! {
pub(crate) unsafe fn InferOutputShapeFn(_: *const ort_sys::OrtCustomOp, arg1: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus {
O::get_infer_shape_function().expect("missing infer shape function")(arg1).into_status()
}
}
}

Expand Down

0 comments on commit b6c41c6

Please sign in to comment.