Skip to content

Commit

Permalink
Add some new function test cases. (#666)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Add some new function test cases.

### Type of change

- [x] Documentation Update
- [x] Test cases
- [x] Python SDK impacted, Need to update PyPI
  • Loading branch information
chrysanthemum-boy authored Feb 27, 2024
1 parent 384e780 commit 4f46119
Show file tree
Hide file tree
Showing 26 changed files with 406 additions and 188 deletions.
2 changes: 1 addition & 1 deletion docs/pysdk_api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ This method searches for rows that satisfy the search condition and updates them
Gets a query_builder by self.

```python
query_builder=table_obj.query_builder()
query_builder=table_obj.query_builder
```

## Details
Expand Down
10 changes: 5 additions & 5 deletions python/infinity/remote_thrift/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def create_table(self, table_name: str, columns_definition: dict[str, str], opti
if res.error_code == ErrorCode.OK:
return RemoteTable(self._conn, self._db_name, table_name)
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def drop_table(self, table_name, if_exists=True):
check_valid_name(table_name, "Table")
Expand All @@ -146,7 +146,7 @@ def list_tables(self):
if res.error_code == ErrorCode.OK:
return res
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def describe_table(self, table_name):
check_valid_name(table_name, "Table")
Expand All @@ -155,7 +155,7 @@ def describe_table(self, table_name):
if res.error_code == ErrorCode.OK:
return select_res_to_polars(res)
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def get_table(self, table_name):
check_valid_name(table_name, "Table")
Expand All @@ -164,11 +164,11 @@ def get_table(self, table_name):
if res.error_code == ErrorCode.OK:
return RemoteTable(self._conn, self._db_name, table_name)
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def show_tables(self):
res = self._conn.show_tables(self._db_name)
if res.error_code == ErrorCode.OK:
return select_res_to_polars(res)
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")
14 changes: 7 additions & 7 deletions python/infinity/remote_thrift/infinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,46 +39,46 @@ def create_database(self, db_name: str, options=None):
if res.error_code == ErrorCode.OK:
return RemoteDatabase(self._client, db_name)
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def list_databases(self):
res = self._client.list_databases()
if res.error_code == ErrorCode.OK:
return res
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def describe_database(self, db_name: str):
check_valid_name(db_name, "DB")
res = self._client.describe_database(db_name=db_name)
if res.error_code == ErrorCode.OK:
return res
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def drop_database(self, db_name: str, options=None):
check_valid_name(db_name, "DB")
res = self._client.drop_database(db_name=db_name)
if res.error_code == ErrorCode.OK:
return res
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def get_database(self, db_name: str):
check_valid_name(db_name, "DB")
res = self._client.get_database(db_name)
if res.error_code == ErrorCode.OK:
return RemoteDatabase(self._client, db_name)
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def disconnect(self):
res = self._client.disconnect()
if res.error_code == ErrorCode.OK:
self._is_connected = False
return res
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

@property
def client(self):
Expand All @@ -89,4 +89,4 @@ def show_variable(self, variable: ShowVariable):
if res.error_code == ErrorCode.OK:
return select_res_to_polars(res)
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")
16 changes: 8 additions & 8 deletions python/infinity/remote_thrift/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create_index(self, index_name: str, index_infos: list[IndexInfo], options=No
if res.error_code == ErrorCode.OK:
return res
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def drop_index(self, index_name: str):
check_valid_name(index_name, "Index")
Expand All @@ -68,7 +68,7 @@ def drop_index(self, index_name: str):
if res.error_code == ErrorCode.OK:
return res
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def insert(self, data: Union[INSERT_DATA, list[INSERT_DATA]]):
# [{"c1": 1, "c2": 1.1}, {"c1": 2, "c2": 2.2}]
Expand Down Expand Up @@ -120,7 +120,7 @@ def insert(self, data: Union[INSERT_DATA, list[INSERT_DATA]]):
if res.error_code == ErrorCode.OK:
return res
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def import_data(self, file_path: str, options=None):

Expand Down Expand Up @@ -173,7 +173,7 @@ def import_data(self, file_path: str, options=None):
if res.error_code == ErrorCode.OK:
return res
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def delete(self, cond: Optional[str] = None):
match cond:
Expand All @@ -186,7 +186,7 @@ def delete(self, cond: Optional[str] = None):
if res.error_code == ErrorCode.OK:
return res
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def update(self, cond: Optional[str], data: Optional[list[dict[str, Union[str, int, float]]]]):
# {"c1": 1, "c2": 1.1}
Expand Down Expand Up @@ -228,7 +228,7 @@ def update(self, cond: Optional[str], data: Optional[list[dict[str, Union[str, i
if res.error_code == ErrorCode.OK:
return res
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def knn(self, vector_column_name: str, embedding_data: VEC, embedding_data_type: str, distance_type: str,
topn: int):
Expand Down Expand Up @@ -291,7 +291,7 @@ def _execute_query(self, query: Query) -> tuple[dict[str, list[Any]], dict[str,
if res.error_code == ErrorCode.OK:
return build_result(res)
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")

def _explain_query(self, query: ExplainQuery) -> Any:
res = self._conn.explain(db_name=self._db_name,
Expand All @@ -306,4 +306,4 @@ def _explain_query(self, query: ExplainQuery) -> Any:
if res.error_code == ErrorCode.OK:
return select_res_to_polars(res)
else:
raise Exception(f"ERROR:{res.error_code}, ", res.error_msg)
raise Exception(f"ERROR:{res.error_code}, {res.error_msg}")
54 changes: 54 additions & 0 deletions python/test/common/common_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class CaseLabel:
"""
Testcase Levels
CI Regression:
L0:
part of CI Regression
triggered by GitHub commit
optional used for dev to verify his fix before submitting a PR(like smoke)
~100 testcases and run in 3 mins
L1:
part of CI Regression
triggered by GitHub commit
must pass before merge
run in 15 mins
Benchmark:
L2:
E2E tests and bug-fix verification
Nightly run triggered by cron job
run in 60 mins
L3:
Stability/Performance/reliability, etc. special tests
Triggered by cron job or manually
run duration depends on test configuration
Loadbalance:
loadbalance testcases which need to be run in multi query nodes
ClusterOnly:
For functions only suitable to cluster mode
GPU:
For GPU supported cases
"""
L0 = "L0"
L1 = "L1"
L2 = "L2"
L3 = "L3"
RBAC = "RBAC"
Loadbalance = "Loadbalance" # loadbalance testcases which need to be run in multi query nodes
ClusterOnly = "ClusterOnly" # For functions only suitable to cluster mode
MultiQueryNodes = "MultiQueryNodes" # for 8 query nodes configs tests, such as resource group
GPU = "GPU"
File renamed without changes.
2 changes: 1 addition & 1 deletion python/test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pandas as pd
from numpy import dtype

import common_values
from python.test.common import common_values
import infinity
import infinity.index as index

Expand Down
2 changes: 1 addition & 1 deletion python/test/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import pytest
import common_values
from python.test.common import common_values
import infinity
from infinity.common import NetworkAddress

Expand Down
21 changes: 19 additions & 2 deletions python/test/test_covert.py → python/test/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import common_values
from python.test.common import common_values
import infinity


class TestCovert:
class TestConvert:
def test_to_pl(self):
infinity_obj = infinity.connect(common_values.TEST_REMOTE_HOST)
db_obj = infinity_obj.get_database("default")
Expand Down Expand Up @@ -36,3 +36,20 @@ def test_to_pa(self):
print(res)
res = table_obj.output(["c1", "c2", "c1"]).to_arrow()
print(res)

def test_to_df(self):
infinity_obj = infinity.connect(common_values.TEST_REMOTE_HOST)
db_obj = infinity_obj.get_database("default")
db_obj.drop_table("test_to_pa", True)
db_obj.create_table("test_to_pa", {
"c1": "int", "c2": "float"}, None)

table_obj = db_obj.get_table("test_to_pa")
table_obj.insert([{"c1": 1, "c2": 2.0}])
print()
res = table_obj.output(["c1", "c2"]).to_df()
print(res)
res = table_obj.output(["c1", "c1"]).to_df()
print(res)
res = table_obj.output(["c1", "c2", "c1"]).to_df()
print(res)
Loading

0 comments on commit 4f46119

Please sign in to comment.