From 2060a95a2b7a7861a8b93580fe5caed28565a3e6 Mon Sep 17 00:00:00 2001 From: caixw Date: Mon, 4 Mar 2024 15:03:08 +0800 Subject: [PATCH] =?UTF-8?q?feat(cmd/web/restdoc):=20=E6=B7=BB=E5=8A=A0=20@?= =?UTF-8?q?security=20=E6=8C=87=E4=BB=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- cmd/web/locales/zh-CN.yaml | 8 +++++ cmd/web/restdoc/parser/api.go | 21 +++++++++++ cmd/web/restdoc/parser/parser_test.go | 2 +- cmd/web/restdoc/parser/restdoc.go | 39 ++++++++++++++++----- cmd/web/restdoc/parser/testdata/api.go | 1 + cmd/web/restdoc/parser/testdata/testdata.go | 1 + cmd/web/restdoc/utils/utils.go | 2 +- 8 files changed, 65 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 9f3388fc..3cd9d4bb 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Test](https://github.com/issue9/web/actions/workflows/test.yml/badge.svg)](https://github.com/issue9/web/actions/workflows/test.yml) [![Go Report Card](https://goreportcard.com/badge/github.com/issue9/web)](https://goreportcard.com/report/github.com/issue9/web) -[![codecov](https://codecov.io/gh/issue9/web/branch/master/graph/badge.svg)](https://codecov.io/gh/issue9/web) +[![codecov](https://codecov.io/gh/issue9/web/graph/badge.svg?token=D5y3FOJk8A)](https://codecov.io/gh/issue9/web) [![PkgGoDev](https://pkg.go.dev/badge/github.com/issue9/web)](https://pkg.go.dev/github.com/issue9/web) ![Go version](https://img.shields.io/github/go-mod/go-version/issue9/web) ![License](https://img.shields.io/github/license/issue9/web) diff --git a/cmd/web/locales/zh-CN.yaml b/cmd/web/locales/zh-CN.yaml index a1307a95..a5ca9af4 100755 --- a/cmd/web/locales/zh-CN.yaml +++ b/cmd/web/locales/zh-CN.yaml @@ -236,6 +236,10 @@ messages: @media application/json application/xml 指定了所有 api 默认的内容类型; @header name desc 公共报头; @cookie name desc 公共的 cookie; + @security tags name args 指定所有 API 的验证接口。 + tags 为支持的标签,多个标签用逗号分隔,可以采用 * 表示支持所有。 + name 为 @scy-* 系列定义的验证接口,args 是以空格分隔的参数列表。 + args 为传递给由 name 指定的接口参数,可以为空; @resp status media object.path desc 定义了全局返回对象,会被附加在每个 API 上。一般用于指定错误处理对象; @resp-header name header desc 指定了 @resp 中设定的返回对象的报头; @doc url 扩展文档; @@ -260,6 +264,10 @@ messages: @req media object.path desc 定义请求格式,media 如果采用全局指定的默认值,可以用 *,如果是多个值,用逗号分隔; @resp status media object.path desc 指定返回的数据类型; @resp-header status key desc 指定报头; + @security name args 指定所有 API 的验证接口。 + name 为 @scy-* 系列定义的验证接口,args 是以空格分隔的参数列表。 + args 为传递给由 name 指定的接口参数,可以为空。 + 如果 @security 之后为空,则表示当前接口不需要验证; ## callback name method url desc diff --git a/cmd/web/restdoc/parser/api.go b/cmd/web/restdoc/parser/api.go index 750f9e3a..0b81a430 100644 --- a/cmd/web/restdoc/parser/api.go +++ b/cmd/web/restdoc/parser/api.go @@ -35,6 +35,7 @@ func (p *Parser) parseAPI(ctx context.Context, t *openapi.OpenAPI, currPath, suf method, path, summary := words[0], words[1], words[2] opt := openapi3.NewOperation() opt.Summary = summary + opt.Security = openapi3.NewSecurityRequirements() resps := map[string]*openapi3.Response{} @@ -73,6 +74,8 @@ LOOP: if !p.parseResponseHeader(resps, suffix, filename, ln+index) { return } + case "@security": // @security name args + p.parseSecurity(opt, suffix) case "##": // 可能是 ## callback delta := p.parseCallback(ctx, t, opt, currPath, suffix, lines[index:], ln+index, filename) index += delta @@ -173,3 +176,21 @@ func (p *Parser) addCookieHeader(tag string, opt *openapi3.Operation, in, suffix s := openapi3.NewSchemaRef("", openapi3.NewStringSchema()) opt.AddParameter(&openapi3.Parameter{In: in, Schema: s, Name: words[0], Description: words[1]}) } + +// @security name args +func (p *Parser) parseSecurity(opt *openapi3.Operation, suffix string) { + words, l := utils.SplitSpaceN(suffix, 2) + var keys []string + switch l { + case 0: // 相当于取消全局定义的数据 + opt.Security.With(openapi3.NewSecurityRequirement()) + return + case 1: // 只有名称,没有参数 + case 2: + keys = strings.Fields(words[1]) + } + + req := openapi3.NewSecurityRequirement() + req[words[0]] = keys + opt.Security.With(req) +} diff --git a/cmd/web/restdoc/parser/parser_test.go b/cmd/web/restdoc/parser/parser_test.go index 109a710f..8e2a6851 100644 --- a/cmd/web/restdoc/parser/parser_test.go +++ b/cmd/web/restdoc/parser/parser_test.go @@ -47,7 +47,7 @@ func TestParser_Parse(t *testing.T) { a.NotNil(d). Length(l.Records[logs.LevelError], 0). Length(l.Records[logs.LevelWarn], 0). - Length(l.Records[logs.LevelInfo], 2) // scan dir/ add api 的提示 + Length(l.Records[logs.LevelInfo], 2) // scan dir / add api 的提示 a.NotNil(d.Doc().Info). Equal(d.Doc().Info.Version, "1.0.0") diff --git a/cmd/web/restdoc/parser/restdoc.go b/cmd/web/restdoc/parser/restdoc.go index 546f4454..08d38372 100644 --- a/cmd/web/restdoc/parser/restdoc.go +++ b/cmd/web/restdoc/parser/restdoc.go @@ -62,20 +62,16 @@ LOOP: if !p.isIgnoreTag(words[0]) { t.Doc().Tags = append(t.Doc().Tags, &openapi3.Tag{Name: words[0], Description: words[1]}) } - case "@server": // @server tag https://example.com *desc + case "@server": // @server tags https://example.com *desc words, l := utils.SplitSpaceN(suffix, 3) if l < 2 { p.syntaxError("@server", 2, filename, ln+i) continue LOOP } - tag := words[0] - if tag == "*" { - tag = "" - } - if tag != "" && p.isIgnoreTag(strings.Split(tag, ",")...) { - continue + + if tag := words[0]; tag == "" || tag == "*" || !p.isIgnoreTag(strings.Split(tag, ",")...) { + t.Doc().Servers = append(t.Doc().Servers, &openapi3.Server{URL: words[1], Description: words[2]}) } - t.Doc().Servers = append(t.Doc().Servers, &openapi3.Server{URL: words[1], Description: words[2]}) case "@license": // @license MIT *https://example.com/license words, l := utils.SplitSpaceN(suffix, 2) if l < 1 { @@ -122,6 +118,10 @@ LOOP: if !p.parseResponseHeader(p.resps, suffix, filename, ln+i) { continue LOOP } + case "@security": // @security tags securityName args1 args2 + if !p.parseCommonSecurity(t, suffix, filename, ln+i) { + continue + } case "@scy-http": // @scy-http name scheme format *desc words, l := utils.SplitSpaceN(suffix, 4) if l < 3 { @@ -257,6 +257,29 @@ LOOP: t.Doc().Info = info } +// @security tags name args +func (p *Parser) parseCommonSecurity(doc *openapi.OpenAPI, suffix, filename string, ln int) bool { + words, l := utils.SplitSpaceN(suffix, 3) + if l < 2 { + p.syntaxError("@security", 2, filename, ln) + return false + } + + if tag := words[0]; tag != "" && tag != "*" && p.isIgnoreTag(strings.Split(tag, ",")...) { + return true + } + + var keys []string + if l == 3 { + keys = strings.Fields(words[2]) + } + + req := openapi3.NewSecurityRequirement() + req[words[1]] = keys + doc.Doc().Security.With(req) + return true +} + // @version 1.0.0 // 直接指定版本号 // @version git // 采用 git 的版本号 // @version path/pkg.version // 采用指向的常量作为版本号 diff --git a/cmd/web/restdoc/parser/testdata/api.go b/cmd/web/restdoc/parser/testdata/api.go index 965194b5..fdd24524 100644 --- a/cmd/web/restdoc/parser/testdata/api.go +++ b/cmd/web/restdoc/parser/testdata/api.go @@ -16,6 +16,7 @@ package testdata // @resp-header 201 h2011 h1 desc // @resp-header 201 h2012 h2 desc // @resp 200 * resp desc +// @security oauth-code // // ## callback onData POST {$request.query.url} 回调1 // @req * req 登录的账号信息 diff --git a/cmd/web/restdoc/parser/testdata/testdata.go b/cmd/web/restdoc/parser/testdata/testdata.go index e1977789..b47e4bf9 100644 --- a/cmd/web/restdoc/parser/testdata/testdata.go +++ b/cmd/web/restdoc/parser/testdata/testdata.go @@ -23,6 +23,7 @@ // @openapi ./openapi.yaml // // @scy-code oauth-code https://example.com/auth https://example.com/token https://example.com/refresh read:info,write:info +// @security users,admin ouath-code // // # 其它文档说明 // diff --git a/cmd/web/restdoc/utils/utils.go b/cmd/web/restdoc/utils/utils.go index cbf9a84c..e80baeb9 100644 --- a/cmd/web/restdoc/utils/utils.go +++ b/cmd/web/restdoc/utils/utils.go @@ -31,7 +31,7 @@ func CutTag(line string) (tag, suffix string) { // SplitSpaceN 以空格分隔字符串 // // maxSize 表示最多分隔的数量,如果无法达到 maxSize 的数量,则采用空字符串代替剩余的元素, -// 返回值 length 表示实际的元素数量。-1 表示按实际的数量拆分,length 始终等于 len(ret)。 +// 返回值 length 表示返回数组中非空元素的数量。-1 表示按实际的数量拆分,此时 length 始终等于 len(ret)。 func SplitSpaceN(s string, maxSize int) (ret []string, length int) { if maxSize == 0 { panic("参数 maxSize 不能为 0")