Skip to content

High Level API

Dominik Jain edited this page Nov 8, 2023 · 2 revisions

Introducing the Tianshou High-Level API

The new high-level library was created based on object-oriented design principles with two primary design goals:

  • ease of use for the end user (without sacrificing generality)

    This is achieved through:

    • a single, well-defined point of interaction (ExperimentBuilder) which uses declarative semantics, allowing the user to focus on what to do rather than how to do it.

    • easily injectible parametrisation.

      For complex parametrisation involving objects, the respective library classes are easily discoverable, keeping the need to browse reference documentation - or, even worse, inspect code or class hierarchies - to an absolute minimium.

    • reduced points of failure.

      Because the high-level API is at a higher level of abstraction, where more knowledge is available, we can centrally define reasonable defaults and apply consistency checks in order to ensure that illegal configurations result in meaningful errors (and are completely avoided as long as the users does not modify default behaviour). For example, we can consider interactions between the nature of the action space and the neural networks being used.

  • maintainability for developers

    This is achieved through:

    • a modular design with strong separation of concerns
    • a high level of factorisation, which largely avoids duplication, partly through the use of mixins and multiple inheritance. This invariably makes the code slightly more complex, yet it greatly reduces the lines of code to be written/updated, so it is a reasonable compromise in this case.

Changeset

The entire high-level library is in its own subpackage tianshou.highlevel and almost no changes were made to the original library in order to support the new APIs. For the most part, only typing-related changes were made, which have aligned type annotations with existing example applications or have made explicit interfaces that were previously implicit.

Furthermore, some helper modules were added to the the tianshou.util package (all of which were copied from the sensAI library).

Many example applications were added, based on the existing MuJoCo and Atari examples (see below).

User-Facing Interface

User Experience Example

To illustrate the UX, consider this video recording (IntelliJ IDEA):

UX

Observe how conveniently relevant classes can be discovered via the IDE's auto-completion function. Discoverability is markedly enhanced by using a prefix-based naming convention, where classes that can be used as parameters use the base class name as a prefix, allowing all potentially relevant subclasses to be straightforwardly auto-completed.

Declarative Semantics

A key design principle for the user-facing interface was to achieve declarative semantics, where the user is no longer concerned with generating a lengthy procedure that sequentially constructs components that build upon each other. Instead, the user focuses purely on declaring the properties of the learning task he would like to run.

  • This essentially reduces boiler-plate code to zero, as every part of the code is defining essential, experiment-specific configuration.
  • This makes it possible to centrally handle interdependent configuration and detect/avoid misspecification.

In order to enable the configuration of interdependent objects without requiring the user to instantiate the respective objects sequentially, we heavily employ the factory pattern.

Experiment Builders

The end user's primary entry point is an ExperimentBuilder, which is specialised for each algorithm. As the name suggests, it uses the builder pattern in order to create an Experiment object, which is then used to run the learning task.

  • At builder construction, the user is required to provide only essential configuration, particularly the environment factory.
  • The bulk of the algorithm-specific parameters can be provided via an algorithm-specific parameter object. For instance, PPOExperimentBuilder has the method with_ppo_params, which expects an object of type PPOParams.
  • Parametrisation that requires the provision of more complex interfaces (e.g. were multiple specification variants exist) are handled via dedicated builder methods. For example, for the specification of the critic component in an actor-critic algorithm, the following group of functions is provided:
    • with_critic_factory (where the user can provide any (user-defined) factory for the critic component)
    • with_critic_factory_default (with which the user specifies that the default, Net-based critic architecture shall be used and has the option to parametrise it)
    • with_critic_factory_use_actor (with which the user indicates that the critic component shall reuse the preprocessing network from the actor component)

Examples

Minimal Example

In the simplest of cases, where the user wants to use the default parametrisation for everything, a user could run a PPO learning task as follows,

experiment = PPOExperimentBuilder(MyEnvFactory()).build()
experiment.run()

where MyEnvFactory is a factory for the agent's environment. The default behaviour will adapt depending on whether the factory creates environments with discrete or continuous action spaces.

Fully Parametrised MuJoCo Example

Importantly, the user still has the option to configure all the details. Consider this example, which is from the high-level version of the mujoco_ppo example:

log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag())

sampling_config = SamplingConfig(
    num_epochs=epoch,
    step_per_epoch=step_per_epoch,
    batch_size=batch_size,
    num_train_envs=training_num,
    num_test_envs=test_num,
    buffer_size=buffer_size,
    step_per_collect=step_per_collect,
    repeat_per_collect=repeat_per_collect,
)

env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)

experiment = (
    PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
    .with_ppo_params(
        PPOParams(
            discount_factor=gamma,
            gae_lambda=gae_lambda,
            action_bound_method=bound_action_method,
            reward_normalization=rew_norm,
            ent_coef=ent_coef,
            vf_coef=vf_coef,
            max_grad_norm=max_grad_norm,
            value_clip=value_clip,
            advantage_normalization=norm_adv,
            eps_clip=eps_clip,
            dual_clip=dual_clip,
            recompute_advantage=recompute_adv,
            lr=lr,
            lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
            if lr_decay
            else None,
            dist_fn=DistributionFunctionFactoryIndependentGaussians(),
        ),
    )
    .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
    .with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
    .build()
)
experiment.run(log_name)

This is functionally equivalent to the procedural, low-level example. Compare the scripts here:

In general, find example applications of the high-level API in the examples/ folder in scripts using the _hl.py suffix:

Experiments

The Experiment representation contains

  • the agent factory ,
  • the environment factory,
  • further definitions pertaining to storage & logging.

An exeriment may be run several times, assigning a name (and corresponding storage location) to each run.

Persistence and Logging

Experiments can be serialized and later be reloaded.

    experiment = Experiment.from_directory("log/my_experiment")

Because the experiment representation is composed purely of configuration and factories, which themselves are composed purely of configuration and factories, persisted objects are compact and do not contain state.

Every experiment run produces the following artifacts:

  • the serialized experiment
  • the serialized best policy found during training
  • a log file
  • (optionally) user-defined data, as the persistence handlers are modular

Running a reloaded experiment can optionally resume training of the serialized policy.

All relevant objects have meaningful string representations that can appear in logs, which is conveniently achieved through the use of ToStringMixin (from sensAI). Its use furthermore prevents string representations of recurring objects from being printed more than once. For example, consider this string representation, which was generated for the fully parametrised PPO experiment from the example above:

Experiment[
    config=ExperimentConfig(
        seed=42, 
        device='cuda', 
        policy_restore_directory=None, 
        train=True, 
        watch=True, 
        watch_render=0.0, 
        persistence_base_dir='log', 
        persistence_enabled=True), 
    sampling_config=SamplingConfig[
        num_epochs=100, 
        step_per_epoch=30000, 
        batch_size=64, 
        num_train_envs=64, 
        num_test_envs=10, 
        buffer_size=4096, 
        step_per_collect=2048, 
        repeat_per_collect=10, 
        update_per_step=1.0, 
        start_timesteps=0, 
        start_timesteps_random=False, 
        replay_buffer_ignore_obs_next=False, 
        replay_buffer_save_only_last_obs=False, 
        replay_buffer_stack_num=1], 
    env_factory=MujocoEnvFactory[
        task=Ant-v4, 
        seed=42, 
        obs_norm=True], 
    agent_factory=PPOAgentFactory[
        sampling_config=SamplingConfig[<<], 
        optim_factory=OptimizerFactoryAdam[
            weight_decay=0, 
            eps=1e-08, 
            betas=(0.9, 0.999)], 
        policy_wrapper_factory=None, 
        trainer_callbacks=TrainerCallbacks(
            epoch_callback_train=None, 
            epoch_callback_test=None, 
            stop_callback=None), 
        params=PPOParams[
            gae_lambda=0.95, 
            max_batchsize=256, 
            lr=0.0003, 
            lr_scheduler_factory=LRSchedulerFactoryLinear[sampling_config=SamplingConfig[<<]], 
            action_scaling=default, 
            action_bound_method=clip, 
            discount_factor=0.99, 
            reward_normalization=True, 
            deterministic_eval=False, 
            dist_fn=DistributionFunctionFactoryIndependentGaussians[], 
            vf_coef=0.25, 
            ent_coef=0.0, 
            max_grad_norm=0.5, 
            eps_clip=0.2, 
            dual_clip=None, 
            value_clip=False, 
            advantage_normalization=False, 
            recompute_advantage=True], 
        actor_factory=ActorFactoryTransientStorageDecorator[
            actor_factory=ActorFactoryDefault[
                continuous_actor_type=ContinuousActorType.GAUSSIAN, 
                continuous_unbounded=True, 
                continuous_conditioned_sigma=False, 
                hidden_sizes=[64, 64], 
                hidden_activation=<class 'torch.nn.modules.activation.Tanh'>, 
                discrete_softmax=True]], 
        critic_factory=CriticFactoryDefault[
            hidden_sizes=[64, 64], 
            hidden_activation=<class 'torch.nn.modules.activation.Tanh'>], 
        critic_use_action=False], 
    logger_factory=LoggerFactoryDefault[
        logger_type=tensorboard, 
        wandb_project=None], 
    env_config=None]

Library Developer Perspective

The presentation thus far has focussed on the user's perspective. From the perspective of a Tianshou developer, it is important that the high-level API be clearly structured and maintainable. Here are the most relevant representations:

  • Policy parameters are represented as dataclasses (base class Params).

    The goal is for the parameters to be ultimately passed to the corresponding policy class (e.g. PPOParams contains parameters for PPOPolicy).

    • Parameter transformation: In part, the parameter dataclass attributes already correspond directly to policy class parameters. However, because the high-level interface must, in many cases, abstract away from the low-level interface, we establish the notion of a ParamTransformer, which transforms one or more parameters into the form that is required by the policy class: The idea is that the dictionary representation of the dataclass is successively transformed via ParamTransformers such that the resulting dictionary can ultimately be used as keyword arguments for the policy. To achieve maintainability, the declaration of parameter transformations is colocated with the parameters they affect. Tests ensure that naming issues are detected.

    • Composition and inheritance: We use inheritance and mixins to reduce duplication.

  • Factories are an essential principle of the library. Because the creation of objects may depend on objects that are not yet created, a declarative approach necessitates that we transition from the objects themselves to factories.

    • The EnvFactory was already mentioned above, as it is a user-facing abstraction. Its purpose is to create the (vectorized) Environments that will be used in the experiments.
    • An AgentFactory is the central component that creates the policy, the trainer as well as the necessary collectors. To support a new type of policy, a subclass that handles the policy creation is required. In turn, the main task when implementing a new algorithm-specific ExperimentBuilder is the creation of the corresponding AgentFactory.
    • Several types of factories serve to parametrize policies and training processes, e.g.
      • OptimizerFactory for the creation of torch optimizers
      • ActorFactory for the creation of actor models
      • CriticFactory for the creation of critic models
      • IntermediateModuleFactory for the creation of models that produce intermediate/latent representations
      • EnvParamFactory for the creation of parameters based on properties of the environment
      • NoiseFactory for the creation of BaseNoise instances
      • DistributionFunctionFactory for the creation of functions that create torch distributions from tensors
      • LRSchedulerFactory for learning rate schedulers
      • PolicyWrapperFactory for policy wrappers that extend the functionality of the regular policy (e.g. intrinsic curiosity)
      • AutoAlphaFactory for automatically tuned regularization coefficients (as supported by SAC or REDQ)
    • A LoggerFactory handles the creation of the experiment logger, but the default implementation already handles the cases that were used in the examples.
  • The ExperimentBuilder implementations make use of mixins to add common functionality. As mentioned above, the main task in an algorithm-specific specialization is to create the AgentFactory.

Supporting a New Algorithm

In order to support a new algorithm in the high-level API, follow these steps:

  1. Create the algorithm-specific parameter representation (specialization of class Params).

    • Try to reuse existing mixins, factorising as much as is reasonable.
    • If the algorithm uses a new type of complex parameter, for which no representation yet exists, create new factory-style classes, which can be easily persisted.
    • Take care to apply all the necessary parameter transformations in order to produce the representations that can ultimately be passed to the low-level policy instance.
  2. Implement the algorithm-specific agent factory (specialization of AgentFactory).

    Inherit from an appropriate base class for algorithms with similar properties to minimize the implementation effort. In many cases, the implementation will be just a few lines.

  3. Implement the algorithm-specific experiment builder (specialization of ExperimentBuilder).

    Make use of the right mixins to provide the required configuration methods.