From 9f3975fa4c364737ad9ffa3f544f94ac16562277 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Wed, 27 Jul 2022 23:12:25 +0800 Subject: [PATCH 001/118] =?UTF-8?q?=E2=9C=A8=20v3:=20Move=20the=20client?= =?UTF-8?q?=20module=20to=20the=20client=20folder=20and=20fix=20the=20erro?= =?UTF-8?q?r?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client.go => client/client.go | 23 +-- client_test.go => client/client_test.go | 185 ++++++++++++------------ middleware/proxy/proxy_test.go | 7 +- 3 files changed, 109 insertions(+), 106 deletions(-) rename client.go => client/client.go (97%) rename client_test.go => client/client_test.go (84%) diff --git a/client.go b/client/client.go similarity index 97% rename from client.go rename to client/client.go index 9a8a74758a..f9bac47647 100644 --- a/client.go +++ b/client/client.go @@ -1,4 +1,4 @@ -package fiber +package client import ( "bytes" @@ -16,6 +16,7 @@ import ( "sync" "time" + "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/utils" "github.com/valyala/fasthttp" ) @@ -84,7 +85,7 @@ func Get(url string) *Agent { return defaultClient.Get(url) } // Get returns a agent with http method GET. func (c *Client) Get(url string) *Agent { - return c.createAgent(MethodGet, url) + return c.createAgent(fiber.MethodGet, url) } // Head returns a agent with http method HEAD. @@ -92,7 +93,7 @@ func Head(url string) *Agent { return defaultClient.Head(url) } // Head returns a agent with http method GET. func (c *Client) Head(url string) *Agent { - return c.createAgent(MethodHead, url) + return c.createAgent(fiber.MethodHead, url) } // Post sends POST request to the given url. @@ -100,7 +101,7 @@ func Post(url string) *Agent { return defaultClient.Post(url) } // Post sends POST request to the given url. func (c *Client) Post(url string) *Agent { - return c.createAgent(MethodPost, url) + return c.createAgent(fiber.MethodPost, url) } // Put sends PUT request to the given url. @@ -108,7 +109,7 @@ func Put(url string) *Agent { return defaultClient.Put(url) } // Put sends PUT request to the given url. func (c *Client) Put(url string) *Agent { - return c.createAgent(MethodPut, url) + return c.createAgent(fiber.MethodPut, url) } // Patch sends PATCH request to the given url. @@ -116,7 +117,7 @@ func Patch(url string) *Agent { return defaultClient.Patch(url) } // Patch sends PATCH request to the given url. func (c *Client) Patch(url string) *Agent { - return c.createAgent(MethodPatch, url) + return c.createAgent(fiber.MethodPatch, url) } // Delete sends DELETE request to the given url. @@ -124,7 +125,7 @@ func Delete(url string) *Agent { return defaultClient.Delete(url) } // Delete sends DELETE request to the given url. func (c *Client) Delete(url string) *Agent { - return c.createAgent(MethodDelete, url) + return c.createAgent(fiber.MethodDelete, url) } func (c *Client) createAgent(method, url string) *Agent { @@ -478,7 +479,7 @@ func (a *Agent) JSON(v any) *Agent { a.jsonEncoder = json.Marshal } - a.req.Header.SetContentType(MIMEApplicationJSON) + a.req.Header.SetContentType(fiber.MIMEApplicationJSON) if body, err := a.jsonEncoder(v); err != nil { a.errs = append(a.errs, err) @@ -491,7 +492,7 @@ func (a *Agent) JSON(v any) *Agent { // XML sends an XML request. func (a *Agent) XML(v any) *Agent { - a.req.Header.SetContentType(MIMEApplicationXML) + a.req.Header.SetContentType(fiber.MIMEApplicationXML) if body, err := xml.Marshal(v); err != nil { a.errs = append(a.errs, err) @@ -507,7 +508,7 @@ func (a *Agent) XML(v any) *Agent { // It is recommended obtaining args via AcquireArgs and release it // manually in performance-critical code. func (a *Agent) Form(args *Args) *Agent { - a.req.Header.SetContentType(MIMEApplicationForm) + a.req.Header.SetContentType(fiber.MIMEApplicationForm) if args != nil { a.req.SetBody(args.QueryString()) @@ -785,7 +786,7 @@ func (a *Agent) Bytes() (code int, body []byte, errs []error) { errs = append(errs, err) return } - } else if a.maxRedirectsCount > 0 && (string(req.Header.Method()) == MethodGet || string(req.Header.Method()) == MethodHead) { + } else if a.maxRedirectsCount > 0 && (string(req.Header.Method()) == fiber.MethodGet || string(req.Header.Method()) == fiber.MethodHead) { if err := a.HostClient.DoRedirects(req, resp, a.maxRedirectsCount); err != nil { errs = append(errs, err) return diff --git a/client_test.go b/client/client_test.go similarity index 84% rename from client_test.go rename to client/client_test.go index e0c682f8c6..a56dcb64c7 100644 --- a/client_test.go +++ b/client/client_test.go @@ -1,4 +1,4 @@ -package fiber +package client import ( "bytes" @@ -18,6 +18,7 @@ import ( "encoding/json" + "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/gofiber/fiber/v3/utils" "github.com/valyala/fasthttp/fasthttputil" @@ -28,9 +29,9 @@ func Test_Client_Invalid_URL(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c Ctx) error { + app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) @@ -65,9 +66,9 @@ func Test_Client_Get(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c Ctx) error { + app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) @@ -80,7 +81,7 @@ func Test_Client_Get(t *testing.T) { code, body, errs := a.String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "example.com", body) utils.AssertEqual(t, 0, len(errs)) } @@ -91,9 +92,9 @@ func Test_Client_Head(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c Ctx) error { + app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) @@ -106,7 +107,7 @@ func Test_Client_Head(t *testing.T) { code, body, errs := a.String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "", body) utils.AssertEqual(t, 0, len(errs)) } @@ -117,10 +118,10 @@ func Test_Client_Post(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Post("/", func(c Ctx) error { - return c.Status(StatusCreated). + app.Post("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusCreated). SendString(c.FormValue("foo")) }) @@ -138,7 +139,7 @@ func Test_Client_Post(t *testing.T) { code, body, errs := a.String() - utils.AssertEqual(t, StatusCreated, code) + utils.AssertEqual(t, fiber.StatusCreated, code) utils.AssertEqual(t, "bar", body) utils.AssertEqual(t, 0, len(errs)) @@ -151,9 +152,9 @@ func Test_Client_Put(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Put("/", func(c Ctx) error { + app.Put("/", func(c fiber.Ctx) error { return c.SendString(c.FormValue("foo")) }) @@ -171,7 +172,7 @@ func Test_Client_Put(t *testing.T) { code, body, errs := a.String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "bar", body) utils.AssertEqual(t, 0, len(errs)) @@ -184,9 +185,9 @@ func Test_Client_Patch(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Patch("/", func(c Ctx) error { + app.Patch("/", func(c fiber.Ctx) error { return c.SendString(c.FormValue("foo")) }) @@ -204,7 +205,7 @@ func Test_Client_Patch(t *testing.T) { code, body, errs := a.String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "bar", body) utils.AssertEqual(t, 0, len(errs)) @@ -217,10 +218,10 @@ func Test_Client_Delete(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Delete("/", func(c Ctx) error { - return c.Status(StatusNoContent). + app.Delete("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusNoContent). SendString("deleted") }) @@ -235,7 +236,7 @@ func Test_Client_Delete(t *testing.T) { code, body, errs := a.String() - utils.AssertEqual(t, StatusNoContent, code) + utils.AssertEqual(t, fiber.StatusNoContent, code) utils.AssertEqual(t, "", body) utils.AssertEqual(t, 0, len(errs)) @@ -248,9 +249,9 @@ func Test_Client_UserAgent(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c Ctx) error { + app.Get("/", func(c fiber.Ctx) error { return c.Send(c.Request().Header.UserAgent()) }) @@ -264,7 +265,7 @@ func Test_Client_UserAgent(t *testing.T) { code, body, errs := a.String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, defaultUserAgent, body) utils.AssertEqual(t, 0, len(errs)) } @@ -281,7 +282,7 @@ func Test_Client_UserAgent(t *testing.T) { code, body, errs := a.String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "ua", body) utils.AssertEqual(t, 0, len(errs)) ReleaseClient(c) @@ -290,7 +291,7 @@ func Test_Client_UserAgent(t *testing.T) { } func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) { - handler := func(c Ctx) error { + handler := func(c fiber.Ctx) error { c.Request().Header.VisitAll(func(key, value []byte) { if k := string(key); k == "K1" || k == "K2" { _, _ = c.Write(key) @@ -315,7 +316,7 @@ func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) { } func Test_Client_Agent_Connection_Close(t *testing.T) { - handler := func(c Ctx) error { + handler := func(c fiber.Ctx) error { if c.Request().Header.ConnectionClose() { return c.SendString("close") } @@ -330,7 +331,7 @@ func Test_Client_Agent_Connection_Close(t *testing.T) { } func Test_Client_Agent_UserAgent(t *testing.T) { - handler := func(c Ctx) error { + handler := func(c fiber.Ctx) error { return c.Send(c.Request().Header.UserAgent()) } @@ -343,7 +344,7 @@ func Test_Client_Agent_UserAgent(t *testing.T) { } func Test_Client_Agent_Cookie(t *testing.T) { - handler := func(c Ctx) error { + handler := func(c fiber.Ctx) error { return c.SendString( c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) } @@ -360,7 +361,7 @@ func Test_Client_Agent_Cookie(t *testing.T) { } func Test_Client_Agent_Referer(t *testing.T) { - handler := func(c Ctx) error { + handler := func(c fiber.Ctx) error { return c.Send(c.Request().Header.Referer()) } @@ -373,7 +374,7 @@ func Test_Client_Agent_Referer(t *testing.T) { } func Test_Client_Agent_ContentType(t *testing.T) { - handler := func(c Ctx) error { + handler := func(c fiber.Ctx) error { return c.Send(c.Request().Header.ContentType()) } @@ -390,9 +391,9 @@ func Test_Client_Agent_Host(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c Ctx) error { + app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) @@ -408,13 +409,13 @@ func Test_Client_Agent_Host(t *testing.T) { code, body, errs := a.String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "example.com", body) utils.AssertEqual(t, 0, len(errs)) } func Test_Client_Agent_QueryString(t *testing.T) { - handler := func(c Ctx) error { + handler := func(c fiber.Ctx) error { return c.Send(c.Request().URI().QueryString()) } @@ -427,9 +428,9 @@ func Test_Client_Agent_QueryString(t *testing.T) { } func Test_Client_Agent_BasicAuth(t *testing.T) { - handler := func(c Ctx) error { + handler := func(c fiber.Ctx) error { // Get authorization header - auth := c.Get(HeaderAuthorization) + auth := c.Get(fiber.HeaderAuthorization) // Decode the header contents raw, err := base64.StdEncoding.DecodeString(auth[6:]) utils.AssertEqual(t, nil, err) @@ -446,7 +447,7 @@ func Test_Client_Agent_BasicAuth(t *testing.T) { } func Test_Client_Agent_BodyString(t *testing.T) { - handler := func(c Ctx) error { + handler := func(c fiber.Ctx) error { return c.Send(c.Request().Body()) } @@ -458,7 +459,7 @@ func Test_Client_Agent_BodyString(t *testing.T) { } func Test_Client_Agent_Body(t *testing.T) { - handler := func(c Ctx) error { + handler := func(c fiber.Ctx) error { return c.Send(c.Request().Body()) } @@ -470,7 +471,7 @@ func Test_Client_Agent_Body(t *testing.T) { } func Test_Client_Agent_BodyStream(t *testing.T) { - handler := func(c Ctx) error { + handler := func(c fiber.Ctx) error { return c.Send(c.Request().Body()) } @@ -486,9 +487,9 @@ func Test_Client_Agent_Custom_Response(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c Ctx) error { + app.Get("/", func(c fiber.Ctx) error { return c.SendString("custom") }) @@ -499,7 +500,7 @@ func Test_Client_Agent_Custom_Response(t *testing.T) { resp := AcquireResponse() req := a.Request() - req.Header.SetMethod(MethodGet) + req.Header.SetMethod(fiber.MethodGet) req.SetRequestURI("http://example.com") utils.AssertEqual(t, nil, a.Parse()) @@ -509,7 +510,7 @@ func Test_Client_Agent_Custom_Response(t *testing.T) { code, body, errs := a.SetResponse(resp). String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "custom", body) utils.AssertEqual(t, "custom", string(resp.Body())) utils.AssertEqual(t, 0, len(errs)) @@ -523,9 +524,9 @@ func Test_Client_Agent_Dest(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c Ctx) error { + app.Get("/", func(c fiber.Ctx) error { return c.SendString("dest") }) @@ -540,7 +541,7 @@ func Test_Client_Agent_Dest(t *testing.T) { code, body, errs := a.Dest(dest[:0]).String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "dest", body) utils.AssertEqual(t, "de", string(dest)) utils.AssertEqual(t, 0, len(errs)) @@ -555,7 +556,7 @@ func Test_Client_Agent_Dest(t *testing.T) { code, body, errs := a.Dest(dest[:0]).String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "dest", body) utils.AssertEqual(t, "destar", string(dest)) utils.AssertEqual(t, 0, len(errs)) @@ -591,7 +592,7 @@ func Test_Client_Agent_RetryIf(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() @@ -623,8 +624,8 @@ func Test_Client_Agent_RetryIf(t *testing.T) { } func Test_Client_Agent_Json(t *testing.T) { - handler := func(c Ctx) error { - utils.AssertEqual(t, MIMEApplicationJSON, string(c.Request().Header.ContentType())) + handler := func(c fiber.Ctx) error { + utils.AssertEqual(t, fiber.MIMEApplicationJSON, string(c.Request().Header.ContentType())) return c.Send(c.Request().Body()) } @@ -649,8 +650,8 @@ func Test_Client_Agent_Json_Error(t *testing.T) { } func Test_Client_Agent_XML(t *testing.T) { - handler := func(c Ctx) error { - utils.AssertEqual(t, MIMEApplicationXML, string(c.Request().Header.ContentType())) + handler := func(c fiber.Ctx) error { + utils.AssertEqual(t, fiber.MIMEApplicationXML, string(c.Request().Header.ContentType())) return c.Send(c.Request().Body()) } @@ -674,8 +675,8 @@ func Test_Client_Agent_XML_Error(t *testing.T) { } func Test_Client_Agent_Form(t *testing.T) { - handler := func(c Ctx) error { - utils.AssertEqual(t, MIMEApplicationForm, string(c.Request().Header.ContentType())) + handler := func(c fiber.Ctx) error { + utils.AssertEqual(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) return c.Send(c.Request().Body()) } @@ -698,10 +699,10 @@ func Test_Client_Agent_MultipartForm(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Post("/", func(c Ctx) error { - utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(HeaderContentType)) + app.Post("/", func(c fiber.Ctx) error { + utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) mf, err := c.MultipartForm() utils.AssertEqual(t, nil, err) @@ -724,7 +725,7 @@ func Test_Client_Agent_MultipartForm(t *testing.T) { code, body, errs := a.String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "--myBoundary\r\nContent-Disposition: form-data; name=\"foo\"\r\n\r\nbar\r\n--myBoundary--\r\n", body) utils.AssertEqual(t, 0, len(errs)) ReleaseArgs(args) @@ -753,10 +754,10 @@ func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Post("/", func(c Ctx) error { - utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(HeaderContentType)) + app.Post("/", func(c fiber.Ctx) error { + utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) fh1, err := c.FormFile("field1") utils.AssertEqual(t, nil, err) @@ -798,7 +799,7 @@ func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) { code, body, errs := a.String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "multipart form files", body) utils.AssertEqual(t, 0, len(errs)) @@ -832,7 +833,7 @@ func Test_Client_Agent_Multipart_Random_Boundary(t *testing.T) { reg := regexp.MustCompile(`multipart/form-data; boundary=\w{30}`) - utils.AssertEqual(t, true, reg.Match(a.req.Header.Peek(HeaderContentType))) + utils.AssertEqual(t, true, reg.Match(a.req.Header.Peek(fiber.HeaderContentType))) } func Test_Client_Agent_Multipart_Invalid_Boundary(t *testing.T) { @@ -857,7 +858,7 @@ func Test_Client_Agent_SendFile_Error(t *testing.T) { } func Test_Client_Debug(t *testing.T) { - handler := func(c Ctx) error { + handler := func(c fiber.Ctx) error { return c.SendString("debug") } @@ -884,9 +885,9 @@ func Test_Client_Agent_Timeout(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c Ctx) error { + app.Get("/", func(c fiber.Ctx) error { time.Sleep(time.Millisecond * 200) return c.SendString("timeout") }) @@ -910,9 +911,9 @@ func Test_Client_Agent_Reuse(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c Ctx) error { + app.Get("/", func(c fiber.Ctx) error { return c.SendString("reuse") }) @@ -925,13 +926,13 @@ func Test_Client_Agent_Reuse(t *testing.T) { code, body, errs := a.String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "reuse", body) utils.AssertEqual(t, 0, len(errs)) code, body, errs = a.String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "reuse", body) utils.AssertEqual(t, 0, len(errs)) } @@ -946,14 +947,14 @@ func Test_Client_Agent_InsecureSkipVerify(t *testing.T) { Certificates: []tls.Certificate{cer}, } - ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0") + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") utils.AssertEqual(t, nil, err) ln = tls.NewListener(ln, serverTLSConf) - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c Ctx) error { + app.Get("/", func(c fiber.Ctx) error { return c.SendString("ignore tls") }) @@ -965,7 +966,7 @@ func Test_Client_Agent_InsecureSkipVerify(t *testing.T) { String() utils.AssertEqual(t, 0, len(errs)) - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "ignore tls", body) } @@ -975,14 +976,14 @@ func Test_Client_Agent_TLS(t *testing.T) { serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() utils.AssertEqual(t, nil, err) - ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0") + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") utils.AssertEqual(t, nil, err) ln = tls.NewListener(ln, serverTLSConf) - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c Ctx) error { + app.Get("/", func(c fiber.Ctx) error { return c.SendString("tls") }) @@ -993,7 +994,7 @@ func Test_Client_Agent_TLS(t *testing.T) { String() utils.AssertEqual(t, 0, len(errs)) - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, "tls", body) } @@ -1002,15 +1003,15 @@ func Test_Client_Agent_MaxRedirectsCount(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c Ctx) error { + app.Get("/", func(c fiber.Ctx) error { if c.Request().URI().QueryArgs().Has("foo") { return c.Redirect("/foo") } return c.Redirect("/") }) - app.Get("/foo", func(c Ctx) error { + app.Get("/foo", func(c fiber.Ctx) error { return c.SendString("redirect") }) @@ -1048,13 +1049,13 @@ func Test_Client_Agent_Struct(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c Ctx) error { + app.Get("/", func(c fiber.Ctx) error { return c.JSON(data{true}) }) - app.Get("/error", func(c Ctx) error { + app.Get("/error", func(c fiber.Ctx) error { return c.SendString(`{"success"`) }) @@ -1071,7 +1072,7 @@ func Test_Client_Agent_Struct(t *testing.T) { code, body, errs := a.Struct(&d) - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, `{"success":true}`, string(body)) utils.AssertEqual(t, 0, len(errs)) utils.AssertEqual(t, true, d.Success) @@ -1102,7 +1103,7 @@ func Test_Client_Agent_Struct(t *testing.T) { code, body, errs := a.JSONDecoder(json.Unmarshal).Struct(&d) - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, `{"success"`, string(body)) utils.AssertEqual(t, 1, len(errs)) utils.AssertEqual(t, "unexpected end of JSON input", errs[0].Error()) @@ -1122,12 +1123,12 @@ func Test_AddMissingPort_TLS(t *testing.T) { utils.AssertEqual(t, "example.com:443", addr) } -func testAgent(t *testing.T, handler Handler, wrapAgent func(agent *Agent), excepted string, count ...int) { +func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Agent), excepted string, count ...int) { t.Parallel() ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) app.Get("/", handler) @@ -1147,7 +1148,7 @@ func testAgent(t *testing.T, handler Handler, wrapAgent func(agent *Agent), exce code, body, errs := a.String() - utils.AssertEqual(t, StatusOK, code) + utils.AssertEqual(t, fiber.StatusOK, code) utils.AssertEqual(t, excepted, body) utils.AssertEqual(t, 0, len(errs)) } diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index ebc1b4eb01..85158aea84 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -12,6 +12,7 @@ import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/gofiber/fiber/v3/utils" + fiberClient "github.com/gofiber/fiber/v3/client" ) func createProxyTestServer(handler fiber.Handler, t *testing.T) (*fiber.App, string) { @@ -115,7 +116,7 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) { go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - code, body, errs := fiber.Get("https://" + addr + "/tlsbalaner").TLSConfig(clientTLSConf).String() + code, body, errs := fiberClient.Get("https://" + addr + "/tlsbalaner").TLSConfig(clientTLSConf).String() utils.AssertEqual(t, 0, len(errs)) utils.AssertEqual(t, fiber.StatusOK, code) @@ -146,7 +147,7 @@ func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) { go func() { utils.AssertEqual(t, nil, app.Listener(proxyServerLn)) }() - code, body, errs := fiber.Get("https://" + proxyAddr). + code, body, errs := fiberClient.Get("https://" + proxyAddr). InsecureSkipVerify(). Timeout(5 * time.Second). String() @@ -204,7 +205,7 @@ func Test_Proxy_Forward_WithTlsConfig(t *testing.T) { go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - code, body, errs := fiber.Get("https://" + addr).TLSConfig(clientTLSConf).String() + code, body, errs := fiberClient.Get("https://" + addr).TLSConfig(clientTLSConf).String() utils.AssertEqual(t, 0, len(errs)) utils.AssertEqual(t, fiber.StatusOK, code) From 1af81b8f15e7942fadd6b22f91c7344632bf9d01 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Fri, 29 Jul 2022 22:14:11 +0800 Subject: [PATCH 002/118] =?UTF-8?q?=E2=9C=A8=20v3:=20add=20xml=20encoder?= =?UTF-8?q?=20and=20decoder?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- utils/xml.go | 9 ++++++ utils/xml_test.go | 76 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 utils/xml.go create mode 100644 utils/xml_test.go diff --git a/utils/xml.go b/utils/xml.go new file mode 100644 index 0000000000..6e28f84d04 --- /dev/null +++ b/utils/xml.go @@ -0,0 +1,9 @@ +package utils + +// XMLMarshal returns the XML encoding of v. +type XMLMarshal func(v any) ([]byte, error) + +// XMLUnmarshal parses the XML-encoded data and stores the result +// in the value pointed to by v. If v is nil or not a pointer, +// Unmarshal returns an InvalidUnmarshalError. +type XMLUnmarshal func(data []byte, v any) error diff --git a/utils/xml_test.go b/utils/xml_test.go new file mode 100644 index 0000000000..55d2b4227e --- /dev/null +++ b/utils/xml_test.go @@ -0,0 +1,76 @@ +package utils + +import ( + "encoding/xml" + "testing" +) + +type serversXMLStructure struct { + XMLName xml.Name `xml:"servers"` + Version string `xml:"version,attr"` + Servers []serverXMLStructure `xml:"server"` +} + +type serverXMLStructure struct { + XMLName xml.Name `xml:"server"` + Name string `xml:"name"` +} + +var xmlString = `fiber onefiber two` + +func Test_GolangXMLEncoder(t *testing.T) { + t.Parallel() + + var ( + ss = &serversXMLStructure{ + Version: "1", + Servers: []serverXMLStructure{ + {Name: "fiber one"}, + {Name: "fiber two"}, + }, + } + xmlEncoder XMLMarshal = xml.Marshal + ) + + raw, err := xmlEncoder(ss) + AssertEqual(t, err, nil) + + AssertEqual(t, string(raw), xmlString) +} + +func Test_DefaultXMLEncoder(t *testing.T) { + t.Parallel() + + var ( + ss = &serversXMLStructure{ + Version: "1", + Servers: []serverXMLStructure{ + {Name: "fiber one"}, + {Name: "fiber two"}, + }, + } + xmlEncoder XMLMarshal = xml.Marshal + ) + + raw, err := xmlEncoder(ss) + AssertEqual(t, err, nil) + + AssertEqual(t, string(raw), xmlString) +} + +func Test_DefaultXMLDecoder(t *testing.T) { + t.Parallel() + + var ( + ss serversXMLStructure + xmlBytes = []byte(xmlString) + xmlDecoder XMLUnmarshal = xml.Unmarshal + ) + + err := xmlDecoder(xmlBytes, &ss) + AssertEqual(t, err, nil) + AssertEqual(t, len(ss.Servers), 2) + AssertEqual(t, ss.Version, "1") + AssertEqual(t, ss.Servers[0].Name, "fiber one") + AssertEqual(t, ss.Servers[1].Name, "fiber two") +} From 894777f0fc3268aa42f8001102a2bd3771ef62cf Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sat, 30 Jul 2022 23:21:38 +0800 Subject: [PATCH 003/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20design=20plugin=20?= =?UTF-8?q?and=20hook=20mechanism,=20complete=20simple=20get=20request?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 1006 ++---------------------- client/client_test.go | 1696 +++++++++++++++++++++-------------------- client/core.go | 230 ++++++ client/hooks.go | 47 ++ client/plugins.go | 1 + client/request.go | 77 ++ client/respose.go | 43 ++ 7 files changed, 1311 insertions(+), 1789 deletions(-) create mode 100644 client/core.go create mode 100644 client/hooks.go create mode 100644 client/plugins.go create mode 100644 client/request.go create mode 100644 client/respose.go diff --git a/client/client.go b/client/client.go index f9bac47647..3bb86920d2 100644 --- a/client/client.go +++ b/client/client.go @@ -1,1000 +1,90 @@ package client import ( - "bytes" - "crypto/tls" - "encoding/json" - "encoding/xml" - "fmt" - "io" - "mime/multipart" - "net" - "os" - "path/filepath" - "strconv" - "strings" "sync" - "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/utils" "github.com/valyala/fasthttp" ) -// Request represents HTTP request. -// -// It is forbidden copying Request instances. Create new instances -// and use CopyTo instead. -// -// Request instance MUST NOT be used from concurrently running goroutines. -// Copy from fasthttp -type Request = fasthttp.Request - -// Response represents HTTP response. -// -// It is forbidden copying Response instances. Create new instances -// and use CopyTo instead. -// -// Response instance MUST NOT be used from concurrently running goroutines. -// Copy from fasthttp -type Response = fasthttp.Response - -// Args represents query arguments. -// -// It is forbidden copying Args instances. Create new instances instead -// and use CopyTo(). -// -// Args instance MUST NOT be used from concurrently running goroutines. -// Copy from fasthttp -type Args = fasthttp.Args - -// RetryIfFunc signature of retry if function -// Request argument passed to RetryIfFunc, if there are any request errors. -// Copy from fasthttp -type RetryIfFunc = fasthttp.RetryIfFunc - -var defaultClient Client - -// Client implements http client. -// -// It is safe calling Client methods from concurrently running goroutines. type Client struct { - mutex sync.RWMutex - // UserAgent is used in User-Agent request header. - UserAgent string - - // NoDefaultUserAgentHeader when set to true, causes the default - // User-Agent header to be excluded from the Request. - NoDefaultUserAgentHeader bool - - // When set by an external client of Fiber it will use the provided implementation of a - // JSONMarshal - // - // Allowing for flexibility in using another json library for encoding - JSONEncoder utils.JSONMarshal - - // When set by an external client of Fiber it will use the provided implementation of a - // JSONUnmarshal - // - // Allowing for flexibility in using another json library for decoding - JSONDecoder utils.JSONUnmarshal -} - -// Get returns a agent with http method GET. -func Get(url string) *Agent { return defaultClient.Get(url) } - -// Get returns a agent with http method GET. -func (c *Client) Get(url string) *Agent { - return c.createAgent(fiber.MethodGet, url) -} - -// Head returns a agent with http method HEAD. -func Head(url string) *Agent { return defaultClient.Head(url) } - -// Head returns a agent with http method GET. -func (c *Client) Head(url string) *Agent { - return c.createAgent(fiber.MethodHead, url) -} - -// Post sends POST request to the given url. -func Post(url string) *Agent { return defaultClient.Post(url) } - -// Post sends POST request to the given url. -func (c *Client) Post(url string) *Agent { - return c.createAgent(fiber.MethodPost, url) -} - -// Put sends PUT request to the given url. -func Put(url string) *Agent { return defaultClient.Put(url) } - -// Put sends PUT request to the given url. -func (c *Client) Put(url string) *Agent { - return c.createAgent(fiber.MethodPut, url) -} - -// Patch sends PATCH request to the given url. -func Patch(url string) *Agent { return defaultClient.Patch(url) } - -// Patch sends PATCH request to the given url. -func (c *Client) Patch(url string) *Agent { - return c.createAgent(fiber.MethodPatch, url) -} - -// Delete sends DELETE request to the given url. -func Delete(url string) *Agent { return defaultClient.Delete(url) } - -// Delete sends DELETE request to the given url. -func (c *Client) Delete(url string) *Agent { - return c.createAgent(fiber.MethodDelete, url) -} - -func (c *Client) createAgent(method, url string) *Agent { - a := AcquireAgent() - a.req.Header.SetMethod(method) - a.req.SetRequestURI(url) - - c.mutex.RLock() - a.Name = c.UserAgent - a.NoDefaultUserAgentHeader = c.NoDefaultUserAgentHeader - a.jsonDecoder = c.JSONDecoder - a.jsonEncoder = c.JSONEncoder - if a.jsonDecoder == nil { - a.jsonDecoder = json.Unmarshal - } - c.mutex.RUnlock() - - if err := a.Parse(); err != nil { - a.errs = append(a.errs, err) - } - - return a -} - -// Agent is an object storing all request data for client. -// Agent instance MUST NOT be used from concurrently running goroutines. -type Agent struct { - // Name is used in User-Agent request header. - Name string - - // NoDefaultUserAgentHeader when set to true, causes the default - // User-Agent header to be excluded from the Request. - NoDefaultUserAgentHeader bool - - // HostClient is an embedded fasthttp HostClient - *fasthttp.HostClient - - req *Request - resp *Response - dest []byte - args *Args - timeout time.Duration - errs []error - formFiles []*FormFile - debugWriter io.Writer - mw multipartWriter - jsonEncoder utils.JSONMarshal - jsonDecoder utils.JSONUnmarshal - maxRedirectsCount int - boundary string - reuse bool - parsed bool -} - -// Parse initializes URI and HostClient. -func (a *Agent) Parse() error { - if a.parsed { - return nil - } - a.parsed = true - - uri := a.req.URI() - - isTLS := false - scheme := uri.Scheme() - if bytes.Equal(scheme, strHTTPS) { - isTLS = true - } else if !bytes.Equal(scheme, strHTTP) { - return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme) - } - - name := a.Name - if name == "" && !a.NoDefaultUserAgentHeader { - name = defaultUserAgent - } - - a.HostClient = &fasthttp.HostClient{ - Addr: addMissingPort(string(uri.Host()), isTLS), - Name: name, - NoDefaultUserAgentHeader: a.NoDefaultUserAgentHeader, - IsTLS: isTLS, - } - - return nil -} - -func addMissingPort(addr string, isTLS bool) string { - n := strings.Index(addr, ":") - if n >= 0 { - return addr - } - port := 80 - if isTLS { - port = 443 - } - return net.JoinHostPort(addr, strconv.Itoa(port)) -} - -/************************** Header Setting **************************/ - -// Set sets the given 'key: value' header. -// -// Use Add for setting multiple header values under the same key. -func (a *Agent) Set(k, v string) *Agent { - a.req.Header.Set(k, v) - - return a -} - -// SetBytesK sets the given 'key: value' header. -// -// Use AddBytesK for setting multiple header values under the same key. -func (a *Agent) SetBytesK(k []byte, v string) *Agent { - a.req.Header.SetBytesK(k, v) - - return a -} - -// SetBytesV sets the given 'key: value' header. -// -// Use AddBytesV for setting multiple header values under the same key. -func (a *Agent) SetBytesV(k string, v []byte) *Agent { - a.req.Header.SetBytesV(k, v) - - return a -} - -// SetBytesKV sets the given 'key: value' header. -// -// Use AddBytesKV for setting multiple header values under the same key. -func (a *Agent) SetBytesKV(k []byte, v []byte) *Agent { - a.req.Header.SetBytesKV(k, v) - - return a -} - -// Add adds the given 'key: value' header. -// -// Multiple headers with the same key may be added with this function. -// Use Set for setting a single header for the given key. -func (a *Agent) Add(k, v string) *Agent { - a.req.Header.Add(k, v) - - return a -} - -// AddBytesK adds the given 'key: value' header. -// -// Multiple headers with the same key may be added with this function. -// Use SetBytesK for setting a single header for the given key. -func (a *Agent) AddBytesK(k []byte, v string) *Agent { - a.req.Header.AddBytesK(k, v) - - return a -} - -// AddBytesV adds the given 'key: value' header. -// -// Multiple headers with the same key may be added with this function. -// Use SetBytesV for setting a single header for the given key. -func (a *Agent) AddBytesV(k string, v []byte) *Agent { - a.req.Header.AddBytesV(k, v) - - return a -} - -// AddBytesKV adds the given 'key: value' header. -// -// Multiple headers with the same key may be added with this function. -// Use SetBytesKV for setting a single header for the given key. -func (a *Agent) AddBytesKV(k []byte, v []byte) *Agent { - a.req.Header.AddBytesKV(k, v) - - return a -} - -// ConnectionClose sets 'Connection: close' header. -func (a *Agent) ConnectionClose() *Agent { - a.req.Header.SetConnectionClose() - - return a -} - -// UserAgent sets User-Agent header value. -func (a *Agent) UserAgent(userAgent string) *Agent { - a.req.Header.SetUserAgent(userAgent) - - return a -} - -// UserAgentBytes sets User-Agent header value. -func (a *Agent) UserAgentBytes(userAgent []byte) *Agent { - a.req.Header.SetUserAgentBytes(userAgent) - - return a -} - -// Cookie sets one 'key: value' cookie. -func (a *Agent) Cookie(key, value string) *Agent { - a.req.Header.SetCookie(key, value) - - return a -} - -// CookieBytesK sets one 'key: value' cookie. -func (a *Agent) CookieBytesK(key []byte, value string) *Agent { - a.req.Header.SetCookieBytesK(key, value) - - return a -} - -// CookieBytesKV sets one 'key: value' cookie. -func (a *Agent) CookieBytesKV(key, value []byte) *Agent { - a.req.Header.SetCookieBytesKV(key, value) - - return a -} - -// Cookies sets multiple 'key: value' cookies. -func (a *Agent) Cookies(kv ...string) *Agent { - for i := 1; i < len(kv); i += 2 { - a.req.Header.SetCookie(kv[i-1], kv[i]) - } - - return a -} - -// CookiesBytesKV sets multiple 'key: value' cookies. -func (a *Agent) CookiesBytesKV(kv ...[]byte) *Agent { - for i := 1; i < len(kv); i += 2 { - a.req.Header.SetCookieBytesKV(kv[i-1], kv[i]) - } - - return a -} - -// Referer sets Referer header value. -func (a *Agent) Referer(referer string) *Agent { - a.req.Header.SetReferer(referer) - - return a -} - -// RefererBytes sets Referer header value. -func (a *Agent) RefererBytes(referer []byte) *Agent { - a.req.Header.SetRefererBytes(referer) - - return a -} - -// ContentType sets Content-Type header value. -func (a *Agent) ContentType(contentType string) *Agent { - a.req.Header.SetContentType(contentType) - - return a -} - -// ContentTypeBytes sets Content-Type header value. -func (a *Agent) ContentTypeBytes(contentType []byte) *Agent { - a.req.Header.SetContentTypeBytes(contentType) - - return a -} - -/************************** End Header Setting **************************/ - -/************************** URI Setting **************************/ - -// Host sets host for the uri. -func (a *Agent) Host(host string) *Agent { - a.req.URI().SetHost(host) - - return a -} - -// HostBytes sets host for the URI. -func (a *Agent) HostBytes(host []byte) *Agent { - a.req.URI().SetHostBytes(host) - - return a -} - -// QueryString sets URI query string. -func (a *Agent) QueryString(queryString string) *Agent { - a.req.URI().SetQueryString(queryString) - - return a -} - -// QueryStringBytes sets URI query string. -func (a *Agent) QueryStringBytes(queryString []byte) *Agent { - a.req.URI().SetQueryStringBytes(queryString) - - return a -} - -// BasicAuth sets URI username and password. -func (a *Agent) BasicAuth(username, password string) *Agent { - a.req.URI().SetUsername(username) - a.req.URI().SetPassword(password) - - return a -} - -// BasicAuthBytes sets URI username and password. -func (a *Agent) BasicAuthBytes(username, password []byte) *Agent { - a.req.URI().SetUsernameBytes(username) - a.req.URI().SetPasswordBytes(password) - - return a -} - -/************************** End URI Setting **************************/ + core -/************************** Request Setting **************************/ - -// BodyString sets request body. -func (a *Agent) BodyString(bodyString string) *Agent { - a.req.SetBodyString(bodyString) - - return a -} - -// Body sets request body. -func (a *Agent) Body(body []byte) *Agent { - a.req.SetBody(body) - - return a -} - -// BodyStream sets request body stream and, optionally body size. -// -// If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes -// before returning io.EOF. -// -// If bodySize < 0, then bodyStream is read until io.EOF. -// -// bodyStream.Close() is called after finishing reading all body data -// if it implements io.Closer. -// -// Note that GET and HEAD requests cannot have body. -func (a *Agent) BodyStream(bodyStream io.Reader, bodySize int) *Agent { - a.req.SetBodyStream(bodyStream, bodySize) - - return a -} - -// JSON sends a JSON request. -func (a *Agent) JSON(v any) *Agent { - if a.jsonEncoder == nil { - a.jsonEncoder = json.Marshal - } - - a.req.Header.SetContentType(fiber.MIMEApplicationJSON) - - if body, err := a.jsonEncoder(v); err != nil { - a.errs = append(a.errs, err) - } else { - a.req.SetBody(body) - } - - return a -} - -// XML sends an XML request. -func (a *Agent) XML(v any) *Agent { - a.req.Header.SetContentType(fiber.MIMEApplicationXML) - - if body, err := xml.Marshal(v); err != nil { - a.errs = append(a.errs, err) - } else { - a.req.SetBody(body) - } - - return a -} - -// Form sends form request with body if args is non-nil. -// -// It is recommended obtaining args via AcquireArgs and release it -// manually in performance-critical code. -func (a *Agent) Form(args *Args) *Agent { - a.req.Header.SetContentType(fiber.MIMEApplicationForm) - - if args != nil { - a.req.SetBody(args.QueryString()) - } - - return a -} - -// FormFile represents multipart form file -type FormFile struct { - // Fieldname is form file's field name - Fieldname string - // Name is form file's name - Name string - // Content is form file's content - Content []byte - // autoRelease indicates if returns the object - // acquired via AcquireFormFile to the pool. - autoRelease bool -} - -// FileData appends files for multipart form request. -// -// It is recommended obtaining formFile via AcquireFormFile and release it -// manually in performance-critical code. -func (a *Agent) FileData(formFiles ...*FormFile) *Agent { - a.formFiles = append(a.formFiles, formFiles...) - - return a -} - -// SendFile reads file and appends it to multipart form request. -func (a *Agent) SendFile(filename string, fieldname ...string) *Agent { - content, err := os.ReadFile(filepath.Clean(filename)) - if err != nil { - a.errs = append(a.errs, err) - return a - } - - ff := AcquireFormFile() - if len(fieldname) > 0 && fieldname[0] != "" { - ff.Fieldname = fieldname[0] - } else { - ff.Fieldname = "file" + strconv.Itoa(len(a.formFiles)+1) - } - ff.Name = filepath.Base(filename) - ff.Content = append(ff.Content, content...) - ff.autoRelease = true - - a.formFiles = append(a.formFiles, ff) - - return a -} - -// SendFiles reads files and appends them to multipart form request. -// -// Examples: -// SendFile("/path/to/file1", "fieldname1", "/path/to/file2") -func (a *Agent) SendFiles(filenamesAndFieldnames ...string) *Agent { - pairs := len(filenamesAndFieldnames) - if pairs&1 == 1 { - filenamesAndFieldnames = append(filenamesAndFieldnames, "") - } - - for i := 0; i < pairs; i += 2 { - a.SendFile(filenamesAndFieldnames[i], filenamesAndFieldnames[i+1]) - } - - return a -} - -// Boundary sets boundary for multipart form request. -func (a *Agent) Boundary(boundary string) *Agent { - a.boundary = boundary - - return a -} - -// MultipartForm sends multipart form request with k-v and files. -// -// It is recommended obtaining args via AcquireArgs and release it -// manually in performance-critical code. -func (a *Agent) MultipartForm(args *Args) *Agent { - if a.mw == nil { - a.mw = multipart.NewWriter(a.req.BodyWriter()) - } - - if a.boundary != "" { - if err := a.mw.SetBoundary(a.boundary); err != nil { - a.errs = append(a.errs, err) - return a - } - } - - a.req.Header.SetMultipartFormBoundary(a.mw.Boundary()) - - if args != nil { - args.VisitAll(func(key, value []byte) { - if err := a.mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value)); err != nil { - a.errs = append(a.errs, err) - } - }) - } - - for _, ff := range a.formFiles { - w, err := a.mw.CreateFormFile(ff.Fieldname, ff.Name) - if err != nil { - a.errs = append(a.errs, err) - continue - } - if _, err = w.Write(ff.Content); err != nil { - a.errs = append(a.errs, err) - } - } - - if err := a.mw.Close(); err != nil { - a.errs = append(a.errs, err) - } - - return a -} - -/************************** End Request Setting **************************/ - -/************************** Agent Setting **************************/ - -// Debug mode enables logging request and response detail -func (a *Agent) Debug(w ...io.Writer) *Agent { - a.debugWriter = os.Stdout - if len(w) > 0 { - a.debugWriter = w[0] - } - - return a -} - -// Timeout sets request timeout duration. -func (a *Agent) Timeout(timeout time.Duration) *Agent { - a.timeout = timeout - - return a -} - -// Reuse enables the Agent instance to be used again after one request. -// -// If agent is reusable, then it should be released manually when it is no -// longer used. -func (a *Agent) Reuse() *Agent { - a.reuse = true - - return a + baseUrl string + header map[string][]string } -// InsecureSkipVerify controls whether the Agent verifies the server -// certificate chain and host name. -func (a *Agent) InsecureSkipVerify() *Agent { - if a.HostClient.TLSConfig == nil { - /* #nosec G402 */ - a.HostClient.TLSConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402 - } else { - /* #nosec G402 */ - a.HostClient.TLSConfig.InsecureSkipVerify = true - } - - return a +// Add user-defined request hooks. +func (c *Client) AddRequestHook(h ...RequestHook) *Client { + c.userRequestHooks = append(c.userRequestHooks, h...) + return c } -// TLSConfig sets tls config. -func (a *Agent) TLSConfig(config *tls.Config) *Agent { - a.HostClient.TLSConfig = config - - return a +// Add user-defined response hooks. +func (c *Client) AddResponseHook(h ...ResponseHook) *Client { + c.userResponseHooks = append(c.userResponseHooks, h...) + return c } -// MaxRedirectsCount sets max redirect count for GET and HEAD. -func (a *Agent) MaxRedirectsCount(count int) *Agent { - a.maxRedirectsCount = count - - return a +func (c *Client) SetDial(f fasthttp.DialFunc) *Client { + c.client.Dial = f + return c } -// JSONEncoder sets custom json encoder. -func (a *Agent) JSONEncoder(jsonEncoder utils.JSONMarshal) *Agent { - a.jsonEncoder = jsonEncoder - - return a +// Set json encoder. +func (c *Client) SetJSONMarshal(f utils.JSONMarshal) *Client { + c.jsonMarshal = f + return c } -// JSONDecoder sets custom json decoder. -func (a *Agent) JSONDecoder(jsonDecoder utils.JSONUnmarshal) *Agent { - a.jsonDecoder = jsonDecoder - - return a +// Set json decoder. +func (c *Client) SetJSONUnmarshal(f utils.JSONUnmarshal) *Client { + c.jsonUnmarshal = f + return c } -// Request returns Agent request instance. -func (a *Agent) Request() *Request { - return a.req +// Set xml encoder. +func (c *Client) SetXMLMarshal(f utils.XMLMarshal) *Client { + c.xmlMarshal = f + return c } -// SetResponse sets custom response for the Agent instance. -// -// It is recommended obtaining custom response via AcquireResponse and release it -// manually in performance-critical code. -func (a *Agent) SetResponse(customResp *Response) *Agent { - a.resp = customResp - - return a +// Set xml decoder. +func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client { + c.xmlUnmarshal = f + return c } -// Dest sets custom dest. -// -// The contents of dest will be replaced by the response body, if the dest -// is too small a new slice will be allocated. -func (a *Agent) Dest(dest []byte) *Agent { - a.dest = dest +func (c *Client) Get(url string) (*Response, error) { + req := AcquireRequest(). + SetURL(url). + SetMethod(fiber.MethodGet) - return a -} - -// RetryIf controls whether a retry should be attempted after an error. -// -// By default, will use isIdempotent function from fasthttp -func (a *Agent) RetryIf(retryIf RetryIfFunc) *Agent { - a.HostClient.RetryIf = retryIf - return a -} - -/************************** End Agent Setting **************************/ -var warnOnce sync.Once - -// Bytes returns the status code, bytes body and errors of url. -func (a *Agent) Bytes() (code int, body []byte, errs []error) { - warnOnce.Do(func() { - fmt.Println("[Warning] client is still in beta, API might change in the future!") - }) - - defer a.release() - - if errs = append(errs, a.errs...); len(errs) > 0 { - return - } - - var ( - req = a.req - resp *Response - nilResp bool - ) - - if a.resp == nil { - resp = AcquireResponse() - nilResp = true - } else { - resp = a.resp - } - - defer func() { - if a.debugWriter != nil { - printDebugInfo(req, resp, a.debugWriter) - } - - if len(errs) == 0 { - code = resp.StatusCode() - } - - body = append(a.dest, resp.Body()...) - - if nilResp { - ReleaseResponse(resp) - } - }() - - if a.timeout > 0 { - if err := a.HostClient.DoTimeout(req, resp, a.timeout); err != nil { - errs = append(errs, err) - return - } - } else if a.maxRedirectsCount > 0 && (string(req.Header.Method()) == fiber.MethodGet || string(req.Header.Method()) == fiber.MethodHead) { - if err := a.HostClient.DoRedirects(req, resp, a.maxRedirectsCount); err != nil { - errs = append(errs, err) - return - } - } else if err := a.HostClient.Do(req, resp); err != nil { - errs = append(errs, err) - } - - return -} - -func printDebugInfo(req *Request, resp *Response, w io.Writer) { - msg := fmt.Sprintf("Connected to %s(%s)\r\n\r\n", req.URI().Host(), resp.RemoteAddr()) - _, _ = w.Write(utils.UnsafeBytes(msg)) - _, _ = req.WriteTo(w) - _, _ = resp.WriteTo(w) -} - -// String returns the status code, string body and errors of url. -func (a *Agent) String() (int, string, []error) { - code, body, errs := a.Bytes() - - return code, utils.UnsafeString(body), errs -} - -// Struct returns the status code, bytes body and errors of url. -// And bytes body will be unmarshalled to given v. -func (a *Agent) Struct(v any) (code int, body []byte, errs []error) { - if code, body, errs = a.Bytes(); len(errs) > 0 { - return - } - - if err := a.jsonDecoder(body, v); err != nil { - errs = append(errs, err) - } - - return -} - -func (a *Agent) release() { - if !a.reuse { - ReleaseAgent(a) - } else { - a.errs = a.errs[:0] - } -} - -func (a *Agent) reset() { - a.HostClient = nil - a.req.Reset() - a.resp = nil - a.dest = nil - a.timeout = 0 - a.args = nil - a.errs = a.errs[:0] - a.debugWriter = nil - a.mw = nil - a.reuse = false - a.parsed = false - a.maxRedirectsCount = 0 - a.boundary = "" - a.Name = "" - a.NoDefaultUserAgentHeader = false - for i, ff := range a.formFiles { - if ff.autoRelease { - ReleaseFormFile(ff) - } - a.formFiles[i] = nil - } - a.formFiles = a.formFiles[:0] + return c.execute(req.Context(), c, req) } var ( - clientPool sync.Pool - agentPool sync.Pool - responsePool sync.Pool - argsPool sync.Pool - formFilePool sync.Pool + defaultClient *Client + clientPool sync.Pool ) -// AcquireClient returns an empty Client instance from client pool. -// -// The returned Client instance may be passed to ReleaseClient when it is -// no longer needed. This allows Client recycling, reduces GC pressure -// and usually improves performance. -func AcquireClient() *Client { - v := clientPool.Get() - if v == nil { - return &Client{} - } - return v.(*Client) -} - -// ReleaseClient returns c acquired via AcquireClient to client pool. -// -// It is forbidden accessing req and/or its' members after returning -// it to client pool. -func ReleaseClient(c *Client) { - c.UserAgent = "" - c.NoDefaultUserAgentHeader = false - c.JSONEncoder = nil - c.JSONDecoder = nil - - clientPool.Put(c) -} - -// AcquireAgent returns an empty Agent instance from Agent pool. -// -// The returned Agent instance may be passed to ReleaseAgent when it is -// no longer needed. This allows Agent recycling, reduces GC pressure -// and usually improves performance. -func AcquireAgent() *Agent { - v := agentPool.Get() - if v == nil { - return &Agent{req: &Request{}} - } - return v.(*Agent) -} - -// ReleaseAgent returns a acquired via AcquireAgent to Agent pool. -// -// It is forbidden accessing req and/or its' members after returning -// it to Agent pool. -func ReleaseAgent(a *Agent) { - a.reset() - agentPool.Put(a) -} - -// AcquireResponse returns an empty Response instance from response pool. -// -// The returned Response instance may be passed to ReleaseResponse when it is -// no longer needed. This allows Response recycling, reduces GC pressure -// and usually improves performance. -// Copy from fasthttp -func AcquireResponse() *Response { - v := responsePool.Get() - if v == nil { - return &Response{} - } - return v.(*Response) +func init() { + defaultClient = AcquireClient() } -// ReleaseResponse return resp acquired via AcquireResponse to response pool. -// -// It is forbidden accessing resp and/or its' members after returning -// it to response pool. -// Copy from fasthttp -func ReleaseResponse(resp *Response) { - resp.Reset() - responsePool.Put(resp) -} - -// AcquireArgs returns an empty Args object from the pool. -// -// The returned Args may be returned to the pool with ReleaseArgs -// when no longer needed. This allows reducing GC load. -// Copy from fasthttp -func AcquireArgs() *Args { - v := argsPool.Get() - if v == nil { - return &Args{} - } - return v.(*Args) -} - -// ReleaseArgs returns the object acquired via AcquireArgs to the pool. -// -// String not access the released Args object, otherwise data races may occur. -// Copy from fasthttp -func ReleaseArgs(a *Args) { - a.Reset() - argsPool.Put(a) -} - -// AcquireFormFile returns an empty FormFile object from the pool. -// -// The returned FormFile may be returned to the pool with ReleaseFormFile -// when no longer needed. This allows reducing GC load. -func AcquireFormFile() *FormFile { - v := formFilePool.Get() - if v == nil { - return &FormFile{} +func AcquireClient() *Client { + return &Client{ + core: *acquireCore(), + header: map[string][]string{}, } - return v.(*FormFile) } -// ReleaseFormFile returns the object acquired via AcquireFormFile to the pool. -// -// String not access the released FormFile object, otherwise data races may occur. -func ReleaseFormFile(ff *FormFile) { - ff.Fieldname = "" - ff.Name = "" - ff.Content = ff.Content[:0] - ff.autoRelease = false - - formFilePool.Put(ff) +// Get default client. +func C() *Client { + return defaultClient } -var ( - strHTTP = []byte("http") - strHTTPS = []byte("https") - defaultUserAgent = "fiber" -) - -type multipartWriter interface { - Boundary() string - SetBoundary(boundary string) error - CreateFormFile(fieldname, filename string) (io.Writer, error) - WriteField(fieldname, value string) error - Close() error +func Get(url string) (*Response, error) { + return defaultClient.Get(url) } diff --git a/client/client_test.go b/client/client_test.go index a56dcb64c7..040cce0c11 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,1179 +1,1213 @@ package client import ( - "bytes" - "crypto/tls" - "encoding/base64" - "errors" - "fmt" - "io" - "mime/multipart" "net" - "os" - "path/filepath" - "regexp" - "strings" "testing" - "time" - - "encoding/json" "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/gofiber/fiber/v3/utils" "github.com/valyala/fasthttp/fasthttputil" ) -func Test_Client_Invalid_URL(t *testing.T) { - t.Parallel() +// import ( +// "bytes" +// "crypto/tls" +// "encoding/base64" +// "errors" +// "fmt" +// "io" +// "mime/multipart" +// "net" +// "os" +// "path/filepath" +// "regexp" +// "strings" +// "testing" +// "time" - ln := fasthttputil.NewInmemoryListener() +// "encoding/json" - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// "github.com/gofiber/fiber/v3" +// "github.com/gofiber/fiber/v3/internal/tlstest" +// "github.com/gofiber/fiber/v3/utils" +// "github.com/valyala/fasthttp/fasthttputil" +// ) - app.Get("/", func(c fiber.Ctx) error { - return c.SendString(c.Hostname()) - }) +// func Test_Client_Invalid_URL(t *testing.T) { +// t.Parallel() - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// ln := fasthttputil.NewInmemoryListener() - a := Get("http://example.com\r\n\r\nGET /\r\n\r\n") +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString(c.Hostname()) +// }) - _, body, errs := a.String() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - utils.AssertEqual(t, "", body) - utils.AssertEqual(t, 1, len(errs)) - utils.AssertEqual(t, "missing required Host header in request", errs[0].Error()) -} +// a := Get("http://example.com\r\n\r\nGET /\r\n\r\n") -func Test_Client_Unsupported_Protocol(t *testing.T) { - t.Parallel() +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - a := Get("ftp://example.com") +// _, body, errs := a.String() - _, body, errs := a.String() +// utils.AssertEqual(t, "", body) +// utils.AssertEqual(t, 1, len(errs)) +// utils.AssertEqual(t, "missing required Host header in request", errs[0].Error()) +// } - utils.AssertEqual(t, "", body) - utils.AssertEqual(t, 1, len(errs)) - utils.AssertEqual(t, `unsupported protocol "ftp". http and https are supported`, - errs[0].Error()) -} +// func Test_Client_Unsupported_Protocol(t *testing.T) { +// t.Parallel() + +// a := Get("ftp://example.com") + +// _, body, errs := a.String() -func Test_Client_Get(t *testing.T) { +// utils.AssertEqual(t, "", body) +// utils.AssertEqual(t, 1, len(errs)) +// utils.AssertEqual(t, `unsupported protocol "ftp". http and https are supported`, +// errs[0].Error()) +// } + +func TestGet(t *testing.T) { t.Parallel() ln := fasthttputil.NewInmemoryListener() app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - - for i := 0; i < 5; i++ { - a := Get("http://example.com") + go func() { + utils.AssertEqual(t, nil, app.Listener(ln)) + }() - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() + t.Run("global get function", func(t *testing.T) { + C().SetDial(func(addr string) (net.Conn, error) { + return ln.Dial() + }) - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "example.com", body) - utils.AssertEqual(t, 0, len(errs)) - } + resp, err := Get("http://example.com") + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "example.com", utils.UnsafeString(resp.rawResponse.Body())) + }) } -func Test_Client_Head(t *testing.T) { - t.Parallel() +// func Test_Client_Get(t *testing.T) { +// t.Parallel() - ln := fasthttputil.NewInmemoryListener() +// ln := fasthttputil.NewInmemoryListener() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c fiber.Ctx) error { - return c.SendString(c.Hostname()) - }) +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString(c.Hostname()) +// }) - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - for i := 0; i < 5; i++ { - a := Head("http://example.com") +// for i := 0; i < 5; i++ { +// a := Get("http://example.com") - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - code, body, errs := a.String() +// code, body, errs := a.String() - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "", body) - utils.AssertEqual(t, 0, len(errs)) - } -} +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "example.com", body) +// utils.AssertEqual(t, 0, len(errs)) +// } +// } -func Test_Client_Post(t *testing.T) { - t.Parallel() +// func Test_Client_Head(t *testing.T) { +// t.Parallel() - ln := fasthttputil.NewInmemoryListener() +// ln := fasthttputil.NewInmemoryListener() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Post("/", func(c fiber.Ctx) error { - return c.Status(fiber.StatusCreated). - SendString(c.FormValue("foo")) - }) +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString(c.Hostname()) +// }) - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - for i := 0; i < 5; i++ { - args := AcquireArgs() +// for i := 0; i < 5; i++ { +// a := Head("http://example.com") - args.Set("foo", "bar") +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - a := Post("http://example.com"). - Form(args) +// code, body, errs := a.String() - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "", body) +// utils.AssertEqual(t, 0, len(errs)) +// } +// } - code, body, errs := a.String() +// func Test_Client_Post(t *testing.T) { +// t.Parallel() - utils.AssertEqual(t, fiber.StatusCreated, code) - utils.AssertEqual(t, "bar", body) - utils.AssertEqual(t, 0, len(errs)) +// ln := fasthttputil.NewInmemoryListener() - ReleaseArgs(args) - } -} +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) -func Test_Client_Put(t *testing.T) { - t.Parallel() +// app.Post("/", func(c fiber.Ctx) error { +// return c.Status(fiber.StatusCreated). +// SendString(c.FormValue("foo")) +// }) - ln := fasthttputil.NewInmemoryListener() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// for i := 0; i < 5; i++ { +// args := AcquireArgs() - app.Put("/", func(c fiber.Ctx) error { - return c.SendString(c.FormValue("foo")) - }) +// args.Set("foo", "bar") - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// a := Post("http://example.com"). +// Form(args) - for i := 0; i < 5; i++ { - args := AcquireArgs() +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - args.Set("foo", "bar") +// code, body, errs := a.String() - a := Put("http://example.com"). - Form(args) +// utils.AssertEqual(t, fiber.StatusCreated, code) +// utils.AssertEqual(t, "bar", body) +// utils.AssertEqual(t, 0, len(errs)) - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// ReleaseArgs(args) +// } +// } - code, body, errs := a.String() +// func Test_Client_Put(t *testing.T) { +// t.Parallel() - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "bar", body) - utils.AssertEqual(t, 0, len(errs)) +// ln := fasthttputil.NewInmemoryListener() - ReleaseArgs(args) - } -} +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) -func Test_Client_Patch(t *testing.T) { - t.Parallel() +// app.Put("/", func(c fiber.Ctx) error { +// return c.SendString(c.FormValue("foo")) +// }) - ln := fasthttputil.NewInmemoryListener() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// for i := 0; i < 5; i++ { +// args := AcquireArgs() - app.Patch("/", func(c fiber.Ctx) error { - return c.SendString(c.FormValue("foo")) - }) +// args.Set("foo", "bar") - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// a := Put("http://example.com"). +// Form(args) - for i := 0; i < 5; i++ { - args := AcquireArgs() +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - args.Set("foo", "bar") +// code, body, errs := a.String() - a := Patch("http://example.com"). - Form(args) +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "bar", body) +// utils.AssertEqual(t, 0, len(errs)) - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// ReleaseArgs(args) +// } +// } - code, body, errs := a.String() +// func Test_Client_Patch(t *testing.T) { +// t.Parallel() - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "bar", body) - utils.AssertEqual(t, 0, len(errs)) +// ln := fasthttputil.NewInmemoryListener() - ReleaseArgs(args) - } -} +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) -func Test_Client_Delete(t *testing.T) { - t.Parallel() +// app.Patch("/", func(c fiber.Ctx) error { +// return c.SendString(c.FormValue("foo")) +// }) - ln := fasthttputil.NewInmemoryListener() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// for i := 0; i < 5; i++ { +// args := AcquireArgs() - app.Delete("/", func(c fiber.Ctx) error { - return c.Status(fiber.StatusNoContent). - SendString("deleted") - }) +// args.Set("foo", "bar") - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// a := Patch("http://example.com"). +// Form(args) - for i := 0; i < 5; i++ { - args := AcquireArgs() +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - a := Delete("http://example.com") +// code, body, errs := a.String() - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "bar", body) +// utils.AssertEqual(t, 0, len(errs)) - code, body, errs := a.String() +// ReleaseArgs(args) +// } +// } - utils.AssertEqual(t, fiber.StatusNoContent, code) - utils.AssertEqual(t, "", body) - utils.AssertEqual(t, 0, len(errs)) +// func Test_Client_Delete(t *testing.T) { +// t.Parallel() - ReleaseArgs(args) - } -} +// ln := fasthttputil.NewInmemoryListener() -func Test_Client_UserAgent(t *testing.T) { - t.Parallel() +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - ln := fasthttputil.NewInmemoryListener() +// app.Delete("/", func(c fiber.Ctx) error { +// return c.Status(fiber.StatusNoContent). +// SendString("deleted") +// }) - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - app.Get("/", func(c fiber.Ctx) error { - return c.Send(c.Request().Header.UserAgent()) - }) +// for i := 0; i < 5; i++ { +// args := AcquireArgs() - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// a := Delete("http://example.com") - t.Run("default", func(t *testing.T) { - for i := 0; i < 5; i++ { - a := Get("http://example.com") +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// code, body, errs := a.String() - code, body, errs := a.String() +// utils.AssertEqual(t, fiber.StatusNoContent, code) +// utils.AssertEqual(t, "", body) +// utils.AssertEqual(t, 0, len(errs)) - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, defaultUserAgent, body) - utils.AssertEqual(t, 0, len(errs)) - } - }) +// ReleaseArgs(args) +// } +// } - t.Run("custom", func(t *testing.T) { - for i := 0; i < 5; i++ { - c := AcquireClient() - c.UserAgent = "ua" +// func Test_Client_UserAgent(t *testing.T) { +// t.Parallel() - a := c.Get("http://example.com") +// ln := fasthttputil.NewInmemoryListener() - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - code, body, errs := a.String() +// app.Get("/", func(c fiber.Ctx) error { +// return c.Send(c.Request().Header.UserAgent()) +// }) - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "ua", body) - utils.AssertEqual(t, 0, len(errs)) - ReleaseClient(c) - } - }) -} +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() -func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) { - handler := func(c fiber.Ctx) error { - c.Request().Header.VisitAll(func(key, value []byte) { - if k := string(key); k == "K1" || k == "K2" { - _, _ = c.Write(key) - _, _ = c.Write(value) - } - }) - return nil - } - - wrapAgent := func(a *Agent) { - a.Set("k1", "v1"). - SetBytesK([]byte("k1"), "v1"). - SetBytesV("k1", []byte("v1")). - AddBytesK([]byte("k1"), "v11"). - AddBytesV("k1", []byte("v22")). - AddBytesKV([]byte("k1"), []byte("v33")). - SetBytesKV([]byte("k2"), []byte("v2")). - Add("k2", "v22") - } - - testAgent(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") -} +// t.Run("default", func(t *testing.T) { +// for i := 0; i < 5; i++ { +// a := Get("http://example.com") -func Test_Client_Agent_Connection_Close(t *testing.T) { - handler := func(c fiber.Ctx) error { - if c.Request().Header.ConnectionClose() { - return c.SendString("close") - } - return c.SendString("not close") - } +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - wrapAgent := func(a *Agent) { - a.ConnectionClose() - } +// code, body, errs := a.String() - testAgent(t, handler, wrapAgent, "close") -} +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, defaultUserAgent, body) +// utils.AssertEqual(t, 0, len(errs)) +// } +// }) -func Test_Client_Agent_UserAgent(t *testing.T) { - handler := func(c fiber.Ctx) error { - return c.Send(c.Request().Header.UserAgent()) - } +// t.Run("custom", func(t *testing.T) { +// for i := 0; i < 5; i++ { +// c := AcquireClient() +// c.UserAgent = "ua" - wrapAgent := func(a *Agent) { - a.UserAgent("ua"). - UserAgentBytes([]byte("ua")) - } +// a := c.Get("http://example.com") - testAgent(t, handler, wrapAgent, "ua") -} +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } -func Test_Client_Agent_Cookie(t *testing.T) { - handler := func(c fiber.Ctx) error { - return c.SendString( - c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) - } - - wrapAgent := func(a *Agent) { - a.Cookie("k1", "v1"). - CookieBytesK([]byte("k2"), "v2"). - CookieBytesKV([]byte("k2"), []byte("v2")). - Cookies("k3", "v3", "k4", "v4"). - CookiesBytesKV([]byte("k3"), []byte("v3"), []byte("k4"), []byte("v4")) - } - - testAgent(t, handler, wrapAgent, "v1v2v3v4") -} +// code, body, errs := a.String() -func Test_Client_Agent_Referer(t *testing.T) { - handler := func(c fiber.Ctx) error { - return c.Send(c.Request().Header.Referer()) - } +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "ua", body) +// utils.AssertEqual(t, 0, len(errs)) +// ReleaseClient(c) +// } +// }) +// } - wrapAgent := func(a *Agent) { - a.Referer("http://referer.com"). - RefererBytes([]byte("http://referer.com")) - } +// func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// c.Request().Header.VisitAll(func(key, value []byte) { +// if k := string(key); k == "K1" || k == "K2" { +// _, _ = c.Write(key) +// _, _ = c.Write(value) +// } +// }) +// return nil +// } - testAgent(t, handler, wrapAgent, "http://referer.com") -} +// wrapAgent := func(a *Agent) { +// a.Set("k1", "v1"). +// SetBytesK([]byte("k1"), "v1"). +// SetBytesV("k1", []byte("v1")). +// AddBytesK([]byte("k1"), "v11"). +// AddBytesV("k1", []byte("v22")). +// AddBytesKV([]byte("k1"), []byte("v33")). +// SetBytesKV([]byte("k2"), []byte("v2")). +// Add("k2", "v22") +// } -func Test_Client_Agent_ContentType(t *testing.T) { - handler := func(c fiber.Ctx) error { - return c.Send(c.Request().Header.ContentType()) - } +// testAgent(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") +// } - wrapAgent := func(a *Agent) { - a.ContentType("custom-type"). - ContentTypeBytes([]byte("custom-type")) - } +// func Test_Client_Agent_Connection_Close(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// if c.Request().Header.ConnectionClose() { +// return c.SendString("close") +// } +// return c.SendString("not close") +// } - testAgent(t, handler, wrapAgent, "custom-type") -} +// wrapAgent := func(a *Agent) { +// a.ConnectionClose() +// } -func Test_Client_Agent_Host(t *testing.T) { - t.Parallel() +// testAgent(t, handler, wrapAgent, "close") +// } - ln := fasthttputil.NewInmemoryListener() +// func Test_Client_Agent_UserAgent(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.Send(c.Request().Header.UserAgent()) +// } + +// wrapAgent := func(a *Agent) { +// a.UserAgent("ua"). +// UserAgentBytes([]byte("ua")) +// } + +// testAgent(t, handler, wrapAgent, "ua") +// } + +// func Test_Client_Agent_Cookie(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.SendString( +// c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) +// } + +// wrapAgent := func(a *Agent) { +// a.Cookie("k1", "v1"). +// CookieBytesK([]byte("k2"), "v2"). +// CookieBytesKV([]byte("k2"), []byte("v2")). +// Cookies("k3", "v3", "k4", "v4"). +// CookiesBytesKV([]byte("k3"), []byte("v3"), []byte("k4"), []byte("v4")) +// } + +// testAgent(t, handler, wrapAgent, "v1v2v3v4") +// } + +// func Test_Client_Agent_Referer(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.Send(c.Request().Header.Referer()) +// } + +// wrapAgent := func(a *Agent) { +// a.Referer("http://referer.com"). +// RefererBytes([]byte("http://referer.com")) +// } + +// testAgent(t, handler, wrapAgent, "http://referer.com") +// } + +// func Test_Client_Agent_ContentType(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.Send(c.Request().Header.ContentType()) +// } + +// wrapAgent := func(a *Agent) { +// a.ContentType("custom-type"). +// ContentTypeBytes([]byte("custom-type")) +// } + +// testAgent(t, handler, wrapAgent, "custom-type") +// } + +// func Test_Client_Agent_Host(t *testing.T) { +// t.Parallel() + +// ln := fasthttputil.NewInmemoryListener() + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString(c.Hostname()) +// }) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// a := Get("http://1.1.1.1:8080"). +// Host("example.com"). +// HostBytes([]byte("example.com")) - app.Get("/", func(c fiber.Ctx) error { - return c.SendString(c.Hostname()) - }) +// utils.AssertEqual(t, "1.1.1.1:8080", a.HostClient.Addr) - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - a := Get("http://1.1.1.1:8080"). - Host("example.com"). - HostBytes([]byte("example.com")) +// code, body, errs := a.String() - utils.AssertEqual(t, "1.1.1.1:8080", a.HostClient.Addr) +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "example.com", body) +// utils.AssertEqual(t, 0, len(errs)) +// } - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// func Test_Client_Agent_QueryString(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.Send(c.Request().URI().QueryString()) +// } - code, body, errs := a.String() +// wrapAgent := func(a *Agent) { +// a.QueryString("foo=bar&bar=baz"). +// QueryStringBytes([]byte("foo=bar&bar=baz")) +// } - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "example.com", body) - utils.AssertEqual(t, 0, len(errs)) -} +// testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") +// } -func Test_Client_Agent_QueryString(t *testing.T) { - handler := func(c fiber.Ctx) error { - return c.Send(c.Request().URI().QueryString()) - } +// func Test_Client_Agent_BasicAuth(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// // Get authorization header +// auth := c.Get(fiber.HeaderAuthorization) +// // Decode the header contents +// raw, err := base64.StdEncoding.DecodeString(auth[6:]) +// utils.AssertEqual(t, nil, err) - wrapAgent := func(a *Agent) { - a.QueryString("foo=bar&bar=baz"). - QueryStringBytes([]byte("foo=bar&bar=baz")) - } +// return c.Send(raw) +// } - testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -} +// wrapAgent := func(a *Agent) { +// a.BasicAuth("foo", "bar"). +// BasicAuthBytes([]byte("foo"), []byte("bar")) +// } -func Test_Client_Agent_BasicAuth(t *testing.T) { - handler := func(c fiber.Ctx) error { - // Get authorization header - auth := c.Get(fiber.HeaderAuthorization) - // Decode the header contents - raw, err := base64.StdEncoding.DecodeString(auth[6:]) - utils.AssertEqual(t, nil, err) +// testAgent(t, handler, wrapAgent, "foo:bar") +// } - return c.Send(raw) - } +// func Test_Client_Agent_BodyString(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.Send(c.Request().Body()) +// } - wrapAgent := func(a *Agent) { - a.BasicAuth("foo", "bar"). - BasicAuthBytes([]byte("foo"), []byte("bar")) - } +// wrapAgent := func(a *Agent) { +// a.BodyString("foo=bar&bar=baz") +// } - testAgent(t, handler, wrapAgent, "foo:bar") -} +// testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") +// } -func Test_Client_Agent_BodyString(t *testing.T) { - handler := func(c fiber.Ctx) error { - return c.Send(c.Request().Body()) - } +// func Test_Client_Agent_Body(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.Send(c.Request().Body()) +// } - wrapAgent := func(a *Agent) { - a.BodyString("foo=bar&bar=baz") - } +// wrapAgent := func(a *Agent) { +// a.Body([]byte("foo=bar&bar=baz")) +// } - testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -} +// testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") +// } -func Test_Client_Agent_Body(t *testing.T) { - handler := func(c fiber.Ctx) error { - return c.Send(c.Request().Body()) - } +// func Test_Client_Agent_BodyStream(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.Send(c.Request().Body()) +// } - wrapAgent := func(a *Agent) { - a.Body([]byte("foo=bar&bar=baz")) - } +// wrapAgent := func(a *Agent) { +// a.BodyStream(strings.NewReader("body stream"), -1) +// } - testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -} +// testAgent(t, handler, wrapAgent, "body stream") +// } -func Test_Client_Agent_BodyStream(t *testing.T) { - handler := func(c fiber.Ctx) error { - return c.Send(c.Request().Body()) - } +// func Test_Client_Agent_Custom_Response(t *testing.T) { +// t.Parallel() - wrapAgent := func(a *Agent) { - a.BodyStream(strings.NewReader("body stream"), -1) - } +// ln := fasthttputil.NewInmemoryListener() - testAgent(t, handler, wrapAgent, "body stream") -} +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) -func Test_Client_Agent_Custom_Response(t *testing.T) { - t.Parallel() +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString("custom") +// }) - ln := fasthttputil.NewInmemoryListener() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// for i := 0; i < 5; i++ { +// a := AcquireAgent() +// resp := AcquireResponse() - app.Get("/", func(c fiber.Ctx) error { - return c.SendString("custom") - }) +// req := a.Request() +// req.Header.SetMethod(fiber.MethodGet) +// req.SetRequestURI("http://example.com") - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// utils.AssertEqual(t, nil, a.Parse()) - for i := 0; i < 5; i++ { - a := AcquireAgent() - resp := AcquireResponse() +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - req := a.Request() - req.Header.SetMethod(fiber.MethodGet) - req.SetRequestURI("http://example.com") +// code, body, errs := a.SetResponse(resp). +// String() - utils.AssertEqual(t, nil, a.Parse()) +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "custom", body) +// utils.AssertEqual(t, "custom", string(resp.Body())) +// utils.AssertEqual(t, 0, len(errs)) - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// ReleaseResponse(resp) +// } +// } - code, body, errs := a.SetResponse(resp). - String() +// func Test_Client_Agent_Dest(t *testing.T) { +// t.Parallel() - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "custom", body) - utils.AssertEqual(t, "custom", string(resp.Body())) - utils.AssertEqual(t, 0, len(errs)) +// ln := fasthttputil.NewInmemoryListener() - ReleaseResponse(resp) - } -} +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) -func Test_Client_Agent_Dest(t *testing.T) { - t.Parallel() +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString("dest") +// }) - ln := fasthttputil.NewInmemoryListener() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// t.Run("small dest", func(t *testing.T) { +// dest := []byte("de") - app.Get("/", func(c fiber.Ctx) error { - return c.SendString("dest") - }) +// a := Get("http://example.com") - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - t.Run("small dest", func(t *testing.T) { - dest := []byte("de") +// code, body, errs := a.Dest(dest[:0]).String() - a := Get("http://example.com") +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "dest", body) +// utils.AssertEqual(t, "de", string(dest)) +// utils.AssertEqual(t, 0, len(errs)) +// }) - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// t.Run("enough dest", func(t *testing.T) { +// dest := []byte("foobar") - code, body, errs := a.Dest(dest[:0]).String() +// a := Get("http://example.com") - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "dest", body) - utils.AssertEqual(t, "de", string(dest)) - utils.AssertEqual(t, 0, len(errs)) - }) +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - t.Run("enough dest", func(t *testing.T) { - dest := []byte("foobar") +// code, body, errs := a.Dest(dest[:0]).String() - a := Get("http://example.com") +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "dest", body) +// utils.AssertEqual(t, "destar", string(dest)) +// utils.AssertEqual(t, 0, len(errs)) +// }) +// } - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// // readErrorConn is a struct for testing retryIf +// type readErrorConn struct { +// net.Conn +// } - code, body, errs := a.Dest(dest[:0]).String() +// func (r *readErrorConn) Read(p []byte) (int, error) { +// return 0, fmt.Errorf("error") +// } - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "dest", body) - utils.AssertEqual(t, "destar", string(dest)) - utils.AssertEqual(t, 0, len(errs)) - }) -} +// func (r *readErrorConn) Write(p []byte) (int, error) { +// return len(p), nil +// } -// readErrorConn is a struct for testing retryIf -type readErrorConn struct { - net.Conn -} +// func (r *readErrorConn) Close() error { +// return nil +// } -func (r *readErrorConn) Read(p []byte) (int, error) { - return 0, fmt.Errorf("error") -} +// func (r *readErrorConn) LocalAddr() net.Addr { +// return nil +// } -func (r *readErrorConn) Write(p []byte) (int, error) { - return len(p), nil -} +// func (r *readErrorConn) RemoteAddr() net.Addr { +// return nil +// } +// func Test_Client_Agent_RetryIf(t *testing.T) { +// t.Parallel() -func (r *readErrorConn) Close() error { - return nil -} +// ln := fasthttputil.NewInmemoryListener() -func (r *readErrorConn) LocalAddr() net.Addr { - return nil -} +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) -func (r *readErrorConn) RemoteAddr() net.Addr { - return nil -} -func Test_Client_Agent_RetryIf(t *testing.T) { - t.Parallel() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - ln := fasthttputil.NewInmemoryListener() +// a := Post("http://example.com"). +// RetryIf(func(req *Request) bool { +// return true +// }) +// dialsCount := 0 +// a.HostClient.Dial = func(addr string) (net.Conn, error) { +// dialsCount++ +// switch dialsCount { +// case 1: +// return &readErrorConn{}, nil +// case 2: +// return &readErrorConn{}, nil +// case 3: +// return &readErrorConn{}, nil +// case 4: +// return ln.Dial() +// default: +// t.Fatalf("unexpected number of dials: %d", dialsCount) +// } +// panic("unreachable") +// } - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// _, _, errs := a.String() +// utils.AssertEqual(t, dialsCount, 4) +// utils.AssertEqual(t, 0, len(errs)) +// } - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// func Test_Client_Agent_Json(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// utils.AssertEqual(t, fiber.MIMEApplicationJSON, string(c.Request().Header.ContentType())) - a := Post("http://example.com"). - RetryIf(func(req *Request) bool { - return true - }) - dialsCount := 0 - a.HostClient.Dial = func(addr string) (net.Conn, error) { - dialsCount++ - switch dialsCount { - case 1: - return &readErrorConn{}, nil - case 2: - return &readErrorConn{}, nil - case 3: - return &readErrorConn{}, nil - case 4: - return ln.Dial() - default: - t.Fatalf("unexpected number of dials: %d", dialsCount) - } - panic("unreachable") - } - - _, _, errs := a.String() - utils.AssertEqual(t, dialsCount, 4) - utils.AssertEqual(t, 0, len(errs)) -} +// return c.Send(c.Request().Body()) +// } -func Test_Client_Agent_Json(t *testing.T) { - handler := func(c fiber.Ctx) error { - utils.AssertEqual(t, fiber.MIMEApplicationJSON, string(c.Request().Header.ContentType())) +// wrapAgent := func(a *Agent) { +// a.JSON(data{Success: true}) +// } - return c.Send(c.Request().Body()) - } +// testAgent(t, handler, wrapAgent, `{"success":true}`) +// } - wrapAgent := func(a *Agent) { - a.JSON(data{Success: true}) - } +// func Test_Client_Agent_Json_Error(t *testing.T) { +// a := Get("http://example.com"). +// JSONEncoder(json.Marshal). +// JSON(complex(1, 1)) - testAgent(t, handler, wrapAgent, `{"success":true}`) -} +// _, body, errs := a.String() -func Test_Client_Agent_Json_Error(t *testing.T) { - a := Get("http://example.com"). - JSONEncoder(json.Marshal). - JSON(complex(1, 1)) +// utils.AssertEqual(t, "", body) +// utils.AssertEqual(t, 1, len(errs)) +// utils.AssertEqual(t, "json: unsupported type: complex128", errs[0].Error()) +// } - _, body, errs := a.String() +// func Test_Client_Agent_XML(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// utils.AssertEqual(t, fiber.MIMEApplicationXML, string(c.Request().Header.ContentType())) - utils.AssertEqual(t, "", body) - utils.AssertEqual(t, 1, len(errs)) - utils.AssertEqual(t, "json: unsupported type: complex128", errs[0].Error()) -} +// return c.Send(c.Request().Body()) +// } -func Test_Client_Agent_XML(t *testing.T) { - handler := func(c fiber.Ctx) error { - utils.AssertEqual(t, fiber.MIMEApplicationXML, string(c.Request().Header.ContentType())) +// wrapAgent := func(a *Agent) { +// a.XML(data{Success: true}) +// } - return c.Send(c.Request().Body()) - } +// testAgent(t, handler, wrapAgent, "true") +// } - wrapAgent := func(a *Agent) { - a.XML(data{Success: true}) - } +// func Test_Client_Agent_XML_Error(t *testing.T) { +// a := Get("http://example.com"). +// XML(complex(1, 1)) - testAgent(t, handler, wrapAgent, "true") -} +// _, body, errs := a.String() -func Test_Client_Agent_XML_Error(t *testing.T) { - a := Get("http://example.com"). - XML(complex(1, 1)) +// utils.AssertEqual(t, "", body) +// utils.AssertEqual(t, 1, len(errs)) +// utils.AssertEqual(t, "xml: unsupported type: complex128", errs[0].Error()) +// } - _, body, errs := a.String() +// func Test_Client_Agent_Form(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// utils.AssertEqual(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) - utils.AssertEqual(t, "", body) - utils.AssertEqual(t, 1, len(errs)) - utils.AssertEqual(t, "xml: unsupported type: complex128", errs[0].Error()) -} +// return c.Send(c.Request().Body()) +// } -func Test_Client_Agent_Form(t *testing.T) { - handler := func(c fiber.Ctx) error { - utils.AssertEqual(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) +// args := AcquireArgs() - return c.Send(c.Request().Body()) - } +// args.Set("foo", "bar") - args := AcquireArgs() +// wrapAgent := func(a *Agent) { +// a.Form(args) +// } - args.Set("foo", "bar") +// testAgent(t, handler, wrapAgent, "foo=bar") - wrapAgent := func(a *Agent) { - a.Form(args) - } +// ReleaseArgs(args) +// } - testAgent(t, handler, wrapAgent, "foo=bar") +// func Test_Client_Agent_MultipartForm(t *testing.T) { +// t.Parallel() - ReleaseArgs(args) -} +// ln := fasthttputil.NewInmemoryListener() -func Test_Client_Agent_MultipartForm(t *testing.T) { - t.Parallel() +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - ln := fasthttputil.NewInmemoryListener() +// app.Post("/", func(c fiber.Ctx) error { +// utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// mf, err := c.MultipartForm() +// utils.AssertEqual(t, nil, err) +// utils.AssertEqual(t, "bar", mf.Value["foo"][0]) - app.Post("/", func(c fiber.Ctx) error { - utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) +// return c.Send(c.Request().Body()) +// }) - mf, err := c.MultipartForm() - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "bar", mf.Value["foo"][0]) +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - return c.Send(c.Request().Body()) - }) +// args := AcquireArgs() - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// args.Set("foo", "bar") - args := AcquireArgs() +// a := Post("http://example.com"). +// Boundary("myBoundary"). +// MultipartForm(args) - args.Set("foo", "bar") +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - a := Post("http://example.com"). - Boundary("myBoundary"). - MultipartForm(args) +// code, body, errs := a.String() - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "--myBoundary\r\nContent-Disposition: form-data; name=\"foo\"\r\n\r\nbar\r\n--myBoundary--\r\n", body) +// utils.AssertEqual(t, 0, len(errs)) +// ReleaseArgs(args) +// } - code, body, errs := a.String() +// func Test_Client_Agent_MultipartForm_Errors(t *testing.T) { +// t.Parallel() - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "--myBoundary\r\nContent-Disposition: form-data; name=\"foo\"\r\n\r\nbar\r\n--myBoundary--\r\n", body) - utils.AssertEqual(t, 0, len(errs)) - ReleaseArgs(args) -} - -func Test_Client_Agent_MultipartForm_Errors(t *testing.T) { - t.Parallel() +// a := AcquireAgent() +// a.mw = &errorMultipartWriter{} - a := AcquireAgent() - a.mw = &errorMultipartWriter{} +// args := AcquireArgs() +// args.Set("foo", "bar") - args := AcquireArgs() - args.Set("foo", "bar") +// ff1 := &FormFile{"", "name1", []byte("content"), false} +// ff2 := &FormFile{"", "name2", []byte("content"), false} +// a.FileData(ff1, ff2). +// MultipartForm(args) - ff1 := &FormFile{"", "name1", []byte("content"), false} - ff2 := &FormFile{"", "name2", []byte("content"), false} - a.FileData(ff1, ff2). - MultipartForm(args) - - utils.AssertEqual(t, 4, len(a.errs)) - ReleaseArgs(args) -} +// utils.AssertEqual(t, 4, len(a.errs)) +// ReleaseArgs(args) +// } -func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) { - t.Parallel() +// func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) { +// t.Parallel() - ln := fasthttputil.NewInmemoryListener() +// ln := fasthttputil.NewInmemoryListener() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Post("/", func(c fiber.Ctx) error { - utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) +// app.Post("/", func(c fiber.Ctx) error { +// utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) - fh1, err := c.FormFile("field1") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fh1.Filename, "name") - buf := make([]byte, fh1.Size) - f, err := fh1.Open() - utils.AssertEqual(t, nil, err) - defer func() { _ = f.Close() }() - _, err = f.Read(buf) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "form file", string(buf)) +// fh1, err := c.FormFile("field1") +// utils.AssertEqual(t, nil, err) +// utils.AssertEqual(t, fh1.Filename, "name") +// buf := make([]byte, fh1.Size) +// f, err := fh1.Open() +// utils.AssertEqual(t, nil, err) +// defer func() { _ = f.Close() }() +// _, err = f.Read(buf) +// utils.AssertEqual(t, nil, err) +// utils.AssertEqual(t, "form file", string(buf)) - fh2, err := c.FormFile("index") - utils.AssertEqual(t, nil, err) - checkFormFile(t, fh2, ".github/testdata/index.html") +// fh2, err := c.FormFile("index") +// utils.AssertEqual(t, nil, err) +// checkFormFile(t, fh2, ".github/testdata/index.html") - fh3, err := c.FormFile("file3") - utils.AssertEqual(t, nil, err) - checkFormFile(t, fh3, ".github/testdata/index.tmpl") +// fh3, err := c.FormFile("file3") +// utils.AssertEqual(t, nil, err) +// checkFormFile(t, fh3, ".github/testdata/index.tmpl") - return c.SendString("multipart form files") - }) +// return c.SendString("multipart form files") +// }) - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - for i := 0; i < 5; i++ { - ff := AcquireFormFile() - ff.Fieldname = "field1" - ff.Name = "name" - ff.Content = []byte("form file") +// for i := 0; i < 5; i++ { +// ff := AcquireFormFile() +// ff.Fieldname = "field1" +// ff.Name = "name" +// ff.Content = []byte("form file") - a := Post("http://example.com"). - Boundary("myBoundary"). - FileData(ff). - SendFiles(".github/testdata/index.html", "index", ".github/testdata/index.tmpl"). - MultipartForm(nil) +// a := Post("http://example.com"). +// Boundary("myBoundary"). +// FileData(ff). +// SendFiles(".github/testdata/index.html", "index", ".github/testdata/index.tmpl"). +// MultipartForm(nil) - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - code, body, errs := a.String() +// code, body, errs := a.String() - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "multipart form files", body) - utils.AssertEqual(t, 0, len(errs)) +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "multipart form files", body) +// utils.AssertEqual(t, 0, len(errs)) - ReleaseFormFile(ff) - } -} +// ReleaseFormFile(ff) +// } +// } -func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { - t.Helper() +// func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { +// t.Helper() - basename := filepath.Base(filename) - utils.AssertEqual(t, fh.Filename, basename) +// basename := filepath.Base(filename) +// utils.AssertEqual(t, fh.Filename, basename) - b1, err := os.ReadFile(filename) - utils.AssertEqual(t, nil, err) +// b1, err := os.ReadFile(filename) +// utils.AssertEqual(t, nil, err) - b2 := make([]byte, fh.Size) - f, err := fh.Open() - utils.AssertEqual(t, nil, err) - defer func() { _ = f.Close() }() - _, err = f.Read(b2) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, b1, b2) -} +// b2 := make([]byte, fh.Size) +// f, err := fh.Open() +// utils.AssertEqual(t, nil, err) +// defer func() { _ = f.Close() }() +// _, err = f.Read(b2) +// utils.AssertEqual(t, nil, err) +// utils.AssertEqual(t, b1, b2) +// } -func Test_Client_Agent_Multipart_Random_Boundary(t *testing.T) { - t.Parallel() +// func Test_Client_Agent_Multipart_Random_Boundary(t *testing.T) { +// t.Parallel() - a := Post("http://example.com"). - MultipartForm(nil) +// a := Post("http://example.com"). +// MultipartForm(nil) - reg := regexp.MustCompile(`multipart/form-data; boundary=\w{30}`) +// reg := regexp.MustCompile(`multipart/form-data; boundary=\w{30}`) - utils.AssertEqual(t, true, reg.Match(a.req.Header.Peek(fiber.HeaderContentType))) -} +// utils.AssertEqual(t, true, reg.Match(a.req.Header.Peek(fiber.HeaderContentType))) +// } -func Test_Client_Agent_Multipart_Invalid_Boundary(t *testing.T) { - t.Parallel() +// func Test_Client_Agent_Multipart_Invalid_Boundary(t *testing.T) { +// t.Parallel() - a := Post("http://example.com"). - Boundary("*"). - MultipartForm(nil) +// a := Post("http://example.com"). +// Boundary("*"). +// MultipartForm(nil) - utils.AssertEqual(t, 1, len(a.errs)) - utils.AssertEqual(t, "mime: invalid boundary character", a.errs[0].Error()) -} +// utils.AssertEqual(t, 1, len(a.errs)) +// utils.AssertEqual(t, "mime: invalid boundary character", a.errs[0].Error()) +// } -func Test_Client_Agent_SendFile_Error(t *testing.T) { - t.Parallel() +// func Test_Client_Agent_SendFile_Error(t *testing.T) { +// t.Parallel() - a := Post("http://example.com"). - SendFile("non-exist-file!", "") +// a := Post("http://example.com"). +// SendFile("non-exist-file!", "") - utils.AssertEqual(t, 1, len(a.errs)) - utils.AssertEqual(t, true, strings.Contains(a.errs[0].Error(), "open non-exist-file!")) -} +// utils.AssertEqual(t, 1, len(a.errs)) +// utils.AssertEqual(t, true, strings.Contains(a.errs[0].Error(), "open non-exist-file!")) +// } -func Test_Client_Debug(t *testing.T) { - handler := func(c fiber.Ctx) error { - return c.SendString("debug") - } +// func Test_Client_Debug(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.SendString("debug") +// } - var output bytes.Buffer +// var output bytes.Buffer - wrapAgent := func(a *Agent) { - a.Debug(&output) - } +// wrapAgent := func(a *Agent) { +// a.Debug(&output) +// } - testAgent(t, handler, wrapAgent, "debug", 1) +// testAgent(t, handler, wrapAgent, "debug", 1) - str := output.String() +// str := output.String() - utils.AssertEqual(t, true, strings.Contains(str, "Connected to example.com(pipe)")) - utils.AssertEqual(t, true, strings.Contains(str, "GET / HTTP/1.1")) - utils.AssertEqual(t, true, strings.Contains(str, "User-Agent: fiber")) - utils.AssertEqual(t, true, strings.Contains(str, "Host: example.com\r\n\r\n")) - utils.AssertEqual(t, true, strings.Contains(str, "HTTP/1.1 200 OK")) - utils.AssertEqual(t, true, strings.Contains(str, "Content-Type: text/plain; charset=utf-8\r\nContent-Length: 5\r\n\r\ndebug")) -} +// utils.AssertEqual(t, true, strings.Contains(str, "Connected to example.com(pipe)")) +// utils.AssertEqual(t, true, strings.Contains(str, "GET / HTTP/1.1")) +// utils.AssertEqual(t, true, strings.Contains(str, "User-Agent: fiber")) +// utils.AssertEqual(t, true, strings.Contains(str, "Host: example.com\r\n\r\n")) +// utils.AssertEqual(t, true, strings.Contains(str, "HTTP/1.1 200 OK")) +// utils.AssertEqual(t, true, strings.Contains(str, "Content-Type: text/plain; charset=utf-8\r\nContent-Length: 5\r\n\r\ndebug")) +// } -func Test_Client_Agent_Timeout(t *testing.T) { - t.Parallel() +// func Test_Client_Agent_Timeout(t *testing.T) { +// t.Parallel() - ln := fasthttputil.NewInmemoryListener() +// ln := fasthttputil.NewInmemoryListener() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c fiber.Ctx) error { - time.Sleep(time.Millisecond * 200) - return c.SendString("timeout") - }) +// app.Get("/", func(c fiber.Ctx) error { +// time.Sleep(time.Millisecond * 200) +// return c.SendString("timeout") +// }) - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - a := Get("http://example.com"). - Timeout(time.Millisecond * 50) +// a := Get("http://example.com"). +// Timeout(time.Millisecond * 50) - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - _, body, errs := a.String() +// _, body, errs := a.String() - utils.AssertEqual(t, "", body) - utils.AssertEqual(t, 1, len(errs)) - utils.AssertEqual(t, "timeout", errs[0].Error()) -} +// utils.AssertEqual(t, "", body) +// utils.AssertEqual(t, 1, len(errs)) +// utils.AssertEqual(t, "timeout", errs[0].Error()) +// } -func Test_Client_Agent_Reuse(t *testing.T) { - t.Parallel() +// func Test_Client_Agent_Reuse(t *testing.T) { +// t.Parallel() - ln := fasthttputil.NewInmemoryListener() +// ln := fasthttputil.NewInmemoryListener() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c fiber.Ctx) error { - return c.SendString("reuse") - }) +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString("reuse") +// }) - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - a := Get("http://example.com"). - Reuse() +// a := Get("http://example.com"). +// Reuse() - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - code, body, errs := a.String() +// code, body, errs := a.String() - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "reuse", body) - utils.AssertEqual(t, 0, len(errs)) +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "reuse", body) +// utils.AssertEqual(t, 0, len(errs)) - code, body, errs = a.String() +// code, body, errs = a.String() - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "reuse", body) - utils.AssertEqual(t, 0, len(errs)) -} +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "reuse", body) +// utils.AssertEqual(t, 0, len(errs)) +// } -func Test_Client_Agent_InsecureSkipVerify(t *testing.T) { - t.Parallel() +// func Test_Client_Agent_InsecureSkipVerify(t *testing.T) { +// t.Parallel() - cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") - utils.AssertEqual(t, nil, err) +// cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") +// utils.AssertEqual(t, nil, err) - serverTLSConf := &tls.Config{ - Certificates: []tls.Certificate{cer}, - } +// serverTLSConf := &tls.Config{ +// Certificates: []tls.Certificate{cer}, +// } - ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") - utils.AssertEqual(t, nil, err) +// ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") +// utils.AssertEqual(t, nil, err) - ln = tls.NewListener(ln, serverTLSConf) +// ln = tls.NewListener(ln, serverTLSConf) - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c fiber.Ctx) error { - return c.SendString("ignore tls") - }) +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString("ignore tls") +// }) - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - code, body, errs := Get("https://" + ln.Addr().String()). - InsecureSkipVerify(). - InsecureSkipVerify(). - String() +// code, body, errs := Get("https://" + ln.Addr().String()). +// InsecureSkipVerify(). +// InsecureSkipVerify(). +// String() - utils.AssertEqual(t, 0, len(errs)) - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "ignore tls", body) -} +// utils.AssertEqual(t, 0, len(errs)) +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "ignore tls", body) +// } -func Test_Client_Agent_TLS(t *testing.T) { - t.Parallel() +// func Test_Client_Agent_TLS(t *testing.T) { +// t.Parallel() - serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() - utils.AssertEqual(t, nil, err) +// serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() +// utils.AssertEqual(t, nil, err) - ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") - utils.AssertEqual(t, nil, err) +// ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") +// utils.AssertEqual(t, nil, err) - ln = tls.NewListener(ln, serverTLSConf) +// ln = tls.NewListener(ln, serverTLSConf) - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c fiber.Ctx) error { - return c.SendString("tls") - }) +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString("tls") +// }) - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - code, body, errs := Get("https://" + ln.Addr().String()). - TLSConfig(clientTLSConf). - String() +// code, body, errs := Get("https://" + ln.Addr().String()). +// TLSConfig(clientTLSConf). +// String() - utils.AssertEqual(t, 0, len(errs)) - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, "tls", body) -} +// utils.AssertEqual(t, 0, len(errs)) +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "tls", body) +// } -func Test_Client_Agent_MaxRedirectsCount(t *testing.T) { - t.Parallel() +// func Test_Client_Agent_MaxRedirectsCount(t *testing.T) { +// t.Parallel() - ln := fasthttputil.NewInmemoryListener() +// ln := fasthttputil.NewInmemoryListener() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c fiber.Ctx) error { - if c.Request().URI().QueryArgs().Has("foo") { - return c.Redirect("/foo") - } - return c.Redirect("/") - }) - app.Get("/foo", func(c fiber.Ctx) error { - return c.SendString("redirect") - }) +// app.Get("/", func(c fiber.Ctx) error { +// if c.Request().URI().QueryArgs().Has("foo") { +// return c.Redirect("/foo") +// } +// return c.Redirect("/") +// }) +// app.Get("/foo", func(c fiber.Ctx) error { +// return c.SendString("redirect") +// }) - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - t.Run("success", func(t *testing.T) { - a := Get("http://example.com?foo"). - MaxRedirectsCount(1) +// t.Run("success", func(t *testing.T) { +// a := Get("http://example.com?foo"). +// MaxRedirectsCount(1) - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - code, body, errs := a.String() +// code, body, errs := a.String() - utils.AssertEqual(t, 200, code) - utils.AssertEqual(t, "redirect", body) - utils.AssertEqual(t, 0, len(errs)) - }) +// utils.AssertEqual(t, 200, code) +// utils.AssertEqual(t, "redirect", body) +// utils.AssertEqual(t, 0, len(errs)) +// }) - t.Run("error", func(t *testing.T) { - a := Get("http://example.com"). - MaxRedirectsCount(1) +// t.Run("error", func(t *testing.T) { +// a := Get("http://example.com"). +// MaxRedirectsCount(1) - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - _, body, errs := a.String() +// _, body, errs := a.String() - utils.AssertEqual(t, "", body) - utils.AssertEqual(t, 1, len(errs)) - utils.AssertEqual(t, "too many redirects detected when doing the request", errs[0].Error()) - }) -} +// utils.AssertEqual(t, "", body) +// utils.AssertEqual(t, 1, len(errs)) +// utils.AssertEqual(t, "too many redirects detected when doing the request", errs[0].Error()) +// }) +// } -func Test_Client_Agent_Struct(t *testing.T) { - t.Parallel() +// func Test_Client_Agent_Struct(t *testing.T) { +// t.Parallel() - ln := fasthttputil.NewInmemoryListener() +// ln := fasthttputil.NewInmemoryListener() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", func(c fiber.Ctx) error { - return c.JSON(data{true}) - }) +// app.Get("/", func(c fiber.Ctx) error { +// return c.JSON(data{true}) +// }) - app.Get("/error", func(c fiber.Ctx) error { - return c.SendString(`{"success"`) - }) +// app.Get("/error", func(c fiber.Ctx) error { +// return c.SendString(`{"success"`) +// }) - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - t.Run("success", func(t *testing.T) { - t.Parallel() +// t.Run("success", func(t *testing.T) { +// t.Parallel() - a := Get("http://example.com") +// a := Get("http://example.com") - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - var d data +// var d data - code, body, errs := a.Struct(&d) +// code, body, errs := a.Struct(&d) - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, `{"success":true}`, string(body)) - utils.AssertEqual(t, 0, len(errs)) - utils.AssertEqual(t, true, d.Success) - }) +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, `{"success":true}`, string(body)) +// utils.AssertEqual(t, 0, len(errs)) +// utils.AssertEqual(t, true, d.Success) +// }) - t.Run("pre error", func(t *testing.T) { - t.Parallel() - a := Get("http://example.com") +// t.Run("pre error", func(t *testing.T) { +// t.Parallel() +// a := Get("http://example.com") - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - a.errs = append(a.errs, errors.New("pre errors")) +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// a.errs = append(a.errs, errors.New("pre errors")) - var d data - _, body, errs := a.Struct(&d) +// var d data +// _, body, errs := a.Struct(&d) - utils.AssertEqual(t, "", string(body)) - utils.AssertEqual(t, 1, len(errs)) - utils.AssertEqual(t, "pre errors", errs[0].Error()) - utils.AssertEqual(t, false, d.Success) - }) +// utils.AssertEqual(t, "", string(body)) +// utils.AssertEqual(t, 1, len(errs)) +// utils.AssertEqual(t, "pre errors", errs[0].Error()) +// utils.AssertEqual(t, false, d.Success) +// }) - t.Run("error", func(t *testing.T) { - a := Get("http://example.com/error") +// t.Run("error", func(t *testing.T) { +// a := Get("http://example.com/error") - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - var d data +// var d data - code, body, errs := a.JSONDecoder(json.Unmarshal).Struct(&d) +// code, body, errs := a.JSONDecoder(json.Unmarshal).Struct(&d) - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, `{"success"`, string(body)) - utils.AssertEqual(t, 1, len(errs)) - utils.AssertEqual(t, "unexpected end of JSON input", errs[0].Error()) - }) -} +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, `{"success"`, string(body)) +// utils.AssertEqual(t, 1, len(errs)) +// utils.AssertEqual(t, "unexpected end of JSON input", errs[0].Error()) +// }) +// } -func Test_Client_Agent_Parse(t *testing.T) { - t.Parallel() +// func Test_Client_Agent_Parse(t *testing.T) { +// t.Parallel() - a := Get("https://example.com:10443") +// a := Get("https://example.com:10443") - utils.AssertEqual(t, nil, a.Parse()) -} +// utils.AssertEqual(t, nil, a.Parse()) +// } -func Test_AddMissingPort_TLS(t *testing.T) { - addr := addMissingPort("example.com", true) - utils.AssertEqual(t, "example.com:443", addr) -} +// func Test_AddMissingPort_TLS(t *testing.T) { +// addr := addMissingPort("example.com", true) +// utils.AssertEqual(t, "example.com:443", addr) +// } -func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Agent), excepted string, count ...int) { - t.Parallel() +// func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Agent), excepted string, count ...int) { +// t.Parallel() - ln := fasthttputil.NewInmemoryListener() +// ln := fasthttputil.NewInmemoryListener() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - app.Get("/", handler) +// app.Get("/", handler) - go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - c := 1 - if len(count) > 0 { - c = count[0] - } +// c := 1 +// if len(count) > 0 { +// c = count[0] +// } - for i := 0; i < c; i++ { - a := Get("http://example.com") +// for i := 0; i < c; i++ { +// a := Get("http://example.com") - wrapAgent(a) +// wrapAgent(a) - a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - code, body, errs := a.String() +// code, body, errs := a.String() - utils.AssertEqual(t, fiber.StatusOK, code) - utils.AssertEqual(t, excepted, body) - utils.AssertEqual(t, 0, len(errs)) - } -} +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, excepted, body) +// utils.AssertEqual(t, 0, len(errs)) +// } +// } -type data struct { - Success bool `json:"success" xml:"success"` -} +// type data struct { +// Success bool `json:"success" xml:"success"` +// } -type errorMultipartWriter struct { - count int -} +// type errorMultipartWriter struct { +// count int +// } -func (e *errorMultipartWriter) Boundary() string { return "myBoundary" } -func (e *errorMultipartWriter) SetBoundary(_ string) error { return nil } -func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) { - if e.count == 0 { - e.count++ - return nil, errors.New("CreateFormFile error") - } - return errorWriter{}, nil -} -func (e *errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") } -func (e *errorMultipartWriter) Close() error { return errors.New("Close error") } +// func (e *errorMultipartWriter) Boundary() string { return "myBoundary" } +// func (e *errorMultipartWriter) SetBoundary(_ string) error { return nil } +// func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) { +// if e.count == 0 { +// e.count++ +// return nil, errors.New("CreateFormFile error") +// } +// return errorWriter{}, nil +// } +// func (e *errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") } +// func (e *errorMultipartWriter) Close() error { return errors.New("Close error") } -type errorWriter struct{} +// type errorWriter struct{} -func (errorWriter) Write(_ []byte) (int, error) { return 0, errors.New("Write error") } +// func (errorWriter) Write(_ []byte) (int, error) { return 0, errors.New("Write error") } diff --git a/client/core.go b/client/core.go new file mode 100644 index 0000000000..7507e5e7ce --- /dev/null +++ b/client/core.go @@ -0,0 +1,230 @@ +package client + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "sync" + + "github.com/gofiber/fiber/v3/utils" + "github.com/valyala/fasthttp" +) + +// RequestHook is a function that receives Agent and Request, +// it can change the data in Request and Agent. +// +// Called before a request is sent. +type RequestHook func(*Client, *Request) error + +// ResponseHook is a function that receives Agent, Respose and Request, +// it can change the data is Respose or deal with some effects. +// +// Called after a respose has been received. +type ResponseHook func(*Client, *Response, *Request) error + +// ExecuteFunc will actually execute the request via fasthttp. +type ExecuteFunc func(context.Context, *Client, *Request) (*Response, error) + +// Plugin can change the execution flow of requests. +type Plugin interface { + // Return the plugin name and the name should be different. + Name() string + + // Determine if the plugin should be executed based on the conditions. + Check() bool + + // Modify specific request execution methods, + // such as adding timeouts, cancellations, retries and other operations. + GenerateExecute(ExecuteFunc) (ExecuteFunc, error) +} + +// `core` stores middleware and plugin definitions, +// and defines the execution process +type core struct { + client *fasthttp.HostClient + + // user defined request hooks + userRequestHooks []RequestHook + + // client package defined request hooks + buildinRequestHooks []RequestHook + + // user defined response hooks + userResponseHooks []ResponseHook + + // client package defined respose hooks + buildinResposeHooks []ResponseHook + + // store plugins + plugins []Plugin + pluginMap map[string]Plugin + + jsonMarshal utils.JSONMarshal + jsonUnmarshal utils.JSONUnmarshal + xmlMarshal utils.XMLMarshal + xmlUnmarshal utils.XMLUnmarshal +} + +// execute will exec each hooks and plugins. +func (c *core) execute(ctx context.Context, agent *Client, req *Request) (*Response, error) { + var execFunc ExecuteFunc = func(ctx context.Context, a *Client, r *Request) (*Response, error) { + resp := AcquireResponse() + + // To avoid memory allocation reuse of data structures such as errch. + errCh, reqv, respv := acquireErrChan(), fasthttp.AcquireRequest(), fasthttp.AcquireResponse() + defer func() { + releaseErrChan(errCh) + fasthttp.ReleaseRequest(reqv) + fasthttp.ReleaseResponse(respv) + }() + + req.rawRequest.CopyTo(reqv) + go func() { + err := c.client.Do(reqv, respv) + if err != nil { + errCh <- err + return + } + respv.CopyTo(resp.rawResponse) + errCh <- nil + }() + + select { + case err := <-errCh: + if err != nil { + // When get error should release Response + ReleaseResponse(resp) + return nil, err + } + return resp, nil + case <-ctx.Done(): + return nil, fmt.Errorf("timeout error") + } + } + + // The built-in hooks will be executed only + // after the user-defined hooks are executed。 + for _, f := range c.userRequestHooks { + err := f(agent, req) + if err != nil { + return nil, err + } + } + + for _, f := range c.buildinRequestHooks { + err := f(agent, req) + if err != nil { + return nil, err + } + } + + // Call the plugins to generate the real request function. + for _, p := range c.plugins { + if !p.Check() { + continue + } + + var err error + execFunc, err = p.GenerateExecute(execFunc) + if err != nil { + return nil, err + } + } + + // Do http request + resp, err := execFunc(ctx, agent, req) + if err != nil { + return nil, err + } + + // The built-in hooks will be executed only + // before the user-defined hooks are executed. + for _, f := range c.buildinResposeHooks { + err := f(agent, resp, req) + if err != nil { + return nil, err + } + } + + for _, f := range c.userResponseHooks { + err := f(agent, resp, req) + if err != nil { + return nil, err + } + } + + return resp, nil +} + +// reset clears core object. +// It will not clear buildin hooks. +func (c *core) reset() { + c.userRequestHooks = c.userRequestHooks[:0] + c.userResponseHooks = c.userResponseHooks[:0] + c.plugins = c.plugins[:0] + + for k := range c.pluginMap { + delete(c.pluginMap, k) + } +} + +var errChanPool sync.Pool + +// acquireErrChan returns an empty error chan from the pool. +// +// The returned error chan may be returned to the pool with releaseErrChan when no longer needed. +// This allows reducing GC load. +func acquireErrChan() (ch chan error) { + chv := errChanPool.Get() + if chv != nil { + ch = chv.(chan error) + return + } + ch = make(chan error, 1) + return +} + +// releaseErrChan returns the object acquired via acquireErrChan to the pool. +// +// Do not access the released core object, otherwise data races may occur. +func releaseErrChan(ch chan error) { + errChanPool.Put(ch) +} + +var corePool sync.Pool + +// acquireCore returns an empty core object from the pool. +// +// The returned core may be returned to the pool with releaseCore when no longer needed. +// This allows reducing GC load. +func acquireCore() (c *core) { + cv := corePool.Get() + if cv != nil { + c = cv.(*core) + return + } + c = &core{ + client: &fasthttp.HostClient{}, + userRequestHooks: []RequestHook{}, + buildinRequestHooks: []RequestHook{parserURL}, + userResponseHooks: []ResponseHook{}, + buildinResposeHooks: []ResponseHook{}, + plugins: []Plugin{}, + pluginMap: map[string]Plugin{}, + jsonMarshal: json.Marshal, + jsonUnmarshal: json.Unmarshal, + xmlMarshal: xml.Marshal, + xmlUnmarshal: xml.Unmarshal, + } + + return +} + +// releaseCore returns the object acquired via acquireCore to the pool. +// +// Do not access the released core object, otherwise data races may occur. +func releaseCore(c *core) { + c.reset() + corePool.Put(c) +} diff --git a/client/hooks.go b/client/hooks.go new file mode 100644 index 0000000000..c37e2981b5 --- /dev/null +++ b/client/hooks.go @@ -0,0 +1,47 @@ +package client + +import ( + "bytes" + "fmt" + "net" + "strconv" + "strings" +) + +var ( + httpBytes = []byte("http") + httpsBytes = []byte("https") +) + +// parserURL will set the options for the hostclient +// and normalize the url. +func parserURL(c *Client, req *Request) error { + req.rawRequest.SetRequestURI(req.url) + + uri := req.rawRequest.URI() + + isTLS, scheme := false, uri.Scheme() + if bytes.Equal(httpsBytes, scheme) { + isTLS = true + } else if !bytes.Equal(httpBytes, scheme) { + return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme) + } + + c.client.Addr = addMissingPort(string(uri.Host()), isTLS) + c.client.IsTLS = isTLS + + return nil +} + +// addMissingPort will add the corresponding port number for host. +func addMissingPort(addr string, isTLS bool) string { + n := strings.Index(addr, ":") + if n >= 0 { + return addr + } + port := 80 + if isTLS { + port = 443 + } + return net.JoinHostPort(addr, strconv.Itoa(port)) +} diff --git a/client/plugins.go b/client/plugins.go new file mode 100644 index 0000000000..da13c8ef3c --- /dev/null +++ b/client/plugins.go @@ -0,0 +1 @@ +package client diff --git a/client/request.go b/client/request.go new file mode 100644 index 0000000000..de3fb0e412 --- /dev/null +++ b/client/request.go @@ -0,0 +1,77 @@ +package client + +import ( + "context" + "sync" + + "github.com/valyala/fasthttp" +) + +type Request struct { + url string + method string + ctx context.Context + rawRequest *fasthttp.Request +} + +func (r *Request) SetURL(url string) *Request { + r.url = url + return r +} + +func (r *Request) SetMethod(method string) *Request { + r.method = method + return r +} + +// Context returns the Context if its already set in request +// otherwise it creates new one using `context.Background()`. +func (r *Request) Context() context.Context { + if r.ctx == nil { + return context.Background() + } + return r.ctx +} + +// SetContext sets the context.Context for current Request. It allows +// to interrupt the request execution if ctx.Done() channel is closed. +// See https://blog.golang.org/context article and the "context" package +// documentation. +func (r *Request) SetContext(ctx context.Context) *Request { + r.ctx = ctx + return r +} + +// Reset clear Request object, used by ReleaseRequest method. +func (r *Request) Reset() { + r.url = "" + + r.rawRequest.Reset() +} + +var requestPool sync.Pool + +// AcquireRequest returns an empty core object from the pool. +// +// The returned core may be returned to the pool with ReleaseRequest when no longer needed. +// This allows reducing GC load. +func AcquireRequest() (req *Request) { + reqv := requestPool.Get() + if reqv != nil { + req = reqv.(*Request) + return + } + + req = &Request{ + rawRequest: fasthttp.AcquireRequest(), + } + return +} + +// ReleaseRequest returns the object acquired via AcquireRequest to the pool. +// +// Do not access the released core object, otherwise data races may occur. +func ReleaseRequest(req *Request) { + req.Reset() + requestPool.Put(req) +} diff --git a/client/respose.go b/client/respose.go new file mode 100644 index 0000000000..6a9aeab4de --- /dev/null +++ b/client/respose.go @@ -0,0 +1,43 @@ +package client + +import ( + "sync" + + "github.com/valyala/fasthttp" +) + +type Response struct { + rawResponse *fasthttp.Response +} + +// Reset clear Response object. +func (r *Response) Reset() { + r.rawResponse.Reset() +} + +var responsePool sync.Pool + +// AcquireResponse returns an empty core object from the pool. +// +// The returned core may be returned to the pool with ReleaseResponse when no longer needed. +// This allows reducing GC load. +func AcquireResponse() (resp *Response) { + respv := responsePool.Get() + if respv != nil { + resp = respv.(*Response) + return + } + resp = &Response{ + rawResponse: fasthttp.AcquireResponse(), + } + + return +} + +// ReleaseResponse returns the object acquired via AcquireResponse to the pool. +// +// Do not access the released core object, otherwise data races may occur. +func ReleaseResponse(resp *Response) { + resp.Reset() + responsePool.Put(resp) +} From cf5fa5e264786841c00037e3822243189a056583 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sun, 31 Jul 2022 08:41:24 +0800 Subject: [PATCH 004/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20reset=20add=20some?= =?UTF-8?q?=20field?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/request.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client/request.go b/client/request.go index de3fb0e412..40d7d7b087 100644 --- a/client/request.go +++ b/client/request.go @@ -4,6 +4,7 @@ import ( "context" "sync" + "github.com/gofiber/fiber/v3" "github.com/valyala/fasthttp" ) @@ -45,6 +46,8 @@ func (r *Request) SetContext(ctx context.Context) *Request { // Reset clear Request object, used by ReleaseRequest method. func (r *Request) Reset() { r.url = "" + r.method = fiber.MethodGet + r.ctx = nil r.rawRequest.Reset() } From b6c6f2af626e5d4bf5d39924b2e6ac073ed0f65d Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sun, 31 Jul 2022 15:57:48 +0800 Subject: [PATCH 005/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20add=20doc=20and=20?= =?UTF-8?q?fix=20some=20error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 58 +++++++++++++++++++++++++++++++++++------------ client/core.go | 22 +++++++++--------- client/hooks.go | 4 ++-- client/request.go | 17 ++++++++------ client/respose.go | 6 ++--- 5 files changed, 70 insertions(+), 37 deletions(-) diff --git a/client/client.go b/client/client.go index 3bb86920d2..78021664a4 100644 --- a/client/client.go +++ b/client/client.go @@ -9,7 +9,7 @@ import ( ) type Client struct { - core + core *Core baseUrl string header map[string][]string @@ -17,51 +17,62 @@ type Client struct { // Add user-defined request hooks. func (c *Client) AddRequestHook(h ...RequestHook) *Client { - c.userRequestHooks = append(c.userRequestHooks, h...) + c.core.userRequestHooks = append(c.core.userRequestHooks, h...) return c } // Add user-defined response hooks. func (c *Client) AddResponseHook(h ...ResponseHook) *Client { - c.userResponseHooks = append(c.userResponseHooks, h...) + c.core.userResponseHooks = append(c.core.userResponseHooks, h...) return c } +// Set HostClient dial, this method for unit test, +// maybe don't use it. func (c *Client) SetDial(f fasthttp.DialFunc) *Client { - c.client.Dial = f + c.core.client.Dial = f return c } // Set json encoder. func (c *Client) SetJSONMarshal(f utils.JSONMarshal) *Client { - c.jsonMarshal = f + c.core.jsonMarshal = f return c } // Set json decoder. func (c *Client) SetJSONUnmarshal(f utils.JSONUnmarshal) *Client { - c.jsonUnmarshal = f + c.core.jsonUnmarshal = f return c } // Set xml encoder. func (c *Client) SetXMLMarshal(f utils.XMLMarshal) *Client { - c.xmlMarshal = f + c.core.xmlMarshal = f return c } // Set xml decoder. func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client { - c.xmlUnmarshal = f + c.core.xmlUnmarshal = f return c } +// Reset clear Client object. +func (c *Client) Reset() { + c.baseUrl = "" + c.header = map[string][]string{} + + c.core.reset() +} + +// Get provide a API like axios which send get request. func (c *Client) Get(url string) (*Response, error) { req := AcquireRequest(). - SetURL(url). - SetMethod(fiber.MethodGet) + setMethod(fiber.MethodGet). + SetURL(url) - return c.execute(req.Context(), c, req) + return c.core.execute(req.Context(), c, req) } var ( @@ -73,11 +84,29 @@ func init() { defaultClient = AcquireClient() } -func AcquireClient() *Client { - return &Client{ - core: *acquireCore(), +// AcquireClient returns an empty Client object from the pool. +// +// The returned Client object may be returned to the pool with ReleaseClient when no longer needed. +// This allows reducing GC load. +func AcquireClient() (c *Client) { + cv := clientPool.Get() + if cv != nil { + c = cv.(*Client) + return + } + c = &Client{ + core: AcquireCore(), header: map[string][]string{}, } + return +} + +// ReleaseClient returns the object acquired via AcquireClient to the pool. +// +// Do not access the released Client object, otherwise data races may occur. +func ReleaseClient(c *Client) { + c.Reset() + clientPool.Put(c) } // Get default client. @@ -85,6 +114,7 @@ func C() *Client { return defaultClient } +// Get send a get request use defaultClient, a convenient method. func Get(url string) (*Response, error) { return defaultClient.Get(url) } diff --git a/client/core.go b/client/core.go index 7507e5e7ce..f24952be6f 100644 --- a/client/core.go +++ b/client/core.go @@ -39,9 +39,9 @@ type Plugin interface { GenerateExecute(ExecuteFunc) (ExecuteFunc, error) } -// `core` stores middleware and plugin definitions, +// `Core` stores middleware and plugin definitions, // and defines the execution process -type core struct { +type Core struct { client *fasthttp.HostClient // user defined request hooks @@ -67,7 +67,7 @@ type core struct { } // execute will exec each hooks and plugins. -func (c *core) execute(ctx context.Context, agent *Client, req *Request) (*Response, error) { +func (c *Core) execute(ctx context.Context, agent *Client, req *Request) (*Response, error) { var execFunc ExecuteFunc = func(ctx context.Context, a *Client, r *Request) (*Response, error) { resp := AcquireResponse() @@ -159,7 +159,7 @@ func (c *core) execute(ctx context.Context, agent *Client, req *Request) (*Respo // reset clears core object. // It will not clear buildin hooks. -func (c *core) reset() { +func (c *Core) reset() { c.userRequestHooks = c.userRequestHooks[:0] c.userResponseHooks = c.userResponseHooks[:0] c.plugins = c.plugins[:0] @@ -194,17 +194,17 @@ func releaseErrChan(ch chan error) { var corePool sync.Pool -// acquireCore returns an empty core object from the pool. +// AcquireCore returns an empty core object from the pool. // -// The returned core may be returned to the pool with releaseCore when no longer needed. +// The returned core may be returned to the pool with ReleaseCore when no longer needed. // This allows reducing GC load. -func acquireCore() (c *core) { +func AcquireCore() (c *Core) { cv := corePool.Get() if cv != nil { - c = cv.(*core) + c = cv.(*Core) return } - c = &core{ + c = &Core{ client: &fasthttp.HostClient{}, userRequestHooks: []RequestHook{}, buildinRequestHooks: []RequestHook{parserURL}, @@ -221,10 +221,10 @@ func acquireCore() (c *core) { return } -// releaseCore returns the object acquired via acquireCore to the pool. +// ReleaseCore returns the object acquired via AcquireCore to the pool. // // Do not access the released core object, otherwise data races may occur. -func releaseCore(c *core) { +func ReleaseCore(c *Core) { c.reset() corePool.Put(c) } diff --git a/client/hooks.go b/client/hooks.go index c37e2981b5..5d7863c3ce 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -27,8 +27,8 @@ func parserURL(c *Client, req *Request) error { return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme) } - c.client.Addr = addMissingPort(string(uri.Host()), isTLS) - c.client.IsTLS = isTLS + c.core.client.Addr = addMissingPort(string(uri.Host()), isTLS) + c.core.client.IsTLS = isTLS return nil } diff --git a/client/request.go b/client/request.go index 40d7d7b087..b0aa342071 100644 --- a/client/request.go +++ b/client/request.go @@ -15,13 +15,16 @@ type Request struct { rawRequest *fasthttp.Request } -func (r *Request) SetURL(url string) *Request { - r.url = url +// setMethod will set method for Request object, +// user should use request method to set method. +func (r *Request) setMethod(method string) *Request { + r.method = method return r } -func (r *Request) SetMethod(method string) *Request { - r.method = method +// SetURL will set url for Request object. +func (r *Request) SetURL(url string) *Request { + r.url = url return r } @@ -54,9 +57,9 @@ func (r *Request) Reset() { var requestPool sync.Pool -// AcquireRequest returns an empty core object from the pool. +// AcquireRequest returns an empty request object from the pool. // -// The returned core may be returned to the pool with ReleaseRequest when no longer needed. +// The returned request may be returned to the pool with ReleaseRequest when no longer needed. // This allows reducing GC load. func AcquireRequest() (req *Request) { reqv := requestPool.Get() @@ -73,7 +76,7 @@ func AcquireRequest() (req *Request) { // ReleaseRequest returns the object acquired via AcquireRequest to the pool. // -// Do not access the released core object, otherwise data races may occur. +// Do not access the released Request object, otherwise data races may occur. func ReleaseRequest(req *Request) { req.Reset() requestPool.Put(req) diff --git a/client/respose.go b/client/respose.go index 6a9aeab4de..97b7bb29bf 100644 --- a/client/respose.go +++ b/client/respose.go @@ -17,9 +17,9 @@ func (r *Response) Reset() { var responsePool sync.Pool -// AcquireResponse returns an empty core object from the pool. +// AcquireResponse returns an empty response object from the pool. // -// The returned core may be returned to the pool with ReleaseResponse when no longer needed. +// The returned response may be returned to the pool with ReleaseResponse when no longer needed. // This allows reducing GC load. func AcquireResponse() (resp *Response) { respv := responsePool.Get() @@ -36,7 +36,7 @@ func AcquireResponse() (resp *Response) { // ReleaseResponse returns the object acquired via AcquireResponse to the pool. // -// Do not access the released core object, otherwise data races may occur. +// Do not access the released Response object, otherwise data races may occur. func ReleaseResponse(resp *Response) { resp.Reset() responsePool.Put(resp) From f5c8e52c556f5332dbee1c6682102feb002a4994 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sun, 31 Jul 2022 22:39:30 +0800 Subject: [PATCH 006/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20add=20header=20mer?= =?UTF-8?q?ge?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 52 ++++++++++++++++++++-- client/core.go | 2 +- client/hooks.go | 36 ++++++++++++---- client/hooks_test.go | 100 +++++++++++++++++++++++++++++++++++++++++++ client/request.go | 53 +++++++++++++++++++++++ 5 files changed, 229 insertions(+), 14 deletions(-) create mode 100644 client/hooks_test.go diff --git a/client/client.go b/client/client.go index 78021664a4..b0c45b6a28 100644 --- a/client/client.go +++ b/client/client.go @@ -1,6 +1,7 @@ package client import ( + "net/http" "sync" "github.com/gofiber/fiber/v3" @@ -12,7 +13,7 @@ type Client struct { core *Core baseUrl string - header map[string][]string + header *Header } // Add user-defined request hooks. @@ -58,10 +59,51 @@ func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client { return c } +// Set baseUrl which is prefix of real url. +func (c *Client) SetBaseURL(url string) *Client { + c.baseUrl = url + return c +} + +// AddHeader method adds a single header field and its value in the client instance. +// These headers will be applied to all requests raised from this client instance. +// Also it can be overridden at request level header options. +func (c *Client) AddHeader(key, val string) *Client { + c.header.Add(key, val) + return c +} + +// SetHeader method sets a single header field and its value in the client instance. +// These headers will be applied to all requests raised from this client instance. +// Also it can be overridden at request level header options. +func (c *Client) SetHeader(key, val string) *Client { + c.header.Set(key, val) + return c +} + +// AddHeaders method adds multiple headers field and its values at one go in the client instance. +// These headers will be applied to all requests raised from this client instance. Also it can be +// overridden at request level headers options. +func (c *Client) AddHeaders(h map[string][]string) *Client { + c.header.AddHeaders(h) + return c +} + +// SetHeaders method sets multiple headers field and its values at one go in the client instance. +// These headers will be applied to all requests raised from this client instance. Also it can be +// overridden at request level headers options. +func (c *Client) SetHeaders(h map[string]string) *Client { + c.header.SetHeaders(h) + return c +} + // Reset clear Client object. func (c *Client) Reset() { c.baseUrl = "" - c.header = map[string][]string{} + + for k := range c.header.Header { + delete(c.header.Header, k) + } c.core.reset() } @@ -95,8 +137,10 @@ func AcquireClient() (c *Client) { return } c = &Client{ - core: AcquireCore(), - header: map[string][]string{}, + core: AcquireCore(), + header: &Header{ + Header: make(http.Header), + }, } return } diff --git a/client/core.go b/client/core.go index f24952be6f..46d0dd7682 100644 --- a/client/core.go +++ b/client/core.go @@ -207,7 +207,7 @@ func AcquireCore() (c *Core) { c = &Core{ client: &fasthttp.HostClient{}, userRequestHooks: []RequestHook{}, - buildinRequestHooks: []RequestHook{parserURL}, + buildinRequestHooks: []RequestHook{parserURL, parserHeader}, userResponseHooks: []ResponseHook{}, buildinResposeHooks: []ResponseHook{}, plugins: []Plugin{}, diff --git a/client/hooks.go b/client/hooks.go index 5d7863c3ce..6a482af106 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -13,8 +13,23 @@ var ( httpsBytes = []byte("https") ) +// addMissingPort will add the corresponding port number for host. +func addMissingPort(addr string, isTLS bool) string { + n := strings.Index(addr, ":") + if n >= 0 { + return addr + } + port := 80 + if isTLS { + port = 443 + } + return net.JoinHostPort(addr, strconv.Itoa(port)) +} + // parserURL will set the options for the hostclient // and normalize the url. +// TODO: The baseUrl should be merge with request uri. +// TODO: Query params and path params should be deal in this function. func parserURL(c *Client, req *Request) error { req.rawRequest.SetRequestURI(req.url) @@ -33,15 +48,18 @@ func parserURL(c *Client, req *Request) error { return nil } -// addMissingPort will add the corresponding port number for host. -func addMissingPort(addr string, isTLS bool) string { - n := strings.Index(addr, ":") - if n >= 0 { - return addr +// parserHeader will make request header up. +// It will merge headers from client and request. +// TODO: Header should be set automatically based on data. +// TODO: User-Agent should be set? +func parserHeader(c *Client, req *Request) error { + for k, v := range c.header.Header { + req.rawRequest.Header.Set(k, strings.Join(v, ", ")) } - port := 80 - if isTLS { - port = 443 + + for k, v := range req.header.Header { + req.rawRequest.Header.Set(k, strings.Join(v, ", ")) } - return net.JoinHostPort(addr, strconv.Itoa(port)) + + return nil } diff --git a/client/hooks_test.go b/client/hooks_test.go new file mode 100644 index 0000000000..023e1a00e2 --- /dev/null +++ b/client/hooks_test.go @@ -0,0 +1,100 @@ +package client + +import ( + "net/http" + "testing" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/utils" + "github.com/valyala/fasthttp" +) + +func TestParserURL(t *testing.T) { + type args struct { + c *Client + req *Request + } + tests := []struct { + name string + args args + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := parserURL(tt.args.c, tt.args.req); (err != nil) != tt.wantErr { + t.Errorf("parserURL() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestParserHeader(t *testing.T) { + t.Parallel() + + t.Run("client header should be set", func(t *testing.T) { + client := &Client{ + header: &Header{ + Header: map[string][]string{ + fiber.HeaderContentType: {"application/json"}, + }, + }, + } + + req := &Request{ + header: &Header{ + Header: make(http.Header), + }, + rawRequest: fasthttp.AcquireRequest(), + } + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("application/json"), req.rawRequest.Header.ContentType()) + }) + + t.Run("request header should be set", func(t *testing.T) { + client := &Client{ + header: &Header{ + Header: make(http.Header), + }, + } + + req := &Request{ + header: &Header{ + Header: map[string][]string{ + fiber.HeaderContentType: {"application/json", "utf-8"}, + }, + }, + rawRequest: fasthttp.AcquireRequest(), + } + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("application/json, utf-8"), req.rawRequest.Header.ContentType()) + }) + + t.Run("request header should override client header", func(t *testing.T) { + client := &Client{ + header: &Header{ + Header: map[string][]string{ + fiber.HeaderContentType: {"application/xml"}, + }, + }, + } + + req := &Request{ + header: &Header{ + Header: map[string][]string{ + fiber.HeaderContentType: {"application/json", "utf-8"}, + }, + }, + rawRequest: fasthttp.AcquireRequest(), + } + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("application/json, utf-8"), req.rawRequest.Header.ContentType()) + }) +} diff --git a/client/request.go b/client/request.go index b0aa342071..b090d3f036 100644 --- a/client/request.go +++ b/client/request.go @@ -2,6 +2,7 @@ package client import ( "context" + "net/http" "sync" "github.com/gofiber/fiber/v3" @@ -12,6 +13,7 @@ type Request struct { url string method string ctx context.Context + header *Header rawRequest *fasthttp.Request } @@ -46,15 +48,65 @@ func (r *Request) SetContext(ctx context.Context) *Request { return r } +// AddHeader method adds a single header field and its value in the request instance. +// It will override header which set in client instance. +func (r *Request) AddHeader(key, val string) *Request { + r.header.Add(key, val) + return r +} + +// SetHeader method sets a single header field and its value in the request instance. +// It will override header which set in client instance. +func (r *Request) SetHeader(key, val string) *Request { + r.header.Set(key, val) + return r +} + +// AddHeaders method adds multiple headers field and its values at one go in the request instance. +// It will override header which set in client instance. +func (r *Request) AddHeaders(h map[string][]string) *Request { + r.header.AddHeaders(h) + return r +} + +// SetHeaders method sets multiple headers field and its values at one go in the request instance. +// It will override header which set in client instance. +func (r *Request) SetHeaders(h map[string]string) *Request { + r.header.SetHeaders(h) + return r +} + // Reset clear Request object, used by ReleaseRequest method. func (r *Request) Reset() { r.url = "" r.method = fiber.MethodGet r.ctx = nil + for k := range r.header.Header { + delete(r.header.Header, k) + } + r.rawRequest.Reset() } +type Header struct { + http.Header +} + +func (h *Header) AddHeaders(r map[string][]string) { + for k, v := range r { + for _, vv := range v { + h.Header.Add(k, vv) + } + } +} + +func (h *Header) SetHeaders(r map[string]string) { + for k, v := range r { + h.Header.Set(k, v) + } +} + var requestPool sync.Pool // AcquireRequest returns an empty request object from the pool. @@ -69,6 +121,7 @@ func AcquireRequest() (req *Request) { } req = &Request{ + header: &Header{Header: make(http.Header)}, rawRequest: fasthttp.AcquireRequest(), } return From 4017c76e3bd7c39c7ca8e34f050c929070d67b21 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Mon, 1 Aug 2022 22:08:49 +0800 Subject: [PATCH 007/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20add=20query=20para?= =?UTF-8?q?m?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 57 +++++++++++++++++ client/hooks.go | 49 +++++++++++++-- client/hooks_test.go | 111 +++++++++++++++++++++++++++----- client/request.go | 139 ++++++++++++++++++++++++++++++++++++++++- client/request_test.go | 136 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 468 insertions(+), 24 deletions(-) create mode 100644 client/request_test.go diff --git a/client/client.go b/client/client.go index b0c45b6a28..3e5f8d96c1 100644 --- a/client/client.go +++ b/client/client.go @@ -2,6 +2,7 @@ package client import ( "net/http" + "net/url" "sync" "github.com/gofiber/fiber/v3" @@ -14,6 +15,7 @@ type Client struct { baseUrl string header *Header + params *Params } // Add user-defined request hooks. @@ -97,6 +99,54 @@ func (c *Client) SetHeaders(h map[string]string) *Client { return c } +// AddParam method adds a single query param field and its value in the client instance. +// These params will be applied to all requests raised from this client instance. +// Also it can be overridden at request level param options. +func (c *Client) AddParam(key, val string) *Client { + c.params.Add(key, val) + return c +} + +// SetParam method sets a single query param field and its value in the client instance. +// These params will be applied to all requests raised from this client instance. +// Also it can be overridden at request level param options. +func (c *Client) SetParam(key, val string) *Client { + c.params.Set(key, val) + return c +} + +// AddParams method adds multiple query params field and its values at one go in the client instance. +// These params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level params options. +func (c *Client) AddParams(m map[string][]string) *Client { + c.params.AddParams(m) + return c +} + +// SetParams method sets multiple params field and its values at one go in the client instance. +// These params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level params options. +func (c *Client) SetParams(m map[string]string) *Client { + c.params.SetParams(m) + return c +} + +// SetParamsWithStruct method sets multiple params field and its values at one go in the client instance. +// These params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level params options. +func (c *Client) SetParamsWithStruct(v any) *Client { + c.params.SetParamsWithStruct(v) + return c +} + +// DelParams method deletes single or multiple params field and its valus in client. +func (c *Client) DelParams(key ...string) *Client { + for _, v := range key { + c.params.Del(v) + } + return c +} + // Reset clear Client object. func (c *Client) Reset() { c.baseUrl = "" @@ -105,6 +155,10 @@ func (c *Client) Reset() { delete(c.header.Header, k) } + for k := range c.params.Values { + delete(c.params.Values, k) + } + c.core.reset() } @@ -141,6 +195,9 @@ func AcquireClient() (c *Client) { header: &Header{ Header: make(http.Header), }, + params: &Params{ + Values: make(url.Values), + }, } return } diff --git a/client/hooks.go b/client/hooks.go index 6a482af106..b7d24658dd 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -4,6 +4,8 @@ import ( "bytes" "fmt" "net" + "net/url" + "regexp" "strconv" "strings" ) @@ -11,6 +13,8 @@ import ( var ( httpBytes = []byte("http") httpsBytes = []byte("https") + + protocolCheck = regexp.MustCompile(`^https?://.*$`) ) // addMissingPort will add the corresponding port number for host. @@ -28,23 +32,58 @@ func addMissingPort(addr string, isTLS bool) string { // parserURL will set the options for the hostclient // and normalize the url. -// TODO: The baseUrl should be merge with request uri. +// The baseUrl will be merge with request uri. // TODO: Query params and path params should be deal in this function. func parserURL(c *Client, req *Request) error { - req.rawRequest.SetRequestURI(req.url) + splitUrl := strings.Split(req.url, "?") + // I don't want to judege splitUrl length. + splitUrl = append(splitUrl, "") - uri := req.rawRequest.URI() + // Determine whether to superimpose baseurl based on + // whether the URL starts with the protocol + uri := splitUrl[0] + if !protocolCheck.MatchString(uri) { + uri = c.baseUrl + uri + if !protocolCheck.MatchString(uri) { + return fmt.Errorf("url format error") + } + } - isTLS, scheme := false, uri.Scheme() + // set uri to request and orther related setting + req.rawRequest.SetRequestURI(uri) + rawUri := req.rawRequest.URI() + isTLS, scheme := false, rawUri.Scheme() if bytes.Equal(httpsBytes, scheme) { isTLS = true } else if !bytes.Equal(httpBytes, scheme) { return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme) } - c.core.client.Addr = addMissingPort(string(uri.Host()), isTLS) + c.core.client.Addr = addMissingPort(string(rawUri.Host()), isTLS) c.core.client.IsTLS = isTLS + // merge query params + hashSplit := strings.Split(splitUrl[1], "#") + hashSplit = append(hashSplit, "") + queryParams, err := url.ParseQuery(hashSplit[0]) + if err != nil { + return err + } + for k, v := range c.params.Values { + for _, vv := range v { + queryParams.Add(k, vv) + } + } + + for k, v := range req.params.Values { + for _, vv := range v { + queryParams.Add(k, vv) + } + } + + req.rawRequest.URI().SetQueryString(queryParams.Encode()) + req.rawRequest.URI().SetHash(hashSplit[1]) + return nil } diff --git a/client/hooks_test.go b/client/hooks_test.go index 023e1a00e2..320c119533 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -1,7 +1,9 @@ package client import ( + "fmt" "net/http" + "net/url" "testing" "github.com/gofiber/fiber/v3" @@ -10,24 +12,99 @@ import ( ) func TestParserURL(t *testing.T) { - type args struct { - c *Client - req *Request - } - tests := []struct { - name string - args args - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := parserURL(tt.args.c, tt.args.req); (err != nil) != tt.wantErr { - t.Errorf("parserURL() error = %v, wantErr %v", err, tt.wantErr) + t.Parallel() + + t.Run("client baseurl should be set", func(t *testing.T) { + client := AcquireClient().SetBaseURL("http://example.com/api") + req := AcquireRequest().SetURL("") + + err := parserURL(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "http://example.com/api", req.rawRequest.URI().String()) + }) + + t.Run("request url should be set", func(t *testing.T) { + client := AcquireClient() + req := AcquireRequest().SetURL("http://example.com/api") + + err := parserURL(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "http://example.com/api", req.rawRequest.URI().String()) + }) + + t.Run("the request url will override baseurl with protocol", func(t *testing.T) { + client := AcquireClient().SetBaseURL("http://example.com/api") + req := AcquireRequest().SetURL("http://example.com/api/v1") + + err := parserURL(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "http://example.com/api/v1", req.rawRequest.URI().String()) + }) + + t.Run("the request url should be append after baseurl without protocol", func(t *testing.T) { + client := AcquireClient().SetBaseURL("http://example.com/api") + req := AcquireRequest().SetURL("/v1") + + err := parserURL(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "http://example.com/api/v1", req.rawRequest.URI().String()) + }) + + t.Run("the url is error", func(t *testing.T) { + client := AcquireClient().SetBaseURL("example.com/api") + req := AcquireRequest().SetURL("/v1") + + err := parserURL(client, req) + utils.AssertEqual(t, fmt.Errorf("url format error"), err) + }) + + t.Run("query params from client should be set", func(t *testing.T) { + client := AcquireClient(). + SetParam("foo", "bar") + req := AcquireRequest().SetURL("http://example.com/api/v1") + + err := parserURL(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("foo=bar"), req.rawRequest.URI().QueryString()) + }) + + t.Run("query params from request should be set", func(t *testing.T) { + client := AcquireClient() + req := AcquireRequest(). + SetURL("http://example.com/api/v1"). + SetParam("bar", "foo") + + err := parserURL(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("bar=foo"), req.rawRequest.URI().QueryString()) + }) + + t.Run("query params should be merged", func(t *testing.T) { + client := AcquireClient(). + SetParam("bar", "foo1") + req := AcquireRequest(). + SetURL("http://example.com/api/v1?bar=foo2"). + SetParam("bar", "foo") + + err := parserURL(client, req) + utils.AssertEqual(t, nil, err) + + values, _ := url.ParseQuery(string(req.rawRequest.URI().QueryString())) + + flag1, flag2, flag3 := false, false, false + for _, v := range values["bar"] { + if v == "foo1" { + flag1 = true + } else if v == "foo2" { + flag2 = true + } else if v == "foo" { + flag3 = true } - }) - } + } + utils.AssertEqual(t, true, flag1) + utils.AssertEqual(t, true, flag2) + utils.AssertEqual(t, true, flag3) + }) } func TestParserHeader(t *testing.T) { diff --git a/client/request.go b/client/request.go index b090d3f036..5df2d020b3 100644 --- a/client/request.go +++ b/client/request.go @@ -3,6 +3,9 @@ package client import ( "context" "net/http" + "net/url" + "reflect" + "strconv" "sync" "github.com/gofiber/fiber/v3" @@ -14,6 +17,7 @@ type Request struct { method string ctx context.Context header *Header + params *Params rawRequest *fasthttp.Request } @@ -76,6 +80,49 @@ func (r *Request) SetHeaders(h map[string]string) *Request { return r } +// AddParam method adds a single param field and its value in the request instance. +// It will override param which set in client instance. +func (r *Request) AddParam(key, val string) *Request { + r.params.Add(key, val) + return r +} + +// SetParam method sets a single param field and its value in the request instance. +// It will override param which set in client instance. +func (r *Request) SetParam(key, val string) *Request { + r.params.Set(key, val) + return r +} + +// AddParams method adds multiple params field and its values at one go in the request instance. +// It will override param which set in client instance. +func (r *Request) AddParams(m map[string][]string) *Request { + r.params.AddParams(m) + return r +} + +// SetParams method sets multiple params field and its values at one go in the request instance. +// It will override param which set in client instance. +func (r *Request) SetParams(m map[string]string) *Request { + r.params.SetParams(m) + return r +} + +// SetParamWithStruct method sets multiple params field and its values at one go in the request instance. +// It will override param which set in client instance. +func (r *Request) SetParamsWithStruct(v any) *Request { + r.params.SetParamsWithStruct(v) + return r +} + +// DelParams method deletes single or multiple params field ant its values. +func (r *Request) DelParams(key ...string) *Request { + for _, v := range key { + r.params.Del(v) + } + return r +} + // Reset clear Request object, used by ReleaseRequest method. func (r *Request) Reset() { r.url = "" @@ -86,24 +133,111 @@ func (r *Request) Reset() { delete(r.header.Header, k) } + for k := range r.params.Values { + delete(r.params.Values, k) + } + r.rawRequest.Reset() } +// Header is a wrapper which wrap http.Header, +// the header in client and request will store in it. type Header struct { http.Header } +// AddHeaders receive a map and add each value to header. func (h *Header) AddHeaders(r map[string][]string) { for k, v := range r { for _, vv := range v { - h.Header.Add(k, vv) + h.Add(k, vv) } } } +// SetHeaders will override all headers. func (h *Header) SetHeaders(r map[string]string) { for k, v := range r { - h.Header.Set(k, v) + h.Set(k, v) + } +} + +// Params is a wrapper which wrap url.Values, +// the query string and formdata in client and request will store in it. +type Params struct { + url.Values +} + +// AddParams receive a map and add each value to param. +func (p *Params) AddParams(r map[string][]string) { + for k, v := range r { + for _, vv := range v { + p.Add(k, vv) + } + } +} + +// SetParams will override all params. +func (p *Params) SetParams(r map[string]string) { + for k, v := range r { + p.Set(k, v) + } +} + +func (p *Params) SetParamsWithStruct(v any) { + valueOfV := reflect.ValueOf(v) + typeOfV := reflect.TypeOf(v) + // The v should be struct or point of struct + + if typeOfV.Kind() == reflect.Pointer && typeOfV.Elem().Kind() == reflect.Struct { + valueOfV = valueOfV.Elem() + typeOfV = typeOfV.Elem() + } else if typeOfV.Kind() != reflect.Struct { + return + } + + // Boring type judge. + // TODO: cover more types and complex data structure. + var setVal func(name string, value reflect.Value) + setVal = func(name string, val reflect.Value) { + switch val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + p.Add(name, strconv.Itoa(int(val.Int()))) + case reflect.Bool: + if val.Bool() { + p.Add(name, "true") + } else { + p.Add(name, "false") + } + case reflect.String: + p.Add(name, val.String()) + case reflect.Float32, reflect.Float64: + p.Add(name, strconv.FormatFloat(val.Float(), 'f', -1, 64)) + case reflect.Slice, reflect.Array: + for i := 0; i < val.Len(); i++ { + setVal(name, val.Index(i)) + } + default: + } + } + + for i := 0; i < typeOfV.NumField(); i++ { + field := typeOfV.Field(i) + if !field.IsExported() { + continue + } + + name := field.Tag.Get("param") + if name == "" { + name = field.Name + } + val := valueOfV.Field(i) + if val.IsZero() { + continue + } + // To cover slice and array, we delete the val then add it. + p.Del(name) + setVal(name, val) } } @@ -122,6 +256,7 @@ func AcquireRequest() (req *Request) { req = &Request{ header: &Header{Header: make(http.Header)}, + params: &Params{Values: make(url.Values)}, rawRequest: fasthttp.AcquireRequest(), } return diff --git a/client/request_test.go b/client/request_test.go new file mode 100644 index 0000000000..c650975eba --- /dev/null +++ b/client/request_test.go @@ -0,0 +1,136 @@ +package client + +import ( + "net/url" + "testing" + + "github.com/gofiber/fiber/v3/utils" +) + +func TestParamsSetParamsWithStruct(t *testing.T) { + t.Parallel() + + type args struct { + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `param:"int_slice"` + } + + t.Run("the struct should be applied", func(t *testing.T) { + p := &Params{ + Values: make(url.Values), + } + p.SetParamsWithStruct(args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + utils.AssertEqual(t, "5", p.Get("TInt")) + utils.AssertEqual(t, "string", p.Get("TString")) + utils.AssertEqual(t, "3.1", p.Get("TFloat")) + utils.AssertEqual(t, true, func() bool { + for _, v := range p.Values["TSlice"] { + if v == "foo" { + return true + } + } + return false + }()) + utils.AssertEqual(t, true, func() bool { + for _, v := range p.Values["TSlice"] { + if v == "bar" { + return true + } + } + return false + }()) + utils.AssertEqual(t, true, func() bool { + for _, v := range p.Values["int_slice"] { + if v == "1" { + return true + } + } + return false + }()) + utils.AssertEqual(t, true, func() bool { + for _, v := range p.Values["int_slice"] { + if v == "2" { + return true + } + } + return false + }()) + }) + + t.Run("the pointer of a struct should be applied", func(t *testing.T) { + p := &Params{ + Values: make(url.Values), + } + p.SetParamsWithStruct(&args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + utils.AssertEqual(t, "5", p.Get("TInt")) + utils.AssertEqual(t, "string", p.Get("TString")) + utils.AssertEqual(t, "3.1", p.Get("TFloat")) + utils.AssertEqual(t, true, func() bool { + for _, v := range p.Values["TSlice"] { + if v == "foo" { + return true + } + } + return false + }()) + utils.AssertEqual(t, true, func() bool { + for _, v := range p.Values["TSlice"] { + if v == "bar" { + return true + } + } + return false + }()) + utils.AssertEqual(t, true, func() bool { + for _, v := range p.Values["int_slice"] { + if v == "1" { + return true + } + } + return false + }()) + utils.AssertEqual(t, true, func() bool { + for _, v := range p.Values["int_slice"] { + if v == "2" { + return true + } + } + return false + }()) + }) + + t.Run("the zero val should be ignore", func(t *testing.T) { + p := &Params{ + Values: make(url.Values), + } + p.SetParamsWithStruct(&args{ + TInt: 0, + TString: "", + TFloat: 0.0, + }) + + utils.AssertEqual(t, "", p.Get("TInt")) + utils.AssertEqual(t, "", p.Get("TString")) + utils.AssertEqual(t, "", p.Get("TFloat")) + utils.AssertEqual(t, 0, len(p.Values["TSlice"])) + utils.AssertEqual(t, 0, len(p.Values["int_slice"])) + }) +} From 3d92f090b8dd016dd2c2975a69c87ff764b2ba16 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Tue, 2 Aug 2022 09:08:42 +0800 Subject: [PATCH 008/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20change=20to=20fast?= =?UTF-8?q?http's=20header=20and=20args?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 16 +++-------- client/hooks.go | 44 ++++++++++++++---------------- client/hooks_test.go | 57 ++++++++++---------------------------- client/request.go | 20 ++++---------- client/request_test.go | 62 +++++++++++++++++++++--------------------- 5 files changed, 76 insertions(+), 123 deletions(-) diff --git a/client/client.go b/client/client.go index 3e5f8d96c1..843445d16a 100644 --- a/client/client.go +++ b/client/client.go @@ -1,8 +1,6 @@ package client import ( - "net/http" - "net/url" "sync" "github.com/gofiber/fiber/v3" @@ -151,15 +149,9 @@ func (c *Client) DelParams(key ...string) *Client { func (c *Client) Reset() { c.baseUrl = "" - for k := range c.header.Header { - delete(c.header.Header, k) - } - - for k := range c.params.Values { - delete(c.params.Values, k) - } - c.core.reset() + c.header.Reset() + c.params.Reset() } // Get provide a API like axios which send get request. @@ -193,10 +185,10 @@ func AcquireClient() (c *Client) { c = &Client{ core: AcquireCore(), header: &Header{ - Header: make(http.Header), + RequestHeader: &fasthttp.RequestHeader{}, }, params: &Params{ - Values: make(url.Values), + Args: fasthttp.AcquireArgs(), }, } return diff --git a/client/hooks.go b/client/hooks.go index b7d24658dd..4fa0bcb605 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -4,10 +4,12 @@ import ( "bytes" "fmt" "net" - "net/url" "regexp" "strconv" "strings" + + "github.com/gofiber/fiber/v3/utils" + "github.com/valyala/fasthttp" ) var ( @@ -65,23 +67,19 @@ func parserURL(c *Client, req *Request) error { // merge query params hashSplit := strings.Split(splitUrl[1], "#") hashSplit = append(hashSplit, "") - queryParams, err := url.ParseQuery(hashSplit[0]) - if err != nil { - return err - } - for k, v := range c.params.Values { - for _, vv := range v { - queryParams.Add(k, vv) - } - } + args := fasthttp.AcquireArgs() + defer func() { + fasthttp.ReleaseArgs(args) + }() - for k, v := range req.params.Values { - for _, vv := range v { - queryParams.Add(k, vv) - } - } - - req.rawRequest.URI().SetQueryString(queryParams.Encode()) + args.Parse(hashSplit[0]) + c.params.VisitAll(func(key, value []byte) { + args.AddBytesKV(key, value) + }) + req.params.VisitAll(func(key, value []byte) { + args.AddBytesKV(key, value) + }) + req.rawRequest.URI().SetQueryStringBytes(utils.CopyBytes(args.QueryString())) req.rawRequest.URI().SetHash(hashSplit[1]) return nil @@ -92,13 +90,13 @@ func parserURL(c *Client, req *Request) error { // TODO: Header should be set automatically based on data. // TODO: User-Agent should be set? func parserHeader(c *Client, req *Request) error { - for k, v := range c.header.Header { - req.rawRequest.Header.Set(k, strings.Join(v, ", ")) - } + c.header.VisitAll(func(key, value []byte) { + req.rawRequest.Header.SetBytesKV(key, value) + }) - for k, v := range req.header.Header { - req.rawRequest.Header.Set(k, strings.Join(v, ", ")) - } + req.header.VisitAll(func(key, value []byte) { + req.rawRequest.Header.SetBytesKV(key, value) + }) return nil } diff --git a/client/hooks_test.go b/client/hooks_test.go index 320c119533..f6fa4738c8 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -2,13 +2,11 @@ package client import ( "fmt" - "net/http" "net/url" "testing" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/utils" - "github.com/valyala/fasthttp" ) func TestParserURL(t *testing.T) { @@ -111,20 +109,12 @@ func TestParserHeader(t *testing.T) { t.Parallel() t.Run("client header should be set", func(t *testing.T) { - client := &Client{ - header: &Header{ - Header: map[string][]string{ - fiber.HeaderContentType: {"application/json"}, - }, - }, - } + client := AcquireClient(). + SetHeaders(map[string]string{ + fiber.HeaderContentType: "application/json", + }) - req := &Request{ - header: &Header{ - Header: make(http.Header), - }, - rawRequest: fasthttp.AcquireRequest(), - } + req := AcquireRequest() err := parserHeader(client, req) utils.AssertEqual(t, nil, err) @@ -132,20 +122,12 @@ func TestParserHeader(t *testing.T) { }) t.Run("request header should be set", func(t *testing.T) { - client := &Client{ - header: &Header{ - Header: make(http.Header), - }, - } + client := AcquireClient() - req := &Request{ - header: &Header{ - Header: map[string][]string{ - fiber.HeaderContentType: {"application/json", "utf-8"}, - }, - }, - rawRequest: fasthttp.AcquireRequest(), - } + req := AcquireRequest(). + SetHeaders(map[string]string{ + fiber.HeaderContentType: "application/json, utf-8", + }) err := parserHeader(client, req) utils.AssertEqual(t, nil, err) @@ -153,22 +135,11 @@ func TestParserHeader(t *testing.T) { }) t.Run("request header should override client header", func(t *testing.T) { - client := &Client{ - header: &Header{ - Header: map[string][]string{ - fiber.HeaderContentType: {"application/xml"}, - }, - }, - } + client := AcquireClient(). + SetHeader(fiber.HeaderContentType, "application/xml") - req := &Request{ - header: &Header{ - Header: map[string][]string{ - fiber.HeaderContentType: {"application/json", "utf-8"}, - }, - }, - rawRequest: fasthttp.AcquireRequest(), - } + req := AcquireRequest(). + SetHeader(fiber.HeaderContentType, "application/json, utf-8") err := parserHeader(client, req) utils.AssertEqual(t, nil, err) diff --git a/client/request.go b/client/request.go index 5df2d020b3..4d2bc738ec 100644 --- a/client/request.go +++ b/client/request.go @@ -2,8 +2,6 @@ package client import ( "context" - "net/http" - "net/url" "reflect" "strconv" "sync" @@ -129,21 +127,15 @@ func (r *Request) Reset() { r.method = fiber.MethodGet r.ctx = nil - for k := range r.header.Header { - delete(r.header.Header, k) - } - - for k := range r.params.Values { - delete(r.params.Values, k) - } - + r.header.Reset() + r.params.Reset() r.rawRequest.Reset() } // Header is a wrapper which wrap http.Header, // the header in client and request will store in it. type Header struct { - http.Header + *fasthttp.RequestHeader } // AddHeaders receive a map and add each value to header. @@ -165,7 +157,7 @@ func (h *Header) SetHeaders(r map[string]string) { // Params is a wrapper which wrap url.Values, // the query string and formdata in client and request will store in it. type Params struct { - url.Values + *fasthttp.Args } // AddParams receive a map and add each value to param. @@ -255,8 +247,8 @@ func AcquireRequest() (req *Request) { } req = &Request{ - header: &Header{Header: make(http.Header)}, - params: &Params{Values: make(url.Values)}, + header: &Header{RequestHeader: &fasthttp.RequestHeader{}}, + params: &Params{Args: fasthttp.AcquireArgs()}, rawRequest: fasthttp.AcquireRequest(), } return diff --git a/client/request_test.go b/client/request_test.go index c650975eba..69669b8ced 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -1,10 +1,10 @@ package client import ( - "net/url" "testing" "github.com/gofiber/fiber/v3/utils" + "github.com/valyala/fasthttp" ) func TestParamsSetParamsWithStruct(t *testing.T) { @@ -21,7 +21,7 @@ func TestParamsSetParamsWithStruct(t *testing.T) { t.Run("the struct should be applied", func(t *testing.T) { p := &Params{ - Values: make(url.Values), + Args: fasthttp.AcquireArgs(), } p.SetParamsWithStruct(args{ TInt: 5, @@ -31,36 +31,36 @@ func TestParamsSetParamsWithStruct(t *testing.T) { TIntSlice: []int{1, 2}, }) - utils.AssertEqual(t, "5", p.Get("TInt")) - utils.AssertEqual(t, "string", p.Get("TString")) - utils.AssertEqual(t, "3.1", p.Get("TFloat")) + utils.AssertEqual(t, []byte("5"), p.Peek("TInt")) + utils.AssertEqual(t, []byte("string"), p.Peek("TString")) + utils.AssertEqual(t, []byte("3.1"), p.Peek("TFloat")) utils.AssertEqual(t, true, func() bool { - for _, v := range p.Values["TSlice"] { - if v == "foo" { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "foo" { return true } } return false }()) utils.AssertEqual(t, true, func() bool { - for _, v := range p.Values["TSlice"] { - if v == "bar" { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "bar" { return true } } return false }()) utils.AssertEqual(t, true, func() bool { - for _, v := range p.Values["int_slice"] { - if v == "1" { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "1" { return true } } return false }()) utils.AssertEqual(t, true, func() bool { - for _, v := range p.Values["int_slice"] { - if v == "2" { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "2" { return true } } @@ -70,7 +70,7 @@ func TestParamsSetParamsWithStruct(t *testing.T) { t.Run("the pointer of a struct should be applied", func(t *testing.T) { p := &Params{ - Values: make(url.Values), + Args: fasthttp.AcquireArgs(), } p.SetParamsWithStruct(&args{ TInt: 5, @@ -80,36 +80,36 @@ func TestParamsSetParamsWithStruct(t *testing.T) { TIntSlice: []int{1, 2}, }) - utils.AssertEqual(t, "5", p.Get("TInt")) - utils.AssertEqual(t, "string", p.Get("TString")) - utils.AssertEqual(t, "3.1", p.Get("TFloat")) + utils.AssertEqual(t, []byte("5"), p.Peek("TInt")) + utils.AssertEqual(t, []byte("string"), p.Peek("TString")) + utils.AssertEqual(t, []byte("3.1"), p.Peek("TFloat")) utils.AssertEqual(t, true, func() bool { - for _, v := range p.Values["TSlice"] { - if v == "foo" { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "foo" { return true } } return false }()) utils.AssertEqual(t, true, func() bool { - for _, v := range p.Values["TSlice"] { - if v == "bar" { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "bar" { return true } } return false }()) utils.AssertEqual(t, true, func() bool { - for _, v := range p.Values["int_slice"] { - if v == "1" { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "1" { return true } } return false }()) utils.AssertEqual(t, true, func() bool { - for _, v := range p.Values["int_slice"] { - if v == "2" { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "2" { return true } } @@ -119,7 +119,7 @@ func TestParamsSetParamsWithStruct(t *testing.T) { t.Run("the zero val should be ignore", func(t *testing.T) { p := &Params{ - Values: make(url.Values), + Args: fasthttp.AcquireArgs(), } p.SetParamsWithStruct(&args{ TInt: 0, @@ -127,10 +127,10 @@ func TestParamsSetParamsWithStruct(t *testing.T) { TFloat: 0.0, }) - utils.AssertEqual(t, "", p.Get("TInt")) - utils.AssertEqual(t, "", p.Get("TString")) - utils.AssertEqual(t, "", p.Get("TFloat")) - utils.AssertEqual(t, 0, len(p.Values["TSlice"])) - utils.AssertEqual(t, 0, len(p.Values["int_slice"])) + utils.AssertEqual(t, "", string(p.Peek("TInt"))) + utils.AssertEqual(t, "", string(p.Peek("TString"))) + utils.AssertEqual(t, "", string(p.Peek("TFloat"))) + utils.AssertEqual(t, 0, len(p.PeekMulti("TSlice"))) + utils.AssertEqual(t, 0, len(p.PeekMulti("int_slice"))) }) } From 3dc96044e26c24b4c2095601e96b1760a359b5ff Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Tue, 2 Aug 2022 14:46:53 +0800 Subject: [PATCH 009/118] =?UTF-8?q?=E2=9C=A8=20v3:=20add=20body=20and=20ua?= =?UTF-8?q?=20setting?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 21 ++++-- client/core.go | 2 +- client/hooks.go | 64 +++++++++++++++++- client/hooks_test.go | 144 +++++++++++++++++++++++++++++++++++++++++ client/request.go | 59 +++++++++++++++-- client/request_test.go | 15 +++++ 6 files changed, 290 insertions(+), 15 deletions(-) diff --git a/client/client.go b/client/client.go index 843445d16a..733a40ee3a 100644 --- a/client/client.go +++ b/client/client.go @@ -11,9 +11,10 @@ import ( type Client struct { core *Core - baseUrl string - header *Header - params *Params + baseUrl string + header *Header + params *Params + userAgent string } // Add user-defined request hooks. @@ -145,9 +146,18 @@ func (c *Client) DelParams(key ...string) *Client { return c } +// SetUserAgent method sets userAgent field and its value in the client instance. +// This ua will be applied to all requests raised from this client instance. +// Also it can be overridden at request level ua options. +func (c *Client) SetUserAgent(ua string) *Client { + c.userAgent = ua + return c +} + // Reset clear Client object. func (c *Client) Reset() { c.baseUrl = "" + c.userAgent = "" c.core.reset() c.header.Reset() @@ -164,8 +174,9 @@ func (c *Client) Get(url string) (*Response, error) { } var ( - defaultClient *Client - clientPool sync.Pool + defaultClient *Client + defaultUserAgent = "fiber" + clientPool sync.Pool ) func init() { diff --git a/client/core.go b/client/core.go index 46d0dd7682..f4e85bb2e9 100644 --- a/client/core.go +++ b/client/core.go @@ -207,7 +207,7 @@ func AcquireCore() (c *Core) { c = &Core{ client: &fasthttp.HostClient{}, userRequestHooks: []RequestHook{}, - buildinRequestHooks: []RequestHook{parserURL, parserHeader}, + buildinRequestHooks: []RequestHook{parserURL, parserHeader, parserBody}, userResponseHooks: []ResponseHook{}, buildinResposeHooks: []ResponseHook{}, plugins: []Plugin{}, diff --git a/client/hooks.go b/client/hooks.go index 4fa0bcb605..bfdd53a1be 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "net" + "reflect" "regexp" "strconv" "strings" @@ -17,6 +18,13 @@ var ( httpsBytes = []byte("https") protocolCheck = regexp.MustCompile(`^https?://.*$`) + + headerAccept = "Accept" + + applicationJSON = "application/json" + applicationXML = "application/xml" + applicationForm = "application/x-www-form-urlencoded" + multipartFormData = "multipart/form-data" ) // addMissingPort will add the corresponding port number for host. @@ -87,9 +95,10 @@ func parserURL(c *Client, req *Request) error { // parserHeader will make request header up. // It will merge headers from client and request. -// TODO: Header should be set automatically based on data. -// TODO: User-Agent should be set? +// Header should be set automatically based on data. +// User-Agent should be set. func parserHeader(c *Client, req *Request) error { + // merge header c.header.VisitAll(func(key, value []byte) { req.rawRequest.Header.SetBytesKV(key, value) }) @@ -98,5 +107,56 @@ func parserHeader(c *Client, req *Request) error { req.rawRequest.Header.SetBytesKV(key, value) }) + // according to data set content-type + switch req.bodyType { + case jsonBody: + req.rawRequest.Header.SetContentType(applicationJSON) + req.rawRequest.Header.Set(headerAccept, applicationJSON) + case xmlBody: + req.rawRequest.Header.SetContentType(applicationXML) + case formBody: + req.rawRequest.Header.SetContentType(applicationForm) + case filesBody: + req.rawRequest.Header.SetContentType(multipartFormData) + default: + } + + // set useragent + req.rawRequest.Header.SetUserAgent(defaultUserAgent) + if c.userAgent != "" { + req.rawRequest.Header.SetUserAgent(c.userAgent) + } + if req.userAgent != "" { + req.rawRequest.Header.SetUserAgent(req.userAgent) + } + + return nil +} + +// parserBody automatically serializes the data according to +// the data type and stores it in the body of the rawRequest +func parserBody(c *Client, req *Request) error { + switch req.bodyType { + case jsonBody: + body, err := c.core.jsonMarshal(req.body) + if err != nil { + return err + } + req.rawRequest.SetBody(body) + case xmlBody: + body, err := c.core.xmlMarshal(req.body) + if err != nil { + return err + } + req.rawRequest.SetBody(body) + case formBody: + case filesBody: + case rawBody: + if body, ok := req.body.([]byte); ok { + req.rawRequest.SetBody(body) + } else { + return fmt.Errorf("the raw body should be []byte, but we receive %s", reflect.TypeOf(req.body).Kind().String()) + } + } return nil } diff --git a/client/hooks_test.go b/client/hooks_test.go index f6fa4738c8..ee248997ae 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -1,6 +1,7 @@ package client import ( + "encoding/xml" "fmt" "net/url" "testing" @@ -9,6 +10,46 @@ import ( "github.com/gofiber/fiber/v3/utils" ) +func TestAddMissingPort(t *testing.T) { + type args struct { + addr string + isTLS bool + } + tests := []struct { + name string + args args + want string + }{ + { + name: "do anything", + args: args{ + addr: "example.com:1234", + }, + want: "example.com:1234", + }, + { + name: "add 80 port", + args: args{ + addr: "example.com", + }, + want: "example.com:80", + }, + { + name: "add 443 port", + args: args{ + addr: "example.com", + isTLS: true, + }, + want: "example.com:443", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + utils.AssertEqual(t, tt.want, addMissingPort(tt.args.addr, tt.args.isTLS)) + }) + } +} + func TestParserURL(t *testing.T) { t.Parallel() @@ -145,4 +186,107 @@ func TestParserHeader(t *testing.T) { utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("application/json, utf-8"), req.rawRequest.Header.ContentType()) }) + + t.Run("auto set json header", func(t *testing.T) { + type jsonData struct { + Name string `json:"name"` + } + client := AcquireClient() + req := AcquireRequest(). + SetJSON(jsonData{ + Name: "foo", + }) + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte(applicationJSON), req.rawRequest.Header.ContentType()) + }) + + t.Run("auto set xml header", func(t *testing.T) { + type xmlData struct { + XMLName xml.Name `xml:"body"` + Name string `xml:"name"` + } + client := AcquireClient() + req := AcquireRequest(). + SetXML(xmlData{ + Name: "foo", + }) + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte(applicationXML), req.rawRequest.Header.ContentType()) + }) + + t.Run("ua should have default value", func(t *testing.T) { + client := AcquireClient() + req := AcquireRequest() + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("fiber"), req.rawRequest.Header.UserAgent()) + }) + + t.Run("ua in client should be set", func(t *testing.T) { + client := AcquireClient().SetUserAgent("foo") + req := AcquireRequest() + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("foo"), req.rawRequest.Header.UserAgent()) + }) + + t.Run("ua in request should have higher level", func(t *testing.T) { + client := AcquireClient().SetUserAgent("foo") + req := AcquireRequest().SetUserAgent("bar") + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("bar"), req.rawRequest.Header.UserAgent()) + }) +} + +func TestParserBody(t *testing.T) { + t.Parallel() + + t.Run("json body", func(t *testing.T) { + type jsonData struct { + Name string `json:"name"` + } + client := AcquireClient() + req := AcquireRequest(). + SetJSON(jsonData{ + Name: "foo", + }) + + err := parserBody(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("{\"name\":\"foo\"}"), req.rawRequest.Body()) + }) + + t.Run("xml body", func(t *testing.T) { + type xmlData struct { + XMLName xml.Name `xml:"body"` + Name string `xml:"name"` + } + client := AcquireClient() + req := AcquireRequest(). + SetXML(xmlData{ + Name: "foo", + }) + + err := parserBody(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("foo"), req.rawRequest.Body()) + }) + + t.Run("raw body", func(t *testing.T) { + client := AcquireClient() + req := AcquireRequest(). + SetRawBody([]byte("hello world")) + + err := parserBody(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("hello world"), req.rawRequest.Body()) + }) } diff --git a/client/request.go b/client/request.go index 4d2bc738ec..dfe09457b3 100644 --- a/client/request.go +++ b/client/request.go @@ -10,12 +10,28 @@ import ( "github.com/valyala/fasthttp" ) +type bodyType int + +const ( + noBody bodyType = iota + jsonBody + xmlBody + formBody + filesBody + rawBody +) + type Request struct { - url string - method string - ctx context.Context - header *Header - params *Params + url string + method string + ctx context.Context + userAgent string + header *Header + params *Params + + body any + bodyType bodyType + rawRequest *fasthttp.Request } @@ -121,11 +137,42 @@ func (r *Request) DelParams(key ...string) *Request { return r } +// SetUserAgent method sets user agent in request. +// It will override user agent which set in client instance. +func (r *Request) SetUserAgent(ua string) *Request { + r.userAgent = ua + return r +} + +// SetJSON method sets json body in request. +func (r *Request) SetJSON(v any) *Request { + r.body = v + r.bodyType = jsonBody + return r +} + +// SetXML method sets xml body in request. +func (r *Request) SetXML(v any) *Request { + r.body = v + r.bodyType = xmlBody + return r +} + +// SetRawBody method sets body with raw data in request. +func (r *Request) SetRawBody(v []byte) *Request { + r.body = v + r.bodyType = rawBody + return r +} + // Reset clear Request object, used by ReleaseRequest method. func (r *Request) Reset() { r.url = "" r.method = fiber.MethodGet r.ctx = nil + r.userAgent = "" + r.body = nil + r.bodyType = noBody r.header.Reset() r.params.Reset() @@ -198,8 +245,6 @@ func (p *Params) SetParamsWithStruct(v any) { case reflect.Bool: if val.Bool() { p.Add(name, "true") - } else { - p.Add(name, "false") } case reflect.String: p.Add(name, val.String()) diff --git a/client/request_test.go b/client/request_test.go index 69669b8ced..26bba6b04f 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -11,6 +11,7 @@ func TestParamsSetParamsWithStruct(t *testing.T) { t.Parallel() type args struct { + unexport int TInt int TString string TFloat float64 @@ -24,16 +25,20 @@ func TestParamsSetParamsWithStruct(t *testing.T) { Args: fasthttp.AcquireArgs(), } p.SetParamsWithStruct(args{ + unexport: 5, TInt: 5, TString: "string", TFloat: 3.1, + TBool: false, TSlice: []string{"foo", "bar"}, TIntSlice: []int{1, 2}, }) + utils.AssertEqual(t, "", string(p.Peek("unexport"))) utils.AssertEqual(t, []byte("5"), p.Peek("TInt")) utils.AssertEqual(t, []byte("string"), p.Peek("TString")) utils.AssertEqual(t, []byte("3.1"), p.Peek("TFloat")) + utils.AssertEqual(t, "", string(p.Peek("TBool"))) utils.AssertEqual(t, true, func() bool { for _, v := range p.PeekMulti("TSlice") { if string(v) == "foo" { @@ -76,6 +81,7 @@ func TestParamsSetParamsWithStruct(t *testing.T) { TInt: 5, TString: "string", TFloat: 3.1, + TBool: true, TSlice: []string{"foo", "bar"}, TIntSlice: []int{1, 2}, }) @@ -83,6 +89,7 @@ func TestParamsSetParamsWithStruct(t *testing.T) { utils.AssertEqual(t, []byte("5"), p.Peek("TInt")) utils.AssertEqual(t, []byte("string"), p.Peek("TString")) utils.AssertEqual(t, []byte("3.1"), p.Peek("TFloat")) + utils.AssertEqual(t, "true", string(p.Peek("TBool"))) utils.AssertEqual(t, true, func() bool { for _, v := range p.PeekMulti("TSlice") { if string(v) == "foo" { @@ -133,4 +140,12 @@ func TestParamsSetParamsWithStruct(t *testing.T) { utils.AssertEqual(t, 0, len(p.PeekMulti("TSlice"))) utils.AssertEqual(t, 0, len(p.PeekMulti("int_slice"))) }) + + t.Run("error type should ignore", func(t *testing.T) { + p := &Params{ + Args: fasthttp.AcquireArgs(), + } + p.SetParamsWithStruct(5) + utils.AssertEqual(t, 0, p.Len()) + }) } From f6777380eb4ba58a7930414975c2c06c12f50f07 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Wed, 3 Aug 2022 15:29:16 +0800 Subject: [PATCH 010/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20add=20cookie=20sup?= =?UTF-8?q?port?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 43 ++++++++++++ client/hooks.go | 9 +++ client/hooks_test.go | 65 +++++++++++++++++ client/request.go | 161 ++++++++++++++++++++++++++++++++++--------- 4 files changed, 247 insertions(+), 31 deletions(-) diff --git a/client/client.go b/client/client.go index 733a40ee3a..64dfeb4e27 100644 --- a/client/client.go +++ b/client/client.go @@ -15,6 +15,7 @@ type Client struct { header *Header params *Params userAgent string + cookies *Cookie } // Add user-defined request hooks. @@ -154,11 +155,42 @@ func (c *Client) SetUserAgent(ua string) *Client { return c } +// SetCookie method sets a single cookie field and its value in the client instance. +// These cookies will be applied to all requests raised from this client instance. +// Also it can be overridden at request level cookie options. +func (c *Client) SetCookie(key, val string) *Client { + c.cookies.SetCookie(key, val) + return c +} + +// SetCookies method sets multiple cookies field and its values at one go in the client instance. +// These cookies will be applied to all requests raised from this client instance. Also it can be +// overridden at request level cookie options. +func (c *Client) SetCookies(m map[string]string) *Client { + c.cookies.SetCookies(m) + return c +} + +// SetCookiesWithStruct method sets multiple cookies field and its values at one go in the client instance. +// These cookies will be applied to all requests raised from this client instance. Also it can be +// overridden at request level cookies options. +func (c *Client) SetCookiesWithStruct(v any) *Client { + c.cookies.SetCookiesWithStruct(v) + return c +} + +// DelCookies method deletes single or multiple cookies field and its valus in client. +func (c *Client) DelCookies(key ...string) *Client { + c.cookies.DelCookies(key...) + return c +} + // Reset clear Client object. func (c *Client) Reset() { c.baseUrl = "" c.userAgent = "" + c.cookies.Reset() c.core.reset() c.header.Reset() c.params.Reset() @@ -201,6 +233,7 @@ func AcquireClient() (c *Client) { params: &Params{ Args: fasthttp.AcquireArgs(), }, + cookies: &Cookie{}, } return } @@ -218,6 +251,16 @@ func C() *Client { return defaultClient } +// Replce the defaultClient, the returned function can undo. +func Replace(c *Client) func() { + oldClient := defaultClient + defaultClient = c + + return func() { + defaultClient = oldClient + } +} + // Get send a get request use defaultClient, a convenient method. func Get(url string) (*Response, error) { return defaultClient.Get(url) diff --git a/client/hooks.go b/client/hooks.go index bfdd53a1be..6e24f9098e 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -130,6 +130,15 @@ func parserHeader(c *Client, req *Request) error { req.rawRequest.Header.SetUserAgent(req.userAgent) } + // set cookie + c.cookies.VisitAll(func(key, val string) { + req.rawRequest.Header.SetCookie(key, val) + }) + + req.cookies.VisitAll(func(key, val string) { + req.rawRequest.Header.SetCookie(key, val) + }) + return nil } diff --git a/client/hooks_test.go b/client/hooks_test.go index ee248997ae..474c14f7ce 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -244,6 +244,71 @@ func TestParserHeader(t *testing.T) { utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("bar"), req.rawRequest.Header.UserAgent()) }) + + t.Run("client cookie should be set", func(t *testing.T) { + client := AcquireClient(). + SetCookie("foo", "bar"). + SetCookies(map[string]string{ + "bar": "foo", + "bar1": "foo1", + }). + DelCookies("bar1") + + req := AcquireRequest() + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "bar", string(req.rawRequest.Header.Cookie("foo"))) + utils.AssertEqual(t, "foo", string(req.rawRequest.Header.Cookie("bar"))) + utils.AssertEqual(t, "", string(req.rawRequest.Header.Cookie("bar1"))) + }) + + t.Run("request cookie should be set", func(t *testing.T) { + type cookies struct { + Foo string `cookie:"foo"` + Bar int `cookie:"bar"` + } + + client := AcquireClient() + + req := AcquireRequest(). + SetCookiesWithStruct(&cookies{ + Foo: "bar", + Bar: 67, + }) + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "bar", string(req.rawRequest.Header.Cookie("foo"))) + utils.AssertEqual(t, "67", string(req.rawRequest.Header.Cookie("bar"))) + utils.AssertEqual(t, "", string(req.rawRequest.Header.Cookie("bar1"))) + }) + + t.Run("request cookie will override client cookie", func(t *testing.T) { + type cookies struct { + Foo string `cookie:"foo"` + Bar int `cookie:"bar"` + } + + client := AcquireClient(). + SetCookie("foo", "bar"). + SetCookies(map[string]string{ + "bar": "foo", + "bar1": "foo1", + }) + + req := AcquireRequest(). + SetCookiesWithStruct(&cookies{ + Foo: "bar", + Bar: 67, + }) + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "bar", string(req.rawRequest.Header.Cookie("foo"))) + utils.AssertEqual(t, "67", string(req.rawRequest.Header.Cookie("bar"))) + utils.AssertEqual(t, "foo1", string(req.rawRequest.Header.Cookie("bar1"))) + }) } func TestParserBody(t *testing.T) { diff --git a/client/request.go b/client/request.go index dfe09457b3..ee454dcb75 100644 --- a/client/request.go +++ b/client/request.go @@ -10,8 +10,16 @@ import ( "github.com/valyala/fasthttp" ) +// Implementing this interface allows data to be passed through the structure. +type WithStruct interface { + Add(string, string) + Del(string) +} + +// Types of request bodies. type bodyType int +// Enumeration definition of the request body type. const ( noBody bodyType = iota jsonBody @@ -28,6 +36,7 @@ type Request struct { userAgent string header *Header params *Params + cookies *Cookie body any bodyType bodyType @@ -144,6 +153,33 @@ func (r *Request) SetUserAgent(ua string) *Request { return r } +// SetCookie method sets a single cookie field and its value in the request instance. +// It will override cookie which set in client instance. +func (r *Request) SetCookie(key, val string) *Request { + r.cookies.SetCookie(key, val) + return r +} + +// SetCookies method sets multiple cookie field and its values at one go in the request instance. +// It will override cookie which set in client instance. +func (r *Request) SetCookies(m map[string]string) *Request { + r.cookies.SetCookies(m) + return r +} + +// SetCookiesWithStruct method sets multiple cookies field and its values at one go in the request instance. +// It will override cookie which set in client instance. +func (r *Request) SetCookiesWithStruct(v any) *Request { + r.cookies.SetCookiesWithStruct(v) + return r +} + +// DelParams method deletes single or multiple cookies field ant its values. +func (r *Request) DelCookies(key ...string) *Request { + r.cookies.DelCookies(key...) + return r +} + // SetJSON method sets json body in request. func (r *Request) SetJSON(v any) *Request { r.body = v @@ -174,6 +210,7 @@ func (r *Request) Reset() { r.body = nil r.bodyType = noBody + r.cookies.Reset() r.header.Reset() r.params.Reset() r.rawRequest.Reset() @@ -223,11 +260,102 @@ func (p *Params) SetParams(r map[string]string) { } } +// SetParamsWithStruct will override all params with struct or pointer of struct. +// Now nested structs are not currently supported. func (p *Params) SetParamsWithStruct(v any) { + SetValWithStruct(p, "param", v) +} + +// Cookie is a map which to store the cookies. +type Cookie map[string]string + +// Add method impl the method in WithStruct interface. +func (c Cookie) Add(key, val string) { + c[key] = val +} + +// Del method impl the method in WithStruct interface. +func (c Cookie) Del(key string) { + delete(c, key) +} + +// SetCookie method sets a signle val in Cookie. +func (c Cookie) SetCookie(key, val string) { + c[key] = val +} + +// SetCookies method sets multiple val in Cookie. +func (c Cookie) SetCookies(m map[string]string) { + for k, v := range m { + c[k] = v + } +} + +// SetCookiesWithStruct method sets multiple val in Cookie via a struct. +func (c Cookie) SetCookiesWithStruct(v any) { + SetValWithStruct(c, "cookie", v) +} + +// DelCookies method deletes mutiple val in Cookie. +func (c Cookie) DelCookies(key ...string) { + for _, v := range key { + c.Del(v) + } +} + +// VisitAll method receive a function which can travel the all val. +func (c Cookie) VisitAll(f func(key, val string)) { + for k, v := range c { + f(k, v) + } +} + +// Reset clear the Cookie object. +func (c Cookie) Reset() { + for k := range c { + delete(c, k) + } +} + +var requestPool sync.Pool + +// AcquireRequest returns an empty request object from the pool. +// +// The returned request may be returned to the pool with ReleaseRequest when no longer needed. +// This allows reducing GC load. +func AcquireRequest() (req *Request) { + reqv := requestPool.Get() + if reqv != nil { + req = reqv.(*Request) + return + } + + req = &Request{ + header: &Header{RequestHeader: &fasthttp.RequestHeader{}}, + params: &Params{Args: fasthttp.AcquireArgs()}, + cookies: &Cookie{}, + rawRequest: fasthttp.AcquireRequest(), + } + return +} + +// ReleaseRequest returns the object acquired via AcquireRequest to the pool. +// +// Do not access the released Request object, otherwise data races may occur. +func ReleaseRequest(req *Request) { + req.Reset() + requestPool.Put(req) +} + +// Set some values using structs. +// `p` is a structure that implements the WithStruct interface, +// The field name can be specified by `tagName`. +// `v` is a struct include some data. +func SetValWithStruct(p WithStruct, tagName string, v any) { valueOfV := reflect.ValueOf(v) typeOfV := reflect.TypeOf(v) - // The v should be struct or point of struct + // The v should be struct or point of struct if typeOfV.Kind() == reflect.Pointer && typeOfV.Elem().Kind() == reflect.Struct { valueOfV = valueOfV.Elem() typeOfV = typeOfV.Elem() @@ -264,7 +392,7 @@ func (p *Params) SetParamsWithStruct(v any) { continue } - name := field.Tag.Get("param") + name := field.Tag.Get(tagName) if name == "" { name = field.Name } @@ -277,32 +405,3 @@ func (p *Params) SetParamsWithStruct(v any) { setVal(name, val) } } - -var requestPool sync.Pool - -// AcquireRequest returns an empty request object from the pool. -// -// The returned request may be returned to the pool with ReleaseRequest when no longer needed. -// This allows reducing GC load. -func AcquireRequest() (req *Request) { - reqv := requestPool.Get() - if reqv != nil { - req = reqv.(*Request) - return - } - - req = &Request{ - header: &Header{RequestHeader: &fasthttp.RequestHeader{}}, - params: &Params{Args: fasthttp.AcquireArgs()}, - rawRequest: fasthttp.AcquireRequest(), - } - return -} - -// ReleaseRequest returns the object acquired via AcquireRequest to the pool. -// -// Do not access the released Request object, otherwise data races may occur. -func ReleaseRequest(req *Request) { - req.Reset() - requestPool.Put(req) -} From 108ce0843f5791f20192fe6c1884763dc7fab585 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Wed, 3 Aug 2022 22:48:37 +0800 Subject: [PATCH 011/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20add=20path=20param?= =?UTF-8?q?=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 37 +++++++++++++++- client/hooks.go | 8 ++++ client/hooks_test.go | 45 ++++++++++++++++++++ client/request.go | 97 ++++++++++++++++++++++++++++++++++++++---- client/request_test.go | 8 ++-- 5 files changed, 181 insertions(+), 14 deletions(-) diff --git a/client/client.go b/client/client.go index 64dfeb4e27..a40523be7f 100644 --- a/client/client.go +++ b/client/client.go @@ -13,9 +13,10 @@ type Client struct { baseUrl string header *Header - params *Params + params *QueryParam userAgent string cookies *Cookie + path *PathParam } // Add user-defined request hooks. @@ -155,6 +156,36 @@ func (c *Client) SetUserAgent(ua string) *Client { return c } +// SetPathParam method sets a single path param field and its value in the client instance. +// These path params will be applied to all requests raised from this client instance. +// Also it can be overridden at request level path params options. +func (c *Client) SetPathParam(key, val string) *Client { + c.path.SetParam(key, val) + return c +} + +// SetPathParams method sets multiple path params field and its values at one go in the client instance. +// These path params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level path params options. +func (c *Client) SetPathParams(m map[string]string) *Client { + c.path.SetParams(m) + return c +} + +// SetPathParamsWithStruct method sets multiple path params field and its values at one go in the client instance. +// These path params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level path params options. +func (c *Client) SetPathParamsWithStruct(v any) *Client { + c.path.SetParamsWithStruct(v) + return c +} + +// DelPathParams method deletes single or multiple path params field and its valus in client. +func (c *Client) DelPathParams(key ...string) *Client { + c.path.DelParams(key...) + return c +} + // SetCookie method sets a single cookie field and its value in the client instance. // These cookies will be applied to all requests raised from this client instance. // Also it can be overridden at request level cookie options. @@ -190,6 +221,7 @@ func (c *Client) Reset() { c.baseUrl = "" c.userAgent = "" + c.path.Reset() c.cookies.Reset() c.core.reset() c.header.Reset() @@ -230,10 +262,11 @@ func AcquireClient() (c *Client) { header: &Header{ RequestHeader: &fasthttp.RequestHeader{}, }, - params: &Params{ + params: &QueryParam{ Args: fasthttp.AcquireArgs(), }, cookies: &Cookie{}, + path: &PathParam{}, } return } diff --git a/client/hooks.go b/client/hooks.go index 6e24f9098e..32740f5087 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -59,6 +59,14 @@ func parserURL(c *Client, req *Request) error { } } + // set path params + req.path.VisitAll(func(key, val string) { + uri = strings.Replace(uri, "{"+key+"}", val, -1) + }) + c.path.VisitAll(func(key, val string) { + uri = strings.Replace(uri, "{"+key+"}", val, -1) + }) + // set uri to request and orther related setting req.rawRequest.SetRequestURI(uri) rawUri := req.rawRequest.URI() diff --git a/client/hooks_test.go b/client/hooks_test.go index 474c14f7ce..a657184276 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -97,6 +97,51 @@ func TestParserURL(t *testing.T) { utils.AssertEqual(t, fmt.Errorf("url format error"), err) }) + t.Run("the path param from client", func(t *testing.T) { + client := AcquireClient(). + SetBaseURL("http://example.com/api/{id}"). + SetPathParam("id", "5") + req := AcquireRequest() + + err := parserURL(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "http://example.com/api/5", req.rawRequest.URI().String()) + }) + + t.Run("the path param from request", func(t *testing.T) { + client := AcquireClient(). + SetBaseURL("http://example.com/api/{id}/{name}"). + SetPathParam("id", "5") + req := AcquireRequest(). + SetURL("/{key}"). + SetPathParams(map[string]string{ + "name": "fiber", + "key": "val", + }). + DelPathParams("key") + + err := parserURL(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "http://example.com/api/5/fiber/%7Bkey%7D", req.rawRequest.URI().String()) + }) + + t.Run("the path param from request and client", func(t *testing.T) { + client := AcquireClient(). + SetBaseURL("http://example.com/api/{id}/{name}"). + SetPathParam("id", "5") + req := AcquireRequest(). + SetURL("/{key}"). + SetPathParams(map[string]string{ + "name": "fiber", + "key": "val", + "id": "12", + }) + + err := parserURL(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "http://example.com/api/12/fiber/%7Bkey%7D", req.rawRequest.URI().String()) + }) + t.Run("query params from client should be set", func(t *testing.T) { client := AcquireClient(). SetParam("foo", "bar") diff --git a/client/request.go b/client/request.go index ee454dcb75..67fe8aa7f8 100644 --- a/client/request.go +++ b/client/request.go @@ -35,8 +35,9 @@ type Request struct { ctx context.Context userAgent string header *Header - params *Params + params *QueryParam cookies *Cookie + path *PathParam body any bodyType bodyType @@ -174,12 +175,39 @@ func (r *Request) SetCookiesWithStruct(v any) *Request { return r } -// DelParams method deletes single or multiple cookies field ant its values. +// DelCookies method deletes single or multiple cookies field ant its values. func (r *Request) DelCookies(key ...string) *Request { r.cookies.DelCookies(key...) return r } +// SetPathParam method sets a single path param field and its value in the request instance. +// It will override path param which set in client instance. +func (r *Request) SetPathParam(key, val string) *Request { + r.path.SetParam(key, val) + return r +} + +// SetPathParams method sets multiple path params field and its values at one go in the request instance. +// It will override path param which set in client instance. +func (r *Request) SetPathParams(m map[string]string) *Request { + r.path.SetParams(m) + return r +} + +// SetParamsWithStruct method sets multiple path params field and its values at one go in the request instance. +// It will override path param which set in client instance. +func (r *Request) SetPathParamsWithStruct(v any) *Request { + r.path.SetParamsWithStruct(v) + return r +} + +// DelPathParams method deletes single or multiple path params field ant its values. +func (r *Request) DelPathParams(key ...string) *Request { + r.path.DelParams(key...) + return r +} + // SetJSON method sets json body in request. func (r *Request) SetJSON(v any) *Request { r.body = v @@ -210,6 +238,7 @@ func (r *Request) Reset() { r.body = nil r.bodyType = noBody + r.path.Reset() r.cookies.Reset() r.header.Reset() r.params.Reset() @@ -238,14 +267,14 @@ func (h *Header) SetHeaders(r map[string]string) { } } -// Params is a wrapper which wrap url.Values, +// QueryParam is a wrapper which wrap url.Values, // the query string and formdata in client and request will store in it. -type Params struct { +type QueryParam struct { *fasthttp.Args } // AddParams receive a map and add each value to param. -func (p *Params) AddParams(r map[string][]string) { +func (p *QueryParam) AddParams(r map[string][]string) { for k, v := range r { for _, vv := range v { p.Add(k, vv) @@ -254,7 +283,7 @@ func (p *Params) AddParams(r map[string][]string) { } // SetParams will override all params. -func (p *Params) SetParams(r map[string]string) { +func (p *QueryParam) SetParams(r map[string]string) { for k, v := range r { p.Set(k, v) } @@ -262,7 +291,7 @@ func (p *Params) SetParams(r map[string]string) { // SetParamsWithStruct will override all params with struct or pointer of struct. // Now nested structs are not currently supported. -func (p *Params) SetParamsWithStruct(v any) { +func (p *QueryParam) SetParamsWithStruct(v any) { SetValWithStruct(p, "param", v) } @@ -317,6 +346,57 @@ func (c Cookie) Reset() { } } +// PathParam is a map which to store the cookies. +type PathParam map[string]string + +// Add method impl the method in WithStruct interface. +func (p PathParam) Add(key, val string) { + p[key] = val +} + +// Del method impl the method in WithStruct interface. +func (p PathParam) Del(key string) { + delete(p, key) +} + +// SetParam method sets a signle val in PathParam. +func (p PathParam) SetParam(key, val string) { + p[key] = val +} + +// SetParams method sets multiple val in PathParam. +func (p PathParam) SetParams(m map[string]string) { + for k, v := range m { + p[k] = v + } +} + +// SetParamsWithStruct method sets multiple val in PathParam via a struct. +func (p PathParam) SetParamsWithStruct(v any) { + SetValWithStruct(p, "path", v) +} + +// DelParams method deletes mutiple val in PathParams. +func (p PathParam) DelParams(key ...string) { + for _, v := range key { + p.Del(v) + } +} + +// VisitAll method receive a function which can travel the all val. +func (p PathParam) VisitAll(f func(key, val string)) { + for k, v := range p { + f(k, v) + } +} + +// Reset clear the PathParams object. +func (p PathParam) Reset() { + for k := range p { + delete(p, k) + } +} + var requestPool sync.Pool // AcquireRequest returns an empty request object from the pool. @@ -332,8 +412,9 @@ func AcquireRequest() (req *Request) { req = &Request{ header: &Header{RequestHeader: &fasthttp.RequestHeader{}}, - params: &Params{Args: fasthttp.AcquireArgs()}, + params: &QueryParam{Args: fasthttp.AcquireArgs()}, cookies: &Cookie{}, + path: &PathParam{}, rawRequest: fasthttp.AcquireRequest(), } return diff --git a/client/request_test.go b/client/request_test.go index 26bba6b04f..11e3306fd1 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -21,7 +21,7 @@ func TestParamsSetParamsWithStruct(t *testing.T) { } t.Run("the struct should be applied", func(t *testing.T) { - p := &Params{ + p := &QueryParam{ Args: fasthttp.AcquireArgs(), } p.SetParamsWithStruct(args{ @@ -74,7 +74,7 @@ func TestParamsSetParamsWithStruct(t *testing.T) { }) t.Run("the pointer of a struct should be applied", func(t *testing.T) { - p := &Params{ + p := &QueryParam{ Args: fasthttp.AcquireArgs(), } p.SetParamsWithStruct(&args{ @@ -125,7 +125,7 @@ func TestParamsSetParamsWithStruct(t *testing.T) { }) t.Run("the zero val should be ignore", func(t *testing.T) { - p := &Params{ + p := &QueryParam{ Args: fasthttp.AcquireArgs(), } p.SetParamsWithStruct(&args{ @@ -142,7 +142,7 @@ func TestParamsSetParamsWithStruct(t *testing.T) { }) t.Run("error type should ignore", func(t *testing.T) { - p := &Params{ + p := &QueryParam{ Args: fasthttp.AcquireArgs(), } p.SetParamsWithStruct(5) From 51d9780df3906a31cca424babd99d48395625b8a Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Wed, 3 Aug 2022 22:51:13 +0800 Subject: [PATCH 012/118] =?UTF-8?q?=E2=9C=85=20v3:=20fix=20error=20test=20?= =?UTF-8?q?case?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/hooks_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/hooks_test.go b/client/hooks_test.go index a657184276..cb5044451d 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -139,7 +139,7 @@ func TestParserURL(t *testing.T) { err := parserURL(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api/12/fiber/%7Bkey%7D", req.rawRequest.URI().String()) + utils.AssertEqual(t, "http://example.com/api/12/fiber/val", req.rawRequest.URI().String()) }) t.Run("query params from client should be set", func(t *testing.T) { From dc0e374c722008a4ca1e32499bb31fd8f6b79fac Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Thu, 4 Aug 2022 15:44:39 +0800 Subject: [PATCH 013/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20add=20formdata=20a?= =?UTF-8?q?nd=20file=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/hooks.go | 104 ++++++++++++++++- client/hooks_test.go | 81 ++++++++++++++ client/request.go | 261 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 433 insertions(+), 13 deletions(-) diff --git a/client/hooks.go b/client/hooks.go index 32740f5087..e3488e2a1f 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -3,11 +3,17 @@ package client import ( "bytes" "fmt" + "io" + "math/rand" + "mime/multipart" "net" + "os" + "path/filepath" "reflect" "regexp" "strconv" "strings" + "time" "github.com/gofiber/fiber/v3/utils" "github.com/valyala/fasthttp" @@ -25,6 +31,12 @@ var ( applicationXML = "application/xml" applicationForm = "application/x-www-form-urlencoded" multipartFormData = "multipart/form-data" + + src = rand.NewSource(time.Now().UnixNano()) + letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + letterIdxBits = 6 // 6 bits to represent a letter index + letterIdxMask = 1<= 0; { + if remain == 0 { + cache, remain = src.Int63(), letterIdxMax + } + + if idx := int(cache & int64(letterIdxMask)); idx < length { + b[i] = letterBytes[idx] + i-- + } + cache >>= int64(letterIdxBits) + remain-- + } + + return utils.UnsafeString(b) +} + // parserURL will set the options for the hostclient // and normalize the url. // The baseUrl will be merge with request uri. @@ -126,6 +157,8 @@ func parserHeader(c *Client, req *Request) error { req.rawRequest.Header.SetContentType(applicationForm) case filesBody: req.rawRequest.Header.SetContentType(multipartFormData) + // set boundary + req.rawRequest.Header.SetMultipartFormBoundary(req.boundary) default: } @@ -152,7 +185,7 @@ func parserHeader(c *Client, req *Request) error { // parserBody automatically serializes the data according to // the data type and stores it in the body of the rawRequest -func parserBody(c *Client, req *Request) error { +func parserBody(c *Client, req *Request) (err error) { switch req.bodyType { case jsonBody: body, err := c.core.jsonMarshal(req.body) @@ -167,7 +200,76 @@ func parserBody(c *Client, req *Request) error { } req.rawRequest.SetBody(body) case formBody: + req.rawRequest.SetBody(req.formData.QueryString()) case filesBody: + mw := multipart.NewWriter(req.rawRequest.BodyWriter()) + mw.SetBoundary(req.boundary) + defer func() { + err = mw.Close() + if err != nil { + return + } + }() + + // add formdata + req.formData.VisitAll(func(key, value []byte) { + if err != nil { + return + } + err = mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value)) + }) + if err != nil { + return + } + + // add file + b := make([]byte, 512) + for i, v := range req.files { + if v.name == "" && v.path == "" { + return fmt.Errorf("the file should have a name") + } + + // if name is not exist, set name + if v.name == "" && v.path != "" { + v.path = filepath.Clean(v.path) + v.name = filepath.Base(v.name) + } + + // if param is not exist, set it + if v.paramName == "" { + v.paramName = "file" + fmt.Sprint(i) + } + + // check the reader + if v.reader == nil { + v.reader, err = os.Open(v.path) + if err != nil { + return + } + } + + // wirte file + w, err := mw.CreateFormFile(v.paramName, v.name) + if err != nil { + return err + } + + for { + _, err := v.reader.Read(b) + if err != nil && err != io.EOF { + return err + } + + if err == io.EOF { + break + } + + w.Write(b) + } + + // ignore err + v.reader.Close() + } case rawBody: if body, ok := req.body.([]byte); ok { req.rawRequest.SetBody(body) diff --git a/client/hooks_test.go b/client/hooks_test.go index cb5044451d..6b30e512b0 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -3,7 +3,9 @@ package client import ( "encoding/xml" "fmt" + "io" "net/url" + "strings" "testing" "github.com/gofiber/fiber/v3" @@ -50,6 +52,24 @@ func TestAddMissingPort(t *testing.T) { } } +func TestRandString(t *testing.T) { + tests := []struct { + name string + args int + }{ + { + name: "test generate", + args: 16, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := randString(tt.args) + utils.AssertEqual(t, 16, len(got)) + }) + } +} + func TestParserURL(t *testing.T) { t.Parallel() @@ -263,6 +283,31 @@ func TestParserHeader(t *testing.T) { utils.AssertEqual(t, []byte(applicationXML), req.rawRequest.Header.ContentType()) }) + t.Run("auto set form data header", func(t *testing.T) { + client := AcquireClient() + req := AcquireRequest(). + SetFormDatas(map[string]string{ + "foo": "bar", + "ball": "cricle and square", + }) + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, applicationForm, string(req.rawRequest.Header.ContentType())) + }) + + t.Run("auto set file header", func(t *testing.T) { + client := AcquireClient() + req := AcquireRequest(). + AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))). + SetFormData("foo", "bar") + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Header.MultipartFormBoundary()), "--FiberFormBoundary")) + utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Header.ContentType()), multipartFormData)) + }) + t.Run("ua should have default value", func(t *testing.T) { client := AcquireClient() req := AcquireRequest() @@ -390,6 +435,42 @@ func TestParserBody(t *testing.T) { utils.AssertEqual(t, []byte("foo"), req.rawRequest.Body()) }) + t.Run("form data body", func(t *testing.T) { + client := AcquireClient() + req := AcquireRequest(). + SetFormDatas(map[string]string{ + "ball": "cricle and square", + }) + + err := parserBody(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "ball=cricle+and+square", string(req.rawRequest.Body())) + }) + + t.Run("file body", func(t *testing.T) { + client := AcquireClient() + req := AcquireRequest(). + AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))) + + err := parserBody(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Body()), "----FiberFormBoundary")) + utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Body()), "world")) + }) + + t.Run("file and form data", func(t *testing.T) { + client := AcquireClient() + req := AcquireRequest(). + AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))). + SetFormData("foo", "bar") + + err := parserBody(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Body()), "----FiberFormBoundary")) + utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Body()), "world")) + utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Body()), "bar")) + }) + t.Run("raw body", func(t *testing.T) { client := AcquireClient() req := AcquireRequest(). diff --git a/client/request.go b/client/request.go index 67fe8aa7f8..ed8e7f2ad5 100644 --- a/client/request.go +++ b/client/request.go @@ -2,6 +2,7 @@ package client import ( "context" + "io" "reflect" "strconv" "sync" @@ -34,12 +35,15 @@ type Request struct { method string ctx context.Context userAgent string + boundary string header *Header params *QueryParam cookies *Cookie path *PathParam body any + formData *FormData + files []*File bodyType bodyType rawRequest *fasthttp.Request @@ -90,14 +94,14 @@ func (r *Request) SetHeader(key, val string) *Request { return r } -// AddHeaders method adds multiple headers field and its values at one go in the request instance. +// AddHeaders method adds multiple header fields and its values at one go in the request instance. // It will override header which set in client instance. func (r *Request) AddHeaders(h map[string][]string) *Request { r.header.AddHeaders(h) return r } -// SetHeaders method sets multiple headers field and its values at one go in the request instance. +// SetHeaders method sets multiple header fields and its values at one go in the request instance. // It will override header which set in client instance. func (r *Request) SetHeaders(h map[string]string) *Request { r.header.SetHeaders(h) @@ -118,28 +122,28 @@ func (r *Request) SetParam(key, val string) *Request { return r } -// AddParams method adds multiple params field and its values at one go in the request instance. +// AddParams method adds multiple param fields and its values at one go in the request instance. // It will override param which set in client instance. func (r *Request) AddParams(m map[string][]string) *Request { r.params.AddParams(m) return r } -// SetParams method sets multiple params field and its values at one go in the request instance. +// SetParams method sets multiple param fields and its values at one go in the request instance. // It will override param which set in client instance. func (r *Request) SetParams(m map[string]string) *Request { r.params.SetParams(m) return r } -// SetParamWithStruct method sets multiple params field and its values at one go in the request instance. +// SetParamWithStruct method sets multiple param fields and its values at one go in the request instance. // It will override param which set in client instance. func (r *Request) SetParamsWithStruct(v any) *Request { r.params.SetParamsWithStruct(v) return r } -// DelParams method deletes single or multiple params field ant its values. +// DelParams method deletes single or multiple param fields ant its values. func (r *Request) DelParams(key ...string) *Request { for _, v := range key { r.params.Del(v) @@ -161,21 +165,21 @@ func (r *Request) SetCookie(key, val string) *Request { return r } -// SetCookies method sets multiple cookie field and its values at one go in the request instance. +// SetCookies method sets multiple cookie fields and its values at one go in the request instance. // It will override cookie which set in client instance. func (r *Request) SetCookies(m map[string]string) *Request { r.cookies.SetCookies(m) return r } -// SetCookiesWithStruct method sets multiple cookies field and its values at one go in the request instance. +// SetCookiesWithStruct method sets multiple cookie fields and its values at one go in the request instance. // It will override cookie which set in client instance. func (r *Request) SetCookiesWithStruct(v any) *Request { r.cookies.SetCookiesWithStruct(v) return r } -// DelCookies method deletes single or multiple cookies field ant its values. +// DelCookies method deletes single or multiple cookie fields ant its values. func (r *Request) DelCookies(key ...string) *Request { r.cookies.DelCookies(key...) return r @@ -188,21 +192,21 @@ func (r *Request) SetPathParam(key, val string) *Request { return r } -// SetPathParams method sets multiple path params field and its values at one go in the request instance. +// SetPathParams method sets multiple path param fields and its values at one go in the request instance. // It will override path param which set in client instance. func (r *Request) SetPathParams(m map[string]string) *Request { r.path.SetParams(m) return r } -// SetParamsWithStruct method sets multiple path params field and its values at one go in the request instance. +// SetParamsWithStruct method sets multiple path param fields and its values at one go in the request instance. // It will override path param which set in client instance. func (r *Request) SetPathParamsWithStruct(v any) *Request { r.path.SetParamsWithStruct(v) return r } -// DelPathParams method deletes single or multiple path params field ant its values. +// DelPathParams method deletes single or multiple path param fields ant its values. func (r *Request) DelPathParams(key ...string) *Request { r.path.DelParams(key...) return r @@ -229,6 +233,84 @@ func (r *Request) SetRawBody(v []byte) *Request { return r } +// resetBody will clear body object and set bodyType +func (r *Request) resetBody(t bodyType) { + r.body = nil + + // Set form data after set file ignore. + if r.bodyType == filesBody && t == formBody { + return + } + r.bodyType = t +} + +// AddFormData method adds a single form data field and its value in the request instance. +func (r *Request) AddFormData(key, val string) *Request { + r.formData.AddData(key, val) + r.resetBody(formBody) + return r +} + +// SetFormData method sets a single form data field and its value in the request instance. +func (r *Request) SetFormData(key, val string) *Request { + r.formData.SetData(key, val) + r.resetBody(formBody) + return r +} + +// AddFormDatas method adds multiple form data fields and its values in the request instance. +func (r *Request) AddFormDatas(m map[string][]string) *Request { + r.formData.AddDatas(m) + r.resetBody(formBody) + return r +} + +// SetFormDatas method sets multiple form data fields and its values in the request instance. +func (r *Request) SetFormDatas(m map[string]string) *Request { + r.formData.SetDatas(m) + r.resetBody(formBody) + return r +} + +// SetFormDatasWithStruct method sets multiple form data fields +// and its values in the request instance via struct. +func (r *Request) SetFormDatasWithStruct(v any) *Request { + r.formData.SetDatasWithStruct(v) + r.resetBody(formBody) + return r +} + +// DelFormDatas method deletes multiple form data fields and its value in the request instance. +func (r *Request) DelFormDatas(key ...string) *Request { + r.formData.DelDatas(key...) + r.resetBody(formBody) + return r +} + +// AddFile method adds single file field +// and its value in the request instance via file path. +func (r *Request) AddFile(path string) *Request { + r.files = append(r.files, AcquireFile(SetFilePath(path))) + r.resetBody(filesBody) + return r +} + +// AddFileWithReader method adds single field +// and its value in the request instance via reader. +func (r *Request) AddFileWithReader(name string, reader io.ReadCloser) *Request { + r.files = append(r.files, AcquireFile(SetFileName(name), SetFileReader(reader))) + r.resetBody(filesBody) + return r +} + +// AddFile method adds multiple file fields +// and its value in the request instance via File instance. +func (r *Request) AddFiles(files ...*File) *Request { + r.files = append(r.files, files...) + r.resetBody(filesBody) + return r +} + // Reset clear Request object, used by ReleaseRequest method. func (r *Request) Reset() { r.url = "" @@ -238,6 +320,13 @@ func (r *Request) Reset() { r.body = nil r.bodyType = noBody + copiedFile := r.files + r.files = r.files[0:0] + for _, v := range copiedFile { + ReleaseFile(v) + } + + r.formData.Reset() r.path.Reset() r.cookies.Reset() r.header.Reset() @@ -397,6 +486,92 @@ func (p PathParam) Reset() { } } +// FormData is a wrapper of fasthttp.Args, +// and it be used for url encode body and file body. +type FormData struct { + *fasthttp.Args +} + +// AddData method is a wrapper of Args's Add method. +func (f *FormData) AddData(key, val string) { + f.Add(key, val) +} + +// SetData method is a wrapper of Args's Set method. +func (f *FormData) SetData(key, val string) { + f.Set(key, val) +} + +// AddDatas method supports add multiple fields. +func (f *FormData) AddDatas(m map[string][]string) { + for k, v := range m { + for _, vv := range v { + f.Add(k, vv) + } + } +} + +// SetDatas method supports set multiple fields. +func (f *FormData) SetDatas(m map[string]string) { + for k, v := range m { + f.Set(k, v) + } +} + +// SetDatasWithStruct method supports set mutiple fields via a struct. +func (f *FormData) SetDatasWithStruct(v any) { + SetValWithStruct(f, "form", v) +} + +// DelDatas method deletes multiple fields. +func (f *FormData) DelDatas(key ...string) { + for _, v := range key { + f.Del(v) + } +} + +// Reset clear the FormData object. +func (f *FormData) Reset() { + f.Args.Reset() +} + +// File is a struct which support send files via request. +type File struct { + name string + paramName string + path string + reader io.ReadCloser +} + +// SetName method sets file name. +func (f *File) SetName(n string) { + f.name = n +} + +// SetParamName method sets key of file in the body. +func (f *File) SetParamName(n string) { + f.paramName = n +} + +// SetPath method set file path. +func (f *File) SetPath(p string) { + f.path = p +} + +// SetReader method can reveive a io.ReadCloser +// which will be closed in parserBody hook. +func (f *File) SetReader(r io.ReadCloser) { + f.reader = r +} + +// Reset clear the File object. +func (f *File) Reset() { + f.name = "" + f.paramName = "" + f.path = "" + f.reader = nil +} + var requestPool sync.Pool // AcquireRequest returns an empty request object from the pool. @@ -415,6 +590,9 @@ func AcquireRequest() (req *Request) { params: &QueryParam{Args: fasthttp.AcquireArgs()}, cookies: &Cookie{}, path: &PathParam{}, + boundary: "--FiberFormBoundary" + randString(16), + formData: &FormData{Args: fasthttp.AcquireArgs()}, + files: make([]*File, 0), rawRequest: fasthttp.AcquireRequest(), } return @@ -428,6 +606,65 @@ func ReleaseRequest(req *Request) { requestPool.Put(req) } +var filePool sync.Pool + +// The methods as follows is used by AcquireFile method. +// You can set file field via these method. +type SetFileFunc func(f *File) + +func SetFileName(n string) SetFileFunc { + return func(f *File) { + f.SetName(n) + } +} + +func SetFileParamName(p string) SetFileFunc { + return func(f *File) { + f.SetParamName(p) + } +} + +func SetFilePath(p string) SetFileFunc { + return func(f *File) { + f.SetPath(p) + } +} + +func SetFileReader(r io.ReadCloser) SetFileFunc { + return func(f *File) { + f.SetReader(r) + } +} + +// AcquireFile returns an File object from the pool. +// And you can set field in the File with SetFileFunc. +// +// The returned file may be returned to the pool with ReleaseFile when no longer needed. +// This allows reducing GC load. +func AcquireFile(setter ...SetFileFunc) (f *File) { + fv := filePool.Get() + if fv != nil { + f = fv.(*File) + for _, v := range setter { + v(f) + } + return + } + f = &File{} + for _, v := range setter { + v(f) + } + return +} + +// ReleaseFile returns the object acquired via AcquireFile to the pool. +// +// Do not access the released File object, otherwise data races may occur. +func ReleaseFile(f *File) { + f.Reset() + filePool.Put(f) +} + // Set some values using structs. // `p` is a structure that implements the WithStruct interface, // The field name can be specified by `tagName`. From c7afe2de2401c8b8561f04958ede3e0bd5e719c9 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Thu, 4 Aug 2022 20:27:47 +0800 Subject: [PATCH 014/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20referer=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 14 ++++++++++++-- client/hooks.go | 8 +++++++- client/hooks_test.go | 18 ++++++++++++++++++ client/request.go | 13 +++++++++++-- 4 files changed, 48 insertions(+), 5 deletions(-) diff --git a/client/client.go b/client/client.go index a40523be7f..d107e7bff6 100644 --- a/client/client.go +++ b/client/client.go @@ -12,9 +12,10 @@ type Client struct { core *Core baseUrl string + userAgent string + referer string header *Header params *QueryParam - userAgent string cookies *Cookie path *PathParam } @@ -156,6 +157,14 @@ func (c *Client) SetUserAgent(ua string) *Client { return c } +// SetReferer method sets referer field and its value in the client instance. +// This referer will be applied to all requests raised from this client instance. +// Also it can be overridden at request level referer options. +func (c *Client) SetReferer(r string) *Client { + c.referer = r + return c +} + // SetPathParam method sets a single path param field and its value in the client instance. // These path params will be applied to all requests raised from this client instance. // Also it can be overridden at request level path params options. @@ -220,6 +229,7 @@ func (c *Client) DelCookies(key ...string) *Client { func (c *Client) Reset() { c.baseUrl = "" c.userAgent = "" + c.referer = "" c.path.Reset() c.cookies.Reset() @@ -266,7 +276,7 @@ func AcquireClient() (c *Client) { Args: fasthttp.AcquireArgs(), }, cookies: &Cookie{}, - path: &PathParam{}, + path: &PathParam{}, } return } diff --git a/client/hooks.go b/client/hooks.go index e3488e2a1f..352981d71d 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -74,7 +74,7 @@ func randString(n int) string { // parserURL will set the options for the hostclient // and normalize the url. // The baseUrl will be merge with request uri. -// TODO: Query params and path params should be deal in this function. +// Query params and path params deal in this function. func parserURL(c *Client, req *Request) error { splitUrl := strings.Split(req.url, "?") // I don't want to judege splitUrl length. @@ -171,6 +171,12 @@ func parserHeader(c *Client, req *Request) error { req.rawRequest.Header.SetUserAgent(req.userAgent) } + // set referer + req.rawRequest.Header.SetReferer(c.referer) + if req.referer != "" { + req.rawRequest.Header.SetReferer(req.referer) + } + // set cookie c.cookies.VisitAll(func(key, val string) { req.rawRequest.Header.SetCookie(key, val) diff --git a/client/hooks_test.go b/client/hooks_test.go index 6b30e512b0..2bcb5aa6f1 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -335,6 +335,24 @@ func TestParserHeader(t *testing.T) { utils.AssertEqual(t, []byte("bar"), req.rawRequest.Header.UserAgent()) }) + t.Run("referer in client should be set", func(t *testing.T) { + client := AcquireClient().SetReferer("https://example.com") + req := AcquireRequest() + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("https://example.com"), req.rawRequest.Header.Referer()) + }) + + t.Run("referer in request should have higher level", func(t *testing.T) { + client := AcquireClient().SetReferer("http://example.com") + req := AcquireRequest().SetReferer("https://example.com") + + err := parserHeader(client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("https://example.com"), req.rawRequest.Header.Referer()) + }) + t.Run("client cookie should be set", func(t *testing.T) { client := AcquireClient(). SetCookie("foo", "bar"). diff --git a/client/request.go b/client/request.go index ed8e7f2ad5..855340c0ea 100644 --- a/client/request.go +++ b/client/request.go @@ -33,9 +33,10 @@ const ( type Request struct { url string method string - ctx context.Context userAgent string boundary string + referer string + ctx context.Context header *Header params *QueryParam cookies *Cookie @@ -158,6 +159,13 @@ func (r *Request) SetUserAgent(ua string) *Request { return r } +// SetReferer method sets referer in request. +// It will override referer which set in client instance. +func (r *Request) SetReferer(referer string) *Request { + r.referer = referer + return r +} + // SetCookie method sets a single cookie field and its value in the request instance. // It will override cookie which set in client instance. func (r *Request) SetCookie(key, val string) *Request { @@ -315,8 +323,9 @@ func (r *Request) AddFiles(files ...*File) *Request { func (r *Request) Reset() { r.url = "" r.method = fiber.MethodGet - r.ctx = nil r.userAgent = "" + r.referer = "" + r.ctx = nil r.body = nil r.bodyType = noBody From 96f562b281fe9da336c3e04a82c91c9f4d0e00c1 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Thu, 4 Aug 2022 21:18:30 +0800 Subject: [PATCH 015/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20reponse=20unmarsha?= =?UTF-8?q?l?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/core.go | 8 +++-- client/hooks.go | 25 ++++++++++---- client/hooks_test.go | 70 +++++++++++++++++++------------------- client/respose.go | 81 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 44 deletions(-) diff --git a/client/core.go b/client/core.go index f4e85bb2e9..809d80125d 100644 --- a/client/core.go +++ b/client/core.go @@ -70,6 +70,8 @@ type Core struct { func (c *Core) execute(ctx context.Context, agent *Client, req *Request) (*Response, error) { var execFunc ExecuteFunc = func(ctx context.Context, a *Client, r *Request) (*Response, error) { resp := AcquireResponse() + resp.setClient(a) + resp.setRequest(r) // To avoid memory allocation reuse of data structures such as errch. errCh, reqv, respv := acquireErrChan(), fasthttp.AcquireRequest(), fasthttp.AcquireResponse() @@ -99,7 +101,7 @@ func (c *Core) execute(ctx context.Context, agent *Client, req *Request) (*Respo } return resp, nil case <-ctx.Done(): - return nil, fmt.Errorf("timeout error") + return nil, fmt.Errorf("timeout or cancel error") } } @@ -207,9 +209,9 @@ func AcquireCore() (c *Core) { c = &Core{ client: &fasthttp.HostClient{}, userRequestHooks: []RequestHook{}, - buildinRequestHooks: []RequestHook{parserURL, parserHeader, parserBody}, + buildinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, userResponseHooks: []ResponseHook{}, - buildinResposeHooks: []ResponseHook{}, + buildinResposeHooks: []ResponseHook{parserResponseCookie}, plugins: []Plugin{}, pluginMap: map[string]Plugin{}, jsonMarshal: json.Marshal, diff --git a/client/hooks.go b/client/hooks.go index 352981d71d..9a52c90415 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -71,11 +71,11 @@ func randString(n int) string { return utils.UnsafeString(b) } -// parserURL will set the options for the hostclient +// parserRequestURL will set the options for the hostclient // and normalize the url. // The baseUrl will be merge with request uri. // Query params and path params deal in this function. -func parserURL(c *Client, req *Request) error { +func parserRequestURL(c *Client, req *Request) error { splitUrl := strings.Split(req.url, "?") // I don't want to judege splitUrl length. splitUrl = append(splitUrl, "") @@ -132,11 +132,11 @@ func parserURL(c *Client, req *Request) error { return nil } -// parserHeader will make request header up. +// parserRequestHeader will make request header up. // It will merge headers from client and request. // Header should be set automatically based on data. // User-Agent should be set. -func parserHeader(c *Client, req *Request) error { +func parserRequestHeader(c *Client, req *Request) error { // merge header c.header.VisitAll(func(key, value []byte) { req.rawRequest.Header.SetBytesKV(key, value) @@ -189,9 +189,9 @@ func parserHeader(c *Client, req *Request) error { return nil } -// parserBody automatically serializes the data according to +// parserRequestBody automatically serializes the data according to // the data type and stores it in the body of the rawRequest -func parserBody(c *Client, req *Request) (err error) { +func parserRequestBody(c *Client, req *Request) (err error) { switch req.bodyType { case jsonBody: body, err := c.core.jsonMarshal(req.body) @@ -285,3 +285,16 @@ func parserBody(c *Client, req *Request) (err error) { } return nil } + +func parserResponseCookie(c *Client, resp *Response, req *Request) (err error) { + resp.rawResponse.Header.VisitAllCookie(func(key, value []byte) { + cookie := fasthttp.AcquireCookie() + err = cookie.ParseBytes(value) + if err != nil { + return + } + cookie.SetKeyBytes(key) + }) + + return +} diff --git a/client/hooks_test.go b/client/hooks_test.go index 2bcb5aa6f1..96bfe2e66a 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -70,14 +70,14 @@ func TestRandString(t *testing.T) { } } -func TestParserURL(t *testing.T) { +func TestParserRequestURL(t *testing.T) { t.Parallel() t.Run("client baseurl should be set", func(t *testing.T) { client := AcquireClient().SetBaseURL("http://example.com/api") req := AcquireRequest().SetURL("") - err := parserURL(client, req) + err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "http://example.com/api", req.rawRequest.URI().String()) }) @@ -86,7 +86,7 @@ func TestParserURL(t *testing.T) { client := AcquireClient() req := AcquireRequest().SetURL("http://example.com/api") - err := parserURL(client, req) + err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "http://example.com/api", req.rawRequest.URI().String()) }) @@ -95,7 +95,7 @@ func TestParserURL(t *testing.T) { client := AcquireClient().SetBaseURL("http://example.com/api") req := AcquireRequest().SetURL("http://example.com/api/v1") - err := parserURL(client, req) + err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "http://example.com/api/v1", req.rawRequest.URI().String()) }) @@ -104,7 +104,7 @@ func TestParserURL(t *testing.T) { client := AcquireClient().SetBaseURL("http://example.com/api") req := AcquireRequest().SetURL("/v1") - err := parserURL(client, req) + err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "http://example.com/api/v1", req.rawRequest.URI().String()) }) @@ -113,7 +113,7 @@ func TestParserURL(t *testing.T) { client := AcquireClient().SetBaseURL("example.com/api") req := AcquireRequest().SetURL("/v1") - err := parserURL(client, req) + err := parserRequestURL(client, req) utils.AssertEqual(t, fmt.Errorf("url format error"), err) }) @@ -123,7 +123,7 @@ func TestParserURL(t *testing.T) { SetPathParam("id", "5") req := AcquireRequest() - err := parserURL(client, req) + err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "http://example.com/api/5", req.rawRequest.URI().String()) }) @@ -140,7 +140,7 @@ func TestParserURL(t *testing.T) { }). DelPathParams("key") - err := parserURL(client, req) + err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "http://example.com/api/5/fiber/%7Bkey%7D", req.rawRequest.URI().String()) }) @@ -157,7 +157,7 @@ func TestParserURL(t *testing.T) { "id": "12", }) - err := parserURL(client, req) + err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "http://example.com/api/12/fiber/val", req.rawRequest.URI().String()) }) @@ -167,7 +167,7 @@ func TestParserURL(t *testing.T) { SetParam("foo", "bar") req := AcquireRequest().SetURL("http://example.com/api/v1") - err := parserURL(client, req) + err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("foo=bar"), req.rawRequest.URI().QueryString()) }) @@ -178,7 +178,7 @@ func TestParserURL(t *testing.T) { SetURL("http://example.com/api/v1"). SetParam("bar", "foo") - err := parserURL(client, req) + err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("bar=foo"), req.rawRequest.URI().QueryString()) }) @@ -190,7 +190,7 @@ func TestParserURL(t *testing.T) { SetURL("http://example.com/api/v1?bar=foo2"). SetParam("bar", "foo") - err := parserURL(client, req) + err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) values, _ := url.ParseQuery(string(req.rawRequest.URI().QueryString())) @@ -211,7 +211,7 @@ func TestParserURL(t *testing.T) { }) } -func TestParserHeader(t *testing.T) { +func TestParserRequestHeader(t *testing.T) { t.Parallel() t.Run("client header should be set", func(t *testing.T) { @@ -222,7 +222,7 @@ func TestParserHeader(t *testing.T) { req := AcquireRequest() - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("application/json"), req.rawRequest.Header.ContentType()) }) @@ -235,7 +235,7 @@ func TestParserHeader(t *testing.T) { fiber.HeaderContentType: "application/json, utf-8", }) - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("application/json, utf-8"), req.rawRequest.Header.ContentType()) }) @@ -247,7 +247,7 @@ func TestParserHeader(t *testing.T) { req := AcquireRequest(). SetHeader(fiber.HeaderContentType, "application/json, utf-8") - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("application/json, utf-8"), req.rawRequest.Header.ContentType()) }) @@ -262,7 +262,7 @@ func TestParserHeader(t *testing.T) { Name: "foo", }) - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte(applicationJSON), req.rawRequest.Header.ContentType()) }) @@ -278,7 +278,7 @@ func TestParserHeader(t *testing.T) { Name: "foo", }) - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte(applicationXML), req.rawRequest.Header.ContentType()) }) @@ -291,7 +291,7 @@ func TestParserHeader(t *testing.T) { "ball": "cricle and square", }) - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, applicationForm, string(req.rawRequest.Header.ContentType())) }) @@ -302,7 +302,7 @@ func TestParserHeader(t *testing.T) { AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))). SetFormData("foo", "bar") - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Header.MultipartFormBoundary()), "--FiberFormBoundary")) utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Header.ContentType()), multipartFormData)) @@ -312,7 +312,7 @@ func TestParserHeader(t *testing.T) { client := AcquireClient() req := AcquireRequest() - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("fiber"), req.rawRequest.Header.UserAgent()) }) @@ -321,7 +321,7 @@ func TestParserHeader(t *testing.T) { client := AcquireClient().SetUserAgent("foo") req := AcquireRequest() - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("foo"), req.rawRequest.Header.UserAgent()) }) @@ -330,7 +330,7 @@ func TestParserHeader(t *testing.T) { client := AcquireClient().SetUserAgent("foo") req := AcquireRequest().SetUserAgent("bar") - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("bar"), req.rawRequest.Header.UserAgent()) }) @@ -339,7 +339,7 @@ func TestParserHeader(t *testing.T) { client := AcquireClient().SetReferer("https://example.com") req := AcquireRequest() - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("https://example.com"), req.rawRequest.Header.Referer()) }) @@ -348,7 +348,7 @@ func TestParserHeader(t *testing.T) { client := AcquireClient().SetReferer("http://example.com") req := AcquireRequest().SetReferer("https://example.com") - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("https://example.com"), req.rawRequest.Header.Referer()) }) @@ -364,7 +364,7 @@ func TestParserHeader(t *testing.T) { req := AcquireRequest() - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "bar", string(req.rawRequest.Header.Cookie("foo"))) utils.AssertEqual(t, "foo", string(req.rawRequest.Header.Cookie("bar"))) @@ -385,7 +385,7 @@ func TestParserHeader(t *testing.T) { Bar: 67, }) - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "bar", string(req.rawRequest.Header.Cookie("foo"))) utils.AssertEqual(t, "67", string(req.rawRequest.Header.Cookie("bar"))) @@ -411,7 +411,7 @@ func TestParserHeader(t *testing.T) { Bar: 67, }) - err := parserHeader(client, req) + err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "bar", string(req.rawRequest.Header.Cookie("foo"))) utils.AssertEqual(t, "67", string(req.rawRequest.Header.Cookie("bar"))) @@ -419,7 +419,7 @@ func TestParserHeader(t *testing.T) { }) } -func TestParserBody(t *testing.T) { +func TestParserRequestBody(t *testing.T) { t.Parallel() t.Run("json body", func(t *testing.T) { @@ -432,7 +432,7 @@ func TestParserBody(t *testing.T) { Name: "foo", }) - err := parserBody(client, req) + err := parserRequestBody(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("{\"name\":\"foo\"}"), req.rawRequest.Body()) }) @@ -448,7 +448,7 @@ func TestParserBody(t *testing.T) { Name: "foo", }) - err := parserBody(client, req) + err := parserRequestBody(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("foo"), req.rawRequest.Body()) }) @@ -460,7 +460,7 @@ func TestParserBody(t *testing.T) { "ball": "cricle and square", }) - err := parserBody(client, req) + err := parserRequestBody(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "ball=cricle+and+square", string(req.rawRequest.Body())) }) @@ -470,7 +470,7 @@ func TestParserBody(t *testing.T) { req := AcquireRequest(). AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))) - err := parserBody(client, req) + err := parserRequestBody(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Body()), "----FiberFormBoundary")) utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Body()), "world")) @@ -482,7 +482,7 @@ func TestParserBody(t *testing.T) { AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))). SetFormData("foo", "bar") - err := parserBody(client, req) + err := parserRequestBody(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Body()), "----FiberFormBoundary")) utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Body()), "world")) @@ -494,7 +494,7 @@ func TestParserBody(t *testing.T) { req := AcquireRequest(). SetRawBody([]byte("hello world")) - err := parserBody(client, req) + err := parserRequestBody(client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, []byte("hello world"), req.rawRequest.Body()) }) diff --git a/client/respose.go b/client/respose.go index 97b7bb29bf..97b93679cb 100644 --- a/client/respose.go +++ b/client/respose.go @@ -1,20 +1,100 @@ package client import ( + "strings" "sync" "github.com/valyala/fasthttp" ) type Response struct { + client *Client + request *Request + cookie []*fasthttp.Cookie rawResponse *fasthttp.Response } +// setClient method sets client object in response instance. +// Use core object in the client. +func (r *Response) setClient(c *Client) { + r.client = c +} + +// setRequest method sets Request object in response instance. +// The request will be released when the Response.Close is called. +func (r *Response) setRequest(req *Request) { + r.request = req +} + +// Status method returns the HTTP status string for the executed request. +func (r *Response) Status() string { + return string(r.rawResponse.Header.StatusMessage()) +} + +// StatusCode method returns the HTTP status code for the executed request. +func (r *Response) StatusCode() int { + return r.rawResponse.StatusCode() +} + +// Protocol method returns the HTTP response protocol used for the request. +func (r *Response) Protocol() string { + return string(r.rawResponse.Header.Protocol()) +} + +// Header method returns the response headers. +func (r *Response) Header() fasthttp.ResponseHeader { + return r.rawResponse.Header +} + +// Cookies method to access all the response cookies. +func (r *Response) Cookies() []*fasthttp.Cookie { + return r.cookie +} + +// Body method returns HTTP response as []byte array for the executed request. +func (r *Response) Body() []byte { + return r.rawResponse.Body() +} + +// String method returns the body of the server response as String. +func (r *Response) String() string { + return strings.TrimSpace(string(r.Body())) +} + +// JSON method will unmarshal body to json. +func (r *Response) JSON(v any) error { + return r.client.core.jsonUnmarshal(r.Body(), v) +} + +// XML method will unmarshal body to xml. +func (r *Response) XML(v any) error { + return r.client.core.xmlUnmarshal(r.Body(), v) +} + // Reset clear Response object. func (r *Response) Reset() { + r.client = nil + r.request = nil + copied := r.cookie + r.cookie = []*fasthttp.Cookie{} + for _, v := range copied { + fasthttp.ReleaseCookie(v) + } + r.rawResponse.Reset() } +// Close method will release Request object and Response object, +// after call Close please don't use these object. +func (r *Response) Close() { + if r.request != nil { + tmp := r.request + r.request = nil + ReleaseRequest(tmp) + } + ReleaseResponse(r) +} + var responsePool sync.Pool // AcquireResponse returns an empty response object from the pool. @@ -28,6 +108,7 @@ func AcquireResponse() (resp *Response) { return } resp = &Response{ + cookie: []*fasthttp.Cookie{}, rawResponse: fasthttp.AcquireResponse(), } From 951b6fb5fb03d604ba8b95426e5c17d14852d177 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Fri, 5 Aug 2022 10:10:54 +0800 Subject: [PATCH 016/118] =?UTF-8?q?=E2=9C=A8=20v3:=20finish=20API=20design?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 171 +++++++++++++++++++++++++++++++++++++++++++--- client/request.go | 75 +++++++++++++++++++- 2 files changed, 234 insertions(+), 12 deletions(-) diff --git a/client/client.go b/client/client.go index d107e7bff6..473d86631b 100644 --- a/client/client.go +++ b/client/client.go @@ -3,7 +3,6 @@ package client import ( "sync" - "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/utils" "github.com/valyala/fasthttp" ) @@ -20,6 +19,10 @@ type Client struct { path *PathParam } +func (c *Client) R() *Request { + return AcquireRequest().SetClient(c) +} + // Add user-defined request hooks. func (c *Client) AddRequestHook(h ...RequestHook) *Client { c.core.userRequestHooks = append(c.core.userRequestHooks, h...) @@ -225,7 +228,84 @@ func (c *Client) DelCookies(key ...string) *Client { return c } -// Reset clear Client object. +// Get provide a API like axios which send get request. +func (c *Client) Get(url string, setter ...SetRequestOptionFunc) (*Response, error) { + req := AcquireRequest().SetClient(c) + + for _, v := range setter { + v(req) + } + + return req.Get(url) +} + +// Post provide a API like axios which send post request. +func (c *Client) Post(url string, setter ...SetRequestOptionFunc) (*Response, error) { + req := AcquireRequest().SetClient(c) + + for _, v := range setter { + v(req) + } + + return req.Post(url) +} + +// Head provide a API like axios which send head request. +func (c *Client) Head(url string, setter ...SetRequestOptionFunc) (*Response, error) { + req := AcquireRequest().SetClient(c) + + for _, v := range setter { + v(req) + } + + return req.Head(url) +} + +// Put provide a API like axios which send put request. +func (c *Client) Put(url string, setter ...SetRequestOptionFunc) (*Response, error) { + req := AcquireRequest().SetClient(c) + + for _, v := range setter { + v(req) + } + + return req.Put(url) +} + +// Delete provide a API like axios which send delete request. +func (c *Client) Delete(url string, setter ...SetRequestOptionFunc) (*Response, error) { + req := AcquireRequest().SetClient(c) + + for _, v := range setter { + v(req) + } + + return req.Delete(url) +} + +// Options provide a API like axios which send options request. +func (c *Client) Options(url string, setter ...SetRequestOptionFunc) (*Response, error) { + req := AcquireRequest().SetClient(c) + + for _, v := range setter { + v(req) + } + + return req.Options(url) +} + +// Patch provide a API like axios which send patch request. +func (c *Client) Patch(url string, setter ...SetRequestOptionFunc) (*Response, error) { + req := AcquireRequest().SetClient(c) + + for _, v := range setter { + v(req) + } + + return req.Patch(url) +} + +// Reset clear Client object func (c *Client) Reset() { c.baseUrl = "" c.userAgent = "" @@ -238,13 +318,54 @@ func (c *Client) Reset() { c.params.Reset() } -// Get provide a API like axios which send get request. -func (c *Client) Get(url string) (*Response, error) { - req := AcquireRequest(). - setMethod(fiber.MethodGet). - SetURL(url) +type SetRequestOptionFunc func(r *Request) + +func SetRequestHeaders(m map[string]string) SetRequestOptionFunc { + return func(r *Request) { + r.SetHeaders(m) + } +} + +func SetRequestQueryParams(m map[string]string) SetRequestOptionFunc { + return func(r *Request) { + r.SetParams(m) + } +} + +func SetRequestUserAgent(ua string) SetRequestOptionFunc { + return func(r *Request) { + r.SetUserAgent(ua) + } +} + +func SetRequestReferer(referer string) SetRequestOptionFunc { + return func(r *Request) { + r.SetReferer(referer) + } +} + +func SetRequestData(v any) SetRequestOptionFunc { + return func(r *Request) { + r.SetJSON(v) + } +} + +func SetRequestFormDatas(m map[string]string) SetRequestOptionFunc { + return func(r *Request) { + r.SetFormDatas(m) + } +} - return c.core.execute(req.Context(), c, req) +func SetRequestPathParams(m map[string]string) SetRequestOptionFunc { + return func(r *Request) { + r.SetPathParams(m) + } +} + +func SetRequestFiles(files ...*File) SetRequestOptionFunc { + return func(r *Request) { + r.AddFiles(files...) + } } var ( @@ -305,6 +426,36 @@ func Replace(c *Client) func() { } // Get send a get request use defaultClient, a convenient method. -func Get(url string) (*Response, error) { - return defaultClient.Get(url) +func Get(url string, setter ...SetRequestOptionFunc) (*Response, error) { + return defaultClient.Get(url, setter...) +} + +// Post send a post request use defaultClient, a convenient method. +func Post(url string, setter ...SetRequestOptionFunc) (*Response, error) { + return defaultClient.Post(url, setter...) +} + +// Head send a head request use defaultClient, a convenient method. +func Head(url string, setter ...SetRequestOptionFunc) (*Response, error) { + return defaultClient.Head(url, setter...) +} + +// Put send a put request use defaultClient, a convenient method. +func Put(url string, setter ...SetRequestOptionFunc) (*Response, error) { + return defaultClient.Put(url, setter...) +} + +// Delete send a delete request use defaultClient, a convenient method. +func Delete(url string, setter ...SetRequestOptionFunc) (*Response, error) { + return defaultClient.Delete(url, setter...) +} + +// Options send a options request use defaultClient, a convenient method. +func Options(url string, setter ...SetRequestOptionFunc) (*Response, error) { + return defaultClient.Options(url, setter...) +} + +// Patch send a patch request use defaultClient, a convenient method. +func Patch(url string, setter ...SetRequestOptionFunc) (*Response, error) { + return defaultClient.Patch(url, setter...) } diff --git a/client/request.go b/client/request.go index 855340c0ea..7386de6cc5 100644 --- a/client/request.go +++ b/client/request.go @@ -42,6 +42,8 @@ type Request struct { cookies *Cookie path *PathParam + client *Client + body any formData *FormData files []*File @@ -50,9 +52,9 @@ type Request struct { rawRequest *fasthttp.Request } -// setMethod will set method for Request object, +// SetMethod will set method for Request object, // user should use request method to set method. -func (r *Request) setMethod(method string) *Request { +func (r *Request) SetMethod(method string) *Request { r.method = method return r } @@ -63,6 +65,12 @@ func (r *Request) SetURL(url string) *Request { return r } +// SetClient method sets client in request instance. +func (r *Request) SetClient(c *Client) *Request { + r.client = c + return r +} + // Context returns the Context if its already set in request // otherwise it creates new one using `context.Background()`. func (r *Request) Context() context.Context { @@ -319,6 +327,69 @@ func (r *Request) AddFiles(files ...*File) *Request { return r } +// checkClient method checks whether the client has been set in request. +func (r *Request) checkClient() { + if r.client == nil { + r.SetClient(defaultClient) + } +} + +// Send get request. +func (r *Request) Get(url string) (*Response, error) { + r.SetURL(url).SetMethod(fiber.MethodGet).checkClient() + + return r.client.core.execute(r.Context(), r.client, r) +} + +// Send post request. +func (r *Request) Post(url string) (*Response, error) { + r.SetURL(url).SetMethod(fiber.MethodPost).checkClient() + + return r.client.core.execute(r.Context(), r.client, r) +} + +// Send head request. +func (r *Request) Head(url string) (*Response, error) { + r.SetURL(url).SetMethod(fiber.MethodHead).checkClient() + + return r.client.core.execute(r.Context(), r.client, r) +} + +// Send put request. +func (r *Request) Put(url string) (*Response, error) { + r.SetURL(url).SetMethod(fiber.MethodPut).checkClient() + + return r.client.core.execute(r.Context(), r.client, r) +} + +// Send Delete request. +func (r *Request) Delete(url string) (*Response, error) { + r.SetURL(url).SetMethod(fiber.MethodDelete).checkClient() + + return r.client.core.execute(r.Context(), r.client, r) +} + +// Send Options reuqest. +func (r *Request) Options(url string) (*Response, error) { + r.SetURL(url).SetMethod(fiber.MethodOptions).checkClient() + + return r.client.core.execute(r.Context(), r.client, r) +} + +// Send patch request. +func (r *Request) Patch(url string) (*Response, error) { + r.SetURL(url).SetMethod(fiber.MethodPatch).checkClient() + + return r.client.core.execute(r.Context(), r.client, r) +} + +// Send a request. +func (r *Request) Send() (*Response, error) { + r.checkClient() + + return r.client.core.execute(r.Context(), r.client, r) +} + // Reset clear Request object, used by ReleaseRequest method. func (r *Request) Reset() { r.url = "" From f4799489a0ac77663090720f85664c20b8539271 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Fri, 5 Aug 2022 20:35:45 +0800 Subject: [PATCH 017/118] =?UTF-8?q?=F0=9F=94=A5=20v3:=20remove=20plugin=20?= =?UTF-8?q?mechanism?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/core.go | 42 +----------------------------------------- client/plugins.go | 1 - client/request.go | 2 +- 3 files changed, 2 insertions(+), 43 deletions(-) delete mode 100644 client/plugins.go diff --git a/client/core.go b/client/core.go index 809d80125d..0687a7b3a1 100644 --- a/client/core.go +++ b/client/core.go @@ -23,22 +23,6 @@ type RequestHook func(*Client, *Request) error // Called after a respose has been received. type ResponseHook func(*Client, *Response, *Request) error -// ExecuteFunc will actually execute the request via fasthttp. -type ExecuteFunc func(context.Context, *Client, *Request) (*Response, error) - -// Plugin can change the execution flow of requests. -type Plugin interface { - // Return the plugin name and the name should be different. - Name() string - - // Determine if the plugin should be executed based on the conditions. - Check() bool - - // Modify specific request execution methods, - // such as adding timeouts, cancellations, retries and other operations. - GenerateExecute(ExecuteFunc) (ExecuteFunc, error) -} - // `Core` stores middleware and plugin definitions, // and defines the execution process type Core struct { @@ -56,10 +40,6 @@ type Core struct { // client package defined respose hooks buildinResposeHooks []ResponseHook - // store plugins - plugins []Plugin - pluginMap map[string]Plugin - jsonMarshal utils.JSONMarshal jsonUnmarshal utils.JSONUnmarshal xmlMarshal utils.XMLMarshal @@ -68,7 +48,7 @@ type Core struct { // execute will exec each hooks and plugins. func (c *Core) execute(ctx context.Context, agent *Client, req *Request) (*Response, error) { - var execFunc ExecuteFunc = func(ctx context.Context, a *Client, r *Request) (*Response, error) { + execFunc := func(ctx context.Context, a *Client, r *Request) (*Response, error) { resp := AcquireResponse() resp.setClient(a) resp.setRequest(r) @@ -121,19 +101,6 @@ func (c *Core) execute(ctx context.Context, agent *Client, req *Request) (*Respo } } - // Call the plugins to generate the real request function. - for _, p := range c.plugins { - if !p.Check() { - continue - } - - var err error - execFunc, err = p.GenerateExecute(execFunc) - if err != nil { - return nil, err - } - } - // Do http request resp, err := execFunc(ctx, agent, req) if err != nil { @@ -164,11 +131,6 @@ func (c *Core) execute(ctx context.Context, agent *Client, req *Request) (*Respo func (c *Core) reset() { c.userRequestHooks = c.userRequestHooks[:0] c.userResponseHooks = c.userResponseHooks[:0] - c.plugins = c.plugins[:0] - - for k := range c.pluginMap { - delete(c.pluginMap, k) - } } var errChanPool sync.Pool @@ -212,8 +174,6 @@ func AcquireCore() (c *Core) { buildinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, userResponseHooks: []ResponseHook{}, buildinResposeHooks: []ResponseHook{parserResponseCookie}, - plugins: []Plugin{}, - pluginMap: map[string]Plugin{}, jsonMarshal: json.Marshal, jsonUnmarshal: json.Unmarshal, xmlMarshal: xml.Marshal, diff --git a/client/plugins.go b/client/plugins.go deleted file mode 100644 index da13c8ef3c..0000000000 --- a/client/plugins.go +++ /dev/null @@ -1 +0,0 @@ -package client diff --git a/client/request.go b/client/request.go index 7386de6cc5..2d364d95c5 100644 --- a/client/request.go +++ b/client/request.go @@ -401,7 +401,7 @@ func (r *Request) Reset() { r.bodyType = noBody copiedFile := r.files - r.files = r.files[0:0] + r.files = r.files[:0] for _, v := range copiedFile { ReleaseFile(v) } From 02ddc9b7904525370102d144e35fd5548037b988 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sat, 6 Aug 2022 16:42:24 +0800 Subject: [PATCH 018/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20add=20timeout?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 12 ++++++ client/core.go | 97 ++++++++++++++++++++++++++++------------------- client/request.go | 10 +++++ 3 files changed, 79 insertions(+), 40 deletions(-) diff --git a/client/client.go b/client/client.go index 473d86631b..7792467fd0 100644 --- a/client/client.go +++ b/client/client.go @@ -2,6 +2,7 @@ package client import ( "sync" + "time" "github.com/gofiber/fiber/v3/utils" "github.com/valyala/fasthttp" @@ -17,6 +18,8 @@ type Client struct { params *QueryParam cookies *Cookie path *PathParam + + timeout time.Duration } func (c *Client) R() *Request { @@ -228,6 +231,14 @@ func (c *Client) DelCookies(key ...string) *Client { return c } +// SetTimeout method sets timeout val in client instance. +// This value will be applied to all requests raised from this client instance. +// Also it can be overridden at request level timeout options. +func (c *Client) SetTimeout(t time.Duration) *Client { + c.timeout = t + return c +} + // Get provide a API like axios which send get request. func (c *Client) Get(url string, setter ...SetRequestOptionFunc) (*Response, error) { req := AcquireRequest().SetClient(c) @@ -308,6 +319,7 @@ func (c *Client) Patch(url string, setter ...SetRequestOptionFunc) (*Response, e // Reset clear Client object func (c *Client) Reset() { c.baseUrl = "" + c.timeout = 0 c.userAgent = "" c.referer = "" diff --git a/client/core.go b/client/core.go index 0687a7b3a1..43b45ebee1 100644 --- a/client/core.go +++ b/client/core.go @@ -46,63 +46,80 @@ type Core struct { xmlUnmarshal utils.XMLUnmarshal } -// execute will exec each hooks and plugins. -func (c *Core) execute(ctx context.Context, agent *Client, req *Request) (*Response, error) { - execFunc := func(ctx context.Context, a *Client, r *Request) (*Response, error) { - resp := AcquireResponse() - resp.setClient(a) - resp.setRequest(r) - - // To avoid memory allocation reuse of data structures such as errch. - errCh, reqv, respv := acquireErrChan(), fasthttp.AcquireRequest(), fasthttp.AcquireResponse() - defer func() { - releaseErrChan(errCh) - fasthttp.ReleaseRequest(reqv) - fasthttp.ReleaseResponse(respv) - }() - - req.rawRequest.CopyTo(reqv) - go func() { - err := c.client.Do(reqv, respv) - if err != nil { - errCh <- err - return - } - respv.CopyTo(resp.rawResponse) - errCh <- nil - }() +func (c *Core) execFunc(ctx context.Context, client *Client, req *Request) (*Response, error) { + resp := AcquireResponse() + resp.setClient(client) + resp.setRequest(req) + + // To avoid memory allocation reuse of data structures such as errch. + errCh, reqv, respv := acquireErrChan(), fasthttp.AcquireRequest(), fasthttp.AcquireResponse() + defer func() { + releaseErrChan(errCh) + fasthttp.ReleaseRequest(reqv) + fasthttp.ReleaseResponse(respv) + }() + + req.rawRequest.CopyTo(reqv) + go func() { + err := c.client.Do(reqv, respv) + if err != nil { + errCh <- err + return + } + respv.CopyTo(resp.rawResponse) + errCh <- nil + }() - select { - case err := <-errCh: - if err != nil { - // When get error should release Response - ReleaseResponse(resp) - return nil, err - } - return resp, nil - case <-ctx.Done(): - return nil, fmt.Errorf("timeout or cancel error") + select { + case err := <-errCh: + if err != nil { + // When get error should release Response + ReleaseResponse(resp) + return nil, err } + return resp, nil + case <-ctx.Done(): + return nil, fmt.Errorf("timeout or cancel error") } +} +// execute will exec each hooks and plugins. +func (c *Core) execute(ctx context.Context, client *Client, req *Request) (*Response, error) { // The built-in hooks will be executed only // after the user-defined hooks are executed。 for _, f := range c.userRequestHooks { - err := f(agent, req) + err := f(client, req) if err != nil { return nil, err } } for _, f := range c.buildinRequestHooks { - err := f(agent, req) + err := f(client, req) if err != nil { return nil, err } } + // deal with timeout + if req.timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, req.timeout) + defer func() { + cancel() + }() + } else { + if client.timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, client.timeout) + defer func() { + cancel() + }() + } + } + // Do http request - resp, err := execFunc(ctx, agent, req) + resp, err := c.execFunc(ctx, client, req) if err != nil { return nil, err } @@ -110,14 +127,14 @@ func (c *Core) execute(ctx context.Context, agent *Client, req *Request) (*Respo // The built-in hooks will be executed only // before the user-defined hooks are executed. for _, f := range c.buildinResposeHooks { - err := f(agent, resp, req) + err := f(client, resp, req) if err != nil { return nil, err } } for _, f := range c.userResponseHooks { - err := f(agent, resp, req) + err := f(client, resp, req) if err != nil { return nil, err } diff --git a/client/request.go b/client/request.go index 2d364d95c5..7307da58cd 100644 --- a/client/request.go +++ b/client/request.go @@ -6,6 +6,7 @@ import ( "reflect" "strconv" "sync" + "time" "github.com/gofiber/fiber/v3" "github.com/valyala/fasthttp" @@ -42,6 +43,8 @@ type Request struct { cookies *Cookie path *PathParam + timeout time.Duration + client *Client body any @@ -327,6 +330,13 @@ func (r *Request) AddFiles(files ...*File) *Request { return r } +// SetTimeout method sets timeout field and its values at one go in the request instance. +// It will override timeout which set in client instance. +func (r *Request) SetTimeout(t time.Duration) *Request { + r.timeout = t + return r +} + // checkClient method checks whether the client has been set in request. func (r *Request) checkClient() { if r.client == nil { From a3d0296ea5d12207d501da60e2d4f393fe7668c8 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sun, 7 Aug 2022 16:18:01 +0800 Subject: [PATCH 019/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20change=20path=20pa?= =?UTF-8?q?rams=20pattern=20and=20add=20unit=20test=20for=20core?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/core.go | 54 +++++------- client/core_test.go | 194 +++++++++++++++++++++++++++++++++++++++++++ client/hooks.go | 13 ++- client/hooks_test.go | 11 ++- 4 files changed, 224 insertions(+), 48 deletions(-) create mode 100644 client/core_test.go diff --git a/client/core.go b/client/core.go index 43b45ebee1..e2e44ac2da 100644 --- a/client/core.go +++ b/client/core.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" "encoding/xml" - "fmt" + "errors" "sync" "github.com/gofiber/fiber/v3/utils" @@ -23,9 +23,9 @@ type RequestHook func(*Client, *Request) error // Called after a respose has been received. type ResponseHook func(*Client, *Response, *Request) error -// `Core` stores middleware and plugin definitions, +// `core` stores middleware and plugin definitions, // and defines the execution process -type Core struct { +type core struct { client *fasthttp.HostClient // user defined request hooks @@ -46,7 +46,7 @@ type Core struct { xmlUnmarshal utils.XMLUnmarshal } -func (c *Core) execFunc(ctx context.Context, client *Client, req *Request) (*Response, error) { +func (c *core) execFunc(ctx context.Context, client *Client, req *Request) (*Response, error) { resp := AcquireResponse() resp.setClient(client) resp.setRequest(req) @@ -79,12 +79,13 @@ func (c *Core) execFunc(ctx context.Context, client *Client, req *Request) (*Res } return resp, nil case <-ctx.Done(): - return nil, fmt.Errorf("timeout or cancel error") + ReleaseResponse(resp) + return nil, ErrTimeoutOrCancel } } // execute will exec each hooks and plugins. -func (c *Core) execute(ctx context.Context, client *Client, req *Request) (*Response, error) { +func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Response, error) { // The built-in hooks will be executed only // after the user-defined hooks are executed。 for _, f := range c.userRequestHooks { @@ -102,14 +103,14 @@ func (c *Core) execute(ctx context.Context, client *Client, req *Request) (*Resp } // deal with timeout - if req.timeout != 0 { + if req.timeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, req.timeout) defer func() { cancel() }() } else { - if client.timeout != 0 { + if client.timeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, client.timeout) defer func() { @@ -143,13 +144,6 @@ func (c *Core) execute(ctx context.Context, client *Client, req *Request) (*Resp return resp, nil } -// reset clears core object. -// It will not clear buildin hooks. -func (c *Core) reset() { - c.userRequestHooks = c.userRequestHooks[:0] - c.userResponseHooks = c.userResponseHooks[:0] -} - var errChanPool sync.Pool // acquireErrChan returns an empty error chan from the pool. @@ -173,19 +167,9 @@ func releaseErrChan(ch chan error) { errChanPool.Put(ch) } -var corePool sync.Pool - -// AcquireCore returns an empty core object from the pool. -// -// The returned core may be returned to the pool with ReleaseCore when no longer needed. -// This allows reducing GC load. -func AcquireCore() (c *Core) { - cv := corePool.Get() - if cv != nil { - c = cv.(*Core) - return - } - c = &Core{ +// newCore returns an empty core object. +func newCore() (c *core) { + c = &core{ client: &fasthttp.HostClient{}, userRequestHooks: []RequestHook{}, buildinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, @@ -200,10 +184,10 @@ func AcquireCore() (c *Core) { return } -// ReleaseCore returns the object acquired via AcquireCore to the pool. -// -// Do not access the released core object, otherwise data races may occur. -func ReleaseCore(c *Core) { - c.reset() - corePool.Put(c) -} +var ( + ErrTimeoutOrCancel = errors.New("timeout or cancel") + ErrURLForamt = errors.New("the url is a mistake") + ErrNotSupportSchema = errors.New("the protocol is not support, only http or https") + ErrFileNoName = errors.New("the file should have name") + ErrBodyType = errors.New("the body type should be []byte") +) diff --git a/client/core_test.go b/client/core_test.go new file mode 100644 index 0000000000..59eccad08c --- /dev/null +++ b/client/core_test.go @@ -0,0 +1,194 @@ +package client + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/utils" + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttputil" +) + +func TestExecFunc(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + + app.Get("/normal", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + + app.Get("/return-error", func(c fiber.Ctx) error { + return fmt.Errorf("the request is error") + }) + + app.Get("/hang-up", func(c fiber.Ctx) error { + time.Sleep(time.Second) + return c.SendString(c.Hostname() + " hang up") + }) + + go func() { + utils.AssertEqual(t, nil, app.Listener(ln)) + }() + + t.Run("normal request", func(t *testing.T) { + core, client, req := newCore(), AcquireClient(), AcquireRequest() + core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + req.rawRequest.SetRequestURI("http://example.com/normal") + + resp, err := core.execFunc(context.Background(), client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 200, resp.rawResponse.StatusCode()) + utils.AssertEqual(t, "example.com", string(resp.rawResponse.Body())) + }) + + t.Run("the request return an error", func(t *testing.T) { + core, client, req := newCore(), AcquireClient(), AcquireRequest() + core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + req.rawRequest.SetRequestURI("http://example.com/return-error") + + resp, err := core.execFunc(context.Background(), client, req) + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 500, resp.rawResponse.StatusCode()) + utils.AssertEqual(t, "the request is error", string(resp.rawResponse.Body())) + }) + + t.Run("there is no connect", func(t *testing.T) { + core, client := newCore(), AcquireClient() + core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + core.client.SetMaxConns(1) + + go func() { + req := AcquireRequest() + req.rawRequest.SetRequestURI("http://example.com/normal") + _, err := core.execFunc(context.Background(), client, req) + utils.AssertEqual(t, fasthttp.ErrNoFreeConns, err) + }() + + req := AcquireRequest() + req.rawRequest.SetRequestURI("http://example.com/hang-up") + core.execFunc(context.Background(), client, req) + }) + + t.Run("the request timeout", func(t *testing.T) { + core, client, req := newCore(), AcquireClient(), AcquireRequest() + core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + req.rawRequest.SetRequestURI("http://example.com/hang-up") + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + _, err := core.execFunc(ctx, client, req) + + utils.AssertEqual(t, ErrTimeoutOrCancel, err) + }) +} + +func TestExecute(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + + app.Get("/normal", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + + app.Get("/return-error", func(c fiber.Ctx) error { + return fmt.Errorf("the request is error") + }) + + app.Get("/hang-up", func(c fiber.Ctx) error { + time.Sleep(time.Second) + return c.SendString(c.Hostname() + " hang up") + }) + + go func() { + utils.AssertEqual(t, nil, app.Listener(ln)) + }() + + t.Run("add user request hooks", func(t *testing.T) { + client, req := AcquireClient(), AcquireRequest() + client.AddRequestHook(func(c *Client, r *Request) error { + utils.AssertEqual(t, "http://example.com", req.URL()) + return nil + }).SetDial(func(addr string) (net.Conn, error) { + return ln.Dial() + }) + req.SetURL("http://example.com") + + resp, err := client.core.execute(context.Background(), client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "Cannot GET /", string(resp.rawResponse.Body())) + }) + + t.Run("add user response hooks", func(t *testing.T) { + client, req := AcquireClient(), AcquireRequest() + client.AddResponseHook(func(c *Client, resp *Response, req *Request) error { + utils.AssertEqual(t, "http://example.com", req.URL()) + return nil + }).SetDial(func(addr string) (net.Conn, error) { + return ln.Dial() + }) + req.SetURL("http://example.com") + + resp, err := client.core.execute(context.Background(), client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "Cannot GET /", string(resp.rawResponse.Body())) + }) + + t.Run("no timeout", func(t *testing.T) { + client, req := AcquireClient(), AcquireRequest() + client.SetDial(func(addr string) (net.Conn, error) { + return ln.Dial() + }) + req.SetURL("http://example.com/hang-up") + + resp, err := client.core.execute(context.Background(), client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "example.com hang up", string(resp.rawResponse.Body())) + }) + + t.Run("client timeout", func(t *testing.T) { + client, req := AcquireClient(), AcquireRequest() + client.SetDial(func(addr string) (net.Conn, error) { + return ln.Dial() + }).SetTimeout(500 * time.Millisecond) + req.SetURL("http://example.com/hang-up") + + _, err := client.core.execute(context.Background(), client, req) + utils.AssertEqual(t, ErrTimeoutOrCancel, err) + }) + + t.Run("request timeout", func(t *testing.T) { + client, req := AcquireClient(), AcquireRequest() + client.SetDial(func(addr string) (net.Conn, error) { + return ln.Dial() + }) + req.SetURL("http://example.com/hang-up"). + SetTimeout(300 * time.Millisecond) + + _, err := client.core.execute(context.Background(), client, req) + utils.AssertEqual(t, ErrTimeoutOrCancel, err) + }) + + t.Run("request timeout has higher level", func(t *testing.T) { + client, req := AcquireClient(), AcquireRequest() + client.SetDial(func(addr string) (net.Conn, error) { + return ln.Dial() + }). + SetTimeout(30 * time.Millisecond) + req.SetURL("http://example.com/hang-up"). + SetTimeout(3000 * time.Millisecond) + + resp, err := client.core.execute(context.Background(), client, req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "example.com hang up", string(resp.rawResponse.Body())) + }) +} diff --git a/client/hooks.go b/client/hooks.go index 9a52c90415..b85e85cf19 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -9,7 +9,6 @@ import ( "net" "os" "path/filepath" - "reflect" "regexp" "strconv" "strings" @@ -86,16 +85,16 @@ func parserRequestURL(c *Client, req *Request) error { if !protocolCheck.MatchString(uri) { uri = c.baseUrl + uri if !protocolCheck.MatchString(uri) { - return fmt.Errorf("url format error") + return ErrURLForamt } } // set path params req.path.VisitAll(func(key, val string) { - uri = strings.Replace(uri, "{"+key+"}", val, -1) + uri = strings.Replace(uri, ":"+key, val, -1) }) c.path.VisitAll(func(key, val string) { - uri = strings.Replace(uri, "{"+key+"}", val, -1) + uri = strings.Replace(uri, ":"+key, val, -1) }) // set uri to request and orther related setting @@ -105,7 +104,7 @@ func parserRequestURL(c *Client, req *Request) error { if bytes.Equal(httpsBytes, scheme) { isTLS = true } else if !bytes.Equal(httpBytes, scheme) { - return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme) + return ErrNotSupportSchema } c.core.client.Addr = addMissingPort(string(rawUri.Host()), isTLS) @@ -232,7 +231,7 @@ func parserRequestBody(c *Client, req *Request) (err error) { b := make([]byte, 512) for i, v := range req.files { if v.name == "" && v.path == "" { - return fmt.Errorf("the file should have a name") + return ErrFileNoName } // if name is not exist, set name @@ -280,7 +279,7 @@ func parserRequestBody(c *Client, req *Request) (err error) { if body, ok := req.body.([]byte); ok { req.rawRequest.SetBody(body) } else { - return fmt.Errorf("the raw body should be []byte, but we receive %s", reflect.TypeOf(req.body).Kind().String()) + return ErrBodyType } } return nil diff --git a/client/hooks_test.go b/client/hooks_test.go index 96bfe2e66a..bdad986fcd 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -2,7 +2,6 @@ package client import ( "encoding/xml" - "fmt" "io" "net/url" "strings" @@ -114,12 +113,12 @@ func TestParserRequestURL(t *testing.T) { req := AcquireRequest().SetURL("/v1") err := parserRequestURL(client, req) - utils.AssertEqual(t, fmt.Errorf("url format error"), err) + utils.AssertEqual(t, ErrURLForamt, err) }) t.Run("the path param from client", func(t *testing.T) { client := AcquireClient(). - SetBaseURL("http://example.com/api/{id}"). + SetBaseURL("http://example.com/api/:id"). SetPathParam("id", "5") req := AcquireRequest() @@ -130,7 +129,7 @@ func TestParserRequestURL(t *testing.T) { t.Run("the path param from request", func(t *testing.T) { client := AcquireClient(). - SetBaseURL("http://example.com/api/{id}/{name}"). + SetBaseURL("http://example.com/api/:id/:name"). SetPathParam("id", "5") req := AcquireRequest(). SetURL("/{key}"). @@ -147,10 +146,10 @@ func TestParserRequestURL(t *testing.T) { t.Run("the path param from request and client", func(t *testing.T) { client := AcquireClient(). - SetBaseURL("http://example.com/api/{id}/{name}"). + SetBaseURL("http://example.com/api/:id/:name"). SetPathParam("id", "5") req := AcquireRequest(). - SetURL("/{key}"). + SetURL("/:key"). SetPathParams(map[string]string{ "name": "fiber", "key": "val", From 5ebb21d831a07c58c37f5388689e2ebade0eb03d Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Mon, 8 Aug 2022 20:02:56 +0800 Subject: [PATCH 020/118] =?UTF-8?q?=E2=9C=8F=EF=B8=8F=20v3:=20error=20spel?= =?UTF-8?q?l?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/{respose.go => response.go} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename client/{respose.go => response.go} (100%) diff --git a/client/respose.go b/client/response.go similarity index 100% rename from client/respose.go rename to client/response.go From cdeb94ebe352ab8840af8276109f6fc79b9f298a Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Tue, 9 Aug 2022 21:04:59 +0800 Subject: [PATCH 021/118] =?UTF-8?q?=E2=9C=85=20v3:=20improve=20test=20cove?= =?UTF-8?q?rage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/README.md | 9 + client/client.go | 81 +- client/client_test.go | 132 +++- client/core.go | 16 +- client/hooks.go | 2 + client/request.go | 137 +++- client/request_test.go | 1623 +++++++++++++++++++++++++++++++++++++++- client/response.go | 21 +- 8 files changed, 1931 insertions(+), 90 deletions(-) create mode 100644 client/README.md diff --git a/client/README.md b/client/README.md new file mode 100644 index 0000000000..b8719aa21b --- /dev/null +++ b/client/README.md @@ -0,0 +1,9 @@ +

Fiber Client

+

Easy-to-use HTTP client based on fasthttp (inspired by resty and axios)

+

Features section describes in detail about Resty capabilities

+ +## Features + +- GET, POST, PUT, DELETE, HEAD, PATCH, OPTIONS, etc. +- Simple and chainable methods for settings and request +- \ No newline at end of file diff --git a/client/client.go b/client/client.go index 7792467fd0..a6d1bcfeb1 100644 --- a/client/client.go +++ b/client/client.go @@ -8,8 +8,14 @@ import ( "github.com/valyala/fasthttp" ) +// The Client is used to create a Fiber Client with +// client-level settings that apply to all requests +// raise from the client. +// +// Fiber Client also provides an option to override +// or merge most of the client settings at the request. type Client struct { - core *Core + core *core baseUrl string userAgent string @@ -22,16 +28,27 @@ type Client struct { timeout time.Duration } +// R raise a request from the client. func (c *Client) R() *Request { return AcquireRequest().SetClient(c) } +// Request returns user-defined request hooks. +func (c *Client) RequestHook() []RequestHook { + return c.core.userRequestHooks +} + // Add user-defined request hooks. func (c *Client) AddRequestHook(h ...RequestHook) *Client { c.core.userRequestHooks = append(c.core.userRequestHooks, h...) return c } +// ResponseHook return user-define reponse hooks. +func (c *Client) ResponseHook() []ResponseHook { + return c.core.userResponseHooks +} + // Add user-defined response hooks. func (c *Client) AddResponseHook(h ...ResponseHook) *Client { c.core.userResponseHooks = append(c.core.userResponseHooks, h...) @@ -45,30 +62,55 @@ func (c *Client) SetDial(f fasthttp.DialFunc) *Client { return c } +// JSONMarshal returns json marshal function in Core. +func (c *Client) JSONMarshal() utils.JSONMarshal { + return c.core.jsonMarshal +} + // Set json encoder. func (c *Client) SetJSONMarshal(f utils.JSONMarshal) *Client { c.core.jsonMarshal = f return c } +// JSONUnmarshal returns json unmarshal function in Core. +func (c *Client) JSONUnmarshal() utils.JSONUnmarshal { + return c.core.jsonUnmarshal +} + // Set json decoder. func (c *Client) SetJSONUnmarshal(f utils.JSONUnmarshal) *Client { c.core.jsonUnmarshal = f return c } +// XMLMarshal returns xml marshal function in Core. +func (c *Client) XMLMarshal() utils.XMLMarshal { + return c.core.xmlMarshal +} + // Set xml encoder. func (c *Client) SetXMLMarshal(f utils.XMLMarshal) *Client { c.core.xmlMarshal = f return c } +// XMLUnmarshal returns xml unmarshal function in Core. +func (c *Client) XMLUnmarshal() utils.XMLUnmarshal { + return c.core.xmlUnmarshal +} + // Set xml decoder. func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client { c.core.xmlUnmarshal = f return c } +// BaseURL returns baseurl in Client instance. +func (c *Client) BaseURL() string { + return c.baseUrl +} + // Set baseUrl which is prefix of real url. func (c *Client) SetBaseURL(url string) *Client { c.baseUrl = url @@ -325,7 +367,6 @@ func (c *Client) Reset() { c.path.Reset() c.cookies.Reset() - c.core.reset() c.header.Reset() c.params.Reset() } @@ -383,7 +424,21 @@ func SetRequestFiles(files ...*File) SetRequestOptionFunc { var ( defaultClient *Client defaultUserAgent = "fiber" - clientPool sync.Pool + clientPool = &sync.Pool{ + New: func() any { + return &Client{ + core: newCore(), + header: &Header{ + RequestHeader: &fasthttp.RequestHeader{}, + }, + params: &QueryParam{ + Args: fasthttp.AcquireArgs(), + }, + cookies: &Cookie{}, + path: &PathParam{}, + } + }, + } ) func init() { @@ -394,24 +449,8 @@ func init() { // // The returned Client object may be returned to the pool with ReleaseClient when no longer needed. // This allows reducing GC load. -func AcquireClient() (c *Client) { - cv := clientPool.Get() - if cv != nil { - c = cv.(*Client) - return - } - c = &Client{ - core: AcquireCore(), - header: &Header{ - RequestHeader: &fasthttp.RequestHeader{}, - }, - params: &QueryParam{ - Args: fasthttp.AcquireArgs(), - }, - cookies: &Cookie{}, - path: &PathParam{}, - } - return +func AcquireClient() *Client { + return clientPool.Get().(*Client) } // ReleaseClient returns the object acquired via AcquireClient to the pool. diff --git a/client/client_test.go b/client/client_test.go index 040cce0c11..f471047bf9 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,7 +1,9 @@ package client import ( + "fmt" "net" + "reflect" "testing" "github.com/gofiber/fiber/v3" @@ -9,30 +11,6 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) -// import ( -// "bytes" -// "crypto/tls" -// "encoding/base64" -// "errors" -// "fmt" -// "io" -// "mime/multipart" -// "net" -// "os" -// "path/filepath" -// "regexp" -// "strings" -// "testing" -// "time" - -// "encoding/json" - -// "github.com/gofiber/fiber/v3" -// "github.com/gofiber/fiber/v3/internal/tlstest" -// "github.com/gofiber/fiber/v3/utils" -// "github.com/valyala/fasthttp/fasthttputil" -// ) - // func Test_Client_Invalid_URL(t *testing.T) { // t.Parallel() @@ -1211,3 +1189,109 @@ func TestGet(t *testing.T) { // type errorWriter struct{} // func (errorWriter) Write(_ []byte) (int, error) { return 0, errors.New("Write error") } + +func TestClientR(t *testing.T) { + t.Parallel() + + client := AcquireClient() + req := client.R() + + utils.AssertEqual(t, "Request", reflect.TypeOf(req).Elem().Name()) + utils.AssertEqual(t, client, req.Client()) +} + +func TestClientAddHook(t *testing.T) { + t.Parallel() + + t.Run("add request hooks", func(t *testing.T) { + client := AcquireClient().AddRequestHook(func(c *Client, r *Request) error { + return nil + }) + + utils.AssertEqual(t, 1, len(client.RequestHook())) + + client.AddRequestHook(func(c *Client, r *Request) error { + return nil + }, func(c *Client, r *Request) error { + return nil + }) + + utils.AssertEqual(t, 3, len(client.RequestHook())) + }) + + t.Run("add response hooks", func(t *testing.T) { + client := AcquireClient().AddResponseHook(func(c *Client, resp *Response, r *Request) error { + return nil + }) + + utils.AssertEqual(t, 1, len(client.ResponseHook())) + + client.AddResponseHook(func(c *Client, resp *Response, r *Request) error { + return nil + }, func(c *Client, resp *Response, r *Request) error { + return nil + }) + + utils.AssertEqual(t, 3, len(client.ResponseHook())) + }) +} + +func TestClientMarshal(t *testing.T) { + t.Run("set json marshal", func(t *testing.T) { + client := AcquireClient(). + SetJSONMarshal(func(v any) ([]byte, error) { + return []byte("hello"), nil + }) + val, err := client.JSONMarshal()(nil) + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("hello"), val) + }) + + t.Run("set json unmarshal", func(t *testing.T) { + client := AcquireClient(). + SetJSONUnmarshal(func(data []byte, v any) error { + return fmt.Errorf("empty json") + }) + + err := client.JSONUnmarshal()(nil, nil) + utils.AssertEqual(t, fmt.Errorf("empty json"), err) + }) + + t.Run("set xml marshal", func(t *testing.T) { + client := AcquireClient(). + SetXMLMarshal(func(v any) ([]byte, error) { + return []byte("hello"), nil + }) + val, err := client.XMLMarshal()(nil) + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("hello"), val) + }) + + t.Run("set xml unmarshal", func(t *testing.T) { + client := AcquireClient(). + SetXMLUnmarshal(func(data []byte, v any) error { + return fmt.Errorf("empty xml") + }) + + err := client.XMLUnmarshal()(nil, nil) + utils.AssertEqual(t, fmt.Errorf("empty xml"), err) + }) +} + +func TestClientSetBaseURL(t *testing.T) { + t.Parallel() + + client := AcquireClient().SetBaseURL("http://example.com") + + utils.AssertEqual(t, "http://example.com", client.BaseURL()) +} + +func TestClientHeader(t *testing.T) { + t.Parallel() + + t.Run("", func(t *testing.T) { + + }) +} diff --git a/client/core.go b/client/core.go index e2e44ac2da..5871e0a901 100644 --- a/client/core.go +++ b/client/core.go @@ -144,20 +144,18 @@ func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Resp return resp, nil } -var errChanPool sync.Pool +var errChanPool = &sync.Pool{ + New: func() any { + return make(chan error, 1) + }, +} // acquireErrChan returns an empty error chan from the pool. // // The returned error chan may be returned to the pool with releaseErrChan when no longer needed. // This allows reducing GC load. -func acquireErrChan() (ch chan error) { - chv := errChanPool.Get() - if chv != nil { - ch = chv.(chan error) - return - } - ch = make(chan error, 1) - return +func acquireErrChan() chan error { + return errChanPool.Get().(chan error) } // releaseErrChan returns the object acquired via acquireErrChan to the pool. diff --git a/client/hooks.go b/client/hooks.go index b85e85cf19..768f3d1c1f 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -136,6 +136,8 @@ func parserRequestURL(c *Client, req *Request) error { // Header should be set automatically based on data. // User-Agent should be set. func parserRequestHeader(c *Client, req *Request) error { + // set method + req.rawRequest.Header.SetMethod(req.Method()) // merge header c.header.VisitAll(func(key, value []byte) { req.rawRequest.Header.SetBytesKV(key, value) diff --git a/client/request.go b/client/request.go index 7307da58cd..5926606615 100644 --- a/client/request.go +++ b/client/request.go @@ -1,18 +1,22 @@ package client import ( + "bytes" "context" "io" "reflect" + "sort" "strconv" "sync" "time" "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/utils" "github.com/valyala/fasthttp" ) -// Implementing this interface allows data to be passed through the structure. +// Implementing this interface allows data to +// be stored from a struct via reflect. type WithStruct interface { Add(string, string) Del(string) @@ -55,6 +59,11 @@ type Request struct { rawRequest *fasthttp.Request } +// Method returns http method in request. +func (r *Request) Method() string { + return r.method +} + // SetMethod will set method for Request object, // user should use request method to set method. func (r *Request) SetMethod(method string) *Request { @@ -62,12 +71,22 @@ func (r *Request) SetMethod(method string) *Request { return r } +// URL returns request url in Request instance. +func (r *Request) URL() string { + return r.url +} + // SetURL will set url for Request object. func (r *Request) SetURL(url string) *Request { r.url = url return r } +// Client get Client instance in Request. +func (r *Request) Client() *Client { + return r.client +} + // SetClient method sets client in request instance. func (r *Request) SetClient(c *Client) *Request { r.client = c @@ -92,6 +111,13 @@ func (r *Request) SetContext(ctx context.Context) *Request { return r } +// Header method returns header value via key, +// this method will visit all field in the header, +// then sort them. +func (r *Request) Header(key string) []string { + return r.header.PeekMultiple(key) +} + // AddHeader method adds a single header field and its value in the request instance. // It will override header which set in client instance. func (r *Request) AddHeader(key, val string) *Request { @@ -102,6 +128,7 @@ func (r *Request) AddHeader(key, val string) *Request { // SetHeader method sets a single header field and its value in the request instance. // It will override header which set in client instance. func (r *Request) SetHeader(key, val string) *Request { + r.header.Del(key) r.header.Set(key, val) return r } @@ -120,6 +147,20 @@ func (r *Request) SetHeaders(h map[string]string) *Request { return r } +// Param method returns params value via key, +// this method will visit all field in the query param, +// then sort them. +func (r *Request) Param(key string) []string { + res := []string{} + tmp := r.params.PeekMulti(key) + for _, v := range tmp { + res = append(res, utils.UnsafeString(v)) + } + sort.Strings(res) + + return res +} + // AddParam method adds a single param field and its value in the request instance. // It will override param which set in client instance. func (r *Request) AddParam(key, val string) *Request { @@ -163,6 +204,11 @@ func (r *Request) DelParams(key ...string) *Request { return r } +// UserAgent returns user agent in request instance. +func (r *Request) UserAgent() string { + return r.userAgent +} + // SetUserAgent method sets user agent in request. // It will override user agent which set in client instance. func (r *Request) SetUserAgent(ua string) *Request { @@ -170,6 +216,11 @@ func (r *Request) SetUserAgent(ua string) *Request { return r } +// Referer returns referer in request instance. +func (r *Request) Referer() string { + return r.referer +} + // SetReferer method sets referer in request. // It will override referer which set in client instance. func (r *Request) SetReferer(referer string) *Request { @@ -177,6 +228,15 @@ func (r *Request) SetReferer(referer string) *Request { return r } +// Cookie returns the cookie be set in request instance. +// if cookie doesn't exist, return empty string. +func (r *Request) Cookie(key string) string { + if val, ok := (*r.cookies)[key]; ok { + return val + } + return "" +} + // SetCookie method sets a single cookie field and its value in the request instance. // It will override cookie which set in client instance. func (r *Request) SetCookie(key, val string) *Request { @@ -204,6 +264,16 @@ func (r *Request) DelCookies(key ...string) *Request { return r } +// PathParam returns the path param be set in request instance. +// if path param doesn't exist, return empty string. +func (r *Request) PathParam(key string) string { + if val, ok := (*r.path)[key]; ok { + return val + } + + return "" +} + // SetPathParam method sets a single path param field and its value in the request instance. // It will override path param which set in client instance. func (r *Request) SetPathParam(key, val string) *Request { @@ -263,6 +333,20 @@ func (r *Request) resetBody(t bodyType) { r.bodyType = t } +// FormData method returns form data value via key, +// this method will visit all field in the form data, +// then sort them. +func (r *Request) FormData(key string) []string { + res := []string{} + tmp := r.formData.PeekMulti(key) + for _, v := range tmp { + res = append(res, utils.UnsafeString(v)) + } + sort.Strings(res) + + return res +} + // AddFormData method adds a single form data field and its value in the request instance. func (r *Request) AddFormData(key, val string) *Request { r.formData.AddData(key, val) @@ -430,6 +514,20 @@ type Header struct { *fasthttp.RequestHeader } +// Peekmutiple methods returns multiple field in header with same key. +func (h *Header) PeekMultiple(key string) []string { + res := []string{} + byteKey := []byte(key) + h.RequestHeader.VisitAll(func(key, value []byte) { + if bytes.EqualFold(key, byteKey) { + res = append(res, utils.UnsafeString(value)) + } + }) + sort.Strings(res) + + return res +} + // AddHeaders receive a map and add each value to header. func (h *Header) AddHeaders(r map[string][]string) { for k, v := range r { @@ -442,6 +540,7 @@ func (h *Header) AddHeaders(r map[string][]string) { // SetHeaders will override all headers. func (h *Header) SetHeaders(r map[string]string) { for k, v := range r { + h.Del(k) h.Set(k, v) } } @@ -662,30 +761,30 @@ func (f *File) Reset() { f.reader = nil } -var requestPool sync.Pool +var requestPool = &sync.Pool{ + New: func() any { + return &Request{ + header: &Header{RequestHeader: &fasthttp.RequestHeader{}}, + params: &QueryParam{Args: fasthttp.AcquireArgs()}, + cookies: &Cookie{}, + path: &PathParam{}, + boundary: "--FiberFormBoundary" + randString(16), + formData: &FormData{Args: fasthttp.AcquireArgs()}, + files: make([]*File, 0), + rawRequest: fasthttp.AcquireRequest(), + } + }, +} // AcquireRequest returns an empty request object from the pool. // // The returned request may be returned to the pool with ReleaseRequest when no longer needed. // This allows reducing GC load. -func AcquireRequest() (req *Request) { - reqv := requestPool.Get() - if reqv != nil { - req = reqv.(*Request) - return - } +func AcquireRequest() *Request { + req := requestPool.Get().(*Request) + req.boundary = "--FiberFormBoundary" + randString(16) - req = &Request{ - header: &Header{RequestHeader: &fasthttp.RequestHeader{}}, - params: &QueryParam{Args: fasthttp.AcquireArgs()}, - cookies: &Cookie{}, - path: &PathParam{}, - boundary: "--FiberFormBoundary" + randString(16), - formData: &FormData{Args: fasthttp.AcquireArgs()}, - files: make([]*File, 0), - rawRequest: fasthttp.AcquireRequest(), - } - return + return req } // ReleaseRequest returns the object acquired via AcquireRequest to the pool. diff --git a/client/request_test.go b/client/request_test.go index 11e3306fd1..132961f082 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -1,15 +1,1626 @@ package client import ( + "context" + "net" "testing" + "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/utils" "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttputil" ) -func TestParamsSetParamsWithStruct(t *testing.T) { +func TestRequestMethod(t *testing.T) { t.Parallel() + req := AcquireRequest() + req.SetMethod("GET") + utils.AssertEqual(t, "GET", req.Method()) + + req.SetMethod("POST") + utils.AssertEqual(t, "POST", req.Method()) + + req.SetMethod("PUT") + utils.AssertEqual(t, "PUT", req.Method()) + + req.SetMethod("DELETE") + utils.AssertEqual(t, "DELETE", req.Method()) +} + +func TestRequestURL(t *testing.T) { + t.Parallel() + + req := AcquireRequest() + + req.SetURL("http://example.com/normal") + utils.AssertEqual(t, "http://example.com/normal", req.URL()) + + req.SetURL("https://example.com/normal") + utils.AssertEqual(t, "https://example.com/normal", req.URL()) +} + +func TestRequestClient(t *testing.T) { + t.Parallel() + + client := AcquireClient() + req := AcquireRequest() + + req.SetClient(client) + utils.AssertEqual(t, client, req.Client()) +} + +func TestRequestContext(t *testing.T) { + t.Parallel() + + req := AcquireRequest() + ctx := req.Context() + key := struct{}{} + + utils.AssertEqual(t, nil, ctx.Value(key)) + + ctx = context.WithValue(ctx, key, "string") + req.SetContext(ctx) + ctx = req.Context() + + utils.AssertEqual(t, "string", ctx.Value(key).(string)) +} + +func TestRequestHeader(t *testing.T) { + t.Parallel() + + t.Run("add header", func(t *testing.T) { + req := AcquireRequest() + req.AddHeader("foo", "bar").AddHeader("foo", "fiber") + + res := req.Header("foo") + utils.AssertEqual(t, 2, len(res)) + utils.AssertEqual(t, "bar", res[0]) + utils.AssertEqual(t, "fiber", res[1]) + }) + + t.Run("set header", func(t *testing.T) { + req := AcquireRequest() + req.AddHeader("foo", "bar").SetHeader("foo", "fiber") + + res := req.Header("foo") + utils.AssertEqual(t, 1, len(res)) + utils.AssertEqual(t, "fiber", res[0]) + }) + + t.Run("add headers", func(t *testing.T) { + req := AcquireRequest() + req.SetHeader("foo", "bar"). + AddHeaders(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.Header("foo") + utils.AssertEqual(t, 3, len(res)) + utils.AssertEqual(t, "bar", res[0]) + utils.AssertEqual(t, "buaa", res[1]) + utils.AssertEqual(t, "fiber", res[2]) + + res = req.Header("bar") + utils.AssertEqual(t, 1, len(res)) + utils.AssertEqual(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + req := AcquireRequest() + req.SetHeader("foo", "bar"). + SetHeaders(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.Header("foo") + utils.AssertEqual(t, 1, len(res)) + utils.AssertEqual(t, "fiber", res[0]) + + res = req.Header("bar") + utils.AssertEqual(t, 1, len(res)) + utils.AssertEqual(t, "foo", res[0]) + }) +} + +func TestRequestQueryParam(t *testing.T) { + t.Parallel() + + t.Run("add param", func(t *testing.T) { + req := AcquireRequest() + req.AddParam("foo", "bar").AddParam("foo", "fiber") + + res := req.Param("foo") + utils.AssertEqual(t, 2, len(res)) + utils.AssertEqual(t, "bar", res[0]) + utils.AssertEqual(t, "fiber", res[1]) + }) + + t.Run("set param", func(t *testing.T) { + req := AcquireRequest() + req.AddParam("foo", "bar").SetParam("foo", "fiber") + + res := req.Param("foo") + utils.AssertEqual(t, 1, len(res)) + utils.AssertEqual(t, "fiber", res[0]) + }) + + t.Run("add params", func(t *testing.T) { + req := AcquireRequest() + req.SetParam("foo", "bar"). + AddParams(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.Param("foo") + utils.AssertEqual(t, 3, len(res)) + utils.AssertEqual(t, "bar", res[0]) + utils.AssertEqual(t, "buaa", res[1]) + utils.AssertEqual(t, "fiber", res[2]) + + res = req.Param("bar") + utils.AssertEqual(t, 1, len(res)) + utils.AssertEqual(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + req := AcquireRequest() + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.Param("foo") + utils.AssertEqual(t, 1, len(res)) + utils.AssertEqual(t, "fiber", res[0]) + + res = req.Param("bar") + utils.AssertEqual(t, 1, len(res)) + utils.AssertEqual(t, "foo", res[0]) + }) + + t.Run("set params with struct", func(t *testing.T) { + t.Parallel() + + type args struct { + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `param:"int_slice"` + } + + p := AcquireRequest() + p.SetParamsWithStruct(&args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: true, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + utils.AssertEqual(t, 0, len(p.Param("unexport"))) + + utils.AssertEqual(t, 1, len(p.Param("TInt"))) + utils.AssertEqual(t, "5", p.Param("TInt")[0]) + + utils.AssertEqual(t, 1, len(p.Param("TString"))) + utils.AssertEqual(t, "string", p.Param("TString")[0]) + + utils.AssertEqual(t, 1, len(p.Param("TFloat"))) + utils.AssertEqual(t, "3.1", p.Param("TFloat")[0]) + + utils.AssertEqual(t, 1, len(p.Param("TBool"))) + + tslice := p.Param("TSlice") + utils.AssertEqual(t, 2, len(tslice)) + utils.AssertEqual(t, "bar", tslice[0]) + utils.AssertEqual(t, "foo", tslice[1]) + + tint := p.Param("TSlice") + utils.AssertEqual(t, 2, len(tint)) + utils.AssertEqual(t, "bar", tint[0]) + utils.AssertEqual(t, "foo", tint[1]) + }) + + t.Run("del params", func(t *testing.T) { + req := AcquireRequest() + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "foo": "fiber", + "bar": "foo", + }).DelParams("foo", "bar") + + res := req.Param("foo") + utils.AssertEqual(t, 0, len(res)) + + res = req.Param("bar") + utils.AssertEqual(t, 0, len(res)) + }) +} + +func TestRequestUA(t *testing.T) { + t.Parallel() + + req := AcquireRequest().SetUserAgent("fiber") + utils.AssertEqual(t, "fiber", req.UserAgent()) + + req.SetUserAgent("foo") + utils.AssertEqual(t, "foo", req.UserAgent()) +} + +func TestReferer(t *testing.T) { + t.Parallel() + + req := AcquireRequest().SetReferer("http://example.com") + utils.AssertEqual(t, "http://example.com", req.Referer()) + + req.SetReferer("https://example.com") + utils.AssertEqual(t, "https://example.com", req.Referer()) +} + +func TestRequestCookie(t *testing.T) { + t.Parallel() + + t.Run("set cookie", func(t *testing.T) { + req := AcquireRequest(). + SetCookie("foo", "bar") + utils.AssertEqual(t, "bar", req.Cookie("foo")) + + req.SetCookie("foo", "bar1") + utils.AssertEqual(t, "bar1", req.Cookie("foo")) + }) + + t.Run("set cookies", func(t *testing.T) { + req := AcquireRequest(). + SetCookies(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + utils.AssertEqual(t, "bar", req.Cookie("foo")) + utils.AssertEqual(t, "foo", req.Cookie("bar")) + + req.SetCookies(map[string]string{ + "foo": "bar1", + }) + utils.AssertEqual(t, "bar1", req.Cookie("foo")) + utils.AssertEqual(t, "foo", req.Cookie("bar")) + }) + + t.Run("set cookies with struct", func(t *testing.T) { + type args struct { + CookieInt int `cookie:"int"` + CookieString string `cookie:"string"` + } + + req := AcquireRequest().SetCookiesWithStruct(&args{ + CookieInt: 5, + CookieString: "foo", + }) + + utils.AssertEqual(t, "5", req.Cookie("int")) + utils.AssertEqual(t, "foo", req.Cookie("string")) + }) + + t.Run("del cookies", func(t *testing.T) { + req := AcquireRequest(). + SetCookies(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + utils.AssertEqual(t, "bar", req.Cookie("foo")) + utils.AssertEqual(t, "foo", req.Cookie("bar")) + + req.DelCookies("foo") + utils.AssertEqual(t, "", req.Cookie("foo")) + utils.AssertEqual(t, "foo", req.Cookie("bar")) + }) +} + +func TestRequestPathParam(t *testing.T) { + t.Parallel() + + t.Run("set path param", func(t *testing.T) { + req := AcquireRequest(). + SetPathParam("foo", "bar") + utils.AssertEqual(t, "bar", req.PathParam("foo")) + + req.SetPathParam("foo", "bar1") + utils.AssertEqual(t, "bar1", req.PathParam("foo")) + }) + + t.Run("set path params", func(t *testing.T) { + req := AcquireRequest(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + utils.AssertEqual(t, "bar", req.PathParam("foo")) + utils.AssertEqual(t, "foo", req.PathParam("bar")) + + req.SetPathParams(map[string]string{ + "foo": "bar1", + }) + utils.AssertEqual(t, "bar1", req.PathParam("foo")) + utils.AssertEqual(t, "foo", req.PathParam("bar")) + }) + + t.Run("set path params with struct", func(t *testing.T) { + type args struct { + CookieInt int `path:"int"` + CookieString string `path:"string"` + } + + req := AcquireRequest().SetPathParamsWithStruct(&args{ + CookieInt: 5, + CookieString: "foo", + }) + + utils.AssertEqual(t, "5", req.PathParam("int")) + utils.AssertEqual(t, "foo", req.PathParam("string")) + }) + + t.Run("del path params", func(t *testing.T) { + req := AcquireRequest(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + utils.AssertEqual(t, "bar", req.PathParam("foo")) + utils.AssertEqual(t, "foo", req.PathParam("bar")) + + req.DelPathParams("foo") + utils.AssertEqual(t, "", req.PathParam("foo")) + utils.AssertEqual(t, "foo", req.PathParam("bar")) + }) +} + +func TestRequestFormData(t *testing.T) { + t.Parallel() + + t.Run("add form data", func(t *testing.T) { + req := AcquireRequest() + req.AddFormData("foo", "bar").AddFormData("foo", "fiber") + + res := req.FormData("foo") + utils.AssertEqual(t, 2, len(res)) + utils.AssertEqual(t, "bar", res[0]) + utils.AssertEqual(t, "fiber", res[1]) + }) + + t.Run("set param", func(t *testing.T) { + req := AcquireRequest() + req.AddFormData("foo", "bar").SetFormData("foo", "fiber") + + res := req.FormData("foo") + utils.AssertEqual(t, 1, len(res)) + utils.AssertEqual(t, "fiber", res[0]) + }) + + t.Run("add params", func(t *testing.T) { + req := AcquireRequest() + req.SetFormData("foo", "bar"). + AddFormDatas(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.FormData("foo") + utils.AssertEqual(t, 3, len(res)) + utils.AssertEqual(t, "bar", res[0]) + utils.AssertEqual(t, "buaa", res[1]) + utils.AssertEqual(t, "fiber", res[2]) + + res = req.FormData("bar") + utils.AssertEqual(t, 1, len(res)) + utils.AssertEqual(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + req := AcquireRequest() + req.SetFormData("foo", "bar"). + SetFormDatas(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.FormData("foo") + utils.AssertEqual(t, 1, len(res)) + utils.AssertEqual(t, "fiber", res[0]) + + res = req.FormData("bar") + utils.AssertEqual(t, 1, len(res)) + utils.AssertEqual(t, "foo", res[0]) + }) + + t.Run("set params with struct", func(t *testing.T) { + t.Parallel() + + type args struct { + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `form:"int_slice"` + } + + p := AcquireRequest() + p.SetFormDatasWithStruct(&args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: true, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + utils.AssertEqual(t, 0, len(p.FormData("unexport"))) + + utils.AssertEqual(t, 1, len(p.FormData("TInt"))) + utils.AssertEqual(t, "5", p.FormData("TInt")[0]) + + utils.AssertEqual(t, 1, len(p.FormData("TString"))) + utils.AssertEqual(t, "string", p.FormData("TString")[0]) + + utils.AssertEqual(t, 1, len(p.FormData("TFloat"))) + utils.AssertEqual(t, "3.1", p.FormData("TFloat")[0]) + + utils.AssertEqual(t, 1, len(p.FormData("TBool"))) + + tslice := p.FormData("TSlice") + utils.AssertEqual(t, 2, len(tslice)) + utils.AssertEqual(t, "bar", tslice[0]) + utils.AssertEqual(t, "foo", tslice[1]) + + tint := p.FormData("TSlice") + utils.AssertEqual(t, 2, len(tint)) + utils.AssertEqual(t, "bar", tint[0]) + utils.AssertEqual(t, "foo", tint[1]) + + }) + + t.Run("del params", func(t *testing.T) { + req := AcquireRequest() + req.SetFormData("foo", "bar"). + SetFormDatas(map[string]string{ + "foo": "fiber", + "bar": "foo", + }).DelFormDatas("foo", "bar") + + res := req.FormData("foo") + utils.AssertEqual(t, 0, len(res)) + + res = req.FormData("bar") + utils.AssertEqual(t, 0, len(res)) + }) +} + +func TestRequestFile(t *testing.T) { + t.Parallel() + + t.Run("add file", func(t *testing.T) { + + }) +} + +func createHelperServer(t *testing.T) (*fiber.App, *Client, func()) { + t.Helper() + + ln := fasthttputil.NewInmemoryListener() + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + + client := AcquireClient().SetDial(func(addr string) (net.Conn, error) { + return ln.Dial() + }) + + return app, client, func() { + utils.AssertEqual(t, nil, app.Listener(ln)) + } +} + +func TestRequestInvalidURL(t *testing.T) { + t.Parallel() + + resp, err := AcquireRequest(). + Get("http://example.com\r\n\r\nGET /\r\n\r\n") + + utils.AssertEqual(t, ErrURLForamt, err) + utils.AssertEqual(t, (*Response)(nil), resp) +} + +func TestRequestUnsupportProtocol(t *testing.T) { + t.Parallel() + + resp, err := AcquireRequest(). + Get("ftp://example.com") + utils.AssertEqual(t, ErrURLForamt, err) + utils.AssertEqual(t, (*Response)(nil), resp) +} + +func TestRequestGet(t *testing.T) { + t.Parallel() + + app, client, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + go start() + + for i := 0; i < 5; i++ { + req := AcquireRequest().SetClient(client) + + resp, err := req.Get("http://example.com") + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) + utils.AssertEqual(t, "example.com", resp.String()) + resp.Close() + } +} + +func TestRequestPost(t *testing.T) { + t.Parallel() + + app, client, start := createHelperServer(t) + app.Post("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusCreated). + SendString(c.FormValue("foo")) + }) + go start() + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + SetFormData("foo", "bar"). + Post("http://example.com") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusCreated, resp.StatusCode()) + utils.AssertEqual(t, "bar", resp.String()) + resp.Close() + } +} + +func TestRequestHead(t *testing.T) { + t.Parallel() + + app, client, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + + go start() + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + Head("http://example.com") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) + utils.AssertEqual(t, "", resp.String()) + resp.Close() + } +} + +func TestRequestPut(t *testing.T) { + t.Parallel() + + app, client, start := createHelperServer(t) + app.Put("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) + + go start() + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + SetFormData("foo", "bar"). + Put("http://example.com") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) + utils.AssertEqual(t, "bar", resp.String()) + + resp.Close() + } +} + +func TestRequestPatch(t *testing.T) { + t.Parallel() + + app, client, start := createHelperServer(t) + + app.Patch("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) + + go start() + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + SetFormData("foo", "bar"). + Patch("http://example.com") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) + utils.AssertEqual(t, "bar", resp.String()) + + resp.Close() + } +} + +func TestRequestDelete(t *testing.T) { + t.Parallel() + + app, client, start := createHelperServer(t) + + app.Delete("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusNoContent). + SendString("deleted") + }) + + go start() + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + Delete("http://example.com") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusNoContent, resp.StatusCode()) + utils.AssertEqual(t, "", resp.String()) + + resp.Close() + } +} + +// func Test_Client_UserAgent(t *testing.T) { +// t.Parallel() + +// ln := fasthttputil.NewInmemoryListener() + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// app.Get("/", func(c fiber.Ctx) error { +// return c.Send(c.Request().Header.UserAgent()) +// }) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + +// t.Run("default", func(t *testing.T) { +// for i := 0; i < 5; i++ { +// a := Get("http://example.com") + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// code, body, errs := a.String() + +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, defaultUserAgent, body) +// utils.AssertEqual(t, 0, len(errs)) +// } +// }) + +// t.Run("custom", func(t *testing.T) { +// for i := 0; i < 5; i++ { +// c := AcquireClient() +// c.UserAgent = "ua" + +// a := c.Get("http://example.com") + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// code, body, errs := a.String() + +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "ua", body) +// utils.AssertEqual(t, 0, len(errs)) +// ReleaseClient(c) +// } +// }) +// } + +// func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// c.Request().Header.VisitAll(func(key, value []byte) { +// if k := string(key); k == "K1" || k == "K2" { +// _, _ = c.Write(key) +// _, _ = c.Write(value) +// } +// }) +// return nil +// } + +// wrapAgent := func(a *Agent) { +// a.Set("k1", "v1"). +// SetBytesK([]byte("k1"), "v1"). +// SetBytesV("k1", []byte("v1")). +// AddBytesK([]byte("k1"), "v11"). +// AddBytesV("k1", []byte("v22")). +// AddBytesKV([]byte("k1"), []byte("v33")). +// SetBytesKV([]byte("k2"), []byte("v2")). +// Add("k2", "v22") +// } + +// testAgent(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") +// } + +// func Test_Client_Agent_Connection_Close(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// if c.Request().Header.ConnectionClose() { +// return c.SendString("close") +// } +// return c.SendString("not close") +// } + +// wrapAgent := func(a *Agent) { +// a.ConnectionClose() +// } + +// testAgent(t, handler, wrapAgent, "close") +// } + +// func Test_Client_Agent_UserAgent(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.Send(c.Request().Header.UserAgent()) +// } + +// wrapAgent := func(a *Agent) { +// a.UserAgent("ua"). +// UserAgentBytes([]byte("ua")) +// } + +// testAgent(t, handler, wrapAgent, "ua") +// } + +// func Test_Client_Agent_Cookie(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.SendString( +// c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) +// } + +// wrapAgent := func(a *Agent) { +// a.Cookie("k1", "v1"). +// CookieBytesK([]byte("k2"), "v2"). +// CookieBytesKV([]byte("k2"), []byte("v2")). +// Cookies("k3", "v3", "k4", "v4"). +// CookiesBytesKV([]byte("k3"), []byte("v3"), []byte("k4"), []byte("v4")) +// } + +// testAgent(t, handler, wrapAgent, "v1v2v3v4") +// } + +// func Test_Client_Agent_Referer(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.Send(c.Request().Header.Referer()) +// } + +// wrapAgent := func(a *Agent) { +// a.Referer("http://referer.com"). +// RefererBytes([]byte("http://referer.com")) +// } + +// testAgent(t, handler, wrapAgent, "http://referer.com") +// } + +// func Test_Client_Agent_ContentType(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.Send(c.Request().Header.ContentType()) +// } + +// wrapAgent := func(a *Agent) { +// a.ContentType("custom-type"). +// ContentTypeBytes([]byte("custom-type")) +// } + +// testAgent(t, handler, wrapAgent, "custom-type") +// } + +// func Test_Client_Agent_Host(t *testing.T) { +// t.Parallel() + +// ln := fasthttputil.NewInmemoryListener() + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString(c.Hostname()) +// }) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + +// a := Get("http://1.1.1.1:8080"). +// Host("example.com"). +// HostBytes([]byte("example.com")) + +// utils.AssertEqual(t, "1.1.1.1:8080", a.HostClient.Addr) + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// code, body, errs := a.String() + +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "example.com", body) +// utils.AssertEqual(t, 0, len(errs)) +// } + +// func Test_Client_Agent_QueryString(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.Send(c.Request().URI().QueryString()) +// } + +// wrapAgent := func(a *Agent) { +// a.QueryString("foo=bar&bar=baz"). +// QueryStringBytes([]byte("foo=bar&bar=baz")) +// } + +// testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") +// } + +// func Test_Client_Agent_BasicAuth(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// // Get authorization header +// auth := c.Get(fiber.HeaderAuthorization) +// // Decode the header contents +// raw, err := base64.StdEncoding.DecodeString(auth[6:]) +// utils.AssertEqual(t, nil, err) + +// return c.Send(raw) +// } + +// wrapAgent := func(a *Agent) { +// a.BasicAuth("foo", "bar"). +// BasicAuthBytes([]byte("foo"), []byte("bar")) +// } + +// testAgent(t, handler, wrapAgent, "foo:bar") +// } + +// func Test_Client_Agent_BodyString(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.Send(c.Request().Body()) +// } + +// wrapAgent := func(a *Agent) { +// a.BodyString("foo=bar&bar=baz") +// } + +// testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") +// } + +// func Test_Client_Agent_Body(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.Send(c.Request().Body()) +// } + +// wrapAgent := func(a *Agent) { +// a.Body([]byte("foo=bar&bar=baz")) +// } + +// testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") +// } + +// func Test_Client_Agent_BodyStream(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.Send(c.Request().Body()) +// } + +// wrapAgent := func(a *Agent) { +// a.BodyStream(strings.NewReader("body stream"), -1) +// } + +// testAgent(t, handler, wrapAgent, "body stream") +// } + +// func Test_Client_Agent_Custom_Response(t *testing.T) { +// t.Parallel() + +// ln := fasthttputil.NewInmemoryListener() + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString("custom") +// }) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + +// for i := 0; i < 5; i++ { +// a := AcquireAgent() +// resp := AcquireResponse() + +// req := a.Request() +// req.Header.SetMethod(fiber.MethodGet) +// req.SetRequestURI("http://example.com") + +// utils.AssertEqual(t, nil, a.Parse()) + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// code, body, errs := a.SetResponse(resp). +// String() + +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "custom", body) +// utils.AssertEqual(t, "custom", string(resp.Body())) +// utils.AssertEqual(t, 0, len(errs)) + +// ReleaseResponse(resp) +// } +// } + +// func Test_Client_Agent_Dest(t *testing.T) { +// t.Parallel() + +// ln := fasthttputil.NewInmemoryListener() + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString("dest") +// }) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + +// t.Run("small dest", func(t *testing.T) { +// dest := []byte("de") + +// a := Get("http://example.com") + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// code, body, errs := a.Dest(dest[:0]).String() + +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "dest", body) +// utils.AssertEqual(t, "de", string(dest)) +// utils.AssertEqual(t, 0, len(errs)) +// }) + +// t.Run("enough dest", func(t *testing.T) { +// dest := []byte("foobar") + +// a := Get("http://example.com") + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// code, body, errs := a.Dest(dest[:0]).String() + +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "dest", body) +// utils.AssertEqual(t, "destar", string(dest)) +// utils.AssertEqual(t, 0, len(errs)) +// }) +// } + +// // readErrorConn is a struct for testing retryIf +// type readErrorConn struct { +// net.Conn +// } + +// func (r *readErrorConn) Read(p []byte) (int, error) { +// return 0, fmt.Errorf("error") +// } + +// func (r *readErrorConn) Write(p []byte) (int, error) { +// return len(p), nil +// } + +// func (r *readErrorConn) Close() error { +// return nil +// } + +// func (r *readErrorConn) LocalAddr() net.Addr { +// return nil +// } + +// func (r *readErrorConn) RemoteAddr() net.Addr { +// return nil +// } +// func Test_Client_Agent_RetryIf(t *testing.T) { +// t.Parallel() + +// ln := fasthttputil.NewInmemoryListener() + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + +// a := Post("http://example.com"). +// RetryIf(func(req *Request) bool { +// return true +// }) +// dialsCount := 0 +// a.HostClient.Dial = func(addr string) (net.Conn, error) { +// dialsCount++ +// switch dialsCount { +// case 1: +// return &readErrorConn{}, nil +// case 2: +// return &readErrorConn{}, nil +// case 3: +// return &readErrorConn{}, nil +// case 4: +// return ln.Dial() +// default: +// t.Fatalf("unexpected number of dials: %d", dialsCount) +// } +// panic("unreachable") +// } + +// _, _, errs := a.String() +// utils.AssertEqual(t, dialsCount, 4) +// utils.AssertEqual(t, 0, len(errs)) +// } + +// func Test_Client_Agent_Json(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// utils.AssertEqual(t, fiber.MIMEApplicationJSON, string(c.Request().Header.ContentType())) + +// return c.Send(c.Request().Body()) +// } + +// wrapAgent := func(a *Agent) { +// a.JSON(data{Success: true}) +// } + +// testAgent(t, handler, wrapAgent, `{"success":true}`) +// } + +// func Test_Client_Agent_Json_Error(t *testing.T) { +// a := Get("http://example.com"). +// JSONEncoder(json.Marshal). +// JSON(complex(1, 1)) + +// _, body, errs := a.String() + +// utils.AssertEqual(t, "", body) +// utils.AssertEqual(t, 1, len(errs)) +// utils.AssertEqual(t, "json: unsupported type: complex128", errs[0].Error()) +// } + +// func Test_Client_Agent_XML(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// utils.AssertEqual(t, fiber.MIMEApplicationXML, string(c.Request().Header.ContentType())) + +// return c.Send(c.Request().Body()) +// } + +// wrapAgent := func(a *Agent) { +// a.XML(data{Success: true}) +// } + +// testAgent(t, handler, wrapAgent, "true") +// } + +// func Test_Client_Agent_XML_Error(t *testing.T) { +// a := Get("http://example.com"). +// XML(complex(1, 1)) + +// _, body, errs := a.String() + +// utils.AssertEqual(t, "", body) +// utils.AssertEqual(t, 1, len(errs)) +// utils.AssertEqual(t, "xml: unsupported type: complex128", errs[0].Error()) +// } + +// func Test_Client_Agent_Form(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// utils.AssertEqual(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) + +// return c.Send(c.Request().Body()) +// } + +// args := AcquireArgs() + +// args.Set("foo", "bar") + +// wrapAgent := func(a *Agent) { +// a.Form(args) +// } + +// testAgent(t, handler, wrapAgent, "foo=bar") + +// ReleaseArgs(args) +// } + +// func Test_Client_Agent_MultipartForm(t *testing.T) { +// t.Parallel() + +// ln := fasthttputil.NewInmemoryListener() + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// app.Post("/", func(c fiber.Ctx) error { +// utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) + +// mf, err := c.MultipartForm() +// utils.AssertEqual(t, nil, err) +// utils.AssertEqual(t, "bar", mf.Value["foo"][0]) + +// return c.Send(c.Request().Body()) +// }) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + +// args := AcquireArgs() + +// args.Set("foo", "bar") + +// a := Post("http://example.com"). +// Boundary("myBoundary"). +// MultipartForm(args) + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// code, body, errs := a.String() + +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "--myBoundary\r\nContent-Disposition: form-data; name=\"foo\"\r\n\r\nbar\r\n--myBoundary--\r\n", body) +// utils.AssertEqual(t, 0, len(errs)) +// ReleaseArgs(args) +// } + +// func Test_Client_Agent_MultipartForm_Errors(t *testing.T) { +// t.Parallel() + +// a := AcquireAgent() +// a.mw = &errorMultipartWriter{} + +// args := AcquireArgs() +// args.Set("foo", "bar") + +// ff1 := &FormFile{"", "name1", []byte("content"), false} +// ff2 := &FormFile{"", "name2", []byte("content"), false} +// a.FileData(ff1, ff2). +// MultipartForm(args) + +// utils.AssertEqual(t, 4, len(a.errs)) +// ReleaseArgs(args) +// } + +// func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) { +// t.Parallel() + +// ln := fasthttputil.NewInmemoryListener() + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// app.Post("/", func(c fiber.Ctx) error { +// utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) + +// fh1, err := c.FormFile("field1") +// utils.AssertEqual(t, nil, err) +// utils.AssertEqual(t, fh1.Filename, "name") +// buf := make([]byte, fh1.Size) +// f, err := fh1.Open() +// utils.AssertEqual(t, nil, err) +// defer func() { _ = f.Close() }() +// _, err = f.Read(buf) +// utils.AssertEqual(t, nil, err) +// utils.AssertEqual(t, "form file", string(buf)) + +// fh2, err := c.FormFile("index") +// utils.AssertEqual(t, nil, err) +// checkFormFile(t, fh2, ".github/testdata/index.html") + +// fh3, err := c.FormFile("file3") +// utils.AssertEqual(t, nil, err) +// checkFormFile(t, fh3, ".github/testdata/index.tmpl") + +// return c.SendString("multipart form files") +// }) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + +// for i := 0; i < 5; i++ { +// ff := AcquireFormFile() +// ff.Fieldname = "field1" +// ff.Name = "name" +// ff.Content = []byte("form file") + +// a := Post("http://example.com"). +// Boundary("myBoundary"). +// FileData(ff). +// SendFiles(".github/testdata/index.html", "index", ".github/testdata/index.tmpl"). +// MultipartForm(nil) + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// code, body, errs := a.String() + +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "multipart form files", body) +// utils.AssertEqual(t, 0, len(errs)) + +// ReleaseFormFile(ff) +// } +// } + +// func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { +// t.Helper() + +// basename := filepath.Base(filename) +// utils.AssertEqual(t, fh.Filename, basename) + +// b1, err := os.ReadFile(filename) +// utils.AssertEqual(t, nil, err) + +// b2 := make([]byte, fh.Size) +// f, err := fh.Open() +// utils.AssertEqual(t, nil, err) +// defer func() { _ = f.Close() }() +// _, err = f.Read(b2) +// utils.AssertEqual(t, nil, err) +// utils.AssertEqual(t, b1, b2) +// } + +// func Test_Client_Agent_Multipart_Random_Boundary(t *testing.T) { +// t.Parallel() + +// a := Post("http://example.com"). +// MultipartForm(nil) + +// reg := regexp.MustCompile(`multipart/form-data; boundary=\w{30}`) + +// utils.AssertEqual(t, true, reg.Match(a.req.Header.Peek(fiber.HeaderContentType))) +// } + +// func Test_Client_Agent_Multipart_Invalid_Boundary(t *testing.T) { +// t.Parallel() + +// a := Post("http://example.com"). +// Boundary("*"). +// MultipartForm(nil) + +// utils.AssertEqual(t, 1, len(a.errs)) +// utils.AssertEqual(t, "mime: invalid boundary character", a.errs[0].Error()) +// } + +// func Test_Client_Agent_SendFile_Error(t *testing.T) { +// t.Parallel() + +// a := Post("http://example.com"). +// SendFile("non-exist-file!", "") + +// utils.AssertEqual(t, 1, len(a.errs)) +// utils.AssertEqual(t, true, strings.Contains(a.errs[0].Error(), "open non-exist-file!")) +// } + +// func Test_Client_Debug(t *testing.T) { +// handler := func(c fiber.Ctx) error { +// return c.SendString("debug") +// } + +// var output bytes.Buffer + +// wrapAgent := func(a *Agent) { +// a.Debug(&output) +// } + +// testAgent(t, handler, wrapAgent, "debug", 1) + +// str := output.String() + +// utils.AssertEqual(t, true, strings.Contains(str, "Connected to example.com(pipe)")) +// utils.AssertEqual(t, true, strings.Contains(str, "GET / HTTP/1.1")) +// utils.AssertEqual(t, true, strings.Contains(str, "User-Agent: fiber")) +// utils.AssertEqual(t, true, strings.Contains(str, "Host: example.com\r\n\r\n")) +// utils.AssertEqual(t, true, strings.Contains(str, "HTTP/1.1 200 OK")) +// utils.AssertEqual(t, true, strings.Contains(str, "Content-Type: text/plain; charset=utf-8\r\nContent-Length: 5\r\n\r\ndebug")) +// } + +// func Test_Client_Agent_Timeout(t *testing.T) { +// t.Parallel() + +// ln := fasthttputil.NewInmemoryListener() + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// app.Get("/", func(c fiber.Ctx) error { +// time.Sleep(time.Millisecond * 200) +// return c.SendString("timeout") +// }) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + +// a := Get("http://example.com"). +// Timeout(time.Millisecond * 50) + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// _, body, errs := a.String() + +// utils.AssertEqual(t, "", body) +// utils.AssertEqual(t, 1, len(errs)) +// utils.AssertEqual(t, "timeout", errs[0].Error()) +// } + +// func Test_Client_Agent_Reuse(t *testing.T) { +// t.Parallel() + +// ln := fasthttputil.NewInmemoryListener() + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString("reuse") +// }) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + +// a := Get("http://example.com"). +// Reuse() + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// code, body, errs := a.String() + +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "reuse", body) +// utils.AssertEqual(t, 0, len(errs)) + +// code, body, errs = a.String() + +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "reuse", body) +// utils.AssertEqual(t, 0, len(errs)) +// } + +// func Test_Client_Agent_InsecureSkipVerify(t *testing.T) { +// t.Parallel() + +// cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") +// utils.AssertEqual(t, nil, err) + +// serverTLSConf := &tls.Config{ +// Certificates: []tls.Certificate{cer}, +// } + +// ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") +// utils.AssertEqual(t, nil, err) + +// ln = tls.NewListener(ln, serverTLSConf) + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString("ignore tls") +// }) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + +// code, body, errs := Get("https://" + ln.Addr().String()). +// InsecureSkipVerify(). +// InsecureSkipVerify(). +// String() + +// utils.AssertEqual(t, 0, len(errs)) +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "ignore tls", body) +// } + +// func Test_Client_Agent_TLS(t *testing.T) { +// t.Parallel() + +// serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() +// utils.AssertEqual(t, nil, err) + +// ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") +// utils.AssertEqual(t, nil, err) + +// ln = tls.NewListener(ln, serverTLSConf) + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString("tls") +// }) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + +// code, body, errs := Get("https://" + ln.Addr().String()). +// TLSConfig(clientTLSConf). +// String() + +// utils.AssertEqual(t, 0, len(errs)) +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, "tls", body) +// } + +// func Test_Client_Agent_MaxRedirectsCount(t *testing.T) { +// t.Parallel() + +// ln := fasthttputil.NewInmemoryListener() + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// app.Get("/", func(c fiber.Ctx) error { +// if c.Request().URI().QueryArgs().Has("foo") { +// return c.Redirect("/foo") +// } +// return c.Redirect("/") +// }) +// app.Get("/foo", func(c fiber.Ctx) error { +// return c.SendString("redirect") +// }) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + +// t.Run("success", func(t *testing.T) { +// a := Get("http://example.com?foo"). +// MaxRedirectsCount(1) + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// code, body, errs := a.String() + +// utils.AssertEqual(t, 200, code) +// utils.AssertEqual(t, "redirect", body) +// utils.AssertEqual(t, 0, len(errs)) +// }) + +// t.Run("error", func(t *testing.T) { +// a := Get("http://example.com"). +// MaxRedirectsCount(1) + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// _, body, errs := a.String() + +// utils.AssertEqual(t, "", body) +// utils.AssertEqual(t, 1, len(errs)) +// utils.AssertEqual(t, "too many redirects detected when doing the request", errs[0].Error()) +// }) +// } + +// func Test_Client_Agent_Struct(t *testing.T) { +// t.Parallel() + +// ln := fasthttputil.NewInmemoryListener() + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// app.Get("/", func(c fiber.Ctx) error { +// return c.JSON(data{true}) +// }) + +// app.Get("/error", func(c fiber.Ctx) error { +// return c.SendString(`{"success"`) +// }) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + +// t.Run("success", func(t *testing.T) { +// t.Parallel() + +// a := Get("http://example.com") + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// var d data + +// code, body, errs := a.Struct(&d) + +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, `{"success":true}`, string(body)) +// utils.AssertEqual(t, 0, len(errs)) +// utils.AssertEqual(t, true, d.Success) +// }) + +// t.Run("pre error", func(t *testing.T) { +// t.Parallel() +// a := Get("http://example.com") + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } +// a.errs = append(a.errs, errors.New("pre errors")) + +// var d data +// _, body, errs := a.Struct(&d) + +// utils.AssertEqual(t, "", string(body)) +// utils.AssertEqual(t, 1, len(errs)) +// utils.AssertEqual(t, "pre errors", errs[0].Error()) +// utils.AssertEqual(t, false, d.Success) +// }) + +// t.Run("error", func(t *testing.T) { +// a := Get("http://example.com/error") + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// var d data + +// code, body, errs := a.JSONDecoder(json.Unmarshal).Struct(&d) + +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, `{"success"`, string(body)) +// utils.AssertEqual(t, 1, len(errs)) +// utils.AssertEqual(t, "unexpected end of JSON input", errs[0].Error()) +// }) +// } + +// func Test_Client_Agent_Parse(t *testing.T) { +// t.Parallel() + +// a := Get("https://example.com:10443") + +// utils.AssertEqual(t, nil, a.Parse()) +// } + +// func Test_AddMissingPort_TLS(t *testing.T) { +// addr := addMissingPort("example.com", true) +// utils.AssertEqual(t, "example.com:443", addr) +// } + +// func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Agent), excepted string, count ...int) { +// t.Parallel() + +// ln := fasthttputil.NewInmemoryListener() + +// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + +// app.Get("/", handler) + +// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + +// c := 1 +// if len(count) > 0 { +// c = count[0] +// } + +// for i := 0; i < c; i++ { +// a := Get("http://example.com") + +// wrapAgent(a) + +// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + +// code, body, errs := a.String() + +// utils.AssertEqual(t, fiber.StatusOK, code) +// utils.AssertEqual(t, excepted, body) +// utils.AssertEqual(t, 0, len(errs)) +// } +// } + +// type data struct { +// Success bool `json:"success" xml:"success"` +// } + +// type errorMultipartWriter struct { +// count int +// } + +// func (e *errorMultipartWriter) Boundary() string { return "myBoundary" } +// func (e *errorMultipartWriter) SetBoundary(_ string) error { return nil } +// func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) { +// if e.count == 0 { +// e.count++ +// return nil, errors.New("CreateFormFile error") +// } +// return errorWriter{}, nil +// } +// func (e *errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") } +// func (e *errorMultipartWriter) Close() error { return errors.New("Close error") } + +// type errorWriter struct{} + +// func (errorWriter) Write(_ []byte) (int, error) { return 0, errors.New("Write error") } + +func TestSetValWithStruct(t *testing.T) { + t.Parallel() + + // test SetValWithStruct vai QueryParam struct. type args struct { unexport int TInt int @@ -24,7 +1635,8 @@ func TestParamsSetParamsWithStruct(t *testing.T) { p := &QueryParam{ Args: fasthttp.AcquireArgs(), } - p.SetParamsWithStruct(args{ + + SetValWithStruct(p, "param", args{ unexport: 5, TInt: 5, TString: "string", @@ -77,7 +1689,8 @@ func TestParamsSetParamsWithStruct(t *testing.T) { p := &QueryParam{ Args: fasthttp.AcquireArgs(), } - p.SetParamsWithStruct(&args{ + + SetValWithStruct(p, "param", &args{ TInt: 5, TString: "string", TFloat: 3.1, @@ -128,7 +1741,7 @@ func TestParamsSetParamsWithStruct(t *testing.T) { p := &QueryParam{ Args: fasthttp.AcquireArgs(), } - p.SetParamsWithStruct(&args{ + SetValWithStruct(p, "param", &args{ TInt: 0, TString: "", TFloat: 0.0, @@ -145,7 +1758,7 @@ func TestParamsSetParamsWithStruct(t *testing.T) { p := &QueryParam{ Args: fasthttp.AcquireArgs(), } - p.SetParamsWithStruct(5) + SetValWithStruct(p, "param", 5) utils.AssertEqual(t, 0, p.Len()) }) } diff --git a/client/response.go b/client/response.go index 97b93679cb..1bbacbc910 100644 --- a/client/response.go +++ b/client/response.go @@ -95,24 +95,21 @@ func (r *Response) Close() { ReleaseResponse(r) } -var responsePool sync.Pool +var responsePool = &sync.Pool{ + New: func() any { + return &Response{ + cookie: []*fasthttp.Cookie{}, + rawResponse: fasthttp.AcquireResponse(), + } + }, +} // AcquireResponse returns an empty response object from the pool. // // The returned response may be returned to the pool with ReleaseResponse when no longer needed. // This allows reducing GC load. func AcquireResponse() (resp *Response) { - respv := responsePool.Get() - if respv != nil { - resp = respv.(*Response) - return - } - resp = &Response{ - cookie: []*fasthttp.Cookie{}, - rawResponse: fasthttp.AcquireResponse(), - } - - return + return responsePool.Get().(*Response) } // ReleaseResponse returns the object acquired via AcquireResponse to the pool. From 89b42d91597393f0fbd1f9bd5fda6ab76365f224 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Fri, 12 Aug 2022 10:49:02 +0800 Subject: [PATCH 022/118] =?UTF-8?q?=E2=9C=85=20perf:=20change=20test=20fun?= =?UTF-8?q?c=20name=20to=20fit=20project=20format?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client_test.go | 12 +- client/core_test.go | 4 +- client/hooks.go | 4 +- client/hooks_test.go | 10 +- client/request_test.go | 244 ++++++++++++++++++++++++++--------------- 5 files changed, 173 insertions(+), 101 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index f471047bf9..a29b1a5c5b 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -48,7 +48,7 @@ import ( // errs[0].Error()) // } -func TestGet(t *testing.T) { +func Test_Get(t *testing.T) { t.Parallel() ln := fasthttputil.NewInmemoryListener() @@ -1190,7 +1190,7 @@ func TestGet(t *testing.T) { // func (errorWriter) Write(_ []byte) (int, error) { return 0, errors.New("Write error") } -func TestClientR(t *testing.T) { +func Test_Client_R(t *testing.T) { t.Parallel() client := AcquireClient() @@ -1200,7 +1200,7 @@ func TestClientR(t *testing.T) { utils.AssertEqual(t, client, req.Client()) } -func TestClientAddHook(t *testing.T) { +func Test_Client_Add_Hook(t *testing.T) { t.Parallel() t.Run("add request hooks", func(t *testing.T) { @@ -1236,7 +1236,7 @@ func TestClientAddHook(t *testing.T) { }) } -func TestClientMarshal(t *testing.T) { +func Test_Client_Marshal(t *testing.T) { t.Run("set json marshal", func(t *testing.T) { client := AcquireClient(). SetJSONMarshal(func(v any) ([]byte, error) { @@ -1280,7 +1280,7 @@ func TestClientMarshal(t *testing.T) { }) } -func TestClientSetBaseURL(t *testing.T) { +func Test_Client_SetBaseURL(t *testing.T) { t.Parallel() client := AcquireClient().SetBaseURL("http://example.com") @@ -1288,7 +1288,7 @@ func TestClientSetBaseURL(t *testing.T) { utils.AssertEqual(t, "http://example.com", client.BaseURL()) } -func TestClientHeader(t *testing.T) { +func Test_Client_Header(t *testing.T) { t.Parallel() t.Run("", func(t *testing.T) { diff --git a/client/core_test.go b/client/core_test.go index 59eccad08c..194cf60ed9 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -13,7 +13,7 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) -func TestExecFunc(t *testing.T) { +func Test_Exec_Func(t *testing.T) { t.Parallel() ln := fasthttputil.NewInmemoryListener() @@ -90,7 +90,7 @@ func TestExecFunc(t *testing.T) { }) } -func TestExecute(t *testing.T) { +func Test_Execute(t *testing.T) { t.Parallel() ln := fasthttputil.NewInmemoryListener() diff --git a/client/hooks.go b/client/hooks.go index 768f3d1c1f..534188513a 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -140,11 +140,11 @@ func parserRequestHeader(c *Client, req *Request) error { req.rawRequest.Header.SetMethod(req.Method()) // merge header c.header.VisitAll(func(key, value []byte) { - req.rawRequest.Header.SetBytesKV(key, value) + req.rawRequest.Header.AddBytesKV(key, value) }) req.header.VisitAll(func(key, value []byte) { - req.rawRequest.Header.SetBytesKV(key, value) + req.rawRequest.Header.AddBytesKV(key, value) }) // according to data set content-type diff --git a/client/hooks_test.go b/client/hooks_test.go index bdad986fcd..7f3af4c794 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -11,7 +11,7 @@ import ( "github.com/gofiber/fiber/v3/utils" ) -func TestAddMissingPort(t *testing.T) { +func Test_AddMissing_Port(t *testing.T) { type args struct { addr string isTLS bool @@ -51,7 +51,7 @@ func TestAddMissingPort(t *testing.T) { } } -func TestRandString(t *testing.T) { +func Test_Rand_String(t *testing.T) { tests := []struct { name string args int @@ -69,7 +69,7 @@ func TestRandString(t *testing.T) { } } -func TestParserRequestURL(t *testing.T) { +func Test_Parser_Request_URL(t *testing.T) { t.Parallel() t.Run("client baseurl should be set", func(t *testing.T) { @@ -210,7 +210,7 @@ func TestParserRequestURL(t *testing.T) { }) } -func TestParserRequestHeader(t *testing.T) { +func Test_Parser_Request_Header(t *testing.T) { t.Parallel() t.Run("client header should be set", func(t *testing.T) { @@ -418,7 +418,7 @@ func TestParserRequestHeader(t *testing.T) { }) } -func TestParserRequestBody(t *testing.T) { +func Test_Parser_Request_Body(t *testing.T) { t.Parallel() t.Run("json body", func(t *testing.T) { diff --git a/client/request_test.go b/client/request_test.go index 132961f082..5d47093439 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "fmt" "net" "testing" @@ -11,7 +12,7 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) -func TestRequestMethod(t *testing.T) { +func Test_Request_Method(t *testing.T) { t.Parallel() req := AcquireRequest() @@ -28,7 +29,7 @@ func TestRequestMethod(t *testing.T) { utils.AssertEqual(t, "DELETE", req.Method()) } -func TestRequestURL(t *testing.T) { +func Test_Request_URL(t *testing.T) { t.Parallel() req := AcquireRequest() @@ -40,7 +41,7 @@ func TestRequestURL(t *testing.T) { utils.AssertEqual(t, "https://example.com/normal", req.URL()) } -func TestRequestClient(t *testing.T) { +func Test_Request_Client(t *testing.T) { t.Parallel() client := AcquireClient() @@ -50,7 +51,7 @@ func TestRequestClient(t *testing.T) { utils.AssertEqual(t, client, req.Client()) } -func TestRequestContext(t *testing.T) { +func Test_Request_Context(t *testing.T) { t.Parallel() req := AcquireRequest() @@ -66,7 +67,7 @@ func TestRequestContext(t *testing.T) { utils.AssertEqual(t, "string", ctx.Value(key).(string)) } -func TestRequestHeader(t *testing.T) { +func Test_Request_Header(t *testing.T) { t.Parallel() t.Run("add header", func(t *testing.T) { @@ -125,7 +126,7 @@ func TestRequestHeader(t *testing.T) { }) } -func TestRequestQueryParam(t *testing.T) { +func Test_Request_QueryParam(t *testing.T) { t.Parallel() t.Run("add param", func(t *testing.T) { @@ -245,7 +246,7 @@ func TestRequestQueryParam(t *testing.T) { }) } -func TestRequestUA(t *testing.T) { +func Test_Request_UA(t *testing.T) { t.Parallel() req := AcquireRequest().SetUserAgent("fiber") @@ -255,7 +256,7 @@ func TestRequestUA(t *testing.T) { utils.AssertEqual(t, "foo", req.UserAgent()) } -func TestReferer(t *testing.T) { +func Test_Request_Referer(t *testing.T) { t.Parallel() req := AcquireRequest().SetReferer("http://example.com") @@ -265,7 +266,7 @@ func TestReferer(t *testing.T) { utils.AssertEqual(t, "https://example.com", req.Referer()) } -func TestRequestCookie(t *testing.T) { +func Test_Request_Cookie(t *testing.T) { t.Parallel() t.Run("set cookie", func(t *testing.T) { @@ -323,7 +324,7 @@ func TestRequestCookie(t *testing.T) { }) } -func TestRequestPathParam(t *testing.T) { +func Test_Request_PathParam(t *testing.T) { t.Parallel() t.Run("set path param", func(t *testing.T) { @@ -381,7 +382,7 @@ func TestRequestPathParam(t *testing.T) { }) } -func TestRequestFormData(t *testing.T) { +func Test_Request_FormData(t *testing.T) { t.Parallel() t.Run("add form data", func(t *testing.T) { @@ -502,7 +503,7 @@ func TestRequestFormData(t *testing.T) { }) } -func TestRequestFile(t *testing.T) { +func Test_Request_File(t *testing.T) { t.Parallel() t.Run("add file", func(t *testing.T) { @@ -525,7 +526,7 @@ func createHelperServer(t *testing.T) (*fiber.App, *Client, func()) { } } -func TestRequestInvalidURL(t *testing.T) { +func Test_Request_Invalid_URL(t *testing.T) { t.Parallel() resp, err := AcquireRequest(). @@ -535,7 +536,7 @@ func TestRequestInvalidURL(t *testing.T) { utils.AssertEqual(t, (*Response)(nil), resp) } -func TestRequestUnsupportProtocol(t *testing.T) { +func Test_Request_Unsupport_Protocol(t *testing.T) { t.Parallel() resp, err := AcquireRequest(). @@ -544,7 +545,7 @@ func TestRequestUnsupportProtocol(t *testing.T) { utils.AssertEqual(t, (*Response)(nil), resp) } -func TestRequestGet(t *testing.T) { +func Test_Request_Get(t *testing.T) { t.Parallel() app, client, start := createHelperServer(t) @@ -564,7 +565,7 @@ func TestRequestGet(t *testing.T) { } } -func TestRequestPost(t *testing.T) { +func Test_Request_Post(t *testing.T) { t.Parallel() app, client, start := createHelperServer(t) @@ -587,7 +588,7 @@ func TestRequestPost(t *testing.T) { } } -func TestRequestHead(t *testing.T) { +func Test_Request_Head(t *testing.T) { t.Parallel() app, client, start := createHelperServer(t) @@ -609,7 +610,7 @@ func TestRequestHead(t *testing.T) { } } -func TestRequestPut(t *testing.T) { +func Test_Request_Put(t *testing.T) { t.Parallel() app, client, start := createHelperServer(t) @@ -632,14 +633,39 @@ func TestRequestPut(t *testing.T) { resp.Close() } } +func Test_Request_Delete(t *testing.T) { + t.Parallel() + + app, client, start := createHelperServer(t) + + app.Delete("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusNoContent). + SendString("deleted") + }) + + go start() -func TestRequestPatch(t *testing.T) { + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + Delete("http://example.com") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusNoContent, resp.StatusCode()) + utils.AssertEqual(t, "", resp.String()) + + resp.Close() + } +} + +func Test_Request_Options(t *testing.T) { t.Parallel() app, client, start := createHelperServer(t) - app.Patch("/", func(c fiber.Ctx) error { - return c.SendString(c.FormValue("foo")) + app.Options("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusOK). + SendString("options") }) go start() @@ -647,25 +673,24 @@ func TestRequestPatch(t *testing.T) { for i := 0; i < 5; i++ { resp, err := AcquireRequest(). SetClient(client). - SetFormData("foo", "bar"). - Patch("http://example.com") + Options("http://example.com") utils.AssertEqual(t, nil, err) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) - utils.AssertEqual(t, "bar", resp.String()) + utils.AssertEqual(t, "options", resp.String()) resp.Close() } } -func TestRequestDelete(t *testing.T) { +func Test_Request_Send(t *testing.T) { t.Parallel() app, client, start := createHelperServer(t) - app.Delete("/", func(c fiber.Ctx) error { - return c.Status(fiber.StatusNoContent). - SendString("deleted") + app.Post("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusOK). + SendString("post") }) go start() @@ -673,86 +698,133 @@ func TestRequestDelete(t *testing.T) { for i := 0; i < 5; i++ { resp, err := AcquireRequest(). SetClient(client). - Delete("http://example.com") + SetURL("http://example.com"). + SetMethod(fiber.MethodPost). + Send() utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusNoContent, resp.StatusCode()) - utils.AssertEqual(t, "", resp.String()) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) + utils.AssertEqual(t, "post", resp.String()) resp.Close() } } -// func Test_Client_UserAgent(t *testing.T) { -// t.Parallel() +func Test_Request_Patch(t *testing.T) { + t.Parallel() -// ln := fasthttputil.NewInmemoryListener() + app, client, start := createHelperServer(t) -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app.Patch("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) -// app.Get("/", func(c fiber.Ctx) error { -// return c.Send(c.Request().Header.UserAgent()) -// }) + go start() -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + SetFormData("foo", "bar"). + Patch("http://example.com") -// t.Run("default", func(t *testing.T) { -// for i := 0; i < 5; i++ { -// a := Get("http://example.com") + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) + utils.AssertEqual(t, "bar", resp.String()) -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + resp.Close() + } +} -// code, body, errs := a.String() +func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted string, count ...int) { + t.Parallel() -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, defaultUserAgent, body) -// utils.AssertEqual(t, 0, len(errs)) -// } -// }) + app, client, start := createHelperServer(t) + app.Get("/", handler) + go start() -// t.Run("custom", func(t *testing.T) { -// for i := 0; i < 5; i++ { -// c := AcquireClient() -// c.UserAgent = "ua" + c := 1 + if len(count) > 0 { + c = count[0] + } -// a := c.Get("http://example.com") + for i := 0; i < c; i++ { + req := AcquireRequest().SetClient(client) + wrapAgent(req) -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + resp, err := req.Get("http://example.com") -// code, body, errs := a.String() + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) + utils.AssertEqual(t, excepted, resp.String()) + } +} -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "ua", body) -// utils.AssertEqual(t, 0, len(errs)) -// ReleaseClient(c) -// } -// }) -// } +func Test_Request_UserAgent_With_Server(t *testing.T) { + t.Parallel() -// func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// c.Request().Header.VisitAll(func(key, value []byte) { -// if k := string(key); k == "K1" || k == "K2" { -// _, _ = c.Write(key) -// _, _ = c.Write(value) -// } -// }) -// return nil -// } + app, client, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + return c.Send(c.Request().Header.UserAgent()) + }) -// wrapAgent := func(a *Agent) { -// a.Set("k1", "v1"). -// SetBytesK([]byte("k1"), "v1"). -// SetBytesV("k1", []byte("v1")). -// AddBytesK([]byte("k1"), "v11"). -// AddBytesV("k1", []byte("v22")). -// AddBytesKV([]byte("k1"), []byte("v33")). -// SetBytesKV([]byte("k2"), []byte("v2")). -// Add("k2", "v22") -// } + go start() -// testAgent(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") -// } + t.Run("default", func(t *testing.T) { + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) + utils.AssertEqual(t, defaultUserAgent, resp.String()) + + resp.Close() + } + }) + + t.Run("custom", func(t *testing.T) { + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + SetUserAgent("ua"). + Get("http://example.com") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) + utils.AssertEqual(t, "ua", resp.String()) + + resp.Close() + } + }) +} + +func Test_Request_Header_With_Server(t *testing.T) { + handler := func(c fiber.Ctx) error { + fmt.Println(c.Request().Header.String()) + c.Request().Header.VisitAll(func(key, value []byte) { + if k := string(key); k == "K1" || k == "K2" { + _, _ = c.Write(key) + _, _ = c.Write(value) + } + }) + return nil + } + + wrapAgent := func(r *Request) { + r.SetHeader("k1", "v1"). + AddHeader("k1", "v11"). + AddHeaders(map[string][]string{ + "k1": {"v22", "v33"}, + }). + SetHeaders(map[string]string{ + "k2": "v2", + }). + AddHeader("k2", "v22") + } + + testAgent(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") +} // func Test_Client_Agent_Connection_Close(t *testing.T) { // handler := func(c fiber.Ctx) error { @@ -1617,7 +1689,7 @@ func TestRequestDelete(t *testing.T) { // func (errorWriter) Write(_ []byte) (int, error) { return 0, errors.New("Write error") } -func TestSetValWithStruct(t *testing.T) { +func Test_SetValWithStruct(t *testing.T) { t.Parallel() // test SetValWithStruct vai QueryParam struct. From 2b075be5096150539dda37deed7f2314a9d5f8f6 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Fri, 12 Aug 2022 14:25:24 +0800 Subject: [PATCH 023/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20handle=20error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/hooks.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/client/hooks.go b/client/hooks.go index 534188513a..e7982acb8c 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -210,7 +210,10 @@ func parserRequestBody(c *Client, req *Request) (err error) { req.rawRequest.SetBody(req.formData.QueryString()) case filesBody: mw := multipart.NewWriter(req.rawRequest.BodyWriter()) - mw.SetBoundary(req.boundary) + err = mw.SetBoundary(req.boundary) + if err != nil { + return + } defer func() { err = mw.Close() if err != nil { @@ -271,11 +274,14 @@ func parserRequestBody(c *Client, req *Request) (err error) { break } - w.Write(b) + _, err = w.Write(b) + if err != nil { + return err + } } // ignore err - v.reader.Close() + _ = v.reader.Close() } case rawBody: if body, ok := req.body.([]byte); ok { From b005de79911415aff139c9a1b93039c420365605 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sat, 13 Aug 2022 21:02:22 +0800 Subject: [PATCH 024/118] =?UTF-8?q?=F0=9F=9A=A7=20v3:=20add=20unit=20test?= =?UTF-8?q?=20and=20fix=20error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/hooks.go | 24 +- client/request.go | 34 +- client/request_test.go | 711 ++++++++++++++++------------------------- client/response.go | 14 +- 4 files changed, 321 insertions(+), 462 deletions(-) diff --git a/client/hooks.go b/client/hooks.go index e7982acb8c..0231a8387e 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -192,7 +192,7 @@ func parserRequestHeader(c *Client, req *Request) error { // parserRequestBody automatically serializes the data according to // the data type and stores it in the body of the rawRequest -func parserRequestBody(c *Client, req *Request) (err error) { +func parserRequestBody(c *Client, req *Request) error { switch req.bodyType { case jsonBody: body, err := c.core.jsonMarshal(req.body) @@ -210,12 +210,12 @@ func parserRequestBody(c *Client, req *Request) (err error) { req.rawRequest.SetBody(req.formData.QueryString()) case filesBody: mw := multipart.NewWriter(req.rawRequest.BodyWriter()) - err = mw.SetBoundary(req.boundary) + err := mw.SetBoundary(req.boundary) if err != nil { - return + return err } defer func() { - err = mw.Close() + err := mw.Close() if err != nil { return } @@ -229,7 +229,7 @@ func parserRequestBody(c *Client, req *Request) (err error) { err = mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value)) }) if err != nil { - return + return err } // add file @@ -242,30 +242,30 @@ func parserRequestBody(c *Client, req *Request) (err error) { // if name is not exist, set name if v.name == "" && v.path != "" { v.path = filepath.Clean(v.path) - v.name = filepath.Base(v.name) + v.name = filepath.Base(v.path) } // if param is not exist, set it - if v.paramName == "" { - v.paramName = "file" + fmt.Sprint(i) + if v.fieldName == "" { + v.fieldName = "file" + fmt.Sprint(i+1) } // check the reader if v.reader == nil { v.reader, err = os.Open(v.path) if err != nil { - return + return err } } // wirte file - w, err := mw.CreateFormFile(v.paramName, v.name) + w, err := mw.CreateFormFile(v.fieldName, v.name) if err != nil { return err } for { - _, err := v.reader.Read(b) + n, err := v.reader.Read(b) if err != nil && err != io.EOF { return err } @@ -274,7 +274,7 @@ func parserRequestBody(c *Client, req *Request) (err error) { break } - _, err = w.Write(b) + _, err = w.Write(b[:n]) if err != nil { return err } diff --git a/client/request.go b/client/request.go index 5926606615..750b63efa8 100644 --- a/client/request.go +++ b/client/request.go @@ -216,6 +216,18 @@ func (r *Request) SetUserAgent(ua string) *Request { return r } +// Boundary returns bounday in multipart boundary. +func (r *Request) Boundary() string { + return r.boundary +} + +// SetBoundary method sets multipart boundary. +func (r *Request) SetBoundary(b string) *Request { + r.boundary = b + + return r +} + // Referer returns referer in request instance. func (r *Request) Referer() string { return r.referer @@ -494,10 +506,10 @@ func (r *Request) Reset() { r.body = nil r.bodyType = noBody - copiedFile := r.files - r.files = r.files[:0] - for _, v := range copiedFile { - ReleaseFile(v) + for len(r.files) != 0 { + t := r.files[0] + r.files = r.files[1:] + ReleaseFile(t) } r.formData.Reset() @@ -727,7 +739,7 @@ func (f *FormData) Reset() { // File is a struct which support send files via request. type File struct { name string - paramName string + fieldName string path string reader io.ReadCloser } @@ -737,9 +749,9 @@ func (f *File) SetName(n string) { f.name = n } -// SetParamName method sets key of file in the body. -func (f *File) SetParamName(n string) { - f.paramName = n +// SetFieldName method sets key of file in the body. +func (f *File) SetFieldName(n string) { + f.fieldName = n } // SetPath method set file path. @@ -756,7 +768,7 @@ func (f *File) SetReader(r io.ReadCloser) { // Reset clear the File object. func (f *File) Reset() { f.name = "" - f.paramName = "" + f.fieldName = "" f.path = "" f.reader = nil } @@ -807,9 +819,9 @@ func SetFileName(n string) SetFileFunc { } } -func SetFileParamName(p string) SetFileFunc { +func SetFileFieldName(p string) SetFileFunc { return func(f *File) { - f.SetParamName(p) + f.SetFieldName(p) } } diff --git a/client/request_test.go b/client/request_test.go index 5d47093439..5d3b0c4e59 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -1,9 +1,16 @@ package client import ( + "bytes" "context" - "fmt" + "errors" + "io" + "mime/multipart" "net" + "os" + "path/filepath" + "regexp" + "strings" "testing" "github.com/gofiber/fiber/v3" @@ -756,52 +763,34 @@ func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Reques utils.AssertEqual(t, nil, err) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) utils.AssertEqual(t, excepted, resp.String()) + resp.Close() } } -func Test_Request_UserAgent_With_Server(t *testing.T) { +func testAgentFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted error, count ...int) { t.Parallel() app, client, start := createHelperServer(t) - app.Get("/", func(c fiber.Ctx) error { - return c.Send(c.Request().Header.UserAgent()) - }) - + app.Get("/", handler) go start() - t.Run("default", func(t *testing.T) { - for i := 0; i < 5; i++ { - resp, err := AcquireRequest(). - SetClient(client). - Get("http://example.com") - - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) - utils.AssertEqual(t, defaultUserAgent, resp.String()) - - resp.Close() - } - }) + c := 1 + if len(count) > 0 { + c = count[0] + } - t.Run("custom", func(t *testing.T) { - for i := 0; i < 5; i++ { - resp, err := AcquireRequest(). - SetClient(client). - SetUserAgent("ua"). - Get("http://example.com") + for i := 0; i < c; i++ { + req := AcquireRequest().SetClient(client) + wrapAgent(req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) - utils.AssertEqual(t, "ua", resp.String()) + _, err := req.Get("http://example.com") - resp.Close() - } - }) + utils.AssertEqual(t, excepted.Error(), err.Error()) + } } func Test_Request_Header_With_Server(t *testing.T) { handler := func(c fiber.Ctx) error { - fmt.Println(c.Request().Header.String()) c.Request().Header.VisitAll(func(key, value []byte) { if k := string(key); k == "K1" || k == "K2" { _, _ = c.Write(key) @@ -841,61 +830,53 @@ func Test_Request_Header_With_Server(t *testing.T) { // testAgent(t, handler, wrapAgent, "close") // } -// func Test_Client_Agent_UserAgent(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.Send(c.Request().Header.UserAgent()) -// } - -// wrapAgent := func(a *Agent) { -// a.UserAgent("ua"). -// UserAgentBytes([]byte("ua")) -// } - -// testAgent(t, handler, wrapAgent, "ua") -// } +func Test_Request_UserAgent_With_Server(t *testing.T) { + t.Parallel() -// func Test_Client_Agent_Cookie(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.SendString( -// c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) -// } + handler := func(c fiber.Ctx) error { + return c.Send(c.Request().Header.UserAgent()) + } -// wrapAgent := func(a *Agent) { -// a.Cookie("k1", "v1"). -// CookieBytesK([]byte("k2"), "v2"). -// CookieBytesKV([]byte("k2"), []byte("v2")). -// Cookies("k3", "v3", "k4", "v4"). -// CookiesBytesKV([]byte("k3"), []byte("v3"), []byte("k4"), []byte("v4")) -// } + t.Run("default", func(t *testing.T) { + testAgent(t, handler, func(agent *Request) {}, defaultUserAgent, 5) + }) -// testAgent(t, handler, wrapAgent, "v1v2v3v4") -// } + t.Run("custom", func(t *testing.T) { + testAgent(t, handler, func(agent *Request) { + agent.SetUserAgent("ua") + }, "ua", 5) + }) +} -// func Test_Client_Agent_Referer(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.Send(c.Request().Header.Referer()) -// } +func Test_Request_Cookie_With_Server(t *testing.T) { + handler := func(c fiber.Ctx) error { + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) + } -// wrapAgent := func(a *Agent) { -// a.Referer("http://referer.com"). -// RefererBytes([]byte("http://referer.com")) -// } + wrapAgent := func(req *Request) { + req.SetCookie("k1", "v1"). + SetCookies(map[string]string{ + "k2": "v2", + "k3": "v3", + "k4": "v4", + }).DelCookies("k4") + } -// testAgent(t, handler, wrapAgent, "http://referer.com") -// } + testAgent(t, handler, wrapAgent, "v1v2v3") +} -// func Test_Client_Agent_ContentType(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.Send(c.Request().Header.ContentType()) -// } +func Test_Request_Referer_With_Server(t *testing.T) { + handler := func(c fiber.Ctx) error { + return c.Send(c.Request().Header.Referer()) + } -// wrapAgent := func(a *Agent) { -// a.ContentType("custom-type"). -// ContentTypeBytes([]byte("custom-type")) -// } + wrapAgent := func(req *Request) { + req.SetReferer("http://referer.com") + } -// testAgent(t, handler, wrapAgent, "custom-type") -// } + testAgent(t, handler, wrapAgent, "http://referer.com") +} // func Test_Client_Agent_Host(t *testing.T) { // t.Parallel() @@ -925,18 +906,20 @@ func Test_Request_Header_With_Server(t *testing.T) { // utils.AssertEqual(t, 0, len(errs)) // } -// func Test_Client_Agent_QueryString(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.Send(c.Request().URI().QueryString()) -// } +func Test_Request_QueryString_With_Server(t *testing.T) { + handler := func(c fiber.Ctx) error { + return c.Send(c.Request().URI().QueryString()) + } -// wrapAgent := func(a *Agent) { -// a.QueryString("foo=bar&bar=baz"). -// QueryStringBytes([]byte("foo=bar&bar=baz")) -// } + wrapAgent := func(req *Request) { + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "bar": "baz", + }) + } -// testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -// } + testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") +} // func Test_Client_Agent_BasicAuth(t *testing.T) { // handler := func(c fiber.Ctx) error { @@ -957,121 +940,248 @@ func Test_Request_Header_With_Server(t *testing.T) { // testAgent(t, handler, wrapAgent, "foo:bar") // } -// func Test_Client_Agent_BodyString(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.Send(c.Request().Body()) -// } +func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { + t.Helper() -// wrapAgent := func(a *Agent) { -// a.BodyString("foo=bar&bar=baz") -// } + basename := filepath.Base(filename) + utils.AssertEqual(t, fh.Filename, basename) -// testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -// } + b1, err := os.ReadFile(filename) + utils.AssertEqual(t, nil, err) -// func Test_Client_Agent_Body(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.Send(c.Request().Body()) -// } + b2 := make([]byte, fh.Size) + f, err := fh.Open() + utils.AssertEqual(t, nil, err) + defer func() { _ = f.Close() }() + _, err = f.Read(b2) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, b1, b2) +} -// wrapAgent := func(a *Agent) { -// a.Body([]byte("foo=bar&bar=baz")) -// } +func Test_Request_Body_With_Server(t *testing.T) { + t.Parallel() -// testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -// } + t.Run("json body", func(t *testing.T) { + testAgent(t, + func(c fiber.Ctx) error { + utils.AssertEqual(t, "application/json", string(c.Request().Header.ContentType())) + return c.SendString(string(c.Request().Body())) + }, + func(agent *Request) { + agent.SetJSON(map[string]string{ + "success": "hello", + }) + }, + "{\"success\":\"hello\"}", + ) + }) -// func Test_Client_Agent_BodyStream(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.Send(c.Request().Body()) -// } + t.Run("json error", func(t *testing.T) { + testAgentFail(t, + func(c fiber.Ctx) error { + return c.SendString("") + }, + func(agent *Request) { + agent.SetJSON(complex(1, 1)) + }, + errors.New("json: unsupported type: complex128"), + ) + }) -// wrapAgent := func(a *Agent) { -// a.BodyStream(strings.NewReader("body stream"), -1) -// } + t.Run("xml body", func(t *testing.T) { + testAgent(t, + func(c fiber.Ctx) error { + utils.AssertEqual(t, "application/xml", string(c.Request().Header.ContentType())) + return c.SendString(string(c.Request().Body())) + }, + func(agent *Request) { + type args struct { + Content string `xml:"content"` + } + agent.SetXML(args{ + Content: "hello", + }) + }, + "hello", + ) + }) -// testAgent(t, handler, wrapAgent, "body stream") -// } + t.Run("xml error", func(t *testing.T) { + testAgentFail(t, + func(c fiber.Ctx) error { + return c.SendString("") + }, + func(agent *Request) { + agent.SetXML(complex(1, 1)) + }, + errors.New("xml: unsupported type: complex128"), + ) + }) -// func Test_Client_Agent_Custom_Response(t *testing.T) { -// t.Parallel() + t.Run("formdata", func(t *testing.T) { + testAgent(t, + func(c fiber.Ctx) error { + utils.AssertEqual(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) + return c.Send(c.Request().Body()) + }, + func(agent *Request) { + agent.SetFormData("foo", "bar"). + SetFormDatas(map[string]string{ + "bar": "baz", + "fiber": "fast", + }) + }, + "foo=bar&bar=baz&fiber=fast") + }) -// ln := fasthttputil.NewInmemoryListener() + t.Run("multipart form", func(t *testing.T) { + t.Parallel() -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app, client, start := createHelperServer(t) + app.Post("/", func(c fiber.Ctx) error { + utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString("custom") -// }) + mf, err := c.MultipartForm() + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "bar", mf.Value["foo"][0]) -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + return c.Send(c.Request().Body()) + }) -// for i := 0; i < 5; i++ { -// a := AcquireAgent() -// resp := AcquireResponse() + go start() -// req := a.Request() -// req.Header.SetMethod(fiber.MethodGet) -// req.SetRequestURI("http://example.com") + req := AcquireRequest(). + SetClient(client). + SetBoundary("myBoundary"). + SetFormData("foo", "bar"). + AddFiles(AcquireFile( + SetFileName("hello.txt"), + SetFileFieldName("foo"), + SetFileReader(io.NopCloser(strings.NewReader("world"))), + )) -// utils.AssertEqual(t, nil, a.Parse()) + resp, err := req.Post("http://exmaple.com") + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + form, err := multipart.NewReader(bytes.NewReader(resp.Body()), "myBoundary").ReadForm(1024 * 1024) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "bar", form.Value["foo"][0]) + resp.Close() + }) -// code, body, errs := a.SetResponse(resp). -// String() + t.Run("multipart form send file", func(t *testing.T) { + t.Parallel() -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "custom", body) -// utils.AssertEqual(t, "custom", string(resp.Body())) -// utils.AssertEqual(t, 0, len(errs)) + app, client, start := createHelperServer(t) + app.Post("/", func(c fiber.Ctx) error { + utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) -// ReleaseResponse(resp) -// } -// } + fh1, err := c.FormFile("field1") + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fh1.Filename, "name") + buf := make([]byte, fh1.Size) + f, err := fh1.Open() + utils.AssertEqual(t, nil, err) + defer func() { _ = f.Close() }() + _, err = f.Read(buf) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "form file", string(buf)) -// func Test_Client_Agent_Dest(t *testing.T) { -// t.Parallel() + fh2, err := c.FormFile("file2") + utils.AssertEqual(t, nil, err) + checkFormFile(t, fh2, "../.github/testdata/index.html") -// ln := fasthttputil.NewInmemoryListener() + fh3, err := c.FormFile("file3") + utils.AssertEqual(t, nil, err) + checkFormFile(t, fh3, "../.github/testdata/index.tmpl") -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + return c.SendString("multipart form files") + }) -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString("dest") -// }) + go start() -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + for i := 0; i < 5; i++ { + req := AcquireRequest(). + SetClient(client). + AddFiles( + AcquireFile( + SetFileFieldName("field1"), + SetFileName("name"), + SetFileReader(io.NopCloser(bytes.NewReader([]byte("form file")))), + ), + ). + AddFile("../.github/testdata/index.html"). + AddFile("../.github/testdata/index.tmpl"). + SetBoundary("myBoundary") + + resp, err := req.Post("http://example.com") + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "multipart form files", resp.String()) -// t.Run("small dest", func(t *testing.T) { -// dest := []byte("de") + resp.Close() + } + }) -// a := Get("http://example.com") + t.Run("multipart random boundary", func(t *testing.T) { + t.Parallel() -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + app, client, start := createHelperServer(t) + app.Post("/", func(c fiber.Ctx) error { + reg := regexp.MustCompile(`multipart/form-data; boundary=[\-\w]{35}`) + utils.AssertEqual(t, true, reg.MatchString(c.Get(fiber.HeaderContentType))) -// code, body, errs := a.Dest(dest[:0]).String() + return c.Send(c.Request().Body()) + }) -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "dest", body) -// utils.AssertEqual(t, "de", string(dest)) -// utils.AssertEqual(t, 0, len(errs)) -// }) + go start() -// t.Run("enough dest", func(t *testing.T) { -// dest := []byte("foobar") + req := AcquireRequest(). + SetClient(client). + SetFormData("foo", "bar"). + AddFiles(AcquireFile( + SetFileName("hello.txt"), + SetFileFieldName("foo"), + SetFileReader(io.NopCloser(strings.NewReader("world"))), + )) -// a := Get("http://example.com") + resp, err := req.Post("http://exmaple.com") + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) + }) -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + t.Run("raw body", func(t *testing.T) { + testAgent(t, + func(c fiber.Ctx) error { + return c.SendString(string(c.Request().Body())) + }, + func(agent *Request) { + agent.SetRawBody([]byte("hello")) + }, + "hello", + ) + }) +} + +// func Test_Client_Agent_Multipart_Invalid_Boundary(t *testing.T) { +// t.Parallel() -// code, body, errs := a.Dest(dest[:0]).String() +// a := Post("http://example.com"). +// Boundary("*"). +// MultipartForm(nil) -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "dest", body) -// utils.AssertEqual(t, "destar", string(dest)) -// utils.AssertEqual(t, 0, len(errs)) -// }) +// utils.AssertEqual(t, 1, len(a.errs)) +// utils.AssertEqual(t, "mime: invalid boundary character", a.errs[0].Error()) +// } + +// func Test_Client_Agent_SendFile_Error(t *testing.T) { +// t.Parallel() + +// a := Post("http://example.com"). +// SendFile("non-exist-file!", "") + +// utils.AssertEqual(t, 1, len(a.errs)) +// utils.AssertEqual(t, true, strings.Contains(a.errs[0].Error(), "open non-exist-file!")) // } // // readErrorConn is a struct for testing retryIf @@ -1134,240 +1244,6 @@ func Test_Request_Header_With_Server(t *testing.T) { // utils.AssertEqual(t, 0, len(errs)) // } -// func Test_Client_Agent_Json(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// utils.AssertEqual(t, fiber.MIMEApplicationJSON, string(c.Request().Header.ContentType())) - -// return c.Send(c.Request().Body()) -// } - -// wrapAgent := func(a *Agent) { -// a.JSON(data{Success: true}) -// } - -// testAgent(t, handler, wrapAgent, `{"success":true}`) -// } - -// func Test_Client_Agent_Json_Error(t *testing.T) { -// a := Get("http://example.com"). -// JSONEncoder(json.Marshal). -// JSON(complex(1, 1)) - -// _, body, errs := a.String() - -// utils.AssertEqual(t, "", body) -// utils.AssertEqual(t, 1, len(errs)) -// utils.AssertEqual(t, "json: unsupported type: complex128", errs[0].Error()) -// } - -// func Test_Client_Agent_XML(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// utils.AssertEqual(t, fiber.MIMEApplicationXML, string(c.Request().Header.ContentType())) - -// return c.Send(c.Request().Body()) -// } - -// wrapAgent := func(a *Agent) { -// a.XML(data{Success: true}) -// } - -// testAgent(t, handler, wrapAgent, "true") -// } - -// func Test_Client_Agent_XML_Error(t *testing.T) { -// a := Get("http://example.com"). -// XML(complex(1, 1)) - -// _, body, errs := a.String() - -// utils.AssertEqual(t, "", body) -// utils.AssertEqual(t, 1, len(errs)) -// utils.AssertEqual(t, "xml: unsupported type: complex128", errs[0].Error()) -// } - -// func Test_Client_Agent_Form(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// utils.AssertEqual(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) - -// return c.Send(c.Request().Body()) -// } - -// args := AcquireArgs() - -// args.Set("foo", "bar") - -// wrapAgent := func(a *Agent) { -// a.Form(args) -// } - -// testAgent(t, handler, wrapAgent, "foo=bar") - -// ReleaseArgs(args) -// } - -// func Test_Client_Agent_MultipartForm(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Post("/", func(c fiber.Ctx) error { -// utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) - -// mf, err := c.MultipartForm() -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, "bar", mf.Value["foo"][0]) - -// return c.Send(c.Request().Body()) -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// args := AcquireArgs() - -// args.Set("foo", "bar") - -// a := Post("http://example.com"). -// Boundary("myBoundary"). -// MultipartForm(args) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "--myBoundary\r\nContent-Disposition: form-data; name=\"foo\"\r\n\r\nbar\r\n--myBoundary--\r\n", body) -// utils.AssertEqual(t, 0, len(errs)) -// ReleaseArgs(args) -// } - -// func Test_Client_Agent_MultipartForm_Errors(t *testing.T) { -// t.Parallel() - -// a := AcquireAgent() -// a.mw = &errorMultipartWriter{} - -// args := AcquireArgs() -// args.Set("foo", "bar") - -// ff1 := &FormFile{"", "name1", []byte("content"), false} -// ff2 := &FormFile{"", "name2", []byte("content"), false} -// a.FileData(ff1, ff2). -// MultipartForm(args) - -// utils.AssertEqual(t, 4, len(a.errs)) -// ReleaseArgs(args) -// } - -// func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Post("/", func(c fiber.Ctx) error { -// utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) - -// fh1, err := c.FormFile("field1") -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, fh1.Filename, "name") -// buf := make([]byte, fh1.Size) -// f, err := fh1.Open() -// utils.AssertEqual(t, nil, err) -// defer func() { _ = f.Close() }() -// _, err = f.Read(buf) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, "form file", string(buf)) - -// fh2, err := c.FormFile("index") -// utils.AssertEqual(t, nil, err) -// checkFormFile(t, fh2, ".github/testdata/index.html") - -// fh3, err := c.FormFile("file3") -// utils.AssertEqual(t, nil, err) -// checkFormFile(t, fh3, ".github/testdata/index.tmpl") - -// return c.SendString("multipart form files") -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// for i := 0; i < 5; i++ { -// ff := AcquireFormFile() -// ff.Fieldname = "field1" -// ff.Name = "name" -// ff.Content = []byte("form file") - -// a := Post("http://example.com"). -// Boundary("myBoundary"). -// FileData(ff). -// SendFiles(".github/testdata/index.html", "index", ".github/testdata/index.tmpl"). -// MultipartForm(nil) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "multipart form files", body) -// utils.AssertEqual(t, 0, len(errs)) - -// ReleaseFormFile(ff) -// } -// } - -// func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { -// t.Helper() - -// basename := filepath.Base(filename) -// utils.AssertEqual(t, fh.Filename, basename) - -// b1, err := os.ReadFile(filename) -// utils.AssertEqual(t, nil, err) - -// b2 := make([]byte, fh.Size) -// f, err := fh.Open() -// utils.AssertEqual(t, nil, err) -// defer func() { _ = f.Close() }() -// _, err = f.Read(b2) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, b1, b2) -// } - -// func Test_Client_Agent_Multipart_Random_Boundary(t *testing.T) { -// t.Parallel() - -// a := Post("http://example.com"). -// MultipartForm(nil) - -// reg := regexp.MustCompile(`multipart/form-data; boundary=\w{30}`) - -// utils.AssertEqual(t, true, reg.Match(a.req.Header.Peek(fiber.HeaderContentType))) -// } - -// func Test_Client_Agent_Multipart_Invalid_Boundary(t *testing.T) { -// t.Parallel() - -// a := Post("http://example.com"). -// Boundary("*"). -// MultipartForm(nil) - -// utils.AssertEqual(t, 1, len(a.errs)) -// utils.AssertEqual(t, "mime: invalid boundary character", a.errs[0].Error()) -// } - -// func Test_Client_Agent_SendFile_Error(t *testing.T) { -// t.Parallel() - -// a := Post("http://example.com"). -// SendFile("non-exist-file!", "") - -// utils.AssertEqual(t, 1, len(a.errs)) -// utils.AssertEqual(t, true, strings.Contains(a.errs[0].Error(), "open non-exist-file!")) -// } - // func Test_Client_Debug(t *testing.T) { // handler := func(c fiber.Ctx) error { // return c.SendString("debug") @@ -1634,37 +1510,6 @@ func Test_Request_Header_With_Server(t *testing.T) { // utils.AssertEqual(t, "example.com:443", addr) // } -// func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Agent), excepted string, count ...int) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", handler) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// c := 1 -// if len(count) > 0 { -// c = count[0] -// } - -// for i := 0; i < c; i++ { -// a := Get("http://example.com") - -// wrapAgent(a) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, excepted, body) -// utils.AssertEqual(t, 0, len(errs)) -// } -// } - // type data struct { // Success bool `json:"success" xml:"success"` // } diff --git a/client/response.go b/client/response.go index 1bbacbc910..56aa0c2216 100644 --- a/client/response.go +++ b/client/response.go @@ -4,6 +4,7 @@ import ( "strings" "sync" + "github.com/gofiber/fiber/v3/utils" "github.com/valyala/fasthttp" ) @@ -42,8 +43,8 @@ func (r *Response) Protocol() string { } // Header method returns the response headers. -func (r *Response) Header() fasthttp.ResponseHeader { - return r.rawResponse.Header +func (r *Response) Header(key string) string { + return utils.UnsafeString(r.rawResponse.Header.Peek(key)) } // Cookies method to access all the response cookies. @@ -75,10 +76,11 @@ func (r *Response) XML(v any) error { func (r *Response) Reset() { r.client = nil r.request = nil - copied := r.cookie - r.cookie = []*fasthttp.Cookie{} - for _, v := range copied { - fasthttp.ReleaseCookie(v) + + for len(r.cookie) != 0 { + t := r.cookie[0] + r.cookie = r.cookie[1:] + fasthttp.ReleaseCookie(t) } r.rawResponse.Reset() From 84f3145ab698a5e8ac7b017424b34cfeb56ba6f6 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Thu, 18 Aug 2022 22:07:45 +0800 Subject: [PATCH 025/118] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20chore:=20change=20?= =?UTF-8?q?func=20to=20improve=20performance?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/hooks.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/client/hooks.go b/client/hooks.go index 0231a8387e..6be46d2361 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -2,7 +2,6 @@ package client import ( "bytes" - "fmt" "io" "math/rand" "mime/multipart" @@ -245,9 +244,9 @@ func parserRequestBody(c *Client, req *Request) error { v.name = filepath.Base(v.path) } - // if param is not exist, set it + // if field name is not exist, set it if v.fieldName == "" { - v.fieldName = "file" + fmt.Sprint(i+1) + v.fieldName = "file" + strconv.Itoa(i+1) } // check the reader From 2b785e974420f4ffa36677fdd7ea01cb7f962b87 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Thu, 18 Aug 2022 22:08:09 +0800 Subject: [PATCH 026/118] =?UTF-8?q?=E2=9C=85=20v3:=20add=20some=20unit=20t?= =?UTF-8?q?est?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/request.go | 33 ++++++++ client/request_test.go | 188 ++++++++++++++++++++--------------------- 2 files changed, 125 insertions(+), 96 deletions(-) diff --git a/client/request.go b/client/request.go index 750b63efa8..fb1baf4580 100644 --- a/client/request.go +++ b/client/request.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "path/filepath" "reflect" "sort" "strconv" @@ -402,6 +403,33 @@ func (r *Request) DelFormDatas(key ...string) *Request { return r } +// File returns file ptr store in request obj by name. +// If name field is empty, it will try to match path. +func (r *Request) File(name string) *File { + for _, v := range r.files { + if v.name == "" { + if filepath.Base(v.path) == name { + return v + } + } else if v.name == name { + return v + } + } + + return nil +} + +// File returns file ptr store in request obj by path. +func (r *Request) FileByPath(path string) *File { + for _, v := range r.files { + if v.path == path { + return v + } + } + + return nil +} + // AddFile method adds single file field // and its value in the request instance via file path. func (r *Request) AddFile(path string) *Request { @@ -426,6 +454,11 @@ func (r *Request) AddFiles(files ...*File) *Request { return r } +// Timeout returns the length of timeout in request. +func (r *Request) Timeout() time.Duration { + return r.timeout +} + // SetTimeout method sets timeout field and its values at one go in the request instance. // It will override timeout which set in client instance. func (r *Request) SetTimeout(t time.Duration) *Request { diff --git a/client/request_test.go b/client/request_test.go index 5d3b0c4e59..4a26507b4d 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -5,6 +5,7 @@ import ( "context" "errors" "io" + "io/ioutil" "mime/multipart" "net" "os" @@ -12,6 +13,7 @@ import ( "regexp" "strings" "testing" + "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/utils" @@ -514,10 +516,43 @@ func Test_Request_File(t *testing.T) { t.Parallel() t.Run("add file", func(t *testing.T) { + req := AcquireRequest(). + AddFile("../.github/index.html"). + AddFiles(AcquireFile(SetFileName("tmp.txt"))) + + utils.AssertEqual(t, "../.github/index.html", req.File("index.html").path) + utils.AssertEqual(t, "../.github/index.html", req.FileByPath("../.github/index.html").path) + utils.AssertEqual(t, "tmp.txt", req.File("tmp.txt").name) + }) + + t.Run("add file by reader", func(t *testing.T) { + req := AcquireRequest(). + AddFileWithReader("tmp.txt", io.NopCloser(strings.NewReader("world"))) + + utils.AssertEqual(t, "tmp.txt", req.File("tmp.txt").name) + + content, err := ioutil.ReadAll(req.File("tmp.txt").reader) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "world", string(content)) + }) + + t.Run("add files", func(t *testing.T) { + req := AcquireRequest(). + AddFiles(AcquireFile(SetFileName("tmp.txt")), AcquireFile(SetFileName("foo.txt"))) + utils.AssertEqual(t, "tmp.txt", req.File("tmp.txt").name) + utils.AssertEqual(t, "foo.txt", req.File("foo.txt").name) }) } +func Test_Request_Timeout(t *testing.T) { + t.Parallel() + + req := AcquireRequest().SetTimeout(5 * time.Second) + + utils.AssertEqual(t, 5*time.Second, req.Timeout()) +} + func createHelperServer(t *testing.T) (*fiber.App, *Client, func()) { t.Helper() @@ -976,18 +1011,6 @@ func Test_Request_Body_With_Server(t *testing.T) { ) }) - t.Run("json error", func(t *testing.T) { - testAgentFail(t, - func(c fiber.Ctx) error { - return c.SendString("") - }, - func(agent *Request) { - agent.SetJSON(complex(1, 1)) - }, - errors.New("json: unsupported type: complex128"), - ) - }) - t.Run("xml body", func(t *testing.T) { testAgent(t, func(c fiber.Ctx) error { @@ -1006,18 +1029,6 @@ func Test_Request_Body_With_Server(t *testing.T) { ) }) - t.Run("xml error", func(t *testing.T) { - testAgentFail(t, - func(c fiber.Ctx) error { - return c.SendString("") - }, - func(agent *Request) { - agent.SetXML(complex(1, 1)) - }, - errors.New("xml: unsupported type: complex128"), - ) - }) - t.Run("formdata", func(t *testing.T) { testAgent(t, func(c fiber.Ctx) error { @@ -1163,26 +1174,68 @@ func Test_Request_Body_With_Server(t *testing.T) { }) } -// func Test_Client_Agent_Multipart_Invalid_Boundary(t *testing.T) { -// t.Parallel() +func Test_Request_Error_Body_With_Server(t *testing.T) { + t.Run("json error", func(t *testing.T) { + testAgentFail(t, + func(c fiber.Ctx) error { + return c.SendString("") + }, + func(agent *Request) { + agent.SetJSON(complex(1, 1)) + }, + errors.New("json: unsupported type: complex128"), + ) + }) -// a := Post("http://example.com"). -// Boundary("*"). -// MultipartForm(nil) + t.Run("xml error", func(t *testing.T) { + testAgentFail(t, + func(c fiber.Ctx) error { + return c.SendString("") + }, + func(agent *Request) { + agent.SetXML(complex(1, 1)) + }, + errors.New("xml: unsupported type: complex128"), + ) + }) -// utils.AssertEqual(t, 1, len(a.errs)) -// utils.AssertEqual(t, "mime: invalid boundary character", a.errs[0].Error()) -// } + t.Run("form body with invalid boundary", func(t *testing.T) { + t.Parallel() -// func Test_Client_Agent_SendFile_Error(t *testing.T) { -// t.Parallel() + _, err := AcquireRequest(). + SetBoundary("*"). + AddFileWithReader("t.txt", io.NopCloser(strings.NewReader("world"))). + Get("http://example.com") + utils.AssertEqual(t, "mime: invalid boundary character", err.Error()) + }) -// a := Post("http://example.com"). -// SendFile("non-exist-file!", "") + t.Run("open non exist file", func(t *testing.T) { + t.Parallel() -// utils.AssertEqual(t, 1, len(a.errs)) -// utils.AssertEqual(t, true, strings.Contains(a.errs[0].Error(), "open non-exist-file!")) -// } + _, err := AcquireRequest(). + AddFile("non-exist-file!"). + Get("http://example.com") + utils.AssertEqual(t, "open non-exist-file!: The system cannot find the file specified.", err.Error()) + }) +} + +func Test_Request_Timeout_With_Server(t *testing.T) { + t.Parallel() + + app, client, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + time.Sleep(time.Millisecond * 200) + return c.SendString("timeout") + }) + go start() + + _, err := AcquireRequest(). + SetClient(client). + SetTimeout(50 * time.Millisecond). + Get("http://example.com") + + utils.AssertEqual(t, ErrTimeoutOrCancel, err) +} // // readErrorConn is a struct for testing retryIf // type readErrorConn struct { @@ -1267,63 +1320,6 @@ func Test_Request_Body_With_Server(t *testing.T) { // utils.AssertEqual(t, true, strings.Contains(str, "Content-Type: text/plain; charset=utf-8\r\nContent-Length: 5\r\n\r\ndebug")) // } -// func Test_Client_Agent_Timeout(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// time.Sleep(time.Millisecond * 200) -// return c.SendString("timeout") -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// a := Get("http://example.com"). -// Timeout(time.Millisecond * 50) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// _, body, errs := a.String() - -// utils.AssertEqual(t, "", body) -// utils.AssertEqual(t, 1, len(errs)) -// utils.AssertEqual(t, "timeout", errs[0].Error()) -// } - -// func Test_Client_Agent_Reuse(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString("reuse") -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// a := Get("http://example.com"). -// Reuse() - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "reuse", body) -// utils.AssertEqual(t, 0, len(errs)) - -// code, body, errs = a.String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "reuse", body) -// utils.AssertEqual(t, 0, len(errs)) -// } - // func Test_Client_Agent_InsecureSkipVerify(t *testing.T) { // t.Parallel() From c3f40b7d43caa7701180add8f60f30f157d6b2d6 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Fri, 19 Aug 2022 15:44:06 +0800 Subject: [PATCH 027/118] =?UTF-8?q?=E2=9C=85=20v3:=20fix=20error=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/core_test.go | 17 ++++++---- client/helper_test.go | 72 ++++++++++++++++++++++++++++++++++++++++++ client/request_test.go | 66 +------------------------------------- 3 files changed, 83 insertions(+), 72 deletions(-) create mode 100644 client/helper_test.go diff --git a/client/core_test.go b/client/core_test.go index 194cf60ed9..700969e579 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -14,8 +14,6 @@ import ( ) func Test_Exec_Func(t *testing.T) { - t.Parallel() - ln := fasthttputil.NewInmemoryListener() app := fiber.New(fiber.Config{DisableStartupMessage: true}) @@ -37,7 +35,8 @@ func Test_Exec_Func(t *testing.T) { }() t.Run("normal request", func(t *testing.T) { - core, client, req := newCore(), AcquireClient(), AcquireRequest() + client, req := AcquireClient(), AcquireRequest() + core := client.core core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.rawRequest.SetRequestURI("http://example.com/normal") @@ -48,7 +47,8 @@ func Test_Exec_Func(t *testing.T) { }) t.Run("the request return an error", func(t *testing.T) { - core, client, req := newCore(), AcquireClient(), AcquireRequest() + client, req := AcquireClient(), AcquireRequest() + core := client.core core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.rawRequest.SetRequestURI("http://example.com/return-error") @@ -60,7 +60,8 @@ func Test_Exec_Func(t *testing.T) { }) t.Run("there is no connect", func(t *testing.T) { - core, client := newCore(), AcquireClient() + client := AcquireClient() + core := client.core core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } core.client.SetMaxConns(1) @@ -77,11 +78,13 @@ func Test_Exec_Func(t *testing.T) { }) t.Run("the request timeout", func(t *testing.T) { - core, client, req := newCore(), AcquireClient(), AcquireRequest() + client, req := AcquireClient(), AcquireRequest() + core := client.core + core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.rawRequest.SetRequestURI("http://example.com/hang-up") - ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() _, err := core.execFunc(ctx, client, req) diff --git a/client/helper_test.go b/client/helper_test.go new file mode 100644 index 0000000000..3b0f82b0c4 --- /dev/null +++ b/client/helper_test.go @@ -0,0 +1,72 @@ +package client + +import ( + "net" + "testing" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/utils" + "github.com/valyala/fasthttp/fasthttputil" +) + +func createHelperServer(t *testing.T) (*fiber.App, *Client, func()) { + t.Helper() + + ln := fasthttputil.NewInmemoryListener() + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + + client := AcquireClient().SetDial(func(addr string) (net.Conn, error) { + return ln.Dial() + }) + + return app, client, func() { + utils.AssertEqual(t, nil, app.Listener(ln)) + } +} + +func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted string, count ...int) { + t.Parallel() + + app, client, start := createHelperServer(t) + app.Get("/", handler) + go start() + + c := 1 + if len(count) > 0 { + c = count[0] + } + + for i := 0; i < c; i++ { + req := AcquireRequest().SetClient(client) + wrapAgent(req) + + resp, err := req.Get("http://example.com") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) + utils.AssertEqual(t, excepted, resp.String()) + resp.Close() + } +} + +func testAgentFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted error, count ...int) { + t.Parallel() + + app, client, start := createHelperServer(t) + app.Get("/", handler) + go start() + + c := 1 + if len(count) > 0 { + c = count[0] + } + + for i := 0; i < c; i++ { + req := AcquireRequest().SetClient(client) + wrapAgent(req) + + _, err := req.Get("http://example.com") + + utils.AssertEqual(t, excepted.Error(), err.Error()) + } +} diff --git a/client/request_test.go b/client/request_test.go index 4a26507b4d..6a7e196ebf 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -7,7 +7,6 @@ import ( "io" "io/ioutil" "mime/multipart" - "net" "os" "path/filepath" "regexp" @@ -18,7 +17,6 @@ import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/utils" "github.com/valyala/fasthttp" - "github.com/valyala/fasthttp/fasthttputil" ) func Test_Request_Method(t *testing.T) { @@ -553,21 +551,6 @@ func Test_Request_Timeout(t *testing.T) { utils.AssertEqual(t, 5*time.Second, req.Timeout()) } -func createHelperServer(t *testing.T) (*fiber.App, *Client, func()) { - t.Helper() - - ln := fasthttputil.NewInmemoryListener() - app := fiber.New(fiber.Config{DisableStartupMessage: true}) - - client := AcquireClient().SetDial(func(addr string) (net.Conn, error) { - return ln.Dial() - }) - - return app, client, func() { - utils.AssertEqual(t, nil, app.Listener(ln)) - } -} - func Test_Request_Invalid_URL(t *testing.T) { t.Parallel() @@ -777,53 +760,6 @@ func Test_Request_Patch(t *testing.T) { } } -func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted string, count ...int) { - t.Parallel() - - app, client, start := createHelperServer(t) - app.Get("/", handler) - go start() - - c := 1 - if len(count) > 0 { - c = count[0] - } - - for i := 0; i < c; i++ { - req := AcquireRequest().SetClient(client) - wrapAgent(req) - - resp, err := req.Get("http://example.com") - - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) - utils.AssertEqual(t, excepted, resp.String()) - resp.Close() - } -} - -func testAgentFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted error, count ...int) { - t.Parallel() - - app, client, start := createHelperServer(t) - app.Get("/", handler) - go start() - - c := 1 - if len(count) > 0 { - c = count[0] - } - - for i := 0; i < c; i++ { - req := AcquireRequest().SetClient(client) - wrapAgent(req) - - _, err := req.Get("http://example.com") - - utils.AssertEqual(t, excepted.Error(), err.Error()) - } -} - func Test_Request_Header_With_Server(t *testing.T) { handler := func(c fiber.Ctx) error { c.Request().Header.VisitAll(func(key, value []byte) { @@ -1033,7 +969,7 @@ func Test_Request_Body_With_Server(t *testing.T) { testAgent(t, func(c fiber.Ctx) error { utils.AssertEqual(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) - return c.Send(c.Request().Body()) + return c.Send([]byte("foo=" + c.FormValue("foo") + "&bar=" + c.FormValue("bar") + "&fiber=" + c.FormValue("fiber"))) }, func(agent *Request) { agent.SetFormData("foo", "bar"). From 7706d5a88fd06948553709ff20cb2a596fc08784 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Fri, 19 Aug 2022 16:02:19 +0800 Subject: [PATCH 028/118] =?UTF-8?q?=F0=9F=90=9B=20fix:=20add=20cookie=20to?= =?UTF-8?q?=20response?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/hooks.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/client/hooks.go b/client/hooks.go index 6be46d2361..f768ef4371 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -295,11 +295,10 @@ func parserRequestBody(c *Client, req *Request) error { func parserResponseCookie(c *Client, resp *Response, req *Request) (err error) { resp.rawResponse.Header.VisitAllCookie(func(key, value []byte) { cookie := fasthttp.AcquireCookie() - err = cookie.ParseBytes(value) - if err != nil { - return - } + _ = cookie.ParseBytes(value) cookie.SetKeyBytes(key) + + resp.cookie = append(resp.cookie, cookie) }) return From 03ce5f798cbeabc3f49da04a2b0920275dc21b9b Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Fri, 19 Aug 2022 16:02:38 +0800 Subject: [PATCH 029/118] =?UTF-8?q?=E2=9C=85=20v3:=20add=20unit=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/response_test.go | 228 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 228 insertions(+) create mode 100644 client/response_test.go diff --git a/client/response_test.go b/client/response_test.go new file mode 100644 index 0000000000..9245044fd1 --- /dev/null +++ b/client/response_test.go @@ -0,0 +1,228 @@ +package client + +import ( + "encoding/xml" + "fmt" + "testing" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/utils" +) + +func Test_Response_Status(t *testing.T) { + t.Parallel() + + app, client, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("foo") + }) + app.Get("/fail", func(c fiber.Ctx) error { + return c.SendStatus(407) + }) + go start() + + t.Run("success", func(t *testing.T) { + t.Parallel() + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "OK", resp.Status()) + resp.Close() + }) + + t.Run("fail", func(t *testing.T) { + t.Parallel() + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example/fail") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "Proxy Authentication Required", resp.Status()) + resp.Close() + }) +} + +func Test_Response_Status_Code(t *testing.T) { + t.Parallel() + + app, client, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("foo") + }) + app.Get("/fail", func(c fiber.Ctx) error { + return c.SendStatus(407) + }) + go start() + + t.Run("success", func(t *testing.T) { + t.Parallel() + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 200, resp.StatusCode()) + resp.Close() + }) + + t.Run("fail", func(t *testing.T) { + t.Parallel() + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example/fail") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 407, resp.StatusCode()) + resp.Close() + }) +} + +func Test_Response_Protocol(t *testing.T) { + t.Parallel() + + t.Run("http", func(t *testing.T) { + app, client, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("foo") + }) + go start() + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "HTTP/1.1", resp.Protocol()) + resp.Close() + }) + + // TODO: add https test after support https + t.Run("https", func(t *testing.T) { + t.Parallel() + }) +} + +func Test_Response_Header(t *testing.T) { + t.Parallel() + + app, client, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + c.Response().Header.Add("foo", "bar") + return c.SendString("helo world") + }) + go start() + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "bar", resp.Header("foo")) + resp.Close() +} + +func Test_Response_Cookie(t *testing.T) { + t.Parallel() + + app, client, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + c.Cookie(&fiber.Cookie{ + Name: "foo", + Value: "bar", + }) + return c.SendString("helo world") + }) + go start() + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com") + + fmt.Println(resp.rawResponse.String()) + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "bar", string(resp.Cookies()[0].Value())) + resp.Close() +} + +func Test_Response_Body(t *testing.T) { + t.Parallel() + + app, client, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) + app.Get("/json", func(c fiber.Ctx) error { + return c.SendString("{\"status\":\"success\"}") + }) + app.Get("/xml", func(c fiber.Ctx) error { + return c.SendString("success") + }) + + go start() + + t.Run("raw body", func(t *testing.T) { + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, []byte("hello world"), resp.Body()) + resp.Close() + }) + + t.Run("string body", func(t *testing.T) { + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com") + + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "hello world", resp.String()) + resp.Close() + }) + + t.Run("json body", func(t *testing.T) { + type body struct { + Status string `json:"status"` + } + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com/json") + + utils.AssertEqual(t, nil, err) + + tmp := &body{} + err = resp.JSON(tmp) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "success", tmp.Status) + resp.Close() + }) + + t.Run("xml body", func(t *testing.T) { + type body struct { + Name xml.Name `xml:"status"` + Status string `xml:"name"` + } + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com/xml") + + fmt.Println(resp.rawResponse.String()) + + utils.AssertEqual(t, nil, err) + + tmp := &body{} + err = resp.XML(tmp) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "success", tmp.Status) + resp.Close() + }) +} From 1abea22e523f01a9d65cffe7f7093bacd23e02a7 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Fri, 19 Aug 2022 16:05:21 +0800 Subject: [PATCH 030/118] =?UTF-8?q?=E2=9C=A8=20v3:=20export=20raw=20field?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client_test.go | 2 +- client/core.go | 4 +- client/core_test.go | 26 ++++++------- client/hooks.go | 52 +++++++++++++------------- client/hooks_test.go | 82 ++++++++++++++++++++--------------------- client/request.go | 6 +-- client/response.go | 23 ++++++------ client/response_test.go | 5 --- 8 files changed, 98 insertions(+), 102 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index a29b1a5c5b..ce051f2c74 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -69,7 +69,7 @@ func Test_Get(t *testing.T) { resp, err := Get("http://example.com") utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "example.com", utils.UnsafeString(resp.rawResponse.Body())) + utils.AssertEqual(t, "example.com", utils.UnsafeString(resp.RawResponse.Body())) }) } diff --git a/client/core.go b/client/core.go index 5871e0a901..83d2e97314 100644 --- a/client/core.go +++ b/client/core.go @@ -59,14 +59,14 @@ func (c *core) execFunc(ctx context.Context, client *Client, req *Request) (*Res fasthttp.ReleaseResponse(respv) }() - req.rawRequest.CopyTo(reqv) + req.RawRequest.CopyTo(reqv) go func() { err := c.client.Do(reqv, respv) if err != nil { errCh <- err return } - respv.CopyTo(resp.rawResponse) + respv.CopyTo(resp.RawResponse) errCh <- nil }() diff --git a/client/core_test.go b/client/core_test.go index 700969e579..106b46f629 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -38,25 +38,25 @@ func Test_Exec_Func(t *testing.T) { client, req := AcquireClient(), AcquireRequest() core := client.core core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - req.rawRequest.SetRequestURI("http://example.com/normal") + req.RawRequest.SetRequestURI("http://example.com/normal") resp, err := core.execFunc(context.Background(), client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, 200, resp.rawResponse.StatusCode()) - utils.AssertEqual(t, "example.com", string(resp.rawResponse.Body())) + utils.AssertEqual(t, 200, resp.RawResponse.StatusCode()) + utils.AssertEqual(t, "example.com", string(resp.RawResponse.Body())) }) t.Run("the request return an error", func(t *testing.T) { client, req := AcquireClient(), AcquireRequest() core := client.core core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - req.rawRequest.SetRequestURI("http://example.com/return-error") + req.RawRequest.SetRequestURI("http://example.com/return-error") resp, err := core.execFunc(context.Background(), client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, 500, resp.rawResponse.StatusCode()) - utils.AssertEqual(t, "the request is error", string(resp.rawResponse.Body())) + utils.AssertEqual(t, 500, resp.RawResponse.StatusCode()) + utils.AssertEqual(t, "the request is error", string(resp.RawResponse.Body())) }) t.Run("there is no connect", func(t *testing.T) { @@ -67,13 +67,13 @@ func Test_Exec_Func(t *testing.T) { go func() { req := AcquireRequest() - req.rawRequest.SetRequestURI("http://example.com/normal") + req.RawRequest.SetRequestURI("http://example.com/normal") _, err := core.execFunc(context.Background(), client, req) utils.AssertEqual(t, fasthttp.ErrNoFreeConns, err) }() req := AcquireRequest() - req.rawRequest.SetRequestURI("http://example.com/hang-up") + req.RawRequest.SetRequestURI("http://example.com/hang-up") core.execFunc(context.Background(), client, req) }) @@ -82,7 +82,7 @@ func Test_Exec_Func(t *testing.T) { core := client.core core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - req.rawRequest.SetRequestURI("http://example.com/hang-up") + req.RawRequest.SetRequestURI("http://example.com/hang-up") ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() @@ -128,7 +128,7 @@ func Test_Execute(t *testing.T) { resp, err := client.core.execute(context.Background(), client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "Cannot GET /", string(resp.rawResponse.Body())) + utils.AssertEqual(t, "Cannot GET /", string(resp.RawResponse.Body())) }) t.Run("add user response hooks", func(t *testing.T) { @@ -143,7 +143,7 @@ func Test_Execute(t *testing.T) { resp, err := client.core.execute(context.Background(), client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "Cannot GET /", string(resp.rawResponse.Body())) + utils.AssertEqual(t, "Cannot GET /", string(resp.RawResponse.Body())) }) t.Run("no timeout", func(t *testing.T) { @@ -155,7 +155,7 @@ func Test_Execute(t *testing.T) { resp, err := client.core.execute(context.Background(), client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "example.com hang up", string(resp.rawResponse.Body())) + utils.AssertEqual(t, "example.com hang up", string(resp.RawResponse.Body())) }) t.Run("client timeout", func(t *testing.T) { @@ -192,6 +192,6 @@ func Test_Execute(t *testing.T) { resp, err := client.core.execute(context.Background(), client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "example.com hang up", string(resp.rawResponse.Body())) + utils.AssertEqual(t, "example.com hang up", string(resp.RawResponse.Body())) }) } diff --git a/client/hooks.go b/client/hooks.go index f768ef4371..94cb604bd0 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -97,8 +97,8 @@ func parserRequestURL(c *Client, req *Request) error { }) // set uri to request and orther related setting - req.rawRequest.SetRequestURI(uri) - rawUri := req.rawRequest.URI() + req.RawRequest.SetRequestURI(uri) + rawUri := req.RawRequest.URI() isTLS, scheme := false, rawUri.Scheme() if bytes.Equal(httpsBytes, scheme) { isTLS = true @@ -124,8 +124,8 @@ func parserRequestURL(c *Client, req *Request) error { req.params.VisitAll(func(key, value []byte) { args.AddBytesKV(key, value) }) - req.rawRequest.URI().SetQueryStringBytes(utils.CopyBytes(args.QueryString())) - req.rawRequest.URI().SetHash(hashSplit[1]) + req.RawRequest.URI().SetQueryStringBytes(utils.CopyBytes(args.QueryString())) + req.RawRequest.URI().SetHash(hashSplit[1]) return nil } @@ -136,54 +136,54 @@ func parserRequestURL(c *Client, req *Request) error { // User-Agent should be set. func parserRequestHeader(c *Client, req *Request) error { // set method - req.rawRequest.Header.SetMethod(req.Method()) + req.RawRequest.Header.SetMethod(req.Method()) // merge header c.header.VisitAll(func(key, value []byte) { - req.rawRequest.Header.AddBytesKV(key, value) + req.RawRequest.Header.AddBytesKV(key, value) }) req.header.VisitAll(func(key, value []byte) { - req.rawRequest.Header.AddBytesKV(key, value) + req.RawRequest.Header.AddBytesKV(key, value) }) // according to data set content-type switch req.bodyType { case jsonBody: - req.rawRequest.Header.SetContentType(applicationJSON) - req.rawRequest.Header.Set(headerAccept, applicationJSON) + req.RawRequest.Header.SetContentType(applicationJSON) + req.RawRequest.Header.Set(headerAccept, applicationJSON) case xmlBody: - req.rawRequest.Header.SetContentType(applicationXML) + req.RawRequest.Header.SetContentType(applicationXML) case formBody: - req.rawRequest.Header.SetContentType(applicationForm) + req.RawRequest.Header.SetContentType(applicationForm) case filesBody: - req.rawRequest.Header.SetContentType(multipartFormData) + req.RawRequest.Header.SetContentType(multipartFormData) // set boundary - req.rawRequest.Header.SetMultipartFormBoundary(req.boundary) + req.RawRequest.Header.SetMultipartFormBoundary(req.boundary) default: } // set useragent - req.rawRequest.Header.SetUserAgent(defaultUserAgent) + req.RawRequest.Header.SetUserAgent(defaultUserAgent) if c.userAgent != "" { - req.rawRequest.Header.SetUserAgent(c.userAgent) + req.RawRequest.Header.SetUserAgent(c.userAgent) } if req.userAgent != "" { - req.rawRequest.Header.SetUserAgent(req.userAgent) + req.RawRequest.Header.SetUserAgent(req.userAgent) } // set referer - req.rawRequest.Header.SetReferer(c.referer) + req.RawRequest.Header.SetReferer(c.referer) if req.referer != "" { - req.rawRequest.Header.SetReferer(req.referer) + req.RawRequest.Header.SetReferer(req.referer) } // set cookie c.cookies.VisitAll(func(key, val string) { - req.rawRequest.Header.SetCookie(key, val) + req.RawRequest.Header.SetCookie(key, val) }) req.cookies.VisitAll(func(key, val string) { - req.rawRequest.Header.SetCookie(key, val) + req.RawRequest.Header.SetCookie(key, val) }) return nil @@ -198,17 +198,17 @@ func parserRequestBody(c *Client, req *Request) error { if err != nil { return err } - req.rawRequest.SetBody(body) + req.RawRequest.SetBody(body) case xmlBody: body, err := c.core.xmlMarshal(req.body) if err != nil { return err } - req.rawRequest.SetBody(body) + req.RawRequest.SetBody(body) case formBody: - req.rawRequest.SetBody(req.formData.QueryString()) + req.RawRequest.SetBody(req.formData.QueryString()) case filesBody: - mw := multipart.NewWriter(req.rawRequest.BodyWriter()) + mw := multipart.NewWriter(req.RawRequest.BodyWriter()) err := mw.SetBoundary(req.boundary) if err != nil { return err @@ -284,7 +284,7 @@ func parserRequestBody(c *Client, req *Request) error { } case rawBody: if body, ok := req.body.([]byte); ok { - req.rawRequest.SetBody(body) + req.RawRequest.SetBody(body) } else { return ErrBodyType } @@ -293,7 +293,7 @@ func parserRequestBody(c *Client, req *Request) error { } func parserResponseCookie(c *Client, resp *Response, req *Request) (err error) { - resp.rawResponse.Header.VisitAllCookie(func(key, value []byte) { + resp.RawResponse.Header.VisitAllCookie(func(key, value []byte) { cookie := fasthttp.AcquireCookie() _ = cookie.ParseBytes(value) cookie.SetKeyBytes(key) diff --git a/client/hooks_test.go b/client/hooks_test.go index 7f3af4c794..7fc3b43c8d 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -78,7 +78,7 @@ func Test_Parser_Request_URL(t *testing.T) { err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api", req.rawRequest.URI().String()) + utils.AssertEqual(t, "http://example.com/api", req.RawRequest.URI().String()) }) t.Run("request url should be set", func(t *testing.T) { @@ -87,7 +87,7 @@ func Test_Parser_Request_URL(t *testing.T) { err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api", req.rawRequest.URI().String()) + utils.AssertEqual(t, "http://example.com/api", req.RawRequest.URI().String()) }) t.Run("the request url will override baseurl with protocol", func(t *testing.T) { @@ -96,7 +96,7 @@ func Test_Parser_Request_URL(t *testing.T) { err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api/v1", req.rawRequest.URI().String()) + utils.AssertEqual(t, "http://example.com/api/v1", req.RawRequest.URI().String()) }) t.Run("the request url should be append after baseurl without protocol", func(t *testing.T) { @@ -105,7 +105,7 @@ func Test_Parser_Request_URL(t *testing.T) { err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api/v1", req.rawRequest.URI().String()) + utils.AssertEqual(t, "http://example.com/api/v1", req.RawRequest.URI().String()) }) t.Run("the url is error", func(t *testing.T) { @@ -124,7 +124,7 @@ func Test_Parser_Request_URL(t *testing.T) { err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api/5", req.rawRequest.URI().String()) + utils.AssertEqual(t, "http://example.com/api/5", req.RawRequest.URI().String()) }) t.Run("the path param from request", func(t *testing.T) { @@ -141,7 +141,7 @@ func Test_Parser_Request_URL(t *testing.T) { err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api/5/fiber/%7Bkey%7D", req.rawRequest.URI().String()) + utils.AssertEqual(t, "http://example.com/api/5/fiber/%7Bkey%7D", req.RawRequest.URI().String()) }) t.Run("the path param from request and client", func(t *testing.T) { @@ -158,7 +158,7 @@ func Test_Parser_Request_URL(t *testing.T) { err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api/12/fiber/val", req.rawRequest.URI().String()) + utils.AssertEqual(t, "http://example.com/api/12/fiber/val", req.RawRequest.URI().String()) }) t.Run("query params from client should be set", func(t *testing.T) { @@ -168,7 +168,7 @@ func Test_Parser_Request_URL(t *testing.T) { err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("foo=bar"), req.rawRequest.URI().QueryString()) + utils.AssertEqual(t, []byte("foo=bar"), req.RawRequest.URI().QueryString()) }) t.Run("query params from request should be set", func(t *testing.T) { @@ -179,7 +179,7 @@ func Test_Parser_Request_URL(t *testing.T) { err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("bar=foo"), req.rawRequest.URI().QueryString()) + utils.AssertEqual(t, []byte("bar=foo"), req.RawRequest.URI().QueryString()) }) t.Run("query params should be merged", func(t *testing.T) { @@ -192,7 +192,7 @@ func Test_Parser_Request_URL(t *testing.T) { err := parserRequestURL(client, req) utils.AssertEqual(t, nil, err) - values, _ := url.ParseQuery(string(req.rawRequest.URI().QueryString())) + values, _ := url.ParseQuery(string(req.RawRequest.URI().QueryString())) flag1, flag2, flag3 := false, false, false for _, v := range values["bar"] { @@ -223,7 +223,7 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("application/json"), req.rawRequest.Header.ContentType()) + utils.AssertEqual(t, []byte("application/json"), req.RawRequest.Header.ContentType()) }) t.Run("request header should be set", func(t *testing.T) { @@ -236,7 +236,7 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("application/json, utf-8"), req.rawRequest.Header.ContentType()) + utils.AssertEqual(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType()) }) t.Run("request header should override client header", func(t *testing.T) { @@ -248,7 +248,7 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("application/json, utf-8"), req.rawRequest.Header.ContentType()) + utils.AssertEqual(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType()) }) t.Run("auto set json header", func(t *testing.T) { @@ -263,7 +263,7 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte(applicationJSON), req.rawRequest.Header.ContentType()) + utils.AssertEqual(t, []byte(applicationJSON), req.RawRequest.Header.ContentType()) }) t.Run("auto set xml header", func(t *testing.T) { @@ -279,7 +279,7 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte(applicationXML), req.rawRequest.Header.ContentType()) + utils.AssertEqual(t, []byte(applicationXML), req.RawRequest.Header.ContentType()) }) t.Run("auto set form data header", func(t *testing.T) { @@ -292,7 +292,7 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, applicationForm, string(req.rawRequest.Header.ContentType())) + utils.AssertEqual(t, applicationForm, string(req.RawRequest.Header.ContentType())) }) t.Run("auto set file header", func(t *testing.T) { @@ -303,8 +303,8 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Header.MultipartFormBoundary()), "--FiberFormBoundary")) - utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Header.ContentType()), multipartFormData)) + utils.AssertEqual(t, true, strings.Contains(string(req.RawRequest.Header.MultipartFormBoundary()), "--FiberFormBoundary")) + utils.AssertEqual(t, true, strings.Contains(string(req.RawRequest.Header.ContentType()), multipartFormData)) }) t.Run("ua should have default value", func(t *testing.T) { @@ -313,7 +313,7 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("fiber"), req.rawRequest.Header.UserAgent()) + utils.AssertEqual(t, []byte("fiber"), req.RawRequest.Header.UserAgent()) }) t.Run("ua in client should be set", func(t *testing.T) { @@ -322,7 +322,7 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("foo"), req.rawRequest.Header.UserAgent()) + utils.AssertEqual(t, []byte("foo"), req.RawRequest.Header.UserAgent()) }) t.Run("ua in request should have higher level", func(t *testing.T) { @@ -331,7 +331,7 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("bar"), req.rawRequest.Header.UserAgent()) + utils.AssertEqual(t, []byte("bar"), req.RawRequest.Header.UserAgent()) }) t.Run("referer in client should be set", func(t *testing.T) { @@ -340,7 +340,7 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("https://example.com"), req.rawRequest.Header.Referer()) + utils.AssertEqual(t, []byte("https://example.com"), req.RawRequest.Header.Referer()) }) t.Run("referer in request should have higher level", func(t *testing.T) { @@ -349,7 +349,7 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("https://example.com"), req.rawRequest.Header.Referer()) + utils.AssertEqual(t, []byte("https://example.com"), req.RawRequest.Header.Referer()) }) t.Run("client cookie should be set", func(t *testing.T) { @@ -365,9 +365,9 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "bar", string(req.rawRequest.Header.Cookie("foo"))) - utils.AssertEqual(t, "foo", string(req.rawRequest.Header.Cookie("bar"))) - utils.AssertEqual(t, "", string(req.rawRequest.Header.Cookie("bar1"))) + utils.AssertEqual(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) + utils.AssertEqual(t, "foo", string(req.RawRequest.Header.Cookie("bar"))) + utils.AssertEqual(t, "", string(req.RawRequest.Header.Cookie("bar1"))) }) t.Run("request cookie should be set", func(t *testing.T) { @@ -386,9 +386,9 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "bar", string(req.rawRequest.Header.Cookie("foo"))) - utils.AssertEqual(t, "67", string(req.rawRequest.Header.Cookie("bar"))) - utils.AssertEqual(t, "", string(req.rawRequest.Header.Cookie("bar1"))) + utils.AssertEqual(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) + utils.AssertEqual(t, "67", string(req.RawRequest.Header.Cookie("bar"))) + utils.AssertEqual(t, "", string(req.RawRequest.Header.Cookie("bar1"))) }) t.Run("request cookie will override client cookie", func(t *testing.T) { @@ -412,9 +412,9 @@ func Test_Parser_Request_Header(t *testing.T) { err := parserRequestHeader(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "bar", string(req.rawRequest.Header.Cookie("foo"))) - utils.AssertEqual(t, "67", string(req.rawRequest.Header.Cookie("bar"))) - utils.AssertEqual(t, "foo1", string(req.rawRequest.Header.Cookie("bar1"))) + utils.AssertEqual(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) + utils.AssertEqual(t, "67", string(req.RawRequest.Header.Cookie("bar"))) + utils.AssertEqual(t, "foo1", string(req.RawRequest.Header.Cookie("bar1"))) }) } @@ -433,7 +433,7 @@ func Test_Parser_Request_Body(t *testing.T) { err := parserRequestBody(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("{\"name\":\"foo\"}"), req.rawRequest.Body()) + utils.AssertEqual(t, []byte("{\"name\":\"foo\"}"), req.RawRequest.Body()) }) t.Run("xml body", func(t *testing.T) { @@ -449,7 +449,7 @@ func Test_Parser_Request_Body(t *testing.T) { err := parserRequestBody(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("foo"), req.rawRequest.Body()) + utils.AssertEqual(t, []byte("foo"), req.RawRequest.Body()) }) t.Run("form data body", func(t *testing.T) { @@ -461,7 +461,7 @@ func Test_Parser_Request_Body(t *testing.T) { err := parserRequestBody(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "ball=cricle+and+square", string(req.rawRequest.Body())) + utils.AssertEqual(t, "ball=cricle+and+square", string(req.RawRequest.Body())) }) t.Run("file body", func(t *testing.T) { @@ -471,8 +471,8 @@ func Test_Parser_Request_Body(t *testing.T) { err := parserRequestBody(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Body()), "----FiberFormBoundary")) - utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Body()), "world")) + utils.AssertEqual(t, true, strings.Contains(string(req.RawRequest.Body()), "----FiberFormBoundary")) + utils.AssertEqual(t, true, strings.Contains(string(req.RawRequest.Body()), "world")) }) t.Run("file and form data", func(t *testing.T) { @@ -483,9 +483,9 @@ func Test_Parser_Request_Body(t *testing.T) { err := parserRequestBody(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Body()), "----FiberFormBoundary")) - utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Body()), "world")) - utils.AssertEqual(t, true, strings.Contains(string(req.rawRequest.Body()), "bar")) + utils.AssertEqual(t, true, strings.Contains(string(req.RawRequest.Body()), "----FiberFormBoundary")) + utils.AssertEqual(t, true, strings.Contains(string(req.RawRequest.Body()), "world")) + utils.AssertEqual(t, true, strings.Contains(string(req.RawRequest.Body()), "bar")) }) t.Run("raw body", func(t *testing.T) { @@ -495,6 +495,6 @@ func Test_Parser_Request_Body(t *testing.T) { err := parserRequestBody(client, req) utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("hello world"), req.rawRequest.Body()) + utils.AssertEqual(t, []byte("hello world"), req.RawRequest.Body()) }) } diff --git a/client/request.go b/client/request.go index fb1baf4580..f445423cad 100644 --- a/client/request.go +++ b/client/request.go @@ -57,7 +57,7 @@ type Request struct { files []*File bodyType bodyType - rawRequest *fasthttp.Request + RawRequest *fasthttp.Request } // Method returns http method in request. @@ -550,7 +550,7 @@ func (r *Request) Reset() { r.cookies.Reset() r.header.Reset() r.params.Reset() - r.rawRequest.Reset() + r.RawRequest.Reset() } // Header is a wrapper which wrap http.Header, @@ -816,7 +816,7 @@ var requestPool = &sync.Pool{ boundary: "--FiberFormBoundary" + randString(16), formData: &FormData{Args: fasthttp.AcquireArgs()}, files: make([]*File, 0), - rawRequest: fasthttp.AcquireRequest(), + RawRequest: fasthttp.AcquireRequest(), } }, } diff --git a/client/response.go b/client/response.go index 56aa0c2216..4ac064012b 100644 --- a/client/response.go +++ b/client/response.go @@ -9,10 +9,11 @@ import ( ) type Response struct { - client *Client - request *Request - cookie []*fasthttp.Cookie - rawResponse *fasthttp.Response + client *Client + request *Request + cookie []*fasthttp.Cookie + + RawResponse *fasthttp.Response } // setClient method sets client object in response instance. @@ -29,22 +30,22 @@ func (r *Response) setRequest(req *Request) { // Status method returns the HTTP status string for the executed request. func (r *Response) Status() string { - return string(r.rawResponse.Header.StatusMessage()) + return string(r.RawResponse.Header.StatusMessage()) } // StatusCode method returns the HTTP status code for the executed request. func (r *Response) StatusCode() int { - return r.rawResponse.StatusCode() + return r.RawResponse.StatusCode() } // Protocol method returns the HTTP response protocol used for the request. func (r *Response) Protocol() string { - return string(r.rawResponse.Header.Protocol()) + return string(r.RawResponse.Header.Protocol()) } // Header method returns the response headers. func (r *Response) Header(key string) string { - return utils.UnsafeString(r.rawResponse.Header.Peek(key)) + return utils.UnsafeString(r.RawResponse.Header.Peek(key)) } // Cookies method to access all the response cookies. @@ -54,7 +55,7 @@ func (r *Response) Cookies() []*fasthttp.Cookie { // Body method returns HTTP response as []byte array for the executed request. func (r *Response) Body() []byte { - return r.rawResponse.Body() + return r.RawResponse.Body() } // String method returns the body of the server response as String. @@ -83,7 +84,7 @@ func (r *Response) Reset() { fasthttp.ReleaseCookie(t) } - r.rawResponse.Reset() + r.RawResponse.Reset() } // Close method will release Request object and Response object, @@ -101,7 +102,7 @@ var responsePool = &sync.Pool{ New: func() any { return &Response{ cookie: []*fasthttp.Cookie{}, - rawResponse: fasthttp.AcquireResponse(), + RawResponse: fasthttp.AcquireResponse(), } }, } diff --git a/client/response_test.go b/client/response_test.go index 9245044fd1..2a83d21540 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -2,7 +2,6 @@ package client import ( "encoding/xml" - "fmt" "testing" "github.com/gofiber/fiber/v3" @@ -144,8 +143,6 @@ func Test_Response_Cookie(t *testing.T) { SetClient(client). Get("http://example.com") - fmt.Println(resp.rawResponse.String()) - utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "bar", string(resp.Cookies()[0].Value())) resp.Close() @@ -215,8 +212,6 @@ func Test_Response_Body(t *testing.T) { SetClient(client). Get("http://example.com/xml") - fmt.Println(resp.rawResponse.String()) - utils.AssertEqual(t, nil, err) tmp := &body{} From 6f48694b5a44af6b44ca49a301c92b682bcbd394 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sun, 21 Aug 2022 10:08:57 +0800 Subject: [PATCH 031/118] =?UTF-8?q?=F0=9F=90=9B=20fix:=20fix=20data=20race?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/core.go | 17 ++++++++++++----- client/hooks.go | 3 ++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/client/core.go b/client/core.go index 83d2e97314..afc9897df6 100644 --- a/client/core.go +++ b/client/core.go @@ -6,6 +6,7 @@ import ( "encoding/xml" "errors" "sync" + "sync/atomic" "github.com/gofiber/fiber/v3/utils" "github.com/valyala/fasthttp" @@ -52,6 +53,7 @@ func (c *core) execFunc(ctx context.Context, client *Client, req *Request) (*Res resp.setRequest(req) // To avoid memory allocation reuse of data structures such as errch. + done := int32(0) errCh, reqv, respv := acquireErrChan(), fasthttp.AcquireRequest(), fasthttp.AcquireResponse() defer func() { releaseErrChan(errCh) @@ -62,12 +64,14 @@ func (c *core) execFunc(ctx context.Context, client *Client, req *Request) (*Res req.RawRequest.CopyTo(reqv) go func() { err := c.client.Do(reqv, respv) - if err != nil { - errCh <- err - return + + if atomic.CompareAndSwapInt32(&done, 0, 1) { + if err != nil { + errCh <- err + return + } + errCh <- nil } - respv.CopyTo(resp.RawResponse) - errCh <- nil }() select { @@ -77,8 +81,11 @@ func (c *core) execFunc(ctx context.Context, client *Client, req *Request) (*Res ReleaseResponse(resp) return nil, err } + + respv.CopyTo(resp.RawResponse) return resp, nil case <-ctx.Done(): + atomic.SwapInt32(&done, 1) ReleaseResponse(resp) return nil, ErrTimeoutOrCancel } diff --git a/client/hooks.go b/client/hooks.go index 94cb604bd0..8fba3d2b4e 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -30,7 +30,6 @@ var ( applicationForm = "application/x-www-form-urlencoded" multipartFormData = "multipart/form-data" - src = rand.NewSource(time.Now().UnixNano()) letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" letterIdxBits = 6 // 6 bits to represent a letter index letterIdxMask = 1<= 0; { if remain == 0 { cache, remain = src.Int63(), letterIdxMax From 35742a77a566983d6af9dd9444c7f32840b7fbf9 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sun, 21 Aug 2022 10:09:43 +0800 Subject: [PATCH 032/118] =?UTF-8?q?=F0=9F=94=92=EF=B8=8F=20chore:=20change?= =?UTF-8?q?=20package?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/request_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/client/request_test.go b/client/request_test.go index 6a7e196ebf..485ab1077a 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -5,7 +5,6 @@ import ( "context" "errors" "io" - "io/ioutil" "mime/multipart" "os" "path/filepath" @@ -529,7 +528,7 @@ func Test_Request_File(t *testing.T) { utils.AssertEqual(t, "tmp.txt", req.File("tmp.txt").name) - content, err := ioutil.ReadAll(req.File("tmp.txt").reader) + content, err := io.ReadAll(req.File("tmp.txt").reader) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "world", string(content)) }) From b13310347575ada5596f5f0de7470b21d3707941 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sun, 21 Aug 2022 10:20:04 +0800 Subject: [PATCH 033/118] =?UTF-8?q?=F0=9F=90=9B=20fix:=20data=20race?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/core.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/client/core.go b/client/core.go index afc9897df6..f719c14338 100644 --- a/client/core.go +++ b/client/core.go @@ -54,22 +54,26 @@ func (c *core) execFunc(ctx context.Context, client *Client, req *Request) (*Res // To avoid memory allocation reuse of data structures such as errch. done := int32(0) - errCh, reqv, respv := acquireErrChan(), fasthttp.AcquireRequest(), fasthttp.AcquireResponse() + errCh, reqv := acquireErrChan(), fasthttp.AcquireRequest() defer func() { releaseErrChan(errCh) - fasthttp.ReleaseRequest(reqv) - fasthttp.ReleaseResponse(respv) }() req.RawRequest.CopyTo(reqv) go func() { + respv := fasthttp.AcquireResponse() err := c.client.Do(reqv, respv) + defer func() { + fasthttp.ReleaseRequest(reqv) + fasthttp.ReleaseResponse(respv) + }() if atomic.CompareAndSwapInt32(&done, 0, 1) { if err != nil { errCh <- err return } + respv.CopyTo(resp.RawResponse) errCh <- nil } }() @@ -81,8 +85,6 @@ func (c *core) execFunc(ctx context.Context, client *Client, req *Request) (*Res ReleaseResponse(resp) return nil, err } - - respv.CopyTo(resp.RawResponse) return resp, nil case <-ctx.Done(): atomic.SwapInt32(&done, 1) From 1a578728134955cb21dd8f53e7c4daa80cb11129 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sun, 21 Aug 2022 10:41:14 +0800 Subject: [PATCH 034/118] =?UTF-8?q?=F0=9F=90=9B=20fix:=20test=20fail?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/request_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/request_test.go b/client/request_test.go index 485ab1077a..6a38f67e20 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -1150,7 +1150,7 @@ func Test_Request_Error_Body_With_Server(t *testing.T) { _, err := AcquireRequest(). AddFile("non-exist-file!"). Get("http://example.com") - utils.AssertEqual(t, "open non-exist-file!: The system cannot find the file specified.", err.Error()) + utils.AssertEqual(t, "open non-exist-file!: no such file or directory", err.Error()) }) } From 43b6d28a5ff6d72fb5a47a09fe133b92b0ec6e55 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sun, 21 Aug 2022 15:11:27 +0800 Subject: [PATCH 035/118] =?UTF-8?q?=E2=9C=A8=20feat:=20move=20core=20to=20?= =?UTF-8?q?req?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 69 +++++++++++++++++++++++++------------ client/client_test.go | 6 ++-- client/core.go | 38 +++------------------ client/core_test.go | 75 +++++++++++++++-------------------------- client/helper_test.go | 22 ++++++------ client/hooks.go | 8 ++--- client/request.go | 26 +++++++++----- client/request_test.go | 48 +++++++++++++------------- client/response.go | 4 +-- client/response_test.go | 34 +++++++++---------- 10 files changed, 157 insertions(+), 173 deletions(-) diff --git a/client/client.go b/client/client.go index a6d1bcfeb1..81de5f6d1d 100644 --- a/client/client.go +++ b/client/client.go @@ -1,6 +1,9 @@ package client import ( + "encoding/json" + "encoding/xml" + "net" "sync" "time" @@ -15,8 +18,6 @@ import ( // Fiber Client also provides an option to override // or merge most of the client settings at the request. type Client struct { - core *core - baseUrl string userAgent string referer string @@ -26,6 +27,23 @@ type Client struct { path *PathParam timeout time.Duration + + // user defined request hooks + userRequestHooks []RequestHook + + // client package defined request hooks + buildinRequestHooks []RequestHook + + // user defined response hooks + userResponseHooks []ResponseHook + + // client package defined respose hooks + buildinResposeHooks []ResponseHook + + jsonMarshal utils.JSONMarshal + jsonUnmarshal utils.JSONUnmarshal + xmlMarshal utils.XMLMarshal + xmlUnmarshal utils.XMLUnmarshal } // R raise a request from the client. @@ -35,74 +53,67 @@ func (c *Client) R() *Request { // Request returns user-defined request hooks. func (c *Client) RequestHook() []RequestHook { - return c.core.userRequestHooks + return c.userRequestHooks } // Add user-defined request hooks. func (c *Client) AddRequestHook(h ...RequestHook) *Client { - c.core.userRequestHooks = append(c.core.userRequestHooks, h...) + c.userRequestHooks = append(c.userRequestHooks, h...) return c } // ResponseHook return user-define reponse hooks. func (c *Client) ResponseHook() []ResponseHook { - return c.core.userResponseHooks + return c.userResponseHooks } // Add user-defined response hooks. func (c *Client) AddResponseHook(h ...ResponseHook) *Client { - c.core.userResponseHooks = append(c.core.userResponseHooks, h...) - return c -} - -// Set HostClient dial, this method for unit test, -// maybe don't use it. -func (c *Client) SetDial(f fasthttp.DialFunc) *Client { - c.core.client.Dial = f + c.userResponseHooks = append(c.userResponseHooks, h...) return c } // JSONMarshal returns json marshal function in Core. func (c *Client) JSONMarshal() utils.JSONMarshal { - return c.core.jsonMarshal + return c.jsonMarshal } // Set json encoder. func (c *Client) SetJSONMarshal(f utils.JSONMarshal) *Client { - c.core.jsonMarshal = f + c.jsonMarshal = f return c } // JSONUnmarshal returns json unmarshal function in Core. func (c *Client) JSONUnmarshal() utils.JSONUnmarshal { - return c.core.jsonUnmarshal + return c.jsonUnmarshal } // Set json decoder. func (c *Client) SetJSONUnmarshal(f utils.JSONUnmarshal) *Client { - c.core.jsonUnmarshal = f + c.jsonUnmarshal = f return c } // XMLMarshal returns xml marshal function in Core. func (c *Client) XMLMarshal() utils.XMLMarshal { - return c.core.xmlMarshal + return c.xmlMarshal } // Set xml encoder. func (c *Client) SetXMLMarshal(f utils.XMLMarshal) *Client { - c.core.xmlMarshal = f + c.xmlMarshal = f return c } // XMLUnmarshal returns xml unmarshal function in Core. func (c *Client) XMLUnmarshal() utils.XMLUnmarshal { - return c.core.xmlUnmarshal + return c.xmlUnmarshal } // Set xml decoder. func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client { - c.core.xmlUnmarshal = f + c.xmlUnmarshal = f return c } @@ -421,13 +432,18 @@ func SetRequestFiles(files ...*File) SetRequestOptionFunc { } } +func SetDial(f func(addr string) (net.Conn, error)) SetRequestOptionFunc { + return func(r *Request) { + r.core.client.Dial = f + } +} + var ( defaultClient *Client defaultUserAgent = "fiber" clientPool = &sync.Pool{ New: func() any { return &Client{ - core: newCore(), header: &Header{ RequestHeader: &fasthttp.RequestHeader{}, }, @@ -436,6 +452,15 @@ var ( }, cookies: &Cookie{}, path: &PathParam{}, + + userRequestHooks: []RequestHook{}, + buildinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, + userResponseHooks: []ResponseHook{}, + buildinResposeHooks: []ResponseHook{parserResponseCookie}, + jsonMarshal: json.Marshal, + jsonUnmarshal: json.Unmarshal, + xmlMarshal: xml.Marshal, + xmlUnmarshal: xml.Unmarshal, } }, } diff --git a/client/client_test.go b/client/client_test.go index ce051f2c74..484af50876 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -63,11 +63,9 @@ func Test_Get(t *testing.T) { }() t.Run("global get function", func(t *testing.T) { - C().SetDial(func(addr string) (net.Conn, error) { + resp, err := Get("http://example.com", SetDial(func(addr string) (net.Conn, error) { return ln.Dial() - }) - - resp, err := Get("http://example.com") + })) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "example.com", utils.UnsafeString(resp.RawResponse.Body())) }) diff --git a/client/core.go b/client/core.go index f719c14338..11e363b4ba 100644 --- a/client/core.go +++ b/client/core.go @@ -2,13 +2,10 @@ package client import ( "context" - "encoding/json" - "encoding/xml" "errors" "sync" "sync/atomic" - "github.com/gofiber/fiber/v3/utils" "github.com/valyala/fasthttp" ) @@ -28,23 +25,6 @@ type ResponseHook func(*Client, *Response, *Request) error // and defines the execution process type core struct { client *fasthttp.HostClient - - // user defined request hooks - userRequestHooks []RequestHook - - // client package defined request hooks - buildinRequestHooks []RequestHook - - // user defined response hooks - userResponseHooks []ResponseHook - - // client package defined respose hooks - buildinResposeHooks []ResponseHook - - jsonMarshal utils.JSONMarshal - jsonUnmarshal utils.JSONUnmarshal - xmlMarshal utils.XMLMarshal - xmlUnmarshal utils.XMLUnmarshal } func (c *core) execFunc(ctx context.Context, client *Client, req *Request) (*Response, error) { @@ -97,14 +77,14 @@ func (c *core) execFunc(ctx context.Context, client *Client, req *Request) (*Res func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Response, error) { // The built-in hooks will be executed only // after the user-defined hooks are executed。 - for _, f := range c.userRequestHooks { + for _, f := range client.userRequestHooks { err := f(client, req) if err != nil { return nil, err } } - for _, f := range c.buildinRequestHooks { + for _, f := range client.buildinRequestHooks { err := f(client, req) if err != nil { return nil, err @@ -136,14 +116,14 @@ func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Resp // The built-in hooks will be executed only // before the user-defined hooks are executed. - for _, f := range c.buildinResposeHooks { + for _, f := range client.buildinResposeHooks { err := f(client, resp, req) if err != nil { return nil, err } } - for _, f := range c.userResponseHooks { + for _, f := range client.userResponseHooks { err := f(client, resp, req) if err != nil { return nil, err @@ -177,15 +157,7 @@ func releaseErrChan(ch chan error) { // newCore returns an empty core object. func newCore() (c *core) { c = &core{ - client: &fasthttp.HostClient{}, - userRequestHooks: []RequestHook{}, - buildinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, - userResponseHooks: []ResponseHook{}, - buildinResposeHooks: []ResponseHook{parserResponseCookie}, - jsonMarshal: json.Marshal, - jsonUnmarshal: json.Unmarshal, - xmlMarshal: xml.Marshal, - xmlUnmarshal: xml.Unmarshal, + client: &fasthttp.HostClient{}, } return diff --git a/client/core_test.go b/client/core_test.go index 106b46f629..344f475dee 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -9,7 +9,6 @@ import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/utils" - "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" ) @@ -36,7 +35,7 @@ func Test_Exec_Func(t *testing.T) { t.Run("normal request", func(t *testing.T) { client, req := AcquireClient(), AcquireRequest() - core := client.core + core := req.core core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.RawRequest.SetRequestURI("http://example.com/normal") @@ -48,7 +47,7 @@ func Test_Exec_Func(t *testing.T) { t.Run("the request return an error", func(t *testing.T) { client, req := AcquireClient(), AcquireRequest() - core := client.core + core := req.core core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.RawRequest.SetRequestURI("http://example.com/return-error") @@ -59,27 +58,9 @@ func Test_Exec_Func(t *testing.T) { utils.AssertEqual(t, "the request is error", string(resp.RawResponse.Body())) }) - t.Run("there is no connect", func(t *testing.T) { - client := AcquireClient() - core := client.core - core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - core.client.SetMaxConns(1) - - go func() { - req := AcquireRequest() - req.RawRequest.SetRequestURI("http://example.com/normal") - _, err := core.execFunc(context.Background(), client, req) - utils.AssertEqual(t, fasthttp.ErrNoFreeConns, err) - }() - - req := AcquireRequest() - req.RawRequest.SetRequestURI("http://example.com/hang-up") - core.execFunc(context.Background(), client, req) - }) - t.Run("the request timeout", func(t *testing.T) { client, req := AcquireClient(), AcquireRequest() - core := client.core + core := req.core core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.RawRequest.SetRequestURI("http://example.com/hang-up") @@ -121,12 +102,12 @@ func Test_Execute(t *testing.T) { client.AddRequestHook(func(c *Client, r *Request) error { utils.AssertEqual(t, "http://example.com", req.URL()) return nil - }).SetDial(func(addr string) (net.Conn, error) { - return ln.Dial() }) - req.SetURL("http://example.com") + req.SetDial(func(addr string) (net.Conn, error) { + return ln.Dial() + }).SetURL("http://example.com") - resp, err := client.core.execute(context.Background(), client, req) + resp, err := req.core.execute(context.Background(), client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "Cannot GET /", string(resp.RawResponse.Body())) }) @@ -136,61 +117,61 @@ func Test_Execute(t *testing.T) { client.AddResponseHook(func(c *Client, resp *Response, req *Request) error { utils.AssertEqual(t, "http://example.com", req.URL()) return nil - }).SetDial(func(addr string) (net.Conn, error) { - return ln.Dial() }) - req.SetURL("http://example.com") + req.SetDial(func(addr string) (net.Conn, error) { + return ln.Dial() + }).SetURL("http://example.com") - resp, err := client.core.execute(context.Background(), client, req) + resp, err := req.core.execute(context.Background(), client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "Cannot GET /", string(resp.RawResponse.Body())) }) t.Run("no timeout", func(t *testing.T) { client, req := AcquireClient(), AcquireRequest() - client.SetDial(func(addr string) (net.Conn, error) { + + req.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() - }) - req.SetURL("http://example.com/hang-up") + }).SetURL("http://example.com/hang-up") - resp, err := client.core.execute(context.Background(), client, req) + resp, err := req.core.execute(context.Background(), client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "example.com hang up", string(resp.RawResponse.Body())) }) t.Run("client timeout", func(t *testing.T) { client, req := AcquireClient(), AcquireRequest() - client.SetDial(func(addr string) (net.Conn, error) { + client.SetTimeout(500 * time.Millisecond) + req.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() - }).SetTimeout(500 * time.Millisecond) - req.SetURL("http://example.com/hang-up") + }).SetURL("http://example.com/hang-up") - _, err := client.core.execute(context.Background(), client, req) + _, err := req.core.execute(context.Background(), client, req) utils.AssertEqual(t, ErrTimeoutOrCancel, err) }) t.Run("request timeout", func(t *testing.T) { client, req := AcquireClient(), AcquireRequest() - client.SetDial(func(addr string) (net.Conn, error) { + + req.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() - }) - req.SetURL("http://example.com/hang-up"). + }).SetURL("http://example.com/hang-up"). SetTimeout(300 * time.Millisecond) - _, err := client.core.execute(context.Background(), client, req) + _, err := req.core.execute(context.Background(), client, req) utils.AssertEqual(t, ErrTimeoutOrCancel, err) }) t.Run("request timeout has higher level", func(t *testing.T) { client, req := AcquireClient(), AcquireRequest() - client.SetDial(func(addr string) (net.Conn, error) { + client.SetTimeout(30 * time.Millisecond) + + req.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() - }). - SetTimeout(30 * time.Millisecond) - req.SetURL("http://example.com/hang-up"). + }).SetURL("http://example.com/hang-up"). SetTimeout(3000 * time.Millisecond) - resp, err := client.core.execute(context.Background(), client, req) + resp, err := req.core.execute(context.Background(), client, req) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "example.com hang up", string(resp.RawResponse.Body())) }) diff --git a/client/helper_test.go b/client/helper_test.go index 3b0f82b0c4..b85bc0758f 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -9,25 +9,23 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) -func createHelperServer(t *testing.T) (*fiber.App, *Client, func()) { +func createHelperServer(t *testing.T) (*fiber.App, func(addr string) (net.Conn, error), func()) { t.Helper() ln := fasthttputil.NewInmemoryListener() app := fiber.New(fiber.Config{DisableStartupMessage: true}) - client := AcquireClient().SetDial(func(addr string) (net.Conn, error) { - return ln.Dial() - }) - - return app, client, func() { - utils.AssertEqual(t, nil, app.Listener(ln)) - } + return app, func(addr string) (net.Conn, error) { + return ln.Dial() + }, func() { + utils.AssertEqual(t, nil, app.Listener(ln)) + } } func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted string, count ...int) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Get("/", handler) go start() @@ -37,7 +35,7 @@ func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Reques } for i := 0; i < c; i++ { - req := AcquireRequest().SetClient(client) + req := AcquireRequest().SetDial(ln) wrapAgent(req) resp, err := req.Get("http://example.com") @@ -52,7 +50,7 @@ func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Reques func testAgentFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted error, count ...int) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Get("/", handler) go start() @@ -62,7 +60,7 @@ func testAgentFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Re } for i := 0; i < c; i++ { - req := AcquireRequest().SetClient(client) + req := AcquireRequest().SetDial(ln) wrapAgent(req) _, err := req.Get("http://example.com") diff --git a/client/hooks.go b/client/hooks.go index 8fba3d2b4e..bf9c889c33 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -107,8 +107,8 @@ func parserRequestURL(c *Client, req *Request) error { return ErrNotSupportSchema } - c.core.client.Addr = addMissingPort(string(rawUri.Host()), isTLS) - c.core.client.IsTLS = isTLS + req.core.client.Addr = addMissingPort(string(rawUri.Host()), isTLS) + req.core.client.IsTLS = isTLS // merge query params hashSplit := strings.Split(splitUrl[1], "#") @@ -195,13 +195,13 @@ func parserRequestHeader(c *Client, req *Request) error { func parserRequestBody(c *Client, req *Request) error { switch req.bodyType { case jsonBody: - body, err := c.core.jsonMarshal(req.body) + body, err := c.jsonMarshal(req.body) if err != nil { return err } req.RawRequest.SetBody(body) case xmlBody: - body, err := c.core.xmlMarshal(req.body) + body, err := c.xmlMarshal(req.body) if err != nil { return err } diff --git a/client/request.go b/client/request.go index f445423cad..7350d6a350 100644 --- a/client/request.go +++ b/client/request.go @@ -37,6 +37,8 @@ const ( ) type Request struct { + core *core + url string method string userAgent string @@ -60,6 +62,13 @@ type Request struct { RawRequest *fasthttp.Request } +// Set HostClient dial, this method for unit test, +// maybe don't use it. +func (r *Request) SetDial(f fasthttp.DialFunc) *Request { + r.core.client.Dial = f + return r +} + // Method returns http method in request. func (r *Request) Method() string { return r.method @@ -477,56 +486,56 @@ func (r *Request) checkClient() { func (r *Request) Get(url string) (*Response, error) { r.SetURL(url).SetMethod(fiber.MethodGet).checkClient() - return r.client.core.execute(r.Context(), r.client, r) + return r.core.execute(r.Context(), r.client, r) } // Send post request. func (r *Request) Post(url string) (*Response, error) { r.SetURL(url).SetMethod(fiber.MethodPost).checkClient() - return r.client.core.execute(r.Context(), r.client, r) + return r.core.execute(r.Context(), r.client, r) } // Send head request. func (r *Request) Head(url string) (*Response, error) { r.SetURL(url).SetMethod(fiber.MethodHead).checkClient() - return r.client.core.execute(r.Context(), r.client, r) + return r.core.execute(r.Context(), r.client, r) } // Send put request. func (r *Request) Put(url string) (*Response, error) { r.SetURL(url).SetMethod(fiber.MethodPut).checkClient() - return r.client.core.execute(r.Context(), r.client, r) + return r.core.execute(r.Context(), r.client, r) } // Send Delete request. func (r *Request) Delete(url string) (*Response, error) { r.SetURL(url).SetMethod(fiber.MethodDelete).checkClient() - return r.client.core.execute(r.Context(), r.client, r) + return r.core.execute(r.Context(), r.client, r) } // Send Options reuqest. func (r *Request) Options(url string) (*Response, error) { r.SetURL(url).SetMethod(fiber.MethodOptions).checkClient() - return r.client.core.execute(r.Context(), r.client, r) + return r.core.execute(r.Context(), r.client, r) } // Send patch request. func (r *Request) Patch(url string) (*Response, error) { r.SetURL(url).SetMethod(fiber.MethodPatch).checkClient() - return r.client.core.execute(r.Context(), r.client, r) + return r.core.execute(r.Context(), r.client, r) } // Send a request. func (r *Request) Send() (*Response, error) { r.checkClient() - return r.client.core.execute(r.Context(), r.client, r) + return r.core.execute(r.Context(), r.client, r) } // Reset clear Request object, used by ReleaseRequest method. @@ -809,6 +818,7 @@ func (f *File) Reset() { var requestPool = &sync.Pool{ New: func() any { return &Request{ + core: newCore(), header: &Header{RequestHeader: &fasthttp.RequestHeader{}}, params: &QueryParam{Args: fasthttp.AcquireArgs()}, cookies: &Cookie{}, diff --git a/client/request_test.go b/client/request_test.go index 485ab1077a..e0b468cff6 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -572,14 +572,14 @@ func Test_Request_Unsupport_Protocol(t *testing.T) { func Test_Request_Get(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) go start() for i := 0; i < 5; i++ { - req := AcquireRequest().SetClient(client) + req := AcquireRequest().SetDial(ln) resp, err := req.Get("http://example.com") utils.AssertEqual(t, nil, err) @@ -592,7 +592,7 @@ func Test_Request_Get(t *testing.T) { func Test_Request_Post(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Post("/", func(c fiber.Ctx) error { return c.Status(fiber.StatusCreated). SendString(c.FormValue("foo")) @@ -601,7 +601,7 @@ func Test_Request_Post(t *testing.T) { for i := 0; i < 5; i++ { resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). SetFormData("foo", "bar"). Post("http://example.com") @@ -615,7 +615,7 @@ func Test_Request_Post(t *testing.T) { func Test_Request_Head(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) @@ -624,7 +624,7 @@ func Test_Request_Head(t *testing.T) { for i := 0; i < 5; i++ { resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). Head("http://example.com") utils.AssertEqual(t, nil, err) @@ -637,7 +637,7 @@ func Test_Request_Head(t *testing.T) { func Test_Request_Put(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Put("/", func(c fiber.Ctx) error { return c.SendString(c.FormValue("foo")) }) @@ -646,7 +646,7 @@ func Test_Request_Put(t *testing.T) { for i := 0; i < 5; i++ { resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). SetFormData("foo", "bar"). Put("http://example.com") @@ -660,7 +660,7 @@ func Test_Request_Put(t *testing.T) { func Test_Request_Delete(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Delete("/", func(c fiber.Ctx) error { return c.Status(fiber.StatusNoContent). @@ -671,7 +671,7 @@ func Test_Request_Delete(t *testing.T) { for i := 0; i < 5; i++ { resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). Delete("http://example.com") utils.AssertEqual(t, nil, err) @@ -685,7 +685,7 @@ func Test_Request_Delete(t *testing.T) { func Test_Request_Options(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Options("/", func(c fiber.Ctx) error { return c.Status(fiber.StatusOK). @@ -696,7 +696,7 @@ func Test_Request_Options(t *testing.T) { for i := 0; i < 5; i++ { resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). Options("http://example.com") utils.AssertEqual(t, nil, err) @@ -710,7 +710,7 @@ func Test_Request_Options(t *testing.T) { func Test_Request_Send(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Post("/", func(c fiber.Ctx) error { return c.Status(fiber.StatusOK). @@ -721,7 +721,7 @@ func Test_Request_Send(t *testing.T) { for i := 0; i < 5; i++ { resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). SetURL("http://example.com"). SetMethod(fiber.MethodPost). Send() @@ -737,7 +737,7 @@ func Test_Request_Send(t *testing.T) { func Test_Request_Patch(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Patch("/", func(c fiber.Ctx) error { return c.SendString(c.FormValue("foo")) @@ -747,7 +747,7 @@ func Test_Request_Patch(t *testing.T) { for i := 0; i < 5; i++ { resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). SetFormData("foo", "bar"). Patch("http://example.com") @@ -983,7 +983,7 @@ func Test_Request_Body_With_Server(t *testing.T) { t.Run("multipart form", func(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Post("/", func(c fiber.Ctx) error { utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) @@ -997,7 +997,7 @@ func Test_Request_Body_With_Server(t *testing.T) { go start() req := AcquireRequest(). - SetClient(client). + SetDial(ln). SetBoundary("myBoundary"). SetFormData("foo", "bar"). AddFiles(AcquireFile( @@ -1019,7 +1019,7 @@ func Test_Request_Body_With_Server(t *testing.T) { t.Run("multipart form send file", func(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Post("/", func(c fiber.Ctx) error { utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) @@ -1049,7 +1049,7 @@ func Test_Request_Body_With_Server(t *testing.T) { for i := 0; i < 5; i++ { req := AcquireRequest(). - SetClient(client). + SetDial(ln). AddFiles( AcquireFile( SetFileFieldName("field1"), @@ -1072,7 +1072,7 @@ func Test_Request_Body_With_Server(t *testing.T) { t.Run("multipart random boundary", func(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Post("/", func(c fiber.Ctx) error { reg := regexp.MustCompile(`multipart/form-data; boundary=[\-\w]{35}`) utils.AssertEqual(t, true, reg.MatchString(c.Get(fiber.HeaderContentType))) @@ -1083,7 +1083,7 @@ func Test_Request_Body_With_Server(t *testing.T) { go start() req := AcquireRequest(). - SetClient(client). + SetDial(ln). SetFormData("foo", "bar"). AddFiles(AcquireFile( SetFileName("hello.txt"), @@ -1157,7 +1157,7 @@ func Test_Request_Error_Body_With_Server(t *testing.T) { func Test_Request_Timeout_With_Server(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { time.Sleep(time.Millisecond * 200) return c.SendString("timeout") @@ -1165,7 +1165,7 @@ func Test_Request_Timeout_With_Server(t *testing.T) { go start() _, err := AcquireRequest(). - SetClient(client). + SetDial(ln). SetTimeout(50 * time.Millisecond). Get("http://example.com") diff --git a/client/response.go b/client/response.go index 4ac064012b..48c722d153 100644 --- a/client/response.go +++ b/client/response.go @@ -65,12 +65,12 @@ func (r *Response) String() string { // JSON method will unmarshal body to json. func (r *Response) JSON(v any) error { - return r.client.core.jsonUnmarshal(r.Body(), v) + return r.client.jsonUnmarshal(r.Body(), v) } // XML method will unmarshal body to xml. func (r *Response) XML(v any) error { - return r.client.core.xmlUnmarshal(r.Body(), v) + return r.client.xmlUnmarshal(r.Body(), v) } // Reset clear Response object. diff --git a/client/response_test.go b/client/response_test.go index 2a83d21540..4cb08e7c89 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -11,7 +11,7 @@ import ( func Test_Response_Status(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { return c.SendString("foo") }) @@ -24,7 +24,7 @@ func Test_Response_Status(t *testing.T) { t.Parallel() resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). Get("http://example") utils.AssertEqual(t, nil, err) @@ -36,7 +36,7 @@ func Test_Response_Status(t *testing.T) { t.Parallel() resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). Get("http://example/fail") utils.AssertEqual(t, nil, err) @@ -48,7 +48,7 @@ func Test_Response_Status(t *testing.T) { func Test_Response_Status_Code(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { return c.SendString("foo") }) @@ -61,7 +61,7 @@ func Test_Response_Status_Code(t *testing.T) { t.Parallel() resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). Get("http://example") utils.AssertEqual(t, nil, err) @@ -73,7 +73,7 @@ func Test_Response_Status_Code(t *testing.T) { t.Parallel() resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). Get("http://example/fail") utils.AssertEqual(t, nil, err) @@ -86,14 +86,14 @@ func Test_Response_Protocol(t *testing.T) { t.Parallel() t.Run("http", func(t *testing.T) { - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { return c.SendString("foo") }) go start() resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). Get("http://example") utils.AssertEqual(t, nil, err) @@ -110,7 +110,7 @@ func Test_Response_Protocol(t *testing.T) { func Test_Response_Header(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { c.Response().Header.Add("foo", "bar") return c.SendString("helo world") @@ -118,7 +118,7 @@ func Test_Response_Header(t *testing.T) { go start() resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). Get("http://example.com") utils.AssertEqual(t, nil, err) @@ -129,7 +129,7 @@ func Test_Response_Header(t *testing.T) { func Test_Response_Cookie(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "foo", @@ -140,7 +140,7 @@ func Test_Response_Cookie(t *testing.T) { go start() resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). Get("http://example.com") utils.AssertEqual(t, nil, err) @@ -151,7 +151,7 @@ func Test_Response_Cookie(t *testing.T) { func Test_Response_Body(t *testing.T) { t.Parallel() - app, client, start := createHelperServer(t) + app, ln, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { return c.SendString("hello world") }) @@ -166,7 +166,7 @@ func Test_Response_Body(t *testing.T) { t.Run("raw body", func(t *testing.T) { resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). Get("http://example.com") utils.AssertEqual(t, nil, err) @@ -176,7 +176,7 @@ func Test_Response_Body(t *testing.T) { t.Run("string body", func(t *testing.T) { resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). Get("http://example.com") utils.AssertEqual(t, nil, err) @@ -190,7 +190,7 @@ func Test_Response_Body(t *testing.T) { } resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). Get("http://example.com/json") utils.AssertEqual(t, nil, err) @@ -209,7 +209,7 @@ func Test_Response_Body(t *testing.T) { } resp, err := AcquireRequest(). - SetClient(client). + SetDial(ln). Get("http://example.com/xml") utils.AssertEqual(t, nil, err) From 9156dc56d0d2bd882c07bf122c3c016ef6ef4ef7 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sun, 21 Aug 2022 17:37:49 +0800 Subject: [PATCH 036/118] =?UTF-8?q?=F0=9F=90=9B=20fix:=20connection=20reus?= =?UTF-8?q?e?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/request.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/request.go b/client/request.go index 7350d6a350..0229bd5a36 100644 --- a/client/request.go +++ b/client/request.go @@ -540,6 +540,7 @@ func (r *Request) Send() (*Response, error) { // Reset clear Request object, used by ReleaseRequest method. func (r *Request) Reset() { + r.core = nil r.url = "" r.method = fiber.MethodGet r.userAgent = "" @@ -818,7 +819,6 @@ func (f *File) Reset() { var requestPool = &sync.Pool{ New: func() any { return &Request{ - core: newCore(), header: &Header{RequestHeader: &fasthttp.RequestHeader{}}, params: &QueryParam{Args: fasthttp.AcquireArgs()}, cookies: &Cookie{}, @@ -838,6 +838,7 @@ var requestPool = &sync.Pool{ func AcquireRequest() *Request { req := requestPool.Get().(*Request) req.boundary = "--FiberFormBoundary" + randString(16) + req.core = newCore() return req } From 9d5949184d0e69c8d9217cecfa961d14ed7f98bd Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sun, 21 Aug 2022 18:48:01 +0800 Subject: [PATCH 037/118] =?UTF-8?q?=F0=9F=90=9B=20fix:=20data=20race?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 2 ++ client/core.go | 58 ++++++++++++++++++++++++++++++++---------------- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/client/client.go b/client/client.go index 81de5f6d1d..9910781252 100644 --- a/client/client.go +++ b/client/client.go @@ -18,6 +18,8 @@ import ( // Fiber Client also provides an option to override // or merge most of the client settings at the request. type Client struct { + mu sync.RWMutex + baseUrl string userAgent string referer string diff --git a/client/core.go b/client/core.go index 11e363b4ba..5377663bea 100644 --- a/client/core.go +++ b/client/core.go @@ -76,19 +76,29 @@ func (c *core) execFunc(ctx context.Context, client *Client, req *Request) (*Res // execute will exec each hooks and plugins. func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Response, error) { // The built-in hooks will be executed only - // after the user-defined hooks are executed。 - for _, f := range client.userRequestHooks { - err := f(client, req) - if err != nil { - return nil, err + // after the user-defined hooks are executed. + err := func() error { + client.mu.RLock() + defer client.mu.RUnlock() + + for _, f := range client.userRequestHooks { + err := f(client, req) + if err != nil { + return err + } } - } - for _, f := range client.buildinRequestHooks { - err := f(client, req) - if err != nil { - return nil, err + for _, f := range client.buildinRequestHooks { + err := f(client, req) + if err != nil { + return err + } } + + return nil + }() + if err != nil { + return nil, err } // deal with timeout @@ -116,18 +126,28 @@ func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Resp // The built-in hooks will be executed only // before the user-defined hooks are executed. - for _, f := range client.buildinResposeHooks { - err := f(client, resp, req) - if err != nil { - return nil, err + err = func() error { + client.mu.RLock() + defer client.mu.RUnlock() + for _, f := range client.buildinResposeHooks { + err := f(client, resp, req) + if err != nil { + return err + } } - } - for _, f := range client.userResponseHooks { - err := f(client, resp, req) - if err != nil { - return nil, err + for _, f := range client.userResponseHooks { + err := f(client, resp, req) + if err != nil { + return err + } } + + return nil + }() + if err != nil { + resp.Close() + return nil, err } return resp, nil From 3a8f6093a790698b09c29b6dcf92e4bb65bcc504 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sun, 21 Aug 2022 18:58:43 +0800 Subject: [PATCH 038/118] =?UTF-8?q?=F0=9F=90=9B=20fix:=20data=20race?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 2 +- client/core.go | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/client/client.go b/client/client.go index 9910781252..29bf394b69 100644 --- a/client/client.go +++ b/client/client.go @@ -18,7 +18,7 @@ import ( // Fiber Client also provides an option to override // or merge most of the client settings at the request. type Client struct { - mu sync.RWMutex + mu sync.Mutex baseUrl string userAgent string diff --git a/client/core.go b/client/core.go index 5377663bea..c163efa513 100644 --- a/client/core.go +++ b/client/core.go @@ -78,8 +78,8 @@ func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Resp // The built-in hooks will be executed only // after the user-defined hooks are executed. err := func() error { - client.mu.RLock() - defer client.mu.RUnlock() + client.mu.Lock() + defer client.mu.Unlock() for _, f := range client.userRequestHooks { err := f(client, req) @@ -127,8 +127,8 @@ func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Resp // The built-in hooks will be executed only // before the user-defined hooks are executed. err = func() error { - client.mu.RLock() - defer client.mu.RUnlock() + client.mu.Lock() + defer client.mu.Unlock() for _, f := range client.buildinResposeHooks { err := f(client, resp, req) if err != nil { From 22c7d20792e362d339e4dc10e50eb46e540af963 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Mon, 22 Aug 2022 21:59:06 +0800 Subject: [PATCH 039/118] =?UTF-8?q?=F0=9F=94=80=20fix:=20change=20to=20tes?= =?UTF-8?q?tify?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client_test.go | 33 ++-- client/core_test.go | 44 ++--- client/helper_test.go | 12 +- client/hooks_test.go | 156 +++++++-------- client/request_test.go | 420 ++++++++++++++++++++-------------------- client/response_test.go | 50 ++--- utils/xml_test.go | 10 +- 7 files changed, 367 insertions(+), 358 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index 484af50876..2d21ba088e 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -8,6 +8,7 @@ import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/utils" + "github.com/stretchr/testify/require" "github.com/valyala/fasthttp/fasthttputil" ) @@ -59,15 +60,15 @@ func Test_Get(t *testing.T) { }) go func() { - utils.AssertEqual(t, nil, app.Listener(ln)) + require.Nil(t, app.Listener(ln)) }() t.Run("global get function", func(t *testing.T) { resp, err := Get("http://example.com", SetDial(func(addr string) (net.Conn, error) { return ln.Dial() })) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "example.com", utils.UnsafeString(resp.RawResponse.Body())) + require.NoError(t, err) + require.Equal(t, "example.com", utils.UnsafeString(resp.RawResponse.Body())) }) } @@ -1194,8 +1195,8 @@ func Test_Client_R(t *testing.T) { client := AcquireClient() req := client.R() - utils.AssertEqual(t, "Request", reflect.TypeOf(req).Elem().Name()) - utils.AssertEqual(t, client, req.Client()) + require.Equal(t, "Request", reflect.TypeOf(req).Elem().Name()) + require.Equal(t, client, req.Client()) } func Test_Client_Add_Hook(t *testing.T) { @@ -1206,7 +1207,7 @@ func Test_Client_Add_Hook(t *testing.T) { return nil }) - utils.AssertEqual(t, 1, len(client.RequestHook())) + require.Equal(t, 1, len(client.RequestHook())) client.AddRequestHook(func(c *Client, r *Request) error { return nil @@ -1214,7 +1215,7 @@ func Test_Client_Add_Hook(t *testing.T) { return nil }) - utils.AssertEqual(t, 3, len(client.RequestHook())) + require.Equal(t, 3, len(client.RequestHook())) }) t.Run("add response hooks", func(t *testing.T) { @@ -1222,7 +1223,7 @@ func Test_Client_Add_Hook(t *testing.T) { return nil }) - utils.AssertEqual(t, 1, len(client.ResponseHook())) + require.Equal(t, 1, len(client.ResponseHook())) client.AddResponseHook(func(c *Client, resp *Response, r *Request) error { return nil @@ -1230,7 +1231,7 @@ func Test_Client_Add_Hook(t *testing.T) { return nil }) - utils.AssertEqual(t, 3, len(client.ResponseHook())) + require.Equal(t, 3, len(client.ResponseHook())) }) } @@ -1242,8 +1243,8 @@ func Test_Client_Marshal(t *testing.T) { }) val, err := client.JSONMarshal()(nil) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("hello"), val) + require.NoError(t, err) + require.Equal(t, []byte("hello"), val) }) t.Run("set json unmarshal", func(t *testing.T) { @@ -1253,7 +1254,7 @@ func Test_Client_Marshal(t *testing.T) { }) err := client.JSONUnmarshal()(nil, nil) - utils.AssertEqual(t, fmt.Errorf("empty json"), err) + require.Equal(t, fmt.Errorf("empty json"), err) }) t.Run("set xml marshal", func(t *testing.T) { @@ -1263,8 +1264,8 @@ func Test_Client_Marshal(t *testing.T) { }) val, err := client.XMLMarshal()(nil) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("hello"), val) + require.NoError(t, err) + require.Equal(t, []byte("hello"), val) }) t.Run("set xml unmarshal", func(t *testing.T) { @@ -1274,7 +1275,7 @@ func Test_Client_Marshal(t *testing.T) { }) err := client.XMLUnmarshal()(nil, nil) - utils.AssertEqual(t, fmt.Errorf("empty xml"), err) + require.Equal(t, fmt.Errorf("empty xml"), err) }) } @@ -1283,7 +1284,7 @@ func Test_Client_SetBaseURL(t *testing.T) { client := AcquireClient().SetBaseURL("http://example.com") - utils.AssertEqual(t, "http://example.com", client.BaseURL()) + require.Equal(t, "http://example.com", client.BaseURL()) } func Test_Client_Header(t *testing.T) { diff --git a/client/core_test.go b/client/core_test.go index 344f475dee..29ee2994d2 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -8,7 +8,7 @@ import ( "time" "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/utils" + "github.com/stretchr/testify/require" "github.com/valyala/fasthttp/fasthttputil" ) @@ -30,7 +30,7 @@ func Test_Exec_Func(t *testing.T) { }) go func() { - utils.AssertEqual(t, nil, app.Listener(ln)) + require.Nil(t, app.Listener(ln)) }() t.Run("normal request", func(t *testing.T) { @@ -40,9 +40,9 @@ func Test_Exec_Func(t *testing.T) { req.RawRequest.SetRequestURI("http://example.com/normal") resp, err := core.execFunc(context.Background(), client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, 200, resp.RawResponse.StatusCode()) - utils.AssertEqual(t, "example.com", string(resp.RawResponse.Body())) + require.NoError(t, err) + require.Equal(t, 200, resp.RawResponse.StatusCode()) + require.Equal(t, "example.com", string(resp.RawResponse.Body())) }) t.Run("the request return an error", func(t *testing.T) { @@ -53,9 +53,9 @@ func Test_Exec_Func(t *testing.T) { resp, err := core.execFunc(context.Background(), client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, 500, resp.RawResponse.StatusCode()) - utils.AssertEqual(t, "the request is error", string(resp.RawResponse.Body())) + require.NoError(t, err) + require.Equal(t, 500, resp.RawResponse.StatusCode()) + require.Equal(t, "the request is error", string(resp.RawResponse.Body())) }) t.Run("the request timeout", func(t *testing.T) { @@ -70,7 +70,7 @@ func Test_Exec_Func(t *testing.T) { _, err := core.execFunc(ctx, client, req) - utils.AssertEqual(t, ErrTimeoutOrCancel, err) + require.Equal(t, ErrTimeoutOrCancel, err) }) } @@ -94,13 +94,13 @@ func Test_Execute(t *testing.T) { }) go func() { - utils.AssertEqual(t, nil, app.Listener(ln)) + require.Nil(t, app.Listener(ln)) }() t.Run("add user request hooks", func(t *testing.T) { client, req := AcquireClient(), AcquireRequest() client.AddRequestHook(func(c *Client, r *Request) error { - utils.AssertEqual(t, "http://example.com", req.URL()) + require.Equal(t, "http://example.com", req.URL()) return nil }) req.SetDial(func(addr string) (net.Conn, error) { @@ -108,14 +108,14 @@ func Test_Execute(t *testing.T) { }).SetURL("http://example.com") resp, err := req.core.execute(context.Background(), client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "Cannot GET /", string(resp.RawResponse.Body())) + require.NoError(t, err) + require.Equal(t, "Cannot GET /", string(resp.RawResponse.Body())) }) t.Run("add user response hooks", func(t *testing.T) { client, req := AcquireClient(), AcquireRequest() client.AddResponseHook(func(c *Client, resp *Response, req *Request) error { - utils.AssertEqual(t, "http://example.com", req.URL()) + require.Equal(t, "http://example.com", req.URL()) return nil }) req.SetDial(func(addr string) (net.Conn, error) { @@ -123,8 +123,8 @@ func Test_Execute(t *testing.T) { }).SetURL("http://example.com") resp, err := req.core.execute(context.Background(), client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "Cannot GET /", string(resp.RawResponse.Body())) + require.NoError(t, err) + require.Equal(t, "Cannot GET /", string(resp.RawResponse.Body())) }) t.Run("no timeout", func(t *testing.T) { @@ -135,8 +135,8 @@ func Test_Execute(t *testing.T) { }).SetURL("http://example.com/hang-up") resp, err := req.core.execute(context.Background(), client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "example.com hang up", string(resp.RawResponse.Body())) + require.NoError(t, err) + require.Equal(t, "example.com hang up", string(resp.RawResponse.Body())) }) t.Run("client timeout", func(t *testing.T) { @@ -147,7 +147,7 @@ func Test_Execute(t *testing.T) { }).SetURL("http://example.com/hang-up") _, err := req.core.execute(context.Background(), client, req) - utils.AssertEqual(t, ErrTimeoutOrCancel, err) + require.Equal(t, ErrTimeoutOrCancel, err) }) t.Run("request timeout", func(t *testing.T) { @@ -159,7 +159,7 @@ func Test_Execute(t *testing.T) { SetTimeout(300 * time.Millisecond) _, err := req.core.execute(context.Background(), client, req) - utils.AssertEqual(t, ErrTimeoutOrCancel, err) + require.Equal(t, ErrTimeoutOrCancel, err) }) t.Run("request timeout has higher level", func(t *testing.T) { @@ -172,7 +172,7 @@ func Test_Execute(t *testing.T) { SetTimeout(3000 * time.Millisecond) resp, err := req.core.execute(context.Background(), client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "example.com hang up", string(resp.RawResponse.Body())) + require.NoError(t, err) + require.Equal(t, "example.com hang up", string(resp.RawResponse.Body())) }) } diff --git a/client/helper_test.go b/client/helper_test.go index b85bc0758f..575943e05d 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/utils" + "github.com/stretchr/testify/require" "github.com/valyala/fasthttp/fasthttputil" ) @@ -18,7 +18,7 @@ func createHelperServer(t *testing.T) (*fiber.App, func(addr string) (net.Conn, return app, func(addr string) (net.Conn, error) { return ln.Dial() }, func() { - utils.AssertEqual(t, nil, app.Listener(ln)) + require.Nil(t, app.Listener(ln)) } } @@ -40,9 +40,9 @@ func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Reques resp, err := req.Get("http://example.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) - utils.AssertEqual(t, excepted, resp.String()) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, excepted, resp.String()) resp.Close() } } @@ -65,6 +65,6 @@ func testAgentFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Re _, err := req.Get("http://example.com") - utils.AssertEqual(t, excepted.Error(), err.Error()) + require.Equal(t, excepted.Error(), err.Error()) } } diff --git a/client/hooks_test.go b/client/hooks_test.go index 7fc3b43c8d..f755181a3c 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/utils" + "github.com/stretchr/testify/require" ) func Test_AddMissing_Port(t *testing.T) { @@ -46,7 +46,7 @@ func Test_AddMissing_Port(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - utils.AssertEqual(t, tt.want, addMissingPort(tt.args.addr, tt.args.isTLS)) + require.Equal(t, tt.want, addMissingPort(tt.args.addr, tt.args.isTLS)) }) } } @@ -64,7 +64,7 @@ func Test_Rand_String(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := randString(tt.args) - utils.AssertEqual(t, 16, len(got)) + require.Equal(t, 16, len(got)) }) } } @@ -77,8 +77,8 @@ func Test_Parser_Request_URL(t *testing.T) { req := AcquireRequest().SetURL("") err := parserRequestURL(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api", req.RawRequest.URI().String()) + require.NoError(t, err) + require.Equal(t, "http://example.com/api", req.RawRequest.URI().String()) }) t.Run("request url should be set", func(t *testing.T) { @@ -86,8 +86,8 @@ func Test_Parser_Request_URL(t *testing.T) { req := AcquireRequest().SetURL("http://example.com/api") err := parserRequestURL(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api", req.RawRequest.URI().String()) + require.NoError(t, err) + require.Equal(t, "http://example.com/api", req.RawRequest.URI().String()) }) t.Run("the request url will override baseurl with protocol", func(t *testing.T) { @@ -95,8 +95,8 @@ func Test_Parser_Request_URL(t *testing.T) { req := AcquireRequest().SetURL("http://example.com/api/v1") err := parserRequestURL(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api/v1", req.RawRequest.URI().String()) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/v1", req.RawRequest.URI().String()) }) t.Run("the request url should be append after baseurl without protocol", func(t *testing.T) { @@ -104,8 +104,8 @@ func Test_Parser_Request_URL(t *testing.T) { req := AcquireRequest().SetURL("/v1") err := parserRequestURL(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api/v1", req.RawRequest.URI().String()) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/v1", req.RawRequest.URI().String()) }) t.Run("the url is error", func(t *testing.T) { @@ -113,7 +113,7 @@ func Test_Parser_Request_URL(t *testing.T) { req := AcquireRequest().SetURL("/v1") err := parserRequestURL(client, req) - utils.AssertEqual(t, ErrURLForamt, err) + require.Equal(t, ErrURLForamt, err) }) t.Run("the path param from client", func(t *testing.T) { @@ -123,8 +123,8 @@ func Test_Parser_Request_URL(t *testing.T) { req := AcquireRequest() err := parserRequestURL(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api/5", req.RawRequest.URI().String()) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/5", req.RawRequest.URI().String()) }) t.Run("the path param from request", func(t *testing.T) { @@ -140,8 +140,8 @@ func Test_Parser_Request_URL(t *testing.T) { DelPathParams("key") err := parserRequestURL(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api/5/fiber/%7Bkey%7D", req.RawRequest.URI().String()) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/5/fiber/%7Bkey%7D", req.RawRequest.URI().String()) }) t.Run("the path param from request and client", func(t *testing.T) { @@ -157,8 +157,8 @@ func Test_Parser_Request_URL(t *testing.T) { }) err := parserRequestURL(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "http://example.com/api/12/fiber/val", req.RawRequest.URI().String()) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/12/fiber/val", req.RawRequest.URI().String()) }) t.Run("query params from client should be set", func(t *testing.T) { @@ -167,8 +167,8 @@ func Test_Parser_Request_URL(t *testing.T) { req := AcquireRequest().SetURL("http://example.com/api/v1") err := parserRequestURL(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("foo=bar"), req.RawRequest.URI().QueryString()) + require.NoError(t, err) + require.Equal(t, []byte("foo=bar"), req.RawRequest.URI().QueryString()) }) t.Run("query params from request should be set", func(t *testing.T) { @@ -178,8 +178,8 @@ func Test_Parser_Request_URL(t *testing.T) { SetParam("bar", "foo") err := parserRequestURL(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("bar=foo"), req.RawRequest.URI().QueryString()) + require.NoError(t, err) + require.Equal(t, []byte("bar=foo"), req.RawRequest.URI().QueryString()) }) t.Run("query params should be merged", func(t *testing.T) { @@ -190,7 +190,7 @@ func Test_Parser_Request_URL(t *testing.T) { SetParam("bar", "foo") err := parserRequestURL(client, req) - utils.AssertEqual(t, nil, err) + require.NoError(t, err) values, _ := url.ParseQuery(string(req.RawRequest.URI().QueryString())) @@ -204,9 +204,9 @@ func Test_Parser_Request_URL(t *testing.T) { flag3 = true } } - utils.AssertEqual(t, true, flag1) - utils.AssertEqual(t, true, flag2) - utils.AssertEqual(t, true, flag3) + require.True(t, flag1) + require.True(t, flag2) + require.True(t, flag3) }) } @@ -222,8 +222,8 @@ func Test_Parser_Request_Header(t *testing.T) { req := AcquireRequest() err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("application/json"), req.RawRequest.Header.ContentType()) + require.NoError(t, err) + require.Equal(t, []byte("application/json"), req.RawRequest.Header.ContentType()) }) t.Run("request header should be set", func(t *testing.T) { @@ -235,8 +235,8 @@ func Test_Parser_Request_Header(t *testing.T) { }) err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType()) + require.NoError(t, err) + require.Equal(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType()) }) t.Run("request header should override client header", func(t *testing.T) { @@ -247,8 +247,8 @@ func Test_Parser_Request_Header(t *testing.T) { SetHeader(fiber.HeaderContentType, "application/json, utf-8") err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType()) + require.NoError(t, err) + require.Equal(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType()) }) t.Run("auto set json header", func(t *testing.T) { @@ -262,8 +262,8 @@ func Test_Parser_Request_Header(t *testing.T) { }) err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte(applicationJSON), req.RawRequest.Header.ContentType()) + require.NoError(t, err) + require.Equal(t, []byte(applicationJSON), req.RawRequest.Header.ContentType()) }) t.Run("auto set xml header", func(t *testing.T) { @@ -278,8 +278,8 @@ func Test_Parser_Request_Header(t *testing.T) { }) err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte(applicationXML), req.RawRequest.Header.ContentType()) + require.NoError(t, err) + require.Equal(t, []byte(applicationXML), req.RawRequest.Header.ContentType()) }) t.Run("auto set form data header", func(t *testing.T) { @@ -291,8 +291,8 @@ func Test_Parser_Request_Header(t *testing.T) { }) err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, applicationForm, string(req.RawRequest.Header.ContentType())) + require.NoError(t, err) + require.Equal(t, applicationForm, string(req.RawRequest.Header.ContentType())) }) t.Run("auto set file header", func(t *testing.T) { @@ -302,9 +302,9 @@ func Test_Parser_Request_Header(t *testing.T) { SetFormData("foo", "bar") err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, true, strings.Contains(string(req.RawRequest.Header.MultipartFormBoundary()), "--FiberFormBoundary")) - utils.AssertEqual(t, true, strings.Contains(string(req.RawRequest.Header.ContentType()), multipartFormData)) + require.NoError(t, err) + require.True(t, strings.Contains(string(req.RawRequest.Header.MultipartFormBoundary()), "--FiberFormBoundary")) + require.True(t, strings.Contains(string(req.RawRequest.Header.ContentType()), multipartFormData)) }) t.Run("ua should have default value", func(t *testing.T) { @@ -312,8 +312,8 @@ func Test_Parser_Request_Header(t *testing.T) { req := AcquireRequest() err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("fiber"), req.RawRequest.Header.UserAgent()) + require.NoError(t, err) + require.Equal(t, []byte("fiber"), req.RawRequest.Header.UserAgent()) }) t.Run("ua in client should be set", func(t *testing.T) { @@ -321,8 +321,8 @@ func Test_Parser_Request_Header(t *testing.T) { req := AcquireRequest() err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("foo"), req.RawRequest.Header.UserAgent()) + require.NoError(t, err) + require.Equal(t, []byte("foo"), req.RawRequest.Header.UserAgent()) }) t.Run("ua in request should have higher level", func(t *testing.T) { @@ -330,8 +330,8 @@ func Test_Parser_Request_Header(t *testing.T) { req := AcquireRequest().SetUserAgent("bar") err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("bar"), req.RawRequest.Header.UserAgent()) + require.NoError(t, err) + require.Equal(t, []byte("bar"), req.RawRequest.Header.UserAgent()) }) t.Run("referer in client should be set", func(t *testing.T) { @@ -339,8 +339,8 @@ func Test_Parser_Request_Header(t *testing.T) { req := AcquireRequest() err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("https://example.com"), req.RawRequest.Header.Referer()) + require.NoError(t, err) + require.Equal(t, []byte("https://example.com"), req.RawRequest.Header.Referer()) }) t.Run("referer in request should have higher level", func(t *testing.T) { @@ -348,8 +348,8 @@ func Test_Parser_Request_Header(t *testing.T) { req := AcquireRequest().SetReferer("https://example.com") err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("https://example.com"), req.RawRequest.Header.Referer()) + require.NoError(t, err) + require.Equal(t, []byte("https://example.com"), req.RawRequest.Header.Referer()) }) t.Run("client cookie should be set", func(t *testing.T) { @@ -364,10 +364,10 @@ func Test_Parser_Request_Header(t *testing.T) { req := AcquireRequest() err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) - utils.AssertEqual(t, "foo", string(req.RawRequest.Header.Cookie("bar"))) - utils.AssertEqual(t, "", string(req.RawRequest.Header.Cookie("bar1"))) + require.NoError(t, err) + require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) + require.Equal(t, "foo", string(req.RawRequest.Header.Cookie("bar"))) + require.Equal(t, "", string(req.RawRequest.Header.Cookie("bar1"))) }) t.Run("request cookie should be set", func(t *testing.T) { @@ -385,10 +385,10 @@ func Test_Parser_Request_Header(t *testing.T) { }) err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) - utils.AssertEqual(t, "67", string(req.RawRequest.Header.Cookie("bar"))) - utils.AssertEqual(t, "", string(req.RawRequest.Header.Cookie("bar1"))) + require.NoError(t, err) + require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) + require.Equal(t, "67", string(req.RawRequest.Header.Cookie("bar"))) + require.Equal(t, "", string(req.RawRequest.Header.Cookie("bar1"))) }) t.Run("request cookie will override client cookie", func(t *testing.T) { @@ -411,10 +411,10 @@ func Test_Parser_Request_Header(t *testing.T) { }) err := parserRequestHeader(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) - utils.AssertEqual(t, "67", string(req.RawRequest.Header.Cookie("bar"))) - utils.AssertEqual(t, "foo1", string(req.RawRequest.Header.Cookie("bar1"))) + require.NoError(t, err) + require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) + require.Equal(t, "67", string(req.RawRequest.Header.Cookie("bar"))) + require.Equal(t, "foo1", string(req.RawRequest.Header.Cookie("bar1"))) }) } @@ -432,8 +432,8 @@ func Test_Parser_Request_Body(t *testing.T) { }) err := parserRequestBody(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("{\"name\":\"foo\"}"), req.RawRequest.Body()) + require.NoError(t, err) + require.Equal(t, []byte("{\"name\":\"foo\"}"), req.RawRequest.Body()) }) t.Run("xml body", func(t *testing.T) { @@ -448,8 +448,8 @@ func Test_Parser_Request_Body(t *testing.T) { }) err := parserRequestBody(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("foo"), req.RawRequest.Body()) + require.NoError(t, err) + require.Equal(t, []byte("foo"), req.RawRequest.Body()) }) t.Run("form data body", func(t *testing.T) { @@ -460,8 +460,8 @@ func Test_Parser_Request_Body(t *testing.T) { }) err := parserRequestBody(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "ball=cricle+and+square", string(req.RawRequest.Body())) + require.NoError(t, err) + require.Equal(t, "ball=cricle+and+square", string(req.RawRequest.Body())) }) t.Run("file body", func(t *testing.T) { @@ -470,9 +470,9 @@ func Test_Parser_Request_Body(t *testing.T) { AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))) err := parserRequestBody(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, true, strings.Contains(string(req.RawRequest.Body()), "----FiberFormBoundary")) - utils.AssertEqual(t, true, strings.Contains(string(req.RawRequest.Body()), "world")) + require.NoError(t, err) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "----FiberFormBoundary")) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "world")) }) t.Run("file and form data", func(t *testing.T) { @@ -482,10 +482,10 @@ func Test_Parser_Request_Body(t *testing.T) { SetFormData("foo", "bar") err := parserRequestBody(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, true, strings.Contains(string(req.RawRequest.Body()), "----FiberFormBoundary")) - utils.AssertEqual(t, true, strings.Contains(string(req.RawRequest.Body()), "world")) - utils.AssertEqual(t, true, strings.Contains(string(req.RawRequest.Body()), "bar")) + require.NoError(t, err) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "----FiberFormBoundary")) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "world")) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "bar")) }) t.Run("raw body", func(t *testing.T) { @@ -494,7 +494,7 @@ func Test_Parser_Request_Body(t *testing.T) { SetRawBody([]byte("hello world")) err := parserRequestBody(client, req) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("hello world"), req.RawRequest.Body()) + require.NoError(t, err) + require.Equal(t, []byte("hello world"), req.RawRequest.Body()) }) } diff --git a/client/request_test.go b/client/request_test.go index 02a19e181b..cef02adc91 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -14,7 +14,7 @@ import ( "time" "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/utils" + "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) @@ -23,16 +23,16 @@ func Test_Request_Method(t *testing.T) { req := AcquireRequest() req.SetMethod("GET") - utils.AssertEqual(t, "GET", req.Method()) + require.Equal(t, "GET", req.Method()) req.SetMethod("POST") - utils.AssertEqual(t, "POST", req.Method()) + require.Equal(t, "POST", req.Method()) req.SetMethod("PUT") - utils.AssertEqual(t, "PUT", req.Method()) + require.Equal(t, "PUT", req.Method()) req.SetMethod("DELETE") - utils.AssertEqual(t, "DELETE", req.Method()) + require.Equal(t, "DELETE", req.Method()) } func Test_Request_URL(t *testing.T) { @@ -41,10 +41,10 @@ func Test_Request_URL(t *testing.T) { req := AcquireRequest() req.SetURL("http://example.com/normal") - utils.AssertEqual(t, "http://example.com/normal", req.URL()) + require.Equal(t, "http://example.com/normal", req.URL()) req.SetURL("https://example.com/normal") - utils.AssertEqual(t, "https://example.com/normal", req.URL()) + require.Equal(t, "https://example.com/normal", req.URL()) } func Test_Request_Client(t *testing.T) { @@ -54,7 +54,7 @@ func Test_Request_Client(t *testing.T) { req := AcquireRequest() req.SetClient(client) - utils.AssertEqual(t, client, req.Client()) + require.Equal(t, client, req.Client()) } func Test_Request_Context(t *testing.T) { @@ -64,13 +64,13 @@ func Test_Request_Context(t *testing.T) { ctx := req.Context() key := struct{}{} - utils.AssertEqual(t, nil, ctx.Value(key)) + require.Nil(t, ctx.Value(key)) ctx = context.WithValue(ctx, key, "string") req.SetContext(ctx) ctx = req.Context() - utils.AssertEqual(t, "string", ctx.Value(key).(string)) + require.Equal(t, "string", ctx.Value(key).(string)) } func Test_Request_Header(t *testing.T) { @@ -81,9 +81,9 @@ func Test_Request_Header(t *testing.T) { req.AddHeader("foo", "bar").AddHeader("foo", "fiber") res := req.Header("foo") - utils.AssertEqual(t, 2, len(res)) - utils.AssertEqual(t, "bar", res[0]) - utils.AssertEqual(t, "fiber", res[1]) + require.Equal(t, 2, len(res)) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) }) t.Run("set header", func(t *testing.T) { @@ -91,8 +91,8 @@ func Test_Request_Header(t *testing.T) { req.AddHeader("foo", "bar").SetHeader("foo", "fiber") res := req.Header("foo") - utils.AssertEqual(t, 1, len(res)) - utils.AssertEqual(t, "fiber", res[0]) + require.Equal(t, 1, len(res)) + require.Equal(t, "fiber", res[0]) }) t.Run("add headers", func(t *testing.T) { @@ -104,14 +104,14 @@ func Test_Request_Header(t *testing.T) { }) res := req.Header("foo") - utils.AssertEqual(t, 3, len(res)) - utils.AssertEqual(t, "bar", res[0]) - utils.AssertEqual(t, "buaa", res[1]) - utils.AssertEqual(t, "fiber", res[2]) + require.Equal(t, 3, len(res)) + require.Equal(t, "bar", res[0]) + require.Equal(t, "buaa", res[1]) + require.Equal(t, "fiber", res[2]) res = req.Header("bar") - utils.AssertEqual(t, 1, len(res)) - utils.AssertEqual(t, "foo", res[0]) + require.Equal(t, 1, len(res)) + require.Equal(t, "foo", res[0]) }) t.Run("set headers", func(t *testing.T) { @@ -123,12 +123,12 @@ func Test_Request_Header(t *testing.T) { }) res := req.Header("foo") - utils.AssertEqual(t, 1, len(res)) - utils.AssertEqual(t, "fiber", res[0]) + require.Equal(t, 1, len(res)) + require.Equal(t, "fiber", res[0]) res = req.Header("bar") - utils.AssertEqual(t, 1, len(res)) - utils.AssertEqual(t, "foo", res[0]) + require.Equal(t, 1, len(res)) + require.Equal(t, "foo", res[0]) }) } @@ -140,9 +140,9 @@ func Test_Request_QueryParam(t *testing.T) { req.AddParam("foo", "bar").AddParam("foo", "fiber") res := req.Param("foo") - utils.AssertEqual(t, 2, len(res)) - utils.AssertEqual(t, "bar", res[0]) - utils.AssertEqual(t, "fiber", res[1]) + require.Equal(t, 2, len(res)) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) }) t.Run("set param", func(t *testing.T) { @@ -150,8 +150,8 @@ func Test_Request_QueryParam(t *testing.T) { req.AddParam("foo", "bar").SetParam("foo", "fiber") res := req.Param("foo") - utils.AssertEqual(t, 1, len(res)) - utils.AssertEqual(t, "fiber", res[0]) + require.Equal(t, 1, len(res)) + require.Equal(t, "fiber", res[0]) }) t.Run("add params", func(t *testing.T) { @@ -163,14 +163,14 @@ func Test_Request_QueryParam(t *testing.T) { }) res := req.Param("foo") - utils.AssertEqual(t, 3, len(res)) - utils.AssertEqual(t, "bar", res[0]) - utils.AssertEqual(t, "buaa", res[1]) - utils.AssertEqual(t, "fiber", res[2]) + require.Equal(t, 3, len(res)) + require.Equal(t, "bar", res[0]) + require.Equal(t, "buaa", res[1]) + require.Equal(t, "fiber", res[2]) res = req.Param("bar") - utils.AssertEqual(t, 1, len(res)) - utils.AssertEqual(t, "foo", res[0]) + require.Equal(t, 1, len(res)) + require.Equal(t, "foo", res[0]) }) t.Run("set headers", func(t *testing.T) { @@ -182,12 +182,12 @@ func Test_Request_QueryParam(t *testing.T) { }) res := req.Param("foo") - utils.AssertEqual(t, 1, len(res)) - utils.AssertEqual(t, "fiber", res[0]) + require.Equal(t, 1, len(res)) + require.Equal(t, "fiber", res[0]) res = req.Param("bar") - utils.AssertEqual(t, 1, len(res)) - utils.AssertEqual(t, "foo", res[0]) + require.Equal(t, 1, len(res)) + require.Equal(t, "foo", res[0]) }) t.Run("set params with struct", func(t *testing.T) { @@ -212,28 +212,28 @@ func Test_Request_QueryParam(t *testing.T) { TIntSlice: []int{1, 2}, }) - utils.AssertEqual(t, 0, len(p.Param("unexport"))) + require.Equal(t, 0, len(p.Param("unexport"))) - utils.AssertEqual(t, 1, len(p.Param("TInt"))) - utils.AssertEqual(t, "5", p.Param("TInt")[0]) + require.Equal(t, 1, len(p.Param("TInt"))) + require.Equal(t, "5", p.Param("TInt")[0]) - utils.AssertEqual(t, 1, len(p.Param("TString"))) - utils.AssertEqual(t, "string", p.Param("TString")[0]) + require.Equal(t, 1, len(p.Param("TString"))) + require.Equal(t, "string", p.Param("TString")[0]) - utils.AssertEqual(t, 1, len(p.Param("TFloat"))) - utils.AssertEqual(t, "3.1", p.Param("TFloat")[0]) + require.Equal(t, 1, len(p.Param("TFloat"))) + require.Equal(t, "3.1", p.Param("TFloat")[0]) - utils.AssertEqual(t, 1, len(p.Param("TBool"))) + require.Equal(t, 1, len(p.Param("TBool"))) tslice := p.Param("TSlice") - utils.AssertEqual(t, 2, len(tslice)) - utils.AssertEqual(t, "bar", tslice[0]) - utils.AssertEqual(t, "foo", tslice[1]) + require.Equal(t, 2, len(tslice)) + require.Equal(t, "bar", tslice[0]) + require.Equal(t, "foo", tslice[1]) tint := p.Param("TSlice") - utils.AssertEqual(t, 2, len(tint)) - utils.AssertEqual(t, "bar", tint[0]) - utils.AssertEqual(t, "foo", tint[1]) + require.Equal(t, 2, len(tint)) + require.Equal(t, "bar", tint[0]) + require.Equal(t, "foo", tint[1]) }) t.Run("del params", func(t *testing.T) { @@ -245,10 +245,10 @@ func Test_Request_QueryParam(t *testing.T) { }).DelParams("foo", "bar") res := req.Param("foo") - utils.AssertEqual(t, 0, len(res)) + require.Equal(t, 0, len(res)) res = req.Param("bar") - utils.AssertEqual(t, 0, len(res)) + require.Equal(t, 0, len(res)) }) } @@ -256,20 +256,20 @@ func Test_Request_UA(t *testing.T) { t.Parallel() req := AcquireRequest().SetUserAgent("fiber") - utils.AssertEqual(t, "fiber", req.UserAgent()) + require.Equal(t, "fiber", req.UserAgent()) req.SetUserAgent("foo") - utils.AssertEqual(t, "foo", req.UserAgent()) + require.Equal(t, "foo", req.UserAgent()) } func Test_Request_Referer(t *testing.T) { t.Parallel() req := AcquireRequest().SetReferer("http://example.com") - utils.AssertEqual(t, "http://example.com", req.Referer()) + require.Equal(t, "http://example.com", req.Referer()) req.SetReferer("https://example.com") - utils.AssertEqual(t, "https://example.com", req.Referer()) + require.Equal(t, "https://example.com", req.Referer()) } func Test_Request_Cookie(t *testing.T) { @@ -278,10 +278,10 @@ func Test_Request_Cookie(t *testing.T) { t.Run("set cookie", func(t *testing.T) { req := AcquireRequest(). SetCookie("foo", "bar") - utils.AssertEqual(t, "bar", req.Cookie("foo")) + require.Equal(t, "bar", req.Cookie("foo")) req.SetCookie("foo", "bar1") - utils.AssertEqual(t, "bar1", req.Cookie("foo")) + require.Equal(t, "bar1", req.Cookie("foo")) }) t.Run("set cookies", func(t *testing.T) { @@ -290,14 +290,14 @@ func Test_Request_Cookie(t *testing.T) { "foo": "bar", "bar": "foo", }) - utils.AssertEqual(t, "bar", req.Cookie("foo")) - utils.AssertEqual(t, "foo", req.Cookie("bar")) + require.Equal(t, "bar", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) req.SetCookies(map[string]string{ "foo": "bar1", }) - utils.AssertEqual(t, "bar1", req.Cookie("foo")) - utils.AssertEqual(t, "foo", req.Cookie("bar")) + require.Equal(t, "bar1", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) }) t.Run("set cookies with struct", func(t *testing.T) { @@ -311,8 +311,8 @@ func Test_Request_Cookie(t *testing.T) { CookieString: "foo", }) - utils.AssertEqual(t, "5", req.Cookie("int")) - utils.AssertEqual(t, "foo", req.Cookie("string")) + require.Equal(t, "5", req.Cookie("int")) + require.Equal(t, "foo", req.Cookie("string")) }) t.Run("del cookies", func(t *testing.T) { @@ -321,12 +321,12 @@ func Test_Request_Cookie(t *testing.T) { "foo": "bar", "bar": "foo", }) - utils.AssertEqual(t, "bar", req.Cookie("foo")) - utils.AssertEqual(t, "foo", req.Cookie("bar")) + require.Equal(t, "bar", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) req.DelCookies("foo") - utils.AssertEqual(t, "", req.Cookie("foo")) - utils.AssertEqual(t, "foo", req.Cookie("bar")) + require.Equal(t, "", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) }) } @@ -336,10 +336,10 @@ func Test_Request_PathParam(t *testing.T) { t.Run("set path param", func(t *testing.T) { req := AcquireRequest(). SetPathParam("foo", "bar") - utils.AssertEqual(t, "bar", req.PathParam("foo")) + require.Equal(t, "bar", req.PathParam("foo")) req.SetPathParam("foo", "bar1") - utils.AssertEqual(t, "bar1", req.PathParam("foo")) + require.Equal(t, "bar1", req.PathParam("foo")) }) t.Run("set path params", func(t *testing.T) { @@ -348,14 +348,14 @@ func Test_Request_PathParam(t *testing.T) { "foo": "bar", "bar": "foo", }) - utils.AssertEqual(t, "bar", req.PathParam("foo")) - utils.AssertEqual(t, "foo", req.PathParam("bar")) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) req.SetPathParams(map[string]string{ "foo": "bar1", }) - utils.AssertEqual(t, "bar1", req.PathParam("foo")) - utils.AssertEqual(t, "foo", req.PathParam("bar")) + require.Equal(t, "bar1", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) }) t.Run("set path params with struct", func(t *testing.T) { @@ -369,8 +369,8 @@ func Test_Request_PathParam(t *testing.T) { CookieString: "foo", }) - utils.AssertEqual(t, "5", req.PathParam("int")) - utils.AssertEqual(t, "foo", req.PathParam("string")) + require.Equal(t, "5", req.PathParam("int")) + require.Equal(t, "foo", req.PathParam("string")) }) t.Run("del path params", func(t *testing.T) { @@ -379,12 +379,12 @@ func Test_Request_PathParam(t *testing.T) { "foo": "bar", "bar": "foo", }) - utils.AssertEqual(t, "bar", req.PathParam("foo")) - utils.AssertEqual(t, "foo", req.PathParam("bar")) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) req.DelPathParams("foo") - utils.AssertEqual(t, "", req.PathParam("foo")) - utils.AssertEqual(t, "foo", req.PathParam("bar")) + require.Equal(t, "", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) }) } @@ -396,9 +396,9 @@ func Test_Request_FormData(t *testing.T) { req.AddFormData("foo", "bar").AddFormData("foo", "fiber") res := req.FormData("foo") - utils.AssertEqual(t, 2, len(res)) - utils.AssertEqual(t, "bar", res[0]) - utils.AssertEqual(t, "fiber", res[1]) + require.Equal(t, 2, len(res)) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) }) t.Run("set param", func(t *testing.T) { @@ -406,8 +406,8 @@ func Test_Request_FormData(t *testing.T) { req.AddFormData("foo", "bar").SetFormData("foo", "fiber") res := req.FormData("foo") - utils.AssertEqual(t, 1, len(res)) - utils.AssertEqual(t, "fiber", res[0]) + require.Equal(t, 1, len(res)) + require.Equal(t, "fiber", res[0]) }) t.Run("add params", func(t *testing.T) { @@ -419,14 +419,14 @@ func Test_Request_FormData(t *testing.T) { }) res := req.FormData("foo") - utils.AssertEqual(t, 3, len(res)) - utils.AssertEqual(t, "bar", res[0]) - utils.AssertEqual(t, "buaa", res[1]) - utils.AssertEqual(t, "fiber", res[2]) + require.Equal(t, 3, len(res)) + require.Equal(t, "bar", res[0]) + require.Equal(t, "buaa", res[1]) + require.Equal(t, "fiber", res[2]) res = req.FormData("bar") - utils.AssertEqual(t, 1, len(res)) - utils.AssertEqual(t, "foo", res[0]) + require.Equal(t, 1, len(res)) + require.Equal(t, "foo", res[0]) }) t.Run("set headers", func(t *testing.T) { @@ -438,12 +438,12 @@ func Test_Request_FormData(t *testing.T) { }) res := req.FormData("foo") - utils.AssertEqual(t, 1, len(res)) - utils.AssertEqual(t, "fiber", res[0]) + require.Equal(t, 1, len(res)) + require.Equal(t, "fiber", res[0]) res = req.FormData("bar") - utils.AssertEqual(t, 1, len(res)) - utils.AssertEqual(t, "foo", res[0]) + require.Equal(t, 1, len(res)) + require.Equal(t, "foo", res[0]) }) t.Run("set params with struct", func(t *testing.T) { @@ -468,28 +468,28 @@ func Test_Request_FormData(t *testing.T) { TIntSlice: []int{1, 2}, }) - utils.AssertEqual(t, 0, len(p.FormData("unexport"))) + require.Equal(t, 0, len(p.FormData("unexport"))) - utils.AssertEqual(t, 1, len(p.FormData("TInt"))) - utils.AssertEqual(t, "5", p.FormData("TInt")[0]) + require.Equal(t, 1, len(p.FormData("TInt"))) + require.Equal(t, "5", p.FormData("TInt")[0]) - utils.AssertEqual(t, 1, len(p.FormData("TString"))) - utils.AssertEqual(t, "string", p.FormData("TString")[0]) + require.Equal(t, 1, len(p.FormData("TString"))) + require.Equal(t, "string", p.FormData("TString")[0]) - utils.AssertEqual(t, 1, len(p.FormData("TFloat"))) - utils.AssertEqual(t, "3.1", p.FormData("TFloat")[0]) + require.Equal(t, 1, len(p.FormData("TFloat"))) + require.Equal(t, "3.1", p.FormData("TFloat")[0]) - utils.AssertEqual(t, 1, len(p.FormData("TBool"))) + require.Equal(t, 1, len(p.FormData("TBool"))) tslice := p.FormData("TSlice") - utils.AssertEqual(t, 2, len(tslice)) - utils.AssertEqual(t, "bar", tslice[0]) - utils.AssertEqual(t, "foo", tslice[1]) + require.Equal(t, 2, len(tslice)) + require.Equal(t, "bar", tslice[0]) + require.Equal(t, "foo", tslice[1]) tint := p.FormData("TSlice") - utils.AssertEqual(t, 2, len(tint)) - utils.AssertEqual(t, "bar", tint[0]) - utils.AssertEqual(t, "foo", tint[1]) + require.Equal(t, 2, len(tint)) + require.Equal(t, "bar", tint[0]) + require.Equal(t, "foo", tint[1]) }) @@ -502,10 +502,10 @@ func Test_Request_FormData(t *testing.T) { }).DelFormDatas("foo", "bar") res := req.FormData("foo") - utils.AssertEqual(t, 0, len(res)) + require.Equal(t, 0, len(res)) res = req.FormData("bar") - utils.AssertEqual(t, 0, len(res)) + require.Equal(t, 0, len(res)) }) } @@ -517,28 +517,28 @@ func Test_Request_File(t *testing.T) { AddFile("../.github/index.html"). AddFiles(AcquireFile(SetFileName("tmp.txt"))) - utils.AssertEqual(t, "../.github/index.html", req.File("index.html").path) - utils.AssertEqual(t, "../.github/index.html", req.FileByPath("../.github/index.html").path) - utils.AssertEqual(t, "tmp.txt", req.File("tmp.txt").name) + require.Equal(t, "../.github/index.html", req.File("index.html").path) + require.Equal(t, "../.github/index.html", req.FileByPath("../.github/index.html").path) + require.Equal(t, "tmp.txt", req.File("tmp.txt").name) }) t.Run("add file by reader", func(t *testing.T) { req := AcquireRequest(). AddFileWithReader("tmp.txt", io.NopCloser(strings.NewReader("world"))) - utils.AssertEqual(t, "tmp.txt", req.File("tmp.txt").name) + require.Equal(t, "tmp.txt", req.File("tmp.txt").name) content, err := io.ReadAll(req.File("tmp.txt").reader) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "world", string(content)) + require.NoError(t, err) + require.Equal(t, "world", string(content)) }) t.Run("add files", func(t *testing.T) { req := AcquireRequest(). AddFiles(AcquireFile(SetFileName("tmp.txt")), AcquireFile(SetFileName("foo.txt"))) - utils.AssertEqual(t, "tmp.txt", req.File("tmp.txt").name) - utils.AssertEqual(t, "foo.txt", req.File("foo.txt").name) + require.Equal(t, "tmp.txt", req.File("tmp.txt").name) + require.Equal(t, "foo.txt", req.File("foo.txt").name) }) } @@ -547,7 +547,7 @@ func Test_Request_Timeout(t *testing.T) { req := AcquireRequest().SetTimeout(5 * time.Second) - utils.AssertEqual(t, 5*time.Second, req.Timeout()) + require.Equal(t, 5*time.Second, req.Timeout()) } func Test_Request_Invalid_URL(t *testing.T) { @@ -556,8 +556,8 @@ func Test_Request_Invalid_URL(t *testing.T) { resp, err := AcquireRequest(). Get("http://example.com\r\n\r\nGET /\r\n\r\n") - utils.AssertEqual(t, ErrURLForamt, err) - utils.AssertEqual(t, (*Response)(nil), resp) + require.Equal(t, ErrURLForamt, err) + require.Equal(t, (*Response)(nil), resp) } func Test_Request_Unsupport_Protocol(t *testing.T) { @@ -565,8 +565,8 @@ func Test_Request_Unsupport_Protocol(t *testing.T) { resp, err := AcquireRequest(). Get("ftp://example.com") - utils.AssertEqual(t, ErrURLForamt, err) - utils.AssertEqual(t, (*Response)(nil), resp) + require.Equal(t, ErrURLForamt, err) + require.Equal(t, (*Response)(nil), resp) } func Test_Request_Get(t *testing.T) { @@ -582,9 +582,9 @@ func Test_Request_Get(t *testing.T) { req := AcquireRequest().SetDial(ln) resp, err := req.Get("http://example.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) - utils.AssertEqual(t, "example.com", resp.String()) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "example.com", resp.String()) resp.Close() } } @@ -605,9 +605,9 @@ func Test_Request_Post(t *testing.T) { SetFormData("foo", "bar"). Post("http://example.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusCreated, resp.StatusCode()) - utils.AssertEqual(t, "bar", resp.String()) + require.NoError(t, err) + require.Equal(t, fiber.StatusCreated, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) resp.Close() } } @@ -627,9 +627,9 @@ func Test_Request_Head(t *testing.T) { SetDial(ln). Head("http://example.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) - utils.AssertEqual(t, "", resp.String()) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "", resp.String()) resp.Close() } } @@ -650,9 +650,9 @@ func Test_Request_Put(t *testing.T) { SetFormData("foo", "bar"). Put("http://example.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) - utils.AssertEqual(t, "bar", resp.String()) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) resp.Close() } @@ -674,9 +674,9 @@ func Test_Request_Delete(t *testing.T) { SetDial(ln). Delete("http://example.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusNoContent, resp.StatusCode()) - utils.AssertEqual(t, "", resp.String()) + require.NoError(t, err) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) resp.Close() } @@ -699,9 +699,9 @@ func Test_Request_Options(t *testing.T) { SetDial(ln). Options("http://example.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) - utils.AssertEqual(t, "options", resp.String()) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "options", resp.String()) resp.Close() } @@ -726,9 +726,9 @@ func Test_Request_Send(t *testing.T) { SetMethod(fiber.MethodPost). Send() - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) - utils.AssertEqual(t, "post", resp.String()) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "post", resp.String()) resp.Close() } @@ -751,9 +751,9 @@ func Test_Request_Patch(t *testing.T) { SetFormData("foo", "bar"). Patch("http://example.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) - utils.AssertEqual(t, "bar", resp.String()) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) resp.Close() } @@ -914,18 +914,18 @@ func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { t.Helper() basename := filepath.Base(filename) - utils.AssertEqual(t, fh.Filename, basename) + require.Equal(t, fh.Filename, basename) b1, err := os.ReadFile(filename) - utils.AssertEqual(t, nil, err) + require.NoError(t, err) b2 := make([]byte, fh.Size) f, err := fh.Open() - utils.AssertEqual(t, nil, err) + require.NoError(t, err) defer func() { _ = f.Close() }() _, err = f.Read(b2) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, b1, b2) + require.NoError(t, err) + require.Equal(t, b1, b2) } func Test_Request_Body_With_Server(t *testing.T) { @@ -934,7 +934,7 @@ func Test_Request_Body_With_Server(t *testing.T) { t.Run("json body", func(t *testing.T) { testAgent(t, func(c fiber.Ctx) error { - utils.AssertEqual(t, "application/json", string(c.Request().Header.ContentType())) + require.Equal(t, "application/json", string(c.Request().Header.ContentType())) return c.SendString(string(c.Request().Body())) }, func(agent *Request) { @@ -949,7 +949,7 @@ func Test_Request_Body_With_Server(t *testing.T) { t.Run("xml body", func(t *testing.T) { testAgent(t, func(c fiber.Ctx) error { - utils.AssertEqual(t, "application/xml", string(c.Request().Header.ContentType())) + require.Equal(t, "application/xml", string(c.Request().Header.ContentType())) return c.SendString(string(c.Request().Body())) }, func(agent *Request) { @@ -967,7 +967,7 @@ func Test_Request_Body_With_Server(t *testing.T) { t.Run("formdata", func(t *testing.T) { testAgent(t, func(c fiber.Ctx) error { - utils.AssertEqual(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) + require.Equal(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) return c.Send([]byte("foo=" + c.FormValue("foo") + "&bar=" + c.FormValue("bar") + "&fiber=" + c.FormValue("fiber"))) }, func(agent *Request) { @@ -985,11 +985,11 @@ func Test_Request_Body_With_Server(t *testing.T) { app, ln, start := createHelperServer(t) app.Post("/", func(c fiber.Ctx) error { - utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) + require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) mf, err := c.MultipartForm() - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "bar", mf.Value["foo"][0]) + require.NoError(t, err) + require.Equal(t, "bar", mf.Value["foo"][0]) return c.Send(c.Request().Body()) }) @@ -1007,12 +1007,12 @@ func Test_Request_Body_With_Server(t *testing.T) { )) resp, err := req.Post("http://exmaple.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) form, err := multipart.NewReader(bytes.NewReader(resp.Body()), "myBoundary").ReadForm(1024 * 1024) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "bar", form.Value["foo"][0]) + require.NoError(t, err) + require.Equal(t, "bar", form.Value["foo"][0]) resp.Close() }) @@ -1021,25 +1021,25 @@ func Test_Request_Body_With_Server(t *testing.T) { app, ln, start := createHelperServer(t) app.Post("/", func(c fiber.Ctx) error { - utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) + require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) fh1, err := c.FormFile("field1") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fh1.Filename, "name") + require.NoError(t, err) + require.Equal(t, fh1.Filename, "name") buf := make([]byte, fh1.Size) f, err := fh1.Open() - utils.AssertEqual(t, nil, err) + require.NoError(t, err) defer func() { _ = f.Close() }() _, err = f.Read(buf) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "form file", string(buf)) + require.NoError(t, err) + require.Equal(t, "form file", string(buf)) fh2, err := c.FormFile("file2") - utils.AssertEqual(t, nil, err) + require.NoError(t, err) checkFormFile(t, fh2, "../.github/testdata/index.html") fh3, err := c.FormFile("file3") - utils.AssertEqual(t, nil, err) + require.NoError(t, err) checkFormFile(t, fh3, "../.github/testdata/index.tmpl") return c.SendString("multipart form files") @@ -1062,8 +1062,8 @@ func Test_Request_Body_With_Server(t *testing.T) { SetBoundary("myBoundary") resp, err := req.Post("http://example.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "multipart form files", resp.String()) + require.NoError(t, err) + require.Equal(t, "multipart form files", resp.String()) resp.Close() } @@ -1075,7 +1075,7 @@ func Test_Request_Body_With_Server(t *testing.T) { app, ln, start := createHelperServer(t) app.Post("/", func(c fiber.Ctx) error { reg := regexp.MustCompile(`multipart/form-data; boundary=[\-\w]{35}`) - utils.AssertEqual(t, true, reg.MatchString(c.Get(fiber.HeaderContentType))) + require.True(t, reg.MatchString(c.Get(fiber.HeaderContentType))) return c.Send(c.Request().Body()) }) @@ -1092,8 +1092,8 @@ func Test_Request_Body_With_Server(t *testing.T) { )) resp, err := req.Post("http://exmaple.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode()) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) }) t.Run("raw body", func(t *testing.T) { @@ -1141,7 +1141,7 @@ func Test_Request_Error_Body_With_Server(t *testing.T) { SetBoundary("*"). AddFileWithReader("t.txt", io.NopCloser(strings.NewReader("world"))). Get("http://example.com") - utils.AssertEqual(t, "mime: invalid boundary character", err.Error()) + require.Equal(t, "mime: invalid boundary character", err.Error()) }) t.Run("open non exist file", func(t *testing.T) { @@ -1150,7 +1150,7 @@ func Test_Request_Error_Body_With_Server(t *testing.T) { _, err := AcquireRequest(). AddFile("non-exist-file!"). Get("http://example.com") - utils.AssertEqual(t, "open non-exist-file!: no such file or directory", err.Error()) + require.Equal(t, "open non-exist-file!: no such file or directory", err.Error()) }) } @@ -1169,7 +1169,7 @@ func Test_Request_Timeout_With_Server(t *testing.T) { SetTimeout(50 * time.Millisecond). Get("http://example.com") - utils.AssertEqual(t, ErrTimeoutOrCancel, err) + require.Equal(t, ErrTimeoutOrCancel, err) } // // readErrorConn is a struct for testing retryIf @@ -1494,12 +1494,12 @@ func Test_SetValWithStruct(t *testing.T) { TIntSlice: []int{1, 2}, }) - utils.AssertEqual(t, "", string(p.Peek("unexport"))) - utils.AssertEqual(t, []byte("5"), p.Peek("TInt")) - utils.AssertEqual(t, []byte("string"), p.Peek("TString")) - utils.AssertEqual(t, []byte("3.1"), p.Peek("TFloat")) - utils.AssertEqual(t, "", string(p.Peek("TBool"))) - utils.AssertEqual(t, true, func() bool { + require.Equal(t, "", string(p.Peek("unexport"))) + require.Equal(t, []byte("5"), p.Peek("TInt")) + require.Equal(t, []byte("string"), p.Peek("TString")) + require.Equal(t, []byte("3.1"), p.Peek("TFloat")) + require.Equal(t, "", string(p.Peek("TBool"))) + require.True(t, func() bool { for _, v := range p.PeekMulti("TSlice") { if string(v) == "foo" { return true @@ -1507,7 +1507,8 @@ func Test_SetValWithStruct(t *testing.T) { } return false }()) - utils.AssertEqual(t, true, func() bool { + + require.True(t, func() bool { for _, v := range p.PeekMulti("TSlice") { if string(v) == "bar" { return true @@ -1515,7 +1516,8 @@ func Test_SetValWithStruct(t *testing.T) { } return false }()) - utils.AssertEqual(t, true, func() bool { + + require.True(t, func() bool { for _, v := range p.PeekMulti("int_slice") { if string(v) == "1" { return true @@ -1523,7 +1525,8 @@ func Test_SetValWithStruct(t *testing.T) { } return false }()) - utils.AssertEqual(t, true, func() bool { + + require.True(t, func() bool { for _, v := range p.PeekMulti("int_slice") { if string(v) == "2" { return true @@ -1531,6 +1534,7 @@ func Test_SetValWithStruct(t *testing.T) { } return false }()) + }) t.Run("the pointer of a struct should be applied", func(t *testing.T) { @@ -1547,11 +1551,11 @@ func Test_SetValWithStruct(t *testing.T) { TIntSlice: []int{1, 2}, }) - utils.AssertEqual(t, []byte("5"), p.Peek("TInt")) - utils.AssertEqual(t, []byte("string"), p.Peek("TString")) - utils.AssertEqual(t, []byte("3.1"), p.Peek("TFloat")) - utils.AssertEqual(t, "true", string(p.Peek("TBool"))) - utils.AssertEqual(t, true, func() bool { + require.Equal(t, []byte("5"), p.Peek("TInt")) + require.Equal(t, []byte("string"), p.Peek("TString")) + require.Equal(t, []byte("3.1"), p.Peek("TFloat")) + require.Equal(t, "true", string(p.Peek("TBool"))) + require.True(t, func() bool { for _, v := range p.PeekMulti("TSlice") { if string(v) == "foo" { return true @@ -1559,7 +1563,8 @@ func Test_SetValWithStruct(t *testing.T) { } return false }()) - utils.AssertEqual(t, true, func() bool { + + require.True(t, func() bool { for _, v := range p.PeekMulti("TSlice") { if string(v) == "bar" { return true @@ -1567,7 +1572,8 @@ func Test_SetValWithStruct(t *testing.T) { } return false }()) - utils.AssertEqual(t, true, func() bool { + + require.True(t, func() bool { for _, v := range p.PeekMulti("int_slice") { if string(v) == "1" { return true @@ -1575,7 +1581,8 @@ func Test_SetValWithStruct(t *testing.T) { } return false }()) - utils.AssertEqual(t, true, func() bool { + + require.True(t, func() bool { for _, v := range p.PeekMulti("int_slice") { if string(v) == "2" { return true @@ -1583,6 +1590,7 @@ func Test_SetValWithStruct(t *testing.T) { } return false }()) + }) t.Run("the zero val should be ignore", func(t *testing.T) { @@ -1595,11 +1603,11 @@ func Test_SetValWithStruct(t *testing.T) { TFloat: 0.0, }) - utils.AssertEqual(t, "", string(p.Peek("TInt"))) - utils.AssertEqual(t, "", string(p.Peek("TString"))) - utils.AssertEqual(t, "", string(p.Peek("TFloat"))) - utils.AssertEqual(t, 0, len(p.PeekMulti("TSlice"))) - utils.AssertEqual(t, 0, len(p.PeekMulti("int_slice"))) + require.Equal(t, "", string(p.Peek("TInt"))) + require.Equal(t, "", string(p.Peek("TString"))) + require.Equal(t, "", string(p.Peek("TFloat"))) + require.Equal(t, 0, len(p.PeekMulti("TSlice"))) + require.Equal(t, 0, len(p.PeekMulti("int_slice"))) }) t.Run("error type should ignore", func(t *testing.T) { @@ -1607,6 +1615,6 @@ func Test_SetValWithStruct(t *testing.T) { Args: fasthttp.AcquireArgs(), } SetValWithStruct(p, "param", 5) - utils.AssertEqual(t, 0, p.Len()) + require.Equal(t, 0, p.Len()) }) } diff --git a/client/response_test.go b/client/response_test.go index 4cb08e7c89..503a6168f2 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/utils" + "github.com/stretchr/testify/require" ) func Test_Response_Status(t *testing.T) { @@ -27,8 +27,8 @@ func Test_Response_Status(t *testing.T) { SetDial(ln). Get("http://example") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "OK", resp.Status()) + require.NoError(t, err) + require.Equal(t, "OK", resp.Status()) resp.Close() }) @@ -39,8 +39,8 @@ func Test_Response_Status(t *testing.T) { SetDial(ln). Get("http://example/fail") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "Proxy Authentication Required", resp.Status()) + require.NoError(t, err) + require.Equal(t, "Proxy Authentication Required", resp.Status()) resp.Close() }) } @@ -64,8 +64,8 @@ func Test_Response_Status_Code(t *testing.T) { SetDial(ln). Get("http://example") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, 200, resp.StatusCode()) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode()) resp.Close() }) @@ -76,8 +76,8 @@ func Test_Response_Status_Code(t *testing.T) { SetDial(ln). Get("http://example/fail") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, 407, resp.StatusCode()) + require.NoError(t, err) + require.Equal(t, 407, resp.StatusCode()) resp.Close() }) } @@ -96,8 +96,8 @@ func Test_Response_Protocol(t *testing.T) { SetDial(ln). Get("http://example") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "HTTP/1.1", resp.Protocol()) + require.NoError(t, err) + require.Equal(t, "HTTP/1.1", resp.Protocol()) resp.Close() }) @@ -121,8 +121,8 @@ func Test_Response_Header(t *testing.T) { SetDial(ln). Get("http://example.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "bar", resp.Header("foo")) + require.NoError(t, err) + require.Equal(t, "bar", resp.Header("foo")) resp.Close() } @@ -143,8 +143,8 @@ func Test_Response_Cookie(t *testing.T) { SetDial(ln). Get("http://example.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "bar", string(resp.Cookies()[0].Value())) + require.NoError(t, err) + require.Equal(t, "bar", string(resp.Cookies()[0].Value())) resp.Close() } @@ -169,8 +169,8 @@ func Test_Response_Body(t *testing.T) { SetDial(ln). Get("http://example.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, []byte("hello world"), resp.Body()) + require.NoError(t, err) + require.Equal(t, []byte("hello world"), resp.Body()) resp.Close() }) @@ -179,8 +179,8 @@ func Test_Response_Body(t *testing.T) { SetDial(ln). Get("http://example.com") - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "hello world", resp.String()) + require.NoError(t, err) + require.Equal(t, "hello world", resp.String()) resp.Close() }) @@ -193,12 +193,12 @@ func Test_Response_Body(t *testing.T) { SetDial(ln). Get("http://example.com/json") - utils.AssertEqual(t, nil, err) + require.NoError(t, err) tmp := &body{} err = resp.JSON(tmp) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "success", tmp.Status) + require.NoError(t, err) + require.Equal(t, "success", tmp.Status) resp.Close() }) @@ -212,12 +212,12 @@ func Test_Response_Body(t *testing.T) { SetDial(ln). Get("http://example.com/xml") - utils.AssertEqual(t, nil, err) + require.NoError(t, err) tmp := &body{} err = resp.XML(tmp) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, "success", tmp.Status) + require.NoError(t, err) + require.Equal(t, "success", tmp.Status) resp.Close() }) } diff --git a/utils/xml_test.go b/utils/xml_test.go index 6aaba0deea..ff5575326c 100644 --- a/utils/xml_test.go +++ b/utils/xml_test.go @@ -70,9 +70,9 @@ func Test_DefaultXMLDecoder(t *testing.T) { ) err := xmlDecoder(xmlBytes, &ss) - AssertEqual(t, err, nil) - AssertEqual(t, len(ss.Servers), 2) - AssertEqual(t, ss.Version, "1") - AssertEqual(t, ss.Servers[0].Name, "fiber one") - AssertEqual(t, ss.Servers[1].Name, "fiber two") + require.Nil(t, err) + require.Equal(t, 2, len(ss.Servers)) + require.Equal(t, "1", ss.Version) + require.Equal(t, "fiber one", ss.Servers[0].Name) + require.Equal(t, "fiber two", ss.Servers[1].Name) } From 5d05cce223ea7cc341e8a09d16944e9b9de0d59c Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Mon, 22 Aug 2022 22:10:23 +0800 Subject: [PATCH 040/118] =?UTF-8?q?=E2=9C=85=20fix:=20fail=20test=20in=20w?= =?UTF-8?q?indows?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/request_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/request_test.go b/client/request_test.go index cef02adc91..9d32e50a95 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -1150,7 +1150,7 @@ func Test_Request_Error_Body_With_Server(t *testing.T) { _, err := AcquireRequest(). AddFile("non-exist-file!"). Get("http://example.com") - require.Equal(t, "open non-exist-file!: no such file or directory", err.Error()) + require.Contains(t, err.Error(), "open non-exist-file!") }) } From 484be572634d74a8f0ce4d80166fabe45eb85c03 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Tue, 23 Aug 2022 14:41:24 +0800 Subject: [PATCH 041/118] =?UTF-8?q?=E2=9C=A8=20feat:=20response=20body=20s?= =?UTF-8?q?ave=20to=20file?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/core.go | 11 +++---- client/response.go | 51 ++++++++++++++++++++++++++++++++ client/response_test.go | 64 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 5 deletions(-) diff --git a/client/core.go b/client/core.go index c163efa513..9b8d65eaaa 100644 --- a/client/core.go +++ b/client/core.go @@ -184,9 +184,10 @@ func newCore() (c *core) { } var ( - ErrTimeoutOrCancel = errors.New("timeout or cancel") - ErrURLForamt = errors.New("the url is a mistake") - ErrNotSupportSchema = errors.New("the protocol is not support, only http or https") - ErrFileNoName = errors.New("the file should have name") - ErrBodyType = errors.New("the body type should be []byte") + ErrTimeoutOrCancel = errors.New("timeout or cancel") + ErrURLForamt = errors.New("the url is a mistake") + ErrNotSupportSchema = errors.New("the protocol is not support, only http or https") + ErrFileNoName = errors.New("the file should have name") + ErrBodyType = errors.New("the body type should be []byte") + ErrNotSupportSaveMethod = errors.New("file path and io.Writer are supported") ) diff --git a/client/response.go b/client/response.go index 48c722d153..23de97c1ca 100644 --- a/client/response.go +++ b/client/response.go @@ -1,6 +1,10 @@ package client import ( + "bytes" + "io" + "os" + "path/filepath" "strings" "sync" @@ -73,6 +77,53 @@ func (r *Response) XML(v any) error { return r.client.xmlUnmarshal(r.Body(), v) } +func (r *Response) Save(v any) error { + switch p := v.(type) { + case string: + file := filepath.Clean(p) + dir := filepath.Dir(file) + + // create director + if _, err := os.Stat(dir); err != nil { + if !os.IsNotExist(err) { + return err + } + + if err = os.MkdirAll(dir, 0755); err != nil { + return err + } + } + + // create file + outFile, err := os.Create(file) + if err != nil { + return err + } + defer func() { outFile.Close() }() + + _, err = io.Copy(outFile, bytes.NewReader(r.Body())) + if err != nil { + return err + } + + return nil + case io.Writer: + _, err := io.Copy(p, bytes.NewReader(r.Body())) + if err != nil { + return err + } + defer func() { + if pc, ok := p.(io.WriteCloser); ok { + pc.Close() + } + }() + + return nil + default: + return ErrNotSupportSaveMethod + } +} + // Reset clear Response object. func (r *Response) Reset() { r.client = nil diff --git a/client/response_test.go b/client/response_test.go index 503a6168f2..e1d24503fa 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -1,7 +1,10 @@ package client import ( + "bytes" "encoding/xml" + "io" + "os" "testing" "github.com/gofiber/fiber/v3" @@ -221,3 +224,64 @@ func Test_Response_Body(t *testing.T) { resp.Close() }) } + +func Test_Response_Save(t *testing.T) { + + app, ln, start := createHelperServer(t) + app.Get("/json", func(c fiber.Ctx) error { + return c.SendString("{\"status\":\"success\"}") + }) + + go start() + + t.Run("file path", func(t *testing.T) { + resp, err := AcquireRequest(). + SetDial(ln). + Get("http://example.com/json") + + require.NoError(t, err) + + err = resp.Save("./test/tmp.json") + require.NoError(t, err) + defer func() { + if _, err := os.Stat("./test/tmp.json"); err != nil { + return + } + + os.RemoveAll("./test") + }() + + file, err := os.Open("./test/tmp.json") + require.NoError(t, err) + + data, err := io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, "{\"status\":\"success\"}", string(data)) + file.Close() + }) + + t.Run("io.Writer", func(t *testing.T) { + resp, err := AcquireRequest(). + SetDial(ln). + Get("http://example.com/json") + + require.NoError(t, err) + + buf := &bytes.Buffer{} + + err = resp.Save(buf) + require.NoError(t, err) + require.Equal(t, "{\"status\":\"success\"}", buf.String()) + }) + + t.Run("error type", func(t *testing.T) { + resp, err := AcquireRequest(). + SetDial(ln). + Get("http://example.com/json") + + require.NoError(t, err) + + err = resp.Save(nil) + require.Error(t, err) + }) +} From 1198e93a26a47a26e911736e4317c8f8315d169a Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Wed, 24 Aug 2022 16:40:12 +0800 Subject: [PATCH 042/118] =?UTF-8?q?=E2=9C=A8=20feat:=20support=20tls=20con?= =?UTF-8?q?fig?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 65 ++++++++++- client/core.go | 196 ++++++++++++++++++++++----------- client/core_test.go | 97 ++++++++++++---- client/hooks.go | 28 ----- client/hooks_test.go | 40 ------- client/request.go | 38 ++----- middleware/proxy/proxy_test.go | 40 ++++--- 7 files changed, 307 insertions(+), 197 deletions(-) diff --git a/client/client.go b/client/client.go index 29bf394b69..1db9f3bc63 100644 --- a/client/client.go +++ b/client/client.go @@ -1,9 +1,13 @@ package client import ( + "crypto/tls" + "crypto/x509" "encoding/json" "encoding/xml" + "io" "net" + "os" "sync" "time" @@ -46,6 +50,9 @@ type Client struct { jsonUnmarshal utils.JSONUnmarshal xmlMarshal utils.XMLMarshal xmlUnmarshal utils.XMLUnmarshal + + // tls config + tlsConfig *tls.Config } // R raise a request from the client. @@ -119,6 +126,62 @@ func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client { return c } +// TLSConfig returns tlsConfig in client. +// If client don't have tlsConfig, this function will init it. +func (c *Client) TLSConfig() *tls.Config { + if c.tlsConfig == nil { + c.tlsConfig = &tls.Config{} + } + + return c.tlsConfig +} + +// SetTLSConfig sets tlsConfig in client. +func (c *Client) SetTLSConfig(config *tls.Config) *Client { + c.tlsConfig = config + return c +} + +// SetCertificates method sets client certificates into client. +func (c *Client) SetCertificates(certs ...tls.Certificate) *Client { + config := c.TLSConfig() + config.Certificates = append(config.Certificates, certs...) + return c +} + +// SetRootCertificate adds one or more root certificates into client. +func (c *Client) SetRootCertificate(path string) *Client { + file, err := os.Open(path) + if err != nil { + return c + } + defer file.Close() + pem, err := io.ReadAll(file) + if err != nil { + return c + } + + config := c.TLSConfig() + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + config.RootCAs.AppendCertsFromPEM(pem) + + return c +} + +// SetRootCertificateFromString method adds one or more root certificates into client. +func (c *Client) SetRootCertificateFromString(pem string) *Client { + config := c.TLSConfig() + + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + config.RootCAs.AppendCertsFromPEM([]byte(pem)) + + return c +} + // BaseURL returns baseurl in Client instance. func (c *Client) BaseURL() string { return c.baseUrl @@ -436,7 +499,7 @@ func SetRequestFiles(files ...*File) SetRequestOptionFunc { func SetDial(f func(addr string) (net.Conn, error)) SetRequestOptionFunc { return func(r *Request) { - r.core.client.Dial = f + r.dial = f } } diff --git a/client/core.go b/client/core.go index 9b8d65eaaa..8c11f9f4f7 100644 --- a/client/core.go +++ b/client/core.go @@ -1,14 +1,36 @@ package client import ( + "bytes" "context" "errors" + "net" + "strconv" + "strings" "sync" "sync/atomic" "github.com/valyala/fasthttp" ) +var ( + httpBytes = []byte("http") + httpsBytes = []byte("https") +) + +// addMissingPort will add the corresponding port number for host. +func addMissingPort(addr string, isTLS bool) string { + n := strings.Index(addr, ":") + if n >= 0 { + return addr + } + port := 80 + if isTLS { + port = 443 + } + return net.JoinHostPort(addr, strconv.Itoa(port)) +} + // RequestHook is a function that receives Agent and Request, // it can change the data in Request and Agent. // @@ -24,13 +46,17 @@ type ResponseHook func(*Client, *Response, *Request) error // `core` stores middleware and plugin definitions, // and defines the execution process type core struct { - client *fasthttp.HostClient + host *fasthttp.HostClient + + client *Client + req *Request + ctx context.Context } -func (c *core) execFunc(ctx context.Context, client *Client, req *Request) (*Response, error) { +func (c *core) execFunc() (*Response, error) { resp := AcquireResponse() - resp.setClient(client) - resp.setRequest(req) + resp.setClient(c.client) + resp.setRequest(c.req) // To avoid memory allocation reuse of data structures such as errch. done := int32(0) @@ -39,10 +65,10 @@ func (c *core) execFunc(ctx context.Context, client *Client, req *Request) (*Res releaseErrChan(errCh) }() - req.RawRequest.CopyTo(reqv) + c.req.RawRequest.CopyTo(reqv) go func() { respv := fasthttp.AcquireResponse() - err := c.client.Do(reqv, respv) + err := c.host.Do(reqv, respv) defer func() { fasthttp.ReleaseRequest(reqv) fasthttp.ReleaseResponse(respv) @@ -66,85 +92,131 @@ func (c *core) execFunc(ctx context.Context, client *Client, req *Request) (*Res return nil, err } return resp, nil - case <-ctx.Done(): + case <-c.ctx.Done(): atomic.SwapInt32(&done, 1) ReleaseResponse(resp) return nil, ErrTimeoutOrCancel } } -// execute will exec each hooks and plugins. -func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Response, error) { - // The built-in hooks will be executed only - // after the user-defined hooks are executed. - err := func() error { - client.mu.Lock() - defer client.mu.Unlock() +// Exec request hook +func (c *core) preHooks() error { + c.client.mu.Lock() + defer c.client.mu.Unlock() - for _, f := range client.userRequestHooks { - err := f(client, req) - if err != nil { - return err - } + for _, f := range c.client.userRequestHooks { + err := f(c.client, c.req) + if err != nil { + return err } + } - for _, f := range client.buildinRequestHooks { - err := f(client, req) - if err != nil { - return err - } + for _, f := range c.client.buildinRequestHooks { + err := f(c.client, c.req) + if err != nil { + return err } + } - return nil - }() - if err != nil { - return nil, err + return nil +} + +// Exec response hooks +func (c *core) afterHooks(resp *Response) error { + c.client.mu.Lock() + defer c.client.mu.Unlock() + for _, f := range c.client.buildinResposeHooks { + err := f(c.client, resp, c.req) + if err != nil { + return err + } } - // deal with timeout - if req.timeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, req.timeout) - defer func() { - cancel() - }() + for _, f := range c.client.userResponseHooks { + err := f(c.client, resp, c.req) + if err != nil { + return err + } + } + + return nil +} + +// timeout deals with timeout +func (c *core) timeout() context.CancelFunc { + var cancel context.CancelFunc + + if c.req.timeout > 0 { + c.ctx, cancel = context.WithTimeout(c.ctx, c.req.timeout) } else { - if client.timeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, client.timeout) - defer func() { - cancel() - }() + if c.client.timeout > 0 { + c.ctx, cancel = context.WithTimeout(c.ctx, c.client.timeout) } } + return cancel +} + +// dial set dial in host. +func (c *core) dial() { + c.host.Dial = c.req.dial +} + +// tls sets tls config. +func (c *core) tls() { + c.host.TLSConfig = c.client.tlsConfig.Clone() +} + +// TODO now set url with Request uri, need cover with proxy url +func (c *core) proxy() error { + rawUri := c.req.RawRequest.URI() + isTLS, scheme := false, rawUri.Scheme() + if bytes.Equal(httpsBytes, scheme) { + isTLS = true + } else if !bytes.Equal(httpBytes, scheme) { + return ErrNotSupportSchema + } + + c.host.Addr = addMissingPort(string(rawUri.Host()), isTLS) + c.host.IsTLS = isTLS + + return nil +} + +// execute will exec each hooks and plugins. +func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Response, error) { + // keep a reference, because pass param is boring + c.ctx = ctx + c.client = client + c.req = req + + // The built-in hooks will be executed only + // after the user-defined hooks are executed. + err := c.preHooks() + if err != nil { + return nil, err + } + + cancel := c.timeout() + if cancel != nil { + defer cancel() + } + + c.tls() + + c.dial() + + c.proxy() + // Do http request - resp, err := c.execFunc(ctx, client, req) + resp, err := c.execFunc() if err != nil { return nil, err } // The built-in hooks will be executed only // before the user-defined hooks are executed. - err = func() error { - client.mu.Lock() - defer client.mu.Unlock() - for _, f := range client.buildinResposeHooks { - err := f(client, resp, req) - if err != nil { - return err - } - } - - for _, f := range client.userResponseHooks { - err := f(client, resp, req) - if err != nil { - return err - } - } - - return nil - }() + err = c.afterHooks(resp) if err != nil { resp.Close() return nil, err @@ -177,7 +249,7 @@ func releaseErrChan(ch chan error) { // newCore returns an empty core object. func newCore() (c *core) { c = &core{ - client: &fasthttp.HostClient{}, + host: &fasthttp.HostClient{}, } return diff --git a/client/core_test.go b/client/core_test.go index 29ee2994d2..2c6c28f845 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -12,6 +12,46 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) +func Test_AddMissing_Port(t *testing.T) { + type args struct { + addr string + isTLS bool + } + tests := []struct { + name string + args args + want string + }{ + { + name: "do anything", + args: args{ + addr: "example.com:1234", + }, + want: "example.com:1234", + }, + { + name: "add 80 port", + args: args{ + addr: "example.com", + }, + want: "example.com:80", + }, + { + name: "add 443 port", + args: args{ + addr: "example.com", + isTLS: true, + }, + want: "example.com:443", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, addMissingPort(tt.args.addr, tt.args.isTLS)) + }) + } +} + func Test_Exec_Func(t *testing.T) { ln := fasthttputil.NewInmemoryListener() app := fiber.New(fiber.Config{DisableStartupMessage: true}) @@ -34,24 +74,30 @@ func Test_Exec_Func(t *testing.T) { }() t.Run("normal request", func(t *testing.T) { - client, req := AcquireClient(), AcquireRequest() - core := req.core - core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + core, client, req := newCore(), AcquireClient(), AcquireRequest() + core.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.RawRequest.SetRequestURI("http://example.com/normal") - resp, err := core.execFunc(context.Background(), client, req) + core.ctx = context.Background() + core.client = client + core.req = req + + resp, err := core.execFunc() require.NoError(t, err) require.Equal(t, 200, resp.RawResponse.StatusCode()) require.Equal(t, "example.com", string(resp.RawResponse.Body())) }) t.Run("the request return an error", func(t *testing.T) { - client, req := AcquireClient(), AcquireRequest() - core := req.core - core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + core, client, req := newCore(), AcquireClient(), AcquireRequest() + core.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.RawRequest.SetRequestURI("http://example.com/return-error") - resp, err := core.execFunc(context.Background(), client, req) + core.ctx = context.Background() + core.client = client + core.req = req + + resp, err := core.execFunc() require.NoError(t, err) require.Equal(t, 500, resp.RawResponse.StatusCode()) @@ -59,16 +105,19 @@ func Test_Exec_Func(t *testing.T) { }) t.Run("the request timeout", func(t *testing.T) { - client, req := AcquireClient(), AcquireRequest() - core := req.core + core, client, req := newCore(), AcquireClient(), AcquireRequest() - core.client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + core.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.RawRequest.SetRequestURI("http://example.com/hang-up") ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - _, err := core.execFunc(ctx, client, req) + core.ctx = ctx + core.client = client + core.req = req + + _, err := core.execFunc() require.Equal(t, ErrTimeoutOrCancel, err) }) @@ -98,7 +147,7 @@ func Test_Execute(t *testing.T) { }() t.Run("add user request hooks", func(t *testing.T) { - client, req := AcquireClient(), AcquireRequest() + core, client, req := newCore(), AcquireClient(), AcquireRequest() client.AddRequestHook(func(c *Client, r *Request) error { require.Equal(t, "http://example.com", req.URL()) return nil @@ -107,13 +156,13 @@ func Test_Execute(t *testing.T) { return ln.Dial() }).SetURL("http://example.com") - resp, err := req.core.execute(context.Background(), client, req) + resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) require.Equal(t, "Cannot GET /", string(resp.RawResponse.Body())) }) t.Run("add user response hooks", func(t *testing.T) { - client, req := AcquireClient(), AcquireRequest() + core, client, req := newCore(), AcquireClient(), AcquireRequest() client.AddResponseHook(func(c *Client, resp *Response, req *Request) error { require.Equal(t, "http://example.com", req.URL()) return nil @@ -122,48 +171,48 @@ func Test_Execute(t *testing.T) { return ln.Dial() }).SetURL("http://example.com") - resp, err := req.core.execute(context.Background(), client, req) + resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) require.Equal(t, "Cannot GET /", string(resp.RawResponse.Body())) }) t.Run("no timeout", func(t *testing.T) { - client, req := AcquireClient(), AcquireRequest() + core, client, req := newCore(), AcquireClient(), AcquireRequest() req.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }).SetURL("http://example.com/hang-up") - resp, err := req.core.execute(context.Background(), client, req) + resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) require.Equal(t, "example.com hang up", string(resp.RawResponse.Body())) }) t.Run("client timeout", func(t *testing.T) { - client, req := AcquireClient(), AcquireRequest() + core, client, req := newCore(), AcquireClient(), AcquireRequest() client.SetTimeout(500 * time.Millisecond) req.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }).SetURL("http://example.com/hang-up") - _, err := req.core.execute(context.Background(), client, req) + _, err := core.execute(context.Background(), client, req) require.Equal(t, ErrTimeoutOrCancel, err) }) t.Run("request timeout", func(t *testing.T) { - client, req := AcquireClient(), AcquireRequest() + core, client, req := newCore(), AcquireClient(), AcquireRequest() req.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }).SetURL("http://example.com/hang-up"). SetTimeout(300 * time.Millisecond) - _, err := req.core.execute(context.Background(), client, req) + _, err := core.execute(context.Background(), client, req) require.Equal(t, ErrTimeoutOrCancel, err) }) t.Run("request timeout has higher level", func(t *testing.T) { - client, req := AcquireClient(), AcquireRequest() + core, client, req := newCore(), AcquireClient(), AcquireRequest() client.SetTimeout(30 * time.Millisecond) req.SetDial(func(addr string) (net.Conn, error) { @@ -171,7 +220,7 @@ func Test_Execute(t *testing.T) { }).SetURL("http://example.com/hang-up"). SetTimeout(3000 * time.Millisecond) - resp, err := req.core.execute(context.Background(), client, req) + resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) require.Equal(t, "example.com hang up", string(resp.RawResponse.Body())) }) diff --git a/client/hooks.go b/client/hooks.go index bf9c889c33..49c4315e22 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -1,11 +1,9 @@ package client import ( - "bytes" "io" "math/rand" "mime/multipart" - "net" "os" "path/filepath" "regexp" @@ -18,9 +16,6 @@ import ( ) var ( - httpBytes = []byte("http") - httpsBytes = []byte("https") - protocolCheck = regexp.MustCompile(`^https?://.*$`) headerAccept = "Accept" @@ -36,19 +31,6 @@ var ( letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits ) -// addMissingPort will add the corresponding port number for host. -func addMissingPort(addr string, isTLS bool) string { - n := strings.Index(addr, ":") - if n >= 0 { - return addr - } - port := 80 - if isTLS { - port = 443 - } - return net.JoinHostPort(addr, strconv.Itoa(port)) -} - func randString(n int) string { b := make([]byte, n) length := len(letterBytes) @@ -99,16 +81,6 @@ func parserRequestURL(c *Client, req *Request) error { // set uri to request and orther related setting req.RawRequest.SetRequestURI(uri) - rawUri := req.RawRequest.URI() - isTLS, scheme := false, rawUri.Scheme() - if bytes.Equal(httpsBytes, scheme) { - isTLS = true - } else if !bytes.Equal(httpBytes, scheme) { - return ErrNotSupportSchema - } - - req.core.client.Addr = addMissingPort(string(rawUri.Host()), isTLS) - req.core.client.IsTLS = isTLS // merge query params hashSplit := strings.Split(splitUrl[1], "#") diff --git a/client/hooks_test.go b/client/hooks_test.go index f755181a3c..2d1284c843 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -11,46 +11,6 @@ import ( "github.com/stretchr/testify/require" ) -func Test_AddMissing_Port(t *testing.T) { - type args struct { - addr string - isTLS bool - } - tests := []struct { - name string - args args - want string - }{ - { - name: "do anything", - args: args{ - addr: "example.com:1234", - }, - want: "example.com:1234", - }, - { - name: "add 80 port", - args: args{ - addr: "example.com", - }, - want: "example.com:80", - }, - { - name: "add 443 port", - args: args{ - addr: "example.com", - isTLS: true, - }, - want: "example.com:443", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - require.Equal(t, tt.want, addMissingPort(tt.args.addr, tt.args.isTLS)) - }) - } -} - func Test_Rand_String(t *testing.T) { tests := []struct { name string diff --git a/client/request.go b/client/request.go index 0229bd5a36..8a986b4c08 100644 --- a/client/request.go +++ b/client/request.go @@ -37,8 +37,6 @@ const ( ) type Request struct { - core *core - url string method string userAgent string @@ -59,13 +57,15 @@ type Request struct { files []*File bodyType bodyType + dial fasthttp.DialFunc + RawRequest *fasthttp.Request } // Set HostClient dial, this method for unit test, // maybe don't use it. func (r *Request) SetDial(f fasthttp.DialFunc) *Request { - r.core.client.Dial = f + r.dial = f return r } @@ -484,63 +484,48 @@ func (r *Request) checkClient() { // Send get request. func (r *Request) Get(url string) (*Response, error) { - r.SetURL(url).SetMethod(fiber.MethodGet).checkClient() - - return r.core.execute(r.Context(), r.client, r) + return r.SetURL(url).SetMethod(fiber.MethodGet).Send() } // Send post request. func (r *Request) Post(url string) (*Response, error) { - r.SetURL(url).SetMethod(fiber.MethodPost).checkClient() - - return r.core.execute(r.Context(), r.client, r) + return r.SetURL(url).SetMethod(fiber.MethodPost).Send() } // Send head request. func (r *Request) Head(url string) (*Response, error) { - r.SetURL(url).SetMethod(fiber.MethodHead).checkClient() - - return r.core.execute(r.Context(), r.client, r) + return r.SetURL(url).SetMethod(fiber.MethodHead).Send() } // Send put request. func (r *Request) Put(url string) (*Response, error) { - r.SetURL(url).SetMethod(fiber.MethodPut).checkClient() - - return r.core.execute(r.Context(), r.client, r) + return r.SetURL(url).SetMethod(fiber.MethodPut).Send() } // Send Delete request. func (r *Request) Delete(url string) (*Response, error) { - r.SetURL(url).SetMethod(fiber.MethodDelete).checkClient() - - return r.core.execute(r.Context(), r.client, r) + return r.SetURL(url).SetMethod(fiber.MethodDelete).Send() } // Send Options reuqest. func (r *Request) Options(url string) (*Response, error) { - r.SetURL(url).SetMethod(fiber.MethodOptions).checkClient() - - return r.core.execute(r.Context(), r.client, r) + return r.SetURL(url).SetMethod(fiber.MethodOptions).Send() } // Send patch request. func (r *Request) Patch(url string) (*Response, error) { - r.SetURL(url).SetMethod(fiber.MethodPatch).checkClient() - - return r.core.execute(r.Context(), r.client, r) + return r.SetURL(url).SetMethod(fiber.MethodPatch).Send() } // Send a request. func (r *Request) Send() (*Response, error) { r.checkClient() - return r.core.execute(r.Context(), r.client, r) + return newCore().execute(r.Context(), r.Client(), r) } // Reset clear Request object, used by ReleaseRequest method. func (r *Request) Reset() { - r.core = nil r.url = "" r.method = fiber.MethodGet r.userAgent = "" @@ -838,7 +823,6 @@ var requestPool = &sync.Pool{ func AcquireRequest() *Request { req := requestPool.Get().(*Request) req.boundary = "--FiberFormBoundary" + randString(16) - req.core = newCore() return req } diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index dca171c94f..afbe2e3713 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/gofiber/fiber/v3" + fiberClient "github.com/gofiber/fiber/v3/client" "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/gofiber/fiber/v3/utils" "github.com/stretchr/testify/require" @@ -117,11 +118,15 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) { go func() { require.Nil(t, app.Listener(ln)) }() - code, body, errs := fiberClient.Get("https://" + addr + "/tlsbalaner").TLSConfig(clientTLSConf).String() + resp, err := fiberClient.AcquireClient(). + SetTLSConfig(clientTLSConf). + R(). + Get("https://" + addr + "/tlsbalaner") - require.Equal(t, 0, len(errs)) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "tls balancer", body) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "tls balancer", resp.String()) + resp.Close() } // go test -run Test_Proxy_Forward_WithTlsConfig_To_Http @@ -148,14 +153,16 @@ func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) { go func() { require.Nil(t, app.Listener(proxyServerLn)) }() - code, body, errs := fiberClient.Get("https://" + proxyAddr). - InsecureSkipVerify(). - Timeout(5 * time.Second). - String() + resp, err := fiberClient.AcquireClient().SetTLSConfig(&tls.Config{ + InsecureSkipVerify: true, + }).R(). + SetTimeout(5 * time.Second). + Get("https://" + proxyAddr) - require.Equal(t, 0, len(errs)) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "hello from target", body) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "hello from target", resp.String()) + resp.Close() } // go test -run Test_Proxy_Forward @@ -206,11 +213,14 @@ func Test_Proxy_Forward_WithTlsConfig(t *testing.T) { go func() { require.Nil(t, app.Listener(ln)) }() - code, body, errs := fiberClient.Get("https://" + addr).TLSConfig(clientTLSConf).String() + resp, err := fiberClient.AcquireClient(). + SetTLSConfig(clientTLSConf). + R(). + Get("https://" + addr) - require.Equal(t, 0, len(errs)) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "tls forward", body) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "tls forward", resp.String()) } // go test -run Test_Proxy_Modify_Response From c14408e8393832bf6c021ce634a8fa66ecf7abd5 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Wed, 24 Aug 2022 16:43:24 +0800 Subject: [PATCH 043/118] =?UTF-8?q?=F0=9F=90=9B=20fix:=20add=20err=20check?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/core.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/client/core.go b/client/core.go index 8c11f9f4f7..469b90aabc 100644 --- a/client/core.go +++ b/client/core.go @@ -206,7 +206,10 @@ func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Resp c.dial() - c.proxy() + err = c.proxy() + if err != nil { + return nil, err + } // Do http request resp, err := c.execFunc() From a336a2a223f7111670e6745fe5345e6731e0f138 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Wed, 24 Aug 2022 17:11:20 +0800 Subject: [PATCH 044/118] =?UTF-8?q?=F0=9F=8E=A8=20perf:=20fix=20some=20sta?= =?UTF-8?q?tic=20check?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 8 ++++++-- client/response.go | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/client/client.go b/client/client.go index 1db9f3bc63..c8e84caf63 100644 --- a/client/client.go +++ b/client/client.go @@ -8,6 +8,7 @@ import ( "io" "net" "os" + "path/filepath" "sync" "time" @@ -130,7 +131,9 @@ func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client { // If client don't have tlsConfig, this function will init it. func (c *Client) TLSConfig() *tls.Config { if c.tlsConfig == nil { - c.tlsConfig = &tls.Config{} + c.tlsConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + } } return c.tlsConfig @@ -151,11 +154,12 @@ func (c *Client) SetCertificates(certs ...tls.Certificate) *Client { // SetRootCertificate adds one or more root certificates into client. func (c *Client) SetRootCertificate(path string) *Client { + path = filepath.Clean(path) file, err := os.Open(path) if err != nil { return c } - defer file.Close() + defer func() { _ = file.Close() }() pem, err := io.ReadAll(file) if err != nil { return c diff --git a/client/response.go b/client/response.go index 23de97c1ca..e453f0d349 100644 --- a/client/response.go +++ b/client/response.go @@ -99,7 +99,7 @@ func (r *Response) Save(v any) error { if err != nil { return err } - defer func() { outFile.Close() }() + defer func() { _ = outFile.Close() }() _, err = io.Copy(outFile, bytes.NewReader(r.Body())) if err != nil { @@ -114,7 +114,7 @@ func (r *Response) Save(v any) error { } defer func() { if pc, ok := p.(io.WriteCloser); ok { - pc.Close() + _ = pc.Close() } }() From adcca7d07f65c4b9fdcfc19829a43e5644ddc8f9 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Fri, 26 Aug 2022 09:16:41 +0800 Subject: [PATCH 045/118] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20proxy=20suppo?= =?UTF-8?q?rt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 15 +++++++++++++++ client/core.go | 6 ++++++ 2 files changed, 21 insertions(+) diff --git a/client/client.go b/client/client.go index c8e84caf63..4888ece633 100644 --- a/client/client.go +++ b/client/client.go @@ -7,6 +7,7 @@ import ( "encoding/xml" "io" "net" + "net/url" "os" "path/filepath" "sync" @@ -54,6 +55,9 @@ type Client struct { // tls config tlsConfig *tls.Config + + // proxy + proxyURL string } // R raise a request from the client. @@ -186,6 +190,17 @@ func (c *Client) SetRootCertificateFromString(pem string) *Client { return c } +// SetProxyURL sets proxy url in client. It will apply via core to hostclient. +func (c *Client) SetProxyURL(proxyURL string) *Client { + pUrl, err := url.Parse(proxyURL) + if err != nil { + return c + } + c.proxyURL = pUrl.String() + + return c +} + // BaseURL returns baseurl in Client instance. func (c *Client) BaseURL() string { return c.baseUrl diff --git a/client/core.go b/client/core.go index 469b90aabc..3a927d0cff 100644 --- a/client/core.go +++ b/client/core.go @@ -170,6 +170,12 @@ func (c *core) tls() { // TODO now set url with Request uri, need cover with proxy url func (c *core) proxy() error { rawUri := c.req.RawRequest.URI() + if c.client.proxyURL != "" { + rawUri := fasthttp.AcquireURI() + rawUri.Update(c.client.proxyURL) + defer fasthttp.ReleaseURI(rawUri) + } + isTLS, scheme := false, rawUri.Scheme() if bytes.Equal(httpsBytes, scheme) { isTLS = true From d8fde62fab488709c20dd1553330945902e12ce6 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Fri, 26 Aug 2022 20:57:22 +0800 Subject: [PATCH 046/118] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20retry=20featu?= =?UTF-8?q?re?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 13 ++++++++++++ client/core.go | 54 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/client/client.go b/client/client.go index 4888ece633..6a9f5e720d 100644 --- a/client/client.go +++ b/client/client.go @@ -58,6 +58,9 @@ type Client struct { // proxy proxyURL string + + // retry + retryConfig *RetryConfig } // R raise a request from the client. @@ -201,6 +204,16 @@ func (c *Client) SetProxyURL(proxyURL string) *Client { return c } +func (c *Client) RetryConfig() *RetryConfig { + return c.retryConfig +} + +// SetRetryConfig sets retry config in client which is impl by addon/retry package. +func (c *Client) SetRetryConfig(config *RetryConfig) *Client { + c.retryConfig = config + return c +} + // BaseURL returns baseurl in Client instance. func (c *Client) BaseURL() string { return c.baseUrl diff --git a/client/core.go b/client/core.go index 3a927d0cff..28a82c0e63 100644 --- a/client/core.go +++ b/client/core.go @@ -10,6 +10,7 @@ import ( "sync" "sync/atomic" + "github.com/gofiber/fiber/v3/addon/retry" "github.com/valyala/fasthttp" ) @@ -18,6 +19,21 @@ var ( httpsBytes = []byte("https") ) +// RequestHook is a function that receives Agent and Request, +// it can change the data in Request and Agent. +// +// Called before a request is sent. +type RequestHook func(*Client, *Request) error + +// ResponseHook is a function that receives Agent, Respose and Request, +// it can change the data is Respose or deal with some effects. +// +// Called after a respose has been received. +type ResponseHook func(*Client, *Response, *Request) error + +// RetryConfig is an alias for config in the `addon/retry` package. +type RetryConfig = retry.Config + // addMissingPort will add the corresponding port number for host. func addMissingPort(addr string, isTLS bool) string { n := strings.Index(addr, ":") @@ -31,18 +47,6 @@ func addMissingPort(addr string, isTLS bool) string { return net.JoinHostPort(addr, strconv.Itoa(port)) } -// RequestHook is a function that receives Agent and Request, -// it can change the data in Request and Agent. -// -// Called before a request is sent. -type RequestHook func(*Client, *Request) error - -// ResponseHook is a function that receives Agent, Respose and Request, -// it can change the data is Respose or deal with some effects. -// -// Called after a respose has been received. -type ResponseHook func(*Client, *Response, *Request) error - // `core` stores middleware and plugin definitions, // and defines the execution process type core struct { @@ -66,9 +70,33 @@ func (c *core) execFunc() (*Response, error) { }() c.req.RawRequest.CopyTo(reqv) + cfg := func() *RetryConfig { + c.client.mu.Lock() + defer c.client.mu.Unlock() + + c := c.client.RetryConfig() + if c == nil { + return nil + } + + return &RetryConfig{ + InitialInterval: c.InitialInterval, + MaxBackoffTime: c.MaxBackoffTime, + Multiplier: c.Multiplier, + MaxRetryCount: c.MaxRetryCount, + } + }() + go func() { + var err error respv := fasthttp.AcquireResponse() - err := c.host.Do(reqv, respv) + if cfg != nil { + retry.NewExponentialBackoff(*cfg).Retry(func() error { + return c.host.Do(reqv, respv) + }) + } else { + err = c.host.Do(reqv, respv) + } defer func() { fasthttp.ReleaseRequest(reqv) fasthttp.ReleaseResponse(respv) From a231a9280969e5178f88aa2903074da353ad2d78 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Fri, 26 Aug 2022 21:23:00 +0800 Subject: [PATCH 047/118] =?UTF-8?q?=F0=9F=90=9B=20fix:=20static=20check=20?= =?UTF-8?q?error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 4 ++-- client/core.go | 2 +- client/response.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/client/client.go b/client/client.go index 6a9f5e720d..7ec814868f 100644 --- a/client/client.go +++ b/client/client.go @@ -161,8 +161,8 @@ func (c *Client) SetCertificates(certs ...tls.Certificate) *Client { // SetRootCertificate adds one or more root certificates into client. func (c *Client) SetRootCertificate(path string) *Client { - path = filepath.Clean(path) - file, err := os.Open(path) + cleanPath := filepath.Clean(path) + file, err := os.Open(cleanPath) if err != nil { return c } diff --git a/client/core.go b/client/core.go index 28a82c0e63..36d42f344e 100644 --- a/client/core.go +++ b/client/core.go @@ -91,7 +91,7 @@ func (c *core) execFunc() (*Response, error) { var err error respv := fasthttp.AcquireResponse() if cfg != nil { - retry.NewExponentialBackoff(*cfg).Retry(func() error { + err = retry.NewExponentialBackoff(*cfg).Retry(func() error { return c.host.Do(reqv, respv) }) } else { diff --git a/client/response.go b/client/response.go index e453f0d349..f835f238ce 100644 --- a/client/response.go +++ b/client/response.go @@ -89,7 +89,7 @@ func (r *Response) Save(v any) error { return err } - if err = os.MkdirAll(dir, 0755); err != nil { + if err = os.MkdirAll(dir, 0750); err != nil { return err } } From 88d77cd6a9248ee0778bdc37f64f1af906bf4a8b Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sat, 27 Aug 2022 10:26:39 +0800 Subject: [PATCH 048/118] =?UTF-8?q?=F0=9F=8E=A8=20refactor:=20move=20som?= =?UTF-8?q?=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/core.go | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/client/core.go b/client/core.go index 36d42f344e..72d5e0d1bd 100644 --- a/client/core.go +++ b/client/core.go @@ -57,6 +57,23 @@ type core struct { ctx context.Context } +func (c *core) getRetryConfig() *RetryConfig { + c.client.mu.Lock() + defer c.client.mu.Unlock() + + cfg := c.client.RetryConfig() + if cfg == nil { + return nil + } + + return &RetryConfig{ + InitialInterval: cfg.InitialInterval, + MaxBackoffTime: cfg.MaxBackoffTime, + Multiplier: cfg.Multiplier, + MaxRetryCount: cfg.MaxRetryCount, + } +} + func (c *core) execFunc() (*Response, error) { resp := AcquireResponse() resp.setClient(c.client) @@ -70,22 +87,7 @@ func (c *core) execFunc() (*Response, error) { }() c.req.RawRequest.CopyTo(reqv) - cfg := func() *RetryConfig { - c.client.mu.Lock() - defer c.client.mu.Unlock() - - c := c.client.RetryConfig() - if c == nil { - return nil - } - - return &RetryConfig{ - InitialInterval: c.InitialInterval, - MaxBackoffTime: c.MaxBackoffTime, - Multiplier: c.Multiplier, - MaxRetryCount: c.MaxRetryCount, - } - }() + cfg := c.getRetryConfig() go func() { var err error From b52b3dc4af5c182fcc2395950bd88204f3a7379f Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sat, 17 Sep 2022 21:29:34 +0800 Subject: [PATCH 049/118] docs: change readme --- client/README.md | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/client/README.md b/client/README.md index b8719aa21b..a992fcc6fb 100644 --- a/client/README.md +++ b/client/README.md @@ -4,6 +4,32 @@ ## Features +> The characteristics have not yet been written. + - GET, POST, PUT, DELETE, HEAD, PATCH, OPTIONS, etc. - Simple and chainable methods for settings and request -- \ No newline at end of file +- Request Body can be `string`, `[]byte`, `map`, `slice` + - Auto detects `Content-Type` + - Buffer processing for `files` + - Native `*fasthttp.Request` instance can be accessed during middleware and request execution via `Request.RawRequest` + - Request Body can be read multiple time via `Request.RawRequest.GetBody()` +- Response object gives you more possibility + - Access as `[]byte` by `response.Body()` or access as `string` by `response.String()` +- Automatic marshal and unmarshal for JSON and XML content type + - Default is JSON, if you supply struct/map without header Content-Type + - For auto-unmarshal, refer to - + - Success scenario Request.SetResult() and Response.Result(). + - Error scenario Request.SetError() and Response.Error(). + - Supports RFC7807 - application/problem+json & application/problem+xml + - Provide an option to override JSON Marshal/Unmarshal and XML Marshal/Unmarshal + +## Usage + +The following samples will assist you to become as comfortable as possible with `Fiber Client` library. + +```go +// Import Fiber Client into your code and refer it as `client`. +import "github.com/gofiber/fiber/client" +``` + +### Simple GET From 21881a4655b9b8427e72f3686c7ed64ac7b9e450 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sat, 17 Sep 2022 22:12:04 +0800 Subject: [PATCH 050/118] =?UTF-8?q?=E2=9C=A8=20feat:=20extend=20axios=20AP?= =?UTF-8?q?I?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 182 ++++++++++++++++++++++-------------------- client/client_test.go | 8 +- client/core.go | 4 +- 3 files changed, 103 insertions(+), 91 deletions(-) diff --git a/client/client.go b/client/client.go index 7ec814868f..4bfc0cff09 100644 --- a/client/client.go +++ b/client/client.go @@ -1,12 +1,12 @@ package client import ( + "context" "crypto/tls" "crypto/x509" "encoding/json" "encoding/xml" "io" - "net" "net/url" "os" "path/filepath" @@ -390,78 +390,57 @@ func (c *Client) SetTimeout(t time.Duration) *Client { } // Get provide a API like axios which send get request. -func (c *Client) Get(url string, setter ...SetRequestOptionFunc) (*Response, error) { +func (c *Client) Get(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) - - for _, v := range setter { - v(req) - } + setConfigToRequest(req, cfg...) return req.Get(url) } // Post provide a API like axios which send post request. -func (c *Client) Post(url string, setter ...SetRequestOptionFunc) (*Response, error) { +func (c *Client) Post(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) - - for _, v := range setter { - v(req) - } + setConfigToRequest(req, cfg...) return req.Post(url) } // Head provide a API like axios which send head request. -func (c *Client) Head(url string, setter ...SetRequestOptionFunc) (*Response, error) { +func (c *Client) Head(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) - - for _, v := range setter { - v(req) - } + setConfigToRequest(req, cfg...) return req.Head(url) } // Put provide a API like axios which send put request. -func (c *Client) Put(url string, setter ...SetRequestOptionFunc) (*Response, error) { +func (c *Client) Put(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) - - for _, v := range setter { - v(req) - } + setConfigToRequest(req, cfg...) return req.Put(url) } // Delete provide a API like axios which send delete request. -func (c *Client) Delete(url string, setter ...SetRequestOptionFunc) (*Response, error) { +func (c *Client) Delete(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) - - for _, v := range setter { - v(req) - } + setConfigToRequest(req, cfg...) return req.Delete(url) } // Options provide a API like axios which send options request. -func (c *Client) Options(url string, setter ...SetRequestOptionFunc) (*Response, error) { +func (c *Client) Options(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) - - for _, v := range setter { - v(req) - } + setConfigToRequest(req, cfg...) return req.Options(url) } // Patch provide a API like axios which send patch request. -func (c *Client) Patch(url string, setter ...SetRequestOptionFunc) (*Response, error) { +func (c *Client) Patch(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) - - for _, v := range setter { - v(req) - } + setConfigToRequest(req, cfg...) return req.Patch(url) } @@ -479,64 +458,92 @@ func (c *Client) Reset() { c.params.Reset() } -type SetRequestOptionFunc func(r *Request) +// Config for easy to set the request parameters, it should be +// noted that when setting the request body will use JSON as +// the default serialization mechanism, while the priority of +// Body is higher than FormData, and the priority of FormData +// is higher than File. +type Config struct { + Ctx context.Context -func SetRequestHeaders(m map[string]string) SetRequestOptionFunc { - return func(r *Request) { - r.SetHeaders(m) - } + UserAgent string + Referer string + Header map[string]string + Param map[string]string + Cookie map[string]string + PathParam map[string]string + + Timeout time.Duration + + Body any + FormData map[string]string + File []*File + + dial fasthttp.DialFunc } -func SetRequestQueryParams(m map[string]string) SetRequestOptionFunc { - return func(r *Request) { - r.SetParams(m) +// Set the parameters passed via Config to Request. +func setConfigToRequest(req *Request, config ...Config) { + if len(config) == 0 { + return } -} + cfg := config[0] -func SetRequestUserAgent(ua string) SetRequestOptionFunc { - return func(r *Request) { - r.SetUserAgent(ua) + if cfg.Ctx != nil { + req.SetContext(cfg.Ctx) } -} -func SetRequestReferer(referer string) SetRequestOptionFunc { - return func(r *Request) { - r.SetReferer(referer) + if cfg.UserAgent != "" { + req.SetUserAgent(cfg.UserAgent) } -} -func SetRequestData(v any) SetRequestOptionFunc { - return func(r *Request) { - r.SetJSON(v) + if cfg.Referer != "" { + req.SetReferer(cfg.Referer) } -} -func SetRequestFormDatas(m map[string]string) SetRequestOptionFunc { - return func(r *Request) { - r.SetFormDatas(m) + if cfg.Header != nil { + req.SetHeaders(cfg.Header) } -} -func SetRequestPathParams(m map[string]string) SetRequestOptionFunc { - return func(r *Request) { - r.SetPathParams(m) + if cfg.Param != nil { + req.SetParams(cfg.Param) } -} -func SetRequestFiles(files ...*File) SetRequestOptionFunc { - return func(r *Request) { - r.AddFiles(files...) + if cfg.Cookie != nil { + req.SetCookies(cfg.Cookie) + } + + if cfg.PathParam != nil { + req.SetPathParams(cfg.PathParam) } -} -func SetDial(f func(addr string) (net.Conn, error)) SetRequestOptionFunc { - return func(r *Request) { - r.dial = f + if cfg.Timeout != 0 { + req.SetTimeout(cfg.Timeout) + } + + if cfg.dial != nil { + req.SetDial(cfg.dial) + } + + if cfg.Body != nil { + req.SetJSON(cfg.Body) + return + } + + if cfg.FormData != nil { + req.SetFormDatas(cfg.FormData) + return + } + + if cfg.File != nil && len(cfg.File) != 0 { + req.AddFiles(cfg.File...) + return } } var ( defaultClient *Client + replaceMu = sync.Mutex{} defaultUserAgent = "fiber" clientPool = &sync.Pool{ New: func() any { @@ -590,6 +597,9 @@ func C() *Client { // Replce the defaultClient, the returned function can undo. func Replace(c *Client) func() { + replaceMu.Lock() + defer replaceMu.Unlock() + oldClient := defaultClient defaultClient = c @@ -599,36 +609,36 @@ func Replace(c *Client) func() { } // Get send a get request use defaultClient, a convenient method. -func Get(url string, setter ...SetRequestOptionFunc) (*Response, error) { - return defaultClient.Get(url, setter...) +func Get(url string, cfg ...Config) (*Response, error) { + return defaultClient.Get(url, cfg...) } // Post send a post request use defaultClient, a convenient method. -func Post(url string, setter ...SetRequestOptionFunc) (*Response, error) { - return defaultClient.Post(url, setter...) +func Post(url string, cfg ...Config) (*Response, error) { + return defaultClient.Post(url, cfg...) } // Head send a head request use defaultClient, a convenient method. -func Head(url string, setter ...SetRequestOptionFunc) (*Response, error) { - return defaultClient.Head(url, setter...) +func Head(url string, cfg ...Config) (*Response, error) { + return defaultClient.Head(url, cfg...) } // Put send a put request use defaultClient, a convenient method. -func Put(url string, setter ...SetRequestOptionFunc) (*Response, error) { - return defaultClient.Put(url, setter...) +func Put(url string, cfg ...Config) (*Response, error) { + return defaultClient.Put(url, cfg...) } // Delete send a delete request use defaultClient, a convenient method. -func Delete(url string, setter ...SetRequestOptionFunc) (*Response, error) { - return defaultClient.Delete(url, setter...) +func Delete(url string, cfg ...Config) (*Response, error) { + return defaultClient.Delete(url, cfg...) } // Options send a options request use defaultClient, a convenient method. -func Options(url string, setter ...SetRequestOptionFunc) (*Response, error) { - return defaultClient.Options(url, setter...) +func Options(url string, cfg ...Config) (*Response, error) { + return defaultClient.Options(url, cfg...) } // Patch send a patch request use defaultClient, a convenient method. -func Patch(url string, setter ...SetRequestOptionFunc) (*Response, error) { - return defaultClient.Patch(url, setter...) +func Patch(url string, cfg ...Config) (*Response, error) { + return defaultClient.Patch(url, cfg...) } diff --git a/client/client_test.go b/client/client_test.go index 2d21ba088e..c7b37a8919 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -64,9 +64,11 @@ func Test_Get(t *testing.T) { }() t.Run("global get function", func(t *testing.T) { - resp, err := Get("http://example.com", SetDial(func(addr string) (net.Conn, error) { - return ln.Dial() - })) + resp, err := Get("http://example.com", Config{ + dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + }) require.NoError(t, err) require.Equal(t, "example.com", utils.UnsafeString(resp.RawResponse.Body())) }) diff --git a/client/core.go b/client/core.go index 72d5e0d1bd..f0826cdd45 100644 --- a/client/core.go +++ b/client/core.go @@ -197,11 +197,11 @@ func (c *core) tls() { c.host.TLSConfig = c.client.tlsConfig.Clone() } -// TODO now set url with Request uri, need cover with proxy url +// proxy set proxy in host. func (c *core) proxy() error { rawUri := c.req.RawRequest.URI() if c.client.proxyURL != "" { - rawUri := fasthttp.AcquireURI() + rawUri = fasthttp.AcquireURI() rawUri.Update(c.client.proxyURL) defer fasthttp.ReleaseURI(rawUri) } From 2ac05755729247005d9135bfb919714a71f38882 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sun, 18 Sep 2022 15:52:52 +0800 Subject: [PATCH 051/118] perf: change field to export field --- client/client.go | 6 +++--- client/client_test.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/client/client.go b/client/client.go index 4bfc0cff09..c34d14331e 100644 --- a/client/client.go +++ b/client/client.go @@ -479,7 +479,7 @@ type Config struct { FormData map[string]string File []*File - dial fasthttp.DialFunc + Dial fasthttp.DialFunc } // Set the parameters passed via Config to Request. @@ -521,8 +521,8 @@ func setConfigToRequest(req *Request, config ...Config) { req.SetTimeout(cfg.Timeout) } - if cfg.dial != nil { - req.SetDial(cfg.dial) + if cfg.Dial != nil { + req.SetDial(cfg.Dial) } if cfg.Body != nil { diff --git a/client/client_test.go b/client/client_test.go index c7b37a8919..14b2184e85 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -65,7 +65,7 @@ func Test_Get(t *testing.T) { t.Run("global get function", func(t *testing.T) { resp, err := Get("http://example.com", Config{ - dial: func(addr string) (net.Conn, error) { + Dial: func(addr string) (net.Conn, error) { return ln.Dial() }, }) From 382cacc1c1f62212fd42439cdedfa29d2759dbf8 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sun, 18 Sep 2022 16:43:09 +0800 Subject: [PATCH 052/118] =?UTF-8?q?=E2=9C=85=20chore:=20disable=20startup?= =?UTF-8?q?=20message?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client_test.go | 2 +- client/core_test.go | 4 ++-- client/helper_test.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index 3be2f91bd0..375882880f 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -60,7 +60,7 @@ func Test_Get(t *testing.T) { }) go func() { - require.Nil(t, app.Listener(ln)) + require.Nil(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() t.Run("global get function", func(t *testing.T) { diff --git a/client/core_test.go b/client/core_test.go index 460991c0b6..e862ced508 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -70,7 +70,7 @@ func Test_Exec_Func(t *testing.T) { }) go func() { - require.Nil(t, app.Listener(ln)) + require.Nil(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() t.Run("normal request", func(t *testing.T) { @@ -143,7 +143,7 @@ func Test_Execute(t *testing.T) { }) go func() { - require.Nil(t, app.Listener(ln)) + require.Nil(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() t.Run("add user request hooks", func(t *testing.T) { diff --git a/client/helper_test.go b/client/helper_test.go index dd1747ff3a..080d8482ff 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -18,7 +18,7 @@ func createHelperServer(t *testing.T) (*fiber.App, func(addr string) (net.Conn, return app, func(addr string) (net.Conn, error) { return ln.Dial() }, func() { - require.Nil(t, app.Listener(ln)) + require.Nil(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) } } From 9a491250027f60907a86b271e0a2c0c70d85308d Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Thu, 6 Oct 2022 16:51:17 +0800 Subject: [PATCH 053/118] =?UTF-8?q?=F0=9F=90=9B=20fix:=20fix=20test=20erro?= =?UTF-8?q?r?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- listen_test.go | 117 ++++++++++++++++++++++++++----------------------- 1 file changed, 62 insertions(+), 55 deletions(-) diff --git a/listen_test.go b/listen_test.go index 797d90747f..00160f7603 100644 --- a/listen_test.go +++ b/listen_test.go @@ -2,6 +2,7 @@ package fiber import ( "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -15,6 +16,7 @@ import ( "time" "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" ) @@ -33,61 +35,66 @@ func Test_Listen(t *testing.T) { } // go test -run Test_Listen_Graceful_Shutdown -// func Test_Listen_Graceful_Shutdown(t *testing.T) { -// var mu sync.Mutex -// var shutdown bool - -// app := New() - -// app.Get("/", func(c Ctx) error { -// return c.SendString(c.Hostname()) -// }) - -// ln := fasthttputil.NewInmemoryListener() - -// go func() { -// ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) -// defer cancel() - -// err := app.Listener(ln, ListenConfig{ -// DisableStartupMessage: true, -// GracefulContext: ctx, -// OnShutdownSuccess: func() { -// mu.Lock() -// shutdown = true -// mu.Unlock() -// }, -// }) - -// require.NoError(t, err) -// }() - -// testCases := []struct { -// Time time.Duration -// ExpectedBody string -// ExpectedStatusCode int -// ExceptedErrsLen int -// }{ -// {Time: 100 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExceptedErrsLen: 0}, -// {Time: 500 * time.Millisecond, ExpectedBody: "", ExpectedStatusCode: 0, ExceptedErrsLen: 1}, -// } - -// for _, tc := range testCases { -// time.Sleep(tc.Time) - -// a := Get("http://example.com") -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } -// code, body, errs := a.String() - -// require.Equal(t, tc.ExpectedStatusCode, code) -// require.Equal(t, tc.ExpectedBody, body) -// require.Equal(t, tc.ExceptedErrsLen, len(errs)) -// } - -// mu.Lock() -// require.True(t, shutdown) -// mu.Unlock() -// } +func Test_Listen_Graceful_Shutdown(t *testing.T) { + var mu sync.Mutex + var shutdown bool + + app := New() + + app.Get("/", func(c Ctx) error { + return c.SendString(c.Hostname()) + }) + + ln := fasthttputil.NewInmemoryListener() + + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + + err := app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + GracefulContext: ctx, + OnShutdownSuccess: func() { + mu.Lock() + shutdown = true + mu.Unlock() + }, + }) + + require.NoError(t, err) + }() + + testCases := []struct { + Time time.Duration + ExpectedBody string + ExpectedStatusCode int + ExceptedErr error + }{ + {Time: 100 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExceptedErr: nil}, + {Time: 500 * time.Millisecond, ExpectedBody: "", ExpectedStatusCode: StatusOK, ExceptedErr: errors.New("InmemoryListener is already closed: use of closed network connection")}, + } + + for _, tc := range testCases { + time.Sleep(tc.Time) + + req := fasthttp.AcquireRequest() + req.SetRequestURI("http://example.com") + + client := fasthttp.HostClient{} + client.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + + resp := fasthttp.AcquireResponse() + err := client.Do(req, resp) + + require.Equal(t, tc.ExceptedErr, err) + require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode()) + require.Equal(t, tc.ExpectedBody, string(resp.Body())) + } + + mu.Lock() + require.True(t, shutdown) + mu.Unlock() +} // go test -run Test_Listen_Prefork func Test_Listen_Prefork(t *testing.T) { From 9d0560b909f86fda16426cf8ee5783b2d51fb8c2 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Thu, 6 Oct 2022 19:10:04 +0800 Subject: [PATCH 054/118] chore: fix error test --- client/request_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/request_test.go b/client/request_test.go index 9d32e50a95..69e5ed04b9 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -616,7 +616,7 @@ func Test_Request_Head(t *testing.T) { t.Parallel() app, ln, start := createHelperServer(t) - app.Get("/", func(c fiber.Ctx) error { + app.Head("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) From fdb2468ff29ce80b665c2fdd62b773340699fe93 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Thu, 6 Oct 2022 19:18:00 +0800 Subject: [PATCH 055/118] chore: fix test case --- redirect_test.go | 170 +++++++++++++++++++++++++---------------------- 1 file changed, 89 insertions(+), 81 deletions(-) diff --git a/redirect_test.go b/redirect_test.go index ad7989e4e3..d2a04152c4 100644 --- a/redirect_test.go +++ b/redirect_test.go @@ -5,11 +5,15 @@ package fiber import ( + "context" + "net" "net/url" "testing" + "time" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttputil" ) // go test -run Test_Redirect_To @@ -233,87 +237,91 @@ func Test_Redirect_setFlash(t *testing.T) { } // go test -run Test_Redirect_Request -// func Test_Redirect_Request(t *testing.T) { -// t.Parallel() - -// app := New() - -// app.Get("/", func(c Ctx) error { -// return c.Redirect().With("key", "value").With("key2", "value2").With("co\\:m\\,ma", "Fi\\:ber\\, v3").Route("name") -// }) - -// app.Get("/with-inputs", func(c Ctx) error { -// return c.Redirect().WithInput().With("key", "value").With("key2", "value2").Route("name") -// }) - -// app.Get("/just-inputs", func(c Ctx) error { -// return c.Redirect().WithInput().Route("name") -// }) - -// app.Get("/redirected", func(c Ctx) error { -// return c.JSON(Map{ -// "messages": c.Redirect().Messages(), -// "inputs": c.Redirect().OldInputs(), -// }) -// }).Name("name") - -// // Start test server -// ln := fasthttputil.NewInmemoryListener() -// go func() { -// ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) -// defer cancel() - -// err := app.Listener(ln, ListenConfig{ -// DisableStartupMessage: true, -// GracefulContext: ctx, -// }) - -// require.NoError(t, err) -// }() - -// // Test cases -// testCases := []struct { -// URL string -// CookieValue string -// ExpectedBody string -// ExpectedStatusCode int -// ExceptedErrsLen int -// }{ -// { -// URL: "/", -// CookieValue: "key:value,key2:value2,co\\:m\\,ma:Fi\\:ber\\, v3", -// ExpectedBody: `{"inputs":{},"messages":{"co:m,ma":"Fi:ber, v3","key":"value","key2":"value2"}}`, -// ExpectedStatusCode: StatusOK, -// ExceptedErrsLen: 0, -// }, -// { -// URL: "/with-inputs?name=john&surname=doe", -// CookieValue: "key:value,key2:value2,key:value,key2:value2,old_input_data_name:john,old_input_data_surname:doe", -// ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{"key":"value","key2":"value2"}}`, -// ExpectedStatusCode: StatusOK, -// ExceptedErrsLen: 0, -// }, -// { -// URL: "/just-inputs?name=john&surname=doe", -// CookieValue: "old_input_data_name:john,old_input_data_surname:doe", -// ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{}}`, -// ExpectedStatusCode: StatusOK, -// ExceptedErrsLen: 0, -// }, -// } - -// for _, tc := range testCases { -// a := Get("http://example.com" + tc.URL) -// a.Cookie(FlashCookieName, tc.CookieValue) -// a.MaxRedirectsCount(1) -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } -// code, body, errs := a.String() - -// require.Equal(t, tc.ExpectedStatusCode, code) -// require.Equal(t, tc.ExpectedBody, body) -// require.Equal(t, tc.ExceptedErrsLen, len(errs)) -// } -// } +func Test_Redirect_Request(t *testing.T) { + t.Parallel() + + app := New() + + app.Get("/", func(c Ctx) error { + return c.Redirect().With("key", "value").With("key2", "value2").With("co\\:m\\,ma", "Fi\\:ber\\, v3").Route("name") + }) + + app.Get("/with-inputs", func(c Ctx) error { + return c.Redirect().WithInput().With("key", "value").With("key2", "value2").Route("name") + }) + + app.Get("/just-inputs", func(c Ctx) error { + return c.Redirect().WithInput().Route("name") + }) + + app.Get("/redirected", func(c Ctx) error { + return c.JSON(Map{ + "messages": c.Redirect().Messages(), + "inputs": c.Redirect().OldInputs(), + }) + }).Name("name") + + // Start test server + ln := fasthttputil.NewInmemoryListener() + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + + err := app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + GracefulContext: ctx, + }) + + require.NoError(t, err) + }() + + // Test cases + testCases := []struct { + URL string + CookieValue string + ExpectedBody string + ExpectedStatusCode int + ExceptedErr error + }{ + { + URL: "/", + CookieValue: "key:value,key2:value2,co\\:m\\,ma:Fi\\:ber\\, v3", + ExpectedBody: `{"inputs":{},"messages":{"co:m,ma":"Fi:ber, v3","key":"value","key2":"value2"}}`, + ExpectedStatusCode: StatusOK, + ExceptedErr: nil, + }, + { + URL: "/with-inputs?name=john&surname=doe", + CookieValue: "key:value,key2:value2,key:value,key2:value2,old_input_data_name:john,old_input_data_surname:doe", + ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{"key":"value","key2":"value2"}}`, + ExpectedStatusCode: StatusOK, + ExceptedErr: nil, + }, + { + URL: "/just-inputs?name=john&surname=doe", + CookieValue: "old_input_data_name:john,old_input_data_surname:doe", + ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{}}`, + ExpectedStatusCode: StatusOK, + ExceptedErr: nil, + }, + } + + for _, tc := range testCases { + client := &fasthttp.HostClient{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + req, resp := fasthttp.AcquireRequest(), fasthttp.AcquireResponse() + req.SetRequestURI("http://example.com" + tc.URL) + req.Header.SetCookie(FlashCookieName, tc.CookieValue) + err := client.DoRedirects(req, resp, 1) + require.NoError(t, err) + + require.Equal(t, tc.ExpectedBody, string(resp.Body())) + require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode()) + } +} // go test -v -run=^$ -bench=Benchmark_Redirect_Route -benchmem -count=4 func Benchmark_Redirect_Route(b *testing.B) { From c5d2df90a25b19a9dd436d8c215746ebd22bcf8a Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sat, 15 Oct 2022 11:44:08 +0800 Subject: [PATCH 056/118] feat: add some test to client --- client/client_test.go | 467 ++++++++++++++++++++++-------------------- 1 file changed, 250 insertions(+), 217 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index c009dae650..6f0c45a853 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -2,306 +2,339 @@ package client import ( "fmt" - "net" "reflect" "testing" "github.com/gofiber/fiber/v3" "github.com/gofiber/utils" "github.com/stretchr/testify/require" - "github.com/valyala/fasthttp/fasthttputil" ) -// func Test_Client_Invalid_URL(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString(c.Hostname()) -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() +func Test_Client_Invalid_URL(t *testing.T) { + t.Parallel() -// a := Get("http://example.com\r\n\r\nGET /\r\n\r\n") + app, dial, start := createHelperServer(t) -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) -// _, body, errs := a.String() + go start() -// utils.AssertEqual(t, "", body) -// utils.AssertEqual(t, 1, len(errs)) -// utils.AssertEqual(t, "missing required Host header in request", errs[0].Error()) -// } + _, err := AcquireClient(). + R(). + SetDial(dial). + Get("http://example.com\r\n\r\nGET /\r\n\r\n") -// func Test_Client_Unsupported_Protocol(t *testing.T) { -// t.Parallel() + require.ErrorIs(t, err, ErrURLForamt) +} -// a := Get("ftp://example.com") +func Test_Client_Unsupported_Protocol(t *testing.T) { + t.Parallel() -// _, body, errs := a.String() + _, err := AcquireClient(). + R(). + Get("ftp://example.com") -// utils.AssertEqual(t, "", body) -// utils.AssertEqual(t, 1, len(errs)) -// utils.AssertEqual(t, `unsupported protocol "ftp". http and https are supported`, -// errs[0].Error()) -// } + require.ErrorIs(t, err, ErrURLForamt) +} func Test_Get(t *testing.T) { t.Parallel() - ln := fasthttputil.NewInmemoryListener() + app, dial, start := createHelperServer(t) - app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) - go func() { - require.Nil(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) - }() + go start() t.Run("global get function", func(t *testing.T) { resp, err := Get("http://example.com", Config{ - Dial: func(addr string) (net.Conn, error) { - return ln.Dial() - }, + Dial: dial, }) require.NoError(t, err) require.Equal(t, "example.com", utils.UnsafeString(resp.RawResponse.Body())) }) -} - -// func Test_Client_Get(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString(c.Hostname()) -// }) -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// for i := 0; i < 5; i++ { -// a := Get("http://example.com") - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "example.com", body) -// utils.AssertEqual(t, 0, len(errs)) -// } -// } - -// func Test_Client_Head(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString(c.Hostname()) -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// for i := 0; i < 5; i++ { -// a := Head("http://example.com") - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "", body) -// utils.AssertEqual(t, 0, len(errs)) -// } -// } - -// func Test_Client_Post(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Post("/", func(c fiber.Ctx) error { -// return c.Status(fiber.StatusCreated). -// SendString(c.FormValue("foo")) -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// for i := 0; i < 5; i++ { -// args := AcquireArgs() - -// args.Set("foo", "bar") - -// a := Post("http://example.com"). -// Form(args) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, fiber.StatusCreated, code) -// utils.AssertEqual(t, "bar", body) -// utils.AssertEqual(t, 0, len(errs)) + t.Run("client get", func(t *testing.T) { + resp, err := AcquireClient().Get("http://example.com", Config{ + Dial: dial, + }) + require.NoError(t, err) + require.Equal(t, "example.com", utils.UnsafeString(resp.RawResponse.Body())) + }) +} -// ReleaseArgs(args) -// } -// } +func Test_Head(t *testing.T) { + t.Parallel() -// func Test_Client_Put(t *testing.T) { -// t.Parallel() + app, dial, start := createHelperServer(t) -// ln := fasthttputil.NewInmemoryListener() + app.Head("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + go start() -// app.Put("/", func(c fiber.Ctx) error { -// return c.SendString(c.FormValue("foo")) -// }) + t.Run("global head function", func(t *testing.T) { + resp, err := Head("http://example.com", Config{ + Dial: dial, + }) + require.NoError(t, err) + require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body())) + }) -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + t.Run("client head", func(t *testing.T) { + resp, err := AcquireClient().Head("http://example.com", Config{ + Dial: dial, + }) + require.NoError(t, err) + require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body())) + }) +} -// for i := 0; i < 5; i++ { -// args := AcquireArgs() +func Test_Post(t *testing.T) { + t.Parallel() -// args.Set("foo", "bar") + app, dial, start := createHelperServer(t) + app.Post("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusCreated). + SendString(c.FormValue("foo")) + }) -// a := Put("http://example.com"). -// Form(args) + go start() -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + t.Run("global post function", func(t *testing.T) { + for i := 0; i < 5; i++ { + resp, err := Post("http://example.com", Config{ + Dial: dial, + FormData: map[string]string{ + "foo": "bar", + }, + }) -// code, body, errs := a.String() + require.Nil(t, err) + require.Equal(t, fiber.StatusCreated, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "bar", body) -// utils.AssertEqual(t, 0, len(errs)) + t.Run("client post", func(t *testing.T) { + for i := 0; i < 5; i++ { + resp, err := AcquireClient().Post("http://example.com", Config{ + Dial: dial, + FormData: map[string]string{ + "foo": "bar", + }, + }) -// ReleaseArgs(args) -// } -// } + require.Nil(t, err) + require.Equal(t, fiber.StatusCreated, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) +} -// func Test_Client_Patch(t *testing.T) { -// t.Parallel() +func Test_Put(t *testing.T) { + t.Parallel() -// ln := fasthttputil.NewInmemoryListener() + app, dial, start := createHelperServer(t) + app.Put("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + go start() -// app.Patch("/", func(c fiber.Ctx) error { -// return c.SendString(c.FormValue("foo")) -// }) + t.Run("global put function", func(t *testing.T) { + for i := 0; i < 5; i++ { + resp, err := Put("http://example.com", Config{ + Dial: dial, + FormData: map[string]string{ + "foo": "bar", + }, + }) -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + require.Nil(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) -// for i := 0; i < 5; i++ { -// args := AcquireArgs() + t.Run("client put", func(t *testing.T) { + for i := 0; i < 5; i++ { + resp, err := AcquireClient().Put("http://example.com", Config{ + Dial: dial, + FormData: map[string]string{ + "foo": "bar", + }, + }) -// args.Set("foo", "bar") + require.Nil(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) +} -// a := Patch("http://example.com"). -// Form(args) +func Test_Delete(t *testing.T) { + t.Parallel() -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + app, dial, start := createHelperServer(t) + app.Delete("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusNoContent). + SendString("deleted") + }) -// code, body, errs := a.String() + go start() -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "bar", body) -// utils.AssertEqual(t, 0, len(errs)) + t.Run("global delete function", func(t *testing.T) { + for i := 0; i < 5; i++ { + resp, err := Delete("http://example.com", Config{ + Dial: dial, + FormData: map[string]string{ + "foo": "bar", + }, + }) -// ReleaseArgs(args) -// } -// } + require.Nil(t, err) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) + } + }) -// func Test_Client_Delete(t *testing.T) { -// t.Parallel() + t.Run("client delete", func(t *testing.T) { + for i := 0; i < 5; i++ { + resp, err := AcquireClient().Delete("http://example.com", Config{ + Dial: dial, + FormData: map[string]string{ + "foo": "bar", + }, + }) -// ln := fasthttputil.NewInmemoryListener() + require.Nil(t, err) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) + } + }) +} -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) +func Test_Options(t *testing.T) { + t.Parallel() -// app.Delete("/", func(c fiber.Ctx) error { -// return c.Status(fiber.StatusNoContent). -// SendString("deleted") -// }) + app, dial, start := createHelperServer(t) + app.Options("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusNoContent).SendString("") + }) -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + go start() -// for i := 0; i < 5; i++ { -// args := AcquireArgs() + t.Run("global options function", func(t *testing.T) { + for i := 0; i < 5; i++ { + resp, err := Options("http://example.com", Config{ + Dial: dial, + }) -// a := Delete("http://example.com") + require.Nil(t, err) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) + } + }) -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + t.Run("client options", func(t *testing.T) { + for i := 0; i < 5; i++ { + resp, err := AcquireClient().Options("http://example.com", Config{ + Dial: dial, + }) -// code, body, errs := a.String() + require.Nil(t, err) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) + } + }) +} +func Test_Patch(t *testing.T) { + t.Parallel() -// utils.AssertEqual(t, fiber.StatusNoContent, code) -// utils.AssertEqual(t, "", body) -// utils.AssertEqual(t, 0, len(errs)) + app, dial, start := createHelperServer(t) -// ReleaseArgs(args) -// } -// } + app.Patch("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) -// func Test_Client_UserAgent(t *testing.T) { -// t.Parallel() + go start() -// ln := fasthttputil.NewInmemoryListener() + t.Run("global patch function", func(t *testing.T) { + for i := 0; i < 5; i++ { + resp, err := Patch("http://example.com", Config{ + Dial: dial, + FormData: map[string]string{ + "foo": "bar", + }, + }) -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) + require.Nil(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) -// app.Get("/", func(c fiber.Ctx) error { -// return c.Send(c.Request().Header.UserAgent()) -// }) + t.Run("client patch", func(t *testing.T) { + for i := 0; i < 5; i++ { + resp, err := AcquireClient().Patch("http://example.com", Config{ + Dial: dial, + FormData: map[string]string{ + "foo": "bar", + }, + }) -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + require.Nil(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) +} -// t.Run("default", func(t *testing.T) { -// for i := 0; i < 5; i++ { -// a := Get("http://example.com") +func Test_Client_UserAgent(t *testing.T) { + t.Parallel() -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + app, dial, start := createHelperServer(t) -// code, body, errs := a.String() + app.Get("/", func(c fiber.Ctx) error { + return c.Send(c.Request().Header.UserAgent()) + }) -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, defaultUserAgent, body) -// utils.AssertEqual(t, 0, len(errs)) -// } -// }) + go start() -// t.Run("custom", func(t *testing.T) { -// for i := 0; i < 5; i++ { -// c := AcquireClient() -// c.UserAgent = "ua" + t.Run("default", func(t *testing.T) { + for i := 0; i < 5; i++ { + resp, err := Get("http://example.com", Config{ + Dial: dial, + }) -// a := c.Get("http://example.com") + require.Nil(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, defaultUserAgent, resp.String()) + } + }) -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + t.Run("custom", func(t *testing.T) { + for i := 0; i < 5; i++ { + c := AcquireClient(). + SetUserAgent("ua") -// code, body, errs := a.String() + resp, err := c.Get("http://example.com", Config{ + Dial: dial, + }) -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "ua", body) -// utils.AssertEqual(t, 0, len(errs)) -// ReleaseClient(c) -// } -// }) -// } + require.Nil(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "ua", resp.String()) + ReleaseClient(c) + } + }) +} // func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) { // handler := func(c fiber.Ctx) error { From 6a2f0ab59986180f14d7d6b4c56302dd558a0f3f Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sun, 16 Oct 2022 14:45:17 +0800 Subject: [PATCH 057/118] chore: add test case --- client/client.go | 3 + client/client_test.go | 136 ++++++++++++++--------------------------- client/helper_test.go | 51 +++++++++++++++- client/request_test.go | 24 ++++---- 4 files changed, 111 insertions(+), 103 deletions(-) diff --git a/client/client.go b/client/client.go index aad371312f..c1817a3501 100644 --- a/client/client.go +++ b/client/client.go @@ -604,6 +604,9 @@ func Replace(c *Client) func() { defaultClient = c return func() { + replaceMu.Lock() + defer replaceMu.Unlock() + defaultClient = oldClient } } diff --git a/client/client_test.go b/client/client_test.go index 6f0c45a853..9975c41689 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -336,101 +336,59 @@ func Test_Client_UserAgent(t *testing.T) { }) } -// func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// c.Request().Header.VisitAll(func(key, value []byte) { -// if k := string(key); k == "K1" || k == "K2" { -// _, _ = c.Write(key) -// _, _ = c.Write(value) -// } -// }) -// return nil -// } - -// wrapAgent := func(a *Agent) { -// a.Set("k1", "v1"). -// SetBytesK([]byte("k1"), "v1"). -// SetBytesV("k1", []byte("v1")). -// AddBytesK([]byte("k1"), "v11"). -// AddBytesV("k1", []byte("v22")). -// AddBytesKV([]byte("k1"), []byte("v33")). -// SetBytesKV([]byte("k2"), []byte("v2")). -// Add("k2", "v22") -// } - -// testAgent(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") -// } - -// func Test_Client_Agent_Connection_Close(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// if c.Request().Header.ConnectionClose() { -// return c.SendString("close") -// } -// return c.SendString("not close") -// } - -// wrapAgent := func(a *Agent) { -// a.ConnectionClose() -// } - -// testAgent(t, handler, wrapAgent, "close") -// } - -// func Test_Client_Agent_UserAgent(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.Send(c.Request().Header.UserAgent()) -// } - -// wrapAgent := func(a *Agent) { -// a.UserAgent("ua"). -// UserAgentBytes([]byte("ua")) -// } - -// testAgent(t, handler, wrapAgent, "ua") -// } - -// func Test_Client_Agent_Cookie(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.SendString( -// c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) -// } - -// wrapAgent := func(a *Agent) { -// a.Cookie("k1", "v1"). -// CookieBytesK([]byte("k2"), "v2"). -// CookieBytesKV([]byte("k2"), []byte("v2")). -// Cookies("k3", "v3", "k4", "v4"). -// CookiesBytesKV([]byte("k3"), []byte("v3"), []byte("k4"), []byte("v4")) -// } - -// testAgent(t, handler, wrapAgent, "v1v2v3v4") -// } +func Test_Client_Headers(t *testing.T) { + handler := func(c fiber.Ctx) error { + c.Request().Header.VisitAll(func(key, value []byte) { + if k := string(key); k == "K1" || k == "K2" { + _, _ = c.Write(key) + _, _ = c.Write(value) + } + }) + return nil + } + + wrapAgent := func(c *Client) { + c.SetHeader("k1", "v1"). + AddHeader("k1", "v2"). + SetHeaders(map[string]string{ + "k2": "v2", + }). + AddHeaders(map[string][]string{ + "k2": {"v22"}, + }) + } -// func Test_Client_Agent_Referer(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.Send(c.Request().Header.Referer()) -// } + testClient(t, handler, wrapAgent, "K1v1K1v2K2v2K2v22") +} -// wrapAgent := func(a *Agent) { -// a.Referer("http://referer.com"). -// RefererBytes([]byte("http://referer.com")) -// } +func Test_Client_Cookie(t *testing.T) { + handler := func(c fiber.Ctx) error { + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3")) + } + + wrapAgent := func(c *Client) { + c.SetCookie("k1", "v1"). + SetCookies(map[string]string{ + "k2": "v2", + "k3": "v3", + }) + } -// testAgent(t, handler, wrapAgent, "http://referer.com") -// } + testClient(t, handler, wrapAgent, "v1v2v3") +} -// func Test_Client_Agent_ContentType(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.Send(c.Request().Header.ContentType()) -// } +func Test_Client_Referer(t *testing.T) { + handler := func(c fiber.Ctx) error { + return c.Send(c.Request().Header.Referer()) + } -// wrapAgent := func(a *Agent) { -// a.ContentType("custom-type"). -// ContentTypeBytes([]byte("custom-type")) -// } + wrapAgent := func(c *Client) { + c.SetReferer("http://referer.com") + } -// testAgent(t, handler, wrapAgent, "custom-type") -// } + testClient(t, handler, wrapAgent, "http://referer.com") +} // func Test_Client_Agent_Host(t *testing.T) { // t.Parallel() diff --git a/client/helper_test.go b/client/helper_test.go index 080d8482ff..914945fcd1 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -22,7 +22,7 @@ func createHelperServer(t *testing.T) (*fiber.App, func(addr string) (net.Conn, } } -func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted string, count ...int) { +func testRequest(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted string, count ...int) { t.Parallel() app, ln, start := createHelperServer(t) @@ -47,7 +47,54 @@ func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Reques } } -func testAgentFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted error, count ...int) { +func testRequestFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted error, count ...int) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Get("/", handler) + go start() + + c := 1 + if len(count) > 0 { + c = count[0] + } + + for i := 0; i < c; i++ { + req := AcquireRequest().SetDial(ln) + wrapAgent(req) + + _, err := req.Get("http://example.com") + + require.Equal(t, excepted.Error(), err.Error()) + } +} + +func testClient(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Client), excepted string, count ...int) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Get("/", handler) + go start() + + c := 1 + if len(count) > 0 { + c = count[0] + } + + for i := 0; i < c; i++ { + client := AcquireClient() + wrapAgent(client) + + resp, err := client.Get("http://example.com", Config{Dial: ln}) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, excepted, resp.String()) + resp.Close() + } +} + +func testClientFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted error, count ...int) { t.Parallel() app, ln, start := createHelperServer(t) diff --git a/client/request_test.go b/client/request_test.go index 69e5ed04b9..708746e08d 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -782,7 +782,7 @@ func Test_Request_Header_With_Server(t *testing.T) { AddHeader("k2", "v22") } - testAgent(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") + testRequest(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") } // func Test_Client_Agent_Connection_Close(t *testing.T) { @@ -808,11 +808,11 @@ func Test_Request_UserAgent_With_Server(t *testing.T) { } t.Run("default", func(t *testing.T) { - testAgent(t, handler, func(agent *Request) {}, defaultUserAgent, 5) + testRequest(t, handler, func(agent *Request) {}, defaultUserAgent, 5) }) t.Run("custom", func(t *testing.T) { - testAgent(t, handler, func(agent *Request) { + testRequest(t, handler, func(agent *Request) { agent.SetUserAgent("ua") }, "ua", 5) }) @@ -833,7 +833,7 @@ func Test_Request_Cookie_With_Server(t *testing.T) { }).DelCookies("k4") } - testAgent(t, handler, wrapAgent, "v1v2v3") + testRequest(t, handler, wrapAgent, "v1v2v3") } func Test_Request_Referer_With_Server(t *testing.T) { @@ -845,7 +845,7 @@ func Test_Request_Referer_With_Server(t *testing.T) { req.SetReferer("http://referer.com") } - testAgent(t, handler, wrapAgent, "http://referer.com") + testRequest(t, handler, wrapAgent, "http://referer.com") } // func Test_Client_Agent_Host(t *testing.T) { @@ -888,7 +888,7 @@ func Test_Request_QueryString_With_Server(t *testing.T) { }) } - testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") + testRequest(t, handler, wrapAgent, "foo=bar&bar=baz") } // func Test_Client_Agent_BasicAuth(t *testing.T) { @@ -932,7 +932,7 @@ func Test_Request_Body_With_Server(t *testing.T) { t.Parallel() t.Run("json body", func(t *testing.T) { - testAgent(t, + testRequest(t, func(c fiber.Ctx) error { require.Equal(t, "application/json", string(c.Request().Header.ContentType())) return c.SendString(string(c.Request().Body())) @@ -947,7 +947,7 @@ func Test_Request_Body_With_Server(t *testing.T) { }) t.Run("xml body", func(t *testing.T) { - testAgent(t, + testRequest(t, func(c fiber.Ctx) error { require.Equal(t, "application/xml", string(c.Request().Header.ContentType())) return c.SendString(string(c.Request().Body())) @@ -965,7 +965,7 @@ func Test_Request_Body_With_Server(t *testing.T) { }) t.Run("formdata", func(t *testing.T) { - testAgent(t, + testRequest(t, func(c fiber.Ctx) error { require.Equal(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) return c.Send([]byte("foo=" + c.FormValue("foo") + "&bar=" + c.FormValue("bar") + "&fiber=" + c.FormValue("fiber"))) @@ -1097,7 +1097,7 @@ func Test_Request_Body_With_Server(t *testing.T) { }) t.Run("raw body", func(t *testing.T) { - testAgent(t, + testRequest(t, func(c fiber.Ctx) error { return c.SendString(string(c.Request().Body())) }, @@ -1111,7 +1111,7 @@ func Test_Request_Body_With_Server(t *testing.T) { func Test_Request_Error_Body_With_Server(t *testing.T) { t.Run("json error", func(t *testing.T) { - testAgentFail(t, + testClientFail(t, func(c fiber.Ctx) error { return c.SendString("") }, @@ -1123,7 +1123,7 @@ func Test_Request_Error_Body_With_Server(t *testing.T) { }) t.Run("xml error", func(t *testing.T) { - testAgentFail(t, + testClientFail(t, func(c fiber.Ctx) error { return c.SendString("") }, From d27665fc4d8ed5ae69c2e59769cdcc9f55b40948 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sun, 16 Oct 2022 22:18:39 +0800 Subject: [PATCH 058/118] chore: add test case --- client/client.go | 14 +- client/client_test.go | 816 ++--------------------------------------- client/request_test.go | 34 -- 3 files changed, 28 insertions(+), 836 deletions(-) diff --git a/client/client.go b/client/client.go index c1817a3501..4ba26108cb 100644 --- a/client/client.go +++ b/client/client.go @@ -613,35 +613,35 @@ func Replace(c *Client) func() { // Get send a get request use defaultClient, a convenient method. func Get(url string, cfg ...Config) (*Response, error) { - return defaultClient.Get(url, cfg...) + return C().Get(url, cfg...) } // Post send a post request use defaultClient, a convenient method. func Post(url string, cfg ...Config) (*Response, error) { - return defaultClient.Post(url, cfg...) + return C().Post(url, cfg...) } // Head send a head request use defaultClient, a convenient method. func Head(url string, cfg ...Config) (*Response, error) { - return defaultClient.Head(url, cfg...) + return C().Head(url, cfg...) } // Put send a put request use defaultClient, a convenient method. func Put(url string, cfg ...Config) (*Response, error) { - return defaultClient.Put(url, cfg...) + return C().Put(url, cfg...) } // Delete send a delete request use defaultClient, a convenient method. func Delete(url string, cfg ...Config) (*Response, error) { - return defaultClient.Delete(url, cfg...) + return C().Delete(url, cfg...) } // Options send a options request use defaultClient, a convenient method. func Options(url string, cfg ...Config) (*Response, error) { - return defaultClient.Options(url, cfg...) + return C().Options(url, cfg...) } // Patch send a patch request use defaultClient, a convenient method. func Patch(url string, cfg ...Config) (*Response, error) { - return defaultClient.Patch(url, cfg...) + return C().Patch(url, cfg...) } diff --git a/client/client_test.go b/client/client_test.go index 9975c41689..2b7b83e431 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -349,22 +349,23 @@ func Test_Client_Headers(t *testing.T) { wrapAgent := func(c *Client) { c.SetHeader("k1", "v1"). - AddHeader("k1", "v2"). + AddHeader("k1", "v11"). + AddHeaders(map[string][]string{ + "k1": {"v22", "v33"}, + }). SetHeaders(map[string]string{ "k2": "v2", }). - AddHeaders(map[string][]string{ - "k2": {"v22"}, - }) + AddHeader("k2", "v22") } - testClient(t, handler, wrapAgent, "K1v1K1v2K2v2K2v22") + testClient(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") } func Test_Client_Cookie(t *testing.T) { handler := func(c fiber.Ctx) error { return c.SendString( - c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3")) + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) } wrapAgent := func(c *Client) { @@ -372,7 +373,8 @@ func Test_Client_Cookie(t *testing.T) { SetCookies(map[string]string{ "k2": "v2", "k3": "v3", - }) + "k4": "v4", + }).DelCookies("k4") } testClient(t, handler, wrapAgent, "v1v2v3") @@ -390,797 +392,21 @@ func Test_Client_Referer(t *testing.T) { testClient(t, handler, wrapAgent, "http://referer.com") } -// func Test_Client_Agent_Host(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString(c.Hostname()) -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// a := Get("http://1.1.1.1:8080"). -// Host("example.com"). -// HostBytes([]byte("example.com")) - -// utils.AssertEqual(t, "1.1.1.1:8080", a.HostClient.Addr) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "example.com", body) -// utils.AssertEqual(t, 0, len(errs)) -// } - -// func Test_Client_Agent_QueryString(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.Send(c.Request().URI().QueryString()) -// } - -// wrapAgent := func(a *Agent) { -// a.QueryString("foo=bar&bar=baz"). -// QueryStringBytes([]byte("foo=bar&bar=baz")) -// } - -// testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -// } - -// func Test_Client_Agent_BasicAuth(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// // Get authorization header -// auth := c.Get(fiber.HeaderAuthorization) -// // Decode the header contents -// raw, err := base64.StdEncoding.DecodeString(auth[6:]) -// utils.AssertEqual(t, nil, err) - -// return c.Send(raw) -// } - -// wrapAgent := func(a *Agent) { -// a.BasicAuth("foo", "bar"). -// BasicAuthBytes([]byte("foo"), []byte("bar")) -// } - -// testAgent(t, handler, wrapAgent, "foo:bar") -// } - -// func Test_Client_Agent_BodyString(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.Send(c.Request().Body()) -// } - -// wrapAgent := func(a *Agent) { -// a.BodyString("foo=bar&bar=baz") -// } - -// testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -// } - -// func Test_Client_Agent_Body(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.Send(c.Request().Body()) -// } - -// wrapAgent := func(a *Agent) { -// a.Body([]byte("foo=bar&bar=baz")) -// } - -// testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -// } - -// func Test_Client_Agent_BodyStream(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.Send(c.Request().Body()) -// } - -// wrapAgent := func(a *Agent) { -// a.BodyStream(strings.NewReader("body stream"), -1) -// } - -// testAgent(t, handler, wrapAgent, "body stream") -// } - -// func Test_Client_Agent_Custom_Response(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString("custom") -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// for i := 0; i < 5; i++ { -// a := AcquireAgent() -// resp := AcquireResponse() - -// req := a.Request() -// req.Header.SetMethod(fiber.MethodGet) -// req.SetRequestURI("http://example.com") - -// utils.AssertEqual(t, nil, a.Parse()) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.SetResponse(resp). -// String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "custom", body) -// utils.AssertEqual(t, "custom", string(resp.Body())) -// utils.AssertEqual(t, 0, len(errs)) - -// ReleaseResponse(resp) -// } -// } - -// func Test_Client_Agent_Dest(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString("dest") -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// t.Run("small dest", func(t *testing.T) { -// dest := []byte("de") - -// a := Get("http://example.com") - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.Dest(dest[:0]).String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "dest", body) -// utils.AssertEqual(t, "de", string(dest)) -// utils.AssertEqual(t, 0, len(errs)) -// }) - -// t.Run("enough dest", func(t *testing.T) { -// dest := []byte("foobar") - -// a := Get("http://example.com") - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.Dest(dest[:0]).String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "dest", body) -// utils.AssertEqual(t, "destar", string(dest)) -// utils.AssertEqual(t, 0, len(errs)) -// }) -// } - -// // readErrorConn is a struct for testing retryIf -// type readErrorConn struct { -// net.Conn -// } - -// func (r *readErrorConn) Read(p []byte) (int, error) { -// return 0, fmt.Errorf("error") -// } - -// func (r *readErrorConn) Write(p []byte) (int, error) { -// return len(p), nil -// } - -// func (r *readErrorConn) Close() error { -// return nil -// } - -// func (r *readErrorConn) LocalAddr() net.Addr { -// return nil -// } - -// func (r *readErrorConn) RemoteAddr() net.Addr { -// return nil -// } -// func Test_Client_Agent_RetryIf(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// a := Post("http://example.com"). -// RetryIf(func(req *Request) bool { -// return true -// }) -// dialsCount := 0 -// a.HostClient.Dial = func(addr string) (net.Conn, error) { -// dialsCount++ -// switch dialsCount { -// case 1: -// return &readErrorConn{}, nil -// case 2: -// return &readErrorConn{}, nil -// case 3: -// return &readErrorConn{}, nil -// case 4: -// return ln.Dial() -// default: -// t.Fatalf("unexpected number of dials: %d", dialsCount) -// } -// panic("unreachable") -// } - -// _, _, errs := a.String() -// utils.AssertEqual(t, dialsCount, 4) -// utils.AssertEqual(t, 0, len(errs)) -// } - -// func Test_Client_Agent_Json(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// utils.AssertEqual(t, fiber.MIMEApplicationJSON, string(c.Request().Header.ContentType())) - -// return c.Send(c.Request().Body()) -// } - -// wrapAgent := func(a *Agent) { -// a.JSON(data{Success: true}) -// } - -// testAgent(t, handler, wrapAgent, `{"success":true}`) -// } - -// func Test_Client_Agent_Json_Error(t *testing.T) { -// a := Get("http://example.com"). -// JSONEncoder(json.Marshal). -// JSON(complex(1, 1)) - -// _, body, errs := a.String() - -// utils.AssertEqual(t, "", body) -// utils.AssertEqual(t, 1, len(errs)) -// utils.AssertEqual(t, "json: unsupported type: complex128", errs[0].Error()) -// } - -// func Test_Client_Agent_XML(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// utils.AssertEqual(t, fiber.MIMEApplicationXML, string(c.Request().Header.ContentType())) - -// return c.Send(c.Request().Body()) -// } - -// wrapAgent := func(a *Agent) { -// a.XML(data{Success: true}) -// } - -// testAgent(t, handler, wrapAgent, "true") -// } - -// func Test_Client_Agent_XML_Error(t *testing.T) { -// a := Get("http://example.com"). -// XML(complex(1, 1)) - -// _, body, errs := a.String() - -// utils.AssertEqual(t, "", body) -// utils.AssertEqual(t, 1, len(errs)) -// utils.AssertEqual(t, "xml: unsupported type: complex128", errs[0].Error()) -// } - -// func Test_Client_Agent_Form(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// utils.AssertEqual(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) - -// return c.Send(c.Request().Body()) -// } - -// args := AcquireArgs() - -// args.Set("foo", "bar") - -// wrapAgent := func(a *Agent) { -// a.Form(args) -// } - -// testAgent(t, handler, wrapAgent, "foo=bar") - -// ReleaseArgs(args) -// } - -// func Test_Client_Agent_MultipartForm(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Post("/", func(c fiber.Ctx) error { -// utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) - -// mf, err := c.MultipartForm() -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, "bar", mf.Value["foo"][0]) - -// return c.Send(c.Request().Body()) -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// args := AcquireArgs() - -// args.Set("foo", "bar") - -// a := Post("http://example.com"). -// Boundary("myBoundary"). -// MultipartForm(args) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "--myBoundary\r\nContent-Disposition: form-data; name=\"foo\"\r\n\r\nbar\r\n--myBoundary--\r\n", body) -// utils.AssertEqual(t, 0, len(errs)) -// ReleaseArgs(args) -// } - -// func Test_Client_Agent_MultipartForm_Errors(t *testing.T) { -// t.Parallel() - -// a := AcquireAgent() -// a.mw = &errorMultipartWriter{} - -// args := AcquireArgs() -// args.Set("foo", "bar") - -// ff1 := &FormFile{"", "name1", []byte("content"), false} -// ff2 := &FormFile{"", "name2", []byte("content"), false} -// a.FileData(ff1, ff2). -// MultipartForm(args) - -// utils.AssertEqual(t, 4, len(a.errs)) -// ReleaseArgs(args) -// } - -// func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Post("/", func(c fiber.Ctx) error { -// utils.AssertEqual(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) - -// fh1, err := c.FormFile("field1") -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, fh1.Filename, "name") -// buf := make([]byte, fh1.Size) -// f, err := fh1.Open() -// utils.AssertEqual(t, nil, err) -// defer func() { _ = f.Close() }() -// _, err = f.Read(buf) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, "form file", string(buf)) - -// fh2, err := c.FormFile("index") -// utils.AssertEqual(t, nil, err) -// checkFormFile(t, fh2, ".github/testdata/index.html") - -// fh3, err := c.FormFile("file3") -// utils.AssertEqual(t, nil, err) -// checkFormFile(t, fh3, ".github/testdata/index.tmpl") - -// return c.SendString("multipart form files") -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// for i := 0; i < 5; i++ { -// ff := AcquireFormFile() -// ff.Fieldname = "field1" -// ff.Name = "name" -// ff.Content = []byte("form file") - -// a := Post("http://example.com"). -// Boundary("myBoundary"). -// FileData(ff). -// SendFiles(".github/testdata/index.html", "index", ".github/testdata/index.tmpl"). -// MultipartForm(nil) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "multipart form files", body) -// utils.AssertEqual(t, 0, len(errs)) - -// ReleaseFormFile(ff) -// } -// } - -// func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { -// t.Helper() - -// basename := filepath.Base(filename) -// utils.AssertEqual(t, fh.Filename, basename) - -// b1, err := os.ReadFile(filename) -// utils.AssertEqual(t, nil, err) - -// b2 := make([]byte, fh.Size) -// f, err := fh.Open() -// utils.AssertEqual(t, nil, err) -// defer func() { _ = f.Close() }() -// _, err = f.Read(b2) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, b1, b2) -// } - -// func Test_Client_Agent_Multipart_Random_Boundary(t *testing.T) { -// t.Parallel() - -// a := Post("http://example.com"). -// MultipartForm(nil) - -// reg := regexp.MustCompile(`multipart/form-data; boundary=\w{30}`) - -// utils.AssertEqual(t, true, reg.Match(a.req.Header.Peek(fiber.HeaderContentType))) -// } - -// func Test_Client_Agent_Multipart_Invalid_Boundary(t *testing.T) { -// t.Parallel() - -// a := Post("http://example.com"). -// Boundary("*"). -// MultipartForm(nil) - -// utils.AssertEqual(t, 1, len(a.errs)) -// utils.AssertEqual(t, "mime: invalid boundary character", a.errs[0].Error()) -// } - -// func Test_Client_Agent_SendFile_Error(t *testing.T) { -// t.Parallel() - -// a := Post("http://example.com"). -// SendFile("non-exist-file!", "") - -// utils.AssertEqual(t, 1, len(a.errs)) -// utils.AssertEqual(t, true, strings.Contains(a.errs[0].Error(), "open non-exist-file!")) -// } - -// func Test_Client_Debug(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.SendString("debug") -// } - -// var output bytes.Buffer - -// wrapAgent := func(a *Agent) { -// a.Debug(&output) -// } - -// testAgent(t, handler, wrapAgent, "debug", 1) - -// str := output.String() - -// utils.AssertEqual(t, true, strings.Contains(str, "Connected to example.com(pipe)")) -// utils.AssertEqual(t, true, strings.Contains(str, "GET / HTTP/1.1")) -// utils.AssertEqual(t, true, strings.Contains(str, "User-Agent: fiber")) -// utils.AssertEqual(t, true, strings.Contains(str, "Host: example.com\r\n\r\n")) -// utils.AssertEqual(t, true, strings.Contains(str, "HTTP/1.1 200 OK")) -// utils.AssertEqual(t, true, strings.Contains(str, "Content-Type: text/plain; charset=utf-8\r\nContent-Length: 5\r\n\r\ndebug")) -// } - -// func Test_Client_Agent_Timeout(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// time.Sleep(time.Millisecond * 200) -// return c.SendString("timeout") -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// a := Get("http://example.com"). -// Timeout(time.Millisecond * 50) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// _, body, errs := a.String() - -// utils.AssertEqual(t, "", body) -// utils.AssertEqual(t, 1, len(errs)) -// utils.AssertEqual(t, "timeout", errs[0].Error()) -// } - -// func Test_Client_Agent_Reuse(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString("reuse") -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// a := Get("http://example.com"). -// Reuse() - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "reuse", body) -// utils.AssertEqual(t, 0, len(errs)) - -// code, body, errs = a.String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "reuse", body) -// utils.AssertEqual(t, 0, len(errs)) -// } - -// func Test_Client_Agent_InsecureSkipVerify(t *testing.T) { -// t.Parallel() - -// cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") -// utils.AssertEqual(t, nil, err) - -// serverTLSConf := &tls.Config{ -// Certificates: []tls.Certificate{cer}, -// } - -// ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") -// utils.AssertEqual(t, nil, err) - -// ln = tls.NewListener(ln, serverTLSConf) - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString("ignore tls") -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// code, body, errs := Get("https://" + ln.Addr().String()). -// InsecureSkipVerify(). -// InsecureSkipVerify(). -// String() - -// utils.AssertEqual(t, 0, len(errs)) -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "ignore tls", body) -// } - -// func Test_Client_Agent_TLS(t *testing.T) { -// t.Parallel() - -// serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() -// utils.AssertEqual(t, nil, err) - -// ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") -// utils.AssertEqual(t, nil, err) - -// ln = tls.NewListener(ln, serverTLSConf) - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString("tls") -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// code, body, errs := Get("https://" + ln.Addr().String()). -// TLSConfig(clientTLSConf). -// String() - -// utils.AssertEqual(t, 0, len(errs)) -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "tls", body) -// } - -// func Test_Client_Agent_MaxRedirectsCount(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// if c.Request().URI().QueryArgs().Has("foo") { -// return c.Redirect("/foo") -// } -// return c.Redirect("/") -// }) -// app.Get("/foo", func(c fiber.Ctx) error { -// return c.SendString("redirect") -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// t.Run("success", func(t *testing.T) { -// a := Get("http://example.com?foo"). -// MaxRedirectsCount(1) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, 200, code) -// utils.AssertEqual(t, "redirect", body) -// utils.AssertEqual(t, 0, len(errs)) -// }) - -// t.Run("error", func(t *testing.T) { -// a := Get("http://example.com"). -// MaxRedirectsCount(1) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// _, body, errs := a.String() - -// utils.AssertEqual(t, "", body) -// utils.AssertEqual(t, 1, len(errs)) -// utils.AssertEqual(t, "too many redirects detected when doing the request", errs[0].Error()) -// }) -// } - -// func Test_Client_Agent_Struct(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.JSON(data{true}) -// }) - -// app.Get("/error", func(c fiber.Ctx) error { -// return c.SendString(`{"success"`) -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// t.Run("success", func(t *testing.T) { -// t.Parallel() - -// a := Get("http://example.com") - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// var d data - -// code, body, errs := a.Struct(&d) - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, `{"success":true}`, string(body)) -// utils.AssertEqual(t, 0, len(errs)) -// utils.AssertEqual(t, true, d.Success) -// }) - -// t.Run("pre error", func(t *testing.T) { -// t.Parallel() -// a := Get("http://example.com") - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } -// a.errs = append(a.errs, errors.New("pre errors")) - -// var d data -// _, body, errs := a.Struct(&d) - -// utils.AssertEqual(t, "", string(body)) -// utils.AssertEqual(t, 1, len(errs)) -// utils.AssertEqual(t, "pre errors", errs[0].Error()) -// utils.AssertEqual(t, false, d.Success) -// }) - -// t.Run("error", func(t *testing.T) { -// a := Get("http://example.com/error") - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// var d data - -// code, body, errs := a.JSONDecoder(json.Unmarshal).Struct(&d) - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, `{"success"`, string(body)) -// utils.AssertEqual(t, 1, len(errs)) -// utils.AssertEqual(t, "unexpected end of JSON input", errs[0].Error()) -// }) -// } - -// func Test_Client_Agent_Parse(t *testing.T) { -// t.Parallel() - -// a := Get("https://example.com:10443") - -// utils.AssertEqual(t, nil, a.Parse()) -// } - -// func Test_AddMissingPort_TLS(t *testing.T) { -// addr := addMissingPort("example.com", true) -// utils.AssertEqual(t, "example.com:443", addr) -// } - -// func testAgent(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Agent), excepted string, count ...int) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", handler) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// c := 1 -// if len(count) > 0 { -// c = count[0] -// } - -// for i := 0; i < c; i++ { -// a := Get("http://example.com") - -// wrapAgent(a) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, excepted, body) -// utils.AssertEqual(t, 0, len(errs)) -// } -// } - -// type data struct { -// Success bool `json:"success" xml:"success"` -// } - -// type errorMultipartWriter struct { -// count int -// } +func Test_Client_Params(t *testing.T) { + handler := func(c fiber.Ctx) error { + c.WriteString(c.Query("k1")) + c.WriteString(c.Query("k2")) -// func (e *errorMultipartWriter) Boundary() string { return "myBoundary" } -// func (e *errorMultipartWriter) SetBoundary(_ string) error { return nil } -// func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) { -// if e.count == 0 { -// e.count++ -// return nil, errors.New("CreateFormFile error") -// } -// return errorWriter{}, nil -// } -// func (e *errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") } -// func (e *errorMultipartWriter) Close() error { return errors.New("Close error") } + return nil + } -// type errorWriter struct{} + wrapAgent := func(c *Client) { + c.SetParam("k1", "v1"). + AddParam("k2", "v2") + } -// func (errorWriter) Write(_ []byte) (int, error) { return 0, errors.New("Write error") } + testClient(t, handler, wrapAgent, "v1v2") +} func Test_Client_R(t *testing.T) { t.Parallel() diff --git a/client/request_test.go b/client/request_test.go index 708746e08d..8ea1442d3d 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -785,21 +785,6 @@ func Test_Request_Header_With_Server(t *testing.T) { testRequest(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") } -// func Test_Client_Agent_Connection_Close(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// if c.Request().Header.ConnectionClose() { -// return c.SendString("close") -// } -// return c.SendString("not close") -// } - -// wrapAgent := func(a *Agent) { -// a.ConnectionClose() -// } - -// testAgent(t, handler, wrapAgent, "close") -// } - func Test_Request_UserAgent_With_Server(t *testing.T) { t.Parallel() @@ -891,25 +876,6 @@ func Test_Request_QueryString_With_Server(t *testing.T) { testRequest(t, handler, wrapAgent, "foo=bar&bar=baz") } -// func Test_Client_Agent_BasicAuth(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// // Get authorization header -// auth := c.Get(fiber.HeaderAuthorization) -// // Decode the header contents -// raw, err := base64.StdEncoding.DecodeString(auth[6:]) -// utils.AssertEqual(t, nil, err) - -// return c.Send(raw) -// } - -// wrapAgent := func(a *Agent) { -// a.BasicAuth("foo", "bar"). -// BasicAuthBytes([]byte("foo"), []byte("bar")) -// } - -// testAgent(t, handler, wrapAgent, "foo:bar") -// } - func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { t.Helper() From af5cd0b0012ae176415ba50cc3e678f53e232f1d Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sun, 16 Oct 2022 22:25:31 +0800 Subject: [PATCH 059/118] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20peek=20for=20?= =?UTF-8?q?client?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 22 +++++ client/client_test.go | 191 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 203 insertions(+), 10 deletions(-) diff --git a/client/client.go b/client/client.go index 4ba26108cb..cff1ee38b5 100644 --- a/client/client.go +++ b/client/client.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "path/filepath" + "sort" "sync" "time" @@ -225,6 +226,13 @@ func (c *Client) SetBaseURL(url string) *Client { return c } +// Header method returns header value via key, +// this method will visit all field in the header, +// then sort them. +func (c *Client) Header(key string) []string { + return c.header.PeekMultiple(key) +} + // AddHeader method adds a single header field and its value in the client instance. // These headers will be applied to all requests raised from this client instance. // Also it can be overridden at request level header options. @@ -257,6 +265,20 @@ func (c *Client) SetHeaders(h map[string]string) *Client { return c } +// Param method returns params value via key, +// this method will visit all field in the query param, +// then sort them. +func (c *Client) Param(key string) []string { + res := []string{} + tmp := c.params.PeekMulti(key) + for _, v := range tmp { + res = append(res, utils.UnsafeString(v)) + } + sort.Strings(res) + + return res +} + // AddParam method adds a single query param field and its value in the client instance. // These params will be applied to all requests raised from this client instance. // Also it can be overridden at request level param options. diff --git a/client/client_test.go b/client/client_test.go index 2b7b83e431..590cee49f8 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -336,7 +336,66 @@ func Test_Client_UserAgent(t *testing.T) { }) } -func Test_Client_Headers(t *testing.T) { +func Test_Client_Header(t *testing.T) { + t.Parallel() + + t.Run("add header", func(t *testing.T) { + req := AcquireClient() + req.AddHeader("foo", "bar").AddHeader("foo", "fiber") + + res := req.Header("foo") + require.Equal(t, 2, len(res)) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) + + t.Run("set header", func(t *testing.T) { + req := AcquireClient() + req.AddHeader("foo", "bar").SetHeader("foo", "fiber") + + res := req.Header("foo") + require.Equal(t, 1, len(res)) + require.Equal(t, "fiber", res[0]) + }) + + t.Run("add headers", func(t *testing.T) { + req := AcquireClient() + req.SetHeader("foo", "bar"). + AddHeaders(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.Header("foo") + require.Equal(t, 3, len(res)) + require.Equal(t, "bar", res[0]) + require.Equal(t, "buaa", res[1]) + require.Equal(t, "fiber", res[2]) + + res = req.Header("bar") + require.Equal(t, 1, len(res)) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + req := AcquireClient() + req.SetHeader("foo", "bar"). + SetHeaders(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.Header("foo") + require.Equal(t, 1, len(res)) + require.Equal(t, "fiber", res[0]) + + res = req.Header("bar") + require.Equal(t, 1, len(res)) + require.Equal(t, "foo", res[0]) + }) +} + +func Test_Client_Header_With_Server(t *testing.T) { handler := func(c fiber.Ctx) error { c.Request().Header.VisitAll(func(key, value []byte) { if k := string(key); k == "K1" || k == "K2" { @@ -392,7 +451,127 @@ func Test_Client_Referer(t *testing.T) { testClient(t, handler, wrapAgent, "http://referer.com") } -func Test_Client_Params(t *testing.T) { +func Test_Client_QueryParam(t *testing.T) { + t.Parallel() + + t.Run("add param", func(t *testing.T) { + req := AcquireClient() + req.AddParam("foo", "bar").AddParam("foo", "fiber") + + res := req.Param("foo") + require.Equal(t, 2, len(res)) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) + + t.Run("set param", func(t *testing.T) { + req := AcquireClient() + req.AddParam("foo", "bar").SetParam("foo", "fiber") + + res := req.Param("foo") + require.Equal(t, 1, len(res)) + require.Equal(t, "fiber", res[0]) + }) + + t.Run("add params", func(t *testing.T) { + req := AcquireClient() + req.SetParam("foo", "bar"). + AddParams(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.Param("foo") + require.Equal(t, 3, len(res)) + require.Equal(t, "bar", res[0]) + require.Equal(t, "buaa", res[1]) + require.Equal(t, "fiber", res[2]) + + res = req.Param("bar") + require.Equal(t, 1, len(res)) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + req := AcquireClient() + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.Param("foo") + require.Equal(t, 1, len(res)) + require.Equal(t, "fiber", res[0]) + + res = req.Param("bar") + require.Equal(t, 1, len(res)) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set params with struct", func(t *testing.T) { + t.Parallel() + + type args struct { + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `param:"int_slice"` + } + + p := AcquireClient() + p.SetParamsWithStruct(&args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: true, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + require.Equal(t, 0, len(p.Param("unexport"))) + + require.Equal(t, 1, len(p.Param("TInt"))) + require.Equal(t, "5", p.Param("TInt")[0]) + + require.Equal(t, 1, len(p.Param("TString"))) + require.Equal(t, "string", p.Param("TString")[0]) + + require.Equal(t, 1, len(p.Param("TFloat"))) + require.Equal(t, "3.1", p.Param("TFloat")[0]) + + require.Equal(t, 1, len(p.Param("TBool"))) + + tslice := p.Param("TSlice") + require.Equal(t, 2, len(tslice)) + require.Equal(t, "bar", tslice[0]) + require.Equal(t, "foo", tslice[1]) + + tint := p.Param("TSlice") + require.Equal(t, 2, len(tint)) + require.Equal(t, "bar", tint[0]) + require.Equal(t, "foo", tint[1]) + }) + + t.Run("del params", func(t *testing.T) { + req := AcquireClient() + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "foo": "fiber", + "bar": "foo", + }).DelParams("foo", "bar") + + res := req.Param("foo") + require.Equal(t, 0, len(res)) + + res = req.Param("bar") + require.Equal(t, 0, len(res)) + }) +} + +func Test_Client_QueryParam_With_Server(t *testing.T) { handler := func(c fiber.Ctx) error { c.WriteString(c.Query("k1")) c.WriteString(c.Query("k2")) @@ -505,11 +684,3 @@ func Test_Client_SetBaseURL(t *testing.T) { require.Equal(t, "http://example.com", client.BaseURL()) } - -func Test_Client_Header(t *testing.T) { - t.Parallel() - - t.Run("", func(t *testing.T) { - - }) -} From 37358db6db762352a8f1c2fe76f2ffd84de164de Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Mon, 17 Oct 2022 15:36:29 +0800 Subject: [PATCH 060/118] =?UTF-8?q?=E2=9C=85=20chore:=20add=20test=20case?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 19 +++ client/client_test.go | 300 ++++++++++++++++++++++++++++++++---------- 2 files changed, 252 insertions(+), 67 deletions(-) diff --git a/client/client.go b/client/client.go index cff1ee38b5..de7ef75f48 100644 --- a/client/client.go +++ b/client/client.go @@ -343,6 +343,16 @@ func (c *Client) SetReferer(r string) *Client { return c } +// PathParam returns the path param be set in request instance. +// if path param doesn't exist, return empty string. +func (c *Client) PathParam(key string) string { + if val, ok := (*c.path)[key]; ok { + return val + } + + return "" +} + // SetPathParam method sets a single path param field and its value in the client instance. // These path params will be applied to all requests raised from this client instance. // Also it can be overridden at request level path params options. @@ -373,6 +383,15 @@ func (c *Client) DelPathParams(key ...string) *Client { return c } +// Cookie returns the cookie be set in request instance. +// if cookie doesn't exist, return empty string. +func (c *Client) Cookie(key string) string { + if val, ok := (*c.cookies)[key]; ok { + return val + } + return "" +} + // SetCookie method sets a single cookie field and its value in the client instance. // These cookies will be applied to all requests raised from this client instance. // Also it can be overridden at request level cookie options. diff --git a/client/client_test.go b/client/client_test.go index 590cee49f8..20d841d2d0 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -10,6 +10,94 @@ import ( "github.com/stretchr/testify/require" ) +func Test_Client_Add_Hook(t *testing.T) { + t.Parallel() + + t.Run("add request hooks", func(t *testing.T) { + client := AcquireClient().AddRequestHook(func(c *Client, r *Request) error { + return nil + }) + + require.Equal(t, 1, len(client.RequestHook())) + + client.AddRequestHook(func(c *Client, r *Request) error { + return nil + }, func(c *Client, r *Request) error { + return nil + }) + + require.Equal(t, 3, len(client.RequestHook())) + }) + + t.Run("add response hooks", func(t *testing.T) { + client := AcquireClient().AddResponseHook(func(c *Client, resp *Response, r *Request) error { + return nil + }) + + require.Equal(t, 1, len(client.ResponseHook())) + + client.AddResponseHook(func(c *Client, resp *Response, r *Request) error { + return nil + }, func(c *Client, resp *Response, r *Request) error { + return nil + }) + + require.Equal(t, 3, len(client.ResponseHook())) + }) +} + +func Test_Client_Marshal(t *testing.T) { + t.Run("set json marshal", func(t *testing.T) { + client := AcquireClient(). + SetJSONMarshal(func(v any) ([]byte, error) { + return []byte("hello"), nil + }) + val, err := client.JSONMarshal()(nil) + + require.NoError(t, err) + require.Equal(t, []byte("hello"), val) + }) + + t.Run("set json unmarshal", func(t *testing.T) { + client := AcquireClient(). + SetJSONUnmarshal(func(data []byte, v any) error { + return fmt.Errorf("empty json") + }) + + err := client.JSONUnmarshal()(nil, nil) + require.Equal(t, fmt.Errorf("empty json"), err) + }) + + t.Run("set xml marshal", func(t *testing.T) { + client := AcquireClient(). + SetXMLMarshal(func(v any) ([]byte, error) { + return []byte("hello"), nil + }) + val, err := client.XMLMarshal()(nil) + + require.NoError(t, err) + require.Equal(t, []byte("hello"), val) + }) + + t.Run("set xml unmarshal", func(t *testing.T) { + client := AcquireClient(). + SetXMLUnmarshal(func(data []byte, v any) error { + return fmt.Errorf("empty xml") + }) + + err := client.XMLUnmarshal()(nil, nil) + require.Equal(t, fmt.Errorf("empty xml"), err) + }) +} + +func Test_Client_SetBaseURL(t *testing.T) { + t.Parallel() + + client := AcquireClient().SetBaseURL("http://example.com") + + require.Equal(t, "http://example.com", client.BaseURL()) +} + func Test_Client_Invalid_URL(t *testing.T) { t.Parallel() @@ -422,6 +510,64 @@ func Test_Client_Header_With_Server(t *testing.T) { } func Test_Client_Cookie(t *testing.T) { + t.Parallel() + + t.Run("set cookie", func(t *testing.T) { + req := AcquireClient(). + SetCookie("foo", "bar") + require.Equal(t, "bar", req.Cookie("foo")) + + req.SetCookie("foo", "bar1") + require.Equal(t, "bar1", req.Cookie("foo")) + }) + + t.Run("set cookies", func(t *testing.T) { + req := AcquireClient(). + SetCookies(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + + req.SetCookies(map[string]string{ + "foo": "bar1", + }) + require.Equal(t, "bar1", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + }) + + t.Run("set cookies with struct", func(t *testing.T) { + type args struct { + CookieInt int `cookie:"int"` + CookieString string `cookie:"string"` + } + + req := AcquireClient().SetCookiesWithStruct(&args{ + CookieInt: 5, + CookieString: "foo", + }) + + require.Equal(t, "5", req.Cookie("int")) + require.Equal(t, "foo", req.Cookie("string")) + }) + + t.Run("del cookies", func(t *testing.T) { + req := AcquireClient(). + SetCookies(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + + req.DelCookies("foo") + require.Equal(t, "", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + }) +} + +func Test_Client_Cookie_With_Server(t *testing.T) { handler := func(c fiber.Ctx) error { return c.SendString( c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) @@ -587,100 +733,120 @@ func Test_Client_QueryParam_With_Server(t *testing.T) { testClient(t, handler, wrapAgent, "v1v2") } -func Test_Client_R(t *testing.T) { +func Test_Client_PathParam(t *testing.T) { t.Parallel() - client := AcquireClient() - req := client.R() + t.Run("set path param", func(t *testing.T) { + req := AcquireClient(). + SetPathParam("foo", "bar") + require.Equal(t, "bar", req.PathParam("foo")) - require.Equal(t, "Request", reflect.TypeOf(req).Elem().Name()) - require.Equal(t, client, req.Client()) -} + req.SetPathParam("foo", "bar1") + require.Equal(t, "bar1", req.PathParam("foo")) + }) -func Test_Client_Add_Hook(t *testing.T) { - t.Parallel() + t.Run("set path params", func(t *testing.T) { + req := AcquireClient(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) - t.Run("add request hooks", func(t *testing.T) { - client := AcquireClient().AddRequestHook(func(c *Client, r *Request) error { - return nil + req.SetPathParams(map[string]string{ + "foo": "bar1", }) + require.Equal(t, "bar1", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + }) - require.Equal(t, 1, len(client.RequestHook())) + t.Run("set path params with struct", func(t *testing.T) { + type args struct { + CookieInt int `path:"int"` + CookieString string `path:"string"` + } - client.AddRequestHook(func(c *Client, r *Request) error { - return nil - }, func(c *Client, r *Request) error { - return nil + req := AcquireClient().SetPathParamsWithStruct(&args{ + CookieInt: 5, + CookieString: "foo", }) - require.Equal(t, 3, len(client.RequestHook())) + require.Equal(t, "5", req.PathParam("int")) + require.Equal(t, "foo", req.PathParam("string")) }) - t.Run("add response hooks", func(t *testing.T) { - client := AcquireClient().AddResponseHook(func(c *Client, resp *Response, r *Request) error { - return nil - }) + t.Run("del path params", func(t *testing.T) { + req := AcquireClient(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) - require.Equal(t, 1, len(client.ResponseHook())) + req.DelPathParams("foo") + require.Equal(t, "", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + }) +} - client.AddResponseHook(func(c *Client, resp *Response, r *Request) error { - return nil - }, func(c *Client, resp *Response, r *Request) error { - return nil - }) +func Test_Client_PathParam_With_Server(t *testing.T) { + app, dial, start := createHelperServer(t) - require.Equal(t, 3, len(client.ResponseHook())) + app.Get("/test", func(c fiber.Ctx) error { + return c.SendString("ok") }) + + go start() + + resp, err := AcquireClient(). + SetPathParam("path", "test"). + Get("http://example.com/:path", Config{Dial: dial}) + + require.Nil(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "ok", resp.String()) } -func Test_Client_Marshal(t *testing.T) { - t.Run("set json marshal", func(t *testing.T) { - client := AcquireClient(). - SetJSONMarshal(func(v any) ([]byte, error) { - return []byte("hello"), nil - }) - val, err := client.JSONMarshal()(nil) +func Test_Client_R(t *testing.T) { + t.Parallel() - require.NoError(t, err) - require.Equal(t, []byte("hello"), val) - }) + client := AcquireClient() + req := client.R() - t.Run("set json unmarshal", func(t *testing.T) { - client := AcquireClient(). - SetJSONUnmarshal(func(data []byte, v any) error { - return fmt.Errorf("empty json") - }) + require.Equal(t, "Request", reflect.TypeOf(req).Elem().Name()) + require.Equal(t, client, req.Client()) +} - err := client.JSONUnmarshal()(nil, nil) - require.Equal(t, fmt.Errorf("empty json"), err) +func Test_Replace(t *testing.T) { + app, dial, start := createHelperServer(t) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(string(c.Request().Header.Peek("k1"))) }) - t.Run("set xml marshal", func(t *testing.T) { - client := AcquireClient(). - SetXMLMarshal(func(v any) ([]byte, error) { - return []byte("hello"), nil - }) - val, err := client.XMLMarshal()(nil) + go start() - require.NoError(t, err) - require.Equal(t, []byte("hello"), val) - }) + resp, err := Get("http://example.com", Config{Dial: dial}) - t.Run("set xml unmarshal", func(t *testing.T) { - client := AcquireClient(). - SetXMLUnmarshal(func(data []byte, v any) error { - return fmt.Errorf("empty xml") - }) + require.Nil(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "", resp.String()) - err := client.XMLUnmarshal()(nil, nil) - require.Equal(t, fmt.Errorf("empty xml"), err) - }) -} + r := AcquireClient().SetHeader("k1", "v1") + clean := Replace(r) + resp, err = Get("http://example.com", Config{Dial: dial}) + require.Nil(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "v1", resp.String()) -func Test_Client_SetBaseURL(t *testing.T) { - t.Parallel() + clean() + ReleaseClient(r) - client := AcquireClient().SetBaseURL("http://example.com") + resp, err = Get("http://example.com", Config{Dial: dial}) - require.Equal(t, "http://example.com", client.BaseURL()) + require.Nil(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "", resp.String()) } From dd51324dbf8066b58f102f4b5a528a13ceb65e37 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Tue, 18 Oct 2022 11:08:25 +0800 Subject: [PATCH 061/118] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20feat:=20lazy=20gen?= =?UTF-8?q?erate=20rand=20string?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client_test.go | 82 ++++++++++++++++++++++++++++++++++++++++++ client/core.go | 1 + client/helper_test.go | 24 +------------ client/hooks.go | 3 ++ client/request.go | 4 +-- client/request_test.go | 4 +-- 6 files changed, 91 insertions(+), 27 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index 20d841d2d0..26f490b2ac 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,11 +1,14 @@ package client import ( + "crypto/tls" "fmt" + "net" "reflect" "testing" "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/gofiber/utils" "github.com/stretchr/testify/require" ) @@ -809,6 +812,67 @@ func Test_Client_PathParam_With_Server(t *testing.T) { require.Equal(t, "ok", resp.String()) } +// func Test_Client_Cert(t *testing.T) { +// t.Parallel() + +// serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() +// require.Nil(t, err) + +// ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") +// require.Nil(t, err) + +// ln = tls.NewListener(ln, serverTLSConf) + +// app := fiber.New() +// app.Get("/", func(c fiber.Ctx) error { +// return c.SendString("tls") +// }) + +// go func() { +// require.Nil(t, nil, app.Listener(ln, fiber.ListenConfig{ +// DisableStartupMessage: true, +// })) +// }() + +// client := AcquireClient().SetCertificates(clientTLSConf.Certificates...) +// resp, err := client.Get("https://" + ln.Addr().String()) + +// require.Nil(t, err) +// require.Equal(t, fiber.StatusOK, resp.StatusCode()) +// require.Equal(t, "tls", resp.String()) +// } + +func Test_Client_TLS(t *testing.T) { + t.Parallel() + + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + require.Nil(t, err) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.Nil(t, err) + + ln = tls.NewListener(ln, serverTLSConf) + + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("tls") + }) + + go func() { + require.Nil(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + client := AcquireClient() + resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) + + require.Nil(t, err) + require.Equal(t, clientTLSConf, client.TLSConfig()) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "tls", resp.String()) +} + func Test_Client_R(t *testing.T) { t.Parallel() @@ -850,3 +914,21 @@ func Test_Replace(t *testing.T) { require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "", resp.String()) } + +func Benchmark_Client_Request(b *testing.B) { + app, dial, start := createHelperServer(b) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) + + go start() + + b.ResetTimer() + b.ReportAllocs() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + resp, _ := Get("http://example.com", Config{Dial: dial}) + resp.Close() + } +} diff --git a/client/core.go b/client/core.go index f0826cdd45..1231c03c34 100644 --- a/client/core.go +++ b/client/core.go @@ -17,6 +17,7 @@ import ( var ( httpBytes = []byte("http") httpsBytes = []byte("https") + boundary = "--FiberFormBoundary" ) // RequestHook is a function that receives Agent and Request, diff --git a/client/helper_test.go b/client/helper_test.go index 914945fcd1..c27eb46105 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -9,7 +9,7 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) -func createHelperServer(t *testing.T) (*fiber.App, func(addr string) (net.Conn, error), func()) { +func createHelperServer(t testing.TB) (*fiber.App, func(addr string) (net.Conn, error), func()) { t.Helper() ln := fasthttputil.NewInmemoryListener() @@ -93,25 +93,3 @@ func testClient(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Clien resp.Close() } } - -func testClientFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted error, count ...int) { - t.Parallel() - - app, ln, start := createHelperServer(t) - app.Get("/", handler) - go start() - - c := 1 - if len(count) > 0 { - c = count[0] - } - - for i := 0; i < c; i++ { - req := AcquireRequest().SetDial(ln) - wrapAgent(req) - - _, err := req.Get("http://example.com") - - require.Equal(t, excepted.Error(), err.Error()) - } -} diff --git a/client/hooks.go b/client/hooks.go index 48ce7693f0..d9202bccd6 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -131,6 +131,9 @@ func parserRequestHeader(c *Client, req *Request) error { case filesBody: req.RawRequest.Header.SetContentType(multipartFormData) // set boundary + if req.boundary == boundary { + req.boundary = req.boundary + randString(16) + } req.RawRequest.Header.SetMultipartFormBoundary(req.boundary) default: } diff --git a/client/request.go b/client/request.go index e938092583..b1afbcd2d6 100644 --- a/client/request.go +++ b/client/request.go @@ -533,6 +533,7 @@ func (r *Request) Reset() { r.ctx = nil r.body = nil r.bodyType = noBody + r.boundary = boundary for len(r.files) != 0 { t := r.files[0] @@ -808,7 +809,7 @@ var requestPool = &sync.Pool{ params: &QueryParam{Args: fasthttp.AcquireArgs()}, cookies: &Cookie{}, path: &PathParam{}, - boundary: "--FiberFormBoundary" + randString(16), + boundary: "--FiberFormBoundary", formData: &FormData{Args: fasthttp.AcquireArgs()}, files: make([]*File, 0), RawRequest: fasthttp.AcquireRequest(), @@ -822,7 +823,6 @@ var requestPool = &sync.Pool{ // This allows reducing GC load. func AcquireRequest() *Request { req := requestPool.Get().(*Request) - req.boundary = "--FiberFormBoundary" + randString(16) return req } diff --git a/client/request_test.go b/client/request_test.go index 8ea1442d3d..2dadb7f498 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -1077,7 +1077,7 @@ func Test_Request_Body_With_Server(t *testing.T) { func Test_Request_Error_Body_With_Server(t *testing.T) { t.Run("json error", func(t *testing.T) { - testClientFail(t, + testRequestFail(t, func(c fiber.Ctx) error { return c.SendString("") }, @@ -1089,7 +1089,7 @@ func Test_Request_Error_Body_With_Server(t *testing.T) { }) t.Run("xml error", func(t *testing.T) { - testClientFail(t, + testRequestFail(t, func(c fiber.Ctx) error { return c.SendString("") }, From aac6425dd694b193dac25dc4d0511b51cda70839 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Thu, 10 Nov 2022 13:08:26 +0800 Subject: [PATCH 062/118] =?UTF-8?q?=F0=9F=9A=A7=20perf:=20add=20config=20t?= =?UTF-8?q?est=20case?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client_test.go | 119 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) diff --git a/client/client_test.go b/client/client_test.go index 26f490b2ac..6c8273b1ea 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "crypto/tls" "fmt" "net" @@ -915,6 +916,124 @@ func Test_Replace(t *testing.T) { require.Equal(t, "", resp.String()) } +func Test_Set_Config_To_Request(t *testing.T) { + t.Parallel() + + t.Run("set ctx", func(t *testing.T) { + key := struct{}{} + + ctx := context.Background() + ctx = context.WithValue(ctx, key, "v1") + + req := AcquireRequest() + + setConfigToRequest(req, Config{Ctx: ctx}) + + require.Equal(t, "v1", req.Context().Value(key)) + }) + + t.Run("set useragent", func(t *testing.T) { + req := AcquireRequest() + + setConfigToRequest(req, Config{UserAgent: "agent"}) + + require.Equal(t, "agent", req.UserAgent()) + }) + + t.Run("set referer", func(t *testing.T) { + req := AcquireRequest() + + setConfigToRequest(req, Config{Referer: "referer"}) + + require.Equal(t, "referer", req.Referer()) + }) + + t.Run("set header", func(t *testing.T) { + req := AcquireRequest() + + setConfigToRequest(req, Config{Header: map[string]string{ + "k1": "v1", + }}) + + require.Equal(t, "v1", req.Header("k1")[0]) + }) + + t.Run("set params", func(t *testing.T) { + req := AcquireRequest() + + setConfigToRequest(req, Config{Param: map[string]string{ + "k1": "v1", + }}) + + require.Equal(t, "v1", req.Param("k1")[0]) + }) + + // t.Run("set ctx", func(t *testing.T) { + // key := struct{}{} + + // ctx := context.Background() + // ctx = context.WithValue(ctx, key, "v1") + + // req := AcquireRequest() + + // setConfigToRequest(req, Config{Ctx: ctx}) + + // require.Equal(t, "v1", req.Context().Value(key)) + // }) + + // t.Run("set ctx", func(t *testing.T) { + // key := struct{}{} + + // ctx := context.Background() + // ctx = context.WithValue(ctx, key, "v1") + + // req := AcquireRequest() + + // setConfigToRequest(req, Config{Ctx: ctx}) + + // require.Equal(t, "v1", req.Context().Value(key)) + // }) + + // t.Run("set ctx", func(t *testing.T) { + // key := struct{}{} + + // ctx := context.Background() + // ctx = context.WithValue(ctx, key, "v1") + + // req := AcquireRequest() + + // setConfigToRequest(req, Config{Ctx: ctx}) + + // require.Equal(t, "v1", req.Context().Value(key)) + // }) + + // t.Run("set ctx", func(t *testing.T) { + // key := struct{}{} + + // ctx := context.Background() + // ctx = context.WithValue(ctx, key, "v1") + + // req := AcquireRequest() + + // setConfigToRequest(req, Config{Ctx: ctx}) + + // require.Equal(t, "v1", req.Context().Value(key)) + // }) + + // t.Run("set ctx", func(t *testing.T) { + // key := struct{}{} + + // ctx := context.Background() + // ctx = context.WithValue(ctx, key, "v1") + + // req := AcquireRequest() + + // setConfigToRequest(req, Config{Ctx: ctx}) + + // require.Equal(t, "v1", req.Context().Value(key)) + // }) +} + func Benchmark_Client_Request(b *testing.B) { app, dial, start := createHelperServer(b) app.Get("/", func(c fiber.Ctx) error { From 7e9564da9af8bbb0e442a77c2c7c715dd2b60c42 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Thu, 10 Nov 2022 13:16:29 +0800 Subject: [PATCH 063/118] =?UTF-8?q?=F0=9F=90=9B=20fix:=20fix=20merge=20err?= =?UTF-8?q?or?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 2 +- client/hooks.go | 2 +- go.mod | 1 + go.sum | 2 ++ middleware/proxy/proxy_test.go | 24 ++++++++++++++++-------- 5 files changed, 21 insertions(+), 10 deletions(-) diff --git a/client/client.go b/client/client.go index de7ef75f48..b64142a7f5 100644 --- a/client/client.go +++ b/client/client.go @@ -14,7 +14,7 @@ import ( "sync" "time" - "github.com/gofiber/utils" + "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) diff --git a/client/hooks.go b/client/hooks.go index d9202bccd6..2b495471cd 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -11,7 +11,7 @@ import ( "strings" "time" - "github.com/gofiber/utils" + "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) diff --git a/go.mod b/go.mod index 56e11d3baa..020c45263a 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/gofiber/fiber/v3 go 1.19 require ( + github.com/gofiber/utils v1.0.1 github.com/gofiber/utils/v2 v2.0.0-beta.1 github.com/google/uuid v1.3.0 github.com/mattn/go-colorable v0.1.13 diff --git a/go.sum b/go.sum index 7412a8ab1b..840582417f 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHG github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gofiber/utils v1.0.1 h1:knct4cXwBipWQqFrOy1Pv6UcgPM+EXo9jDgc66V1Qio= +github.com/gofiber/utils v1.0.1/go.mod h1:pacRFtghAE3UoknMOUiXh2Io/nLWSUHtQCi/3QASsOc= github.com/gofiber/utils/v2 v2.0.0-beta.1 h1:ACfPdqeclx+BFIja19UjkKx7k3r5tmpILpNgzrfPLKs= github.com/gofiber/utils/v2 v2.0.0-beta.1/go.mod h1:CG89nDoIkEFIJaw5LdLO9AmBM11odse/LC79KQujm74= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 41b88d91b0..9941a457bc 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -441,10 +441,14 @@ func Test_Proxy_Forward_Global_Client(t *testing.T) { })) }() - code, body, errs := fiber.Get("http://" + addr).String() - require.Equal(t, 0, len(errs)) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "test_global_client", body) + resp, err := fiberClient.AcquireClient(). + R(). + Get("https://" + addr) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "test_global_client", resp.String()) + resp.Close() } // go test -race -run Test_Proxy_Forward_Local_Client @@ -471,10 +475,14 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) { })) }() - code, body, errs := fiber.Get("http://" + addr).String() - require.Equal(t, 0, len(errs)) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "test_local_client", body) + resp, err := fiberClient.AcquireClient(). + R(). + Get("https://" + addr) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "test_local_client", resp.String()) + resp.Close() } // go test -run Test_ProxyBalancer_Custom_Client From 7e5445f2c0acae156e6551ef78a5bab9fbe8919c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Muhammed=20Efe=20=C3=87etin?= Date: Sat, 12 Nov 2022 10:09:41 +0300 Subject: [PATCH 064/118] :bug: fix utils error --- client/client.go | 2 +- client/client_test.go | 2 +- client/hooks.go | 2 +- client/request.go | 2 +- client/request_test.go | 2 ++ client/response.go | 2 +- middleware/proxy/proxy_test.go | 12 ++++++++---- 7 files changed, 15 insertions(+), 9 deletions(-) diff --git a/client/client.go b/client/client.go index de7ef75f48..b64142a7f5 100644 --- a/client/client.go +++ b/client/client.go @@ -14,7 +14,7 @@ import ( "sync" "time" - "github.com/gofiber/utils" + "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) diff --git a/client/client_test.go b/client/client_test.go index 6c8273b1ea..23c94a4cfc 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -10,7 +10,7 @@ import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/tlstest" - "github.com/gofiber/utils" + "github.com/gofiber/utils/v2" "github.com/stretchr/testify/require" ) diff --git a/client/hooks.go b/client/hooks.go index d9202bccd6..2b495471cd 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -11,7 +11,7 @@ import ( "strings" "time" - "github.com/gofiber/utils" + "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) diff --git a/client/request.go b/client/request.go index b1afbcd2d6..786cc2231a 100644 --- a/client/request.go +++ b/client/request.go @@ -12,7 +12,7 @@ import ( "time" "github.com/gofiber/fiber/v3" - "github.com/gofiber/utils" + "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) diff --git a/client/request_test.go b/client/request_test.go index 2dadb7f498..c90dafd68c 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -972,6 +972,8 @@ func Test_Request_Body_With_Server(t *testing.T) { SetFileReader(io.NopCloser(strings.NewReader("world"))), )) + require.Equal(t, req.Boundary(), "myBoundary") + resp, err := req.Post("http://exmaple.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) diff --git a/client/response.go b/client/response.go index 7c9f40f0d9..e9fb40826c 100644 --- a/client/response.go +++ b/client/response.go @@ -8,7 +8,7 @@ import ( "strings" "sync" - "github.com/gofiber/utils" + "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 41b88d91b0..5bdef6ae98 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -441,8 +441,10 @@ func Test_Proxy_Forward_Global_Client(t *testing.T) { })) }() - code, body, errs := fiber.Get("http://" + addr).String() - require.Equal(t, 0, len(errs)) + resp, err := fiberClient.Get("http://" + addr) + body := resp.String() + code := resp.StatusCode() + require.NoError(t, err) require.Equal(t, fiber.StatusOK, code) require.Equal(t, "test_global_client", body) } @@ -471,8 +473,10 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) { })) }() - code, body, errs := fiber.Get("http://" + addr).String() - require.Equal(t, 0, len(errs)) + resp, err := fiberClient.Get("http://" + addr) + body := resp.String() + code := resp.StatusCode() + require.NoError(t, err) require.Equal(t, fiber.StatusOK, code) require.Equal(t, "test_local_client", body) } From 4d6d79e1490c44d7a8541e7bfc42b9fc5f52b2c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Muhammed=20Efe=20=C3=87etin?= Date: Sat, 12 Nov 2022 11:09:02 +0300 Subject: [PATCH 065/118] :sparkles: add redirection --- client/client.go | 7 ++++++- client/core.go | 11 +++++++++- client/request.go | 17 ++++++++++++++- client/request_test.go | 47 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 79 insertions(+), 3 deletions(-) diff --git a/client/client.go b/client/client.go index b64142a7f5..ac14355bfd 100644 --- a/client/client.go +++ b/client/client.go @@ -514,7 +514,8 @@ type Config struct { Cookie map[string]string PathParam map[string]string - Timeout time.Duration + Timeout time.Duration + MaxRedirects int Body any FormData map[string]string @@ -562,6 +563,10 @@ func setConfigToRequest(req *Request, config ...Config) { req.SetTimeout(cfg.Timeout) } + if cfg.MaxRedirects != 0 { + req.SetMaxRedirects(cfg.MaxRedirects) + } + if cfg.Dial != nil { req.SetDial(cfg.Dial) } diff --git a/client/core.go b/client/core.go index 1231c03c34..7ed3dd9b25 100644 --- a/client/core.go +++ b/client/core.go @@ -10,6 +10,7 @@ import ( "sync" "sync/atomic" + "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/addon/retry" "github.com/valyala/fasthttp" ) @@ -95,10 +96,18 @@ func (c *core) execFunc() (*Response, error) { respv := fasthttp.AcquireResponse() if cfg != nil { err = retry.NewExponentialBackoff(*cfg).Retry(func() error { + if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { + return c.host.DoRedirects(reqv, respv, c.req.maxRedirects) + } + return c.host.Do(reqv, respv) }) } else { - err = c.host.Do(reqv, respv) + if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { + err = c.host.DoRedirects(reqv, respv, c.req.maxRedirects) + } else { + err = c.host.Do(reqv, respv) + } } defer func() { fasthttp.ReleaseRequest(reqv) diff --git a/client/request.go b/client/request.go index 786cc2231a..9e5b31dc18 100644 --- a/client/request.go +++ b/client/request.go @@ -48,7 +48,8 @@ type Request struct { cookies *Cookie path *PathParam - timeout time.Duration + timeout time.Duration + maxRedirects int client *Client @@ -475,6 +476,18 @@ func (r *Request) SetTimeout(t time.Duration) *Request { return r } +// MaxRedirects returns the max redirects count in request. +func (r *Request) MaxRedirects() int { + return r.maxRedirects +} + +// SetMaxRedirects method sets the maximum number of redirects at one go in the request instance. +// It will override max redirect which set in client instance. +func (r *Request) SetMaxRedirects(count int) *Request { + r.maxRedirects = count + return r +} + // checkClient method checks whether the client has been set in request. func (r *Request) checkClient() { if r.client == nil { @@ -532,6 +545,8 @@ func (r *Request) Reset() { r.referer = "" r.ctx = nil r.body = nil + r.timeout = 0 + r.maxRedirects = 0 r.bodyType = noBody r.boundary = boundary diff --git a/client/request_test.go b/client/request_test.go index c90dafd68c..b896ccb193 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -6,6 +6,7 @@ import ( "errors" "io" "mime/multipart" + "net" "os" "path/filepath" "regexp" @@ -16,6 +17,7 @@ import ( "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttputil" ) func Test_Request_Method(t *testing.T) { @@ -1140,6 +1142,51 @@ func Test_Request_Timeout_With_Server(t *testing.T) { require.Equal(t, ErrTimeoutOrCancel, err) } +func Test_Request_MaxRedirects(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := fiber.New() + + app.Get("/", func(c fiber.Ctx) error { + if c.Request().URI().QueryArgs().Has("foo") { + return c.Redirect().To("/foo") + } + return c.Redirect().To("/") + }) + app.Get("/foo", func(c fiber.Ctx) error { + return c.SendString("redirect") + }) + + go func() { require.Equal(t, nil, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() + + t.Run("success", func(t *testing.T) { + resp, err := AcquireRequest(). + SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }). + SetMaxRedirects(1). + Get("http://example.com?foo") + body := resp.String() + code := resp.StatusCode() + + require.Equal(t, 200, code) + require.Equal(t, "redirect", body) + require.NoError(t, err) + + resp.Close() + }) + + t.Run("error", func(t *testing.T) { + resp, err := AcquireRequest(). + SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }). + SetMaxRedirects(1). + Get("http://example.com") + + require.Nil(t, resp) + require.Equal(t, "too many redirects detected when doing the request", err.Error()) + }) +} + // // readErrorConn is a struct for testing retryIf // type readErrorConn struct { // net.Conn From 9fa8c943702508c8ed89a633d9b895065c09c837 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sat, 12 Nov 2022 19:01:04 +0800 Subject: [PATCH 066/118] =?UTF-8?q?=F0=9F=94=A5=20chore:=20delete=20deps?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 1 - go.sum | 2 -- 2 files changed, 3 deletions(-) diff --git a/go.mod b/go.mod index 020c45263a..56e11d3baa 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/gofiber/fiber/v3 go 1.19 require ( - github.com/gofiber/utils v1.0.1 github.com/gofiber/utils/v2 v2.0.0-beta.1 github.com/google/uuid v1.3.0 github.com/mattn/go-colorable v0.1.13 diff --git a/go.sum b/go.sum index 840582417f..7412a8ab1b 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,6 @@ github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHG github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gofiber/utils v1.0.1 h1:knct4cXwBipWQqFrOy1Pv6UcgPM+EXo9jDgc66V1Qio= -github.com/gofiber/utils v1.0.1/go.mod h1:pacRFtghAE3UoknMOUiXh2Io/nLWSUHtQCi/3QASsOc= github.com/gofiber/utils/v2 v2.0.0-beta.1 h1:ACfPdqeclx+BFIja19UjkKx7k3r5tmpILpNgzrfPLKs= github.com/gofiber/utils/v2 v2.0.0-beta.1/go.mod h1:CG89nDoIkEFIJaw5LdLO9AmBM11odse/LC79KQujm74= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= From 4a73d3b9351adb2887273ee7f00fd7401334738f Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sat, 12 Nov 2022 19:13:27 +0800 Subject: [PATCH 067/118] perf: fix spell error --- client/client_test.go | 4 ++-- client/core.go | 2 +- client/hooks.go | 2 +- client/hooks_test.go | 2 +- client/jar.go | 3 +++ client/request_test.go | 4 ++-- 6 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 client/jar.go diff --git a/client/client_test.go b/client/client_test.go index 23c94a4cfc..9cdf6348ae 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -118,7 +118,7 @@ func Test_Client_Invalid_URL(t *testing.T) { SetDial(dial). Get("http://example.com\r\n\r\nGET /\r\n\r\n") - require.ErrorIs(t, err, ErrURLForamt) + require.ErrorIs(t, err, ErrURLFormat) } func Test_Client_Unsupported_Protocol(t *testing.T) { @@ -128,7 +128,7 @@ func Test_Client_Unsupported_Protocol(t *testing.T) { R(). Get("ftp://example.com") - require.ErrorIs(t, err, ErrURLForamt) + require.ErrorIs(t, err, ErrURLFormat) } func Test_Get(t *testing.T) { diff --git a/client/core.go b/client/core.go index 7ed3dd9b25..2364799e0b 100644 --- a/client/core.go +++ b/client/core.go @@ -306,7 +306,7 @@ func newCore() (c *core) { var ( ErrTimeoutOrCancel = errors.New("timeout or cancel") - ErrURLForamt = errors.New("the url is a mistake") + ErrURLFormat = errors.New("the url is a mistake") ErrNotSupportSchema = errors.New("the protocol is not support, only http or https") ErrFileNoName = errors.New("the file should have name") ErrBodyType = errors.New("the body type should be []byte") diff --git a/client/hooks.go b/client/hooks.go index 2b495471cd..2185de9d6b 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -67,7 +67,7 @@ func parserRequestURL(c *Client, req *Request) error { if !protocolCheck.MatchString(uri) { uri = c.baseUrl + uri if !protocolCheck.MatchString(uri) { - return ErrURLForamt + return ErrURLFormat } } diff --git a/client/hooks_test.go b/client/hooks_test.go index 2d1284c843..b2e1ab9ca8 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -73,7 +73,7 @@ func Test_Parser_Request_URL(t *testing.T) { req := AcquireRequest().SetURL("/v1") err := parserRequestURL(client, req) - require.Equal(t, ErrURLForamt, err) + require.Equal(t, ErrURLFormat, err) }) t.Run("the path param from client", func(t *testing.T) { diff --git a/client/jar.go b/client/jar.go new file mode 100644 index 0000000000..e8d08339cb --- /dev/null +++ b/client/jar.go @@ -0,0 +1,3 @@ +package client + + diff --git a/client/request_test.go b/client/request_test.go index b896ccb193..1598033b9c 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -558,7 +558,7 @@ func Test_Request_Invalid_URL(t *testing.T) { resp, err := AcquireRequest(). Get("http://example.com\r\n\r\nGET /\r\n\r\n") - require.Equal(t, ErrURLForamt, err) + require.Equal(t, ErrURLFormat, err) require.Equal(t, (*Response)(nil), resp) } @@ -567,7 +567,7 @@ func Test_Request_Unsupport_Protocol(t *testing.T) { resp, err := AcquireRequest(). Get("ftp://example.com") - require.Equal(t, ErrURLForamt, err) + require.Equal(t, ErrURLFormat, err) require.Equal(t, (*Response)(nil), resp) } From d84da4131745f4fa8ccc0cd81edd0bddf2635ef2 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sat, 12 Nov 2022 19:31:30 +0800 Subject: [PATCH 068/118] =?UTF-8?q?=F0=9F=8E=A8=20perf:=20spell=20error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 30 +++++++++++++++--------------- client/core.go | 8 ++++---- client/hooks.go | 6 +++--- client/request.go | 18 +++++++++--------- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/client/client.go b/client/client.go index ac14355bfd..aacaa123fa 100644 --- a/client/client.go +++ b/client/client.go @@ -46,8 +46,8 @@ type Client struct { // user defined response hooks userResponseHooks []ResponseHook - // client package defined respose hooks - buildinResposeHooks []ResponseHook + // client package defined response hooks + buildinResponseHooks []ResponseHook jsonMarshal utils.JSONMarshal jsonUnmarshal utils.JSONUnmarshal @@ -80,7 +80,7 @@ func (c *Client) AddRequestHook(h ...RequestHook) *Client { return c } -// ResponseHook return user-define reponse hooks. +// ResponseHook return user-define response hooks. func (c *Client) ResponseHook() []ResponseHook { return c.userResponseHooks } @@ -319,7 +319,7 @@ func (c *Client) SetParamsWithStruct(v any) *Client { return c } -// DelParams method deletes single or multiple params field and its valus in client. +// DelParams method deletes single or multiple params field and its values in client. func (c *Client) DelParams(key ...string) *Client { for _, v := range key { c.params.Del(v) @@ -377,7 +377,7 @@ func (c *Client) SetPathParamsWithStruct(v any) *Client { return c } -// DelPathParams method deletes single or multiple path params field and its valus in client. +// DelPathParams method deletes single or multiple path params field and its values in client. func (c *Client) DelPathParams(key ...string) *Client { c.path.DelParams(key...) return c @@ -416,7 +416,7 @@ func (c *Client) SetCookiesWithStruct(v any) *Client { return c } -// DelCookies method deletes single or multiple cookies field and its valus in client. +// DelCookies method deletes single or multiple cookies field and its values in client. func (c *Client) DelCookies(key ...string) *Client { c.cookies.DelCookies(key...) return c @@ -603,14 +603,14 @@ var ( cookies: &Cookie{}, path: &PathParam{}, - userRequestHooks: []RequestHook{}, - buildinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, - userResponseHooks: []ResponseHook{}, - buildinResposeHooks: []ResponseHook{parserResponseCookie}, - jsonMarshal: json.Marshal, - jsonUnmarshal: json.Unmarshal, - xmlMarshal: xml.Marshal, - xmlUnmarshal: xml.Unmarshal, + userRequestHooks: []RequestHook{}, + buildinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, + userResponseHooks: []ResponseHook{}, + buildinResponseHooks: []ResponseHook{parserResponseCookie}, + jsonMarshal: json.Marshal, + jsonUnmarshal: json.Unmarshal, + xmlMarshal: xml.Marshal, + xmlUnmarshal: xml.Unmarshal, } }, } @@ -641,7 +641,7 @@ func C() *Client { return defaultClient } -// Replce the defaultClient, the returned function can undo. +// Replace the defaultClient, the returned function can undo. func Replace(c *Client) func() { replaceMu.Lock() defer replaceMu.Unlock() diff --git a/client/core.go b/client/core.go index 2364799e0b..666a4601ed 100644 --- a/client/core.go +++ b/client/core.go @@ -27,10 +27,10 @@ var ( // Called before a request is sent. type RequestHook func(*Client, *Request) error -// ResponseHook is a function that receives Agent, Respose and Request, -// it can change the data is Respose or deal with some effects. +// ResponseHook is a function that receives Agent, Response and Request, +// it can change the data is Response or deal with some effects. // -// Called after a respose has been received. +// Called after a response has been received. type ResponseHook func(*Client, *Response, *Request) error // RetryConfig is an alias for config in the `addon/retry` package. @@ -165,7 +165,7 @@ func (c *core) preHooks() error { func (c *core) afterHooks(resp *Response) error { c.client.mu.Lock() defer c.client.mu.Unlock() - for _, f := range c.client.buildinResposeHooks { + for _, f := range c.client.buildinResponseHooks { err := f(c.client, resp, c.req) if err != nil { return err diff --git a/client/hooks.go b/client/hooks.go index 2185de9d6b..e996d4737f 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -58,7 +58,7 @@ func randString(n int) string { // Query params and path params deal in this function. func parserRequestURL(c *Client, req *Request) error { splitUrl := strings.Split(req.url, "?") - // I don't want to judege splitUrl length. + // I don't want to judge splitUrl length. splitUrl = append(splitUrl, "") // Determine whether to superimpose baseurl based on @@ -79,7 +79,7 @@ func parserRequestURL(c *Client, req *Request) error { uri = strings.Replace(uri, ":"+key, val, -1) }) - // set uri to request and orther related setting + // set uri to request and other related setting req.RawRequest.SetRequestURI(uri) // merge query params @@ -233,7 +233,7 @@ func parserRequestBody(c *Client, req *Request) error { } } - // wirte file + // write file w, err := mw.CreateFormFile(v.fieldName, v.name) if err != nil { return err diff --git a/client/request.go b/client/request.go index 9e5b31dc18..c2f504148e 100644 --- a/client/request.go +++ b/client/request.go @@ -227,7 +227,7 @@ func (r *Request) SetUserAgent(ua string) *Request { return r } -// Boundary returns bounday in multipart boundary. +// Boundary returns boundary in multipart boundary. func (r *Request) Boundary() string { return r.boundary } @@ -520,7 +520,7 @@ func (r *Request) Delete(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodDelete).Send() } -// Send Options reuqest. +// Send Options request. func (r *Request) Options(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodOptions).Send() } @@ -570,7 +570,7 @@ type Header struct { *fasthttp.RequestHeader } -// Peekmutiple methods returns multiple field in header with same key. +// Peekmultiple methods returns multiple field in header with same key. func (h *Header) PeekMultiple(key string) []string { res := []string{} byteKey := []byte(key) @@ -642,7 +642,7 @@ func (c Cookie) Del(key string) { delete(c, key) } -// SetCookie method sets a signle val in Cookie. +// SetCookie method sets a single val in Cookie. func (c Cookie) SetCookie(key, val string) { c[key] = val } @@ -659,7 +659,7 @@ func (c Cookie) SetCookiesWithStruct(v any) { SetValWithStruct(c, "cookie", v) } -// DelCookies method deletes mutiple val in Cookie. +// DelCookies method deletes multiple val in Cookie. func (c Cookie) DelCookies(key ...string) { for _, v := range key { c.Del(v) @@ -693,7 +693,7 @@ func (p PathParam) Del(key string) { delete(p, key) } -// SetParam method sets a signle val in PathParam. +// SetParam method sets a single val in PathParam. func (p PathParam) SetParam(key, val string) { p[key] = val } @@ -710,7 +710,7 @@ func (p PathParam) SetParamsWithStruct(v any) { SetValWithStruct(p, "path", v) } -// DelParams method deletes mutiple val in PathParams. +// DelParams method deletes multiple val in PathParams. func (p PathParam) DelParams(key ...string) { for _, v := range key { p.Del(v) @@ -763,7 +763,7 @@ func (f *FormData) SetDatas(m map[string]string) { } } -// SetDatasWithStruct method supports set mutiple fields via a struct. +// SetDatasWithStruct method supports set multiple fields via a struct. func (f *FormData) SetDatasWithStruct(v any) { SetValWithStruct(f, "form", v) } @@ -803,7 +803,7 @@ func (f *File) SetPath(p string) { f.path = p } -// SetReader method can reveive a io.ReadCloser +// SetReader method can receive a io.ReadCloser // which will be closed in parserBody hook. func (f *File) SetReader(r io.ReadCloser) { f.reader = r From 679690b775c52736fcd127fe1a086f59529ca557 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sat, 12 Nov 2022 22:29:46 +0800 Subject: [PATCH 069/118] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20logger?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 39 +++++++++++++++++++++++++++++++++++++-- client/core.go | 13 +++++++------ client/hooks.go | 9 +++++++++ 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/client/client.go b/client/client.go index aacaa123fa..1ba6856fa5 100644 --- a/client/client.go +++ b/client/client.go @@ -18,6 +18,14 @@ import ( "github.com/valyala/fasthttp" ) +type Logger interface { + Printf(format string, args ...any) +} + +type disableLogger struct{} + +func (*disableLogger) Printf(format string, args ...any) {} + // The Client is used to create a Fiber Client with // client-level settings that apply to all requests // raise from the client. @@ -25,7 +33,7 @@ import ( // Fiber Client also provides an option to override // or merge most of the client settings at the request. type Client struct { - mu sync.Mutex + mu sync.RWMutex baseUrl string userAgent string @@ -35,6 +43,8 @@ type Client struct { cookies *Cookie path *PathParam + logger Logger + timeout time.Duration // user defined request hooks @@ -76,6 +86,9 @@ func (c *Client) RequestHook() []RequestHook { // Add user-defined request hooks. func (c *Client) AddRequestHook(h ...RequestHook) *Client { + c.mu.Lock() + defer c.mu.Unlock() + c.userRequestHooks = append(c.userRequestHooks, h...) return c } @@ -87,6 +100,9 @@ func (c *Client) ResponseHook() []ResponseHook { // Add user-defined response hooks. func (c *Client) AddResponseHook(h ...ResponseHook) *Client { + c.mu.Lock() + defer c.mu.Unlock() + c.userResponseHooks = append(c.userResponseHooks, h...) return c } @@ -211,6 +227,9 @@ func (c *Client) RetryConfig() *RetryConfig { // SetRetryConfig sets retry config in client which is impl by addon/retry package. func (c *Client) SetRetryConfig(config *RetryConfig) *Client { + c.mu.Lock() + defer c.mu.Unlock() + c.retryConfig = config return c } @@ -430,6 +449,21 @@ func (c *Client) SetTimeout(t time.Duration) *Client { return c } +func (c *Client) Logger() Logger { + if c.logger == nil { + return &disableLogger{} + } + + return c.logger +} + +// SetLogger set logger field in client. +// The logger would output relate info with request. +func (c *Client) SetLogger(logger Logger) *Client { + c.logger = logger + return c +} + // Get provide a API like axios which send get request. func (c *Client) Get(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) @@ -602,11 +636,12 @@ var ( }, cookies: &Cookie{}, path: &PathParam{}, + logger: &disableLogger{}, userRequestHooks: []RequestHook{}, buildinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, userResponseHooks: []ResponseHook{}, - buildinResponseHooks: []ResponseHook{parserResponseCookie}, + buildinResponseHooks: []ResponseHook{parserResponseCookie, logger}, jsonMarshal: json.Marshal, jsonUnmarshal: json.Unmarshal, xmlMarshal: xml.Marshal, diff --git a/client/core.go b/client/core.go index 666a4601ed..a255a6ca79 100644 --- a/client/core.go +++ b/client/core.go @@ -60,8 +60,8 @@ type core struct { } func (c *core) getRetryConfig() *RetryConfig { - c.client.mu.Lock() - defer c.client.mu.Unlock() + c.client.mu.RLock() + defer c.client.mu.RUnlock() cfg := c.client.RetryConfig() if cfg == nil { @@ -141,8 +141,8 @@ func (c *core) execFunc() (*Response, error) { // Exec request hook func (c *core) preHooks() error { - c.client.mu.Lock() - defer c.client.mu.Unlock() + c.client.mu.RLock() + defer c.client.mu.RUnlock() for _, f := range c.client.userRequestHooks { err := f(c.client, c.req) @@ -163,8 +163,9 @@ func (c *core) preHooks() error { // Exec response hooks func (c *core) afterHooks(resp *Response) error { - c.client.mu.Lock() - defer c.client.mu.Unlock() + c.client.mu.RLock() + defer c.client.mu.RUnlock() + for _, f := range c.client.buildinResponseHooks { err := f(c.client, resp, c.req) if err != nil { diff --git a/client/hooks.go b/client/hooks.go index e996d4737f..f6c427568a 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -279,3 +279,12 @@ func parserResponseCookie(c *Client, resp *Response, req *Request) (err error) { return } + +func logger(c *Client, resp *Response, req *Request) (err error) { + logger := c.Logger() + + logger.Printf("%s\n", req.RawRequest.String()) + logger.Printf("%s\n", resp.RawResponse.String()) + + return +} From bddf796cfbdae97c8ceea07c3b36a6d90cc54771 Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Sun, 13 Nov 2022 11:38:15 +0800 Subject: [PATCH 070/118] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20cookie=20jar?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 14 +++ client/hooks.go | 13 ++ client/jar.go | 303 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 330 insertions(+) diff --git a/client/client.go b/client/client.go index 1ba6856fa5..88f60072b7 100644 --- a/client/client.go +++ b/client/client.go @@ -64,6 +64,8 @@ type Client struct { xmlMarshal utils.XMLMarshal xmlUnmarshal utils.XMLUnmarshal + jar *jar + // tls config tlsConfig *tls.Config @@ -464,6 +466,18 @@ func (c *Client) SetLogger(logger Logger) *Client { return c } +// enable cookie jar +func (c *Client) EnableJar() *Client { + c.jar = newJar() + return c +} + +// disable cookie jar +func (c *Client) DisableJar() *Client { + c.jar = nil + return c +} + // Get provide a API like axios which send get request. func (c *Client) Get(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) diff --git a/client/hooks.go b/client/hooks.go index f6c427568a..8747f70a9b 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -154,6 +154,14 @@ func parserRequestHeader(c *Client, req *Request) error { } // set cookie + // add cookie form jar to req + if c.jar != nil { + cookies := c.jar.Cookies(req.RawRequest.URI()) + for _, c := range cookies { + req.RawRequest.Header.SetCookieBytesKV(c.Key, c.Value) + } + } + c.cookies.VisitAll(func(key, val string) { req.RawRequest.Header.SetCookie(key, val) }) @@ -277,6 +285,11 @@ func parserResponseCookie(c *Client, resp *Response, req *Request) (err error) { resp.cookie = append(resp.cookie, cookie) }) + // store cookies to jar + if c.jar != nil { + c.jar.SetCookies(req.RawRequest.URI(), resp.cookie) + } + return } diff --git a/client/jar.go b/client/jar.go index e8d08339cb..648fce0a25 100644 --- a/client/jar.go +++ b/client/jar.go @@ -1,3 +1,306 @@ package client +import ( + "fmt" + "sort" + "strings" + "sync" + "time" + "github.com/gofiber/utils/v2" + "github.com/valyala/fasthttp" +) + +var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC) + +type jar struct { + mu sync.Mutex + + entries map[string]map[string]*entry + + nextSeqNum uint64 +} + +type entry struct { + Key []byte + Value []byte + Domain []byte + Path []byte + SameSite fasthttp.CookieSameSite + Secure bool + HttpOnly bool + Persistent bool + Expires time.Time + Creation time.Time + LastAccess time.Time + + seqNum uint64 +} + +func (e *entry) id() string { + return fmt.Sprintf("%s:%s:%s", utils.UnsafeString(e.Domain), utils.UnsafeString(e.Path), utils.UnsafeString(e.Key)) +} +func (e *entry) shouldSend(https bool, host, path []byte) bool { + return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure) +} + +func (e *entry) domainMatch(host []byte) bool { + if utils.EqualFold(e.Domain, host) { + return true + } + + return hasDotSuffix(host, e.Domain) +} + +func (e *entry) pathMatch(path []byte) bool { + if utils.EqualFold(e.Path, path) { + return true + } + if strings.HasPrefix(utils.UnsafeString(path), utils.UnsafeString(e.Path)) { + if e.Path[len(e.Path)-1] == '/' { + return true // The "/any/" matches "/any/path" case. + } else if path[len(e.Path)] == '/' { + return true // The "/any" matches "/any/path" case. + } + } + return false +} + +func (e *entry) reset() { + e.Key = []byte{} + e.Value = []byte{} + e.Domain = []byte{} + e.Path = []byte{} + e.SameSite = fasthttp.CookieSameSiteDefaultMode + e.Secure = true + e.HttpOnly = true + e.Persistent = false + + now := time.Now() + e.Expires = now + e.Creation = now + e.LastAccess = now + + e.seqNum = 0 +} + +func (j *jar) Cookies(u *fasthttp.URI) []*entry { + return j.cookies(u, time.Now()) +} + +func (j *jar) cookies(u *fasthttp.URI, now time.Time) (cookies []*entry) { + if !utils.EqualFold(u.Scheme(), []byte("http")) && !utils.EqualFold(u.Scheme(), []byte("https")) { + return + } + + host := u.Host() + key := jarKey(host) + + j.mu.Lock() + defer j.mu.Unlock() + + subMap := j.entries[key] + if subMap == nil { + return + } + + https := utils.EqualFold(u.Scheme(), []byte("https")) + path := u.Path() + if len(path) == 0 { + path = []byte("/") + } + + modified := false + for id, e := range subMap { + if e.Persistent && !e.Expires.After(now) { + ee := subMap[id] + delete(subMap, id) + releaseEntry(ee) + modified = true + continue + } + + if !e.shouldSend(https, host, path) { + continue + } + + e.LastAccess = now + subMap[id] = e + cookies = append(cookies, e) + modified = true + } + + if modified { + if len(subMap) == 0 { + delete(j.entries, key) + } else { + j.entries[key] = subMap + } + } + + sort.Slice(cookies, func(i, j int) bool { + s := cookies + if len(s[i].Path) != len(s[j].Path) { + return len(s[i].Path) > len(s[j].Path) + } + + if s[i].Creation != s[j].Creation { + return s[i].Creation.Before(s[j].Creation) + } + + return s[i].seqNum < s[j].seqNum + }) + + return +} + +func (j *jar) SetCookies(u *fasthttp.URI, cookies []*fasthttp.Cookie) { + j.setCookies(u, cookies, time.Now()) +} + +func (j *jar) setCookies(u *fasthttp.URI, cookies []*fasthttp.Cookie, now time.Time) { + if len(cookies) == 0 { + return + } + + if !utils.EqualFold(u.Scheme(), []byte("http")) && !utils.EqualFold(u.Scheme(), []byte("https")) { + return + } + + host := u.Host() + path := u.Path() + key := jarKey(host) + + j.mu.Lock() + defer j.mu.Unlock() + + subMap := j.entries[key] + + modified := false + for _, cookie := range cookies { + e, remove := newEntry(cookie, now, path) + id := e.id() + + if remove { + if subMap != nil { + if _, ok := subMap[id]; ok { + ee := subMap[id] + delete(subMap, id) + releaseEntry(ee) + modified = true + } + + continue + } + } + + if subMap == nil { + subMap = make(map[string]*entry) + } + + if old, ok := subMap[id]; ok { + e.Creation = old.Creation + e.seqNum = old.seqNum + } else { + e.Creation = now + e.seqNum = j.nextSeqNum + j.nextSeqNum++ + } + + e.LastAccess = now + subMap[id] = e + modified = true + } + + if modified { + if len(subMap) == 0 { + delete(j.entries, key) + } else { + j.entries[key] = subMap + } + } +} + +func jarKey(h []byte) string { + host := utils.UnsafeString(h) + if utils.IsIPv4(host) || utils.IsIPv6(host) { + return host + } + + i := strings.LastIndex(host, ".") + if i <= 0 { + return host + } + + prevDot := strings.LastIndex(host[:i-1], ".") + return host[prevDot+1:] +} + +func hasDotSuffix(s, suffix []byte) bool { + return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && utils.EqualFold(s[len(s)-len(suffix):], suffix) +} + +func newEntry(c *fasthttp.Cookie, now time.Time, path []byte) (*entry, bool) { + e := acquireEntry() + + e.Key = utils.CopyBytes(c.Key()) + + if len(c.Path()) != 0 || c.Path()[0] != '/' { + e.Path = utils.CopyBytes(path) + } else { + e.Path = utils.CopyBytes(c.Path()) + } + + e.Domain = utils.CopyBytes(c.Domain()) + + if c.MaxAge() < 0 { + return e, true + } else if c.MaxAge() > 0 { + e.Expires = now.Add(time.Duration(c.MaxAge()) * time.Second) + e.Persistent = true + } else { + if c.Expire().IsZero() { + e.Expires = endOfTime + e.Persistent = false + } else { + if !c.Expire().After(now) { + return e, true + } + + e.Expires = c.Expire() + e.Persistent = true + } + } + + e.Value = utils.CopyBytes(c.Value()) + e.Secure = c.Secure() + e.HttpOnly = c.HTTPOnly() + + e.SameSite = c.SameSite() + + return e, false +} + +func newJar() *jar { + return &jar{ + mu: sync.Mutex{}, + entries: map[string]map[string]*entry{}, + nextSeqNum: 0, + } +} + +var entryPool = &sync.Pool{ + New: func() any { + return &entry{} + }, +} + +func acquireEntry() *entry { + e := entryPool.Get().(*entry) + return e +} + +func releaseEntry(e *entry) { + e.reset() + entryPool.Put(e) +} From fc9fdb59d53afc337dc9e29c4685c6ee3271e94d Mon Sep 17 00:00:00 2001 From: Jinquan Wang Date: Mon, 14 Nov 2022 11:52:39 +0800 Subject: [PATCH 071/118] =?UTF-8?q?=E2=9C=A8=20feat:=20logger=20with=20lev?= =?UTF-8?q?el?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 39 +++++++++++++++++++++++++++++++++++++-- client/hooks.go | 8 ++++++-- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/client/client.go b/client/client.go index 88f60072b7..aaf40be349 100644 --- a/client/client.go +++ b/client/client.go @@ -18,13 +18,35 @@ import ( "github.com/valyala/fasthttp" ) +// Define the logger interface so that users can +// use different log implements to output logs. type Logger interface { - Printf(format string, args ...any) + // The log with error level + Errorf(format string, v ...any) + + // The log with warn level + Warnf(format string, v ...any) + + // The log with info level + Infof(format string, v ...any) + + // The log with debug level + Debugf(format string, v ...any) } +var _ (Logger) = (*disableLogger)(nil) + +// Implement a Logger interface. +// All logs are turned off by default. type disableLogger struct{} -func (*disableLogger) Printf(format string, args ...any) {} +func (*disableLogger) Errorf(format string, args ...any) {} + +func (*disableLogger) Warnf(format string, args ...any) {} + +func (*disableLogger) Infof(format string, args ...any) {} + +func (*disableLogger) Debugf(format string, args ...any) {} // The Client is used to create a Fiber Client with // client-level settings that apply to all requests @@ -43,6 +65,7 @@ type Client struct { cookies *Cookie path *PathParam + debug bool logger Logger timeout time.Duration @@ -466,6 +489,18 @@ func (c *Client) SetLogger(logger Logger) *Client { return c } +// Debug enable log debug level output. +func (c *Client) Debug() *Client { + c.debug = true + return c +} + +// DisableDebug disenable log debug level output. +func (c *Client) DisableDebug() *Client { + c.debug = false + return c +} + // enable cookie jar func (c *Client) EnableJar() *Client { c.jar = newJar() diff --git a/client/hooks.go b/client/hooks.go index 8747f70a9b..c2c5713d58 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -294,10 +294,14 @@ func parserResponseCookie(c *Client, resp *Response, req *Request) (err error) { } func logger(c *Client, resp *Response, req *Request) (err error) { + if !c.debug { + return + } + logger := c.Logger() - logger.Printf("%s\n", req.RawRequest.String()) - logger.Printf("%s\n", resp.RawResponse.String()) + logger.Debugf("%s\n", req.RawRequest.String()) + logger.Debugf("%s\n", resp.RawResponse.String()) return } From b66fb0d075d43f4fbda1001ac1f3201384d3e611 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sat, 24 Dec 2022 20:58:55 +0800 Subject: [PATCH 072/118] =?UTF-8?q?=F0=9F=8E=A8=20perf:=20change=20the=20f?= =?UTF-8?q?ield=20name?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/client.go | 8 ++++---- client/core.go | 4 ++-- client/jar_test.go | 1 + client/request.go | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) create mode 100644 client/jar_test.go diff --git a/client/client.go b/client/client.go index aaf40be349..33265bc239 100644 --- a/client/client.go +++ b/client/client.go @@ -74,13 +74,13 @@ type Client struct { userRequestHooks []RequestHook // client package defined request hooks - buildinRequestHooks []RequestHook + builtinRequestHooks []RequestHook // user defined response hooks userResponseHooks []ResponseHook // client package defined response hooks - buildinResponseHooks []ResponseHook + builtinResponseHooks []ResponseHook jsonMarshal utils.JSONMarshal jsonUnmarshal utils.JSONUnmarshal @@ -688,9 +688,9 @@ var ( logger: &disableLogger{}, userRequestHooks: []RequestHook{}, - buildinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, + builtinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, userResponseHooks: []ResponseHook{}, - buildinResponseHooks: []ResponseHook{parserResponseCookie, logger}, + builtinResponseHooks: []ResponseHook{parserResponseCookie, logger}, jsonMarshal: json.Marshal, jsonUnmarshal: json.Unmarshal, xmlMarshal: xml.Marshal, diff --git a/client/core.go b/client/core.go index a255a6ca79..851e30bb7d 100644 --- a/client/core.go +++ b/client/core.go @@ -151,7 +151,7 @@ func (c *core) preHooks() error { } } - for _, f := range c.client.buildinRequestHooks { + for _, f := range c.client.builtinRequestHooks { err := f(c.client, c.req) if err != nil { return err @@ -166,7 +166,7 @@ func (c *core) afterHooks(resp *Response) error { c.client.mu.RLock() defer c.client.mu.RUnlock() - for _, f := range c.client.buildinResponseHooks { + for _, f := range c.client.builtinResponseHooks { err := f(c.client, resp, c.req) if err != nil { return err diff --git a/client/jar_test.go b/client/jar_test.go new file mode 100644 index 0000000000..e169c0b61a --- /dev/null +++ b/client/jar_test.go @@ -0,0 +1 @@ +package client_test diff --git a/client/request.go b/client/request.go index c2f504148e..c2d8cad74f 100644 --- a/client/request.go +++ b/client/request.go @@ -570,7 +570,7 @@ type Header struct { *fasthttp.RequestHeader } -// Peekmultiple methods returns multiple field in header with same key. +// PeekMultiple methods returns multiple field in header with same key. func (h *Header) PeekMultiple(key string) []string { res := []string{} byteKey := []byte(key) From ecf4a8f633f38e88ef87dad0c0fc6635881188e1 Mon Sep 17 00:00:00 2001 From: wangjq4214 Date: Sun, 1 Jan 2023 22:45:49 +0800 Subject: [PATCH 073/118] perf: add jar test --- client/jar_test.go | 207 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 206 insertions(+), 1 deletion(-) diff --git a/client/jar_test.go b/client/jar_test.go index e169c0b61a..cf8947192e 100644 --- a/client/jar_test.go +++ b/client/jar_test.go @@ -1 +1,206 @@ -package client_test +package client + +import ( + "fmt" + "strings" + "testing" + "time" + + "github.com/gofiber/utils/v2" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +// tNow is the synthetic current time used as now during testing. +var tNow = time.Date(2013, 1, 1, 12, 0, 0, 0, time.UTC) + +var hasDotSuffixTests = [...]struct { + s, suffix string +}{ + {"", ""}, + {"", "."}, + {"", "x"}, + {".", ""}, + {".", "."}, + {".", ".."}, + {".", "x"}, + {".", "x."}, + {".", ".x"}, + {".", ".x."}, + {"x", ""}, + {"x", "."}, + {"x", ".."}, + {"x", "x"}, + {"x", "x."}, + {"x", ".x"}, + {"x", ".x."}, + {".x", ""}, + {".x", "."}, + {".x", ".."}, + {".x", "x"}, + {".x", "x."}, + {".x", ".x"}, + {".x", ".x."}, + {"x.", ""}, + {"x.", "."}, + {"x.", ".."}, + {"x.", "x"}, + {"x.", "x."}, + {"x.", ".x"}, + {"x.", ".x."}, + {"com", ""}, + {"com", "m"}, + {"com", "om"}, + {"com", "com"}, + {"com", ".com"}, + {"com", "x.com"}, + {"com", "xcom"}, + {"com", "xorg"}, + {"com", "org"}, + {"com", "rg"}, + {"foo.com", ""}, + {"foo.com", "m"}, + {"foo.com", "om"}, + {"foo.com", "com"}, + {"foo.com", ".com"}, + {"foo.com", "o.com"}, + {"foo.com", "oo.com"}, + {"foo.com", "foo.com"}, + {"foo.com", ".foo.com"}, + {"foo.com", "x.foo.com"}, + {"foo.com", "xfoo.com"}, + {"foo.com", "xfoo.org"}, + {"foo.com", "foo.org"}, + {"foo.com", "oo.org"}, + {"foo.com", "o.org"}, + {"foo.com", ".org"}, + {"foo.com", "org"}, + {"foo.com", "rg"}, +} + +func TestHasDotSuffix(t *testing.T) { + for _, tc := range hasDotSuffixTests { + got := hasDotSuffix([]byte(tc.s), []byte(tc.suffix)) + + want := strings.HasSuffix(tc.s, "."+tc.suffix) + require.Equal(t, want, got) + } +} + +var jarKeyTests = map[string]string{ + "foo.www.example.com": "example.com", + "www.example.com": "example.com", + "example.com": "example.com", + "com": "com", + "foo.www.bbc.co.uk": "co.uk", + "www.bbc.co.uk": "co.uk", + "bbc.co.uk": "co.uk", + "co.uk": "co.uk", + "uk": "uk", + "192.168.0.5": "192.168.0.5", + // The following are actual outputs of canonicalHost for + // malformed inputs to canonicalHost. + "": "", + ".": ".", + "..": "..", + ".net": ".net", + "a.": "a.", + "b.a.": "a.", + "weird.stuff..": "stuff..", +} + +func TestJarKey(t *testing.T) { + for host, want := range jarKeyTests { + got := jarKey([]byte(host)) + + require.Equal(t, want, got) + } +} + +// expiresIn creates an expires attribute delta seconds from tNow. +func expiresIn(delta int) string { + t := tNow.Add(time.Duration(delta) * time.Second) + return "expires=" + t.Format(time.RFC1123) +} + +// mustParseURL parses s to an URL and panics on error. +func mustParseURL(s string) *fasthttp.URI { + u := fasthttp.AcquireURI() + err := u.Parse(nil, utils.UnsafeBytes(s)) + + if err != nil || utils.UnsafeString(u.Scheme()) == "" || utils.UnsafeString(u.Hash()) == "" { + panic(fmt.Sprintf("Unable to parse URL %s.", s)) + } + return u +} + +// jarTest encapsulates the following actions on a jar: +// 1. Perform SetCookies with fromURL and the cookies from setCookies. +// (Done at time tNow + 0 ms.) +// 2. Check that the entries in the jar matches content. +// (Done at time tNow + 1001 ms.) +// 3. For each query in tests: Check that Cookies with toURL yields the +// cookies in want. +// (Query n done at tNow + (n+2)*1001 ms.) +type jarTest struct { + description string // The description of what this test is supposed to test + fromURL string // The full URL of the request from which Set-Cookie headers where received + setCookies []string // All the cookies received from fromURL + content string // The whole (non-expired) content of the jar + queries []query // Queries to test the Jar.Cookies method +} + +// query contains one test of the cookies returned from Jar.Cookies. +type query struct { + toURL string // the URL in the Cookies call + want string // the expected list of cookies (order matters) +} + +// run runs the jarTest. +func (test jarTest) run(t *testing.T, jar *jar) { + // now := tNow + + // // Populate jar with cookies. + // setCookies := make([]*fasthttp.Cookie, len(test.setCookies)) + // for i, cs := range test.setCookies { + // resp := fasthttp.AcquireResponse() + // cookies := (&http.Response{Header: http.Header{"Set-Cookie": {cs}}}).Cookies() + // if len(cookies) != 1 { + // panic(fmt.Sprintf("Wrong cookie line %q: %#v", cs, cookies)) + // } + // setCookies[i] = cookies[0] + // } + // jar.setCookies(mustParseURL(test.fromURL), setCookies, now) + // now = now.Add(1001 * time.Millisecond) + + // // Serialize non-expired entries in the form "name1=val1 name2=val2". + // var cs []string + // for _, submap := range jar.entries { + // for _, cookie := range submap { + // if !cookie.Expires.After(now) { + // continue + // } + // cs = append(cs, cookie.Name+"="+cookie.Value) + // } + // } + // sort.Strings(cs) + // got := strings.Join(cs, " ") + + // // Make sure jar content matches our expectations. + // if got != test.content { + // t.Errorf("Test %q Content\ngot %q\nwant %q", + // test.description, got, test.content) + // } + + // // Test different calls to Cookies. + // for i, query := range test.queries { + // now = now.Add(1001 * time.Millisecond) + // var s []string + // for _, c := range jar.cookies(mustParseURL(query.toURL), now) { + // s = append(s, c.Name+"="+c.Value) + // } + // if got := strings.Join(s, " "); got != query.want { + // t.Errorf("Test %q #%d\ngot %q\nwant %q", test.description, i, got, query.want) + // } + // } +} From e4a79a6d52e6124eadb3208e3817d44f8ba1ecc7 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 6 Aug 2023 16:25:50 +0300 Subject: [PATCH 074/118] fix proxy test --- middleware/proxy/proxy_test.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 2a7ad204bc..77a4d0e160 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -700,10 +700,12 @@ func Test_Proxy_Domain_Forward_Local(t *testing.T) { go func() { require.NoError(t, app.Listener(ln)) }() go func() { require.NoError(t, app1.Listener(ln1)) }() - code, body, errs := fiber.Get("http://" + localDomain + "/test?query_test=true").String() - require.Equal(t, 0, len(errs)) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "test_local_client:true", body) + resp, err := fiberClient.Get("http://" + localDomain + "/test?query_test=true") + defer resp.Close() + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "test_local_client:true", string(resp.Body())) } // go test -run Test_Proxy_Balancer_Forward_Local From 41ae1a54fe522abad47ccbc49b7b537cb4d975d6 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Mon, 7 Aug 2023 18:23:23 +0300 Subject: [PATCH 075/118] improve test coverage --- client/client.go | 20 +--- client/client_test.go | 226 ++++++++++++++++++++++++++++------------ client/hooks.go | 7 +- client/hooks_test.go | 88 ++++++++++++++++ client/jar.go | 2 +- client/jar_test.go | 25 +++++ client/request.go | 6 ++ client/request_test.go | 24 +++++ client/response.go | 6 +- client/response_test.go | 3 +- 10 files changed, 316 insertions(+), 91 deletions(-) diff --git a/client/client.go b/client/client.go index 33265bc239..cbc944b126 100644 --- a/client/client.go +++ b/client/client.go @@ -6,6 +6,8 @@ import ( "crypto/x509" "encoding/json" "encoding/xml" + "github.com/gofiber/fiber/v3/log" + "github.com/gofiber/utils/v2" "io" "net/url" "os" @@ -14,7 +16,6 @@ import ( "sync" "time" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -239,6 +240,8 @@ func (c *Client) SetRootCertificateFromString(pem string) *Client { func (c *Client) SetProxyURL(proxyURL string) *Client { pUrl, err := url.Parse(proxyURL) if err != nil { + log.Errorf("%v", err) + return c } c.proxyURL = pUrl.String() @@ -474,21 +477,6 @@ func (c *Client) SetTimeout(t time.Duration) *Client { return c } -func (c *Client) Logger() Logger { - if c.logger == nil { - return &disableLogger{} - } - - return c.logger -} - -// SetLogger set logger field in client. -// The logger would output relate info with request. -func (c *Client) SetLogger(logger Logger) *Client { - c.logger = logger - return c -} - // Debug enable log debug level output. func (c *Client) Debug() *Client { c.debug = true diff --git a/client/client_test.go b/client/client_test.go index 9cdf6348ae..3da79c2596 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,12 +1,18 @@ package client import ( + "bytes" "context" "crypto/tls" "fmt" + "github.com/gofiber/fiber/v3/addon/retry" + "github.com/gofiber/fiber/v3/log" + "io" "net" + "os" "reflect" "testing" + "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/tlstest" @@ -813,37 +819,38 @@ func Test_Client_PathParam_With_Server(t *testing.T) { require.Equal(t, "ok", resp.String()) } -// func Test_Client_Cert(t *testing.T) { -// t.Parallel() +func Test_Client_TLS(t *testing.T) { + t.Parallel() -// serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() -// require.Nil(t, err) + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + require.Nil(t, err) -// ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") -// require.Nil(t, err) + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.Nil(t, err) -// ln = tls.NewListener(ln, serverTLSConf) + ln = tls.NewListener(ln, serverTLSConf) -// app := fiber.New() -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString("tls") -// }) + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("tls") + }) -// go func() { -// require.Nil(t, nil, app.Listener(ln, fiber.ListenConfig{ -// DisableStartupMessage: true, -// })) -// }() + go func() { + require.Nil(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() -// client := AcquireClient().SetCertificates(clientTLSConf.Certificates...) -// resp, err := client.Get("https://" + ln.Addr().String()) + client := AcquireClient() + resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) -// require.Nil(t, err) -// require.Equal(t, fiber.StatusOK, resp.StatusCode()) -// require.Equal(t, "tls", resp.String()) -// } + require.Nil(t, err) + require.Equal(t, clientTLSConf, client.TLSConfig()) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "tls", resp.String()) +} -func Test_Client_TLS(t *testing.T) { +func Test_Client_TLS_Empty_TLSConfig(t *testing.T) { t.Parallel() serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() @@ -866,12 +873,42 @@ func Test_Client_TLS(t *testing.T) { }() client := AcquireClient() - resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) + resp, err := client.Get("https://" + ln.Addr().String()) + require.Error(t, err) + require.NotEqual(t, clientTLSConf, client.TLSConfig()) + require.Nil(t, resp) +} + +func Test_Client_SetCertificates(t *testing.T) { + t.Parallel() + + serverTLSConf, _, err := tlstest.GetTLSConfigs() require.Nil(t, err) - require.Equal(t, clientTLSConf, client.TLSConfig()) - require.Equal(t, fiber.StatusOK, resp.StatusCode()) - require.Equal(t, "tls", resp.String()) + + client := AcquireClient().SetCertificates(serverTLSConf.Certificates...) + require.Equal(t, 1, len(client.tlsConfig.Certificates)) +} + +func Test_Client_SetRootCertificate(t *testing.T) { + t.Parallel() + + client := AcquireClient().SetRootCertificate("../.github/testdata/ssl.pem") + require.NotNil(t, client.tlsConfig.RootCAs) +} + +func Test_Client_SetRootCertificateFromString(t *testing.T) { + t.Parallel() + + file, err := os.Open("../.github/testdata/ssl.pem") + defer func() { _ = file.Close() }() + require.NoError(t, err) + + pem, err := io.ReadAll(file) + require.NoError(t, err) + + client := AcquireClient().SetRootCertificateFromString(string(pem)) + require.NotNil(t, client.tlsConfig.RootCAs) } func Test_Client_R(t *testing.T) { @@ -968,70 +1005,125 @@ func Test_Set_Config_To_Request(t *testing.T) { require.Equal(t, "v1", req.Param("k1")[0]) }) - // t.Run("set ctx", func(t *testing.T) { - // key := struct{}{} + t.Run("set cookies", func(t *testing.T) { + req := AcquireRequest() + + setConfigToRequest(req, Config{Cookie: map[string]string{ + "k1": "v1", + }}) + + require.Equal(t, "v1", req.Cookie("k1")) + }) + + t.Run("set pathparam", func(t *testing.T) { + req := AcquireRequest() + + setConfigToRequest(req, Config{PathParam: map[string]string{ + "k1": "v1", + }}) + + require.Equal(t, "v1", req.PathParam("k1")) + }) + + t.Run("set timeout", func(t *testing.T) { + req := AcquireRequest() + + setConfigToRequest(req, Config{Timeout: 1 * time.Second}) + + require.Equal(t, 1*time.Second, req.Timeout()) + }) + + t.Run("set maxredirects", func(t *testing.T) { + req := AcquireRequest() - // ctx := context.Background() - // ctx = context.WithValue(ctx, key, "v1") + setConfigToRequest(req, Config{MaxRedirects: 1}) - // req := AcquireRequest() + require.Equal(t, 1, req.MaxRedirects()) + }) - // setConfigToRequest(req, Config{Ctx: ctx}) + t.Run("set body", func(t *testing.T) { + req := AcquireRequest() - // require.Equal(t, "v1", req.Context().Value(key)) - // }) + setConfigToRequest(req, Config{Body: "test"}) - // t.Run("set ctx", func(t *testing.T) { - // key := struct{}{} + require.Equal(t, "test", req.body) + }) - // ctx := context.Background() - // ctx = context.WithValue(ctx, key, "v1") + t.Run("set file", func(t *testing.T) { + req := AcquireRequest() - // req := AcquireRequest() + setConfigToRequest(req, Config{File: []*File{ + { + name: "test", + path: "path", + }, + }}) - // setConfigToRequest(req, Config{Ctx: ctx}) + require.Equal(t, "path", req.File("test").path) + }) +} - // require.Equal(t, "v1", req.Context().Value(key)) - // }) +func Test_Client_SetProxyURL(t *testing.T) { + t.Parallel() - // t.Run("set ctx", func(t *testing.T) { - // key := struct{}{} + app, dial, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) - // ctx := context.Background() - // ctx = context.WithValue(ctx, key, "v1") + go start() - // req := AcquireRequest() + defer func(app *fiber.App) { + _ = app.Shutdown() + }(app) - // setConfigToRequest(req, Config{Ctx: ctx}) + time.Sleep(1 * time.Second) - // require.Equal(t, "v1", req.Context().Value(key)) - // }) + t.Run("success", func(t *testing.T) { + client := AcquireClient() + client.SetProxyURL("http://test.com") + _, err := client.Get("http://localhost:3000", Config{Dial: dial}) - // t.Run("set ctx", func(t *testing.T) { - // key := struct{}{} + require.NoError(t, err) + }) - // ctx := context.Background() - // ctx = context.WithValue(ctx, key, "v1") + t.Run("wrong url", func(t *testing.T) { + var buf bytes.Buffer + log.SetOutput(&buf) - // req := AcquireRequest() + client := AcquireClient() + client.SetProxyURL(":this is not a url") + _, err := client.Get("http://localhost:3000", Config{Dial: dial}) - // setConfigToRequest(req, Config{Ctx: ctx}) + require.Contains(t, buf.String(), "missing protocol scheme") + require.NoError(t, err) + }) - // require.Equal(t, "v1", req.Context().Value(key)) - // }) + t.Run("error", func(t *testing.T) { + client := AcquireClient() + client.SetProxyURL("htgdftp://test.com") + _, err := client.Get("http://localhost:3000", Config{Dial: dial}) - // t.Run("set ctx", func(t *testing.T) { - // key := struct{}{} + require.Error(t, err) + }) +} - // ctx := context.Background() - // ctx = context.WithValue(ctx, key, "v1") +func Test_Client_SetRetryConfig(t *testing.T) { + t.Parallel() - // req := AcquireRequest() + retryConfig := &retry.Config{ + InitialInterval: 1 * time.Second, + MaxRetryCount: 3, + } - // setConfigToRequest(req, Config{Ctx: ctx}) + core, client, req := newCore(), AcquireClient(), AcquireRequest() + req.SetURL("http://example.com") + client.SetRetryConfig(retryConfig) + _, err := core.execute(context.Background(), client, req) - // require.Equal(t, "v1", req.Context().Value(key)) - // }) + require.NoError(t, err) + require.Equal(t, retryConfig.InitialInterval, client.RetryConfig().InitialInterval) + require.Equal(t, retryConfig.MaxRetryCount, client.RetryConfig().MaxRetryCount) } func Benchmark_Client_Request(b *testing.B) { diff --git a/client/hooks.go b/client/hooks.go index c2c5713d58..15d13ddf72 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -298,10 +299,8 @@ func logger(c *Client, resp *Response, req *Request) (err error) { return } - logger := c.Logger() - - logger.Debugf("%s\n", req.RawRequest.String()) - logger.Debugf("%s\n", resp.RawResponse.String()) + log.Debugf("%s\n", req.RawRequest.String()) + log.Debugf("%s\n", resp.RawResponse.String()) return } diff --git a/client/hooks_test.go b/client/hooks_test.go index b2e1ab9ca8..d144b29e30 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -1,11 +1,14 @@ package client import ( + "bytes" "encoding/xml" + "github.com/gofiber/fiber/v3/log" "io" "net/url" "strings" "testing" + "time" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" @@ -424,6 +427,17 @@ func Test_Parser_Request_Body(t *testing.T) { require.Equal(t, "ball=cricle+and+square", string(req.RawRequest.Body())) }) + t.Run("form data body error", func(t *testing.T) { + client := AcquireClient() + req := AcquireRequest(). + SetFormDatas(map[string]string{ + "": "", + }) + + err := parserRequestBody(client, req) + require.NoError(t, err) + }) + t.Run("file body", func(t *testing.T) { client := AcquireClient() req := AcquireRequest(). @@ -457,4 +471,78 @@ func Test_Parser_Request_Body(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte("hello world"), req.RawRequest.Body()) }) + + t.Run("raw body error", func(t *testing.T) { + client := AcquireClient() + req := AcquireRequest(). + SetRawBody([]byte("hello world")) + + req.body = nil + + err := parserRequestBody(client, req) + require.ErrorIs(t, err, ErrBodyType) + }) +} + +func Test_Client_Logger_Debug(t *testing.T) { + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("response") + }) + + go func() { + require.Nil(t, app.Listen(":3000", fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + defer func(app *fiber.App) { + _ = app.Shutdown() + }(app) + + time.Sleep(1 * time.Second) + + var buf bytes.Buffer + log.SetOutput(&buf) + + client := AcquireClient() + client.Debug() + + resp, err := client.Get("http://localhost:3000") + defer resp.Close() + + require.NoError(t, err) + require.Contains(t, buf.String(), "Host: localhost:3000") + require.Contains(t, buf.String(), "Content-Length: 8") +} + +func Test_Client_Logger_DisableDebug(t *testing.T) { + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("response") + }) + + go func() { + require.Nil(t, app.Listen(":3000", fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + defer func(app *fiber.App) { + _ = app.Shutdown() + }(app) + + time.Sleep(1 * time.Second) + + var buf bytes.Buffer + log.SetOutput(&buf) + + client := AcquireClient() + client.DisableDebug() + + resp, err := client.Get("http://localhost:3000") + defer resp.Close() + + require.NoError(t, err) + require.Len(t, buf.String(), 0) } diff --git a/client/jar.go b/client/jar.go index 648fce0a25..71f711fff7 100644 --- a/client/jar.go +++ b/client/jar.go @@ -244,7 +244,7 @@ func newEntry(c *fasthttp.Cookie, now time.Time, path []byte) (*entry, bool) { e := acquireEntry() e.Key = utils.CopyBytes(c.Key()) - + fmt.Println(c.Path()) if len(c.Path()) != 0 || c.Path()[0] != '/' { e.Path = utils.CopyBytes(path) } else { diff --git a/client/jar_test.go b/client/jar_test.go index cf8947192e..c5c6825f38 100644 --- a/client/jar_test.go +++ b/client/jar_test.go @@ -87,6 +87,31 @@ func TestHasDotSuffix(t *testing.T) { } } +var canonicalHostTests = map[string]string{ + "www.example.com": "www.example.com", + "WWW.EXAMPLE.COM": "www.example.com", + "wWw.eXAmple.CoM": "www.example.com", + "www.example.com:80": "www.example.com", + "192.168.0.10": "192.168.0.10", + "192.168.0.5:8080": "192.168.0.5", + "2001:4860:0:2001::68": "2001:4860:0:2001::68", + "[2001:4860:0:::68]:8080": "2001:4860:0:::68", + "www.bücher.de": "www.xn--bcher-kva.de", + "www.example.com.": "www.example.com", + // TODO: Fix canonicalHost so that all of the following malformed + // domain names trigger an error. (This list is not exhaustive, e.g. + // malformed internationalized domain names are missing.) + ".": "", + "..": ".", + "...": "..", + ".net": ".net", + ".net.": ".net", + "a..": "a.", + "b.a..": "b.a.", + "weird.stuff...": "weird.stuff..", + "[bad.unmatched.bracket:": "error", +} + var jarKeyTests = map[string]string{ "foo.www.example.com": "example.com", "www.example.com": "example.com", diff --git a/client/request.go b/client/request.go index c2d8cad74f..5e55f1dccb 100644 --- a/client/request.go +++ b/client/request.go @@ -324,6 +324,12 @@ func (r *Request) DelPathParams(key ...string) *Request { return r } +// ResetPathParams deletes all path params. +func (r *Request) ResetPathParams() *Request { + r.path.Reset() + return r +} + // SetJSON method sets json body in request. func (r *Request) SetJSON(v any) *Request { r.body = v diff --git a/client/request_test.go b/client/request_test.go index 1598033b9c..78d3ef6c35 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -388,6 +388,20 @@ func Test_Request_PathParam(t *testing.T) { require.Equal(t, "", req.PathParam("foo")) require.Equal(t, "foo", req.PathParam("bar")) }) + + t.Run("clear path params", func(t *testing.T) { + req := AcquireRequest(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + + req.ResetPathParams() + require.Equal(t, "", req.PathParam("foo")) + require.Equal(t, "", req.PathParam("bar")) + }) } func Test_Request_FormData(t *testing.T) { @@ -522,6 +536,8 @@ func Test_Request_File(t *testing.T) { require.Equal(t, "../.github/index.html", req.File("index.html").path) require.Equal(t, "../.github/index.html", req.FileByPath("../.github/index.html").path) require.Equal(t, "tmp.txt", req.File("tmp.txt").name) + require.Nil(t, req.File("tmp2.txt")) + require.Nil(t, req.FileByPath("tmp2.txt")) }) t.Run("add file by reader", func(t *testing.T) { @@ -1185,6 +1201,14 @@ func Test_Request_MaxRedirects(t *testing.T) { require.Nil(t, resp) require.Equal(t, "too many redirects detected when doing the request", err.Error()) }) + + t.Run("MaxRedirects", func(t *testing.T) { + req := AcquireRequest(). + SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }). + SetMaxRedirects(3) + + require.Equal(t, req.MaxRedirects(), 3) + }) } // // readErrorConn is a struct for testing retryIf diff --git a/client/response.go b/client/response.go index e9fb40826c..170821683b 100644 --- a/client/response.go +++ b/client/response.go @@ -2,7 +2,9 @@ package client import ( "bytes" + "errors" "io" + "io/fs" "os" "path/filepath" "strings" @@ -83,9 +85,9 @@ func (r *Response) Save(v any) error { file := filepath.Clean(p) dir := filepath.Dir(file) - // create director + // create directory if _, err := os.Stat(dir); err != nil { - if !os.IsNotExist(err) { + if !errors.Is(err, fs.ErrNotExist) { return err } diff --git a/client/response_test.go b/client/response_test.go index e1d24503fa..e67f6ac334 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -252,12 +252,13 @@ func Test_Response_Save(t *testing.T) { }() file, err := os.Open("./test/tmp.json") + defer file.Close() + require.NoError(t, err) data, err := io.ReadAll(file) require.NoError(t, err) require.Equal(t, "{\"status\":\"success\"}", string(data)) - file.Close() }) t.Run("io.Writer", func(t *testing.T) { From 6ea941035789a72dde334d957a6b5cfb105fead1 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Mon, 7 Aug 2023 18:59:56 +0300 Subject: [PATCH 076/118] fix proxy tests --- middleware/proxy/proxy_test.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 77a4d0e160..e0a8c0f787 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -596,9 +596,11 @@ func Test_Proxy_Forward_Global_Client(t *testing.T) { })) }() + time.Sleep(1 * time.Second) + resp, err := fiberClient.AcquireClient(). R(). - Get("https://" + addr) + Get("http://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) @@ -629,9 +631,11 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) { })) }() + time.Sleep(1 * time.Second) + resp, err := fiberClient.AcquireClient(). R(). - Get("https://" + addr) + Get("http://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) From c34321af268651c3aa06e089a89b471e54078553 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 17 Dec 2023 20:07:27 +0300 Subject: [PATCH 077/118] add cookiejar support from pending fasthttp PR --- client/client.go | 14 +- client/client_test.go | 100 +++++++++++++ client/cookiejar.go | 240 ++++++++++++++++++++++++++++++ client/cookiejar_test.go | 212 +++++++++++++++++++++++++++ client/hooks.go | 11 +- client/jar.go | 306 --------------------------------------- client/jar_test.go | 231 ----------------------------- 7 files changed, 560 insertions(+), 554 deletions(-) create mode 100644 client/cookiejar.go create mode 100644 client/cookiejar_test.go delete mode 100644 client/jar.go delete mode 100644 client/jar_test.go diff --git a/client/client.go b/client/client.go index cbc944b126..d3e30848ef 100644 --- a/client/client.go +++ b/client/client.go @@ -88,7 +88,7 @@ type Client struct { xmlMarshal utils.XMLMarshal xmlUnmarshal utils.XMLUnmarshal - jar *jar + cookieJar *CookieJar // tls config tlsConfig *tls.Config @@ -489,15 +489,9 @@ func (c *Client) DisableDebug() *Client { return c } -// enable cookie jar -func (c *Client) EnableJar() *Client { - c.jar = newJar() - return c -} - -// disable cookie jar -func (c *Client) DisableJar() *Client { - c.jar = nil +// SetCookieJar sets cookie jar in client instance. +func (c *Client) SetCookieJar(cookieJar *CookieJar) *Client { + c.cookieJar = cookieJar return c } diff --git a/client/client_test.go b/client/client_test.go index 3da79c2596..331e1cfeae 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -595,6 +595,106 @@ func Test_Client_Cookie_With_Server(t *testing.T) { testClient(t, handler, wrapAgent, "v1v2v3") } +func Test_Client_CookieJar(t *testing.T) { + handler := func(c fiber.Ctx) error { + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + jar.SetKeyValue("example.com", "k2", "v2") + jar.SetKeyValue("example", "k3", "v3") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1v2") +} + +func Test_Client_CookieJar_Response(t *testing.T) { + t.Run("without expiration", func(t *testing.T) { + handler := func(c fiber.Ctx) error { + c.Cookie(&fiber.Cookie{ + Name: "k4", + Value: "v4", + }) + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + jar.SetKeyValue("example.com", "k2", "v2") + jar.SetKeyValue("example", "k3", "v3") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1v2") + + require.Len(t, jar.getCookiesByHost("example.com"), 3) + }) + + t.Run("with expiration", func(t *testing.T) { + handler := func(c fiber.Ctx) error { + c.Cookie(&fiber.Cookie{ + Name: "k4", + Value: "v4", + Expires: time.Now().Add(1 * time.Nanosecond), + }) + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + jar.SetKeyValue("example.com", "k2", "v2") + jar.SetKeyValue("example", "k3", "v3") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1v2") + + require.Len(t, jar.getCookiesByHost("example.com"), 2) + }) + + t.Run("override cookie value", func(t *testing.T) { + handler := func(c fiber.Ctx) error { + c.Cookie(&fiber.Cookie{ + Name: "k1", + Value: "v2", + }) + return c.SendString( + c.Cookies("k1") + c.Cookies("k2")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + jar.SetKeyValue("example.com", "k2", "v2") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1v2") + + for _, cookie := range jar.getCookiesByHost("example.com") { + if string(cookie.Key()) == "k1" { + require.Equal(t, "v2", string(cookie.Value())) + } + } + }) +} + func Test_Client_Referer(t *testing.T) { handler := func(c fiber.Ctx) error { return c.Send(c.Request().Header.Referer()) diff --git a/client/cookiejar.go b/client/cookiejar.go new file mode 100644 index 0000000000..3db88d8e46 --- /dev/null +++ b/client/cookiejar.go @@ -0,0 +1,240 @@ +// The code has been taken from https://github.com/valyala/fasthttp/pull/526 originally. +package client + +import ( + "bytes" + "github.com/gofiber/utils/v2" + "github.com/valyala/fasthttp" + "net" + "sync" + "time" +) + +var cookieJarPool = sync.Pool{ + New: func() interface{} { + return &CookieJar{} + }, +} + +// AcquireCookieJar returns an empty CookieJar object from pool. +func AcquireCookieJar() *CookieJar { + return cookieJarPool.Get().(*CookieJar) +} + +// ReleaseCookieJar returns CookieJar to the pool. +func ReleaseCookieJar(c *CookieJar) { + c.Release() + cookieJarPool.Put(c) +} + +// CookieJar manages cookie storage. It is used by the client to store cookies. +type CookieJar struct { + mu sync.Mutex + hostCookies map[string][]*fasthttp.Cookie +} + +// Get returns the cookies stored from a specific domain. +// If there were no cookies related with host returned slice will be nil. +// +// CookieJar keeps a copy of the cookies, so the returned cookies can be released safely. +func (cj *CookieJar) Get(uri *fasthttp.URI) []*fasthttp.Cookie { + if uri == nil { + return nil + } + + return cj.get(uri.Host(), uri.Path()) +} + +// get returns the cookies stored from a specific host and path. +func (cj *CookieJar) get(host, path []byte) []*fasthttp.Cookie { + if cj.hostCookies == nil { + return nil + } + + var ( + err error + cookies []*fasthttp.Cookie + hostStr = utils.UnsafeString(host) + ) + + // port must not be included. + hostStr, _, err = net.SplitHostPort(hostStr) + if err != nil { + hostStr = utils.UnsafeString(host) + } + // get cookies deleting expired ones + cookies = cj.getCookiesByHost(hostStr) + + newCookies := make([]*fasthttp.Cookie, 0, len(cookies)) + for i := 0; i < len(cookies); i++ { + cookie := cookies[i] + if len(path) > 1 && len(cookie.Path()) > 1 && !bytes.HasPrefix(cookie.Path(), path) { + continue + } + newCookies = append(newCookies, cookie) + } + + return newCookies +} + +// getCookiesByHost returns the cookies stored from a specific host. +// If cookies are expired they will be deleted. +func (cj *CookieJar) getCookiesByHost(host string) []*fasthttp.Cookie { + cj.mu.Lock() + defer cj.mu.Unlock() + + now := time.Now() + cookies := cj.hostCookies[host] + + for i := 0; i < len(cookies); i++ { + c := cookies[i] + if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) { // release cookie if expired + cookies = append(cookies[:i], cookies[i+1:]...) + fasthttp.ReleaseCookie(c) + i-- + } + } + + return cookies +} + +// Set sets cookies for a specific host. +// The host is get from uri.Host(). +// If the cookie key already exists it will be replaced by the new cookie value. +// +// CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely. +func (cj *CookieJar) Set(uri *fasthttp.URI, cookies ...*fasthttp.Cookie) { + if uri == nil { + return + } + + cj.set(uri.Host(), cookies...) +} + +// SetByHost sets cookies for a specific host. +// If the cookie key already exists it will be replaced by the new cookie value. +// +// CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely. +func (cj *CookieJar) SetByHost(host []byte, cookies ...*fasthttp.Cookie) { + cj.set(host, cookies...) +} + +func (cj *CookieJar) set(host []byte, cookies ...*fasthttp.Cookie) { + hostStr := utils.UnsafeString(host) + + cj.mu.Lock() + defer cj.mu.Unlock() + + if cj.hostCookies == nil { + cj.hostCookies = make(map[string][]*fasthttp.Cookie) + } + + hostCookies, ok := cj.hostCookies[hostStr] + if !ok { + // If the key does not exist in the map, then we must make a copy for the key to avoid unsafe usage. + hostStr = string(host) + } + + for _, cookie := range cookies { + c := searchCookieByKeyAndPath(cookie.Key(), cookie.Path(), hostCookies) + if c == nil { + // If the cookie does not exist in the slice, let's acquire new cookie and store it. + c = fasthttp.AcquireCookie() + hostCookies = append(hostCookies, c) + } + c.CopyTo(cookie) // override cookie properties + } + cj.hostCookies[hostStr] = hostCookies +} + +// SetKeyValue sets a cookie by key and value for a specific host. +// +// This function prevents extra allocations by making repeated cookies +// not being duplicated. +func (cj *CookieJar) SetKeyValue(host, key, value string) { + cj.SetKeyValueBytes(host, utils.UnsafeBytes(key), utils.UnsafeBytes(value)) +} + +// SetKeyValueBytes sets a cookie by key and value for a specific host. +// +// This function prevents extra allocations by making repeated cookies +// not being duplicated. +func (cj *CookieJar) SetKeyValueBytes(host string, key, value []byte) { + cj.setKeyValue(host, key, value) +} + +func (cj *CookieJar) setKeyValue(host string, key, value []byte) { + c := fasthttp.AcquireCookie() + c.SetKeyBytes(key) + c.SetValueBytes(value) + + cj.set(utils.UnsafeBytes(host), c) +} + +// dumpCookiesToReq dumps the stored cookies to the request. +func (cj *CookieJar) dumpCookiesToReq(req *fasthttp.Request) { + uri := req.URI() + + cookies := cj.get(uri.Host(), uri.Path()) + for _, cookie := range cookies { + req.Header.SetCookieBytesKV(cookie.Key(), cookie.Value()) + } +} + +// getCookiesFromResp parses the response cookies and stores them. +func (cj *CookieJar) getCookiesFromResp(host, path []byte, resp *fasthttp.Response) { + hostStr := utils.UnsafeString(host) + + cj.mu.Lock() + defer cj.mu.Unlock() + + if cj.hostCookies == nil { + cj.hostCookies = make(map[string][]*fasthttp.Cookie) + } + cookies, ok := cj.hostCookies[hostStr] + if !ok { + // If the key does not exist in the map then + // we must make a copy for the key to avoid unsafe usage. + hostStr = string(host) + } + + now := time.Now() + resp.Header.VisitAllCookie(func(key, value []byte) { + isCreated := false + c := searchCookieByKeyAndPath(key, path, cookies) + if c == nil { + c, isCreated = fasthttp.AcquireCookie(), true + } + + _ = c.ParseBytes(value) + if c.Expire().Equal(fasthttp.CookieExpireUnlimited) || c.Expire().After(now) { + cookies = append(cookies, c) + } else if isCreated { + fasthttp.ReleaseCookie(c) + } + }) + cj.hostCookies[hostStr] = cookies +} + +// Release releases all cookie values. +func (cj *CookieJar) Release() { + for _, v := range cj.hostCookies { + for _, c := range v { + fasthttp.ReleaseCookie(c) + } + } + cj.hostCookies = nil +} + +// searchCookieByKeyAndPath searches for a cookie by key and path. +func searchCookieByKeyAndPath(key, path []byte, cookies []*fasthttp.Cookie) *fasthttp.Cookie { + for _, c := range cookies { + if bytes.Equal(key, c.Key()) { + if len(path) <= 1 || bytes.HasPrefix(c.Path(), path) { + return c + } + } + } + + return nil +} diff --git a/client/cookiejar_test.go b/client/cookiejar_test.go new file mode 100644 index 0000000000..1cfd7510bd --- /dev/null +++ b/client/cookiejar_test.go @@ -0,0 +1,212 @@ +package client + +import ( + "bytes" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" + "testing" + "time" +) + +func checkKeyValue(t *testing.T, cj *CookieJar, cookie *fasthttp.Cookie, uri *fasthttp.URI, n int) { + t.Helper() + + cs := cj.Get(uri) + require.True(t, len(cs) >= n) + + c := cs[n-1] + require.NotNil(t, c) + + require.Equal(t, string(c.Key()), string(cookie.Key())) + require.Equal(t, string(c.Value()), string(cookie.Value())) +} + +func TestCookieJarGet(t *testing.T) { + t.Parallel() + + url := []byte("http://fasthttp.com/") + url1 := []byte("http://fasthttp.com/make") + url11 := []byte("http://fasthttp.com/hola") + url2 := []byte("http://fasthttp.com/make/fasthttp") + url3 := []byte("http://fasthttp.com/make/fasthttp/great") + prefix := []byte("/") + prefix1 := []byte("/make") + prefix2 := []byte("/make/fasthttp") + prefix3 := []byte("/make/fasthttp/great") + cj := &CookieJar{} + + c1 := &fasthttp.Cookie{} + c1.SetKey("k") + c1.SetValue("v") + c1.SetPath("/make/") + + c2 := &fasthttp.Cookie{} + c2.SetKey("kk") + c2.SetValue("vv") + c2.SetPath("/make/fasthttp") + + c3 := &fasthttp.Cookie{} + c3.SetKey("kkk") + c3.SetValue("vvv") + c3.SetPath("/make/fasthttp/great") + + uri := fasthttp.AcquireURI() + require.NoError(t, uri.Parse(nil, url)) + + uri1 := fasthttp.AcquireURI() + require.NoError(t, uri1.Parse(nil, url1)) + + uri11 := fasthttp.AcquireURI() + require.NoError(t, uri11.Parse(nil, url11)) + + uri2 := fasthttp.AcquireURI() + require.NoError(t, uri2.Parse(nil, url2)) + + uri3 := fasthttp.AcquireURI() + require.NoError(t, uri3.Parse(nil, url3)) + + cj.Set(uri1, c1, c2, c3) + + cookies := cj.Get(uri1) + require.Len(t, cookies, 3) + for _, cookie := range cookies { + require.True(t, bytes.HasPrefix(cookie.Path(), prefix1)) + } + + cookies = cj.Get(uri11) + require.Len(t, cookies, 0) + + cookies = cj.Get(uri2) + require.Len(t, cookies, 2) + for _, cookie := range cookies { + require.True(t, bytes.HasPrefix(cookie.Path(), prefix2)) + } + + cookies = cj.Get(uri3) + require.Len(t, cookies, 1) + + for _, cookie := range cookies { + require.True(t, bytes.HasPrefix(cookie.Path(), prefix3)) + } + + cookies = cj.Get(uri) + require.Len(t, cookies, 3) + for _, cookie := range cookies { + require.True(t, bytes.HasPrefix(cookie.Path(), prefix)) + } +} + +func TestCookieJarGetExpired(t *testing.T) { + t.Parallel() + + url1 := []byte("http://fasthttp.com/make/") + uri1 := fasthttp.AcquireURI() + require.NoError(t, uri1.Parse(nil, url1)) + + c1 := &fasthttp.Cookie{} + c1.SetKey("k") + c1.SetValue("v") + c1.SetExpire(time.Now().Add(-time.Hour)) + + cj := &CookieJar{} + cj.Set(uri1, c1) + + cookies := cj.Get(uri1) + require.Len(t, cookies, 0) +} + +func TestCookieJarSet(t *testing.T) { + t.Parallel() + + url := []byte("http://fasthttp.com/hello/world") + cj := &CookieJar{} + + cookie := &fasthttp.Cookie{} + cookie.SetKey("k") + cookie.SetValue("v") + + uri := fasthttp.AcquireURI() + require.NoError(t, uri.Parse(nil, url)) + + cj.Set(uri, cookie) + checkKeyValue(t, cj, cookie, uri, 1) +} + +func TestCookieJarSetRepeatedCookieKeys(t *testing.T) { + t.Parallel() + + host := "fast.http" + cj := &CookieJar{} + + uri := fasthttp.AcquireURI() + uri.SetHost(host) + + cookie := &fasthttp.Cookie{} + cookie.SetKey("k") + cookie.SetValue("v") + + cookie2 := &fasthttp.Cookie{} + cookie2.SetKey("k") + cookie2.SetValue("v2") + + cookie3 := &fasthttp.Cookie{} + cookie3.SetKey("key") + cookie3.SetValue("value") + + cj.Set(uri, cookie, cookie2, cookie3) + + cookies := cj.Get(uri) + require.Len(t, cookies, 2) + require.Equal(t, cookies[0], cookie2) + require.True(t, bytes.Equal(cookies[0].Value(), cookie2.Value())) +} + +func TestCookieJarSetKeyValue(t *testing.T) { + t.Parallel() + + host := "fast.http" + cj := &CookieJar{} + + uri := fasthttp.AcquireURI() + uri.SetHost(host) + + cj.SetKeyValue(host, "k", "v") + cj.SetKeyValue(host, "key", "value") + cj.SetKeyValue(host, "k", "vv") + cj.SetKeyValue(host, "key", "value2") + + cookies := cj.Get(uri) + require.Len(t, cookies, 2) +} + +func TestCookieJarGetFromResponse(t *testing.T) { + t.Parallel() + + res := fasthttp.AcquireResponse() + host := []byte("fast.http") + uri := fasthttp.AcquireURI() + uri.SetHostBytes(host) + + c := &fasthttp.Cookie{} + c.SetKey("key") + c.SetValue("val") + + c2 := &fasthttp.Cookie{} + c2.SetKey("k") + c2.SetValue("v") + + c3 := &fasthttp.Cookie{} + c3.SetKey("kk") + c3.SetValue("vv") + + res.Header.SetStatusCode(200) + res.Header.SetCookie(c) + res.Header.SetCookie(c2) + res.Header.SetCookie(c3) + + cj := &CookieJar{} + cj.getCookiesFromResp(host, nil, res) + + cookies := cj.Get(uri) + require.Len(t, cookies, 3) +} diff --git a/client/hooks.go b/client/hooks.go index 15d13ddf72..73ba0a79e5 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -156,11 +156,8 @@ func parserRequestHeader(c *Client, req *Request) error { // set cookie // add cookie form jar to req - if c.jar != nil { - cookies := c.jar.Cookies(req.RawRequest.URI()) - for _, c := range cookies { - req.RawRequest.Header.SetCookieBytesKV(c.Key, c.Value) - } + if c.cookieJar != nil { + c.cookieJar.dumpCookiesToReq(req.RawRequest) } c.cookies.VisitAll(func(key, val string) { @@ -287,8 +284,8 @@ func parserResponseCookie(c *Client, resp *Response, req *Request) (err error) { }) // store cookies to jar - if c.jar != nil { - c.jar.SetCookies(req.RawRequest.URI(), resp.cookie) + if c.cookieJar != nil { + c.cookieJar.getCookiesFromResp(req.RawRequest.URI().Host(), req.RawRequest.URI().Path(), resp.RawResponse) } return diff --git a/client/jar.go b/client/jar.go deleted file mode 100644 index 71f711fff7..0000000000 --- a/client/jar.go +++ /dev/null @@ -1,306 +0,0 @@ -package client - -import ( - "fmt" - "sort" - "strings" - "sync" - "time" - - "github.com/gofiber/utils/v2" - "github.com/valyala/fasthttp" -) - -var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC) - -type jar struct { - mu sync.Mutex - - entries map[string]map[string]*entry - - nextSeqNum uint64 -} - -type entry struct { - Key []byte - Value []byte - Domain []byte - Path []byte - SameSite fasthttp.CookieSameSite - Secure bool - HttpOnly bool - Persistent bool - Expires time.Time - Creation time.Time - LastAccess time.Time - - seqNum uint64 -} - -func (e *entry) id() string { - return fmt.Sprintf("%s:%s:%s", utils.UnsafeString(e.Domain), utils.UnsafeString(e.Path), utils.UnsafeString(e.Key)) -} -func (e *entry) shouldSend(https bool, host, path []byte) bool { - return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure) -} - -func (e *entry) domainMatch(host []byte) bool { - if utils.EqualFold(e.Domain, host) { - return true - } - - return hasDotSuffix(host, e.Domain) -} - -func (e *entry) pathMatch(path []byte) bool { - if utils.EqualFold(e.Path, path) { - return true - } - if strings.HasPrefix(utils.UnsafeString(path), utils.UnsafeString(e.Path)) { - if e.Path[len(e.Path)-1] == '/' { - return true // The "/any/" matches "/any/path" case. - } else if path[len(e.Path)] == '/' { - return true // The "/any" matches "/any/path" case. - } - } - return false -} - -func (e *entry) reset() { - e.Key = []byte{} - e.Value = []byte{} - e.Domain = []byte{} - e.Path = []byte{} - e.SameSite = fasthttp.CookieSameSiteDefaultMode - e.Secure = true - e.HttpOnly = true - e.Persistent = false - - now := time.Now() - e.Expires = now - e.Creation = now - e.LastAccess = now - - e.seqNum = 0 -} - -func (j *jar) Cookies(u *fasthttp.URI) []*entry { - return j.cookies(u, time.Now()) -} - -func (j *jar) cookies(u *fasthttp.URI, now time.Time) (cookies []*entry) { - if !utils.EqualFold(u.Scheme(), []byte("http")) && !utils.EqualFold(u.Scheme(), []byte("https")) { - return - } - - host := u.Host() - key := jarKey(host) - - j.mu.Lock() - defer j.mu.Unlock() - - subMap := j.entries[key] - if subMap == nil { - return - } - - https := utils.EqualFold(u.Scheme(), []byte("https")) - path := u.Path() - if len(path) == 0 { - path = []byte("/") - } - - modified := false - for id, e := range subMap { - if e.Persistent && !e.Expires.After(now) { - ee := subMap[id] - delete(subMap, id) - releaseEntry(ee) - modified = true - continue - } - - if !e.shouldSend(https, host, path) { - continue - } - - e.LastAccess = now - subMap[id] = e - cookies = append(cookies, e) - modified = true - } - - if modified { - if len(subMap) == 0 { - delete(j.entries, key) - } else { - j.entries[key] = subMap - } - } - - sort.Slice(cookies, func(i, j int) bool { - s := cookies - if len(s[i].Path) != len(s[j].Path) { - return len(s[i].Path) > len(s[j].Path) - } - - if s[i].Creation != s[j].Creation { - return s[i].Creation.Before(s[j].Creation) - } - - return s[i].seqNum < s[j].seqNum - }) - - return -} - -func (j *jar) SetCookies(u *fasthttp.URI, cookies []*fasthttp.Cookie) { - j.setCookies(u, cookies, time.Now()) -} - -func (j *jar) setCookies(u *fasthttp.URI, cookies []*fasthttp.Cookie, now time.Time) { - if len(cookies) == 0 { - return - } - - if !utils.EqualFold(u.Scheme(), []byte("http")) && !utils.EqualFold(u.Scheme(), []byte("https")) { - return - } - - host := u.Host() - path := u.Path() - key := jarKey(host) - - j.mu.Lock() - defer j.mu.Unlock() - - subMap := j.entries[key] - - modified := false - for _, cookie := range cookies { - e, remove := newEntry(cookie, now, path) - id := e.id() - - if remove { - if subMap != nil { - if _, ok := subMap[id]; ok { - ee := subMap[id] - delete(subMap, id) - releaseEntry(ee) - modified = true - } - - continue - } - } - - if subMap == nil { - subMap = make(map[string]*entry) - } - - if old, ok := subMap[id]; ok { - e.Creation = old.Creation - e.seqNum = old.seqNum - } else { - e.Creation = now - e.seqNum = j.nextSeqNum - j.nextSeqNum++ - } - - e.LastAccess = now - subMap[id] = e - modified = true - } - - if modified { - if len(subMap) == 0 { - delete(j.entries, key) - } else { - j.entries[key] = subMap - } - } -} - -func jarKey(h []byte) string { - host := utils.UnsafeString(h) - if utils.IsIPv4(host) || utils.IsIPv6(host) { - return host - } - - i := strings.LastIndex(host, ".") - if i <= 0 { - return host - } - - prevDot := strings.LastIndex(host[:i-1], ".") - return host[prevDot+1:] -} - -func hasDotSuffix(s, suffix []byte) bool { - return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && utils.EqualFold(s[len(s)-len(suffix):], suffix) -} - -func newEntry(c *fasthttp.Cookie, now time.Time, path []byte) (*entry, bool) { - e := acquireEntry() - - e.Key = utils.CopyBytes(c.Key()) - fmt.Println(c.Path()) - if len(c.Path()) != 0 || c.Path()[0] != '/' { - e.Path = utils.CopyBytes(path) - } else { - e.Path = utils.CopyBytes(c.Path()) - } - - e.Domain = utils.CopyBytes(c.Domain()) - - if c.MaxAge() < 0 { - return e, true - } else if c.MaxAge() > 0 { - e.Expires = now.Add(time.Duration(c.MaxAge()) * time.Second) - e.Persistent = true - } else { - if c.Expire().IsZero() { - e.Expires = endOfTime - e.Persistent = false - } else { - if !c.Expire().After(now) { - return e, true - } - - e.Expires = c.Expire() - e.Persistent = true - } - } - - e.Value = utils.CopyBytes(c.Value()) - e.Secure = c.Secure() - e.HttpOnly = c.HTTPOnly() - - e.SameSite = c.SameSite() - - return e, false -} - -func newJar() *jar { - return &jar{ - mu: sync.Mutex{}, - entries: map[string]map[string]*entry{}, - nextSeqNum: 0, - } -} - -var entryPool = &sync.Pool{ - New: func() any { - return &entry{} - }, -} - -func acquireEntry() *entry { - e := entryPool.Get().(*entry) - return e -} - -func releaseEntry(e *entry) { - e.reset() - entryPool.Put(e) -} diff --git a/client/jar_test.go b/client/jar_test.go deleted file mode 100644 index c5c6825f38..0000000000 --- a/client/jar_test.go +++ /dev/null @@ -1,231 +0,0 @@ -package client - -import ( - "fmt" - "strings" - "testing" - "time" - - "github.com/gofiber/utils/v2" - "github.com/stretchr/testify/require" - "github.com/valyala/fasthttp" -) - -// tNow is the synthetic current time used as now during testing. -var tNow = time.Date(2013, 1, 1, 12, 0, 0, 0, time.UTC) - -var hasDotSuffixTests = [...]struct { - s, suffix string -}{ - {"", ""}, - {"", "."}, - {"", "x"}, - {".", ""}, - {".", "."}, - {".", ".."}, - {".", "x"}, - {".", "x."}, - {".", ".x"}, - {".", ".x."}, - {"x", ""}, - {"x", "."}, - {"x", ".."}, - {"x", "x"}, - {"x", "x."}, - {"x", ".x"}, - {"x", ".x."}, - {".x", ""}, - {".x", "."}, - {".x", ".."}, - {".x", "x"}, - {".x", "x."}, - {".x", ".x"}, - {".x", ".x."}, - {"x.", ""}, - {"x.", "."}, - {"x.", ".."}, - {"x.", "x"}, - {"x.", "x."}, - {"x.", ".x"}, - {"x.", ".x."}, - {"com", ""}, - {"com", "m"}, - {"com", "om"}, - {"com", "com"}, - {"com", ".com"}, - {"com", "x.com"}, - {"com", "xcom"}, - {"com", "xorg"}, - {"com", "org"}, - {"com", "rg"}, - {"foo.com", ""}, - {"foo.com", "m"}, - {"foo.com", "om"}, - {"foo.com", "com"}, - {"foo.com", ".com"}, - {"foo.com", "o.com"}, - {"foo.com", "oo.com"}, - {"foo.com", "foo.com"}, - {"foo.com", ".foo.com"}, - {"foo.com", "x.foo.com"}, - {"foo.com", "xfoo.com"}, - {"foo.com", "xfoo.org"}, - {"foo.com", "foo.org"}, - {"foo.com", "oo.org"}, - {"foo.com", "o.org"}, - {"foo.com", ".org"}, - {"foo.com", "org"}, - {"foo.com", "rg"}, -} - -func TestHasDotSuffix(t *testing.T) { - for _, tc := range hasDotSuffixTests { - got := hasDotSuffix([]byte(tc.s), []byte(tc.suffix)) - - want := strings.HasSuffix(tc.s, "."+tc.suffix) - require.Equal(t, want, got) - } -} - -var canonicalHostTests = map[string]string{ - "www.example.com": "www.example.com", - "WWW.EXAMPLE.COM": "www.example.com", - "wWw.eXAmple.CoM": "www.example.com", - "www.example.com:80": "www.example.com", - "192.168.0.10": "192.168.0.10", - "192.168.0.5:8080": "192.168.0.5", - "2001:4860:0:2001::68": "2001:4860:0:2001::68", - "[2001:4860:0:::68]:8080": "2001:4860:0:::68", - "www.bücher.de": "www.xn--bcher-kva.de", - "www.example.com.": "www.example.com", - // TODO: Fix canonicalHost so that all of the following malformed - // domain names trigger an error. (This list is not exhaustive, e.g. - // malformed internationalized domain names are missing.) - ".": "", - "..": ".", - "...": "..", - ".net": ".net", - ".net.": ".net", - "a..": "a.", - "b.a..": "b.a.", - "weird.stuff...": "weird.stuff..", - "[bad.unmatched.bracket:": "error", -} - -var jarKeyTests = map[string]string{ - "foo.www.example.com": "example.com", - "www.example.com": "example.com", - "example.com": "example.com", - "com": "com", - "foo.www.bbc.co.uk": "co.uk", - "www.bbc.co.uk": "co.uk", - "bbc.co.uk": "co.uk", - "co.uk": "co.uk", - "uk": "uk", - "192.168.0.5": "192.168.0.5", - // The following are actual outputs of canonicalHost for - // malformed inputs to canonicalHost. - "": "", - ".": ".", - "..": "..", - ".net": ".net", - "a.": "a.", - "b.a.": "a.", - "weird.stuff..": "stuff..", -} - -func TestJarKey(t *testing.T) { - for host, want := range jarKeyTests { - got := jarKey([]byte(host)) - - require.Equal(t, want, got) - } -} - -// expiresIn creates an expires attribute delta seconds from tNow. -func expiresIn(delta int) string { - t := tNow.Add(time.Duration(delta) * time.Second) - return "expires=" + t.Format(time.RFC1123) -} - -// mustParseURL parses s to an URL and panics on error. -func mustParseURL(s string) *fasthttp.URI { - u := fasthttp.AcquireURI() - err := u.Parse(nil, utils.UnsafeBytes(s)) - - if err != nil || utils.UnsafeString(u.Scheme()) == "" || utils.UnsafeString(u.Hash()) == "" { - panic(fmt.Sprintf("Unable to parse URL %s.", s)) - } - return u -} - -// jarTest encapsulates the following actions on a jar: -// 1. Perform SetCookies with fromURL and the cookies from setCookies. -// (Done at time tNow + 0 ms.) -// 2. Check that the entries in the jar matches content. -// (Done at time tNow + 1001 ms.) -// 3. For each query in tests: Check that Cookies with toURL yields the -// cookies in want. -// (Query n done at tNow + (n+2)*1001 ms.) -type jarTest struct { - description string // The description of what this test is supposed to test - fromURL string // The full URL of the request from which Set-Cookie headers where received - setCookies []string // All the cookies received from fromURL - content string // The whole (non-expired) content of the jar - queries []query // Queries to test the Jar.Cookies method -} - -// query contains one test of the cookies returned from Jar.Cookies. -type query struct { - toURL string // the URL in the Cookies call - want string // the expected list of cookies (order matters) -} - -// run runs the jarTest. -func (test jarTest) run(t *testing.T, jar *jar) { - // now := tNow - - // // Populate jar with cookies. - // setCookies := make([]*fasthttp.Cookie, len(test.setCookies)) - // for i, cs := range test.setCookies { - // resp := fasthttp.AcquireResponse() - // cookies := (&http.Response{Header: http.Header{"Set-Cookie": {cs}}}).Cookies() - // if len(cookies) != 1 { - // panic(fmt.Sprintf("Wrong cookie line %q: %#v", cs, cookies)) - // } - // setCookies[i] = cookies[0] - // } - // jar.setCookies(mustParseURL(test.fromURL), setCookies, now) - // now = now.Add(1001 * time.Millisecond) - - // // Serialize non-expired entries in the form "name1=val1 name2=val2". - // var cs []string - // for _, submap := range jar.entries { - // for _, cookie := range submap { - // if !cookie.Expires.After(now) { - // continue - // } - // cs = append(cs, cookie.Name+"="+cookie.Value) - // } - // } - // sort.Strings(cs) - // got := strings.Join(cs, " ") - - // // Make sure jar content matches our expectations. - // if got != test.content { - // t.Errorf("Test %q Content\ngot %q\nwant %q", - // test.description, got, test.content) - // } - - // // Test different calls to Cookies. - // for i, query := range test.queries { - // now = now.Add(1001 * time.Millisecond) - // var s []string - // for _, c := range jar.cookies(mustParseURL(query.toURL), now) { - // s = append(s, c.Name+"="+c.Value) - // } - // if got := strings.Join(s, " "); got != query.want { - // t.Errorf("Test %q #%d\ngot %q\nwant %q", test.description, i, got, query.want) - // } - // } -} From de4b39adba2cac0ef932641a86904bb4dfe6327d Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Mon, 18 Dec 2023 15:56:27 +0300 Subject: [PATCH 078/118] fix some lint errors. --- client/client.go | 41 +++--- client/client_test.go | 13 +- client/cookiejar.go | 14 +- client/helper_test.go | 10 +- client/hooks.go | 3 +- client/request_test.go | 293 ---------------------------------------- client/response_test.go | 43 +++++- 7 files changed, 86 insertions(+), 331 deletions(-) diff --git a/client/client.go b/client/client.go index d3e30848ef..5ac8fed8d6 100644 --- a/client/client.go +++ b/client/client.go @@ -6,16 +6,18 @@ import ( "crypto/x509" "encoding/json" "encoding/xml" - "github.com/gofiber/fiber/v3/log" - "github.com/gofiber/utils/v2" + "fmt" "io" - "net/url" + urlPkg "net/url" "os" "path/filepath" "sort" "sync" "time" + "github.com/gofiber/fiber/v3/log" + "github.com/gofiber/utils/v2" + "github.com/valyala/fasthttp" ) @@ -41,13 +43,13 @@ var _ (Logger) = (*disableLogger)(nil) // All logs are turned off by default. type disableLogger struct{} -func (*disableLogger) Errorf(format string, args ...any) {} +func (*disableLogger) Errorf(_ string, _ ...any) {} -func (*disableLogger) Warnf(format string, args ...any) {} +func (*disableLogger) Warnf(_ string, _ ...any) {} -func (*disableLogger) Infof(format string, args ...any) {} +func (*disableLogger) Infof(_ string, _ ...any) {} -func (*disableLogger) Debugf(format string, args ...any) {} +func (*disableLogger) Debugf(_ string, _ ...any) {} // The Client is used to create a Fiber Client with // client-level settings that apply to all requests @@ -58,7 +60,7 @@ func (*disableLogger) Debugf(format string, args ...any) {} type Client struct { mu sync.RWMutex - baseUrl string + baseURL string userAgent string referer string header *Header @@ -209,7 +211,9 @@ func (c *Client) SetRootCertificate(path string) *Client { if err != nil { return c } - defer func() { _ = file.Close() }() + defer func() { + _ = file.Close() //nolint:errcheck // It is fine to ignore the error here + }() pem, err := io.ReadAll(file) if err != nil { return c @@ -238,13 +242,13 @@ func (c *Client) SetRootCertificateFromString(pem string) *Client { // SetProxyURL sets proxy url in client. It will apply via core to hostclient. func (c *Client) SetProxyURL(proxyURL string) *Client { - pUrl, err := url.Parse(proxyURL) + pURL, err := urlPkg.Parse(proxyURL) if err != nil { log.Errorf("%v", err) return c } - c.proxyURL = pUrl.String() + c.proxyURL = pURL.String() return c } @@ -264,12 +268,12 @@ func (c *Client) SetRetryConfig(config *RetryConfig) *Client { // BaseURL returns baseurl in Client instance. func (c *Client) BaseURL() string { - return c.baseUrl + return c.baseURL } // Set baseUrl which is prefix of real url. func (c *Client) SetBaseURL(url string) *Client { - c.baseUrl = url + c.baseURL = url return c } @@ -282,7 +286,7 @@ func (c *Client) Header(key string) []string { // AddHeader method adds a single header field and its value in the client instance. // These headers will be applied to all requests raised from this client instance. -// Also it can be overridden at request level header options. +// Also, it can be overridden at request level header options. func (c *Client) AddHeader(key, val string) *Client { c.header.Add(key, val) return c @@ -553,7 +557,7 @@ func (c *Client) Patch(url string, cfg ...Config) (*Response, error) { // Reset clear Client object func (c *Client) Reset() { - c.baseUrl = "" + c.baseURL = "" c.timeout = 0 c.userAgent = "" c.referer = "" @@ -691,7 +695,12 @@ func init() { // The returned Client object may be returned to the pool with ReleaseClient when no longer needed. // This allows reducing GC load. func AcquireClient() *Client { - return clientPool.Get().(*Client) + client, ok := clientPool.Get().(*Client) + if !ok { + panic(fmt.Errorf("failed to type-assert to *Client")) + } + + return client } // ReleaseClient returns the object acquired via AcquireClient to the pool. diff --git a/client/client_test.go b/client/client_test.go index 331e1cfeae..9eff93c8ea 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -5,8 +5,6 @@ import ( "context" "crypto/tls" "fmt" - "github.com/gofiber/fiber/v3/addon/retry" - "github.com/gofiber/fiber/v3/log" "io" "net" "os" @@ -14,6 +12,9 @@ import ( "testing" "time" + "github.com/gofiber/fiber/v3/addon/retry" + "github.com/gofiber/fiber/v3/log" + "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/gofiber/utils/v2" @@ -497,8 +498,8 @@ func Test_Client_Header_With_Server(t *testing.T) { handler := func(c fiber.Ctx) error { c.Request().Header.VisitAll(func(key, value []byte) { if k := string(key); k == "K1" || k == "K2" { - _, _ = c.Write(key) - _, _ = c.Write(value) + _, _ = c.Write(key) //nolint:errcheck // It is fine to ignore the error here + _, _ = c.Write(value) //nolint:errcheck // It is fine to ignore the error here } }) return nil @@ -829,8 +830,8 @@ func Test_Client_QueryParam(t *testing.T) { func Test_Client_QueryParam_With_Server(t *testing.T) { handler := func(c fiber.Ctx) error { - c.WriteString(c.Query("k1")) - c.WriteString(c.Query("k2")) + _, _ = c.WriteString(c.Query("k1")) //nolint:errcheck // It is fine to ignore the error here + _, _ = c.WriteString(c.Query("k2")) //nolint:errcheck // It is fine to ignore the error here return nil } diff --git a/client/cookiejar.go b/client/cookiejar.go index 3db88d8e46..4c5106046f 100644 --- a/client/cookiejar.go +++ b/client/cookiejar.go @@ -42,11 +42,11 @@ func (cj *CookieJar) Get(uri *fasthttp.URI) []*fasthttp.Cookie { return nil } - return cj.get(uri.Host(), uri.Path()) + return cj.getByHostAndPath(uri.Host(), uri.Path()) } // get returns the cookies stored from a specific host and path. -func (cj *CookieJar) get(host, path []byte) []*fasthttp.Cookie { +func (cj *CookieJar) getByHostAndPath(host, path []byte) []*fasthttp.Cookie { if cj.hostCookies == nil { return nil } @@ -108,7 +108,7 @@ func (cj *CookieJar) Set(uri *fasthttp.URI, cookies ...*fasthttp.Cookie) { return } - cj.set(uri.Host(), cookies...) + cj.SetByHost(uri.Host(), cookies...) } // SetByHost sets cookies for a specific host. @@ -116,10 +116,6 @@ func (cj *CookieJar) Set(uri *fasthttp.URI, cookies ...*fasthttp.Cookie) { // // CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely. func (cj *CookieJar) SetByHost(host []byte, cookies ...*fasthttp.Cookie) { - cj.set(host, cookies...) -} - -func (cj *CookieJar) set(host []byte, cookies ...*fasthttp.Cookie) { hostStr := utils.UnsafeString(host) cj.mu.Lock() @@ -168,14 +164,14 @@ func (cj *CookieJar) setKeyValue(host string, key, value []byte) { c.SetKeyBytes(key) c.SetValueBytes(value) - cj.set(utils.UnsafeBytes(host), c) + cj.SetByHost(utils.UnsafeBytes(host), c) } // dumpCookiesToReq dumps the stored cookies to the request. func (cj *CookieJar) dumpCookiesToReq(req *fasthttp.Request) { uri := req.URI() - cookies := cj.get(uri.Host(), uri.Path()) + cookies := cj.getByHostAndPath(uri.Host(), uri.Path()) for _, cookie := range cookies { req.Header.SetCookieBytesKV(cookie.Key(), cookie.Value()) } diff --git a/client/helper_test.go b/client/helper_test.go index c27eb46105..8636796e21 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -9,11 +9,17 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) -func createHelperServer(t testing.TB) (*fiber.App, func(addr string) (net.Conn, error), func()) { +func createHelperServer(t testing.TB, config ...fiber.Config) (*fiber.App, func(addr string) (net.Conn, error), func()) { t.Helper() ln := fasthttputil.NewInmemoryListener() - app := fiber.New() + + var cfg fiber.Config + if len(config) > 0 { + cfg = config[0] + } + + app := fiber.New(cfg) return app, func(addr string) (net.Conn, error) { return ln.Dial() diff --git a/client/hooks.go b/client/hooks.go index 73ba0a79e5..5c7089b2f1 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -66,7 +66,7 @@ func parserRequestURL(c *Client, req *Request) error { // whether the URL starts with the protocol uri := splitUrl[0] if !protocolCheck.MatchString(uri) { - uri = c.baseUrl + uri + uri = c.baseURL + uri if !protocolCheck.MatchString(uri) { return ErrURLFormat } @@ -271,6 +271,7 @@ func parserRequestBody(c *Client, req *Request) error { return ErrBodyType } } + return nil } diff --git a/client/request_test.go b/client/request_test.go index 78d3ef6c35..de02bad41b 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -1211,299 +1211,6 @@ func Test_Request_MaxRedirects(t *testing.T) { }) } -// // readErrorConn is a struct for testing retryIf -// type readErrorConn struct { -// net.Conn -// } - -// func (r *readErrorConn) Read(p []byte) (int, error) { -// return 0, fmt.Errorf("error") -// } - -// func (r *readErrorConn) Write(p []byte) (int, error) { -// return len(p), nil -// } - -// func (r *readErrorConn) Close() error { -// return nil -// } - -// func (r *readErrorConn) LocalAddr() net.Addr { -// return nil -// } - -// func (r *readErrorConn) RemoteAddr() net.Addr { -// return nil -// } -// func Test_Client_Agent_RetryIf(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// a := Post("http://example.com"). -// RetryIf(func(req *Request) bool { -// return true -// }) -// dialsCount := 0 -// a.HostClient.Dial = func(addr string) (net.Conn, error) { -// dialsCount++ -// switch dialsCount { -// case 1: -// return &readErrorConn{}, nil -// case 2: -// return &readErrorConn{}, nil -// case 3: -// return &readErrorConn{}, nil -// case 4: -// return ln.Dial() -// default: -// t.Fatalf("unexpected number of dials: %d", dialsCount) -// } -// panic("unreachable") -// } - -// _, _, errs := a.String() -// utils.AssertEqual(t, dialsCount, 4) -// utils.AssertEqual(t, 0, len(errs)) -// } - -// func Test_Client_Debug(t *testing.T) { -// handler := func(c fiber.Ctx) error { -// return c.SendString("debug") -// } - -// var output bytes.Buffer - -// wrapAgent := func(a *Agent) { -// a.Debug(&output) -// } - -// testAgent(t, handler, wrapAgent, "debug", 1) - -// str := output.String() - -// utils.AssertEqual(t, true, strings.Contains(str, "Connected to example.com(pipe)")) -// utils.AssertEqual(t, true, strings.Contains(str, "GET / HTTP/1.1")) -// utils.AssertEqual(t, true, strings.Contains(str, "User-Agent: fiber")) -// utils.AssertEqual(t, true, strings.Contains(str, "Host: example.com\r\n\r\n")) -// utils.AssertEqual(t, true, strings.Contains(str, "HTTP/1.1 200 OK")) -// utils.AssertEqual(t, true, strings.Contains(str, "Content-Type: text/plain; charset=utf-8\r\nContent-Length: 5\r\n\r\ndebug")) -// } - -// func Test_Client_Agent_InsecureSkipVerify(t *testing.T) { -// t.Parallel() - -// cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") -// utils.AssertEqual(t, nil, err) - -// serverTLSConf := &tls.Config{ -// Certificates: []tls.Certificate{cer}, -// } - -// ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") -// utils.AssertEqual(t, nil, err) - -// ln = tls.NewListener(ln, serverTLSConf) - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString("ignore tls") -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// code, body, errs := Get("https://" + ln.Addr().String()). -// InsecureSkipVerify(). -// InsecureSkipVerify(). -// String() - -// utils.AssertEqual(t, 0, len(errs)) -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "ignore tls", body) -// } - -// func Test_Client_Agent_TLS(t *testing.T) { -// t.Parallel() - -// serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() -// utils.AssertEqual(t, nil, err) - -// ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") -// utils.AssertEqual(t, nil, err) - -// ln = tls.NewListener(ln, serverTLSConf) - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString("tls") -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// code, body, errs := Get("https://" + ln.Addr().String()). -// TLSConfig(clientTLSConf). -// String() - -// utils.AssertEqual(t, 0, len(errs)) -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "tls", body) -// } - -// func Test_Client_Agent_MaxRedirectsCount(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// if c.Request().URI().QueryArgs().Has("foo") { -// return c.Redirect("/foo") -// } -// return c.Redirect("/") -// }) -// app.Get("/foo", func(c fiber.Ctx) error { -// return c.SendString("redirect") -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// t.Run("success", func(t *testing.T) { -// a := Get("http://example.com?foo"). -// MaxRedirectsCount(1) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, 200, code) -// utils.AssertEqual(t, "redirect", body) -// utils.AssertEqual(t, 0, len(errs)) -// }) - -// t.Run("error", func(t *testing.T) { -// a := Get("http://example.com"). -// MaxRedirectsCount(1) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// _, body, errs := a.String() - -// utils.AssertEqual(t, "", body) -// utils.AssertEqual(t, 1, len(errs)) -// utils.AssertEqual(t, "too many redirects detected when doing the request", errs[0].Error()) -// }) -// } - -// func Test_Client_Agent_Struct(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.JSON(data{true}) -// }) - -// app.Get("/error", func(c fiber.Ctx) error { -// return c.SendString(`{"success"`) -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// t.Run("success", func(t *testing.T) { -// t.Parallel() - -// a := Get("http://example.com") - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// var d data - -// code, body, errs := a.Struct(&d) - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, `{"success":true}`, string(body)) -// utils.AssertEqual(t, 0, len(errs)) -// utils.AssertEqual(t, true, d.Success) -// }) - -// t.Run("pre error", func(t *testing.T) { -// t.Parallel() -// a := Get("http://example.com") - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } -// a.errs = append(a.errs, errors.New("pre errors")) - -// var d data -// _, body, errs := a.Struct(&d) - -// utils.AssertEqual(t, "", string(body)) -// utils.AssertEqual(t, 1, len(errs)) -// utils.AssertEqual(t, "pre errors", errs[0].Error()) -// utils.AssertEqual(t, false, d.Success) -// }) - -// t.Run("error", func(t *testing.T) { -// a := Get("http://example.com/error") - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// var d data - -// code, body, errs := a.JSONDecoder(json.Unmarshal).Struct(&d) - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, `{"success"`, string(body)) -// utils.AssertEqual(t, 1, len(errs)) -// utils.AssertEqual(t, "unexpected end of JSON input", errs[0].Error()) -// }) -// } - -// func Test_Client_Agent_Parse(t *testing.T) { -// t.Parallel() - -// a := Get("https://example.com:10443") - -// utils.AssertEqual(t, nil, a.Parse()) -// } - -// func Test_AddMissingPort_TLS(t *testing.T) { -// addr := addMissingPort("example.com", true) -// utils.AssertEqual(t, "example.com:443", addr) -// } - -// type data struct { -// Success bool `json:"success" xml:"success"` -// } - -// type errorMultipartWriter struct { -// count int -// } - -// func (e *errorMultipartWriter) Boundary() string { return "myBoundary" } -// func (e *errorMultipartWriter) SetBoundary(_ string) error { return nil } -// func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) { -// if e.count == 0 { -// e.count++ -// return nil, errors.New("CreateFormFile error") -// } -// return errorWriter{}, nil -// } -// func (e *errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") } -// func (e *errorMultipartWriter) Close() error { return errors.New("Close error") } - -// type errorWriter struct{} - -// func (errorWriter) Write(_ []byte) (int, error) { return 0, errors.New("Write error") } - func Test_SetValWithStruct(t *testing.T) { t.Parallel() diff --git a/client/response_test.go b/client/response_test.go index e67f6ac334..c32dc8dece 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -2,8 +2,11 @@ package client import ( "bytes" + "crypto/tls" "encoding/xml" + "github.com/gofiber/fiber/v3/internal/tlstest" "io" + "net" "os" "testing" @@ -104,9 +107,38 @@ func Test_Response_Protocol(t *testing.T) { resp.Close() }) - // TODO: add https test after support https t.Run("https", func(t *testing.T) { t.Parallel() + + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + require.Nil(t, err) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.Nil(t, err) + + ln = tls.NewListener(ln, serverTLSConf) + + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(c.Scheme()) + }) + + go func() { + require.Nil(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + client := AcquireClient() + resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) + + require.Nil(t, err) + require.Equal(t, clientTLSConf, client.TLSConfig()) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "https", resp.String()) + require.Equal(t, "HTTP/1.1", resp.Protocol()) + + resp.Close() }) } @@ -248,13 +280,16 @@ func Test_Response_Save(t *testing.T) { return } - os.RemoveAll("./test") + err := os.RemoveAll("./test") + require.NoError(t, err) }() file, err := os.Open("./test/tmp.json") - defer file.Close() - require.NoError(t, err) + defer func(file *os.File) { + err := file.Close() + require.NoError(t, err) + }(file) data, err := io.ReadAll(file) require.NoError(t, err) From fa2f85839bc615e7a1c90a6c9c3fdc6de54335a8 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Mon, 18 Dec 2023 16:18:08 +0300 Subject: [PATCH 079/118] add benchmark for SetValWithStruct --- client/request_test.go | 203 +++++++++++++++++++++++++++++++++++------ 1 file changed, 175 insertions(+), 28 deletions(-) diff --git a/client/request_test.go b/client/request_test.go index de02bad41b..442e330fab 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -851,34 +851,6 @@ func Test_Request_Referer_With_Server(t *testing.T) { testRequest(t, handler, wrapAgent, "http://referer.com") } -// func Test_Client_Agent_Host(t *testing.T) { -// t.Parallel() - -// ln := fasthttputil.NewInmemoryListener() - -// app := fiber.New(fiber.Config{DisableStartupMessage: true}) - -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString(c.Hostname()) -// }) - -// go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() - -// a := Get("http://1.1.1.1:8080"). -// Host("example.com"). -// HostBytes([]byte("example.com")) - -// utils.AssertEqual(t, "1.1.1.1:8080", a.HostClient.Addr) - -// a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - -// code, body, errs := a.String() - -// utils.AssertEqual(t, fiber.StatusOK, code) -// utils.AssertEqual(t, "example.com", body) -// utils.AssertEqual(t, 0, len(errs)) -// } - func Test_Request_QueryString_With_Server(t *testing.T) { handler := func(c fiber.Ctx) error { return c.Send(c.Request().URI().QueryString()) @@ -1364,3 +1336,178 @@ func Test_SetValWithStruct(t *testing.T) { require.Equal(t, 0, p.Len()) }) } + +func Benchmark_SetValWithStruct(b *testing.B) { + // test SetValWithStruct vai QueryParam struct. + type args struct { + unexport int + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `param:"int_slice"` + } + + b.Run("the struct should be applied", func(b *testing.B) { + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + b.ReportAllocs() + b.StartTimer() + + for i := 0; i < b.N; i++ { + SetValWithStruct(p, "param", args{ + unexport: 5, + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: false, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + } + + require.Equal(b, "", string(p.Peek("unexport"))) + require.Equal(b, []byte("5"), p.Peek("TInt")) + require.Equal(b, []byte("string"), p.Peek("TString")) + require.Equal(b, []byte("3.1"), p.Peek("TFloat")) + require.Equal(b, "", string(p.Peek("TBool"))) + require.True(b, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "foo" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "bar" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "1" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "2" { + return true + } + } + return false + }()) + + }) + + b.Run("the pointer of a struct should be applied", func(b *testing.B) { + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + b.ReportAllocs() + b.StartTimer() + + for i := 0; i < b.N; i++ { + SetValWithStruct(p, "param", &args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: true, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + } + + require.Equal(b, []byte("5"), p.Peek("TInt")) + require.Equal(b, []byte("string"), p.Peek("TString")) + require.Equal(b, []byte("3.1"), p.Peek("TFloat")) + require.Equal(b, "true", string(p.Peek("TBool"))) + require.True(b, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "foo" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "bar" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "1" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "2" { + return true + } + } + return false + }()) + + }) + + b.Run("the zero val should be ignore", func(b *testing.B) { + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + b.ReportAllocs() + b.StartTimer() + + for i := 0; i < b.N; i++ { + SetValWithStruct(p, "param", &args{ + TInt: 0, + TString: "", + TFloat: 0.0, + }) + } + + require.Equal(b, "", string(p.Peek("TInt"))) + require.Equal(b, "", string(p.Peek("TString"))) + require.Equal(b, "", string(p.Peek("TFloat"))) + require.Equal(b, 0, len(p.PeekMulti("TSlice"))) + require.Equal(b, 0, len(p.PeekMulti("int_slice"))) + }) + + b.Run("error type should ignore", func(b *testing.B) { + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + b.ReportAllocs() + b.StartTimer() + + for i := 0; i < b.N; i++ { + SetValWithStruct(p, "param", 5) + } + + require.Equal(b, 0, p.Len()) + }) +} From 2a21c31c3d64607a9e94538587c3df8e516ccac0 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 21 Jan 2024 19:01:42 +0300 Subject: [PATCH 080/118] optimize --- client/client.go | 14 ++++++++++++++ client/client_test.go | 16 +++++++++++++++- client/core.go | 31 ++++++++++++++++++------------- client/core_test.go | 20 ++++++++++---------- client/response_test.go | 3 ++- 5 files changed, 59 insertions(+), 25 deletions(-) diff --git a/client/client.go b/client/client.go index 5ac8fed8d6..3d4d9125dd 100644 --- a/client/client.go +++ b/client/client.go @@ -60,6 +60,8 @@ func (*disableLogger) Debugf(_ string, _ ...any) {} type Client struct { mu sync.RWMutex + host *fasthttp.HostClient + baseURL string userAgent string referer string @@ -135,6 +137,17 @@ func (c *Client) AddResponseHook(h ...ResponseHook) *Client { return c } +// HostClient returns host client in client. +func (c *Client) HostClient() *fasthttp.HostClient { + return c.host +} + +// SetHostClient sets host client in client. +func (c *Client) SetHostClient(host *fasthttp.HostClient) *Client { + c.host = host + return c +} + // JSONMarshal returns json marshal function in Core. func (c *Client) JSONMarshal() utils.JSONMarshal { return c.jsonMarshal @@ -663,6 +676,7 @@ var ( clientPool = &sync.Pool{ New: func() any { return &Client{ + host: &fasthttp.HostClient{}, header: &Header{ RequestHeader: &fasthttp.RequestHeader{}, }, diff --git a/client/client_test.go b/client/client_test.go index 9eff93c8ea..a6ff0fdf2b 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -14,6 +14,7 @@ import ( "github.com/gofiber/fiber/v3/addon/retry" "github.com/gofiber/fiber/v3/log" + "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/tlstest" @@ -109,6 +110,20 @@ func Test_Client_SetBaseURL(t *testing.T) { require.Equal(t, "http://example.com", client.BaseURL()) } +func Test_Client_SetHostClient(t *testing.T) { + t.Parallel() + + hostClient := &fasthttp.HostClient{} + hostClient.Name = "test" + + client := AcquireClient() + defer ReleaseClient(client) + + client.SetHostClient(hostClient) + + require.Equal(t, "test", client.HostClient().Name) +} + func Test_Client_Invalid_URL(t *testing.T) { t.Parallel() @@ -1237,7 +1252,6 @@ func Benchmark_Client_Request(b *testing.B) { b.ResetTimer() b.ReportAllocs() - b.ReportAllocs() for i := 0; i < b.N; i++ { resp, _ := Get("http://example.com", Config{Dial: dial}) diff --git a/client/core.go b/client/core.go index 851e30bb7d..6c989aabdf 100644 --- a/client/core.go +++ b/client/core.go @@ -52,8 +52,6 @@ func addMissingPort(addr string, isTLS bool) string { // `core` stores middleware and plugin definitions, // and defines the execution process type core struct { - host *fasthttp.HostClient - client *Client req *Request ctx context.Context @@ -92,21 +90,23 @@ func (c *core) execFunc() (*Response, error) { cfg := c.getRetryConfig() go func() { + c.client.mu.Lock() + var err error respv := fasthttp.AcquireResponse() if cfg != nil { err = retry.NewExponentialBackoff(*cfg).Retry(func() error { if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { - return c.host.DoRedirects(reqv, respv, c.req.maxRedirects) + return c.client.host.DoRedirects(reqv, respv, c.req.maxRedirects) } - return c.host.Do(reqv, respv) + return c.client.host.Do(reqv, respv) }) } else { if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { - err = c.host.DoRedirects(reqv, respv, c.req.maxRedirects) + err = c.client.host.DoRedirects(reqv, respv, c.req.maxRedirects) } else { - err = c.host.Do(reqv, respv) + err = c.client.host.Do(reqv, respv) } } defer func() { @@ -122,6 +122,7 @@ func (c *core) execFunc() (*Response, error) { respv.CopyTo(resp.RawResponse) errCh <- nil } + c.client.mu.Unlock() }() select { @@ -200,12 +201,16 @@ func (c *core) timeout() context.CancelFunc { // dial set dial in host. func (c *core) dial() { - c.host.Dial = c.req.dial + c.client.mu.Lock() + c.client.host.Dial = c.req.dial + c.client.mu.Unlock() } // tls sets tls config. func (c *core) tls() { - c.host.TLSConfig = c.client.tlsConfig.Clone() + c.client.mu.Lock() + c.client.host.TLSConfig = c.client.tlsConfig.Clone() + c.client.mu.Unlock() } // proxy set proxy in host. @@ -224,8 +229,10 @@ func (c *core) proxy() error { return ErrNotSupportSchema } - c.host.Addr = addMissingPort(string(rawUri.Host()), isTLS) - c.host.IsTLS = isTLS + c.client.mu.Lock() + c.client.host.Addr = addMissingPort(string(rawUri.Host()), isTLS) + c.client.host.IsTLS = isTLS + c.client.mu.Unlock() return nil } @@ -298,9 +305,7 @@ func releaseErrChan(ch chan error) { // newCore returns an empty core object. func newCore() (c *core) { - c = &core{ - host: &fasthttp.HostClient{}, - } + c = &core{} return } diff --git a/client/core_test.go b/client/core_test.go index e862ced508..b4b2d10a86 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -75,14 +75,15 @@ func Test_Exec_Func(t *testing.T) { t.Run("normal request", func(t *testing.T) { core, client, req := newCore(), AcquireClient(), AcquireRequest() - core.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - req.RawRequest.SetRequestURI("http://example.com/normal") - core.ctx = context.Background() core.client = client core.req = req + core.client.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + req.RawRequest.SetRequestURI("http://example.com/normal") + resp, err := core.execFunc() + fmt.Print(string(resp.Body())) require.NoError(t, err) require.Equal(t, 200, resp.RawResponse.StatusCode()) require.Equal(t, "example.com", string(resp.RawResponse.Body())) @@ -90,13 +91,13 @@ func Test_Exec_Func(t *testing.T) { t.Run("the request return an error", func(t *testing.T) { core, client, req := newCore(), AcquireClient(), AcquireRequest() - core.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - req.RawRequest.SetRequestURI("http://example.com/return-error") - core.ctx = context.Background() core.client = client core.req = req + core.client.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + req.RawRequest.SetRequestURI("http://example.com/return-error") + resp, err := core.execFunc() require.NoError(t, err) @@ -106,10 +107,6 @@ func Test_Exec_Func(t *testing.T) { t.Run("the request timeout", func(t *testing.T) { core, client, req := newCore(), AcquireClient(), AcquireRequest() - - core.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } - req.RawRequest.SetRequestURI("http://example.com/hang-up") - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() @@ -117,6 +114,9 @@ func Test_Exec_Func(t *testing.T) { core.client = client core.req = req + core.client.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + req.RawRequest.SetRequestURI("http://example.com/hang-up") + _, err := core.execFunc() require.Equal(t, ErrTimeoutOrCancel, err) diff --git a/client/response_test.go b/client/response_test.go index c32dc8dece..6d59f093a9 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -4,12 +4,13 @@ import ( "bytes" "crypto/tls" "encoding/xml" - "github.com/gofiber/fiber/v3/internal/tlstest" "io" "net" "os" "testing" + "github.com/gofiber/fiber/v3/internal/tlstest" + "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" ) From 1b602da3b893791faea0f97014c7a84b822e7da4 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 21 Jan 2024 19:56:52 +0300 Subject: [PATCH 081/118] update --- client/client.go | 31 ++++++++++++++++++++++--------- client/client_test.go | 12 ++++++++++++ client/request.go | 10 ++-------- 3 files changed, 36 insertions(+), 17 deletions(-) diff --git a/client/client.go b/client/client.go index 3d4d9125dd..ca80743b39 100644 --- a/client/client.go +++ b/client/client.go @@ -6,12 +6,12 @@ import ( "crypto/x509" "encoding/json" "encoding/xml" + "errors" "fmt" "io" urlPkg "net/url" "os" "path/filepath" - "sort" "sync" "time" @@ -21,6 +21,9 @@ import ( "github.com/valyala/fasthttp" ) +var ErrInvalidProxyURL = errors.New("invalid proxy url scheme") +var ErrFailedToAppendCert = errors.New("failed to append certificate") + // Define the logger interface so that users can // use different log implements to output logs. type Logger interface { @@ -222,21 +225,25 @@ func (c *Client) SetRootCertificate(path string) *Client { cleanPath := filepath.Clean(path) file, err := os.Open(cleanPath) if err != nil { - return c + log.Errorf("client: %v", err) } defer func() { _ = file.Close() //nolint:errcheck // It is fine to ignore the error here }() + pem, err := io.ReadAll(file) if err != nil { - return c + log.Errorf("client: %v", err) } config := c.TLSConfig() if config.RootCAs == nil { config.RootCAs = x509.NewCertPool() } - config.RootCAs.AppendCertsFromPEM(pem) + + if !config.RootCAs.AppendCertsFromPEM(pem) { + log.Errorf("client: %v", ErrFailedToAppendCert) + } return c } @@ -248,7 +255,10 @@ func (c *Client) SetRootCertificateFromString(pem string) *Client { if config.RootCAs == nil { config.RootCAs = x509.NewCertPool() } - config.RootCAs.AppendCertsFromPEM([]byte(pem)) + + if !config.RootCAs.AppendCertsFromPEM([]byte(pem)) { + log.Errorf("client: %v", ErrFailedToAppendCert) + } return c } @@ -257,10 +267,15 @@ func (c *Client) SetRootCertificateFromString(pem string) *Client { func (c *Client) SetProxyURL(proxyURL string) *Client { pURL, err := urlPkg.Parse(proxyURL) if err != nil { - log.Errorf("%v", err) + log.Errorf("client: %v", err) + return c + } + if pURL.Scheme != "http" && pURL.Scheme != "https" { + log.Errorf("client: %v", ErrInvalidProxyURL) return c } + c.proxyURL = pURL.String() return c @@ -330,15 +345,13 @@ func (c *Client) SetHeaders(h map[string]string) *Client { } // Param method returns params value via key, -// this method will visit all field in the query param, -// then sort them. +// this method will visit all field in the query param. func (c *Client) Param(key string) []string { res := []string{} tmp := c.params.PeekMulti(key) for _, v := range tmp { res = append(res, utils.UnsafeString(v)) } - sort.Strings(res) return res } diff --git a/client/client_test.go b/client/client_test.go index a6ff0fdf2b..9b623f31fc 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1215,6 +1215,18 @@ func Test_Client_SetProxyURL(t *testing.T) { require.NoError(t, err) }) + t.Run("wrong url scheme", func(t *testing.T) { + var buf bytes.Buffer + log.SetOutput(&buf) + + client := AcquireClient() + client.SetProxyURL("x://test.com") + _, err := client.Get("http://localhost:3000", Config{Dial: dial}) + + require.Contains(t, buf.String(), "client: invalid proxy url scheme") + require.NoError(t, err) + }) + t.Run("error", func(t *testing.T) { client := AcquireClient() client.SetProxyURL("htgdftp://test.com") diff --git a/client/request.go b/client/request.go index 5e55f1dccb..574d435426 100644 --- a/client/request.go +++ b/client/request.go @@ -6,7 +6,6 @@ import ( "io" "path/filepath" "reflect" - "sort" "strconv" "sync" "time" @@ -159,15 +158,13 @@ func (r *Request) SetHeaders(h map[string]string) *Request { } // Param method returns params value via key, -// this method will visit all field in the query param, -// then sort them. +// this method will visit all field in the query param. func (r *Request) Param(key string) []string { res := []string{} tmp := r.params.PeekMulti(key) for _, v := range tmp { res = append(res, utils.UnsafeString(v)) } - sort.Strings(res) return res } @@ -363,15 +360,13 @@ func (r *Request) resetBody(t bodyType) { } // FormData method returns form data value via key, -// this method will visit all field in the form data, -// then sort them. +// this method will visit all field in the form data. func (r *Request) FormData(key string) []string { res := []string{} tmp := r.formData.PeekMulti(key) for _, v := range tmp { res = append(res, utils.UnsafeString(v)) } - sort.Strings(res) return res } @@ -585,7 +580,6 @@ func (h *Header) PeekMultiple(key string) []string { res = append(res, utils.UnsafeString(value)) } }) - sort.Strings(res) return res } From 568152ff24afc81f9587c69a6a59427202951cef Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Mon, 22 Jan 2024 12:26:29 +0300 Subject: [PATCH 082/118] fix proxy middleware --- middleware/proxy/proxy.go | 8 --- middleware/proxy/proxy_test.go | 105 +++++++++++++++++---------------- 2 files changed, 53 insertions(+), 60 deletions(-) diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 284b67c8f5..167a0e6f31 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -2,7 +2,6 @@ package proxy import ( "bytes" - "crypto/tls" "net/url" "strings" "sync" @@ -105,13 +104,6 @@ var client = &fasthttp.Client{ var lock sync.RWMutex -// WithTLSConfig update http client with a user specified tls.config -// This function should be called before Do and Forward. -// Deprecated: use WithClient instead. -func WithTLSConfig(tlsConfig *tls.Config) { - client.TLSConfig = tlsConfig -} - // WithClient sets the global proxy client. // This function should be called before Do and Forward. func WithClient(cli *fasthttp.Client) { diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 34deaf4b52..3594638931 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -3,6 +3,9 @@ package proxy import ( "crypto/tls" "errors" + "github.com/gofiber/fiber/v3" + fiberClient "github.com/gofiber/fiber/v3/client" + "github.com/stretchr/testify/require" "io" "net" "net/http/httptest" @@ -10,10 +13,7 @@ import ( "testing" "time" - "github.com/gofiber/fiber/v3" - fiberClient "github.com/gofiber/fiber/v3/client" "github.com/gofiber/fiber/v3/internal/tlstest" - "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) @@ -26,8 +26,6 @@ func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, str ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) - addr := ln.Addr().String() - go func() { require.NoError(t, target.Listener(ln, fiber.ListenConfig{ DisableStartupMessage: true, @@ -35,6 +33,7 @@ func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, str }() time.Sleep(2 * time.Second) + addr := ln.Addr().String() return target, addr } @@ -105,8 +104,8 @@ func Test_Proxy(t *testing.T) { require.Equal(t, fiber.StatusTeapot, resp.StatusCode) } -// go test -run Test_Proxy_Balancer_WithTLSConfig -func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) { +// go test -run Test_Proxy_Balancer_WithTlsConfig +func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) { t.Parallel() serverTLSConf, _, err := tlstest.GetTLSConfigs() @@ -119,7 +118,7 @@ func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) { app := fiber.New() - app.Get("/tlsbalaner", func(c fiber.Ctx) error { + app.Get("/tlsbalancer", func(c fiber.Ctx) error { return c.SendString("tls balancer") }) @@ -138,21 +137,19 @@ func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) { })) }() - time.Sleep(500 * time.Second) - - resp, err := fiberClient.AcquireClient(). - SetTLSConfig(clientTLSConf). - R(). - Get("https://" + addr + "/tlsbalaner") + client := fiberClient.AcquireClient() + defer fiberClient.ReleaseClient(client) + client.SetTLSConfig(clientTLSConf) + resp, err := client.Get("https://" + addr + "/tlsbalancer") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "tls balancer", string(resp.Body())) resp.Close() } -// go test -run Test_Proxy_Forward_WithTLSConfig_To_Http -func Test_Proxy_Forward_WithTLSConfig_To_Http(t *testing.T) { +// go test -run Test_Proxy_Forward_WithTlsConfig_To_Http +func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) { t.Parallel() _, targetAddr := createProxyTestServer(t, func(c fiber.Ctx) error { @@ -179,14 +176,12 @@ func Test_Proxy_Forward_WithTLSConfig_To_Http(t *testing.T) { })) }() - time.Sleep(500 * time.Second) - - resp, err := fiberClient.AcquireClient().SetTLSConfig(&tls.Config{ - InsecureSkipVerify: true, - }).R(). - SetTimeout(5 * time.Second). - Get("https://" + proxyAddr) + client := fiberClient.AcquireClient() + defer fiberClient.ReleaseClient(client) + client.SetTimeout(5 * time.Second) + client.TLSConfig().InsecureSkipVerify = true //nolint:gosec // We're in a test func, so this is fine + resp, err := client.Get("https://" + proxyAddr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "hello from target", string(resp.Body())) @@ -214,8 +209,8 @@ func Test_Proxy_Forward(t *testing.T) { require.Equal(t, "forwarded", string(b)) } -// go test -run Test_Proxy_Forward_WithTLSConfig -func Test_Proxy_Forward_WithTLSConfig(t *testing.T) { +// go test -run Test_Proxy_Forward_WithClient_TLSConfig +func Test_Proxy_Forward_WithClient_TLSConfig(t *testing.T) { t.Parallel() serverTLSConf, _, err := tlstest.GetTLSConfigs() @@ -236,7 +231,9 @@ func Test_Proxy_Forward_WithTLSConfig(t *testing.T) { clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine // disable certificate verification - WithTLSConfig(clientTLSConf) + WithClient(&fasthttp.Client{ + TLSConfig: clientTLSConf, + }) app.Use(Forward("https://" + addr + "/tlsfwd")) go func() { @@ -245,13 +242,11 @@ func Test_Proxy_Forward_WithTLSConfig(t *testing.T) { })) }() - time.Sleep(500 * time.Second) - - resp, err := fiberClient.AcquireClient(). - SetTLSConfig(clientTLSConf). - R(). - Get("https://" + addr) + client := fiberClient.AcquireClient() + defer fiberClient.ReleaseClient(client) + client.SetTLSConfig(clientTLSConf) + resp, err := client.Get("https://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "tls forward", string(resp.Body())) @@ -432,7 +427,7 @@ func Test_Proxy_Do_WithRedirect(t *testing.T) { return Do(c, "https://google.com") }) - resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500) + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) require.NoError(t, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -448,7 +443,7 @@ func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) { return DoRedirects(c, "http://google.com", 1) }) - resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500) + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) require.NoError(t, err1) _, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -464,7 +459,7 @@ func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) { return DoRedirects(c, "http://google.com", 0) }) - resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500) + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) require.NoError(t, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -487,7 +482,7 @@ func Test_Proxy_DoTimeout_RestoreOriginalURL(t *testing.T) { }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) - require.NoError(t, err1) + require.NoError(t, nil, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "proxied", string(body)) @@ -603,12 +598,10 @@ func Test_Proxy_Forward_Global_Client(t *testing.T) { })) }() - time.Sleep(500 * time.Second) - - resp, err := fiberClient.AcquireClient(). - R(). - Get("http://" + addr) + client := fiberClient.AcquireClient() + defer fiberClient.ReleaseClient(client) + resp, err := client.Get("http://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "test_global_client", string(resp.Body())) @@ -638,12 +631,10 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) { })) }() - time.Sleep(500 * time.Second) - - resp, err := fiberClient.AcquireClient(). - R(). - Get("http://" + addr) + client := fiberClient.AcquireClient() + defer fiberClient.ReleaseClient(client) + resp, err := client.Get("http://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "test_local_client", string(resp.Body())) @@ -695,7 +686,7 @@ func Test_Proxy_Domain_Forward_Local(t *testing.T) { app1 := fiber.New() app1.Get("/test", func(c fiber.Ctx) error { - return c.SendString("test_local_client:" + fiber.Query[string](c, "query_test")) + return c.SendString("test_local_client:" + c.Query("query_test")) }) proxyAddr := ln.Addr().String() @@ -708,15 +699,25 @@ func Test_Proxy_Domain_Forward_Local(t *testing.T) { Dial: fasthttp.Dial, })) - go func() { require.NoError(t, app.Listener(ln)) }() - go func() { require.NoError(t, app1.Listener(ln1)) }() + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + go func() { + require.NoError(t, app1.Listener(ln1, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() - resp, err := fiberClient.Get("http://" + localDomain + "/test?query_test=true") - defer resp.Close() + client := fiberClient.AcquireClient() + defer fiberClient.ReleaseClient(client) + resp, err := client.Get("http://" + localDomain + "/test?query_test=true") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "test_local_client:true", string(resp.Body())) + resp.Close() } // go test -run Test_Proxy_Balancer_Forward_Local @@ -738,5 +739,5 @@ func Test_Proxy_Balancer_Forward_Local(t *testing.T) { b, err := io.ReadAll(resp.Body) require.NoError(t, err) - require.Equal(t, "forwarded", string(b)) + require.Equal(t, string(b), "forwarded") } From fc4f9d57ae7ed2bd0b6b184fc975ecef20161114 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Mon, 22 Jan 2024 12:52:11 +0300 Subject: [PATCH 083/118] use panicf instead of errorf and fix panic on default logger --- client/client.go | 46 +++++++------------------------------------ client/client_test.go | 32 ++++++++---------------------- log/default.go | 22 ++++++++++++++++++--- log/default_test.go | 9 ++++++--- 4 files changed, 40 insertions(+), 69 deletions(-) diff --git a/client/client.go b/client/client.go index ca80743b39..c9181a3552 100644 --- a/client/client.go +++ b/client/client.go @@ -24,36 +24,6 @@ import ( var ErrInvalidProxyURL = errors.New("invalid proxy url scheme") var ErrFailedToAppendCert = errors.New("failed to append certificate") -// Define the logger interface so that users can -// use different log implements to output logs. -type Logger interface { - // The log with error level - Errorf(format string, v ...any) - - // The log with warn level - Warnf(format string, v ...any) - - // The log with info level - Infof(format string, v ...any) - - // The log with debug level - Debugf(format string, v ...any) -} - -var _ (Logger) = (*disableLogger)(nil) - -// Implement a Logger interface. -// All logs are turned off by default. -type disableLogger struct{} - -func (*disableLogger) Errorf(_ string, _ ...any) {} - -func (*disableLogger) Warnf(_ string, _ ...any) {} - -func (*disableLogger) Infof(_ string, _ ...any) {} - -func (*disableLogger) Debugf(_ string, _ ...any) {} - // The Client is used to create a Fiber Client with // client-level settings that apply to all requests // raise from the client. @@ -73,8 +43,7 @@ type Client struct { cookies *Cookie path *PathParam - debug bool - logger Logger + debug bool timeout time.Duration @@ -225,7 +194,7 @@ func (c *Client) SetRootCertificate(path string) *Client { cleanPath := filepath.Clean(path) file, err := os.Open(cleanPath) if err != nil { - log.Errorf("client: %v", err) + log.Panicf("client: %v", err) } defer func() { _ = file.Close() //nolint:errcheck // It is fine to ignore the error here @@ -233,7 +202,7 @@ func (c *Client) SetRootCertificate(path string) *Client { pem, err := io.ReadAll(file) if err != nil { - log.Errorf("client: %v", err) + log.Panicf("client: %v", err) } config := c.TLSConfig() @@ -242,7 +211,7 @@ func (c *Client) SetRootCertificate(path string) *Client { } if !config.RootCAs.AppendCertsFromPEM(pem) { - log.Errorf("client: %v", ErrFailedToAppendCert) + log.Panicf("client: %v", ErrFailedToAppendCert) } return c @@ -257,7 +226,7 @@ func (c *Client) SetRootCertificateFromString(pem string) *Client { } if !config.RootCAs.AppendCertsFromPEM([]byte(pem)) { - log.Errorf("client: %v", ErrFailedToAppendCert) + log.Panicf("client: %v", ErrFailedToAppendCert) } return c @@ -267,12 +236,12 @@ func (c *Client) SetRootCertificateFromString(pem string) *Client { func (c *Client) SetProxyURL(proxyURL string) *Client { pURL, err := urlPkg.Parse(proxyURL) if err != nil { - log.Errorf("client: %v", err) + log.Panicf("client: %v", err) return c } if pURL.Scheme != "http" && pURL.Scheme != "https" { - log.Errorf("client: %v", ErrInvalidProxyURL) + log.Panicf("client: %v", ErrInvalidProxyURL) return c } @@ -698,7 +667,6 @@ var ( }, cookies: &Cookie{}, path: &PathParam{}, - logger: &disableLogger{}, userRequestHooks: []RequestHook{}, builtinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, diff --git a/client/client_test.go b/client/client_test.go index 9b623f31fc..5678697cb0 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,7 +1,6 @@ package client import ( - "bytes" "context" "crypto/tls" "fmt" @@ -13,7 +12,6 @@ import ( "time" "github.com/gofiber/fiber/v3/addon/retry" - "github.com/gofiber/fiber/v3/log" "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3" @@ -1204,35 +1202,21 @@ func Test_Client_SetProxyURL(t *testing.T) { }) t.Run("wrong url", func(t *testing.T) { - var buf bytes.Buffer - log.SetOutput(&buf) - client := AcquireClient() - client.SetProxyURL(":this is not a url") - _, err := client.Get("http://localhost:3000", Config{Dial: dial}) + defer ReleaseClient(client) - require.Contains(t, buf.String(), "missing protocol scheme") - require.NoError(t, err) - }) - - t.Run("wrong url scheme", func(t *testing.T) { - var buf bytes.Buffer - log.SetOutput(&buf) - - client := AcquireClient() - client.SetProxyURL("x://test.com") - _, err := client.Get("http://localhost:3000", Config{Dial: dial}) - - require.Contains(t, buf.String(), "client: invalid proxy url scheme") - require.NoError(t, err) + require.Panics(t, func() { + client.SetProxyURL(":this is not a url") + }) }) t.Run("error", func(t *testing.T) { client := AcquireClient() - client.SetProxyURL("htgdftp://test.com") - _, err := client.Get("http://localhost:3000", Config{Dial: dial}) + defer ReleaseClient(client) - require.Error(t, err) + require.Panics(t, func() { + client.SetProxyURL("htgdftp://test.com") + }) }) } diff --git a/log/default.go b/log/default.go index abc9c8f4d7..690f734602 100644 --- a/log/default.go +++ b/log/default.go @@ -30,7 +30,12 @@ func (l *defaultLogger) privateLog(lv Level, fmtArgs []any) { _, _ = buf.WriteString(level) //nolint:errcheck // It is fine to ignore the error _, _ = buf.WriteString(fmt.Sprint(fmtArgs...)) //nolint:errcheck // It is fine to ignore the error - _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error + if lv == LevelPanic { + panic(buf.String()) + } else { + _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error + } + buf.Reset() bytebufferpool.Put(buf) if lv == LevelFatal { @@ -53,7 +58,13 @@ func (l *defaultLogger) privateLogf(lv Level, format string, fmtArgs []any) { } else { _, _ = fmt.Fprint(buf, fmtArgs...) } - _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error + + if lv == LevelPanic { + panic(buf.String()) + } else { + _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error + } + buf.Reset() bytebufferpool.Put(buf) if lv == LevelFatal { @@ -95,7 +106,12 @@ func (l *defaultLogger) privateLogw(lv Level, format string, keysAndValues []any } } - _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error + if lv == LevelPanic { + panic(buf.String()) + } else { + _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error + } + buf.Reset() bytebufferpool.Put(buf) if lv == LevelFatal { diff --git a/log/default_test.go b/log/default_test.go index 78ef0204ca..8cdb289c97 100644 --- a/log/default_test.go +++ b/log/default_test.go @@ -39,13 +39,16 @@ func Test_DefaultLogger(t *testing.T) { Info("starting work") Warn("work may fail") Error("work failed") - Panic("work panic") + + require.Panics(t, func() { + Panic("work panic") + }) + require.Equal(t, "[Trace] trace work\n"+ "[Debug] received work order\n"+ "[Info] starting work\n"+ "[Warn] work may fail\n"+ - "[Error] work failed\n"+ - "[Panic] work panic\n", string(w.b)) + "[Error] work failed\n", string(w.b)) } func Test_DefaultFormatLogger(t *testing.T) { From 7af7e4be06e608de1a4cbaebb32b86b535bcfe56 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Mon, 22 Jan 2024 14:44:16 +0300 Subject: [PATCH 084/118] update --- client/core_test.go | 2 ++ client/request_test.go | 10 ++++++++++ client/response_test.go | 2 ++ 3 files changed, 14 insertions(+) diff --git a/client/core_test.go b/client/core_test.go index b4b2d10a86..37bdcf4412 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -73,6 +73,8 @@ func Test_Exec_Func(t *testing.T) { require.Nil(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() + time.Sleep(300 * time.Millisecond) + t.Run("normal request", func(t *testing.T) { core, client, req := newCore(), AcquireClient(), AcquireRequest() core.ctx = context.Background() diff --git a/client/request_test.go b/client/request_test.go index 442e330fab..d26133e76f 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -594,7 +594,9 @@ func Test_Request_Get(t *testing.T) { app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) + go start() + time.Sleep(100 * time.Millisecond) for i := 0; i < 5; i++ { req := AcquireRequest().SetDial(ln) @@ -615,7 +617,9 @@ func Test_Request_Post(t *testing.T) { return c.Status(fiber.StatusCreated). SendString(c.FormValue("foo")) }) + go start() + time.Sleep(100 * time.Millisecond) for i := 0; i < 5; i++ { resp, err := AcquireRequest(). @@ -639,6 +643,7 @@ func Test_Request_Head(t *testing.T) { }) go start() + time.Sleep(100 * time.Millisecond) for i := 0; i < 5; i++ { resp, err := AcquireRequest(). @@ -661,6 +666,7 @@ func Test_Request_Put(t *testing.T) { }) go start() + time.Sleep(100 * time.Millisecond) for i := 0; i < 5; i++ { resp, err := AcquireRequest(). @@ -686,6 +692,7 @@ func Test_Request_Delete(t *testing.T) { }) go start() + time.Sleep(100 * time.Millisecond) for i := 0; i < 5; i++ { resp, err := AcquireRequest(). @@ -711,6 +718,7 @@ func Test_Request_Options(t *testing.T) { }) go start() + time.Sleep(100 * time.Millisecond) for i := 0; i < 5; i++ { resp, err := AcquireRequest(). @@ -736,6 +744,7 @@ func Test_Request_Send(t *testing.T) { }) go start() + time.Sleep(100 * time.Millisecond) for i := 0; i < 5; i++ { resp, err := AcquireRequest(). @@ -762,6 +771,7 @@ func Test_Request_Patch(t *testing.T) { }) go start() + time.Sleep(100 * time.Millisecond) for i := 0; i < 5; i++ { resp, err := AcquireRequest(). diff --git a/client/response_test.go b/client/response_test.go index 6d59f093a9..4839615cf6 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -8,6 +8,7 @@ import ( "net" "os" "testing" + "time" "github.com/gofiber/fiber/v3/internal/tlstest" @@ -266,6 +267,7 @@ func Test_Response_Save(t *testing.T) { }) go start() + time.Sleep(300 * time.Millisecond) t.Run("file path", func(t *testing.T) { resp, err := AcquireRequest(). From a6191759bd6e2bf21533b89a4a1fb47362ced671 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Mon, 22 Jan 2024 14:54:47 +0300 Subject: [PATCH 085/118] update --- client/core.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/core.go b/client/core.go index 6c989aabdf..8b3ecf2636 100644 --- a/client/core.go +++ b/client/core.go @@ -164,8 +164,8 @@ func (c *core) preHooks() error { // Exec response hooks func (c *core) afterHooks(resp *Response) error { - c.client.mu.RLock() - defer c.client.mu.RUnlock() + c.client.mu.Lock() + defer c.client.mu.Unlock() for _, f := range c.client.builtinResponseHooks { err := f(c.client, resp, c.req) From f371b386394324b6fa650d7f89a4a22fe0b11cb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Werner?= Date: Mon, 5 Feb 2024 09:31:18 +0100 Subject: [PATCH 086/118] cleanup comments --- client/client.go | 40 +++++++++++++++++++++------------------- client/cookiejar.go | 1 + client/core.go | 7 +++++-- client/hooks.go | 3 +++ client/request.go | 41 +++++++++++++++++++++++------------------ client/response.go | 2 ++ go.sum | 12 ------------ 7 files changed, 55 insertions(+), 51 deletions(-) diff --git a/client/client.go b/client/client.go index c9181a3552..eb9152c2e2 100644 --- a/client/client.go +++ b/client/client.go @@ -81,12 +81,12 @@ func (c *Client) R() *Request { return AcquireRequest().SetClient(c) } -// Request returns user-defined request hooks. +// RequestHook Request returns user-defined request hooks. func (c *Client) RequestHook() []RequestHook { return c.userRequestHooks } -// Add user-defined request hooks. +// AddRequestHook Add user-defined request hooks. func (c *Client) AddRequestHook(h ...RequestHook) *Client { c.mu.Lock() defer c.mu.Unlock() @@ -100,7 +100,7 @@ func (c *Client) ResponseHook() []ResponseHook { return c.userResponseHooks } -// Add user-defined response hooks. +// AddResponseHook Add user-defined response hooks. func (c *Client) AddResponseHook(h ...ResponseHook) *Client { c.mu.Lock() defer c.mu.Unlock() @@ -125,7 +125,7 @@ func (c *Client) JSONMarshal() utils.JSONMarshal { return c.jsonMarshal } -// Set json encoder. +// SetJSONMarshal Set json encoder. func (c *Client) SetJSONMarshal(f utils.JSONMarshal) *Client { c.jsonMarshal = f return c @@ -147,7 +147,7 @@ func (c *Client) XMLMarshal() utils.XMLMarshal { return c.xmlMarshal } -// Set xml encoder. +// SetXMLMarshal Set xml encoder. func (c *Client) SetXMLMarshal(f utils.XMLMarshal) *Client { c.xmlMarshal = f return c @@ -158,7 +158,7 @@ func (c *Client) XMLUnmarshal() utils.XMLUnmarshal { return c.xmlUnmarshal } -// Set xml decoder. +// SetXMLUnmarshal Set xml decoder. func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client { c.xmlUnmarshal = f return c @@ -250,6 +250,7 @@ func (c *Client) SetProxyURL(proxyURL string) *Client { return c } +// RetryConfig returns retry config in client. func (c *Client) RetryConfig() *RetryConfig { return c.retryConfig } @@ -268,7 +269,7 @@ func (c *Client) BaseURL() string { return c.baseURL } -// Set baseUrl which is prefix of real url. +// SetBaseURL Set baseUrl which is prefix of real url. func (c *Client) SetBaseURL(url string) *Client { c.baseURL = url return c @@ -291,7 +292,7 @@ func (c *Client) AddHeader(key, val string) *Client { // SetHeader method sets a single header field and its value in the client instance. // These headers will be applied to all requests raised from this client instance. -// Also it can be overridden at request level header options. +// Also, it can be overridden at request level header options. func (c *Client) SetHeader(key, val string) *Client { c.header.Set(key, val) return c @@ -327,7 +328,7 @@ func (c *Client) Param(key string) []string { // AddParam method adds a single query param field and its value in the client instance. // These params will be applied to all requests raised from this client instance. -// Also it can be overridden at request level param options. +// Also, it can be overridden at request level param options. func (c *Client) AddParam(key, val string) *Client { c.params.Add(key, val) return c @@ -335,7 +336,7 @@ func (c *Client) AddParam(key, val string) *Client { // SetParam method sets a single query param field and its value in the client instance. // These params will be applied to all requests raised from this client instance. -// Also it can be overridden at request level param options. +// Also, it can be overridden at request level param options. func (c *Client) SetParam(key, val string) *Client { c.params.Set(key, val) return c @@ -470,7 +471,7 @@ func (c *Client) DelCookies(key ...string) *Client { // SetTimeout method sets timeout val in client instance. // This value will be applied to all requests raised from this client instance. -// Also it can be overridden at request level timeout options. +// Also, it can be overridden at request level timeout options. func (c *Client) SetTimeout(t time.Duration) *Client { c.timeout = t return c @@ -494,7 +495,7 @@ func (c *Client) SetCookieJar(cookieJar *CookieJar) *Client { return c } -// Get provide a API like axios which send get request. +// Get provide an API like axios which send get request. func (c *Client) Get(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) @@ -502,7 +503,7 @@ func (c *Client) Get(url string, cfg ...Config) (*Response, error) { return req.Get(url) } -// Post provide a API like axios which send post request. +// Post provide an API like axios which send post request. func (c *Client) Post(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) @@ -518,7 +519,7 @@ func (c *Client) Head(url string, cfg ...Config) (*Response, error) { return req.Head(url) } -// Put provide a API like axios which send put request. +// Put provide an API like axios which send put request. func (c *Client) Put(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) @@ -526,7 +527,7 @@ func (c *Client) Put(url string, cfg ...Config) (*Response, error) { return req.Put(url) } -// Delete provide a API like axios which send delete request. +// Delete provide an API like axios which send delete request. func (c *Client) Delete(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) @@ -534,7 +535,7 @@ func (c *Client) Delete(url string, cfg ...Config) (*Response, error) { return req.Delete(url) } -// Options provide a API like axios which send options request. +// Options provide an API like axios which send options request. func (c *Client) Options(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) @@ -542,7 +543,7 @@ func (c *Client) Options(url string, cfg ...Config) (*Response, error) { return req.Options(url) } -// Patch provide a API like axios which send patch request. +// Patch provide an API like axios which send patch request. func (c *Client) Patch(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) @@ -588,7 +589,7 @@ type Config struct { Dial fasthttp.DialFunc } -// Set the parameters passed via Config to Request. +// setConfigToRequest Set the parameters passed via Config to Request. func setConfigToRequest(req *Request, config ...Config) { if len(config) == 0 { return @@ -681,6 +682,7 @@ var ( } ) +// init acquire a default client. func init() { defaultClient = AcquireClient() } @@ -706,7 +708,7 @@ func ReleaseClient(c *Client) { clientPool.Put(c) } -// Get default client. +// C get default client. func C() *Client { return defaultClient } diff --git a/client/cookiejar.go b/client/cookiejar.go index 4c5106046f..353f266d40 100644 --- a/client/cookiejar.go +++ b/client/cookiejar.go @@ -159,6 +159,7 @@ func (cj *CookieJar) SetKeyValueBytes(host string, key, value []byte) { cj.setKeyValue(host, key, value) } +// setKeyValue sets a cookie by key and value for a specific host. func (cj *CookieJar) setKeyValue(host string, key, value []byte) { c := fasthttp.AcquireCookie() c.SetKeyBytes(key) diff --git a/client/core.go b/client/core.go index 8b3ecf2636..d2675ce2f9 100644 --- a/client/core.go +++ b/client/core.go @@ -57,6 +57,7 @@ type core struct { ctx context.Context } +// getRetryConfig returns the retry configuration of the client. func (c *core) getRetryConfig() *RetryConfig { c.client.mu.RLock() defer c.client.mu.RUnlock() @@ -74,6 +75,8 @@ func (c *core) getRetryConfig() *RetryConfig { } } +// execFunc is the core function of the client. +// It sends the request and receives the response. func (c *core) execFunc() (*Response, error) { resp := AcquireResponse() resp.setClient(c.client) @@ -140,7 +143,7 @@ func (c *core) execFunc() (*Response, error) { } } -// Exec request hook +// preHooks Exec request hook func (c *core) preHooks() error { c.client.mu.RLock() defer c.client.mu.RUnlock() @@ -162,7 +165,7 @@ func (c *core) preHooks() error { return nil } -// Exec response hooks +// afterHooks Exec response hooks func (c *core) afterHooks(resp *Response) error { c.client.mu.Lock() defer c.client.mu.Unlock() diff --git a/client/hooks.go b/client/hooks.go index 5c7089b2f1..190dda1840 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -32,6 +32,7 @@ var ( letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits ) +// randString returns a random string with n length func randString(n int) string { b := make([]byte, n) length := len(letterBytes) @@ -275,6 +276,7 @@ func parserRequestBody(c *Client, req *Request) error { return nil } +// parserResponseHeader will parse the response header and store it in the response func parserResponseCookie(c *Client, resp *Response, req *Request) (err error) { resp.RawResponse.Header.VisitAllCookie(func(key, value []byte) { cookie := fasthttp.AcquireCookie() @@ -292,6 +294,7 @@ func parserResponseCookie(c *Client, resp *Response, req *Request) (err error) { return } +// logger is a response hook that logs the request and response func logger(c *Client, resp *Response, req *Request) (err error) { if !c.debug { return diff --git a/client/request.go b/client/request.go index 574d435426..ef0c0bbe33 100644 --- a/client/request.go +++ b/client/request.go @@ -15,7 +15,7 @@ import ( "github.com/valyala/fasthttp" ) -// Implementing this interface allows data to +// WithStruct Implementing this interface allows data to // be stored from a struct via reflect. type WithStruct interface { Add(string, string) @@ -35,6 +35,7 @@ const ( rawBody ) +// Request is a struct which contains the request data. type Request struct { url string method string @@ -62,7 +63,7 @@ type Request struct { RawRequest *fasthttp.Request } -// Set HostClient dial, this method for unit test, +// SetDial set HostClient dial, this method for unit test, // maybe don't use it. func (r *Request) SetDial(f fasthttp.DialFunc) *Request { r.dial = f @@ -160,7 +161,7 @@ func (r *Request) SetHeaders(h map[string]string) *Request { // Param method returns params value via key, // this method will visit all field in the query param. func (r *Request) Param(key string) []string { - res := []string{} + var res []string tmp := r.params.PeekMulti(key) for _, v := range tmp { res = append(res, utils.UnsafeString(v)) @@ -197,7 +198,7 @@ func (r *Request) SetParams(m map[string]string) *Request { return r } -// SetParamWithStruct method sets multiple param fields and its values at one go in the request instance. +// SetParamsWithStruct method sets multiple param fields and its values at one go in the request instance. // It will override param which set in client instance. func (r *Request) SetParamsWithStruct(v any) *Request { r.params.SetParamsWithStruct(v) @@ -308,7 +309,7 @@ func (r *Request) SetPathParams(m map[string]string) *Request { return r } -// SetParamsWithStruct method sets multiple path param fields and its values at one go in the request instance. +// SetPathParamsWithStruct method sets multiple path param fields and its values at one go in the request instance. // It will override path param which set in client instance. func (r *Request) SetPathParamsWithStruct(v any) *Request { r.path.SetParamsWithStruct(v) @@ -362,7 +363,7 @@ func (r *Request) resetBody(t bodyType) { // FormData method returns form data value via key, // this method will visit all field in the form data. func (r *Request) FormData(key string) []string { - res := []string{} + var res []string tmp := r.formData.PeekMulti(key) for _, v := range tmp { res = append(res, utils.UnsafeString(v)) @@ -430,7 +431,7 @@ func (r *Request) File(name string) *File { return nil } -// File returns file ptr store in request obj by path. +// FileByPath returns file ptr store in request obj by path. func (r *Request) FileByPath(path string) *File { for _, v := range r.files { if v.path == path { @@ -457,7 +458,7 @@ func (r *Request) AddFileWithReader(name string, reader io.ReadCloser) *Request return r } -// AddFile method adds multiple file fields +// AddFiles method adds multiple file fields // and its value in the request instance via File instance. func (r *Request) AddFiles(files ...*File) *Request { r.files = append(r.files, files...) @@ -496,37 +497,37 @@ func (r *Request) checkClient() { } } -// Send get request. +// Get Send get request. func (r *Request) Get(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodGet).Send() } -// Send post request. +// Post Send post request. func (r *Request) Post(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodPost).Send() } -// Send head request. +// Head Send head request. func (r *Request) Head(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodHead).Send() } -// Send put request. +// Put Send put request. func (r *Request) Put(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodPut).Send() } -// Send Delete request. +// Delete Send Delete request. func (r *Request) Delete(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodDelete).Send() } -// Send Options request. +// Options Send Options request. func (r *Request) Options(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodOptions).Send() } -// Send patch request. +// Patch Send patch request. func (r *Request) Patch(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodPatch).Send() } @@ -573,7 +574,7 @@ type Header struct { // PeekMultiple methods returns multiple field in header with same key. func (h *Header) PeekMultiple(key string) []string { - res := []string{} + var res []string byteKey := []byte(key) h.RequestHeader.VisitAll(func(key, value []byte) { if bytes.EqualFold(key, byteKey) { @@ -852,28 +853,32 @@ func ReleaseRequest(req *Request) { var filePool sync.Pool -// The methods as follows is used by AcquireFile method. +// SetFileFunc The methods as follows is used by AcquireFile method. // You can set file field via these method. type SetFileFunc func(f *File) +// SetFileName method sets file name. func SetFileName(n string) SetFileFunc { return func(f *File) { f.SetName(n) } } +// SetFileFieldName method sets key of file in the body. func SetFileFieldName(p string) SetFileFunc { return func(f *File) { f.SetFieldName(p) } } +// SetFilePath method set file path. func SetFilePath(p string) SetFileFunc { return func(f *File) { f.SetPath(p) } } +// SetFileReader method can receive a io.ReadCloser func SetFileReader(r io.ReadCloser) SetFileFunc { return func(f *File) { f.SetReader(r) @@ -909,7 +914,7 @@ func ReleaseFile(f *File) { filePool.Put(f) } -// Set some values using structs. +// SetValWithStruct Set some values using structs. // `p` is a structure that implements the WithStruct interface, // The field name can be specified by `tagName`. // `v` is a struct include some data. diff --git a/client/response.go b/client/response.go index 170821683b..d21e218102 100644 --- a/client/response.go +++ b/client/response.go @@ -14,6 +14,7 @@ import ( "github.com/valyala/fasthttp" ) +// Response is the result of a request. This object is used to access the response data. type Response struct { client *Client request *Request @@ -79,6 +80,7 @@ func (r *Response) XML(v any) error { return r.client.xmlUnmarshal(r.Body(), v) } +// Save method will save the body to a file or io.Writer. func (r *Response) Save(v any) error { switch p := v.(type) { case string: diff --git a/go.sum b/go.sum index babb349db9..e26704c90c 100644 --- a/go.sum +++ b/go.sum @@ -1,24 +1,16 @@ -github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= -github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gofiber/utils/v2 v2.0.0-beta.3 h1:pfOhUDDVjBJpkWv6C5jaDyYLvpui7zQ97zpyFFsUOKw= github.com/gofiber/utils/v2 v2.0.0-beta.3/go.mod h1:jsl17+MsKfwJjM3ONCE9Rzji/j8XNbwjhUVTjzgfDCo= -github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= -github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4= -github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/philhofer/fwd v1.1.2 h1:bnDivRJ1EWPjUIRXV5KfORO897HTbpFAQddBdE8t7Gw= @@ -31,8 +23,6 @@ github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0= github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.50.0 h1:H7fweIlBm0rXLs2q0XbalvJ6r0CUPFWK3/bB4N13e9M= -github.com/valyala/fasthttp v1.50.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= @@ -57,8 +47,6 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= -golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= From e62a8df29a965af7231cbe9e86e1e26f340cd091 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Werner?= Date: Mon, 5 Feb 2024 10:03:54 +0100 Subject: [PATCH 087/118] cleanup comments --- client/core.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/client/core.go b/client/core.go index d2675ce2f9..18f1125ace 100644 --- a/client/core.go +++ b/client/core.go @@ -218,14 +218,14 @@ func (c *core) tls() { // proxy set proxy in host. func (c *core) proxy() error { - rawUri := c.req.RawRequest.URI() + rawURI := c.req.RawRequest.URI() if c.client.proxyURL != "" { - rawUri = fasthttp.AcquireURI() - rawUri.Update(c.client.proxyURL) - defer fasthttp.ReleaseURI(rawUri) + rawURI = fasthttp.AcquireURI() + rawURI.Update(c.client.proxyURL) + defer fasthttp.ReleaseURI(rawURI) } - isTLS, scheme := false, rawUri.Scheme() + isTLS, scheme := false, rawURI.Scheme() if bytes.Equal(httpsBytes, scheme) { isTLS = true } else if !bytes.Equal(httpBytes, scheme) { @@ -233,7 +233,7 @@ func (c *core) proxy() error { } c.client.mu.Lock() - c.client.host.Addr = addMissingPort(string(rawUri.Host()), isTLS) + c.client.host.Addr = addMissingPort(string(rawURI.Host()), isTLS) c.client.host.IsTLS = isTLS c.client.mu.Unlock() @@ -310,7 +310,7 @@ func releaseErrChan(ch chan error) { func newCore() (c *core) { c = &core{} - return + return c } var ( From 2fbdc898aff20b5805526b7db04f3ca851845c3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Werner?= Date: Mon, 5 Feb 2024 11:09:49 +0100 Subject: [PATCH 088/118] fix golang-lint errors --- client/client.go | 10 +- client/client_test.go | 166 ++++++++++++++++++++++----------- client/cookiejar.go | 11 ++- client/cookiejar_test.go | 13 +-- client/core.go | 10 +- client/core_test.go | 18 +++- client/helper_test.go | 8 +- client/hooks.go | 12 +-- client/hooks_test.go | 47 +++++++++- client/request_test.go | 133 +++++++++++++++++--------- client/response.go | 10 +- client/response_test.go | 18 +++- log/default.go | 12 +-- middleware/proxy/proxy_test.go | 37 ++++---- 14 files changed, 338 insertions(+), 167 deletions(-) diff --git a/client/client.go b/client/client.go index eb9152c2e2..78cc8eb18e 100644 --- a/client/client.go +++ b/client/client.go @@ -9,7 +9,7 @@ import ( "errors" "fmt" "io" - urlPkg "net/url" + urlpkg "net/url" "os" "path/filepath" "sync" @@ -21,8 +21,10 @@ import ( "github.com/valyala/fasthttp" ) -var ErrInvalidProxyURL = errors.New("invalid proxy url scheme") -var ErrFailedToAppendCert = errors.New("failed to append certificate") +var ( + ErrInvalidProxyURL = errors.New("invalid proxy url scheme") + ErrFailedToAppendCert = errors.New("failed to append certificate") +) // The Client is used to create a Fiber Client with // client-level settings that apply to all requests @@ -234,7 +236,7 @@ func (c *Client) SetRootCertificateFromString(pem string) *Client { // SetProxyURL sets proxy url in client. It will apply via core to hostclient. func (c *Client) SetProxyURL(proxyURL string) *Client { - pURL, err := urlPkg.Parse(proxyURL) + pURL, err := urlpkg.Parse(proxyURL) if err != nil { log.Panicf("client: %v", err) return c diff --git a/client/client_test.go b/client/client_test.go index 5678697cb0..cb165376c9 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -24,11 +24,12 @@ func Test_Client_Add_Hook(t *testing.T) { t.Parallel() t.Run("add request hooks", func(t *testing.T) { + t.Parallel() client := AcquireClient().AddRequestHook(func(c *Client, r *Request) error { return nil }) - require.Equal(t, 1, len(client.RequestHook())) + require.Len(t, client.RequestHook(), 1) client.AddRequestHook(func(c *Client, r *Request) error { return nil @@ -36,15 +37,16 @@ func Test_Client_Add_Hook(t *testing.T) { return nil }) - require.Equal(t, 3, len(client.RequestHook())) + require.Len(t, client.RequestHook(), 3) }) t.Run("add response hooks", func(t *testing.T) { + t.Parallel() client := AcquireClient().AddResponseHook(func(c *Client, resp *Response, r *Request) error { return nil }) - require.Equal(t, 1, len(client.ResponseHook())) + require.Len(t, client.ResponseHook(), 1) client.AddResponseHook(func(c *Client, resp *Response, r *Request) error { return nil @@ -52,12 +54,15 @@ func Test_Client_Add_Hook(t *testing.T) { return nil }) - require.Equal(t, 3, len(client.ResponseHook())) + require.Len(t, client.ResponseHook(), 3) }) } func Test_Client_Marshal(t *testing.T) { + t.Parallel() + t.Run("set json marshal", func(t *testing.T) { + t.Parallel() client := AcquireClient(). SetJSONMarshal(func(v any) ([]byte, error) { return []byte("hello"), nil @@ -69,6 +74,7 @@ func Test_Client_Marshal(t *testing.T) { }) t.Run("set json unmarshal", func(t *testing.T) { + t.Parallel() client := AcquireClient(). SetJSONUnmarshal(func(data []byte, v any) error { return fmt.Errorf("empty json") @@ -79,6 +85,7 @@ func Test_Client_Marshal(t *testing.T) { }) t.Run("set xml marshal", func(t *testing.T) { + t.Parallel() client := AcquireClient(). SetXMLMarshal(func(v any) ([]byte, error) { return []byte("hello"), nil @@ -90,6 +97,7 @@ func Test_Client_Marshal(t *testing.T) { }) t.Run("set xml unmarshal", func(t *testing.T) { + t.Parallel() client := AcquireClient(). SetXMLUnmarshal(func(data []byte, v any) error { return fmt.Errorf("empty xml") @@ -163,6 +171,7 @@ func Test_Get(t *testing.T) { go start() t.Run("global get function", func(t *testing.T) { + t.Parallel() resp, err := Get("http://example.com", Config{ Dial: dial, }) @@ -171,6 +180,7 @@ func Test_Get(t *testing.T) { }) t.Run("client get", func(t *testing.T) { + t.Parallel() resp, err := AcquireClient().Get("http://example.com", Config{ Dial: dial, }) @@ -191,6 +201,7 @@ func Test_Head(t *testing.T) { go start() t.Run("global head function", func(t *testing.T) { + t.Parallel() resp, err := Head("http://example.com", Config{ Dial: dial, }) @@ -199,6 +210,7 @@ func Test_Head(t *testing.T) { }) t.Run("client head", func(t *testing.T) { + t.Parallel() resp, err := AcquireClient().Head("http://example.com", Config{ Dial: dial, }) @@ -219,6 +231,7 @@ func Test_Post(t *testing.T) { go start() t.Run("global post function", func(t *testing.T) { + t.Parallel() for i := 0; i < 5; i++ { resp, err := Post("http://example.com", Config{ Dial: dial, @@ -227,13 +240,14 @@ func Test_Post(t *testing.T) { }, }) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusCreated, resp.StatusCode()) require.Equal(t, "bar", resp.String()) } }) t.Run("client post", func(t *testing.T) { + t.Parallel() for i := 0; i < 5; i++ { resp, err := AcquireClient().Post("http://example.com", Config{ Dial: dial, @@ -242,7 +256,7 @@ func Test_Post(t *testing.T) { }, }) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusCreated, resp.StatusCode()) require.Equal(t, "bar", resp.String()) } @@ -260,6 +274,7 @@ func Test_Put(t *testing.T) { go start() t.Run("global put function", func(t *testing.T) { + t.Parallel() for i := 0; i < 5; i++ { resp, err := Put("http://example.com", Config{ Dial: dial, @@ -268,13 +283,14 @@ func Test_Put(t *testing.T) { }, }) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "bar", resp.String()) } }) t.Run("client put", func(t *testing.T) { + t.Parallel() for i := 0; i < 5; i++ { resp, err := AcquireClient().Put("http://example.com", Config{ Dial: dial, @@ -283,7 +299,7 @@ func Test_Put(t *testing.T) { }, }) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "bar", resp.String()) } @@ -302,6 +318,7 @@ func Test_Delete(t *testing.T) { go start() t.Run("global delete function", func(t *testing.T) { + t.Parallel() for i := 0; i < 5; i++ { resp, err := Delete("http://example.com", Config{ Dial: dial, @@ -310,13 +327,14 @@ func Test_Delete(t *testing.T) { }, }) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) require.Equal(t, "", resp.String()) } }) t.Run("client delete", func(t *testing.T) { + t.Parallel() for i := 0; i < 5; i++ { resp, err := AcquireClient().Delete("http://example.com", Config{ Dial: dial, @@ -325,7 +343,7 @@ func Test_Delete(t *testing.T) { }, }) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) require.Equal(t, "", resp.String()) } @@ -343,29 +361,32 @@ func Test_Options(t *testing.T) { go start() t.Run("global options function", func(t *testing.T) { + t.Parallel() for i := 0; i < 5; i++ { resp, err := Options("http://example.com", Config{ Dial: dial, }) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) require.Equal(t, "", resp.String()) } }) t.Run("client options", func(t *testing.T) { + t.Parallel() for i := 0; i < 5; i++ { resp, err := AcquireClient().Options("http://example.com", Config{ Dial: dial, }) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) require.Equal(t, "", resp.String()) } }) } + func Test_Patch(t *testing.T) { t.Parallel() @@ -378,6 +399,7 @@ func Test_Patch(t *testing.T) { go start() t.Run("global patch function", func(t *testing.T) { + t.Parallel() for i := 0; i < 5; i++ { resp, err := Patch("http://example.com", Config{ Dial: dial, @@ -386,13 +408,14 @@ func Test_Patch(t *testing.T) { }, }) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "bar", resp.String()) } }) t.Run("client patch", func(t *testing.T) { + t.Parallel() for i := 0; i < 5; i++ { resp, err := AcquireClient().Patch("http://example.com", Config{ Dial: dial, @@ -401,7 +424,7 @@ func Test_Patch(t *testing.T) { }, }) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "bar", resp.String()) } @@ -420,18 +443,20 @@ func Test_Client_UserAgent(t *testing.T) { go start() t.Run("default", func(t *testing.T) { + t.Parallel() for i := 0; i < 5; i++ { resp, err := Get("http://example.com", Config{ Dial: dial, }) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, defaultUserAgent, resp.String()) } }) t.Run("custom", func(t *testing.T) { + t.Parallel() for i := 0; i < 5; i++ { c := AcquireClient(). SetUserAgent("ua") @@ -440,7 +465,7 @@ func Test_Client_UserAgent(t *testing.T) { Dial: dial, }) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "ua", resp.String()) ReleaseClient(c) @@ -452,25 +477,28 @@ func Test_Client_Header(t *testing.T) { t.Parallel() t.Run("add header", func(t *testing.T) { + t.Parallel() req := AcquireClient() req.AddHeader("foo", "bar").AddHeader("foo", "fiber") res := req.Header("foo") - require.Equal(t, 2, len(res)) + require.Len(t, res, 2) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) }) t.Run("set header", func(t *testing.T) { + t.Parallel() req := AcquireClient() req.AddHeader("foo", "bar").SetHeader("foo", "fiber") res := req.Header("foo") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) }) t.Run("add headers", func(t *testing.T) { + t.Parallel() req := AcquireClient() req.SetHeader("foo", "bar"). AddHeaders(map[string][]string{ @@ -479,17 +507,18 @@ func Test_Client_Header(t *testing.T) { }) res := req.Header("foo") - require.Equal(t, 3, len(res)) + require.Len(t, res, 3) require.Equal(t, "bar", res[0]) require.Equal(t, "buaa", res[1]) require.Equal(t, "fiber", res[2]) res = req.Header("bar") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) t.Run("set headers", func(t *testing.T) { + t.Parallel() req := AcquireClient() req.SetHeader("foo", "bar"). SetHeaders(map[string]string{ @@ -498,11 +527,11 @@ func Test_Client_Header(t *testing.T) { }) res := req.Header("foo") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) res = req.Header("bar") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) } @@ -537,6 +566,7 @@ func Test_Client_Cookie(t *testing.T) { t.Parallel() t.Run("set cookie", func(t *testing.T) { + t.Parallel() req := AcquireClient(). SetCookie("foo", "bar") require.Equal(t, "bar", req.Cookie("foo")) @@ -546,6 +576,7 @@ func Test_Client_Cookie(t *testing.T) { }) t.Run("set cookies", func(t *testing.T) { + t.Parallel() req := AcquireClient(). SetCookies(map[string]string{ "foo": "bar", @@ -562,6 +593,7 @@ func Test_Client_Cookie(t *testing.T) { }) t.Run("set cookies with struct", func(t *testing.T) { + t.Parallel() type args struct { CookieInt int `cookie:"int"` CookieString string `cookie:"string"` @@ -577,6 +609,7 @@ func Test_Client_Cookie(t *testing.T) { }) t.Run("del cookies", func(t *testing.T) { + t.Parallel() req := AcquireClient(). SetCookies(map[string]string{ "foo": "bar", @@ -592,6 +625,8 @@ func Test_Client_Cookie(t *testing.T) { } func Test_Client_Cookie_With_Server(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { return c.SendString( c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) @@ -629,7 +664,10 @@ func Test_Client_CookieJar(t *testing.T) { } func Test_Client_CookieJar_Response(t *testing.T) { + t.Parallel() + t.Run("without expiration", func(t *testing.T) { + t.Parallel() handler := func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "k4", @@ -655,6 +693,7 @@ func Test_Client_CookieJar_Response(t *testing.T) { }) t.Run("with expiration", func(t *testing.T) { + t.Parallel() handler := func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "k4", @@ -681,6 +720,7 @@ func Test_Client_CookieJar_Response(t *testing.T) { }) t.Run("override cookie value", func(t *testing.T) { + t.Parallel() handler := func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "k1", @@ -725,25 +765,28 @@ func Test_Client_QueryParam(t *testing.T) { t.Parallel() t.Run("add param", func(t *testing.T) { + t.Parallel() req := AcquireClient() req.AddParam("foo", "bar").AddParam("foo", "fiber") res := req.Param("foo") - require.Equal(t, 2, len(res)) + require.Len(t, res, 2) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) }) t.Run("set param", func(t *testing.T) { + t.Parallel() req := AcquireClient() req.AddParam("foo", "bar").SetParam("foo", "fiber") res := req.Param("foo") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) }) t.Run("add params", func(t *testing.T) { + t.Parallel() req := AcquireClient() req.SetParam("foo", "bar"). AddParams(map[string][]string{ @@ -752,17 +795,18 @@ func Test_Client_QueryParam(t *testing.T) { }) res := req.Param("foo") - require.Equal(t, 3, len(res)) + require.Len(t, res, 3) require.Equal(t, "bar", res[0]) require.Equal(t, "buaa", res[1]) require.Equal(t, "fiber", res[2]) res = req.Param("bar") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) t.Run("set headers", func(t *testing.T) { + t.Parallel() req := AcquireClient() req.SetParam("foo", "bar"). SetParams(map[string]string{ @@ -771,11 +815,11 @@ func Test_Client_QueryParam(t *testing.T) { }) res := req.Param("foo") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) res = req.Param("bar") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) @@ -801,31 +845,32 @@ func Test_Client_QueryParam(t *testing.T) { TIntSlice: []int{1, 2}, }) - require.Equal(t, 0, len(p.Param("unexport"))) + require.Empty(t, p.Param("unexport")) - require.Equal(t, 1, len(p.Param("TInt"))) + require.Len(t, p.Param("TInt"), 1) require.Equal(t, "5", p.Param("TInt")[0]) - require.Equal(t, 1, len(p.Param("TString"))) + require.Len(t, p.Param("TString"), 1) require.Equal(t, "string", p.Param("TString")[0]) - require.Equal(t, 1, len(p.Param("TFloat"))) + require.Len(t, p.Param("TFloat"), 1) require.Equal(t, "3.1", p.Param("TFloat")[0]) - require.Equal(t, 1, len(p.Param("TBool"))) + require.Len(t, p.Param("TBool"), 1) tslice := p.Param("TSlice") - require.Equal(t, 2, len(tslice)) + require.Len(t, tslice, 2) require.Equal(t, "bar", tslice[0]) require.Equal(t, "foo", tslice[1]) tint := p.Param("TSlice") - require.Equal(t, 2, len(tint)) + require.Len(t, tint, 2) require.Equal(t, "bar", tint[0]) require.Equal(t, "foo", tint[1]) }) t.Run("del params", func(t *testing.T) { + t.Parallel() req := AcquireClient() req.SetParam("foo", "bar"). SetParams(map[string]string{ @@ -834,10 +879,10 @@ func Test_Client_QueryParam(t *testing.T) { }).DelParams("foo", "bar") res := req.Param("foo") - require.Equal(t, 0, len(res)) + require.Empty(t, res) res = req.Param("bar") - require.Equal(t, 0, len(res)) + require.Empty(t, res) }) } @@ -861,6 +906,7 @@ func Test_Client_PathParam(t *testing.T) { t.Parallel() t.Run("set path param", func(t *testing.T) { + t.Parallel() req := AcquireClient(). SetPathParam("foo", "bar") require.Equal(t, "bar", req.PathParam("foo")) @@ -870,6 +916,7 @@ func Test_Client_PathParam(t *testing.T) { }) t.Run("set path params", func(t *testing.T) { + t.Parallel() req := AcquireClient(). SetPathParams(map[string]string{ "foo": "bar", @@ -886,6 +933,7 @@ func Test_Client_PathParam(t *testing.T) { }) t.Run("set path params with struct", func(t *testing.T) { + t.Parallel() type args struct { CookieInt int `path:"int"` CookieString string `path:"string"` @@ -901,6 +949,7 @@ func Test_Client_PathParam(t *testing.T) { }) t.Run("del path params", func(t *testing.T) { + t.Parallel() req := AcquireClient(). SetPathParams(map[string]string{ "foo": "bar", @@ -928,7 +977,7 @@ func Test_Client_PathParam_With_Server(t *testing.T) { SetPathParam("path", "test"). Get("http://example.com/:path", Config{Dial: dial}) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "ok", resp.String()) } @@ -937,10 +986,10 @@ func Test_Client_TLS(t *testing.T) { t.Parallel() serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() - require.Nil(t, err) + require.NoError(t, err) ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") - require.Nil(t, err) + require.NoError(t, err) ln = tls.NewListener(ln, serverTLSConf) @@ -950,7 +999,7 @@ func Test_Client_TLS(t *testing.T) { }) go func() { - require.Nil(t, app.Listener(ln, fiber.ListenConfig{ + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ DisableStartupMessage: true, })) }() @@ -958,7 +1007,7 @@ func Test_Client_TLS(t *testing.T) { client := AcquireClient() resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, clientTLSConf, client.TLSConfig()) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "tls", resp.String()) @@ -968,10 +1017,10 @@ func Test_Client_TLS_Empty_TLSConfig(t *testing.T) { t.Parallel() serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() - require.Nil(t, err) + require.NoError(t, err) ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") - require.Nil(t, err) + require.NoError(t, err) ln = tls.NewListener(ln, serverTLSConf) @@ -981,7 +1030,7 @@ func Test_Client_TLS_Empty_TLSConfig(t *testing.T) { }) go func() { - require.Nil(t, app.Listener(ln, fiber.ListenConfig{ + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ DisableStartupMessage: true, })) }() @@ -998,10 +1047,10 @@ func Test_Client_SetCertificates(t *testing.T) { t.Parallel() serverTLSConf, _, err := tlstest.GetTLSConfigs() - require.Nil(t, err) + require.NoError(t, err) client := AcquireClient().SetCertificates(serverTLSConf.Certificates...) - require.Equal(t, 1, len(client.tlsConfig.Certificates)) + require.Len(t, client.tlsConfig.Certificates, 1) } func Test_Client_SetRootCertificate(t *testing.T) { @@ -1046,14 +1095,14 @@ func Test_Replace(t *testing.T) { resp, err := Get("http://example.com", Config{Dial: dial}) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "", resp.String()) r := AcquireClient().SetHeader("k1", "v1") clean := Replace(r) resp, err = Get("http://example.com", Config{Dial: dial}) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "v1", resp.String()) @@ -1062,7 +1111,7 @@ func Test_Replace(t *testing.T) { resp, err = Get("http://example.com", Config{Dial: dial}) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "", resp.String()) } @@ -1071,6 +1120,7 @@ func Test_Set_Config_To_Request(t *testing.T) { t.Parallel() t.Run("set ctx", func(t *testing.T) { + t.Parallel() key := struct{}{} ctx := context.Background() @@ -1084,6 +1134,7 @@ func Test_Set_Config_To_Request(t *testing.T) { }) t.Run("set useragent", func(t *testing.T) { + t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{UserAgent: "agent"}) @@ -1092,6 +1143,7 @@ func Test_Set_Config_To_Request(t *testing.T) { }) t.Run("set referer", func(t *testing.T) { + t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{Referer: "referer"}) @@ -1110,6 +1162,7 @@ func Test_Set_Config_To_Request(t *testing.T) { }) t.Run("set params", func(t *testing.T) { + t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{Param: map[string]string{ @@ -1120,6 +1173,7 @@ func Test_Set_Config_To_Request(t *testing.T) { }) t.Run("set cookies", func(t *testing.T) { + t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{Cookie: map[string]string{ @@ -1130,6 +1184,7 @@ func Test_Set_Config_To_Request(t *testing.T) { }) t.Run("set pathparam", func(t *testing.T) { + t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{PathParam: map[string]string{ @@ -1140,6 +1195,7 @@ func Test_Set_Config_To_Request(t *testing.T) { }) t.Run("set timeout", func(t *testing.T) { + t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{Timeout: 1 * time.Second}) @@ -1148,6 +1204,7 @@ func Test_Set_Config_To_Request(t *testing.T) { }) t.Run("set maxredirects", func(t *testing.T) { + t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{MaxRedirects: 1}) @@ -1156,6 +1213,7 @@ func Test_Set_Config_To_Request(t *testing.T) { }) t.Run("set body", func(t *testing.T) { + t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{Body: "test"}) @@ -1164,6 +1222,7 @@ func Test_Set_Config_To_Request(t *testing.T) { }) t.Run("set file", func(t *testing.T) { + t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{File: []*File{ @@ -1187,13 +1246,14 @@ func Test_Client_SetProxyURL(t *testing.T) { go start() - defer func(app *fiber.App) { + t.Cleanup(func() { _ = app.Shutdown() - }(app) + }) time.Sleep(1 * time.Second) t.Run("success", func(t *testing.T) { + t.Parallel() client := AcquireClient() client.SetProxyURL("http://test.com") _, err := client.Get("http://localhost:3000", Config{Dial: dial}) @@ -1202,6 +1262,7 @@ func Test_Client_SetProxyURL(t *testing.T) { }) t.Run("wrong url", func(t *testing.T) { + t.Parallel() client := AcquireClient() defer ReleaseClient(client) @@ -1211,6 +1272,7 @@ func Test_Client_SetProxyURL(t *testing.T) { }) t.Run("error", func(t *testing.T) { + t.Parallel() client := AcquireClient() defer ReleaseClient(client) diff --git a/client/cookiejar.go b/client/cookiejar.go index 353f266d40..6f2088c482 100644 --- a/client/cookiejar.go +++ b/client/cookiejar.go @@ -3,15 +3,16 @@ package client import ( "bytes" - "github.com/gofiber/utils/v2" - "github.com/valyala/fasthttp" "net" "sync" "time" + + "github.com/gofiber/utils/v2" + "github.com/valyala/fasthttp" ) var cookieJarPool = sync.Pool{ - New: func() interface{} { + New: func() any { return &CookieJar{} }, } @@ -178,8 +179,8 @@ func (cj *CookieJar) dumpCookiesToReq(req *fasthttp.Request) { } } -// getCookiesFromResp parses the response cookies and stores them. -func (cj *CookieJar) getCookiesFromResp(host, path []byte, resp *fasthttp.Response) { +// parseCookiesFromResp parses the response cookies and stores them. +func (cj *CookieJar) parseCookiesFromResp(host, path []byte, resp *fasthttp.Response) { hostStr := utils.UnsafeString(host) cj.mu.Lock() diff --git a/client/cookiejar_test.go b/client/cookiejar_test.go index 1cfd7510bd..3b6fdcda83 100644 --- a/client/cookiejar_test.go +++ b/client/cookiejar_test.go @@ -2,17 +2,18 @@ package client import ( "bytes" - "github.com/stretchr/testify/require" - "github.com/valyala/fasthttp" "testing" "time" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" ) func checkKeyValue(t *testing.T, cj *CookieJar, cookie *fasthttp.Cookie, uri *fasthttp.URI, n int) { t.Helper() cs := cj.Get(uri) - require.True(t, len(cs) >= n) + require.GreaterOrEqual(t, len(cs), n) c := cs[n-1] require.NotNil(t, c) @@ -74,7 +75,7 @@ func TestCookieJarGet(t *testing.T) { } cookies = cj.Get(uri11) - require.Len(t, cookies, 0) + require.Empty(t, cookies) cookies = cj.Get(uri2) require.Len(t, cookies, 2) @@ -112,7 +113,7 @@ func TestCookieJarGetExpired(t *testing.T) { cj.Set(uri1, c1) cookies := cj.Get(uri1) - require.Len(t, cookies, 0) + require.Empty(t, cookies) } func TestCookieJarSet(t *testing.T) { @@ -205,7 +206,7 @@ func TestCookieJarGetFromResponse(t *testing.T) { res.Header.SetCookie(c3) cj := &CookieJar{} - cj.getCookiesFromResp(host, nil, res) + cj.parseCookiesFromResp(host, nil, res) cookies := cj.Get(uri) require.Len(t, cookies, 3) diff --git a/client/core.go b/client/core.go index 18f1125ace..bdd90692f7 100644 --- a/client/core.go +++ b/client/core.go @@ -193,10 +193,8 @@ func (c *core) timeout() context.CancelFunc { if c.req.timeout > 0 { c.ctx, cancel = context.WithTimeout(c.ctx, c.req.timeout) - } else { - if c.client.timeout > 0 { - c.ctx, cancel = context.WithTimeout(c.ctx, c.client.timeout) - } + } else if c.client.timeout > 0 { + c.ctx, cancel = context.WithTimeout(c.ctx, c.client.timeout) } return cancel @@ -307,8 +305,8 @@ func releaseErrChan(ch chan error) { } // newCore returns an empty core object. -func newCore() (c *core) { - c = &core{} +func newCore() *core { + c := &core{} return c } diff --git a/client/core_test.go b/client/core_test.go index 37bdcf4412..b8d39b0480 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -13,6 +13,8 @@ import ( ) func Test_AddMissing_Port(t *testing.T) { + t.Parallel() + type args struct { addr string isTLS bool @@ -47,12 +49,15 @@ func Test_AddMissing_Port(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() require.Equal(t, tt.want, addMissingPort(tt.args.addr, tt.args.isTLS)) }) } } func Test_Exec_Func(t *testing.T) { + t.Parallel() + ln := fasthttputil.NewInmemoryListener() app := fiber.New() @@ -70,12 +75,13 @@ func Test_Exec_Func(t *testing.T) { }) go func() { - require.Nil(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) + require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() time.Sleep(300 * time.Millisecond) t.Run("normal request", func(t *testing.T) { + t.Parallel() core, client, req := newCore(), AcquireClient(), AcquireRequest() core.ctx = context.Background() core.client = client @@ -92,6 +98,7 @@ func Test_Exec_Func(t *testing.T) { }) t.Run("the request return an error", func(t *testing.T) { + t.Parallel() core, client, req := newCore(), AcquireClient(), AcquireRequest() core.ctx = context.Background() core.client = client @@ -108,6 +115,7 @@ func Test_Exec_Func(t *testing.T) { }) t.Run("the request timeout", func(t *testing.T) { + t.Parallel() core, client, req := newCore(), AcquireClient(), AcquireRequest() ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() @@ -145,10 +153,11 @@ func Test_Execute(t *testing.T) { }) go func() { - require.Nil(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) + require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() t.Run("add user request hooks", func(t *testing.T) { + t.Parallel() core, client, req := newCore(), AcquireClient(), AcquireRequest() client.AddRequestHook(func(c *Client, r *Request) error { require.Equal(t, "http://example.com", req.URL()) @@ -164,6 +173,7 @@ func Test_Execute(t *testing.T) { }) t.Run("add user response hooks", func(t *testing.T) { + t.Parallel() core, client, req := newCore(), AcquireClient(), AcquireRequest() client.AddResponseHook(func(c *Client, resp *Response, req *Request) error { require.Equal(t, "http://example.com", req.URL()) @@ -179,6 +189,7 @@ func Test_Execute(t *testing.T) { }) t.Run("no timeout", func(t *testing.T) { + t.Parallel() core, client, req := newCore(), AcquireClient(), AcquireRequest() req.SetDial(func(addr string) (net.Conn, error) { @@ -191,6 +202,7 @@ func Test_Execute(t *testing.T) { }) t.Run("client timeout", func(t *testing.T) { + t.Parallel() core, client, req := newCore(), AcquireClient(), AcquireRequest() client.SetTimeout(500 * time.Millisecond) req.SetDial(func(addr string) (net.Conn, error) { @@ -202,6 +214,7 @@ func Test_Execute(t *testing.T) { }) t.Run("request timeout", func(t *testing.T) { + t.Parallel() core, client, req := newCore(), AcquireClient(), AcquireRequest() req.SetDial(func(addr string) (net.Conn, error) { @@ -214,6 +227,7 @@ func Test_Execute(t *testing.T) { }) t.Run("request timeout has higher level", func(t *testing.T) { + t.Parallel() core, client, req := newCore(), AcquireClient(), AcquireRequest() client.SetTimeout(30 * time.Millisecond) diff --git a/client/helper_test.go b/client/helper_test.go index 8636796e21..92a120d4d3 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -24,12 +24,12 @@ func createHelperServer(t testing.TB, config ...fiber.Config) (*fiber.App, func( return app, func(addr string) (net.Conn, error) { return ln.Dial() }, func() { - require.Nil(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) + require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) } } func testRequest(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted string, count ...int) { - t.Parallel() + t.Helper() app, ln, start := createHelperServer(t) app.Get("/", handler) @@ -54,7 +54,7 @@ func testRequest(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Requ } func testRequestFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted error, count ...int) { - t.Parallel() + t.Helper() app, ln, start := createHelperServer(t) app.Get("/", handler) @@ -76,7 +76,7 @@ func testRequestFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent * } func testClient(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Client), excepted string, count ...int) { - t.Parallel() + t.Helper() app, ln, start := createHelperServer(t) app.Get("/", handler) diff --git a/client/hooks.go b/client/hooks.go index 190dda1840..f93ed46803 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -277,7 +277,7 @@ func parserRequestBody(c *Client, req *Request) error { } // parserResponseHeader will parse the response header and store it in the response -func parserResponseCookie(c *Client, resp *Response, req *Request) (err error) { +func parserResponseCookie(c *Client, resp *Response, req *Request) error { resp.RawResponse.Header.VisitAllCookie(func(key, value []byte) { cookie := fasthttp.AcquireCookie() _ = cookie.ParseBytes(value) @@ -288,20 +288,20 @@ func parserResponseCookie(c *Client, resp *Response, req *Request) (err error) { // store cookies to jar if c.cookieJar != nil { - c.cookieJar.getCookiesFromResp(req.RawRequest.URI().Host(), req.RawRequest.URI().Path(), resp.RawResponse) + c.cookieJar.parseCookiesFromResp(req.RawRequest.URI().Host(), req.RawRequest.URI().Path(), resp.RawResponse) } - return + return nil } // logger is a response hook that logs the request and response -func logger(c *Client, resp *Response, req *Request) (err error) { +func logger(c *Client, resp *Response, req *Request) error { if !c.debug { - return + return nil } log.Debugf("%s\n", req.RawRequest.String()) log.Debugf("%s\n", resp.RawResponse.String()) - return + return nil } diff --git a/client/hooks_test.go b/client/hooks_test.go index d144b29e30..2c89868998 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -3,18 +3,20 @@ package client import ( "bytes" "encoding/xml" - "github.com/gofiber/fiber/v3/log" "io" "net/url" "strings" "testing" "time" + "github.com/gofiber/fiber/v3/log" + "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" ) func Test_Rand_String(t *testing.T) { + t.Parallel() tests := []struct { name string args int @@ -26,8 +28,9 @@ func Test_Rand_String(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() got := randString(tt.args) - require.Equal(t, 16, len(got)) + require.Len(t, got, 16) }) } } @@ -36,6 +39,7 @@ func Test_Parser_Request_URL(t *testing.T) { t.Parallel() t.Run("client baseurl should be set", func(t *testing.T) { + t.Parallel() client := AcquireClient().SetBaseURL("http://example.com/api") req := AcquireRequest().SetURL("") @@ -45,6 +49,7 @@ func Test_Parser_Request_URL(t *testing.T) { }) t.Run("request url should be set", func(t *testing.T) { + t.Parallel() client := AcquireClient() req := AcquireRequest().SetURL("http://example.com/api") @@ -54,6 +59,7 @@ func Test_Parser_Request_URL(t *testing.T) { }) t.Run("the request url will override baseurl with protocol", func(t *testing.T) { + t.Parallel() client := AcquireClient().SetBaseURL("http://example.com/api") req := AcquireRequest().SetURL("http://example.com/api/v1") @@ -63,6 +69,7 @@ func Test_Parser_Request_URL(t *testing.T) { }) t.Run("the request url should be append after baseurl without protocol", func(t *testing.T) { + t.Parallel() client := AcquireClient().SetBaseURL("http://example.com/api") req := AcquireRequest().SetURL("/v1") @@ -72,6 +79,7 @@ func Test_Parser_Request_URL(t *testing.T) { }) t.Run("the url is error", func(t *testing.T) { + t.Parallel() client := AcquireClient().SetBaseURL("example.com/api") req := AcquireRequest().SetURL("/v1") @@ -80,6 +88,7 @@ func Test_Parser_Request_URL(t *testing.T) { }) t.Run("the path param from client", func(t *testing.T) { + t.Parallel() client := AcquireClient(). SetBaseURL("http://example.com/api/:id"). SetPathParam("id", "5") @@ -91,6 +100,7 @@ func Test_Parser_Request_URL(t *testing.T) { }) t.Run("the path param from request", func(t *testing.T) { + t.Parallel() client := AcquireClient(). SetBaseURL("http://example.com/api/:id/:name"). SetPathParam("id", "5") @@ -108,6 +118,7 @@ func Test_Parser_Request_URL(t *testing.T) { }) t.Run("the path param from request and client", func(t *testing.T) { + t.Parallel() client := AcquireClient(). SetBaseURL("http://example.com/api/:id/:name"). SetPathParam("id", "5") @@ -125,6 +136,7 @@ func Test_Parser_Request_URL(t *testing.T) { }) t.Run("query params from client should be set", func(t *testing.T) { + t.Parallel() client := AcquireClient(). SetParam("foo", "bar") req := AcquireRequest().SetURL("http://example.com/api/v1") @@ -135,6 +147,7 @@ func Test_Parser_Request_URL(t *testing.T) { }) t.Run("query params from request should be set", func(t *testing.T) { + t.Parallel() client := AcquireClient() req := AcquireRequest(). SetURL("http://example.com/api/v1"). @@ -146,6 +159,7 @@ func Test_Parser_Request_URL(t *testing.T) { }) t.Run("query params should be merged", func(t *testing.T) { + t.Parallel() client := AcquireClient(). SetParam("bar", "foo1") req := AcquireRequest(). @@ -177,6 +191,7 @@ func Test_Parser_Request_Header(t *testing.T) { t.Parallel() t.Run("client header should be set", func(t *testing.T) { + t.Parallel() client := AcquireClient(). SetHeaders(map[string]string{ fiber.HeaderContentType: "application/json", @@ -190,6 +205,7 @@ func Test_Parser_Request_Header(t *testing.T) { }) t.Run("request header should be set", func(t *testing.T) { + t.Parallel() client := AcquireClient() req := AcquireRequest(). @@ -203,6 +219,7 @@ func Test_Parser_Request_Header(t *testing.T) { }) t.Run("request header should override client header", func(t *testing.T) { + t.Parallel() client := AcquireClient(). SetHeader(fiber.HeaderContentType, "application/xml") @@ -215,6 +232,7 @@ func Test_Parser_Request_Header(t *testing.T) { }) t.Run("auto set json header", func(t *testing.T) { + t.Parallel() type jsonData struct { Name string `json:"name"` } @@ -230,6 +248,7 @@ func Test_Parser_Request_Header(t *testing.T) { }) t.Run("auto set xml header", func(t *testing.T) { + t.Parallel() type xmlData struct { XMLName xml.Name `xml:"body"` Name string `xml:"name"` @@ -246,6 +265,7 @@ func Test_Parser_Request_Header(t *testing.T) { }) t.Run("auto set form data header", func(t *testing.T) { + t.Parallel() client := AcquireClient() req := AcquireRequest(). SetFormDatas(map[string]string{ @@ -259,6 +279,7 @@ func Test_Parser_Request_Header(t *testing.T) { }) t.Run("auto set file header", func(t *testing.T) { + t.Parallel() client := AcquireClient() req := AcquireRequest(). AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))). @@ -271,6 +292,7 @@ func Test_Parser_Request_Header(t *testing.T) { }) t.Run("ua should have default value", func(t *testing.T) { + t.Parallel() client := AcquireClient() req := AcquireRequest() @@ -280,6 +302,7 @@ func Test_Parser_Request_Header(t *testing.T) { }) t.Run("ua in client should be set", func(t *testing.T) { + t.Parallel() client := AcquireClient().SetUserAgent("foo") req := AcquireRequest() @@ -289,6 +312,7 @@ func Test_Parser_Request_Header(t *testing.T) { }) t.Run("ua in request should have higher level", func(t *testing.T) { + t.Parallel() client := AcquireClient().SetUserAgent("foo") req := AcquireRequest().SetUserAgent("bar") @@ -298,6 +322,7 @@ func Test_Parser_Request_Header(t *testing.T) { }) t.Run("referer in client should be set", func(t *testing.T) { + t.Parallel() client := AcquireClient().SetReferer("https://example.com") req := AcquireRequest() @@ -307,6 +332,7 @@ func Test_Parser_Request_Header(t *testing.T) { }) t.Run("referer in request should have higher level", func(t *testing.T) { + t.Parallel() client := AcquireClient().SetReferer("http://example.com") req := AcquireRequest().SetReferer("https://example.com") @@ -316,6 +342,7 @@ func Test_Parser_Request_Header(t *testing.T) { }) t.Run("client cookie should be set", func(t *testing.T) { + t.Parallel() client := AcquireClient(). SetCookie("foo", "bar"). SetCookies(map[string]string{ @@ -334,6 +361,7 @@ func Test_Parser_Request_Header(t *testing.T) { }) t.Run("request cookie should be set", func(t *testing.T) { + t.Parallel() type cookies struct { Foo string `cookie:"foo"` Bar int `cookie:"bar"` @@ -355,6 +383,7 @@ func Test_Parser_Request_Header(t *testing.T) { }) t.Run("request cookie will override client cookie", func(t *testing.T) { + t.Parallel() type cookies struct { Foo string `cookie:"foo"` Bar int `cookie:"bar"` @@ -385,6 +414,7 @@ func Test_Parser_Request_Body(t *testing.T) { t.Parallel() t.Run("json body", func(t *testing.T) { + t.Parallel() type jsonData struct { Name string `json:"name"` } @@ -400,6 +430,7 @@ func Test_Parser_Request_Body(t *testing.T) { }) t.Run("xml body", func(t *testing.T) { + t.Parallel() type xmlData struct { XMLName xml.Name `xml:"body"` Name string `xml:"name"` @@ -416,6 +447,7 @@ func Test_Parser_Request_Body(t *testing.T) { }) t.Run("form data body", func(t *testing.T) { + t.Parallel() client := AcquireClient() req := AcquireRequest(). SetFormDatas(map[string]string{ @@ -428,6 +460,7 @@ func Test_Parser_Request_Body(t *testing.T) { }) t.Run("form data body error", func(t *testing.T) { + t.Parallel() client := AcquireClient() req := AcquireRequest(). SetFormDatas(map[string]string{ @@ -439,6 +472,7 @@ func Test_Parser_Request_Body(t *testing.T) { }) t.Run("file body", func(t *testing.T) { + t.Parallel() client := AcquireClient() req := AcquireRequest(). AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))) @@ -450,6 +484,7 @@ func Test_Parser_Request_Body(t *testing.T) { }) t.Run("file and form data", func(t *testing.T) { + t.Parallel() client := AcquireClient() req := AcquireRequest(). AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))). @@ -463,6 +498,7 @@ func Test_Parser_Request_Body(t *testing.T) { }) t.Run("raw body", func(t *testing.T) { + t.Parallel() client := AcquireClient() req := AcquireRequest(). SetRawBody([]byte("hello world")) @@ -473,6 +509,7 @@ func Test_Parser_Request_Body(t *testing.T) { }) t.Run("raw body error", func(t *testing.T) { + t.Parallel() client := AcquireClient() req := AcquireRequest(). SetRawBody([]byte("hello world")) @@ -485,13 +522,14 @@ func Test_Parser_Request_Body(t *testing.T) { } func Test_Client_Logger_Debug(t *testing.T) { + t.Parallel() app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString("response") }) go func() { - require.Nil(t, app.Listen(":3000", fiber.ListenConfig{ + require.NoError(t, app.Listen(":3000", fiber.ListenConfig{ DisableStartupMessage: true, })) }() @@ -517,13 +555,14 @@ func Test_Client_Logger_Debug(t *testing.T) { } func Test_Client_Logger_DisableDebug(t *testing.T) { + t.Parallel() app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString("response") }) go func() { - require.Nil(t, app.Listen(":3000", fiber.ListenConfig{ + require.NoError(t, app.Listen(":3000", fiber.ListenConfig{ DisableStartupMessage: true, })) }() diff --git a/client/request_test.go b/client/request_test.go index d26133e76f..8bfe3cbbe2 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -79,25 +79,28 @@ func Test_Request_Header(t *testing.T) { t.Parallel() t.Run("add header", func(t *testing.T) { + t.Parallel() req := AcquireRequest() req.AddHeader("foo", "bar").AddHeader("foo", "fiber") res := req.Header("foo") - require.Equal(t, 2, len(res)) + require.Len(t, res, 2) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) }) t.Run("set header", func(t *testing.T) { + t.Parallel() req := AcquireRequest() req.AddHeader("foo", "bar").SetHeader("foo", "fiber") res := req.Header("foo") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) }) t.Run("add headers", func(t *testing.T) { + t.Parallel() req := AcquireRequest() req.SetHeader("foo", "bar"). AddHeaders(map[string][]string{ @@ -106,17 +109,18 @@ func Test_Request_Header(t *testing.T) { }) res := req.Header("foo") - require.Equal(t, 3, len(res)) + require.Len(t, res, 3) require.Equal(t, "bar", res[0]) require.Equal(t, "buaa", res[1]) require.Equal(t, "fiber", res[2]) res = req.Header("bar") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) t.Run("set headers", func(t *testing.T) { + t.Parallel() req := AcquireRequest() req.SetHeader("foo", "bar"). SetHeaders(map[string]string{ @@ -125,11 +129,11 @@ func Test_Request_Header(t *testing.T) { }) res := req.Header("foo") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) res = req.Header("bar") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) } @@ -138,25 +142,28 @@ func Test_Request_QueryParam(t *testing.T) { t.Parallel() t.Run("add param", func(t *testing.T) { + t.Parallel() req := AcquireRequest() req.AddParam("foo", "bar").AddParam("foo", "fiber") res := req.Param("foo") - require.Equal(t, 2, len(res)) + require.Len(t, res, 2) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) }) t.Run("set param", func(t *testing.T) { + t.Parallel() req := AcquireRequest() req.AddParam("foo", "bar").SetParam("foo", "fiber") res := req.Param("foo") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) }) t.Run("add params", func(t *testing.T) { + t.Parallel() req := AcquireRequest() req.SetParam("foo", "bar"). AddParams(map[string][]string{ @@ -165,17 +172,18 @@ func Test_Request_QueryParam(t *testing.T) { }) res := req.Param("foo") - require.Equal(t, 3, len(res)) + require.Len(t, res, 3) require.Equal(t, "bar", res[0]) require.Equal(t, "buaa", res[1]) require.Equal(t, "fiber", res[2]) res = req.Param("bar") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) t.Run("set headers", func(t *testing.T) { + t.Parallel() req := AcquireRequest() req.SetParam("foo", "bar"). SetParams(map[string]string{ @@ -184,11 +192,11 @@ func Test_Request_QueryParam(t *testing.T) { }) res := req.Param("foo") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) res = req.Param("bar") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) @@ -214,31 +222,32 @@ func Test_Request_QueryParam(t *testing.T) { TIntSlice: []int{1, 2}, }) - require.Equal(t, 0, len(p.Param("unexport"))) + require.Empty(t, p.Param("unexport")) - require.Equal(t, 1, len(p.Param("TInt"))) + require.Len(t, p.Param("TInt"), 1) require.Equal(t, "5", p.Param("TInt")[0]) - require.Equal(t, 1, len(p.Param("TString"))) + require.Len(t, p.Param("TString"), 1) require.Equal(t, "string", p.Param("TString")[0]) - require.Equal(t, 1, len(p.Param("TFloat"))) + require.Len(t, p.Param("TFloat"), 1) require.Equal(t, "3.1", p.Param("TFloat")[0]) - require.Equal(t, 1, len(p.Param("TBool"))) + require.Len(t, p.Param("TBool"), 1) tslice := p.Param("TSlice") - require.Equal(t, 2, len(tslice)) + require.Len(t, tslice, 2) require.Equal(t, "bar", tslice[0]) require.Equal(t, "foo", tslice[1]) tint := p.Param("TSlice") - require.Equal(t, 2, len(tint)) + require.Len(t, tint, 2) require.Equal(t, "bar", tint[0]) require.Equal(t, "foo", tint[1]) }) t.Run("del params", func(t *testing.T) { + t.Parallel() req := AcquireRequest() req.SetParam("foo", "bar"). SetParams(map[string]string{ @@ -247,10 +256,10 @@ func Test_Request_QueryParam(t *testing.T) { }).DelParams("foo", "bar") res := req.Param("foo") - require.Equal(t, 0, len(res)) + require.Empty(t, res) res = req.Param("bar") - require.Equal(t, 0, len(res)) + require.Empty(t, res) }) } @@ -278,6 +287,7 @@ func Test_Request_Cookie(t *testing.T) { t.Parallel() t.Run("set cookie", func(t *testing.T) { + t.Parallel() req := AcquireRequest(). SetCookie("foo", "bar") require.Equal(t, "bar", req.Cookie("foo")) @@ -287,6 +297,7 @@ func Test_Request_Cookie(t *testing.T) { }) t.Run("set cookies", func(t *testing.T) { + t.Parallel() req := AcquireRequest(). SetCookies(map[string]string{ "foo": "bar", @@ -303,6 +314,7 @@ func Test_Request_Cookie(t *testing.T) { }) t.Run("set cookies with struct", func(t *testing.T) { + t.Parallel() type args struct { CookieInt int `cookie:"int"` CookieString string `cookie:"string"` @@ -318,6 +330,7 @@ func Test_Request_Cookie(t *testing.T) { }) t.Run("del cookies", func(t *testing.T) { + t.Parallel() req := AcquireRequest(). SetCookies(map[string]string{ "foo": "bar", @@ -336,6 +349,7 @@ func Test_Request_PathParam(t *testing.T) { t.Parallel() t.Run("set path param", func(t *testing.T) { + t.Parallel() req := AcquireRequest(). SetPathParam("foo", "bar") require.Equal(t, "bar", req.PathParam("foo")) @@ -345,6 +359,7 @@ func Test_Request_PathParam(t *testing.T) { }) t.Run("set path params", func(t *testing.T) { + t.Parallel() req := AcquireRequest(). SetPathParams(map[string]string{ "foo": "bar", @@ -361,6 +376,7 @@ func Test_Request_PathParam(t *testing.T) { }) t.Run("set path params with struct", func(t *testing.T) { + t.Parallel() type args struct { CookieInt int `path:"int"` CookieString string `path:"string"` @@ -376,6 +392,7 @@ func Test_Request_PathParam(t *testing.T) { }) t.Run("del path params", func(t *testing.T) { + t.Parallel() req := AcquireRequest(). SetPathParams(map[string]string{ "foo": "bar", @@ -390,6 +407,7 @@ func Test_Request_PathParam(t *testing.T) { }) t.Run("clear path params", func(t *testing.T) { + t.Parallel() req := AcquireRequest(). SetPathParams(map[string]string{ "foo": "bar", @@ -408,25 +426,28 @@ func Test_Request_FormData(t *testing.T) { t.Parallel() t.Run("add form data", func(t *testing.T) { + t.Parallel() req := AcquireRequest() req.AddFormData("foo", "bar").AddFormData("foo", "fiber") res := req.FormData("foo") - require.Equal(t, 2, len(res)) + require.Len(t, res, 2) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) }) t.Run("set param", func(t *testing.T) { + t.Parallel() req := AcquireRequest() req.AddFormData("foo", "bar").SetFormData("foo", "fiber") res := req.FormData("foo") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) }) t.Run("add params", func(t *testing.T) { + t.Parallel() req := AcquireRequest() req.SetFormData("foo", "bar"). AddFormDatas(map[string][]string{ @@ -435,17 +456,18 @@ func Test_Request_FormData(t *testing.T) { }) res := req.FormData("foo") - require.Equal(t, 3, len(res)) + require.Len(t, res, 3) require.Equal(t, "bar", res[0]) require.Equal(t, "buaa", res[1]) require.Equal(t, "fiber", res[2]) res = req.FormData("bar") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) t.Run("set headers", func(t *testing.T) { + t.Parallel() req := AcquireRequest() req.SetFormData("foo", "bar"). SetFormDatas(map[string]string{ @@ -454,11 +476,11 @@ func Test_Request_FormData(t *testing.T) { }) res := req.FormData("foo") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) res = req.FormData("bar") - require.Equal(t, 1, len(res)) + require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) @@ -484,32 +506,32 @@ func Test_Request_FormData(t *testing.T) { TIntSlice: []int{1, 2}, }) - require.Equal(t, 0, len(p.FormData("unexport"))) + require.Empty(t, p.FormData("unexport")) - require.Equal(t, 1, len(p.FormData("TInt"))) + require.Len(t, p.FormData("TInt"), 1) require.Equal(t, "5", p.FormData("TInt")[0]) - require.Equal(t, 1, len(p.FormData("TString"))) + require.Len(t, p.FormData("TString"), 1) require.Equal(t, "string", p.FormData("TString")[0]) - require.Equal(t, 1, len(p.FormData("TFloat"))) + require.Len(t, p.FormData("TFloat"), 1) require.Equal(t, "3.1", p.FormData("TFloat")[0]) - require.Equal(t, 1, len(p.FormData("TBool"))) + require.Len(t, p.FormData("TBool"), 1) tslice := p.FormData("TSlice") - require.Equal(t, 2, len(tslice)) + require.Len(t, tslice, 2) require.Equal(t, "bar", tslice[0]) require.Equal(t, "foo", tslice[1]) tint := p.FormData("TSlice") - require.Equal(t, 2, len(tint)) + require.Len(t, tint, 2) require.Equal(t, "bar", tint[0]) require.Equal(t, "foo", tint[1]) - }) t.Run("del params", func(t *testing.T) { + t.Parallel() req := AcquireRequest() req.SetFormData("foo", "bar"). SetFormDatas(map[string]string{ @@ -518,10 +540,10 @@ func Test_Request_FormData(t *testing.T) { }).DelFormDatas("foo", "bar") res := req.FormData("foo") - require.Equal(t, 0, len(res)) + require.Empty(t, res) res = req.FormData("bar") - require.Equal(t, 0, len(res)) + require.Empty(t, res) }) } @@ -529,6 +551,7 @@ func Test_Request_File(t *testing.T) { t.Parallel() t.Run("add file", func(t *testing.T) { + t.Parallel() req := AcquireRequest(). AddFile("../.github/index.html"). AddFiles(AcquireFile(SetFileName("tmp.txt"))) @@ -541,6 +564,7 @@ func Test_Request_File(t *testing.T) { }) t.Run("add file by reader", func(t *testing.T) { + t.Parallel() req := AcquireRequest(). AddFileWithReader("tmp.txt", io.NopCloser(strings.NewReader("world"))) @@ -552,6 +576,7 @@ func Test_Request_File(t *testing.T) { }) t.Run("add files", func(t *testing.T) { + t.Parallel() req := AcquireRequest(). AddFiles(AcquireFile(SetFileName("tmp.txt")), AcquireFile(SetFileName("foo.txt"))) @@ -681,6 +706,7 @@ func Test_Request_Put(t *testing.T) { resp.Close() } } + func Test_Request_Delete(t *testing.T) { t.Parallel() @@ -788,6 +814,7 @@ func Test_Request_Patch(t *testing.T) { } func Test_Request_Header_With_Server(t *testing.T) { + t.Parallel() handler := func(c fiber.Ctx) error { c.Request().Header.VisitAll(func(key, value []byte) { if k := string(key); k == "K1" || k == "K2" { @@ -821,10 +848,12 @@ func Test_Request_UserAgent_With_Server(t *testing.T) { } t.Run("default", func(t *testing.T) { + t.Parallel() testRequest(t, handler, func(agent *Request) {}, defaultUserAgent, 5) }) t.Run("custom", func(t *testing.T) { + t.Parallel() testRequest(t, handler, func(agent *Request) { agent.SetUserAgent("ua") }, "ua", 5) @@ -832,6 +861,7 @@ func Test_Request_UserAgent_With_Server(t *testing.T) { } func Test_Request_Cookie_With_Server(t *testing.T) { + t.Parallel() handler := func(c fiber.Ctx) error { return c.SendString( c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) @@ -850,6 +880,7 @@ func Test_Request_Cookie_With_Server(t *testing.T) { } func Test_Request_Referer_With_Server(t *testing.T) { + t.Parallel() handler := func(c fiber.Ctx) error { return c.Send(c.Request().Header.Referer()) } @@ -862,6 +893,7 @@ func Test_Request_Referer_With_Server(t *testing.T) { } func Test_Request_QueryString_With_Server(t *testing.T) { + t.Parallel() handler := func(c fiber.Ctx) error { return c.Send(c.Request().URI().QueryString()) } @@ -898,6 +930,7 @@ func Test_Request_Body_With_Server(t *testing.T) { t.Parallel() t.Run("json body", func(t *testing.T) { + t.Parallel() testRequest(t, func(c fiber.Ctx) error { require.Equal(t, "application/json", string(c.Request().Header.ContentType())) @@ -913,6 +946,7 @@ func Test_Request_Body_With_Server(t *testing.T) { }) t.Run("xml body", func(t *testing.T) { + t.Parallel() testRequest(t, func(c fiber.Ctx) error { require.Equal(t, "application/xml", string(c.Request().Header.ContentType())) @@ -931,6 +965,7 @@ func Test_Request_Body_With_Server(t *testing.T) { }) t.Run("formdata", func(t *testing.T) { + t.Parallel() testRequest(t, func(c fiber.Ctx) error { require.Equal(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) @@ -972,7 +1007,7 @@ func Test_Request_Body_With_Server(t *testing.T) { SetFileReader(io.NopCloser(strings.NewReader("world"))), )) - require.Equal(t, req.Boundary(), "myBoundary") + require.Equal(t, "myBoundary", req.Boundary()) resp, err := req.Post("http://exmaple.com") require.NoError(t, err) @@ -993,7 +1028,7 @@ func Test_Request_Body_With_Server(t *testing.T) { fh1, err := c.FormFile("field1") require.NoError(t, err) - require.Equal(t, fh1.Filename, "name") + require.Equal(t, "name", fh1.Filename) buf := make([]byte, fh1.Size) f, err := fh1.Open() require.NoError(t, err) @@ -1065,6 +1100,7 @@ func Test_Request_Body_With_Server(t *testing.T) { }) t.Run("raw body", func(t *testing.T) { + t.Parallel() testRequest(t, func(c fiber.Ctx) error { return c.SendString(string(c.Request().Body())) @@ -1078,7 +1114,9 @@ func Test_Request_Body_With_Server(t *testing.T) { } func Test_Request_Error_Body_With_Server(t *testing.T) { + t.Parallel() t.Run("json error", func(t *testing.T) { + t.Parallel() testRequestFail(t, func(c fiber.Ctx) error { return c.SendString("") @@ -1091,6 +1129,7 @@ func Test_Request_Error_Body_With_Server(t *testing.T) { }) t.Run("xml error", func(t *testing.T) { + t.Parallel() testRequestFail(t, func(c fiber.Ctx) error { return c.SendString("") @@ -1157,9 +1196,10 @@ func Test_Request_MaxRedirects(t *testing.T) { return c.SendString("redirect") }) - go func() { require.Equal(t, nil, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() + go func() { require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() t.Run("success", func(t *testing.T) { + t.Parallel() resp, err := AcquireRequest(). SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }). SetMaxRedirects(1). @@ -1175,6 +1215,7 @@ func Test_Request_MaxRedirects(t *testing.T) { }) t.Run("error", func(t *testing.T) { + t.Parallel() resp, err := AcquireRequest(). SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }). SetMaxRedirects(1). @@ -1208,6 +1249,7 @@ func Test_SetValWithStruct(t *testing.T) { } t.Run("the struct should be applied", func(t *testing.T) { + t.Parallel() p := &QueryParam{ Args: fasthttp.AcquireArgs(), } @@ -1262,10 +1304,10 @@ func Test_SetValWithStruct(t *testing.T) { } return false }()) - }) t.Run("the pointer of a struct should be applied", func(t *testing.T) { + t.Parallel() p := &QueryParam{ Args: fasthttp.AcquireArgs(), } @@ -1318,10 +1360,10 @@ func Test_SetValWithStruct(t *testing.T) { } return false }()) - }) t.Run("the zero val should be ignore", func(t *testing.T) { + t.Parallel() p := &QueryParam{ Args: fasthttp.AcquireArgs(), } @@ -1334,11 +1376,12 @@ func Test_SetValWithStruct(t *testing.T) { require.Equal(t, "", string(p.Peek("TInt"))) require.Equal(t, "", string(p.Peek("TString"))) require.Equal(t, "", string(p.Peek("TFloat"))) - require.Equal(t, 0, len(p.PeekMulti("TSlice"))) - require.Equal(t, 0, len(p.PeekMulti("int_slice"))) + require.Empty(t, p.PeekMulti("TSlice")) + require.Empty(t, p.PeekMulti("int_slice")) }) t.Run("error type should ignore", func(t *testing.T) { + t.Parallel() p := &QueryParam{ Args: fasthttp.AcquireArgs(), } @@ -1419,7 +1462,6 @@ func Benchmark_SetValWithStruct(b *testing.B) { } return false }()) - }) b.Run("the pointer of a struct should be applied", func(b *testing.B) { @@ -1480,7 +1522,6 @@ func Benchmark_SetValWithStruct(b *testing.B) { } return false }()) - }) b.Run("the zero val should be ignore", func(b *testing.B) { diff --git a/client/response.go b/client/response.go index d21e218102..47b3da82bc 100644 --- a/client/response.go +++ b/client/response.go @@ -93,7 +93,7 @@ func (r *Response) Save(v any) error { return err } - if err = os.MkdirAll(dir, 0750); err != nil { + if err = os.MkdirAll(dir, 0o750); err != nil { return err } } @@ -166,8 +166,12 @@ var responsePool = &sync.Pool{ // // The returned response may be returned to the pool with ReleaseResponse when no longer needed. // This allows reducing GC load. -func AcquireResponse() (resp *Response) { - return responsePool.Get().(*Response) +func AcquireResponse() *Response { + resp, ok := responsePool.Get().(*Response) + if !ok { + panic("unexpected type from responsePool.Get()") + } + return resp } // ReleaseResponse returns the object acquired via AcquireResponse to the pool. diff --git a/client/response_test.go b/client/response_test.go index 4839615cf6..7db76a07f6 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -94,6 +94,7 @@ func Test_Response_Protocol(t *testing.T) { t.Parallel() t.Run("http", func(t *testing.T) { + t.Parallel() app, ln, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { return c.SendString("foo") @@ -113,10 +114,10 @@ func Test_Response_Protocol(t *testing.T) { t.Parallel() serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() - require.Nil(t, err) + require.NoError(t, err) ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") - require.Nil(t, err) + require.NoError(t, err) ln = tls.NewListener(ln, serverTLSConf) @@ -126,7 +127,7 @@ func Test_Response_Protocol(t *testing.T) { }) go func() { - require.Nil(t, app.Listener(ln, fiber.ListenConfig{ + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ DisableStartupMessage: true, })) }() @@ -134,7 +135,7 @@ func Test_Response_Protocol(t *testing.T) { client := AcquireClient() resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, clientTLSConf, client.TLSConfig()) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "https", resp.String()) @@ -202,6 +203,7 @@ func Test_Response_Body(t *testing.T) { go start() t.Run("raw body", func(t *testing.T) { + t.Parallel() resp, err := AcquireRequest(). SetDial(ln). Get("http://example.com") @@ -212,6 +214,7 @@ func Test_Response_Body(t *testing.T) { }) t.Run("string body", func(t *testing.T) { + t.Parallel() resp, err := AcquireRequest(). SetDial(ln). Get("http://example.com") @@ -222,6 +225,7 @@ func Test_Response_Body(t *testing.T) { }) t.Run("json body", func(t *testing.T) { + t.Parallel() type body struct { Status string `json:"status"` } @@ -240,6 +244,7 @@ func Test_Response_Body(t *testing.T) { }) t.Run("xml body", func(t *testing.T) { + t.Parallel() type body struct { Name xml.Name `xml:"status"` Status string `xml:"name"` @@ -260,7 +265,7 @@ func Test_Response_Body(t *testing.T) { } func Test_Response_Save(t *testing.T) { - + t.Parallel() app, ln, start := createHelperServer(t) app.Get("/json", func(c fiber.Ctx) error { return c.SendString("{\"status\":\"success\"}") @@ -270,6 +275,7 @@ func Test_Response_Save(t *testing.T) { time.Sleep(300 * time.Millisecond) t.Run("file path", func(t *testing.T) { + t.Parallel() resp, err := AcquireRequest(). SetDial(ln). Get("http://example.com/json") @@ -300,6 +306,7 @@ func Test_Response_Save(t *testing.T) { }) t.Run("io.Writer", func(t *testing.T) { + t.Parallel() resp, err := AcquireRequest(). SetDial(ln). Get("http://example.com/json") @@ -314,6 +321,7 @@ func Test_Response_Save(t *testing.T) { }) t.Run("error type", func(t *testing.T) { + t.Parallel() resp, err := AcquireRequest(). SetDial(ln). Get("http://example.com/json") diff --git a/log/default.go b/log/default.go index 690f734602..b2cae0665d 100644 --- a/log/default.go +++ b/log/default.go @@ -32,10 +32,10 @@ func (l *defaultLogger) privateLog(lv Level, fmtArgs []any) { if lv == LevelPanic { panic(buf.String()) - } else { - _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error } + _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error + buf.Reset() bytebufferpool.Put(buf) if lv == LevelFatal { @@ -61,10 +61,10 @@ func (l *defaultLogger) privateLogf(lv Level, format string, fmtArgs []any) { if lv == LevelPanic { panic(buf.String()) - } else { - _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error } + _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error + buf.Reset() bytebufferpool.Put(buf) if lv == LevelFatal { @@ -108,10 +108,10 @@ func (l *defaultLogger) privateLogw(lv Level, format string, keysAndValues []any if lv == LevelPanic { panic(buf.String()) - } else { - _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error } + _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error + buf.Reset() bytebufferpool.Put(buf) if lv == LevelFatal { diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 3594638931..978d33523d 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -3,9 +3,6 @@ package proxy import ( "crypto/tls" "errors" - "github.com/gofiber/fiber/v3" - fiberClient "github.com/gofiber/fiber/v3/client" - "github.com/stretchr/testify/require" "io" "net" "net/http/httptest" @@ -13,6 +10,10 @@ import ( "testing" "time" + "github.com/gofiber/fiber/v3" + clientpkg "github.com/gofiber/fiber/v3/client" + "github.com/stretchr/testify/require" + "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/valyala/fasthttp" ) @@ -137,8 +138,8 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) { })) }() - client := fiberClient.AcquireClient() - defer fiberClient.ReleaseClient(client) + client := clientpkg.AcquireClient() + defer clientpkg.ReleaseClient(client) client.SetTLSConfig(clientTLSConf) resp, err := client.Get("https://" + addr + "/tlsbalancer") @@ -176,10 +177,10 @@ func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) { })) }() - client := fiberClient.AcquireClient() - defer fiberClient.ReleaseClient(client) + client := clientpkg.AcquireClient() + defer clientpkg.ReleaseClient(client) client.SetTimeout(5 * time.Second) - client.TLSConfig().InsecureSkipVerify = true //nolint:gosec // We're in a test func, so this is fine + client.TLSConfig().InsecureSkipVerify = true resp, err := client.Get("https://" + proxyAddr) require.NoError(t, err) @@ -242,8 +243,8 @@ func Test_Proxy_Forward_WithClient_TLSConfig(t *testing.T) { })) }() - client := fiberClient.AcquireClient() - defer fiberClient.ReleaseClient(client) + client := clientpkg.AcquireClient() + defer clientpkg.ReleaseClient(client) client.SetTLSConfig(clientTLSConf) resp, err := client.Get("https://" + addr) @@ -482,7 +483,7 @@ func Test_Proxy_DoTimeout_RestoreOriginalURL(t *testing.T) { }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) - require.NoError(t, nil, err1) + require.NotErrorIs(t, nil, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "proxied", string(body)) @@ -598,8 +599,8 @@ func Test_Proxy_Forward_Global_Client(t *testing.T) { })) }() - client := fiberClient.AcquireClient() - defer fiberClient.ReleaseClient(client) + client := clientpkg.AcquireClient() + defer clientpkg.ReleaseClient(client) resp, err := client.Get("http://" + addr) require.NoError(t, err) @@ -631,8 +632,8 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) { })) }() - client := fiberClient.AcquireClient() - defer fiberClient.ReleaseClient(client) + client := clientpkg.AcquireClient() + defer clientpkg.ReleaseClient(client) resp, err := client.Get("http://" + addr) require.NoError(t, err) @@ -710,8 +711,8 @@ func Test_Proxy_Domain_Forward_Local(t *testing.T) { })) }() - client := fiberClient.AcquireClient() - defer fiberClient.ReleaseClient(client) + client := clientpkg.AcquireClient() + defer clientpkg.ReleaseClient(client) resp, err := client.Get("http://" + localDomain + "/test?query_test=true") require.NoError(t, err) @@ -739,5 +740,5 @@ func Test_Proxy_Balancer_Forward_Local(t *testing.T) { b, err := io.ReadAll(resp.Body) require.NoError(t, err) - require.Equal(t, string(b), "forwarded") + require.Equal(t, "forwarded", string(b)) } From 6c413d1c88e6f6eb66b0049f08a06610309e7bec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Efe=20=C3=87etin?= Date: Sat, 10 Feb 2024 02:13:45 +0300 Subject: [PATCH 089/118] Update helper_test.go --- client/helper_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/helper_test.go b/client/helper_test.go index 92a120d4d3..75d438c34a 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -13,7 +13,8 @@ func createHelperServer(t testing.TB, config ...fiber.Config) (*fiber.App, func( t.Helper() ln := fasthttputil.NewInmemoryListener() - + defer ln.Close() + var cfg fiber.Config if len(config) > 0 { cfg = config[0] From 7a6ca57e6b885dc03740bab5f540ef8636d67978 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Werner?= Date: Sun, 11 Feb 2024 00:54:18 +0100 Subject: [PATCH 090/118] add more test cases --- client/client.go | 8 ++++++++ client/client_test.go | 28 ++++++++++++++++++++++++++++ client/helper_test.go | 4 ++-- client/request.go | 5 +++++ 4 files changed, 43 insertions(+), 2 deletions(-) diff --git a/client/client.go b/client/client.go index 78cc8eb18e..a4c142eede 100644 --- a/client/client.go +++ b/client/client.go @@ -553,6 +553,14 @@ func (c *Client) Patch(url string, cfg ...Config) (*Response, error) { return req.Patch(url) } +// Custom provide an API like axios which send custom request. +func (c *Client) Custom(url string, method string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Custom(url, method) +} + // Reset clear Client object func (c *Client) Reset() { c.baseURL = "" diff --git a/client/client_test.go b/client/client_test.go index cb165376c9..8b373bdd36 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -8,6 +8,7 @@ import ( "net" "os" "reflect" + "sync" "testing" "time" @@ -159,6 +160,33 @@ func Test_Client_Unsupported_Protocol(t *testing.T) { require.ErrorIs(t, err, ErrURLFormat) } +func Test_Client_ConcurrencyRequests(t *testing.T) { + t.Parallel() + + app, dial, start := createHelperServer(t) + app.All("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname() + " " + c.Method()) + }) + go start() + + wg := sync.WaitGroup{} + for i := 0; i < 5; i++ { + for _, method := range []string{"GET", "POST", "PUT", "DELETE", "PATCH"} { + wg.Add(1) + go func(m string) { + defer wg.Done() + resp, err := C().Custom("http://example.com", m, Config{ + Dial: dial, + }) + require.NoError(t, err) + require.Equal(t, "example.com "+m, utils.UnsafeString(resp.RawResponse.Body())) + }(method) + } + } + + wg.Wait() +} + func Test_Get(t *testing.T) { t.Parallel() diff --git a/client/helper_test.go b/client/helper_test.go index 75d438c34a..e5cc0508b5 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -13,8 +13,7 @@ func createHelperServer(t testing.TB, config ...fiber.Config) (*fiber.App, func( t.Helper() ln := fasthttputil.NewInmemoryListener() - defer ln.Close() - + var cfg fiber.Config if len(config) > 0 { cfg = config[0] @@ -27,6 +26,7 @@ func createHelperServer(t testing.TB, config ...fiber.Config) (*fiber.App, func( }, func() { require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) } + // TODO: add closer fn } func testRequest(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted string, count ...int) { diff --git a/client/request.go b/client/request.go index ef0c0bbe33..99ec40f738 100644 --- a/client/request.go +++ b/client/request.go @@ -532,6 +532,11 @@ func (r *Request) Patch(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodPatch).Send() } +// Custom Send custom request. +func (r *Request) Custom(url string, method string) (*Response, error) { + return r.SetURL(url).SetMethod(method).Send() +} + // Send a request. func (r *Request) Send() (*Response, error) { r.checkClient() From de0ccb899d72fe27baee9b09e3e8aa3b20262e1f Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 18 Feb 2024 17:22:25 +0300 Subject: [PATCH 091/118] add hostclient pool --- client/client.go | 54 ++++++++++++++++++++++++++++++++----------- client/client_test.go | 15 ------------ client/core.go | 23 ++++++++++-------- client/core_test.go | 6 ++--- 4 files changed, 57 insertions(+), 41 deletions(-) diff --git a/client/client.go b/client/client.go index a4c142eede..ed1f4d1cb1 100644 --- a/client/client.go +++ b/client/client.go @@ -35,8 +35,6 @@ var ( type Client struct { mu sync.RWMutex - host *fasthttp.HostClient - baseURL string userAgent string referer string @@ -111,17 +109,6 @@ func (c *Client) AddResponseHook(h ...ResponseHook) *Client { return c } -// HostClient returns host client in client. -func (c *Client) HostClient() *fasthttp.HostClient { - return c.host -} - -// SetHostClient sets host client in client. -func (c *Client) SetHostClient(host *fasthttp.HostClient) *Client { - c.host = host - return c -} - // JSONMarshal returns json marshal function in Core. func (c *Client) JSONMarshal() utils.JSONMarshal { return c.jsonMarshal @@ -669,7 +656,6 @@ var ( clientPool = &sync.Pool{ New: func() any { return &Client{ - host: &fasthttp.HostClient{}, header: &Header{ RequestHeader: &fasthttp.RequestHeader{}, }, @@ -773,3 +759,43 @@ func Options(url string, cfg ...Config) (*Response, error) { func Patch(url string, cfg ...Config) (*Response, error) { return C().Patch(url, cfg...) } + +var hostClienPool = &sync.Pool{ + New: func() any { + return &fasthttp.HostClient{} + }, +} + +// AcquireHostClient returns an empty HostClient object from the pool. +// +// The returned HostClient object may be returned to the pool with ReleaseHostClient when no longer needed. +// This allows reducing GC load. +func AcquireHostClient() *fasthttp.HostClient { + hostClient, ok := hostClienPool.Get().(*fasthttp.HostClient) + if !ok { + panic(fmt.Errorf("failed to type-assert to *fasthttp.HostClient")) + } + + return hostClient +} + +// ReleaseHostClient returns the object acquired via AcquireHostClient to the pool. +// +// Do not access the released HostClient object, otherwise data +func ReleaseHostClient(h *fasthttp.HostClient) { + // reset host client + h.Addr = "" + h.Name = "" + h.Dial = nil + h.MaxConns = 0 + h.MaxIdleConnDuration = 0 + h.ReadTimeout = 0 + h.WriteTimeout = 0 + h.ReadBufferSize = 0 + h.WriteBufferSize = 0 + h.DisableHeaderNamesNormalizing = false + h.DisablePathNormalizing = false + h.DisablePathNormalizing = false + + hostClienPool.Put(h) +} diff --git a/client/client_test.go b/client/client_test.go index 8b373bdd36..139af66bbc 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -13,7 +13,6 @@ import ( "time" "github.com/gofiber/fiber/v3/addon/retry" - "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/tlstest" @@ -117,20 +116,6 @@ func Test_Client_SetBaseURL(t *testing.T) { require.Equal(t, "http://example.com", client.BaseURL()) } -func Test_Client_SetHostClient(t *testing.T) { - t.Parallel() - - hostClient := &fasthttp.HostClient{} - hostClient.Name = "test" - - client := AcquireClient() - defer ReleaseClient(client) - - client.SetHostClient(hostClient) - - require.Equal(t, "test", client.HostClient().Name) -} - func Test_Client_Invalid_URL(t *testing.T) { t.Parallel() diff --git a/client/core.go b/client/core.go index bdd90692f7..db5085033f 100644 --- a/client/core.go +++ b/client/core.go @@ -52,6 +52,7 @@ func addMissingPort(addr string, isTLS bool) string { // `core` stores middleware and plugin definitions, // and defines the execution process type core struct { + host *fasthttp.HostClient client *Client req *Request ctx context.Context @@ -92,6 +93,8 @@ func (c *core) execFunc() (*Response, error) { c.req.RawRequest.CopyTo(reqv) cfg := c.getRetryConfig() + defer ReleaseHostClient(c.host) + go func() { c.client.mu.Lock() @@ -100,16 +103,16 @@ func (c *core) execFunc() (*Response, error) { if cfg != nil { err = retry.NewExponentialBackoff(*cfg).Retry(func() error { if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { - return c.client.host.DoRedirects(reqv, respv, c.req.maxRedirects) + return c.host.DoRedirects(reqv, respv, c.req.maxRedirects) } - return c.client.host.Do(reqv, respv) + return c.host.Do(reqv, respv) }) } else { if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { - err = c.client.host.DoRedirects(reqv, respv, c.req.maxRedirects) + err = c.host.DoRedirects(reqv, respv, c.req.maxRedirects) } else { - err = c.client.host.Do(reqv, respv) + err = c.host.Do(reqv, respv) } } defer func() { @@ -203,14 +206,14 @@ func (c *core) timeout() context.CancelFunc { // dial set dial in host. func (c *core) dial() { c.client.mu.Lock() - c.client.host.Dial = c.req.dial + c.host.Dial = c.req.dial c.client.mu.Unlock() } // tls sets tls config. func (c *core) tls() { c.client.mu.Lock() - c.client.host.TLSConfig = c.client.tlsConfig.Clone() + c.host.TLSConfig = c.client.tlsConfig.Clone() c.client.mu.Unlock() } @@ -231,8 +234,8 @@ func (c *core) proxy() error { } c.client.mu.Lock() - c.client.host.Addr = addMissingPort(string(rawURI.Host()), isTLS) - c.client.host.IsTLS = isTLS + c.host.Addr = addMissingPort(string(rawURI.Host()), isTLS) + c.host.IsTLS = isTLS c.client.mu.Unlock() return nil @@ -306,7 +309,9 @@ func releaseErrChan(ch chan error) { // newCore returns an empty core object. func newCore() *core { - c := &core{} + c := &core{ + host: AcquireHostClient(), + } return c } diff --git a/client/core_test.go b/client/core_test.go index b8d39b0480..b576d90294 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -87,7 +87,7 @@ func Test_Exec_Func(t *testing.T) { core.client = client core.req = req - core.client.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + core.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.RawRequest.SetRequestURI("http://example.com/normal") resp, err := core.execFunc() @@ -104,7 +104,7 @@ func Test_Exec_Func(t *testing.T) { core.client = client core.req = req - core.client.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + core.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.RawRequest.SetRequestURI("http://example.com/return-error") resp, err := core.execFunc() @@ -124,7 +124,7 @@ func Test_Exec_Func(t *testing.T) { core.client = client core.req = req - core.client.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + core.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.RawRequest.SetRequestURI("http://example.com/hang-up") _, err := core.execFunc() From 8a8086a7a806b095c6327b73e760d6f44a921c0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Werner?= Date: Tue, 20 Feb 2024 14:17:19 +0100 Subject: [PATCH 092/118] make it more thread safe -> there is still something which is shared between the requests --- client/core.go | 54 +++++++++++++++++---------------------------- client/core_test.go | 6 ++--- 2 files changed, 23 insertions(+), 37 deletions(-) diff --git a/client/core.go b/client/core.go index db5085033f..d47026a636 100644 --- a/client/core.go +++ b/client/core.go @@ -52,7 +52,6 @@ func addMissingPort(addr string, isTLS bool) string { // `core` stores middleware and plugin definitions, // and defines the execution process type core struct { - host *fasthttp.HostClient client *Client req *Request ctx context.Context @@ -93,26 +92,31 @@ func (c *core) execFunc() (*Response, error) { c.req.RawRequest.CopyTo(reqv) cfg := c.getRetryConfig() - defer ReleaseHostClient(c.host) + var host = AcquireHostClient() + err := c.configureHostClient(host) + if err != nil { + defer ReleaseHostClient(host) + return nil, err + } go func() { + defer ReleaseHostClient(host) c.client.mu.Lock() - var err error respv := fasthttp.AcquireResponse() if cfg != nil { err = retry.NewExponentialBackoff(*cfg).Retry(func() error { if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { - return c.host.DoRedirects(reqv, respv, c.req.maxRedirects) + return host.DoRedirects(reqv, respv, c.req.maxRedirects) } - return c.host.Do(reqv, respv) + return host.Do(reqv, respv) }) } else { if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { - err = c.host.DoRedirects(reqv, respv, c.req.maxRedirects) + err = host.DoRedirects(reqv, respv, c.req.maxRedirects) } else { - err = c.host.Do(reqv, respv) + err = host.Do(reqv, respv) } } defer func() { @@ -203,22 +207,14 @@ func (c *core) timeout() context.CancelFunc { return cancel } -// dial set dial in host. -func (c *core) dial() { - c.client.mu.Lock() - c.host.Dial = c.req.dial - c.client.mu.Unlock() -} - -// tls sets tls config. -func (c *core) tls() { +// configureHostClient set configureHostClient in host. +func (c *core) configureHostClient(hostClient *fasthttp.HostClient) error { + // tls and dial configuration c.client.mu.Lock() - c.host.TLSConfig = c.client.tlsConfig.Clone() + hostClient.TLSConfig = c.client.tlsConfig.Clone() + hostClient.Dial = c.req.dial c.client.mu.Unlock() -} -// proxy set proxy in host. -func (c *core) proxy() error { rawURI := c.req.RawRequest.URI() if c.client.proxyURL != "" { rawURI = fasthttp.AcquireURI() @@ -233,9 +229,10 @@ func (c *core) proxy() error { return ErrNotSupportSchema } + // proxy configuration c.client.mu.Lock() - c.host.Addr = addMissingPort(string(rawURI.Host()), isTLS) - c.host.IsTLS = isTLS + hostClient.Addr = addMissingPort(string(rawURI.Host()), isTLS) + hostClient.IsTLS = isTLS c.client.mu.Unlock() return nil @@ -260,15 +257,6 @@ func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Resp defer cancel() } - c.tls() - - c.dial() - - err = c.proxy() - if err != nil { - return nil, err - } - // Do http request resp, err := c.execFunc() if err != nil { @@ -309,9 +297,7 @@ func releaseErrChan(ch chan error) { // newCore returns an empty core object. func newCore() *core { - c := &core{ - host: AcquireHostClient(), - } + c := &core{} return c } diff --git a/client/core_test.go b/client/core_test.go index b576d90294..0b181a97c2 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -87,7 +87,7 @@ func Test_Exec_Func(t *testing.T) { core.client = client core.req = req - core.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + core.req.dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.RawRequest.SetRequestURI("http://example.com/normal") resp, err := core.execFunc() @@ -104,7 +104,7 @@ func Test_Exec_Func(t *testing.T) { core.client = client core.req = req - core.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + core.req.dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.RawRequest.SetRequestURI("http://example.com/return-error") resp, err := core.execFunc() @@ -124,7 +124,7 @@ func Test_Exec_Func(t *testing.T) { core.client = client core.req = req - core.host.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + core.req.dial = func(addr string) (net.Conn, error) { return ln.Dial() } req.RawRequest.SetRequestURI("http://example.com/hang-up") _, err := core.execFunc() From 1b42170ae22f7959b8269255ad16eba833768a42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Werner?= Date: Tue, 20 Feb 2024 14:34:17 +0100 Subject: [PATCH 093/118] fixed some golangci-lint errors --- client/client_test.go | 6 +++--- client/core_test.go | 6 +++--- client/helper_test.go | 9 ++------- client/hooks.go | 10 +++++----- client/hooks_test.go | 2 +- client/request.go | 14 +++++++------- client/request_test.go | 12 ++++++------ middleware/proxy/proxy_test.go | 2 +- 8 files changed, 28 insertions(+), 33 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index 139af66bbc..503f5980b3 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -64,7 +64,7 @@ func Test_Client_Marshal(t *testing.T) { t.Run("set json marshal", func(t *testing.T) { t.Parallel() client := AcquireClient(). - SetJSONMarshal(func(v any) ([]byte, error) { + SetJSONMarshal(func(_ any) ([]byte, error) { return []byte("hello"), nil }) val, err := client.JSONMarshal()(nil) @@ -87,7 +87,7 @@ func Test_Client_Marshal(t *testing.T) { t.Run("set xml marshal", func(t *testing.T) { t.Parallel() client := AcquireClient(). - SetXMLMarshal(func(v any) ([]byte, error) { + SetXMLMarshal(func(_ any) ([]byte, error) { return []byte("hello"), nil }) val, err := client.XMLMarshal()(nil) @@ -99,7 +99,7 @@ func Test_Client_Marshal(t *testing.T) { t.Run("set xml unmarshal", func(t *testing.T) { t.Parallel() client := AcquireClient(). - SetXMLUnmarshal(func(data []byte, v any) error { + SetXMLUnmarshal(func(_ []byte, _ any) error { return fmt.Errorf("empty xml") }) diff --git a/client/core_test.go b/client/core_test.go index 0b181a97c2..51d15e76da 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -65,7 +65,7 @@ func Test_Exec_Func(t *testing.T) { return c.SendString(c.Hostname()) }) - app.Get("/return-error", func(c fiber.Ctx) error { + app.Get("/return-error", func(_ fiber.Ctx) error { return fmt.Errorf("the request is error") }) @@ -87,7 +87,7 @@ func Test_Exec_Func(t *testing.T) { core.client = client core.req = req - core.req.dial = func(addr string) (net.Conn, error) { return ln.Dial() } + core.req.dial = func(_ string) (net.Conn, error) { return ln.Dial() } req.RawRequest.SetRequestURI("http://example.com/normal") resp, err := core.execFunc() @@ -159,7 +159,7 @@ func Test_Execute(t *testing.T) { t.Run("add user request hooks", func(t *testing.T) { t.Parallel() core, client, req := newCore(), AcquireClient(), AcquireRequest() - client.AddRequestHook(func(c *Client, r *Request) error { + client.AddRequestHook(func(_ *Client, _ *Request) error { require.Equal(t, "http://example.com", req.URL()) return nil }) diff --git a/client/helper_test.go b/client/helper_test.go index e5cc0508b5..1337dc43bc 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -9,17 +9,12 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) -func createHelperServer(t testing.TB, config ...fiber.Config) (*fiber.App, func(addr string) (net.Conn, error), func()) { +func createHelperServer(t testing.TB) (*fiber.App, func(addr string) (net.Conn, error), func()) { t.Helper() ln := fasthttputil.NewInmemoryListener() - var cfg fiber.Config - if len(config) > 0 { - cfg = config[0] - } - - app := fiber.New(cfg) + app := fiber.New() return app, func(addr string) (net.Conn, error) { return ln.Dial() diff --git a/client/hooks.go b/client/hooks.go index f93ed46803..8b61d8c12a 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -59,13 +59,13 @@ func randString(n int) string { // The baseUrl will be merge with request uri. // Query params and path params deal in this function. func parserRequestURL(c *Client, req *Request) error { - splitUrl := strings.Split(req.url, "?") - // I don't want to judge splitUrl length. - splitUrl = append(splitUrl, "") + splitURL := strings.Split(req.url, "?") + // I don't want to judge splitURL length. + splitURL = append(splitURL, "") // Determine whether to superimpose baseurl based on // whether the URL starts with the protocol - uri := splitUrl[0] + uri := splitURL[0] if !protocolCheck.MatchString(uri) { uri = c.baseURL + uri if !protocolCheck.MatchString(uri) { @@ -85,7 +85,7 @@ func parserRequestURL(c *Client, req *Request) error { req.RawRequest.SetRequestURI(uri) // merge query params - hashSplit := strings.Split(splitUrl[1], "#") + hashSplit := strings.Split(splitURL[1], "#") hashSplit = append(hashSplit, "") args := fasthttp.AcquireArgs() defer func() { diff --git a/client/hooks_test.go b/client/hooks_test.go index 2c89868998..85d1f83fb8 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -583,5 +583,5 @@ func Test_Client_Logger_DisableDebug(t *testing.T) { defer resp.Close() require.NoError(t, err) - require.Len(t, buf.String(), 0) + require.Empty(t, buf.String()) } diff --git a/client/request.go b/client/request.go index 99ec40f738..1ebd959111 100644 --- a/client/request.go +++ b/client/request.go @@ -18,8 +18,8 @@ import ( // WithStruct Implementing this interface allows data to // be stored from a struct via reflect. type WithStruct interface { - Add(string, string) - Del(string) + Add(name string, obj string) + Del(name string) } // Types of request bodies. @@ -895,20 +895,20 @@ func SetFileReader(r io.ReadCloser) SetFileFunc { // // The returned file may be returned to the pool with ReleaseFile when no longer needed. // This allows reducing GC load. -func AcquireFile(setter ...SetFileFunc) (f *File) { +func AcquireFile(setter ...SetFileFunc) *File { fv := filePool.Get() if fv != nil { - f = fv.(*File) + f := fv.(*File) for _, v := range setter { v(f) } - return + return f } - f = &File{} + f := &File{} for _, v := range setter { v(f) } - return + return f } // ReleaseFile returns the object acquired via AcquireFile to the pool. diff --git a/client/request_test.go b/client/request_test.go index 8bfe3cbbe2..76db732808 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -849,7 +849,7 @@ func Test_Request_UserAgent_With_Server(t *testing.T) { t.Run("default", func(t *testing.T) { t.Parallel() - testRequest(t, handler, func(agent *Request) {}, defaultUserAgent, 5) + testRequest(t, handler, func(_ *Request) {}, defaultUserAgent, 5) }) t.Run("custom", func(t *testing.T) { @@ -1540,11 +1540,11 @@ func Benchmark_SetValWithStruct(b *testing.B) { }) } - require.Equal(b, "", string(p.Peek("TInt"))) - require.Equal(b, "", string(p.Peek("TString"))) - require.Equal(b, "", string(p.Peek("TFloat"))) - require.Equal(b, 0, len(p.PeekMulti("TSlice"))) - require.Equal(b, 0, len(p.PeekMulti("int_slice"))) + require.Empty(b, string(p.Peek("TInt"))) + require.Empty(b, string(p.Peek("TString"))) + require.Empty(b, string(p.Peek("TFloat"))) + require.Empty(b, len(p.PeekMulti("TSlice"))) + require.Empty(b, len(p.PeekMulti("int_slice"))) }) b.Run("error type should ignore", func(b *testing.B) { diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 978d33523d..793f5a09ae 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -483,7 +483,7 @@ func Test_Proxy_DoTimeout_RestoreOriginalURL(t *testing.T) { }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) - require.NotErrorIs(t, nil, err1) + require.Error(t, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "proxied", string(body)) From 5b31ffbca27fd3180133a97de3ac90af159b54eb Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Thu, 22 Feb 2024 13:35:00 +0300 Subject: [PATCH 094/118] fix Test_Request_FormData test --- client/request_test.go | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/client/request_test.go b/client/request_test.go index 76db732808..5edd843722 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -428,6 +428,7 @@ func Test_Request_FormData(t *testing.T) { t.Run("add form data", func(t *testing.T) { t.Parallel() req := AcquireRequest() + defer ReleaseRequest(req) req.AddFormData("foo", "bar").AddFormData("foo", "fiber") res := req.FormData("foo") @@ -439,6 +440,7 @@ func Test_Request_FormData(t *testing.T) { t.Run("set param", func(t *testing.T) { t.Parallel() req := AcquireRequest() + defer ReleaseRequest(req) req.AddFormData("foo", "bar").SetFormData("foo", "fiber") res := req.FormData("foo") @@ -449,6 +451,7 @@ func Test_Request_FormData(t *testing.T) { t.Run("add params", func(t *testing.T) { t.Parallel() req := AcquireRequest() + defer ReleaseRequest(req) req.SetFormData("foo", "bar"). AddFormDatas(map[string][]string{ "foo": {"fiber", "buaa"}, @@ -457,9 +460,9 @@ func Test_Request_FormData(t *testing.T) { res := req.FormData("foo") require.Len(t, res, 3) - require.Equal(t, "bar", res[0]) - require.Equal(t, "buaa", res[1]) - require.Equal(t, "fiber", res[2]) + require.Contains(t, res, "bar") + require.Contains(t, res, "buaa") + require.Contains(t, res, "fiber") res = req.FormData("bar") require.Len(t, res, 1) @@ -469,6 +472,7 @@ func Test_Request_FormData(t *testing.T) { t.Run("set headers", func(t *testing.T) { t.Parallel() req := AcquireRequest() + defer ReleaseRequest(req) req.SetFormData("foo", "bar"). SetFormDatas(map[string]string{ "foo": "fiber", @@ -497,6 +501,7 @@ func Test_Request_FormData(t *testing.T) { } p := AcquireRequest() + defer ReleaseRequest(p) p.SetFormDatasWithStruct(&args{ TInt: 5, TString: "string", @@ -521,18 +526,19 @@ func Test_Request_FormData(t *testing.T) { tslice := p.FormData("TSlice") require.Len(t, tslice, 2) - require.Equal(t, "bar", tslice[0]) - require.Equal(t, "foo", tslice[1]) + require.Contains(t, tslice, "bar") + require.Contains(t, tslice, "foo") tint := p.FormData("TSlice") require.Len(t, tint, 2) - require.Equal(t, "bar", tint[0]) - require.Equal(t, "foo", tint[1]) + require.Contains(t, tint, "bar") + require.Contains(t, tint, "foo") }) t.Run("del params", func(t *testing.T) { t.Parallel() req := AcquireRequest() + defer ReleaseRequest(req) req.SetFormData("foo", "bar"). SetFormDatas(map[string]string{ "foo": "fiber", From 88d4cfe89e8a7acbbe9de262271e4cef705a95e3 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sat, 24 Feb 2024 14:44:22 +0300 Subject: [PATCH 095/118] create new test suite --- client/helper_test.go | 53 ++++++++++++++ client/response_test.go | 156 +++++++++++++++++++++++++--------------- 2 files changed, 152 insertions(+), 57 deletions(-) diff --git a/client/helper_test.go b/client/helper_test.go index 1337dc43bc..a5e0742318 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -3,12 +3,65 @@ package client import ( "net" "testing" + "time" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp/fasthttputil" ) +type testServer struct { + app *fiber.App + ch chan struct{} + ln *fasthttputil.InmemoryListener + tb testing.TB +} + +func startTestServer(tb testing.TB) *testServer { + tb.Helper() + + ln := fasthttputil.NewInmemoryListener() + app := fiber.New() + + ch := make(chan struct{}) + go func() { + if err := app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}); err != nil { + tb.Fatal(err) + } + + close(ch) + }() + + return &testServer{ + app: app, + ch: ch, + ln: ln, + tb: tb, + } +} + +func (ts *testServer) stop() { + ts.tb.Helper() + + if err := ts.app.Shutdown(); err != nil { + ts.tb.Fatal(err) + } + + select { + case <-ts.ch: + case <-time.After(time.Second): + ts.tb.Fatalf("timeout when waiting for server close") + } +} + +func (ts *testServer) dial() func(addr string) (net.Conn, error) { + ts.tb.Helper() + + return func(addr string) (net.Conn, error) { + return ts.ln.Dial() + } +} + func createHelperServer(t testing.TB) (*fiber.App, func(addr string) (net.Conn, error), func()) { t.Helper() diff --git a/client/response_test.go b/client/response_test.go index 7db76a07f6..41869aef01 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -8,7 +8,6 @@ import ( "net" "os" "testing" - "time" "github.com/gofiber/fiber/v3/internal/tlstest" @@ -19,20 +18,26 @@ import ( func Test_Response_Status(t *testing.T) { t.Parallel() - app, ln, start := createHelperServer(t) - app.Get("/", func(c fiber.Ctx) error { - return c.SendString("foo") - }) - app.Get("/fail", func(c fiber.Ctx) error { - return c.SendStatus(407) - }) - go start() + setupApp := func() *testServer { + server := startTestServer(t) + + server.app.Get("/", func(c fiber.Ctx) error { + return c.SendString("foo") + }) + server.app.Get("/fail", func(c fiber.Ctx) error { + return c.SendStatus(407) + }) + + return server + } t.Run("success", func(t *testing.T) { t.Parallel() + server := setupApp() + defer server.stop() resp, err := AcquireRequest(). - SetDial(ln). + SetDial(server.dial()). Get("http://example") require.NoError(t, err) @@ -43,8 +48,10 @@ func Test_Response_Status(t *testing.T) { t.Run("fail", func(t *testing.T) { t.Parallel() + server := setupApp() + defer server.stop() resp, err := AcquireRequest(). - SetDial(ln). + SetDial(server.dial()). Get("http://example/fail") require.NoError(t, err) @@ -56,20 +63,26 @@ func Test_Response_Status(t *testing.T) { func Test_Response_Status_Code(t *testing.T) { t.Parallel() - app, ln, start := createHelperServer(t) - app.Get("/", func(c fiber.Ctx) error { - return c.SendString("foo") - }) - app.Get("/fail", func(c fiber.Ctx) error { - return c.SendStatus(407) - }) - go start() + setupApp := func() *testServer { + server := startTestServer(t) + + server.app.Get("/", func(c fiber.Ctx) error { + return c.SendString("foo") + }) + server.app.Get("/fail", func(c fiber.Ctx) error { + return c.SendStatus(407) + }) + + return server + } t.Run("success", func(t *testing.T) { t.Parallel() + server := setupApp() + defer server.stop() resp, err := AcquireRequest(). - SetDial(ln). + SetDial(server.dial()). Get("http://example") require.NoError(t, err) @@ -80,8 +93,10 @@ func Test_Response_Status_Code(t *testing.T) { t.Run("fail", func(t *testing.T) { t.Parallel() + server := setupApp() + defer server.stop() resp, err := AcquireRequest(). - SetDial(ln). + SetDial(server.dial()). Get("http://example/fail") require.NoError(t, err) @@ -95,14 +110,16 @@ func Test_Response_Protocol(t *testing.T) { t.Run("http", func(t *testing.T) { t.Parallel() - app, ln, start := createHelperServer(t) - app.Get("/", func(c fiber.Ctx) error { + + server := startTestServer(t) + defer server.stop() + + server.app.Get("/", func(c fiber.Ctx) error { return c.SendString("foo") }) - go start() resp, err := AcquireRequest(). - SetDial(ln). + SetDial(server.dial()). Get("http://example") require.NoError(t, err) @@ -148,15 +165,15 @@ func Test_Response_Protocol(t *testing.T) { func Test_Response_Header(t *testing.T) { t.Parallel() - app, ln, start := createHelperServer(t) - app.Get("/", func(c fiber.Ctx) error { + server := startTestServer(t) + defer server.stop() + server.app.Get("/", func(c fiber.Ctx) error { c.Response().Header.Add("foo", "bar") return c.SendString("helo world") }) - go start() resp, err := AcquireRequest(). - SetDial(ln). + SetDial(server.dial()). Get("http://example.com") require.NoError(t, err) @@ -167,18 +184,18 @@ func Test_Response_Header(t *testing.T) { func Test_Response_Cookie(t *testing.T) { t.Parallel() - app, ln, start := createHelperServer(t) - app.Get("/", func(c fiber.Ctx) error { + server := startTestServer(t) + defer server.stop() + server.app.Get("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "foo", Value: "bar", }) return c.SendString("helo world") }) - go start() resp, err := AcquireRequest(). - SetDial(ln). + SetDial(server.dial()). Get("http://example.com") require.NoError(t, err) @@ -189,23 +206,29 @@ func Test_Response_Cookie(t *testing.T) { func Test_Response_Body(t *testing.T) { t.Parallel() - app, ln, start := createHelperServer(t) - app.Get("/", func(c fiber.Ctx) error { - return c.SendString("hello world") - }) - app.Get("/json", func(c fiber.Ctx) error { - return c.SendString("{\"status\":\"success\"}") - }) - app.Get("/xml", func(c fiber.Ctx) error { - return c.SendString("success") - }) + setupApp := func() *testServer { + server := startTestServer(t) + + server.app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) + server.app.Get("/json", func(c fiber.Ctx) error { + return c.SendString("{\"status\":\"success\"}") + }) + server.app.Get("/xml", func(c fiber.Ctx) error { + return c.SendString("success") + }) - go start() + return server + } t.Run("raw body", func(t *testing.T) { t.Parallel() + + server := setupApp() + defer server.stop() resp, err := AcquireRequest(). - SetDial(ln). + SetDial(server.dial()). Get("http://example.com") require.NoError(t, err) @@ -215,8 +238,11 @@ func Test_Response_Body(t *testing.T) { t.Run("string body", func(t *testing.T) { t.Parallel() + + server := setupApp() + defer server.stop() resp, err := AcquireRequest(). - SetDial(ln). + SetDial(server.dial()). Get("http://example.com") require.NoError(t, err) @@ -230,8 +256,10 @@ func Test_Response_Body(t *testing.T) { Status string `json:"status"` } + server := setupApp() + defer server.stop() resp, err := AcquireRequest(). - SetDial(ln). + SetDial(server.dial()). Get("http://example.com/json") require.NoError(t, err) @@ -250,8 +278,10 @@ func Test_Response_Body(t *testing.T) { Status string `xml:"name"` } + server := setupApp() + defer server.stop() resp, err := AcquireRequest(). - SetDial(ln). + SetDial(server.dial()). Get("http://example.com/xml") require.NoError(t, err) @@ -266,18 +296,24 @@ func Test_Response_Body(t *testing.T) { func Test_Response_Save(t *testing.T) { t.Parallel() - app, ln, start := createHelperServer(t) - app.Get("/json", func(c fiber.Ctx) error { - return c.SendString("{\"status\":\"success\"}") - }) - go start() - time.Sleep(300 * time.Millisecond) + setupApp := func() *testServer { + server := startTestServer(t) + + server.app.Get("/json", func(c fiber.Ctx) error { + return c.SendString("{\"status\":\"success\"}") + }) + + return server + } t.Run("file path", func(t *testing.T) { t.Parallel() + + server := setupApp() + defer server.stop() resp, err := AcquireRequest(). - SetDial(ln). + SetDial(server.dial()). Get("http://example.com/json") require.NoError(t, err) @@ -307,8 +343,11 @@ func Test_Response_Save(t *testing.T) { t.Run("io.Writer", func(t *testing.T) { t.Parallel() + + server := setupApp() + defer server.stop() resp, err := AcquireRequest(). - SetDial(ln). + SetDial(server.dial()). Get("http://example.com/json") require.NoError(t, err) @@ -322,8 +361,11 @@ func Test_Response_Save(t *testing.T) { t.Run("error type", func(t *testing.T) { t.Parallel() + + server := setupApp() + defer server.stop() resp, err := AcquireRequest(). - SetDial(ln). + SetDial(server.dial()). Get("http://example.com/json") require.NoError(t, err) From 1cc045cd6cd81405a7c86b8a145f809e65269310 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sat, 24 Feb 2024 17:33:51 +0300 Subject: [PATCH 096/118] just create client for once --- client/client.go | 69 ++++++------------------- client/client_test.go | 109 +++++++++++++++++----------------------- client/core.go | 26 ++++------ client/core_test.go | 36 +++++++------ client/helper_test.go | 14 ++++-- client/request.go | 9 ---- client/request_test.go | 95 +++++++++++++++++++++++++--------- client/response_test.go | 81 +++++++++++++++++++++++------ 8 files changed, 242 insertions(+), 197 deletions(-) diff --git a/client/client.go b/client/client.go index ed1f4d1cb1..2952a99765 100644 --- a/client/client.go +++ b/client/client.go @@ -35,6 +35,8 @@ var ( type Client struct { mu sync.RWMutex + client *fasthttp.Client + baseURL string userAgent string referer string @@ -66,9 +68,6 @@ type Client struct { cookieJar *CookieJar - // tls config - tlsConfig *tls.Config - // proxy proxyURL string @@ -156,18 +155,18 @@ func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client { // TLSConfig returns tlsConfig in client. // If client don't have tlsConfig, this function will init it. func (c *Client) TLSConfig() *tls.Config { - if c.tlsConfig == nil { - c.tlsConfig = &tls.Config{ + if c.client.TLSConfig == nil { + c.client.TLSConfig = &tls.Config{ MinVersion: tls.VersionTLS12, } } - return c.tlsConfig + return c.client.TLSConfig } // SetTLSConfig sets tlsConfig in client. func (c *Client) SetTLSConfig(config *tls.Config) *Client { - c.tlsConfig = config + c.client.TLSConfig = config return c } @@ -548,8 +547,17 @@ func (c *Client) Custom(url string, method string, cfg ...Config) (*Response, er return req.Custom(url, method) } +func (c *Client) SetDial(dial fasthttp.DialFunc) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.client.Dial = dial + return c +} + // Reset clear Client object func (c *Client) Reset() { + c.client = &fasthttp.Client{} c.baseURL = "" c.timeout = 0 c.userAgent = "" @@ -582,8 +590,6 @@ type Config struct { Body any FormData map[string]string File []*File - - Dial fasthttp.DialFunc } // setConfigToRequest Set the parameters passed via Config to Request. @@ -629,10 +635,6 @@ func setConfigToRequest(req *Request, config ...Config) { req.SetMaxRedirects(cfg.MaxRedirects) } - if cfg.Dial != nil { - req.SetDial(cfg.Dial) - } - if cfg.Body != nil { req.SetJSON(cfg.Body) return @@ -656,6 +658,7 @@ var ( clientPool = &sync.Pool{ New: func() any { return &Client{ + client: &fasthttp.Client{}, header: &Header{ RequestHeader: &fasthttp.RequestHeader{}, }, @@ -759,43 +762,3 @@ func Options(url string, cfg ...Config) (*Response, error) { func Patch(url string, cfg ...Config) (*Response, error) { return C().Patch(url, cfg...) } - -var hostClienPool = &sync.Pool{ - New: func() any { - return &fasthttp.HostClient{} - }, -} - -// AcquireHostClient returns an empty HostClient object from the pool. -// -// The returned HostClient object may be returned to the pool with ReleaseHostClient when no longer needed. -// This allows reducing GC load. -func AcquireHostClient() *fasthttp.HostClient { - hostClient, ok := hostClienPool.Get().(*fasthttp.HostClient) - if !ok { - panic(fmt.Errorf("failed to type-assert to *fasthttp.HostClient")) - } - - return hostClient -} - -// ReleaseHostClient returns the object acquired via AcquireHostClient to the pool. -// -// Do not access the released HostClient object, otherwise data -func ReleaseHostClient(h *fasthttp.HostClient) { - // reset host client - h.Addr = "" - h.Name = "" - h.Dial = nil - h.MaxConns = 0 - h.MaxIdleConnDuration = 0 - h.ReadTimeout = 0 - h.WriteTimeout = 0 - h.ReadBufferSize = 0 - h.WriteBufferSize = 0 - h.DisableHeaderNamesNormalizing = false - h.DisablePathNormalizing = false - h.DisablePathNormalizing = false - - hostClienPool.Put(h) -} diff --git a/client/client_test.go b/client/client_test.go index 503f5980b3..0773e83da4 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -127,9 +127,8 @@ func Test_Client_Invalid_URL(t *testing.T) { go start() - _, err := AcquireClient(). + _, err := AcquireClient().SetDial(dial). R(). - SetDial(dial). Get("http://example.com\r\n\r\nGET /\r\n\r\n") require.ErrorIs(t, err, ErrURLFormat) @@ -156,13 +155,12 @@ func Test_Client_ConcurrencyRequests(t *testing.T) { wg := sync.WaitGroup{} for i := 0; i < 5; i++ { + C().SetDial(dial) for _, method := range []string{"GET", "POST", "PUT", "DELETE", "PATCH"} { wg.Add(1) go func(m string) { defer wg.Done() - resp, err := C().Custom("http://example.com", m, Config{ - Dial: dial, - }) + resp, err := C().Custom("http://example.com", m) require.NoError(t, err) require.Equal(t, "example.com "+m, utils.UnsafeString(resp.RawResponse.Body())) }(method) @@ -185,18 +183,15 @@ func Test_Get(t *testing.T) { t.Run("global get function", func(t *testing.T) { t.Parallel() - resp, err := Get("http://example.com", Config{ - Dial: dial, - }) + C().SetDial(dial) + resp, err := Get("http://example.com") require.NoError(t, err) require.Equal(t, "example.com", utils.UnsafeString(resp.RawResponse.Body())) }) t.Run("client get", func(t *testing.T) { t.Parallel() - resp, err := AcquireClient().Get("http://example.com", Config{ - Dial: dial, - }) + resp, err := AcquireClient().SetDial(dial).Get("http://example.com") require.NoError(t, err) require.Equal(t, "example.com", utils.UnsafeString(resp.RawResponse.Body())) }) @@ -215,18 +210,15 @@ func Test_Head(t *testing.T) { t.Run("global head function", func(t *testing.T) { t.Parallel() - resp, err := Head("http://example.com", Config{ - Dial: dial, - }) + C().SetDial(dial) + resp, err := Head("http://example.com") require.NoError(t, err) require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body())) }) t.Run("client head", func(t *testing.T) { t.Parallel() - resp, err := AcquireClient().Head("http://example.com", Config{ - Dial: dial, - }) + resp, err := AcquireClient().SetDial(dial).Head("http://example.com") require.NoError(t, err) require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body())) }) @@ -246,8 +238,8 @@ func Test_Post(t *testing.T) { t.Run("global post function", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { + C().SetDial(dial) resp, err := Post("http://example.com", Config{ - Dial: dial, FormData: map[string]string{ "foo": "bar", }, @@ -262,8 +254,7 @@ func Test_Post(t *testing.T) { t.Run("client post", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { - resp, err := AcquireClient().Post("http://example.com", Config{ - Dial: dial, + resp, err := AcquireClient().SetDial(dial).Post("http://example.com", Config{ FormData: map[string]string{ "foo": "bar", }, @@ -290,7 +281,6 @@ func Test_Put(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { resp, err := Put("http://example.com", Config{ - Dial: dial, FormData: map[string]string{ "foo": "bar", }, @@ -305,8 +295,7 @@ func Test_Put(t *testing.T) { t.Run("client put", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { - resp, err := AcquireClient().Put("http://example.com", Config{ - Dial: dial, + resp, err := AcquireClient().SetDial(dial).Put("http://example.com", Config{ FormData: map[string]string{ "foo": "bar", }, @@ -334,7 +323,6 @@ func Test_Delete(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { resp, err := Delete("http://example.com", Config{ - Dial: dial, FormData: map[string]string{ "foo": "bar", }, @@ -349,8 +337,7 @@ func Test_Delete(t *testing.T) { t.Run("client delete", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { - resp, err := AcquireClient().Delete("http://example.com", Config{ - Dial: dial, + resp, err := AcquireClient().SetDial(dial).Delete("http://example.com", Config{ FormData: map[string]string{ "foo": "bar", }, @@ -376,9 +363,8 @@ func Test_Options(t *testing.T) { t.Run("global options function", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { - resp, err := Options("http://example.com", Config{ - Dial: dial, - }) + C().SetDial(dial) + resp, err := Options("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) @@ -389,9 +375,7 @@ func Test_Options(t *testing.T) { t.Run("client options", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { - resp, err := AcquireClient().Options("http://example.com", Config{ - Dial: dial, - }) + resp, err := AcquireClient().SetDial(dial).Options("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) @@ -414,8 +398,8 @@ func Test_Patch(t *testing.T) { t.Run("global patch function", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { + C().SetDial(dial) resp, err := Patch("http://example.com", Config{ - Dial: dial, FormData: map[string]string{ "foo": "bar", }, @@ -430,8 +414,7 @@ func Test_Patch(t *testing.T) { t.Run("client patch", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { - resp, err := AcquireClient().Patch("http://example.com", Config{ - Dial: dial, + resp, err := AcquireClient().SetDial(dial).Patch("http://example.com", Config{ FormData: map[string]string{ "foo": "bar", }, @@ -458,9 +441,8 @@ func Test_Client_UserAgent(t *testing.T) { t.Run("default", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { - resp, err := Get("http://example.com", Config{ - Dial: dial, - }) + C().SetDial(dial) + resp, err := Get("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) @@ -471,12 +453,10 @@ func Test_Client_UserAgent(t *testing.T) { t.Run("custom", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { - c := AcquireClient(). + c := AcquireClient().SetDial(dial). SetUserAgent("ua") - resp, err := c.Get("http://example.com", Config{ - Dial: dial, - }) + resp, err := c.Get("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) @@ -522,8 +502,8 @@ func Test_Client_Header(t *testing.T) { res := req.Header("foo") require.Len(t, res, 3) require.Equal(t, "bar", res[0]) - require.Equal(t, "buaa", res[1]) - require.Equal(t, "fiber", res[2]) + require.Equal(t, "fiber", res[1]) + require.Equal(t, "buaa", res[2]) res = req.Header("bar") require.Len(t, res, 1) @@ -810,8 +790,8 @@ func Test_Client_QueryParam(t *testing.T) { res := req.Param("foo") require.Len(t, res, 3) require.Equal(t, "bar", res[0]) - require.Equal(t, "buaa", res[1]) - require.Equal(t, "fiber", res[2]) + require.Equal(t, "fiber", res[1]) + require.Equal(t, "buaa", res[2]) res = req.Param("bar") require.Len(t, res, 1) @@ -873,13 +853,13 @@ func Test_Client_QueryParam(t *testing.T) { tslice := p.Param("TSlice") require.Len(t, tslice, 2) - require.Equal(t, "bar", tslice[0]) - require.Equal(t, "foo", tslice[1]) + require.Equal(t, "foo", tslice[0]) + require.Equal(t, "bar", tslice[1]) tint := p.Param("TSlice") require.Len(t, tint, 2) - require.Equal(t, "bar", tint[0]) - require.Equal(t, "foo", tint[1]) + require.Equal(t, "foo", tint[0]) + require.Equal(t, "bar", tint[1]) }) t.Run("del params", func(t *testing.T) { @@ -986,9 +966,9 @@ func Test_Client_PathParam_With_Server(t *testing.T) { go start() - resp, err := AcquireClient(). + resp, err := AcquireClient().SetDial(dial). SetPathParam("path", "test"). - Get("http://example.com/:path", Config{Dial: dial}) + Get("http://example.com/:path") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) @@ -1063,14 +1043,14 @@ func Test_Client_SetCertificates(t *testing.T) { require.NoError(t, err) client := AcquireClient().SetCertificates(serverTLSConf.Certificates...) - require.Len(t, client.tlsConfig.Certificates, 1) + require.Len(t, client.TLSConfig().Certificates, 1) } func Test_Client_SetRootCertificate(t *testing.T) { t.Parallel() client := AcquireClient().SetRootCertificate("../.github/testdata/ssl.pem") - require.NotNil(t, client.tlsConfig.RootCAs) + require.NotNil(t, client.TLSConfig().RootCAs) } func Test_Client_SetRootCertificateFromString(t *testing.T) { @@ -1084,7 +1064,7 @@ func Test_Client_SetRootCertificateFromString(t *testing.T) { require.NoError(t, err) client := AcquireClient().SetRootCertificateFromString(string(pem)) - require.NotNil(t, client.tlsConfig.RootCAs) + require.NotNil(t, client.TLSConfig().RootCAs) } func Test_Client_R(t *testing.T) { @@ -1106,15 +1086,16 @@ func Test_Replace(t *testing.T) { go start() - resp, err := Get("http://example.com", Config{Dial: dial}) + C().SetDial(dial) + resp, err := Get("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "", resp.String()) - r := AcquireClient().SetHeader("k1", "v1") + r := AcquireClient().SetDial(dial).SetHeader("k1", "v1") clean := Replace(r) - resp, err = Get("http://example.com", Config{Dial: dial}) + resp, err = Get("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "v1", resp.String()) @@ -1122,7 +1103,8 @@ func Test_Replace(t *testing.T) { clean() ReleaseClient(r) - resp, err = Get("http://example.com", Config{Dial: dial}) + C().SetDial(dial) + resp, err = Get("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) @@ -1267,9 +1249,9 @@ func Test_Client_SetProxyURL(t *testing.T) { t.Run("success", func(t *testing.T) { t.Parallel() - client := AcquireClient() + client := AcquireClient().SetDial(dial) client.SetProxyURL("http://test.com") - _, err := client.Get("http://localhost:3000", Config{Dial: dial}) + _, err := client.Get("http://localhost:3000") require.NoError(t, err) }) @@ -1321,11 +1303,14 @@ func Benchmark_Client_Request(b *testing.B) { go start() + client := AcquireClient().SetDial(dial) + defer ReleaseClient(client) + b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - resp, _ := Get("http://example.com", Config{Dial: dial}) + resp, _ := client.Get("http://example.com") resp.Close() } } diff --git a/client/core.go b/client/core.go index d47026a636..44b4e5f71d 100644 --- a/client/core.go +++ b/client/core.go @@ -1,7 +1,6 @@ package client import ( - "bytes" "context" "errors" "net" @@ -92,31 +91,24 @@ func (c *core) execFunc() (*Response, error) { c.req.RawRequest.CopyTo(reqv) cfg := c.getRetryConfig() - var host = AcquireHostClient() - - err := c.configureHostClient(host) - if err != nil { - defer ReleaseHostClient(host) - return nil, err - } + var err error go func() { - defer ReleaseHostClient(host) - c.client.mu.Lock() + //c.client.mu.Lock() respv := fasthttp.AcquireResponse() if cfg != nil { err = retry.NewExponentialBackoff(*cfg).Retry(func() error { if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { - return host.DoRedirects(reqv, respv, c.req.maxRedirects) + return c.client.client.DoRedirects(reqv, respv, c.req.maxRedirects) } - return host.Do(reqv, respv) + return c.client.client.Do(reqv, respv) }) } else { if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { - err = host.DoRedirects(reqv, respv, c.req.maxRedirects) + err = c.client.client.DoRedirects(reqv, respv, c.req.maxRedirects) } else { - err = host.Do(reqv, respv) + err = c.client.client.Do(reqv, respv) } } defer func() { @@ -132,7 +124,7 @@ func (c *core) execFunc() (*Response, error) { respv.CopyTo(resp.RawResponse) errCh <- nil } - c.client.mu.Unlock() + //c.client.mu.Unlock() }() select { @@ -208,7 +200,7 @@ func (c *core) timeout() context.CancelFunc { } // configureHostClient set configureHostClient in host. -func (c *core) configureHostClient(hostClient *fasthttp.HostClient) error { +/*func (c *core) configureHostClient(hostClient *fasthttp.Client) error { // tls and dial configuration c.client.mu.Lock() hostClient.TLSConfig = c.client.tlsConfig.Clone() @@ -236,7 +228,7 @@ func (c *core) configureHostClient(hostClient *fasthttp.HostClient) error { c.client.mu.Unlock() return nil -} +}*/ // execute will exec each hooks and plugins. func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Response, error) { diff --git a/client/core_test.go b/client/core_test.go index 51d15e76da..e8b3ce3eb5 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -87,7 +87,7 @@ func Test_Exec_Func(t *testing.T) { core.client = client core.req = req - core.req.dial = func(_ string) (net.Conn, error) { return ln.Dial() } + client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.RawRequest.SetRequestURI("http://example.com/normal") resp, err := core.execFunc() @@ -104,7 +104,7 @@ func Test_Exec_Func(t *testing.T) { core.client = client core.req = req - core.req.dial = func(addr string) (net.Conn, error) { return ln.Dial() } + client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.RawRequest.SetRequestURI("http://example.com/return-error") resp, err := core.execFunc() @@ -124,7 +124,7 @@ func Test_Exec_Func(t *testing.T) { core.client = client core.req = req - core.req.dial = func(addr string) (net.Conn, error) { return ln.Dial() } + client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.RawRequest.SetRequestURI("http://example.com/hang-up") _, err := core.execFunc() @@ -163,9 +163,10 @@ func Test_Execute(t *testing.T) { require.Equal(t, "http://example.com", req.URL()) return nil }) - req.SetDial(func(addr string) (net.Conn, error) { + client.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() - }).SetURL("http://example.com") + }) + req.SetURL("http://example.com") resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) @@ -179,9 +180,10 @@ func Test_Execute(t *testing.T) { require.Equal(t, "http://example.com", req.URL()) return nil }) - req.SetDial(func(addr string) (net.Conn, error) { + client.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() - }).SetURL("http://example.com") + }) + req.SetURL("http://example.com") resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) @@ -192,9 +194,10 @@ func Test_Execute(t *testing.T) { t.Parallel() core, client, req := newCore(), AcquireClient(), AcquireRequest() - req.SetDial(func(addr string) (net.Conn, error) { + client.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() - }).SetURL("http://example.com/hang-up") + }) + req.SetURL("http://example.com/hang-up") resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) @@ -205,9 +208,10 @@ func Test_Execute(t *testing.T) { t.Parallel() core, client, req := newCore(), AcquireClient(), AcquireRequest() client.SetTimeout(500 * time.Millisecond) - req.SetDial(func(addr string) (net.Conn, error) { + client.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() - }).SetURL("http://example.com/hang-up") + }) + req.SetURL("http://example.com/hang-up") _, err := core.execute(context.Background(), client, req) require.Equal(t, ErrTimeoutOrCancel, err) @@ -217,9 +221,10 @@ func Test_Execute(t *testing.T) { t.Parallel() core, client, req := newCore(), AcquireClient(), AcquireRequest() - req.SetDial(func(addr string) (net.Conn, error) { + client.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() - }).SetURL("http://example.com/hang-up"). + }) + req.SetURL("http://example.com/hang-up"). SetTimeout(300 * time.Millisecond) _, err := core.execute(context.Background(), client, req) @@ -231,9 +236,10 @@ func Test_Execute(t *testing.T) { core, client, req := newCore(), AcquireClient(), AcquireRequest() client.SetTimeout(30 * time.Millisecond) - req.SetDial(func(addr string) (net.Conn, error) { + client.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() - }).SetURL("http://example.com/hang-up"). + }) + req.SetURL("http://example.com/hang-up"). SetTimeout(3000 * time.Millisecond) resp, err := core.execute(context.Background(), client, req) diff --git a/client/helper_test.go b/client/helper_test.go index a5e0742318..fe5b70c65d 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -89,8 +89,11 @@ func testRequest(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Requ c = count[0] } + client := AcquireClient().SetDial(ln) + defer ReleaseClient(client) + for i := 0; i < c; i++ { - req := AcquireRequest().SetDial(ln) + req := AcquireRequest().SetClient(client) wrapAgent(req) resp, err := req.Get("http://example.com") @@ -114,8 +117,11 @@ func testRequestFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent * c = count[0] } + client := AcquireClient().SetDial(ln) + defer ReleaseClient(client) + for i := 0; i < c; i++ { - req := AcquireRequest().SetDial(ln) + req := AcquireRequest().SetClient(client) wrapAgent(req) _, err := req.Get("http://example.com") @@ -137,10 +143,10 @@ func testClient(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Clien } for i := 0; i < c; i++ { - client := AcquireClient() + client := AcquireClient().SetDial(ln) wrapAgent(client) - resp, err := client.Get("http://example.com", Config{Dial: ln}) + resp, err := client.Get("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) diff --git a/client/request.go b/client/request.go index 1ebd959111..d082f00e28 100644 --- a/client/request.go +++ b/client/request.go @@ -58,18 +58,9 @@ type Request struct { files []*File bodyType bodyType - dial fasthttp.DialFunc - RawRequest *fasthttp.Request } -// SetDial set HostClient dial, this method for unit test, -// maybe don't use it. -func (r *Request) SetDial(f fasthttp.DialFunc) *Request { - r.dial = f - return r -} - // Method returns http method in request. func (r *Request) Method() string { return r.method diff --git a/client/request_test.go b/client/request_test.go index 5edd843722..9df77b1422 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -111,8 +111,8 @@ func Test_Request_Header(t *testing.T) { res := req.Header("foo") require.Len(t, res, 3) require.Equal(t, "bar", res[0]) - require.Equal(t, "buaa", res[1]) - require.Equal(t, "fiber", res[2]) + require.Equal(t, "fiber", res[1]) + require.Equal(t, "buaa", res[2]) res = req.Header("bar") require.Len(t, res, 1) @@ -174,8 +174,8 @@ func Test_Request_QueryParam(t *testing.T) { res := req.Param("foo") require.Len(t, res, 3) require.Equal(t, "bar", res[0]) - require.Equal(t, "buaa", res[1]) - require.Equal(t, "fiber", res[2]) + require.Equal(t, "fiber", res[1]) + require.Equal(t, "buaa", res[2]) res = req.Param("bar") require.Len(t, res, 1) @@ -237,13 +237,13 @@ func Test_Request_QueryParam(t *testing.T) { tslice := p.Param("TSlice") require.Len(t, tslice, 2) - require.Equal(t, "bar", tslice[0]) - require.Equal(t, "foo", tslice[1]) + require.Equal(t, "foo", tslice[0]) + require.Equal(t, "bar", tslice[1]) tint := p.Param("TSlice") require.Len(t, tint, 2) - require.Equal(t, "bar", tint[0]) - require.Equal(t, "foo", tint[1]) + require.Equal(t, "foo", tint[0]) + require.Equal(t, "bar", tint[1]) }) t.Run("del params", func(t *testing.T) { @@ -629,8 +629,11 @@ func Test_Request_Get(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) + client := AcquireClient().SetDial(ln) + defer ReleaseClient(client) + for i := 0; i < 5; i++ { - req := AcquireRequest().SetDial(ln) + req := AcquireRequest().SetClient(client) resp, err := req.Get("http://example.com") require.NoError(t, err) @@ -652,9 +655,12 @@ func Test_Request_Post(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) + client := AcquireClient().SetDial(ln) + defer ReleaseClient(client) + for i := 0; i < 5; i++ { resp, err := AcquireRequest(). - SetDial(ln). + SetClient(client). SetFormData("foo", "bar"). Post("http://example.com") @@ -676,9 +682,12 @@ func Test_Request_Head(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) + client := AcquireClient().SetDial(ln) + defer ReleaseClient(client) + for i := 0; i < 5; i++ { resp, err := AcquireRequest(). - SetDial(ln). + SetClient(client). Head("http://example.com") require.NoError(t, err) @@ -699,9 +708,12 @@ func Test_Request_Put(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) + client := AcquireClient().SetDial(ln) + defer ReleaseClient(client) + for i := 0; i < 5; i++ { resp, err := AcquireRequest(). - SetDial(ln). + SetClient(client). SetFormData("foo", "bar"). Put("http://example.com") @@ -726,9 +738,12 @@ func Test_Request_Delete(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) + client := AcquireClient().SetDial(ln) + defer ReleaseClient(client) + for i := 0; i < 5; i++ { resp, err := AcquireRequest(). - SetDial(ln). + SetClient(client). Delete("http://example.com") require.NoError(t, err) @@ -752,9 +767,12 @@ func Test_Request_Options(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) + client := AcquireClient().SetDial(ln) + defer ReleaseClient(client) + for i := 0; i < 5; i++ { resp, err := AcquireRequest(). - SetDial(ln). + SetClient(client). Options("http://example.com") require.NoError(t, err) @@ -778,9 +796,12 @@ func Test_Request_Send(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) + client := AcquireClient().SetDial(ln) + defer ReleaseClient(client) + for i := 0; i < 5; i++ { resp, err := AcquireRequest(). - SetDial(ln). + SetClient(client). SetURL("http://example.com"). SetMethod(fiber.MethodPost). Send() @@ -805,9 +826,12 @@ func Test_Request_Patch(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) + client := AcquireClient().SetDial(ln) + defer ReleaseClient(client) + for i := 0; i < 5; i++ { resp, err := AcquireRequest(). - SetDial(ln). + SetClient(client). SetFormData("foo", "bar"). Patch("http://example.com") @@ -1003,8 +1027,11 @@ func Test_Request_Body_With_Server(t *testing.T) { go start() + client := AcquireClient().SetDial(ln) + defer ReleaseClient(client) + req := AcquireRequest(). - SetDial(ln). + SetClient(client). SetBoundary("myBoundary"). SetFormData("foo", "bar"). AddFiles(AcquireFile( @@ -1056,9 +1083,12 @@ func Test_Request_Body_With_Server(t *testing.T) { go start() + client := AcquireClient().SetDial(ln) + defer ReleaseClient(client) + for i := 0; i < 5; i++ { req := AcquireRequest(). - SetDial(ln). + SetClient(client). AddFiles( AcquireFile( SetFileFieldName("field1"), @@ -1091,8 +1121,11 @@ func Test_Request_Body_With_Server(t *testing.T) { go start() + client := AcquireClient().SetDial(ln) + defer ReleaseClient(client) + req := AcquireRequest(). - SetDial(ln). + SetClient(client). SetFormData("foo", "bar"). AddFiles(AcquireFile( SetFileName("hello.txt"), @@ -1177,8 +1210,11 @@ func Test_Request_Timeout_With_Server(t *testing.T) { }) go start() + client := AcquireClient().SetDial(ln) + defer ReleaseClient(client) + _, err := AcquireRequest(). - SetDial(ln). + SetClient(client). SetTimeout(50 * time.Millisecond). Get("http://example.com") @@ -1206,8 +1242,12 @@ func Test_Request_MaxRedirects(t *testing.T) { t.Run("success", func(t *testing.T) { t.Parallel() + + client := AcquireClient().SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }). + SetClient(client). SetMaxRedirects(1). Get("http://example.com?foo") body := resp.String() @@ -1222,8 +1262,12 @@ func Test_Request_MaxRedirects(t *testing.T) { t.Run("error", func(t *testing.T) { t.Parallel() + + client := AcquireClient().SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }). + SetClient(client). SetMaxRedirects(1). Get("http://example.com") @@ -1232,8 +1276,13 @@ func Test_Request_MaxRedirects(t *testing.T) { }) t.Run("MaxRedirects", func(t *testing.T) { + t.Parallel() + + client := AcquireClient().SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }) + defer ReleaseClient(client) + req := AcquireRequest(). - SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }). + SetClient(client). SetMaxRedirects(3) require.Equal(t, req.MaxRedirects(), 3) diff --git a/client/response_test.go b/client/response_test.go index 41869aef01..47646c4f5d 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -36,8 +36,12 @@ func Test_Response_Status(t *testing.T) { server := setupApp() defer server.stop() + + client := AcquireClient().SetDial(server.dial()) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(server.dial()). + SetClient(client). Get("http://example") require.NoError(t, err) @@ -50,8 +54,12 @@ func Test_Response_Status(t *testing.T) { server := setupApp() defer server.stop() + + client := AcquireClient().SetDial(server.dial()) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(server.dial()). + SetClient(client). Get("http://example/fail") require.NoError(t, err) @@ -81,8 +89,12 @@ func Test_Response_Status_Code(t *testing.T) { server := setupApp() defer server.stop() + + client := AcquireClient().SetDial(server.dial()) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(server.dial()). + SetClient(client). Get("http://example") require.NoError(t, err) @@ -95,8 +107,12 @@ func Test_Response_Status_Code(t *testing.T) { server := setupApp() defer server.stop() + + client := AcquireClient().SetDial(server.dial()) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(server.dial()). + SetClient(client). Get("http://example/fail") require.NoError(t, err) @@ -118,8 +134,11 @@ func Test_Response_Protocol(t *testing.T) { return c.SendString("foo") }) + client := AcquireClient().SetDial(server.dial()) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(server.dial()). + SetClient(client). Get("http://example") require.NoError(t, err) @@ -172,8 +191,11 @@ func Test_Response_Header(t *testing.T) { return c.SendString("helo world") }) + client := AcquireClient().SetDial(server.dial()) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(server.dial()). + SetClient(client). Get("http://example.com") require.NoError(t, err) @@ -194,8 +216,11 @@ func Test_Response_Cookie(t *testing.T) { return c.SendString("helo world") }) + client := AcquireClient().SetDial(server.dial()) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(server.dial()). + SetClient(client). Get("http://example.com") require.NoError(t, err) @@ -227,8 +252,12 @@ func Test_Response_Body(t *testing.T) { server := setupApp() defer server.stop() + + client := AcquireClient().SetDial(server.dial()) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(server.dial()). + SetClient(client). Get("http://example.com") require.NoError(t, err) @@ -241,8 +270,12 @@ func Test_Response_Body(t *testing.T) { server := setupApp() defer server.stop() + + client := AcquireClient().SetDial(server.dial()) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(server.dial()). + SetClient(client). Get("http://example.com") require.NoError(t, err) @@ -258,8 +291,12 @@ func Test_Response_Body(t *testing.T) { server := setupApp() defer server.stop() + + client := AcquireClient().SetDial(server.dial()) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(server.dial()). + SetClient(client). Get("http://example.com/json") require.NoError(t, err) @@ -280,8 +317,12 @@ func Test_Response_Body(t *testing.T) { server := setupApp() defer server.stop() + + client := AcquireClient().SetDial(server.dial()) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(server.dial()). + SetClient(client). Get("http://example.com/xml") require.NoError(t, err) @@ -312,8 +353,12 @@ func Test_Response_Save(t *testing.T) { server := setupApp() defer server.stop() + + client := AcquireClient().SetDial(server.dial()) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(server.dial()). + SetClient(client). Get("http://example.com/json") require.NoError(t, err) @@ -346,8 +391,12 @@ func Test_Response_Save(t *testing.T) { server := setupApp() defer server.stop() + + client := AcquireClient().SetDial(server.dial()) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(server.dial()). + SetClient(client). Get("http://example.com/json") require.NoError(t, err) @@ -364,8 +413,12 @@ func Test_Response_Save(t *testing.T) { server := setupApp() defer server.stop() + + client := AcquireClient().SetDial(server.dial()) + defer ReleaseClient(client) + resp, err := AcquireRequest(). - SetDial(server.dial()). + SetClient(client). Get("http://example.com/json") require.NoError(t, err) From ebc7d50c5b3e78c308cf98eb05c36e2f4409388f Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sat, 24 Feb 2024 17:38:39 +0300 Subject: [PATCH 097/118] use random port instead of 3000 --- client/hooks_test.go | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/client/hooks_test.go b/client/hooks_test.go index 85d1f83fb8..6ec5669ea6 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/xml" "io" + "net" "net/url" "strings" "testing" @@ -528,9 +529,14 @@ func Test_Client_Logger_Debug(t *testing.T) { return c.SendString("response") }) + var url string + go func() { - require.NoError(t, app.Listen(":3000", fiber.ListenConfig{ + require.NoError(t, app.Listen(":0", fiber.ListenConfig{ DisableStartupMessage: true, + ListenerAddrFunc: func(addr net.Addr) { + url = addr.String() + }, })) }() @@ -546,11 +552,11 @@ func Test_Client_Logger_Debug(t *testing.T) { client := AcquireClient() client.Debug() - resp, err := client.Get("http://localhost:3000") + resp, err := client.Get("http://" + url) defer resp.Close() require.NoError(t, err) - require.Contains(t, buf.String(), "Host: localhost:3000") + require.Contains(t, buf.String(), "Host: "+url) require.Contains(t, buf.String(), "Content-Length: 8") } @@ -561,9 +567,14 @@ func Test_Client_Logger_DisableDebug(t *testing.T) { return c.SendString("response") }) + var url string + go func() { - require.NoError(t, app.Listen(":3000", fiber.ListenConfig{ + require.NoError(t, app.Listen(":0", fiber.ListenConfig{ DisableStartupMessage: true, + ListenerAddrFunc: func(addr net.Addr) { + url = addr.String() + }, })) }() @@ -579,7 +590,7 @@ func Test_Client_Logger_DisableDebug(t *testing.T) { client := AcquireClient() client.DisableDebug() - resp, err := client.Get("http://localhost:3000") + resp, err := client.Get("http://" + url) defer resp.Close() require.NoError(t, err) From 62644827da104b7d765d9984f0fdaa518e30c28e Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 25 Feb 2024 00:58:45 +0300 Subject: [PATCH 098/118] remove client pooling and fix test suite --- client/client.go | 66 +++++++------------ client/client_test.go | 101 ++++++++++++++--------------- client/core_test.go | 18 +++--- client/helper_test.go | 14 ++-- client/hooks_test.go | 72 ++++++++++----------- client/request_test.go | 47 +++++--------- client/response_test.go | 138 +++++++++++++++++++--------------------- 7 files changed, 203 insertions(+), 253 deletions(-) diff --git a/client/client.go b/client/client.go index 2952a99765..10b11de166 100644 --- a/client/client.go +++ b/client/client.go @@ -7,7 +7,6 @@ import ( "encoding/json" "encoding/xml" "errors" - "fmt" "io" urlpkg "net/url" "os" @@ -655,56 +654,35 @@ var ( defaultClient *Client replaceMu = sync.Mutex{} defaultUserAgent = "fiber" - clientPool = &sync.Pool{ - New: func() any { - return &Client{ - client: &fasthttp.Client{}, - header: &Header{ - RequestHeader: &fasthttp.RequestHeader{}, - }, - params: &QueryParam{ - Args: fasthttp.AcquireArgs(), - }, - cookies: &Cookie{}, - path: &PathParam{}, - - userRequestHooks: []RequestHook{}, - builtinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, - userResponseHooks: []ResponseHook{}, - builtinResponseHooks: []ResponseHook{parserResponseCookie, logger}, - jsonMarshal: json.Marshal, - jsonUnmarshal: json.Unmarshal, - xmlMarshal: xml.Marshal, - xmlUnmarshal: xml.Unmarshal, - } - }, - } ) // init acquire a default client. func init() { - defaultClient = AcquireClient() + defaultClient = NewClient() } -// AcquireClient returns an empty Client object from the pool. -// -// The returned Client object may be returned to the pool with ReleaseClient when no longer needed. -// This allows reducing GC load. -func AcquireClient() *Client { - client, ok := clientPool.Get().(*Client) - if !ok { - panic(fmt.Errorf("failed to type-assert to *Client")) +// NewClient creates and returns a new Client object. +func NewClient() *Client { + return &Client{ + client: &fasthttp.Client{}, + header: &Header{ + RequestHeader: &fasthttp.RequestHeader{}, + }, + params: &QueryParam{ + Args: fasthttp.AcquireArgs(), + }, + cookies: &Cookie{}, + path: &PathParam{}, + + userRequestHooks: []RequestHook{}, + builtinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, + userResponseHooks: []ResponseHook{}, + builtinResponseHooks: []ResponseHook{parserResponseCookie, logger}, + jsonMarshal: json.Marshal, + jsonUnmarshal: json.Unmarshal, + xmlMarshal: xml.Marshal, + xmlUnmarshal: xml.Unmarshal, } - - return client -} - -// ReleaseClient returns the object acquired via AcquireClient to the pool. -// -// Do not access the released Client object, otherwise data races may occur. -func ReleaseClient(c *Client) { - c.Reset() - clientPool.Put(c) } // C get default client. diff --git a/client/client_test.go b/client/client_test.go index 0773e83da4..e7e4ac0847 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -25,7 +25,7 @@ func Test_Client_Add_Hook(t *testing.T) { t.Run("add request hooks", func(t *testing.T) { t.Parallel() - client := AcquireClient().AddRequestHook(func(c *Client, r *Request) error { + client := NewClient().AddRequestHook(func(c *Client, r *Request) error { return nil }) @@ -42,7 +42,7 @@ func Test_Client_Add_Hook(t *testing.T) { t.Run("add response hooks", func(t *testing.T) { t.Parallel() - client := AcquireClient().AddResponseHook(func(c *Client, resp *Response, r *Request) error { + client := NewClient().AddResponseHook(func(c *Client, resp *Response, r *Request) error { return nil }) @@ -63,7 +63,7 @@ func Test_Client_Marshal(t *testing.T) { t.Run("set json marshal", func(t *testing.T) { t.Parallel() - client := AcquireClient(). + client := NewClient(). SetJSONMarshal(func(_ any) ([]byte, error) { return []byte("hello"), nil }) @@ -75,7 +75,7 @@ func Test_Client_Marshal(t *testing.T) { t.Run("set json unmarshal", func(t *testing.T) { t.Parallel() - client := AcquireClient(). + client := NewClient(). SetJSONUnmarshal(func(data []byte, v any) error { return fmt.Errorf("empty json") }) @@ -86,7 +86,7 @@ func Test_Client_Marshal(t *testing.T) { t.Run("set xml marshal", func(t *testing.T) { t.Parallel() - client := AcquireClient(). + client := NewClient(). SetXMLMarshal(func(_ any) ([]byte, error) { return []byte("hello"), nil }) @@ -98,7 +98,7 @@ func Test_Client_Marshal(t *testing.T) { t.Run("set xml unmarshal", func(t *testing.T) { t.Parallel() - client := AcquireClient(). + client := NewClient(). SetXMLUnmarshal(func(_ []byte, _ any) error { return fmt.Errorf("empty xml") }) @@ -111,7 +111,7 @@ func Test_Client_Marshal(t *testing.T) { func Test_Client_SetBaseURL(t *testing.T) { t.Parallel() - client := AcquireClient().SetBaseURL("http://example.com") + client := NewClient().SetBaseURL("http://example.com") require.Equal(t, "http://example.com", client.BaseURL()) } @@ -127,7 +127,7 @@ func Test_Client_Invalid_URL(t *testing.T) { go start() - _, err := AcquireClient().SetDial(dial). + _, err := NewClient().SetDial(dial). R(). Get("http://example.com\r\n\r\nGET /\r\n\r\n") @@ -137,7 +137,7 @@ func Test_Client_Invalid_URL(t *testing.T) { func Test_Client_Unsupported_Protocol(t *testing.T) { t.Parallel() - _, err := AcquireClient(). + _, err := NewClient(). R(). Get("ftp://example.com") @@ -191,7 +191,7 @@ func Test_Get(t *testing.T) { t.Run("client get", func(t *testing.T) { t.Parallel() - resp, err := AcquireClient().SetDial(dial).Get("http://example.com") + resp, err := NewClient().SetDial(dial).Get("http://example.com") require.NoError(t, err) require.Equal(t, "example.com", utils.UnsafeString(resp.RawResponse.Body())) }) @@ -218,7 +218,7 @@ func Test_Head(t *testing.T) { t.Run("client head", func(t *testing.T) { t.Parallel() - resp, err := AcquireClient().SetDial(dial).Head("http://example.com") + resp, err := NewClient().SetDial(dial).Head("http://example.com") require.NoError(t, err) require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body())) }) @@ -254,7 +254,7 @@ func Test_Post(t *testing.T) { t.Run("client post", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { - resp, err := AcquireClient().SetDial(dial).Post("http://example.com", Config{ + resp, err := NewClient().SetDial(dial).Post("http://example.com", Config{ FormData: map[string]string{ "foo": "bar", }, @@ -295,7 +295,7 @@ func Test_Put(t *testing.T) { t.Run("client put", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { - resp, err := AcquireClient().SetDial(dial).Put("http://example.com", Config{ + resp, err := NewClient().SetDial(dial).Put("http://example.com", Config{ FormData: map[string]string{ "foo": "bar", }, @@ -337,7 +337,7 @@ func Test_Delete(t *testing.T) { t.Run("client delete", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { - resp, err := AcquireClient().SetDial(dial).Delete("http://example.com", Config{ + resp, err := NewClient().SetDial(dial).Delete("http://example.com", Config{ FormData: map[string]string{ "foo": "bar", }, @@ -375,7 +375,7 @@ func Test_Options(t *testing.T) { t.Run("client options", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { - resp, err := AcquireClient().SetDial(dial).Options("http://example.com") + resp, err := NewClient().SetDial(dial).Options("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) @@ -414,7 +414,7 @@ func Test_Patch(t *testing.T) { t.Run("client patch", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { - resp, err := AcquireClient().SetDial(dial).Patch("http://example.com", Config{ + resp, err := NewClient().SetDial(dial).Patch("http://example.com", Config{ FormData: map[string]string{ "foo": "bar", }, @@ -453,7 +453,7 @@ func Test_Client_UserAgent(t *testing.T) { t.Run("custom", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { - c := AcquireClient().SetDial(dial). + c := NewClient().SetDial(dial). SetUserAgent("ua") resp, err := c.Get("http://example.com") @@ -461,7 +461,6 @@ func Test_Client_UserAgent(t *testing.T) { require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "ua", resp.String()) - ReleaseClient(c) } }) } @@ -471,7 +470,7 @@ func Test_Client_Header(t *testing.T) { t.Run("add header", func(t *testing.T) { t.Parallel() - req := AcquireClient() + req := NewClient() req.AddHeader("foo", "bar").AddHeader("foo", "fiber") res := req.Header("foo") @@ -482,7 +481,7 @@ func Test_Client_Header(t *testing.T) { t.Run("set header", func(t *testing.T) { t.Parallel() - req := AcquireClient() + req := NewClient() req.AddHeader("foo", "bar").SetHeader("foo", "fiber") res := req.Header("foo") @@ -492,7 +491,7 @@ func Test_Client_Header(t *testing.T) { t.Run("add headers", func(t *testing.T) { t.Parallel() - req := AcquireClient() + req := NewClient() req.SetHeader("foo", "bar"). AddHeaders(map[string][]string{ "foo": {"fiber", "buaa"}, @@ -512,7 +511,7 @@ func Test_Client_Header(t *testing.T) { t.Run("set headers", func(t *testing.T) { t.Parallel() - req := AcquireClient() + req := NewClient() req.SetHeader("foo", "bar"). SetHeaders(map[string]string{ "foo": "fiber", @@ -560,7 +559,7 @@ func Test_Client_Cookie(t *testing.T) { t.Run("set cookie", func(t *testing.T) { t.Parallel() - req := AcquireClient(). + req := NewClient(). SetCookie("foo", "bar") require.Equal(t, "bar", req.Cookie("foo")) @@ -570,7 +569,7 @@ func Test_Client_Cookie(t *testing.T) { t.Run("set cookies", func(t *testing.T) { t.Parallel() - req := AcquireClient(). + req := NewClient(). SetCookies(map[string]string{ "foo": "bar", "bar": "foo", @@ -592,7 +591,7 @@ func Test_Client_Cookie(t *testing.T) { CookieString string `cookie:"string"` } - req := AcquireClient().SetCookiesWithStruct(&args{ + req := NewClient().SetCookiesWithStruct(&args{ CookieInt: 5, CookieString: "foo", }) @@ -603,7 +602,7 @@ func Test_Client_Cookie(t *testing.T) { t.Run("del cookies", func(t *testing.T) { t.Parallel() - req := AcquireClient(). + req := NewClient(). SetCookies(map[string]string{ "foo": "bar", "bar": "foo", @@ -759,7 +758,7 @@ func Test_Client_QueryParam(t *testing.T) { t.Run("add param", func(t *testing.T) { t.Parallel() - req := AcquireClient() + req := NewClient() req.AddParam("foo", "bar").AddParam("foo", "fiber") res := req.Param("foo") @@ -770,7 +769,7 @@ func Test_Client_QueryParam(t *testing.T) { t.Run("set param", func(t *testing.T) { t.Parallel() - req := AcquireClient() + req := NewClient() req.AddParam("foo", "bar").SetParam("foo", "fiber") res := req.Param("foo") @@ -780,7 +779,7 @@ func Test_Client_QueryParam(t *testing.T) { t.Run("add params", func(t *testing.T) { t.Parallel() - req := AcquireClient() + req := NewClient() req.SetParam("foo", "bar"). AddParams(map[string][]string{ "foo": {"fiber", "buaa"}, @@ -800,7 +799,7 @@ func Test_Client_QueryParam(t *testing.T) { t.Run("set headers", func(t *testing.T) { t.Parallel() - req := AcquireClient() + req := NewClient() req.SetParam("foo", "bar"). SetParams(map[string]string{ "foo": "fiber", @@ -828,7 +827,7 @@ func Test_Client_QueryParam(t *testing.T) { TIntSlice []int `param:"int_slice"` } - p := AcquireClient() + p := NewClient() p.SetParamsWithStruct(&args{ TInt: 5, TString: "string", @@ -864,7 +863,7 @@ func Test_Client_QueryParam(t *testing.T) { t.Run("del params", func(t *testing.T) { t.Parallel() - req := AcquireClient() + req := NewClient() req.SetParam("foo", "bar"). SetParams(map[string]string{ "foo": "fiber", @@ -900,7 +899,7 @@ func Test_Client_PathParam(t *testing.T) { t.Run("set path param", func(t *testing.T) { t.Parallel() - req := AcquireClient(). + req := NewClient(). SetPathParam("foo", "bar") require.Equal(t, "bar", req.PathParam("foo")) @@ -910,7 +909,7 @@ func Test_Client_PathParam(t *testing.T) { t.Run("set path params", func(t *testing.T) { t.Parallel() - req := AcquireClient(). + req := NewClient(). SetPathParams(map[string]string{ "foo": "bar", "bar": "foo", @@ -932,7 +931,7 @@ func Test_Client_PathParam(t *testing.T) { CookieString string `path:"string"` } - req := AcquireClient().SetPathParamsWithStruct(&args{ + req := NewClient().SetPathParamsWithStruct(&args{ CookieInt: 5, CookieString: "foo", }) @@ -943,7 +942,7 @@ func Test_Client_PathParam(t *testing.T) { t.Run("del path params", func(t *testing.T) { t.Parallel() - req := AcquireClient(). + req := NewClient(). SetPathParams(map[string]string{ "foo": "bar", "bar": "foo", @@ -966,7 +965,7 @@ func Test_Client_PathParam_With_Server(t *testing.T) { go start() - resp, err := AcquireClient().SetDial(dial). + resp, err := NewClient().SetDial(dial). SetPathParam("path", "test"). Get("http://example.com/:path") @@ -997,7 +996,7 @@ func Test_Client_TLS(t *testing.T) { })) }() - client := AcquireClient() + client := NewClient() resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) require.NoError(t, err) @@ -1028,7 +1027,7 @@ func Test_Client_TLS_Empty_TLSConfig(t *testing.T) { })) }() - client := AcquireClient() + client := NewClient() resp, err := client.Get("https://" + ln.Addr().String()) require.Error(t, err) @@ -1042,14 +1041,14 @@ func Test_Client_SetCertificates(t *testing.T) { serverTLSConf, _, err := tlstest.GetTLSConfigs() require.NoError(t, err) - client := AcquireClient().SetCertificates(serverTLSConf.Certificates...) + client := NewClient().SetCertificates(serverTLSConf.Certificates...) require.Len(t, client.TLSConfig().Certificates, 1) } func Test_Client_SetRootCertificate(t *testing.T) { t.Parallel() - client := AcquireClient().SetRootCertificate("../.github/testdata/ssl.pem") + client := NewClient().SetRootCertificate("../.github/testdata/ssl.pem") require.NotNil(t, client.TLSConfig().RootCAs) } @@ -1063,14 +1062,14 @@ func Test_Client_SetRootCertificateFromString(t *testing.T) { pem, err := io.ReadAll(file) require.NoError(t, err) - client := AcquireClient().SetRootCertificateFromString(string(pem)) + client := NewClient().SetRootCertificateFromString(string(pem)) require.NotNil(t, client.TLSConfig().RootCAs) } func Test_Client_R(t *testing.T) { t.Parallel() - client := AcquireClient() + client := NewClient() req := client.R() require.Equal(t, "Request", reflect.TypeOf(req).Elem().Name()) @@ -1093,7 +1092,7 @@ func Test_Replace(t *testing.T) { require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "", resp.String()) - r := AcquireClient().SetDial(dial).SetHeader("k1", "v1") + r := NewClient().SetDial(dial).SetHeader("k1", "v1") clean := Replace(r) resp, err = Get("http://example.com") require.NoError(t, err) @@ -1101,7 +1100,6 @@ func Test_Replace(t *testing.T) { require.Equal(t, "v1", resp.String()) clean() - ReleaseClient(r) C().SetDial(dial) resp, err = Get("http://example.com") @@ -1249,7 +1247,7 @@ func Test_Client_SetProxyURL(t *testing.T) { t.Run("success", func(t *testing.T) { t.Parallel() - client := AcquireClient().SetDial(dial) + client := NewClient().SetDial(dial) client.SetProxyURL("http://test.com") _, err := client.Get("http://localhost:3000") @@ -1258,8 +1256,7 @@ func Test_Client_SetProxyURL(t *testing.T) { t.Run("wrong url", func(t *testing.T) { t.Parallel() - client := AcquireClient() - defer ReleaseClient(client) + client := NewClient() require.Panics(t, func() { client.SetProxyURL(":this is not a url") @@ -1268,8 +1265,7 @@ func Test_Client_SetProxyURL(t *testing.T) { t.Run("error", func(t *testing.T) { t.Parallel() - client := AcquireClient() - defer ReleaseClient(client) + client := NewClient() require.Panics(t, func() { client.SetProxyURL("htgdftp://test.com") @@ -1285,7 +1281,7 @@ func Test_Client_SetRetryConfig(t *testing.T) { MaxRetryCount: 3, } - core, client, req := newCore(), AcquireClient(), AcquireRequest() + core, client, req := newCore(), NewClient(), AcquireRequest() req.SetURL("http://example.com") client.SetRetryConfig(retryConfig) _, err := core.execute(context.Background(), client, req) @@ -1303,8 +1299,7 @@ func Benchmark_Client_Request(b *testing.B) { go start() - client := AcquireClient().SetDial(dial) - defer ReleaseClient(client) + client := NewClient().SetDial(dial) b.ResetTimer() b.ReportAllocs() diff --git a/client/core_test.go b/client/core_test.go index e8b3ce3eb5..f3ebb8f667 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -82,7 +82,7 @@ func Test_Exec_Func(t *testing.T) { t.Run("normal request", func(t *testing.T) { t.Parallel() - core, client, req := newCore(), AcquireClient(), AcquireRequest() + core, client, req := newCore(), NewClient(), AcquireRequest() core.ctx = context.Background() core.client = client core.req = req @@ -99,7 +99,7 @@ func Test_Exec_Func(t *testing.T) { t.Run("the request return an error", func(t *testing.T) { t.Parallel() - core, client, req := newCore(), AcquireClient(), AcquireRequest() + core, client, req := newCore(), NewClient(), AcquireRequest() core.ctx = context.Background() core.client = client core.req = req @@ -116,7 +116,7 @@ func Test_Exec_Func(t *testing.T) { t.Run("the request timeout", func(t *testing.T) { t.Parallel() - core, client, req := newCore(), AcquireClient(), AcquireRequest() + core, client, req := newCore(), NewClient(), AcquireRequest() ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() @@ -158,7 +158,7 @@ func Test_Execute(t *testing.T) { t.Run("add user request hooks", func(t *testing.T) { t.Parallel() - core, client, req := newCore(), AcquireClient(), AcquireRequest() + core, client, req := newCore(), NewClient(), AcquireRequest() client.AddRequestHook(func(_ *Client, _ *Request) error { require.Equal(t, "http://example.com", req.URL()) return nil @@ -175,7 +175,7 @@ func Test_Execute(t *testing.T) { t.Run("add user response hooks", func(t *testing.T) { t.Parallel() - core, client, req := newCore(), AcquireClient(), AcquireRequest() + core, client, req := newCore(), NewClient(), AcquireRequest() client.AddResponseHook(func(c *Client, resp *Response, req *Request) error { require.Equal(t, "http://example.com", req.URL()) return nil @@ -192,7 +192,7 @@ func Test_Execute(t *testing.T) { t.Run("no timeout", func(t *testing.T) { t.Parallel() - core, client, req := newCore(), AcquireClient(), AcquireRequest() + core, client, req := newCore(), NewClient(), AcquireRequest() client.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() @@ -206,7 +206,7 @@ func Test_Execute(t *testing.T) { t.Run("client timeout", func(t *testing.T) { t.Parallel() - core, client, req := newCore(), AcquireClient(), AcquireRequest() + core, client, req := newCore(), NewClient(), AcquireRequest() client.SetTimeout(500 * time.Millisecond) client.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() @@ -219,7 +219,7 @@ func Test_Execute(t *testing.T) { t.Run("request timeout", func(t *testing.T) { t.Parallel() - core, client, req := newCore(), AcquireClient(), AcquireRequest() + core, client, req := newCore(), NewClient(), AcquireRequest() client.SetDial(func(addr string) (net.Conn, error) { return ln.Dial() @@ -233,7 +233,7 @@ func Test_Execute(t *testing.T) { t.Run("request timeout has higher level", func(t *testing.T) { t.Parallel() - core, client, req := newCore(), AcquireClient(), AcquireRequest() + core, client, req := newCore(), NewClient(), AcquireRequest() client.SetTimeout(30 * time.Millisecond) client.SetDial(func(addr string) (net.Conn, error) { diff --git a/client/helper_test.go b/client/helper_test.go index fe5b70c65d..2b935a8fab 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -17,12 +17,16 @@ type testServer struct { tb testing.TB } -func startTestServer(tb testing.TB) *testServer { +func startTestServer(tb testing.TB, beforeStarting func(app *fiber.App)) *testServer { tb.Helper() ln := fasthttputil.NewInmemoryListener() app := fiber.New() + if beforeStarting != nil { + beforeStarting(app) + } + ch := make(chan struct{}) go func() { if err := app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}); err != nil { @@ -89,8 +93,7 @@ func testRequest(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Requ c = count[0] } - client := AcquireClient().SetDial(ln) - defer ReleaseClient(client) + client := NewClient().SetDial(ln) for i := 0; i < c; i++ { req := AcquireRequest().SetClient(client) @@ -117,8 +120,7 @@ func testRequestFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent * c = count[0] } - client := AcquireClient().SetDial(ln) - defer ReleaseClient(client) + client := NewClient().SetDial(ln) for i := 0; i < c; i++ { req := AcquireRequest().SetClient(client) @@ -143,7 +145,7 @@ func testClient(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Clien } for i := 0; i < c; i++ { - client := AcquireClient().SetDial(ln) + client := NewClient().SetDial(ln) wrapAgent(client) resp, err := client.Get("http://example.com") diff --git a/client/hooks_test.go b/client/hooks_test.go index 6ec5669ea6..945e21aefc 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -41,7 +41,7 @@ func Test_Parser_Request_URL(t *testing.T) { t.Run("client baseurl should be set", func(t *testing.T) { t.Parallel() - client := AcquireClient().SetBaseURL("http://example.com/api") + client := NewClient().SetBaseURL("http://example.com/api") req := AcquireRequest().SetURL("") err := parserRequestURL(client, req) @@ -51,7 +51,7 @@ func Test_Parser_Request_URL(t *testing.T) { t.Run("request url should be set", func(t *testing.T) { t.Parallel() - client := AcquireClient() + client := NewClient() req := AcquireRequest().SetURL("http://example.com/api") err := parserRequestURL(client, req) @@ -61,7 +61,7 @@ func Test_Parser_Request_URL(t *testing.T) { t.Run("the request url will override baseurl with protocol", func(t *testing.T) { t.Parallel() - client := AcquireClient().SetBaseURL("http://example.com/api") + client := NewClient().SetBaseURL("http://example.com/api") req := AcquireRequest().SetURL("http://example.com/api/v1") err := parserRequestURL(client, req) @@ -71,7 +71,7 @@ func Test_Parser_Request_URL(t *testing.T) { t.Run("the request url should be append after baseurl without protocol", func(t *testing.T) { t.Parallel() - client := AcquireClient().SetBaseURL("http://example.com/api") + client := NewClient().SetBaseURL("http://example.com/api") req := AcquireRequest().SetURL("/v1") err := parserRequestURL(client, req) @@ -81,7 +81,7 @@ func Test_Parser_Request_URL(t *testing.T) { t.Run("the url is error", func(t *testing.T) { t.Parallel() - client := AcquireClient().SetBaseURL("example.com/api") + client := NewClient().SetBaseURL("example.com/api") req := AcquireRequest().SetURL("/v1") err := parserRequestURL(client, req) @@ -90,7 +90,7 @@ func Test_Parser_Request_URL(t *testing.T) { t.Run("the path param from client", func(t *testing.T) { t.Parallel() - client := AcquireClient(). + client := NewClient(). SetBaseURL("http://example.com/api/:id"). SetPathParam("id", "5") req := AcquireRequest() @@ -102,7 +102,7 @@ func Test_Parser_Request_URL(t *testing.T) { t.Run("the path param from request", func(t *testing.T) { t.Parallel() - client := AcquireClient(). + client := NewClient(). SetBaseURL("http://example.com/api/:id/:name"). SetPathParam("id", "5") req := AcquireRequest(). @@ -120,7 +120,7 @@ func Test_Parser_Request_URL(t *testing.T) { t.Run("the path param from request and client", func(t *testing.T) { t.Parallel() - client := AcquireClient(). + client := NewClient(). SetBaseURL("http://example.com/api/:id/:name"). SetPathParam("id", "5") req := AcquireRequest(). @@ -138,7 +138,7 @@ func Test_Parser_Request_URL(t *testing.T) { t.Run("query params from client should be set", func(t *testing.T) { t.Parallel() - client := AcquireClient(). + client := NewClient(). SetParam("foo", "bar") req := AcquireRequest().SetURL("http://example.com/api/v1") @@ -149,7 +149,7 @@ func Test_Parser_Request_URL(t *testing.T) { t.Run("query params from request should be set", func(t *testing.T) { t.Parallel() - client := AcquireClient() + client := NewClient() req := AcquireRequest(). SetURL("http://example.com/api/v1"). SetParam("bar", "foo") @@ -161,7 +161,7 @@ func Test_Parser_Request_URL(t *testing.T) { t.Run("query params should be merged", func(t *testing.T) { t.Parallel() - client := AcquireClient(). + client := NewClient(). SetParam("bar", "foo1") req := AcquireRequest(). SetURL("http://example.com/api/v1?bar=foo2"). @@ -193,7 +193,7 @@ func Test_Parser_Request_Header(t *testing.T) { t.Run("client header should be set", func(t *testing.T) { t.Parallel() - client := AcquireClient(). + client := NewClient(). SetHeaders(map[string]string{ fiber.HeaderContentType: "application/json", }) @@ -207,7 +207,7 @@ func Test_Parser_Request_Header(t *testing.T) { t.Run("request header should be set", func(t *testing.T) { t.Parallel() - client := AcquireClient() + client := NewClient() req := AcquireRequest(). SetHeaders(map[string]string{ @@ -221,7 +221,7 @@ func Test_Parser_Request_Header(t *testing.T) { t.Run("request header should override client header", func(t *testing.T) { t.Parallel() - client := AcquireClient(). + client := NewClient(). SetHeader(fiber.HeaderContentType, "application/xml") req := AcquireRequest(). @@ -237,7 +237,7 @@ func Test_Parser_Request_Header(t *testing.T) { type jsonData struct { Name string `json:"name"` } - client := AcquireClient() + client := NewClient() req := AcquireRequest(). SetJSON(jsonData{ Name: "foo", @@ -254,7 +254,7 @@ func Test_Parser_Request_Header(t *testing.T) { XMLName xml.Name `xml:"body"` Name string `xml:"name"` } - client := AcquireClient() + client := NewClient() req := AcquireRequest(). SetXML(xmlData{ Name: "foo", @@ -267,7 +267,7 @@ func Test_Parser_Request_Header(t *testing.T) { t.Run("auto set form data header", func(t *testing.T) { t.Parallel() - client := AcquireClient() + client := NewClient() req := AcquireRequest(). SetFormDatas(map[string]string{ "foo": "bar", @@ -281,7 +281,7 @@ func Test_Parser_Request_Header(t *testing.T) { t.Run("auto set file header", func(t *testing.T) { t.Parallel() - client := AcquireClient() + client := NewClient() req := AcquireRequest(). AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))). SetFormData("foo", "bar") @@ -294,7 +294,7 @@ func Test_Parser_Request_Header(t *testing.T) { t.Run("ua should have default value", func(t *testing.T) { t.Parallel() - client := AcquireClient() + client := NewClient() req := AcquireRequest() err := parserRequestHeader(client, req) @@ -304,7 +304,7 @@ func Test_Parser_Request_Header(t *testing.T) { t.Run("ua in client should be set", func(t *testing.T) { t.Parallel() - client := AcquireClient().SetUserAgent("foo") + client := NewClient().SetUserAgent("foo") req := AcquireRequest() err := parserRequestHeader(client, req) @@ -314,7 +314,7 @@ func Test_Parser_Request_Header(t *testing.T) { t.Run("ua in request should have higher level", func(t *testing.T) { t.Parallel() - client := AcquireClient().SetUserAgent("foo") + client := NewClient().SetUserAgent("foo") req := AcquireRequest().SetUserAgent("bar") err := parserRequestHeader(client, req) @@ -324,7 +324,7 @@ func Test_Parser_Request_Header(t *testing.T) { t.Run("referer in client should be set", func(t *testing.T) { t.Parallel() - client := AcquireClient().SetReferer("https://example.com") + client := NewClient().SetReferer("https://example.com") req := AcquireRequest() err := parserRequestHeader(client, req) @@ -334,7 +334,7 @@ func Test_Parser_Request_Header(t *testing.T) { t.Run("referer in request should have higher level", func(t *testing.T) { t.Parallel() - client := AcquireClient().SetReferer("http://example.com") + client := NewClient().SetReferer("http://example.com") req := AcquireRequest().SetReferer("https://example.com") err := parserRequestHeader(client, req) @@ -344,7 +344,7 @@ func Test_Parser_Request_Header(t *testing.T) { t.Run("client cookie should be set", func(t *testing.T) { t.Parallel() - client := AcquireClient(). + client := NewClient(). SetCookie("foo", "bar"). SetCookies(map[string]string{ "bar": "foo", @@ -368,7 +368,7 @@ func Test_Parser_Request_Header(t *testing.T) { Bar int `cookie:"bar"` } - client := AcquireClient() + client := NewClient() req := AcquireRequest(). SetCookiesWithStruct(&cookies{ @@ -390,7 +390,7 @@ func Test_Parser_Request_Header(t *testing.T) { Bar int `cookie:"bar"` } - client := AcquireClient(). + client := NewClient(). SetCookie("foo", "bar"). SetCookies(map[string]string{ "bar": "foo", @@ -419,7 +419,7 @@ func Test_Parser_Request_Body(t *testing.T) { type jsonData struct { Name string `json:"name"` } - client := AcquireClient() + client := NewClient() req := AcquireRequest(). SetJSON(jsonData{ Name: "foo", @@ -436,7 +436,7 @@ func Test_Parser_Request_Body(t *testing.T) { XMLName xml.Name `xml:"body"` Name string `xml:"name"` } - client := AcquireClient() + client := NewClient() req := AcquireRequest(). SetXML(xmlData{ Name: "foo", @@ -449,7 +449,7 @@ func Test_Parser_Request_Body(t *testing.T) { t.Run("form data body", func(t *testing.T) { t.Parallel() - client := AcquireClient() + client := NewClient() req := AcquireRequest(). SetFormDatas(map[string]string{ "ball": "cricle and square", @@ -462,7 +462,7 @@ func Test_Parser_Request_Body(t *testing.T) { t.Run("form data body error", func(t *testing.T) { t.Parallel() - client := AcquireClient() + client := NewClient() req := AcquireRequest(). SetFormDatas(map[string]string{ "": "", @@ -474,7 +474,7 @@ func Test_Parser_Request_Body(t *testing.T) { t.Run("file body", func(t *testing.T) { t.Parallel() - client := AcquireClient() + client := NewClient() req := AcquireRequest(). AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))) @@ -486,7 +486,7 @@ func Test_Parser_Request_Body(t *testing.T) { t.Run("file and form data", func(t *testing.T) { t.Parallel() - client := AcquireClient() + client := NewClient() req := AcquireRequest(). AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))). SetFormData("foo", "bar") @@ -500,7 +500,7 @@ func Test_Parser_Request_Body(t *testing.T) { t.Run("raw body", func(t *testing.T) { t.Parallel() - client := AcquireClient() + client := NewClient() req := AcquireRequest(). SetRawBody([]byte("hello world")) @@ -511,7 +511,7 @@ func Test_Parser_Request_Body(t *testing.T) { t.Run("raw body error", func(t *testing.T) { t.Parallel() - client := AcquireClient() + client := NewClient() req := AcquireRequest(). SetRawBody([]byte("hello world")) @@ -549,7 +549,7 @@ func Test_Client_Logger_Debug(t *testing.T) { var buf bytes.Buffer log.SetOutput(&buf) - client := AcquireClient() + client := NewClient() client.Debug() resp, err := client.Get("http://" + url) @@ -587,7 +587,7 @@ func Test_Client_Logger_DisableDebug(t *testing.T) { var buf bytes.Buffer log.SetOutput(&buf) - client := AcquireClient() + client := NewClient() client.DisableDebug() resp, err := client.Get("http://" + url) diff --git a/client/request_test.go b/client/request_test.go index 9df77b1422..ff0b5b81f6 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -52,7 +52,7 @@ func Test_Request_URL(t *testing.T) { func Test_Request_Client(t *testing.T) { t.Parallel() - client := AcquireClient() + client := NewClient() req := AcquireRequest() req.SetClient(client) @@ -629,8 +629,7 @@ func Test_Request_Get(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) - client := AcquireClient().SetDial(ln) - defer ReleaseClient(client) + client := NewClient().SetDial(ln) for i := 0; i < 5; i++ { req := AcquireRequest().SetClient(client) @@ -655,8 +654,7 @@ func Test_Request_Post(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) - client := AcquireClient().SetDial(ln) - defer ReleaseClient(client) + client := NewClient().SetDial(ln) for i := 0; i < 5; i++ { resp, err := AcquireRequest(). @@ -682,8 +680,7 @@ func Test_Request_Head(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) - client := AcquireClient().SetDial(ln) - defer ReleaseClient(client) + client := NewClient().SetDial(ln) for i := 0; i < 5; i++ { resp, err := AcquireRequest(). @@ -708,8 +705,7 @@ func Test_Request_Put(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) - client := AcquireClient().SetDial(ln) - defer ReleaseClient(client) + client := NewClient().SetDial(ln) for i := 0; i < 5; i++ { resp, err := AcquireRequest(). @@ -738,8 +734,7 @@ func Test_Request_Delete(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) - client := AcquireClient().SetDial(ln) - defer ReleaseClient(client) + client := NewClient().SetDial(ln) for i := 0; i < 5; i++ { resp, err := AcquireRequest(). @@ -767,8 +762,7 @@ func Test_Request_Options(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) - client := AcquireClient().SetDial(ln) - defer ReleaseClient(client) + client := NewClient().SetDial(ln) for i := 0; i < 5; i++ { resp, err := AcquireRequest(). @@ -796,8 +790,7 @@ func Test_Request_Send(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) - client := AcquireClient().SetDial(ln) - defer ReleaseClient(client) + client := NewClient().SetDial(ln) for i := 0; i < 5; i++ { resp, err := AcquireRequest(). @@ -826,8 +819,7 @@ func Test_Request_Patch(t *testing.T) { go start() time.Sleep(100 * time.Millisecond) - client := AcquireClient().SetDial(ln) - defer ReleaseClient(client) + client := NewClient().SetDial(ln) for i := 0; i < 5; i++ { resp, err := AcquireRequest(). @@ -1027,8 +1019,7 @@ func Test_Request_Body_With_Server(t *testing.T) { go start() - client := AcquireClient().SetDial(ln) - defer ReleaseClient(client) + client := NewClient().SetDial(ln) req := AcquireRequest(). SetClient(client). @@ -1083,8 +1074,7 @@ func Test_Request_Body_With_Server(t *testing.T) { go start() - client := AcquireClient().SetDial(ln) - defer ReleaseClient(client) + client := NewClient().SetDial(ln) for i := 0; i < 5; i++ { req := AcquireRequest(). @@ -1121,8 +1111,7 @@ func Test_Request_Body_With_Server(t *testing.T) { go start() - client := AcquireClient().SetDial(ln) - defer ReleaseClient(client) + client := NewClient().SetDial(ln) req := AcquireRequest(). SetClient(client). @@ -1210,8 +1199,7 @@ func Test_Request_Timeout_With_Server(t *testing.T) { }) go start() - client := AcquireClient().SetDial(ln) - defer ReleaseClient(client) + client := NewClient().SetDial(ln) _, err := AcquireRequest(). SetClient(client). @@ -1243,8 +1231,7 @@ func Test_Request_MaxRedirects(t *testing.T) { t.Run("success", func(t *testing.T) { t.Parallel() - client := AcquireClient().SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }) - defer ReleaseClient(client) + client := NewClient().SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }) resp, err := AcquireRequest(). SetClient(client). @@ -1263,8 +1250,7 @@ func Test_Request_MaxRedirects(t *testing.T) { t.Run("error", func(t *testing.T) { t.Parallel() - client := AcquireClient().SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }) - defer ReleaseClient(client) + client := NewClient().SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }) resp, err := AcquireRequest(). SetClient(client). @@ -1278,8 +1264,7 @@ func Test_Request_MaxRedirects(t *testing.T) { t.Run("MaxRedirects", func(t *testing.T) { t.Parallel() - client := AcquireClient().SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }) - defer ReleaseClient(client) + client := NewClient().SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }) req := AcquireRequest(). SetClient(client). diff --git a/client/response_test.go b/client/response_test.go index 47646c4f5d..c474c40a20 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -19,13 +19,13 @@ func Test_Response_Status(t *testing.T) { t.Parallel() setupApp := func() *testServer { - server := startTestServer(t) - - server.app.Get("/", func(c fiber.Ctx) error { - return c.SendString("foo") - }) - server.app.Get("/fail", func(c fiber.Ctx) error { - return c.SendStatus(407) + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("foo") + }) + app.Get("/fail", func(c fiber.Ctx) error { + return c.SendStatus(407) + }) }) return server @@ -37,8 +37,7 @@ func Test_Response_Status(t *testing.T) { server := setupApp() defer server.stop() - client := AcquireClient().SetDial(server.dial()) - defer ReleaseClient(client) + client := NewClient().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). @@ -55,8 +54,7 @@ func Test_Response_Status(t *testing.T) { server := setupApp() defer server.stop() - client := AcquireClient().SetDial(server.dial()) - defer ReleaseClient(client) + client := NewClient().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). @@ -72,13 +70,13 @@ func Test_Response_Status_Code(t *testing.T) { t.Parallel() setupApp := func() *testServer { - server := startTestServer(t) - - server.app.Get("/", func(c fiber.Ctx) error { - return c.SendString("foo") - }) - server.app.Get("/fail", func(c fiber.Ctx) error { - return c.SendStatus(407) + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("foo") + }) + app.Get("/fail", func(c fiber.Ctx) error { + return c.SendStatus(407) + }) }) return server @@ -90,8 +88,7 @@ func Test_Response_Status_Code(t *testing.T) { server := setupApp() defer server.stop() - client := AcquireClient().SetDial(server.dial()) - defer ReleaseClient(client) + client := NewClient().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). @@ -108,8 +105,7 @@ func Test_Response_Status_Code(t *testing.T) { server := setupApp() defer server.stop() - client := AcquireClient().SetDial(server.dial()) - defer ReleaseClient(client) + client := NewClient().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). @@ -127,15 +123,14 @@ func Test_Response_Protocol(t *testing.T) { t.Run("http", func(t *testing.T) { t.Parallel() - server := startTestServer(t) - defer server.stop() - - server.app.Get("/", func(c fiber.Ctx) error { - return c.SendString("foo") + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("foo") + }) }) + defer server.stop() - client := AcquireClient().SetDial(server.dial()) - defer ReleaseClient(client) + client := NewClient().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). @@ -168,7 +163,7 @@ func Test_Response_Protocol(t *testing.T) { })) }() - client := AcquireClient() + client := NewClient() resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) require.NoError(t, err) @@ -184,15 +179,15 @@ func Test_Response_Protocol(t *testing.T) { func Test_Response_Header(t *testing.T) { t.Parallel() - server := startTestServer(t) - defer server.stop() - server.app.Get("/", func(c fiber.Ctx) error { - c.Response().Header.Add("foo", "bar") - return c.SendString("helo world") + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + c.Response().Header.Add("foo", "bar") + return c.SendString("helo world") + }) }) + defer server.stop() - client := AcquireClient().SetDial(server.dial()) - defer ReleaseClient(client) + client := NewClient().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). @@ -206,18 +201,18 @@ func Test_Response_Header(t *testing.T) { func Test_Response_Cookie(t *testing.T) { t.Parallel() - server := startTestServer(t) - defer server.stop() - server.app.Get("/", func(c fiber.Ctx) error { - c.Cookie(&fiber.Cookie{ - Name: "foo", - Value: "bar", + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + c.Cookie(&fiber.Cookie{ + Name: "foo", + Value: "bar", + }) + return c.SendString("helo world") }) - return c.SendString("helo world") }) + defer server.stop() - client := AcquireClient().SetDial(server.dial()) - defer ReleaseClient(client) + client := NewClient().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). @@ -232,16 +227,18 @@ func Test_Response_Body(t *testing.T) { t.Parallel() setupApp := func() *testServer { - server := startTestServer(t) - - server.app.Get("/", func(c fiber.Ctx) error { - return c.SendString("hello world") - }) - server.app.Get("/json", func(c fiber.Ctx) error { - return c.SendString("{\"status\":\"success\"}") - }) - server.app.Get("/xml", func(c fiber.Ctx) error { - return c.SendString("success") + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) + + app.Get("/json", func(c fiber.Ctx) error { + return c.SendString("{\"status\":\"success\"}") + }) + + app.Get("/xml", func(c fiber.Ctx) error { + return c.SendString("success") + }) }) return server @@ -253,8 +250,7 @@ func Test_Response_Body(t *testing.T) { server := setupApp() defer server.stop() - client := AcquireClient().SetDial(server.dial()) - defer ReleaseClient(client) + client := NewClient().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). @@ -271,8 +267,7 @@ func Test_Response_Body(t *testing.T) { server := setupApp() defer server.stop() - client := AcquireClient().SetDial(server.dial()) - defer ReleaseClient(client) + client := NewClient().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). @@ -292,8 +287,7 @@ func Test_Response_Body(t *testing.T) { server := setupApp() defer server.stop() - client := AcquireClient().SetDial(server.dial()) - defer ReleaseClient(client) + client := NewClient().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). @@ -318,8 +312,7 @@ func Test_Response_Body(t *testing.T) { server := setupApp() defer server.stop() - client := AcquireClient().SetDial(server.dial()) - defer ReleaseClient(client) + client := NewClient().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). @@ -339,10 +332,10 @@ func Test_Response_Save(t *testing.T) { t.Parallel() setupApp := func() *testServer { - server := startTestServer(t) - - server.app.Get("/json", func(c fiber.Ctx) error { - return c.SendString("{\"status\":\"success\"}") + server := startTestServer(t, func(app *fiber.App) { + app.Get("/json", func(c fiber.Ctx) error { + return c.SendString("{\"status\":\"success\"}") + }) }) return server @@ -354,8 +347,7 @@ func Test_Response_Save(t *testing.T) { server := setupApp() defer server.stop() - client := AcquireClient().SetDial(server.dial()) - defer ReleaseClient(client) + client := NewClient().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). @@ -392,8 +384,7 @@ func Test_Response_Save(t *testing.T) { server := setupApp() defer server.stop() - client := AcquireClient().SetDial(server.dial()) - defer ReleaseClient(client) + client := NewClient().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). @@ -414,8 +405,7 @@ func Test_Response_Save(t *testing.T) { server := setupApp() defer server.stop() - client := AcquireClient().SetDial(server.dial()) - defer ReleaseClient(client) + client := NewClient().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). From b978fda33a84bf86cc23cc892d046729f0b94420 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 25 Feb 2024 01:46:08 +0300 Subject: [PATCH 099/118] fix data races on logger tests --- client/client.go | 32 +++++++++++++++---- client/hooks.go | 5 ++- client/hooks_test.go | 76 ++++++++++++++++++++++++++++++++++---------- 3 files changed, 87 insertions(+), 26 deletions(-) diff --git a/client/client.go b/client/client.go index 10b11de166..1b8d107a1f 100644 --- a/client/client.go +++ b/client/client.go @@ -15,6 +15,7 @@ import ( "time" "github.com/gofiber/fiber/v3/log" + "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" @@ -72,6 +73,9 @@ type Client struct { // retry retryConfig *RetryConfig + + // logger + logger log.CommonLogger } // R raise a request from the client. @@ -181,7 +185,7 @@ func (c *Client) SetRootCertificate(path string) *Client { cleanPath := filepath.Clean(path) file, err := os.Open(cleanPath) if err != nil { - log.Panicf("client: %v", err) + c.logger.Panicf("client: %v", err) } defer func() { _ = file.Close() //nolint:errcheck // It is fine to ignore the error here @@ -189,7 +193,7 @@ func (c *Client) SetRootCertificate(path string) *Client { pem, err := io.ReadAll(file) if err != nil { - log.Panicf("client: %v", err) + c.logger.Panicf("client: %v", err) } config := c.TLSConfig() @@ -198,7 +202,7 @@ func (c *Client) SetRootCertificate(path string) *Client { } if !config.RootCAs.AppendCertsFromPEM(pem) { - log.Panicf("client: %v", ErrFailedToAppendCert) + c.logger.Panicf("client: %v", ErrFailedToAppendCert) } return c @@ -213,7 +217,7 @@ func (c *Client) SetRootCertificateFromString(pem string) *Client { } if !config.RootCAs.AppendCertsFromPEM([]byte(pem)) { - log.Panicf("client: %v", ErrFailedToAppendCert) + c.logger.Panicf("client: %v", ErrFailedToAppendCert) } return c @@ -223,12 +227,12 @@ func (c *Client) SetRootCertificateFromString(pem string) *Client { func (c *Client) SetProxyURL(proxyURL string) *Client { pURL, err := urlpkg.Parse(proxyURL) if err != nil { - log.Panicf("client: %v", err) + c.logger.Panicf("client: %v", err) return c } if pURL.Scheme != "http" && pURL.Scheme != "https" { - log.Panicf("client: %v", ErrInvalidProxyURL) + c.logger.Panicf("client: %v", ErrInvalidProxyURL) return c } @@ -546,6 +550,7 @@ func (c *Client) Custom(url string, method string, cfg ...Config) (*Response, er return req.Custom(url, method) } +// SetDial sets dial function in client. func (c *Client) SetDial(dial fasthttp.DialFunc) *Client { c.mu.Lock() defer c.mu.Unlock() @@ -554,6 +559,20 @@ func (c *Client) SetDial(dial fasthttp.DialFunc) *Client { return c } +// SetLogger sets logger instance in client. +func (c *Client) SetLogger(logger log.CommonLogger) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.logger = logger + return c +} + +// Logger returns logger instance of client. +func (c *Client) Logger() log.CommonLogger { + return c.logger +} + // Reset clear Client object func (c *Client) Reset() { c.client = &fasthttp.Client{} @@ -682,6 +701,7 @@ func NewClient() *Client { jsonUnmarshal: json.Unmarshal, xmlMarshal: xml.Marshal, xmlUnmarshal: xml.Unmarshal, + logger: log.DefaultLogger(), } } diff --git a/client/hooks.go b/client/hooks.go index 8b61d8c12a..b1c76a326d 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -11,7 +11,6 @@ import ( "strings" "time" - "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -300,8 +299,8 @@ func logger(c *Client, resp *Response, req *Request) error { return nil } - log.Debugf("%s\n", req.RawRequest.String()) - log.Debugf("%s\n", resp.RawResponse.String()) + c.logger.Debugf("%s\n", req.RawRequest.String()) + c.logger.Debugf("%s\n", resp.RawResponse.String()) return nil } diff --git a/client/hooks_test.go b/client/hooks_test.go index 945e21aefc..e5410bed94 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -3,14 +3,12 @@ package client import ( "bytes" "encoding/xml" + "fmt" "io" "net" "net/url" "strings" "testing" - "time" - - "github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" @@ -522,6 +520,54 @@ func Test_Parser_Request_Body(t *testing.T) { }) } +type dummyLogger struct { + buf *bytes.Buffer +} + +func (l *dummyLogger) Trace(v ...any) {} + +func (l *dummyLogger) Debug(v ...any) {} + +func (l *dummyLogger) Info(v ...any) {} + +func (l *dummyLogger) Warn(v ...any) {} + +func (l *dummyLogger) Error(v ...any) {} + +func (l *dummyLogger) Fatal(v ...any) {} + +func (l *dummyLogger) Panic(v ...any) {} + +func (l *dummyLogger) Tracef(format string, v ...any) {} + +func (l *dummyLogger) Debugf(format string, v ...any) { + l.buf.WriteString(fmt.Sprintf(format, v...)) +} + +func (l *dummyLogger) Infof(format string, v ...any) {} + +func (l *dummyLogger) Warnf(format string, v ...any) {} + +func (l *dummyLogger) Errorf(format string, v ...any) {} + +func (l *dummyLogger) Fatalf(format string, v ...any) {} + +func (l *dummyLogger) Panicf(format string, v ...any) {} + +func (l *dummyLogger) Tracew(msg string, keysAndValues ...any) {} + +func (l *dummyLogger) Debugw(msg string, keysAndValues ...any) {} + +func (l *dummyLogger) Infow(msg string, keysAndValues ...any) {} + +func (l *dummyLogger) Warnw(msg string, keysAndValues ...any) {} + +func (l *dummyLogger) Errorw(msg string, keysAndValues ...any) {} + +func (l *dummyLogger) Fatalw(msg string, keysAndValues ...any) {} + +func (l *dummyLogger) Panicw(msg string, keysAndValues ...any) {} + func Test_Client_Logger_Debug(t *testing.T) { t.Parallel() app := fiber.New() @@ -529,13 +575,12 @@ func Test_Client_Logger_Debug(t *testing.T) { return c.SendString("response") }) - var url string - + addrChan := make(chan string) go func() { require.NoError(t, app.Listen(":0", fiber.ListenConfig{ DisableStartupMessage: true, ListenerAddrFunc: func(addr net.Addr) { - url = addr.String() + addrChan <- addr.String() }, })) }() @@ -544,14 +589,13 @@ func Test_Client_Logger_Debug(t *testing.T) { _ = app.Shutdown() }(app) - time.Sleep(1 * time.Second) - var buf bytes.Buffer - log.SetOutput(&buf) + logger := &dummyLogger{buf: &buf} client := NewClient() - client.Debug() + client.Debug().SetLogger(logger) + url := <-addrChan resp, err := client.Get("http://" + url) defer resp.Close() @@ -567,13 +611,12 @@ func Test_Client_Logger_DisableDebug(t *testing.T) { return c.SendString("response") }) - var url string - + addrChan := make(chan string) go func() { require.NoError(t, app.Listen(":0", fiber.ListenConfig{ DisableStartupMessage: true, ListenerAddrFunc: func(addr net.Addr) { - url = addr.String() + addrChan <- addr.String() }, })) }() @@ -582,14 +625,13 @@ func Test_Client_Logger_DisableDebug(t *testing.T) { _ = app.Shutdown() }(app) - time.Sleep(1 * time.Second) - var buf bytes.Buffer - log.SetOutput(&buf) + logger := &dummyLogger{buf: &buf} client := NewClient() - client.DisableDebug() + client.DisableDebug().SetLogger(logger) + url := <-addrChan resp, err := client.Get("http://" + url) defer resp.Close() From be94a3f9676d46150d6ad4fe4bcdef8ffb9b2127 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 25 Feb 2024 01:46:27 +0300 Subject: [PATCH 100/118] fix proxy tests --- middleware/proxy/proxy_test.go | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 793f5a09ae..9ded0a6d64 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -138,8 +138,7 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) { })) }() - client := clientpkg.AcquireClient() - defer clientpkg.ReleaseClient(client) + client := clientpkg.NewClient() client.SetTLSConfig(clientTLSConf) resp, err := client.Get("https://" + addr + "/tlsbalancer") @@ -177,8 +176,7 @@ func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) { })) }() - client := clientpkg.AcquireClient() - defer clientpkg.ReleaseClient(client) + client := clientpkg.NewClient() client.SetTimeout(5 * time.Second) client.TLSConfig().InsecureSkipVerify = true @@ -243,8 +241,7 @@ func Test_Proxy_Forward_WithClient_TLSConfig(t *testing.T) { })) }() - client := clientpkg.AcquireClient() - defer clientpkg.ReleaseClient(client) + client := clientpkg.NewClient() client.SetTLSConfig(clientTLSConf) resp, err := client.Get("https://" + addr) @@ -599,8 +596,7 @@ func Test_Proxy_Forward_Global_Client(t *testing.T) { })) }() - client := clientpkg.AcquireClient() - defer clientpkg.ReleaseClient(client) + client := clientpkg.NewClient() resp, err := client.Get("http://" + addr) require.NoError(t, err) @@ -632,8 +628,7 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) { })) }() - client := clientpkg.AcquireClient() - defer clientpkg.ReleaseClient(client) + client := clientpkg.NewClient() resp, err := client.Get("http://" + addr) require.NoError(t, err) @@ -711,8 +706,7 @@ func Test_Proxy_Domain_Forward_Local(t *testing.T) { })) }() - client := clientpkg.AcquireClient() - defer clientpkg.ReleaseClient(client) + client := clientpkg.NewClient() resp, err := client.Get("http://" + localDomain + "/test?query_test=true") require.NoError(t, err) From ed3c640a85f7837bb892811bfcb78542c06047b7 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 25 Feb 2024 13:13:10 +0300 Subject: [PATCH 101/118] fix global tests --- client/client_test.go | 255 +++++++++++++++++++++++++++++------------- 1 file changed, 180 insertions(+), 75 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index e7e4ac0847..d20e1be563 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -12,14 +12,34 @@ import ( "testing" "time" - "github.com/gofiber/fiber/v3/addon/retry" - "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/addon/retry" "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/require" ) +func startTestServerWithPort(t *testing.T, beforeStarting func(app *fiber.App)) (*fiber.App, string) { + app := fiber.New() + + if beforeStarting != nil { + beforeStarting(app) + } + + addrChan := make(chan string) + go func() { + require.NoError(t, app.Listen(":0", fiber.ListenConfig{ + DisableStartupMessage: true, + ListenerAddrFunc: func(addr net.Addr) { + addrChan <- addr.String() + }, + })) + }() + + addr := <-addrChan + return app, addr +} + func Test_Client_Add_Hook(t *testing.T) { t.Parallel() @@ -153,14 +173,15 @@ func Test_Client_ConcurrencyRequests(t *testing.T) { }) go start() + client := NewClient().SetDial(dial) + wg := sync.WaitGroup{} for i := 0; i < 5; i++ { - C().SetDial(dial) for _, method := range []string{"GET", "POST", "PUT", "DELETE", "PATCH"} { wg.Add(1) go func(m string) { defer wg.Done() - resp, err := C().Custom("http://example.com", m) + resp, err := client.Custom("http://example.com", m) require.NoError(t, err) require.Equal(t, "example.com "+m, utils.UnsafeString(resp.RawResponse.Body())) }(method) @@ -173,52 +194,70 @@ func Test_Client_ConcurrencyRequests(t *testing.T) { func Test_Get(t *testing.T) { t.Parallel() - app, dial, start := createHelperServer(t) - - app.Get("/", func(c fiber.Ctx) error { - return c.SendString(c.Hostname()) - }) + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + }) - go start() + return app, addr + } t.Run("global get function", func(t *testing.T) { t.Parallel() - C().SetDial(dial) - resp, err := Get("http://example.com") + + app, addr := setupApp() + defer app.Shutdown() + + resp, err := Get("http://" + addr) require.NoError(t, err) - require.Equal(t, "example.com", utils.UnsafeString(resp.RawResponse.Body())) + require.Equal(t, "0.0.0.0", utils.UnsafeString(resp.RawResponse.Body())) }) t.Run("client get", func(t *testing.T) { t.Parallel() - resp, err := NewClient().SetDial(dial).Get("http://example.com") + + app, addr := setupApp() + defer app.Shutdown() + + resp, err := NewClient().Get("http://" + addr) require.NoError(t, err) - require.Equal(t, "example.com", utils.UnsafeString(resp.RawResponse.Body())) + require.Equal(t, "0.0.0.0", utils.UnsafeString(resp.RawResponse.Body())) }) } func Test_Head(t *testing.T) { t.Parallel() - app, dial, start := createHelperServer(t) - - app.Head("/", func(c fiber.Ctx) error { - return c.SendString(c.Hostname()) - }) + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Head("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + }) - go start() + return app, addr + } t.Run("global head function", func(t *testing.T) { t.Parallel() - C().SetDial(dial) - resp, err := Head("http://example.com") + + app, addr := setupApp() + defer app.Shutdown() + + resp, err := Head("http://" + addr) require.NoError(t, err) require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body())) }) t.Run("client head", func(t *testing.T) { t.Parallel() - resp, err := NewClient().SetDial(dial).Head("http://example.com") + + app, addr := setupApp() + defer app.Shutdown() + + resp, err := NewClient().Head("http://" + addr) require.NoError(t, err) require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body())) }) @@ -227,19 +266,25 @@ func Test_Head(t *testing.T) { func Test_Post(t *testing.T) { t.Parallel() - app, dial, start := createHelperServer(t) - app.Post("/", func(c fiber.Ctx) error { - return c.Status(fiber.StatusCreated). - SendString(c.FormValue("foo")) - }) + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Post("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusCreated). + SendString(c.FormValue("foo")) + }) + }) - go start() + return app, addr + } t.Run("global post function", func(t *testing.T) { t.Parallel() + + app, addr := setupApp() + defer app.Shutdown() + for i := 0; i < 5; i++ { - C().SetDial(dial) - resp, err := Post("http://example.com", Config{ + resp, err := Post("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, @@ -253,8 +298,12 @@ func Test_Post(t *testing.T) { t.Run("client post", func(t *testing.T) { t.Parallel() + + app, addr := setupApp() + defer app.Shutdown() + for i := 0; i < 5; i++ { - resp, err := NewClient().SetDial(dial).Post("http://example.com", Config{ + resp, err := NewClient().Post("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, @@ -270,17 +319,24 @@ func Test_Post(t *testing.T) { func Test_Put(t *testing.T) { t.Parallel() - app, dial, start := createHelperServer(t) - app.Put("/", func(c fiber.Ctx) error { - return c.SendString(c.FormValue("foo")) - }) + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Put("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) + }) - go start() + return app, addr + } t.Run("global put function", func(t *testing.T) { t.Parallel() + + app, addr := setupApp() + defer app.Shutdown() + for i := 0; i < 5; i++ { - resp, err := Put("http://example.com", Config{ + resp, err := Put("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, @@ -294,8 +350,12 @@ func Test_Put(t *testing.T) { t.Run("client put", func(t *testing.T) { t.Parallel() + + app, addr := setupApp() + defer app.Shutdown() + for i := 0; i < 5; i++ { - resp, err := NewClient().SetDial(dial).Put("http://example.com", Config{ + resp, err := NewClient().Put("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, @@ -311,18 +371,27 @@ func Test_Put(t *testing.T) { func Test_Delete(t *testing.T) { t.Parallel() - app, dial, start := createHelperServer(t) - app.Delete("/", func(c fiber.Ctx) error { - return c.Status(fiber.StatusNoContent). - SendString("deleted") - }) + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Delete("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusNoContent). + SendString("deleted") + }) + }) - go start() + return app, addr + } t.Run("global delete function", func(t *testing.T) { t.Parallel() + + app, addr := setupApp() + defer app.Shutdown() + + time.Sleep(1 * time.Second) + for i := 0; i < 5; i++ { - resp, err := Delete("http://example.com", Config{ + resp, err := Delete("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, @@ -336,8 +405,12 @@ func Test_Delete(t *testing.T) { t.Run("client delete", func(t *testing.T) { t.Parallel() + + app, addr := setupApp() + defer app.Shutdown() + for i := 0; i < 5; i++ { - resp, err := NewClient().SetDial(dial).Delete("http://example.com", Config{ + resp, err := NewClient().Delete("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, @@ -353,18 +426,24 @@ func Test_Delete(t *testing.T) { func Test_Options(t *testing.T) { t.Parallel() - app, dial, start := createHelperServer(t) - app.Options("/", func(c fiber.Ctx) error { - return c.Status(fiber.StatusNoContent).SendString("") - }) + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Options("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusNoContent).SendString("") + }) + }) - go start() + return app, addr + } t.Run("global options function", func(t *testing.T) { t.Parallel() + + app, addr := setupApp() + defer app.Shutdown() + for i := 0; i < 5; i++ { - C().SetDial(dial) - resp, err := Options("http://example.com") + resp, err := Options("http://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) @@ -374,8 +453,12 @@ func Test_Options(t *testing.T) { t.Run("client options", func(t *testing.T) { t.Parallel() + + app, addr := setupApp() + defer app.Shutdown() + for i := 0; i < 5; i++ { - resp, err := NewClient().SetDial(dial).Options("http://example.com") + resp, err := NewClient().Options("http://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) @@ -387,19 +470,26 @@ func Test_Options(t *testing.T) { func Test_Patch(t *testing.T) { t.Parallel() - app, dial, start := createHelperServer(t) - - app.Patch("/", func(c fiber.Ctx) error { - return c.SendString(c.FormValue("foo")) - }) + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Patch("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) + }) - go start() + return app, addr + } t.Run("global patch function", func(t *testing.T) { t.Parallel() + + app, addr := setupApp() + defer app.Shutdown() + + time.Sleep(1 * time.Second) + for i := 0; i < 5; i++ { - C().SetDial(dial) - resp, err := Patch("http://example.com", Config{ + resp, err := Patch("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, @@ -413,8 +503,12 @@ func Test_Patch(t *testing.T) { t.Run("client patch", func(t *testing.T) { t.Parallel() + + app, addr := setupApp() + defer app.Shutdown() + for i := 0; i < 5; i++ { - resp, err := NewClient().SetDial(dial).Patch("http://example.com", Config{ + resp, err := NewClient().Patch("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, @@ -430,19 +524,24 @@ func Test_Patch(t *testing.T) { func Test_Client_UserAgent(t *testing.T) { t.Parallel() - app, dial, start := createHelperServer(t) - - app.Get("/", func(c fiber.Ctx) error { - return c.Send(c.Request().Header.UserAgent()) - }) + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.Send(c.Request().Header.UserAgent()) + }) + }) - go start() + return app, addr + } t.Run("default", func(t *testing.T) { t.Parallel() + + app, addr := setupApp() + defer app.Shutdown() + for i := 0; i < 5; i++ { - C().SetDial(dial) - resp, err := Get("http://example.com") + resp, err := Get("http://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) @@ -452,11 +551,15 @@ func Test_Client_UserAgent(t *testing.T) { t.Run("custom", func(t *testing.T) { t.Parallel() + + app, addr := setupApp() + defer app.Shutdown() + for i := 0; i < 5; i++ { - c := NewClient().SetDial(dial). + c := NewClient(). SetUserAgent("ua") - resp, err := c.Get("http://example.com") + resp, err := c.Get("http://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) @@ -1107,6 +1210,8 @@ func Test_Replace(t *testing.T) { require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "", resp.String()) + + C().SetDial(nil) } func Test_Set_Config_To_Request(t *testing.T) { From 8ae4c353f0cc862b16b4326632d02afb794705ff Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 25 Feb 2024 13:14:49 +0300 Subject: [PATCH 102/118] remove unused code --- client/core.go | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/client/core.go b/client/core.go index 44b4e5f71d..726f2a5ecb 100644 --- a/client/core.go +++ b/client/core.go @@ -199,37 +199,6 @@ func (c *core) timeout() context.CancelFunc { return cancel } -// configureHostClient set configureHostClient in host. -/*func (c *core) configureHostClient(hostClient *fasthttp.Client) error { - // tls and dial configuration - c.client.mu.Lock() - hostClient.TLSConfig = c.client.tlsConfig.Clone() - hostClient.Dial = c.req.dial - c.client.mu.Unlock() - - rawURI := c.req.RawRequest.URI() - if c.client.proxyURL != "" { - rawURI = fasthttp.AcquireURI() - rawURI.Update(c.client.proxyURL) - defer fasthttp.ReleaseURI(rawURI) - } - - isTLS, scheme := false, rawURI.Scheme() - if bytes.Equal(httpsBytes, scheme) { - isTLS = true - } else if !bytes.Equal(httpBytes, scheme) { - return ErrNotSupportSchema - } - - // proxy configuration - c.client.mu.Lock() - hostClient.Addr = addMissingPort(string(rawURI.Host()), isTLS) - hostClient.IsTLS = isTLS - c.client.mu.Unlock() - - return nil -}*/ - // execute will exec each hooks and plugins. func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Response, error) { // keep a reference, because pass param is boring From c5901b20e8efd2f626a53854cb1c0f916c7780e3 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 25 Feb 2024 13:20:07 +0300 Subject: [PATCH 103/118] fix logger test --- log/default_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/log/default_test.go b/log/default_test.go index db29602ef6..e9cffb9611 100644 --- a/log/default_test.go +++ b/log/default_test.go @@ -230,7 +230,7 @@ func Test_WithContextCaller(t *testing.T) { WithContext(ctx).Info("") Info("") - require.Equal(t, "default_test.go:223: [Info] \ndefault_test.go:224: [Info] \n", string(w.b)) + require.Equal(t, "default_test.go:230: [Info] \ndefault_test.go:231: [Info] \n", string(w.b)) } func Test_SetLevel(t *testing.T) { From 96cb295a625e7082681bb72f1248ed52e92cd9ea Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 25 Feb 2024 13:23:48 +0300 Subject: [PATCH 104/118] fix proxy tests --- middleware/proxy/proxy_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 9ded0a6d64..e2b58a29d5 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -480,7 +480,7 @@ func Test_Proxy_DoTimeout_RestoreOriginalURL(t *testing.T) { }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) - require.Error(t, err1) + require.NoError(t, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "proxied", string(body)) From d8ee1445c28f8ffda1a65d1a18d4df69d8926bfb Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 25 Feb 2024 14:11:14 +0300 Subject: [PATCH 105/118] fix linter --- client/client.go | 4 +- client/client_test.go | 99 +++++++++++++++++++++++++++++------------- client/cookiejar.go | 15 ++++--- client/core.go | 20 ++++----- client/core_test.go | 41 +++++++++-------- client/helper_test.go | 17 ++++---- client/hooks.go | 42 ++++++++++++------ client/hooks_test.go | 61 +++++++++++++------------- client/request.go | 17 +++++--- client/request_test.go | 26 ++++++----- client/response.go | 15 ++++--- 11 files changed, 210 insertions(+), 147 deletions(-) diff --git a/client/client.go b/client/client.go index 1b8d107a1f..0d3db2fa05 100644 --- a/client/client.go +++ b/client/client.go @@ -543,7 +543,7 @@ func (c *Client) Patch(url string, cfg ...Config) (*Response, error) { } // Custom provide an API like axios which send custom request. -func (c *Client) Custom(url string, method string, cfg ...Config) (*Response, error) { +func (c *Client) Custom(url, method string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) @@ -593,7 +593,7 @@ func (c *Client) Reset() { // Body is higher than FormData, and the priority of FormData // is higher than File. type Config struct { - Ctx context.Context + Ctx context.Context //nolint:containedctx // It's needed to be stored in the config. UserAgent string Referer string diff --git a/client/client_test.go b/client/client_test.go index d20e1be563..ae80242f35 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -3,7 +3,7 @@ package client import ( "context" "crypto/tls" - "fmt" + "errors" "io" "net" "os" @@ -20,6 +20,8 @@ import ( ) func startTestServerWithPort(t *testing.T, beforeStarting func(app *fiber.App)) (*fiber.App, string) { + t.Helper() + app := fiber.New() if beforeStarting != nil { @@ -45,15 +47,15 @@ func Test_Client_Add_Hook(t *testing.T) { t.Run("add request hooks", func(t *testing.T) { t.Parallel() - client := NewClient().AddRequestHook(func(c *Client, r *Request) error { + client := NewClient().AddRequestHook(func(_ *Client, _ *Request) error { return nil }) require.Len(t, client.RequestHook(), 1) - client.AddRequestHook(func(c *Client, r *Request) error { + client.AddRequestHook(func(_ *Client, _ *Request) error { return nil - }, func(c *Client, r *Request) error { + }, func(_ *Client, _ *Request) error { return nil }) @@ -62,15 +64,15 @@ func Test_Client_Add_Hook(t *testing.T) { t.Run("add response hooks", func(t *testing.T) { t.Parallel() - client := NewClient().AddResponseHook(func(c *Client, resp *Response, r *Request) error { + client := NewClient().AddResponseHook(func(_ *Client, _ *Response, _ *Request) error { return nil }) require.Len(t, client.ResponseHook(), 1) - client.AddResponseHook(func(c *Client, resp *Response, r *Request) error { + client.AddResponseHook(func(_ *Client, _ *Response, _ *Request) error { return nil - }, func(c *Client, resp *Response, r *Request) error { + }, func(_ *Client, _ *Response, _ *Request) error { return nil }) @@ -96,12 +98,12 @@ func Test_Client_Marshal(t *testing.T) { t.Run("set json unmarshal", func(t *testing.T) { t.Parallel() client := NewClient(). - SetJSONUnmarshal(func(data []byte, v any) error { - return fmt.Errorf("empty json") + SetJSONUnmarshal(func(_ []byte, _ any) error { + return errors.New("empty json") }) err := client.JSONUnmarshal()(nil, nil) - require.Equal(t, fmt.Errorf("empty json"), err) + require.Equal(t, errors.New("empty json"), err) }) t.Run("set xml marshal", func(t *testing.T) { @@ -120,11 +122,11 @@ func Test_Client_Marshal(t *testing.T) { t.Parallel() client := NewClient(). SetXMLUnmarshal(func(_ []byte, _ any) error { - return fmt.Errorf("empty xml") + return errors.New("empty xml") }) err := client.XMLUnmarshal()(nil, nil) - require.Equal(t, fmt.Errorf("empty xml"), err) + require.Equal(t, errors.New("empty xml"), err) }) } @@ -208,7 +210,9 @@ func Test_Get(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() resp, err := Get("http://" + addr) require.NoError(t, err) @@ -219,7 +223,9 @@ func Test_Get(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() resp, err := NewClient().Get("http://" + addr) require.NoError(t, err) @@ -244,7 +250,9 @@ func Test_Head(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() resp, err := Head("http://" + addr) require.NoError(t, err) @@ -255,7 +263,9 @@ func Test_Head(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() resp, err := NewClient().Head("http://" + addr) require.NoError(t, err) @@ -281,7 +291,9 @@ func Test_Post(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() for i := 0; i < 5; i++ { resp, err := Post("http://"+addr, Config{ @@ -300,7 +312,9 @@ func Test_Post(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() for i := 0; i < 5; i++ { resp, err := NewClient().Post("http://"+addr, Config{ @@ -333,7 +347,9 @@ func Test_Put(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() for i := 0; i < 5; i++ { resp, err := Put("http://"+addr, Config{ @@ -352,7 +368,9 @@ func Test_Put(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() for i := 0; i < 5; i++ { resp, err := NewClient().Put("http://"+addr, Config{ @@ -386,7 +404,9 @@ func Test_Delete(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() time.Sleep(1 * time.Second) @@ -407,7 +427,9 @@ func Test_Delete(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() for i := 0; i < 5; i++ { resp, err := NewClient().Delete("http://"+addr, Config{ @@ -440,7 +462,9 @@ func Test_Options(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() for i := 0; i < 5; i++ { resp, err := Options("http://" + addr) @@ -455,7 +479,9 @@ func Test_Options(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() for i := 0; i < 5; i++ { resp, err := NewClient().Options("http://" + addr) @@ -484,7 +510,9 @@ func Test_Patch(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() time.Sleep(1 * time.Second) @@ -505,7 +533,9 @@ func Test_Patch(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() for i := 0; i < 5; i++ { resp, err := NewClient().Patch("http://"+addr, Config{ @@ -538,7 +568,9 @@ func Test_Client_UserAgent(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() for i := 0; i < 5; i++ { resp, err := Get("http://" + addr) @@ -553,7 +585,9 @@ func Test_Client_UserAgent(t *testing.T) { t.Parallel() app, addr := setupApp() - defer app.Shutdown() + defer func() { + require.NoError(t, app.Shutdown()) + }() for i := 0; i < 5; i++ { c := NewClient(). @@ -1159,7 +1193,7 @@ func Test_Client_SetRootCertificateFromString(t *testing.T) { t.Parallel() file, err := os.Open("../.github/testdata/ssl.pem") - defer func() { _ = file.Close() }() + defer func() { require.NoError(t, file.Close()) }() require.NoError(t, err) pem, err := io.ReadAll(file) @@ -1345,7 +1379,7 @@ func Test_Client_SetProxyURL(t *testing.T) { go start() t.Cleanup(func() { - _ = app.Shutdown() + require.NoError(t, app.Shutdown()) }) time.Sleep(1 * time.Second) @@ -1409,8 +1443,11 @@ func Benchmark_Client_Request(b *testing.B) { b.ResetTimer() b.ReportAllocs() + var err error + var resp *Response for i := 0; i < b.N; i++ { - resp, _ := client.Get("http://example.com") + resp, err = client.Get("http://example.com") resp.Close() } + require.NoError(b, err) } diff --git a/client/cookiejar.go b/client/cookiejar.go index 6f2088c482..09a42b37ca 100644 --- a/client/cookiejar.go +++ b/client/cookiejar.go @@ -3,6 +3,7 @@ package client import ( "bytes" + "errors" "net" "sync" "time" @@ -19,7 +20,12 @@ var cookieJarPool = sync.Pool{ // AcquireCookieJar returns an empty CookieJar object from pool. func AcquireCookieJar() *CookieJar { - return cookieJarPool.Get().(*CookieJar) + jar, ok := cookieJarPool.Get().(*CookieJar) + if !ok { + panic(errors.New("failed to type-assert to *CookieJar")) + } + + return jar } // ReleaseCookieJar returns CookieJar to the pool. @@ -157,11 +163,6 @@ func (cj *CookieJar) SetKeyValue(host, key, value string) { // This function prevents extra allocations by making repeated cookies // not being duplicated. func (cj *CookieJar) SetKeyValueBytes(host string, key, value []byte) { - cj.setKeyValue(host, key, value) -} - -// setKeyValue sets a cookie by key and value for a specific host. -func (cj *CookieJar) setKeyValue(host string, key, value []byte) { c := fasthttp.AcquireCookie() c.SetKeyBytes(key) c.SetValueBytes(value) @@ -204,7 +205,7 @@ func (cj *CookieJar) parseCookiesFromResp(host, path []byte, resp *fasthttp.Resp c, isCreated = fasthttp.AcquireCookie(), true } - _ = c.ParseBytes(value) + _ = c.ParseBytes(value) //nolint:errcheck // ignore error if c.Expire().Equal(fasthttp.CookieExpireUnlimited) || c.Expire().After(now) { cookies = append(cookies, c) } else if isCreated { diff --git a/client/core.go b/client/core.go index 726f2a5ecb..d234837ffa 100644 --- a/client/core.go +++ b/client/core.go @@ -14,11 +14,7 @@ import ( "github.com/valyala/fasthttp" ) -var ( - httpBytes = []byte("http") - httpsBytes = []byte("https") - boundary = "--FiberFormBoundary" -) +var boundary = "--FiberFormBoundary" // RequestHook is a function that receives Agent and Request, // it can change the data in Request and Agent. @@ -36,7 +32,7 @@ type ResponseHook func(*Client, *Response, *Request) error type RetryConfig = retry.Config // addMissingPort will add the corresponding port number for host. -func addMissingPort(addr string, isTLS bool) string { +func addMissingPort(addr string, isTLS bool) string { //revive:disable-line:flag-parameter // Accepting a bool param named isTLS if fine here n := strings.Index(addr, ":") if n >= 0 { return addr @@ -53,7 +49,7 @@ func addMissingPort(addr string, isTLS bool) string { type core struct { client *Client req *Request - ctx context.Context + ctx context.Context //nolint:containedctx // It's needed to be stored in the core. } // getRetryConfig returns the retry configuration of the client. @@ -93,8 +89,6 @@ func (c *core) execFunc() (*Response, error) { var err error go func() { - //c.client.mu.Lock() - respv := fasthttp.AcquireResponse() if cfg != nil { err = retry.NewExponentialBackoff(*cfg).Retry(func() error { @@ -124,7 +118,6 @@ func (c *core) execFunc() (*Response, error) { respv.CopyTo(resp.RawResponse) errCh <- nil } - //c.client.mu.Unlock() }() select { @@ -246,7 +239,12 @@ var errChanPool = &sync.Pool{ // The returned error chan may be returned to the pool with releaseErrChan when no longer needed. // This allows reducing GC load. func acquireErrChan() chan error { - return errChanPool.Get().(chan error) + ch, ok := errChanPool.Get().(chan error) + if !ok { + panic(errors.New("failed to type-assert to chan error")) + } + + return ch } // releaseErrChan returns the object acquired via acquireErrChan to the pool. diff --git a/client/core_test.go b/client/core_test.go index f3ebb8f667..1b8ea42b9d 100644 --- a/client/core_test.go +++ b/client/core_test.go @@ -2,7 +2,7 @@ package client import ( "context" - "fmt" + "errors" "net" "testing" "time" @@ -66,7 +66,7 @@ func Test_Exec_Func(t *testing.T) { }) app.Get("/return-error", func(_ fiber.Ctx) error { - return fmt.Errorf("the request is error") + return errors.New("the request is error") }) app.Get("/hang-up", func(c fiber.Ctx) error { @@ -87,11 +87,10 @@ func Test_Exec_Func(t *testing.T) { core.client = client core.req = req - client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) + client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed req.RawRequest.SetRequestURI("http://example.com/normal") resp, err := core.execFunc() - fmt.Print(string(resp.Body())) require.NoError(t, err) require.Equal(t, 200, resp.RawResponse.StatusCode()) require.Equal(t, "example.com", string(resp.RawResponse.Body())) @@ -104,7 +103,7 @@ func Test_Exec_Func(t *testing.T) { core.client = client core.req = req - client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) + client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed req.RawRequest.SetRequestURI("http://example.com/return-error") resp, err := core.execFunc() @@ -124,7 +123,7 @@ func Test_Exec_Func(t *testing.T) { core.client = client core.req = req - client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) + client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed req.RawRequest.SetRequestURI("http://example.com/hang-up") _, err := core.execFunc() @@ -143,8 +142,8 @@ func Test_Execute(t *testing.T) { return c.SendString(c.Hostname()) }) - app.Get("/return-error", func(c fiber.Ctx) error { - return fmt.Errorf("the request is error") + app.Get("/return-error", func(_ fiber.Ctx) error { + return errors.New("the request is error") }) app.Get("/hang-up", func(c fiber.Ctx) error { @@ -163,8 +162,8 @@ func Test_Execute(t *testing.T) { require.Equal(t, "http://example.com", req.URL()) return nil }) - client.SetDial(func(addr string) (net.Conn, error) { - return ln.Dial() + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed }) req.SetURL("http://example.com") @@ -176,12 +175,12 @@ func Test_Execute(t *testing.T) { t.Run("add user response hooks", func(t *testing.T) { t.Parallel() core, client, req := newCore(), NewClient(), AcquireRequest() - client.AddResponseHook(func(c *Client, resp *Response, req *Request) error { + client.AddResponseHook(func(_ *Client, _ *Response, req *Request) error { require.Equal(t, "http://example.com", req.URL()) return nil }) - client.SetDial(func(addr string) (net.Conn, error) { - return ln.Dial() + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed }) req.SetURL("http://example.com") @@ -194,8 +193,8 @@ func Test_Execute(t *testing.T) { t.Parallel() core, client, req := newCore(), NewClient(), AcquireRequest() - client.SetDial(func(addr string) (net.Conn, error) { - return ln.Dial() + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed }) req.SetURL("http://example.com/hang-up") @@ -208,8 +207,8 @@ func Test_Execute(t *testing.T) { t.Parallel() core, client, req := newCore(), NewClient(), AcquireRequest() client.SetTimeout(500 * time.Millisecond) - client.SetDial(func(addr string) (net.Conn, error) { - return ln.Dial() + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed }) req.SetURL("http://example.com/hang-up") @@ -221,8 +220,8 @@ func Test_Execute(t *testing.T) { t.Parallel() core, client, req := newCore(), NewClient(), AcquireRequest() - client.SetDial(func(addr string) (net.Conn, error) { - return ln.Dial() + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed }) req.SetURL("http://example.com/hang-up"). SetTimeout(300 * time.Millisecond) @@ -236,8 +235,8 @@ func Test_Execute(t *testing.T) { core, client, req := newCore(), NewClient(), AcquireRequest() client.SetTimeout(30 * time.Millisecond) - client.SetDial(func(addr string) (net.Conn, error) { - return ln.Dial() + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed }) req.SetURL("http://example.com/hang-up"). SetTimeout(3000 * time.Millisecond) diff --git a/client/helper_test.go b/client/helper_test.go index 2b935a8fab..67380f3470 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -61,24 +61,23 @@ func (ts *testServer) stop() { func (ts *testServer) dial() func(addr string) (net.Conn, error) { ts.tb.Helper() - return func(addr string) (net.Conn, error) { - return ts.ln.Dial() + return func(_ string) (net.Conn, error) { + return ts.ln.Dial() //nolint:wrapcheck // not needed } } -func createHelperServer(t testing.TB) (*fiber.App, func(addr string) (net.Conn, error), func()) { - t.Helper() +func createHelperServer(tb testing.TB) (*fiber.App, func(addr string) (net.Conn, error), func()) { + tb.Helper() ln := fasthttputil.NewInmemoryListener() app := fiber.New() - return app, func(addr string) (net.Conn, error) { - return ln.Dial() + return app, func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed }, func() { - require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) + require.NoError(tb, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) } - // TODO: add closer fn } func testRequest(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted string, count ...int) { @@ -132,7 +131,7 @@ func testRequestFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent * } } -func testClient(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Client), excepted string, count ...int) { +func testClient(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Client), excepted string, count ...int) { //nolint: unparam // maybe needed t.Helper() app, ln, start := createHelperServer(t) diff --git a/client/hooks.go b/client/hooks.go index b1c76a326d..3545b58448 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -1,6 +1,8 @@ package client import ( + "errors" + "fmt" "io" "math/rand" "mime/multipart" @@ -74,10 +76,10 @@ func parserRequestURL(c *Client, req *Request) error { // set path params req.path.VisitAll(func(key, val string) { - uri = strings.Replace(uri, ":"+key, val, -1) + uri = strings.ReplaceAll(uri, ":"+key, val) }) c.path.VisitAll(func(key, val string) { - uri = strings.Replace(uri, ":"+key, val, -1) + uri = strings.ReplaceAll(uri, ":"+key, val) }) // set uri to request and other related setting @@ -133,7 +135,7 @@ func parserRequestHeader(c *Client, req *Request) error { req.RawRequest.Header.SetContentType(multipartFormData) // set boundary if req.boundary == boundary { - req.boundary = req.boundary + randString(16) + req.boundary += randString(16) } req.RawRequest.Header.SetMultipartFormBoundary(req.boundary) default: @@ -193,7 +195,7 @@ func parserRequestBody(c *Client, req *Request) error { mw := multipart.NewWriter(req.RawRequest.BodyWriter()) err := mw.SetBoundary(req.boundary) if err != nil { - return err + return fmt.Errorf("set boundary error: %w", err) } defer func() { err := mw.Close() @@ -210,7 +212,7 @@ func parserRequestBody(c *Client, req *Request) error { err = mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value)) }) if err != nil { - return err + return fmt.Errorf("write formdata error: %w", err) } // add file @@ -235,34 +237,36 @@ func parserRequestBody(c *Client, req *Request) error { if v.reader == nil { v.reader, err = os.Open(v.path) if err != nil { - return err + return fmt.Errorf("open file error: %w", err) } } // write file w, err := mw.CreateFormFile(v.fieldName, v.name) if err != nil { - return err + return fmt.Errorf("create file error: %w", err) } for { n, err := v.reader.Read(b) - if err != nil && err != io.EOF { - return err + if err != nil && !errors.Is(err, io.EOF) { + return fmt.Errorf("read file error: %w", err) } - if err == io.EOF { + if errors.Is(err, io.EOF) { break } _, err = w.Write(b[:n]) if err != nil { - return err + return fmt.Errorf("write file error: %w", err) } } - // ignore err - _ = v.reader.Close() + err = v.reader.Close() + if err != nil { + return fmt.Errorf("close file error: %w", err) + } } case rawBody: if body, ok := req.body.([]byte); ok { @@ -270,6 +274,8 @@ func parserRequestBody(c *Client, req *Request) error { } else { return ErrBodyType } + case noBody: + return nil } return nil @@ -277,14 +283,22 @@ func parserRequestBody(c *Client, req *Request) error { // parserResponseHeader will parse the response header and store it in the response func parserResponseCookie(c *Client, resp *Response, req *Request) error { + var err error resp.RawResponse.Header.VisitAllCookie(func(key, value []byte) { cookie := fasthttp.AcquireCookie() - _ = cookie.ParseBytes(value) + err = cookie.ParseBytes(value) + if err != nil { + return + } cookie.SetKeyBytes(key) resp.cookie = append(resp.cookie, cookie) }) + if err != nil { + return err + } + // store cookies to jar if c.cookieJar != nil { c.cookieJar.parseCookiesFromResp(req.RawRequest.URI().Host(), req.RawRequest.URI().Path(), resp.RawResponse) diff --git a/client/hooks_test.go b/client/hooks_test.go index e5410bed94..455ebe3268 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -168,7 +168,8 @@ func Test_Parser_Request_URL(t *testing.T) { err := parserRequestURL(client, req) require.NoError(t, err) - values, _ := url.ParseQuery(string(req.RawRequest.URI().QueryString())) + values, err := url.ParseQuery(string(req.RawRequest.URI().QueryString())) + require.NoError(t, err) flag1, flag2, flag3 := false, false, false for _, v := range values["bar"] { @@ -524,49 +525,49 @@ type dummyLogger struct { buf *bytes.Buffer } -func (l *dummyLogger) Trace(v ...any) {} +func (*dummyLogger) Trace(_ ...any) {} -func (l *dummyLogger) Debug(v ...any) {} +func (*dummyLogger) Debug(_ ...any) {} -func (l *dummyLogger) Info(v ...any) {} +func (*dummyLogger) Info(_ ...any) {} -func (l *dummyLogger) Warn(v ...any) {} +func (*dummyLogger) Warn(_ ...any) {} -func (l *dummyLogger) Error(v ...any) {} +func (*dummyLogger) Error(_ ...any) {} -func (l *dummyLogger) Fatal(v ...any) {} +func (*dummyLogger) Fatal(_ ...any) {} -func (l *dummyLogger) Panic(v ...any) {} +func (*dummyLogger) Panic(_ ...any) {} -func (l *dummyLogger) Tracef(format string, v ...any) {} +func (*dummyLogger) Tracef(_ string, _ ...any) {} func (l *dummyLogger) Debugf(format string, v ...any) { - l.buf.WriteString(fmt.Sprintf(format, v...)) + _, _ = l.buf.WriteString(fmt.Sprintf(format, v...)) //nolint:errcheck // not needed } -func (l *dummyLogger) Infof(format string, v ...any) {} +func (*dummyLogger) Infof(_ string, _ ...any) {} -func (l *dummyLogger) Warnf(format string, v ...any) {} +func (*dummyLogger) Warnf(_ string, _ ...any) {} -func (l *dummyLogger) Errorf(format string, v ...any) {} +func (*dummyLogger) Errorf(_ string, _ ...any) {} -func (l *dummyLogger) Fatalf(format string, v ...any) {} +func (*dummyLogger) Fatalf(_ string, _ ...any) {} -func (l *dummyLogger) Panicf(format string, v ...any) {} +func (*dummyLogger) Panicf(_ string, _ ...any) {} -func (l *dummyLogger) Tracew(msg string, keysAndValues ...any) {} +func (*dummyLogger) Tracew(_ string, _ ...any) {} -func (l *dummyLogger) Debugw(msg string, keysAndValues ...any) {} +func (*dummyLogger) Debugw(_ string, _ ...any) {} -func (l *dummyLogger) Infow(msg string, keysAndValues ...any) {} +func (*dummyLogger) Infow(_ string, _ ...any) {} -func (l *dummyLogger) Warnw(msg string, keysAndValues ...any) {} +func (*dummyLogger) Warnw(_ string, _ ...any) {} -func (l *dummyLogger) Errorw(msg string, keysAndValues ...any) {} +func (*dummyLogger) Errorw(_ string, _ ...any) {} -func (l *dummyLogger) Fatalw(msg string, keysAndValues ...any) {} +func (*dummyLogger) Fatalw(_ string, _ ...any) {} -func (l *dummyLogger) Panicw(msg string, keysAndValues ...any) {} +func (*dummyLogger) Panicw(_ string, _ ...any) {} func Test_Client_Logger_Debug(t *testing.T) { t.Parallel() @@ -586,7 +587,7 @@ func Test_Client_Logger_Debug(t *testing.T) { }() defer func(app *fiber.App) { - _ = app.Shutdown() + require.NoError(t, app.Shutdown()) }(app) var buf bytes.Buffer @@ -595,12 +596,13 @@ func Test_Client_Logger_Debug(t *testing.T) { client := NewClient() client.Debug().SetLogger(logger) - url := <-addrChan - resp, err := client.Get("http://" + url) + addr := <-addrChan + resp, err := client.Get("http://" + addr) + require.NoError(t, err) defer resp.Close() require.NoError(t, err) - require.Contains(t, buf.String(), "Host: "+url) + require.Contains(t, buf.String(), "Host: "+addr) require.Contains(t, buf.String(), "Content-Length: 8") } @@ -622,7 +624,7 @@ func Test_Client_Logger_DisableDebug(t *testing.T) { }() defer func(app *fiber.App) { - _ = app.Shutdown() + require.NoError(t, app.Shutdown()) }(app) var buf bytes.Buffer @@ -631,8 +633,9 @@ func Test_Client_Logger_DisableDebug(t *testing.T) { client := NewClient() client.DisableDebug().SetLogger(logger) - url := <-addrChan - resp, err := client.Get("http://" + url) + addr := <-addrChan + resp, err := client.Get("http://" + addr) + require.NoError(t, err) defer resp.Close() require.NoError(t, err) diff --git a/client/request.go b/client/request.go index d082f00e28..1fca71d756 100644 --- a/client/request.go +++ b/client/request.go @@ -3,6 +3,7 @@ package client import ( "bytes" "context" + "errors" "io" "path/filepath" "reflect" @@ -18,7 +19,7 @@ import ( // WithStruct Implementing this interface allows data to // be stored from a struct via reflect. type WithStruct interface { - Add(name string, obj string) + Add(name, obj string) Del(name string) } @@ -42,7 +43,7 @@ type Request struct { userAgent string boundary string referer string - ctx context.Context + ctx context.Context //nolint:containedctx // It's needed to be stored in the request. header *Header params *QueryParam cookies *Cookie @@ -524,7 +525,7 @@ func (r *Request) Patch(url string) (*Response, error) { } // Custom Send custom request. -func (r *Request) Custom(url string, method string) (*Response, error) { +func (r *Request) Custom(url, method string) (*Response, error) { return r.SetURL(url).SetMethod(method).Send() } @@ -834,7 +835,10 @@ var requestPool = &sync.Pool{ // The returned request may be returned to the pool with ReleaseRequest when no longer needed. // This allows reducing GC load. func AcquireRequest() *Request { - req := requestPool.Get().(*Request) + req, ok := requestPool.Get().(*Request) + if !ok { + panic(errors.New("failed to type-assert to *Request")) + } return req } @@ -889,7 +893,10 @@ func SetFileReader(r io.ReadCloser) SetFileFunc { func AcquireFile(setter ...SetFileFunc) *File { fv := filePool.Get() if fv != nil { - f := fv.(*File) + f, ok := fv.(*File) + if !ok { + panic(errors.New("failed to type-assert to *File")) + } for _, v := range setter { v(f) } diff --git a/client/request_test.go b/client/request_test.go index ff0b5b81f6..bc1a3d5716 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -72,7 +72,9 @@ func Test_Request_Context(t *testing.T) { req.SetContext(ctx) ctx = req.Context() - require.Equal(t, "string", ctx.Value(key).(string)) + v, ok := ctx.Value(key).(string) + require.True(t, ok) + require.Equal(t, "string", v) } func Test_Request_Header(t *testing.T) { @@ -840,8 +842,10 @@ func Test_Request_Header_With_Server(t *testing.T) { handler := func(c fiber.Ctx) error { c.Request().Header.VisitAll(func(key, value []byte) { if k := string(key); k == "K1" || k == "K2" { - _, _ = c.Write(key) - _, _ = c.Write(value) + _, err := c.Write(key) + require.NoError(t, err) + _, err = c.Write(value) + require.NoError(t, err) } }) return nil @@ -936,13 +940,13 @@ func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { basename := filepath.Base(filename) require.Equal(t, fh.Filename, basename) - b1, err := os.ReadFile(filename) + b1, err := os.ReadFile(filepath.Clean(filename)) require.NoError(t, err) b2 := make([]byte, fh.Size) f, err := fh.Open() require.NoError(t, err) - defer func() { _ = f.Close() }() + defer func() { require.NoError(t, f.Close()) }() _, err = f.Read(b2) require.NoError(t, err) require.Equal(t, b1, b2) @@ -1056,7 +1060,7 @@ func Test_Request_Body_With_Server(t *testing.T) { buf := make([]byte, fh1.Size) f, err := fh1.Open() require.NoError(t, err) - defer func() { _ = f.Close() }() + defer func() { require.NoError(t, f.Close()) }() _, err = f.Read(buf) require.NoError(t, err) require.Equal(t, "form file", string(buf)) @@ -1176,7 +1180,7 @@ func Test_Request_Error_Body_With_Server(t *testing.T) { SetBoundary("*"). AddFileWithReader("t.txt", io.NopCloser(strings.NewReader("world"))). Get("http://example.com") - require.Equal(t, "mime: invalid boundary character", err.Error()) + require.Equal(t, "set boundary error: mime: invalid boundary character", err.Error()) }) t.Run("open non exist file", func(t *testing.T) { @@ -1231,7 +1235,7 @@ func Test_Request_MaxRedirects(t *testing.T) { t.Run("success", func(t *testing.T) { t.Parallel() - client := NewClient().SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }) + client := NewClient().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed resp, err := AcquireRequest(). SetClient(client). @@ -1250,7 +1254,7 @@ func Test_Request_MaxRedirects(t *testing.T) { t.Run("error", func(t *testing.T) { t.Parallel() - client := NewClient().SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }) + client := NewClient().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed resp, err := AcquireRequest(). SetClient(client). @@ -1264,13 +1268,13 @@ func Test_Request_MaxRedirects(t *testing.T) { t.Run("MaxRedirects", func(t *testing.T) { t.Parallel() - client := NewClient().SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }) + client := NewClient().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed req := AcquireRequest(). SetClient(client). SetMaxRedirects(3) - require.Equal(t, req.MaxRedirects(), 3) + require.Equal(t, 3, req.MaxRedirects()) }) } diff --git a/client/response.go b/client/response.go index 47b3da82bc..f6ecd6fcd8 100644 --- a/client/response.go +++ b/client/response.go @@ -3,6 +3,7 @@ package client import ( "bytes" "errors" + "fmt" "io" "io/fs" "os" @@ -90,35 +91,35 @@ func (r *Response) Save(v any) error { // create directory if _, err := os.Stat(dir); err != nil { if !errors.Is(err, fs.ErrNotExist) { - return err + return fmt.Errorf("failed to check directory: %w", err) } if err = os.MkdirAll(dir, 0o750); err != nil { - return err + return fmt.Errorf("failed to create directory: %w", err) } } // create file outFile, err := os.Create(file) if err != nil { - return err + return fmt.Errorf("failed to create file: %w", err) } - defer func() { _ = outFile.Close() }() + defer func() { _ = outFile.Close() }() //nolint:errcheck // not needed _, err = io.Copy(outFile, bytes.NewReader(r.Body())) if err != nil { - return err + return fmt.Errorf("failed to write response body to file: %w", err) } return nil case io.Writer: _, err := io.Copy(p, bytes.NewReader(r.Body())) if err != nil { - return err + return fmt.Errorf("failed to write response body to io.Writer: %w", err) } defer func() { if pc, ok := p.(io.WriteCloser); ok { - _ = pc.Close() + _ = pc.Close() //nolint:errcheck // not needed } }() From d6d0f17eec4b998d093400cf03a6d7b6a1a09a43 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 25 Feb 2024 14:17:55 +0300 Subject: [PATCH 106/118] use lock instead of rlock --- client/core.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/core.go b/client/core.go index d234837ffa..d084f14979 100644 --- a/client/core.go +++ b/client/core.go @@ -137,8 +137,8 @@ func (c *core) execFunc() (*Response, error) { // preHooks Exec request hook func (c *core) preHooks() error { - c.client.mu.RLock() - defer c.client.mu.RUnlock() + c.client.mu.Lock() + defer c.client.mu.Unlock() for _, f := range c.client.userRequestHooks { err := f(c.client, c.req) From 5a9223a877b333e2ea57ff627d52eaf40aa853fb Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sun, 25 Feb 2024 14:57:12 +0300 Subject: [PATCH 107/118] fix cookiejar data-race --- client/cookiejar.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/client/cookiejar.go b/client/cookiejar.go index 09a42b37ca..482642771a 100644 --- a/client/cookiejar.go +++ b/client/cookiejar.go @@ -217,11 +217,6 @@ func (cj *CookieJar) parseCookiesFromResp(host, path []byte, resp *fasthttp.Resp // Release releases all cookie values. func (cj *CookieJar) Release() { - for _, v := range cj.hostCookies { - for _, c := range v { - fasthttp.ReleaseCookie(c) - } - } cj.hostCookies = nil } From d122be3e0efb5afd288c5b05c8c5cc69af24bbbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Wed, 28 Feb 2024 15:25:43 +0100 Subject: [PATCH 108/118] fix(client): race conditions --- client/client.go | 4 ++++ client/cookiejar.go | 13 ++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/client/client.go b/client/client.go index 0d3db2fa05..9a30426732 100644 --- a/client/client.go +++ b/client/client.go @@ -682,6 +682,10 @@ func init() { // NewClient creates and returns a new Client object. func NewClient() *Client { + // FOllOW-UP performance optimization + // trie to use a pool to reduce the cost of memory allocation + // for the fiber client and the fasthttp client + // if possible also for other structs -> request header, cookie, query param, path param... return &Client{ client: &fasthttp.Client{}, header: &Header{ diff --git a/client/cookiejar.go b/client/cookiejar.go index 482642771a..5d52c03344 100644 --- a/client/cookiejar.go +++ b/client/cookiejar.go @@ -155,7 +155,11 @@ func (cj *CookieJar) SetByHost(host []byte, cookies ...*fasthttp.Cookie) { // This function prevents extra allocations by making repeated cookies // not being duplicated. func (cj *CookieJar) SetKeyValue(host, key, value string) { - cj.SetKeyValueBytes(host, utils.UnsafeBytes(key), utils.UnsafeBytes(value)) + c := fasthttp.AcquireCookie() + c.SetKey(key) + c.SetValue(value) + + cj.SetByHost(utils.UnsafeBytes(host), c) } // SetKeyValueBytes sets a cookie by key and value for a specific host. @@ -217,6 +221,13 @@ func (cj *CookieJar) parseCookiesFromResp(host, path []byte, resp *fasthttp.Resp // Release releases all cookie values. func (cj *CookieJar) Release() { + // FOllOW-UP performance optimization + // currently a race condition is found because the reset method modifies a value which is not a copy but a reference -> solution should be to make a copy + //for _, v := range cj.hostCookies { + // for _, c := range v { + // fasthttp.ReleaseCookie(c) + // } + //} cj.hostCookies = nil } From a14086380f5f10a815480b9a4d9f6d34adc4842d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Wed, 28 Feb 2024 15:28:24 +0100 Subject: [PATCH 109/118] fix(client): race conditions --- client/cookiejar.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/client/cookiejar.go b/client/cookiejar.go index 5d52c03344..c66d5f3b7c 100644 --- a/client/cookiejar.go +++ b/client/cookiejar.go @@ -223,11 +223,11 @@ func (cj *CookieJar) parseCookiesFromResp(host, path []byte, resp *fasthttp.Resp func (cj *CookieJar) Release() { // FOllOW-UP performance optimization // currently a race condition is found because the reset method modifies a value which is not a copy but a reference -> solution should be to make a copy - //for _, v := range cj.hostCookies { - // for _, c := range v { + // for _, v := range cj.hostCookies { + // for _, c := range v { // fasthttp.ReleaseCookie(c) - // } - //} + // } + // } cj.hostCookies = nil } From 6b01572dc9aa0f2b7ffaf5ec45faf4e95c4779e7 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sat, 2 Mar 2024 16:26:52 +0300 Subject: [PATCH 110/118] apply some reviews --- client/client.go | 13 +++- client/client_test.go | 103 ++++++++++++++++++++++++++++++- client/core.go | 9 +-- client/hooks.go | 136 ++++++++++++++++++++++------------------- client/hooks_test.go | 11 +++- client/request.go | 8 +++ client/request_test.go | 15 +++++ 7 files changed, 222 insertions(+), 73 deletions(-) diff --git a/client/client.go b/client/client.go index 9a30426732..f632303914 100644 --- a/client/client.go +++ b/client/client.go @@ -188,7 +188,10 @@ func (c *Client) SetRootCertificate(path string) *Client { c.logger.Panicf("client: %v", err) } defer func() { - _ = file.Close() //nolint:errcheck // It is fine to ignore the error here + if err := file.Close(); err != nil { + c.logger.Panicf("client: failed to close file: %v", err) + } + }() pem, err := io.ReadAll(file) @@ -580,6 +583,14 @@ func (c *Client) Reset() { c.timeout = 0 c.userAgent = "" c.referer = "" + c.proxyURL = "" + c.retryConfig = nil + c.debug = false + + if c.cookieJar != nil { + c.cookieJar.Release() + c.cookieJar = nil + } c.path.Reset() c.cookies.Reset() diff --git a/client/client_test.go b/client/client_test.go index ae80242f35..c65693de7b 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -17,6 +17,7 @@ import ( "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/require" + "github.com/valyala/bytebufferpool" ) func startTestServerWithPort(t *testing.T, beforeStarting func(app *fiber.App)) (*fiber.App, string) { @@ -30,12 +31,13 @@ func startTestServerWithPort(t *testing.T, beforeStarting func(app *fiber.App)) addrChan := make(chan string) go func() { - require.NoError(t, app.Listen(":0", fiber.ListenConfig{ + err := app.Listen(":0", fiber.ListenConfig{ DisableStartupMessage: true, ListenerAddrFunc: func(addr net.Addr) { addrChan <- addr.String() }, - })) + }) + require.NoError(t, err) }() addr := <-addrChan @@ -47,19 +49,28 @@ func Test_Client_Add_Hook(t *testing.T) { t.Run("add request hooks", func(t *testing.T) { t.Parallel() + + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + client := NewClient().AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook1") return nil }) require.Len(t, client.RequestHook(), 1) client.AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook2") return nil }, func(_ *Client, _ *Request) error { + buf.WriteString("hook3") return nil }) require.Len(t, client.RequestHook(), 3) + + client.builtinRequestHooks[0](client, &Request{}) }) t.Run("add response hooks", func(t *testing.T) { @@ -80,6 +91,34 @@ func Test_Client_Add_Hook(t *testing.T) { }) } +func Test_Client_Add_Hook_CheckOrder(t *testing.T) { + t.Parallel() + + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + + client := NewClient(). + AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook1") + return nil + }). + AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook2") + return nil + }). + AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook3") + return nil + }) + + for _, hook := range client.RequestHook() { + require.NoError(t, hook(client, &Request{})) + } + + require.Equal(t, "hook1hook2hook3", buf.String()) + +} + func Test_Client_Marshal(t *testing.T) { t.Parallel() @@ -95,6 +134,18 @@ func Test_Client_Marshal(t *testing.T) { require.Equal(t, []byte("hello"), val) }) + t.Run("set json marshal error", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetJSONMarshal(func(_ any) ([]byte, error) { + return nil, errors.New("empty json") + }) + + val, err := client.JSONMarshal()(nil) + require.Nil(t, val) + require.Equal(t, errors.New("empty json"), err) + }) + t.Run("set json unmarshal", func(t *testing.T) { t.Parallel() client := NewClient(). @@ -106,6 +157,17 @@ func Test_Client_Marshal(t *testing.T) { require.Equal(t, errors.New("empty json"), err) }) + t.Run("set json unmarshal error", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetJSONUnmarshal(func(_ []byte, _ any) error { + return errors.New("empty json") + }) + + err := client.JSONUnmarshal()(nil, nil) + require.Equal(t, errors.New("empty json"), err) + }) + t.Run("set xml marshal", func(t *testing.T) { t.Parallel() client := NewClient(). @@ -118,6 +180,18 @@ func Test_Client_Marshal(t *testing.T) { require.Equal(t, []byte("hello"), val) }) + t.Run("set xml marshal error", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetXMLMarshal(func(_ any) ([]byte, error) { + return nil, errors.New("empty xml") + }) + + val, err := client.XMLMarshal()(nil) + require.Nil(t, val) + require.Equal(t, errors.New("empty xml"), err) + }) + t.Run("set xml unmarshal", func(t *testing.T) { t.Parallel() client := NewClient(). @@ -128,6 +202,17 @@ func Test_Client_Marshal(t *testing.T) { err := client.XMLUnmarshal()(nil, nil) require.Equal(t, errors.New("empty xml"), err) }) + + t.Run("set xml unmarshal error", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetXMLUnmarshal(func(_ []byte, _ any) error { + return errors.New("empty xml") + }) + + err := client.XMLUnmarshal()(nil, nil) + require.Equal(t, errors.New("empty xml"), err) + }) } func Test_Client_SetBaseURL(t *testing.T) { @@ -151,7 +236,7 @@ func Test_Client_Invalid_URL(t *testing.T) { _, err := NewClient().SetDial(dial). R(). - Get("http://example.com\r\n\r\nGET /\r\n\r\n") + Get("http//example") require.ErrorIs(t, err, ErrURLFormat) } @@ -663,6 +748,18 @@ func Test_Client_Header(t *testing.T) { require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) + + t.Run("set header case insensitive", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.SetHeader("foo", "bar"). + AddHeader("FOO", "fiber") + + res := req.Header("foo") + require.Len(t, res, 2) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) } func Test_Client_Header_With_Server(t *testing.T) { diff --git a/client/core.go b/client/core.go index d084f14979..129caf6fb9 100644 --- a/client/core.go +++ b/client/core.go @@ -90,6 +90,11 @@ func (c *core) execFunc() (*Response, error) { var err error go func() { respv := fasthttp.AcquireResponse() + defer func() { + fasthttp.ReleaseRequest(reqv) + fasthttp.ReleaseResponse(respv) + }() + if cfg != nil { err = retry.NewExponentialBackoff(*cfg).Retry(func() error { if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { @@ -105,10 +110,6 @@ func (c *core) execFunc() (*Response, error) { err = c.client.client.Do(reqv, respv) } } - defer func() { - fasthttp.ReleaseRequest(reqv) - fasthttp.ReleaseResponse(respv) - }() if atomic.CompareAndSwapInt32(&done, 0, 1) { if err != nil { diff --git a/client/hooks.go b/client/hooks.go index 3545b58448..0ecc970d53 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -192,90 +192,98 @@ func parserRequestBody(c *Client, req *Request) error { case formBody: req.RawRequest.SetBody(req.formData.QueryString()) case filesBody: - mw := multipart.NewWriter(req.RawRequest.BodyWriter()) - err := mw.SetBoundary(req.boundary) - if err != nil { - return fmt.Errorf("set boundary error: %w", err) + return parserRequestBodyFile(req) + case rawBody: + if body, ok := req.body.([]byte); ok { + req.RawRequest.SetBody(body) + } else { + return ErrBodyType } - defer func() { - err := mw.Close() - if err != nil { - return - } - }() + case noBody: + return nil + } - // add formdata - req.formData.VisitAll(func(key, value []byte) { - if err != nil { - return - } - err = mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value)) - }) + return nil +} + +// parserRequestBodyFile parses request body if body type is file +// this is an addition of parserRequestBody. +func parserRequestBodyFile(req *Request) error { + mw := multipart.NewWriter(req.RawRequest.BodyWriter()) + err := mw.SetBoundary(req.boundary) + if err != nil { + return fmt.Errorf("set boundary error: %w", err) + } + defer func() { + err := mw.Close() if err != nil { - return fmt.Errorf("write formdata error: %w", err) + return } + }() - // add file - b := make([]byte, 512) - for i, v := range req.files { - if v.name == "" && v.path == "" { - return ErrFileNoName - } + // add formdata + req.formData.VisitAll(func(key, value []byte) { + if err != nil { + return + } + err = mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value)) + }) + if err != nil { + return fmt.Errorf("write formdata error: %w", err) + } - // if name is not exist, set name - if v.name == "" && v.path != "" { - v.path = filepath.Clean(v.path) - v.name = filepath.Base(v.path) - } + // add file + b := make([]byte, 512) + for i, v := range req.files { + if v.name == "" && v.path == "" { + return ErrFileNoName + } - // if field name is not exist, set it - if v.fieldName == "" { - v.fieldName = "file" + strconv.Itoa(i+1) - } + // if name is not exist, set name + if v.name == "" && v.path != "" { + v.path = filepath.Clean(v.path) + v.name = filepath.Base(v.path) + } - // check the reader - if v.reader == nil { - v.reader, err = os.Open(v.path) - if err != nil { - return fmt.Errorf("open file error: %w", err) - } - } + // if field name is not exist, set it + if v.fieldName == "" { + v.fieldName = "file" + strconv.Itoa(i+1) + } - // write file - w, err := mw.CreateFormFile(v.fieldName, v.name) + // check the reader + if v.reader == nil { + v.reader, err = os.Open(v.path) if err != nil { - return fmt.Errorf("create file error: %w", err) + return fmt.Errorf("open file error: %w", err) } + } - for { - n, err := v.reader.Read(b) - if err != nil && !errors.Is(err, io.EOF) { - return fmt.Errorf("read file error: %w", err) - } + // write file + w, err := mw.CreateFormFile(v.fieldName, v.name) + if err != nil { + return fmt.Errorf("create file error: %w", err) + } - if errors.Is(err, io.EOF) { - break - } + for { + n, err := v.reader.Read(b) + if err != nil && !errors.Is(err, io.EOF) { + return fmt.Errorf("read file error: %w", err) + } - _, err = w.Write(b[:n]) - if err != nil { - return fmt.Errorf("write file error: %w", err) - } + if errors.Is(err, io.EOF) { + break } - err = v.reader.Close() + _, err = w.Write(b[:n]) if err != nil { - return fmt.Errorf("close file error: %w", err) + return fmt.Errorf("write file error: %w", err) } } - case rawBody: - if body, ok := req.body.([]byte); ok { - req.RawRequest.SetBody(body) - } else { - return ErrBodyType + + err = v.reader.Close() + if err != nil { + return fmt.Errorf("close file error: %w", err) } - case noBody: - return nil } return nil diff --git a/client/hooks_test.go b/client/hooks_test.go index 455ebe3268..a555bba833 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -24,12 +24,21 @@ func Test_Rand_String(t *testing.T) { name: "test generate", args: 16, }, + { + name: "test generate smaller string", + args: 8, + }, + { + name: "test generate larger string", + args: 32, + }, } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() got := randString(tt.args) - require.Len(t, got, 16) + require.Len(t, got, tt.args) }) } } diff --git a/client/request.go b/client/request.go index 1fca71d756..0bf2fb321c 100644 --- a/client/request.go +++ b/client/request.go @@ -36,6 +36,8 @@ const ( rawBody ) +var ErrClientNil = errors.New("client can not be nil") + // Request is a struct which contains the request data. type Request struct { url string @@ -92,6 +94,10 @@ func (r *Request) Client() *Client { // SetClient method sets client in request instance. func (r *Request) SetClient(c *Client) *Request { + if c == nil { + panic(ErrClientNil) + } + r.client = c return r } @@ -342,6 +348,7 @@ func (r *Request) SetRawBody(v []byte) *Request { } // resetBody will clear body object and set bodyType +// if body type is formBody and filesBody, the new body type will be ignored. func (r *Request) resetBody(t bodyType) { r.body = nil @@ -921,6 +928,7 @@ func ReleaseFile(f *File) { // `p` is a structure that implements the WithStruct interface, // The field name can be specified by `tagName`. // `v` is a struct include some data. +// Note: This method only supports simple types and nested structs are not currently supported. func SetValWithStruct(p WithStruct, tagName string, v any) { valueOfV := reflect.ValueOf(v) typeOfV := reflect.TypeOf(v) diff --git a/client/request_test.go b/client/request_test.go index bc1a3d5716..07e5254e15 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -35,6 +35,21 @@ func Test_Request_Method(t *testing.T) { req.SetMethod("DELETE") require.Equal(t, "DELETE", req.Method()) + + req.SetMethod("PATCH") + require.Equal(t, "PATCH", req.Method()) + + req.SetMethod("OPTIONS") + require.Equal(t, "OPTIONS", req.Method()) + + req.SetMethod("HEAD") + require.Equal(t, "HEAD", req.Method()) + + req.SetMethod("TRACE") + require.Equal(t, "TRACE", req.Method()) + + req.SetMethod("CUSTOM") + require.Equal(t, "CUSTOM", req.Method()) } func Test_Request_URL(t *testing.T) { From 907b8a7f400cb8f71af2f370ba4bb72c68d880d4 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sat, 2 Mar 2024 16:33:59 +0300 Subject: [PATCH 111/118] change client property name --- client/client.go | 16 ++++++++-------- client/core.go | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/client/client.go b/client/client.go index f632303914..f478d0e57e 100644 --- a/client/client.go +++ b/client/client.go @@ -35,7 +35,7 @@ var ( type Client struct { mu sync.RWMutex - client *fasthttp.Client + fasthttp *fasthttp.Client baseURL string userAgent string @@ -158,18 +158,18 @@ func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client { // TLSConfig returns tlsConfig in client. // If client don't have tlsConfig, this function will init it. func (c *Client) TLSConfig() *tls.Config { - if c.client.TLSConfig == nil { - c.client.TLSConfig = &tls.Config{ + if c.fasthttp.TLSConfig == nil { + c.fasthttp.TLSConfig = &tls.Config{ MinVersion: tls.VersionTLS12, } } - return c.client.TLSConfig + return c.fasthttp.TLSConfig } // SetTLSConfig sets tlsConfig in client. func (c *Client) SetTLSConfig(config *tls.Config) *Client { - c.client.TLSConfig = config + c.fasthttp.TLSConfig = config return c } @@ -558,7 +558,7 @@ func (c *Client) SetDial(dial fasthttp.DialFunc) *Client { c.mu.Lock() defer c.mu.Unlock() - c.client.Dial = dial + c.fasthttp.Dial = dial return c } @@ -578,7 +578,7 @@ func (c *Client) Logger() log.CommonLogger { // Reset clear Client object func (c *Client) Reset() { - c.client = &fasthttp.Client{} + c.fasthttp = &fasthttp.Client{} c.baseURL = "" c.timeout = 0 c.userAgent = "" @@ -698,7 +698,7 @@ func NewClient() *Client { // for the fiber client and the fasthttp client // if possible also for other structs -> request header, cookie, query param, path param... return &Client{ - client: &fasthttp.Client{}, + fasthttp: &fasthttp.Client{}, header: &Header{ RequestHeader: &fasthttp.RequestHeader{}, }, diff --git a/client/core.go b/client/core.go index 129caf6fb9..315d12d474 100644 --- a/client/core.go +++ b/client/core.go @@ -98,16 +98,16 @@ func (c *core) execFunc() (*Response, error) { if cfg != nil { err = retry.NewExponentialBackoff(*cfg).Retry(func() error { if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { - return c.client.client.DoRedirects(reqv, respv, c.req.maxRedirects) + return c.client.fasthttp.DoRedirects(reqv, respv, c.req.maxRedirects) } - return c.client.client.Do(reqv, respv) + return c.client.fasthttp.Do(reqv, respv) }) } else { if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { - err = c.client.client.DoRedirects(reqv, respv, c.req.maxRedirects) + err = c.client.fasthttp.DoRedirects(reqv, respv, c.req.maxRedirects) } else { - err = c.client.client.Do(reqv, respv) + err = c.client.fasthttp.Do(reqv, respv) } } From cb33aaefd6abac3e24067b119b048099a202f1e6 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sat, 2 Mar 2024 16:59:00 +0300 Subject: [PATCH 112/118] apply review --- client/client_test.go | 32 ++++++++++++++++++++++++++++++++ client/response_test.go | 7 +++---- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index c65693de7b..9a9614eb75 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1239,6 +1239,38 @@ func Test_Client_TLS(t *testing.T) { require.Equal(t, "tls", resp.String()) } +func Test_Client_TLS_Error(t *testing.T) { + t.Parallel() + + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + clientTLSConf.MaxVersion = tls.VersionTLS12 + serverTLSConf.MinVersion = tls.VersionTLS13 + require.NoError(t, err) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.NoError(t, err) + + ln = tls.NewListener(ln, serverTLSConf) + + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("tls") + }) + + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + client := NewClient() + resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) + + require.Error(t, err) + require.Equal(t, clientTLSConf, client.TLSConfig()) + require.Nil(t, resp) +} + func Test_Client_TLS_Empty_TLSConfig(t *testing.T) { t.Parallel() diff --git a/client/response_test.go b/client/response_test.go index c474c40a20..622e835714 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -358,11 +358,10 @@ func Test_Response_Save(t *testing.T) { err = resp.Save("./test/tmp.json") require.NoError(t, err) defer func() { - if _, err := os.Stat("./test/tmp.json"); err != nil { - return - } + _, err := os.Stat("./test/tmp.json") + require.NoError(t, err) - err := os.RemoveAll("./test") + err = os.RemoveAll("./test") require.NoError(t, err) }() From 9bf3d34e4dfa44dc2900d73c519a22af44794232 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sat, 2 Mar 2024 17:10:57 +0300 Subject: [PATCH 113/118] add parallel benchmark for simple request --- client/client_test.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/client/client_test.go b/client/client_test.go index 9a9614eb75..125d84c61e 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1580,3 +1580,27 @@ func Benchmark_Client_Request(b *testing.B) { } require.NoError(b, err) } + +func Benchmark_Client_Request_Parallel(b *testing.B) { + app, dial, start := createHelperServer(b) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) + + go start() + + client := NewClient().SetDial(dial) + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + var err error + var resp *Response + for pb.Next() { + resp, err = client.Get("http://example.com") + resp.Close() + } + require.NoError(b, err) + }) +} From f56dfd24c8073e814bd52339553e1f49815fa2f7 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sat, 2 Mar 2024 17:38:29 +0300 Subject: [PATCH 114/118] apply review --- client/client_test.go | 20 ++++++++++++++++++++ listen_test.go | 11 +++++++---- redirect_test.go | 8 ++++---- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index 125d84c61e..b0d0729c77 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -973,6 +973,26 @@ func Test_Client_CookieJar_Response(t *testing.T) { } } }) + + t.Run("different domain", func(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + return c.SendString(c.Cookies("k1")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1") + + require.Len(t, jar.getCookiesByHost("example.com"), 1) + require.Len(t, jar.getCookiesByHost("example"), 0) + }) } func Test_Client_Referer(t *testing.T) { diff --git a/listen_test.go b/listen_test.go index 7d9b8eb3bc..a5d419ac86 100644 --- a/listen_test.go +++ b/listen_test.go @@ -69,10 +69,10 @@ func Test_Listen_Graceful_Shutdown(t *testing.T) { Time time.Duration ExpectedBody string ExpectedStatusCode int - ExceptedErr error + ExpectedErr error }{ - {Time: 100 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExceptedErr: nil}, - {Time: 500 * time.Millisecond, ExpectedBody: "", ExpectedStatusCode: StatusOK, ExceptedErr: errors.New("InmemoryListener is already closed: use of closed network connection")}, + {Time: 100 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExpectedErr: nil}, + {Time: 500 * time.Millisecond, ExpectedBody: "", ExpectedStatusCode: StatusOK, ExpectedErr: errors.New("InmemoryListener is already closed: use of closed network connection")}, } for _, tc := range testCases { @@ -87,9 +87,12 @@ func Test_Listen_Graceful_Shutdown(t *testing.T) { resp := fasthttp.AcquireResponse() err := client.Do(req, resp) - require.Equal(t, tc.ExceptedErr, err) + require.Equal(t, tc.ExpectedErr, err) require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode()) require.Equal(t, tc.ExpectedBody, string(resp.Body())) + + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) } mu.Lock() diff --git a/redirect_test.go b/redirect_test.go index b45aecc8b7..a83a37dd64 100644 --- a/redirect_test.go +++ b/redirect_test.go @@ -291,28 +291,28 @@ func Test_Redirect_Request(t *testing.T) { CookieValue string ExpectedBody string ExpectedStatusCode int - ExceptedErr error + ExpectedErr error }{ { URL: "/", CookieValue: "key:value,key2:value2,co\\:m\\,ma:Fi\\:ber\\, v3", ExpectedBody: `{"inputs":{},"messages":{"co:m,ma":"Fi:ber, v3","key":"value","key2":"value2"}}`, ExpectedStatusCode: StatusOK, - ExceptedErr: nil, + ExpectedErr: nil, }, { URL: "/with-inputs?name=john&surname=doe", CookieValue: "key:value,key2:value2,key:value,key2:value2,old_input_data_name:john,old_input_data_surname:doe", ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{"key":"value","key2":"value2"}}`, ExpectedStatusCode: StatusOK, - ExceptedErr: nil, + ExpectedErr: nil, }, { URL: "/just-inputs?name=john&surname=doe", CookieValue: "old_input_data_name:john,old_input_data_surname:doe", ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{}}`, ExpectedStatusCode: StatusOK, - ExceptedErr: nil, + ExpectedErr: nil, }, } From 96d606820b89ceddbe59054288e91e3715da3881 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sat, 2 Mar 2024 18:06:53 +0300 Subject: [PATCH 115/118] apply review --- client/client_test.go | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index b0d0729c77..778aae850d 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -30,6 +30,7 @@ func startTestServerWithPort(t *testing.T, beforeStarting func(app *fiber.App)) } addrChan := make(chan string) + errChan := make(chan error, 1) go func() { err := app.Listen(":0", fiber.ListenConfig{ DisableStartupMessage: true, @@ -37,11 +38,19 @@ func startTestServerWithPort(t *testing.T, beforeStarting func(app *fiber.App)) addrChan <- addr.String() }, }) - require.NoError(t, err) + if err != nil { + errChan <- err + } }() - addr := <-addrChan - return app, addr + select { + case addr := <-addrChan: + return app, addr + case err := <-errChan: + t.Fatalf("Failed to start test server: %v", err) + } + + return nil, "" } func Test_Client_Add_Hook(t *testing.T) { @@ -136,14 +145,16 @@ func Test_Client_Marshal(t *testing.T) { t.Run("set json marshal error", func(t *testing.T) { t.Parallel() + + emptyErr := errors.New("empty json") client := NewClient(). SetJSONMarshal(func(_ any) ([]byte, error) { - return nil, errors.New("empty json") + return nil, emptyErr }) val, err := client.JSONMarshal()(nil) require.Nil(t, val) - require.Equal(t, errors.New("empty json"), err) + require.ErrorIs(t, err, emptyErr) }) t.Run("set json unmarshal", func(t *testing.T) { @@ -341,6 +352,7 @@ func Test_Head(t *testing.T) { resp, err := Head("http://" + addr) require.NoError(t, err) + require.Equal(t, "7", resp.Header(fiber.HeaderContentLength)) require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body())) }) @@ -354,6 +366,7 @@ func Test_Head(t *testing.T) { resp, err := NewClient().Head("http://" + addr) require.NoError(t, err) + require.Equal(t, "7", resp.Header(fiber.HeaderContentLength)) require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body())) }) } @@ -536,6 +549,7 @@ func Test_Options(t *testing.T) { setupApp := func() (*fiber.App, string) { app, addr := startTestServerWithPort(t, func(app *fiber.App) { app.Options("/", func(c fiber.Ctx) error { + c.Set(fiber.HeaderAllow, "GET, POST, PUT, DELETE, PATCH") return c.Status(fiber.StatusNoContent).SendString("") }) }) @@ -555,6 +569,7 @@ func Test_Options(t *testing.T) { resp, err := Options("http://" + addr) require.NoError(t, err) + require.Equal(t, "GET, POST, PUT, DELETE, PATCH", resp.Header(fiber.HeaderAllow)) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) require.Equal(t, "", resp.String()) } @@ -572,6 +587,7 @@ func Test_Options(t *testing.T) { resp, err := NewClient().Options("http://" + addr) require.NoError(t, err) + require.Equal(t, "GET, POST, PUT, DELETE, PATCH", resp.Header(fiber.HeaderAllow)) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) require.Equal(t, "", resp.String()) } @@ -1213,8 +1229,8 @@ func Test_Client_PathParam(t *testing.T) { func Test_Client_PathParam_With_Server(t *testing.T) { app, dial, start := createHelperServer(t) - app.Get("/test", func(c fiber.Ctx) error { - return c.SendString("ok") + app.Get("/:test", func(c fiber.Ctx) error { + return c.SendString(c.Params("test")) }) go start() @@ -1225,7 +1241,7 @@ func Test_Client_PathParam_With_Server(t *testing.T) { require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) - require.Equal(t, "ok", resp.String()) + require.Equal(t, "test", resp.String()) } func Test_Client_TLS(t *testing.T) { From 0b6a4e84265c6e675ed34662346f6e3616830f1b Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sat, 2 Mar 2024 18:24:23 +0300 Subject: [PATCH 116/118] fix log tests --- log/default_test.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/log/default_test.go b/log/default_test.go index 8379df1a66..0f57d2009f 100644 --- a/log/default_test.go +++ b/log/default_test.go @@ -64,7 +64,8 @@ func Test_DefaultLogger(t *testing.T) { "[Debug] received work order\n"+ "[Info] starting work\n"+ "[Warn] work may fail\n"+ - "[Error] work failed\n", string(w.b)) + "[Error] work failed\n"+ + "[Panic] work panic\n", string(w.b)) } func Test_DefaultFormatLogger(t *testing.T) { @@ -87,7 +88,8 @@ func Test_DefaultFormatLogger(t *testing.T) { "[Debug] received work order\n"+ "[Info] starting work\n"+ "[Warn] work may fail\n"+ - "[Error] work failed\n", string(w.b)) + "[Error] work failed\n"+ + "[Panic] work panic\n", string(w.b)) } func Test_CtxLogger(t *testing.T) { @@ -112,7 +114,8 @@ func Test_CtxLogger(t *testing.T) { "[Debug] received work order\n"+ "[Info] starting work\n"+ "[Warn] work may fail\n"+ - "[Error] work failed 50\n", string(w.b)) + "[Error] work failed 50\n"+ + "[Panic] work panic\n", string(w.b)) } func Test_LogfKeyAndValues(t *testing.T) { From 557ddcaf336c7e0d58d3862be71f10e1982d8b78 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Sat, 2 Mar 2024 18:26:48 +0300 Subject: [PATCH 117/118] fix linter --- client/client.go | 1 - client/client_test.go | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/client/client.go b/client/client.go index f478d0e57e..9ceb238c24 100644 --- a/client/client.go +++ b/client/client.go @@ -191,7 +191,6 @@ func (c *Client) SetRootCertificate(path string) *Client { if err := file.Close(); err != nil { c.logger.Panicf("client: failed to close file: %v", err) } - }() pem, err := io.ReadAll(file) diff --git a/client/client_test.go b/client/client_test.go index 778aae850d..6985822b84 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -78,8 +78,6 @@ func Test_Client_Add_Hook(t *testing.T) { }) require.Len(t, client.RequestHook(), 3) - - client.builtinRequestHooks[0](client, &Request{}) }) t.Run("add response hooks", func(t *testing.T) { @@ -125,7 +123,6 @@ func Test_Client_Add_Hook_CheckOrder(t *testing.T) { } require.Equal(t, "hook1hook2hook3", buf.String()) - } func Test_Client_Marshal(t *testing.T) { @@ -1007,7 +1004,7 @@ func Test_Client_CookieJar_Response(t *testing.T) { testClient(t, handler, wrapAgent, "v1") require.Len(t, jar.getCookiesByHost("example.com"), 1) - require.Len(t, jar.getCookiesByHost("example"), 0) + require.Empty(t, jar.getCookiesByHost("example")) }) } From 2627b521d0a4ed446c6a5cef34c39bc9f067a36c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Mon, 4 Mar 2024 08:12:56 +0100 Subject: [PATCH 118/118] fix(client): return error in SetProxyURL instead of panic --- client/client.go | 11 +++++------ client/client_test.go | 19 +++++++++++-------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/client/client.go b/client/client.go index 9ceb238c24..22cbd19727 100644 --- a/client/client.go +++ b/client/client.go @@ -7,6 +7,7 @@ import ( "encoding/json" "encoding/xml" "errors" + "fmt" "io" urlpkg "net/url" "os" @@ -226,21 +227,19 @@ func (c *Client) SetRootCertificateFromString(pem string) *Client { } // SetProxyURL sets proxy url in client. It will apply via core to hostclient. -func (c *Client) SetProxyURL(proxyURL string) *Client { +func (c *Client) SetProxyURL(proxyURL string) error { pURL, err := urlpkg.Parse(proxyURL) if err != nil { - c.logger.Panicf("client: %v", err) - return c + return fmt.Errorf("client: %w", err) } if pURL.Scheme != "http" && pURL.Scheme != "https" { - c.logger.Panicf("client: %v", ErrInvalidProxyURL) - return c + return fmt.Errorf("client: %w", ErrInvalidProxyURL) } c.proxyURL = pURL.String() - return c + return nil } // RetryConfig returns retry config in client. diff --git a/client/client_test.go b/client/client_test.go index 6985822b84..4fd2e484a6 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1549,8 +1549,11 @@ func Test_Client_SetProxyURL(t *testing.T) { t.Run("success", func(t *testing.T) { t.Parallel() client := NewClient().SetDial(dial) - client.SetProxyURL("http://test.com") - _, err := client.Get("http://localhost:3000") + err := client.SetProxyURL("http://test.com") + + require.NoError(t, err) + + _, err = client.Get("http://localhost:3000") require.NoError(t, err) }) @@ -1559,18 +1562,18 @@ func Test_Client_SetProxyURL(t *testing.T) { t.Parallel() client := NewClient() - require.Panics(t, func() { - client.SetProxyURL(":this is not a url") - }) + err := client.SetProxyURL(":this is not a url") + + require.Error(t, err) }) t.Run("error", func(t *testing.T) { t.Parallel() client := NewClient() - require.Panics(t, func() { - client.SetProxyURL("htgdftp://test.com") - }) + err := client.SetProxyURL("htgdftp://test.com") + + require.Error(t, err) }) }