Skip to content

Commit

Permalink
fix: formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
TheEimer committed May 16, 2024
1 parent e3dea64 commit 97103ca
Show file tree
Hide file tree
Showing 23 changed files with 56 additions and 124 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: clean clean-build clean-pyc clean-test coverage dist docs help install check formatbump-version release
.PHONY: clean clean-build clean-pyc clean-test coverage dist docs help install check bump-version release format
.DEFAULT_GOAL := help

define BROWSER_PYSCRIPT
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
</a>
</p>

[![PyPI Version](https://img.shields.io/pypi/v/arlbench.svg)](https://pypi.python.org/pypi/arlbench)
<!--- [![PyPI Version](https://img.shields.io/pypi/v/arlbench.svg)](https://pypi.python.org/pypi/arlbench) -->
[![Test](https://github.com/automl/arlbench/actions/workflows/pytest.yaml/badge.svg)](https://github.com/automl/arlbench/actions/workflows/pytest.yaml)
[![Doc Status](https://github.com/automl/arlbench/actions/workflows/docs.yaml/badge.svg)](https://github.com/automl/arlbench/actions/workflows/docs.yaml)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
Expand Down
11 changes: 2 additions & 9 deletions arlbench/autorl/autorl_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,8 @@
import numpy as np
import pandas as pd

from arlbench.core.algorithms import (
DQN,
PPO,
SAC,
Algorithm,
AlgorithmState,
TrainResult,
TrainReturnT,
)
from arlbench.core.algorithms import (DQN, PPO, SAC, Algorithm, AlgorithmState,
TrainResult, TrainReturnT)
from arlbench.core.environments import make_env
from arlbench.utils import config_space_to_gymnasium_space

Expand Down
5 changes: 2 additions & 3 deletions arlbench/autorl/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
from flashbax.buffers.prioritised_trajectory_buffer import (
PrioritisedTrajectoryBufferState,
)
from flashbax.buffers.prioritised_trajectory_buffer import \
PrioritisedTrajectoryBufferState
from flashbax.buffers.sum_tree import SumTreeState
from flashbax.vault import Vault
from flax.core.frozen_dict import FrozenDict
Expand Down
35 changes: 8 additions & 27 deletions arlbench/core/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,16 @@
from collections.abc import Callable
from typing import Optional, Union

from flashbax.buffers.prioritised_trajectory_buffer import (
PrioritisedTrajectoryBufferState,
)
from flashbax.buffers.prioritised_trajectory_buffer import \
PrioritisedTrajectoryBufferState

from .algorithm import Algorithm
from .dqn import (
DQN,
DQNMetrics,
DQNRunnerState,
DQNState,
DQNTrainingResult,
DQNTrainReturnT,
)
from .ppo import (
PPO,
PPOMetrics,
PPORunnerState,
PPOState,
PPOTrainingResult,
PPOTrainReturnT,
)
from .sac import (
SAC,
SACMetrics,
SACRunnerState,
SACState,
SACTrainingResult,
SACTrainReturnT,
)
from .dqn import (DQN, DQNMetrics, DQNRunnerState, DQNState, DQNTrainingResult,
DQNTrainReturnT)
from .ppo import (PPO, PPOMetrics, PPORunnerState, PPOState, PPOTrainingResult,
PPOTrainReturnT)
from .sac import (SAC, SACMetrics, SACRunnerState, SACState, SACTrainingResult,
SACTrainReturnT)

TrainResult = Union[DQNTrainingResult, PPOTrainingResult, SACTrainingResult]
TrainMetrics = Union[DQNMetrics, PPOMetrics, SACMetrics]
Expand Down
9 changes: 3 additions & 6 deletions arlbench/core/algorithms/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,9 @@
from flashbax import utils
from flashbax.buffers import sum_tree
from flashbax.buffers.prioritised_trajectory_buffer import (
Experience,
PrioritisedTrajectoryBufferSample,
PrioritisedTrajectoryBufferState,
_get_sample_trajectories,
get_invalid_indices,
)
Experience, PrioritisedTrajectoryBufferSample,
PrioritisedTrajectoryBufferState, _get_sample_trajectories,
get_invalid_indices)
from flashbax.buffers.trajectory_buffer import calculate_uniform_item_indices

if TYPE_CHECKING:
Expand Down
10 changes: 2 additions & 8 deletions arlbench/core/algorithms/dqn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
from .dqn import (
DQN,
DQNMetrics,
DQNRunnerState,
DQNState,
DQNTrainingResult,
DQNTrainReturnT,
)
from .dqn import (DQN, DQNMetrics, DQNRunnerState, DQNState, DQNTrainingResult,
DQNTrainReturnT)

__all__ = [
"DQN",
Expand Down
20 changes: 6 additions & 14 deletions arlbench/core/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,23 @@
import jax.numpy as jnp
import numpy as np
import optax
from ConfigSpace import (
Categorical,
Configuration,
ConfigurationSpace,
EqualsCondition,
Float,
Integer,
)
from ConfigSpace import (Categorical, Configuration, ConfigurationSpace,
EqualsCondition, Float, Integer)
from flax.training.train_state import TrainState

from arlbench.core import running_statistics
from arlbench.core.algorithms.algorithm import Algorithm
from arlbench.core.algorithms.buffers import uniform_sample
from arlbench.core.algorithms.common import TimeStep
from arlbench.core.algorithms.prioritised_item_buffer import (
make_prioritised_item_buffer,
)
from arlbench.core.algorithms.prioritised_item_buffer import \
make_prioritised_item_buffer

from .models import CNNQ, MLPQ

if TYPE_CHECKING:
import chex
from flashbax.buffers.prioritised_trajectory_buffer import (
PrioritisedTrajectoryBufferState,
)
from flashbax.buffers.prioritised_trajectory_buffer import \
PrioritisedTrajectoryBufferState
from flax.core.frozen_dict import FrozenDict

from arlbench.core.environments import Environment
Expand Down
10 changes: 2 additions & 8 deletions arlbench/core/algorithms/ppo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
from .ppo import (
PPO,
PPOMetrics,
PPORunnerState,
PPOState,
PPOTrainingResult,
PPOTrainReturnT,
)
from .ppo import (PPO, PPOMetrics, PPORunnerState, PPOState, PPOTrainingResult,
PPOTrainReturnT)

__all__ = [
"PPO",
Expand Down
3 changes: 2 additions & 1 deletion arlbench/core/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import jax.numpy as jnp
import numpy as np
import optax
from ConfigSpace import Categorical, Configuration, ConfigurationSpace, Float, Integer
from ConfigSpace import (Categorical, Configuration, ConfigurationSpace, Float,
Integer)
from flax.training.train_state import TrainState

from arlbench.core import running_statistics
Expand Down
9 changes: 3 additions & 6 deletions arlbench/core/algorithms/prioritised_item_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@
from flashbax.buffers.item_buffer import validate_item_buffer_args
from flashbax.buffers.prioritised_flat_buffer import validate_priority_exponent
from flashbax.buffers.prioritised_trajectory_buffer import (
PrioritisedTrajectoryBuffer,
PrioritisedTrajectoryBufferSample,
PrioritisedTrajectoryBufferState,
make_prioritised_trajectory_buffer,
validate_device,
)
PrioritisedTrajectoryBuffer, PrioritisedTrajectoryBufferSample,
PrioritisedTrajectoryBufferState, make_prioritised_trajectory_buffer,
validate_device)
from flashbax.utils import add_dim_to_args

if TYPE_CHECKING:
Expand Down
10 changes: 2 additions & 8 deletions arlbench/core/algorithms/sac/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
from .sac import (
SAC,
SACMetrics,
SACRunnerState,
SACState,
SACTrainingResult,
SACTrainReturnT,
)
from .sac import (SAC, SACMetrics, SACRunnerState, SACState, SACTrainingResult,
SACTrainReturnT)

__all__ = [
"SAC",
Expand Down
32 changes: 9 additions & 23 deletions arlbench/core/algorithms/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,24 @@
import jax.numpy as jnp
import numpy as np
import optax
from ConfigSpace import (
Categorical,
Configuration,
ConfigurationSpace,
EqualsCondition,
Float,
Integer,
)
from ConfigSpace import (Categorical, Configuration, ConfigurationSpace,
EqualsCondition, Float, Integer)
from flax.training.train_state import TrainState

from arlbench.core import running_statistics
from arlbench.core.algorithms.algorithm import Algorithm
from arlbench.core.algorithms.buffers import uniform_sample
from arlbench.core.algorithms.common import TimeStep
from arlbench.core.algorithms.prioritised_item_buffer import (
make_prioritised_item_buffer,
)

from .models import (
AlphaCoef,
SACCNNActor,
SACCNNCritic,
SACMLPActor,
SACMLPCritic,
SACVectorCritic,
)
from arlbench.core.algorithms.prioritised_item_buffer import \
make_prioritised_item_buffer

from .models import (AlphaCoef, SACCNNActor, SACCNNCritic, SACMLPActor,
SACMLPCritic, SACVectorCritic)

if TYPE_CHECKING:
import chex
from flashbax.buffers.prioritised_trajectory_buffer import (
PrioritisedTrajectoryBufferState,
)
from flashbax.buffers.prioritised_trajectory_buffer import \
PrioritisedTrajectoryBufferState
from flax.core.frozen_dict import FrozenDict

from arlbench.core.environments import Environment
Expand Down
11 changes: 3 additions & 8 deletions arlbench/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
from .common import (
config_space_to_gymnasium_space,
config_space_to_yaml,
gymnasium_space_to_gymnax_space,
recursive_concat,
save_defaults_to_yaml,
tuple_concat,
)
from .common import (config_space_to_gymnasium_space, config_space_to_yaml,
gymnasium_space_to_gymnax_space, recursive_concat,
save_defaults_to_yaml, tuple_concat)
from .handle_termination import HandleTermination

__all__ = [
Expand Down
1 change: 1 addition & 0 deletions tests/autorl/test_autorl_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import pytest

from arlbench import AutoRLEnv
from arlbench.core.algorithms import DQN

Expand Down
1 change: 1 addition & 0 deletions tests/autorl/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax
import numpy as np

from arlbench.autorl import AutoRLEnv
from arlbench.autorl.checkpointing import Checkpointer
from arlbench.core.algorithms import DQN
Expand Down
1 change: 1 addition & 0 deletions tests/autorl/test_objectives.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import numpy as np

from arlbench import AutoRLEnv
from arlbench.core.algorithms import DQN

Expand Down
1 change: 1 addition & 0 deletions tests/core/algorithms/test_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings

import jax

from arlbench.core.algorithms import DQN
from arlbench.core.environments import make_env

Expand Down
1 change: 1 addition & 0 deletions tests/core/algorithms/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings

import jax

from arlbench.core.algorithms import PPO
from arlbench.core.environments import make_env

Expand Down
1 change: 1 addition & 0 deletions tests/core/algorithms/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings

import jax

from arlbench.core.algorithms import SAC
from arlbench.core.environments import make_env

Expand Down
1 change: 1 addition & 0 deletions tests/core/environments/test_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax
import numpy as np

from arlbench.core.algorithms import DQN, PPO
from arlbench.core.environments import make_env

Expand Down
1 change: 1 addition & 0 deletions tests/core/environments/test_gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax
import numpy as np

from arlbench.core.algorithms import DQN, PPO, SAC
from arlbench.core.environments import make_env

Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import gymnasium
import numpy as np
from arlbench.utils import config_space_to_gymnasium_space
from ConfigSpace import Categorical, ConfigurationSpace, Float, Integer

from arlbench.utils import config_space_to_gymnasium_space


def test_config_space_to_gymnasium_space():
config_space = ConfigurationSpace(
Expand Down

0 comments on commit 97103ca

Please sign in to comment.