diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 234e788..e03364c 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -212,3 +212,24 @@ def test_entropy_metric(): D1 = tsgm.dataset.Dataset(ts, y=None) spec_entropy_metric = tsgm.metrics.EntropyMetric() assert spec_entropy_metric(D1) == 2.6402430161833763 + + +def test_demographic_parity(): + ts = np.array([[[0, 2], [11, -11], [1, 2]], [[0, 2], [11, -11], [1, 2]], [[10, 21], [1, -1], [6, 8]]]).astype(np.float32) + y = np.array([0, 1, 1]) + groups = np.array([0, 1 ,2]) + D = tsgm.dataset.Dataset(ts, y) + + synth_ts = ts + synth_y = np.array([0, 1, 1]) + synth_groups = np.array([1, 2, 3]) + D_synth = tsgm.dataset.Dataset(synth_ts, synth_y) + demographic_parity_metric = tsgm.metrics.DemographicParityMetric() + result = demographic_parity_metric(D, groups, D_synth, synth_groups) + + assert result == { + 0: np.inf, + 1: 1.0, + 2: 0, + 3: -np.inf + } diff --git a/tsgm/metrics/__init__.py b/tsgm/metrics/__init__.py index fe0866b..926c637 100644 --- a/tsgm/metrics/__init__.py +++ b/tsgm/metrics/__init__.py @@ -2,5 +2,5 @@ from tsgm.metrics.metrics import ( DistanceMetric, ConsistencyMetric, BaseDownstreamEvaluator, DownstreamPerformanceMetric, PrivacyMembershipInferenceMetric, - MMDMetric, DiscriminativeMetric, EntropyMetric + MMDMetric, DiscriminativeMetric, EntropyMetric, DemographicParityMetric ) diff --git a/tsgm/metrics/metrics.py b/tsgm/metrics/metrics.py index a7c3edf..c360460 100644 --- a/tsgm/metrics/metrics.py +++ b/tsgm/metrics/metrics.py @@ -5,6 +5,7 @@ import numpy as np import itertools import sklearn +import scipy from tqdm import tqdm from tensorflow.python.types.core import TensorLike @@ -328,3 +329,68 @@ def __call__(self, d: tsgm.dataset.DatasetOrTensor) -> float: """ X = _dataset_or_tensor_to_tensor(d) return np.sum(_spectral_entropy_sum(X), axis=None) + + +class DemographicParityMetric(Metric): + _DEFAULT_KS_METRIC = lambda data1, data2: scipy.stats.ks_2samp(data1, data2).statistic # noqa: E731 + + """ + Measuring demographic parity between two datasets. + + This metric assesses the disparity in the distributions of a target variable among different groups in two datasets. + By default, it uses the Kolmogorov-Smirnov statistic to quantify the maximum vertical deviation between the cumulative distribution functions + of the target variable for the historical and synthetic data within each group. + + Args: + d_hist (tsgm.dataset.DatasetOrTensor): The historical input dataset or tensor. + groups_hist (TensorLike): The group assignments for the historical data. + d_synth (tsgm.dataset.DatasetOrTensor): The synthetic input dataset or tensor. + groups_synth (TensorLike): The group assignments for the synthetic data. + metric (callable, optional): The metric used to compare the target variable distributions within each group. + Default is the Kolmogorov-Smirnov statistic. + + Returns: + dict: A dictionary mapping each group to the computed demographic parity metric. + + Example: + >>> metric = DemographicParityMetric() + >>> dataset_hist = tsgm.dataset.Dataset(...) + >>> dataset_synth = tsgm.dataset.Dataset(...) + >>> groups_hist = [0, 1, 0, 1, 1, 0] + >>> groups_synth = [1, 1, 0, 0, 0, 1] + >>> result = metric(dataset_hist, groups_hist, dataset_synth, groups_synth) + >>> print(result) + """ + def __call__(self, d_hist: tsgm.dataset.DatasetOrTensor, groups_hist: TensorLike, d_synth: tsgm.dataset.DatasetOrTensor, groups_synth: TensorLike, metric: T.Callable = _DEFAULT_KS_METRIC) -> T.Dict: + """ + Calculate the demographic parity metric for the input datasets. + + Args: + d_hist (tsgm.dataset.DatasetOrTensor): The historical input dataset or tensor. + groups_hist (TensorLike): The group assignments for the historical data. + d_synth (tsgm.dataset.DatasetOrTensor): The synthetic input dataset or tensor. + groups_synth (TensorLike): The group assignments for the synthetic data. + metric (callable, optional): The metric used to compare the target variable distributions within each group. + Default is the Kolmogorov-Smirnov statistic. + + Returns: + dict: A dictionary mapping each group to the computed demographic parity metric. + """ + + y_hist, y_synth = d_hist.y, d_synth.y + + unique_groups_hist, unique_groups_synth = set(groups_hist), set(groups_synth) + all_groups = unique_groups_hist | unique_groups_synth + if len(all_groups) > len(unique_groups_hist) or len(all_groups) > len(unique_groups_synth): + logger.warning("Groups in historical and synthetic data are not entirely identical.") + + result = {} + for g in all_groups: + y_g_hist, y_g_synth = y_hist[groups_hist == g], y_synth[groups_synth == g] + if not len(y_g_synth): + result[g] = np.inf + elif not len(y_g_hist): + result[g] = -np.inf + else: + result[g] = metric(y_g_hist, y_g_synth) + return result