From 95bd846a859b201377a6baaefaf11bb97a60b225 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 13 Nov 2023 16:41:57 +0000 Subject: [PATCH] [Refactor] Refactor implement_for (#556) --- tensordict/utils.py | 120 +++++++++++++++++++++++++++++++++----------- 1 file changed, 92 insertions(+), 28 deletions(-) diff --git a/tensordict/utils.py b/tensordict/utils.py index 7e79fc2ff..5fc2e5d16 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2,12 +2,14 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - from __future__ import annotations +import collections import dataclasses import inspect import math + +import sys import time import warnings @@ -16,7 +18,7 @@ from copy import copy from functools import wraps from importlib import import_module -from typing import Any, Callable, List, Sequence, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Sequence, Tuple, TYPE_CHECKING, Union import numpy as np import torch @@ -984,6 +986,7 @@ class implement_for: # Stores pointers to fitting implementations: dict[func_name] = func_pointer _implementations = {} _setters = [] + _cache_modules = {} def __init__( self, @@ -1005,35 +1008,87 @@ def check_version(version, from_version, to_version): @staticmethod def get_class_that_defined_method(f): """Returns the class of a method, if it is defined, and None otherwise.""" - return f.__globals__.get(f.__qualname__.split(".")[0], None) + out = f.__globals__.get(f.__qualname__.split(".")[0], None) + return out - @property - def func_name(self): - return self.fn.__name__ + @classmethod + def get_func_name(cls, fn): + # produces a name like torchrl.module.Class.method or torchrl.module.function + first = str(fn).split(".")[0][len(" str: + @classmethod + def import_module(cls, module_name: Union[Callable, str]) -> str: """Imports module and returns its version.""" if not callable(module_name): - module = import_module(module_name) + module = cls._cache_modules.get(module_name, None) + if module is None: + if module_name in sys.modules: + sys.modules[module_name] = module = import_module(module_name) + else: + cls._cache_modules[module_name] = module = import_module( + module_name + ) else: module = module_name() return module.__version__ + _lazy_impl = collections.defaultdict(list) + + def _delazify(self, func_name): + for local_call in implement_for._lazy_impl[func_name]: + out = local_call() + return out + def __call__(self, fn): + # function names are unique + self.func_name = self.get_func_name(fn) self.fn = fn + implement_for._lazy_impl[self.func_name].append(self._call) + + @wraps(fn) + def _lazy_call_fn(*args, **kwargs): + # first time we call the function, we also do the replacement. + # This will cause the imports to occur only during the first call to fn + return self._delazify(self.func_name)(*args, **kwargs) + + return _lazy_call_fn + + def _call(self): # If the module is missing replace the function with the mock. + fn = self.fn func_name = self.func_name implementations = implement_for._implementations @@ -1043,41 +1098,50 @@ def unsupported(*args, **kwargs): f"Supported version of '{func_name}' has not been found." ) - do_set = False + self.do_set = False # Return fitting implementation if it was encountered before. if func_name in implementations: try: # check that backends don't conflict version = self.import_module(self.module_name) if self.check_version(version, self.from_version, self.to_version): - do_set = True - if not do_set: - return implementations[func_name] + self.do_set = True + if not self.do_set: + return implementations[func_name].fn except ModuleNotFoundError: # then it's ok, there is no conflict - return implementations[func_name] + return implementations[func_name].fn else: try: version = self.import_module(self.module_name) if self.check_version(version, self.from_version, self.to_version): - do_set = True + self.do_set = True except ModuleNotFoundError: return unsupported - if do_set: - implementations[func_name] = fn + if self.do_set: self.module_set() return fn return unsupported @classmethod - def reset(cls, setters=None): - if setters is None: - setters = copy(cls._setters) - cls._setters = [] - cls._implementations = {} - for setter in setters: - setter(setter.fn) - cls._setters.append(setter) + def reset(cls, setters_dict: Dict[str, implement_for] = None): + """Resets the setters in setter_dict. + + ``setter_dict`` is a copy of implementations. We just need to iterate through its + values and call :meth:`~.module_set` for each. + + """ + if setters_dict is None: + setters_dict = copy(cls._implementations) + for setter in setters_dict.values(): + setter.module_set() + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"module_name={self.module_name}({self.from_version, self.to_version}), " + f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)}, is_set={self.do_set})" + ) def _unfold_sequence(seq):