Skip to content

Commit

Permalink
format -> black
Browse files Browse the repository at this point in the history
  • Loading branch information
Pouya Rostam committed Jan 24, 2024
1 parent e4e32f8 commit 5af920b
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 203 deletions.
160 changes: 64 additions & 96 deletions binding/python/_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


class EagleError(Exception):
def __init__(self, message: str = '', message_stack: Sequence[str] = None):
def __init__(self, message: str = "", message_stack: Sequence[str] = None):
super().__init__(message)

self._message = message
Expand All @@ -25,9 +25,9 @@ def __init__(self, message: str = '', message_stack: Sequence[str] = None):
def __str__(self):
message = self._message
if len(self._message_stack) > 0:
message += ':'
message += ":"
for i in range(len(self._message_stack)):
message += '\n [%d] %s' % (i, self._message_stack[i])
message += "\n [%d] %s" % (i, self._message_stack[i])
return message

@property
Expand Down Expand Up @@ -109,7 +109,7 @@ class PicovoiceStatuses(Enum):
PicovoiceStatuses.ACTIVATION_ERROR: EagleActivationError,
PicovoiceStatuses.ACTIVATION_LIMIT_REACHED: EagleActivationLimitError,
PicovoiceStatuses.ACTIVATION_THROTTLED: EagleActivationThrottledError,
PicovoiceStatuses.ACTIVATION_REFUSED: EagleActivationRefusedError
PicovoiceStatuses.ACTIVATION_REFUSED: EagleActivationRefusedError,
}


Expand Down Expand Up @@ -144,7 +144,7 @@ def to_bytes(self) -> bytes:
return self._to_bytes(self.handle, self.size)

@classmethod
def from_bytes(cls, profile: bytes) -> 'EagleProfile':
def from_bytes(cls, profile: bytes) -> "EagleProfile":
"""
Creates an instance of EagleProfile from a bytes object.
Expand Down Expand Up @@ -207,7 +207,7 @@ def __init__(self, access_key: str, model_path: str, library_path: str) -> None:
set_sdk_func.argtypes = [c_char_p]
set_sdk_func.restype = None

set_sdk_func('python'.encode('utf-8'))
set_sdk_func("python".encode("utf-8"))

self._get_error_stack_func = library.pv_get_error_stack
self._get_error_stack_func.argtypes = [POINTER(POINTER(c_char_p)), POINTER(c_int)]
Expand All @@ -221,50 +221,37 @@ def __init__(self, access_key: str, model_path: str, library_path: str) -> None:
self._eagle_profiler = POINTER(self.CEagleProfiler)()

init_func = library.pv_eagle_profiler_init
init_func.argtypes = [
c_char_p,
c_char_p,
POINTER(POINTER(self.CEagleProfiler))]
init_func.argtypes = [c_char_p, c_char_p, POINTER(POINTER(self.CEagleProfiler))]
init_func.restype = PicovoiceStatuses

status = init_func(
access_key.encode('utf-8'),
model_path.encode('utf-8'),
byref(self._eagle_profiler))
status = init_func(access_key.encode("utf-8"), model_path.encode("utf-8"), byref(self._eagle_profiler))
if status is not PicovoiceStatuses.SUCCESS:
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Profile initialization failed',
message_stack=self._get_error_stack())
message="Profile initialization failed", message_stack=self._get_error_stack()
)

speaker_profile_size_func = library.pv_eagle_profiler_export_size
speaker_profile_size_func.argtypes = [
POINTER(self.CEagleProfiler),
POINTER(c_int32)]
speaker_profile_size_func.argtypes = [POINTER(self.CEagleProfiler), POINTER(c_int32)]
speaker_profile_size_func.restype = PicovoiceStatuses

profile_size = c_int32()
status = speaker_profile_size_func(self._eagle_profiler, byref(profile_size))
if status is not PicovoiceStatuses.SUCCESS:
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Failed to get profile size',
message_stack=self._get_error_stack())
message="Failed to get profile size", message_stack=self._get_error_stack()
)
self._profile_size = profile_size.value

enroll_min_audio_length_sample_func = \
library.pv_eagle_profiler_enroll_min_audio_length_samples
enroll_min_audio_length_sample_func.argtypes = [
POINTER(self.CEagleProfiler),
POINTER(c_int32)]
enroll_min_audio_length_sample_func = library.pv_eagle_profiler_enroll_min_audio_length_samples
enroll_min_audio_length_sample_func.argtypes = [POINTER(self.CEagleProfiler), POINTER(c_int32)]
enroll_min_audio_length_sample_func.restype = PicovoiceStatuses

min_enroll_samples = c_int32()
status = enroll_min_audio_length_sample_func(
self._eagle_profiler,
byref(min_enroll_samples))
status = enroll_min_audio_length_sample_func(self._eagle_profiler, byref(min_enroll_samples))
if status is not PicovoiceStatuses.SUCCESS:
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Failed to get min audio length sample',
message_stack=self._get_error_stack())
message="Failed to get min audio length sample", message_stack=self._get_error_stack()
)
self._min_enroll_samples = min_enroll_samples.value

self._delete_func = library.pv_eagle_profiler_delete
Expand All @@ -277,17 +264,16 @@ def __init__(self, access_key: str, model_path: str, library_path: str) -> None:
POINTER(c_int16),
c_int32,
POINTER(c_int),
POINTER(c_float)]
POINTER(c_float),
]
self._enroll_func.restype = PicovoiceStatuses

self._reset_func = library.pv_eagle_profiler_reset
self._reset_func.argtypes = [POINTER(self.CEagleProfiler)]
self._reset_func.restype = PicovoiceStatuses

self._export_func = library.pv_eagle_profiler_export
self._export_func.argtypes = [
POINTER(self.CEagleProfiler),
c_void_p]
self._export_func.argtypes = [POINTER(self.CEagleProfiler), c_void_p]
self._export_func.restype = PicovoiceStatuses

self._sample_rate = library.pv_sample_rate()
Expand All @@ -297,7 +283,7 @@ def __init__(self, access_key: str, model_path: str, library_path: str) -> None:
version_func = library.pv_eagle_version
version_func.argtypes = []
version_func.restype = c_char_p
self._version = version_func().decode('utf-8')
self._version = version_func().decode("utf-8")

def enroll(self, pcm: Sequence[int]) -> Tuple[float, EagleProfilerEnrollFeedback]:
"""
Expand Down Expand Up @@ -330,17 +316,12 @@ def enroll(self, pcm: Sequence[int]) -> Tuple[float, EagleProfilerEnrollFeedback

feedback_code = c_int()
percentage = c_float()
status = self._enroll_func(
self._eagle_profiler,
c_pcm,
len(c_pcm),
byref(feedback_code),
byref(percentage))
status = self._enroll_func(self._eagle_profiler, c_pcm, len(c_pcm), byref(feedback_code), byref(percentage))
feedback = EagleProfilerEnrollFeedback(feedback_code.value)
if status is not PicovoiceStatuses.SUCCESS:
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Enrollment failed',
message_stack=self._get_error_stack())
message="Enrollment failed", message_stack=self._get_error_stack()
)

return percentage.value, feedback

Expand All @@ -353,14 +334,9 @@ def export(self) -> EagleProfile:
"""

profile = (c_byte * self._profile_size)()
status = self._export_func(
self._eagle_profiler,
byref(profile)
)
status = self._export_func(self._eagle_profiler, byref(profile))
if status is not PicovoiceStatuses.SUCCESS:
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Export failed',
message_stack=self._get_error_stack())
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](message="Export failed", message_stack=self._get_error_stack())

return EagleProfile(cast(profile, c_void_p), self._profile_size)

Expand All @@ -373,8 +349,8 @@ def reset(self) -> None:
status = self._reset_func(self._eagle_profiler)
if status is not PicovoiceStatuses.SUCCESS:
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Profile reset failed',
message_stack=self._get_error_stack())
message="Profile reset failed", message_stack=self._get_error_stack()
)

def delete(self) -> None:
"""
Expand Down Expand Up @@ -412,11 +388,11 @@ def _get_error_stack(self) -> Sequence[str]:
message_stack_depth = c_int()
status = self._get_error_stack_func(byref(message_stack_ref), byref(message_stack_depth))
if status is not PicovoiceStatuses.SUCCESS:
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](message='Unable to get Eagle error state')
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](message="Unable to get Eagle error state")

message_stack = list()
for i in range(message_stack_depth.value):
message_stack.append(message_stack_ref[i].decode('utf-8'))
message_stack.append(message_stack_ref[i].decode("utf-8"))

self._free_error_stack_func(message_stack_ref)

Expand Down Expand Up @@ -465,7 +441,7 @@ def __init__(
set_sdk_func.argtypes = [c_char_p]
set_sdk_func.restype = None

set_sdk_func('python'.encode('utf-8'))
set_sdk_func("python".encode("utf-8"))

self._get_error_stack_func = library.pv_get_error_stack
self._get_error_stack_func.argtypes = [POINTER(POINTER(c_char_p)), POINTER(c_int)]
Expand All @@ -479,38 +455,31 @@ def __init__(
self._eagle = POINTER(self.CEagle)()

init_func = library.pv_eagle_init
init_func.argtypes = [
c_char_p,
c_char_p,
c_int32,
POINTER(c_void_p),
POINTER(POINTER(self.CEagle))]
init_func.argtypes = [c_char_p, c_char_p, c_int32, POINTER(c_void_p), POINTER(POINTER(self.CEagle))]
init_func.restype = PicovoiceStatuses

profile_bytes = (c_void_p * len(speaker_profiles))()
for i, profile in enumerate(speaker_profiles):
profile_bytes[i] = profile.handle

status = init_func(
access_key.encode('utf-8'),
model_path.encode('utf-8'),
access_key.encode("utf-8"),
model_path.encode("utf-8"),
len(speaker_profiles),
profile_bytes,
byref(self._eagle))
byref(self._eagle),
)
if status is not PicovoiceStatuses.SUCCESS:
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Initialization failed',
message_stack=self._get_error_stack())
message="Initialization failed", message_stack=self._get_error_stack()
)

self._delete_func = library.pv_eagle_delete
self._delete_func.argtypes = [POINTER(self.CEagle)]
self._delete_func.restype = None

self._process_func = library.pv_eagle_process
self._process_func.argtypes = [
POINTER(self.CEagle),
POINTER(c_int16),
POINTER(c_float)]
self._process_func.argtypes = [POINTER(self.CEagle), POINTER(c_int16), POINTER(c_float)]
self._process_func.restype = PicovoiceStatuses

self._scores = (c_float * len(speaker_profiles))()
Expand All @@ -526,7 +495,7 @@ def __init__(
version_func = library.pv_eagle_version
version_func.argtypes = []
version_func.restype = c_char_p
self._version = version_func().decode('utf-8')
self._version = version_func().decode("utf-8")

def process(self, pcm: Sequence[int]) -> Sequence[float]:
"""
Expand All @@ -541,16 +510,17 @@ def process(self, pcm: Sequence[int]) -> Sequence[float]:

if len(pcm) != self.frame_length:
raise EagleInvalidArgumentError(
"Length of input frame %d does not match required frame length %d" % (len(pcm), self.frame_length))
"Length of input frame %d does not match required frame length %d" % (len(pcm), self.frame_length)
)

frame_type = c_int16 * self.frame_length
pcm = frame_type(*pcm)

status = self._process_func(self._eagle, pcm, self._scores)
if status is not PicovoiceStatuses.SUCCESS:
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Process failed',
message_stack=self._get_error_stack())
message="Process failed", message_stack=self._get_error_stack()
)

# noinspection PyTypeChecker
return [float(score) for score in self._scores]
Expand All @@ -564,9 +534,7 @@ def reset(self) -> None:

status = self._reset_func(self._eagle)
if status is not PicovoiceStatuses.SUCCESS:
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](
message='Reset failed',
message_stack=self._get_error_stack())
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](message="Reset failed", message_stack=self._get_error_stack())

def delete(self) -> None:
"""
Expand Down Expand Up @@ -604,32 +572,32 @@ def _get_error_stack(self) -> Sequence[str]:
message_stack_depth = c_int()
status = self._get_error_stack_func(byref(message_stack_ref), byref(message_stack_depth))
if status is not PicovoiceStatuses.SUCCESS:
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](message='Unable to get Eagle error state')
raise _PICOVOICE_STATUS_TO_EXCEPTION[status](message="Unable to get Eagle error state")

message_stack = list()
for i in range(message_stack_depth.value):
message_stack.append(message_stack_ref[i].decode('utf-8'))
message_stack.append(message_stack_ref[i].decode("utf-8"))

self._free_error_stack_func(message_stack_ref)

return message_stack


__all__ = [
'Eagle',
'EagleProfile',
'EagleProfiler',
'EagleProfilerEnrollFeedback',
'EagleActivationError',
'EagleActivationLimitError',
'EagleActivationRefusedError',
'EagleActivationThrottledError',
'EagleError',
'EagleInvalidArgumentError',
'EagleInvalidStateError',
'EagleIOError',
'EagleKeyError',
'EagleMemoryError',
'EagleRuntimeError',
'EagleStopIterationError',
"Eagle",
"EagleProfile",
"EagleProfiler",
"EagleProfilerEnrollFeedback",
"EagleActivationError",
"EagleActivationLimitError",
"EagleActivationRefusedError",
"EagleActivationThrottledError",
"EagleError",
"EagleInvalidArgumentError",
"EagleInvalidStateError",
"EagleIOError",
"EagleKeyError",
"EagleMemoryError",
"EagleRuntimeError",
"EagleStopIterationError",
]
15 changes: 4 additions & 11 deletions binding/python/_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@

from typing import Optional, Sequence, Union

from ._eagle import (
Eagle,
EagleProfile,
EagleProfiler
)
from ._eagle import Eagle, EagleProfile, EagleProfiler
from ._util import default_library_path, default_model_path


Expand Down Expand Up @@ -74,13 +70,10 @@ def create_profiler(
if library_path is None:
library_path = default_library_path()

return EagleProfiler(
access_key=access_key,
model_path=model_path,
library_path=library_path)
return EagleProfiler(access_key=access_key, model_path=model_path, library_path=library_path)


__all__ = [
'create_recognizer',
'create_profiler'
"create_recognizer",
"create_profiler"
]
Loading

0 comments on commit 5af920b

Please sign in to comment.