diff --git a/openapi/components.go b/openapi/components.go index 55b4ad17..e81922a8 100644 --- a/openapi/components.go +++ b/openapi/components.go @@ -47,7 +47,7 @@ func newComponents() *components { } } -func (t *Parameter) addComponents(c *components, in string) { +func (t *Parameter) addToComponents(c *components, in string) { if t.Ref == nil { return } @@ -72,11 +72,11 @@ func (t *Parameter) addComponents(c *components, in string) { } if t.Schema != nil { - t.Schema.addComponents(c) + t.Schema.addToComponents(c) } } -func (s *Schema) addComponents(c *components) { +func (s *Schema) addToComponents(c *components) { if s.Ref != nil { if _, found := c.schemas[s.Ref.Ref]; !found { c.schemas[s.Ref.Ref] = s @@ -84,27 +84,27 @@ func (s *Schema) addComponents(c *components) { } for _, item := range s.AllOf { - item.addComponents(c) + item.addToComponents(c) } for _, item := range s.OneOf { - item.addComponents(c) + item.addToComponents(c) } for _, item := range s.AnyOf { - item.addComponents(c) + item.addToComponents(c) } if s.Items != nil { - s.Items.addComponents(c) + s.Items.addToComponents(c) } for _, item := range s.Properties { - item.addComponents(c) + item.addToComponents(c) } } -func (resp *Callback) addComponents(c *components) { +func (resp *Callback) addToComponents(c *components) { if resp.Ref != nil { if _, found := c.callbacks[resp.Ref.Ref]; !found { c.callbacks[resp.Ref.Ref] = resp @@ -112,42 +112,42 @@ func (resp *Callback) addComponents(c *components) { } for _, item := range resp.Callback { - item.addComponents(c) + item.addToComponents(c) } } // 所有带 $ref 的字段如果还未存在于 c,则会写入。 -func (o *Operation) addComponents(c *components) { +func (o *Operation) addToComponents(c *components) { for _, p := range o.Paths { - p.addComponents(c, InPath) + p.addToComponents(c, InPath) } for _, p := range o.Queries { - p.addComponents(c, InQuery) + p.addToComponents(c, InQuery) } for _, p := range o.Cookies { - p.addComponents(c, InCookie) + p.addToComponents(c, InCookie) } for _, p := range o.Headers { - p.addComponents(c, InHeader) + p.addToComponents(c, InHeader) } if o.RequestBody != nil { - o.RequestBody.addComponents(c) + o.RequestBody.addToComponents(c) } for _, r := range o.Responses { - r.addComponents(c) + r.addToComponents(c) } for _, r := range o.Callbacks { - r.addComponents(c) + r.addToComponents(c) } } -func (resp *Response) addComponents(c *components) { +func (resp *Response) addToComponents(c *components) { if resp.Ref != nil { if _, found := c.responses[resp.Ref.Ref]; !found { c.responses[resp.Ref.Ref] = resp @@ -155,19 +155,19 @@ func (resp *Response) addComponents(c *components) { } for _, h := range resp.Headers { - h.addComponents(c, InHeader) + h.addToComponents(c, InHeader) } if resp.Body != nil { - resp.Body.addComponents(c) + resp.Body.addToComponents(c) } for _, s := range resp.Content { - s.addComponents(c) + s.addToComponents(c) } } -func (req *Request) addComponents(c *components) { +func (req *Request) addToComponents(c *components) { if req.Ref != nil { if _, found := c.requests[req.Ref.Ref]; !found { c.requests[req.Ref.Ref] = req @@ -175,17 +175,17 @@ func (req *Request) addComponents(c *components) { } if req.Body != nil { - req.Body.addComponents(c) + req.Body.addToComponents(c) } if len(req.Content) > 0 { for _, s := range req.Content { - s.addComponents(c) + s.addToComponents(c) } } } -func (item *PathItem) addComponents(c *components) { +func (item *PathItem) addToComponents(c *components) { if item.Ref != nil { if _, found := c.pathItems[item.Ref.Ref]; !found { c.pathItems[item.Ref.Ref] = item @@ -193,15 +193,15 @@ func (item *PathItem) addComponents(c *components) { } for _, p := range item.Paths { - p.addComponents(c, InPath) + p.addToComponents(c, InPath) } for _, p := range item.Queries { - p.addComponents(c, InQuery) + p.addToComponents(c, InQuery) } for _, p := range item.Headers { - p.addComponents(c, InHeader) + p.addToComponents(c, InHeader) } for _, p := range item.Cookies { - p.addComponents(c, InCookie) + p.addToComponents(c, InCookie) } } diff --git a/openapi/components_test.go b/openapi/components_test.go index 1feaaf64..a963e8b3 100644 --- a/openapi/components_test.go +++ b/openapi/components_test.go @@ -18,15 +18,15 @@ func TestParameter_addComponents(t *testing.T) { d := New(ss, web.Phrase("desc")) p := &Parameter{} - p.addComponents(d.components, InPath) + p.addToComponents(d.components, InPath) a.Empty(d.paths) p = &Parameter{Schema: &Schema{Type: TypeString}} - p.addComponents(d.components, InPath) + p.addToComponents(d.components, InPath) a.Empty(d.paths) p = &Parameter{Schema: &Schema{Type: TypeString}, Ref: &Ref{Ref: "string"}} - p.addComponents(d.components, InPath) + p.addToComponents(d.components, InPath) a.Equal(d.components.paths["string"], p) } @@ -36,28 +36,28 @@ func TestSchema_addComponents(t *testing.T) { d := New(ss, web.Phrase("desc")) s := &Schema{} - s.addComponents(d.components) + s.addToComponents(d.components) a.Empty(d.components.schemas) s = &Schema{Type: TypeString} - s.addComponents(d.components) + s.addToComponents(d.components) a.Empty(d.components.schemas) s1 := &Schema{Type: TypeString, Ref: &Ref{Ref: "t1"}} - s1.addComponents(d.components) + s1.addToComponents(d.components) a.Length(d.components.schemas, 1).Equal(d.components.schemas["t1"], s1) // 同名不会再添加 s2 := &Schema{Type: TypeString, Ref: &Ref{Ref: "t1"}} - s2.addComponents(d.components) + s2.addToComponents(d.components) a.Length(d.components.schemas, 1).Equal(d.components.schemas["t1"], s1) s2 = &Schema{Type: TypeString, Ref: &Ref{Ref: "t2"}, Items: s1} - s2.addComponents(d.components) + s2.addToComponents(d.components) a.Length(d.components.schemas, 2).Equal(d.components.schemas["t2"], s2) s3 := &Schema{Type: TypeString, Ref: &Ref{Ref: "t3"}, Items: &Schema{Type: TypeNumber, Ref: &Ref{Ref: "t4"}}} - s3.addComponents(d.components) + s3.addToComponents(d.components) a.Length(d.components.schemas, 4). Equal(d.components.schemas["t3"], s3). Equal(d.components.schemas["t4"].Type, TypeNumber) @@ -74,7 +74,7 @@ func TestPathItem_addComponents(t *testing.T) { Headers: []*Parameter{{Name: "h1", Ref: &Ref{Ref: "h1"}}, {Name: "h2"}}, Cookies: []*Parameter{{Name: "c1", Ref: &Ref{Ref: "c1"}}, {Name: "c2"}}, } - item.addComponents(d.components) + item.addToComponents(d.components) a.Length(d.components.paths, 1). Length(d.components.cookies, 1). Length(d.components.queries, 1). diff --git a/openapi/middleware.go b/openapi/middleware.go index 355baec4..f61eb45e 100644 --- a/openapi/middleware.go +++ b/openapi/middleware.go @@ -85,8 +85,15 @@ func (o *Operation) PathID(name string, desc web.LocaleStringer) *Operation { }) } -func (o *Operation) PathRef(ref string) *Operation { - o.Paths = append(o.Paths, &Parameter{Ref: &Ref{Ref: ref}, Required: true}) +func (o *Operation) PathRef(ref string, summary, description web.LocaleStringer) *Operation { + if _, found := o.d.components.paths[ref]; !found { + panic(fmt.Sprintf("未找到引用 %s", ref)) + } + + o.Paths = append(o.Paths, &Parameter{ + Ref: &Ref{Ref: ref, Summary: summary, Description: description}, + Required: true, + }) return o } @@ -96,13 +103,19 @@ func (o *Operation) Query(name, typ string, desc web.LocaleStringer, f func(*Par return o } -func (o *Operation) QueryRef(ref string) *Operation { - o.Queries = append(o.Queries, &Parameter{Ref: &Ref{Ref: ref}}) +func (o *Operation) QueryRef(ref string, summary, description web.LocaleStringer) *Operation { + if _, found := o.d.components.queries[ref]; !found { + panic(fmt.Sprintf("未找到引用 %s", ref)) + } + + o.Queries = append(o.Queries, &Parameter{Ref: &Ref{ + Ref: ref, + Summary: summary, + Description: description, + }}) return o } -var timeType = reflect.TypeFor[time.Time]() - // QueryObject 从参数 o 中获取相应的查询参数 // // 对于 obj 的要求与 [web.Context.QueryObject] 是相同的。 @@ -121,14 +134,14 @@ func (o *Operation) queryObject(t reflect.Type, f func(*Parameter)) *Operation { } if t.Kind() != reflect.Struct { - panic("o 必须得是 struct 类型") + panic("t 必须得是 struct 类型") } for i := 0; i < t.NumField(); i++ { field := t.Field(i) if field.Anonymous { - o.QueryObject(field.Type, f) + o.queryObject(field.Type, f) continue } @@ -157,35 +170,13 @@ func (o *Operation) queryObject(t reflect.Type, f func(*Parameter)) *Operation { } o.Queries = append(o.Queries, p) } - switch field.Type.Kind() { - case reflect.String: - p.Schema = &Schema{Type: TypeString} - q(p) - case reflect.Bool: - p.Schema = &Schema{Type: TypeBoolean} - q(p) - case reflect.Float32, reflect.Float64: - p.Schema = &Schema{Type: TypeNumber} - q(p) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - p.Schema = &Schema{Type: TypeInteger} - q(p) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - p.Schema = &Schema{Type: TypeInteger, Minimum: 0} - q(p) - case reflect.Array, reflect.Slice: - p.Schema = &Schema{Type: TypeArray, Items: schemaFromType(o.d, reflect.TypeOf(field.Type.Elem()), false, "", nil)} - q(p) - case reflect.Struct: - if field.Type == timeType { - p.Schema = &Schema{Type: TypeString, Format: FormatDateTime} - q(p) - } else { - panic(fmt.Sprintf("查询参数不支持复杂的类型 %v:%v", field.Type.Kind(), field.Name)) - } - default: - panic(fmt.Sprintf("查询参数不支持复杂的类型 %v:%v", field.Type.Kind(), field.Name)) + + p.Schema = &Schema{} + schemaFromType(nil, field.Type, true, "", p.Schema) + if !p.Schema.isBasicType() { + panic("不支持复杂类型") } + q(p) } return o @@ -202,8 +193,16 @@ func (o *Operation) Header(name, typ string, desc web.LocaleStringer, f func(*Pa return o } -func (o *Operation) HeaderRef(ref string) *Operation { - o.Headers = append(o.Headers, &Parameter{Ref: &Ref{Ref: ref}}) +func (o *Operation) HeaderRef(ref string, summary, description web.LocaleStringer) *Operation { + if _, found := o.d.components.headers[ref]; !found { + panic(fmt.Sprintf("未找到引用 %s", ref)) + } + + o.Headers = append(o.Headers, &Parameter{Ref: &Ref{ + Ref: ref, + Summary: summary, + Description: description, + }}) return o } @@ -218,8 +217,16 @@ func (o *Operation) Cookie(name, typ string, desc web.LocaleStringer, f func(*Pa return o } -func (o *Operation) CookieRef(ref string) *Operation { - o.Cookies = append(o.Cookies, &Parameter{Ref: &Ref{Ref: ref}}) +func (o *Operation) CookieRef(ref string, summary, description web.LocaleStringer) *Operation { + if _, found := o.d.components.cookies[ref]; !found { + panic(fmt.Sprintf("未找到引用 %s", ref)) + } + + o.Cookies = append(o.Cookies, &Parameter{Ref: &Ref{ + Ref: ref, + Summary: summary, + Description: description, + }}) return o } @@ -244,8 +251,16 @@ func (o *Operation) Body(body any, ignorable bool, desc web.LocaleStringer, f fu return o } -func (o *Operation) BodyRef(ref string) *Operation { - o.RequestBody = &Request{Ref: &Ref{Ref: ref}} +func (o *Operation) BodyRef(ref string, summary, description web.LocaleStringer) *Operation { + if _, found := o.d.components.requests[ref]; !found { + panic(fmt.Sprintf("未找到引用 %s", ref)) + } + + o.RequestBody = &Request{Ref: &Ref{ + Ref: ref, + Summary: summary, + Description: description, + }} return o } @@ -269,17 +284,33 @@ func (o *Operation) Response(status string, resp any, desc web.LocaleStringer, f return o } -func (o *Operation) ResponseRef(status string, ref string) *Operation { - o.Responses[status] = &Response{Ref: &Ref{Ref: ref}} +func (o *Operation) ResponseRef(status, ref string, summary, description web.LocaleStringer) *Operation { + if _, found := o.d.components.responses[ref]; !found { + panic(fmt.Sprintf("未找到引用 %s", ref)) + } + + o.Responses[status] = &Response{Ref: &Ref{ + Ref: ref, + Summary: summary, + Description: description, + }} return o } // CallbackRef 引用 components 中定义的回调对象 -func (o *Operation) CallbackRef(name, ref string) *Operation { +func (o *Operation) CallbackRef(name, ref string, summary, description web.LocaleStringer) *Operation { + if _, found := o.d.components.callbacks[ref]; !found { + panic(fmt.Sprintf("未找到引用 %s", ref)) + } + if o.Callbacks == nil { o.Callbacks = make(map[string]*Callback, 1) } - o.Callbacks[name] = &Callback{Ref: &Ref{Ref: ref}} + o.Callbacks[name] = &Callback{Ref: &Ref{ + Ref: ref, + Summary: summary, + Description: description, + }} return o } @@ -308,7 +339,10 @@ func (o *Operation) Callback(name, path, method string, f func(*Operation)) *Ope if !found { opt = &Operation{d: o.d} } - f(opt) + + if f != nil { + f(opt) + } return o } @@ -325,7 +359,7 @@ func (d *Document) API(f func(o *Operation)) web.Middleware { (d.enableOptions || method != http.MethodOptions) { o := &Operation{ d: d, - Responses: make(map[string]*Response, 1), // 必然存在的字段,直接初始化了。 + Responses: make(map[string]*Response, len(d.responses)+1), // 必然存在的字段,直接初始化了。 } f(o) @@ -348,7 +382,7 @@ func (d *Document) addOperation(method, pattern, _ string, opt *Operation) { d.paths = make(map[string]*PathItem, 50) } - opt.addComponents(d.components) + opt.addToComponents(d.components) for _, ref := range d.headers { // 添加公共报头的定义 opt.Headers = append(opt.Headers, &Parameter{ diff --git a/openapi/middleware_test.go b/openapi/middleware_test.go index 55f0b49a..a710111f 100644 --- a/openapi/middleware_test.go +++ b/openapi/middleware_test.go @@ -19,6 +19,127 @@ type q struct { Q3 int } +func newOperation(a *assert.Assertion) *Operation { + ss := newServer(a) + d := New(ss, web.Phrase("title")) + o := &Operation{ + d: d, + Responses: make(map[string]*Response, len(d.responses)+1), // 必然存在的字段,直接初始化了。 + } + return o +} + +func TestOperation_Server(t *testing.T) { + a := assert.New(t, false) + o := newOperation(a) + + o.Server("https://example.com", web.Phrase("lang")) + a.Length(o.Servers, 1) + + a.Panic(func() { + o.Server("https://example.com/{id}", web.Phrase("lang")) + }) +} + +func TestOperation_Path(t *testing.T) { + a := assert.New(t, false) + o := newOperation(a) + + o.PathID("p1", nil) + a.Length(o.Paths, 1) + + o.Path("p2", TypeInteger, nil, nil) + a.Length(o.Paths, 2) + + a.PanicString(func() { + o.PathRef("p3", nil, nil) + }, "未找到引用 p3") + + o.d.components.paths["p3"] = &Parameter{Schema: &Schema{Type: TypeInteger}} + o.PathRef("p3", nil, nil) + a.Length(o.Paths, 3) +} + +func TestOperation_Query(t *testing.T) { + a := assert.New(t, false) + o := newOperation(a) + + o.Query("q1", TypeInteger, nil, nil) + a.Length(o.Queries, 1) + + a.PanicString(func() { + o.QueryRef("q2", nil, nil) + }, "未找到引用 q2") + + o.d.components.queries["q2"] = &Parameter{Schema: &Schema{Type: TypeInteger}} + o.QueryRef("q2", nil, nil) + a.Length(o.Queries, 2) +} + +func TestOperation_Cookie(t *testing.T) { + a := assert.New(t, false) + o := newOperation(a) + + o.Cookie("c1", TypeInteger, nil, nil) + a.Length(o.Cookies, 1) + + a.PanicString(func() { + o.CookieRef("c2", nil, nil) + }, "未找到引用 c2") + + o.d.components.cookies["c2"] = &Parameter{Schema: &Schema{Type: TypeInteger}} + o.CookieRef("c2", nil, nil) + a.Length(o.Cookies, 2) +} + +func TestOperation_Header(t *testing.T) { + a := assert.New(t, false) + o := newOperation(a) + + o.Header("h1", TypeInteger, nil, nil) + a.Length(o.Headers, 1) + + a.PanicString(func() { + o.HeaderRef("h2", nil, nil) + }, "未找到引用 h2") + + o.d.components.headers["h2"] = &Parameter{Schema: &Schema{Type: TypeInteger}} + o.HeaderRef("h2", nil, nil) + a.Length(o.Headers, 2) +} + +func TestOperation_Response(t *testing.T) { + a := assert.New(t, false) + o := newOperation(a) + + o.Response("2xx", object{}, nil, nil) + a.Length(o.Responses, 1) + + a.PanicString(func() { + o.ResponseRef("301", "3xx", nil, nil) + }, "未找到引用 3xx") + + o.d.components.responses["3xx"] = &Response{Body: &Schema{Type: TypeInteger}} + o.ResponseRef("301", "3xx", nil, nil) + a.Length(o.Responses, 2) +} + +func TestOperation_Callback(t *testing.T) { + a := assert.New(t, false) + o := newOperation(a) + + o.Callback("c1", "/path1", http.MethodGet, nil) + a.Length(o.Callbacks, 1) + + a.PanicString(func() { + o.CallbackRef("c2", "c2", nil, nil) + }, "未找到引用 c2") + + o.d.components.callbacks["c2"] = &Callback{} + o.CallbackRef("c2", "c2", nil, nil) + a.Length(o.Callbacks, 2) +} + func TestDocument_API(t *testing.T) { a := assert.New(t, false) ss := newServer(a) diff --git a/openapi/openapi.go b/openapi/openapi.go index b2219dbb..26575494 100644 --- a/openapi/openapi.go +++ b/openapi/openapi.go @@ -31,12 +31,12 @@ const ( const ( TypeString = "string" - TypeObject = "object" TypeNull = "null" TypeBoolean = "boolean" TypeArray = "array" TypeNumber = "number" TypeInteger = "integer" + TypeObject = "object" ) const ( diff --git a/openapi/option.go b/openapi/option.go index e2db7f1d..2f1036b6 100644 --- a/openapi/option.go +++ b/openapi/option.go @@ -59,7 +59,7 @@ func WithResponse(resp *Response, status ...string) Option { if resp.Ref == nil || resp.Ref.Ref == "" { panic("resp 必须存在 ref") } - resp.addComponents(d.components) + resp.addToComponents(d.components) for _, s := range status { d.responses[s] = resp.Ref.Ref @@ -118,7 +118,7 @@ func WithHeader(global bool, p ...*Parameter) Option { panic(err) } - pp.addComponents(d.components, InHeader) + pp.addToComponents(d.components, InHeader) if global { d.headers = append(d.headers, pp.Ref.Ref) @@ -140,7 +140,7 @@ func WithCookie(global bool, p ...*Parameter) Option { if pp.Ref == nil || pp.Ref.Ref == "" { panic("必须存在 ref") } - pp.addComponents(d.components, InCookie) + pp.addToComponents(d.components, InCookie) if global { d.cookies = append(d.cookies, pp.Ref.Ref) } @@ -159,7 +159,7 @@ func WithQuery(p ...*Parameter) Option { if pp.Ref == nil || pp.Ref.Ref == "" { panic("必须存在 ref") } - pp.addComponents(d.components, InQuery) + pp.addToComponents(d.components, InQuery) } } } @@ -175,7 +175,7 @@ func WithPath(global bool, p ...*Parameter) Option { if pp.Ref == nil || pp.Ref.Ref == "" { panic("必须存在 ref") } - pp.addComponents(d.components, InPath) + pp.addToComponents(d.components, InPath) } } } @@ -191,7 +191,7 @@ func WithCallback(c ...*Callback) Option { panic("Callback 不能为空") } - cc.addComponents(d.components) + cc.addToComponents(d.components) } } } diff --git a/openapi/render.go b/openapi/render.go index 33bb6a62..28e9e91d 100644 --- a/openapi/render.go +++ b/openapi/render.go @@ -48,6 +48,10 @@ func (r *renderer[T]) MarshalYAML() (any, error) { return r.obj, nil } +func (o *openAPIRenderer) MarshalHTML() (name string, data any) { + return o.templateName, o +} + type documentQuery struct { Tags []string `query:"tag"` } @@ -80,11 +84,6 @@ func (d *Document) Handler(ctx *web.Context) web.Responser { return ctx.Problem(web.ProblemNotAcceptable) } - dataURL := ctx.Request().URL.Path - if len(q.Tags) > 0 { - dataURL += "?tag=" + strings.Join(q.Tags, ",") - } - return web.NotModified(func() (string, bool) { slices.Sort(q.Tags) @@ -101,7 +100,3 @@ func (d *Document) Handler(ctx *web.Context) web.Responser { return d.build(ctx.LocalePrinter(), ctx.LanguageTag(), q.Tags), nil }) } - -func (o *openAPIRenderer) MarshalHTML() (name string, data any) { - return o.templateName, o -} diff --git a/openapi/schema.go b/openapi/schema.go index 1d567d04..33ab0a69 100644 --- a/openapi/schema.go +++ b/openapi/schema.go @@ -6,6 +6,7 @@ package openapi import ( "reflect" + "time" orderedmap "github.com/wk8/go-ordered-map/v2" @@ -79,68 +80,91 @@ type schemaRenderer struct { } func (d *Document) newSchema(t reflect.Type) *Schema { - return schemaFromType(d, t, true, "", nil) + s := &Schema{} + schemaFromType(d, t, true, "", s) + return s } // NewSchema 根据 [reflect.Type] 生成 [Schema] 对象 func NewSchema(t reflect.Type, title, desc web.LocaleStringer) *Schema { - s := schemaFromType(nil, t, true, "", desc) - s.Title = title + s := &Schema{ + Title: title, + Description: desc, + } + schemaFromType(nil, t, true, "", s) return s } +var timeType = reflect.TypeFor[time.Time]() + // d 仅用于查找其关联的 components/schema 中是否存在相同名称的对象,如果存在则直接生成引用对象。 // // desc 表示类型 t 的 Description 属性 // rootName 根结构体的名称,主要是为了解决子元素又引用了根元素的类型引起的循环引用。 -func schemaFromType(d *Document, t reflect.Type, isRoot bool, rootName string, desc web.LocaleStringer) *Schema { +func schemaFromType(d *Document, t reflect.Type, isRoot bool, rootName string, s *Schema) { for t.Kind() == reflect.Pointer { t = t.Elem() } switch t.Kind() { case reflect.String: - return &Schema{Type: TypeString, Description: desc} + s.Type = TypeString case reflect.Bool: - return &Schema{Type: TypeBoolean, Description: desc} + s.Type = TypeBoolean case reflect.Float32, reflect.Float64: - return &Schema{Type: TypeNumber, Description: desc} + s.Type = TypeNumber case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return &Schema{Type: TypeInteger, Description: desc} + s.Type = TypeInteger case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return &Schema{Type: TypeInteger, Minimum: 0, Description: desc} + s.Type = TypeInteger + s.Minimum = 0 case reflect.Array, reflect.Slice: - return &Schema{Type: TypeArray, Items: schemaFromType(d, t.Elem(), false, rootName, nil), Description: desc} + s.Type = TypeArray + s.Items = &Schema{} + schemaFromType(d, t.Elem(), false, rootName, s.Items) case reflect.Map: - return &Schema{Type: TypeObject, AdditionalProperties: schemaFromType(d, t.Elem(), false, rootName, nil), Description: desc} + s.Type = TypeObject + s.AdditionalProperties = &Schema{} + schemaFromType(d, t.Elem(), false, rootName, s.AdditionalProperties) case reflect.Struct: - return schemaFromObject(d, t, isRoot, rootName, desc) + if t == timeType { // 对时间作特殊处理 + s.Type = TypeString + s.Format = FormatDateTime + return + } + schemaFromObjectType(d, t, isRoot, rootName, s) } - return nil } -func schemaFromObject(d *Document, t reflect.Type, isRoot bool, rootName string, desc web.LocaleStringer) *Schema { +func schemaFromObjectType(d *Document, t reflect.Type, isRoot bool, rootName string, s *Schema) { typeName := getTypeName(t) if d != nil { - if s, found := d.components.schemas[typeName]; found { // 已经存在于 components - return s + if _, found := d.components.schemas[typeName]; found { // 已经存在于 components + s.Ref = &Ref{Ref: typeName} + return } } - ref := &Ref{Ref: typeName} + s.Ref = &Ref{Ref: typeName} if isRoot { rootName = typeName // isRoot == true 时,rootName 必然为空 } else if typeName == rootName { // 在字段中引用了根对象 - return &Schema{Ref: ref} + return } - ps := make(map[string]*Schema, t.NumField()) - req := make([]string, 0, t.NumField()) + s.Type = TypeObject + s.Properties = make(map[string]*Schema, t.NumField()) + for i := 0; i < t.NumField(); i++ { f := t.Field(i) k := f.Type.Kind() - var itemTitle web.LocaleStringer + var itemDesc web.LocaleStringer + + if f.Anonymous { + schemaFromType(d, f.Type, isRoot, rootName, s) + continue + } if f.IsExported() && k != reflect.Chan && k != reflect.Func && k != reflect.Complex64 && k != reflect.Complex128 { name := f.Name @@ -154,7 +178,7 @@ func schemaFromObject(d *Document, t reflect.Type, isRoot bool, rootName string, } if !omitempty { - req = append(req, name) + s.Required = append(s.Required, name) } if xmlName, _, attr := getTagName(f, "xml"); xmlName != "" && xmlName != name { @@ -163,27 +187,31 @@ func schemaFromObject(d *Document, t reflect.Type, isRoot bool, rootName string, comment := f.Tag.Get(CommentTag) if comment != "" { - itemTitle = web.Phrase(comment) + itemDesc = web.Phrase(comment) } } - s := schemaFromType(d, t.Field(i).Type, false, rootName, itemTitle) - if s == nil { + item := &Schema{Description: itemDesc} + schemaFromType(d, t.Field(i).Type, false, rootName, item) + if item.Type == "" { continue } if xml != nil { - s.XML = xml + item.XML = xml } - ps[name] = s + s.Properties[name] = item } } +} - return &Schema{ - Type: TypeObject, - Properties: ps, - Ref: ref, - Required: req, - Description: desc, +func (s *Schema) isBasicType() bool { + switch s.Type { + case TypeObject: + return false + case TypeArray: + return s.Items.isBasicType() + default: + return true } } diff --git a/openapi/schema_test.go b/openapi/schema_test.go index 50d35146..df25da2b 100644 --- a/openapi/schema_test.go +++ b/openapi/schema_test.go @@ -7,13 +7,24 @@ package openapi import ( "reflect" "testing" + "time" "github.com/issue9/assert/v4" "github.com/issue9/web" ) -func TestDocument_NewSchema(t *testing.T) { +type schemaObject1 struct { + object + Root string + T time.Time +} + +type schemaObject2 struct { + schemaObject1 +} + +func TestDocument_newSchema(t *testing.T) { a := assert.New(t, false) ss := newServer(a) d := New(ss, web.Phrase("desc")) @@ -22,6 +33,10 @@ func TestDocument_NewSchema(t *testing.T) { a.Equal(s.Type, TypeInteger). Nil(s.Ref) + s = d.newSchema(reflect.TypeFor[[]int]()) + a.Equal(s.Type, TypeArray). + Equal(s.Items.Type, TypeInteger) + s = d.newSchema(reflect.TypeFor[map[string]float32]()) a.Equal(s.Type, TypeObject). Nil(s.Ref). @@ -37,4 +52,35 @@ func TestDocument_NewSchema(t *testing.T) { Nil(s.Properties["Items"].XML). Equal(s.Properties["Items"].Type, TypeArray). NotZero(s.Properties["Items"].Items.Ref.Ref) // 引用了 object + + s = d.newSchema(reflect.ValueOf(schemaObject1{}).Type()) + a.Equal(s.Type, TypeObject). + NotZero(s.Ref.Ref). + Length(s.Properties, 5). + Equal(s.Properties["id"].Type, TypeInteger). + Equal(s.Properties["Root"].Type, TypeString). + Equal(s.Properties["T"].Type, TypeString). + Equal(s.Properties["T"].Format, FormatDateTime) + + s = d.newSchema(reflect.ValueOf(schemaObject2{}).Type()) + a.Equal(s.Type, TypeObject). + NotZero(s.Ref.Ref). + Length(s.Properties, 5). + Equal(s.Properties["id"].Type, TypeInteger). + Equal(s.Properties["Root"].Type, TypeString). + Equal(s.Properties["T"].Type, TypeString). + Equal(s.Properties["T"].Format, FormatDateTime) +} + +func TestSchema_isBasicType(t *testing.T) { + a := assert.New(t, false) + + s := NewSchema(reflect.TypeFor[int](), nil, nil) + a.True(s.isBasicType()) + + s = NewSchema(reflect.TypeFor[object](), nil, nil) + a.False(s.isBasicType()) + + s = NewSchema(reflect.TypeFor[[]string](), nil, nil) + a.True(s.isBasicType()) } diff --git a/openapi/valid.go b/openapi/valid.go index 4dc3b876..c815c49d 100644 --- a/openapi/valid.go +++ b/openapi/valid.go @@ -99,7 +99,7 @@ func (t *Parameter) valid(skipRefNotNil bool) *web.FieldError { return err } - if t.Schema.Type == TypeObject { + if !t.Schema.isBasicType() { return web.NewFieldError("Schema", "不支持复杂类型") }