Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: secure filename + keep filename if uploading file bytes #73

Merged
merged 6 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions chord_drs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from werkzeug.utils import secure_filename
from urllib.parse import urlparse
from uuid import uuid4

Expand Down Expand Up @@ -68,13 +69,18 @@ class DrsBlob(db.Model, DrsMixin):
location = db.Column(db.String(500), nullable=False)

def __init__(self, *args, **kwargs):
logger = current_app.logger

# If set, we are deduplicating with an existing file object
object_to_copy: DrsBlob | None = kwargs.get("object_to_copy")

# If set, we are overriding the filename to save the file to
filename: str | None = kwargs.get("filename")

self.id = str(uuid4())

if object_to_copy:
self.name = object_to_copy.name
self.name = secure_filename(filename) if filename else object_to_copy.name
self.location = object_to_copy.location
self.size = object_to_copy.size
self.checksum = object_to_copy.checksum
Expand All @@ -88,8 +94,8 @@ def __init__(self, *args, **kwargs):
# TODO: we will need to account for URLs at some point
raise FileNotFoundError("Provided file path does not exists")

self.name = p.name
new_filename = f"{self.id[:12]}-{p.name}" # TODO: use checksum for filename instead
self.name = secure_filename(filename or p.name)
new_filename = f"{self.id[:12]}-{self.name}" # TODO: use checksum for filename instead

backend = get_backend()

Expand All @@ -100,12 +106,15 @@ def __init__(self, *args, **kwargs):
self.size = os.path.getsize(p)
self.checksum = drs_file_checksum(location)
except Exception as e:
current_app.logger.error(f"Encountered exception during DRS object creation: {e}")
logger.error(f"Encountered exception during DRS object creation: {e}")
# TODO: implement more specific exception handling
raise Exception("Well if the file is not saved... we can't do squat")

if "location" in kwargs:
del kwargs["location"]
logger.info(f"Creating new DRS object: name={self.name}; size={self.size}; sha256={self.checksum}")

for key_to_remove in ("location", "filename"):
if key_to_remove in kwargs:
del kwargs[key_to_remove]

super().__init__(*args, **kwargs)

Expand Down
6 changes: 5 additions & 1 deletion chord_drs/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,9 +473,12 @@ def object_ingest():

tfh, t_obj_path = tempfile.mkstemp()
try:
if file:
filename: str | None = None # no override, use path filename if path is specified instead of a file upload
if file is not None:
logger.debug(f"ingest - recieved file object: {file}")
file.save(t_obj_path)
obj_path = t_obj_path
filename = file.filename # still may be none, in which case the temporary filename will be used

if deduplicate:
# Get checksum of original file, and query database for objects that match
Expand Down Expand Up @@ -507,6 +510,7 @@ def object_ingest():
try:
drs_object = DrsBlob(
**(dict(object_to_copy=object_to_copy) if object_to_copy else dict(location=obj_path)),
filename=filename,
project_id=project_id,
dataset_id=dataset_id,
data_type=data_type,
Expand Down
6 changes: 4 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@


def test_drs_blob_init_bad_file():
from chord_drs.app import application
from chord_drs.models import DrsBlob

with pytest.raises(FileNotFoundError):
DrsBlob(location="path/to/dne")
with application.app_context():
with pytest.raises(FileNotFoundError):
DrsBlob(location="path/to/dne")


def test_drs_blob_init():
Expand Down