Skip to content

Commit

Permalink
similarity -> distance, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderVNikitin committed Aug 29, 2023
1 parent a92a363 commit c44016c
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 14 deletions.
19 changes: 18 additions & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_similarity_metric():
functools.partial(tsgm.metrics.statistics.axis_max_s, axis=1),
functools.partial(tsgm.metrics.statistics.axis_min_s, axis=1)]

sim_metric = tsgm.metrics.SimilarityMetric(
sim_metric = tsgm.metrics.DistanceMetric(
statistics=statistics, discrepancy=lambda x, y: np.linalg.norm(x - y)
)
assert sim_metric(ts, diff_ts) < sim_metric(ts, sim_ts)
Expand Down Expand Up @@ -145,3 +145,20 @@ def test_mmd_metric():

assert mmd_metric(D1, D2) == mmd_metric(ts, diff_ts)
assert mmd_metric(D1, D1) == 0 and mmd_metric(D2, D2) == 0


def test_discriminative_metric():
ts = np.array([[[0, 2], [11, -11], [1, 2]], [[10, 21], [1, -1], [6, 8]]]).astype(np.float32)
D1 = tsgm.dataset.Dataset(ts, y=None)

diff_ts = np.array([[[12, 13], [10, 10], [-1, -2]], [[-1, 32], [2, 1], [10, 8]]]).astype(np.float32)
D2 = tsgm.dataset.Dataset(diff_ts, y=None)

model = tsgm.models.zoo["clf_cl_n"](seq_len=ts.shape[1], feat_dim=ts.shape[2], output_dim=1).model
model.compile(
tf.keras.optimizers.Adam(),
tf.keras.losses.CategoricalCrossentropy(from_logits=True)
)
discr_metric = tsgm.metrics.DiscriminativeMetric()
assert discr_metric(d_hist=D1, d_syn=D2, model=model, test_size=0.2, random_seed=42, n_epochs=10) == 1.0
assert discr_metric(d_hist=D1, d_syn=D2, model=model, metric=sklearn.metrics.precision_score, test_size=0.2, random_seed=42, n_epochs=10) == 1.0
4 changes: 2 additions & 2 deletions tsgm/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import tsgm.metrics.statistics
from tsgm.metrics.metrics import (
SimilarityMetric, ConsistencyMetric, BaseDownstreamEvaluator,
DistanceMetric, ConsistencyMetric, BaseDownstreamEvaluator,
DownstreamPerformanceMetric, PrivacyMembershipInferenceMetric,
MMDMetric
MMDMetric, DiscriminativeMetric
)
19 changes: 18 additions & 1 deletion tsgm/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __call__(self, *args, **kwargs) -> float:
pass


class SimilarityMetric(Metric):
class DistanceMetric(Metric):
"""
Metric that measures similarity between synthetic and real time series
"""
Expand Down Expand Up @@ -211,3 +211,20 @@ def __call__(self, D1: tsgm.dataset.DatasetOrTensor, D2: tsgm.dataset.DatasetOrT
logger.warning("It is currently impossible to run MMD for labeled time series. Labels will be ignored!")
X1, X2 = _dataset_or_tensor_to_tensor(D1), _dataset_or_tensor_to_tensor(D2)
return tsgm.utils.mmd.MMD(X1, X2, kernel=self.kernel)


class DiscriminativeMetric(Metric):
"""
The discriminative metric measures how accurately a discriminative model can separate synthetic and real data.
"""
def __call__(self, d_hist: tsgm.dataset.DatasetOrTensor, d_syn: tsgm.dataset.DatasetOrTensor, model, test_size, n_epochs, metric=None, random_seed=None) -> float:
X_hist, X_syn = _dataset_or_tensor_to_tensor(d_hist), _dataset_or_tensor_to_tensor(d_syn)
X_all, y_all = np.concatenate([X_hist, X_syn]), np.concatenate([[1] * len(d_hist), [0] * len(d_syn)])
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(X_all, y_all, test_size=test_size, random_state=random_seed)
model.fit(X_all, y_all, epochs=n_epochs)
import pdb; pdb.set_trace()
y_pred = model.predict(X_test)
if metric == None:
return sklearn.metrics.accuracy_score(y_test, y_pred)
else:
return metric(y_test, y_pred)
6 changes: 3 additions & 3 deletions tutorials/GANs/RCGAN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@
"statistics = [functools.partial(tsgm.metrics.statistics.axis_max_s, axis=1),\n",
" functools.partial(tsgm.metrics.statistics.axis_min_s, axis=1)]\n",
"\n",
"sim_metric = tsgm.metrics.SimilarityMetric(\n",
"sim_metric = tsgm.metrics.DistanceMetric(\n",
" statistics=statistics, discrepancy=lambda x, y: np.linalg.norm(x - y)\n",
")\n",
"\n",
"print(f\"Similarity metric: {sim_metric(X, X_gen)}\")"
"print(f\"Distance metric: {sim_metric(X, X_gen)}\")"
]
},
{
Expand Down Expand Up @@ -259,7 +259,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.10.8"
}
},
"nbformat": 4,
Expand Down
10 changes: 5 additions & 5 deletions tutorials/Metrics Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@
"id": "a121d403",
"metadata": {},
"source": [
"## Similarity metric\n",
"## Distance metric\n",
"\n",
"First, we define a list of summary statistics that reflect the similarity between the datasets. Module `tss.metrics.statistics` defines a set of handy statistics."
"First, we define a list of summary statistics that reflect the distance between the datasets. Module `tss.metrics.statistics` defines a set of handy statistics."
]
},
{
Expand Down Expand Up @@ -103,7 +103,7 @@
"id": "b2a72a7c",
"metadata": {},
"source": [
"Finally, we are putting all together using `tss.metrics.SimilarityMetric` object."
"Finally, we are putting all together using `tss.metrics.DistanceMetric` object."
]
},
{
Expand All @@ -113,7 +113,7 @@
"metadata": {},
"outputs": [],
"source": [
"sim_metric = tsgm.metrics.SimilarityMetric(\n",
"sim_metric = tsgm.metrics.DistanceMetric(\n",
" statistics=statistics, discrepancy=discrepancy_func\n",
")"
]
Expand Down Expand Up @@ -322,7 +322,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.10.8"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions tutorials/Model Selection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
"metadata": {},
"outputs": [],
"source": [
"metric_to_optimize = tsgm.metrics.metrics.SimilarityMetric(\n",
"metric_to_optimize = tsgm.metrics.metrics.DistanceMetric(\n",
" statistics=[\n",
" functools.partial(tsgm.metrics.statistics.axis_max_s, axis=None),\n",
" functools.partial(tsgm.metrics.statistics.axis_min_s, axis=None),\n",
Expand Down Expand Up @@ -265,7 +265,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.10.8"
}
},
"nbformat": 4,
Expand Down

0 comments on commit c44016c

Please sign in to comment.