diff --git a/.gitignore b/.gitignore index 94c98aa1d1..1eaf666bf5 100644 --- a/.gitignore +++ b/.gitignore @@ -99,8 +99,8 @@ benchmark/csv/csv_config.h #python sdk python/build python/dist -python/infinity.egg-info python/.idea +python/infinity_sdk.egg-info sift_1m diff --git a/python/README.md b/python/README.md index 5a90a25f69..53853c5244 100644 --- a/python/README.md +++ b/python/README.md @@ -17,106 +17,35 @@ sudo apt-get install thrift-compiler # build ```shell -python setup.py bdist_wheel +python setup.py sdist bdist_wheel ``` # install ```shell -pip install dist/infinity-0.0.1-py3-none-any.whl +pip install dist/*.whl ``` + +# upload +twine upload dist/* # using ```python import infinity -from infinity import NetworkAddress - -infinity_obj = infinity.connect(NetworkAddress('127.0.0.1', 23817)) - -# infinity -res = infinity_obj.create_database("my_db") - -res = infinity_obj.list_databases() - -res = infinity_obj.drop_database("my_db") - -db_obj = infinity_obj.get_database("default") -res = db_obj.create_table("my_table1", {"c1": "int, primary key"}, None) - -res = db_obj.list_tables() - -res = db_obj.drop_table("my_table1") - -# index -res = db_obj.create_table("my_table2", {"c1": "vector,1024,float"}, None) - -table_obj = db_obj.get_table("my_table2") - -res = table_obj.create_index("my_index", - [index.IndexInfo("c1", - index.IndexType.IVFFlat, - [index.InitParameter("centroids_count", "128"), - index.InitParameter("metric", "l2")])], None) - -res = table_obj.drop_index("my_index") - -res = db_obj.drop_table("my_table2") - -# insert -res = db_obj.create_table("my_table3", {"c1": "int, primary key", "c2": "float"}, None) - -table_obj = db_obj.get_table("my_table3") - -res = table_obj.insert([{"c1": 1, "c2": 1.1}, {"c1": 2, "c2": 2.2}]) - -res = db_obj.create_table("test_insert_varchar", {"c1": "varchar"}, None) - -table_obj = db_obj.get_table("test_insert_varchar") - -res = table_obj.insert([{"c1": "test_insert_varchar"}]) - -res = db_obj.create_table("test_insert_embedding", {"c1": "vector,3,int"}, None) - -table_obj = db_obj.get_table("test_insert_embedding") - -res = table_obj.insert([{"c1": [4, 5, 6]}]) - -res = table_obj.insert([{"c1": [1, 2, 3]}, {"c1": [4, 5, 6]}, {"c1": [7, 8, 9]}, {"c1": [-7, -8, -9]}]) - -# search - -res = table_obj.search().output(["c1 + 0.1"]).to_df() - -res = table_obj.search().output(["*"]).filter("c1 > 1").to_df() - -# import -res = db_obj.create_table("my_table4", {"c1": "int", "c2": "vector,3,int"}, None) -table_obj = db_obj.get_table("my_table4") -parent_dir = os.path.dirname(os.path.dirname(os.getcwd())) -test_csv_dir = parent_dir + "/test/data/csv/embedding_int_dim3.csv" - -res = table_obj.import_data(test_csv_dir, None) - -# search -res = table_obj.search().output(["c1"]).filter("c1 > 1").to_df() - -res = db_obj.drop_table("my_table4") - - -res = db_obj.create_table("table_4", {"c1": "int, primary key, not null", "c2": "int", "c3": "int"}, None) - -table_obj = db_obj.get_table("table_4") - -res = table_obj.insert( - [{"c1": 1, "c2": 10, "c3": 100}, {"c1": 2, "c2": 20, "c3": 200}, {"c1": 3, "c2": 30, "c3": 300}, - {"c1": 4, "c2": 40, "c3": 400}]) - -res = table_obj.update("c1 = 1", [{"c2": 90, "c3": 900}]) - -res = table_obj.delete("c1 = 1") - -res = table_obj.delete() - -# disconnect -res = infinity_obj.disconnect() +from infinity.common import REMOTE_HOST + +infinity_obj = infinity.connect(REMOTE_HOST) +db = infinity_obj.get_database("default") +db.drop_table("my_table", if_exists=True) +table = db.create_table( + "my_table", {"num": "integer", "body": "varchar", "vec": "vector,5,float"}, None) +table.insert( + [{"num": 1, "body": "undesirable, unnecessary, and harmful", "vec": [1.0] * 5}]) +table.insert( + [{"num": 2, "body": "publisher=US National Office for Harmful Algal Blooms", "vec": [4.0] * 5}]) +table.insert( + [{"num": 3, "body": "in the case of plants, growth and chemical", "vec": [7.0] * 5}]) + +res = table.output(["*"]).knn("vec", [3.0] * 5, "float", "ip", 2).to_pl() +print(res) ``` \ No newline at end of file diff --git a/python/benchmark/remote_benchmark_milvus.py b/python/benchmark/remote_benchmark_milvus.py deleted file mode 100644 index 6e27607b1e..0000000000 --- a/python/benchmark/remote_benchmark_milvus.py +++ /dev/null @@ -1,322 +0,0 @@ -# hello_milvus.py demonstrates the basic operations of PyMilvus, a Python SDK of Milvus. -# 1. connect to Milvus -# 2. create collection -# 3. insert data -# 4. create index -# 5. search, query, and hybrid search on entities -# 6. delete entities by PK -# 7. drop collection -import argparse -import os -import struct -import time - -import numpy as np -from pymilvus import ( - connections, - utility, - FieldSchema, CollectionSchema, DataType, - Collection, -) - -from benchmark.hello_milvus import fmt - - -def test_hello_milvus(): - fmt = "\n=== {:30} ===\n" - search_latency_fmt = "search latency = {:.4f}s" - num_entities, dim = 3000, 8 - - ################################################################################# - # 1. connect to Milvus - # Add a new connection alias `default` for Milvus server in `localhost:19530` - # Actually the "default" alias is a buildin in PyMilvus. - # If the address of Milvus is the same as `localhost:19530`, you can omit all - # parameters and call the method as: `connections.connect()`. - # - # Note: the `using` parameter of the following methods is default to "default". - print(fmt.format("start connecting to Milvus")) - connections.connect("default", host="localhost", port="19530") - - has = utility.has_collection("hello_milvus") - print(f"Does collection hello_milvus exist in Milvus: {has}") - - ################################################################################# - # 2. create collection - # We're going to create a collection with 3 fields. - # +-+------------+------------+------------------+------------------------------+ - # | | field name | field type | other attributes | field description | - # +-+------------+------------+------------------+------------------------------+ - # |1| "pk" | VarChar | is_primary=True | "primary field" | - # | | | | auto_id=False | | - # +-+------------+------------+------------------+------------------------------+ - # |2| "random" | Double | | "a double field" | - # +-+------------+------------+------------------+------------------------------+ - # |3|"embeddings"| FloatVector| dim=8 | "float vector with dim 8" | - # +-+------------+------------+------------------+------------------------------+ - fields = [ - FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), - FieldSchema(name="random", dtype=DataType.DOUBLE), - FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim) - ] - - schema = CollectionSchema(fields, "hello_milvus is the simplest demo to introduce the APIs") - - print(fmt.format("Create collection `hello_milvus`")) - hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong") - - ################################################################################ - # 3. insert data - # We are going to insert 3000 rows of data into `hello_milvus` - # Data to be inserted must be organized in fields. - # - # The insert() method returns: - # - either automatically generated primary keys by Milvus if auto_id=True in the schema; - # - or the existing primary key field from the entities if auto_id=False in the schema. - - print(fmt.format("Start inserting entities")) - rng = np.random.default_rng(seed=19530) - entities = [ - # provide the pk field because `auto_id` is set to False - [str(i) for i in range(num_entities)], - rng.random(num_entities).tolist(), # field random, only supports list - rng.random((num_entities, dim)), # field embeddings, supports numpy.ndarray and list - ] - - insert_result = hello_milvus.insert(entities) - - hello_milvus.flush() - print(f"Number of entities in Milvus: {hello_milvus.num_entities}") # check the num_entities - - ################################################################################ - # 4. create index - # We are going to create an IVF_FLAT index for hello_milvus collection. - # create_index() can only be applied to `FloatVector` and `BinaryVector` fields. - print(fmt.format("Start Creating index IVF_FLAT")) - index = { - "index_type": "IVF_FLAT", - "metric_type": "L2", - "params": {"nlist": 128}, - } - - hello_milvus.create_index("embeddings", index) - - ################################################################################ - # 5. search, query, and hybrid search - # After data were inserted into Milvus and indexed, you can perform: - # - search based on vector similarity - # - query based on scalar filtering(boolean, int, etc.) - # - hybrid search based on vector similarity and scalar filtering. - # - - # Before conducting a search or a query, you need to load the data in `hello_milvus` into memory. - print(fmt.format("Start loading")) - hello_milvus.load() - - # ----------------------------------------------------------------------------- - # search based on vector similarity - print(fmt.format("Start searching based on vector similarity")) - vectors_to_search = entities[-1][-2:] - search_params = { - "metric_type": "L2", - "params": {"nprobe": 10}, - } - - start_time = time.time() - result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["random"]) - end_time = time.time() - - for hits in result: - for hit in hits: - print(f"hit: {hit}, random field: {hit.entity.get('random')}") - print(search_latency_fmt.format(end_time - start_time)) - - # ----------------------------------------------------------------------------- - # query based on scalar filtering(boolean, int, etc.) - print(fmt.format("Start querying with `random > 0.5`")) - - start_time = time.time() - result = hello_milvus.query(expr="random > 0.5", output_fields=["random", "embeddings"]) - end_time = time.time() - - print(f"query result:\n-{result[0]}") - print(search_latency_fmt.format(end_time - start_time)) - - # ----------------------------------------------------------------------------- - # pagination - r1 = hello_milvus.query(expr="random > 0.5", limit=4, output_fields=["random"]) - r2 = hello_milvus.query(expr="random > 0.5", offset=1, limit=3, output_fields=["random"]) - print(f"query pagination(limit=4):\n\t{r1}") - print(f"query pagination(offset=1, limit=3):\n\t{r2}") - - # ----------------------------------------------------------------------------- - # hybrid search - print(fmt.format("Start hybrid searching with `random > 0.5`")) - - start_time = time.time() - result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, expr="random > 0.5", - output_fields=["random"]) - end_time = time.time() - - for hits in result: - for hit in hits: - print(f"hit: {hit}, random field: {hit.entity.get('random')}") - print(search_latency_fmt.format(end_time - start_time)) - - ############################################################################### - # 6. delete entities by PK - # You can delete entities by their PK values using boolean expressions. - ids = insert_result.primary_keys - - expr = f'pk in ["{ids[0]}" , "{ids[1]}"]' - print(fmt.format(f"Start deleting with expr `{expr}`")) - - result = hello_milvus.query(expr=expr, output_fields=["random", "embeddings"]) - print(f"query before delete by expr=`{expr}` -> result: \n-{result[0]}\n-{result[1]}\n") - - hello_milvus.delete(expr) - - result = hello_milvus.query(expr=expr, output_fields=["random", "embeddings"]) - print(f"query after delete by expr=`{expr}` -> result: {result}\n") - - ############################################################################### - # 7. drop collection - # Finally, drop the hello_milvus collection - print(fmt.format("Drop collection `hello_milvus`")) - utility.drop_collection("hello_milvus") - - -def fvecs_read_all(filename): - vectors = [] - with open(filename, 'rb') as f: - while True: - try: - dims = struct.unpack('i', f.read(4))[0] - vec = struct.unpack('{}f'.format(dims), f.read(4 * dims)) - assert dims == len(vec) - vectors.append(list(vec)) - except struct.error: - break - return vectors - - -def import_sift_1m_milvus(path): - print(fmt.format("start connecting to Milvus")) - connections.connect("default", host="localhost", port="19530") - - has = utility.has_collection("sift_benchmark") - print(f"Does collection sift_benchmark exist in Milvus: {has}") - - num_entities, dim = 100_0000, 128 - fields = [ - FieldSchema(name="col1", dtype=DataType.FLOAT_VECTOR, dim=dim) - ] - schema = CollectionSchema(fields, "sift_1m is a collection for benchmark") - sift_collection = Collection("sift_benchmark", schema, consistency_level="Strong") - - # import data - print("Start importing data") - vectors = fvecs_read_all(path) - print(f"Number of entities: {len(vectors)}") - - entities = [ - vectors - ] - - start = time.time() - sift_collection.insert(entities) - end = time.time() - print(f"Time to insert {num_entities} entities: {end - start}") - - sift_collection.flush() - print(f"Number of entities in Milvus: {sift_collection.num_entities}") - - -def import_gist_1m_milvus(path): - num_entities, dim = 100_0000, 960 - fields = [ - FieldSchema(name="col1", dtype=DataType.FLOAT_VECTOR, dim=dim) - ] - schema = CollectionSchema(fields, "gist_1m is a collection for benchmark") - gist_collection = Collection("gist_1m", schema, consistency_level="Strong") - - # import data - print("Start importing data") - vectors = fvecs_read_all(path) - print(f"Number of entities: {len(vectors)}") - - entities = [ - vectors - ] - - gist_collection.insert(entities) - gist_collection.flush() - - print(f"Number of entities in Milvus: {gist_collection.num_entities}") - - -def import_data(path): - if os.path.exists(path + "/sift_base.fvecs"): - import_sift_1m_milvus(path + "/sift_base.fvecs") - elif os.path.exists(path + "/gist_base.fvecs"): - import_gist_1m_milvus(path + "/gist_base.fvecs") - else: - raise Exception("Invalid data set") - - -def benchmark(threads, rounds, data_set, path): - import_data(path) - if not os.path.exists(path): - print(f"Path: {path} doesn't exist") - raise Exception(f"Path: {path} doesn't exist") - if data_set == "sift_1m": - query_path = path + "/sift_query.fvecs" - ground_truth_path = path + "/sift_groundtruth.ivecs" - - - - elif data_set == "gist_1m": - query_path = path + "/gist_query.fvecs" - ground_truth_path = path + "/gist_groundtruth.ivecs" - else: - raise Exception("Invalid data set") - - -if __name__ == '__main__': - current_path = os.getcwd() - parent_path = os.path.dirname(current_path) - parent_path = os.path.dirname(parent_path) - - print(f"Current Path: {current_path}") - print(f"Parent Path: {parent_path}") - - parser = argparse.ArgumentParser(description="Benchmark Infinity") - - parser.add_argument( - "-t", - "--threads", - type=int, - default=1, - dest="threads", - ) - parser.add_argument( - "-r", - "--rounds", - type=int, - default=5, - dest="rounds", - ) - parser.add_argument( - "-d", - "--data", - type=str, - default='sift_1m', # gist_1m - dest="data_set", - ) - - data_dir = parent_path + "/test/data/benchmark/" + parser.parse_args().data_set - print(f"Data Dir: {data_dir}") - - args = parser.parse_args() - - benchmark(args.threads, args.rounds, args.data_set, path=data_dir) diff --git a/python/benchmark/test_benchmark_import.py b/python/benchmark/test_benchmark_import.py deleted file mode 100644 index 88b3d3a7aa..0000000000 --- a/python/benchmark/test_benchmark_import.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time - -import infinity -from infinity import index -from infinity.common import REMOTE_HOST -from infinity.remote_thrift.client import ThriftInfinityClient -from infinity.remote_thrift.table import RemoteTable - -def find_path(): - current_path = os.getcwd() - parent_path = os.path.dirname(current_path) - parent_path = os.path.dirname(parent_path) - - print(f"Current Path: {current_path}") - print(f"Parent Path: {parent_path}") - - data_dir = parent_path + "/test/data/benchmark/sift_1m" - print(f"Data Dir: {data_dir}") - return data_dir -class TestImportBenchmark: - - def test_import(self): - """ - target: test import data to remote server - method: connect server, create table, import data, search, drop table, disconnect - expect: all operations successfully - """ - - infinity_obj = infinity.connect(REMOTE_HOST) - assert infinity_obj - - st = time.process_time() - - db_obj = infinity_obj.get_database("default") - assert db_obj - db_obj.drop_table("sift_benchmark", True) - db_obj.create_table("sift_benchmark", {"col1": "vector,128,float"}, None) - table_obj = db_obj.get_table("sift_benchmark") - assert table_obj - - test_fvecs_dir = find_path() + "/sift_base.fvecs" - assert os.path.exists(test_fvecs_dir) - # - res = table_obj.import_data(test_fvecs_dir, None) - assert res.success - - end = time.process_time() - dur = end - st - print(dur) - - def test_create_index(self): - st = time.process_time() - conn = ThriftInfinityClient(REMOTE_HOST) - table = RemoteTable(conn, "default", "sift_benchmark") - res = table.create_index("hnsw_index", - [index.IndexInfo("col1", - index.IndexType.Hnsw, - [ - index.InitParameter("M", "16"), - index.InitParameter("ef_construction", "200"), - index.InitParameter("ef", "200"), - index.InitParameter("metric", "l2"), - index.InitParameter("encode", "lvq") - ])], None) - - assert res.success - - end = time.process_time() - dur = end - st - print(dur) diff --git a/python/benchmark/test_benchmark_query.py b/python/benchmark/test_benchmark_query.py deleted file mode 100644 index 6035366c5f..0000000000 --- a/python/benchmark/test_benchmark_query.py +++ /dev/null @@ -1,267 +0,0 @@ -# Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import multiprocessing -import os -import struct -import time -from concurrent.futures import ThreadPoolExecutor - -import polars as pl - -from benchmark.test_benchmark import trace_unhandled_exceptions -from benchmark.test_benchmark_import import find_path -from infinity.common import REMOTE_HOST -from infinity.remote_thrift.client import ThriftInfinityClient -from infinity.remote_thrift.query_builder import InfinityThriftQueryBuilder -from infinity.remote_thrift.table import RemoteTable - - -def fvecs_read(filename): - with open(filename, 'rb') as f: - while True: - try: - dims = struct.unpack('i', f.read(4))[0] - vec = struct.unpack('{}f'.format(dims), f.read(4 * dims)) - assert dims == len(vec) - yield list(vec) - except struct.error: - break - - -def fvecs_read_all(filename): - vectors = [] - with open(filename, 'rb') as f: - while True: - try: - dims = struct.unpack('i', f.read(4))[0] - vec = struct.unpack('{}f'.format(dims), f.read(4 * dims)) - assert dims == len(vec) - vectors.append(list(vec)) - except struct.error: - break - return vectors - - -def read_groundtruth(filename): - vectors = [] - with open(filename, 'rb') as f: - while True: - try: - dims = struct.unpack('i', f.read(4))[0] - vec = struct.unpack('{}i'.format(dims), f.read(4 * dims)) - assert dims == len(vec) - vectors.append(list(vec)) - except struct.error: - break - - print("len(vectors): ", len(vectors)) - - gt_count = len(vectors) - gt_top_k = len(vectors[0]) - - ground_truth_sets_1 = [set() for _ in range(gt_count)] - ground_truth_sets_10 = [set() for _ in range(gt_count)] - ground_truth_sets_100 = [set() for _ in range(gt_count)] - - for i in range(gt_count): - for j in range(gt_top_k): - x = vectors[i][j] - if j < 1: - ground_truth_sets_1[i].add(x) - if j < 10: - ground_truth_sets_10[i].add(x) - if j < 100: - ground_truth_sets_100[i].add(x) - - return ground_truth_sets_1, ground_truth_sets_10, ground_truth_sets_100 - - -def calculate_recall(ground_truth_sets, result_sets): - recall = 0.0 - for i in range(len(ground_truth_sets)): - if ground_truth_sets[i] in result_sets: - recall += 1 - return recall / len(result_sets) * len(ground_truth_sets) - - -def calculate_recall_all(ground_truth_sets_1, ground_truth_sets_10, ground_truth_sets_100, query_results): - correct_1 = 0.0 - correct_10 = 0.0 - correct_100 = 0.0 - - for query_idx in range(len(query_results)): - result = query_results[query_idx] - ground_truth_1 = ground_truth_sets_1[query_idx] - ground_truth_10 = ground_truth_sets_10[query_idx] - ground_truth_100 = ground_truth_sets_100[query_idx] - - for i in range(len(result)): - if i < 1 and result[i] in ground_truth_1: - correct_1 += 1.0 - if i < 10 and result[i] in ground_truth_10: - correct_10 += 1.0 - if i < 100 and result[i] in ground_truth_100: - correct_100 += 1.0 - - recall_1 = correct_1 / (len(query_results) * 1) - recall_10 = correct_10 / (len(query_results) * 10) - recall_100 = correct_100 / (len(query_results) * 100) - return recall_1, recall_10, recall_100 - - -def test_read_all(): - sift_query_path = os.getcwd() + "/sift_1m/sift/l2_groundtruth.ivecs" - vectors = fvecs_read_all(sift_query_path) - print(len(vectors)) - print(pl.DataFrame(vectors)) - - -def test_read(): - sift_query_path = os.getcwd() + "/sift_1m/sift/l2_groundtruth.ivecs" - for idx, query_vec in enumerate(fvecs_read(sift_query_path)): - print(pl.DataFrame([query_vec])) - if idx == 10: - assert idx == 10 - break - - -def test_read_groundtruth(): - read_groundtruth_path = os.getcwd() + "/sift_1m/sift/l2_groundtruth.ivecs" - ground_truth_sets_1, ground_truth_sets_10, ground_truth_sets_100 = read_groundtruth(read_groundtruth_path) - print(len(ground_truth_sets_1)) - print(len(ground_truth_sets_10)) - print(len(ground_truth_sets_100)) - print(len(ground_truth_sets_1[0])) - print(len(ground_truth_sets_10[0])) - print(len(ground_truth_sets_100[0])) - - -@trace_unhandled_exceptions -def work(query_vec, topk, metric_type, column_name, data_type): - conn = ThriftInfinityClient(REMOTE_HOST) - table = RemoteTable(conn, "default", "sift_benchmark") - query_builder = InfinityThriftQueryBuilder(table) - query_builder.output(["_row_id"]) - query_builder.knn(column_name, query_vec, data_type, metric_type, topk) - query_builder.to_result() - - -class TestQueryBenchmark: - - def test_process_pool(self): - round = 1 - total_times = 10000 - client_num = 25 - sift_query_path = find_path() + "/sift_query.fvecs" - if not os.path.exists(sift_query_path): - print(f"File: {sift_query_path} doesn't exist") - raise Exception(f"File: {sift_query_path} doesn't exist") - - start = time.time() - - p = multiprocessing.Pool(client_num) - - for i in range(round): - for idx, query_vec in enumerate(fvecs_read(sift_query_path)): - p.apply_async(work, args=(query_vec, 100, "l2", "col1", "float")) - if idx == total_times: - assert idx == total_times - break - - p.close() - p.join() - - end = time.time() - dur = end - start - print(">>> Query Benchmark End <<<") - print(f"Total Times: {total_times * round}") - print(f"Total Dur: {dur}") - qps = (total_times * round) / dur - print(f"QPS: {qps}") - - def test_thread_pool(self): - total_times = 10000 - sift_query_path = find_path() + "/sift_1m/sift_query.fvecs" - if not os.path.exists(sift_query_path): - print(f"File: {sift_query_path} doesn't exist") - return - - start = time.time() - - with ThreadPoolExecutor(max_workers=16) as executor: - for idx, query_vec in enumerate(fvecs_read(sift_query_path)): - executor.submit(work, query_vec, 100, "l2", "col1", "float") - if idx == total_times: - assert idx == total_times - break - - end = time.time() - dur = end - start - print(">>> Query Benchmark End <<<") - print(f"Total Times: {total_times}") - print(f"Total Dur: {dur}") - qps = total_times / dur - print(f"QPS: {qps}") - - def test_query(self): - thread_num = 1 - total_times = 10000 - - print(">>> Query Benchmark Start <<<") - print(f"Thread Num: {thread_num}, Times: {total_times}") - - sift_query_path = find_path() + "/sift_query.fvecs" - if not os.path.exists(sift_query_path): - print(f"File: {sift_query_path} doesn't exist") - return - - conn = ThriftInfinityClient(REMOTE_HOST) - table = RemoteTable(conn, "default", "sift_benchmark") - queries = fvecs_read_all(sift_query_path) - query_results = [[] for _ in range(len(queries))] - - dur = 0.0 - for idx, query_vec in enumerate(queries): - - start = time.time() - - query_builder = InfinityThriftQueryBuilder(table) - query_builder.output(["_row_id"]) - query_builder.knn('col1', query_vec, 'float', 'l2', 100) - res, _ = query_builder.to_result() - end = time.time() - - diff = end - start - dur += diff - - res_list = res["ROW_ID"] - # print(len(res_list)) - - for i in range(len(res_list)): - query_results[idx].append(res_list[i][1]) - - read_groundtruth_path = os.getcwd() + "/sift_1m/sift_groundtruth.ivecs" - ground_truth_sets_1, ground_truth_sets_10, ground_truth_sets_100 = read_groundtruth(read_groundtruth_path) - - recall_1, recall_10, recall_100 = calculate_recall_all(ground_truth_sets_1, ground_truth_sets_10, - ground_truth_sets_100, query_results) - print("recall_1: ", recall_1) - print("recall_10: ", recall_10) - print("recall_100: ", recall_100) - - print(">>> Query Benchmark End <<<") - qps = total_times / dur - print(f"Total Times: {total_times}") - print(f"Total Dur: {dur}") - print(f"QPS: {qps}") diff --git a/python/hello_infinity.py b/python/hello_infinity.py new file mode 100644 index 0000000000..69ba89d820 --- /dev/null +++ b/python/hello_infinity.py @@ -0,0 +1,37 @@ +# Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import infinity +from infinity.common import REMOTE_HOST + + +def main(): + infinity_obj = infinity.connect(REMOTE_HOST) + db = infinity_obj.get_database("default") + db.drop_table("my_table", if_exists=True) + table = db.create_table( + "my_table", {"num": "integer", "body": "varchar", "vec": "vector,5,float"}, None) + table.insert( + [{"num": 1, "body": "undesirable, unnecessary, and harmful", "vec": [1.0] * 5}]) + table.insert( + [{"num": 2, "body": "publisher=US National Office for Harmful Algal Blooms", "vec": [4.0] * 5}]) + table.insert( + [{"num": 3, "body": "in the case of plants, growth and chemical", "vec": [7.0] * 5}]) + + res = table.output(["*"]).knn("vec", [3.0] * 5, "float", "ip", 2).to_pl() + print(res) + + +if __name__ == '__main__': + main() diff --git a/python/infinity/__init__.py b/python/infinity/__init__.py index 9a7b2e2238..433590fabe 100644 --- a/python/infinity/__init__.py +++ b/python/infinity/__init__.py @@ -14,7 +14,7 @@ import importlib.metadata -__version__ = importlib.metadata.version("infinity") +__version__ = importlib.metadata.version("infinity_sdk") from infinity.common import URI, NetworkAddress, LOCAL_HOST from infinity.infinity import InfinityConnection diff --git a/python/infinity/common.py b/python/infinity/common.py index 6734f0b2fc..2ba70fc087 100644 --- a/python/infinity/common.py +++ b/python/infinity/common.py @@ -28,6 +28,7 @@ def __str__(self): URI = Union[NetworkAddress, Path] VEC = Union[list, np.ndarray] +INSERT_DATA = dict[str, Union[str, int, float, list[Union[int, float]]]] REMOTE_HOST = NetworkAddress("127.0.0.1", 23817) LOCAL_HOST = NetworkAddress("0.0.0.0", 23817) diff --git a/python/infinity/remote_thrift/db.py b/python/infinity/remote_thrift/db.py index 8756ad5894..cb738b9743 100644 --- a/python/infinity/remote_thrift/db.py +++ b/python/infinity/remote_thrift/db.py @@ -125,9 +125,13 @@ def create_table(self, table_name: str, columns_definition: dict[str, str], opti get_ordinary_info( column_big_info, column_defs, column_name, index) # print(column_defs) - return self._conn.create_table(db_name=self._db_name, table_name=table_name, - column_defs=column_defs, - option=options) + res = self._conn.create_table(db_name=self._db_name, table_name=table_name, + column_defs=column_defs, + option=options) + if res.success is True: + return RemoteTable(self._conn, self._db_name, table_name) + else: + raise Exception(res.error_msg) def drop_table(self, table_name, if_exists=True): return self._conn.drop_table(db_name=self._db_name, table_name=table_name, if_exists=if_exists) diff --git a/python/infinity/remote_thrift/infinity.py b/python/infinity/remote_thrift/infinity.py index 03efe8d9f1..13049af8d5 100644 --- a/python/infinity/remote_thrift/infinity.py +++ b/python/infinity/remote_thrift/infinity.py @@ -31,7 +31,11 @@ def __del__(self): self.disconnect() def create_database(self, db_name: str, options=None): - return self._client.create_database(db_name=db_name) + res = self._client.create_database(db_name=db_name) + if res.success is True: + return RemoteDatabase(self._client, db_name) + else: + raise Exception(res.error_msg) def list_databases(self): return self._client.list_databases() diff --git a/python/infinity/remote_thrift/table.py b/python/infinity/remote_thrift/table.py index 2ac7fbaa72..e40115c448 100644 --- a/python/infinity/remote_thrift/table.py +++ b/python/infinity/remote_thrift/table.py @@ -13,11 +13,12 @@ # limitations under the License. import os from abc import ABC -from typing import Optional, Union, List, Any, Tuple, Dict +from typing import Optional, Union, List, Any from sqlglot import condition import infinity.remote_thrift.infinity_thrift_rpc.ttypes as ttypes +from infinity.common import INSERT_DATA, VEC from infinity.index import IndexInfo from infinity.remote_thrift.query_builder import Query, InfinityThriftQueryBuilder from infinity.remote_thrift.types import build_result @@ -31,6 +32,7 @@ def __init__(self, conn, db_name, table_name): self._conn = conn self._db_name = db_name self._table_name = table_name + self.query_builder = InfinityThriftQueryBuilder(table=self) def create_index(self, index_name: str, index_infos: list[IndexInfo], options=None): @@ -56,12 +58,16 @@ def drop_index(self, index_name: str): return self._conn.drop_index(db_name=self._db_name, table_name=self._table_name, index_name=index_name) - def insert(self, data: list[dict[str, Union[str, int, float, list[Union[int, float]]]]]): + def insert(self, data: Union[INSERT_DATA, list[INSERT_DATA]]): # [{"c1": 1, "c2": 1.1}, {"c1": 2, "c2": 2.2}] db_name = self._db_name table_name = self._table_name column_names: list[str] = [] fields: list[ttypes.Field] = [] + + if isinstance(data, dict): + data = [data] + for row in data: column_names = list(row.keys()) parse_exprs = [] @@ -191,10 +197,50 @@ def update(self, cond: Optional[str], data: Optional[list[dict[str, Union[str, i return self._conn.update(db_name=self._db_name, table_name=self._table_name, where_expr=where_expr, update_expr_array=update_expr_array) - pass + def knn(self, vector_column_name: str, embedding_data: VEC, embedding_data_type: str, distance_type: str, + topn: int): + self.query_builder.knn(vector_column_name, embedding_data, embedding_data_type, distance_type, topn) + + return self + + def match(self, vector_column_name: str, embedding_data: VEC, embedding_data_type: str, topn: int): + self.query_builder.match(vector_column_name, embedding_data, embedding_data_type, topn) + + return self + + def output(self, columns: Optional[List[str]]): + self.query_builder.output(columns) + + return self + + def filter(self, filter: Optional[str]): + self.query_builder.filter(filter) + + return self + + def limit(self, limit: Optional[int]): + self.query_builder.limit(limit) + + return self + + def offset(self, offset: Optional[int]): + + self.query_builder.offset(offset) + + return self + + def to_result(self): + return self.query_builder.to_result() + + def to_df(self): + return self.query_builder.to_df() + + def to_pl(self): + return self.query_builder.to_pl() + + def to_arrow(self): + return self.query_builder.to_arrow() - def query_builder(self) -> InfinityThriftQueryBuilder: - return InfinityThriftQueryBuilder(table=self) def _execute_query(self, query: Query) -> tuple[dict[str, list[Any]], dict[str, Any]]: # process select_list diff --git a/python/infinity/table.py b/python/infinity/table.py index 77b0d51252..ace85a6937 100644 --- a/python/infinity/table.py +++ b/python/infinity/table.py @@ -44,9 +44,6 @@ def delete(self, cond: Optional[str] = None): def update(self, cond: Optional[str], data: Optional[list[dict[str, Union[str, int, float]]]]): pass - @abstractmethod - def query_builder(self): - pass @abstractmethod def _execute_query(self, query): diff --git a/python/pyproject.toml b/python/pyproject.toml index 88f037f674..844d380498 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [project] -name = "infinity" -version = "0.0.1" +name = "infinity_sdk" +version = "0.1.0-dev2" dependencies = [ "sqlglot==11.7.1", "pydantic", diff --git a/python/run_all_test.py b/python/run_all_test.py index fc3d761a28..f556d8c547 100644 --- a/python/run_all_test.py +++ b/python/run_all_test.py @@ -1,3 +1,17 @@ +# Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os diff --git a/python/setup.py b/python/setup.py index 1abbd068c1..5f48f725df 100644 --- a/python/setup.py +++ b/python/setup.py @@ -1,3 +1,17 @@ +# Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import setuptools if __name__ == "__main__": diff --git a/python/test/test.py b/python/test/test.py index e7e6b6253f..b297f3ce78 100644 --- a/python/test/test.py +++ b/python/test/test.py @@ -50,7 +50,7 @@ def test_create_db_with_invalid_name(self): assert infinity_obj res = infinity_obj.create_database("") - assert res.success == False + assert not res.success assert res.error_msg assert infinity_obj.disconnect() @@ -79,8 +79,8 @@ def test_infinity_thrift(self): assert infinity_obj # infinity - res = infinity_obj.create_database("my_db") - assert res.success + db_obj = infinity_obj.create_database("my_db") + assert db_obj is not None res = infinity_obj.list_databases() assert res.success @@ -93,9 +93,9 @@ def test_infinity_thrift(self): db_obj = infinity_obj.get_database("default") db_obj.drop_table("my_table1", if_exists=True) - res = db_obj.create_table( + table_obj = db_obj.create_table( "my_table1", {"c1": "int, primary key"}, None) - assert res.success + assert table_obj is not None res = db_obj.list_tables() assert res.success @@ -104,9 +104,10 @@ def test_infinity_thrift(self): assert res.success # index - res = db_obj.create_table( + db_obj.drop_table("my_table2", if_exists=True) + table_obj = db_obj.create_table( "my_table2", {"c1": "vector,1024,float"}, None) - assert res.success + assert table_obj is not None table_obj = db_obj.get_table("my_table2") assert table_obj @@ -126,9 +127,9 @@ def test_infinity_thrift(self): # insert db_obj.drop_table("my_table3", if_exists=True) - res = db_obj.create_table( + table_obj = db_obj.create_table( "my_table3", {"c1": "int, primary key", "c2": "float"}, None) - assert res.success + assert table_obj is not None table_obj = db_obj.get_table("my_table3") assert table_obj @@ -137,12 +138,12 @@ def test_infinity_thrift(self): [{"c1": 1, "c2": 1.1}, {"c1": 2, "c2": 2.2}]) assert res.success # search - res = table_obj.query_builder().output(["c1 + 0.1"]).to_df() + res = table_obj.output(["c1 + 0.1"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'(c1 + 0.100000)': (1.1, 2.1)}).astype( {'(c1 + 0.100000)': dtype('float64')})) - res = table_obj.query_builder().output( + res = table_obj.output( ["*"]).filter("c1 > 1").to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (2,), 'c2': (2.2,)}).astype( @@ -152,9 +153,10 @@ def test_infinity_thrift(self): assert res.success # import - res = db_obj.create_table( + db_obj.drop_table("my_table4", if_exists=True) + table_obj = db_obj.create_table( "my_table4", {"c1": "int", "c2": "vector,3,int"}, None) - assert res.success + assert table_obj is not None table_obj = db_obj.get_table("my_table4") assert table_obj @@ -166,7 +168,7 @@ def test_infinity_thrift(self): assert res.success # search - res = table_obj.query_builder().output( + res = table_obj.output( ["c1"]).filter("c1 > 1").to_df() print(res) res = db_obj.drop_table("my_table4") diff --git a/python/test/test_covert.py b/python/test/test_covert.py index b4cc7af1ee..8f7b596613 100644 --- a/python/test/test_covert.py +++ b/python/test/test_covert.py @@ -13,11 +13,11 @@ def test_to_pl(self): table_obj = db_obj.get_table("test_to_pl") table_obj.insert([{"c1": 1, "c2": 2}]) print() - res = table_obj.query_builder().output(["c1", "c2"]).to_pl() + res = table_obj.output(["c1", "c2"]).to_pl() print(res) - res = table_obj.query_builder().output(["c1", "c1"]).to_pl() + res = table_obj.output(["c1", "c1"]).to_pl() print(res) - res = table_obj.query_builder().output(["c1", "c2", "c1"]).to_pl() + res = table_obj.output(["c1", "c2", "c1"]).to_pl() print(res) @@ -31,9 +31,9 @@ def test_to_pa(self): table_obj = db_obj.get_table("test_to_pa") table_obj.insert([{"c1": 1, "c2": 2.0}]) print() - res = table_obj.query_builder().output(["c1", "c2"]).to_arrow() + res = table_obj.output(["c1", "c2"]).to_arrow() print(res) - res = table_obj.query_builder().output(["c1", "c1"]).to_arrow() + res = table_obj.output(["c1", "c1"]).to_arrow() print(res) - res = table_obj.query_builder().output(["c1", "c2", "c1"]).to_arrow() + res = table_obj.output(["c1", "c2", "c1"]).to_arrow() print(res) diff --git a/python/test/test_create_table.py b/python/test/test_create_table.py index c3ebd42aa5..98b95c87d3 100644 --- a/python/test/test_create_table.py +++ b/python/test/test_create_table.py @@ -27,10 +27,8 @@ def test_create_varchar_table(self): infinity_obj = infinity.connect(REMOTE_HOST) db_obj = infinity_obj.get_database("default") db_obj.drop_table("test_create_varchar_table", True) - res = db_obj.create_table("test_create_varchar_table", { + table_obj = db_obj.create_table("test_create_varchar_table", { "c1": "varchar, primary key", "c2": "float"}, None) - assert res.success - table_obj = db_obj.get_table("test_create_varchar_table") assert table_obj db_obj.drop_table("test_create_varchar_table") @@ -44,10 +42,8 @@ def test_create_embedding_table(self): infinity_obj = infinity.connect(REMOTE_HOST) db_obj = infinity_obj.get_database("default") db_obj.drop_table("test_create_embedding_table", True) - res = db_obj.create_table("test_create_embedding_table", { + table_obj = db_obj.create_table("test_create_embedding_table", { "c1": "vector,128,float"}, None) - assert res.success - table_obj = db_obj.get_table("test_create_embedding_table") assert table_obj db_obj.drop_table("test_create_embedding_table") diff --git a/python/test/test_database.py b/python/test/test_database.py index cae8d8c38d..931f909ac1 100644 --- a/python/test/test_database.py +++ b/python/test/test_database.py @@ -43,23 +43,32 @@ def test_infinity_thrift(self): infinity_obj = infinity.connect(REMOTE_HOST) # infinity - res = infinity_obj.create_database("my_database") - assert res.success - - res = infinity_obj.create_database("my_database!@#") - assert not res.success - - res = infinity_obj.create_database("my-database-dash") - assert not res.success - - res = infinity_obj.create_database("123_database") - assert not res.success - - res = infinity_obj.create_database("") - assert not res.success + infinity_obj.drop_database("my_database", None) + db = infinity_obj.create_database("my_database") + assert db + + try: + db = infinity_obj.create_database("my_database!@#") + except Exception as e: + print(e) + + try: + db = infinity_obj.create_database("my-database-dash") + except Exception as e: + print(e) + + try: + db = infinity_obj.create_database("123_database") + except Exception as e: + print(e) + + try: + db = infinity_obj.create_database("") + except Exception as e: + print(e) res = infinity_obj.list_databases() - assert res.success + assert res is not None res.db_names.sort() diff --git a/python/test/test_delete.py b/python/test/test_delete.py index 60d9212292..b7dcb42080 100644 --- a/python/test/test_delete.py +++ b/python/test/test_delete.py @@ -67,14 +67,14 @@ def test_infinity_thrift(self): res = table_obj.delete("c1 = 1") assert res.success - res = table_obj.query_builder().output(["*"]).to_df() + res = table_obj.output(["*"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (2, 3, 4), 'c2': (20, 30, 40), 'c3': (200, 300, 400)}) .astype({'c1': dtype('int32'), 'c2': dtype('int32'), 'c3': dtype('int32')})) res = table_obj.delete() assert res.success - res = table_obj.query_builder().output(["*"]).to_df() + res = table_obj.output(["*"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (), 'c2': (), 'c3': ()}) .astype({'c1': dtype('int32'), 'c2': dtype('int32'), 'c3': dtype('int32')})) diff --git a/python/test/test_import.py b/python/test/test_import.py index 0502364023..671f9fba05 100644 --- a/python/test/test_import.py +++ b/python/test/test_import.py @@ -29,37 +29,29 @@ def test_import(self): method: connect server, create table, import data, search, drop table, disconnect expect: all operations successfully """ - ports = [23817] - for port in ports: - infinity_obj = infinity.connect(REMOTE_HOST) - assert infinity_obj - # infinity + infinity_obj = infinity.connect(REMOTE_HOST) + assert infinity_obj - db_obj = infinity_obj.get_database("default") - assert db_obj + # infinity - # import - db_obj.drop_table("test_import", True) - res = db_obj.create_table( - "test_import", {"c1": "int", "c2": "vector,3,int"}, None) - table_obj = db_obj.get_table("test_import") - assert table_obj + db_obj = infinity_obj.get_database("default") + assert db_obj - test_dir = "/tmp/infinity/test_data/" - test_csv_dir = test_dir + "embedding_int_dim3.csv" - assert os.path.exists(test_csv_dir) + # import + db_obj.drop_table("test_import", True) + table_obj = db_obj.create_table( + "test_import", {"c1": "int", "c2": "vector,3,int"}, None) - res = table_obj.import_data(test_csv_dir, None) - assert res.success + test_dir = "/tmp/infinity/test_data/" + test_csv_dir = test_dir + "embedding_int_dim3.csv" + assert os.path.exists(test_csv_dir) - # search - res = table_obj.query_builder().output( - ["c1"]).filter("c1 > 1").to_df() - print(res) - res = db_obj.drop_table("test_import") - assert res.success + res = table_obj.import_data(test_csv_dir, None) + assert res.success - # disconnect - res = infinity_obj.disconnect() - assert res.success + # search + res = table_obj.output(["c1"]).filter("c1 > 1").to_df() + print(res) + res = db_obj.drop_table("test_import") + assert res.success diff --git a/python/test/test_index.py b/python/test/test_index.py index 7cf2380f95..ad57c02e6e 100644 --- a/python/test/test_index.py +++ b/python/test/test_index.py @@ -24,12 +24,9 @@ def test_create_index_IVFFlat(self): db_obj = infinity_obj.get_database("default") res = db_obj.drop_table("test_index_ivfflat", True) assert res.success - res = db_obj.create_table("test_index_ivfflat", { + table_obj = db_obj.create_table("test_index_ivfflat", { "c1": "vector,1024,float"}, None) - assert res.success - - table_obj = db_obj.get_table("test_index_ivfflat") - assert table_obj + assert table_obj is not None res = table_obj.create_index("my_index", [index.IndexInfo("c1", @@ -47,12 +44,9 @@ def test_create_index_HNSW(self): db_obj = infinity_obj.get_database("default") res = db_obj.drop_table("test_index_hnsw", True) assert res.success - res = db_obj.create_table( + table_obj = db_obj.create_table( "test_index_hnsw", {"c1": "vector,1024,float"}, None) - assert res.success - - table_obj = db_obj.get_table("test_index_hnsw") - assert table_obj + assert table_obj is not None res = table_obj.create_index("my_index", [index.IndexInfo("c1", @@ -75,12 +69,9 @@ def test_create_index_fulltext(self): db_obj = infinity_obj.get_database("default") res = db_obj.drop_table("test_index_fulltext", if_exists=True) assert res.success - res = db_obj.create_table( + table_obj = db_obj.create_table( "test_index_fulltext", {"doctitle": "varchar", "docdate": "varchar", "body": "varchar"}, None) - assert res.success - - table_obj = db_obj.get_table("test_index_fulltext") - assert table_obj + assert table_obj is not None res = table_obj.create_index("my_index", [index.IndexInfo("body", diff --git a/python/test/test_insert.py b/python/test/test_insert.py index 637bf52393..da86f22b67 100644 --- a/python/test/test_insert.py +++ b/python/test/test_insert.py @@ -48,23 +48,24 @@ def test_insert_basic(self): db_obj.drop_table(table_name="table_2", if_exists=True) # infinity - res = db_obj.create_table( + table_obj = db_obj.create_table( "table_2", {"c1": "int, primary key, not null", "c2": "int, not null"}, None) - assert res.success - - table_obj = db_obj.get_table("table_2") + assert table_obj is not None res = table_obj.insert([{"c1": 0, "c2": 0}]) assert res.success - res = table_obj.insert([{"c1": 1, "c2": 2}]) + res = table_obj.insert([{"c1": 1, "c2": 1}]) assert res.success - res = table_obj.insert([{"c2": 1, "c1": 2}]) + res = table_obj.insert({"c2": 2, "c1": 2}) assert res.success - res = table_obj.query_builder().output(["*"]).to_df() - pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (0, 1, 2), 'c2': (0, 2, 1)}) + res = table_obj.insert([{"c2": 3, "c1": 3}, {"c1": 4, "c2": 4}]) + assert res.success + + res = table_obj.output(["*"]).to_df() + pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (0, 1, 2, 3, 4), 'c2': (0, 1, 2, 3, 4)}) .astype({'c1': dtype('int32'), 'c2': dtype('int32')})) res = db_obj.drop_table("table_2") @@ -84,11 +85,10 @@ def test_insert_varchar(self): infinity_obj = infinity.connect(REMOTE_HOST) db_obj = infinity_obj.get_database("default") db_obj.drop_table("test_insert_varchar", True) - res = db_obj.create_table("test_insert_varchar", { + table_obj = db_obj.create_table("test_insert_varchar", { "c1": "varchar"}, None) - assert res.success - table_obj = db_obj.get_table("test_insert_varchar") assert table_obj + res = table_obj.insert([{"c1": "test_insert_varchar"}]) assert res.success res = table_obj.insert([{"c1": " test insert varchar "}]) @@ -96,7 +96,7 @@ def test_insert_varchar(self): res = table_obj.insert([{"c1": "^789$ test insert varchar"}]) assert res.success - res = table_obj.query_builder().output(["*"]).to_df() + res = table_obj.output(["*"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': ("test_insert_varchar", " test insert varchar ", "^789$ test insert varchar")})) db_obj.drop_table("test_insert_varchar") @@ -110,16 +110,14 @@ def test_insert_big_varchar(self): infinity_obj = infinity.connect(REMOTE_HOST) db_obj = infinity_obj.get_database("default") db_obj.drop_table("test_insert_big_varchar", True) - res = db_obj.create_table("test_insert_big_varchar", { + table_obj = db_obj.create_table("test_insert_big_varchar", { "c1": "varchar"}, None) - assert res.success - table_obj = db_obj.get_table("test_insert_big_varchar") assert table_obj for i in range(100): res = table_obj.insert([{"c1": "test_insert_big_varchar" * 1000}]) assert res.success - res = table_obj.query_builder().output(["*"]).to_df() + res = table_obj.output(["*"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame( {'c1': ["test_insert_big_varchar" * 1000] * 100})) @@ -134,10 +132,8 @@ def test_insert_embedding(self): infinity_obj = infinity.connect(REMOTE_HOST) db_obj = infinity_obj.get_database("default") db_obj.drop_table("test_insert_embedding", True) - res = db_obj.create_table("test_insert_embedding", { + table_obj = db_obj.create_table("test_insert_embedding", { "c1": "vector,3,int"}, None) - assert res.success - table_obj = db_obj.get_table("test_insert_embedding") assert table_obj res = table_obj.insert([{"c1": [1, 2, 3]}]) assert res.success @@ -147,13 +143,13 @@ def test_insert_embedding(self): assert res.success res = table_obj.insert([{"c1": [-7, -8, -9]}]) assert res.success - res = table_obj.query_builder().output(["*"]).to_df() + res = table_obj.output(["*"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame( {'c1': ([1, 2, 3], [4, 5, 6], [7, 8, 9], [-7, -8, -9])})) res = table_obj.insert([{"c1": [1, 2, 3]}, {"c1": [4, 5, 6]}, { "c1": [7, 8, 9]}, {"c1": [-7, -8, -9]}]) assert res.success - res = table_obj.query_builder().output(["*"]).to_df() + res = table_obj.output(["*"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': ([1, 2, 3], [4, 5, 6], [7, 8, 9], [-7, -8, -9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [-7, -8, -9])})) @@ -171,7 +167,7 @@ def test_insert_embedding(self): res = table_obj.insert([{"c1": [-7.7, -8.8, -9.9]}]) assert res.success - res = table_obj.query_builder().output(["*"]).to_df() + res = table_obj.output(["*"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame( {'c1': ([1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9], [-7.7, -8.8, -9.9])})) @@ -186,9 +182,8 @@ def test_insert_big_embedding(self): infinity_obj = infinity.connect(REMOTE_HOST) db_obj = infinity_obj.get_database("default") db_obj.drop_table("test_insert_big_embedding", True) - res = db_obj.create_table("test_insert_big_embedding", { + table_obj = db_obj.create_table("test_insert_big_embedding", { "c1": "vector,65535,int"}, None) - table_obj = db_obj.get_table("test_insert_big_embedding") assert table_obj res = table_obj.insert([{"c1": [1] * 65535}]) assert res.success @@ -208,10 +203,8 @@ def test_insert_big_embedding_float(self): infinity_obj = infinity.connect(REMOTE_HOST) db_obj = infinity_obj.get_database("default") db_obj.drop_table("test_insert_big_embedding_float", True) - res = db_obj.create_table("test_insert_big_embedding_float", { + table_obj = db_obj.create_table("test_insert_big_embedding_float", { "c1": "vector,65535,float"}, None) - assert res.success - table_obj = db_obj.get_table("test_insert_big_embedding_float") assert table_obj res = table_obj.insert([{"c1": [1] * 65535}]) assert res.success diff --git a/python/test/test_select.py b/python/test/test_select.py index 0cc2e52245..2539c4829f 100644 --- a/python/test/test_select.py +++ b/python/test/test_select.py @@ -81,11 +81,10 @@ def test_infinity_select(self): # infinity db_obj.drop_table("test_infinity_select", True) - res = db_obj.create_table( + table_obj = db_obj.create_table( "test_infinity_select", {"c1": "int, primary key, not null", "c2": "int, not null"}, None) - assert res.success - table_obj = db_obj.get_table("test_infinity_select") + assert table_obj is not None res = table_obj.insert( [{"c1": -3, "c2": -3}, {"c1": -2, "c2": -2}, {"c1": -1, "c2": -1}, {"c1": 0, "c2": 0}, {"c1": 1, "c2": 1}, @@ -97,53 +96,53 @@ def test_infinity_select(self): {"c1": 9, "c2": 9}]) assert res.success - res = table_obj.query_builder().output(["*"]).to_df() + res = table_obj.output(["*"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (-3, -2, -1, 0, 1, 2, 3, -8, -7, -6, 7, 8, 9), 'c2': (-3, -2, -1, 0, 1, 2, 3, -8, -7, -6, 7, 8, 9)}) .astype({'c1': dtype('int32'), 'c2': dtype('int32')})) - res = table_obj.query_builder().output(["c1", "c2"]).to_df() + res = table_obj.output(["c1", "c2"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (-3, -2, -1, 0, 1, 2, 3, -8, -7, -6, 7, 8, 9), 'c2': (-3, -2, -1, 0, 1, 2, 3, -8, -7, -6, 7, 8, 9)}) .astype({'c1': dtype('int32'), 'c2': dtype('int32')})) - res = table_obj.query_builder().output( + res = table_obj.output( ["c1 + c2"]).filter("c1 = 3").to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'(c1 + c2)': (6,)}) .astype({'(c1 + c2)': dtype('int32')})) - res = table_obj.query_builder().output( + res = table_obj.output( ["c1"]).filter("c1 > 2 and c2 < 4").to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (3,)}) .astype({'c1': dtype('int32')})) - res = table_obj.query_builder().output(["c2"]).filter( + res = table_obj.output(["c2"]).filter( "(-7 < c1 or 9 <= c1) and (c1 = 3)").to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c2': (3,)}) .astype({'c2': dtype('int32')})) - res = table_obj.query_builder().output(["c2"]).filter( + res = table_obj.output(["c2"]).filter( "(-8 < c1 and c1 <= -7) or (c1 >= 1 and 2 > c1)").to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c2': (1, -7)}) .astype({'c2': dtype('int32')})) - res = table_obj.query_builder().output(["c2"]).filter( + res = table_obj.output(["c2"]).filter( "((c1 >= -8 and -4 >= c1) or (c1 >= 0 and 5 > c1)) and ((c1 > 0 and c1 <= 1) or (c1 > -8 and c1 < -6))").to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c2': (1, -7)}) .astype({'c2': dtype('int32')})) - res = table_obj.query_builder().output(["c2"]).filter( + res = table_obj.output(["c2"]).filter( "(-7 < c1 or 9 <= c1) and (c2 = 3)").to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c2': (3,)}) .astype({'c2': dtype('int32')})) - res = table_obj.query_builder().output(["c2"]).filter( + res = table_obj.output(["c2"]).filter( "(-8 < c1 and c2 <= -7) or (c1 >= 1 and 2 > c2)").to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c2': (1, -7)}) .astype({'c2': dtype('int32')})) # Need fix rbo caused it Planner Error: Indices must be in order @src/planner/bound/base_table_ref.cppm:45 - # res = table_obj.query_builder().output(["c2"]).filter( + # res = table_obj.output(["c2"]).filter( # "((c2 >= -8 and -4 >= c1) or (c1 >= 0 and 5 > c2)) and ((c2 > 0 and c1 <= 1) or (c1 > -8 and c2 < -6))").to_df() # pd.testing.assert_frame_equal(res, pd.DataFrame({'c2': (1, -7)}) # .astype({'c2': dtype('int32')})) @@ -212,20 +211,20 @@ def test_select_varchar(self): {"c1": 'i', "c2": 'i'}, {"c1": 'j', "c2": 'j'}, { "c1": 'k', "c2": 'k'}, {"c1": 'l', "c2": 'l'}, {"c1": 'm', "c2": 'm'}]) - res = table_obj.query_builder().output(["*"]).to_df() + res = table_obj.output(["*"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm'), 'c2': ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm')}) .astype({'c1': dtype('O'), 'c2': dtype('O')})) - res = table_obj.query_builder().output( + res = table_obj.output( ["c1", "c2"]).filter("c1 = 'a'").to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': ('a',), 'c2': ('a',)}).astype( {'c1': dtype('O'), 'c2': dtype('O')})) # TODO NotImplement Error: Not implement: varchar > varchar - # res = table_obj.query_builder().output(["c1"]).filter("c1 > 'a' and c2 < 'c'").to_df() + # res = table_obj.output(["c1"]).filter("c1 > 'a' and c2 < 'c'").to_df() # pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': ('b',)}).astype({'c1': dtype('O')})) res = db_obj.drop_table("test_select_varchar") @@ -277,12 +276,12 @@ def test_select_embedding_int32(self): res = table_obj.import_data(test_csv_dir, None) assert res.success - res = table_obj.query_builder().output(["c2"]).to_df() + res = table_obj.output(["c2"]).to_df() print(res) pd.testing.assert_frame_equal(res, pd.DataFrame( {'c2': ([2, 3, 4], [6, 7, 8], [10, 11, 12])})) - res = table_obj.query_builder().output(["*"]).to_df() + res = table_obj.output(["*"]).to_df() print(res) pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (1, 5, 9), 'c2': ([2, 3, 4], [6, 7, 8], [10, 11, 12])}) .astype({'c1': dtype('int32'), 'c2': dtype('O')})) @@ -320,13 +319,13 @@ def test_select_embedding_float(self): res = table_obj.import_data(test_csv_dir, None) assert res.success - res = table_obj.query_builder().output(["c2"]).to_df() + res = table_obj.output(["c2"]).to_df() print(res) pd.testing.assert_frame_equal(res, pd.DataFrame( {'c2': ([0.1, 0.2, 0.3, -0.2], [0.2, 0.1, 0.3, 0.4], [0.3, 0.2, 0.1, 0.4], [0.4, 0.3, 0.2, 0.1])})) - res = table_obj.query_builder().output(["*"]).to_df() + res = table_obj.output(["*"]).to_df() print(res) pd.testing.assert_frame_equal(res, @@ -384,11 +383,11 @@ def test_select_same_output(self): table_obj = db_obj.get_table("test_select_same_output") table_obj.insert([{"c1": 1, "c2": 2}]) print() - res = table_obj.query_builder().output(["c1", "c2"]).to_df() + res = table_obj.output(["c1", "c2"]).to_df() print(res) - res = table_obj.query_builder().output(["c1", "c1"]).to_df() + res = table_obj.output(["c1", "c1"]).to_df() print(res) - res = table_obj.query_builder().output(["c1", "c2", "c1"]).to_df() + res = table_obj.output(["c1", "c2", "c1"]).to_df() print(res) def test_empty_table(self): @@ -400,12 +399,12 @@ def test_empty_table(self): table_obj = db_obj.get_table("test_empty_table") print() - res = table_obj.query_builder().output(["c1", "c2"]).to_df() + res = table_obj.output(["c1", "c2"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (), 'c2': ()}).astype( {'c1': dtype('int32'), 'c2': dtype('int32')})) - res = table_obj.query_builder().output(["c1", "c1"]).to_df() + res = table_obj.output(["c1", "c1"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (), 'c1_2': ()}).astype( {'c1': dtype('int32'), 'c1_2': dtype('int32')})) - res = table_obj.query_builder().output(["c1", "c2", "c1"]).to_df() + res = table_obj.output(["c1", "c2", "c1"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (), 'c2': (), 'c1_2': ()}).astype( {'c1': dtype('int32'), 'c2': dtype('int32'), 'c1_2': dtype('int32')})) diff --git a/python/test/test_table.py b/python/test/test_table.py index 48e3f50c19..4515c30a0c 100644 --- a/python/test/test_table.py +++ b/python/test/test_table.py @@ -13,7 +13,7 @@ # limitations under the License. import infinity -from infinity.common import NetworkAddress, REMOTE_HOST +from infinity.common import REMOTE_HOST class TestTable: @@ -56,29 +56,35 @@ def test_infinity_thrift(self): db_obj.drop_table("my_table") # infinity - res = db_obj.create_table( + tb = db_obj.create_table( "my_table", {"c1": "int, primary key", "c2": "float"}, None) - assert res.success - - res = db_obj.create_table( - "my_table!@#", {"c1": "int, primary key", "c2": "float"}, None) - assert not res.success - - res = db_obj.create_table( - "my-table-dash", {"c1": "float, primary key", "c2": "int"}, None) - assert not res.success - - res = db_obj.create_table( - "123_table", {"c1": "int, primary key", "c1": "float"}, None) - assert not res.success - - res = db_obj.create_table( - "bad_column", {"123": "int, primary key", "c2": "float"}, None) - assert not res.success - - res = db_obj.create_table( - "", {"c1": "int, primary key", "c2": "float"}, None) - assert not res.success + assert tb is not None + + try: + tb = db_obj.create_table( + "my_table!@#", {"c1": "int, primary key", "c2": "float"}, None) + except Exception as e: + print(e) + try: + tb = db_obj.create_table( + "my-table-dash", {"c1": "float, primary key", "c2": "int"}, None) + except Exception as e: + print(e) + try: + tb = db_obj.create_table( + "123_table", {"c1": "int, primary key", "c1": "float"}, None) + except Exception as e: + print(e) + try: + tb = db_obj.create_table( + "bad_column", {"123": "int, primary key", "c2": "float"}, None) + except Exception as e: + print(e) + try: + tb = db_obj.create_table( + "", {"c1": "int, primary key", "c2": "float"}, None) + except Exception as e: + print(e) # FIXME: res = db_obj.describe_table("my_table") diff --git a/python/test/test_update.py b/python/test/test_update.py index c896a1ae71..03be4d5c33 100644 --- a/python/test/test_update.py +++ b/python/test/test_update.py @@ -60,11 +60,9 @@ def test_infinity_thrift(self): db_obj.drop_table(table_name="table_4", if_exists=True) # infinity - res = db_obj.create_table( + table_obj = db_obj.create_table( "table_4", {"c1": "int, primary key, not null", "c2": "int", "c3": "int"}, None) - assert res.success - - table_obj = db_obj.get_table("table_4") + assert table_obj is not None res = table_obj.insert( [{"c1": 1, "c2": 10, "c3": 100}, {"c1": 2, "c2": 20, "c3": 200}, {"c1": 3, "c2": 30, "c3": 300}, @@ -74,7 +72,7 @@ def test_infinity_thrift(self): res = table_obj.update("c1 = 1", [{"c2": 90, "c3": 900}]) assert res.success - res = table_obj.query_builder().output(["*"]).to_df() + res = table_obj.output(["*"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame( {'c1': (2, 3, 4, 1), 'c2': (20, 30, 40, 90), 'c3': (200, 300, 400, 900)}) .astype({'c1': dtype('int32'), 'c2': dtype('int32'), 'c3': dtype('int32')})) @@ -82,7 +80,7 @@ def test_infinity_thrift(self): res = table_obj.update(None, [{"c2": 80, "c3": 800}]) assert res.success is False - res = table_obj.query_builder().output(["*"]).to_df() + res = table_obj.output(["*"]).to_df() pd.testing.assert_frame_equal(res, pd.DataFrame( {'c1': (2, 3, 4, 1), 'c2': (20, 30, 40, 90), 'c3': (200, 300, 400, 900)}) .astype({'c1': dtype('int32'), 'c2': dtype('int32'), 'c3': dtype('int32')}))