diff --git a/benchmarks/views/user.py b/benchmarks/views/user.py index ab4634591..5a0b340c0 100644 --- a/benchmarks/views/user.py +++ b/benchmarks/views/user.py @@ -2,7 +2,9 @@ import logging import os import zipfile +import re from typing import Tuple, Union, List +from io import TextIOWrapper import boto3 import requests @@ -226,22 +228,37 @@ def is_submission_original(file, submitter: User) -> Tuple[bool, Union[None, Lis namelist = archive.infolist() plugins = plugins_exist(namelist)[1] + # grab identifiers from inits of all plugins + plugin_identifiers = extract_identifiers(archive) + # for each plugin submitted, make sure that the identifier does not exist already: for plugin in plugins: - identifiers = plugin_has_instances(namelist, plugin)[1] + plugin_directory_names = plugin_has_instances(namelist, plugin)[1] db_table = plugin_db_mapping[plugin] # Determine the lookup field name based on the plugin type field_name = 'name' if plugin == "models" else 'identifier' - for identifier in identifiers: - query_filter = {field_name: identifier} + # plugin_name corresponds to the directory name, plugin_identifier corresponds to actual identifiers from inits + all_plugin_ids = plugin_directory_names + list(plugin_identifiers[plugin]) + for plugin_name_or_identifier in all_plugin_ids: + query_filter = {field_name: plugin_name_or_identifier} + + # check for tutorial + if "resnet50_tutorial" in plugin_name_or_identifier: + return False, [plugin, plugin_name_or_identifier] - # Check if an entry with the given identifier exists - if db_table.objects.filter(**query_filter).exists() or "resnet50_tutorial" in identifier: - return False, [plugin, identifier] + # check if an entry with the given identifier exists + if db_table.objects.filter(**query_filter).exists(): + owner_obj = db_table.objects.get(**query_filter) + owner_id = getattr(owner_obj, 'owner_id', None) or getattr(owner_obj, 'owner').id - return True, None # Passes all checks, then the submission is original -> good to go + # check to see if the submitter is the owner (or superuser) + if owner_id != submitter.id and not submitter.is_superuser: + return False, [plugin, plugin_name_or_identifier] + # else, versioning will occur here + + return True, [] # Passes all checks, then the submission is original -> good to go def validate_zip(file: InMemoryUploadedFile) -> Tuple[bool, str]: @@ -358,6 +375,34 @@ def _is_instance_path(path: str, plugin: str) -> bool: return len(parts) > 2 and parts[1] == plugin and path.endswith("/") +def extract_identifiers(zip_ref): + # define patterns for each plugin type (data and metrics to be added later) + possible_plugins = ["models", "benchmarks"] + registry_patterns = { + "models": re.compile(r"model_registry\['(.+?)'\]"), + "benchmarks": re.compile(r"benchmark_registry\['(.+?)'\]"), + } + + # dictionary to hold identifiers for each plugin type found + identifiers = {plugin: set() for plugin in possible_plugins} + + for file_info in zip_ref.infolist(): + path_segments = file_info.filename.split('/') + # ensure the path has 4 segments [zip root, plugin, plugin_name, __init__.py] + if len(path_segments) == 4 and path_segments[1] in possible_plugins and path_segments[-1] == '__init__.py': + plugin = path_segments[1] + with zip_ref.open(file_info) as file: + # extract identifier pattern matches + for line in TextIOWrapper(file, encoding='utf-8'): + line_code = line.split('#', 1)[0].strip() # ignore both inline and own line comments + pattern = registry_patterns.get(plugin) + if pattern: + matches = pattern.findall(line_code) + identifiers[plugin].update(matches) + + return identifiers + + def collect_models_benchmarks(request): assert request.method == 'POST'