From 9bded0b6c1a7547a043861e0099bf5f8411a3f2a Mon Sep 17 00:00:00 2001 From: Janet Barclay Date: Fri, 21 Oct 2022 13:46:39 -0400 Subject: [PATCH 1/2] remove exclude_file from sample Snakefiles --- workflow_examples/Snakefile_basic.smk | 1 - workflow_examples/Snakefile_gwn.smk | 1 - workflow_examples/Snakefile_pretrain_LSTM.smk | 1 - workflow_examples/Snakefile_rgcn.smk | 1 - workflow_examples/Snakefile_rgcn_hypertune.smk | 1 - workflow_examples/Snakefile_rgcn_pytorch.smk | 1 - 6 files changed, 6 deletions(-) diff --git a/workflow_examples/Snakefile_basic.smk b/workflow_examples/Snakefile_basic.smk index be4c617..ae4801e 100755 --- a/workflow_examples/Snakefile_basic.smk +++ b/workflow_examples/Snakefile_basic.smk @@ -49,7 +49,6 @@ rule prep_io_data: spatial_idx_name='segs_test', time_idx_name='times_test', catch_prop_file=None, - exclude_file=None, train_start_date=config['train_start_date'], train_end_date=config['train_end_date'], val_start_date=config['val_start_date'], diff --git a/workflow_examples/Snakefile_gwn.smk b/workflow_examples/Snakefile_gwn.smk index e1a8512..cca840d 100644 --- a/workflow_examples/Snakefile_gwn.smk +++ b/workflow_examples/Snakefile_gwn.smk @@ -61,7 +61,6 @@ rule prep_io_data: y_vars_pretrain=config['y_vars_pretrain'], y_vars_finetune=config['y_vars_finetune'], catch_prop_file=None, - exclude_file=None, train_start_date=config['train_start_date'], train_end_date=config['train_end_date'], val_start_date=config['val_start_date'], diff --git a/workflow_examples/Snakefile_pretrain_LSTM.smk b/workflow_examples/Snakefile_pretrain_LSTM.smk index a84da80..9bd9140 100644 --- a/workflow_examples/Snakefile_pretrain_LSTM.smk +++ b/workflow_examples/Snakefile_pretrain_LSTM.smk @@ -54,7 +54,6 @@ rule prep_io_data: spatial_idx_name='segs_test', time_idx_name='times_test', catch_prop_file=None, - exclude_file=None, train_start_date=config['train_start_date'], train_end_date=config['train_end_date'], val_start_date=config['val_start_date'], diff --git a/workflow_examples/Snakefile_rgcn.smk b/workflow_examples/Snakefile_rgcn.smk index a9e98b5..ea77412 100644 --- a/workflow_examples/Snakefile_rgcn.smk +++ b/workflow_examples/Snakefile_rgcn.smk @@ -63,7 +63,6 @@ rule prep_io_data: y_vars_pretrain=config['y_vars_pretrain'], y_vars_finetune=config['y_vars_finetune'], catch_prop_file=None, - exclude_file=None, train_start_date=config['train_start_date'], train_end_date=config['train_end_date'], val_start_date=config['val_start_date'], diff --git a/workflow_examples/Snakefile_rgcn_hypertune.smk b/workflow_examples/Snakefile_rgcn_hypertune.smk index 36bea5d..0dc264b 100644 --- a/workflow_examples/Snakefile_rgcn_hypertune.smk +++ b/workflow_examples/Snakefile_rgcn_hypertune.smk @@ -67,7 +67,6 @@ rule prep_io_data: y_vars_pretrain=config['y_vars_pretrain'], y_vars_finetune=config['y_vars_finetune'], catch_prop_file=None, - exclude_file=None, train_start_date=config['train_start_date'], train_end_date=config['train_end_date'], val_start_date=config['val_start_date'], diff --git a/workflow_examples/Snakefile_rgcn_pytorch.smk b/workflow_examples/Snakefile_rgcn_pytorch.smk index 0b52025..ecbc1e3 100644 --- a/workflow_examples/Snakefile_rgcn_pytorch.smk +++ b/workflow_examples/Snakefile_rgcn_pytorch.smk @@ -61,7 +61,6 @@ rule prep_io_data: y_vars_pretrain=config['y_vars_pretrain'], y_vars_finetune=config['y_vars_finetune'], catch_prop_file=None, - exclude_file=None, train_start_date=config['train_start_date'], train_end_date=config['train_end_date'], val_start_date=config['val_start_date'], From 825c3e0f9a61b588256efcfba1518b6101b7d6b9 Mon Sep 17 00:00:00 2001 From: Janet Barclay Date: Fri, 21 Oct 2022 13:53:03 -0400 Subject: [PATCH 2/2] removing exclude_file and associated functions --- river_dl/preproc_utils.py | 144 -------------------------------------- 1 file changed, 144 deletions(-) diff --git a/river_dl/preproc_utils.py b/river_dl/preproc_utils.py index ac8fbdc..ac15bc9 100755 --- a/river_dl/preproc_utils.py +++ b/river_dl/preproc_utils.py @@ -276,102 +276,6 @@ def reshape_for_training(data): return np.reshape(data, [n_batch * n_seg, seq_len, n_feat]) -def get_exclude_start_end(exclude_grp): - """ - get the start and end dates for the exclude group - :param exclude_grp: [dict] dictionary representing the exclude group from - the exclude yml file - :return: [tuple of datetime objects] start date, end date - """ - start = exclude_grp.get("start_date") - if start: - start = datetime.datetime.strptime(start, "%Y-%m-%d") - - end = exclude_grp.get("end_date") - if end: - end = datetime.datetime.strptime(end, "%Y-%m-%d") - return start, end - - -def get_exclude_vars(exclude_grp): - """ - get the variables_to_log to exclude for the exclude group - :param exclude_grp: [dict] dictionary representing the exclude group from - the exclude yml file - :return: [list] variables_to_log to exclude - """ - variable = exclude_grp.get("variable") - if not variable or variable == "both": - return ["seg_tave_water", "seg_outflow"] - elif variable == "temp": - return ["seg_tave_water"] - elif variable == "flow": - return ["seg_outflow"] - else: - raise ValueError("exclude variable must be flow, temp, or both") - - -def get_exclude_seg_ids(exclude_grp, all_segs): - """ - get the segments to exclude - :param exclude_grp: [dict] dictionary representing the exclude group from - the exclude yml file - :param all_segs: [array] all of the segments. this is needed if we are doing - a reverse exclusion - :return: [list like] the segments to exclude - """ - # ex_segs are the sites to exclude - if "seg_id_nats_ex" in exclude_grp.keys(): - ex_segs = exclude_grp["seg_id_nats_ex"] - # exclude all *but* the "seg_id_nats_in" - elif "seg_id_nats_in" in exclude_grp.keys(): - ex_mask = ~all_segs.isin(exclude_grp["seg_id_nats_in"]) - ex_segs = all_segs[ex_mask] - else: - ex_segs = all_segs - return ex_segs - - -def exclude_segments(y_data, exclude_segs): - """ - exclude segments from being trained on by setting their weights as zero - :param y_data:[xr dataset] y_dataset data. this is used to get the dimensions - :param exclude_segs: [list] list of segments to exclude in the loss - calculation - :return: - """ - weights = initialize_weights(y_data, 1) - for seg_grp in exclude_segs: - # get the start and end dates is present - start, end = get_exclude_start_end(seg_grp) - exclude_vars = get_exclude_vars(seg_grp) - segs_to_exclude = get_exclude_seg_ids(seg_grp, weights.seg_id_nat) - - # loop through the data_vars - for v in exclude_vars: - # set those weights to zero - weights[v].load() - weights[v].loc[ - dict(date=slice(start, end), seg_id_nat=segs_to_exclude) - ] = 0 - return weights - - -def initialize_weights(y_data, initial_val=1): - """ - initialize all weights with a value. - :param y_data:[xr dataset] y_dataset data. this is used to get the dimensions - :param initial_val: [num] a number to initialize the weights with. should - be between 0 and 1 (inclusive) - :return: [xr dataset] dataset weights initialized with a uniform value - """ - weights = y_data.copy(deep=True) - for v in y_data.data_vars: - weights[v].load() - weights[v].loc[:, :] = initial_val - return weights - - def reduce_training_data_random( data_file, train_start_date="1980-10-01", @@ -600,7 +504,6 @@ def prep_y_data( time_idx_name="date", seq_len=365, log_vars=None, - exclude_file=None, normalize_y=True, y_type="obs", y_std=None, @@ -637,7 +540,6 @@ def prep_y_data( sites will be witheld from training and validation :param seq_len: [int] length of sequences (e.g., 365) :param log_vars: [list-like] which variables_to_log (if any) to take log of - :param exclude_file: [str] path to exclude file :param normalize_y: [bool] whether or not to normalize the y_dataset values :param y_type: [str] "obs" if observations or "pre" if pretraining :param y_std: [array-like] standard deviations of y_dataset variables_to_log @@ -683,12 +585,6 @@ def prep_y_data( if log_vars: y_trn = log_variables(y_trn, log_vars) - # filter pretrain/finetune y_dataset - if exclude_file: - exclude_segs = read_exclude_segs_file(exclude_file) - y_wgts = exclude_segments(y_trn, exclude_segs=exclude_segs) - else: - y_wgts = initialize_weights(y_trn) # scale y_dataset training data and get the mean and std # scale the validation partition to benchmark epoch performance if normalize_y: @@ -713,9 +609,6 @@ def prep_y_data( "y_obs_trn": convert_batch_reshape( y_trn, spatial_idx_name, time_idx_name, offset=trn_offset, seq_len=seq_len ), - "y_obs_wgts": convert_batch_reshape( - y_wgts, spatial_idx_name, time_idx_name, offset=trn_offset, seq_len=seq_len - ), "y_obs_val": convert_batch_reshape( y_val, spatial_idx_name, time_idx_name, offset=tst_val_offset, seq_len=seq_len ), @@ -768,7 +661,6 @@ def prep_all_data( dist_type="updown", catch_prop_file=None, catch_prop_vars=None, - exclude_file=None, log_y_vars=False, out_file=None, segs=None, @@ -823,7 +715,6 @@ def prep_all_data( left unfilled, the catchment properties will not be included as predictors :param catch_prop_vars: [list of str] list of catchment properties to use. If left unfilled and a catchment property file is supplied all variables will be used. - :param exclude_file: [str] path to exclude file :param log_y_vars: [bool] whether or not to take the log of discharge in training :param segs: [list-like] which segments to prepare the data for @@ -1005,7 +896,6 @@ def prep_all_data( time_idx_name=time_idx_name, seq_len=seq_len, log_vars=log_y_vars, - exclude_file=exclude_file, normalize_y=normalize_y, y_type="obs", trn_offset = trn_offset, @@ -1028,7 +918,6 @@ def prep_all_data( time_idx_name=time_idx_name, seq_len=seq_len, log_vars=log_y_vars, - exclude_file=exclude_file, normalize_y=normalize_y, y_type="pre", y_std=y_obs_data["y_std"], @@ -1053,7 +942,6 @@ def prep_all_data( time_idx_name=time_idx_name, seq_len=seq_len, log_vars=log_y_vars, - exclude_file=exclude_file, normalize_y=normalize_y, y_type="pre", trn_offset = trn_offset, @@ -1118,35 +1006,3 @@ def prep_adj_matrix(infile, dist_type, dist_idx_name, segs=None, out_file=None): np.savez_compressed(out_file, dist_matrix=A_hat) return A_hat - -def read_exclude_segs_file(exclude_file): - """ - read the exclude segs file. should be a yml file with start_date and list of - segments to exclude - -- - example exclude file: - - group_after_2017: - start_date: "2017-10-01" - variable: "temp" - seg_id_nats_ex: - - 1556 - - 1569 - group_2018_water_year: - start_date: "2017-10-01" - end_date: "2018-10-01" - seg_id_nats_ex: - - 1653 - group_all_time: - seg_id_nats_in: - - 1806 - - 2030 - - -- - :param exclude_file: [str] exclude segs file - :return: [list] list of dictionaries of segments to exclude. dict keys must - have 'seg_id_nats' and may also have 'start_date' and 'end_date' - """ - with open(exclude_file, "r") as s: - d = yaml.safe_load(s) - return [val for key, val in d.items()]