diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index e10a7cef5c481..fdddb039d46f1 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -195,7 +195,13 @@ To build the Milvus project, run the following command: $ make ``` -If this command succeed, you will now have an executable at `bin/milvus` off of your Milvus project directory. +If this command succeeds, you will now have an executable at `bin/milvus` in your Milvus project directory. + +If you want to run the `bin/milvus` executable on the host machine, you need to set `LD_LIBRARY_PATH` temporarily: + +```shell +$ LD_LIBRARY_PATH=./internal/core/output/lib:lib:$LD_LIBRARY_PATH ./bin/milvus +``` If you want to update proto file before `make`, we can use the following command: @@ -500,7 +506,6 @@ $ ./build/build_image.sh // build milvus latest docker image $ docker images // check if milvus latest image is ready REPOSITORY TAG IMAGE ID CREATED SIZE milvusdb/milvus latest 63c62ff7c1b7 52 minutes ago 570MB -$ install with docker compose ``` ## GitHub Flow @@ -602,4 +607,4 @@ A: Reinstall llvm@15 brew reinstall llvm@15 export LDFLAGS="-L/opt/homebrew/opt/llvm@15/lib" export CPPFLAGS="-I/opt/homebrew/opt/llvm@15/include" -``` \ No newline at end of file +``` diff --git a/README.md b/README.md index 09858de499297..9e0bf7607554a 100644 --- a/README.md +++ b/README.md @@ -56,13 +56,13 @@ Milvus was released under the [open-source Apache License 2.0](https://github.co ## Quick start ### Start with Zilliz Cloud -Zilliz Cloud is a fully managed service on cloud and the simplest way to deploy LF AI Milvus®, See [Zilliz Cloud Quick Start Guide](https://zilliz.com/doc/quick_start) and start your [free trial](https://cloud.zilliz.com/signup). +Zilliz Cloud is a fully managed service on cloud and the simplest way to deploy LF AI Milvus®, See [Zilliz Cloud](https://zilliz.com/) and start your [free trial](https://cloud.zilliz.com/signup). ### Install Milvus -- [Standalone Quick Start Guide](https://milvus.io/docs/v2.0.x/install_standalone-docker.md) +- [Standalone Quick Start Guide](https://milvus.io/docs/install_standalone-docker.md) -- [Cluster Quick Start Guide](https://milvus.io/docs/v2.0.x/install_cluster-docker.md) +- [Cluster Quick Start Guide](https://milvus.io/docs/install_cluster-docker.md) - [Advanced Deployment](https://github.com/milvus-io/milvus/wiki) @@ -169,7 +169,7 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut ### All contributors
-
+
@@ -284,6 +284,7 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut + @@ -371,6 +372,7 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut + diff --git a/README_CN.md b/README_CN.md index 1aee72e54ee74..26207c0f21fbb 100644 --- a/README_CN.md +++ b/README_CN.md @@ -154,7 +154,7 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 ### All contributors
-
+
@@ -269,6 +269,7 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 + @@ -356,6 +357,7 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 + diff --git a/build/docker/builder/gpu/ubuntu20.04/Dockerfile b/build/docker/builder/gpu/ubuntu20.04/Dockerfile index ba86136227817..5550e3fc82391 100644 --- a/build/docker/builder/gpu/ubuntu20.04/Dockerfile +++ b/build/docker/builder/gpu/ubuntu20.04/Dockerfile @@ -34,7 +34,7 @@ RUN /opt/vcpkg/bootstrap-vcpkg.sh -disableMetrics && ln -s /opt/vcpkg/vcpkg /usr RUN vcpkg install azure-identity-cpp azure-storage-blobs-cpp gtest -# Instal openblas +# Install openblas # RUN wget https://github.com/xianyi/OpenBLAS/archive/v0.3.21.tar.gz && \ # tar zxvf v0.3.21.tar.gz && cd OpenBLAS-0.3.21 && \ # make NO_STATIC=1 NO_LAPACK=1 NO_LAPACKE=1 NO_CBLAS=1 NO_AFFINITY=1 USE_OPENMP=1 \ diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 22212fec29aef..6cf78cc058e63 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -530,6 +530,8 @@ dataNode: serverMaxRecvSize: 268435456 clientMaxSendSize: 268435456 clientMaxRecvSize: 536870912 + slot: + slotCap: 2 # The maximum number of tasks(e.g. compaction, importing) allowed to run concurrently on a datanode. # Configures the system log output. log: diff --git a/internal/core/conanfile.py b/internal/core/conanfile.py index cdeb40a1d7f53..2b4e31ed002c5 100644 --- a/internal/core/conanfile.py +++ b/internal/core/conanfile.py @@ -42,6 +42,7 @@ class MilvusConan(ConanFile): "opentelemetry-cpp/1.8.1.1@milvus/dev", "librdkafka/1.9.1", "abseil/20230125.3", + "roaring/3.0.0", ) generators = ("cmake", "cmake_find_package") default_options = { diff --git a/internal/core/src/index/BitmapIndex.cpp b/internal/core/src/index/BitmapIndex.cpp new file mode 100644 index 0000000000000..5d0a4aabec3cd --- /dev/null +++ b/internal/core/src/index/BitmapIndex.cpp @@ -0,0 +1,666 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "index/BitmapIndex.h" + +#include "common/Slice.h" +#include "index/Meta.h" +#include "index/ScalarIndex.h" +#include "index/Utils.h" +#include "storage/Util.h" +#include "storage/space.h" + +namespace milvus { +namespace index { + +template +BitmapIndex::BitmapIndex( + const storage::FileManagerContext& file_manager_context) + : is_built_(false) { + if (file_manager_context.Valid()) { + file_manager_ = + std::make_shared(file_manager_context); + AssertInfo(file_manager_ != nullptr, "create file manager failed!"); + } +} + +template +BitmapIndex::BitmapIndex( + const storage::FileManagerContext& file_manager_context, + std::shared_ptr space) + : is_built_(false), data_(), space_(space) { + if (file_manager_context.Valid()) { + file_manager_ = std::make_shared( + file_manager_context, space); + AssertInfo(file_manager_ != nullptr, "create file manager failed!"); + } +} + +template +void +BitmapIndex::Build(const Config& config) { + if (is_built_) { + return; + } + auto insert_files = + GetValueFromConfig>(config, "insert_files"); + AssertInfo(insert_files.has_value(), + "insert file paths is empty when build index"); + + auto field_datas = + file_manager_->CacheRawDataToMemory(insert_files.value()); + + int total_num_rows = 0; + for (const auto& field_data : field_datas) { + total_num_rows += field_data->get_num_rows(); + } + if (total_num_rows == 0) { + throw SegcoreError(DataIsEmpty, + "scalar bitmap index can not build null values"); + } + + total_num_rows_ = total_num_rows; + + int64_t offset = 0; + for (const auto& data : field_datas) { + auto slice_row_num = data->get_num_rows(); + for (size_t i = 0; i < slice_row_num; ++i) { + auto val = reinterpret_cast(data->RawValue(i)); + data_[*val].add(offset); + offset++; + } + } + is_built_ = true; +} + +template +void +BitmapIndex::Build(size_t n, const T* data) { + if (is_built_) { + return; + } + if (n == 0) { + throw SegcoreError(DataIsEmpty, + "BitmapIndex can not build null values"); + } + + T* p = const_cast(data); + for (int i = 0; i < n; ++i, ++p) { + data_[*p].add(i); + } + total_num_rows_ = n; + + for (auto it = data_.begin(); it != data_.end(); ++it) { + bitsets_[it->first] = ConvertRoaringToBitset(it->second); + } + + is_built_ = true; +} + +template +void +BitmapIndex::BuildV2(const Config& config) { + if (is_built_) { + return; + } + auto field_name = file_manager_->GetIndexMeta().field_name; + auto reader = space_->ScanData(); + std::vector field_datas; + for (auto rec = reader->Next(); rec != nullptr; rec = reader->Next()) { + if (!rec.ok()) { + PanicInfo(DataFormatBroken, "failed to read data"); + } + auto data = rec.ValueUnsafe(); + auto total_num_rows = data->num_rows(); + auto col_data = data->GetColumnByName(field_name); + auto field_data = storage::CreateFieldData( + DataType(GetDType()), 0, total_num_rows); + field_data->FillFieldData(col_data); + field_datas.push_back(field_data); + } + + int total_num_rows = 0; + for (auto& field_data : field_datas) { + total_num_rows += field_data->get_num_rows(); + } + if (total_num_rows == 0) { + throw SegcoreError(DataIsEmpty, + "scalar bitmap index can not build null values"); + } + + total_num_rows_ = total_num_rows; + + int64_t offset = 0; + for (const auto& data : field_datas) { + auto slice_row_num = data->get_num_rows(); + for (size_t i = 0; i < slice_row_num; ++i) { + auto val = reinterpret_cast(data->RawValue(i)); + data_[*val].add(offset); + offset++; + } + } + is_built_ = true; +} + +template +size_t +BitmapIndex::GetIndexDataSize() { + auto index_data_size = 0; + for (auto& pair : data_) { + index_data_size += pair.second.getSizeInBytes() + sizeof(T); + } + return index_data_size; +} + +template <> +size_t +BitmapIndex::GetIndexDataSize() { + auto index_data_size = 0; + for (auto& pair : data_) { + index_data_size += + pair.second.getSizeInBytes() + pair.first.size() + sizeof(size_t); + } + return index_data_size; +} + +template +void +BitmapIndex::SerializeIndexData(uint8_t* data_ptr) { + for (auto& pair : data_) { + memcpy(data_ptr, &pair.first, sizeof(T)); + data_ptr += sizeof(T); + + pair.second.write(reinterpret_cast(data_ptr)); + data_ptr += pair.second.getSizeInBytes(); + } +} + +template <> +void +BitmapIndex::SerializeIndexData(uint8_t* data_ptr) { + for (auto& pair : data_) { + size_t key_size = pair.first.size(); + memcpy(data_ptr, &key_size, sizeof(size_t)); + data_ptr += sizeof(size_t); + + memcpy(data_ptr, pair.first.data(), key_size); + data_ptr += key_size; + + pair.second.write(reinterpret_cast(data_ptr)); + data_ptr += pair.second.getSizeInBytes(); + } +} + +template +BinarySet +BitmapIndex::Serialize(const Config& config) { + AssertInfo(is_built_, "index has not been built yet"); + + auto index_data_size = GetIndexDataSize(); + + std::shared_ptr index_data(new uint8_t[index_data_size]); + uint8_t* data_ptr = index_data.get(); + SerializeIndexData(data_ptr); + + std::shared_ptr index_length(new uint8_t[sizeof(size_t)]); + auto index_size = data_.size(); + memcpy(index_length.get(), &index_size, sizeof(size_t)); + + std::shared_ptr num_rows(new uint8_t[sizeof(size_t)]); + memcpy(num_rows.get(), &total_num_rows_, sizeof(size_t)); + + BinarySet ret_set; + ret_set.Append(BITMAP_INDEX_DATA, index_data, index_data_size); + ret_set.Append(BITMAP_INDEX_LENGTH, index_length, sizeof(size_t)); + ret_set.Append(BITMAP_INDEX_NUM_ROWS, num_rows, sizeof(size_t)); + + LOG_INFO("build bitmap index with cardinality = {}, num_rows = {}", + index_size, + total_num_rows_); + return ret_set; +} + +template +BinarySet +BitmapIndex::Upload(const Config& config) { + auto binary_set = Serialize(config); + + file_manager_->AddFile(binary_set); + + auto remote_path_to_size = file_manager_->GetRemotePathsToFileSize(); + BinarySet ret; + for (auto& file : remote_path_to_size) { + ret.Append(file.first, nullptr, file.second); + } + return ret; +} + +template +BinarySet +BitmapIndex::UploadV2(const Config& config) { + auto binary_set = Serialize(config); + + file_manager_->AddFileV2(binary_set); + + auto remote_path_to_size = file_manager_->GetRemotePathsToFileSize(); + BinarySet ret; + for (auto& file : remote_path_to_size) { + ret.Append(file.first, nullptr, file.second); + } + return ret; +} + +template +void +BitmapIndex::Load(const BinarySet& binary_set, const Config& config) { + milvus::Assemble(const_cast(binary_set)); + LoadWithoutAssemble(binary_set, config); +} + +template +TargetBitmap +BitmapIndex::ConvertRoaringToBitset(const roaring::Roaring& values) { + AssertInfo(total_num_rows_ != 0, "total num rows should not be 0"); + TargetBitmap res(total_num_rows_, false); + for (const auto& val : values) { + res.set(val); + } + return res; +} + +template +void +BitmapIndex::DeserializeIndexData(const uint8_t* data_ptr, + size_t index_length) { + for (size_t i = 0; i < index_length; ++i) { + T key; + memcpy(&key, data_ptr, sizeof(T)); + data_ptr += sizeof(T); + + roaring::Roaring value; + value = roaring::Roaring::read(reinterpret_cast(data_ptr)); + data_ptr += value.getSizeInBytes(); + + bitsets_[key] = ConvertRoaringToBitset(value); + } +} + +template <> +void +BitmapIndex::DeserializeIndexData(const uint8_t* data_ptr, + size_t index_length) { + for (size_t i = 0; i < index_length; ++i) { + size_t key_size; + memcpy(&key_size, data_ptr, sizeof(size_t)); + data_ptr += sizeof(size_t); + + std::string key(reinterpret_cast(data_ptr), key_size); + data_ptr += key_size; + + roaring::Roaring value; + value = roaring::Roaring::read(reinterpret_cast(data_ptr)); + data_ptr += value.getSizeInBytes(); + + bitsets_[key] = ConvertRoaringToBitset(value); + } +} + +template +void +BitmapIndex::LoadWithoutAssemble(const BinarySet& binary_set, + const Config& config) { + size_t index_length; + auto index_length_buffer = binary_set.GetByName(BITMAP_INDEX_LENGTH); + memcpy(&index_length, + index_length_buffer->data.get(), + (size_t)index_length_buffer->size); + + auto num_rows_buffer = binary_set.GetByName(BITMAP_INDEX_NUM_ROWS); + memcpy(&total_num_rows_, + num_rows_buffer->data.get(), + (size_t)num_rows_buffer->size); + + auto index_data_buffer = binary_set.GetByName(BITMAP_INDEX_DATA); + const uint8_t* data_ptr = index_data_buffer->data.get(); + + DeserializeIndexData(data_ptr, index_length); + + LOG_INFO("load bitmap index with cardinality = {}, num_rows = {}", + Cardinality(), + total_num_rows_); + + is_built_ = true; +} + +template +void +BitmapIndex::LoadV2(const Config& config) { + auto blobs = space_->StatisticsBlobs(); + std::vector index_files; + auto prefix = file_manager_->GetRemoteIndexObjectPrefixV2(); + for (auto& b : blobs) { + if (b.name.rfind(prefix, 0) == 0) { + index_files.push_back(b.name); + } + } + std::map index_datas{}; + for (auto& file_name : index_files) { + auto res = space_->GetBlobByteSize(file_name); + if (!res.ok()) { + PanicInfo(S3Error, "unable to read index blob"); + } + auto index_blob_data = + std::shared_ptr(new uint8_t[res.value()]); + auto status = space_->ReadBlob(file_name, index_blob_data.get()); + if (!status.ok()) { + PanicInfo(S3Error, "unable to read index blob"); + } + auto raw_index_blob = + storage::DeserializeFileData(index_blob_data, res.value()); + auto key = file_name.substr(file_name.find_last_of('/') + 1); + index_datas[key] = raw_index_blob->GetFieldData(); + } + AssembleIndexDatas(index_datas); + + BinarySet binary_set; + for (auto& [key, data] : index_datas) { + auto size = data->Size(); + auto deleter = [&](uint8_t*) {}; // avoid repeated deconstruction + auto buf = std::shared_ptr( + (uint8_t*)const_cast(data->Data()), deleter); + binary_set.Append(key, buf, size); + } + + LoadWithoutAssemble(binary_set, config); +} + +template +void +BitmapIndex::Load(milvus::tracer::TraceContext ctx, const Config& config) { + auto index_files = + GetValueFromConfig>(config, "index_files"); + AssertInfo(index_files.has_value(), + "index file paths is empty when load bitmap index"); + auto index_datas = file_manager_->LoadIndexToMemory(index_files.value()); + AssembleIndexDatas(index_datas); + BinarySet binary_set; + for (auto& [key, data] : index_datas) { + auto size = data->Size(); + auto deleter = [&](uint8_t*) {}; // avoid repeated deconstruction + auto buf = std::shared_ptr( + (uint8_t*)const_cast(data->Data()), deleter); + binary_set.Append(key, buf, size); + } + + LoadWithoutAssemble(binary_set, config); +} + +template +const TargetBitmap +BitmapIndex::In(const size_t n, const T* values) { + AssertInfo(is_built_, "index has not been built"); + TargetBitmap res(total_num_rows_, false); + +#if 0 + roaring::Roaring result; + for (size_t i = 0; i < n; ++i) { + auto val = values[i]; + auto it = data_.find(val); + if (it != data_.end()) { + result |= it->second; + } + } + for (auto& val : result) { + res.set(val); + } +#else + for (size_t i = 0; i < n; ++i) { + auto val = values[i]; + if (bitsets_.find(val) != bitsets_.end()) { + res |= bitsets_.at(val); + } + } +#endif + return res; +} + +template +const TargetBitmap +BitmapIndex::NotIn(const size_t n, const T* values) { + AssertInfo(is_built_, "index has not been built"); + TargetBitmap res(total_num_rows_, false); + +#if 0 + roaring::Roaring result; + for (int i = 0; i < n; ++i) { + auto val = values[i]; + auto it = data_.find(val); + if (it != data_.end()) { + result |= it->second; + } + } + + for (auto& val : result) { + bitset.reset(val); + } +#else + for (size_t i = 0; i < n; ++i) { + auto val = values[i]; + if (bitsets_.find(val) != bitsets_.end()) { + res |= bitsets_.at(val); + } + } +#endif + res.flip(); + return res; +} + +template +const TargetBitmap +BitmapIndex::Range(const T value, const OpType op) { + AssertInfo(is_built_, "index has not been built"); + TargetBitmap res(total_num_rows_, false); + if (ShouldSkip(value, value, op)) { + return res; + } + auto lb = bitsets_.begin(); + auto ub = bitsets_.end(); + + switch (op) { + case OpType::LessThan: { + ub = std::lower_bound(bitsets_.begin(), + bitsets_.end(), + std::make_pair(value, TargetBitmap()), + [](const auto& lhs, const auto& rhs) { + return lhs.first < rhs.first; + }); + break; + } + case OpType::LessEqual: { + ub = std::upper_bound(bitsets_.begin(), + bitsets_.end(), + std::make_pair(value, TargetBitmap()), + [](const auto& lhs, const auto& rhs) { + return lhs.first < rhs.first; + }); + break; + } + case OpType::GreaterThan: { + lb = std::upper_bound(bitsets_.begin(), + bitsets_.end(), + std::make_pair(value, TargetBitmap()), + [](const auto& lhs, const auto& rhs) { + return lhs.first < rhs.first; + }); + break; + } + case OpType::GreaterEqual: { + lb = std::lower_bound(bitsets_.begin(), + bitsets_.end(), + std::make_pair(value, TargetBitmap()), + [](const auto& lhs, const auto& rhs) { + return lhs.first < rhs.first; + }); + break; + } + default: { + throw SegcoreError(OpTypeInvalid, + fmt::format("Invalid OperatorType: {}", op)); + } + } + + for (; lb != ub; lb++) { + res |= lb->second; + } + return res; +} + +template +const TargetBitmap +BitmapIndex::Range(const T lower_value, + bool lb_inclusive, + const T upper_value, + bool ub_inclusive) { + AssertInfo(is_built_, "index has not been built"); + TargetBitmap res(total_num_rows_, false); + if (lower_value > upper_value || + (lower_value == upper_value && !(lb_inclusive && ub_inclusive))) { + return res; + } + if (ShouldSkip(lower_value, upper_value, OpType::Range)) { + return res; + } + + auto lb = bitsets_.begin(); + auto ub = bitsets_.end(); + + if (lb_inclusive) { + lb = std::lower_bound(bitsets_.begin(), + bitsets_.end(), + std::make_pair(lower_value, TargetBitmap()), + [](const auto& lhs, const auto& rhs) { + return lhs.first < rhs.first; + }); + } else { + lb = std::upper_bound(bitsets_.begin(), + bitsets_.end(), + std::make_pair(lower_value, TargetBitmap()), + [](const auto& lhs, const auto& rhs) { + return lhs.first < rhs.first; + }); + } + + if (ub_inclusive) { + ub = std::upper_bound(bitsets_.begin(), + bitsets_.end(), + std::make_pair(upper_value, TargetBitmap()), + [](const auto& lhs, const auto& rhs) { + return lhs.first < rhs.first; + }); + } else { + ub = std::lower_bound(bitsets_.begin(), + bitsets_.end(), + std::make_pair(upper_value, TargetBitmap()), + [](const auto& lhs, const auto& rhs) { + return lhs.first < rhs.first; + }); + } + + for (; lb != ub; lb++) { + res |= lb->second; + } + return res; +} + +template +T +BitmapIndex::Reverse_Lookup(size_t idx) const { + AssertInfo(is_built_, "index has not been built"); + AssertInfo(idx < total_num_rows_, "out of range of total coun"); + + for (auto it = bitsets_.begin(); it != bitsets_.end(); it++) { + if (it->second[idx]) { + return it->first; + } + } + throw SegcoreError( + UnexpectedError, + fmt::format( + "scalar bitmap index can not lookup target value of index {}", + idx)); +} + +template +bool +BitmapIndex::ShouldSkip(const T lower_value, + const T upper_value, + const OpType op) { + if (!bitsets_.empty()) { + auto lower_bound = bitsets_.begin()->first; + auto upper_bound = bitsets_.rbegin()->first; + bool should_skip = false; + switch (op) { + case OpType::LessThan: { + // lower_value == upper_value + should_skip = lower_bound >= lower_value; + break; + } + case OpType::LessEqual: { + // lower_value == upper_value + should_skip = lower_bound > lower_value; + break; + } + case OpType::GreaterThan: { + // lower_value == upper_value + should_skip = upper_bound <= lower_value; + break; + } + case OpType::GreaterEqual: { + // lower_value == upper_value + should_skip = upper_bound < lower_value; + break; + } + case OpType::Range: { + // lower_value == upper_value + should_skip = + lower_bound > upper_value || upper_bound < lower_value; + break; + } + default: + throw SegcoreError( + OpTypeInvalid, + fmt::format("Invalid OperatorType for " + "checking scalar index optimization: {}", + op)); + } + return should_skip; + } + return true; +} + +template class BitmapIndex; +template class BitmapIndex; +template class BitmapIndex; +template class BitmapIndex; +template class BitmapIndex; +template class BitmapIndex; +template class BitmapIndex; +template class BitmapIndex; + +} // namespace index +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/index/BitmapIndex.h b/internal/core/src/index/BitmapIndex.h new file mode 100644 index 0000000000000..38ea6004495ff --- /dev/null +++ b/internal/core/src/index/BitmapIndex.h @@ -0,0 +1,144 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "index/ScalarIndex.h" +#include "storage/FileManager.h" +#include "storage/DiskFileManagerImpl.h" +#include "storage/MemFileManagerImpl.h" +#include "storage/space.h" + +namespace milvus { +namespace index { + +/* +* @brief Implementation of Bitmap Index +* @details This index only for scalar Integral type. +*/ +template +class BitmapIndex : public ScalarIndex { + public: + explicit BitmapIndex( + const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()); + + explicit BitmapIndex( + const storage::FileManagerContext& file_manager_context, + std::shared_ptr space); + + ~BitmapIndex() override = default; + + BinarySet + Serialize(const Config& config) override; + + void + Load(const BinarySet& index_binary, const Config& config = {}) override; + + void + Load(milvus::tracer::TraceContext ctx, const Config& config = {}) override; + + void + LoadV2(const Config& config = {}) override; + + int64_t + Count() override { + return bitsets_.begin()->second.size(); + } + + void + Build(size_t n, const T* values) override; + + void + Build(const Config& config = {}) override; + + void + BuildV2(const Config& config = {}) override; + + const TargetBitmap + In(size_t n, const T* values) override; + + const TargetBitmap + NotIn(size_t n, const T* values) override; + + const TargetBitmap + Range(T value, OpType op) override; + + const TargetBitmap + Range(T lower_bound_value, + bool lb_inclusive, + T upper_bound_value, + bool ub_inclusive) override; + + T + Reverse_Lookup(size_t offset) const override; + + int64_t + Size() override { + return Count(); + } + + BinarySet + Upload(const Config& config = {}) override; + BinarySet + UploadV2(const Config& config = {}) override; + + const bool + HasRawData() const override { + return true; + } + + int64_t + Cardinality() { + return bitsets_.size(); + } + + private: + size_t + GetIndexDataSize(); + + void + SerializeIndexData(uint8_t* index_data_ptr); + + void + DeserializeIndexData(const uint8_t* data_ptr, size_t index_length); + + bool + ShouldSkip(const T lower_value, const T upper_value, const OpType op); + + TargetBitmap + ConvertRoaringToBitset(const roaring::Roaring& values); + + void + LoadWithoutAssemble(const BinarySet& binary_set, const Config& config); + + private: + bool is_built_; + Config config_; + std::map data_; + std::map bitsets_; + size_t total_num_rows_; + std::shared_ptr file_manager_; + std::shared_ptr space_; +}; + +} // namespace index +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/index/CMakeLists.txt b/internal/core/src/index/CMakeLists.txt index c49e9477fba02..ed0f600587bd2 100644 --- a/internal/core/src/index/CMakeLists.txt +++ b/internal/core/src/index/CMakeLists.txt @@ -19,6 +19,7 @@ set(INDEX_FILES ScalarIndexSort.cpp SkipIndex.cpp InvertedIndexTantivy.cpp + BitmapIndex.cpp ) milvus_add_pkg_config("milvus_index") diff --git a/internal/core/src/index/IndexFactory.cpp b/internal/core/src/index/IndexFactory.cpp index e6e8a4cf93443..6d133adc96204 100644 --- a/internal/core/src/index/IndexFactory.cpp +++ b/internal/core/src/index/IndexFactory.cpp @@ -27,6 +27,7 @@ #include "index/StringIndexMarisa.h" #include "index/BoolIndex.h" #include "index/InvertedIndexTantivy.h" +#include "index/BitmapIndex.h" namespace milvus::index { @@ -42,6 +43,9 @@ IndexFactory::CreateScalarIndex( return std::make_unique>(cfg, file_manager_context); } + if (index_type == BITMAP_INDEX_TYPE) { + return std::make_unique>(file_manager_context); + } return CreateScalarIndexSort(file_manager_context); } @@ -65,6 +69,9 @@ IndexFactory::CreateScalarIndex( return std::make_unique>( cfg, file_manager_context); } + if (index_type == BITMAP_INDEX_TYPE) { + return std::make_unique>(file_manager_context); + } return CreateStringIndexMarisa(file_manager_context); #else throw SegcoreError(Unsupported, "unsupported platform"); @@ -84,6 +91,9 @@ IndexFactory::CreateScalarIndex( return std::make_unique>( cfg, file_manager_context, space); } + if (index_type == BITMAP_INDEX_TYPE) { + return std::make_unique>(file_manager_context, space); + } return CreateScalarIndexSort(file_manager_context, space); } @@ -101,6 +111,10 @@ IndexFactory::CreateScalarIndex( return std::make_unique>( cfg, file_manager_context, space); } + if (index_type == BITMAP_INDEX_TYPE) { + return std::make_unique>(file_manager_context, + space); + } return CreateStringIndexMarisa(file_manager_context, space); #else throw SegcoreError(Unsupported, "unsupported platform"); diff --git a/internal/core/src/index/Meta.h b/internal/core/src/index/Meta.h index c8fad3dbf4e1d..e44eb6d87a1ea 100644 --- a/internal/core/src/index/Meta.h +++ b/internal/core/src/index/Meta.h @@ -30,6 +30,12 @@ constexpr const char* PREFIX_VALUE = "prefix_value"; constexpr const char* MARISA_TRIE_INDEX = "marisa_trie_index"; constexpr const char* MARISA_STR_IDS = "marisa_trie_str_ids"; +// below meta key of store bitmap indexes +constexpr const char* BITMAP_INDEX_DATA = "bitmap_index_data"; +constexpr const char* BITMAP_INDEX_META = "bitmap_index_meta"; +constexpr const char* BITMAP_INDEX_LENGTH = "bitmap_index_length"; +constexpr const char* BITMAP_INDEX_NUM_ROWS = "bitmap_index_num_rows"; + constexpr const char* INDEX_TYPE = "index_type"; constexpr const char* METRIC_TYPE = "metric_type"; @@ -37,6 +43,7 @@ constexpr const char* METRIC_TYPE = "metric_type"; constexpr const char* ASCENDING_SORT = "STL_SORT"; constexpr const char* MARISA_TRIE = "Trie"; constexpr const char* INVERTED_INDEX_TYPE = "INVERTED"; +constexpr const char* BITMAP_INDEX_TYPE = "BITMAP"; // index meta constexpr const char* COLLECTION_ID = "collection_id"; diff --git a/internal/core/thirdparty/knowhere/CMakeLists.txt b/internal/core/thirdparty/knowhere/CMakeLists.txt index 24d610767d51c..9a146ffe273d1 100644 --- a/internal/core/thirdparty/knowhere/CMakeLists.txt +++ b/internal/core/thirdparty/knowhere/CMakeLists.txt @@ -12,7 +12,7 @@ #------------------------------------------------------------------------------- # Update KNOWHERE_VERSION for the first occurrence -set( KNOWHERE_VERSION v2.3.2 ) +set( KNOWHERE_VERSION 89657b08 ) set( GIT_REPOSITORY "https://github.com/zilliztech/knowhere.git") message(STATUS "Knowhere repo: ${GIT_REPOSITORY}") message(STATUS "Knowhere version: ${KNOWHERE_VERSION}") diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 657198c9b88c2..be78b2b36c43b 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -32,6 +32,7 @@ set(MILVUS_TEST_FILES test_growing.cpp test_growing_index.cpp test_indexing.cpp + test_bitmap_index.cpp test_index_c_api.cpp test_index_wrapper.cpp test_init.cpp diff --git a/internal/core/unittest/test_bitmap_index.cpp b/internal/core/unittest/test_bitmap_index.cpp new file mode 100644 index 0000000000000..99d877d744587 --- /dev/null +++ b/internal/core/unittest/test_bitmap_index.cpp @@ -0,0 +1,274 @@ +// Copyright(C) 2019 - 2020 Zilliz.All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include +#include + +#include "common/Tracer.h" +#include "index/BitmapIndex.h" +#include "storage/Util.h" +#include "storage/InsertData.h" +#include "indexbuilder/IndexFactory.h" +#include "index/IndexFactory.h" +#include "test_utils/indexbuilder_test_utils.h" +#include "index/Meta.h" + +using namespace milvus::index; +using namespace milvus::indexbuilder; +using namespace milvus; +using namespace milvus::index; + +template +static std::vector +GenerateData(const size_t size, const size_t cardinality) { + std::vector result; + for (size_t i = 0; i < size; ++i) { + result.push_back(rand() % cardinality); + } + return result; +} + +template <> +std::vector +GenerateData(const size_t size, const size_t cardinality) { + std::vector result; + for (size_t i = 0; i < size; ++i) { + result.push_back(rand() % 2 == 0); + } + return result; +} + +template <> +std::vector +GenerateData(const size_t size, const size_t cardinality) { + std::vector result; + for (size_t i = 0; i < size; ++i) { + result.push_back(std::to_string(rand() % cardinality)); + } + return result; +} + +template +class BitmapIndexTest : public testing::Test { + protected: + void + Init(int64_t collection_id, + int64_t partition_id, + int64_t segment_id, + int64_t field_id, + int64_t index_build_id, + int64_t index_version) { + auto field_meta = storage::FieldDataMeta{ + collection_id, partition_id, segment_id, field_id}; + auto index_meta = storage::IndexMeta{ + segment_id, field_id, index_build_id, index_version}; + + std::vector data_gen; + data_gen = GenerateData(nb_, cardinality_); + for (auto x : data_gen) { + data_.push_back(x); + } + + auto field_data = storage::CreateFieldData(type_); + field_data->FillFieldData(data_.data(), data_.size()); + storage::InsertData insert_data(field_data); + insert_data.SetFieldDataMeta(field_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::Remote); + + auto log_path = fmt::format("{}/{}/{}/{}/{}", + collection_id, + partition_id, + segment_id, + field_id, + 0); + chunk_manager_->Write( + log_path, serialized_bytes.data(), serialized_bytes.size()); + + storage::FileManagerContext ctx(field_meta, index_meta, chunk_manager_); + std::vector index_files; + + Config config; + config["index_type"] = milvus::index::BITMAP_INDEX_TYPE; + config["insert_files"] = std::vector{log_path}; + + auto build_index = + indexbuilder::IndexFactory::GetInstance().CreateIndex( + type_, config, ctx); + build_index->Build(); + + auto binary_set = build_index->Upload(); + for (const auto& [key, _] : binary_set.binary_map_) { + index_files.push_back(key); + } + + index::CreateIndexInfo index_info{}; + index_info.index_type = milvus::index::BITMAP_INDEX_TYPE; + index_info.field_type = type_; + + config["index_files"] = index_files; + + index_ = + index::IndexFactory::GetInstance().CreateIndex(index_info, ctx); + index_->Load(milvus::tracer::TraceContext{}, config); + } + + void + SetUp() override { + nb_ = 10000; + cardinality_ = 30; + + if constexpr (std::is_same_v) { + type_ = DataType::INT8; + } else if constexpr (std::is_same_v) { + type_ = DataType::INT16; + } else if constexpr (std::is_same_v) { + type_ = DataType::INT32; + } else if constexpr (std::is_same_v) { + type_ = DataType::INT64; + } else if constexpr (std::is_same_v) { + type_ = DataType::VARCHAR; + } + int64_t collection_id = 1; + int64_t partition_id = 2; + int64_t segment_id = 3; + int64_t field_id = 101; + int64_t index_build_id = 1000; + int64_t index_version = 10000; + std::string root_path = "/tmp/test-bitmap-index/"; + + storage::StorageConfig storage_config; + storage_config.storage_type = "local"; + storage_config.root_path = root_path; + chunk_manager_ = storage::CreateChunkManager(storage_config); + + Init(collection_id, + partition_id, + segment_id, + field_id, + index_build_id, + index_version); + } + + virtual ~BitmapIndexTest() override { + boost::filesystem::remove_all(chunk_manager_->GetRootPath()); + } + + public: + void + TestInFunc() { + boost::container::vector test_data; + std::unordered_set s; + size_t nq = 10; + for (size_t i = 0; i < nq; i++) { + test_data.push_back(data_[i]); + s.insert(data_[i]); + } + auto index_ptr = dynamic_cast*>(index_.get()); + auto bitset = index_ptr->In(test_data.size(), test_data.data()); + for (size_t i = 0; i < bitset.size(); i++) { + ASSERT_EQ(bitset[i], s.find(data_[i]) != s.end()); + } + } + + void + TestNotInFunc() { + boost::container::vector test_data; + std::unordered_set s; + size_t nq = 10; + for (size_t i = 0; i < nq; i++) { + test_data.push_back(data_[i]); + s.insert(data_[i]); + } + auto index_ptr = dynamic_cast*>(index_.get()); + auto bitset = index_ptr->NotIn(test_data.size(), test_data.data()); + for (size_t i = 0; i < bitset.size(); i++) { + ASSERT_EQ(bitset[i], s.find(data_[i]) == s.end()); + } + } + + void + TestCompareValueFunc() { + if constexpr (!std::is_same_v) { + using RefFunc = std::function; + std::vector> test_cases{ + {10, + OpType::GreaterThan, + [&](int64_t i) -> bool { return data_[i] > 10; }}, + {10, + OpType::GreaterEqual, + [&](int64_t i) -> bool { return data_[i] >= 10; }}, + {10, + OpType::LessThan, + [&](int64_t i) -> bool { return data_[i] < 10; }}, + {10, + OpType::LessEqual, + [&](int64_t i) -> bool { return data_[i] <= 10; }}, + }; + for (const auto& [test_value, op, ref] : test_cases) { + auto index_ptr = + dynamic_cast*>(index_.get()); + auto bitset = index_ptr->Range(test_value, op); + for (size_t i = 0; i < bitset.size(); i++) { + auto ans = bitset[i]; + auto should = ref(i); + ASSERT_EQ(ans, should) + << "op: " << op << ", @" << i << ", ans: " << ans + << ", ref: " << should; + } + } + } + } + + private: + std::shared_ptr chunk_manager_; + + public: + IndexBasePtr index_; + DataType type_; + size_t nb_; + size_t cardinality_; + boost::container::vector data_; +}; + +TYPED_TEST_SUITE_P(BitmapIndexTest); + +TYPED_TEST_P(BitmapIndexTest, CountFuncTest) { + auto count = this->index_->Count(); + EXPECT_EQ(count, this->nb_); +} + +TYPED_TEST_P(BitmapIndexTest, INFuncTest) { + this->TestInFunc(); +} + +TYPED_TEST_P(BitmapIndexTest, NotINFuncTest) { + this->TestNotInFunc(); +} + +TYPED_TEST_P(BitmapIndexTest, CompareValFuncTest) { + this->TestCompareValueFunc(); +} + +using BitmapType = + testing::Types; + +REGISTER_TYPED_TEST_SUITE_P(BitmapIndexTest, + CountFuncTest, + INFuncTest, + NotINFuncTest, + CompareValFuncTest); + +INSTANTIATE_TYPED_TEST_SUITE_P(BitmapE2ECheck, BitmapIndexTest, BitmapType); diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 3b7991abee1ce..379bf7792e622 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -5212,4 +5212,4 @@ TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_IP_BFLOAT16) { TEST(CApiTest, IsLoadWithDisk) { ASSERT_TRUE(IsLoadWithDisk(INVERTED_INDEX_TYPE, 0)); -} \ No newline at end of file +} diff --git a/internal/core/unittest/test_scalar_index.cpp b/internal/core/unittest/test_scalar_index.cpp index 8b11c89530e9b..2fc943b57b505 100644 --- a/internal/core/unittest/test_scalar_index.cpp +++ b/internal/core/unittest/test_scalar_index.cpp @@ -382,7 +382,7 @@ TYPED_TEST_P(TypedScalarIndexTestV2, Base) { auto new_scalar_index = dynamic_cast*>(new_index.get()); new_scalar_index->LoadV2(); - ASSERT_EQ(nb, scalar_index->Count()); + ASSERT_EQ(nb, new_scalar_index->Count()); } } diff --git a/internal/core/unittest/test_utils/indexbuilder_test_utils.h b/internal/core/unittest/test_utils/indexbuilder_test_utils.h index 2c4a283cd039a..8581c0453c8d0 100644 --- a/internal/core/unittest/test_utils/indexbuilder_test_utils.h +++ b/internal/core/unittest/test_utils/indexbuilder_test_utils.h @@ -478,26 +478,30 @@ GenDsFromPB(const google::protobuf::Message& msg) { template inline std::vector GetIndexTypes() { - return std::vector{"sort"}; + return std::vector{"sort", milvus::index::BITMAP_INDEX_TYPE}; } template <> inline std::vector GetIndexTypes() { - return std::vector{"sort", "marisa"}; + return std::vector{ + "sort", "marisa", milvus::index::BITMAP_INDEX_TYPE}; } template inline std::vector GetIndexTypesV2() { - return std::vector{"sort", milvus::index::INVERTED_INDEX_TYPE}; + return std::vector{"sort", + milvus::index::INVERTED_INDEX_TYPE, + milvus::index::BITMAP_INDEX_TYPE}; } template <> inline std::vector GetIndexTypesV2() { - return std::vector{milvus::index::INVERTED_INDEX_TYPE, - "marisa"}; + return std::vector{"marisa", + milvus::index::INVERTED_INDEX_TYPE, + milvus::index::BITMAP_INDEX_TYPE}; } } // namespace diff --git a/internal/datacoord/cluster.go b/internal/datacoord/cluster.go index ee07d7be41d81..b47142b1084b2 100644 --- a/internal/datacoord/cluster.go +++ b/internal/datacoord/cluster.go @@ -19,6 +19,7 @@ package datacoord import ( "context" "fmt" + "sync" "github.com/samber/lo" "go.uber.org/zap" @@ -31,6 +32,8 @@ import ( ) // Cluster provides interfaces to interact with datanode cluster +// +//go:generate mockery --name=Cluster --structname=MockCluster --output=./ --filename=mock_cluster.go --with-expecter --inpackage type Cluster interface { Startup(ctx context.Context, nodes []*NodeInfo) error Register(node *NodeInfo) error @@ -43,6 +46,7 @@ type Cluster interface { QueryPreImport(nodeID int64, in *datapb.QueryPreImportRequest) (*datapb.QueryPreImportResponse, error) QueryImport(nodeID int64, in *datapb.QueryImportRequest) (*datapb.QueryImportResponse, error) DropImport(nodeID int64, in *datapb.DropImportRequest) error + QuerySlots() map[int64]int64 GetSessions() []*Session Close() } @@ -175,6 +179,30 @@ func (c *ClusterImpl) DropImport(nodeID int64, in *datapb.DropImportRequest) err return c.sessionManager.DropImport(nodeID, in) } +func (c *ClusterImpl) QuerySlots() map[int64]int64 { + nodeIDs := c.sessionManager.GetSessionIDs() + nodeSlots := make(map[int64]int64) + mu := &sync.Mutex{} + wg := &sync.WaitGroup{} + for _, nodeID := range nodeIDs { + wg.Add(1) + go func(nodeID int64) { + defer wg.Done() + resp, err := c.sessionManager.QuerySlot(nodeID) + if err != nil { + log.Warn("query slot failed", zap.Int64("nodeID", nodeID), zap.Error(err)) + return + } + mu.Lock() + defer mu.Unlock() + nodeSlots[nodeID] = resp.GetNumSlots() + }(nodeID) + } + wg.Wait() + log.Debug("query slot done", zap.Any("nodeSlots", nodeSlots)) + return nodeSlots +} + // GetSessions returns all sessions func (c *ClusterImpl) GetSessions() []*Session { return c.sessionManager.GetSessions() diff --git a/internal/datacoord/cluster_test.go b/internal/datacoord/cluster_test.go index 145cd11662098..fee6d1938673d 100644 --- a/internal/datacoord/cluster_test.go +++ b/internal/datacoord/cluster_test.go @@ -20,6 +20,7 @@ import ( "context" "testing" + "github.com/cockroachdb/errors" "github.com/samber/lo" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -28,6 +29,7 @@ import ( "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/kv/mocks" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/util/testutils" ) @@ -175,3 +177,29 @@ func (suite *ClusterSuite) TestFlushChannels() { suite.NoError(err) }) } + +func (suite *ClusterSuite) TestQuerySlot() { + suite.Run("query slot failed", func() { + suite.SetupTest() + suite.mockSession.EXPECT().GetSessionIDs().Return([]int64{1}).Once() + suite.mockSession.EXPECT().QuerySlot(int64(1)).Return(nil, errors.New("mock err")).Once() + cluster := NewClusterImpl(suite.mockSession, suite.mockChManager) + nodeSlots := cluster.QuerySlots() + suite.Equal(0, len(nodeSlots)) + }) + + suite.Run("normal", func() { + suite.SetupTest() + suite.mockSession.EXPECT().GetSessionIDs().Return([]int64{1, 2, 3, 4}).Once() + suite.mockSession.EXPECT().QuerySlot(int64(1)).Return(&datapb.QuerySlotResponse{NumSlots: 1}, nil).Once() + suite.mockSession.EXPECT().QuerySlot(int64(2)).Return(&datapb.QuerySlotResponse{NumSlots: 2}, nil).Once() + suite.mockSession.EXPECT().QuerySlot(int64(3)).Return(&datapb.QuerySlotResponse{NumSlots: 3}, nil).Once() + suite.mockSession.EXPECT().QuerySlot(int64(4)).Return(&datapb.QuerySlotResponse{NumSlots: 4}, nil).Once() + cluster := NewClusterImpl(suite.mockSession, suite.mockChManager) + nodeSlots := cluster.QuerySlots() + suite.Equal(int64(1), nodeSlots[1]) + suite.Equal(int64(2), nodeSlots[2]) + suite.Equal(int64(3), nodeSlots[3]) + suite.Equal(int64(4), nodeSlots[4]) + }) +} diff --git a/internal/datacoord/compaction.go b/internal/datacoord/compaction.go index fef788b659035..d4d7a0ef82751 100644 --- a/internal/datacoord/compaction.go +++ b/internal/datacoord/compaction.go @@ -127,7 +127,7 @@ type compactionPlanHandler struct { stopWg sync.WaitGroup } -func newCompactionPlanHandler(sessions SessionManager, cm ChannelManager, meta CompactionMeta, allocator allocator, +func newCompactionPlanHandler(cluster Cluster, sessions SessionManager, cm ChannelManager, meta CompactionMeta, allocator allocator, ) *compactionPlanHandler { return &compactionPlanHandler{ plans: make(map[int64]*compactionTask), @@ -135,7 +135,7 @@ func newCompactionPlanHandler(sessions SessionManager, cm ChannelManager, meta C meta: meta, sessions: sessions, allocator: allocator, - scheduler: NewCompactionScheduler(), + scheduler: NewCompactionScheduler(cluster), } } @@ -199,7 +199,7 @@ func (c *compactionPlanHandler) start() { // influence the schedule go func() { defer c.stopWg.Done() - scheduleTicker := time.NewTicker(200 * time.Millisecond) + scheduleTicker := time.NewTicker(2 * time.Second) defer scheduleTicker.Stop() log.Info("compaction handler start schedule") for { diff --git a/internal/datacoord/compaction_scheduler.go b/internal/datacoord/compaction_scheduler.go index 89b8308dee50f..5e592d5e3033f 100644 --- a/internal/datacoord/compaction_scheduler.go +++ b/internal/datacoord/compaction_scheduler.go @@ -35,15 +35,17 @@ type CompactionScheduler struct { taskGuard lock.RWMutex planHandler *compactionPlanHandler + cluster Cluster } var _ Scheduler = (*CompactionScheduler)(nil) -func NewCompactionScheduler() *CompactionScheduler { +func NewCompactionScheduler(cluster Cluster) *CompactionScheduler { return &CompactionScheduler{ taskNumber: atomic.NewInt32(0), queuingTasks: make([]*compactionTask, 0), parallelTasks: make(map[int64][]*compactionTask), + cluster: cluster, } } @@ -62,22 +64,27 @@ func (s *CompactionScheduler) Submit(tasks ...*compactionTask) { // Schedule pick 1 or 0 tasks for 1 node func (s *CompactionScheduler) Schedule() []*compactionTask { - nodeTasks := make(map[int64][]*compactionTask) // nodeID - s.taskGuard.Lock() - defer s.taskGuard.Unlock() - for _, task := range s.queuingTasks { - if _, ok := nodeTasks[task.dataNodeID]; !ok { - nodeTasks[task.dataNodeID] = make([]*compactionTask, 0) - } - - nodeTasks[task.dataNodeID] = append(nodeTasks[task.dataNodeID], task) + nodeTasks := lo.GroupBy(s.queuingTasks, func(t *compactionTask) int64 { + return t.dataNodeID + }) + s.taskGuard.Unlock() + if len(nodeTasks) == 0 { + return nil // To mitigate the need for frequent slot querying } + nodeSlots := s.cluster.QuerySlots() + executable := make(map[int64]*compactionTask) pickPriorPolicy := func(tasks []*compactionTask, exclusiveChannels []string, executing []string) *compactionTask { for _, task := range tasks { + // TODO: sheep, replace pickShardNode with pickAnyNode + if nodeID := s.pickShardNode(task.dataNodeID, nodeSlots); nodeID == NullNodeID { + log.Warn("cannot find datanode for compaction task", zap.Int64("planID", task.plan.PlanID), zap.String("vchannel", task.plan.Channel)) + continue + } + if lo.Contains(exclusiveChannels, task.plan.GetChannel()) { continue } @@ -100,13 +107,11 @@ func (s *CompactionScheduler) Schedule() []*compactionTask { return nil } + s.taskGuard.Lock() + defer s.taskGuard.Unlock() // pick 1 or 0 task for 1 node for node, tasks := range nodeTasks { parallel := s.parallelTasks[node] - if len(parallel) >= calculateParallel() { - log.Info("Compaction parallel in DataNode reaches the limit", zap.Int64("nodeID", node), zap.Int("parallel", len(parallel))) - continue - } var ( executing = typeutil.NewSet[string]() @@ -122,6 +127,7 @@ func (s *CompactionScheduler) Schedule() []*compactionTask { picked := pickPriorPolicy(tasks, channelsExecPrior.Collect(), executing.Collect()) if picked != nil { executable[node] = picked + nodeSlots[node]-- } } @@ -211,3 +217,24 @@ func (s *CompactionScheduler) LogStatus() { func (s *CompactionScheduler) GetTaskCount() int { return int(s.taskNumber.Load()) } + +func (s *CompactionScheduler) pickAnyNode(nodeSlots map[int64]int64) int64 { + var ( + nodeID int64 = NullNodeID + maxSlots int64 = -1 + ) + for id, slots := range nodeSlots { + if slots > 0 && slots > maxSlots { + nodeID = id + maxSlots = slots + } + } + return nodeID +} + +func (s *CompactionScheduler) pickShardNode(nodeID int64, nodeSlots map[int64]int64) int64 { + if nodeSlots[nodeID] > 0 { + return nodeID + } + return NullNodeID +} diff --git a/internal/datacoord/compaction_scheduler_test.go b/internal/datacoord/compaction_scheduler_test.go index b96c76e7e1433..37f64f740b2f7 100644 --- a/internal/datacoord/compaction_scheduler_test.go +++ b/internal/datacoord/compaction_scheduler_test.go @@ -22,7 +22,8 @@ type SchedulerSuite struct { } func (s *SchedulerSuite) SetupTest() { - s.scheduler = NewCompactionScheduler() + cluster := NewMockCluster(s.T()) + s.scheduler = NewCompactionScheduler(cluster) s.scheduler.parallelTasks = map[int64][]*compactionTask{ 100: { {dataNodeID: 100, plan: &datapb.CompactionPlan{PlanID: 1, Channel: "ch-1", Type: datapb.CompactionType_MixCompaction}}, @@ -39,7 +40,8 @@ func (s *SchedulerSuite) SetupTest() { } func (s *SchedulerSuite) TestScheduleEmpty() { - emptySch := NewCompactionScheduler() + cluster := NewMockCluster(s.T()) + emptySch := NewCompactionScheduler(cluster) tasks := emptySch.Schedule() s.Empty(tasks) @@ -72,6 +74,12 @@ func (s *SchedulerSuite) TestScheduleParallelTaskFull() { s.SetupTest() s.Require().Equal(4, s.scheduler.GetTaskCount()) + if len(test.tasks) > 0 { + cluster := NewMockCluster(s.T()) + cluster.EXPECT().QuerySlots().Return(map[int64]int64{100: 0}) + s.scheduler.cluster = cluster + } + // submit the testing tasks s.scheduler.Submit(test.tasks...) s.Equal(4+len(test.tasks), s.scheduler.GetTaskCount()) @@ -111,6 +119,12 @@ func (s *SchedulerSuite) TestScheduleNodeWith1ParallelTask() { s.SetupTest() s.Require().Equal(4, s.scheduler.GetTaskCount()) + if len(test.tasks) > 0 { + cluster := NewMockCluster(s.T()) + cluster.EXPECT().QuerySlots().Return(map[int64]int64{101: 2}) + s.scheduler.cluster = cluster + } + // submit the testing tasks s.scheduler.Submit(test.tasks...) s.Equal(4+len(test.tasks), s.scheduler.GetTaskCount()) @@ -120,7 +134,12 @@ func (s *SchedulerSuite) TestScheduleNodeWith1ParallelTask() { return t.plan.PlanID })) - // the second schedule returns empty for full paralleTasks + // the second schedule returns empty for no slot + if len(test.tasks) > 0 { + cluster := NewMockCluster(s.T()) + cluster.EXPECT().QuerySlots().Return(map[int64]int64{101: 0}) + s.scheduler.cluster = cluster + } gotTasks = s.scheduler.Schedule() s.Empty(gotTasks) @@ -158,6 +177,12 @@ func (s *SchedulerSuite) TestScheduleNodeWithL0Executing() { s.SetupTest() s.Require().Equal(4, s.scheduler.GetTaskCount()) + if len(test.tasks) > 0 { + cluster := NewMockCluster(s.T()) + cluster.EXPECT().QuerySlots().Return(map[int64]int64{102: 2}) + s.scheduler.cluster = cluster + } + // submit the testing tasks s.scheduler.Submit(test.tasks...) s.Equal(4+len(test.tasks), s.scheduler.GetTaskCount()) @@ -167,7 +192,12 @@ func (s *SchedulerSuite) TestScheduleNodeWithL0Executing() { return t.plan.PlanID })) - // the second schedule returns empty for full paralleTasks + // the second schedule returns empty for no slot + if len(test.tasks) > 0 { + cluster := NewMockCluster(s.T()) + cluster.EXPECT().QuerySlots().Return(map[int64]int64{101: 0}) + s.scheduler.cluster = cluster + } if len(gotTasks) > 0 { gotTasks = s.scheduler.Schedule() s.Empty(gotTasks) @@ -215,3 +245,17 @@ func (s *SchedulerSuite) TestFinish() { s.MetricsEqual(taskNum, 1) }) } + +func (s *SchedulerSuite) TestPickNode() { + s.Run("test pickAnyNode", func() { + nodeSlots := map[int64]int64{ + 100: 2, + 101: 6, + } + node := s.scheduler.pickAnyNode(nodeSlots) + s.Equal(int64(101), node) + + node = s.scheduler.pickAnyNode(map[int64]int64{}) + s.Equal(int64(NullNodeID), node) + }) +} diff --git a/internal/datacoord/compaction_test.go b/internal/datacoord/compaction_test.go index e21f06f29f626..0936e7f8adf24 100644 --- a/internal/datacoord/compaction_test.go +++ b/internal/datacoord/compaction_test.go @@ -57,7 +57,7 @@ func (s *CompactionPlanHandlerSuite) SetupTest() { func (s *CompactionPlanHandlerSuite) TestRemoveTasksByChannel() { s.mockSch.EXPECT().Finish(mock.Anything, mock.Anything).Return().Once() - handler := newCompactionPlanHandler(nil, nil, nil, nil) + handler := newCompactionPlanHandler(nil, nil, nil, nil, nil) handler.scheduler = s.mockSch var ch string = "ch1" @@ -87,13 +87,13 @@ func (s *CompactionPlanHandlerSuite) TestCheckResult() { s.mockSessMgr.EXPECT().SyncSegments(int64(100), mock.Anything).Return(nil).Once() { s.mockAlloc.EXPECT().allocTimestamp(mock.Anything).Return(0, errors.New("mock")).Once() - handler := newCompactionPlanHandler(s.mockSessMgr, nil, nil, s.mockAlloc) + handler := newCompactionPlanHandler(nil, s.mockSessMgr, nil, nil, s.mockAlloc) handler.checkResult() } { s.mockAlloc.EXPECT().allocTimestamp(mock.Anything).Return(19530, nil).Once() - handler := newCompactionPlanHandler(s.mockSessMgr, nil, nil, s.mockAlloc) + handler := newCompactionPlanHandler(nil, s.mockSessMgr, nil, nil, s.mockAlloc) handler.checkResult() } } @@ -195,7 +195,7 @@ func (s *CompactionPlanHandlerSuite) TestHandleL0CompactionResults() { }, } - handler := newCompactionPlanHandler(nil, nil, s.mockMeta, s.mockAlloc) + handler := newCompactionPlanHandler(nil, nil, nil, s.mockMeta, s.mockAlloc) err := handler.handleL0CompactionResult(plan, result) s.NoError(err) } @@ -258,7 +258,7 @@ func (s *CompactionPlanHandlerSuite) TestRefreshL0Plan() { dataNodeID: 1, } - handler := newCompactionPlanHandler(nil, nil, s.mockMeta, s.mockAlloc) + handler := newCompactionPlanHandler(nil, nil, nil, s.mockMeta, s.mockAlloc) err := handler.RefreshPlan(task) s.Require().NoError(err) @@ -293,7 +293,7 @@ func (s *CompactionPlanHandlerSuite) TestRefreshL0Plan() { dataNodeID: 1, } - handler := newCompactionPlanHandler(nil, nil, s.mockMeta, s.mockAlloc) + handler := newCompactionPlanHandler(nil, nil, nil, s.mockMeta, s.mockAlloc) err := handler.RefreshPlan(task) s.Error(err) s.ErrorIs(err, merr.ErrSegmentNotFound) @@ -337,7 +337,7 @@ func (s *CompactionPlanHandlerSuite) TestRefreshL0Plan() { dataNodeID: 1, } - handler := newCompactionPlanHandler(nil, nil, s.mockMeta, s.mockAlloc) + handler := newCompactionPlanHandler(nil, nil, nil, s.mockMeta, s.mockAlloc) err := handler.RefreshPlan(task) s.Error(err) }) @@ -382,7 +382,7 @@ func (s *CompactionPlanHandlerSuite) TestRefreshPlanMixCompaction() { dataNodeID: 1, } - handler := newCompactionPlanHandler(nil, nil, s.mockMeta, s.mockAlloc) + handler := newCompactionPlanHandler(nil, nil, nil, s.mockMeta, s.mockAlloc) err := handler.RefreshPlan(task) s.Require().NoError(err) @@ -423,7 +423,7 @@ func (s *CompactionPlanHandlerSuite) TestRefreshPlanMixCompaction() { dataNodeID: 1, } - handler := newCompactionPlanHandler(nil, nil, s.mockMeta, s.mockAlloc) + handler := newCompactionPlanHandler(nil, nil, nil, s.mockMeta, s.mockAlloc) err := handler.RefreshPlan(task) s.Error(err) s.ErrorIs(err, merr.ErrSegmentNotFound) @@ -449,7 +449,7 @@ func (s *CompactionPlanHandlerSuite) TestExecCompactionPlan() { {"channel with no error", "ch-2", false}, } - handler := newCompactionPlanHandler(s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc) + handler := newCompactionPlanHandler(nil, s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc) handler.scheduler = s.mockSch for idx, test := range tests { @@ -482,7 +482,7 @@ func (s *CompactionPlanHandlerSuite) TestHandleMergeCompactionResult() { s.Run("illegal nil result", func() { s.SetupTest() - handler := newCompactionPlanHandler(s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc) + handler := newCompactionPlanHandler(nil, s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc) err := handler.handleMergeCompactionResult(nil, nil) s.Error(err) }) @@ -498,7 +498,7 @@ func (s *CompactionPlanHandlerSuite) TestHandleMergeCompactionResult() { }).Once() s.mockSessMgr.EXPECT().SyncSegments(mock.Anything, mock.Anything).Return(nil).Once() - handler := newCompactionPlanHandler(s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc) + handler := newCompactionPlanHandler(nil, s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc) handler.plans[plan.PlanID] = &compactionTask{dataNodeID: 111, plan: plan} compactionResult := &datapb.CompactionPlanResult{ @@ -518,7 +518,7 @@ func (s *CompactionPlanHandlerSuite) TestHandleMergeCompactionResult() { s.mockMeta.EXPECT().CompleteCompactionMutation(mock.Anything, mock.Anything).Return( nil, nil, errors.New("mock error")).Once() - handler := newCompactionPlanHandler(s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc) + handler := newCompactionPlanHandler(nil, s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc) handler.plans[plan.PlanID] = &compactionTask{dataNodeID: 111, plan: plan} compactionResult := &datapb.CompactionPlanResult{ PlanID: plan.PlanID, @@ -540,7 +540,7 @@ func (s *CompactionPlanHandlerSuite) TestHandleMergeCompactionResult() { &segMetricMutation{}, nil).Once() s.mockSessMgr.EXPECT().SyncSegments(mock.Anything, mock.Anything).Return(errors.New("mock error")).Once() - handler := newCompactionPlanHandler(s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc) + handler := newCompactionPlanHandler(nil, s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc) handler.plans[plan.PlanID] = &compactionTask{dataNodeID: 111, plan: plan} compactionResult := &datapb.CompactionPlanResult{ PlanID: plan.PlanID, @@ -556,7 +556,7 @@ func (s *CompactionPlanHandlerSuite) TestHandleMergeCompactionResult() { func (s *CompactionPlanHandlerSuite) TestCompleteCompaction() { s.Run("test not exists compaction task", func() { - handler := newCompactionPlanHandler(nil, nil, nil, nil) + handler := newCompactionPlanHandler(nil, nil, nil, nil, nil) err := handler.completeCompaction(&datapb.CompactionPlanResult{PlanID: 2}) s.Error(err) }) @@ -636,7 +636,7 @@ func (s *CompactionPlanHandlerSuite) TestCompleteCompaction() { }, } - c := newCompactionPlanHandler(s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc) + c := newCompactionPlanHandler(nil, s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc) c.scheduler = s.mockSch c.plans = plans @@ -734,7 +734,7 @@ func (s *CompactionPlanHandlerSuite) TestUpdateCompaction() { s.mockCm.EXPECT().Match(int64(111), "ch-1").Return(true) s.mockCm.EXPECT().Match(int64(111), "ch-2").Return(false).Once() - handler := newCompactionPlanHandler(s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc) + handler := newCompactionPlanHandler(nil, s.mockSessMgr, s.mockCm, s.mockMeta, s.mockAlloc) handler.plans = inPlans _, ok := handler.plans[5] diff --git a/internal/datacoord/compaction_trigger_test.go b/internal/datacoord/compaction_trigger_test.go index 417bf92b04d94..19d4146a65e14 100644 --- a/internal/datacoord/compaction_trigger_test.go +++ b/internal/datacoord/compaction_trigger_test.go @@ -2000,12 +2000,12 @@ func Test_compactionTrigger_new(t *testing.T) { } func Test_compactionTrigger_allocTs(t *testing.T) { - got := newCompactionTrigger(&meta{segments: NewSegmentsInfo()}, &compactionPlanHandler{scheduler: NewCompactionScheduler()}, newMockAllocator(), newMockHandler(), newMockVersionManager()) + got := newCompactionTrigger(&meta{segments: NewSegmentsInfo()}, &compactionPlanHandler{scheduler: NewCompactionScheduler(nil)}, newMockAllocator(), newMockHandler(), newMockVersionManager()) ts, err := got.allocTs() assert.NoError(t, err) assert.True(t, ts > 0) - got = newCompactionTrigger(&meta{segments: NewSegmentsInfo()}, &compactionPlanHandler{scheduler: NewCompactionScheduler()}, &FailsAllocator{}, newMockHandler(), newMockVersionManager()) + got = newCompactionTrigger(&meta{segments: NewSegmentsInfo()}, &compactionPlanHandler{scheduler: NewCompactionScheduler(nil)}, &FailsAllocator{}, newMockHandler(), newMockVersionManager()) ts, err = got.allocTs() assert.Error(t, err) assert.Equal(t, uint64(0), ts) @@ -2032,7 +2032,7 @@ func Test_compactionTrigger_getCompactTime(t *testing.T) { } m := &meta{segments: NewSegmentsInfo(), collections: collections} - got := newCompactionTrigger(m, &compactionPlanHandler{scheduler: NewCompactionScheduler()}, newMockAllocator(), + got := newCompactionTrigger(m, &compactionPlanHandler{scheduler: NewCompactionScheduler(nil)}, newMockAllocator(), &ServerHandler{ &Server{ meta: m, diff --git a/internal/datacoord/mock_cluster.go b/internal/datacoord/mock_cluster.go index 77a9b56633767..e35f1e1fee0ab 100644 --- a/internal/datacoord/mock_cluster.go +++ b/internal/datacoord/mock_cluster.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.30.1. DO NOT EDIT. package datacoord @@ -74,8 +74,8 @@ type MockCluster_DropImport_Call struct { } // DropImport is a helper method to define mock.On call -// - nodeID int64 -// - in *datapb.DropImportRequest +// - nodeID int64 +// - in *datapb.DropImportRequest func (_e *MockCluster_Expecter) DropImport(nodeID interface{}, in interface{}) *MockCluster_DropImport_Call { return &MockCluster_DropImport_Call{Call: _e.mock.On("DropImport", nodeID, in)} } @@ -117,10 +117,10 @@ type MockCluster_Flush_Call struct { } // Flush is a helper method to define mock.On call -// - ctx context.Context -// - nodeID int64 -// - channel string -// - segments []*datapb.SegmentInfo +// - ctx context.Context +// - nodeID int64 +// - channel string +// - segments []*datapb.SegmentInfo func (_e *MockCluster_Expecter) Flush(ctx interface{}, nodeID interface{}, channel interface{}, segments interface{}) *MockCluster_Flush_Call { return &MockCluster_Flush_Call{Call: _e.mock.On("Flush", ctx, nodeID, channel, segments)} } @@ -162,10 +162,10 @@ type MockCluster_FlushChannels_Call struct { } // FlushChannels is a helper method to define mock.On call -// - ctx context.Context -// - nodeID int64 -// - flushTs uint64 -// - channels []string +// - ctx context.Context +// - nodeID int64 +// - flushTs uint64 +// - channels []string func (_e *MockCluster_Expecter) FlushChannels(ctx interface{}, nodeID interface{}, flushTs interface{}, channels interface{}) *MockCluster_FlushChannels_Call { return &MockCluster_FlushChannels_Call{Call: _e.mock.On("FlushChannels", ctx, nodeID, flushTs, channels)} } @@ -250,8 +250,8 @@ type MockCluster_ImportV2_Call struct { } // ImportV2 is a helper method to define mock.On call -// - nodeID int64 -// - in *datapb.ImportRequest +// - nodeID int64 +// - in *datapb.ImportRequest func (_e *MockCluster_Expecter) ImportV2(nodeID interface{}, in interface{}) *MockCluster_ImportV2_Call { return &MockCluster_ImportV2_Call{Call: _e.mock.On("ImportV2", nodeID, in)} } @@ -293,8 +293,8 @@ type MockCluster_PreImport_Call struct { } // PreImport is a helper method to define mock.On call -// - nodeID int64 -// - in *datapb.PreImportRequest +// - nodeID int64 +// - in *datapb.PreImportRequest func (_e *MockCluster_Expecter) PreImport(nodeID interface{}, in interface{}) *MockCluster_PreImport_Call { return &MockCluster_PreImport_Call{Call: _e.mock.On("PreImport", nodeID, in)} } @@ -348,8 +348,8 @@ type MockCluster_QueryImport_Call struct { } // QueryImport is a helper method to define mock.On call -// - nodeID int64 -// - in *datapb.QueryImportRequest +// - nodeID int64 +// - in *datapb.QueryImportRequest func (_e *MockCluster_Expecter) QueryImport(nodeID interface{}, in interface{}) *MockCluster_QueryImport_Call { return &MockCluster_QueryImport_Call{Call: _e.mock.On("QueryImport", nodeID, in)} } @@ -403,8 +403,8 @@ type MockCluster_QueryPreImport_Call struct { } // QueryPreImport is a helper method to define mock.On call -// - nodeID int64 -// - in *datapb.QueryPreImportRequest +// - nodeID int64 +// - in *datapb.QueryPreImportRequest func (_e *MockCluster_Expecter) QueryPreImport(nodeID interface{}, in interface{}) *MockCluster_QueryPreImport_Call { return &MockCluster_QueryPreImport_Call{Call: _e.mock.On("QueryPreImport", nodeID, in)} } @@ -426,6 +426,49 @@ func (_c *MockCluster_QueryPreImport_Call) RunAndReturn(run func(int64, *datapb. return _c } +// QuerySlots provides a mock function with given fields: +func (_m *MockCluster) QuerySlots() map[int64]int64 { + ret := _m.Called() + + var r0 map[int64]int64 + if rf, ok := ret.Get(0).(func() map[int64]int64); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]int64) + } + } + + return r0 +} + +// MockCluster_QuerySlots_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QuerySlots' +type MockCluster_QuerySlots_Call struct { + *mock.Call +} + +// QuerySlots is a helper method to define mock.On call +func (_e *MockCluster_Expecter) QuerySlots() *MockCluster_QuerySlots_Call { + return &MockCluster_QuerySlots_Call{Call: _e.mock.On("QuerySlots")} +} + +func (_c *MockCluster_QuerySlots_Call) Run(run func()) *MockCluster_QuerySlots_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockCluster_QuerySlots_Call) Return(_a0 map[int64]int64) *MockCluster_QuerySlots_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCluster_QuerySlots_Call) RunAndReturn(run func() map[int64]int64) *MockCluster_QuerySlots_Call { + _c.Call.Return(run) + return _c +} + // Register provides a mock function with given fields: node func (_m *MockCluster) Register(node *NodeInfo) error { ret := _m.Called(node) @@ -446,7 +489,7 @@ type MockCluster_Register_Call struct { } // Register is a helper method to define mock.On call -// - node *NodeInfo +// - node *NodeInfo func (_e *MockCluster_Expecter) Register(node interface{}) *MockCluster_Register_Call { return &MockCluster_Register_Call{Call: _e.mock.On("Register", node)} } @@ -488,8 +531,8 @@ type MockCluster_Startup_Call struct { } // Startup is a helper method to define mock.On call -// - ctx context.Context -// - nodes []*NodeInfo +// - ctx context.Context +// - nodes []*NodeInfo func (_e *MockCluster_Expecter) Startup(ctx interface{}, nodes interface{}) *MockCluster_Startup_Call { return &MockCluster_Startup_Call{Call: _e.mock.On("Startup", ctx, nodes)} } @@ -531,7 +574,7 @@ type MockCluster_UnRegister_Call struct { } // UnRegister is a helper method to define mock.On call -// - node *NodeInfo +// - node *NodeInfo func (_e *MockCluster_Expecter) UnRegister(node interface{}) *MockCluster_UnRegister_Call { return &MockCluster_UnRegister_Call{Call: _e.mock.On("UnRegister", node)} } @@ -573,8 +616,8 @@ type MockCluster_Watch_Call struct { } // Watch is a helper method to define mock.On call -// - ctx context.Context -// - ch RWChannel +// - ctx context.Context +// - ch RWChannel func (_e *MockCluster_Expecter) Watch(ctx interface{}, ch interface{}) *MockCluster_Watch_Call { return &MockCluster_Watch_Call{Call: _e.mock.On("Watch", ctx, ch)} } diff --git a/internal/datacoord/mock_session_manager.go b/internal/datacoord/mock_session_manager.go index b35ead232a23d..a7d8e7f679c59 100644 --- a/internal/datacoord/mock_session_manager.go +++ b/internal/datacoord/mock_session_manager.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.30.1. DO NOT EDIT. package datacoord @@ -35,7 +35,7 @@ type MockSessionManager_AddSession_Call struct { } // AddSession is a helper method to define mock.On call -// - node *NodeInfo +// - node *NodeInfo func (_e *MockSessionManager_Expecter) AddSession(node interface{}) *MockSessionManager_AddSession_Call { return &MockSessionManager_AddSession_Call{Call: _e.mock.On("AddSession", node)} } @@ -89,9 +89,9 @@ type MockSessionManager_CheckChannelOperationProgress_Call struct { } // CheckChannelOperationProgress is a helper method to define mock.On call -// - ctx context.Context -// - nodeID int64 -// - info *datapb.ChannelWatchInfo +// - ctx context.Context +// - nodeID int64 +// - info *datapb.ChannelWatchInfo func (_e *MockSessionManager_Expecter) CheckChannelOperationProgress(ctx interface{}, nodeID interface{}, info interface{}) *MockSessionManager_CheckChannelOperationProgress_Call { return &MockSessionManager_CheckChannelOperationProgress_Call{Call: _e.mock.On("CheckChannelOperationProgress", ctx, nodeID, info)} } @@ -133,7 +133,7 @@ type MockSessionManager_CheckHealth_Call struct { } // CheckHealth is a helper method to define mock.On call -// - ctx context.Context +// - ctx context.Context func (_e *MockSessionManager_Expecter) CheckHealth(ctx interface{}) *MockSessionManager_CheckHealth_Call { return &MockSessionManager_CheckHealth_Call{Call: _e.mock.On("CheckHealth", ctx)} } @@ -207,9 +207,9 @@ type MockSessionManager_Compaction_Call struct { } // Compaction is a helper method to define mock.On call -// - ctx context.Context -// - nodeID int64 -// - plan *datapb.CompactionPlan +// - ctx context.Context +// - nodeID int64 +// - plan *datapb.CompactionPlan func (_e *MockSessionManager_Expecter) Compaction(ctx interface{}, nodeID interface{}, plan interface{}) *MockSessionManager_Compaction_Call { return &MockSessionManager_Compaction_Call{Call: _e.mock.On("Compaction", ctx, nodeID, plan)} } @@ -242,7 +242,7 @@ type MockSessionManager_DeleteSession_Call struct { } // DeleteSession is a helper method to define mock.On call -// - node *NodeInfo +// - node *NodeInfo func (_e *MockSessionManager_Expecter) DeleteSession(node interface{}) *MockSessionManager_DeleteSession_Call { return &MockSessionManager_DeleteSession_Call{Call: _e.mock.On("DeleteSession", node)} } @@ -284,8 +284,8 @@ type MockSessionManager_DropImport_Call struct { } // DropImport is a helper method to define mock.On call -// - nodeID int64 -// - in *datapb.DropImportRequest +// - nodeID int64 +// - in *datapb.DropImportRequest func (_e *MockSessionManager_Expecter) DropImport(nodeID interface{}, in interface{}) *MockSessionManager_DropImport_Call { return &MockSessionManager_DropImport_Call{Call: _e.mock.On("DropImport", nodeID, in)} } @@ -318,9 +318,9 @@ type MockSessionManager_Flush_Call struct { } // Flush is a helper method to define mock.On call -// - ctx context.Context -// - nodeID int64 -// - req *datapb.FlushSegmentsRequest +// - ctx context.Context +// - nodeID int64 +// - req *datapb.FlushSegmentsRequest func (_e *MockSessionManager_Expecter) Flush(ctx interface{}, nodeID interface{}, req interface{}) *MockSessionManager_Flush_Call { return &MockSessionManager_Flush_Call{Call: _e.mock.On("Flush", ctx, nodeID, req)} } @@ -362,9 +362,9 @@ type MockSessionManager_FlushChannels_Call struct { } // FlushChannels is a helper method to define mock.On call -// - ctx context.Context -// - nodeID int64 -// - req *datapb.FlushChannelsRequest +// - ctx context.Context +// - nodeID int64 +// - req *datapb.FlushChannelsRequest func (_e *MockSessionManager_Expecter) FlushChannels(ctx interface{}, nodeID interface{}, req interface{}) *MockSessionManager_FlushChannels_Call { return &MockSessionManager_FlushChannels_Call{Call: _e.mock.On("FlushChannels", ctx, nodeID, req)} } @@ -545,8 +545,8 @@ type MockSessionManager_ImportV2_Call struct { } // ImportV2 is a helper method to define mock.On call -// - nodeID int64 -// - in *datapb.ImportRequest +// - nodeID int64 +// - in *datapb.ImportRequest func (_e *MockSessionManager_Expecter) ImportV2(nodeID interface{}, in interface{}) *MockSessionManager_ImportV2_Call { return &MockSessionManager_ImportV2_Call{Call: _e.mock.On("ImportV2", nodeID, in)} } @@ -588,9 +588,9 @@ type MockSessionManager_NotifyChannelOperation_Call struct { } // NotifyChannelOperation is a helper method to define mock.On call -// - ctx context.Context -// - nodeID int64 -// - req *datapb.ChannelOperationsRequest +// - ctx context.Context +// - nodeID int64 +// - req *datapb.ChannelOperationsRequest func (_e *MockSessionManager_Expecter) NotifyChannelOperation(ctx interface{}, nodeID interface{}, req interface{}) *MockSessionManager_NotifyChannelOperation_Call { return &MockSessionManager_NotifyChannelOperation_Call{Call: _e.mock.On("NotifyChannelOperation", ctx, nodeID, req)} } @@ -632,8 +632,8 @@ type MockSessionManager_PreImport_Call struct { } // PreImport is a helper method to define mock.On call -// - nodeID int64 -// - in *datapb.PreImportRequest +// - nodeID int64 +// - in *datapb.PreImportRequest func (_e *MockSessionManager_Expecter) PreImport(nodeID interface{}, in interface{}) *MockSessionManager_PreImport_Call { return &MockSessionManager_PreImport_Call{Call: _e.mock.On("PreImport", nodeID, in)} } @@ -687,8 +687,8 @@ type MockSessionManager_QueryImport_Call struct { } // QueryImport is a helper method to define mock.On call -// - nodeID int64 -// - in *datapb.QueryImportRequest +// - nodeID int64 +// - in *datapb.QueryImportRequest func (_e *MockSessionManager_Expecter) QueryImport(nodeID interface{}, in interface{}) *MockSessionManager_QueryImport_Call { return &MockSessionManager_QueryImport_Call{Call: _e.mock.On("QueryImport", nodeID, in)} } @@ -742,8 +742,8 @@ type MockSessionManager_QueryPreImport_Call struct { } // QueryPreImport is a helper method to define mock.On call -// - nodeID int64 -// - in *datapb.QueryPreImportRequest +// - nodeID int64 +// - in *datapb.QueryPreImportRequest func (_e *MockSessionManager_Expecter) QueryPreImport(nodeID interface{}, in interface{}) *MockSessionManager_QueryPreImport_Call { return &MockSessionManager_QueryPreImport_Call{Call: _e.mock.On("QueryPreImport", nodeID, in)} } @@ -765,6 +765,60 @@ func (_c *MockSessionManager_QueryPreImport_Call) RunAndReturn(run func(int64, * return _c } +// QuerySlot provides a mock function with given fields: nodeID +func (_m *MockSessionManager) QuerySlot(nodeID int64) (*datapb.QuerySlotResponse, error) { + ret := _m.Called(nodeID) + + var r0 *datapb.QuerySlotResponse + var r1 error + if rf, ok := ret.Get(0).(func(int64) (*datapb.QuerySlotResponse, error)); ok { + return rf(nodeID) + } + if rf, ok := ret.Get(0).(func(int64) *datapb.QuerySlotResponse); ok { + r0 = rf(nodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.QuerySlotResponse) + } + } + + if rf, ok := ret.Get(1).(func(int64) error); ok { + r1 = rf(nodeID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSessionManager_QuerySlot_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QuerySlot' +type MockSessionManager_QuerySlot_Call struct { + *mock.Call +} + +// QuerySlot is a helper method to define mock.On call +// - nodeID int64 +func (_e *MockSessionManager_Expecter) QuerySlot(nodeID interface{}) *MockSessionManager_QuerySlot_Call { + return &MockSessionManager_QuerySlot_Call{Call: _e.mock.On("QuerySlot", nodeID)} +} + +func (_c *MockSessionManager_QuerySlot_Call) Run(run func(nodeID int64)) *MockSessionManager_QuerySlot_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockSessionManager_QuerySlot_Call) Return(_a0 *datapb.QuerySlotResponse, _a1 error) *MockSessionManager_QuerySlot_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSessionManager_QuerySlot_Call) RunAndReturn(run func(int64) (*datapb.QuerySlotResponse, error)) *MockSessionManager_QuerySlot_Call { + _c.Call.Return(run) + return _c +} + // SyncSegments provides a mock function with given fields: nodeID, req func (_m *MockSessionManager) SyncSegments(nodeID int64, req *datapb.SyncSegmentsRequest) error { ret := _m.Called(nodeID, req) @@ -785,8 +839,8 @@ type MockSessionManager_SyncSegments_Call struct { } // SyncSegments is a helper method to define mock.On call -// - nodeID int64 -// - req *datapb.SyncSegmentsRequest +// - nodeID int64 +// - req *datapb.SyncSegmentsRequest func (_e *MockSessionManager_Expecter) SyncSegments(nodeID interface{}, req interface{}) *MockSessionManager_SyncSegments_Call { return &MockSessionManager_SyncSegments_Call{Call: _e.mock.On("SyncSegments", nodeID, req)} } diff --git a/internal/datacoord/mock_test.go b/internal/datacoord/mock_test.go index 70c0c9482cd7d..bac8735fd394a 100644 --- a/internal/datacoord/mock_test.go +++ b/internal/datacoord/mock_test.go @@ -319,6 +319,10 @@ func (c *mockDataNodeClient) DropImport(ctx context.Context, req *datapb.DropImp return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil } +func (c *mockDataNodeClient) QuerySlot(ctx context.Context, req *datapb.QuerySlotRequest, opts ...grpc.CallOption) (*datapb.QuerySlotResponse, error) { + return &datapb.QuerySlotResponse{Status: merr.Success()}, nil +} + func (c *mockDataNodeClient) Stop() error { c.state = commonpb.StateCode_Abnormal return nil diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index 5662b47a1b74f..50ccc8d37ca58 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -523,7 +523,7 @@ func (s *Server) SetIndexNodeCreator(f func(context.Context, string, int64) (typ } func (s *Server) createCompactionHandler() { - s.compactionHandler = newCompactionPlanHandler(s.sessionManager, s.channelManager, s.meta, s.allocator) + s.compactionHandler = newCompactionPlanHandler(s.cluster, s.sessionManager, s.channelManager, s.meta, s.allocator) triggerv2 := NewCompactionTriggerManager(s.allocator, s.compactionHandler) s.compactionViewManager = NewCompactionViewManager(s.meta, triggerv2, s.allocator) } diff --git a/internal/datacoord/session_manager.go b/internal/datacoord/session_manager.go index 7b3e01892093f..14e4e23d5ea04 100644 --- a/internal/datacoord/session_manager.go +++ b/internal/datacoord/session_manager.go @@ -44,8 +44,10 @@ import ( const ( flushTimeout = 15 * time.Second importTaskTimeout = 10 * time.Second + querySlotTimeout = 10 * time.Second ) +//go:generate mockery --name=SessionManager --structname=MockSessionManager --output=./ --filename=mock_session_manager.go --with-expecter --inpackage type SessionManager interface { AddSession(node *NodeInfo) DeleteSession(node *NodeInfo) @@ -65,6 +67,7 @@ type SessionManager interface { QueryImport(nodeID int64, in *datapb.QueryImportRequest) (*datapb.QueryImportResponse, error) DropImport(nodeID int64, in *datapb.DropImportRequest) error CheckHealth(ctx context.Context) error + QuerySlot(nodeID int64) (*datapb.QuerySlotResponse, error) Close() } @@ -474,6 +477,22 @@ func (c *SessionManagerImpl) CheckHealth(ctx context.Context) error { return group.Wait() } +func (c *SessionManagerImpl) QuerySlot(nodeID int64) (*datapb.QuerySlotResponse, error) { + log := log.With(zap.Int64("nodeID", nodeID)) + ctx, cancel := context.WithTimeout(context.Background(), querySlotTimeout) + defer cancel() + cli, err := c.getClient(ctx, nodeID) + if err != nil { + log.Info("failed to get client", zap.Error(err)) + return nil, err + } + resp, err := cli.QuerySlot(ctx, &datapb.QuerySlotRequest{}) + if err = VerifyResponse(resp.GetStatus(), err); err != nil { + return nil, err + } + return resp, nil +} + // Close release sessions func (c *SessionManagerImpl) Close() { c.sessions.Lock() diff --git a/internal/datanode/binlog_io.go b/internal/datanode/binlog_io.go index c6ff5425f6d57..506c614a1c9b9 100644 --- a/internal/datanode/binlog_io.go +++ b/internal/datanode/binlog_io.go @@ -80,22 +80,18 @@ func genDeltaBlobs(b io.BinlogIO, allocator allocator.Allocator, data *DeleteDat } // genInsertBlobs returns insert-paths and save blob to kvs -func genInsertBlobs(b io.BinlogIO, allocator allocator.Allocator, data *InsertData, collectionID, partID, segID UniqueID, iCodec *storage.InsertCodec, kvs map[string][]byte) (map[UniqueID]*datapb.FieldBinlog, error) { - inlogs, err := iCodec.Serialize(partID, segID, data) - if err != nil { - return nil, err - } - +func genInsertBlobs(b io.BinlogIO, allocator allocator.Allocator, data []*Blob, collectionID, partID, segID UniqueID, kvs map[string][]byte, +) (map[UniqueID]*datapb.FieldBinlog, error) { inpaths := make(map[UniqueID]*datapb.FieldBinlog) notifyGenIdx := make(chan struct{}) defer close(notifyGenIdx) - generator, err := allocator.GetGenerator(len(inlogs), notifyGenIdx) + generator, err := allocator.GetGenerator(len(data), notifyGenIdx) if err != nil { return nil, err } - for _, blob := range inlogs { + for _, blob := range data { // Blob Key is generated by Serialize from int64 fieldID in collection schema, which won't raise error in ParseInt fID, _ := strconv.ParseInt(blob.GetKey(), 10, 64) k := metautil.JoinIDPath(collectionID, partID, segID, fID, <-generator) @@ -177,22 +173,21 @@ func uploadInsertLog( collectionID UniqueID, partID UniqueID, segID UniqueID, - iData *InsertData, - iCodec *storage.InsertCodec, + data []*Blob, ) (map[UniqueID]*datapb.FieldBinlog, error) { ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "UploadInsertLog") defer span.End() kvs := make(map[string][]byte) - if iData.IsEmpty() { + if len(data) <= 0 || data[0].RowNum <= 0 { log.Warn("binlog io uploading empty insert data", zap.Int64("segmentID", segID), - zap.Int64("collectionID", iCodec.Schema.GetID()), + zap.Int64("collectionID", collectionID), ) return nil, nil } - inpaths, err := genInsertBlobs(b, allocator, iData, collectionID, partID, segID, iCodec, kvs) + inpaths, err := genInsertBlobs(b, allocator, data, collectionID, partID, segID, kvs) if err != nil { return nil, err } diff --git a/internal/datanode/binlog_io_test.go b/internal/datanode/binlog_io_test.go index eea1b18291e81..038978ac0464c 100644 --- a/internal/datanode/binlog_io_test.go +++ b/internal/datanode/binlog_io_test.go @@ -124,21 +124,17 @@ func TestBinlogIOInterfaceMethods(t *testing.T) { f := &MetaFactory{} meta := f.GetCollectionMeta(UniqueID(10001), "test_gen_blobs", schemapb.DataType_Int64) - t.Run("empty insert", func(t *testing.T) { - alloc := allocator.NewMockAllocator(t) - binlogIO := io.NewBinlogIO(cm, getOrCreateIOPool()) - iCodec := storage.NewInsertCodecWithSchema(meta) - paths, err := uploadInsertLog(context.Background(), binlogIO, alloc, meta.GetID(), 10, 1, genEmptyInsertData(), iCodec) - assert.NoError(t, err) - assert.Nil(t, paths) - }) - t.Run("gen insert blob failed", func(t *testing.T) { alloc := allocator.NewMockAllocator(t) binlogIO := io.NewBinlogIO(cm, getOrCreateIOPool()) iCodec := storage.NewInsertCodecWithSchema(meta) + var partId int64 = 10 + var segId int64 = 1 + iData := genInsertData(2) + blobs, err := iCodec.Serialize(10, 1, iData) + assert.NoError(t, err) alloc.EXPECT().GetGenerator(mock.Anything, mock.Anything).Call.Return(nil, fmt.Errorf("mock err")) - _, err := uploadInsertLog(context.Background(), binlogIO, alloc, meta.GetID(), 10, 1, genInsertData(2), iCodec) + _, err = uploadInsertLog(context.Background(), binlogIO, alloc, meta.GetID(), partId, segId, blobs) assert.Error(t, err) }) @@ -147,13 +143,18 @@ func TestBinlogIOInterfaceMethods(t *testing.T) { alloc := allocator.NewMockAllocator(t) binlogIO := io.NewBinlogIO(mkc, getOrCreateIOPool()) iCodec := storage.NewInsertCodecWithSchema(meta) + var partId int64 = 1 + var segId int64 = 10 + iData := genInsertData(2) + blobs, err := iCodec.Serialize(10, 1, iData) + assert.NoError(t, err) alloc.EXPECT().GetGenerator(mock.Anything, mock.Anything).Call.Return(validGeneratorFn, nil) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - _, err := uploadInsertLog(ctx, binlogIO, alloc, meta.GetID(), 1, 10, genInsertData(2), iCodec) + _, err = uploadInsertLog(ctx, binlogIO, alloc, meta.GetID(), partId, segId, blobs) assert.Error(t, err) }) }) @@ -256,9 +257,13 @@ func TestBinlogIOInnerMethods(t *testing.T) { t.Run(test.description, func(t *testing.T) { meta := f.GetCollectionMeta(UniqueID(10001), "test_gen_blobs", test.pkType) iCodec := storage.NewInsertCodecWithSchema(meta) - + var partId int64 = 10 + var segId int64 = 1 + iData := genInsertData(2) + blobs, err := iCodec.Serialize(10, 1, iData) + assert.NoError(t, err) kvs := make(map[string][]byte) - pin, err := genInsertBlobs(binlogIO, alloc, genInsertData(2), meta.GetID(), 10, 1, iCodec, kvs) + pin, err := genInsertBlobs(binlogIO, alloc, blobs, meta.GetID(), partId, segId, kvs) assert.NoError(t, err) assert.Equal(t, 12, len(pin)) @@ -277,30 +282,22 @@ func TestBinlogIOInnerMethods(t *testing.T) { cm := storage.NewLocalChunkManager(storage.RootPath(binlogTestDir)) defer cm.RemoveWithPrefix(ctx, cm.RootPath()) - t.Run("serialize error", func(t *testing.T) { - iCodec := storage.NewInsertCodecWithSchema(nil) - - alloc := allocator.NewMockAllocator(t) - binlogIO := io.NewBinlogIO(cm, getOrCreateIOPool()) - kvs := make(map[string][]byte) - pin, err := genInsertBlobs(binlogIO, alloc, genEmptyInsertData(), 0, 10, 1, iCodec, kvs) - - assert.Error(t, err) - assert.Empty(t, kvs) - assert.Empty(t, pin) - }) - t.Run("GetGenerator error", func(t *testing.T) { f := &MetaFactory{} meta := f.GetCollectionMeta(UniqueID(10001), "test_gen_blobs", schemapb.DataType_Int64) iCodec := storage.NewInsertCodecWithSchema(meta) + var partId int64 = 10 + var segId int64 = 1 + iData := genInsertData(2) + blobs, err := iCodec.Serialize(partId, segId, iData) + assert.NoError(t, err) alloc := allocator.NewMockAllocator(t) alloc.EXPECT().GetGenerator(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("mock GetGenerator error")) binlogIO := io.NewBinlogIO(cm, getOrCreateIOPool()) kvs := make(map[string][]byte) - pin, err := genInsertBlobs(binlogIO, alloc, genInsertData(2), meta.GetID(), 10, 1, iCodec, kvs) + pin, err := genInsertBlobs(binlogIO, alloc, blobs, meta.GetID(), partId, segId, kvs) assert.Error(t, err) assert.Empty(t, kvs) diff --git a/internal/datanode/compactor.go b/internal/datanode/compactor.go index 44e292bc76da1..e99642316e6f4 100644 --- a/internal/datanode/compactor.go +++ b/internal/datanode/compactor.go @@ -55,8 +55,6 @@ var ( errContext = errors.New("context done or timeout") ) -type iterator = storage.Iterator - type compactor interface { complete() compact() (*datapb.CompactionPlanResult, error) @@ -174,48 +172,15 @@ func (t *compactionTask) mergeDeltalogs(dBlobs map[UniqueID][]*Blob) (map[interf return pk2ts, nil } -func (t *compactionTask) uploadRemainLog( - ctxTimeout context.Context, - targetSegID UniqueID, - partID UniqueID, - meta *etcdpb.CollectionMeta, - stats *storage.PrimaryKeyStats, - totRows int64, - writeBuffer *storage.InsertData, -) (map[UniqueID]*datapb.FieldBinlog, map[UniqueID]*datapb.FieldBinlog, error) { - iCodec := storage.NewInsertCodecWithSchema(meta) - inPaths := make(map[int64]*datapb.FieldBinlog, 0) - var err error - if !writeBuffer.IsEmpty() { - inPaths, err = uploadInsertLog(ctxTimeout, t.binlogIO, t.Allocator, meta.GetID(), partID, targetSegID, writeBuffer, iCodec) - if err != nil { - return nil, nil, err - } - } - - statPaths, err := uploadStatsLog(ctxTimeout, t.binlogIO, t.Allocator, meta.GetID(), partID, targetSegID, stats, totRows, iCodec) - if err != nil { - return nil, nil, err +func newBinlogWriter(collectionId, partitionId, segmentId UniqueID, schema *schemapb.CollectionSchema, +) (writer *storage.SerializeWriter[*storage.Value], closers []func() (*Blob, error), err error) { + fieldWriters := storage.NewBinlogStreamWriters(collectionId, partitionId, segmentId, schema.Fields) + closers = make([]func() (*Blob, error), 0, len(fieldWriters)) + for _, w := range fieldWriters { + closers = append(closers, w.Finalize) } - - return inPaths, statPaths, nil -} - -func (t *compactionTask) uploadSingleInsertLog( - ctxTimeout context.Context, - targetSegID UniqueID, - partID UniqueID, - meta *etcdpb.CollectionMeta, - writeBuffer *storage.InsertData, -) (map[UniqueID]*datapb.FieldBinlog, error) { - iCodec := storage.NewInsertCodecWithSchema(meta) - - inPaths, err := uploadInsertLog(ctxTimeout, t.binlogIO, t.Allocator, meta.GetID(), partID, targetSegID, writeBuffer, iCodec) - if err != nil { - return nil, err - } - - return inPaths, nil + writer, err = storage.NewBinlogSerializeWriter(schema, partitionId, segmentId, fieldWriters, 1024) + return } func (t *compactionTask) merge( @@ -231,10 +196,15 @@ func (t *compactionTask) merge( log := log.With(zap.Int64("planID", t.getPlanID())) mergeStart := time.Now() + writer, finalizers, err := newBinlogWriter(meta.GetID(), partID, targetSegID, meta.GetSchema()) + if err != nil { + return nil, nil, 0, err + } + var ( - numBinlogs int // binlog number - numRows int64 // the number of rows uploaded - expired int64 // the number of expired entity + numBinlogs int // binlog number + numRows uint64 // the number of rows uploaded + expired int64 // the number of expired entity insertField2Path = make(map[UniqueID]*datapb.FieldBinlog) insertPaths = make([]*datapb.FieldBinlog, 0) @@ -242,10 +212,6 @@ func (t *compactionTask) merge( statField2Path = make(map[UniqueID]*datapb.FieldBinlog) statPaths = make([]*datapb.FieldBinlog, 0) ) - writeBuffer, err := storage.NewInsertData(meta.GetSchema()) - if err != nil { - return nil, nil, -1, err - } isDeletedValue := func(v *storage.Value) bool { ts, ok := delta[v.PK.GetValue()] @@ -306,7 +272,7 @@ func (t *compactionTask) merge( numRows = 0 numBinlogs = 0 currentTs := t.GetCurrentTime() - currentRows := 0 + unflushedRows := 0 downloadTimeCost := time.Duration(0) uploadInsertTimeCost := time.Duration(0) @@ -325,6 +291,30 @@ func (t *compactionTask) merge( timestampFrom int64 = -1 ) + flush := func() error { + uploadInsertStart := time.Now() + writer.Close() + fieldData := make([]*Blob, len(finalizers)) + + for i, f := range finalizers { + blob, err := f() + if err != nil { + return err + } + fieldData[i] = blob + } + inPaths, err := uploadInsertLog(ctx, t.binlogIO, t.Allocator, meta.ID, partID, targetSegID, fieldData) + if err != nil { + log.Warn("failed to upload single insert log", zap.Error(err)) + return err + } + numBinlogs += len(inPaths) + uploadInsertTimeCost += time.Since(uploadInsertStart) + addInsertFieldPath(inPaths, timestampFrom, timestampTo) + unflushedRows = 0 + return nil + } + for _, path := range unMergedInsertlogs { downloadStart := time.Now() data, err := downloadBlobs(ctx, t.binlogIO, path) @@ -370,55 +360,50 @@ func (t *compactionTask) merge( timestampTo = v.Timestamp } - row, ok := v.Value.(map[UniqueID]interface{}) - if !ok { - log.Warn("transfer interface to map wrong", zap.Strings("path", path)) - return nil, nil, 0, errors.New("unexpected error") - } - - err = writeBuffer.Append(row) + err = writer.Write(v) if err != nil { return nil, nil, 0, err } + numRows++ + unflushedRows++ - currentRows++ stats.Update(v.PK) // check size every 100 rows in case of too many `GetMemorySize` call - if (currentRows+1)%100 == 0 && writeBuffer.GetMemorySize() > paramtable.Get().DataNodeCfg.BinLogMaxSize.GetAsInt() { - numRows += int64(writeBuffer.GetRowNum()) - uploadInsertStart := time.Now() - inPaths, err := t.uploadSingleInsertLog(ctx, targetSegID, partID, meta, writeBuffer) - if err != nil { - log.Warn("failed to upload single insert log", zap.Error(err)) - return nil, nil, 0, err + if (unflushedRows+1)%100 == 0 { + writer.Flush() // Flush to update memory size + + if writer.WrittenMemorySize() > paramtable.Get().DataNodeCfg.BinLogMaxSize.GetAsUint64() { + if err := flush(); err != nil { + return nil, nil, 0, err + } + timestampFrom = -1 + timestampTo = -1 + + writer, finalizers, err = newBinlogWriter(meta.ID, targetSegID, partID, meta.Schema) + if err != nil { + return nil, nil, 0, err + } } - uploadInsertTimeCost += time.Since(uploadInsertStart) - addInsertFieldPath(inPaths, timestampFrom, timestampTo) - timestampFrom = -1 - timestampTo = -1 - - writeBuffer, _ = storage.NewInsertData(meta.GetSchema()) - currentRows = 0 - numBinlogs++ } } } - // upload stats log and remain insert rows - if writeBuffer.GetRowNum() > 0 || numRows > 0 { - numRows += int64(writeBuffer.GetRowNum()) - uploadStart := time.Now() - inPaths, statsPaths, err := t.uploadRemainLog(ctx, targetSegID, partID, meta, - stats, numRows+int64(currentRows), writeBuffer) - if err != nil { + // final flush if there is unflushed rows + if unflushedRows > 0 { + if err := flush(); err != nil { return nil, nil, 0, err } + } - uploadInsertTimeCost += time.Since(uploadStart) - addInsertFieldPath(inPaths, timestampFrom, timestampTo) + // upload stats log + if numRows > 0 { + iCodec := storage.NewInsertCodecWithSchema(meta) + statsPaths, err := uploadStatsLog(ctx, t.binlogIO, t.Allocator, meta.GetID(), partID, targetSegID, stats, int64(numRows), iCodec) + if err != nil { + return nil, nil, 0, err + } addStatFieldPath(statsPaths) - numBinlogs += len(inPaths) } for _, path := range insertField2Path { @@ -430,14 +415,14 @@ func (t *compactionTask) merge( } log.Info("compact merge end", - zap.Int64("remaining insert numRows", numRows), + zap.Uint64("remaining insert numRows", numRows), zap.Int64("expired entities", expired), zap.Int("binlog file number", numBinlogs), zap.Duration("download insert log elapse", downloadTimeCost), zap.Duration("upload insert log elapse", uploadInsertTimeCost), zap.Duration("merge elapse", time.Since(mergeStart))) - return insertPaths, statPaths, numRows, nil + return insertPaths, statPaths, int64(numRows), nil } func (t *compactionTask) compact() (*datapb.CompactionPlanResult, error) { diff --git a/internal/datanode/compactor_test.go b/internal/datanode/compactor_test.go index ccae34bebae2a..efea77b55fbbb 100644 --- a/internal/datanode/compactor_test.go +++ b/internal/datanode/compactor_test.go @@ -21,7 +21,6 @@ import ( "fmt" "math" "testing" - "time" "github.com/cockroachdb/errors" "github.com/samber/lo" @@ -296,8 +295,12 @@ func TestCompactionTaskInnerMethods(t *testing.T) { paramtable.Get().Save(Params.CommonCfg.EntityExpirationTTL.Key, "0") iData := genInsertDataWithExpiredTS() iCodec := storage.NewInsertCodecWithSchema(meta) + var partId int64 = 0 + var segmentId int64 = 1 + blobs, err := iCodec.Serialize(partId, 0, iData) + assert.NoError(t, err) var allPaths [][]string - inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), 0, 1, iData, iCodec) + inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), partId, segmentId, blobs) assert.NoError(t, err) assert.Equal(t, 12, len(inpath)) binlogNum := len(inpath[0].GetBinlogs()) @@ -336,18 +339,22 @@ func TestCompactionTaskInnerMethods(t *testing.T) { }) t.Run("Merge without expiration2", func(t *testing.T) { mockbIO := io.NewBinlogIO(cm, getOrCreateIOPool()) + iData := genInsertDataWithExpiredTS() iCodec := storage.NewInsertCodecWithSchema(meta) + var partId int64 = 0 + var segmentId int64 = 1 + blobs, err := iCodec.Serialize(partId, 0, iData) + assert.NoError(t, err) paramtable.Get().Save(Params.CommonCfg.EntityExpirationTTL.Key, "0") BinLogMaxSize := Params.DataNodeCfg.BinLogMaxSize.GetValue() defer func() { Params.Save(Params.DataNodeCfg.BinLogMaxSize.Key, BinLogMaxSize) }() paramtable.Get().Save(Params.DataNodeCfg.BinLogMaxSize.Key, "64") - iData := genInsertDataWithExpiredTS() meta := NewMetaFactory().GetCollectionMeta(1, "test", schemapb.DataType_Int64) var allPaths [][]string - inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), 0, 1, iData, iCodec) + inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), partId, segmentId, blobs) assert.NoError(t, err) assert.Equal(t, 12, len(inpath)) binlogNum := len(inpath[0].GetBinlogs()) @@ -394,9 +401,13 @@ func TestCompactionTaskInnerMethods(t *testing.T) { }() paramtable.Get().Save(Params.DataNodeCfg.BinLogMaxSize.Key, "1") iData := genInsertData(101) + var partId int64 = 0 + var segmentId int64 = 1 + blobs, err := iCodec.Serialize(partId, segmentId, iData) + assert.NoError(t, err) var allPaths [][]string - inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), 0, 1, iData, iCodec) + inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), partId, segmentId, blobs) assert.NoError(t, err) assert.Equal(t, 12, len(inpath)) binlogNum := len(inpath[0].GetBinlogs()) @@ -440,10 +451,14 @@ func TestCompactionTaskInnerMethods(t *testing.T) { mockbIO := io.NewBinlogIO(cm, getOrCreateIOPool()) iCodec := storage.NewInsertCodecWithSchema(meta) iData := genInsertDataWithExpiredTS() + var partId int64 = 0 + var segmentId int64 = 1 + blobs, err := iCodec.Serialize(partId, 0, iData) + assert.NoError(t, err) meta := NewMetaFactory().GetCollectionMeta(1, "test", schemapb.DataType_Int64) var allPaths [][]string - inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), 0, 1, iData, iCodec) + inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), partId, segmentId, blobs) assert.NoError(t, err) assert.Equal(t, 12, len(inpath)) binlogNum := len(inpath[0].GetBinlogs()) @@ -485,6 +500,10 @@ func TestCompactionTaskInnerMethods(t *testing.T) { mockbIO := io.NewBinlogIO(cm, getOrCreateIOPool()) iData := genInsertDataWithExpiredTS() iCodec := storage.NewInsertCodecWithSchema(meta) + var partId int64 = 0 + var segmentId int64 = 1 + blobs, err := iCodec.Serialize(partId, 0, iData) + assert.NoError(t, err) meta := NewMetaFactory().GetCollectionMeta(1, "test", schemapb.DataType_Int64) metaCache := metacache.NewMockMetaCache(t) metaCache.EXPECT().Schema().Return(meta.GetSchema()).Maybe() @@ -499,7 +518,7 @@ func TestCompactionTaskInnerMethods(t *testing.T) { }) var allPaths [][]string - inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), 0, 1, iData, iCodec) + inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), partId, segmentId, blobs) assert.NoError(t, err) assert.Equal(t, 12, len(inpath)) binlogNum := len(inpath[0].GetBinlogs()) @@ -539,10 +558,14 @@ func TestCompactionTaskInnerMethods(t *testing.T) { iCodec := storage.NewInsertCodecWithSchema(meta) paramtable.Get().Save(Params.CommonCfg.EntityExpirationTTL.Key, "0") iData := genInsertDataWithExpiredTS() + var partId int64 = 0 + var segmentId int64 = 1 + blobs, err := iCodec.Serialize(partId, 0, iData) + assert.NoError(t, err) meta := NewMetaFactory().GetCollectionMeta(1, "test", schemapb.DataType_Int64) var allPaths [][]string - inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), 0, 1, iData, iCodec) + inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), partId, segmentId, blobs) assert.NoError(t, err) assert.Equal(t, 12, len(inpath)) binlogNum := len(inpath[0].GetBinlogs()) @@ -586,10 +609,14 @@ func TestCompactionTaskInnerMethods(t *testing.T) { iCodec := storage.NewInsertCodecWithSchema(meta) paramtable.Get().Save(Params.CommonCfg.EntityExpirationTTL.Key, "0") iData := genInsertDataWithExpiredTS() + var partId int64 = 0 + var segmentId int64 = 1 + blobs, err := iCodec.Serialize(partId, 0, iData) + assert.NoError(t, err) meta := NewMetaFactory().GetCollectionMeta(1, "test", schemapb.DataType_Int64) var allPaths [][]string - inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), 0, 1, iData, iCodec) + inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), partId, segmentId, blobs) assert.NoError(t, err) assert.Equal(t, 12, len(inpath)) binlogNum := len(inpath[0].GetBinlogs()) @@ -714,32 +741,6 @@ func TestCompactionTaskInnerMethods(t *testing.T) { _, err := ct.getNumRows() assert.Error(t, err, "segment not found") }) - - t.Run("Test uploadRemainLog error", func(t *testing.T) { - f := &MetaFactory{} - - t.Run("upload failed", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - alloc := allocator.NewMockAllocator(t) - alloc.EXPECT().AllocOne().Call.Return(int64(11111), nil) - - meta := f.GetCollectionMeta(UniqueID(10001), "test_upload_remain_log", schemapb.DataType_Int64) - stats, err := storage.NewPrimaryKeyStats(106, int64(schemapb.DataType_Int64), 10) - - require.NoError(t, err) - - ct := &compactionTask{ - binlogIO: io.NewBinlogIO(&mockCm{errSave: true}, getOrCreateIOPool()), - Allocator: alloc, - done: make(chan struct{}, 1), - } - - _, _, err = ct.uploadRemainLog(ctx, 1, 2, meta, stats, 10, nil) - assert.Error(t, err) - }) - }) } func getInt64DeltaBlobs(segID UniqueID, pks []UniqueID, tss []Timestamp) ([]*Blob, error) { @@ -924,12 +925,16 @@ func TestCompactorInterfaceMethods(t *testing.T) { metaCache.EXPECT().GetSegmentByID(mock.Anything).Return(nil, false) iData1 := genInsertDataWithPKs(c.pks1, c.pkType) + iblobs1, err := iCodec.Serialize(c.parID, 0, iData1) + assert.NoError(t, err) dData1 := &DeleteData{ Pks: []storage.PrimaryKey{c.pks1[0]}, Tss: []Timestamp{20000}, RowCount: 1, } iData2 := genInsertDataWithPKs(c.pks2, c.pkType) + iblobs2, err := iCodec.Serialize(c.parID, 3, iData2) + assert.NoError(t, err) dData2 := &DeleteData{ Pks: []storage.PrimaryKey{c.pks2[0]}, Tss: []Timestamp{30000}, @@ -938,7 +943,7 @@ func TestCompactorInterfaceMethods(t *testing.T) { stats1, err := storage.NewPrimaryKeyStats(1, int64(c.pkType), 1) require.NoError(t, err) - iPaths1, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), c.parID, c.segID1, iData1, iCodec) + iPaths1, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), c.parID, c.segID1, iblobs1) require.NoError(t, err) sPaths1, err := uploadStatsLog(context.Background(), mockbIO, alloc, meta.GetID(), c.parID, c.segID1, stats1, 2, iCodec) require.NoError(t, err) @@ -948,7 +953,7 @@ func TestCompactorInterfaceMethods(t *testing.T) { stats2, err := storage.NewPrimaryKeyStats(1, int64(c.pkType), 1) require.NoError(t, err) - iPaths2, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), c.parID, c.segID2, iData2, iCodec) + iPaths2, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), c.parID, c.segID2, iblobs2) require.NoError(t, err) sPaths2, err := uploadStatsLog(context.Background(), mockbIO, alloc, meta.GetID(), c.parID, c.segID2, stats2, 2, iCodec) require.NoError(t, err) @@ -1067,7 +1072,11 @@ func TestCompactorInterfaceMethods(t *testing.T) { // the same pk for segmentI and segmentII pks := [2]storage.PrimaryKey{storage.NewInt64PrimaryKey(1), storage.NewInt64PrimaryKey(2)} iData1 := genInsertDataWithPKs(pks, schemapb.DataType_Int64) + iblobs1, err := iCodec.Serialize(partID, 0, iData1) + assert.NoError(t, err) iData2 := genInsertDataWithPKs(pks, schemapb.DataType_Int64) + iblobs2, err := iCodec.Serialize(partID, 1, iData2) + assert.NoError(t, err) pk1 := storage.NewInt64PrimaryKey(1) dData1 := &DeleteData{ @@ -1084,7 +1093,7 @@ func TestCompactorInterfaceMethods(t *testing.T) { stats1, err := storage.NewPrimaryKeyStats(1, int64(schemapb.DataType_Int64), 1) require.NoError(t, err) - iPaths1, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), partID, segID1, iData1, iCodec) + iPaths1, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), partID, segID1, iblobs1) require.NoError(t, err) sPaths1, err := uploadStatsLog(context.Background(), mockbIO, alloc, meta.GetID(), partID, segID1, stats1, 1, iCodec) require.NoError(t, err) @@ -1094,7 +1103,7 @@ func TestCompactorInterfaceMethods(t *testing.T) { stats2, err := storage.NewPrimaryKeyStats(1, int64(schemapb.DataType_Int64), 1) require.NoError(t, err) - iPaths2, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), partID, segID2, iData2, iCodec) + iPaths2, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), partID, segID2, iblobs2) require.NoError(t, err) sPaths2, err := uploadStatsLog(context.Background(), mockbIO, alloc, meta.GetID(), partID, segID2, stats2, 1, iCodec) require.NoError(t, err) @@ -1160,3 +1169,78 @@ func TestInjectDone(t *testing.T) { task.injectDone() task.injectDone() } + +func BenchmarkCompaction(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cm := storage.NewLocalChunkManager(storage.RootPath(compactTestDir)) + defer cm.RemoveWithPrefix(ctx, cm.RootPath()) + + collectionID := int64(1) + meta := NewMetaFactory().GetCollectionMeta(collectionID, "test", schemapb.DataType_Int64) + mockbIO := io.NewBinlogIO(cm, getOrCreateIOPool()) + paramtable.Get().Save(Params.CommonCfg.EntityExpirationTTL.Key, "0") + iData := genInsertDataWithExpiredTS() + iCodec := storage.NewInsertCodecWithSchema(meta) + var partId int64 = 0 + var segmentId int64 = 1 + blobs, err := iCodec.Serialize(partId, 0, iData) + assert.NoError(b, err) + var allPaths [][]string + alloc := allocator.NewMockAllocator(b) + alloc.EXPECT().GetGenerator(mock.Anything, mock.Anything).Call.Return(validGeneratorFn, nil) + alloc.EXPECT().AllocOne().Call.Return(int64(19530), nil) + inpath, err := uploadInsertLog(context.Background(), mockbIO, alloc, meta.GetID(), partId, segmentId, blobs) + assert.NoError(b, err) + assert.Equal(b, 12, len(inpath)) + binlogNum := len(inpath[0].GetBinlogs()) + assert.Equal(b, 1, binlogNum) + + for idx := 0; idx < binlogNum; idx++ { + var ps []string + for _, path := range inpath { + ps = append(ps, path.GetBinlogs()[idx].GetLogPath()) + } + allPaths = append(allPaths, ps) + } + + dm := map[interface{}]Timestamp{ + 1: 10000, + } + + metaCache := metacache.NewMockMetaCache(b) + metaCache.EXPECT().Schema().Return(meta.GetSchema()).Maybe() + metaCache.EXPECT().GetSegmentByID(mock.Anything).RunAndReturn(func(id int64, filters ...metacache.SegmentFilter) (*metacache.SegmentInfo, bool) { + segment := metacache.NewSegmentInfo(&datapb.SegmentInfo{ + CollectionID: 1, + PartitionID: 0, + ID: id, + NumOfRows: 10, + }, nil) + return segment, true + }) + + ct := &compactionTask{ + metaCache: metaCache, + binlogIO: mockbIO, + Allocator: alloc, + done: make(chan struct{}, 1), + plan: &datapb.CompactionPlan{ + SegmentBinlogs: []*datapb.CompactionSegmentBinlogs{ + {SegmentID: 1}, + }, + }, + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + inPaths, statsPaths, numOfRow, err := ct.merge(context.Background(), allPaths, 2, 0, meta, dm) + assert.NoError(b, err) + assert.Equal(b, int64(2), numOfRow) + assert.Equal(b, 1, len(inPaths[0].GetBinlogs())) + assert.Equal(b, 1, len(statsPaths)) + assert.NotEqual(b, -1, inPaths[0].GetBinlogs()[0].GetTimestampFrom()) + assert.NotEqual(b, -1, inPaths[0].GetBinlogs()[0].GetTimestampTo()) + } +} diff --git a/internal/datanode/data_sync_service.go b/internal/datanode/data_sync_service.go index 1620d25734ae5..ca744d239f1bf 100644 --- a/internal/datanode/data_sync_service.go +++ b/internal/datanode/data_sync_service.go @@ -254,13 +254,9 @@ func loadStats(ctx context.Context, chunkManager storage.ChunkManager, schema *s log := log.With(zap.Int64("segmentID", segmentID)) log.Info("begin to init pk bloom filter", zap.Int("statsBinLogsLen", len(statsBinlogs))) - // get pkfield id - pkField := int64(-1) - for _, field := range schema.Fields { - if field.IsPrimaryKey { - pkField = field.FieldID - break - } + pkField, err := typeutil.GetPrimaryFieldSchema(schema) + if err != nil { + return nil, err } // filter stats binlog files which is pk field stats log @@ -268,7 +264,7 @@ func loadStats(ctx context.Context, chunkManager storage.ChunkManager, schema *s logType := storage.DefaultStatsType for _, binlog := range statsBinlogs { - if binlog.FieldID != pkField { + if binlog.FieldID != pkField.GetFieldID() { continue } Loop: diff --git a/internal/datanode/io/binlog_io.go b/internal/datanode/io/binlog_io.go index bd470d21e6b9f..c60af8e992dda 100644 --- a/internal/datanode/io/binlog_io.go +++ b/internal/datanode/io/binlog_io.go @@ -99,7 +99,6 @@ func (b *BinlogIoImpl) Upload(ctx context.Context, kvs map[string][]byte) error } return err }) - return struct{}{}, err }) diff --git a/internal/datanode/l0_compactor.go b/internal/datanode/l0_compactor.go index b211c78749211..75bde780323c7 100644 --- a/internal/datanode/l0_compactor.go +++ b/internal/datanode/l0_compactor.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "math" + "sync" "time" "github.com/samber/lo" @@ -38,6 +39,7 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/hardware" "github.com/milvus-io/milvus/pkg/util/merr" @@ -54,6 +56,7 @@ type levelZeroCompactionTask struct { allocator allocator.Allocator metacache metacache.MetaCache syncmgr syncmgr.SyncManager + cm storage.ChunkManager plan *datapb.CompactionPlan @@ -70,6 +73,7 @@ func newLevelZeroCompactionTask( alloc allocator.Allocator, metaCache metacache.MetaCache, syncmgr syncmgr.SyncManager, + cm storage.ChunkManager, plan *datapb.CompactionPlan, ) *levelZeroCompactionTask { ctx, cancel := context.WithCancel(ctx) @@ -81,6 +85,7 @@ func newLevelZeroCompactionTask( allocator: alloc, metacache: metaCache, syncmgr: syncmgr, + cm: cm, plan: plan, tr: timerecord.NewTimeRecorder("levelzero compaction"), done: make(chan struct{}, 1), @@ -129,13 +134,10 @@ func (t *levelZeroCompactionTask) compact() (*datapb.CompactionPlanResult, error return s.Level == datapb.SegmentLevel_L0 }) - targetSegIDs := lo.FilterMap(t.plan.GetSegmentBinlogs(), func(s *datapb.CompactionSegmentBinlogs, _ int) (int64, bool) { - if s.Level == datapb.SegmentLevel_L1 { - return s.GetSegmentID(), true - } - return 0, false + targetSegments := lo.Filter(t.plan.GetSegmentBinlogs(), func(s *datapb.CompactionSegmentBinlogs, _ int) bool { + return s.Level != datapb.SegmentLevel_L0 }) - if len(targetSegIDs) == 0 { + if len(targetSegments) == 0 { log.Warn("compact wrong, not target sealed segments") return nil, errIllegalCompactionPlan } @@ -165,9 +167,9 @@ func (t *levelZeroCompactionTask) compact() (*datapb.CompactionPlanResult, error var resultSegments []*datapb.CompactionSegment if float64(hardware.GetFreeMemoryCount())*paramtable.Get().DataNodeCfg.L0BatchMemoryRatio.GetAsFloat() < float64(totalSize) { - resultSegments, err = t.linearProcess(ctxTimeout, targetSegIDs, totalDeltalogs) + resultSegments, err = t.linearProcess(ctxTimeout, targetSegments, totalDeltalogs) } else { - resultSegments, err = t.batchProcess(ctxTimeout, targetSegIDs, lo.Values(totalDeltalogs)...) + resultSegments, err = t.batchProcess(ctxTimeout, targetSegments, lo.Values(totalDeltalogs)...) } if err != nil { return nil, err @@ -188,65 +190,87 @@ func (t *levelZeroCompactionTask) compact() (*datapb.CompactionPlanResult, error return result, nil } -func (t *levelZeroCompactionTask) linearProcess(ctx context.Context, targetSegments []int64, totalDeltalogs map[int64][]string) ([]*datapb.CompactionSegment, error) { +func (t *levelZeroCompactionTask) linearProcess(ctx context.Context, targetSegments []*datapb.CompactionSegmentBinlogs, totalDeltalogs map[int64][]string) ([]*datapb.CompactionSegment, error) { log := log.Ctx(t.ctx).With( zap.Int64("planID", t.plan.GetPlanID()), zap.String("type", t.plan.GetType().String()), zap.Int("target segment counts", len(targetSegments)), ) + + // just for logging + targetSegmentIDs := lo.Map(targetSegments, func(segment *datapb.CompactionSegmentBinlogs, _ int) int64 { + return segment.GetSegmentID() + }) + var ( resultSegments = make(map[int64]*datapb.CompactionSegment) alteredSegments = make(map[int64]*storage.DeleteData) ) + + segmentBFs, err := t.loadBF(targetSegments) + if err != nil { + return nil, err + } for segID, deltaLogs := range totalDeltalogs { log := log.With(zap.Int64("levelzero segment", segID)) log.Info("Linear L0 compaction start processing segment") allIters, err := t.loadDelta(ctx, deltaLogs) if err != nil { - log.Warn("Linear L0 compaction loadDelta fail", zap.Int64s("target segments", targetSegments), zap.Error(err)) + log.Warn("Linear L0 compaction loadDelta fail", zap.Int64s("target segments", targetSegmentIDs), zap.Error(err)) return nil, err } - t.splitDelta(ctx, allIters, alteredSegments, targetSegments) + t.splitDelta(ctx, allIters, alteredSegments, segmentBFs) err = t.uploadByCheck(ctx, true, alteredSegments, resultSegments) if err != nil { - log.Warn("Linear L0 compaction upload buffer fail", zap.Int64s("target segments", targetSegments), zap.Error(err)) + log.Warn("Linear L0 compaction upload buffer fail", zap.Int64s("target segments", targetSegmentIDs), zap.Error(err)) return nil, err } } - err := t.uploadByCheck(ctx, false, alteredSegments, resultSegments) + err = t.uploadByCheck(ctx, false, alteredSegments, resultSegments) if err != nil { - log.Warn("Linear L0 compaction upload all buffer fail", zap.Int64s("target segment", targetSegments), zap.Error(err)) + log.Warn("Linear L0 compaction upload all buffer fail", zap.Int64s("target segment", targetSegmentIDs), zap.Error(err)) return nil, err } log.Info("Linear L0 compaction finished", zap.Duration("elapse", t.tr.RecordSpan())) return lo.Values(resultSegments), nil } -func (t *levelZeroCompactionTask) batchProcess(ctx context.Context, targetSegments []int64, deltaLogs ...[]string) ([]*datapb.CompactionSegment, error) { +func (t *levelZeroCompactionTask) batchProcess(ctx context.Context, targetSegments []*datapb.CompactionSegmentBinlogs, deltaLogs ...[]string) ([]*datapb.CompactionSegment, error) { log := log.Ctx(t.ctx).With( zap.Int64("planID", t.plan.GetPlanID()), zap.String("type", t.plan.GetType().String()), zap.Int("target segment counts", len(targetSegments)), ) + + // just for logging + targetSegmentIDs := lo.Map(targetSegments, func(segment *datapb.CompactionSegmentBinlogs, _ int) int64 { + return segment.GetSegmentID() + }) + log.Info("Batch L0 compaction start processing") resultSegments := make(map[int64]*datapb.CompactionSegment) iters, err := t.loadDelta(ctx, lo.Flatten(deltaLogs)) if err != nil { - log.Warn("Batch L0 compaction loadDelta fail", zap.Int64s("target segments", targetSegments), zap.Error(err)) + log.Warn("Batch L0 compaction loadDelta fail", zap.Int64s("target segments", targetSegmentIDs), zap.Error(err)) + return nil, err + } + + segmentBFs, err := t.loadBF(targetSegments) + if err != nil { return nil, err } alteredSegments := make(map[int64]*storage.DeleteData) - t.splitDelta(ctx, iters, alteredSegments, targetSegments) + t.splitDelta(ctx, iters, alteredSegments, segmentBFs) err = t.uploadByCheck(ctx, false, alteredSegments, resultSegments) if err != nil { - log.Warn("Batch L0 compaction upload fail", zap.Int64s("target segments", targetSegments), zap.Error(err)) + log.Warn("Batch L0 compaction upload fail", zap.Int64s("target segments", targetSegmentIDs), zap.Error(err)) return nil, err } log.Info("Batch L0 compaction finished", zap.Duration("elapse", t.tr.RecordSpan())) @@ -271,18 +295,20 @@ func (t *levelZeroCompactionTask) splitDelta( ctx context.Context, allIters []*iter.DeltalogIterator, targetSegBuffer map[int64]*storage.DeleteData, - targetSegIDs []int64, + segmentBfs map[int64]*metacache.BloomFilterSet, ) { _, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "L0Compact splitDelta") defer span.End() - // segments shall be safe to read outside - segments := t.metacache.GetSegmentsBy(metacache.WithSegmentIDs(targetSegIDs...)) split := func(pk storage.PrimaryKey) []int64 { lc := storage.NewLocationsCache(pk) - return lo.FilterMap(segments, func(segment *metacache.SegmentInfo, _ int) (int64, bool) { - return segment.SegmentID(), segment.GetBloomFilterSet().PkExists(lc) - }) + predicts := make([]int64, 0, len(segmentBfs)) + for segmentID, bf := range segmentBfs { + if bf.PkExists(lc) { + predicts = append(predicts, segmentID) + } + } + return predicts } // spilt all delete data to segments @@ -395,3 +421,41 @@ func (t *levelZeroCompactionTask) uploadByCheck(ctx context.Context, requireChec return nil } + +func (t *levelZeroCompactionTask) loadBF(targetSegments []*datapb.CompactionSegmentBinlogs) (map[int64]*metacache.BloomFilterSet, error) { + log := log.Ctx(t.ctx).With( + zap.Int64("planID", t.plan.GetPlanID()), + zap.String("type", t.plan.GetType().String()), + ) + + var ( + futures = make([]*conc.Future[any], 0, len(targetSegments)) + pool = getOrCreateStatsPool() + + mu = &sync.Mutex{} + bfs = make(map[int64]*metacache.BloomFilterSet) + ) + + for _, segment := range targetSegments { + segment := segment + future := pool.Submit(func() (any, error) { + _ = binlog.DecompressBinLog(storage.StatsBinlog, segment.GetCollectionID(), + segment.GetPartitionID(), segment.GetSegmentID(), segment.GetField2StatslogPaths()) + pks, err := loadStats(t.ctx, t.cm, + t.metacache.Schema(), segment.GetSegmentID(), segment.GetField2StatslogPaths()) + if err != nil { + log.Warn("failed to load segment stats log", zap.Error(err)) + return err, err + } + bf := metacache.NewBloomFilterSet(pks...) + mu.Lock() + defer mu.Unlock() + bfs[segment.GetSegmentID()] = bf + return nil, nil + }) + futures = append(futures, future) + } + + err := conc.AwaitAll(futures...) + return bfs, err +} diff --git a/internal/datanode/l0_compactor_test.go b/internal/datanode/l0_compactor_test.go index 80b9241285c14..59b66086d67fa 100644 --- a/internal/datanode/l0_compactor_test.go +++ b/internal/datanode/l0_compactor_test.go @@ -27,10 +27,12 @@ import ( "github.com/stretchr/testify/suite" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/datanode/allocator" "github.com/milvus-io/milvus/internal/datanode/io" iter "github.com/milvus-io/milvus/internal/datanode/iterators" "github.com/milvus-io/milvus/internal/datanode/metacache" + "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" @@ -61,7 +63,7 @@ func (s *LevelZeroCompactionTaskSuite) SetupTest() { s.mockBinlogIO = io.NewMockBinlogIO(s.T()) s.mockMeta = metacache.NewMockMetaCache(s.T()) // plan of the task is unset - s.task = newLevelZeroCompactionTask(context.Background(), s.mockBinlogIO, s.mockAlloc, s.mockMeta, nil, nil) + s.task = newLevelZeroCompactionTask(context.Background(), s.mockBinlogIO, s.mockAlloc, s.mockMeta, nil, nil, nil) pk2ts := map[int64]uint64{ 1: 20000, @@ -105,7 +107,17 @@ func (s *LevelZeroCompactionTaskSuite) TestLinearBatchLoadDeltaFail() { s.task.tr = timerecord.NewTimeRecorder("test") s.mockBinlogIO.EXPECT().Download(mock.Anything, mock.Anything).Return(nil, errors.New("mock download fail")).Twice() - targetSegments := []int64{200} + s.mockMeta.EXPECT().Schema().Return(&schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + IsPrimaryKey: true, + }, + }, + }) + + targetSegments := lo.Filter(plan.SegmentBinlogs, func(s *datapb.CompactionSegmentBinlogs, _ int) bool { + return s.Level == datapb.SegmentLevel_L1 + }) deltaLogs := map[int64][]string{100: {"a/b/c1"}} segments, err := s.task.linearProcess(context.Background(), targetSegments, deltaLogs) @@ -134,24 +146,43 @@ func (s *LevelZeroCompactionTaskSuite) TestLinearBatchUploadByCheckFail() { }, }, }, - {SegmentID: 200, Level: datapb.SegmentLevel_L1}, + {SegmentID: 200, Level: datapb.SegmentLevel_L1, Field2StatslogPaths: []*datapb.FieldBinlog{ + { + Binlogs: []*datapb.Binlog{ + {LogID: 9999, LogSize: 100}, + }, + }, + }}, }, } s.task.plan = plan s.task.tr = timerecord.NewTimeRecorder("test") + + data := &storage.Int64FieldData{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + sw := &storage.StatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_Int64, data) + s.NoError(err) + cm := mocks.NewChunkManager(s.T()) + cm.EXPECT().MultiRead(mock.Anything, mock.Anything).Return([][]byte{sw.GetBuffer()}, nil) + s.task.cm = cm + s.mockBinlogIO.EXPECT().Download(mock.Anything, mock.Anything).Return([][]byte{s.dBlob}, nil).Times(2) s.mockMeta.EXPECT().Collection().Return(1) s.mockMeta.EXPECT().GetSegmentByID(mock.Anything).Return(nil, false).Twice() - s.mockMeta.EXPECT().GetSegmentsBy(mock.Anything).RunAndReturn( - func(filters ...metacache.SegmentFilter) []*metacache.SegmentInfo { - bfs1 := metacache.NewBloomFilterSetWithBatchSize(100) - bfs1.UpdatePKRange(&storage.Int64FieldData{Data: []int64{1, 2}}) - segment1 := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 200}, bfs1) - return []*metacache.SegmentInfo{segment1} - }).Twice() - - targetSegments := []int64{200} + s.mockMeta.EXPECT().Schema().Return(&schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + IsPrimaryKey: true, + }, + }, + }) + + targetSegments := lo.Filter(plan.SegmentBinlogs, func(s *datapb.CompactionSegmentBinlogs, _ int) bool { + return s.Level == datapb.SegmentLevel_L1 + }) deltaLogs := map[int64][]string{100: {"a/b/c1"}} segments, err := s.task.linearProcess(context.Background(), targetSegments, deltaLogs) @@ -192,28 +223,49 @@ func (s *LevelZeroCompactionTaskSuite) TestCompactLinear() { }, }, }, - {SegmentID: 200, Level: datapb.SegmentLevel_L1}, - {SegmentID: 201, Level: datapb.SegmentLevel_L1}, + {SegmentID: 200, Level: datapb.SegmentLevel_L1, Field2StatslogPaths: []*datapb.FieldBinlog{ + { + Binlogs: []*datapb.Binlog{ + {LogID: 9999, LogSize: 100}, + }, + }, + }}, + {SegmentID: 201, Level: datapb.SegmentLevel_L1, Field2StatslogPaths: []*datapb.FieldBinlog{ + { + Binlogs: []*datapb.Binlog{ + {LogID: 9999, LogSize: 100}, + }, + }, + }}, }, } s.task.plan = plan s.task.tr = timerecord.NewTimeRecorder("test") - bfs1 := metacache.NewBloomFilterSetWithBatchSize(100) - bfs1.UpdatePKRange(&storage.Int64FieldData{Data: []int64{1, 2}}) - segment1 := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 200}, bfs1) - bfs2 := metacache.NewBloomFilterSetWithBatchSize(100) - bfs2.UpdatePKRange(&storage.Int64FieldData{Data: []int64{1, 2}}) - segment2 := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 201}, bfs2) + data := &storage.Int64FieldData{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + sw := &storage.StatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_Int64, data) + s.NoError(err) + cm := mocks.NewChunkManager(s.T()) + cm.EXPECT().MultiRead(mock.Anything, mock.Anything).Return([][]byte{sw.GetBuffer()}, nil) + s.task.cm = cm s.mockBinlogIO.EXPECT().Download(mock.Anything, mock.Anything).Return([][]byte{s.dBlob}, nil).Times(2) - s.mockMeta.EXPECT().GetSegmentsBy(mock.Anything).Return([]*metacache.SegmentInfo{segment1, segment2}) s.mockMeta.EXPECT().Collection().Return(1) s.mockMeta.EXPECT().GetSegmentByID(mock.Anything, mock.Anything). RunAndReturn(func(id int64, filters ...metacache.SegmentFilter) (*metacache.SegmentInfo, bool) { return metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: id, PartitionID: 10}, nil), true }) + s.mockMeta.EXPECT().Schema().Return(&schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + IsPrimaryKey: true, + }, + }, + }) s.mockAlloc.EXPECT().AllocOne().Return(19530, nil).Times(2) s.mockBinlogIO.EXPECT().JoinFullPath(mock.Anything, mock.Anything). @@ -230,11 +282,8 @@ func (s *LevelZeroCompactionTaskSuite) TestCompactLinear() { return s.Level == datapb.SegmentLevel_L0 }) - targetSegIDs := lo.FilterMap(s.task.plan.GetSegmentBinlogs(), func(s *datapb.CompactionSegmentBinlogs, _ int) (int64, bool) { - if s.Level == datapb.SegmentLevel_L1 { - return s.GetSegmentID(), true - } - return 0, false + targetSegments := lo.Filter(s.task.plan.GetSegmentBinlogs(), func(s *datapb.CompactionSegmentBinlogs, _ int) bool { + return s.Level == datapb.SegmentLevel_L1 }) totalDeltalogs := make(map[UniqueID][]string) @@ -249,7 +298,7 @@ func (s *LevelZeroCompactionTaskSuite) TestCompactLinear() { totalDeltalogs[s.GetSegmentID()] = paths } } - segments, err := s.task.linearProcess(context.Background(), targetSegIDs, totalDeltalogs) + segments, err := s.task.linearProcess(context.Background(), targetSegments, totalDeltalogs) s.NoError(err) s.NotEmpty(segments) s.Equal(2, len(segments)) @@ -257,6 +306,9 @@ func (s *LevelZeroCompactionTaskSuite) TestCompactLinear() { lo.Map(segments, func(seg *datapb.CompactionSegment, _ int) int64 { return seg.GetSegmentID() })) + for _, segment := range segments { + s.NotNil(segment.GetDeltalogs()) + } log.Info("test segment results", zap.Any("result", segments)) } @@ -290,25 +342,35 @@ func (s *LevelZeroCompactionTaskSuite) TestCompactBatch() { }, }, }, - {SegmentID: 200, Level: datapb.SegmentLevel_L1}, - {SegmentID: 201, Level: datapb.SegmentLevel_L1}, + {SegmentID: 200, Level: datapb.SegmentLevel_L1, Field2StatslogPaths: []*datapb.FieldBinlog{ + { + Binlogs: []*datapb.Binlog{ + {LogID: 9999, LogSize: 100}, + }, + }, + }}, + {SegmentID: 201, Level: datapb.SegmentLevel_L1, Field2StatslogPaths: []*datapb.FieldBinlog{ + { + Binlogs: []*datapb.Binlog{ + {LogID: 9999, LogSize: 100}, + }, + }, + }}, }, } s.task.plan = plan s.task.tr = timerecord.NewTimeRecorder("test") - s.mockMeta.EXPECT().GetSegmentsBy(mock.Anything).RunAndReturn( - func(filters ...metacache.SegmentFilter) []*metacache.SegmentInfo { - bfs1 := metacache.NewBloomFilterSetWithBatchSize(100) - bfs1.UpdatePKRange(&storage.Int64FieldData{Data: []int64{1, 2, 3}}) - segment1 := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 200}, bfs1) - bfs2 := metacache.NewBloomFilterSetWithBatchSize(100) - bfs2.UpdatePKRange(&storage.Int64FieldData{Data: []int64{1, 2, 3}}) - segment2 := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 201}, bfs2) - - return []*metacache.SegmentInfo{segment1, segment2} - }) + data := &storage.Int64FieldData{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + sw := &storage.StatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_Int64, data) + s.NoError(err) + cm := mocks.NewChunkManager(s.T()) + cm.EXPECT().MultiRead(mock.Anything, mock.Anything).Return([][]byte{sw.GetBuffer()}, nil) + s.task.cm = cm s.mockBinlogIO.EXPECT().Download(mock.Anything, mock.Anything).Return([][]byte{s.dBlob}, nil).Once() s.mockMeta.EXPECT().Collection().Return(1) @@ -316,6 +378,13 @@ func (s *LevelZeroCompactionTaskSuite) TestCompactBatch() { RunAndReturn(func(id int64, filters ...metacache.SegmentFilter) (*metacache.SegmentInfo, bool) { return metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: id, PartitionID: 10}, nil), true }) + s.mockMeta.EXPECT().Schema().Return(&schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + IsPrimaryKey: true, + }, + }, + }) s.mockAlloc.EXPECT().AllocOne().Return(19530, nil).Times(2) s.mockBinlogIO.EXPECT().JoinFullPath(mock.Anything, mock.Anything). @@ -328,11 +397,8 @@ func (s *LevelZeroCompactionTaskSuite) TestCompactBatch() { return s.Level == datapb.SegmentLevel_L0 }) - targetSegIDs := lo.FilterMap(s.task.plan.GetSegmentBinlogs(), func(s *datapb.CompactionSegmentBinlogs, _ int) (int64, bool) { - if s.Level == datapb.SegmentLevel_L1 { - return s.GetSegmentID(), true - } - return 0, false + targetSegments := lo.Filter(s.task.plan.GetSegmentBinlogs(), func(s *datapb.CompactionSegmentBinlogs, _ int) bool { + return s.Level == datapb.SegmentLevel_L1 }) totalDeltalogs := make(map[UniqueID][]string) @@ -347,7 +413,7 @@ func (s *LevelZeroCompactionTaskSuite) TestCompactBatch() { totalDeltalogs[s.GetSegmentID()] = paths } } - segments, err := s.task.batchProcess(context.TODO(), targetSegIDs, lo.Values(totalDeltalogs)...) + segments, err := s.task.batchProcess(context.TODO(), targetSegments, lo.Values(totalDeltalogs)...) s.NoError(err) s.NotEmpty(segments) s.Equal(2, len(segments)) @@ -355,6 +421,9 @@ func (s *LevelZeroCompactionTaskSuite) TestCompactBatch() { lo.Map(segments, func(seg *datapb.CompactionSegment, _ int) int64 { return seg.GetSegmentID() })) + for _, segment := range segments { + s.NotNil(segment.GetDeltalogs()) + } log.Info("test segment results", zap.Any("result", segments)) } @@ -506,23 +575,23 @@ func (s *LevelZeroCompactionTaskSuite) TestComposeDeltalog() { func (s *LevelZeroCompactionTaskSuite) TestSplitDelta() { bfs1 := metacache.NewBloomFilterSetWithBatchSize(100) bfs1.UpdatePKRange(&storage.Int64FieldData{Data: []int64{1, 3}}) - segment1 := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 100}, bfs1) bfs2 := metacache.NewBloomFilterSetWithBatchSize(100) bfs2.UpdatePKRange(&storage.Int64FieldData{Data: []int64{3}}) - segment2 := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 101}, bfs2) bfs3 := metacache.NewBloomFilterSetWithBatchSize(100) bfs3.UpdatePKRange(&storage.Int64FieldData{Data: []int64{3}}) - segment3 := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 102}, bfs3) predicted := []int64{100, 101, 102} - s.mockMeta.EXPECT().GetSegmentsBy(mock.Anything).Return([]*metacache.SegmentInfo{segment1, segment2, segment3}) diter := iter.NewDeltalogIterator([][]byte{s.dBlob}, nil) s.Require().NotNil(diter) targetSegBuffer := make(map[int64]*storage.DeleteData) - targetSegIDs := predicted - s.task.splitDelta(context.TODO(), []*iter.DeltalogIterator{diter}, targetSegBuffer, targetSegIDs) + segmentBFs := map[int64]*metacache.BloomFilterSet{ + 100: bfs1, + 101: bfs2, + 102: bfs3, + } + s.task.splitDelta(context.TODO(), []*iter.DeltalogIterator{diter}, targetSegBuffer, segmentBFs) s.NotEmpty(targetSegBuffer) s.ElementsMatch(predicted, lo.Keys(targetSegBuffer)) @@ -601,3 +670,94 @@ func (s *LevelZeroCompactionTaskSuite) TestLoadDelta() { } } } + +func (s *LevelZeroCompactionTaskSuite) TestLoadBF() { + plan := &datapb.CompactionPlan{ + PlanID: 19530, + Type: datapb.CompactionType_Level0DeleteCompaction, + SegmentBinlogs: []*datapb.CompactionSegmentBinlogs{ + {SegmentID: 201, Level: datapb.SegmentLevel_L1, Field2StatslogPaths: []*datapb.FieldBinlog{ + { + Binlogs: []*datapb.Binlog{ + {LogID: 9999, LogSize: 100}, + }, + }, + }}, + }, + } + + s.task.plan = plan + + data := &storage.Int64FieldData{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + sw := &storage.StatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_Int64, data) + s.NoError(err) + cm := mocks.NewChunkManager(s.T()) + cm.EXPECT().MultiRead(mock.Anything, mock.Anything).Return([][]byte{sw.GetBuffer()}, nil) + s.task.cm = cm + + s.mockMeta.EXPECT().Schema().Return(&schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + IsPrimaryKey: true, + }, + }, + }) + + bfs, err := s.task.loadBF(plan.SegmentBinlogs) + s.NoError(err) + + s.Len(bfs, 1) + for _, pk := range s.dData.Pks { + lc := storage.NewLocationsCache(pk) + s.True(bfs[201].PkExists(lc)) + } +} + +func (s *LevelZeroCompactionTaskSuite) TestFailed() { + s.Run("no primary key", func() { + plan := &datapb.CompactionPlan{ + PlanID: 19530, + Type: datapb.CompactionType_Level0DeleteCompaction, + SegmentBinlogs: []*datapb.CompactionSegmentBinlogs{ + {SegmentID: 201, Level: datapb.SegmentLevel_L1, Field2StatslogPaths: []*datapb.FieldBinlog{ + { + Binlogs: []*datapb.Binlog{ + {LogID: 9999, LogSize: 100}, + }, + }, + }}, + }, + } + + s.task.plan = plan + + s.mockMeta.EXPECT().Schema().Return(&schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + IsPrimaryKey: false, + }, + }, + }) + + _, err := s.task.loadBF(plan.SegmentBinlogs) + s.Error(err) + }) + + s.Run("no l1 segments", func() { + plan := &datapb.CompactionPlan{ + PlanID: 19530, + Type: datapb.CompactionType_Level0DeleteCompaction, + SegmentBinlogs: []*datapb.CompactionSegmentBinlogs{ + {SegmentID: 201, Level: datapb.SegmentLevel_L0}, + }, + } + + s.task.plan = plan + + _, err := s.task.compact() + s.Error(err) + }) +} diff --git a/internal/datanode/services.go b/internal/datanode/services.go index 1b66f157f9770..ad8cb3039e7ec 100644 --- a/internal/datanode/services.go +++ b/internal/datanode/services.go @@ -245,6 +245,7 @@ func (node *DataNode) Compaction(ctx context.Context, req *datapb.CompactionPlan node.allocator, ds.metacache, node.syncMgr, + node.chunkManager, req, ) case datapb.CompactionType_MixCompaction: @@ -515,3 +516,16 @@ func (node *DataNode) DropImport(ctx context.Context, req *datapb.DropImportRequ return merr.Success(), nil } + +func (node *DataNode) QuerySlot(ctx context.Context, req *datapb.QuerySlotRequest) (*datapb.QuerySlotResponse, error) { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return &datapb.QuerySlotResponse{ + Status: merr.Status(err), + }, nil + } + + return &datapb.QuerySlotResponse{ + Status: merr.Success(), + NumSlots: Params.DataNodeCfg.SlotCap.GetAsInt64() - int64(node.compactionExecutor.executing.Len()), + }, nil +} diff --git a/internal/datanode/services_test.go b/internal/datanode/services_test.go index 00a803642a675..94eed7f5193e5 100644 --- a/internal/datanode/services_test.go +++ b/internal/datanode/services_test.go @@ -654,3 +654,25 @@ func (s *DataNodeServicesSuite) TestRPCWatch() { s.False(merr.Ok(resp.GetStatus())) }) } + +func (s *DataNodeServicesSuite) TestQuerySlot() { + s.Run("node not healthy", func() { + s.SetupTest() + s.node.UpdateStateCode(commonpb.StateCode_Abnormal) + + ctx := context.Background() + resp, err := s.node.QuerySlot(ctx, nil) + s.NoError(err) + s.False(merr.Ok(resp.GetStatus())) + s.ErrorIs(merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) + }) + + s.Run("normal case", func() { + s.SetupTest() + ctx := context.Background() + resp, err := s.node.QuerySlot(ctx, nil) + s.NoError(err) + s.True(merr.Ok(resp.GetStatus())) + s.NoError(merr.Error(resp.GetStatus())) + }) +} diff --git a/internal/distributed/datanode/client/client.go b/internal/distributed/datanode/client/client.go index 3734daf444f01..824f7762bd7a7 100644 --- a/internal/distributed/datanode/client/client.go +++ b/internal/distributed/datanode/client/client.go @@ -255,3 +255,9 @@ func (c *Client) DropImport(ctx context.Context, req *datapb.DropImportRequest, return client.DropImport(ctx, req) }) } + +func (c *Client) QuerySlot(ctx context.Context, req *datapb.QuerySlotRequest, opts ...grpc.CallOption) (*datapb.QuerySlotResponse, error) { + return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*datapb.QuerySlotResponse, error) { + return client.QuerySlot(ctx, req) + }) +} diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index 5bbf12224e577..492b7d490b915 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -402,3 +402,7 @@ func (s *Server) QueryImport(ctx context.Context, req *datapb.QueryImportRequest func (s *Server) DropImport(ctx context.Context, req *datapb.DropImportRequest) (*commonpb.Status, error) { return s.datanode.DropImport(ctx, req) } + +func (s *Server) QuerySlot(ctx context.Context, req *datapb.QuerySlotRequest) (*datapb.QuerySlotResponse, error) { + return s.datanode.QuerySlot(ctx, req) +} diff --git a/internal/distributed/datanode/service_test.go b/internal/distributed/datanode/service_test.go index b0200ebf996e8..0be3c5a493112 100644 --- a/internal/distributed/datanode/service_test.go +++ b/internal/distributed/datanode/service_test.go @@ -177,6 +177,10 @@ func (m *MockDataNode) DropImport(ctx context.Context, req *datapb.DropImportReq return m.status, m.err } +func (m *MockDataNode) QuerySlot(ctx context.Context, req *datapb.QuerySlotRequest) (*datapb.QuerySlotResponse, error) { + return &datapb.QuerySlotResponse{}, m.err +} + // ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// func Test_NewServer(t *testing.T) { paramtable.Init() diff --git a/internal/metastore/kv/datacoord/kv_catalog.go b/internal/metastore/kv/datacoord/kv_catalog.go index a7cae134c9ac3..f6b4190f75b38 100644 --- a/internal/metastore/kv/datacoord/kv_catalog.go +++ b/internal/metastore/kv/datacoord/kv_catalog.go @@ -45,8 +45,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -var maxEtcdTxnNum = 128 - var paginationSize = 2000 type Catalog struct { @@ -341,32 +339,10 @@ func (kc *Catalog) SaveByBatch(kvs map[string]string) error { saveFn := func(partialKvs map[string]string) error { return kc.MetaKv.MultiSave(partialKvs) } - if len(kvs) <= maxEtcdTxnNum { - if err := etcd.SaveByBatch(kvs, saveFn); err != nil { - log.Error("failed to save by batch", zap.Error(err)) - return err - } - } else { - // Split kvs into multiple operations to avoid over-sized operations. - // Also make sure kvs of the same segment are not split into different operations. - batch := make(map[string]string) - for k, v := range kvs { - if len(batch) == maxEtcdTxnNum { - if err := etcd.SaveByBatch(batch, saveFn); err != nil { - log.Error("failed to save by batch", zap.Error(err)) - return err - } - maps.Clear(batch) - } - batch[k] = v - } - - if len(batch) > 0 { - if err := etcd.SaveByBatch(batch, saveFn); err != nil { - log.Error("failed to save by batch", zap.Error(err)) - return err - } - } + err := etcd.SaveByBatchWithLimit(kvs, util.MaxEtcdTxnNum, saveFn) + if err != nil { + log.Error("failed to save by batch", zap.Error(err)) + return err } return nil } @@ -434,7 +410,7 @@ func (kc *Catalog) SaveDroppedSegmentsInBatch(ctx context.Context, segments []*d saveFn := func(partialKvs map[string]string) error { return kc.MetaKv.MultiSave(partialKvs) } - if err := etcd.SaveByBatch(kvs, saveFn); err != nil { + if err := etcd.SaveByBatchWithLimit(kvs, util.MaxEtcdTxnNum, saveFn); err != nil { return err } diff --git a/internal/metastore/kv/rootcoord/kv_catalog.go b/internal/metastore/kv/rootcoord/kv_catalog.go index 7bf240c0b5fc9..9edcfe13f6be5 100644 --- a/internal/metastore/kv/rootcoord/kv_catalog.go +++ b/internal/metastore/kv/rootcoord/kv_catalog.go @@ -27,10 +27,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -const ( - maxTxnNum = 64 -) - // prefix/collection/collection_id -> CollectionInfo // prefix/partitions/collection_id/partition_id -> PartitionInfo // prefix/aliases/alias_name -> AliasInfo @@ -87,11 +83,13 @@ func BuildAliasPrefixWithDB(dbID int64) string { return fmt.Sprintf("%s/%s/%d", DatabaseMetaPrefix, Aliases, dbID) } -func batchMultiSaveAndRemoveWithPrefix(snapshot kv.SnapShotKV, maxTxnNum int, saves map[string]string, removals []string, ts typeutil.Timestamp) error { +// since SnapshotKV may save both snapshot key and the original key if the original key is newest +// MaxEtcdTxnNum need to divided by 2 +func batchMultiSaveAndRemoveWithPrefix(snapshot kv.SnapShotKV, limit int, saves map[string]string, removals []string, ts typeutil.Timestamp) error { saveFn := func(partialKvs map[string]string) error { return snapshot.MultiSave(partialKvs, ts) } - if err := etcd.SaveByBatchWithLimit(saves, maxTxnNum, saveFn); err != nil { + if err := etcd.SaveByBatchWithLimit(saves, limit, saveFn); err != nil { return err } @@ -104,7 +102,7 @@ func batchMultiSaveAndRemoveWithPrefix(snapshot kv.SnapShotKV, maxTxnNum int, sa removeFn := func(partialKeys []string) error { return snapshot.MultiSaveAndRemoveWithPrefix(nil, partialKeys, ts) } - return etcd.RemoveByBatchWithLimit(removals, maxTxnNum/2, removeFn) + return etcd.RemoveByBatchWithLimit(removals, limit, removeFn) } func (kc *Catalog) CreateDatabase(ctx context.Context, db *model.Database, ts typeutil.Timestamp) error { @@ -200,7 +198,9 @@ func (kc *Catalog) CreateCollection(ctx context.Context, coll *model.Collection, // Though batchSave is not atomic enough, we can promise the atomicity outside. // Recovering from failure, if we found collection is creating, we should remove all these related meta. - return etcd.SaveByBatchWithLimit(kvs, maxTxnNum/2, func(partialKvs map[string]string) error { + // since SnapshotKV may save both snapshot key and the original key if the original key is newest + // MaxEtcdTxnNum need to divided by 2 + return etcd.SaveByBatchWithLimit(kvs, util.MaxEtcdTxnNum/2, func(partialKvs map[string]string) error { return kc.Snapshot.MultiSave(partialKvs, ts) }) } @@ -453,9 +453,9 @@ func (kc *Catalog) DropCollection(ctx context.Context, collectionInfo *model.Col // Though batchMultiSaveAndRemoveWithPrefix is not atomic enough, we can promise atomicity outside. // If we found collection under dropping state, we'll know that gc is not completely on this collection. // However, if we remove collection first, we cannot remove other metas. - // We set maxTxnNum to 64, since SnapshotKV may save both snapshot key and the original key if the original key is - // newest. - if err := batchMultiSaveAndRemoveWithPrefix(kc.Snapshot, maxTxnNum, nil, delMetakeysSnap, ts); err != nil { + // since SnapshotKV may save both snapshot key and the original key if the original key is newest + // MaxEtcdTxnNum need to divided by 2 + if err := batchMultiSaveAndRemoveWithPrefix(kc.Snapshot, util.MaxEtcdTxnNum/2, nil, delMetakeysSnap, ts); err != nil { return err } diff --git a/internal/metastore/kv/rootcoord/kv_catalog_test.go b/internal/metastore/kv/rootcoord/kv_catalog_test.go index b5c4502eb6729..7523c821677d5 100644 --- a/internal/metastore/kv/rootcoord/kv_catalog_test.go +++ b/internal/metastore/kv/rootcoord/kv_catalog_test.go @@ -949,7 +949,7 @@ func Test_batchMultiSaveAndRemoveWithPrefix(t *testing.T) { return errors.New("error mock MultiSave") } saves := map[string]string{"k": "v"} - err := batchMultiSaveAndRemoveWithPrefix(snapshot, maxTxnNum, saves, []string{}, 0) + err := batchMultiSaveAndRemoveWithPrefix(snapshot, util.MaxEtcdTxnNum/2, saves, []string{}, 0) assert.Error(t, err) }) t.Run("failed to remove", func(t *testing.T) { @@ -962,7 +962,7 @@ func Test_batchMultiSaveAndRemoveWithPrefix(t *testing.T) { } saves := map[string]string{"k": "v"} removals := []string{"prefix1", "prefix2"} - err := batchMultiSaveAndRemoveWithPrefix(snapshot, maxTxnNum, saves, removals, 0) + err := batchMultiSaveAndRemoveWithPrefix(snapshot, util.MaxEtcdTxnNum/2, saves, removals, 0) assert.Error(t, err) }) t.Run("normal case", func(t *testing.T) { @@ -983,7 +983,7 @@ func Test_batchMultiSaveAndRemoveWithPrefix(t *testing.T) { saves[fmt.Sprintf("k%d", i)] = fmt.Sprintf("v%d", i) removals = append(removals, fmt.Sprintf("k%d", i)) } - err := batchMultiSaveAndRemoveWithPrefix(snapshot, 64, saves, removals, 0) + err := batchMultiSaveAndRemoveWithPrefix(snapshot, util.MaxEtcdTxnNum/2, saves, removals, 0) assert.NoError(t, err) }) } diff --git a/internal/metastore/kv/rootcoord/suffix_snapshot.go b/internal/metastore/kv/rootcoord/suffix_snapshot.go index 832bfd45d9aa1..f945dc958d3b7 100644 --- a/internal/metastore/kv/rootcoord/suffix_snapshot.go +++ b/internal/metastore/kv/rootcoord/suffix_snapshot.go @@ -33,6 +33,7 @@ import ( "github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/tsoutil" @@ -596,7 +597,7 @@ func (ss *SuffixSnapshot) batchRemoveExpiredKvs(keyGroup []string, originalKey s removeFn := func(partialKeys []string) error { return ss.MetaKv.MultiRemove(keyGroup) } - return etcd.RemoveByBatch(keyGroup, removeFn) + return etcd.RemoveByBatchWithLimit(keyGroup, util.MaxEtcdTxnNum, removeFn) } func (ss *SuffixSnapshot) removeExpiredKvs(now time.Time) error { diff --git a/internal/mocks/mock_datanode.go b/internal/mocks/mock_datanode.go index 1a6b281c933b4..3392028c1bd48 100644 --- a/internal/mocks/mock_datanode.go +++ b/internal/mocks/mock_datanode.go @@ -64,8 +64,8 @@ type MockDataNode_CheckChannelOperationProgress_Call struct { } // CheckChannelOperationProgress is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *datapb.ChannelWatchInfo +// - _a0 context.Context +// - _a1 *datapb.ChannelWatchInfo func (_e *MockDataNode_Expecter) CheckChannelOperationProgress(_a0 interface{}, _a1 interface{}) *MockDataNode_CheckChannelOperationProgress_Call { return &MockDataNode_CheckChannelOperationProgress_Call{Call: _e.mock.On("CheckChannelOperationProgress", _a0, _a1)} } @@ -119,8 +119,8 @@ type MockDataNode_Compaction_Call struct { } // Compaction is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *datapb.CompactionPlan +// - _a0 context.Context +// - _a1 *datapb.CompactionPlan func (_e *MockDataNode_Expecter) Compaction(_a0 interface{}, _a1 interface{}) *MockDataNode_Compaction_Call { return &MockDataNode_Compaction_Call{Call: _e.mock.On("Compaction", _a0, _a1)} } @@ -174,8 +174,8 @@ type MockDataNode_DropImport_Call struct { } // DropImport is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *datapb.DropImportRequest +// - _a0 context.Context +// - _a1 *datapb.DropImportRequest func (_e *MockDataNode_Expecter) DropImport(_a0 interface{}, _a1 interface{}) *MockDataNode_DropImport_Call { return &MockDataNode_DropImport_Call{Call: _e.mock.On("DropImport", _a0, _a1)} } @@ -229,8 +229,8 @@ type MockDataNode_FlushChannels_Call struct { } // FlushChannels is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *datapb.FlushChannelsRequest +// - _a0 context.Context +// - _a1 *datapb.FlushChannelsRequest func (_e *MockDataNode_Expecter) FlushChannels(_a0 interface{}, _a1 interface{}) *MockDataNode_FlushChannels_Call { return &MockDataNode_FlushChannels_Call{Call: _e.mock.On("FlushChannels", _a0, _a1)} } @@ -284,8 +284,8 @@ type MockDataNode_FlushSegments_Call struct { } // FlushSegments is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *datapb.FlushSegmentsRequest +// - _a0 context.Context +// - _a1 *datapb.FlushSegmentsRequest func (_e *MockDataNode_Expecter) FlushSegments(_a0 interface{}, _a1 interface{}) *MockDataNode_FlushSegments_Call { return &MockDataNode_FlushSegments_Call{Call: _e.mock.On("FlushSegments", _a0, _a1)} } @@ -380,8 +380,8 @@ type MockDataNode_GetCompactionState_Call struct { } // GetCompactionState is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *datapb.CompactionStateRequest +// - _a0 context.Context +// - _a1 *datapb.CompactionStateRequest func (_e *MockDataNode_Expecter) GetCompactionState(_a0 interface{}, _a1 interface{}) *MockDataNode_GetCompactionState_Call { return &MockDataNode_GetCompactionState_Call{Call: _e.mock.On("GetCompactionState", _a0, _a1)} } @@ -435,8 +435,8 @@ type MockDataNode_GetComponentStates_Call struct { } // GetComponentStates is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *milvuspb.GetComponentStatesRequest +// - _a0 context.Context +// - _a1 *milvuspb.GetComponentStatesRequest func (_e *MockDataNode_Expecter) GetComponentStates(_a0 interface{}, _a1 interface{}) *MockDataNode_GetComponentStates_Call { return &MockDataNode_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", _a0, _a1)} } @@ -490,8 +490,8 @@ type MockDataNode_GetMetrics_Call struct { } // GetMetrics is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *milvuspb.GetMetricsRequest +// - _a0 context.Context +// - _a1 *milvuspb.GetMetricsRequest func (_e *MockDataNode_Expecter) GetMetrics(_a0 interface{}, _a1 interface{}) *MockDataNode_GetMetrics_Call { return &MockDataNode_GetMetrics_Call{Call: _e.mock.On("GetMetrics", _a0, _a1)} } @@ -627,8 +627,8 @@ type MockDataNode_GetStatisticsChannel_Call struct { } // GetStatisticsChannel is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *internalpb.GetStatisticsChannelRequest +// - _a0 context.Context +// - _a1 *internalpb.GetStatisticsChannelRequest func (_e *MockDataNode_Expecter) GetStatisticsChannel(_a0 interface{}, _a1 interface{}) *MockDataNode_GetStatisticsChannel_Call { return &MockDataNode_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", _a0, _a1)} } @@ -682,8 +682,8 @@ type MockDataNode_ImportV2_Call struct { } // ImportV2 is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *datapb.ImportRequest +// - _a0 context.Context +// - _a1 *datapb.ImportRequest func (_e *MockDataNode_Expecter) ImportV2(_a0 interface{}, _a1 interface{}) *MockDataNode_ImportV2_Call { return &MockDataNode_ImportV2_Call{Call: _e.mock.On("ImportV2", _a0, _a1)} } @@ -778,8 +778,8 @@ type MockDataNode_NotifyChannelOperation_Call struct { } // NotifyChannelOperation is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *datapb.ChannelOperationsRequest +// - _a0 context.Context +// - _a1 *datapb.ChannelOperationsRequest func (_e *MockDataNode_Expecter) NotifyChannelOperation(_a0 interface{}, _a1 interface{}) *MockDataNode_NotifyChannelOperation_Call { return &MockDataNode_NotifyChannelOperation_Call{Call: _e.mock.On("NotifyChannelOperation", _a0, _a1)} } @@ -833,8 +833,8 @@ type MockDataNode_PreImport_Call struct { } // PreImport is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *datapb.PreImportRequest +// - _a0 context.Context +// - _a1 *datapb.PreImportRequest func (_e *MockDataNode_Expecter) PreImport(_a0 interface{}, _a1 interface{}) *MockDataNode_PreImport_Call { return &MockDataNode_PreImport_Call{Call: _e.mock.On("PreImport", _a0, _a1)} } @@ -888,8 +888,8 @@ type MockDataNode_QueryImport_Call struct { } // QueryImport is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *datapb.QueryImportRequest +// - _a0 context.Context +// - _a1 *datapb.QueryImportRequest func (_e *MockDataNode_Expecter) QueryImport(_a0 interface{}, _a1 interface{}) *MockDataNode_QueryImport_Call { return &MockDataNode_QueryImport_Call{Call: _e.mock.On("QueryImport", _a0, _a1)} } @@ -943,8 +943,8 @@ type MockDataNode_QueryPreImport_Call struct { } // QueryPreImport is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *datapb.QueryPreImportRequest +// - _a0 context.Context +// - _a1 *datapb.QueryPreImportRequest func (_e *MockDataNode_Expecter) QueryPreImport(_a0 interface{}, _a1 interface{}) *MockDataNode_QueryPreImport_Call { return &MockDataNode_QueryPreImport_Call{Call: _e.mock.On("QueryPreImport", _a0, _a1)} } @@ -966,6 +966,61 @@ func (_c *MockDataNode_QueryPreImport_Call) RunAndReturn(run func(context.Contex return _c } +// QuerySlot provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) QuerySlot(_a0 context.Context, _a1 *datapb.QuerySlotRequest) (*datapb.QuerySlotResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *datapb.QuerySlotResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.QuerySlotRequest) (*datapb.QuerySlotResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.QuerySlotRequest) *datapb.QuerySlotResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.QuerySlotResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.QuerySlotRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNode_QuerySlot_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QuerySlot' +type MockDataNode_QuerySlot_Call struct { + *mock.Call +} + +// QuerySlot is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *datapb.QuerySlotRequest +func (_e *MockDataNode_Expecter) QuerySlot(_a0 interface{}, _a1 interface{}) *MockDataNode_QuerySlot_Call { + return &MockDataNode_QuerySlot_Call{Call: _e.mock.On("QuerySlot", _a0, _a1)} +} + +func (_c *MockDataNode_QuerySlot_Call) Run(run func(_a0 context.Context, _a1 *datapb.QuerySlotRequest)) *MockDataNode_QuerySlot_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*datapb.QuerySlotRequest)) + }) + return _c +} + +func (_c *MockDataNode_QuerySlot_Call) Return(_a0 *datapb.QuerySlotResponse, _a1 error) *MockDataNode_QuerySlot_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNode_QuerySlot_Call) RunAndReturn(run func(context.Context, *datapb.QuerySlotRequest) (*datapb.QuerySlotResponse, error)) *MockDataNode_QuerySlot_Call { + _c.Call.Return(run) + return _c +} + // Register provides a mock function with given fields: func (_m *MockDataNode) Register() error { ret := _m.Called() @@ -1039,8 +1094,8 @@ type MockDataNode_ResendSegmentStats_Call struct { } // ResendSegmentStats is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *datapb.ResendSegmentStatsRequest +// - _a0 context.Context +// - _a1 *datapb.ResendSegmentStatsRequest func (_e *MockDataNode_Expecter) ResendSegmentStats(_a0 interface{}, _a1 interface{}) *MockDataNode_ResendSegmentStats_Call { return &MockDataNode_ResendSegmentStats_Call{Call: _e.mock.On("ResendSegmentStats", _a0, _a1)} } @@ -1073,7 +1128,7 @@ type MockDataNode_SetAddress_Call struct { } // SetAddress is a helper method to define mock.On call -// - address string +// - address string func (_e *MockDataNode_Expecter) SetAddress(address interface{}) *MockDataNode_SetAddress_Call { return &MockDataNode_SetAddress_Call{Call: _e.mock.On("SetAddress", address)} } @@ -1115,7 +1170,7 @@ type MockDataNode_SetDataCoordClient_Call struct { } // SetDataCoordClient is a helper method to define mock.On call -// - dataCoord types.DataCoordClient +// - dataCoord types.DataCoordClient func (_e *MockDataNode_Expecter) SetDataCoordClient(dataCoord interface{}) *MockDataNode_SetDataCoordClient_Call { return &MockDataNode_SetDataCoordClient_Call{Call: _e.mock.On("SetDataCoordClient", dataCoord)} } @@ -1148,7 +1203,7 @@ type MockDataNode_SetEtcdClient_Call struct { } // SetEtcdClient is a helper method to define mock.On call -// - etcdClient *clientv3.Client +// - etcdClient *clientv3.Client func (_e *MockDataNode_Expecter) SetEtcdClient(etcdClient interface{}) *MockDataNode_SetEtcdClient_Call { return &MockDataNode_SetEtcdClient_Call{Call: _e.mock.On("SetEtcdClient", etcdClient)} } @@ -1190,7 +1245,7 @@ type MockDataNode_SetRootCoordClient_Call struct { } // SetRootCoordClient is a helper method to define mock.On call -// - rootCoord types.RootCoordClient +// - rootCoord types.RootCoordClient func (_e *MockDataNode_Expecter) SetRootCoordClient(rootCoord interface{}) *MockDataNode_SetRootCoordClient_Call { return &MockDataNode_SetRootCoordClient_Call{Call: _e.mock.On("SetRootCoordClient", rootCoord)} } @@ -1244,8 +1299,8 @@ type MockDataNode_ShowConfigurations_Call struct { } // ShowConfigurations is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *internalpb.ShowConfigurationsRequest +// - _a0 context.Context +// - _a1 *internalpb.ShowConfigurationsRequest func (_e *MockDataNode_Expecter) ShowConfigurations(_a0 interface{}, _a1 interface{}) *MockDataNode_ShowConfigurations_Call { return &MockDataNode_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", _a0, _a1)} } @@ -1381,8 +1436,8 @@ type MockDataNode_SyncSegments_Call struct { } // SyncSegments is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *datapb.SyncSegmentsRequest +// - _a0 context.Context +// - _a1 *datapb.SyncSegmentsRequest func (_e *MockDataNode_Expecter) SyncSegments(_a0 interface{}, _a1 interface{}) *MockDataNode_SyncSegments_Call { return &MockDataNode_SyncSegments_Call{Call: _e.mock.On("SyncSegments", _a0, _a1)} } @@ -1415,7 +1470,7 @@ type MockDataNode_UpdateStateCode_Call struct { } // UpdateStateCode is a helper method to define mock.On call -// - stateCode commonpb.StateCode +// - stateCode commonpb.StateCode func (_e *MockDataNode_Expecter) UpdateStateCode(stateCode interface{}) *MockDataNode_UpdateStateCode_Call { return &MockDataNode_UpdateStateCode_Call{Call: _e.mock.On("UpdateStateCode", stateCode)} } @@ -1469,8 +1524,8 @@ type MockDataNode_WatchDmChannels_Call struct { } // WatchDmChannels is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *datapb.WatchDmChannelsRequest +// - _a0 context.Context +// - _a1 *datapb.WatchDmChannelsRequest func (_e *MockDataNode_Expecter) WatchDmChannels(_a0 interface{}, _a1 interface{}) *MockDataNode_WatchDmChannels_Call { return &MockDataNode_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", _a0, _a1)} } diff --git a/internal/mocks/mock_datanode_client.go b/internal/mocks/mock_datanode_client.go index ead0d9136b850..78f7aeec32131 100644 --- a/internal/mocks/mock_datanode_client.go +++ b/internal/mocks/mock_datanode_client.go @@ -70,9 +70,9 @@ type MockDataNodeClient_CheckChannelOperationProgress_Call struct { } // CheckChannelOperationProgress is a helper method to define mock.On call -// - ctx context.Context -// - in *datapb.ChannelWatchInfo -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *datapb.ChannelWatchInfo +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) CheckChannelOperationProgress(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_CheckChannelOperationProgress_Call { return &MockDataNodeClient_CheckChannelOperationProgress_Call{Call: _e.mock.On("CheckChannelOperationProgress", append([]interface{}{ctx, in}, opts...)...)} @@ -181,9 +181,9 @@ type MockDataNodeClient_Compaction_Call struct { } // Compaction is a helper method to define mock.On call -// - ctx context.Context -// - in *datapb.CompactionPlan -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *datapb.CompactionPlan +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) Compaction(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_Compaction_Call { return &MockDataNodeClient_Compaction_Call{Call: _e.mock.On("Compaction", append([]interface{}{ctx, in}, opts...)...)} @@ -251,9 +251,9 @@ type MockDataNodeClient_DropImport_Call struct { } // DropImport is a helper method to define mock.On call -// - ctx context.Context -// - in *datapb.DropImportRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *datapb.DropImportRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) DropImport(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_DropImport_Call { return &MockDataNodeClient_DropImport_Call{Call: _e.mock.On("DropImport", append([]interface{}{ctx, in}, opts...)...)} @@ -321,9 +321,9 @@ type MockDataNodeClient_FlushChannels_Call struct { } // FlushChannels is a helper method to define mock.On call -// - ctx context.Context -// - in *datapb.FlushChannelsRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *datapb.FlushChannelsRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) FlushChannels(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_FlushChannels_Call { return &MockDataNodeClient_FlushChannels_Call{Call: _e.mock.On("FlushChannels", append([]interface{}{ctx, in}, opts...)...)} @@ -391,9 +391,9 @@ type MockDataNodeClient_FlushSegments_Call struct { } // FlushSegments is a helper method to define mock.On call -// - ctx context.Context -// - in *datapb.FlushSegmentsRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *datapb.FlushSegmentsRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) FlushSegments(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_FlushSegments_Call { return &MockDataNodeClient_FlushSegments_Call{Call: _e.mock.On("FlushSegments", append([]interface{}{ctx, in}, opts...)...)} @@ -461,9 +461,9 @@ type MockDataNodeClient_GetCompactionState_Call struct { } // GetCompactionState is a helper method to define mock.On call -// - ctx context.Context -// - in *datapb.CompactionStateRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *datapb.CompactionStateRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) GetCompactionState(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_GetCompactionState_Call { return &MockDataNodeClient_GetCompactionState_Call{Call: _e.mock.On("GetCompactionState", append([]interface{}{ctx, in}, opts...)...)} @@ -531,9 +531,9 @@ type MockDataNodeClient_GetComponentStates_Call struct { } // GetComponentStates is a helper method to define mock.On call -// - ctx context.Context -// - in *milvuspb.GetComponentStatesRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *milvuspb.GetComponentStatesRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) GetComponentStates(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_GetComponentStates_Call { return &MockDataNodeClient_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", append([]interface{}{ctx, in}, opts...)...)} @@ -601,9 +601,9 @@ type MockDataNodeClient_GetMetrics_Call struct { } // GetMetrics is a helper method to define mock.On call -// - ctx context.Context -// - in *milvuspb.GetMetricsRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *milvuspb.GetMetricsRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) GetMetrics(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_GetMetrics_Call { return &MockDataNodeClient_GetMetrics_Call{Call: _e.mock.On("GetMetrics", append([]interface{}{ctx, in}, opts...)...)} @@ -671,9 +671,9 @@ type MockDataNodeClient_GetStatisticsChannel_Call struct { } // GetStatisticsChannel is a helper method to define mock.On call -// - ctx context.Context -// - in *internalpb.GetStatisticsChannelRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *internalpb.GetStatisticsChannelRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) GetStatisticsChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_GetStatisticsChannel_Call { return &MockDataNodeClient_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", append([]interface{}{ctx, in}, opts...)...)} @@ -741,9 +741,9 @@ type MockDataNodeClient_ImportV2_Call struct { } // ImportV2 is a helper method to define mock.On call -// - ctx context.Context -// - in *datapb.ImportRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *datapb.ImportRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) ImportV2(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_ImportV2_Call { return &MockDataNodeClient_ImportV2_Call{Call: _e.mock.On("ImportV2", append([]interface{}{ctx, in}, opts...)...)} @@ -811,9 +811,9 @@ type MockDataNodeClient_NotifyChannelOperation_Call struct { } // NotifyChannelOperation is a helper method to define mock.On call -// - ctx context.Context -// - in *datapb.ChannelOperationsRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *datapb.ChannelOperationsRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) NotifyChannelOperation(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_NotifyChannelOperation_Call { return &MockDataNodeClient_NotifyChannelOperation_Call{Call: _e.mock.On("NotifyChannelOperation", append([]interface{}{ctx, in}, opts...)...)} @@ -881,9 +881,9 @@ type MockDataNodeClient_PreImport_Call struct { } // PreImport is a helper method to define mock.On call -// - ctx context.Context -// - in *datapb.PreImportRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *datapb.PreImportRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) PreImport(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_PreImport_Call { return &MockDataNodeClient_PreImport_Call{Call: _e.mock.On("PreImport", append([]interface{}{ctx, in}, opts...)...)} @@ -951,9 +951,9 @@ type MockDataNodeClient_QueryImport_Call struct { } // QueryImport is a helper method to define mock.On call -// - ctx context.Context -// - in *datapb.QueryImportRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *datapb.QueryImportRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) QueryImport(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_QueryImport_Call { return &MockDataNodeClient_QueryImport_Call{Call: _e.mock.On("QueryImport", append([]interface{}{ctx, in}, opts...)...)} @@ -1021,9 +1021,9 @@ type MockDataNodeClient_QueryPreImport_Call struct { } // QueryPreImport is a helper method to define mock.On call -// - ctx context.Context -// - in *datapb.QueryPreImportRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *datapb.QueryPreImportRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) QueryPreImport(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_QueryPreImport_Call { return &MockDataNodeClient_QueryPreImport_Call{Call: _e.mock.On("QueryPreImport", append([]interface{}{ctx, in}, opts...)...)} @@ -1052,6 +1052,76 @@ func (_c *MockDataNodeClient_QueryPreImport_Call) RunAndReturn(run func(context. return _c } +// QuerySlot provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) QuerySlot(ctx context.Context, in *datapb.QuerySlotRequest, opts ...grpc.CallOption) (*datapb.QuerySlotResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.QuerySlotResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.QuerySlotRequest, ...grpc.CallOption) (*datapb.QuerySlotResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.QuerySlotRequest, ...grpc.CallOption) *datapb.QuerySlotResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.QuerySlotResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.QuerySlotRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_QuerySlot_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QuerySlot' +type MockDataNodeClient_QuerySlot_Call struct { + *mock.Call +} + +// QuerySlot is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.QuerySlotRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) QuerySlot(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_QuerySlot_Call { + return &MockDataNodeClient_QuerySlot_Call{Call: _e.mock.On("QuerySlot", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_QuerySlot_Call) Run(run func(ctx context.Context, in *datapb.QuerySlotRequest, opts ...grpc.CallOption)) *MockDataNodeClient_QuerySlot_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.QuerySlotRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_QuerySlot_Call) Return(_a0 *datapb.QuerySlotResponse, _a1 error) *MockDataNodeClient_QuerySlot_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_QuerySlot_Call) RunAndReturn(run func(context.Context, *datapb.QuerySlotRequest, ...grpc.CallOption) (*datapb.QuerySlotResponse, error)) *MockDataNodeClient_QuerySlot_Call { + _c.Call.Return(run) + return _c +} + // ResendSegmentStats provides a mock function with given fields: ctx, in, opts func (_m *MockDataNodeClient) ResendSegmentStats(ctx context.Context, in *datapb.ResendSegmentStatsRequest, opts ...grpc.CallOption) (*datapb.ResendSegmentStatsResponse, error) { _va := make([]interface{}, len(opts)) @@ -1091,9 +1161,9 @@ type MockDataNodeClient_ResendSegmentStats_Call struct { } // ResendSegmentStats is a helper method to define mock.On call -// - ctx context.Context -// - in *datapb.ResendSegmentStatsRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *datapb.ResendSegmentStatsRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) ResendSegmentStats(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_ResendSegmentStats_Call { return &MockDataNodeClient_ResendSegmentStats_Call{Call: _e.mock.On("ResendSegmentStats", append([]interface{}{ctx, in}, opts...)...)} @@ -1161,9 +1231,9 @@ type MockDataNodeClient_ShowConfigurations_Call struct { } // ShowConfigurations is a helper method to define mock.On call -// - ctx context.Context -// - in *internalpb.ShowConfigurationsRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *internalpb.ShowConfigurationsRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) ShowConfigurations(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_ShowConfigurations_Call { return &MockDataNodeClient_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", append([]interface{}{ctx, in}, opts...)...)} @@ -1231,9 +1301,9 @@ type MockDataNodeClient_SyncSegments_Call struct { } // SyncSegments is a helper method to define mock.On call -// - ctx context.Context -// - in *datapb.SyncSegmentsRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *datapb.SyncSegmentsRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) SyncSegments(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_SyncSegments_Call { return &MockDataNodeClient_SyncSegments_Call{Call: _e.mock.On("SyncSegments", append([]interface{}{ctx, in}, opts...)...)} @@ -1301,9 +1371,9 @@ type MockDataNodeClient_WatchDmChannels_Call struct { } // WatchDmChannels is a helper method to define mock.On call -// - ctx context.Context -// - in *datapb.WatchDmChannelsRequest -// - opts ...grpc.CallOption +// - ctx context.Context +// - in *datapb.WatchDmChannelsRequest +// - opts ...grpc.CallOption func (_e *MockDataNodeClient_Expecter) WatchDmChannels(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_WatchDmChannels_Call { return &MockDataNodeClient_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", append([]interface{}{ctx, in}, opts...)...)} diff --git a/internal/proto/data_coord.proto b/internal/proto/data_coord.proto index da8a36597ae16..ecb29e3be162c 100644 --- a/internal/proto/data_coord.proto +++ b/internal/proto/data_coord.proto @@ -127,6 +127,8 @@ service DataNode { rpc QueryPreImport(QueryPreImportRequest) returns(QueryPreImportResponse) {} rpc QueryImport(QueryImportRequest) returns(QueryImportResponse) {} rpc DropImport(DropImportRequest) returns(common.Status) {} + + rpc QuerySlot(QuerySlotRequest) returns(QuerySlotResponse) {} } message FlushRequest { @@ -832,3 +834,10 @@ message GcControlRequest { GcCommand command = 2; repeated common.KeyValuePair params = 3; } + +message QuerySlotRequest {} + +message QuerySlotResponse { + common.Status status = 1; + int64 num_slots = 2; +} diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go index f1807a9c170a1..53c6b89f86f08 100644 --- a/internal/proxy/validate_util.go +++ b/internal/proxy/validate_util.go @@ -759,6 +759,9 @@ func (v *validateUtil) checkArrayElement(array *schemapb.ArrayArray, field *sche switch field.GetElementType() { case schemapb.DataType_Bool: for _, row := range array.GetData() { + if row.GetData() == nil { + return merr.WrapErrParameterInvalid("bool array", "nil array", "insert data does not match") + } actualType := reflect.TypeOf(row.GetData()) if actualType != reflect.TypeOf((*schemapb.ScalarField_BoolData)(nil)) { return merr.WrapErrParameterInvalid("bool array", @@ -767,6 +770,9 @@ func (v *validateUtil) checkArrayElement(array *schemapb.ArrayArray, field *sche } case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: for _, row := range array.GetData() { + if row.GetData() == nil { + return merr.WrapErrParameterInvalid("int array", "nil array", "insert data does not match") + } actualType := reflect.TypeOf(row.GetData()) if actualType != reflect.TypeOf((*schemapb.ScalarField_IntData)(nil)) { return merr.WrapErrParameterInvalid("int array", @@ -787,6 +793,9 @@ func (v *validateUtil) checkArrayElement(array *schemapb.ArrayArray, field *sche } case schemapb.DataType_Int64: for _, row := range array.GetData() { + if row.GetData() == nil { + return merr.WrapErrParameterInvalid("int64 array", "nil array", "insert data does not match") + } actualType := reflect.TypeOf(row.GetData()) if actualType != reflect.TypeOf((*schemapb.ScalarField_LongData)(nil)) { return merr.WrapErrParameterInvalid("int64 array", @@ -795,6 +804,9 @@ func (v *validateUtil) checkArrayElement(array *schemapb.ArrayArray, field *sche } case schemapb.DataType_Float: for _, row := range array.GetData() { + if row.GetData() == nil { + return merr.WrapErrParameterInvalid("float array", "nil array", "insert data does not match") + } actualType := reflect.TypeOf(row.GetData()) if actualType != reflect.TypeOf((*schemapb.ScalarField_FloatData)(nil)) { return merr.WrapErrParameterInvalid("float array", @@ -803,6 +815,9 @@ func (v *validateUtil) checkArrayElement(array *schemapb.ArrayArray, field *sche } case schemapb.DataType_Double: for _, row := range array.GetData() { + if row.GetData() == nil { + return merr.WrapErrParameterInvalid("double array", "nil array", "insert data does not match") + } actualType := reflect.TypeOf(row.GetData()) if actualType != reflect.TypeOf((*schemapb.ScalarField_DoubleData)(nil)) { return merr.WrapErrParameterInvalid("double array", @@ -811,6 +826,9 @@ func (v *validateUtil) checkArrayElement(array *schemapb.ArrayArray, field *sche } case schemapb.DataType_VarChar, schemapb.DataType_String: for _, row := range array.GetData() { + if row.GetData() == nil { + return merr.WrapErrParameterInvalid("string array", "nil array", "insert data does not match") + } actualType := reflect.TypeOf(row.GetData()) if actualType != reflect.TypeOf((*schemapb.ScalarField_StringData)(nil)) { return merr.WrapErrParameterInvalid("string array", diff --git a/internal/proxy/validate_util_test.go b/internal/proxy/validate_util_test.go index 2d02cf7a24dc9..5c4079dbe171d 100644 --- a/internal/proxy/validate_util_test.go +++ b/internal/proxy/validate_util_test.go @@ -6026,3 +6026,19 @@ func Test_validateUtil_checkDoubleFieldData(t *testing.T) { }, }, nil)) } + +func TestCheckArrayElementNilData(t *testing.T) { + data := &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{nil}, + } + + fieldSchema := &schemapb.FieldSchema{ + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int64, + } + + v := newValidateUtil() + err := v.checkArrayElement(data, fieldSchema) + assert.True(t, merr.ErrParameterInvalid.Is(err)) +} diff --git a/internal/querycoordv2/balance/channel_level_score_balancer.go b/internal/querycoordv2/balance/channel_level_score_balancer.go index 5e5e69d7c4f3e..cb59eb67a15ae 100644 --- a/internal/querycoordv2/balance/channel_level_score_balancer.go +++ b/internal/querycoordv2/balance/channel_level_score_balancer.go @@ -77,68 +77,51 @@ func (b *ChannelLevelScoreBalancer) BalanceReplica(replica *meta.Replica) ([]Seg return nil, nil } - onlineNodes := make([]int64, 0) - offlineNodes := make([]int64, 0) - // read only nodes is offline in current replica. - if replica.RONodesCount() > 0 { - // if node is stop or transfer to other rg - log.RatedInfo(10, "meet read only node, try to move out all segment/channel", zap.Int64s("node", replica.GetRONodes())) - offlineNodes = append(offlineNodes, replica.GetRONodes()...) - } + rwNodes := replica.GetChannelRWNodes(channelName) + roNodes := replica.GetRONodes() // mark channel's outbound access node as offline - channelRWNode := typeutil.NewUniqueSet(replica.GetChannelRWNodes(channelName)...) + channelRWNode := typeutil.NewUniqueSet(rwNodes...) channelDist := b.dist.ChannelDistManager.GetByFilter(meta.WithChannelName2Channel(channelName), meta.WithReplica2Channel(replica)) for _, channel := range channelDist { if !channelRWNode.Contain(channel.Node) { - offlineNodes = append(offlineNodes, channel.Node) + roNodes = append(roNodes, channel.Node) } } segmentDist := b.dist.SegmentDistManager.GetByFilter(meta.WithChannel(channelName), meta.WithReplica(replica)) for _, segment := range segmentDist { if !channelRWNode.Contain(segment.Node) { - offlineNodes = append(offlineNodes, segment.Node) - } - } - - for nid := range channelRWNode { - if isStopping, err := b.nodeManager.IsStoppingNode(nid); err != nil { - log.Info("not existed node", zap.Int64("nid", nid), zap.Error(err)) - continue - } else if isStopping { - offlineNodes = append(offlineNodes, nid) - } else { - onlineNodes = append(onlineNodes, nid) + roNodes = append(roNodes, segment.Node) } } - if len(onlineNodes) == 0 { + if len(rwNodes) == 0 { // no available nodes to balance return nil, nil } - if len(offlineNodes) != 0 { + if len(roNodes) != 0 { if !paramtable.Get().QueryCoordCfg.EnableStoppingBalance.GetAsBool() { - log.RatedInfo(10, "stopping balance is disabled!", zap.Int64s("stoppingNode", offlineNodes)) + log.RatedInfo(10, "stopping balance is disabled!", zap.Int64s("stoppingNode", roNodes)) return nil, nil } log.Info("Handle stopping nodes", - zap.Any("stopping nodes", offlineNodes), - zap.Any("available nodes", onlineNodes), + zap.Any("stopping nodes", roNodes), + zap.Any("available nodes", rwNodes), ) // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score - channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, channelName, onlineNodes, offlineNodes)...) + channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, channelName, rwNodes, roNodes)...) if len(channelPlans) == 0 { - segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, channelName, onlineNodes, offlineNodes)...) + segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, channelName, rwNodes, roNodes)...) } } else { if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() { - channelPlans = append(channelPlans, b.genChannelPlan(replica, channelName, onlineNodes)...) + channelPlans = append(channelPlans, b.genChannelPlan(replica, channelName, rwNodes)...) } if len(channelPlans) == 0 { - segmentPlans = append(segmentPlans, b.genSegmentPlan(replica, channelName, onlineNodes)...) + segmentPlans = append(segmentPlans, b.genSegmentPlan(replica, channelName, rwNodes)...) } } } diff --git a/internal/querycoordv2/balance/channel_level_score_balancer_test.go b/internal/querycoordv2/balance/channel_level_score_balancer_test.go index 87c0841c71ef0..219ee694d349a 100644 --- a/internal/querycoordv2/balance/channel_level_score_balancer_test.go +++ b/internal/querycoordv2/balance/channel_level_score_balancer_test.go @@ -1162,8 +1162,13 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Nod }, }...) - suite.balancer.nodeManager.Stopping(ch1Nodes[0]) - suite.balancer.nodeManager.Stopping(ch2Nodes[0]) + balancer.nodeManager.Stopping(ch1Nodes[0]) + balancer.nodeManager.Stopping(ch2Nodes[0]) + suite.balancer.meta.ResourceManager.HandleNodeStopping(ch1Nodes[0]) + suite.balancer.meta.ResourceManager.HandleNodeStopping(ch2Nodes[0]) + utils.RecoverAllCollection(balancer.meta) + + replica = balancer.meta.ReplicaManager.Get(replica.GetID()) sPlans, cPlans := balancer.BalanceReplica(replica) suite.Len(sPlans, 0) suite.Len(cPlans, 2) diff --git a/internal/querycoordv2/balance/multi_target_balance.go b/internal/querycoordv2/balance/multi_target_balance.go index cdab89ded2523..8874ee0bbb2ec 100644 --- a/internal/querycoordv2/balance/multi_target_balance.go +++ b/internal/querycoordv2/balance/multi_target_balance.go @@ -466,67 +466,49 @@ func (b *MultiTargetBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAs return nil, nil } - onlineNodes := make([]int64, 0) - offlineNodes := make([]int64, 0) + rwNodes := replica.GetRWNodes() + roNodes := replica.GetRONodes() - // read only nodes is offline in current replica. - if replica.RONodesCount() > 0 { - // if node is stop or transfer to other rg - log.RatedInfo(10, "meet read only node, try to move out all segment/channel", zap.Int64s("node", replica.GetRONodes())) - offlineNodes = append(offlineNodes, replica.GetRONodes()...) - } - - for _, nid := range replica.GetNodes() { - if isStopping, err := b.nodeManager.IsStoppingNode(nid); err != nil { - log.Info("not existed node", zap.Int64("nid", nid), zap.Error(err)) - continue - } else if isStopping { - offlineNodes = append(offlineNodes, nid) - } else { - onlineNodes = append(onlineNodes, nid) - } - } - - if len(onlineNodes) == 0 { + if len(rwNodes) == 0 { // no available nodes to balance return nil, nil } // print current distribution before generating plans segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0) - if len(offlineNodes) != 0 { + if len(roNodes) != 0 { if !paramtable.Get().QueryCoordCfg.EnableStoppingBalance.GetAsBool() { - log.RatedInfo(10, "stopping balance is disabled!", zap.Int64s("stoppingNode", offlineNodes)) + log.RatedInfo(10, "stopping balance is disabled!", zap.Int64s("stoppingNode", roNodes)) return nil, nil } log.Info("Handle stopping nodes", - zap.Any("stopping nodes", offlineNodes), - zap.Any("available nodes", onlineNodes), + zap.Any("stopping nodes", roNodes), + zap.Any("available nodes", rwNodes), ) // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score - channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, onlineNodes, offlineNodes)...) + channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...) if len(channelPlans) == 0 { - segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, onlineNodes, offlineNodes)...) + segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, rwNodes, roNodes)...) } } else { if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() { - channelPlans = append(channelPlans, b.genChannelPlan(replica, onlineNodes)...) + channelPlans = append(channelPlans, b.genChannelPlan(replica, rwNodes)...) } if len(channelPlans) == 0 { - segmentPlans = b.genSegmentPlan(replica) + segmentPlans = b.genSegmentPlan(replica, rwNodes) } } return segmentPlans, channelPlans } -func (b *MultiTargetBalancer) genSegmentPlan(replica *meta.Replica) []SegmentAssignPlan { +func (b *MultiTargetBalancer) genSegmentPlan(replica *meta.Replica, rwNodes []int64) []SegmentAssignPlan { // get segments distribution on replica level and global level nodeSegments := make(map[int64][]*meta.Segment) globalNodeSegments := make(map[int64][]*meta.Segment) - for _, node := range replica.GetNodes() { + for _, node := range rwNodes { dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node)) segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool { return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil && diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index cab6bd5488c6d..15ee6f80f8ae0 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -126,9 +126,7 @@ func (b *RowCountBasedBalancer) AssignChannel(channels []*meta.DmChannel, nodes func (b *RowCountBasedBalancer) convertToNodeItemsBySegment(nodeIDs []int64) []*nodeItem { ret := make([]*nodeItem, 0, len(nodeIDs)) - for _, nodeInfo := range b.getNodes(nodeIDs) { - node := nodeInfo.ID() - + for _, node := range nodeIDs { // calculate sealed segment row count on node segments := b.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(node)) rowcnt := 0 @@ -151,8 +149,7 @@ func (b *RowCountBasedBalancer) convertToNodeItemsBySegment(nodeIDs []int64) []* func (b *RowCountBasedBalancer) convertToNodeItemsByChannel(nodeIDs []int64) []*nodeItem { ret := make([]*nodeItem, 0, len(nodeIDs)) - for _, nodeInfo := range b.getNodes(nodeIDs) { - node := nodeInfo.ID() + for _, node := range nodeIDs { channels := b.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(node)) // more channel num, less priority @@ -172,71 +169,52 @@ func (b *RowCountBasedBalancer) BalanceReplica(replica *meta.Replica) ([]Segment return nil, nil } - onlineNodes := make([]int64, 0) - offlineNodes := make([]int64, 0) - - // read only nodes is offline in current replica. - if replica.RONodesCount() > 0 { - // if node is stop or transfer to other rg - log.RatedInfo(10, "meet read only node, try to move out all segment/channel", zap.Int64s("node", replica.GetRONodes())) - offlineNodes = append(offlineNodes, replica.GetRONodes()...) - } - - for _, nid := range replica.GetNodes() { - if isStopping, err := b.nodeManager.IsStoppingNode(nid); err != nil { - log.Info("not existed node", zap.Int64("nid", nid), zap.Error(err)) - continue - } else if isStopping { - offlineNodes = append(offlineNodes, nid) - } else { - onlineNodes = append(onlineNodes, nid) - } - } - - if len(onlineNodes) == 0 { + rwNodes := replica.GetRWNodes() + roNodes := replica.GetRONodes() + if len(rwNodes) == 0 { // no available nodes to balance return nil, nil } segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0) - if len(offlineNodes) != 0 { + if len(roNodes) != 0 { if !paramtable.Get().QueryCoordCfg.EnableStoppingBalance.GetAsBool() { - log.RatedInfo(10, "stopping balance is disabled!", zap.Int64s("stoppingNode", offlineNodes)) + log.RatedInfo(10, "stopping balance is disabled!", zap.Int64s("stoppingNode", roNodes)) return nil, nil } log.Info("Handle stopping nodes", - zap.Any("stopping nodes", offlineNodes), - zap.Any("available nodes", onlineNodes), + zap.Any("stopping nodes", roNodes), + zap.Any("available nodes", rwNodes), ) // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score - channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, onlineNodes, offlineNodes)...) + channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...) if len(channelPlans) == 0 { - segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, onlineNodes, offlineNodes)...) + segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, rwNodes, roNodes)...) } } else { if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() { - channelPlans = append(channelPlans, b.genChannelPlan(replica, onlineNodes)...) + channelPlans = append(channelPlans, b.genChannelPlan(replica, rwNodes)...) } if len(channelPlans) == 0 { - segmentPlans = append(segmentPlans, b.genSegmentPlan(replica, onlineNodes)...) + segmentPlans = append(segmentPlans, b.genSegmentPlan(replica, rwNodes)...) } } return segmentPlans, channelPlans } -func (b *RowCountBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan { +func (b *RowCountBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, rwNodes []int64, roNodes []int64) []SegmentAssignPlan { segmentPlans := make([]SegmentAssignPlan, 0) - for _, nodeID := range offlineNodes { + for _, nodeID := range roNodes { dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(nodeID)) segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool { return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil && b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil && segment.GetLevel() != datapb.SegmentLevel_L0 }) - plans := b.AssignSegment(replica.GetCollectionID(), segments, onlineNodes, false) + plans := b.AssignSegment(replica.GetCollectionID(), segments, rwNodes, false) for i := range plans { plans[i].From = nodeID plans[i].Replica = replica @@ -246,13 +224,13 @@ func (b *RowCountBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, on return segmentPlans } -func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNodes []int64) []SegmentAssignPlan { +func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, rwNodes []int64) []SegmentAssignPlan { segmentsToMove := make([]*meta.Segment, 0) nodeRowCount := make(map[int64]int, 0) segmentDist := make(map[int64][]*meta.Segment) totalRowCount := 0 - for _, node := range onlineNodes { + for _, node := range rwNodes { dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node)) segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool { return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil && @@ -273,7 +251,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode } // find nodes with less row count than average - average := totalRowCount / len(onlineNodes) + average := totalRowCount / len(rwNodes) nodesWithLessRow := make([]int64, 0) for node, segments := range segmentDist { sort.Slice(segments, func(i, j int) bool { @@ -313,11 +291,11 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode return segmentPlans } -func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []ChannelAssignPlan { +func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, rwNodes []int64, roNodes []int64) []ChannelAssignPlan { channelPlans := make([]ChannelAssignPlan, 0) - for _, nodeID := range offlineNodes { + for _, nodeID := range roNodes { dmChannels := b.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(nodeID)) - plans := b.AssignChannel(dmChannels, onlineNodes, false) + plans := b.AssignChannel(dmChannels, rwNodes, false) for i := range plans { plans[i].From = nodeID plans[i].Replica = replica @@ -327,20 +305,20 @@ func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, on return channelPlans } -func (b *RowCountBasedBalancer) genChannelPlan(replica *meta.Replica, onlineNodes []int64) []ChannelAssignPlan { +func (b *RowCountBasedBalancer) genChannelPlan(replica *meta.Replica, rwNodes []int64) []ChannelAssignPlan { channelPlans := make([]ChannelAssignPlan, 0) - if len(onlineNodes) > 1 { + if len(rwNodes) > 1 { // start to balance channels on all available nodes channelDist := b.dist.ChannelDistManager.GetByFilter(meta.WithReplica2Channel(replica)) if len(channelDist) == 0 { return nil } - average := int(math.Ceil(float64(len(channelDist)) / float64(len(onlineNodes)))) + average := int(math.Ceil(float64(len(channelDist)) / float64(len(rwNodes)))) // find nodes with less channel count than average nodeWithLessChannel := make([]int64, 0) channelsToMove := make([]*meta.DmChannel, 0) - for _, node := range onlineNodes { + for _, node := range rwNodes { channels := b.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(node)) if len(channels) <= average { diff --git a/internal/querycoordv2/balance/rowcount_based_balancer_test.go b/internal/querycoordv2/balance/rowcount_based_balancer_test.go index 8cd220bbc9be3..41662d00d3f60 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer_test.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer_test.go @@ -409,8 +409,8 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() { segmentCnts: []int{1, 2}, states: []session.State{session.NodeStateNormal, session.NodeStateNormal}, distributions: map[int64][]*meta.Segment{ - 1: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 30}, Node: 11}}, - 2: { + 11: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 30}, Node: 11}}, + 22: { {SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, Node: 22}, {SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 30}, Node: 22}, }, @@ -455,7 +455,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() { collection.LoadType = querypb.LoadType_LoadCollection balancer.meta.CollectionManager.PutCollection(collection) balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...))) + balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, c.nodes)) suite.broker.ExpectedCalls = nil suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, segments, nil) balancer.targetMgr.UpdateCollectionNextTarget(int64(1)) @@ -481,6 +481,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() { suite.balancer.nodeManager.Add(nodeInfo) suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) } + utils.RecoverAllCollection(balancer.meta) segmentPlans, channelPlans := suite.getCollectionBalancePlans(balancer, 1) if !c.multiple { @@ -492,10 +493,11 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() { } // clear distribution - for node := range c.distributions { + + for _, node := range c.nodes { + balancer.meta.ResourceManager.HandleNodeDown(node) + balancer.nodeManager.Remove(node) balancer.dist.SegmentDistManager.Update(node) - } - for node := range c.distributionChannels { balancer.dist.ChannelDistManager.Update(node) } }) @@ -693,6 +695,8 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() { suite.balancer.nodeManager.Add(nodeInfo) suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) } + utils.RecoverAllCollection(balancer.meta) + segmentPlans, channelPlans := suite.getCollectionBalancePlans(balancer, 1) assertSegmentAssignPlanElementMatch(&suite.Suite, c.expectPlans, segmentPlans) assertChannelAssignPlanElementMatch(&suite.Suite, c.expectChannelPlans, channelPlans) diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index 6b0554e397630..2730645fa9ab1 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -141,8 +141,7 @@ func (b *ScoreBasedBalancer) hasEnoughBenefit(sourceNode *nodeItem, targetNode * func (b *ScoreBasedBalancer) convertToNodeItems(collectionID int64, nodeIDs []int64) []*nodeItem { ret := make([]*nodeItem, 0, len(nodeIDs)) - for _, nodeInfo := range b.getNodes(nodeIDs) { - node := nodeInfo.ID() + for _, node := range nodeIDs { priority := b.calculateScore(collectionID, node) nodeItem := newNodeItem(priority, node) ret = append(ret, &nodeItem) @@ -195,56 +194,38 @@ func (b *ScoreBasedBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAss return nil, nil } - onlineNodes := make([]int64, 0) - offlineNodes := make([]int64, 0) + rwNodes := replica.GetRWNodes() + roNodes := replica.GetRONodes() - // read only nodes is offline in current replica. - if replica.RONodesCount() > 0 { - // if node is stop or transfer to other rg - log.RatedInfo(10, "meet read only node, try to move out all segment/channel", zap.Int64s("node", replica.GetRONodes())) - offlineNodes = append(offlineNodes, replica.GetRONodes()...) - } - - for _, nid := range replica.GetNodes() { - if isStopping, err := b.nodeManager.IsStoppingNode(nid); err != nil { - log.Info("not existed node", zap.Int64("nid", nid), zap.Error(err)) - continue - } else if isStopping { - offlineNodes = append(offlineNodes, nid) - } else { - onlineNodes = append(onlineNodes, nid) - } - } - - if len(onlineNodes) == 0 { + if len(rwNodes) == 0 { // no available nodes to balance return nil, nil } // print current distribution before generating plans segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0) - if len(offlineNodes) != 0 { + if len(roNodes) != 0 { if !paramtable.Get().QueryCoordCfg.EnableStoppingBalance.GetAsBool() { - log.RatedInfo(10, "stopping balance is disabled!", zap.Int64s("stoppingNode", offlineNodes)) + log.RatedInfo(10, "stopping balance is disabled!", zap.Int64s("stoppingNode", roNodes)) return nil, nil } log.Info("Handle stopping nodes", - zap.Any("stopping nodes", offlineNodes), - zap.Any("available nodes", onlineNodes), + zap.Any("stopping nodes", roNodes), + zap.Any("available nodes", rwNodes), ) // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score - channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, onlineNodes, offlineNodes)...) + channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...) if len(channelPlans) == 0 { - segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, onlineNodes, offlineNodes)...) + segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, rwNodes, roNodes)...) } } else { if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() { - channelPlans = append(channelPlans, b.genChannelPlan(replica, onlineNodes)...) + channelPlans = append(channelPlans, b.genChannelPlan(replica, rwNodes)...) } if len(channelPlans) == 0 { - segmentPlans = append(segmentPlans, b.genSegmentPlan(replica, onlineNodes)...) + segmentPlans = append(segmentPlans, b.genSegmentPlan(replica, rwNodes)...) } } diff --git a/internal/querycoordv2/balance/score_based_balancer_test.go b/internal/querycoordv2/balance/score_based_balancer_test.go index b52af132111c3..90401cec640ac 100644 --- a/internal/querycoordv2/balance/score_based_balancer_test.go +++ b/internal/querycoordv2/balance/score_based_balancer_test.go @@ -439,6 +439,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() { suite.balancer.nodeManager.Add(nodeInfo) suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) } + utils.RecoverAllCollection(balancer.meta) // 4. balance and verify result segmentPlans, channelPlans := suite.getCollectionBalancePlans(balancer, c.collectionID) diff --git a/internal/querycoordv2/checkers/balance_checker.go b/internal/querycoordv2/checkers/balance_checker.go index 1b268cf4e025a..f611bdef1887f 100644 --- a/internal/querycoordv2/checkers/balance_checker.go +++ b/internal/querycoordv2/checkers/balance_checker.go @@ -101,12 +101,8 @@ func (b *BalanceChecker) replicasToBalance() []int64 { } replicas := b.meta.ReplicaManager.GetByCollection(cid) for _, replica := range replicas { - for _, nodeID := range replica.GetNodes() { - isStopping, _ := b.nodeManager.IsStoppingNode(nodeID) - if isStopping { - stoppingReplicas = append(stoppingReplicas, replica.GetID()) - break - } + if replica.RONodesCount() > 0 { + stoppingReplicas = append(stoppingReplicas, replica.GetID()) } } } diff --git a/internal/querycoordv2/checkers/balance_checker_test.go b/internal/querycoordv2/checkers/balance_checker_test.go index 2b759681155bb..6cc52b58145d4 100644 --- a/internal/querycoordv2/checkers/balance_checker_test.go +++ b/internal/querycoordv2/checkers/balance_checker_test.go @@ -278,6 +278,14 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() { suite.targetMgr.UpdateCollectionNextTarget(int64(cid2)) suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid2)) + mr1 := replica1.CopyForWrite() + mr1.AddRONode(1) + suite.checker.meta.ReplicaManager.Put(mr1.IntoReplica()) + + mr2 := replica2.CopyForWrite() + mr2.AddRONode(1) + suite.checker.meta.ReplicaManager.Put(mr2.IntoReplica()) + // test stopping balance idsToBalance := []int64{int64(replicaID1), int64(replicaID2)} replicasToBalance := suite.checker.replicasToBalance() @@ -348,6 +356,14 @@ func (suite *BalanceCheckerTestSuite) TestTargetNotReady() { suite.checker.meta.CollectionManager.PutCollection(collection2, partition2) suite.checker.meta.ReplicaManager.Put(replica2) + mr1 := replica1.CopyForWrite() + mr1.AddRONode(1) + suite.checker.meta.ReplicaManager.Put(mr1.IntoReplica()) + + mr2 := replica2.CopyForWrite() + mr2.AddRONode(1) + suite.checker.meta.ReplicaManager.Put(mr2.IntoReplica()) + // test stopping balance idsToBalance := []int64{int64(replicaID1)} replicasToBalance := suite.checker.replicasToBalance() diff --git a/internal/querycoordv2/checkers/channel_checker.go b/internal/querycoordv2/checkers/channel_checker.go index d22069741605a..9ba0761107b2f 100644 --- a/internal/querycoordv2/checkers/channel_checker.go +++ b/internal/querycoordv2/checkers/channel_checker.go @@ -130,7 +130,7 @@ func (c *ChannelChecker) getDmChannelDiff(collectionID int64, return } - dist := c.getChannelDist(replica) + dist := c.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithReplica2Channel(replica)) distMap := typeutil.NewSet[string]() for _, ch := range dist { distMap.Insert(ch.GetChannelName()) @@ -159,14 +159,6 @@ func (c *ChannelChecker) getDmChannelDiff(collectionID int64, return } -func (c *ChannelChecker) getChannelDist(replica *meta.Replica) []*meta.DmChannel { - dist := make([]*meta.DmChannel, 0) - for _, nodeID := range replica.GetNodes() { - dist = append(dist, c.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(nodeID))...) - } - return dist -} - func (c *ChannelChecker) findRepeatedChannels(ctx context.Context, replicaID int64) []*meta.DmChannel { log := log.Ctx(ctx).WithRateGroup("ChannelChecker.findRepeatedChannels", 1, 60) replica := c.meta.Get(replicaID) @@ -176,7 +168,7 @@ func (c *ChannelChecker) findRepeatedChannels(ctx context.Context, replicaID int log.Info("replica does not exist, skip it") return ret } - dist := c.getChannelDist(replica) + dist := c.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithReplica2Channel(replica)) targets := c.targetMgr.GetSealedSegmentsByCollection(replica.GetCollectionID(), meta.CurrentTarget) versionsMap := make(map[string]*meta.DmChannel) @@ -221,7 +213,7 @@ func (c *ChannelChecker) createChannelLoadTask(ctx context.Context, channels []* for _, ch := range channels { rwNodes := replica.GetChannelRWNodes(ch.GetChannelName()) if len(rwNodes) == 0 { - rwNodes = replica.GetNodes() + rwNodes = replica.GetRWNodes() } plan := c.balancer.AssignChannel([]*meta.DmChannel{ch}, rwNodes, false) plans = append(plans, plan...) diff --git a/internal/querycoordv2/checkers/index_checker.go b/internal/querycoordv2/checkers/index_checker.go index 4ffe71b4ab5f6..83ffe54db2617 100644 --- a/internal/querycoordv2/checkers/index_checker.go +++ b/internal/querycoordv2/checkers/index_checker.go @@ -102,16 +102,17 @@ func (c *IndexChecker) checkReplica(ctx context.Context, collection *meta.Collec ) var tasks []task.Task - segments := c.getSealedSegmentsDist(replica) + segments := c.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithReplica(replica)) idSegments := make(map[int64]*meta.Segment) + roNodeSet := typeutil.NewUniqueSet(replica.GetRONodes()...) targets := make(map[int64][]int64) // segmentID => FieldID for _, segment := range segments { - // skip update index in stopping node - if ok, _ := c.nodeMgr.IsStoppingNode(segment.Node); ok { + // skip update index in read only node + if roNodeSet.Contain(segment.Node) { continue } - missing := c.checkSegment(ctx, segment, indexInfos) + missing := c.checkSegment(segment, indexInfos) if len(missing) > 0 { targets[segment.GetID()] = missing idSegments[segment.GetID()] = segment @@ -142,7 +143,7 @@ func (c *IndexChecker) checkReplica(ctx context.Context, collection *meta.Collec return tasks } -func (c *IndexChecker) checkSegment(ctx context.Context, segment *meta.Segment, indexInfos []*indexpb.IndexInfo) (fieldIDs []int64) { +func (c *IndexChecker) checkSegment(segment *meta.Segment, indexInfos []*indexpb.IndexInfo) (fieldIDs []int64) { var result []int64 for _, indexInfo := range indexInfos { fieldID, indexID := indexInfo.FieldID, indexInfo.IndexID @@ -158,14 +159,6 @@ func (c *IndexChecker) checkSegment(ctx context.Context, segment *meta.Segment, return result } -func (c *IndexChecker) getSealedSegmentsDist(replica *meta.Replica) []*meta.Segment { - var ret []*meta.Segment - for _, node := range replica.GetNodes() { - ret = append(ret, c.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node))...) - } - return ret -} - func (c *IndexChecker) createSegmentUpdateTask(ctx context.Context, segment *meta.Segment, replica *meta.Replica) (task.Task, bool) { action := task.NewSegmentActionWithScope(segment.Node, task.ActionTypeUpdate, segment.GetInsertChannel(), segment.GetID(), querypb.DataScope_Historical) t, err := task.NewSegmentTask( diff --git a/internal/querycoordv2/checkers/index_checker_test.go b/internal/querycoordv2/checkers/index_checker_test.go index c7b43a5945805..ef9b80e50de3b 100644 --- a/internal/querycoordv2/checkers/index_checker_test.go +++ b/internal/querycoordv2/checkers/index_checker_test.go @@ -134,9 +134,12 @@ func (suite *IndexCheckerSuite) TestLoadIndex() { suite.Equal(task.ActionTypeUpdate, action.Type()) suite.EqualValues(2, action.SegmentID()) - // test skip load index for stopping node + // test skip load index for read only node suite.nodeMgr.Stopping(1) suite.nodeMgr.Stopping(2) + suite.meta.ResourceManager.HandleNodeStopping(1) + suite.meta.ResourceManager.HandleNodeStopping(2) + utils.RecoverAllCollection(suite.meta) tasks = checker.Check(context.Background()) suite.Require().Len(tasks, 0) } diff --git a/internal/querycoordv2/checkers/leader_checker.go b/internal/querycoordv2/checkers/leader_checker.go index 84d7ad8f52101..7c4a1ae899bc8 100644 --- a/internal/querycoordv2/checkers/leader_checker.go +++ b/internal/querycoordv2/checkers/leader_checker.go @@ -93,12 +93,7 @@ func (c *LeaderChecker) Check(ctx context.Context) []task.Task { replicas := c.meta.ReplicaManager.GetByCollection(collectionID) for _, replica := range replicas { - for _, node := range replica.GetNodes() { - if ok, _ := c.nodeMgr.IsStoppingNode(node); ok { - // no need to correct leader's view which is loaded on stopping node - continue - } - + for _, node := range replica.GetRWNodes() { leaderViews := c.dist.LeaderViewManager.GetByFilter(meta.WithCollectionID2LeaderView(replica.GetCollectionID()), meta.WithNodeID2LeaderView(node)) for _, leaderView := range leaderViews { dist := c.dist.SegmentDistManager.GetByFilter(meta.WithChannel(leaderView.Channel), meta.WithReplica(replica)) diff --git a/internal/querycoordv2/checkers/leader_checker_test.go b/internal/querycoordv2/checkers/leader_checker_test.go index f40563483ee2a..f01b8bb34d7e3 100644 --- a/internal/querycoordv2/checkers/leader_checker_test.go +++ b/internal/querycoordv2/checkers/leader_checker_test.go @@ -237,7 +237,8 @@ func (suite *LeaderCheckerTestSuite) TestStoppingNode() { observer := suite.checker observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + replica := utils.CreateTestReplica(1, 1, []int64{1, 2}) + observer.meta.ReplicaManager.Put(replica) segments := []*datapb.SegmentInfo{ { ID: 1, @@ -261,12 +262,9 @@ func (suite *LeaderCheckerTestSuite) TestStoppingNode() { view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget) observer.dist.LeaderViewManager.Update(2, view) - suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ - NodeID: 2, - Address: "localhost", - Hostname: "localhost", - })) - suite.nodeMgr.Stopping(2) + mutableReplica := replica.CopyForWrite() + mutableReplica.AddRONode(2) + observer.meta.ReplicaManager.Put(mutableReplica.IntoReplica()) tasks := suite.checker.Check(context.TODO()) suite.Len(tasks, 0) diff --git a/internal/querycoordv2/checkers/segment_checker.go b/internal/querycoordv2/checkers/segment_checker.go index 17e9e7346fa1f..1c85aef177df3 100644 --- a/internal/querycoordv2/checkers/segment_checker.go +++ b/internal/querycoordv2/checkers/segment_checker.go @@ -204,7 +204,7 @@ func (c *SegmentChecker) getSealedSegmentDiff( log.Info("replica does not exist, skip it") return } - dist := c.getSealedSegmentsDist(replica) + dist := c.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithReplica(replica)) sort.Slice(dist, func(i, j int) bool { return dist[i].Version < dist[j].Version }) @@ -293,14 +293,6 @@ func (c *SegmentChecker) getSealedSegmentDiff( return } -func (c *SegmentChecker) getSealedSegmentsDist(replica *meta.Replica) []*meta.Segment { - ret := make([]*meta.Segment, 0) - for _, node := range replica.GetNodes() { - ret = append(ret, c.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node))...) - } - return ret -} - func (c *SegmentChecker) findRepeatedSealedSegments(replicaID int64) []*meta.Segment { segments := make([]*meta.Segment, 0) replica := c.meta.Get(replicaID) @@ -308,7 +300,7 @@ func (c *SegmentChecker) findRepeatedSealedSegments(replicaID int64) []*meta.Seg log.Info("replica does not exist, skip it") return segments } - dist := c.getSealedSegmentsDist(replica) + dist := c.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithReplica(replica)) versions := make(map[int64]*meta.Segment) for _, s := range dist { // l0 segment should be release with channel together @@ -398,25 +390,12 @@ func (c *SegmentChecker) createSegmentLoadTasks(ctx context.Context, segments [] rwNodes := replica.GetChannelRWNodes(shard) if len(rwNodes) == 0 { - rwNodes = replica.GetNodes() - } - - // filter out stopping nodes. - availableNodes := lo.Filter(rwNodes, func(node int64, _ int) bool { - stop, err := c.nodeMgr.IsStoppingNode(node) - if err != nil { - return false - } - return !stop - }) - - if len(availableNodes) == 0 { - return nil + rwNodes = replica.GetRWNodes() } // L0 segment can only be assign to shard leader's node if isLevel0 { - availableNodes = []int64{leader.ID} + rwNodes = []int64{leader.ID} } segmentInfos := lo.Map(segments, func(s *datapb.SegmentInfo, _ int) *meta.Segment { @@ -424,7 +403,7 @@ func (c *SegmentChecker) createSegmentLoadTasks(ctx context.Context, segments [] SegmentInfo: s, } }) - shardPlans := c.balancer.AssignSegment(replica.GetCollectionID(), segmentInfos, availableNodes, false) + shardPlans := c.balancer.AssignSegment(replica.GetCollectionID(), segmentInfos, rwNodes, false) for i := range shardPlans { shardPlans[i].Replica = replica } diff --git a/internal/querycoordv2/handlers.go b/internal/querycoordv2/handlers.go index afd6df93cad9b..e3387ae6b785a 100644 --- a/internal/querycoordv2/handlers.go +++ b/internal/querycoordv2/handlers.go @@ -46,7 +46,7 @@ import ( func (s *Server) checkAnyReplicaAvailable(collectionID int64) bool { for _, replica := range s.meta.ReplicaManager.GetByCollection(collectionID) { isAvailable := true - for _, node := range replica.GetNodes() { + for _, node := range replica.GetRONodes() { if s.nodeMgr.Get(node) == nil { isAvailable = false break diff --git a/internal/querycoordv2/job/job_load.go b/internal/querycoordv2/job/job_load.go index d28813f08e9db..03b4de5e332d1 100644 --- a/internal/querycoordv2/job/job_load.go +++ b/internal/querycoordv2/job/job_load.go @@ -159,16 +159,12 @@ func (job *LoadCollectionJob) Execute() error { // API of LoadCollection is wired, we should use map[resourceGroupNames]replicaNumber as input, to keep consistency with `TransferReplica` API. // Then we can implement dynamic replica changed in different resource group independently. - replicas, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(), collectionInfo.GetVirtualChannelNames()) + _, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(), collectionInfo.GetVirtualChannelNames()) if err != nil { msg := "failed to spawn replica for collection" log.Warn(msg, zap.Error(err)) return errors.Wrap(err, msg) } - for _, replica := range replicas { - log.Info("replica created", zap.Int64("replicaID", replica.GetID()), - zap.Int64s("nodes", replica.GetNodes()), zap.String("resourceGroup", replica.GetResourceGroup())) - } job.undo.IsReplicaCreated = true } @@ -346,16 +342,12 @@ func (job *LoadPartitionJob) Execute() error { if err != nil { return err } - replicas, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(), collectionInfo.GetVirtualChannelNames()) + _, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(), collectionInfo.GetVirtualChannelNames()) if err != nil { msg := "failed to spawn replica for collection" log.Warn(msg, zap.Error(err)) return errors.Wrap(err, msg) } - for _, replica := range replicas { - log.Info("replica created", zap.Int64("replicaID", replica.GetID()), - zap.Int64s("nodes", replica.GetNodes()), zap.String("resourceGroup", replica.GetResourceGroup())) - } job.undo.IsReplicaCreated = true } diff --git a/internal/querycoordv2/meta/replica.go b/internal/querycoordv2/meta/replica.go index 4aeb08bfb801f..387dc910d57d1 100644 --- a/internal/querycoordv2/meta/replica.go +++ b/internal/querycoordv2/meta/replica.go @@ -39,7 +39,7 @@ func NewReplica(replica *querypb.Replica, nodes ...typeutil.UniqueSet) *Replica } // newReplica creates a new replica from pb. -func newReplica(replica *querypb.Replica, channels ...string) *Replica { +func newReplica(replica *querypb.Replica) *Replica { return &Replica{ replicaPB: proto.Clone(replica).(*querypb.Replica), rwNodes: typeutil.NewUniqueSet(replica.Nodes...), @@ -65,7 +65,10 @@ func (replica *Replica) GetResourceGroup() string { // GetNodes returns the rw nodes of the replica. // readonly, don't modify the returned slice. func (replica *Replica) GetNodes() []int64 { - return replica.replicaPB.GetNodes() + nodes := make([]int64, 0) + nodes = append(nodes, replica.replicaPB.GetRoNodes()...) + nodes = append(nodes, replica.replicaPB.GetNodes()...) + return nodes } // GetRONodes returns the ro nodes of the replica. @@ -74,6 +77,12 @@ func (replica *Replica) GetRONodes() []int64 { return replica.replicaPB.GetRoNodes() } +// GetRONodes returns the rw nodes of the replica. +// readonly, don't modify the returned slice. +func (replica *Replica) GetRWNodes() []int64 { + return replica.replicaPB.GetNodes() +} + // RangeOverRWNodes iterates over the read and write nodes of the replica. func (replica *Replica) RangeOverRWNodes(f func(node int64) bool) { replica.rwNodes.Range(f) @@ -131,8 +140,8 @@ func (replica *Replica) GetChannelRWNodes(channelName string) []int64 { return replica.replicaPB.ChannelNodeInfos[channelName].GetRwNodes() } -// copyForWrite returns a mutable replica for write operations. -func (replica *Replica) copyForWrite() *mutableReplica { +// CopyForWrite returns a mutable replica for write operations. +func (replica *Replica) CopyForWrite() *mutableReplica { exclusiveRWNodeToChannel := make(map[int64]string) for name, channelNodeInfo := range replica.replicaPB.GetChannelNodeInfos() { for _, nodeID := range channelNodeInfo.GetRwNodes() { diff --git a/internal/querycoordv2/meta/replica_manager.go b/internal/querycoordv2/meta/replica_manager.go index 91824b18c350e..2a947a6532188 100644 --- a/internal/querycoordv2/meta/replica_manager.go +++ b/internal/querycoordv2/meta/replica_manager.go @@ -195,7 +195,7 @@ func (m *ReplicaManager) TransferReplica(collectionID typeutil.UniqueID, srcRGNa // Node Change will be executed by replica_observer in background. replicas := make([]*Replica, 0, replicaNum) for i := 0; i < replicaNum; i++ { - mutableReplica := srcReplicas[i].copyForWrite() + mutableReplica := srcReplicas[i].CopyForWrite() mutableReplica.SetResourceGroup(dstRGName) replicas = append(replicas, mutableReplica.IntoReplica()) } @@ -350,7 +350,7 @@ func (m *ReplicaManager) RecoverNodesInCollection(collectionID typeutil.UniqueID // nothing to do. return } - mutableReplica := m.replicas[assignment.GetReplicaID()].copyForWrite() + mutableReplica := m.replicas[assignment.GetReplicaID()].CopyForWrite() mutableReplica.AddRONode(roNodes...) // rw -> ro mutableReplica.AddRWNode(recoverableNodes...) // ro -> rw mutableReplica.AddRWNode(incomingNode...) // unused -> rw @@ -414,7 +414,7 @@ func (m *ReplicaManager) RemoveNode(replicaID typeutil.UniqueID, nodes ...typeut return merr.WrapErrReplicaNotFound(replicaID) } - mutableReplica := replica.copyForWrite() + mutableReplica := replica.CopyForWrite() mutableReplica.RemoveNode(nodes...) // ro -> unused return m.put(mutableReplica.IntoReplica()) } diff --git a/internal/querycoordv2/meta/replica_test.go b/internal/querycoordv2/meta/replica_test.go index 01ea9e6cc19a5..31c1194ac023b 100644 --- a/internal/querycoordv2/meta/replica_test.go +++ b/internal/querycoordv2/meta/replica_test.go @@ -30,13 +30,13 @@ func (suite *ReplicaSuite) TestReadOperations() { r := newReplica(suite.replicaPB) suite.testRead(r) // keep same after clone. - mutableReplica := r.copyForWrite() + mutableReplica := r.CopyForWrite() suite.testRead(mutableReplica.IntoReplica()) } func (suite *ReplicaSuite) TestClone() { r := newReplica(suite.replicaPB) - r2 := r.copyForWrite() + r2 := r.CopyForWrite() suite.testRead(r) // after apply write operation on copy, the original should not be affected. @@ -68,7 +68,7 @@ func (suite *ReplicaSuite) TestRange() { }) suite.Equal(1, count) - mr := r.copyForWrite() + mr := r.CopyForWrite() mr.AddRONode(1) count = 0 @@ -81,7 +81,7 @@ func (suite *ReplicaSuite) TestRange() { func (suite *ReplicaSuite) TestWriteOperation() { r := newReplica(suite.replicaPB) - mr := r.copyForWrite() + mr := r.CopyForWrite() // test add available node. suite.False(mr.Contains(5)) @@ -158,7 +158,7 @@ func (suite *ReplicaSuite) testRead(r *Replica) { suite.Equal(suite.replicaPB.GetResourceGroup(), r.GetResourceGroup()) // Test GetNodes() - suite.ElementsMatch(suite.replicaPB.GetNodes(), r.GetNodes()) + suite.ElementsMatch(suite.replicaPB.GetNodes(), r.GetRWNodes()) // Test GetRONodes() suite.ElementsMatch(suite.replicaPB.GetRoNodes(), r.GetRONodes()) @@ -195,7 +195,7 @@ func (suite *ReplicaSuite) TestChannelExclusiveMode() { }, }) - mutableReplica := r.copyForWrite() + mutableReplica := r.CopyForWrite() // add 10 rw nodes, exclusive mode is false. for i := 0; i < 10; i++ { mutableReplica.AddRWNode(int64(i)) @@ -205,7 +205,7 @@ func (suite *ReplicaSuite) TestChannelExclusiveMode() { suite.Equal(0, len(channelNodeInfo.GetRwNodes())) } - mutableReplica = r.copyForWrite() + mutableReplica = r.CopyForWrite() // add 10 rw nodes, exclusive mode is true. for i := 10; i < 20; i++ { mutableReplica.AddRWNode(int64(i)) @@ -216,7 +216,7 @@ func (suite *ReplicaSuite) TestChannelExclusiveMode() { } // 4 node become read only, exclusive mode still be true - mutableReplica = r.copyForWrite() + mutableReplica = r.CopyForWrite() for i := 0; i < 4; i++ { mutableReplica.AddRONode(int64(i)) } @@ -226,7 +226,7 @@ func (suite *ReplicaSuite) TestChannelExclusiveMode() { } // 4 node has been removed, exclusive mode back to false - mutableReplica = r.copyForWrite() + mutableReplica = r.CopyForWrite() for i := 4; i < 8; i++ { mutableReplica.RemoveNode(int64(i)) } diff --git a/internal/querycoordv2/meta/resource_manager.go b/internal/querycoordv2/meta/resource_manager.go index 0fea92ef905d3..3df8755629716 100644 --- a/internal/querycoordv2/meta/resource_manager.go +++ b/internal/querycoordv2/meta/resource_manager.go @@ -453,7 +453,6 @@ func (rm *ResourceManager) HandleNodeDown(node int64) { rm.rwmutex.Lock() defer rm.rwmutex.Unlock() - // failure of node down can be ignored, node down can be done by `RemoveAllDownNode`. rm.incomingNode.Remove(node) // for stopping query node becomes offline, node change won't be triggered, @@ -470,6 +469,19 @@ func (rm *ResourceManager) HandleNodeDown(node int64) { ) } +func (rm *ResourceManager) HandleNodeStopping(node int64) { + rm.rwmutex.Lock() + defer rm.rwmutex.Unlock() + + rm.incomingNode.Remove(node) + rgName, err := rm.unassignNode(node) + log.Info("HandleNodeStopping: remove node from resource group", + zap.String("rgName", rgName), + zap.Int64("node", node), + zap.Error(err), + ) +} + // ListenResourceGroupChanged return a listener for resource group changed. func (rm *ResourceManager) ListenResourceGroupChanged() *syncutil.VersionedListener { return rm.rgChangedNotifier.Listen(syncutil.VersionedListenAtEarliest) @@ -495,25 +507,6 @@ func (rm *ResourceManager) AssignPendingIncomingNode() { } } -// RemoveAllDownNode remove all down node from resource group. -func (rm *ResourceManager) RemoveAllDownNode() { - rm.rwmutex.Lock() - defer rm.rwmutex.Unlock() - - for nodeID := range rm.nodeIDMap { - if node := rm.nodeMgr.Get(nodeID); node == nil || node.IsStoppingState() { - // unassignNode failure can be skip. - rgName, err := rm.unassignNode(nodeID) - log.Info("remove down node from resource group", - zap.Bool("nodeExist", node != nil), - zap.Int64("nodeID", nodeID), - zap.String("rgName", rgName), - zap.Error(err), - ) - } - } -} - // AutoRecoverResourceGroup auto recover rg, return recover used node num func (rm *ResourceManager) AutoRecoverResourceGroup(rgName string) error { rm.rwmutex.Lock() @@ -847,7 +840,8 @@ func (rm *ResourceManager) unassignNode(node int64) (string, error) { rm.nodeChangedNotifier.NotifyAll() return rg.GetName(), nil } - return "", nil + + return "", errors.Errorf("node %d not found in any resource group", node) } // validateResourceGroupConfig validate resource group config. diff --git a/internal/querycoordv2/meta/resource_manager_test.go b/internal/querycoordv2/meta/resource_manager_test.go index 969e09cde6c1f..37a6a4f647187 100644 --- a/internal/querycoordv2/meta/resource_manager_test.go +++ b/internal/querycoordv2/meta/resource_manager_test.go @@ -524,16 +524,6 @@ func (suite *ResourceManagerSuite) TestAutoRecover() { suite.Equal(80, suite.manager.GetResourceGroup("rg2").NodeNum()) suite.Equal(5, suite.manager.GetResourceGroup("rg3").NodeNum()) suite.Equal(5, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) - - // Test down all nodes. - for i := 1; i <= 100; i++ { - suite.manager.nodeMgr.Remove(int64(i)) - } - suite.manager.RemoveAllDownNode() - suite.Zero(suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Zero(suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Zero(suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) } func (suite *ResourceManagerSuite) testTransferNode() { diff --git a/internal/querycoordv2/observers/replica_observer.go b/internal/querycoordv2/observers/replica_observer.go index e60a988a5e545..96180fb72ec54 100644 --- a/internal/querycoordv2/observers/replica_observer.go +++ b/internal/querycoordv2/observers/replica_observer.go @@ -100,6 +100,7 @@ func (ob *ReplicaObserver) checkNodesInReplica() { replicas := ob.meta.ReplicaManager.GetByCollection(collectionID) for _, replica := range replicas { roNodes := replica.GetRONodes() + rwNodes := replica.GetRWNodes() if len(roNodes) == 0 { continue } @@ -124,7 +125,7 @@ func (ob *ReplicaObserver) checkNodesInReplica() { zap.Int64("replicaID", replica.GetID()), zap.Int64s("removedNodes", removeNodes), zap.Int64s("roNodes", roNodes), - zap.Int64s("availableNodes", replica.GetNodes()), + zap.Int64s("rwNodes", rwNodes), ) if err := ob.meta.ReplicaManager.RemoveNode(replica.GetID(), removeNodes...); err != nil { logger.Warn("fail to remove node from replica", zap.Error(err)) diff --git a/internal/querycoordv2/observers/resource_observer.go b/internal/querycoordv2/observers/resource_observer.go index cf15f6657d67f..bfad63e28aea6 100644 --- a/internal/querycoordv2/observers/resource_observer.go +++ b/internal/querycoordv2/observers/resource_observer.go @@ -98,10 +98,6 @@ func (ob *ResourceObserver) checkAndRecoverResourceGroup() { manager.AssignPendingIncomingNode() } - // Remove all down nodes in resource group manager. - log.Debug("remove all down nodes in resource group manager...") - ob.meta.RemoveAllDownNode() - log.Debug("recover resource groups...") // Recover all resource group into expected configuration. for _, rgName := range rgNames { diff --git a/internal/querycoordv2/observers/resource_observer_test.go b/internal/querycoordv2/observers/resource_observer_test.go index b9e4872ff9319..f75dae11e2a9f 100644 --- a/internal/querycoordv2/observers/resource_observer_test.go +++ b/internal/querycoordv2/observers/resource_observer_test.go @@ -136,6 +136,7 @@ func (suite *ResourceObserverSuite) TestObserverRecoverOperation() { suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg2")) suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg3")) // new node is down, rg3 cannot use that node anymore. + suite.meta.ResourceManager.HandleNodeDown(10) suite.observer.checkAndRecoverResourceGroup() suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg1")) suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg2")) diff --git a/internal/querycoordv2/ops_services.go b/internal/querycoordv2/ops_services.go index 2d643b18c2ff3..46b3792207706 100644 --- a/internal/querycoordv2/ops_services.go +++ b/internal/querycoordv2/ops_services.go @@ -276,7 +276,7 @@ func (s *Server) TransferSegment(ctx context.Context, req *querypb.TransferSegme // when no dst node specified, default to use all other nodes in same dstNodeSet := typeutil.NewUniqueSet() if req.GetToAllNodes() { - dstNodeSet.Insert(replica.GetNodes()...) + dstNodeSet.Insert(replica.GetRWNodes()...) } else { // check whether dstNode is healthy if err := s.isStoppingNode(req.GetTargetNodeID()); err != nil { @@ -348,7 +348,7 @@ func (s *Server) TransferChannel(ctx context.Context, req *querypb.TransferChann // when no dst node specified, default to use all other nodes in same dstNodeSet := typeutil.NewUniqueSet() if req.GetToAllNodes() { - dstNodeSet.Insert(replica.GetNodes()...) + dstNodeSet.Insert(replica.GetRWNodes()...) } else { // check whether dstNode is healthy if err := s.isStoppingNode(req.GetTargetNodeID()); err != nil { diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index 127844f8e6ad4..d115c4ceb7cf7 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -441,7 +441,6 @@ func (s *Server) startQueryCoord() error { s.nodeMgr.Stopping(node.ServerID) } } - s.checkReplicas() for _, node := range sessions { s.handleNodeUp(node.ServerID) } @@ -685,6 +684,7 @@ func (s *Server) watchNodes(revision int64) { ) s.nodeMgr.Stopping(nodeID) s.checkerController.Check() + s.meta.ResourceManager.HandleNodeStopping(nodeID) case sessionutil.SessionDelEvent: nodeID := event.Session.ServerID @@ -748,7 +748,6 @@ func (s *Server) handleNodeUp(node int64) { } func (s *Server) handleNodeDown(node int64) { - log := log.With(zap.Int64("nodeID", node)) s.taskScheduler.RemoveExecutor(node) s.distController.Remove(node) @@ -757,57 +756,12 @@ func (s *Server) handleNodeDown(node int64) { s.dist.ChannelDistManager.Update(node) s.dist.SegmentDistManager.Update(node) - // Clear meta - for _, collection := range s.meta.CollectionManager.GetAll() { - log := log.With(zap.Int64("collectionID", collection)) - replica := s.meta.ReplicaManager.GetByCollectionAndNode(collection, node) - if replica == nil { - continue - } - err := s.meta.ReplicaManager.RemoveNode(replica.GetID(), node) - if err != nil { - log.Warn("failed to remove node from collection's replicas", - zap.Int64("replicaID", replica.GetID()), - zap.Error(err), - ) - } - log.Info("remove node from replica", - zap.Int64("replicaID", replica.GetID())) - } - // Clear tasks s.taskScheduler.RemoveByNode(node) s.meta.ResourceManager.HandleNodeDown(node) } -// checkReplicas checks whether replica contains offline node, and remove those nodes -func (s *Server) checkReplicas() { - for _, collection := range s.meta.CollectionManager.GetAll() { - log := log.With(zap.Int64("collectionID", collection)) - replicas := s.meta.ReplicaManager.GetByCollection(collection) - for _, replica := range replicas { - toRemove := make([]int64, 0) - for _, node := range replica.GetNodes() { - if s.nodeMgr.Get(node) == nil { - toRemove = append(toRemove, node) - } - } - - if len(toRemove) > 0 { - log := log.With( - zap.Int64("replicaID", replica.GetID()), - zap.Int64s("offlineNodes", toRemove), - ) - log.Info("some nodes are offline, remove them from replica", zap.Any("toRemove", toRemove)) - if err := s.meta.ReplicaManager.RemoveNode(replica.GetID(), toRemove...); err != nil { - log.Warn("failed to remove offline nodes from replica") - } - } - } - } -} - func (s *Server) updateBalanceConfigLoop(ctx context.Context) { success := s.updateBalanceConfig() if success { diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index 93606f6d3bebd..dea9817a2777f 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -686,7 +686,7 @@ func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques // when no dst node specified, default to use all other nodes in same dstNodeSet := typeutil.NewUniqueSet() if len(req.GetDstNodeIDs()) == 0 { - dstNodeSet.Insert(replica.GetNodes()...) + dstNodeSet.Insert(replica.GetRWNodes()...) } else { for _, dstNode := range req.GetDstNodeIDs() { if !replica.Contains(dstNode) { @@ -1075,7 +1075,7 @@ func (s *Server) DescribeResourceGroup(ctx context.Context, req *querypb.Describ replicasInRG := s.meta.GetByResourceGroup(req.GetResourceGroup()) for _, replica := range replicasInRG { loadedReplicas[replica.GetCollectionID()]++ - for _, node := range replica.GetNodes() { + for _, node := range replica.GetRONodes() { if !s.meta.ContainsNode(replica.GetResourceGroup(), node) { outgoingNodes[replica.GetCollectionID()]++ } @@ -1090,7 +1090,7 @@ func (s *Server) DescribeResourceGroup(ctx context.Context, req *querypb.Describ if replica.GetResourceGroup() == req.GetResourceGroup() { continue } - for _, node := range replica.GetNodes() { + for _, node := range replica.GetRONodes() { if s.meta.ContainsNode(req.GetResourceGroup(), node) { incomingNodes[collection]++ } @@ -1101,8 +1101,7 @@ func (s *Server) DescribeResourceGroup(ctx context.Context, req *querypb.Describ nodes := make([]*commonpb.NodeInfo, 0, len(rg.GetNodes())) for _, nodeID := range rg.GetNodes() { nodeSessionInfo := s.nodeMgr.Get(nodeID) - // Filter offline nodes and nodes in stopping state - if nodeSessionInfo != nil && !nodeSessionInfo.IsStoppingState() { + if nodeSessionInfo != nil { nodes = append(nodes, &commonpb.NodeInfo{ NodeId: nodeSessionInfo.ID(), Address: nodeSessionInfo.Addr(), diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index e33d673f095ba..9c486a26b9c81 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -432,7 +432,8 @@ func (suite *ServiceSuite) TestResourceGroup() { server.meta.ReplicaManager.Put(meta.NewReplica(&querypb.Replica{ ID: 1, CollectionID: 1, - Nodes: []int64{1011, 1013}, + Nodes: []int64{1011}, + RoNodes: []int64{1013}, ResourceGroup: "rg11", }, typeutil.NewUniqueSet(1011, 1013)), @@ -440,7 +441,8 @@ func (suite *ServiceSuite) TestResourceGroup() { server.meta.ReplicaManager.Put(meta.NewReplica(&querypb.Replica{ ID: 2, CollectionID: 2, - Nodes: []int64{1012, 1014}, + Nodes: []int64{1014}, + RoNodes: []int64{1012}, ResourceGroup: "rg12", }, typeutil.NewUniqueSet(1012, 1014)), diff --git a/internal/querycoordv2/utils/meta.go b/internal/querycoordv2/utils/meta.go index 4dff731286226..b6ac15839e0b2 100644 --- a/internal/querycoordv2/utils/meta.go +++ b/internal/querycoordv2/utils/meta.go @@ -22,7 +22,6 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/internal/querycoordv2/meta" - "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -35,19 +34,6 @@ var ( ErrUseWrongNumRG = errors.New("resource group num can only be 0, 1 or same as replica number") ) -func GetReplicaNodesInfo(replicaMgr *meta.ReplicaManager, nodeMgr *session.NodeManager, replicaID int64) []*session.NodeInfo { - replica := replicaMgr.Get(replicaID) - if replica == nil { - return nil - } - - nodes := make([]*session.NodeInfo, 0, len(replica.GetNodes())) - for _, node := range replica.GetNodes() { - nodes = append(nodes, nodeMgr.Get(node)) - } - return nodes -} - func GetPartitions(collectionMgr *meta.CollectionManager, collectionID int64) ([]int64, error) { collection := collectionMgr.GetCollection(collectionID) if collection != nil { diff --git a/internal/storage/serde.go b/internal/storage/serde.go index 7a64edf79b653..636e505b83c76 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -28,6 +28,7 @@ import ( "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/memory" + "github.com/apache/arrow/go/v12/parquet" "github.com/apache/arrow/go/v12/parquet/pqarrow" "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" @@ -749,18 +750,17 @@ var _ RecordWriter = (*singleFieldRecordWriter)(nil) type singleFieldRecordWriter struct { fw *pqarrow.FileWriter fieldId FieldID + schema *arrow.Schema - grouped bool + numRows int } func (sfw *singleFieldRecordWriter) Write(r Record) error { - if !sfw.grouped { - sfw.grouped = true - sfw.fw.NewRowGroup() - } - // TODO: adding row group support by calling fw.NewRowGroup() + sfw.numRows += r.Len() a := r.Column(sfw.fieldId) - return sfw.fw.WriteColumnData(a) + rec := array.NewRecord(sfw.schema, []arrow.Array{a}, int64(r.Len())) + defer rec.Release() + return sfw.fw.WriteBuffered(rec) } func (sfw *singleFieldRecordWriter) Close() { @@ -769,13 +769,16 @@ func (sfw *singleFieldRecordWriter) Close() { func newSingleFieldRecordWriter(fieldId FieldID, field arrow.Field, writer io.Writer) (*singleFieldRecordWriter, error) { schema := arrow.NewSchema([]arrow.Field{field}, nil) - fw, err := pqarrow.NewFileWriter(schema, writer, nil, pqarrow.DefaultWriterProps()) + fw, err := pqarrow.NewFileWriter(schema, writer, + parquet.NewWriterProperties(parquet.WithMaxRowGroupLength(math.MaxInt64)), // No additional grouping for now. + pqarrow.DefaultWriterProps()) if err != nil { return nil, err } return &singleFieldRecordWriter{ fw: fw, fieldId: fieldId, + schema: schema, }, nil } @@ -790,15 +793,18 @@ type SerializeWriter[T any] struct { } func (sw *SerializeWriter[T]) Flush() error { + if sw.pos == 0 { + return nil + } buf := sw.buffer[:sw.pos] r, size, err := sw.serializer(buf) if err != nil { return err } + defer r.Release() if err := sw.rw.Write(r); err != nil { return err } - r.Release() sw.pos = 0 sw.writtenMemorySize += size return nil @@ -823,8 +829,11 @@ func (sw *SerializeWriter[T]) WrittenMemorySize() uint64 { } func (sw *SerializeWriter[T]) Close() error { + if err := sw.Flush(); err != nil { + return err + } sw.rw.Close() - return sw.Flush() + return nil } func NewSerializeRecordWriter[T any](rw RecordWriter, serializer Serializer[T], batchSize int) *SerializeWriter[T] { @@ -881,7 +890,7 @@ type BinlogStreamWriter struct { memorySize int // To be updated on the fly buf bytes.Buffer - rw RecordWriter + rw *singleFieldRecordWriter } func (bsw *BinlogStreamWriter) GetRecordWriter() (RecordWriter, error) { @@ -916,8 +925,9 @@ func (bsw *BinlogStreamWriter) Finalize() (*Blob, error) { return nil, err } return &Blob{ - Key: strconv.Itoa(int(bsw.fieldSchema.FieldID)), - Value: b.Bytes(), + Key: strconv.Itoa(int(bsw.fieldSchema.FieldID)), + Value: b.Bytes(), + RowNum: int64(bsw.rw.numRows), }, nil } diff --git a/internal/storage/serde_test.go b/internal/storage/serde_test.go index 17a10e3a2104e..21a871cb5e606 100644 --- a/internal/storage/serde_test.go +++ b/internal/storage/serde_test.go @@ -124,7 +124,7 @@ func TestBinlogSerializeWriter(t *testing.T) { }) t.Run("test serialize", func(t *testing.T) { - size := 3 + size := 16 blobs, err := generateTestData(size) assert.NoError(t, err) reader, err := NewBinlogDeserializeReader(blobs, common.RowIDField) @@ -134,7 +134,7 @@ func TestBinlogSerializeWriter(t *testing.T) { schema := generateTestSchema() // Copy write the generated data writers := NewBinlogStreamWriters(0, 0, 0, schema.Fields) - writer, err := NewBinlogSerializeWriter(schema, 0, 0, writers, 1024) + writer, err := NewBinlogSerializeWriter(schema, 0, 0, writers, 7) assert.NoError(t, err) for i := 1; i <= size; i++ { @@ -143,7 +143,8 @@ func TestBinlogSerializeWriter(t *testing.T) { value := reader.Value() assertTestData(t, i, value) - writer.Write(value) + err := writer.Write(value) + assert.NoError(t, err) } err = reader.Next() diff --git a/internal/util/importutilv2/json/reader_test.go b/internal/util/importutilv2/json/reader_test.go index 89d033356723f..c46954ead8418 100644 --- a/internal/util/importutilv2/json/reader_test.go +++ b/internal/util/importutilv2/json/reader_test.go @@ -143,14 +143,7 @@ func (suite *ReaderSuite) run(dataType schemapb.DataType, elemType schemapb.Data data[fieldID] = typeutil.BFloat16BytesToFloat32Vector(bytes) case schemapb.DataType_SparseFloatVector: bytes := v.GetRow(i).([]byte) - elemCount := len(bytes) / 8 - values := make(map[uint32]float32) - for j := 0; j < elemCount; j++ { - idx := common.Endian.Uint32(bytes[j*8:]) - f := typeutil.BytesToFloat32(bytes[j*8+4:]) - values[idx] = f - } - data[fieldID] = values + data[fieldID] = typeutil.SparseFloatBytesToMap(bytes) default: data[fieldID] = v.GetRow(i) } diff --git a/internal/util/importutilv2/json/row_parser.go b/internal/util/importutilv2/json/row_parser.go index 0bd65db599b61..c3f4bcb1c4287 100644 --- a/internal/util/importutilv2/json/row_parser.go +++ b/internal/util/importutilv2/json/row_parser.go @@ -354,7 +354,7 @@ func (r *rowParser) parseEntity(fieldID int64, obj any) (any, error) { if !ok { return nil, r.wrapTypeError(obj, fieldID) } - vec, err := typeutil.CreateSparseFloatRowFromJSON(arr) + vec, err := typeutil.CreateSparseFloatRowFromMap(arr) if err != nil { return nil, err } diff --git a/internal/util/importutilv2/parquet/field_reader.go b/internal/util/importutilv2/parquet/field_reader.go index 858027bfa2c51..090a5e2a638fe 100644 --- a/internal/util/importutilv2/parquet/field_reader.go +++ b/internal/util/importutilv2/parquet/field_reader.go @@ -97,31 +97,7 @@ func (c *FieldReader) Next(count int64) (any, error) { case schemapb.DataType_VarChar, schemapb.DataType_String: return ReadStringData(c, count) case schemapb.DataType_JSON: - // JSON field read data from string array Parquet - data, err := ReadStringData(c, count) - if err != nil { - return nil, err - } - if data == nil { - return nil, nil - } - byteArr := make([][]byte, 0) - for _, str := range data.([]string) { - var dummy interface{} - err = json.Unmarshal([]byte(str), &dummy) - if err != nil { - return nil, err - } - if c.field.GetIsDynamic() { - var dummy2 map[string]interface{} - err = json.Unmarshal([]byte(str), &dummy2) - if err != nil { - return nil, err - } - } - byteArr = append(byteArr, []byte(str)) - } - return byteArr, nil + return ReadJSONData(c, count) case schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector: return ReadBinaryData(c, count) case schemapb.DataType_FloatVector: @@ -135,152 +111,9 @@ func (c *FieldReader) Next(count int64) (any, error) { vectors := lo.Flatten(arrayData.([][]float32)) return vectors, nil case schemapb.DataType_SparseFloatVector: - return ReadBinaryDataForSparseFloatVector(c, count) + return ReadSparseFloatVectorData(c, count) case schemapb.DataType_Array: - data := make([]*schemapb.ScalarField, 0, count) - elementType := c.field.GetElementType() - switch elementType { - case schemapb.DataType_Bool: - boolArray, err := ReadBoolArrayData(c, count) - if err != nil { - return nil, err - } - if boolArray == nil { - return nil, nil - } - for _, elementArray := range boolArray.([][]bool) { - data = append(data, &schemapb.ScalarField{ - Data: &schemapb.ScalarField_BoolData{ - BoolData: &schemapb.BoolArray{ - Data: elementArray, - }, - }, - }) - } - case schemapb.DataType_Int8: - int8Array, err := ReadIntegerOrFloatArrayData[int32](c, count) - if err != nil { - return nil, err - } - if int8Array == nil { - return nil, nil - } - for _, elementArray := range int8Array.([][]int32) { - data = append(data, &schemapb.ScalarField{ - Data: &schemapb.ScalarField_IntData{ - IntData: &schemapb.IntArray{ - Data: elementArray, - }, - }, - }) - } - case schemapb.DataType_Int16: - int16Array, err := ReadIntegerOrFloatArrayData[int32](c, count) - if err != nil { - return nil, err - } - if int16Array == nil { - return nil, nil - } - for _, elementArray := range int16Array.([][]int32) { - data = append(data, &schemapb.ScalarField{ - Data: &schemapb.ScalarField_IntData{ - IntData: &schemapb.IntArray{ - Data: elementArray, - }, - }, - }) - } - case schemapb.DataType_Int32: - int32Array, err := ReadIntegerOrFloatArrayData[int32](c, count) - if err != nil { - return nil, err - } - if int32Array == nil { - return nil, nil - } - for _, elementArray := range int32Array.([][]int32) { - data = append(data, &schemapb.ScalarField{ - Data: &schemapb.ScalarField_IntData{ - IntData: &schemapb.IntArray{ - Data: elementArray, - }, - }, - }) - } - case schemapb.DataType_Int64: - int64Array, err := ReadIntegerOrFloatArrayData[int64](c, count) - if err != nil { - return nil, err - } - if int64Array == nil { - return nil, nil - } - for _, elementArray := range int64Array.([][]int64) { - data = append(data, &schemapb.ScalarField{ - Data: &schemapb.ScalarField_LongData{ - LongData: &schemapb.LongArray{ - Data: elementArray, - }, - }, - }) - } - case schemapb.DataType_Float: - float32Array, err := ReadIntegerOrFloatArrayData[float32](c, count) - if err != nil { - return nil, err - } - if float32Array == nil { - return nil, nil - } - for _, elementArray := range float32Array.([][]float32) { - data = append(data, &schemapb.ScalarField{ - Data: &schemapb.ScalarField_FloatData{ - FloatData: &schemapb.FloatArray{ - Data: elementArray, - }, - }, - }) - } - case schemapb.DataType_Double: - float64Array, err := ReadIntegerOrFloatArrayData[float64](c, count) - if err != nil { - return nil, err - } - if float64Array == nil { - return nil, nil - } - for _, elementArray := range float64Array.([][]float64) { - data = append(data, &schemapb.ScalarField{ - Data: &schemapb.ScalarField_DoubleData{ - DoubleData: &schemapb.DoubleArray{ - Data: elementArray, - }, - }, - }) - } - case schemapb.DataType_VarChar, schemapb.DataType_String: - stringArray, err := ReadStringArrayData(c, count) - if err != nil { - return nil, err - } - if stringArray == nil { - return nil, nil - } - for _, elementArray := range stringArray.([][]string) { - data = append(data, &schemapb.ScalarField{ - Data: &schemapb.ScalarField_StringData{ - StringData: &schemapb.StringArray{ - Data: elementArray, - }, - }, - }) - } - default: - return nil, merr.WrapErrImportFailed(fmt.Sprintf("unsupported data type '%s' for array field '%s'", - elementType.String(), c.field.GetName())) - } - return data, nil + return ReadArrayData(c, count) default: return nil, merr.WrapErrImportFailed(fmt.Sprintf("unsupported data type '%s' for field '%s'", c.field.GetDataType().String(), c.field.GetName())) @@ -388,6 +221,34 @@ func ReadStringData(pcr *FieldReader, count int64) (any, error) { return data, nil } +func ReadJSONData(pcr *FieldReader, count int64) (any, error) { + // JSON field read data from string array Parquet + data, err := ReadStringData(pcr, count) + if err != nil { + return nil, err + } + if data == nil { + return nil, nil + } + byteArr := make([][]byte, 0) + for _, str := range data.([]string) { + var dummy interface{} + err = json.Unmarshal([]byte(str), &dummy) + if err != nil { + return nil, err + } + if pcr.field.GetIsDynamic() { + var dummy2 map[string]interface{} + err = json.Unmarshal([]byte(str), &dummy2) + if err != nil { + return nil, err + } + } + byteArr = append(byteArr, []byte(str)) + } + return byteArr, nil +} + func ReadBinaryData(pcr *FieldReader, count int64) (any, error) { dataType := pcr.field.GetDataType() chunked, err := pcr.columnReader.NextBatch(count) @@ -423,38 +284,32 @@ func ReadBinaryData(pcr *FieldReader, count int64) (any, error) { return data, nil } -func ReadBinaryDataForSparseFloatVector(pcr *FieldReader, count int64) (any, error) { - chunked, err := pcr.columnReader.NextBatch(count) +func ReadSparseFloatVectorData(pcr *FieldReader, count int64) (any, error) { + data, err := ReadStringData(pcr, count) if err != nil { return nil, err } - data := make([][]byte, 0, count) + if data == nil { + return nil, nil + } + byteArr := make([][]byte, 0, count) maxDim := uint32(0) - for _, chunk := range chunked.Chunks() { - listReader := chunk.(*array.List) - offsets := listReader.Offsets() - if !isVectorAligned(offsets, pcr.dim, schemapb.DataType_SparseFloatVector) { - return nil, merr.WrapErrImportFailed("%s not aligned", schemapb.DataType_SparseFloatVector.String()) - } - uint8Reader, ok := listReader.ListValues().(*array.Uint8) - if !ok { - return nil, WrapTypeErr("binary", listReader.ListValues().DataType().Name(), pcr.field) + for _, str := range data.([]string) { + rowVec, err := typeutil.CreateSparseFloatRowFromJSON([]byte(str)) + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("Invalid JSON string for SparseFloatVector: '%s'", str)) } - vecData := uint8Reader.Uint8Values() - for i := 1; i < len(offsets); i++ { - elemCount := int((offsets[i] - offsets[i-1]) / 8) - rowVec := vecData[offsets[i-1]:offsets[i]] - data = append(data, rowVec) - maxIdx := typeutil.SparseFloatRowIndexAt(rowVec, elemCount-1) - if maxIdx+1 > maxDim { - maxDim = maxIdx + 1 - } + byteArr = append(byteArr, rowVec) + elemCount := len(rowVec) / 8 + maxIdx := typeutil.SparseFloatRowIndexAt(rowVec, elemCount-1) + if maxIdx+1 > maxDim { + maxDim = maxIdx + 1 } } return &storage.SparseFloatVectorFieldData{ SparseFloatArray: schemapb.SparseFloatArray{ Dim: int64(maxDim), - Contents: data, + Contents: byteArr, }, }, nil } @@ -468,16 +323,6 @@ func checkVectorAlignWithDim(offsets []int32, dim int32) bool { return true } -func checkSparseFloatVectorAlign(offsets []int32) bool { - // index: 4 bytes, value: 4 bytes - for i := 1; i < len(offsets); i++ { - if (offsets[i]-offsets[i-1])%8 != 0 { - return false - } - } - return true -} - func isVectorAligned(offsets []int32, dim int, dataType schemapb.DataType) bool { if len(offsets) < 1 { return false @@ -490,7 +335,8 @@ func isVectorAligned(offsets []int32, dim int, dataType schemapb.DataType) bool case schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector: return checkVectorAlignWithDim(offsets, int32(dim*2)) case schemapb.DataType_SparseFloatVector: - return checkSparseFloatVectorAlign(offsets) + // JSON format, skip alignment check + return true default: return false } @@ -626,3 +472,150 @@ func ReadStringArrayData(pcr *FieldReader, count int64) (any, error) { } return data, nil } + +func ReadArrayData(pcr *FieldReader, count int64) (any, error) { + data := make([]*schemapb.ScalarField, 0, count) + elementType := pcr.field.GetElementType() + switch elementType { + case schemapb.DataType_Bool: + boolArray, err := ReadBoolArrayData(pcr, count) + if err != nil { + return nil, err + } + if boolArray == nil { + return nil, nil + } + for _, elementArray := range boolArray.([][]bool) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: elementArray, + }, + }, + }) + } + case schemapb.DataType_Int8: + int8Array, err := ReadIntegerOrFloatArrayData[int32](pcr, count) + if err != nil { + return nil, err + } + if int8Array == nil { + return nil, nil + } + for _, elementArray := range int8Array.([][]int32) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: elementArray, + }, + }, + }) + } + case schemapb.DataType_Int16: + int16Array, err := ReadIntegerOrFloatArrayData[int32](pcr, count) + if err != nil { + return nil, err + } + if int16Array == nil { + return nil, nil + } + for _, elementArray := range int16Array.([][]int32) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: elementArray, + }, + }, + }) + } + case schemapb.DataType_Int32: + int32Array, err := ReadIntegerOrFloatArrayData[int32](pcr, count) + if err != nil { + return nil, err + } + if int32Array == nil { + return nil, nil + } + for _, elementArray := range int32Array.([][]int32) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: elementArray, + }, + }, + }) + } + case schemapb.DataType_Int64: + int64Array, err := ReadIntegerOrFloatArrayData[int64](pcr, count) + if err != nil { + return nil, err + } + if int64Array == nil { + return nil, nil + } + for _, elementArray := range int64Array.([][]int64) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: elementArray, + }, + }, + }) + } + case schemapb.DataType_Float: + float32Array, err := ReadIntegerOrFloatArrayData[float32](pcr, count) + if err != nil { + return nil, err + } + if float32Array == nil { + return nil, nil + } + for _, elementArray := range float32Array.([][]float32) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: elementArray, + }, + }, + }) + } + case schemapb.DataType_Double: + float64Array, err := ReadIntegerOrFloatArrayData[float64](pcr, count) + if err != nil { + return nil, err + } + if float64Array == nil { + return nil, nil + } + for _, elementArray := range float64Array.([][]float64) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: elementArray, + }, + }, + }) + } + case schemapb.DataType_VarChar, schemapb.DataType_String: + stringArray, err := ReadStringArrayData(pcr, count) + if err != nil { + return nil, err + } + if stringArray == nil { + return nil, nil + } + for _, elementArray := range stringArray.([][]string) { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: elementArray, + }, + }, + }) + } + default: + return nil, merr.WrapErrImportFailed(fmt.Sprintf("unsupported data type '%s' for array field '%s'", + elementType.String(), pcr.field.GetName())) + } + return data, nil +} diff --git a/internal/util/importutilv2/parquet/util.go b/internal/util/importutilv2/parquet/util.go index 4164ff4f6ed03..d74b293474c17 100644 --- a/internal/util/importutilv2/parquet/util.go +++ b/internal/util/importutilv2/parquet/util.go @@ -183,7 +183,7 @@ func convertToArrowDataType(field *schemapb.FieldSchema, isArray bool) (arrow.Da Nullable: true, Metadata: arrow.Metadata{}, }), nil - case schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector, schemapb.DataType_SparseFloatVector: + case schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector: return arrow.ListOfField(arrow.Field{ Name: "item", Type: &arrow.Uint8Type{}, @@ -197,6 +197,8 @@ func convertToArrowDataType(field *schemapb.FieldSchema, isArray bool) (arrow.Da Nullable: true, Metadata: arrow.Metadata{}, }), nil + case schemapb.DataType_SparseFloatVector: + return &arrow.StringType{}, nil default: return nil, merr.WrapErrParameterInvalidMsg("unsupported data type %v", dataType.String()) } diff --git a/internal/util/mock/grpc_datanode_client.go b/internal/util/mock/grpc_datanode_client.go index 7ec47c7c3085d..6226286bbef2c 100644 --- a/internal/util/mock/grpc_datanode_client.go +++ b/internal/util/mock/grpc_datanode_client.go @@ -104,3 +104,7 @@ func (m *GrpcDataNodeClient) QueryImport(ctx context.Context, req *datapb.QueryI func (m *GrpcDataNodeClient) DropImport(ctx context.Context, req *datapb.DropImportRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, m.Err } + +func (m *GrpcDataNodeClient) QuerySlot(ctx context.Context, req *datapb.QuerySlotRequest, opts ...grpc.CallOption) (*datapb.QuerySlotResponse, error) { + return &datapb.QuerySlotResponse{}, m.Err +} diff --git a/internal/util/testutil/test_util.go b/internal/util/testutil/test_util.go index a9da8ca3b8ac6..d84affd043f7d 100644 --- a/internal/util/testutil/test_util.go +++ b/internal/util/testutil/test_util.go @@ -1,6 +1,7 @@ package testutil import ( + "encoding/json" "fmt" "math/rand" "strconv" @@ -333,23 +334,24 @@ func BuildArrayData(schema *schemapb.CollectionSchema, insertData *storage.Inser builder.AppendValues(offsets, valid) columns = append(columns, builder.NewListArray()) case schemapb.DataType_SparseFloatVector: - sparseFloatVecData := make([]byte, 0) - builder := array.NewListBuilder(mem, &arrow.Uint8Type{}) + builder := array.NewStringBuilder(mem) contents := insertData.Data[fieldID].(*storage.SparseFloatVectorFieldData).GetContents() rows := len(contents) - offsets := make([]int32, 0, rows) - valid := make([]bool, 0, rows) - currOffset := int32(0) + jsonBytesData := make([][]byte, 0) for i := 0; i < rows; i++ { rowVecData := contents[i] - sparseFloatVecData = append(sparseFloatVecData, rowVecData...) - offsets = append(offsets, currOffset) - currOffset = currOffset + int32(len(rowVecData)) - valid = append(valid, true) + mapData := typeutil.SparseFloatBytesToMap(rowVecData) + // convert to JSON format + jsonBytes, err := json.Marshal(mapData) + if err != nil { + return nil, err + } + jsonBytesData = append(jsonBytesData, jsonBytes) } - builder.ValueBuilder().(*array.Uint8Builder).AppendValues(sparseFloatVecData, nil) - builder.AppendValues(offsets, valid) - columns = append(columns, builder.NewListArray()) + builder.AppendValues(lo.Map(jsonBytesData, func(bs []byte, _ int) string { + return string(bs) + }), nil) + columns = append(columns, builder.NewStringArray()) case schemapb.DataType_JSON: builder := array.NewStringBuilder(mem) jsonData := insertData.Data[fieldID].(*storage.JSONFieldData).Data diff --git a/pkg/util/constant.go b/pkg/util/constant.go index 7bdad8a371f02..36e52d83dce45 100644 --- a/pkg/util/constant.go +++ b/pkg/util/constant.go @@ -70,6 +70,8 @@ const ( RoleConfigObjectName = "object_name" RoleConfigDBName = "db_name" RoleConfigPrivilege = "privilege" + + MaxEtcdTxnNum = 128 ) const ( diff --git a/pkg/util/etcd/etcd_util.go b/pkg/util/etcd/etcd_util.go index b92d651a2f8d5..37717ed4d252e 100644 --- a/pkg/util/etcd/etcd_util.go +++ b/pkg/util/etcd/etcd_util.go @@ -33,8 +33,6 @@ import ( "github.com/milvus-io/milvus/pkg/log" ) -var maxTxnNum = 128 - // GetEtcdClient returns etcd client // should only used for test func GetEtcdClient( @@ -191,11 +189,6 @@ func SaveByBatchWithLimit(kvs map[string]string, limit int, op func(partialKvs m return nil } -// SaveByBatch there will not guarantee atomicity. -func SaveByBatch(kvs map[string]string, op func(partialKvs map[string]string) error) error { - return SaveByBatchWithLimit(kvs, maxTxnNum, op) -} - func RemoveByBatchWithLimit(removals []string, limit int, op func(partialKeys []string) error) error { if len(removals) == 0 { return nil @@ -211,10 +204,6 @@ func RemoveByBatchWithLimit(removals []string, limit int, op func(partialKeys [] return nil } -func RemoveByBatch(removals []string, op func(partialKeys []string) error) error { - return RemoveByBatchWithLimit(removals, maxTxnNum, op) -} - func buildKvGroup(keys, values []string) (map[string]string, error) { if len(keys) != len(values) { return nil, fmt.Errorf("length of keys (%d) and values (%d) are not equal", len(keys), len(values)) diff --git a/pkg/util/etcd/etcd_util_test.go b/pkg/util/etcd/etcd_util_test.go index 86a60ae4eab2f..aa94d49dcb9c7 100644 --- a/pkg/util/etcd/etcd_util_test.go +++ b/pkg/util/etcd/etcd_util_test.go @@ -104,8 +104,8 @@ func Test_SaveByBatch(t *testing.T) { return nil } - maxTxnNum = 2 - err := SaveByBatch(kvs, saveFn) + limit := 2 + err := SaveByBatchWithLimit(kvs, limit, saveFn) assert.NoError(t, err) assert.Equal(t, 0, group) assert.Equal(t, 0, count) @@ -126,8 +126,8 @@ func Test_SaveByBatch(t *testing.T) { return nil } - maxTxnNum = 2 - err := SaveByBatch(kvs, saveFn) + limit := 2 + err := SaveByBatchWithLimit(kvs, limit, saveFn) assert.NoError(t, err) assert.Equal(t, 2, group) assert.Equal(t, 3, count) @@ -142,8 +142,8 @@ func Test_SaveByBatch(t *testing.T) { "k2": "v2", "k3": "v3", } - maxTxnNum = 2 - err := SaveByBatch(kvs, saveFn) + limit := 2 + err := SaveByBatchWithLimit(kvs, limit, saveFn) assert.Error(t, err) }) } @@ -160,8 +160,8 @@ func Test_RemoveByBatch(t *testing.T) { return nil } - maxTxnNum = 2 - err := RemoveByBatch(kvs, removeFn) + limit := 2 + err := RemoveByBatchWithLimit(kvs, limit, removeFn) assert.NoError(t, err) assert.Equal(t, 0, group) assert.Equal(t, 0, count) @@ -178,8 +178,8 @@ func Test_RemoveByBatch(t *testing.T) { return nil } - maxTxnNum = 2 - err := RemoveByBatch(kvs, removeFn) + limit := 2 + err := RemoveByBatchWithLimit(kvs, limit, removeFn) assert.NoError(t, err) assert.Equal(t, 3, group) assert.Equal(t, 5, count) @@ -190,8 +190,8 @@ func Test_RemoveByBatch(t *testing.T) { return errors.New("mock") } kvs := []string{"k1", "k2", "k3", "k4", "k5"} - maxTxnNum = 2 - err := RemoveByBatch(kvs, removeFn) + limit := 2 + err := RemoveByBatchWithLimit(kvs, limit, removeFn) assert.Error(t, err) }) } diff --git a/pkg/util/indexparamcheck/bitmap_checker_test.go b/pkg/util/indexparamcheck/bitmap_checker_test.go new file mode 100644 index 0000000000000..4b0cca2bf3309 --- /dev/null +++ b/pkg/util/indexparamcheck/bitmap_checker_test.go @@ -0,0 +1,22 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_BitmapIndexChecker(t *testing.T) { + c := newBITMAPChecker() + + assert.NoError(t, c.CheckTrain(map[string]string{})) + + assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Int64)) + assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Float)) + assert.NoError(t, c.CheckValidDataType(schemapb.DataType_String)) + + assert.Error(t, c.CheckValidDataType(schemapb.DataType_JSON)) + assert.Error(t, c.CheckValidDataType(schemapb.DataType_Array)) +} diff --git a/pkg/util/indexparamcheck/bitmap_index_checker.go b/pkg/util/indexparamcheck/bitmap_index_checker.go new file mode 100644 index 0000000000000..da90a7d06db3a --- /dev/null +++ b/pkg/util/indexparamcheck/bitmap_index_checker.go @@ -0,0 +1,28 @@ +package indexparamcheck + +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// STLSORTChecker checks if a STL_SORT index can be built. +type BITMAPChecker struct { + scalarIndexChecker +} + +func (c *BITMAPChecker) CheckTrain(params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(params) +} + +func (c *BITMAPChecker) CheckValidDataType(dType schemapb.DataType) error { + if !typeutil.IsArithmetic(dType) && !typeutil.IsStringType(dType) { + return fmt.Errorf("bitmap index are only supported on numeric and string field") + } + return nil +} + +func newBITMAPChecker() *BITMAPChecker { + return &BITMAPChecker{} +} diff --git a/pkg/util/indexparamcheck/conf_adapter_mgr.go b/pkg/util/indexparamcheck/conf_adapter_mgr.go index 9fdc1a1af6086..d79196f72a619 100644 --- a/pkg/util/indexparamcheck/conf_adapter_mgr.go +++ b/pkg/util/indexparamcheck/conf_adapter_mgr.go @@ -65,6 +65,7 @@ func (mgr *indexCheckerMgrImpl) registerIndexChecker() { mgr.checkers["Asceneding"] = newSTLSORTChecker() mgr.checkers[IndexTRIE] = newTRIEChecker() mgr.checkers[IndexTrie] = newTRIEChecker() + mgr.checkers[IndexBitmap] = newBITMAPChecker() mgr.checkers["marisa-trie"] = newTRIEChecker() mgr.checkers[AutoIndex] = newAUTOINDEXChecker() } diff --git a/pkg/util/indexparamcheck/index_type.go b/pkg/util/indexparamcheck/index_type.go index 7b24202f02a16..e752057ea4e85 100644 --- a/pkg/util/indexparamcheck/index_type.go +++ b/pkg/util/indexparamcheck/index_type.go @@ -37,6 +37,7 @@ const ( IndexSTLSORT IndexType = "STL_SORT" IndexTRIE IndexType = "TRIE" IndexTrie IndexType = "Trie" + IndexBitmap IndexType = "BITMAP" AutoIndex IndexType = "AUTOINDEX" ) diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 79cbf64c412cf..117757815a39b 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -3406,6 +3406,9 @@ type dataNodeConfig struct { L0BatchMemoryRatio ParamItem `refreshable:"true"` GracefulStopTimeout ParamItem `refreshable:"true"` + + // slot + SlotCap ParamItem `refreshable:"true"` } func (p *dataNodeConfig) init(base *BaseTable) { @@ -3711,6 +3714,15 @@ if this parameter <= 0, will set it as 10`, Export: true, } p.GracefulStopTimeout.Init(base.mgr) + + p.SlotCap = ParamItem{ + Key: "dataNode.slot.slotCap", + Version: "2.4.2", + DefaultValue: "2", + Doc: "The maximum number of tasks(e.g. compaction, importing) allowed to run concurrently on a datanode", + Export: true, + } + p.SlotCap.Init(base.mgr) } // ///////////////////////////////////////////////////////////////////////////// diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index c7aea9fb820d0..d3918c9d432e3 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -481,6 +481,7 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, 16, Params.ReadBufferSizeInMB.GetAsInt()) params.Save("datanode.gracefulStopTimeout", "100") assert.Equal(t, 100*time.Second, Params.GracefulStopTimeout.GetAsDuration(time.Second)) + assert.Equal(t, 2, Params.SlotCap.GetAsInt()) }) t.Run("test indexNodeConfig", func(t *testing.T) { diff --git a/pkg/util/typeutil/convension.go b/pkg/util/typeutil/convension.go index 0bde57a7eb5b8..95e138b5c50a8 100644 --- a/pkg/util/typeutil/convension.go +++ b/pkg/util/typeutil/convension.go @@ -154,3 +154,14 @@ func BFloat16BytesToFloat32Vector(b []byte) []float32 { } return vec } + +func SparseFloatBytesToMap(b []byte) map[uint32]float32 { + elemCount := len(b) / 8 + values := make(map[uint32]float32) + for j := 0; j < elemCount; j++ { + idx := common.Endian.Uint32(b[j*8:]) + f := BytesToFloat32(b[j*8+4:]) + values[idx] = f + } + return values +} diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index 7a2af87d2e242..8277ccbe438a1 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -17,6 +17,7 @@ package typeutil import ( + "bytes" "encoding/binary" "encoding/json" "fmt" @@ -1535,18 +1536,13 @@ func CreateSparseFloatRow(indices []uint32, values []float32) []byte { return row } -type sparseFloatVectorJSONRepresentation struct { - Indices []uint32 `json:"indices"` - Values []float32 `json:"values"` -} - // accepted format: // - {"indices": [1, 2, 3], "values": [0.1, 0.2, 0.3]} # format1 // - {"1": 0.1, "2": 0.2, "3": 0.3} # format2 // // we don't require the indices to be sorted from user input, but the returned // byte representation must have indices sorted -func CreateSparseFloatRowFromJSON(input map[string]interface{}) ([]byte, error) { +func CreateSparseFloatRowFromMap(input map[string]interface{}) ([]byte, error) { var indices []uint32 var values []float32 @@ -1601,6 +1597,17 @@ func CreateSparseFloatRowFromJSON(input map[string]interface{}) ([]byte, error) return row, nil } +func CreateSparseFloatRowFromJSON(input []byte) ([]byte, error) { + var vec map[string]interface{} + decoder := json.NewDecoder(bytes.NewReader(input)) + decoder.DisallowUnknownFields() + err := decoder.Decode(&vec) + if err != nil { + return nil, err + } + return CreateSparseFloatRowFromMap(vec) +} + // dim of a sparse float vector is the maximum/last index + 1 func SparseFloatRowDim(row []byte) int64 { if len(row) == 0 { diff --git a/pkg/util/typeutil/schema_test.go b/pkg/util/typeutil/schema_test.go index 606ba7b4ec69d..67601a719d9e9 100644 --- a/pkg/util/typeutil/schema_test.go +++ b/pkg/util/typeutil/schema_test.go @@ -2121,89 +2121,89 @@ func TestValidateSparseFloatRows(t *testing.T) { func TestParseJsonSparseFloatRow(t *testing.T) { t.Run("valid row 1", func(t *testing.T) { row := map[string]interface{}{"indices": []uint32{1, 3, 5}, "values": []float32{1.0, 2.0, 3.0}} - res, err := CreateSparseFloatRowFromJSON(row) + res, err := CreateSparseFloatRowFromMap(row) assert.NoError(t, err) assert.Equal(t, CreateSparseFloatRow([]uint32{1, 3, 5}, []float32{1.0, 2.0, 3.0}), res) }) t.Run("valid row 2", func(t *testing.T) { row := map[string]interface{}{"indices": []uint32{3, 1, 5}, "values": []float32{1.0, 2.0, 3.0}} - res, err := CreateSparseFloatRowFromJSON(row) + res, err := CreateSparseFloatRowFromMap(row) assert.NoError(t, err) assert.Equal(t, CreateSparseFloatRow([]uint32{1, 3, 5}, []float32{2.0, 1.0, 3.0}), res) }) t.Run("invalid row 1", func(t *testing.T) { row := map[string]interface{}{"indices": []uint32{1, 3, 5}, "values": []float32{1.0, 2.0}} - _, err := CreateSparseFloatRowFromJSON(row) + _, err := CreateSparseFloatRowFromMap(row) assert.Error(t, err) }) t.Run("invalid row 2", func(t *testing.T) { row := map[string]interface{}{"indices": []uint32{1}, "values": []float32{1.0, 2.0}} - _, err := CreateSparseFloatRowFromJSON(row) + _, err := CreateSparseFloatRowFromMap(row) assert.Error(t, err) }) t.Run("invalid row 3", func(t *testing.T) { row := map[string]interface{}{"indices": []uint32{}, "values": []float32{}} - _, err := CreateSparseFloatRowFromJSON(row) + _, err := CreateSparseFloatRowFromMap(row) assert.Error(t, err) }) t.Run("invalid row 4", func(t *testing.T) { row := map[string]interface{}{"indices": []uint32{3}, "values": []float32{-0.2}} - _, err := CreateSparseFloatRowFromJSON(row) + _, err := CreateSparseFloatRowFromMap(row) assert.Error(t, err) }) t.Run("valid dict row 1", func(t *testing.T) { row := map[string]interface{}{"1": 1.0, "3": 2.0, "5": 3.0} - res, err := CreateSparseFloatRowFromJSON(row) + res, err := CreateSparseFloatRowFromMap(row) assert.NoError(t, err) assert.Equal(t, CreateSparseFloatRow([]uint32{1, 3, 5}, []float32{1.0, 2.0, 3.0}), res) }) t.Run("valid dict row 2", func(t *testing.T) { row := map[string]interface{}{"3": 1.0, "1": 2.0, "5": 3.0} - res, err := CreateSparseFloatRowFromJSON(row) + res, err := CreateSparseFloatRowFromMap(row) assert.NoError(t, err) assert.Equal(t, CreateSparseFloatRow([]uint32{1, 3, 5}, []float32{2.0, 1.0, 3.0}), res) }) t.Run("invalid dict row 1", func(t *testing.T) { row := map[string]interface{}{"a": 1.0, "3": 2.0, "5": 3.0} - _, err := CreateSparseFloatRowFromJSON(row) + _, err := CreateSparseFloatRowFromMap(row) assert.Error(t, err) }) t.Run("invalid dict row 2", func(t *testing.T) { row := map[string]interface{}{"1": "a", "3": 2.0, "5": 3.0} - _, err := CreateSparseFloatRowFromJSON(row) + _, err := CreateSparseFloatRowFromMap(row) assert.Error(t, err) }) t.Run("invalid dict row 3", func(t *testing.T) { row := map[string]interface{}{"1": "1.0", "3": 2.0, "5": 3.0} - _, err := CreateSparseFloatRowFromJSON(row) + _, err := CreateSparseFloatRowFromMap(row) assert.Error(t, err) }) t.Run("invalid dict row 4", func(t *testing.T) { row := map[string]interface{}{"-1": 1.0, "3": 2.0, "5": 3.0} - _, err := CreateSparseFloatRowFromJSON(row) + _, err := CreateSparseFloatRowFromMap(row) assert.Error(t, err) }) t.Run("invalid dict row 5", func(t *testing.T) { row := map[string]interface{}{"1": -1.0, "3": 2.0, "5": 3.0} - _, err := CreateSparseFloatRowFromJSON(row) + _, err := CreateSparseFloatRowFromMap(row) assert.Error(t, err) }) t.Run("invalid dict row 6", func(t *testing.T) { row := map[string]interface{}{} - _, err := CreateSparseFloatRowFromJSON(row) + _, err := CreateSparseFloatRowFromMap(row) assert.Error(t, err) }) } diff --git a/scripts/README.md b/scripts/README.md index 6b702620fe483..8cb64fbca7dc4 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -23,6 +23,14 @@ $ go get github.com/golang/protobuf/protoc-gen-go@v1.3.2 Install OpenBlas library +install using apt + +```shell +sudo apt install -y libopenblas-dev +``` + +or build from source code + ```shell $ wget https://github.com/xianyi/OpenBLAS/archive/v0.3.9.tar.gz && \ $ tar zxvf v0.3.9.tar.gz && cd OpenBLAS-0.3.9 && \ diff --git a/scripts/install_deps.sh b/scripts/install_deps.sh index 608262f425c06..8ae371aa2a075 100755 --- a/scripts/install_deps.sh +++ b/scripts/install_deps.sh @@ -22,7 +22,7 @@ function install_linux_deps() { sudo apt install -y wget curl ca-certificates gnupg2 \ g++ gcc gfortran git make ccache libssl-dev zlib1g-dev zip unzip \ clang-format-10 clang-tidy-10 lcov libtool m4 autoconf automake python3 python3-pip \ - pkg-config uuid-dev libaio-dev libgoogle-perftools-dev + pkg-config uuid-dev libaio-dev libopenblas-dev libgoogle-perftools-dev sudo pip3 install conan==1.61.0 elif [[ -x "$(command -v yum)" ]]; then @@ -31,7 +31,7 @@ function install_linux_deps() { sudo yum install -y wget curl which \ git make automake python3-devel \ devtoolset-11-gcc devtoolset-11-gcc-c++ devtoolset-11-gcc-gfortran devtoolset-11-libatomic-devel \ - llvm-toolset-11.0-clang llvm-toolset-11.0-clang-tools-extra \ + llvm-toolset-11.0-clang llvm-toolset-11.0-clang-tools-extra openblas-devel \ libaio libuuid-devel zip unzip \ ccache lcov libtool m4 autoconf automake diff --git a/tests/integration/import/util_test.go b/tests/integration/import/util_test.go index 8df58c0541bd3..237a705ec2474 100644 --- a/tests/integration/import/util_test.go +++ b/tests/integration/import/util_test.go @@ -37,7 +37,6 @@ import ( "github.com/milvus-io/milvus/internal/storage" pq "github.com/milvus-io/milvus/internal/util/importutilv2/parquet" "github.com/milvus-io/milvus/internal/util/testutil" - "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -234,14 +233,7 @@ func GenerateJSONFile(t *testing.T, filePath string, schema *schemapb.Collection data[fieldID] = typeutil.BFloat16BytesToFloat32Vector(bytes) case schemapb.DataType_SparseFloatVector: bytes := v.GetRow(i).([]byte) - elemCount := len(bytes) / 8 - values := make(map[uint32]float32) - for j := 0; j < elemCount; j++ { - idx := common.Endian.Uint32(bytes[j*8:]) - f := typeutil.BytesToFloat32(bytes[j*8+4:]) - values[idx] = f - } - data[fieldID] = values + data[fieldID] = typeutil.SparseFloatBytesToMap(bytes) default: data[fieldID] = v.GetRow(i) } diff --git a/tests/python_client/requirements.txt b/tests/python_client/requirements.txt index 84976974b8c84..177e44cd3692f 100644 --- a/tests/python_client/requirements.txt +++ b/tests/python_client/requirements.txt @@ -12,6 +12,7 @@ allure-pytest==2.7.0 pytest-print==0.2.1 pytest-level==0.1.1 pytest-xdist==2.5.0 +pymilvus==2.5.0rc31 pymilvus[bulk_writer]==2.5.0rc31 pytest-rerunfailures==9.1.1 git+https://github.com/Projectplace/pytest-tags diff --git a/tests/python_client/testcases/test_utility.py b/tests/python_client/testcases/test_utility.py index d077d83d7f056..f4eccf19597cc 100644 --- a/tests/python_client/testcases/test_utility.py +++ b/tests/python_client/testcases/test_utility.py @@ -57,42 +57,67 @@ def get_invalid_value_collection_name(self, request): """ @pytest.mark.tags(CaseLabel.L2) - def test_has_collection_name_invalid(self, get_invalid_collection_name): + def test_has_collection_name_type_invalid(self, get_invalid_type_collection_name): """ target: test has_collection with error collection name method: input invalid name expected: raise exception """ self._connect() - c_name = get_invalid_collection_name - if isinstance(c_name, str) and c_name: - self.utility_wrap.has_collection( - c_name, - check_task=CheckTasks.err_res, - check_items={ct.err_code: 1100, - ct.err_msg: "collection name should not be empty: invalid parameter"}) - # elif not isinstance(c_name, str): self.utility_wrap.has_collection(c_name, check_task=CheckTasks.err_res, - # check_items={ct.err_code: 1, ct.err_msg: "illegal"}) + c_name = get_invalid_type_collection_name + self.utility_wrap.has_collection(c_name, check_task=CheckTasks.err_res, + check_items={ct.err_code: 999, + ct.err_msg: f"`collection_name` value {c_name} is illegal"}) @pytest.mark.tags(CaseLabel.L2) - def test_has_partition_collection_name_invalid(self, get_invalid_collection_name): + def test_has_collection_name_value_invalid(self, get_invalid_value_collection_name): + """ + target: test has_collection with error collection name + method: input invalid name + expected: raise exception + """ + self._connect() + c_name = get_invalid_value_collection_name + error = {ct.err_code: 999, ct.err_msg: f"Invalid collection name: {c_name}"} + if c_name in [None, ""]: + error = {ct.err_code: 999, ct.err_msg: f"`collection_name` value {c_name} is illegal"} + elif c_name == " ": + error = {ct.err_code: 999, ct.err_msg: "collection name should not be empty: invalid parameter"} + self.utility_wrap.has_collection(c_name, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L2) + def test_has_partition_collection_name_type_invalid(self, get_invalid_type_collection_name): """ target: test has_partition with error collection name method: input invalid name expected: raise exception """ self._connect() - c_name = get_invalid_collection_name + c_name = get_invalid_type_collection_name p_name = cf.gen_unique_str(prefix) - if isinstance(c_name, str) and c_name: - self.utility_wrap.has_partition( - c_name, p_name, - check_task=CheckTasks.err_res, - check_items={ct.err_code: 1100, - ct.err_msg: "collection name should not be empty: invalid parameter"}) + self.utility_wrap.has_partition(c_name, p_name, check_task=CheckTasks.err_res, + check_items={ct.err_code: 999, + ct.err_msg: f"`collection_name` value {c_name} is illegal"}) @pytest.mark.tags(CaseLabel.L2) - def test_has_partition_name_invalid(self, get_invalid_partition_name): + def test_has_partition_collection_name_value_invalid(self, get_invalid_value_collection_name): + """ + target: test has_partition with error collection name + method: input invalid name + expected: raise exception + """ + self._connect() + c_name = get_invalid_value_collection_name + p_name = cf.gen_unique_str(prefix) + error = {ct.err_code: 999, ct.err_msg: f"Invalid collection name: {c_name}"} + if c_name in [None, ""]: + error = {ct.err_code: 999, ct.err_msg: f"`collection_name` value {c_name} is illegal"} + elif c_name == " ": + error = {ct.err_code: 999, ct.err_msg: "collection name should not be empty: invalid parameter"} + self.utility_wrap.has_partition(c_name, p_name, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L2) + def test_has_partition_name_type_invalid(self, get_invalid_type_collection_name): """ target: test has_partition with error partition name method: input invalid name @@ -101,21 +126,49 @@ def test_has_partition_name_invalid(self, get_invalid_partition_name): self._connect() ut = ApiUtilityWrapper() c_name = cf.gen_unique_str(prefix) - p_name = get_invalid_partition_name - if isinstance(p_name, str) and p_name: - ex, _ = ut.has_partition( - c_name, p_name, - check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "Invalid"}) + p_name = get_invalid_type_collection_name + ut.has_partition(c_name, p_name, check_task=CheckTasks.err_res, + check_items={ct.err_code: 999, + ct.err_msg: f"`partition_name` value {p_name} is illegal"}) @pytest.mark.tags(CaseLabel.L2) - def test_drop_collection_name_invalid(self, get_invalid_collection_name): + def test_has_partition_name_value_invalid(self, get_invalid_value_collection_name): + """ + target: test has_partition with error partition name + method: input invalid name + expected: raise exception + """ self._connect() - error1 = {ct.err_code: 1, ct.err_msg: f"`collection_name` value {get_invalid_collection_name} is illegal"} - error2 = {ct.err_code: 1100, ct.err_msg: f"Invalid collection name: {get_invalid_collection_name}."} - error = error1 if get_invalid_collection_name in [[], 1, [1, '2', 3], (1,), {1: 1}, None, ""] else error2 - self.utility_wrap.drop_collection(get_invalid_collection_name, check_task=CheckTasks.err_res, - check_items=error) + ut = ApiUtilityWrapper() + c_name = cf.gen_unique_str(prefix) + p_name = get_invalid_value_collection_name + if p_name == "12name": + pytest.skip("partition name 12name is legal") + error = {ct.err_code: 999, ct.err_msg: f"Invalid partition name: {p_name}"} + if p_name in [None]: + error = {ct.err_code: 999, ct.err_msg: f"`partition_name` value {p_name} is illegal"} + elif p_name in [" ", ""]: + error = {ct.err_code: 999, ct.err_msg: "Invalid partition name: . Partition name should not be empty."} + ut.has_partition(c_name, p_name, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L2) + def test_drop_collection_name_type_invalid(self, get_invalid_type_collection_name): + self._connect() + c_name = get_invalid_type_collection_name + self.utility_wrap.drop_collection(c_name, check_task=CheckTasks.err_res, + check_items={ct.err_code: 999, + ct.err_msg: f"`collection_name` value {c_name} is illegal"}) + + @pytest.mark.tags(CaseLabel.L2) + def test_drop_collection_name_value_invalid(self, get_invalid_value_collection_name): + self._connect() + c_name = get_invalid_value_collection_name + error = {ct.err_code: 999, ct.err_msg: f"Invalid collection name: {c_name}"} + if c_name in [None, ""]: + error = {ct.err_code: 999, ct.err_msg: f"`collection_name` value {c_name} is illegal"} + elif c_name == " ": + error = {ct.err_code: 999, ct.err_msg: "collection name should not be empty: invalid parameter"} + self.utility_wrap.drop_collection(c_name, check_task=CheckTasks.err_res, check_items=error) # TODO: enable @pytest.mark.tags(CaseLabel.L2) @@ -162,7 +215,8 @@ def test_index_process_invalid_index_name(self, invalid_index_name): check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) - def test_wait_index_invalid_name(self, get_invalid_collection_name): + @pytest.mark.skip("not ready") + def test_wait_index_invalid_name(self, get_invalid_type_collection_name): """ target: test wait_index method: input invalid name @@ -436,13 +490,12 @@ def test_rename_collection_new_invalid_value(self, get_invalid_value_collection_ collection_w, vectors, _, insert_ids, _ = self.init_collection_general(prefix) old_collection_name = collection_w.name new_collection_name = get_invalid_value_collection_name + error = {"err_code": 1100, "err_msg": "Invalid collection name: %s. the first character of a collection name mu" + "st be an underscore or letter: invalid parameter" % new_collection_name} + if new_collection_name in [None, ""]: + error = {"err_code": 999, "err_msg": f"`collection_name` value {new_collection_name} is illegal"} self.utility_wrap.rename_collection(old_collection_name, new_collection_name, - check_task=CheckTasks.err_res, - check_items={"err_code": 1100, - "err_msg": "Invalid collection name: %s. the first " - "character of a collection name must be an " - "underscore or letter: invalid parameter" - % new_collection_name}) + check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_rename_collection_not_existed_collection(self):