Skip to content

Commit

Permalink
add default deny message (#1347)
Browse files Browse the repository at this point in the history
Co-authored-by: Kent Dong <[email protected]>
  • Loading branch information
rinfx and CH3CHO authored Sep 27, 2024
1 parent ea99159 commit 1b119ed
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 27 deletions.
17 changes: 16 additions & 1 deletion plugins/wasm-go/extensions/ai-security-guard/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,21 @@ description: 阿里云内容安全检测
| `responseContentJsonPath` | string | optional | `choices.0.message.content` | 指定要检测内容在响应body中的jsonpath |
| `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | 指定要检测内容在流式响应body中的jsonpath |
| `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 |
| `denyMessage` | string | optional | openai格式的流失/非流式响应,回答内容为阿里云内容安全的建议回答 | 指定内容非法时的响应内容 |
| `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 |
| `protocol` | string | optional | openai | 协议格式,非openai协议填`original` |

补充说明一下 `denyMessage`,对于openai格式的请求,对非法请求的处理逻辑为:
- 如果配置了 `denyMessage`
- 优先返回阿里云内容安全的建议回答,格式为openai格式的流式/非流式响应
- 如果阿里云内容安全未返回建议的回答,返回内容为 `denyMessage` 配置内容,格式为openai格式的流式/非流式响应
- 如果没有配置 `denyMessage`
- 优先返回阿里云内容安全的建议回答,格式为openai格式的流式/非流式响应
- 如果阿里云内容安全未返回建议的回答,返回内容为内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,格式为openai格式的流式/非流式响应

如果用户使用了非openai格式的协议,应当配置 `denyMessage`,此时对非法请求的处理逻辑为:
- 返回用户配置的 `denyMessage` 内容,用户可以配置其为序列化后的json字符串,以保持与正常请求接口返回格式的一致性
- 如果 `denyMessage` 为空,优先返回阿里云内容安全的建议回答,格式为纯文本
- 如果阿里云内容安全未返回建议回答,返回内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`,格式为纯文本

## 配置示例
### 前提条件
Expand Down Expand Up @@ -90,6 +104,7 @@ requestContentJsonPath: "input.prompt"
responseContentJsonPath: "output.text"
denyCode: 200
denyMessage: "很抱歉,我无法回答您的问题"
protocol: original
```
## 可观测
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-security-guard/README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ Plugin Priority: `300`
| `responseContentJsonPath` | string | optional | `choices.0.message.content` | Specify the jsonpath of the content to be detected in the response body |
| `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | Specify the jsonpath of the content to be detected in the streaming response body |
| `denyCode` | int | optional | 200 | Response status code when the specified content is illegal |
| `denyMessage` | string | optional | Drainage/non-streaming response in openai format, the answer content is the suggested answer from Alibaba Cloud content security
| Response content when the specified content is illegal |
| `denyMessage` | string | optional | Drainage/non-streaming response in openai format, the answer content is the suggested answer from Alibaba Cloud content security | Response content when the specified content is illegal |
| `protocol` | string | optional | openai | protocol format, `openai` or `original` |


## Examples of configuration
Expand Down
95 changes: 71 additions & 24 deletions plugins/wasm-go/extensions/ai-security-guard/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ const (
DefaultResponseJsonPath = "choices.0.message.content"
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
DefaultDenyCode = 200
DefaultDenyMessage = "很抱歉,我无法回答您的问题"

AliyunUserAgent = "CIPFrom/AIGateway"
)
Expand All @@ -64,6 +65,7 @@ type AISecurityConfig struct {
responseStreamContentJsonPath string
denyCode int64
denyMessage string
protocolOriginal bool
metrics map[string]proxywasm.MetricCounter
}

Expand Down Expand Up @@ -129,6 +131,7 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e
}
config.checkRequest = json.Get("checkRequest").Bool()
config.checkResponse = json.Get("checkResponse").Bool()
config.protocolOriginal = json.Get("protocol").String() == "original"
config.denyMessage = json.Get("denyMessage").String()
if obj := json.Get("denyCode"); obj.Exists() {
config.denyCode = obj.Int()
Expand Down Expand Up @@ -218,29 +221,51 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
reqParams.Add("Signature", signature)
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != 200 {
log.Error(string(responseBody))
proxywasm.ResumeHttpRequest()
return
}
respData := gjson.GetBytes(responseBody, "Data")
if respData.Exists() {
respAdvice := respData.Get("Advice")
respResult := respData.Get("Result")
if respAdvice.Exists() {
var denyMessage string
if config.protocolOriginal {
// not openai
if config.denyMessage != "" {
denyMessage = config.denyMessage
} else if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
} else {
denyMessage = DefaultDenyMessage
}
} else {
// openai
if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
} else if config.denyMessage != "" {
denyMessage = config.denyMessage
} else {
denyMessage = DefaultDenyMessage
}
}
if respResult.Array()[0].Get("Label").String() != "nonLabel" {
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request"))
config.incrementCounter("ai_sec_request_deny", 1)
if config.denyMessage != "" {
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(config.denyMessage), -1)
if config.protocolOriginal {
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(denyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
answer := respAdvice.Array()[0].Get("Answer").Raw
if gjson.GetBytes(body, "stream").Bool() {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, answer))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, answer))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, denyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
} else {
proxywasm.ResumeHttpRequest()
}
Expand Down Expand Up @@ -330,22 +355,44 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
defer proxywasm.ResumeHttpResponse()
if statusCode != 200 {
log.Error(string(responseBody))
return
}
respData := gjson.GetBytes(responseBody, "Data")
if respData.Exists() {
respAdvice := respData.Get("Advice")
respResult := respData.Get("Result")
if respAdvice.Exists() {
var jsonData []byte
var denyMessage string
if config.protocolOriginal {
// not openai
if config.denyMessage != "" {
jsonData = []byte(config.denyMessage)
denyMessage = config.denyMessage
} else if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
} else {
denyMessage = DefaultDenyMessage
}
} else {
// openai
if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
} else if config.denyMessage != "" {
denyMessage = config.denyMessage
} else {
denyMessage = DefaultDenyMessage
}
}
if respResult.Array()[0].Get("Label").String() != "nonLabel" {
var jsonData []byte
if config.protocolOriginal {
jsonData = []byte(denyMessage)
} else if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String(), randomID, model))
} else {
if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String()))
} else {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String()))
}
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String()))
}
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
Expand Down

0 comments on commit 1b119ed

Please sign in to comment.