Skip to content

Commit

Permalink
make replacement kwarg in warn_deprecated_function optional
Browse files Browse the repository at this point in the history
Sometimes we want to just deprecate a function without having a replacement for it

PiperOrigin-RevId: 678675304
  • Loading branch information
fabianp authored and ChexDev committed Sep 26, 2024
1 parent 343d03a commit 84a3899
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
20 changes: 13 additions & 7 deletions chex/_src/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Utilities to emit warnings."""

import functools
from typing import Any, Callable
import warnings


Expand Down Expand Up @@ -57,7 +58,9 @@ def wrapper(*args, **kwargs):
)


def warn_deprecated_function(fun, replacement):
def warn_deprecated_function(
fun: Callable[..., Any], replacement: str | None = None
):
"""A decorator to mark a function definition as deprecated.
Example usage:
Expand All @@ -67,20 +70,23 @@ def warn_deprecated_function(fun, replacement):
Args:
fun: the deprecated function.
replacement: the name of the function to be used instead.
replacement: name of the function to be used instead.
Returns:
the wrapped function.
"""
if hasattr(fun, '__name__'):
warning_message = f'The function {fun.__name__} is deprecated.'
else:
warning_message = 'The function is deprecated.'
if replacement:
warning_message += f' Please use {replacement} instead.'

@functools.wraps(fun)
def new_fun(*args, **kwargs):
warnings.warn(
f'The function {fun.__name__} is deprecated, '
f'please use {replacement} instead.',
category=DeprecationWarning,
stacklevel=2)
warnings.warn(warning_message, category=DeprecationWarning, stacklevel=2)
return fun(*args, **kwargs)

return new_fun


Expand Down
27 changes: 23 additions & 4 deletions chex/_src/warnings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import functools

from absl.testing import absltest

from chex._src import warnings


Expand All @@ -26,13 +25,20 @@ def f(a, b, c):
return a + b + c


@warnings.warn_deprecated_function
def g0(a, b, c):
return a + b + c


@functools.partial(warnings.warn_deprecated_function, replacement='h')
def g(a, b, c):
def g1(a, b, c):
return a + b + c


def h1(a, b, c):
return a + b + c


h2 = warnings.create_deprecated_function_alias(h1, 'path.h2', 'path.h1')


Expand All @@ -44,9 +50,22 @@ def test_warn_only_n_pos_args_in_future(self):
with self.assertWarns(Warning):
f(1, 2, c=3)

def test_warn_deprecated_function_no_replacement(self):
with self.assertWarns(Warning) as cm:
g0(1, 2, 3)

warning_message = str(cm.warnings[0].message)

self.assertIn('The function g0 is deprecated', warning_message)
# the warning message doesn't have a replacement function
self.assertNotIn('please use', warning_message)

def test_warn_deprecated_function(self):
with self.assertWarns(Warning):
g(1, 2, 3)
with self.assertWarns(Warning) as cm:
g1(1, 2, 3)

warning_message = str(cm.warnings[0].message)
self.assertIn('The function g1 is deprecated', warning_message)

def test_create_deprecated_function_alias(self):
with self.assertWarns(Warning):
Expand Down

0 comments on commit 84a3899

Please sign in to comment.