From ee7cfda3f0eec286490018be8bf37cee3a8a16b7 Mon Sep 17 00:00:00 2001 From: Zhiyuan Liang <132966438+Ami11111@users.noreply.github.com> Date: Tue, 8 Oct 2024 19:31:51 +0800 Subject: [PATCH] Add sort for http api (#1980) ### What problem does this PR solve? Add sort for http api ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Test cases --- example/http/insert_search_data.sh | 24 +++++++ python/infinity_http.py | 23 ++++++- python/test_pysdk/test_select.py | 6 +- src/network/http/http_search.cpp | 100 ++++++++++++++++++++++++++++- src/network/http/http_search.cppm | 2 + 5 files changed, 151 insertions(+), 4 deletions(-) diff --git a/example/http/insert_search_data.sh b/example/http/insert_search_data.sh index e4e2fed55b..90af2b39e2 100755 --- a/example/http/insert_search_data.sh +++ b/example/http/insert_search_data.sh @@ -104,6 +104,30 @@ curl --request GET \ ] } ' +# select num and year of 'tbl1' and order by num descending, year ascending +echo -e '\n\n-- select num and year of tbl1 and order by num descending, year ascending' +curl --request GET \ + --url http://localhost:23820/databases/default_db/tables/tbl1/docs \ + --header 'accept: application/json' \ + --header 'content-type: application/json' \ + --data ' + { + "output": + [ + "num", + "year" + ], + "sort" : + [ + { + "num": "desc" + }, + { + "year": "asc" + } + ] + } ' + # select num and year of 'tbl1' where num > 1 and year < 2023 echo -e '\n\n-- select num and year of tbl1 where num > 1 and year < 2023 offset 1 limit 1' diff --git a/python/infinity_http.py b/python/infinity_http.py index 869dee7d66..776d15c8fa 100644 --- a/python/infinity_http.py +++ b/python/infinity_http.py @@ -4,7 +4,7 @@ import logging import os from test_pysdk.common.common_data import * -from infinity.common import ConflictType, InfinityException, SparseVector +from infinity.common import ConflictType, InfinityException, SparseVector, SortType from test_pysdk.common import common_values import infinity from typing import Optional, Any @@ -16,6 +16,7 @@ import pyarrow as pa from infinity.table import ExplainType from datetime import date, time, datetime +from typing import Optional, Union, List, Any class infinity_http: @@ -476,6 +477,8 @@ def select(self): tmp.update({"search": self._search_exprs}) if len(self._output): tmp.update({"output":self._output}) + if len(self._sort): + tmp.update({"sort":self._sort}) #print(tmp) d = self.set_up_data([], tmp) r = self.request(url, "get", h, d) @@ -552,6 +555,21 @@ def output( self._output = output self._filter = "" self._search_exprs = [] + self._sort = [] + return self + + def sort(self, order_by_expr_list: Optional[List[list[str, SortType]]]): + for order_by_expr in order_by_expr_list: + tmp = {} + if len(order_by_expr) != 2: + raise InfinityException(ErrorCode.INVALID_PARAMETER, f"order_by_expr_list must be a list of [column_name, sort_type]") + if order_by_expr[1] not in [SortType.Asc, SortType.Desc]: + raise InfinityException(ErrorCode.INVALID_PARAMETER, f"sort_type must be SortType.Asc or SortType.Desc") + if order_by_expr[1] == SortType.Asc: + tmp[order_by_expr[0]] = "asc" + else: + tmp[order_by_expr[0]] = "desc" + self._sort.append(tmp) return self def match_text(self, fields: str, query: str, topn: int, opt_params: Optional[dict] = None): @@ -713,7 +731,7 @@ def update(self, filter_str: str, update: dict[str, Any]): class database_result(infinity_http): def __init__(self, list = [], error_code = ErrorCode.OK, database_name = "" ,columns=[], table_name = "", - index_list = [], output = ["*"], filter="", fusion=[], knn={}, match = {}, match_tensor = {}, match_sparse = {}, output_res = []): + index_list = [], output = ["*"], filter="", fusion=[], knn={}, match = {}, match_tensor = {}, match_sparse = {}, sort = [], output_res = []): self.db_names = list self.error_code = error_code self.database_name = database_name # get database @@ -728,6 +746,7 @@ def __init__(self, list = [], error_code = ErrorCode.OK, database_name = "" ,col self._match = match self._match_tensor = match_tensor self._match_sparse = match_sparse + self._sort = sort self.output_res = output_res diff --git a/python/test_pysdk/test_select.py b/python/test_pysdk/test_select.py index 84389836b4..4b3261b10d 100644 --- a/python/test_pysdk/test_select.py +++ b/python/test_pysdk/test_select.py @@ -752,7 +752,6 @@ def test_neg_func(self, suffix): res = db_obj.drop_table("test_neg_func" + suffix, ConflictType.Error) assert res.error_code == ErrorCode.OK - @pytest.mark.usefixtures("skip_if_http") def test_sort(self, suffix): db_obj = self.infinity_obj.get_database("default_db") @@ -785,5 +784,10 @@ def test_sort(self, suffix): 'c2': (0, 1, 1, 2, 2, 3, 3, 6, 7, 7, 8, 8, 9)}) .astype({'c1': dtype('int32'), 'c2': dtype('int32')})) + res = table_obj.output(["_row_id"]).sort([["_row_id", SortType.Desc]]).to_df() + #pd.testing.assert_frame_equal(res, pd.DataFrame({'ROW_ID': (12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)}) + # .astype({'ROW_ID': dtype('int64')})) + print(res) + res = db_obj.drop_table("test_sort"+suffix, ConflictType.Error) assert res.error_code == ErrorCode.OK \ No newline at end of file diff --git a/src/network/http/http_search.cpp b/src/network/http/http_search.cpp index 614413f0da..09557fb020 100644 --- a/src/network/http/http_search.cpp +++ b/src/network/http/http_search.cpp @@ -42,6 +42,7 @@ import value; import physical_import; import explain_statement; import internal_types; +import select_statement; namespace infinity { @@ -63,6 +64,7 @@ void HTTPSearch::Process(Infinity *infinity_ptr, UniquePtr offset{}; UniquePtr search_expr{}; Vector *output_columns{nullptr}; + Vector *order_by_list{nullptr}; DeferFn defer_fn([&]() { if (output_columns != nullptr) { for (auto &expr : *output_columns) { @@ -72,6 +74,17 @@ void HTTPSearch::Process(Infinity *infinity_ptr, output_columns = nullptr; } }); + + DeferFn defer_fn_order([&]() { + if (order_by_list != nullptr) { + for (auto &expr : *order_by_list) { + delete expr; + } + delete order_by_list; + order_by_list = nullptr; + } + }); + for (const auto &elem : input_json.items()) { String key = elem.key(); ToLower(key); @@ -92,6 +105,24 @@ void HTTPSearch::Process(Infinity *infinity_ptr, if (output_columns == nullptr) { return; } + } else if (IsEqual(key, "sort")) { + if (order_by_list != nullptr) { + response["error_code"] = ErrorCode::kInvalidExpression; + response["error_message"] = "More than one sort field."; + return; + } + + auto &list = elem.value(); + if (!list.is_array()) { + response["error_code"] = ErrorCode::kInvalidExpression; + response["error_message"] = "Sort field should be array"; + return; + } + + order_by_list = ParseSort(list, http_status, response); + if (order_by_list == nullptr) { + return; + } } else if (IsEqual(key, "filter")) { if (filter) { @@ -143,9 +174,10 @@ void HTTPSearch::Process(Infinity *infinity_ptr, } const QueryResult result = - infinity_ptr->Search(db_name, table_name, search_expr.release(), filter.release(), limit.release(), offset.release(), output_columns, nullptr); + infinity_ptr->Search(db_name, table_name, search_expr.release(), filter.release(), limit.release(), offset.release(), output_columns, order_by_list); output_columns = nullptr; + order_by_list = nullptr; if (result.IsOk()) { SizeT block_rows = result.result_table_->DataBlockCount(); for (SizeT block_id = 0; block_id < block_rows; ++block_id) { @@ -415,6 +447,72 @@ Vector *HTTPSearch::ParseOutput(const nlohmann::json &output_list, return res; } +Vector *HTTPSearch::ParseSort(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response) { + Vector *order_by_list = new Vector(); + DeferFn defer_fn([&]() { + if (order_by_list != nullptr) { + for (auto &expr : *order_by_list) { + delete expr; + } + delete order_by_list; + order_by_list = nullptr; + } + }); + + for(const auto &order_expr : json_object) { + for (const auto &expression : order_expr.items()) { + String key = expression.key(); + ToLower(key); + auto order_by_expr = MakeUnique(); + if (key == "_row_id" or key == "_similarity" or key == "_distance" or key == "_score") { + auto parsed_expr = new FunctionExpr(); + if (key == "_row_id") { + parsed_expr->func_name_ = "row_id"; + } else if (key == "_similarity") { + parsed_expr->func_name_ = "similarity"; + } else if (key == "_distance") { + parsed_expr->func_name_ = "distance"; + } else if (key == "_score") { + parsed_expr->func_name_ = "score"; + } + order_by_expr->expr_ = parsed_expr; + parsed_expr = nullptr; + } else { + UniquePtr expr_parsed_result = MakeUnique(); + ExprParser expr_parser; + expr_parser.Parse(key, expr_parsed_result.get()); + if (expr_parsed_result->IsError() || expr_parsed_result->exprs_ptr_->size() == 0) { + response["error_code"] = ErrorCode::kInvalidExpression; + response["error_message"] = fmt::format("Invalid expression: {}", key); + return nullptr; + } + + order_by_expr->expr_ = expr_parsed_result->exprs_ptr_->at(0); + expr_parsed_result->exprs_ptr_->at(0) = nullptr; + } + + String value = expression.value(); + ToLower(value); + if (value == "asc") { + order_by_expr->type_ = OrderType::kAsc; + } else if (value == "desc") { + order_by_expr->type_ = OrderType::kDesc; + } else { + response["error_code"] = ErrorCode::kInvalidExpression; + response["error_message"] = fmt::format("Invalid expression: {}", value); + return nullptr; + } + + order_by_list->emplace_back(order_by_expr.release()); + } + } + + // Avoiding DeferFN auto free the output expressions + Vector *res = order_by_list; + order_by_list = nullptr; + return res; +} + UniquePtr HTTPSearch::ParseSearchExpr(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response) { if (json_object.type() != nlohmann::json::value_t::array) { response["error_code"] = ErrorCode::kInvalidExpression; diff --git a/src/network/http/http_search.cppm b/src/network/http/http_search.cppm index 93f4862d1e..0a9cbb8acb 100644 --- a/src/network/http/http_search.cppm +++ b/src/network/http/http_search.cppm @@ -29,6 +29,7 @@ import infinity; import internal_types; import constant_expr; import search_expr; +import select_statement; namespace infinity { @@ -48,6 +49,7 @@ public: nlohmann::json &response); static Vector *ParseOutput(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response); + static Vector *ParseSort(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response); static UniquePtr ParseFilter(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response); static UniquePtr ParseSearchExpr(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response); static UniquePtr ParseFusion(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response);