Skip to content

Commit

Permalink
fix: fix callback of throttled/debounced decorated functions with mis…
Browse files Browse the repository at this point in the history
…matched args (#184)

* fix: fix throttled inspection

* build: change typing-ext deps

* fix: use inspect.signature

* use get_max_args

* fix: fix typing
  • Loading branch information
tlambert03 authored Aug 17, 2023
1 parent 1da26ce commit 64dfb43
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 71 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ jobs:
run: |
python -m pip install -U pip
python -m pip install -e .[test,pyqt5]
python -m pip install qtpy==1.1.0 typing-extensions==3.10.0.0
python -m pip install qtpy==1.1.0 typing-extensions==3.7.4.3
- name: Test
uses: aganders3/[email protected]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ dependencies = [
"packaging",
"pygments>=2.4.0",
"qtpy>=1.1.0",
"typing-extensions",
"typing-extensions >=3.7.4.3,!=3.10.0.0",
]

# extras
Expand Down
126 changes: 57 additions & 69 deletions src/superqt/utils/_throttler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,19 @@
SOFTWARE.
"""
import sys
from __future__ import annotations

from concurrent.futures import Future
from enum import IntFlag, auto
from functools import wraps
from typing import TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union, overload
from typing import TYPE_CHECKING, Callable, Generic, TypeVar, overload

from qtpy.QtCore import QObject, Qt, QTimer, Signal

from ._util import get_max_args

if TYPE_CHECKING:
from qtpy.QtCore import SignalInstance
from typing_extensions import Literal, ParamSpec
from typing_extensions import ParamSpec

P = ParamSpec("P")
# maintain runtime compatibility with older typing_extensions
Expand Down Expand Up @@ -70,7 +72,7 @@ def __init__(
self,
kind: Kind,
emissionPolicy: EmissionPolicy,
parent: Optional[QObject] = None,
parent: QObject | None = None,
) -> None:
super().__init__(parent)

Expand Down Expand Up @@ -166,7 +168,7 @@ class QSignalThrottler(GenericSignalThrottler):
def __init__(
self,
policy: EmissionPolicy = EmissionPolicy.Leading,
parent: Optional[QObject] = None,
parent: QObject | None = None,
) -> None:
super().__init__(Kind.Throttler, policy, parent)

Expand All @@ -181,38 +183,52 @@ class QSignalDebouncer(GenericSignalThrottler):
def __init__(
self,
policy: EmissionPolicy = EmissionPolicy.Trailing,
parent: Optional[QObject] = None,
parent: QObject | None = None,
) -> None:
super().__init__(Kind.Debouncer, policy, parent)


# below here part is unique to superqt (not from KD)


if TYPE_CHECKING:
from typing_extensions import Protocol

class ThrottledCallable(Generic[P, R], Protocol):
triggered: "SignalInstance"
class ThrottledCallable(GenericSignalThrottler, Generic[P, R]):
def __init__(
self,
func: Callable[P, R],
kind: Kind,
emissionPolicy: EmissionPolicy,
parent: QObject | None = None,
) -> None:
super().__init__(kind, emissionPolicy, parent)

def cancel(self) -> None:
...
self._future: Future[R] = Future()
self.__wrapped__ = func

def flush(self) -> None:
...
self._args: tuple = ()
self._kwargs: dict = {}
self.triggered.connect(self._set_future_result)

def set_timeout(self, timeout: int) -> None:
...
# even if we were to compile __call__ with a signature matching that of func,
# PySide wouldn't correctly inspect the signature of the ThrottledCallable
# instance: https://bugreports.qt.io/browse/PYSIDE-2423
# so we do it ourselfs and limit the number of positional arguments
# that we pass to func
self._max_args: int | None = get_max_args(func)

if sys.version_info < (3, 9):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "Future[R]": # noqa
if not self._future.done():
self._future.cancel()

def __call__(self, *args: "P.args", **kwargs: "P.kwargs") -> Future:
...
self._future = Future()
self._args = args
self._kwargs = kwargs

else:
self.throttle()
return self._future

def __call__(self, *args: "P.args", **kwargs: "P.kwargs") -> Future[R]:
...
def _set_future_result(self):
result = self.__wrapped__(*self._args[: self._max_args], **self._kwargs)
self._future.set_result(result)


@overload
Expand All @@ -221,28 +237,26 @@ def qthrottled(
timeout: int = 100,
leading: bool = True,
timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer,
) -> "ThrottledCallable[P, R]":
) -> ThrottledCallable[P, R]:
...


@overload
def qthrottled(
func: Optional["Literal[None]"] = None,
func: None = ...,
timeout: int = 100,
leading: bool = True,
timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer,
) -> Callable[[Callable[P, R]], "ThrottledCallable[P, R]"]:
) -> Callable[[Callable[P, R]], ThrottledCallable[P, R]]:
...


def qthrottled(
func: Optional[Callable[P, R]] = None,
func: Callable[P, R] | None = None,
timeout: int = 100,
leading: bool = True,
timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer,
) -> Union[
"ThrottledCallable[P, R]", Callable[[Callable[P, R]], "ThrottledCallable[P, R]"]
]:
) -> ThrottledCallable[P, R] | Callable[[Callable[P, R]], ThrottledCallable[P, R]]:
"""Creates a throttled function that invokes func at most once per timeout.
The throttled function comes with a `cancel` method to cancel delayed func
Expand Down Expand Up @@ -280,28 +294,26 @@ def qdebounced(
timeout: int = 100,
leading: bool = False,
timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer,
) -> "ThrottledCallable[P, R]":
) -> ThrottledCallable[P, R]:
...


@overload
def qdebounced(
func: Optional["Literal[None]"] = None,
func: None = ...,
timeout: int = 100,
leading: bool = False,
timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer,
) -> Callable[[Callable[P, R]], "ThrottledCallable[P, R]"]:
) -> Callable[[Callable[P, R]], ThrottledCallable[P, R]]:
...


def qdebounced(
func: Optional[Callable[P, R]] = None,
func: Callable[P, R] | None = None,
timeout: int = 100,
leading: bool = False,
timer_type: Qt.TimerType = Qt.TimerType.PreciseTimer,
) -> Union[
"ThrottledCallable[P, R]", Callable[[Callable[P, R]], "ThrottledCallable[P, R]"]
]:
) -> ThrottledCallable[P, R] | Callable[[Callable[P, R]], ThrottledCallable[P, R]]:
"""Creates a debounced function that delays invoking `func`.
`func` will not be invoked until `timeout` ms have elapsed since the last time
Expand Down Expand Up @@ -337,41 +349,17 @@ def qdebounced(


def _make_decorator(
func: Optional[Callable[P, R]],
func: Callable[P, R] | None,
timeout: int,
leading: bool,
timer_type: Qt.TimerType,
kind: Kind,
) -> Union[
"ThrottledCallable[P, R]", Callable[[Callable[P, R]], "ThrottledCallable[P, R]"]
]:
def deco(func: Callable[P, R]) -> "ThrottledCallable[P, R]":
) -> ThrottledCallable[P, R] | Callable[[Callable[P, R]], ThrottledCallable[P, R]]:
def deco(func: Callable[P, R]) -> ThrottledCallable[P, R]:
policy = EmissionPolicy.Leading if leading else EmissionPolicy.Trailing
throttle = GenericSignalThrottler(kind, policy)
throttle.setTimerType(timer_type)
throttle.setTimeout(timeout)
last_f = None
future: Optional[Future] = None

@wraps(func)
def inner(*args: "P.args", **kwargs: "P.kwargs") -> Future:
nonlocal last_f
nonlocal future
if last_f is not None:
throttle.triggered.disconnect(last_f)
if future is not None and not future.done():
future.cancel()

future = Future()
last_f = lambda: future.set_result(func(*args, **kwargs)) # noqa
throttle.triggered.connect(last_f)
throttle.throttle()
return future

inner.cancel = throttle.cancel
inner.flush = throttle.flush
inner.set_timeout = throttle.setTimeout
inner.triggered = throttle.triggered
return inner # type: ignore
obj = ThrottledCallable(func, kind, policy)
obj.setTimerType(timer_type)
obj.setTimeout(timeout)
return wraps(func)(obj)

return deco(func) if func is not None else deco
29 changes: 29 additions & 0 deletions tests/test_throttler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from unittest.mock import Mock

import pytest
from qtpy.QtCore import QObject, Signal

from superqt.utils import qdebounced, qthrottled


Expand Down Expand Up @@ -41,3 +44,29 @@ def f2() -> str:
qtbot.wait(5)
assert mock1.call_count == 2
assert mock2.call_count == 10


@pytest.mark.parametrize("deco", [qthrottled, qdebounced])
def test_ensure_throttled_sig_inspection(deco, qtbot):
mock = Mock()

class Emitter(QObject):
sig = Signal(int, int, int)

@deco
def func(a: int, b: int):
"""docstring"""
mock(a, b)

obj = Emitter()
obj.sig.connect(func)

# this is the crux of the test...
# we emit 3 args, but the function only takes 2
# this should normally work fine in Qt.
# testing here that the decorator doesn't break it.
with qtbot.waitSignal(func.triggered, timeout=1000):
obj.sig.emit(1, 2, 3)
mock.assert_called_once_with(1, 2)
assert func.__doc__ == "docstring"
assert func.__name__ == "func"

0 comments on commit 64dfb43

Please sign in to comment.