Skip to content

Commit

Permalink
TT-12057 Add Proxy support for Splunk (TykTechnologies#820)
Browse files Browse the repository at this point in the history
* Add Proxy support for Splunk

* Add tests to check for proxy connection

* update deprecated methods and run splunk tests earlier

---------

Co-authored-by: Sredny M <[email protected]>
  • Loading branch information
sedkis and sredxny authored May 27, 2024
1 parent a93a6a3 commit dcba70a
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 10 deletions.
5 changes: 4 additions & 1 deletion pumps/splunk.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,10 @@ func NewSplunkClient(token string, collectorURL string, skipVerify bool, certFil
}
tlsConfig = &tls.Config{Certificates: []tls.Certificate{cert}, ServerName: serverName}
}
http.DefaultClient.Transport = &http.Transport{TLSClientConfig: tlsConfig}
http.DefaultClient.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: tlsConfig,
}
// Append the default collector API path:
u.Path = defaultPath
c = &SplunkClient{
Expand Down
80 changes: 71 additions & 9 deletions pumps/splunk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package pumps
import (
"context"
"encoding/json"
"io/ioutil"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -46,11 +48,12 @@ func (h *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Body == nil {
h.test.Fatal("Body is nil")
}
body, err := ioutil.ReadAll(r.Body)

body, err := io.ReadAll(r.Body)
if err != nil {
h.test.Fatal("Couldn't ready body")
}
r.Body.Close()
defer r.Body.Close()

if h.returnErrors >= h.reqCount {
w.WriteHeader(http.StatusInternalServerError)
Expand All @@ -72,7 +75,10 @@ func (h *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
status.Len = len(body)
}

statusJSON, _ := json.Marshal(&status)
statusJSON, err := json.Marshal(&status)
if err != nil {
h.test.Fatalf("Failed to marshal JSON: %v", err)
}
w.Write(statusJSON)
h.responses = append(h.responses, status)
}
Expand All @@ -92,6 +98,59 @@ func TestSplunkInit(t *testing.T) {
}
}

func Test_SplunkProxyFromEnvironment(t *testing.T) {
// Setup a test server to act as a proxy
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Proxy call successful")
}))
defer proxyServer.Close()

// Set environment variable to use the proxy
os.Setenv("HTTP_PROXY", proxyServer.URL)
defer os.Unsetenv("HTTP_PROXY")

// Initialize client
client, err := NewSplunkClient("token", "https://example.com", true, "", "", "")
if err != nil {
t.Fatal("Failed to create client:", err)
}

// Make a request
resp, err := client.httpClient.Get("http://example.com")
if err != nil {
t.Fatal("Failed to make request:", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal("Failed to read response:", err)
}

// Check if the proxy was called
if string(body) != "Proxy call successful\n" {
t.Errorf("Expected proxy to be called, but it wasn't")
}

}

func Test_SplunkInvalidProxyURL(t *testing.T) {
// Set an invalid proxy URL
os.Setenv("HTTP_PROXY", "htttp://invalid-url")
defer os.Unsetenv("HTTP_PROXY")

// Initialize client
client, err := NewSplunkClient("token", "https://example.com", true, "", "", "")
if err != nil {
t.Fatal("Failed to create client:", err)
}

// Make a request and expect it to fail
_, err = client.httpClient.Get("http://example.com")
if err == nil {
t.Error("Expected error due to invalid proxy URL, but no error occurred")
}
}

func Test_SplunkBackoffRetry(t *testing.T) {
go t.Run("max_retries=1", func(t *testing.T) {
handler := &testHandler{test: t, batched: false, returnErrors: 1}
Expand Down Expand Up @@ -254,7 +313,7 @@ func Test_SplunkWriteDataBatch(t *testing.T) {
cfg["collector_url"] = server.URL
cfg["ssl_insecure_skip_verify"] = true
cfg["enable_batch"] = true
cfg["batch_max_content_length"] = getEventBytes(keys[:2])
cfg["batch_max_content_length"] = getEventBytes(keys[:2], t)

if errInit := pmp.Init(cfg); errInit != nil {
t.Error("Error initializing pump")
Expand All @@ -268,12 +327,12 @@ func Test_SplunkWriteDataBatch(t *testing.T) {

assert.Equal(t, 2, len(handler.responses))

assert.Equal(t, getEventBytes(keys[:2]), handler.responses[0].Len)
assert.Equal(t, getEventBytes(keys[2:]), handler.responses[1].Len)
assert.Equal(t, getEventBytes(keys[:2], t), handler.responses[0].Len)
assert.Equal(t, getEventBytes(keys[2:], t), handler.responses[1].Len)
}

// getEventBytes returns the bytes amount of the marshalled events struct
func getEventBytes(records []interface{}) int {
func getEventBytes(records []interface{}, t *testing.T) int {
result := 0

for _, record := range records {
Expand Down Expand Up @@ -301,7 +360,10 @@ func getEventBytes(records []interface{}) int {
Event map[string]interface{} `json:"event"`
}{Time: decoded.TimeStamp.Unix(), Event: event}

data, _ := json.Marshal(eventWrap)
data, err := json.Marshal(eventWrap)
if err != nil {
t.Fatal("Failed to marshal event:", err) // Adjusted for context that t is not available, consider passing testing.T or handle differently.
}
result += len(data)
}
return result
Expand Down

0 comments on commit dcba70a

Please sign in to comment.