diff --git a/README.md b/README.md index 06e5a06..44ff25b 100644 --- a/README.md +++ b/README.md @@ -350,6 +350,35 @@ func TestApi(t *testing.T) { } ``` +#### Provide a multipart/form-data with custom filesystem + +```go +inMemFS := fstest.MapFS{ + "audio.wav": &fstest.MapFile{ + Data: []byte{19,2,123,12,35,1}, + Mode: fs.FileMode(0644), + ModTime: time.Now(), + }, + "audio.mp3": &fstest.MapFile{ + Data: []byte{21,13,88,123,9,8}, + Mode: fs.FileMode(0644), + ModTime: time.Now(), + }, +} + +func TestApi(t *testing.T) { + apitest.Handler(handler). + UseFS(inMemFS). + Post("/hello"). + MultipartFormData("a", "1", "2"). + MultipartFile("file", "audio.wav", "audio.mp3"). + Expect(t). + Status(http.StatusOK). + End() +} +``` + + #### Capture the request and response data ```go diff --git a/apitest.go b/apitest.go index 77f2991..84b4f48 100644 --- a/apitest.go +++ b/apitest.go @@ -7,6 +7,7 @@ import ( "fmt" "hash/fnv" "io" + "io/fs" "io/ioutil" "mime/multipart" "net/http" @@ -14,7 +15,6 @@ import ( "net/http/httputil" "net/textproto" "net/url" - "os" "path/filepath" "runtime/debug" "sort" @@ -56,6 +56,7 @@ type APITest struct { meta map[string]interface{} started time.Time finished time.Time + fileSystem fs.FS } // InboundRequest used to wrap the incoming request with a timestamp @@ -98,6 +99,7 @@ func New(name ...string) *APITest { if len(name) > 0 { apiTest.name = name[0] } + apiTest.fileSystem = OSFS{} return apiTest } @@ -221,6 +223,14 @@ func (a *APITest) Response() *Response { return a.response } +// Use filesystem allows you to change to a custom fs.FS filesystem that you can define (Used by Request.MultipartFile). +// e.g: fstest.MapFS +// Your os filesystem is used by default +func (a *APITest) UseFS(fs fs.FS) *APITest { + a.fileSystem = fs + return a +} + // Request is the user defined request that will be invoked on the handler under test type Request struct { interceptor Intercept @@ -532,13 +542,13 @@ func (r *Request) MultipartFile(name string, ff ...string) *Request { for _, f := range ff { func() { - file, err := os.Open(f) + file, err := r.apiTest.fileSystem.Open(f) if err != nil { r.apiTest.t.Fatal(err) } defer file.Close() - part, err := r.multipart.CreateFormFile(name, filepath.Base(file.Name())) + part, err := r.multipart.CreateFormFile(name, filepath.Base(f)) if err != nil { r.apiTest.t.Fatal(err) } diff --git a/apitest_test.go b/apitest_test.go index 8f9822b..0a0a526 100644 --- a/apitest_test.go +++ b/apitest_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io/fs" "io/ioutil" "net/http" "net/http/cookiejar" @@ -13,6 +14,7 @@ import ( "reflect" "strings" "testing" + "testing/fstest" "time" "github.com/steinfletcher/apitest" @@ -1285,6 +1287,99 @@ func TestApiTest_AddsMultipartFormData(t *testing.T) { End() } +func TestApiTest_AddsMultipartFormDataWithCustomFS(t *testing.T) { + handler := http.NewServeMux() + + inMemFS := fstest.MapFS{ + "audio.wav": &fstest.MapFile{ + Data: []byte{19, 2, 123, 12, 35, 1}, + Mode: fs.FileMode(0644), + ModTime: time.Now(), + }, + "audio.mp3": &fstest.MapFile{ + Data: []byte{21, 13, 88, 123, 9, 8}, + Mode: fs.FileMode(0644), + ModTime: time.Now(), + }, + } + + handler.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header["Content-Type"][0], "multipart/form-data") { + w.WriteHeader(http.StatusBadRequest) + return + } + + err := r.ParseMultipartForm(2 << 32) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + expectedPostFormData := map[string][]string{ + "name": {"John"}, + "age": {"99"}, + "children": {"Jack", "Ann"}, + "pets": {"Toby", "Henry", "Alice"}, + } + + for key := range expectedPostFormData { + if !reflect.DeepEqual(expectedPostFormData[key], r.MultipartForm.Value[key]) { + w.WriteHeader(http.StatusBadRequest) + return + } + } + + for _, exp := range []struct { + formName string + fileNames []string + }{ + { + formName: "audio1", + fileNames: []string{"audio.wav", "audio.mp3"}, + }, + { + formName: "audio2", + fileNames: []string{"audio.mp3"}, + }, + } { + formFiles := r.MultipartForm.File[exp.formName] + assert.Equal(t, len(exp.fileNames), len(formFiles), "Number of files do not match") + for i, fileName := range exp.fileNames { + expFile := inMemFS[fileName] + formFile := formFiles[i] + + assert.Equal(t, fileName, formFile.Filename, "File names do not match") + f, err := formFile.Open() + if err != nil { + t.Fatal(err) + } + data, err := ioutil.ReadAll(f) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, expFile.Data, data) + } + } + + w.WriteHeader(http.StatusOK) + }) + + apitest.New(). + UseFS(inMemFS). + Handler(handler). + Post("/hello"). + MultipartFormData("name", "John"). + MultipartFormData("age", "99"). + MultipartFormData("children", "Jack"). + MultipartFormData("children", "Ann"). + MultipartFormData("pets", "Toby", "Henry", "Alice"). + MultipartFile("audio1", "audio.wav", "audio.mp3"). + MultipartFile("audio2", "audio.mp3"). + Expect(t). + Status(http.StatusOK). + End() +} + func TestApiTest_CombineFormDataWithMultipart(t *testing.T) { if os.Getenv("RUN_FATAL_TEST") == "FormData" { apitest.New(). diff --git a/filesystem.go b/filesystem.go new file mode 100644 index 0000000..07cac94 --- /dev/null +++ b/filesystem.go @@ -0,0 +1,14 @@ +package apitest + +import ( + "io/fs" + "os" +) + +//An implementation of fs.FS that wraps your OS's filesystem +type OSFS struct { +} +//Calls os.Open +func (OSFS) Open(name string) (file fs.File, err error) { + return os.Open(name) +}