diff --git a/autora/skl/darts.py b/autora/skl/darts.py index 7bc3f794d..57675cec5 100644 --- a/autora/skl/darts.py +++ b/autora/skl/darts.py @@ -12,7 +12,7 @@ import torch.utils.data import tqdm from matplotlib import pyplot as plt -from sklearn.base import BaseEstimator, RegressorMixin +from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin from sklearn.utils.validation import check_array, check_is_fitted, check_X_y from autora.theorist.darts import ( @@ -42,6 +42,7 @@ "probability", "probability_sample", "probability_distribution", + "class", ] @@ -126,12 +127,18 @@ def _general_darts( _logger.info("Starting fit initialization") - data_loader, input_dimensions, output_dimensions = _get_data_loader( + data_loader = _get_data_loader( X=X, y=y, batch_size=batch_size, ) + input_dimensions = X.shape[1] + if output_type == ValueType.CLASS: + output_dimensions = len(np.unique(y)) + else: + output_dimensions = y.shape[1] + criterion = get_loss_function(ValueType(output_type)) output_function = get_output_format(ValueType(output_type)) @@ -320,9 +327,6 @@ def _get_data_loader( if y_.ndim == 1: y_ = y_.reshape((y_.size, 1)) - input_dimensions = X_.shape[1] - output_dimensions = y_.shape[1] - experimental_data = darts_dataset_from_ndarray(X_, y_) data_loader = torch.utils.data.DataLoader( @@ -332,7 +336,7 @@ def _get_data_loader( pin_memory=True, num_workers=0, ) - return data_loader, input_dimensions, output_dimensions + return data_loader def _get_data_iterator(data_loader: torch.utils.data.DataLoader) -> Iterator: @@ -454,7 +458,7 @@ class DARTSRegressor(BaseEstimator, RegressorMixin): >>> estimator = DARTSRegressor(num_graph_nodes=1) >>> estimator = estimator.fit(X, y) >>> estimator.predict([[0.5]]) - array([[15.051043]], dtype=float32) + array([[15.051043]]) Attributes: @@ -776,6 +780,209 @@ def model_repr( return model_repr_ +class DARTSClassifier(DARTSRegressor, ClassifierMixin): + """ + Differentiable ARchiTecture Search Classifier. + + DARTS finds a composition of functions and coefficients to minimize a loss function suitable for + the dependent variable. + + This class is intended to be compatible with the + [Scikit-Learn Estimator API](https://scikit-learn.org/stable/developers/develop.html). + + Examples: + + TODO: add example + + + Attributes: + network_: represents the optimized network for the architecture search, without the + output function + model_: represents the best-fit model including the output function + after sampling of the network to pick a single computation graph. + By default, this is the computation graph with the maximum weights, + but can be set to a graph based on a sample of the edge weights + by running the `resample_model(sample_strategy="sample")` method. + It can be reset by running the `resample_model(sample_strategy="max")` method. + + + + """ + + def __init__( + self, + batch_size: int = 64, + num_graph_nodes: int = 2, + output_type: IMPLEMENTED_OUTPUT_TYPES = "class", + classifier_weight_decay: float = 1e-2, + darts_type: IMPLEMENTED_DARTS_TYPES = "original", + init_weights_function: Optional[Callable] = None, + param_updates_per_epoch: int = 10, + param_updates_for_sampled_model: int = 100, + param_learning_rate_max: float = 2.5e-2, + param_learning_rate_min: float = 0.01, + param_momentum: float = 9e-1, + param_weight_decay: float = 3e-4, + arch_updates_per_epoch: int = 1, + arch_learning_rate_max: float = 3e-3, + arch_weight_decay: float = 1e-4, + arch_weight_decay_df: float = 3e-4, + arch_weight_decay_base: float = 0.0, + arch_momentum: float = 9e-1, + fair_darts_loss_weight: int = 1, + max_epochs: int = 10, + grad_clip: float = 5, + primitives: Sequence[str] = PRIMITIVES, + train_classifier_coefficients: bool = False, + train_classifier_bias: bool = False, + execution_monitor: Callable = (lambda *args, **kwargs: None), + sampling_strategy: SAMPLING_STRATEGIES = "max", + ) -> None: + """ + Initializes the DARTSRegressor. + + Arguments: + batch_size: Batch size for the data loader. + num_graph_nodes: Number of nodes in the desired computation graph. + output_type: Type of output function to use. This function is applied to transform + the output of the mixture architecture. + classifier_weight_decay: Weight decay for the classifier. + darts_type: Type of DARTS to use ('original' or 'fair'). + init_weights_function: Function to initialize the parameters of each operation. + param_updates_per_epoch: Number of updates to perform per epoch + for the operation parameters. + param_learning_rate_max: Initial (maximum) learning rate for the operation parameters. + param_learning_rate_min: Final (minimum) learning rate for the operation parameters. + param_momentum: Momentum for the operation parameters. + param_weight_decay: Weight decay for the operation parameters. + arch_updates_per_epoch: Number of architecture weight updates to perform per epoch. + arch_learning_rate_max: Initial (maximum) learning rate for the architecture. + arch_weight_decay: Weight decay for the architecture weights. + arch_weight_decay_df: An additional weight decay that scales with the number of + parameters (degrees of freedom) in the operation. The higher this weight decay, + the more DARTS will prefer simple operations. + arch_weight_decay_base: A base weight decay that is added to the scaled weight decay. + arch_momentum: Momentum for the architecture weights. + fair_darts_loss_weight: Weight of the loss in fair darts which forces architecture + weights to become either 0 or 1. + max_epochs: Maximum number of epochs to train for. + grad_clip: Gradient clipping value for updating the parameters of the operations. + primitives: List of primitives (operations) to use. + train_classifier_coefficients: Whether to train the coefficients of the classifier. + train_classifier_bias: Whether to train the bias of the classifier. + execution_monitor: Function to monitor the execution of the model. + primitives: list of primitive operations used in the DARTS network, + e.g., 'add', 'subtract', 'none'. For details, see + [`autora.theorist.darts.operations`][autora.theorist.darts.operations] + """ + + self.batch_size = batch_size + + self.num_graph_nodes = num_graph_nodes + self.classifier_weight_decay = classifier_weight_decay + self.darts_type = darts_type + self.init_weights_function = init_weights_function + + self.param_updates_per_epoch = param_updates_per_epoch + self.param_updates_for_sampled_model = param_updates_for_sampled_model + + self.param_learning_rate_max = param_learning_rate_max + self.param_learning_rate_min = param_learning_rate_min + self.param_momentum = param_momentum + self.arch_momentum = arch_momentum + self.param_weight_decay = param_weight_decay + + self.arch_updates_per_epoch = arch_updates_per_epoch + self.arch_weight_decay = arch_weight_decay + self.arch_weight_decay_df = arch_weight_decay_df + self.arch_weight_decay_base = arch_weight_decay_base + self.arch_learning_rate_max = arch_learning_rate_max + self.fair_darts_loss_weight = fair_darts_loss_weight + + self.max_epochs = max_epochs + self.grad_clip = grad_clip + + self.primitives = primitives + + self.output_type = output_type + self.darts_type = darts_type + + self.X_: Optional[np.ndarray] = None + self.y_: Optional[np.ndarray] = None + self.network_: Optional[Network] = None + self.model_: Optional[Network] = None + + self.train_classifier_coefficients = train_classifier_coefficients + self.train_classifier_bias = train_classifier_bias + + self.execution_monitor = execution_monitor + + self.sampling_strategy = sampling_strategy + + def fit(self, X: np.ndarray, y: np.ndarray): + """ + Runs the optimization for a given set of `X`s and `y`s. + + Arguments: + X: independent variables in an n-dimensional array + y: dependent variables in an n-dimensional array + + Returns: + self (DARTSRegressor): the fitted estimator + """ + + params = self.get_params() + + self.X_ = X + self.y_ = y + + fit_results = _general_darts( + X=self.X_, y=self.y_, network=self.network_, **params + ) + self.network_ = fit_results.network + self.model_ = fit_results.model + return self + + def predict(self, X: np.ndarray) -> np.ndarray: + """ + Applies the fitted model to a set of independent variables `X`, + to give predictions for the dependent variable `y`. + + Arguments: + X: independent variables in an n-dimensional array + + Returns: + y: predicted dependent variable values + """ + probabilities = self.predict_proba(X) + classes = np.argmax(probabilities, axis=1) + y = classes + + return y + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + """ + Applies the fitted model to a set of independent variables `X`, + to give predictions for the dependent variable `y`. + + Arguments: + X: independent variables in an n-dimensional array + + Returns: + y: predicted dependent variable values + """ + X_ = check_array(X) + + check_is_fitted(self, attributes=["model_"]) + + assert self.model_ is not None + + probabilities = self.model_(torch.as_tensor(X_)) + y = probabilities.detach().numpy() + + return y + + class DARTSExecutionMonitor: """ A monitor of the execution of the DARTS algorithm. diff --git a/autora/theorist/darts/dataset.py b/autora/theorist/darts/dataset.py index d9ae8690d..111cd1ee4 100644 --- a/autora/theorist/darts/dataset.py +++ b/autora/theorist/darts/dataset.py @@ -66,7 +66,7 @@ def darts_dataset_from_ndarray( """ obj = DARTSDataset( - torch.tensor(input_data, dtype=torch.get_default_dtype()), - torch.tensor(output_data, dtype=torch.get_default_dtype()), + torch.tensor(input_data), + torch.tensor(output_data), ) return obj diff --git a/tests/test_darts_classifier.py b/tests/test_darts_classifier.py new file mode 100644 index 000000000..ae3a4fe02 --- /dev/null +++ b/tests/test_darts_classifier.py @@ -0,0 +1,32 @@ +import pytest +import torch +from sklearn.datasets import make_classification +from sklearn.gaussian_process import GaussianProcessClassifier +from sklearn.model_selection import train_test_split + +from autora.skl.darts import DARTSClassifier + +torch.set_default_dtype(torch.double) + + +@pytest.fixture +def classification_data(): + x, y = make_classification(random_state=180) + return x, y + + +def test_darts_classifier(classification_data): + x, y = classification_data + + x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=181) + for classifier in [GaussianProcessClassifier(), DARTSClassifier()]: + classifier.fit(x_train, y_train) + + predictions = classifier.predict(x_test) + assert predictions is not None + + prediction_probabilities = classifier.predict_proba(x_test) + assert prediction_probabilities is not None + + score = classifier.score(x_test, y_test) + print(f"\n{classifier=} {score=}")