diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 33b84f4..c445178 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -215,6 +215,27 @@ def test_entropy_metric(): assert spec_entropy_metric(D1) == 2.6402430161833763 +def test_shannon_entropy_metric(): + ts = np.array([[[0, 2], [11, -11], [1, 2]], [[10, 21], [1, -1], [6, 8]]]).astype(np.float32) + y = np.array([1] * ts.shape[0]) + D1 = tsgm.dataset.Dataset(ts, y=y) + sdi_metric = tsgm.metrics.ShannonEntropyMetric() + assert sdi_metric(D1) == 0 + y = np.array([1, 2]) + D2 = tsgm.dataset.Dataset(ts, y=y) + assert sdi_metric(D2) > 0 + + +def test_pairwise_distance_metric(): + ts = np.array([[[0, 2], [11, -11], [1, 2]], [[0, 2], [11, -11], [1, 2]]]).astype(np.float32) + D1 = tsgm.dataset.Dataset(ts, y=None) + pd_metric = tsgm.metrics.PairwiseDistanceMetric() + assert np.mean(pd_metric(D1)) == 0 + ts = np.array([[[0, 2], [11, -11], [1, 2]], [[10, 21], [1, -1], [6, 8]]]).astype(np.float32) + D2 = tsgm.dataset.Dataset(ts, y=None) + assert np.mean(pd_metric(D2)) > 0 + + def test_demographic_parity(): ts = np.array([[[0, 2], [11, -11], [1, 2]], [[0, 2], [11, -11], [1, 2]], [[10, 21], [1, -1], [6, 8]]]).astype(np.float32) y = np.array([0, 1, 1]) diff --git a/tsgm/metrics/__init__.py b/tsgm/metrics/__init__.py index 926c637..c01851c 100644 --- a/tsgm/metrics/__init__.py +++ b/tsgm/metrics/__init__.py @@ -2,5 +2,6 @@ from tsgm.metrics.metrics import ( DistanceMetric, ConsistencyMetric, BaseDownstreamEvaluator, DownstreamPerformanceMetric, PrivacyMembershipInferenceMetric, - MMDMetric, DiscriminativeMetric, EntropyMetric, DemographicParityMetric + MMDMetric, DiscriminativeMetric, EntropyMetric, DemographicParityMetric, + ShannonEntropyMetric, PairwiseDistanceMetric ) diff --git a/tsgm/metrics/metrics.py b/tsgm/metrics/metrics.py index afd3776..009e923 100644 --- a/tsgm/metrics/metrics.py +++ b/tsgm/metrics/metrics.py @@ -6,7 +6,9 @@ import itertools import sklearn import scipy +from scipy.stats import entropy from tqdm import tqdm +from scipy.spatial.distance import pdist, squareform from tensorflow.python.types.core import TensorLike import tsgm @@ -304,7 +306,7 @@ def _spectral_entropy_sum(X: TensorLike) -> TensorLike: class EntropyMetric(Metric): """ - Calculates the spectral entropy of a dataset or tensor. + Calculates the spectral entropy of a dataset or tensor as a sum of individual entropies. Args: d (tsgm.dataset.DatasetOrTensor): The input dataset or tensor. @@ -330,6 +332,76 @@ def __call__(self, d: tsgm.dataset.DatasetOrTensor) -> float: """ X = _dataset_or_tensor_to_tensor(d) return np.sum(_spectral_entropy_sum(X), axis=None) + + +class ShannonEntropyMetric(Metric): + """ + Shannon Entropy calculated over the labels of a dataset. + This index is a measure of diversity that accounts for categories present in a dataset. + """ + + def _shannon_entropy(self, labels): + """ + Private method to calculate the Shannon Entropy for a given set of labels. + + Parameters: + labels (array-like): The labels or categories for which the diversity measure is to be calculated. + + Returns: + float: The Shannon Entropy value. + """ + _, counts = np.unique(labels, return_counts=True) + return entropy(counts) + + def __call__(self, d: tsgm.dataset.DatasetOrTensor) -> float: + """ + Calculate the Shannon entropy for the dataset. + + Parameters: + d (tsgm.dataset.DatasetOrTensor): The dataset or tensor object containing the labels. + + Returns: + float: The Shannon entropy value. + + Raises: + AssertionError: If the dataset does not contain labels. + """ + y = d.y + assert y is not None, "The dataset must contain labels." + + return self._shannon_entropy(y) + + +class PairwiseDistanceMetric(Metric): + """ + Measures pairwise distances in a set of time series. + """ + + def pairwise_euclidean_distances(self, ts: TensorLike) -> TensorLike: + """ + Computes the pairwise Euclidean distances for a set of time series. + + Parameters: + ts (numpy.ndarray): A 2D array where each row represents a time series. + + Returns: + numpy.ndarray: A 2D array representing the pairwise Euclidean distance matrix. + """ + distances = pdist(np.reshape(ts, (ts.shape[0], -1)), metric='euclidean') + return squareform(distances) + + def __call__(self, d: tsgm.dataset.DatasetOrTensor) -> TensorLike: + """ + Calculates the pairwise Euclidean distances for a dataset or tensor. + + Parameters: + d (tsgm.dataset.DatasetOrTensor): The input dataset or tensor containing time series data. + + Returns: + float: The pairwise Euclidean distances of the input data. + """ + X = _dataset_or_tensor_to_tensor(d) + return self.pairwise_euclidean_distances(X) class DemographicParityMetric(Metric):