Skip to content

Commit

Permalink
feat: disabling S3 syncs on demand.
Browse files Browse the repository at this point in the history
Signed-off-by: Ubuntu <[email protected]>
  • Loading branch information
drugilsberg committed Jul 24, 2023
1 parent 4e28e17 commit bd955fd
Show file tree
Hide file tree
Showing 10 changed files with 8 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,6 @@ def safely_determine_task(self, x: str) -> str:
self.tokenizer.mask_token
in x.split(self.tokenizer.expression_separator)[-1]
):

return "generation"

return "regression"
Expand Down
10 changes: 7 additions & 3 deletions src/gt4sd/algorithms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,17 +759,21 @@ def ensure_artifacts_for_version(cls, algorithm_version: str) -> str:
cls.get_application_prefix(),
algorithm_version,
)
local_path = ""
try:
local_path = sync_algorithm_with_s3(prefix)
if not gt4sd_configuration_instance.gt4sd_skip_s3_sync_in_inference:
local_path = sync_algorithm_with_s3(prefix)
except (KeyError, S3SyncError) as error:
logger.info(
f"searching S3 raised {error.__class__.__name__}, using local cache only."
)
logger.debug(error)
local_path = get_cached_algorithm_path(prefix)
finally:
if not local_path:
local_path = get_cached_algorithm_path(prefix)
if not os.path.isdir(local_path):
raise OSError(
f"artifacts directory {local_path} does not exist locally, and syncing with s3 failed: {error}"
f"artifacts directory {local_path} does not exist locally"
)

return local_path
Expand Down
1 change: 0 additions & 1 deletion src/gt4sd/cli/argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def _add_dataclass_arguments(self, dtype: DataClassType) -> None:
else:
kwargs["required"] = True
elif field.type is bool or field.type == Optional[bool]:

if field.default is True:
parser.add_argument(
f"--no_{field.name}",
Expand Down
5 changes: 0 additions & 5 deletions src/gt4sd/cli/load_arguments_from_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,23 @@ def extract_fields_from_class(

# assign default values
for field in fields(dataclass):

if not isinstance(field.default, _MISSING_TYPE):

if field.default is None:
field.default = "none"

arg_fields[field.name]["default"] = field.default

# convert type to str
for field_name in arg_fields:

field_type = find_type(arg_fields[field_name]["type"])

if field_type:

arg_fields[field_name]["type"] = field_type

elif (
hasattr(arg_fields[field_name]["type"], "__origin__")
and arg_fields[field_name]["type"].__origin__ is Union
):

types = [
find_type(type) for type in arg_fields[field_name]["type"].__args__
]
Expand Down
2 changes: 1 addition & 1 deletion src/gt4sd/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class GT4SDConfiguration(BaseSettings):
gt4sd_max_runtime: int = 86400
gt4sd_create_unverified_ssl_context: bool = False
gt4sd_disable_cudnn: bool = False
gt4sd_skip_s3_sync_in_inference: bool = False

gt4sd_s3_host: str = "s3.par01.cloud-object-storage.appdomain.cloud"
gt4sd_s3_access_key: str = "6e9891531d724da89997575a65f4592e"
Expand Down Expand Up @@ -200,7 +201,6 @@ def sync_algorithm_with_s3(
def get_cached_algorithm_path(
prefix: Optional[str] = None, module: str = "algorithms"
) -> str:

if module not in gt4sd_artifact_management_configuration.gt4sd_s3_modules:
raise ValueError(
f"Unknown cache module: {module}. Supported modules: "
Expand Down
1 change: 0 additions & 1 deletion src/gt4sd/training_pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def training_pipeline_name_to_metadata(name: str) -> Dict[str, Any]:

metadata: Dict[str, Any] = {"training_pipeline": name, "parameters": {}}
if name in TRAINING_PIPELINE_ARGUMENTS_MAPPING:

for training_argument_class in TRAINING_PIPELINE_ARGUMENTS_MAPPING[name]:
field_types = extract_fields_from_class(training_argument_class)
metadata["parameters"].update(field_types)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def train( # type: ignore
dataset_args: dataset arguments passed to the configuration.
"""
try:

params = {**training_args, **dataset_args, **model_args}
# Setup logging
logging.basicConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def prepare_datasets_from_files(
for i, (data, path) in enumerate(
zip([train_data, test_data], [train_path, test_path])
):

if not path.endswith(".csv"):
raise TypeError(f"Please provide a csv file not {path}.")

Expand Down
18 changes: 0 additions & 18 deletions src/gt4sd/training_pipelines/tests/test_argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

@dataclass
class TestArguments:

int_arg: int = field(default=0)

float_arg: float = field(default=0.0)
Expand All @@ -50,7 +49,6 @@ class TestArguments:


def test_int_default():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses([])
Expand All @@ -60,7 +58,6 @@ def test_int_default():


def test_float_default():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses([])
Expand All @@ -70,7 +67,6 @@ def test_float_default():


def test_str_default():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses([])
Expand All @@ -80,7 +76,6 @@ def test_str_default():


def test_bool_default():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses([])
Expand All @@ -90,7 +85,6 @@ def test_bool_default():


def test_int_assigned():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses(["--int_arg", "1"])
Expand All @@ -100,7 +94,6 @@ def test_int_assigned():


def test_float_assigned():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses(["--float_arg", "1.0"])
Expand All @@ -110,7 +103,6 @@ def test_float_assigned():


def test_str_assigned():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses(["--str_arg", "my_test"])
Expand All @@ -120,7 +112,6 @@ def test_str_assigned():


def test_bool_assigned():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses(["--bool_arg", "False"])
Expand All @@ -130,7 +121,6 @@ def test_bool_assigned():


def test_bool_int_assigned():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses(["--bool_arg", "0"])
Expand All @@ -140,7 +130,6 @@ def test_bool_int_assigned():


def test_int_none():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses([])
Expand All @@ -149,7 +138,6 @@ def test_int_none():


def test_float_none():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses([])
Expand All @@ -158,7 +146,6 @@ def test_float_none():


def test_str_none():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses([])
Expand All @@ -167,7 +154,6 @@ def test_str_none():


def test_bool_none():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses([])
Expand All @@ -176,7 +162,6 @@ def test_bool_none():


def test_int_str_none():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses(["--int_none_arg", ""])
Expand All @@ -185,7 +170,6 @@ def test_int_str_none():


def test_float_str_none():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses(["--float_none_arg", ""])
Expand All @@ -194,7 +178,6 @@ def test_float_str_none():


def test_str_str_none():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses(["--str_none_arg", ""])
Expand All @@ -203,7 +186,6 @@ def test_str_str_none():


def test_bool_str_none():

parser = ArgumentParser((TestArguments)) # type: ignore

args = parser.parse_args_into_dataclasses(["--bool_none_arg", ""])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@
def combine_defaults_and_user_args(
config: Dict[str, Dict[str, Any]]
) -> Dict[str, Dict[str, Any]]:

arguments = TRAINING_PIPELINE_ARGUMENTS_MAPPING["regression-transformer-trainer"]
"""
We need `conflict_handler='resolve'` because the RT relies on the TrainingArguments
Expand Down Expand Up @@ -145,7 +144,6 @@ def combine_defaults_and_user_args(


def test_train():

pipeline = TRAINING_PIPELINE_MAPPING.get("regression-transformer-trainer")
assert pipeline is not None

Expand All @@ -161,7 +159,6 @@ def test_train():
importlib_resources.files("gt4sd")
/ "training_pipelines/tests/regression_transformer_raw.csv"
) as raw_path:

# Test finetuning the QED model
config["model_args"]["model_path"] = mol_path
config["dataset_args"]["train_data_path"] = str(raw_path)
Expand Down Expand Up @@ -200,7 +197,6 @@ def test_train():
importlib_resources.files("gt4sd")
/ "training_pipelines/tests/regression_transformer_copolymer_raw.csv"
) as raw_path:

# Test finetuning the polymer model
polymer_model = RegressionTransformerMolecules(
algorithm_version="block_copolymer"
Expand Down

0 comments on commit bd955fd

Please sign in to comment.