From 86cef96959926181b4c75bb425fb3c1bdc2588b4 Mon Sep 17 00:00:00 2001 From: Enzo Caceres Date: Sun, 27 Oct 2024 15:53:44 +0100 Subject: [PATCH] feat(custom): support checking and scoring (#143) Infrastructure and tools required for the new unstructured competition format. --- crunch/api/__init__.py | 7 ++ crunch/api/domain/__init__.py | 4 +- crunch/api/domain/data_release.py | 14 ++- crunch/api/domain/phase.py | 13 +++ crunch/api/resource.py | 2 +- crunch/cli.py | 187 ++++++++++++++++++++++++++++++ crunch/command/download.py | 96 ++------------- crunch/constants.py | 4 + crunch/custom/__init__.py | 13 +++ crunch/custom/code_loader.py | 78 +++++++++++++ crunch/custom/scoring.py | 127 ++++++++++++++++++++ crunch/downloader.py | 102 ++++++++++++++++ crunch/scoring/score.py | 2 +- 13 files changed, 558 insertions(+), 91 deletions(-) create mode 100644 crunch/custom/__init__.py create mode 100644 crunch/custom/code_loader.py create mode 100644 crunch/custom/scoring.py create mode 100644 crunch/downloader.py diff --git a/crunch/api/__init__.py b/crunch/api/__init__.py index 6b7b9c3..8e41802 100644 --- a/crunch/api/__init__.py +++ b/crunch/api/__init__.py @@ -28,3 +28,10 @@ LatestRoundNotFoundException, NextRoundNotFoundException ) + +from .identifiers import ( + CompetitionIdentifierType, + RoundIdentifierType, + PhaseIdentifierType, + CrunchIdentifierType, +) diff --git a/crunch/api/domain/__init__.py b/crunch/api/domain/__init__.py index 6576796..f4e6314 100644 --- a/crunch/api/domain/__init__.py +++ b/crunch/api/domain/__init__.py @@ -14,13 +14,15 @@ ColumnNames, TargetColumnNames, DataReleaseTargetResolution, + SplitKeyPythonType, DataReleaseSplit, DataReleaseSplitGroup, DataReleaseSplitReduced, DataReleaseFeature, + DataFile, DataFiles, OriginalFiles, - DataFile, + DataFilesUnion, KnownData, ) from .enum_ import ( diff --git a/crunch/api/domain/data_release.py b/crunch/api/domain/data_release.py index fa2ed51..27ffcd4 100644 --- a/crunch/api/domain/data_release.py +++ b/crunch/api/domain/data_release.py @@ -77,6 +77,9 @@ def items(self): return vars(self).items() +DataFilesUnion = typing.Union[DataFiles, OriginalFiles, typing.Dict[str, DataFile]] + + class DataReleaseSplitGroup(enum.Enum): TRAIN = "TRAIN" @@ -95,6 +98,9 @@ def __repr__(self): return self.name +SplitKeyPythonType = typing.Union[str, int] + + @dataclasses_json.dataclass_json( letter_case=dataclasses_json.LetterCase.CAMEL, undefined=dataclasses_json.Undefined.EXCLUDE, @@ -102,7 +108,7 @@ def __repr__(self): @dataclasses.dataclass(frozen=True) class DataReleaseSplit: - key: typing.Union[str, int] + key: SplitKeyPythonType group: DataReleaseSplitGroup reduced: typing.Optional[DataReleaseSplitReduced] = None @@ -180,7 +186,7 @@ def target_resolution(self): return DataReleaseTargetResolution[self._attrs["target_resolution"]] @property - def data_files(self) -> typing.Union[DataFiles, OriginalFiles, typing.Dict[str, DataFile]]: + def data_files(self) -> DataFilesUnion: files = self._attrs.get("dataFiles") if not files: self.reload() @@ -198,13 +204,13 @@ def data_files(self) -> typing.Union[DataFiles, OriginalFiles, typing.Dict[str, }) @property - def splits(self) -> typing.Tuple[DataReleaseSplit]: + def splits(self) -> typing.List[DataReleaseSplit]: splits = self._attrs.get("splits") if splits is None: self.reload(include_splits=True) splits = self._attrs["splits"] - return tuple(DataReleaseSplit.from_dict_array(splits)) + return list(DataReleaseSplit.from_dict_array(splits)) @property def default_feature_group(self) -> str: diff --git a/crunch/api/domain/phase.py b/crunch/api/domain/phase.py index 322d219..3db68d4 100644 --- a/crunch/api/domain/phase.py +++ b/crunch/api/domain/phase.py @@ -14,6 +14,15 @@ class PhaseType(enum.Enum): def __repr__(self): return self.name + def slug(self): + if self == PhaseType.SUBMISSION: + return "submission" + + if self == PhaseType.OUT_OF_SAMPLE: + return "out-of-sample" + + return self._default_slug() + def pretty(self): if self == PhaseType.SUBMISSION: return "Submission" @@ -21,9 +30,13 @@ def pretty(self): if self == PhaseType.OUT_OF_SAMPLE: return "Out-of-Sample" + return self._default_slug() + + def _default_slug(self): return str(self).lower().replace("_", "-") + class Phase(Model): resource_identifier_attribute = "type" diff --git a/crunch/api/resource.py b/crunch/api/resource.py index bffaf7f..79f0e50 100644 --- a/crunch/api/resource.py +++ b/crunch/api/resource.py @@ -42,7 +42,7 @@ def __hash__(self): return hash(f"{self.__class__.__name__}:{self.id}") @property - def id(self): + def id(self) -> typing.Union[int, str]: return self._attrs.get(self.id_attribute) @property diff --git a/crunch/cli.py b/crunch/cli.py index 85d276e..d5f5788 100644 --- a/crunch/cli.py +++ b/crunch/cli.py @@ -2,6 +2,7 @@ import logging import os import sys +import traceback import typing import click @@ -679,5 +680,191 @@ def cloud_executor( executor.start() +@cli.group(name="organizer") +@click.argument('competition_name') +@click.pass_context +def organize_group( + context: click.Context, + competition_name: str, +): + client = api.Client.from_env() + + try: + competition = client.competitions.get(competition_name) + except api.errors.CompetitionNameNotFoundException: + print(f"competition {competition_name} not found", file=sys.stderr) + raise click.Abort() + except api.ApiException as error: + utils.exit_via(error) + + context.obj = competition + + +@organize_group.command() +@click.pass_context +def x( + context: click.Context, +): + print(context.obj) + + +@organize_group.group(name="test") +def organize_test_group(): + pass + + +@organize_test_group.group(name="scoring") +@click.option("--script-file", "script_file_path", type=click.Path(dir_okay=False, readable=True), required=False) +@click.option("--github-repository", default=constants.COMPETITIONS_REPOSITORY, required=False) +@click.option("--github-branch", default=constants.COMPETITIONS_BRANCH, required=False) +@click.pass_context +def scoring_group( + context: click.Context, + script_file_path: str, + github_repository: str, + github_branch: str, +): + from . import custom + + competition: api.Competition = context.obj + + if script_file_path is None: + loader = custom.GithubCodeLoader( + competition.name, + repository=github_repository, + branch=github_branch, + ) + else: + loader = custom.LocalCodeLoader( + script_file_path, + ) + + context.obj = (competition, loader) + + +LOWER_PHASE_TYPES = list(map(lambda x: x.name, [ + api.PhaseType.SUBMISSION, + api.PhaseType.OUT_OF_SAMPLE, +])) + + +@scoring_group.command(name="check") +@click.option("--data-directory", "data_directory_path", type=click.Path(file_okay=False, readable=True), required=True) +@click.option("--prediction-file", "prediction_file_path", type=click.Path(dir_okay=False, readable=True), required=True) +@click.option("--phase-type", "phase_type_string", type=click.Choice(LOWER_PHASE_TYPES), default=LOWER_PHASE_TYPES[0]) +@click.pass_context +def scoring_check( + context: click.Context, + data_directory_path: str, + prediction_file_path: str, + phase_type_string: str, +): + from . import custom + + competition, loader = typing.cast( + typing.Tuple[ + api.Competition, + custom.CodeLoader, + ], + context.obj + ) + + phase_type = api.PhaseType[phase_type_string] + + try: + custom.check( + custom.ScoringModule.load(loader), + phase_type, + competition.metrics.list(), + utils.read(prediction_file_path), + data_directory_path + ) + + print(f"\n\nPrediction is valid!") + except custom.ParticipantVisibleError as error: + print(f"\n\nPrediction is not valid: {error}") + except api.ApiException as error: + utils.exit_via(error) + except BaseException as error: + print(f"\n\nPrediction check function failed: {error}") + + traceback.print_exc() + + +@scoring_group.command(name="score") +@click.option("--data-directory", "data_directory_path", type=click.Path(file_okay=False, readable=True), required=True) +@click.option("--prediction-file", "prediction_file_path", type=click.Path(dir_okay=False, readable=True), required=True) +@click.option("--phase-type", "phase_type_string", type=click.Choice(LOWER_PHASE_TYPES), default=LOWER_PHASE_TYPES[0]) +@click.pass_context +def scoring_score( + context: click.Context, + data_directory_path: str, + prediction_file_path: str, + phase_type_string: str, +): + from . import custom + + competition, loader = typing.cast( + typing.Tuple[ + api.Competition, + custom.CodeLoader, + ], + context.obj + ) + + phase_type = api.PhaseType[phase_type_string] + + try: + metrics = competition.metrics.list() + results = custom.score( + custom.ScoringModule.load(loader), + phase_type, + metrics, + utils.read(prediction_file_path), + data_directory_path, + ) + + metric_by_id = { + metric.id: metric + for metric in metrics + } + + print(f"\n\nPrediction is scorable!") + + rows = [ + ( + metric_by_id[metric_id].target.name, + metric_by_id[metric_id].name, + str(scored_metric.value) + ) + for metric_id, scored_metric in results.items() + ] + + rows.insert(0, ("Target", "Metric", "Score")) + + max_length_per_columns = [ + max((len(row[index]) for row in rows)) + for index in range(3) + ] + + print(f"\nResults:") + for row in rows: + print(" ", end="") + + for column_index, value in enumerate(row): + width = max_length_per_columns[column_index] + 3 + print(value.ljust(width), end="") + + print() + except custom.ParticipantVisibleError as error: + print(f"\n\nPrediction is not scorable: {error}") + except api.ApiException as error: + utils.exit_via(error) + except BaseException as error: + print(f"\n\nPrediction score function failed: {error}") + + traceback.print_exc() + + if __name__ == '__main__': cli() diff --git a/crunch/command/download.py b/crunch/command/download.py index fefc194..0464f78 100644 --- a/crunch/command/download.py +++ b/crunch/command/download.py @@ -1,32 +1,7 @@ import os import typing -import dataclasses -import click - -from .. import constants, utils, api, container - - -# TODO Remove me -LEGACY_NAME_MAPPING = { - "x_train": "X_train", - "y_train": "y_train", - "x_test": "X_test", - "y_test": "y_test", - "example_prediction": "example_prediction", -} - -@dataclasses.dataclass -class DataFile: - - path: str - url: str - size: int - signed: bool - - @property - def has_size(self): - return self.size != -1 +from .. import api, constants, container, downloader def _get_data_urls( @@ -35,10 +10,10 @@ def _get_data_urls( ) -> typing.Tuple[ int, int, - typing.List[int], + typing.List[api.SplitKeyPythonType], container.Features, api.ColumnNames, - typing.Dict[str, DataFile] + typing.Dict[str, downloader.PreparedDataFile] ]: data_release = round.phases.get_submission().get_data_release() @@ -58,63 +33,18 @@ def _get_data_urls( ) ] - def get_file(data_file: api.DataFile, key: str) -> DataFile: - url = data_file.url - path = os.path.join( - data_directory_path, - data_file.name or (f"{LEGACY_NAME_MAPPING[key]}.{utils.get_extension(url)}") - ) - - return DataFile( - path, - url, - data_file.size, - data_file.signed - ) - return ( embargo, number_of_features, split_keys, features, column_names, - { - key: get_file(value, key) - for key, value in data_files.items() - } + downloader.prepare_all(data_directory_path, data_files), ) -def _download( - data_file: DataFile, - force: bool -): - if data_file is None: - return - - file_length_str = f" ({data_file.size} bytes)" if data_file.has_size else "" - print(f"download {data_file.path} from {utils.cut_url(data_file.url)}" + file_length_str) - - if not data_file.has_size: - print(f"skip: not given by server") - return - - exists = os.path.exists(data_file.path) - if not force and exists: - stat = os.stat(data_file.path) - if stat.st_size == data_file.size: - print(f"already exists: file length match") - return - - if not data_file.signed: - print(f"signature missing: cannot download file without being authenticated") - raise click.Abort() - - utils.download(data_file.url, data_file.path, log=False) - - def download( - round_number="@current", + round_number: api.RoundIdentifierType = "@current", force=False, ): _, project = api.Client.from_project() @@ -131,14 +61,16 @@ def download( split_keys, features, column_names, - data_files, + prepared_data_files, ) = _get_data_urls( round, data_directory_path, ) - - for data_file in data_files.values(): - _download(data_file, force) + + file_paths = downloader.save_all( + prepared_data_files, + force, + ) return ( embargo, @@ -147,11 +79,7 @@ def download( features, column_names, data_directory_path, - { - key: value - for key, value in data_files.items() - if value.has_size - } + file_paths, ) diff --git a/crunch/constants.py b/crunch/constants.py index 1b4210e..d0d2dbe 100644 --- a/crunch/constants.py +++ b/crunch/constants.py @@ -29,3 +29,7 @@ WEB_BASE_URL_PRODUCTION = "https://hub.crunchdao.com/" WEB_BASE_URL_STAGING = "https://hub.crunchdao.io/" WEB_BASE_URL_DEVELOPMENT = "http://localhost:3000/" + +# TODO Change me when renamed +COMPETITIONS_REPOSITORY = "crunchdao/quickstarters" +COMPETITIONS_BRANCH = "feat/broad-1" diff --git a/crunch/custom/__init__.py b/crunch/custom/__init__.py new file mode 100644 index 0000000..4d57c56 --- /dev/null +++ b/crunch/custom/__init__.py @@ -0,0 +1,13 @@ +from .code_loader import ( + CodeLoader, + CodeLoadError, + GithubCodeLoader, + LocalCodeLoader +) + +from .scoring import ( + ScoringModule, + ParticipantVisibleError, + check, + score, +) diff --git a/crunch/custom/code_loader.py b/crunch/custom/code_loader.py new file mode 100644 index 0000000..3df8854 --- /dev/null +++ b/crunch/custom/code_loader.py @@ -0,0 +1,78 @@ +import abc +import os +import types + +import requests + +from .. import constants + + +class CodeLoadError(ImportError): + pass + + +class CodeLoader(abc.ABC): + + def load(self): + name = "scoring.py" + path = self.path + + try: + module = types.ModuleType(name) + module.__loader__ = self + module.__file__ = path + module.__path__ = [os.path.dirname(path)] + module.__package__ = name.rpartition('.')[0] + + code = compile(self.source, path, 'exec') + exec(code, module.__dict__) + except BaseException as exception: + raise CodeLoadError(f"could not load {path}") from exception + + return module + + @property + @abc.abstractmethod + def path(self) -> str: + pass + + @property + @abc.abstractmethod + def source(self) -> str: + pass + + +class GithubCodeLoader(CodeLoader): + + def __init__( + self, + competition_name: str, + repository=constants.COMPETITIONS_REPOSITORY, + branch=constants.COMPETITIONS_BRANCH, + ): + self._path = f"https://raw.githubusercontent.com/{repository}/refs/heads/{branch}/competitions/{competition_name}/scoring/scoring.py" + + @property + def path(self): + return self._path + + @property + def source(self): + response = requests.get(self._path) + response.raise_for_status() + return response.text + + +class LocalCodeLoader(CodeLoader): + + def __init__(self, path: str): + self._path = path + + @property + def path(self): + return self._path + + @property + def source(self): + with open(self._path, "r") as fd: + return fd.read() diff --git a/crunch/custom/scoring.py b/crunch/custom/scoring.py new file mode 100644 index 0000000..d1f2df4 --- /dev/null +++ b/crunch/custom/scoring.py @@ -0,0 +1,127 @@ +import itertools +import typing + +import pandas + +from .. import api, scoring, utils +from .code_loader import CodeLoader + + +class ScoringModule: + + check: typing.Callable + score: typing.Callable + + @staticmethod + def load(loader: CodeLoader): + module = loader.load() + + assert hasattr(module, "check"), "`check` function is missing" + assert hasattr(module, "score"), "`score` function is missing" + + return typing.cast(ScoringModule, module) + + +class ParticipantVisibleError(Exception): + pass + + +def _call( + function: typing.Callable, + phase_type: api.PhaseType, + metrics: typing.List[api.Metric], + prediction: pandas.DataFrame, + data_directory_path: str, + print=print, +): + target_and_metrics = [ + (target, list(metrics)) + for target, metrics in itertools.groupby( + sorted( + metrics, + key=lambda x: x.target.id + ), + lambda x: x.target + ) + ] + + target_names = list({ + target.name + for target, _ in target_and_metrics + }) + + try: + print(f"\n\ncalling {function}\n") + + return utils.smart_call( + function, + { + "phase_type": phase_type, + "prediction": prediction, + "data_directory_path": data_directory_path, + "target_names": target_names, + "target_and_metrics": target_and_metrics, + } + ) + except Exception as exception: + if exception.__class__.__name__ == 'ParticipantVisibleError': + raise ParticipantVisibleError(str(exception)) from exception + + raise + + +def check( + scoring: ScoringModule, + phase_type: api.PhaseType, + metrics: typing.List[api.Metric], + prediction: pandas.DataFrame, + data_directory_path: str, + logger=print, +): + _call( + scoring.check, + phase_type, + metrics, + prediction, + data_directory_path, + logger, + ) + + +def score( + scoring_module: ScoringModule, + phase_type: api.PhaseType, + metrics: typing.List[api.Metric], + prediction: pandas.DataFrame, + data_directory_path: str, + logger=print, +): + metric_ids = { + metric.id + for metric in metrics + } + + results = _call( + scoring_module.score, + phase_type, + metrics, + prediction, + data_directory_path, + logger, + ) + + if not isinstance(results, dict): + raise ValueError(f"return results must be a dict, got: {results.__class__}") + + for metric_id, scored_metric in results.items(): + if metric_id not in metric_ids: + raise ValueError(f"metric id {metric_id} does not exists") + + if not isinstance(scored_metric, scoring.ScoredMetric): + raise ValueError(f"results[{metric_id}] must be a ScoredMetric, got: {scored_metric.__class__}") + + value = scored_metric.value + if not isinstance(value, float): + raise ValueError(f"results[{metric_id}].value must be a float, got: {value.__class__}") + + return results diff --git a/crunch/downloader.py b/crunch/downloader.py new file mode 100644 index 0000000..d8b5495 --- /dev/null +++ b/crunch/downloader.py @@ -0,0 +1,102 @@ +import dataclasses +import os +import typing + +import click + +from . import api, constants, container, utils + +# TODO Remove me +LEGACY_NAME_MAPPING = { + "x_train": "X_train", + "y_train": "y_train", + "x_test": "X_test", + "y_test": "y_test", + "example_prediction": "example_prediction", +} + + +@dataclasses.dataclass +class PreparedDataFile: + + path: str + url: str + size: int + signed: bool + + @property + def has_size(self): + return self.size != -1 + + +def prepare_all( + data_directory_path: str, + data_files: api.DataFilesUnion, +): + return { + key: prepare_one(data_directory_path, value, key) + for key, value in data_files.items() + } + + +def prepare_one( + data_directory_path: str, + data_file: api.DataFile, + key: str +): + url = data_file.url + path = os.path.join( + data_directory_path, + data_file.name or (f"{LEGACY_NAME_MAPPING[key]}.{utils.get_extension(url)}") + ) + + return PreparedDataFile( + path, + url, + data_file.size, + data_file.signed + ) + + +def save_one( + data_file: PreparedDataFile, + force: bool, + print=print, +): + if data_file is None: + return + + file_length_str = f" ({data_file.size} bytes)" if data_file.has_size else "" + print(f"download {data_file.path} from {utils.cut_url(data_file.url)}" + file_length_str) + + if not data_file.has_size: + print(f"skip: not given by server") + return + + exists = os.path.exists(data_file.path) + if not force and exists: + stat = os.stat(data_file.path) + if stat.st_size == data_file.size: + print(f"already exists: file length match") + return + + if not data_file.signed: + print(f"signature missing: cannot download file without being authenticated") + raise click.Abort() + + utils.download(data_file.url, data_file.path, log=False) + + +def save_all( + data_files: typing.Dict[str, PreparedDataFile], + force: bool, + print=print, +): + for data_file in data_files.values(): + save_one(data_file, force, print) + + return { + key: value.path + for key, value in data_files.items() + if value.has_size + } diff --git a/crunch/scoring/score.py b/crunch/scoring/score.py index b5a0ae9..c2dd0f3 100644 --- a/crunch/scoring/score.py +++ b/crunch/scoring/score.py @@ -13,7 +13,7 @@ @dataclasses.dataclass class ScoredMetric: value: typing.Optional[float] - details: typing.List["ScoredMetricDetail"] + details: typing.List["ScoredMetricDetail"] = dataclasses.field(default_factory=list) @dataclasses.dataclass