Skip to content

Commit

Permalink
Refactor python SDK (#2364)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Refactor python SDK code

### Type of change

- [x] Refactoring
- [x] Python SDK impacted, Need to update PyPI

Signed-off-by: Jin Hai <[email protected]>
  • Loading branch information
JinHai-CN authored and vsian committed Dec 13, 2024
1 parent 2f0aa4c commit eac8d7b
Show file tree
Hide file tree
Showing 27 changed files with 250 additions and 118 deletions.
8 changes: 5 additions & 3 deletions python/infinity_embedded/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
# Copyright(C) 2024 InfiniFlow, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,12 +20,14 @@
# import pkg_resources
# __version__ = pkg_resources.get_distribution("infinity_sdk").version

from infinity_embedded.common import URI, NetworkAddress, LOCAL_HOST, LOCAL_INFINITY_PATH, InfinityException, LOCAL_INFINITY_CONFIG_PATH
from infinity_embedded.common import URI, NetworkAddress, LOCAL_HOST, LOCAL_INFINITY_PATH, InfinityException, \
LOCAL_INFINITY_CONFIG_PATH
from infinity_embedded.infinity import InfinityConnection
from infinity_embedded.local_infinity.infinity import LocalInfinityConnection
from infinity_embedded.errors import ErrorCode

def connect(uri = LOCAL_INFINITY_PATH, config_path = LOCAL_INFINITY_CONFIG_PATH) -> InfinityConnection:

def connect(uri=LOCAL_INFINITY_PATH, config_path=LOCAL_INFINITY_CONFIG_PATH) -> InfinityConnection:
if isinstance(uri, str) and len(uri) != 0:
return LocalInfinityConnection(uri, config_path)
else:
Expand Down
5 changes: 4 additions & 1 deletion python/infinity_embedded/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
# Copyright(C) 2024 InfiniFlow, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Union
from dataclasses import dataclass
Expand Down Expand Up @@ -75,10 +76,12 @@ class ConflictType(object):
Error = 1
Replace = 2


class SortType(object):
Asc = 0
Desc = 1


class InfinityException(Exception):
def __init__(self, error_code=0, error_message=None):
self.error_code = error_code
Expand Down
3 changes: 2 additions & 1 deletion python/infinity_embedded/db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
# Copyright(C) 2024 InfiniFlow, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,6 +14,7 @@

from abc import ABC, abstractmethod


class Database(ABC):

@abstractmethod
Expand Down
35 changes: 32 additions & 3 deletions python/infinity_embedded/errors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
# Copyright(C) 2024 InfiniFlow, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -117,6 +117,15 @@ class ErrorCode(IntEnum):
INVALID_EXPLAIN_TYPE = 3081,
CHUNK_NOT_EXIST = 3082,
NAME_MISMATCHED = 3083,
TRANSACTION_NOT_FOUND = 3084,
INVALID_DATABASE_INDEX = 3085,
INVALID_TABLE_INDEX = 3086,
FUNCTION_IS_DISABLE = 3087,
NOT_FOUND = 3088,
ERROR_INIT = 3089,
FILE_IS_OPEN = 3090,
UNKNOWN = 3091,
INVALID_QUERY_OPTION = 3092,

TXN_ROLLBACK = 4001,
TXN_CONFLICT = 4002,
Expand All @@ -126,6 +135,7 @@ class ErrorCode(IntEnum):
TOO_MANY_CONNECTIONS = 5003,
CONFIGURATION_LIMIT_EXCEED = 5004,
QUERY_IS_TOO_COMPLEX = 5005,
FAIL_TO_GET_SYS_INFO = 5006,

QUERY_CANCELLED = 6001,
QUERY_NOT_SUPPORTED = 6002,
Expand All @@ -147,7 +157,26 @@ class ErrorCode(IntEnum):
MUNMAP_FILE_ERROR = 7014,
INVALID_FILE_FLAG = 7015,
INVALID_SERVER_ADDRESS = 7016,
FAIL_TO_FUN_PYTHON = 7017,
CANT_CONNECT_SERVER = 7018,
NOT_EXIST_NODE = 7019,
DUPLICATE_NODE = 7020,
CANT_CONNECT_LEADER = 7021,
MINIO_INVALID_ACCESS_KEY = 7022,
MINIO_BUCKET_NOT_EXISTS = 7023,
INVALID_STORAGE_TYPE = 7024,
NOT_REGISTERED = 7025,
CANT_SWITCH_ROLE = 7026,
TOO_MANY_FOLLOWER = 7027,
TOO_MANY_LEARNER = 7028,

INVALID_ENTRY = 8001,
NOT_FOUND_ENTRY = 8002,
EMPTY_ENTRY_LIST = 8003,
DUPLICATE_ENTRY = 8002
NOT_FOUND_ENTRY = 8003,
EMPTY_ENTRY_LIST = 8004,
NO_WAL_ENTRY_FOUND = 8005,
WRONG_CHECKPOINT_TYPE = 8006,
INVALID_NODE_ROLE = 8007,
INVALID_NODE_STATUS = 8008,
NODE_INFO_UPDATED = 8009,
NODE_NAME_MISMATCH = 8010
3 changes: 1 addition & 2 deletions python/infinity_embedded/index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
# Copyright(C) 2024 InfiniFlow, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,6 @@

from infinity_embedded.embedded_infinity_ext import IndexType as LocalIndexType, WrapIndexInfo
from infinity_embedded.embedded_infinity_ext import InitParameter as LocalInitParameter
from infinity_embedded.embedded_infinity_ext import WrapIndexInfo as LocalIndexInfo
from infinity_embedded.errors import ErrorCode


Expand Down
4 changes: 3 additions & 1 deletion python/infinity_embedded/infinity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
# Copyright(C) 2024 InfiniFlow, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod


# abstract class
class InfinityConnection(ABC):
def __init__(self, uri):
Expand Down
7 changes: 4 additions & 3 deletions python/infinity_embedded/local_infinity/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
# Copyright(C) 2024 InfiniFlow, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,7 +22,8 @@ class LocalQueryResult:
def __init__(self, error_code: PyErrorCode, error_msg: str, db_names=None, table_names=None, index_names=None,
column_defs=None, column_fields=None, database_name=None, store_dir=None, table_count=None,
comment=None,
table_name=None, index_name=None, index_type=None, index_comment=None, deleted_rows=0, extra_result=None):
table_name=None, index_name=None, index_type=None, index_comment=None, deleted_rows=0,
extra_result=None):
self.error_code = error_code
self.error_msg = error_msg
self.db_names = db_names
Expand All @@ -44,7 +45,7 @@ def __init__(self, error_code: PyErrorCode, error_msg: str, db_names=None, table


class LocalInfinityClient:
def __init__(self, path: str = LOCAL_INFINITY_PATH, config_path = LOCAL_INFINITY_CONFIG_PATH):
def __init__(self, path: str = LOCAL_INFINITY_PATH, config_path=LOCAL_INFINITY_CONFIG_PATH):
self.path = path
Infinity.LocalInit(path, config_path)
self.client = Infinity.LocalConnect()
Expand Down
14 changes: 14 additions & 0 deletions python/infinity_embedded/local_infinity/db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright(C) 2024 InfiniFlow, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC

from infinity_embedded.db import Database
Expand Down
15 changes: 14 additions & 1 deletion python/infinity_embedded/local_infinity/infinity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright(C) 2024 InfiniFlow, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from infinity_embedded import InfinityConnection
from abc import ABC
Expand Down Expand Up @@ -90,7 +104,6 @@ def show_current_node(self):
else:
raise InfinityException(res.error_code, res.error_msg)


def search(self, db_name, table_name):
self.check_connect()
res = self._client.search(db_name, table_name, [])
Expand Down
103 changes: 60 additions & 43 deletions python/infinity_embedded/local_infinity/query_builder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright(C) 2024 InfiniFlow, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import ABC
Expand All @@ -18,18 +32,19 @@
from infinity_embedded.table import ExplainType as BaseExplainType
from infinity_embedded.errors import ErrorCode


class Query(ABC):
def __init__(
self,
columns: Optional[List[WrapParsedExpr]],
highlight: Optional[List[WrapParsedExpr]],
search: Optional[WrapSearchExpr],
filter: Optional[WrapParsedExpr],
group_by: Optional[List[WrapParsedExpr]],
limit: Optional[WrapParsedExpr],
offset: Optional[WrapParsedExpr],
sort: Optional[List[WrapOrderByExpr]],
total_hits_count: Optional[bool]
self,
columns: Optional[List[WrapParsedExpr]],
highlight: Optional[List[WrapParsedExpr]],
search: Optional[WrapSearchExpr],
filter: Optional[WrapParsedExpr],
group_by: Optional[List[WrapParsedExpr]],
limit: Optional[WrapParsedExpr],
offset: Optional[WrapParsedExpr],
sort: Optional[List[WrapOrderByExpr]],
total_hits_count: Optional[bool]
):
self.columns = columns
self.highlight = highlight
Expand All @@ -44,16 +59,16 @@ def __init__(

class ExplainQuery(Query):
def __init__(
self,
columns: Optional[List[WrapParsedExpr]],
highlight: Optional[List[WrapParsedExpr]],
search: Optional[WrapSearchExpr],
filter: Optional[WrapParsedExpr],
group_by: Optional[List[WrapParsedExpr]],
limit: Optional[WrapParsedExpr],
offset: Optional[WrapParsedExpr],
sort: Optional[List[WrapOrderByExpr]],
explain_type: Optional[BaseExplainType],
self,
columns: Optional[List[WrapParsedExpr]],
highlight: Optional[List[WrapParsedExpr]],
search: Optional[WrapSearchExpr],
filter: Optional[WrapParsedExpr],
group_by: Optional[List[WrapParsedExpr]],
limit: Optional[WrapParsedExpr],
offset: Optional[WrapParsedExpr],
sort: Optional[List[WrapOrderByExpr]],
explain_type: Optional[BaseExplainType],
):
super().__init__(columns, highlight, search, filter, group_by, limit, offset, sort, None)
self.explain_type = explain_type
Expand Down Expand Up @@ -84,13 +99,13 @@ def reset(self):
self._total_hits_count = None

def match_dense(
self,
vector_column_name: str,
embedding_data: VEC,
embedding_data_type: str,
distance_type: str,
topn: int,
knn_params: {} = None,
self,
vector_column_name: str,
embedding_data: VEC,
embedding_data_type: str,
distance_type: str,
topn: int,
knn_params: {} = None,
) -> InfinityLocalQueryBuilder:
if self._search is None:
self._search = WrapSearchExpr()
Expand All @@ -108,7 +123,8 @@ def match_dense(
if embedding_data_type == "bit":
if len(embedding_data) % 8 != 0:
raise InfinityException(
ErrorCode.INVALID_EMBEDDING_DATA_TYPE, f"Embeddings with data bit must have dimension of times of 8!"
ErrorCode.INVALID_EMBEDDING_DATA_TYPE,
f"Embeddings with data bit must have dimension of times of 8!"
)
else:
new_embedding_data = []
Expand Down Expand Up @@ -174,7 +190,8 @@ def match_dense(
elem_type = EmbeddingDataType.kElemBFloat16
data.bf16_array_value = embedding_data
else:
raise InfinityException(ErrorCode.INVALID_EMBEDDING_DATA_TYPE, f"Invalid embedding {embedding_data[0]} type")
raise InfinityException(ErrorCode.INVALID_EMBEDDING_DATA_TYPE,
f"Invalid embedding {embedding_data[0]} type")

dist_type = KnnDistanceType.kInvalid
if distance_type == "l2":
Expand Down Expand Up @@ -218,12 +235,12 @@ def match_dense(
return self

def match_sparse(
self,
vector_column_name: str,
sparse_data: SparseVector | dict,
metric_type: str,
topn: int,
opt_params: {} = None,
self,
vector_column_name: str,
sparse_data: SparseVector | dict,
metric_type: str,
topn: int,
opt_params: {} = None,
) -> InfinityLocalQueryBuilder:
if self._search is None:
self._search = WrapSearchExpr()
Expand Down Expand Up @@ -298,7 +315,7 @@ def match_sparse(
return self

def match_text(
self, fields: str, matching_text: str, topn: int, extra_options: Optional[dict]
self, fields: str, matching_text: str, topn: int, extra_options: Optional[dict]
) -> InfinityLocalQueryBuilder:
if self._search is None:
self._search = WrapSearchExpr()
Expand All @@ -324,12 +341,12 @@ def match_text(
return self

def match_tensor(
self,
column_name: str,
query_data: VEC,
query_data_type: str,
topn: int,
extra_option: Optional[dict] = None,
self,
column_name: str,
query_data: VEC,
query_data_type: str,
topn: int,
extra_option: Optional[dict] = None,
) -> InfinityLocalQueryBuilder:
if self._search is None:
self._search = WrapSearchExpr()
Expand Down Expand Up @@ -674,7 +691,7 @@ def to_result(self) -> tuple[dict[str, list[Any]], dict[str, Any], {}]:
limit=self._limit,
offset=self._offset,
sort=self._sort,
total_hits_count = self._total_hits_count,
total_hits_count=self._total_hits_count,
)
self.reset()
return self._table._execute_query(query)
Expand Down
Loading

0 comments on commit eac8d7b

Please sign in to comment.