From 6fcbce25a48fea26f8bed6ba110c739aa28818e7 Mon Sep 17 00:00:00 2001 From: fraenkel Date: Thu, 1 May 2014 21:10:01 -0400 Subject: [PATCH] Retry on failed connects - Fix all go vet issues [#69786534] Signed-off-by: Matthew Sykes --- access_log/access_log_record.go | 6 +- access_log/access_log_record_test.go | 4 +- ...file_and_loggregator_access_logger_test.go | 14 +- common/component.go | 8 +- main.go | 2 +- main_test.go | 2 + perf_test.go | 7 +- proxy/proxy.go | 206 ++++++++------ proxy/proxy_test.go | 52 +++- proxy/request_handler.go | 127 ++++++--- proxy/responsewriter.go | 70 +++++ registry/registry.go | 232 ++++++---------- registry/registry_test.go | 206 +++++++------- route/endpoint.go | 26 +- route/endpoint_iterator_test.go | 198 ++++++++++++++ route/pool.go | 257 ++++++++++++++++-- route/pool_test.go | 181 ++++++------ router/helper_test.go | 18 +- router/registry_message.go | 10 +- router/router.go | 10 +- router/router_drain_test.go | 4 +- router/router_test.go | 11 +- scripts/test | 10 + varz/varz.go | 20 +- varz/varz_test.go | 12 +- 25 files changed, 1092 insertions(+), 601 deletions(-) create mode 100644 proxy/responsewriter.go create mode 100644 route/endpoint_iterator_test.go diff --git a/access_log/access_log_record.go b/access_log/access_log_record.go index 699139f59..982bb57f6 100644 --- a/access_log/access_log_record.go +++ b/access_log/access_log_record.go @@ -12,7 +12,7 @@ import ( type AccessLogRecord struct { Request *http.Request - Response *http.Response + StatusCode int RouteEndpoint *route.Endpoint StartedAt time.Time FirstByteAt time.Time @@ -42,10 +42,10 @@ func (r *AccessLogRecord) makeRecord() *bytes.Buffer { fmt.Fprintf(b, `[%s] `, r.FormatStartedAt()) fmt.Fprintf(b, `"%s %s %s" `, r.Request.Method, r.Request.URL.RequestURI(), r.Request.Proto) - if r.Response == nil { + if r.StatusCode == 0 { fmt.Fprintf(b, "MissingResponseStatusCode ") } else { - fmt.Fprintf(b, `%d `, r.Response.StatusCode) + fmt.Fprintf(b, `%d `, r.StatusCode) } fmt.Fprintf(b, `%d `, r.BodyBytesSent) diff --git a/access_log/access_log_record_test.go b/access_log/access_log_record_test.go index b6db3ccb1..03b5d7989 100644 --- a/access_log/access_log_record_test.go +++ b/access_log/access_log_record_test.go @@ -92,9 +92,7 @@ func CompleteAccessLogRecord() AccessLogRecord { RemoteAddr: "FakeRemoteAddr", }, BodyBytesSent: 23, - Response: &http.Response{ - StatusCode: 200, - }, + StatusCode: 200, RouteEndpoint: &route.Endpoint{ ApplicationId: "FakeApplicationId", }, diff --git a/access_log/file_and_loggregator_access_logger_test.go b/access_log/file_and_loggregator_access_logger_test.go index d39d3ef3e..0cfb62dd2 100644 --- a/access_log/file_and_loggregator_access_logger_test.go +++ b/access_log/file_and_loggregator_access_logger_test.go @@ -62,11 +62,7 @@ var _ = Describe("AccessLog", func() { testEmitter := NewMockEmitter() accessLogger := NewFileAndLoggregatorAccessLogger(nil, testEmitter) - routeEndpoint := &route.Endpoint{ - ApplicationId: "", - Host: "127.0.0.1", - Port: 4567, - } + routeEndpoint := route.NewEndpoint("", "127.0.0.1", 4567, "", nil) accessLogRecord := CreateAccessLogRecord() accessLogRecord.RouteEndpoint = routeEndpoint @@ -165,15 +161,11 @@ func CreateAccessLogRecord() *AccessLogRecord { StatusCode: http.StatusOK, } - b := &route.Endpoint{ - ApplicationId: "my_awesome_id", - Host: "127.0.0.1", - Port: 4567, - } + b := route.NewEndpoint("my_awesome_id", "127.0.0.1", 4567, "", nil) r := AccessLogRecord{ Request: req, - Response: res, + StatusCode: res.StatusCode, RouteEndpoint: b, StartedAt: time.Unix(10, 100000000), FirstByteAt: time.Unix(10, 200000000), diff --git a/common/component.go b/common/component.go index 8b10697a0..0a5db8041 100644 --- a/common/component.go +++ b/common/component.go @@ -4,13 +4,13 @@ import ( "encoding/json" "errors" "fmt" - . "github.com/cloudfoundry/gorouter/common/http" - steno "github.com/cloudfoundry/gosteno" - "github.com/cloudfoundry/yagnats" "net" "net/http" "runtime" "time" + . "github.com/cloudfoundry/gorouter/common/http" + steno "github.com/cloudfoundry/gosteno" + "github.com/cloudfoundry/yagnats" ) var procStat *ProcessStatus @@ -81,7 +81,7 @@ func (c *VcapComponent) Start() error { return err } - c.Host = fmt.Sprintf("%s:%s", host, port) + c.Host = fmt.Sprintf("%s:%d", host, port) } if c.Credentials == nil || len(c.Credentials) != 2 { diff --git a/main.go b/main.go index 2b4cf080f..e06e5005b 100644 --- a/main.go +++ b/main.go @@ -63,7 +63,7 @@ func main() { logger.Fatalf("Error connecting to NATS: %s\n", err) } - registry := rregistry.NewCFRegistry(c, natsClient) + registry := rregistry.NewRouteRegistry(c, natsClient) varz := rvarz.NewVarz(registry) diff --git a/main_test.go b/main_test.go index f3074bbb5..30272a0f9 100644 --- a/main_test.go +++ b/main_test.go @@ -115,6 +115,7 @@ var _ = Describe("Router Integration", func() { It("waits for all requests to finish", func() { mbusClient, err := newMessageBus(config) + Ω(err).ShouldNot(HaveOccurred()) blocker := make(chan bool) longApp := test.NewTestApp([]route.Uri{"longapp.vcap.me"}, proxyPort, mbusClient, nil) @@ -148,6 +149,7 @@ var _ = Describe("Router Integration", func() { It("will timeout if requests take too long", func() { mbusClient, err := newMessageBus(config) + Ω(err).ShouldNot(HaveOccurred()) blocker := make(chan bool) resultCh := make(chan error, 1) diff --git a/perf_test.go b/perf_test.go index bcf61ff25..3dc4996e1 100644 --- a/perf_test.go +++ b/perf_test.go @@ -19,7 +19,7 @@ var _ = Describe("AccessLogRecord", func() { Measure("Register", func(b Benchmarker) { c := config.DefaultConfig() mbus := fakeyagnats.New() - r := registry.NewCFRegistry(c, mbus) + r := registry.NewRouteRegistry(c, mbus) accesslog, err := access_log.CreateRunningAccessLogger(c) Ω(err).ToNot(HaveOccurred()) @@ -38,10 +38,7 @@ var _ = Describe("AccessLogRecord", func() { str := strconv.Itoa(i) r.Register( route.Uri("bench.vcap.me."+str), - &route.Endpoint{ - Host: "localhost", - Port: uint16(i), - }, + route.NewEndpoint("", "localhost", uint16(i), "", nil), ) } }) diff --git a/proxy/proxy.go b/proxy/proxy.go index 8d8d23256..5ceb590d9 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -1,6 +1,8 @@ package proxy import ( + "errors" + "net" "net/http" "net/http/httputil" "strings" @@ -16,13 +18,17 @@ import ( const ( VcapCookieId = "__VCAP_ID__" StickyCookieKey = "JSESSIONID" + retries = 3 ) +var noEndpointsAvailable = errors.New("No endpoints available") + type LookupRegistry interface { - Lookup(uri route.Uri) (*route.Endpoint, bool) - LookupByPrivateInstanceId(uri route.Uri, p string) (*route.Endpoint, bool) + Lookup(uri route.Uri) *route.Pool } +type AfterRoundTrip func(rsp *http.Response, endpoint *route.Endpoint, err error) + type ProxyReporter interface { CaptureBadRequest(req *http.Request) CaptureBadGateway(req *http.Request) @@ -88,32 +94,32 @@ func (p *proxy) Wait() { p.waitgroup.Wait() } -func (p *proxy) lookup(request *http.Request) (*route.Endpoint, bool) { - uri := route.Uri(hostWithoutPort(request)) - +func (p *proxy) getStickySession(request *http.Request) string { // Try choosing a backend using sticky session if _, err := request.Cookie(StickyCookieKey); err == nil { if sticky, err := request.Cookie(VcapCookieId); err == nil { - routeEndpoint, ok := p.registry.LookupByPrivateInstanceId(uri, sticky.Value) - if ok { - return routeEndpoint, ok - } + return sticky.Value } } + return "" +} +func (p *proxy) lookup(request *http.Request) *route.Pool { + uri := route.Uri(hostWithoutPort(request)) // Choose backend using host alone return p.registry.Lookup(uri) } func (p *proxy) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { startedAt := time.Now() - handler := NewRequestHandler(request, responseWriter) accessLog := access_log.AccessLogRecord{ Request: request, StartedAt: startedAt, } + handler := NewRequestHandler(request, responseWriter, p.reporter, &accessLog) + p.waitgroup.Add(1) defer func() { @@ -131,164 +137,184 @@ func (p *proxy) ServeHTTP(responseWriter http.ResponseWriter, request *http.Requ return } - routeEndpoint, found := p.lookup(request) - if !found { + routePool := p.lookup(request) + if routePool == nil { p.reporter.CaptureBadRequest(request) handler.HandleMissingRoute() return } - handler.logger.Set("RouteEndpoint", routeEndpoint.ToLogData()) + stickyEndpointId := p.getStickySession(request) + iter := &wrappedIterator{ + nested: routePool.Endpoints(stickyEndpointId), - accessLog.RouteEndpoint = routeEndpoint - - p.reporter.CaptureRoutingRequest(routeEndpoint, handler.request) + afterNext: func(endpoint *route.Endpoint) { + if endpoint != nil { + handler.logger.Set("RouteEndpoint", endpoint.ToLogData()) + accessLog.RouteEndpoint = endpoint + p.reporter.CaptureRoutingRequest(endpoint, request) + } + }, + } if isTcpUpgrade(request) { - handler.HandleTcpRequest(routeEndpoint) + handler.HandleTcpRequest(iter) return } if isWebSocketUpgrade(request) { - handler.HandleWebSocketRequest(routeEndpoint) + handler.HandleWebSocketRequest(iter) return } + proxyWriter := newProxyResponseWriter(responseWriter) proxyTransport := &proxyRoundTripper{ transport: p.transport, - after: func(rsp *http.Response, err error) { + iter: iter, + handler: &handler, + + after: func(rsp *http.Response, endpoint *route.Endpoint, err error) { accessLog.FirstByteAt = time.Now() - accessLog.Response = rsp + if rsp != nil { + accessLog.StatusCode = rsp.StatusCode + } // disable keep-alives -- not needed with Go 1.3 responseWriter.Header().Set("Connection", "close") if p.traceKey != "" && request.Header.Get(router_http.VcapTraceHeader) == p.traceKey { - setTraceHeaders(responseWriter, p.ip, routeEndpoint.CanonicalAddr()) + setTraceHeaders(responseWriter, p.ip, endpoint.CanonicalAddr()) } latency := time.Since(startedAt) - p.reporter.CaptureRoutingResponse(routeEndpoint, rsp, startedAt, latency) + p.reporter.CaptureRoutingResponse(endpoint, rsp, startedAt, latency) if err != nil { p.reporter.CaptureBadGateway(request) handler.HandleBadGateway(err) + proxyWriter.Done() return } - if routeEndpoint.PrivateInstanceId != "" { - setupStickySession(responseWriter, rsp, routeEndpoint) + if endpoint.PrivateInstanceId != "" { + setupStickySession(responseWriter, rsp, endpoint) } }, } - proxyWriter := newProxyResponseWriter(responseWriter) - p.newReverseProxy(proxyTransport, routeEndpoint, request).ServeHTTP(proxyWriter, request) + p.newReverseProxy(proxyTransport, request).ServeHTTP(proxyWriter, request) accessLog.FinishedAt = time.Now() accessLog.BodyBytesSent = int64(proxyWriter.Size()) } -func (p *proxy) newReverseProxy(proxyTransport http.RoundTripper, endpoint *route.Endpoint, req *http.Request) http.Handler { +func (p *proxy) newReverseProxy(proxyTransport http.RoundTripper, req *http.Request) http.Handler { rproxy := &httputil.ReverseProxy{ Director: func(request *http.Request) { request.URL.Scheme = "http" - request.URL.Host = endpoint.CanonicalAddr() + request.URL.Host = req.Host request.URL.Opaque = req.URL.Opaque request.URL.RawQuery = req.URL.RawQuery setRequestXRequestStart(req) setRequestXVcapRequestId(req, nil) }, + Transport: proxyTransport, + FlushInterval: 50 * time.Millisecond, } - rproxy.Transport = proxyTransport - rproxy.FlushInterval = 50 * time.Millisecond - return rproxy } -func setupStickySession(responseWriter http.ResponseWriter, response *http.Response, endpoint *route.Endpoint) { - for _, v := range response.Cookies() { - if v.Name == StickyCookieKey { - cookie := &http.Cookie{ - Name: VcapCookieId, - Value: endpoint.PrivateInstanceId, - Path: "/", - } - - http.SetCookie(responseWriter, cookie) - return - } - } -} - type proxyRoundTripper struct { transport http.RoundTripper - after func(response *http.Response, err error) - response *http.Response - err error + after AfterRoundTrip + iter route.EndpointIterator + handler *RequestHandler + + response *http.Response + err error } func (p *proxyRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { - p.response, p.err = p.transport.RoundTrip(request) - if p.after != nil { - p.after(p.response, p.err) - } + var err error + var res *http.Response + var endpoint *route.Endpoint + retry := 0 + for { + endpoint = p.iter.Next() + + if endpoint == nil { + p.handler.reporter.CaptureBadGateway(request) + err = noEndpointsAvailable + p.handler.HandleBadGateway(err) + return nil, err + } - return p.response, p.err -} + request.URL.Host = endpoint.CanonicalAddr() + res, err = p.transport.RoundTrip(request) + if err == nil { + break + } -type proxyResponseWriter struct { - w http.ResponseWriter - status int - size int + if ne, netErr := err.(*net.OpError); !netErr || ne.Op != "dial" { + break + } - flusher http.Flusher -} + p.iter.EndpointFailed() -func newProxyResponseWriter(w http.ResponseWriter) *proxyResponseWriter { - proxyWriter := &proxyResponseWriter{ - w: w, - flusher: w.(http.Flusher), + p.handler.Logger().Set("Error", err.Error()) + p.handler.Logger().Warnf("proxy.endpoint.failed") + + retry++ + if retry == retries { + break + } } - return proxyWriter -} + if p.after != nil { + p.after(res, endpoint, err) + } -func (p *proxyResponseWriter) Header() http.Header { - return p.w.Header() -} + p.response = res + p.err = err -func (p *proxyResponseWriter) Write(b []byte) (int, error) { - if p.status == 0 { - p.WriteHeader(http.StatusOK) - } - size, err := p.w.Write(b) - p.size += size - return size, err + return res, err } -func (p *proxyResponseWriter) WriteHeader(s int) { - p.w.WriteHeader(s) - - if p.status == 0 { - p.status = s - } +type wrappedIterator struct { + nested route.EndpointIterator + afterNext func(*route.Endpoint) } -func (p *proxyResponseWriter) Flush() { - if p.flusher != nil { - p.flusher.Flush() + +func (i *wrappedIterator) Next() *route.Endpoint { + e := i.nested.Next() + if i.afterNext != nil { + i.afterNext(e) } + return e } -func (p *proxyResponseWriter) Status() int { - return p.status +func (i *wrappedIterator) EndpointFailed() { + i.nested.EndpointFailed() } -func (p *proxyResponseWriter) Size() int { - return p.size +func setupStickySession(responseWriter http.ResponseWriter, response *http.Response, endpoint *route.Endpoint) { + for _, v := range response.Cookies() { + if v.Name == StickyCookieKey { + cookie := &http.Cookie{ + Name: VcapCookieId, + Value: endpoint.PrivateInstanceId, + Path: "/", + + HttpOnly: true, + } + + http.SetCookie(responseWriter, cookie) + return + } + } } func isProtocolSupported(request *http.Request) bool { diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 22a3ebff1..9b0952c25 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -33,14 +33,14 @@ type nullVarz struct{} func (_ nullVarz) MarshalJSON() ([]byte, error) { return json.Marshal(nil) } func (_ nullVarz) ActiveApps() *stats.ActiveApps { return stats.NewActiveApps() } -func (_ nullVarz) CaptureBadRequest(req *http.Request) {} -func (_ nullVarz) CaptureBadGateway(req *http.Request) {} +func (_ nullVarz) CaptureBadRequest(*http.Request) {} +func (_ nullVarz) CaptureBadGateway(*http.Request) {} func (_ nullVarz) CaptureRoutingRequest(b *route.Endpoint, req *http.Request) {} func (_ nullVarz) CaptureRoutingResponse(b *route.Endpoint, res *http.Response, t time.Time, d time.Duration) { } var _ = Describe("Proxy", func() { - var r *registry.CFRegistry + var r *registry.RouteRegistry var p Proxy var conf *config.Config var proxyServer net.Listener @@ -54,7 +54,7 @@ var _ = Describe("Proxy", func() { mbus := fakeyagnats.New() - r = registry.NewCFRegistry(conf, mbus) + r = registry.NewRouteRegistry(conf, mbus) accessLogFile = new(test_util.FakeFile) accessLog = access_log.NewFileAndLoggregatorAccessLogger(accessLogFile, nil) @@ -112,6 +112,7 @@ var _ = Describe("Proxy", func() { "HTTP/1.1 200 OK", "Content-Length: 0", }) + }) defer ln.Close() @@ -125,9 +126,10 @@ var _ = Describe("Proxy", func() { x.CheckLine("HTTP/1.0 200 OK") var payload []byte - n, e := accessLogFile.Read(&payload) - Ω(e).ShouldNot(HaveOccurred()) - Ω(n).ShouldNot(BeZero()) + Eventually(func() int { + accessLogFile.Read(&payload) + return len(payload) + }).ShouldNot(BeZero()) Ω(string(payload)).To(MatchRegexp("^test.*\n")) //make sure the record includes all the data //since the building of the log record happens throughout the life of the request @@ -726,6 +728,31 @@ var _ = Describe("Proxy", func() { Ω(err).Should(HaveOccurred()) }) + It("retries when failed endpoints exist", func() { + ln := registerHandler(r, "retries", func(x *test_util.HttpConn) { + x.CheckLine("GET / HTTP/1.1") + resp := test_util.NewResponse(http.StatusOK) + x.WriteResponse(resp) + x.Close() + }) + defer ln.Close() + + ip, err := net.ResolveTCPAddr("tcp", "localhost:81") + Ω(err).Should(BeNil()) + registerAddr(r, "retries", ip) + + for i := 0; i < 5; i++ { + x := dialProxy(proxyServer) + + req := x.NewRequest("GET", "/", nil) + req.Host = "retries" + x.WriteRequest(req) + resp, _ := x.ReadResponse() + + Ω(resp.StatusCode).To(Equal(http.StatusOK)) + } + }) + Context("Wait", func() { It("waits for requests to finish", func() { blocker := make(chan bool) @@ -766,7 +793,7 @@ var _ = Describe("Proxy", func() { }) }) -func registerAddr(r *registry.CFRegistry, u string, a net.Addr) { +func registerAddr(r *registry.RouteRegistry, u string, a net.Addr) { h, p, err := net.SplitHostPort(a.String()) Ω(err).NotTo(HaveOccurred()) @@ -775,14 +802,11 @@ func registerAddr(r *registry.CFRegistry, u string, a net.Addr) { r.Register( route.Uri(u), - &route.Endpoint{ - Host: h, - Port: uint16(x), - }, + route.NewEndpoint("", h, uint16(x), "", nil), ) } -func registerHandler(r *registry.CFRegistry, u string, h connHandler) net.Listener { +func registerHandler(r *registry.RouteRegistry, u string, h connHandler) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") Ω(err).NotTo(HaveOccurred()) @@ -800,7 +824,7 @@ func registerHandler(r *registry.CFRegistry, u string, h connHandler) net.Listen if max := 1 * time.Second; tempDelay > max { tempDelay = max } - println("http: Accept error: %v; retrying in %v", err, tempDelay) + fmt.Printf("http: Accept error: %v; retrying in %v\n", err, tempDelay) time.Sleep(tempDelay) continue } diff --git a/proxy/request_handler.go b/proxy/request_handler.go index 890b023c5..4173cdacf 100644 --- a/proxy/request_handler.go +++ b/proxy/request_handler.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/cloudfoundry/gorouter/access_log" "github.com/cloudfoundry/gorouter/common" router_http "github.com/cloudfoundry/gorouter/common/http" "github.com/cloudfoundry/gorouter/route" @@ -18,15 +19,20 @@ import ( ) type RequestHandler struct { - logger *steno.Logger + logger *steno.Logger + reporter ProxyReporter + logrecord *access_log.AccessLogRecord request *http.Request response http.ResponseWriter } -func NewRequestHandler(request *http.Request, response http.ResponseWriter) RequestHandler { +func NewRequestHandler(request *http.Request, response http.ResponseWriter, r ProxyReporter, + alr *access_log.AccessLogRecord) RequestHandler { return RequestHandler{ - logger: createLogger(request), + logger: createLogger(request), + reporter: r, + logrecord: alr, request: request, response: response, @@ -45,7 +51,12 @@ func createLogger(request *http.Request) *steno.Logger { return logger } +func (h *RequestHandler) Logger() *steno.Logger { + return h.logger +} + func (h *RequestHandler) HandleHeartbeat() { + h.logrecord.StatusCode = http.StatusOK h.response.WriteHeader(http.StatusOK) h.response.Write([]byte("ok\n")) h.request.Close = true @@ -59,6 +70,7 @@ func (h *RequestHandler) HandleUnsupportedProtocol() { return } + h.logrecord.StatusCode = http.StatusBadRequest fmt.Fprintf(buf, "HTTP/1.0 400 Bad Request\r\n\r\n") buf.Flush() conn.Close() @@ -80,28 +92,20 @@ func (h *RequestHandler) HandleBadGateway(err error) { h.writeStatus(http.StatusBadGateway, "Registered endpoint failed to handle the request.") } -func (h *RequestHandler) HandleTcpRequest(endpoint *route.Endpoint) { +func (h *RequestHandler) HandleTcpRequest(iter route.EndpointIterator) { h.logger.Set("Upgrade", "tcp") - err := h.serveTcp(endpoint) + err := h.serveTcp(iter) if err != nil { - h.logger.Set("Error", err.Error()) - h.logger.Warn("proxy.tcp.failed") - h.writeStatus(http.StatusBadRequest, "TCP forwarding to endpoint failed.") } } -func (h *RequestHandler) HandleWebSocketRequest(endpoint *route.Endpoint) { - h.setupRequest(endpoint) - +func (h *RequestHandler) HandleWebSocketRequest(iter route.EndpointIterator) { h.logger.Set("Upgrade", "websocket") - err := h.serveWebSocket(endpoint) + err := h.serveWebSocket(iter) if err != nil { - h.logger.Set("Error", err.Error()) - h.logger.Warn("proxy.websocket.failed") - h.writeStatus(http.StatusBadRequest, "WebSocket request to endpoint failed.") } } @@ -110,6 +114,7 @@ func (h *RequestHandler) writeStatus(code int, message string) { body := fmt.Sprintf("%d %s: %s", code, http.StatusText(code), message) h.logger.Warn(body) + h.logrecord.StatusCode = code http.Error(h.response, body, code) if code > 299 { @@ -117,54 +122,104 @@ func (h *RequestHandler) writeStatus(code int, message string) { } } -func (h *RequestHandler) serveTcp(endpoint *route.Endpoint) error { +func (h *RequestHandler) serveTcp(iter route.EndpointIterator) error { var err error + var connection net.Conn client, _, err := h.hijack() if err != nil { return err } - - connection, err := net.Dial("tcp", endpoint.CanonicalAddr()) - if err != nil { - return err - } - defer func() { client.Close() - connection.Close() + if connection != nil { + connection.Close() + } }() - forwardIO(client, connection) + retry := 0 + for { + endpoint := iter.Next() + if endpoint == nil { + h.reporter.CaptureBadGateway(h.request) + err = noEndpointsAvailable + h.HandleBadGateway(err) + return err + } + + connection, err = net.Dial("tcp", endpoint.CanonicalAddr()) + if err == nil { + break + } + + iter.EndpointFailed() + + h.logger.Set("Error", err.Error()) + h.logger.Warn("proxy.tcp.failed") + + retry++ + if retry == retries { + return err + } + } + + if connection != nil { + forwardIO(client, connection) + } return nil } -func (h *RequestHandler) serveWebSocket(endpoint *route.Endpoint) error { +func (h *RequestHandler) serveWebSocket(iter route.EndpointIterator) error { var err error + var connection net.Conn client, _, err := h.hijack() if err != nil { return err } - - connection, err := net.Dial("tcp", endpoint.CanonicalAddr()) - if err != nil { - return err - } - defer func() { client.Close() - connection.Close() + if connection != nil { + connection.Close() + } }() - err = h.request.Write(connection) - if err != nil { - return err + retry := 0 + for { + endpoint := iter.Next() + if endpoint == nil { + h.reporter.CaptureBadGateway(h.request) + err = noEndpointsAvailable + h.HandleBadGateway(err) + return err + } + + connection, err = net.Dial("tcp", endpoint.CanonicalAddr()) + if err == nil { + h.setupRequest(endpoint) + break + } + + iter.EndpointFailed() + + h.logger.Set("Error", err.Error()) + h.logger.Warn("proxy.websocket.failed") + + retry++ + if retry == retries { + return err + } } - forwardIO(client, connection) + if connection != nil { + err = h.request.Write(connection) + if err != nil { + return err + } + forwardIO(client, connection) + } return nil } diff --git a/proxy/responsewriter.go b/proxy/responsewriter.go new file mode 100644 index 000000000..f36432a17 --- /dev/null +++ b/proxy/responsewriter.go @@ -0,0 +1,70 @@ +package proxy + +import ( + "net/http" +) + +type proxyResponseWriter struct { + w http.ResponseWriter + status int + size int + + flusher http.Flusher + done bool +} + +func newProxyResponseWriter(w http.ResponseWriter) *proxyResponseWriter { + proxyWriter := &proxyResponseWriter{ + w: w, + flusher: w.(http.Flusher), + } + + return proxyWriter +} + +func (p *proxyResponseWriter) Header() http.Header { + return p.w.Header() +} + +func (p *proxyResponseWriter) Write(b []byte) (int, error) { + if p.done { + return 0, nil + } + + if p.status == 0 { + p.WriteHeader(http.StatusOK) + } + size, err := p.w.Write(b) + p.size += size + return size, err +} + +func (p *proxyResponseWriter) WriteHeader(s int) { + if p.done { + return + } + + p.w.WriteHeader(s) + + if p.status == 0 { + p.status = s + } +} + +func (p *proxyResponseWriter) Done() { + p.done = true +} + +func (p *proxyResponseWriter) Flush() { + if p.flusher != nil { + p.flusher.Flush() + } +} + +func (p *proxyResponseWriter) Status() int { + return p.status +} + +func (p *proxyResponseWriter) Size() int { + return p.size +} diff --git a/registry/registry.go b/registry/registry.go index 42db41b65..7c3e3a4b1 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -12,42 +12,29 @@ import ( "github.com/cloudfoundry/gorouter/route" ) -type CFRegistry struct { +type RouteRegistry struct { sync.RWMutex logger *steno.Logger byUri map[route.Uri]*route.Pool - table map[tableKey]*tableEntry - pruneStaleDropletsInterval time.Duration dropletStaleThreshold time.Duration messageBus yagnats.NATSClient + ticker *time.Ticker timeOfLastUpdate time.Time } -type tableKey struct { - addr string - uri route.Uri -} - -type tableEntry struct { - endpoint *route.Endpoint - updatedAt time.Time -} - -func NewCFRegistry(c *config.Config, mbus yagnats.NATSClient) *CFRegistry { - r := &CFRegistry{} +func NewRouteRegistry(c *config.Config, mbus yagnats.NATSClient) *RouteRegistry { + r := &RouteRegistry{} r.logger = steno.NewLogger("router.registry") r.byUri = make(map[route.Uri]*route.Pool) - r.table = make(map[tableKey]*tableEntry) - r.pruneStaleDropletsInterval = c.PruneStaleDropletsInterval r.dropletStaleThreshold = c.DropletStaleThreshold @@ -56,192 +43,145 @@ func NewCFRegistry(c *config.Config, mbus yagnats.NATSClient) *CFRegistry { return r } -func (registry *CFRegistry) Register(uri route.Uri, endpoint *route.Endpoint) { - registry.Lock() - defer registry.Unlock() +func (r *RouteRegistry) Register(uri route.Uri, endpoint *route.Endpoint) { + t := time.Now() + r.Lock() uri = uri.ToLower() - key := tableKey{ - addr: endpoint.CanonicalAddr(), - uri: uri, - } - - var endpointToRegister *route.Endpoint - - entry, found := registry.table[key] - if found { - endpointToRegister = entry.endpoint - } else { - endpointToRegister = endpoint - entry = &tableEntry{endpoint: endpoint} - - registry.table[key] = entry - } - - pool, found := registry.byUri[uri] + pool, found := r.byUri[uri] if !found { - pool = route.NewPool() - registry.byUri[uri] = pool + pool = route.NewPool(r.dropletStaleThreshold / 4) + r.byUri[uri] = pool } - pool.Add(endpointToRegister) - - entry.updatedAt = time.Now() + pool.Put(endpoint) - registry.timeOfLastUpdate = time.Now() + r.timeOfLastUpdate = t + r.Unlock() } -func (registry *CFRegistry) Unregister(uri route.Uri, endpoint *route.Endpoint) { - registry.Lock() - defer registry.Unlock() +func (r *RouteRegistry) Unregister(uri route.Uri, endpoint *route.Endpoint) { + r.Lock() uri = uri.ToLower() - key := tableKey{ - addr: endpoint.CanonicalAddr(), - uri: uri, + pool, found := r.byUri[uri] + if found { + pool.Remove(endpoint) + + if pool.IsEmpty() { + delete(r.byUri, uri) + } } - registry.unregisterUri(key) + r.Unlock() } -func (r *CFRegistry) Lookup(uri route.Uri) (*route.Endpoint, bool) { +func (r *RouteRegistry) Lookup(uri route.Uri) *route.Pool { r.RLock() - defer r.RUnlock() - pool, ok := r.lookupByUri(uri) - if !ok { - return nil, false - } + uri = uri.ToLower() + pool := r.byUri[uri] - return pool.Sample() + r.RUnlock() + + return pool } -func (r *CFRegistry) LookupByPrivateInstanceId(uri route.Uri, p string) (*route.Endpoint, bool) { - r.RLock() - defer r.RUnlock() +func (r *RouteRegistry) StartPruningCycle() { + if r.pruneStaleDropletsInterval > 0 { + r.Lock() + r.ticker = time.NewTicker(r.pruneStaleDropletsInterval) + r.Unlock() - pool, ok := r.lookupByUri(uri) - if !ok { - return nil, false - } + go func() { + for { + select { + case <-r.ticker.C: + r.logger.Debug("Start to check and prune stale droplets") + if r.isStateStale() { + r.logger.Info("State is stale; NOT pruning") + r.pauseStaleTracker() + break + } - return pool.FindByPrivateInstanceId(p) -} + r.pruneStaleDroplets() -func (r *CFRegistry) lookupByUri(uri route.Uri) (*route.Pool, bool) { - uri = uri.ToLower() - pool, ok := r.byUri[uri] - return pool, ok -} - -func (r *CFRegistry) StartPruningCycle() { - go r.checkAndPrune() + } + } + }() + } } -func (r *CFRegistry) PruneStaleDroplets() { - if r.isStateStale() { - r.logger.Info("State is stale; NOT pruning") - r.pauseStaleTracker() - return +func (r *RouteRegistry) StopPruningCycle() { + r.Lock() + if r.ticker != nil { + r.ticker.Stop() } - - r.pruneStaleDroplets() + r.Unlock() } -func (registry *CFRegistry) NumUris() int { +func (registry *RouteRegistry) NumUris() int { registry.RLock() - defer registry.RUnlock() + uriCount := len(registry.byUri) + registry.RUnlock() - return len(registry.byUri) + return uriCount } -func (r *CFRegistry) TimeOfLastUpdate() time.Time { +func (r *RouteRegistry) TimeOfLastUpdate() time.Time { r.RLock() - defer r.RUnlock() - return r.timeOfLastUpdate + t := r.timeOfLastUpdate + r.RUnlock() + + return t } -func (r *CFRegistry) NumEndpoints() int { +func (r *RouteRegistry) NumEndpoints() int { r.RLock() - defer r.RUnlock() - - mapForSize := make(map[string]bool) - for _, entry := range r.table { - mapForSize[entry.endpoint.CanonicalAddr()] = true + uris := make(map[string]struct{}) + f := func(endpoint *route.Endpoint) { + uris[endpoint.CanonicalAddr()] = struct{}{} + } + for _, pool := range r.byUri { + pool.Each(f) } + r.RUnlock() - return len(mapForSize) + return len(uris) } -func (r *CFRegistry) MarshalJSON() ([]byte, error) { +func (r *RouteRegistry) MarshalJSON() ([]byte, error) { r.RLock() defer r.RUnlock() return json.Marshal(r.byUri) } -func (r *CFRegistry) isStateStale() bool { +func (r *RouteRegistry) isStateStale() bool { return !r.messageBus.Ping() } -func (r *CFRegistry) pruneStaleDroplets() { +func (r *RouteRegistry) pruneStaleDroplets() { r.Lock() - defer r.Unlock() - - for key, entry := range r.table { - if !r.isEntryStale(entry) { - continue + pruneTime := time.Now().Add(-r.dropletStaleThreshold) + for k, pool := range r.byUri { + pool.PruneBefore(pruneTime) + if pool.IsEmpty() { + delete(r.byUri, k) } - - r.logger.Infof("Pruning stale droplet: %v, uri: %s", entry, key.uri) - r.unregisterUri(key) } + r.Unlock() } -func (r *CFRegistry) isEntryStale(entry *tableEntry) bool { - return entry.updatedAt.Add(r.dropletStaleThreshold).Before(time.Now()) -} - -func (r *CFRegistry) pauseStaleTracker() { +func (r *RouteRegistry) pauseStaleTracker() { r.Lock() - defer r.Unlock() - - for _, entry := range r.table { - entry.updatedAt = time.Now() - } -} - -func (r *CFRegistry) checkAndPrune() { - if r.pruneStaleDropletsInterval == 0 { - return - } - - tick := time.Tick(r.pruneStaleDropletsInterval) - for { - select { - case <-tick: - r.logger.Debug("Start to check and prune stale droplets") - r.PruneStaleDroplets() - } - } -} - -func (r *CFRegistry) unregisterUri(key tableKey) { - entry, found := r.table[key] - if !found { - return - } + t := time.Now() - endpoints, found := r.byUri[key.uri] - if found { - endpoints.Remove(entry.endpoint) - - if endpoints.IsEmpty() { - delete(r.byUri, key.uri) - } + for _, pool := range r.byUri { + pool.MarkUpdated(t) } - delete(r.table, key) + r.Unlock() } diff --git a/registry/registry_test.go b/registry/registry_test.go index 63af4e1e0..21cd7309c 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -13,8 +13,8 @@ import ( "time" ) -var _ = Describe("Registry", func() { - var r *CFRegistry +var _ = Describe("RouteRegistry", func() { + var r *RouteRegistry var messageBus *fakeyagnats.FakeYagnats var fooEndpoint, barEndpoint, bar2Endpoint *route.Endpoint @@ -22,43 +22,30 @@ var _ = Describe("Registry", func() { BeforeEach(func() { configObj = config.DefaultConfig() + configObj.PruneStaleDropletsInterval = 50 * time.Millisecond configObj.DropletStaleThreshold = 10 * time.Millisecond messageBus = fakeyagnats.New() - r = NewCFRegistry(configObj, messageBus) - fooEndpoint = &route.Endpoint{ - Host: "192.168.1.1", - Port: 1234, - - ApplicationId: "12345", - Tags: map[string]string{ + r = NewRouteRegistry(configObj, messageBus) + fooEndpoint = route.NewEndpoint("12345", "192.168.1.1", 1234, + "id1", map[string]string{ "runtime": "ruby18", "framework": "sinatra", - }, - } - - barEndpoint = &route.Endpoint{ - Host: "192.168.1.2", - Port: 4321, + }) - ApplicationId: "54321", - Tags: map[string]string{ + barEndpoint = route.NewEndpoint("54321", "192.168.1.2", 4321, + "id2", map[string]string{ "runtime": "javascript", "framework": "node", - }, - } - - bar2Endpoint = &route.Endpoint{ - Host: "192.168.1.3", - Port: 1234, + }) - ApplicationId: "54321", - Tags: map[string]string{ + bar2Endpoint = route.NewEndpoint("54321", "192.168.1.3", 1234, + "id3", map[string]string{ "runtime": "javascript", "framework": "node", - }, - } + }) }) + Context("Register", func() { It("records and tracks time of last update", func() { r.Register("foo", fooEndpoint) @@ -89,15 +76,8 @@ var _ = Describe("Registry", func() { }) It("ignores case", func() { - m1 := &route.Endpoint{ - Host: "192.168.1.1", - Port: 1234, - } - - m2 := &route.Endpoint{ - Host: "192.168.1.1", - Port: 1235, - } + m1 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil) + m2 := route.NewEndpoint("", "192.168.1.1", 1235, "", nil) r.Register("foo", m1) r.Register("FOO", m2) @@ -106,15 +86,8 @@ var _ = Describe("Registry", func() { }) It("allows multiple uris for the same endpoint", func() { - m1 := &route.Endpoint{ - Host: "192.168.1.1", - Port: 1234, - } - - m2 := &route.Endpoint{ - Host: "192.168.1.1", - Port: 1234, - } + m1 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil) + m2 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil) r.Register("foo", m1) r.Register("bar", m2) @@ -148,15 +121,8 @@ var _ = Describe("Registry", func() { }) It("ignores uri case and matches endpoint", func() { - m1 := &route.Endpoint{ - Host: "192.168.1.1", - Port: 1234, - } - - m2 := &route.Endpoint{ - Host: "192.168.1.1", - Port: 1234, - } + m1 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil) + m2 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil) r.Register("foo", m1) r.Unregister("FOO", m2) @@ -165,15 +131,8 @@ var _ = Describe("Registry", func() { }) It("removes the specific url/endpoint combo", func() { - m1 := &route.Endpoint{ - Host: "192.168.1.1", - Port: 1234, - } - - m2 := &route.Endpoint{ - Host: "192.168.1.1", - Port: 1234, - } + m1 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil) + m2 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil) r.Register("foo", m1) r.Register("bar", m1) @@ -186,32 +145,21 @@ var _ = Describe("Registry", func() { Context("Lookup", func() { It("case insensitive lookup", func() { - m := &route.Endpoint{ - Host: "192.168.1.1", - Port: 1234, - } + m := route.NewEndpoint("", "192.168.1.1", 1234, "", nil) r.Register("foo", m) - b, ok := r.Lookup("foo") - Ω(ok).To(BeTrue()) - Ω(b.CanonicalAddr()).To(Equal("192.168.1.1:1234")) + p1 := r.Lookup("foo") + p2 := r.Lookup("FOO") + Ω(p1).To(Equal(p2)) - b, ok = r.Lookup("FOO") - Ω(ok).To(BeTrue()) - Ω(b.CanonicalAddr()).To(Equal("192.168.1.1:1234")) + iter := p1.Endpoints("") + Ω(iter.Next().CanonicalAddr()).To(Equal("192.168.1.1:1234")) }) It("selects one of the routes", func() { - m1 := &route.Endpoint{ - Host: "192.168.1.2", - Port: 1234, - } - - m2 := &route.Endpoint{ - Host: "192.168.1.2", - Port: 1235, - } + m1 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil) + m2 := route.NewEndpoint("", "192.168.1.1", 1235, "", nil) r.Register("bar", m1) r.Register("barr", m1) @@ -222,13 +170,19 @@ var _ = Describe("Registry", func() { Ω(r.NumUris()).To(Equal(2)) Ω(r.NumEndpoints()).To(Equal(2)) - b, ok := r.Lookup("bar") - Ω(ok).To(BeTrue()) - Ω(b.Host).To(Equal("192.168.1.2")) - Ω(b.Port == m1.Port || b.Port == m2.Port).To(BeTrue()) + p := r.Lookup("bar") + Ω(p).ShouldNot(BeNil()) + e := p.Endpoints("").Next() + Ω(e).ShouldNot(BeNil()) + Ω(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:123[4|5]")) }) }) - Context("PruneStaleDropelts", func() { + Context("Prunes Stale Droplets", func() { + + AfterEach(func() { + r.StopPruningCycle() + }) + It("removes stale droplets", func() { r.Register("foo", fooEndpoint) r.Register("fooo", fooEndpoint) @@ -239,18 +193,15 @@ var _ = Describe("Registry", func() { Ω(r.NumUris()).To(Equal(4)) Ω(r.NumEndpoints()).To(Equal(2)) - time.Sleep(configObj.DropletStaleThreshold + 1*time.Millisecond) - r.PruneStaleDroplets() + r.StartPruningCycle() + time.Sleep(configObj.PruneStaleDropletsInterval + 10*time.Millisecond) Ω(r.NumUris()).To(Equal(0)) Ω(r.NumEndpoints()).To(Equal(0)) }) It("skips fresh droplets", func() { - endpoint := &route.Endpoint{ - Host: "192.168.1.1", - Port: 1234, - } + endpoint := route.NewEndpoint("", "192.168.1.1", 1234, "", nil) r.Register("foo", endpoint) r.Register("bar", endpoint) @@ -260,21 +211,21 @@ var _ = Describe("Registry", func() { Ω(r.NumUris()).To(Equal(2)) Ω(r.NumEndpoints()).To(Equal(1)) - time.Sleep(configObj.DropletStaleThreshold + 1*time.Millisecond) + r.StartPruningCycle() + time.Sleep(configObj.PruneStaleDropletsInterval + 10*time.Millisecond) r.Register("foo", endpoint) - r.PruneStaleDroplets() - + r.StopPruningCycle() Ω(r.NumUris()).To(Equal(1)) Ω(r.NumEndpoints()).To(Equal(1)) - foundEndpoint, found := r.Lookup("foo") - Ω(found).To(BeTrue()) - Ω(foundEndpoint).To(Equal(endpoint)) + p := r.Lookup("foo") + Ω(p).ShouldNot(BeNil()) + Ω(p.Endpoints("").Next()).To(Equal(endpoint)) - _, found = r.Lookup("bar") - Ω(found).To(BeFalse()) + p = r.Lookup("bar") + Ω(p).Should(BeNil()) }) It("disables pruning when NATS is unavailable", func() { @@ -287,10 +238,9 @@ var _ = Describe("Registry", func() { Ω(r.NumUris()).To(Equal(4)) Ω(r.NumEndpoints()).To(Equal(2)) - time.Sleep(configObj.DropletStaleThreshold + 1*time.Millisecond) - messageBus.OnPing(func() bool { return false }) - r.PruneStaleDroplets() + r.StartPruningCycle() + time.Sleep(configObj.PruneStaleDropletsInterval + 10*time.Millisecond) Ω(r.NumUris()).To(Equal(4)) Ω(r.NumEndpoints()).To(Equal(2)) @@ -313,25 +263,55 @@ var _ = Describe("Registry", func() { return false }) - go r.PruneStaleDroplets() + r.StartPruningCycle() <-barrier - _, ok := r.Lookup("foo") + p := r.Lookup("foo") barrier <- struct{}{} - Ω(ok).To(BeTrue()) + Ω(p).ShouldNot(BeNil()) }) }) - It("marshals", func() { - m := &route.Endpoint{ - Host: "192.168.1.1", - Port: 1234, - } + Context("Varz data", func() { + It("NumUris", func() { + r.Register("bar", barEndpoint) + r.Register("baar", barEndpoint) + + Ω(r.NumUris()).To(Equal(2)) + r.Register("foo", fooEndpoint) + + Ω(r.NumUris()).To(Equal(3)) + }) + + It("NumEndpoints", func() { + r.Register("bar", barEndpoint) + r.Register("baar", barEndpoint) + + Ω(r.NumEndpoints()).To(Equal(1)) + + r.Register("foo", fooEndpoint) + + Ω(r.NumEndpoints()).To(Equal(2)) + }) + + It("TimeOfLastUpdate", func() { + start := time.Now() + r.Register("bar", barEndpoint) + t := r.TimeOfLastUpdate() + end := time.Now() + + Ω(start.Before(t)).Should(BeTrue()) + Ω(end.After(t)).Should(BeTrue()) + }) + }) + + It("marshals", func() { + m := route.NewEndpoint("", "192.168.1.1", 1234, "", nil) r.Register("foo", m) + marshalled, err := json.Marshal(r) Ω(err).NotTo(HaveOccurred()) - Ω(string(marshalled)).To(Equal(`{"foo":["192.168.1.1:1234"]}`)) }) }) diff --git a/route/endpoint.go b/route/endpoint.go index 7cf5d5475..da3abd89f 100644 --- a/route/endpoint.go +++ b/route/endpoint.go @@ -3,37 +3,41 @@ package route import ( "encoding/json" "fmt" - "sync" ) -type Endpoint struct { - sync.Mutex +func NewEndpoint(appId, host string, port uint16, privateInstanceId string, + tags map[string]string) *Endpoint { + return &Endpoint{ + ApplicationId: appId, + addr: fmt.Sprintf("%s:%d", host, port), + Tags: tags, + PrivateInstanceId: privateInstanceId, + } +} +type Endpoint struct { ApplicationId string - Host string - Port uint16 + addr string Tags map[string]string PrivateInstanceId string } func (e *Endpoint) MarshalJSON() ([]byte, error) { - return json.Marshal(e.CanonicalAddr()) + return json.Marshal(e.addr) } func (e *Endpoint) CanonicalAddr() string { - return fmt.Sprintf("%s:%d", e.Host, e.Port) + return e.addr } func (e *Endpoint) ToLogData() interface{} { return struct { ApplicationId string - Host string - Port uint16 + Addr string Tags map[string]string }{ e.ApplicationId, - e.Host, - e.Port, + e.addr, e.Tags, } } diff --git a/route/endpoint_iterator_test.go b/route/endpoint_iterator_test.go new file mode 100644 index 000000000..1094c803a --- /dev/null +++ b/route/endpoint_iterator_test.go @@ -0,0 +1,198 @@ +package route_test + +import ( + "time" + . "github.com/cloudfoundry/gorouter/route" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("EndpointIterator", func() { + var pool *Pool + + BeforeEach(func() { + pool = NewPool(2 * time.Minute) + }) + + Describe("Next", func() { + It("performs round-robin through the endpoints", func() { + e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil) + e2 := NewEndpoint("", "5.6.7.8", 1234, "", nil) + e3 := NewEndpoint("", "1.2.7.8", 1234, "", nil) + endpoints := []*Endpoint{e1, e2, e3} + + for _, e := range endpoints { + pool.Put(e) + } + + counts := make([]int, len(endpoints)) + + iter := pool.Endpoints("") + + loops := 50 + for i := 0; i < len(endpoints)*loops; i += 1 { + n := iter.Next() + for j, e := range endpoints { + if e == n { + counts[j]++ + break + } + } + } + + for i := 0; i < len(endpoints); i++ { + Ω(counts[i]).To(Equal(loops)) + } + }) + + It("returns nil when no endpoints exist", func() { + iter := pool.Endpoints("") + e := iter.Next() + Ω(e).Should(BeNil()) + }) + + It("finds the initial endpoint by private id", func() { + b := NewEndpoint("", "1.2.3.4", 1235, "b", nil) + pool.Put(NewEndpoint("", "1.2.3.4", 1234, "a", nil)) + pool.Put(b) + pool.Put(NewEndpoint("", "1.2.3.4", 1236, "c", nil)) + pool.Put(NewEndpoint("", "1.2.3.4", 1237, "d", nil)) + + for i := 0; i < 10; i++ { + iter := pool.Endpoints(b.PrivateInstanceId) + e := iter.Next() + Ω(e).ShouldNot(BeNil()) + Ω(e.PrivateInstanceId).To(Equal(b.PrivateInstanceId)) + } + }) + + It("finds the initial endpoint by canonical addr", func() { + b := NewEndpoint("", "1.2.3.4", 1235, "b", nil) + pool.Put(NewEndpoint("", "1.2.3.4", 1234, "a", nil)) + pool.Put(b) + pool.Put(NewEndpoint("", "1.2.3.4", 1236, "c", nil)) + pool.Put(NewEndpoint("", "1.2.3.4", 1237, "d", nil)) + + for i := 0; i < 10; i++ { + iter := pool.Endpoints(b.CanonicalAddr()) + e := iter.Next() + Ω(e).ShouldNot(BeNil()) + Ω(e.CanonicalAddr()).To(Equal(b.CanonicalAddr())) + } + }) + + It("finds when there are multiple private ids", func() { + endpointFoo := NewEndpoint("", "1.2.3.4", 1234, "foo", nil) + endpointBar := NewEndpoint("", "5.6.7.8", 5678, "bar", nil) + + pool.Put(endpointFoo) + pool.Put(endpointBar) + + iter := pool.Endpoints(endpointFoo.PrivateInstanceId) + foundEndpoint := iter.Next() + Ω(foundEndpoint).ToNot(BeNil()) + Ω(foundEndpoint).To(Equal(endpointFoo)) + + iter = pool.Endpoints(endpointBar.PrivateInstanceId) + foundEndpoint = iter.Next() + Ω(foundEndpoint).ToNot(BeNil()) + Ω(foundEndpoint).To(Equal(endpointBar)) + }) + + It("returns the next available endpoint when the initial is not found", func() { + eFoo := NewEndpoint("", "1.2.3.4", 1234, "foo", nil) + pool.Put(eFoo) + + iter := pool.Endpoints("bogus") + e := iter.Next() + Ω(e).ShouldNot(BeNil()) + Ω(e).Should(Equal(eFoo)) + }) + + It("finds the correct endpoint when private ids change", func() { + endpointFoo := NewEndpoint("", "1.2.3.4", 1234, "foo", nil) + pool.Put(endpointFoo) + + iter := pool.Endpoints(endpointFoo.PrivateInstanceId) + foundEndpoint := iter.Next() + Ω(foundEndpoint).ShouldNot(BeNil()) + Ω(foundEndpoint).Should(Equal(endpointFoo)) + + endpointBar := NewEndpoint("", "1.2.3.4", 1234, "bar", nil) + pool.Put(endpointBar) + + iter = pool.Endpoints("foo") + foundEndpoint = iter.Next() + Ω(foundEndpoint).ShouldNot(Equal(endpointFoo)) + + iter = pool.Endpoints("bar") + Ω(foundEndpoint).Should(Equal(endpointBar)) + }) + }) + + Describe("Failed", func() { + It("skips failed endpoints", func() { + e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil) + e2 := NewEndpoint("", "5.6.7.8", 1234, "", nil) + pool.Put(e1) + pool.Put(e2) + + iter := pool.Endpoints("") + n := iter.Next() + Ω(n).ShouldNot(BeNil()) + + iter.EndpointFailed() + + nn1 := iter.Next() + nn2 := iter.Next() + Ω(nn1).ShouldNot(BeNil()) + Ω(nn2).ShouldNot(BeNil()) + Ω(nn1).ShouldNot(Equal(n)) + Ω(nn1).Should(Equal(nn2)) + }) + + It("resets when all endpoints are failed", func() { + e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil) + e2 := NewEndpoint("", "5.6.7.8", 1234, "", nil) + pool.Put(e1) + pool.Put(e2) + + iter := pool.Endpoints("") + n1 := iter.Next() + iter.EndpointFailed() + n2 := iter.Next() + iter.EndpointFailed() + Ω(n1).ShouldNot(Equal(n2)) + + n1 = iter.Next() + n2 = iter.Next() + Ω(n1).ShouldNot(Equal(n2)) + }) + + It("resets failed endpoints after exceeding failure duration", func() { + pool = NewPool(50 * time.Millisecond) + + e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil) + e2 := NewEndpoint("", "5.6.7.8", 1234, "", nil) + pool.Put(e1) + pool.Put(e2) + + iter := pool.Endpoints("") + n1 := iter.Next() + n2 := iter.Next() + Ω(n1).ShouldNot(Equal(n2)) + + iter.EndpointFailed() + + n1 = iter.Next() + n2 = iter.Next() + Ω(n1).Should(Equal(n2)) + + time.Sleep(50 * time.Millisecond) + + n1 = iter.Next() + n2 = iter.Next() + Ω(n1).ShouldNot(Equal(n2)) + }) + }) +}) diff --git a/route/pool.go b/route/pool.go index 1986510ac..a986674f1 100644 --- a/route/pool.go +++ b/route/pool.go @@ -3,65 +3,266 @@ package route import ( "encoding/json" "math/rand" + "sync" + "time" ) +type EndpointIterator interface { + Next() *Endpoint + EndpointFailed() +} + +type endpointIterator struct { + pool *Pool + + initialEndpoint string + lastEndpoint *Endpoint +} + +type endpointElem struct { + endpoint *Endpoint + index int + updated time.Time + failedAt *time.Time +} + type Pool struct { - endpoints map[string]*Endpoint + lock sync.Mutex + endpoints []*endpointElem + index map[string]*endpointElem + + retryAfterFailure time.Duration + nextIdx int } -func NewPool() *Pool { +func NewPool(retryAfterFailure time.Duration) *Pool { return &Pool{ - endpoints: make(map[string]*Endpoint), + endpoints: make([]*endpointElem, 0, 1), + index: make(map[string]*endpointElem), + retryAfterFailure: retryAfterFailure, + nextIdx: -1, } } -func (p *Pool) Add(endpoint *Endpoint) { - p.endpoints[endpoint.CanonicalAddr()] = endpoint +func (p *Pool) Put(endpoint *Endpoint) bool { + p.lock.Lock() + defer p.lock.Unlock() + + e, found := p.index[endpoint.CanonicalAddr()] + if found { + if e.endpoint == endpoint { + return false + } + + oldEndpoint := e.endpoint + e.endpoint = endpoint + + if oldEndpoint.PrivateInstanceId != endpoint.PrivateInstanceId { + delete(p.index, oldEndpoint.PrivateInstanceId) + p.index[endpoint.PrivateInstanceId] = e + } + } else { + e = &endpointElem{ + endpoint: endpoint, + index: len(p.endpoints), + } + + p.endpoints = append(p.endpoints, e) + + p.index[endpoint.CanonicalAddr()] = e + p.index[endpoint.PrivateInstanceId] = e + } + + e.updated = time.Now() + + return !found } -func (p *Pool) Remove(endpoint *Endpoint) { - delete(p.endpoints, endpoint.CanonicalAddr()) +func (p *Pool) Remove(endpoint *Endpoint) bool { + var e *endpointElem + + p.lock.Lock() + l := len(p.endpoints) + if l > 0 { + e = p.index[endpoint.CanonicalAddr()] + if e != nil { + p.removeEndpoint(e) + } + } + p.lock.Unlock() + + return e != nil } -func (p *Pool) Sample() (*Endpoint, bool) { - if len(p.endpoints) == 0 { - return nil, false +func (p *Pool) removeEndpoint(e *endpointElem) { + i := e.index + es := p.endpoints + last := len(es) + // re-ordering delete + es[last-1], es[i], es = nil, es[last-1], es[:last-1] + if i < last-1 { + es[i].index = i } + p.endpoints = es + + delete(p.index, e.endpoint.CanonicalAddr()) + delete(p.index, e.endpoint.PrivateInstanceId) +} + +func (p *Pool) Endpoints(initial string) EndpointIterator { + return newEndpointIterator(p, initial) +} - index := rand.Intn(len(p.endpoints)) +func (p *Pool) next() *Endpoint { + p.lock.Lock() + defer p.lock.Unlock() - ticker := 0 - for _, endpoint := range p.endpoints { - if ticker == index { - return endpoint, true + last := len(p.endpoints) + if last == 0 { + return nil + } + + if p.nextIdx == -1 { + p.nextIdx = rand.Intn(last) + } else if p.nextIdx >= last { + p.nextIdx = 0 + } + + startIdx := p.nextIdx + curIdx := startIdx + for { + e := p.endpoints[curIdx] + + curIdx++ + if curIdx == last { + curIdx = 0 + } + + if e.failedAt != nil { + curTime := time.Now() + if curTime.Sub(*e.failedAt) > p.retryAfterFailure { + // exipired failure window + e.failedAt = nil + } } - ticker += 1 + if e.failedAt == nil { + p.nextIdx = curIdx + return e.endpoint + } + + if curIdx == startIdx { + // all endpoints are marked failed so reset everything to available + for _, e2 := range p.endpoints { + e2.failedAt = nil + } + } + } +} + +func (p *Pool) findById(id string) *Endpoint { + var endpoint *Endpoint + p.lock.Lock() + e := p.index[id] + if e != nil { + endpoint = e.endpoint } + p.lock.Unlock() + + return endpoint +} + +func (p *Pool) IsEmpty() bool { + p.lock.Lock() + l := len(p.endpoints) + p.lock.Unlock() - panic("unreachable") + return l == 0 } -func (p *Pool) FindByPrivateInstanceId(id string) (*Endpoint, bool) { - for _, endpoint := range p.endpoints { - if endpoint.PrivateInstanceId == id { - return endpoint, true +func (p *Pool) PruneBefore(t time.Time) { + p.lock.Lock() + + last := len(p.endpoints) + for i := 0; i < last; { + e := p.endpoints[i] + if e.updated.Before(t) { + p.removeEndpoint(e) + last-- + } else { + i++ } } - return nil, false + p.lock.Unlock() } -func (p *Pool) IsEmpty() bool { - return len(p.endpoints) == 0 +func (p *Pool) MarkUpdated(t time.Time) { + p.lock.Lock() + for _, e := range p.endpoints { + e.updated = t + } + p.lock.Unlock() } -func (p *Pool) MarshalJSON() ([]byte, error) { - addresses := []string{} +func (p *Pool) endpointFailed(endpoint *Endpoint) { + p.lock.Lock() + e := p.index[endpoint.CanonicalAddr()] + if e != nil { + e.failed() + } + p.lock.Unlock() +} + +func (p *Pool) Each(f func(endpoint *Endpoint)) { + p.lock.Lock() + for _, e := range p.endpoints { + f(e.endpoint) + } + p.lock.Unlock() +} - for addr, _ := range p.endpoints { - addresses = append(addresses, addr) +func (p *Pool) MarshalJSON() ([]byte, error) { + p.lock.Lock() + addresses := make([]string, 0, len(p.endpoints)) + for _, e := range p.endpoints { + addresses = append(addresses, e.endpoint.addr) } + p.lock.Unlock() return json.Marshal(addresses) } + +func newEndpointIterator(p *Pool, initial string) EndpointIterator { + return &endpointIterator{ + pool: p, + initialEndpoint: initial, + } +} + +func (i *endpointIterator) Next() *Endpoint { + var e *Endpoint + if i.initialEndpoint != "" { + e = i.pool.findById(i.initialEndpoint) + i.initialEndpoint = "" + } + + if e == nil { + e = i.pool.next() + } + + i.lastEndpoint = e + + return e +} + +func (i *endpointIterator) EndpointFailed() { + if i.lastEndpoint != nil { + i.pool.endpointFailed(i.lastEndpoint) + } +} + +func (e *endpointElem) failed() { + t := time.Now() + e.failedAt = &t +} diff --git a/route/pool_test.go b/route/pool_test.go index 319d95a1b..2d650e56a 100644 --- a/route/pool_test.go +++ b/route/pool_test.go @@ -4,147 +4,148 @@ import ( . "github.com/cloudfoundry/gorouter/route" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + "time" ) -var _ = Describe("Route", func() { - Context("Add", func() { +var _ = Describe("Pool", func() { + var pool *Pool + + BeforeEach(func() { + pool = NewPool(2 * time.Minute) + }) + + Context("Put", func() { It("adds endpoints", func() { - pool := NewPool() endpoint := &Endpoint{} - pool.Add(endpoint) - foundEndpoint, found := pool.Sample() - Ω(found).To(BeTrue()) - Ω(foundEndpoint).To(Equal(endpoint)) + b := pool.Put(endpoint) + Ω(b).Should(BeTrue()) }) It("handles duplicate endpoints", func() { - pool := NewPool() - endpoint := &Endpoint{} - pool.Add(endpoint) - pool.Add(endpoint) - - foundEndpoint, found := pool.Sample() - Ω(found).To(BeTrue()) - Ω(foundEndpoint).To(Equal(endpoint)) - - pool.Remove(endpoint) - - _, found = pool.Sample() - Ω(found).To(BeFalse()) + pool.Put(endpoint) + b := pool.Put(endpoint) + Ω(b).Should(BeFalse()) }) It("handles equivalent (duplicate) endpoints", func() { - pool := NewPool() - - endpoint1 := &Endpoint{Host: "1.2.3.4", Port: 5678} - endpoint2 := &Endpoint{Host: "1.2.3.4", Port: 5678} + endpoint1 := NewEndpoint("", "1.2.3.4", 5678, "", nil) + endpoint2 := NewEndpoint("", "1.2.3.4", 5678, "", nil) - pool.Add(endpoint1) - pool.Add(endpoint2) - - _, found := pool.Sample() - Ω(found).To(BeTrue()) - - pool.Remove(endpoint1) - - _, found = pool.Sample() - Ω(found).To(BeFalse()) + pool.Put(endpoint1) + Ω(pool.Put(endpoint2)).Should(BeFalse()) }) }) + Context("Remove", func() { It("removes endpoints", func() { - pool := NewPool() - endpoint := &Endpoint{} + pool.Put(endpoint) - pool.Add(endpoint) - - foundEndpoint, found := pool.Sample() - Ω(found).To(BeTrue()) - Ω(foundEndpoint).To(Equal(endpoint)) + b := pool.Remove(endpoint) + Ω(b).Should(BeTrue()) + Ω(pool.IsEmpty()).Should(BeTrue()) + }) - pool.Remove(endpoint) + It("fails to remove an endpoint that doesn't exist", func() { + endpoint := &Endpoint{} - _, found = pool.Sample() - Ω(found).To(BeFalse()) + b := pool.Remove(endpoint) + Ω(b).Should(BeFalse()) }) - }) + Context("IsEmpty", func() { It("starts empty", func() { - Ω(NewPool().IsEmpty()).To(BeTrue()) + Ω(pool.IsEmpty()).To(BeTrue()) }) - It("empty after removing everything", func() { - pool := NewPool() - + It("not empty after adding an endpoint", func() { endpoint := &Endpoint{} + pool.Put(endpoint) - pool.Add(endpoint) - - Ω(pool.IsEmpty()).To(BeFalse()) + Ω(pool.IsEmpty()).Should(BeFalse()) + }) + It("is empty after removing everything", func() { + endpoint := &Endpoint{} + pool.Put(endpoint) pool.Remove(endpoint) Ω(pool.IsEmpty()).To(BeTrue()) }) }) - It("finds by private instance id", func() { - pool := NewPool() + Context("PruneBefore", func() { + It("prunes endpoints that haven't been updated", func() { + e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil) + e2 := NewEndpoint("", "5.6.7.8", 1234, "", nil) + pool.Put(e1) + pool.Put(e2) - endpointFoo := &Endpoint{Host: "1.2.3.4", Port: 1234, PrivateInstanceId: "foo"} - endpointBar := &Endpoint{Host: "5.6.7.8", Port: 5678, PrivateInstanceId: "bar"} - - pool.Add(endpointFoo) - pool.Add(endpointBar) + t := time.Now().Add(1 * time.Second) + pool.PruneBefore(t) + Ω(pool.IsEmpty()).Should(BeTrue()) + }) - foundEndpoint, found := pool.FindByPrivateInstanceId("foo") - Ω(found).To(BeTrue()) - Ω(foundEndpoint).To(Equal(endpointFoo)) + It("does not prune updated endpoints", func() { + e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil) + e2 := NewEndpoint("", "5.6.7.8", 1234, "", nil) + pool.Put(e1) + pool.Put(e2) - foundEndpoint, found = pool.FindByPrivateInstanceId("bar") - Ω(found).To(BeTrue()) - Ω(foundEndpoint).To(Equal(endpointBar)) + t := time.Now().Add(-1 * time.Second) + pool.PruneBefore(t) + Ω(pool.IsEmpty()).Should(BeFalse()) - _, found = pool.FindByPrivateInstanceId("quux") - Ω(found).To(BeFalse()) + iter := pool.Endpoints("") + n1 := iter.Next() + n2 := iter.Next() + Ω(n1).ShouldNot(Equal(n2)) + }) }) - It("Sample is randomish", func() { - pool := NewPool() - - endpoint1 := &Endpoint{Host: "1.2.3.4", Port: 5678} - endpoint2 := &Endpoint{Host: "5.6.7.8", Port: 1234} + Context("MarkUpdated", func() { + It("updates all endpoints", func() { + e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil) - pool.Add(endpoint1) - pool.Add(endpoint2) + pool.Put(e1) - var occurrences1, occurrences2 int + t := time.Time{}.Add(1 * time.Second) + pool.PruneBefore(t) + Ω(pool.IsEmpty()).Should(BeFalse()) - for i := 0; i < 200; i += 1 { - foundEndpoint, _ := pool.Sample() - if foundEndpoint == endpoint1 { - occurrences1 += 1 - } else { - occurrences2 += 1 - } - } + pool.MarkUpdated(t) + pool.PruneBefore(t) + Ω(pool.IsEmpty()).Should(BeFalse()) - Ω(occurrences1).ToNot(BeZero()) - Ω(occurrences2).ToNot(BeZero()) + pool.PruneBefore(t.Add(1 * time.Microsecond)) + Ω(pool.IsEmpty()).Should(BeTrue()) + }) + }) - // they should be arbitrarily close - Ω(occurrences1 - occurrences2).To(BeNumerically("~", 0, 50)) + Context("Each", func() { + It("applies a function to each endpoint", func() { + e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil) + e2 := NewEndpoint("", "5.6.7.8", 1234, "", nil) + pool.Put(e1) + pool.Put(e2) + + endpoints := make(map[string]*Endpoint) + pool.Each(func(e *Endpoint) { + endpoints[e.CanonicalAddr()] = e + }) + Ω(endpoints).Should(HaveLen(2)) + Ω(endpoints[e1.CanonicalAddr()]).Should(Equal(e1)) + Ω(endpoints[e2.CanonicalAddr()]).Should(Equal(e2)) + }) }) It("marshals json", func() { - pool := NewPool() - - pool.Add(&Endpoint{Host: "1.2.3.4", Port: 5678}) + e := NewEndpoint("", "1.2.3.4", 5678, "", nil) + pool.Put(e) json, err := pool.MarshalJSON() Ω(err).ToNot(HaveOccurred()) diff --git a/router/helper_test.go b/router/helper_test.go index 4801e2325..7a57e958c 100644 --- a/router/helper_test.go +++ b/router/helper_test.go @@ -8,15 +8,22 @@ import ( "time" ) -func waitMsgReceived(registry *registry.CFRegistry, app *test.TestApp, expectedToBeFound bool, timeout time.Duration) bool { +func waitMsgReceived(registry *registry.RouteRegistry, app *test.TestApp, expectedToBeFound bool, timeout time.Duration) bool { interval := time.Millisecond * 50 repetitions := int(timeout / interval) for j := 0; j < repetitions; j++ { + if j > 0 { + time.Sleep(interval) + } + received := true for _, url := range app.Urls() { - _, ok := registry.Lookup(url) - if ok != expectedToBeFound { + pool := registry.Lookup(url) + if expectedToBeFound && pool == nil { + received = false + break + } else if !expectedToBeFound && pool != nil { received = false break } @@ -24,17 +31,16 @@ func waitMsgReceived(registry *registry.CFRegistry, app *test.TestApp, expectedT if received { return true } - time.Sleep(interval) } return false } -func waitAppRegistered(registry *registry.CFRegistry, app *test.TestApp, timeout time.Duration) bool { +func waitAppRegistered(registry *registry.RouteRegistry, app *test.TestApp, timeout time.Duration) bool { return waitMsgReceived(registry, app, true, timeout) } -func waitAppUnregistered(registry *registry.CFRegistry, app *test.TestApp, timeout time.Duration) bool { +func waitAppUnregistered(registry *registry.RouteRegistry, app *test.TestApp, timeout time.Duration) bool { return waitMsgReceived(registry, app, false, timeout) } diff --git a/router/registry_message.go b/router/registry_message.go index 9bb1bbfda..30532d4a6 100644 --- a/router/registry_message.go +++ b/router/registry_message.go @@ -14,12 +14,6 @@ type registryMessage struct { PrivateInstanceId string `json:"private_instance_id"` } -func (registryMessage *registryMessage) makeEndpoint() *route.Endpoint { - return &route.Endpoint{ - Host: registryMessage.Host, - Port: registryMessage.Port, - ApplicationId: registryMessage.App, - Tags: registryMessage.Tags, - PrivateInstanceId: registryMessage.PrivateInstanceId, - } +func (rm *registryMessage) makeEndpoint() *route.Endpoint { + return route.NewEndpoint(rm.App, rm.Host, rm.Port, rm.PrivateInstanceId, rm.Tags) } diff --git a/router/router.go b/router/router.go index ac7e3d7e0..dcb36f132 100644 --- a/router/router.go +++ b/router/router.go @@ -25,7 +25,7 @@ type Router struct { config *config.Config proxy proxy.Proxy mbusClient *yagnats.Client - registry *registry.CFRegistry + registry *registry.RouteRegistry varz varz.Varz component *vcap.VcapComponent @@ -34,7 +34,7 @@ type Router struct { logger *steno.Logger } -func NewRouter(cfg *config.Config, p proxy.Proxy, mbusClient *yagnats.Client, r *registry.CFRegistry, v varz.Varz, +func NewRouter(cfg *config.Config, p proxy.Proxy, mbusClient *yagnats.Client, r *registry.RouteRegistry, v varz.Varz, logCounter *vcap.LogCounter) (*Router, error) { var host string @@ -260,9 +260,9 @@ func (r *Router) greetMessage() ([]byte, error) { } d := vcap.RouterStart{ - uuid, - []string{host}, - r.config.StartResponseDelayIntervalInSeconds, + Id: uuid, + Hosts: []string{host}, + MinimumRegisterIntervalInSeconds: r.config.StartResponseDelayIntervalInSeconds, } return json.Marshal(d) diff --git a/router/router_drain_test.go b/router/router_drain_test.go index 5cd608b17..cc9fb58a4 100644 --- a/router/router_drain_test.go +++ b/router/router_drain_test.go @@ -25,7 +25,7 @@ var _ = Describe("Router", func() { var config *cfg.Config var mbusClient *yagnats.Client - var registry *rregistry.CFRegistry + var registry *rregistry.RouteRegistry var varz vvarz.Varz var router *Router var natsPort uint16 @@ -42,7 +42,7 @@ var _ = Describe("Router", func() { config.EndpointTimeout = 5 * time.Second mbusClient = natsRunner.MessageBus.(*yagnats.Client) - registry = rregistry.NewCFRegistry(config, mbusClient) + registry = rregistry.NewRouteRegistry(config, mbusClient) varz = vvarz.NewVarz(registry) logcounter := vcap.NewLogCounter() proxy := proxy.NewProxy(proxy.ProxyArgs{ diff --git a/router/router_test.go b/router/router_test.go index 1f28c227e..cf6bf14ba 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -33,7 +33,7 @@ var _ = Describe("Router", func() { var config *cfg.Config var mbusClient *yagnats.Client - var registry *rregistry.CFRegistry + var registry *rregistry.RouteRegistry var varz vvarz.Varz var router *Router @@ -48,7 +48,7 @@ var _ = Describe("Router", func() { config = test_util.SpecConfig(natsPort, statusPort, proxyPort) mbusClient = natsRunner.MessageBus.(*yagnats.Client) - registry = rregistry.NewCFRegistry(config, mbusClient) + registry = rregistry.NewRouteRegistry(config, mbusClient) varz = vvarz.NewVarz(registry) logcounter := vcap.NewLogCounter() proxy := proxy.NewProxy(proxy.ProxyArgs{ @@ -157,15 +157,14 @@ var _ = Describe("Router", func() { app1.Listen() Ω(waitAppRegistered(registry, app1, time.Second*1)).To(BeTrue()) - time.Sleep(2 * time.Second) + time.Sleep(100 * time.Millisecond) initialUpdateTime := fetchRecursively(readVarz(varz), "ms_since_last_registry_update").(float64) - // initialUpdateTime should be roughly 2 seconds. app2 := test.NewGreetApp([]route.Uri{"test2.vcap.me"}, config.Port, mbusClient, nil) app2.Listen() Ω(waitAppRegistered(registry, app2, time.Second*1)).To(BeTrue()) - // updateTime should be roughly 0 seconds + // updateTime should be after initial update time updateTime := fetchRecursively(readVarz(varz), "ms_since_last_registry_update").(float64) Ω(updateTime).To(BeNumerically("<", initialUpdateTime)) }) @@ -422,7 +421,7 @@ func fetchRecursively(x interface{}, s ...string) interface{} { return x } -func verify_health_z(host string, r *rregistry.CFRegistry) { +func verify_health_z(host string, r *rregistry.RouteRegistry) { var req *http.Request path := "/healthz" diff --git a/scripts/test b/scripts/test index 15baf4dc4..4f9ab6a37 100755 --- a/scripts/test +++ b/scripts/test @@ -2,6 +2,16 @@ set -e -x -u +function printStatus { + if [ $? -eq 0 ]; then + echo -e "\nSWEET SUITE SUCCESS" + else + echo -e "\nSUITE FAILURE" + fi + } + +trap printStatus EXIT + . $(dirname $0)/gorequired #Download & Install gnatsd into GOPATH (or use pre-installed version) diff --git a/varz/varz.go b/varz/varz.go index f13c636f7..3c645187b 100644 --- a/varz/varz.go +++ b/varz/varz.go @@ -170,13 +170,13 @@ type Varz interface { type RealVarz struct { sync.Mutex - r *registry.CFRegistry + r *registry.RouteRegistry activeApps *stats.ActiveApps topApps *stats.TopApps varz } -func NewVarz(r *registry.CFRegistry) Varz { +func NewVarz(r *registry.RouteRegistry) Varz { x := &RealVarz{r: r} x.activeApps = stats.NewActiveApps() @@ -227,18 +227,16 @@ func (x *RealVarz) ActiveApps() *stats.ActiveApps { return x.activeApps } -func (x *RealVarz) CaptureBadRequest(req *http.Request) { +func (x *RealVarz) CaptureBadRequest(*http.Request) { x.Lock() - defer x.Unlock() - x.BadRequests++ + x.Unlock() } -func (x *RealVarz) CaptureBadGateway(req *http.Request) { +func (x *RealVarz) CaptureBadGateway(*http.Request) { x.Lock() - defer x.Unlock() - x.BadGateways++ + x.Unlock() } func (x *RealVarz) CaptureAppStats(b *route.Endpoint, t time.Time) { @@ -250,7 +248,6 @@ func (x *RealVarz) CaptureAppStats(b *route.Endpoint, t time.Time) { func (x *RealVarz) CaptureRoutingRequest(b *route.Endpoint, req *http.Request) { x.Lock() - defer x.Unlock() var t string var ok bool @@ -261,11 +258,12 @@ func (x *RealVarz) CaptureRoutingRequest(b *route.Endpoint, req *http.Request) { } x.varz.All.CaptureRequest() + + x.Unlock() } func (x *RealVarz) CaptureRoutingResponse(endpoint *route.Endpoint, response *http.Response, startedAt time.Time, duration time.Duration) { x.Lock() - defer x.Unlock() var tags string var ok bool @@ -277,6 +275,8 @@ func (x *RealVarz) CaptureRoutingResponse(endpoint *route.Endpoint, response *ht x.CaptureAppStats(endpoint, startedAt) x.varz.All.CaptureResponse(response, duration) + + x.Unlock() } func transform(x interface{}, y map[string]interface{}) error { diff --git a/varz/varz_test.go b/varz/varz_test.go index 154a5ca63..4613a4922 100644 --- a/varz/varz_test.go +++ b/varz/varz_test.go @@ -17,10 +17,10 @@ import ( var _ = Describe("Varz", func() { var Varz Varz - var Registry *registry.CFRegistry + var Registry *registry.RouteRegistry BeforeEach(func() { - Registry = registry.NewCFRegistry(config.DefaultConfig(), fakeyagnats.New()) + Registry = registry.NewRouteRegistry(config.DefaultConfig(), fakeyagnats.New()) Varz = NewVarz(Registry) }) @@ -72,13 +72,7 @@ var _ = Describe("Varz", func() { It("has urls", func() { Ω(findValue(Varz, "urls")).To(Equal(float64(0))) - var fooReg = &route.Endpoint{ - Host: "192.168.1.1", - Port: 1234, - Tags: map[string]string{}, - - ApplicationId: "12345", - } + var fooReg = route.NewEndpoint("12345", "192.168.1.1", 1234, "", map[string]string{}) // Add a route Registry.Register("foo.vcap.me", fooReg)