Skip to content

Commit

Permalink
Simplify numpy usage
Browse files Browse the repository at this point in the history
  • Loading branch information
tibor-reiss committed Sep 25, 2024
1 parent 1ec45b9 commit 6eb0ad9
Showing 1 changed file with 5 additions and 13 deletions.
18 changes: 5 additions & 13 deletions integration/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uuid
from typing import Any, Callable, Dict, List, Optional, Sequence, TypedDict, Union

import numpy as np
import pytest

from integration.conftest import CollectionFactory, CollectionFactoryGet, _sanitize_collection_name
Expand Down Expand Up @@ -61,15 +62,6 @@
DATE3 = datetime.datetime.strptime("2019-06-10", "%Y-%m-%d").replace(tzinfo=datetime.timezone.utc)


def get_numpy_vector(input_list: list) -> Any:
try:
import numpy as np

return np.array(input_list)
except ModuleNotFoundError:
return input_list


def test_insert_with_typed_dict_generic(
collection_factory: CollectionFactory,
collection_factory_get: CollectionFactoryGet,
Expand Down Expand Up @@ -309,18 +301,18 @@ class TestInsertManyWithTypedDict(TypedDict):
(
[
DataObject(
properties={"name": "some numpy one"}, vector=get_numpy_vector([1, 2, 3])
properties={"name": "some numpy one"}, vector=np.array([1, 2, 3])
),
],
False,
),
(
[
DataObject(
properties={"name": "some numpy one"}, vector=get_numpy_vector([1, 2, 3])
properties={"name": "some numpy one"}, vector=np.array([1, 2, 3])
),
DataObject(
properties={"name": "some numpy two"}, vector=get_numpy_vector([11, 12, 13])
properties={"name": "some numpy two"}, vector=np.array([11, 12, 13])
),
],
False,
Expand All @@ -329,7 +321,7 @@ class TestInsertManyWithTypedDict(TypedDict):
[
DataObject(
properties={"name": "some numpy 2d"},
vector=get_numpy_vector([[1, 2, 3], [11, 12, 13]]),
vector=np.array([[1, 2, 3], [11, 12, 13]]),
),
],
True,
Expand Down

0 comments on commit 6eb0ad9

Please sign in to comment.