Skip to content

Commit

Permalink
fixed multiclass and cross_entropy
Browse files Browse the repository at this point in the history
  • Loading branch information
myungkim930 committed Sep 2, 2024
1 parent b6bb302 commit d65d9b6
Showing 1 changed file with 71 additions and 37 deletions.
108 changes: 71 additions & 37 deletions src/carte_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
from torch import Tensor
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
from sklearn.model_selection import RepeatedKFold, RepeatedStratifiedKFold, ShuffleSplit, StratifiedShuffleSplit, ParameterGrid, train_test_split
from sklearn.model_selection import (
RepeatedKFold,
RepeatedStratifiedKFold,
ShuffleSplit,
StratifiedShuffleSplit,
ParameterGrid,
train_test_split,
)
from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin
from sklearn.utils.validation import check_is_fitted, check_random_state
from joblib import Parallel, delayed
Expand Down Expand Up @@ -100,7 +107,7 @@ def fit(self, X, y):
# Store the required results that may be used later
self.model_list_ = [model for (model, _) in result_fit]
self.valid_loss_ = [valid_loss for (_, valid_loss) in result_fit]
self.weights_ = np.array([1/self.num_model]*self.num_model)
self.weights_ = np.array([1 / self.num_model] * self.num_model)
self.is_fitted_ = True

return self
Expand Down Expand Up @@ -167,6 +174,8 @@ def _run_step(self, model, data, optimizer):
data.to(self.device_) # Send to device
out = model(data) # Perform a single forward pass.
target = data.y # Set target
if self.loss == "categorical_crossentropy":
target = target.to(torch.long)
if self.output_dim_ == 1:
out = out.view(-1).to(torch.float32) # Reshape outputSet head index
target = target.to(torch.float32) # Reshape target
Expand All @@ -183,6 +192,8 @@ def _eval(self, model, ds_eval):
model.eval()
out = model(ds_eval)
target = ds_eval.y
if self.loss == "categorical_crossentropy":
target = target.to(torch.long)
if self.output_dim_ == 1:
out = out.view(-1).to(torch.float32)
target = target.to(torch.float32)
Expand All @@ -196,7 +207,7 @@ def _eval(self, model, ds_eval):

def _set_train_valid_split(self):
"""Train/validation split for the bagging strategy.
The style of split depends on the cross_validate parameter.
Reuturns the train/validation split with KFold cross-validation.
"""
Expand All @@ -206,23 +217,37 @@ def _set_train_valid_split(self):
n_splits = int(1 / self.val_size)
n_repeats = int(self.num_model / n_splits)
splitter = RepeatedKFold(
n_splits=n_splits, n_repeats=n_repeats, random_state=self.random_state,
n_splits=n_splits,
n_repeats=n_repeats,
random_state=self.random_state,
)
else:
splitter = ShuffleSplit(n_splits = self.num_model, test_size=self.val_size, random_state=self.random_state)
splitter = ShuffleSplit(
n_splits=self.num_model,
test_size=self.val_size,
random_state=self.random_state,
)
splits = [
(train_index, test_index)
for train_index, test_index in splitter.split(np.arange(0, len(self.X_)))
]
(train_index, test_index)
for train_index, test_index in splitter.split(
np.arange(0, len(self.X_))
)
]
else:
if self.cross_validate:
n_splits = int(1 / self.val_size)
n_repeats = int(self.num_model / n_splits)
splitter = RepeatedStratifiedKFold(
n_splits=n_splits, n_repeats=n_repeats, random_state=self.random_state,
n_splits=n_splits,
n_repeats=n_repeats,
random_state=self.random_state,
)
else:
splitter = StratifiedShuffleSplit(n_splits = self.num_model, test_size=self.val_size, random_state=self.random_state)
splitter = StratifiedShuffleSplit(
n_splits=self.num_model,
test_size=self.val_size,
random_state=self.random_state,
)
splits = [
(train_index, test_index)
for train_index, test_index in splitter.split(
Expand Down Expand Up @@ -257,6 +282,14 @@ def _generate_output(self, X, model_list, weights):
model(ds_predict_eval).cpu().detach().numpy() for model in model_list
]
out = np.array(out).squeeze().transpose()

# Transform according to loss
if self.loss == "categorical_crossentropy":
if len(model_list) != 1:
out = out.transpose((1, 2, 0))
else:
out = out.transpose()

if len(model_list) != 1:
out = np.average(out, weights=weights, axis=1)

Expand All @@ -273,8 +306,7 @@ def _generate_output(self, X, model_list, weights):
return out

def _set_task_specific_settings(self):
"""Set task specific settings for regression and classfication.
"""
"""Set task specific settings for regression and classfication."""

if self._estimator_type == "regressor":
if self.loss == "squared_error":
Expand All @@ -289,27 +321,24 @@ def _set_task_specific_settings(self):
self.valid_loss_flag_ = "neg"
self.output_dim_ = 1
elif self._estimator_type == "classifier":
self.classes_ = np.unique(self.y_)
if self.loss == "binary_crossentropy":
self.criterion_ = torch.nn.BCEWithLogitsLoss()
self.output_dim_ = 1
if self.scoring == "auroc":
self.valid_loss_metric_ = BinaryAUROC()
self.valid_loss_flag_ = "neg"
elif self.scoring == "binary_entropy":
self.valid_loss_metric_ = BinaryNormalizedEntropy(from_logits=True)
self.valid_loss_flag_ = "neg"
elif self.scoring == "auprc":
self.valid_loss_metric_ = BinaryAUPRC()
self.valid_loss_flag_ = "neg"
elif self.loss == "categorical_crossentropy":
self.criterion_ = torch.nn.CrossEntropyLoss()
self.output_dim_ = len(np.unique(self.y_))
if self.output_dim_ == 2:
self.output_dim_ -= 1
self.criterion_ = torch.nn.BCEWithLogitsLoss()
if self.scoring == "auroc":
self.valid_loss_metric_ = BinaryAUROC()
self.valid_loss_flag_ = "neg"
elif self.scoring == "binary_entropy":
self.valid_loss_metric_ = BinaryNormalizedEntropy(from_logits = True)
self.valid_loss_flag_ = "neg"
elif self.scoring == "auprc":
self.valid_loss_metric_ = BinaryAUPRC()
self.valid_loss_flag_ = "neg"
if self.loss == "categorical_crossentropy":
self.output_dim_ = len(np.unique(self.y_))
self.valid_loss_metric_ = MulticlassAUROC(num_classes=self.output_dim_)
self.valid_loss_flag_ = "neg"
self.classes_ = np.unique(self.y_)
self.valid_loss_metric_.to(self.device_)

def _load_model(self):
Expand All @@ -327,7 +356,7 @@ def _load_model(self):
model_config["hidden_dim"] = self.X_[0].x.size(1)
model_config["ff_dim"] = self.X_[0].x.size(1)
model_config["num_heads"] = 12
model_config["num_layers"] = self.num_layers-1
model_config["num_layers"] = self.num_layers - 1
model_config["output_dim"] = self.output_dim_
model_config["dropout"] = self.dropout

Expand Down Expand Up @@ -462,10 +491,10 @@ def predict(self, X):
y : ndarray, shape (n_samples,)
The predicted values.
"""

check_is_fitted(self, "is_fitted_")

out = self._generate_output(X=X, model_list = self.model_list_, weights=None)
out = self._generate_output(X=X, model_list=self.model_list_, weights=None)

return out

Expand Down Expand Up @@ -626,7 +655,7 @@ def _get_predict_prob(self, X):
The raw predicted values.
"""

out = self._generate_output(X=X, model_list = self.model_list_, weights=None)
out = self._generate_output(X=X, model_list=self.model_list_, weights=None)

return out

Expand Down Expand Up @@ -1043,8 +1072,8 @@ def __init__(
self.scoring = scoring

def predict(self, X):
"""Predict values for X.
"""Predict values for X.
Returns the weighted average of the singletable model and all pairwise model with 1-source.
Parameters
Expand Down Expand Up @@ -1237,12 +1266,15 @@ def _get_predict_prob(self, X):
model_list = [self.model_list_[idx] for idx in idx_]
out += [self._generate_output(X, model_list=model_list, weights=None)]
out = np.array(out).squeeze().transpose()

out = np.average(out, weights=self.weights_, axis=1)

# Transform according to loss
if self.loss == "binary_crossentropy":
out = 1 / (1 + np.exp(-out))
elif self.loss == "categorical_crossentropy":
out = softmax(out, axis=1)

# Control for nulls in prediction
if np.isnan(out).sum() > 0:
mean_pred = np.mean(self.y_)
Expand All @@ -1255,7 +1287,7 @@ class CARTE_AblationRegressor(CARTERegressor):
This estimator is GNN-based model compatible with the CARTE pretrained model.
Note that this is an implementation for the ablation study of CARTE
Parameters
----------
ablation_method : {'exclude-edge', 'exclude-attention', 'exclude-attention-edge'}, default='exclude-edge'
Expand Down Expand Up @@ -1299,6 +1331,7 @@ class CARTE_AblationRegressor(CARTERegressor):
disable_pbar : bool, default=True
Indicates whether to show progress bars for the training process.
"""

def __init__(
self,
*,
Expand Down Expand Up @@ -1361,7 +1394,7 @@ def _load_model(self):
model_config["hidden_dim"] = self.X_[0].x.size(1)
model_config["ff_dim"] = self.X_[0].x.size(1)
model_config["num_heads"] = 12
model_config["num_layers"] = self.num_layers-1
model_config["num_layers"] = self.num_layers - 1
model_config["output_dim"] = self.output_dim_
model_config["dropout"] = self.dropout

Expand Down Expand Up @@ -1398,7 +1431,7 @@ class CARTE_AblationClassifier(CARTEClassifier):
This estimator is GNN-based model compatible with the CARTE pretrained model.
Note that this is an implementation for the ablation study of CARTE
Parameters
----------
ablation_method : {'exclude-edge', 'exclude-attention', 'exclude-attention-edge'}, default='exclude-edge'
Expand Down Expand Up @@ -1442,6 +1475,7 @@ class CARTE_AblationClassifier(CARTEClassifier):
disable_pbar : bool, default=True
Indicates whether to show progress bars for the training process.
"""

def __init__(
self,
*,
Expand Down Expand Up @@ -1504,7 +1538,7 @@ def _load_model(self):
model_config["hidden_dim"] = self.X_[0].x.size(1)
model_config["ff_dim"] = self.X_[0].x.size(1)
model_config["num_heads"] = 12
model_config["num_layers"] = self.num_layers-1
model_config["num_layers"] = self.num_layers - 1
model_config["output_dim"] = self.output_dim_
model_config["dropout"] = self.dropout

Expand Down

0 comments on commit d65d9b6

Please sign in to comment.