Skip to content

Commit

Permalink
feat(custom): support checking and scoring (#143)
Browse files Browse the repository at this point in the history
Infrastructure and tools required for the new unstructured competition format.
  • Loading branch information
Caceresenzo authored Oct 27, 2024
1 parent aa36402 commit 86cef96
Show file tree
Hide file tree
Showing 13 changed files with 558 additions and 91 deletions.
7 changes: 7 additions & 0 deletions crunch/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,10 @@
LatestRoundNotFoundException,
NextRoundNotFoundException
)

from .identifiers import (
CompetitionIdentifierType,
RoundIdentifierType,
PhaseIdentifierType,
CrunchIdentifierType,
)
4 changes: 3 additions & 1 deletion crunch/api/domain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
ColumnNames,
TargetColumnNames,
DataReleaseTargetResolution,
SplitKeyPythonType,
DataReleaseSplit,
DataReleaseSplitGroup,
DataReleaseSplitReduced,
DataReleaseFeature,
DataFile,
DataFiles,
OriginalFiles,
DataFile,
DataFilesUnion,
KnownData,
)
from .enum_ import (
Expand Down
14 changes: 10 additions & 4 deletions crunch/api/domain/data_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def items(self):
return vars(self).items()


DataFilesUnion = typing.Union[DataFiles, OriginalFiles, typing.Dict[str, DataFile]]


class DataReleaseSplitGroup(enum.Enum):

TRAIN = "TRAIN"
Expand All @@ -95,14 +98,17 @@ def __repr__(self):
return self.name


SplitKeyPythonType = typing.Union[str, int]


@dataclasses_json.dataclass_json(
letter_case=dataclasses_json.LetterCase.CAMEL,
undefined=dataclasses_json.Undefined.EXCLUDE,
)
@dataclasses.dataclass(frozen=True)
class DataReleaseSplit:

key: typing.Union[str, int]
key: SplitKeyPythonType
group: DataReleaseSplitGroup
reduced: typing.Optional[DataReleaseSplitReduced] = None

Expand Down Expand Up @@ -180,7 +186,7 @@ def target_resolution(self):
return DataReleaseTargetResolution[self._attrs["target_resolution"]]

@property
def data_files(self) -> typing.Union[DataFiles, OriginalFiles, typing.Dict[str, DataFile]]:
def data_files(self) -> DataFilesUnion:
files = self._attrs.get("dataFiles")
if not files:
self.reload()
Expand All @@ -198,13 +204,13 @@ def data_files(self) -> typing.Union[DataFiles, OriginalFiles, typing.Dict[str,
})

@property
def splits(self) -> typing.Tuple[DataReleaseSplit]:
def splits(self) -> typing.List[DataReleaseSplit]:
splits = self._attrs.get("splits")
if splits is None:
self.reload(include_splits=True)
splits = self._attrs["splits"]

return tuple(DataReleaseSplit.from_dict_array(splits))
return list(DataReleaseSplit.from_dict_array(splits))

@property
def default_feature_group(self) -> str:
Expand Down
13 changes: 13 additions & 0 deletions crunch/api/domain/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,29 @@ class PhaseType(enum.Enum):
def __repr__(self):
return self.name

def slug(self):
if self == PhaseType.SUBMISSION:
return "submission"

if self == PhaseType.OUT_OF_SAMPLE:
return "out-of-sample"

return self._default_slug()

def pretty(self):
if self == PhaseType.SUBMISSION:
return "Submission"

if self == PhaseType.OUT_OF_SAMPLE:
return "Out-of-Sample"

return self._default_slug()

def _default_slug(self):
return str(self).lower().replace("_", "-")



class Phase(Model):

resource_identifier_attribute = "type"
Expand Down
2 changes: 1 addition & 1 deletion crunch/api/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __hash__(self):
return hash(f"{self.__class__.__name__}:{self.id}")

@property
def id(self):
def id(self) -> typing.Union[int, str]:
return self._attrs.get(self.id_attribute)

@property
Expand Down
187 changes: 187 additions & 0 deletions crunch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import sys
import traceback
import typing

import click
Expand Down Expand Up @@ -679,5 +680,191 @@ def cloud_executor(
executor.start()


@cli.group(name="organizer")
@click.argument('competition_name')
@click.pass_context
def organize_group(
context: click.Context,
competition_name: str,
):
client = api.Client.from_env()

try:
competition = client.competitions.get(competition_name)
except api.errors.CompetitionNameNotFoundException:
print(f"competition {competition_name} not found", file=sys.stderr)
raise click.Abort()
except api.ApiException as error:
utils.exit_via(error)

context.obj = competition


@organize_group.command()
@click.pass_context
def x(
context: click.Context,
):
print(context.obj)


@organize_group.group(name="test")
def organize_test_group():
pass


@organize_test_group.group(name="scoring")
@click.option("--script-file", "script_file_path", type=click.Path(dir_okay=False, readable=True), required=False)
@click.option("--github-repository", default=constants.COMPETITIONS_REPOSITORY, required=False)
@click.option("--github-branch", default=constants.COMPETITIONS_BRANCH, required=False)
@click.pass_context
def scoring_group(
context: click.Context,
script_file_path: str,
github_repository: str,
github_branch: str,
):
from . import custom

competition: api.Competition = context.obj

if script_file_path is None:
loader = custom.GithubCodeLoader(
competition.name,
repository=github_repository,
branch=github_branch,
)
else:
loader = custom.LocalCodeLoader(
script_file_path,
)

context.obj = (competition, loader)


LOWER_PHASE_TYPES = list(map(lambda x: x.name, [
api.PhaseType.SUBMISSION,
api.PhaseType.OUT_OF_SAMPLE,
]))


@scoring_group.command(name="check")
@click.option("--data-directory", "data_directory_path", type=click.Path(file_okay=False, readable=True), required=True)
@click.option("--prediction-file", "prediction_file_path", type=click.Path(dir_okay=False, readable=True), required=True)
@click.option("--phase-type", "phase_type_string", type=click.Choice(LOWER_PHASE_TYPES), default=LOWER_PHASE_TYPES[0])
@click.pass_context
def scoring_check(
context: click.Context,
data_directory_path: str,
prediction_file_path: str,
phase_type_string: str,
):
from . import custom

competition, loader = typing.cast(
typing.Tuple[
api.Competition,
custom.CodeLoader,
],
context.obj
)

phase_type = api.PhaseType[phase_type_string]

try:
custom.check(
custom.ScoringModule.load(loader),
phase_type,
competition.metrics.list(),
utils.read(prediction_file_path),
data_directory_path
)

print(f"\n\nPrediction is valid!")
except custom.ParticipantVisibleError as error:
print(f"\n\nPrediction is not valid: {error}")
except api.ApiException as error:
utils.exit_via(error)
except BaseException as error:
print(f"\n\nPrediction check function failed: {error}")

traceback.print_exc()


@scoring_group.command(name="score")
@click.option("--data-directory", "data_directory_path", type=click.Path(file_okay=False, readable=True), required=True)
@click.option("--prediction-file", "prediction_file_path", type=click.Path(dir_okay=False, readable=True), required=True)
@click.option("--phase-type", "phase_type_string", type=click.Choice(LOWER_PHASE_TYPES), default=LOWER_PHASE_TYPES[0])
@click.pass_context
def scoring_score(
context: click.Context,
data_directory_path: str,
prediction_file_path: str,
phase_type_string: str,
):
from . import custom

competition, loader = typing.cast(
typing.Tuple[
api.Competition,
custom.CodeLoader,
],
context.obj
)

phase_type = api.PhaseType[phase_type_string]

try:
metrics = competition.metrics.list()
results = custom.score(
custom.ScoringModule.load(loader),
phase_type,
metrics,
utils.read(prediction_file_path),
data_directory_path,
)

metric_by_id = {
metric.id: metric
for metric in metrics
}

print(f"\n\nPrediction is scorable!")

rows = [
(
metric_by_id[metric_id].target.name,
metric_by_id[metric_id].name,
str(scored_metric.value)
)
for metric_id, scored_metric in results.items()
]

rows.insert(0, ("Target", "Metric", "Score"))

max_length_per_columns = [
max((len(row[index]) for row in rows))
for index in range(3)
]

print(f"\nResults:")
for row in rows:
print(" ", end="")

for column_index, value in enumerate(row):
width = max_length_per_columns[column_index] + 3
print(value.ljust(width), end="")

print()
except custom.ParticipantVisibleError as error:
print(f"\n\nPrediction is not scorable: {error}")
except api.ApiException as error:
utils.exit_via(error)
except BaseException as error:
print(f"\n\nPrediction score function failed: {error}")

traceback.print_exc()


if __name__ == '__main__':
cli()
Loading

0 comments on commit 86cef96

Please sign in to comment.