diff --git a/once/__init__.py b/once/__init__.py index 82b7615..5d1c67a 100644 --- a/once/__init__.py +++ b/once/__init__.py @@ -16,7 +16,7 @@ def _is_method(func: collections.abc.Callable): """Determine if a function is a method on a class.""" - if isinstance(func, (classmethod, staticmethod)): + if isinstance(func, (classmethod, staticmethod, property)): return True sig = inspect.signature(func) return "self" in sig.parameters @@ -396,6 +396,8 @@ def __get__(self, obj, cls) -> collections.abc.Callable: class once_per_instance: # pylint: disable=invalid-name """A version of once for class methods which runs once per instance.""" + is_property: bool + @classmethod def with_options(cls, per_thread: bool = False, retry_exceptions=False, allow_reset=False): return lambda func: cls( @@ -432,6 +434,11 @@ def once_factory(self) -> _ONCE_FACTORY_TYPE: def _inspect_function(self, func: collections.abc.Callable): if isinstance(func, (classmethod, staticmethod)): raise SyntaxError("Must use @once.once_per_class on classmethod and staticmethod") + if isinstance(func, property): + func = func.fget + self.is_property = True + else: + self.is_property = False if not _is_method(func): raise SyntaxError( "Attempting to use @once.once_per_instance method-only decorator " @@ -450,4 +457,6 @@ def __get__(self, obj, cls) -> collections.abc.Callable: bound_func, self.once_factory(), self.fn_type, self.retry_exceptions ) self.callables[obj] = callable + if self.is_property: + return callable() return callable diff --git a/once_test.py b/once_test.py index 28fd60b..8738a8d 100644 --- a/once_test.py +++ b/once_test.py @@ -882,6 +882,25 @@ def execute(i): self.assertEqual(min(results), 1) self.assertEqual(max(results), math.ceil(_N_WORKERS / 4)) + def test_once_per_instance_property(self): + counter = Counter() + + class _CallOnceClass: + @once.once_per_instance + @property + def value(self): + nonlocal counter + return counter.get_incremented() + + a = _CallOnceClass() + b = _CallOnceClass() + self.assertEqual(a.value, 1) + self.assertEqual(b.value, 2) + self.assertEqual(a.value, 1) + self.assertEqual(b.value, 2) + self.assertEqual(_CallOnceClass().value, 3) + self.assertEqual(_CallOnceClass().value, 4) + def test_once_per_class_classmethod(self): counter = Counter()