Skip to content

Commit

Permalink
nsqd: support for multiple Auth HTTP Methods
Browse files Browse the repository at this point in the history
Adds simple config option and flag to allow for auth to occur via POST
request in addition to GET. Rationale: Errors from net/http requests are
bubbled to nsqd when there is an error during authentication, such as if
the nsq authentication server is unavailable. These errors include the
full path, including any GET parameter, thus causing the authentication
secret to be logged. This does not occur by default for the POST body
thus helping protect secrets in transit between nsqd and the
authentication server.
  • Loading branch information
danrjohnson authored and mreiferson committed May 12, 2024
1 parent 62fa868 commit f162880
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 18 deletions.
1 change: 1 addition & 0 deletions apps/nsqd/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func nsqdFlagSet(opts *nsqd.Options) *flag.FlagSet {

authHTTPAddresses := app.StringArray{}
flagSet.Var(&authHTTPAddresses, "auth-http-address", "<addr>:<port> or a full url to query auth server (may be given multiple times)")
flagSet.String("auth-http-request-method", opts.AuthHTTPRequestMethod, "HTTP method to use for auth server requests")
flagSet.String("broadcast-address", opts.BroadcastAddress, "address that will be registered with lookupd (defaults to the OS hostname)")
flagSet.Int("broadcast-tcp-port", opts.BroadcastTCPPort, "TCP port that will be registered with lookupd (defaults to the TCP port that this nsqd is listening on)")
flagSet.Int("broadcast-http-port", opts.BroadcastHTTPPort, "HTTP port that will be registered with lookupd (defaults to the HTTP port that this nsqd is listening on)")
Expand Down
23 changes: 15 additions & 8 deletions internal/auth/authorizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ func (a *State) IsExpired() bool {
}

func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) {
var retErr error
start := rand.Int()
n := len(authd)
for i := 0; i < n; i++ {
a := authd[(i+start)%n]
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout)
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout, httpRequestMethod)
if err != nil {
es := fmt.Sprintf("failed to auth against %s - %s", a, err)
if retErr != nil {
Expand All @@ -97,7 +97,8 @@ func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName
}

func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) {
var authState State
v := url.Values{}
v.Set("remote_ip", remoteIP)
if tlsEnabled {
Expand All @@ -110,15 +111,21 @@ func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName strin

var endpoint string
if strings.Contains(authd, "://") {
endpoint = fmt.Sprintf("%s?%s", authd, v.Encode())
endpoint = authd
} else {
endpoint = fmt.Sprintf("http://%s/auth?%s", authd, v.Encode())
endpoint = fmt.Sprintf("http://%s/auth", authd)
}

var authState State
client := http_api.NewClient(clientTLSConfig, connectTimeout, requestTimeout)
if err := client.GETV1(endpoint, &authState); err != nil {
return nil, err
if httpRequestMethod == "post" {
if err := client.POSTV1(endpoint, v, &authState); err != nil {
return nil, err
}
} else {
endpoint = fmt.Sprintf("%s?%s", endpoint, v.Encode())
if err := client.GETV1(endpoint, &authState); err != nil {
return nil, err
}
}

// validation on response
Expand Down
43 changes: 43 additions & 0 deletions internal/http_api/api_request.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http_api

import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -119,6 +120,48 @@ retry:
return nil
}

func (c *Client) POSTV1(endpoint string, data url.Values, v interface{}) error {

Check failure on line 123 in internal/http_api/api_request.go

View workflow job for this annotation

GitHub Actions / test (1.18.x, amd64)

Client.POSTV1 redeclared in this block

Check failure on line 123 in internal/http_api/api_request.go

View workflow job for this annotation

GitHub Actions / test (1.19.x, amd64)

Client.POSTV1 redeclared in this block

Check failure on line 123 in internal/http_api/api_request.go

View workflow job for this annotation

GitHub Actions / test (1.19.x, 386)

Client.POSTV1 redeclared in this block

Check failure on line 123 in internal/http_api/api_request.go

View workflow job for this annotation

GitHub Actions / test (1.20.x, amd64)

method Client.POSTV1 already declared at internal/http_api/api_request.go:90:18

Check failure on line 123 in internal/http_api/api_request.go

View workflow job for this annotation

GitHub Actions / test (1.18.x, 386)

Client.POSTV1 redeclared in this block

Check failure on line 123 in internal/http_api/api_request.go

View workflow job for this annotation

GitHub Actions / test (1.20.x, 386)

method Client.POSTV1 already declared at internal/http_api/api_request.go:90:18

Check failure on line 123 in internal/http_api/api_request.go

View workflow job for this annotation

GitHub Actions / staticcheck

method Client.POSTV1 already declared at internal/http_api/api_request.go:90:18 (compile)
retry:
reqBody, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal POST data to endpoint: %v", endpoint)
}
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(reqBody))
if err != nil {
return err
}

req.Header.Add("Accept", "application/vnd.nsq; version=1.0")
req.Header.Add("Content-Type", "application/json")

resp, err := c.c.Do(req)
if err != nil {
return err
}

body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return err
}
if resp.StatusCode != 200 {
if resp.StatusCode == 403 && !strings.HasPrefix(endpoint, "https") {
endpoint, err = httpsEndpoint(endpoint, body)
if err != nil {
return err
}
goto retry
}
return fmt.Errorf("got response %s %q", resp.Status, body)
}
err = json.Unmarshal(body, &v)
if err != nil {
return err
}

return nil
}

func httpsEndpoint(endpoint string, body []byte) (string, error) {
var forbiddenResp struct {
HTTPSPort int `json:"https_port"`
Expand Down
4 changes: 3 additions & 1 deletion nsqd/client_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,9 @@ func (c *clientV2) QueryAuthd() error {
remoteIP, tlsEnabled, commonName, c.AuthSecret,
c.nsqd.clientTLSConfig,
c.nsqd.getOpts().HTTPClientConnectTimeout,
c.nsqd.getOpts().HTTPClientRequestTimeout)
c.nsqd.getOpts().HTTPClientRequestTimeout,
c.nsqd.getOpts().AuthHTTPRequestMethod,
)
if err != nil {
return err
}
Expand Down
4 changes: 4 additions & 0 deletions nsqd/nsqd.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ func New(opts *Options) (*NSQD, error) {
}
n.clientTLSConfig = clientTLSConfig

if opts.AuthHTTPRequestMethod != "post" && opts.AuthHTTPRequestMethod != "get" {
return nil, errors.New("--auth-http-request-method must be post or get")
}

for _, v := range opts.E2EProcessingLatencyPercentiles {
if v <= 0 || v > 1 {
return nil, fmt.Errorf("invalid E2E processing latency percentile: %v", v)
Expand Down
2 changes: 2 additions & 0 deletions nsqd/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type Options struct {
BroadcastHTTPPort int `flag:"broadcast-http-port"`
NSQLookupdTCPAddresses []string `flag:"lookupd-tcp-address" cfg:"nsqlookupd_tcp_addresses"`
AuthHTTPAddresses []string `flag:"auth-http-address" cfg:"auth_http_addresses"`
AuthHTTPRequestMethod string `flag:"auth-http-request-method" cfg:"auth_http_request_method"`
HTTPClientConnectTimeout time.Duration `flag:"http-client-connect-timeout" cfg:"http_client_connect_timeout"`
HTTPClientRequestTimeout time.Duration `flag:"http-client-request-timeout" cfg:"http_client_request_timeout"`

Expand Down Expand Up @@ -110,6 +111,7 @@ func NewOptions() *Options {

NSQLookupdTCPAddresses: make([]string, 0),
AuthHTTPAddresses: make([]string, 0),
AuthHTTPRequestMethod: "get",

HTTPClientConnectTimeout: 2 * time.Second,
HTTPClientRequestTimeout: 5 * time.Second,
Expand Down
38 changes: 29 additions & 9 deletions nsqd/protocol_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"os"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -1476,24 +1477,30 @@ func TestClientAuth(t *testing.T) {
authSuccess := ""
tlsEnabled := false
commonName := ""
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
httpAuthRequestMethod := "get"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

// now one that will succeed
authResponse = `{"ttl":10, "authorizations":
[{"topic":"test", "channels":[".*"], "permissions":["subscribe","publish"]}]
}`
authError = ""
authSuccess = `{"identity":"","identity_url":"","permission_count":1}`
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

// one with TLS enabled
tlsEnabled = true
commonName = "test.local"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

// test POST based authentication
httpAuthRequestMethod = "post"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

}

func runAuthTest(t *testing.T, authResponse string, authSecret string, authError string,
authSuccess string, tlsEnabled bool, commonName string) {
authSuccess string, tlsEnabled bool, commonName string, httpAuthRequestMethod string) {
var err error
var expectedRemoteIP string
expectedTLS := "false"
Expand All @@ -1503,11 +1510,23 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError

authd := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("in test auth handler %s", r.RequestURI)
r.ParseForm()
test.Equal(t, expectedRemoteIP, r.Form.Get("remote_ip"))
test.Equal(t, expectedTLS, r.Form.Get("tls"))
test.Equal(t, commonName, r.Form.Get("common_name"))
test.Equal(t, authSecret, r.Form.Get("secret"))
test.Equal(t, httpAuthRequestMethod, strings.ToLower(r.Method))

var values url.Values

if r.Method == "POST" {
err = json.NewDecoder(r.Body).Decode(&values)
if err != nil {
t.Error(err)
}
} else {
r.ParseForm()
values = r.Form
}
test.Equal(t, expectedRemoteIP, values.Get("remote_ip"))
test.Equal(t, expectedTLS, values.Get("tls"))
test.Equal(t, commonName, values.Get("common_name"))
test.Equal(t, authSecret, values.Get("secret"))
fmt.Fprint(w, authResponse)
}))
defer authd.Close()
Expand All @@ -1519,6 +1538,7 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError
opts.Logger = test.NewTestLogger(t)
opts.LogLevel = LOG_DEBUG
opts.AuthHTTPAddresses = []string{addr.Host}
opts.AuthHTTPRequestMethod = httpAuthRequestMethod
if tlsEnabled {
opts.TLSCert = "./test/certs/server.pem"
opts.TLSKey = "./test/certs/server.key"
Expand Down

0 comments on commit f162880

Please sign in to comment.