-
Notifications
You must be signed in to change notification settings - Fork 257
/
Copy pathroundtripper_test.go
144 lines (120 loc) · 3.35 KB
/
roundtripper_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package retryablehttp
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"sync/atomic"
"testing"
)
func TestRoundTripper_implements(t *testing.T) {
// Compile-time proof of interface satisfaction.
var _ http.RoundTripper = &RoundTripper{}
}
func TestRoundTripper_init(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer ts.Close()
// Start with a new empty RoundTripper.
rt := &RoundTripper{}
// RoundTrip once.
req, _ := http.NewRequest("GET", ts.URL, nil)
if _, err := rt.RoundTrip(req); err != nil {
t.Fatal(err)
}
// Check that the Client was initialized.
if rt.Client == nil {
t.Fatal("expected rt.Client to be initialized")
}
// Save the Client for later comparison.
initialClient := rt.Client
// RoundTrip again.
req, _ = http.NewRequest("GET", ts.URL, nil)
if _, err := rt.RoundTrip(req); err != nil {
t.Fatal(err)
}
// Check that the underlying Client is unchanged.
if rt.Client != initialClient {
t.Fatalf("expected %v, got %v", initialClient, rt.Client)
}
}
func TestRoundTripper_RoundTrip(t *testing.T) {
var reqCount int32 = 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqNo := atomic.AddInt32(&reqCount, 1)
if reqNo < 3 {
w.WriteHeader(404)
} else {
w.WriteHeader(200)
w.Write([]byte("success!"))
}
}))
defer ts.Close()
// Make a client with some custom settings to verify they are used.
retryClient := NewClient()
retryClient.CheckRetry = func(_ context.Context, resp *http.Response, _ error) (bool, error) {
return resp.StatusCode == 404, nil
}
// Get the standard client and execute the request.
client := retryClient.StandardClient()
resp, err := client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
// Check the response to ensure the client behaved as expected.
if resp.StatusCode != 200 {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if v, err := io.ReadAll(resp.Body); err != nil {
t.Fatal(err)
} else if string(v) != "success!" {
t.Fatalf("expected %q, got %q", "success!", v)
}
}
func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) {
// Make a client with some custom settings to verify they are used.
retryClient := NewClient()
retryClient.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
if err != nil {
return true, err
}
return false, nil
}
retryClient.ErrorHandler = PassthroughErrorHandler
expectedError := &url.Error{
Op: "Get",
URL: "http://999.999.999.999:999/",
Err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &net.DNSError{
Name: "999.999.999.999",
Err: "no such host",
IsNotFound: true,
},
},
}
// Get the standard client and execute the request.
client := retryClient.StandardClient()
_, err := client.Get("http://999.999.999.999:999/")
// assert expectations
if !reflect.DeepEqual(expectedError, normalizeError(err)) {
t.Fatalf("expected %q, got %q", expectedError, err)
}
}
func normalizeError(err error) error {
var dnsError *net.DNSError
if errors.As(err, &dnsError) {
// this field is populated with the DNS server on on CI, but not locally
dnsError.Server = ""
}
return err
}