diff --git a/docs/examples.rst b/docs/examples.rst index d7d6a1e2..5101eba8 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -29,15 +29,13 @@ You can also find a simple and quick-start tutorial notebook on Google Colab # Data preprocessing. Tedious, but PyPOTS can help. 🤓 data = load_specific_dataset('physionet_2012') # PyPOTS will automatically download and extract it. - X = data['X'] - num_samples = len(X['RecordID'].unique()) - X = X.drop(['RecordID', 'Time'], axis = 1) - X = StandardScaler().fit_transform(X.to_numpy()) - X = X.reshape(num_samples, 48, -1) + X = data['train_X'] + num_samples = len(X) + X = StandardScaler().fit_transform(X.reshape(-1, X.shape[-1])).reshape(X.shape) X_ori = X # keep X_ori for validation X = mcar(X, 0.1) # randomly hold out 10% observed values as ground truth dataset = {"X": X} # X for model input - print(X.shape) # (11988, 48, 37), 11988 samples, 48 time steps, 37 features + print(X.shape) # (7671, 48, 37), 7671 samples, 48 time steps, 37 features # initialize the model saits = SAITS( @@ -55,7 +53,7 @@ You can also find a simple and quick-start tutorial notebook on Google Colab model_saving_strategy="best", # only save the model with the best validation performance ) - # train the model. Here I use the whole dataset as the training set, because ground truth is not visible to the model. + # train the model. Here I consider the train dataset only, and evaluate on it, because ground truth is not visible to the model. saits.fit(dataset) # impute the originally-missing values and artificially-missing values imputation = saits.impute(dataset) @@ -64,6 +62,6 @@ You can also find a simple and quick-start tutorial notebook on Google Colab mae = calc_mae(imputation, np.nan_to_num(X_ori), indicating_mask) # calculate mean absolute error on the ground truth (artificially-missing values) # the best model has been already saved, but you can still manually save it with function save_model() as below - saits.save_model(saving_dir="examples/saits",file_name="manually_saved_saits_model") + saits.save(saving_path="examples/saits/manually_saved_saits_model") # you can load the saved model into a new initialized model - saits.load_model("examples/saits/manually_saved_saits_model") + saits.load("examples/saits/manually_saved_saits_model.pypots")