Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add p2p download functionality #127

Draft
wants to merge 1 commit into
base: refactor_model_download
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions exo/inference/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from pathlib import Path
from typing import Generator, Iterable, TypeVar, TypedDict
from dataclasses import dataclass
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
from exo.topology.topology import Topology
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from exo.helpers import DEBUG


T = TypeVar("T")
def filter_repo_objects(
items: Iterable[T],
Expand Down Expand Up @@ -199,6 +202,14 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, speed, eta, status))
if DEBUG >= 2: print(f"Downloaded: {file_path}")

async def download_file_from_peer(file_path: str, save_directory: str, progress_callback: Optional[HFRepoProgressCallback] = None):
topology = Topology()
node_id = await topology.broadcast_file_request(file_path)
if not node_id:
raise ValueError(f"No peer has the file {file_path}")
peer_handle = GRPCPeerHandle(node_id)
await peer_handle.download_file(file_path, save_directory, progress_callback)

async def download_all_files(repo_id: str, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
repo_root = get_repo_root(repo_id)
refs_dir = repo_root / "refs"
Expand All @@ -208,6 +219,14 @@ async def download_all_files(repo_id: str, revision: str = "main", progress_call
refs_dir.mkdir(parents=True, exist_ok=True)
snapshots_dir.mkdir(parents=True, exist_ok=True)

try:
await download_file_from_peer(repo_id, snapshots_dir, progress_callback)
return
except ValueError:
if DEBUG >= 2: print("File not found with peers, downloading from HUB")

# Else download the file from HUB

async with aiohttp.ClientSession() as session:
# Fetch the commit hash for the given revision
api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
Expand Down
42 changes: 34 additions & 8 deletions exo/networking/grpc/grpc_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .grpc_peer_handle import GRPCPeerHandle
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
from exo import DEBUG_DISCOVERY
from exo.topology.topology import Topology


class ListenProtocol(asyncio.DatagramProtocol):
Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(
self.listen_task = None
self.cleanup_task = None
self.discovery_timeout = discovery_timeout
self.topology = Topology()

async def start(self):
self.device_capabilities = device_capabilities()
Expand Down Expand Up @@ -97,14 +99,12 @@ async def task_broadcast_presence(self):
sock = transport.get_extra_info("socket")
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)

message = json.dumps(
{
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict(),
}
).encode("utf-8")
message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict(),
}).encode("utf-8")

while True:
try:
Expand Down Expand Up @@ -161,6 +161,32 @@ async def task_listen_for_peers(self):
if DEBUG_DISCOVERY >= 2:
print("Started listen task")

async def broadcast_file_request(self, file_path: str) -> List[str]:
"""
Sends a file request to all known peers and gathers their responses.
Returns a list of node IDs that have the file.
"""
nodes_with_file = []
for peer_id, (peer_handle, _, _) in self.known_peers.items():
has_file = await peer_handle.check_file(file_path)
if has_file:
self.topology.update_file_ownership(peer_id, file_path)
nodes_with_file.append(peer_id)
return nodes_with_file

async def download_from_peer(self, file_path: str, save_directory: str):
"""
Initiates a file download from a peer that has the file.
If no peer has the file, raises a ValueError.
"""
nodes_with_file = await self.broadcast_file_request(file_path)
if not nodes_with_file:
raise ValueError(f"No peer has the file {file_path}")
# Choose the first node with the file to download from
node_id = nodes_with_file[0]
peer_handle = GRPCPeerHandle(node_id)
await peer_handle.download_file(file_path, save_directory)

async def task_cleanup_peers(self):
while True:
try:
Expand Down
40 changes: 38 additions & 2 deletions exo/topology/topology.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from .device_capabilities import DeviceCapabilities
from typing import Dict, Set, Optional
from typing import Dict, Optional, Set
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
from exo.topology.device_capabilities import DeviceCapabilities


class Topology:
def __init__(self):
self.nodes: Dict[str, DeviceCapabilities] = {} # Maps node IDs to DeviceCapabilities
self.peer_graph: Dict[str, Set[str]] = {} # Adjacency list representing the graph
self.active_node_id: Optional[str] = None
self.file_ownership: Dict[str, Set[str]] = {} # Maps file paths to node IDs

def update_node(self, node_id: str, device_capabilities: DeviceCapabilities):
self.nodes[node_id] = device_capabilities
Expand Down Expand Up @@ -47,3 +49,37 @@ def __str__(self):
nodes_str = ", ".join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items())
edges_str = ", ".join(f"{node}: {neighbors}" for node, neighbors in self.peer_graph.items())
return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})"

def update_file_ownership(self, node_id: str, file_path: str):
"""
Updates the file ownership dictionary with the given node ID and file path.
If the file path does not exist, it is added.
"""
if file_path not in self.file_ownership:
self.file_ownership[file_path] = set()
self.file_ownership[file_path].add(node_id)

async def send_broadcast_request(self, node_id: str, file_path: str) -> Optional[str]:
"""
Sends a file request to the specified node using gRPC.
If the node has the file, returns the node ID. Otherwise, returns None.
"""
peer_handle = GRPCPeerHandle(node_id)
has_file = await peer_handle.check_file(file_path)
if has_file:
self.update_file_ownership(node_id, file_path)
return node_id
return None

async def download_from_peer(self, file_path: str, save_directory: str):
"""
Initiates a file download from a peer that has the file.
If no peer has the file, raises a ValueError.
"""
nodes_with_file = await self.broadcast_file_request(file_path)
if not nodes_with_file:
raise ValueError(f"No peer has the file {file_path}")
# Choose the first node with the file to download from
node_id = nodes_with_file[0]
peer_handle = GRPCPeerHandle(node_id)
await peer_handle.download_file(file_path, save_directory)