diff --git a/balancer/endpointsharding/endpointsharding.go b/balancer/endpointsharding/endpointsharding.go index b5b92143194b..263c024a84c7 100644 --- a/balancer/endpointsharding/endpointsharding.go +++ b/balancer/endpointsharding/endpointsharding.go @@ -28,19 +28,33 @@ package endpointsharding import ( "encoding/json" "errors" + "fmt" + rand "math/rand/v2" "sync" "sync/atomic" - rand "math/rand/v2" - "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/balancer/pickfirst" + "google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal/balancer/gracefulswitch" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" ) +// PickFirstConfig is a pick first config without shuffling enabled. +var PickFirstConfig string + +func init() { + name := pickfirst.Name + if !envconfig.NewPickFirstEnabled { + name = pickfirstleaf.Name + } + PickFirstConfig = fmt.Sprintf("[{%q: {}}]", name) +} + // ChildState is the balancer state of a child along with the endpoint which // identifies the child balancer. type ChildState struct { @@ -100,9 +114,6 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState // Update/Create new children. for _, endpoint := range state.ResolverState.Endpoints { - if len(endpoint.Addresses) == 0 { - continue - } if _, ok := newChildren.Get(endpoint); ok { // Endpoint child was already created, continue to avoid duplicate // update. @@ -143,6 +154,9 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState } } es.children.Store(newChildren) + if newChildren.Len() == 0 { + return balancer.ErrBadResolverState + } return ret } @@ -306,6 +320,3 @@ func (bw *balancerWrapper) UpdateState(state balancer.State) { func ParseConfig(cfg json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { return gracefulswitch.ParseConfig(cfg) } - -// PickFirstConfig is a pick first config without shuffling enabled. -const PickFirstConfig = "[{\"pick_first\": {}}]" diff --git a/balancer/weightedroundrobin/balancer.go b/balancer/weightedroundrobin/balancer.go index a0511772d2fa..c9c5b576bb0c 100644 --- a/balancer/weightedroundrobin/balancer.go +++ b/balancer/weightedroundrobin/balancer.go @@ -19,9 +19,7 @@ package weightedroundrobin import ( - "context" "encoding/json" - "errors" "fmt" rand "math/rand/v2" "sync" @@ -30,12 +28,13 @@ import ( "unsafe" "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/balancer/endpointsharding" "google.golang.org/grpc/balancer/weightedroundrobin/internal" "google.golang.org/grpc/balancer/weightedtarget" "google.golang.org/grpc/connectivity" estats "google.golang.org/grpc/experimental/stats" "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcsync" iserviceconfig "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/orca" "google.golang.org/grpc/resolver" @@ -84,23 +83,31 @@ var ( }) ) +// endpointSharding which specifies pick first children. +var endpointShardingLBConfig serviceconfig.LoadBalancingConfig + func init() { balancer.Register(bb{}) + var err error + endpointShardingLBConfig, err = endpointsharding.ParseConfig(json.RawMessage(endpointsharding.PickFirstConfig)) + if err != nil { + logger.Fatal(err) + } } type bb struct{} func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { b := &wrrBalancer{ - cc: cc, - subConns: resolver.NewAddressMap(), - csEvltr: &balancer.ConnectivityStateEvaluator{}, - scMap: make(map[balancer.SubConn]*weightedSubConn), - connectivityState: connectivity.Connecting, - target: bOpts.Target.String(), - metricsRecorder: bOpts.MetricsRecorder, + ClientConn: cc, + target: bOpts.Target.String(), + metricsRecorder: bOpts.MetricsRecorder, + addressWeights: resolver.NewAddressMap(), + endpointToWeight: resolver.NewEndpointMap(), + scToWeight: make(map[balancer.SubConn]*endpointWeight), } + b.child = endpointsharding.NewBalancer(b, bOpts) b.logger = prefixLogger(b) b.logger.Infof("Created") return b @@ -141,123 +148,189 @@ func (bb) Name() string { return Name } +// updateEndpointsLocked updates endpoint weight state based off new update, by +// starting and clearing any endpoint weights needed. +// +// Caller must hold b.mu. +func (b *wrrBalancer) updateEndpointsLocked(endpoints []resolver.Endpoint) { + endpointSet := resolver.NewEndpointMap() + addressSet := resolver.NewAddressMap() + for _, endpoint := range endpoints { + endpointSet.Set(endpoint, nil) + for _, addr := range endpoint.Addresses { + addressSet.Set(addr, nil) + } + var ew *endpointWeight + if ewi, ok := b.endpointToWeight.Get(endpoint); ok { + ew = ewi.(*endpointWeight) + } else { + ew = &endpointWeight{ + logger: b.logger, + connectivityState: connectivity.Connecting, + // Initially, we set load reports to off, because they are not + // running upon initial endpointWeight creation. + cfg: &lbConfig{EnableOOBLoadReport: false}, + metricsRecorder: b.metricsRecorder, + target: b.target, + locality: b.locality, + } + for _, addr := range endpoint.Addresses { + b.addressWeights.Set(addr, ew) + } + b.endpointToWeight.Set(endpoint, ew) + } + ew.updateConfig(b.cfg) + } + + for _, endpoint := range b.endpointToWeight.Keys() { + if _, ok := endpointSet.Get(endpoint); ok { + // Existing endpoint also in new endpoint list; skip. + continue + } + b.endpointToWeight.Delete(endpoint) + for _, addr := range endpoint.Addresses { + if _, ok := addressSet.Get(addr); !ok { // old endpoints to be deleted can share addresses with new endpoints, so only delete if necessary + b.addressWeights.Delete(addr) + } + } + // SubConn map will get handled in updateSubConnState + // when receives SHUTDOWN signal. + } +} + // wrrBalancer implements the weighted round robin LB policy. type wrrBalancer struct { - // The following fields are immutable. - cc balancer.ClientConn - logger *grpclog.PrefixLogger - target string - metricsRecorder estats.MetricsRecorder - - // The following fields are only accessed on calls into the LB policy, and - // do not need a mutex. - cfg *lbConfig // active config - subConns *resolver.AddressMap // active weightedSubConns mapped by address - scMap map[balancer.SubConn]*weightedSubConn - connectivityState connectivity.State // aggregate state - csEvltr *balancer.ConnectivityStateEvaluator - resolverErr error // the last error reported by the resolver; cleared on successful resolution - connErr error // the last connection error; cleared upon leaving TransientFailure - stopPicker func() - locality string + // The following fields are set at initialization time and read only after that, + // so they do not need to be protected by a mutex. + child balancer.Balancer + balancer.ClientConn // Embed to intercept NewSubConn operation + logger *grpclog.PrefixLogger + target string + metricsRecorder estats.MetricsRecorder + + mu sync.Mutex + cfg *lbConfig // active config + locality string + stopPicker *grpcsync.Event + addressWeights *resolver.AddressMap // addr -> endpointWeight + endpointToWeight *resolver.EndpointMap // endpoint -> endpointWeight + scToWeight map[balancer.SubConn]*endpointWeight } func (b *wrrBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error { b.logger.Infof("UpdateCCS: %v", ccs) - b.resolverErr = nil cfg, ok := ccs.BalancerConfig.(*lbConfig) if !ok { return fmt.Errorf("wrr: received nil or illegal BalancerConfig (type %T): %v", ccs.BalancerConfig, ccs.BalancerConfig) } + // Note: empty endpoints and duplicate addresses across endpoints won't + // explicitly error but will have undefined behavior. + b.mu.Lock() b.cfg = cfg b.locality = weightedtarget.LocalityFromResolverState(ccs.ResolverState) - b.updateAddresses(ccs.ResolverState.Addresses) - - if len(ccs.ResolverState.Addresses) == 0 { - b.ResolverError(errors.New("resolver produced zero addresses")) // will call regeneratePicker - return balancer.ErrBadResolverState - } + b.updateEndpointsLocked(ccs.ResolverState.Endpoints) + b.mu.Unlock() + + // This causes child to update picker inline and will thus cause inline + // picker update. + return b.child.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: endpointShardingLBConfig, + ResolverState: ccs.ResolverState, + }) +} - b.regeneratePicker() +func (b *wrrBalancer) UpdateState(state balancer.State) { + b.mu.Lock() + defer b.mu.Unlock() - return nil -} + if b.stopPicker != nil { + b.stopPicker.Fire() + b.stopPicker = nil + } -func (b *wrrBalancer) updateAddresses(addrs []resolver.Address) { - addrsSet := resolver.NewAddressMap() + childStates := endpointsharding.ChildStatesFromPicker(state.Picker) - // Loop through new address list and create subconns for any new addresses. - for _, addr := range addrs { - if _, ok := addrsSet.Get(addr); ok { - // Redundant address; skip. - continue - } - addrsSet.Set(addr, nil) + var readyPickersWeight []pickerWeightedEndpoint - var wsc *weightedSubConn - wsci, ok := b.subConns.Get(addr) - if ok { - wsc = wsci.(*weightedSubConn) - } else { - // addr is a new address (not existing in b.subConns). - var sc balancer.SubConn - sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{ - StateListener: func(state balancer.SubConnState) { - b.updateSubConnState(sc, state) - }, - }) - if err != nil { - b.logger.Warningf("Failed to create new SubConn for address %v: %v", addr, err) + for _, childState := range childStates { + if childState.State.ConnectivityState == connectivity.Ready { + ewv, ok := b.endpointToWeight.Get(childState.Endpoint) + if !ok { + // Should never happen, simply continue and ignore this endpoint + // for READY pickers. continue } - wsc = &weightedSubConn{ - SubConn: sc, - logger: b.logger, - connectivityState: connectivity.Idle, - // Initially, we set load reports to off, because they are not - // running upon initial weightedSubConn creation. - cfg: &lbConfig{EnableOOBLoadReport: false}, - - metricsRecorder: b.metricsRecorder, - target: b.target, - locality: b.locality, - } - b.subConns.Set(addr, wsc) - b.scMap[sc] = wsc - b.csEvltr.RecordTransition(connectivity.Shutdown, connectivity.Idle) - sc.Connect() + ew := ewv.(*endpointWeight) + readyPickersWeight = append(readyPickersWeight, pickerWeightedEndpoint{ + picker: childState.State.Picker, + weightedEndpoint: ew, + }) } - // Update config for existing weightedSubConn or send update for first - // time to new one. Ensures an OOB listener is running if needed - // (and stops the existing one if applicable). - wsc.updateConfig(b.cfg) + } + // If no ready pickers are present, simply defer to the round robin picker + // from endpoint sharding, which will round robin across the most relevant + // pick first children in the highest precedence connectivity state. + if len(readyPickersWeight) == 0 { + b.ClientConn.UpdateState(balancer.State{ + ConnectivityState: state.ConnectivityState, + Picker: state.Picker, + }) + return } - // Loop through existing subconns and remove ones that are not in addrs. - for _, addr := range b.subConns.Keys() { - if _, ok := addrsSet.Get(addr); ok { - // Existing address also in new address list; skip. - continue - } - // addr was removed by resolver. Remove. - wsci, _ := b.subConns.Get(addr) - wsc := wsci.(*weightedSubConn) - wsc.SubConn.Shutdown() - b.subConns.Delete(addr) + p := &picker{ + v: rand.Uint32(), // start the scheduler at a random point + cfg: b.cfg, + weightedPickers: readyPickersWeight, + metricsRecorder: b.metricsRecorder, + locality: b.locality, + target: b.target, } + + b.stopPicker = grpcsync.NewEvent() + p.start(b.stopPicker) + + b.ClientConn.UpdateState(balancer.State{ + ConnectivityState: state.ConnectivityState, + Picker: p, + }) } -func (b *wrrBalancer) ResolverError(err error) { - b.resolverErr = err - if b.subConns.Len() == 0 { - b.connectivityState = connectivity.TransientFailure +type pickerWeightedEndpoint struct { + picker balancer.Picker + weightedEndpoint *endpointWeight +} + +func (b *wrrBalancer) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { + addr := addrs[0] // The new pick first policy for DualStack will only ever create a SubConn with one address. + var sc balancer.SubConn + + oldListener := opts.StateListener + opts.StateListener = func(state balancer.SubConnState) { + b.updateSubConnState(sc, state) + oldListener(state) } - if b.connectivityState != connectivity.TransientFailure { - // No need to update the picker since no error is being returned. - return + + b.mu.Lock() + defer b.mu.Unlock() + ewi, ok := b.addressWeights.Get(addr) + if !ok { + // SubConn state updates can come in for a no longer relevant endpoint + // weight (from the old system after a new config update is applied). + return nil, fmt.Errorf("balancer is being closed; no new SubConns allowed") + } + sc, err := b.ClientConn.NewSubConn([]resolver.Address{addr}, opts) + if err != nil { + return nil, err } - b.regeneratePicker() + b.scToWeight[sc] = ewi.(*endpointWeight) + return sc, nil +} + +func (b *wrrBalancer) ResolverError(err error) { + // Will cause inline picker update from endpoint sharding. + b.child.ResolverError(err) } func (b *wrrBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { @@ -265,134 +338,84 @@ func (b *wrrBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.Sub } func (b *wrrBalancer) updateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { - wsc := b.scMap[sc] - if wsc == nil { - b.logger.Errorf("UpdateSubConnState called with an unknown SubConn: %p, %v", sc, state) + b.mu.Lock() + ew := b.scToWeight[sc] + // updates from a no longer relevant SubConn update, nothing to do here but + // forward state to state listener, which happens in wrapped listener. Will + // eventually get cleared from scMap once receives Shutdown signal. + if ew == nil { + b.mu.Unlock() return } - if b.logger.V(2) { - logger.Infof("UpdateSubConnState(%+v, %+v)", sc, state) - } - - cs := state.ConnectivityState - - if cs == connectivity.TransientFailure { - // Save error to be reported via picker. - b.connErr = state.ConnectionError - } - - if cs == connectivity.Shutdown { - delete(b.scMap, sc) - // The subconn was removed from b.subConns when the address was removed - // in updateAddresses. + if state.ConnectivityState == connectivity.Shutdown { + delete(b.scToWeight, sc) + } + b.mu.Unlock() + + // On the first READY SubConn/Transition for an endpoint, set pickedSC, + // clear endpoint tracking weight state, and potentially start an OOB watch. + if state.ConnectivityState == connectivity.Ready && ew.pickedSC == nil { + ew.pickedSC = sc + ew.mu.Lock() + ew.nonEmptySince = time.Time{} + ew.lastUpdated = time.Time{} + cfg := ew.cfg + ew.mu.Unlock() + ew.updateORCAListener(cfg) + return } - oldCS := wsc.updateConnectivityState(cs) - b.connectivityState = b.csEvltr.RecordTransition(oldCS, cs) - - // Regenerate picker when one of the following happens: - // - this sc entered or left ready - // - the aggregated state of balancer is TransientFailure - // (may need to update error message) - if (cs == connectivity.Ready) != (oldCS == connectivity.Ready) || - b.connectivityState == connectivity.TransientFailure { - b.regeneratePicker() + // If the pickedSC (the one pick first uses for an endpoint) transitions out + // of READY, stop OOB listener if needed and clear pickedSC so the next + // created SubConn for the endpoint that goes READY will be chosen for + // endpoint as the active SubConn. + if state.ConnectivityState != connectivity.Ready && ew.pickedSC == sc { + // The first SubConn that goes READY for an endpoint is what pick first + // will pick. Only once that SubConn goes not ready will pick first + // restart this cycle of creating SubConns and using the first READY + // one. The lower level endpoint sharding will ping the Pick First once + // this occurs to ExitIdle which will trigger a connection attempt. + if ew.stopORCAListener != nil { + ew.stopORCAListener() + } + ew.pickedSC = nil } } // Close stops the balancer. It cancels any ongoing scheduler updates and // stops any ORCA listeners. func (b *wrrBalancer) Close() { + b.mu.Lock() if b.stopPicker != nil { - b.stopPicker() + b.stopPicker.Fire() b.stopPicker = nil } - for _, wsc := range b.scMap { - // Ensure any lingering OOB watchers are stopped. - wsc.updateConnectivityState(connectivity.Shutdown) - } -} - -// ExitIdle is ignored; we always connect to all backends. -func (b *wrrBalancer) ExitIdle() {} + b.mu.Unlock() -func (b *wrrBalancer) readySubConns() []*weightedSubConn { - var ret []*weightedSubConn - for _, v := range b.subConns.Values() { - wsc := v.(*weightedSubConn) - if wsc.connectivityState == connectivity.Ready { - ret = append(ret, wsc) + // Ensure any lingering OOB watchers are stopped. + for _, ewv := range b.endpointToWeight.Values() { + ew := ewv.(*endpointWeight) + if ew.stopORCAListener != nil { + ew.stopORCAListener() } } - return ret } -// mergeErrors builds an error from the last connection error and the last -// resolver error. Must only be called if b.connectivityState is -// TransientFailure. -func (b *wrrBalancer) mergeErrors() error { - // connErr must always be non-nil unless there are no SubConns, in which - // case resolverErr must be non-nil. - if b.connErr == nil { - return fmt.Errorf("last resolver error: %v", b.resolverErr) +func (b *wrrBalancer) ExitIdle() { + if ei, ok := b.child.(balancer.ExitIdler); ok { // Should always be ok, as child is endpoint sharding. + ei.ExitIdle() } - if b.resolverErr == nil { - return fmt.Errorf("last connection error: %v", b.connErr) - } - return fmt.Errorf("last connection error: %v; last resolver error: %v", b.connErr, b.resolverErr) -} - -func (b *wrrBalancer) regeneratePicker() { - if b.stopPicker != nil { - b.stopPicker() - b.stopPicker = nil - } - - switch b.connectivityState { - case connectivity.TransientFailure: - b.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.TransientFailure, - Picker: base.NewErrPicker(b.mergeErrors()), - }) - return - case connectivity.Connecting, connectivity.Idle: - // Idle could happen very briefly if all subconns are Idle and we've - // asked them to connect but they haven't reported Connecting yet. - // Report the same as Connecting since this is temporary. - b.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.Connecting, - Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable), - }) - return - case connectivity.Ready: - b.connErr = nil - } - - p := &picker{ - v: rand.Uint32(), // start the scheduler at a random point - cfg: b.cfg, - subConns: b.readySubConns(), - metricsRecorder: b.metricsRecorder, - locality: b.locality, - target: b.target, - } - var ctx context.Context - ctx, b.stopPicker = context.WithCancel(context.Background()) - p.start(ctx) - b.cc.UpdateState(balancer.State{ - ConnectivityState: b.connectivityState, - Picker: p, - }) } // picker is the WRR policy's picker. It uses live-updating backend weights to // update the scheduler periodically and ensure picks are routed proportional // to those weights. type picker struct { - scheduler unsafe.Pointer // *scheduler; accessed atomically - v uint32 // incrementing value used by the scheduler; accessed atomically - cfg *lbConfig // active config when picker created - subConns []*weightedSubConn // all READY subconns + scheduler unsafe.Pointer // *scheduler; accessed atomically + v uint32 // incrementing value used by the scheduler; accessed atomically + cfg *lbConfig // active config when picker created + + weightedPickers []pickerWeightedEndpoint // all READY pickers // The following fields are immutable. target string @@ -400,14 +423,39 @@ type picker struct { metricsRecorder estats.MetricsRecorder } -func (p *picker) scWeights(recordMetrics bool) []float64 { - ws := make([]float64, len(p.subConns)) +func (p *picker) endpointWeights(recordMetrics bool) []float64 { + wp := make([]float64, len(p.weightedPickers)) now := internal.TimeNow() - for i, wsc := range p.subConns { - ws[i] = wsc.weight(now, time.Duration(p.cfg.WeightExpirationPeriod), time.Duration(p.cfg.BlackoutPeriod), recordMetrics) + for i, wpi := range p.weightedPickers { + wp[i] = wpi.weightedEndpoint.weight(now, time.Duration(p.cfg.WeightExpirationPeriod), time.Duration(p.cfg.BlackoutPeriod), recordMetrics) } + return wp +} + +func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + // Read the scheduler atomically. All scheduler operations are threadsafe, + // and if the scheduler is replaced during this usage, we want to use the + // scheduler that was live when the pick started. + sched := *(*scheduler)(atomic.LoadPointer(&p.scheduler)) - return ws + pickedPicker := p.weightedPickers[sched.nextIndex()] + pr, err := pickedPicker.picker.Pick(info) + if err != nil { + logger.Errorf("ready picker returned error: %v", err) + return balancer.PickResult{}, err + } + if !p.cfg.EnableOOBLoadReport { + oldDone := pr.Done + pr.Done = func(info balancer.DoneInfo) { + if load, ok := info.ServerLoad.(*v3orcapb.OrcaLoadReport); ok && load != nil { + pickedPicker.weightedEndpoint.OnLoadReport(load) + } + if oldDone != nil { + oldDone(info) + } + } + } + return pr, nil } func (p *picker) inc() uint32 { @@ -419,9 +467,9 @@ func (p *picker) regenerateScheduler() { atomic.StorePointer(&p.scheduler, unsafe.Pointer(&s)) } -func (p *picker) start(ctx context.Context) { +func (p *picker) start(stopPicker *grpcsync.Event) { p.regenerateScheduler() - if len(p.subConns) == 1 { + if len(p.weightedPickers) == 1 { // No need to regenerate weights with only one backend. return } @@ -431,7 +479,7 @@ func (p *picker) start(ctx context.Context) { defer ticker.Stop() for { select { - case <-ctx.Done(): + case <-stopPicker.Done(): return case <-ticker.C: p.regenerateScheduler() @@ -440,29 +488,12 @@ func (p *picker) start(ctx context.Context) { }() } -func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) { - // Read the scheduler atomically. All scheduler operations are threadsafe, - // and if the scheduler is replaced during this usage, we want to use the - // scheduler that was live when the pick started. - sched := *(*scheduler)(atomic.LoadPointer(&p.scheduler)) - - pickedSC := p.subConns[sched.nextIndex()] - pr := balancer.PickResult{SubConn: pickedSC.SubConn} - if !p.cfg.EnableOOBLoadReport { - pr.Done = func(info balancer.DoneInfo) { - if load, ok := info.ServerLoad.(*v3orcapb.OrcaLoadReport); ok && load != nil { - pickedSC.OnLoadReport(load) - } - } - } - return pr, nil -} - -// weightedSubConn is the wrapper of a subconn that holds the subconn and its -// weight (and other parameters relevant to computing the effective weight). -// When needed, it also tracks connectivity state, listens for metrics updates -// by implementing the orca.OOBListener interface and manages that listener. -type weightedSubConn struct { +// endpointWeight is the weight for an endpoint. It tracks the SubConn that will +// be picked for the endpoint, and other parameters relevant to computing the +// effective weight. When needed, it also tracks connectivity state, listens for +// metrics updates by implementing the orca.OOBListener interface and manages +// that listener. +type endpointWeight struct { // The following fields are immutable. balancer.SubConn logger *grpclog.PrefixLogger @@ -474,6 +505,11 @@ type weightedSubConn struct { // do not need a mutex. connectivityState connectivity.State stopORCAListener func() + // The first SubConn for the endpoint that goes READY when endpoint has no + // READY SubConns yet, cleared on that sc disconnecting (i.e. going out of + // READY). Represents what pick first will use as it's picked SubConn for + // this endpoint. + pickedSC balancer.SubConn // The following fields are accessed asynchronously and are protected by // mu. Note that mu may not be held when calling into the stopORCAListener @@ -487,11 +523,11 @@ type weightedSubConn struct { cfg *lbConfig } -func (w *weightedSubConn) OnLoadReport(load *v3orcapb.OrcaLoadReport) { +func (w *endpointWeight) OnLoadReport(load *v3orcapb.OrcaLoadReport) { if w.logger.V(2) { w.logger.Infof("Received load report for subchannel %v: %v", w.SubConn, load) } - // Update weights of this subchannel according to the reported load + // Update weights of this endpoint according to the reported load. utilization := load.ApplicationUtilization if utilization == 0 { utilization = load.CpuUtilization @@ -520,7 +556,7 @@ func (w *weightedSubConn) OnLoadReport(load *v3orcapb.OrcaLoadReport) { // updateConfig updates the parameters of the WRR policy and // stops/starts/restarts the ORCA OOB listener. -func (w *weightedSubConn) updateConfig(cfg *lbConfig) { +func (w *endpointWeight) updateConfig(cfg *lbConfig) { w.mu.Lock() oldCfg := w.cfg w.cfg = cfg @@ -533,14 +569,12 @@ func (w *weightedSubConn) updateConfig(cfg *lbConfig) { // load reporting disabled, OOBReportingPeriod is always 0.) return } - if w.connectivityState == connectivity.Ready { - // (Re)start the listener to use the new config's settings for OOB - // reporting. - w.updateORCAListener(cfg) - } + // (Re)start the listener to use the new config's settings for OOB + // reporting. + w.updateORCAListener(cfg) } -func (w *weightedSubConn) updateORCAListener(cfg *lbConfig) { +func (w *endpointWeight) updateORCAListener(cfg *lbConfig) { if w.stopORCAListener != nil { w.stopORCAListener() } @@ -548,57 +582,22 @@ func (w *weightedSubConn) updateORCAListener(cfg *lbConfig) { w.stopORCAListener = nil return } + if w.pickedSC == nil { // No picked SC for this endpoint yet, nothing to listen on. + return + } if w.logger.V(2) { - w.logger.Infof("Registering ORCA listener for %v with interval %v", w.SubConn, cfg.OOBReportingPeriod) + w.logger.Infof("Registering ORCA listener for %v with interval %v", w.pickedSC, cfg.OOBReportingPeriod) } opts := orca.OOBListenerOptions{ReportInterval: time.Duration(cfg.OOBReportingPeriod)} - w.stopORCAListener = orca.RegisterOOBListener(w.SubConn, w, opts) -} - -func (w *weightedSubConn) updateConnectivityState(cs connectivity.State) connectivity.State { - switch cs { - case connectivity.Idle: - // Always reconnect when idle. - w.SubConn.Connect() - case connectivity.Ready: - // If we transition back to READY state, reset nonEmptySince so that we - // apply the blackout period after we start receiving load data. Also - // reset lastUpdated to trigger endpoint weight not yet usable in the - // case endpoint gets asked what weight it is before receiving a new - // load report. Note that we cannot guarantee that we will never receive - // lingering callbacks for backend metric reports from the previous - // connection after the new connection has been established, but they - // should be masked by new backend metric reports from the new - // connection by the time the blackout period ends. - w.mu.Lock() - w.nonEmptySince = time.Time{} - w.lastUpdated = time.Time{} - cfg := w.cfg - w.mu.Unlock() - w.updateORCAListener(cfg) - } - - oldCS := w.connectivityState - - if oldCS == connectivity.TransientFailure && - (cs == connectivity.Connecting || cs == connectivity.Idle) { - // Once a subconn enters TRANSIENT_FAILURE, ignore subsequent IDLE or - // CONNECTING transitions to prevent the aggregated state from being - // always CONNECTING when many backends exist but are all down. - return oldCS - } - - w.connectivityState = cs - - return oldCS + w.stopORCAListener = orca.RegisterOOBListener(w.pickedSC, w, opts) } -// weight returns the current effective weight of the subconn, taking into +// weight returns the current effective weight of the endpoint, taking into // account the parameters. Returns 0 for blacked out or expired data, which // will cause the backend weight to be treated as the mean of the weights of the // other backends. If forScheduler is set to true, this function will emit // metrics through the metrics registry. -func (w *weightedSubConn) weight(now time.Time, weightExpirationPeriod, blackoutPeriod time.Duration, recordMetrics bool) (weight float64) { +func (w *endpointWeight) weight(now time.Time, weightExpirationPeriod, blackoutPeriod time.Duration, recordMetrics bool) (weight float64) { w.mu.Lock() defer w.mu.Unlock() @@ -608,7 +607,7 @@ func (w *weightedSubConn) weight(now time.Time, weightExpirationPeriod, blackout }() } - // The SubConn has not received a load report (i.e. just turned READY with + // The endpoint has not received a load report (i.e. just turned READY with // no load report). if w.lastUpdated.Equal(time.Time{}) { endpointWeightNotYetUsableMetric.Record(w.metricsRecorder, 1, w.target, w.locality) diff --git a/balancer/weightedroundrobin/balancer_test.go b/balancer/weightedroundrobin/balancer_test.go index 68d2d5a5c5c8..5e369780764e 100644 --- a/balancer/weightedroundrobin/balancer_test.go +++ b/balancer/weightedroundrobin/balancer_test.go @@ -460,6 +460,65 @@ func (s) TestBalancer_TwoAddresses_OOBThenPerCall(t *testing.T) { checkWeights(ctx, t, srvWeight{srv1, 10}, srvWeight{srv2, 1}) } +// TestEndpoints_SharedAddress tests the case where two endpoints have the same +// address. The expected behavior is undefined, however the program should not +// crash. +func (s) TestEndpoints_SharedAddress(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv := startServer(t, reportCall) + sc := svcConfig(t, perCallConfig) + if err := srv.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + + endpointsSharedAddress := []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: srv.Address}}}, {Addresses: []resolver.Address{{Addr: srv.Address}}}} + srv.R.UpdateState(resolver.State{Endpoints: endpointsSharedAddress}) + + // Make some RPC's and make sure doesn't crash. It should go to one of the + // endpoints addresses, it's undefined which one it will choose and the load + // reporting might not work, but it should be able to make an RPC. + for i := 0; i < 10; i++ { + if _, err := srv.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall failed with err: %v", err) + } + } +} + +// TestEndpoints_MultipleAddresses tests WRR on endpoints with numerous +// addresses. It configures WRR with two endpoints with one bad address followed +// by a good address. It configures two backends that each report per call +// metrics, each corresponding to the two endpoints good address. It then +// asserts load is distributed as expected corresponding to the call metrics +// received. +func (s) TestEndpoints_MultipleAddresses(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + srv1 := startServer(t, reportCall) + srv2 := startServer(t, reportCall) + + srv1.callMetrics.SetQPS(10.0) + srv1.callMetrics.SetApplicationUtilization(.1) + + srv2.callMetrics.SetQPS(10.0) + srv2.callMetrics.SetApplicationUtilization(1.0) + + sc := svcConfig(t, perCallConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + + twoEndpoints := []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: "bad-address-1"}, {Addr: srv1.Address}}}, {Addresses: []resolver.Address{{Addr: "bad-address-2"}, {Addr: srv2.Address}}}} + srv1.R.UpdateState(resolver.State{Endpoints: twoEndpoints}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 10}, srvWeight{srv2, 1}) +} + // Tests two addresses with OOB ORCA reporting enabled and a non-zero error // penalty applied. func (s) TestBalancer_TwoAddresses_ErrorPenalty(t *testing.T) { diff --git a/balancer/weightedroundrobin/metrics_test.go b/balancer/weightedroundrobin/metrics_test.go index 9794a65e044f..79e4d0a145a0 100644 --- a/balancer/weightedroundrobin/metrics_test.go +++ b/balancer/weightedroundrobin/metrics_test.go @@ -109,7 +109,7 @@ func (s) TestWRR_Metrics_SubConnWeight(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { tmr := stats.NewTestMetricsRecorder() - wsc := &weightedSubConn{ + wsc := &endpointWeight{ metricsRecorder: tmr, weightVal: 3, lastUpdated: test.lastUpdated, @@ -137,7 +137,7 @@ func (s) TestWRR_Metrics_SubConnWeight(t *testing.T) { // fallback. func (s) TestWRR_Metrics_Scheduler_RR_Fallback(t *testing.T) { tmr := stats.NewTestMetricsRecorder() - wsc := &weightedSubConn{ + ew := &endpointWeight{ metricsRecorder: tmr, weightVal: 0, } @@ -147,7 +147,7 @@ func (s) TestWRR_Metrics_Scheduler_RR_Fallback(t *testing.T) { BlackoutPeriod: iserviceconfig.Duration(10 * time.Second), WeightExpirationPeriod: iserviceconfig.Duration(3 * time.Minute), }, - subConns: []*weightedSubConn{wsc}, + weightedPickers: []pickerWeightedEndpoint{{weightedEndpoint: ew}}, metricsRecorder: tmr, } // There is only one SubConn, so no matter if the SubConn has a weight or @@ -160,12 +160,12 @@ func (s) TestWRR_Metrics_Scheduler_RR_Fallback(t *testing.T) { // With two SubConns, if neither of them have weights, it will also fallback // to round robin. - wsc2 := &weightedSubConn{ + ew2 := &endpointWeight{ target: "target", metricsRecorder: tmr, weightVal: 0, } - p.subConns = append(p.subConns, wsc2) + p.weightedPickers = append(p.weightedPickers, pickerWeightedEndpoint{weightedEndpoint: ew2}) p.regenerateScheduler() if got, _ := tmr.Metric("grpc.lb.wrr.rr_fallback"); got != 1 { t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.wrr.rr_fallback", got, 1) diff --git a/balancer/weightedroundrobin/scheduler.go b/balancer/weightedroundrobin/scheduler.go index 56aa15da10d2..7d3d6815eb7a 100644 --- a/balancer/weightedroundrobin/scheduler.go +++ b/balancer/weightedroundrobin/scheduler.go @@ -26,14 +26,14 @@ type scheduler interface { nextIndex() int } -// newScheduler uses scWeights to create a new scheduler for selecting subconns +// newScheduler uses scWeights to create a new scheduler for selecting endpoints // in a picker. It will return a round robin implementation if at least -// len(scWeights)-1 are zero or there is only a single subconn, otherwise it +// len(scWeights)-1 are zero or there is only a single endpoint, otherwise it // will return an Earliest Deadline First (EDF) scheduler implementation that -// selects the subchannels according to their weights. +// selects the endpoints according to their weights. func (p *picker) newScheduler(recordMetrics bool) scheduler { - scWeights := p.scWeights(recordMetrics) - n := len(scWeights) + epWeights := p.endpointWeights(recordMetrics) + n := len(epWeights) if n == 0 { return nil } @@ -46,7 +46,7 @@ func (p *picker) newScheduler(recordMetrics bool) scheduler { sum := float64(0) numZero := 0 max := float64(0) - for _, w := range scWeights { + for _, w := range epWeights { sum += w if w > max { max = w @@ -68,7 +68,7 @@ func (p *picker) newScheduler(recordMetrics bool) scheduler { weights := make([]uint16, n) allEqual := true - for i, w := range scWeights { + for i, w := range epWeights { if w == 0 { // Backends with weight = 0 use the mean. weights[i] = mean