diff --git a/chord_drs/models.py b/chord_drs/models.py index 906f6cc..7f7ac6f 100644 --- a/chord_drs/models.py +++ b/chord_drs/models.py @@ -4,6 +4,7 @@ from pathlib import Path from sqlalchemy.sql import func from sqlalchemy.orm import relationship +from werkzeug.utils import secure_filename from urllib.parse import urlparse from uuid import uuid4 @@ -68,13 +69,18 @@ class DrsBlob(db.Model, DrsMixin): location = db.Column(db.String(500), nullable=False) def __init__(self, *args, **kwargs): + logger = current_app.logger + # If set, we are deduplicating with an existing file object object_to_copy: DrsBlob | None = kwargs.get("object_to_copy") + # If set, we are overriding the filename to save the file to + filename: str | None = kwargs.get("filename") + self.id = str(uuid4()) if object_to_copy: - self.name = object_to_copy.name + self.name = secure_filename(filename) if filename else object_to_copy.name self.location = object_to_copy.location self.size = object_to_copy.size self.checksum = object_to_copy.checksum @@ -88,8 +94,8 @@ def __init__(self, *args, **kwargs): # TODO: we will need to account for URLs at some point raise FileNotFoundError("Provided file path does not exists") - self.name = p.name - new_filename = f"{self.id[:12]}-{p.name}" # TODO: use checksum for filename instead + self.name = secure_filename(filename or p.name) + new_filename = f"{self.id[:12]}-{self.name}" # TODO: use checksum for filename instead backend = get_backend() @@ -100,12 +106,15 @@ def __init__(self, *args, **kwargs): self.size = os.path.getsize(p) self.checksum = drs_file_checksum(location) except Exception as e: - current_app.logger.error(f"Encountered exception during DRS object creation: {e}") + logger.error(f"Encountered exception during DRS object creation: {e}") # TODO: implement more specific exception handling raise Exception("Well if the file is not saved... we can't do squat") - if "location" in kwargs: - del kwargs["location"] + logger.info(f"Creating new DRS object: name={self.name}; size={self.size}; sha256={self.checksum}") + + for key_to_remove in ("location", "filename"): + if key_to_remove in kwargs: + del kwargs[key_to_remove] super().__init__(*args, **kwargs) diff --git a/chord_drs/routes.py b/chord_drs/routes.py index 6c1bdce..a1be118 100644 --- a/chord_drs/routes.py +++ b/chord_drs/routes.py @@ -473,9 +473,12 @@ def object_ingest(): tfh, t_obj_path = tempfile.mkstemp() try: - if file: + filename: str | None = None # no override, use path filename if path is specified instead of a file upload + if file is not None: + logger.debug(f"ingest - recieved file object: {file}") file.save(t_obj_path) obj_path = t_obj_path + filename = file.filename # still may be none, in which case the temporary filename will be used if deduplicate: # Get checksum of original file, and query database for objects that match @@ -507,6 +510,7 @@ def object_ingest(): try: drs_object = DrsBlob( **(dict(object_to_copy=object_to_copy) if object_to_copy else dict(location=obj_path)), + filename=filename, project_id=project_id, dataset_id=dataset_id, data_type=data_type, diff --git a/tests/test_models.py b/tests/test_models.py index 4ac2100..649d9ba 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,10 +4,12 @@ def test_drs_blob_init_bad_file(): + from chord_drs.app import application from chord_drs.models import DrsBlob - with pytest.raises(FileNotFoundError): - DrsBlob(location="path/to/dne") + with application.app_context(): + with pytest.raises(FileNotFoundError): + DrsBlob(location="path/to/dne") def test_drs_blob_init():