Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace threading.Lock with asyncio.Lock when batching to avoid deadlocks #1270

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
f16f366
Fix deadlocks in the batching algo by replacing `threading.Lock` with…
tsmith023 Aug 29, 2024
4e8863f
Remove lower bound requirement on fixed size batching
tsmith023 Aug 29, 2024
057b3dc
change time.sleep to asyncio.sleep in async func
tsmith023 Aug 29, 2024
3150959
Ensure `asyncio.Lock`s are opened in the event loop thread to handle …
tsmith023 Aug 29, 2024
d45621f
Remove stifling locks in _add_objects
tsmith023 Aug 29, 2024
336f8ee
Merge branch 'fix-deadlocks-in-batching' of https://github.com/weavia…
tsmith023 Aug 29, 2024
7cdcfc0
Move objs and refs inits back to __init__
tsmith023 Aug 29, 2024
d4a3a9f
Add missing props in collection/base.pyi stubs
tsmith023 Aug 29, 2024
c7bc642
Fix formatting
tsmith023 Aug 29, 2024
0539e98
Release `asyncio.Lock`s in the event-loop's context
tsmith023 Aug 30, 2024
c1daea8
Fix linter
tsmith023 Aug 30, 2024
86f9431
Log errors in __send_batch
tsmith023 Sep 2, 2024
8866503
Change `ErrorX` classes to refer to `BatchX` instead of internal `_Ba…
tsmith023 Sep 16, 2024
176984e
Merge branch 'main' of https://github.com/weaviate/weaviate-python-cl…
tsmith023 Sep 16, 2024
da9603a
Fix wrong <3.10 syntax
tsmith023 Sep 16, 2024
3eafa77
Fix missing default of retry_count in BatchObject
tsmith023 Sep 16, 2024
7a98f55
Fix parsing of `XReference`
tsmith023 Sep 16, 2024
ee17a4e
Merge branch 'main' of https://github.com/weaviate/weaviate-python-cl…
tsmith023 Nov 1, 2024
feb148a
Move futures inside executor to avoid races
tsmith023 Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ image.png
scratch/

*-test.sh
*.hdf5
*.hdf5
*.jsonl
14 changes: 9 additions & 5 deletions integration/test_batch_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,8 +596,8 @@ def batch_insert(batch: BatchClient) -> None:
with concurrent.futures.ThreadPoolExecutor() as executor:
with client.batch.dynamic() as batch:
futures = [executor.submit(batch_insert, batch) for _ in range(nr_threads)]
for future in concurrent.futures.as_completed(futures):
future.result()
for future in concurrent.futures.as_completed(futures):
future.result()
objs = client.collections.get(name).query.fetch_objects(limit=nr_objects * nr_threads).objects
assert len(objs) == nr_objects * nr_threads

Expand Down Expand Up @@ -687,9 +687,13 @@ def test_batching_error_logs(
for obj in [{"name": i} for i in range(100)]:
batch.add_object(properties=obj, collection=name)
assert (
"Failed to send 100 objects in a batch of 100. Please inspect client.batch.failed_objects or collection.batch.failed_objects for the failed objects."
in caplog.text
)
("Failed to send" in caplog.text)
and ("objects in a batch of" in caplog.text)
and (
"Please inspect client.batch.failed_objects or collection.batch.failed_objects for the failed objects."
in caplog.text
)
) # number of objects sent per batch is not fixed for less than 100 objects


def test_references_with_to_uuids(client_factory: ClientFactory) -> None:
Expand Down
6 changes: 3 additions & 3 deletions profiling/test_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def test_sphere(collection_factory: CollectionFactory) -> None:
sphere_file = get_file_path("sphere.100k.jsonl")
sphere_file = get_file_path("sphere.1m.jsonl")

collection = collection_factory(
properties=[
Expand All @@ -26,7 +26,7 @@ def test_sphere(collection_factory: CollectionFactory) -> None:
)
start = time.time()

import_objects = 50000
import_objects = 1000000
with collection.batch.dynamic() as batch:
with open(sphere_file) as jsonl_file:
for i, jsonl in enumerate(jsonl_file):
Expand All @@ -45,7 +45,7 @@ def test_sphere(collection_factory: CollectionFactory) -> None:
vector=json_parsed["vector"],
)
if i % 1000 == 0:
print(f"Imported {i} objects")
print(f"Imported {len(collection)} objects")
assert len(collection.batch.failed_objects) == 0
assert len(collection) == import_objects
print(f"Imported {import_objects} objects in {time.time() - start}")
71 changes: 50 additions & 21 deletions weaviate/collections/batch/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import math
import threading
import time
Expand Down Expand Up @@ -161,11 +162,12 @@ def __init__(
batch_mode: _BatchMode,
event_loop: _EventLoop,
vectorizer_batching: bool,
objects_: Optional[ObjectsBatchRequest] = None,
objects: Optional[ObjectsBatchRequest] = None,
references: Optional[ReferencesBatchRequest] = None,
) -> None:
self.__batch_objects = objects_ or ObjectsBatchRequest()
self.__batch_objects = objects or ObjectsBatchRequest()
self.__batch_references = references or ReferencesBatchRequest()

self.__connection = connection
self.__consistency_level: Optional[ConsistencyLevel] = consistency_level
self.__vectorizer_batching = vectorizer_batching
Expand All @@ -174,15 +176,12 @@ def __init__(
self.__batch_rest = _BatchREST(connection, self.__consistency_level)

# lookup table for objects that are currently being processed - is used to not send references from objects that have not been added yet
self.__uuid_lookup_lock = threading.Lock()
self.__uuid_lookup: Set[str] = set()

# we do not want that users can access the results directly as they are not thread-safe
self.__results_for_wrapper_backup = results
self.__results_for_wrapper = _BatchDataWrapper()

self.__results_lock = threading.Lock()

self.__cluster = _ClusterBatch(self.__connection)

self.__batching_mode: _BatchMode = batch_mode
Expand Down Expand Up @@ -221,7 +220,6 @@ def __init__(
self.__recommended_num_refs: int = 50

self.__active_requests = 0
self.__active_requests_lock = threading.Lock()

# dynamic batching
self.__time_last_scale_up: float = 0
Expand All @@ -233,9 +231,21 @@ def __init__(
# do 62 secs to give us some buffer to the "per-minute" calculation
self.__fix_rate_batching_base_time = 62

self.__loop.run_until_complete(self.__make_asyncio_locks)

self.__bg_thread = self.__start_bg_threads()
self.__bg_thread_exception: Optional[Exception] = None

async def __make_asyncio_locks(self) -> None:
"""Create the locks in the context of the running event loop so that internal `asyncio.get_event_loop()` calls work."""
self.__active_requests_lock = asyncio.Lock()
self.__uuid_lookup_lock = asyncio.Lock()
self.__results_lock = asyncio.Lock()

async def __release_asyncio_lock(self, lock: asyncio.Lock) -> None:
"""Release the lock in the context of the running event loop so that internal `asyncio.get_event_loop()` calls work."""
return lock.release()

@property
def number_errors(self) -> int:
"""Return the number of errors in the batch."""
Expand Down Expand Up @@ -292,16 +302,17 @@ def __batch_send(self) -> None:
self.__time_stamp_last_request = time.time()

self._batch_send = True
self.__active_requests_lock.acquire()
self.__loop.run_until_complete(self.__active_requests_lock.acquire)
self.__active_requests += 1
self.__active_requests_lock.release()
self.__loop.run_until_complete(
self.__release_asyncio_lock, self.__active_requests_lock
)

objs = self.__batch_objects.pop_items(self.__recommended_num_objects)
self.__uuid_lookup_lock.acquire()
refs = self.__batch_references.pop_items(
self.__recommended_num_refs, uuid_lookup=self.__uuid_lookup
self.__recommended_num_refs,
uuid_lookup=self.__uuid_lookup,
)
self.__uuid_lookup_lock.release()
# do not block the thread - the results are written to a central (locked) list and we want to have multiple concurrent batch-requests
self.__loop.schedule(
self.__send_batch,
Expand Down Expand Up @@ -349,6 +360,7 @@ def batch_send_wrapper() -> None:
try:
self.__batch_send()
except Exception as e:
logger.error(e)
self.__bg_thread_exception = e

demonBatchSend = threading.Thread(
Expand All @@ -357,6 +369,7 @@ def batch_send_wrapper() -> None:
name="BgBatchScheduler",
)
demonBatchSend.start()

return demonBatchSend

def __dynamic_batching(self) -> None:
Expand Down Expand Up @@ -459,10 +472,24 @@ async def __send_batch(
response_obj = await self.__batch_grpc.objects(
objects=objs, timeout=DEFAULT_REQUEST_TIMEOUT
)
if response_obj.has_errors:
logger.error(
{
"message": f"Failed to send {len(response_obj.errors)} in a batch of {len(objs)}",
"errors": {err.message for err in response_obj.errors.values()},
}
)
except Exception as e:
errors_obj = {
idx: ErrorObject(message=repr(e), object_=obj) for idx, obj in enumerate(objs)
idx: ErrorObject(message=repr(e), object_=BatchObject._from_internal(obj))
for idx, obj in enumerate(objs)
}
logger.error(
{
"message": f"Failed to send all objects in a batch of {len(objs)}",
"error": repr(e),
}
)
response_obj = BatchObjectReturn(
_all_responses=list(errors_obj.values()),
elapsed_seconds=time.time() - start,
Expand Down Expand Up @@ -509,7 +536,9 @@ async def __send_batch(
)

readd_objects = [
err.object_ for i, err in response_obj.errors.items() if i in readded_objects
err.object_._to_internal()
for i, err in response_obj.errors.items()
if i in readded_objects
]
readded_uuids = {obj.uuid for obj in readd_objects}

Expand Down Expand Up @@ -541,8 +570,8 @@ async def __send_batch(
)
else:
# sleep a bit to recover from the rate limit in other cases
time.sleep(2**highest_retry_count)
self.__uuid_lookup_lock.acquire()
await asyncio.sleep(2**highest_retry_count)
await self.__uuid_lookup_lock.acquire()
self.__uuid_lookup.difference_update(
obj.uuid for obj in objs if obj.uuid not in readded_uuids
)
Expand All @@ -561,7 +590,7 @@ async def __send_batch(
"message": "There have been more than 30 failed object batches. Further errors will not be logged.",
}
)
self.__results_lock.acquire()
await self.__results_lock.acquire()
self.__results_for_wrapper.results.objs += response_obj
self.__results_for_wrapper.failed_objects.extend(response_obj.errors.values())
self.__results_lock.release()
Expand All @@ -573,7 +602,9 @@ async def __send_batch(
response_ref = await self.__batch_rest.references(references=refs)
except Exception as e:
errors_ref = {
idx: ErrorReference(message=repr(e), reference=ref)
idx: ErrorReference(
message=repr(e), reference=BatchReference._from_internal(ref)
)
for idx, ref in enumerate(refs)
}
response_ref = BatchReferenceReturn(
Expand All @@ -595,12 +626,12 @@ async def __send_batch(
"message": "There have been more than 30 failed reference batches. Further errors will not be logged.",
}
)
self.__results_lock.acquire()
await self.__results_lock.acquire()
self.__results_for_wrapper.results.refs += response_ref
self.__results_for_wrapper.failed_references.extend(response_ref.errors.values())
self.__results_lock.release()

self.__active_requests_lock.acquire()
await self.__active_requests_lock.acquire()
self.__active_requests -= 1
self.__active_requests_lock.release()

Expand Down Expand Up @@ -641,9 +672,7 @@ def _add_object(
)
except ValidationError as e:
raise WeaviateBatchValidationError(repr(e))
self.__uuid_lookup_lock.acquire()
self.__uuid_lookup.add(str(batch_object.uuid))
self.__uuid_lookup_lock.release()
self.__batch_objects.add(batch_object._to_internal())

# block if queue gets too long or weaviate is overloaded - reading files is faster them sending them so we do
Expand Down
5 changes: 4 additions & 1 deletion weaviate/collections/batch/grpc_batch_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from weaviate.collections.classes.batch import (
ErrorObject,
_BatchObject,
BatchObject,
BatchObjectReturn,
)
from weaviate.collections.classes.config import ConsistencyLevel
Expand Down Expand Up @@ -116,7 +117,9 @@ async def objects(
for idx, weav_obj in enumerate(weaviate_objs):
obj = objects[idx]
if idx in errors:
error = ErrorObject(errors[idx], obj, original_uuid=obj.uuid)
error = ErrorObject(
errors[idx], BatchObject._from_internal(obj), original_uuid=obj.uuid
)
return_errors[obj.index] = error
all_responses[idx] = error
else:
Expand Down
3 changes: 2 additions & 1 deletion weaviate/collections/batch/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from weaviate.collections.classes.batch import (
ErrorReference,
BatchReference,
_BatchReference,
BatchReferenceReturn,
)
Expand Down Expand Up @@ -45,7 +46,7 @@ async def references(self, references: List[_BatchReference]) -> BatchReferenceR
errors = {
idx: ErrorReference(
message=entry["result"]["errors"]["error"][0]["message"],
reference=references[idx],
reference=BatchReference._from_internal(references[idx]),
)
for idx, entry in enumerate(payload)
if entry["result"]["status"] == "FAILED"
Expand Down
39 changes: 36 additions & 3 deletions weaviate/collections/classes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class _BatchReference:
to: str
tenant: Optional[str]
from_uuid: str
to_uuid: Optional[str] = None
to_uuid: Union[str, None]


class BatchObject(BaseModel):
Expand All @@ -49,6 +49,7 @@ class BatchObject(BaseModel):
vector: Optional[VECTORS] = Field(default=None)
tenant: Optional[str] = Field(default=None)
index: int
retry_count: int = 0

def __init__(self, **data: Any) -> None:
v = data.get("vector")
Expand Down Expand Up @@ -76,6 +77,19 @@ def _to_internal(self) -> _BatchObject:
index=self.index,
)

@classmethod
def _from_internal(cls, obj: _BatchObject) -> "BatchObject":
return BatchObject(
collection=obj.collection,
vector=obj.vector,
uuid=uuid_package.UUID(obj.uuid),
properties=obj.properties,
tenant=obj.tenant,
references=obj.references,
index=obj.index,
retry_count=obj.retry_count,
)

@field_validator("collection")
def _validate_collection(cls, v: str) -> str:
return _capitalize_first_letter(v)
Expand Down Expand Up @@ -136,13 +150,32 @@ def _to_internal(self) -> _BatchReference:
tenant=self.tenant,
)

@classmethod
def _from_internal(cls, ref: _BatchReference) -> "BatchReference":
from_ = ref.from_.split("weaviate://")[1].split("/")
to = ref.to.split("weaviate://")[1].split("/")
if len(to) == 2:
to_object_collection = to[1]
elif len(to) == 1:
to_object_collection = None
else:
raise ValueError(f"Invalid reference 'to' value in _BatchReference object {ref}")
return BatchReference(
from_object_collection=from_[1],
from_object_uuid=ref.from_uuid,
from_property_name=ref.from_[-1],
to_object_uuid=ref.to_uuid if ref.to_uuid is not None else uuid_package.UUID(to[-1]),
to_object_collection=to_object_collection,
tenant=ref.tenant,
)


@dataclass
class ErrorObject:
"""This class contains the error information for a single object in a batch operation."""

message: str
object_: _BatchObject
object_: BatchObject
original_uuid: Optional[UUID] = None


Expand All @@ -151,7 +184,7 @@ class ErrorReference:
"""This class contains the error information for a single reference in a batch operation."""

message: str
reference: _BatchReference
reference: BatchReference


@dataclass
Expand Down