Skip to content

Commit

Permalink
feat(orthogonalization): add
Browse files Browse the repository at this point in the history
  • Loading branch information
Caceresenzo committed Feb 28, 2024
1 parent 2de7287 commit a9a99ad
Show file tree
Hide file tree
Showing 19 changed files with 328 additions and 49 deletions.
1 change: 1 addition & 0 deletions crunch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@

from .inline import load as load_notebook
from .runner import is_inside as is_inside_runner
from .orthogonalization import orthogonalize
16 changes: 0 additions & 16 deletions crunch/api.py

This file was deleted.

1 change: 1 addition & 0 deletions crunch/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .domain import *
Empty file added crunch/api/client/_api_key.py
Empty file.
78 changes: 78 additions & 0 deletions crunch/api/client/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import requests
import typing
import dataclasses
import urllib.parse
import inflection

from ... import store
from .. import domain


@dataclasses.dataclass(frozen=True)
class ClientConfiguration:

web_base_url: str = dataclasses.field(default_factory=lambda: store.web_base_url)
api_base_url: str = dataclasses.field(default_factory=lambda: store.api_base_url)
debug: bool = dataclasses.field(default_factory=lambda: store.debug)


class Client:

def __init__(
self,
configuration: typing.Optional[ClientConfiguration] = None
):
self._configuration = configuration or ClientConfiguration()

self._session = requests.Session()

def _request(self, method, endpoint, *args, **kwargs):
response = super().request(
method,
urllib.parse.urljoin(self._configuration.api_base_url, endpoint),
*args,
**kwargs
)

status_code = response.status_code
if status_code // 100 != 2:
raise self._convert_error(response)

return response

def _convert_error(
self,
response: requests.Response
):
try:
error = response.json()
except:
return ValueError(f"unexpected error: {response.text}")
else:
code = error.pop("code", "")
message = error.pop("message", "")

error_class = self._find_error_class(code, message)
error = error_class(message)

for key, value in error.items():
key = inflection.underscore(key)
setattr(error, key, value)

return error

def _find_error_class(
self,
code: str
):
if code:
base_class_name = inflection.camelize(code)

for suffix in ["Exception", "Error"]:
class_name = base_class_name + suffix

clazz = getattr(domain, class_name, None)
if clazz is not None:
return clazz

return domain.ApiException
20 changes: 20 additions & 0 deletions crunch/api/client/_push_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import typing

from ._base import Client, ClientConfiguration


class PushTokenClient(Client):

def __init__(
self,
push_token: str,
configuration: typing.Optional[ClientConfiguration]=None,
):
super().__init__(configuration)

self._session.params.update({
"pushToken": push_token
})

def orthogonalization():
pass
7 changes: 7 additions & 0 deletions crunch/api/domain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ._common import ApiException
from ._crunch import *
from ._common import *
from ._prediction import *
from ._project import *
from ._score import *
from ._submission import *
22 changes: 22 additions & 0 deletions crunch/api/domain/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import datetime

import dataclasses_json
import marshmallow

date_config = dataclasses_json.config(
encoder=datetime.date.isoformat,
decoder=datetime.date.fromisoformat,
mm_field=marshmallow.fields.DateTime(format='iso')
)

datetime_config = dataclasses_json.config(
encoder=datetime.datetime.isoformat,
decoder=datetime.datetime.fromisoformat,
mm_field=marshmallow.fields.Date(format='iso')
)


class ApiException(Exception):

def __init__(self, message: str):
super().__init__(message)
5 changes: 5 additions & 0 deletions crunch/api/domain/_crunch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._common import ApiException


class CurrentCrunchNotFoundException(ApiException):
pass
29 changes: 29 additions & 0 deletions crunch/api/domain/_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import dataclasses
import enum
import typing

import dataclasses_json


class PredictionTag(enum.Enum):

USER_RUN_OUTPUT = "USER_RUN_OUTPUT"
MANAGED_RUN_OUTPUT = "MANAGED_RUN_OUTPUT"
USER_ORTHOGONALIZE = "USER_ORTHOGONALIZE"


@dataclasses_json.dataclass_json(
letter_case=dataclasses_json.LetterCase.CAMEL,
undefined=dataclasses_json.Undefined.EXCLUDE
)
@dataclasses.dataclass(frozen=True)
class Prediction:

id: int
name: typing.Optional[str]
success: typing.Optional[bool]
error: typing.Optional[str]
mean: typing.Optional[float]
tag: PredictionTag
orthogonalized: bool
created_at: bool
5 changes: 5 additions & 0 deletions crunch/api/domain/_project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._common import ApiException


class InvalidProjectTokenException(ApiException):
pass
49 changes: 49 additions & 0 deletions crunch/api/domain/_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import dataclasses
import datetime
import enum
import typing

import dataclasses_json

from ._common import datetime_config


class MetricFunction(enum.Enum):

SPEARMAN = "SPEARMAN"
F1 = "F1"
RECALL = "RECALL"
PRECISION = "PRECISION"
DOT_PRODUCT = "DOT_PRODUCT"


@dataclasses_json.dataclass_json(
letter_case=dataclasses_json.LetterCase.CAMEL,
undefined=dataclasses_json.Undefined.EXCLUDE
)
@dataclasses.dataclass(frozen=True)
class Metric:

id: int
name: str
display_name: str
weight: int
score: bool
multiplier: float
function: MetricFunction
created_at: datetime.datetime = dataclasses.field(metadata=datetime_config)


@dataclasses_json.dataclass_json(
letter_case=dataclasses_json.LetterCase.CAMEL,
undefined=dataclasses_json.Undefined.EXCLUDE
)
@dataclasses.dataclass(frozen=True)
class Score:

id: int
success: bool
metric: Metric
value: typing.Optional[float]
details: typing.Optional[typing.Dict[str, typing.Optional[float]]]
created_at: datetime.datetime = dataclasses.field(metadata=datetime_config)
5 changes: 5 additions & 0 deletions crunch/api/domain/_submission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._common import ApiException


class NeverSubmittedException(ApiException):
pass
12 changes: 1 addition & 11 deletions crunch/inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,7 @@ def __init__(self, module: typing.Any, model_directory: str, has_gpu=False):
self.model_directory = model_directory
self.has_gpu = has_gpu

self.session = utils.CustomSession(
os.environ.get(
constants.WEB_BASE_URL_ENV_VAR,
constants.WEB_BASE_URL_DEFAULT
),
os.environ.get(
constants.API_BASE_URL_ENV_VAR,
constants.API_BASE_URL_DEFAULT
),
bool(os.environ.get(constants.DEBUG_ENV_VAR, "False")),
)
self.session = utils.CustomSession.from_env()

print(f"loaded inline runner with module: {module}")

Expand Down
37 changes: 16 additions & 21 deletions crunch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,30 @@
import logging

import click
import dotenv

from . import command, constants, utils, api, library, tester, __version__
from . import command, constants, utils, api, library, tester, store, __version__

session = None
debug = False

dotenv.load_dotenv(".env", verbose=True)
store.load_from_env()


@click.group()
@click.version_option(
__version__.__version__,
package_name="__version__.__title__"
)
@click.option("--debug", "enable_debug", envvar=constants.DEBUG_ENV_VAR, is_flag=True, help="Enable debug output.")
@click.option("--debug", envvar=constants.DEBUG_ENV_VAR, is_flag=True, help="Enable debug output.")
@click.option("--api-base-url", envvar=constants.API_BASE_URL_ENV_VAR, default=constants.API_BASE_URL_DEFAULT, help="Set the API base url.")
@click.option("--web-base-url", envvar=constants.WEB_BASE_URL_ENV_VAR, default=constants.WEB_BASE_URL_DEFAULT, help="Set the Web base url.")
def cli(
enable_debug: bool,
debug: bool,
api_base_url: str,
web_base_url: str,
):
global debug
debug = enable_debug
store.debug = debug
store.api_base_url = api_base_url
store.web_base_url = web_base_url

global session
session = utils.CustomSession(
store.session = utils.CustomSession(
web_base_url,
api_base_url,
debug,
Expand Down Expand Up @@ -61,7 +57,7 @@ def setup(
directory = os.path.normpath(directory)

command.setup(
session,
store.session,
clone_token=clone_token,
submission_number=submission_number,
competition_name=competition_name,
Expand All @@ -75,7 +71,7 @@ def setup(
os.chdir(directory)

try:
command.download(session, force=True)
command.download(store.session, force=True)
except api.CurrentCrunchNotFoundException:
command.download_no_data_available()

Expand All @@ -84,8 +80,7 @@ def setup(
print(f"Next recommended actions:")

if directory != '.':
print(
f" - To get inside your workspace directory, run: cd {directory}")
print(f" - To get inside your workspace directory, run: cd {directory}")

print(f" - To see all of the available commands of the CrunchDAO CLI, run: crunch --help")

Expand Down Expand Up @@ -124,7 +119,7 @@ def push(

try:
command.push(
session,
store.session,
message=message,
main_file_path=main_file_path,
model_directory_path=model_directory_path,
Expand Down Expand Up @@ -157,11 +152,11 @@ def test(
tester.install_logger()

if not skip_library_check and os.path.exists(constants.REQUIREMENTS_TXT):
library.scan(session, requirements_file=constants.REQUIREMENTS_TXT)
library.scan(store.session, requirements_file=constants.REQUIREMENTS_TXT)
logging.warn('')

command.test(
session,
store.session,
main_file_path=main_file_path,
model_directory_path=model_directory_path,
force_first_train=not no_force_first_train,
Expand All @@ -180,7 +175,7 @@ def download(

try:
command.download(
session,
store.session,
round_number=round_number
)
except api.CurrentCrunchNotFoundException:
Expand Down Expand Up @@ -211,7 +206,7 @@ def update_token(
utils.change_root()

command.update_token(
session,
store.session,
clone_token=clone_token
)

Expand Down
Loading

0 comments on commit a9a99ad

Please sign in to comment.