Skip to content

Commit

Permalink
round robin
Browse files Browse the repository at this point in the history
  • Loading branch information
HammadB committed Jan 9, 2025
1 parent 4b047d7 commit 0bc0b8f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 26 deletions.
5 changes: 4 additions & 1 deletion chromadb/execution/executor/distributed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from typing import Dict, Optional
import grpc
from overrides import overrides
Expand Down Expand Up @@ -161,7 +162,9 @@ def knn(self, plan: KNNPlan) -> QueryResult:
def _grpc_executuor_stub(self, scan: Scan) -> QueryExecutorStub:
# Since grpc endpoint is endpoint is determined by collection uuid,
# the endpoint should be the same for all segments of the same collection
grpc_url = self._manager.get_endpoint(scan.record)
# TODO: configure the number of endpoints to fetch
grpc_urls = self._manager.get_endpoints(scan.record, 3)
grpc_url = grpc_urls[random.randint(0, len(grpc_urls) - 1)]
if grpc_url not in self._grpc_stub_pool:
channel = grpc.insecure_channel(grpc_url)
interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()]
Expand Down
8 changes: 5 additions & 3 deletions chromadb/segment/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

class SegmentDirectory(Component):
"""A segment directory is a data interface that manages the location of segments. Concretely, this
means that for clustered chroma, it provides the grpc endpoint for a segment."""
means that for distributed chroma, it provides the grpc endpoint for a segment."""

@abstractmethod
def get_segment_endpoint(self, segment: Segment) -> str:
"""Return the segment residence for a given segment ID"""
def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]:
"""Return the segment residences for a given segment ID. Will return at most n residences.
Should only return less than n residences if there are less than n residences available.
"""

@abstractmethod
def register_updated_segment_callback(
Expand Down
39 changes: 20 additions & 19 deletions chromadb/segment/impl/distributed/segment_directory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from enum import Enum
import threading
import time
from typing import Any, Callable, Dict, Optional, cast

from typing import Any, Callable, Dict, List, Optional, cast
from kubernetes import client, config, watch
from kubernetes.client.rest import ApiException
from overrides import EnforceOverrides, override

from chromadb.config import System
from chromadb.segment.distributed import (
Member,
Expand Down Expand Up @@ -273,7 +271,7 @@ def stop(self) -> None:
return super().stop()

@override
def get_segment_endpoint(self, segment: Segment) -> str:
def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]:
if self._curr_memberlist is None or len(self._curr_memberlist) == 0:
raise ValueError("Memberlist is not initialized")

Expand All @@ -287,38 +285,41 @@ def get_segment_endpoint(self, segment: Segment) -> str:
can_use_node_routing = all([m.node != "" for m in self._curr_memberlist])
if can_use_node_routing and self._routing_mode == RoutingMode.NODE:
# If we are using node routing and the segments
assignment = assign(
assignments = assign(
segment["collection"].hex,
[m.node for m in self._curr_memberlist],
murmur3hasher,
1,
)[0]
n,
)
else:
# Query to the same collection should end up on the same endpoint
assignment = assign(
assignments = assign(
segment["collection"].hex,
[m.id for m in self._curr_memberlist],
murmur3hasher,
1,
)[0]
n,
)

service_name = self.extract_service_name(assignment)
# If the memberlist has an ip, use it, otherwise use the member id with the headless service
# this is for backwards compatibility with the old memberlist which only had ids
assignments_set = set(assignments)
out_endpoints = []
for member in self._curr_memberlist:
is_chosen_with_node_routing = (
can_use_node_routing and member.node == assignment
can_use_node_routing and member.node in assignments_set
)
is_chosen_with_id_routing = (
not can_use_node_routing and member.id == assignment
not can_use_node_routing and member.id in assignments_set
)
if is_chosen_with_node_routing or is_chosen_with_id_routing:
# If the memberlist has an ip, use it, otherwise use the member id with the headless service
# this is for backwards compatibility with the old memberlist which only had ids
if member.ip is not None and member.ip != "":
endpoint = f"{member.ip}:50051"
return endpoint

endpoint = f"{assignment}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051" # TODO: make port configurable
return endpoint
out_endpoints.append(endpoint)
else:
service_name = self.extract_service_name(member.id)
endpoint = f"{member.id}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051"
out_endpoints.append(endpoint)
return out_endpoints

@override
def register_updated_segment_callback(
Expand Down
6 changes: 3 additions & 3 deletions chromadb/segment/impl/manager/distributed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from threading import Lock
from typing import Dict, Sequence
from typing import Dict, List, Sequence
from uuid import UUID, uuid4

from overrides import override
Expand Down Expand Up @@ -87,8 +87,8 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
"DistributedSegmentManager.get_endpoint",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
def get_endpoint(self, segment: Segment) -> str:
return self._segment_directory.get_segment_endpoint(segment)
def get_endpoints(self, segment: Segment, n: int) -> List[str]:
return self._segment_directory.get_segment_endpoints(segment, n)

@trace_method(
"DistributedSegmentManager.hint_use_collection",
Expand Down

0 comments on commit 0bc0b8f

Please sign in to comment.