diff --git a/client.go b/client.go deleted file mode 100644 index 8825f9d815..0000000000 --- a/client.go +++ /dev/null @@ -1,1021 +0,0 @@ -package fiber - -import ( - "bytes" - "crypto/tls" - "encoding/json" - "encoding/xml" - "errors" - "fmt" - "io" - "mime/multipart" - "os" - "path/filepath" - "strconv" - "sync" - "time" - - "github.com/gofiber/utils/v2" - "github.com/valyala/fasthttp" -) - -// Request represents HTTP request. -// -// It is forbidden copying Request instances. Create new instances -// and use CopyTo instead. -// -// Request instance MUST NOT be used from concurrently running goroutines. -// Copy from fasthttp -type Request = fasthttp.Request - -// Response represents HTTP response. -// -// It is forbidden copying Response instances. Create new instances -// and use CopyTo instead. -// -// Response instance MUST NOT be used from concurrently running goroutines. -// Copy from fasthttp -type Response = fasthttp.Response - -// Args represents query arguments. -// -// It is forbidden copying Args instances. Create new instances instead -// and use CopyTo(). -// -// Args instance MUST NOT be used from concurrently running goroutines. -// Copy from fasthttp -type Args = fasthttp.Args - -// RetryIfFunc signature of retry if function -// Request argument passed to RetryIfFunc, if there are any request errors. -// Copy from fasthttp -type RetryIfFunc = fasthttp.RetryIfFunc - -var defaultClient Client - -// Client implements http client. -// -// It is safe calling Client methods from concurrently running goroutines. -type Client struct { - mutex sync.RWMutex - // UserAgent is used in User-Agent request header. - UserAgent string - - // NoDefaultUserAgentHeader when set to true, causes the default - // User-Agent header to be excluded from the Request. - NoDefaultUserAgentHeader bool - - // When set by an external client of Fiber it will use the provided implementation of a - // JSONMarshal - // - // Allowing for flexibility in using another json library for encoding - JSONEncoder utils.JSONMarshal - - // When set by an external client of Fiber it will use the provided implementation of a - // JSONUnmarshal - // - // Allowing for flexibility in using another json library for decoding - JSONDecoder utils.JSONUnmarshal -} - -// Get returns an agent with http method GET. -func Get(url string) *Agent { return defaultClient.Get(url) } - -// Get returns an agent with http method GET. -func (c *Client) Get(url string) *Agent { - return c.createAgent(MethodGet, url) -} - -// Head returns an agent with http method HEAD. -func Head(url string) *Agent { return defaultClient.Head(url) } - -// Head returns an agent with http method GET. -func (c *Client) Head(url string) *Agent { - return c.createAgent(MethodHead, url) -} - -// Post sends POST request to the given URL. -func Post(url string) *Agent { return defaultClient.Post(url) } - -// Post sends POST request to the given URL. -func (c *Client) Post(url string) *Agent { - return c.createAgent(MethodPost, url) -} - -// Put sends PUT request to the given URL. -func Put(url string) *Agent { return defaultClient.Put(url) } - -// Put sends PUT request to the given URL. -func (c *Client) Put(url string) *Agent { - return c.createAgent(MethodPut, url) -} - -// Patch sends PATCH request to the given URL. -func Patch(url string) *Agent { return defaultClient.Patch(url) } - -// Patch sends PATCH request to the given URL. -func (c *Client) Patch(url string) *Agent { - return c.createAgent(MethodPatch, url) -} - -// Delete sends DELETE request to the given URL. -func Delete(url string) *Agent { return defaultClient.Delete(url) } - -// Delete sends DELETE request to the given URL. -func (c *Client) Delete(url string) *Agent { - return c.createAgent(MethodDelete, url) -} - -func (c *Client) createAgent(method, url string) *Agent { - a := AcquireAgent() - a.req.Header.SetMethod(method) - a.req.SetRequestURI(url) - - c.mutex.RLock() - a.Name = c.UserAgent - a.NoDefaultUserAgentHeader = c.NoDefaultUserAgentHeader - a.jsonDecoder = c.JSONDecoder - a.jsonEncoder = c.JSONEncoder - if a.jsonDecoder == nil { - a.jsonDecoder = json.Unmarshal - } - c.mutex.RUnlock() - - if err := a.Parse(); err != nil { - a.errs = append(a.errs, err) - } - - return a -} - -// Agent is an object storing all request data for client. -// Agent instance MUST NOT be used from concurrently running goroutines. -type Agent struct { - // Name is used in User-Agent request header. - Name string - - // NoDefaultUserAgentHeader when set to true, causes the default - // User-Agent header to be excluded from the Request. - NoDefaultUserAgentHeader bool - - // HostClient is an embedded fasthttp HostClient - *fasthttp.HostClient - - req *Request - resp *Response - dest []byte - args *Args - timeout time.Duration - errs []error - formFiles []*FormFile - debugWriter io.Writer - mw multipartWriter - jsonEncoder utils.JSONMarshal - jsonDecoder utils.JSONUnmarshal - maxRedirectsCount int - boundary string - reuse bool - parsed bool -} - -// Parse initializes URI and HostClient. -func (a *Agent) Parse() error { - if a.parsed { - return nil - } - a.parsed = true - - uri := a.req.URI() - - var isTLS bool - scheme := uri.Scheme() - if bytes.Equal(scheme, []byte(schemeHTTPS)) { - isTLS = true - } else if !bytes.Equal(scheme, []byte(schemeHTTP)) { - return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme) - } - - name := a.Name - if name == "" && !a.NoDefaultUserAgentHeader { - name = defaultUserAgent - } - - a.HostClient = &fasthttp.HostClient{ - Addr: fasthttp.AddMissingPort(string(uri.Host()), isTLS), - Name: name, - NoDefaultUserAgentHeader: a.NoDefaultUserAgentHeader, - IsTLS: isTLS, - } - - return nil -} - -/************************** Header Setting **************************/ - -// Set sets the given 'key: value' header. -// -// Use Add for setting multiple header values under the same key. -func (a *Agent) Set(k, v string) *Agent { - a.req.Header.Set(k, v) - - return a -} - -// SetBytesK sets the given 'key: value' header. -// -// Use AddBytesK for setting multiple header values under the same key. -func (a *Agent) SetBytesK(k []byte, v string) *Agent { - a.req.Header.SetBytesK(k, v) - - return a -} - -// SetBytesV sets the given 'key: value' header. -// -// Use AddBytesV for setting multiple header values under the same key. -func (a *Agent) SetBytesV(k string, v []byte) *Agent { - a.req.Header.SetBytesV(k, v) - - return a -} - -// SetBytesKV sets the given 'key: value' header. -// -// Use AddBytesKV for setting multiple header values under the same key. -func (a *Agent) SetBytesKV(k, v []byte) *Agent { - a.req.Header.SetBytesKV(k, v) - - return a -} - -// Add adds the given 'key: value' header. -// -// Multiple headers with the same key may be added with this function. -// Use Set for setting a single header for the given key. -func (a *Agent) Add(k, v string) *Agent { - a.req.Header.Add(k, v) - - return a -} - -// AddBytesK adds the given 'key: value' header. -// -// Multiple headers with the same key may be added with this function. -// Use SetBytesK for setting a single header for the given key. -func (a *Agent) AddBytesK(k []byte, v string) *Agent { - a.req.Header.AddBytesK(k, v) - - return a -} - -// AddBytesV adds the given 'key: value' header. -// -// Multiple headers with the same key may be added with this function. -// Use SetBytesV for setting a single header for the given key. -func (a *Agent) AddBytesV(k string, v []byte) *Agent { - a.req.Header.AddBytesV(k, v) - - return a -} - -// AddBytesKV adds the given 'key: value' header. -// -// Multiple headers with the same key may be added with this function. -// Use SetBytesKV for setting a single header for the given key. -func (a *Agent) AddBytesKV(k, v []byte) *Agent { - a.req.Header.AddBytesKV(k, v) - - return a -} - -// ConnectionClose sets 'Connection: close' header. -func (a *Agent) ConnectionClose() *Agent { - a.req.Header.SetConnectionClose() - - return a -} - -// UserAgent sets User-Agent header value. -func (a *Agent) UserAgent(userAgent string) *Agent { - a.req.Header.SetUserAgent(userAgent) - - return a -} - -// UserAgentBytes sets User-Agent header value. -func (a *Agent) UserAgentBytes(userAgent []byte) *Agent { - a.req.Header.SetUserAgentBytes(userAgent) - - return a -} - -// Cookie sets one 'key: value' cookie. -func (a *Agent) Cookie(key, value string) *Agent { - a.req.Header.SetCookie(key, value) - - return a -} - -// CookieBytesK sets one 'key: value' cookie. -func (a *Agent) CookieBytesK(key []byte, value string) *Agent { - a.req.Header.SetCookieBytesK(key, value) - - return a -} - -// CookieBytesKV sets one 'key: value' cookie. -func (a *Agent) CookieBytesKV(key, value []byte) *Agent { - a.req.Header.SetCookieBytesKV(key, value) - - return a -} - -// Cookies sets multiple 'key: value' cookies. -func (a *Agent) Cookies(kv ...string) *Agent { - for i := 1; i < len(kv); i += 2 { - a.req.Header.SetCookie(kv[i-1], kv[i]) - } - - return a -} - -// CookiesBytesKV sets multiple 'key: value' cookies. -func (a *Agent) CookiesBytesKV(kv ...[]byte) *Agent { - for i := 1; i < len(kv); i += 2 { - a.req.Header.SetCookieBytesKV(kv[i-1], kv[i]) - } - - return a -} - -// Referer sets Referer header value. -func (a *Agent) Referer(referer string) *Agent { - a.req.Header.SetReferer(referer) - - return a -} - -// RefererBytes sets Referer header value. -func (a *Agent) RefererBytes(referer []byte) *Agent { - a.req.Header.SetRefererBytes(referer) - - return a -} - -// ContentType sets Content-Type header value. -func (a *Agent) ContentType(contentType string) *Agent { - a.req.Header.SetContentType(contentType) - - return a -} - -// ContentTypeBytes sets Content-Type header value. -func (a *Agent) ContentTypeBytes(contentType []byte) *Agent { - a.req.Header.SetContentTypeBytes(contentType) - - return a -} - -/************************** End Header Setting **************************/ - -/************************** URI Setting **************************/ - -// Host sets host for the URI. -func (a *Agent) Host(host string) *Agent { - a.req.URI().SetHost(host) - - return a -} - -// HostBytes sets host for the URI. -func (a *Agent) HostBytes(host []byte) *Agent { - a.req.URI().SetHostBytes(host) - - return a -} - -// QueryString sets URI query string. -func (a *Agent) QueryString(queryString string) *Agent { - a.req.URI().SetQueryString(queryString) - - return a -} - -// QueryStringBytes sets URI query string. -func (a *Agent) QueryStringBytes(queryString []byte) *Agent { - a.req.URI().SetQueryStringBytes(queryString) - - return a -} - -// BasicAuth sets URI username and password. -func (a *Agent) BasicAuth(username, password string) *Agent { - a.req.URI().SetUsername(username) - a.req.URI().SetPassword(password) - - return a -} - -// BasicAuthBytes sets URI username and password. -func (a *Agent) BasicAuthBytes(username, password []byte) *Agent { - a.req.URI().SetUsernameBytes(username) - a.req.URI().SetPasswordBytes(password) - - return a -} - -/************************** End URI Setting **************************/ - -/************************** Request Setting **************************/ - -// BodyString sets request body. -func (a *Agent) BodyString(bodyString string) *Agent { - a.req.SetBodyString(bodyString) - - return a -} - -// Body sets request body. -func (a *Agent) Body(body []byte) *Agent { - a.req.SetBody(body) - - return a -} - -// BodyStream sets request body stream and, optionally body size. -// -// If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes -// before returning io.EOF. -// -// If bodySize < 0, then bodyStream is read until io.EOF. -// -// bodyStream.Close() is called after finishing reading all body data -// if it implements io.Closer. -// -// Note that GET and HEAD requests cannot have body. -func (a *Agent) BodyStream(bodyStream io.Reader, bodySize int) *Agent { - a.req.SetBodyStream(bodyStream, bodySize) - - return a -} - -// JSON sends a JSON request. -func (a *Agent) JSON(v any, ctype ...string) *Agent { - if a.jsonEncoder == nil { - a.jsonEncoder = json.Marshal - } - - if len(ctype) > 0 { - a.req.Header.SetContentType(ctype[0]) - } else { - a.req.Header.SetContentType(MIMEApplicationJSON) - } - - if body, err := a.jsonEncoder(v); err != nil { - a.errs = append(a.errs, err) - } else { - a.req.SetBody(body) - } - - return a -} - -// XML sends an XML request. -func (a *Agent) XML(v any) *Agent { - a.req.Header.SetContentType(MIMEApplicationXML) - - if body, err := xml.Marshal(v); err != nil { - a.errs = append(a.errs, err) - } else { - a.req.SetBody(body) - } - - return a -} - -// Form sends form request with body if args is non-nil. -// -// It is recommended obtaining args via AcquireArgs and release it -// manually in performance-critical code. -func (a *Agent) Form(args *Args) *Agent { - a.req.Header.SetContentType(MIMEApplicationForm) - - if args != nil { - a.req.SetBody(args.QueryString()) - } - - return a -} - -// FormFile represents multipart form file -type FormFile struct { - // Fieldname is form file's field name - Fieldname string - // Name is form file's name - Name string - // Content is form file's content - Content []byte - // autoRelease indicates if returns the object - // acquired via AcquireFormFile to the pool. - autoRelease bool -} - -// FileData appends files for multipart form request. -// -// It is recommended obtaining formFile via AcquireFormFile and release it -// manually in performance-critical code. -func (a *Agent) FileData(formFiles ...*FormFile) *Agent { - a.formFiles = append(a.formFiles, formFiles...) - - return a -} - -// SendFile reads file and appends it to multipart form request. -func (a *Agent) SendFile(filename string, fieldname ...string) *Agent { - content, err := os.ReadFile(filepath.Clean(filename)) - if err != nil { - a.errs = append(a.errs, err) - return a - } - - ff := AcquireFormFile() - if len(fieldname) > 0 && fieldname[0] != "" { - ff.Fieldname = fieldname[0] - } else { - ff.Fieldname = "file" + strconv.Itoa(len(a.formFiles)+1) - } - ff.Name = filepath.Base(filename) - ff.Content = append(ff.Content, content...) - ff.autoRelease = true - - a.formFiles = append(a.formFiles, ff) - - return a -} - -// SendFiles reads files and appends them to multipart form request. -// -// Examples: -// -// SendFile("/path/to/file1", "fieldname1", "/path/to/file2") -func (a *Agent) SendFiles(filenamesAndFieldnames ...string) *Agent { - pairs := len(filenamesAndFieldnames) - if pairs&1 == 1 { - filenamesAndFieldnames = append(filenamesAndFieldnames, "") - } - - for i := 0; i < pairs; i += 2 { - a.SendFile(filenamesAndFieldnames[i], filenamesAndFieldnames[i+1]) - } - - return a -} - -// Boundary sets boundary for multipart form request. -func (a *Agent) Boundary(boundary string) *Agent { - a.boundary = boundary - - return a -} - -// MultipartForm sends multipart form request with k-v and files. -// -// It is recommended obtaining args via AcquireArgs and release it -// manually in performance-critical code. -func (a *Agent) MultipartForm(args *Args) *Agent { - if a.mw == nil { - a.mw = multipart.NewWriter(a.req.BodyWriter()) - } - - if a.boundary != "" { - if err := a.mw.SetBoundary(a.boundary); err != nil { - a.errs = append(a.errs, err) - return a - } - } - - a.req.Header.SetMultipartFormBoundary(a.mw.Boundary()) - - if args != nil { - args.VisitAll(func(key, value []byte) { - if err := a.mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value)); err != nil { - a.errs = append(a.errs, err) - } - }) - } - - for _, ff := range a.formFiles { - w, err := a.mw.CreateFormFile(ff.Fieldname, ff.Name) - if err != nil { - a.errs = append(a.errs, err) - continue - } - if _, err = w.Write(ff.Content); err != nil { - a.errs = append(a.errs, err) - } - } - - if err := a.mw.Close(); err != nil { - a.errs = append(a.errs, err) - } - - return a -} - -/************************** End Request Setting **************************/ - -/************************** Agent Setting **************************/ - -// Debug mode enables logging request and response detail -func (a *Agent) Debug(w ...io.Writer) *Agent { - a.debugWriter = os.Stdout - if len(w) > 0 { - a.debugWriter = w[0] - } - - return a -} - -// Timeout sets request timeout duration. -func (a *Agent) Timeout(timeout time.Duration) *Agent { - a.timeout = timeout - - return a -} - -// Reuse enables the Agent instance to be used again after one request. -// -// If agent is reusable, then it should be released manually when it is no -// longer used. -func (a *Agent) Reuse() *Agent { - a.reuse = true - - return a -} - -// InsecureSkipVerify controls whether the Agent verifies the server -// certificate chain and host name. -func (a *Agent) InsecureSkipVerify() *Agent { - if a.HostClient.TLSConfig == nil { - a.HostClient.TLSConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We explicitly let the user set insecure mode here - } else { - a.HostClient.TLSConfig.InsecureSkipVerify = true - } - - return a -} - -// TLSConfig sets tls config. -func (a *Agent) TLSConfig(config *tls.Config) *Agent { - a.HostClient.TLSConfig = config - - return a -} - -// MaxRedirectsCount sets max redirect count for GET and HEAD. -func (a *Agent) MaxRedirectsCount(count int) *Agent { - a.maxRedirectsCount = count - - return a -} - -// JSONEncoder sets custom json encoder. -func (a *Agent) JSONEncoder(jsonEncoder utils.JSONMarshal) *Agent { - a.jsonEncoder = jsonEncoder - - return a -} - -// JSONDecoder sets custom json decoder. -func (a *Agent) JSONDecoder(jsonDecoder utils.JSONUnmarshal) *Agent { - a.jsonDecoder = jsonDecoder - - return a -} - -// Request returns Agent request instance. -func (a *Agent) Request() *Request { - return a.req -} - -// SetResponse sets custom response for the Agent instance. -// -// It is recommended obtaining custom response via AcquireResponse and release it -// manually in performance-critical code. -func (a *Agent) SetResponse(customResp *Response) *Agent { - a.resp = customResp - - return a -} - -// Dest sets custom dest. -// -// The contents of dest will be replaced by the response body, if the dest -// is too small a new slice will be allocated. -func (a *Agent) Dest(dest []byte) *Agent { - a.dest = dest - - return a -} - -// RetryIf controls whether a retry should be attempted after an error. -// -// By default, will use isIdempotent function from fasthttp -func (a *Agent) RetryIf(retryIf RetryIfFunc) *Agent { - a.HostClient.RetryIf = retryIf - return a -} - -/************************** End Agent Setting **************************/ - -// Bytes returns the status code, bytes body and errors of url. -// -// it's not safe to use Agent after calling [Agent.Bytes] -func (a *Agent) Bytes() (int, []byte, []error) { - defer a.release() - return a.bytes() -} - -func (a *Agent) bytes() (code int, body []byte, errs []error) { //nolint:nonamedreturns,revive // We want to overwrite the body in a deferred func. TODO: Check if we really need to do this. We eventually want to get rid of all named returns. - if errs = append(errs, a.errs...); len(errs) > 0 { - return code, body, errs - } - - var ( - req = a.req - resp *Response - nilResp bool - ) - - if a.resp == nil { - resp = AcquireResponse() - nilResp = true - } else { - resp = a.resp - } - - defer func() { - if a.debugWriter != nil { - printDebugInfo(req, resp, a.debugWriter) - } - - if len(errs) == 0 { - code = resp.StatusCode() - } - - body = append(a.dest, resp.Body()...) //nolint:gocritic // We want to append to the returned slice here - - if nilResp { - ReleaseResponse(resp) - } - }() - - if a.timeout > 0 { - if err := a.HostClient.DoTimeout(req, resp, a.timeout); err != nil { - errs = append(errs, err) - return code, body, errs - } - } else if a.maxRedirectsCount > 0 && (string(req.Header.Method()) == MethodGet || string(req.Header.Method()) == MethodHead) { - if err := a.HostClient.DoRedirects(req, resp, a.maxRedirectsCount); err != nil { - errs = append(errs, err) - return code, body, errs - } - } else if err := a.HostClient.Do(req, resp); err != nil { - errs = append(errs, err) - } - - return code, body, errs -} - -func printDebugInfo(req *Request, resp *Response, w io.Writer) { - msg := fmt.Sprintf("Connected to %s(%s)\r\n\r\n", req.URI().Host(), resp.RemoteAddr()) - _, _ = w.Write(utils.UnsafeBytes(msg)) //nolint:errcheck // This will never fail - _, _ = req.WriteTo(w) //nolint:errcheck // This will never fail - _, _ = resp.WriteTo(w) //nolint:errcheck // This will never fail -} - -// String returns the status code, string body and errors of url. -// -// it's not safe to use Agent after calling [Agent.String] -func (a *Agent) String() (int, string, []error) { - defer a.release() - code, body, errs := a.bytes() - // TODO: There might be a data race here on body. Maybe use utils.CopyBytes on it? - - return code, utils.UnsafeString(body), errs -} - -// Struct returns the status code, bytes body and errors of URL. -// And bytes body will be unmarshalled to given v. -// -// it's not safe to use Agent after calling [Agent.Struct] -func (a *Agent) Struct(v any) (int, []byte, []error) { - defer a.release() - - code, body, errs := a.bytes() - if len(errs) > 0 { - return code, body, errs - } - - // TODO: This should only be done once - if a.jsonDecoder == nil { - a.jsonDecoder = json.Unmarshal - } - - if err := a.jsonDecoder(body, v); err != nil { - errs = append(errs, err) - } - - return code, body, errs -} - -func (a *Agent) release() { - if !a.reuse { - ReleaseAgent(a) - } else { - a.errs = a.errs[:0] - } -} - -func (a *Agent) reset() { - a.HostClient = nil - a.req.Reset() - a.resp = nil - a.dest = nil - a.timeout = 0 - a.args = nil - a.errs = a.errs[:0] - a.debugWriter = nil - a.mw = nil - a.reuse = false - a.parsed = false - a.maxRedirectsCount = 0 - a.boundary = "" - a.Name = "" - a.NoDefaultUserAgentHeader = false - for i, ff := range a.formFiles { - if ff.autoRelease { - ReleaseFormFile(ff) - } - a.formFiles[i] = nil - } - a.formFiles = a.formFiles[:0] -} - -var ( - clientPool sync.Pool - agentPool = sync.Pool{ - New: func() any { - return &Agent{req: &Request{}} - }, - } - responsePool sync.Pool - argsPool sync.Pool - formFilePool sync.Pool -) - -// AcquireClient returns an empty Client instance from client pool. -// -// The returned Client instance may be passed to ReleaseClient when it is -// no longer needed. This allows Client recycling, reduces GC pressure -// and usually improves performance. -func AcquireClient() *Client { - v := clientPool.Get() - if v == nil { - return &Client{} - } - c, ok := v.(*Client) - if !ok { - panic(errors.New("failed to type-assert to *Client")) - } - return c -} - -// ReleaseClient returns c acquired via AcquireClient to client pool. -// -// It is forbidden accessing req and/or it's members after returning -// it to client pool. -func ReleaseClient(c *Client) { - c.UserAgent = "" - c.NoDefaultUserAgentHeader = false - c.JSONEncoder = nil - c.JSONDecoder = nil - - clientPool.Put(c) -} - -// AcquireAgent returns an empty Agent instance from Agent pool. -// -// The returned Agent instance may be passed to ReleaseAgent when it is -// no longer needed. This allows Agent recycling, reduces GC pressure -// and usually improves performance. -func AcquireAgent() *Agent { - a, ok := agentPool.Get().(*Agent) - if !ok { - panic(errors.New("failed to type-assert to *Agent")) - } - return a -} - -// ReleaseAgent returns an acquired via AcquireAgent to Agent pool. -// -// It is forbidden accessing req and/or it's members after returning -// it to Agent pool. -func ReleaseAgent(a *Agent) { - a.reset() - agentPool.Put(a) -} - -// AcquireResponse returns an empty Response instance from response pool. -// -// The returned Response instance may be passed to ReleaseResponse when it is -// no longer needed. This allows Response recycling, reduces GC pressure -// and usually improves performance. -// Copy from fasthttp -func AcquireResponse() *Response { - v := responsePool.Get() - if v == nil { - return &Response{} - } - r, ok := v.(*Response) - if !ok { - panic(errors.New("failed to type-assert to *Response")) - } - return r -} - -// ReleaseResponse return resp acquired via AcquireResponse to response pool. -// -// It is forbidden accessing resp and/or it's members after returning -// it to response pool. -// Copy from fasthttp -func ReleaseResponse(resp *Response) { - resp.Reset() - responsePool.Put(resp) -} - -// AcquireArgs returns an empty Args object from the pool. -// -// The returned Args may be returned to the pool with ReleaseArgs -// when no longer needed. This allows reducing GC load. -// Copy from fasthttp -func AcquireArgs() *Args { - v := argsPool.Get() - if v == nil { - return &Args{} - } - a, ok := v.(*Args) - if !ok { - panic(errors.New("failed to type-assert to *Args")) - } - return a -} - -// ReleaseArgs returns the object acquired via AcquireArgs to the pool. -// -// String not access the released Args object, otherwise data races may occur. -// Copy from fasthttp -func ReleaseArgs(a *Args) { - a.Reset() - argsPool.Put(a) -} - -// AcquireFormFile returns an empty FormFile object from the pool. -// -// The returned FormFile may be returned to the pool with ReleaseFormFile -// when no longer needed. This allows reducing GC load. -func AcquireFormFile() *FormFile { - v := formFilePool.Get() - if v == nil { - return &FormFile{} - } - ff, ok := v.(*FormFile) - if !ok { - panic(errors.New("failed to type-assert to *FormFile")) - } - return ff -} - -// ReleaseFormFile returns the object acquired via AcquireFormFile to the pool. -// -// String not access the released FormFile object, otherwise data races may occur. -func ReleaseFormFile(ff *FormFile) { - ff.Fieldname = "" - ff.Name = "" - ff.Content = ff.Content[:0] - ff.autoRelease = false - - formFilePool.Put(ff) -} - -const ( - defaultUserAgent = "fiber" -) - -type multipartWriter interface { - Boundary() string - SetBoundary(boundary string) error - CreateFormFile(fieldname, filename string) (io.Writer, error) - WriteField(fieldname, value string) error - Close() error -} diff --git a/client/README.md b/client/README.md new file mode 100644 index 0000000000..a992fcc6fb --- /dev/null +++ b/client/README.md @@ -0,0 +1,35 @@ +

Fiber Client

+

Easy-to-use HTTP client based on fasthttp (inspired by resty and axios)

+

Features section describes in detail about Resty capabilities

+ +## Features + +> The characteristics have not yet been written. + +- GET, POST, PUT, DELETE, HEAD, PATCH, OPTIONS, etc. +- Simple and chainable methods for settings and request +- Request Body can be `string`, `[]byte`, `map`, `slice` + - Auto detects `Content-Type` + - Buffer processing for `files` + - Native `*fasthttp.Request` instance can be accessed during middleware and request execution via `Request.RawRequest` + - Request Body can be read multiple time via `Request.RawRequest.GetBody()` +- Response object gives you more possibility + - Access as `[]byte` by `response.Body()` or access as `string` by `response.String()` +- Automatic marshal and unmarshal for JSON and XML content type + - Default is JSON, if you supply struct/map without header Content-Type + - For auto-unmarshal, refer to - + - Success scenario Request.SetResult() and Response.Result(). + - Error scenario Request.SetError() and Response.Error(). + - Supports RFC7807 - application/problem+json & application/problem+xml + - Provide an option to override JSON Marshal/Unmarshal and XML Marshal/Unmarshal + +## Usage + +The following samples will assist you to become as comfortable as possible with `Fiber Client` library. + +```go +// Import Fiber Client into your code and refer it as `client`. +import "github.com/gofiber/fiber/client" +``` + +### Simple GET diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000000..22cbd19727 --- /dev/null +++ b/client/client.go @@ -0,0 +1,775 @@ +package client + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "encoding/xml" + "errors" + "fmt" + "io" + urlpkg "net/url" + "os" + "path/filepath" + "sync" + "time" + + "github.com/gofiber/fiber/v3/log" + + "github.com/gofiber/utils/v2" + + "github.com/valyala/fasthttp" +) + +var ( + ErrInvalidProxyURL = errors.New("invalid proxy url scheme") + ErrFailedToAppendCert = errors.New("failed to append certificate") +) + +// The Client is used to create a Fiber Client with +// client-level settings that apply to all requests +// raise from the client. +// +// Fiber Client also provides an option to override +// or merge most of the client settings at the request. +type Client struct { + mu sync.RWMutex + + fasthttp *fasthttp.Client + + baseURL string + userAgent string + referer string + header *Header + params *QueryParam + cookies *Cookie + path *PathParam + + debug bool + + timeout time.Duration + + // user defined request hooks + userRequestHooks []RequestHook + + // client package defined request hooks + builtinRequestHooks []RequestHook + + // user defined response hooks + userResponseHooks []ResponseHook + + // client package defined response hooks + builtinResponseHooks []ResponseHook + + jsonMarshal utils.JSONMarshal + jsonUnmarshal utils.JSONUnmarshal + xmlMarshal utils.XMLMarshal + xmlUnmarshal utils.XMLUnmarshal + + cookieJar *CookieJar + + // proxy + proxyURL string + + // retry + retryConfig *RetryConfig + + // logger + logger log.CommonLogger +} + +// R raise a request from the client. +func (c *Client) R() *Request { + return AcquireRequest().SetClient(c) +} + +// RequestHook Request returns user-defined request hooks. +func (c *Client) RequestHook() []RequestHook { + return c.userRequestHooks +} + +// AddRequestHook Add user-defined request hooks. +func (c *Client) AddRequestHook(h ...RequestHook) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.userRequestHooks = append(c.userRequestHooks, h...) + return c +} + +// ResponseHook return user-define response hooks. +func (c *Client) ResponseHook() []ResponseHook { + return c.userResponseHooks +} + +// AddResponseHook Add user-defined response hooks. +func (c *Client) AddResponseHook(h ...ResponseHook) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.userResponseHooks = append(c.userResponseHooks, h...) + return c +} + +// JSONMarshal returns json marshal function in Core. +func (c *Client) JSONMarshal() utils.JSONMarshal { + return c.jsonMarshal +} + +// SetJSONMarshal Set json encoder. +func (c *Client) SetJSONMarshal(f utils.JSONMarshal) *Client { + c.jsonMarshal = f + return c +} + +// JSONUnmarshal returns json unmarshal function in Core. +func (c *Client) JSONUnmarshal() utils.JSONUnmarshal { + return c.jsonUnmarshal +} + +// Set json decoder. +func (c *Client) SetJSONUnmarshal(f utils.JSONUnmarshal) *Client { + c.jsonUnmarshal = f + return c +} + +// XMLMarshal returns xml marshal function in Core. +func (c *Client) XMLMarshal() utils.XMLMarshal { + return c.xmlMarshal +} + +// SetXMLMarshal Set xml encoder. +func (c *Client) SetXMLMarshal(f utils.XMLMarshal) *Client { + c.xmlMarshal = f + return c +} + +// XMLUnmarshal returns xml unmarshal function in Core. +func (c *Client) XMLUnmarshal() utils.XMLUnmarshal { + return c.xmlUnmarshal +} + +// SetXMLUnmarshal Set xml decoder. +func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client { + c.xmlUnmarshal = f + return c +} + +// TLSConfig returns tlsConfig in client. +// If client don't have tlsConfig, this function will init it. +func (c *Client) TLSConfig() *tls.Config { + if c.fasthttp.TLSConfig == nil { + c.fasthttp.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + } + } + + return c.fasthttp.TLSConfig +} + +// SetTLSConfig sets tlsConfig in client. +func (c *Client) SetTLSConfig(config *tls.Config) *Client { + c.fasthttp.TLSConfig = config + return c +} + +// SetCertificates method sets client certificates into client. +func (c *Client) SetCertificates(certs ...tls.Certificate) *Client { + config := c.TLSConfig() + config.Certificates = append(config.Certificates, certs...) + return c +} + +// SetRootCertificate adds one or more root certificates into client. +func (c *Client) SetRootCertificate(path string) *Client { + cleanPath := filepath.Clean(path) + file, err := os.Open(cleanPath) + if err != nil { + c.logger.Panicf("client: %v", err) + } + defer func() { + if err := file.Close(); err != nil { + c.logger.Panicf("client: failed to close file: %v", err) + } + }() + + pem, err := io.ReadAll(file) + if err != nil { + c.logger.Panicf("client: %v", err) + } + + config := c.TLSConfig() + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + + if !config.RootCAs.AppendCertsFromPEM(pem) { + c.logger.Panicf("client: %v", ErrFailedToAppendCert) + } + + return c +} + +// SetRootCertificateFromString method adds one or more root certificates into client. +func (c *Client) SetRootCertificateFromString(pem string) *Client { + config := c.TLSConfig() + + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + + if !config.RootCAs.AppendCertsFromPEM([]byte(pem)) { + c.logger.Panicf("client: %v", ErrFailedToAppendCert) + } + + return c +} + +// SetProxyURL sets proxy url in client. It will apply via core to hostclient. +func (c *Client) SetProxyURL(proxyURL string) error { + pURL, err := urlpkg.Parse(proxyURL) + if err != nil { + return fmt.Errorf("client: %w", err) + } + + if pURL.Scheme != "http" && pURL.Scheme != "https" { + return fmt.Errorf("client: %w", ErrInvalidProxyURL) + } + + c.proxyURL = pURL.String() + + return nil +} + +// RetryConfig returns retry config in client. +func (c *Client) RetryConfig() *RetryConfig { + return c.retryConfig +} + +// SetRetryConfig sets retry config in client which is impl by addon/retry package. +func (c *Client) SetRetryConfig(config *RetryConfig) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.retryConfig = config + return c +} + +// BaseURL returns baseurl in Client instance. +func (c *Client) BaseURL() string { + return c.baseURL +} + +// SetBaseURL Set baseUrl which is prefix of real url. +func (c *Client) SetBaseURL(url string) *Client { + c.baseURL = url + return c +} + +// Header method returns header value via key, +// this method will visit all field in the header, +// then sort them. +func (c *Client) Header(key string) []string { + return c.header.PeekMultiple(key) +} + +// AddHeader method adds a single header field and its value in the client instance. +// These headers will be applied to all requests raised from this client instance. +// Also, it can be overridden at request level header options. +func (c *Client) AddHeader(key, val string) *Client { + c.header.Add(key, val) + return c +} + +// SetHeader method sets a single header field and its value in the client instance. +// These headers will be applied to all requests raised from this client instance. +// Also, it can be overridden at request level header options. +func (c *Client) SetHeader(key, val string) *Client { + c.header.Set(key, val) + return c +} + +// AddHeaders method adds multiple headers field and its values at one go in the client instance. +// These headers will be applied to all requests raised from this client instance. Also it can be +// overridden at request level headers options. +func (c *Client) AddHeaders(h map[string][]string) *Client { + c.header.AddHeaders(h) + return c +} + +// SetHeaders method sets multiple headers field and its values at one go in the client instance. +// These headers will be applied to all requests raised from this client instance. Also it can be +// overridden at request level headers options. +func (c *Client) SetHeaders(h map[string]string) *Client { + c.header.SetHeaders(h) + return c +} + +// Param method returns params value via key, +// this method will visit all field in the query param. +func (c *Client) Param(key string) []string { + res := []string{} + tmp := c.params.PeekMulti(key) + for _, v := range tmp { + res = append(res, utils.UnsafeString(v)) + } + + return res +} + +// AddParam method adds a single query param field and its value in the client instance. +// These params will be applied to all requests raised from this client instance. +// Also, it can be overridden at request level param options. +func (c *Client) AddParam(key, val string) *Client { + c.params.Add(key, val) + return c +} + +// SetParam method sets a single query param field and its value in the client instance. +// These params will be applied to all requests raised from this client instance. +// Also, it can be overridden at request level param options. +func (c *Client) SetParam(key, val string) *Client { + c.params.Set(key, val) + return c +} + +// AddParams method adds multiple query params field and its values at one go in the client instance. +// These params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level params options. +func (c *Client) AddParams(m map[string][]string) *Client { + c.params.AddParams(m) + return c +} + +// SetParams method sets multiple params field and its values at one go in the client instance. +// These params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level params options. +func (c *Client) SetParams(m map[string]string) *Client { + c.params.SetParams(m) + return c +} + +// SetParamsWithStruct method sets multiple params field and its values at one go in the client instance. +// These params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level params options. +func (c *Client) SetParamsWithStruct(v any) *Client { + c.params.SetParamsWithStruct(v) + return c +} + +// DelParams method deletes single or multiple params field and its values in client. +func (c *Client) DelParams(key ...string) *Client { + for _, v := range key { + c.params.Del(v) + } + return c +} + +// SetUserAgent method sets userAgent field and its value in the client instance. +// This ua will be applied to all requests raised from this client instance. +// Also it can be overridden at request level ua options. +func (c *Client) SetUserAgent(ua string) *Client { + c.userAgent = ua + return c +} + +// SetReferer method sets referer field and its value in the client instance. +// This referer will be applied to all requests raised from this client instance. +// Also it can be overridden at request level referer options. +func (c *Client) SetReferer(r string) *Client { + c.referer = r + return c +} + +// PathParam returns the path param be set in request instance. +// if path param doesn't exist, return empty string. +func (c *Client) PathParam(key string) string { + if val, ok := (*c.path)[key]; ok { + return val + } + + return "" +} + +// SetPathParam method sets a single path param field and its value in the client instance. +// These path params will be applied to all requests raised from this client instance. +// Also it can be overridden at request level path params options. +func (c *Client) SetPathParam(key, val string) *Client { + c.path.SetParam(key, val) + return c +} + +// SetPathParams method sets multiple path params field and its values at one go in the client instance. +// These path params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level path params options. +func (c *Client) SetPathParams(m map[string]string) *Client { + c.path.SetParams(m) + return c +} + +// SetPathParamsWithStruct method sets multiple path params field and its values at one go in the client instance. +// These path params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level path params options. +func (c *Client) SetPathParamsWithStruct(v any) *Client { + c.path.SetParamsWithStruct(v) + return c +} + +// DelPathParams method deletes single or multiple path params field and its values in client. +func (c *Client) DelPathParams(key ...string) *Client { + c.path.DelParams(key...) + return c +} + +// Cookie returns the cookie be set in request instance. +// if cookie doesn't exist, return empty string. +func (c *Client) Cookie(key string) string { + if val, ok := (*c.cookies)[key]; ok { + return val + } + return "" +} + +// SetCookie method sets a single cookie field and its value in the client instance. +// These cookies will be applied to all requests raised from this client instance. +// Also it can be overridden at request level cookie options. +func (c *Client) SetCookie(key, val string) *Client { + c.cookies.SetCookie(key, val) + return c +} + +// SetCookies method sets multiple cookies field and its values at one go in the client instance. +// These cookies will be applied to all requests raised from this client instance. Also it can be +// overridden at request level cookie options. +func (c *Client) SetCookies(m map[string]string) *Client { + c.cookies.SetCookies(m) + return c +} + +// SetCookiesWithStruct method sets multiple cookies field and its values at one go in the client instance. +// These cookies will be applied to all requests raised from this client instance. Also it can be +// overridden at request level cookies options. +func (c *Client) SetCookiesWithStruct(v any) *Client { + c.cookies.SetCookiesWithStruct(v) + return c +} + +// DelCookies method deletes single or multiple cookies field and its values in client. +func (c *Client) DelCookies(key ...string) *Client { + c.cookies.DelCookies(key...) + return c +} + +// SetTimeout method sets timeout val in client instance. +// This value will be applied to all requests raised from this client instance. +// Also, it can be overridden at request level timeout options. +func (c *Client) SetTimeout(t time.Duration) *Client { + c.timeout = t + return c +} + +// Debug enable log debug level output. +func (c *Client) Debug() *Client { + c.debug = true + return c +} + +// DisableDebug disenable log debug level output. +func (c *Client) DisableDebug() *Client { + c.debug = false + return c +} + +// SetCookieJar sets cookie jar in client instance. +func (c *Client) SetCookieJar(cookieJar *CookieJar) *Client { + c.cookieJar = cookieJar + return c +} + +// Get provide an API like axios which send get request. +func (c *Client) Get(url string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Get(url) +} + +// Post provide an API like axios which send post request. +func (c *Client) Post(url string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Post(url) +} + +// Head provide a API like axios which send head request. +func (c *Client) Head(url string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Head(url) +} + +// Put provide an API like axios which send put request. +func (c *Client) Put(url string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Put(url) +} + +// Delete provide an API like axios which send delete request. +func (c *Client) Delete(url string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Delete(url) +} + +// Options provide an API like axios which send options request. +func (c *Client) Options(url string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Options(url) +} + +// Patch provide an API like axios which send patch request. +func (c *Client) Patch(url string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Patch(url) +} + +// Custom provide an API like axios which send custom request. +func (c *Client) Custom(url, method string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Custom(url, method) +} + +// SetDial sets dial function in client. +func (c *Client) SetDial(dial fasthttp.DialFunc) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.fasthttp.Dial = dial + return c +} + +// SetLogger sets logger instance in client. +func (c *Client) SetLogger(logger log.CommonLogger) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.logger = logger + return c +} + +// Logger returns logger instance of client. +func (c *Client) Logger() log.CommonLogger { + return c.logger +} + +// Reset clear Client object +func (c *Client) Reset() { + c.fasthttp = &fasthttp.Client{} + c.baseURL = "" + c.timeout = 0 + c.userAgent = "" + c.referer = "" + c.proxyURL = "" + c.retryConfig = nil + c.debug = false + + if c.cookieJar != nil { + c.cookieJar.Release() + c.cookieJar = nil + } + + c.path.Reset() + c.cookies.Reset() + c.header.Reset() + c.params.Reset() +} + +// Config for easy to set the request parameters, it should be +// noted that when setting the request body will use JSON as +// the default serialization mechanism, while the priority of +// Body is higher than FormData, and the priority of FormData +// is higher than File. +type Config struct { + Ctx context.Context //nolint:containedctx // It's needed to be stored in the config. + + UserAgent string + Referer string + Header map[string]string + Param map[string]string + Cookie map[string]string + PathParam map[string]string + + Timeout time.Duration + MaxRedirects int + + Body any + FormData map[string]string + File []*File +} + +// setConfigToRequest Set the parameters passed via Config to Request. +func setConfigToRequest(req *Request, config ...Config) { + if len(config) == 0 { + return + } + cfg := config[0] + + if cfg.Ctx != nil { + req.SetContext(cfg.Ctx) + } + + if cfg.UserAgent != "" { + req.SetUserAgent(cfg.UserAgent) + } + + if cfg.Referer != "" { + req.SetReferer(cfg.Referer) + } + + if cfg.Header != nil { + req.SetHeaders(cfg.Header) + } + + if cfg.Param != nil { + req.SetParams(cfg.Param) + } + + if cfg.Cookie != nil { + req.SetCookies(cfg.Cookie) + } + + if cfg.PathParam != nil { + req.SetPathParams(cfg.PathParam) + } + + if cfg.Timeout != 0 { + req.SetTimeout(cfg.Timeout) + } + + if cfg.MaxRedirects != 0 { + req.SetMaxRedirects(cfg.MaxRedirects) + } + + if cfg.Body != nil { + req.SetJSON(cfg.Body) + return + } + + if cfg.FormData != nil { + req.SetFormDatas(cfg.FormData) + return + } + + if cfg.File != nil && len(cfg.File) != 0 { + req.AddFiles(cfg.File...) + return + } +} + +var ( + defaultClient *Client + replaceMu = sync.Mutex{} + defaultUserAgent = "fiber" +) + +// init acquire a default client. +func init() { + defaultClient = NewClient() +} + +// NewClient creates and returns a new Client object. +func NewClient() *Client { + // FOllOW-UP performance optimization + // trie to use a pool to reduce the cost of memory allocation + // for the fiber client and the fasthttp client + // if possible also for other structs -> request header, cookie, query param, path param... + return &Client{ + fasthttp: &fasthttp.Client{}, + header: &Header{ + RequestHeader: &fasthttp.RequestHeader{}, + }, + params: &QueryParam{ + Args: fasthttp.AcquireArgs(), + }, + cookies: &Cookie{}, + path: &PathParam{}, + + userRequestHooks: []RequestHook{}, + builtinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, + userResponseHooks: []ResponseHook{}, + builtinResponseHooks: []ResponseHook{parserResponseCookie, logger}, + jsonMarshal: json.Marshal, + jsonUnmarshal: json.Unmarshal, + xmlMarshal: xml.Marshal, + xmlUnmarshal: xml.Unmarshal, + logger: log.DefaultLogger(), + } +} + +// C get default client. +func C() *Client { + return defaultClient +} + +// Replace the defaultClient, the returned function can undo. +func Replace(c *Client) func() { + replaceMu.Lock() + defer replaceMu.Unlock() + + oldClient := defaultClient + defaultClient = c + + return func() { + replaceMu.Lock() + defer replaceMu.Unlock() + + defaultClient = oldClient + } +} + +// Get send a get request use defaultClient, a convenient method. +func Get(url string, cfg ...Config) (*Response, error) { + return C().Get(url, cfg...) +} + +// Post send a post request use defaultClient, a convenient method. +func Post(url string, cfg ...Config) (*Response, error) { + return C().Post(url, cfg...) +} + +// Head send a head request use defaultClient, a convenient method. +func Head(url string, cfg ...Config) (*Response, error) { + return C().Head(url, cfg...) +} + +// Put send a put request use defaultClient, a convenient method. +func Put(url string, cfg ...Config) (*Response, error) { + return C().Put(url, cfg...) +} + +// Delete send a delete request use defaultClient, a convenient method. +func Delete(url string, cfg ...Config) (*Response, error) { + return C().Delete(url, cfg...) +} + +// Options send a options request use defaultClient, a convenient method. +func Options(url string, cfg ...Config) (*Response, error) { + return C().Options(url, cfg...) +} + +// Patch send a patch request use defaultClient, a convenient method. +func Patch(url string, cfg ...Config) (*Response, error) { + return C().Patch(url, cfg...) +} diff --git a/client/client_test.go b/client/client_test.go new file mode 100644 index 0000000000..4fd2e484a6 --- /dev/null +++ b/client/client_test.go @@ -0,0 +1,1642 @@ +package client + +import ( + "context" + "crypto/tls" + "errors" + "io" + "net" + "os" + "reflect" + "sync" + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/addon/retry" + "github.com/gofiber/fiber/v3/internal/tlstest" + "github.com/gofiber/utils/v2" + "github.com/stretchr/testify/require" + "github.com/valyala/bytebufferpool" +) + +func startTestServerWithPort(t *testing.T, beforeStarting func(app *fiber.App)) (*fiber.App, string) { + t.Helper() + + app := fiber.New() + + if beforeStarting != nil { + beforeStarting(app) + } + + addrChan := make(chan string) + errChan := make(chan error, 1) + go func() { + err := app.Listen(":0", fiber.ListenConfig{ + DisableStartupMessage: true, + ListenerAddrFunc: func(addr net.Addr) { + addrChan <- addr.String() + }, + }) + if err != nil { + errChan <- err + } + }() + + select { + case addr := <-addrChan: + return app, addr + case err := <-errChan: + t.Fatalf("Failed to start test server: %v", err) + } + + return nil, "" +} + +func Test_Client_Add_Hook(t *testing.T) { + t.Parallel() + + t.Run("add request hooks", func(t *testing.T) { + t.Parallel() + + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + + client := NewClient().AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook1") + return nil + }) + + require.Len(t, client.RequestHook(), 1) + + client.AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook2") + return nil + }, func(_ *Client, _ *Request) error { + buf.WriteString("hook3") + return nil + }) + + require.Len(t, client.RequestHook(), 3) + }) + + t.Run("add response hooks", func(t *testing.T) { + t.Parallel() + client := NewClient().AddResponseHook(func(_ *Client, _ *Response, _ *Request) error { + return nil + }) + + require.Len(t, client.ResponseHook(), 1) + + client.AddResponseHook(func(_ *Client, _ *Response, _ *Request) error { + return nil + }, func(_ *Client, _ *Response, _ *Request) error { + return nil + }) + + require.Len(t, client.ResponseHook(), 3) + }) +} + +func Test_Client_Add_Hook_CheckOrder(t *testing.T) { + t.Parallel() + + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + + client := NewClient(). + AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook1") + return nil + }). + AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook2") + return nil + }). + AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook3") + return nil + }) + + for _, hook := range client.RequestHook() { + require.NoError(t, hook(client, &Request{})) + } + + require.Equal(t, "hook1hook2hook3", buf.String()) +} + +func Test_Client_Marshal(t *testing.T) { + t.Parallel() + + t.Run("set json marshal", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetJSONMarshal(func(_ any) ([]byte, error) { + return []byte("hello"), nil + }) + val, err := client.JSONMarshal()(nil) + + require.NoError(t, err) + require.Equal(t, []byte("hello"), val) + }) + + t.Run("set json marshal error", func(t *testing.T) { + t.Parallel() + + emptyErr := errors.New("empty json") + client := NewClient(). + SetJSONMarshal(func(_ any) ([]byte, error) { + return nil, emptyErr + }) + + val, err := client.JSONMarshal()(nil) + require.Nil(t, val) + require.ErrorIs(t, err, emptyErr) + }) + + t.Run("set json unmarshal", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetJSONUnmarshal(func(_ []byte, _ any) error { + return errors.New("empty json") + }) + + err := client.JSONUnmarshal()(nil, nil) + require.Equal(t, errors.New("empty json"), err) + }) + + t.Run("set json unmarshal error", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetJSONUnmarshal(func(_ []byte, _ any) error { + return errors.New("empty json") + }) + + err := client.JSONUnmarshal()(nil, nil) + require.Equal(t, errors.New("empty json"), err) + }) + + t.Run("set xml marshal", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetXMLMarshal(func(_ any) ([]byte, error) { + return []byte("hello"), nil + }) + val, err := client.XMLMarshal()(nil) + + require.NoError(t, err) + require.Equal(t, []byte("hello"), val) + }) + + t.Run("set xml marshal error", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetXMLMarshal(func(_ any) ([]byte, error) { + return nil, errors.New("empty xml") + }) + + val, err := client.XMLMarshal()(nil) + require.Nil(t, val) + require.Equal(t, errors.New("empty xml"), err) + }) + + t.Run("set xml unmarshal", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetXMLUnmarshal(func(_ []byte, _ any) error { + return errors.New("empty xml") + }) + + err := client.XMLUnmarshal()(nil, nil) + require.Equal(t, errors.New("empty xml"), err) + }) + + t.Run("set xml unmarshal error", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetXMLUnmarshal(func(_ []byte, _ any) error { + return errors.New("empty xml") + }) + + err := client.XMLUnmarshal()(nil, nil) + require.Equal(t, errors.New("empty xml"), err) + }) +} + +func Test_Client_SetBaseURL(t *testing.T) { + t.Parallel() + + client := NewClient().SetBaseURL("http://example.com") + + require.Equal(t, "http://example.com", client.BaseURL()) +} + +func Test_Client_Invalid_URL(t *testing.T) { + t.Parallel() + + app, dial, start := createHelperServer(t) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + + go start() + + _, err := NewClient().SetDial(dial). + R(). + Get("http//example") + + require.ErrorIs(t, err, ErrURLFormat) +} + +func Test_Client_Unsupported_Protocol(t *testing.T) { + t.Parallel() + + _, err := NewClient(). + R(). + Get("ftp://example.com") + + require.ErrorIs(t, err, ErrURLFormat) +} + +func Test_Client_ConcurrencyRequests(t *testing.T) { + t.Parallel() + + app, dial, start := createHelperServer(t) + app.All("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname() + " " + c.Method()) + }) + go start() + + client := NewClient().SetDial(dial) + + wg := sync.WaitGroup{} + for i := 0; i < 5; i++ { + for _, method := range []string{"GET", "POST", "PUT", "DELETE", "PATCH"} { + wg.Add(1) + go func(m string) { + defer wg.Done() + resp, err := client.Custom("http://example.com", m) + require.NoError(t, err) + require.Equal(t, "example.com "+m, utils.UnsafeString(resp.RawResponse.Body())) + }(method) + } + } + + wg.Wait() +} + +func Test_Get(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + }) + + return app, addr + } + + t.Run("global get function", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + resp, err := Get("http://" + addr) + require.NoError(t, err) + require.Equal(t, "0.0.0.0", utils.UnsafeString(resp.RawResponse.Body())) + }) + + t.Run("client get", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + resp, err := NewClient().Get("http://" + addr) + require.NoError(t, err) + require.Equal(t, "0.0.0.0", utils.UnsafeString(resp.RawResponse.Body())) + }) +} + +func Test_Head(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Head("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + }) + + return app, addr + } + + t.Run("global head function", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + resp, err := Head("http://" + addr) + require.NoError(t, err) + require.Equal(t, "7", resp.Header(fiber.HeaderContentLength)) + require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body())) + }) + + t.Run("client head", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + resp, err := NewClient().Head("http://" + addr) + require.NoError(t, err) + require.Equal(t, "7", resp.Header(fiber.HeaderContentLength)) + require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body())) + }) +} + +func Test_Post(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Post("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusCreated). + SendString(c.FormValue("foo")) + }) + }) + + return app, addr + } + + t.Run("global post function", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := Post("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusCreated, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) + + t.Run("client post", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := NewClient().Post("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusCreated, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) +} + +func Test_Put(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Put("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) + }) + + return app, addr + } + + t.Run("global put function", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := Put("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) + + t.Run("client put", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := NewClient().Put("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) +} + +func Test_Delete(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Delete("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusNoContent). + SendString("deleted") + }) + }) + + return app, addr + } + + t.Run("global delete function", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + time.Sleep(1 * time.Second) + + for i := 0; i < 5; i++ { + resp, err := Delete("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) + } + }) + + t.Run("client delete", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := NewClient().Delete("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) + } + }) +} + +func Test_Options(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Options("/", func(c fiber.Ctx) error { + c.Set(fiber.HeaderAllow, "GET, POST, PUT, DELETE, PATCH") + return c.Status(fiber.StatusNoContent).SendString("") + }) + }) + + return app, addr + } + + t.Run("global options function", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := Options("http://" + addr) + + require.NoError(t, err) + require.Equal(t, "GET, POST, PUT, DELETE, PATCH", resp.Header(fiber.HeaderAllow)) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) + } + }) + + t.Run("client options", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := NewClient().Options("http://" + addr) + + require.NoError(t, err) + require.Equal(t, "GET, POST, PUT, DELETE, PATCH", resp.Header(fiber.HeaderAllow)) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) + } + }) +} + +func Test_Patch(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Patch("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) + }) + + return app, addr + } + + t.Run("global patch function", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + time.Sleep(1 * time.Second) + + for i := 0; i < 5; i++ { + resp, err := Patch("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) + + t.Run("client patch", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := NewClient().Patch("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) +} + +func Test_Client_UserAgent(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.Send(c.Request().Header.UserAgent()) + }) + }) + + return app, addr + } + + t.Run("default", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := Get("http://" + addr) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, defaultUserAgent, resp.String()) + } + }) + + t.Run("custom", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + c := NewClient(). + SetUserAgent("ua") + + resp, err := c.Get("http://" + addr) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "ua", resp.String()) + } + }) +} + +func Test_Client_Header(t *testing.T) { + t.Parallel() + + t.Run("add header", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.AddHeader("foo", "bar").AddHeader("foo", "fiber") + + res := req.Header("foo") + require.Len(t, res, 2) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) + + t.Run("set header", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.AddHeader("foo", "bar").SetHeader("foo", "fiber") + + res := req.Header("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + }) + + t.Run("add headers", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.SetHeader("foo", "bar"). + AddHeaders(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.Header("foo") + require.Len(t, res, 3) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + require.Equal(t, "buaa", res[2]) + + res = req.Header("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.SetHeader("foo", "bar"). + SetHeaders(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.Header("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + + res = req.Header("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set header case insensitive", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.SetHeader("foo", "bar"). + AddHeader("FOO", "fiber") + + res := req.Header("foo") + require.Len(t, res, 2) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) +} + +func Test_Client_Header_With_Server(t *testing.T) { + handler := func(c fiber.Ctx) error { + c.Request().Header.VisitAll(func(key, value []byte) { + if k := string(key); k == "K1" || k == "K2" { + _, _ = c.Write(key) //nolint:errcheck // It is fine to ignore the error here + _, _ = c.Write(value) //nolint:errcheck // It is fine to ignore the error here + } + }) + return nil + } + + wrapAgent := func(c *Client) { + c.SetHeader("k1", "v1"). + AddHeader("k1", "v11"). + AddHeaders(map[string][]string{ + "k1": {"v22", "v33"}, + }). + SetHeaders(map[string]string{ + "k2": "v2", + }). + AddHeader("k2", "v22") + } + + testClient(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") +} + +func Test_Client_Cookie(t *testing.T) { + t.Parallel() + + t.Run("set cookie", func(t *testing.T) { + t.Parallel() + req := NewClient(). + SetCookie("foo", "bar") + require.Equal(t, "bar", req.Cookie("foo")) + + req.SetCookie("foo", "bar1") + require.Equal(t, "bar1", req.Cookie("foo")) + }) + + t.Run("set cookies", func(t *testing.T) { + t.Parallel() + req := NewClient(). + SetCookies(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + + req.SetCookies(map[string]string{ + "foo": "bar1", + }) + require.Equal(t, "bar1", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + }) + + t.Run("set cookies with struct", func(t *testing.T) { + t.Parallel() + type args struct { + CookieInt int `cookie:"int"` + CookieString string `cookie:"string"` + } + + req := NewClient().SetCookiesWithStruct(&args{ + CookieInt: 5, + CookieString: "foo", + }) + + require.Equal(t, "5", req.Cookie("int")) + require.Equal(t, "foo", req.Cookie("string")) + }) + + t.Run("del cookies", func(t *testing.T) { + t.Parallel() + req := NewClient(). + SetCookies(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + + req.DelCookies("foo") + require.Equal(t, "", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + }) +} + +func Test_Client_Cookie_With_Server(t *testing.T) { + t.Parallel() + + handler := func(c fiber.Ctx) error { + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) + } + + wrapAgent := func(c *Client) { + c.SetCookie("k1", "v1"). + SetCookies(map[string]string{ + "k2": "v2", + "k3": "v3", + "k4": "v4", + }).DelCookies("k4") + } + + testClient(t, handler, wrapAgent, "v1v2v3") +} + +func Test_Client_CookieJar(t *testing.T) { + handler := func(c fiber.Ctx) error { + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + jar.SetKeyValue("example.com", "k2", "v2") + jar.SetKeyValue("example", "k3", "v3") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1v2") +} + +func Test_Client_CookieJar_Response(t *testing.T) { + t.Parallel() + + t.Run("without expiration", func(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + c.Cookie(&fiber.Cookie{ + Name: "k4", + Value: "v4", + }) + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + jar.SetKeyValue("example.com", "k2", "v2") + jar.SetKeyValue("example", "k3", "v3") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1v2") + + require.Len(t, jar.getCookiesByHost("example.com"), 3) + }) + + t.Run("with expiration", func(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + c.Cookie(&fiber.Cookie{ + Name: "k4", + Value: "v4", + Expires: time.Now().Add(1 * time.Nanosecond), + }) + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + jar.SetKeyValue("example.com", "k2", "v2") + jar.SetKeyValue("example", "k3", "v3") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1v2") + + require.Len(t, jar.getCookiesByHost("example.com"), 2) + }) + + t.Run("override cookie value", func(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + c.Cookie(&fiber.Cookie{ + Name: "k1", + Value: "v2", + }) + return c.SendString( + c.Cookies("k1") + c.Cookies("k2")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + jar.SetKeyValue("example.com", "k2", "v2") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1v2") + + for _, cookie := range jar.getCookiesByHost("example.com") { + if string(cookie.Key()) == "k1" { + require.Equal(t, "v2", string(cookie.Value())) + } + } + }) + + t.Run("different domain", func(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + return c.SendString(c.Cookies("k1")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1") + + require.Len(t, jar.getCookiesByHost("example.com"), 1) + require.Empty(t, jar.getCookiesByHost("example")) + }) +} + +func Test_Client_Referer(t *testing.T) { + handler := func(c fiber.Ctx) error { + return c.Send(c.Request().Header.Referer()) + } + + wrapAgent := func(c *Client) { + c.SetReferer("http://referer.com") + } + + testClient(t, handler, wrapAgent, "http://referer.com") +} + +func Test_Client_QueryParam(t *testing.T) { + t.Parallel() + + t.Run("add param", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.AddParam("foo", "bar").AddParam("foo", "fiber") + + res := req.Param("foo") + require.Len(t, res, 2) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) + + t.Run("set param", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.AddParam("foo", "bar").SetParam("foo", "fiber") + + res := req.Param("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + }) + + t.Run("add params", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.SetParam("foo", "bar"). + AddParams(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.Param("foo") + require.Len(t, res, 3) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + require.Equal(t, "buaa", res[2]) + + res = req.Param("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.Param("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + + res = req.Param("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set params with struct", func(t *testing.T) { + t.Parallel() + + type args struct { + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `param:"int_slice"` + } + + p := NewClient() + p.SetParamsWithStruct(&args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: true, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + require.Empty(t, p.Param("unexport")) + + require.Len(t, p.Param("TInt"), 1) + require.Equal(t, "5", p.Param("TInt")[0]) + + require.Len(t, p.Param("TString"), 1) + require.Equal(t, "string", p.Param("TString")[0]) + + require.Len(t, p.Param("TFloat"), 1) + require.Equal(t, "3.1", p.Param("TFloat")[0]) + + require.Len(t, p.Param("TBool"), 1) + + tslice := p.Param("TSlice") + require.Len(t, tslice, 2) + require.Equal(t, "foo", tslice[0]) + require.Equal(t, "bar", tslice[1]) + + tint := p.Param("TSlice") + require.Len(t, tint, 2) + require.Equal(t, "foo", tint[0]) + require.Equal(t, "bar", tint[1]) + }) + + t.Run("del params", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "foo": "fiber", + "bar": "foo", + }).DelParams("foo", "bar") + + res := req.Param("foo") + require.Empty(t, res) + + res = req.Param("bar") + require.Empty(t, res) + }) +} + +func Test_Client_QueryParam_With_Server(t *testing.T) { + handler := func(c fiber.Ctx) error { + _, _ = c.WriteString(c.Query("k1")) //nolint:errcheck // It is fine to ignore the error here + _, _ = c.WriteString(c.Query("k2")) //nolint:errcheck // It is fine to ignore the error here + + return nil + } + + wrapAgent := func(c *Client) { + c.SetParam("k1", "v1"). + AddParam("k2", "v2") + } + + testClient(t, handler, wrapAgent, "v1v2") +} + +func Test_Client_PathParam(t *testing.T) { + t.Parallel() + + t.Run("set path param", func(t *testing.T) { + t.Parallel() + req := NewClient(). + SetPathParam("foo", "bar") + require.Equal(t, "bar", req.PathParam("foo")) + + req.SetPathParam("foo", "bar1") + require.Equal(t, "bar1", req.PathParam("foo")) + }) + + t.Run("set path params", func(t *testing.T) { + t.Parallel() + req := NewClient(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + + req.SetPathParams(map[string]string{ + "foo": "bar1", + }) + require.Equal(t, "bar1", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + }) + + t.Run("set path params with struct", func(t *testing.T) { + t.Parallel() + type args struct { + CookieInt int `path:"int"` + CookieString string `path:"string"` + } + + req := NewClient().SetPathParamsWithStruct(&args{ + CookieInt: 5, + CookieString: "foo", + }) + + require.Equal(t, "5", req.PathParam("int")) + require.Equal(t, "foo", req.PathParam("string")) + }) + + t.Run("del path params", func(t *testing.T) { + t.Parallel() + req := NewClient(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + + req.DelPathParams("foo") + require.Equal(t, "", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + }) +} + +func Test_Client_PathParam_With_Server(t *testing.T) { + app, dial, start := createHelperServer(t) + + app.Get("/:test", func(c fiber.Ctx) error { + return c.SendString(c.Params("test")) + }) + + go start() + + resp, err := NewClient().SetDial(dial). + SetPathParam("path", "test"). + Get("http://example.com/:path") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "test", resp.String()) +} + +func Test_Client_TLS(t *testing.T) { + t.Parallel() + + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + require.NoError(t, err) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.NoError(t, err) + + ln = tls.NewListener(ln, serverTLSConf) + + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("tls") + }) + + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + client := NewClient() + resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) + + require.NoError(t, err) + require.Equal(t, clientTLSConf, client.TLSConfig()) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "tls", resp.String()) +} + +func Test_Client_TLS_Error(t *testing.T) { + t.Parallel() + + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + clientTLSConf.MaxVersion = tls.VersionTLS12 + serverTLSConf.MinVersion = tls.VersionTLS13 + require.NoError(t, err) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.NoError(t, err) + + ln = tls.NewListener(ln, serverTLSConf) + + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("tls") + }) + + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + client := NewClient() + resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) + + require.Error(t, err) + require.Equal(t, clientTLSConf, client.TLSConfig()) + require.Nil(t, resp) +} + +func Test_Client_TLS_Empty_TLSConfig(t *testing.T) { + t.Parallel() + + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + require.NoError(t, err) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.NoError(t, err) + + ln = tls.NewListener(ln, serverTLSConf) + + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("tls") + }) + + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + client := NewClient() + resp, err := client.Get("https://" + ln.Addr().String()) + + require.Error(t, err) + require.NotEqual(t, clientTLSConf, client.TLSConfig()) + require.Nil(t, resp) +} + +func Test_Client_SetCertificates(t *testing.T) { + t.Parallel() + + serverTLSConf, _, err := tlstest.GetTLSConfigs() + require.NoError(t, err) + + client := NewClient().SetCertificates(serverTLSConf.Certificates...) + require.Len(t, client.TLSConfig().Certificates, 1) +} + +func Test_Client_SetRootCertificate(t *testing.T) { + t.Parallel() + + client := NewClient().SetRootCertificate("../.github/testdata/ssl.pem") + require.NotNil(t, client.TLSConfig().RootCAs) +} + +func Test_Client_SetRootCertificateFromString(t *testing.T) { + t.Parallel() + + file, err := os.Open("../.github/testdata/ssl.pem") + defer func() { require.NoError(t, file.Close()) }() + require.NoError(t, err) + + pem, err := io.ReadAll(file) + require.NoError(t, err) + + client := NewClient().SetRootCertificateFromString(string(pem)) + require.NotNil(t, client.TLSConfig().RootCAs) +} + +func Test_Client_R(t *testing.T) { + t.Parallel() + + client := NewClient() + req := client.R() + + require.Equal(t, "Request", reflect.TypeOf(req).Elem().Name()) + require.Equal(t, client, req.Client()) +} + +func Test_Replace(t *testing.T) { + app, dial, start := createHelperServer(t) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(string(c.Request().Header.Peek("k1"))) + }) + + go start() + + C().SetDial(dial) + resp, err := Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "", resp.String()) + + r := NewClient().SetDial(dial).SetHeader("k1", "v1") + clean := Replace(r) + resp, err = Get("http://example.com") + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "v1", resp.String()) + + clean() + + C().SetDial(dial) + resp, err = Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "", resp.String()) + + C().SetDial(nil) +} + +func Test_Set_Config_To_Request(t *testing.T) { + t.Parallel() + + t.Run("set ctx", func(t *testing.T) { + t.Parallel() + key := struct{}{} + + ctx := context.Background() + ctx = context.WithValue(ctx, key, "v1") + + req := AcquireRequest() + + setConfigToRequest(req, Config{Ctx: ctx}) + + require.Equal(t, "v1", req.Context().Value(key)) + }) + + t.Run("set useragent", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{UserAgent: "agent"}) + + require.Equal(t, "agent", req.UserAgent()) + }) + + t.Run("set referer", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{Referer: "referer"}) + + require.Equal(t, "referer", req.Referer()) + }) + + t.Run("set header", func(t *testing.T) { + req := AcquireRequest() + + setConfigToRequest(req, Config{Header: map[string]string{ + "k1": "v1", + }}) + + require.Equal(t, "v1", req.Header("k1")[0]) + }) + + t.Run("set params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{Param: map[string]string{ + "k1": "v1", + }}) + + require.Equal(t, "v1", req.Param("k1")[0]) + }) + + t.Run("set cookies", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{Cookie: map[string]string{ + "k1": "v1", + }}) + + require.Equal(t, "v1", req.Cookie("k1")) + }) + + t.Run("set pathparam", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{PathParam: map[string]string{ + "k1": "v1", + }}) + + require.Equal(t, "v1", req.PathParam("k1")) + }) + + t.Run("set timeout", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{Timeout: 1 * time.Second}) + + require.Equal(t, 1*time.Second, req.Timeout()) + }) + + t.Run("set maxredirects", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{MaxRedirects: 1}) + + require.Equal(t, 1, req.MaxRedirects()) + }) + + t.Run("set body", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{Body: "test"}) + + require.Equal(t, "test", req.body) + }) + + t.Run("set file", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{File: []*File{ + { + name: "test", + path: "path", + }, + }}) + + require.Equal(t, "path", req.File("test").path) + }) +} + +func Test_Client_SetProxyURL(t *testing.T) { + t.Parallel() + + app, dial, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) + + go start() + + t.Cleanup(func() { + require.NoError(t, app.Shutdown()) + }) + + time.Sleep(1 * time.Second) + + t.Run("success", func(t *testing.T) { + t.Parallel() + client := NewClient().SetDial(dial) + err := client.SetProxyURL("http://test.com") + + require.NoError(t, err) + + _, err = client.Get("http://localhost:3000") + + require.NoError(t, err) + }) + + t.Run("wrong url", func(t *testing.T) { + t.Parallel() + client := NewClient() + + err := client.SetProxyURL(":this is not a url") + + require.Error(t, err) + }) + + t.Run("error", func(t *testing.T) { + t.Parallel() + client := NewClient() + + err := client.SetProxyURL("htgdftp://test.com") + + require.Error(t, err) + }) +} + +func Test_Client_SetRetryConfig(t *testing.T) { + t.Parallel() + + retryConfig := &retry.Config{ + InitialInterval: 1 * time.Second, + MaxRetryCount: 3, + } + + core, client, req := newCore(), NewClient(), AcquireRequest() + req.SetURL("http://example.com") + client.SetRetryConfig(retryConfig) + _, err := core.execute(context.Background(), client, req) + + require.NoError(t, err) + require.Equal(t, retryConfig.InitialInterval, client.RetryConfig().InitialInterval) + require.Equal(t, retryConfig.MaxRetryCount, client.RetryConfig().MaxRetryCount) +} + +func Benchmark_Client_Request(b *testing.B) { + app, dial, start := createHelperServer(b) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) + + go start() + + client := NewClient().SetDial(dial) + + b.ResetTimer() + b.ReportAllocs() + + var err error + var resp *Response + for i := 0; i < b.N; i++ { + resp, err = client.Get("http://example.com") + resp.Close() + } + require.NoError(b, err) +} + +func Benchmark_Client_Request_Parallel(b *testing.B) { + app, dial, start := createHelperServer(b) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) + + go start() + + client := NewClient().SetDial(dial) + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + var err error + var resp *Response + for pb.Next() { + resp, err = client.Get("http://example.com") + resp.Close() + } + require.NoError(b, err) + }) +} diff --git a/client/cookiejar.go b/client/cookiejar.go new file mode 100644 index 0000000000..c66d5f3b7c --- /dev/null +++ b/client/cookiejar.go @@ -0,0 +1,245 @@ +// The code has been taken from https://github.com/valyala/fasthttp/pull/526 originally. +package client + +import ( + "bytes" + "errors" + "net" + "sync" + "time" + + "github.com/gofiber/utils/v2" + "github.com/valyala/fasthttp" +) + +var cookieJarPool = sync.Pool{ + New: func() any { + return &CookieJar{} + }, +} + +// AcquireCookieJar returns an empty CookieJar object from pool. +func AcquireCookieJar() *CookieJar { + jar, ok := cookieJarPool.Get().(*CookieJar) + if !ok { + panic(errors.New("failed to type-assert to *CookieJar")) + } + + return jar +} + +// ReleaseCookieJar returns CookieJar to the pool. +func ReleaseCookieJar(c *CookieJar) { + c.Release() + cookieJarPool.Put(c) +} + +// CookieJar manages cookie storage. It is used by the client to store cookies. +type CookieJar struct { + mu sync.Mutex + hostCookies map[string][]*fasthttp.Cookie +} + +// Get returns the cookies stored from a specific domain. +// If there were no cookies related with host returned slice will be nil. +// +// CookieJar keeps a copy of the cookies, so the returned cookies can be released safely. +func (cj *CookieJar) Get(uri *fasthttp.URI) []*fasthttp.Cookie { + if uri == nil { + return nil + } + + return cj.getByHostAndPath(uri.Host(), uri.Path()) +} + +// get returns the cookies stored from a specific host and path. +func (cj *CookieJar) getByHostAndPath(host, path []byte) []*fasthttp.Cookie { + if cj.hostCookies == nil { + return nil + } + + var ( + err error + cookies []*fasthttp.Cookie + hostStr = utils.UnsafeString(host) + ) + + // port must not be included. + hostStr, _, err = net.SplitHostPort(hostStr) + if err != nil { + hostStr = utils.UnsafeString(host) + } + // get cookies deleting expired ones + cookies = cj.getCookiesByHost(hostStr) + + newCookies := make([]*fasthttp.Cookie, 0, len(cookies)) + for i := 0; i < len(cookies); i++ { + cookie := cookies[i] + if len(path) > 1 && len(cookie.Path()) > 1 && !bytes.HasPrefix(cookie.Path(), path) { + continue + } + newCookies = append(newCookies, cookie) + } + + return newCookies +} + +// getCookiesByHost returns the cookies stored from a specific host. +// If cookies are expired they will be deleted. +func (cj *CookieJar) getCookiesByHost(host string) []*fasthttp.Cookie { + cj.mu.Lock() + defer cj.mu.Unlock() + + now := time.Now() + cookies := cj.hostCookies[host] + + for i := 0; i < len(cookies); i++ { + c := cookies[i] + if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) { // release cookie if expired + cookies = append(cookies[:i], cookies[i+1:]...) + fasthttp.ReleaseCookie(c) + i-- + } + } + + return cookies +} + +// Set sets cookies for a specific host. +// The host is get from uri.Host(). +// If the cookie key already exists it will be replaced by the new cookie value. +// +// CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely. +func (cj *CookieJar) Set(uri *fasthttp.URI, cookies ...*fasthttp.Cookie) { + if uri == nil { + return + } + + cj.SetByHost(uri.Host(), cookies...) +} + +// SetByHost sets cookies for a specific host. +// If the cookie key already exists it will be replaced by the new cookie value. +// +// CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely. +func (cj *CookieJar) SetByHost(host []byte, cookies ...*fasthttp.Cookie) { + hostStr := utils.UnsafeString(host) + + cj.mu.Lock() + defer cj.mu.Unlock() + + if cj.hostCookies == nil { + cj.hostCookies = make(map[string][]*fasthttp.Cookie) + } + + hostCookies, ok := cj.hostCookies[hostStr] + if !ok { + // If the key does not exist in the map, then we must make a copy for the key to avoid unsafe usage. + hostStr = string(host) + } + + for _, cookie := range cookies { + c := searchCookieByKeyAndPath(cookie.Key(), cookie.Path(), hostCookies) + if c == nil { + // If the cookie does not exist in the slice, let's acquire new cookie and store it. + c = fasthttp.AcquireCookie() + hostCookies = append(hostCookies, c) + } + c.CopyTo(cookie) // override cookie properties + } + cj.hostCookies[hostStr] = hostCookies +} + +// SetKeyValue sets a cookie by key and value for a specific host. +// +// This function prevents extra allocations by making repeated cookies +// not being duplicated. +func (cj *CookieJar) SetKeyValue(host, key, value string) { + c := fasthttp.AcquireCookie() + c.SetKey(key) + c.SetValue(value) + + cj.SetByHost(utils.UnsafeBytes(host), c) +} + +// SetKeyValueBytes sets a cookie by key and value for a specific host. +// +// This function prevents extra allocations by making repeated cookies +// not being duplicated. +func (cj *CookieJar) SetKeyValueBytes(host string, key, value []byte) { + c := fasthttp.AcquireCookie() + c.SetKeyBytes(key) + c.SetValueBytes(value) + + cj.SetByHost(utils.UnsafeBytes(host), c) +} + +// dumpCookiesToReq dumps the stored cookies to the request. +func (cj *CookieJar) dumpCookiesToReq(req *fasthttp.Request) { + uri := req.URI() + + cookies := cj.getByHostAndPath(uri.Host(), uri.Path()) + for _, cookie := range cookies { + req.Header.SetCookieBytesKV(cookie.Key(), cookie.Value()) + } +} + +// parseCookiesFromResp parses the response cookies and stores them. +func (cj *CookieJar) parseCookiesFromResp(host, path []byte, resp *fasthttp.Response) { + hostStr := utils.UnsafeString(host) + + cj.mu.Lock() + defer cj.mu.Unlock() + + if cj.hostCookies == nil { + cj.hostCookies = make(map[string][]*fasthttp.Cookie) + } + cookies, ok := cj.hostCookies[hostStr] + if !ok { + // If the key does not exist in the map then + // we must make a copy for the key to avoid unsafe usage. + hostStr = string(host) + } + + now := time.Now() + resp.Header.VisitAllCookie(func(key, value []byte) { + isCreated := false + c := searchCookieByKeyAndPath(key, path, cookies) + if c == nil { + c, isCreated = fasthttp.AcquireCookie(), true + } + + _ = c.ParseBytes(value) //nolint:errcheck // ignore error + if c.Expire().Equal(fasthttp.CookieExpireUnlimited) || c.Expire().After(now) { + cookies = append(cookies, c) + } else if isCreated { + fasthttp.ReleaseCookie(c) + } + }) + cj.hostCookies[hostStr] = cookies +} + +// Release releases all cookie values. +func (cj *CookieJar) Release() { + // FOllOW-UP performance optimization + // currently a race condition is found because the reset method modifies a value which is not a copy but a reference -> solution should be to make a copy + // for _, v := range cj.hostCookies { + // for _, c := range v { + // fasthttp.ReleaseCookie(c) + // } + // } + cj.hostCookies = nil +} + +// searchCookieByKeyAndPath searches for a cookie by key and path. +func searchCookieByKeyAndPath(key, path []byte, cookies []*fasthttp.Cookie) *fasthttp.Cookie { + for _, c := range cookies { + if bytes.Equal(key, c.Key()) { + if len(path) <= 1 || bytes.HasPrefix(c.Path(), path) { + return c + } + } + } + + return nil +} diff --git a/client/cookiejar_test.go b/client/cookiejar_test.go new file mode 100644 index 0000000000..3b6fdcda83 --- /dev/null +++ b/client/cookiejar_test.go @@ -0,0 +1,213 @@ +package client + +import ( + "bytes" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func checkKeyValue(t *testing.T, cj *CookieJar, cookie *fasthttp.Cookie, uri *fasthttp.URI, n int) { + t.Helper() + + cs := cj.Get(uri) + require.GreaterOrEqual(t, len(cs), n) + + c := cs[n-1] + require.NotNil(t, c) + + require.Equal(t, string(c.Key()), string(cookie.Key())) + require.Equal(t, string(c.Value()), string(cookie.Value())) +} + +func TestCookieJarGet(t *testing.T) { + t.Parallel() + + url := []byte("http://fasthttp.com/") + url1 := []byte("http://fasthttp.com/make") + url11 := []byte("http://fasthttp.com/hola") + url2 := []byte("http://fasthttp.com/make/fasthttp") + url3 := []byte("http://fasthttp.com/make/fasthttp/great") + prefix := []byte("/") + prefix1 := []byte("/make") + prefix2 := []byte("/make/fasthttp") + prefix3 := []byte("/make/fasthttp/great") + cj := &CookieJar{} + + c1 := &fasthttp.Cookie{} + c1.SetKey("k") + c1.SetValue("v") + c1.SetPath("/make/") + + c2 := &fasthttp.Cookie{} + c2.SetKey("kk") + c2.SetValue("vv") + c2.SetPath("/make/fasthttp") + + c3 := &fasthttp.Cookie{} + c3.SetKey("kkk") + c3.SetValue("vvv") + c3.SetPath("/make/fasthttp/great") + + uri := fasthttp.AcquireURI() + require.NoError(t, uri.Parse(nil, url)) + + uri1 := fasthttp.AcquireURI() + require.NoError(t, uri1.Parse(nil, url1)) + + uri11 := fasthttp.AcquireURI() + require.NoError(t, uri11.Parse(nil, url11)) + + uri2 := fasthttp.AcquireURI() + require.NoError(t, uri2.Parse(nil, url2)) + + uri3 := fasthttp.AcquireURI() + require.NoError(t, uri3.Parse(nil, url3)) + + cj.Set(uri1, c1, c2, c3) + + cookies := cj.Get(uri1) + require.Len(t, cookies, 3) + for _, cookie := range cookies { + require.True(t, bytes.HasPrefix(cookie.Path(), prefix1)) + } + + cookies = cj.Get(uri11) + require.Empty(t, cookies) + + cookies = cj.Get(uri2) + require.Len(t, cookies, 2) + for _, cookie := range cookies { + require.True(t, bytes.HasPrefix(cookie.Path(), prefix2)) + } + + cookies = cj.Get(uri3) + require.Len(t, cookies, 1) + + for _, cookie := range cookies { + require.True(t, bytes.HasPrefix(cookie.Path(), prefix3)) + } + + cookies = cj.Get(uri) + require.Len(t, cookies, 3) + for _, cookie := range cookies { + require.True(t, bytes.HasPrefix(cookie.Path(), prefix)) + } +} + +func TestCookieJarGetExpired(t *testing.T) { + t.Parallel() + + url1 := []byte("http://fasthttp.com/make/") + uri1 := fasthttp.AcquireURI() + require.NoError(t, uri1.Parse(nil, url1)) + + c1 := &fasthttp.Cookie{} + c1.SetKey("k") + c1.SetValue("v") + c1.SetExpire(time.Now().Add(-time.Hour)) + + cj := &CookieJar{} + cj.Set(uri1, c1) + + cookies := cj.Get(uri1) + require.Empty(t, cookies) +} + +func TestCookieJarSet(t *testing.T) { + t.Parallel() + + url := []byte("http://fasthttp.com/hello/world") + cj := &CookieJar{} + + cookie := &fasthttp.Cookie{} + cookie.SetKey("k") + cookie.SetValue("v") + + uri := fasthttp.AcquireURI() + require.NoError(t, uri.Parse(nil, url)) + + cj.Set(uri, cookie) + checkKeyValue(t, cj, cookie, uri, 1) +} + +func TestCookieJarSetRepeatedCookieKeys(t *testing.T) { + t.Parallel() + + host := "fast.http" + cj := &CookieJar{} + + uri := fasthttp.AcquireURI() + uri.SetHost(host) + + cookie := &fasthttp.Cookie{} + cookie.SetKey("k") + cookie.SetValue("v") + + cookie2 := &fasthttp.Cookie{} + cookie2.SetKey("k") + cookie2.SetValue("v2") + + cookie3 := &fasthttp.Cookie{} + cookie3.SetKey("key") + cookie3.SetValue("value") + + cj.Set(uri, cookie, cookie2, cookie3) + + cookies := cj.Get(uri) + require.Len(t, cookies, 2) + require.Equal(t, cookies[0], cookie2) + require.True(t, bytes.Equal(cookies[0].Value(), cookie2.Value())) +} + +func TestCookieJarSetKeyValue(t *testing.T) { + t.Parallel() + + host := "fast.http" + cj := &CookieJar{} + + uri := fasthttp.AcquireURI() + uri.SetHost(host) + + cj.SetKeyValue(host, "k", "v") + cj.SetKeyValue(host, "key", "value") + cj.SetKeyValue(host, "k", "vv") + cj.SetKeyValue(host, "key", "value2") + + cookies := cj.Get(uri) + require.Len(t, cookies, 2) +} + +func TestCookieJarGetFromResponse(t *testing.T) { + t.Parallel() + + res := fasthttp.AcquireResponse() + host := []byte("fast.http") + uri := fasthttp.AcquireURI() + uri.SetHostBytes(host) + + c := &fasthttp.Cookie{} + c.SetKey("key") + c.SetValue("val") + + c2 := &fasthttp.Cookie{} + c2.SetKey("k") + c2.SetValue("v") + + c3 := &fasthttp.Cookie{} + c3.SetKey("kk") + c3.SetValue("vv") + + res.Header.SetStatusCode(200) + res.Header.SetCookie(c) + res.Header.SetCookie(c2) + res.Header.SetCookie(c3) + + cj := &CookieJar{} + cj.parseCookiesFromResp(host, nil, res) + + cookies := cj.Get(uri) + require.Len(t, cookies, 3) +} diff --git a/client/core.go b/client/core.go new file mode 100644 index 0000000000..315d12d474 --- /dev/null +++ b/client/core.go @@ -0,0 +1,272 @@ +package client + +import ( + "context" + "errors" + "net" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/addon/retry" + "github.com/valyala/fasthttp" +) + +var boundary = "--FiberFormBoundary" + +// RequestHook is a function that receives Agent and Request, +// it can change the data in Request and Agent. +// +// Called before a request is sent. +type RequestHook func(*Client, *Request) error + +// ResponseHook is a function that receives Agent, Response and Request, +// it can change the data is Response or deal with some effects. +// +// Called after a response has been received. +type ResponseHook func(*Client, *Response, *Request) error + +// RetryConfig is an alias for config in the `addon/retry` package. +type RetryConfig = retry.Config + +// addMissingPort will add the corresponding port number for host. +func addMissingPort(addr string, isTLS bool) string { //revive:disable-line:flag-parameter // Accepting a bool param named isTLS if fine here + n := strings.Index(addr, ":") + if n >= 0 { + return addr + } + port := 80 + if isTLS { + port = 443 + } + return net.JoinHostPort(addr, strconv.Itoa(port)) +} + +// `core` stores middleware and plugin definitions, +// and defines the execution process +type core struct { + client *Client + req *Request + ctx context.Context //nolint:containedctx // It's needed to be stored in the core. +} + +// getRetryConfig returns the retry configuration of the client. +func (c *core) getRetryConfig() *RetryConfig { + c.client.mu.RLock() + defer c.client.mu.RUnlock() + + cfg := c.client.RetryConfig() + if cfg == nil { + return nil + } + + return &RetryConfig{ + InitialInterval: cfg.InitialInterval, + MaxBackoffTime: cfg.MaxBackoffTime, + Multiplier: cfg.Multiplier, + MaxRetryCount: cfg.MaxRetryCount, + } +} + +// execFunc is the core function of the client. +// It sends the request and receives the response. +func (c *core) execFunc() (*Response, error) { + resp := AcquireResponse() + resp.setClient(c.client) + resp.setRequest(c.req) + + // To avoid memory allocation reuse of data structures such as errch. + done := int32(0) + errCh, reqv := acquireErrChan(), fasthttp.AcquireRequest() + defer func() { + releaseErrChan(errCh) + }() + + c.req.RawRequest.CopyTo(reqv) + cfg := c.getRetryConfig() + + var err error + go func() { + respv := fasthttp.AcquireResponse() + defer func() { + fasthttp.ReleaseRequest(reqv) + fasthttp.ReleaseResponse(respv) + }() + + if cfg != nil { + err = retry.NewExponentialBackoff(*cfg).Retry(func() error { + if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { + return c.client.fasthttp.DoRedirects(reqv, respv, c.req.maxRedirects) + } + + return c.client.fasthttp.Do(reqv, respv) + }) + } else { + if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { + err = c.client.fasthttp.DoRedirects(reqv, respv, c.req.maxRedirects) + } else { + err = c.client.fasthttp.Do(reqv, respv) + } + } + + if atomic.CompareAndSwapInt32(&done, 0, 1) { + if err != nil { + errCh <- err + return + } + respv.CopyTo(resp.RawResponse) + errCh <- nil + } + }() + + select { + case err := <-errCh: + if err != nil { + // When get error should release Response + ReleaseResponse(resp) + return nil, err + } + return resp, nil + case <-c.ctx.Done(): + atomic.SwapInt32(&done, 1) + ReleaseResponse(resp) + return nil, ErrTimeoutOrCancel + } +} + +// preHooks Exec request hook +func (c *core) preHooks() error { + c.client.mu.Lock() + defer c.client.mu.Unlock() + + for _, f := range c.client.userRequestHooks { + err := f(c.client, c.req) + if err != nil { + return err + } + } + + for _, f := range c.client.builtinRequestHooks { + err := f(c.client, c.req) + if err != nil { + return err + } + } + + return nil +} + +// afterHooks Exec response hooks +func (c *core) afterHooks(resp *Response) error { + c.client.mu.Lock() + defer c.client.mu.Unlock() + + for _, f := range c.client.builtinResponseHooks { + err := f(c.client, resp, c.req) + if err != nil { + return err + } + } + + for _, f := range c.client.userResponseHooks { + err := f(c.client, resp, c.req) + if err != nil { + return err + } + } + + return nil +} + +// timeout deals with timeout +func (c *core) timeout() context.CancelFunc { + var cancel context.CancelFunc + + if c.req.timeout > 0 { + c.ctx, cancel = context.WithTimeout(c.ctx, c.req.timeout) + } else if c.client.timeout > 0 { + c.ctx, cancel = context.WithTimeout(c.ctx, c.client.timeout) + } + + return cancel +} + +// execute will exec each hooks and plugins. +func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Response, error) { + // keep a reference, because pass param is boring + c.ctx = ctx + c.client = client + c.req = req + + // The built-in hooks will be executed only + // after the user-defined hooks are executed. + err := c.preHooks() + if err != nil { + return nil, err + } + + cancel := c.timeout() + if cancel != nil { + defer cancel() + } + + // Do http request + resp, err := c.execFunc() + if err != nil { + return nil, err + } + + // The built-in hooks will be executed only + // before the user-defined hooks are executed. + err = c.afterHooks(resp) + if err != nil { + resp.Close() + return nil, err + } + + return resp, nil +} + +var errChanPool = &sync.Pool{ + New: func() any { + return make(chan error, 1) + }, +} + +// acquireErrChan returns an empty error chan from the pool. +// +// The returned error chan may be returned to the pool with releaseErrChan when no longer needed. +// This allows reducing GC load. +func acquireErrChan() chan error { + ch, ok := errChanPool.Get().(chan error) + if !ok { + panic(errors.New("failed to type-assert to chan error")) + } + + return ch +} + +// releaseErrChan returns the object acquired via acquireErrChan to the pool. +// +// Do not access the released core object, otherwise data races may occur. +func releaseErrChan(ch chan error) { + errChanPool.Put(ch) +} + +// newCore returns an empty core object. +func newCore() *core { + c := &core{} + + return c +} + +var ( + ErrTimeoutOrCancel = errors.New("timeout or cancel") + ErrURLFormat = errors.New("the url is a mistake") + ErrNotSupportSchema = errors.New("the protocol is not support, only http or https") + ErrFileNoName = errors.New("the file should have name") + ErrBodyType = errors.New("the body type should be []byte") + ErrNotSupportSaveMethod = errors.New("file path and io.Writer are supported") +) diff --git a/client/core_test.go b/client/core_test.go new file mode 100644 index 0000000000..1b8ea42b9d --- /dev/null +++ b/client/core_test.go @@ -0,0 +1,248 @@ +package client + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp/fasthttputil" +) + +func Test_AddMissing_Port(t *testing.T) { + t.Parallel() + + type args struct { + addr string + isTLS bool + } + tests := []struct { + name string + args args + want string + }{ + { + name: "do anything", + args: args{ + addr: "example.com:1234", + }, + want: "example.com:1234", + }, + { + name: "add 80 port", + args: args{ + addr: "example.com", + }, + want: "example.com:80", + }, + { + name: "add 443 port", + args: args{ + addr: "example.com", + isTLS: true, + }, + want: "example.com:443", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, addMissingPort(tt.args.addr, tt.args.isTLS)) + }) + } +} + +func Test_Exec_Func(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + app := fiber.New() + + app.Get("/normal", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + + app.Get("/return-error", func(_ fiber.Ctx) error { + return errors.New("the request is error") + }) + + app.Get("/hang-up", func(c fiber.Ctx) error { + time.Sleep(time.Second) + return c.SendString(c.Hostname() + " hang up") + }) + + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) + }() + + time.Sleep(300 * time.Millisecond) + + t.Run("normal request", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + core.ctx = context.Background() + core.client = client + core.req = req + + client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed + req.RawRequest.SetRequestURI("http://example.com/normal") + + resp, err := core.execFunc() + require.NoError(t, err) + require.Equal(t, 200, resp.RawResponse.StatusCode()) + require.Equal(t, "example.com", string(resp.RawResponse.Body())) + }) + + t.Run("the request return an error", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + core.ctx = context.Background() + core.client = client + core.req = req + + client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed + req.RawRequest.SetRequestURI("http://example.com/return-error") + + resp, err := core.execFunc() + + require.NoError(t, err) + require.Equal(t, 500, resp.RawResponse.StatusCode()) + require.Equal(t, "the request is error", string(resp.RawResponse.Body())) + }) + + t.Run("the request timeout", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + core.ctx = ctx + core.client = client + core.req = req + + client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed + req.RawRequest.SetRequestURI("http://example.com/hang-up") + + _, err := core.execFunc() + + require.Equal(t, ErrTimeoutOrCancel, err) + }) +} + +func Test_Execute(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + app := fiber.New() + + app.Get("/normal", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + + app.Get("/return-error", func(_ fiber.Ctx) error { + return errors.New("the request is error") + }) + + app.Get("/hang-up", func(c fiber.Ctx) error { + time.Sleep(time.Second) + return c.SendString(c.Hostname() + " hang up") + }) + + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) + }() + + t.Run("add user request hooks", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + client.AddRequestHook(func(_ *Client, _ *Request) error { + require.Equal(t, "http://example.com", req.URL()) + return nil + }) + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed + }) + req.SetURL("http://example.com") + + resp, err := core.execute(context.Background(), client, req) + require.NoError(t, err) + require.Equal(t, "Cannot GET /", string(resp.RawResponse.Body())) + }) + + t.Run("add user response hooks", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + client.AddResponseHook(func(_ *Client, _ *Response, req *Request) error { + require.Equal(t, "http://example.com", req.URL()) + return nil + }) + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed + }) + req.SetURL("http://example.com") + + resp, err := core.execute(context.Background(), client, req) + require.NoError(t, err) + require.Equal(t, "Cannot GET /", string(resp.RawResponse.Body())) + }) + + t.Run("no timeout", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed + }) + req.SetURL("http://example.com/hang-up") + + resp, err := core.execute(context.Background(), client, req) + require.NoError(t, err) + require.Equal(t, "example.com hang up", string(resp.RawResponse.Body())) + }) + + t.Run("client timeout", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + client.SetTimeout(500 * time.Millisecond) + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed + }) + req.SetURL("http://example.com/hang-up") + + _, err := core.execute(context.Background(), client, req) + require.Equal(t, ErrTimeoutOrCancel, err) + }) + + t.Run("request timeout", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed + }) + req.SetURL("http://example.com/hang-up"). + SetTimeout(300 * time.Millisecond) + + _, err := core.execute(context.Background(), client, req) + require.Equal(t, ErrTimeoutOrCancel, err) + }) + + t.Run("request timeout has higher level", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + client.SetTimeout(30 * time.Millisecond) + + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed + }) + req.SetURL("http://example.com/hang-up"). + SetTimeout(3000 * time.Millisecond) + + resp, err := core.execute(context.Background(), client, req) + require.NoError(t, err) + require.Equal(t, "example.com hang up", string(resp.RawResponse.Body())) + }) +} diff --git a/client/helper_test.go b/client/helper_test.go new file mode 100644 index 0000000000..67380f3470 --- /dev/null +++ b/client/helper_test.go @@ -0,0 +1,157 @@ +package client + +import ( + "net" + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp/fasthttputil" +) + +type testServer struct { + app *fiber.App + ch chan struct{} + ln *fasthttputil.InmemoryListener + tb testing.TB +} + +func startTestServer(tb testing.TB, beforeStarting func(app *fiber.App)) *testServer { + tb.Helper() + + ln := fasthttputil.NewInmemoryListener() + app := fiber.New() + + if beforeStarting != nil { + beforeStarting(app) + } + + ch := make(chan struct{}) + go func() { + if err := app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}); err != nil { + tb.Fatal(err) + } + + close(ch) + }() + + return &testServer{ + app: app, + ch: ch, + ln: ln, + tb: tb, + } +} + +func (ts *testServer) stop() { + ts.tb.Helper() + + if err := ts.app.Shutdown(); err != nil { + ts.tb.Fatal(err) + } + + select { + case <-ts.ch: + case <-time.After(time.Second): + ts.tb.Fatalf("timeout when waiting for server close") + } +} + +func (ts *testServer) dial() func(addr string) (net.Conn, error) { + ts.tb.Helper() + + return func(_ string) (net.Conn, error) { + return ts.ln.Dial() //nolint:wrapcheck // not needed + } +} + +func createHelperServer(tb testing.TB) (*fiber.App, func(addr string) (net.Conn, error), func()) { + tb.Helper() + + ln := fasthttputil.NewInmemoryListener() + + app := fiber.New() + + return app, func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed + }, func() { + require.NoError(tb, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) + } +} + +func testRequest(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted string, count ...int) { + t.Helper() + + app, ln, start := createHelperServer(t) + app.Get("/", handler) + go start() + + c := 1 + if len(count) > 0 { + c = count[0] + } + + client := NewClient().SetDial(ln) + + for i := 0; i < c; i++ { + req := AcquireRequest().SetClient(client) + wrapAgent(req) + + resp, err := req.Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, excepted, resp.String()) + resp.Close() + } +} + +func testRequestFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted error, count ...int) { + t.Helper() + + app, ln, start := createHelperServer(t) + app.Get("/", handler) + go start() + + c := 1 + if len(count) > 0 { + c = count[0] + } + + client := NewClient().SetDial(ln) + + for i := 0; i < c; i++ { + req := AcquireRequest().SetClient(client) + wrapAgent(req) + + _, err := req.Get("http://example.com") + + require.Equal(t, excepted.Error(), err.Error()) + } +} + +func testClient(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Client), excepted string, count ...int) { //nolint: unparam // maybe needed + t.Helper() + + app, ln, start := createHelperServer(t) + app.Get("/", handler) + go start() + + c := 1 + if len(count) > 0 { + c = count[0] + } + + for i := 0; i < c; i++ { + client := NewClient().SetDial(ln) + wrapAgent(client) + + resp, err := client.Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, excepted, resp.String()) + resp.Close() + } +} diff --git a/client/hooks.go b/client/hooks.go new file mode 100644 index 0000000000..0ecc970d53 --- /dev/null +++ b/client/hooks.go @@ -0,0 +1,328 @@ +package client + +import ( + "errors" + "fmt" + "io" + "math/rand" + "mime/multipart" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "github.com/gofiber/utils/v2" + "github.com/valyala/fasthttp" +) + +var ( + protocolCheck = regexp.MustCompile(`^https?://.*$`) + + headerAccept = "Accept" + + applicationJSON = "application/json" + applicationXML = "application/xml" + applicationForm = "application/x-www-form-urlencoded" + multipartFormData = "multipart/form-data" + + letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + letterIdxBits = 6 // 6 bits to represent a letter index + letterIdxMask = 1<= 0; { + if remain == 0 { + cache, remain = src.Int63(), letterIdxMax + } + + if idx := int(cache & int64(letterIdxMask)); idx < length { + b[i] = letterBytes[idx] + i-- + } + cache >>= int64(letterIdxBits) + remain-- + } + + return utils.UnsafeString(b) +} + +// parserRequestURL will set the options for the hostclient +// and normalize the url. +// The baseUrl will be merge with request uri. +// Query params and path params deal in this function. +func parserRequestURL(c *Client, req *Request) error { + splitURL := strings.Split(req.url, "?") + // I don't want to judge splitURL length. + splitURL = append(splitURL, "") + + // Determine whether to superimpose baseurl based on + // whether the URL starts with the protocol + uri := splitURL[0] + if !protocolCheck.MatchString(uri) { + uri = c.baseURL + uri + if !protocolCheck.MatchString(uri) { + return ErrURLFormat + } + } + + // set path params + req.path.VisitAll(func(key, val string) { + uri = strings.ReplaceAll(uri, ":"+key, val) + }) + c.path.VisitAll(func(key, val string) { + uri = strings.ReplaceAll(uri, ":"+key, val) + }) + + // set uri to request and other related setting + req.RawRequest.SetRequestURI(uri) + + // merge query params + hashSplit := strings.Split(splitURL[1], "#") + hashSplit = append(hashSplit, "") + args := fasthttp.AcquireArgs() + defer func() { + fasthttp.ReleaseArgs(args) + }() + + args.Parse(hashSplit[0]) + c.params.VisitAll(func(key, value []byte) { + args.AddBytesKV(key, value) + }) + req.params.VisitAll(func(key, value []byte) { + args.AddBytesKV(key, value) + }) + req.RawRequest.URI().SetQueryStringBytes(utils.CopyBytes(args.QueryString())) + req.RawRequest.URI().SetHash(hashSplit[1]) + + return nil +} + +// parserRequestHeader will make request header up. +// It will merge headers from client and request. +// Header should be set automatically based on data. +// User-Agent should be set. +func parserRequestHeader(c *Client, req *Request) error { + // set method + req.RawRequest.Header.SetMethod(req.Method()) + // merge header + c.header.VisitAll(func(key, value []byte) { + req.RawRequest.Header.AddBytesKV(key, value) + }) + + req.header.VisitAll(func(key, value []byte) { + req.RawRequest.Header.AddBytesKV(key, value) + }) + + // according to data set content-type + switch req.bodyType { + case jsonBody: + req.RawRequest.Header.SetContentType(applicationJSON) + req.RawRequest.Header.Set(headerAccept, applicationJSON) + case xmlBody: + req.RawRequest.Header.SetContentType(applicationXML) + case formBody: + req.RawRequest.Header.SetContentType(applicationForm) + case filesBody: + req.RawRequest.Header.SetContentType(multipartFormData) + // set boundary + if req.boundary == boundary { + req.boundary += randString(16) + } + req.RawRequest.Header.SetMultipartFormBoundary(req.boundary) + default: + } + + // set useragent + req.RawRequest.Header.SetUserAgent(defaultUserAgent) + if c.userAgent != "" { + req.RawRequest.Header.SetUserAgent(c.userAgent) + } + if req.userAgent != "" { + req.RawRequest.Header.SetUserAgent(req.userAgent) + } + + // set referer + req.RawRequest.Header.SetReferer(c.referer) + if req.referer != "" { + req.RawRequest.Header.SetReferer(req.referer) + } + + // set cookie + // add cookie form jar to req + if c.cookieJar != nil { + c.cookieJar.dumpCookiesToReq(req.RawRequest) + } + + c.cookies.VisitAll(func(key, val string) { + req.RawRequest.Header.SetCookie(key, val) + }) + + req.cookies.VisitAll(func(key, val string) { + req.RawRequest.Header.SetCookie(key, val) + }) + + return nil +} + +// parserRequestBody automatically serializes the data according to +// the data type and stores it in the body of the rawRequest +func parserRequestBody(c *Client, req *Request) error { + switch req.bodyType { + case jsonBody: + body, err := c.jsonMarshal(req.body) + if err != nil { + return err + } + req.RawRequest.SetBody(body) + case xmlBody: + body, err := c.xmlMarshal(req.body) + if err != nil { + return err + } + req.RawRequest.SetBody(body) + case formBody: + req.RawRequest.SetBody(req.formData.QueryString()) + case filesBody: + return parserRequestBodyFile(req) + case rawBody: + if body, ok := req.body.([]byte); ok { + req.RawRequest.SetBody(body) + } else { + return ErrBodyType + } + case noBody: + return nil + } + + return nil +} + +// parserRequestBodyFile parses request body if body type is file +// this is an addition of parserRequestBody. +func parserRequestBodyFile(req *Request) error { + mw := multipart.NewWriter(req.RawRequest.BodyWriter()) + err := mw.SetBoundary(req.boundary) + if err != nil { + return fmt.Errorf("set boundary error: %w", err) + } + defer func() { + err := mw.Close() + if err != nil { + return + } + }() + + // add formdata + req.formData.VisitAll(func(key, value []byte) { + if err != nil { + return + } + err = mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value)) + }) + if err != nil { + return fmt.Errorf("write formdata error: %w", err) + } + + // add file + b := make([]byte, 512) + for i, v := range req.files { + if v.name == "" && v.path == "" { + return ErrFileNoName + } + + // if name is not exist, set name + if v.name == "" && v.path != "" { + v.path = filepath.Clean(v.path) + v.name = filepath.Base(v.path) + } + + // if field name is not exist, set it + if v.fieldName == "" { + v.fieldName = "file" + strconv.Itoa(i+1) + } + + // check the reader + if v.reader == nil { + v.reader, err = os.Open(v.path) + if err != nil { + return fmt.Errorf("open file error: %w", err) + } + } + + // write file + w, err := mw.CreateFormFile(v.fieldName, v.name) + if err != nil { + return fmt.Errorf("create file error: %w", err) + } + + for { + n, err := v.reader.Read(b) + if err != nil && !errors.Is(err, io.EOF) { + return fmt.Errorf("read file error: %w", err) + } + + if errors.Is(err, io.EOF) { + break + } + + _, err = w.Write(b[:n]) + if err != nil { + return fmt.Errorf("write file error: %w", err) + } + } + + err = v.reader.Close() + if err != nil { + return fmt.Errorf("close file error: %w", err) + } + } + + return nil +} + +// parserResponseHeader will parse the response header and store it in the response +func parserResponseCookie(c *Client, resp *Response, req *Request) error { + var err error + resp.RawResponse.Header.VisitAllCookie(func(key, value []byte) { + cookie := fasthttp.AcquireCookie() + err = cookie.ParseBytes(value) + if err != nil { + return + } + cookie.SetKeyBytes(key) + + resp.cookie = append(resp.cookie, cookie) + }) + + if err != nil { + return err + } + + // store cookies to jar + if c.cookieJar != nil { + c.cookieJar.parseCookiesFromResp(req.RawRequest.URI().Host(), req.RawRequest.URI().Path(), resp.RawResponse) + } + + return nil +} + +// logger is a response hook that logs the request and response +func logger(c *Client, resp *Response, req *Request) error { + if !c.debug { + return nil + } + + c.logger.Debugf("%s\n", req.RawRequest.String()) + c.logger.Debugf("%s\n", resp.RawResponse.String()) + + return nil +} diff --git a/client/hooks_test.go b/client/hooks_test.go new file mode 100644 index 0000000000..a555bba833 --- /dev/null +++ b/client/hooks_test.go @@ -0,0 +1,652 @@ +package client + +import ( + "bytes" + "encoding/xml" + "fmt" + "io" + "net" + "net/url" + "strings" + "testing" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" +) + +func Test_Rand_String(t *testing.T) { + t.Parallel() + tests := []struct { + name string + args int + }{ + { + name: "test generate", + args: 16, + }, + { + name: "test generate smaller string", + args: 8, + }, + { + name: "test generate larger string", + args: 32, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := randString(tt.args) + require.Len(t, got, tt.args) + }) + } +} + +func Test_Parser_Request_URL(t *testing.T) { + t.Parallel() + + t.Run("client baseurl should be set", func(t *testing.T) { + t.Parallel() + client := NewClient().SetBaseURL("http://example.com/api") + req := AcquireRequest().SetURL("") + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, "http://example.com/api", req.RawRequest.URI().String()) + }) + + t.Run("request url should be set", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest().SetURL("http://example.com/api") + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, "http://example.com/api", req.RawRequest.URI().String()) + }) + + t.Run("the request url will override baseurl with protocol", func(t *testing.T) { + t.Parallel() + client := NewClient().SetBaseURL("http://example.com/api") + req := AcquireRequest().SetURL("http://example.com/api/v1") + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/v1", req.RawRequest.URI().String()) + }) + + t.Run("the request url should be append after baseurl without protocol", func(t *testing.T) { + t.Parallel() + client := NewClient().SetBaseURL("http://example.com/api") + req := AcquireRequest().SetURL("/v1") + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/v1", req.RawRequest.URI().String()) + }) + + t.Run("the url is error", func(t *testing.T) { + t.Parallel() + client := NewClient().SetBaseURL("example.com/api") + req := AcquireRequest().SetURL("/v1") + + err := parserRequestURL(client, req) + require.Equal(t, ErrURLFormat, err) + }) + + t.Run("the path param from client", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetBaseURL("http://example.com/api/:id"). + SetPathParam("id", "5") + req := AcquireRequest() + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/5", req.RawRequest.URI().String()) + }) + + t.Run("the path param from request", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetBaseURL("http://example.com/api/:id/:name"). + SetPathParam("id", "5") + req := AcquireRequest(). + SetURL("/{key}"). + SetPathParams(map[string]string{ + "name": "fiber", + "key": "val", + }). + DelPathParams("key") + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/5/fiber/%7Bkey%7D", req.RawRequest.URI().String()) + }) + + t.Run("the path param from request and client", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetBaseURL("http://example.com/api/:id/:name"). + SetPathParam("id", "5") + req := AcquireRequest(). + SetURL("/:key"). + SetPathParams(map[string]string{ + "name": "fiber", + "key": "val", + "id": "12", + }) + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/12/fiber/val", req.RawRequest.URI().String()) + }) + + t.Run("query params from client should be set", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetParam("foo", "bar") + req := AcquireRequest().SetURL("http://example.com/api/v1") + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, []byte("foo=bar"), req.RawRequest.URI().QueryString()) + }) + + t.Run("query params from request should be set", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + SetURL("http://example.com/api/v1"). + SetParam("bar", "foo") + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, []byte("bar=foo"), req.RawRequest.URI().QueryString()) + }) + + t.Run("query params should be merged", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetParam("bar", "foo1") + req := AcquireRequest(). + SetURL("http://example.com/api/v1?bar=foo2"). + SetParam("bar", "foo") + + err := parserRequestURL(client, req) + require.NoError(t, err) + + values, err := url.ParseQuery(string(req.RawRequest.URI().QueryString())) + require.NoError(t, err) + + flag1, flag2, flag3 := false, false, false + for _, v := range values["bar"] { + if v == "foo1" { + flag1 = true + } else if v == "foo2" { + flag2 = true + } else if v == "foo" { + flag3 = true + } + } + require.True(t, flag1) + require.True(t, flag2) + require.True(t, flag3) + }) +} + +func Test_Parser_Request_Header(t *testing.T) { + t.Parallel() + + t.Run("client header should be set", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetHeaders(map[string]string{ + fiber.HeaderContentType: "application/json", + }) + + req := AcquireRequest() + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("application/json"), req.RawRequest.Header.ContentType()) + }) + + t.Run("request header should be set", func(t *testing.T) { + t.Parallel() + client := NewClient() + + req := AcquireRequest(). + SetHeaders(map[string]string{ + fiber.HeaderContentType: "application/json, utf-8", + }) + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType()) + }) + + t.Run("request header should override client header", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetHeader(fiber.HeaderContentType, "application/xml") + + req := AcquireRequest(). + SetHeader(fiber.HeaderContentType, "application/json, utf-8") + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType()) + }) + + t.Run("auto set json header", func(t *testing.T) { + t.Parallel() + type jsonData struct { + Name string `json:"name"` + } + client := NewClient() + req := AcquireRequest(). + SetJSON(jsonData{ + Name: "foo", + }) + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte(applicationJSON), req.RawRequest.Header.ContentType()) + }) + + t.Run("auto set xml header", func(t *testing.T) { + t.Parallel() + type xmlData struct { + XMLName xml.Name `xml:"body"` + Name string `xml:"name"` + } + client := NewClient() + req := AcquireRequest(). + SetXML(xmlData{ + Name: "foo", + }) + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte(applicationXML), req.RawRequest.Header.ContentType()) + }) + + t.Run("auto set form data header", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + SetFormDatas(map[string]string{ + "foo": "bar", + "ball": "cricle and square", + }) + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, applicationForm, string(req.RawRequest.Header.ContentType())) + }) + + t.Run("auto set file header", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))). + SetFormData("foo", "bar") + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.True(t, strings.Contains(string(req.RawRequest.Header.MultipartFormBoundary()), "--FiberFormBoundary")) + require.True(t, strings.Contains(string(req.RawRequest.Header.ContentType()), multipartFormData)) + }) + + t.Run("ua should have default value", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest() + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("fiber"), req.RawRequest.Header.UserAgent()) + }) + + t.Run("ua in client should be set", func(t *testing.T) { + t.Parallel() + client := NewClient().SetUserAgent("foo") + req := AcquireRequest() + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("foo"), req.RawRequest.Header.UserAgent()) + }) + + t.Run("ua in request should have higher level", func(t *testing.T) { + t.Parallel() + client := NewClient().SetUserAgent("foo") + req := AcquireRequest().SetUserAgent("bar") + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("bar"), req.RawRequest.Header.UserAgent()) + }) + + t.Run("referer in client should be set", func(t *testing.T) { + t.Parallel() + client := NewClient().SetReferer("https://example.com") + req := AcquireRequest() + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("https://example.com"), req.RawRequest.Header.Referer()) + }) + + t.Run("referer in request should have higher level", func(t *testing.T) { + t.Parallel() + client := NewClient().SetReferer("http://example.com") + req := AcquireRequest().SetReferer("https://example.com") + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("https://example.com"), req.RawRequest.Header.Referer()) + }) + + t.Run("client cookie should be set", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetCookie("foo", "bar"). + SetCookies(map[string]string{ + "bar": "foo", + "bar1": "foo1", + }). + DelCookies("bar1") + + req := AcquireRequest() + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) + require.Equal(t, "foo", string(req.RawRequest.Header.Cookie("bar"))) + require.Equal(t, "", string(req.RawRequest.Header.Cookie("bar1"))) + }) + + t.Run("request cookie should be set", func(t *testing.T) { + t.Parallel() + type cookies struct { + Foo string `cookie:"foo"` + Bar int `cookie:"bar"` + } + + client := NewClient() + + req := AcquireRequest(). + SetCookiesWithStruct(&cookies{ + Foo: "bar", + Bar: 67, + }) + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) + require.Equal(t, "67", string(req.RawRequest.Header.Cookie("bar"))) + require.Equal(t, "", string(req.RawRequest.Header.Cookie("bar1"))) + }) + + t.Run("request cookie will override client cookie", func(t *testing.T) { + t.Parallel() + type cookies struct { + Foo string `cookie:"foo"` + Bar int `cookie:"bar"` + } + + client := NewClient(). + SetCookie("foo", "bar"). + SetCookies(map[string]string{ + "bar": "foo", + "bar1": "foo1", + }) + + req := AcquireRequest(). + SetCookiesWithStruct(&cookies{ + Foo: "bar", + Bar: 67, + }) + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) + require.Equal(t, "67", string(req.RawRequest.Header.Cookie("bar"))) + require.Equal(t, "foo1", string(req.RawRequest.Header.Cookie("bar1"))) + }) +} + +func Test_Parser_Request_Body(t *testing.T) { + t.Parallel() + + t.Run("json body", func(t *testing.T) { + t.Parallel() + type jsonData struct { + Name string `json:"name"` + } + client := NewClient() + req := AcquireRequest(). + SetJSON(jsonData{ + Name: "foo", + }) + + err := parserRequestBody(client, req) + require.NoError(t, err) + require.Equal(t, []byte("{\"name\":\"foo\"}"), req.RawRequest.Body()) + }) + + t.Run("xml body", func(t *testing.T) { + t.Parallel() + type xmlData struct { + XMLName xml.Name `xml:"body"` + Name string `xml:"name"` + } + client := NewClient() + req := AcquireRequest(). + SetXML(xmlData{ + Name: "foo", + }) + + err := parserRequestBody(client, req) + require.NoError(t, err) + require.Equal(t, []byte("foo"), req.RawRequest.Body()) + }) + + t.Run("form data body", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + SetFormDatas(map[string]string{ + "ball": "cricle and square", + }) + + err := parserRequestBody(client, req) + require.NoError(t, err) + require.Equal(t, "ball=cricle+and+square", string(req.RawRequest.Body())) + }) + + t.Run("form data body error", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + SetFormDatas(map[string]string{ + "": "", + }) + + err := parserRequestBody(client, req) + require.NoError(t, err) + }) + + t.Run("file body", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))) + + err := parserRequestBody(client, req) + require.NoError(t, err) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "----FiberFormBoundary")) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "world")) + }) + + t.Run("file and form data", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))). + SetFormData("foo", "bar") + + err := parserRequestBody(client, req) + require.NoError(t, err) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "----FiberFormBoundary")) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "world")) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "bar")) + }) + + t.Run("raw body", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + SetRawBody([]byte("hello world")) + + err := parserRequestBody(client, req) + require.NoError(t, err) + require.Equal(t, []byte("hello world"), req.RawRequest.Body()) + }) + + t.Run("raw body error", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + SetRawBody([]byte("hello world")) + + req.body = nil + + err := parserRequestBody(client, req) + require.ErrorIs(t, err, ErrBodyType) + }) +} + +type dummyLogger struct { + buf *bytes.Buffer +} + +func (*dummyLogger) Trace(_ ...any) {} + +func (*dummyLogger) Debug(_ ...any) {} + +func (*dummyLogger) Info(_ ...any) {} + +func (*dummyLogger) Warn(_ ...any) {} + +func (*dummyLogger) Error(_ ...any) {} + +func (*dummyLogger) Fatal(_ ...any) {} + +func (*dummyLogger) Panic(_ ...any) {} + +func (*dummyLogger) Tracef(_ string, _ ...any) {} + +func (l *dummyLogger) Debugf(format string, v ...any) { + _, _ = l.buf.WriteString(fmt.Sprintf(format, v...)) //nolint:errcheck // not needed +} + +func (*dummyLogger) Infof(_ string, _ ...any) {} + +func (*dummyLogger) Warnf(_ string, _ ...any) {} + +func (*dummyLogger) Errorf(_ string, _ ...any) {} + +func (*dummyLogger) Fatalf(_ string, _ ...any) {} + +func (*dummyLogger) Panicf(_ string, _ ...any) {} + +func (*dummyLogger) Tracew(_ string, _ ...any) {} + +func (*dummyLogger) Debugw(_ string, _ ...any) {} + +func (*dummyLogger) Infow(_ string, _ ...any) {} + +func (*dummyLogger) Warnw(_ string, _ ...any) {} + +func (*dummyLogger) Errorw(_ string, _ ...any) {} + +func (*dummyLogger) Fatalw(_ string, _ ...any) {} + +func (*dummyLogger) Panicw(_ string, _ ...any) {} + +func Test_Client_Logger_Debug(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("response") + }) + + addrChan := make(chan string) + go func() { + require.NoError(t, app.Listen(":0", fiber.ListenConfig{ + DisableStartupMessage: true, + ListenerAddrFunc: func(addr net.Addr) { + addrChan <- addr.String() + }, + })) + }() + + defer func(app *fiber.App) { + require.NoError(t, app.Shutdown()) + }(app) + + var buf bytes.Buffer + logger := &dummyLogger{buf: &buf} + + client := NewClient() + client.Debug().SetLogger(logger) + + addr := <-addrChan + resp, err := client.Get("http://" + addr) + require.NoError(t, err) + defer resp.Close() + + require.NoError(t, err) + require.Contains(t, buf.String(), "Host: "+addr) + require.Contains(t, buf.String(), "Content-Length: 8") +} + +func Test_Client_Logger_DisableDebug(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("response") + }) + + addrChan := make(chan string) + go func() { + require.NoError(t, app.Listen(":0", fiber.ListenConfig{ + DisableStartupMessage: true, + ListenerAddrFunc: func(addr net.Addr) { + addrChan <- addr.String() + }, + })) + }() + + defer func(app *fiber.App) { + require.NoError(t, app.Shutdown()) + }(app) + + var buf bytes.Buffer + logger := &dummyLogger{buf: &buf} + + client := NewClient() + client.DisableDebug().SetLogger(logger) + + addr := <-addrChan + resp, err := client.Get("http://" + addr) + require.NoError(t, err) + defer resp.Close() + + require.NoError(t, err) + require.Empty(t, buf.String()) +} diff --git a/client/request.go b/client/request.go new file mode 100644 index 0000000000..0bf2fb321c --- /dev/null +++ b/client/request.go @@ -0,0 +1,985 @@ +package client + +import ( + "bytes" + "context" + "errors" + "io" + "path/filepath" + "reflect" + "strconv" + "sync" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/utils/v2" + "github.com/valyala/fasthttp" +) + +// WithStruct Implementing this interface allows data to +// be stored from a struct via reflect. +type WithStruct interface { + Add(name, obj string) + Del(name string) +} + +// Types of request bodies. +type bodyType int + +// Enumeration definition of the request body type. +const ( + noBody bodyType = iota + jsonBody + xmlBody + formBody + filesBody + rawBody +) + +var ErrClientNil = errors.New("client can not be nil") + +// Request is a struct which contains the request data. +type Request struct { + url string + method string + userAgent string + boundary string + referer string + ctx context.Context //nolint:containedctx // It's needed to be stored in the request. + header *Header + params *QueryParam + cookies *Cookie + path *PathParam + + timeout time.Duration + maxRedirects int + + client *Client + + body any + formData *FormData + files []*File + bodyType bodyType + + RawRequest *fasthttp.Request +} + +// Method returns http method in request. +func (r *Request) Method() string { + return r.method +} + +// SetMethod will set method for Request object, +// user should use request method to set method. +func (r *Request) SetMethod(method string) *Request { + r.method = method + return r +} + +// URL returns request url in Request instance. +func (r *Request) URL() string { + return r.url +} + +// SetURL will set url for Request object. +func (r *Request) SetURL(url string) *Request { + r.url = url + return r +} + +// Client get Client instance in Request. +func (r *Request) Client() *Client { + return r.client +} + +// SetClient method sets client in request instance. +func (r *Request) SetClient(c *Client) *Request { + if c == nil { + panic(ErrClientNil) + } + + r.client = c + return r +} + +// Context returns the Context if its already set in request +// otherwise it creates new one using `context.Background()`. +func (r *Request) Context() context.Context { + if r.ctx == nil { + return context.Background() + } + return r.ctx +} + +// SetContext sets the context.Context for current Request. It allows +// to interrupt the request execution if ctx.Done() channel is closed. +// See https://blog.golang.org/context article and the "context" package +// documentation. +func (r *Request) SetContext(ctx context.Context) *Request { + r.ctx = ctx + return r +} + +// Header method returns header value via key, +// this method will visit all field in the header, +// then sort them. +func (r *Request) Header(key string) []string { + return r.header.PeekMultiple(key) +} + +// AddHeader method adds a single header field and its value in the request instance. +// It will override header which set in client instance. +func (r *Request) AddHeader(key, val string) *Request { + r.header.Add(key, val) + return r +} + +// SetHeader method sets a single header field and its value in the request instance. +// It will override header which set in client instance. +func (r *Request) SetHeader(key, val string) *Request { + r.header.Del(key) + r.header.Set(key, val) + return r +} + +// AddHeaders method adds multiple header fields and its values at one go in the request instance. +// It will override header which set in client instance. +func (r *Request) AddHeaders(h map[string][]string) *Request { + r.header.AddHeaders(h) + return r +} + +// SetHeaders method sets multiple header fields and its values at one go in the request instance. +// It will override header which set in client instance. +func (r *Request) SetHeaders(h map[string]string) *Request { + r.header.SetHeaders(h) + return r +} + +// Param method returns params value via key, +// this method will visit all field in the query param. +func (r *Request) Param(key string) []string { + var res []string + tmp := r.params.PeekMulti(key) + for _, v := range tmp { + res = append(res, utils.UnsafeString(v)) + } + + return res +} + +// AddParam method adds a single param field and its value in the request instance. +// It will override param which set in client instance. +func (r *Request) AddParam(key, val string) *Request { + r.params.Add(key, val) + return r +} + +// SetParam method sets a single param field and its value in the request instance. +// It will override param which set in client instance. +func (r *Request) SetParam(key, val string) *Request { + r.params.Set(key, val) + return r +} + +// AddParams method adds multiple param fields and its values at one go in the request instance. +// It will override param which set in client instance. +func (r *Request) AddParams(m map[string][]string) *Request { + r.params.AddParams(m) + return r +} + +// SetParams method sets multiple param fields and its values at one go in the request instance. +// It will override param which set in client instance. +func (r *Request) SetParams(m map[string]string) *Request { + r.params.SetParams(m) + return r +} + +// SetParamsWithStruct method sets multiple param fields and its values at one go in the request instance. +// It will override param which set in client instance. +func (r *Request) SetParamsWithStruct(v any) *Request { + r.params.SetParamsWithStruct(v) + return r +} + +// DelParams method deletes single or multiple param fields ant its values. +func (r *Request) DelParams(key ...string) *Request { + for _, v := range key { + r.params.Del(v) + } + return r +} + +// UserAgent returns user agent in request instance. +func (r *Request) UserAgent() string { + return r.userAgent +} + +// SetUserAgent method sets user agent in request. +// It will override user agent which set in client instance. +func (r *Request) SetUserAgent(ua string) *Request { + r.userAgent = ua + return r +} + +// Boundary returns boundary in multipart boundary. +func (r *Request) Boundary() string { + return r.boundary +} + +// SetBoundary method sets multipart boundary. +func (r *Request) SetBoundary(b string) *Request { + r.boundary = b + + return r +} + +// Referer returns referer in request instance. +func (r *Request) Referer() string { + return r.referer +} + +// SetReferer method sets referer in request. +// It will override referer which set in client instance. +func (r *Request) SetReferer(referer string) *Request { + r.referer = referer + return r +} + +// Cookie returns the cookie be set in request instance. +// if cookie doesn't exist, return empty string. +func (r *Request) Cookie(key string) string { + if val, ok := (*r.cookies)[key]; ok { + return val + } + return "" +} + +// SetCookie method sets a single cookie field and its value in the request instance. +// It will override cookie which set in client instance. +func (r *Request) SetCookie(key, val string) *Request { + r.cookies.SetCookie(key, val) + return r +} + +// SetCookies method sets multiple cookie fields and its values at one go in the request instance. +// It will override cookie which set in client instance. +func (r *Request) SetCookies(m map[string]string) *Request { + r.cookies.SetCookies(m) + return r +} + +// SetCookiesWithStruct method sets multiple cookie fields and its values at one go in the request instance. +// It will override cookie which set in client instance. +func (r *Request) SetCookiesWithStruct(v any) *Request { + r.cookies.SetCookiesWithStruct(v) + return r +} + +// DelCookies method deletes single or multiple cookie fields ant its values. +func (r *Request) DelCookies(key ...string) *Request { + r.cookies.DelCookies(key...) + return r +} + +// PathParam returns the path param be set in request instance. +// if path param doesn't exist, return empty string. +func (r *Request) PathParam(key string) string { + if val, ok := (*r.path)[key]; ok { + return val + } + + return "" +} + +// SetPathParam method sets a single path param field and its value in the request instance. +// It will override path param which set in client instance. +func (r *Request) SetPathParam(key, val string) *Request { + r.path.SetParam(key, val) + return r +} + +// SetPathParams method sets multiple path param fields and its values at one go in the request instance. +// It will override path param which set in client instance. +func (r *Request) SetPathParams(m map[string]string) *Request { + r.path.SetParams(m) + return r +} + +// SetPathParamsWithStruct method sets multiple path param fields and its values at one go in the request instance. +// It will override path param which set in client instance. +func (r *Request) SetPathParamsWithStruct(v any) *Request { + r.path.SetParamsWithStruct(v) + return r +} + +// DelPathParams method deletes single or multiple path param fields ant its values. +func (r *Request) DelPathParams(key ...string) *Request { + r.path.DelParams(key...) + return r +} + +// ResetPathParams deletes all path params. +func (r *Request) ResetPathParams() *Request { + r.path.Reset() + return r +} + +// SetJSON method sets json body in request. +func (r *Request) SetJSON(v any) *Request { + r.body = v + r.bodyType = jsonBody + return r +} + +// SetXML method sets xml body in request. +func (r *Request) SetXML(v any) *Request { + r.body = v + r.bodyType = xmlBody + return r +} + +// SetRawBody method sets body with raw data in request. +func (r *Request) SetRawBody(v []byte) *Request { + r.body = v + r.bodyType = rawBody + return r +} + +// resetBody will clear body object and set bodyType +// if body type is formBody and filesBody, the new body type will be ignored. +func (r *Request) resetBody(t bodyType) { + r.body = nil + + // Set form data after set file ignore. + if r.bodyType == filesBody && t == formBody { + return + } + r.bodyType = t +} + +// FormData method returns form data value via key, +// this method will visit all field in the form data. +func (r *Request) FormData(key string) []string { + var res []string + tmp := r.formData.PeekMulti(key) + for _, v := range tmp { + res = append(res, utils.UnsafeString(v)) + } + + return res +} + +// AddFormData method adds a single form data field and its value in the request instance. +func (r *Request) AddFormData(key, val string) *Request { + r.formData.AddData(key, val) + r.resetBody(formBody) + return r +} + +// SetFormData method sets a single form data field and its value in the request instance. +func (r *Request) SetFormData(key, val string) *Request { + r.formData.SetData(key, val) + r.resetBody(formBody) + return r +} + +// AddFormDatas method adds multiple form data fields and its values in the request instance. +func (r *Request) AddFormDatas(m map[string][]string) *Request { + r.formData.AddDatas(m) + r.resetBody(formBody) + return r +} + +// SetFormDatas method sets multiple form data fields and its values in the request instance. +func (r *Request) SetFormDatas(m map[string]string) *Request { + r.formData.SetDatas(m) + r.resetBody(formBody) + return r +} + +// SetFormDatasWithStruct method sets multiple form data fields +// and its values in the request instance via struct. +func (r *Request) SetFormDatasWithStruct(v any) *Request { + r.formData.SetDatasWithStruct(v) + r.resetBody(formBody) + return r +} + +// DelFormDatas method deletes multiple form data fields and its value in the request instance. +func (r *Request) DelFormDatas(key ...string) *Request { + r.formData.DelDatas(key...) + r.resetBody(formBody) + return r +} + +// File returns file ptr store in request obj by name. +// If name field is empty, it will try to match path. +func (r *Request) File(name string) *File { + for _, v := range r.files { + if v.name == "" { + if filepath.Base(v.path) == name { + return v + } + } else if v.name == name { + return v + } + } + + return nil +} + +// FileByPath returns file ptr store in request obj by path. +func (r *Request) FileByPath(path string) *File { + for _, v := range r.files { + if v.path == path { + return v + } + } + + return nil +} + +// AddFile method adds single file field +// and its value in the request instance via file path. +func (r *Request) AddFile(path string) *Request { + r.files = append(r.files, AcquireFile(SetFilePath(path))) + r.resetBody(filesBody) + return r +} + +// AddFileWithReader method adds single field +// and its value in the request instance via reader. +func (r *Request) AddFileWithReader(name string, reader io.ReadCloser) *Request { + r.files = append(r.files, AcquireFile(SetFileName(name), SetFileReader(reader))) + r.resetBody(filesBody) + return r +} + +// AddFiles method adds multiple file fields +// and its value in the request instance via File instance. +func (r *Request) AddFiles(files ...*File) *Request { + r.files = append(r.files, files...) + r.resetBody(filesBody) + return r +} + +// Timeout returns the length of timeout in request. +func (r *Request) Timeout() time.Duration { + return r.timeout +} + +// SetTimeout method sets timeout field and its values at one go in the request instance. +// It will override timeout which set in client instance. +func (r *Request) SetTimeout(t time.Duration) *Request { + r.timeout = t + return r +} + +// MaxRedirects returns the max redirects count in request. +func (r *Request) MaxRedirects() int { + return r.maxRedirects +} + +// SetMaxRedirects method sets the maximum number of redirects at one go in the request instance. +// It will override max redirect which set in client instance. +func (r *Request) SetMaxRedirects(count int) *Request { + r.maxRedirects = count + return r +} + +// checkClient method checks whether the client has been set in request. +func (r *Request) checkClient() { + if r.client == nil { + r.SetClient(defaultClient) + } +} + +// Get Send get request. +func (r *Request) Get(url string) (*Response, error) { + return r.SetURL(url).SetMethod(fiber.MethodGet).Send() +} + +// Post Send post request. +func (r *Request) Post(url string) (*Response, error) { + return r.SetURL(url).SetMethod(fiber.MethodPost).Send() +} + +// Head Send head request. +func (r *Request) Head(url string) (*Response, error) { + return r.SetURL(url).SetMethod(fiber.MethodHead).Send() +} + +// Put Send put request. +func (r *Request) Put(url string) (*Response, error) { + return r.SetURL(url).SetMethod(fiber.MethodPut).Send() +} + +// Delete Send Delete request. +func (r *Request) Delete(url string) (*Response, error) { + return r.SetURL(url).SetMethod(fiber.MethodDelete).Send() +} + +// Options Send Options request. +func (r *Request) Options(url string) (*Response, error) { + return r.SetURL(url).SetMethod(fiber.MethodOptions).Send() +} + +// Patch Send patch request. +func (r *Request) Patch(url string) (*Response, error) { + return r.SetURL(url).SetMethod(fiber.MethodPatch).Send() +} + +// Custom Send custom request. +func (r *Request) Custom(url, method string) (*Response, error) { + return r.SetURL(url).SetMethod(method).Send() +} + +// Send a request. +func (r *Request) Send() (*Response, error) { + r.checkClient() + + return newCore().execute(r.Context(), r.Client(), r) +} + +// Reset clear Request object, used by ReleaseRequest method. +func (r *Request) Reset() { + r.url = "" + r.method = fiber.MethodGet + r.userAgent = "" + r.referer = "" + r.ctx = nil + r.body = nil + r.timeout = 0 + r.maxRedirects = 0 + r.bodyType = noBody + r.boundary = boundary + + for len(r.files) != 0 { + t := r.files[0] + r.files = r.files[1:] + ReleaseFile(t) + } + + r.formData.Reset() + r.path.Reset() + r.cookies.Reset() + r.header.Reset() + r.params.Reset() + r.RawRequest.Reset() +} + +// Header is a wrapper which wrap http.Header, +// the header in client and request will store in it. +type Header struct { + *fasthttp.RequestHeader +} + +// PeekMultiple methods returns multiple field in header with same key. +func (h *Header) PeekMultiple(key string) []string { + var res []string + byteKey := []byte(key) + h.RequestHeader.VisitAll(func(key, value []byte) { + if bytes.EqualFold(key, byteKey) { + res = append(res, utils.UnsafeString(value)) + } + }) + + return res +} + +// AddHeaders receive a map and add each value to header. +func (h *Header) AddHeaders(r map[string][]string) { + for k, v := range r { + for _, vv := range v { + h.Add(k, vv) + } + } +} + +// SetHeaders will override all headers. +func (h *Header) SetHeaders(r map[string]string) { + for k, v := range r { + h.Del(k) + h.Set(k, v) + } +} + +// QueryParam is a wrapper which wrap url.Values, +// the query string and formdata in client and request will store in it. +type QueryParam struct { + *fasthttp.Args +} + +// AddParams receive a map and add each value to param. +func (p *QueryParam) AddParams(r map[string][]string) { + for k, v := range r { + for _, vv := range v { + p.Add(k, vv) + } + } +} + +// SetParams will override all params. +func (p *QueryParam) SetParams(r map[string]string) { + for k, v := range r { + p.Set(k, v) + } +} + +// SetParamsWithStruct will override all params with struct or pointer of struct. +// Now nested structs are not currently supported. +func (p *QueryParam) SetParamsWithStruct(v any) { + SetValWithStruct(p, "param", v) +} + +// Cookie is a map which to store the cookies. +type Cookie map[string]string + +// Add method impl the method in WithStruct interface. +func (c Cookie) Add(key, val string) { + c[key] = val +} + +// Del method impl the method in WithStruct interface. +func (c Cookie) Del(key string) { + delete(c, key) +} + +// SetCookie method sets a single val in Cookie. +func (c Cookie) SetCookie(key, val string) { + c[key] = val +} + +// SetCookies method sets multiple val in Cookie. +func (c Cookie) SetCookies(m map[string]string) { + for k, v := range m { + c[k] = v + } +} + +// SetCookiesWithStruct method sets multiple val in Cookie via a struct. +func (c Cookie) SetCookiesWithStruct(v any) { + SetValWithStruct(c, "cookie", v) +} + +// DelCookies method deletes multiple val in Cookie. +func (c Cookie) DelCookies(key ...string) { + for _, v := range key { + c.Del(v) + } +} + +// VisitAll method receive a function which can travel the all val. +func (c Cookie) VisitAll(f func(key, val string)) { + for k, v := range c { + f(k, v) + } +} + +// Reset clear the Cookie object. +func (c Cookie) Reset() { + for k := range c { + delete(c, k) + } +} + +// PathParam is a map which to store the cookies. +type PathParam map[string]string + +// Add method impl the method in WithStruct interface. +func (p PathParam) Add(key, val string) { + p[key] = val +} + +// Del method impl the method in WithStruct interface. +func (p PathParam) Del(key string) { + delete(p, key) +} + +// SetParam method sets a single val in PathParam. +func (p PathParam) SetParam(key, val string) { + p[key] = val +} + +// SetParams method sets multiple val in PathParam. +func (p PathParam) SetParams(m map[string]string) { + for k, v := range m { + p[k] = v + } +} + +// SetParamsWithStruct method sets multiple val in PathParam via a struct. +func (p PathParam) SetParamsWithStruct(v any) { + SetValWithStruct(p, "path", v) +} + +// DelParams method deletes multiple val in PathParams. +func (p PathParam) DelParams(key ...string) { + for _, v := range key { + p.Del(v) + } +} + +// VisitAll method receive a function which can travel the all val. +func (p PathParam) VisitAll(f func(key, val string)) { + for k, v := range p { + f(k, v) + } +} + +// Reset clear the PathParams object. +func (p PathParam) Reset() { + for k := range p { + delete(p, k) + } +} + +// FormData is a wrapper of fasthttp.Args, +// and it be used for url encode body and file body. +type FormData struct { + *fasthttp.Args +} + +// AddData method is a wrapper of Args's Add method. +func (f *FormData) AddData(key, val string) { + f.Add(key, val) +} + +// SetData method is a wrapper of Args's Set method. +func (f *FormData) SetData(key, val string) { + f.Set(key, val) +} + +// AddDatas method supports add multiple fields. +func (f *FormData) AddDatas(m map[string][]string) { + for k, v := range m { + for _, vv := range v { + f.Add(k, vv) + } + } +} + +// SetDatas method supports set multiple fields. +func (f *FormData) SetDatas(m map[string]string) { + for k, v := range m { + f.Set(k, v) + } +} + +// SetDatasWithStruct method supports set multiple fields via a struct. +func (f *FormData) SetDatasWithStruct(v any) { + SetValWithStruct(f, "form", v) +} + +// DelDatas method deletes multiple fields. +func (f *FormData) DelDatas(key ...string) { + for _, v := range key { + f.Del(v) + } +} + +// Reset clear the FormData object. +func (f *FormData) Reset() { + f.Args.Reset() +} + +// File is a struct which support send files via request. +type File struct { + name string + fieldName string + path string + reader io.ReadCloser +} + +// SetName method sets file name. +func (f *File) SetName(n string) { + f.name = n +} + +// SetFieldName method sets key of file in the body. +func (f *File) SetFieldName(n string) { + f.fieldName = n +} + +// SetPath method set file path. +func (f *File) SetPath(p string) { + f.path = p +} + +// SetReader method can receive a io.ReadCloser +// which will be closed in parserBody hook. +func (f *File) SetReader(r io.ReadCloser) { + f.reader = r +} + +// Reset clear the File object. +func (f *File) Reset() { + f.name = "" + f.fieldName = "" + f.path = "" + f.reader = nil +} + +var requestPool = &sync.Pool{ + New: func() any { + return &Request{ + header: &Header{RequestHeader: &fasthttp.RequestHeader{}}, + params: &QueryParam{Args: fasthttp.AcquireArgs()}, + cookies: &Cookie{}, + path: &PathParam{}, + boundary: "--FiberFormBoundary", + formData: &FormData{Args: fasthttp.AcquireArgs()}, + files: make([]*File, 0), + RawRequest: fasthttp.AcquireRequest(), + } + }, +} + +// AcquireRequest returns an empty request object from the pool. +// +// The returned request may be returned to the pool with ReleaseRequest when no longer needed. +// This allows reducing GC load. +func AcquireRequest() *Request { + req, ok := requestPool.Get().(*Request) + if !ok { + panic(errors.New("failed to type-assert to *Request")) + } + + return req +} + +// ReleaseRequest returns the object acquired via AcquireRequest to the pool. +// +// Do not access the released Request object, otherwise data races may occur. +func ReleaseRequest(req *Request) { + req.Reset() + requestPool.Put(req) +} + +var filePool sync.Pool + +// SetFileFunc The methods as follows is used by AcquireFile method. +// You can set file field via these method. +type SetFileFunc func(f *File) + +// SetFileName method sets file name. +func SetFileName(n string) SetFileFunc { + return func(f *File) { + f.SetName(n) + } +} + +// SetFileFieldName method sets key of file in the body. +func SetFileFieldName(p string) SetFileFunc { + return func(f *File) { + f.SetFieldName(p) + } +} + +// SetFilePath method set file path. +func SetFilePath(p string) SetFileFunc { + return func(f *File) { + f.SetPath(p) + } +} + +// SetFileReader method can receive a io.ReadCloser +func SetFileReader(r io.ReadCloser) SetFileFunc { + return func(f *File) { + f.SetReader(r) + } +} + +// AcquireFile returns an File object from the pool. +// And you can set field in the File with SetFileFunc. +// +// The returned file may be returned to the pool with ReleaseFile when no longer needed. +// This allows reducing GC load. +func AcquireFile(setter ...SetFileFunc) *File { + fv := filePool.Get() + if fv != nil { + f, ok := fv.(*File) + if !ok { + panic(errors.New("failed to type-assert to *File")) + } + for _, v := range setter { + v(f) + } + return f + } + f := &File{} + for _, v := range setter { + v(f) + } + return f +} + +// ReleaseFile returns the object acquired via AcquireFile to the pool. +// +// Do not access the released File object, otherwise data races may occur. +func ReleaseFile(f *File) { + f.Reset() + filePool.Put(f) +} + +// SetValWithStruct Set some values using structs. +// `p` is a structure that implements the WithStruct interface, +// The field name can be specified by `tagName`. +// `v` is a struct include some data. +// Note: This method only supports simple types and nested structs are not currently supported. +func SetValWithStruct(p WithStruct, tagName string, v any) { + valueOfV := reflect.ValueOf(v) + typeOfV := reflect.TypeOf(v) + + // The v should be struct or point of struct + if typeOfV.Kind() == reflect.Pointer && typeOfV.Elem().Kind() == reflect.Struct { + valueOfV = valueOfV.Elem() + typeOfV = typeOfV.Elem() + } else if typeOfV.Kind() != reflect.Struct { + return + } + + // Boring type judge. + // TODO: cover more types and complex data structure. + var setVal func(name string, value reflect.Value) + setVal = func(name string, val reflect.Value) { + switch val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + p.Add(name, strconv.Itoa(int(val.Int()))) + case reflect.Bool: + if val.Bool() { + p.Add(name, "true") + } + case reflect.String: + p.Add(name, val.String()) + case reflect.Float32, reflect.Float64: + p.Add(name, strconv.FormatFloat(val.Float(), 'f', -1, 64)) + case reflect.Slice, reflect.Array: + for i := 0; i < val.Len(); i++ { + setVal(name, val.Index(i)) + } + default: + } + } + + for i := 0; i < typeOfV.NumField(); i++ { + field := typeOfV.Field(i) + if !field.IsExported() { + continue + } + + name := field.Tag.Get(tagName) + if name == "" { + name = field.Name + } + val := valueOfV.Field(i) + if val.IsZero() { + continue + } + // To cover slice and array, we delete the val then add it. + p.Del(name) + setVal(name, val) + } +} diff --git a/client/request_test.go b/client/request_test.go new file mode 100644 index 0000000000..07e5254e15 --- /dev/null +++ b/client/request_test.go @@ -0,0 +1,1623 @@ +package client + +import ( + "bytes" + "context" + "errors" + "io" + "mime/multipart" + "net" + "os" + "path/filepath" + "regexp" + "strings" + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttputil" +) + +func Test_Request_Method(t *testing.T) { + t.Parallel() + + req := AcquireRequest() + req.SetMethod("GET") + require.Equal(t, "GET", req.Method()) + + req.SetMethod("POST") + require.Equal(t, "POST", req.Method()) + + req.SetMethod("PUT") + require.Equal(t, "PUT", req.Method()) + + req.SetMethod("DELETE") + require.Equal(t, "DELETE", req.Method()) + + req.SetMethod("PATCH") + require.Equal(t, "PATCH", req.Method()) + + req.SetMethod("OPTIONS") + require.Equal(t, "OPTIONS", req.Method()) + + req.SetMethod("HEAD") + require.Equal(t, "HEAD", req.Method()) + + req.SetMethod("TRACE") + require.Equal(t, "TRACE", req.Method()) + + req.SetMethod("CUSTOM") + require.Equal(t, "CUSTOM", req.Method()) +} + +func Test_Request_URL(t *testing.T) { + t.Parallel() + + req := AcquireRequest() + + req.SetURL("http://example.com/normal") + require.Equal(t, "http://example.com/normal", req.URL()) + + req.SetURL("https://example.com/normal") + require.Equal(t, "https://example.com/normal", req.URL()) +} + +func Test_Request_Client(t *testing.T) { + t.Parallel() + + client := NewClient() + req := AcquireRequest() + + req.SetClient(client) + require.Equal(t, client, req.Client()) +} + +func Test_Request_Context(t *testing.T) { + t.Parallel() + + req := AcquireRequest() + ctx := req.Context() + key := struct{}{} + + require.Nil(t, ctx.Value(key)) + + ctx = context.WithValue(ctx, key, "string") + req.SetContext(ctx) + ctx = req.Context() + + v, ok := ctx.Value(key).(string) + require.True(t, ok) + require.Equal(t, "string", v) +} + +func Test_Request_Header(t *testing.T) { + t.Parallel() + + t.Run("add header", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.AddHeader("foo", "bar").AddHeader("foo", "fiber") + + res := req.Header("foo") + require.Len(t, res, 2) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) + + t.Run("set header", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.AddHeader("foo", "bar").SetHeader("foo", "fiber") + + res := req.Header("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + }) + + t.Run("add headers", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.SetHeader("foo", "bar"). + AddHeaders(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.Header("foo") + require.Len(t, res, 3) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + require.Equal(t, "buaa", res[2]) + + res = req.Header("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.SetHeader("foo", "bar"). + SetHeaders(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.Header("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + + res = req.Header("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) +} + +func Test_Request_QueryParam(t *testing.T) { + t.Parallel() + + t.Run("add param", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.AddParam("foo", "bar").AddParam("foo", "fiber") + + res := req.Param("foo") + require.Len(t, res, 2) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) + + t.Run("set param", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.AddParam("foo", "bar").SetParam("foo", "fiber") + + res := req.Param("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + }) + + t.Run("add params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.SetParam("foo", "bar"). + AddParams(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.Param("foo") + require.Len(t, res, 3) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + require.Equal(t, "buaa", res[2]) + + res = req.Param("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.Param("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + + res = req.Param("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set params with struct", func(t *testing.T) { + t.Parallel() + + type args struct { + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `param:"int_slice"` + } + + p := AcquireRequest() + p.SetParamsWithStruct(&args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: true, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + require.Empty(t, p.Param("unexport")) + + require.Len(t, p.Param("TInt"), 1) + require.Equal(t, "5", p.Param("TInt")[0]) + + require.Len(t, p.Param("TString"), 1) + require.Equal(t, "string", p.Param("TString")[0]) + + require.Len(t, p.Param("TFloat"), 1) + require.Equal(t, "3.1", p.Param("TFloat")[0]) + + require.Len(t, p.Param("TBool"), 1) + + tslice := p.Param("TSlice") + require.Len(t, tslice, 2) + require.Equal(t, "foo", tslice[0]) + require.Equal(t, "bar", tslice[1]) + + tint := p.Param("TSlice") + require.Len(t, tint, 2) + require.Equal(t, "foo", tint[0]) + require.Equal(t, "bar", tint[1]) + }) + + t.Run("del params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "foo": "fiber", + "bar": "foo", + }).DelParams("foo", "bar") + + res := req.Param("foo") + require.Empty(t, res) + + res = req.Param("bar") + require.Empty(t, res) + }) +} + +func Test_Request_UA(t *testing.T) { + t.Parallel() + + req := AcquireRequest().SetUserAgent("fiber") + require.Equal(t, "fiber", req.UserAgent()) + + req.SetUserAgent("foo") + require.Equal(t, "foo", req.UserAgent()) +} + +func Test_Request_Referer(t *testing.T) { + t.Parallel() + + req := AcquireRequest().SetReferer("http://example.com") + require.Equal(t, "http://example.com", req.Referer()) + + req.SetReferer("https://example.com") + require.Equal(t, "https://example.com", req.Referer()) +} + +func Test_Request_Cookie(t *testing.T) { + t.Parallel() + + t.Run("set cookie", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + SetCookie("foo", "bar") + require.Equal(t, "bar", req.Cookie("foo")) + + req.SetCookie("foo", "bar1") + require.Equal(t, "bar1", req.Cookie("foo")) + }) + + t.Run("set cookies", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + SetCookies(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + + req.SetCookies(map[string]string{ + "foo": "bar1", + }) + require.Equal(t, "bar1", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + }) + + t.Run("set cookies with struct", func(t *testing.T) { + t.Parallel() + type args struct { + CookieInt int `cookie:"int"` + CookieString string `cookie:"string"` + } + + req := AcquireRequest().SetCookiesWithStruct(&args{ + CookieInt: 5, + CookieString: "foo", + }) + + require.Equal(t, "5", req.Cookie("int")) + require.Equal(t, "foo", req.Cookie("string")) + }) + + t.Run("del cookies", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + SetCookies(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + + req.DelCookies("foo") + require.Equal(t, "", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + }) +} + +func Test_Request_PathParam(t *testing.T) { + t.Parallel() + + t.Run("set path param", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + SetPathParam("foo", "bar") + require.Equal(t, "bar", req.PathParam("foo")) + + req.SetPathParam("foo", "bar1") + require.Equal(t, "bar1", req.PathParam("foo")) + }) + + t.Run("set path params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + + req.SetPathParams(map[string]string{ + "foo": "bar1", + }) + require.Equal(t, "bar1", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + }) + + t.Run("set path params with struct", func(t *testing.T) { + t.Parallel() + type args struct { + CookieInt int `path:"int"` + CookieString string `path:"string"` + } + + req := AcquireRequest().SetPathParamsWithStruct(&args{ + CookieInt: 5, + CookieString: "foo", + }) + + require.Equal(t, "5", req.PathParam("int")) + require.Equal(t, "foo", req.PathParam("string")) + }) + + t.Run("del path params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + + req.DelPathParams("foo") + require.Equal(t, "", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + }) + + t.Run("clear path params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + + req.ResetPathParams() + require.Equal(t, "", req.PathParam("foo")) + require.Equal(t, "", req.PathParam("bar")) + }) +} + +func Test_Request_FormData(t *testing.T) { + t.Parallel() + + t.Run("add form data", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + defer ReleaseRequest(req) + req.AddFormData("foo", "bar").AddFormData("foo", "fiber") + + res := req.FormData("foo") + require.Len(t, res, 2) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) + + t.Run("set param", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + defer ReleaseRequest(req) + req.AddFormData("foo", "bar").SetFormData("foo", "fiber") + + res := req.FormData("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + }) + + t.Run("add params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + defer ReleaseRequest(req) + req.SetFormData("foo", "bar"). + AddFormDatas(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.FormData("foo") + require.Len(t, res, 3) + require.Contains(t, res, "bar") + require.Contains(t, res, "buaa") + require.Contains(t, res, "fiber") + + res = req.FormData("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + defer ReleaseRequest(req) + req.SetFormData("foo", "bar"). + SetFormDatas(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.FormData("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + + res = req.FormData("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set params with struct", func(t *testing.T) { + t.Parallel() + + type args struct { + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `form:"int_slice"` + } + + p := AcquireRequest() + defer ReleaseRequest(p) + p.SetFormDatasWithStruct(&args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: true, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + require.Empty(t, p.FormData("unexport")) + + require.Len(t, p.FormData("TInt"), 1) + require.Equal(t, "5", p.FormData("TInt")[0]) + + require.Len(t, p.FormData("TString"), 1) + require.Equal(t, "string", p.FormData("TString")[0]) + + require.Len(t, p.FormData("TFloat"), 1) + require.Equal(t, "3.1", p.FormData("TFloat")[0]) + + require.Len(t, p.FormData("TBool"), 1) + + tslice := p.FormData("TSlice") + require.Len(t, tslice, 2) + require.Contains(t, tslice, "bar") + require.Contains(t, tslice, "foo") + + tint := p.FormData("TSlice") + require.Len(t, tint, 2) + require.Contains(t, tint, "bar") + require.Contains(t, tint, "foo") + }) + + t.Run("del params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + defer ReleaseRequest(req) + req.SetFormData("foo", "bar"). + SetFormDatas(map[string]string{ + "foo": "fiber", + "bar": "foo", + }).DelFormDatas("foo", "bar") + + res := req.FormData("foo") + require.Empty(t, res) + + res = req.FormData("bar") + require.Empty(t, res) + }) +} + +func Test_Request_File(t *testing.T) { + t.Parallel() + + t.Run("add file", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + AddFile("../.github/index.html"). + AddFiles(AcquireFile(SetFileName("tmp.txt"))) + + require.Equal(t, "../.github/index.html", req.File("index.html").path) + require.Equal(t, "../.github/index.html", req.FileByPath("../.github/index.html").path) + require.Equal(t, "tmp.txt", req.File("tmp.txt").name) + require.Nil(t, req.File("tmp2.txt")) + require.Nil(t, req.FileByPath("tmp2.txt")) + }) + + t.Run("add file by reader", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + AddFileWithReader("tmp.txt", io.NopCloser(strings.NewReader("world"))) + + require.Equal(t, "tmp.txt", req.File("tmp.txt").name) + + content, err := io.ReadAll(req.File("tmp.txt").reader) + require.NoError(t, err) + require.Equal(t, "world", string(content)) + }) + + t.Run("add files", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + AddFiles(AcquireFile(SetFileName("tmp.txt")), AcquireFile(SetFileName("foo.txt"))) + + require.Equal(t, "tmp.txt", req.File("tmp.txt").name) + require.Equal(t, "foo.txt", req.File("foo.txt").name) + }) +} + +func Test_Request_Timeout(t *testing.T) { + t.Parallel() + + req := AcquireRequest().SetTimeout(5 * time.Second) + + require.Equal(t, 5*time.Second, req.Timeout()) +} + +func Test_Request_Invalid_URL(t *testing.T) { + t.Parallel() + + resp, err := AcquireRequest(). + Get("http://example.com\r\n\r\nGET /\r\n\r\n") + + require.Equal(t, ErrURLFormat, err) + require.Equal(t, (*Response)(nil), resp) +} + +func Test_Request_Unsupport_Protocol(t *testing.T) { + t.Parallel() + + resp, err := AcquireRequest(). + Get("ftp://example.com") + require.Equal(t, ErrURLFormat, err) + require.Equal(t, (*Response)(nil), resp) +} + +func Test_Request_Get(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + req := AcquireRequest().SetClient(client) + + resp, err := req.Get("http://example.com") + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "example.com", resp.String()) + resp.Close() + } +} + +func Test_Request_Post(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Post("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusCreated). + SendString(c.FormValue("foo")) + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + SetFormData("foo", "bar"). + Post("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusCreated, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + resp.Close() + } +} + +func Test_Request_Head(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Head("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + Head("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "", resp.String()) + resp.Close() + } +} + +func Test_Request_Put(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Put("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + SetFormData("foo", "bar"). + Put("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + + resp.Close() + } +} + +func Test_Request_Delete(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + + app.Delete("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusNoContent). + SendString("deleted") + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + Delete("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) + + resp.Close() + } +} + +func Test_Request_Options(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + + app.Options("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusOK). + SendString("options") + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + Options("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "options", resp.String()) + + resp.Close() + } +} + +func Test_Request_Send(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + + app.Post("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusOK). + SendString("post") + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + SetURL("http://example.com"). + SetMethod(fiber.MethodPost). + Send() + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "post", resp.String()) + + resp.Close() + } +} + +func Test_Request_Patch(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + + app.Patch("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + SetFormData("foo", "bar"). + Patch("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + + resp.Close() + } +} + +func Test_Request_Header_With_Server(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + c.Request().Header.VisitAll(func(key, value []byte) { + if k := string(key); k == "K1" || k == "K2" { + _, err := c.Write(key) + require.NoError(t, err) + _, err = c.Write(value) + require.NoError(t, err) + } + }) + return nil + } + + wrapAgent := func(r *Request) { + r.SetHeader("k1", "v1"). + AddHeader("k1", "v11"). + AddHeaders(map[string][]string{ + "k1": {"v22", "v33"}, + }). + SetHeaders(map[string]string{ + "k2": "v2", + }). + AddHeader("k2", "v22") + } + + testRequest(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") +} + +func Test_Request_UserAgent_With_Server(t *testing.T) { + t.Parallel() + + handler := func(c fiber.Ctx) error { + return c.Send(c.Request().Header.UserAgent()) + } + + t.Run("default", func(t *testing.T) { + t.Parallel() + testRequest(t, handler, func(_ *Request) {}, defaultUserAgent, 5) + }) + + t.Run("custom", func(t *testing.T) { + t.Parallel() + testRequest(t, handler, func(agent *Request) { + agent.SetUserAgent("ua") + }, "ua", 5) + }) +} + +func Test_Request_Cookie_With_Server(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) + } + + wrapAgent := func(req *Request) { + req.SetCookie("k1", "v1"). + SetCookies(map[string]string{ + "k2": "v2", + "k3": "v3", + "k4": "v4", + }).DelCookies("k4") + } + + testRequest(t, handler, wrapAgent, "v1v2v3") +} + +func Test_Request_Referer_With_Server(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + return c.Send(c.Request().Header.Referer()) + } + + wrapAgent := func(req *Request) { + req.SetReferer("http://referer.com") + } + + testRequest(t, handler, wrapAgent, "http://referer.com") +} + +func Test_Request_QueryString_With_Server(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + return c.Send(c.Request().URI().QueryString()) + } + + wrapAgent := func(req *Request) { + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "bar": "baz", + }) + } + + testRequest(t, handler, wrapAgent, "foo=bar&bar=baz") +} + +func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { + t.Helper() + + basename := filepath.Base(filename) + require.Equal(t, fh.Filename, basename) + + b1, err := os.ReadFile(filepath.Clean(filename)) + require.NoError(t, err) + + b2 := make([]byte, fh.Size) + f, err := fh.Open() + require.NoError(t, err) + defer func() { require.NoError(t, f.Close()) }() + _, err = f.Read(b2) + require.NoError(t, err) + require.Equal(t, b1, b2) +} + +func Test_Request_Body_With_Server(t *testing.T) { + t.Parallel() + + t.Run("json body", func(t *testing.T) { + t.Parallel() + testRequest(t, + func(c fiber.Ctx) error { + require.Equal(t, "application/json", string(c.Request().Header.ContentType())) + return c.SendString(string(c.Request().Body())) + }, + func(agent *Request) { + agent.SetJSON(map[string]string{ + "success": "hello", + }) + }, + "{\"success\":\"hello\"}", + ) + }) + + t.Run("xml body", func(t *testing.T) { + t.Parallel() + testRequest(t, + func(c fiber.Ctx) error { + require.Equal(t, "application/xml", string(c.Request().Header.ContentType())) + return c.SendString(string(c.Request().Body())) + }, + func(agent *Request) { + type args struct { + Content string `xml:"content"` + } + agent.SetXML(args{ + Content: "hello", + }) + }, + "hello", + ) + }) + + t.Run("formdata", func(t *testing.T) { + t.Parallel() + testRequest(t, + func(c fiber.Ctx) error { + require.Equal(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) + return c.Send([]byte("foo=" + c.FormValue("foo") + "&bar=" + c.FormValue("bar") + "&fiber=" + c.FormValue("fiber"))) + }, + func(agent *Request) { + agent.SetFormData("foo", "bar"). + SetFormDatas(map[string]string{ + "bar": "baz", + "fiber": "fast", + }) + }, + "foo=bar&bar=baz&fiber=fast") + }) + + t.Run("multipart form", func(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Post("/", func(c fiber.Ctx) error { + require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) + + mf, err := c.MultipartForm() + require.NoError(t, err) + require.Equal(t, "bar", mf.Value["foo"][0]) + + return c.Send(c.Request().Body()) + }) + + go start() + + client := NewClient().SetDial(ln) + + req := AcquireRequest(). + SetClient(client). + SetBoundary("myBoundary"). + SetFormData("foo", "bar"). + AddFiles(AcquireFile( + SetFileName("hello.txt"), + SetFileFieldName("foo"), + SetFileReader(io.NopCloser(strings.NewReader("world"))), + )) + + require.Equal(t, "myBoundary", req.Boundary()) + + resp, err := req.Post("http://exmaple.com") + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + + form, err := multipart.NewReader(bytes.NewReader(resp.Body()), "myBoundary").ReadForm(1024 * 1024) + require.NoError(t, err) + require.Equal(t, "bar", form.Value["foo"][0]) + resp.Close() + }) + + t.Run("multipart form send file", func(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Post("/", func(c fiber.Ctx) error { + require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) + + fh1, err := c.FormFile("field1") + require.NoError(t, err) + require.Equal(t, "name", fh1.Filename) + buf := make([]byte, fh1.Size) + f, err := fh1.Open() + require.NoError(t, err) + defer func() { require.NoError(t, f.Close()) }() + _, err = f.Read(buf) + require.NoError(t, err) + require.Equal(t, "form file", string(buf)) + + fh2, err := c.FormFile("file2") + require.NoError(t, err) + checkFormFile(t, fh2, "../.github/testdata/index.html") + + fh3, err := c.FormFile("file3") + require.NoError(t, err) + checkFormFile(t, fh3, "../.github/testdata/index.tmpl") + + return c.SendString("multipart form files") + }) + + go start() + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + req := AcquireRequest(). + SetClient(client). + AddFiles( + AcquireFile( + SetFileFieldName("field1"), + SetFileName("name"), + SetFileReader(io.NopCloser(bytes.NewReader([]byte("form file")))), + ), + ). + AddFile("../.github/testdata/index.html"). + AddFile("../.github/testdata/index.tmpl"). + SetBoundary("myBoundary") + + resp, err := req.Post("http://example.com") + require.NoError(t, err) + require.Equal(t, "multipart form files", resp.String()) + + resp.Close() + } + }) + + t.Run("multipart random boundary", func(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Post("/", func(c fiber.Ctx) error { + reg := regexp.MustCompile(`multipart/form-data; boundary=[\-\w]{35}`) + require.True(t, reg.MatchString(c.Get(fiber.HeaderContentType))) + + return c.Send(c.Request().Body()) + }) + + go start() + + client := NewClient().SetDial(ln) + + req := AcquireRequest(). + SetClient(client). + SetFormData("foo", "bar"). + AddFiles(AcquireFile( + SetFileName("hello.txt"), + SetFileFieldName("foo"), + SetFileReader(io.NopCloser(strings.NewReader("world"))), + )) + + resp, err := req.Post("http://exmaple.com") + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + }) + + t.Run("raw body", func(t *testing.T) { + t.Parallel() + testRequest(t, + func(c fiber.Ctx) error { + return c.SendString(string(c.Request().Body())) + }, + func(agent *Request) { + agent.SetRawBody([]byte("hello")) + }, + "hello", + ) + }) +} + +func Test_Request_Error_Body_With_Server(t *testing.T) { + t.Parallel() + t.Run("json error", func(t *testing.T) { + t.Parallel() + testRequestFail(t, + func(c fiber.Ctx) error { + return c.SendString("") + }, + func(agent *Request) { + agent.SetJSON(complex(1, 1)) + }, + errors.New("json: unsupported type: complex128"), + ) + }) + + t.Run("xml error", func(t *testing.T) { + t.Parallel() + testRequestFail(t, + func(c fiber.Ctx) error { + return c.SendString("") + }, + func(agent *Request) { + agent.SetXML(complex(1, 1)) + }, + errors.New("xml: unsupported type: complex128"), + ) + }) + + t.Run("form body with invalid boundary", func(t *testing.T) { + t.Parallel() + + _, err := AcquireRequest(). + SetBoundary("*"). + AddFileWithReader("t.txt", io.NopCloser(strings.NewReader("world"))). + Get("http://example.com") + require.Equal(t, "set boundary error: mime: invalid boundary character", err.Error()) + }) + + t.Run("open non exist file", func(t *testing.T) { + t.Parallel() + + _, err := AcquireRequest(). + AddFile("non-exist-file!"). + Get("http://example.com") + require.Contains(t, err.Error(), "open non-exist-file!") + }) +} + +func Test_Request_Timeout_With_Server(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + time.Sleep(time.Millisecond * 200) + return c.SendString("timeout") + }) + go start() + + client := NewClient().SetDial(ln) + + _, err := AcquireRequest(). + SetClient(client). + SetTimeout(50 * time.Millisecond). + Get("http://example.com") + + require.Equal(t, ErrTimeoutOrCancel, err) +} + +func Test_Request_MaxRedirects(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := fiber.New() + + app.Get("/", func(c fiber.Ctx) error { + if c.Request().URI().QueryArgs().Has("foo") { + return c.Redirect().To("/foo") + } + return c.Redirect().To("/") + }) + app.Get("/foo", func(c fiber.Ctx) error { + return c.SendString("redirect") + }) + + go func() { require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() + + t.Run("success", func(t *testing.T) { + t.Parallel() + + client := NewClient().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed + + resp, err := AcquireRequest(). + SetClient(client). + SetMaxRedirects(1). + Get("http://example.com?foo") + body := resp.String() + code := resp.StatusCode() + + require.Equal(t, 200, code) + require.Equal(t, "redirect", body) + require.NoError(t, err) + + resp.Close() + }) + + t.Run("error", func(t *testing.T) { + t.Parallel() + + client := NewClient().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed + + resp, err := AcquireRequest(). + SetClient(client). + SetMaxRedirects(1). + Get("http://example.com") + + require.Nil(t, resp) + require.Equal(t, "too many redirects detected when doing the request", err.Error()) + }) + + t.Run("MaxRedirects", func(t *testing.T) { + t.Parallel() + + client := NewClient().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed + + req := AcquireRequest(). + SetClient(client). + SetMaxRedirects(3) + + require.Equal(t, 3, req.MaxRedirects()) + }) +} + +func Test_SetValWithStruct(t *testing.T) { + t.Parallel() + + // test SetValWithStruct vai QueryParam struct. + type args struct { + unexport int + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `param:"int_slice"` + } + + t.Run("the struct should be applied", func(t *testing.T) { + t.Parallel() + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + SetValWithStruct(p, "param", args{ + unexport: 5, + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: false, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + require.Equal(t, "", string(p.Peek("unexport"))) + require.Equal(t, []byte("5"), p.Peek("TInt")) + require.Equal(t, []byte("string"), p.Peek("TString")) + require.Equal(t, []byte("3.1"), p.Peek("TFloat")) + require.Equal(t, "", string(p.Peek("TBool"))) + require.True(t, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "foo" { + return true + } + } + return false + }()) + + require.True(t, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "bar" { + return true + } + } + return false + }()) + + require.True(t, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "1" { + return true + } + } + return false + }()) + + require.True(t, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "2" { + return true + } + } + return false + }()) + }) + + t.Run("the pointer of a struct should be applied", func(t *testing.T) { + t.Parallel() + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + SetValWithStruct(p, "param", &args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: true, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + require.Equal(t, []byte("5"), p.Peek("TInt")) + require.Equal(t, []byte("string"), p.Peek("TString")) + require.Equal(t, []byte("3.1"), p.Peek("TFloat")) + require.Equal(t, "true", string(p.Peek("TBool"))) + require.True(t, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "foo" { + return true + } + } + return false + }()) + + require.True(t, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "bar" { + return true + } + } + return false + }()) + + require.True(t, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "1" { + return true + } + } + return false + }()) + + require.True(t, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "2" { + return true + } + } + return false + }()) + }) + + t.Run("the zero val should be ignore", func(t *testing.T) { + t.Parallel() + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + SetValWithStruct(p, "param", &args{ + TInt: 0, + TString: "", + TFloat: 0.0, + }) + + require.Equal(t, "", string(p.Peek("TInt"))) + require.Equal(t, "", string(p.Peek("TString"))) + require.Equal(t, "", string(p.Peek("TFloat"))) + require.Empty(t, p.PeekMulti("TSlice")) + require.Empty(t, p.PeekMulti("int_slice")) + }) + + t.Run("error type should ignore", func(t *testing.T) { + t.Parallel() + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + SetValWithStruct(p, "param", 5) + require.Equal(t, 0, p.Len()) + }) +} + +func Benchmark_SetValWithStruct(b *testing.B) { + // test SetValWithStruct vai QueryParam struct. + type args struct { + unexport int + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `param:"int_slice"` + } + + b.Run("the struct should be applied", func(b *testing.B) { + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + b.ReportAllocs() + b.StartTimer() + + for i := 0; i < b.N; i++ { + SetValWithStruct(p, "param", args{ + unexport: 5, + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: false, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + } + + require.Equal(b, "", string(p.Peek("unexport"))) + require.Equal(b, []byte("5"), p.Peek("TInt")) + require.Equal(b, []byte("string"), p.Peek("TString")) + require.Equal(b, []byte("3.1"), p.Peek("TFloat")) + require.Equal(b, "", string(p.Peek("TBool"))) + require.True(b, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "foo" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "bar" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "1" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "2" { + return true + } + } + return false + }()) + }) + + b.Run("the pointer of a struct should be applied", func(b *testing.B) { + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + b.ReportAllocs() + b.StartTimer() + + for i := 0; i < b.N; i++ { + SetValWithStruct(p, "param", &args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: true, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + } + + require.Equal(b, []byte("5"), p.Peek("TInt")) + require.Equal(b, []byte("string"), p.Peek("TString")) + require.Equal(b, []byte("3.1"), p.Peek("TFloat")) + require.Equal(b, "true", string(p.Peek("TBool"))) + require.True(b, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "foo" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "bar" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "1" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "2" { + return true + } + } + return false + }()) + }) + + b.Run("the zero val should be ignore", func(b *testing.B) { + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + b.ReportAllocs() + b.StartTimer() + + for i := 0; i < b.N; i++ { + SetValWithStruct(p, "param", &args{ + TInt: 0, + TString: "", + TFloat: 0.0, + }) + } + + require.Empty(b, string(p.Peek("TInt"))) + require.Empty(b, string(p.Peek("TString"))) + require.Empty(b, string(p.Peek("TFloat"))) + require.Empty(b, len(p.PeekMulti("TSlice"))) + require.Empty(b, len(p.PeekMulti("int_slice"))) + }) + + b.Run("error type should ignore", func(b *testing.B) { + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + b.ReportAllocs() + b.StartTimer() + + for i := 0; i < b.N; i++ { + SetValWithStruct(p, "param", 5) + } + + require.Equal(b, 0, p.Len()) + }) +} diff --git a/client/response.go b/client/response.go new file mode 100644 index 0000000000..f6ecd6fcd8 --- /dev/null +++ b/client/response.go @@ -0,0 +1,184 @@ +package client + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/gofiber/utils/v2" + "github.com/valyala/fasthttp" +) + +// Response is the result of a request. This object is used to access the response data. +type Response struct { + client *Client + request *Request + cookie []*fasthttp.Cookie + + RawResponse *fasthttp.Response +} + +// setClient method sets client object in response instance. +// Use core object in the client. +func (r *Response) setClient(c *Client) { + r.client = c +} + +// setRequest method sets Request object in response instance. +// The request will be released when the Response.Close is called. +func (r *Response) setRequest(req *Request) { + r.request = req +} + +// Status method returns the HTTP status string for the executed request. +func (r *Response) Status() string { + return string(r.RawResponse.Header.StatusMessage()) +} + +// StatusCode method returns the HTTP status code for the executed request. +func (r *Response) StatusCode() int { + return r.RawResponse.StatusCode() +} + +// Protocol method returns the HTTP response protocol used for the request. +func (r *Response) Protocol() string { + return string(r.RawResponse.Header.Protocol()) +} + +// Header method returns the response headers. +func (r *Response) Header(key string) string { + return utils.UnsafeString(r.RawResponse.Header.Peek(key)) +} + +// Cookies method to access all the response cookies. +func (r *Response) Cookies() []*fasthttp.Cookie { + return r.cookie +} + +// Body method returns HTTP response as []byte array for the executed request. +func (r *Response) Body() []byte { + return r.RawResponse.Body() +} + +// String method returns the body of the server response as String. +func (r *Response) String() string { + return strings.TrimSpace(string(r.Body())) +} + +// JSON method will unmarshal body to json. +func (r *Response) JSON(v any) error { + return r.client.jsonUnmarshal(r.Body(), v) +} + +// XML method will unmarshal body to xml. +func (r *Response) XML(v any) error { + return r.client.xmlUnmarshal(r.Body(), v) +} + +// Save method will save the body to a file or io.Writer. +func (r *Response) Save(v any) error { + switch p := v.(type) { + case string: + file := filepath.Clean(p) + dir := filepath.Dir(file) + + // create directory + if _, err := os.Stat(dir); err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("failed to check directory: %w", err) + } + + if err = os.MkdirAll(dir, 0o750); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + } + + // create file + outFile, err := os.Create(file) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + defer func() { _ = outFile.Close() }() //nolint:errcheck // not needed + + _, err = io.Copy(outFile, bytes.NewReader(r.Body())) + if err != nil { + return fmt.Errorf("failed to write response body to file: %w", err) + } + + return nil + case io.Writer: + _, err := io.Copy(p, bytes.NewReader(r.Body())) + if err != nil { + return fmt.Errorf("failed to write response body to io.Writer: %w", err) + } + defer func() { + if pc, ok := p.(io.WriteCloser); ok { + _ = pc.Close() //nolint:errcheck // not needed + } + }() + + return nil + default: + return ErrNotSupportSaveMethod + } +} + +// Reset clear Response object. +func (r *Response) Reset() { + r.client = nil + r.request = nil + + for len(r.cookie) != 0 { + t := r.cookie[0] + r.cookie = r.cookie[1:] + fasthttp.ReleaseCookie(t) + } + + r.RawResponse.Reset() +} + +// Close method will release Request object and Response object, +// after call Close please don't use these object. +func (r *Response) Close() { + if r.request != nil { + tmp := r.request + r.request = nil + ReleaseRequest(tmp) + } + ReleaseResponse(r) +} + +var responsePool = &sync.Pool{ + New: func() any { + return &Response{ + cookie: []*fasthttp.Cookie{}, + RawResponse: fasthttp.AcquireResponse(), + } + }, +} + +// AcquireResponse returns an empty response object from the pool. +// +// The returned response may be returned to the pool with ReleaseResponse when no longer needed. +// This allows reducing GC load. +func AcquireResponse() *Response { + resp, ok := responsePool.Get().(*Response) + if !ok { + panic("unexpected type from responsePool.Get()") + } + return resp +} + +// ReleaseResponse returns the object acquired via AcquireResponse to the pool. +// +// Do not access the released Response object, otherwise data races may occur. +func ReleaseResponse(resp *Response) { + resp.Reset() + responsePool.Put(resp) +} diff --git a/client/response_test.go b/client/response_test.go new file mode 100644 index 0000000000..622e835714 --- /dev/null +++ b/client/response_test.go @@ -0,0 +1,418 @@ +package client + +import ( + "bytes" + "crypto/tls" + "encoding/xml" + "io" + "net" + "os" + "testing" + + "github.com/gofiber/fiber/v3/internal/tlstest" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" +) + +func Test_Response_Status(t *testing.T) { + t.Parallel() + + setupApp := func() *testServer { + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("foo") + }) + app.Get("/fail", func(c fiber.Ctx) error { + return c.SendStatus(407) + }) + }) + + return server + } + + t.Run("success", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example") + + require.NoError(t, err) + require.Equal(t, "OK", resp.Status()) + resp.Close() + }) + + t.Run("fail", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example/fail") + + require.NoError(t, err) + require.Equal(t, "Proxy Authentication Required", resp.Status()) + resp.Close() + }) +} + +func Test_Response_Status_Code(t *testing.T) { + t.Parallel() + + setupApp := func() *testServer { + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("foo") + }) + app.Get("/fail", func(c fiber.Ctx) error { + return c.SendStatus(407) + }) + }) + + return server + } + + t.Run("success", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example") + + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode()) + resp.Close() + }) + + t.Run("fail", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example/fail") + + require.NoError(t, err) + require.Equal(t, 407, resp.StatusCode()) + resp.Close() + }) +} + +func Test_Response_Protocol(t *testing.T) { + t.Parallel() + + t.Run("http", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("foo") + }) + }) + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example") + + require.NoError(t, err) + require.Equal(t, "HTTP/1.1", resp.Protocol()) + resp.Close() + }) + + t.Run("https", func(t *testing.T) { + t.Parallel() + + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + require.NoError(t, err) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.NoError(t, err) + + ln = tls.NewListener(ln, serverTLSConf) + + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(c.Scheme()) + }) + + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + client := NewClient() + resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) + + require.NoError(t, err) + require.Equal(t, clientTLSConf, client.TLSConfig()) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "https", resp.String()) + require.Equal(t, "HTTP/1.1", resp.Protocol()) + + resp.Close() + }) +} + +func Test_Response_Header(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + c.Response().Header.Add("foo", "bar") + return c.SendString("helo world") + }) + }) + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, "bar", resp.Header("foo")) + resp.Close() +} + +func Test_Response_Cookie(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + c.Cookie(&fiber.Cookie{ + Name: "foo", + Value: "bar", + }) + return c.SendString("helo world") + }) + }) + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, "bar", string(resp.Cookies()[0].Value())) + resp.Close() +} + +func Test_Response_Body(t *testing.T) { + t.Parallel() + + setupApp := func() *testServer { + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) + + app.Get("/json", func(c fiber.Ctx) error { + return c.SendString("{\"status\":\"success\"}") + }) + + app.Get("/xml", func(c fiber.Ctx) error { + return c.SendString("success") + }) + }) + + return server + } + + t.Run("raw body", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, []byte("hello world"), resp.Body()) + resp.Close() + }) + + t.Run("string body", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, "hello world", resp.String()) + resp.Close() + }) + + t.Run("json body", func(t *testing.T) { + t.Parallel() + type body struct { + Status string `json:"status"` + } + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com/json") + + require.NoError(t, err) + + tmp := &body{} + err = resp.JSON(tmp) + require.NoError(t, err) + require.Equal(t, "success", tmp.Status) + resp.Close() + }) + + t.Run("xml body", func(t *testing.T) { + t.Parallel() + type body struct { + Name xml.Name `xml:"status"` + Status string `xml:"name"` + } + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com/xml") + + require.NoError(t, err) + + tmp := &body{} + err = resp.XML(tmp) + require.NoError(t, err) + require.Equal(t, "success", tmp.Status) + resp.Close() + }) +} + +func Test_Response_Save(t *testing.T) { + t.Parallel() + + setupApp := func() *testServer { + server := startTestServer(t, func(app *fiber.App) { + app.Get("/json", func(c fiber.Ctx) error { + return c.SendString("{\"status\":\"success\"}") + }) + }) + + return server + } + + t.Run("file path", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com/json") + + require.NoError(t, err) + + err = resp.Save("./test/tmp.json") + require.NoError(t, err) + defer func() { + _, err := os.Stat("./test/tmp.json") + require.NoError(t, err) + + err = os.RemoveAll("./test") + require.NoError(t, err) + }() + + file, err := os.Open("./test/tmp.json") + require.NoError(t, err) + defer func(file *os.File) { + err := file.Close() + require.NoError(t, err) + }(file) + + data, err := io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, "{\"status\":\"success\"}", string(data)) + }) + + t.Run("io.Writer", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com/json") + + require.NoError(t, err) + + buf := &bytes.Buffer{} + + err = resp.Save(buf) + require.NoError(t, err) + require.Equal(t, "{\"status\":\"success\"}", buf.String()) + }) + + t.Run("error type", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com/json") + + require.NoError(t, err) + + err = resp.Save(nil) + require.Error(t, err) + }) +} diff --git a/client_test.go b/client_test.go deleted file mode 100644 index 57cc4e4d2e..0000000000 --- a/client_test.go +++ /dev/null @@ -1,1337 +0,0 @@ -//nolint:wrapcheck // We must not wrap errors in tests -package fiber - -import ( - "bytes" - "crypto/tls" - "encoding/base64" - "encoding/json" - "encoding/xml" - "errors" - "io" - "mime/multipart" - "net" - "os" - "path/filepath" - "regexp" - "strings" - "testing" - "time" - - "github.com/gofiber/fiber/v3/internal/tlstest" - "github.com/stretchr/testify/require" - "github.com/valyala/fasthttp" - "github.com/valyala/fasthttp/fasthttputil" -) - -func Test_Client_Invalid_URL(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString(c.Host()) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - a := Get("http://example.com\r\n\r\nGET /\r\n\r\n") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - _, body, errs := a.String() - - require.Equal(t, "", body) - require.Len(t, errs, 1) - require.Error(t, errs[0], - `Expected error "missing required Host header in request"`) -} - -func Test_Client_Unsupported_Protocol(t *testing.T) { - t.Parallel() - - a := Get("ftp://example.com") - - _, body, errs := a.String() - - require.Equal(t, "", body) - require.Len(t, errs, 1) - require.ErrorContains(t, errs[0], `unsupported protocol "ftp". http and https are supported`) -} - -func Test_Client_Get(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString(c.Host()) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - for i := 0; i < 5; i++ { - a := Get("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "example.com", body) - require.Empty(t, errs) - } -} - -func Test_Client_Head(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Head("/", func(c Ctx) error { - return c.SendStatus(StatusAccepted) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - for i := 0; i < 5; i++ { - a := Head("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusAccepted, code) - require.Equal(t, "", body) - require.Empty(t, errs) - } -} - -func Test_Client_Post(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Post("/", func(c Ctx) error { - return c.Status(StatusCreated). - SendString(c.FormValue("foo")) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - for i := 0; i < 5; i++ { - args := AcquireArgs() - - args.Set("foo", "bar") - - a := Post("http://example.com"). - Form(args) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusCreated, code) - require.Equal(t, "bar", body) - require.Empty(t, errs) - - ReleaseArgs(args) - } -} - -func Test_Client_Put(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Put("/", func(c Ctx) error { - return c.SendString(c.FormValue("foo")) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - for i := 0; i < 5; i++ { - args := AcquireArgs() - - args.Set("foo", "bar") - - a := Put("http://example.com"). - Form(args) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "bar", body) - require.Empty(t, errs) - - ReleaseArgs(args) - } -} - -func Test_Client_Patch(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Patch("/", func(c Ctx) error { - return c.SendString(c.FormValue("foo")) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - for i := 0; i < 5; i++ { - args := AcquireArgs() - - args.Set("foo", "bar") - - a := Patch("http://example.com"). - Form(args) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "bar", body) - require.Empty(t, errs) - - ReleaseArgs(args) - } -} - -func Test_Client_Delete(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Delete("/", func(c Ctx) error { - return c.Status(StatusNoContent). - SendString("deleted") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - for i := 0; i < 5; i++ { - args := AcquireArgs() - - a := Delete("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusNoContent, code) - require.Equal(t, "", body) - require.Empty(t, errs) - - ReleaseArgs(args) - } -} - -func Test_Client_UserAgent(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.Send(c.Request().Header.UserAgent()) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - t.Run("default", func(t *testing.T) { - t.Parallel() - for i := 0; i < 5; i++ { - a := Get("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, defaultUserAgent, body) - require.Empty(t, errs) - } - }) - - t.Run("custom", func(t *testing.T) { - t.Parallel() - for i := 0; i < 5; i++ { - c := AcquireClient() - c.UserAgent = "ua" - - a := c.Get("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "ua", body) - require.Empty(t, errs) - ReleaseClient(c) - } - }) -} - -func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - c.Request().Header.VisitAll(func(key, value []byte) { - if k := string(key); k == "K1" || k == "K2" { - _, err := c.Write(key) - require.NoError(t, err) - _, err = c.Write(value) - require.NoError(t, err) - } - }) - return nil - } - - wrapAgent := func(a *Agent) { - a.Set("k1", "v1"). - SetBytesK([]byte("k1"), "v1"). - SetBytesV("k1", []byte("v1")). - AddBytesK([]byte("k1"), "v11"). - AddBytesV("k1", []byte("v22")). - AddBytesKV([]byte("k1"), []byte("v33")). - SetBytesKV([]byte("k2"), []byte("v2")). - Add("k2", "v22") - } - - testAgent(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") -} - -func Test_Client_Agent_Connection_Close(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - if c.Request().Header.ConnectionClose() { - return c.SendString("close") - } - return c.SendString("not close") - } - - wrapAgent := func(a *Agent) { - a.ConnectionClose() - } - - testAgent(t, handler, wrapAgent, "close") -} - -func Test_Client_Agent_UserAgent(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.Send(c.Request().Header.UserAgent()) - } - - wrapAgent := func(a *Agent) { - a.UserAgent("ua"). - UserAgentBytes([]byte("ua")) - } - - testAgent(t, handler, wrapAgent, "ua") -} - -func Test_Client_Agent_Cookie(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.SendString( - c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) - } - - wrapAgent := func(a *Agent) { - a.Cookie("k1", "v1"). - CookieBytesK([]byte("k2"), "v2"). - CookieBytesKV([]byte("k2"), []byte("v2")). - Cookies("k3", "v3", "k4", "v4"). - CookiesBytesKV([]byte("k3"), []byte("v3"), []byte("k4"), []byte("v4")) - } - - testAgent(t, handler, wrapAgent, "v1v2v3v4") -} - -func Test_Client_Agent_Referer(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.Send(c.Request().Header.Referer()) - } - - wrapAgent := func(a *Agent) { - a.Referer("http://referer.com"). - RefererBytes([]byte("http://referer.com")) - } - - testAgent(t, handler, wrapAgent, "http://referer.com") -} - -func Test_Client_Agent_ContentType(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.Send(c.Request().Header.ContentType()) - } - - wrapAgent := func(a *Agent) { - a.ContentType("custom-type"). - ContentTypeBytes([]byte("custom-type")) - } - - testAgent(t, handler, wrapAgent, "custom-type") -} - -func Test_Client_Agent_Host(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString(c.Host()) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - a := Get("http://1.1.1.1:8080"). - Host("example.com"). - HostBytes([]byte("example.com")) - - require.Equal(t, "1.1.1.1:8080", a.HostClient.Addr) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "example.com", body) - require.Empty(t, errs) -} - -func Test_Client_Agent_QueryString(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.Send(c.Request().URI().QueryString()) - } - - wrapAgent := func(a *Agent) { - a.QueryString("foo=bar&bar=baz"). - QueryStringBytes([]byte("foo=bar&bar=baz")) - } - - testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -} - -func Test_Client_Agent_BasicAuth(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - // Get authorization header - auth := c.Get(HeaderAuthorization) - // Decode the header contents - raw, err := base64.StdEncoding.DecodeString(auth[6:]) - require.NoError(t, err) - - return c.Send(raw) - } - - wrapAgent := func(a *Agent) { - a.BasicAuth("foo", "bar"). - BasicAuthBytes([]byte("foo"), []byte("bar")) - } - - testAgent(t, handler, wrapAgent, "foo:bar") -} - -func Test_Client_Agent_BodyString(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.Send(c.Request().Body()) - } - - wrapAgent := func(a *Agent) { - a.BodyString("foo=bar&bar=baz") - } - - testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -} - -func Test_Client_Agent_Body(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.Send(c.Request().Body()) - } - - wrapAgent := func(a *Agent) { - a.Body([]byte("foo=bar&bar=baz")) - } - - testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -} - -func Test_Client_Agent_BodyStream(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.Send(c.Request().Body()) - } - - wrapAgent := func(a *Agent) { - a.BodyStream(strings.NewReader("body stream"), -1) - } - - testAgent(t, handler, wrapAgent, "body stream") -} - -func Test_Client_Agent_Custom_Response(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString("custom") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - for i := 0; i < 5; i++ { - a := AcquireAgent() - resp := AcquireResponse() - - req := a.Request() - req.Header.SetMethod(MethodGet) - req.SetRequestURI("http://example.com") - - require.NoError(t, a.Parse()) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.SetResponse(resp). - String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "custom", body) - require.Equal(t, "custom", string(resp.Body())) - require.Empty(t, errs) - - ReleaseResponse(resp) - } -} - -func Test_Client_Agent_Dest(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString("dest") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - t.Run("small dest", func(t *testing.T) { - t.Parallel() - dest := []byte("de") - - a := Get("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.Dest(dest[:0]).String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "dest", body) - require.Equal(t, "de", string(dest)) - require.Empty(t, errs) - }) - - t.Run("enough dest", func(t *testing.T) { - t.Parallel() - dest := []byte("foobar") - - a := Get("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.Dest(dest[:0]).String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "dest", body) - require.Equal(t, "destar", string(dest)) - require.Empty(t, errs) - }) -} - -// readErrorConn is a struct for testing retryIf -type readErrorConn struct { - net.Conn -} - -func (*readErrorConn) Read(_ []byte) (int, error) { - return 0, errors.New("error") -} - -func (*readErrorConn) Write(p []byte) (int, error) { - return len(p), nil -} - -func (*readErrorConn) Close() error { - return nil -} - -func (*readErrorConn) LocalAddr() net.Addr { - return nil -} - -func (*readErrorConn) RemoteAddr() net.Addr { - return nil -} - -func (*readErrorConn) SetReadDeadline(_ time.Time) error { - return nil -} - -func (*readErrorConn) SetWriteDeadline(_ time.Time) error { - return nil -} - -func Test_Client_Agent_RetryIf(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - a := Post("http://example.com"). - RetryIf(func(_ *Request) bool { - return true - }) - dialsCount := 0 - a.HostClient.Dial = func(_ string) (net.Conn, error) { - dialsCount++ - switch dialsCount { - case 1: - return &readErrorConn{}, nil - case 2: - return &readErrorConn{}, nil - case 3: - return &readErrorConn{}, nil - case 4: - return ln.Dial() - default: - t.Fatalf("unexpected number of dials: %d", dialsCount) - } - panic("unreachable") - } - - _, _, errs := a.String() - require.Equal(t, 4, dialsCount) - require.Empty(t, errs) -} - -func Test_Client_Agent_Json(t *testing.T) { - t.Parallel() - // Test without ctype parameter - handler := func(c Ctx) error { - require.Equal(t, MIMEApplicationJSON, string(c.Request().Header.ContentType())) - - return c.Send(c.Request().Body()) - } - - wrapAgent := func(a *Agent) { - a.JSON(data{Success: true}) - } - - testAgent(t, handler, wrapAgent, `{"success":true}`) - - // Test with ctype parameter - handler = func(c Ctx) error { - require.Equal(t, "application/problem+json", string(c.Request().Header.ContentType())) - - return c.Send(c.Request().Body()) - } - - wrapAgent = func(a *Agent) { - a.JSON(data{Success: true}, "application/problem+json") - } - - testAgent(t, handler, wrapAgent, `{"success":true}`) -} - -func Test_Client_Agent_Json_Error(t *testing.T) { - t.Parallel() - a := Get("http://example.com"). - JSONEncoder(json.Marshal). - JSON(complex(1, 1)) - - _, body, errs := a.String() - - require.Equal(t, "", body) - require.Len(t, errs, 1) - wantErr := new(json.UnsupportedTypeError) - require.ErrorAs(t, errs[0], &wantErr) -} - -func Test_Client_Agent_XML(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - require.Equal(t, MIMEApplicationXML, string(c.Request().Header.ContentType())) - - return c.Send(c.Request().Body()) - } - - wrapAgent := func(a *Agent) { - a.XML(data{Success: true}) - } - - testAgent(t, handler, wrapAgent, "true") -} - -func Test_Client_Agent_XML_Error(t *testing.T) { - t.Parallel() - a := Get("http://example.com"). - XML(complex(1, 1)) - - _, body, errs := a.String() - require.Equal(t, "", body) - require.Len(t, errs, 1) - wantErr := new(xml.UnsupportedTypeError) - require.ErrorAs(t, errs[0], &wantErr) -} - -func Test_Client_Agent_Form(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - require.Equal(t, MIMEApplicationForm, string(c.Request().Header.ContentType())) - - return c.Send(c.Request().Body()) - } - - args := AcquireArgs() - - args.Set("foo", "bar") - - wrapAgent := func(a *Agent) { - a.Form(args) - } - - testAgent(t, handler, wrapAgent, "foo=bar") - - ReleaseArgs(args) -} - -func Test_Client_Agent_MultipartForm(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Post("/", func(c Ctx) error { - require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(HeaderContentType)) - - mf, err := c.MultipartForm() - require.NoError(t, err) - require.Equal(t, "bar", mf.Value["foo"][0]) - - return c.Send(c.Request().Body()) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - args := AcquireArgs() - - args.Set("foo", "bar") - - a := Post("http://example.com"). - Boundary("myBoundary"). - MultipartForm(args) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "--myBoundary\r\nContent-Disposition: form-data; name=\"foo\"\r\n\r\nbar\r\n--myBoundary--\r\n", body) - require.Empty(t, errs) - ReleaseArgs(args) -} - -func Test_Client_Agent_MultipartForm_Errors(t *testing.T) { - t.Parallel() - - a := AcquireAgent() - a.mw = &errorMultipartWriter{} - - args := AcquireArgs() - args.Set("foo", "bar") - - ff1 := &FormFile{"", "name1", []byte("content"), false} - ff2 := &FormFile{"", "name2", []byte("content"), false} - a.FileData(ff1, ff2). - MultipartForm(args) - - require.Len(t, a.errs, 4) - ReleaseArgs(args) -} - -func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Post("/", func(c Ctx) error { - require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(HeaderContentType)) - - fh1, err := c.FormFile("field1") - require.NoError(t, err) - require.Equal(t, "name", fh1.Filename) - buf := make([]byte, fh1.Size) - f, err := fh1.Open() - require.NoError(t, err) - defer func() { - err := f.Close() - require.NoError(t, err) - }() - _, err = f.Read(buf) - require.NoError(t, err) - require.Equal(t, "form file", string(buf)) - - fh2, err := c.FormFile("index") - require.NoError(t, err) - checkFormFile(t, fh2, ".github/testdata/index.html") - - fh3, err := c.FormFile("file3") - require.NoError(t, err) - checkFormFile(t, fh3, ".github/testdata/index.tmpl") - - return c.SendString("multipart form files") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - for i := 0; i < 5; i++ { - ff := AcquireFormFile() - ff.Fieldname = "field1" - ff.Name = "name" - ff.Content = []byte("form file") - - a := Post("http://example.com"). - Boundary("myBoundary"). - FileData(ff). - SendFiles(".github/testdata/index.html", "index", ".github/testdata/index.tmpl"). - MultipartForm(nil) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "multipart form files", body) - require.Empty(t, errs) - - ReleaseFormFile(ff) - } -} - -func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { - t.Helper() - - basename := filepath.Base(filename) - require.Equal(t, fh.Filename, basename) - - b1, err := os.ReadFile(filename) //nolint:gosec // We're in a test so reading user-provided files by name is fine - require.NoError(t, err) - - b2 := make([]byte, fh.Size) - f, err := fh.Open() - require.NoError(t, err) - defer func() { - err := f.Close() - require.NoError(t, err) - }() - _, err = f.Read(b2) - require.NoError(t, err) - require.Equal(t, b1, b2) -} - -func Test_Client_Agent_Multipart_Random_Boundary(t *testing.T) { - t.Parallel() - - a := Post("http://example.com"). - MultipartForm(nil) - - reg := regexp.MustCompile(`multipart/form-data; boundary=\w{30}`) - - require.True(t, reg.Match(a.req.Header.Peek(HeaderContentType))) -} - -func Test_Client_Agent_Multipart_Invalid_Boundary(t *testing.T) { - t.Parallel() - - a := Post("http://example.com"). - Boundary("*"). - MultipartForm(nil) - - require.Len(t, a.errs, 1) - require.ErrorContains(t, a.errs[0], "mime: invalid boundary character") -} - -func Test_Client_Agent_SendFile_Error(t *testing.T) { - t.Parallel() - - a := Post("http://example.com"). - SendFile("non-exist-file!", "") - - require.Len(t, a.errs, 1) - require.ErrorIs(t, a.errs[0], os.ErrNotExist) -} - -func Test_Client_Debug(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.SendString("debug") - } - - var output bytes.Buffer - - wrapAgent := func(a *Agent) { - a.Debug(&output) - } - - testAgent(t, handler, wrapAgent, "debug", 1) - - str := output.String() - - require.Contains(t, str, "Connected to example.com(InmemoryListener)") - require.Contains(t, str, "GET / HTTP/1.1") - require.Contains(t, str, "User-Agent: fiber") - require.Contains(t, str, "Host: example.com\r\n\r\n") - require.Contains(t, str, "HTTP/1.1 200 OK") - require.Contains(t, str, "Content-Type: text/plain; charset=utf-8\r\nContent-Length: 5\r\n\r\ndebug") -} - -func Test_Client_Agent_Timeout(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - time.Sleep(time.Millisecond * 200) - return c.SendString("timeout") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - a := Get("http://example.com"). - Timeout(time.Millisecond * 50) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - _, body, errs := a.String() - - require.Equal(t, "", body) - require.Len(t, errs, 1) - require.ErrorIs(t, errs[0], fasthttp.ErrTimeout) -} - -func Test_Client_Agent_Reuse(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString("reuse") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - a := Get("http://example.com"). - Reuse() - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "reuse", body) - require.Empty(t, errs) - - code, body, errs = a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "reuse", body) - require.Empty(t, errs) -} - -func Test_Client_Agent_InsecureSkipVerify(t *testing.T) { - t.Parallel() - - cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") - require.NoError(t, err) - - //nolint:gosec // We're in a test so using old ciphers is fine - serverTLSConf := &tls.Config{ - Certificates: []tls.Certificate{cer}, - } - - ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0") - require.NoError(t, err) - - ln = tls.NewListener(ln, serverTLSConf) - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString("ignore tls") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - code, body, errs := Get("https://" + ln.Addr().String()). - InsecureSkipVerify(). - InsecureSkipVerify(). - String() - - require.Empty(t, errs) - require.Equal(t, StatusOK, code) - require.Equal(t, "ignore tls", body) -} - -func Test_Client_Agent_TLS(t *testing.T) { - t.Parallel() - - serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() - require.NoError(t, err) - - ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0") - require.NoError(t, err) - - ln = tls.NewListener(ln, serverTLSConf) - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString("tls") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - code, body, errs := Get("https://" + ln.Addr().String()). - TLSConfig(clientTLSConf). - String() - - require.Empty(t, errs) - require.Equal(t, StatusOK, code) - require.Equal(t, "tls", body) -} - -func Test_Client_Agent_MaxRedirectsCount(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - if c.Request().URI().QueryArgs().Has("foo") { - return c.Redirect().To("/foo") - } - return c.Redirect().To("/") - }) - app.Get("/foo", func(c Ctx) error { - return c.SendString("redirect") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - t.Run("success", func(t *testing.T) { - t.Parallel() - a := Get("http://example.com?foo"). - MaxRedirectsCount(1) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, 200, code) - require.Equal(t, "redirect", body) - require.Empty(t, errs) - }) - - t.Run("error", func(t *testing.T) { - t.Parallel() - a := Get("http://example.com"). - MaxRedirectsCount(1) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - _, body, errs := a.String() - - require.Equal(t, "", body) - require.Len(t, errs, 1) - require.ErrorIs(t, errs[0], fasthttp.ErrTooManyRedirects) - }) -} - -func Test_Client_Agent_Struct(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.JSON(data{true}) - }) - - app.Get("/error", func(c Ctx) error { - return c.SendString(`{"success"`) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - t.Run("success", func(t *testing.T) { - t.Parallel() - - a := Get("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - var d data - - code, body, errs := a.Struct(&d) - - require.Equal(t, StatusOK, code) - require.Equal(t, `{"success":true}`, string(body)) - require.Empty(t, errs) - require.True(t, d.Success) - }) - - t.Run("pre error", func(t *testing.T) { - t.Parallel() - a := Get("http://example.com") - - errPre := errors.New("pre errors") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - a.errs = append(a.errs, errPre) - - var d data - _, body, errs := a.Struct(&d) - - require.Equal(t, "", string(body)) - require.Len(t, errs, 1) - require.ErrorIs(t, errs[0], errPre) - require.False(t, d.Success) - }) - - t.Run("error", func(t *testing.T) { - t.Parallel() - a := Get("http://example.com/error") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - var d data - - code, body, errs := a.JSONDecoder(json.Unmarshal).Struct(&d) - - require.Equal(t, StatusOK, code) - require.Equal(t, `{"success"`, string(body)) - require.Len(t, errs, 1) - wantErr := new(json.SyntaxError) - require.ErrorAs(t, errs[0], &wantErr) - require.EqualValues(t, 10, wantErr.Offset) - }) - - t.Run("nil jsonDecoder", func(t *testing.T) { - t.Parallel() - a := AcquireAgent() - defer ReleaseAgent(a) - defer a.ConnectionClose() - request := a.Request() - request.Header.SetMethod(MethodGet) - request.SetRequestURI("http://example.com") - err := a.Parse() - require.NoError(t, err) - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - var d data - code, body, errs := a.Struct(&d) - require.Equal(t, StatusOK, code) - require.Equal(t, `{"success":true}`, string(body)) - require.Empty(t, errs) - require.True(t, d.Success) - }) -} - -func Test_Client_Agent_Parse(t *testing.T) { - t.Parallel() - - a := Get("https://example.com:10443") - - require.NoError(t, a.Parse()) -} - -func testAgent(t *testing.T, handler Handler, wrapAgent func(agent *Agent), excepted string, count ...int) { - t.Helper() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", handler) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - c := 1 - if len(count) > 0 { - c = count[0] - } - - for i := 0; i < c; i++ { - a := Get("http://example.com") - - wrapAgent(a) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, excepted, body) - require.Empty(t, errs) - } -} - -type data struct { - Success bool `json:"success" xml:"success"` -} - -type errorMultipartWriter struct { - count int -} - -func (*errorMultipartWriter) Boundary() string { return "myBoundary" } -func (*errorMultipartWriter) SetBoundary(_ string) error { return nil } -func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) { - if e.count == 0 { - e.count++ - return nil, errors.New("CreateFormFile error") - } - return errorWriter{}, nil -} -func (*errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") } -func (*errorMultipartWriter) Close() error { return errors.New("Close error") } - -type errorWriter struct{} - -func (errorWriter) Write(_ []byte) (int, error) { return 0, errors.New("Write error") } diff --git a/listen_test.go b/listen_test.go index d92b9fb396..a5d419ac86 100644 --- a/listen_test.go +++ b/listen_test.go @@ -17,6 +17,7 @@ import ( "time" "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" ) @@ -68,22 +69,30 @@ func Test_Listen_Graceful_Shutdown(t *testing.T) { Time time.Duration ExpectedBody string ExpectedStatusCode int - ExceptedErrsLen int + ExpectedErr error }{ - {Time: 100 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExceptedErrsLen: 0}, - {Time: 500 * time.Millisecond, ExpectedBody: "", ExpectedStatusCode: 0, ExceptedErrsLen: 1}, + {Time: 100 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExpectedErr: nil}, + {Time: 500 * time.Millisecond, ExpectedBody: "", ExpectedStatusCode: StatusOK, ExpectedErr: errors.New("InmemoryListener is already closed: use of closed network connection")}, } for _, tc := range testCases { time.Sleep(tc.Time) - a := Get("http://example.com") - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - code, body, errs := a.String() + req := fasthttp.AcquireRequest() + req.SetRequestURI("http://example.com") - require.Equal(t, tc.ExpectedStatusCode, code) - require.Equal(t, tc.ExpectedBody, body) - require.Len(t, errs, tc.ExceptedErrsLen) + client := fasthttp.HostClient{} + client.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } + + resp := fasthttp.AcquireResponse() + err := client.Do(req, resp) + + require.Equal(t, tc.ExpectedErr, err) + require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode()) + require.Equal(t, tc.ExpectedBody, string(resp.Body())) + + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) } mu.Lock() diff --git a/log/default.go b/log/default.go index e9c3d1bbbd..67fa137ecf 100644 --- a/log/default.go +++ b/log/default.go @@ -34,6 +34,7 @@ func (l *defaultLogger) privateLog(lv Level, fmtArgs []any) { if lv == LevelPanic { panic(buf.String()) } + buf.Reset() bytebufferpool.Put(buf) if lv == LevelFatal { @@ -56,6 +57,7 @@ func (l *defaultLogger) privateLogf(lv Level, format string, fmtArgs []any) { } else { _, _ = fmt.Fprint(buf, fmtArgs...) } + _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error if lv == LevelPanic { panic(buf.String()) diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 284b67c8f5..167a0e6f31 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -2,7 +2,6 @@ package proxy import ( "bytes" - "crypto/tls" "net/url" "strings" "sync" @@ -105,13 +104,6 @@ var client = &fasthttp.Client{ var lock sync.RWMutex -// WithTLSConfig update http client with a user specified tls.config -// This function should be called before Do and Forward. -// Deprecated: use WithClient instead. -func WithTLSConfig(tlsConfig *tls.Config) { - client.TLSConfig = tlsConfig -} - // WithClient sets the global proxy client. // This function should be called before Do and Forward. func WithClient(cli *fasthttp.Client) { diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 408ee71a5f..4aa0065040 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -11,8 +11,10 @@ import ( "time" "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/internal/tlstest" + clientpkg "github.com/gofiber/fiber/v3/client" "github.com/stretchr/testify/require" + + "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/valyala/fasthttp" ) @@ -25,8 +27,6 @@ func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, str ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) - addr := ln.Addr().String() - go func() { require.NoError(t, target.Listener(ln, fiber.ListenConfig{ DisableStartupMessage: true, @@ -34,6 +34,7 @@ func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, str }() time.Sleep(2 * time.Second) + addr := ln.Addr().String() return target, addr } @@ -104,8 +105,8 @@ func Test_Proxy(t *testing.T) { require.Equal(t, fiber.StatusTeapot, resp.StatusCode) } -// go test -run Test_Proxy_Balancer_WithTLSConfig -func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) { +// go test -run Test_Proxy_Balancer_WithTlsConfig +func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) { t.Parallel() serverTLSConf, _, err := tlstest.GetTLSConfigs() @@ -118,7 +119,7 @@ func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) { app := fiber.New() - app.Get("/tlsbalaner", func(c fiber.Ctx) error { + app.Get("/tlsbalancer", func(c fiber.Ctx) error { return c.SendString("tls balancer") }) @@ -137,15 +138,18 @@ func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) { })) }() - code, body, errs := fiber.Get("https://" + addr + "/tlsbalaner").TLSConfig(clientTLSConf).String() + client := clientpkg.NewClient() + client.SetTLSConfig(clientTLSConf) - require.Empty(t, errs) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "tls balancer", body) + resp, err := client.Get("https://" + addr + "/tlsbalancer") + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "tls balancer", string(resp.Body())) + resp.Close() } -// go test -run Test_Proxy_Forward_WithTLSConfig_To_Http -func Test_Proxy_Forward_WithTLSConfig_To_Http(t *testing.T) { +// go test -run Test_Proxy_Forward_WithTlsConfig_To_Http +func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) { t.Parallel() _, targetAddr := createProxyTestServer(t, func(c fiber.Ctx) error { @@ -172,14 +176,15 @@ func Test_Proxy_Forward_WithTLSConfig_To_Http(t *testing.T) { })) }() - code, body, errs := fiber.Get("https://" + proxyAddr). - InsecureSkipVerify(). - Timeout(5 * time.Second). - String() + client := clientpkg.NewClient() + client.SetTimeout(5 * time.Second) + client.TLSConfig().InsecureSkipVerify = true - require.Empty(t, errs) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "hello from target", body) + resp, err := client.Get("https://" + proxyAddr) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "hello from target", string(resp.Body())) + resp.Close() } // go test -run Test_Proxy_Forward @@ -203,8 +208,8 @@ func Test_Proxy_Forward(t *testing.T) { require.Equal(t, "forwarded", string(b)) } -// go test -run Test_Proxy_Forward_WithTLSConfig -func Test_Proxy_Forward_WithTLSConfig(t *testing.T) { +// go test -run Test_Proxy_Forward_WithClient_TLSConfig +func Test_Proxy_Forward_WithClient_TLSConfig(t *testing.T) { t.Parallel() serverTLSConf, _, err := tlstest.GetTLSConfigs() @@ -225,7 +230,9 @@ func Test_Proxy_Forward_WithTLSConfig(t *testing.T) { clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine // disable certificate verification - WithTLSConfig(clientTLSConf) + WithClient(&fasthttp.Client{ + TLSConfig: clientTLSConf, + }) app.Use(Forward("https://" + addr + "/tlsfwd")) go func() { @@ -234,11 +241,14 @@ func Test_Proxy_Forward_WithTLSConfig(t *testing.T) { })) }() - code, body, errs := fiber.Get("https://" + addr).TLSConfig(clientTLSConf).String() + client := clientpkg.NewClient() + client.SetTLSConfig(clientTLSConf) - require.Empty(t, errs) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "tls forward", body) + resp, err := client.Get("https://" + addr) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "tls forward", string(resp.Body())) + resp.Close() } // go test -run Test_Proxy_Modify_Response @@ -415,7 +425,7 @@ func Test_Proxy_Do_WithRedirect(t *testing.T) { return Do(c, "https://google.com") }) - resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500) + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) require.NoError(t, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -431,7 +441,7 @@ func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) { return DoRedirects(c, "http://google.com", 1) }) - resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500) + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) require.NoError(t, err1) _, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -447,7 +457,7 @@ func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) { return DoRedirects(c, "http://google.com", 0) }) - resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500) + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) require.NoError(t, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -586,10 +596,13 @@ func Test_Proxy_Forward_Global_Client(t *testing.T) { })) }() - code, body, errs := fiber.Get("http://" + addr).String() - require.Empty(t, errs) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "test_global_client", body) + client := clientpkg.NewClient() + + resp, err := client.Get("http://" + addr) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "test_global_client", string(resp.Body())) + resp.Close() } // go test -race -run Test_Proxy_Forward_Local_Client @@ -615,10 +628,13 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) { })) }() - code, body, errs := fiber.Get("http://" + addr).String() - require.Empty(t, errs) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "test_local_client", body) + client := clientpkg.NewClient() + + resp, err := client.Get("http://" + addr) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "test_local_client", string(resp.Body())) + resp.Close() } // go test -run Test_ProxyBalancer_Custom_Client @@ -666,7 +682,7 @@ func Test_Proxy_Domain_Forward_Local(t *testing.T) { app1 := fiber.New() app1.Get("/test", func(c fiber.Ctx) error { - return c.SendString("test_local_client:" + fiber.Query[string](c, "query_test")) + return c.SendString("test_local_client:" + c.Query("query_test")) }) proxyAddr := ln.Addr().String() @@ -679,13 +695,24 @@ func Test_Proxy_Domain_Forward_Local(t *testing.T) { Dial: fasthttp.Dial, })) - go func() { require.NoError(t, app.Listener(ln)) }() - go func() { require.NoError(t, app1.Listener(ln1)) }() + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + go func() { + require.NoError(t, app1.Listener(ln1, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + client := clientpkg.NewClient() - code, body, errs := fiber.Get("http://" + localDomain + "/test?query_test=true").String() - require.Empty(t, errs) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "test_local_client:true", body) + resp, err := client.Get("http://" + localDomain + "/test?query_test=true") + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "test_local_client:true", string(resp.Body())) + resp.Close() } // go test -run Test_Proxy_Balancer_Forward_Local diff --git a/redirect_test.go b/redirect_test.go index d49f526771..6dd2ae6d19 100644 --- a/redirect_test.go +++ b/redirect_test.go @@ -291,41 +291,45 @@ func Test_Redirect_Request(t *testing.T) { CookieValue string ExpectedBody string ExpectedStatusCode int - ExceptedErrsLen int + ExpectedErr error }{ { URL: "/", CookieValue: "key:value,key2:value2,co\\:m\\,ma:Fi\\:ber\\, v3", ExpectedBody: `{"inputs":{},"messages":{"co:m,ma":"Fi:ber, v3","key":"value","key2":"value2"}}`, ExpectedStatusCode: StatusOK, - ExceptedErrsLen: 0, + ExpectedErr: nil, }, { URL: "/with-inputs?name=john&surname=doe", CookieValue: "key:value,key2:value2,key:value,key2:value2,old_input_data_name:john,old_input_data_surname:doe", ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{"key":"value","key2":"value2"}}`, ExpectedStatusCode: StatusOK, - ExceptedErrsLen: 0, + ExpectedErr: nil, }, { URL: "/just-inputs?name=john&surname=doe", CookieValue: "old_input_data_name:john,old_input_data_surname:doe", ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{}}`, ExpectedStatusCode: StatusOK, - ExceptedErrsLen: 0, + ExpectedErr: nil, }, } for _, tc := range testCases { - a := Get("http://example.com" + tc.URL) - a.Cookie(FlashCookieName, tc.CookieValue) - a.MaxRedirectsCount(1) - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - code, body, errs := a.String() - - require.Equal(t, tc.ExpectedStatusCode, code) - require.Equal(t, tc.ExpectedBody, body) - require.Len(t, errs, tc.ExceptedErrsLen) + client := &fasthttp.HostClient{ + Dial: func(_ string) (net.Conn, error) { + return ln.Dial() + }, + } + req, resp := fasthttp.AcquireRequest(), fasthttp.AcquireResponse() + req.SetRequestURI("http://example.com" + tc.URL) + req.Header.SetCookie(FlashCookieName, tc.CookieValue) + err := client.DoRedirects(req, resp, 1) + + require.NoError(t, err) + require.Equal(t, tc.ExpectedBody, string(resp.Body())) + require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode()) } }