diff --git a/client.go b/client.go
deleted file mode 100644
index 8825f9d815..0000000000
--- a/client.go
+++ /dev/null
@@ -1,1021 +0,0 @@
-package fiber
-
-import (
- "bytes"
- "crypto/tls"
- "encoding/json"
- "encoding/xml"
- "errors"
- "fmt"
- "io"
- "mime/multipart"
- "os"
- "path/filepath"
- "strconv"
- "sync"
- "time"
-
- "github.com/gofiber/utils/v2"
- "github.com/valyala/fasthttp"
-)
-
-// Request represents HTTP request.
-//
-// It is forbidden copying Request instances. Create new instances
-// and use CopyTo instead.
-//
-// Request instance MUST NOT be used from concurrently running goroutines.
-// Copy from fasthttp
-type Request = fasthttp.Request
-
-// Response represents HTTP response.
-//
-// It is forbidden copying Response instances. Create new instances
-// and use CopyTo instead.
-//
-// Response instance MUST NOT be used from concurrently running goroutines.
-// Copy from fasthttp
-type Response = fasthttp.Response
-
-// Args represents query arguments.
-//
-// It is forbidden copying Args instances. Create new instances instead
-// and use CopyTo().
-//
-// Args instance MUST NOT be used from concurrently running goroutines.
-// Copy from fasthttp
-type Args = fasthttp.Args
-
-// RetryIfFunc signature of retry if function
-// Request argument passed to RetryIfFunc, if there are any request errors.
-// Copy from fasthttp
-type RetryIfFunc = fasthttp.RetryIfFunc
-
-var defaultClient Client
-
-// Client implements http client.
-//
-// It is safe calling Client methods from concurrently running goroutines.
-type Client struct {
- mutex sync.RWMutex
- // UserAgent is used in User-Agent request header.
- UserAgent string
-
- // NoDefaultUserAgentHeader when set to true, causes the default
- // User-Agent header to be excluded from the Request.
- NoDefaultUserAgentHeader bool
-
- // When set by an external client of Fiber it will use the provided implementation of a
- // JSONMarshal
- //
- // Allowing for flexibility in using another json library for encoding
- JSONEncoder utils.JSONMarshal
-
- // When set by an external client of Fiber it will use the provided implementation of a
- // JSONUnmarshal
- //
- // Allowing for flexibility in using another json library for decoding
- JSONDecoder utils.JSONUnmarshal
-}
-
-// Get returns an agent with http method GET.
-func Get(url string) *Agent { return defaultClient.Get(url) }
-
-// Get returns an agent with http method GET.
-func (c *Client) Get(url string) *Agent {
- return c.createAgent(MethodGet, url)
-}
-
-// Head returns an agent with http method HEAD.
-func Head(url string) *Agent { return defaultClient.Head(url) }
-
-// Head returns an agent with http method GET.
-func (c *Client) Head(url string) *Agent {
- return c.createAgent(MethodHead, url)
-}
-
-// Post sends POST request to the given URL.
-func Post(url string) *Agent { return defaultClient.Post(url) }
-
-// Post sends POST request to the given URL.
-func (c *Client) Post(url string) *Agent {
- return c.createAgent(MethodPost, url)
-}
-
-// Put sends PUT request to the given URL.
-func Put(url string) *Agent { return defaultClient.Put(url) }
-
-// Put sends PUT request to the given URL.
-func (c *Client) Put(url string) *Agent {
- return c.createAgent(MethodPut, url)
-}
-
-// Patch sends PATCH request to the given URL.
-func Patch(url string) *Agent { return defaultClient.Patch(url) }
-
-// Patch sends PATCH request to the given URL.
-func (c *Client) Patch(url string) *Agent {
- return c.createAgent(MethodPatch, url)
-}
-
-// Delete sends DELETE request to the given URL.
-func Delete(url string) *Agent { return defaultClient.Delete(url) }
-
-// Delete sends DELETE request to the given URL.
-func (c *Client) Delete(url string) *Agent {
- return c.createAgent(MethodDelete, url)
-}
-
-func (c *Client) createAgent(method, url string) *Agent {
- a := AcquireAgent()
- a.req.Header.SetMethod(method)
- a.req.SetRequestURI(url)
-
- c.mutex.RLock()
- a.Name = c.UserAgent
- a.NoDefaultUserAgentHeader = c.NoDefaultUserAgentHeader
- a.jsonDecoder = c.JSONDecoder
- a.jsonEncoder = c.JSONEncoder
- if a.jsonDecoder == nil {
- a.jsonDecoder = json.Unmarshal
- }
- c.mutex.RUnlock()
-
- if err := a.Parse(); err != nil {
- a.errs = append(a.errs, err)
- }
-
- return a
-}
-
-// Agent is an object storing all request data for client.
-// Agent instance MUST NOT be used from concurrently running goroutines.
-type Agent struct {
- // Name is used in User-Agent request header.
- Name string
-
- // NoDefaultUserAgentHeader when set to true, causes the default
- // User-Agent header to be excluded from the Request.
- NoDefaultUserAgentHeader bool
-
- // HostClient is an embedded fasthttp HostClient
- *fasthttp.HostClient
-
- req *Request
- resp *Response
- dest []byte
- args *Args
- timeout time.Duration
- errs []error
- formFiles []*FormFile
- debugWriter io.Writer
- mw multipartWriter
- jsonEncoder utils.JSONMarshal
- jsonDecoder utils.JSONUnmarshal
- maxRedirectsCount int
- boundary string
- reuse bool
- parsed bool
-}
-
-// Parse initializes URI and HostClient.
-func (a *Agent) Parse() error {
- if a.parsed {
- return nil
- }
- a.parsed = true
-
- uri := a.req.URI()
-
- var isTLS bool
- scheme := uri.Scheme()
- if bytes.Equal(scheme, []byte(schemeHTTPS)) {
- isTLS = true
- } else if !bytes.Equal(scheme, []byte(schemeHTTP)) {
- return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme)
- }
-
- name := a.Name
- if name == "" && !a.NoDefaultUserAgentHeader {
- name = defaultUserAgent
- }
-
- a.HostClient = &fasthttp.HostClient{
- Addr: fasthttp.AddMissingPort(string(uri.Host()), isTLS),
- Name: name,
- NoDefaultUserAgentHeader: a.NoDefaultUserAgentHeader,
- IsTLS: isTLS,
- }
-
- return nil
-}
-
-/************************** Header Setting **************************/
-
-// Set sets the given 'key: value' header.
-//
-// Use Add for setting multiple header values under the same key.
-func (a *Agent) Set(k, v string) *Agent {
- a.req.Header.Set(k, v)
-
- return a
-}
-
-// SetBytesK sets the given 'key: value' header.
-//
-// Use AddBytesK for setting multiple header values under the same key.
-func (a *Agent) SetBytesK(k []byte, v string) *Agent {
- a.req.Header.SetBytesK(k, v)
-
- return a
-}
-
-// SetBytesV sets the given 'key: value' header.
-//
-// Use AddBytesV for setting multiple header values under the same key.
-func (a *Agent) SetBytesV(k string, v []byte) *Agent {
- a.req.Header.SetBytesV(k, v)
-
- return a
-}
-
-// SetBytesKV sets the given 'key: value' header.
-//
-// Use AddBytesKV for setting multiple header values under the same key.
-func (a *Agent) SetBytesKV(k, v []byte) *Agent {
- a.req.Header.SetBytesKV(k, v)
-
- return a
-}
-
-// Add adds the given 'key: value' header.
-//
-// Multiple headers with the same key may be added with this function.
-// Use Set for setting a single header for the given key.
-func (a *Agent) Add(k, v string) *Agent {
- a.req.Header.Add(k, v)
-
- return a
-}
-
-// AddBytesK adds the given 'key: value' header.
-//
-// Multiple headers with the same key may be added with this function.
-// Use SetBytesK for setting a single header for the given key.
-func (a *Agent) AddBytesK(k []byte, v string) *Agent {
- a.req.Header.AddBytesK(k, v)
-
- return a
-}
-
-// AddBytesV adds the given 'key: value' header.
-//
-// Multiple headers with the same key may be added with this function.
-// Use SetBytesV for setting a single header for the given key.
-func (a *Agent) AddBytesV(k string, v []byte) *Agent {
- a.req.Header.AddBytesV(k, v)
-
- return a
-}
-
-// AddBytesKV adds the given 'key: value' header.
-//
-// Multiple headers with the same key may be added with this function.
-// Use SetBytesKV for setting a single header for the given key.
-func (a *Agent) AddBytesKV(k, v []byte) *Agent {
- a.req.Header.AddBytesKV(k, v)
-
- return a
-}
-
-// ConnectionClose sets 'Connection: close' header.
-func (a *Agent) ConnectionClose() *Agent {
- a.req.Header.SetConnectionClose()
-
- return a
-}
-
-// UserAgent sets User-Agent header value.
-func (a *Agent) UserAgent(userAgent string) *Agent {
- a.req.Header.SetUserAgent(userAgent)
-
- return a
-}
-
-// UserAgentBytes sets User-Agent header value.
-func (a *Agent) UserAgentBytes(userAgent []byte) *Agent {
- a.req.Header.SetUserAgentBytes(userAgent)
-
- return a
-}
-
-// Cookie sets one 'key: value' cookie.
-func (a *Agent) Cookie(key, value string) *Agent {
- a.req.Header.SetCookie(key, value)
-
- return a
-}
-
-// CookieBytesK sets one 'key: value' cookie.
-func (a *Agent) CookieBytesK(key []byte, value string) *Agent {
- a.req.Header.SetCookieBytesK(key, value)
-
- return a
-}
-
-// CookieBytesKV sets one 'key: value' cookie.
-func (a *Agent) CookieBytesKV(key, value []byte) *Agent {
- a.req.Header.SetCookieBytesKV(key, value)
-
- return a
-}
-
-// Cookies sets multiple 'key: value' cookies.
-func (a *Agent) Cookies(kv ...string) *Agent {
- for i := 1; i < len(kv); i += 2 {
- a.req.Header.SetCookie(kv[i-1], kv[i])
- }
-
- return a
-}
-
-// CookiesBytesKV sets multiple 'key: value' cookies.
-func (a *Agent) CookiesBytesKV(kv ...[]byte) *Agent {
- for i := 1; i < len(kv); i += 2 {
- a.req.Header.SetCookieBytesKV(kv[i-1], kv[i])
- }
-
- return a
-}
-
-// Referer sets Referer header value.
-func (a *Agent) Referer(referer string) *Agent {
- a.req.Header.SetReferer(referer)
-
- return a
-}
-
-// RefererBytes sets Referer header value.
-func (a *Agent) RefererBytes(referer []byte) *Agent {
- a.req.Header.SetRefererBytes(referer)
-
- return a
-}
-
-// ContentType sets Content-Type header value.
-func (a *Agent) ContentType(contentType string) *Agent {
- a.req.Header.SetContentType(contentType)
-
- return a
-}
-
-// ContentTypeBytes sets Content-Type header value.
-func (a *Agent) ContentTypeBytes(contentType []byte) *Agent {
- a.req.Header.SetContentTypeBytes(contentType)
-
- return a
-}
-
-/************************** End Header Setting **************************/
-
-/************************** URI Setting **************************/
-
-// Host sets host for the URI.
-func (a *Agent) Host(host string) *Agent {
- a.req.URI().SetHost(host)
-
- return a
-}
-
-// HostBytes sets host for the URI.
-func (a *Agent) HostBytes(host []byte) *Agent {
- a.req.URI().SetHostBytes(host)
-
- return a
-}
-
-// QueryString sets URI query string.
-func (a *Agent) QueryString(queryString string) *Agent {
- a.req.URI().SetQueryString(queryString)
-
- return a
-}
-
-// QueryStringBytes sets URI query string.
-func (a *Agent) QueryStringBytes(queryString []byte) *Agent {
- a.req.URI().SetQueryStringBytes(queryString)
-
- return a
-}
-
-// BasicAuth sets URI username and password.
-func (a *Agent) BasicAuth(username, password string) *Agent {
- a.req.URI().SetUsername(username)
- a.req.URI().SetPassword(password)
-
- return a
-}
-
-// BasicAuthBytes sets URI username and password.
-func (a *Agent) BasicAuthBytes(username, password []byte) *Agent {
- a.req.URI().SetUsernameBytes(username)
- a.req.URI().SetPasswordBytes(password)
-
- return a
-}
-
-/************************** End URI Setting **************************/
-
-/************************** Request Setting **************************/
-
-// BodyString sets request body.
-func (a *Agent) BodyString(bodyString string) *Agent {
- a.req.SetBodyString(bodyString)
-
- return a
-}
-
-// Body sets request body.
-func (a *Agent) Body(body []byte) *Agent {
- a.req.SetBody(body)
-
- return a
-}
-
-// BodyStream sets request body stream and, optionally body size.
-//
-// If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes
-// before returning io.EOF.
-//
-// If bodySize < 0, then bodyStream is read until io.EOF.
-//
-// bodyStream.Close() is called after finishing reading all body data
-// if it implements io.Closer.
-//
-// Note that GET and HEAD requests cannot have body.
-func (a *Agent) BodyStream(bodyStream io.Reader, bodySize int) *Agent {
- a.req.SetBodyStream(bodyStream, bodySize)
-
- return a
-}
-
-// JSON sends a JSON request.
-func (a *Agent) JSON(v any, ctype ...string) *Agent {
- if a.jsonEncoder == nil {
- a.jsonEncoder = json.Marshal
- }
-
- if len(ctype) > 0 {
- a.req.Header.SetContentType(ctype[0])
- } else {
- a.req.Header.SetContentType(MIMEApplicationJSON)
- }
-
- if body, err := a.jsonEncoder(v); err != nil {
- a.errs = append(a.errs, err)
- } else {
- a.req.SetBody(body)
- }
-
- return a
-}
-
-// XML sends an XML request.
-func (a *Agent) XML(v any) *Agent {
- a.req.Header.SetContentType(MIMEApplicationXML)
-
- if body, err := xml.Marshal(v); err != nil {
- a.errs = append(a.errs, err)
- } else {
- a.req.SetBody(body)
- }
-
- return a
-}
-
-// Form sends form request with body if args is non-nil.
-//
-// It is recommended obtaining args via AcquireArgs and release it
-// manually in performance-critical code.
-func (a *Agent) Form(args *Args) *Agent {
- a.req.Header.SetContentType(MIMEApplicationForm)
-
- if args != nil {
- a.req.SetBody(args.QueryString())
- }
-
- return a
-}
-
-// FormFile represents multipart form file
-type FormFile struct {
- // Fieldname is form file's field name
- Fieldname string
- // Name is form file's name
- Name string
- // Content is form file's content
- Content []byte
- // autoRelease indicates if returns the object
- // acquired via AcquireFormFile to the pool.
- autoRelease bool
-}
-
-// FileData appends files for multipart form request.
-//
-// It is recommended obtaining formFile via AcquireFormFile and release it
-// manually in performance-critical code.
-func (a *Agent) FileData(formFiles ...*FormFile) *Agent {
- a.formFiles = append(a.formFiles, formFiles...)
-
- return a
-}
-
-// SendFile reads file and appends it to multipart form request.
-func (a *Agent) SendFile(filename string, fieldname ...string) *Agent {
- content, err := os.ReadFile(filepath.Clean(filename))
- if err != nil {
- a.errs = append(a.errs, err)
- return a
- }
-
- ff := AcquireFormFile()
- if len(fieldname) > 0 && fieldname[0] != "" {
- ff.Fieldname = fieldname[0]
- } else {
- ff.Fieldname = "file" + strconv.Itoa(len(a.formFiles)+1)
- }
- ff.Name = filepath.Base(filename)
- ff.Content = append(ff.Content, content...)
- ff.autoRelease = true
-
- a.formFiles = append(a.formFiles, ff)
-
- return a
-}
-
-// SendFiles reads files and appends them to multipart form request.
-//
-// Examples:
-//
-// SendFile("/path/to/file1", "fieldname1", "/path/to/file2")
-func (a *Agent) SendFiles(filenamesAndFieldnames ...string) *Agent {
- pairs := len(filenamesAndFieldnames)
- if pairs&1 == 1 {
- filenamesAndFieldnames = append(filenamesAndFieldnames, "")
- }
-
- for i := 0; i < pairs; i += 2 {
- a.SendFile(filenamesAndFieldnames[i], filenamesAndFieldnames[i+1])
- }
-
- return a
-}
-
-// Boundary sets boundary for multipart form request.
-func (a *Agent) Boundary(boundary string) *Agent {
- a.boundary = boundary
-
- return a
-}
-
-// MultipartForm sends multipart form request with k-v and files.
-//
-// It is recommended obtaining args via AcquireArgs and release it
-// manually in performance-critical code.
-func (a *Agent) MultipartForm(args *Args) *Agent {
- if a.mw == nil {
- a.mw = multipart.NewWriter(a.req.BodyWriter())
- }
-
- if a.boundary != "" {
- if err := a.mw.SetBoundary(a.boundary); err != nil {
- a.errs = append(a.errs, err)
- return a
- }
- }
-
- a.req.Header.SetMultipartFormBoundary(a.mw.Boundary())
-
- if args != nil {
- args.VisitAll(func(key, value []byte) {
- if err := a.mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value)); err != nil {
- a.errs = append(a.errs, err)
- }
- })
- }
-
- for _, ff := range a.formFiles {
- w, err := a.mw.CreateFormFile(ff.Fieldname, ff.Name)
- if err != nil {
- a.errs = append(a.errs, err)
- continue
- }
- if _, err = w.Write(ff.Content); err != nil {
- a.errs = append(a.errs, err)
- }
- }
-
- if err := a.mw.Close(); err != nil {
- a.errs = append(a.errs, err)
- }
-
- return a
-}
-
-/************************** End Request Setting **************************/
-
-/************************** Agent Setting **************************/
-
-// Debug mode enables logging request and response detail
-func (a *Agent) Debug(w ...io.Writer) *Agent {
- a.debugWriter = os.Stdout
- if len(w) > 0 {
- a.debugWriter = w[0]
- }
-
- return a
-}
-
-// Timeout sets request timeout duration.
-func (a *Agent) Timeout(timeout time.Duration) *Agent {
- a.timeout = timeout
-
- return a
-}
-
-// Reuse enables the Agent instance to be used again after one request.
-//
-// If agent is reusable, then it should be released manually when it is no
-// longer used.
-func (a *Agent) Reuse() *Agent {
- a.reuse = true
-
- return a
-}
-
-// InsecureSkipVerify controls whether the Agent verifies the server
-// certificate chain and host name.
-func (a *Agent) InsecureSkipVerify() *Agent {
- if a.HostClient.TLSConfig == nil {
- a.HostClient.TLSConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We explicitly let the user set insecure mode here
- } else {
- a.HostClient.TLSConfig.InsecureSkipVerify = true
- }
-
- return a
-}
-
-// TLSConfig sets tls config.
-func (a *Agent) TLSConfig(config *tls.Config) *Agent {
- a.HostClient.TLSConfig = config
-
- return a
-}
-
-// MaxRedirectsCount sets max redirect count for GET and HEAD.
-func (a *Agent) MaxRedirectsCount(count int) *Agent {
- a.maxRedirectsCount = count
-
- return a
-}
-
-// JSONEncoder sets custom json encoder.
-func (a *Agent) JSONEncoder(jsonEncoder utils.JSONMarshal) *Agent {
- a.jsonEncoder = jsonEncoder
-
- return a
-}
-
-// JSONDecoder sets custom json decoder.
-func (a *Agent) JSONDecoder(jsonDecoder utils.JSONUnmarshal) *Agent {
- a.jsonDecoder = jsonDecoder
-
- return a
-}
-
-// Request returns Agent request instance.
-func (a *Agent) Request() *Request {
- return a.req
-}
-
-// SetResponse sets custom response for the Agent instance.
-//
-// It is recommended obtaining custom response via AcquireResponse and release it
-// manually in performance-critical code.
-func (a *Agent) SetResponse(customResp *Response) *Agent {
- a.resp = customResp
-
- return a
-}
-
-// Dest sets custom dest.
-//
-// The contents of dest will be replaced by the response body, if the dest
-// is too small a new slice will be allocated.
-func (a *Agent) Dest(dest []byte) *Agent {
- a.dest = dest
-
- return a
-}
-
-// RetryIf controls whether a retry should be attempted after an error.
-//
-// By default, will use isIdempotent function from fasthttp
-func (a *Agent) RetryIf(retryIf RetryIfFunc) *Agent {
- a.HostClient.RetryIf = retryIf
- return a
-}
-
-/************************** End Agent Setting **************************/
-
-// Bytes returns the status code, bytes body and errors of url.
-//
-// it's not safe to use Agent after calling [Agent.Bytes]
-func (a *Agent) Bytes() (int, []byte, []error) {
- defer a.release()
- return a.bytes()
-}
-
-func (a *Agent) bytes() (code int, body []byte, errs []error) { //nolint:nonamedreturns,revive // We want to overwrite the body in a deferred func. TODO: Check if we really need to do this. We eventually want to get rid of all named returns.
- if errs = append(errs, a.errs...); len(errs) > 0 {
- return code, body, errs
- }
-
- var (
- req = a.req
- resp *Response
- nilResp bool
- )
-
- if a.resp == nil {
- resp = AcquireResponse()
- nilResp = true
- } else {
- resp = a.resp
- }
-
- defer func() {
- if a.debugWriter != nil {
- printDebugInfo(req, resp, a.debugWriter)
- }
-
- if len(errs) == 0 {
- code = resp.StatusCode()
- }
-
- body = append(a.dest, resp.Body()...) //nolint:gocritic // We want to append to the returned slice here
-
- if nilResp {
- ReleaseResponse(resp)
- }
- }()
-
- if a.timeout > 0 {
- if err := a.HostClient.DoTimeout(req, resp, a.timeout); err != nil {
- errs = append(errs, err)
- return code, body, errs
- }
- } else if a.maxRedirectsCount > 0 && (string(req.Header.Method()) == MethodGet || string(req.Header.Method()) == MethodHead) {
- if err := a.HostClient.DoRedirects(req, resp, a.maxRedirectsCount); err != nil {
- errs = append(errs, err)
- return code, body, errs
- }
- } else if err := a.HostClient.Do(req, resp); err != nil {
- errs = append(errs, err)
- }
-
- return code, body, errs
-}
-
-func printDebugInfo(req *Request, resp *Response, w io.Writer) {
- msg := fmt.Sprintf("Connected to %s(%s)\r\n\r\n", req.URI().Host(), resp.RemoteAddr())
- _, _ = w.Write(utils.UnsafeBytes(msg)) //nolint:errcheck // This will never fail
- _, _ = req.WriteTo(w) //nolint:errcheck // This will never fail
- _, _ = resp.WriteTo(w) //nolint:errcheck // This will never fail
-}
-
-// String returns the status code, string body and errors of url.
-//
-// it's not safe to use Agent after calling [Agent.String]
-func (a *Agent) String() (int, string, []error) {
- defer a.release()
- code, body, errs := a.bytes()
- // TODO: There might be a data race here on body. Maybe use utils.CopyBytes on it?
-
- return code, utils.UnsafeString(body), errs
-}
-
-// Struct returns the status code, bytes body and errors of URL.
-// And bytes body will be unmarshalled to given v.
-//
-// it's not safe to use Agent after calling [Agent.Struct]
-func (a *Agent) Struct(v any) (int, []byte, []error) {
- defer a.release()
-
- code, body, errs := a.bytes()
- if len(errs) > 0 {
- return code, body, errs
- }
-
- // TODO: This should only be done once
- if a.jsonDecoder == nil {
- a.jsonDecoder = json.Unmarshal
- }
-
- if err := a.jsonDecoder(body, v); err != nil {
- errs = append(errs, err)
- }
-
- return code, body, errs
-}
-
-func (a *Agent) release() {
- if !a.reuse {
- ReleaseAgent(a)
- } else {
- a.errs = a.errs[:0]
- }
-}
-
-func (a *Agent) reset() {
- a.HostClient = nil
- a.req.Reset()
- a.resp = nil
- a.dest = nil
- a.timeout = 0
- a.args = nil
- a.errs = a.errs[:0]
- a.debugWriter = nil
- a.mw = nil
- a.reuse = false
- a.parsed = false
- a.maxRedirectsCount = 0
- a.boundary = ""
- a.Name = ""
- a.NoDefaultUserAgentHeader = false
- for i, ff := range a.formFiles {
- if ff.autoRelease {
- ReleaseFormFile(ff)
- }
- a.formFiles[i] = nil
- }
- a.formFiles = a.formFiles[:0]
-}
-
-var (
- clientPool sync.Pool
- agentPool = sync.Pool{
- New: func() any {
- return &Agent{req: &Request{}}
- },
- }
- responsePool sync.Pool
- argsPool sync.Pool
- formFilePool sync.Pool
-)
-
-// AcquireClient returns an empty Client instance from client pool.
-//
-// The returned Client instance may be passed to ReleaseClient when it is
-// no longer needed. This allows Client recycling, reduces GC pressure
-// and usually improves performance.
-func AcquireClient() *Client {
- v := clientPool.Get()
- if v == nil {
- return &Client{}
- }
- c, ok := v.(*Client)
- if !ok {
- panic(errors.New("failed to type-assert to *Client"))
- }
- return c
-}
-
-// ReleaseClient returns c acquired via AcquireClient to client pool.
-//
-// It is forbidden accessing req and/or it's members after returning
-// it to client pool.
-func ReleaseClient(c *Client) {
- c.UserAgent = ""
- c.NoDefaultUserAgentHeader = false
- c.JSONEncoder = nil
- c.JSONDecoder = nil
-
- clientPool.Put(c)
-}
-
-// AcquireAgent returns an empty Agent instance from Agent pool.
-//
-// The returned Agent instance may be passed to ReleaseAgent when it is
-// no longer needed. This allows Agent recycling, reduces GC pressure
-// and usually improves performance.
-func AcquireAgent() *Agent {
- a, ok := agentPool.Get().(*Agent)
- if !ok {
- panic(errors.New("failed to type-assert to *Agent"))
- }
- return a
-}
-
-// ReleaseAgent returns an acquired via AcquireAgent to Agent pool.
-//
-// It is forbidden accessing req and/or it's members after returning
-// it to Agent pool.
-func ReleaseAgent(a *Agent) {
- a.reset()
- agentPool.Put(a)
-}
-
-// AcquireResponse returns an empty Response instance from response pool.
-//
-// The returned Response instance may be passed to ReleaseResponse when it is
-// no longer needed. This allows Response recycling, reduces GC pressure
-// and usually improves performance.
-// Copy from fasthttp
-func AcquireResponse() *Response {
- v := responsePool.Get()
- if v == nil {
- return &Response{}
- }
- r, ok := v.(*Response)
- if !ok {
- panic(errors.New("failed to type-assert to *Response"))
- }
- return r
-}
-
-// ReleaseResponse return resp acquired via AcquireResponse to response pool.
-//
-// It is forbidden accessing resp and/or it's members after returning
-// it to response pool.
-// Copy from fasthttp
-func ReleaseResponse(resp *Response) {
- resp.Reset()
- responsePool.Put(resp)
-}
-
-// AcquireArgs returns an empty Args object from the pool.
-//
-// The returned Args may be returned to the pool with ReleaseArgs
-// when no longer needed. This allows reducing GC load.
-// Copy from fasthttp
-func AcquireArgs() *Args {
- v := argsPool.Get()
- if v == nil {
- return &Args{}
- }
- a, ok := v.(*Args)
- if !ok {
- panic(errors.New("failed to type-assert to *Args"))
- }
- return a
-}
-
-// ReleaseArgs returns the object acquired via AcquireArgs to the pool.
-//
-// String not access the released Args object, otherwise data races may occur.
-// Copy from fasthttp
-func ReleaseArgs(a *Args) {
- a.Reset()
- argsPool.Put(a)
-}
-
-// AcquireFormFile returns an empty FormFile object from the pool.
-//
-// The returned FormFile may be returned to the pool with ReleaseFormFile
-// when no longer needed. This allows reducing GC load.
-func AcquireFormFile() *FormFile {
- v := formFilePool.Get()
- if v == nil {
- return &FormFile{}
- }
- ff, ok := v.(*FormFile)
- if !ok {
- panic(errors.New("failed to type-assert to *FormFile"))
- }
- return ff
-}
-
-// ReleaseFormFile returns the object acquired via AcquireFormFile to the pool.
-//
-// String not access the released FormFile object, otherwise data races may occur.
-func ReleaseFormFile(ff *FormFile) {
- ff.Fieldname = ""
- ff.Name = ""
- ff.Content = ff.Content[:0]
- ff.autoRelease = false
-
- formFilePool.Put(ff)
-}
-
-const (
- defaultUserAgent = "fiber"
-)
-
-type multipartWriter interface {
- Boundary() string
- SetBoundary(boundary string) error
- CreateFormFile(fieldname, filename string) (io.Writer, error)
- WriteField(fieldname, value string) error
- Close() error
-}
diff --git a/client/README.md b/client/README.md
new file mode 100644
index 0000000000..a992fcc6fb
--- /dev/null
+++ b/client/README.md
@@ -0,0 +1,35 @@
+
Fiber Client
+Easy-to-use HTTP client based on fasthttp (inspired by resty and axios)
+Features section describes in detail about Resty capabilities
+
+## Features
+
+> The characteristics have not yet been written.
+
+- GET, POST, PUT, DELETE, HEAD, PATCH, OPTIONS, etc.
+- Simple and chainable methods for settings and request
+- Request Body can be `string`, `[]byte`, `map`, `slice`
+ - Auto detects `Content-Type`
+ - Buffer processing for `files`
+ - Native `*fasthttp.Request` instance can be accessed during middleware and request execution via `Request.RawRequest`
+ - Request Body can be read multiple time via `Request.RawRequest.GetBody()`
+- Response object gives you more possibility
+ - Access as `[]byte` by `response.Body()` or access as `string` by `response.String()`
+- Automatic marshal and unmarshal for JSON and XML content type
+ - Default is JSON, if you supply struct/map without header Content-Type
+ - For auto-unmarshal, refer to -
+ - Success scenario Request.SetResult() and Response.Result().
+ - Error scenario Request.SetError() and Response.Error().
+ - Supports RFC7807 - application/problem+json & application/problem+xml
+ - Provide an option to override JSON Marshal/Unmarshal and XML Marshal/Unmarshal
+
+## Usage
+
+The following samples will assist you to become as comfortable as possible with `Fiber Client` library.
+
+```go
+// Import Fiber Client into your code and refer it as `client`.
+import "github.com/gofiber/fiber/client"
+```
+
+### Simple GET
diff --git a/client/client.go b/client/client.go
new file mode 100644
index 0000000000..22cbd19727
--- /dev/null
+++ b/client/client.go
@@ -0,0 +1,775 @@
+package client
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/json"
+ "encoding/xml"
+ "errors"
+ "fmt"
+ "io"
+ urlpkg "net/url"
+ "os"
+ "path/filepath"
+ "sync"
+ "time"
+
+ "github.com/gofiber/fiber/v3/log"
+
+ "github.com/gofiber/utils/v2"
+
+ "github.com/valyala/fasthttp"
+)
+
+var (
+ ErrInvalidProxyURL = errors.New("invalid proxy url scheme")
+ ErrFailedToAppendCert = errors.New("failed to append certificate")
+)
+
+// The Client is used to create a Fiber Client with
+// client-level settings that apply to all requests
+// raise from the client.
+//
+// Fiber Client also provides an option to override
+// or merge most of the client settings at the request.
+type Client struct {
+ mu sync.RWMutex
+
+ fasthttp *fasthttp.Client
+
+ baseURL string
+ userAgent string
+ referer string
+ header *Header
+ params *QueryParam
+ cookies *Cookie
+ path *PathParam
+
+ debug bool
+
+ timeout time.Duration
+
+ // user defined request hooks
+ userRequestHooks []RequestHook
+
+ // client package defined request hooks
+ builtinRequestHooks []RequestHook
+
+ // user defined response hooks
+ userResponseHooks []ResponseHook
+
+ // client package defined response hooks
+ builtinResponseHooks []ResponseHook
+
+ jsonMarshal utils.JSONMarshal
+ jsonUnmarshal utils.JSONUnmarshal
+ xmlMarshal utils.XMLMarshal
+ xmlUnmarshal utils.XMLUnmarshal
+
+ cookieJar *CookieJar
+
+ // proxy
+ proxyURL string
+
+ // retry
+ retryConfig *RetryConfig
+
+ // logger
+ logger log.CommonLogger
+}
+
+// R raise a request from the client.
+func (c *Client) R() *Request {
+ return AcquireRequest().SetClient(c)
+}
+
+// RequestHook Request returns user-defined request hooks.
+func (c *Client) RequestHook() []RequestHook {
+ return c.userRequestHooks
+}
+
+// AddRequestHook Add user-defined request hooks.
+func (c *Client) AddRequestHook(h ...RequestHook) *Client {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.userRequestHooks = append(c.userRequestHooks, h...)
+ return c
+}
+
+// ResponseHook return user-define response hooks.
+func (c *Client) ResponseHook() []ResponseHook {
+ return c.userResponseHooks
+}
+
+// AddResponseHook Add user-defined response hooks.
+func (c *Client) AddResponseHook(h ...ResponseHook) *Client {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.userResponseHooks = append(c.userResponseHooks, h...)
+ return c
+}
+
+// JSONMarshal returns json marshal function in Core.
+func (c *Client) JSONMarshal() utils.JSONMarshal {
+ return c.jsonMarshal
+}
+
+// SetJSONMarshal Set json encoder.
+func (c *Client) SetJSONMarshal(f utils.JSONMarshal) *Client {
+ c.jsonMarshal = f
+ return c
+}
+
+// JSONUnmarshal returns json unmarshal function in Core.
+func (c *Client) JSONUnmarshal() utils.JSONUnmarshal {
+ return c.jsonUnmarshal
+}
+
+// Set json decoder.
+func (c *Client) SetJSONUnmarshal(f utils.JSONUnmarshal) *Client {
+ c.jsonUnmarshal = f
+ return c
+}
+
+// XMLMarshal returns xml marshal function in Core.
+func (c *Client) XMLMarshal() utils.XMLMarshal {
+ return c.xmlMarshal
+}
+
+// SetXMLMarshal Set xml encoder.
+func (c *Client) SetXMLMarshal(f utils.XMLMarshal) *Client {
+ c.xmlMarshal = f
+ return c
+}
+
+// XMLUnmarshal returns xml unmarshal function in Core.
+func (c *Client) XMLUnmarshal() utils.XMLUnmarshal {
+ return c.xmlUnmarshal
+}
+
+// SetXMLUnmarshal Set xml decoder.
+func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client {
+ c.xmlUnmarshal = f
+ return c
+}
+
+// TLSConfig returns tlsConfig in client.
+// If client don't have tlsConfig, this function will init it.
+func (c *Client) TLSConfig() *tls.Config {
+ if c.fasthttp.TLSConfig == nil {
+ c.fasthttp.TLSConfig = &tls.Config{
+ MinVersion: tls.VersionTLS12,
+ }
+ }
+
+ return c.fasthttp.TLSConfig
+}
+
+// SetTLSConfig sets tlsConfig in client.
+func (c *Client) SetTLSConfig(config *tls.Config) *Client {
+ c.fasthttp.TLSConfig = config
+ return c
+}
+
+// SetCertificates method sets client certificates into client.
+func (c *Client) SetCertificates(certs ...tls.Certificate) *Client {
+ config := c.TLSConfig()
+ config.Certificates = append(config.Certificates, certs...)
+ return c
+}
+
+// SetRootCertificate adds one or more root certificates into client.
+func (c *Client) SetRootCertificate(path string) *Client {
+ cleanPath := filepath.Clean(path)
+ file, err := os.Open(cleanPath)
+ if err != nil {
+ c.logger.Panicf("client: %v", err)
+ }
+ defer func() {
+ if err := file.Close(); err != nil {
+ c.logger.Panicf("client: failed to close file: %v", err)
+ }
+ }()
+
+ pem, err := io.ReadAll(file)
+ if err != nil {
+ c.logger.Panicf("client: %v", err)
+ }
+
+ config := c.TLSConfig()
+ if config.RootCAs == nil {
+ config.RootCAs = x509.NewCertPool()
+ }
+
+ if !config.RootCAs.AppendCertsFromPEM(pem) {
+ c.logger.Panicf("client: %v", ErrFailedToAppendCert)
+ }
+
+ return c
+}
+
+// SetRootCertificateFromString method adds one or more root certificates into client.
+func (c *Client) SetRootCertificateFromString(pem string) *Client {
+ config := c.TLSConfig()
+
+ if config.RootCAs == nil {
+ config.RootCAs = x509.NewCertPool()
+ }
+
+ if !config.RootCAs.AppendCertsFromPEM([]byte(pem)) {
+ c.logger.Panicf("client: %v", ErrFailedToAppendCert)
+ }
+
+ return c
+}
+
+// SetProxyURL sets proxy url in client. It will apply via core to hostclient.
+func (c *Client) SetProxyURL(proxyURL string) error {
+ pURL, err := urlpkg.Parse(proxyURL)
+ if err != nil {
+ return fmt.Errorf("client: %w", err)
+ }
+
+ if pURL.Scheme != "http" && pURL.Scheme != "https" {
+ return fmt.Errorf("client: %w", ErrInvalidProxyURL)
+ }
+
+ c.proxyURL = pURL.String()
+
+ return nil
+}
+
+// RetryConfig returns retry config in client.
+func (c *Client) RetryConfig() *RetryConfig {
+ return c.retryConfig
+}
+
+// SetRetryConfig sets retry config in client which is impl by addon/retry package.
+func (c *Client) SetRetryConfig(config *RetryConfig) *Client {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.retryConfig = config
+ return c
+}
+
+// BaseURL returns baseurl in Client instance.
+func (c *Client) BaseURL() string {
+ return c.baseURL
+}
+
+// SetBaseURL Set baseUrl which is prefix of real url.
+func (c *Client) SetBaseURL(url string) *Client {
+ c.baseURL = url
+ return c
+}
+
+// Header method returns header value via key,
+// this method will visit all field in the header,
+// then sort them.
+func (c *Client) Header(key string) []string {
+ return c.header.PeekMultiple(key)
+}
+
+// AddHeader method adds a single header field and its value in the client instance.
+// These headers will be applied to all requests raised from this client instance.
+// Also, it can be overridden at request level header options.
+func (c *Client) AddHeader(key, val string) *Client {
+ c.header.Add(key, val)
+ return c
+}
+
+// SetHeader method sets a single header field and its value in the client instance.
+// These headers will be applied to all requests raised from this client instance.
+// Also, it can be overridden at request level header options.
+func (c *Client) SetHeader(key, val string) *Client {
+ c.header.Set(key, val)
+ return c
+}
+
+// AddHeaders method adds multiple headers field and its values at one go in the client instance.
+// These headers will be applied to all requests raised from this client instance. Also it can be
+// overridden at request level headers options.
+func (c *Client) AddHeaders(h map[string][]string) *Client {
+ c.header.AddHeaders(h)
+ return c
+}
+
+// SetHeaders method sets multiple headers field and its values at one go in the client instance.
+// These headers will be applied to all requests raised from this client instance. Also it can be
+// overridden at request level headers options.
+func (c *Client) SetHeaders(h map[string]string) *Client {
+ c.header.SetHeaders(h)
+ return c
+}
+
+// Param method returns params value via key,
+// this method will visit all field in the query param.
+func (c *Client) Param(key string) []string {
+ res := []string{}
+ tmp := c.params.PeekMulti(key)
+ for _, v := range tmp {
+ res = append(res, utils.UnsafeString(v))
+ }
+
+ return res
+}
+
+// AddParam method adds a single query param field and its value in the client instance.
+// These params will be applied to all requests raised from this client instance.
+// Also, it can be overridden at request level param options.
+func (c *Client) AddParam(key, val string) *Client {
+ c.params.Add(key, val)
+ return c
+}
+
+// SetParam method sets a single query param field and its value in the client instance.
+// These params will be applied to all requests raised from this client instance.
+// Also, it can be overridden at request level param options.
+func (c *Client) SetParam(key, val string) *Client {
+ c.params.Set(key, val)
+ return c
+}
+
+// AddParams method adds multiple query params field and its values at one go in the client instance.
+// These params will be applied to all requests raised from this client instance. Also it can be
+// overridden at request level params options.
+func (c *Client) AddParams(m map[string][]string) *Client {
+ c.params.AddParams(m)
+ return c
+}
+
+// SetParams method sets multiple params field and its values at one go in the client instance.
+// These params will be applied to all requests raised from this client instance. Also it can be
+// overridden at request level params options.
+func (c *Client) SetParams(m map[string]string) *Client {
+ c.params.SetParams(m)
+ return c
+}
+
+// SetParamsWithStruct method sets multiple params field and its values at one go in the client instance.
+// These params will be applied to all requests raised from this client instance. Also it can be
+// overridden at request level params options.
+func (c *Client) SetParamsWithStruct(v any) *Client {
+ c.params.SetParamsWithStruct(v)
+ return c
+}
+
+// DelParams method deletes single or multiple params field and its values in client.
+func (c *Client) DelParams(key ...string) *Client {
+ for _, v := range key {
+ c.params.Del(v)
+ }
+ return c
+}
+
+// SetUserAgent method sets userAgent field and its value in the client instance.
+// This ua will be applied to all requests raised from this client instance.
+// Also it can be overridden at request level ua options.
+func (c *Client) SetUserAgent(ua string) *Client {
+ c.userAgent = ua
+ return c
+}
+
+// SetReferer method sets referer field and its value in the client instance.
+// This referer will be applied to all requests raised from this client instance.
+// Also it can be overridden at request level referer options.
+func (c *Client) SetReferer(r string) *Client {
+ c.referer = r
+ return c
+}
+
+// PathParam returns the path param be set in request instance.
+// if path param doesn't exist, return empty string.
+func (c *Client) PathParam(key string) string {
+ if val, ok := (*c.path)[key]; ok {
+ return val
+ }
+
+ return ""
+}
+
+// SetPathParam method sets a single path param field and its value in the client instance.
+// These path params will be applied to all requests raised from this client instance.
+// Also it can be overridden at request level path params options.
+func (c *Client) SetPathParam(key, val string) *Client {
+ c.path.SetParam(key, val)
+ return c
+}
+
+// SetPathParams method sets multiple path params field and its values at one go in the client instance.
+// These path params will be applied to all requests raised from this client instance. Also it can be
+// overridden at request level path params options.
+func (c *Client) SetPathParams(m map[string]string) *Client {
+ c.path.SetParams(m)
+ return c
+}
+
+// SetPathParamsWithStruct method sets multiple path params field and its values at one go in the client instance.
+// These path params will be applied to all requests raised from this client instance. Also it can be
+// overridden at request level path params options.
+func (c *Client) SetPathParamsWithStruct(v any) *Client {
+ c.path.SetParamsWithStruct(v)
+ return c
+}
+
+// DelPathParams method deletes single or multiple path params field and its values in client.
+func (c *Client) DelPathParams(key ...string) *Client {
+ c.path.DelParams(key...)
+ return c
+}
+
+// Cookie returns the cookie be set in request instance.
+// if cookie doesn't exist, return empty string.
+func (c *Client) Cookie(key string) string {
+ if val, ok := (*c.cookies)[key]; ok {
+ return val
+ }
+ return ""
+}
+
+// SetCookie method sets a single cookie field and its value in the client instance.
+// These cookies will be applied to all requests raised from this client instance.
+// Also it can be overridden at request level cookie options.
+func (c *Client) SetCookie(key, val string) *Client {
+ c.cookies.SetCookie(key, val)
+ return c
+}
+
+// SetCookies method sets multiple cookies field and its values at one go in the client instance.
+// These cookies will be applied to all requests raised from this client instance. Also it can be
+// overridden at request level cookie options.
+func (c *Client) SetCookies(m map[string]string) *Client {
+ c.cookies.SetCookies(m)
+ return c
+}
+
+// SetCookiesWithStruct method sets multiple cookies field and its values at one go in the client instance.
+// These cookies will be applied to all requests raised from this client instance. Also it can be
+// overridden at request level cookies options.
+func (c *Client) SetCookiesWithStruct(v any) *Client {
+ c.cookies.SetCookiesWithStruct(v)
+ return c
+}
+
+// DelCookies method deletes single or multiple cookies field and its values in client.
+func (c *Client) DelCookies(key ...string) *Client {
+ c.cookies.DelCookies(key...)
+ return c
+}
+
+// SetTimeout method sets timeout val in client instance.
+// This value will be applied to all requests raised from this client instance.
+// Also, it can be overridden at request level timeout options.
+func (c *Client) SetTimeout(t time.Duration) *Client {
+ c.timeout = t
+ return c
+}
+
+// Debug enable log debug level output.
+func (c *Client) Debug() *Client {
+ c.debug = true
+ return c
+}
+
+// DisableDebug disenable log debug level output.
+func (c *Client) DisableDebug() *Client {
+ c.debug = false
+ return c
+}
+
+// SetCookieJar sets cookie jar in client instance.
+func (c *Client) SetCookieJar(cookieJar *CookieJar) *Client {
+ c.cookieJar = cookieJar
+ return c
+}
+
+// Get provide an API like axios which send get request.
+func (c *Client) Get(url string, cfg ...Config) (*Response, error) {
+ req := AcquireRequest().SetClient(c)
+ setConfigToRequest(req, cfg...)
+
+ return req.Get(url)
+}
+
+// Post provide an API like axios which send post request.
+func (c *Client) Post(url string, cfg ...Config) (*Response, error) {
+ req := AcquireRequest().SetClient(c)
+ setConfigToRequest(req, cfg...)
+
+ return req.Post(url)
+}
+
+// Head provide a API like axios which send head request.
+func (c *Client) Head(url string, cfg ...Config) (*Response, error) {
+ req := AcquireRequest().SetClient(c)
+ setConfigToRequest(req, cfg...)
+
+ return req.Head(url)
+}
+
+// Put provide an API like axios which send put request.
+func (c *Client) Put(url string, cfg ...Config) (*Response, error) {
+ req := AcquireRequest().SetClient(c)
+ setConfigToRequest(req, cfg...)
+
+ return req.Put(url)
+}
+
+// Delete provide an API like axios which send delete request.
+func (c *Client) Delete(url string, cfg ...Config) (*Response, error) {
+ req := AcquireRequest().SetClient(c)
+ setConfigToRequest(req, cfg...)
+
+ return req.Delete(url)
+}
+
+// Options provide an API like axios which send options request.
+func (c *Client) Options(url string, cfg ...Config) (*Response, error) {
+ req := AcquireRequest().SetClient(c)
+ setConfigToRequest(req, cfg...)
+
+ return req.Options(url)
+}
+
+// Patch provide an API like axios which send patch request.
+func (c *Client) Patch(url string, cfg ...Config) (*Response, error) {
+ req := AcquireRequest().SetClient(c)
+ setConfigToRequest(req, cfg...)
+
+ return req.Patch(url)
+}
+
+// Custom provide an API like axios which send custom request.
+func (c *Client) Custom(url, method string, cfg ...Config) (*Response, error) {
+ req := AcquireRequest().SetClient(c)
+ setConfigToRequest(req, cfg...)
+
+ return req.Custom(url, method)
+}
+
+// SetDial sets dial function in client.
+func (c *Client) SetDial(dial fasthttp.DialFunc) *Client {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.fasthttp.Dial = dial
+ return c
+}
+
+// SetLogger sets logger instance in client.
+func (c *Client) SetLogger(logger log.CommonLogger) *Client {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.logger = logger
+ return c
+}
+
+// Logger returns logger instance of client.
+func (c *Client) Logger() log.CommonLogger {
+ return c.logger
+}
+
+// Reset clear Client object
+func (c *Client) Reset() {
+ c.fasthttp = &fasthttp.Client{}
+ c.baseURL = ""
+ c.timeout = 0
+ c.userAgent = ""
+ c.referer = ""
+ c.proxyURL = ""
+ c.retryConfig = nil
+ c.debug = false
+
+ if c.cookieJar != nil {
+ c.cookieJar.Release()
+ c.cookieJar = nil
+ }
+
+ c.path.Reset()
+ c.cookies.Reset()
+ c.header.Reset()
+ c.params.Reset()
+}
+
+// Config for easy to set the request parameters, it should be
+// noted that when setting the request body will use JSON as
+// the default serialization mechanism, while the priority of
+// Body is higher than FormData, and the priority of FormData
+// is higher than File.
+type Config struct {
+ Ctx context.Context //nolint:containedctx // It's needed to be stored in the config.
+
+ UserAgent string
+ Referer string
+ Header map[string]string
+ Param map[string]string
+ Cookie map[string]string
+ PathParam map[string]string
+
+ Timeout time.Duration
+ MaxRedirects int
+
+ Body any
+ FormData map[string]string
+ File []*File
+}
+
+// setConfigToRequest Set the parameters passed via Config to Request.
+func setConfigToRequest(req *Request, config ...Config) {
+ if len(config) == 0 {
+ return
+ }
+ cfg := config[0]
+
+ if cfg.Ctx != nil {
+ req.SetContext(cfg.Ctx)
+ }
+
+ if cfg.UserAgent != "" {
+ req.SetUserAgent(cfg.UserAgent)
+ }
+
+ if cfg.Referer != "" {
+ req.SetReferer(cfg.Referer)
+ }
+
+ if cfg.Header != nil {
+ req.SetHeaders(cfg.Header)
+ }
+
+ if cfg.Param != nil {
+ req.SetParams(cfg.Param)
+ }
+
+ if cfg.Cookie != nil {
+ req.SetCookies(cfg.Cookie)
+ }
+
+ if cfg.PathParam != nil {
+ req.SetPathParams(cfg.PathParam)
+ }
+
+ if cfg.Timeout != 0 {
+ req.SetTimeout(cfg.Timeout)
+ }
+
+ if cfg.MaxRedirects != 0 {
+ req.SetMaxRedirects(cfg.MaxRedirects)
+ }
+
+ if cfg.Body != nil {
+ req.SetJSON(cfg.Body)
+ return
+ }
+
+ if cfg.FormData != nil {
+ req.SetFormDatas(cfg.FormData)
+ return
+ }
+
+ if cfg.File != nil && len(cfg.File) != 0 {
+ req.AddFiles(cfg.File...)
+ return
+ }
+}
+
+var (
+ defaultClient *Client
+ replaceMu = sync.Mutex{}
+ defaultUserAgent = "fiber"
+)
+
+// init acquire a default client.
+func init() {
+ defaultClient = NewClient()
+}
+
+// NewClient creates and returns a new Client object.
+func NewClient() *Client {
+ // FOllOW-UP performance optimization
+ // trie to use a pool to reduce the cost of memory allocation
+ // for the fiber client and the fasthttp client
+ // if possible also for other structs -> request header, cookie, query param, path param...
+ return &Client{
+ fasthttp: &fasthttp.Client{},
+ header: &Header{
+ RequestHeader: &fasthttp.RequestHeader{},
+ },
+ params: &QueryParam{
+ Args: fasthttp.AcquireArgs(),
+ },
+ cookies: &Cookie{},
+ path: &PathParam{},
+
+ userRequestHooks: []RequestHook{},
+ builtinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody},
+ userResponseHooks: []ResponseHook{},
+ builtinResponseHooks: []ResponseHook{parserResponseCookie, logger},
+ jsonMarshal: json.Marshal,
+ jsonUnmarshal: json.Unmarshal,
+ xmlMarshal: xml.Marshal,
+ xmlUnmarshal: xml.Unmarshal,
+ logger: log.DefaultLogger(),
+ }
+}
+
+// C get default client.
+func C() *Client {
+ return defaultClient
+}
+
+// Replace the defaultClient, the returned function can undo.
+func Replace(c *Client) func() {
+ replaceMu.Lock()
+ defer replaceMu.Unlock()
+
+ oldClient := defaultClient
+ defaultClient = c
+
+ return func() {
+ replaceMu.Lock()
+ defer replaceMu.Unlock()
+
+ defaultClient = oldClient
+ }
+}
+
+// Get send a get request use defaultClient, a convenient method.
+func Get(url string, cfg ...Config) (*Response, error) {
+ return C().Get(url, cfg...)
+}
+
+// Post send a post request use defaultClient, a convenient method.
+func Post(url string, cfg ...Config) (*Response, error) {
+ return C().Post(url, cfg...)
+}
+
+// Head send a head request use defaultClient, a convenient method.
+func Head(url string, cfg ...Config) (*Response, error) {
+ return C().Head(url, cfg...)
+}
+
+// Put send a put request use defaultClient, a convenient method.
+func Put(url string, cfg ...Config) (*Response, error) {
+ return C().Put(url, cfg...)
+}
+
+// Delete send a delete request use defaultClient, a convenient method.
+func Delete(url string, cfg ...Config) (*Response, error) {
+ return C().Delete(url, cfg...)
+}
+
+// Options send a options request use defaultClient, a convenient method.
+func Options(url string, cfg ...Config) (*Response, error) {
+ return C().Options(url, cfg...)
+}
+
+// Patch send a patch request use defaultClient, a convenient method.
+func Patch(url string, cfg ...Config) (*Response, error) {
+ return C().Patch(url, cfg...)
+}
diff --git a/client/client_test.go b/client/client_test.go
new file mode 100644
index 0000000000..4fd2e484a6
--- /dev/null
+++ b/client/client_test.go
@@ -0,0 +1,1642 @@
+package client
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "io"
+ "net"
+ "os"
+ "reflect"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/gofiber/fiber/v3"
+ "github.com/gofiber/fiber/v3/addon/retry"
+ "github.com/gofiber/fiber/v3/internal/tlstest"
+ "github.com/gofiber/utils/v2"
+ "github.com/stretchr/testify/require"
+ "github.com/valyala/bytebufferpool"
+)
+
+func startTestServerWithPort(t *testing.T, beforeStarting func(app *fiber.App)) (*fiber.App, string) {
+ t.Helper()
+
+ app := fiber.New()
+
+ if beforeStarting != nil {
+ beforeStarting(app)
+ }
+
+ addrChan := make(chan string)
+ errChan := make(chan error, 1)
+ go func() {
+ err := app.Listen(":0", fiber.ListenConfig{
+ DisableStartupMessage: true,
+ ListenerAddrFunc: func(addr net.Addr) {
+ addrChan <- addr.String()
+ },
+ })
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ select {
+ case addr := <-addrChan:
+ return app, addr
+ case err := <-errChan:
+ t.Fatalf("Failed to start test server: %v", err)
+ }
+
+ return nil, ""
+}
+
+func Test_Client_Add_Hook(t *testing.T) {
+ t.Parallel()
+
+ t.Run("add request hooks", func(t *testing.T) {
+ t.Parallel()
+
+ buf := bytebufferpool.Get()
+ defer bytebufferpool.Put(buf)
+
+ client := NewClient().AddRequestHook(func(_ *Client, _ *Request) error {
+ buf.WriteString("hook1")
+ return nil
+ })
+
+ require.Len(t, client.RequestHook(), 1)
+
+ client.AddRequestHook(func(_ *Client, _ *Request) error {
+ buf.WriteString("hook2")
+ return nil
+ }, func(_ *Client, _ *Request) error {
+ buf.WriteString("hook3")
+ return nil
+ })
+
+ require.Len(t, client.RequestHook(), 3)
+ })
+
+ t.Run("add response hooks", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().AddResponseHook(func(_ *Client, _ *Response, _ *Request) error {
+ return nil
+ })
+
+ require.Len(t, client.ResponseHook(), 1)
+
+ client.AddResponseHook(func(_ *Client, _ *Response, _ *Request) error {
+ return nil
+ }, func(_ *Client, _ *Response, _ *Request) error {
+ return nil
+ })
+
+ require.Len(t, client.ResponseHook(), 3)
+ })
+}
+
+func Test_Client_Add_Hook_CheckOrder(t *testing.T) {
+ t.Parallel()
+
+ buf := bytebufferpool.Get()
+ defer bytebufferpool.Put(buf)
+
+ client := NewClient().
+ AddRequestHook(func(_ *Client, _ *Request) error {
+ buf.WriteString("hook1")
+ return nil
+ }).
+ AddRequestHook(func(_ *Client, _ *Request) error {
+ buf.WriteString("hook2")
+ return nil
+ }).
+ AddRequestHook(func(_ *Client, _ *Request) error {
+ buf.WriteString("hook3")
+ return nil
+ })
+
+ for _, hook := range client.RequestHook() {
+ require.NoError(t, hook(client, &Request{}))
+ }
+
+ require.Equal(t, "hook1hook2hook3", buf.String())
+}
+
+func Test_Client_Marshal(t *testing.T) {
+ t.Parallel()
+
+ t.Run("set json marshal", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetJSONMarshal(func(_ any) ([]byte, error) {
+ return []byte("hello"), nil
+ })
+ val, err := client.JSONMarshal()(nil)
+
+ require.NoError(t, err)
+ require.Equal(t, []byte("hello"), val)
+ })
+
+ t.Run("set json marshal error", func(t *testing.T) {
+ t.Parallel()
+
+ emptyErr := errors.New("empty json")
+ client := NewClient().
+ SetJSONMarshal(func(_ any) ([]byte, error) {
+ return nil, emptyErr
+ })
+
+ val, err := client.JSONMarshal()(nil)
+ require.Nil(t, val)
+ require.ErrorIs(t, err, emptyErr)
+ })
+
+ t.Run("set json unmarshal", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetJSONUnmarshal(func(_ []byte, _ any) error {
+ return errors.New("empty json")
+ })
+
+ err := client.JSONUnmarshal()(nil, nil)
+ require.Equal(t, errors.New("empty json"), err)
+ })
+
+ t.Run("set json unmarshal error", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetJSONUnmarshal(func(_ []byte, _ any) error {
+ return errors.New("empty json")
+ })
+
+ err := client.JSONUnmarshal()(nil, nil)
+ require.Equal(t, errors.New("empty json"), err)
+ })
+
+ t.Run("set xml marshal", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetXMLMarshal(func(_ any) ([]byte, error) {
+ return []byte("hello"), nil
+ })
+ val, err := client.XMLMarshal()(nil)
+
+ require.NoError(t, err)
+ require.Equal(t, []byte("hello"), val)
+ })
+
+ t.Run("set xml marshal error", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetXMLMarshal(func(_ any) ([]byte, error) {
+ return nil, errors.New("empty xml")
+ })
+
+ val, err := client.XMLMarshal()(nil)
+ require.Nil(t, val)
+ require.Equal(t, errors.New("empty xml"), err)
+ })
+
+ t.Run("set xml unmarshal", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetXMLUnmarshal(func(_ []byte, _ any) error {
+ return errors.New("empty xml")
+ })
+
+ err := client.XMLUnmarshal()(nil, nil)
+ require.Equal(t, errors.New("empty xml"), err)
+ })
+
+ t.Run("set xml unmarshal error", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetXMLUnmarshal(func(_ []byte, _ any) error {
+ return errors.New("empty xml")
+ })
+
+ err := client.XMLUnmarshal()(nil, nil)
+ require.Equal(t, errors.New("empty xml"), err)
+ })
+}
+
+func Test_Client_SetBaseURL(t *testing.T) {
+ t.Parallel()
+
+ client := NewClient().SetBaseURL("http://example.com")
+
+ require.Equal(t, "http://example.com", client.BaseURL())
+}
+
+func Test_Client_Invalid_URL(t *testing.T) {
+ t.Parallel()
+
+ app, dial, start := createHelperServer(t)
+
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString(c.Hostname())
+ })
+
+ go start()
+
+ _, err := NewClient().SetDial(dial).
+ R().
+ Get("http//example")
+
+ require.ErrorIs(t, err, ErrURLFormat)
+}
+
+func Test_Client_Unsupported_Protocol(t *testing.T) {
+ t.Parallel()
+
+ _, err := NewClient().
+ R().
+ Get("ftp://example.com")
+
+ require.ErrorIs(t, err, ErrURLFormat)
+}
+
+func Test_Client_ConcurrencyRequests(t *testing.T) {
+ t.Parallel()
+
+ app, dial, start := createHelperServer(t)
+ app.All("/", func(c fiber.Ctx) error {
+ return c.SendString(c.Hostname() + " " + c.Method())
+ })
+ go start()
+
+ client := NewClient().SetDial(dial)
+
+ wg := sync.WaitGroup{}
+ for i := 0; i < 5; i++ {
+ for _, method := range []string{"GET", "POST", "PUT", "DELETE", "PATCH"} {
+ wg.Add(1)
+ go func(m string) {
+ defer wg.Done()
+ resp, err := client.Custom("http://example.com", m)
+ require.NoError(t, err)
+ require.Equal(t, "example.com "+m, utils.UnsafeString(resp.RawResponse.Body()))
+ }(method)
+ }
+ }
+
+ wg.Wait()
+}
+
+func Test_Get(t *testing.T) {
+ t.Parallel()
+
+ setupApp := func() (*fiber.App, string) {
+ app, addr := startTestServerWithPort(t, func(app *fiber.App) {
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString(c.Hostname())
+ })
+ })
+
+ return app, addr
+ }
+
+ t.Run("global get function", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ resp, err := Get("http://" + addr)
+ require.NoError(t, err)
+ require.Equal(t, "0.0.0.0", utils.UnsafeString(resp.RawResponse.Body()))
+ })
+
+ t.Run("client get", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ resp, err := NewClient().Get("http://" + addr)
+ require.NoError(t, err)
+ require.Equal(t, "0.0.0.0", utils.UnsafeString(resp.RawResponse.Body()))
+ })
+}
+
+func Test_Head(t *testing.T) {
+ t.Parallel()
+
+ setupApp := func() (*fiber.App, string) {
+ app, addr := startTestServerWithPort(t, func(app *fiber.App) {
+ app.Head("/", func(c fiber.Ctx) error {
+ return c.SendString(c.Hostname())
+ })
+ })
+
+ return app, addr
+ }
+
+ t.Run("global head function", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ resp, err := Head("http://" + addr)
+ require.NoError(t, err)
+ require.Equal(t, "7", resp.Header(fiber.HeaderContentLength))
+ require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body()))
+ })
+
+ t.Run("client head", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ resp, err := NewClient().Head("http://" + addr)
+ require.NoError(t, err)
+ require.Equal(t, "7", resp.Header(fiber.HeaderContentLength))
+ require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body()))
+ })
+}
+
+func Test_Post(t *testing.T) {
+ t.Parallel()
+
+ setupApp := func() (*fiber.App, string) {
+ app, addr := startTestServerWithPort(t, func(app *fiber.App) {
+ app.Post("/", func(c fiber.Ctx) error {
+ return c.Status(fiber.StatusCreated).
+ SendString(c.FormValue("foo"))
+ })
+ })
+
+ return app, addr
+ }
+
+ t.Run("global post function", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ for i := 0; i < 5; i++ {
+ resp, err := Post("http://"+addr, Config{
+ FormData: map[string]string{
+ "foo": "bar",
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusCreated, resp.StatusCode())
+ require.Equal(t, "bar", resp.String())
+ }
+ })
+
+ t.Run("client post", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ for i := 0; i < 5; i++ {
+ resp, err := NewClient().Post("http://"+addr, Config{
+ FormData: map[string]string{
+ "foo": "bar",
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusCreated, resp.StatusCode())
+ require.Equal(t, "bar", resp.String())
+ }
+ })
+}
+
+func Test_Put(t *testing.T) {
+ t.Parallel()
+
+ setupApp := func() (*fiber.App, string) {
+ app, addr := startTestServerWithPort(t, func(app *fiber.App) {
+ app.Put("/", func(c fiber.Ctx) error {
+ return c.SendString(c.FormValue("foo"))
+ })
+ })
+
+ return app, addr
+ }
+
+ t.Run("global put function", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ for i := 0; i < 5; i++ {
+ resp, err := Put("http://"+addr, Config{
+ FormData: map[string]string{
+ "foo": "bar",
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "bar", resp.String())
+ }
+ })
+
+ t.Run("client put", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ for i := 0; i < 5; i++ {
+ resp, err := NewClient().Put("http://"+addr, Config{
+ FormData: map[string]string{
+ "foo": "bar",
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "bar", resp.String())
+ }
+ })
+}
+
+func Test_Delete(t *testing.T) {
+ t.Parallel()
+
+ setupApp := func() (*fiber.App, string) {
+ app, addr := startTestServerWithPort(t, func(app *fiber.App) {
+ app.Delete("/", func(c fiber.Ctx) error {
+ return c.Status(fiber.StatusNoContent).
+ SendString("deleted")
+ })
+ })
+
+ return app, addr
+ }
+
+ t.Run("global delete function", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ time.Sleep(1 * time.Second)
+
+ for i := 0; i < 5; i++ {
+ resp, err := Delete("http://"+addr, Config{
+ FormData: map[string]string{
+ "foo": "bar",
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusNoContent, resp.StatusCode())
+ require.Equal(t, "", resp.String())
+ }
+ })
+
+ t.Run("client delete", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ for i := 0; i < 5; i++ {
+ resp, err := NewClient().Delete("http://"+addr, Config{
+ FormData: map[string]string{
+ "foo": "bar",
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusNoContent, resp.StatusCode())
+ require.Equal(t, "", resp.String())
+ }
+ })
+}
+
+func Test_Options(t *testing.T) {
+ t.Parallel()
+
+ setupApp := func() (*fiber.App, string) {
+ app, addr := startTestServerWithPort(t, func(app *fiber.App) {
+ app.Options("/", func(c fiber.Ctx) error {
+ c.Set(fiber.HeaderAllow, "GET, POST, PUT, DELETE, PATCH")
+ return c.Status(fiber.StatusNoContent).SendString("")
+ })
+ })
+
+ return app, addr
+ }
+
+ t.Run("global options function", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ for i := 0; i < 5; i++ {
+ resp, err := Options("http://" + addr)
+
+ require.NoError(t, err)
+ require.Equal(t, "GET, POST, PUT, DELETE, PATCH", resp.Header(fiber.HeaderAllow))
+ require.Equal(t, fiber.StatusNoContent, resp.StatusCode())
+ require.Equal(t, "", resp.String())
+ }
+ })
+
+ t.Run("client options", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ for i := 0; i < 5; i++ {
+ resp, err := NewClient().Options("http://" + addr)
+
+ require.NoError(t, err)
+ require.Equal(t, "GET, POST, PUT, DELETE, PATCH", resp.Header(fiber.HeaderAllow))
+ require.Equal(t, fiber.StatusNoContent, resp.StatusCode())
+ require.Equal(t, "", resp.String())
+ }
+ })
+}
+
+func Test_Patch(t *testing.T) {
+ t.Parallel()
+
+ setupApp := func() (*fiber.App, string) {
+ app, addr := startTestServerWithPort(t, func(app *fiber.App) {
+ app.Patch("/", func(c fiber.Ctx) error {
+ return c.SendString(c.FormValue("foo"))
+ })
+ })
+
+ return app, addr
+ }
+
+ t.Run("global patch function", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ time.Sleep(1 * time.Second)
+
+ for i := 0; i < 5; i++ {
+ resp, err := Patch("http://"+addr, Config{
+ FormData: map[string]string{
+ "foo": "bar",
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "bar", resp.String())
+ }
+ })
+
+ t.Run("client patch", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ for i := 0; i < 5; i++ {
+ resp, err := NewClient().Patch("http://"+addr, Config{
+ FormData: map[string]string{
+ "foo": "bar",
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "bar", resp.String())
+ }
+ })
+}
+
+func Test_Client_UserAgent(t *testing.T) {
+ t.Parallel()
+
+ setupApp := func() (*fiber.App, string) {
+ app, addr := startTestServerWithPort(t, func(app *fiber.App) {
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.Send(c.Request().Header.UserAgent())
+ })
+ })
+
+ return app, addr
+ }
+
+ t.Run("default", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ for i := 0; i < 5; i++ {
+ resp, err := Get("http://" + addr)
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, defaultUserAgent, resp.String())
+ }
+ })
+
+ t.Run("custom", func(t *testing.T) {
+ t.Parallel()
+
+ app, addr := setupApp()
+ defer func() {
+ require.NoError(t, app.Shutdown())
+ }()
+
+ for i := 0; i < 5; i++ {
+ c := NewClient().
+ SetUserAgent("ua")
+
+ resp, err := c.Get("http://" + addr)
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "ua", resp.String())
+ }
+ })
+}
+
+func Test_Client_Header(t *testing.T) {
+ t.Parallel()
+
+ t.Run("add header", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient()
+ req.AddHeader("foo", "bar").AddHeader("foo", "fiber")
+
+ res := req.Header("foo")
+ require.Len(t, res, 2)
+ require.Equal(t, "bar", res[0])
+ require.Equal(t, "fiber", res[1])
+ })
+
+ t.Run("set header", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient()
+ req.AddHeader("foo", "bar").SetHeader("foo", "fiber")
+
+ res := req.Header("foo")
+ require.Len(t, res, 1)
+ require.Equal(t, "fiber", res[0])
+ })
+
+ t.Run("add headers", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient()
+ req.SetHeader("foo", "bar").
+ AddHeaders(map[string][]string{
+ "foo": {"fiber", "buaa"},
+ "bar": {"foo"},
+ })
+
+ res := req.Header("foo")
+ require.Len(t, res, 3)
+ require.Equal(t, "bar", res[0])
+ require.Equal(t, "fiber", res[1])
+ require.Equal(t, "buaa", res[2])
+
+ res = req.Header("bar")
+ require.Len(t, res, 1)
+ require.Equal(t, "foo", res[0])
+ })
+
+ t.Run("set headers", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient()
+ req.SetHeader("foo", "bar").
+ SetHeaders(map[string]string{
+ "foo": "fiber",
+ "bar": "foo",
+ })
+
+ res := req.Header("foo")
+ require.Len(t, res, 1)
+ require.Equal(t, "fiber", res[0])
+
+ res = req.Header("bar")
+ require.Len(t, res, 1)
+ require.Equal(t, "foo", res[0])
+ })
+
+ t.Run("set header case insensitive", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient()
+ req.SetHeader("foo", "bar").
+ AddHeader("FOO", "fiber")
+
+ res := req.Header("foo")
+ require.Len(t, res, 2)
+ require.Equal(t, "bar", res[0])
+ require.Equal(t, "fiber", res[1])
+ })
+}
+
+func Test_Client_Header_With_Server(t *testing.T) {
+ handler := func(c fiber.Ctx) error {
+ c.Request().Header.VisitAll(func(key, value []byte) {
+ if k := string(key); k == "K1" || k == "K2" {
+ _, _ = c.Write(key) //nolint:errcheck // It is fine to ignore the error here
+ _, _ = c.Write(value) //nolint:errcheck // It is fine to ignore the error here
+ }
+ })
+ return nil
+ }
+
+ wrapAgent := func(c *Client) {
+ c.SetHeader("k1", "v1").
+ AddHeader("k1", "v11").
+ AddHeaders(map[string][]string{
+ "k1": {"v22", "v33"},
+ }).
+ SetHeaders(map[string]string{
+ "k2": "v2",
+ }).
+ AddHeader("k2", "v22")
+ }
+
+ testClient(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22")
+}
+
+func Test_Client_Cookie(t *testing.T) {
+ t.Parallel()
+
+ t.Run("set cookie", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient().
+ SetCookie("foo", "bar")
+ require.Equal(t, "bar", req.Cookie("foo"))
+
+ req.SetCookie("foo", "bar1")
+ require.Equal(t, "bar1", req.Cookie("foo"))
+ })
+
+ t.Run("set cookies", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient().
+ SetCookies(map[string]string{
+ "foo": "bar",
+ "bar": "foo",
+ })
+ require.Equal(t, "bar", req.Cookie("foo"))
+ require.Equal(t, "foo", req.Cookie("bar"))
+
+ req.SetCookies(map[string]string{
+ "foo": "bar1",
+ })
+ require.Equal(t, "bar1", req.Cookie("foo"))
+ require.Equal(t, "foo", req.Cookie("bar"))
+ })
+
+ t.Run("set cookies with struct", func(t *testing.T) {
+ t.Parallel()
+ type args struct {
+ CookieInt int `cookie:"int"`
+ CookieString string `cookie:"string"`
+ }
+
+ req := NewClient().SetCookiesWithStruct(&args{
+ CookieInt: 5,
+ CookieString: "foo",
+ })
+
+ require.Equal(t, "5", req.Cookie("int"))
+ require.Equal(t, "foo", req.Cookie("string"))
+ })
+
+ t.Run("del cookies", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient().
+ SetCookies(map[string]string{
+ "foo": "bar",
+ "bar": "foo",
+ })
+ require.Equal(t, "bar", req.Cookie("foo"))
+ require.Equal(t, "foo", req.Cookie("bar"))
+
+ req.DelCookies("foo")
+ require.Equal(t, "", req.Cookie("foo"))
+ require.Equal(t, "foo", req.Cookie("bar"))
+ })
+}
+
+func Test_Client_Cookie_With_Server(t *testing.T) {
+ t.Parallel()
+
+ handler := func(c fiber.Ctx) error {
+ return c.SendString(
+ c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4"))
+ }
+
+ wrapAgent := func(c *Client) {
+ c.SetCookie("k1", "v1").
+ SetCookies(map[string]string{
+ "k2": "v2",
+ "k3": "v3",
+ "k4": "v4",
+ }).DelCookies("k4")
+ }
+
+ testClient(t, handler, wrapAgent, "v1v2v3")
+}
+
+func Test_Client_CookieJar(t *testing.T) {
+ handler := func(c fiber.Ctx) error {
+ return c.SendString(
+ c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3"))
+ }
+
+ jar := AcquireCookieJar()
+ defer ReleaseCookieJar(jar)
+
+ jar.SetKeyValue("example.com", "k1", "v1")
+ jar.SetKeyValue("example.com", "k2", "v2")
+ jar.SetKeyValue("example", "k3", "v3")
+
+ wrapAgent := func(c *Client) {
+ c.SetCookieJar(jar)
+ }
+ testClient(t, handler, wrapAgent, "v1v2")
+}
+
+func Test_Client_CookieJar_Response(t *testing.T) {
+ t.Parallel()
+
+ t.Run("without expiration", func(t *testing.T) {
+ t.Parallel()
+ handler := func(c fiber.Ctx) error {
+ c.Cookie(&fiber.Cookie{
+ Name: "k4",
+ Value: "v4",
+ })
+ return c.SendString(
+ c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3"))
+ }
+
+ jar := AcquireCookieJar()
+ defer ReleaseCookieJar(jar)
+
+ jar.SetKeyValue("example.com", "k1", "v1")
+ jar.SetKeyValue("example.com", "k2", "v2")
+ jar.SetKeyValue("example", "k3", "v3")
+
+ wrapAgent := func(c *Client) {
+ c.SetCookieJar(jar)
+ }
+ testClient(t, handler, wrapAgent, "v1v2")
+
+ require.Len(t, jar.getCookiesByHost("example.com"), 3)
+ })
+
+ t.Run("with expiration", func(t *testing.T) {
+ t.Parallel()
+ handler := func(c fiber.Ctx) error {
+ c.Cookie(&fiber.Cookie{
+ Name: "k4",
+ Value: "v4",
+ Expires: time.Now().Add(1 * time.Nanosecond),
+ })
+ return c.SendString(
+ c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3"))
+ }
+
+ jar := AcquireCookieJar()
+ defer ReleaseCookieJar(jar)
+
+ jar.SetKeyValue("example.com", "k1", "v1")
+ jar.SetKeyValue("example.com", "k2", "v2")
+ jar.SetKeyValue("example", "k3", "v3")
+
+ wrapAgent := func(c *Client) {
+ c.SetCookieJar(jar)
+ }
+ testClient(t, handler, wrapAgent, "v1v2")
+
+ require.Len(t, jar.getCookiesByHost("example.com"), 2)
+ })
+
+ t.Run("override cookie value", func(t *testing.T) {
+ t.Parallel()
+ handler := func(c fiber.Ctx) error {
+ c.Cookie(&fiber.Cookie{
+ Name: "k1",
+ Value: "v2",
+ })
+ return c.SendString(
+ c.Cookies("k1") + c.Cookies("k2"))
+ }
+
+ jar := AcquireCookieJar()
+ defer ReleaseCookieJar(jar)
+
+ jar.SetKeyValue("example.com", "k1", "v1")
+ jar.SetKeyValue("example.com", "k2", "v2")
+
+ wrapAgent := func(c *Client) {
+ c.SetCookieJar(jar)
+ }
+ testClient(t, handler, wrapAgent, "v1v2")
+
+ for _, cookie := range jar.getCookiesByHost("example.com") {
+ if string(cookie.Key()) == "k1" {
+ require.Equal(t, "v2", string(cookie.Value()))
+ }
+ }
+ })
+
+ t.Run("different domain", func(t *testing.T) {
+ t.Parallel()
+ handler := func(c fiber.Ctx) error {
+ return c.SendString(c.Cookies("k1"))
+ }
+
+ jar := AcquireCookieJar()
+ defer ReleaseCookieJar(jar)
+
+ jar.SetKeyValue("example.com", "k1", "v1")
+
+ wrapAgent := func(c *Client) {
+ c.SetCookieJar(jar)
+ }
+ testClient(t, handler, wrapAgent, "v1")
+
+ require.Len(t, jar.getCookiesByHost("example.com"), 1)
+ require.Empty(t, jar.getCookiesByHost("example"))
+ })
+}
+
+func Test_Client_Referer(t *testing.T) {
+ handler := func(c fiber.Ctx) error {
+ return c.Send(c.Request().Header.Referer())
+ }
+
+ wrapAgent := func(c *Client) {
+ c.SetReferer("http://referer.com")
+ }
+
+ testClient(t, handler, wrapAgent, "http://referer.com")
+}
+
+func Test_Client_QueryParam(t *testing.T) {
+ t.Parallel()
+
+ t.Run("add param", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient()
+ req.AddParam("foo", "bar").AddParam("foo", "fiber")
+
+ res := req.Param("foo")
+ require.Len(t, res, 2)
+ require.Equal(t, "bar", res[0])
+ require.Equal(t, "fiber", res[1])
+ })
+
+ t.Run("set param", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient()
+ req.AddParam("foo", "bar").SetParam("foo", "fiber")
+
+ res := req.Param("foo")
+ require.Len(t, res, 1)
+ require.Equal(t, "fiber", res[0])
+ })
+
+ t.Run("add params", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient()
+ req.SetParam("foo", "bar").
+ AddParams(map[string][]string{
+ "foo": {"fiber", "buaa"},
+ "bar": {"foo"},
+ })
+
+ res := req.Param("foo")
+ require.Len(t, res, 3)
+ require.Equal(t, "bar", res[0])
+ require.Equal(t, "fiber", res[1])
+ require.Equal(t, "buaa", res[2])
+
+ res = req.Param("bar")
+ require.Len(t, res, 1)
+ require.Equal(t, "foo", res[0])
+ })
+
+ t.Run("set headers", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient()
+ req.SetParam("foo", "bar").
+ SetParams(map[string]string{
+ "foo": "fiber",
+ "bar": "foo",
+ })
+
+ res := req.Param("foo")
+ require.Len(t, res, 1)
+ require.Equal(t, "fiber", res[0])
+
+ res = req.Param("bar")
+ require.Len(t, res, 1)
+ require.Equal(t, "foo", res[0])
+ })
+
+ t.Run("set params with struct", func(t *testing.T) {
+ t.Parallel()
+
+ type args struct {
+ TInt int
+ TString string
+ TFloat float64
+ TBool bool
+ TSlice []string
+ TIntSlice []int `param:"int_slice"`
+ }
+
+ p := NewClient()
+ p.SetParamsWithStruct(&args{
+ TInt: 5,
+ TString: "string",
+ TFloat: 3.1,
+ TBool: true,
+ TSlice: []string{"foo", "bar"},
+ TIntSlice: []int{1, 2},
+ })
+
+ require.Empty(t, p.Param("unexport"))
+
+ require.Len(t, p.Param("TInt"), 1)
+ require.Equal(t, "5", p.Param("TInt")[0])
+
+ require.Len(t, p.Param("TString"), 1)
+ require.Equal(t, "string", p.Param("TString")[0])
+
+ require.Len(t, p.Param("TFloat"), 1)
+ require.Equal(t, "3.1", p.Param("TFloat")[0])
+
+ require.Len(t, p.Param("TBool"), 1)
+
+ tslice := p.Param("TSlice")
+ require.Len(t, tslice, 2)
+ require.Equal(t, "foo", tslice[0])
+ require.Equal(t, "bar", tslice[1])
+
+ tint := p.Param("TSlice")
+ require.Len(t, tint, 2)
+ require.Equal(t, "foo", tint[0])
+ require.Equal(t, "bar", tint[1])
+ })
+
+ t.Run("del params", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient()
+ req.SetParam("foo", "bar").
+ SetParams(map[string]string{
+ "foo": "fiber",
+ "bar": "foo",
+ }).DelParams("foo", "bar")
+
+ res := req.Param("foo")
+ require.Empty(t, res)
+
+ res = req.Param("bar")
+ require.Empty(t, res)
+ })
+}
+
+func Test_Client_QueryParam_With_Server(t *testing.T) {
+ handler := func(c fiber.Ctx) error {
+ _, _ = c.WriteString(c.Query("k1")) //nolint:errcheck // It is fine to ignore the error here
+ _, _ = c.WriteString(c.Query("k2")) //nolint:errcheck // It is fine to ignore the error here
+
+ return nil
+ }
+
+ wrapAgent := func(c *Client) {
+ c.SetParam("k1", "v1").
+ AddParam("k2", "v2")
+ }
+
+ testClient(t, handler, wrapAgent, "v1v2")
+}
+
+func Test_Client_PathParam(t *testing.T) {
+ t.Parallel()
+
+ t.Run("set path param", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient().
+ SetPathParam("foo", "bar")
+ require.Equal(t, "bar", req.PathParam("foo"))
+
+ req.SetPathParam("foo", "bar1")
+ require.Equal(t, "bar1", req.PathParam("foo"))
+ })
+
+ t.Run("set path params", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient().
+ SetPathParams(map[string]string{
+ "foo": "bar",
+ "bar": "foo",
+ })
+ require.Equal(t, "bar", req.PathParam("foo"))
+ require.Equal(t, "foo", req.PathParam("bar"))
+
+ req.SetPathParams(map[string]string{
+ "foo": "bar1",
+ })
+ require.Equal(t, "bar1", req.PathParam("foo"))
+ require.Equal(t, "foo", req.PathParam("bar"))
+ })
+
+ t.Run("set path params with struct", func(t *testing.T) {
+ t.Parallel()
+ type args struct {
+ CookieInt int `path:"int"`
+ CookieString string `path:"string"`
+ }
+
+ req := NewClient().SetPathParamsWithStruct(&args{
+ CookieInt: 5,
+ CookieString: "foo",
+ })
+
+ require.Equal(t, "5", req.PathParam("int"))
+ require.Equal(t, "foo", req.PathParam("string"))
+ })
+
+ t.Run("del path params", func(t *testing.T) {
+ t.Parallel()
+ req := NewClient().
+ SetPathParams(map[string]string{
+ "foo": "bar",
+ "bar": "foo",
+ })
+ require.Equal(t, "bar", req.PathParam("foo"))
+ require.Equal(t, "foo", req.PathParam("bar"))
+
+ req.DelPathParams("foo")
+ require.Equal(t, "", req.PathParam("foo"))
+ require.Equal(t, "foo", req.PathParam("bar"))
+ })
+}
+
+func Test_Client_PathParam_With_Server(t *testing.T) {
+ app, dial, start := createHelperServer(t)
+
+ app.Get("/:test", func(c fiber.Ctx) error {
+ return c.SendString(c.Params("test"))
+ })
+
+ go start()
+
+ resp, err := NewClient().SetDial(dial).
+ SetPathParam("path", "test").
+ Get("http://example.com/:path")
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "test", resp.String())
+}
+
+func Test_Client_TLS(t *testing.T) {
+ t.Parallel()
+
+ serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs()
+ require.NoError(t, err)
+
+ ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
+ require.NoError(t, err)
+
+ ln = tls.NewListener(ln, serverTLSConf)
+
+ app := fiber.New()
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString("tls")
+ })
+
+ go func() {
+ require.NoError(t, app.Listener(ln, fiber.ListenConfig{
+ DisableStartupMessage: true,
+ }))
+ }()
+
+ client := NewClient()
+ resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String())
+
+ require.NoError(t, err)
+ require.Equal(t, clientTLSConf, client.TLSConfig())
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "tls", resp.String())
+}
+
+func Test_Client_TLS_Error(t *testing.T) {
+ t.Parallel()
+
+ serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs()
+ clientTLSConf.MaxVersion = tls.VersionTLS12
+ serverTLSConf.MinVersion = tls.VersionTLS13
+ require.NoError(t, err)
+
+ ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
+ require.NoError(t, err)
+
+ ln = tls.NewListener(ln, serverTLSConf)
+
+ app := fiber.New()
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString("tls")
+ })
+
+ go func() {
+ require.NoError(t, app.Listener(ln, fiber.ListenConfig{
+ DisableStartupMessage: true,
+ }))
+ }()
+
+ client := NewClient()
+ resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String())
+
+ require.Error(t, err)
+ require.Equal(t, clientTLSConf, client.TLSConfig())
+ require.Nil(t, resp)
+}
+
+func Test_Client_TLS_Empty_TLSConfig(t *testing.T) {
+ t.Parallel()
+
+ serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs()
+ require.NoError(t, err)
+
+ ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
+ require.NoError(t, err)
+
+ ln = tls.NewListener(ln, serverTLSConf)
+
+ app := fiber.New()
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString("tls")
+ })
+
+ go func() {
+ require.NoError(t, app.Listener(ln, fiber.ListenConfig{
+ DisableStartupMessage: true,
+ }))
+ }()
+
+ client := NewClient()
+ resp, err := client.Get("https://" + ln.Addr().String())
+
+ require.Error(t, err)
+ require.NotEqual(t, clientTLSConf, client.TLSConfig())
+ require.Nil(t, resp)
+}
+
+func Test_Client_SetCertificates(t *testing.T) {
+ t.Parallel()
+
+ serverTLSConf, _, err := tlstest.GetTLSConfigs()
+ require.NoError(t, err)
+
+ client := NewClient().SetCertificates(serverTLSConf.Certificates...)
+ require.Len(t, client.TLSConfig().Certificates, 1)
+}
+
+func Test_Client_SetRootCertificate(t *testing.T) {
+ t.Parallel()
+
+ client := NewClient().SetRootCertificate("../.github/testdata/ssl.pem")
+ require.NotNil(t, client.TLSConfig().RootCAs)
+}
+
+func Test_Client_SetRootCertificateFromString(t *testing.T) {
+ t.Parallel()
+
+ file, err := os.Open("../.github/testdata/ssl.pem")
+ defer func() { require.NoError(t, file.Close()) }()
+ require.NoError(t, err)
+
+ pem, err := io.ReadAll(file)
+ require.NoError(t, err)
+
+ client := NewClient().SetRootCertificateFromString(string(pem))
+ require.NotNil(t, client.TLSConfig().RootCAs)
+}
+
+func Test_Client_R(t *testing.T) {
+ t.Parallel()
+
+ client := NewClient()
+ req := client.R()
+
+ require.Equal(t, "Request", reflect.TypeOf(req).Elem().Name())
+ require.Equal(t, client, req.Client())
+}
+
+func Test_Replace(t *testing.T) {
+ app, dial, start := createHelperServer(t)
+
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString(string(c.Request().Header.Peek("k1")))
+ })
+
+ go start()
+
+ C().SetDial(dial)
+ resp, err := Get("http://example.com")
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "", resp.String())
+
+ r := NewClient().SetDial(dial).SetHeader("k1", "v1")
+ clean := Replace(r)
+ resp, err = Get("http://example.com")
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "v1", resp.String())
+
+ clean()
+
+ C().SetDial(dial)
+ resp, err = Get("http://example.com")
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "", resp.String())
+
+ C().SetDial(nil)
+}
+
+func Test_Set_Config_To_Request(t *testing.T) {
+ t.Parallel()
+
+ t.Run("set ctx", func(t *testing.T) {
+ t.Parallel()
+ key := struct{}{}
+
+ ctx := context.Background()
+ ctx = context.WithValue(ctx, key, "v1")
+
+ req := AcquireRequest()
+
+ setConfigToRequest(req, Config{Ctx: ctx})
+
+ require.Equal(t, "v1", req.Context().Value(key))
+ })
+
+ t.Run("set useragent", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+
+ setConfigToRequest(req, Config{UserAgent: "agent"})
+
+ require.Equal(t, "agent", req.UserAgent())
+ })
+
+ t.Run("set referer", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+
+ setConfigToRequest(req, Config{Referer: "referer"})
+
+ require.Equal(t, "referer", req.Referer())
+ })
+
+ t.Run("set header", func(t *testing.T) {
+ req := AcquireRequest()
+
+ setConfigToRequest(req, Config{Header: map[string]string{
+ "k1": "v1",
+ }})
+
+ require.Equal(t, "v1", req.Header("k1")[0])
+ })
+
+ t.Run("set params", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+
+ setConfigToRequest(req, Config{Param: map[string]string{
+ "k1": "v1",
+ }})
+
+ require.Equal(t, "v1", req.Param("k1")[0])
+ })
+
+ t.Run("set cookies", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+
+ setConfigToRequest(req, Config{Cookie: map[string]string{
+ "k1": "v1",
+ }})
+
+ require.Equal(t, "v1", req.Cookie("k1"))
+ })
+
+ t.Run("set pathparam", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+
+ setConfigToRequest(req, Config{PathParam: map[string]string{
+ "k1": "v1",
+ }})
+
+ require.Equal(t, "v1", req.PathParam("k1"))
+ })
+
+ t.Run("set timeout", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+
+ setConfigToRequest(req, Config{Timeout: 1 * time.Second})
+
+ require.Equal(t, 1*time.Second, req.Timeout())
+ })
+
+ t.Run("set maxredirects", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+
+ setConfigToRequest(req, Config{MaxRedirects: 1})
+
+ require.Equal(t, 1, req.MaxRedirects())
+ })
+
+ t.Run("set body", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+
+ setConfigToRequest(req, Config{Body: "test"})
+
+ require.Equal(t, "test", req.body)
+ })
+
+ t.Run("set file", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+
+ setConfigToRequest(req, Config{File: []*File{
+ {
+ name: "test",
+ path: "path",
+ },
+ }})
+
+ require.Equal(t, "path", req.File("test").path)
+ })
+}
+
+func Test_Client_SetProxyURL(t *testing.T) {
+ t.Parallel()
+
+ app, dial, start := createHelperServer(t)
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString("hello world")
+ })
+
+ go start()
+
+ t.Cleanup(func() {
+ require.NoError(t, app.Shutdown())
+ })
+
+ time.Sleep(1 * time.Second)
+
+ t.Run("success", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().SetDial(dial)
+ err := client.SetProxyURL("http://test.com")
+
+ require.NoError(t, err)
+
+ _, err = client.Get("http://localhost:3000")
+
+ require.NoError(t, err)
+ })
+
+ t.Run("wrong url", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient()
+
+ err := client.SetProxyURL(":this is not a url")
+
+ require.Error(t, err)
+ })
+
+ t.Run("error", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient()
+
+ err := client.SetProxyURL("htgdftp://test.com")
+
+ require.Error(t, err)
+ })
+}
+
+func Test_Client_SetRetryConfig(t *testing.T) {
+ t.Parallel()
+
+ retryConfig := &retry.Config{
+ InitialInterval: 1 * time.Second,
+ MaxRetryCount: 3,
+ }
+
+ core, client, req := newCore(), NewClient(), AcquireRequest()
+ req.SetURL("http://example.com")
+ client.SetRetryConfig(retryConfig)
+ _, err := core.execute(context.Background(), client, req)
+
+ require.NoError(t, err)
+ require.Equal(t, retryConfig.InitialInterval, client.RetryConfig().InitialInterval)
+ require.Equal(t, retryConfig.MaxRetryCount, client.RetryConfig().MaxRetryCount)
+}
+
+func Benchmark_Client_Request(b *testing.B) {
+ app, dial, start := createHelperServer(b)
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString("hello world")
+ })
+
+ go start()
+
+ client := NewClient().SetDial(dial)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ var err error
+ var resp *Response
+ for i := 0; i < b.N; i++ {
+ resp, err = client.Get("http://example.com")
+ resp.Close()
+ }
+ require.NoError(b, err)
+}
+
+func Benchmark_Client_Request_Parallel(b *testing.B) {
+ app, dial, start := createHelperServer(b)
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString("hello world")
+ })
+
+ go start()
+
+ client := NewClient().SetDial(dial)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ b.RunParallel(func(pb *testing.PB) {
+ var err error
+ var resp *Response
+ for pb.Next() {
+ resp, err = client.Get("http://example.com")
+ resp.Close()
+ }
+ require.NoError(b, err)
+ })
+}
diff --git a/client/cookiejar.go b/client/cookiejar.go
new file mode 100644
index 0000000000..c66d5f3b7c
--- /dev/null
+++ b/client/cookiejar.go
@@ -0,0 +1,245 @@
+// The code has been taken from https://github.com/valyala/fasthttp/pull/526 originally.
+package client
+
+import (
+ "bytes"
+ "errors"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/gofiber/utils/v2"
+ "github.com/valyala/fasthttp"
+)
+
+var cookieJarPool = sync.Pool{
+ New: func() any {
+ return &CookieJar{}
+ },
+}
+
+// AcquireCookieJar returns an empty CookieJar object from pool.
+func AcquireCookieJar() *CookieJar {
+ jar, ok := cookieJarPool.Get().(*CookieJar)
+ if !ok {
+ panic(errors.New("failed to type-assert to *CookieJar"))
+ }
+
+ return jar
+}
+
+// ReleaseCookieJar returns CookieJar to the pool.
+func ReleaseCookieJar(c *CookieJar) {
+ c.Release()
+ cookieJarPool.Put(c)
+}
+
+// CookieJar manages cookie storage. It is used by the client to store cookies.
+type CookieJar struct {
+ mu sync.Mutex
+ hostCookies map[string][]*fasthttp.Cookie
+}
+
+// Get returns the cookies stored from a specific domain.
+// If there were no cookies related with host returned slice will be nil.
+//
+// CookieJar keeps a copy of the cookies, so the returned cookies can be released safely.
+func (cj *CookieJar) Get(uri *fasthttp.URI) []*fasthttp.Cookie {
+ if uri == nil {
+ return nil
+ }
+
+ return cj.getByHostAndPath(uri.Host(), uri.Path())
+}
+
+// get returns the cookies stored from a specific host and path.
+func (cj *CookieJar) getByHostAndPath(host, path []byte) []*fasthttp.Cookie {
+ if cj.hostCookies == nil {
+ return nil
+ }
+
+ var (
+ err error
+ cookies []*fasthttp.Cookie
+ hostStr = utils.UnsafeString(host)
+ )
+
+ // port must not be included.
+ hostStr, _, err = net.SplitHostPort(hostStr)
+ if err != nil {
+ hostStr = utils.UnsafeString(host)
+ }
+ // get cookies deleting expired ones
+ cookies = cj.getCookiesByHost(hostStr)
+
+ newCookies := make([]*fasthttp.Cookie, 0, len(cookies))
+ for i := 0; i < len(cookies); i++ {
+ cookie := cookies[i]
+ if len(path) > 1 && len(cookie.Path()) > 1 && !bytes.HasPrefix(cookie.Path(), path) {
+ continue
+ }
+ newCookies = append(newCookies, cookie)
+ }
+
+ return newCookies
+}
+
+// getCookiesByHost returns the cookies stored from a specific host.
+// If cookies are expired they will be deleted.
+func (cj *CookieJar) getCookiesByHost(host string) []*fasthttp.Cookie {
+ cj.mu.Lock()
+ defer cj.mu.Unlock()
+
+ now := time.Now()
+ cookies := cj.hostCookies[host]
+
+ for i := 0; i < len(cookies); i++ {
+ c := cookies[i]
+ if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) { // release cookie if expired
+ cookies = append(cookies[:i], cookies[i+1:]...)
+ fasthttp.ReleaseCookie(c)
+ i--
+ }
+ }
+
+ return cookies
+}
+
+// Set sets cookies for a specific host.
+// The host is get from uri.Host().
+// If the cookie key already exists it will be replaced by the new cookie value.
+//
+// CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely.
+func (cj *CookieJar) Set(uri *fasthttp.URI, cookies ...*fasthttp.Cookie) {
+ if uri == nil {
+ return
+ }
+
+ cj.SetByHost(uri.Host(), cookies...)
+}
+
+// SetByHost sets cookies for a specific host.
+// If the cookie key already exists it will be replaced by the new cookie value.
+//
+// CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely.
+func (cj *CookieJar) SetByHost(host []byte, cookies ...*fasthttp.Cookie) {
+ hostStr := utils.UnsafeString(host)
+
+ cj.mu.Lock()
+ defer cj.mu.Unlock()
+
+ if cj.hostCookies == nil {
+ cj.hostCookies = make(map[string][]*fasthttp.Cookie)
+ }
+
+ hostCookies, ok := cj.hostCookies[hostStr]
+ if !ok {
+ // If the key does not exist in the map, then we must make a copy for the key to avoid unsafe usage.
+ hostStr = string(host)
+ }
+
+ for _, cookie := range cookies {
+ c := searchCookieByKeyAndPath(cookie.Key(), cookie.Path(), hostCookies)
+ if c == nil {
+ // If the cookie does not exist in the slice, let's acquire new cookie and store it.
+ c = fasthttp.AcquireCookie()
+ hostCookies = append(hostCookies, c)
+ }
+ c.CopyTo(cookie) // override cookie properties
+ }
+ cj.hostCookies[hostStr] = hostCookies
+}
+
+// SetKeyValue sets a cookie by key and value for a specific host.
+//
+// This function prevents extra allocations by making repeated cookies
+// not being duplicated.
+func (cj *CookieJar) SetKeyValue(host, key, value string) {
+ c := fasthttp.AcquireCookie()
+ c.SetKey(key)
+ c.SetValue(value)
+
+ cj.SetByHost(utils.UnsafeBytes(host), c)
+}
+
+// SetKeyValueBytes sets a cookie by key and value for a specific host.
+//
+// This function prevents extra allocations by making repeated cookies
+// not being duplicated.
+func (cj *CookieJar) SetKeyValueBytes(host string, key, value []byte) {
+ c := fasthttp.AcquireCookie()
+ c.SetKeyBytes(key)
+ c.SetValueBytes(value)
+
+ cj.SetByHost(utils.UnsafeBytes(host), c)
+}
+
+// dumpCookiesToReq dumps the stored cookies to the request.
+func (cj *CookieJar) dumpCookiesToReq(req *fasthttp.Request) {
+ uri := req.URI()
+
+ cookies := cj.getByHostAndPath(uri.Host(), uri.Path())
+ for _, cookie := range cookies {
+ req.Header.SetCookieBytesKV(cookie.Key(), cookie.Value())
+ }
+}
+
+// parseCookiesFromResp parses the response cookies and stores them.
+func (cj *CookieJar) parseCookiesFromResp(host, path []byte, resp *fasthttp.Response) {
+ hostStr := utils.UnsafeString(host)
+
+ cj.mu.Lock()
+ defer cj.mu.Unlock()
+
+ if cj.hostCookies == nil {
+ cj.hostCookies = make(map[string][]*fasthttp.Cookie)
+ }
+ cookies, ok := cj.hostCookies[hostStr]
+ if !ok {
+ // If the key does not exist in the map then
+ // we must make a copy for the key to avoid unsafe usage.
+ hostStr = string(host)
+ }
+
+ now := time.Now()
+ resp.Header.VisitAllCookie(func(key, value []byte) {
+ isCreated := false
+ c := searchCookieByKeyAndPath(key, path, cookies)
+ if c == nil {
+ c, isCreated = fasthttp.AcquireCookie(), true
+ }
+
+ _ = c.ParseBytes(value) //nolint:errcheck // ignore error
+ if c.Expire().Equal(fasthttp.CookieExpireUnlimited) || c.Expire().After(now) {
+ cookies = append(cookies, c)
+ } else if isCreated {
+ fasthttp.ReleaseCookie(c)
+ }
+ })
+ cj.hostCookies[hostStr] = cookies
+}
+
+// Release releases all cookie values.
+func (cj *CookieJar) Release() {
+ // FOllOW-UP performance optimization
+ // currently a race condition is found because the reset method modifies a value which is not a copy but a reference -> solution should be to make a copy
+ // for _, v := range cj.hostCookies {
+ // for _, c := range v {
+ // fasthttp.ReleaseCookie(c)
+ // }
+ // }
+ cj.hostCookies = nil
+}
+
+// searchCookieByKeyAndPath searches for a cookie by key and path.
+func searchCookieByKeyAndPath(key, path []byte, cookies []*fasthttp.Cookie) *fasthttp.Cookie {
+ for _, c := range cookies {
+ if bytes.Equal(key, c.Key()) {
+ if len(path) <= 1 || bytes.HasPrefix(c.Path(), path) {
+ return c
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/client/cookiejar_test.go b/client/cookiejar_test.go
new file mode 100644
index 0000000000..3b6fdcda83
--- /dev/null
+++ b/client/cookiejar_test.go
@@ -0,0 +1,213 @@
+package client
+
+import (
+ "bytes"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "github.com/valyala/fasthttp"
+)
+
+func checkKeyValue(t *testing.T, cj *CookieJar, cookie *fasthttp.Cookie, uri *fasthttp.URI, n int) {
+ t.Helper()
+
+ cs := cj.Get(uri)
+ require.GreaterOrEqual(t, len(cs), n)
+
+ c := cs[n-1]
+ require.NotNil(t, c)
+
+ require.Equal(t, string(c.Key()), string(cookie.Key()))
+ require.Equal(t, string(c.Value()), string(cookie.Value()))
+}
+
+func TestCookieJarGet(t *testing.T) {
+ t.Parallel()
+
+ url := []byte("http://fasthttp.com/")
+ url1 := []byte("http://fasthttp.com/make")
+ url11 := []byte("http://fasthttp.com/hola")
+ url2 := []byte("http://fasthttp.com/make/fasthttp")
+ url3 := []byte("http://fasthttp.com/make/fasthttp/great")
+ prefix := []byte("/")
+ prefix1 := []byte("/make")
+ prefix2 := []byte("/make/fasthttp")
+ prefix3 := []byte("/make/fasthttp/great")
+ cj := &CookieJar{}
+
+ c1 := &fasthttp.Cookie{}
+ c1.SetKey("k")
+ c1.SetValue("v")
+ c1.SetPath("/make/")
+
+ c2 := &fasthttp.Cookie{}
+ c2.SetKey("kk")
+ c2.SetValue("vv")
+ c2.SetPath("/make/fasthttp")
+
+ c3 := &fasthttp.Cookie{}
+ c3.SetKey("kkk")
+ c3.SetValue("vvv")
+ c3.SetPath("/make/fasthttp/great")
+
+ uri := fasthttp.AcquireURI()
+ require.NoError(t, uri.Parse(nil, url))
+
+ uri1 := fasthttp.AcquireURI()
+ require.NoError(t, uri1.Parse(nil, url1))
+
+ uri11 := fasthttp.AcquireURI()
+ require.NoError(t, uri11.Parse(nil, url11))
+
+ uri2 := fasthttp.AcquireURI()
+ require.NoError(t, uri2.Parse(nil, url2))
+
+ uri3 := fasthttp.AcquireURI()
+ require.NoError(t, uri3.Parse(nil, url3))
+
+ cj.Set(uri1, c1, c2, c3)
+
+ cookies := cj.Get(uri1)
+ require.Len(t, cookies, 3)
+ for _, cookie := range cookies {
+ require.True(t, bytes.HasPrefix(cookie.Path(), prefix1))
+ }
+
+ cookies = cj.Get(uri11)
+ require.Empty(t, cookies)
+
+ cookies = cj.Get(uri2)
+ require.Len(t, cookies, 2)
+ for _, cookie := range cookies {
+ require.True(t, bytes.HasPrefix(cookie.Path(), prefix2))
+ }
+
+ cookies = cj.Get(uri3)
+ require.Len(t, cookies, 1)
+
+ for _, cookie := range cookies {
+ require.True(t, bytes.HasPrefix(cookie.Path(), prefix3))
+ }
+
+ cookies = cj.Get(uri)
+ require.Len(t, cookies, 3)
+ for _, cookie := range cookies {
+ require.True(t, bytes.HasPrefix(cookie.Path(), prefix))
+ }
+}
+
+func TestCookieJarGetExpired(t *testing.T) {
+ t.Parallel()
+
+ url1 := []byte("http://fasthttp.com/make/")
+ uri1 := fasthttp.AcquireURI()
+ require.NoError(t, uri1.Parse(nil, url1))
+
+ c1 := &fasthttp.Cookie{}
+ c1.SetKey("k")
+ c1.SetValue("v")
+ c1.SetExpire(time.Now().Add(-time.Hour))
+
+ cj := &CookieJar{}
+ cj.Set(uri1, c1)
+
+ cookies := cj.Get(uri1)
+ require.Empty(t, cookies)
+}
+
+func TestCookieJarSet(t *testing.T) {
+ t.Parallel()
+
+ url := []byte("http://fasthttp.com/hello/world")
+ cj := &CookieJar{}
+
+ cookie := &fasthttp.Cookie{}
+ cookie.SetKey("k")
+ cookie.SetValue("v")
+
+ uri := fasthttp.AcquireURI()
+ require.NoError(t, uri.Parse(nil, url))
+
+ cj.Set(uri, cookie)
+ checkKeyValue(t, cj, cookie, uri, 1)
+}
+
+func TestCookieJarSetRepeatedCookieKeys(t *testing.T) {
+ t.Parallel()
+
+ host := "fast.http"
+ cj := &CookieJar{}
+
+ uri := fasthttp.AcquireURI()
+ uri.SetHost(host)
+
+ cookie := &fasthttp.Cookie{}
+ cookie.SetKey("k")
+ cookie.SetValue("v")
+
+ cookie2 := &fasthttp.Cookie{}
+ cookie2.SetKey("k")
+ cookie2.SetValue("v2")
+
+ cookie3 := &fasthttp.Cookie{}
+ cookie3.SetKey("key")
+ cookie3.SetValue("value")
+
+ cj.Set(uri, cookie, cookie2, cookie3)
+
+ cookies := cj.Get(uri)
+ require.Len(t, cookies, 2)
+ require.Equal(t, cookies[0], cookie2)
+ require.True(t, bytes.Equal(cookies[0].Value(), cookie2.Value()))
+}
+
+func TestCookieJarSetKeyValue(t *testing.T) {
+ t.Parallel()
+
+ host := "fast.http"
+ cj := &CookieJar{}
+
+ uri := fasthttp.AcquireURI()
+ uri.SetHost(host)
+
+ cj.SetKeyValue(host, "k", "v")
+ cj.SetKeyValue(host, "key", "value")
+ cj.SetKeyValue(host, "k", "vv")
+ cj.SetKeyValue(host, "key", "value2")
+
+ cookies := cj.Get(uri)
+ require.Len(t, cookies, 2)
+}
+
+func TestCookieJarGetFromResponse(t *testing.T) {
+ t.Parallel()
+
+ res := fasthttp.AcquireResponse()
+ host := []byte("fast.http")
+ uri := fasthttp.AcquireURI()
+ uri.SetHostBytes(host)
+
+ c := &fasthttp.Cookie{}
+ c.SetKey("key")
+ c.SetValue("val")
+
+ c2 := &fasthttp.Cookie{}
+ c2.SetKey("k")
+ c2.SetValue("v")
+
+ c3 := &fasthttp.Cookie{}
+ c3.SetKey("kk")
+ c3.SetValue("vv")
+
+ res.Header.SetStatusCode(200)
+ res.Header.SetCookie(c)
+ res.Header.SetCookie(c2)
+ res.Header.SetCookie(c3)
+
+ cj := &CookieJar{}
+ cj.parseCookiesFromResp(host, nil, res)
+
+ cookies := cj.Get(uri)
+ require.Len(t, cookies, 3)
+}
diff --git a/client/core.go b/client/core.go
new file mode 100644
index 0000000000..315d12d474
--- /dev/null
+++ b/client/core.go
@@ -0,0 +1,272 @@
+package client
+
+import (
+ "context"
+ "errors"
+ "net"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+
+ "github.com/gofiber/fiber/v3"
+ "github.com/gofiber/fiber/v3/addon/retry"
+ "github.com/valyala/fasthttp"
+)
+
+var boundary = "--FiberFormBoundary"
+
+// RequestHook is a function that receives Agent and Request,
+// it can change the data in Request and Agent.
+//
+// Called before a request is sent.
+type RequestHook func(*Client, *Request) error
+
+// ResponseHook is a function that receives Agent, Response and Request,
+// it can change the data is Response or deal with some effects.
+//
+// Called after a response has been received.
+type ResponseHook func(*Client, *Response, *Request) error
+
+// RetryConfig is an alias for config in the `addon/retry` package.
+type RetryConfig = retry.Config
+
+// addMissingPort will add the corresponding port number for host.
+func addMissingPort(addr string, isTLS bool) string { //revive:disable-line:flag-parameter // Accepting a bool param named isTLS if fine here
+ n := strings.Index(addr, ":")
+ if n >= 0 {
+ return addr
+ }
+ port := 80
+ if isTLS {
+ port = 443
+ }
+ return net.JoinHostPort(addr, strconv.Itoa(port))
+}
+
+// `core` stores middleware and plugin definitions,
+// and defines the execution process
+type core struct {
+ client *Client
+ req *Request
+ ctx context.Context //nolint:containedctx // It's needed to be stored in the core.
+}
+
+// getRetryConfig returns the retry configuration of the client.
+func (c *core) getRetryConfig() *RetryConfig {
+ c.client.mu.RLock()
+ defer c.client.mu.RUnlock()
+
+ cfg := c.client.RetryConfig()
+ if cfg == nil {
+ return nil
+ }
+
+ return &RetryConfig{
+ InitialInterval: cfg.InitialInterval,
+ MaxBackoffTime: cfg.MaxBackoffTime,
+ Multiplier: cfg.Multiplier,
+ MaxRetryCount: cfg.MaxRetryCount,
+ }
+}
+
+// execFunc is the core function of the client.
+// It sends the request and receives the response.
+func (c *core) execFunc() (*Response, error) {
+ resp := AcquireResponse()
+ resp.setClient(c.client)
+ resp.setRequest(c.req)
+
+ // To avoid memory allocation reuse of data structures such as errch.
+ done := int32(0)
+ errCh, reqv := acquireErrChan(), fasthttp.AcquireRequest()
+ defer func() {
+ releaseErrChan(errCh)
+ }()
+
+ c.req.RawRequest.CopyTo(reqv)
+ cfg := c.getRetryConfig()
+
+ var err error
+ go func() {
+ respv := fasthttp.AcquireResponse()
+ defer func() {
+ fasthttp.ReleaseRequest(reqv)
+ fasthttp.ReleaseResponse(respv)
+ }()
+
+ if cfg != nil {
+ err = retry.NewExponentialBackoff(*cfg).Retry(func() error {
+ if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) {
+ return c.client.fasthttp.DoRedirects(reqv, respv, c.req.maxRedirects)
+ }
+
+ return c.client.fasthttp.Do(reqv, respv)
+ })
+ } else {
+ if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) {
+ err = c.client.fasthttp.DoRedirects(reqv, respv, c.req.maxRedirects)
+ } else {
+ err = c.client.fasthttp.Do(reqv, respv)
+ }
+ }
+
+ if atomic.CompareAndSwapInt32(&done, 0, 1) {
+ if err != nil {
+ errCh <- err
+ return
+ }
+ respv.CopyTo(resp.RawResponse)
+ errCh <- nil
+ }
+ }()
+
+ select {
+ case err := <-errCh:
+ if err != nil {
+ // When get error should release Response
+ ReleaseResponse(resp)
+ return nil, err
+ }
+ return resp, nil
+ case <-c.ctx.Done():
+ atomic.SwapInt32(&done, 1)
+ ReleaseResponse(resp)
+ return nil, ErrTimeoutOrCancel
+ }
+}
+
+// preHooks Exec request hook
+func (c *core) preHooks() error {
+ c.client.mu.Lock()
+ defer c.client.mu.Unlock()
+
+ for _, f := range c.client.userRequestHooks {
+ err := f(c.client, c.req)
+ if err != nil {
+ return err
+ }
+ }
+
+ for _, f := range c.client.builtinRequestHooks {
+ err := f(c.client, c.req)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// afterHooks Exec response hooks
+func (c *core) afterHooks(resp *Response) error {
+ c.client.mu.Lock()
+ defer c.client.mu.Unlock()
+
+ for _, f := range c.client.builtinResponseHooks {
+ err := f(c.client, resp, c.req)
+ if err != nil {
+ return err
+ }
+ }
+
+ for _, f := range c.client.userResponseHooks {
+ err := f(c.client, resp, c.req)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// timeout deals with timeout
+func (c *core) timeout() context.CancelFunc {
+ var cancel context.CancelFunc
+
+ if c.req.timeout > 0 {
+ c.ctx, cancel = context.WithTimeout(c.ctx, c.req.timeout)
+ } else if c.client.timeout > 0 {
+ c.ctx, cancel = context.WithTimeout(c.ctx, c.client.timeout)
+ }
+
+ return cancel
+}
+
+// execute will exec each hooks and plugins.
+func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Response, error) {
+ // keep a reference, because pass param is boring
+ c.ctx = ctx
+ c.client = client
+ c.req = req
+
+ // The built-in hooks will be executed only
+ // after the user-defined hooks are executed.
+ err := c.preHooks()
+ if err != nil {
+ return nil, err
+ }
+
+ cancel := c.timeout()
+ if cancel != nil {
+ defer cancel()
+ }
+
+ // Do http request
+ resp, err := c.execFunc()
+ if err != nil {
+ return nil, err
+ }
+
+ // The built-in hooks will be executed only
+ // before the user-defined hooks are executed.
+ err = c.afterHooks(resp)
+ if err != nil {
+ resp.Close()
+ return nil, err
+ }
+
+ return resp, nil
+}
+
+var errChanPool = &sync.Pool{
+ New: func() any {
+ return make(chan error, 1)
+ },
+}
+
+// acquireErrChan returns an empty error chan from the pool.
+//
+// The returned error chan may be returned to the pool with releaseErrChan when no longer needed.
+// This allows reducing GC load.
+func acquireErrChan() chan error {
+ ch, ok := errChanPool.Get().(chan error)
+ if !ok {
+ panic(errors.New("failed to type-assert to chan error"))
+ }
+
+ return ch
+}
+
+// releaseErrChan returns the object acquired via acquireErrChan to the pool.
+//
+// Do not access the released core object, otherwise data races may occur.
+func releaseErrChan(ch chan error) {
+ errChanPool.Put(ch)
+}
+
+// newCore returns an empty core object.
+func newCore() *core {
+ c := &core{}
+
+ return c
+}
+
+var (
+ ErrTimeoutOrCancel = errors.New("timeout or cancel")
+ ErrURLFormat = errors.New("the url is a mistake")
+ ErrNotSupportSchema = errors.New("the protocol is not support, only http or https")
+ ErrFileNoName = errors.New("the file should have name")
+ ErrBodyType = errors.New("the body type should be []byte")
+ ErrNotSupportSaveMethod = errors.New("file path and io.Writer are supported")
+)
diff --git a/client/core_test.go b/client/core_test.go
new file mode 100644
index 0000000000..1b8ea42b9d
--- /dev/null
+++ b/client/core_test.go
@@ -0,0 +1,248 @@
+package client
+
+import (
+ "context"
+ "errors"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/gofiber/fiber/v3"
+ "github.com/stretchr/testify/require"
+ "github.com/valyala/fasthttp/fasthttputil"
+)
+
+func Test_AddMissing_Port(t *testing.T) {
+ t.Parallel()
+
+ type args struct {
+ addr string
+ isTLS bool
+ }
+ tests := []struct {
+ name string
+ args args
+ want string
+ }{
+ {
+ name: "do anything",
+ args: args{
+ addr: "example.com:1234",
+ },
+ want: "example.com:1234",
+ },
+ {
+ name: "add 80 port",
+ args: args{
+ addr: "example.com",
+ },
+ want: "example.com:80",
+ },
+ {
+ name: "add 443 port",
+ args: args{
+ addr: "example.com",
+ isTLS: true,
+ },
+ want: "example.com:443",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.want, addMissingPort(tt.args.addr, tt.args.isTLS))
+ })
+ }
+}
+
+func Test_Exec_Func(t *testing.T) {
+ t.Parallel()
+
+ ln := fasthttputil.NewInmemoryListener()
+ app := fiber.New()
+
+ app.Get("/normal", func(c fiber.Ctx) error {
+ return c.SendString(c.Hostname())
+ })
+
+ app.Get("/return-error", func(_ fiber.Ctx) error {
+ return errors.New("the request is error")
+ })
+
+ app.Get("/hang-up", func(c fiber.Ctx) error {
+ time.Sleep(time.Second)
+ return c.SendString(c.Hostname() + " hang up")
+ })
+
+ go func() {
+ require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}))
+ }()
+
+ time.Sleep(300 * time.Millisecond)
+
+ t.Run("normal request", func(t *testing.T) {
+ t.Parallel()
+ core, client, req := newCore(), NewClient(), AcquireRequest()
+ core.ctx = context.Background()
+ core.client = client
+ core.req = req
+
+ client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed
+ req.RawRequest.SetRequestURI("http://example.com/normal")
+
+ resp, err := core.execFunc()
+ require.NoError(t, err)
+ require.Equal(t, 200, resp.RawResponse.StatusCode())
+ require.Equal(t, "example.com", string(resp.RawResponse.Body()))
+ })
+
+ t.Run("the request return an error", func(t *testing.T) {
+ t.Parallel()
+ core, client, req := newCore(), NewClient(), AcquireRequest()
+ core.ctx = context.Background()
+ core.client = client
+ core.req = req
+
+ client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed
+ req.RawRequest.SetRequestURI("http://example.com/return-error")
+
+ resp, err := core.execFunc()
+
+ require.NoError(t, err)
+ require.Equal(t, 500, resp.RawResponse.StatusCode())
+ require.Equal(t, "the request is error", string(resp.RawResponse.Body()))
+ })
+
+ t.Run("the request timeout", func(t *testing.T) {
+ t.Parallel()
+ core, client, req := newCore(), NewClient(), AcquireRequest()
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
+
+ core.ctx = ctx
+ core.client = client
+ core.req = req
+
+ client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed
+ req.RawRequest.SetRequestURI("http://example.com/hang-up")
+
+ _, err := core.execFunc()
+
+ require.Equal(t, ErrTimeoutOrCancel, err)
+ })
+}
+
+func Test_Execute(t *testing.T) {
+ t.Parallel()
+
+ ln := fasthttputil.NewInmemoryListener()
+ app := fiber.New()
+
+ app.Get("/normal", func(c fiber.Ctx) error {
+ return c.SendString(c.Hostname())
+ })
+
+ app.Get("/return-error", func(_ fiber.Ctx) error {
+ return errors.New("the request is error")
+ })
+
+ app.Get("/hang-up", func(c fiber.Ctx) error {
+ time.Sleep(time.Second)
+ return c.SendString(c.Hostname() + " hang up")
+ })
+
+ go func() {
+ require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}))
+ }()
+
+ t.Run("add user request hooks", func(t *testing.T) {
+ t.Parallel()
+ core, client, req := newCore(), NewClient(), AcquireRequest()
+ client.AddRequestHook(func(_ *Client, _ *Request) error {
+ require.Equal(t, "http://example.com", req.URL())
+ return nil
+ })
+ client.SetDial(func(_ string) (net.Conn, error) {
+ return ln.Dial() //nolint:wrapcheck // not needed
+ })
+ req.SetURL("http://example.com")
+
+ resp, err := core.execute(context.Background(), client, req)
+ require.NoError(t, err)
+ require.Equal(t, "Cannot GET /", string(resp.RawResponse.Body()))
+ })
+
+ t.Run("add user response hooks", func(t *testing.T) {
+ t.Parallel()
+ core, client, req := newCore(), NewClient(), AcquireRequest()
+ client.AddResponseHook(func(_ *Client, _ *Response, req *Request) error {
+ require.Equal(t, "http://example.com", req.URL())
+ return nil
+ })
+ client.SetDial(func(_ string) (net.Conn, error) {
+ return ln.Dial() //nolint:wrapcheck // not needed
+ })
+ req.SetURL("http://example.com")
+
+ resp, err := core.execute(context.Background(), client, req)
+ require.NoError(t, err)
+ require.Equal(t, "Cannot GET /", string(resp.RawResponse.Body()))
+ })
+
+ t.Run("no timeout", func(t *testing.T) {
+ t.Parallel()
+ core, client, req := newCore(), NewClient(), AcquireRequest()
+
+ client.SetDial(func(_ string) (net.Conn, error) {
+ return ln.Dial() //nolint:wrapcheck // not needed
+ })
+ req.SetURL("http://example.com/hang-up")
+
+ resp, err := core.execute(context.Background(), client, req)
+ require.NoError(t, err)
+ require.Equal(t, "example.com hang up", string(resp.RawResponse.Body()))
+ })
+
+ t.Run("client timeout", func(t *testing.T) {
+ t.Parallel()
+ core, client, req := newCore(), NewClient(), AcquireRequest()
+ client.SetTimeout(500 * time.Millisecond)
+ client.SetDial(func(_ string) (net.Conn, error) {
+ return ln.Dial() //nolint:wrapcheck // not needed
+ })
+ req.SetURL("http://example.com/hang-up")
+
+ _, err := core.execute(context.Background(), client, req)
+ require.Equal(t, ErrTimeoutOrCancel, err)
+ })
+
+ t.Run("request timeout", func(t *testing.T) {
+ t.Parallel()
+ core, client, req := newCore(), NewClient(), AcquireRequest()
+
+ client.SetDial(func(_ string) (net.Conn, error) {
+ return ln.Dial() //nolint:wrapcheck // not needed
+ })
+ req.SetURL("http://example.com/hang-up").
+ SetTimeout(300 * time.Millisecond)
+
+ _, err := core.execute(context.Background(), client, req)
+ require.Equal(t, ErrTimeoutOrCancel, err)
+ })
+
+ t.Run("request timeout has higher level", func(t *testing.T) {
+ t.Parallel()
+ core, client, req := newCore(), NewClient(), AcquireRequest()
+ client.SetTimeout(30 * time.Millisecond)
+
+ client.SetDial(func(_ string) (net.Conn, error) {
+ return ln.Dial() //nolint:wrapcheck // not needed
+ })
+ req.SetURL("http://example.com/hang-up").
+ SetTimeout(3000 * time.Millisecond)
+
+ resp, err := core.execute(context.Background(), client, req)
+ require.NoError(t, err)
+ require.Equal(t, "example.com hang up", string(resp.RawResponse.Body()))
+ })
+}
diff --git a/client/helper_test.go b/client/helper_test.go
new file mode 100644
index 0000000000..67380f3470
--- /dev/null
+++ b/client/helper_test.go
@@ -0,0 +1,157 @@
+package client
+
+import (
+ "net"
+ "testing"
+ "time"
+
+ "github.com/gofiber/fiber/v3"
+ "github.com/stretchr/testify/require"
+ "github.com/valyala/fasthttp/fasthttputil"
+)
+
+type testServer struct {
+ app *fiber.App
+ ch chan struct{}
+ ln *fasthttputil.InmemoryListener
+ tb testing.TB
+}
+
+func startTestServer(tb testing.TB, beforeStarting func(app *fiber.App)) *testServer {
+ tb.Helper()
+
+ ln := fasthttputil.NewInmemoryListener()
+ app := fiber.New()
+
+ if beforeStarting != nil {
+ beforeStarting(app)
+ }
+
+ ch := make(chan struct{})
+ go func() {
+ if err := app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}); err != nil {
+ tb.Fatal(err)
+ }
+
+ close(ch)
+ }()
+
+ return &testServer{
+ app: app,
+ ch: ch,
+ ln: ln,
+ tb: tb,
+ }
+}
+
+func (ts *testServer) stop() {
+ ts.tb.Helper()
+
+ if err := ts.app.Shutdown(); err != nil {
+ ts.tb.Fatal(err)
+ }
+
+ select {
+ case <-ts.ch:
+ case <-time.After(time.Second):
+ ts.tb.Fatalf("timeout when waiting for server close")
+ }
+}
+
+func (ts *testServer) dial() func(addr string) (net.Conn, error) {
+ ts.tb.Helper()
+
+ return func(_ string) (net.Conn, error) {
+ return ts.ln.Dial() //nolint:wrapcheck // not needed
+ }
+}
+
+func createHelperServer(tb testing.TB) (*fiber.App, func(addr string) (net.Conn, error), func()) {
+ tb.Helper()
+
+ ln := fasthttputil.NewInmemoryListener()
+
+ app := fiber.New()
+
+ return app, func(_ string) (net.Conn, error) {
+ return ln.Dial() //nolint:wrapcheck // not needed
+ }, func() {
+ require.NoError(tb, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}))
+ }
+}
+
+func testRequest(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted string, count ...int) {
+ t.Helper()
+
+ app, ln, start := createHelperServer(t)
+ app.Get("/", handler)
+ go start()
+
+ c := 1
+ if len(count) > 0 {
+ c = count[0]
+ }
+
+ client := NewClient().SetDial(ln)
+
+ for i := 0; i < c; i++ {
+ req := AcquireRequest().SetClient(client)
+ wrapAgent(req)
+
+ resp, err := req.Get("http://example.com")
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, excepted, resp.String())
+ resp.Close()
+ }
+}
+
+func testRequestFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted error, count ...int) {
+ t.Helper()
+
+ app, ln, start := createHelperServer(t)
+ app.Get("/", handler)
+ go start()
+
+ c := 1
+ if len(count) > 0 {
+ c = count[0]
+ }
+
+ client := NewClient().SetDial(ln)
+
+ for i := 0; i < c; i++ {
+ req := AcquireRequest().SetClient(client)
+ wrapAgent(req)
+
+ _, err := req.Get("http://example.com")
+
+ require.Equal(t, excepted.Error(), err.Error())
+ }
+}
+
+func testClient(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Client), excepted string, count ...int) { //nolint: unparam // maybe needed
+ t.Helper()
+
+ app, ln, start := createHelperServer(t)
+ app.Get("/", handler)
+ go start()
+
+ c := 1
+ if len(count) > 0 {
+ c = count[0]
+ }
+
+ for i := 0; i < c; i++ {
+ client := NewClient().SetDial(ln)
+ wrapAgent(client)
+
+ resp, err := client.Get("http://example.com")
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, excepted, resp.String())
+ resp.Close()
+ }
+}
diff --git a/client/hooks.go b/client/hooks.go
new file mode 100644
index 0000000000..0ecc970d53
--- /dev/null
+++ b/client/hooks.go
@@ -0,0 +1,328 @@
+package client
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "math/rand"
+ "mime/multipart"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/gofiber/utils/v2"
+ "github.com/valyala/fasthttp"
+)
+
+var (
+ protocolCheck = regexp.MustCompile(`^https?://.*$`)
+
+ headerAccept = "Accept"
+
+ applicationJSON = "application/json"
+ applicationXML = "application/xml"
+ applicationForm = "application/x-www-form-urlencoded"
+ multipartFormData = "multipart/form-data"
+
+ letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
+ letterIdxBits = 6 // 6 bits to represent a letter index
+ letterIdxMask = 1<= 0; {
+ if remain == 0 {
+ cache, remain = src.Int63(), letterIdxMax
+ }
+
+ if idx := int(cache & int64(letterIdxMask)); idx < length {
+ b[i] = letterBytes[idx]
+ i--
+ }
+ cache >>= int64(letterIdxBits)
+ remain--
+ }
+
+ return utils.UnsafeString(b)
+}
+
+// parserRequestURL will set the options for the hostclient
+// and normalize the url.
+// The baseUrl will be merge with request uri.
+// Query params and path params deal in this function.
+func parserRequestURL(c *Client, req *Request) error {
+ splitURL := strings.Split(req.url, "?")
+ // I don't want to judge splitURL length.
+ splitURL = append(splitURL, "")
+
+ // Determine whether to superimpose baseurl based on
+ // whether the URL starts with the protocol
+ uri := splitURL[0]
+ if !protocolCheck.MatchString(uri) {
+ uri = c.baseURL + uri
+ if !protocolCheck.MatchString(uri) {
+ return ErrURLFormat
+ }
+ }
+
+ // set path params
+ req.path.VisitAll(func(key, val string) {
+ uri = strings.ReplaceAll(uri, ":"+key, val)
+ })
+ c.path.VisitAll(func(key, val string) {
+ uri = strings.ReplaceAll(uri, ":"+key, val)
+ })
+
+ // set uri to request and other related setting
+ req.RawRequest.SetRequestURI(uri)
+
+ // merge query params
+ hashSplit := strings.Split(splitURL[1], "#")
+ hashSplit = append(hashSplit, "")
+ args := fasthttp.AcquireArgs()
+ defer func() {
+ fasthttp.ReleaseArgs(args)
+ }()
+
+ args.Parse(hashSplit[0])
+ c.params.VisitAll(func(key, value []byte) {
+ args.AddBytesKV(key, value)
+ })
+ req.params.VisitAll(func(key, value []byte) {
+ args.AddBytesKV(key, value)
+ })
+ req.RawRequest.URI().SetQueryStringBytes(utils.CopyBytes(args.QueryString()))
+ req.RawRequest.URI().SetHash(hashSplit[1])
+
+ return nil
+}
+
+// parserRequestHeader will make request header up.
+// It will merge headers from client and request.
+// Header should be set automatically based on data.
+// User-Agent should be set.
+func parserRequestHeader(c *Client, req *Request) error {
+ // set method
+ req.RawRequest.Header.SetMethod(req.Method())
+ // merge header
+ c.header.VisitAll(func(key, value []byte) {
+ req.RawRequest.Header.AddBytesKV(key, value)
+ })
+
+ req.header.VisitAll(func(key, value []byte) {
+ req.RawRequest.Header.AddBytesKV(key, value)
+ })
+
+ // according to data set content-type
+ switch req.bodyType {
+ case jsonBody:
+ req.RawRequest.Header.SetContentType(applicationJSON)
+ req.RawRequest.Header.Set(headerAccept, applicationJSON)
+ case xmlBody:
+ req.RawRequest.Header.SetContentType(applicationXML)
+ case formBody:
+ req.RawRequest.Header.SetContentType(applicationForm)
+ case filesBody:
+ req.RawRequest.Header.SetContentType(multipartFormData)
+ // set boundary
+ if req.boundary == boundary {
+ req.boundary += randString(16)
+ }
+ req.RawRequest.Header.SetMultipartFormBoundary(req.boundary)
+ default:
+ }
+
+ // set useragent
+ req.RawRequest.Header.SetUserAgent(defaultUserAgent)
+ if c.userAgent != "" {
+ req.RawRequest.Header.SetUserAgent(c.userAgent)
+ }
+ if req.userAgent != "" {
+ req.RawRequest.Header.SetUserAgent(req.userAgent)
+ }
+
+ // set referer
+ req.RawRequest.Header.SetReferer(c.referer)
+ if req.referer != "" {
+ req.RawRequest.Header.SetReferer(req.referer)
+ }
+
+ // set cookie
+ // add cookie form jar to req
+ if c.cookieJar != nil {
+ c.cookieJar.dumpCookiesToReq(req.RawRequest)
+ }
+
+ c.cookies.VisitAll(func(key, val string) {
+ req.RawRequest.Header.SetCookie(key, val)
+ })
+
+ req.cookies.VisitAll(func(key, val string) {
+ req.RawRequest.Header.SetCookie(key, val)
+ })
+
+ return nil
+}
+
+// parserRequestBody automatically serializes the data according to
+// the data type and stores it in the body of the rawRequest
+func parserRequestBody(c *Client, req *Request) error {
+ switch req.bodyType {
+ case jsonBody:
+ body, err := c.jsonMarshal(req.body)
+ if err != nil {
+ return err
+ }
+ req.RawRequest.SetBody(body)
+ case xmlBody:
+ body, err := c.xmlMarshal(req.body)
+ if err != nil {
+ return err
+ }
+ req.RawRequest.SetBody(body)
+ case formBody:
+ req.RawRequest.SetBody(req.formData.QueryString())
+ case filesBody:
+ return parserRequestBodyFile(req)
+ case rawBody:
+ if body, ok := req.body.([]byte); ok {
+ req.RawRequest.SetBody(body)
+ } else {
+ return ErrBodyType
+ }
+ case noBody:
+ return nil
+ }
+
+ return nil
+}
+
+// parserRequestBodyFile parses request body if body type is file
+// this is an addition of parserRequestBody.
+func parserRequestBodyFile(req *Request) error {
+ mw := multipart.NewWriter(req.RawRequest.BodyWriter())
+ err := mw.SetBoundary(req.boundary)
+ if err != nil {
+ return fmt.Errorf("set boundary error: %w", err)
+ }
+ defer func() {
+ err := mw.Close()
+ if err != nil {
+ return
+ }
+ }()
+
+ // add formdata
+ req.formData.VisitAll(func(key, value []byte) {
+ if err != nil {
+ return
+ }
+ err = mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value))
+ })
+ if err != nil {
+ return fmt.Errorf("write formdata error: %w", err)
+ }
+
+ // add file
+ b := make([]byte, 512)
+ for i, v := range req.files {
+ if v.name == "" && v.path == "" {
+ return ErrFileNoName
+ }
+
+ // if name is not exist, set name
+ if v.name == "" && v.path != "" {
+ v.path = filepath.Clean(v.path)
+ v.name = filepath.Base(v.path)
+ }
+
+ // if field name is not exist, set it
+ if v.fieldName == "" {
+ v.fieldName = "file" + strconv.Itoa(i+1)
+ }
+
+ // check the reader
+ if v.reader == nil {
+ v.reader, err = os.Open(v.path)
+ if err != nil {
+ return fmt.Errorf("open file error: %w", err)
+ }
+ }
+
+ // write file
+ w, err := mw.CreateFormFile(v.fieldName, v.name)
+ if err != nil {
+ return fmt.Errorf("create file error: %w", err)
+ }
+
+ for {
+ n, err := v.reader.Read(b)
+ if err != nil && !errors.Is(err, io.EOF) {
+ return fmt.Errorf("read file error: %w", err)
+ }
+
+ if errors.Is(err, io.EOF) {
+ break
+ }
+
+ _, err = w.Write(b[:n])
+ if err != nil {
+ return fmt.Errorf("write file error: %w", err)
+ }
+ }
+
+ err = v.reader.Close()
+ if err != nil {
+ return fmt.Errorf("close file error: %w", err)
+ }
+ }
+
+ return nil
+}
+
+// parserResponseHeader will parse the response header and store it in the response
+func parserResponseCookie(c *Client, resp *Response, req *Request) error {
+ var err error
+ resp.RawResponse.Header.VisitAllCookie(func(key, value []byte) {
+ cookie := fasthttp.AcquireCookie()
+ err = cookie.ParseBytes(value)
+ if err != nil {
+ return
+ }
+ cookie.SetKeyBytes(key)
+
+ resp.cookie = append(resp.cookie, cookie)
+ })
+
+ if err != nil {
+ return err
+ }
+
+ // store cookies to jar
+ if c.cookieJar != nil {
+ c.cookieJar.parseCookiesFromResp(req.RawRequest.URI().Host(), req.RawRequest.URI().Path(), resp.RawResponse)
+ }
+
+ return nil
+}
+
+// logger is a response hook that logs the request and response
+func logger(c *Client, resp *Response, req *Request) error {
+ if !c.debug {
+ return nil
+ }
+
+ c.logger.Debugf("%s\n", req.RawRequest.String())
+ c.logger.Debugf("%s\n", resp.RawResponse.String())
+
+ return nil
+}
diff --git a/client/hooks_test.go b/client/hooks_test.go
new file mode 100644
index 0000000000..a555bba833
--- /dev/null
+++ b/client/hooks_test.go
@@ -0,0 +1,652 @@
+package client
+
+import (
+ "bytes"
+ "encoding/xml"
+ "fmt"
+ "io"
+ "net"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/gofiber/fiber/v3"
+ "github.com/stretchr/testify/require"
+)
+
+func Test_Rand_String(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ args int
+ }{
+ {
+ name: "test generate",
+ args: 16,
+ },
+ {
+ name: "test generate smaller string",
+ args: 8,
+ },
+ {
+ name: "test generate larger string",
+ args: 32,
+ },
+ }
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := randString(tt.args)
+ require.Len(t, got, tt.args)
+ })
+ }
+}
+
+func Test_Parser_Request_URL(t *testing.T) {
+ t.Parallel()
+
+ t.Run("client baseurl should be set", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().SetBaseURL("http://example.com/api")
+ req := AcquireRequest().SetURL("")
+
+ err := parserRequestURL(client, req)
+ require.NoError(t, err)
+ require.Equal(t, "http://example.com/api", req.RawRequest.URI().String())
+ })
+
+ t.Run("request url should be set", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient()
+ req := AcquireRequest().SetURL("http://example.com/api")
+
+ err := parserRequestURL(client, req)
+ require.NoError(t, err)
+ require.Equal(t, "http://example.com/api", req.RawRequest.URI().String())
+ })
+
+ t.Run("the request url will override baseurl with protocol", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().SetBaseURL("http://example.com/api")
+ req := AcquireRequest().SetURL("http://example.com/api/v1")
+
+ err := parserRequestURL(client, req)
+ require.NoError(t, err)
+ require.Equal(t, "http://example.com/api/v1", req.RawRequest.URI().String())
+ })
+
+ t.Run("the request url should be append after baseurl without protocol", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().SetBaseURL("http://example.com/api")
+ req := AcquireRequest().SetURL("/v1")
+
+ err := parserRequestURL(client, req)
+ require.NoError(t, err)
+ require.Equal(t, "http://example.com/api/v1", req.RawRequest.URI().String())
+ })
+
+ t.Run("the url is error", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().SetBaseURL("example.com/api")
+ req := AcquireRequest().SetURL("/v1")
+
+ err := parserRequestURL(client, req)
+ require.Equal(t, ErrURLFormat, err)
+ })
+
+ t.Run("the path param from client", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetBaseURL("http://example.com/api/:id").
+ SetPathParam("id", "5")
+ req := AcquireRequest()
+
+ err := parserRequestURL(client, req)
+ require.NoError(t, err)
+ require.Equal(t, "http://example.com/api/5", req.RawRequest.URI().String())
+ })
+
+ t.Run("the path param from request", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetBaseURL("http://example.com/api/:id/:name").
+ SetPathParam("id", "5")
+ req := AcquireRequest().
+ SetURL("/{key}").
+ SetPathParams(map[string]string{
+ "name": "fiber",
+ "key": "val",
+ }).
+ DelPathParams("key")
+
+ err := parserRequestURL(client, req)
+ require.NoError(t, err)
+ require.Equal(t, "http://example.com/api/5/fiber/%7Bkey%7D", req.RawRequest.URI().String())
+ })
+
+ t.Run("the path param from request and client", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetBaseURL("http://example.com/api/:id/:name").
+ SetPathParam("id", "5")
+ req := AcquireRequest().
+ SetURL("/:key").
+ SetPathParams(map[string]string{
+ "name": "fiber",
+ "key": "val",
+ "id": "12",
+ })
+
+ err := parserRequestURL(client, req)
+ require.NoError(t, err)
+ require.Equal(t, "http://example.com/api/12/fiber/val", req.RawRequest.URI().String())
+ })
+
+ t.Run("query params from client should be set", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetParam("foo", "bar")
+ req := AcquireRequest().SetURL("http://example.com/api/v1")
+
+ err := parserRequestURL(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte("foo=bar"), req.RawRequest.URI().QueryString())
+ })
+
+ t.Run("query params from request should be set", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient()
+ req := AcquireRequest().
+ SetURL("http://example.com/api/v1").
+ SetParam("bar", "foo")
+
+ err := parserRequestURL(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte("bar=foo"), req.RawRequest.URI().QueryString())
+ })
+
+ t.Run("query params should be merged", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetParam("bar", "foo1")
+ req := AcquireRequest().
+ SetURL("http://example.com/api/v1?bar=foo2").
+ SetParam("bar", "foo")
+
+ err := parserRequestURL(client, req)
+ require.NoError(t, err)
+
+ values, err := url.ParseQuery(string(req.RawRequest.URI().QueryString()))
+ require.NoError(t, err)
+
+ flag1, flag2, flag3 := false, false, false
+ for _, v := range values["bar"] {
+ if v == "foo1" {
+ flag1 = true
+ } else if v == "foo2" {
+ flag2 = true
+ } else if v == "foo" {
+ flag3 = true
+ }
+ }
+ require.True(t, flag1)
+ require.True(t, flag2)
+ require.True(t, flag3)
+ })
+}
+
+func Test_Parser_Request_Header(t *testing.T) {
+ t.Parallel()
+
+ t.Run("client header should be set", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetHeaders(map[string]string{
+ fiber.HeaderContentType: "application/json",
+ })
+
+ req := AcquireRequest()
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte("application/json"), req.RawRequest.Header.ContentType())
+ })
+
+ t.Run("request header should be set", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient()
+
+ req := AcquireRequest().
+ SetHeaders(map[string]string{
+ fiber.HeaderContentType: "application/json, utf-8",
+ })
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType())
+ })
+
+ t.Run("request header should override client header", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetHeader(fiber.HeaderContentType, "application/xml")
+
+ req := AcquireRequest().
+ SetHeader(fiber.HeaderContentType, "application/json, utf-8")
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType())
+ })
+
+ t.Run("auto set json header", func(t *testing.T) {
+ t.Parallel()
+ type jsonData struct {
+ Name string `json:"name"`
+ }
+ client := NewClient()
+ req := AcquireRequest().
+ SetJSON(jsonData{
+ Name: "foo",
+ })
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte(applicationJSON), req.RawRequest.Header.ContentType())
+ })
+
+ t.Run("auto set xml header", func(t *testing.T) {
+ t.Parallel()
+ type xmlData struct {
+ XMLName xml.Name `xml:"body"`
+ Name string `xml:"name"`
+ }
+ client := NewClient()
+ req := AcquireRequest().
+ SetXML(xmlData{
+ Name: "foo",
+ })
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte(applicationXML), req.RawRequest.Header.ContentType())
+ })
+
+ t.Run("auto set form data header", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient()
+ req := AcquireRequest().
+ SetFormDatas(map[string]string{
+ "foo": "bar",
+ "ball": "cricle and square",
+ })
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.Equal(t, applicationForm, string(req.RawRequest.Header.ContentType()))
+ })
+
+ t.Run("auto set file header", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient()
+ req := AcquireRequest().
+ AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))).
+ SetFormData("foo", "bar")
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.True(t, strings.Contains(string(req.RawRequest.Header.MultipartFormBoundary()), "--FiberFormBoundary"))
+ require.True(t, strings.Contains(string(req.RawRequest.Header.ContentType()), multipartFormData))
+ })
+
+ t.Run("ua should have default value", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient()
+ req := AcquireRequest()
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte("fiber"), req.RawRequest.Header.UserAgent())
+ })
+
+ t.Run("ua in client should be set", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().SetUserAgent("foo")
+ req := AcquireRequest()
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte("foo"), req.RawRequest.Header.UserAgent())
+ })
+
+ t.Run("ua in request should have higher level", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().SetUserAgent("foo")
+ req := AcquireRequest().SetUserAgent("bar")
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte("bar"), req.RawRequest.Header.UserAgent())
+ })
+
+ t.Run("referer in client should be set", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().SetReferer("https://example.com")
+ req := AcquireRequest()
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte("https://example.com"), req.RawRequest.Header.Referer())
+ })
+
+ t.Run("referer in request should have higher level", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().SetReferer("http://example.com")
+ req := AcquireRequest().SetReferer("https://example.com")
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte("https://example.com"), req.RawRequest.Header.Referer())
+ })
+
+ t.Run("client cookie should be set", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient().
+ SetCookie("foo", "bar").
+ SetCookies(map[string]string{
+ "bar": "foo",
+ "bar1": "foo1",
+ }).
+ DelCookies("bar1")
+
+ req := AcquireRequest()
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo")))
+ require.Equal(t, "foo", string(req.RawRequest.Header.Cookie("bar")))
+ require.Equal(t, "", string(req.RawRequest.Header.Cookie("bar1")))
+ })
+
+ t.Run("request cookie should be set", func(t *testing.T) {
+ t.Parallel()
+ type cookies struct {
+ Foo string `cookie:"foo"`
+ Bar int `cookie:"bar"`
+ }
+
+ client := NewClient()
+
+ req := AcquireRequest().
+ SetCookiesWithStruct(&cookies{
+ Foo: "bar",
+ Bar: 67,
+ })
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo")))
+ require.Equal(t, "67", string(req.RawRequest.Header.Cookie("bar")))
+ require.Equal(t, "", string(req.RawRequest.Header.Cookie("bar1")))
+ })
+
+ t.Run("request cookie will override client cookie", func(t *testing.T) {
+ t.Parallel()
+ type cookies struct {
+ Foo string `cookie:"foo"`
+ Bar int `cookie:"bar"`
+ }
+
+ client := NewClient().
+ SetCookie("foo", "bar").
+ SetCookies(map[string]string{
+ "bar": "foo",
+ "bar1": "foo1",
+ })
+
+ req := AcquireRequest().
+ SetCookiesWithStruct(&cookies{
+ Foo: "bar",
+ Bar: 67,
+ })
+
+ err := parserRequestHeader(client, req)
+ require.NoError(t, err)
+ require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo")))
+ require.Equal(t, "67", string(req.RawRequest.Header.Cookie("bar")))
+ require.Equal(t, "foo1", string(req.RawRequest.Header.Cookie("bar1")))
+ })
+}
+
+func Test_Parser_Request_Body(t *testing.T) {
+ t.Parallel()
+
+ t.Run("json body", func(t *testing.T) {
+ t.Parallel()
+ type jsonData struct {
+ Name string `json:"name"`
+ }
+ client := NewClient()
+ req := AcquireRequest().
+ SetJSON(jsonData{
+ Name: "foo",
+ })
+
+ err := parserRequestBody(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte("{\"name\":\"foo\"}"), req.RawRequest.Body())
+ })
+
+ t.Run("xml body", func(t *testing.T) {
+ t.Parallel()
+ type xmlData struct {
+ XMLName xml.Name `xml:"body"`
+ Name string `xml:"name"`
+ }
+ client := NewClient()
+ req := AcquireRequest().
+ SetXML(xmlData{
+ Name: "foo",
+ })
+
+ err := parserRequestBody(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte("foo"), req.RawRequest.Body())
+ })
+
+ t.Run("form data body", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient()
+ req := AcquireRequest().
+ SetFormDatas(map[string]string{
+ "ball": "cricle and square",
+ })
+
+ err := parserRequestBody(client, req)
+ require.NoError(t, err)
+ require.Equal(t, "ball=cricle+and+square", string(req.RawRequest.Body()))
+ })
+
+ t.Run("form data body error", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient()
+ req := AcquireRequest().
+ SetFormDatas(map[string]string{
+ "": "",
+ })
+
+ err := parserRequestBody(client, req)
+ require.NoError(t, err)
+ })
+
+ t.Run("file body", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient()
+ req := AcquireRequest().
+ AddFileWithReader("hello", io.NopCloser(strings.NewReader("world")))
+
+ err := parserRequestBody(client, req)
+ require.NoError(t, err)
+ require.True(t, strings.Contains(string(req.RawRequest.Body()), "----FiberFormBoundary"))
+ require.True(t, strings.Contains(string(req.RawRequest.Body()), "world"))
+ })
+
+ t.Run("file and form data", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient()
+ req := AcquireRequest().
+ AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))).
+ SetFormData("foo", "bar")
+
+ err := parserRequestBody(client, req)
+ require.NoError(t, err)
+ require.True(t, strings.Contains(string(req.RawRequest.Body()), "----FiberFormBoundary"))
+ require.True(t, strings.Contains(string(req.RawRequest.Body()), "world"))
+ require.True(t, strings.Contains(string(req.RawRequest.Body()), "bar"))
+ })
+
+ t.Run("raw body", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient()
+ req := AcquireRequest().
+ SetRawBody([]byte("hello world"))
+
+ err := parserRequestBody(client, req)
+ require.NoError(t, err)
+ require.Equal(t, []byte("hello world"), req.RawRequest.Body())
+ })
+
+ t.Run("raw body error", func(t *testing.T) {
+ t.Parallel()
+ client := NewClient()
+ req := AcquireRequest().
+ SetRawBody([]byte("hello world"))
+
+ req.body = nil
+
+ err := parserRequestBody(client, req)
+ require.ErrorIs(t, err, ErrBodyType)
+ })
+}
+
+type dummyLogger struct {
+ buf *bytes.Buffer
+}
+
+func (*dummyLogger) Trace(_ ...any) {}
+
+func (*dummyLogger) Debug(_ ...any) {}
+
+func (*dummyLogger) Info(_ ...any) {}
+
+func (*dummyLogger) Warn(_ ...any) {}
+
+func (*dummyLogger) Error(_ ...any) {}
+
+func (*dummyLogger) Fatal(_ ...any) {}
+
+func (*dummyLogger) Panic(_ ...any) {}
+
+func (*dummyLogger) Tracef(_ string, _ ...any) {}
+
+func (l *dummyLogger) Debugf(format string, v ...any) {
+ _, _ = l.buf.WriteString(fmt.Sprintf(format, v...)) //nolint:errcheck // not needed
+}
+
+func (*dummyLogger) Infof(_ string, _ ...any) {}
+
+func (*dummyLogger) Warnf(_ string, _ ...any) {}
+
+func (*dummyLogger) Errorf(_ string, _ ...any) {}
+
+func (*dummyLogger) Fatalf(_ string, _ ...any) {}
+
+func (*dummyLogger) Panicf(_ string, _ ...any) {}
+
+func (*dummyLogger) Tracew(_ string, _ ...any) {}
+
+func (*dummyLogger) Debugw(_ string, _ ...any) {}
+
+func (*dummyLogger) Infow(_ string, _ ...any) {}
+
+func (*dummyLogger) Warnw(_ string, _ ...any) {}
+
+func (*dummyLogger) Errorw(_ string, _ ...any) {}
+
+func (*dummyLogger) Fatalw(_ string, _ ...any) {}
+
+func (*dummyLogger) Panicw(_ string, _ ...any) {}
+
+func Test_Client_Logger_Debug(t *testing.T) {
+ t.Parallel()
+ app := fiber.New()
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString("response")
+ })
+
+ addrChan := make(chan string)
+ go func() {
+ require.NoError(t, app.Listen(":0", fiber.ListenConfig{
+ DisableStartupMessage: true,
+ ListenerAddrFunc: func(addr net.Addr) {
+ addrChan <- addr.String()
+ },
+ }))
+ }()
+
+ defer func(app *fiber.App) {
+ require.NoError(t, app.Shutdown())
+ }(app)
+
+ var buf bytes.Buffer
+ logger := &dummyLogger{buf: &buf}
+
+ client := NewClient()
+ client.Debug().SetLogger(logger)
+
+ addr := <-addrChan
+ resp, err := client.Get("http://" + addr)
+ require.NoError(t, err)
+ defer resp.Close()
+
+ require.NoError(t, err)
+ require.Contains(t, buf.String(), "Host: "+addr)
+ require.Contains(t, buf.String(), "Content-Length: 8")
+}
+
+func Test_Client_Logger_DisableDebug(t *testing.T) {
+ t.Parallel()
+ app := fiber.New()
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString("response")
+ })
+
+ addrChan := make(chan string)
+ go func() {
+ require.NoError(t, app.Listen(":0", fiber.ListenConfig{
+ DisableStartupMessage: true,
+ ListenerAddrFunc: func(addr net.Addr) {
+ addrChan <- addr.String()
+ },
+ }))
+ }()
+
+ defer func(app *fiber.App) {
+ require.NoError(t, app.Shutdown())
+ }(app)
+
+ var buf bytes.Buffer
+ logger := &dummyLogger{buf: &buf}
+
+ client := NewClient()
+ client.DisableDebug().SetLogger(logger)
+
+ addr := <-addrChan
+ resp, err := client.Get("http://" + addr)
+ require.NoError(t, err)
+ defer resp.Close()
+
+ require.NoError(t, err)
+ require.Empty(t, buf.String())
+}
diff --git a/client/request.go b/client/request.go
new file mode 100644
index 0000000000..0bf2fb321c
--- /dev/null
+++ b/client/request.go
@@ -0,0 +1,985 @@
+package client
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "io"
+ "path/filepath"
+ "reflect"
+ "strconv"
+ "sync"
+ "time"
+
+ "github.com/gofiber/fiber/v3"
+ "github.com/gofiber/utils/v2"
+ "github.com/valyala/fasthttp"
+)
+
+// WithStruct Implementing this interface allows data to
+// be stored from a struct via reflect.
+type WithStruct interface {
+ Add(name, obj string)
+ Del(name string)
+}
+
+// Types of request bodies.
+type bodyType int
+
+// Enumeration definition of the request body type.
+const (
+ noBody bodyType = iota
+ jsonBody
+ xmlBody
+ formBody
+ filesBody
+ rawBody
+)
+
+var ErrClientNil = errors.New("client can not be nil")
+
+// Request is a struct which contains the request data.
+type Request struct {
+ url string
+ method string
+ userAgent string
+ boundary string
+ referer string
+ ctx context.Context //nolint:containedctx // It's needed to be stored in the request.
+ header *Header
+ params *QueryParam
+ cookies *Cookie
+ path *PathParam
+
+ timeout time.Duration
+ maxRedirects int
+
+ client *Client
+
+ body any
+ formData *FormData
+ files []*File
+ bodyType bodyType
+
+ RawRequest *fasthttp.Request
+}
+
+// Method returns http method in request.
+func (r *Request) Method() string {
+ return r.method
+}
+
+// SetMethod will set method for Request object,
+// user should use request method to set method.
+func (r *Request) SetMethod(method string) *Request {
+ r.method = method
+ return r
+}
+
+// URL returns request url in Request instance.
+func (r *Request) URL() string {
+ return r.url
+}
+
+// SetURL will set url for Request object.
+func (r *Request) SetURL(url string) *Request {
+ r.url = url
+ return r
+}
+
+// Client get Client instance in Request.
+func (r *Request) Client() *Client {
+ return r.client
+}
+
+// SetClient method sets client in request instance.
+func (r *Request) SetClient(c *Client) *Request {
+ if c == nil {
+ panic(ErrClientNil)
+ }
+
+ r.client = c
+ return r
+}
+
+// Context returns the Context if its already set in request
+// otherwise it creates new one using `context.Background()`.
+func (r *Request) Context() context.Context {
+ if r.ctx == nil {
+ return context.Background()
+ }
+ return r.ctx
+}
+
+// SetContext sets the context.Context for current Request. It allows
+// to interrupt the request execution if ctx.Done() channel is closed.
+// See https://blog.golang.org/context article and the "context" package
+// documentation.
+func (r *Request) SetContext(ctx context.Context) *Request {
+ r.ctx = ctx
+ return r
+}
+
+// Header method returns header value via key,
+// this method will visit all field in the header,
+// then sort them.
+func (r *Request) Header(key string) []string {
+ return r.header.PeekMultiple(key)
+}
+
+// AddHeader method adds a single header field and its value in the request instance.
+// It will override header which set in client instance.
+func (r *Request) AddHeader(key, val string) *Request {
+ r.header.Add(key, val)
+ return r
+}
+
+// SetHeader method sets a single header field and its value in the request instance.
+// It will override header which set in client instance.
+func (r *Request) SetHeader(key, val string) *Request {
+ r.header.Del(key)
+ r.header.Set(key, val)
+ return r
+}
+
+// AddHeaders method adds multiple header fields and its values at one go in the request instance.
+// It will override header which set in client instance.
+func (r *Request) AddHeaders(h map[string][]string) *Request {
+ r.header.AddHeaders(h)
+ return r
+}
+
+// SetHeaders method sets multiple header fields and its values at one go in the request instance.
+// It will override header which set in client instance.
+func (r *Request) SetHeaders(h map[string]string) *Request {
+ r.header.SetHeaders(h)
+ return r
+}
+
+// Param method returns params value via key,
+// this method will visit all field in the query param.
+func (r *Request) Param(key string) []string {
+ var res []string
+ tmp := r.params.PeekMulti(key)
+ for _, v := range tmp {
+ res = append(res, utils.UnsafeString(v))
+ }
+
+ return res
+}
+
+// AddParam method adds a single param field and its value in the request instance.
+// It will override param which set in client instance.
+func (r *Request) AddParam(key, val string) *Request {
+ r.params.Add(key, val)
+ return r
+}
+
+// SetParam method sets a single param field and its value in the request instance.
+// It will override param which set in client instance.
+func (r *Request) SetParam(key, val string) *Request {
+ r.params.Set(key, val)
+ return r
+}
+
+// AddParams method adds multiple param fields and its values at one go in the request instance.
+// It will override param which set in client instance.
+func (r *Request) AddParams(m map[string][]string) *Request {
+ r.params.AddParams(m)
+ return r
+}
+
+// SetParams method sets multiple param fields and its values at one go in the request instance.
+// It will override param which set in client instance.
+func (r *Request) SetParams(m map[string]string) *Request {
+ r.params.SetParams(m)
+ return r
+}
+
+// SetParamsWithStruct method sets multiple param fields and its values at one go in the request instance.
+// It will override param which set in client instance.
+func (r *Request) SetParamsWithStruct(v any) *Request {
+ r.params.SetParamsWithStruct(v)
+ return r
+}
+
+// DelParams method deletes single or multiple param fields ant its values.
+func (r *Request) DelParams(key ...string) *Request {
+ for _, v := range key {
+ r.params.Del(v)
+ }
+ return r
+}
+
+// UserAgent returns user agent in request instance.
+func (r *Request) UserAgent() string {
+ return r.userAgent
+}
+
+// SetUserAgent method sets user agent in request.
+// It will override user agent which set in client instance.
+func (r *Request) SetUserAgent(ua string) *Request {
+ r.userAgent = ua
+ return r
+}
+
+// Boundary returns boundary in multipart boundary.
+func (r *Request) Boundary() string {
+ return r.boundary
+}
+
+// SetBoundary method sets multipart boundary.
+func (r *Request) SetBoundary(b string) *Request {
+ r.boundary = b
+
+ return r
+}
+
+// Referer returns referer in request instance.
+func (r *Request) Referer() string {
+ return r.referer
+}
+
+// SetReferer method sets referer in request.
+// It will override referer which set in client instance.
+func (r *Request) SetReferer(referer string) *Request {
+ r.referer = referer
+ return r
+}
+
+// Cookie returns the cookie be set in request instance.
+// if cookie doesn't exist, return empty string.
+func (r *Request) Cookie(key string) string {
+ if val, ok := (*r.cookies)[key]; ok {
+ return val
+ }
+ return ""
+}
+
+// SetCookie method sets a single cookie field and its value in the request instance.
+// It will override cookie which set in client instance.
+func (r *Request) SetCookie(key, val string) *Request {
+ r.cookies.SetCookie(key, val)
+ return r
+}
+
+// SetCookies method sets multiple cookie fields and its values at one go in the request instance.
+// It will override cookie which set in client instance.
+func (r *Request) SetCookies(m map[string]string) *Request {
+ r.cookies.SetCookies(m)
+ return r
+}
+
+// SetCookiesWithStruct method sets multiple cookie fields and its values at one go in the request instance.
+// It will override cookie which set in client instance.
+func (r *Request) SetCookiesWithStruct(v any) *Request {
+ r.cookies.SetCookiesWithStruct(v)
+ return r
+}
+
+// DelCookies method deletes single or multiple cookie fields ant its values.
+func (r *Request) DelCookies(key ...string) *Request {
+ r.cookies.DelCookies(key...)
+ return r
+}
+
+// PathParam returns the path param be set in request instance.
+// if path param doesn't exist, return empty string.
+func (r *Request) PathParam(key string) string {
+ if val, ok := (*r.path)[key]; ok {
+ return val
+ }
+
+ return ""
+}
+
+// SetPathParam method sets a single path param field and its value in the request instance.
+// It will override path param which set in client instance.
+func (r *Request) SetPathParam(key, val string) *Request {
+ r.path.SetParam(key, val)
+ return r
+}
+
+// SetPathParams method sets multiple path param fields and its values at one go in the request instance.
+// It will override path param which set in client instance.
+func (r *Request) SetPathParams(m map[string]string) *Request {
+ r.path.SetParams(m)
+ return r
+}
+
+// SetPathParamsWithStruct method sets multiple path param fields and its values at one go in the request instance.
+// It will override path param which set in client instance.
+func (r *Request) SetPathParamsWithStruct(v any) *Request {
+ r.path.SetParamsWithStruct(v)
+ return r
+}
+
+// DelPathParams method deletes single or multiple path param fields ant its values.
+func (r *Request) DelPathParams(key ...string) *Request {
+ r.path.DelParams(key...)
+ return r
+}
+
+// ResetPathParams deletes all path params.
+func (r *Request) ResetPathParams() *Request {
+ r.path.Reset()
+ return r
+}
+
+// SetJSON method sets json body in request.
+func (r *Request) SetJSON(v any) *Request {
+ r.body = v
+ r.bodyType = jsonBody
+ return r
+}
+
+// SetXML method sets xml body in request.
+func (r *Request) SetXML(v any) *Request {
+ r.body = v
+ r.bodyType = xmlBody
+ return r
+}
+
+// SetRawBody method sets body with raw data in request.
+func (r *Request) SetRawBody(v []byte) *Request {
+ r.body = v
+ r.bodyType = rawBody
+ return r
+}
+
+// resetBody will clear body object and set bodyType
+// if body type is formBody and filesBody, the new body type will be ignored.
+func (r *Request) resetBody(t bodyType) {
+ r.body = nil
+
+ // Set form data after set file ignore.
+ if r.bodyType == filesBody && t == formBody {
+ return
+ }
+ r.bodyType = t
+}
+
+// FormData method returns form data value via key,
+// this method will visit all field in the form data.
+func (r *Request) FormData(key string) []string {
+ var res []string
+ tmp := r.formData.PeekMulti(key)
+ for _, v := range tmp {
+ res = append(res, utils.UnsafeString(v))
+ }
+
+ return res
+}
+
+// AddFormData method adds a single form data field and its value in the request instance.
+func (r *Request) AddFormData(key, val string) *Request {
+ r.formData.AddData(key, val)
+ r.resetBody(formBody)
+ return r
+}
+
+// SetFormData method sets a single form data field and its value in the request instance.
+func (r *Request) SetFormData(key, val string) *Request {
+ r.formData.SetData(key, val)
+ r.resetBody(formBody)
+ return r
+}
+
+// AddFormDatas method adds multiple form data fields and its values in the request instance.
+func (r *Request) AddFormDatas(m map[string][]string) *Request {
+ r.formData.AddDatas(m)
+ r.resetBody(formBody)
+ return r
+}
+
+// SetFormDatas method sets multiple form data fields and its values in the request instance.
+func (r *Request) SetFormDatas(m map[string]string) *Request {
+ r.formData.SetDatas(m)
+ r.resetBody(formBody)
+ return r
+}
+
+// SetFormDatasWithStruct method sets multiple form data fields
+// and its values in the request instance via struct.
+func (r *Request) SetFormDatasWithStruct(v any) *Request {
+ r.formData.SetDatasWithStruct(v)
+ r.resetBody(formBody)
+ return r
+}
+
+// DelFormDatas method deletes multiple form data fields and its value in the request instance.
+func (r *Request) DelFormDatas(key ...string) *Request {
+ r.formData.DelDatas(key...)
+ r.resetBody(formBody)
+ return r
+}
+
+// File returns file ptr store in request obj by name.
+// If name field is empty, it will try to match path.
+func (r *Request) File(name string) *File {
+ for _, v := range r.files {
+ if v.name == "" {
+ if filepath.Base(v.path) == name {
+ return v
+ }
+ } else if v.name == name {
+ return v
+ }
+ }
+
+ return nil
+}
+
+// FileByPath returns file ptr store in request obj by path.
+func (r *Request) FileByPath(path string) *File {
+ for _, v := range r.files {
+ if v.path == path {
+ return v
+ }
+ }
+
+ return nil
+}
+
+// AddFile method adds single file field
+// and its value in the request instance via file path.
+func (r *Request) AddFile(path string) *Request {
+ r.files = append(r.files, AcquireFile(SetFilePath(path)))
+ r.resetBody(filesBody)
+ return r
+}
+
+// AddFileWithReader method adds single field
+// and its value in the request instance via reader.
+func (r *Request) AddFileWithReader(name string, reader io.ReadCloser) *Request {
+ r.files = append(r.files, AcquireFile(SetFileName(name), SetFileReader(reader)))
+ r.resetBody(filesBody)
+ return r
+}
+
+// AddFiles method adds multiple file fields
+// and its value in the request instance via File instance.
+func (r *Request) AddFiles(files ...*File) *Request {
+ r.files = append(r.files, files...)
+ r.resetBody(filesBody)
+ return r
+}
+
+// Timeout returns the length of timeout in request.
+func (r *Request) Timeout() time.Duration {
+ return r.timeout
+}
+
+// SetTimeout method sets timeout field and its values at one go in the request instance.
+// It will override timeout which set in client instance.
+func (r *Request) SetTimeout(t time.Duration) *Request {
+ r.timeout = t
+ return r
+}
+
+// MaxRedirects returns the max redirects count in request.
+func (r *Request) MaxRedirects() int {
+ return r.maxRedirects
+}
+
+// SetMaxRedirects method sets the maximum number of redirects at one go in the request instance.
+// It will override max redirect which set in client instance.
+func (r *Request) SetMaxRedirects(count int) *Request {
+ r.maxRedirects = count
+ return r
+}
+
+// checkClient method checks whether the client has been set in request.
+func (r *Request) checkClient() {
+ if r.client == nil {
+ r.SetClient(defaultClient)
+ }
+}
+
+// Get Send get request.
+func (r *Request) Get(url string) (*Response, error) {
+ return r.SetURL(url).SetMethod(fiber.MethodGet).Send()
+}
+
+// Post Send post request.
+func (r *Request) Post(url string) (*Response, error) {
+ return r.SetURL(url).SetMethod(fiber.MethodPost).Send()
+}
+
+// Head Send head request.
+func (r *Request) Head(url string) (*Response, error) {
+ return r.SetURL(url).SetMethod(fiber.MethodHead).Send()
+}
+
+// Put Send put request.
+func (r *Request) Put(url string) (*Response, error) {
+ return r.SetURL(url).SetMethod(fiber.MethodPut).Send()
+}
+
+// Delete Send Delete request.
+func (r *Request) Delete(url string) (*Response, error) {
+ return r.SetURL(url).SetMethod(fiber.MethodDelete).Send()
+}
+
+// Options Send Options request.
+func (r *Request) Options(url string) (*Response, error) {
+ return r.SetURL(url).SetMethod(fiber.MethodOptions).Send()
+}
+
+// Patch Send patch request.
+func (r *Request) Patch(url string) (*Response, error) {
+ return r.SetURL(url).SetMethod(fiber.MethodPatch).Send()
+}
+
+// Custom Send custom request.
+func (r *Request) Custom(url, method string) (*Response, error) {
+ return r.SetURL(url).SetMethod(method).Send()
+}
+
+// Send a request.
+func (r *Request) Send() (*Response, error) {
+ r.checkClient()
+
+ return newCore().execute(r.Context(), r.Client(), r)
+}
+
+// Reset clear Request object, used by ReleaseRequest method.
+func (r *Request) Reset() {
+ r.url = ""
+ r.method = fiber.MethodGet
+ r.userAgent = ""
+ r.referer = ""
+ r.ctx = nil
+ r.body = nil
+ r.timeout = 0
+ r.maxRedirects = 0
+ r.bodyType = noBody
+ r.boundary = boundary
+
+ for len(r.files) != 0 {
+ t := r.files[0]
+ r.files = r.files[1:]
+ ReleaseFile(t)
+ }
+
+ r.formData.Reset()
+ r.path.Reset()
+ r.cookies.Reset()
+ r.header.Reset()
+ r.params.Reset()
+ r.RawRequest.Reset()
+}
+
+// Header is a wrapper which wrap http.Header,
+// the header in client and request will store in it.
+type Header struct {
+ *fasthttp.RequestHeader
+}
+
+// PeekMultiple methods returns multiple field in header with same key.
+func (h *Header) PeekMultiple(key string) []string {
+ var res []string
+ byteKey := []byte(key)
+ h.RequestHeader.VisitAll(func(key, value []byte) {
+ if bytes.EqualFold(key, byteKey) {
+ res = append(res, utils.UnsafeString(value))
+ }
+ })
+
+ return res
+}
+
+// AddHeaders receive a map and add each value to header.
+func (h *Header) AddHeaders(r map[string][]string) {
+ for k, v := range r {
+ for _, vv := range v {
+ h.Add(k, vv)
+ }
+ }
+}
+
+// SetHeaders will override all headers.
+func (h *Header) SetHeaders(r map[string]string) {
+ for k, v := range r {
+ h.Del(k)
+ h.Set(k, v)
+ }
+}
+
+// QueryParam is a wrapper which wrap url.Values,
+// the query string and formdata in client and request will store in it.
+type QueryParam struct {
+ *fasthttp.Args
+}
+
+// AddParams receive a map and add each value to param.
+func (p *QueryParam) AddParams(r map[string][]string) {
+ for k, v := range r {
+ for _, vv := range v {
+ p.Add(k, vv)
+ }
+ }
+}
+
+// SetParams will override all params.
+func (p *QueryParam) SetParams(r map[string]string) {
+ for k, v := range r {
+ p.Set(k, v)
+ }
+}
+
+// SetParamsWithStruct will override all params with struct or pointer of struct.
+// Now nested structs are not currently supported.
+func (p *QueryParam) SetParamsWithStruct(v any) {
+ SetValWithStruct(p, "param", v)
+}
+
+// Cookie is a map which to store the cookies.
+type Cookie map[string]string
+
+// Add method impl the method in WithStruct interface.
+func (c Cookie) Add(key, val string) {
+ c[key] = val
+}
+
+// Del method impl the method in WithStruct interface.
+func (c Cookie) Del(key string) {
+ delete(c, key)
+}
+
+// SetCookie method sets a single val in Cookie.
+func (c Cookie) SetCookie(key, val string) {
+ c[key] = val
+}
+
+// SetCookies method sets multiple val in Cookie.
+func (c Cookie) SetCookies(m map[string]string) {
+ for k, v := range m {
+ c[k] = v
+ }
+}
+
+// SetCookiesWithStruct method sets multiple val in Cookie via a struct.
+func (c Cookie) SetCookiesWithStruct(v any) {
+ SetValWithStruct(c, "cookie", v)
+}
+
+// DelCookies method deletes multiple val in Cookie.
+func (c Cookie) DelCookies(key ...string) {
+ for _, v := range key {
+ c.Del(v)
+ }
+}
+
+// VisitAll method receive a function which can travel the all val.
+func (c Cookie) VisitAll(f func(key, val string)) {
+ for k, v := range c {
+ f(k, v)
+ }
+}
+
+// Reset clear the Cookie object.
+func (c Cookie) Reset() {
+ for k := range c {
+ delete(c, k)
+ }
+}
+
+// PathParam is a map which to store the cookies.
+type PathParam map[string]string
+
+// Add method impl the method in WithStruct interface.
+func (p PathParam) Add(key, val string) {
+ p[key] = val
+}
+
+// Del method impl the method in WithStruct interface.
+func (p PathParam) Del(key string) {
+ delete(p, key)
+}
+
+// SetParam method sets a single val in PathParam.
+func (p PathParam) SetParam(key, val string) {
+ p[key] = val
+}
+
+// SetParams method sets multiple val in PathParam.
+func (p PathParam) SetParams(m map[string]string) {
+ for k, v := range m {
+ p[k] = v
+ }
+}
+
+// SetParamsWithStruct method sets multiple val in PathParam via a struct.
+func (p PathParam) SetParamsWithStruct(v any) {
+ SetValWithStruct(p, "path", v)
+}
+
+// DelParams method deletes multiple val in PathParams.
+func (p PathParam) DelParams(key ...string) {
+ for _, v := range key {
+ p.Del(v)
+ }
+}
+
+// VisitAll method receive a function which can travel the all val.
+func (p PathParam) VisitAll(f func(key, val string)) {
+ for k, v := range p {
+ f(k, v)
+ }
+}
+
+// Reset clear the PathParams object.
+func (p PathParam) Reset() {
+ for k := range p {
+ delete(p, k)
+ }
+}
+
+// FormData is a wrapper of fasthttp.Args,
+// and it be used for url encode body and file body.
+type FormData struct {
+ *fasthttp.Args
+}
+
+// AddData method is a wrapper of Args's Add method.
+func (f *FormData) AddData(key, val string) {
+ f.Add(key, val)
+}
+
+// SetData method is a wrapper of Args's Set method.
+func (f *FormData) SetData(key, val string) {
+ f.Set(key, val)
+}
+
+// AddDatas method supports add multiple fields.
+func (f *FormData) AddDatas(m map[string][]string) {
+ for k, v := range m {
+ for _, vv := range v {
+ f.Add(k, vv)
+ }
+ }
+}
+
+// SetDatas method supports set multiple fields.
+func (f *FormData) SetDatas(m map[string]string) {
+ for k, v := range m {
+ f.Set(k, v)
+ }
+}
+
+// SetDatasWithStruct method supports set multiple fields via a struct.
+func (f *FormData) SetDatasWithStruct(v any) {
+ SetValWithStruct(f, "form", v)
+}
+
+// DelDatas method deletes multiple fields.
+func (f *FormData) DelDatas(key ...string) {
+ for _, v := range key {
+ f.Del(v)
+ }
+}
+
+// Reset clear the FormData object.
+func (f *FormData) Reset() {
+ f.Args.Reset()
+}
+
+// File is a struct which support send files via request.
+type File struct {
+ name string
+ fieldName string
+ path string
+ reader io.ReadCloser
+}
+
+// SetName method sets file name.
+func (f *File) SetName(n string) {
+ f.name = n
+}
+
+// SetFieldName method sets key of file in the body.
+func (f *File) SetFieldName(n string) {
+ f.fieldName = n
+}
+
+// SetPath method set file path.
+func (f *File) SetPath(p string) {
+ f.path = p
+}
+
+// SetReader method can receive a io.ReadCloser
+// which will be closed in parserBody hook.
+func (f *File) SetReader(r io.ReadCloser) {
+ f.reader = r
+}
+
+// Reset clear the File object.
+func (f *File) Reset() {
+ f.name = ""
+ f.fieldName = ""
+ f.path = ""
+ f.reader = nil
+}
+
+var requestPool = &sync.Pool{
+ New: func() any {
+ return &Request{
+ header: &Header{RequestHeader: &fasthttp.RequestHeader{}},
+ params: &QueryParam{Args: fasthttp.AcquireArgs()},
+ cookies: &Cookie{},
+ path: &PathParam{},
+ boundary: "--FiberFormBoundary",
+ formData: &FormData{Args: fasthttp.AcquireArgs()},
+ files: make([]*File, 0),
+ RawRequest: fasthttp.AcquireRequest(),
+ }
+ },
+}
+
+// AcquireRequest returns an empty request object from the pool.
+//
+// The returned request may be returned to the pool with ReleaseRequest when no longer needed.
+// This allows reducing GC load.
+func AcquireRequest() *Request {
+ req, ok := requestPool.Get().(*Request)
+ if !ok {
+ panic(errors.New("failed to type-assert to *Request"))
+ }
+
+ return req
+}
+
+// ReleaseRequest returns the object acquired via AcquireRequest to the pool.
+//
+// Do not access the released Request object, otherwise data races may occur.
+func ReleaseRequest(req *Request) {
+ req.Reset()
+ requestPool.Put(req)
+}
+
+var filePool sync.Pool
+
+// SetFileFunc The methods as follows is used by AcquireFile method.
+// You can set file field via these method.
+type SetFileFunc func(f *File)
+
+// SetFileName method sets file name.
+func SetFileName(n string) SetFileFunc {
+ return func(f *File) {
+ f.SetName(n)
+ }
+}
+
+// SetFileFieldName method sets key of file in the body.
+func SetFileFieldName(p string) SetFileFunc {
+ return func(f *File) {
+ f.SetFieldName(p)
+ }
+}
+
+// SetFilePath method set file path.
+func SetFilePath(p string) SetFileFunc {
+ return func(f *File) {
+ f.SetPath(p)
+ }
+}
+
+// SetFileReader method can receive a io.ReadCloser
+func SetFileReader(r io.ReadCloser) SetFileFunc {
+ return func(f *File) {
+ f.SetReader(r)
+ }
+}
+
+// AcquireFile returns an File object from the pool.
+// And you can set field in the File with SetFileFunc.
+//
+// The returned file may be returned to the pool with ReleaseFile when no longer needed.
+// This allows reducing GC load.
+func AcquireFile(setter ...SetFileFunc) *File {
+ fv := filePool.Get()
+ if fv != nil {
+ f, ok := fv.(*File)
+ if !ok {
+ panic(errors.New("failed to type-assert to *File"))
+ }
+ for _, v := range setter {
+ v(f)
+ }
+ return f
+ }
+ f := &File{}
+ for _, v := range setter {
+ v(f)
+ }
+ return f
+}
+
+// ReleaseFile returns the object acquired via AcquireFile to the pool.
+//
+// Do not access the released File object, otherwise data races may occur.
+func ReleaseFile(f *File) {
+ f.Reset()
+ filePool.Put(f)
+}
+
+// SetValWithStruct Set some values using structs.
+// `p` is a structure that implements the WithStruct interface,
+// The field name can be specified by `tagName`.
+// `v` is a struct include some data.
+// Note: This method only supports simple types and nested structs are not currently supported.
+func SetValWithStruct(p WithStruct, tagName string, v any) {
+ valueOfV := reflect.ValueOf(v)
+ typeOfV := reflect.TypeOf(v)
+
+ // The v should be struct or point of struct
+ if typeOfV.Kind() == reflect.Pointer && typeOfV.Elem().Kind() == reflect.Struct {
+ valueOfV = valueOfV.Elem()
+ typeOfV = typeOfV.Elem()
+ } else if typeOfV.Kind() != reflect.Struct {
+ return
+ }
+
+ // Boring type judge.
+ // TODO: cover more types and complex data structure.
+ var setVal func(name string, value reflect.Value)
+ setVal = func(name string, val reflect.Value) {
+ switch val.Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ p.Add(name, strconv.Itoa(int(val.Int())))
+ case reflect.Bool:
+ if val.Bool() {
+ p.Add(name, "true")
+ }
+ case reflect.String:
+ p.Add(name, val.String())
+ case reflect.Float32, reflect.Float64:
+ p.Add(name, strconv.FormatFloat(val.Float(), 'f', -1, 64))
+ case reflect.Slice, reflect.Array:
+ for i := 0; i < val.Len(); i++ {
+ setVal(name, val.Index(i))
+ }
+ default:
+ }
+ }
+
+ for i := 0; i < typeOfV.NumField(); i++ {
+ field := typeOfV.Field(i)
+ if !field.IsExported() {
+ continue
+ }
+
+ name := field.Tag.Get(tagName)
+ if name == "" {
+ name = field.Name
+ }
+ val := valueOfV.Field(i)
+ if val.IsZero() {
+ continue
+ }
+ // To cover slice and array, we delete the val then add it.
+ p.Del(name)
+ setVal(name, val)
+ }
+}
diff --git a/client/request_test.go b/client/request_test.go
new file mode 100644
index 0000000000..07e5254e15
--- /dev/null
+++ b/client/request_test.go
@@ -0,0 +1,1623 @@
+package client
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "io"
+ "mime/multipart"
+ "net"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/gofiber/fiber/v3"
+ "github.com/stretchr/testify/require"
+ "github.com/valyala/fasthttp"
+ "github.com/valyala/fasthttp/fasthttputil"
+)
+
+func Test_Request_Method(t *testing.T) {
+ t.Parallel()
+
+ req := AcquireRequest()
+ req.SetMethod("GET")
+ require.Equal(t, "GET", req.Method())
+
+ req.SetMethod("POST")
+ require.Equal(t, "POST", req.Method())
+
+ req.SetMethod("PUT")
+ require.Equal(t, "PUT", req.Method())
+
+ req.SetMethod("DELETE")
+ require.Equal(t, "DELETE", req.Method())
+
+ req.SetMethod("PATCH")
+ require.Equal(t, "PATCH", req.Method())
+
+ req.SetMethod("OPTIONS")
+ require.Equal(t, "OPTIONS", req.Method())
+
+ req.SetMethod("HEAD")
+ require.Equal(t, "HEAD", req.Method())
+
+ req.SetMethod("TRACE")
+ require.Equal(t, "TRACE", req.Method())
+
+ req.SetMethod("CUSTOM")
+ require.Equal(t, "CUSTOM", req.Method())
+}
+
+func Test_Request_URL(t *testing.T) {
+ t.Parallel()
+
+ req := AcquireRequest()
+
+ req.SetURL("http://example.com/normal")
+ require.Equal(t, "http://example.com/normal", req.URL())
+
+ req.SetURL("https://example.com/normal")
+ require.Equal(t, "https://example.com/normal", req.URL())
+}
+
+func Test_Request_Client(t *testing.T) {
+ t.Parallel()
+
+ client := NewClient()
+ req := AcquireRequest()
+
+ req.SetClient(client)
+ require.Equal(t, client, req.Client())
+}
+
+func Test_Request_Context(t *testing.T) {
+ t.Parallel()
+
+ req := AcquireRequest()
+ ctx := req.Context()
+ key := struct{}{}
+
+ require.Nil(t, ctx.Value(key))
+
+ ctx = context.WithValue(ctx, key, "string")
+ req.SetContext(ctx)
+ ctx = req.Context()
+
+ v, ok := ctx.Value(key).(string)
+ require.True(t, ok)
+ require.Equal(t, "string", v)
+}
+
+func Test_Request_Header(t *testing.T) {
+ t.Parallel()
+
+ t.Run("add header", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ req.AddHeader("foo", "bar").AddHeader("foo", "fiber")
+
+ res := req.Header("foo")
+ require.Len(t, res, 2)
+ require.Equal(t, "bar", res[0])
+ require.Equal(t, "fiber", res[1])
+ })
+
+ t.Run("set header", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ req.AddHeader("foo", "bar").SetHeader("foo", "fiber")
+
+ res := req.Header("foo")
+ require.Len(t, res, 1)
+ require.Equal(t, "fiber", res[0])
+ })
+
+ t.Run("add headers", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ req.SetHeader("foo", "bar").
+ AddHeaders(map[string][]string{
+ "foo": {"fiber", "buaa"},
+ "bar": {"foo"},
+ })
+
+ res := req.Header("foo")
+ require.Len(t, res, 3)
+ require.Equal(t, "bar", res[0])
+ require.Equal(t, "fiber", res[1])
+ require.Equal(t, "buaa", res[2])
+
+ res = req.Header("bar")
+ require.Len(t, res, 1)
+ require.Equal(t, "foo", res[0])
+ })
+
+ t.Run("set headers", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ req.SetHeader("foo", "bar").
+ SetHeaders(map[string]string{
+ "foo": "fiber",
+ "bar": "foo",
+ })
+
+ res := req.Header("foo")
+ require.Len(t, res, 1)
+ require.Equal(t, "fiber", res[0])
+
+ res = req.Header("bar")
+ require.Len(t, res, 1)
+ require.Equal(t, "foo", res[0])
+ })
+}
+
+func Test_Request_QueryParam(t *testing.T) {
+ t.Parallel()
+
+ t.Run("add param", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ req.AddParam("foo", "bar").AddParam("foo", "fiber")
+
+ res := req.Param("foo")
+ require.Len(t, res, 2)
+ require.Equal(t, "bar", res[0])
+ require.Equal(t, "fiber", res[1])
+ })
+
+ t.Run("set param", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ req.AddParam("foo", "bar").SetParam("foo", "fiber")
+
+ res := req.Param("foo")
+ require.Len(t, res, 1)
+ require.Equal(t, "fiber", res[0])
+ })
+
+ t.Run("add params", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ req.SetParam("foo", "bar").
+ AddParams(map[string][]string{
+ "foo": {"fiber", "buaa"},
+ "bar": {"foo"},
+ })
+
+ res := req.Param("foo")
+ require.Len(t, res, 3)
+ require.Equal(t, "bar", res[0])
+ require.Equal(t, "fiber", res[1])
+ require.Equal(t, "buaa", res[2])
+
+ res = req.Param("bar")
+ require.Len(t, res, 1)
+ require.Equal(t, "foo", res[0])
+ })
+
+ t.Run("set headers", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ req.SetParam("foo", "bar").
+ SetParams(map[string]string{
+ "foo": "fiber",
+ "bar": "foo",
+ })
+
+ res := req.Param("foo")
+ require.Len(t, res, 1)
+ require.Equal(t, "fiber", res[0])
+
+ res = req.Param("bar")
+ require.Len(t, res, 1)
+ require.Equal(t, "foo", res[0])
+ })
+
+ t.Run("set params with struct", func(t *testing.T) {
+ t.Parallel()
+
+ type args struct {
+ TInt int
+ TString string
+ TFloat float64
+ TBool bool
+ TSlice []string
+ TIntSlice []int `param:"int_slice"`
+ }
+
+ p := AcquireRequest()
+ p.SetParamsWithStruct(&args{
+ TInt: 5,
+ TString: "string",
+ TFloat: 3.1,
+ TBool: true,
+ TSlice: []string{"foo", "bar"},
+ TIntSlice: []int{1, 2},
+ })
+
+ require.Empty(t, p.Param("unexport"))
+
+ require.Len(t, p.Param("TInt"), 1)
+ require.Equal(t, "5", p.Param("TInt")[0])
+
+ require.Len(t, p.Param("TString"), 1)
+ require.Equal(t, "string", p.Param("TString")[0])
+
+ require.Len(t, p.Param("TFloat"), 1)
+ require.Equal(t, "3.1", p.Param("TFloat")[0])
+
+ require.Len(t, p.Param("TBool"), 1)
+
+ tslice := p.Param("TSlice")
+ require.Len(t, tslice, 2)
+ require.Equal(t, "foo", tslice[0])
+ require.Equal(t, "bar", tslice[1])
+
+ tint := p.Param("TSlice")
+ require.Len(t, tint, 2)
+ require.Equal(t, "foo", tint[0])
+ require.Equal(t, "bar", tint[1])
+ })
+
+ t.Run("del params", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ req.SetParam("foo", "bar").
+ SetParams(map[string]string{
+ "foo": "fiber",
+ "bar": "foo",
+ }).DelParams("foo", "bar")
+
+ res := req.Param("foo")
+ require.Empty(t, res)
+
+ res = req.Param("bar")
+ require.Empty(t, res)
+ })
+}
+
+func Test_Request_UA(t *testing.T) {
+ t.Parallel()
+
+ req := AcquireRequest().SetUserAgent("fiber")
+ require.Equal(t, "fiber", req.UserAgent())
+
+ req.SetUserAgent("foo")
+ require.Equal(t, "foo", req.UserAgent())
+}
+
+func Test_Request_Referer(t *testing.T) {
+ t.Parallel()
+
+ req := AcquireRequest().SetReferer("http://example.com")
+ require.Equal(t, "http://example.com", req.Referer())
+
+ req.SetReferer("https://example.com")
+ require.Equal(t, "https://example.com", req.Referer())
+}
+
+func Test_Request_Cookie(t *testing.T) {
+ t.Parallel()
+
+ t.Run("set cookie", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest().
+ SetCookie("foo", "bar")
+ require.Equal(t, "bar", req.Cookie("foo"))
+
+ req.SetCookie("foo", "bar1")
+ require.Equal(t, "bar1", req.Cookie("foo"))
+ })
+
+ t.Run("set cookies", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest().
+ SetCookies(map[string]string{
+ "foo": "bar",
+ "bar": "foo",
+ })
+ require.Equal(t, "bar", req.Cookie("foo"))
+ require.Equal(t, "foo", req.Cookie("bar"))
+
+ req.SetCookies(map[string]string{
+ "foo": "bar1",
+ })
+ require.Equal(t, "bar1", req.Cookie("foo"))
+ require.Equal(t, "foo", req.Cookie("bar"))
+ })
+
+ t.Run("set cookies with struct", func(t *testing.T) {
+ t.Parallel()
+ type args struct {
+ CookieInt int `cookie:"int"`
+ CookieString string `cookie:"string"`
+ }
+
+ req := AcquireRequest().SetCookiesWithStruct(&args{
+ CookieInt: 5,
+ CookieString: "foo",
+ })
+
+ require.Equal(t, "5", req.Cookie("int"))
+ require.Equal(t, "foo", req.Cookie("string"))
+ })
+
+ t.Run("del cookies", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest().
+ SetCookies(map[string]string{
+ "foo": "bar",
+ "bar": "foo",
+ })
+ require.Equal(t, "bar", req.Cookie("foo"))
+ require.Equal(t, "foo", req.Cookie("bar"))
+
+ req.DelCookies("foo")
+ require.Equal(t, "", req.Cookie("foo"))
+ require.Equal(t, "foo", req.Cookie("bar"))
+ })
+}
+
+func Test_Request_PathParam(t *testing.T) {
+ t.Parallel()
+
+ t.Run("set path param", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest().
+ SetPathParam("foo", "bar")
+ require.Equal(t, "bar", req.PathParam("foo"))
+
+ req.SetPathParam("foo", "bar1")
+ require.Equal(t, "bar1", req.PathParam("foo"))
+ })
+
+ t.Run("set path params", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest().
+ SetPathParams(map[string]string{
+ "foo": "bar",
+ "bar": "foo",
+ })
+ require.Equal(t, "bar", req.PathParam("foo"))
+ require.Equal(t, "foo", req.PathParam("bar"))
+
+ req.SetPathParams(map[string]string{
+ "foo": "bar1",
+ })
+ require.Equal(t, "bar1", req.PathParam("foo"))
+ require.Equal(t, "foo", req.PathParam("bar"))
+ })
+
+ t.Run("set path params with struct", func(t *testing.T) {
+ t.Parallel()
+ type args struct {
+ CookieInt int `path:"int"`
+ CookieString string `path:"string"`
+ }
+
+ req := AcquireRequest().SetPathParamsWithStruct(&args{
+ CookieInt: 5,
+ CookieString: "foo",
+ })
+
+ require.Equal(t, "5", req.PathParam("int"))
+ require.Equal(t, "foo", req.PathParam("string"))
+ })
+
+ t.Run("del path params", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest().
+ SetPathParams(map[string]string{
+ "foo": "bar",
+ "bar": "foo",
+ })
+ require.Equal(t, "bar", req.PathParam("foo"))
+ require.Equal(t, "foo", req.PathParam("bar"))
+
+ req.DelPathParams("foo")
+ require.Equal(t, "", req.PathParam("foo"))
+ require.Equal(t, "foo", req.PathParam("bar"))
+ })
+
+ t.Run("clear path params", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest().
+ SetPathParams(map[string]string{
+ "foo": "bar",
+ "bar": "foo",
+ })
+ require.Equal(t, "bar", req.PathParam("foo"))
+ require.Equal(t, "foo", req.PathParam("bar"))
+
+ req.ResetPathParams()
+ require.Equal(t, "", req.PathParam("foo"))
+ require.Equal(t, "", req.PathParam("bar"))
+ })
+}
+
+func Test_Request_FormData(t *testing.T) {
+ t.Parallel()
+
+ t.Run("add form data", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ defer ReleaseRequest(req)
+ req.AddFormData("foo", "bar").AddFormData("foo", "fiber")
+
+ res := req.FormData("foo")
+ require.Len(t, res, 2)
+ require.Equal(t, "bar", res[0])
+ require.Equal(t, "fiber", res[1])
+ })
+
+ t.Run("set param", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ defer ReleaseRequest(req)
+ req.AddFormData("foo", "bar").SetFormData("foo", "fiber")
+
+ res := req.FormData("foo")
+ require.Len(t, res, 1)
+ require.Equal(t, "fiber", res[0])
+ })
+
+ t.Run("add params", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ defer ReleaseRequest(req)
+ req.SetFormData("foo", "bar").
+ AddFormDatas(map[string][]string{
+ "foo": {"fiber", "buaa"},
+ "bar": {"foo"},
+ })
+
+ res := req.FormData("foo")
+ require.Len(t, res, 3)
+ require.Contains(t, res, "bar")
+ require.Contains(t, res, "buaa")
+ require.Contains(t, res, "fiber")
+
+ res = req.FormData("bar")
+ require.Len(t, res, 1)
+ require.Equal(t, "foo", res[0])
+ })
+
+ t.Run("set headers", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ defer ReleaseRequest(req)
+ req.SetFormData("foo", "bar").
+ SetFormDatas(map[string]string{
+ "foo": "fiber",
+ "bar": "foo",
+ })
+
+ res := req.FormData("foo")
+ require.Len(t, res, 1)
+ require.Equal(t, "fiber", res[0])
+
+ res = req.FormData("bar")
+ require.Len(t, res, 1)
+ require.Equal(t, "foo", res[0])
+ })
+
+ t.Run("set params with struct", func(t *testing.T) {
+ t.Parallel()
+
+ type args struct {
+ TInt int
+ TString string
+ TFloat float64
+ TBool bool
+ TSlice []string
+ TIntSlice []int `form:"int_slice"`
+ }
+
+ p := AcquireRequest()
+ defer ReleaseRequest(p)
+ p.SetFormDatasWithStruct(&args{
+ TInt: 5,
+ TString: "string",
+ TFloat: 3.1,
+ TBool: true,
+ TSlice: []string{"foo", "bar"},
+ TIntSlice: []int{1, 2},
+ })
+
+ require.Empty(t, p.FormData("unexport"))
+
+ require.Len(t, p.FormData("TInt"), 1)
+ require.Equal(t, "5", p.FormData("TInt")[0])
+
+ require.Len(t, p.FormData("TString"), 1)
+ require.Equal(t, "string", p.FormData("TString")[0])
+
+ require.Len(t, p.FormData("TFloat"), 1)
+ require.Equal(t, "3.1", p.FormData("TFloat")[0])
+
+ require.Len(t, p.FormData("TBool"), 1)
+
+ tslice := p.FormData("TSlice")
+ require.Len(t, tslice, 2)
+ require.Contains(t, tslice, "bar")
+ require.Contains(t, tslice, "foo")
+
+ tint := p.FormData("TSlice")
+ require.Len(t, tint, 2)
+ require.Contains(t, tint, "bar")
+ require.Contains(t, tint, "foo")
+ })
+
+ t.Run("del params", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ defer ReleaseRequest(req)
+ req.SetFormData("foo", "bar").
+ SetFormDatas(map[string]string{
+ "foo": "fiber",
+ "bar": "foo",
+ }).DelFormDatas("foo", "bar")
+
+ res := req.FormData("foo")
+ require.Empty(t, res)
+
+ res = req.FormData("bar")
+ require.Empty(t, res)
+ })
+}
+
+func Test_Request_File(t *testing.T) {
+ t.Parallel()
+
+ t.Run("add file", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest().
+ AddFile("../.github/index.html").
+ AddFiles(AcquireFile(SetFileName("tmp.txt")))
+
+ require.Equal(t, "../.github/index.html", req.File("index.html").path)
+ require.Equal(t, "../.github/index.html", req.FileByPath("../.github/index.html").path)
+ require.Equal(t, "tmp.txt", req.File("tmp.txt").name)
+ require.Nil(t, req.File("tmp2.txt"))
+ require.Nil(t, req.FileByPath("tmp2.txt"))
+ })
+
+ t.Run("add file by reader", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest().
+ AddFileWithReader("tmp.txt", io.NopCloser(strings.NewReader("world")))
+
+ require.Equal(t, "tmp.txt", req.File("tmp.txt").name)
+
+ content, err := io.ReadAll(req.File("tmp.txt").reader)
+ require.NoError(t, err)
+ require.Equal(t, "world", string(content))
+ })
+
+ t.Run("add files", func(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest().
+ AddFiles(AcquireFile(SetFileName("tmp.txt")), AcquireFile(SetFileName("foo.txt")))
+
+ require.Equal(t, "tmp.txt", req.File("tmp.txt").name)
+ require.Equal(t, "foo.txt", req.File("foo.txt").name)
+ })
+}
+
+func Test_Request_Timeout(t *testing.T) {
+ t.Parallel()
+
+ req := AcquireRequest().SetTimeout(5 * time.Second)
+
+ require.Equal(t, 5*time.Second, req.Timeout())
+}
+
+func Test_Request_Invalid_URL(t *testing.T) {
+ t.Parallel()
+
+ resp, err := AcquireRequest().
+ Get("http://example.com\r\n\r\nGET /\r\n\r\n")
+
+ require.Equal(t, ErrURLFormat, err)
+ require.Equal(t, (*Response)(nil), resp)
+}
+
+func Test_Request_Unsupport_Protocol(t *testing.T) {
+ t.Parallel()
+
+ resp, err := AcquireRequest().
+ Get("ftp://example.com")
+ require.Equal(t, ErrURLFormat, err)
+ require.Equal(t, (*Response)(nil), resp)
+}
+
+func Test_Request_Get(t *testing.T) {
+ t.Parallel()
+
+ app, ln, start := createHelperServer(t)
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString(c.Hostname())
+ })
+
+ go start()
+ time.Sleep(100 * time.Millisecond)
+
+ client := NewClient().SetDial(ln)
+
+ for i := 0; i < 5; i++ {
+ req := AcquireRequest().SetClient(client)
+
+ resp, err := req.Get("http://example.com")
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "example.com", resp.String())
+ resp.Close()
+ }
+}
+
+func Test_Request_Post(t *testing.T) {
+ t.Parallel()
+
+ app, ln, start := createHelperServer(t)
+ app.Post("/", func(c fiber.Ctx) error {
+ return c.Status(fiber.StatusCreated).
+ SendString(c.FormValue("foo"))
+ })
+
+ go start()
+ time.Sleep(100 * time.Millisecond)
+
+ client := NewClient().SetDial(ln)
+
+ for i := 0; i < 5; i++ {
+ resp, err := AcquireRequest().
+ SetClient(client).
+ SetFormData("foo", "bar").
+ Post("http://example.com")
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusCreated, resp.StatusCode())
+ require.Equal(t, "bar", resp.String())
+ resp.Close()
+ }
+}
+
+func Test_Request_Head(t *testing.T) {
+ t.Parallel()
+
+ app, ln, start := createHelperServer(t)
+ app.Head("/", func(c fiber.Ctx) error {
+ return c.SendString(c.Hostname())
+ })
+
+ go start()
+ time.Sleep(100 * time.Millisecond)
+
+ client := NewClient().SetDial(ln)
+
+ for i := 0; i < 5; i++ {
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Head("http://example.com")
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "", resp.String())
+ resp.Close()
+ }
+}
+
+func Test_Request_Put(t *testing.T) {
+ t.Parallel()
+
+ app, ln, start := createHelperServer(t)
+ app.Put("/", func(c fiber.Ctx) error {
+ return c.SendString(c.FormValue("foo"))
+ })
+
+ go start()
+ time.Sleep(100 * time.Millisecond)
+
+ client := NewClient().SetDial(ln)
+
+ for i := 0; i < 5; i++ {
+ resp, err := AcquireRequest().
+ SetClient(client).
+ SetFormData("foo", "bar").
+ Put("http://example.com")
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "bar", resp.String())
+
+ resp.Close()
+ }
+}
+
+func Test_Request_Delete(t *testing.T) {
+ t.Parallel()
+
+ app, ln, start := createHelperServer(t)
+
+ app.Delete("/", func(c fiber.Ctx) error {
+ return c.Status(fiber.StatusNoContent).
+ SendString("deleted")
+ })
+
+ go start()
+ time.Sleep(100 * time.Millisecond)
+
+ client := NewClient().SetDial(ln)
+
+ for i := 0; i < 5; i++ {
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Delete("http://example.com")
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusNoContent, resp.StatusCode())
+ require.Equal(t, "", resp.String())
+
+ resp.Close()
+ }
+}
+
+func Test_Request_Options(t *testing.T) {
+ t.Parallel()
+
+ app, ln, start := createHelperServer(t)
+
+ app.Options("/", func(c fiber.Ctx) error {
+ return c.Status(fiber.StatusOK).
+ SendString("options")
+ })
+
+ go start()
+ time.Sleep(100 * time.Millisecond)
+
+ client := NewClient().SetDial(ln)
+
+ for i := 0; i < 5; i++ {
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Options("http://example.com")
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "options", resp.String())
+
+ resp.Close()
+ }
+}
+
+func Test_Request_Send(t *testing.T) {
+ t.Parallel()
+
+ app, ln, start := createHelperServer(t)
+
+ app.Post("/", func(c fiber.Ctx) error {
+ return c.Status(fiber.StatusOK).
+ SendString("post")
+ })
+
+ go start()
+ time.Sleep(100 * time.Millisecond)
+
+ client := NewClient().SetDial(ln)
+
+ for i := 0; i < 5; i++ {
+ resp, err := AcquireRequest().
+ SetClient(client).
+ SetURL("http://example.com").
+ SetMethod(fiber.MethodPost).
+ Send()
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "post", resp.String())
+
+ resp.Close()
+ }
+}
+
+func Test_Request_Patch(t *testing.T) {
+ t.Parallel()
+
+ app, ln, start := createHelperServer(t)
+
+ app.Patch("/", func(c fiber.Ctx) error {
+ return c.SendString(c.FormValue("foo"))
+ })
+
+ go start()
+ time.Sleep(100 * time.Millisecond)
+
+ client := NewClient().SetDial(ln)
+
+ for i := 0; i < 5; i++ {
+ resp, err := AcquireRequest().
+ SetClient(client).
+ SetFormData("foo", "bar").
+ Patch("http://example.com")
+
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "bar", resp.String())
+
+ resp.Close()
+ }
+}
+
+func Test_Request_Header_With_Server(t *testing.T) {
+ t.Parallel()
+ handler := func(c fiber.Ctx) error {
+ c.Request().Header.VisitAll(func(key, value []byte) {
+ if k := string(key); k == "K1" || k == "K2" {
+ _, err := c.Write(key)
+ require.NoError(t, err)
+ _, err = c.Write(value)
+ require.NoError(t, err)
+ }
+ })
+ return nil
+ }
+
+ wrapAgent := func(r *Request) {
+ r.SetHeader("k1", "v1").
+ AddHeader("k1", "v11").
+ AddHeaders(map[string][]string{
+ "k1": {"v22", "v33"},
+ }).
+ SetHeaders(map[string]string{
+ "k2": "v2",
+ }).
+ AddHeader("k2", "v22")
+ }
+
+ testRequest(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22")
+}
+
+func Test_Request_UserAgent_With_Server(t *testing.T) {
+ t.Parallel()
+
+ handler := func(c fiber.Ctx) error {
+ return c.Send(c.Request().Header.UserAgent())
+ }
+
+ t.Run("default", func(t *testing.T) {
+ t.Parallel()
+ testRequest(t, handler, func(_ *Request) {}, defaultUserAgent, 5)
+ })
+
+ t.Run("custom", func(t *testing.T) {
+ t.Parallel()
+ testRequest(t, handler, func(agent *Request) {
+ agent.SetUserAgent("ua")
+ }, "ua", 5)
+ })
+}
+
+func Test_Request_Cookie_With_Server(t *testing.T) {
+ t.Parallel()
+ handler := func(c fiber.Ctx) error {
+ return c.SendString(
+ c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4"))
+ }
+
+ wrapAgent := func(req *Request) {
+ req.SetCookie("k1", "v1").
+ SetCookies(map[string]string{
+ "k2": "v2",
+ "k3": "v3",
+ "k4": "v4",
+ }).DelCookies("k4")
+ }
+
+ testRequest(t, handler, wrapAgent, "v1v2v3")
+}
+
+func Test_Request_Referer_With_Server(t *testing.T) {
+ t.Parallel()
+ handler := func(c fiber.Ctx) error {
+ return c.Send(c.Request().Header.Referer())
+ }
+
+ wrapAgent := func(req *Request) {
+ req.SetReferer("http://referer.com")
+ }
+
+ testRequest(t, handler, wrapAgent, "http://referer.com")
+}
+
+func Test_Request_QueryString_With_Server(t *testing.T) {
+ t.Parallel()
+ handler := func(c fiber.Ctx) error {
+ return c.Send(c.Request().URI().QueryString())
+ }
+
+ wrapAgent := func(req *Request) {
+ req.SetParam("foo", "bar").
+ SetParams(map[string]string{
+ "bar": "baz",
+ })
+ }
+
+ testRequest(t, handler, wrapAgent, "foo=bar&bar=baz")
+}
+
+func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) {
+ t.Helper()
+
+ basename := filepath.Base(filename)
+ require.Equal(t, fh.Filename, basename)
+
+ b1, err := os.ReadFile(filepath.Clean(filename))
+ require.NoError(t, err)
+
+ b2 := make([]byte, fh.Size)
+ f, err := fh.Open()
+ require.NoError(t, err)
+ defer func() { require.NoError(t, f.Close()) }()
+ _, err = f.Read(b2)
+ require.NoError(t, err)
+ require.Equal(t, b1, b2)
+}
+
+func Test_Request_Body_With_Server(t *testing.T) {
+ t.Parallel()
+
+ t.Run("json body", func(t *testing.T) {
+ t.Parallel()
+ testRequest(t,
+ func(c fiber.Ctx) error {
+ require.Equal(t, "application/json", string(c.Request().Header.ContentType()))
+ return c.SendString(string(c.Request().Body()))
+ },
+ func(agent *Request) {
+ agent.SetJSON(map[string]string{
+ "success": "hello",
+ })
+ },
+ "{\"success\":\"hello\"}",
+ )
+ })
+
+ t.Run("xml body", func(t *testing.T) {
+ t.Parallel()
+ testRequest(t,
+ func(c fiber.Ctx) error {
+ require.Equal(t, "application/xml", string(c.Request().Header.ContentType()))
+ return c.SendString(string(c.Request().Body()))
+ },
+ func(agent *Request) {
+ type args struct {
+ Content string `xml:"content"`
+ }
+ agent.SetXML(args{
+ Content: "hello",
+ })
+ },
+ "hello",
+ )
+ })
+
+ t.Run("formdata", func(t *testing.T) {
+ t.Parallel()
+ testRequest(t,
+ func(c fiber.Ctx) error {
+ require.Equal(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType()))
+ return c.Send([]byte("foo=" + c.FormValue("foo") + "&bar=" + c.FormValue("bar") + "&fiber=" + c.FormValue("fiber")))
+ },
+ func(agent *Request) {
+ agent.SetFormData("foo", "bar").
+ SetFormDatas(map[string]string{
+ "bar": "baz",
+ "fiber": "fast",
+ })
+ },
+ "foo=bar&bar=baz&fiber=fast")
+ })
+
+ t.Run("multipart form", func(t *testing.T) {
+ t.Parallel()
+
+ app, ln, start := createHelperServer(t)
+ app.Post("/", func(c fiber.Ctx) error {
+ require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType))
+
+ mf, err := c.MultipartForm()
+ require.NoError(t, err)
+ require.Equal(t, "bar", mf.Value["foo"][0])
+
+ return c.Send(c.Request().Body())
+ })
+
+ go start()
+
+ client := NewClient().SetDial(ln)
+
+ req := AcquireRequest().
+ SetClient(client).
+ SetBoundary("myBoundary").
+ SetFormData("foo", "bar").
+ AddFiles(AcquireFile(
+ SetFileName("hello.txt"),
+ SetFileFieldName("foo"),
+ SetFileReader(io.NopCloser(strings.NewReader("world"))),
+ ))
+
+ require.Equal(t, "myBoundary", req.Boundary())
+
+ resp, err := req.Post("http://exmaple.com")
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+
+ form, err := multipart.NewReader(bytes.NewReader(resp.Body()), "myBoundary").ReadForm(1024 * 1024)
+ require.NoError(t, err)
+ require.Equal(t, "bar", form.Value["foo"][0])
+ resp.Close()
+ })
+
+ t.Run("multipart form send file", func(t *testing.T) {
+ t.Parallel()
+
+ app, ln, start := createHelperServer(t)
+ app.Post("/", func(c fiber.Ctx) error {
+ require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType))
+
+ fh1, err := c.FormFile("field1")
+ require.NoError(t, err)
+ require.Equal(t, "name", fh1.Filename)
+ buf := make([]byte, fh1.Size)
+ f, err := fh1.Open()
+ require.NoError(t, err)
+ defer func() { require.NoError(t, f.Close()) }()
+ _, err = f.Read(buf)
+ require.NoError(t, err)
+ require.Equal(t, "form file", string(buf))
+
+ fh2, err := c.FormFile("file2")
+ require.NoError(t, err)
+ checkFormFile(t, fh2, "../.github/testdata/index.html")
+
+ fh3, err := c.FormFile("file3")
+ require.NoError(t, err)
+ checkFormFile(t, fh3, "../.github/testdata/index.tmpl")
+
+ return c.SendString("multipart form files")
+ })
+
+ go start()
+
+ client := NewClient().SetDial(ln)
+
+ for i := 0; i < 5; i++ {
+ req := AcquireRequest().
+ SetClient(client).
+ AddFiles(
+ AcquireFile(
+ SetFileFieldName("field1"),
+ SetFileName("name"),
+ SetFileReader(io.NopCloser(bytes.NewReader([]byte("form file")))),
+ ),
+ ).
+ AddFile("../.github/testdata/index.html").
+ AddFile("../.github/testdata/index.tmpl").
+ SetBoundary("myBoundary")
+
+ resp, err := req.Post("http://example.com")
+ require.NoError(t, err)
+ require.Equal(t, "multipart form files", resp.String())
+
+ resp.Close()
+ }
+ })
+
+ t.Run("multipart random boundary", func(t *testing.T) {
+ t.Parallel()
+
+ app, ln, start := createHelperServer(t)
+ app.Post("/", func(c fiber.Ctx) error {
+ reg := regexp.MustCompile(`multipart/form-data; boundary=[\-\w]{35}`)
+ require.True(t, reg.MatchString(c.Get(fiber.HeaderContentType)))
+
+ return c.Send(c.Request().Body())
+ })
+
+ go start()
+
+ client := NewClient().SetDial(ln)
+
+ req := AcquireRequest().
+ SetClient(client).
+ SetFormData("foo", "bar").
+ AddFiles(AcquireFile(
+ SetFileName("hello.txt"),
+ SetFileFieldName("foo"),
+ SetFileReader(io.NopCloser(strings.NewReader("world"))),
+ ))
+
+ resp, err := req.Post("http://exmaple.com")
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ })
+
+ t.Run("raw body", func(t *testing.T) {
+ t.Parallel()
+ testRequest(t,
+ func(c fiber.Ctx) error {
+ return c.SendString(string(c.Request().Body()))
+ },
+ func(agent *Request) {
+ agent.SetRawBody([]byte("hello"))
+ },
+ "hello",
+ )
+ })
+}
+
+func Test_Request_Error_Body_With_Server(t *testing.T) {
+ t.Parallel()
+ t.Run("json error", func(t *testing.T) {
+ t.Parallel()
+ testRequestFail(t,
+ func(c fiber.Ctx) error {
+ return c.SendString("")
+ },
+ func(agent *Request) {
+ agent.SetJSON(complex(1, 1))
+ },
+ errors.New("json: unsupported type: complex128"),
+ )
+ })
+
+ t.Run("xml error", func(t *testing.T) {
+ t.Parallel()
+ testRequestFail(t,
+ func(c fiber.Ctx) error {
+ return c.SendString("")
+ },
+ func(agent *Request) {
+ agent.SetXML(complex(1, 1))
+ },
+ errors.New("xml: unsupported type: complex128"),
+ )
+ })
+
+ t.Run("form body with invalid boundary", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := AcquireRequest().
+ SetBoundary("*").
+ AddFileWithReader("t.txt", io.NopCloser(strings.NewReader("world"))).
+ Get("http://example.com")
+ require.Equal(t, "set boundary error: mime: invalid boundary character", err.Error())
+ })
+
+ t.Run("open non exist file", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := AcquireRequest().
+ AddFile("non-exist-file!").
+ Get("http://example.com")
+ require.Contains(t, err.Error(), "open non-exist-file!")
+ })
+}
+
+func Test_Request_Timeout_With_Server(t *testing.T) {
+ t.Parallel()
+
+ app, ln, start := createHelperServer(t)
+ app.Get("/", func(c fiber.Ctx) error {
+ time.Sleep(time.Millisecond * 200)
+ return c.SendString("timeout")
+ })
+ go start()
+
+ client := NewClient().SetDial(ln)
+
+ _, err := AcquireRequest().
+ SetClient(client).
+ SetTimeout(50 * time.Millisecond).
+ Get("http://example.com")
+
+ require.Equal(t, ErrTimeoutOrCancel, err)
+}
+
+func Test_Request_MaxRedirects(t *testing.T) {
+ t.Parallel()
+
+ ln := fasthttputil.NewInmemoryListener()
+
+ app := fiber.New()
+
+ app.Get("/", func(c fiber.Ctx) error {
+ if c.Request().URI().QueryArgs().Has("foo") {
+ return c.Redirect().To("/foo")
+ }
+ return c.Redirect().To("/")
+ })
+ app.Get("/foo", func(c fiber.Ctx) error {
+ return c.SendString("redirect")
+ })
+
+ go func() { require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }()
+
+ t.Run("success", func(t *testing.T) {
+ t.Parallel()
+
+ client := NewClient().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ SetMaxRedirects(1).
+ Get("http://example.com?foo")
+ body := resp.String()
+ code := resp.StatusCode()
+
+ require.Equal(t, 200, code)
+ require.Equal(t, "redirect", body)
+ require.NoError(t, err)
+
+ resp.Close()
+ })
+
+ t.Run("error", func(t *testing.T) {
+ t.Parallel()
+
+ client := NewClient().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ SetMaxRedirects(1).
+ Get("http://example.com")
+
+ require.Nil(t, resp)
+ require.Equal(t, "too many redirects detected when doing the request", err.Error())
+ })
+
+ t.Run("MaxRedirects", func(t *testing.T) {
+ t.Parallel()
+
+ client := NewClient().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed
+
+ req := AcquireRequest().
+ SetClient(client).
+ SetMaxRedirects(3)
+
+ require.Equal(t, 3, req.MaxRedirects())
+ })
+}
+
+func Test_SetValWithStruct(t *testing.T) {
+ t.Parallel()
+
+ // test SetValWithStruct vai QueryParam struct.
+ type args struct {
+ unexport int
+ TInt int
+ TString string
+ TFloat float64
+ TBool bool
+ TSlice []string
+ TIntSlice []int `param:"int_slice"`
+ }
+
+ t.Run("the struct should be applied", func(t *testing.T) {
+ t.Parallel()
+ p := &QueryParam{
+ Args: fasthttp.AcquireArgs(),
+ }
+
+ SetValWithStruct(p, "param", args{
+ unexport: 5,
+ TInt: 5,
+ TString: "string",
+ TFloat: 3.1,
+ TBool: false,
+ TSlice: []string{"foo", "bar"},
+ TIntSlice: []int{1, 2},
+ })
+
+ require.Equal(t, "", string(p.Peek("unexport")))
+ require.Equal(t, []byte("5"), p.Peek("TInt"))
+ require.Equal(t, []byte("string"), p.Peek("TString"))
+ require.Equal(t, []byte("3.1"), p.Peek("TFloat"))
+ require.Equal(t, "", string(p.Peek("TBool")))
+ require.True(t, func() bool {
+ for _, v := range p.PeekMulti("TSlice") {
+ if string(v) == "foo" {
+ return true
+ }
+ }
+ return false
+ }())
+
+ require.True(t, func() bool {
+ for _, v := range p.PeekMulti("TSlice") {
+ if string(v) == "bar" {
+ return true
+ }
+ }
+ return false
+ }())
+
+ require.True(t, func() bool {
+ for _, v := range p.PeekMulti("int_slice") {
+ if string(v) == "1" {
+ return true
+ }
+ }
+ return false
+ }())
+
+ require.True(t, func() bool {
+ for _, v := range p.PeekMulti("int_slice") {
+ if string(v) == "2" {
+ return true
+ }
+ }
+ return false
+ }())
+ })
+
+ t.Run("the pointer of a struct should be applied", func(t *testing.T) {
+ t.Parallel()
+ p := &QueryParam{
+ Args: fasthttp.AcquireArgs(),
+ }
+
+ SetValWithStruct(p, "param", &args{
+ TInt: 5,
+ TString: "string",
+ TFloat: 3.1,
+ TBool: true,
+ TSlice: []string{"foo", "bar"},
+ TIntSlice: []int{1, 2},
+ })
+
+ require.Equal(t, []byte("5"), p.Peek("TInt"))
+ require.Equal(t, []byte("string"), p.Peek("TString"))
+ require.Equal(t, []byte("3.1"), p.Peek("TFloat"))
+ require.Equal(t, "true", string(p.Peek("TBool")))
+ require.True(t, func() bool {
+ for _, v := range p.PeekMulti("TSlice") {
+ if string(v) == "foo" {
+ return true
+ }
+ }
+ return false
+ }())
+
+ require.True(t, func() bool {
+ for _, v := range p.PeekMulti("TSlice") {
+ if string(v) == "bar" {
+ return true
+ }
+ }
+ return false
+ }())
+
+ require.True(t, func() bool {
+ for _, v := range p.PeekMulti("int_slice") {
+ if string(v) == "1" {
+ return true
+ }
+ }
+ return false
+ }())
+
+ require.True(t, func() bool {
+ for _, v := range p.PeekMulti("int_slice") {
+ if string(v) == "2" {
+ return true
+ }
+ }
+ return false
+ }())
+ })
+
+ t.Run("the zero val should be ignore", func(t *testing.T) {
+ t.Parallel()
+ p := &QueryParam{
+ Args: fasthttp.AcquireArgs(),
+ }
+ SetValWithStruct(p, "param", &args{
+ TInt: 0,
+ TString: "",
+ TFloat: 0.0,
+ })
+
+ require.Equal(t, "", string(p.Peek("TInt")))
+ require.Equal(t, "", string(p.Peek("TString")))
+ require.Equal(t, "", string(p.Peek("TFloat")))
+ require.Empty(t, p.PeekMulti("TSlice"))
+ require.Empty(t, p.PeekMulti("int_slice"))
+ })
+
+ t.Run("error type should ignore", func(t *testing.T) {
+ t.Parallel()
+ p := &QueryParam{
+ Args: fasthttp.AcquireArgs(),
+ }
+ SetValWithStruct(p, "param", 5)
+ require.Equal(t, 0, p.Len())
+ })
+}
+
+func Benchmark_SetValWithStruct(b *testing.B) {
+ // test SetValWithStruct vai QueryParam struct.
+ type args struct {
+ unexport int
+ TInt int
+ TString string
+ TFloat float64
+ TBool bool
+ TSlice []string
+ TIntSlice []int `param:"int_slice"`
+ }
+
+ b.Run("the struct should be applied", func(b *testing.B) {
+ p := &QueryParam{
+ Args: fasthttp.AcquireArgs(),
+ }
+
+ b.ReportAllocs()
+ b.StartTimer()
+
+ for i := 0; i < b.N; i++ {
+ SetValWithStruct(p, "param", args{
+ unexport: 5,
+ TInt: 5,
+ TString: "string",
+ TFloat: 3.1,
+ TBool: false,
+ TSlice: []string{"foo", "bar"},
+ TIntSlice: []int{1, 2},
+ })
+ }
+
+ require.Equal(b, "", string(p.Peek("unexport")))
+ require.Equal(b, []byte("5"), p.Peek("TInt"))
+ require.Equal(b, []byte("string"), p.Peek("TString"))
+ require.Equal(b, []byte("3.1"), p.Peek("TFloat"))
+ require.Equal(b, "", string(p.Peek("TBool")))
+ require.True(b, func() bool {
+ for _, v := range p.PeekMulti("TSlice") {
+ if string(v) == "foo" {
+ return true
+ }
+ }
+ return false
+ }())
+
+ require.True(b, func() bool {
+ for _, v := range p.PeekMulti("TSlice") {
+ if string(v) == "bar" {
+ return true
+ }
+ }
+ return false
+ }())
+
+ require.True(b, func() bool {
+ for _, v := range p.PeekMulti("int_slice") {
+ if string(v) == "1" {
+ return true
+ }
+ }
+ return false
+ }())
+
+ require.True(b, func() bool {
+ for _, v := range p.PeekMulti("int_slice") {
+ if string(v) == "2" {
+ return true
+ }
+ }
+ return false
+ }())
+ })
+
+ b.Run("the pointer of a struct should be applied", func(b *testing.B) {
+ p := &QueryParam{
+ Args: fasthttp.AcquireArgs(),
+ }
+
+ b.ReportAllocs()
+ b.StartTimer()
+
+ for i := 0; i < b.N; i++ {
+ SetValWithStruct(p, "param", &args{
+ TInt: 5,
+ TString: "string",
+ TFloat: 3.1,
+ TBool: true,
+ TSlice: []string{"foo", "bar"},
+ TIntSlice: []int{1, 2},
+ })
+ }
+
+ require.Equal(b, []byte("5"), p.Peek("TInt"))
+ require.Equal(b, []byte("string"), p.Peek("TString"))
+ require.Equal(b, []byte("3.1"), p.Peek("TFloat"))
+ require.Equal(b, "true", string(p.Peek("TBool")))
+ require.True(b, func() bool {
+ for _, v := range p.PeekMulti("TSlice") {
+ if string(v) == "foo" {
+ return true
+ }
+ }
+ return false
+ }())
+
+ require.True(b, func() bool {
+ for _, v := range p.PeekMulti("TSlice") {
+ if string(v) == "bar" {
+ return true
+ }
+ }
+ return false
+ }())
+
+ require.True(b, func() bool {
+ for _, v := range p.PeekMulti("int_slice") {
+ if string(v) == "1" {
+ return true
+ }
+ }
+ return false
+ }())
+
+ require.True(b, func() bool {
+ for _, v := range p.PeekMulti("int_slice") {
+ if string(v) == "2" {
+ return true
+ }
+ }
+ return false
+ }())
+ })
+
+ b.Run("the zero val should be ignore", func(b *testing.B) {
+ p := &QueryParam{
+ Args: fasthttp.AcquireArgs(),
+ }
+
+ b.ReportAllocs()
+ b.StartTimer()
+
+ for i := 0; i < b.N; i++ {
+ SetValWithStruct(p, "param", &args{
+ TInt: 0,
+ TString: "",
+ TFloat: 0.0,
+ })
+ }
+
+ require.Empty(b, string(p.Peek("TInt")))
+ require.Empty(b, string(p.Peek("TString")))
+ require.Empty(b, string(p.Peek("TFloat")))
+ require.Empty(b, len(p.PeekMulti("TSlice")))
+ require.Empty(b, len(p.PeekMulti("int_slice")))
+ })
+
+ b.Run("error type should ignore", func(b *testing.B) {
+ p := &QueryParam{
+ Args: fasthttp.AcquireArgs(),
+ }
+
+ b.ReportAllocs()
+ b.StartTimer()
+
+ for i := 0; i < b.N; i++ {
+ SetValWithStruct(p, "param", 5)
+ }
+
+ require.Equal(b, 0, p.Len())
+ })
+}
diff --git a/client/response.go b/client/response.go
new file mode 100644
index 0000000000..f6ecd6fcd8
--- /dev/null
+++ b/client/response.go
@@ -0,0 +1,184 @@
+package client
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "io/fs"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+
+ "github.com/gofiber/utils/v2"
+ "github.com/valyala/fasthttp"
+)
+
+// Response is the result of a request. This object is used to access the response data.
+type Response struct {
+ client *Client
+ request *Request
+ cookie []*fasthttp.Cookie
+
+ RawResponse *fasthttp.Response
+}
+
+// setClient method sets client object in response instance.
+// Use core object in the client.
+func (r *Response) setClient(c *Client) {
+ r.client = c
+}
+
+// setRequest method sets Request object in response instance.
+// The request will be released when the Response.Close is called.
+func (r *Response) setRequest(req *Request) {
+ r.request = req
+}
+
+// Status method returns the HTTP status string for the executed request.
+func (r *Response) Status() string {
+ return string(r.RawResponse.Header.StatusMessage())
+}
+
+// StatusCode method returns the HTTP status code for the executed request.
+func (r *Response) StatusCode() int {
+ return r.RawResponse.StatusCode()
+}
+
+// Protocol method returns the HTTP response protocol used for the request.
+func (r *Response) Protocol() string {
+ return string(r.RawResponse.Header.Protocol())
+}
+
+// Header method returns the response headers.
+func (r *Response) Header(key string) string {
+ return utils.UnsafeString(r.RawResponse.Header.Peek(key))
+}
+
+// Cookies method to access all the response cookies.
+func (r *Response) Cookies() []*fasthttp.Cookie {
+ return r.cookie
+}
+
+// Body method returns HTTP response as []byte array for the executed request.
+func (r *Response) Body() []byte {
+ return r.RawResponse.Body()
+}
+
+// String method returns the body of the server response as String.
+func (r *Response) String() string {
+ return strings.TrimSpace(string(r.Body()))
+}
+
+// JSON method will unmarshal body to json.
+func (r *Response) JSON(v any) error {
+ return r.client.jsonUnmarshal(r.Body(), v)
+}
+
+// XML method will unmarshal body to xml.
+func (r *Response) XML(v any) error {
+ return r.client.xmlUnmarshal(r.Body(), v)
+}
+
+// Save method will save the body to a file or io.Writer.
+func (r *Response) Save(v any) error {
+ switch p := v.(type) {
+ case string:
+ file := filepath.Clean(p)
+ dir := filepath.Dir(file)
+
+ // create directory
+ if _, err := os.Stat(dir); err != nil {
+ if !errors.Is(err, fs.ErrNotExist) {
+ return fmt.Errorf("failed to check directory: %w", err)
+ }
+
+ if err = os.MkdirAll(dir, 0o750); err != nil {
+ return fmt.Errorf("failed to create directory: %w", err)
+ }
+ }
+
+ // create file
+ outFile, err := os.Create(file)
+ if err != nil {
+ return fmt.Errorf("failed to create file: %w", err)
+ }
+ defer func() { _ = outFile.Close() }() //nolint:errcheck // not needed
+
+ _, err = io.Copy(outFile, bytes.NewReader(r.Body()))
+ if err != nil {
+ return fmt.Errorf("failed to write response body to file: %w", err)
+ }
+
+ return nil
+ case io.Writer:
+ _, err := io.Copy(p, bytes.NewReader(r.Body()))
+ if err != nil {
+ return fmt.Errorf("failed to write response body to io.Writer: %w", err)
+ }
+ defer func() {
+ if pc, ok := p.(io.WriteCloser); ok {
+ _ = pc.Close() //nolint:errcheck // not needed
+ }
+ }()
+
+ return nil
+ default:
+ return ErrNotSupportSaveMethod
+ }
+}
+
+// Reset clear Response object.
+func (r *Response) Reset() {
+ r.client = nil
+ r.request = nil
+
+ for len(r.cookie) != 0 {
+ t := r.cookie[0]
+ r.cookie = r.cookie[1:]
+ fasthttp.ReleaseCookie(t)
+ }
+
+ r.RawResponse.Reset()
+}
+
+// Close method will release Request object and Response object,
+// after call Close please don't use these object.
+func (r *Response) Close() {
+ if r.request != nil {
+ tmp := r.request
+ r.request = nil
+ ReleaseRequest(tmp)
+ }
+ ReleaseResponse(r)
+}
+
+var responsePool = &sync.Pool{
+ New: func() any {
+ return &Response{
+ cookie: []*fasthttp.Cookie{},
+ RawResponse: fasthttp.AcquireResponse(),
+ }
+ },
+}
+
+// AcquireResponse returns an empty response object from the pool.
+//
+// The returned response may be returned to the pool with ReleaseResponse when no longer needed.
+// This allows reducing GC load.
+func AcquireResponse() *Response {
+ resp, ok := responsePool.Get().(*Response)
+ if !ok {
+ panic("unexpected type from responsePool.Get()")
+ }
+ return resp
+}
+
+// ReleaseResponse returns the object acquired via AcquireResponse to the pool.
+//
+// Do not access the released Response object, otherwise data races may occur.
+func ReleaseResponse(resp *Response) {
+ resp.Reset()
+ responsePool.Put(resp)
+}
diff --git a/client/response_test.go b/client/response_test.go
new file mode 100644
index 0000000000..622e835714
--- /dev/null
+++ b/client/response_test.go
@@ -0,0 +1,418 @@
+package client
+
+import (
+ "bytes"
+ "crypto/tls"
+ "encoding/xml"
+ "io"
+ "net"
+ "os"
+ "testing"
+
+ "github.com/gofiber/fiber/v3/internal/tlstest"
+
+ "github.com/gofiber/fiber/v3"
+ "github.com/stretchr/testify/require"
+)
+
+func Test_Response_Status(t *testing.T) {
+ t.Parallel()
+
+ setupApp := func() *testServer {
+ server := startTestServer(t, func(app *fiber.App) {
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString("foo")
+ })
+ app.Get("/fail", func(c fiber.Ctx) error {
+ return c.SendStatus(407)
+ })
+ })
+
+ return server
+ }
+
+ t.Run("success", func(t *testing.T) {
+ t.Parallel()
+
+ server := setupApp()
+ defer server.stop()
+
+ client := NewClient().SetDial(server.dial())
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Get("http://example")
+
+ require.NoError(t, err)
+ require.Equal(t, "OK", resp.Status())
+ resp.Close()
+ })
+
+ t.Run("fail", func(t *testing.T) {
+ t.Parallel()
+
+ server := setupApp()
+ defer server.stop()
+
+ client := NewClient().SetDial(server.dial())
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Get("http://example/fail")
+
+ require.NoError(t, err)
+ require.Equal(t, "Proxy Authentication Required", resp.Status())
+ resp.Close()
+ })
+}
+
+func Test_Response_Status_Code(t *testing.T) {
+ t.Parallel()
+
+ setupApp := func() *testServer {
+ server := startTestServer(t, func(app *fiber.App) {
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString("foo")
+ })
+ app.Get("/fail", func(c fiber.Ctx) error {
+ return c.SendStatus(407)
+ })
+ })
+
+ return server
+ }
+
+ t.Run("success", func(t *testing.T) {
+ t.Parallel()
+
+ server := setupApp()
+ defer server.stop()
+
+ client := NewClient().SetDial(server.dial())
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Get("http://example")
+
+ require.NoError(t, err)
+ require.Equal(t, 200, resp.StatusCode())
+ resp.Close()
+ })
+
+ t.Run("fail", func(t *testing.T) {
+ t.Parallel()
+
+ server := setupApp()
+ defer server.stop()
+
+ client := NewClient().SetDial(server.dial())
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Get("http://example/fail")
+
+ require.NoError(t, err)
+ require.Equal(t, 407, resp.StatusCode())
+ resp.Close()
+ })
+}
+
+func Test_Response_Protocol(t *testing.T) {
+ t.Parallel()
+
+ t.Run("http", func(t *testing.T) {
+ t.Parallel()
+
+ server := startTestServer(t, func(app *fiber.App) {
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString("foo")
+ })
+ })
+ defer server.stop()
+
+ client := NewClient().SetDial(server.dial())
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Get("http://example")
+
+ require.NoError(t, err)
+ require.Equal(t, "HTTP/1.1", resp.Protocol())
+ resp.Close()
+ })
+
+ t.Run("https", func(t *testing.T) {
+ t.Parallel()
+
+ serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs()
+ require.NoError(t, err)
+
+ ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
+ require.NoError(t, err)
+
+ ln = tls.NewListener(ln, serverTLSConf)
+
+ app := fiber.New()
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString(c.Scheme())
+ })
+
+ go func() {
+ require.NoError(t, app.Listener(ln, fiber.ListenConfig{
+ DisableStartupMessage: true,
+ }))
+ }()
+
+ client := NewClient()
+ resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String())
+
+ require.NoError(t, err)
+ require.Equal(t, clientTLSConf, client.TLSConfig())
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "https", resp.String())
+ require.Equal(t, "HTTP/1.1", resp.Protocol())
+
+ resp.Close()
+ })
+}
+
+func Test_Response_Header(t *testing.T) {
+ t.Parallel()
+
+ server := startTestServer(t, func(app *fiber.App) {
+ app.Get("/", func(c fiber.Ctx) error {
+ c.Response().Header.Add("foo", "bar")
+ return c.SendString("helo world")
+ })
+ })
+ defer server.stop()
+
+ client := NewClient().SetDial(server.dial())
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Get("http://example.com")
+
+ require.NoError(t, err)
+ require.Equal(t, "bar", resp.Header("foo"))
+ resp.Close()
+}
+
+func Test_Response_Cookie(t *testing.T) {
+ t.Parallel()
+
+ server := startTestServer(t, func(app *fiber.App) {
+ app.Get("/", func(c fiber.Ctx) error {
+ c.Cookie(&fiber.Cookie{
+ Name: "foo",
+ Value: "bar",
+ })
+ return c.SendString("helo world")
+ })
+ })
+ defer server.stop()
+
+ client := NewClient().SetDial(server.dial())
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Get("http://example.com")
+
+ require.NoError(t, err)
+ require.Equal(t, "bar", string(resp.Cookies()[0].Value()))
+ resp.Close()
+}
+
+func Test_Response_Body(t *testing.T) {
+ t.Parallel()
+
+ setupApp := func() *testServer {
+ server := startTestServer(t, func(app *fiber.App) {
+ app.Get("/", func(c fiber.Ctx) error {
+ return c.SendString("hello world")
+ })
+
+ app.Get("/json", func(c fiber.Ctx) error {
+ return c.SendString("{\"status\":\"success\"}")
+ })
+
+ app.Get("/xml", func(c fiber.Ctx) error {
+ return c.SendString("success")
+ })
+ })
+
+ return server
+ }
+
+ t.Run("raw body", func(t *testing.T) {
+ t.Parallel()
+
+ server := setupApp()
+ defer server.stop()
+
+ client := NewClient().SetDial(server.dial())
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Get("http://example.com")
+
+ require.NoError(t, err)
+ require.Equal(t, []byte("hello world"), resp.Body())
+ resp.Close()
+ })
+
+ t.Run("string body", func(t *testing.T) {
+ t.Parallel()
+
+ server := setupApp()
+ defer server.stop()
+
+ client := NewClient().SetDial(server.dial())
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Get("http://example.com")
+
+ require.NoError(t, err)
+ require.Equal(t, "hello world", resp.String())
+ resp.Close()
+ })
+
+ t.Run("json body", func(t *testing.T) {
+ t.Parallel()
+ type body struct {
+ Status string `json:"status"`
+ }
+
+ server := setupApp()
+ defer server.stop()
+
+ client := NewClient().SetDial(server.dial())
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Get("http://example.com/json")
+
+ require.NoError(t, err)
+
+ tmp := &body{}
+ err = resp.JSON(tmp)
+ require.NoError(t, err)
+ require.Equal(t, "success", tmp.Status)
+ resp.Close()
+ })
+
+ t.Run("xml body", func(t *testing.T) {
+ t.Parallel()
+ type body struct {
+ Name xml.Name `xml:"status"`
+ Status string `xml:"name"`
+ }
+
+ server := setupApp()
+ defer server.stop()
+
+ client := NewClient().SetDial(server.dial())
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Get("http://example.com/xml")
+
+ require.NoError(t, err)
+
+ tmp := &body{}
+ err = resp.XML(tmp)
+ require.NoError(t, err)
+ require.Equal(t, "success", tmp.Status)
+ resp.Close()
+ })
+}
+
+func Test_Response_Save(t *testing.T) {
+ t.Parallel()
+
+ setupApp := func() *testServer {
+ server := startTestServer(t, func(app *fiber.App) {
+ app.Get("/json", func(c fiber.Ctx) error {
+ return c.SendString("{\"status\":\"success\"}")
+ })
+ })
+
+ return server
+ }
+
+ t.Run("file path", func(t *testing.T) {
+ t.Parallel()
+
+ server := setupApp()
+ defer server.stop()
+
+ client := NewClient().SetDial(server.dial())
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Get("http://example.com/json")
+
+ require.NoError(t, err)
+
+ err = resp.Save("./test/tmp.json")
+ require.NoError(t, err)
+ defer func() {
+ _, err := os.Stat("./test/tmp.json")
+ require.NoError(t, err)
+
+ err = os.RemoveAll("./test")
+ require.NoError(t, err)
+ }()
+
+ file, err := os.Open("./test/tmp.json")
+ require.NoError(t, err)
+ defer func(file *os.File) {
+ err := file.Close()
+ require.NoError(t, err)
+ }(file)
+
+ data, err := io.ReadAll(file)
+ require.NoError(t, err)
+ require.Equal(t, "{\"status\":\"success\"}", string(data))
+ })
+
+ t.Run("io.Writer", func(t *testing.T) {
+ t.Parallel()
+
+ server := setupApp()
+ defer server.stop()
+
+ client := NewClient().SetDial(server.dial())
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Get("http://example.com/json")
+
+ require.NoError(t, err)
+
+ buf := &bytes.Buffer{}
+
+ err = resp.Save(buf)
+ require.NoError(t, err)
+ require.Equal(t, "{\"status\":\"success\"}", buf.String())
+ })
+
+ t.Run("error type", func(t *testing.T) {
+ t.Parallel()
+
+ server := setupApp()
+ defer server.stop()
+
+ client := NewClient().SetDial(server.dial())
+
+ resp, err := AcquireRequest().
+ SetClient(client).
+ Get("http://example.com/json")
+
+ require.NoError(t, err)
+
+ err = resp.Save(nil)
+ require.Error(t, err)
+ })
+}
diff --git a/client_test.go b/client_test.go
deleted file mode 100644
index 57cc4e4d2e..0000000000
--- a/client_test.go
+++ /dev/null
@@ -1,1337 +0,0 @@
-//nolint:wrapcheck // We must not wrap errors in tests
-package fiber
-
-import (
- "bytes"
- "crypto/tls"
- "encoding/base64"
- "encoding/json"
- "encoding/xml"
- "errors"
- "io"
- "mime/multipart"
- "net"
- "os"
- "path/filepath"
- "regexp"
- "strings"
- "testing"
- "time"
-
- "github.com/gofiber/fiber/v3/internal/tlstest"
- "github.com/stretchr/testify/require"
- "github.com/valyala/fasthttp"
- "github.com/valyala/fasthttp/fasthttputil"
-)
-
-func Test_Client_Invalid_URL(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Get("/", func(c Ctx) error {
- return c.SendString(c.Host())
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- a := Get("http://example.com\r\n\r\nGET /\r\n\r\n")
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- _, body, errs := a.String()
-
- require.Equal(t, "", body)
- require.Len(t, errs, 1)
- require.Error(t, errs[0],
- `Expected error "missing required Host header in request"`)
-}
-
-func Test_Client_Unsupported_Protocol(t *testing.T) {
- t.Parallel()
-
- a := Get("ftp://example.com")
-
- _, body, errs := a.String()
-
- require.Equal(t, "", body)
- require.Len(t, errs, 1)
- require.ErrorContains(t, errs[0], `unsupported protocol "ftp". http and https are supported`)
-}
-
-func Test_Client_Get(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Get("/", func(c Ctx) error {
- return c.SendString(c.Host())
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- for i := 0; i < 5; i++ {
- a := Get("http://example.com")
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.String()
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, "example.com", body)
- require.Empty(t, errs)
- }
-}
-
-func Test_Client_Head(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Head("/", func(c Ctx) error {
- return c.SendStatus(StatusAccepted)
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
- for i := 0; i < 5; i++ {
- a := Head("http://example.com")
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.String()
-
- require.Equal(t, StatusAccepted, code)
- require.Equal(t, "", body)
- require.Empty(t, errs)
- }
-}
-
-func Test_Client_Post(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Post("/", func(c Ctx) error {
- return c.Status(StatusCreated).
- SendString(c.FormValue("foo"))
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- for i := 0; i < 5; i++ {
- args := AcquireArgs()
-
- args.Set("foo", "bar")
-
- a := Post("http://example.com").
- Form(args)
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.String()
-
- require.Equal(t, StatusCreated, code)
- require.Equal(t, "bar", body)
- require.Empty(t, errs)
-
- ReleaseArgs(args)
- }
-}
-
-func Test_Client_Put(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Put("/", func(c Ctx) error {
- return c.SendString(c.FormValue("foo"))
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- for i := 0; i < 5; i++ {
- args := AcquireArgs()
-
- args.Set("foo", "bar")
-
- a := Put("http://example.com").
- Form(args)
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.String()
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, "bar", body)
- require.Empty(t, errs)
-
- ReleaseArgs(args)
- }
-}
-
-func Test_Client_Patch(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Patch("/", func(c Ctx) error {
- return c.SendString(c.FormValue("foo"))
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- for i := 0; i < 5; i++ {
- args := AcquireArgs()
-
- args.Set("foo", "bar")
-
- a := Patch("http://example.com").
- Form(args)
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.String()
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, "bar", body)
- require.Empty(t, errs)
-
- ReleaseArgs(args)
- }
-}
-
-func Test_Client_Delete(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Delete("/", func(c Ctx) error {
- return c.Status(StatusNoContent).
- SendString("deleted")
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- for i := 0; i < 5; i++ {
- args := AcquireArgs()
-
- a := Delete("http://example.com")
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.String()
-
- require.Equal(t, StatusNoContent, code)
- require.Equal(t, "", body)
- require.Empty(t, errs)
-
- ReleaseArgs(args)
- }
-}
-
-func Test_Client_UserAgent(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Get("/", func(c Ctx) error {
- return c.Send(c.Request().Header.UserAgent())
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- t.Run("default", func(t *testing.T) {
- t.Parallel()
- for i := 0; i < 5; i++ {
- a := Get("http://example.com")
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.String()
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, defaultUserAgent, body)
- require.Empty(t, errs)
- }
- })
-
- t.Run("custom", func(t *testing.T) {
- t.Parallel()
- for i := 0; i < 5; i++ {
- c := AcquireClient()
- c.UserAgent = "ua"
-
- a := c.Get("http://example.com")
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.String()
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, "ua", body)
- require.Empty(t, errs)
- ReleaseClient(c)
- }
- })
-}
-
-func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) {
- t.Parallel()
- handler := func(c Ctx) error {
- c.Request().Header.VisitAll(func(key, value []byte) {
- if k := string(key); k == "K1" || k == "K2" {
- _, err := c.Write(key)
- require.NoError(t, err)
- _, err = c.Write(value)
- require.NoError(t, err)
- }
- })
- return nil
- }
-
- wrapAgent := func(a *Agent) {
- a.Set("k1", "v1").
- SetBytesK([]byte("k1"), "v1").
- SetBytesV("k1", []byte("v1")).
- AddBytesK([]byte("k1"), "v11").
- AddBytesV("k1", []byte("v22")).
- AddBytesKV([]byte("k1"), []byte("v33")).
- SetBytesKV([]byte("k2"), []byte("v2")).
- Add("k2", "v22")
- }
-
- testAgent(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22")
-}
-
-func Test_Client_Agent_Connection_Close(t *testing.T) {
- t.Parallel()
- handler := func(c Ctx) error {
- if c.Request().Header.ConnectionClose() {
- return c.SendString("close")
- }
- return c.SendString("not close")
- }
-
- wrapAgent := func(a *Agent) {
- a.ConnectionClose()
- }
-
- testAgent(t, handler, wrapAgent, "close")
-}
-
-func Test_Client_Agent_UserAgent(t *testing.T) {
- t.Parallel()
- handler := func(c Ctx) error {
- return c.Send(c.Request().Header.UserAgent())
- }
-
- wrapAgent := func(a *Agent) {
- a.UserAgent("ua").
- UserAgentBytes([]byte("ua"))
- }
-
- testAgent(t, handler, wrapAgent, "ua")
-}
-
-func Test_Client_Agent_Cookie(t *testing.T) {
- t.Parallel()
- handler := func(c Ctx) error {
- return c.SendString(
- c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4"))
- }
-
- wrapAgent := func(a *Agent) {
- a.Cookie("k1", "v1").
- CookieBytesK([]byte("k2"), "v2").
- CookieBytesKV([]byte("k2"), []byte("v2")).
- Cookies("k3", "v3", "k4", "v4").
- CookiesBytesKV([]byte("k3"), []byte("v3"), []byte("k4"), []byte("v4"))
- }
-
- testAgent(t, handler, wrapAgent, "v1v2v3v4")
-}
-
-func Test_Client_Agent_Referer(t *testing.T) {
- t.Parallel()
- handler := func(c Ctx) error {
- return c.Send(c.Request().Header.Referer())
- }
-
- wrapAgent := func(a *Agent) {
- a.Referer("http://referer.com").
- RefererBytes([]byte("http://referer.com"))
- }
-
- testAgent(t, handler, wrapAgent, "http://referer.com")
-}
-
-func Test_Client_Agent_ContentType(t *testing.T) {
- t.Parallel()
- handler := func(c Ctx) error {
- return c.Send(c.Request().Header.ContentType())
- }
-
- wrapAgent := func(a *Agent) {
- a.ContentType("custom-type").
- ContentTypeBytes([]byte("custom-type"))
- }
-
- testAgent(t, handler, wrapAgent, "custom-type")
-}
-
-func Test_Client_Agent_Host(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Get("/", func(c Ctx) error {
- return c.SendString(c.Host())
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- a := Get("http://1.1.1.1:8080").
- Host("example.com").
- HostBytes([]byte("example.com"))
-
- require.Equal(t, "1.1.1.1:8080", a.HostClient.Addr)
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.String()
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, "example.com", body)
- require.Empty(t, errs)
-}
-
-func Test_Client_Agent_QueryString(t *testing.T) {
- t.Parallel()
- handler := func(c Ctx) error {
- return c.Send(c.Request().URI().QueryString())
- }
-
- wrapAgent := func(a *Agent) {
- a.QueryString("foo=bar&bar=baz").
- QueryStringBytes([]byte("foo=bar&bar=baz"))
- }
-
- testAgent(t, handler, wrapAgent, "foo=bar&bar=baz")
-}
-
-func Test_Client_Agent_BasicAuth(t *testing.T) {
- t.Parallel()
- handler := func(c Ctx) error {
- // Get authorization header
- auth := c.Get(HeaderAuthorization)
- // Decode the header contents
- raw, err := base64.StdEncoding.DecodeString(auth[6:])
- require.NoError(t, err)
-
- return c.Send(raw)
- }
-
- wrapAgent := func(a *Agent) {
- a.BasicAuth("foo", "bar").
- BasicAuthBytes([]byte("foo"), []byte("bar"))
- }
-
- testAgent(t, handler, wrapAgent, "foo:bar")
-}
-
-func Test_Client_Agent_BodyString(t *testing.T) {
- t.Parallel()
- handler := func(c Ctx) error {
- return c.Send(c.Request().Body())
- }
-
- wrapAgent := func(a *Agent) {
- a.BodyString("foo=bar&bar=baz")
- }
-
- testAgent(t, handler, wrapAgent, "foo=bar&bar=baz")
-}
-
-func Test_Client_Agent_Body(t *testing.T) {
- t.Parallel()
- handler := func(c Ctx) error {
- return c.Send(c.Request().Body())
- }
-
- wrapAgent := func(a *Agent) {
- a.Body([]byte("foo=bar&bar=baz"))
- }
-
- testAgent(t, handler, wrapAgent, "foo=bar&bar=baz")
-}
-
-func Test_Client_Agent_BodyStream(t *testing.T) {
- t.Parallel()
- handler := func(c Ctx) error {
- return c.Send(c.Request().Body())
- }
-
- wrapAgent := func(a *Agent) {
- a.BodyStream(strings.NewReader("body stream"), -1)
- }
-
- testAgent(t, handler, wrapAgent, "body stream")
-}
-
-func Test_Client_Agent_Custom_Response(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Get("/", func(c Ctx) error {
- return c.SendString("custom")
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- for i := 0; i < 5; i++ {
- a := AcquireAgent()
- resp := AcquireResponse()
-
- req := a.Request()
- req.Header.SetMethod(MethodGet)
- req.SetRequestURI("http://example.com")
-
- require.NoError(t, a.Parse())
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.SetResponse(resp).
- String()
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, "custom", body)
- require.Equal(t, "custom", string(resp.Body()))
- require.Empty(t, errs)
-
- ReleaseResponse(resp)
- }
-}
-
-func Test_Client_Agent_Dest(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Get("/", func(c Ctx) error {
- return c.SendString("dest")
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- t.Run("small dest", func(t *testing.T) {
- t.Parallel()
- dest := []byte("de")
-
- a := Get("http://example.com")
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.Dest(dest[:0]).String()
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, "dest", body)
- require.Equal(t, "de", string(dest))
- require.Empty(t, errs)
- })
-
- t.Run("enough dest", func(t *testing.T) {
- t.Parallel()
- dest := []byte("foobar")
-
- a := Get("http://example.com")
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.Dest(dest[:0]).String()
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, "dest", body)
- require.Equal(t, "destar", string(dest))
- require.Empty(t, errs)
- })
-}
-
-// readErrorConn is a struct for testing retryIf
-type readErrorConn struct {
- net.Conn
-}
-
-func (*readErrorConn) Read(_ []byte) (int, error) {
- return 0, errors.New("error")
-}
-
-func (*readErrorConn) Write(p []byte) (int, error) {
- return len(p), nil
-}
-
-func (*readErrorConn) Close() error {
- return nil
-}
-
-func (*readErrorConn) LocalAddr() net.Addr {
- return nil
-}
-
-func (*readErrorConn) RemoteAddr() net.Addr {
- return nil
-}
-
-func (*readErrorConn) SetReadDeadline(_ time.Time) error {
- return nil
-}
-
-func (*readErrorConn) SetWriteDeadline(_ time.Time) error {
- return nil
-}
-
-func Test_Client_Agent_RetryIf(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- a := Post("http://example.com").
- RetryIf(func(_ *Request) bool {
- return true
- })
- dialsCount := 0
- a.HostClient.Dial = func(_ string) (net.Conn, error) {
- dialsCount++
- switch dialsCount {
- case 1:
- return &readErrorConn{}, nil
- case 2:
- return &readErrorConn{}, nil
- case 3:
- return &readErrorConn{}, nil
- case 4:
- return ln.Dial()
- default:
- t.Fatalf("unexpected number of dials: %d", dialsCount)
- }
- panic("unreachable")
- }
-
- _, _, errs := a.String()
- require.Equal(t, 4, dialsCount)
- require.Empty(t, errs)
-}
-
-func Test_Client_Agent_Json(t *testing.T) {
- t.Parallel()
- // Test without ctype parameter
- handler := func(c Ctx) error {
- require.Equal(t, MIMEApplicationJSON, string(c.Request().Header.ContentType()))
-
- return c.Send(c.Request().Body())
- }
-
- wrapAgent := func(a *Agent) {
- a.JSON(data{Success: true})
- }
-
- testAgent(t, handler, wrapAgent, `{"success":true}`)
-
- // Test with ctype parameter
- handler = func(c Ctx) error {
- require.Equal(t, "application/problem+json", string(c.Request().Header.ContentType()))
-
- return c.Send(c.Request().Body())
- }
-
- wrapAgent = func(a *Agent) {
- a.JSON(data{Success: true}, "application/problem+json")
- }
-
- testAgent(t, handler, wrapAgent, `{"success":true}`)
-}
-
-func Test_Client_Agent_Json_Error(t *testing.T) {
- t.Parallel()
- a := Get("http://example.com").
- JSONEncoder(json.Marshal).
- JSON(complex(1, 1))
-
- _, body, errs := a.String()
-
- require.Equal(t, "", body)
- require.Len(t, errs, 1)
- wantErr := new(json.UnsupportedTypeError)
- require.ErrorAs(t, errs[0], &wantErr)
-}
-
-func Test_Client_Agent_XML(t *testing.T) {
- t.Parallel()
- handler := func(c Ctx) error {
- require.Equal(t, MIMEApplicationXML, string(c.Request().Header.ContentType()))
-
- return c.Send(c.Request().Body())
- }
-
- wrapAgent := func(a *Agent) {
- a.XML(data{Success: true})
- }
-
- testAgent(t, handler, wrapAgent, "true")
-}
-
-func Test_Client_Agent_XML_Error(t *testing.T) {
- t.Parallel()
- a := Get("http://example.com").
- XML(complex(1, 1))
-
- _, body, errs := a.String()
- require.Equal(t, "", body)
- require.Len(t, errs, 1)
- wantErr := new(xml.UnsupportedTypeError)
- require.ErrorAs(t, errs[0], &wantErr)
-}
-
-func Test_Client_Agent_Form(t *testing.T) {
- t.Parallel()
- handler := func(c Ctx) error {
- require.Equal(t, MIMEApplicationForm, string(c.Request().Header.ContentType()))
-
- return c.Send(c.Request().Body())
- }
-
- args := AcquireArgs()
-
- args.Set("foo", "bar")
-
- wrapAgent := func(a *Agent) {
- a.Form(args)
- }
-
- testAgent(t, handler, wrapAgent, "foo=bar")
-
- ReleaseArgs(args)
-}
-
-func Test_Client_Agent_MultipartForm(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Post("/", func(c Ctx) error {
- require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(HeaderContentType))
-
- mf, err := c.MultipartForm()
- require.NoError(t, err)
- require.Equal(t, "bar", mf.Value["foo"][0])
-
- return c.Send(c.Request().Body())
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- args := AcquireArgs()
-
- args.Set("foo", "bar")
-
- a := Post("http://example.com").
- Boundary("myBoundary").
- MultipartForm(args)
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.String()
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, "--myBoundary\r\nContent-Disposition: form-data; name=\"foo\"\r\n\r\nbar\r\n--myBoundary--\r\n", body)
- require.Empty(t, errs)
- ReleaseArgs(args)
-}
-
-func Test_Client_Agent_MultipartForm_Errors(t *testing.T) {
- t.Parallel()
-
- a := AcquireAgent()
- a.mw = &errorMultipartWriter{}
-
- args := AcquireArgs()
- args.Set("foo", "bar")
-
- ff1 := &FormFile{"", "name1", []byte("content"), false}
- ff2 := &FormFile{"", "name2", []byte("content"), false}
- a.FileData(ff1, ff2).
- MultipartForm(args)
-
- require.Len(t, a.errs, 4)
- ReleaseArgs(args)
-}
-
-func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Post("/", func(c Ctx) error {
- require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(HeaderContentType))
-
- fh1, err := c.FormFile("field1")
- require.NoError(t, err)
- require.Equal(t, "name", fh1.Filename)
- buf := make([]byte, fh1.Size)
- f, err := fh1.Open()
- require.NoError(t, err)
- defer func() {
- err := f.Close()
- require.NoError(t, err)
- }()
- _, err = f.Read(buf)
- require.NoError(t, err)
- require.Equal(t, "form file", string(buf))
-
- fh2, err := c.FormFile("index")
- require.NoError(t, err)
- checkFormFile(t, fh2, ".github/testdata/index.html")
-
- fh3, err := c.FormFile("file3")
- require.NoError(t, err)
- checkFormFile(t, fh3, ".github/testdata/index.tmpl")
-
- return c.SendString("multipart form files")
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- for i := 0; i < 5; i++ {
- ff := AcquireFormFile()
- ff.Fieldname = "field1"
- ff.Name = "name"
- ff.Content = []byte("form file")
-
- a := Post("http://example.com").
- Boundary("myBoundary").
- FileData(ff).
- SendFiles(".github/testdata/index.html", "index", ".github/testdata/index.tmpl").
- MultipartForm(nil)
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.String()
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, "multipart form files", body)
- require.Empty(t, errs)
-
- ReleaseFormFile(ff)
- }
-}
-
-func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) {
- t.Helper()
-
- basename := filepath.Base(filename)
- require.Equal(t, fh.Filename, basename)
-
- b1, err := os.ReadFile(filename) //nolint:gosec // We're in a test so reading user-provided files by name is fine
- require.NoError(t, err)
-
- b2 := make([]byte, fh.Size)
- f, err := fh.Open()
- require.NoError(t, err)
- defer func() {
- err := f.Close()
- require.NoError(t, err)
- }()
- _, err = f.Read(b2)
- require.NoError(t, err)
- require.Equal(t, b1, b2)
-}
-
-func Test_Client_Agent_Multipart_Random_Boundary(t *testing.T) {
- t.Parallel()
-
- a := Post("http://example.com").
- MultipartForm(nil)
-
- reg := regexp.MustCompile(`multipart/form-data; boundary=\w{30}`)
-
- require.True(t, reg.Match(a.req.Header.Peek(HeaderContentType)))
-}
-
-func Test_Client_Agent_Multipart_Invalid_Boundary(t *testing.T) {
- t.Parallel()
-
- a := Post("http://example.com").
- Boundary("*").
- MultipartForm(nil)
-
- require.Len(t, a.errs, 1)
- require.ErrorContains(t, a.errs[0], "mime: invalid boundary character")
-}
-
-func Test_Client_Agent_SendFile_Error(t *testing.T) {
- t.Parallel()
-
- a := Post("http://example.com").
- SendFile("non-exist-file!", "")
-
- require.Len(t, a.errs, 1)
- require.ErrorIs(t, a.errs[0], os.ErrNotExist)
-}
-
-func Test_Client_Debug(t *testing.T) {
- t.Parallel()
- handler := func(c Ctx) error {
- return c.SendString("debug")
- }
-
- var output bytes.Buffer
-
- wrapAgent := func(a *Agent) {
- a.Debug(&output)
- }
-
- testAgent(t, handler, wrapAgent, "debug", 1)
-
- str := output.String()
-
- require.Contains(t, str, "Connected to example.com(InmemoryListener)")
- require.Contains(t, str, "GET / HTTP/1.1")
- require.Contains(t, str, "User-Agent: fiber")
- require.Contains(t, str, "Host: example.com\r\n\r\n")
- require.Contains(t, str, "HTTP/1.1 200 OK")
- require.Contains(t, str, "Content-Type: text/plain; charset=utf-8\r\nContent-Length: 5\r\n\r\ndebug")
-}
-
-func Test_Client_Agent_Timeout(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Get("/", func(c Ctx) error {
- time.Sleep(time.Millisecond * 200)
- return c.SendString("timeout")
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- a := Get("http://example.com").
- Timeout(time.Millisecond * 50)
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- _, body, errs := a.String()
-
- require.Equal(t, "", body)
- require.Len(t, errs, 1)
- require.ErrorIs(t, errs[0], fasthttp.ErrTimeout)
-}
-
-func Test_Client_Agent_Reuse(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Get("/", func(c Ctx) error {
- return c.SendString("reuse")
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- a := Get("http://example.com").
- Reuse()
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.String()
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, "reuse", body)
- require.Empty(t, errs)
-
- code, body, errs = a.String()
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, "reuse", body)
- require.Empty(t, errs)
-}
-
-func Test_Client_Agent_InsecureSkipVerify(t *testing.T) {
- t.Parallel()
-
- cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key")
- require.NoError(t, err)
-
- //nolint:gosec // We're in a test so using old ciphers is fine
- serverTLSConf := &tls.Config{
- Certificates: []tls.Certificate{cer},
- }
-
- ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0")
- require.NoError(t, err)
-
- ln = tls.NewListener(ln, serverTLSConf)
-
- app := New()
-
- app.Get("/", func(c Ctx) error {
- return c.SendString("ignore tls")
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- code, body, errs := Get("https://" + ln.Addr().String()).
- InsecureSkipVerify().
- InsecureSkipVerify().
- String()
-
- require.Empty(t, errs)
- require.Equal(t, StatusOK, code)
- require.Equal(t, "ignore tls", body)
-}
-
-func Test_Client_Agent_TLS(t *testing.T) {
- t.Parallel()
-
- serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs()
- require.NoError(t, err)
-
- ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0")
- require.NoError(t, err)
-
- ln = tls.NewListener(ln, serverTLSConf)
-
- app := New()
-
- app.Get("/", func(c Ctx) error {
- return c.SendString("tls")
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- code, body, errs := Get("https://" + ln.Addr().String()).
- TLSConfig(clientTLSConf).
- String()
-
- require.Empty(t, errs)
- require.Equal(t, StatusOK, code)
- require.Equal(t, "tls", body)
-}
-
-func Test_Client_Agent_MaxRedirectsCount(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Get("/", func(c Ctx) error {
- if c.Request().URI().QueryArgs().Has("foo") {
- return c.Redirect().To("/foo")
- }
- return c.Redirect().To("/")
- })
- app.Get("/foo", func(c Ctx) error {
- return c.SendString("redirect")
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- t.Run("success", func(t *testing.T) {
- t.Parallel()
- a := Get("http://example.com?foo").
- MaxRedirectsCount(1)
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.String()
-
- require.Equal(t, 200, code)
- require.Equal(t, "redirect", body)
- require.Empty(t, errs)
- })
-
- t.Run("error", func(t *testing.T) {
- t.Parallel()
- a := Get("http://example.com").
- MaxRedirectsCount(1)
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- _, body, errs := a.String()
-
- require.Equal(t, "", body)
- require.Len(t, errs, 1)
- require.ErrorIs(t, errs[0], fasthttp.ErrTooManyRedirects)
- })
-}
-
-func Test_Client_Agent_Struct(t *testing.T) {
- t.Parallel()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Get("/", func(c Ctx) error {
- return c.JSON(data{true})
- })
-
- app.Get("/error", func(c Ctx) error {
- return c.SendString(`{"success"`)
- })
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- t.Run("success", func(t *testing.T) {
- t.Parallel()
-
- a := Get("http://example.com")
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- var d data
-
- code, body, errs := a.Struct(&d)
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, `{"success":true}`, string(body))
- require.Empty(t, errs)
- require.True(t, d.Success)
- })
-
- t.Run("pre error", func(t *testing.T) {
- t.Parallel()
- a := Get("http://example.com")
-
- errPre := errors.New("pre errors")
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
- a.errs = append(a.errs, errPre)
-
- var d data
- _, body, errs := a.Struct(&d)
-
- require.Equal(t, "", string(body))
- require.Len(t, errs, 1)
- require.ErrorIs(t, errs[0], errPre)
- require.False(t, d.Success)
- })
-
- t.Run("error", func(t *testing.T) {
- t.Parallel()
- a := Get("http://example.com/error")
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- var d data
-
- code, body, errs := a.JSONDecoder(json.Unmarshal).Struct(&d)
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, `{"success"`, string(body))
- require.Len(t, errs, 1)
- wantErr := new(json.SyntaxError)
- require.ErrorAs(t, errs[0], &wantErr)
- require.EqualValues(t, 10, wantErr.Offset)
- })
-
- t.Run("nil jsonDecoder", func(t *testing.T) {
- t.Parallel()
- a := AcquireAgent()
- defer ReleaseAgent(a)
- defer a.ConnectionClose()
- request := a.Request()
- request.Header.SetMethod(MethodGet)
- request.SetRequestURI("http://example.com")
- err := a.Parse()
- require.NoError(t, err)
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
- var d data
- code, body, errs := a.Struct(&d)
- require.Equal(t, StatusOK, code)
- require.Equal(t, `{"success":true}`, string(body))
- require.Empty(t, errs)
- require.True(t, d.Success)
- })
-}
-
-func Test_Client_Agent_Parse(t *testing.T) {
- t.Parallel()
-
- a := Get("https://example.com:10443")
-
- require.NoError(t, a.Parse())
-}
-
-func testAgent(t *testing.T, handler Handler, wrapAgent func(agent *Agent), excepted string, count ...int) {
- t.Helper()
-
- ln := fasthttputil.NewInmemoryListener()
-
- app := New()
-
- app.Get("/", handler)
-
- go func() {
- require.NoError(t, app.Listener(ln, ListenConfig{
- DisableStartupMessage: true,
- }))
- }()
-
- c := 1
- if len(count) > 0 {
- c = count[0]
- }
-
- for i := 0; i < c; i++ {
- a := Get("http://example.com")
-
- wrapAgent(a)
-
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
-
- code, body, errs := a.String()
-
- require.Equal(t, StatusOK, code)
- require.Equal(t, excepted, body)
- require.Empty(t, errs)
- }
-}
-
-type data struct {
- Success bool `json:"success" xml:"success"`
-}
-
-type errorMultipartWriter struct {
- count int
-}
-
-func (*errorMultipartWriter) Boundary() string { return "myBoundary" }
-func (*errorMultipartWriter) SetBoundary(_ string) error { return nil }
-func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) {
- if e.count == 0 {
- e.count++
- return nil, errors.New("CreateFormFile error")
- }
- return errorWriter{}, nil
-}
-func (*errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") }
-func (*errorMultipartWriter) Close() error { return errors.New("Close error") }
-
-type errorWriter struct{}
-
-func (errorWriter) Write(_ []byte) (int, error) { return 0, errors.New("Write error") }
diff --git a/listen_test.go b/listen_test.go
index d92b9fb396..a5d419ac86 100644
--- a/listen_test.go
+++ b/listen_test.go
@@ -17,6 +17,7 @@ import (
"time"
"github.com/stretchr/testify/require"
+ "github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
)
@@ -68,22 +69,30 @@ func Test_Listen_Graceful_Shutdown(t *testing.T) {
Time time.Duration
ExpectedBody string
ExpectedStatusCode int
- ExceptedErrsLen int
+ ExpectedErr error
}{
- {Time: 100 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExceptedErrsLen: 0},
- {Time: 500 * time.Millisecond, ExpectedBody: "", ExpectedStatusCode: 0, ExceptedErrsLen: 1},
+ {Time: 100 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExpectedErr: nil},
+ {Time: 500 * time.Millisecond, ExpectedBody: "", ExpectedStatusCode: StatusOK, ExpectedErr: errors.New("InmemoryListener is already closed: use of closed network connection")},
}
for _, tc := range testCases {
time.Sleep(tc.Time)
- a := Get("http://example.com")
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
- code, body, errs := a.String()
+ req := fasthttp.AcquireRequest()
+ req.SetRequestURI("http://example.com")
- require.Equal(t, tc.ExpectedStatusCode, code)
- require.Equal(t, tc.ExpectedBody, body)
- require.Len(t, errs, tc.ExceptedErrsLen)
+ client := fasthttp.HostClient{}
+ client.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
+
+ resp := fasthttp.AcquireResponse()
+ err := client.Do(req, resp)
+
+ require.Equal(t, tc.ExpectedErr, err)
+ require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode())
+ require.Equal(t, tc.ExpectedBody, string(resp.Body()))
+
+ fasthttp.ReleaseRequest(req)
+ fasthttp.ReleaseResponse(resp)
}
mu.Lock()
diff --git a/log/default.go b/log/default.go
index e9c3d1bbbd..67fa137ecf 100644
--- a/log/default.go
+++ b/log/default.go
@@ -34,6 +34,7 @@ func (l *defaultLogger) privateLog(lv Level, fmtArgs []any) {
if lv == LevelPanic {
panic(buf.String())
}
+
buf.Reset()
bytebufferpool.Put(buf)
if lv == LevelFatal {
@@ -56,6 +57,7 @@ func (l *defaultLogger) privateLogf(lv Level, format string, fmtArgs []any) {
} else {
_, _ = fmt.Fprint(buf, fmtArgs...)
}
+
_ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error
if lv == LevelPanic {
panic(buf.String())
diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go
index 284b67c8f5..167a0e6f31 100644
--- a/middleware/proxy/proxy.go
+++ b/middleware/proxy/proxy.go
@@ -2,7 +2,6 @@ package proxy
import (
"bytes"
- "crypto/tls"
"net/url"
"strings"
"sync"
@@ -105,13 +104,6 @@ var client = &fasthttp.Client{
var lock sync.RWMutex
-// WithTLSConfig update http client with a user specified tls.config
-// This function should be called before Do and Forward.
-// Deprecated: use WithClient instead.
-func WithTLSConfig(tlsConfig *tls.Config) {
- client.TLSConfig = tlsConfig
-}
-
// WithClient sets the global proxy client.
// This function should be called before Do and Forward.
func WithClient(cli *fasthttp.Client) {
diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go
index 408ee71a5f..4aa0065040 100644
--- a/middleware/proxy/proxy_test.go
+++ b/middleware/proxy/proxy_test.go
@@ -11,8 +11,10 @@ import (
"time"
"github.com/gofiber/fiber/v3"
- "github.com/gofiber/fiber/v3/internal/tlstest"
+ clientpkg "github.com/gofiber/fiber/v3/client"
"github.com/stretchr/testify/require"
+
+ "github.com/gofiber/fiber/v3/internal/tlstest"
"github.com/valyala/fasthttp"
)
@@ -25,8 +27,6 @@ func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, str
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
- addr := ln.Addr().String()
-
go func() {
require.NoError(t, target.Listener(ln, fiber.ListenConfig{
DisableStartupMessage: true,
@@ -34,6 +34,7 @@ func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, str
}()
time.Sleep(2 * time.Second)
+ addr := ln.Addr().String()
return target, addr
}
@@ -104,8 +105,8 @@ func Test_Proxy(t *testing.T) {
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
}
-// go test -run Test_Proxy_Balancer_WithTLSConfig
-func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) {
+// go test -run Test_Proxy_Balancer_WithTlsConfig
+func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) {
t.Parallel()
serverTLSConf, _, err := tlstest.GetTLSConfigs()
@@ -118,7 +119,7 @@ func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) {
app := fiber.New()
- app.Get("/tlsbalaner", func(c fiber.Ctx) error {
+ app.Get("/tlsbalancer", func(c fiber.Ctx) error {
return c.SendString("tls balancer")
})
@@ -137,15 +138,18 @@ func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) {
}))
}()
- code, body, errs := fiber.Get("https://" + addr + "/tlsbalaner").TLSConfig(clientTLSConf).String()
+ client := clientpkg.NewClient()
+ client.SetTLSConfig(clientTLSConf)
- require.Empty(t, errs)
- require.Equal(t, fiber.StatusOK, code)
- require.Equal(t, "tls balancer", body)
+ resp, err := client.Get("https://" + addr + "/tlsbalancer")
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "tls balancer", string(resp.Body()))
+ resp.Close()
}
-// go test -run Test_Proxy_Forward_WithTLSConfig_To_Http
-func Test_Proxy_Forward_WithTLSConfig_To_Http(t *testing.T) {
+// go test -run Test_Proxy_Forward_WithTlsConfig_To_Http
+func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) {
t.Parallel()
_, targetAddr := createProxyTestServer(t, func(c fiber.Ctx) error {
@@ -172,14 +176,15 @@ func Test_Proxy_Forward_WithTLSConfig_To_Http(t *testing.T) {
}))
}()
- code, body, errs := fiber.Get("https://" + proxyAddr).
- InsecureSkipVerify().
- Timeout(5 * time.Second).
- String()
+ client := clientpkg.NewClient()
+ client.SetTimeout(5 * time.Second)
+ client.TLSConfig().InsecureSkipVerify = true
- require.Empty(t, errs)
- require.Equal(t, fiber.StatusOK, code)
- require.Equal(t, "hello from target", body)
+ resp, err := client.Get("https://" + proxyAddr)
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "hello from target", string(resp.Body()))
+ resp.Close()
}
// go test -run Test_Proxy_Forward
@@ -203,8 +208,8 @@ func Test_Proxy_Forward(t *testing.T) {
require.Equal(t, "forwarded", string(b))
}
-// go test -run Test_Proxy_Forward_WithTLSConfig
-func Test_Proxy_Forward_WithTLSConfig(t *testing.T) {
+// go test -run Test_Proxy_Forward_WithClient_TLSConfig
+func Test_Proxy_Forward_WithClient_TLSConfig(t *testing.T) {
t.Parallel()
serverTLSConf, _, err := tlstest.GetTLSConfigs()
@@ -225,7 +230,9 @@ func Test_Proxy_Forward_WithTLSConfig(t *testing.T) {
clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
// disable certificate verification
- WithTLSConfig(clientTLSConf)
+ WithClient(&fasthttp.Client{
+ TLSConfig: clientTLSConf,
+ })
app.Use(Forward("https://" + addr + "/tlsfwd"))
go func() {
@@ -234,11 +241,14 @@ func Test_Proxy_Forward_WithTLSConfig(t *testing.T) {
}))
}()
- code, body, errs := fiber.Get("https://" + addr).TLSConfig(clientTLSConf).String()
+ client := clientpkg.NewClient()
+ client.SetTLSConfig(clientTLSConf)
- require.Empty(t, errs)
- require.Equal(t, fiber.StatusOK, code)
- require.Equal(t, "tls forward", body)
+ resp, err := client.Get("https://" + addr)
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "tls forward", string(resp.Body()))
+ resp.Close()
}
// go test -run Test_Proxy_Modify_Response
@@ -415,7 +425,7 @@ func Test_Proxy_Do_WithRedirect(t *testing.T) {
return Do(c, "https://google.com")
})
- resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500)
+ resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
require.NoError(t, err1)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
@@ -431,7 +441,7 @@ func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) {
return DoRedirects(c, "http://google.com", 1)
})
- resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500)
+ resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
require.NoError(t, err1)
_, err := io.ReadAll(resp.Body)
require.NoError(t, err)
@@ -447,7 +457,7 @@ func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) {
return DoRedirects(c, "http://google.com", 0)
})
- resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500)
+ resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
require.NoError(t, err1)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
@@ -586,10 +596,13 @@ func Test_Proxy_Forward_Global_Client(t *testing.T) {
}))
}()
- code, body, errs := fiber.Get("http://" + addr).String()
- require.Empty(t, errs)
- require.Equal(t, fiber.StatusOK, code)
- require.Equal(t, "test_global_client", body)
+ client := clientpkg.NewClient()
+
+ resp, err := client.Get("http://" + addr)
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "test_global_client", string(resp.Body()))
+ resp.Close()
}
// go test -race -run Test_Proxy_Forward_Local_Client
@@ -615,10 +628,13 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) {
}))
}()
- code, body, errs := fiber.Get("http://" + addr).String()
- require.Empty(t, errs)
- require.Equal(t, fiber.StatusOK, code)
- require.Equal(t, "test_local_client", body)
+ client := clientpkg.NewClient()
+
+ resp, err := client.Get("http://" + addr)
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "test_local_client", string(resp.Body()))
+ resp.Close()
}
// go test -run Test_ProxyBalancer_Custom_Client
@@ -666,7 +682,7 @@ func Test_Proxy_Domain_Forward_Local(t *testing.T) {
app1 := fiber.New()
app1.Get("/test", func(c fiber.Ctx) error {
- return c.SendString("test_local_client:" + fiber.Query[string](c, "query_test"))
+ return c.SendString("test_local_client:" + c.Query("query_test"))
})
proxyAddr := ln.Addr().String()
@@ -679,13 +695,24 @@ func Test_Proxy_Domain_Forward_Local(t *testing.T) {
Dial: fasthttp.Dial,
}))
- go func() { require.NoError(t, app.Listener(ln)) }()
- go func() { require.NoError(t, app1.Listener(ln1)) }()
+ go func() {
+ require.NoError(t, app.Listener(ln, fiber.ListenConfig{
+ DisableStartupMessage: true,
+ }))
+ }()
+ go func() {
+ require.NoError(t, app1.Listener(ln1, fiber.ListenConfig{
+ DisableStartupMessage: true,
+ }))
+ }()
+
+ client := clientpkg.NewClient()
- code, body, errs := fiber.Get("http://" + localDomain + "/test?query_test=true").String()
- require.Empty(t, errs)
- require.Equal(t, fiber.StatusOK, code)
- require.Equal(t, "test_local_client:true", body)
+ resp, err := client.Get("http://" + localDomain + "/test?query_test=true")
+ require.NoError(t, err)
+ require.Equal(t, fiber.StatusOK, resp.StatusCode())
+ require.Equal(t, "test_local_client:true", string(resp.Body()))
+ resp.Close()
}
// go test -run Test_Proxy_Balancer_Forward_Local
diff --git a/redirect_test.go b/redirect_test.go
index d49f526771..6dd2ae6d19 100644
--- a/redirect_test.go
+++ b/redirect_test.go
@@ -291,41 +291,45 @@ func Test_Redirect_Request(t *testing.T) {
CookieValue string
ExpectedBody string
ExpectedStatusCode int
- ExceptedErrsLen int
+ ExpectedErr error
}{
{
URL: "/",
CookieValue: "key:value,key2:value2,co\\:m\\,ma:Fi\\:ber\\, v3",
ExpectedBody: `{"inputs":{},"messages":{"co:m,ma":"Fi:ber, v3","key":"value","key2":"value2"}}`,
ExpectedStatusCode: StatusOK,
- ExceptedErrsLen: 0,
+ ExpectedErr: nil,
},
{
URL: "/with-inputs?name=john&surname=doe",
CookieValue: "key:value,key2:value2,key:value,key2:value2,old_input_data_name:john,old_input_data_surname:doe",
ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{"key":"value","key2":"value2"}}`,
ExpectedStatusCode: StatusOK,
- ExceptedErrsLen: 0,
+ ExpectedErr: nil,
},
{
URL: "/just-inputs?name=john&surname=doe",
CookieValue: "old_input_data_name:john,old_input_data_surname:doe",
ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{}}`,
ExpectedStatusCode: StatusOK,
- ExceptedErrsLen: 0,
+ ExpectedErr: nil,
},
}
for _, tc := range testCases {
- a := Get("http://example.com" + tc.URL)
- a.Cookie(FlashCookieName, tc.CookieValue)
- a.MaxRedirectsCount(1)
- a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() }
- code, body, errs := a.String()
-
- require.Equal(t, tc.ExpectedStatusCode, code)
- require.Equal(t, tc.ExpectedBody, body)
- require.Len(t, errs, tc.ExceptedErrsLen)
+ client := &fasthttp.HostClient{
+ Dial: func(_ string) (net.Conn, error) {
+ return ln.Dial()
+ },
+ }
+ req, resp := fasthttp.AcquireRequest(), fasthttp.AcquireResponse()
+ req.SetRequestURI("http://example.com" + tc.URL)
+ req.Header.SetCookie(FlashCookieName, tc.CookieValue)
+ err := client.DoRedirects(req, resp, 1)
+
+ require.NoError(t, err)
+ require.Equal(t, tc.ExpectedBody, string(resp.Body()))
+ require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode())
}
}