Skip to content

Commit

Permalink
Merge pull request #103 from matrix-org/kegan/callback-can-modify-res…
Browse files Browse the repository at this point in the history
…ponses

mitmproxy: let the callback addon modify response data
  • Loading branch information
kegsay authored Jul 4, 2024
2 parents 60dcf8b + 1afe671 commit 9b17d5d
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 44 deletions.
23 changes: 21 additions & 2 deletions internal/deploy/callback_addon.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,21 @@ type CallbackData struct {
RequestBody json.RawMessage `json:"request_body"`
}

type CallbackResponse struct {
// if set, changes the HTTP response status code for this request.
RespondStatusCode int `json:"respond_status_code,omitempty"`
// if set, changes the HTTP response body for this request.
RespondBody json.RawMessage `json:"respond_body,omitempty"`
}

func (cd CallbackData) String() string {
return fmt.Sprintf("%s %s (token=%s) req_len=%d => HTTP %v", cd.Method, cd.URL, cd.AccessToken, len(cd.RequestBody), cd.ResponseCode)
}

// NewCallbackServer runs a local HTTP server that can read callbacks from mitmproxy.
// Returns the URL of the callback server for use with WithMITMOptions, along with a close function
// which should be called when the test finishes to shut down the HTTP server.
func NewCallbackServer(t *testing.T, hostnameRunningComplement string, cb func(CallbackData)) (callbackURL string, close func()) {
func NewCallbackServer(t *testing.T, hostnameRunningComplement string, cb func(CallbackData) *CallbackResponse) (callbackURL string, close func()) {
if lastTestName != "" {
t.Logf("WARNING[%s]: NewCallbackServer called without closing the last one. Check test '%s'", t.Name(), lastTestName)
}
Expand All @@ -53,8 +60,20 @@ func NewCallbackServer(t *testing.T, hostnameRunningComplement string, cb func(C
}
}
t.Logf("CallbackServer[%s]%s: %v %s", t.Name(), localpart, time.Now(), data)
cb(data)
cbRes := cb(data)
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(200)
if cbRes == nil {
w.Write([]byte(`{}`))
return
}
cbResBytes, err := json.Marshal(cbRes)
if err != nil {
ct.Errorf(t, "failed to marshal callback response: %s", err)
return
}
fmt.Println(string(cbResBytes))
w.Write(cbResBytes)
})
// listen on a random high numbered port
ln, err := net.Listen("tcp", ":0") //nolint
Expand Down
93 changes: 80 additions & 13 deletions internal/deploy/callback_addon_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package deploy

import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"reflect"
Expand All @@ -14,6 +16,7 @@ import (
"github.com/matrix-org/complement/ct"
"github.com/matrix-org/complement/helpers"
"github.com/matrix-org/complement/must"
"github.com/tidwall/gjson"
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -130,12 +133,13 @@ func TestCallbackAddon(t *testing.T) {
signalSendUnrelatedRequest := make(chan bool)
signalTestFinished := make(chan bool)
checker.expect(&callbackRequest{
OnCallback: func(cd CallbackData) {
OnCallback: func(cd CallbackData) *CallbackResponse {
if strings.Contains(cd.URL, "capabilities") {
close(signalSendUnrelatedRequest) // send the signal to make the unrelated request
time.Sleep(time.Second) // tarpit this request
close(signalTestFinished) // test is done, cleanup
}
return nil
},
})
beforeSendingRequests := time.Now()
Expand Down Expand Up @@ -167,12 +171,70 @@ func TestCallbackAddon(t *testing.T) {
}
},
},

// TODO: migrate functionality from status_code addon
// TODO: can modify response codes
// TODO: can modify response bodies
{
name: "can modify response codes without modifying the response body",
filter: "~hq " + client.AccessToken,
inner: func(t *testing.T, checker *checker) {
checker.expect(&callbackRequest{
OnCallback: func(cd CallbackData) *CallbackResponse {
return &CallbackResponse{
RespondStatusCode: 404,
}
},
})
res := client.Do(t, "GET", []string{"_matrix", "client", "v3", "capabilities"})
checker.wait()
must.Equal(t, res.StatusCode, 404, "response code was not altered")
body, err := io.ReadAll(res.Body)
must.NotError(t, "failed to read CSAPI response", err)
must.Equal(t, gjson.ParseBytes(body).Get("capabilities").Exists(), true, "response body was modified")
},
},
{
name: "can modify response bodies without modifying the response code",
filter: "~hq " + client.AccessToken,
inner: func(t *testing.T, checker *checker) {
checker.expect(&callbackRequest{
OnCallback: func(cd CallbackData) *CallbackResponse {
return &CallbackResponse{
RespondBody: json.RawMessage(`{
"foo": "bar"
}`),
}
},
})
res := client.Do(t, "GET", []string{"_matrix", "client", "v3", "capabilities"})
checker.wait()
must.Equal(t, res.StatusCode, 200, "response code was modified")
body, err := io.ReadAll(res.Body)
must.NotError(t, "failed to read CSAPI response", err)
must.Equal(t, gjson.ParseBytes(body).Get("foo").Str, "bar", "response body was not altered")
},
},
{
name: "can modify response codes and bodies",
filter: "~hq " + client.AccessToken,
inner: func(t *testing.T, checker *checker) {
checker.expect(&callbackRequest{
OnCallback: func(cd CallbackData) *CallbackResponse {
return &CallbackResponse{
RespondStatusCode: 403,
RespondBody: json.RawMessage(`{
"foo": "bar"
}`),
}
},
})
res := client.Do(t, "GET", []string{"_matrix", "client", "v3", "capabilities"})
checker.wait()
must.Equal(t, res.StatusCode, 403, "response code was not modified")
body, err := io.ReadAll(res.Body)
must.NotError(t, "failed to read CSAPI response", err)
must.Equal(t, gjson.ParseBytes(body).Get("foo").Str, "bar", "response body was not modified")
},
},
// TODO: can block requests
// TODO: can block responses
// TODO: migrate functionality from status_code addon
}

for _, tc := range testCases {
Expand All @@ -184,8 +246,8 @@ func TestCallbackAddon(t *testing.T) {
}
callbackURL, close := NewCallbackServer(
t, deployment.GetConfig().HostnameRunningComplement,
func(cd CallbackData) {
checker.onCallback(cd)
func(cd CallbackData) *CallbackResponse {
return checker.onCallback(cd)
},
)
defer close()
Expand All @@ -212,7 +274,7 @@ type callbackRequest struct {
PathContains string
AccessToken string
ResponseCode int
OnCallback func(cd CallbackData)
OnCallback func(cd CallbackData) *CallbackResponse
}

type checker struct {
Expand All @@ -223,14 +285,14 @@ type checker struct {
noCallbacks bool
}

func (c *checker) onCallback(cd CallbackData) {
func (c *checker) onCallback(cd CallbackData) *CallbackResponse {
c.mu.Lock()
if c.noCallbacks {
ct.Errorf(c.t, "wanted no callbacks but got %+v", cd)
}
if c.want == nil {
c.mu.Unlock()
return
return nil
}
if c.want.AccessToken != "" {
must.Equal(c.t, cd.AccessToken, c.want.AccessToken, "access token mismatch")
Expand All @@ -251,11 +313,13 @@ func (c *checker) onCallback(cd CallbackData) {
// unlock early so we don't block other requests, as custom callbacks are generally
// used for testing tarpitting.
c.mu.Unlock()
var callbackResponse *CallbackResponse
if customCallback != nil {
customCallback(cd)
callbackResponse = customCallback(cd)
}
// signal that we processed the callback
c.ch <- *c.want
return callbackResponse
}

func (c *checker) expect(want *callbackRequest) {
Expand All @@ -271,9 +335,12 @@ func (c *checker) expectNoCallbacks(noCallbacks bool) {
}

func (c *checker) wait() {
c.t.Helper()
select {
case got := <-c.ch:
if !reflect.DeepEqual(got, *c.want) {
// we can't sanity check if there are callbacks involved, as we can't easily
// pair responses up.
if c.want.OnCallback == nil && !reflect.DeepEqual(got, *c.want) {
ct.Fatalf(c.t, "checker: got success from a different request: did you forget to wait?"+
" Received %+v but expected +%v", got, c.want)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/deploy/mitm.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ type MITMPathConfiguration struct {
path string
accessToken string
method string
listener func(cd CallbackData)
listener func(cd CallbackData) *CallbackResponse

blockCount int
blockStatusCode int
Expand Down Expand Up @@ -180,7 +180,7 @@ func (p *MITMPathConfiguration) filter() string {
return s.String()
}

func (p *MITMPathConfiguration) Listen(cb func(cd CallbackData)) *MITMPathConfiguration {
func (p *MITMPathConfiguration) Listen(cb func(cd CallbackData) *CallbackResponse) *MITMPathConfiguration {
p.listener = cb
return p
}
Expand Down
3 changes: 2 additions & 1 deletion tests/delayed_requests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestDelayedInviteResponse(t *testing.T) {

config := tc.Deployment.MITM().Configure(t)
serverHasInvite := helpers.NewWaiter()
config.ForPath("/sync").AccessToken(alice.CurrentAccessToken(t)).Listen(func(cd deploy.CallbackData) {
config.ForPath("/sync").AccessToken(alice.CurrentAccessToken(t)).Listen(func(cd deploy.CallbackData) *deploy.CallbackResponse {
if strings.Contains(
strings.ReplaceAll(string(cd.ResponseBody), " ", ""),
`"membership":"invite"`,
Expand All @@ -50,6 +50,7 @@ func TestDelayedInviteResponse(t *testing.T) {
serverHasInvite.Finish()
time.Sleep(delayTime)
}
return nil
})
config.Execute(func() {
t.Logf("Alice about to /invite Bob")
Expand Down
3 changes: 2 additions & 1 deletion tests/device_keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ func TestFailedDeviceKeyDownloadRetries(t *testing.T) {

// Given that the first 4 attempts to download device keys will fail
mitmConfiguration := tc.Deployment.MITM().Configure(t)
mitmConfiguration.ForPath("/keys/query").Method("POST").BlockRequest(4, http.StatusGatewayTimeout).Listen(func(data deploy.CallbackData) {
mitmConfiguration.ForPath("/keys/query").Method("POST").BlockRequest(4, http.StatusGatewayTimeout).Listen(func(data deploy.CallbackData) *deploy.CallbackResponse {
queryReceived.Store(true)
return nil
})
mitmConfiguration.Execute(func() {
// And Alice and Bob are in an encrypted room together
Expand Down
32 changes: 23 additions & 9 deletions tests/mitmproxy_addons/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,33 @@ async def response(self, flow):
except:
res_body = None
print(f'{datetime.now().strftime("%H:%M:%S.%f")} hitting callback for {flow.request.url}')
callback_body = {
"method": flow.request.method,
"access_token": flow.request.headers.get("Authorization", "").removeprefix("Bearer "),
"url": flow.request.url,
"response_code": flow.response.status_code,
"request_body": req_body,
"response_body": res_body,
}
try:
# use asyncio so we don't block other unrelated requests from being processed
async with aiohttp.request(
method="POST",url=self.config["callback_url"], timeout=aiohttp.ClientTimeout(total=10),
headers={"Content-Type": "application/json"},
json={
"method": flow.request.method,
"access_token": flow.request.headers.get("Authorization", "").removeprefix("Bearer "),
"url": flow.request.url,
"response_code": flow.response.status_code,
"request_body": req_body,
"response_body": res_body,
}) as response:
json=callback_body) as response:
print(f'{datetime.now().strftime("%H:%M:%S.%f")} callback for {flow.request.url} returned HTTP {response.status}')
test_response_body = await response.json()
# if the response includes some keys then we are modifying the response on a per-key basis.
if len(test_response_body) > 0:
respond_status_code = test_response_body.get("respond_status_code", flow.response.status_code)
respond_body = test_response_body.get("respond_body", res_body)
flow.response = Response.make(
respond_status_code, json.dumps(respond_body),
headers={
"MITM-Proxy": "yes", # so we don't reprocess this
"Content-Type": "application/json",
})

except Exception as error:
print(f"ERR: callback returned {error}")
print(f"ERR: callback for {flow.request.url} returned {error}")
print(f"ERR: callback, provided request body was {callback_body}")
5 changes: 3 additions & 2 deletions tests/notification_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,9 @@ func TestMultiprocessDupeOTKUpload(t *testing.T) {
// at once. If the NSE and main apps are talking to each other, they should be using the same key ID + key.
// If not... well, that's a bug because then the client will forget one of these keys.
mitmConfiguration := tc.Deployment.MITM().Configure(t)
mitmConfiguration.ForPath("/keys/upload").Listen(func(cd deploy.CallbackData) {
mitmConfiguration.ForPath("/keys/upload").Listen(func(cd deploy.CallbackData) *deploy.CallbackResponse {
if cd.AccessToken != aliceAccessToken {
return // let bob upload OTKs
return nil // let bob upload OTKs
}
aliceUploadedNewKeys = true
if cd.ResponseCode != 200 {
Expand All @@ -581,6 +581,7 @@ func TestMultiprocessDupeOTKUpload(t *testing.T) {
// tarpit the response
t.Logf("tarpitting keys/upload response for 4 seconds")
time.Sleep(4 * time.Second)
return nil
})
mitmConfiguration.Execute(func() {
var eventID string
Expand Down
3 changes: 2 additions & 1 deletion tests/one_time_keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,13 @@ func TestFailedKeysClaimRetries(t *testing.T) {
roomID := tc.CreateNewEncryptedRoom(t, tc.Alice, cc.EncRoomOptions.PresetPublicChat())
// block /keys/claim and join the room, causing the Olm session to be created
mitmConfiguration := tc.Deployment.MITM().Configure(t)
mitmConfiguration.ForPath("/keys/claim").Method("POST").BlockRequest(2, http.StatusGatewayTimeout).Listen(func(cd deploy.CallbackData) {
mitmConfiguration.ForPath("/keys/claim").Method("POST").BlockRequest(2, http.StatusGatewayTimeout).Listen(func(cd deploy.CallbackData) *deploy.CallbackResponse {
t.Logf("%+v", cd)
if cd.ResponseCode == 200 {
waiter.Finish()
stopPoking.Store(true)
}
return nil
})
mitmConfiguration.Execute(func() {
// join the room. This should cause an Olm session to be made but it will fail as we cannot
Expand Down
3 changes: 2 additions & 1 deletion tests/room_keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ import (

func sniffToDeviceEvent(t *testing.T, tc *cc.TestContext, ch chan deploy.CallbackData) *deploy.MITMConfiguration {
mitmConfiguration := tc.Deployment.MITM().Configure(t)
mitmConfiguration.ForPath("/sendToDevice").Method("PUT").Listen(func(cd deploy.CallbackData) {
mitmConfiguration.ForPath("/sendToDevice").Method("PUT").Listen(func(cd deploy.CallbackData) *deploy.CallbackResponse {
if strings.Contains(cd.URL, "m.room.encrypted") {
// we can't decrypt this, but we know that this should most likely be the m.room_key to-device event.
ch <- cd
}
return nil
})
return mitmConfiguration
}
Expand Down
Loading

0 comments on commit 9b17d5d

Please sign in to comment.