Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nsqd: support POST auth #1487

Merged
merged 1 commit into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/nsqd/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func nsqdFlagSet(opts *nsqd.Options) *flag.FlagSet {

authHTTPAddresses := app.StringArray{}
flagSet.Var(&authHTTPAddresses, "auth-http-address", "<addr>:<port> or a full url to query auth server (may be given multiple times)")
flagSet.String("auth-http-request-method", opts.AuthHTTPRequestMethod, "HTTP method to use for auth server requests")
danrjohnson marked this conversation as resolved.
Show resolved Hide resolved
flagSet.String("broadcast-address", opts.BroadcastAddress, "address that will be registered with lookupd (defaults to the OS hostname)")
flagSet.Int("broadcast-tcp-port", opts.BroadcastTCPPort, "TCP port that will be registered with lookupd (defaults to the TCP port that this nsqd is listening on)")
flagSet.Int("broadcast-http-port", opts.BroadcastHTTPPort, "HTTP port that will be registered with lookupd (defaults to the HTTP port that this nsqd is listening on)")
Expand Down
23 changes: 15 additions & 8 deletions internal/auth/authorizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ func (a *State) IsExpired() bool {
}

func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) {
var retErr error
start := rand.Int()
n := len(authd)
for i := 0; i < n; i++ {
a := authd[(i+start)%n]
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout)
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout, httpRequestMethod)
if err != nil {
es := fmt.Sprintf("failed to auth against %s - %s", a, err)
if retErr != nil {
Expand All @@ -97,7 +97,8 @@ func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName
}

func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) {
var authState State
v := url.Values{}
v.Set("remote_ip", remoteIP)
if tlsEnabled {
Expand All @@ -110,15 +111,21 @@ func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName strin

var endpoint string
if strings.Contains(authd, "://") {
endpoint = fmt.Sprintf("%s?%s", authd, v.Encode())
endpoint = authd
} else {
endpoint = fmt.Sprintf("http://%s/auth?%s", authd, v.Encode())
endpoint = fmt.Sprintf("http://%s/auth", authd)
}

var authState State
client := http_api.NewClient(clientTLSConfig, connectTimeout, requestTimeout)
if err := client.GETV1(endpoint, &authState); err != nil {
return nil, err
if httpRequestMethod == "post" {
if err := client.POSTV1(endpoint, v, &authState); err != nil {
return nil, err
}
} else {
endpoint = fmt.Sprintf("%s?%s", endpoint, v.Encode())
if err := client.GETV1(endpoint, &authState); err != nil {
return nil, err
}
}

// validation on response
Expand Down
4 changes: 2 additions & 2 deletions internal/clusterinfo/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ func (c *ClusterInfo) nsqlookupdPOST(addrs []string, uri string, qs string) erro
for _, addr := range addrs {
endpoint := fmt.Sprintf("http://%s/%s?%s", addr, uri, qs)
c.logf("CI: querying nsqlookupd %s", endpoint)
err := c.client.POSTV1(endpoint)
err := c.client.POSTV1(endpoint, nil, nil)
if err != nil {
errs = append(errs, err)
}
Expand All @@ -894,7 +894,7 @@ func (c *ClusterInfo) producersPOST(pl Producers, uri string, qs string) error {
for _, p := range pl {
endpoint := fmt.Sprintf("http://%s/%s?%s", p.HTTPAddress(), uri, qs)
c.logf("CI: querying nsqd %s", endpoint)
err := c.client.POSTV1(endpoint)
err := c.client.POSTV1(endpoint, nil, nil)
if err != nil {
errs = append(errs, err)
}
Expand Down
21 changes: 19 additions & 2 deletions internal/http_api/api_request.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http_api

import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -86,14 +87,26 @@ retry:

// PostV1 is a helper function to perform a V1 HTTP request
// and parse our NSQ daemon's expected response format, with deadlines.
func (c *Client) POSTV1(endpoint string) error {
func (c *Client) POSTV1(endpoint string, data url.Values, v interface{}) error {
retry:
req, err := http.NewRequest("POST", endpoint, nil)
var reqBody io.Reader
if data != nil {
js, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal POST data to endpoint: %v", endpoint)
}
reqBody = bytes.NewBuffer(js)
}

req, err := http.NewRequest("POST", endpoint, reqBody)
if err != nil {
return err
}

req.Header.Add("Accept", "application/vnd.nsq; version=1.0")
if reqBody != nil {
req.Header.Add("Content-Type", "application/json")
}

resp, err := c.c.Do(req)
if err != nil {
Expand All @@ -116,6 +129,10 @@ retry:
return fmt.Errorf("got response %s %q", resp.Status, body)
}

if v != nil {
return json.Unmarshal(body, &v)
}

return nil
}

Expand Down
4 changes: 3 additions & 1 deletion nsqd/client_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,9 @@ func (c *clientV2) QueryAuthd() error {
remoteIP, tlsEnabled, commonName, c.AuthSecret,
c.nsqd.clientTLSConfig,
c.nsqd.getOpts().HTTPClientConnectTimeout,
c.nsqd.getOpts().HTTPClientRequestTimeout)
c.nsqd.getOpts().HTTPClientRequestTimeout,
c.nsqd.getOpts().AuthHTTPRequestMethod,
)
if err != nil {
return err
}
Expand Down
4 changes: 4 additions & 0 deletions nsqd/nsqd.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ func New(opts *Options) (*NSQD, error) {
}
n.clientTLSConfig = clientTLSConfig

if opts.AuthHTTPRequestMethod != "post" && opts.AuthHTTPRequestMethod != "get" {
return nil, errors.New("--auth-http-request-method must be post or get")
}

for _, v := range opts.E2EProcessingLatencyPercentiles {
if v <= 0 || v > 1 {
return nil, fmt.Errorf("invalid E2E processing latency percentile: %v", v)
Expand Down
6 changes: 3 additions & 3 deletions nsqd/nsqd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,11 @@ func TestCluster(t *testing.T) {
test.Nil(t, err)

url := fmt.Sprintf("http://%s/topic/create?topic=%s", nsqd.RealHTTPAddr(), topicName)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil)
test.Nil(t, err)

url = fmt.Sprintf("http://%s/channel/create?topic=%s&channel=ch", nsqd.RealHTTPAddr(), topicName)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil)
test.Nil(t, err)

// allow some time for nsqd to push info to nsqlookupd
Expand Down Expand Up @@ -394,7 +394,7 @@ func TestCluster(t *testing.T) {
test.Equal(t, "ch", lr.Channels[0])

url = fmt.Sprintf("http://%s/topic/delete?topic=%s", nsqd.RealHTTPAddr(), topicName)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil)
test.Nil(t, err)

// allow some time for nsqd to push info to nsqlookupd
Expand Down
2 changes: 2 additions & 0 deletions nsqd/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type Options struct {
BroadcastHTTPPort int `flag:"broadcast-http-port"`
NSQLookupdTCPAddresses []string `flag:"lookupd-tcp-address" cfg:"nsqlookupd_tcp_addresses"`
AuthHTTPAddresses []string `flag:"auth-http-address" cfg:"auth_http_addresses"`
AuthHTTPRequestMethod string `flag:"auth-http-request-method" cfg:"auth_http_request_method"`
HTTPClientConnectTimeout time.Duration `flag:"http-client-connect-timeout" cfg:"http_client_connect_timeout"`
HTTPClientRequestTimeout time.Duration `flag:"http-client-request-timeout" cfg:"http_client_request_timeout"`

Expand Down Expand Up @@ -110,6 +111,7 @@ func NewOptions() *Options {

NSQLookupdTCPAddresses: make([]string, 0),
AuthHTTPAddresses: make([]string, 0),
AuthHTTPRequestMethod: "get",

HTTPClientConnectTimeout: 2 * time.Second,
HTTPClientRequestTimeout: 5 * time.Second,
Expand Down
38 changes: 29 additions & 9 deletions nsqd/protocol_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"os"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -1476,24 +1477,30 @@ func TestClientAuth(t *testing.T) {
authSuccess := ""
tlsEnabled := false
commonName := ""
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
httpAuthRequestMethod := "get"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

// now one that will succeed
authResponse = `{"ttl":10, "authorizations":
[{"topic":"test", "channels":[".*"], "permissions":["subscribe","publish"]}]
}`
authError = ""
authSuccess = `{"identity":"","identity_url":"","permission_count":1}`
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

// one with TLS enabled
tlsEnabled = true
commonName = "test.local"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

// test POST based authentication
httpAuthRequestMethod = "post"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

}

func runAuthTest(t *testing.T, authResponse string, authSecret string, authError string,
authSuccess string, tlsEnabled bool, commonName string) {
authSuccess string, tlsEnabled bool, commonName string, httpAuthRequestMethod string) {
var err error
var expectedRemoteIP string
expectedTLS := "false"
Expand All @@ -1503,11 +1510,23 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError

authd := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("in test auth handler %s", r.RequestURI)
r.ParseForm()
test.Equal(t, expectedRemoteIP, r.Form.Get("remote_ip"))
test.Equal(t, expectedTLS, r.Form.Get("tls"))
test.Equal(t, commonName, r.Form.Get("common_name"))
test.Equal(t, authSecret, r.Form.Get("secret"))
test.Equal(t, httpAuthRequestMethod, strings.ToLower(r.Method))

var values url.Values

if r.Method == "POST" {
err = json.NewDecoder(r.Body).Decode(&values)
if err != nil {
t.Error(err)
}
} else {
r.ParseForm()
values = r.Form
}
test.Equal(t, expectedRemoteIP, values.Get("remote_ip"))
test.Equal(t, expectedTLS, values.Get("tls"))
test.Equal(t, commonName, values.Get("common_name"))
test.Equal(t, authSecret, values.Get("secret"))
fmt.Fprint(w, authResponse)
}))
defer authd.Close()
Expand All @@ -1519,6 +1538,7 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError
opts.Logger = test.NewTestLogger(t)
opts.LogLevel = LOG_DEBUG
opts.AuthHTTPAddresses = []string{addr.Host}
opts.AuthHTTPRequestMethod = httpAuthRequestMethod
if tlsEnabled {
opts.TLSCert = "./test/certs/server.pem"
opts.TLSKey = "./test/certs/server.key"
Expand Down
6 changes: 3 additions & 3 deletions nsqlookupd/nsqlookupd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func TestTombstoneRecover(t *testing.T) {

endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d",
httpAddr, topicName, HostAddr, HTTPPort)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil)
test.Nil(t, err)

pr := ProducersDoc{}
Expand Down Expand Up @@ -263,7 +263,7 @@ func TestTombstoneUnregister(t *testing.T) {

endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d",
httpAddr, topicName, HostAddr, HTTPPort)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil)
test.Nil(t, err)

pr := ProducersDoc{}
Expand Down Expand Up @@ -348,7 +348,7 @@ func TestTombstonedNodes(t *testing.T) {

endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d",
httpAddr, topicName, HostAddr, HTTPPort)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil)
test.Nil(t, err)

producers, _ = ci.GetLookupdProducers(lookupdHTTPAddrs)
Expand Down
Loading