Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mobile] Memory crash after repeated inference with dynamic shape input #22520

Open
laurenspriem opened this issue Oct 21, 2024 · 8 comments
Open
Labels
api:Java issues related to the Java API memory platform:mobile issues related to ONNX Runtime mobile; typically submitted using template stale issues that have not been addressed in a while; categorized by a bot

Comments

@laurenspriem
Copy link

Describe the issue

We recently altered our ONNX model used in production in our mobile app to include the preprocessing steps, which were previously done separately prior to inference. Because it is an image model, this means that now the model takes as input an array of raw RGBA bytes of an image, which tends to be a lot of data. We've found that since this change the memory consumption goes continually up as the app performs more inference runs, eventually resulting in a crash.

I was wondering, is there anything we can do in our Java/Kotlin code to make sure memory is getting properly cleared? Aside from the outputs.close() and inputTensor.close() calls that we already have? It seems like GC is not able to keep up with continued inference runs right now.

Please see below for the crash logs. Thank you in advance for any and all help!

I/dependent.debug(16354): Background concurrent mark compact GC freed 638KB AllocSpace bytes, 30(193MB) LOS objects, 29% free, 230MB/326MB, paused 762us,6.342ms total 308.671ms
I/dependent.debug(16354): Waiting for a blocking GC Alloc
I/dependent.debug(16354): Background concurrent mark compact GC freed 200KB AllocSpace bytes, 2(104KB) LOS objects, 21% free, 342MB/438MB, paused 453us,7.482ms total 133.554ms
I/dependent.debug(16354): WaitForGcToComplete blocked Alloc on Background for 53.726ms
I/dependent.debug(16354): Starting a blocking GC Alloc
I/dependent.debug(16354): Forcing collection of SoftReferences for 111MB allocation
I/dependent.debug(16354): Starting a blocking GC Alloc
I/dependent.debug(16354): Alloc concurrent mark compact GC freed 52KB AllocSpace bytes, 0(0B) LOS objects, 21% free, 342MB/438MB, paused 454us,3.273ms total 36.865ms
W/dependent.debug(16354): Throwing OutOfMemoryError "Failed to allocate a 117235219 byte allocation with 100630528 free bytes and 169MB until OOM, target footprint 460065312, growth limit 536870912" (VmSize 26713268 kB)
2
I/dependent.debug(16354): Starting a blocking GC Alloc
I/dependent.debug(16354): Alloc concurrent mark compact GC freed 64KB AllocSpace bytes, 0(0B) LOS objects, 21% free, 342MB/438MB, paused 501us,5.569ms total 47.679ms
I/dependent.debug(16354): Forcing collection of SoftReferences for 111MB allocation
I/dependent.debug(16354): Starting a blocking GC Alloc
I/dependent.debug(16354): Alloc concurrent mark compact GC freed 32KB AllocSpace bytes, 0(0B) LOS objects, 21% free, 342MB/438MB, paused 480us,3.199ms total 40.527ms
W/dependent.debug(16354): Throwing OutOfMemoryError "Failed to allocate a 117235219 byte allocation with 100663296 free bytes and 169MB until OOM, target footprint 460065312, growth limit 536870912" (VmSize 26713268 kB)
E/AndroidRuntime(16354): FATAL EXCEPTION: DefaultDispatcher-worker-1
E/AndroidRuntime(16354): Process: io.ente.photos.independent.debug, PID: 16354
E/AndroidRuntime(16354): java.lang.OutOfMemoryError: Failed to allocate a 117235219 byte allocation with 100663296 free bytes and 169MB until OOM, target footprint 460065312, growth limit 536870912
E/AndroidRuntime(16354): at dalvik.system.VMRuntime.newNonMovableArray(Native Method)
E/AndroidRuntime(16354): at java.nio.DirectByteBuffer$MemoryRef.<init>(DirectByteBuffer.java:73)
E/AndroidRuntime(16354): at java.nio.ByteBuffer.allocateDirect(ByteBuffer.java:347)
E/AndroidRuntime(16354): at ai.onnxruntime.OrtUtil.prepareBuffer(OrtUtil.java:507)
E/AndroidRuntime(16354): at ai.onnxruntime.OnnxTensor.createTensor(OnnxTensor.java:754)
E/AndroidRuntime(16354): at ai.onnxruntime.OnnxTensor.createTensor(OnnxTensor.java:610)
E/AndroidRuntime(16354): at ai.onnxruntime.OnnxTensor.createTensor(OnnxTensor.java:589)
E/AndroidRuntime(16354): at io.ente.photos.onnx_dart.OnnxDartPlugin$predict$2.invokeSuspend(OnnxDartPlugin.kt:205)
E/AndroidRuntime(16354): at kotlin.coroutines.jvm.internal.BaseContinuationImpl.resumeWith(ContinuationImpl.kt:33)
E/AndroidRuntime(16354): at kotlinx.coroutines.DispatchedTask.run(DispatchedTask.kt:106)
E/AndroidRuntime(16354): at kotlinx.coroutines.internal.LimitedDispatcher$Worker.run(LimitedDispatcher.kt:115)
E/AndroidRuntime(16354): at kotlinx.coroutines.scheduling.TaskImpl.run(Tasks.kt:100)
E/AndroidRuntime(16354): at kotlinx.coroutines.scheduling.CoroutineScheduler.runSafely(CoroutineScheduler.kt:584)
E/AndroidRuntime(16354): at kotlinx.coroutines.scheduling.CoroutineScheduler$Worker.executeTask(CoroutineScheduler.kt:793)
E/AndroidRuntime(16354): at kotlinx.coroutines.scheduling.CoroutineScheduler$Worker.runWorker(CoroutineScheduler.kt:697)
E/AndroidRuntime(16354): at kotlinx.coroutines.scheduling.CoroutineScheduler$Worker.run(CoroutineScheduler.kt:684)
E/AndroidRuntime(16354): Suppressed: kotlinx.coroutines.internal.DiagnosticCoroutineContextException: [StandaloneCoroutine{Cancelling}@4fa0031, Dispatchers.IO]
I/Process (16354): Sending signal. PID: 16354 SIG: 9
Lost connection to device.

To reproduce

  1. Get any model that has dynamic shaped input (lmk if I should share mine, can do that)
  2. Continuously run inference on the model with different data using Java API on Android
  3. Watch the memory consumption go up and up till app crashes

Urgency

Urgent, as this issue is happening in production, causing crashes and inconvenience for our mobile customers.

Platform

Android

OS Version

Android 14

ONNX Runtime Installation

Released Package

Compiler Version (if 'Built from Source')

No response

Package Name (if 'Released Package')

onnxruntime-android

ONNX Runtime Version or Commit ID

1.18

ONNX Runtime API

Java/Kotlin

Architecture

ARM64

Execution Provider

Default CPU

Execution Provider Library Version

No response

@laurenspriem laurenspriem added the platform:mobile issues related to ONNX Runtime mobile; typically submitted using template label Oct 21, 2024
@github-actions github-actions bot added the api:Java issues related to the Java API label Oct 21, 2024
@Craigacp
Copy link
Contributor

How are you constructing the tensors? For best performance you should be using a cache of direct ByteBuffers you manage rather than letting the JVM create & garbage collect them as the GC algorithm can get overwhelmed.

@laurenspriem
Copy link
Author

laurenspriem commented Oct 22, 2024

I created an MRE in this repo.

Regarding your question about constructing the tensors, the relevant code is here:

val env = OrtEnvironment.getEnvironment()
var inputTensorShape: LongArray = longArrayOf(1, 112, 112, 3)
when (modelType) {
  ModelType.ClipImageEncoder -> {
      inputTensorShape = inputShapeArray!!.map { it.toLong() }.toLongArray()
  }
  ModelType.YOLOv5Face -> {
      inputTensorShape = inputShapeArray!!.map { it.toLong() }.toLongArray()
  }
}

var buffer: ByteBuffer = ByteBuffer.allocate(0)
if (inputUint8DataArray != null) {
  buffer = ByteBuffer.wrap(inputUint8DataArray)
}
val inputTensor = OnnxTensor.createTensor(env, buffer, inputTensorShape, OnnxJavaType.UINT8)
val inputs = mutableMapOf<String, OnnxTensor>()
inputs["input"] = inputTensor
val outputs = session.run(inputs)
val outputTensor = (outputs[0].value as Array<FloatArray>)
val flatList = outputTensor.flattenToFloatArray()
withContext(Dispatchers.Main) {
  result.success(flatList)
}
outputs.close()
inputTensor.close()
buffer.clear()

I'm not used to writing Kotlin code, so I might be missing something obvious. If so, any pointers on how to solve this would be appreciated! If not, then it's probably a memory management issue in ORT.

@Craigacp
Copy link
Contributor

You should use ByteBuffer.allocateDirect(<size-of-input-in-bytes>) rather than ByteBuffer.allocate. allocateDirect produces something that ORT can directly consume, if you use allocate then we have to copy it to a direct buffer before we can pass it into the native library. Similarly using OnnxTensor.createTensor(<array-type>) will cause some copying to get the data out of the Java array and into the native library, which is pretty slow as Java multidimensional arrays are pointer heavy.

If your inputs are always of the same size (or a small set of sizes) then keep around a cache of the buffers, you can rewrite the entries (via put(byte[]), rewind) and either recreate the tensor, or keep the tensor around itself and pass it back in. That will keep you to a fixed memory budget and also speed things up because you won't have any Java side allocations.

@laurenspriem
Copy link
Author

You should use ByteBuffer.allocateDirect() rather than ByteBuffer.allocate. allocateDirect produces something that ORT can directly consume, if you use allocate then we have to copy it to a direct buffer before we can pass it into the native library.

Thanks for pointing this out! I have changed it to the following:

val buffer: ByteBuffer = ByteBuffer.allocateDirect(inputUint8DataArray!!.size)
buffer.put(inputUint8DataArray)
buffer.flip()

Despite this change I'm still seeing a lot of GC work in the logs:

I/mple.ort_memory(11873): Waiting for a blocking GC Alloc
I/mple.ort_memory(11873): Background concurrent mark compact GC freed 194KB AllocSpace bytes, 4(191MB) LOS objects, 65% free, 50MB/146MB, paused 344us,3.309ms total 150.397ms
I/mple.ort_memory(11873): WaitForGcToComplete blocked Alloc on Background for 96.850ms
I/mple.ort_memory(11873): Starting a blocking GC Alloc

So I'm afraid this change alone doesn't solve the issue.

Similarly using OnnxTensor.createTensor() will cause some copying to get the data out of the Java array and into the native library, which is pretty slow as Java multidimensional arrays are pointer heavy.

I'm not sure I understand this comment. As for as I understood the ORT java API only takes byte buffers for creating tensors of this type of data (uint8). Is there some other way of creating the onnx tensor that I'm not aware of?

If your inputs are always of the same size (or a small set of sizes) then keep around a cache of the buffers, you can rewrite the entries (via put(byte[]), rewind) and either recreate the tensor, or keep the tensor around itself and pass it back in. That will keep you to a fixed memory budget and also speed things up because you won't have any Java side allocations.

Unfortunately the inputs are very dynamic and can be of any size, so I don't think this would help.

@Craigacp
Copy link
Contributor

You should use ByteBuffer.allocateDirect() rather than ByteBuffer.allocate. allocateDirect produces something that ORT can directly consume, if you use allocate then we have to copy it to a direct buffer before we can pass it into the native library.

Thanks for pointing this out! I have changed it to the following:

val buffer: ByteBuffer = ByteBuffer.allocateDirect(inputUint8DataArray!!.size)
buffer.put(inputUint8DataArray)
buffer.flip()

Despite this change I'm still seeing a lot of GC work in the logs:

I/mple.ort_memory(11873): Waiting for a blocking GC Alloc
I/mple.ort_memory(11873): Background concurrent mark compact GC freed 194KB AllocSpace bytes, 4(191MB) LOS objects, 65% free, 50MB/146MB, paused 344us,3.309ms total 150.397ms
I/mple.ort_memory(11873): WaitForGcToComplete blocked Alloc on Background for 96.850ms
I/mple.ort_memory(11873): Starting a blocking GC Alloc

So I'm afraid this change alone doesn't solve the issue.

Lots of GC work just means you're creating a lot of garbage. If you keep passing in large bitmaps allocated in fresh objects then it'll necessarily have to create garbage. You can try to modify your code so you write directly to the buffer from the image source rather than having intermediate arrays, but I don't know what the rest of your codebase looks like.

Similarly using OnnxTensor.createTensor() will cause some copying to get the data out of the Java array and into the native library, which is pretty slow as Java multidimensional arrays are pointer heavy.

I'm not sure I understand this comment. As for as I understood the ORT java API only takes byte buffers for creating tensors of this type of data (uint8). Is there some other way of creating the onnx tensor that I'm not aware of?

Yeah, there's no way to create a uint8 input aside from a buffer, but if you have other inputs you should not use arrays to create those.

If your inputs are always of the same size (or a small set of sizes) then keep around a cache of the buffers, you can rewrite the entries (via put(byte[]), rewind) and either recreate the tensor, or keep the tensor around itself and pass it back in. That will keep you to a fixed memory budget and also speed things up because you won't have any Java side allocations.

Unfortunately the inputs are very dynamic and can be of any size, so I don't think this would help.

If there's an upper bound on the size then you can allocate buffers of that size, set the limit on them as appropriate for the image you've got and pass it in to tensor construction. ORT doesn't care if the buffer has other stuff in it provided you've set the position and limit correctly.

@laurenspriem
Copy link
Author

First of all, I really appreciate all the pointers, thank you so much for your help @Craigacp ! 🙏

Lots of GC work just means you're creating a lot of garbage. If you keep passing in large bitmaps allocated in fresh objects then it'll necessarily have to create garbage. You can try to modify your code so you write directly to the buffer from the image source rather than having intermediate arrays, but I don't know what the rest of your codebase looks like.

This will be tricky, since the app is in fact a Flutter app where I'm writing a platform plugin to access the java API. Unfortunately the data I have in Flutter/Dart can only be passed to kotlin as a ByteArray so the step towards ByteBuffer will always be needed.

If there's an upper bound on the size then you can allocate buffers of that size, set the limit on them as appropriate for the image you've got and pass it in to tensor construction. ORT doesn't care if the buffer has other stuff in it provided you've set the position and limit correctly.

This is a great idea, thanks for the suggestion! Unfortunately I cannot seem to get it to work, or at least I'm not able to reduce the GC calls through this strategy. To make sure I'm not making a dumb mistake, here is the code I'm using:

First I initiate permant direct buffers inside the class I'm using:

  private val yoloBuffer = ByteBuffer.allocateDirect(5000*5000*4)
  private val clipBuffer = ByteBuffer.allocateDirect(5000*5000*4)

Then in my predict function that gets called repeatedly, I keep re-using these buffers:

  private fun predict(modelType: ModelType, sessionAddress: Int, inputUint8DataArray: ByteArray? = null, inputShapeArray: IntArray? = null, result: Result) {

    scope.launch {
      val modelState = sessionMap[modelType]
      val session = modelState?.sessionAddresses?.get(sessionAddress)
      if (session == null) {
        withContext(Dispatchers.Main) {
          result.error("SESSION_NOT_FOUND", "Session not found for address: $sessionAddress", null)
        }
        return@launch
      }

      try {
        val env = OrtEnvironment.getEnvironment()
        var inputTensorShape: LongArray = longArrayOf(1, 112, 112, 3)
        var inputTensor: OnnxTensor? = null
        when (modelType) {
          ModelType.ClipImageEncoder -> {
              inputTensorShape = inputShapeArray!!.map { it.toLong() }.toLongArray()
              clipBuffer.clear()
              clipBuffer.put(inputUint8DataArray!!)
              clipBuffer.flip()
            inputTensor = OnnxTensor.createTensor(env, clipBuffer, inputTensorShape, OnnxJavaType.UINT8)
          }
          ModelType.YOLOv5Face -> {
              inputTensorShape = inputShapeArray!!.map { it.toLong() }.toLongArray()
              yoloBuffer.clear()
              yoloBuffer.put(inputUint8DataArray!!)
              yoloBuffer.flip()
            inputTensor = OnnxTensor.createTensor(env, yoloBuffer, inputTensorShape, OnnxJavaType.UINT8)
          }
        }
        val inputs = mutableMapOf<String, OnnxTensor>()
        inputs["input"] = inputTensor
        val outputs = session.run(inputs)
        val outputTensor = (outputs[0].value as Array<FloatArray>)
        val flatList = outputTensor.flattenToFloatArray()
        withContext(Dispatchers.Main) {
          result.success(flatList)
        }
        outputs.close()
        inputTensor.close()
      } catch (e: OrtException) {
        withContext(Dispatchers.Main) {
          result.error("PREDICTION_ERROR", "Error during prediction: ${e.message} ${e.stackTraceToString()}", null)
        }
      } catch (e: Exception) {
        Log.e(TAG, "Error during prediction: ${e.message}", e)
        withContext(Dispatchers.Main) {
          result.error("UNHANDLED_ERROR", "Error during prediction: ${e.message}", null)
        }
      }
    }
  }

To be honest I'm slowly starting to lose faith that I'll ever be able to get rid of the memory issues, but still motivated to try potential fixes out. So if you have any other ideas please let me know :)

@Craigacp
Copy link
Contributor

You can supply a buffer as the output tensor too assuming you know the size of the output. That will prevent ORT from allocating memory to hold the output, and also prevent the Java code from allocating a float array to store it. The input side of things looks ok in your example.

Copy link
Contributor

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Nov 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
api:Java issues related to the Java API memory platform:mobile issues related to ONNX Runtime mobile; typically submitted using template stale issues that have not been addressed in a while; categorized by a bot
Projects
None yet
Development

No branches or pull requests

3 participants