Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 16, 2024
1 parent b6e0165 commit a6000af
Show file tree
Hide file tree
Showing 36 changed files with 392 additions and 115 deletions.
8 changes: 5 additions & 3 deletions pipelines/precipitation_model/impa/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,17 @@
"start_datetime",
default=None,
required=False,
#description="Datetime in YYYY-MM-dd HH:mm:ss format, UTC timezone",
# description="Datetime in YYYY-MM-dd HH:mm:ss format, UTC timezone",
)
num_workers = Parameter(
"num_workers",
default=8,
required=False,
#description="Number of workers to use for parallel processing",
# description="Number of workers to use for parallel processing",
)
cuda = Parameter("cuda", default=False, required=False) #, description="Use CUDA for prediction"
cuda = Parameter(
"cuda", default=False, required=False
) # , description="Use CUDA for prediction"

# Parameters for saving data on GCP
materialize_after_dump = Parameter("materialize_after_dump", default=False, required=False)
Expand Down
37 changes: 30 additions & 7 deletions pipelines/precipitation_model/impa/src/data/HDFDataset2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
# -*- coding: utf-8 -*-
from pathlib import Path

import h5py
import numpy as np
import torch
from torch.utils import data

from pipelines.precipitation_model.impa.src.utils.dataframe_utils import N_AFTER, N_BEFORE, fetch_future_datetimes, fetch_reversed_past_datetimes
from pipelines.precipitation_model.impa.src.utils.general_utils import print_ok, print_warning
from pipelines.precipitation_model.impa.src.utils.dataframe_utils import (
N_AFTER,
N_BEFORE,
fetch_future_datetimes,
fetch_reversed_past_datetimes,
)
from pipelines.precipitation_model.impa.src.utils.general_utils import (
print_ok,
print_warning,
)
from pipelines.precipitation_model.impa.src.utils.hdf_utils import get_dataset_keys

MIN_WEIGHT = 100
Expand Down Expand Up @@ -56,16 +65,27 @@ def __init__(
else:
suffix = ""

if len(set(["latent_field", "motion_field", "intensities"]).intersection(set(get_item_output))) > 0:
if (
len(
set(["latent_field", "motion_field", "intensities"]).intersection(
set(get_item_output)
)
)
> 0
):
if autoencoder_hash is None:
autoencoder_hash = "001178e117f50cf17817f336b86a809f"
parent_path = Path(filepath).parents[0]
if "train" in filepath.stem:
self.latent_field_filepath = Path(f"{parent_path}/train_latent{suffix}_{autoencoder_hash}.hdf")
self.latent_field_filepath = Path(
f"{parent_path}/train_latent{suffix}_{autoencoder_hash}.hdf"
)
self.motion_field_filepath = Path(f"{parent_path}/train_motion{suffix}.hdf")
self.intensities_filepath = Path(f"{parent_path}/train_intensities{suffix}.hdf")
elif "val" in filepath.stem:
self.latent_field_filepath = Path(f"{parent_path}/val_latent{suffix}_{autoencoder_hash}.hdf")
self.latent_field_filepath = Path(
f"{parent_path}/val_latent{suffix}_{autoencoder_hash}.hdf"
)
self.motion_field_filepath = Path(f"{parent_path}/val_motion{suffix}.hdf")
self.intensities_filepath = Path(f"{parent_path}/val_intensities{suffix}.hdf")
else:
Expand Down Expand Up @@ -267,9 +287,12 @@ def __len__(self):
return len(self.keys)

def get_sample_weights(self, overwrite_if_exists=False, verbose=True):
assert self.filepath.stem == "train", "Sample weights can only be calculated for the train dataset."
assert (
self.filepath.stem == "train"
), "Sample weights can only be calculated for the train dataset."
weights_filepath = (
self.filepath.parents[0] / f"train_sample_weights2-n_before={self.n_before}-n_after={self.n_after}.npy"
self.filepath.parents[0]
/ f"train_sample_weights2-n_before={self.n_before}-n_after={self.n_after}.npy"
)
if not overwrite_if_exists:
if weights_filepath.is_file():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# -*- coding: utf-8 -*-
import numpy as np

from pipelines.precipitation_model.impa.src.data.HDFDatasetMultiple import HDFDatasetMultiple
from pipelines.precipitation_model.impa.src.utils.dataframe_utils import N_AFTER, N_BEFORE
from pipelines.precipitation_model.impa.src.data.HDFDatasetMultiple import (
HDFDatasetMultiple,
)
from pipelines.precipitation_model.impa.src.utils.dataframe_utils import (
N_AFTER,
N_BEFORE,
)

elevation_file_small = "data/processed/elevations_data/elevation_{location}-res=2km-256x256.npy"
elevation_file_large = "data/processed/elevations_data/elevation_{location}-res=4km-256x256.npy"
Expand Down
10 changes: 8 additions & 2 deletions pipelines/precipitation_model/impa/src/data/HDFDatasetMerged.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# -*- coding: utf-8 -*-
import numpy as np

from pipelines.precipitation_model.impa.src.data.HDFDatasetMultiple import HDFDatasetMultiple
from pipelines.precipitation_model.impa.src.utils.dataframe_utils import N_AFTER, N_BEFORE
from pipelines.precipitation_model.impa.src.data.HDFDatasetMultiple import (
HDFDatasetMultiple,
)
from pipelines.precipitation_model.impa.src.utils.dataframe_utils import (
N_AFTER,
N_BEFORE,
)

# For Rio de Janeiro only

Expand Down
83 changes: 64 additions & 19 deletions pipelines/precipitation_model/impa/src/data/HDFDatasetMultiple.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import torch
from torch.utils import data

from pipelines.precipitation_model.impa.src.utils.dataframe_utils import fetch_future_datetimes, fetch_reversed_past_datetimes
from pipelines.precipitation_model.impa.src.utils.dataframe_utils import (
fetch_future_datetimes,
fetch_reversed_past_datetimes,
)
from pipelines.precipitation_model.impa.src.utils.hdf_utils import get_dataset_keys


Expand Down Expand Up @@ -76,7 +79,9 @@ def __init__(
# print_warning("File not found in /dev/shm, using original path.")
new_dataframe_filepaths_array[i] = str(filepath)

self.dataframe_shm_filepaths_array = new_dataframe_filepaths_array.reshape(dataframe_filepaths_array.shape)
self.dataframe_shm_filepaths_array = new_dataframe_filepaths_array.reshape(
dataframe_filepaths_array.shape
)
self.dataframe_filepaths_array = dataframe_filepaths_array
self.n_before_array = n_before_array
self.n_after_array = n_after_array
Expand Down Expand Up @@ -126,7 +131,9 @@ def _load_keys(self):
if j == 0:
new_keys = set(hdf["split_info"]["split_datetime_keys"])
else:
new_keys = set(hdf["split_info"]["split_datetime_keys"]).intersection(new_keys)
new_keys = set(hdf["split_info"]["split_datetime_keys"]).intersection(
new_keys
)
else:
if j == 0:
new_keys = set(get_dataset_keys(hdf))
Expand All @@ -140,17 +147,30 @@ def _load_keys(self):
self.past_keys = np.vstack(
[
self.past_keys,
np.array(fetch_reversed_past_datetimes(new_keys, self.n_before_array.max(), timestep)),
np.array(
fetch_reversed_past_datetimes(
new_keys, self.n_before_array.max(), timestep
)
),
]
)
except ValueError:
self.past_keys = np.array(fetch_reversed_past_datetimes(new_keys, self.n_before_array.max(), timestep))
self.past_keys = np.array(
fetch_reversed_past_datetimes(new_keys, self.n_before_array.max(), timestep)
)
try:
self.future_keys = np.vstack(
[self.future_keys, np.array(fetch_future_datetimes(new_keys, self.n_after_array.max(), timestep))]
[
self.future_keys,
np.array(
fetch_future_datetimes(new_keys, self.n_after_array.max(), timestep)
),
]
)
except ValueError:
self.future_keys = np.array(fetch_future_datetimes(new_keys, self.n_after_array.max(), timestep))
self.future_keys = np.array(
fetch_future_datetimes(new_keys, self.n_after_array.max(), timestep)
)
try:
self.ds_indices.append(len(new_keys) + self.ds_indices[-1])
except IndexError:
Expand All @@ -162,11 +182,21 @@ def _get_hdf_index(self, index):
return i

def __getitem__(self, index):
X = torch.ones((self.ni, self.nj, (self.n_before_array * self.n_before_resolution_array).sum())) * torch.inf
X = (
torch.ones(
(self.ni, self.nj, (self.n_before_array * self.n_before_resolution_array).sum())
)
* torch.inf
)
if self.leadtime_conditioning:
pass
else:
Y = torch.ones((self.ni, self.nj, (self.n_after_array * self.n_after_resolution_array).sum())) * torch.inf
Y = (
torch.ones(
(self.ni, self.nj, (self.n_after_array * self.n_after_resolution_array).sum())
)
* torch.inf
)

if self.leadtime_conditioning:
leadtime_index = index % self.n_after_array[0]
Expand All @@ -186,33 +216,44 @@ def __getitem__(self, index):
)
try:
X[:, :, tensor_ind : tensor_ind + n_resolution] = torch.as_tensor(
np.array(hdf[self.past_keys[index][i]]).reshape((self.ni, self.nj, n_resolution))
np.array(hdf[self.past_keys[index][i]]).reshape(
(self.ni, self.nj, n_resolution)
)
)
except KeyError:
X[:, :, tensor_ind : tensor_ind + n_resolution] = torch.ones((self.ni, self.nj, 1)) * np.nan
X[:, :, tensor_ind : tensor_ind + n_resolution] = (
torch.ones((self.ni, self.nj, 1)) * np.nan
)
if self.leadtime_conditioning and j == 0:
try:
Y = torch.as_tensor(np.array(hdf[self.future_keys[index][leadtime_index]])).reshape(
(self.ni, self.nj, self.n_after_resolution_array[0])
)
Y = torch.as_tensor(
np.array(hdf[self.future_keys[index][leadtime_index]])
).reshape((self.ni, self.nj, self.n_after_resolution_array[0]))
except KeyError:
Y = torch.ones((self.ni, self.nj, self.n_after_resolution_array[0])) * np.nan
Y = (
torch.ones((self.ni, self.nj, self.n_after_resolution_array[0]))
* np.nan
)
else:
# for i, key in enumerate(self.future_keys[index]):
n_resolution = self.n_after_resolution_array[j]
for i in range(self.n_after_array[j]):
if j == 0:
cumsum = 0
else:
cumsum = np.cumsum(self.n_after_array * self.n_after_resolution_array)[j - 1]
cumsum = np.cumsum(self.n_after_array * self.n_after_resolution_array)[
j - 1
]
tensor_ind = cumsum + i * n_resolution
try:
Y[:, :, tensor_ind : tensor_ind + n_resolution] = torch.as_tensor(
np.array(hdf[self.future_keys[index][i]])
).reshape((self.ni, self.nj, n_resolution))

except KeyError:
Y[:, :, tensor_ind : tensor_ind + n_resolution] = torch.ones((self.ni, self.nj, 1)) * np.nan
Y[:, :, tensor_ind : tensor_ind + n_resolution] = (
torch.ones((self.ni, self.nj, 1)) * np.nan
)
if self.x_transform:
X = self.x_transform(X)
if self.y_transform:
Expand All @@ -225,7 +266,9 @@ def __getitem__(self, index):
day = int(date[6:8])
hour = int(date[9:11])
minute = int(date[11:13])
date = torch.tensor([month / 12, day / 31, hour / 24, minute / 60], dtype=torch.float32).reshape((1, 1, -1))
date = torch.tensor(
[month / 12, day / 31, hour / 24, minute / 60], dtype=torch.float32
).reshape((1, 1, -1))
date_tensor = date.expand((self.ni, self.nj, -1))

metadata_tensor = torch.cat(
Expand Down Expand Up @@ -259,7 +302,9 @@ def get_sample_weights(self, min_sum=0):
pre_weights_size = len(pre_weights)
pre_weights = np.append(pre_weights, np.array([pre_weights.min()]))
inds = np.load(inds_filepath).reshape(-1, 1)
deltas = np.arange(-self.n_before_array[0] + 1, self.n_after_array[0] + 1, dtype=int).reshape(1, -1)
deltas = np.arange(
-self.n_before_array[0] + 1, self.n_after_array[0] + 1, dtype=int
).reshape(1, -1)
slices = inds + deltas
slices[np.logical_or(slices < 0, slices >= pre_weights_size)] = pre_weights_size
summed_weights = pre_weights[slices].sum(axis=1).reshape(-1)
Expand Down
34 changes: 27 additions & 7 deletions pipelines/precipitation_model/impa/src/data/PredHDFDataset2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from pathlib import Path

import h5py
Expand All @@ -12,7 +13,10 @@
fetch_pred_keys,
fetch_reversed_past_datetimes,
)
from pipelines.precipitation_model.impa.src.utils.general_utils import print_ok, print_warning
from pipelines.precipitation_model.impa.src.utils.general_utils import (
print_ok,
print_warning,
)
from pipelines.precipitation_model.impa.src.utils.hdf_utils import get_dataset_keys

MIN_WEIGHT = 100
Expand Down Expand Up @@ -59,7 +63,9 @@ def __init__(
f"models/{model}/predictions/{dataset}/predict_{split}-ckpt={ckpt_file.replace('.ckpt','')}.hdf"
)
else:
self.predict_filepath = Path(f"models/{model}/predictions/{dataset}/predict_{split}.hdf")
self.predict_filepath = Path(
f"models/{model}/predictions/{dataset}/predict_{split}.hdf"
)
self.n_predictions = n_predictions
self.n_before = n_before
self.n_after = n_after
Expand All @@ -69,16 +75,27 @@ def __init__(

self._load_keys()

if len(set(["latent_field", "motion_field", "intensities"]).intersection(set(get_item_output))) > 0:
if (
len(
set(["latent_field", "motion_field", "intensities"]).intersection(
set(get_item_output)
)
)
> 0
):
if autoencoder_hash is None:
autoencoder_hash = "001178e117f50cf17817f336b86a809f"
parent_path = filepath.parents[0]
if "train" in filepath.stem:
self.latent_field_filepath = Path(f"{parent_path}/train_latent_{autoencoder_hash}.hdf")
self.latent_field_filepath = Path(
f"{parent_path}/train_latent_{autoencoder_hash}.hdf"
)
self.motion_field_filepath = Path(f"{parent_path}/train_motion.hdf")
self.intensities_filepath = Path(f"{parent_path}/train_intensities.hdf")
elif "val" in filepath.stem:
self.latent_field_filepath = Path(f"{parent_path}/val_latent_{autoencoder_hash}.hdf")
self.latent_field_filepath = Path(
f"{parent_path}/val_latent_{autoencoder_hash}.hdf"
)
self.motion_field_filepath = Path(f"{parent_path}/val_motion.hdf")
self.intensities_filepath = Path(f"{parent_path}/val_intensities.hdf")
else:
Expand Down Expand Up @@ -287,9 +304,12 @@ def __len__(self):
return len(self.keys)

def get_sample_weights(self, overwrite_if_exists=False, verbose=True):
assert self.filepath.stem == "train", "Sample weights can only be calculated for the train dataset."
assert (
self.filepath.stem == "train"
), "Sample weights can only be calculated for the train dataset."
weights_filepath = (
self.filepath.parents[0] / f"train_sample_weights2-n_before={self.n_before}-n_after={self.n_after}.npy"
self.filepath.parents[0]
/ f"train_sample_weights2-n_before={self.n_before}-n_after={self.n_after}.npy"
)
if not overwrite_if_exists:
if weights_filepath.is_file():
Expand Down
Loading

0 comments on commit a6000af

Please sign in to comment.