diff --git a/main.go b/main.go index 2de31b0..6b578d4 100644 --- a/main.go +++ b/main.go @@ -34,6 +34,7 @@ func main() { proxy.InitHttpClient() object.InitSiteMap() object.InitRuleMap() + object.InitActionMap() run.InitAppMap() run.InitSelfStart() object.StartMonitorSitesLoop() diff --git a/object/action_cache.go b/object/action_cache.go index f7e86d4..4d738dd 100644 --- a/object/action_cache.go +++ b/object/action_cache.go @@ -44,7 +44,7 @@ func refreshActionMap() error { return nil } -func GetActionsByActionIds(ids []string) ([]*Action, error) { +func GetActionsByIds(ids []string) ([]*Action, error) { var res []*Action for _, id := range ids { action, ok := actionMap[id] @@ -55,3 +55,11 @@ func GetActionsByActionIds(ids []string) ([]*Action, error) { } return res, nil } + +func GetActionById(id string) (*Action, error) { + action, ok := actionMap[id] + if !ok { + return nil, fmt.Errorf("action: %s not found", id) + } + return action, nil +} diff --git a/rule/rule.go b/rule/rule.go index 0e48beb..58f942c 100644 --- a/rule/rule.go +++ b/rule/rule.go @@ -25,10 +25,11 @@ type Rule interface { checkRule(expressions []*object.Expression, req *http.Request) (bool, string, string, error) } -func CheckRules(ruleIds []string, r *http.Request) (string, string, error) { +func CheckRules(ruleIds []string, r *http.Request) (*object.Action, string, error) { + var actionObj *object.Action rules, err := object.GetRulesByRuleIds(ruleIds) if err != nil { - return "", "", err + return nil, "", err } for i, rule := range rules { var ruleObj Rule @@ -46,15 +47,36 @@ func CheckRules(ruleIds []string, r *http.Request) (string, string, error) { case "Compound": ruleObj = &CompoundRule{} default: - return "", "", fmt.Errorf("unknown rule type: %s for rule: %s", rule.Type, rule.GetId()) + return nil, "", fmt.Errorf("unknown rule type: %s for rule: %s", rule.Type, rule.GetId()) } isHit, action, reason, err := ruleObj.checkRule(rule.Expressions, r) if err != nil { - return "", "", err + return nil, "", err } if action == "" { - action = rule.Action + actionObj, err = object.GetActionById(rule.Action) + if err != nil { + return nil, "", err + } + action = actionObj.Type + } else { + switch action { + case "Block": + actionObj.Type = "Block" + actionObj.StatusCode = 403 + case "Drop": + actionObj.Type = "Drop" + actionObj.StatusCode = 400 + case "Allow": + actionObj.Type = "Allow" + actionObj.StatusCode = 200 + case "Captcha": + actionObj.Type = "Captcha" + actionObj.StatusCode = 302 + default: + return nil, "", fmt.Errorf("unknown rule action: %s for rule: %s", action, rule.GetId()) + } } if isHit { if action == "Block" || action == "Drop" { @@ -63,16 +85,17 @@ func CheckRules(ruleIds []string, r *http.Request) (string, string, error) { } else { reason = fmt.Sprintf("hit rule %s: %s", ruleIds[i], reason) } - return action, reason, nil + return actionObj, reason, nil } else if action == "Allow" { - return action, reason, nil + return actionObj, reason, nil } else if action == "Captcha" { - return action, reason, nil + return actionObj, reason, nil } else { - return "", "", fmt.Errorf("unknown rule action: %s for rule: %s", action, rule.GetId()) + return nil, "", fmt.Errorf("unknown rule action: %s for rule: %s", action, rule.GetId()) } } } - - return "", "", nil + actionObj.Type = "Allow" + actionObj.StatusCode = 200 + return actionObj, "", nil } diff --git a/rule/rule_compound.go b/rule/rule_compound.go index a6f51c2..34093e6 100644 --- a/rule/rule_compound.go +++ b/rule/rule_compound.go @@ -33,7 +33,7 @@ func (r *CompoundRule) checkRule(expressions []*object.Expression, req *http.Req if err != nil { return false, "", "", err } - if action == "" { + if action.Type == "" { isHit = false } switch expression.Operator { diff --git a/service/proxy.go b/service/proxy.go index ea3502e..9556438 100644 --- a/service/proxy.go +++ b/service/proxy.go @@ -205,15 +205,15 @@ func handleRequest(w http.ResponseWriter, r *http.Request) { reason = "the rule has been hit" } - switch action { + switch action.Type { case "", "Allow": - w.WriteHeader(http.StatusOK) + w.WriteHeader(action.StatusCode) case "Block": responseError(w, "Blocked by CasWAF: %s", reason) - w.WriteHeader(http.StatusForbidden) + w.WriteHeader(action.StatusCode) case "Drop": responseError(w, "Dropped by CasWAF: %s", reason) - w.WriteHeader(http.StatusBadRequest) + w.WriteHeader(action.StatusCode) case "Captcha": ok := isVerifiedSession(r) if ok { diff --git a/web/src/RuleEditPage.js b/web/src/RuleEditPage.js index 388a73e..014bcba 100644 --- a/web/src/RuleEditPage.js +++ b/web/src/RuleEditPage.js @@ -13,9 +13,10 @@ // limitations under the License. import React from "react"; -import {Button, Card, Col, Input, InputNumber, Row, Select} from "antd"; +import {Button, Card, Col, Input, Row, Select} from "antd"; import * as Setting from "./Setting"; import * as RuleBackend from "./backend/RuleBackend"; +import * as ActionBackend from "./backend/ActionBackend"; import i18next from "i18next"; import WafRuleTable from "./components/WafRuleTable"; import IpRuleTable from "./components/IpRuleTable"; @@ -33,11 +34,13 @@ class RuleEditPage extends React.Component { owner: props.match.params.owner, ruleName: props.match.params.ruleName, rule: null, + actions: [], }; } UNSAFE_componentWillMount() { this.getRule(); + this.getActions(); } getRule() { @@ -48,6 +51,14 @@ class RuleEditPage extends React.Component { }); } + getActions() { + ActionBackend.getActions(this.state.owner).then((res) => { + this.setState({ + actions: res.data, + }); + }); + } + updateRuleField(key, value) { const rule = Setting.deepCopy(this.state.rule); rule[key] = value; @@ -172,17 +183,14 @@ class RuleEditPage extends React.Component { {i18next.t("general:Action")}: