diff --git a/README.md b/README.md index 068de023..91b65d51 100644 --- a/README.md +++ b/README.md @@ -64,32 +64,14 @@ Where ID is the agent's ID given when its created (`train.py` prints this outt, To train agents with custom models, environments, etc. you write your own script. The following is a minimal example: ```python -from angorapy.agent.ppo_agent import PPOAgent -from angorapy.common.policies import BetaPolicyDistribution -from angorapy.common.transformers import RewardNormalizationTransformer, StateNormalizationTransformer from angorapy.common.wrappers import make_env from angorapy.models import get_model_builder +from angorapy.agent.ppo_agent import PPOAgent -wrappers = [StateNormalizationTransformer, RewardNormalizationTransformer] -env = make_env("LunarLanderContinuous-v2", reward_config=None, transformers=wrappers) - -# make policy distribution -distribution = BetaPolicyDistribution(env) - -# the agent needs to create the model itself, so we build a method that builds a model -build_models = get_model_builder(model="simple", model_type="ffn", shared=False) - -# given the model builder and the environment we can create an agent -agent = PPOAgent(build_models, env, horizon=1024, workers=12, distribution=distribution) - -# let's check the agents ID, so we can find its saved states after training -print(f"My Agent's ID: {agent.agent_id}") - -# ... and then train that agent for n cycles -agent.drill(n=100, epochs=3, batch_size=64) - -# after training, we can save the agent for analysis or the like -agent.save_agent_state() +env = make_env("LunarLanderContinuous-v2") +model_builder = get_model_builder("simple", "ffn") +agent = PPOAgent(model_builder, env) +agent.drill(100, 10, 512) ``` For more details, consult the [examples](examples).