diff --git a/exo/helpers.py b/exo/helpers.py index 7ca2b1f07..97b1d8e8f 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -3,6 +3,8 @@ from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple import socket import random +import platform +import psutil DEBUG = int(os.getenv("DEBUG", default="0")) DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0")) @@ -15,6 +17,28 @@ \___/_/\_\___/ """ +def get_system_info(): + if psutil.MACOS: + if platform.machine() == 'arm64': + return "Apple Silicon Mac" + elif platform.machine() in ['x86_64', 'i386']: + return "Intel Mac" + else: + return "Unknown Mac architecture" + elif psutil.LINUX: + return "Linux" + else: + return "Non-Mac, non-Linux system" + +def get_inference_engine(): + system_info = get_system_info() + if system_info == "Apple Silicon Mac": + from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine + return MLXDynamicShardInferenceEngine() + else: + from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine + return TinygradDynamicShardInferenceEngine() + def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int: used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports') @@ -113,4 +137,4 @@ def trigger(self, name: K, *args: T) -> None: def trigger_all(self, *args: T) -> None: for callback in self.callbacks.values(): - callback.set(*args) \ No newline at end of file + callback.set(*args) diff --git a/main.py b/main.py index 833f85426..3864a859d 100644 --- a/main.py +++ b/main.py @@ -2,16 +2,13 @@ import asyncio import signal import uuid -import platform -import psutil -import os from typing import List from exo.orchestration.standard_node import StandardNode from exo.networking.grpc.grpc_server import GRPCServer from exo.networking.grpc.grpc_discovery import GRPCDiscovery from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy from exo.api import ChatGPTAPI -from exo.helpers import print_yellow_exo, find_available_port, DEBUG +from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info # parse args parser = argparse.ArgumentParser(description="Initialize GRPC Discovery") @@ -28,33 +25,17 @@ args = parser.parse_args() print_yellow_exo() -print(f"Starting exo {platform.system()=} {psutil.virtual_memory()=}") -if args.inference_engine is None: - if psutil.MACOS: - from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine - inference_engine = MLXDynamicShardInferenceEngine() - else: - from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine - import tinygrad.helpers - tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) - inference_engine = TinygradDynamicShardInferenceEngine() -else: - if args.inference_engine == "mlx": - from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine - inference_engine = MLXDynamicShardInferenceEngine() - elif args.inference_engine == "tinygrad": - from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine - import tinygrad.helpers - tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) - inference_engine = TinygradDynamicShardInferenceEngine() - else: - raise ValueError(f"Inference engine {args.inference_engine} not supported") -print(f"Using inference engine {inference_engine.__class__.__name__}") +system_info = get_system_info() +print(f"Detected system: {system_info}") + +inference_engine = get_inference_engine() +print(f"Using inference engine: {inference_engine.__class__.__name__}") if args.node_port is None: args.node_port = find_available_port(args.node_host) if DEBUG >= 1: print(f"Using available port: {args.node_port}") + discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port) node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}", disable_tui=args.disable_tui) server = GRPCServer(node, args.node_host, args.node_port)