diff --git a/python-api-examples/non_streaming_server.py b/python-api-examples/non_streaming_server.py index 327fead8d..902f658c4 100755 --- a/python-api-examples/non_streaming_server.py +++ b/python-api-examples/non_streaming_server.py @@ -369,6 +369,7 @@ def check_args(args): if args.hotwords_file != "": assert args.decoding_method == "modified_beam_search", args.decoding_method + assert Path(args.hotwords_file).is_file(), args.hotwords_file def get_args(): diff --git a/sherpa-onnx/python/tests/test_text2token.py b/sherpa-onnx/python/tests/test_text2token.py index bad95288d..5f81449b1 100644 --- a/sherpa-onnx/python/tests/test_text2token.py +++ b/sherpa-onnx/python/tests/test_text2token.py @@ -7,6 +7,7 @@ # ctest --verbose -R test_text2token_py import unittest +from pathlib import Path import sherpa_onnx @@ -18,12 +19,23 @@ class TestText2Token(unittest.TestCase): def test_bpe(self): + tokens = f"{d}/text2token/tokens_en.txt" + bpe_model = f"{d}/text2token/bpe_en.model" + + if not Path(tokens).is_file() or not Path(bpe_model).is_file(): + print( + f"No test data found, skipping test_bpe().\n" + f"You can download the test data by: \n" + f"git clone git@github.com:pkufool/sherpa-test-data.git /tmp/sherpa-test-data" + ) + return + texts = ["HELLO WORLD", "I LOVE YOU"] encoded_texts = sherpa_onnx.text2token( texts, - tokens=f"{d}/text2token/tokens_en.txt", + tokens=tokens, tokens_type="bpe", - bpe_model=f"{d}/text2token/bpe_en.model", + bpe_model=bpe_model, ) assert encoded_texts == [ ["▁HE", "LL", "O", "▁WORLD"], @@ -32,17 +44,27 @@ def test_bpe(self): encoded_ids = sherpa_onnx.text2token( texts, - tokens=f"{d}/text2token/tokens_en.txt", + tokens=tokens, tokens_type="bpe", - bpe_model=f"{d}/text2token/bpe_en.model", + bpe_model=bpe_model, output_ids=True, ) assert encoded_ids == [[22, 58, 24, 425], [19, 370, 47]], encoded_ids def test_cjkchar(self): + tokens = f"{d}/text2token/tokens_cn.txt" + + if not Path(tokens).is_file(): + print( + f"No test data found, skipping test_cjkchar().\n" + f"You can download the test data by: \n" + f"git clone git@github.com:pkufool/sherpa-test-data.git /tmp/sherpa-test-data" + ) + return + texts = ["世界人民大团结", "中国 VS 美国"] encoded_texts = sherpa_onnx.text2token( - texts, tokens=f"{d}/text2token/tokens_cn.txt", tokens_type="cjkchar" + texts, tokens=tokens, tokens_type="cjkchar" ) assert encoded_texts == [ ["世", "界", "人", "民", "大", "团", "结"], @@ -50,7 +72,7 @@ def test_cjkchar(self): ], encoded_texts encoded_ids = sherpa_onnx.text2token( texts, - tokens=f"{d}/text2token/tokens_cn.txt", + tokens=tokens, tokens_type="cjkchar", output_ids=True, ) @@ -60,12 +82,23 @@ def test_cjkchar(self): ], encoded_ids def test_cjkchar_bpe(self): + tokens = f"{d}/text2token/tokens_mix.txt" + bpe_model = f"{d}/text2token/bpe_mix.model" + + if not Path(tokens).is_file() or not Path(bpe_model).is_file(): + print( + f"No test data found, skipping test_cjkchar_bpe().\n" + f"You can download the test data by: \n" + f"git clone git@github.com:pkufool/sherpa-test-data.git /tmp/sherpa-test-data" + ) + return + texts = ["世界人民 GOES TOGETHER", "中国 GOES WITH 美国"] encoded_texts = sherpa_onnx.text2token( texts, - tokens=f"{d}/text2token/tokens_mix.txt", + tokens=tokens, tokens_type="cjkchar+bpe", - bpe_model=f"{d}/text2token/bpe_mix.model", + bpe_model=bpe_model, ) assert encoded_texts == [ ["世", "界", "人", "民", "▁GO", "ES", "▁TOGETHER"], @@ -73,9 +106,9 @@ def test_cjkchar_bpe(self): ], encoded_texts encoded_ids = sherpa_onnx.text2token( texts, - tokens=f"{d}/text2token/tokens_mix.txt", + tokens=tokens, tokens_type="cjkchar+bpe", - bpe_model=f"{d}/text2token/bpe_mix.model", + bpe_model=bpe_model, output_ids=True, ) assert encoded_ids == [