diff --git a/balancer_wrapper.go b/balancer_wrapper.go index 905817b5fc7b..c2688376ae74 100644 --- a/balancer_wrapper.go +++ b/balancer_wrapper.go @@ -34,7 +34,15 @@ import ( "google.golang.org/grpc/status" ) -var setConnectedAddress = internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address)) +var ( + setConnectedAddress = internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address)) + // noOpRegisterHealthListenerFn is used when client side health checking is + // disabled. It sends a single READY update on the registered listener. + noOpRegisterHealthListenerFn = func(_ context.Context, listener func(balancer.SubConnState)) func() { + listener(balancer.SubConnState{ConnectivityState: connectivity.Ready}) + return func() {} + } +) // ccBalancerWrapper sits between the ClientConn and the Balancer. // @@ -277,10 +285,17 @@ type healthData struct { // to the LB policy. This is stored to avoid sending updates when the // SubConn has already exited connectivity state READY. connectivityState connectivity.State + // closeHealthProducer stores function to close the ref counted health + // producer. The health producer is automatically closed when the SubConn + // state changes. + closeHealthProducer func() } func newHealthData(s connectivity.State) *healthData { - return &healthData{connectivityState: s} + return &healthData{ + connectivityState: s, + closeHealthProducer: func() {}, + } } // updateState is invoked by grpc to push a subConn state update to the @@ -413,6 +428,37 @@ func (acbw *acBalancerWrapper) closeProducers() { } } +// healthProducerRegisterFn is a type alias for the health producer's function +// for registering listeners. +type healthProducerRegisterFn = func(context.Context, balancer.SubConn, string, func(balancer.SubConnState)) func() + +// healthListenerRegFn returns a function to register a listener for health +// updates. If client side health checks are disabled, the registered listener +// will get a single READY (raw connectivity state) update. +// +// Client side health checking is enabled when all the following +// conditions are satisfied: +// 1. Health checking is not disabled using the dial option. +// 2. The health package is imported. +// 3. The health check config is present in the service config. +func (acbw *acBalancerWrapper) healthListenerRegFn() func(context.Context, func(balancer.SubConnState)) func() { + if acbw.ccb.cc.dopts.disableHealthCheck { + return noOpRegisterHealthListenerFn + } + regHealthLisFn := internal.RegisterClientHealthCheckListener + if regHealthLisFn == nil { + // The health package is not imported. + return noOpRegisterHealthListenerFn + } + cfg := acbw.ac.cc.healthCheckConfig() + if cfg == nil { + return noOpRegisterHealthListenerFn + } + return func(ctx context.Context, listener func(balancer.SubConnState)) func() { + return regHealthLisFn.(healthProducerRegisterFn)(ctx, acbw, cfg.ServiceName, listener) + } +} + // RegisterHealthListener accepts a health listener from the LB policy. It sends // updates to the health listener as long as the SubConn's connectivity state // doesn't change and a new health listener is not registered. To invalidate @@ -421,6 +467,7 @@ func (acbw *acBalancerWrapper) closeProducers() { func (acbw *acBalancerWrapper) RegisterHealthListener(listener func(balancer.SubConnState)) { acbw.healthMu.Lock() defer acbw.healthMu.Unlock() + acbw.healthData.closeHealthProducer() // listeners should not be registered when the connectivity state // isn't Ready. This may happen when the balancer registers a listener // after the connectivityState is updated, but before it is notified @@ -436,6 +483,7 @@ func (acbw *acBalancerWrapper) RegisterHealthListener(listener func(balancer.Sub return } + registerFn := acbw.healthListenerRegFn() acbw.ccb.serializer.TrySchedule(func(ctx context.Context) { if ctx.Err() != nil || acbw.ccb.balancer == nil { return @@ -443,10 +491,25 @@ func (acbw *acBalancerWrapper) RegisterHealthListener(listener func(balancer.Sub // Don't send updates if a new listener is registered. acbw.healthMu.Lock() defer acbw.healthMu.Unlock() - curHD := acbw.healthData - if curHD != hd { + if acbw.healthData != hd { return } - listener(balancer.SubConnState{ConnectivityState: connectivity.Ready}) + // Serialize the health updates from the health producer with + // other calls into the LB policy. + listenerWrapper := func(scs balancer.SubConnState) { + acbw.ccb.serializer.TrySchedule(func(ctx context.Context) { + if ctx.Err() != nil || acbw.ccb.balancer == nil { + return + } + acbw.healthMu.Lock() + defer acbw.healthMu.Unlock() + if acbw.healthData != hd { + return + } + listener(scs) + }) + } + + hd.closeHealthProducer = registerFn(ctx, listenerWrapper) }) } diff --git a/health/producer.go b/health/producer.go new file mode 100644 index 000000000000..f938e5790c7b --- /dev/null +++ b/health/producer.go @@ -0,0 +1,106 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package health + +import ( + "context" + "sync" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/status" +) + +func init() { + producerBuilderSingleton = &producerBuilder{} + internal.RegisterClientHealthCheckListener = registerClientSideHealthCheckListener +} + +type producerBuilder struct{} + +var producerBuilderSingleton *producerBuilder + +// Build constructs and returns a producer and its cleanup function. +func (*producerBuilder) Build(cci any) (balancer.Producer, func()) { + p := &healthServiceProducer{ + cc: cci.(grpc.ClientConnInterface), + cancel: func() {}, + } + return p, func() { + p.mu.Lock() + defer p.mu.Unlock() + p.cancel() + } +} + +type healthServiceProducer struct { + // The following fields are initialized at build time and read-only after + // that and therefore do not need to be guarded by a mutex. + cc grpc.ClientConnInterface + + mu sync.Mutex + cancel func() +} + +// registerClientSideHealthCheckListener accepts a listener to provide server +// health state via the health service. +func registerClientSideHealthCheckListener(ctx context.Context, sc balancer.SubConn, serviceName string, listener func(balancer.SubConnState)) func() { + pr, closeFn := sc.GetOrBuildProducer(producerBuilderSingleton) + p := pr.(*healthServiceProducer) + p.mu.Lock() + defer p.mu.Unlock() + p.cancel() + if listener == nil { + return closeFn + } + + ctx, cancel := context.WithCancel(ctx) + p.cancel = cancel + + go p.startHealthCheck(ctx, sc, serviceName, listener) + return closeFn +} + +func (p *healthServiceProducer) startHealthCheck(ctx context.Context, sc balancer.SubConn, serviceName string, listener func(balancer.SubConnState)) { + newStream := func(method string) (any, error) { + return p.cc.NewStream(ctx, &grpc.StreamDesc{ServerStreams: true}, method) + } + + setConnectivityState := func(state connectivity.State, err error) { + listener(balancer.SubConnState{ + ConnectivityState: state, + ConnectionError: err, + }) + } + + // Call the function through the internal variable as tests use it for + // mocking. + err := internal.HealthCheckFunc(ctx, newStream, setConnectivityState, serviceName) + if err == nil { + return + } + if status.Code(err) == codes.Unimplemented { + logger.Errorf("Subchannel health check is unimplemented at server side, thus health check is disabled for SubConn %p", sc) + } else { + logger.Errorf("Health checking failed for SubConn %p: %v", sc, err) + } +} diff --git a/internal/internal.go b/internal/internal.go index 3afc1813440e..c17b98194b3c 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -31,6 +31,10 @@ import ( var ( // HealthCheckFunc is used to provide client-side LB channel health checking HealthCheckFunc HealthChecker + // RegisterClientHealthCheckListener is used to provide a listener for + // updates from the client-side health checking service. It returns a + // function that can be called to stop the health producer. + RegisterClientHealthCheckListener any // func(ctx context.Context, sc balancer.SubConn, serviceName string, listener func(balancer.SubConnState)) func() // BalancerUnregister is exported by package balancer to unregister a balancer. BalancerUnregister func(name string) // KeepaliveMinPingTime is the minimum ping interval. This must be 10s by diff --git a/test/healthcheck_test.go b/test/healthcheck_test.go index 424682d09625..fac565240ab7 100644 --- a/test/healthcheck_test.go +++ b/test/healthcheck_test.go @@ -28,12 +28,17 @@ import ( "time" "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/pickfirst" + "google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/health" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" @@ -46,6 +51,47 @@ import ( testpb "google.golang.org/grpc/interop/grpc_testing" ) +const healthCheckingPetiolePolicyName = "health_checking_petiole_policy" + +var ( + // healthCheckTestPolicyName is the LB policy used for testing the health check + // service. + healthCheckTestPolicyName = "round_robin" +) + +func init() { + balancer.Register(&healthCheckingPetiolePolicyBuilder{}) + // Till dualstack changes are not implemented and round_robin doesn't + // delegate to pickfirst, test a fake petiole policy that delegates to + // the new pickfirst balancer. + // TODO: https://github.com/grpc/grpc-go/issues/7906 - Remove the fake + // petiole policy one round robin starts delegating to pickfirst. + if envconfig.NewPickFirstEnabled { + healthCheckTestPolicyName = healthCheckingPetiolePolicyName + } +} + +type healthCheckingPetiolePolicyBuilder struct{} + +func (bb *healthCheckingPetiolePolicyBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + return &healthCheckingPetiolePolicy{ + Balancer: balancer.Get(pickfirstleaf.Name).Build(cc, opts), + } +} + +func (bb *healthCheckingPetiolePolicyBuilder) Name() string { + return healthCheckingPetiolePolicyName +} + +func (b *healthCheckingPetiolePolicy) UpdateClientConnState(state balancer.ClientConnState) error { + state.ResolverState = pickfirstleaf.EnableHealthListener(state.ResolverState) + return b.Balancer.UpdateClientConnState(state) +} + +type healthCheckingPetiolePolicy struct { + balancer.Balancer +} + func newTestHealthServer() *testHealthServer { return newTestHealthServerWithWatchFunc(defaultWatchFunc) } @@ -261,12 +307,12 @@ func (s) TestHealthCheckHealthServerNotRegistered(t *testing.T) { cc, r := setupClient(t, nil) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ - "healthCheckConfig": { - "serviceName": "foo" - }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ + "healthCheckConfig": { + "serviceName": "foo" + }, + "loadBalancingConfig": [{"%s":{}}] + }`, healthCheckTestPolicyName))}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -288,12 +334,12 @@ func (s) TestHealthCheckWithGoAway(t *testing.T) { tc := testgrpc.NewTestServiceClient(cc) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ - "healthCheckConfig": { - "serviceName": "foo" - }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ + "healthCheckConfig": { + "serviceName": "foo" + }, + "loadBalancingConfig": [{"%s":{}}] + }`, healthCheckTestPolicyName))}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -366,12 +412,12 @@ func (s) TestHealthCheckWithConnClose(t *testing.T) { tc := testgrpc.NewTestServiceClient(cc) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ - "healthCheckConfig": { - "serviceName": "foo" - }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ + "healthCheckConfig": { + "serviceName": "foo" + }, + "loadBalancingConfig": [{"%s":{}}] + }`, healthCheckTestPolicyName))}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -414,12 +460,12 @@ func (s) TestHealthCheckWithAddrConnDrain(t *testing.T) { hcEnterChan, hcExitChan := setupHealthCheckWrapper(t) cc, r := setupClient(t, &clientConfig{}) tc := testgrpc.NewTestServiceClient(cc) - sc := parseServiceConfig(t, r, `{ - "healthCheckConfig": { - "serviceName": "foo" - }, - "loadBalancingConfig": [{"round_robin":{}}] -}`) + sc := parseServiceConfig(t, r, fmt.Sprintf(`{ + "healthCheckConfig": { + "serviceName": "foo" + }, + "loadBalancingConfig": [{"%s":{}}] + }`, healthCheckTestPolicyName)) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, ServiceConfig: sc, @@ -496,12 +542,12 @@ func (s) TestHealthCheckWithClientConnClose(t *testing.T) { tc := testgrpc.NewTestServiceClient(cc) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ - "healthCheckConfig": { - "serviceName": "foo" - }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + ServiceConfig: parseServiceConfig(t, r, (fmt.Sprintf(`{ + "healthCheckConfig": { + "serviceName": "foo" + }, + "loadBalancingConfig": [{"%s":{}}] + }`, healthCheckTestPolicyName)))}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -563,12 +609,12 @@ func (s) TestHealthCheckWithoutSetConnectivityStateCalledAddrConnShutDown(t *tes // The serviceName "delay" is specially handled at server side, where response will not be sent // back to client immediately upon receiving the request (client should receive no response until // test ends). - sc := parseServiceConfig(t, r, `{ - "healthCheckConfig": { - "serviceName": "delay" - }, - "loadBalancingConfig": [{"round_robin":{}}] -}`) + sc := parseServiceConfig(t, r, fmt.Sprintf(`{ + "healthCheckConfig": { + "serviceName": "delay" + }, + "loadBalancingConfig": [{"%s":{}}] + }`, healthCheckTestPolicyName)) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, ServiceConfig: sc, @@ -628,12 +674,12 @@ func (s) TestHealthCheckWithoutSetConnectivityStateCalled(t *testing.T) { // test ends). r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ - "healthCheckConfig": { - "serviceName": "delay" - }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ + "healthCheckConfig": { + "serviceName": "delay" + }, + "loadBalancingConfig": [{"%s":{}}] + }`, healthCheckTestPolicyName))}) select { case <-hcExitChan: @@ -666,12 +712,12 @@ func testHealthCheckDisableWithDialOption(t *testing.T, addr string) { tc := testgrpc.NewTestServiceClient(cc) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: addr}}, - ServiceConfig: parseServiceConfig(t, r, `{ - "healthCheckConfig": { - "serviceName": "foo" - }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ + "healthCheckConfig": { + "serviceName": "foo" + }, + "loadBalancingConfig": [{"%s":{}}] + }`, healthCheckTestPolicyName))}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -772,12 +818,12 @@ func (s) TestHealthCheckChannelzCountingCallSuccess(t *testing.T) { _, r := setupClient(t, nil) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ - "healthCheckConfig": { - "serviceName": "channelzSuccess" - }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ + "healthCheckConfig": { + "serviceName": "channelzSuccess" + }, + "loadBalancingConfig": [{"%s":{}}] + }`, healthCheckTestPolicyName))}) if err := verifyResultWithDelay(func() (bool, error) { cm, _ := channelz.GetTopChannels(0, 0) @@ -821,12 +867,12 @@ func (s) TestHealthCheckChannelzCountingCallFailure(t *testing.T) { _, r := setupClient(t, nil) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ - "healthCheckConfig": { - "serviceName": "channelzFailure" - }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ + "healthCheckConfig": { + "serviceName": "channelzFailure" + }, + "loadBalancingConfig": [{"%s":{}}] + }`, healthCheckTestPolicyName))}) if err := verifyResultWithDelay(func() (bool, error) { cm, _ := channelz.GetTopChannels(0, 0) @@ -935,12 +981,12 @@ func testHealthCheckSuccess(t *testing.T, e env) { // TestHealthCheckFailure invokes the unary Check() RPC on the health server // with an expired context and expects the RPC to fail. func (s) TestHealthCheckFailure(t *testing.T) { - for _, e := range listTestEnv() { - testHealthCheckFailure(t, e) + e := env{ + name: "tcp-tls", + network: "tcp", + security: "tls", + balancer: healthCheckTestPolicyName, } -} - -func testHealthCheckFailure(t *testing.T, e env) { te := newTest(t, e) te.declareLogNoise( "Failed to dial ", @@ -1166,3 +1212,111 @@ func testHealthCheckServingStatus(t *testing.T, e env) { te.setHealthServingStatus(defaultHealthService, healthpb.HealthCheckResponse_NOT_SERVING) verifyHealthCheckStatus(t, 1*time.Second, cc, defaultHealthService, healthpb.HealthCheckResponse_NOT_SERVING) } + +// Test verifies that registering a nil health listener closes the health +// client. +func (s) TestHealthCheckUnregisterHealthListener(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + hcEnterChan, hcExitChan := setupHealthCheckWrapper(t) + scChan := make(chan balancer.SubConn, 1) + readyUpdateReceivedCh := make(chan struct{}) + bf := stub.BalancerFuncs{ + Init: func(bd *stub.BalancerData) { + cc := bd.ClientConn + ccw := &subConnStoringCCWrapper{ + ClientConn: cc, + scChan: scChan, + stateListener: func(scs balancer.SubConnState) { + if scs.ConnectivityState != connectivity.Ready { + return + } + close(readyUpdateReceivedCh) + }, + } + bd.Data = balancer.Get(pickfirst.Name).Build(ccw, bd.BuildOptions) + }, + Close: func(bd *stub.BalancerData) { + bd.Data.(balancer.Balancer).Close() + }, + UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { + return bd.Data.(balancer.Balancer).UpdateClientConnState(ccs) + }, + } + + stub.Register(t.Name(), bf) + _, lis, ts := setupServer(t, nil) + ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING) + + _, r := setupClient(t, nil) + svcCfg := fmt.Sprintf(`{ + "healthCheckConfig": { + "serviceName": "foo" + }, + "loadBalancingConfig": [{"%s":{}}] + }`, t.Name()) + r.UpdateState(resolver.State{ + Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, + ServiceConfig: parseServiceConfig(t, r, svcCfg)}) + + var sc balancer.SubConn + select { + case sc = <-scChan: + case <-ctx.Done(): + t.Fatal("Context timed out waiting for SubConn creation") + } + + // Wait for the SubConn to enter READY. + select { + case <-readyUpdateReceivedCh: + case <-ctx.Done(): + t.Fatalf("Context timed out waiting for SubConn to enter READY") + } + + // Health check should start only after a health listener is registered. + select { + case <-hcEnterChan: + t.Fatalf("Health service client created prematurely.") + case <-time.After(defaultTestShortTimeout): + } + + // Register a health listener and verify it receives updates. + healthChan := make(chan balancer.SubConnState, 1) + sc.RegisterHealthListener(func(scs balancer.SubConnState) { + healthChan <- scs + }) + + select { + case <-hcEnterChan: + case <-ctx.Done(): + t.Fatalf("Context timed out waiting for health check to begin.") + } + + for readyReceived := false; !readyReceived; { + select { + case scs := <-healthChan: + t.Logf("Received health update: %v", scs) + readyReceived = scs.ConnectivityState == connectivity.Ready + case <-ctx.Done(): + t.Fatalf("Context timed out waiting for healthy backend.") + } + } + + // Registering a nil listener should invalidate the previously registered + // listener and close the health service client. + sc.RegisterHealthListener(nil) + select { + case <-hcExitChan: + case <-ctx.Done(): + t.Fatalf("Context timed out waiting for the health client to close.") + } + + ts.SetServingStatus("foo", healthpb.HealthCheckResponse_NOT_SERVING) + + // No updates should be received on the listener. + select { + case scs := <-healthChan: + t.Fatalf("Received unexpected health update on the listener: %v", scs) + case <-time.After(defaultTestShortTimeout): + } +}