From ee062539caad5051f84bb2d2f8ae6f2282b0467b Mon Sep 17 00:00:00 2001 From: Sebastian Musslick Date: Mon, 3 Jul 2023 11:22:44 -0400 Subject: [PATCH] added test for iterator conversion --- tests/test_exp_falsification_sampler.py | 47 +++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/test_exp_falsification_sampler.py b/tests/test_exp_falsification_sampler.py index 7368bf3..d1db082 100644 --- a/tests/test_exp_falsification_sampler.py +++ b/tests/test_exp_falsification_sampler.py @@ -4,6 +4,7 @@ from sklearn.linear_model import LinearRegression, LogisticRegression from autora.experimentalist.pipeline import Pipeline +from autora.experimentalist.pooler.grid import grid_pool from autora.experimentalist.sampler.falsification import ( falsification_sample, falsification_score_sample, @@ -276,6 +277,52 @@ def test_falsification_reconstruction_without_model( assert np.round(X_selected[0, 0], 4) == 1.8 or np.round(X_selected[0, 0], 4) == 4.8 assert np.round(X_selected[1, 0], 4) == 1.8 or np.round(X_selected[1, 0], 4) == 4.8 +def test_iterator_input(synthetic_linr_model): + # Import model and data_closed_loop + X_train, Y_train = get_sin_data() + model = synthetic_linr_model + + # specify meta data_closed_loop + + # Specify independent variables + iv = IV( + name="x", + value_range=(0, 2 * np.pi), + allowed_values=(np.linspace(0, 2 * np.pi, 100)), + units="intensity", + variable_label="stimulus", + ) + + # specify dependent variables + dv = DV( + name="y", + value_range=(-1, 1), + units="real", + variable_label="response", + type=ValueType.REAL, + ) + + # Variable collection with ivs and dvs + metadata = VariableCollection( + independent_variables=[iv], + dependent_variables=[dv], + ) + + X = grid_pool(metadata.independent_variables) + + new_conditions, new_scores = falsification_score_sample( + condition_pool=X, + model=model, + reference_conditions=X_train, + reference_observations=Y_train, + metadata=metadata, + num_samples=5, + training_epochs=1000, + training_lr=1e-3, + plot=False, + ) + + assert new_conditions.shape[0] == 5 def test_doc_example(): # Specify X and Y