diff --git a/.prospector.yml b/.prospector.yml index a59dca9..78bec73 100644 --- a/.prospector.yml +++ b/.prospector.yml @@ -4,7 +4,7 @@ output-format: grouped -strictness: medium +strictness: medium # Note that this can suppress many pylint warnings doc-warnings: false test-warnings: true member-warnings: false diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..5664a7d --- /dev/null +++ b/.pylintrc @@ -0,0 +1,6 @@ +[tool.pylint."messages control"] +# Suppress warnings for bad variables for specific variables: +good-names=X,x,y + +# To disable this warning completely: +# disable = invalid-name diff --git a/notebooks/tutorial_traintest.ipynb b/notebooks/tutorial_traintest.ipynb index 7f82319..7c87111 100644 --- a/notebooks/tutorial_traintest.ipynb +++ b/notebooks/tutorial_traintest.ipynb @@ -4,1086 +4,152 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Tutorial on how to use the train/test method listed in the s2s `traintest` module\n", - "\n", - "For cross-validation, we split data resampled in the s2s `time` module into groups.\n", - "\n", - "We start by importing the required libraries and generating an example `AdventCalendar` along with example data." + "### TrainTest splitters operating on (multiple) xarray dataarrays" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], - "source": [ - "import s2spy.time\n", - "import s2spy.traintest\n", - "import pandas as pd\n", - "import numpy as np" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "calendar = s2spy.time.AdventCalendar(anchor_date=(10, 15), freq=\"180d\")\n", - "\n", - "time_index = pd.date_range(\"20151020\", \"20211001\", freq=\"60d\")\n", - "test_data = np.random.random(len(time_index))\n", - "df = pd.DataFrame(test_data, index=time_index, columns =[\"data1\"])\n", - "ds = df.to_xarray().rename({\"index\": \"time\"})" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We first need to resample the data using the calendar:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "Index(['anchor_year', 'i_interval', 'interval', 'data1', 'target'], dtype='object')" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "-0.08806 -1.723 -1.14 -0.0535 0.6887 ... 0.6837 -0.523 -1.067 0.5144 0.1049\n", + "Coordinates:\n", + " * time (time) datetime64[ns] 2015-10-20 2015-12-19 ... 2023-11-07\n" + ] } ], "source": [ - "df = calendar.resample(df)\n", - "df.keys()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Example of the `KFold` method.\n", + "# Create dummy data\n", + "import numpy as np\n", + "import pandas as pd\n", + "import xarray as xr\n", "\n", - "All splitter classes from sklearn are supported, a list is available here:\n", + "# Hide the full data when displaying a dataset in the notebook\n", + "xr.set_options(display_expand_data=False) \n", "\n", - "https://scikit-learn.org/stable/modules/classes.html#splitter-classes" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.model_selection import KFold\n", - "splitter = KFold(n_splits=3)\n", - "df = s2spy.traintest.split_groups(splitter, df)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Get data from all training groups of fold 0:" + "n = 50\n", + "time_index = pd.date_range(\"20151020\", periods=n, freq=\"60d\")\n", + "time_coord = {\"time\": time_index}\n", + "x1 = xr.DataArray(np.random.randn(n), coords=time_coord, name=\"precursor1\")\n", + "x2 = xr.DataArray(np.random.randn(n), coords=time_coord, name=\"precursor2\")\n", + "y = xr.DataArray(np.random.randn(n), coords=time_coord, name=\"target\")\n", + "print(x1)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
anchor_yeari_intervalintervaldata1targetsplit_0split_1split_2
420180.0(2018-04-18, 2018-10-15]0.240913Truetraintesttrain
520181.0(2017-10-20, 2018-04-18]0.669087Falsetraintesttrain
620190.0(2019-04-18, 2019-10-15]0.414739Truetraintesttrain
720191.0(2018-10-20, 2019-04-18]0.287408Falsetraintesttrain
820200.0(2020-04-18, 2020-10-15]0.670571Truetraintraintest
920201.0(2019-10-21, 2020-04-18]0.556028Falsetraintraintest
\n", - "
" - ], - "text/plain": [ - " anchor_year i_interval interval data1 target split_0 \\\n", - "4 2018 0.0 (2018-04-18, 2018-10-15] 0.240913 True train \n", - "5 2018 1.0 (2017-10-20, 2018-04-18] 0.669087 False train \n", - "6 2019 0.0 (2019-04-18, 2019-10-15] 0.414739 True train \n", - "7 2019 1.0 (2018-10-20, 2019-04-18] 0.287408 False train \n", - "8 2020 0.0 (2020-04-18, 2020-10-15] 0.670571 True train \n", - "9 2020 1.0 (2019-10-21, 2020-04-18] 0.556028 False train \n", - "\n", - " split_1 split_2 \n", - "4 test train \n", - "5 test train \n", - "6 test train \n", - "7 test train \n", - "8 train test \n", - "9 train test " - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dimensions: (anchor_year: 8, i_interval: 2)\n", + "Coordinates:\n", + " * anchor_year (anchor_year) int64 2016 2017 2018 2019 2020 2021 2022 2023\n", + " * i_interval (i_interval) int64 0 1\n", + " index (anchor_year, i_interval) int64 0 1 2 3 4 5 ... 11 12 13 14 15\n", + " interval (anchor_year, i_interval) object (2016-04-18, 2016-10-15] .....\n", + " target (i_interval) bool True False\n", + "Data variables:\n", + " precursor1 (anchor_year, i_interval) float64 0.9591 -0.9723 ... 0.5215\n" + ] } ], "source": [ - "training_data_split_0 = df.where(df.split_0 == \"train\")\n", - "training_data_split_0.dropna()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### `xarray` example" + "# Fit to calendar\n", + "import s2spy.time\n", + "\n", + "calendar = s2spy.time.AdventCalendar(anchor=(10, 15), freq=\"180d\")\n", + "calendar.map_to_data(x1) # TODO: would be nice to pass in multiple at once.\n", + "x1 = s2spy.time.resample(calendar, x1)\n", + "x2 = s2spy.time.resample(calendar, x2)\n", + "y = s2spy.time.resample(calendar, y)\n", + "\n", + "print(x1)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.Dataset>\n",
-       "Dimensions:      (anchor_year: 5, i_interval: 2)\n",
-       "Coordinates:\n",
-       "    index        (anchor_year, i_interval) int64 0 1 2 3 4 5 6 7 8 9\n",
-       "    interval     (anchor_year, i_interval) object (2016-04-18, 2016-10-15] .....\n",
-       "  * anchor_year  (anchor_year) int64 2016 2017 2018 2019 2020\n",
-       "  * i_interval   (i_interval) int64 0 1\n",
-       "    target       (i_interval) bool True False\n",
-       "Data variables:\n",
-       "    data1        (anchor_year, i_interval) float64 0.6396 0.3708 ... 0.556
" - ], - "text/plain": [ - "\n", - "Dimensions: (anchor_year: 5, i_interval: 2)\n", - "Coordinates:\n", - " index (anchor_year, i_interval) int64 0 1 2 3 4 5 6 7 8 9\n", - " interval (anchor_year, i_interval) object (2016-04-18, 2016-10-15] .....\n", - " * anchor_year (anchor_year) int64 2016 2017 2018 2019 2020\n", - " * i_interval (i_interval) int64 0 1\n", - " target (i_interval) bool True False\n", - "Data variables:\n", - " data1 (anchor_year, i_interval) float64 0.6396 0.3708 ... 0.556" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: [2019 2020 2021 2022 2023]\n", + "Test: [2016 2017 2018]\n", + "Train: [2016 2017 2018 2022 2023]\n", + "Test: [2019 2020 2021]\n", + "Train: [2016 2017 2018 2019 2020 2021]\n", + "Test: [2022 2023]\n", + "\n", + "Dimensions: (anchor_year: 6, i_interval: 2)\n", + "Coordinates:\n", + " * anchor_year (anchor_year) int64 2016 2017 2018 2019 2020 2021\n", + " * i_interval (i_interval) int64 0 1\n", + " index (anchor_year, i_interval) int64 0 1 2 3 4 5 6 7 8 9 10 11\n", + " interval (anchor_year, i_interval) object (2016-04-18, 2016-10-15] .....\n", + " target (i_interval) bool True False\n", + "Data variables:\n", + " precursor1 (anchor_year, i_interval) float64 0.9591 -0.9723 ... 0.7427\n" + ] } ], "source": [ - "ds = calendar.resample(ds)\n", - "ds" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Here we choose the `ShuffleSplit` method:" + "# Cross-validation\n", + "from sklearn.model_selection import KFold\n", + "import s2spy.traintest\n", + "\n", + "kfold = KFold(n_splits=3)\n", + "cv = s2spy.traintest.TrainTestSplit(kfold)\n", + "for (x1_train, x2_train), (x1_test, x2_test), y_train, y_test in cv.split(x1, x2, y=y):\n", + " print(\"Train:\", x1_train.anchor_year.values)\n", + " print(\"Test:\", x1_test.anchor_year.values)\n", + "\n", + "print(x1_train)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.Dataset>\n",
-       "Dimensions:      (anchor_year: 5, i_interval: 2, split: 3)\n",
-       "Coordinates:\n",
-       "    index        (anchor_year, i_interval) int64 0 1 2 3 4 5 6 7 8 9\n",
-       "    interval     (anchor_year, i_interval) object (2016-04-18, 2016-10-15] .....\n",
-       "  * anchor_year  (anchor_year) int64 2016 2017 2018 2019 2020\n",
-       "  * i_interval   (i_interval) int64 0 1\n",
-       "    target       (i_interval) bool True False\n",
-       "  * split        (split) int32 0 1 2\n",
-       "    traintest    (split, anchor_year) <U6 'train' 'test' ... 'train' 'train'\n",
-       "Data variables:\n",
-       "    data1        (anchor_year, i_interval) float64 0.6396 0.3708 ... 0.556
" - ], - "text/plain": [ - "\n", - "Dimensions: (anchor_year: 5, i_interval: 2, split: 3)\n", - "Coordinates:\n", - " index (anchor_year, i_interval) int64 0 1 2 3 4 5 6 7 8 9\n", - " interval (anchor_year, i_interval) object (2016-04-18, 2016-10-15] .....\n", - " * anchor_year (anchor_year) int64 2016 2017 2018 2019 2020\n", - " * i_interval (i_interval) int64 0 1\n", - " target (i_interval) bool True False\n", - " * split (split) int32 0 1 2\n", - " traintest (split, anchor_year) None: + self.splitter = splitter + + def split( + self, + *x_args: xr.DataArray, + y: Optional[xr.DataArray] = None, + dim: str = "anchor_year" + ) -> XMaybeY: + """Iterate over splits. + + Args: + x_args: one or multiple xr.DataArray's that share the same + coordinate along the given dimension + y: (optional) xr.DataArray that shares the same coordinate along the + given dimension + dim: name of the dimension along which to split the data. + + Returns: + Iterator over the splits + """ + # Check that all inputs share the same dim coordinate + coords = [] + x: xr.DataArray # Initialize x to set scope outside loop + for x in x_args: + try: + coords.append(x[dim]) + except KeyError as err: + raise CoordinateMismatch( + f"Not all input data arrays have the {dim} dimension." + ) from err + + if not _all_equal(coords): + raise CoordinateMismatch( + f"Input arrays are not equal along {dim} dimension." + ) + + if y is not None and not np.array_equal(y[dim], x[dim]): + raise CoordinateMismatch( + f"Input arrays are not equal along {dim} dimension." + ) + + if x[dim].size <=1: + raise ValueError( + f"Invalid input: need at least 2 values along dimension {dim}" + ) + + # Now we know that all inputs are equal. + for (train_indices, test_indices) in self.splitter.split(x[dim]): + if len(x_args) == 1: + x_train: XType = x.isel({dim: train_indices}) + x_test: XType = x.isel({dim: test_indices}) + else: + x_train = [da.isel({dim: train_indices}) for da in x_args] + x_test = [da.isel({dim: test_indices}) for da in x_args] + + if y is None: + yield x_train, x_test + else: + y_train = y.isel({dim: train_indices}) + y_test = y.isel({dim: test_indices}) + yield x_train, x_test, y_train, y_test diff --git a/tests/test_traintest.py b/tests/test_traintest.py index 4f7e60a..1ae509f 100644 --- a/tests/test_traintest.py +++ b/tests/test_traintest.py @@ -3,68 +3,94 @@ import numpy as np import pandas as pd import pytest +import xarray as xr from sklearn.model_selection import KFold +import s2spy.time import s2spy.traintest -from s2spy.time import AdventCalendar -from s2spy.time import resample - - -class TestTrainTest: - # Define all required inputs as fixtures: - @pytest.fixture(autouse=True) - def dummy_calendar(self): - return AdventCalendar(anchor=(10, 15), freq="180d") - - @pytest.fixture(autouse=True) - def dummy_dataframe(self): - time_index = pd.date_range("20181020", "20211001", freq="60d") - test_data = np.random.random(len(time_index)) - return pd.DataFrame(test_data, index=time_index, columns=["data1"]) - - @pytest.fixture(autouse=True) - def dummy_dataset(self, dummy_dataframe): - return dummy_dataframe.to_xarray().rename({"index": "time"}) - - @pytest.fixture(autouse=True) - def dummy_dataframe_short(self): - time_index = pd.date_range("20191020", "20211001", freq="60d") - test_data = np.random.random(len(time_index)) - return pd.DataFrame(test_data, index=time_index, columns=["data1"]) - - @pytest.fixture(autouse=True) - def dummy_dataset_short(self, dummy_dataframe_short): - return dummy_dataframe_short.to_xarray().rename({"index": "time"}) - - def test_kfold_df(self, dummy_calendar, dummy_dataframe): - mapped_calendar = dummy_calendar.map_to_data(dummy_dataframe) - df = resample(mapped_calendar, dummy_dataframe) - df = s2spy.traintest.split_groups(KFold(n_splits=2), df) - expected_group = ["test", "test", "train", "train"] - assert np.array_equal(df["split_0"].values, expected_group) - - def test_kfold_ds(self, dummy_calendar, dummy_dataset): - mapped_calendar = dummy_calendar.map_to_data(dummy_dataset) - ds = resample(mapped_calendar, dummy_dataset) - ds = s2spy.traintest.split_groups(KFold(n_splits=2), ds) - expected_group = ["test", "train"] - assert np.array_equal(ds.traintest.values[0], expected_group) - - def test_kfold_df_short(self, dummy_calendar, dummy_dataframe_short): - "Should fail as there is only a single anchor year: no splits can be made" - mapped_calendar = dummy_calendar.map_to_data(dummy_dataframe_short) - df = resample(mapped_calendar, dummy_dataframe_short) - with pytest.raises(ValueError): - df = s2spy.traintest.split_groups(KFold(n_splits=2), df) - - def test_kfold_ds_short(self, dummy_calendar, dummy_dataset_short): - "Should fail as there is only a single anchor year: no splits can be made" - mapped_calendar = dummy_calendar.map_to_data(dummy_dataset_short) - ds = resample(mapped_calendar, dummy_dataset_short) - with pytest.raises(ValueError): - ds = s2spy.traintest.split_groups(KFold(n_splits=2), ds) - - def test_alternative_key(self, dummy_calendar, dummy_dataset): - mapped_calendar = dummy_calendar.map_to_data(dummy_dataset) - ds = resample(mapped_calendar, dummy_dataset) - ds = s2spy.traintest.split_groups(KFold(n_splits=2), ds, key="i_interval") - assert "i_interval" in ds.traintest.dims + + +@pytest.fixture(autouse=True) +def dummy_data(): + # Generate random data + n = 50 + time_index = pd.date_range("20151020", periods=n, freq="60d") + time_coord = {"time": time_index} + x1 = xr.DataArray(np.random.randn(n), coords=time_coord, name="x1") + x2 = xr.DataArray(np.random.randn(n), coords=time_coord, name="x2") + y = xr.DataArray(np.random.randn(n), coords=time_coord, name="y") + + # Map data to calendar and store for later reference + calendar = s2spy.time.AdventCalendar(anchor=(10, 15), freq="180d") + calendar.map_to_data(x1) + x1 = s2spy.time.resample(calendar, x1) + x2 = s2spy.time.resample(calendar, x2) + y = s2spy.time.resample(calendar, y) + return x1, x2, y + + +def test_kfold_x(dummy_data): + """Correctly split x.""" + x1, _, _ = dummy_data + cv = s2spy.traintest.TrainTestSplit(KFold(n_splits=3)) + x_train, x_test = next(cv.split(x1)) + expected_train = [2019, 2020, 2021, 2022, 2023] + expected_test = [2016, 2017, 2018] + assert np.array_equal(x_train.anchor_year, expected_train) + xr.testing.assert_equal(x_test, x1.sel(anchor_year=expected_test)) + + +def test_kfold_xy(dummy_data): + """Correctly split x and y.""" + x1, _, y = dummy_data + cv = s2spy.traintest.TrainTestSplit(KFold(n_splits=3)) + x_train, x_test, y_train, y_test = next(cv.split(x1, y=y)) + expected_train = [2019, 2020, 2021, 2022, 2023] + expected_test = [2016, 2017, 2018] + + assert np.array_equal(x_train.anchor_year, expected_train) + xr.testing.assert_equal(x_test, x1.sel(anchor_year=expected_test)) + assert np.array_equal(y_train.anchor_year, expected_train) + xr.testing.assert_equal(y_test, y.sel(anchor_year=expected_test)) + + +def test_kfold_xxy(dummy_data): + """Correctly split x1, x2, and y.""" + x1, x2, y = dummy_data + cv = s2spy.traintest.TrainTestSplit(KFold(n_splits=3)) + x_train, x_test, y_train, y_test = next(cv.split(x1, x2, y=y)) + expected_train = [2019, 2020, 2021, 2022, 2023] + expected_test = [2016, 2017, 2018] + + assert np.array_equal(x_train[0].anchor_year, expected_train) + xr.testing.assert_equal(x_test[1], x2.sel(anchor_year=expected_test)) + assert np.array_equal(y_train.anchor_year, expected_train) + xr.testing.assert_equal(y_test, y.sel(anchor_year=expected_test)) + + +def test_kfold_too_short(dummy_data): + "Fail if there is only a single anchor year: no splits can be made" + x1, _, _ = dummy_data + x = x1.isel(anchor_year=1) + cv = s2spy.traintest.TrainTestSplit(KFold(n_splits=3)) + + with pytest.raises(ValueError): + next(cv.split(x)) + + +def test_kfold_different_xcoords(dummy_data): + x1, x2, _ = dummy_data + x1 = x1.isel(anchor_year=slice(1, None, None)) + cv = s2spy.traintest.TrainTestSplit(KFold(n_splits=3)) + + with pytest.raises(s2spy.traintest.CoordinateMismatch): + next(cv.split(x1, x2)) + + +def test_custom_dim(dummy_data): + x1, _, _ = dummy_data + x = x1.rename(anchor_year="custom_coord") + cv = s2spy.traintest.TrainTestSplit(KFold(n_splits=3)) + x_train, _ = next(cv.split(x, dim="custom_coord")) + expected_train = [2019, 2020, 2021, 2022, 2023] + + assert np.array_equal(x_train.custom_coord, expected_train)