Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mobile] whisper android sample #18521

Closed
marouanetalaa opened this issue Nov 20, 2023 · 6 comments
Closed

[Mobile] whisper android sample #18521

marouanetalaa opened this issue Nov 20, 2023 · 6 comments
Labels
api:Java issues related to the Java API platform:mobile issues related to ONNX Runtime mobile; typically submitted using template

Comments

@marouanetalaa
Copy link

Describe the issue

i get this error when i run a model i optimized with onnxruntime, for example openai/whisper-base, i use it then in the example cloud example of whisper, i only change the model.onnx but keep the same audio.pcm ;

Error: Unknown input name audio_pcm, expected one of [audio_stream, max_length, min_length, num_beams, num_return_sequences, length_penalty, repetition_penalty]

'Error: Unknown input name audio_pcm, expected one of [audio_stream, max_length, min_length, num_beams, num_return_sequences, length_penalty, repetition_penalty] ai.onnxruntime.OrtException: Unknown input name audio_pcm, expected one of [audio_stream, max_length, min_length, num_beams, num_return_sequences, length_penalty, repetition_penalty] at ai.onnxruntime.OrtSession.run(OrtSession.java:284) at ai.onnxruntime.OrtSession.run(OrtSession.java:242) at ai.onnxruntime.OrtSession.run(OrtSession.java:210) at com.alex.newnotes.utils.SpeechRecognizer.run(SpeechRecognizer.kt:40) at com.alex.newnotes.utils.AudioTensorSource$Companion.fromRecording(AudioTensorSource.kt:118) at com.alex.newnotes.ui.edit.EditFragment.addSpeech$lambda-28(EditFragment.kt:454) at com.alex.newnotes.ui.edit.EditFragment.$r8$lambda$imHSLHbY2ngBNrJv8EgIh9aj6Dg(Unknown Source:0) at com.alex.newnotes.ui.edit.EditFragment$$ExternalSyntheticLambda11.run(Unknown Source:2) at java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:462) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1167) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:641) at java.lang.Thread.run(Thread.java:923) '

speechRecognizer.kt : `class SpeechRecognizer(modelBytes: ByteArray) : AutoCloseable {
private val session: OrtSession
private val baseInputs: Map<String, OnnxTensor>

init {
    val env = OrtEnvironment.getEnvironment()
    val sessionOptions = OrtSession.SessionOptions()
    sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath())

    session = env.createSession(modelBytes, sessionOptions)

    val nMels: Long = 80
    val nFrames: Long = 3000

    baseInputs = mapOf(
        "min_length" to createIntTensor(env, intArrayOf(1), tensorShape(1)),
        "max_length" to createIntTensor(env, intArrayOf(200), tensorShape(1)),
        "num_beams" to createIntTensor(env, intArrayOf(1), tensorShape(1)),
        "num_return_sequences" to createIntTensor(env, intArrayOf(1), tensorShape(1)),
        "length_penalty" to createFloatTensor(env, floatArrayOf(1.0f), tensorShape(1)),
        "repetition_penalty" to createFloatTensor(env, floatArrayOf(1.0f), tensorShape(1)),
    )
}

data class Result(val text: String, val inferenceTimeInMs: Long)

fun run(audioTensor: OnnxTensor): Result {
    val inputs = mutableMapOf<String, OnnxTensor>()
    baseInputs.toMap(inputs)
    inputs["audio_pcm"] = audioTensor
    val startTimeInMs = SystemClock.elapsedRealtime()
    val outputs = session.run(inputs)
    val elapsedTimeInMs = SystemClock.elapsedRealtime() - startTimeInMs
    val recognizedText = outputs.use {
        @Suppress("UNCHECKED_CAST")
        (outputs[0].value as Array<Array<String>>)[0][0]
    }
    return Result(recognizedText, elapsedTimeInMs)
}

override fun close() {
    baseInputs.values.forEach {
        it.close()
    }
    session.close()
}

}`

To reproduce

it is basically the same as the example

Urgency

urgent

Platform

Android

OS Version

11.0

ONNX Runtime Installation

Built from Source

Compiler Version (if 'Built from Source')

Package Name (if 'Released Package')

None

ONNX Runtime Version or Commit ID

ONNX Runtime API

Java/Kotlin

Architecture

X86

Execution Provider

Default CPU

Execution Provider Library Version

No response

@marouanetalaa marouanetalaa added the platform:mobile issues related to ONNX Runtime mobile; typically submitted using template label Nov 20, 2023
@github-actions github-actions bot added the api:Java issues related to the Java API label Nov 20, 2023
@YUNQIUGUO
Copy link
Contributor

YUNQIUGUO commented Nov 20, 2023

Can you clarify are you following the cloud one or the local example?

Infer from the error message and the code snippet you provided, seems like you were following the local one.

so in the local example, here we are using the input_name "audio_pcm" based on this whipser_cpu_int8_model.onnx

From your error message, your model is expecting these sets of input names:
[audio_stream, max_length, min_length, num_beams, num_return_sequences, length_penalty, repetition_penalty]

which is probably different from the input names that the example model have. You might want to adjust the input tensor name accordingly.

@marouanetalaa
Copy link
Author

Sorry for the confusion, I meant to say local example.

Do you mean that I should change audio_pcm to audio_stream ?

when I try that I get this error :
Error: Error code - ORT_INVALID_ARGUMENT - message: Unexpected input data type. Actual: (tensor(float)) , expected: (tensor(uint8)) ai.onnxruntime.OrtException: Error code - ORT_INVALID_ARGUMENT - message: Unexpected input data type. Actual: (tensor(float)) , expected: (tensor(uint8)) at ai.onnxruntime.OrtSession.run(Native Method) at ai.onnxruntime.OrtSession.run(OrtSession.java:301) at ai.onnxruntime.OrtSession.run(OrtSession.java:242) at ai.onnxruntime.OrtSession.run(OrtSession.java:210) at com.alex.newnotes.utils.SpeechRecognizer.run(SpeechRecognizer.kt:40) at com.alex.newnotes.utils.AudioTensorSource$Companion.fromRecording(AudioTensorSource.kt:118) at com.alex.newnotes.ui.edit.EditFragment.addSpeech$lambda-28(EditFragment.kt:454) at com.alex.newnotes.ui.edit.EditFragment.$r8$lambda$imHSLHbY2ngBNrJv8EgIh9aj6Dg(Unknown Source:0) at com.alex.newnotes.ui.edit.EditFragment$$ExternalSyntheticLambda11.run(Unknown Source:2) at java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:462) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1167) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:641) at java.lang.Thread.run(Thread.java:923)

@YUNQIUGUO
Copy link
Contributor

YUNQIUGUO commented Nov 20, 2023

Either input_names/output names, input/output data types, they all have to match with the onnx model you are testing with.

The original example is using a model that has input data type of float tensor and I am guessing your model is expecting an uint8 type tensor.

Could you please check with Netron that what inputs/outputs your model have?

@YUNQIUGUO YUNQIUGUO changed the title [Mobile] [Mobile] whisper android sample Nov 20, 2023
@marouanetalaa
Copy link
Author

[Yes] you are right it takes tensor: uint8[1,?] as input and output an str: tensor: string[N,text], then it uses an AudioDecoder to ouput floatPCM

Is the solution to add the option --no_audio_decoder when generating the model, or is there a way i can change my code to adapt to the model ?

@marouanetalaa
Copy link
Author

marouanetalaa commented Nov 21, 2023

After doing that, I don't get that error anymore, but the thing is I am trying to use --multilingual option which requires decoder_inputs_id that i am not adding, I get this error :
`[E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running BeamSearch node. Name:'BeamSearch_node' Status Message: /onnxruntime_src/include/onnxruntime/core/framework/op_kernel_context.h:42 const T *onnxruntime::OpKernelContext::Input(int) const [T = onnxruntime::Tensor] Missing Input: decoder_input_ids
2023-11-21 10:30:40.303 4324-4373 note_id com.alex.newnotes E Error: Error code - ORT_RUNTIME_EXCEPTION - message: Non-zero status code returned while running BeamSearch node. Name:'BeamSearch_node' Status Message: /onnxruntime_src/include/onnxruntime/core/framework/op_kernel_context.h:42 const T *onnxruntime::OpKernelContext::Input(int) const [T = onnxruntime::Tensor] Missing Input: decoder_input_ids

                                                                                                ai.onnxruntime.OrtException: Error code - ORT_RUNTIME_EXCEPTION - message: Non-zero status code returned while running BeamSearch node. Name:'BeamSearch_node' Status Message: /onnxruntime_src/include/onnxruntime/core/framework/op_kernel_context.h:42 const T *onnxruntime::OpKernelContext::Input(int) const [T = onnxruntime::Tensor] Missing Input: decoder_input_ids
                                                                                                
                                                                                                	at ai.onnxruntime.OrtSession.run(Native Method)
                                                                                                	at ai.onnxruntime.OrtSession.run(OrtSession.java:301)
                                                                                                	at ai.onnxruntime.OrtSession.run(OrtSession.java:242)
                                                                                                	at ai.onnxruntime.OrtSession.run(OrtSession.java:210)
                                                                                                	at com.alex.newnotes.utils.SpeechRecognizer.run(SpeechRecognizer.kt:40)
                                                                                                	at com.alex.newnotes.utils.AudioTensorSource$Companion.fromRecording(AudioTensorSource.kt:118)
                                                                                                	at com.alex.newnotes.ui.edit.EditFragment.addSpeech$lambda-28(EditFragment.kt:454)
                                                                                                	at com.alex.newnotes.ui.edit.EditFragment.$r8$lambda$imHSLHbY2ngBNrJv8EgIh9aj6Dg(Unknown Source:0)
                                                                                                	at com.alex.newnotes.ui.edit.EditFragment$$ExternalSyntheticLambda11.run(Unknown Source:2)
                                                                                                	at java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:462)
                                                                                                	at java.util.concurrent.FutureTask.run(FutureTask.java:266)
                                                                                                	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1167)
                                                                                                	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:641)
                                                                                                	at java.lang.Thread.run(Thread.java:923)

`

I am supposed to do somthing that should look like this in python :
`import numpy as np
from transformers import AutoConfig, AutoProcessor

model = "openai/whisper-tiny"
config = AutoConfig.from_pretrained(model)
processor = AutoProcessor.from_pretrained(model)

English transcription
forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
""" forced_decoder_ids is of the format [(1, 50259), (2, 50359), (3, 50363)] and needs to be
of the format [50258, 50259, 50359, 50363] where 50258 is the start token id"""
forced_decoder_ids = [config.decoder_start_token_id] + list(map(lambda token: token[1], forced_decoder_ids))

"""If you don't want to provide specific decoder input ids or you want
Whisper to predict the output language and task, you can set
forced_decoder_ids = [config.decoder_start_token_id]
[50258]

decoder input ids"""
decoder_input_ids = np.array([forced_decoder_ids], dtype=np.int32)`

but how do I go about it in kotlin is there some documentation about it ? Thank you

@edgchen1
Copy link
Contributor

to construct an int32 tensor input in Kotlin, you can refer to this example:
https://github.com/microsoft/onnxruntime-inference-examples/blob/174f8bd10b82d65fb6bbf36deb6c5aaf87e8ed6c/mobile/examples/whisper/local/android/app/src/main/java/ai/onnxruntime/example/whisperLocal/SpeechRecognizer.kt#L25

I'm not sure if there's a transformers equivalent in Kotlin. if you know the values you'd like to set for decoder_input_ids you can set them directly. you could also run the Python code, examine what values processor.get_decoder_prompt_ids(language="english", task="transcribe") produces, and copy those over.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
api:Java issues related to the Java API platform:mobile issues related to ONNX Runtime mobile; typically submitted using template
Projects
None yet
Development

No branches or pull requests

4 participants