From e934664168e4edf924aa5cdd3ea4b788fd7135ed Mon Sep 17 00:00:00 2001 From: itsknk Date: Sat, 20 Jul 2024 08:05:34 -0700 Subject: [PATCH] implement dynamic inference engine selection implement the system detection and inference engine selection implement dynamic inference engine selection implement dynamic inference engine selection implement dynamic inference engine selection remove inconsistency implement dynamic inference engine selection --- exo/helpers.py | 26 +++++++++++++++++++++++++- main.py | 33 +++++++-------------------------- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/exo/helpers.py b/exo/helpers.py index 7ca2b1f07..97b1d8e8f 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -3,6 +3,8 @@ from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple import socket import random +import platform +import psutil DEBUG = int(os.getenv("DEBUG", default="0")) DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0")) @@ -15,6 +17,28 @@ \___/_/\_\___/ """ +def get_system_info(): + if psutil.MACOS: + if platform.machine() == 'arm64': + return "Apple Silicon Mac" + elif platform.machine() in ['x86_64', 'i386']: + return "Intel Mac" + else: + return "Unknown Mac architecture" + elif psutil.LINUX: + return "Linux" + else: + return "Non-Mac, non-Linux system" + +def get_inference_engine(): + system_info = get_system_info() + if system_info == "Apple Silicon Mac": + from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine + return MLXDynamicShardInferenceEngine() + else: + from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine + return TinygradDynamicShardInferenceEngine() + def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int: used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports') @@ -113,4 +137,4 @@ def trigger(self, name: K, *args: T) -> None: def trigger_all(self, *args: T) -> None: for callback in self.callbacks.values(): - callback.set(*args) \ No newline at end of file + callback.set(*args) diff --git a/main.py b/main.py index 2758473a2..20cc15535 100644 --- a/main.py +++ b/main.py @@ -2,16 +2,13 @@ import asyncio import signal import uuid -import platform -import psutil -import os from typing import List from exo.orchestration.standard_node import StandardNode from exo.networking.grpc.grpc_server import GRPCServer from exo.networking.grpc.grpc_discovery import GRPCDiscovery from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy from exo.api import ChatGPTAPI -from exo.helpers import print_yellow_exo, find_available_port, DEBUG +from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info # parse args parser = argparse.ArgumentParser(description="Initialize GRPC Discovery") @@ -27,33 +24,17 @@ args = parser.parse_args() print_yellow_exo() -print(f"Starting exo {platform.system()=} {psutil.virtual_memory()=}") -if args.inference_engine is None: - if psutil.MACOS: - from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine - inference_engine = MLXDynamicShardInferenceEngine() - else: - from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine - import tinygrad.helpers - tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) - inference_engine = TinygradDynamicShardInferenceEngine() -else: - if args.inference_engine == "mlx": - from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine - inference_engine = MLXDynamicShardInferenceEngine() - elif args.inference_engine == "tinygrad": - from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine - import tinygrad.helpers - tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) - inference_engine = TinygradDynamicShardInferenceEngine() - else: - raise ValueError(f"Inference engine {args.inference_engine} not supported") -print(f"Using inference engine {inference_engine.__class__.__name__}") +system_info = get_system_info() +print(f"Detected system: {system_info}") + +inference_engine = get_inference_engine() +print(f"Using inference engine: {inference_engine.__class__.__name__}") if args.node_port is None: args.node_port = find_available_port(args.node_host) if DEBUG >= 1: print(f"Using available port: {args.node_port}") + discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port) node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}") server = GRPCServer(node, args.node_host, args.node_port)