Skip to content

Commit

Permalink
[opt](function)Some geo functions incorrectly used get. (apache#40107)
Browse files Browse the repository at this point in the history
## Proposed changes
```

mysql [test]>select count(st_distance_sphere(db, db, db, db)) from double_ranges;
+-------------------------------------------+
| count(st_distance_sphere(db, db, db, db)) |
+-------------------------------------------+
|                                         0 |
+-------------------------------------------+
1 row in set (1.25 sec)

mysql [test]>select count(st_distance_sphere(db, db, db, db)) from double_ranges;
+-------------------------------------------+
| count(st_distance_sphere(db, db, db, db)) |
+-------------------------------------------+
|                                         0 |
+-------------------------------------------+
1 row in set (0.33 sec)
```

<!--Describe your changes.-->
  • Loading branch information
Mryange authored Sep 4, 2024
1 parent cfee8c9 commit 9a92b4c
Showing 1 changed file with 39 additions and 29 deletions.
68 changes: 39 additions & 29 deletions be/src/vec/functions/functions_geo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "vec/columns/column.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/columns_number.h"
#include "vec/common/assert_cast.h"
#include "vec/common/string_ref.h"
#include "vec/core/block.h"
#include "vec/core/column_with_type_and_name.h"
Expand Down Expand Up @@ -58,14 +59,16 @@ struct StPoint {
auto res = ColumnString::create();
auto null_map = ColumnUInt8::create(size, 0);
auto& null_map_data = null_map->get_data();
const auto* left_column_f64 = assert_cast<const ColumnFloat64*>(left_column.get());
const auto* right_column_f64 = assert_cast<const ColumnFloat64*>(right_column.get());
GeoPoint point;
std::string buf;
if (left_const) {
const_vector(left_column, right_column, res, null_map_data, size, point, buf);
const_vector(left_column_f64, right_column_f64, res, null_map_data, size, point, buf);
} else if (right_const) {
vector_const(left_column, right_column, res, null_map_data, size, point, buf);
vector_const(left_column_f64, right_column_f64, res, null_map_data, size, point, buf);
} else {
vector_vector(left_column, right_column, res, null_map_data, size, point, buf);
vector_vector(left_column_f64, right_column_f64, res, null_map_data, size, point, buf);
}

block.replace_by_position(result,
Expand All @@ -86,32 +89,32 @@ struct StPoint {
res->insert_data(buf.data(), buf.size());
}

static void const_vector(const ColumnPtr& left_column, const ColumnPtr& right_column,
static void const_vector(const ColumnFloat64* left_column, const ColumnFloat64* right_column,
ColumnString::MutablePtr& res, NullMap& null_map, const size_t size,
GeoPoint& point, std::string& buf) {
double x = left_column->operator[](0).get<Float64>();
double x = left_column->get_element(0);
for (int row = 0; row < size; ++row) {
auto cur_res = point.from_coord(x, right_column->operator[](row).get<Float64>());
auto cur_res = point.from_coord(x, right_column->get_element(row));
loop_do(cur_res, res, null_map, row, point, buf);
}
}

static void vector_const(const ColumnPtr& left_column, const ColumnPtr& right_column,
static void vector_const(const ColumnFloat64* left_column, const ColumnFloat64* right_column,
ColumnString::MutablePtr& res, NullMap& null_map, const size_t size,
GeoPoint& point, std::string& buf) {
double y = right_column->operator[](0).get<Float64>();
double y = right_column->get_element(0);
for (int row = 0; row < size; ++row) {
auto cur_res = point.from_coord(right_column->operator[](row).get<Float64>(), y);
auto cur_res = point.from_coord(right_column->get_element(row), y);
loop_do(cur_res, res, null_map, row, point, buf);
}
}

static void vector_vector(const ColumnPtr& left_column, const ColumnPtr& right_column,
static void vector_vector(const ColumnFloat64* left_column, const ColumnFloat64* right_column,
ColumnString::MutablePtr& res, NullMap& null_map, const size_t size,
GeoPoint& point, std::string& buf) {
for (int row = 0; row < size; ++row) {
auto cur_res = point.from_coord(left_column->operator[](row).get<Float64>(),
right_column->operator[](row).get<Float64>());
auto cur_res =
point.from_coord(left_column->get_element(row), right_column->get_element(row));
loop_do(cur_res, res, null_map, row, point, buf);
}
}
Expand Down Expand Up @@ -246,22 +249,25 @@ struct StDistanceSphere {
DCHECK_EQ(arguments.size(), 4);
auto return_type = block.get_data_type(result);

auto x_lng = block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
auto x_lat = block.get_by_position(arguments[1]).column->convert_to_full_column_if_const();
auto y_lng = block.get_by_position(arguments[2]).column->convert_to_full_column_if_const();
auto y_lat = block.get_by_position(arguments[3]).column->convert_to_full_column_if_const();

const auto* x_lng = check_and_get_column<ColumnFloat64>(
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const());
const auto* x_lat = check_and_get_column<ColumnFloat64>(
block.get_by_position(arguments[1]).column->convert_to_full_column_if_const());
const auto* y_lng = check_and_get_column<ColumnFloat64>(
block.get_by_position(arguments[2]).column->convert_to_full_column_if_const());
const auto* y_lat = check_and_get_column<ColumnFloat64>(
block.get_by_position(arguments[3]).column->convert_to_full_column_if_const());
CHECK(x_lng && x_lat && y_lng && y_lat);
const auto size = x_lng->size();
auto res = ColumnFloat64::create();
res->reserve(size);
auto null_map = ColumnUInt8::create(size, 0);
auto& null_map_data = null_map->get_data();
for (int row = 0; row < size; ++row) {
double distance = 0;
if (!GeoPoint::ComputeDistance(x_lng->operator[](row).get<Float64>(),
x_lat->operator[](row).get<Float64>(),
y_lng->operator[](row).get<Float64>(),
y_lat->operator[](row).get<Float64>(), &distance)) {
if (!GeoPoint::ComputeDistance(x_lng->get_element(row), x_lat->get_element(row),
y_lng->get_element(row), y_lat->get_element(row),
&distance)) {
null_map_data[row] = 1;
res->insert_default();
continue;
Expand All @@ -284,10 +290,15 @@ struct StAngleSphere {
DCHECK_EQ(arguments.size(), 4);
auto return_type = block.get_data_type(result);

auto x_lng = block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
auto x_lat = block.get_by_position(arguments[1]).column->convert_to_full_column_if_const();
auto y_lng = block.get_by_position(arguments[2]).column->convert_to_full_column_if_const();
auto y_lat = block.get_by_position(arguments[3]).column->convert_to_full_column_if_const();
const auto* x_lng = check_and_get_column<ColumnFloat64>(
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const());
const auto* x_lat = check_and_get_column<ColumnFloat64>(
block.get_by_position(arguments[1]).column->convert_to_full_column_if_const());
const auto* y_lng = check_and_get_column<ColumnFloat64>(
block.get_by_position(arguments[2]).column->convert_to_full_column_if_const());
const auto* y_lat = check_and_get_column<ColumnFloat64>(
block.get_by_position(arguments[3]).column->convert_to_full_column_if_const());
CHECK(x_lng && x_lat && y_lng && y_lat);

const auto size = x_lng->size();

Expand All @@ -298,10 +309,9 @@ struct StAngleSphere {

for (int row = 0; row < size; ++row) {
double angle = 0;
if (!GeoPoint::ComputeAngleSphere(x_lng->operator[](row).get<Float64>(),
x_lat->operator[](row).get<Float64>(),
y_lng->operator[](row).get<Float64>(),
y_lat->operator[](row).get<Float64>(), &angle)) {
if (!GeoPoint::ComputeAngleSphere(x_lng->get_element(row), x_lat->get_element(row),
y_lng->get_element(row), y_lat->get_element(row),
&angle)) {
null_map_data[row] = 1;
res->insert_default();
continue;
Expand Down

0 comments on commit 9a92b4c

Please sign in to comment.