Skip to content

Commit

Permalink
[Extensions] Retrieve transport service from SDKTransportService rath…
Browse files Browse the repository at this point in the history
…er than ExtensionsRunner (#923)

* Retrieving transport service from sdkTransportService rather than from extensionrunner

Signed-off-by: Joshua Palis <[email protected]>

* Fixing affected tests

Signed-off-by: Joshua Palis <[email protected]>

---------

Signed-off-by: Joshua Palis <[email protected]>
  • Loading branch information
joshpalis authored Jun 7, 2023
1 parent 45127f7 commit d01f733
Show file tree
Hide file tree
Showing 12 changed files with 28 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ private void registerJobDetailsIfNecessary(RestRequest request) throws IOExcepti
requestBody.field(GetJobDetailsRequest.JOB_TYPE, AnomalyDetectorExtension.AD_JOB_TYPE);
requestBody.field(GetJobDetailsRequest.JOB_PARAMETER_ACTION, ADJobParameterAction.class.getName());
requestBody.field(GetJobDetailsRequest.JOB_RUNNER_ACTION, ADJobRunnerAction.class.getName());
requestBody.field(GetJobDetailsRequest.EXTENSION_UNIQUE_ID, extensionsRunner.getUniqueId());
requestBody.field(GetJobDetailsRequest.EXTENSION_UNIQUE_ID, extensionsRunner.getSdkTransportService().getUniqueId());
requestBody.endObject();

Request registerJobDetailsRequest = new Request(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public RestIndexAnomalyDetectorAction(ExtensionsRunner extensionsRunner, SDKRest
super(extensionsRunner);
this.namedXContentRegistry = extensionsRunner.getNamedXContentRegistry();
this.environmentSettings = extensionsRunner.getEnvironmentSettings();
this.transportService = extensionsRunner.getExtensionTransportService();
this.transportService = extensionsRunner.getSdkTransportService().getTransportService();
this.sdkRestClient = sdkRestClient;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public ADBatchAnomalyResultTransportAction(
ADBatchTaskRunner adBatchTaskRunner
) {
super(ADBatchAnomalyResultAction.NAME, actionFilters, taskManager);
this.transportService = extensionsRunner.getExtensionTransportService();
this.transportService = extensionsRunner.getSdkTransportService().getTransportService();
this.adBatchTaskRunner = adBatchTaskRunner;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public AnomalyDetectorJobTransportAction(
ADTaskManager adTaskManager
) {
super(AnomalyDetectorJobAction.NAME, actionFilters, taskManager);
this.transportService = extensionsRunner.getExtensionTransportService();
this.transportService = extensionsRunner.getSdkTransportService().getTransportService();
this.client = client;
this.clusterService = clusterService;
this.settings = extensionsRunner.getEnvironmentSettings();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ public AnomalyResultTransportAction(
) {
super(AnomalyResultAction.NAME, actionFilters, taskManager);
this.extensionsRunner = extensionsRunner;
this.transportService = extensionsRunner.getExtensionTransportService();
this.transportService = extensionsRunner.getSdkTransportService().getTransportService();
this.settings = extensionsRunner.getEnvironmentSettings();
this.sdkRestClient = sdkRestClient;
this.stateManager = manager;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public DeleteAnomalyDetectorTransportAction(
ADTaskManager adTaskManager
) {
super(DeleteAnomalyDetectorAction.NAME, actionFilters, taskManager);
this.transportService = extensionsRunner.getExtensionTransportService();
this.transportService = extensionsRunner.getSdkTransportService().getTransportService();
this.client = client;
this.clusterService = clusterService;
this.xContentRegistry = xContentRegistry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public ForwardADTaskTransportAction(
) {
super(ForwardADTaskAction.NAME, actionFilters, taskManager);
this.adTaskManager = adTaskManager;
this.transportService = extensionsRunner.getExtensionTransportService();
this.transportService = extensionsRunner.getSdkTransportService().getTransportService();
this.adTaskCacheManager = adTaskCacheManager;
this.featureManager = featureManager;
this.stateManager = stateManager;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public GetAnomalyDetectorTransportAction(
this.nodeFilter = nodeFilter;
filterByEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it);
this.transportService = extensionsRunner.getExtensionTransportService();
this.transportService = extensionsRunner.getSdkTransportService().getTransportService();
this.adTaskManager = adTaskManager;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public IndexAnomalyDetectorTransportAction(
) {
super(IndexAnomalyDetectorAction.NAME, actionFilters, taskManager);
this.client = restClient;
this.transportService = extensionsRunner.getExtensionTransportService();
this.transportService = extensionsRunner.getSdkTransportService().getTransportService();
this.clusterService = sdkClusterService;
this.anomalyDetectionIndices = anomalyDetectionIndices;
this.xContentRegistry = namedXContentRegistry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.ad.task.ADTaskCacheManager;
import org.opensearch.ad.task.ADTaskManager;
import org.opensearch.sdk.ExtensionsRunner;
import org.opensearch.sdk.SDKTransportService;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskManager;
import org.opensearch.transport.TransportService;
Expand All @@ -50,6 +51,7 @@ public class ForwardADTaskTransportActionTests extends ADUnitTestCase {
private ExtensionsRunner extensionsRunner;
private TaskManager taskManager;
private ActionFilters actionFilters;
private SDKTransportService sdkTransportService;
private TransportService transportService;
private ADTaskManager adTaskManager;
private ADTaskCacheManager adTaskCacheManager;
Expand All @@ -66,12 +68,14 @@ public void setUp() throws Exception {
extensionsRunner = mock(ExtensionsRunner.class);
taskManager = mock(TaskManager.class);
actionFilters = mock(ActionFilters.class);
sdkTransportService = mock(SDKTransportService.class);
transportService = mock(TransportService.class);
adTaskManager = mock(ADTaskManager.class);
adTaskCacheManager = mock(ADTaskCacheManager.class);
featureManager = mock(FeatureManager.class);
stateManager = mock(NodeStateManager.class);
when(extensionsRunner.getExtensionTransportService()).thenReturn(transportService);
when(extensionsRunner.getSdkTransportService()).thenReturn(sdkTransportService);
when(sdkTransportService.getTransportService()).thenReturn(transportService);
forwardADTaskTransportAction = new ForwardADTaskTransportAction(
extensionsRunner,
taskManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@
import org.opensearch.sdk.SDKClient.SDKRestClient;
import org.opensearch.sdk.SDKClusterService;
import org.opensearch.sdk.SDKNamedXContentRegistry;
import org.opensearch.sdk.SDKTransportService;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskManager;
import org.opensearch.test.OpenSearchSingleNodeTestCase;
import org.opensearch.transport.TransportService;

import com.google.common.collect.ImmutableMap;

Expand Down Expand Up @@ -94,6 +96,11 @@ public void setUp() throws Exception {
when(mockRunner.getNamedXContentRegistry()).thenReturn(sdkNamedXContentRegistry);
when(sdkNamedXContentRegistry.getRegistry()).thenReturn(xContentRegistry());

SDKTransportService mockSdkTransportService = mock(SDKTransportService.class);
TransportService mockTransportService = mock(TransportService.class);
when(mockRunner.getSdkTransportService()).thenReturn(mockSdkTransportService);
when(mockSdkTransportService.getTransportService()).thenReturn(mockTransportService);

action = new GetAnomalyDetectorTransportAction(
mockRunner,
mock(TaskManager.class),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@
import org.opensearch.sdk.SDKClusterService;
import org.opensearch.sdk.SDKClusterService.SDKClusterSettings;
import org.opensearch.sdk.SDKNamedXContentRegistry;
import org.opensearch.sdk.SDKTransportService;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskManager;
import org.opensearch.test.OpenSearchIntegTestCase;
import org.opensearch.transport.TransportService;

import com.google.common.collect.ImmutableMap;

Expand Down Expand Up @@ -122,6 +124,11 @@ public void setUp() throws Exception {
this.mockSdkXContentRegistry = mock(SDKNamedXContentRegistry.class);
when(mockSdkXContentRegistry.getRegistry()).thenReturn(xContentRegistry());

SDKTransportService mockSdkTransportService = mock(SDKTransportService.class);
TransportService mockTransportService = mock(TransportService.class);
when(mockRunner.getSdkTransportService()).thenReturn(mockSdkTransportService);
when(mockSdkTransportService.getTransportService()).thenReturn(mockTransportService);

action = new IndexAnomalyDetectorTransportAction(
mockRunner,
mock(TaskManager.class),
Expand Down

0 comments on commit d01f733

Please sign in to comment.