Skip to content

Commit

Permalink
Add sort for http api (#1980)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?
Add sort for http api

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Test cases
  • Loading branch information
Ami11111 authored Oct 8, 2024
1 parent 4bc9374 commit ee7cfda
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 4 deletions.
24 changes: 24 additions & 0 deletions example/http/insert_search_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,30 @@ curl --request GET \
]
} '

# select num and year of 'tbl1' and order by num descending, year ascending
echo -e '\n\n-- select num and year of tbl1 and order by num descending, year ascending'
curl --request GET \
--url http://localhost:23820/databases/default_db/tables/tbl1/docs \
--header 'accept: application/json' \
--header 'content-type: application/json' \
--data '
{
"output":
[
"num",
"year"
],
"sort" :
[
{
"num": "desc"
},
{
"year": "asc"
}
]
} '


# select num and year of 'tbl1' where num > 1 and year < 2023
echo -e '\n\n-- select num and year of tbl1 where num > 1 and year < 2023 offset 1 limit 1'
Expand Down
23 changes: 21 additions & 2 deletions python/infinity_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import os
from test_pysdk.common.common_data import *
from infinity.common import ConflictType, InfinityException, SparseVector
from infinity.common import ConflictType, InfinityException, SparseVector, SortType
from test_pysdk.common import common_values
import infinity
from typing import Optional, Any
Expand All @@ -16,6 +16,7 @@
import pyarrow as pa
from infinity.table import ExplainType
from datetime import date, time, datetime
from typing import Optional, Union, List, Any


class infinity_http:
Expand Down Expand Up @@ -476,6 +477,8 @@ def select(self):
tmp.update({"search": self._search_exprs})
if len(self._output):
tmp.update({"output":self._output})
if len(self._sort):
tmp.update({"sort":self._sort})
#print(tmp)
d = self.set_up_data([], tmp)
r = self.request(url, "get", h, d)
Expand Down Expand Up @@ -552,6 +555,21 @@ def output(
self._output = output
self._filter = ""
self._search_exprs = []
self._sort = []
return self

def sort(self, order_by_expr_list: Optional[List[list[str, SortType]]]):
for order_by_expr in order_by_expr_list:
tmp = {}
if len(order_by_expr) != 2:
raise InfinityException(ErrorCode.INVALID_PARAMETER, f"order_by_expr_list must be a list of [column_name, sort_type]")
if order_by_expr[1] not in [SortType.Asc, SortType.Desc]:
raise InfinityException(ErrorCode.INVALID_PARAMETER, f"sort_type must be SortType.Asc or SortType.Desc")
if order_by_expr[1] == SortType.Asc:
tmp[order_by_expr[0]] = "asc"
else:
tmp[order_by_expr[0]] = "desc"
self._sort.append(tmp)
return self

def match_text(self, fields: str, query: str, topn: int, opt_params: Optional[dict] = None):
Expand Down Expand Up @@ -713,7 +731,7 @@ def update(self, filter_str: str, update: dict[str, Any]):

class database_result(infinity_http):
def __init__(self, list = [], error_code = ErrorCode.OK, database_name = "" ,columns=[], table_name = "",
index_list = [], output = ["*"], filter="", fusion=[], knn={}, match = {}, match_tensor = {}, match_sparse = {}, output_res = []):
index_list = [], output = ["*"], filter="", fusion=[], knn={}, match = {}, match_tensor = {}, match_sparse = {}, sort = [], output_res = []):
self.db_names = list
self.error_code = error_code
self.database_name = database_name # get database
Expand All @@ -728,6 +746,7 @@ def __init__(self, list = [], error_code = ErrorCode.OK, database_name = "" ,col
self._match = match
self._match_tensor = match_tensor
self._match_sparse = match_sparse
self._sort = sort
self.output_res = output_res


Expand Down
6 changes: 5 additions & 1 deletion python/test_pysdk/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,6 @@ def test_neg_func(self, suffix):
res = db_obj.drop_table("test_neg_func" + suffix, ConflictType.Error)
assert res.error_code == ErrorCode.OK

@pytest.mark.usefixtures("skip_if_http")
def test_sort(self, suffix):
db_obj = self.infinity_obj.get_database("default_db")

Expand Down Expand Up @@ -785,5 +784,10 @@ def test_sort(self, suffix):
'c2': (0, 1, 1, 2, 2, 3, 3, 6, 7, 7, 8, 8, 9)})
.astype({'c1': dtype('int32'), 'c2': dtype('int32')}))

res = table_obj.output(["_row_id"]).sort([["_row_id", SortType.Desc]]).to_df()
#pd.testing.assert_frame_equal(res, pd.DataFrame({'ROW_ID': (12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)})
# .astype({'ROW_ID': dtype('int64')}))
print(res)

res = db_obj.drop_table("test_sort"+suffix, ConflictType.Error)
assert res.error_code == ErrorCode.OK
100 changes: 99 additions & 1 deletion src/network/http/http_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import value;
import physical_import;
import explain_statement;
import internal_types;
import select_statement;

namespace infinity {

Expand All @@ -63,6 +64,7 @@ void HTTPSearch::Process(Infinity *infinity_ptr,
UniquePtr<ParsedExpr> offset{};
UniquePtr<SearchExpr> search_expr{};
Vector<ParsedExpr *> *output_columns{nullptr};
Vector<OrderByExpr *> *order_by_list{nullptr};
DeferFn defer_fn([&]() {
if (output_columns != nullptr) {
for (auto &expr : *output_columns) {
Expand All @@ -72,6 +74,17 @@ void HTTPSearch::Process(Infinity *infinity_ptr,
output_columns = nullptr;
}
});

DeferFn defer_fn_order([&]() {
if (order_by_list != nullptr) {
for (auto &expr : *order_by_list) {
delete expr;
}
delete order_by_list;
order_by_list = nullptr;
}
});

for (const auto &elem : input_json.items()) {
String key = elem.key();
ToLower(key);
Expand All @@ -92,6 +105,24 @@ void HTTPSearch::Process(Infinity *infinity_ptr,
if (output_columns == nullptr) {
return;
}
} else if (IsEqual(key, "sort")) {
if (order_by_list != nullptr) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = "More than one sort field.";
return;
}

auto &list = elem.value();
if (!list.is_array()) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = "Sort field should be array";
return;
}

order_by_list = ParseSort(list, http_status, response);
if (order_by_list == nullptr) {
return;
}
} else if (IsEqual(key, "filter")) {

if (filter) {
Expand Down Expand Up @@ -143,9 +174,10 @@ void HTTPSearch::Process(Infinity *infinity_ptr,
}

const QueryResult result =
infinity_ptr->Search(db_name, table_name, search_expr.release(), filter.release(), limit.release(), offset.release(), output_columns, nullptr);
infinity_ptr->Search(db_name, table_name, search_expr.release(), filter.release(), limit.release(), offset.release(), output_columns, order_by_list);

output_columns = nullptr;
order_by_list = nullptr;
if (result.IsOk()) {
SizeT block_rows = result.result_table_->DataBlockCount();
for (SizeT block_id = 0; block_id < block_rows; ++block_id) {
Expand Down Expand Up @@ -415,6 +447,72 @@ Vector<ParsedExpr *> *HTTPSearch::ParseOutput(const nlohmann::json &output_list,
return res;
}

Vector<OrderByExpr *> *HTTPSearch::ParseSort(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response) {
Vector<OrderByExpr *> *order_by_list = new Vector<OrderByExpr *>();
DeferFn defer_fn([&]() {
if (order_by_list != nullptr) {
for (auto &expr : *order_by_list) {
delete expr;
}
delete order_by_list;
order_by_list = nullptr;
}
});

for(const auto &order_expr : json_object) {
for (const auto &expression : order_expr.items()) {
String key = expression.key();
ToLower(key);
auto order_by_expr = MakeUnique<OrderByExpr>();
if (key == "_row_id" or key == "_similarity" or key == "_distance" or key == "_score") {
auto parsed_expr = new FunctionExpr();
if (key == "_row_id") {
parsed_expr->func_name_ = "row_id";
} else if (key == "_similarity") {
parsed_expr->func_name_ = "similarity";
} else if (key == "_distance") {
parsed_expr->func_name_ = "distance";
} else if (key == "_score") {
parsed_expr->func_name_ = "score";
}
order_by_expr->expr_ = parsed_expr;
parsed_expr = nullptr;
} else {
UniquePtr<ExpressionParserResult> expr_parsed_result = MakeUnique<ExpressionParserResult>();
ExprParser expr_parser;
expr_parser.Parse(key, expr_parsed_result.get());
if (expr_parsed_result->IsError() || expr_parsed_result->exprs_ptr_->size() == 0) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = fmt::format("Invalid expression: {}", key);
return nullptr;
}

order_by_expr->expr_ = expr_parsed_result->exprs_ptr_->at(0);
expr_parsed_result->exprs_ptr_->at(0) = nullptr;
}

String value = expression.value();
ToLower(value);
if (value == "asc") {
order_by_expr->type_ = OrderType::kAsc;
} else if (value == "desc") {
order_by_expr->type_ = OrderType::kDesc;
} else {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = fmt::format("Invalid expression: {}", value);
return nullptr;
}

order_by_list->emplace_back(order_by_expr.release());
}
}

// Avoiding DeferFN auto free the output expressions
Vector<OrderByExpr *> *res = order_by_list;
order_by_list = nullptr;
return res;
}

UniquePtr<SearchExpr> HTTPSearch::ParseSearchExpr(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response) {
if (json_object.type() != nlohmann::json::value_t::array) {
response["error_code"] = ErrorCode::kInvalidExpression;
Expand Down
2 changes: 2 additions & 0 deletions src/network/http/http_search.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import infinity;
import internal_types;
import constant_expr;
import search_expr;
import select_statement;

namespace infinity {

Expand All @@ -48,6 +49,7 @@ public:
nlohmann::json &response);

static Vector<ParsedExpr *> *ParseOutput(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response);
static Vector<OrderByExpr *> *ParseSort(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response);
static UniquePtr<ParsedExpr> ParseFilter(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response);
static UniquePtr<SearchExpr> ParseSearchExpr(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response);
static UniquePtr<FusionExpr> ParseFusion(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response);
Expand Down

0 comments on commit ee7cfda

Please sign in to comment.