diff --git a/chromadb/api/configuration.py b/chromadb/api/configuration.py index 7a8e1b04896..03488e92909 100644 --- a/chromadb/api/configuration.py +++ b/chromadb/api/configuration.py @@ -239,7 +239,7 @@ class HNSWConfigurationInternal(ConfigurationInternal): name="ef_search", validator=lambda value: isinstance(value, int) and value >= 1, is_static=False, - default_value=10, + default_value=100, ), "num_threads": ConfigurationDefinition( name="num_threads", @@ -328,7 +328,7 @@ def __init__( self, space: str = "l2", ef_construction: int = 100, - ef_search: int = 10, + ef_search: int = 100, num_threads: int = cpu_count(), M: int = 16, resize_factor: float = 1.2, diff --git a/chromadb/errors.py b/chromadb/errors.py index ec8c77aa818..9a3f5d091ab 100644 --- a/chromadb/errors.py +++ b/chromadb/errors.py @@ -194,4 +194,7 @@ def name(cls) -> str: "VersionMismatchError": VersionMismatchError, "RateLimitError": RateLimitError, "AuthError": ChromaAuthError, + "UniqueConstraintError": UniqueConstraintError, + "QuotaError": QuotaError, + "InternalError": InternalError, } diff --git a/chromadb/execution/executor/distributed.py b/chromadb/execution/executor/distributed.py index 3cf5c591c77..6bf795540c1 100644 --- a/chromadb/execution/executor/distributed.py +++ b/chromadb/execution/executor/distributed.py @@ -1,16 +1,12 @@ from typing import Dict, Optional - import grpc from overrides import overrides - from chromadb.api.types import GetResult, Metadata, QueryResult from chromadb.config import System -from chromadb.errors import VersionMismatchError from chromadb.execution.executor.abstract import Executor from chromadb.execution.expression.operator import Scan from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan from chromadb.proto import convert - from chromadb.proto.query_executor_pb2_grpc import QueryExecutorStub from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor from chromadb.segment.impl.manager.distributed import DistributedSegmentManager @@ -170,6 +166,6 @@ def _grpc_executuor_stub(self, scan: Scan) -> QueryExecutorStub: channel = grpc.insecure_channel(grpc_url) interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()] channel = grpc.intercept_channel(channel, *interceptors) - self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel) # type: ignore[no-untyped-call] + self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel) return self._grpc_stub_pool[grpc_url] diff --git a/chromadb/segment/distributed/__init__.py b/chromadb/segment/distributed/__init__.py index 08efdafd18c..049c3b54b62 100644 --- a/chromadb/segment/distributed/__init__.py +++ b/chromadb/segment/distributed/__init__.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from dataclasses import dataclass from typing import Any, Callable, List from overrides import EnforceOverrides, overrides @@ -22,7 +23,13 @@ def register_updated_segment_callback( pass -Memberlist = List[str] +@dataclass +class Member: + id: str + ip: str + + +Memberlist = List[Member] class MemberlistProvider(Component, EnforceOverrides): diff --git a/chromadb/segment/impl/distributed/segment_directory.py b/chromadb/segment/impl/distributed/segment_directory.py index 12a6b35fa7d..097b3b38566 100644 --- a/chromadb/segment/impl/distributed/segment_directory.py +++ b/chromadb/segment/impl/distributed/segment_directory.py @@ -8,6 +8,7 @@ from chromadb.config import System from chromadb.segment.distributed import ( + Member, Memberlist, MemberlistProvider, SegmentDirectory, @@ -35,7 +36,11 @@ class MockMemberlistProvider(MemberlistProvider, EnforceOverrides): def __init__(self, system: System): super().__init__(system) - self._memberlist = ["a", "b", "c"] + self._memberlist = [ + Member(id="a", ip="10.0.0.1"), + Member(id="b", ip="10.0.0.2"), + Member(id="c", ip="10.0.0.3"), + ] @override def get_memberlist(self) -> Memberlist: @@ -203,7 +208,12 @@ def _parse_response_memberlist( ) -> Memberlist: if "members" not in api_response_spec: return [] - return [m["member_id"] for m in api_response_spec["members"]] + parsed = [] + for m in api_response_spec["members"]: + id = m["member_id"] + ip = m["member_ip"] if "member_ip" in m else "" + parsed.append(Member(id=id, ip=ip)) + return parsed def _notify(self, memberlist: Memberlist) -> None: for callback in self.callbacks: @@ -245,11 +255,23 @@ def get_segment_endpoint(self, segment: Segment) -> str: raise ValueError("Memberlist is not initialized") # Query to the same collection should end up on the same endpoint assignment = assign( - segment["collection"].hex, self._curr_memberlist, murmur3hasher, 1 + segment["collection"].hex, + [m.id for m in self._curr_memberlist], + murmur3hasher, + 1, )[0] service_name = self.extract_service_name(assignment) - assignment = f"{assignment}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051" # TODO: make port configurable - return assignment + + # If the memberlist has an ip, use it, otherwise use the member id with the headless service + # this is for backwards compatibility with the old memberlist which only had ids + for member in self._curr_memberlist: + if member.id == assignment: + if member.ip is not None and member.ip != "": + endpoint = f"{member.ip}:50051" + return endpoint + + endpoint = f"{assignment}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051" # TODO: make port configurable + return endpoint @override def register_updated_segment_callback( @@ -263,7 +285,9 @@ def register_updated_segment_callback( ) def _update_memberlist(self, memberlist: Memberlist) -> None: with self._curr_memberlist_mutex: - add_attributes_to_current_span({"new_memberlist": memberlist}) + add_attributes_to_current_span( + {"new_memberlist": [m.id for m in memberlist]} + ) self._curr_memberlist = memberlist def extract_service_name(self, pod_name: str) -> Optional[str]: diff --git a/chromadb/segment/impl/manager/distributed.py b/chromadb/segment/impl/manager/distributed.py index a90b0b62712..91b9d7e1304 100644 --- a/chromadb/segment/impl/manager/distributed.py +++ b/chromadb/segment/impl/manager/distributed.py @@ -14,13 +14,11 @@ from chromadb.segment.distributed import SegmentDirectory from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams from chromadb.telemetry.opentelemetry import ( - OpenTelemetryClient, OpenTelemetryGranularity, trace_method, ) from chromadb.types import ( Collection, - CollectionAndSegments, Operation, Segment, SegmentScope, @@ -30,18 +28,15 @@ class DistributedSegmentManager(SegmentManager): _sysdb: SysDB _system: System - _opentelemetry_client: OpenTelemetryClient _instances: Dict[UUID, SegmentImplementation] _segment_directory: SegmentDirectory _lock: Lock - # _segment_server_stubs: Dict[str, SegmentServerStub] # grpc_url -> grpc stub def __init__(self, system: System): super().__init__(system) self._sysdb = self.require(SysDB) self._segment_directory = self.require(SegmentDirectory) self._system = system - self._opentelemetry_client = system.require(OpenTelemetryClient) self._instances = {} self._lock = Lock() diff --git a/chromadb/segment/impl/vector/hnsw_params.py b/chromadb/segment/impl/vector/hnsw_params.py index b12c4281508..4387f188edf 100644 --- a/chromadb/segment/impl/vector/hnsw_params.py +++ b/chromadb/segment/impl/vector/hnsw_params.py @@ -55,7 +55,7 @@ def __init__(self, metadata: Metadata): metadata = metadata or {} self.space = str(metadata.get("hnsw:space", "l2")) self.construction_ef = int(metadata.get("hnsw:construction_ef", 100)) - self.search_ef = int(metadata.get("hnsw:search_ef", 10)) + self.search_ef = int(metadata.get("hnsw:search_ef", 100)) self.M = int(metadata.get("hnsw:M", 16)) self.num_threads = int( metadata.get("hnsw:num_threads", multiprocessing.cpu_count()) diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index 4f8aeca38a8..9b3e725b2d7 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -16,6 +16,7 @@ CapacityLimiter, ) from fastapi import FastAPI as _FastAPI, Response, Request +from fastapi.openapi.utils import get_openapi from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse from fastapi.routing import APIRoute @@ -61,6 +62,7 @@ ) from starlette.datastructures import Headers import logging +import importlib.metadata from chromadb.telemetry.product.events import ServerStartEvent from chromadb.utils.fastapi import fastapi_json_response, string_to_uuid as _uuid @@ -142,18 +144,6 @@ def validate_model(model: Type[D], data: Any) -> D: # type: ignore return model.parse_obj(data) # pydantic 1.x -def get_openapi_extras_for_model(request_model: Type[D]) -> Dict[str, Any]: - openapi_extra = { - "requestBody": { - "content": { - "application/json": {"schema": request_model.model_json_schema()} - }, - "required": True, - } - } - return openapi_extra - - class ChromaAPIRouter(fastapi.APIRouter): # type: ignore # A simple subclass of fastapi's APIRouter which treats URLs with a # trailing "/" the same as URLs without. Docs will only contain URLs @@ -189,6 +179,10 @@ def __init__(self, settings: Settings): self._app = fastapi.FastAPI(debug=True, default_response_class=ORJSONResponse) self._system = System(settings) self._api: ServerAPI = self._system.instance(ServerAPI) + + self._extra_openapi_schemas: Dict[str, Any] = {} + self._app.openapi = self.generate_openapi + self._opentelemetry_client = self._api.require(OpenTelemetryClient) self._capacity_limiter = CapacityLimiter( settings.chroma_server_thread_pool_size @@ -232,6 +226,37 @@ def __init__(self, settings: Settings): telemetry_client = self._system.instance(ProductTelemetryClient) telemetry_client.capture(ServerStartEvent()) + def generate_openapi(self) -> Dict[str, Any]: + """Used instead of the default openapi() generation handler to include manually-populated schemas.""" + schema: Dict[str, Any] = get_openapi( + title="Chroma", + routes=self._app.routes, + version=importlib.metadata.version("chromadb"), + ) + + for key, value in self._extra_openapi_schemas.items(): + schema["components"]["schemas"][key] = value + + return schema + + def get_openapi_extras_for_body_model( + self, request_model: Type[D] + ) -> Dict[str, Any]: + schema = request_model.model_json_schema( + ref_template="#/components/schemas/{model}" + ) + if "$defs" in schema: + for key, value in schema["$defs"].items(): + self._extra_openapi_schemas[key] = value + + openapi_extra = { + "requestBody": { + "content": {"application/json": {"schema": schema}}, + "required": True, + } + } + return openapi_extra + def setup_v2_routes(self) -> None: self.router.add_api_route("/api/v2", self.root, methods=["GET"]) self.router.add_api_route("/api/v2/reset", self.reset, methods=["POST"]) @@ -253,7 +278,7 @@ def setup_v2_routes(self) -> None: self.create_database, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(CreateDatabase), + openapi_extra=self.get_openapi_extras_for_body_model(CreateDatabase), ) self.router.add_api_route( @@ -268,7 +293,7 @@ def setup_v2_routes(self) -> None: self.create_tenant, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(CreateTenant), + openapi_extra=self.get_openapi_extras_for_body_model(CreateTenant), ) self.router.add_api_route( @@ -295,7 +320,7 @@ def setup_v2_routes(self) -> None: self.create_collection, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(CreateCollection), + openapi_extra=self.get_openapi_extras_for_body_model(CreateCollection), ) self.router.add_api_route( @@ -304,35 +329,35 @@ def setup_v2_routes(self) -> None: methods=["POST"], status_code=status.HTTP_201_CREATED, response_model=None, - openapi_extra=get_openapi_extras_for_model(AddEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/update", self.update, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(UpdateEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(UpdateEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/upsert", self.upsert, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(AddEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/get", self.get, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(GetEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(GetEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/delete", self.delete, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(DeleteEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(DeleteEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/count", @@ -345,7 +370,9 @@ def setup_v2_routes(self) -> None: self.get_nearest_neighbors, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(request_model=QueryEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model( + request_model=QueryEmbedding + ), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_name}", @@ -358,7 +385,7 @@ def setup_v2_routes(self) -> None: self.update_collection, methods=["PUT"], response_model=None, - openapi_extra=get_openapi_extras_for_model(UpdateCollection), + openapi_extra=self.get_openapi_extras_for_body_model(UpdateCollection), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_name}", @@ -1138,7 +1165,7 @@ def setup_v1_routes(self) -> None: self.create_database_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(CreateDatabase), + openapi_extra=self.get_openapi_extras_for_body_model(CreateDatabase), ) self.router.add_api_route( @@ -1153,7 +1180,7 @@ def setup_v1_routes(self) -> None: self.create_tenant_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(CreateTenant), + openapi_extra=self.get_openapi_extras_for_body_model(CreateTenant), ) self.router.add_api_route( @@ -1180,7 +1207,7 @@ def setup_v1_routes(self) -> None: self.create_collection_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(CreateCollection), + openapi_extra=self.get_openapi_extras_for_body_model(CreateCollection), ) self.router.add_api_route( @@ -1189,35 +1216,35 @@ def setup_v1_routes(self) -> None: methods=["POST"], status_code=status.HTTP_201_CREATED, response_model=None, - openapi_extra=get_openapi_extras_for_model(AddEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/update", self.update_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(UpdateEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(UpdateEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/upsert", self.upsert_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(AddEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/get", self.get_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(GetEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(GetEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/delete", self.delete_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(DeleteEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(DeleteEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/count", @@ -1230,7 +1257,7 @@ def setup_v1_routes(self) -> None: self.get_nearest_neighbors_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(QueryEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(QueryEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_name}", @@ -1243,7 +1270,7 @@ def setup_v1_routes(self) -> None: self.update_collection_v1, methods=["PUT"], response_model=None, - openapi_extra=get_openapi_extras_for_model(UpdateCollection), + openapi_extra=self.get_openapi_extras_for_body_model(UpdateCollection), ) self.router.add_api_route( "/api/v1/collections/{collection_name}", @@ -1598,6 +1625,7 @@ async def inner(): ), ) return api_collection_model + return await inner() @trace_method("FastAPI.update_collection_v1", OpenTelemetryGranularity.OPERATION) diff --git a/chromadb/test/api/test_collection.py b/chromadb/test/api/test_collection.py index e367ebc21f8..dee01300e2a 100644 --- a/chromadb/test/api/test_collection.py +++ b/chromadb/test/api/test_collection.py @@ -1,4 +1,5 @@ from chromadb.api import ClientAPI +from chromadb.errors import UniqueConstraintError def test_duplicate_collection_create( @@ -21,7 +22,7 @@ def test_duplicate_collection_create( assert False, "Expected exception" except Exception as e: print("Collection creation failed as expected with error ", e) - assert "already exists" in e.args[0] or "UniqueConstraintError" in e.args[0] + assert "already exists" in e.args[0] or isinstance(e, UniqueConstraintError) def test_not_existing_collection_delete( diff --git a/chromadb/test/distributed/test_reroute.py b/chromadb/test/distributed/test_reroute.py new file mode 100644 index 00000000000..824664aded0 --- /dev/null +++ b/chromadb/test/distributed/test_reroute.py @@ -0,0 +1,74 @@ +from typing import Sequence +from chromadb.test.conftest import ( + reset, + skip_if_not_cluster, +) +from chromadb.api import ClientAPI +from kubernetes import client as k8s_client, config +import time + + +@skip_if_not_cluster() +def test_reroute( + client: ClientAPI, +) -> None: + reset(client) + collection = client.create_collection( + name="test", + metadata={"hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128}, + ) + + ids = [str(i) for i in range(10)] + embeddings: list[Sequence[float]] = [ + [float(i), float(i), float(i)] for i in range(10) + ] + collection.add(ids=ids, embeddings=embeddings) + collection.query(query_embeddings=[embeddings[0]]) + + # Restart the query service using k8s api, in order to trigger a reroute + # of the query service + config.load_kube_config() + v1 = k8s_client.CoreV1Api() + # Find all pods with the label "app=query" + res = v1.list_namespaced_pod("chroma", label_selector="app=query-service") + assert len(res.items) > 0 + items = res.items + seen_ids = set() + + # Restart all the pods by deleting them + for item in items: + seen_ids.add(item.metadata.uid) + name = item.metadata.name + namespace = item.metadata.namespace + v1.delete_namespaced_pod(name, namespace) + + # Wait until we have len(seen_ids) pods running with new UIDs + timeout_secs = 10 + start_time = time.time() + while True: + res = v1.list_namespaced_pod("chroma", label_selector="app=query-service") + items = res.items + new_ids = set([item.metadata.uid for item in items]) + if len(new_ids) == len(seen_ids) and len(new_ids.intersection(seen_ids)) == 0: + break + if time.time() - start_time > timeout_secs: + assert False, "Timed out waiting for new pods to start" + time.sleep(1) + + # Wait for the query service to be ready, or timeout + while True: + res = v1.list_namespaced_pod("chroma", label_selector="app=query-service") + items = res.items + ready = True + for item in items: + if item.status.phase != "Running": + ready = False + break + if ready: + break + if time.time() - start_time > timeout_secs: + assert False, "Timed out waiting for new pods to be ready" + time.sleep(1) + + time.sleep(1) + collection.query(query_embeddings=[embeddings[0]]) diff --git a/chromadb/test/segment/distributed/test_memberlist_provider.py b/chromadb/test/segment/distributed/test_memberlist_provider.py index c97bcbd06cb..0422d84431a 100644 --- a/chromadb/test/segment/distributed/test_memberlist_provider.py +++ b/chromadb/test/segment/distributed/test_memberlist_provider.py @@ -1,9 +1,10 @@ # Tests the CustomResourceMemberlist provider +from dataclasses import asdict import threading from chromadb.test.conftest import skip_if_not_cluster from kubernetes import client, config from chromadb.config import System, Settings -from chromadb.segment.distributed import Memberlist +from chromadb.segment.distributed import Memberlist, Member from chromadb.segment.impl.distributed.segment_directory import ( CustomResourceMemberlistProvider, KUBERNETES_GROUP, @@ -17,12 +18,12 @@ def update_memberlist(n: int, memberlist_name: str = "test-memberlist") -> Membe config.load_config() api_instance = client.CustomObjectsApi() - members = [{"member_id": f"test-{i}"} for i in range(1, n + 1)] + members = [Member(id=f"test-{i}", ip=f"10.0.0.{i}") for i in range(1, n + 1)] body = { "kind": "MemberList", "metadata": {"name": memberlist_name}, - "spec": {"members": members}, + "spec": {"members": [{"member_id": m.id, "member_ip": m.ip} for m in members]}, } _ = api_instance.patch_namespaced_custom_object( @@ -34,11 +35,13 @@ def update_memberlist(n: int, memberlist_name: str = "test-memberlist") -> Membe body=body, ) - return [m["member_id"] for m in members] + return members def compare_memberlists(m1: Memberlist, m2: Memberlist) -> bool: - return sorted(m1) == sorted(m2) + m1_as_dict = sorted([asdict(m) for m in m1], key=lambda x: x["id"]) + m2_as_dict = sorted([asdict(m) for m in m2], key=lambda x: x["id"]) + return m1_as_dict == m2_as_dict @skip_if_not_cluster() diff --git a/clients/js/src/utils.ts b/clients/js/src/utils.ts index c857ab74510..8ee712e5ba0 100644 --- a/clients/js/src/utils.ts +++ b/clients/js/src/utils.ts @@ -23,13 +23,21 @@ export function toArray(obj: T | T[]): Array { // a function to convert an array to array of arrays export function toArrayOfArrays( - obj: Array> | Array, + obj: Array> | Array ): Array> { + if (obj.length === 0) { + return []; + } + if (Array.isArray(obj[0])) { return obj as Array>; - } else { - return [obj] as Array>; } + + if (obj[0] && typeof (obj[0] as any)[Symbol.iterator] === 'function') { + return (obj as unknown as Array>).map(el => Array.from(el)); + } + + return [obj] as Array>; } /** diff --git a/docs/docs.trychroma.com/markdoc/content/docs/collections/configure.md b/docs/docs.trychroma.com/markdoc/content/docs/collections/configure.md index 7b2cf5b389e..05950cafde7 100644 --- a/docs/docs.trychroma.com/markdoc/content/docs/collections/configure.md +++ b/docs/docs.trychroma.com/markdoc/content/docs/collections/configure.md @@ -11,7 +11,7 @@ You can configure the embedding space of a collection by setting special keys on | Cosine similarity | `cosine` | {% Latex %} d = 1.0 - \\frac{\\sum\\left(A_i \\times B_i\\right)}{\\sqrt{\\sum\\left(A_i^2\\right)} \\cdot \\sqrt{\\sum\\left(B_i^2\\right)}} {% /Latex %} | * `hnsw:construction_ef` determines the size of the candidate list used to select neighbors during index creation. A higher value improves index quality at the cost of more memory and time, while a lower value speeds up construction with reduced accuracy. The default value is `100`. -* `hnsw:search_ef` determines the size of the dynamic candidate list used while searching for the nearest neighbors. A higher value improves recall and accuracy by exploring more potential neighbors but increases query time and computational cost, while a lower value results in faster but less accurate searches. The default value is `10`. +* `hnsw:search_ef` determines the size of the dynamic candidate list used while searching for the nearest neighbors. A higher value improves recall and accuracy by exploring more potential neighbors but increases query time and computational cost, while a lower value results in faster but less accurate searches. The default value is `100`. * `hnsw:M` is the maximum number of neighbors (connections) that each node in the graph can have during the construction of the index. A higher value results in a denser graph, leading to better recall and accuracy during searches but increases memory usage and construction time. A lower value creates a sparser graph, reducing memory usage and construction time but at the cost of lower search accuracy and recall. The default value is `16`. * `hnsw:num_threads` specifies the number of threads to use during index construction or search operations. The default value is `multiprocessing.cpu_count()` (available CPU cores). diff --git a/go/pkg/memberlist_manager/memberlist_manager.go b/go/pkg/memberlist_manager/memberlist_manager.go index 24a54e5c8c8..588fceeaa26 100644 --- a/go/pkg/memberlist_manager/memberlist_manager.go +++ b/go/pkg/memberlist_manager/memberlist_manager.go @@ -128,6 +128,16 @@ func memberlistSame(oldMemberlist Memberlist, newMemberlist Memberlist) bool { if len(oldMemberlist) != len(newMemberlist) { return false } + oldMemberlistIps := make(map[string]string) + for _, member := range oldMemberlist { + oldMemberlistIps[member.id] = member.ip + } + for _, member := range newMemberlist { + if ip, ok := oldMemberlistIps[member.id]; !ok || ip != member.ip { + return false + } + } + // use a map to check if the new memberlist contains all the old members newMemberlistMap := make(map[string]bool) for _, member := range newMemberlist { diff --git a/go/pkg/memberlist_manager/memberlist_manager_test.go b/go/pkg/memberlist_manager/memberlist_manager_test.go index 9fb9ff1e172..ccea2ee71aa 100644 --- a/go/pkg/memberlist_manager/memberlist_manager_test.go +++ b/go/pkg/memberlist_manager/memberlist_manager_test.go @@ -52,7 +52,7 @@ func TestNodeWatcher(t *testing.T) { t.Fatalf("Error getting node status: %v", err) } - return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0"}}) + return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}}) }, 10, 1*time.Second) if !ok { t.Fatalf("Node status did not update after adding a pod") @@ -83,7 +83,7 @@ func TestNodeWatcher(t *testing.T) { if err != nil { t.Fatalf("Error getting node status: %v", err) } - return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0"}}) + return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}}) }, 10, 1*time.Second) if !ok { t.Fatalf("Node status did not update after adding a not ready pod") @@ -108,13 +108,13 @@ func TestMemberlistStore(t *testing.T) { assert.Equal(t, Memberlist{}, memberlist) // Add a member to the memberlist - memberlist_store.UpdateMemberlist(context.Background(), Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}, "0") + memberlist_store.UpdateMemberlist(context.Background(), Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}}, "0") memberlist, _, err = memberlist_store.GetMemberlist(context.Background()) if err != nil { t.Fatalf("Error getting memberlist: %v", err) } // assert the memberlist has the correct members - if !memberlistSame(memberlist, Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}) { + if !memberlistSame(memberlist, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}}) { t.Fatalf("Memberlist did not update after adding a member") } } @@ -184,7 +184,7 @@ func TestMemberlistManager(t *testing.T) { // Get the memberlist ok := retryUntilCondition(func() bool { - return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0"}}) + return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.49"}}) }, 30, 1*time.Second) if !ok { t.Fatalf("Memberlist did not update after adding a pod") @@ -195,7 +195,7 @@ func TestMemberlistManager(t *testing.T) { // Get the memberlist ok = retryUntilCondition(func() bool { - return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}) + return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.49"}, Member{id: "test-pod-1", ip: "10.0.0.50"}}) }, 30, 1*time.Second) if !ok { t.Fatalf("Memberlist did not update after adding a pod") @@ -206,7 +206,7 @@ func TestMemberlistManager(t *testing.T) { // Get the memberlist ok = retryUntilCondition(func() bool { - return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-1"}}) + return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-1", ip: "10.0.0.50"}}) }, 30, 1*time.Second) if !ok { t.Fatalf("Memberlist did not update after deleting a pod") @@ -217,23 +217,23 @@ func TestMemberlistSame(t *testing.T) { memberlist := Memberlist{} assert.True(t, memberlistSame(memberlist, memberlist)) - newMemberlist := Memberlist{Member{id: "test-pod-0"}} + newMemberlist := Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}} assert.False(t, memberlistSame(memberlist, newMemberlist)) assert.False(t, memberlistSame(newMemberlist, memberlist)) assert.True(t, memberlistSame(newMemberlist, newMemberlist)) - memberlist = Memberlist{Member{id: "test-pod-1"}} + memberlist = Memberlist{Member{id: "test-pod-1", ip: "10.0.0.2"}} assert.False(t, memberlistSame(newMemberlist, memberlist)) assert.False(t, memberlistSame(memberlist, newMemberlist)) assert.True(t, memberlistSame(memberlist, memberlist)) - memberlist = Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}} - newMemberlist = Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}} + memberlist = Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}} + newMemberlist = Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}} assert.True(t, memberlistSame(memberlist, newMemberlist)) assert.True(t, memberlistSame(newMemberlist, memberlist)) - memberlist = Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}} - newMemberlist = Memberlist{Member{id: "test-pod-1"}, Member{id: "test-pod-0"}} + memberlist = Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}} + newMemberlist = Memberlist{Member{id: "test-pod-1", ip: "10.0.0.2"}, Member{id: "test-pod-0", ip: "10.0.0.1"}} assert.True(t, memberlistSame(memberlist, newMemberlist)) assert.True(t, memberlistSame(newMemberlist, memberlist)) } diff --git a/go/pkg/memberlist_manager/memberlist_store.go b/go/pkg/memberlist_manager/memberlist_store.go index 42a2efe4261..d7046205431 100644 --- a/go/pkg/memberlist_manager/memberlist_store.go +++ b/go/pkg/memberlist_manager/memberlist_store.go @@ -20,11 +20,13 @@ type IMemberlistStore interface { type Member struct { id string + ip string } // MarshalLogObject implements the zapcore.ObjectMarshaler interface func (m Member) MarshalLogObject(enc zapcore.ObjectEncoder) error { enc.AddString("id", m.id) + enc.AddString("ip", m.ip) return nil } @@ -80,7 +82,14 @@ func (s *CRMemberlistStore) GetMemberlist(ctx context.Context) (return_memberlis if !ok { return nil, "", errors.New("failed to cast member_id to string") } - memberlist = append(memberlist, Member{member_id}) + // If member_ip is in the CR, extract it, otherwise set it to empty string + // This is for backwards compatibility with older CRs that don't have member_ip + member_ip, ok := member_map["member_ip"].(string) + if !ok { + member_ip = "" + } + + memberlist = append(memberlist, Member{member_id, member_ip}) } return memberlist, unstrucuted.GetResourceVersion(), nil } @@ -107,6 +116,7 @@ func (list Memberlist) toCr(namespace string, memberlistName string, resourceVer for i, member := range list { members[i] = map[string]interface{}{ "member_id": member.id, + "member_ip": member.ip, } } diff --git a/go/pkg/memberlist_manager/node_watcher.go b/go/pkg/memberlist_manager/node_watcher.go index 4351255da95..a79b73d59f6 100644 --- a/go/pkg/memberlist_manager/node_watcher.go +++ b/go/pkg/memberlist_manager/node_watcher.go @@ -165,7 +165,7 @@ func (w *KubernetesWatcher) ListReadyMembers() (Memberlist, error) { for _, condition := range pod.Status.Conditions { if condition.Type == v1.PodReady { if condition.Status == v1.ConditionTrue { - memberlist = append(memberlist, Member{pod.Name}) + memberlist = append(memberlist, Member{pod.Name, pod.Status.PodIP}) } break } diff --git a/k8s/distributed-chroma/Chart.yaml b/k8s/distributed-chroma/Chart.yaml index 72d420042f6..ab51db46b14 100644 --- a/k8s/distributed-chroma/Chart.yaml +++ b/k8s/distributed-chroma/Chart.yaml @@ -16,7 +16,7 @@ apiVersion: v2 name: distributed-chroma description: A helm chart for distributed Chroma type: application -version: 0.1.12 +version: 0.1.13 appVersion: "0.4.24" keywords: - chroma diff --git a/k8s/distributed-chroma/crds/memberlist_crd.yaml b/k8s/distributed-chroma/crds/memberlist_crd.yaml index 9cde59ab468..51e34426db6 100644 --- a/k8s/distributed-chroma/crds/memberlist_crd.yaml +++ b/k8s/distributed-chroma/crds/memberlist_crd.yaml @@ -27,6 +27,8 @@ spec: properties: member_id: type: string + member_ip: + type: string scope: Namespaced names: plural: memberlists diff --git a/requirements.txt b/requirements.txt index 422c8062ac5..6953574a00b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,10 +9,10 @@ kubernetes>=28.1.0 mmh3>=4.0.1 numpy>=1.22.5 onnxruntime>=1.14.1 -opentelemetry-api>=1.2.0 +opentelemetry-api>=1.24.0 opentelemetry-exporter-otlp-proto-grpc>=1.24.0 opentelemetry-instrumentation-fastapi>=0.41b0 -opentelemetry-sdk>=1.2.0 +opentelemetry-sdk>=1.24.0 orjson>=3.9.12 overrides>=7.3.1 posthog>=2.4.0