Skip to content

Commit

Permalink
disable tui flag
Browse files Browse the repository at this point in the history
  • Loading branch information
apotl committed Jul 20, 2024
1 parent 821f114 commit db583a8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
10 changes: 6 additions & 4 deletions exo/orchestration/standard_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from exo.viz.topology_viz import TopologyViz

class StandardNode(Node):
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256, chatgpt_api_endpoint: Optional[str] = None, web_chat_url: Optional[str] = None):
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256, chatgpt_api_endpoint: Optional[str] = None, web_chat_url: Optional[str] = None, disable_tui: Optional[bool] = False):
self.id = id
self.inference_engine = inference_engine
self.server = server
Expand All @@ -25,7 +25,7 @@ def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, d
self.topology: Topology = Topology()
self.device_capabilities = device_capabilities()
self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
self.topology_viz = TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url)
self.topology_viz = TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url) if not disable_tui else None
self.max_generate_tokens = max_generate_tokens
self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
self._on_opaque_status = AsyncCallbackSystem[str, str]()
Expand All @@ -40,7 +40,8 @@ def on_node_status(self, request_id, opaque_status):
elif status_data.get("status", "").startswith("end_"):
if status_data.get("node_id") == self.current_topology.active_node_id:
self.current_topology.active_node_id = None
self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
if self.topology_viz:
self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
except json.JSONDecodeError:
pass

Expand Down Expand Up @@ -242,7 +243,8 @@ async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4)

next_topology.active_node_id = self.topology.active_node_id # this is not so clean.
self.topology = next_topology
self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
if self.topology_viz:
self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
return next_topology

# TODO: unify this and collect_topology as global actions
Expand Down
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds")
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
args = parser.parse_args()

print_yellow_exo()
Expand Down Expand Up @@ -55,7 +56,7 @@
args.node_port = find_available_port(args.node_host)
if DEBUG >= 1: print(f"Using available port: {args.node_port}")
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}")
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}", disable_tui=args.disable_tui)
server = GRPCServer(node, args.node_host, args.node_port)
node.server = server
api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
Expand Down

0 comments on commit db583a8

Please sign in to comment.