Skip to content

Commit

Permalink
change neighbor comparison methods to check str representation of key…
Browse files Browse the repository at this point in the history
… if distance is equal
  • Loading branch information
dwelch-spike committed Oct 11, 2024
1 parent 6a30658 commit a6468f8
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 35 deletions.
41 changes: 12 additions & 29 deletions src/aerospike_vector_search/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def __init__(self, *, namespace: str, set: str, key: Any) -> None:

def __repr__(self) -> str:
return (
f"namespace={self.namespace}, "
f"Key(namespace={self.namespace}, "
f"set={self.set}, "
f"key={self.key}"
f"key={self.key})"
)

def __str__(self):
Expand Down Expand Up @@ -137,9 +137,9 @@ def __init__(self, *, key: Key, fields: dict[str, Any], distance: float) -> None

def __repr__(self) -> str:
return (
f"key={self.key}, "
f"Neighbor(key={self.key}, "
f"fields={self.fields}, "
f"distance={self.distance}"
f"distance={self.distance})"
)

def __str__(self):
Expand Down Expand Up @@ -167,7 +167,6 @@ def __str__(self):
def __eq__(self, other) -> bool:
if not isinstance(other, Neighbor):
return NotImplemented

return (
self.distance == other.distance
and self.key == other.key
Expand All @@ -177,38 +176,22 @@ def __eq__(self, other) -> bool:
def __lt__(self, other) -> bool:
if not isinstance(other, Neighbor):
return NotImplemented

if self.distance == other.distance:
return self.key.key < other.key.key

return self.distance < other.distance
if self.distance != other.distance:
return self.distance < other.distance
if self.key.set != other.key.set:
return self.key.set < other.key.set
return str(self.key.key) < str(other.key.key)

def __le__(self, other) -> bool:
if not isinstance(other, Neighbor):
return NotImplemented

if self.distance == other.distance:
return self.key.key <= other.key.key

return self.distance <= other.distance
return self < other or self == other

def __gt__(self, other) -> bool:
if not isinstance(other, Neighbor):
return NotImplemented

if self.distance == other.distance:
return self.key.key > other.key.key

return self.distance > other.distance
return not (self <= other)

def __ge__(self, other) -> bool:
if not isinstance(other, Neighbor):
return NotImplemented

if self.distance == other.distance:
return self.key.key >= other.key.key

return self.distance >= other.distance
return not (self < other)



Expand Down
4 changes: 2 additions & 2 deletions tests/standard/aio/test_vector_client_search_by_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ async def test_vector_search_by_key(
namespace=test_case.key_namespace,
key=key,
record_data=rec,
key_set=test_case.key_set,
set_name=test_case.key_set,
))

tasks.append(
Expand All @@ -306,7 +306,7 @@ async def test_vector_search_by_key(
exclude_fields=test_case.exclude_fields,
)

assert results == test_case.expected_results
assert list.sort(results) == list.sort(test_case.expected_results)

tasks = []
for key in test_case.record_data:
Expand Down
2 changes: 1 addition & 1 deletion tests/standard/aio/test_vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def test_vector_search(
exclude_fields=test_case.exclude_fields,
)

assert results == test_case.expected_results
assert list.sort(results) == list.sort(test_case.expected_results)

tasks = []
for key in test_case.record_data:
Expand Down
4 changes: 2 additions & 2 deletions tests/standard/sync/test_vector_client_search_by_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def test_vector_search_by_key(
namespace=test_case.key_namespace,
key=key,
record_data=rec,
key_set=test_case.key_set,
set_name=test_case.key_set,
)

session_vector_client.wait_for_index_completion(
Expand All @@ -300,7 +300,7 @@ def test_vector_search_by_key(
exclude_fields=test_case.exclude_fields,
)

assert results == test_case.expected_results
assert list.sort(results) == list.sort(test_case.expected_results)

for key in test_case.record_data:
session_vector_client.delete(
Expand Down
2 changes: 1 addition & 1 deletion tests/standard/sync/test_vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_vector_search(
exclude_fields=test_case.exclude_fields,
)

assert results == test_case.expected_results
assert list.sort(results) == list.sort(test_case.expected_results)

for key in test_case.record_data:
session_vector_client.delete(
Expand Down

0 comments on commit a6468f8

Please sign in to comment.