Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plugin duplicate protection (identifiers and directory names) #237

Merged
merged 9 commits into from
Feb 22, 2024
56 changes: 50 additions & 6 deletions benchmarks/views/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import logging
import os
import zipfile
import re
from typing import Tuple, Union, List
from io import TextIOWrapper

import boto3
import requests
Expand Down Expand Up @@ -213,30 +215,43 @@ def post(self, request):
return render(request, 'benchmarks/success.html', {"domain": self.domain})


def is_submission_original(file, submitter: User) -> Tuple[bool, Union[None, List[str]]]:
samwinebrake marked this conversation as resolved.
Show resolved Hide resolved
def is_submission_original(file, submitter):
# add metrics and data eventually
plugin_db_mapping = {"models": Model, "benchmarks": BenchmarkType}

with zipfile.ZipFile(file, mode="r") as archive:
namelist = archive.infolist()
plugins = plugins_exist(namelist)[1]

# grab identifiers from inits of all plugins
plugin_identifiers = extract_identifiers(archive)

# for each plugin submitted, make sure that the identifier does not exist already:
for plugin in plugins:
identifiers = plugin_has_instances(namelist, plugin)[1]
plugin_directory_names = plugin_has_instances(namelist, plugin)[1]
db_table = plugin_db_mapping[plugin]

# Determine the field name based on the plugin type
field_name = 'name' if plugin == "models" else 'identifier'

for identifier in identifiers:
query_filter = {field_name: identifier}
# plugin_name corresponds to the directory name, plugin_identifier corresponds to actual identifiers from inits
all_plugin_ids = plugin_directory_names + list(plugin_identifiers[plugin])
for plugin_name_or_identifier in all_plugin_ids:
query_filter = {field_name: plugin_name_or_identifier}

# Check if an entry with the given identifier exists
if db_table.objects.filter(**query_filter).exists():
return False, [plugin, identifier]
owner_obj = db_table.objects.get(**query_filter)
owner_id = getattr(owner_obj, 'owner_id', None) or getattr(owner_obj, 'owner').id

# Check to see if the submitter is the owner (or superuser)
if owner_id == submitter.id or submitter.is_superuser:
# Khaled versioning here
print(owner_id, submitter)
else:
return False, [plugin, plugin_name_or_identifier]

return True, None # Passes all checks, then the submission is original -> good to go
return True, [] # Passes all checks, then the submission is original -> good to go


def validate_zip(file):
Expand Down Expand Up @@ -315,6 +330,35 @@ def instance_has_files(namelist, instances):
return True, files_list, None


def extract_identifiers(zip_ref):
# define patterns for each plugin type (data and metrics to be added later)
possible_plugins = ["models", "benchmarks"]
registry_patterns = {
"models": re.compile(r"model_registry\['(.+?)'\]"),
"benchmarks": re.compile(r"benchmark_registry\['(.+?)'\]"),
samwinebrake marked this conversation as resolved.
Show resolved Hide resolved
}

# dictionary to hold identifiers for each plugin type found
identifiers = {plugin: set() for plugin in possible_plugins}

for file_info in zip_ref.infolist():
path_segments = file_info.filename.split('/')
for plugin in possible_plugins:
# check if __init__.py under any of the possible plugins' directories
if plugin in path_segments and '__init__.py' in path_segments[-1]:
samwinebrake marked this conversation as resolved.
Show resolved Hide resolved
with zip_ref.open(file_info) as file:
# extract identifier pattern matches
for line in TextIOWrapper(file, encoding='utf-8'):
line_code = line.split('#', 1)[0] # ignore both inline and own line comments
pattern = registry_patterns.get(plugin)
if pattern:
matches = pattern.findall(line_code)
identifiers[plugin].update(matches)
break

return identifiers


def collect_models_benchmarks(request):
assert request.method == 'POST'

Expand Down