diff --git a/.gitignore b/.gitignore index 8505e00..d1da908 100644 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,9 @@ *.iml .gradle /local.properties -/.idea/caches -/.idea/caches/build_file_checksums.ser -/.idea/dictionaries -/.idea/libraries -/.idea/assetWizardSettings.xml -/.idea/gradle.xml -/.idea/modules.xml -/.idea/tasks.xml -/.idea/workspace.xml +.idea/ .DS_Store /build /captures .externalNativeBuild +**/.cxx \ No newline at end of file diff --git a/.idea/caches/build_file_checksums.ser b/.idea/caches/build_file_checksums.ser deleted file mode 100644 index ffd528f..0000000 Binary files a/.idea/caches/build_file_checksums.ser and /dev/null differ diff --git a/.idea/caches/gradle_models.ser b/.idea/caches/gradle_models.ser deleted file mode 100644 index 113dcf8..0000000 Binary files a/.idea/caches/gradle_models.ser and /dev/null differ diff --git a/.idea/codeStyles/Project.xml b/.idea/codeStyles/Project.xml deleted file mode 100644 index 30aa626..0000000 --- a/.idea/codeStyles/Project.xml +++ /dev/null @@ -1,29 +0,0 @@ - - - - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/compiler.xml b/.idea/compiler.xml deleted file mode 100644 index 40ed937..0000000 --- a/.idea/compiler.xml +++ /dev/null @@ -1,15 +0,0 @@ - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/encodings.xml b/.idea/encodings.xml deleted file mode 100644 index 97626ba..0000000 --- a/.idea/encodings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/gradle.xml b/.idea/gradle.xml deleted file mode 100644 index 2996d53..0000000 --- a/.idea/gradle.xml +++ /dev/null @@ -1,15 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index 703e5d4..0000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,14 +0,0 @@ - - - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations.xml b/.idea/runConfigurations.xml deleted file mode 100644 index 7f68460..0000000 --- a/.idea/runConfigurations.xml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 35eb1dd..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index b0e2f21..7733989 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,8 +7,8 @@ android: - tools - platform-tools - tools - - build-tools-28.0.3 - - android-28 + - build-tools-29.0.2 + - android-29 - extra-android-m2repository - extra-google-m2repository install: diff --git a/README.md b/README.md index 8860a34..973407c 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ This sample demonstrates realtime face recognition on Android. The project is ba ## Inspiration The project is heavily inspired by * [FaceNet](https://github.com/davidsandberg/facenet) -* [MTCNN](https://github.com/blaueck/tf-mtcnn) +* [MediaPipe](https://github.com/google/mediapipe) * [Android LibSVM](https://github.com/yctung/AndroidLibSVM) * [Tensorflow Android Camera Demo](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android) @@ -25,5 +25,9 @@ from davidsandberg's facenet |-----------------|--------------|------------------|-------------| | [20180402-114759](https://drive.google.com/open?id=1EXPBSXwTaqrSC0OhUdXNmKSh9qJUQ55-) | 0.9965 | VGGFace2 | [Inception ResNet v1](https://github.com/davidsandberg/facenet/blob/master/src/models/inception_resnet_v1.py) | +from MediaPipe + * [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite) + * Paper: ["BlazeFace: Sub-millisecond Neural Face Detection on Mobile GPUs"](https://sites.google.com/corp/view/perception-cv4arvr/blazeface) + ## License [Apache License 2.0](./LICENSE) diff --git a/app/build.gradle b/app/build.gradle index 77faece..3bff525 100644 --- a/app/build.gradle +++ b/app/build.gradle @@ -1,15 +1,15 @@ apply plugin: 'com.android.application' android { - compileSdkVersion 28 + compileSdkVersion 29 defaultConfig { applicationId "pp.facerecognizer" minSdkVersion 25 - targetSdkVersion 28 - versionCode 1 - versionName "1.0" + targetSdkVersion 29 + versionCode 2 + versionName "1.0.1" ndk { - abiFilters "armeabi-v7a" + abiFilters 'armeabi-v7a', 'arm64-v8a' } } buildTypes { @@ -23,6 +23,9 @@ android { path 'src/main/jni/CMakeLists.txt' } } + aaptOptions { + noCompress "tflite" + } compileOptions { sourceCompatibility JavaVersion.VERSION_1_8 targetCompatibility JavaVersion.VERSION_1_8 @@ -30,8 +33,8 @@ android { } dependencies { - implementation 'androidx.annotation:annotation:1.0.1' - implementation 'androidx.appcompat:appcompat:1.0.2' - implementation 'com.google.android.material:material:1.1.0-alpha03' - implementation 'org.tensorflow:tensorflow-android:1.13.0-rc0' + implementation 'androidx.annotation:annotation:1.1.0' + implementation 'androidx.appcompat:appcompat:1.1.0' + implementation 'com.google.android.material:material:1.2.0-alpha03' + implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly' } diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml index c8768fc..17ee3a4 100644 --- a/app/src/main/AndroidManifest.xml +++ b/app/src/main/AndroidManifest.xml @@ -14,7 +14,8 @@ android:label="@string/app_name" android:roundIcon="@mipmap/ic_launcher_round" android:supportsRtl="true" - android:theme="@style/AppTheme"> + android:theme="@style/AppTheme" + android:requestLegacyExternalStorage="true"> diff --git a/app/src/main/assets/face_detection_front.tflite b/app/src/main/assets/face_detection_front.tflite new file mode 100644 index 0000000..419e1a8 Binary files /dev/null and b/app/src/main/assets/face_detection_front.tflite differ diff --git a/app/src/main/assets/facenet.pb b/app/src/main/assets/facenet.tflite similarity index 84% rename from app/src/main/assets/facenet.pb rename to app/src/main/assets/facenet.tflite index 8287bd6..7e54671 100644 Binary files a/app/src/main/assets/facenet.pb and b/app/src/main/assets/facenet.tflite differ diff --git a/app/src/main/assets/mtcnn.pb b/app/src/main/assets/mtcnn.pb deleted file mode 100644 index c12fb2a..0000000 Binary files a/app/src/main/assets/mtcnn.pb and /dev/null differ diff --git a/app/src/main/java/pp/facerecognizer/MainActivity.java b/app/src/main/java/pp/facerecognizer/MainActivity.java index 24806ca..4f332c5 100644 --- a/app/src/main/java/pp/facerecognizer/MainActivity.java +++ b/app/src/main/java/pp/facerecognizer/MainActivity.java @@ -41,7 +41,6 @@ import java.io.File; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Vector; @@ -50,6 +49,7 @@ import pp.facerecognizer.env.FileUtils; import pp.facerecognizer.env.ImageUtils; import pp.facerecognizer.env.Logger; +import pp.facerecognizer.ml.BlazeFace; import pp.facerecognizer.tracking.MultiBoxTracker; /** @@ -59,8 +59,8 @@ public class MainActivity extends CameraActivity implements OnImageAvailableListener { private static final Logger LOGGER = new Logger(); - private static final int FACE_SIZE = 160; - private static final int CROP_SIZE = 300; + private static final int CROP_HEIGHT = BlazeFace.INPUT_SIZE_HEIGHT; + private static final int CROP_WIDTH = BlazeFace.INPUT_SIZE_WIDTH; private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480); @@ -69,7 +69,7 @@ public class MainActivity extends CameraActivity implements OnImageAvailableList private Integer sensorOrientation; - private Classifier classifier; + private Recognizer recognizer; private long lastProcessingTimeMs; private Bitmap rgbFrameBitmap = null; @@ -101,8 +101,10 @@ protected void onCreate(final Bundle savedInstanceState) { super.onCreate(savedInstanceState); FrameLayout container = findViewById(R.id.container); - initSnackbar = Snackbar.make(container, "Initializing...", Snackbar.LENGTH_INDEFINITE); - trainSnackbar = Snackbar.make(container, "Training data...", Snackbar.LENGTH_INDEFINITE); + initSnackbar = Snackbar.make( + container, getString(R.string.initializing), Snackbar.LENGTH_INDEFINITE); + trainSnackbar = Snackbar.make( + container, getString(R.string.training), Snackbar.LENGTH_INDEFINITE); View dialogView = getLayoutInflater().inflate(R.layout.dialog_edittext, null); EditText editText = dialogView.findViewById(R.id.edit_text); @@ -110,7 +112,7 @@ protected void onCreate(final Bundle savedInstanceState) { .setTitle(R.string.enter_name) .setView(dialogView) .setPositiveButton(getString(R.string.ok), (dialogInterface, i) -> { - int idx = classifier.addPerson(editText.getText().toString()); + int idx = recognizer.addPerson(editText.getText().toString()); performFileSearch(idx - 1); }) .create(); @@ -119,7 +121,7 @@ protected void onCreate(final Bundle savedInstanceState) { button.setOnClickListener(view -> new AlertDialog.Builder(MainActivity.this) .setTitle(getString(R.string.select_name)) - .setItems(classifier.getClassNames(), (dialogInterface, i) -> { + .setItems(recognizer.getClassNames(), (dialogInterface, i) -> { if (i == 0) { editDialog.show(); } else { @@ -132,7 +134,7 @@ protected void onCreate(final Bundle savedInstanceState) { @Override public void onPreviewSizeChosen(final Size size, final int rotation) { if (!initialized) - new Thread(this::init).start(); + init(); final float textSizePx = TypedValue.applyDimension( @@ -150,12 +152,12 @@ public void onPreviewSizeChosen(final Size size, final int rotation) { LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); - croppedBitmap = Bitmap.createBitmap(CROP_SIZE, CROP_SIZE, Config.ARGB_8888); + croppedBitmap = Bitmap.createBitmap(CROP_WIDTH, CROP_HEIGHT, Config.ARGB_8888); frameToCropTransform = ImageUtils.getTransformationMatrix( previewWidth, previewHeight, - CROP_SIZE, CROP_SIZE, + CROP_WIDTH, CROP_HEIGHT, sensorOrientation, false); cropToFrameTransform = new Matrix(); @@ -191,13 +193,7 @@ public void onPreviewSizeChosen(final Size size, final int rotation) { canvas.getHeight() - copy.getHeight() * scaleFactor); canvas.drawBitmap(copy, matrix, new Paint()); - final Vector lines = new Vector(); - if (classifier != null) { - final String statString = classifier.getStatString(); - final String[] statLines = statString.split("\n"); - Collections.addAll(lines, statLines); - } - lines.add(""); + final Vector lines = new Vector<>(); lines.add("Frame: " + previewWidth + "x" + previewHeight); lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight()); lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight()); @@ -211,28 +207,30 @@ public void onPreviewSizeChosen(final Size size, final int rotation) { OverlayView trackingOverlay; void init() { - runOnUiThread(()-> initSnackbar.show()); - File dir = new File(FileUtils.ROOT); - - if (!dir.isDirectory()) { - if (dir.exists()) dir.delete(); - dir.mkdirs(); - - AssetManager mgr = getAssets(); - FileUtils.copyAsset(mgr, FileUtils.DATA_FILE); - FileUtils.copyAsset(mgr, FileUtils.MODEL_FILE); - FileUtils.copyAsset(mgr, FileUtils.LABEL_FILE); - } + runInBackground(() -> { + runOnUiThread(()-> initSnackbar.show()); + File dir = new File(FileUtils.ROOT); + + if (!dir.isDirectory()) { + if (dir.exists()) dir.delete(); + dir.mkdirs(); + + AssetManager mgr = getAssets(); + FileUtils.copyAsset(mgr, FileUtils.DATA_FILE); + FileUtils.copyAsset(mgr, FileUtils.MODEL_FILE); + FileUtils.copyAsset(mgr, FileUtils.LABEL_FILE); + } - try { - classifier = Classifier.getInstance(getAssets(), FACE_SIZE, FACE_SIZE); - } catch (Exception e) { - LOGGER.e("Exception initializing classifier!", e); - finish(); - } + try { + recognizer = Recognizer.getInstance(getAssets()); + } catch (Exception e) { + LOGGER.e("Exception initializing classifier!", e); + finish(); + } - runOnUiThread(()-> initSnackbar.dismiss()); - initialized = true; + runOnUiThread(()-> initSnackbar.dismiss()); + initialized = true; + }); } @Override @@ -278,8 +276,8 @@ protected void processImage() { final long startTime = SystemClock.uptimeMillis(); cropCopyBitmap = Bitmap.createBitmap(croppedBitmap); - List mappedRecognitions = - classifier.recognizeImage(croppedBitmap,cropToFrameTransform); + List mappedRecognitions = + recognizer.recognizeImage(croppedBitmap,cropToFrameTransform); lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; tracker.trackResults(mappedRecognitions, luminanceCopy, currTimestamp); @@ -302,10 +300,12 @@ protected Size getDesiredPreviewFrameSize() { @Override protected void onActivityResult(int requestCode, int resultCode, Intent data) { + super.onActivityResult(requestCode, resultCode, data); + if (!initialized) { Snackbar.make( getWindow().getDecorView().findViewById(R.id.container), - "Try it again later", Snackbar.LENGTH_SHORT) + getString(R.string.try_it_later), Snackbar.LENGTH_SHORT) .show(); return; } @@ -327,7 +327,7 @@ protected void onActivityResult(int requestCode, int resultCode, Intent data) { new Thread(() -> { try { - classifier.updateData(requestCode, getContentResolver(), uris); + recognizer.updateData(requestCode, getContentResolver(), uris); } catch (Exception e) { LOGGER.e(e, "Exception!"); } finally { diff --git a/app/src/main/java/pp/facerecognizer/Classifier.java b/app/src/main/java/pp/facerecognizer/Recognizer.java similarity index 71% rename from app/src/main/java/pp/facerecognizer/Classifier.java rename to app/src/main/java/pp/facerecognizer/Recognizer.java index 2fa6957..ea4a0d3 100644 --- a/app/src/main/java/pp/facerecognizer/Classifier.java +++ b/app/src/main/java/pp/facerecognizer/Recognizer.java @@ -31,16 +31,15 @@ import java.util.LinkedList; import java.util.List; -import androidx.core.util.Pair; import pp.facerecognizer.env.FileUtils; -import pp.facerecognizer.wrapper.FaceNet; -import pp.facerecognizer.wrapper.LibSVM; -import pp.facerecognizer.wrapper.MTCNN; +import pp.facerecognizer.ml.BlazeFace; +import pp.facerecognizer.ml.FaceNet; +import pp.facerecognizer.ml.LibSVM; /** * Generic interface for interacting with different recognition engines. */ -public class Classifier { +public class Recognizer { /** * An immutable result returned by a Classifier describing what was recognized. */ @@ -88,10 +87,6 @@ public RectF getLocation() { return new RectF(location); } - void setLocation(RectF location) { - this.location = location; - } - @Override public String toString() { String resultString = ""; @@ -115,38 +110,33 @@ public String toString() { } } - public static final int EMBEDDING_SIZE = 512; - private static Classifier classifier; + private static Recognizer recognizer; - private MTCNN mtcnn; + private BlazeFace blazeFace; private FaceNet faceNet; private LibSVM svm; private List classNames; - private Classifier() {} + private Recognizer() {} - static Classifier getInstance (AssetManager assetManager, - int inputHeight, - int inputWidth) throws Exception { - if (classifier != null) return classifier; + static Recognizer getInstance (AssetManager assetManager) throws Exception { + if (recognizer != null) return recognizer; - classifier = new Classifier(); + recognizer = new Recognizer(); + recognizer.blazeFace = BlazeFace.create(assetManager); + recognizer.faceNet = FaceNet.create(assetManager); + recognizer.svm = LibSVM.getInstance(); + recognizer.classNames = FileUtils.readLabel(FileUtils.LABEL_FILE); - classifier.mtcnn = MTCNN.create(assetManager); - classifier.faceNet = FaceNet.create(assetManager, inputHeight, inputWidth); - classifier.svm = LibSVM.getInstance(); - - classifier.classNames = FileUtils.readLabel(FileUtils.LABEL_FILE); - - return classifier; + return recognizer; } CharSequence[] getClassNames() { CharSequence[] cs = new CharSequence[classNames.size() + 1]; int idx = 1; - cs[0] = "+ Add new person"; + cs[0] = "+ add new person"; for (String name : classNames) { cs[idx++] = name; } @@ -156,30 +146,22 @@ CharSequence[] getClassNames() { List recognizeImage(Bitmap bitmap, Matrix matrix) { synchronized (this) { - Pair faces[] = mtcnn.detect(bitmap); - + List faces = blazeFace.detect(bitmap); final List mappedRecognitions = new LinkedList<>(); - for (Pair face : faces) { - RectF rectF = (RectF) face.first; - + for (RectF rectF : faces) { Rect rect = new Rect(); rectF.round(rect); FloatBuffer buffer = faceNet.getEmbeddings(bitmap, rect); - Pair pair = svm.predict(buffer); + LibSVM.Prediction prediction = svm.predict(buffer); matrix.mapRect(rectF); - Float prob = pair.second; - - String name; - if (prob > 0.5) - name = classNames.get(pair.first); - else - name = "Unknown"; + int index = prediction.getIndex(); + String name = classNames.get(index); Recognition result = - new Recognition("" + pair.first, name, prob, rectF); + new Recognition("" + index, name, prediction.getProb(), rectF); mappedRecognitions.add(result); } return mappedRecognitions; @@ -193,22 +175,14 @@ void updateData(int label, ContentResolver contentResolver, ArrayList uris) for (Uri uri : uris) { Bitmap bitmap = getBitmapFromUri(contentResolver, uri); - Pair faces[] = mtcnn.detect(bitmap); + List faces = blazeFace.detect(bitmap); - float max = 0f; Rect rect = new Rect(); - - for (Pair face : faces) { - Float prob = (Float) face.second; - if (prob > max) { - max = prob; - - RectF rectF = (RectF) face.first; - rectF.round(rect); - } + if (!faces.isEmpty()) { + faces.get(0).round(rect); } - float[] emb_array = new float[EMBEDDING_SIZE]; + float[] emb_array = new float[FaceNet.EMBEDDING_SIZE]; faceNet.getEmbeddings(bitmap, rect).get(emb_array); list.add(emb_array); } @@ -237,12 +211,8 @@ private Bitmap getBitmapFromUri(ContentResolver contentResolver, Uri uri) throws void enableStatLogging(final boolean debug){ } - String getStatString() { - return faceNet.getStatString(); - } - void close() { - mtcnn.close(); + blazeFace.close(); faceNet.close(); } } diff --git a/app/src/main/java/pp/facerecognizer/env/ImageUtils.java b/app/src/main/java/pp/facerecognizer/env/ImageUtils.java index 54ec68e..a3e971f 100644 --- a/app/src/main/java/pp/facerecognizer/env/ImageUtils.java +++ b/app/src/main/java/pp/facerecognizer/env/ImageUtils.java @@ -18,6 +18,8 @@ import android.graphics.Bitmap; import android.graphics.Matrix; +import java.nio.FloatBuffer; + /** * Utility class for manipulating images. **/ @@ -307,4 +309,39 @@ public static Matrix getTransformationMatrix( return matrix; } + + public static void prewhiten(float[] input, FloatBuffer output) { + if (useNativeConversion) { + try { + ImageUtils.prewhiten(input, input.length, output); + return; + } catch (UnsatisfiedLinkError e) { + LOGGER.w( + "Native prewhiten implementation not found, falling back to Java implementation"); + useNativeConversion = false; + } + } + + double sum = 0f; + for (float value : input) { + sum += value; + } + double mean = sum / input.length; + sum = 0f; + + for (int i = 0; i < input.length; ++i) { + input[i] -= mean; + sum += Math.pow(input[i], 2); + } + double std = Math.sqrt(sum / input.length); + double std_adj = Math.max(std, 1.0 / Math.sqrt(input.length)); + + output.clear(); + for (float value : input) { + output.put((float) (value / std_adj)); + } + output.rewind(); + } + + private static native float prewhiten(float[] input, int length, FloatBuffer output); } diff --git a/app/src/main/java/pp/facerecognizer/ml/BlazeFace.java b/app/src/main/java/pp/facerecognizer/ml/BlazeFace.java new file mode 100644 index 0000000..04b60fb --- /dev/null +++ b/app/src/main/java/pp/facerecognizer/ml/BlazeFace.java @@ -0,0 +1,376 @@ +package pp.facerecognizer.ml; + +import android.content.res.AssetFileDescriptor; +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.RectF; +import android.os.Trace; + +import org.tensorflow.lite.Interpreter; + +import java.io.FileInputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.nio.channels.FileChannel; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class BlazeFace { + private static final String MODEL_FILE = "face_detection_front.tflite"; + + public static final int INPUT_SIZE_HEIGHT = 128; + public static final int INPUT_SIZE_WIDTH = 128; + + // Only return this many results. + private static final int NUM_BOXES = 896; + private static final int NUM_COORDS = 16; + private static final int BYTE_SIZE_OF_FLOAT = 4; + + private static final float MIN_SCORE_THRESH = 0.95f; + + private static final int[] strides = {8, 16, 16, 16}; + + private static final int ASPECT_RATIOS_SIZE = 1; + + private static final float MIN_SCALE = 0.1484375f; + private static final float MAX_SCALE = 0.75f; + + private static final float ANCHOR_OFFSET_X = 0.5f; + private static final float ANCHOR_OFFSET_Y = 0.5f; + + private static final float X_SCALE = 128f; + private static final float Y_SCALE = 128f; + private static final float H_SCALE = 128f; + private static final float W_SCALE = 128f; + + private static final float MIN_SUPPRESSION_THRESHOLD = 0.3f; + + // Pre-allocated buffers. + private int[] intValues; + private float[][][][] floatValues; + private Object[] inputArray; + + private FloatBuffer outputScores; + private FloatBuffer outputBoxes; + private Map outputMap; + + private Interpreter interpreter; + + private List anchors; + + private static class Anchor { + private float x_center; + private float y_center; + private float h; + private float w; + } + + private class Detection { + private RectF location; + private float score; + + Detection(RectF location, float score) { + this.location = location; + this.score = score; + } + } + + private class IndexedScore { + private int index; + private float score; + + IndexedScore(int index, float score) { + this.index = index; + this.score = score; + } + } + + /** Memory-map the model file in Assets. */ + private static ByteBuffer loadModelFile(AssetManager assets) + throws IOException { + AssetFileDescriptor fileDescriptor = assets.openFd(MODEL_FILE); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } + + private static float CalculateScale(float min_scale, float max_scale, int stride_index, + int num_strides) { + return min_scale + + (max_scale - min_scale) * 1.0f * stride_index / (num_strides - 1.0f); + } + + private static List GenerateAnchors() { + List anchors = new ArrayList<>(); + int layer_id = 0; + + while (layer_id < strides.length) { + List anchor_height = new ArrayList<>(); + List anchor_width = new ArrayList<>(); + List aspect_ratios = new ArrayList<>(); + List scales = new ArrayList<>(); + + // For same strides, we merge the anchors in the same order. + int last_same_stride_layer = layer_id; + while (last_same_stride_layer < strides.length && + strides[last_same_stride_layer] == strides[layer_id]) { + float scale = CalculateScale(MIN_SCALE, MAX_SCALE, + last_same_stride_layer, strides.length); + for (int aspect_ratio_id = 0; aspect_ratio_id < ASPECT_RATIOS_SIZE; ++aspect_ratio_id) { + aspect_ratios.add(1.0f); + scales.add(scale); + } + float scale_next = + last_same_stride_layer == strides.length - 1 + ? 1.0f + : CalculateScale(MIN_SCALE, MAX_SCALE, + last_same_stride_layer + 1, + strides.length); + scales.add((float) Math.sqrt(scale * scale_next)); + aspect_ratios.add(1.0f); + last_same_stride_layer++; + } + + for (int i = 0; i < aspect_ratios.size(); ++i) { + float ratio_sqrts = (float) Math.sqrt(aspect_ratios.get(i)); + anchor_height.add(scales.get(i) / ratio_sqrts); + anchor_width.add(scales.get(i) * ratio_sqrts); + } + + int stride = strides[layer_id]; + int feature_map_height = (int) Math.ceil(1.0f * INPUT_SIZE_HEIGHT / stride); + int feature_map_width = (int) Math.ceil(1.0f * INPUT_SIZE_WIDTH / stride); + + for (int y = 0; y < feature_map_height; ++y) { + for (int x = 0; x < feature_map_width; ++x) { + for (int anchor_id = 0; anchor_id < anchor_height.size(); ++anchor_id) { + // TODO: Support specifying anchor_offset_x, anchor_offset_y. + float x_center = (x + ANCHOR_OFFSET_X) * 1.0f / feature_map_width; + float y_center = (y + ANCHOR_OFFSET_Y) * 1.0f / feature_map_height; + + Anchor new_anchor = new Anchor(); + new_anchor.x_center = x_center; + new_anchor.y_center = y_center; + new_anchor.w = 1.0f; + new_anchor.h = 1.0f; + + anchors.add(new_anchor); + } + } + } + layer_id = last_same_stride_layer; + } + return anchors; + } + + /** + * Initializes a native TensorFlow session for classifying images. + * + * @param assetManager The asset manager to be used to load assets. + */ + public static BlazeFace create( + final AssetManager assetManager) { + final BlazeFace b = new BlazeFace(); + + try { + b.interpreter = new Interpreter(loadModelFile(assetManager)); + } catch (Exception e) { + throw new RuntimeException(e); + } + + // Pre-allocate buffers. + b.intValues = new int[INPUT_SIZE_WIDTH * INPUT_SIZE_HEIGHT]; + b.floatValues = new float[1][INPUT_SIZE_HEIGHT][INPUT_SIZE_WIDTH][3]; + b.inputArray = new Object[]{b.floatValues}; + + b.outputScores = ByteBuffer.allocateDirect(NUM_BOXES * BYTE_SIZE_OF_FLOAT) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer(); + + b.outputBoxes = ByteBuffer.allocateDirect(NUM_BOXES * NUM_COORDS * BYTE_SIZE_OF_FLOAT) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer(); + + b.outputMap = new HashMap<>(); + b.outputMap.put(0, b.outputBoxes); + b.outputMap.put(1, b.outputScores); + + b.anchors = GenerateAnchors(); + + return b; + } + + private BlazeFace() {} + + public List detect(Bitmap bitmap) { + // Log this method so that it can be analyzed with systrace. + Trace.beginSection("detect"); + + Trace.beginSection("preprocessBitmap"); + // Preprocess the image data from 0-255 int to normalized float based + // on the provided parameters. + bitmap.getPixels(intValues, 0, INPUT_SIZE_WIDTH, 0, 0, INPUT_SIZE_WIDTH, INPUT_SIZE_HEIGHT); + + for (int i = 0; i < INPUT_SIZE_HEIGHT; ++i) { + for (int j = 0; j < INPUT_SIZE_WIDTH; ++j) { + int p = intValues[i * INPUT_SIZE_WIDTH + j]; + + floatValues[0][i][j][2] = (p & 0xFF) / 127.5f - 1; + floatValues[0][i][j][1] = ((p >> 8) & 0xFF) / 127.5f - 1; + floatValues[0][i][j][0] = ((p >> 16) & 0xFF) / 127.5f - 1; + } + } + Trace.endSection(); // preprocessBitmap + + // Run the inference call. + Trace.beginSection("run"); + interpreter.runForMultipleInputsOutputs(inputArray, outputMap); + Trace.endSection(); + + outputScores.flip(); + outputBoxes.flip(); + + List detections = new ArrayList<>(); + for (int i = 0; i < NUM_BOXES; i++) { + float score = outputScores.get(i); + score = score < -100.0f ? -100.0f : score; + score = score > 100.0f ? 100.0f : score; + score = 1.0f / (1.0f + (float) Math.exp(-score)); + + if (score <= MIN_SCORE_THRESH) + continue; + + float x_center = outputBoxes.get(i * NUM_COORDS); + float y_center = outputBoxes.get(i * NUM_COORDS + 1); + float w = outputBoxes.get(i * NUM_COORDS + 2); + float h = outputBoxes.get(i * NUM_COORDS + 3); + + x_center = + x_center / X_SCALE * anchors.get(i).w + anchors.get(i).x_center; + y_center = + y_center / Y_SCALE * anchors.get(i).h + anchors.get(i).y_center; + + h = h / H_SCALE * anchors.get(i).h; + w = w / W_SCALE * anchors.get(i).w; + + float ymin = y_center - h / 2.f; + float xmin = x_center - w / 2.f; + float ymax = y_center + h / 2.f; + float xmax = x_center + w / 2.f; + + detections.add(new Detection(new RectF(xmin, ymin, xmax, ymax), score)); + } + + outputScores.clear(); + outputBoxes.clear(); + + // Check if there are any detections at all. + if (detections.isEmpty()) { + return new ArrayList<>(); + } + + // Copy all the scores (there is a single score in each detection after + // the above pruning) to an indexed vector for sorting. The first value is + // the index of the detection in the original vector from which the score + // stems, while the second is the actual score. + List indexed_scores = new ArrayList<>(); + for (int index = 0; index < detections.size(); ++index) { + indexed_scores.add( + new IndexedScore(index, detections.get(index).score)); + } + indexed_scores.sort((o1, o2) -> { + if (o1.score > o2.score) return 1; + else if (o1.score == o2.score) return 0; + return -1; + }); + + // A set of detections and locations, wrapping the location data from each + // detection, which are retained after the non-maximum suppression. + List retained_detections = WeightedNonMaxSuppression(indexed_scores, detections); + + Trace.endSection(); // "detect" + return retained_detections; + } + + private List WeightedNonMaxSuppression(List indexed_scores, + List detections) { + List remained_indexed_scores = new ArrayList<>(indexed_scores); + + List remained = new ArrayList<>(); + List candidates = new ArrayList<>(); + List output_locations = new ArrayList<>(); + + while (!remained_indexed_scores.isEmpty()) { + Detection detection = detections.get(remained_indexed_scores.get(0).index); + if ((int) detection.score < -1.f) { + break; + } + + remained.clear(); + candidates.clear(); + RectF location = new RectF(detection.location); + // This includes the first box. + for (IndexedScore indexed_score : remained_indexed_scores) { + RectF rest_location = new RectF(detections.get(indexed_score.index).location); + float similarity = + OverlapSimilarity(rest_location, location); + if (similarity > MIN_SUPPRESSION_THRESHOLD) { + candidates.add(indexed_score); + } else { + remained.add(indexed_score); + } + } + RectF weighted_location = new RectF(detection.location); + if (!candidates.isEmpty()) { + float w_xmin = 0.0f; + float w_ymin = 0.0f; + float w_xmax = 0.0f; + float w_ymax = 0.0f; + float total_score = 0.0f; + for (IndexedScore candidate : candidates) { + total_score += candidate.score; + RectF bbox = + detections.get(candidate.index).location; + w_xmin += bbox.left * candidate.score; + w_ymin += bbox.top * candidate.score; + w_xmax += bbox.right * candidate.score; + w_ymax += bbox.bottom * candidate.score; + + } + weighted_location.left = w_xmin / total_score * INPUT_SIZE_WIDTH; + weighted_location.top = w_ymin / total_score * INPUT_SIZE_HEIGHT; + weighted_location.right = w_xmax / total_score * INPUT_SIZE_WIDTH; + weighted_location.bottom = w_ymax / total_score * INPUT_SIZE_HEIGHT; + } + remained_indexed_scores.clear(); + remained_indexed_scores.addAll(remained); + output_locations.add(weighted_location); + } + + return output_locations; + } + + // Computes an overlap similarity between two rectangles. Similarity measure is + // defined by overlap_type parameter. + private float OverlapSimilarity(RectF rect1, RectF rect2) { + if (!RectF.intersects(rect1, rect2)) return 0.0f; + RectF intersection = new RectF(); + intersection.setIntersect(rect1, rect2); + + float intersection_area = intersection.height() * intersection.width(); + float normalization = rect1.height() * rect1.width() + + rect2.height() * rect2.width() - intersection_area; + + return normalization > 0.0f ? intersection_area / normalization : 0.0f; + } + + public void close() { + interpreter.close(); + } +} diff --git a/app/src/main/java/pp/facerecognizer/ml/FaceNet.java b/app/src/main/java/pp/facerecognizer/ml/FaceNet.java new file mode 100644 index 0000000..756b356 --- /dev/null +++ b/app/src/main/java/pp/facerecognizer/ml/FaceNet.java @@ -0,0 +1,135 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package pp.facerecognizer.ml; + +import android.content.res.AssetFileDescriptor; +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.graphics.Canvas; +import android.graphics.Rect; +import android.os.Trace; + +import org.tensorflow.lite.Interpreter; + +import java.io.FileInputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.nio.channels.FileChannel; + +import pp.facerecognizer.env.ImageUtils; + +public class FaceNet { + private static final String MODEL_FILE = "facenet.tflite"; + + public static final int EMBEDDING_SIZE = 512; + + private static final int INPUT_SIZE_HEIGHT = 160; + private static final int INPUT_SIZE_WIDTH = 160; + + private static final int BYTE_SIZE_OF_FLOAT = 4; + + // Pre-allocated buffers. + private int[] intValues; + private float[] rgbValues; + + private FloatBuffer inputBuffer; + private FloatBuffer outputBuffer; + + private Bitmap bitmap; + + private Interpreter interpreter; + + /** Memory-map the model file in Assets. */ + private static ByteBuffer loadModelFile(AssetManager assets) + throws IOException { + AssetFileDescriptor fileDescriptor = assets.openFd(MODEL_FILE); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } + + /** + * Initializes a native TensorFlow session for classifying images. + * + * @param assetManager The asset manager to be used to load assets. + */ + public static FaceNet create(final AssetManager assetManager) { + final FaceNet f = new FaceNet(); + + try { + f.interpreter = new Interpreter(loadModelFile(assetManager)); + } catch (Exception e) { + throw new RuntimeException(e); + } + + // Pre-allocate buffers. + f.intValues = new int[INPUT_SIZE_HEIGHT * INPUT_SIZE_WIDTH]; + f.rgbValues = new float[INPUT_SIZE_HEIGHT * INPUT_SIZE_WIDTH * 3]; + f.inputBuffer = ByteBuffer.allocateDirect(INPUT_SIZE_HEIGHT * INPUT_SIZE_WIDTH * 3 * BYTE_SIZE_OF_FLOAT) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer(); + f.outputBuffer = ByteBuffer.allocateDirect(EMBEDDING_SIZE * BYTE_SIZE_OF_FLOAT) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer(); + + f.bitmap = Bitmap.createBitmap(INPUT_SIZE_WIDTH, INPUT_SIZE_HEIGHT, Config.ARGB_8888); + return f; + } + + private FaceNet() {} + + public FloatBuffer getEmbeddings(Bitmap originalBitmap, Rect rect) { + // Log this method so that it can be analyzed with systrace. + Trace.beginSection("getEmbeddings"); + + Trace.beginSection("preprocessBitmap"); + Canvas canvas = new Canvas(bitmap); + canvas.drawBitmap(originalBitmap, rect, + new Rect(0, 0, INPUT_SIZE_WIDTH, INPUT_SIZE_HEIGHT), null); + + bitmap.getPixels(intValues, 0, INPUT_SIZE_WIDTH, 0, 0, + INPUT_SIZE_WIDTH, INPUT_SIZE_HEIGHT); + ImageUtils.saveBitmap(bitmap); + + for (int i = 0; i < intValues.length; ++i) { + int p = intValues[i]; + + rgbValues[i * 3 + 2] = (float) (p & 0xFF); + rgbValues[i * 3 + 1] = (float) ((p >> 8) & 0xFF); + rgbValues[i * 3 + 0] = (float) ((p >> 16) & 0xFF); + } + + ImageUtils.prewhiten(rgbValues, inputBuffer); + + Trace.endSection(); // preprocessBitmap + + // Run the inference call. + Trace.beginSection("run"); + outputBuffer.rewind(); + interpreter.run(inputBuffer, outputBuffer); + outputBuffer.flip(); + Trace.endSection(); + + Trace.endSection(); // "getEmbeddings" + return outputBuffer; + } + + public void close() { + interpreter.close(); + } +} diff --git a/app/src/main/java/pp/facerecognizer/wrapper/LibSVM.java b/app/src/main/java/pp/facerecognizer/ml/LibSVM.java similarity index 81% rename from app/src/main/java/pp/facerecognizer/wrapper/LibSVM.java rename to app/src/main/java/pp/facerecognizer/ml/LibSVM.java index 14262ea..f189d97 100644 --- a/app/src/main/java/pp/facerecognizer/wrapper/LibSVM.java +++ b/app/src/main/java/pp/facerecognizer/ml/LibSVM.java @@ -1,4 +1,4 @@ -package pp.facerecognizer.wrapper; +package pp.facerecognizer.ml; import android.text.TextUtils; import android.util.Log; @@ -8,8 +8,6 @@ import java.util.ArrayList; import java.util.Arrays; -import androidx.core.util.Pair; -import pp.facerecognizer.Classifier; import pp.facerecognizer.env.FileUtils; /** @@ -25,6 +23,22 @@ public class LibSVM { private int index; private double prob; + public class Prediction { + private int index; + private float prob; + + Prediction(int index, float prob) { + this.index = index; + this.prob = prob; + } + public int getIndex() { + return index; + } + public float getProb() { + return prob; + } + } + static { System.loadLibrary("jnilibsvm"); } @@ -68,12 +82,12 @@ public void train() { train(cmd); } - public Pair predict(FloatBuffer buffer) { + public Prediction predict(FloatBuffer buffer) { String options = "-b 1"; String cmd = TextUtils.join(" ", Arrays.asList(options, MODEL_PATH)); - predict(cmd, buffer, Classifier.EMBEDDING_SIZE); - return new Pair<>(index, (float) prob); + predict(cmd, buffer, FaceNet.EMBEDDING_SIZE); + return new Prediction(index, (float) prob); } // singleton for the easy access diff --git a/app/src/main/java/pp/facerecognizer/tracking/MultiBoxTracker.java b/app/src/main/java/pp/facerecognizer/tracking/MultiBoxTracker.java index 496b806..2af4195 100644 --- a/app/src/main/java/pp/facerecognizer/tracking/MultiBoxTracker.java +++ b/app/src/main/java/pp/facerecognizer/tracking/MultiBoxTracker.java @@ -33,7 +33,7 @@ import java.util.List; import java.util.Queue; -import pp.facerecognizer.Classifier.Recognition; +import pp.facerecognizer.Recognizer.Recognition; import pp.facerecognizer.env.BorderedText; import pp.facerecognizer.env.ImageUtils; import pp.facerecognizer.env.Logger; diff --git a/app/src/main/java/pp/facerecognizer/wrapper/FaceNet.java b/app/src/main/java/pp/facerecognizer/wrapper/FaceNet.java deleted file mode 100644 index 4e35c2b..0000000 --- a/app/src/main/java/pp/facerecognizer/wrapper/FaceNet.java +++ /dev/null @@ -1,172 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package pp.facerecognizer.wrapper; - -import android.content.res.AssetManager; -import android.graphics.Bitmap; -import android.graphics.Bitmap.Config; -import android.graphics.Canvas; -import android.graphics.Rect; -import android.os.Trace; - -import org.tensorflow.Graph; -import org.tensorflow.Operation; -import org.tensorflow.contrib.android.TensorFlowInferenceInterface; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.FloatBuffer; - -import pp.facerecognizer.Classifier; - -public class FaceNet { - private static final String MODEL_FILE = "file:///android_asset/facenet.pb"; - private static final int BYTE_SIZE_OF_FLOAT = 4; - - // Config values. - private String inputName; - private int inputHeight; - private int inputWidth; - - // Pre-allocated buffers. - private int[] intValues; - private short[] shortValues; - private FloatBuffer inputBuffer; - - private FloatBuffer outputBuffer; - private String[] outputNames; - - private TensorFlowInferenceInterface inferenceInterface; - - private Bitmap bitmap; - - /** - * Initializes a native TensorFlow session for classifying images. - * - * @param assetManager The asset manager to be used to load assets. - */ - public static FaceNet create( - final AssetManager assetManager, - final int inputHeight, - final int inputWidth) { - final FaceNet d = new FaceNet(); - - d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, MODEL_FILE); - - final Graph g = d.inferenceInterface.graph(); - - d.inputName = "input"; - // The inputName node has a shape of [N, H, W, C], where - // N is the batch size - // H = W are the height and width - // C is the number of channels (3 for our purposes - RGB) - final Operation - inputOp1 = g.operation(d.inputName); - if (inputOp1 == null) { - throw new RuntimeException("Failed to find input Node '" + d.inputName + "'"); - } - - d.inputHeight = inputHeight; - d.inputWidth = inputWidth; - - d.outputNames = new String[] {"embeddings"}; - final Operation outputOp1 = g.operation(d.outputNames[0]); - if (outputOp1 == null) { - throw new RuntimeException("Failed to find output Node'" + d.outputNames[0] + "'"); - } - - // Pre-allocate buffers. - d.intValues = new int[inputHeight * inputWidth]; - d.shortValues = new short[inputHeight * inputWidth * 3]; - d.inputBuffer = ByteBuffer.allocateDirect(inputHeight * inputWidth * BYTE_SIZE_OF_FLOAT * 3) - .order(ByteOrder.nativeOrder()) - .asFloatBuffer(); - - d.outputBuffer = ByteBuffer.allocateDirect(Classifier.EMBEDDING_SIZE * BYTE_SIZE_OF_FLOAT) - .order(ByteOrder.nativeOrder()) - .asFloatBuffer(); - - d.bitmap = Bitmap.createBitmap(inputWidth, inputHeight, Config.ARGB_8888); - return d; - } - - private FaceNet() {} - - public FloatBuffer getEmbeddings(Bitmap originalBitmap, Rect rect) { - // Log this method so that it can be analyzed with systrace. - Trace.beginSection("getEmbeddings"); - - Trace.beginSection("preprocessBitmap"); - Canvas canvas = new Canvas(bitmap); - canvas.drawBitmap(originalBitmap, rect, new Rect(0, 0, inputWidth, inputHeight), null); - - bitmap.getPixels(intValues, 0, inputWidth, 0, 0, inputWidth, inputHeight); - - for (int i = 0; i < intValues.length; ++i) { - int p = intValues[i]; - - shortValues[i * 3 + 2] = (short) (p & 0xFF); - shortValues[i * 3 + 1] = (short) ((p >> 8) & 0xFF); - shortValues[i * 3 + 0] = (short) ((p >> 16) & 0xFF); - } - - double sum = 0f; - for (short shortValue : shortValues) { - sum += shortValue; - } - double mean = sum / shortValues.length; - sum = 0f; - - for (short shortValue : shortValues) { - sum += Math.pow(shortValue - mean, 2); - } - double std = Math.sqrt(sum / shortValues.length); - double std_adj = Math.max(std, 1.0/Math.sqrt(shortValues.length)); - - inputBuffer.rewind(); - for (short shortValue : shortValues) { - inputBuffer.put((float) ((shortValue - mean) * (1 / std_adj))); - } - inputBuffer.flip(); - - Trace.endSection(); // preprocessBitmap - - // Copy the input data into TensorFlow. - Trace.beginSection("feed"); - inferenceInterface.feed(inputName, inputBuffer, 1, inputHeight, inputWidth, 3); - Trace.endSection(); - - // Run the inference call. - Trace.beginSection("run"); - inferenceInterface.run(outputNames, false); - Trace.endSection(); - - // Copy the output Tensor back into the output array. - Trace.beginSection("fetch"); - outputBuffer.rewind(); - inferenceInterface.fetch(outputNames[0], outputBuffer); - outputBuffer.flip(); - Trace.endSection(); - - Trace.endSection(); // "getEmbeddings" - return outputBuffer; - } - - public String getStatString() { - return inferenceInterface.getStatString(); - } - - public void close() { - inferenceInterface.close(); - } -} diff --git a/app/src/main/java/pp/facerecognizer/wrapper/MTCNN.java b/app/src/main/java/pp/facerecognizer/wrapper/MTCNN.java deleted file mode 100644 index 0a968d6..0000000 --- a/app/src/main/java/pp/facerecognizer/wrapper/MTCNN.java +++ /dev/null @@ -1,158 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package pp.facerecognizer.wrapper; - -import android.content.res.AssetManager; -import android.graphics.Bitmap; -import android.graphics.RectF; -import android.os.Trace; - -import org.tensorflow.Graph; -import org.tensorflow.contrib.android.TensorFlowInferenceInterface; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.FloatBuffer; - -import androidx.core.util.Pair; - -public class MTCNN { - private static final String MODEL_FILE = "file:///android_asset/mtcnn.pb"; - // Only return this many results. - private static final int MAX_RESULTS = 100; - private static final int BYTE_SIZE_OF_FLOAT = 4; - - // Config values. - private String inputName; - - // Pre-allocated buffers. - private FloatBuffer outputProbs; - private FloatBuffer outputBoxes; - private String[] outputNames; - - private TensorFlowInferenceInterface inferenceInterface; - - /** - * Initializes a native TensorFlow session for classifying images. - * - * @param assetManager The asset manager to be used to load assets. - */ - public static MTCNN create( - final AssetManager assetManager) { - final MTCNN d = new MTCNN(); - - d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, MODEL_FILE); - - final Graph g = d.inferenceInterface.graph(); - - d.inputName = "input"; - if (g.operation(d.inputName) == null) - throw new RuntimeException("Failed to find input Node '" + d.inputName + "'"); - - d.outputNames = new String[] {"prob", "landmarks", "box"}; - if (g.operation(d.outputNames[0]) == null) - throw new RuntimeException("Failed to find output Node '" + d.outputNames[0] + "'"); - - if (g.operation(d.outputNames[1]) == null) - throw new RuntimeException("Failed to find output Node '" + d.outputNames[1] + "'"); - - if (g.operation(d.outputNames[2]) == null) - throw new RuntimeException("Failed to find output Node '" + d.outputNames[2] + "'"); - - // Pre-allocate buffers. - ByteBuffer byteBuffer = ByteBuffer.allocateDirect(MAX_RESULTS * BYTE_SIZE_OF_FLOAT); - byteBuffer.order(ByteOrder.nativeOrder()); - d.outputProbs = byteBuffer.asFloatBuffer(); - - d.outputBoxes = ByteBuffer.allocateDirect(MAX_RESULTS * BYTE_SIZE_OF_FLOAT * 4) - .order(ByteOrder.nativeOrder()) - .asFloatBuffer(); - - return d; - } - - private MTCNN() {} - - public Pair[] detect(Bitmap bitmap) { - // Log this method so that it can be analyzed with systrace. - Trace.beginSection("detect"); - - Trace.beginSection("preprocessBitmap"); - // Preprocess the image data from 0-255 int to normalized float based - // on the provided parameters. - int w = bitmap.getWidth(), h = bitmap.getHeight(); - int intValues[] = new int[w * h]; - float floatValues[] = new float[w * h * 3]; - - bitmap.getPixels(intValues, 0, w, 0, 0, w, h); - - // BGR - for (int i = 0; i < intValues.length; ++i) { - int p = intValues[i]; - - floatValues[i * 3 + 0] = p & 0xFF; - floatValues[i * 3 + 1] = (p >> 8) & 0xFF; - floatValues[i * 3 + 2] = (p >> 16) & 0xFF; - } - Trace.endSection(); // preprocessBitmap - - // Copy the input data into TensorFlow. - Trace.beginSection("feed"); - inferenceInterface.feed(inputName, floatValues, h, w, 3); - Trace.endSection(); - - // Run the inference call. - Trace.beginSection("run"); - inferenceInterface.run(outputNames, false); - Trace.endSection(); - - // Copy the output Tensor back into the output array. - Trace.beginSection("fetch"); - inferenceInterface.fetch(outputNames[0], outputProbs); - inferenceInterface.fetch(outputNames[2], outputBoxes); - Trace.endSection(); - - outputProbs.flip(); - outputBoxes.flip(); - - int len = outputProbs.remaining(); - Pair faces[] = new Pair[len]; - - for (int i = 0; i < len; i++) { - float top = outputBoxes.get(); - float left = outputBoxes.get(); - float bottom = outputBoxes.get(); - float right = outputBoxes.get(); - - faces[i] = new Pair<>( - new RectF(left, top, right, bottom), outputProbs.get()); - } - - if (outputBoxes.hasRemaining()) - outputBoxes.position(outputBoxes.limit()); - - outputProbs.compact(); - outputBoxes.compact(); - - Trace.endSection(); // "detect" - return faces; - } - - public String getStatString() { - return inferenceInterface.getStatString(); - } - - public void close() { - inferenceInterface.close(); - } -} diff --git a/app/src/main/jni/jnilibsvm/common.h b/app/src/main/jni/jnilibsvm/common.h index 3b252e7..2c3e58a 100755 --- a/app/src/main/jni/jnilibsvm/common.h +++ b/app/src/main/jni/jnilibsvm/common.h @@ -20,7 +20,7 @@ #define DEBUG_TAG "LibSVM-NDK" #define DEBUG_MACRO(x) __android_log_print(ANDROID_LOG_DEBUG, DEBUG_TAG, "NDK: %s", x); -#define JNI_FUNC_NAME(name) Java_pp_facerecognizer_wrapper_LibSVM_ ## name +#define JNI_FUNC_NAME(name) Java_pp_facerecognizer_ml_LibSVM_ ## name const int debug_message_max=1024; diff --git a/app/src/main/jni/tensorflow_demo/imageutils_jni.cc b/app/src/main/jni/tensorflow_demo/imageutils_jni.cc index cb689cb..34bb26b 100755 --- a/app/src/main/jni/tensorflow_demo/imageutils_jni.cc +++ b/app/src/main/jni/tensorflow_demo/imageutils_jni.cc @@ -22,6 +22,7 @@ limitations under the License. #include "rgb2yuv.h" #include "yuv2rgb.h" +#include "prewhiten.h" #define IMAGEUTILS_METHOD(METHOD_NAME) \ Java_pp_facerecognizer_env_ImageUtils_##METHOD_NAME // NOLINT @@ -54,6 +55,10 @@ IMAGEUTILS_METHOD(convertRGB565ToYUV420SP)( JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, jint width, jint height); +JNIEXPORT void JNICALL +IMAGEUTILS_METHOD(prewhiten)( + JNIEnv* env, jclass clazz, jfloatArray input, jint length, jobject output); + #ifdef __cplusplus } #endif @@ -161,3 +166,15 @@ IMAGEUTILS_METHOD(convertRGB565ToYUV420SP)( env->ReleaseByteArrayElements(input, i, JNI_ABORT); env->ReleaseByteArrayElements(output, o, 0); } + +JNIEXPORT void JNICALL +IMAGEUTILS_METHOD(prewhiten)( + JNIEnv* env, jclass clazz, jfloatArray input, jint length, jobject output) { + jboolean inputCopy = JNI_FALSE; + jfloat* const i = env->GetFloatArrayElements(input, &inputCopy); + auto* const o = (jfloat*) env->GetDirectBufferAddress(output); + + Prewhiten(i, length, o); + + env->ReleaseFloatArrayElements(input, i, JNI_ABORT); +} \ No newline at end of file diff --git a/app/src/main/jni/tensorflow_demo/object_tracking/gl_utils.h b/app/src/main/jni/tensorflow_demo/object_tracking/gl_utils.h index a29e677..56b0b96 100755 --- a/app/src/main/jni/tensorflow_demo/object_tracking/gl_utils.h +++ b/app/src/main/jni/tensorflow_demo/object_tracking/gl_utils.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "geom.h" namespace tf_tracking { diff --git a/app/src/main/jni/tensorflow_demo/prewhiten.cc b/app/src/main/jni/tensorflow_demo/prewhiten.cc new file mode 100644 index 0000000..d71cabd --- /dev/null +++ b/app/src/main/jni/tensorflow_demo/prewhiten.cc @@ -0,0 +1,38 @@ +#include "prewhiten.h" +#include "object_tracking/utils.h" + +#ifdef __ARM_NEON +#include +#endif + +void Prewhiten(const float* const input, const int num_vals, float* const output) { + float mean = tf_tracking::ComputeMean(input, num_vals); + float std = tf_tracking::ComputeStdDev(input, num_vals, mean); + auto std_adj = (float) fmax(std, 1.0/sqrt(num_vals)); + + Normalize(input, mean, std_adj, num_vals, output); +} + +#ifdef __ARM_NEON +void NormalizeNeon(const float* const input, const float mean, + const float std_adj, const int num_vals, float* const output) { + const float32x4_t mean_vec = vdupq_n_f32(-mean); + const float32x4_t std_vec = vdupq_n_f32(1/std_adj); + + float32x4_t result; + + int offset = 0; + for (; offset <= num_vals - 4; offset += 4) { + const float32x4_t deltas = + vaddq_f32(mean_vec, vld1q_f32(&input[offset])); + + result = vmulq_f32(deltas, std_vec); + vst1q_f32(&output[offset], result); + } + + // Get the remaining 1 to 3 values. + for (; offset < num_vals; ++offset) { + output[offset] = (input[offset] - mean) / std_adj; + } +} +#endif \ No newline at end of file diff --git a/app/src/main/jni/tensorflow_demo/prewhiten.h b/app/src/main/jni/tensorflow_demo/prewhiten.h new file mode 100644 index 0000000..a8e0580 --- /dev/null +++ b/app/src/main/jni/tensorflow_demo/prewhiten.h @@ -0,0 +1,38 @@ +#ifndef ORG_TENSORFLOW_JNI_IMAGEUTILS_PREWHITEN_H_ +#define ORG_TENSORFLOW_JNI_IMAGEUTILS_PREWHITEN_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +void Prewhiten(const float* const input, const int num_vals, float* const output); + +#ifdef __ARM_NEON +void NormalizeNeon(const float* const input, const float mean, + const float std_adj, const int num_vals, float* const output); +#endif + +inline void NormalizeCpu(const float* const input, const float mean, + const float std_adj, const int num_vals, float* const output) { + for (int i = 0; i < num_vals; ++i) { + output[i] = (input[i] - mean) / std_adj; + } +} + +inline void Normalize(const float* const input, const float mean, + const float std_adj, const int num_vals, float* const output) { +#ifdef __ARM_NEON + (num_vals >= 8) ? NormalizeNeon(input, mean, std_adj, num_vals, output) + : +#endif + NormalizeCpu(input, mean, std_adj, num_vals, output); +} + +#ifdef __cplusplus +} +#endif + +#endif // ORG_TENSORFLOW_JNI_IMAGEUTILS_PREWHITEN_H_ diff --git a/app/src/main/res/values/strings.xml b/app/src/main/res/values/strings.xml index d0209a0..678787b 100644 --- a/app/src/main/res/values/strings.xml +++ b/app/src/main/res/values/strings.xml @@ -5,6 +5,7 @@ Select name OK - add person - update person + Initializing… + Training data… + Try it again later diff --git a/build.gradle b/build.gradle index a012883..4058440 100644 --- a/build.gradle +++ b/build.gradle @@ -7,7 +7,7 @@ buildscript { jcenter() } dependencies { - classpath 'com.android.tools.build:gradle:3.3.1' + classpath 'com.android.tools.build:gradle:3.5.3' // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index c7572e1..bf9dec0 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ -#Mon Jan 28 21:37:57 KST 2019 +#Sun Nov 10 01:04:20 KST 2019 distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-4.10.1-all.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-5.4.1-all.zip