Skip to content

Commit

Permalink
Merge pull request #48 from itsknk/intel-mac
Browse files Browse the repository at this point in the history
Implement dynamic inference engine selection #45
  • Loading branch information
AlexCheema authored Jul 22, 2024
2 parents 1fcbe18 + e934664 commit 2e419ba
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 27 deletions.
26 changes: 25 additions & 1 deletion exo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple
import socket
import random
import platform
import psutil

DEBUG = int(os.getenv("DEBUG", default="0"))
DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
Expand All @@ -15,6 +17,28 @@
\___/_/\_\___/
"""

def get_system_info():
if psutil.MACOS:
if platform.machine() == 'arm64':
return "Apple Silicon Mac"
elif platform.machine() in ['x86_64', 'i386']:
return "Intel Mac"
else:
return "Unknown Mac architecture"
elif psutil.LINUX:
return "Linux"
else:
return "Non-Mac, non-Linux system"

def get_inference_engine():
system_info = get_system_info()
if system_info == "Apple Silicon Mac":
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
return MLXDynamicShardInferenceEngine()
else:
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
return TinygradDynamicShardInferenceEngine()

def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int:
used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports')

Expand Down Expand Up @@ -113,4 +137,4 @@ def trigger(self, name: K, *args: T) -> None:

def trigger_all(self, *args: T) -> None:
for callback in self.callbacks.values():
callback.set(*args)
callback.set(*args)
33 changes: 7 additions & 26 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@
import asyncio
import signal
import uuid
import platform
import psutil
import os
from typing import List
from exo.orchestration.standard_node import StandardNode
from exo.networking.grpc.grpc_server import GRPCServer
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.api import ChatGPTAPI
from exo.helpers import print_yellow_exo, find_available_port, DEBUG
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info

# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
Expand All @@ -28,33 +25,17 @@
args = parser.parse_args()

print_yellow_exo()
print(f"Starting exo {platform.system()=} {psutil.virtual_memory()=}")
if args.inference_engine is None:
if psutil.MACOS:
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
inference_engine = MLXDynamicShardInferenceEngine()
else:
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
import tinygrad.helpers
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
inference_engine = TinygradDynamicShardInferenceEngine()
else:
if args.inference_engine == "mlx":
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
inference_engine = MLXDynamicShardInferenceEngine()
elif args.inference_engine == "tinygrad":
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
import tinygrad.helpers
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
inference_engine = TinygradDynamicShardInferenceEngine()
else:
raise ValueError(f"Inference engine {args.inference_engine} not supported")
print(f"Using inference engine {inference_engine.__class__.__name__}")

system_info = get_system_info()
print(f"Detected system: {system_info}")

inference_engine = get_inference_engine()
print(f"Using inference engine: {inference_engine.__class__.__name__}")

if args.node_port is None:
args.node_port = find_available_port(args.node_host)
if DEBUG >= 1: print(f"Using available port: {args.node_port}")

discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}", disable_tui=args.disable_tui)
server = GRPCServer(node, args.node_host, args.node_port)
Expand Down

0 comments on commit 2e419ba

Please sign in to comment.