Skip to content

Commit

Permalink
Fix get task API does not refresh resource stats
Browse files Browse the repository at this point in the history
Signed-off-by: Gao Binlong <[email protected]>
  • Loading branch information
gaobinlong committed Dec 8, 2023
1 parent c1b3a73 commit 9f77336
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix remote shards balancer and remove unused variables ([#11167](https://github.com/opensearch-project/OpenSearch/pull/11167))
- Fix bug where replication lag grows post primary relocation ([#11238](https://github.com/opensearch-project/OpenSearch/pull/11238))
- Fix template setting override for replication type ([#11417](https://github.com/opensearch-project/OpenSearch/pull/11417))
- Fix get task API does not refresh resource stats

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskInfo;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.tasks.TaskResult;
import org.opensearch.tasks.TaskResultsService;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -84,21 +85,25 @@ public class TransportGetTaskAction extends HandledTransportAction<GetTaskReques
private final Client client;
private final NamedXContentRegistry xContentRegistry;

private final TaskResourceTrackingService taskResourceTrackingService;

@Inject
public TransportGetTaskAction(
ThreadPool threadPool,
TransportService transportService,
ActionFilters actionFilters,
ClusterService clusterService,
Client client,
NamedXContentRegistry xContentRegistry
NamedXContentRegistry xContentRegistry,
TaskResourceTrackingService taskResourceTrackingService
) {
super(GetTaskAction.NAME, transportService, actionFilters, GetTaskRequest::new);
this.threadPool = threadPool;
this.clusterService = clusterService;
this.transportService = transportService;
this.client = new OriginSettingClient(client, GetTaskAction.TASKS_ORIGIN);
this.xContentRegistry = xContentRegistry;
this.taskResourceTrackingService = taskResourceTrackingService;
}

@Override
Expand Down Expand Up @@ -173,6 +178,7 @@ public void onFailure(Exception e) {
}
});
} else {
taskResourceTrackingService.refreshResourceStats(runningTask);
TaskInfo info = runningTask.taskInfo(clusterService.localNode().getId(), true);
listener.onResponse(new GetTaskResponse(new TaskResult(false, info)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.apache.lucene.util.Constants;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.opensearch.action.admin.cluster.node.tasks.get.GetTaskRequest;
import org.opensearch.action.admin.cluster.node.tasks.get.GetTaskResponse;
import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksRequest;
import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
import org.opensearch.action.support.ActionTestUtils;
Expand Down Expand Up @@ -563,8 +565,57 @@ public void testOnDemandRefreshWhileFetchingTasks() throws InterruptedException

assertNotNull(taskInfo.getResourceStats());
assertNotNull(taskInfo.getResourceStats().getResourceUsageInfo());
assertTrue(taskInfo.getResourceStats().getResourceUsageInfo().get("total") instanceof TaskResourceUsage);
TaskResourceUsage taskResourceUsage = (TaskResourceUsage) taskInfo.getResourceStats().getResourceUsageInfo().get("total");
assertNotNull(taskInfo.getResourceStats().getResourceUsageInfo().get("total"));
TaskResourceUsage taskResourceUsage = taskInfo.getResourceStats().getResourceUsageInfo().get("total");
assertCPUTime(taskResourceUsage.getCpuTimeInNanos());
assertTrue(taskResourceUsage.getMemoryInBytes() > 0);
};

taskTestContext.operationFinishedValidator = (task, threadId) -> { assertEquals(0, resourceTasks.size()); };

startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
taskTestContext.requestCompleteLatch.countDown();
}

@Override
public void onFailure(Exception e) {
throwableReference.set(e);
taskTestContext.requestCompleteLatch.countDown();
}
});

// Waiting for whole request to complete and return successfully till client
taskTestContext.requestCompleteLatch.await();

assertTasksRequestFinishedSuccessfully(responseReference.get(), throwableReference.get());
}

public void testOnDemandRefreshWhileGetTask() throws InterruptedException {
setup(true, false);

final AtomicReference<Throwable> throwableReference = new AtomicReference<>();
final AtomicReference<NodesResponse> responseReference = new AtomicReference<>();

TaskTestContext taskTestContext = new TaskTestContext();

Map<Long, Task> resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks();

taskTestContext.operationStartValidator = (task, threadId) -> {
assertFalse(resourceTasks.isEmpty());
GetTaskResponse getTaskResponse = ActionTestUtils.executeBlocking(
testNodes[0].transportGetTaskAction,
new GetTaskRequest().setTaskId(new TaskId(testNodes[0].getNodeId(), new ArrayList<>(resourceTasks.values()).get(0).getId()))
);

TaskInfo taskInfo = getTaskResponse.getTask().getTask();

assertNotNull(taskInfo.getResourceStats());
assertNotNull(taskInfo.getResourceStats().getResourceUsageInfo());
assertNotNull(taskInfo.getResourceStats().getResourceUsageInfo().get("total"));
TaskResourceUsage taskResourceUsage = taskInfo.getResourceStats().getResourceUsageInfo().get("total");
assertCPUTime(taskResourceUsage.getCpuTimeInNanos());
assertTrue(taskResourceUsage.getMemoryInBytes() > 0);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@
import org.opensearch.Version;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.admin.cluster.node.tasks.cancel.TransportCancelTasksAction;
import org.opensearch.action.admin.cluster.node.tasks.get.TransportGetTaskAction;
import org.opensearch.action.admin.cluster.node.tasks.list.TransportListTasksAction;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.nodes.BaseNodeResponse;
import org.opensearch.action.support.nodes.BaseNodesRequest;
import org.opensearch.action.support.nodes.BaseNodesResponse;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.action.support.replication.ClusterStateCreationUtils;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterModule;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.node.DiscoveryNode;
Expand All @@ -57,6 +59,7 @@
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.common.transport.BoundTransportAddress;
import org.opensearch.core.indices.breaker.NoneCircuitBreakerService;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.tasks.TaskCancellationService;
import org.opensearch.tasks.TaskManager;
import org.opensearch.tasks.TaskResourceTrackingService;
Expand Down Expand Up @@ -85,6 +88,7 @@
import static java.util.Collections.emptySet;
import static org.opensearch.test.ClusterServiceUtils.createClusterService;
import static org.opensearch.test.ClusterServiceUtils.setState;
import static org.mockito.Mockito.mock;

/**
* The test case for unit testing task manager and related transport actions
Expand Down Expand Up @@ -249,6 +253,17 @@ protected TaskManager createTaskManager(
taskResourceTrackingService
);
transportCancelTasksAction = new TransportCancelTasksAction(clusterService, transportService, actionFilters);
Client mockClient = mock(Client.class);
NamedXContentRegistry namedXContentRegistry = mock(NamedXContentRegistry.class);
transportGetTaskAction = new TransportGetTaskAction(
threadPool,
transportService,
actionFilters,
clusterService,
mockClient,
namedXContentRegistry,
taskResourceTrackingService
);
transportService.acceptIncomingRequests();
}

Expand All @@ -258,6 +273,7 @@ protected TaskManager createTaskManager(
private final SetOnce<DiscoveryNode> discoveryNode = new SetOnce<>();
public final TransportListTasksAction transportListTasksAction;
public final TransportCancelTasksAction transportCancelTasksAction;
public final TransportGetTaskAction transportGetTaskAction;

@Override
public void close() {
Expand Down

0 comments on commit 9f77336

Please sign in to comment.