diff --git a/temporal-serviceclient/build.gradle b/temporal-serviceclient/build.gradle index f476f4890..8b51ba6ba 100644 --- a/temporal-serviceclient/build.gradle +++ b/temporal-serviceclient/build.gradle @@ -24,6 +24,7 @@ dependencies { api "org.slf4j:slf4j-api:$slf4jVersion" testImplementation project(':temporal-testing') + testImplementation "io.grpc:grpc-testing:${grpcVersion}" testImplementation "junit:junit:${junitVersion}" testImplementation "org.mockito:mockito-core:${mockitoVersion}" diff --git a/temporal-serviceclient/src/main/java/io/temporal/serviceclient/ChannelManager.java b/temporal-serviceclient/src/main/java/io/temporal/serviceclient/ChannelManager.java index 4cb88c0f9..38fa4e170 100644 --- a/temporal-serviceclient/src/main/java/io/temporal/serviceclient/ChannelManager.java +++ b/temporal-serviceclient/src/main/java/io/temporal/serviceclient/ChannelManager.java @@ -27,8 +27,9 @@ import io.grpc.health.v1.HealthGrpc; import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; import io.grpc.stub.MetadataUtils; -import io.temporal.api.workflowservice.v1.GetSystemInfoResponse; +import io.temporal.api.workflowservice.v1.GetSystemInfoResponse.Capabilities; import io.temporal.internal.retryer.GrpcRetryer; +import io.temporal.internal.retryer.GrpcRetryer.GrpcRetryerOptions; import java.time.Duration; import java.util.Collection; import java.util.List; @@ -87,7 +88,7 @@ final class ChannelManager { private final Channel interceptedChannel; private final HealthGrpc.HealthBlockingStub healthBlockingStub; - private final CompletableFuture serverCapabilitiesFuture = + private final CompletableFuture serverCapabilitiesFuture = new CompletableFuture<>(); public ChannelManager( @@ -289,8 +290,8 @@ public void connect(String healthCheckServiceName, @Nullable Duration timeout) { if (timeout == null) { timeout = options.getRpcTimeout(); } - GrpcRetryer.GrpcRetryerOptions grpcRetryerOptions = - new GrpcRetryer.GrpcRetryerOptions( + GrpcRetryerOptions grpcRetryerOptions = + new GrpcRetryerOptions( RpcRetryOptions.newBuilder().setExpiration(timeout).validateBuildWithDefaults(), null); new GrpcRetryer(getServerCapabilities()) @@ -310,30 +311,24 @@ public void connect(String healthCheckServiceName, @Nullable Duration timeout) { */ public HealthCheckResponse healthCheck( String healthCheckServiceName, @Nullable Duration timeout) { - HealthGrpc.HealthBlockingStub stub; - if (timeout != null) { - stub = - this.healthBlockingStub.withDeadline( - Deadline.after( - options.getHealthCheckAttemptTimeout().toMillis(), TimeUnit.MILLISECONDS)); - } else { - stub = this.healthBlockingStub; + if (timeout == null) { + timeout = options.getHealthCheckAttemptTimeout(); } - return stub.check(HealthCheckRequest.newBuilder().setService(healthCheckServiceName).build()); + return this.healthBlockingStub + .withDeadline(deadlineFrom(timeout)) + .check(HealthCheckRequest.newBuilder().setService(healthCheckServiceName).build()); } - public Supplier getServerCapabilities() { - return () -> { - synchronized (serverCapabilitiesFuture) { - GetSystemInfoResponse.Capabilities capabilities = serverCapabilitiesFuture.getNow(null); - if (capabilities == null) { - serverCapabilitiesFuture.complete( - SystemInfoInterceptor.getServerCapabilitiesOrThrow(interceptedChannel, null)); - capabilities = serverCapabilitiesFuture.getNow(null); - } - return capabilities; - } - }; + public Supplier getServerCapabilities() { + return () -> + SystemInfoInterceptor.getServerCapabilitiesWithRetryOrThrow( + serverCapabilitiesFuture, + interceptedChannel, + deadlineFrom(options.getHealthCheckAttemptTimeout())); + } + + private static Deadline deadlineFrom(Duration duration) { + return Deadline.after(duration.toMillis(), TimeUnit.MILLISECONDS); } public void shutdown() { diff --git a/temporal-serviceclient/src/main/java/io/temporal/serviceclient/SystemInfoInterceptor.java b/temporal-serviceclient/src/main/java/io/temporal/serviceclient/SystemInfoInterceptor.java index e392219ff..3f65cf2a3 100644 --- a/temporal-serviceclient/src/main/java/io/temporal/serviceclient/SystemInfoInterceptor.java +++ b/temporal-serviceclient/src/main/java/io/temporal/serviceclient/SystemInfoInterceptor.java @@ -23,16 +23,22 @@ import io.grpc.*; import io.temporal.api.workflowservice.v1.GetSystemInfoRequest; import io.temporal.api.workflowservice.v1.GetSystemInfoResponse; +import io.temporal.api.workflowservice.v1.GetSystemInfoResponse.Capabilities; import io.temporal.api.workflowservice.v1.WorkflowServiceGrpc; +import io.temporal.internal.retryer.GrpcRetryer; +import io.temporal.internal.retryer.GrpcRetryer.GrpcRetryerOptions; +import java.time.Duration; +import java.util.Objects; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nonnull; import javax.annotation.Nullable; public class SystemInfoInterceptor implements ClientInterceptor { - private final CompletableFuture serverCapabilitiesFuture; + private final CompletableFuture serverCapabilitiesFuture; - public SystemInfoInterceptor( - CompletableFuture serverCapabilitiesFuture) { + public SystemInfoInterceptor(CompletableFuture serverCapabilitiesFuture) { this.serverCapabilitiesFuture = serverCapabilitiesFuture; } @@ -63,8 +69,7 @@ public void onMessage(RespT message) { @Override public void onClose(Status status, Metadata trailers) { if (Status.UNIMPLEMENTED.getCode().equals(status.getCode())) { - serverCapabilitiesFuture.complete( - GetSystemInfoResponse.Capabilities.getDefaultInstance()); + serverCapabilitiesFuture.complete(Capabilities.getDefaultInstance()); } super.onClose(status, trailers); } @@ -87,7 +92,39 @@ public void onClose(Status status, Metadata trailers) { }; } - public static GetSystemInfoResponse.Capabilities getServerCapabilitiesOrThrow( + public static Capabilities getServerCapabilitiesWithRetryOrThrow( + @Nonnull CompletableFuture future, + @Nonnull Channel channel, + @Nullable Deadline deadline) { + Capabilities capabilities = future.getNow(null); + if (capabilities == null) { + synchronized (Objects.requireNonNull(future)) { + capabilities = future.getNow(null); + if (capabilities == null) { + if (deadline == null) { + deadline = Deadline.after(30, TimeUnit.SECONDS); + } + Deadline computedDeadline = deadline; + RpcRetryOptions rpcRetryOptions = + RpcRetryOptions.newBuilder() + .setExpiration( + Duration.ofMillis(computedDeadline.timeRemaining(TimeUnit.MILLISECONDS))) + .validateBuildWithDefaults(); + GrpcRetryerOptions grpcRetryerOptions = + new GrpcRetryerOptions(rpcRetryOptions, computedDeadline); + capabilities = + new GrpcRetryer(Capabilities::getDefaultInstance) + .retryWithResult( + () -> getServerCapabilitiesOrThrow(channel, computedDeadline), + grpcRetryerOptions); + future.complete(capabilities); + } + } + } + return capabilities; + } + + public static Capabilities getServerCapabilitiesOrThrow( Channel channel, @Nullable Deadline deadline) { try { return WorkflowServiceGrpc.newBlockingStub(channel) @@ -96,7 +133,7 @@ public static GetSystemInfoResponse.Capabilities getServerCapabilitiesOrThrow( .getCapabilities(); } catch (StatusRuntimeException ex) { if (Status.Code.UNIMPLEMENTED.equals(ex.getStatus().getCode())) { - return GetSystemInfoResponse.Capabilities.getDefaultInstance(); + return Capabilities.getDefaultInstance(); } throw ex; } diff --git a/temporal-serviceclient/src/test/java/io/temporal/serviceclient/ChannelManagerTest.java b/temporal-serviceclient/src/test/java/io/temporal/serviceclient/ChannelManagerTest.java new file mode 100644 index 000000000..362b1d9e9 --- /dev/null +++ b/temporal-serviceclient/src/test/java/io/temporal/serviceclient/ChannelManagerTest.java @@ -0,0 +1,232 @@ +/* + * Copyright (C) 2022 Temporal Technologies, Inc. All Rights Reserved. + * + * Copyright (C) 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this material except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.temporal.serviceclient; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import io.grpc.ManagedChannel; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.health.v1.HealthCheckRequest; +import io.grpc.health.v1.HealthCheckResponse; +import io.grpc.health.v1.HealthCheckResponse.ServingStatus; +import io.grpc.health.v1.HealthGrpc.HealthImplBase; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.temporal.api.workflowservice.v1.GetSystemInfoRequest; +import io.temporal.api.workflowservice.v1.GetSystemInfoResponse; +import io.temporal.api.workflowservice.v1.GetSystemInfoResponse.Capabilities; +import io.temporal.api.workflowservice.v1.WorkflowServiceGrpc.WorkflowServiceImplBase; +import java.time.Duration; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.*; + +public class ChannelManagerTest { + + private static final String HEALTH_CHECK_NAME = "my-health-check"; + + private static final HealthCheckResponse HEALTH_CHECK_SERVING = + HealthCheckResponse.newBuilder().setStatus(ServingStatus.SERVING).build(); + + private static final Capabilities CAPABILITIES = + Capabilities.newBuilder().setInternalErrorDifferentiation(true).build(); + + private static final GetSystemInfoResponse GET_SYSTEM_INFO_RESPONSE = + GetSystemInfoResponse.newBuilder().setCapabilities(CAPABILITIES).build(); + + private static final RpcRetryOptions RPC_RETRY_OPTIONS = + RpcRetryOptions.newBuilder() + .setInitialInterval(Duration.ofMillis(10)) + .setBackoffCoefficient(1.0) + .setMaximumAttempts(3) + .setExpiration(Duration.ofMillis(100)) + .validateBuildWithDefaults(); + + @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); + + private final AtomicInteger checkCount = new AtomicInteger(0); + private final AtomicInteger checkUnavailable = new AtomicInteger(0); + private final AtomicInteger getSystemInfoCount = new AtomicInteger(0); + private final AtomicInteger getSystemInfoUnavailable = new AtomicInteger(0); + private final AtomicInteger getSystemInfoUnimplemented = new AtomicInteger(0); + + private final HealthImplBase healthImpl = + new HealthImplBase() { + @Override + public void check( + HealthCheckRequest request, StreamObserver responseObserver) { + if (!HEALTH_CHECK_NAME.equals(request.getService())) { + responseObserver.onError(Status.fromCode(Status.Code.NOT_FOUND).asException()); + } else if (checkUnavailable.getAndDecrement() > 0) { + responseObserver.onError(Status.fromCode(Status.Code.UNAVAILABLE).asException()); + } else { + checkCount.getAndIncrement(); + responseObserver.onNext(HEALTH_CHECK_SERVING); + responseObserver.onCompleted(); + } + } + }; + private final WorkflowServiceImplBase workflowImpl = + new WorkflowServiceImplBase() { + @Override + public void getSystemInfo( + GetSystemInfoRequest request, StreamObserver responseObserver) { + if (getSystemInfoUnavailable.getAndDecrement() > 0) { + responseObserver.onError(Status.fromCode(Status.Code.UNAVAILABLE).asException()); + } else if (getSystemInfoUnimplemented.getAndDecrement() > 0) { + responseObserver.onError(Status.fromCode(Status.Code.UNIMPLEMENTED).asException()); + } else { + getSystemInfoCount.getAndIncrement(); + responseObserver.onNext(GET_SYSTEM_INFO_RESPONSE); + responseObserver.onCompleted(); + } + } + }; + + private ChannelManager channelManager; + + @Before + public void setUp() throws Exception { + checkCount.set(0); + checkUnavailable.set(0); + getSystemInfoCount.set(0); + getSystemInfoUnavailable.set(0); + getSystemInfoUnimplemented.set(0); + String serverName = InProcessServerBuilder.generateName(); + grpcCleanupRule.register( + InProcessServerBuilder.forName(serverName) + .directExecutor() + .addService(healthImpl) + .addService(workflowImpl) + .build() + .start()); + ManagedChannel channel = + grpcCleanupRule.register( + InProcessChannelBuilder.forName(serverName).directExecutor().build()); + WorkflowServiceStubsOptions serviceStubsOptions = + WorkflowServiceStubsOptions.newBuilder() + .setChannel(channel) + .setRpcRetryOptions(RPC_RETRY_OPTIONS) + .validateAndBuildWithDefaults(); + channelManager = new ChannelManager(serviceStubsOptions, Collections.emptyList()); + } + + @After + public void tearDown() throws Exception { + if (channelManager != null) { + channelManager.shutdownNow(); + } + } + + @Test + public void testGetServerCapabilities() throws Exception { + Capabilities capabilities = channelManager.getServerCapabilities().get(); + assertEquals(CAPABILITIES, capabilities); + assertEquals(1, getSystemInfoCount.get()); + assertEquals(-1, getSystemInfoUnavailable.get()); + assertEquals(-1, getSystemInfoUnimplemented.get()); + } + + @Test + public void testGetServerCapabilitiesRetry() throws Exception { + getSystemInfoUnavailable.set(2); + Capabilities capabilities = channelManager.getServerCapabilities().get(); + assertEquals(CAPABILITIES, capabilities); + assertEquals(1, getSystemInfoCount.get()); + assertEquals(-1, getSystemInfoUnavailable.get()); + assertEquals(-1, getSystemInfoUnimplemented.get()); + } + + @Test + public void testGetServerCapabilitiesUnavailable() throws Exception { + getSystemInfoUnavailable.set(Integer.MAX_VALUE); + try { + Capabilities unused = channelManager.getServerCapabilities().get(); + Assert.fail("expected StatusRuntimeException"); + } catch (StatusRuntimeException e) { + assertEquals(Status.Code.UNAVAILABLE, e.getStatus().getCode()); + assertEquals(0, getSystemInfoCount.get()); + assertTrue(getSystemInfoUnavailable.get() >= 0); + assertEquals(0, getSystemInfoUnimplemented.get()); + } + } + + @Test + public void testGetServerCapabilitiesUnimplemented() throws Exception { + getSystemInfoUnimplemented.set(1); + Capabilities capabilities = channelManager.getServerCapabilities().get(); + assertEquals(Capabilities.getDefaultInstance(), capabilities); + assertEquals(0, getSystemInfoCount.get()); + assertEquals(-1, getSystemInfoUnavailable.get()); + assertEquals(0, getSystemInfoUnimplemented.get()); + } + + @Test + public void testGetServerCapabilitiesWithConnect() throws Exception { + channelManager.connect(HEALTH_CHECK_NAME, Duration.ofMillis(100)); + Capabilities capabilities = channelManager.getServerCapabilities().get(); + assertEquals(CAPABILITIES, capabilities); + assertEquals(1, getSystemInfoCount.get()); + assertEquals(-1, getSystemInfoUnavailable.get()); + assertEquals(-1, getSystemInfoUnimplemented.get()); + } + + @Test + public void testGetServerCapabilitiesRetryWithConnect() throws Exception { + getSystemInfoUnavailable.set(2); + channelManager.connect(HEALTH_CHECK_NAME, Duration.ofMillis(100)); + Capabilities capabilities = channelManager.getServerCapabilities().get(); + assertEquals(CAPABILITIES, capabilities); + assertEquals(1, getSystemInfoCount.get()); + assertEquals(-1, getSystemInfoUnavailable.get()); + assertEquals(-1, getSystemInfoUnimplemented.get()); + } + + @Test + public void testGetServerCapabilitiesUnavailableWithConnect() throws Exception { + getSystemInfoUnavailable.set(Integer.MAX_VALUE); + try { + channelManager.connect(HEALTH_CHECK_NAME, Duration.ofMillis(100)); + Capabilities unused = channelManager.getServerCapabilities().get(); + Assert.fail("expected StatusRuntimeException"); + } catch (StatusRuntimeException e) { + assertEquals(Status.Code.UNAVAILABLE, e.getStatus().getCode()); + assertEquals(0, getSystemInfoCount.get()); + assertTrue(getSystemInfoUnavailable.get() >= 0); + assertEquals(0, getSystemInfoUnimplemented.get()); + } + } + + @Test + public void testGetServerCapabilitiesUnimplementedWithConnect() throws Exception { + getSystemInfoUnimplemented.set(1); + channelManager.connect(HEALTH_CHECK_NAME, Duration.ofMillis(100)); + Capabilities capabilities = channelManager.getServerCapabilities().get(); + assertEquals(Capabilities.getDefaultInstance(), capabilities); + assertEquals(0, getSystemInfoCount.get()); + assertEquals(-1, getSystemInfoUnavailable.get()); + assertEquals(0, getSystemInfoUnimplemented.get()); + } +}