Skip to content

Commit

Permalink
feat: only have the callback_with_info api
Browse files Browse the repository at this point in the history
  • Loading branch information
uttarayan21 committed Sep 21, 2024
1 parent 51d4f70 commit 347b1d4
Show file tree
Hide file tree
Showing 8 changed files with 306 additions and 73 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ thiserror = "1.0"
error-stack = { version = "0.5" }
oneshot = "0.1"
tracing = { version = "0.1.40", optional = true }
dunce = "1.0.5"

[features]
metal = ["mnn-sys/metal"]
Expand Down
4 changes: 2 additions & 2 deletions examples/inspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ pub fn main() -> anyhow::Result<()> {
});
time!(interpreter.run_session_with_callback(&session, |_, name| {
println!("Before Callback: {:?}", name);
1
true
},|_ , name| {
println!("After Callback: {:?}", name);
1
true
} , true)?;"run session");
let outputs = interpreter.outputs(&session);
outputs.iter().for_each(|x| {
Expand Down
61 changes: 53 additions & 8 deletions mnn-sys/mnn_c/interpreter_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
extern "C" {
int rust_closure_callback_runner(void *closure, Tensor *const *tensors,
size_t tensorCount, const char *opName);
int rust_closure_callback_runner_op(void *closure, Tensor *const *tensors,
size_t tensorCount, const void *op);

void modelPrintIO(const char *model) {
auto net = MNN::Interpreter::createFromFile(model);
Expand Down Expand Up @@ -198,20 +200,58 @@ ErrorCode Interpreter_runSessionWithCallBack(const Interpreter *interpreter,
MNN::TensorCallBack beforeCpp =
[before](const std::vector<MNN::Tensor *> &tensors,
const std::string &opName) {
return rust_closure_callback_runner(
if (before == nullptr) {
return true;
}
return static_cast<bool>(rust_closure_callback_runner(
before, reinterpret_cast<Tensor *const *>(tensors.data()),
tensors.size(), opName.c_str());
tensors.size(), opName.c_str()));
};

MNN::TensorCallBack endCpp = [end](const std::vector<MNN::Tensor *> &tensors,
const std::string &opName) {
return rust_closure_callback_runner(
if (end == nullptr) {
return true;
}
return static_cast<bool>(rust_closure_callback_runner(
end, reinterpret_cast<Tensor *const *>(tensors.data()), tensors.size(),
opName.c_str());
opName.c_str()));
};
auto net = reinterpret_cast<MNN::Interpreter const *>(interpreter);
auto sess = reinterpret_cast<MNN::Session const *>(session);
auto ret = net->runSessionWithCallBack(sess, beforeCpp, endCpp, sync);
auto ret = net->runSessionWithCallBack(sess, beforeCpp, endCpp,
static_cast<bool>(sync));
return static_cast<ErrorCode>(ret);
}

ErrorCode Interpreter_runSessionWithCallBackInfo(const Interpreter *interpreter,
const Session *session,
void *before, void *end,
int sync) {
MNN::TensorCallBackWithInfo beforeCpp =
[before](const std::vector<MNN::Tensor *> &tensors,
const MNN::OperatorInfo *op) {
if (before == nullptr) {
return true;
}
return static_cast<bool>(rust_closure_callback_runner_op(
before, reinterpret_cast<Tensor *const *>(tensors.data()),
tensors.size(), reinterpret_cast<const void *>(op)));
};
MNN::TensorCallBackWithInfo endCpp =
[end](const std::vector<MNN::Tensor *> &tensors,
const MNN::OperatorInfo *op) {
if (end == nullptr) {
return true;
}
return static_cast<bool>(rust_closure_callback_runner_op(
end, reinterpret_cast<Tensor *const *>(tensors.data()),
tensors.size(), reinterpret_cast<const void *>(op)));
};
auto net = reinterpret_cast<MNN::Interpreter const *>(interpreter);
auto sess = reinterpret_cast<MNN::Session const *>(session);
auto ret = net->runSessionWithCallBackInfo(sess, beforeCpp, endCpp,
static_cast<bool>(sync));
return static_cast<ErrorCode>(ret);
}

Expand Down Expand Up @@ -305,8 +345,13 @@ const char *Interpreter_uuid(const Interpreter *interpreter) {
reinterpret_cast<MNN::Interpreter const *>(interpreter);
return mnn_interpreter->uuid();
}
void Session_destroy(Session *session) {
auto mnn_session = reinterpret_cast<MNN::Session *>(session);
// delete mnn_session;
const char *OperatorInfo_name(const void *op) {
return reinterpret_cast<const MNN::OperatorInfo *>(op)->name().c_str();
}
const char *OperatorInfo_type(const void *op) {
return reinterpret_cast<const MNN::OperatorInfo *>(op)->type().c_str();
}
const float OperatorInfo_flops(const void *op) {
return reinterpret_cast<const MNN::OperatorInfo *>(op)->flops();
}
} // extern "C"
27 changes: 10 additions & 17 deletions mnn-sys/mnn_c/interpreter_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "backend_c.h"
#include "error_code_c.h"
#include "schedule_c.h"
#include "session_c.h"
#include "tensor_c.h"
#include "utils.h"
#include <MNN/HalideRuntime.h>
Expand All @@ -11,7 +12,6 @@
extern "C" {
#endif
typedef struct Interpreter Interpreter;
typedef struct Session Session;
typedef struct Backend Backend;

/** acquire runtime status by Runtime::getCurrentStatus with following keys,
Expand Down Expand Up @@ -55,16 +55,6 @@ enum RuntimeStatus {
// MNNBackendConfig *backendConfig;
// } ScheduleConfig;

typedef struct {
const char *name;
const char *type;
float flops;
} OperatorInfo;
typedef int (*TensorCallBack)(const Tensor **tensors, size_t tensorCount,
const char *opName);
typedef int (*TensorCallBackWithInfo)(const Tensor **tensors,
size_t tensorCount,
const OperatorInfo *opInfo);
#if 0
typedef struct {
std::map<MNNForwardType, std::shared_ptr<Runtime>> *runtimeMap;
Expand Down Expand Up @@ -175,11 +165,10 @@ ErrorCode Interpreter_runSession(const Interpreter *interpreter,
ErrorCode Interpreter_runSessionWithCallBack(const Interpreter *interpreter,
const Session *session,
void *before, void *end, int sync);
// ErrorCode Interpreter_runSessionWithCallBackInfo(const Interpreter *interpreter,
// const Session *session,
// TensorCallBackWithInfo before,
// TensorCallBackWithInfo end,
// int sync);
ErrorCode Interpreter_runSessionWithCallBackInfo(const Interpreter *interpreter,
const Session *session,
void *before, void *end,
int sync);
Tensor *Interpreter_getSessionInput(Interpreter *interpreter,
const Session *session, const char *name);
Tensor *Interpreter_getSessionOutput(Interpreter *interpreter,
Expand All @@ -203,7 +192,11 @@ const Backend *Interpreter_getBackend(const Interpreter *interpreter,
const Tensor *tensor);
const char *Interpreter_bizCode(const Interpreter *interpreter);
const char *Interpreter_uuid(const Interpreter *interpreter);
void Session_destroy(Session *session);

const char *OperatorInfo_name(const void *op);
const char *OperatorInfo_type(const void *op);
const float OperatorInfo_flops(const void *op);

#ifdef __cplusplus
}
#endif
Expand Down
19 changes: 19 additions & 0 deletions mnn-sys/mnn_c/session_c.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "session_c.h"
#include <MNN/Interpreter.hpp>

namespace MNN {
class Session {
public:
bool hasAsyncWork();
};
} // namespace MNN
void Session_destroy(Session *session) {
auto mnn_session = reinterpret_cast<MNN::Session *>(session);
// delete mnn_session;
}

int Session_hasAsyncWork(Session *session) {
auto mnn_session = reinterpret_cast<MNN::Session *>(session);
return mnn_session->hasAsyncWork();
// return true;
}
16 changes: 16 additions & 0 deletions mnn-sys/mnn_c/session_c.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef SESSION_C_H
#define SESSION_C_H

#ifdef __cplusplus
extern "C" {
#endif

typedef struct Session Session;
void Session_destroy(Session *session);
int Session_hasAsyncWork(Session *session);

#ifdef __cplusplus
}
#endif

#endif // SESSION_C_H
Loading

0 comments on commit 347b1d4

Please sign in to comment.