diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index 31969e801..526688558 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -15,7 +15,7 @@ from exo.viz.topology_viz import TopologyViz class StandardNode(Node): - def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256, chatgpt_api_endpoint: Optional[str] = None, web_chat_url: Optional[str] = None): + def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256, chatgpt_api_endpoint: Optional[str] = None, web_chat_url: Optional[str] = None, disable_tui: Optional[bool] = False): self.id = id self.inference_engine = inference_engine self.server = server @@ -25,7 +25,7 @@ def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, d self.topology: Topology = Topology() self.device_capabilities = device_capabilities() self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {} - self.topology_viz = TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url) + self.topology_viz = TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url) if not disable_tui else None self.max_generate_tokens = max_generate_tokens self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]() self._on_opaque_status = AsyncCallbackSystem[str, str]() @@ -40,7 +40,8 @@ def on_node_status(self, request_id, opaque_status): elif status_data.get("status", "").startswith("end_"): if status_data.get("node_id") == self.current_topology.active_node_id: self.current_topology.active_node_id = None - self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology)) + if self.topology_viz: + self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology)) except json.JSONDecodeError: pass @@ -242,7 +243,8 @@ async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) next_topology.active_node_id = self.topology.active_node_id # this is not so clean. self.topology = next_topology - self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology)) + if self.topology_viz: + self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology)) return next_topology # TODO: unify this and collect_topology as global actions diff --git a/main.py b/main.py index 2758473a2..833f85426 100644 --- a/main.py +++ b/main.py @@ -24,6 +24,7 @@ parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port") parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds") parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use") +parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI") args = parser.parse_args() print_yellow_exo() @@ -55,7 +56,7 @@ args.node_port = find_available_port(args.node_host) if DEBUG >= 1: print(f"Using available port: {args.node_port}") discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port) -node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}") +node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}", disable_tui=args.disable_tui) server = GRPCServer(node, args.node_host, args.node_port) node.server = server api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)