diff --git a/app.go b/app.go index 378a3426b9..9b40b4084b 100644 --- a/app.go +++ b/app.go @@ -117,6 +117,8 @@ type App struct { newCtxFunc func(app *App) CustomCtx // TLS handler tlsHandler *tlsHandler + // bind decoder cache + bindDecoderCache sync.Map } // Config is a struct holding the server settings. @@ -329,6 +331,17 @@ type Config struct { // Default: xml.Marshal XMLEncoder utils.XMLMarshal `json:"-"` + // XMLDecoder set by an external client of Fiber it will use the provided implementation of a + // XMLUnmarshal + // + // Allowing for flexibility in using another XML library for encoding + // Default: utils.XMLUnmarshal + XMLDecoder utils.XMLUnmarshal `json:"-"` + + // App validate. if nil, and context.EnableValidate will always return a error. + // Default: nil + Validator Validator + // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only) // WARNING: When prefork is set to true, only "tcp4" and "tcp6" can be chose. // @@ -513,9 +526,14 @@ func New(config ...Config) *App { if app.config.JSONDecoder == nil { app.config.JSONDecoder = json.Unmarshal } + if app.config.XMLEncoder == nil { app.config.XMLEncoder = xml.Marshal } + if app.config.XMLDecoder == nil { + app.config.XMLDecoder = xml.Unmarshal + } + if app.config.Network == "" { app.config.Network = NetworkTCP4 } diff --git a/bind.go b/bind.go new file mode 100644 index 0000000000..cce399203b --- /dev/null +++ b/bind.go @@ -0,0 +1,59 @@ +package fiber + +import ( + "fmt" + "reflect" + + "github.com/gofiber/fiber/v3/internal/bind" +) + +type Binder interface { + UnmarshalFiberCtx(ctx Ctx) error +} + +// decoder should set a field on reqValue +// it's created with field index +type decoder interface { + Decode(ctx Ctx, reqValue reflect.Value) error +} + +type fieldCtxDecoder struct { + index int + fieldName string + fieldType reflect.Type +} + +func (d *fieldCtxDecoder) Decode(ctx Ctx, reqValue reflect.Value) error { + v := reflect.New(d.fieldType) + unmarshaler := v.Interface().(Binder) + + if err := unmarshaler.UnmarshalFiberCtx(ctx); err != nil { + return err + } + + reqValue.Field(d.index).Set(v.Elem()) + return nil +} + +type fieldTextDecoder struct { + index int + fieldName string + tag string // query,param,header,respHeader ... + reqField string + dec bind.TextDecoder + get func(c Ctx, key string, defaultValue ...string) string +} + +func (d *fieldTextDecoder) Decode(ctx Ctx, reqValue reflect.Value) error { + text := d.get(ctx, d.reqField) + if text == "" { + return nil + } + + err := d.dec.UnmarshalString(text, reqValue.Field(d.index)) + if err != nil { + return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.reqField, err) + } + + return nil +} diff --git a/bind_readme.md b/bind_readme.md new file mode 100644 index 0000000000..77cc5773bc --- /dev/null +++ b/bind_readme.md @@ -0,0 +1,172 @@ +# Fiber Binders + +Bind is new request/response binding feature for Fiber. +By against old Fiber parsers, it supports custom binder registration, +struct validation with high performance and easy to use. + +It's introduced in Fiber v3 and a replacement of: + +- BodyParser +- ParamsParser +- GetReqHeaders +- GetRespHeaders +- AllParams +- QueryParser +- ReqHeaderParser + +## Guides + +### Binding basic request info + +Fiber supports binding basic request data into the struct: + +all tags you can use are: + +- respHeader +- header +- query +- param +- cookie + +(binding for Request/Response header are case in-sensitive) + +private and anonymous fields will be ignored. + +```go +package main + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "time" + + fiber "github.com/gofiber/fiber/v3" +) + +type Req struct { + ID int `param:"id"` + Q int `query:"q"` + Likes []int `query:"likes"` + T time.Time `header:"x-time"` + Token string `header:"x-auth"` +} + +func main() { + app := fiber.New() + + app.Get("/:id", func(c fiber.Ctx) error { + var req Req + if err := c.Bind().Req(&req).Err(); err != nil { + return err + } + return c.JSON(req) + }) + + req := httptest.NewRequest(http.MethodGet, "/1?&s=a,b,c&q=47&likes=1&likes=2", http.NoBody) + req.Header.Set("x-auth", "ttt") + req.Header.Set("x-time", "2022-08-08T08:11:39+08:00") + resp, err := app.Test(req) + if err != nil { + panic(err) + } + defer resp.Body.Close() + + b, err := io.ReadAll(resp.Body) + if err != nil { + panic(err) + } + + fmt.Println(resp.StatusCode, string(b)) + // Output: 200 {"ID":1,"S":["a","b","c"],"Q":47,"Likes":[1,2],"T":"2022-08-08T08:11:39+08:00","Token":"ttt"} +} + +``` + +### Defining Custom Binder + +We support 2 types of Custom Binder + +#### a `encoding.TextUnmarshaler` with basic tag config. + +like the `time.Time` field in the previous example, if a field implement `encoding.TextUnmarshaler`, it will be called +to +unmarshal raw string we get from request's query/header/... + +#### a `fiber.Binder` interface. + +You don't need to set a field tag and it's binding tag will be ignored. + +``` +type Binder interface { + UnmarshalFiberCtx(ctx fiber.Ctx) error +} +``` + +If your type implement `fiber.Binder`, bind will pass current request Context to your and you can unmarshal the info +you need. + +### Parse Request Body + +you can call `ctx.BodyJSON(v any) error` or `BodyXML(v any) error` + +These methods will check content-type HTTP header and call configured JSON or XML decoder to unmarshal. + +```golang +package main + +type Body struct { + ID int `json:"..."` + Q int `json:"..."` + Likes []int `json:"..."` + T time.Time `json:"..."` + Token string `json:"..."` +} + +func main() { + app := fiber.New() + + app.Get("/:id", func(c fiber.Ctx) error { + var data Body + if err := c.Bind().JSON(&data).Err(); err != nil { + return err + } + return c.JSON(data) + }) +} +``` + +### Bind With validation + +Normally, `bind` will only try to unmarshal data from request and pass it to request handler. + +you can call `.Validate()` to validate previous binding. + +And you will need to set a validator in app Config, otherwise it will always return an error. + +```go +package main + +type Validator struct{} + +func (validator *Validator) Validate(v any) error { + return nil +} + +func main() { + app := fiber.New(fiber.Config{ + Validator: &Validator{}, + }) + + app.Get("/:id", func(c fiber.Ctx) error { + var req struct{} + var body struct{} + if err := c.Bind().Req(&req).Validate().JSON(&body).Validate().Err(); err != nil { + return err + } + + return nil + }) +} +``` diff --git a/bind_test.go b/bind_test.go new file mode 100644 index 0000000000..1f1f4aca4d --- /dev/null +++ b/bind_test.go @@ -0,0 +1,331 @@ +package fiber + +import ( + "net/url" + "regexp" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +// go test -run Test_Bind_BasicType -v +func Test_Bind_BasicType(t *testing.T) { + t.Parallel() + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + + type Query struct { + Flag bool `query:"enable"` + + I8 int8 `query:"i8"` + I16 int16 `query:"i16"` + I32 int32 `query:"i32"` + I64 int64 `query:"i64"` + I int `query:"i"` + + U8 uint8 `query:"u8"` + U16 uint16 `query:"u16"` + U32 uint32 `query:"u32"` + U64 uint64 `query:"u64"` + U uint `query:"u"` + + S string `query:"s"` + } + + var q Query + + const qs = "i8=88&i16=166&i32=322&i64=644&i=101&u8=77&u16=165&u32=321&u64=643&u=99&s=john&enable=true" + c.Request().URI().SetQueryString(qs) + require.NoError(t, c.Bind().Req(&q).Err()) + + require.Equal(t, Query{ + Flag: true, + I8: 88, + I16: 166, + I32: 322, + I64: 644, + I: 101, + U8: 77, + U16: 165, + U32: 321, + U64: 643, + U: 99, + S: "john", + }, q) + + type Query2 struct { + Flag []bool `query:"enable"` + + I8 []int8 `query:"i8"` + I16 []int16 `query:"i16"` + I32 []int32 `query:"i32"` + I64 []int64 `query:"i64"` + I []int `query:"i"` + + U8 []uint8 `query:"u8"` + U16 []uint16 `query:"u16"` + U32 []uint32 `query:"u32"` + U64 []uint64 `query:"u64"` + U []uint `query:"u"` + + S []string `query:"s"` + } + + var q2 Query2 + + c.Request().URI().SetQueryString(qs) + require.NoError(t, c.Bind().Req(&q2).Err()) + + require.Equal(t, Query2{ + Flag: []bool{true}, + I8: []int8{88}, + I16: []int16{166}, + I32: []int32{322}, + I64: []int64{644}, + I: []int{101}, + U8: []uint8{77}, + U16: []uint16{165}, + U32: []uint32{321}, + U64: []uint64{643}, + U: []uint{99}, + S: []string{"john"}, + }, q2) + +} + +// go test -run Test_Bind_Query -v +func Test_Bind_Query(t *testing.T) { + t.Parallel() + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + + type Query struct { + ID int `query:"id"` + Name string `query:"name"` + Hobby []string `query:"hobby"` + } + + var q Query + + c.Request().SetBody([]byte{}) + c.Request().Header.SetContentType("") + c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball&hobby=football") + + require.NoError(t, c.Bind().Req(&q).Err()) + require.Equal(t, 2, len(q.Hobby)) + + c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball,football") + require.NoError(t, c.Bind().Req(&q).Err()) + require.Equal(t, 1, len(q.Hobby)) + + c.Request().URI().SetQueryString("id=1&name=tom&hobby=scoccer&hobby=basketball,football") + require.NoError(t, c.Bind().Req(&q).Err()) + require.Equal(t, 2, len(q.Hobby)) + + c.Request().URI().SetQueryString("") + require.NoError(t, c.Bind().Req(&q).Err()) + require.Equal(t, 0, len(q.Hobby)) + + type Query2 struct { + Bool bool `query:"bool"` + ID int `query:"id"` + Name string `query:"name"` + Hobby string `query:"hobby"` + FavouriteDrinks string `query:"favouriteDrinks"` + Empty []string `query:"empty"` + Alloc []string `query:"alloc"` + No []int64 `query:"no"` + } + + var q2 Query2 + + c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball,football&favouriteDrinks=milo,coke,pepsi&alloc=&no=1") + require.NoError(t, c.Bind().Req(&q2).Err()) + require.Equal(t, "basketball,football", q2.Hobby) + require.Equal(t, "tom", q2.Name) // check value get overwritten + require.Equal(t, "milo,coke,pepsi", q2.FavouriteDrinks) + require.Equal(t, []string{}, q2.Empty) + require.Equal(t, []string{""}, q2.Alloc) + require.Equal(t, []int64{1}, q2.No) + + type ArrayQuery struct { + Data []string `query:"data[]"` + } + var aq ArrayQuery + c.Request().URI().SetQueryString("data[]=john&data[]=doe") + require.NoError(t, c.Bind().Req(&aq).Err()) + require.Equal(t, ArrayQuery{Data: []string{"john", "doe"}}, aq) +} + +// go test -run Test_Bind_Resp_Header -v +func Test_Bind_Resp_Header(t *testing.T) { + t.Parallel() + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + + type resHeader struct { + Key string `respHeader:"k"` + + Keys []string `respHeader:"keys"` + } + + c.Set("k", "vv") + c.Response().Header.Add("keys", "v1") + c.Response().Header.Add("keys", "v2") + + var q resHeader + require.NoError(t, c.Bind().Req(&q).Err()) + require.Equal(t, "vv", q.Key) + require.Equal(t, []string{"v1", "v2"}, q.Keys) +} + +var _ Binder = (*userCtxUnmarshaler)(nil) + +type userCtxUnmarshaler struct { + V int +} + +func (u *userCtxUnmarshaler) UnmarshalFiberCtx(ctx Ctx) error { + u.V++ + return nil +} + +// go test -run Test_Bind_CustomizedUnmarshaler -v +func Test_Bind_CustomizedUnmarshaler(t *testing.T) { + t.Parallel() + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + + type Req struct { + Key userCtxUnmarshaler + } + + var r Req + require.NoError(t, c.Bind().Req(&r).Err()) + require.Equal(t, 1, r.Key.V) + + require.NoError(t, c.Bind().Req(&r).Err()) + require.Equal(t, 1, r.Key.V) +} + +// go test -run Test_Bind_TextUnmarshaler -v +func Test_Bind_TextUnmarshaler(t *testing.T) { + t.Parallel() + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + + type Req struct { + Time time.Time `query:"time"` + } + + now := time.Now() + + c.Request().URI().SetQueryString(url.Values{ + "time": []string{now.Format(time.RFC3339Nano)}, + }.Encode()) + + var q Req + require.NoError(t, c.Bind().Req(&q).Err()) + require.Equal(t, false, q.Time.IsZero(), "time should not be zero") + require.Equal(t, true, q.Time.Before(now.Add(time.Second))) + require.Equal(t, true, q.Time.After(now.Add(-time.Second))) +} + +// go test -run Test_Bind_error_message -v +func Test_Bind_error_message(t *testing.T) { + t.Parallel() + app := New() + c := app.NewCtx(&fasthttp.RequestCtx{}) + + type Req struct { + Time time.Time `query:"time"` + } + + c.Request().URI().SetQueryString("time=john") + + err := c.Bind().Req(&Req{}).Err() + + require.Error(t, err) + require.Regexp(t, regexp.MustCompile(`unable to decode 'john' as time`), err.Error()) +} + +type Req struct { + ID int `query:"id"` + + I int `query:"I"` + J int `query:"j"` + K int `query:"k"` + + Token string `header:"x-auth"` +} + +func getCtx() Ctx { + app := New() + + // TODO: also bench params + ctx := app.NewCtx(&fasthttp.RequestCtx{}) + + var u = fasthttp.URI{} + u.SetQueryString("j=1&j=123&k=-1") + ctx.Request().SetURI(&u) + + ctx.Request().Header.Set("a-auth", "bearer tt") + + return ctx +} + +func Benchmark_Bind_by_hand(b *testing.B) { + ctx := getCtx() + for i := 0; i < b.N; i++ { + var req Req + var err error + if raw := ctx.Query("id"); raw != "" { + req.ID, err = strconv.Atoi(raw) + if err != nil { + b.Error(err) + b.FailNow() + } + } + + if raw := ctx.Query("i"); raw != "" { + req.I, err = strconv.Atoi(raw) + if err != nil { + b.Error(err) + b.FailNow() + } + } + + if raw := ctx.Query("j"); raw != "" { + req.J, err = strconv.Atoi(raw) + if err != nil { + b.Error(err) + b.FailNow() + } + } + + if raw := ctx.Query("k"); raw != "" { + req.K, err = strconv.Atoi(raw) + if err != nil { + b.Error(err) + b.FailNow() + } + } + + req.Token = ctx.Get("x-auth") + } +} + +func Benchmark_Bind(b *testing.B) { + ctx := getCtx() + for i := 0; i < b.N; i++ { + var v = Req{} + err := ctx.Bind().Req(&v) + if err != nil { + b.Error(err) + b.FailNow() + } + } +} diff --git a/binder.go b/binder.go new file mode 100644 index 0000000000..4ce2f1b6b7 --- /dev/null +++ b/binder.go @@ -0,0 +1,124 @@ +package fiber + +import ( + "bytes" + "net/http" + "reflect" + + "github.com/gofiber/fiber/v3/internal/reflectunsafe" + "github.com/gofiber/fiber/v3/utils" +) + +type Bind struct { + err error + ctx Ctx + val any // last decoded val +} + +func (b *Bind) setErr(err error) *Bind { + b.err = err + return b +} + +func (b *Bind) HTTPErr() error { + if b.err != nil { + if fe, ok := b.err.(*Error); ok { + return fe + } + + return NewError(http.StatusBadRequest, b.err.Error()) + } + + return nil +} + +func (b *Bind) Err() error { + return b.err +} + +// JSON unmarshal body as json +// unlike `ctx.BodyJSON`, this will also check "content-type" HTTP header. +func (b *Bind) JSON(v any) *Bind { + if b.err != nil { + return b + } + + if !bytes.HasPrefix(b.ctx.Request().Header.ContentType(), utils.UnsafeBytes(MIMEApplicationJSON)) { + return b.setErr(NewError(http.StatusUnsupportedMediaType, "expecting content-type \"application/json\"")) + } + + if err := b.ctx.BodyJSON(v); err != nil { + return b.setErr(err) + } + + b.val = v + return b +} + +// XML unmarshal body as xml +// unlike `ctx.BodyXML`, this will also check "content-type" HTTP header. +func (b *Bind) XML(v any) *Bind { + if b.err != nil { + return b + } + + if !bytes.HasPrefix(b.ctx.Request().Header.ContentType(), utils.UnsafeBytes(MIMEApplicationXML)) { + return b.setErr(NewError(http.StatusUnsupportedMediaType, "expecting content-type \"application/xml\"")) + } + + if err := b.ctx.BodyXML(v); err != nil { + return b.setErr(err) + } + + b.val = v + return b +} + +func (b *Bind) Req(v any) *Bind { + if b.err != nil { + return b + } + + if err := b.decode(v); err != nil { + return b.setErr(err) + } + return b +} + +func (b *Bind) Validate() *Bind { + if b.err != nil { + return b + } + + if b.val == nil { + return b + } + + if err := b.ctx.Validate(b.val); err != nil { + return b.setErr(err) + } + + return b +} + +func (b *Bind) decode(v any) error { + rv, typeID := reflectunsafe.ValueAndTypeID(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return &InvalidBinderError{Type: reflect.TypeOf(v)} + } + + cached, ok := b.ctx.App().bindDecoderCache.Load(typeID) + if ok { + // cached decoder, fast path + decoder := cached.(Decoder) + return decoder(b.ctx, rv.Elem()) + } + + decoder, err := compileReqParser(rv.Type()) + if err != nil { + return err + } + + b.ctx.App().bindDecoderCache.Store(typeID, decoder) + return decoder(b.ctx, rv.Elem()) +} diff --git a/binder_compile.go b/binder_compile.go new file mode 100644 index 0000000000..68eb47a2e1 --- /dev/null +++ b/binder_compile.go @@ -0,0 +1,164 @@ +package fiber + +import ( + "bytes" + "encoding" + "errors" + "fmt" + "reflect" + "strconv" + + "github.com/gofiber/fiber/v3/internal/bind" + "github.com/gofiber/fiber/v3/utils" +) + +type Decoder func(c Ctx, rv reflect.Value) error + +const bindTagRespHeader = "respHeader" +const bindTagHeader = "header" +const bindTagQuery = "query" +const bindTagParam = "param" +const bindTagCookie = "cookie" + +var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() +var bindUnmarshalerType = reflect.TypeOf((*Binder)(nil)).Elem() + +func compileReqParser(rt reflect.Type) (Decoder, error) { + var decoders []decoder + + el := rt.Elem() + if el.Kind() != reflect.Struct { + panic("wrapped request need to struct") + } + + for i := 0; i < el.NumField(); i++ { + if !el.Field(i).IsExported() { + // ignore unexported field + continue + } + + dec, err := compileFieldDecoder(el.Field(i), i) + if err != nil { + return nil, err + } + + if dec != nil { + decoders = append(decoders, dec) + } + } + + return func(c Ctx, rv reflect.Value) error { + for _, decoder := range decoders { + err := decoder.Decode(c, rv) + if err != nil { + return err + } + } + + return nil + }, nil +} + +func compileFieldDecoder(field reflect.StructField, index int) (decoder, error) { + if reflect.PtrTo(field.Type).Implements(bindUnmarshalerType) { + return &fieldCtxDecoder{index: index, fieldName: field.Name, fieldType: field.Type}, nil + } + + var tagScope = "" + for _, loopTagScope := range []string{bindTagRespHeader, bindTagQuery, bindTagParam, bindTagHeader, bindTagCookie} { + if _, ok := field.Tag.Lookup(loopTagScope); ok { + tagScope = loopTagScope + break + } + } + + if tagScope == "" { + return nil, nil + } + + tagContent := field.Tag.Get(tagScope) + + if reflect.PtrTo(field.Type).Implements(textUnmarshalerType) { + return compileTextBasedDecoder(field, index, tagScope, tagContent) + } + + if field.Type.Kind() == reflect.Slice { + return compileSliceFieldTextBasedDecoder(field, index, tagScope, tagContent) + } + + return compileTextBasedDecoder(field, index, tagScope, tagContent) +} + +func compileTextBasedDecoder(field reflect.StructField, index int, tagScope, tagContent string) (decoder, error) { + var get func(ctx Ctx, key string, defaultValue ...string) string + switch tagScope { + case bindTagQuery: + get = Ctx.Query + case bindTagHeader: + get = Ctx.Get + case bindTagRespHeader: + get = Ctx.GetRespHeader + case bindTagParam: + get = Ctx.Params + case bindTagCookie: + get = Ctx.Cookies + default: + return nil, errors.New("unexpected tag scope " + strconv.Quote(tagScope)) + } + + textDecoder, err := bind.CompileTextDecoder(field.Type) + if err != nil { + return nil, err + } + + return &fieldTextDecoder{ + index: index, + fieldName: field.Name, + tag: tagScope, + reqField: tagContent, + dec: textDecoder, + get: get, + }, nil +} + +func compileSliceFieldTextBasedDecoder(field reflect.StructField, index int, tagScope string, tagContent string) (decoder, error) { + if field.Type.Kind() != reflect.Slice { + panic("BUG: unexpected type, expecting slice " + field.Type.String()) + } + + et := field.Type.Elem() + elementUnmarshaler, err := bind.CompileTextDecoder(et) + if err != nil { + return nil, fmt.Errorf("failed to build slice binder: %w", err) + } + + var eqBytes = bytes.Equal + var visitAll func(Ctx, func(key, value []byte)) + switch tagScope { + case bindTagQuery: + visitAll = visitQuery + case bindTagHeader: + visitAll = visitHeader + eqBytes = utils.EqualFold[[]byte] + case bindTagRespHeader: + visitAll = visitResHeader + eqBytes = utils.EqualFold[[]byte] + case bindTagCookie: + visitAll = visitCookie + case bindTagParam: + return nil, errors.New("using params with slice type is not supported") + default: + return nil, errors.New("unexpected tag scope " + strconv.Quote(tagScope)) + } + + return &fieldSliceDecoder{ + fieldIndex: index, + eqBytes: eqBytes, + fieldName: field.Name, + visitAll: visitAll, + reqKey: []byte(tagContent), + fieldType: field.Type, + elementType: et, + elementDecoder: elementUnmarshaler, + }, nil +} diff --git a/binder_slice.go b/binder_slice.go new file mode 100644 index 0000000000..c9031abfe9 --- /dev/null +++ b/binder_slice.go @@ -0,0 +1,76 @@ +package fiber + +import ( + "reflect" + + "github.com/gofiber/fiber/v3/internal/bind" + "github.com/gofiber/fiber/v3/utils" +) + +var _ decoder = (*fieldSliceDecoder)(nil) + +type fieldSliceDecoder struct { + fieldIndex int + fieldName string + fieldType reflect.Type + reqKey []byte + // [utils.EqualFold] for headers and [bytes.Equal] for query/params. + eqBytes func([]byte, []byte) bool + elementType reflect.Type + elementDecoder bind.TextDecoder + visitAll func(Ctx, func(key []byte, value []byte)) +} + +func (d *fieldSliceDecoder) Decode(ctx Ctx, reqValue reflect.Value) error { + count := 0 + d.visitAll(ctx, func(key, value []byte) { + if d.eqBytes(key, d.reqKey) { + count++ + } + }) + + rv := reflect.MakeSlice(d.fieldType, 0, count) + + if count == 0 { + reqValue.Field(d.fieldIndex).Set(rv) + return nil + } + + var err error + d.visitAll(ctx, func(key, value []byte) { + if err != nil { + return + } + if d.eqBytes(key, d.reqKey) { + ev := reflect.New(d.elementType) + if ee := d.elementDecoder.UnmarshalString(utils.UnsafeString(value), ev.Elem()); ee != nil { + err = ee + } + + rv = reflect.Append(rv, ev.Elem()) + } + }) + + if err != nil { + return err + } + + reqValue.Field(d.fieldIndex).Set(rv) + return nil +} + +func visitQuery(ctx Ctx, f func(key []byte, value []byte)) { + ctx.Context().QueryArgs().VisitAll(f) +} + +func visitHeader(ctx Ctx, f func(key []byte, value []byte)) { + ctx.Request().Header.VisitAll(f) +} + +func visitResHeader(ctx Ctx, f func(key []byte, value []byte)) { + ctx.Response().Header.VisitAll(f) +} + +func visitCookie(ctx Ctx, f func(key []byte, value []byte)) { + ctx.Request().Header.VisitAllCookie(f) +} diff --git a/binder_test.go b/binder_test.go new file mode 100644 index 0000000000..862969a334 --- /dev/null +++ b/binder_test.go @@ -0,0 +1,32 @@ +package fiber + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func Test_Binder(t *testing.T) { + t.Parallel() + app := New() + + ctx := app.NewCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) + ctx.values = [maxParams]string{"id string"} + ctx.route = &Route{Params: []string{"id"}} + ctx.Request().SetBody([]byte(`{"name": "john doe"}`)) + ctx.Request().Header.Set("content-type", "application/json") + + var req struct { + ID string `param:"id"` + } + + var body struct { + Name string `json:"name"` + } + + err := ctx.Bind().Req(&req).JSON(&body).Err() + require.NoError(t, err) + require.Equal(t, "id string", req.ID) + require.Equal(t, "john doe", body.Name) +} diff --git a/client_test.go b/client_test.go index 987b0c3cbb..daf78c0baf 100644 --- a/client_test.go +++ b/client_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/tls" "encoding/base64" + "encoding/json" "errors" "fmt" "io" @@ -16,8 +17,6 @@ import ( "testing" "time" - "encoding/json" - "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp/fasthttputil" @@ -431,7 +430,7 @@ func Test_Client_Agent_BasicAuth(t *testing.T) { handler := func(c Ctx) error { // Get authorization header auth := c.Get(HeaderAuthorization) - // Decode the header contents + // Req the header contents raw, err := base64.StdEncoding.DecodeString(auth[6:]) require.NoError(t, err) diff --git a/ctx.go b/ctx.go index 9f88d1f9ed..1887363552 100644 --- a/ctx.go +++ b/ctx.go @@ -213,6 +213,25 @@ func (c *DefaultCtx) BaseURL() string { return c.baseURI } +func (c *DefaultCtx) Bind() *Bind { + return &Bind{ctx: c} +} + +// func (c *DefaultCtx) BindWithValidate(v any) error { +// if err := c.Bind(v); err != nil { +// return err +// } +// +// return c.EnableValidate(v) +// } + +func (c *DefaultCtx) Validate(v any) error { + if c.app.config.Validator == nil { + return NilValidatorError{} + } + return c.app.config.Validator.Validate(v) +} + // Body contains the raw body submitted in a POST request. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. @@ -245,6 +264,14 @@ func (c *DefaultCtx) Body() []byte { return body } +func (c *DefaultCtx) BodyJSON(v any) error { + return c.app.config.JSONDecoder(c.Body(), v) +} + +func (c *DefaultCtx) BodyXML(v any) error { + return c.app.config.XMLDecoder(c.Body(), v) +} + // ClearCookie expires a specific cookie by key on the client side. // If no key is provided it expires all cookies that came with the request. func (c *DefaultCtx) ClearCookie(key ...string) { @@ -836,7 +863,7 @@ func (c *DefaultCtx) Redirect(location string, status ...int) error { return nil } -// Add vars to default view var map binding to template engine. +// BindVars Add vars to default view var map binding to template engine. // Variables are read by the Render method and may be overwritten. func (c *DefaultCtx) BindVars(vars Map) error { // init viewBindMap - lazy map diff --git a/ctx_interface.go b/ctx_interface.go index d18ae3d18b..98e9cd3122 100644 --- a/ctx_interface.go +++ b/ctx_interface.go @@ -42,11 +42,28 @@ type Ctx interface { // BaseURL returns (protocol + host + base path). BaseURL() string + // Bind unmarshal request data from context add assign to struct fields. + // You can bind cookie, headers etc. into basic type, slice, or any customized binders by + // implementing [encoding.TextUnmarshaler] or [bind.Unmarshaler]. + // Replacement of: BodyParser, ParamsParser, GetReqHeaders, GetRespHeaders, AllParams, QueryParser, ReqHeaderParser + Bind() *Bind + + // BindWithValidate is an alias for `context.Bind` and `context.EnableValidate` + // BindWithValidate(v any) error + + Validate(v any) error + // Body contains the raw body submitted in a POST request. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. Body() []byte + // BodyJSON will unmarshal request body with Config.JSONDecoder + BodyJSON(v any) error + + // BodyXML will unmarshal request body with Config.XMLDecoder + BodyXML(v any) error + // ClearCookie expires a specific cookie by key on the client side. // If no key is provided it expires all cookies that came with the request. ClearCookie(key ...string) @@ -227,7 +244,7 @@ type Ctx interface { // If status is not specified, status defaults to 302 Found. Redirect(location string, status ...int) error - // Add vars to default view var map binding to template engine. + // BindVars Add vars to default view var map binding to template engine. // Variables are read by the Render method and may be overwritten. BindVars(vars Map) error diff --git a/error.go b/error.go index d6aee39d99..965d712450 100644 --- a/error.go +++ b/error.go @@ -1,11 +1,36 @@ package fiber import ( - goErrors "errors" + "errors" + "reflect" ) // Range errors var ( - ErrRangeMalformed = goErrors.New("range: malformed range header string") - ErrRangeUnsatisfiable = goErrors.New("range: unsatisfiable range") + ErrRangeMalformed = errors.New("range: malformed range header string") + ErrRangeUnsatisfiable = errors.New("range: unsatisfiable range") ) + +// NilValidatorError is the validate error when context.EnableValidate is called but no validator is set in config. +type NilValidatorError struct { +} + +func (n NilValidatorError) Error() string { + return "fiber: ctx.EnableValidate(v any) is called without validator" +} + +// InvalidBinderError is the error when try to bind unsupported type. +type InvalidBinderError struct { + Type reflect.Type +} + +func (e *InvalidBinderError) Error() string { + if e.Type == nil { + return "fiber: Bind(nil)" + } + + if e.Type.Kind() != reflect.Pointer { + return "fiber: Unmarshal(non-pointer " + e.Type.String() + ")" + } + return "fiber: Bind(nil " + e.Type.String() + ")" +} diff --git a/internal/bind/bool.go b/internal/bind/bool.go new file mode 100644 index 0000000000..a7f207cea3 --- /dev/null +++ b/internal/bind/bool.go @@ -0,0 +1,18 @@ +package bind + +import ( + "reflect" + "strconv" +) + +type boolDecoder struct { +} + +func (d *boolDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + fieldValue.SetBool(v) + return nil +} diff --git a/internal/bind/compile.go b/internal/bind/compile.go new file mode 100644 index 0000000000..da5ca7ae66 --- /dev/null +++ b/internal/bind/compile.go @@ -0,0 +1,49 @@ +package bind + +import ( + "encoding" + "errors" + "reflect" +) + +type TextDecoder interface { + UnmarshalString(s string, fieldValue reflect.Value) error +} + +var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + +func CompileTextDecoder(rt reflect.Type) (TextDecoder, error) { + // encoding.TextUnmarshaler + if reflect.PtrTo(rt).Implements(textUnmarshalerType) { + return &textUnmarshalEncoder{fieldType: rt}, nil + } + + switch rt.Kind() { + case reflect.Bool: + return &boolDecoder{}, nil + case reflect.Uint8: + return &uintDecoder{bitSize: 8}, nil + case reflect.Uint16: + return &uintDecoder{bitSize: 16}, nil + case reflect.Uint32: + return &uintDecoder{bitSize: 32}, nil + case reflect.Uint64: + return &uintDecoder{bitSize: 64}, nil + case reflect.Uint: + return &uintDecoder{}, nil + case reflect.Int8: + return &intDecoder{bitSize: 8}, nil + case reflect.Int16: + return &intDecoder{bitSize: 16}, nil + case reflect.Int32: + return &intDecoder{bitSize: 32}, nil + case reflect.Int64: + return &intDecoder{bitSize: 64}, nil + case reflect.Int: + return &intDecoder{}, nil + case reflect.String: + return &stringDecoder{}, nil + } + + return nil, errors.New("unsupported type " + rt.String()) +} diff --git a/internal/bind/int.go b/internal/bind/int.go new file mode 100644 index 0000000000..6b1cb4855d --- /dev/null +++ b/internal/bind/int.go @@ -0,0 +1,19 @@ +package bind + +import ( + "reflect" + "strconv" +) + +type intDecoder struct { + bitSize int +} + +func (d *intDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseInt(s, 10, d.bitSize) + if err != nil { + return err + } + fieldValue.SetInt(v) + return nil +} diff --git a/internal/bind/string.go b/internal/bind/string.go new file mode 100644 index 0000000000..521b2277b7 --- /dev/null +++ b/internal/bind/string.go @@ -0,0 +1,15 @@ +package bind + +import ( + "reflect" + + "github.com/gofiber/fiber/v3/utils" +) + +type stringDecoder struct { +} + +func (d *stringDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + fieldValue.SetString(utils.CopyString(s)) + return nil +} diff --git a/internal/bind/text_unmarshaler.go b/internal/bind/text_unmarshaler.go new file mode 100644 index 0000000000..55b5b5811d --- /dev/null +++ b/internal/bind/text_unmarshaler.go @@ -0,0 +1,27 @@ +package bind + +import ( + "encoding" + "reflect" +) + +type textUnmarshalEncoder struct { + fieldType reflect.Type +} + +func (d *textUnmarshalEncoder) UnmarshalString(s string, fieldValue reflect.Value) error { + if s == "" { + return nil + } + + v := reflect.New(d.fieldType) + unmarshaler := v.Interface().(encoding.TextUnmarshaler) + + if err := unmarshaler.UnmarshalText([]byte(s)); err != nil { + return err + } + + fieldValue.Set(v.Elem()) + + return nil +} diff --git a/internal/bind/uint.go b/internal/bind/uint.go new file mode 100644 index 0000000000..8cccc95378 --- /dev/null +++ b/internal/bind/uint.go @@ -0,0 +1,19 @@ +package bind + +import ( + "reflect" + "strconv" +) + +type uintDecoder struct { + bitSize int +} + +func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseUint(s, 10, d.bitSize) + if err != nil { + return err + } + fieldValue.SetUint(v) + return nil +} diff --git a/internal/reflectunsafe/reflectunsafe.go b/internal/reflectunsafe/reflectunsafe.go new file mode 100644 index 0000000000..7416da003b --- /dev/null +++ b/internal/reflectunsafe/reflectunsafe.go @@ -0,0 +1,12 @@ +package reflectunsafe + +import ( + "reflect" + "unsafe" +) + +func ValueAndTypeID(v any) (reflect.Value, uintptr) { + rv := reflect.ValueOf(v) + rt := rv.Type() + return rv, (*[2]uintptr)(unsafe.Pointer(&rt))[1] +} diff --git a/utils/xml.go b/utils/xml.go index 9cc23512b0..f205cb6633 100644 --- a/utils/xml.go +++ b/utils/xml.go @@ -2,3 +2,9 @@ package utils // XMLMarshal returns the XML encoding of v. type XMLMarshal func(v any) ([]byte, error) + +// XMLUnmarshal parses the XML-encoded data and stores the result in +// the value pointed to by v, which must be an arbitrary struct, +// slice, or string. Well-formed data that does not fit into v is +// discarded. +type XMLUnmarshal func([]byte, any) error diff --git a/validate.go b/validate.go new file mode 100644 index 0000000000..72dfee6ca9 --- /dev/null +++ b/validate.go @@ -0,0 +1,5 @@ +package fiber + +type Validator interface { + Validate(v any) error +}