Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(middleware): v4 experimental middleware #986

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions experimental/plugins/plugintypes/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package plugintypes

import (
"context"

"github.com/corazawaf/coraza/v3/collection"
"github.com/corazawaf/coraza/v3/debuglog"
"github.com/corazawaf/coraza/v3/types"
Expand Down Expand Up @@ -34,7 +36,11 @@ type TransactionState interface {
// CaptureField captures a field.
CaptureField(idx int, value string)

// LastPhase that was evaluated
LastPhase() types.RulePhase

// Context returns the context of the transaction.
Context() context.Context
}

// TransactionVariables has pointers to all the variables of the transaction
Expand Down
112 changes: 107 additions & 5 deletions http/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/corazawaf/coraza/v3"
"github.com/corazawaf/coraza/v3/experimental"
"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
"github.com/corazawaf/coraza/v3/types"
)

Expand Down Expand Up @@ -96,16 +97,115 @@ func processRequest(tx types.Transaction, req *http.Request) (*types.Interruptio
return tx.ProcessRequestBody()
}

var (
forbiddenMessage []byte = []byte("Forbidden")
errorMessage []byte = []byte("Internal Server Error")
)

// Options represents the options for the experimental middleware
type Options struct {
// OnInterruption is a function that will be called when an interruption is triggered
// This function will render the error page and write the response
OnInterruption func(types.Interruption, http.ResponseWriter, *http.Request)

// OnError is a function that will be called when an error is triggered
// This function will render the error page and write the response
OnError func(error, http.ResponseWriter, *http.Request)

// BeforeClose is a function that will be called before the transaction is closed
// If this function is overwritten tx.ProcessLogging() has to be called manually
// It is useful to complement observability signals like metrics, traces and logs
// by providing additional context about the transaction and the rules that were matched.
BeforeClose func(types.Transaction, *http.Request)

// OnTransactionStarted is called when a new transaction is started. It is useful to
// complement observability signals like metrics, traces and logs by providing additional
// context about the transaction.
OnTransactionStarted func(tx plugintypes.TransactionState)

// ProcessResponse enables the processing of the response
// If the response is not processed, the middleware will only consume
// request headers and request body, also, response will have to be
// processed by the next handler.
ProcessResponse bool

// WAF represents the WAF instance to use
// New transactions will be created using this WAF instance
WAF coraza.WAF

// SamplingRate represents the rate of sampling for the middleware
// If the rate is 0, the middleware will not sample
// If the rate is 100, the middleware will sample all requests
SamplingRate int
}

var defaultOptions = Options{
OnInterruption: func(i types.Interruption, w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
w.Write(forbiddenMessage) //nolint:errcheck
},
OnError: func(e error, w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write(errorMessage) //nolint:errcheck
// TODO generate log?
},
BeforeClose: func(tx types.Transaction, r *http.Request) {
tx.ProcessLogging()
},
OnTransactionStarted: func(tx plugintypes.TransactionState) {
// Nothing to do here
},
ProcessResponse: false,
SamplingRate: 0,
}

func (o *Options) loadDefaults() {
if o.OnInterruption == nil {
o.OnInterruption = defaultOptions.OnInterruption
}

if o.OnError == nil {
o.OnError = defaultOptions.OnError
}

if o.BeforeClose == nil {
o.BeforeClose = defaultOptions.BeforeClose
}

if o.OnTransactionStarted == nil {
o.OnTransactionStarted = defaultOptions.OnTransactionStarted
}

}

// DefaultOptions returns the default options for the middleware
func DefaultOptions(waf coraza.WAF) Options {
opts := Options{
WAF: waf,
}
opts.loadDefaults()
return opts
}

func WrapHandler(waf coraza.WAF, h http.Handler) http.Handler {
if waf == nil {
return wrapHandler(h, DefaultOptions(waf))
}

func WrapHandlerWithOptions(h http.Handler, opts Options) http.Handler {
opts.loadDefaults()
return wrapHandler(h, opts)
}

func wrapHandler(h http.Handler, opts Options) http.Handler {
if opts.WAF == nil {
return h
}

newTX := func(*http.Request) types.Transaction {
return waf.NewTransaction()
return opts.WAF.NewTransaction()
}

if ctxwaf, ok := waf.(experimental.WAFWithOptions); ok {
if ctxwaf, ok := opts.WAF.(experimental.WAFWithOptions); ok {
newTX = func(r *http.Request) types.Transaction {
return ctxwaf.NewTransactionWithOptions(experimental.Options{
Context: r.Context(),
Expand All @@ -115,9 +215,11 @@ func WrapHandler(waf coraza.WAF, h http.Handler) http.Handler {

fn := func(w http.ResponseWriter, r *http.Request) {
tx := newTX(r)
txs := tx.(plugintypes.TransactionState)
opts.OnTransactionStarted(txs)
defer func() {
// We run phase 5 rules and create audit logs (if enabled)
tx.ProcessLogging()
// BeforeClose should call tx.ProcessLogging() for phase 5 processing
opts.BeforeClose(tx, r)
// we remove temporary files and free some memory
if err := tx.Close(); err != nil {
tx.DebugLogger().Error().Err(err).Msg("Failed to close the transaction")
Expand Down
24 changes: 24 additions & 0 deletions http/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package http
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"mime/multipart"
Expand All @@ -23,6 +24,7 @@ import (
"github.com/corazawaf/coraza/v3"
"github.com/corazawaf/coraza/v3/debuglog"
"github.com/corazawaf/coraza/v3/experimental/plugins/macro"
"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
"github.com/corazawaf/coraza/v3/internal/corazawaf"
"github.com/corazawaf/coraza/v3/internal/seclang"
"github.com/corazawaf/coraza/v3/types"
Expand Down Expand Up @@ -213,6 +215,7 @@ func TestChainEvaluation(t *testing.T) {
}

func errLogger(t *testing.T) func(rule types.MatchedRule) {
t.Helper()
return func(rule types.MatchedRule) {
t.Log(rule.ErrorLog())
}
Expand Down Expand Up @@ -643,3 +646,24 @@ func TestHandlerAPI(t *testing.T) {
})
}
}

type ctxKey struct{}

func TestWrapHandlerWithOptions(t *testing.T) {
waf, _ := coraza.NewWAF(coraza.NewWAFConfig())
delegateHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})

ctx := context.WithValue(context.Background(), ctxKey{}, "value")
req, _ := http.NewRequestWithContext(ctx, "GET", "https://www.coraza.io/test", nil)

wrappedHandler := WrapHandlerWithOptions(delegateHandler, Options{
WAF: waf,
OnTransactionStarted: func(tx plugintypes.TransactionState) {
if want, have := "value", tx.Context().Value(ctxKey{}).(string); want != have {
t.Errorf("unexpected context value, want: %s, have: %s", want, have)
}
},
}).(http.HandlerFunc)

wrappedHandler(httptest.NewRecorder(), req)
}
4 changes: 4 additions & 0 deletions internal/corazawaf/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ func (tx *Transaction) ID() string {
return tx.id
}

func (tx *Transaction) Context() context.Context {
return tx.context
}

func (tx *Transaction) Variables() plugintypes.TransactionVariables {
return &tx.variables
}
Expand Down
Loading