Skip to content

Commit

Permalink
Add misc.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lmanan committed Oct 11, 2023
1 parent d488781 commit c038aaf
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
14 changes: 9 additions & 5 deletions cellulus/configs/experiment_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

import attrs
from attrs.validators import instance_of

Expand All @@ -13,11 +15,11 @@ class ExperimentConfig:
Parameters:
experiment_name:
experiment_name: (default = 'YYYY-MM-DD')
A unique name for the experiment.
object_size:
object_size: (default = 26.0)
A rough estimate of the size of objects in the image, given in
world units. The "patch size" of the network will be chosen based
Expand All @@ -36,13 +38,15 @@ class ExperimentConfig:
Configuration object for prediction.
"""

experiment_name: str = attrs.field(validator=instance_of(str))
object_size: float = attrs.field(validator=instance_of(float))
experiment_name: str = attrs.field(
default=datetime.today().strftime("%Y-%m-%d"), validator=instance_of(str)
)
object_size: float = attrs.field(default=26.0, validator=instance_of(float))

model_config: ModelConfig = attrs.field(converter=to_config(ModelConfig))
train_config: TrainConfig = attrs.field(
default=None, converter=to_config(TrainConfig)
)
inference_config: InferenceConfig = attrs.field(
default=None, converter=to_config(InferenceConfig)
default=None, converter=to_config(InferenceConfig)(default="YYYY-MM-DD")
)
33 changes: 33 additions & 0 deletions cellulus/utils/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
from io import BytesIO
from urllib.request import urlopen
from zipfile import ZipFile


def extract_data(zip_url, data_dir, project_name):
"""
Extracts data from `zip_url` to the location identified by `data_dir` parameters.
Parameters
----------
zip_url: string
Indicates the external url from where the data is downloaded
data_dir: string
Indicates the path to the directory where the data should be saved.
Returns
-------
"""
if not os.path.exists(os.path.join(data_dir, project_name)):
os.makedirs(data_dir)
print(f"Created new directory {data_dir}")

with urlopen(zip_url) as zipresp:
with ZipFile(BytesIO(zipresp.read())) as zfile:
zfile.extractall(data_dir)
print(f"Downloaded and unzipped data to the location {data_dir}")
else:
print(
"Directory already exists at the location "
f"{os.path.join(data_dir, project_name)}"
)

0 comments on commit c038aaf

Please sign in to comment.