diff --git a/geost/abstract_classes.py b/geost/abstract_classes.py index edc1027..1608af1 100644 --- a/geost/abstract_classes.py +++ b/geost/abstract_classes.py @@ -104,6 +104,10 @@ def slice_depth_interval(self): def slice_by_values(self): pass + @abstractmethod + def select_by_condition(self): + pass + @abstractmethod def get_cumulative_layer_thickness(self): # Not sure if this should be here, potentially unsuitable with DiscreteData @@ -208,6 +212,10 @@ def select_by_depth(self): def select_by_length(self): pass + @abstractmethod + def select_by_condition(self): + pass + @abstractmethod def get_area_labels(self): pass diff --git a/geost/base.py b/geost/base.py index c395e41..9a1ff77 100644 --- a/geost/base.py +++ b/geost/base.py @@ -1,6 +1,6 @@ import pickle from pathlib import WindowsPath -from typing import Iterable, List +from typing import Any, Iterable, List import geopandas as gpd import numpy as np @@ -824,6 +824,44 @@ def slice_by_values( return self.__class__(sliced, self.has_inclined) + def select_by_condition(self, condition: Any, invert: bool = False): + """ + Select data using a manual condition that results in a boolean mask. Returns the + rows in the data where the 'condition' evaluates to True. + + Parameters + ---------- + condition : list, pd.Series or array like + Boolean array like object with locations at which the values will be + preserved, dtype must be 'bool' and the length must correspond with the + length of the data. + invert : bool, optional + If True, the selection is inverted so rows that evaluate to False will be + returned. The default is False. + + Returns + ------- + :class:`~geost.base.LayeredData` + New instance containing only the data objects selected by this method. + + Examples + -------- + Select rows in borehole data that contain a specific value: + + >>> boreholes.select_by_condition(boreholes["lith"]=="V") + + Or select rows in the borehole data that contain a specific (part of) string or + strings: + + >>> boreholes.select_by_condition(boreholes["column"].str.contains("foo|bar")) + + """ + if invert: + selected = self[~condition] + else: + selected = self[condition] + return self.__class__(selected, self.has_inclined) + def get_cumulative_layer_thickness(self, column: str, values: str | List[str]): """ Get the cumulative thickness of layers where a column contains a specified search @@ -1221,9 +1259,11 @@ def slice_depth_interval(self): def slice_by_values(self): raise NotImplementedError() + def select_by_condition(self): + raise NotImplementedError() + def get_cumulative_layer_thickness(self): raise NotImplementedError() - pass def get_layer_top(self): raise NotImplementedError() @@ -1830,6 +1870,41 @@ def slice_by_values( collection_selected = data_selected.to_collection() return collection_selected + def select_by_condition(self, condition: Any, invert: bool = False): + """ + Select from collection.data using a manual condition that results in a boolean + mask. Returns the rows in the data where the 'condition' evaluates to True. + + Parameters + ---------- + condition : list, pd.Series or array like + Boolean array like object with locations at which the values will be + preserved, dtype must be 'bool' and the length must correspond with the + length of the data. + invert : bool, optional + If True, the selection is inverted so rows that evaluate to False will be + returned. The default is False. + + Returns + ------- + :class:`~geost.base.LayeredData` + New instance containing only the data objects selected by this method. + + Examples + -------- + Select rows in borehole data that contain a specific value: + + >>> boreholes.select_by_condition(boreholes["lith"]=="V") + + Or select rows in the borehole data that contain a specific (part of) string or + strings: + + >>> boreholes.select_by_condition(boreholes["column"].str.contains("foo|bar")) + + """ + selected = self.data.select_by_condition(condition, invert) + return selected.to_collection() + def get_area_labels( self, polygon_gdf: str | WindowsPath | gpd.GeoDataFrame, diff --git a/tests/test_collections.py b/tests/test_collections.py index 49e1e65..4f57c89 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -384,6 +384,15 @@ def test_slice_by_values(self, borehole_collection): assert len(sliced.data) == expected_length assert sliced.n_points == 4 + @pytest.mark.unittest + def test_select_by_condition(self, borehole_collection): + # same selection with this method is used in test_data_objects + selected = borehole_collection.select_by_condition( + borehole_collection.data["lith"] == "V" + ) + assert isinstance(selected, BoreholeCollection) + assert selected.n_points == 2 + @pytest.mark.integrationtest def test_validation_pass(self, capfd, borehole_df_ok): LayeredData(borehole_df_ok).to_collection() diff --git a/tests/test_data_objects.py b/tests/test_data_objects.py index dde9996..c925379 100644 --- a/tests/test_data_objects.py +++ b/tests/test_data_objects.py @@ -180,6 +180,20 @@ def test_slice_depth_interval(self, borehole_data): assert len(sliced) == 7 assert_array_equal(bottoms_of_slice, expected_bottoms_of_slice) + @pytest.mark.unittest + def test_select_by_condition(self, borehole_data): + selected = borehole_data.select_by_condition(borehole_data["lith"] == "V") + expected_nrs = ["B", "D"] + assert_array_equal(selected["nr"].unique(), expected_nrs) + assert np.all(selected["lith"] == "V") + assert len(selected) == 4 + + selected = borehole_data.select_by_condition( + borehole_data["lith"] == "V", invert=True + ) + assert len(selected) == 21 + assert ~np.all(selected['lith']=='V') + @pytest.mark.unittest def test_to_multiblock(self, borehole_data): # Test normal to multiblock.