Skip to content

Commit

Permalink
Plugin duplicate protection (identifiers and directory names) (#237)
Browse files Browse the repository at this point in the history
* - added extract_identifiers to find patterns within plugin inits
- changed variable naming within is_submission_original
- duplicate protection of directory names and plugin identifiers

* := not compatible in python 3.7... code changed

* ignore possible commented out identifier patterns

* make zip structure 'zip_root/plugin/plugin_name/__init__.py' mandatory

* refix :=

* update check for submitter id

* readd type hints
  • Loading branch information
samwinebrake authored Feb 22, 2024
1 parent 84fe649 commit 523d04c
Showing 1 changed file with 52 additions and 7 deletions.
59 changes: 52 additions & 7 deletions benchmarks/views/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import logging
import os
import zipfile
import re
from typing import Tuple, Union, List
from io import TextIOWrapper

import boto3
import requests
Expand Down Expand Up @@ -226,22 +228,37 @@ def is_submission_original(file, submitter: User) -> Tuple[bool, Union[None, Lis
namelist = archive.infolist()
plugins = plugins_exist(namelist)[1]

# grab identifiers from inits of all plugins
plugin_identifiers = extract_identifiers(archive)

# for each plugin submitted, make sure that the identifier does not exist already:
for plugin in plugins:
identifiers = plugin_has_instances(namelist, plugin)[1]
plugin_directory_names = plugin_has_instances(namelist, plugin)[1]
db_table = plugin_db_mapping[plugin]

# Determine the lookup field name based on the plugin type
field_name = 'name' if plugin == "models" else 'identifier'

for identifier in identifiers:
query_filter = {field_name: identifier}
# plugin_name corresponds to the directory name, plugin_identifier corresponds to actual identifiers from inits
all_plugin_ids = plugin_directory_names + list(plugin_identifiers[plugin])
for plugin_name_or_identifier in all_plugin_ids:
query_filter = {field_name: plugin_name_or_identifier}

# check for tutorial
if "resnet50_tutorial" in plugin_name_or_identifier:
return False, [plugin, plugin_name_or_identifier]

# Check if an entry with the given identifier exists
if db_table.objects.filter(**query_filter).exists() or "resnet50_tutorial" in identifier:
return False, [plugin, identifier]
# check if an entry with the given identifier exists
if db_table.objects.filter(**query_filter).exists():
owner_obj = db_table.objects.get(**query_filter)
owner_id = getattr(owner_obj, 'owner_id', None) or getattr(owner_obj, 'owner').id

return True, None # Passes all checks, then the submission is original -> good to go
# check to see if the submitter is the owner (or superuser)
if owner_id != submitter.id and not submitter.is_superuser:
return False, [plugin, plugin_name_or_identifier]
# else, versioning will occur here

return True, [] # Passes all checks, then the submission is original -> good to go


def validate_zip(file: InMemoryUploadedFile) -> Tuple[bool, str]:
Expand Down Expand Up @@ -358,6 +375,34 @@ def _is_instance_path(path: str, plugin: str) -> bool:
return len(parts) > 2 and parts[1] == plugin and path.endswith("/")


def extract_identifiers(zip_ref):
# define patterns for each plugin type (data and metrics to be added later)
possible_plugins = ["models", "benchmarks"]
registry_patterns = {
"models": re.compile(r"model_registry\['(.+?)'\]"),
"benchmarks": re.compile(r"benchmark_registry\['(.+?)'\]"),
}

# dictionary to hold identifiers for each plugin type found
identifiers = {plugin: set() for plugin in possible_plugins}

for file_info in zip_ref.infolist():
path_segments = file_info.filename.split('/')
# ensure the path has 4 segments [zip root, plugin, plugin_name, __init__.py]
if len(path_segments) == 4 and path_segments[1] in possible_plugins and path_segments[-1] == '__init__.py':
plugin = path_segments[1]
with zip_ref.open(file_info) as file:
# extract identifier pattern matches
for line in TextIOWrapper(file, encoding='utf-8'):
line_code = line.split('#', 1)[0].strip() # ignore both inline and own line comments
pattern = registry_patterns.get(plugin)
if pattern:
matches = pattern.findall(line_code)
identifiers[plugin].update(matches)

return identifiers


def collect_models_benchmarks(request):
assert request.method == 'POST'

Expand Down

0 comments on commit 523d04c

Please sign in to comment.