Skip to content

Commit

Permalink
Add input box in android demo so that users can specify their keywords
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Jan 11, 2024
1 parent c02284c commit d63cc88
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import android.os.Bundle
import android.text.method.ScrollingMovementMethod
import android.util.Log
import android.widget.Button
import android.widget.EditText
import android.widget.TextView
import androidx.appcompat.app.AppCompatActivity
import androidx.core.app.ActivityCompat
Expand All @@ -25,6 +26,7 @@ class MainActivity : AppCompatActivity() {
private var audioRecord: AudioRecord? = null
private lateinit var recordButton: Button
private lateinit var textView: TextView
private lateinit var inputText: EditText
private var recordingThread: Thread? = null

private val audioSource = MediaRecorder.AudioSource.MIC
Expand Down Expand Up @@ -74,6 +76,8 @@ class MainActivity : AppCompatActivity() {

textView = findViewById(R.id.my_text)
textView.movementMethod = ScrollingMovementMethod()

inputText = findViewById(R.id.input_text)
}

private fun onclick() {
Expand All @@ -91,6 +95,14 @@ class MainActivity : AppCompatActivity() {
lastText = ""
idx = 0

var keywords = inputText.text.toString()
Log.i(TAG, keywords)
keywords = keywords.replace("\n", "/")
val status = model.setKeywords(keywords)
if (!status) {
Log.i(TAG, "Failed to setKeywords.")
}

recordingThread = thread(true) {
processSamples()
}
Expand Down Expand Up @@ -129,9 +141,9 @@ class MainActivity : AppCompatActivity() {
if (lastText.isBlank()) {
textToDisplay = "${idx}: ${text}"
} else {
textToDisplay = "${lastText}\n${idx}: ${text}"
textToDisplay = "${idx}: ${text}\n${lastText}"
}
lastText = "${lastText}\n${idx}: ${text}"
lastText = "${idx}: ${text}\n${lastText}"
idx += 1
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class SherpaOnnxKws(
fun inputFinished() = inputFinished(ptr)
fun decode() = decode(ptr)
fun isReady(): Boolean = isReady(ptr)
fun setKeywords(keywords: String): Boolean = setKeywords(ptr, keywords)

val keyword: String
get() = getKeyword(ptr)
Expand All @@ -75,6 +76,7 @@ class SherpaOnnxKws(
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun inputFinished(ptr: Long)
private external fun getKeyword(ptr: Long): String
private external fun setKeywords(ptr: Long, keywords: String): Boolean
private external fun decode(ptr: Long)
private external fun isReady(ptr: Long): Boolean

Expand Down
18 changes: 13 additions & 5 deletions android/SherpaOnnxKws/app/src/main/res/layout/activity_main.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,34 @@
android:gravity="center"
android:orientation="vertical">

<EditText
android:id="@+id/input_text"
android:layout_width="match_parent"
android:layout_height="320dp"
android:layout_weight="2.5"
android:hint="@string/keyword_hint"
android:scrollbars="vertical"
android:text=""
android:textSize="15dp" />

<TextView
android:id="@+id/my_text"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_height="443dp"
android:layout_weight="2.5"
android:padding="24dp"
android:scrollbars="vertical"
android:singleLine="false"
android:text="@string/hint"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />
android:textSize="15dp" />

<Button
android:id="@+id/record_button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_weight="0.5"
android:text="@string/start" />

</LinearLayout>


Expand Down
1 change: 1 addition & 0 deletions android/SherpaOnnxKws/app/src/main/res/values/strings.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
The source code and pre-trained models are publicly available.
Please see https://github.com/k2-fsa/sherpa-onnx for details.
</string>
<string name="keyword_hint">Input your keywords here, one keyword perline.</string>
<string name="start">Start</string>
<string name="stop">Stop</string>
</resources>
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/keyword-spotter-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class KeywordSpotterImpl {

virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;

virtual std::unique_ptr<OnlineStream> CreateStream(
const std::string& keywords) const = 0;

virtual bool IsReady(OnlineStream *s) const = 0;

virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
Expand Down
64 changes: 64 additions & 0 deletions sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,70 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
return stream;
}

std::unique_ptr<OnlineStream> CreateStream(
const std::string& keywords) const override {
auto kws = std::regex_replace(keywords, std::regex("/"), "\n");
std::istringstream is(kws);

std::vector<std::vector<int32_t>> current_ids;
std::vector<std::string> current_kws;
std::vector<float> current_scores;
std::vector<float> current_thresholds;

if (!EncodeKeywords(is, sym_, &current_ids, &current_kws, &current_scores,
&current_thresholds)) {
SHERPA_ONNX_LOGE("Encode keywords failed.");
return nullptr;
}

int32_t num_kws = current_ids.size();
int32_t num_default_kws = keywords_id_.size();

current_ids.insert(current_ids.end(), keywords_id_.begin(), keywords_id_.end());

if (!current_kws.empty() && !keywords_.empty()) {
current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end());
} else if (!current_kws.empty() && keywords_.empty()) {
current_kws.insert(current_kws.end(), num_default_kws, std::string());
} else if (current_kws.empty() && !keywords_.empty()) {
current_kws.insert(current_kws.end(), num_kws, std::string());
current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end());
} else {
// Do nothing.
}

if (!current_scores.empty() && !boost_scores_.empty()) {
current_scores.insert(current_scores.end(), boost_scores_.begin(), boost_scores_.end());
} else if (!current_scores.empty() && boost_scores_.empty()) {
current_scores.insert(current_scores.end(), num_default_kws, config_.keywords_score);
} else if (current_scores.empty() && !boost_scores_.empty()) {
current_scores.insert(current_scores.end(), num_kws, config_.keywords_score);
current_scores.insert(current_scores.end(), boost_scores_.begin(), boost_scores_.end());
} else {
// Do nothing.
}

if (!current_thresholds.empty() && !thresholds_.empty()) {
current_thresholds.insert(current_thresholds.end(), thresholds_.begin(), thresholds_.end());
} else if (!current_thresholds.empty() && thresholds_.empty()) {
current_thresholds.insert(current_thresholds.end(), num_default_kws, config_.keywords_threshold);
} else if (current_thresholds.empty() && !thresholds_.empty()) {
current_thresholds.insert(current_thresholds.end(), num_kws, config_.keywords_threshold);
current_thresholds.insert(current_thresholds.end(), thresholds_.begin(), thresholds_.end());
} else {
// Do nothing.
}

auto keywords_graph = std::make_shared<ContextGraph>(
current_ids, config_.keywords_score, config_.keywords_threshold,
current_scores, current_kws, current_thresholds);

auto stream =
std::make_unique<OnlineStream>(config_.feat_config, keywords_graph);
InitOnlineStream(stream.get());
return stream;
}

bool IsReady(OnlineStream *s) const override {
return s->GetNumProcessedFrames() + model_->ChunkSize() <
s->NumFramesReady();
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/keyword-spotter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream() const {
return impl_->CreateStream();
}

std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream(
const std::string& keywords) const {
return impl_->CreateStream(keywords);
}

bool KeywordSpotter::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}
Expand Down
12 changes: 12 additions & 0 deletions sherpa-onnx/csrc/keyword-spotter.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,18 @@ class KeywordSpotter {
*/
std::unique_ptr<OnlineStream> CreateStream() const;

/** Create a stream for decoding.
*
* @param The keywords for this string, it might contain several keywords,
* the keywords are separated by "/". In each of the keywords, there
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
* For example, keywords I LOVE YOU and HELLO WORLD, looks like:
*
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
*/
std::unique_ptr<OnlineStream> CreateStream(
const std::string &keywords) const;

/**
* Return true if the given stream has enough frames for decoding.
* Return false otherwise
Expand Down
23 changes: 23 additions & 0 deletions sherpa-onnx/jni/jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,16 @@ class SherpaOnnxKws {
stream_->InputFinished();
}

bool SetKeywords(const std::string& keywords) {
auto stream = keyword_spotter_.CreateStream(keywords);
if (stream == nullptr) {
return false;
} else {
stream_ = std::move(stream);
return true;
}
}

std::string GetKeyword() const {
auto result = keyword_spotter_.GetResult(stream_.get());
return result.keyword;
Expand Down Expand Up @@ -1251,6 +1261,19 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getKeyword(
return env->NewStringUTF(text.c_str());
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_setKeywords(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {

const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);

std::string keywords_str = p_keywords;

bool status = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->SetKeywords(keywords_str);
env->ReleaseStringUTFChars(keywords, p_keywords);
return status;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getTokens(JNIEnv *env, jobject /*obj*/,
Expand Down

0 comments on commit d63cc88

Please sign in to comment.