Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add DARTSClassifier #164

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4d937d9
refactor: remove default_dtype setting from darts dataset
hollandjg Dec 6, 2022
92b1eeb
fix: set default dtype for torch to float64
hollandjg Dec 6, 2022
c0a5080
feat: add outline of DARTSClassifier
hollandjg Dec 6, 2022
605748d
fix: convert OPS dictionary into a factory function
hollandjg Dec 6, 2022
a0696a8
docs: fix datatype in DARTSRegressor docstring
hollandjg Dec 6, 2022
5edcdf7
test: add testcase for DARTSClassifier
hollandjg Dec 6, 2022
34472c7
refactor: remove cast in fit method on Classifier
hollandjg Dec 6, 2022
ad4d24f
Revert "fix: set default dtype for torch to float64"
hollandjg Dec 6, 2022
4d4a85f
feat: add predict and predict_proba to DARTSClassifier
hollandjg Dec 6, 2022
1518bfc
refactor: use predict_proba within predict
hollandjg Dec 6, 2022
da26a36
test: update testcases
hollandjg Dec 6, 2022
6c396c8
Merge branch 'main' into feat-darts-classifier
hollandjg Dec 6, 2022
70b15e4
test: fix testcase
hollandjg Dec 6, 2022
e7dcd4d
Merge branch 'main' into feat-darts-classifier
hollandjg Dec 7, 2022
31eb03a
Merge branch 'main' into feat-darts-classifier
hollandjg Dec 8, 2022
23bafea
Merge branch 'main' into feat-darts-classifier
hollandjg Dec 13, 2022
f3bbc63
Merge branch 'main' into feat-darts-classifier
hollandjg Dec 14, 2022
d02b1e1
Update autora/skl/darts.py
hollandjg Dec 15, 2022
c33bc9a
Update autora/skl/darts.py
hollandjg Dec 15, 2022
d661147
Merge branch 'main' into feat-darts-classifier
hollandjg Dec 15, 2022
70b4f3e
Merge branch 'main' into feat-darts-classifier
hollandjg Dec 15, 2022
547a23f
Merge branch 'main' into feat-darts-classifier
hollandjg Jan 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 214 additions & 7 deletions autora/skl/darts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.utils.data
import tqdm
from matplotlib import pyplot as plt
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y

from autora.theorist.darts import (
Expand Down Expand Up @@ -42,6 +42,7 @@
"probability",
"probability_sample",
"probability_distribution",
"class",
]


Expand Down Expand Up @@ -126,12 +127,18 @@ def _general_darts(

_logger.info("Starting fit initialization")

data_loader, input_dimensions, output_dimensions = _get_data_loader(
data_loader = _get_data_loader(
X=X,
y=y,
batch_size=batch_size,
)

input_dimensions = X.shape[1]
if output_type == ValueType.CLASS:
output_dimensions = len(np.unique(y))
else:
output_dimensions = y.shape[1]

criterion = get_loss_function(ValueType(output_type))
output_function = get_output_format(ValueType(output_type))

Expand Down Expand Up @@ -320,9 +327,6 @@ def _get_data_loader(
if y_.ndim == 1:
y_ = y_.reshape((y_.size, 1))

input_dimensions = X_.shape[1]
output_dimensions = y_.shape[1]

experimental_data = darts_dataset_from_ndarray(X_, y_)

data_loader = torch.utils.data.DataLoader(
Expand All @@ -332,7 +336,7 @@ def _get_data_loader(
pin_memory=True,
num_workers=0,
)
return data_loader, input_dimensions, output_dimensions
return data_loader


def _get_data_iterator(data_loader: torch.utils.data.DataLoader) -> Iterator:
Expand Down Expand Up @@ -454,7 +458,7 @@ class DARTSRegressor(BaseEstimator, RegressorMixin):
>>> estimator = DARTSRegressor(num_graph_nodes=1)
>>> estimator = estimator.fit(X, y)
>>> estimator.predict([[0.5]])
array([[15.051043]], dtype=float32)
array([[15.051043]])


Attributes:
Expand Down Expand Up @@ -776,6 +780,209 @@ def model_repr(
return model_repr_


class DARTSClassifier(DARTSRegressor, ClassifierMixin):
"""
Differentiable ARchiTecture Search Classifier.

DARTS finds a composition of functions and coefficients to minimize a loss function suitable for
the dependent variable.

This class is intended to be compatible with the
[Scikit-Learn Estimator API](https://scikit-learn.org/stable/developers/develop.html).

Examples:

TODO: add example


Attributes:
network_: represents the optimized network for the architecture search, without the
output function
model_: represents the best-fit model including the output function
after sampling of the network to pick a single computation graph.
By default, this is the computation graph with the maximum weights,
but can be set to a graph based on a sample on the edge weights
hollandjg marked this conversation as resolved.
Show resolved Hide resolved
by running the `resample_model(sample_strategy="sample")` method.
It can be reset by running the `resample_model(sample_strategy="max")` method.



"""

def __init__(
self,
batch_size: int = 64,
num_graph_nodes: int = 2,
output_type: IMPLEMENTED_OUTPUT_TYPES = "class",
classifier_weight_decay: float = 1e-2,
darts_type: IMPLEMENTED_DARTS_TYPES = "original",
init_weights_function: Optional[Callable] = None,
param_updates_per_epoch: int = 10,
param_updates_for_sampled_model: int = 100,
param_learning_rate_max: float = 2.5e-2,
param_learning_rate_min: float = 0.01,
param_momentum: float = 9e-1,
param_weight_decay: float = 3e-4,
arch_updates_per_epoch: int = 1,
arch_learning_rate_max: float = 3e-3,
arch_weight_decay: float = 1e-4,
arch_weight_decay_df: float = 3e-4,
arch_weight_decay_base: float = 0.0,
arch_momentum: float = 9e-1,
fair_darts_loss_weight: int = 1,
max_epochs: int = 10,
grad_clip: float = 5,
primitives: Sequence[str] = PRIMITIVES,
train_classifier_coefficients: bool = False,
train_classifier_bias: bool = False,
execution_monitor: Callable = (lambda *args, **kwargs: None),
sampling_strategy: SAMPLING_STRATEGIES = "max",
) -> None:
"""
Initializes the DARTSRegressor.

Arguments:
batch_size: Batch size for the data loader.
num_graph_nodes: Number of nodes in the desired computation graph.
output_type: Type of output function to use. This function is applied to transform
the output of the mixture architecture.
classifier_weight_decay: Weight decay for the classifier.
darts_type: Type of DARTS to use ('original' or 'fair').
init_weights_function: Function to initialize the parameters of each operation.
param_updates_per_epoch: Number of updates to perform per epoch.
hollandjg marked this conversation as resolved.
Show resolved Hide resolved
for the operation parameters.
param_learning_rate_max: Initial (maximum) learning rate for the operation parameters.
param_learning_rate_min: Final (minimum) learning rate for the operation parameters.
param_momentum: Momentum for the operation parameters.
param_weight_decay: Weight decay for the operation parameters.
arch_updates_per_epoch: Number of architecture weight updates to perform per epoch.
arch_learning_rate_max: Initial (maximum) learning rate for the architecture.
arch_weight_decay: Weight decay for the architecture weights.
arch_weight_decay_df: An additional weight decay that scales with the number of
parameters (degrees of freedom) in the operation. The higher this weight decay,
the more DARTS will prefer simple operations.
arch_weight_decay_base: A base weight decay that is added to the scaled weight decay.
arch_momentum: Momentum for the architecture weights.
fair_darts_loss_weight: Weight of the loss in fair darts which forces architecture
weights to become either 0 or 1.
max_epochs: Maximum number of epochs to train for.
grad_clip: Gradient clipping value for updating the parameters of the operations.
primitives: List of primitives (operations) to use.
train_classifier_coefficients: Whether to train the coefficients of the classifier.
train_classifier_bias: Whether to train the bias of the classifier.
execution_monitor: Function to monitor the execution of the model.
primitives: list of primitive operations used in the DARTS network,
e.g., 'add', 'subtract', 'none'. For details, see
[`autora.theorist.darts.operations`][autora.theorist.darts.operations]
"""

self.batch_size = batch_size

self.num_graph_nodes = num_graph_nodes
self.classifier_weight_decay = classifier_weight_decay
self.darts_type = darts_type
self.init_weights_function = init_weights_function

self.param_updates_per_epoch = param_updates_per_epoch
self.param_updates_for_sampled_model = param_updates_for_sampled_model

self.param_learning_rate_max = param_learning_rate_max
self.param_learning_rate_min = param_learning_rate_min
self.param_momentum = param_momentum
self.arch_momentum = arch_momentum
self.param_weight_decay = param_weight_decay

self.arch_updates_per_epoch = arch_updates_per_epoch
self.arch_weight_decay = arch_weight_decay
self.arch_weight_decay_df = arch_weight_decay_df
self.arch_weight_decay_base = arch_weight_decay_base
self.arch_learning_rate_max = arch_learning_rate_max
self.fair_darts_loss_weight = fair_darts_loss_weight

self.max_epochs = max_epochs
self.grad_clip = grad_clip

self.primitives = primitives

self.output_type = output_type
self.darts_type = darts_type

self.X_: Optional[np.ndarray] = None
self.y_: Optional[np.ndarray] = None
self.network_: Optional[Network] = None
self.model_: Optional[Network] = None

self.train_classifier_coefficients = train_classifier_coefficients
self.train_classifier_bias = train_classifier_bias

self.execution_monitor = execution_monitor

self.sampling_strategy = sampling_strategy

def fit(self, X: np.ndarray, y: np.ndarray):
"""
Runs the optimization for a given set of `X`s and `y`s.

Arguments:
X: independent variables in an n-dimensional array
y: dependent variables in an n-dimensional array

Returns:
self (DARTSRegressor): the fitted estimator
"""

params = self.get_params()

self.X_ = X
self.y_ = y

fit_results = _general_darts(
X=self.X_, y=self.y_, network=self.network_, **params
)
self.network_ = fit_results.network
self.model_ = fit_results.model
return self

def predict(self, X: np.ndarray) -> np.ndarray:
"""
Applies the fitted model to a set of independent variables `X`,
to give predictions for the dependent variable `y`.

Arguments:
X: independent variables in an n-dimensional array

Returns:
y: predicted dependent variable values
"""
probabilities = self.predict_proba(X)
classes = np.argmax(probabilities, axis=1)
y = classes

return y

def predict_proba(self, X: np.ndarray) -> np.ndarray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious about the use of the function name predict_proba instead of just predict_prob

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a scikit-learn naming convention. No idea why they chose it.

"""
Applies the fitted model to a set of independent variables `X`,
to give predictions for the dependent variable `y`.

Arguments:
X: independent variables in an n-dimensional array

Returns:
y: predicted dependent variable values
"""
X_ = check_array(X)

check_is_fitted(self, attributes=["model_"])

assert self.model_ is not None

probabilities = self.model_(torch.as_tensor(X_))
y = probabilities.detach().numpy()

return y


class DARTSExecutionMonitor:
"""
A monitor of the execution of the DARTS algorithm.
Expand Down
4 changes: 2 additions & 2 deletions autora/theorist/darts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def darts_dataset_from_ndarray(
"""

obj = DARTSDataset(
torch.tensor(input_data, dtype=torch.get_default_dtype()),
torch.tensor(output_data, dtype=torch.get_default_dtype()),
torch.tensor(input_data),
torch.tensor(output_data),
)
return obj
32 changes: 32 additions & 0 deletions tests/test_darts_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
import torch
from sklearn.datasets import make_classification
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.model_selection import train_test_split

from autora.skl.darts import DARTSClassifier

torch.set_default_dtype(torch.double)


@pytest.fixture
def classification_data():
x, y = make_classification(random_state=180)
return x, y


def test_darts_classifier(classification_data):
x, y = classification_data

x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=181)
for classifier in [GaussianProcessClassifier(), DARTSClassifier()]:
classifier.fit(x_train, y_train)

predictions = classifier.predict(x_test)
assert predictions is not None

prediction_probabilities = classifier.predict_proba(x_test)
assert prediction_probabilities is not None

score = classifier.score(x_test, y_test)
print(f"\n{classifier=} {score=}")