diff --git a/docs/pyartnet.rst b/docs/pyartnet.rst index f631e8f..90c024b 100644 --- a/docs/pyartnet.rst +++ b/docs/pyartnet.rst @@ -164,10 +164,10 @@ Example channel = universe.add_channel(start=1, width=3) # set quadratic correction for the whole universe to quadratic - universe.set_output_correction(output_correction.quadratic) + universe.set_output_correction(output_correction.Quadratic()) # Explicitly set output for this channel to linear - channel.set_output_correction(output_correction.linear) + channel.set_output_correction(output_correction.Linear()) # Remove output correction for the channel. # The channel will now use the correction from the universe again diff --git a/readme.md b/readme.md index 576fe65..f445ad2 100644 --- a/readme.md +++ b/readme.md @@ -54,3 +54,7 @@ Docs and examples can be found [here](https://pyartnet.readthedocs.io/en/latest/ - added more and better type hints - switched to pytest - small fixes + +--- + +`Art-Netâ„¢ Designed by and Copyright Artistic Licence Engineering Ltd` diff --git a/src/pyartnet/base/channel.py b/src/pyartnet/base/channel.py index f1e3116..ae6d323 100644 --- a/src/pyartnet/base/channel.py +++ b/src/pyartnet/base/channel.py @@ -6,12 +6,12 @@ from pyartnet.errors import ChannelOutOfUniverseError, ChannelValueOutOfBoundsError, \ ChannelWidthError, ValueCountDoesNotMatchChannelWidthError -from pyartnet.output_correction import linear from ..fades import FadeBase, LinearFade from .channel_fade import ChannelBoundFade from .output_correction import OutputCorrection from .universe import BaseUniverse +from ..output_correction import Correction, Linear log = logging.getLogger('pyartnet.Channel') @@ -68,7 +68,7 @@ def __init__(self, universe: BaseUniverse, self._parent_universe: Final = universe self._parent_node: Final = universe._node - self._correction_current: Callable[[float, int], float] = linear + self._correction_current: Correction = Linear() # Fade self._current_fade: Optional[ChannelBoundFade] = None @@ -78,10 +78,12 @@ def __init__(self, universe: BaseUniverse, # --------------------------------------------------------------------- # Callbacks self.callback_fade_finished: Optional[Callable[[Channel], Any]] = None + self.callback_values_updated: Optional[Callable[[array[int]], None]] = None + def _apply_output_correction(self): # default correction is linear - self._correction_current = linear + self._correction_current = Linear() # inherit correction if it is not set first from universe and then from the node for obj in (self, self._parent_universe, self._parent_node): @@ -113,7 +115,7 @@ def set_values(self, values: Iterable[Union[int, float]]): raise ChannelValueOutOfBoundsError(f'Channel value out of bounds! 0 <= {val} <= {value_max:d}') self._values_raw[i] = raw_new - act_new = round(correction(val, value_max)) if correction is not linear else raw_new + act_new = round(correction.correct(val, value_max)) if not isinstance(correction, Linear) else raw_new if self._values_act[i] != act_new: changed = True self._values_act[i] = act_new @@ -137,6 +139,55 @@ def to_buffer(self, buf: bytearray): start += byte_size return self + def from_buffer(self, buf: bytearray): + byte_order = self._byte_order + byte_size = self._byte_size + correction = self._correction_current + value_max = self._value_max + + start_index = self._start + end_index = self._stop + 1 + + byte_chunks = Channel.__chunks(buf[start_index:end_index], byte_size) + + values_act = array( + 'i', [int.from_bytes(byte_chunk, byte_order, signed=False) + if len(byte_chunk) == byte_size else None + for byte_chunk in byte_chunks] + ) + + changed = False + for act_value_index, act_value in enumerate(values_act): + if act_value is None: + log.warning(f"Channel {start_index + act_value_index} was updated externally, but is part of an " + f"incomplete {byte_size} byte number. This is very likely unintended by the external " + f"controller.") + continue + + if self._values_act[act_value_index] == values_act[act_value_index]: + continue + + self._values_act[act_value_index] = values_act[act_value_index] + changed = True + + if not changed: + return + + self._values_act = values_act + + values_raw = [round(correction.reverse_correct(val, value_max)) for val in values_act] + for raw_value_index, raw_value in enumerate(values_raw): + self._values_raw[raw_value_index] = raw_value + + if self.callback_values_updated is not None: + self.callback_values_updated(self._values_raw) + + @staticmethod + def __chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i:i + n] + # noinspection PyProtectedMember def add_fade(self, values: Iterable[Union[int, FadeBase]], duration_ms: int, fade_class: Type[FadeBase] = LinearFade): diff --git a/src/pyartnet/base/universe.py b/src/pyartnet/base/universe.py index 5718ca9..d25dde7 100644 --- a/src/pyartnet/base/universe.py +++ b/src/pyartnet/base/universe.py @@ -29,6 +29,18 @@ def __init__(self, node: 'pyartnet.base.BaseNode', universe: int = 0): self._channels: Dict[str, 'pyartnet.base.Channel'] = {} + @property + def universe(self): + return self._universe + + @property + def data(self): + return self._data + + @property + def data_changed(self): + return self._data_changed + def _apply_output_correction(self): for c in self._channels.values(): c._apply_output_correction() @@ -49,6 +61,12 @@ def send_data(self): self._last_send = monotonic() self._data_changed = False + def receive_data(self, data: bytearray): + channels = self._channels + + for channel in channels.values(): + channel.from_buffer(data) + def get_channel(self, channel_name: str) -> 'pyartnet.base.Channel': """Return a channel by name or raise an exception diff --git a/src/pyartnet/output_correction.py b/src/pyartnet/output_correction.py index 7181cb2..4fd0102 100644 --- a/src/pyartnet/output_correction.py +++ b/src/pyartnet/output_correction.py @@ -1,18 +1,42 @@ -def linear(val: float, max_val: int = 0xFF) -> float: - """linear output correction""" - return val +class Correction: + def correct(self, val: float, max_val: int = 0xFF) -> float: + raise NotImplementedError() + def reverse_correct(self, val: float, max_val: int = 0xFF) -> float: + raise NotImplementedError() -def quadratic(val: float, max_val: int = 0xFF) -> float: - """Quadratic output correction""" - return (val ** 2) / max_val +class Linear(Correction): -def cubic(val: float, max_val: int = 0xFF) -> float: - """Cubic output correction""" - return (val ** 3) / (max_val ** 2) + def correct(self, val: float, max_val: int = 0xFF): + return val + def reverse_correct(self, val: float, max_val: int = 0xFF): + return val -def quadruple(val: float, max_val: int = 0xFF) -> float: - """Quadruple output correction""" - return (val ** 4) / (max_val ** 3) + +class Quadratic(Correction): + + def correct(self, val: float, max_val: int = 0xFF): + return val ** 2 / max_val + + def reverse_correct(self, val: float, max_val: int = 0xFF): + return val ** (1. / 2.) * max_val ** (1. / 2.) + + +class Cubic(Correction): + + def correct(self, val: float, max_val: int = 0xFF): + return val ** 3 / max_val ** 2 + + def reverse_correct(self, val: float, max_val: int = 0xFF): + return val ** (1. / 3.) * max_val ** (2. / 3.) + + +class Quadruple(Correction): + + def correct(self, val: float, max_val: int = 0xFF): + return val ** 4 / max_val ** 3 + + def reverse_correct(self, val: float, max_val: int = 0xFF): + return val ** (1. / 4.) * max_val ** (3. / 4.) diff --git a/tests/test_output_correction.py b/tests/test_output_correction.py index e37e7a0..827228e 100644 --- a/tests/test_output_correction.py +++ b/tests/test_output_correction.py @@ -1,11 +1,11 @@ import pytest -from pyartnet.output_correction import cubic, quadratic, quadruple +from pyartnet.output_correction import Quadratic, Quadruple, Cubic @pytest.mark.parametrize('max_val', [ pytest.param(k, id=f'{k:X}') for k in (0xFF, 0xFFFF, 0xFFFFFF, 0xFFFFFFFF, 0xFFFFFFFFFF)]) -@pytest.mark.parametrize('corr', [quadratic, quadruple, cubic]) +@pytest.mark.parametrize('corr', [Quadratic(), Quadruple(), Cubic()]) def test_correction(corr, max_val): - assert corr(0, max_val=max_val) == 0 - assert corr(max_val, max_val=max_val) == max_val + assert corr.correct(0, max_val=max_val) == 0 + assert corr.correct(max_val, max_val=max_val) == max_val