Skip to content

Commit

Permalink
feat(setup): add --size option
Browse files Browse the repository at this point in the history
  • Loading branch information
Caceresenzo committed Oct 31, 2024
1 parent 2e3a81e commit f56e758
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 18 deletions.
7 changes: 5 additions & 2 deletions crunch/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,12 @@ def __init__(
self,
api_base_url: str,
web_base_url: str,
auth: Auth
auth: Auth,
project_info: typing.Optional[utils.ProjectInfo] = None
):
self.api = EndpointClient(api_base_url, auth)
self.web_base_url = web_base_url
self.project_info = project_info

@property
def competitions(self):
Expand Down Expand Up @@ -258,7 +260,8 @@ def from_project() -> typing.Tuple["Client", Project]:
client = Client(
store.api_base_url,
store.web_base_url,
PushTokenAuth(push_token)
PushTokenAuth(push_token),
project_info,
)

competition = client.competitions.get(project_info.competition_name)
Expand Down
1 change: 1 addition & 0 deletions crunch/api/domain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
OriginalFiles,
DataFilesUnion,
KnownData,
SizeVariant,
)
from .enum_ import (
Language,
Expand Down
9 changes: 9 additions & 0 deletions crunch/api/domain/data_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ def from_dict_array(
]


class SizeVariant(enum.Enum):

DEFAULT = "DEFAULT"
LARGE = "LARGE"

def __repr__(self):
return self.name


class DataRelease(Model):

resource_identifier_attribute = "number"
Expand Down
16 changes: 13 additions & 3 deletions crunch/api/domain/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,16 @@ def crunches(self):
client=self._client
)

def get_data_release(self):
def get_data_release(
self,
size_variant: typing.Optional["data_release.SizeVariant"]
):
from .data_release import DataReleaseCollection, DataRelease

attrs = self._client.api.get_submission_phase_data_release(
self.round.competition.resource_identifier,
self.round.resource_identifier,
size_variant.name if size_variant else None,
)

competition = self.round.competition
Expand Down Expand Up @@ -173,11 +177,17 @@ def get_phase(
def get_submission_phase_data_release(
self,
competition_identifier,
round_identifier
round_identifier,
size_variant,
):
params = {}
if size_variant:
params["sizeVariant"] = size_variant

return self._result(
self.get(
f"/v2/competitions/{competition_identifier}/rounds/{round_identifier}/phases/submission/data-release"
f"/v2/competitions/{competition_identifier}/rounds/{round_identifier}/phases/submission/data-release",
params=params
),
json=True
)
34 changes: 26 additions & 8 deletions crunch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
ENVIRONMENT_DEVELOPMENT: (constants.API_BASE_URL_DEVELOPMENT, constants.WEB_BASE_URL_DEVELOPMENT),
}

DATA_SIZE_VARIANTS = [
api.SizeVariant.DEFAULT.name.lower(),
api.SizeVariant.LARGE.name.lower(),
]


def _format_directory(directory: str, competition_name: str, project_name: str):
directory = directory \
Expand Down Expand Up @@ -107,12 +112,12 @@ def init(

try:
command.init(
clone_token=clone_token,
competition_name=competition_name,
project_name=project_name,
directory=directory,
model_directory=model_directory_path,
force=force,
clone_token,
competition_name,
project_name,
directory,
model_directory_path,
force,
)

if not no_data:
Expand Down Expand Up @@ -140,6 +145,7 @@ def init(
@click.option("--quickstarter-name", type=str, help="Pre-select a quickstarter.")
@click.option("--show-notebook-quickstarters", is_flag=True, help="Show quickstarters notebook in selection.")
@click.option("--notebook", is_flag=True, help="Setup everything for a notebook environment.")
@click.option("--size", "data_size_variant_raw", type=click.Choice(DATA_SIZE_VARIANTS), default=DATA_SIZE_VARIANTS[0], help="Use another data variant.")
@click.argument("competition-name", required=True)
@click.argument("project-name", required=True)
@click.argument("directory", default=DIRECTORY_DEFAULT_FORMAT)
Expand All @@ -157,6 +163,7 @@ def setup(
quickstarter_name: str,
show_notebook_quickstarters: bool,
notebook: bool,
data_size_variant_raw: str,
):
if notebook:
if force:
Expand All @@ -182,6 +189,8 @@ def setup(
else:
directory = _format_directory(directory, competition_name, project_name)

data_size_variant = api.SizeVariant[data_size_variant_raw.upper()]

try:
command.setup(
clone_token,
Expand All @@ -195,6 +204,7 @@ def setup(
not no_quickstarter,
quickstarter_name,
show_notebook_quickstarters,
data_size_variant,
)

if not no_data:
Expand Down Expand Up @@ -336,14 +346,22 @@ def test(

@cli.command(help="Download the data locally.")
@click.option("--round-number", default="@current")
@click.option("--force", is_flag=True, help="Force the download of the data.")
@click.option("--size-variant", "size_variant_raw", type=click.Choice(DATA_SIZE_VARIANTS), default=DATA_SIZE_VARIANTS[0], help="Use alternative version of the data.")
def download(
round_number: str
round_number: str,
force: bool,
size_variant_raw: str,
):
utils.change_root()

size_variant = api.SizeVariant[size_variant_raw.upper()]

try:
command.download(
round_number=round_number
round_number,
force,
size_variant,
)
except api.CrunchNotFoundException:
command.download_no_data_available()
Expand Down
17 changes: 14 additions & 3 deletions crunch/command/download.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import typing

from .. import api, constants, container, downloader
from .. import api, constants, container, downloader, utils


def _get_data_urls(
round: api.Round,
data_directory_path: str,
size_variant: api.SizeVariant,
) -> typing.Tuple[
int,
int,
Expand All @@ -15,7 +16,7 @@ def _get_data_urls(
api.ColumnNames,
typing.Dict[str, downloader.PreparedDataFile]
]:
data_release = round.phases.get_submission().get_data_release()
data_release = round.phases.get_submission().get_data_release(size_variant=size_variant)

embargo = data_release.embargo
number_of_features = data_release.number_of_features
Expand Down Expand Up @@ -46,8 +47,17 @@ def _get_data_urls(
def download(
round_number: api.RoundIdentifierType = "@current",
force=False,
size_variant: typing.Optional[api.SizeVariant] = None,
):
_, project = api.Client.from_project()
client, project = api.Client.from_project()

project_info = client.project_info
if size_variant is None:
size_variant = project_info.size_variant
elif project_info.size_variant != size_variant:
project_info.size_variant = size_variant
utils.write_project_info(project_info)
print(f"repository set default size variant: {size_variant.name.lower()}")

competition = project.competition
round = competition.rounds.get(round_number)
Expand All @@ -65,6 +75,7 @@ def download(
) = _get_data_urls(
round,
data_directory_path,
size_variant,
)

file_paths = downloader.save_all(
Expand Down
4 changes: 3 additions & 1 deletion crunch/command/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def init(
directory: str,
model_directory: str,
force: bool,
data_size_variant=api.SizeVariant.DEFAULT
):
should_delete = _check_if_already_exists(directory, force)

Expand All @@ -47,7 +48,8 @@ def init(
project_info = utils.ProjectInfo(
competition_name,
project_name,
user_id
user_id,
data_size_variant,
)

utils.write_project_info(project_info, directory)
Expand Down
5 changes: 4 additions & 1 deletion crunch/command/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@ def setup(
show_quickstarters: bool,
quickstarter_name: typing.Optional[str],
show_notebook_quickstarters: bool,
data_size_variant: api.SizeVariant,
):
command.init(
clone_token,
competition_name,
project_name,
directory,
model_directory,
force
force,
data_size_variant,
)

_, project = api.Client.from_project()

try:
Expand Down
8 changes: 8 additions & 0 deletions crunch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class ProjectInfo:
competition_name: str
project_name: str
user_id: str
size_variant: api.SizeVariant


def write_project_info(info: ProjectInfo, directory=".") -> ProjectInfo:
Expand All @@ -81,6 +82,7 @@ def write_project_info(info: ProjectInfo, directory=".") -> ProjectInfo:
"competitionName": info.competition_name,
"projectName": info.project_name,
"userId": info.user_id,
"sizeVariant": info.size_variant.name,
}, fd)


Expand All @@ -99,11 +101,17 @@ def read_project_info(raise_if_missing=True) -> ProjectInfo:

root = json.loads(content)

try:
size_variant = api.SizeVariant[root["sizeVariant"]]
except:
size_variant = api.SizeVariant.DEFAULT

# TODO: need of a better system for handling file versions
return ProjectInfo(
root["competitionName"],
root.get("projectName") or "default", # backward compatibility
root["userId"],
size_variant,
)


Expand Down

0 comments on commit f56e758

Please sign in to comment.