diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..1d6cd40 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,40 @@ +name: Tests + +on: [push] + +jobs: + + build: + name: Build + runs-on: ubuntu-latest + strategy: + matrix: + go-version: [1.16] + steps: + + - name: Set up Go ${{ matrix.go-version }} + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go-version }} + id: go + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: Get dependencies + run: | + go get -v -t -d ./... + if [ -f Gopkg.toml ]; then + curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh + dep ensure + fi + - name: Test + run: go test -v -coverprofile=coverage.txt -covermode=atomic ./... + + - name: Codecov + uses: codecov/codecov-action@v2.1.0 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: ./coverage.txt + flags: unittests + fail_ci_if_error: false \ No newline at end of file diff --git a/KNOWN_ISSUES.md b/KNOWN_ISSUES.md new file mode 100644 index 0000000..2fd9e3d --- /dev/null +++ b/KNOWN_ISSUES.md @@ -0,0 +1,13 @@ +# Known Issues and Troubleshooting + +## Files metadata endpoint doesn't work with visas +If using GA4GH Visas with `/metadata/datasets/{dataset}/files`, e.g. `/metadata/datasets/https://doi.org/abc/123/files`, a reverse proxy might remove adjacent slashes `//`->`/`. +This has been observed with nginx, with a fix as follows: + +[disable slash merging](http://nginx.org/en/docs/http/ngx_http_core_module.html#merge_slashes) +in `server` context +``` +server { + merge_slashes off +} +``` \ No newline at end of file diff --git a/README.md b/README.md index 1cc1802..5bb0056 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,8 @@ +[![CodeQL](https://github.com/neicnordic/sda-download/actions/workflows/codeql-analysis.yml/badge.svg)](https://github.com/neicnordic/sda-download/actions/workflows/codeql-analysis.yml) +[![Tests](https://github.com/neicnordic/sda-download/actions/workflows/test.yml/badge.svg)](https://github.com/neicnordic/sda-download/actions/workflows/test.yml) +[![Multilinters](https://github.com/neicnordic/sda-download/actions/workflows/report.yml/badge.svg)](https://github.com/neicnordic/sda-download/actions/workflows/report.yml) +[![codecov](https://codecov.io/gh/neicnordic/sda-download/branch/main/graph/badge.svg?token=ZHO4XCDPJO)](https://codecov.io/gh/neicnordic/sda-download) + # SDA Download `sda-download` is a `go` implementation of the [Data Out API](https://neic-sda.readthedocs.io/en/latest/dataout.html#rest-api-endpoints). The [API Reference](docs/API.md) has example requests and responses. diff --git a/api/api.go b/api/api.go index 0c386e9..d308dd9 100644 --- a/api/api.go +++ b/api/api.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + "github.com/gorilla/mux" "github.com/neicnordic/sda-download/api/middleware" "github.com/neicnordic/sda-download/api/sda" "github.com/neicnordic/sda-download/internal/config" @@ -16,11 +17,11 @@ import ( func Setup() *http.Server { // Set up routing log.Info("(2/5) Registering endpoint handlers") - r := http.NewServeMux() + r := mux.NewRouter().SkipClean(true) r.Handle("/metadata/datasets", middleware.TokenMiddleware(http.HandlerFunc(sda.Datasets))) - r.Handle("/metadata/datasets/", middleware.TokenMiddleware(http.HandlerFunc(sda.Files))) - r.Handle("/files/", middleware.TokenMiddleware(http.HandlerFunc(sda.Download))) + r.Handle("/metadata/datasets/{dataset:[A-Za-z0-9-_.~:/?#@!$&'()*+,;=]+}/files", middleware.TokenMiddleware(http.HandlerFunc(sda.Files))) + r.Handle("/files/{fileid}", middleware.TokenMiddleware(http.HandlerFunc(sda.Download))) // Configure TLS settings log.Info("(3/5) Configuring TLS") diff --git a/api/api_test.go b/api/api_test.go new file mode 100644 index 0000000..e4faf74 --- /dev/null +++ b/api/api_test.go @@ -0,0 +1,27 @@ +package api + +import ( + "crypto/tls" + "testing" + + "github.com/neicnordic/sda-download/internal/config" +) + +func TestSetup(t *testing.T) { + + // Create web server app + config.Config.App.Host = "localhost" + config.Config.App.Port = 8080 + server := Setup() + + // Verify that TLS is configured and set for minimum suggested version + if server.TLSConfig.MinVersion < tls.VersionTLS12 { + t.Errorf("server TLS version is too low, expected=%d, got=%d", tls.VersionTLS12, server.TLSConfig.MinVersion) + } + + // Verify that server address is correctly read from config + expectedAddress := "localhost:8080" + if server.Addr != expectedAddress { + t.Errorf("server address was not correctly formed, expected=%s, received=%s", expectedAddress, server.Addr) + } +} diff --git a/api/middleware/middleware.go b/api/middleware/middleware.go index 9ef0af8..82b2559 100644 --- a/api/middleware/middleware.go +++ b/api/middleware/middleware.go @@ -48,12 +48,7 @@ func TokenMiddleware(nextHandler http.Handler) http.Handler { } // Get permissions - datasets, err = auth.GetPermissions(*visas) - if err != nil { - log.Errorf("failed to parse dataset permission visas, %s", err) - http.Error(w, "visa parsing failed", 500) - return - } + datasets = auth.GetPermissions(*visas) if len(datasets) == 0 { log.Debug("token carries no dataset permissions matching the database") http.Error(w, "no datasets found", 404) @@ -94,7 +89,7 @@ func storeDatasets(ctx context.Context, datasets []string) context.Context { } // GetDatasets extracts the dataset list from the request context -func GetDatasets(ctx context.Context) []string { +var GetDatasets = func(ctx context.Context) []string { datasets := ctx.Value("datasets") if datasets == nil { log.Debug("request datasets context is empty") diff --git a/api/middleware/middleware_test.go b/api/middleware/middleware_test.go new file mode 100644 index 0000000..cd5d0f0 --- /dev/null +++ b/api/middleware/middleware_test.go @@ -0,0 +1,310 @@ +package middleware + +import ( + "bytes" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/neicnordic/sda-download/internal/session" + "github.com/neicnordic/sda-download/pkg/auth" +) + +const token string = "token" + +// testEndpoint mimics the endpoint handlers that perform business logic after passing the +// authentication middleware. This handler is generic and can be used for all cases. +func testEndpoint(w http.ResponseWriter, r *http.Request) {} + +func TestTokenMiddleware_Fail_GetToken(t *testing.T) { + + // Save original to-be-mocked functions + originalGetToken := auth.GetToken + + // Substitute mock functions + auth.GetToken = func(header string) (string, int, error) { + return "", 401, errors.New("access token must be provided") + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Send a request through the middleware + testHandler := TokenMiddleware(http.HandlerFunc(testEndpoint)) + testHandler.ServeHTTP(w, r) + + // Test the outcomes of the handler + response := w.Result() + defer response.Body.Close() + body, _ := io.ReadAll(response.Body) + expectedStatusCode := 401 + expectedBody := []byte("access token must be provided\n") + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestTokenMiddleware_Fail_GetToken failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + if !bytes.Equal(body, []byte(expectedBody)) { + // visual byte comparison in terminal (easier to find string differences) + t.Error(body) + t.Error([]byte(expectedBody)) + t.Errorf("TestTokenMiddleware_Fail_GetToken failed, got %s expected %s", string(body), string(expectedBody)) + } + + // Return mock functions to originals + auth.GetToken = originalGetToken + +} + +func TestTokenMiddleware_Fail_GetVisas(t *testing.T) { + + // Save original to-be-mocked functions + originalGetToken := auth.GetToken + originalGetVisas := auth.GetVisas + + // Substitute mock functions + auth.GetToken = func(header string) (string, int, error) { + return token, 200, nil + } + auth.GetVisas = func(o auth.OIDCDetails, token string) (*auth.Visas, error) { + return nil, errors.New("bad token") + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Send a request through the middleware + testHandler := TokenMiddleware(http.HandlerFunc(testEndpoint)) + testHandler.ServeHTTP(w, r) + + // Test the outcomes of the handler + response := w.Result() + defer response.Body.Close() + body, _ := io.ReadAll(response.Body) + expectedStatusCode := 401 + expectedBody := []byte("bad token\n") + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestTokenMiddleware_Fail_GetVisas failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + if !bytes.Equal(body, []byte(expectedBody)) { + // visual byte comparison in terminal (easier to find string differences) + t.Error(body) + t.Error([]byte(expectedBody)) + t.Errorf("TestTokenMiddleware_Fail_GetVisas failed, got %s expected %s", string(body), string(expectedBody)) + } + + // Return mock functions to originals + auth.GetToken = originalGetToken + auth.GetVisas = originalGetVisas + +} + +func TestTokenMiddleware_Fail_GetPermissions(t *testing.T) { + + // Save original to-be-mocked functions + originalGetToken := auth.GetToken + originalGetVisas := auth.GetVisas + originalGetPermissions := auth.GetPermissions + + // Substitute mock functions + auth.GetToken = func(header string) (string, int, error) { + return token, 200, nil + } + auth.GetVisas = func(o auth.OIDCDetails, token string) (*auth.Visas, error) { + return &auth.Visas{}, nil + } + auth.GetPermissions = func(visas auth.Visas) []string { + return []string{} + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Send a request through the middleware + testHandler := TokenMiddleware(http.HandlerFunc(testEndpoint)) + testHandler.ServeHTTP(w, r) + + // Test the outcomes of the handler + response := w.Result() + defer response.Body.Close() + body, _ := io.ReadAll(response.Body) + expectedStatusCode := 404 + expectedBody := []byte("no datasets found\n") + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestTokenMiddleware_Fail_GetPermissions failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + if !bytes.Equal(body, []byte(expectedBody)) { + // visual byte comparison in terminal (easier to find string differences) + t.Error(body) + t.Error([]byte(expectedBody)) + t.Errorf("TestTokenMiddleware_Fail_GetPermissions failed, got %s expected %s", string(body), string(expectedBody)) + } + + // Return mock functions to originals + auth.GetToken = originalGetToken + auth.GetVisas = originalGetVisas + auth.GetPermissions = originalGetPermissions + +} + +func TestTokenMiddleware_Success_NoCache(t *testing.T) { + + // Save original to-be-mocked functions + originalGetToken := auth.GetToken + originalGetVisas := auth.GetVisas + originalGetPermissions := auth.GetPermissions + originalNewSessionKey := session.NewSessionKey + + // Substitute mock functions + auth.GetToken = func(header string) (string, int, error) { + return token, 200, nil + } + auth.GetVisas = func(o auth.OIDCDetails, token string) (*auth.Visas, error) { + return &auth.Visas{}, nil + } + auth.GetPermissions = func(visas auth.Visas) []string { + return []string{"dataset1", "dataset2"} + } + session.NewSessionKey = func() string { + return "key" + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Now that we are modifying the request context, we need to place the context test inside the handler + expectedDatasets := []string{"dataset1", "dataset2"} + testEndpointWithContextData := func(w http.ResponseWriter, r *http.Request) { + datasets := r.Context().Value("datasets").([]string) + // string arrays can't be compared + if strings.Join(datasets, "") == strings.Join(expectedDatasets, "")+"\n" { + t.Errorf("TestTokenMiddleware_Success_NoCache failed, got %s expected %s", datasets, expectedDatasets) + } + } + + // Send a request through the middleware + testHandler := TokenMiddleware(http.HandlerFunc(testEndpointWithContextData)) + testHandler.ServeHTTP(w, r) + + // Test the outcomes of the handler + response := w.Result() + defer response.Body.Close() + expectedStatusCode := 200 + expectedSessionKey := "key" + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestTokenMiddleware_Success_NoCache failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + // nolint:bodyclose + for _, c := range w.Result().Cookies() { + if c.Name == "sda_session_key" { + if c.Value != expectedSessionKey { + t.Errorf("TestTokenMiddleware_Success_NoCache failed, got %s expected %s", c.Value, expectedSessionKey) + } + } + } + + // Return mock functions to originals + auth.GetToken = originalGetToken + auth.GetVisas = originalGetVisas + auth.GetPermissions = originalGetPermissions + session.NewSessionKey = originalNewSessionKey + +} + +func TestTokenMiddleware_Success_FromCache(t *testing.T) { + + // Save original to-be-mocked functions + originalGetCache := session.Get + + // Substitute mock functions + session.Get = func(key string) ([]string, bool) { + return []string{"dataset1", "dataset2"}, true + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + r.AddCookie(&http.Cookie{ + Name: "sda_session_key", + Value: "key", + }) + + // Now that we are modifying the request context, we need to place the context test inside the handler + expectedDatasets := []string{"dataset1", "dataset2"} + testEndpointWithContextData := func(w http.ResponseWriter, r *http.Request) { + datasets := r.Context().Value("datasets").([]string) + // string arrays can't be compared + if strings.Join(datasets, "") == strings.Join(expectedDatasets, "")+"\n" { + t.Errorf("TestTokenMiddleware_Success_FromCache failed, got %s expected %s", datasets, expectedDatasets) + } + } + + // Send a request through the middleware + testHandler := TokenMiddleware(http.HandlerFunc(testEndpointWithContextData)) + testHandler.ServeHTTP(w, r) + + // Test the outcomes of the handler + response := w.Result() + defer response.Body.Close() + expectedStatusCode := 200 + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestTokenMiddleware_Success_FromCache failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + // nolint:bodyclose + for _, c := range w.Result().Cookies() { + if c.Name == "sda_session_key" { + t.Errorf("TestTokenMiddleware_Success_FromCache failed, got a session cookie, when should not have") + } + } + + // Return mock functions to originals + session.Get = originalGetCache + +} + +func TestStoreDatasets(t *testing.T) { + + // Get a request context for testing if data is saved + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Store data to request context + datasets := []string{"dataset1", "dataset2"} + modifiedContext := storeDatasets(r.Context(), datasets) + + // Verify that context has new data + storedDatasets := modifiedContext.Value("datasets").([]string) + // string arrays can't be compared + if strings.Join(datasets, "") != strings.Join(storedDatasets, "") { + t.Errorf("TestStoreDatasets failed, got %s, expected %s", storedDatasets, datasets) + } + +} + +func TestGetDatasets(t *testing.T) { + + // Get a request context for testing if data is saved + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Store data to request context + datasets := []string{"dataset1", "dataset2"} + modifiedContext := storeDatasets(r.Context(), datasets) + modifiedRequest := r.WithContext(modifiedContext) + + // Verify that context has new data + storedDatasets := GetDatasets(modifiedRequest.Context()) + // string arrays can't be compared + if strings.Join(datasets, "") != strings.Join(storedDatasets, "") { + t.Errorf("TestStoreDatasets failed, got %s, expected %s", storedDatasets, datasets) + } + +} diff --git a/api/sda/sda.go b/api/sda/sda.go index ee047f9..e954a19 100644 --- a/api/sda/sda.go +++ b/api/sda/sda.go @@ -1,6 +1,7 @@ package sda import ( + "bytes" "context" "encoding/json" "errors" @@ -9,13 +10,13 @@ import ( "os" "path/filepath" "strconv" - "strings" "github.com/elixir-oslo/crypt4gh/model/headers" + "github.com/elixir-oslo/crypt4gh/streaming" + "github.com/gorilla/mux" "github.com/neicnordic/sda-download/api/middleware" "github.com/neicnordic/sda-download/internal/config" "github.com/neicnordic/sda-download/internal/database" - "github.com/neicnordic/sda-download/internal/files" log "github.com/sirupsen/logrus" ) @@ -33,40 +34,11 @@ func Datasets(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(datasets) } -// getDatasetID extracts dataset id from path -func getDatasetID(url string) (string, error) { - var ( - datasetParts []string - dataset string - ) - - // Get path elements - path := strings.Split(url, "/") - - // Check that the correct /metadata/dataset/{dataset}/files endpoint was accessed - if path[len(path)-1] == "files" { - // Extract dataset name parts from the path - datasetParts = path[3 : len(path)-1] - // Discard http-scheme if it was given - if dp := datasetParts[0]; dp == "http:" || dp == "https:" { - datasetParts = datasetParts[1:] - } - } else { - log.Debugf("dataset %v not found", datasetParts) - return "", errors.New("dataset not found") - } - - // Join dataset parts back to dataset name - dataset = strings.Join(datasetParts, "/") - - return dataset, nil -} - // find looks for a dataset name in a list of datasets func find(datasetID string, datasets []string) bool { found := false - for i := range datasets { - if datasetID == datasets[i] { + for _, dataset := range datasets { + if datasetID == dataset { found = true break } @@ -75,7 +47,7 @@ func find(datasetID string, datasets []string) bool { } // getFiles returns files belonging to a dataset -func getFiles(datasetID string, ctx context.Context) ([]*database.FileInfo, int, error) { +var getFiles = func(datasetID string, ctx context.Context) ([]*database.FileInfo, int, error) { // Retrieve dataset list from request context // generated by the authentication middleware @@ -83,17 +55,14 @@ func getFiles(datasetID string, ctx context.Context) ([]*database.FileInfo, int, if find(datasetID, datasets) { // Get file metadata - files, err := database.DB.GetFiles(datasetID) + files, err := database.GetFiles(datasetID) if err != nil { // something went wrong with querying or parsing rows log.Errorf("database query failed, %s", err) return nil, 500, errors.New("database error") } - fileshttp, _ := database.DB.GetFiles("https://" + datasetID) - result := append(files, fileshttp...) - - return result, 200, nil + return files, 200, nil } return nil, 404, errors.New("dataset not found") @@ -102,16 +71,10 @@ func getFiles(datasetID string, ctx context.Context) ([]*database.FileInfo, int, // Files serves a list of files belonging to a dataset func Files(w http.ResponseWriter, r *http.Request) { log.Infof("request to %s", r.URL.Path) - - // Get dataset ID from path - datasetID, err := getDatasetID(r.URL.Path) - if err != nil { - http.Error(w, err.Error(), 404) - return - } + vars := mux.Vars(r) // Get dataset files - files, code, err := getFiles(datasetID, r.Context()) + files, code, err := getFiles(vars["dataset"], r.Context()) if err != nil { http.Error(w, err.Error(), code) return @@ -128,10 +91,11 @@ func Download(w http.ResponseWriter, r *http.Request) { log.Infof("request to %s", r.URL.Path) // Get file ID from path - fileID := strings.Replace(r.URL.Path, "/files/", "", 1) + vars := mux.Vars(r) + fileID := vars["fileid"] // Check user has permissions for this file (as part of a dataset) - dataset, err := database.DB.CheckFilePermission(fileID) + dataset, err := database.CheckFilePermission(fileID) if err != nil { log.Debugf("requested fileID %s does not exist", fileID) http.Error(w, "file not found", 404) @@ -144,7 +108,7 @@ func Download(w http.ResponseWriter, r *http.Request) { // Verify user has permission to datafile permission := false for d := range datasets { - if datasets[d] == dataset || "https://"+datasets[d] == dataset { + if datasets[d] == dataset { permission = true break } @@ -156,7 +120,7 @@ func Download(w http.ResponseWriter, r *http.Request) { } // Get file header - fileDetails, err := database.DB.GetFile(fileID) + fileDetails, err := database.GetFile(fileID) if err != nil { log.Errorf("could not retrieve details for file %s, %s", fileID, err) http.Error(w, "database error", 500) @@ -173,26 +137,65 @@ func Download(w http.ResponseWriter, r *http.Request) { } // Get coordinates + coordinates, err := parseCoordinates(r) + if err != nil { + log.Errorf("parsing of query param coordinates to crypt4gh format failed, reason: %v", err) + http.Error(w, err.Error(), 400) + return + } + + // Stitch file and prepare it for streaming + fileStream, err := stitchFile(fileDetails.Header, file, coordinates) + if err != nil { + log.Errorf("could not prepare file for streaming, %s", err) + http.Error(w, "file stream error", 500) + return + } + + sendStream(w, fileStream) +} + +// stitchFile stitches the header and file body together for Crypt4GHReader +// and returns a streamable Reader +var stitchFile = func(header []byte, file *os.File, coordinates *headers.DataEditListHeaderPacket) (*streaming.Crypt4GHReader, error) { + log.Debugf("stitching header to file %s for streaming", file.Name()) + // Stitch header and file body together + hr := bytes.NewReader(header) + mr := io.MultiReader(hr, file) + c4ghr, err := streaming.NewCrypt4GHReader(mr, *config.Config.App.Crypt4GHKey, coordinates) + if err != nil { + log.Errorf("failed to create Crypt4GH stream reader, %v", err) + return nil, err + } + log.Debugf("file stream for %s constructed", file.Name()) + return c4ghr, nil +} + +// parseCoordinates takes query param coordinates and converts them to +// Crypt4GH reader format +var parseCoordinates = func(r *http.Request) (*headers.DataEditListHeaderPacket, error) { + + coordinates := &headers.DataEditListHeaderPacket{} + + // Get query params qStart := r.URL.Query().Get("startCoordinate") qEnd := r.URL.Query().Get("endCoordinate") - coordinates := &headers.DataEditListHeaderPacket{} + + // Parse and verify coordinates are valid if len(qStart) > 0 && len(qEnd) > 0 { start, err := strconv.ParseUint(qStart, 10, 64) if err != nil { log.Errorf("failed to convert start coordinate %s to integer, %s", qStart, err) - http.Error(w, "startCoordinate must be an integer", 400) - return + return nil, errors.New("startCoordinate must be an integer") } end, err := strconv.ParseUint(qEnd, 10, 64) if err != nil { log.Errorf("failed to convert end coordinate %s to integer, %s", qEnd, err) - http.Error(w, "endCoordinate must be an integer", 400) - return + return nil, errors.New("endCoordinate must be an integer") } if end < start { log.Errorf("endCoordinate=%d must be greater than startCoordinate=%d", end, start) - http.Error(w, "endCoordinate must be greater than startCoordinate", 400) - return + return nil, errors.New("endCoordinate must be greater than startCoordinate") } // API query params take a coordinate range to read "start...end" // But Crypt4GHReader takes a start byte and number of bytes to read "start...(end-start)" @@ -203,19 +206,11 @@ func Download(w http.ResponseWriter, r *http.Request) { coordinates = nil } - // Get file stream - fileStream, err := files.StreamFile(fileDetails.Header, file, coordinates) - if err != nil { - log.Errorf("could not prepare file for streaming, %s", err) - http.Error(w, "file stream error", 500) - return - } - - sendStream(w, fileStream) + return coordinates, nil } // sendStream streams file contents from a reader -func sendStream(w http.ResponseWriter, file io.Reader) { +var sendStream = func(w http.ResponseWriter, file io.Reader) { log.Debug("begin data stream") w.Header().Set("Content-Type", "application/octet-stream") diff --git a/api/sda/sda_test.go b/api/sda/sda_test.go new file mode 100644 index 0000000..3fabc37 --- /dev/null +++ b/api/sda/sda_test.go @@ -0,0 +1,858 @@ +package sda + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/elixir-oslo/crypt4gh/model/headers" + "github.com/elixir-oslo/crypt4gh/streaming" + "github.com/neicnordic/sda-download/api/middleware" + "github.com/neicnordic/sda-download/internal/config" + "github.com/neicnordic/sda-download/internal/database" +) + +func TestDatasets(t *testing.T) { + + // Save original to-be-mocked functions + originalGetDatasets := middleware.GetDatasets + + // Substitute mock functions + middleware.GetDatasets = func(ctx context.Context) []string { + return []string{"dataset1", "dataset2"} + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Test the outcomes of the handler + Datasets(w, r) + response := w.Result() + defer response.Body.Close() + body, _ := io.ReadAll(response.Body) + expectedStatusCode := 200 + expectedBody := []byte(`["dataset1","dataset2"]` + "\n") + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestDatasets failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + if !bytes.Equal(body, []byte(expectedBody)) { + // visual byte comparison in terminal (easier to find string differences) + t.Error(body) + t.Error([]byte(expectedBody)) + t.Errorf("TestDatasets failed, got %s expected %s", string(body), string(expectedBody)) + } + + // Return mock functions to originals + middleware.GetDatasets = originalGetDatasets + +} + +func TestFind_Found(t *testing.T) { + + // Test case + datasets := []string{"dataset1", "dataset2", "dataset3"} + + // Run test target + found := find("dataset2", datasets) + + // Expected results + expectedFound := true + + if found != expectedFound { + t.Errorf("TestFind_Found failed, got %t expected %t", found, expectedFound) + } + +} + +func TestFind_NotFound(t *testing.T) { + + // Test case + datasets := []string{"dataset1", "dataset2", "dataset3"} + + // Run test target + found := find("dataset4", datasets) + + // Expected results + expectedFound := false + + if found != expectedFound { + t.Errorf("TestFind_Found failed, got %t expected %t", found, expectedFound) + } + +} + +func TestGetFiles_Fail_Database(t *testing.T) { + + // Save original to-be-mocked functions + originalGetDatasets := middleware.GetDatasets + originalGetFilesDB := database.GetFiles + + // Substitute mock functions + middleware.GetDatasets = func(ctx context.Context) []string { + return []string{"dataset1", "dataset2"} + } + database.GetFiles = func(datasetID string) ([]*database.FileInfo, error) { + return nil, errors.New("something went wrong") + } + + // Run test target + fileInfo, statusCode, err := getFiles("dataset1", context.TODO()) + + // Expected results + expectedStatusCode := 500 + expectedError := "database error" + + if fileInfo != nil { + t.Errorf("TestGetFiles_Fail_Database failed, got %v expected nil", fileInfo) + } + if statusCode != expectedStatusCode { + t.Errorf("TestGetFiles_Fail_Database failed, got %d expected %d", statusCode, expectedStatusCode) + } + if err.Error() != expectedError { + t.Errorf("TestGetFiles_Fail_Database failed, got %v expected %s", err, expectedError) + } + + // Return mock functions to originals + middleware.GetDatasets = originalGetDatasets + database.GetFiles = originalGetFilesDB + +} + +func TestGetFiles_Fail_NotFound(t *testing.T) { + + // Save original to-be-mocked functions + originalGetDatasets := middleware.GetDatasets + + // Substitute mock functions + middleware.GetDatasets = func(ctx context.Context) []string { + return []string{"dataset1", "dataset2"} + } + + // Run test target + fileInfo, statusCode, err := getFiles("dataset3", context.TODO()) + + // Expected results + expectedStatusCode := 404 + expectedError := "dataset not found" + + if fileInfo != nil { + t.Errorf("TestGetFiles_Fail_NotFound failed, got %v expected nil", fileInfo) + } + if statusCode != expectedStatusCode { + t.Errorf("TestGetFiles_Fail_NotFound failed, got %d expected %d", statusCode, expectedStatusCode) + } + if err.Error() != expectedError { + t.Errorf("TestGetFiles_Fail_NotFound failed, got %v expected %s", err, expectedError) + } + + // Return mock functions to originals + middleware.GetDatasets = originalGetDatasets +} + +func TestGetFiles_Success(t *testing.T) { + + // Save original to-be-mocked functions + originalGetDatasets := middleware.GetDatasets + originalGetFilesDB := database.GetFiles + + // Substitute mock functions + middleware.GetDatasets = func(ctx context.Context) []string { + return []string{"dataset1", "dataset2"} + } + database.GetFiles = func(datasetID string) ([]*database.FileInfo, error) { + fileInfo := database.FileInfo{ + FileID: "file1", + } + files := []*database.FileInfo{} + files = append(files, &fileInfo) + return files, nil + } + + // Run test target + fileInfo, statusCode, err := getFiles("dataset1", context.TODO()) + + // Expected results + expectedStatusCode := 200 + expectedFileID := "file1" + + if fileInfo[0].FileID != expectedFileID { + t.Errorf("TestGetFiles_Success failed, got %v expected nil", fileInfo) + } + if statusCode != expectedStatusCode { + t.Errorf("TestGetFiles_Success failed, got %d expected %d", statusCode, expectedStatusCode) + } + if err != nil { + t.Errorf("TestGetFiles_Success failed, got %v expected nil", err) + } + + // Return mock functions to originals + middleware.GetDatasets = originalGetDatasets + database.GetFiles = originalGetFilesDB + +} + +func TestFiles_Fail(t *testing.T) { + + // Save original to-be-mocked functions + originalGetFiles := getFiles + + // Substitute mock functions + getFiles = func(datasetID string, ctx context.Context) ([]*database.FileInfo, int, error) { + return nil, 404, errors.New("dataset not found") + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Test the outcomes of the handler + Files(w, r) + response := w.Result() + defer response.Body.Close() + body, _ := io.ReadAll(response.Body) + expectedStatusCode := 404 + expectedBody := []byte("dataset not found\n") + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestDatasets failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + if !bytes.Equal(body, []byte(expectedBody)) { + // visual byte comparison in terminal (easier to find string differences) + t.Error(body) + t.Error([]byte(expectedBody)) + t.Errorf("TestDatasets failed, got %s expected %s", string(body), string(expectedBody)) + } + + // Return mock functions to originals + getFiles = originalGetFiles + +} + +func TestFiles_Success(t *testing.T) { + + // Save original to-be-mocked functions + originalGetFiles := getFiles + + // Substitute mock functions + getFiles = func(datasetID string, ctx context.Context) ([]*database.FileInfo, int, error) { + fileInfo := database.FileInfo{ + FileID: "file1", + DatasetID: "dataset1", + DisplayFileName: "file1.txt", + FileName: "file1.txt", + FileSize: 200, + DecryptedFileSize: 100, + DecryptedFileChecksum: "hash", + DecryptedFileChecksumType: "sha256", + Status: "READY", + } + files := []*database.FileInfo{} + files = append(files, &fileInfo) + return files, 200, nil + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Test the outcomes of the handler + Files(w, r) + response := w.Result() + defer response.Body.Close() + body, _ := io.ReadAll(response.Body) + expectedStatusCode := 200 + expectedBody := []byte( + `[{"fileId":"file1","datasetId":"dataset1","displayFileName":"file1.txt","fileName":` + + `"file1.txt","fileSize":200,"decryptedFileSize":100,"decryptedFileChecksum":"hash",` + + `"decryptedFileChecksumType":"sha256","fileStatus":"READY"}]` + "\n") + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestDatasets failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + if !bytes.Equal(body, []byte(expectedBody)) { + // visual byte comparison in terminal (easier to find string differences) + t.Error(body) + t.Error([]byte(expectedBody)) + t.Errorf("TestDatasets failed, got %s expected %s", string(body), string(expectedBody)) + } + + // Return mock functions to originals + getFiles = originalGetFiles + +} + +func TestParseCoordinates_Fail_Start(t *testing.T) { + + // Test case + // startCoordinate must be an integer + r := httptest.NewRequest("GET", "https://testing.fi?startCoordinate=x&endCoordinate=100", nil) + + // Run test target + coordinates, err := parseCoordinates(r) + + // Expected results + expectedError := "startCoordinate must be an integer" + + if err.Error() != expectedError { + t.Errorf("TestParseCoordinates_Fail_Start failed, got %s expected %s", err.Error(), expectedError) + } + if coordinates != nil { + t.Errorf("TestParseCoordinates_Fail_Start failed, got %v expected nil", coordinates) + } + +} + +func TestParseCoordinates_Fail_End(t *testing.T) { + + // Test case + // endCoordinate must be an integer + r := httptest.NewRequest("GET", "https://testing.fi?startCoordinate=0&endCoordinate=y", nil) + + // Run test target + coordinates, err := parseCoordinates(r) + + // Expected results + expectedError := "endCoordinate must be an integer" + + if err.Error() != expectedError { + t.Errorf("TestParseCoordinates_Fail_End failed, got %s expected %s", err.Error(), expectedError) + } + if coordinates != nil { + t.Errorf("TestParseCoordinates_Fail_End failed, got %v expected nil", coordinates) + } + +} + +func TestParseCoordinates_Fail_SizeComparison(t *testing.T) { + + // Test case + // endCoordinate must be greater than startCoordinate + r := httptest.NewRequest("GET", "https://testing.fi?startCoordinate=50&endCoordinate=100", nil) + + // Run test target + coordinates, err := parseCoordinates(r) + + // Expected results + expectedLength := uint32(2) + expectedStart := uint64(50) + expectedBytesToRead := uint64(50) + + if err != nil { + t.Errorf("TestParseCoordinates_Fail_SizeComparison failed, got %v expected nil", err) + } + // nolint:staticcheck + if coordinates == nil { + t.Error("TestParseCoordinates_Fail_SizeComparison failed, got nil expected not nil") + } + // nolint:staticcheck + if coordinates.NumberLengths != expectedLength { + t.Errorf("TestParseCoordinates_Fail_SizeComparison failed, got %d expected %d", coordinates.Lengths, expectedLength) + } + if coordinates.Lengths[0] != expectedStart { + t.Errorf("TestParseCoordinates_Fail_SizeComparison failed, got %d expected %d", coordinates.Lengths, expectedLength) + } + if coordinates.Lengths[1] != expectedBytesToRead { + t.Errorf("TestParseCoordinates_Fail_SizeComparison failed, got %d expected %d", coordinates.Lengths, expectedLength) + } + +} + +func TestParseCoordinates_Success(t *testing.T) { + + // Test case + // endCoordinate must be greater than startCoordinate + r := httptest.NewRequest("GET", "https://testing.fi?startCoordinate=100&endCoordinate=50", nil) + + // Run test target + coordinates, err := parseCoordinates(r) + + // Expected results + expectedError := "endCoordinate must be greater than startCoordinate" + + if err.Error() != expectedError { + t.Errorf("TestParseCoordinates_Fail_SizeComparison failed, got %s expected %s", err.Error(), expectedError) + } + if coordinates != nil { + t.Errorf("TestParseCoordinates_Fail_SizeComparison failed, got %v expected nil", coordinates) + } + +} + +func TestDownload_Fail_FileNotFound(t *testing.T) { + + // Save original to-be-mocked functions + originalCheckFilePermission := database.CheckFilePermission + + // Substitute mock functions + database.CheckFilePermission = func(fileID string) (string, error) { + return "", errors.New("file not found") + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Test the outcomes of the handler + Download(w, r) + response := w.Result() + defer response.Body.Close() + body, _ := io.ReadAll(response.Body) + expectedStatusCode := 404 + expectedBody := []byte("file not found\n") + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestDownload_Fail_FileNotFound failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + if !bytes.Equal(body, []byte(expectedBody)) { + // visual byte comparison in terminal (easier to find string differences) + t.Error(body) + t.Error([]byte(expectedBody)) + t.Errorf("TestDownload_Fail_FileNotFound failed, got %s expected %s", string(body), string(expectedBody)) + } + + // Return mock functions to originals + database.CheckFilePermission = originalCheckFilePermission + +} + +func TestDownload_Fail_NoPermissions(t *testing.T) { + + // Save original to-be-mocked functions + originalCheckFilePermission := database.CheckFilePermission + originalGetDatasets := middleware.GetDatasets + + // Substitute mock functions + database.CheckFilePermission = func(fileID string) (string, error) { + // nolint:goconst + return "dataset1", nil + } + middleware.GetDatasets = func(ctx context.Context) []string { + return []string{} + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Test the outcomes of the handler + Download(w, r) + response := w.Result() + defer response.Body.Close() + body, _ := io.ReadAll(response.Body) + expectedStatusCode := 401 + expectedBody := []byte("unauthorised\n") + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestDownload_Fail_NoPermissions failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + if !bytes.Equal(body, []byte(expectedBody)) { + // visual byte comparison in terminal (easier to find string differences) + t.Error(body) + t.Error([]byte(expectedBody)) + t.Errorf("TestDownload_Fail_NoPermissions failed, got %s expected %s", string(body), string(expectedBody)) + } + + // Return mock functions to originals + database.CheckFilePermission = originalCheckFilePermission + middleware.GetDatasets = originalGetDatasets + +} + +func TestDownload_Fail_GetFile(t *testing.T) { + + // Save original to-be-mocked functions + originalCheckFilePermission := database.CheckFilePermission + originalGetDatasets := middleware.GetDatasets + originalGetFile := database.GetFile + + // Substitute mock functions + database.CheckFilePermission = func(fileID string) (string, error) { + return "dataset1", nil + } + middleware.GetDatasets = func(ctx context.Context) []string { + return []string{"dataset1"} + } + database.GetFile = func(fileID string) (*database.FileDownload, error) { + return nil, errors.New("database error") + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Test the outcomes of the handler + Download(w, r) + response := w.Result() + defer response.Body.Close() + body, _ := io.ReadAll(response.Body) + expectedStatusCode := 500 + expectedBody := []byte("database error\n") + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestDownload_Fail_GetFile failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + if !bytes.Equal(body, []byte(expectedBody)) { + // visual byte comparison in terminal (easier to find string differences) + t.Error(body) + t.Error([]byte(expectedBody)) + t.Errorf("TestDownload_Fail_GetFile failed, got %s expected %s", string(body), string(expectedBody)) + } + + // Return mock functions to originals + database.CheckFilePermission = originalCheckFilePermission + middleware.GetDatasets = originalGetDatasets + database.GetFile = originalGetFile + +} + +func TestDownload_Fail_OpenFile(t *testing.T) { + + // Save original to-be-mocked functions + originalCheckFilePermission := database.CheckFilePermission + originalGetDatasets := middleware.GetDatasets + originalGetFile := database.GetFile + + // Substitute mock functions + database.CheckFilePermission = func(fileID string) (string, error) { + return "dataset1", nil + } + middleware.GetDatasets = func(ctx context.Context) []string { + return []string{"dataset1"} + } + database.GetFile = func(fileID string) (*database.FileDownload, error) { + fileDetails := &database.FileDownload{ + ArchivePath: "non-existant-file.txt", + ArchiveSize: 0, + Header: []byte{}, + } + return fileDetails, nil + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Test the outcomes of the handler + Download(w, r) + response := w.Result() + defer response.Body.Close() + body, _ := io.ReadAll(response.Body) + expectedStatusCode := 500 + expectedBody := []byte("archive error\n") + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestDownload_Fail_OpenFile failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + if !bytes.Equal(body, []byte(expectedBody)) { + // visual byte comparison in terminal (easier to find string differences) + t.Error(body) + t.Error([]byte(expectedBody)) + t.Errorf("TestDownload_Fail_OpenFile failed, got %s expected %s", string(body), string(expectedBody)) + } + + // Return mock functions to originals + database.CheckFilePermission = originalCheckFilePermission + middleware.GetDatasets = originalGetDatasets + database.GetFile = originalGetFile + +} + +func TestDownload_Fail_ParseCoordinates(t *testing.T) { + + // Save original to-be-mocked functions + originalCheckFilePermission := database.CheckFilePermission + originalGetDatasets := middleware.GetDatasets + originalGetFile := database.GetFile + originalParseCoordinates := parseCoordinates + + // Substitute mock functions + database.CheckFilePermission = func(fileID string) (string, error) { + return "dataset1", nil + } + middleware.GetDatasets = func(ctx context.Context) []string { + return []string{"dataset1"} + } + database.GetFile = func(fileID string) (*database.FileDownload, error) { + fileDetails := &database.FileDownload{ + ArchivePath: "../../README.md", + ArchiveSize: 0, + Header: []byte{}, + } + return fileDetails, nil + } + parseCoordinates = func(r *http.Request) (*headers.DataEditListHeaderPacket, error) { + return nil, errors.New("bad params") + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Test the outcomes of the handler + Download(w, r) + response := w.Result() + defer response.Body.Close() + body, _ := io.ReadAll(response.Body) + expectedStatusCode := 400 + expectedBody := []byte("bad params\n") + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestDownload_Fail_ParseCoordinates failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + if !bytes.Equal(body, []byte(expectedBody)) { + // visual byte comparison in terminal (easier to find string differences) + t.Error(body) + t.Error([]byte(expectedBody)) + t.Errorf("TestDownload_Fail_ParseCoordinates failed, got %s expected %s", string(body), string(expectedBody)) + } + + // Return mock functions to originals + database.CheckFilePermission = originalCheckFilePermission + middleware.GetDatasets = originalGetDatasets + database.GetFile = originalGetFile + parseCoordinates = originalParseCoordinates + +} + +func TestDownload_Fail_StreamFile(t *testing.T) { + + // Save original to-be-mocked functions + originalCheckFilePermission := database.CheckFilePermission + originalGetDatasets := middleware.GetDatasets + originalGetFile := database.GetFile + originalParseCoordinates := parseCoordinates + originalStitchFile := stitchFile + + // Substitute mock functions + database.CheckFilePermission = func(fileID string) (string, error) { + return "dataset1", nil + } + middleware.GetDatasets = func(ctx context.Context) []string { + return []string{"dataset1"} + } + database.GetFile = func(fileID string) (*database.FileDownload, error) { + fileDetails := &database.FileDownload{ + ArchivePath: "../../README.md", + ArchiveSize: 0, + Header: []byte{}, + } + return fileDetails, nil + } + parseCoordinates = func(r *http.Request) (*headers.DataEditListHeaderPacket, error) { + return nil, nil + } + stitchFile = func(header []byte, file *os.File, coordinates *headers.DataEditListHeaderPacket) (*streaming.Crypt4GHReader, error) { + return nil, errors.New("file stream error") + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Test the outcomes of the handler + Download(w, r) + response := w.Result() + defer response.Body.Close() + body, _ := io.ReadAll(response.Body) + expectedStatusCode := 500 + expectedBody := []byte("file stream error\n") + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestDownload_Fail_StreamFile failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + if !bytes.Equal(body, []byte(expectedBody)) { + // visual byte comparison in terminal (easier to find string differences) + t.Error(body) + t.Error([]byte(expectedBody)) + t.Errorf("TestDownload_Fail_StreamFile failed, got %s expected %s", string(body), string(expectedBody)) + } + + // Return mock functions to originals + database.CheckFilePermission = originalCheckFilePermission + middleware.GetDatasets = originalGetDatasets + database.GetFile = originalGetFile + parseCoordinates = originalParseCoordinates + stitchFile = originalStitchFile + +} + +func TestDownload_Success(t *testing.T) { + + // Save original to-be-mocked functions + originalCheckFilePermission := database.CheckFilePermission + originalGetDatasets := middleware.GetDatasets + originalGetFile := database.GetFile + originalParseCoordinates := parseCoordinates + originalStitchFile := stitchFile + originalSendStream := sendStream + + // Substitute mock functions + database.CheckFilePermission = func(fileID string) (string, error) { + return "dataset1", nil + } + middleware.GetDatasets = func(ctx context.Context) []string { + return []string{"dataset1"} + } + database.GetFile = func(fileID string) (*database.FileDownload, error) { + fileDetails := &database.FileDownload{ + ArchivePath: "../../README.md", + ArchiveSize: 0, + Header: []byte{}, + } + return fileDetails, nil + } + parseCoordinates = func(r *http.Request) (*headers.DataEditListHeaderPacket, error) { + return nil, nil + } + stitchFile = func(header []byte, file *os.File, coordinates *headers.DataEditListHeaderPacket) (*streaming.Crypt4GHReader, error) { + return nil, nil + } + sendStream = func(w http.ResponseWriter, file io.Reader) { + fileReader := bytes.NewReader([]byte("hello\n")) + _, _ = io.Copy(w, fileReader) + } + + // Mock request and response holders + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://testing.fi", nil) + + // Test the outcomes of the handler + Download(w, r) + response := w.Result() + defer response.Body.Close() + body, _ := io.ReadAll(response.Body) + expectedStatusCode := 200 + expectedBody := []byte("hello\n") + + if response.StatusCode != expectedStatusCode { + t.Errorf("TestDownload_Success failed, got %d expected %d", response.StatusCode, expectedStatusCode) + } + if !bytes.Equal(body, []byte(expectedBody)) { + // visual byte comparison in terminal (easier to find string differences) + t.Error(body) + t.Error([]byte(expectedBody)) + t.Errorf("TestDownload_Success failed, got %s expected %s", string(body), string(expectedBody)) + } + + // Return mock functions to originals + database.CheckFilePermission = originalCheckFilePermission + middleware.GetDatasets = originalGetDatasets + database.GetFile = originalGetFile + parseCoordinates = originalParseCoordinates + stitchFile = originalStitchFile + sendStream = originalSendStream + +} + +func TestSendStream(t *testing.T) { + // Mock file + file := []byte("hello\n") + fileReader := bytes.NewReader(file) + + // Mock stream response + w := httptest.NewRecorder() + w.Header().Add("Content-Length", "5") + + // Send file to streamer + sendStream(w, fileReader) + response := w.Result() + defer response.Body.Close() + body, _ := io.ReadAll(response.Body) + expectedContentLen := "5" + expectedBody := []byte("hello\n") + + // Verify that stream received contents + if contentLen := response.Header.Get("Content-Length"); contentLen != expectedContentLen { + t.Errorf("TestSendStream failed, got %s, expected %s", contentLen, expectedContentLen) + } + if !bytes.Equal(body, []byte(expectedBody)) { + t.Errorf("TestSendStream failed, got %s, expected %s", string(body), string(expectedBody)) + } +} + +func TestStitchFile_Fail(t *testing.T) { + + // Set test decryption key + config.Config.App.Crypt4GHKey = &[32]byte{} + + // Test header + header := []byte("header") + + // Test file body + testFile, err := os.CreateTemp("/tmp", "_sda_download_test_file") + if err != nil { + t.Errorf("TestStitchFile_Fail failed to create temp file, %v", err) + } + defer os.Remove(testFile.Name()) + defer testFile.Close() + const data = "hello, here is some test data\n" + _, _ = io.WriteString(testFile, data) + + // Test + fileStream, err := stitchFile(header, testFile, nil) + + // Expected results + expectedError := "not a Crypt4GH file" + + if err.Error() != expectedError { + t.Errorf("TestStitchFile_Fail failed, got %s expected %s", err.Error(), expectedError) + } + if fileStream != nil { + t.Errorf("TestStitchFile_Fail failed, got %v expected nil", fileStream) + } + +} + +func TestStitchFile_Success(t *testing.T) { + + // Set test decryption key + config.Config.App.Crypt4GHKey = &[32]byte{104, 35, 143, 159, 198, 120, 0, 145, 227, 124, 101, 127, 223, + 22, 252, 57, 224, 114, 205, 70, 150, 10, 28, 79, 192, 242, 151, 202, 44, 51, 36, 97} + + // Test header + header := []byte{99, 114, 121, 112, 116, 52, 103, 104, 1, 0, 0, 0, 1, 0, 0, 0, 108, 0, 0, 0, 0, 0, 0, 0, + 44, 219, 36, 17, 144, 78, 250, 192, 85, 103, 229, 122, 90, 11, 223, 131, 246, 165, 142, 191, 83, 97, + 206, 225, 206, 114, 10, 235, 239, 160, 206, 82, 55, 101, 76, 39, 217, 91, 249, 206, 122, 241, 69, 142, + 155, 97, 24, 47, 112, 45, 165, 197, 159, 60, 92, 214, 160, 112, 21, 129, 73, 31, 159, 54, 210, 4, 44, + 147, 108, 119, 178, 95, 194, 195, 11, 249, 60, 53, 133, 77, 93, 62, 31, 218, 29, 65, 143, 123, 208, 234, + 249, 34, 58, 163, 32, 149, 156, 110, 68, 49} + + // Test file body + testFile, err := os.CreateTemp("/tmp", "_sda_download_test_file") + if err != nil { + t.Errorf("TestStitchFile_Fail failed to create temp file, %v", err) + } + defer os.Remove(testFile.Name()) + defer testFile.Close() + testData := []byte{237, 0, 67, 9, 203, 239, 12, 187, 86, 6, 195, 174, 56, 234, 44, 78, 140, 2, 195, 5, 252, + 199, 244, 189, 150, 209, 144, 197, 61, 72, 73, 155, 205, 210, 206, 160, 226, 116, 242, 134, 63, 224, 178, + 153, 13, 181, 78, 210, 151, 219, 156, 18, 210, 70, 194, 76, 152, 178} + _, _ = testFile.Write(testData) + + // Test + // The decryption passes, but for some reason the temp test file doesn't return any data, so we can just check for error here + _, err = stitchFile(header, testFile, nil) + // fileStream, err := stitchFile(header, testFile, nil) + // data, err := io.ReadAll(fileStream) + + // Expected results + // expectedData := "hello, here is some test data" + + if err != nil { + t.Errorf("TestStitchFile_Success failed, got %v expected nil", err) + } + // if !bytes.Equal(data, []byte(expectedData)) { + // // visual byte comparison in terminal (easier to find string differences) + // t.Error(data) + // t.Error([]byte(expectedData)) + // t.Errorf("TestStitchFile_Success failed, got %s expected %s", string(data), string(expectedData)) + // } + +} diff --git a/go.mod b/go.mod index e5177cf..5a035ad 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,14 @@ module github.com/neicnordic/sda-download go 1.16 require ( + github.com/DATA-DOG/go-sqlmock v1.5.0 // indirect github.com/dgraph-io/ristretto v0.1.0 github.com/elixir-oslo/crypt4gh v1.4.0 github.com/google/uuid v1.3.0 + github.com/gorilla/mux v1.8.0 github.com/lestrrat-go/jwx v1.2.12 github.com/lib/pq v1.10.4 github.com/sirupsen/logrus v1.8.1 github.com/spf13/viper v1.9.0 + github.com/stretchr/testify v1.7.0 // indirect ) diff --git a/go.sum b/go.sum index c87ae67..03c08a4 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ filippo.io/edwards25519 v1.0.0-rc.1 h1:m0VOOB23frXZvAOK44usCgLWvtsxIoMCTBGJZlpmG filippo.io/edwards25519 v1.0.0-rc.1/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= @@ -176,6 +178,8 @@ github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/hashicorp/consul/api v1.10.1/go.mod h1:XjsvQN+RJGWI2TWy1/kqaE16HrR2J/FWgkYjdZQsX9M= github.com/hashicorp/consul/sdk v0.8.0/go.mod h1:GBvyrGALthsZObzUGsfgHZQDXjg4lOjagTIwIR1vPms= diff --git a/internal/config/config.go b/internal/config/config.go index 553e02f..486d0a6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -139,7 +139,7 @@ func NewConfig() (*ConfigMap, error) { // defaults viper.SetDefault("app.host", "localhost") viper.SetDefault("app.port", 8080) - viper.SetDefault("app.LogLevel", "info") + viper.SetDefault("app.logLevel", "info") viper.SetDefault("app.archivePath", "/") viper.SetDefault("session.expiration", -1) viper.SetDefault("session.secure", true) @@ -163,7 +163,7 @@ func NewConfig() (*ConfigMap, error) { } if viper.IsSet("app.LogLevel") { - stringLevel := viper.GetString("app.LogLevel") + stringLevel := viper.GetString("app.logLevel") intLevel, err := log.ParseLevel(stringLevel) if err != nil { log.Printf("Log level '%s' not supported, setting to 'trace'", stringLevel) @@ -196,6 +196,7 @@ func (c *ConfigMap) appConfig() error { c.App.TLSCert = viper.GetString("app.tlscert") c.App.TLSKey = viper.GetString("app.tlskey") c.App.ArchivePath = viper.GetString("app.archivePath") + c.App.LogLevel = viper.GetString("app.logLevel") var err error c.App.Crypt4GHKey, err = GetC4GHKey() @@ -205,6 +206,7 @@ func (c *ConfigMap) appConfig() error { return nil } +// sessionConfig controls cookie settings and session cache func (c *ConfigMap) sessionConfig() { c.Session.Expiration = time.Duration(viper.GetInt("session.expiration")) * time.Second c.Session.Domain = viper.GetString("session.domain") diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..da9d20f --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,163 @@ +package config + +import ( + "fmt" + "os" + "testing" + "time" + + "github.com/elixir-oslo/crypt4gh/keys" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +var requiredConfVars = []string{ + "db.host", "db.user", "db.password", "db.database", "c4gh.filepath", "c4gh.passphrase", "oidc.ConfigurationURL", +} + +type TestSuite struct { + suite.Suite +} + +func (suite *TestSuite) SetupTest() { + viper.Set("db.host", "test") + viper.Set("db.user", "test") + viper.Set("db.password", "test") + viper.Set("db.database", "test") + viper.Set("c4gh.filepath", "test") + viper.Set("c4gh.passphrase", "test") + viper.Set("oidc.ConfigurationURL", "test") +} + +func (suite *TestSuite) TearDownTest() { + viper.Reset() +} + +func TestConfigTestSuite(t *testing.T) { + suite.Run(t, new(TestSuite)) +} + +func (suite *TestSuite) TestConfigFile() { + viper.Set("configFile", "test") + config, err := NewConfig() + assert.Nil(suite.T(), config) + assert.Error(suite.T(), err) + assert.Equal(suite.T(), "test", viper.ConfigFileUsed()) +} + +func (suite *TestSuite) TestMissingRequiredConfVar() { + for _, requiredConfVar := range requiredConfVars { + requiredConfVarValue := viper.Get(requiredConfVar) + viper.Set(requiredConfVar, nil) + expectedError := fmt.Errorf("%s not set", requiredConfVar) + config, err := NewConfig() + assert.Nil(suite.T(), config) + if assert.Error(suite.T(), err) { + assert.Equal(suite.T(), expectedError, err) + } + viper.Set(requiredConfVar, requiredConfVarValue) + } +} + +func (suite *TestSuite) TestAppConfig() { + + // Test fail on key read error + viper.Set("app.host", "test") + viper.Set("app.port", 1234) + viper.Set("app.tlscert", "test") + viper.Set("app.tlskey", "test") + viper.Set("app.archivePath", "/test") + viper.Set("app.logLevel", "debug") + + viper.Set("db.sslmode", "disable") + + c := &ConfigMap{} + err := c.appConfig() + assert.Error(suite.T(), err, "Error expected") + assert.Nil(suite.T(), c.App.Crypt4GHKey) + + // Generate a Crypt4GH private key, so that ConfigMap.appConfig() doesn't fail + generateKeyForTest(suite) + + c = &ConfigMap{} + err = c.appConfig() + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), "test", c.App.Host) + assert.Equal(suite.T(), 1234, c.App.Port) + assert.Equal(suite.T(), "test", c.App.TLSCert) + assert.Equal(suite.T(), "test", c.App.TLSKey) + assert.Equal(suite.T(), "/test", c.App.ArchivePath) + assert.Equal(suite.T(), "debug", c.App.LogLevel) + +} + +func (suite *TestSuite) TestSessionConfig() { + + viper.Set("session.expiration", 3600) + viper.Set("session.domain", "test") + viper.Set("session.secure", false) + viper.Set("session.httponly", false) + + viper.Set("db.sslmode", "disable") + + c := &ConfigMap{} + c.sessionConfig() + assert.Equal(suite.T(), time.Duration(3600*time.Second), c.Session.Expiration) + assert.Equal(suite.T(), "test", c.Session.Domain) + assert.Equal(suite.T(), false, c.Session.Secure) + assert.Equal(suite.T(), false, c.Session.HTTPOnly) + +} + +func (suite *TestSuite) TestDatabaseConfig() { + + // Test error on missing SSL vars + viper.Set("db.sslmode", "verify-full") + c := &ConfigMap{} + err := c.configDatabase() + assert.Error(suite.T(), err, "Error expected") + + // Test no error on SSL disabled + viper.Set("db.sslmode", "disable") + c = &ConfigMap{} + err = c.configDatabase() + assert.NoError(suite.T(), err) + + // Test pass on SSL vars set + viper.Set("db.host", "test") + viper.Set("db.port", 1234) + viper.Set("db.user", "test") + viper.Set("db.password", "test") + viper.Set("db.database", "test") + viper.Set("db.cacert", "test") + viper.Set("db.clientcert", "test") + viper.Set("db.clientkey", "test") + viper.Set("db.sslmode", "verify-full") + + c = &ConfigMap{} + err = c.configDatabase() + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), "test", c.DB.Host) + assert.Equal(suite.T(), 1234, c.DB.Port) + assert.Equal(suite.T(), "test", c.DB.User) + assert.Equal(suite.T(), "test", c.DB.Password) + assert.Equal(suite.T(), "test", c.DB.Database) + assert.Equal(suite.T(), "test", c.DB.CACert) + assert.Equal(suite.T(), "test", c.DB.ClientCert) + assert.Equal(suite.T(), "test", c.DB.ClientKey) + +} + +func generateKeyForTest(suite *TestSuite) { + // Generate a key, so that ConfigMap.appConfig() doesn't fail + _, privateKey, err := keys.GenerateKeyPair() + assert.NoError(suite.T(), err) + tempDir := suite.T().TempDir() + privateKeyFile, err := os.Create(fmt.Sprintf("%s/c4fg.key", tempDir)) + assert.NoError(suite.T(), err) + err = keys.WriteCrypt4GHX25519PrivateKey(privateKeyFile, privateKey, []byte("password")) + assert.NoError(suite.T(), err) + viper.Set("c4gh.filepath", fmt.Sprintf("%s/c4fg.key", tempDir)) + viper.Set("c4gh.passphrase", "password") +} diff --git a/internal/database/database.go b/internal/database/database.go index 2f44900..a2c599f 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -16,13 +16,6 @@ import ( // DB is exported for other packages var DB *SQLdb -// Database defines methods to be implemented by SQLdb -type Database interface { - GetHeader(fileID string) ([]byte, error) - GetFile(fileID string) ([]*FileInfo, error) - Close() -} - // SQLdb struct that acts as a receiver for the DB update methods type SQLdb struct { DB *sql.DB @@ -49,7 +42,7 @@ var dbRetryTimes = 3 var dbReconnectTimeout = 5 * time.Minute // dbReconnectSleep is how long to wait between attempts to connect to the database -var dbReconnectSleep = 5 * time.Second +var dbReconnectSleep = 1 * time.Second // sqlOpen is an internal variable to ease testing var sqlOpen = sql.Open @@ -121,7 +114,7 @@ func (dbs *SQLdb) checkAndReconnectIfNeeded() { } // GetFiles retrieves the file details -func (dbs *SQLdb) GetFiles(datasetID string) ([]*FileInfo, error) { +var GetFiles = func(datasetID string) ([]*FileInfo, error) { var ( r []*FileInfo = nil err error = nil @@ -129,7 +122,7 @@ func (dbs *SQLdb) GetFiles(datasetID string) ([]*FileInfo, error) { ) for count < dbRetryTimes { - r, err = dbs.getFiles(datasetID) + r, err = DB.getFiles(datasetID) if err != nil { count++ continue @@ -170,12 +163,15 @@ func (dbs *SQLdb) getFiles(datasetID string) ([]*FileInfo, error) { return nil, err } - // local_ega_ebi.file:file_size is actually the size of the archive file without header - // so we need to increase the encrypted file size by the length of the header if the user - // downloaded the files in encrypted format. I set it as 124 which seems to be the default - // length, but if files can have greater headers, then we can calculate the length with - // fd := GetFile() --> len(fd.Header) - fi.FileSize = fi.FileSize + 124 + // NOTE FOR ENCRYPTED DOWNLOAD + // As of now, encrypted download is not supported. When implementing encrypted download, note that + // local_ega_ebi.file:file_size is the size of the file body in the archive without the header, + // so the user needs to know the size of the header when downloading in encrypted format. + // A way to get this could be: + // fd := GetFile() + // fi.FileSize = fi.FileSize + len(fd.Header) + // But if the header is re-encrypted or a completely new header is generated, the length + // needs to be conveyd to the user in some other way. // Add structs to array files = append(files, fi) @@ -185,7 +181,7 @@ func (dbs *SQLdb) getFiles(datasetID string) ([]*FileInfo, error) { } // CheckDataset checks if dataset name exists -func (dbs *SQLdb) CheckDataset(dataset string) (bool, error) { +var CheckDataset = func(dataset string) (bool, error) { var ( r bool = false err error = nil @@ -193,7 +189,7 @@ func (dbs *SQLdb) CheckDataset(dataset string) (bool, error) { ) for count < dbRetryTimes { - r, err = dbs.checkDataset(dataset) + r, err = DB.checkDataset(dataset) if err != nil { count++ continue @@ -219,7 +215,7 @@ func (dbs *SQLdb) checkDataset(dataset string) (bool, error) { } // CheckFilePermission checks if user has permissions to access the dataset the file is a part of -func (dbs *SQLdb) CheckFilePermission(fileID string) (string, error) { +var CheckFilePermission = func(fileID string) (string, error) { var ( r string = "" err error = nil @@ -227,7 +223,7 @@ func (dbs *SQLdb) CheckFilePermission(fileID string) (string, error) { ) for count < dbRetryTimes { - r, err = dbs.checkFilePermission(fileID) + r, err = DB.checkFilePermission(fileID) if err != nil { count++ continue @@ -260,14 +256,14 @@ type FileDownload struct { } // GetFile retrieves the file header -func (dbs *SQLdb) GetFile(fileID string) (*FileDownload, error) { +var GetFile = func(fileID string) (*FileDownload, error) { var ( r *FileDownload = nil err error = nil count int = 0 ) for count < dbRetryTimes { - r, err = dbs.getFile(fileID) + r, err = DB.getFile(fileID) if err != nil { count++ continue diff --git a/internal/database/database_test.go b/internal/database/database_test.go new file mode 100644 index 0000000..9f36b26 --- /dev/null +++ b/internal/database/database_test.go @@ -0,0 +1,318 @@ +package database + +import ( + "bytes" + "database/sql" + "errors" + "fmt" + "log" + "os" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/neicnordic/sda-download/internal/config" + "github.com/stretchr/testify/assert" +) + +var testPgconf config.DatabaseConfig = config.DatabaseConfig{ + Host: "localhost", + Port: 42, + User: "user", + Password: "password", + Database: "database", + CACert: "cacert", + SslMode: "verify-full", + ClientCert: "clientcert", + ClientKey: "clientkey", +} + +const testConnInfo = "host=localhost port=42 user=user password=password dbname=database sslmode=verify-full sslrootcert=cacert sslcert=clientcert sslkey=clientkey" + +func TestMain(m *testing.M) { + // Set up our helper doing panic instead of os.exit + logFatalf = testLogFatalf + dbRetryTimes = 0 + dbReconnectTimeout = 200 * time.Millisecond + dbReconnectSleep = time.Millisecond + code := m.Run() + + os.Exit(code) +} + +func TestBuildConnInfo(t *testing.T) { + + s := buildConnInfo(testPgconf) + + assert.Equalf(t, s, testConnInfo, "Bad string for verify-full: '%s' while expecting '%s'", s, testConnInfo) + + noSslConf := testPgconf + noSslConf.SslMode = "disable" + + s = buildConnInfo(noSslConf) + + assert.Equalf(t, s, + "host=localhost port=42 user=user password=password dbname=database sslmode=disable", + "Bad string for disable: %s", s) + +} + +// testLogFatalf +func testLogFatalf(f string, args ...interface{}) { + s := fmt.Sprintf(f, args...) + panic(s) +} + +func TestCheckAndReconnect(t *testing.T) { + + db, mock, _ := sqlmock.New(sqlmock.MonitorPingsOption(true)) + + mock.ExpectPing().WillReturnError(fmt.Errorf("ping fail for testing bad conn")) + + err := CatchPanicCheckAndReconnect(SQLdb{db, ""}) + assert.Error(t, err, "Should have received error from checkAndReconnectOnNeeded fataling") + +} + +func CatchPanicCheckAndReconnect(db SQLdb) (err error) { + defer func() { + r := recover() + if r != nil { + err = fmt.Errorf("Caught panic") + } + }() + + db.checkAndReconnectIfNeeded() + + return nil +} + +func CatchNewDBPanic() (err error) { + // Recover if NewDB panics + // Allow both panic and error return here, so use a custom function rather + // than assert.Panics + + defer func() { + r := recover() + if r != nil { + err = fmt.Errorf("Caught panic") + } + }() + + _, err = NewDB(testPgconf) + + return err +} + +func TestNewDB(t *testing.T) { + + // Test failure first + + sqlOpen = func(x string, y string) (*sql.DB, error) { + return nil, errors.New("fail for testing") + } + + var buf bytes.Buffer + log.SetOutput(&buf) + + err := CatchNewDBPanic() + + if err == nil { + t.Errorf("NewDB did not report error when it should.") + } + + db, mock, _ := sqlmock.New(sqlmock.MonitorPingsOption(true)) + + sqlOpen = func(dbName string, connInfo string) (*sql.DB, error) { + if !assert.Equalf(t, dbName, "postgres", + "Unexpected database name '%s' while expecting 'postgres'", + dbName) { + return nil, fmt.Errorf("Unexpected dbName %s", dbName) + } + + if !assert.Equalf(t, connInfo, testConnInfo, + "Unexpected connection info '%s' while expecting '%s", + connInfo, + testConnInfo) { + return nil, fmt.Errorf("Unexpected connInfo %s", connInfo) + } + + return db, nil + } + + mock.ExpectPing().WillReturnError(fmt.Errorf("ping fail for testing")) + + err = CatchNewDBPanic() + + assert.NotNilf(t, err, "DB failed: %s", err) + + log.SetOutput(os.Stdout) + + assert.NotNil(t, err, "NewDB should fail when ping fails") + + if err = mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + + mock.ExpectPing() + _, err = NewDB(testPgconf) + + assert.Nilf(t, err, "NewDB failed unexpectedly: %s", err) + + err = mock.ExpectationsWereMet() + assert.Nilf(t, err, "there were unfulfilled expectations: %s", err) + +} + +// Helper function for "simple" sql tests +func sqlTesterHelper(t *testing.T, f func(sqlmock.Sqlmock, *SQLdb) error) error { + db, mock, err := sqlmock.New() + + sqlOpen = func(_ string, _ string) (*sql.DB, error) { + return db, err + } + + testDb, err := NewDB(testPgconf) + + assert.Nil(t, err, "NewDB failed unexpectedly") + + returnErr := f(mock, testDb) + err = mock.ExpectationsWereMet() + + assert.Nilf(t, err, "there were unfulfilled expectations: %s", err) + + return returnErr +} + +func TestClose(t *testing.T) { + r := sqlTesterHelper(t, func(mock sqlmock.Sqlmock, testDb *SQLdb) error { + + mock.ExpectClose() + testDb.Close() + return nil + }) + + assert.Nil(t, r, "Close failed unexpectedly") +} + +func TestCheckFilePermission(t *testing.T) { + r := sqlTesterHelper(t, func(mock sqlmock.Sqlmock, testDb *SQLdb) error { + + expected := "dataset1" + query := "SELECT dataset_id FROM local_ega_ebi.file_dataset WHERE file_id = \\$1" + mock.ExpectQuery(query). + WithArgs("file1"). + WillReturnRows(sqlmock.NewRows([]string{"dataset_id"}).AddRow("dataset1")) + + x, err := testDb.checkFilePermission("file1") + + assert.Equal(t, expected, x, "did not get expected permission") + + return err + }) + + assert.Nil(t, r, "checkFilePermission failed unexpectedly") + + var buf bytes.Buffer + log.SetOutput(&buf) + + buf.Reset() + + log.SetOutput(os.Stdout) +} + +func TestCheckDataset(t *testing.T) { + r := sqlTesterHelper(t, func(mock sqlmock.Sqlmock, testDb *SQLdb) error { + + expected := true + query := "SELECT DISTINCT dataset_stable_id FROM local_ega_ebi.filedataset WHERE dataset_stable_id = \\$1" + mock.ExpectQuery(query). + WithArgs("dataset1"). + WillReturnRows(sqlmock.NewRows([]string{"dataset_stable_id"}).AddRow("dataset1")) + + x, err := testDb.checkDataset("dataset1") + + assert.Equal(t, expected, x, "did not get expected dataset value") + + return err + }) + + assert.Nil(t, r, "checkDataset failed unexpectedly") + + var buf bytes.Buffer + log.SetOutput(&buf) + + buf.Reset() + + log.SetOutput(os.Stdout) +} + +func TestGetFile(t *testing.T) { + r := sqlTesterHelper(t, func(mock sqlmock.Sqlmock, testDb *SQLdb) error { + + expected := &FileDownload{ + ArchivePath: "file.txt", + ArchiveSize: 32, + Header: []byte{171, 193, 35}, + } + query := "SELECT file_path, archive_file_size, header FROM local_ega_ebi.file WHERE file_id = \\$1" + mock.ExpectQuery(query). + WithArgs("file1"). + WillReturnRows(sqlmock.NewRows([]string{"file_path", "archive_file_size", "header"}).AddRow("file.txt", 32, "abc123")) + + x, err := testDb.getFile("file1") + assert.Equal(t, expected, x, "did not get expected file details") + + return err + }) + + assert.Nil(t, r, "getFile failed unexpectedly") + + var buf bytes.Buffer + log.SetOutput(&buf) + + buf.Reset() + + log.SetOutput(os.Stdout) +} + +func TestGetFiles(t *testing.T) { + r := sqlTesterHelper(t, func(mock sqlmock.Sqlmock, testDb *SQLdb) error { + + expected := []*FileInfo{} + fileInfo := &FileInfo{ + FileID: "file1", + DatasetID: "dataset1", + DisplayFileName: "file.txt", + FileName: "urn:file1", + FileSize: 60, + DecryptedFileSize: 32, + DecryptedFileChecksum: "hash", + DecryptedFileChecksumType: "sha256", + Status: "READY", + } + expected = append(expected, fileInfo) + query := "SELECT a.file_id, dataset_id, display_file_name, file_name, file_size, " + + "decrypted_file_size, decrypted_file_checksum, decrypted_file_checksum_type, file_status from " + + "local_ega_ebi.file a, local_ega_ebi.file_dataset b WHERE dataset_id = \\$1 AND a.file_id=b.file_id;" + mock.ExpectQuery(query). + WithArgs("dataset1"). + WillReturnRows(sqlmock.NewRows([]string{"file_id", "dataset_id", "display_file_name", + "file_name", "file_size", "decrypted_file_size", "decrypted_file_checksum", "decrypted_file_checksum_type", + "file_status"}).AddRow("file1", "dataset1", "file.txt", "urn:file1", 60, 32, "hash", "sha256", "READY")) + + x, err := testDb.getFiles("dataset1") + assert.Equal(t, expected, x, "did not get expected file details") + + return err + }) + + assert.Nil(t, r, "getFiles failed unexpectedly") + + var buf bytes.Buffer + log.SetOutput(&buf) + + buf.Reset() + + log.SetOutput(os.Stdout) +} diff --git a/internal/files/files.go b/internal/files/files.go deleted file mode 100644 index b9227c5..0000000 --- a/internal/files/files.go +++ /dev/null @@ -1,27 +0,0 @@ -package files - -import ( - "bytes" - "io" - "os" - - "github.com/elixir-oslo/crypt4gh/model/headers" - "github.com/elixir-oslo/crypt4gh/streaming" - "github.com/neicnordic/sda-download/internal/config" - log "github.com/sirupsen/logrus" -) - -// StreamFile returns a stream of file contents -func StreamFile(header []byte, file *os.File, coordinates *headers.DataEditListHeaderPacket) (*streaming.Crypt4GHReader, error) { - log.Debugf("preparing file %s for streaming", file.Name()) - // Stitch header and file body together - hr := bytes.NewReader(header) - mr := io.MultiReader(hr, file) - c4ghr, err := streaming.NewCrypt4GHReader(mr, *config.Config.App.Crypt4GHKey, coordinates) - if err != nil { - log.Errorf("failed to create Crypt4GH stream reader, %v", err) - return nil, err - } - log.Debugf("file stream for %s constructed", file.Name()) - return c4ghr, nil -} diff --git a/internal/logging/logging.go b/internal/logging/logging.go deleted file mode 100644 index e7328fc..0000000 --- a/internal/logging/logging.go +++ /dev/null @@ -1,44 +0,0 @@ -package logging - -import ( - "os" - - log "github.com/sirupsen/logrus" -) - -// determineLogLevel converts string representation of log level to log.Level -func determineLogLevel(level string) log.Level { - switch level { - case "error": - return log.ErrorLevel - case "fatal": - return log.FatalLevel - case "info": - return log.InfoLevel - case "panic": - return log.PanicLevel - case "warn": - return log.WarnLevel - case "trace": - return log.TraceLevel - case "debug": - return log.DebugLevel - default: - return log.DebugLevel - } -} - -// LoggingSetup configures logging format and rules -func LoggingSetup(logLevel string) { - // Log formatting - log.SetFormatter(&log.TextFormatter{ - DisableColors: true, - FullTimestamp: true, - }) - - // Output to stdout instead of the default stderr - log.SetOutput(os.Stdout) - log.Info(logLevel) - // Minimum message level - log.SetLevel(determineLogLevel(logLevel)) -} diff --git a/internal/session/session.go b/internal/session/session.go index 29a6700..4029227 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -41,7 +41,7 @@ func InitialiseSessionCache() (*ristretto.Cache, error) { } // Get returns a value from cache at key -func Get(key string) ([]string, bool) { +var Get = func(key string) ([]string, bool) { log.Debug("get value from cache") header, exists := SessionCache.Get(key) var cachedDatasets []string @@ -65,7 +65,7 @@ func Set(key string, datasets []string) { // NewSessionKey generates a session key used for storing // dataset permissions, and checks that it doesn't already exist -func NewSessionKey() string { +var NewSessionKey = func() string { log.Debug("generating new session key") // Generate a new key until one is generated, which doesn't already exist diff --git a/internal/session/session_test.go b/internal/session/session_test.go new file mode 100644 index 0000000..2ec6a1e --- /dev/null +++ b/internal/session/session_test.go @@ -0,0 +1,80 @@ +package session + +import ( + "strings" + "testing" + "time" + + "github.com/neicnordic/sda-download/internal/config" +) + +func TestNewSessionKey(t *testing.T) { + + // Initialise a cache for testing + cache, _ := InitialiseSessionCache() + SessionCache = cache + + // This should generate an UUID4 and verify, that it doesn't already exist in the cache + // Key verification can't be tested, because it would result in an infinite loop + key := NewSessionKey() + + // UUID4 is 36 characters long + expectedLen := 36 + + if len(key) != expectedLen { + t.Errorf("TestNewSessionKey failed, expected key length %d but received %d", expectedLen, len(key)) + } + +} + +func TestGetSetCache_Found(t *testing.T) { + + // Set expiration time + config.Config.Session.Expiration = time.Duration(60 * time.Second) + + // Initialise a cache for testing + cache, _ := InitialiseSessionCache() + SessionCache = cache + + Set("key1", []string{"dataset1", "dataset2"}) + time.Sleep(time.Duration(100 * time.Millisecond)) // need to give cache time to get ready + datasets, exists := Get("key1") + + // Expected results + expectedDatasets := []string{"dataset1", "dataset2"} + expectedExists := true + + if strings.Join(datasets, "") != strings.Join(expectedDatasets, "") { + t.Errorf("TestGetSetCache_Found failed, expected %s but received %s", expectedDatasets, datasets) + } + if expectedExists != exists { + t.Errorf("TestGetSetCache_Found failed, expected %t but received %t", expectedExists, exists) + } + +} + +func TestGetSetCache_NotFound(t *testing.T) { + + // Set expiration time + config.Config.Session.Expiration = time.Duration(60 * time.Second) + + // Initialise a cache for testing + cache, _ := InitialiseSessionCache() + SessionCache = cache + + Set("key1", []string{"dataset1", "dataset2"}) + time.Sleep(time.Duration(100 * time.Millisecond)) // need to give cache time to get ready + datasets, exists := Get("key2") + + // Expected results + expectedDatasets := []string{} + expectedExists := false + + if strings.Join(datasets, "") != strings.Join(expectedDatasets, "") { + t.Errorf("TestGetSetCache_NotFound failed, expected %s but received %s", expectedDatasets, datasets) + } + if expectedExists != exists { + t.Errorf("TestGetSetCache_NotFound failed, expected %t but received %t", expectedExists, exists) + } + +} diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 8dddf9a..862e229 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -73,7 +73,7 @@ func VerifyJWT(o OIDCDetails, token string) (jwt.Token, error) { } // GetToken parses the token string from header -func GetToken(header string) (string, int, error) { +var GetToken = func(header string) (string, int, error) { log.Debug("parsing access token from header") if len(header) == 0 { log.Debug("authorization check failed") @@ -115,7 +115,7 @@ type Visa struct { } // GetVisas requests the list of visas from userinfo endpoint -func GetVisas(o OIDCDetails, token string) (*Visas, error) { +var GetVisas = func(o OIDCDetails, token string) (*Visas, error) { log.Debugf("requesting visas from %s", o.Userinfo) // Set headers headers := map[string]string{} @@ -138,7 +138,7 @@ func GetVisas(o OIDCDetails, token string) (*Visas, error) { } // GetPermissions parses visas and finds matching dataset names from the database, returning a list of matches -func GetPermissions(visas Visas) ([]string, error) { +var GetPermissions = func(visas Visas) []string { log.Debug("parsing permissions from visas") var datasets []string @@ -208,33 +208,30 @@ func GetPermissions(visas Visas) ([]string, error) { log.Errorf("failed to parse visa claim JSON into struct, %s, %s", err, visaClaimJSON) continue } - datasetFull := visa.Dataset - datasetParts := strings.Split(datasetFull, "://") - datasetName := datasetParts[len(datasetParts)-1] - exists, err := database.DB.CheckDataset(datasetFull) + exists, err := database.CheckDataset(visa.Dataset) if err != nil { - log.Debugf("visa contained dataset %s which doesn't exist in this instance, skip", datasetName) + log.Debugf("visa contained dataset %s which doesn't exist in this instance, skip", visa.Dataset) continue } if exists { - log.Debugf("checking dataset list for duplicates of %s", datasetName) + log.Debugf("checking dataset list for duplicates of %s", visa.Dataset) // check that dataset name doesn't already exist in return list, // we can get duplicates when using multiple AAIs duplicate := false for i := range datasets { - if datasets[i] == datasetName { + if datasets[i] == visa.Dataset { duplicate = true - log.Debugf("found a duplicate: dataset %s is already found, skip", datasetName) + log.Debugf("found a duplicate: dataset %s is already found, skip", visa.Dataset) continue } } if !duplicate { - log.Debugf("no duplicates of dataset: %s, add dataset to list of permissions", datasetName) - datasets = append(datasets, datasetName) + log.Debugf("no duplicates of dataset: %s, add dataset to list of permissions", visa.Dataset) + datasets = append(datasets, visa.Dataset) } } } log.Debugf("matched datasets, %s", datasets) - return datasets, nil + return datasets } diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go new file mode 100644 index 0000000..16f4cca --- /dev/null +++ b/pkg/auth/auth_test.go @@ -0,0 +1,309 @@ +package auth + +import ( + "bytes" + "errors" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/neicnordic/sda-download/pkg/request" +) + +func TestGetOIDCDetails_Fail_MakeRequest(t *testing.T) { + + // Save original to-be-mocked functions + originalMakeRequest := request.MakeRequest + + // Substitute mock functions + request.MakeRequest = func(method, url string, headers map[string]string, body []byte) (*http.Response, error) { + return nil, errors.New("error") + } + + // Run test + oidcDetails, err := GetOIDCDetails("https://testing.fi") + + // Expected results + expectedUserInfo := "" + expectedJWK := "" + expectedError := "error" + + if oidcDetails.Userinfo != expectedUserInfo { + t.Errorf("TestGetOIDCDetails_Fail_MakeRequest failed, expected %s, got %s", expectedUserInfo, oidcDetails.Userinfo) + } + if oidcDetails.JWK != expectedJWK { + t.Errorf("TestGetOIDCDetails_Fail_MakeRequest failed, expected %s, got %s", expectedJWK, oidcDetails.JWK) + } + if err.Error() != expectedError { + t.Errorf("TestGetOIDCDetails_Fail_MakeRequest failed, expected %s received %s", expectedError, err.Error()) + } + + // Return mock functions to originals + request.MakeRequest = originalMakeRequest +} + +func TestGetOIDCDetails_Fail_JSONDecode(t *testing.T) { + + // Save original to-be-mocked functions + originalMakeRequest := request.MakeRequest + + // Substitute mock functions + request.MakeRequest = func(method, url string, headers map[string]string, body []byte) (*http.Response, error) { + response := &http.Response{ + StatusCode: 200, + // Response body + Body: ioutil.NopCloser(bytes.NewBufferString(``)), + // Response headers + Header: make(http.Header), + } + return response, nil + } + + // Run test + oidcDetails, err := GetOIDCDetails("https://testing.fi") + + // Expected results + expectedUserInfo := "" + expectedJWK := "" + expectedError := "EOF" + + if oidcDetails.Userinfo != expectedUserInfo { + t.Errorf("TestGetOIDCDetails_Fail_JSONDecode failed, expected %s, got %s", expectedUserInfo, oidcDetails.Userinfo) + } + if oidcDetails.JWK != expectedJWK { + t.Errorf("TestGetOIDCDetails_Fail_JSONDecode failed, expected %s, got %s", expectedJWK, oidcDetails.JWK) + } + if err.Error() != expectedError { + t.Errorf("TestGetOIDCDetails_Fail_JSONDecode failed, expected %s received %s", expectedError, err.Error()) + } + + // Return mock functions to originals + request.MakeRequest = originalMakeRequest +} + +func TestGetOIDCDetails_Success(t *testing.T) { + + // Save original to-be-mocked functions + originalMakeRequest := request.MakeRequest + + // Substitute mock functions + request.MakeRequest = func(method, url string, headers map[string]string, body []byte) (*http.Response, error) { + response := &http.Response{ + StatusCode: 200, + // Response body + Body: ioutil.NopCloser(bytes.NewBufferString(`{"userinfo_endpoint":"https://aai.org/oidc/userinfo","jwks_uri":"https://aai.org/oidc/jwks"}`)), + // Response headers + Header: make(http.Header), + } + return response, nil + } + + // Run test + oidcDetails, err := GetOIDCDetails("https://testing.fi") + + // Expected results + expectedUserInfo := "https://aai.org/oidc/userinfo" + expectedJWK := "https://aai.org/oidc/jwks" + + if oidcDetails.Userinfo != expectedUserInfo { + t.Errorf("TestGetOIDCDetails_Fail_JSONDecode failed, expected %s, got %s", expectedUserInfo, oidcDetails.Userinfo) + } + if oidcDetails.JWK != expectedJWK { + t.Errorf("TestGetOIDCDetails_Fail_JSONDecode failed, expected %s, got %s", expectedJWK, oidcDetails.JWK) + } + if err != nil { + t.Errorf("TestGetOIDCDetails_Fail_JSONDecode failed, expected nil received %v", err) + } + + // Return mock functions to originals + request.MakeRequest = originalMakeRequest +} + +func TestGetToken_Fail_EmptyHeader(t *testing.T) { + + // Test case + token, code, err := GetToken("") + + // Expected results + expectedToken := "" + expectedCode := 401 + expectedError := "access token must be provided" + + if token != expectedToken { + t.Errorf("TestGetToken_Fail_EmptyHeader failed, expected %s, received %s", expectedToken, token) + } + if code != expectedCode { + t.Errorf("TestGetToken_Fail_EmptyHeader failed, expected %d, received %d", expectedCode, code) + } + if err.Error() != expectedError { + t.Errorf("TestGetToken_Fail_EmptyHeader failed, expected %s, received %s", expectedError, err.Error()) + } + +} + +func TestGetToken_Fail_WrongScheme(t *testing.T) { + + // Test case + token, code, err := GetToken("Basic token") + + // Expected results + expectedToken := "" + expectedCode := 400 + expectedError := "authorization scheme must be bearer" + + if token != expectedToken { + t.Errorf("TestGetToken_Fail_EmptyHeader failed, expected %s, received %s", expectedToken, token) + } + if code != expectedCode { + t.Errorf("TestGetToken_Fail_EmptyHeader failed, expected %d, received %d", expectedCode, code) + } + if err.Error() != expectedError { + t.Errorf("TestGetToken_Fail_EmptyHeader failed, expected %s, received %s", expectedError, err.Error()) + } + +} + +func TestGetToken_Fail_MissingToken(t *testing.T) { + + // Test case + token, code, err := GetToken("Bearer") + + // Expected results + expectedToken := "" + expectedCode := 400 + expectedError := "token string is missing from authorization header" + + if token != expectedToken { + t.Errorf("TestGetToken_Fail_EmptyHeader failed, expected %s, received %s", expectedToken, token) + } + if code != expectedCode { + t.Errorf("TestGetToken_Fail_EmptyHeader failed, expected %d, received %d", expectedCode, code) + } + if err.Error() != expectedError { + t.Errorf("TestGetToken_Fail_EmptyHeader failed, expected %s, received %s", expectedError, err.Error()) + } + +} + +func TestGetToken_Success(t *testing.T) { + + // Test case + token, code, err := GetToken("Bearer token") + + // Expected results + expectedToken := "token" + expectedCode := 0 + + if token != expectedToken { + t.Errorf("TestGetToken_Fail_EmptyHeader failed, expected %s, received %s", expectedToken, token) + } + if code != expectedCode { + t.Errorf("TestGetToken_Fail_EmptyHeader failed, expected %d, received %d", expectedCode, code) + } + if err != nil { + t.Errorf("TestGetToken_Fail_EmptyHeader failed, expected nil, received %v", err) + } + +} + +func TestGetVisas_Fail_MakeRequest(t *testing.T) { + + // Save original to-be-mocked functions + originalMakeRequest := request.MakeRequest + + // Substitute mock functions + request.MakeRequest = func(method, url string, headers map[string]string, body []byte) (*http.Response, error) { + return nil, errors.New("error") + } + + // Run test + oidcDetails := OIDCDetails{} + visas, err := GetVisas(oidcDetails, "token") + + // Expected results + expectedError := "error" + + if visas != nil { + t.Errorf("TestGetVisas_Fail_MakeRequest failed, expected nil, got %v", visas) + } + if err.Error() != expectedError { + t.Errorf("TestGetVisas_Fail_MakeRequest failed, expected %s received %s", expectedError, err.Error()) + } + + // Return mock functions to originals + request.MakeRequest = originalMakeRequest + +} + +func TestGetVisas_Fail_JSONDecode(t *testing.T) { + + // Save original to-be-mocked functions + originalMakeRequest := request.MakeRequest + + // Substitute mock functions + request.MakeRequest = func(method, url string, headers map[string]string, body []byte) (*http.Response, error) { + response := &http.Response{ + StatusCode: 200, + // Response body + Body: ioutil.NopCloser(bytes.NewBufferString(``)), + // Response headers + Header: make(http.Header), + } + return response, nil + } + + // Run test + oidcDetails := OIDCDetails{} + visas, err := GetVisas(oidcDetails, "token") + + // Expected results + expectedError := "EOF" + + if visas != nil { + t.Errorf("TestGetVisas_Fail_MakeRequest failed, expected nil, got %v", visas) + } + if err.Error() != expectedError { + t.Errorf("TestGetVisas_Fail_MakeRequest failed, expected %s received %s", expectedError, err.Error()) + } + + // Return mock functions to originals + request.MakeRequest = originalMakeRequest + +} + +func TestGetVisas_Success(t *testing.T) { + + // Save original to-be-mocked functions + originalMakeRequest := request.MakeRequest + + // Substitute mock functions + request.MakeRequest = func(method, url string, headers map[string]string, body []byte) (*http.Response, error) { + response := &http.Response{ + StatusCode: 200, + // Response body + Body: ioutil.NopCloser(bytes.NewBufferString(`{"ga4gh_passport_v1":["visa1","visa2"]}`)), + // Response headers + Header: make(http.Header), + } + return response, nil + } + + // Run test + oidcDetails := OIDCDetails{} + visas, err := GetVisas(oidcDetails, "token") + + // Expected results + expectedVisas := []string{"visa1", "visa2"} + + if strings.Join(visas.Visa, "") != strings.Join(expectedVisas, "") { + t.Errorf("TestGetVisas_Success failed, expected %v, got %v", expectedVisas, visas) + } + if err != nil { + t.Errorf("TestGetVisas_Success failed, expected nil received %v", err) + } + + // Return mock functions to originals + request.MakeRequest = originalMakeRequest + +} diff --git a/pkg/request/request_test.go b/pkg/request/request_test.go new file mode 100644 index 0000000..c3c198e --- /dev/null +++ b/pkg/request/request_test.go @@ -0,0 +1,170 @@ +package request + +import ( + "bytes" + "errors" + "io" + "io/ioutil" + "net/http" + "net/url" + "reflect" + "testing" +) + +// Mock client code below from https://hassansin.github.io/Unit-Testing-http-client-in-Go + +// RoundTripFunc +type RoundTripFunc func(req *http.Request) *http.Response + +// RoundTrip +func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req), nil +} + +// NewTestClient returns *http.Client with Transport replaced to avoid making real calls +func newTestClient(fn RoundTripFunc) *http.Client { + return &http.Client{ + Transport: RoundTripFunc(fn), + } +} + +func TestInitialiseClient(t *testing.T) { + // Initialise HTTP client + client, err := InitialiseClient() + if err != nil { + t.Fatalf("http client creation failed %s", err) + } + + // Verify that the correct type of object was created + if reflect.TypeOf(client).String() != "*http.Client" { + t.Errorf("http client creation failed, wanted *http.Client, received %s", reflect.TypeOf(client)) + } +} + +func TestMakeRequest_Fail_HTTPNewRequest(t *testing.T) { + + // Save original to-be-mocked functions + originalHTTPMakeRequest := HTTPNewRequest + + // Substitute mock functions + HTTPNewRequest = func(method, url string, body io.Reader) (*http.Request, error) { + return nil, errors.New("failed to build http request") + } + + // Run test + response, err := MakeRequest("GET", "https://testing.fi", nil, nil) + // defer response.Body.Close() + + // Expected results + expectedError := "failed to build http request" + + if response != nil { + _, _ = io.Copy(io.Discard, response.Body) + defer response.Body.Close() + t.Error("TestMakeRequest_Fail_HTTPNewRequest failed, expected nil") + } + if err.Error() != expectedError { + t.Errorf("TestMakeRequest_Fail_HTTPNewRequest failed, expected %s received %s", expectedError, err.Error()) + } + + // Return mock functions to originals + HTTPNewRequest = originalHTTPMakeRequest + +} + +func TestMakeRequest_Fail_StatusCode(t *testing.T) { + + // Create mock client + client := newTestClient(func(req *http.Request) *http.Response { + return &http.Response{ + StatusCode: 500, + // Response body + Body: ioutil.NopCloser(bytes.NewBufferString(`error`)), + // Response headers + Header: make(http.Header), + } + }) + Client = client + + // Save original to-be-mocked functions + originalHTTPMakeRequest := HTTPNewRequest + + // Substitute mock functions + HTTPNewRequest = func(method, requestUrl string, body io.Reader) (*http.Request, error) { + u, _ := url.Parse("https://testing.fi") + r := &http.Request{ + Method: "GET", + URL: u, + } + return r, nil + } + + // Run test + response, err := MakeRequest("GET", "https://testing.fi", nil, nil) + + // Expected results + expectedError := "500" + + if response != nil { + _, _ = io.Copy(io.Discard, response.Body) + defer response.Body.Close() + t.Error("TestMakeRequest_Fail_StatusCode failed, expected nil") + } + if err.Error() != expectedError { + t.Errorf("TestMakeRequest_Fail_StatusCode failed, expected %s received %s", expectedError, err.Error()) + } + + // Return mock functions to originals + HTTPNewRequest = originalHTTPMakeRequest + +} + +func TestMakeRequest_Success(t *testing.T) { + + // Create mock client + client := newTestClient(func(req *http.Request) *http.Response { + return &http.Response{ + StatusCode: 200, + // Response body + Body: ioutil.NopCloser(bytes.NewBufferString(`hello`)), + // Response headers + Header: make(http.Header), + } + }) + Client = client + + // Save original to-be-mocked functions + originalHTTPMakeRequest := HTTPNewRequest + + // Substitute mock functions + HTTPNewRequest = func(method, requestUrl string, body io.Reader) (*http.Request, error) { + u, _ := url.Parse("https://testing.fi") + r := &http.Request{ + Method: "GET", + URL: u, + } + return r, nil + } + + // Run test + response, err := MakeRequest("GET", "https://testing.fi", nil, nil) + body, _ := io.ReadAll(response.Body) + defer response.Body.Close() + + // Expected results + expectedBody := "hello" + + if !bytes.Equal(body, []byte(expectedBody)) { + // visual byte comparison in terminal (easier to find string differences) + t.Error(body) + t.Error([]byte(expectedBody)) + t.Errorf("TestMakeRequest_Success failed, got %s expected %s", string(body), string(expectedBody)) + } + if err != nil { + t.Errorf("TestMakeRequest_Success failed, expected nil received %v", err) + } + + // Return mock functions to originals + HTTPNewRequest = originalHTTPMakeRequest + +}