Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v3.0 python #579

Merged
merged 8 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/python-demos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ jobs:
- name: Pre-build dependencies
run: python -m pip install --upgrade pip

# ************** REMOVE AFTER RELEASE ********************
- name: Build binding
run: |
pip install wheel && cd ../../binding/python && python setup.py sdist bdist_wheel && pip install dist/pvrhino-3.0.0-py3-none-any.whl
# ********************************************************

- name: Install dependencies
run: pip install -r requirements.txt

Expand Down Expand Up @@ -77,6 +83,12 @@ jobs:
steps:
- uses: actions/checkout@v3

# ************** REMOVE AFTER RELEASE ********************
- name: Build binding
run: |
pip3 uninstall -y pvrhino && pip3 install wheel && cd ../../binding/python && python3 setup.py sdist bdist_wheel && pip3 install dist/pvrhino-3.0.0-py3-none-any.whl
# ********************************************************

- name: Install dependencies
run: pip3 install -r requirements.txt

Expand Down
92 changes: 84 additions & 8 deletions binding/python/_rhino.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,27 @@


class RhinoError(Exception):
pass
def __init__(self, message: str = '', message_stack: Sequence[str] = None):
super().__init__(message)

self._message = message
self._message_stack = list() if message_stack is None else message_stack

def __str__(self):
message = self._message
if len(self._message_stack) > 0:
message += ':'
for i in range(len(self._message_stack)):
message += '\n [%d] %s' % (i, self._message_stack[i])
return message

@property
def message(self) -> str:
return self._message

@property
def message_stack(self) -> Sequence[str]:
return self._message_stack


class RhinoMemoryError(RhinoError):
Expand Down Expand Up @@ -164,6 +184,22 @@ def __init__(
if not 0.5 <= endpoint_duration_sec <= 5.:
raise ValueError("Endpoint duration should be within [0.5, 5]")

set_sdk_func = library.pv_set_sdk
set_sdk_func.argtypes = [c_char_p]
set_sdk_func.restype = None

set_sdk_func('python'.encode('utf-8'))

self._get_error_stack_func = library.pv_get_error_stack
self._get_error_stack_func.argtypes = [POINTER(POINTER(c_char_p)), POINTER(c_int)]
self._get_error_stack_func.restype = self.PicovoiceStatuses

self._free_error_stack_func = library.pv_free_error_stack
self._free_error_stack_func.argtypes = [POINTER(c_char_p)]
self._free_error_stack_func.restype = None

set_sdk_func('python'.encode('utf-8'))

init_func = library.pv_rhino_init
init_func.argtypes = [
c_char_p,
Expand All @@ -186,7 +222,9 @@ def __init__(
require_endpoint,
byref(self._handle))
if status is not self.PicovoiceStatuses.SUCCESS:
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]()
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Initialization failed',
message_stack=self._get_error_stack())

self._delete_func = library.pv_rhino_delete
self._delete_func.argtypes = [POINTER(self.CRhino)]
Expand Down Expand Up @@ -224,7 +262,9 @@ def __init__(
context_info = c_char_p()
status = context_info_func(self._handle, byref(context_info))
if status is not self.PicovoiceStatuses.SUCCESS:
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]()
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Failed to get context info',
message_stack=self._get_error_stack())

self._context_info = context_info.value.decode('utf-8')

Expand Down Expand Up @@ -259,10 +299,23 @@ def process(self, pcm: Sequence[int]) -> bool:
is_finalized = c_bool()
status = self._process_func(self._handle, (c_short * len(pcm))(*pcm), byref(is_finalized))
if status is not self.PicovoiceStatuses.SUCCESS:
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]()
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Processing failed',
message_stack=self._get_error_stack())

return is_finalized.value

def reset(self) -> None:
"""
Resets the internal state of Rhino. It should be called before the engine can be used to infer intent from a new
stream of audio.
"""
status = self._reset_func(self._handle)
if status is not self.PicovoiceStatuses.SUCCESS:
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Processing failed',
message_stack=self._get_error_stack())

def get_inference(self) -> Inference:
"""
Gets inference results from Rhino. If the spoken command was understood, it includes the specific intent name
Expand All @@ -274,7 +327,9 @@ def get_inference(self) -> Inference:
is_understood = c_bool()
status = self._is_understood_func(self._handle, byref(is_understood))
if status is not self.PicovoiceStatuses.SUCCESS:
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]()
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Failed to get inference',
message_stack=self._get_error_stack())
is_understood = is_understood.value

if is_understood:
Expand All @@ -289,7 +344,9 @@ def get_inference(self) -> Inference:
byref(slot_keys),
byref(slot_values))
if status is not self.PicovoiceStatuses.SUCCESS:
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]()
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Failed to get intent',
message_stack=self._get_error_stack())

intent = intent.value.decode('utf-8')

Expand All @@ -299,14 +356,18 @@ def get_inference(self) -> Inference:

status = self._free_slots_and_values_func(self._handle, slot_keys, slot_values)
if status is not self.PicovoiceStatuses.SUCCESS:
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]()
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Failed to clear resources',
message_stack=self._get_error_stack())
else:
intent = None
slots = dict()

status = self._reset_func(self._handle)
if status is not self.PicovoiceStatuses.SUCCESS:
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]()
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Reset failed',
message_stack=self._get_error_stack())

return Inference(is_understood=is_understood, intent=intent, slots=slots)

Expand Down Expand Up @@ -334,6 +395,21 @@ def sample_rate(self) -> int:

return self._sample_rate

def _get_error_stack(self) -> Sequence[str]:
message_stack_ref = POINTER(c_char_p)()
message_stack_depth = c_int()
status = self._get_error_stack_func(byref(message_stack_ref), byref(message_stack_depth))
if status is not self.PicovoiceStatuses.SUCCESS:
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status](message='Unable to get Rhino error state')

message_stack = list()
for i in range(message_stack_depth.value):
message_stack.append(message_stack_ref[i].decode('utf-8'))

self._free_error_stack_func(message_stack_ref)

return message_stack


__all__ = [
'Rhino',
Expand Down
2 changes: 1 addition & 1 deletion binding/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

setuptools.setup(
name="pvrhino",
version="2.2.1",
version="3.0.0",
author="Picovoice",
author_email="[email protected]",
description="Rhino Speech-to-Intent engine.",
Expand Down
81 changes: 69 additions & 12 deletions binding/python/test_rhino.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,18 @@

from parameterized import parameterized

from _rhino import Rhino
from _util import *
from _rhino import Rhino, RhinoError
from test_util import *

within_context_parameters, out_of_context_parameters = load_test_data()


class RhinoTestCase(unittest.TestCase):

def run_rhino(self, language, context_name, is_within_context, intent=None, slots=None):
relative_path = '../..'

rhino = Rhino(
access_key=sys.argv[1],
library_path=pv_library_path(relative_path),
model_path=get_model_path_by_language(relative_path, language),
context_path=get_context_path_by_language(relative_path, context_name, language)
)
@staticmethod
def _process_file_helper(rhino: Rhino, audio_file: str, max_process_count: int = -1) -> bool:
processed = 0

audio_file = get_audio_file_by_language(relative_path, language, is_within_context)
audio = read_wav_file(
audio_file,
rhino.sample_rate)
Expand All @@ -44,6 +36,24 @@ def run_rhino(self, language, context_name, is_within_context, intent=None, slot
is_finalized = rhino.process(frame)
if is_finalized:
break
if max_process_count != -1 and processed >= max_process_count:
break
processed += 1

return is_finalized

def run_rhino(self, language, context_name, is_within_context, intent=None, slots=None):
relative_path = '../..'

rhino = Rhino(
access_key=sys.argv[1],
library_path=pv_library_path(relative_path),
model_path=get_model_path_by_language(relative_path, language),
context_path=get_context_path_by_language(relative_path, context_name, language)
)

audio_file = get_audio_file_by_language(relative_path, language, is_within_context)
is_finalized = self._process_file_helper(rhino, audio_file)

self.assertTrue(is_finalized, "Failed to finalize.")

Expand Down Expand Up @@ -75,6 +85,53 @@ def test_out_of_context(self, language, context_name):
context_name=context_name,
is_within_context=False)

def test_reset(self):
relative_path = '../..'

rhino = Rhino(
access_key=sys.argv[1],
library_path=pv_library_path(relative_path),
model_path=get_model_path_by_language(relative_path, 'en'),
context_path=get_context_path_by_language(relative_path, 'coffee_maker', 'en')
)
audio_file = get_audio_file_by_language(relative_path, 'en', True)

is_finalized = self._process_file_helper(rhino, audio_file, 15)
self.assertFalse(is_finalized)

rhino.reset()
is_finalized = self._process_file_helper(rhino, audio_file)
self.assertTrue(is_finalized)

inference = rhino.get_inference()
self.assertTrue(inference.is_understood)

def test_message_stack(self):
relative_path = '../..'

error = None
try:
_ = Rhino(
access_key='invalid',
library_path=pv_library_path(relative_path),
model_path=get_model_path_by_language(relative_path, 'en'),
context_path=get_context_path_by_language(relative_path, 'smart_lighting', 'en'))
except RhinoError as e:
error = e.message_stack

self.assertIsNotNone(error)
self.assertGreater(len(error), 0)

try:
_ = Rhino(
access_key='invalid',
library_path=pv_library_path(relative_path),
model_path=get_model_path_by_language(relative_path, 'en'),
context_path=get_context_path_by_language(relative_path, 'smart_lighting', 'en'))
except RhinoError as e:
self.assertEqual(len(error), len(e.message_stack))
self.assertListEqual(list(error), list(e.message_stack))


if __name__ == '__main__':
if len(sys.argv) != 2:
Expand Down
2 changes: 1 addition & 1 deletion demo/python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pvrhino==2.2.1
pvrhino==3.0.0
pvrecorder==1.2.1