From 77197b255388631644595a47ca2b41b502951243 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 9 Oct 2023 15:06:02 +0100 Subject: [PATCH] amend --- tensordict/tensorclass.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 8435ba4af..9cd95fa59 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -183,14 +183,8 @@ def __torch_function__( cls.state_dict = _state_dict cls.load_state_dict = _load_state_dict - methods = set(TensorDict.__dict__.keys()).union(TensorDictBase.__dict__.keys()) - for attr in methods: - func = getattr(TensorDict, attr, None) - if func is None: - # get TensorDictBase - func = getattr(TensorDictBase, attr, None) - if func is None: - continue + for attr in TensorDict.__dict__.keys(): + func = getattr(TensorDict, attr) if ( inspect.ismethod(func) and issubclass(func.__self__, TensorDictBase) ): # detects classmethods