Skip to content

Commit

Permalink
feat(runner): get file path from data release targets (#148)
Browse files Browse the repository at this point in the history
Makes it easier to loop over the targets in the runner.
  • Loading branch information
Caceresenzo authored Oct 28, 2024
1 parent 03ec73d commit af40118
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 68 deletions.
6 changes: 3 additions & 3 deletions crunch/api/domain/data_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ def __repr__(self):
@dataclasses.dataclass(frozen=True)
class DataFile:

name: str
url: str
size: int
signed: bool
# TODO Make me mandatory
name: str = None
compressed: bool = False
compressed: bool


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -253,6 +252,7 @@ class TargetColumnNames:
side: typing.Optional[str]
input: typing.Optional[str]
output: typing.Optional[str]
file_path: typing.Optional[str]


@dataclasses_json.dataclass_json(
Expand Down
15 changes: 11 additions & 4 deletions crunch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def cloud(
@click.option("--side-column-name", required=True)
@click.option("--input-column-name", required=True)
@click.option("--output-column-name", required=True)
@click.option("--target", "targets", required=True, multiple=True, nargs=4)
@click.option("--target", "targets", required=True, multiple=True, nargs=5)
def cloud_executor(
competition_name: str,
competition_format: str,
Expand Down Expand Up @@ -621,7 +621,7 @@ def cloud_executor(
side_column_name: str,
input_column_name: str,
output_column_name: str,
targets: typing.List[typing.Tuple[str, str, str, str]],
targets: typing.List[typing.Tuple[str, str, str, str, str]],
):
from .runner import is_inside
if not is_inside:
Expand Down Expand Up @@ -671,8 +671,15 @@ def cloud_executor(
input_column_name or None,
output_column_name or None,
[
api.TargetColumnNames(0, target_name, side or None, input or None, output or None)
for target_name, side, input, output in targets
api.TargetColumnNames(
0,
target_name,
side or None,
input or None,
output or None,
file_path or None,
)
for target_name, side, input, output, file_path in targets
]
)
)
Expand Down
3 changes: 2 additions & 1 deletion crunch/runner/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,8 @@ def sandbox(
target_column_names.name,
target_column_names.side or "",
target_column_names.input or "",
target_column_names.output or ""
target_column_names.output or "",
target_column_names.file_path or ""
)
for target_column_names in self.column_names.targets
],
Expand Down
26 changes: 5 additions & 21 deletions crunch/runner/cloud_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,15 +414,9 @@ def process_unstructured(
}

if self.train:
# TODO Make dynamic or come from the API
train_directory_path = os.path.join(self.data_directory_path, "train")

utils.smart_call(
train_function,
default_values,
{
"train_directory_path": train_directory_path,
},
log=False
)

Expand All @@ -431,26 +425,16 @@ def process_unstructured(
target_column_names = self.column_names.get_target_by_name(self.loop_key)
assert target_column_names is not None, f"target not found: {self.loop_key}"

# TODO Make dynamic or come from the API
test_directory_path = os.path.join(self.data_directory_path, "test")

matching_data_file_name = utils.find_first_file(
test_directory_path,
target_column_names.name
)

test_data_file_path = os.path.join(
test_directory_path,
matching_data_file_name
) if matching_data_file_name else None
data_file_path = os.path.join(
self.data_directory_path,
target_column_names.file_path
) if target_column_names.file_path else None

prediction = utils.smart_call(
infer_function,
default_values,
{
"test_directory_path": test_directory_path,
"test_data_file_path": test_data_file_path,
"data_file_path": test_data_file_path,
"data_file_path": data_file_path,
"target_name": target_column_names.name,
}
)
Expand Down
37 changes: 12 additions & 25 deletions crunch/runner/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def stream_loop(

def _get_spatial_default_values(self):
return {
"data_directory_path": self.data_directory_path,
"model_directory_path": self.model_directory_path,
"column_names": self.column_names,
"target_names": self.column_names.target_names,
Expand All @@ -355,44 +356,30 @@ def _get_spatial_default_values(self):
def spatial_train(
self,
):
# TODO Make dynamic or come from the API
train_directory_path = os.path.join(self.data_directory_path, "train")

default_values = self._get_spatial_default_values()

logging.warning('call: train')
utils.smart_call(self.train_function, default_values, {
"train_directory_path": train_directory_path,
"data_directory_path": train_directory_path,
})
utils.smart_call(
self.train_function,
self._get_spatial_default_values()
)

def spatial_loop(
self,
target_column_names: api.TargetColumnNames
) -> pandas.DataFrame:
# TODO Make dynamic or come from the API
test_directory_path = os.path.join(self.data_directory_path, "test")

matching_data_file_name = utils.find_first_file(
test_directory_path,
target_column_names.name
)

test_data_file_path = os.path.join(
test_directory_path,
matching_data_file_name
) if matching_data_file_name else None
data_file_path = os.path.join(
self.data_directory_path,
target_column_names.file_path
) if target_column_names.file_path else None

logging.warning('call: infer')
logging.warning(f'call: infer ({target_column_names.name})')

prediction = utils.smart_call(
self.infer_function,
self._get_spatial_default_values(),
{
"test_directory_path": test_directory_path,
"test_data_file_path": test_data_file_path,
"data_file_path": test_data_file_path,
"data_file_path": data_file_path,
"target_name": target_column_names.name,
"target_name": target_column_names.file_path,
}
)

Expand Down
12 changes: 0 additions & 12 deletions crunch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,15 +457,3 @@ def split_at_nans(
parts.append(dataframe.iloc[start:end])

return parts


def find_first_file(directory: str, prefix: str):
prefix_with_dot = prefix + "."

for name in os.listdir(directory):
path = os.path.join(directory, name)
if os.path.isdir(path):
continue

if name == prefix or name.startswith(prefix_with_dot):
return name
4 changes: 2 additions & 2 deletions tests/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def test_from_model(self):
input=None,
output=None,
targets=[
TargetColumnNames(0, "a", "side_a", "in_a", "out_a"),
TargetColumnNames(0, "b", "side_b", "in_b", "out_b")
TargetColumnNames(0, "a", "side_a", "in_a", "out_a", None),
TargetColumnNames(0, "b", "side_b", "in_b", "out_b", None)
]
)

Expand Down

0 comments on commit af40118

Please sign in to comment.