Skip to content

Commit

Permalink
Feat 51: Enable xlsx file upload (#62)
Browse files Browse the repository at this point in the history
* add input file type validation on UI

* enable xlsx file upload, add server-side convert to csv

* adjust unit tests to also run on windows

* add unit tests for validate_csv for xlsx upload

* revert AttributeError msg comparison in unit test for linux
  • Loading branch information
helen-m-lin authored Jan 25, 2024
1 parent b3f6a56 commit 7028075
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 26 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ server = [
'starlette_wtf',
'uvicorn[standard]',
'wtforms',
'requests==2.25.0'
'requests==2.25.0',
'openpyxl'
]

[tool.setuptools.packages.find]
Expand Down
38 changes: 25 additions & 13 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from asyncio import sleep
from pathlib import Path

import openpyxl
from fastapi import Request
from fastapi.responses import JSONResponse
from fastapi.templating import Jinja2Templates
Expand Down Expand Up @@ -45,22 +46,33 @@


async def validate_csv(request: Request):
"""Validate a csv file. Return parsed contents as json."""
"""Validate a csv or xlsx file. Return parsed contents as json."""
async with request.form() as form:
content = await form["file"].read()
# A few csv files created from excel have extra unicode byte chars.
# Adding "utf-8-sig" should remove them.
data = content.decode("utf-8-sig")
csv_reader = csv.DictReader(io.StringIO(data))
basic_jobs = []
errors = []
for row in csv_reader:
try:
job = BasicUploadJobConfigs.from_csv_row(row=row)
# Construct hpc job setting most of the vars from the env
basic_jobs.append(job.json())
except Exception as e:
errors.append(repr(e))
if not form["file"].filename.endswith((".csv", ".xlsx")):
errors.append("Invalid input file type")
else:
content = await form["file"].read()
if form["file"].filename.endswith(".csv"):
# A few csv files created from excel have extra unicode
# byte chars. Adding "utf-8-sig" should remove them.
data = content.decode("utf-8-sig")
else:
xlsx_sheet = openpyxl.load_workbook(io.BytesIO(content)).active
csv_io = io.StringIO()
csv_writer = csv.writer(csv_io)
for r in xlsx_sheet.rows:
csv_writer.writerow([cell.value for cell in r])
data = csv_io.getvalue()
csv_reader = csv.DictReader(io.StringIO(data))
for row in csv_reader:
try:
job = BasicUploadJobConfigs.from_csv_row(row=row)
# Construct hpc job setting most of the vars from the env
basic_jobs.append(job.json())
except Exception as e:
errors.append(repr(e))
message = "There were errors" if len(errors) > 0 else "Valid Data"
status_code = 406 if len(errors) > 0 else 200
content = {
Expand Down
12 changes: 10 additions & 2 deletions src/aind_data_transfer_service/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ <h2>Submit Jobs</h2>
</fieldset>
</div><br><br>
<form id="preview_form" method="post" enctype="multipart/form-data">
<label for="file">Please select a CSV file:</label>
<input type="file" id="file" name="file"><br><br>
<label for="file">Please select a .csv or .xlsx file:</label>
<input type="file" id="file" name="file" accept=".csv,.xlsx"><br><br>
<input type="submit" id="preview" value="preview"><br><br>
</form>
<button type="button" onclick="submitJobs()">Submit</button>
Expand All @@ -63,6 +63,14 @@ <h2>Submit Jobs</h2>
$(function() {
$("#preview_form").on("submit", function(e) {
e.preventDefault();
if ($("#file").prop("files").length != 1) {
alert("No file selected. Please attach a .csv or .xlsx file.");
return;
}
if (![".csv", ".xlsx"].some(ext => $("#file").prop("files")[0].name.endsWith(ext))) {
alert("Invalid file type. Please attach a .csv or .xlsx file.");
return;
}
var formData = new FormData(this);
$.ajax({
url: "/api/validate_csv",
Expand Down
Binary file added tests/resources/sample.xlsx
Binary file not shown.
4 changes: 4 additions & 0 deletions tests/resources/sample_invalid_ext.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
modality0, modality0.source, modality1, modality1.source, s3-bucket, subject-id, platform, acq-datetime
ECEPHYS, dir/data_set_1, ,, some_bucket, 123454, ecephys, 2020-10-10 14:10:10
BEHAVIOR_VIDEOS, dir/data_set_2, MRI, dir/data_set_3, some_bucket2, 123456, BEHAVIOR, 10/13/2020 1:10:10 PM
BEHAVIOR_VIDEOS, dir/data_set_2, BEHAVIOR_VIDEOS, dir/data_set_3, some_bucket2, 123456, BEHAVIOR, 10/13/2020 1:10:10 PM
Binary file added tests/resources/sample_malformed.xlsx
Binary file not shown.
3 changes: 2 additions & 1 deletion tests/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def test_from_job_and_server_configs(self):
'[{"modality": {"name": '
'"Extracellular electrophysiology", '
'"abbreviation": "ecephys"}, '
'"source": "dir/data_set_1", "compress_raw_data": true, '
f'"source": "{repr(str(Path("dir/data_set_1")))[1:-1]}", '
'"compress_raw_data": true, '
'"extra_configs": null,'
' "skip_staging": false}],'
' "subject_id": "123454",'
Expand Down
4 changes: 2 additions & 2 deletions tests/test_hpc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ def test_from_upload_job_configs(self):
"ecephys_123454_2020-10-10_14-10-10", hpc_settings.name
)
self.assertEqual(
"dir/logs/ecephys_123454_2020-10-10_14-10-10_error.out",
str(Path("dir/logs/ecephys_123454_2020-10-10_14-10-10_error.out")),
hpc_settings.standard_error,
)
self.assertEqual(
"dir/logs/ecephys_123454_2020-10-10_14-10-10.out",
str(Path("dir/logs/ecephys_123454_2020-10-10_14-10-10.out")),
hpc_settings.standard_out,
)
self.assertEqual(180, hpc_settings.time_limit)
Expand Down
64 changes: 57 additions & 7 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
from tests.test_configs import TestJobConfigs

TEST_DIRECTORY = Path(os.path.dirname(os.path.realpath(__file__)))
SAMPLE_FILE = TEST_DIRECTORY / "resources" / "sample.csv"
MALFORMED_SAMPLE_FILE = TEST_DIRECTORY / "resources" / "sample_malformed.csv"
SAMPLE_INVALID_EXT = TEST_DIRECTORY / "resources" / "sample_invalid_ext.txt"
SAMPLE_CSV = TEST_DIRECTORY / "resources" / "sample.csv"
MALFORMED_SAMPLE_CSV = TEST_DIRECTORY / "resources" / "sample_malformed.csv"
SAMPLE_XLSX = TEST_DIRECTORY / "resources" / "sample.xlsx"
MALFORMED_SAMPLE_XLSX = TEST_DIRECTORY / "resources" / "sample_malformed.xlsx"
MOCK_DB_FILE = TEST_DIRECTORY / "test_server" / "db.json"


Expand All @@ -40,7 +43,7 @@ class TestServer(unittest.TestCase):
"HPC_AWS_PARAM_STORE_NAME": "/some/param/store",
}

with open(SAMPLE_FILE, "r") as file:
with open(SAMPLE_CSV, "r") as file:
csv_content = file.read()

with open(MOCK_DB_FILE) as f:
Expand All @@ -54,7 +57,24 @@ class TestServer(unittest.TestCase):
def test_validate_csv(self):
"""Tests that valid csv file is returned."""
with TestClient(app) as client:
with open(SAMPLE_FILE, "rb") as f:
with open(SAMPLE_CSV, "rb") as f:
files = {
"file": f,
}
response = client.post(url="/api/validate_csv", files=files)
expected_jobs = [j.json() for j in self.expected_job_configs]
expected_response = {
"message": "Valid Data",
"data": {"jobs": expected_jobs, "errors": []},
}
self.assertEqual(response.status_code, 200)
self.assertEqual(expected_response, response.json())

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
def test_validate_csv_xlsx(self):
"""Tests that valid xlsx file is returned."""
with TestClient(app) as client:
with open(SAMPLE_XLSX, "rb") as f:
files = {
"file": f,
}
Expand All @@ -79,7 +99,7 @@ def test_submit_jobs(
mock_response._content = b'{"message": "success"}'
mock_submit_job.return_value = mock_response
with TestClient(app) as client:
with open(SAMPLE_FILE, "rb") as f:
with open(SAMPLE_CSV, "rb") as f:
files = {
"file": f,
}
Expand Down Expand Up @@ -118,7 +138,7 @@ def test_submit_jobs_server_error(
mock_response.status_code = 500
mock_submit_job.return_value = mock_response
with TestClient(app) as client:
with open(SAMPLE_FILE, "rb") as f:
with open(SAMPLE_CSV, "rb") as f:
files = {
"file": f,
}
Expand Down Expand Up @@ -179,11 +199,41 @@ def test_submit_jobs_malformed_json(
self.assertEqual(0, mock_sleep.call_count)
self.assertEqual(0, mock_log_error.call_count)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
def test_validate_null_csv(self):
"""Tests that invalid file type returns FileNotFoundError"""
with TestClient(app) as client:
with open(SAMPLE_INVALID_EXT, "rb") as f:
files = {
"file": f,
}
response = client.post(url="/api/validate_csv", files=files)
self.assertEqual(response.status_code, 406)
self.assertEqual(
["Invalid input file type"],
response.json()["data"]["errors"],
)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
def test_validate_malformed_csv(self):
"""Tests that invalid csv returns errors"""
with TestClient(app) as client:
with open(MALFORMED_SAMPLE_FILE, "rb") as f:
with open(MALFORMED_SAMPLE_CSV, "rb") as f:
files = {
"file": f,
}
response = client.post(url="/api/validate_csv", files=files)
self.assertEqual(response.status_code, 406)
self.assertEqual(
["AttributeError('WRONG_MODALITY_HERE')"],
response.json()["data"]["errors"],
)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
def test_validate_malformed_xlsx(self):
"""Tests that invalid xlsx returns errors"""
with TestClient(app) as client:
with open(MALFORMED_SAMPLE_XLSX, "rb") as f:
files = {
"file": f,
}
Expand Down

0 comments on commit 7028075

Please sign in to comment.