Skip to content

Commit

Permalink
Selfie Quality Model
Browse files Browse the repository at this point in the history
  • Loading branch information
vanshg committed Jan 27, 2024
1 parent 165967b commit 668d8f5
Show file tree
Hide file tree
Showing 17 changed files with 464 additions and 160 deletions.
8 changes: 8 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ okhttp = "4.12.0"
play-services-mlkit-face-detection = "17.1.0"
retrofit = "2.9.0"
sentry = "7.2.0"
tflite = "2.14.0"
tflite-gpu = "2.14.0"
tflite-metadata = "0.4.4"
tflite-support = "0.4.4"
timber = "5.0.1"
truth = "1.3.0"
uiautomator = "2.3.0-beta01"
Expand Down Expand Up @@ -104,6 +108,10 @@ retrofit = { module = "com.squareup.retrofit2:retrofit", version.ref = "retrofit
retrofit-converter-moshi = { module = "com.squareup.retrofit2:converter-moshi", version.ref = "retrofit" }
sentry = { module = "io.sentry:sentry" }
sentry-bom = { module = "io.sentry:sentry-bom", version.ref = "sentry" }
tflite = { module = "org.tensorflow:tensorflow-lite", version.ref = "tflite" }
tflite-metadata = { group = "org.tensorflow", name = "tensorflow-lite-metadata", version.ref = "tflite-metadata" }
tflite-gpu = { group = "org.tensorflow", name = "tensorflow-lite-gpu", version.ref = "tflite-gpu" }
tflite-support = { group = "org.tensorflow", name = "tensorflow-lite-support", version.ref = "tflite-support" }
timber = { module = "com.jakewharton.timber:timber", version.ref = "timber" }
truth = { module = "com.google.truth:truth", version.ref = "truth" }
uiautomator = { module = "androidx.test.uiautomator:uiautomator", version.ref = "uiautomator" }
6 changes: 6 additions & 0 deletions lib/lib.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ android {
buildFeatures {
compose = true
buildConfig = true
mlModelBinding = true
}

composeOptions {
Expand Down Expand Up @@ -206,6 +207,11 @@ dependencies {
// Bundled model
implementation(libs.mlkit.obj.detection)

implementation(libs.tflite)
implementation(libs.tflite.gpu)
implementation(libs.tflite.metadata)
implementation(libs.tflite.support)

testImplementation(libs.junit)
testImplementation(libs.okhttp.mockwebserver)
testImplementation(libs.coroutines.test)
Expand Down
16 changes: 16 additions & 0 deletions lib/src/main/java/com/smileidentity/compose/SmileIDExt.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import com.smileidentity.compose.document.OrchestratedDocumentVerificationScreen
import com.smileidentity.compose.selfie.OrchestratedSelfieCaptureScreen
import com.smileidentity.compose.theme.colorScheme
import com.smileidentity.compose.theme.typography
import com.smileidentity.compose.transactionfraud.TransactionFraudScreen
import com.smileidentity.models.IdInfo
import com.smileidentity.models.JobType
import com.smileidentity.results.BiometricKycResult
Expand Down Expand Up @@ -434,3 +435,18 @@ fun SmileID.ConsentScreen(
)
}
}

@Composable
fun SmileID.TransactionFraud(
modifier: Modifier = Modifier,
colorScheme: ColorScheme = SmileID.colorScheme,
typography: Typography = SmileID.typography,
onResult: SmileIDCallback<Nothing> = {},
) {
MaterialTheme(colorScheme = colorScheme, typography = typography) {
TransactionFraudScreen(
modifier = modifier,
onResult = onResult,
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
package com.smileidentity.compose.transactionfraud

import android.graphics.Bitmap
import android.os.OperationCanceledException
import androidx.annotation.IntRange
import androidx.annotation.OptIn
import androidx.camera.core.ExperimentalGetImage
import androidx.camera.core.ImageProxy
import androidx.compose.animation.animateColorAsState
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.remember
import androidx.compose.ui.Alignment.Companion.BottomCenter
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import androidx.compose.ui.window.DialogProperties
import androidx.core.graphics.scale
import androidx.lifecycle.ViewModel
import androidx.lifecycle.compose.collectAsStateWithLifecycle
import androidx.lifecycle.viewModelScope
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import com.smileidentity.ml.ImQualCp20
import com.smileidentity.results.SmileIDCallback
import com.smileidentity.results.SmileIDResult
import com.smileidentity.util.rotated
import com.ujizin.camposer.CameraPreview
import com.ujizin.camposer.state.CamSelector
import com.ujizin.camposer.state.ImplementationMode
import com.ujizin.camposer.state.ScaleType
import com.ujizin.camposer.state.rememberCamSelector
import com.ujizin.camposer.state.rememberCameraState
import com.ujizin.camposer.state.rememberImageAnalyzer
import kotlinx.coroutines.FlowPreview
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.sample
import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.flow.update
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.image.TensorImage
import timber.log.Timber

@Composable
fun TransactionFraudScreen(
modifier: Modifier = Modifier,
onResult: SmileIDCallback<Nothing> = {},
) {
val context = LocalContext.current
val imageQualityModel = remember { ImQualCp20.newInstance(context) }
// TODO: Request Permissions if not granted
Dialog(
onDismissRequest = {
onResult(SmileIDResult.Error(OperationCanceledException("User Cancelled")))
},
properties = DialogProperties(dismissOnBackPress = true, dismissOnClickOutside = false),
) {
TransactionFraudScreen(
imageQualityModel = imageQualityModel,
onResult = onResult,
modifier = modifier
.height(512.dp)
.clip(MaterialTheme.shapes.large),
)
}
}

@Composable
private fun TransactionFraudScreen(
imageQualityModel: ImQualCp20,
modifier: Modifier = Modifier,
onResult: SmileIDCallback<Nothing> = {},
viewModel: TransactionFraudViewModel = viewModel(
initializer = { TransactionFraudViewModel(imageQualityModel) },
),
) {
val uiState by viewModel.uiState.collectAsStateWithLifecycle()
val cameraState = rememberCameraState()
val camSelector by rememberCamSelector(CamSelector.Front)
Box(contentAlignment = BottomCenter, modifier = modifier) {
CameraPreview(
cameraState = cameraState,
camSelector = camSelector,
implementationMode = ImplementationMode.Compatible,
scaleType = ScaleType.FillCenter,
imageAnalyzer = cameraState.rememberImageAnalyzer(analyze = viewModel::analyzeImage),
isImageAnalysisEnabled = true,
modifier = Modifier.fillMaxSize(),
)

val textColor = if (uiState.faceQuality > 50) {
MaterialTheme.colorScheme.tertiary
} else {
MaterialTheme.colorScheme.error
}
Text(
text = "Face Quality\n${uiState.faceQuality}",
textAlign = TextAlign.Center,
color = animateColorAsState(targetValue = textColor, label = "faceQualityText").value,
style = MaterialTheme.typography.displaySmall,
fontWeight = FontWeight.Bold,
modifier = Modifier.padding(vertical = 64.dp),
)
}
}

data class TransactionFraudUiState(
@IntRange(0, 100) val faceQuality: Int = 0,
)

@kotlin.OptIn(FlowPreview::class)
class TransactionFraudViewModel(private val imageQualityModel: ImQualCp20) : ViewModel() {
private val _uiState = MutableStateFlow(TransactionFraudUiState())
val uiState = _uiState.asStateFlow().sample(250).stateIn(
viewModelScope,
SharingStarted.WhileSubscribed(),
TransactionFraudUiState(),
)
private val modelInputSize = intArrayOf(1, 120, 120, 3)
private val faceDetectorOptions = FaceDetectorOptions.Builder().apply {
setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_FAST)
setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_NONE)
setContourMode(FaceDetectorOptions.CONTOUR_MODE_NONE)
setClassificationMode(FaceDetectorOptions.CLASSIFICATION_MODE_ALL)
}.build()

private val faceDetector by lazy { FaceDetection.getClient(faceDetectorOptions) }

@OptIn(ExperimentalGetImage::class)
fun analyzeImage(imageProxy: ImageProxy) {
val image = imageProxy.image ?: run {
Timber.w("ImageProxy has no image")
imageProxy.close()
return
}

val inputImage = InputImage.fromMediaImage(image, imageProxy.imageInfo.rotationDegrees)
faceDetector.process(inputImage).addOnSuccessListener { faces ->
// TODO: Add all the protections
val face = faces.firstOrNull() ?: run {
Timber.w("No face detected")
_uiState.update { it.copy(faceQuality = 0) }
return@addOnSuccessListener
}

val bBox = face.boundingBox

// Check that the corners of the face bounding box are within the inputImage
val faceCornersInImage = bBox.left >= 0 && bBox.right <= inputImage.width &&
bBox.top >= 0 && bBox.bottom <= inputImage.height
if (!faceCornersInImage) {
Timber.w("Face bounding box not within image")
_uiState.update { it.copy(faceQuality = 0) }
return@addOnSuccessListener
}

// face mesh returns 480ish points. take min/max of all those points. use that as
// bounding box
// Check that the corners of the face bounding box are within the inputImage

// returns a matrix, each row is a probability of being a quality
// get 1 row if batch size is 1
// 1st column is the actual quality
// theoretically, 2nd column is 1-(1st_column)

// model is trained on *face mesh* crop (different from face detection potentially)

val startTime = System.nanoTime()
val bitmap = with(imageProxy.toBitmap().rotated(imageProxy.imageInfo.rotationDegrees)) {
if (bBox.left + bBox.width() > this.width) {
Timber.w("Face bounding box width is greater than image width")
_uiState.update { it.copy(faceQuality = 0) }
return@addOnSuccessListener
}

if (bBox.top + bBox.height() > this.height) {
Timber.w("Face bounding box height is greater than image height")
_uiState.update { it.copy(faceQuality = 0) }
return@addOnSuccessListener
}

val croppedBitmap = Bitmap.createBitmap(
this,
bBox.left,
bBox.top,
bBox.width(),
bBox.height(),
// NB! bBox is not guaranteed to be square, so scale might squish the image
).scale(modelInputSize[1], modelInputSize[2], false)
recycle()
return@with croppedBitmap
}

// Image Quality Model Inference
val input = TensorImage(DataType.FLOAT32).apply { load(bitmap) }
val outputs = imageQualityModel.process(input.tensorBuffer)
val output = outputs.outputFeature0AsTensorBuffer.floatArray.firstOrNull() ?: run {
Timber.e("No image quality output")
return@addOnSuccessListener
}

val elapsedTimeMs = (System.nanoTime() - startTime) / 1_000_000
Timber.d("Face Quality: $output (model inference time: $elapsedTimeMs ms)")

_uiState.update { it.copy(faceQuality = (output * 100).toInt()) }
}.addOnFailureListener { exception ->
Timber.e(exception, "Error detecting faces")
_uiState.update { it.copy(faceQuality = 0) }
}.addOnCompleteListener {
// Closing the proxy allows the next image to be delivered to the analyzer
imageProxy.close()
}
}
}
1 change: 0 additions & 1 deletion lib/src/main/java/com/smileidentity/models/PrepUpload.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,4 @@ data class PrepUploadResponse(
@Json(name = "ref_id") val refId: String,
@Json(name = "upload_url") val uploadUrl: String,
@Json(name = "smile_job_id") val smileJobId: String,
@Json(name = "camera_config") val cameraConfig: String?,
)
22 changes: 22 additions & 0 deletions lib/src/main/java/com/smileidentity/util/Util.kt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,28 @@ internal fun isValidDocumentImage(
uri: Uri?,
) = isImageAtLeast(context, uri, width = 1920, height = 1080)

fun Bitmap.rotated(
rotationDegrees: Int,
flipX: Boolean = false,
flipY: Boolean = false,
): Bitmap {
val matrix = Matrix()

// Rotate the image back to straight.
matrix.postRotate(rotationDegrees.toFloat())

// Mirror the image along the X or Y axis.
matrix.postScale(if (flipX) -1.0f else 1.0f, if (flipY) -1.0f else 1.0f)
val rotatedBitmap =
Bitmap.createBitmap(this, 0, 0, width, height, matrix, true)

// Recycle the old bitmap if it has changed.
if (rotatedBitmap !== this) {
recycle()
}
return rotatedBitmap
}

/**
* Post-processes the image stored in [bitmap] and saves to [file]. The image is scaled to
* [maxOutputSize], but maintains the aspect ratio. The image can also converted to grayscale.
Expand Down
25 changes: 1 addition & 24 deletions lib/src/main/java/com/smileidentity/viewmodel/SelfieViewModel.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.smileidentity.viewmodel

import android.graphics.Bitmap
import android.graphics.Matrix
import android.util.Size
import androidx.annotation.OptIn
import androidx.annotation.StringRes
Expand Down Expand Up @@ -33,6 +31,7 @@ import com.smileidentity.util.createLivenessFile
import com.smileidentity.util.createSelfieFile
import com.smileidentity.util.getExceptionHandler
import com.smileidentity.util.postProcessImageBitmap
import com.smileidentity.util.rotated
import kotlinx.collections.immutable.ImmutableMap
import kotlinx.collections.immutable.persistentMapOf
import kotlinx.coroutines.FlowPreview
Expand Down Expand Up @@ -344,26 +343,4 @@ class SelfieViewModel(
fun onFinished(callback: SmileIDCallback<SmartSelfieResult>) {
callback(result!!)
}

private fun Bitmap.rotated(
rotationDegrees: Int,
flipX: Boolean = false,
flipY: Boolean = false,
): Bitmap {
val matrix = Matrix()

// Rotate the image back to straight.
matrix.postRotate(rotationDegrees.toFloat())

// Mirror the image along the X or Y axis.
matrix.postScale(if (flipX) -1.0f else 1.0f, if (flipY) -1.0f else 1.0f)
val rotatedBitmap =
Bitmap.createBitmap(this, 0, 0, width, height, matrix, true)

// Recycle the old bitmap if it has changed.
if (rotatedBitmap !== this) {
recycle()
}
return rotatedBitmap
}
}
1 change: 1 addition & 0 deletions lib/src/main/ml/.gitkeep
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Placeholder for the ml directory. ML models go in this directory
3 changes: 3 additions & 0 deletions lib/src/main/res/values/strings.xml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@
<!-- Enhanced DocV Screen -->
<string name="si_enhanced_docv_product_name">Enhanced Document Verification</string>

<!-- Transaction Fraud -->
<string name="si_transaction_fraud_product_name">Transaction Fraud</string>

<!-- Generic Errors -->
<string name="si_processing_error_subtitle">This could be because of image quality or internet connectivity</string>
</resources>
Loading

0 comments on commit 668d8f5

Please sign in to comment.