From 68bad890bf3298b8e9afb71d8d723e30b5e73f2d Mon Sep 17 00:00:00 2001 From: ctuguinay Date: Mon, 1 Apr 2024 16:46:19 +0000 Subject: [PATCH] refactor region id selection and mask label construction in r2d.mask --- echoregions/regions2d/regions2d.py | 46 ++++++++++++------------------ 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/echoregions/regions2d/regions2d.py b/echoregions/regions2d/regions2d.py index e57aca1..7db75f5 100644 --- a/echoregions/regions2d/regions2d.py +++ b/echoregions/regions2d/regions2d.py @@ -195,6 +195,8 @@ def select_region( if isinstance(region_id, (float, int, str)): region_id = [region_id] elif isinstance(region_id, list): + if len(region_id) == 0: + raise ValueError("region_id list is empty. Cannot be empty.") for value in region_id: if not isinstance(value, (float, int, str)): raise TypeError( @@ -483,36 +485,9 @@ def mask( region_contours : pd.DataFrame DataFrame containing region_id, depth, and time. """ - if isinstance(region_id, list): - if len(region_id) == 0: - raise ValueError("region_id list is empty. Cannot be empty.") - mask_label_region_ids = region_id - elif region_id is None: - # Extract all region_id values - mask_label_region_ids = self.data.region_id.astype(int).to_list() - else: - raise TypeError( - f"region_id must be of type list. Currently is of type {type(region_id)}" - ) - - if mask_labels is None: - # Create mask_labels with each region_id as a key and values starting from 0 - mask_labels = {key: idx for idx, key in enumerate(mask_label_region_ids)} - - # Check that region_id and mask_labels are of the same size - if len(set(mask_label_region_ids) - set(mask_labels.keys())) > 0: - raise ValueError( - "Each region_id' must be a key in 'mask_labels'. " - "If you would prefer 0 based indexing as values for mask_labels, leave " - "mask_labels as None." - ) - # Dataframe containing region information. region_df = self.select_region(region_id, region_class) - # Select only important columns - region_df = region_df[["region_id", "time", "depth"]] - # Filter for rows with depth values within self min and self max depth and # for rows that have positive depth. region_df = region_df[ @@ -529,6 +504,23 @@ def mask( "between min_depth and max_depth." ) else: + # Grab subset region ids + subset_region_ids = region_df.region_id.astype(int).to_list() + + if mask_labels is None: + # Create mask_labels with each subset_region_ids as a key and values starting from 0 + mask_labels = {key: idx for idx, key in enumerate(subset_region_ids)} + + # Check that subset_region_ids and mask_labels are of the same size + if len(set(subset_region_ids) - set(mask_labels.keys())) > 0: + raise ValueError( + "Each value in subset_region_ids must be a key in 'mask_labels'. " + "If you would prefer 0 based indexing as values for mask_labels, leave " + "mask_labels as None." + ) + # Select only important columns + region_df = region_df[["region_id", "time", "depth"]] + # Organize the regions in a format for region mask. df = region_df.explode(["time", "depth"])