Skip to content

Commit

Permalink
Extend test coverage of the validate_file script
Browse files Browse the repository at this point in the history
* Extend test coverage of validate_file script
* Add type hints to validate_file script
  • Loading branch information
replaceafill authored Aug 14, 2024
1 parent 6bacc0b commit 01293dc
Show file tree
Hide file tree
Showing 3 changed files with 517 additions and 60 deletions.
4 changes: 2 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ explicit_package_bases = True
warn_redundant_casts = True
warn_unused_configs = True

[mypy-src.MCPClient.lib.client.*,src.MCPClient.*.normalize]
[mypy-src.MCPClient.lib.client.*,src.MCPClient.*.normalize,src.MCPClient.*.validate_file]
check_untyped_defs = True
disallow_any_generics = True
disallow_incomplete_defs = True
Expand All @@ -18,7 +18,7 @@ strict_equality = True
warn_return_any = True
warn_unused_ignores = True

[mypy-tests.MCPClient.conftest,tests.MCPClient.test_normalize]
[mypy-tests.MCPClient.conftest,tests.MCPClient.test_normalize,tests.MCPClient.test_validate_file]
check_untyped_defs = True
disallow_any_generics = True
disallow_incomplete_defs = True
Expand Down
63 changes: 42 additions & 21 deletions src/MCPClient/lib/clientScripts/validate_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,22 @@
import os
import sys
from pprint import pformat
from typing import Any
from typing import List
from typing import Mapping
from typing import Optional

import django
from django.core.exceptions import ValidationError
from django.db import transaction

django.setup()

import databaseFunctions
from client.job import Job
from custom_handlers import get_script_logger
from dicts import replace_string_values
from django.conf import settings as mcpclient_settings
from django.core.exceptions import ValidationError
from django.db import transaction
from executeOrRunSubProcess import executeOrRun
from fpr.models import FormatVersion
from fpr.models import FPRule
Expand All @@ -44,7 +50,14 @@
DERIVATIVE_TYPES = {"preservation", "access"}


def main(job, file_path, file_uuid, sip_uuid, shared_path, file_type):
def main(
job: Job,
file_path: str,
file_uuid: str,
sip_uuid: str,
shared_path: Optional[str],
file_type: str,
) -> int:
setup_dicts(mcpclient_settings)

validator = Validator(job, file_path, file_uuid, sip_uuid, shared_path, file_type)
Expand All @@ -67,18 +80,26 @@ class Validator:
determine whether a given file conforms to a given specification.
"""

def __init__(self, job, file_path, file_uuid, sip_uuid, shared_path, file_type):
def __init__(
self,
job: Job,
file_path: str,
file_uuid: str,
sip_uuid: str,
shared_path: Optional[str],
file_type: str,
):
self.job = job
self.file_path = file_path
self.file_uuid = file_uuid
self.sip_uuid = sip_uuid
self.shared_path = shared_path
self.shared_path = shared_path if shared_path else ""
self.file_type = file_type
self.purpose = "validation"
self._sip_logs_dir = None
self._sip_pres_val_dir = None
self._sip_logs_dir: Optional[str] = None
self._sip_pres_val_dir: Optional[str] = None

def validate(self):
def validate(self) -> int:
"""Validate the file identified by ``self.file_uuid``, using all rules
that apply. Return an error code (1 or 0), which the script as a whole
also returns. Side effects include printing to stdout/stderr (which
Expand All @@ -101,7 +122,7 @@ def validate(self):

return SUCCESS_CODE

def _get_rules(self):
def _get_rules(self) -> FPRule:
"""Return all FPR rules that apply to files of this type."""
try:
fmt = FormatVersion.active.get(fileformatversion__file_uuid=self.file_uuid)
Expand All @@ -114,7 +135,7 @@ def _get_rules(self):
rules = FPRule.active.filter(purpose=f"default_{self.purpose}")
return rules

def _execute_rule_command(self, rule):
def _execute_rule_command(self, rule: FPRule) -> str:
"""Run the command against the file and return either 'passed' or
'failed'. If the command errors or determines that the file is invalid,
return 'failed'. Non-errors will result in the creation of an Event
Expand Down Expand Up @@ -149,7 +170,7 @@ def _execute_rule_command(self, rule):
# Parse output and generate an Event
# TODO: Evaluating a python string from a user-definable script seems
# insecure practice; should be JSON.
output = ast.literal_eval(stdout)
output: Mapping[str, Any] = ast.literal_eval(stdout)
event_detail = (
f'program="{rule.command.tool.description}";'
f' version="{rule.command.tool.version}"'
Expand Down Expand Up @@ -190,7 +211,7 @@ def _execute_rule_command(self, rule):
)
return result

def _save_stdout_to_logs_dir(self, output):
def _save_stdout_to_logs_dir(self, output: Mapping[str, Any]) -> None:
"""Save the validation command's output from validating the file to a
file at logs/implementationChecks/<input_filename>.xml in the SIP.
``output`` is expected to be a dict with a ``stdout`` key.
Expand All @@ -202,15 +223,15 @@ def _save_stdout_to_logs_dir(self, output):
with open(stdout_path, "w") as f:
f.write(stdout)

def _file_is_derivative(self):
def _file_is_derivative(self) -> bool:
"""Return ``True`` if the file we are validating is a derivative, i.e.,
a modified version created for preservation or access.
"""
if self.file_type == "preservation":
return self._file_is_preservation_derivative()
return self._file_is_access_derivative()

def _file_is_preservation_derivative(self):
def _file_is_preservation_derivative(self) -> bool:
"""Returns ``True`` if the file with UUID ``self.file_uuid`` is a
preservation derivative.
"""
Expand All @@ -222,7 +243,7 @@ def _file_is_preservation_derivative(self):
except (Derivation.DoesNotExist, ValidationError):
return False

def _file_is_access_derivative(self):
def _file_is_access_derivative(self) -> bool:
"""Returns ``True`` if the file with UUID ``self.file_uuid`` is an
access derivative.
"""
Expand All @@ -241,14 +262,14 @@ def _file_is_access_derivative(self):
except (File.DoesNotExist, ValidationError):
return False

def _not_derivative_msg(self):
def _not_derivative_msg(self) -> str:
"""Return the message to print if the file is not a derivative."""
if self.file_type == "preservation":
return "is not a preservation derivative"
return "is not an access derivative"

@property
def sip_logs_dir(self):
def sip_logs_dir(self) -> Optional[str]:
"""Return the absolute path the logs/ directory of the SIP that the
target file is a part of.
"""
Expand Down Expand Up @@ -277,7 +298,7 @@ def sip_logs_dir(self):
return None

@property
def sip_pres_val_dir(self):
def sip_pres_val_dir(self) -> Optional[str]:
"""Return the full path to the directory within the SIP where stdout
from perservation derivative validation output should be written to
disk.
Expand All @@ -303,21 +324,21 @@ def sip_pres_val_dir(self):
return self._sip_pres_val_dir


def _get_shared_path(argv):
def _get_shared_path(argv: List[str]) -> Optional[str]:
try:
return argv[4]
except IndexError:
return None


def _get_file_type(argv):
def _get_file_type(argv: List[str]) -> str:
try:
return argv[5]
except IndexError:
return "original"


def call(jobs):
def call(jobs: List[Job]) -> None:
with transaction.atomic():
for job in jobs:
with job.JobContext(logger=logger):
Expand Down
Loading

0 comments on commit 01293dc

Please sign in to comment.