From f574229c509ba8254e3207352af557a01cad87c4 Mon Sep 17 00:00:00 2001 From: Kasperi Apell Date: Fri, 3 Jan 2025 00:00:57 -0800 Subject: [PATCH] Propagate callable parameter types using ParamSpec (#142306) (#143797) Summary: The codebase has a few locations where callable parameter type information is lost when the unpackings *args and **kwargs are typed as Any. Refactor these instances to retain type information using typing_extensions.ParamSpec. Also, in these functions, enforce return type with TypeVar. Addresses #142306 X-link: https://github.com/pytorch/pytorch/pull/143797 Approved by: https://github.com/Skylion007 Reviewed By: jeanschmidt Differential Revision: D67707277 fbshipit-source-id: 18e8dbde435a46e9dca17e3c81d1dd019601cf1d Co-authored-by: Aaron Gokaslan Co-authored-by: Xuehai Pan --- userbenchmark/dynamo/dynamobench/_dynamo/testing.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/testing.py b/userbenchmark/dynamo/dynamobench/_dynamo/testing.py index f8583b1bd..d401c83f0 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/testing.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/testing.py @@ -20,6 +20,7 @@ TypeVar, Union, ) +from typing_extensions import ParamSpec from unittest.mock import patch import torch @@ -51,6 +52,8 @@ log = logging.getLogger(__name__) +_P = ParamSpec("_P") + def clone_me(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: if x is None: @@ -407,9 +410,9 @@ def check_dynamic_shape_capture() -> bool: return not config.assume_static_by_default -def _make_fn_with_patches(fn: Callable[..., _T], *patches: Any) -> Callable[..., _T]: +def _make_fn_with_patches(fn: Callable[_P, _T], *patches: Any) -> Callable[_P, _T]: @functools.wraps(fn) - def _fn(*args: Any, **kwargs: Any) -> _T: + def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: with contextlib.ExitStack() as stack: for module, attr, val in patches: stack.enter_context(patch.object(module, attr, val))