-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consider extending traintest to mimic sklearn splitter classes #46
Comments
The current functionality of traintest is to explicitly add the created splits to a given dataframe/array. We could keep that as a separate method: With respect to the iterator discussed in AI4S2S/s2spy#71, we could have something like splitter = sklearn.KFold(...)
traintest = s2spy.TrainTest(splitter)
for train, test in traintest.iterate(X, y):
# do stuff |
Very good suggestion. This looks quite clean and logical. To have a function like this |
Just some thoughts based on our discussion after exploring the common practice of ML for timeseries. From a user perspective, this is something I think is quite logical and is also consistent to my experience of using other ml packages for cross validation and model training: import s2spy.time
import s2spy.traintest
import xarray as xr
from s2spy import RGDR
# assume that I want to explore the causal relation between sea surface temperature and
# the change of Atlantic Meridional Overturning Circulation (AMOC)
# and use sst to predict AMOC
# load data
sst = xr.open_dataset("sst_field_from_2010_to_2020.nc") # daily data [time, lat, lon]
amoc = xr.open_dataset("amoc_rapid_array_obs_from_2010_to_2020.nc") # daily data [time]
# create calendar using s2spy based on my interest of timescales
calendar = s2spy.time.AdventCalendar(anchor=(10, 15), freq="180d")
# map to data
calendar.map_to_data(sst)
# resample my data to the preferred timescales
sst_resample = s2spy.time.resample(calendar, sst)
amoc_resample = s2spy.time.resample(calendar, amoc)
######################## cross validation ###########################
# train/test splits using kfold
from sklearn.model_selection import KFold
splitter = KFold(n_splits=3)
traintest_splits = s2spy.traintest.split_groups(splitter, calendar) # here we make `traintest_splits` a class
# add labels to the data if the user wants to have an overview of the splits - data pairing
sst_traintest_summary = traintest_splits.add_label(sst_resample)
# cross-validation
# we use linear regression model and mse as metrics
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
scores = []
for train_data, test_data in traintest_splits.iterate((sst_resample, amoc_resample)):
sst_train, amoc_train = train_data
sst_test, amoc_test = test_data
# perform dimensionality reduction using RGDR
rgdr = RGDR(amoc_train, eps_km=600, alpha=0.05, min_area_km2=3000**2)
sst_clustered_train = rgdr.fit(sst_train)
# train model
sst_X_train = sst_clustered_train.sel(target = "False")
amoc_y_train = amoc_train.sel(target = "True")
model_ols = LinearRegression(normalize=True)
model_ols.fit(sst_X_train, amoc_y_train)
# apply clusters to test data
sst_clustered_test = rgdr.transform(sst_test)
sst_X_test = sst_clustered_test.sel(target = "False")
amoc_y_test = amoc_test.sel(target = "True")
# make predictions using test data
predict_amoc = model_ols.predict(sst_X_test)
# calculate score with mse
scores.append(mean_squared_error(amoc_y_test, predict_amoc))
######################## cross validation ###########################
# plot scores to check the results from cross validation
import matplotlib.pyplot as plt
plt.plot(scores) I will explain the code in the posts below (it gets a bit too long...). |
A few concerns relate to the workflow above:
We could make the codes in the |
Currently, our train-test splitting function mostly serves to show the result of train-test splitting. Though the output datasets can be used directly, it would require a custom workflow e.g. for cross-validation.
In the future, it might be useful if we could rework them in the form of a class, similar to sklearns existing splitter classes. The main feature we'd add would be that we're a bit more restrictive in how groups are made; e.g. we don't allow splitting up rows from the same anchor year.
Then, it'd be possible to use them in conjunction with existing cross-validation code, e.g. sklearn cross-validate. Something like:
The text was updated successfully, but these errors were encountered: