Skip to content

Commit

Permalink
Expose call to change which core a model is running on via JNI (#10)
Browse files Browse the repository at this point in the history
* added default case to core specifier to allow NPU to handle load balancing internally

* added support for all possible core masks

* added explicit branch for auto core mask, changed default case to fail

* added support for changing core mask at runtime

* cpp oop skill issues

* bruh i forgor to accept the user defined core
  • Loading branch information
james20902 authored Feb 2, 2024
1 parent 1dbb427 commit c0836a6
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 34 deletions.
9 changes: 9 additions & 0 deletions src/main/java/org/photonvision/rknn/RknnJNI.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ public boolean equals(Object obj) {
* @return Pointer to the detector in native memory
*/
public static native long create(String modelPath, int numClasses, int modelVer, int coreNum);

/**
* Given an already running detector, change the bitmask controlling which
* of the 3 cores the model is running on
* @param ptr Pointer to detector in native memory
* @param desiredCore Which of the three cores to operate on
* @return return code of rknn_set_core_mask call, indicating success or failure
*/
public static native int setCoreMask(long ptr, int desiredCore);

/**
* Delete all native resources assocated with a detector
Expand Down
9 changes: 9 additions & 0 deletions src/main/native/cpp/rknn_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ Java_org_photonvision_rknn_RknnJNI_create
return reinterpret_cast<jlong>(ret);
}

JNIEXPORT jint JNICALL Java_org_photonvision_rknn_RknnJNI_setCoreMask(JNIEnv *env,
jclass,
jlong ptr,
jint coreMask)
{
YoloModel *yolo = reinterpret_cast<YoloModel *>(ptr);
return yolo->changeCoreMask(coreMask);
}

/*
* Class: org_photonvision_rknn_RknnJNI
* Method: destroy
Expand Down
6 changes: 6 additions & 0 deletions src/main/native/cpp/rknn_jni.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ JNIEXPORT jlong JNICALL Java_org_photonvision_rknn_RknnJNI_create(JNIEnv *,
jstring,
jint, jint, jint);


JNIEXPORT jint JNICALL Java_org_photonvision_rknn_RknnJNI_setCoreMask(JNIEnv *,
jclass,
jlong,
jint);

/*
* Class: org_photonvision_rknn_RknnJNI
* Method: destroy
Expand Down
67 changes: 36 additions & 31 deletions src/main/native/cpp/yolo_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,37 +82,8 @@ YoloModel::YoloModel(std::string modelPath, int num_classes_, ModelVersion type_
throw std::runtime_error("rknn_init error ret=" + ret);
}

// 设置模型绑定的核心/Set the core of the model that needs to be bound
rknn_core_mask core_mask;
switch (coreNumber)
{
case -1:
core_mask = RKNN_NPU_CORE_AUTO;
break;
case 0:
core_mask = RKNN_NPU_CORE_0;
break;
case 1:
core_mask = RKNN_NPU_CORE_1;
break;
case 2:
core_mask = RKNN_NPU_CORE_2;
break;
case 10:
core_mask = RKNN_NPU_CORE_0_1;
break;
case 210:
core_mask = RKNN_NPU_CORE_0_1_2;
break;
default:
throw std::runtime_error("invalid core selection! core selected: " + coreNumber);
break;
}
ret = rknn_set_core_mask(ctx, core_mask);
if (ret < 0)
{
throw std::runtime_error("rknn_init core error ret=" + ret);
}
// hard coded to let npu decide where the model runs
this->changeCoreMask(coreNumber);

rknn_sdk_version version;
ret = rknn_query(ctx, RKNN_QUERY_SDK_VERSION, &version, sizeof(rknn_sdk_version));
Expand Down Expand Up @@ -204,6 +175,40 @@ YoloModel::~YoloModel() {
free(model_data);
}

int YoloModel::changeCoreMask(int coreNumber) {
// 设置模型绑定的核心/Set the core of the model that needs to be bound
rknn_core_mask core_mask;
switch (coreNumber)
{
case -1:
core_mask = RKNN_NPU_CORE_AUTO;
break;
case 0:
core_mask = RKNN_NPU_CORE_0;
break;
case 1:
core_mask = RKNN_NPU_CORE_1;
break;
case 2:
core_mask = RKNN_NPU_CORE_2;
break;
case 10:
core_mask = RKNN_NPU_CORE_0_1;
break;
case 210:
core_mask = RKNN_NPU_CORE_0_1_2;
break;
default:
throw std::runtime_error("invalid core selection! core selected: " + coreNumber);
break;
}
int ret = rknn_set_core_mask(ctx, core_mask);
if (ret < 0)
{
throw std::runtime_error("rknn_init core error ret=" + ret);
}
return ret;
}

detect_result_group_t YoloModel::forward(cv::Mat &orig_img, DetectionFilterParams params) {
cv::Mat img;
Expand Down
2 changes: 2 additions & 0 deletions src/main/native/include/yolo_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class YoloModel {
public:
YoloModel(std::string modelPath, int num_classes_, ModelVersion type_, int coreNumber);

int changeCoreMask(int coreNumber);

detect_result_group_t forward(cv::Mat &orig_img, DetectionFilterParams params);

~YoloModel();
Expand Down
14 changes: 11 additions & 3 deletions src/test/java/org/photonvision/rknn/RknnTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,26 @@ public void testBasicBlobs() {


System.out.println("Loading bus");
Mat img = Imgcodecs.imread("silly_notes2.png");
Mat img = Imgcodecs.imread("silly_notes.png");
Mat img2 = Imgcodecs.imread("silly_notes2.png");

System.out.println("Loading rknn-jni");
System.load("/home/coolpi/rknn_jni/cmake_build/librknn_jni.so");

System.out.println("Creating detector");
long ptr = RknnJNI.create("/home/coolpi/rknn_jni/note-640-640-yolov5s.rknn", 1, ModelVersion.YOLO_V5.ordinal(), 0);
System.out.println("Creating detector on three cores");
long ptr = RknnJNI.create("/home/coolpi/rknn_jni/note-640-640-yolov5s.rknn", 1, ModelVersion.YOLO_V5.ordinal(), 210);

System.out.println("Running detector");
var ret = RknnJNI.detect(ptr, img.getNativeObjAddr(), .45, .25);
System.out.println(Arrays.toString(ret));

System.out.println("Changing detector to run on core 0");
System.out.println("return code: " + RknnJNI.setCoreMask(ptr, 0));

System.out.println("Running detector again");
ret = RknnJNI.detect(ptr, img2.getNativeObjAddr(), .45, .25);
System.out.println(Arrays.toString(ret));

System.out.println("Killing detector");
RknnJNI.destroy(ptr);
}
Expand Down

0 comments on commit c0836a6

Please sign in to comment.