Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multinode batch inference #2

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ RUN pip install -r requirements.txt
COPY model-engine /workspace/model-engine
RUN pip install -e /workspace/model-engine
COPY model-engine/model_engine_server/inference/batch_inference/vllm_batch.py /workspace/vllm_batch.py
COPY model-engine/model_engine_server/inference/batch_inference/init_ray.py /workspace/init_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@ if [ -z "$2" ]; then
exit 1;
fi

if [ -z "$3" ]; then
echo "Must supply the repo name"
exit 1;
fi

REPO_NAME=$3
IMAGE_TAG=$2
ACCOUNT=$1
aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com
DOCKER_BUILDKIT=1 docker build -t $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/llm-engine/batch-infer-vllm:$IMAGE_TAG -f Dockerfile_vllm ../../../../
docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/llm-engine/batch-infer-vllm:$IMAGE_TAG
DOCKER_BUILDKIT=1 docker build -t $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/$REPO_NAME:$IMAGE_TAG -f Dockerfile_vllm ../../../../
docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/$REPO_NAME:$IMAGE_TAG
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# TODO this should initialize multinode and ray. Should look similar to the set up for multinode online serving.
# TODO a few things differ here, we need to look at a bunch of env vars to determine if 1. we're leader, 2. ray address, port, cluster_size, own_address.
# In one case, we have JOB_COMPLETION_INDEX, NUM_INSTANCES, MASTER_ADDR, MASTER_PORT as available env vars.
# (May) need to get own_address from somewhere. In serving, it's from a few env vars, and is a k8s dns name.

import argparse
import os
import subprocess
import sys
import time

RAY_INIT_TIMEOUT = 1200


def start_worker(ray_address, ray_port, ray_init_timeout):
for i in range(0, ray_init_timeout, 5):
result = subprocess.run(
[
"ray",
"start",
"--address",
f"{ray_address}:{ray_port}",
"--block",
# "--node-ip-address",
# own_address,
],
capture_output=True,
)
if result.returncode == 0:
print(f"Worker: Ray runtime started with head address {ray_address}:{ray_port}")
sys.exit(0)
print(result.returncode)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Print the error message from the result object for better debugging

print("Waiting until the ray worker is active...")
time.sleep(5)
print(f"Ray worker starts timeout, head address: {ray_address}:{ray_port}")
sys.exit(1)


def start_leader(ray_port, ray_cluster_size, ray_init_timeout):
subprocess.run(
["ray", "start", "--head", "--port", str(ray_port)] # , "--node-ip-address", own_address]
)
for i in range(0, ray_init_timeout, 5):
active_nodes = subprocess.run(
[
"python3",
"-c",
'import ray; ray.init(); print(sum(node["Alive"] for node in ray.nodes()))',
],
capture_output=True,
text=True,
)
active_nodes = int(active_nodes.stdout.strip())
if active_nodes == ray_cluster_size:
print("All ray workers are active and the ray cluster is initialized successfully.")
sys.exit(0)
print(f"Wait for all ray workers to be active. {active_nodes}/{ray_cluster_size} is active")
time.sleep(5)
print("Waiting for all ray workers to be active timed out.")
sys.exit(1)


def main():
parser = argparse.ArgumentParser(description="Ray cluster initialization script")
parser.add_argument("--ray_init_timeout", type=int, default=RAY_INIT_TIMEOUT)

args = parser.parse_args()

is_leader = os.getenv("JOB_COMPLETION_INDEX") == "0"
ray_address = os.getenv("MASTER_ADDR")
ray_port = os.getenv("MASTER_PORT")
ray_cluster_size = os.getenv("NUM_INSTANCES")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Ensure NUM_INSTANCES is converted to an integer


if is_leader:
start_leader(ray_port, ray_cluster_size, args.ray_init_timeout)
else:
start_worker(ray_address, ray_port, args.ray_init_timeout)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,14 @@ def tool_func(text: str, past_context: Optional[str]):


async def batch_inference():
job_index = int(os.getenv("JOB_COMPLETION_INDEX", 0))
job_index = int(os.getenv("JOB_COMPLETION_INDEX", 0)) # TODO this conflicts with multinode
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: This TODO suggests a conflict with multinode setup. Consider resolving this before merging.


request = CreateBatchCompletionsEngineRequest.parse_file(CONFIG_FILE)

if request.model_cfg.checkpoint_path is not None:
download_model(request.model_cfg.checkpoint_path, MODEL_WEIGHTS_FOLDER)
download_model(
request.model_cfg.checkpoint_path, MODEL_WEIGHTS_FOLDER
) # TODO move this out
Comment on lines +320 to +322
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Moving model download out of this function could improve performance for multinode setups.


content = request.content
if content is None:
Expand Down