Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plugin duplicate protection (identifiers and directory names) #237

Merged
merged 9 commits into from
Feb 22, 2024
59 changes: 52 additions & 7 deletions benchmarks/views/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import logging
import os
import zipfile
import re
from typing import Tuple, Union, List
from io import TextIOWrapper

import boto3
import requests
Expand Down Expand Up @@ -226,22 +228,37 @@ def is_submission_original(file, submitter: User) -> Tuple[bool, Union[None, Lis
namelist = archive.infolist()
plugins = plugins_exist(namelist)[1]

# grab identifiers from inits of all plugins
plugin_identifiers = extract_identifiers(archive)

# for each plugin submitted, make sure that the identifier does not exist already:
for plugin in plugins:
identifiers = plugin_has_instances(namelist, plugin)[1]
plugin_directory_names = plugin_has_instances(namelist, plugin)[1]
db_table = plugin_db_mapping[plugin]

# Determine the lookup field name based on the plugin type
field_name = 'name' if plugin == "models" else 'identifier'

for identifier in identifiers:
query_filter = {field_name: identifier}
# plugin_name corresponds to the directory name, plugin_identifier corresponds to actual identifiers from inits
all_plugin_ids = plugin_directory_names + list(plugin_identifiers[plugin])
for plugin_name_or_identifier in all_plugin_ids:
query_filter = {field_name: plugin_name_or_identifier}

# check for tutorial
if "resnet50_tutorial" in plugin_name_or_identifier:
return False, [plugin, plugin_name_or_identifier]

# Check if an entry with the given identifier exists
if db_table.objects.filter(**query_filter).exists() or "resnet50_tutorial" in identifier:
return False, [plugin, identifier]
# check if an entry with the given identifier exists
if db_table.objects.filter(**query_filter).exists():
owner_obj = db_table.objects.get(**query_filter)
owner_id = getattr(owner_obj, 'owner_id', None) or getattr(owner_obj, 'owner').id

return True, None # Passes all checks, then the submission is original -> good to go
# check to see if the submitter is the owner (or superuser)
if owner_id != submitter.id and not submitter.is_superuser:
return False, [plugin, plugin_name_or_identifier]
# else, versioning will occur here

return True, [] # Passes all checks, then the submission is original -> good to go


def validate_zip(file: InMemoryUploadedFile) -> Tuple[bool, str]:
Expand Down Expand Up @@ -358,6 +375,34 @@ def _is_instance_path(path: str, plugin: str) -> bool:
return len(parts) > 2 and parts[1] == plugin and path.endswith("/")


def extract_identifiers(zip_ref):
# define patterns for each plugin type (data and metrics to be added later)
possible_plugins = ["models", "benchmarks"]
registry_patterns = {
"models": re.compile(r"model_registry\['(.+?)'\]"),
"benchmarks": re.compile(r"benchmark_registry\['(.+?)'\]"),
samwinebrake marked this conversation as resolved.
Show resolved Hide resolved
}

# dictionary to hold identifiers for each plugin type found
identifiers = {plugin: set() for plugin in possible_plugins}

for file_info in zip_ref.infolist():
path_segments = file_info.filename.split('/')
# ensure the path has 4 segments [zip root, plugin, plugin_name, __init__.py]
if len(path_segments) == 4 and path_segments[1] in possible_plugins and path_segments[-1] == '__init__.py':
plugin = path_segments[1]
with zip_ref.open(file_info) as file:
# extract identifier pattern matches
for line in TextIOWrapper(file, encoding='utf-8'):
line_code = line.split('#', 1)[0].strip() # ignore both inline and own line comments
pattern = registry_patterns.get(plugin)
if pattern:
matches = pattern.findall(line_code)
identifiers[plugin].update(matches)

return identifiers


def collect_models_benchmarks(request):
assert request.method == 'POST'

Expand Down