diff --git a/rllib/BUILD b/rllib/BUILD index 0481d9c46a5a..c84b01a74856 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2973,15 +2973,15 @@ py_test( # subdirectory: ray_serve/ # .................................... -# TODO (sven): Uncomment once the problem with the path on BAZEL is solved. -# py_test( -# name = "examples/ray_serve/ray_serve_with_rllib", -# main = "examples/ray_serve/ray_serve_with_rllib.py", -# tags = ["team:rllib", "exclusive", "examples"], -# size = "medium", -# srcs = ["examples/ray_serve/ray_serve_with_rllib.py"], -# args = ["--train-iters=2", "--serve-episodes=2", "--no-render"] -# ) +py_test( + name = "examples/ray_serve/ray_serve_with_rllib", + main = "examples/ray_serve/ray_serve_with_rllib.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "medium", + srcs = ["examples/ray_serve/ray_serve_with_rllib.py"], + data = glob(["examples/ray_serve/classes/**"]), + args = ["--stop-iters=2", "--num-episodes-served=2", "--no-render", "--port=12345"] +) # subdirectory: ray_tune/ # .................................... diff --git a/rllib/examples/ray_serve/classes/cartpole_deployment.py b/rllib/examples/ray_serve/classes/cartpole_deployment.py index a58580a7d1e4..41686306c095 100644 --- a/rllib/examples/ray_serve/classes/cartpole_deployment.py +++ b/rllib/examples/ray_serve/classes/cartpole_deployment.py @@ -1,17 +1,17 @@ import json from typing import Dict +import numpy as np from starlette.requests import Request +import torch from ray import serve +from ray.rllib.core import Columns +from ray.rllib.core.rl_module.rl_module import RLModule from ray.serve.schema import LoggingConfig -from ray.rllib.algorithms.algorithm import Algorithm -@serve.deployment( - route_prefix="/rllib-rlmodule", - logging_config=LoggingConfig(log_level="WARN"), -) +@serve.deployment(logging_config=LoggingConfig(log_level="WARN")) class ServeRLlibRLModule: """Callable class used by Ray Serve to handle async requests. @@ -21,8 +21,8 @@ class ServeRLlibRLModule: (with a current observation). """ - def __init__(self, checkpoint): - self.algo = Algorithm.from_checkpoint(checkpoint) + def __init__(self, rl_module_checkpoint): + self.rl_module = RLModule.from_checkpoint(rl_module_checkpoint) async def __call__(self, starlette_request: Request) -> Dict: request = await starlette_request.body() @@ -30,13 +30,21 @@ async def __call__(self, starlette_request: Request) -> Dict: request = json.loads(request) obs = request["observation"] - # Compute and return the action for the given observation. - action = self.algo.compute_single_action(obs) + # Compute and return the action for the given observation (create a batch + # with B=1 and convert to torch). + output = self.rl_module.forward_inference( + batch={"obs": torch.from_numpy(np.array([obs], np.float32))} + ) + # Extract action logits and unbatch. + logits = output[Columns.ACTION_DIST_INPUTS][0] + # Act greedily (argmax). + action = int(np.argmax(logits)) - return {"action": int(action)} + return {"action": action} # Defining the builder function. This is so we can start our deployment via: # `serve run [this py module]:rl_module checkpoint=[some algo checkpoint path]` def rl_module(args: Dict[str, str]): - return ServeRLlibRLModule.bind(args["checkpoint"]) + serve.start(http_options={"host": "0.0.0.0", "port": args.get("port", 12345)}) + return ServeRLlibRLModule.bind(args["rl_module_checkpoint"]) diff --git a/rllib/examples/ray_serve/ray_serve_with_rllib.py b/rllib/examples/ray_serve/ray_serve_with_rllib.py index 6001865a5544..0853151f40fa 100644 --- a/rllib/examples/ray_serve/ray_serve_with_rllib.py +++ b/rllib/examples/ray_serve/ray_serve_with_rllib.py @@ -1,55 +1,105 @@ -"""This example script shows how one can use Ray Serve in combination with RLlib. +"""Example on how to run RLlib in combination with Ray Serve. -Here, we serve an already trained PyTorch RLModule to provide action computations -to a Ray Serve client. -""" -import argparse -import atexit -import os +This example trains an agent with PPO on the CartPole environment, then creates +an RLModule checkpoint and returns its location. After that, it sends the checkpoint +to the Serve deployment for serving the trained RLModule (policy). -import requests -import subprocess -import time +This example: + - shows how to set up a Ray Serve deployment for serving an already trained + RLModule (policy network). + - shows how to request new actions from the Ray Serve deployment while actually + running through episodes in an environment (on which the RLModule that's served + was trained). -import gymnasium as gym -from pathlib import Path -import ray -from ray.rllib.algorithms.algorithm import AlgorithmConfig -from ray.rllib.algorithms.ppo import PPOConfig +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack --stop-reward=200.0` -parser = argparse.ArgumentParser() -parser.add_argument("--train-iters", type=int, default=3) -parser.add_argument("--serve-episodes", type=int, default=2) -parser.add_argument("--no-render", action="store_true") +Use the `--stop-iters`, `--stop-reward`, and/or `--stop-timesteps` options to +determine how long to train the policy for. Use the `--serve-episodes` option to +set the number of episodes to serve (after training) and the `--no-render` option +to NOT render the environment during the serving phase. +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. -def train_rllib_rl_module(config: AlgorithmConfig, train_iters: int = 1): - """Trains a PPO (RLModule) on ALE/MsPacman-v5 for n iterations. +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` - Saves the trained Algorithm to disk and returns the checkpoint path. +You can visualize experiment results in ~/ray_results using TensorBoard. - Args: - config: The algo config object for the Algorithm. - train_iters: For how many iterations to train the Algorithm. - Returns: - str: The saved checkpoint to restore the RLModule from. - """ - # Create algorithm from config. - algo = config.build() +Results to expect +----------------- - # Train for n iterations, then save, stop, and return the checkpoint path. - for _ in range(train_iters): - print(algo.train()) +You should see something similar to the following on the command line when using the +options: `--stop-reward=250.0`, `--num-episodes-served=2`, and `--port=12345`: - # TODO (sven): Change this example to only storing the RLModule checkpoint, NOT - # the entire Algorithm. - checkpoint_result = algo.save() +[First, the RLModule is trained through PPO] - algo.stop() ++-----------------------------+------------+-----------------+--------+ +| Trial name | status | loc | iter | +| | | | | +|-----------------------------+------------+-----------------+--------+ +| PPO_CartPole-v1_84778_00000 | TERMINATED | 127.0.0.1:40411 | 1 | ++-----------------------------+------------+-----------------+--------+ ++------------------+---------------------+------------------------+ +| total time (s) | episode_return_mean | num_env_steps_sample | +| | | d_lifetime | +|------------------+---------------------|------------------------| +| 2.87052 | 253.2 | 12000 | ++------------------+---------------------+------------------------+ - return checkpoint_result.checkpoint +[The RLModule is deployed through Ray Serve on port 12345] + +Started Ray Serve with PID: 40458 + +[A few episodes are played through using the policy service (w/ greedy, non-exploratory +actions)] + +Episode R=500.0 +Episode R=500.0 +""" + +import atexit +import os + +import requests +import subprocess +import time + +import gymnasium as gym +from pathlib import Path + +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.core import ( + COMPONENT_LEARNER_GROUP, + COMPONENT_LEARNER, + COMPONENT_RL_MODULE, + DEFAULT_MODULE_ID, +) +from ray.rllib.utils.metrics import ( + ENV_RUNNER_RESULTS, + EPISODE_RETURN_MEAN, +) +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) + +parser = add_rllib_example_script_args() +parser.set_defaults( + enable_new_api_stack=True, + checkpoint_freq=1, + checkpoint_at_and=True, +) +parser.add_argument("--num-episodes-served", type=int, default=2) +parser.add_argument("--no-render", action="store_true") +parser.add_argument("--port", type=int, default=12345) def kill_proc(proc): @@ -64,18 +114,23 @@ def kill_proc(proc): if __name__ == "__main__": args = parser.parse_args() - ray.init(num_cpus=8) - # Config for the served RLlib RLModule/Algorithm. - config = ( - PPOConfig() - .api_stack(enable_rl_module_and_learner=True) - .environment("CartPole-v1") + base_config = PPOConfig().environment("CartPole-v1") + + results = run_rllib_example_script_experiment(base_config, args) + algo_checkpoint = results.get_best_result( + f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}" + ).checkpoint.path + # We only need the RLModule component from the algorithm checkpoint. It's located + # under "[algo checkpoint dir]/learner_group/learner/rl_module/[default policy ID] + rl_module_checkpoint = ( + Path(algo_checkpoint) + / COMPONENT_LEARNER_GROUP + / COMPONENT_LEARNER + / COMPONENT_RL_MODULE + / DEFAULT_MODULE_ID ) - # Train the Algorithm for some time, then save it and get the checkpoint path. - checkpoint = train_rllib_rl_module(config, train_iters=args.train_iters) - path_of_this_file = Path(__file__).parent os.chdir(path_of_this_file) # Start the serve app with the trained checkpoint. @@ -84,7 +139,9 @@ def kill_proc(proc): "serve", "run", "classes.cartpole_deployment:rl_module", - f"checkpoint={checkpoint.path}", + f"rl_module_checkpoint={rl_module_checkpoint}", + f"port={args.port}", + "route_prefix=/rllib-rlmodule", ] ) # Register our `kill_proc` function to be called on exit to stop Ray Serve again. @@ -97,35 +154,34 @@ def kill_proc(proc): # Create the environment that we would like to receive # served actions for. env = gym.make("CartPole-v1", render_mode="human") - obs, info = env.reset() + obs, _ = env.reset() num_episodes = 0 episode_return = 0.0 - while num_episodes < args.serve_episodes: + while num_episodes < args.num_episodes_served: # Render env if necessary. if not args.no_render: env.render() - # print("-> Requesting action for obs ...") + # print(f"-> Requesting action for obs={obs} ...", end="") # Send a request to serve. resp = requests.get( - "http://localhost:8000/rllib-rlmodule", + f"http://localhost:{args.port}/rllib-rlmodule", json={"observation": obs.tolist()}, - # timeout=5.0, ) response = resp.json() - # print("<- Received response {}".format(response)) + # print(f" received: action={response['action']}") # Apply the action in the env. action = response["action"] - obs, reward, done, _, _ = env.step(action) + obs, reward, terminated, truncated, _ = env.step(action) episode_return += reward # If episode done -> reset to get initial observation of new episode. - if done: + if terminated or truncated: print(f"Episode R={episode_return}") - obs, info = env.reset() + obs, _ = env.reset() num_episodes += 1 episode_return = 0.0