From 41ae1a54fe522abad47ccbc49b7b537cb4d975d6 Mon Sep 17 00:00:00 2001 From: Muhammed Efe Cetin Date: Mon, 7 Aug 2023 18:23:23 +0300 Subject: [PATCH] improve test coverage --- client/client.go | 20 +--- client/client_test.go | 226 ++++++++++++++++++++++++++++------------ client/hooks.go | 7 +- client/hooks_test.go | 88 ++++++++++++++++ client/jar.go | 2 +- client/jar_test.go | 25 +++++ client/request.go | 6 ++ client/request_test.go | 24 +++++ client/response.go | 6 +- client/response_test.go | 3 +- 10 files changed, 316 insertions(+), 91 deletions(-) diff --git a/client/client.go b/client/client.go index 33265bc239..cbc944b126 100644 --- a/client/client.go +++ b/client/client.go @@ -6,6 +6,8 @@ import ( "crypto/x509" "encoding/json" "encoding/xml" + "github.com/gofiber/fiber/v3/log" + "github.com/gofiber/utils/v2" "io" "net/url" "os" @@ -14,7 +16,6 @@ import ( "sync" "time" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -239,6 +240,8 @@ func (c *Client) SetRootCertificateFromString(pem string) *Client { func (c *Client) SetProxyURL(proxyURL string) *Client { pUrl, err := url.Parse(proxyURL) if err != nil { + log.Errorf("%v", err) + return c } c.proxyURL = pUrl.String() @@ -474,21 +477,6 @@ func (c *Client) SetTimeout(t time.Duration) *Client { return c } -func (c *Client) Logger() Logger { - if c.logger == nil { - return &disableLogger{} - } - - return c.logger -} - -// SetLogger set logger field in client. -// The logger would output relate info with request. -func (c *Client) SetLogger(logger Logger) *Client { - c.logger = logger - return c -} - // Debug enable log debug level output. func (c *Client) Debug() *Client { c.debug = true diff --git a/client/client_test.go b/client/client_test.go index 9cdf6348ae..3da79c2596 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,12 +1,18 @@ package client import ( + "bytes" "context" "crypto/tls" "fmt" + "github.com/gofiber/fiber/v3/addon/retry" + "github.com/gofiber/fiber/v3/log" + "io" "net" + "os" "reflect" "testing" + "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/tlstest" @@ -813,37 +819,38 @@ func Test_Client_PathParam_With_Server(t *testing.T) { require.Equal(t, "ok", resp.String()) } -// func Test_Client_Cert(t *testing.T) { -// t.Parallel() +func Test_Client_TLS(t *testing.T) { + t.Parallel() -// serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() -// require.Nil(t, err) + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + require.Nil(t, err) -// ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") -// require.Nil(t, err) + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.Nil(t, err) -// ln = tls.NewListener(ln, serverTLSConf) + ln = tls.NewListener(ln, serverTLSConf) -// app := fiber.New() -// app.Get("/", func(c fiber.Ctx) error { -// return c.SendString("tls") -// }) + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("tls") + }) -// go func() { -// require.Nil(t, nil, app.Listener(ln, fiber.ListenConfig{ -// DisableStartupMessage: true, -// })) -// }() + go func() { + require.Nil(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() -// client := AcquireClient().SetCertificates(clientTLSConf.Certificates...) -// resp, err := client.Get("https://" + ln.Addr().String()) + client := AcquireClient() + resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) -// require.Nil(t, err) -// require.Equal(t, fiber.StatusOK, resp.StatusCode()) -// require.Equal(t, "tls", resp.String()) -// } + require.Nil(t, err) + require.Equal(t, clientTLSConf, client.TLSConfig()) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "tls", resp.String()) +} -func Test_Client_TLS(t *testing.T) { +func Test_Client_TLS_Empty_TLSConfig(t *testing.T) { t.Parallel() serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() @@ -866,12 +873,42 @@ func Test_Client_TLS(t *testing.T) { }() client := AcquireClient() - resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) + resp, err := client.Get("https://" + ln.Addr().String()) + require.Error(t, err) + require.NotEqual(t, clientTLSConf, client.TLSConfig()) + require.Nil(t, resp) +} + +func Test_Client_SetCertificates(t *testing.T) { + t.Parallel() + + serverTLSConf, _, err := tlstest.GetTLSConfigs() require.Nil(t, err) - require.Equal(t, clientTLSConf, client.TLSConfig()) - require.Equal(t, fiber.StatusOK, resp.StatusCode()) - require.Equal(t, "tls", resp.String()) + + client := AcquireClient().SetCertificates(serverTLSConf.Certificates...) + require.Equal(t, 1, len(client.tlsConfig.Certificates)) +} + +func Test_Client_SetRootCertificate(t *testing.T) { + t.Parallel() + + client := AcquireClient().SetRootCertificate("../.github/testdata/ssl.pem") + require.NotNil(t, client.tlsConfig.RootCAs) +} + +func Test_Client_SetRootCertificateFromString(t *testing.T) { + t.Parallel() + + file, err := os.Open("../.github/testdata/ssl.pem") + defer func() { _ = file.Close() }() + require.NoError(t, err) + + pem, err := io.ReadAll(file) + require.NoError(t, err) + + client := AcquireClient().SetRootCertificateFromString(string(pem)) + require.NotNil(t, client.tlsConfig.RootCAs) } func Test_Client_R(t *testing.T) { @@ -968,70 +1005,125 @@ func Test_Set_Config_To_Request(t *testing.T) { require.Equal(t, "v1", req.Param("k1")[0]) }) - // t.Run("set ctx", func(t *testing.T) { - // key := struct{}{} + t.Run("set cookies", func(t *testing.T) { + req := AcquireRequest() + + setConfigToRequest(req, Config{Cookie: map[string]string{ + "k1": "v1", + }}) + + require.Equal(t, "v1", req.Cookie("k1")) + }) + + t.Run("set pathparam", func(t *testing.T) { + req := AcquireRequest() + + setConfigToRequest(req, Config{PathParam: map[string]string{ + "k1": "v1", + }}) + + require.Equal(t, "v1", req.PathParam("k1")) + }) + + t.Run("set timeout", func(t *testing.T) { + req := AcquireRequest() + + setConfigToRequest(req, Config{Timeout: 1 * time.Second}) + + require.Equal(t, 1*time.Second, req.Timeout()) + }) + + t.Run("set maxredirects", func(t *testing.T) { + req := AcquireRequest() - // ctx := context.Background() - // ctx = context.WithValue(ctx, key, "v1") + setConfigToRequest(req, Config{MaxRedirects: 1}) - // req := AcquireRequest() + require.Equal(t, 1, req.MaxRedirects()) + }) - // setConfigToRequest(req, Config{Ctx: ctx}) + t.Run("set body", func(t *testing.T) { + req := AcquireRequest() - // require.Equal(t, "v1", req.Context().Value(key)) - // }) + setConfigToRequest(req, Config{Body: "test"}) - // t.Run("set ctx", func(t *testing.T) { - // key := struct{}{} + require.Equal(t, "test", req.body) + }) - // ctx := context.Background() - // ctx = context.WithValue(ctx, key, "v1") + t.Run("set file", func(t *testing.T) { + req := AcquireRequest() - // req := AcquireRequest() + setConfigToRequest(req, Config{File: []*File{ + { + name: "test", + path: "path", + }, + }}) - // setConfigToRequest(req, Config{Ctx: ctx}) + require.Equal(t, "path", req.File("test").path) + }) +} - // require.Equal(t, "v1", req.Context().Value(key)) - // }) +func Test_Client_SetProxyURL(t *testing.T) { + t.Parallel() - // t.Run("set ctx", func(t *testing.T) { - // key := struct{}{} + app, dial, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) - // ctx := context.Background() - // ctx = context.WithValue(ctx, key, "v1") + go start() - // req := AcquireRequest() + defer func(app *fiber.App) { + _ = app.Shutdown() + }(app) - // setConfigToRequest(req, Config{Ctx: ctx}) + time.Sleep(1 * time.Second) - // require.Equal(t, "v1", req.Context().Value(key)) - // }) + t.Run("success", func(t *testing.T) { + client := AcquireClient() + client.SetProxyURL("http://test.com") + _, err := client.Get("http://localhost:3000", Config{Dial: dial}) - // t.Run("set ctx", func(t *testing.T) { - // key := struct{}{} + require.NoError(t, err) + }) - // ctx := context.Background() - // ctx = context.WithValue(ctx, key, "v1") + t.Run("wrong url", func(t *testing.T) { + var buf bytes.Buffer + log.SetOutput(&buf) - // req := AcquireRequest() + client := AcquireClient() + client.SetProxyURL(":this is not a url") + _, err := client.Get("http://localhost:3000", Config{Dial: dial}) - // setConfigToRequest(req, Config{Ctx: ctx}) + require.Contains(t, buf.String(), "missing protocol scheme") + require.NoError(t, err) + }) - // require.Equal(t, "v1", req.Context().Value(key)) - // }) + t.Run("error", func(t *testing.T) { + client := AcquireClient() + client.SetProxyURL("htgdftp://test.com") + _, err := client.Get("http://localhost:3000", Config{Dial: dial}) - // t.Run("set ctx", func(t *testing.T) { - // key := struct{}{} + require.Error(t, err) + }) +} - // ctx := context.Background() - // ctx = context.WithValue(ctx, key, "v1") +func Test_Client_SetRetryConfig(t *testing.T) { + t.Parallel() - // req := AcquireRequest() + retryConfig := &retry.Config{ + InitialInterval: 1 * time.Second, + MaxRetryCount: 3, + } - // setConfigToRequest(req, Config{Ctx: ctx}) + core, client, req := newCore(), AcquireClient(), AcquireRequest() + req.SetURL("http://example.com") + client.SetRetryConfig(retryConfig) + _, err := core.execute(context.Background(), client, req) - // require.Equal(t, "v1", req.Context().Value(key)) - // }) + require.NoError(t, err) + require.Equal(t, retryConfig.InitialInterval, client.RetryConfig().InitialInterval) + require.Equal(t, retryConfig.MaxRetryCount, client.RetryConfig().MaxRetryCount) } func Benchmark_Client_Request(b *testing.B) { diff --git a/client/hooks.go b/client/hooks.go index c2c5713d58..15d13ddf72 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -298,10 +299,8 @@ func logger(c *Client, resp *Response, req *Request) (err error) { return } - logger := c.Logger() - - logger.Debugf("%s\n", req.RawRequest.String()) - logger.Debugf("%s\n", resp.RawResponse.String()) + log.Debugf("%s\n", req.RawRequest.String()) + log.Debugf("%s\n", resp.RawResponse.String()) return } diff --git a/client/hooks_test.go b/client/hooks_test.go index b2e1ab9ca8..d144b29e30 100644 --- a/client/hooks_test.go +++ b/client/hooks_test.go @@ -1,11 +1,14 @@ package client import ( + "bytes" "encoding/xml" + "github.com/gofiber/fiber/v3/log" "io" "net/url" "strings" "testing" + "time" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" @@ -424,6 +427,17 @@ func Test_Parser_Request_Body(t *testing.T) { require.Equal(t, "ball=cricle+and+square", string(req.RawRequest.Body())) }) + t.Run("form data body error", func(t *testing.T) { + client := AcquireClient() + req := AcquireRequest(). + SetFormDatas(map[string]string{ + "": "", + }) + + err := parserRequestBody(client, req) + require.NoError(t, err) + }) + t.Run("file body", func(t *testing.T) { client := AcquireClient() req := AcquireRequest(). @@ -457,4 +471,78 @@ func Test_Parser_Request_Body(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte("hello world"), req.RawRequest.Body()) }) + + t.Run("raw body error", func(t *testing.T) { + client := AcquireClient() + req := AcquireRequest(). + SetRawBody([]byte("hello world")) + + req.body = nil + + err := parserRequestBody(client, req) + require.ErrorIs(t, err, ErrBodyType) + }) +} + +func Test_Client_Logger_Debug(t *testing.T) { + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("response") + }) + + go func() { + require.Nil(t, app.Listen(":3000", fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + defer func(app *fiber.App) { + _ = app.Shutdown() + }(app) + + time.Sleep(1 * time.Second) + + var buf bytes.Buffer + log.SetOutput(&buf) + + client := AcquireClient() + client.Debug() + + resp, err := client.Get("http://localhost:3000") + defer resp.Close() + + require.NoError(t, err) + require.Contains(t, buf.String(), "Host: localhost:3000") + require.Contains(t, buf.String(), "Content-Length: 8") +} + +func Test_Client_Logger_DisableDebug(t *testing.T) { + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("response") + }) + + go func() { + require.Nil(t, app.Listen(":3000", fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + defer func(app *fiber.App) { + _ = app.Shutdown() + }(app) + + time.Sleep(1 * time.Second) + + var buf bytes.Buffer + log.SetOutput(&buf) + + client := AcquireClient() + client.DisableDebug() + + resp, err := client.Get("http://localhost:3000") + defer resp.Close() + + require.NoError(t, err) + require.Len(t, buf.String(), 0) } diff --git a/client/jar.go b/client/jar.go index 648fce0a25..71f711fff7 100644 --- a/client/jar.go +++ b/client/jar.go @@ -244,7 +244,7 @@ func newEntry(c *fasthttp.Cookie, now time.Time, path []byte) (*entry, bool) { e := acquireEntry() e.Key = utils.CopyBytes(c.Key()) - + fmt.Println(c.Path()) if len(c.Path()) != 0 || c.Path()[0] != '/' { e.Path = utils.CopyBytes(path) } else { diff --git a/client/jar_test.go b/client/jar_test.go index cf8947192e..c5c6825f38 100644 --- a/client/jar_test.go +++ b/client/jar_test.go @@ -87,6 +87,31 @@ func TestHasDotSuffix(t *testing.T) { } } +var canonicalHostTests = map[string]string{ + "www.example.com": "www.example.com", + "WWW.EXAMPLE.COM": "www.example.com", + "wWw.eXAmple.CoM": "www.example.com", + "www.example.com:80": "www.example.com", + "192.168.0.10": "192.168.0.10", + "192.168.0.5:8080": "192.168.0.5", + "2001:4860:0:2001::68": "2001:4860:0:2001::68", + "[2001:4860:0:::68]:8080": "2001:4860:0:::68", + "www.bücher.de": "www.xn--bcher-kva.de", + "www.example.com.": "www.example.com", + // TODO: Fix canonicalHost so that all of the following malformed + // domain names trigger an error. (This list is not exhaustive, e.g. + // malformed internationalized domain names are missing.) + ".": "", + "..": ".", + "...": "..", + ".net": ".net", + ".net.": ".net", + "a..": "a.", + "b.a..": "b.a.", + "weird.stuff...": "weird.stuff..", + "[bad.unmatched.bracket:": "error", +} + var jarKeyTests = map[string]string{ "foo.www.example.com": "example.com", "www.example.com": "example.com", diff --git a/client/request.go b/client/request.go index c2d8cad74f..5e55f1dccb 100644 --- a/client/request.go +++ b/client/request.go @@ -324,6 +324,12 @@ func (r *Request) DelPathParams(key ...string) *Request { return r } +// ResetPathParams deletes all path params. +func (r *Request) ResetPathParams() *Request { + r.path.Reset() + return r +} + // SetJSON method sets json body in request. func (r *Request) SetJSON(v any) *Request { r.body = v diff --git a/client/request_test.go b/client/request_test.go index 1598033b9c..78d3ef6c35 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -388,6 +388,20 @@ func Test_Request_PathParam(t *testing.T) { require.Equal(t, "", req.PathParam("foo")) require.Equal(t, "foo", req.PathParam("bar")) }) + + t.Run("clear path params", func(t *testing.T) { + req := AcquireRequest(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + + req.ResetPathParams() + require.Equal(t, "", req.PathParam("foo")) + require.Equal(t, "", req.PathParam("bar")) + }) } func Test_Request_FormData(t *testing.T) { @@ -522,6 +536,8 @@ func Test_Request_File(t *testing.T) { require.Equal(t, "../.github/index.html", req.File("index.html").path) require.Equal(t, "../.github/index.html", req.FileByPath("../.github/index.html").path) require.Equal(t, "tmp.txt", req.File("tmp.txt").name) + require.Nil(t, req.File("tmp2.txt")) + require.Nil(t, req.FileByPath("tmp2.txt")) }) t.Run("add file by reader", func(t *testing.T) { @@ -1185,6 +1201,14 @@ func Test_Request_MaxRedirects(t *testing.T) { require.Nil(t, resp) require.Equal(t, "too many redirects detected when doing the request", err.Error()) }) + + t.Run("MaxRedirects", func(t *testing.T) { + req := AcquireRequest(). + SetDial(func(addr string) (net.Conn, error) { return ln.Dial() }). + SetMaxRedirects(3) + + require.Equal(t, req.MaxRedirects(), 3) + }) } // // readErrorConn is a struct for testing retryIf diff --git a/client/response.go b/client/response.go index e9fb40826c..170821683b 100644 --- a/client/response.go +++ b/client/response.go @@ -2,7 +2,9 @@ package client import ( "bytes" + "errors" "io" + "io/fs" "os" "path/filepath" "strings" @@ -83,9 +85,9 @@ func (r *Response) Save(v any) error { file := filepath.Clean(p) dir := filepath.Dir(file) - // create director + // create directory if _, err := os.Stat(dir); err != nil { - if !os.IsNotExist(err) { + if !errors.Is(err, fs.ErrNotExist) { return err } diff --git a/client/response_test.go b/client/response_test.go index e1d24503fa..e67f6ac334 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -252,12 +252,13 @@ func Test_Response_Save(t *testing.T) { }() file, err := os.Open("./test/tmp.json") + defer file.Close() + require.NoError(t, err) data, err := io.ReadAll(file) require.NoError(t, err) require.Equal(t, "{\"status\":\"success\"}", string(data)) - file.Close() }) t.Run("io.Writer", func(t *testing.T) {