diff --git a/clients/venice-samza/src/main/java/com/linkedin/venice/pushmonitor/RouterBasedHybridStoreQuotaMonitor.java b/clients/venice-samza/src/main/java/com/linkedin/venice/pushmonitor/RouterBasedHybridStoreQuotaMonitor.java index 1d0d7f0ec0..2dfebadf33 100644 --- a/clients/venice-samza/src/main/java/com/linkedin/venice/pushmonitor/RouterBasedHybridStoreQuotaMonitor.java +++ b/clients/venice-samza/src/main/java/com/linkedin/venice/pushmonitor/RouterBasedHybridStoreQuotaMonitor.java @@ -6,6 +6,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.linkedin.venice.client.store.transport.TransportClient; import com.linkedin.venice.client.store.transport.TransportClientResponse; +import com.linkedin.venice.exceptions.ErrorType; import com.linkedin.venice.exceptions.VeniceException; import com.linkedin.venice.meta.Version; import com.linkedin.venice.routerapi.HybridStoreQuotaStatusResponse; @@ -14,10 +15,13 @@ import com.linkedin.venice.utils.ObjectMapperFactory; import com.linkedin.venice.utils.Utils; import java.io.Closeable; +import java.io.IOException; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -41,7 +45,8 @@ public RouterBasedHybridStoreQuotaMonitor( TransportClient transportClient, String storeName, Version.PushType pushType, - String topicName) { + String topicName, + TransportClientReinitProvider reinitProvider) { final String requestPath; if (Version.PushType.STREAM.equals(pushType)) { requestPath = buildStreamHybridStoreQuotaRequestPath(storeName); @@ -54,7 +59,7 @@ public RouterBasedHybridStoreQuotaMonitor( + " can monitor hybrid store quota."); } executor = Executors.newSingleThreadExecutor(new DaemonThreadFactory("RouterBasedHybridQuotaMonitor")); - hybridQuotaMonitorTask = new HybridQuotaMonitorTask(transportClient, storeName, requestPath, this); + hybridQuotaMonitorTask = new HybridQuotaMonitorTask(transportClient, storeName, requestPath, this, reinitProvider); } public void start() { @@ -66,6 +71,10 @@ public void close() { hybridQuotaMonitorTask.close(); } + protected HybridQuotaMonitorTask getHybridQuotaMonitorTask() { + return hybridQuotaMonitorTask; + } + public void setCurrentStatus(HybridStoreQuotaStatus currentStatus) { this.currentStatus = currentStatus; } @@ -82,12 +91,14 @@ private static String buildStreamReprocessingHybridStoreQuotaRequestPath(String return TYPE_STREAM_REPROCESSING_HYBRID_STORE_QUOTA + "/" + versionTopic; } - private static class HybridQuotaMonitorTask implements Runnable, Closeable { - private static ObjectMapper mapper = ObjectMapperFactory.getInstance(); + protected static class HybridQuotaMonitorTask implements Runnable, Closeable { + private ObjectMapper mapper = ObjectMapperFactory.getInstance(); private final AtomicBoolean isRunning; private final String storeName; - private final TransportClient transportClient; + private TransportClient transportClient; + + private TransportClientReinitProvider reinitProvider; private final String requestPath; private final RouterBasedHybridStoreQuotaMonitor hybridStoreQuotaMonitorService; @@ -95,12 +106,44 @@ public HybridQuotaMonitorTask( TransportClient transportClient, String storeName, String requestPath, - RouterBasedHybridStoreQuotaMonitor hybridStoreQuotaMonitorService) { + RouterBasedHybridStoreQuotaMonitor hybridStoreQuotaMonitorService, + TransportClientReinitProvider reinitProvider) { this.transportClient = transportClient; this.storeName = storeName; this.requestPath = requestPath; this.hybridStoreQuotaMonitorService = hybridStoreQuotaMonitorService; this.isRunning = new AtomicBoolean(true); + this.reinitProvider = reinitProvider; + } + + protected void setMapper(ObjectMapper mapper) { + this.mapper = mapper; + } + + protected void checkStatus() throws ExecutionException, InterruptedException, TimeoutException, IOException { + // Get hybrid store quota status + CompletableFuture responseFuture = transportClient.get(requestPath); + TransportClientResponse response = responseFuture.get(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS); + HybridStoreQuotaStatusResponse quotaStatusResponse = + mapper.readValue(response.getBody(), HybridStoreQuotaStatusResponse.class); + if (quotaStatusResponse.isError()) { + if (quotaStatusResponse.getErrorType().equals(ErrorType.STORE_NOT_FOUND)) { + LOGGER.warn("Store not found, reinitializing client! Error: {}", quotaStatusResponse.getError()); + // TODO: It'd make sense to call shutdown on the transport client, but it's a shared object so that's + // a bit dangerous. + transportClient = reinitProvider.apply(); + } + LOGGER.error("Router was not able to get hybrid quota status: {}", quotaStatusResponse.getError()); + return; + } + hybridStoreQuotaMonitorService.setCurrentStatus(quotaStatusResponse.getQuotaStatus()); + switch (quotaStatusResponse.getQuotaStatus()) { + case QUOTA_VIOLATED: + LOGGER.info("Hybrid job failed with quota violation for store: {}", storeName); + break; + default: + LOGGER.info("Current hybrid job state: {} for store: {}", quotaStatusResponse.getQuotaStatus(), storeName); + } } @Override @@ -108,25 +151,7 @@ public void run() { LOGGER.info("Running {}", this.getClass().getSimpleName()); while (isRunning.get()) { try { - // Get hybrid store quota status - CompletableFuture responseFuture = transportClient.get(requestPath); - TransportClientResponse response = responseFuture.get(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS); - HybridStoreQuotaStatusResponse quotaStatusResponse = - mapper.readValue(response.getBody(), HybridStoreQuotaStatusResponse.class); - if (quotaStatusResponse.isError()) { - LOGGER.error("Router was not able to get hybrid quota status: {}", quotaStatusResponse.getError()); - continue; - } - hybridStoreQuotaMonitorService.setCurrentStatus(quotaStatusResponse.getQuotaStatus()); - switch (quotaStatusResponse.getQuotaStatus()) { - case QUOTA_VIOLATED: - LOGGER.info("Hybrid job failed with quota violation for store: {}", storeName); - break; - default: - LOGGER - .info("Current hybrid job state: {} for store: {}", quotaStatusResponse.getQuotaStatus(), storeName); - } - + checkStatus(); Utils.sleep(POLL_CYCLE_DELAY_MS); } catch (Exception e) { if (isRunning.get() && !ExceptionUtils.recursiveClassEquals(e, InterruptedException.class)) { @@ -144,4 +169,9 @@ public void close() { isRunning.set(false); } } + + @FunctionalInterface + public interface TransportClientReinitProvider { + TransportClient apply(); + } } diff --git a/clients/venice-samza/src/main/java/com/linkedin/venice/samza/VeniceSystemProducer.java b/clients/venice-samza/src/main/java/com/linkedin/venice/samza/VeniceSystemProducer.java index 7dc177f35b..64fe020659 100644 --- a/clients/venice-samza/src/main/java/com/linkedin/venice/samza/VeniceSystemProducer.java +++ b/clients/venice-samza/src/main/java/com/linkedin/venice/samza/VeniceSystemProducer.java @@ -422,6 +422,7 @@ public synchronized void start() { this.isStarted = true; final TransportClient transportClient; + RouterBasedHybridStoreQuotaMonitor.TransportClientReinitProvider reinitProvider; if (discoveryUrl.isPresent()) { this.controllerClient = ControllerClientFactory.discoverAndConstructControllerClient(storeName, discoveryUrl.get(), sslFactory, 1); @@ -448,10 +449,11 @@ public synchronized void start() { } if (sslFactory.isPresent()) { - transportClient = new HttpsTransportClient(discoveryUrl.get(), 0, 0, false, sslFactory.get()); + reinitProvider = () -> new HttpsTransportClient(discoveryUrl.get(), 0, 0, false, sslFactory.get()); } else { - transportClient = new HttpTransportClient(discoveryUrl.get(), 0, 0); + reinitProvider = () -> new HttpTransportClient(discoveryUrl.get(), 0, 0); } + transportClient = reinitProvider.apply(); } else { this.primaryControllerColoD2Client = getStartedD2Client(primaryControllerColoD2ZKHost); this.childColoD2Client = getStartedD2Client(veniceChildD2ZkHost); @@ -461,6 +463,7 @@ public synchronized void start() { () -> D2ControllerClient .discoverCluster(primaryControllerColoD2Client, primaryControllerD2ServiceName, this.storeName), 10); + String clusterName = discoveryResponse.getCluster(); LOGGER.info("Found cluster: {} for store: {}", clusterName, storeName); @@ -496,6 +499,15 @@ public synchronized void start() { primaryControllerColoD2Client, sslFactory); transportClient = new D2TransportClient(discoveryResponse.getD2Service(), childColoD2Client); + + reinitProvider = () -> { + D2ServiceDiscoveryResponse d2DiscoveryResponse = (D2ServiceDiscoveryResponse) controllerRequestWithRetry( + () -> D2ControllerClient + .discoverCluster(primaryControllerColoD2Client, primaryControllerD2ServiceName, this.storeName), + 10); + LOGGER.info("Found cluster: {} for store: {}", clusterName, storeName); + return new D2TransportClient(d2DiscoveryResponse.getD2Service(), childColoD2Client); + }; } // Request all the necessary info from Venice Controller @@ -568,8 +580,8 @@ public synchronized void start() { if ((pushType.equals(Version.PushType.STREAM) || pushType.equals(Version.PushType.STREAM_REPROCESSING)) && hybridStoreDiskQuotaEnabled) { - hybridStoreQuotaMonitor = - Optional.of(new RouterBasedHybridStoreQuotaMonitor(transportClient, storeName, pushType, topicName)); + hybridStoreQuotaMonitor = Optional + .of(new RouterBasedHybridStoreQuotaMonitor(transportClient, storeName, pushType, topicName, reinitProvider)); hybridStoreQuotaMonitor.get().start(); } } diff --git a/clients/venice-samza/src/test/java/com/linkedin/venice/pushmonitor/RouterBasedHybridStoreQuotaMonitorTest.java b/clients/venice-samza/src/test/java/com/linkedin/venice/pushmonitor/RouterBasedHybridStoreQuotaMonitorTest.java new file mode 100644 index 0000000000..b0e0eef746 --- /dev/null +++ b/clients/venice-samza/src/test/java/com/linkedin/venice/pushmonitor/RouterBasedHybridStoreQuotaMonitorTest.java @@ -0,0 +1,91 @@ +package com.linkedin.venice.pushmonitor; + +import static org.testng.Assert.*; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.linkedin.venice.client.store.transport.TransportClient; +import com.linkedin.venice.client.store.transport.TransportClientResponse; +import com.linkedin.venice.exceptions.ErrorType; +import com.linkedin.venice.meta.Version; +import com.linkedin.venice.routerapi.HybridStoreQuotaStatusResponse; +import java.io.IOException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import org.mockito.Mockito; +import org.testng.Assert; +import org.testng.annotations.Test; + + +public class RouterBasedHybridStoreQuotaMonitorTest { + private static final String STORE_NAME = "fake_Store"; + private static final String TOPIC_NAME = "fake_Store_v1"; + + @Test + public void testTransportClientReinit() + throws IOException, ExecutionException, InterruptedException, TimeoutException { + TransportClient mockTransportclient = Mockito.mock(TransportClient.class); + TransportClientResponse mockResponse = Mockito.mock(TransportClientResponse.class); + ObjectMapper mockMapper = Mockito.mock(ObjectMapper.class); + HybridStoreQuotaStatusResponse mockQuotaStatusResponse = Mockito.mock(HybridStoreQuotaStatusResponse.class); + Mockito.when(mockResponse.getBody()).thenReturn(STORE_NAME.getBytes()); + Mockito.when(mockTransportclient.get(Mockito.anyString())) + .thenReturn(CompletableFuture.completedFuture(mockResponse)); + Mockito + .when(mockMapper.readValue(Mockito.eq(STORE_NAME.getBytes()), Mockito.eq(HybridStoreQuotaStatusResponse.class))) + .thenReturn(mockQuotaStatusResponse); + Mockito.when(mockQuotaStatusResponse.isError()).thenReturn(true); + Mockito.when(mockQuotaStatusResponse.getErrorType()).thenReturn(ErrorType.STORE_NOT_FOUND); + + final boolean[] isReinitCalled = { false }; + RouterBasedHybridStoreQuotaMonitor.TransportClientReinitProvider transportClientReinitProvider = () -> { + isReinitCalled[0] = true; + return mockTransportclient; + }; + RouterBasedHybridStoreQuotaMonitor routerBasedHybridStoreQuotaMonitor = new RouterBasedHybridStoreQuotaMonitor( + mockTransportclient, + STORE_NAME, + Version.PushType.STREAM, + TOPIC_NAME, + transportClientReinitProvider); + + routerBasedHybridStoreQuotaMonitor.getHybridQuotaMonitorTask().setMapper(mockMapper); + routerBasedHybridStoreQuotaMonitor.getHybridQuotaMonitorTask().checkStatus(); + + Assert.assertTrue(isReinitCalled[0]); + } + + @Test + public void testStatusChange() throws IOException, ExecutionException, InterruptedException, TimeoutException { + TransportClient mockTransportclient = Mockito.mock(TransportClient.class); + TransportClientResponse mockResponse = Mockito.mock(TransportClientResponse.class); + ObjectMapper mockMapper = Mockito.mock(ObjectMapper.class); + HybridStoreQuotaStatusResponse mockQuotaStatusResponse = Mockito.mock(HybridStoreQuotaStatusResponse.class); + Mockito.when(mockResponse.getBody()).thenReturn(STORE_NAME.getBytes()); + Mockito.when(mockTransportclient.get(Mockito.anyString())) + .thenReturn(CompletableFuture.completedFuture(mockResponse)); + Mockito + .when(mockMapper.readValue(Mockito.eq(STORE_NAME.getBytes()), Mockito.eq(HybridStoreQuotaStatusResponse.class))) + .thenReturn(mockQuotaStatusResponse); + Mockito.when(mockQuotaStatusResponse.isError()).thenReturn(false); + Mockito.when(mockQuotaStatusResponse.getQuotaStatus()).thenReturn(HybridStoreQuotaStatus.QUOTA_VIOLATED); + + final boolean[] isReinitCalled = { false }; + RouterBasedHybridStoreQuotaMonitor.TransportClientReinitProvider transportClientReinitProvider = () -> { + isReinitCalled[0] = true; + return mockTransportclient; + }; + RouterBasedHybridStoreQuotaMonitor routerBasedHybridStoreQuotaMonitor = new RouterBasedHybridStoreQuotaMonitor( + mockTransportclient, + STORE_NAME, + Version.PushType.STREAM, + TOPIC_NAME, + transportClientReinitProvider); + + routerBasedHybridStoreQuotaMonitor.getHybridQuotaMonitorTask().setMapper(mockMapper); + routerBasedHybridStoreQuotaMonitor.getHybridQuotaMonitorTask().checkStatus(); + + Assert.assertFalse(isReinitCalled[0]); + Assert.assertEquals(routerBasedHybridStoreQuotaMonitor.getCurrentStatus(), HybridStoreQuotaStatus.QUOTA_VIOLATED); + } +}