Skip to content

Commit

Permalink
address security concerns and apply black
Browse files Browse the repository at this point in the history
  • Loading branch information
wistuba committed Feb 13, 2024
1 parent 0fdbb3e commit 0f1b469
Show file tree
Hide file tree
Showing 36 changed files with 3,247 additions and 2,838 deletions.
8 changes: 5 additions & 3 deletions fortuna/calib_model/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,11 @@ def calibrate(
self._check_output_dim(val_data_loader)
return self._calibrate(
calib_data_loader=calib_data_loader,
uncertainty_fn=config.monitor.uncertainty_fn
if config.monitor.uncertainty_fn is not None
else self.prob_output_layer.mean,
uncertainty_fn=(
config.monitor.uncertainty_fn
if config.monitor.uncertainty_fn is not None
else self.prob_output_layer.mean
),
val_data_loader=val_data_loader,
loss_fn=loss_fn,
config=config,
Expand Down
16 changes: 10 additions & 6 deletions fortuna/calib_model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,16 @@ def __call__(
aux = dict()
if self.likelihood.output_calib_manager is not None:
outs = self.likelihood.output_calib_manager.apply(
params=calib_params["output_calibrator"]
if calib_params is not None
else None,
mutable=calib_mutable["output_calibrator"]
if calib_mutable is not None
else None,
params=(
calib_params["output_calibrator"]
if calib_params is not None
else None
),
mutable=(
calib_mutable["output_calibrator"]
if calib_mutable is not None
else None
),
outputs=outputs,
calib="calib_mutable" in return_aux,
)
Expand Down
4 changes: 1 addition & 3 deletions fortuna/calib_model/predictive/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ def credible_interval(
q = (
jnp.array([0.5 * error, 1 - 0.5 * error])
if interval_type == "two-tailed"
else error
if interval_type == "left-tailed"
else 1 - error
else error if interval_type == "left-tailed" else 1 - error
)
qq = self.quantile(
q=q,
Expand Down
8 changes: 5 additions & 3 deletions fortuna/calib_model/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ def calibrate(
self._check_output_dim(val_data_loader)
return self._calibrate(
calib_data_loader=calib_data_loader,
uncertainty_fn=config.monitor.uncertainty_fn
if config.monitor.uncertainty_fn is not None
else self.prob_output_layer.variance,
uncertainty_fn=(
config.monitor.uncertainty_fn
if config.monitor.uncertainty_fn is not None
else self.prob_output_layer.variance
),
val_data_loader=val_data_loader,
loss_fn=loss_fn,
config=config,
Expand Down
24 changes: 15 additions & 9 deletions fortuna/calib_model/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ def init(
return cls(
apply_fn=None,
params=params,
opt_state=kwargs["opt_state"]
if optimizer is None and "opt_state" in kwargs
else optimizer.init(params),
opt_state=(
kwargs["opt_state"]
if optimizer is None and "opt_state" in kwargs
else optimizer.init(params)
),
mutable=mutable,
step=kwargs.get("step", 0),
tx=optimizer,
Expand Down Expand Up @@ -78,11 +80,15 @@ def init_from_dict(
FrozenDict(d["params"]),
FrozenDict(d["mutable"]) if d["mutable"] is not None else None,
optimizer,
FrozenDict(d.get("calib_params"))
if d["calib_params"] is not None
else None,
FrozenDict(d.get("calib_mutable"))
if d["calib_mutable"] is not None
else None,
(
FrozenDict(d.get("calib_params"))
if d["calib_params"] is not None
else None
),
(
FrozenDict(d.get("calib_mutable"))
if d["calib_mutable"] is not None
else None
),
**kwargs,
)
48 changes: 30 additions & 18 deletions fortuna/likelihood/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,16 @@ def _batched_log_joint_prob(
aux = dict()
if self.output_calib_manager is not None:
outs = self.output_calib_manager.apply(
params=calib_params["output_calibrator"]
if calib_params is not None
else None,
mutable=calib_mutable["output_calibrator"]
if calib_mutable is not None
else None,
params=(
calib_params["output_calibrator"]
if calib_params is not None
else None
),
mutable=(
calib_mutable["output_calibrator"]
if calib_mutable is not None
else None
),
outputs=outputs,
calib="calib_mutable" in return_aux,
)
Expand Down Expand Up @@ -407,12 +411,16 @@ def get_calibrated_outputs(
and self.output_calib_manager.output_calibrator is not None
):
outputs = self.output_calib_manager.apply(
params=calib_params["output_calibrator"]
if calib_params is not None
else None,
mutable=calib_mutable["output_calibrator"]
if calib_mutable is not None
else None,
params=(
calib_params["output_calibrator"]
if calib_params is not None
else None
),
mutable=(
calib_mutable["output_calibrator"]
if calib_mutable is not None
else None
),
outputs=outputs,
**kwargs,
)
Expand All @@ -434,12 +442,16 @@ def _get_batched_calibrated_outputs(
and self.output_calib_manager.output_calibrator is not None
):
outputs = self.output_calib_manager.apply(
params=calib_params["output_calibrator"]
if calib_params is not None
else None,
mutable=calib_mutable["output_calibrator"]
if calib_mutable is not None
else None,
params=(
calib_params["output_calibrator"]
if calib_params is not None
else None
),
mutable=(
calib_mutable["output_calibrator"]
if calib_mutable is not None
else None
),
outputs=outputs,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions fortuna/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ConstantModel(nn.Module):
Function to initialize the model parameters.
This must be one of the available options in :code:`flax.linen.initializers`.
"""

output_dim: int
initializer_fun: Optional[Initializer] = nn.initializers.zeros

Expand Down
1 change: 1 addition & 0 deletions fortuna/model/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class HyperparameterModel(nn.Module):
value: Union[float, Array]
Value of the hyperparameter.
"""

value: Array

def setup(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions fortuna/model/model_manager/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def apply_fn(p, x, m_mutable, llv_mutable):
p,
x,
m_mutable=mutable["model"] if mutable is not None else False,
llv_mutable=mutable["lik_log_var"]
if mutable is not None
else False,
llv_mutable=(
mutable["lik_log_var"] if mutable is not None else False
),
),
model_params=params,
x=inputs,
Expand Down
1 change: 1 addition & 0 deletions fortuna/model/scalar_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class ScalarConstantModel(nn.Module):
Function to initialize the model parameters.
This must be one of the available options in :code:`flax.linen.initializers`.
"""

output_dim: int
initializer_fun: Optional[Initializer] = nn.initializers.zeros

Expand Down
1 change: 1 addition & 0 deletions fortuna/model/scalar_hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class ScalarHyperparameterModel(nn.Module):
value: float
Scalar value of the hyperparameter.
"""

output_dim: int
value: float

Expand Down
1 change: 1 addition & 0 deletions fortuna/model/utils/spectral_norm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
The code has been taken from https://github.com/google/edward2/blob/main/edward2/jax/nn/normalization.py
"""

import dataclasses
from typing import (
Any,
Expand Down
1 change: 1 addition & 0 deletions fortuna/model/wideresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Wide ResNet model
(adapted from https://github.com/google/flax/blob/v0.2/examples/cifar10/models/wideresnet.py)
"""

from functools import partial
from typing import (
Any,
Expand Down
1 change: 1 addition & 0 deletions fortuna/ood_detection/ddu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Adapted from https://github.com/omegafragger/DDU/blob/main/utils/gmm_utils.py
"""

import logging
from typing import (
Callable,
Expand Down
8 changes: 5 additions & 3 deletions fortuna/output_calib_model/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@ def calibrate(
if val_outputs is not None:
self._check_output_dim(val_outputs, val_targets)
return super()._calibrate(
uncertainty_fn=config.monitor.uncertainty_fn
if config.monitor.uncertainty_fn is not None
else self.prob_output_layer.mean,
uncertainty_fn=(
config.monitor.uncertainty_fn
if config.monitor.uncertainty_fn is not None
else self.prob_output_layer.mean
),
calib_outputs=calib_outputs,
calib_targets=calib_targets,
val_outputs=val_outputs,
Expand Down
8 changes: 5 additions & 3 deletions fortuna/output_calib_model/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@ def calibrate(
if val_outputs is not None:
self._check_output_dim(val_outputs, val_targets)
return super()._calibrate(
uncertainty_fn=config.monitor.uncertainty_fn
if config.monitor.uncertainty_fn is not None
else self.prob_output_layer.variance,
uncertainty_fn=(
config.monitor.uncertainty_fn
if config.monitor.uncertainty_fn is not None
else self.prob_output_layer.variance
),
calib_outputs=calib_outputs,
calib_targets=calib_targets,
val_outputs=val_outputs,
Expand Down
10 changes: 5 additions & 5 deletions fortuna/output_calib_model/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ def init(
return cls(
apply_fn=None,
params=params,
opt_state=kwargs["opt_state"]
if optimizer is None and "opt_state" in kwargs
else None
if optimizer is None
else optimizer.init(params),
opt_state=(
kwargs["opt_state"]
if optimizer is None and "opt_state" in kwargs
else None if optimizer is None else optimizer.init(params)
),
mutable=mutable,
step=kwargs.get("step", 0),
tx=optimizer,
Expand Down
8 changes: 5 additions & 3 deletions fortuna/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,11 @@ def plot_2d_classification_predictions_and_uncertainty(
inputs[:, 0],
inputs[:, 1],
s=marker_size,
c=[preds_color[0] if i == 1 else preds_color[1] for i in preds]
if preds is not None
else base_inputs_color,
c=(
[preds_color[0] if i == 1 else preds_color[1] for i in preds]
if preds is not None
else base_inputs_color
),
)
if colorbar:
plt.colorbar(im, ax=ax.ravel().tolist() if hasattr(ax, "ravel") else ax)
Expand Down
8 changes: 5 additions & 3 deletions fortuna/prob_model/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,11 @@ def calibrate(
if val_data_loader is not None:
self._check_output_dim(val_data_loader)
return super()._calibrate(
uncertainty_fn=calib_config.monitor.uncertainty_fn
if calib_config.monitor.uncertainty_fn is not None
else self.prob_output_layer.mean,
uncertainty_fn=(
calib_config.monitor.uncertainty_fn
if calib_config.monitor.uncertainty_fn is not None
else self.prob_output_layer.mean
),
calib_data_loader=calib_data_loader,
val_data_loader=val_data_loader,
calib_config=calib_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,11 @@ def _fit(i):
else:
self.state = PosteriorMultiStateRepository(
size=self.posterior_approximator.ensemble_size,
checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir
if fit_config.checkpointer.dump_state is True
else None,
checkpoint_dir=(
fit_config.checkpointer.save_checkpoint_dir
if fit_config.checkpointer.dump_state is True
else None
),
)

status = []
Expand Down
16 changes: 10 additions & 6 deletions fortuna/prob_model/posterior/laplace/laplace_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,16 @@ def apply_calib_model_manager(_params, _batch_inputs):
_params, _batch_inputs, mutable=mutable, train=False
)
outputs = self.joint.likelihood.output_calib_manager.apply(
params=calib_params["output_calibrator"]
if calib_params is not None
else None,
mutable=calib_mutable["output_calibrator"]
if calib_mutable is not None
else None,
params=(
calib_params["output_calibrator"]
if calib_params is not None
else None
),
mutable=(
calib_mutable["output_calibrator"]
if calib_mutable is not None
else None
),
outputs=outputs,
)
return outputs
Expand Down
6 changes: 3 additions & 3 deletions fortuna/prob_model/posterior/map/map_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ def fit(
n_epochs=fit_config.optimizer.n_epochs,
metrics=fit_config.monitor.metrics,
validation_dataloader=val_data_loader,
validation_dataset_size=val_data_loader.size
if val_data_loader is not None
else None,
validation_dataset_size=(
val_data_loader.size if val_data_loader is not None else None
),
verbose=fit_config.monitor.verbose,
callbacks=fit_config.callbacks,
max_grad_norm=fit_config.hyperparameters.max_grad_norm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ def fit(
n_epochs=fit_config.optimizer.n_epochs,
metrics=fit_config.monitor.metrics,
validation_dataloader=val_data_loader,
validation_dataset_size=val_data_loader.size
if val_data_loader is not None
else None,
validation_dataset_size=(
val_data_loader.size if val_data_loader is not None else None
),
verbose=fit_config.monitor.verbose,
unravel=self._unravel,
n_samples=self.posterior_approximator.n_loss_samples,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def __init__(self, size: int, checkpoint_dir: Optional[Path] = None):
self.size = size
self.state = [
PosteriorStateRepository(
checkpoint_dir=os.path.join(checkpoint_dir, str(i))
if checkpoint_dir
else None
checkpoint_dir=(
os.path.join(checkpoint_dir, str(i)) if checkpoint_dir else None
)
)
for i in range(size)
]
Expand Down
Loading

0 comments on commit 0f1b469

Please sign in to comment.