Skip to content

Commit

Permalink
Explanations on the use of the LinearFractionalParity metric and impo…
Browse files Browse the repository at this point in the history
…rtant documentation update.
  • Loading branch information
maartenbuyl committed Apr 26, 2024
1 parent 209ad6b commit ce91581
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 17 deletions.
10 changes: 5 additions & 5 deletions fairret/loss/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ class ProjectionLoss(FairnessLoss):
The projections are computed using cvxpy. Hence, any subclass is expected to implement the statistical distance
between distributions in both cvxpy and PyTorch by implementing the
:func:`~projection.ProjectionLoss.cvxpy_distance` method and the
:func:`~projection.ProjectionLoss.torch_distance` method respectively.
:py:func:`~projection.ProjectionLoss.cvxpy_distance` method and the
:py:func:`~projection.ProjectionLoss.torch_distance` method respectively.
Optionally, the :func:`~projection.ProjectionLoss.torch_distance_with_logits` method can be overwritten to
Optionally, the :py:func:`~projection.ProjectionLoss.torch_distance_with_logits` method can be overwritten to
provide a more numerically stable handling of predictions that are provided as logits. If left unimplemented,
:func:`~projection.ProjectionLoss.torch_distance` will be called instead, after applying the sigmoid function to
:py:func:`~projection.ProjectionLoss.torch_distance` will be called instead, after applying the sigmoid function to
the predictions.
Note:
Expand Down Expand Up @@ -131,7 +131,7 @@ def torch_distance(self, pred: torch.Tensor, proj: torch.Tensor) -> torch.Tensor

def torch_distance_with_logits(self, pred, proj):
"""
A more numerically stable alternative method to :func:`~projection.ProjectionLoss.torch_distance`, where `pred`
A more numerically stable alternative method to :py:func:`~projection.ProjectionLoss.torch_distance`, where `pred`
is assumed to be logits.
Args:
Expand Down
2 changes: 1 addition & 1 deletion fairret/loss/violation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def forward(self, pred: torch.Tensor, sens: torch.Tensor, *stat_args, pred_as_lo
target_statistic: Optional[torch.Tensor] = None, **stat_kwargs: Any) -> torch.Tensor:
"""
Calculate the violation vector in relation to the `target_statistic` and penalize this violation using the
:func:`~violation.ViolationLoss.penalize_violation` method implemented by the subclass.
:py:func:`~violation.ViolationLoss.penalize_violation` method implemented by the subclass.
Args:
pred (torch.Tensor): Predictions of shape :math:`(N, 1)`, as we assume to be performing binary
Expand Down
18 changes: 11 additions & 7 deletions fairret/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,22 @@ def gap_relative_abs_max(vals: torch.Tensor, target_val: float) -> float:
class LinearFractionalParity(torchmetrics.Metric):
"""
Metric that assesses the fairness of a model's predictions by comparing the gaps between the provided
LinearFractionalStatistic for every sensitive feature.
:py:class:`~fairret.statistic.LinearFractionalStatistic` for every sensitive feature.
The metric maintains two pairs of running sums: one for the statistic for every sensitive feature, and one for the
overall statistic. Each pair of running sums consists of the numerator and the denominator for those statistics.
Observations are added to these sums by calling the `update` method. The final fairness gap is computed by calling
the `compute` method, which also resets the internal state of the metric.
Observations are added to these sums by calling the :py:func:`~fairret.metric.update` method. The final fairness gap
is computed by calling the :py:func:`~fairret.metric.compute` method.
The class is implemented as a subclass of torchmetrics.Metric, so the `torchmetrics` package is required.
The class is implemented as a subclass of :py:class:`torchmetrics.Metric`, so the :py:mod:`torchmetrics` package is
required.
Warning:
It is advised not to mix LinearFractionalParity metrics with different statistics in a single
torchmetrics.MetricCollection with `compute_groups=True`, as this can lead to hard-to-debug errors.
A separate :py:func:`~torchmetrics.Metric.reset() call is required to reset the internal state of the metric
between epochs.
Warning:
It is advised not to mix metrics of this class with different statistics in a single
:py:class:`torchmetrics.MetricCollection` with `compute_groups=True`, as this can lead to hard-to-debug errors.
"""

is_differentiable = True
Expand Down Expand Up @@ -120,7 +124,7 @@ def compute(self) -> float:
final gaps between the groupwise and overall statistics, according to the `gap_fn`.
Warning:
This does NOT reset the internal state of the metric. A separate .reset() call is required to do so.
This does NOT reset the internal state of the metric. A separate `.reset()` call is required to do so.
Returns:
float: The final fairness gap.
Expand Down
8 changes: 4 additions & 4 deletions fairret/statistic/linear_fractional.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ class LinearFractionalStatistic(Statistic):
A LinearFractionalStatistic is constructed in a canonical form by defining the intercept and slope of the
numerator and denominator linear functions, i.e. the functions
:func:`~linear_fractional.LinearFractionalStatistic.num_intercept`,
:func:`~linear_fractional.LinearFractionalStatistic.num_slope`,
:func:`~linear_fractional.LinearFractionalStatistic.denom_intercept`,
and :func:`~linear_fractional.LinearFractionalStatistic.denom_slope`. Each subclass must implement these
:py:func:`~linear_fractional.LinearFractionalStatistic.num_intercept`,
:py:func:`~linear_fractional.LinearFractionalStatistic.num_slope`,
:py:func:`~linear_fractional.LinearFractionalStatistic.denom_intercept`,
and :py:func:`~linear_fractional.LinearFractionalStatistic.denom_slope`. Each subclass must implement these
functions (using any signature).
The statistic is then computed as :math:`\\frac{num\\_intercept + num\\_slope * pred}{denom\\_intercept +
Expand Down

0 comments on commit ce91581

Please sign in to comment.