Skip to content

Commit

Permalink
Make utils private
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Nov 25, 2023
1 parent 9589d86 commit 7a2c917
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 15 deletions.
6 changes: 3 additions & 3 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater

from benchmarl.models.common import ModelConfig
from benchmarl.utils import DEVICE_TYPING, read_yaml_config
from benchmarl.utils import _read_yaml_config, DEVICE_TYPING


class Algorithm(ABC):
Expand Down Expand Up @@ -351,7 +351,7 @@ def _load_from_yaml(name: str) -> Dict[str, Any]:
/ "algorithm"
/ f"{name.lower()}.yaml"
)
return read_yaml_config(str(yaml_path.resolve()))
return _read_yaml_config(str(yaml_path.resolve()))

@classmethod
def get_from_yaml(cls, path: Optional[str] = None):
Expand All @@ -372,7 +372,7 @@ def get_from_yaml(cls, path: Optional[str] = None):
)
)
else:
return cls(**read_yaml_config(path))
return cls(**_read_yaml_config(path))

@staticmethod
@abstractmethod
Expand Down
6 changes: 3 additions & 3 deletions benchmarl/environments/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchrl.data import CompositeSpec
from torchrl.envs import EnvBase, RewardSum, Transform

from benchmarl.utils import DEVICE_TYPING, read_yaml_config
from benchmarl.utils import _read_yaml_config, DEVICE_TYPING


def _load_config(name: str, config: Dict[str, Any]):
Expand Down Expand Up @@ -255,7 +255,7 @@ def __str__(self):
@staticmethod
def _load_from_yaml(name: str) -> Dict[str, Any]:
yaml_path = Path(__file__).parent.parent / "conf" / "task" / f"{name}.yaml"
return read_yaml_config(str(yaml_path.resolve()))
return _read_yaml_config(str(yaml_path.resolve()))

def get_from_yaml(self, path: Optional[str] = None) -> Task:
"""
Expand All @@ -273,4 +273,4 @@ def get_from_yaml(self, path: Optional[str] = None) -> Task:
Task._load_from_yaml(str(Path(self.env_name()) / Path(task_name)))
)
else:
return self.update_config(**read_yaml_config(path))
return self.update_config(**_read_yaml_config(path))
6 changes: 3 additions & 3 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from benchmarl.experiment.callback import Callback, CallbackNotifier
from benchmarl.experiment.logger import Logger
from benchmarl.models.common import ModelConfig
from benchmarl.utils import read_yaml_config
from benchmarl.utils import _read_yaml_config

_has_hydra = importlib.util.find_spec("hydra") is not None
if _has_hydra:
Expand Down Expand Up @@ -244,9 +244,9 @@ def get_from_yaml(path: Optional[str] = None):
/ "experiment"
/ "base_experiment.yaml"
)
return ExperimentConfig(**read_yaml_config(str(yaml_path.resolve())))
return ExperimentConfig(**_read_yaml_config(str(yaml_path.resolve())))
else:
return ExperimentConfig(**read_yaml_config(path))
return ExperimentConfig(**_read_yaml_config(path))

def validate(self, on_policy: bool):
"""
Expand Down
8 changes: 4 additions & 4 deletions benchmarl/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase

from benchmarl.utils import class_from_name, DEVICE_TYPING, read_yaml_config
from benchmarl.utils import _class_from_name, _read_yaml_config, DEVICE_TYPING


def _check_spec(tensordict, spec):
Expand All @@ -28,7 +28,7 @@ def parse_model_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
kwargs = {}
for key, value in cfg.items():
if key.endswith("class") and value is not None:
value = class_from_name(cfg[key])
value = _class_from_name(cfg[key])
kwargs.update({key: value})
return kwargs

Expand Down Expand Up @@ -282,7 +282,7 @@ def _load_from_yaml(name: str) -> Dict[str, Any]:
/ "layers"
/ f"{name.lower()}.yaml"
)
cfg = read_yaml_config(str(yaml_path.resolve()))
cfg = _read_yaml_config(str(yaml_path.resolve()))
return parse_model_config(cfg)

@classmethod
Expand All @@ -304,7 +304,7 @@ def get_from_yaml(cls, path: Optional[str] = None):
)
)
else:
return cls(**parse_model_config(read_yaml_config(path)))
return cls(**parse_model_config(_read_yaml_config(path)))


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions benchmarl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
DEVICE_TYPING = Union[torch.device, str, int]


def read_yaml_config(config_file: str) -> Dict[str, Any]:
def _read_yaml_config(config_file: str) -> Dict[str, Any]:
with open(config_file) as config:
yaml_string = config.read()
config_dict = yaml.safe_load(yaml_string)
Expand All @@ -22,7 +22,7 @@ def read_yaml_config(config_file: str) -> Dict[str, Any]:
return config_dict


def class_from_name(name: str):
def _class_from_name(name: str):
name_split = name.split(".")
module_name = ".".join(name_split[:-1])
class_name = name_split[-1]
Expand Down

0 comments on commit 7a2c917

Please sign in to comment.