From c44016cbe26efb628f25edfb09b3ee8481d593ff Mon Sep 17 00:00:00 2001 From: Alexander Nikitin <1243786+AlexanderVNikitin@users.noreply.github.com> Date: Tue, 29 Aug 2023 13:03:29 +0300 Subject: [PATCH] similarity -> distance, add tests --- tests/test_metrics.py | 19 ++++++++++++++++++- tsgm/metrics/__init__.py | 4 ++-- tsgm/metrics/metrics.py | 19 ++++++++++++++++++- tutorials/GANs/RCGAN.ipynb | 6 +++--- tutorials/Metrics Tutorial.ipynb | 10 +++++----- tutorials/Model Selection.ipynb | 4 ++-- 6 files changed, 48 insertions(+), 14 deletions(-) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index afc13e6..d50143b 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -47,7 +47,7 @@ def test_similarity_metric(): functools.partial(tsgm.metrics.statistics.axis_max_s, axis=1), functools.partial(tsgm.metrics.statistics.axis_min_s, axis=1)] - sim_metric = tsgm.metrics.SimilarityMetric( + sim_metric = tsgm.metrics.DistanceMetric( statistics=statistics, discrepancy=lambda x, y: np.linalg.norm(x - y) ) assert sim_metric(ts, diff_ts) < sim_metric(ts, sim_ts) @@ -145,3 +145,20 @@ def test_mmd_metric(): assert mmd_metric(D1, D2) == mmd_metric(ts, diff_ts) assert mmd_metric(D1, D1) == 0 and mmd_metric(D2, D2) == 0 + + +def test_discriminative_metric(): + ts = np.array([[[0, 2], [11, -11], [1, 2]], [[10, 21], [1, -1], [6, 8]]]).astype(np.float32) + D1 = tsgm.dataset.Dataset(ts, y=None) + + diff_ts = np.array([[[12, 13], [10, 10], [-1, -2]], [[-1, 32], [2, 1], [10, 8]]]).astype(np.float32) + D2 = tsgm.dataset.Dataset(diff_ts, y=None) + + model = tsgm.models.zoo["clf_cl_n"](seq_len=ts.shape[1], feat_dim=ts.shape[2], output_dim=1).model + model.compile( + tf.keras.optimizers.Adam(), + tf.keras.losses.CategoricalCrossentropy(from_logits=True) + ) + discr_metric = tsgm.metrics.DiscriminativeMetric() + assert discr_metric(d_hist=D1, d_syn=D2, model=model, test_size=0.2, random_seed=42, n_epochs=10) == 1.0 + assert discr_metric(d_hist=D1, d_syn=D2, model=model, metric=sklearn.metrics.precision_score, test_size=0.2, random_seed=42, n_epochs=10) == 1.0 diff --git a/tsgm/metrics/__init__.py b/tsgm/metrics/__init__.py index 8501fab..c33bd92 100644 --- a/tsgm/metrics/__init__.py +++ b/tsgm/metrics/__init__.py @@ -1,6 +1,6 @@ import tsgm.metrics.statistics from tsgm.metrics.metrics import ( - SimilarityMetric, ConsistencyMetric, BaseDownstreamEvaluator, + DistanceMetric, ConsistencyMetric, BaseDownstreamEvaluator, DownstreamPerformanceMetric, PrivacyMembershipInferenceMetric, - MMDMetric + MMDMetric, DiscriminativeMetric ) diff --git a/tsgm/metrics/metrics.py b/tsgm/metrics/metrics.py index e5d46b0..a2fa4ed 100644 --- a/tsgm/metrics/metrics.py +++ b/tsgm/metrics/metrics.py @@ -29,7 +29,7 @@ def __call__(self, *args, **kwargs) -> float: pass -class SimilarityMetric(Metric): +class DistanceMetric(Metric): """ Metric that measures similarity between synthetic and real time series """ @@ -211,3 +211,20 @@ def __call__(self, D1: tsgm.dataset.DatasetOrTensor, D2: tsgm.dataset.DatasetOrT logger.warning("It is currently impossible to run MMD for labeled time series. Labels will be ignored!") X1, X2 = _dataset_or_tensor_to_tensor(D1), _dataset_or_tensor_to_tensor(D2) return tsgm.utils.mmd.MMD(X1, X2, kernel=self.kernel) + + +class DiscriminativeMetric(Metric): + """ + The discriminative metric measures how accurately a discriminative model can separate synthetic and real data. + """ + def __call__(self, d_hist: tsgm.dataset.DatasetOrTensor, d_syn: tsgm.dataset.DatasetOrTensor, model, test_size, n_epochs, metric=None, random_seed=None) -> float: + X_hist, X_syn = _dataset_or_tensor_to_tensor(d_hist), _dataset_or_tensor_to_tensor(d_syn) + X_all, y_all = np.concatenate([X_hist, X_syn]), np.concatenate([[1] * len(d_hist), [0] * len(d_syn)]) + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(X_all, y_all, test_size=test_size, random_state=random_seed) + model.fit(X_all, y_all, epochs=n_epochs) + import pdb; pdb.set_trace() + y_pred = model.predict(X_test) + if metric == None: + return sklearn.metrics.accuracy_score(y_test, y_pred) + else: + return metric(y_test, y_pred) diff --git a/tutorials/GANs/RCGAN.ipynb b/tutorials/GANs/RCGAN.ipynb index 2a4854e..56d6d6a 100644 --- a/tutorials/GANs/RCGAN.ipynb +++ b/tutorials/GANs/RCGAN.ipynb @@ -178,11 +178,11 @@ "statistics = [functools.partial(tsgm.metrics.statistics.axis_max_s, axis=1),\n", " functools.partial(tsgm.metrics.statistics.axis_min_s, axis=1)]\n", "\n", - "sim_metric = tsgm.metrics.SimilarityMetric(\n", + "sim_metric = tsgm.metrics.DistanceMetric(\n", " statistics=statistics, discrepancy=lambda x, y: np.linalg.norm(x - y)\n", ")\n", "\n", - "print(f\"Similarity metric: {sim_metric(X, X_gen)}\")" + "print(f\"Distance metric: {sim_metric(X, X_gen)}\")" ] }, { @@ -259,7 +259,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.8" } }, "nbformat": 4, diff --git a/tutorials/Metrics Tutorial.ipynb b/tutorials/Metrics Tutorial.ipynb index ff47cfd..18dd2ed 100644 --- a/tutorials/Metrics Tutorial.ipynb +++ b/tutorials/Metrics Tutorial.ipynb @@ -62,9 +62,9 @@ "id": "a121d403", "metadata": {}, "source": [ - "## Similarity metric\n", + "## Distance metric\n", "\n", - "First, we define a list of summary statistics that reflect the similarity between the datasets. Module `tss.metrics.statistics` defines a set of handy statistics." + "First, we define a list of summary statistics that reflect the distance between the datasets. Module `tss.metrics.statistics` defines a set of handy statistics." ] }, { @@ -103,7 +103,7 @@ "id": "b2a72a7c", "metadata": {}, "source": [ - "Finally, we are putting all together using `tss.metrics.SimilarityMetric` object." + "Finally, we are putting all together using `tss.metrics.DistanceMetric` object." ] }, { @@ -113,7 +113,7 @@ "metadata": {}, "outputs": [], "source": [ - "sim_metric = tsgm.metrics.SimilarityMetric(\n", + "sim_metric = tsgm.metrics.DistanceMetric(\n", " statistics=statistics, discrepancy=discrepancy_func\n", ")" ] @@ -322,7 +322,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.8" } }, "nbformat": 4, diff --git a/tutorials/Model Selection.ipynb b/tutorials/Model Selection.ipynb index d198d39..8fbde33 100644 --- a/tutorials/Model Selection.ipynb +++ b/tutorials/Model Selection.ipynb @@ -109,7 +109,7 @@ "metadata": {}, "outputs": [], "source": [ - "metric_to_optimize = tsgm.metrics.metrics.SimilarityMetric(\n", + "metric_to_optimize = tsgm.metrics.metrics.DistanceMetric(\n", " statistics=[\n", " functools.partial(tsgm.metrics.statistics.axis_max_s, axis=None),\n", " functools.partial(tsgm.metrics.statistics.axis_min_s, axis=None),\n", @@ -265,7 +265,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.8" } }, "nbformat": 4,