From 0bc0b8f602591b2d8524e525ca8f1d24680f3d1b Mon Sep 17 00:00:00 2001 From: hammadb Date: Wed, 8 Jan 2025 19:30:28 -0800 Subject: [PATCH] round robin --- chromadb/execution/executor/distributed.py | 5 ++- chromadb/segment/distributed/__init__.py | 8 ++-- .../impl/distributed/segment_directory.py | 39 ++++++++++--------- chromadb/segment/impl/manager/distributed.py | 6 +-- 4 files changed, 32 insertions(+), 26 deletions(-) diff --git a/chromadb/execution/executor/distributed.py b/chromadb/execution/executor/distributed.py index 6bf795540c1..baa691840ef 100644 --- a/chromadb/execution/executor/distributed.py +++ b/chromadb/execution/executor/distributed.py @@ -1,3 +1,4 @@ +import random from typing import Dict, Optional import grpc from overrides import overrides @@ -161,7 +162,9 @@ def knn(self, plan: KNNPlan) -> QueryResult: def _grpc_executuor_stub(self, scan: Scan) -> QueryExecutorStub: # Since grpc endpoint is endpoint is determined by collection uuid, # the endpoint should be the same for all segments of the same collection - grpc_url = self._manager.get_endpoint(scan.record) + # TODO: configure the number of endpoints to fetch + grpc_urls = self._manager.get_endpoints(scan.record, 3) + grpc_url = grpc_urls[random.randint(0, len(grpc_urls) - 1)] if grpc_url not in self._grpc_stub_pool: channel = grpc.insecure_channel(grpc_url) interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()] diff --git a/chromadb/segment/distributed/__init__.py b/chromadb/segment/distributed/__init__.py index 75d602d6f55..beb48102ac6 100644 --- a/chromadb/segment/distributed/__init__.py +++ b/chromadb/segment/distributed/__init__.py @@ -9,11 +9,13 @@ class SegmentDirectory(Component): """A segment directory is a data interface that manages the location of segments. Concretely, this - means that for clustered chroma, it provides the grpc endpoint for a segment.""" + means that for distributed chroma, it provides the grpc endpoint for a segment.""" @abstractmethod - def get_segment_endpoint(self, segment: Segment) -> str: - """Return the segment residence for a given segment ID""" + def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]: + """Return the segment residences for a given segment ID. Will return at most n residences. + Should only return less than n residences if there are less than n residences available. + """ @abstractmethod def register_updated_segment_callback( diff --git a/chromadb/segment/impl/distributed/segment_directory.py b/chromadb/segment/impl/distributed/segment_directory.py index dc1a73e8e11..bf74e0b499e 100644 --- a/chromadb/segment/impl/distributed/segment_directory.py +++ b/chromadb/segment/impl/distributed/segment_directory.py @@ -1,12 +1,10 @@ from enum import Enum import threading import time -from typing import Any, Callable, Dict, Optional, cast - +from typing import Any, Callable, Dict, List, Optional, cast from kubernetes import client, config, watch from kubernetes.client.rest import ApiException from overrides import EnforceOverrides, override - from chromadb.config import System from chromadb.segment.distributed import ( Member, @@ -273,7 +271,7 @@ def stop(self) -> None: return super().stop() @override - def get_segment_endpoint(self, segment: Segment) -> str: + def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]: if self._curr_memberlist is None or len(self._curr_memberlist) == 0: raise ValueError("Memberlist is not initialized") @@ -287,38 +285,41 @@ def get_segment_endpoint(self, segment: Segment) -> str: can_use_node_routing = all([m.node != "" for m in self._curr_memberlist]) if can_use_node_routing and self._routing_mode == RoutingMode.NODE: # If we are using node routing and the segments - assignment = assign( + assignments = assign( segment["collection"].hex, [m.node for m in self._curr_memberlist], murmur3hasher, - 1, - )[0] + n, + ) else: # Query to the same collection should end up on the same endpoint - assignment = assign( + assignments = assign( segment["collection"].hex, [m.id for m in self._curr_memberlist], murmur3hasher, - 1, - )[0] + n, + ) - service_name = self.extract_service_name(assignment) - # If the memberlist has an ip, use it, otherwise use the member id with the headless service - # this is for backwards compatibility with the old memberlist which only had ids + assignments_set = set(assignments) + out_endpoints = [] for member in self._curr_memberlist: is_chosen_with_node_routing = ( - can_use_node_routing and member.node == assignment + can_use_node_routing and member.node in assignments_set ) is_chosen_with_id_routing = ( - not can_use_node_routing and member.id == assignment + not can_use_node_routing and member.id in assignments_set ) if is_chosen_with_node_routing or is_chosen_with_id_routing: + # If the memberlist has an ip, use it, otherwise use the member id with the headless service + # this is for backwards compatibility with the old memberlist which only had ids if member.ip is not None and member.ip != "": endpoint = f"{member.ip}:50051" - return endpoint - - endpoint = f"{assignment}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051" # TODO: make port configurable - return endpoint + out_endpoints.append(endpoint) + else: + service_name = self.extract_service_name(member.id) + endpoint = f"{member.id}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051" + out_endpoints.append(endpoint) + return out_endpoints @override def register_updated_segment_callback( diff --git a/chromadb/segment/impl/manager/distributed.py b/chromadb/segment/impl/manager/distributed.py index 4367ab1c44e..5875d7a1929 100644 --- a/chromadb/segment/impl/manager/distributed.py +++ b/chromadb/segment/impl/manager/distributed.py @@ -1,5 +1,5 @@ from threading import Lock -from typing import Dict, Sequence +from typing import Dict, List, Sequence from uuid import UUID, uuid4 from overrides import override @@ -87,8 +87,8 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: "DistributedSegmentManager.get_endpoint", OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) - def get_endpoint(self, segment: Segment) -> str: - return self._segment_directory.get_segment_endpoint(segment) + def get_endpoints(self, segment: Segment, n: int) -> List[str]: + return self._segment_directory.get_segment_endpoints(segment, n) @trace_method( "DistributedSegmentManager.hint_use_collection",