Skip to content

Commit

Permalink
[BugFix] Fix getNextWorker overflow (#53213)
Browse files Browse the repository at this point in the history
Signed-off-by: zihe.liu <[email protected]>
  • Loading branch information
ZiheLiu authored Nov 27, 2024
1 parent 4ebd6b8 commit 4e70ef5
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.IntSupplier;
import java.util.stream.Collectors;

import static com.starrocks.qe.WorkerProviderHelper.getNextWorker;

/**
* WorkerProvider for SHARED_DATA mode. Compared to its counterpart for SHARED_NOTHING mode:
* 1. All Backends and ComputeNodes are treated the same as ComputeNodes.
Expand Down Expand Up @@ -271,13 +272,9 @@ static int getNextComputeNodeIndex() {
return NEXT_COMPUTE_NODE_INDEX.getAndIncrement();
}

private static ComputeNode getNextWorker(ImmutableMap<Long, ComputeNode> workers,
IntSupplier getNextWorkerNodeIndex) {
if (workers.isEmpty()) {
return null;
}
int index = getNextWorkerNodeIndex.getAsInt() % workers.size();
return workers.values().asList().get(index);
@VisibleForTesting
static AtomicInteger getNextComputeNodeIndexer() {
return NEXT_COMPUTE_NODE_INDEX;
}

private static ImmutableMap<Long, ComputeNode> filterAvailableWorkers(ImmutableMap<Long, ComputeNode> workers) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.starrocks.qe;

import com.google.common.collect.ImmutableMap;
import com.starrocks.system.ComputeNode;

import java.util.function.IntSupplier;

public class WorkerProviderHelper {
public static <C extends ComputeNode> C getNextWorker(ImmutableMap<Long, C> workers,
IntSupplier getNextWorkerNodeIndex) {
if (workers.isEmpty()) {
return null;
}
int index = getNextWorkerNodeIndex.getAsInt() % workers.size();
if (index < 0) {
index = -index;
}
return workers.values().asList().get(index);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,17 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.IntSupplier;
import java.util.stream.Collectors;

import static com.starrocks.qe.WorkerProviderHelper.getNextWorker;

/**
* DefaultWorkerProvider handles ComputeNode/Backend selection in SHARED_NOTHING mode.
* NOTE: remember to update DefaultSharedDataWorkerProvider if the change applies to both run modes.
*/
public class DefaultWorkerProvider implements WorkerProvider {
private static final Logger LOG = LogManager.getLogger(DefaultWorkerProvider.class);

private static final AtomicInteger NEXT_COMPUTE_NODE_INDEX = new AtomicInteger(0);
private static final AtomicInteger NEXT_BACKEND_INDEX = new AtomicInteger(0);

Expand Down Expand Up @@ -401,19 +403,15 @@ private static ImmutableMap<Long, ComputeNode> buildComputeNodeInfo(SystemInfoSe
return ImmutableMap.copyOf(computeNodes);
}

private static <C extends ComputeNode> C getNextWorker(ImmutableMap<Long, C> workers,
IntSupplier getNextWorkerNodeIndex) {
if (workers.isEmpty()) {
return null;
}
int index = getNextWorkerNodeIndex.getAsInt() % workers.size();
return workers.values().asList().get(index);
}

public static boolean isWorkerAvailable(ComputeNode worker) {
return worker.isAlive() && !SimpleScheduler.isInBlocklist(worker.getId());
}

@VisibleForTesting
static AtomicInteger getNextComputeNodeIndexer() {
return NEXT_COMPUTE_NODE_INDEX;
}

private static <C extends ComputeNode> ImmutableMap<Long, C> filterAvailableWorkers(ImmutableMap<Long, C> workers) {
return ImmutableMap.copyOf(
workers.entrySet().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;

public class DefaultSharedDataWorkerProviderTest {
private Map<Long, ComputeNode> id2Backend;
private Map<Long, ComputeNode> id2ComputeNode;
Expand Down Expand Up @@ -658,4 +660,19 @@ public void testCollocationBackendSelectorWithSharedDataWorkerProvider() {
Assert.assertThrows(NonRecoverableException.class, selector::computeScanRangeAssignment);
}
}

@Test
public void testNextWorkerOverflow() throws NonRecoverableException {
WorkerProvider provider =
new DefaultSharedDataWorkerProvider(ImmutableMap.copyOf(id2AllNodes), ImmutableMap.copyOf(id2AllNodes));
for (int i = 0; i < 100; i++) {
Long workerId = provider.selectNextWorker();
assertThat(workerId).isNotNegative();
}
DefaultSharedDataWorkerProvider.getNextComputeNodeIndexer().set(Integer.MAX_VALUE);
for (int i = 0; i < 100; i++) {
Long workerId = provider.selectNextWorker();
assertThat(workerId).isNotNegative();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ private static <C extends ComputeNode> ImmutableMap<Long, C> genWorkers(long sta
return ImmutableMap.copyOf(res);
}


@Test
public void testCaptureAvailableWorkers() {

Expand Down Expand Up @@ -129,7 +128,7 @@ public ImmutableMap<Long, ComputeNode> getIdComputeNode() {

workerProvider =
workerProviderFactory.captureAvailableWorkers(GlobalStateMgr.getCurrentState().getNodeMgr().getClusterInfo(),
true, numUsedComputeNodes, ComputationFragmentSchedulingPolicy.COMPUTE_NODES_ONLY,
true, numUsedComputeNodes, ComputationFragmentSchedulingPolicy.COMPUTE_NODES_ONLY,
WarehouseManager.DEFAULT_WAREHOUSE_ID);

int numAvailableComputeNodes = 0;
Expand Down Expand Up @@ -176,7 +175,7 @@ public ImmutableMap<Long, ComputeNode> getIdComputeNode() {
for (Integer numUsedComputeNodes : numUsedComputeNodesList) {
workerProvider =
workerProviderFactory.captureAvailableWorkers(GlobalStateMgr.getCurrentState().getNodeMgr().getClusterInfo(),
true, numUsedComputeNodes, ComputationFragmentSchedulingPolicy.COMPUTE_NODES_ONLY,
true, numUsedComputeNodes, ComputationFragmentSchedulingPolicy.COMPUTE_NODES_ONLY,
WarehouseManager.DEFAULT_WAREHOUSE_ID);
List<Long> selectedWorkerIdsList = workerProvider.getAllAvailableNodes();
for (Long selectedWorkerId : selectedWorkerIdsList) {
Expand All @@ -188,7 +187,7 @@ public ImmutableMap<Long, ComputeNode> getIdComputeNode() {
for (Integer numUsedComputeNodes : numUsedComputeNodesList) {
workerProvider =
workerProviderFactory.captureAvailableWorkers(GlobalStateMgr.getCurrentState().getNodeMgr().getClusterInfo(),
false, numUsedComputeNodes, ComputationFragmentSchedulingPolicy.COMPUTE_NODES_ONLY,
false, numUsedComputeNodes, ComputationFragmentSchedulingPolicy.COMPUTE_NODES_ONLY,
WarehouseManager.DEFAULT_WAREHOUSE_ID);
List<Long> selectedWorkerIdsList = workerProvider.getAllAvailableNodes();
Assert.assertEquals(availableId2Backend.size(), selectedWorkerIdsList.size());
Expand All @@ -201,7 +200,7 @@ public ImmutableMap<Long, ComputeNode> getIdComputeNode() {
for (Integer numUsedComputeNodes : numUsedComputeNodesList) {
workerProvider =
workerProviderFactory.captureAvailableWorkers(GlobalStateMgr.getCurrentState().getNodeMgr().getClusterInfo(),
true, numUsedComputeNodes, ComputationFragmentSchedulingPolicy.ALL_NODES,
true, numUsedComputeNodes, ComputationFragmentSchedulingPolicy.ALL_NODES,
WarehouseManager.DEFAULT_WAREHOUSE_ID);
List<Long> selectedWorkerIdsList = workerProvider.getAllAvailableNodes();
Collections.reverse(selectedWorkerIdsList); //put ComputeNode id to the front,Backend id to the back
Expand Down Expand Up @@ -377,6 +376,21 @@ public void testReportBackendNotFoundException() {
Assert.assertThrows(SchedulerException.class, workerProvider::reportDataNodeNotFoundException);
}

@Test
public void testNextWorkerOverflow() throws NonRecoverableException {
DefaultWorkerProvider workerProvider =
new DefaultWorkerProvider(id2Backend, id2ComputeNode, availableId2Backend, availableId2ComputeNode, true);
for (int i = 0; i < 100; i++) {
Long workerId = workerProvider.selectNextWorker();
assertThat(workerId).isNotNegative();
}
DefaultWorkerProvider.getNextComputeNodeIndexer().set(Integer.MAX_VALUE);
for (int i = 0; i < 100; i++) {
Long workerId = workerProvider.selectNextWorker();
assertThat(workerId).isNotNegative();
}
}

public static void testUsingWorkerHelper(DefaultWorkerProvider workerProvider, Long workerId) {
Assert.assertTrue(workerProvider.isWorkerSelected(workerId));
assertThat(workerProvider.getSelectedWorkerIds()).contains(workerId);
Expand Down

0 comments on commit 4e70ef5

Please sign in to comment.