diff --git a/internal/api/upload.go b/internal/api/upload.go index be35f38..1b61d1a 100644 --- a/internal/api/upload.go +++ b/internal/api/upload.go @@ -4,11 +4,11 @@ import ( "bytes" "encoding/json" "github.com/opg-sirius-finance-admin/internal/model" + "io" "net/http" - "os" ) -func (c *Client) Upload(ctx Context, reportUploadType string, uploadDate string, email string, file *os.File) error { +func (c *Client) Upload(ctx Context, reportUploadType string, uploadDate string, email string, file io.Reader) error { var body bytes.Buffer var uploadDateTransformed *model.Date var req *http.Request @@ -18,24 +18,19 @@ func (c *Client) Upload(ctx Context, reportUploadType string, uploadDate string, uploadDateTransformed = &uploadDateFormatted } - err := json.NewEncoder(&body).Encode(model.Upload{ + fileTransformed, err := io.ReadAll(file) + + err = json.NewEncoder(&body).Encode(model.Upload{ ReportUploadType: reportUploadType, UploadDate: uploadDateTransformed, Email: email, - File: file, + File: fileTransformed, }) if err != nil { return err } - switch reportUploadType { - case "DebtChase": - req, err = c.newSiriusRequest(ctx, http.MethodPost, "/finance/reports/upload-fee-chase", &body) - case "DeputySchedule": - req, err = c.newSiriusRequest(ctx, http.MethodPost, "/finance/reports/upload-deputy-billing-schedule", &body) - default: - req, err = c.newBackendRequest(ctx, http.MethodPost, "/uploads", &body) - } + req, err = c.newBackendRequest(ctx, http.MethodPost, "/uploads", &body) if err != nil { return err diff --git a/internal/api/upload_test.go b/internal/api/upload_test.go index 211e001..c4b44b8 100644 --- a/internal/api/upload_test.go +++ b/internal/api/upload_test.go @@ -8,7 +8,7 @@ import ( "net/http" "net/http/httptest" "net/url" - "os" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -18,49 +18,22 @@ func TestUploadUrlSwitching(t *testing.T) { mockClient := &MockClient{} client, _ := NewClient(mockClient, "http://localhost:3000", "") - tempFile, err := os.CreateTemp("", "testfile") - if err != nil { - t.Fatal(err) - } - defer os.Remove(tempFile.Name()) // Clean up after the test - defer tempFile.Close() - // Write some content to the temp file - content := []byte("fake file content") - if _, err := tempFile.Write(content); err != nil { - t.Fatal(err) - } - // Reset the file pointer to the beginning - if _, err := tempFile.Seek(0, io.SeekStart); err != nil { - t.Fatal(err) - } + content := strings.NewReader("something here") - // Define test cases - testCases := map[string]string{ - "DebtChase": "http://localhost:3000/supervision-api/v1/finance/reports/upload-fee-chase", - "DeputySchedule": "http://localhost:3000/supervision-api/v1/finance/reports/upload-deputy-billing-schedule", - "OtherType": "/uploads", - } + var capturedURL *url.URL - for reportUploadType, expectedURL := range testCases { - t.Run(reportUploadType, func(t *testing.T) { - // Variable to capture the request URL - var capturedURL *url.URL - - // Mock the HTTP client's Do function to capture the request URL - GetDoFunc = func(req *http.Request) (*http.Response, error) { - capturedURL = req.URL - return &http.Response{ - StatusCode: http.StatusCreated, - Body: io.NopCloser(bytes.NewReader([]byte{})), - }, nil - } - - err := client.Upload(getContext(nil), reportUploadType, "", "", tempFile) - assert.NoError(t, err) - assert.Equal(t, expectedURL, capturedURL.String()) - }) + GetDoFunc = func(req *http.Request) (*http.Response, error) { + capturedURL = req.URL + return &http.Response{ + StatusCode: http.StatusCreated, + Body: io.NopCloser(bytes.NewReader([]byte{})), + }, nil } + + err := client.Upload(getContext(nil), "", "", "", content) + assert.NoError(t, err) + assert.Equal(t, "/uploads", capturedURL.String()) } func TestSubmitUploadUnauthorised(t *testing.T) { diff --git a/internal/model/upload.go b/internal/model/upload.go index f7f9b71..b228e67 100644 --- a/internal/model/upload.go +++ b/internal/model/upload.go @@ -1,10 +1,8 @@ package model -import "os" - type Upload struct { - ReportUploadType string `json:"reportUploadType"` - UploadDate *Date `json:"uploadDate"` - Email string `json:"email"` - File *os.File `json:"file"` + ReportUploadType string `json:"reportUploadType"` + UploadDate *Date `json:"uploadDate"` + Email string `json:"email"` + File []byte `json:"file"` } diff --git a/internal/server/server.go b/internal/server/server.go index e293219..1ccc3e5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -11,12 +11,11 @@ import ( "log/slog" "net/http" "net/url" - "os" ) type ApiClient interface { Download(api.Context, model.Download) error - Upload(api.Context, string, string, string, *os.File) error + Upload(api.Context, string, string, string, io.Reader) error } type router interface { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index a3bf099..3e03b46 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -5,7 +5,6 @@ import ( "github.com/opg-sirius-finance-admin/internal/model" "io" "net/http" - "os" ) type mockTemplate struct { @@ -53,7 +52,7 @@ type mockApiClient struct { error error //nolint:golint,unused } -func (m mockApiClient) Upload(context api.Context, s string, s2 string, s3 string, file *os.File) error { +func (m mockApiClient) Upload(context api.Context, s string, s2 string, s3 string, file io.Reader) error { return m.error } diff --git a/internal/server/upload.go b/internal/server/upload.go index 82be0e9..278a4d0 100644 --- a/internal/server/upload.go +++ b/internal/server/upload.go @@ -3,12 +3,10 @@ package server import ( "encoding/csv" "errors" - "fmt" "github.com/opg-sirius-finance-admin/internal/api" "github.com/opg-sirius-finance-admin/internal/model" "io" "net/http" - "os" "reflect" "strings" "unicode" @@ -21,55 +19,20 @@ type UploadHandler struct { func (h *UploadHandler) render(v AppVars, w http.ResponseWriter, r *http.Request) error { ctx := getContext(r) - //Create and defer cleanup of a temp directory - dir := "./uploads" - if err := os.MkdirAll(dir, os.ModePerm); err != nil { - http.Error(w, "Failed to create directory", http.StatusInternalServerError) - return err - } - defer func() { - if err := os.RemoveAll(dir); err != nil { - fmt.Println("Could not remove directory:", err) - } - }() - - // Extract form data reportUploadType := r.PostFormValue("reportUploadType") uploadDate := r.PostFormValue("uploadDate") email := r.PostFormValue("email") // Handle file upload - file, headers, err := r.FormFile("fileUpload") + file, _, err := r.FormFile("fileUpload") if err != nil { return h.handleError(w, r, "No file uploaded", http.StatusBadRequest) } defer file.Close() - // Define the destination file path - dst, err := os.Create(headers.Filename) - if err != nil { - return h.handleError(w, r, "Failed to create file", http.StatusInternalServerError) - } - defer dst.Close() - - // Copy the uploaded file to the destination - if _, err := io.Copy(dst, file); err != nil { - return h.handleError(w, r, "Failed to save file", http.StatusInternalServerError) - } - - // Reopen the file to read the headers - if err := dst.Close(); err != nil { - return h.handleError(w, r, "Failed to close file", http.StatusInternalServerError) - } - uploadedFile, err := os.Open(headers.Filename) - if err != nil { - return h.handleError(w, r, "Failed to reopen file", http.StatusInternalServerError) - } - defer uploadedFile.Close() - + csvReader := csv.NewReader(file) expectedHeaders := reportHeadersByType(reportUploadType) - csvReader := csv.NewReader(uploadedFile) readHeaders, err := csvReader.Read() if err != nil { return h.handleError(w, r, "Failed to read CSV headers", http.StatusBadRequest) @@ -84,8 +47,13 @@ func (h *UploadHandler) render(v AppVars, w http.ResponseWriter, r *http.Request return h.handleError(w, r, "CSV headers do not match for the file trying to be uploaded", http.StatusBadRequest) } + _, err = file.Seek(0, io.SeekStart) + if err != nil { + return err + } + // Upload the file - if err := h.Client().Upload(ctx, reportUploadType, uploadDate, email, uploadedFile); err != nil { + if err := h.Client().Upload(ctx, reportUploadType, uploadDate, email, file); err != nil { return h.handleUploadError(w, r, err) }