Skip to content

Commit

Permalink
async model downloading with download progress. fixes #102. related: #16
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jul 31, 2024
1 parent 5c67e24 commit d6a7e46
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 21 deletions.
3 changes: 1 addition & 2 deletions exo/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import asyncio
from typing import Any, Callable, TypeVar, Optional, Dict, Generic, Tuple, List
from collections import defaultdict
from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
import socket
import random
import platform
Expand Down
10 changes: 7 additions & 3 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from .sharded_model import StatefulShardedModel
from .sharded_utils import load_shard, get_image_from_str
from ..shard import Shard
from typing import Optional
from typing import Optional, Callable


class MLXDynamicShardInferenceEngine(InferenceEngine):
def __init__(self):
def __init__(self, on_download_progress: Callable[[int, int], None] = None):
self.shard = None
self.on_download_progress = on_download_progress

async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
await self.ensure_shard(shard)
Expand All @@ -32,6 +33,9 @@ async def ensure_shard(self, shard: Shard):
if self.shard == shard:
return

model_shard, self.tokenizer = await load_shard(shard.model_id, shard)
model_shard, self.tokenizer = await load_shard(shard.model_id, shard, on_download_progress=self.on_download_progress)
self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
self.shard = shard

def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
self.on_download_progress = on_download_progress
65 changes: 54 additions & 11 deletions exo/inference/mlx/sharded_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@
import aiohttp
from functools import partial
from pathlib import Path
from typing import Optional, Tuple
import requests
from typing import Optional, Tuple, Union, List, Callable
from PIL import Image
from io import BytesIO
import base64
import os

from exo import DEBUG
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from huggingface_hub import snapshot_download, list_repo_tree, get_paths_info
from huggingface_hub.utils import filter_repo_objects
from huggingface_hub.file_download import repo_folder_name
from huggingface_hub.constants import HF_HUB_CACHE
from huggingface_hub.utils._errors import RepositoryNotFoundError
from transformers import AutoProcessor

Expand Down Expand Up @@ -144,12 +147,50 @@ def load_model_shard(
return model


async def snapshot_download_async(*args, **kwargs):
func = partial(snapshot_download, *args, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)


async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
async def get_repo_size(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
it = await asyncio.to_thread(list_repo_tree, repo_id, revision=revision, repo_type=repo_type)
files = list(filter_repo_objects(it, allow_patterns=allow_patterns, key=lambda f: f.path))
return sum(file.size for file in files if file.size is not None)

async def monitor_progress(dir, total_size, print_progress=False, on_progress: Callable[[int, int], None] = None):
while True:
await asyncio.sleep(0.1)
current_size = sum(os.path.getsize(os.path.join(root, file))
for root, _, files in os.walk(dir)
for file in files)
progress = min(current_size / total_size * 100, 100)
if print_progress:
print(f"\rProgress: {progress:.2f}% ({current_size}/{total_size} bytes)", end="", flush=True)
if on_progress:
on_progress(current_size, total_size)
if progress >= 100:
if print_progress:
print("\nDownload complete!")
break

async def download_repo(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
# Use snapshot_download in a separate thread to not block the event loop
return await asyncio.to_thread(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type)

async def download_async_with_progress(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None, on_progress: Callable[[int, int], None] = None):
storage_folder = os.path.join(HF_HUB_CACHE, repo_folder_name(repo_id=repo_id, repo_type="model"))
# os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
# os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'

print(f"Estimating size of repository: {repo_id}")
total_size = await get_repo_size(repo_id)
print(f"Estimated total size: {total_size} bytes")

# Create tasks for download and progress checking
download_task = asyncio.create_task(download_repo(repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type))
progress_task = asyncio.create_task(monitor_progress(storage_folder, total_size, on_progress=on_progress))

# Wait for both tasks to complete
result = await asyncio.gather(download_task, progress_task)
return result[0] # Return the result from download_task


async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path:
"""
Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub.
Expand All @@ -165,7 +206,7 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -
if not model_path.exists():
try:
model_path = Path(
await snapshot_download_async(
await download_async_with_progress(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
Expand All @@ -176,6 +217,7 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -
"*.tiktoken",
"*.txt",
],
on_progress=on_download_progress,
)
)
except RepositoryNotFoundError:
Expand All @@ -196,6 +238,7 @@ async def load_shard(
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
on_download_progress: Callable[[int, int], None] = None,
) -> Tuple[nn.Module, TokenizerWrapper]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
Expand All @@ -218,7 +261,7 @@ async def load_shard(
FileNotFoundError: If config file or safetensors are not found.
ValueError: If model class or args class are not found.
"""
model_path = await get_model_path(path_or_hf_repo)
model_path = await get_model_path(path_or_hf_repo, on_download_progress=on_download_progress)

model = load_model_shard(model_path, shard, lazy, model_config)
if adapter_path is not None:
Expand Down
8 changes: 7 additions & 1 deletion exo/orchestration/standard_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,13 @@ def on_node_status(self, request_id, opaque_status):
elif status_data.get("status", "").startswith("end_"):
if status_data.get("node_id") == self.current_topology.active_node_id:
self.current_topology.active_node_id = None
download_progress = None
if status_data.get("type", "") == "download_progress":
if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('current')}/{status_data.get('total')} ({round(status_data.get('current') / status_data.get('total') * 100, 2)}%)")
if status_data.get("node_id") == self.id:
download_progress = (status_data.get('current'), status_data.get('total'))
if self.topology_viz:
self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), download_progress)
except json.JSONDecodeError:
pass

Expand Down Expand Up @@ -370,6 +375,7 @@ async def send_result_to_peer(peer):
await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)

async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
if DEBUG >= 5: print(f"Broadcasting opaque status: {request_id=} {status=}")
async def send_status_to_peer(peer):
try:
await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
Expand Down
8 changes: 5 additions & 3 deletions exo/viz/topology_viz.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List
from typing import List, Optional, Tuple
from exo.helpers import exo_text
from exo.topology.topology import Topology
from exo.topology.partitioning_strategy import Partition
Expand All @@ -17,22 +17,24 @@ def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
self.web_chat_url = web_chat_url
self.topology = Topology()
self.partitions: List[Partition] = []
self.download_progress = None

self.console = Console()
self.panel = Panel(self._generate_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
self.live_panel.start()

def update_visualization(self, topology: Topology, partitions: List[Partition]):
def update_visualization(self, topology: Topology, partitions: List[Partition], download_progress: Optional[Tuple[int, int]] = None):
self.topology = topology
self.partitions = partitions
self.download_progress = download_progress
self.refresh()

def refresh(self):
self.panel.renderable = self._generate_layout()
# Update the panel title with the number of nodes and partitions
node_count = len(self.topology.nodes)
self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''}){f' {self.download_progress[0]/self.download_progress[1]:.2%} Downloaded' if self.download_progress else ''}"
self.live_panel.update(self.panel, refresh=True)

def _generate_layout(self) -> str:
Expand Down
2 changes: 2 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import asyncio
import signal
import json
import uuid
from exo.orchestration.standard_node import StandardNode
from exo.networking.grpc.grpc_server import GRPCServer
Expand Down Expand Up @@ -58,6 +59,7 @@
if args.prometheus_client_port:
from exo.stats.metrics import start_metrics_server
start_metrics_server(node, args.prometheus_client_port)
inference_engine.set_on_download_progress(lambda current, total: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": current, "total": total}))))

async def shutdown(signal, loop):
"""Gracefully shutdown the server and close the asyncio loop."""
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
"blobfile==2.1.1",
"grpcio==1.64.1",
"grpcio-tools==1.64.1",
"huggingface-hub==0.23.4",
"hf-transfer==0.1.8",
"huggingface-hub==0.24.5",
"Jinja2==3.1.4",
"numpy==2.0.0",
"pillow==10.4.0",
Expand Down

0 comments on commit d6a7e46

Please sign in to comment.