Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discriminative metric #23

Merged
merged 3 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,38 @@ def test_statistics():
assert (tsgm.metrics.statistics.axis_mode_s(ts_tf, axis=None) == [1]).all()


def test_similarity_metric():
def test_distance_metric():
ts = np.array([[[0, 2], [11, -11], [1, 2]], [[10, 21], [1, -1], [6, 8]]])
diff_ts = np.array([[[0, 2], [11, -11], [1, 2]], [[10, 21], [1, -1], [6, 8]]])
sim_ts = ts + 1e-7
diff_ts = 10 * ts
y = np.ones((ts.shape[0], 1))

statistics = [functools.partial(tsgm.metrics.statistics.axis_max_s, axis=None),
functools.partial(tsgm.metrics.statistics.axis_min_s, axis=None),
functools.partial(tsgm.metrics.statistics.axis_max_s, axis=1),
functools.partial(tsgm.metrics.statistics.axis_min_s, axis=1)]

sim_metric = tsgm.metrics.SimilarityMetric(
dist_metric = tsgm.metrics.DistanceMetric(
statistics=statistics, discrepancy=lambda x, y: np.linalg.norm(x - y)
)
assert sim_metric(ts, diff_ts) < sim_metric(ts, sim_ts)
stat_results = sim_metric.stats(ts)
assert dist_metric(ts, diff_ts) > dist_metric(ts, sim_ts)
stat_results = dist_metric.stats(ts)

assert len(stat_results) == 6
assert sim_metric._discrepancy(sim_metric.stats(ts), sim_metric.stats(diff_ts)) == sim_metric(ts, diff_ts)
assert sim_metric(ts, diff_ts) == sim_metric(diff_ts, ts)
assert dist_metric._discrepancy(dist_metric.stats(ts), dist_metric.stats(sim_ts)) == dist_metric(ts, sim_ts)
assert dist_metric(ts, sim_ts) != dist_metric(ts, diff_ts)
assert dist_metric(ts, ts) == 0
assert dist_metric(diff_ts, ts) == dist_metric(ts, diff_ts)

# with labels
ds = tsgm.dataset.Dataset(ts, y)
ds_diff = tsgm.dataset.Dataset(diff_ts, y)
ds_sim = tsgm.dataset.Dataset(sim_ts, y)
assert dist_metric(ts, diff_ts) != 0
assert dist_metric(ds, ds) == 0
assert dist_metric(ds_sim, ds) < dist_metric(ds_diff, ds)
assert dist_metric(ds, ds_diff) == dist_metric(ds_diff, ds)



class MockEvaluator:
Expand Down Expand Up @@ -99,6 +113,10 @@ def test_downstream_performance_metric():
assert downstream_perf_metric(D1, D2, D_test) == 0

assert downstream_perf_metric(D1, D2, D_test) == downstream_perf_metric(ts, diff_ts, test_ts)
assert downstream_perf_metric(D1, D2, D_test) == downstream_perf_metric(D1, diff_ts, D_test)
assert downstream_perf_metric(D1, D2, D_test) == downstream_perf_metric(ts, D2, D_test)
mean, std = downstream_perf_metric(D1, D2, D_test, return_std=True)
assert mean == 0 and std == 0


class FlattenTSOneClassSVM:
Expand Down Expand Up @@ -145,3 +163,20 @@ def test_mmd_metric():

assert mmd_metric(D1, D2) == mmd_metric(ts, diff_ts)
assert mmd_metric(D1, D1) == 0 and mmd_metric(D2, D2) == 0


def test_discriminative_metric():
ts = np.array([[[0, 2], [11, -11], [1, 2]], [[10, 21], [1, -1], [6, 8]]]).astype(np.float32)
D1 = tsgm.dataset.Dataset(ts, y=None)

diff_ts = np.array([[[12, 13], [10, 10], [-1, -2]], [[-1, 32], [2, 1], [10, 8]]]).astype(np.float32)
D2 = tsgm.dataset.Dataset(diff_ts, y=None)

model = tsgm.models.zoo["clf_cl_n"](seq_len=ts.shape[1], feat_dim=ts.shape[2], output_dim=1).model
model.compile(
tf.keras.optimizers.Adam(),
tf.keras.losses.CategoricalCrossentropy(from_logits=True)
)
discr_metric = tsgm.metrics.DiscriminativeMetric()
assert discr_metric(d_hist=D1, d_syn=D2, model=model, test_size=0.2, random_seed=42, n_epochs=10) == 1.0
assert discr_metric(d_hist=D1, d_syn=D2, model=model, metric=sklearn.metrics.precision_score, test_size=0.2, random_seed=42, n_epochs=10) == 1.0
4 changes: 2 additions & 2 deletions tsgm/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import tsgm.metrics.statistics
from tsgm.metrics.metrics import (
SimilarityMetric, ConsistencyMetric, BaseDownstreamEvaluator,
DistanceMetric, ConsistencyMetric, BaseDownstreamEvaluator,
DownstreamPerformanceMetric, PrivacyMembershipInferenceMetric,
MMDMetric
MMDMetric, DiscriminativeMetric
)
28 changes: 21 additions & 7 deletions tsgm/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __call__(self, *args, **kwargs) -> float:
pass


class SimilarityMetric(Metric):
class DistanceMetric(Metric):
"""
Metric that measures similarity between synthetic and real time series
"""
Expand Down Expand Up @@ -73,10 +73,6 @@ def __call__(self, D1: tsgm.dataset.DatasetOrTensor, D2: tsgm.dataset.DatasetOrT
:returns: similarity metric between D1 & D2.
"""

# TODO: check compatibility of this metric in different versions of python
# typing.get_args() can be used instead
# assert isinstance(D1, tsgm.dataset.Dataset) and isinstance(D2, tsgm.dataset.Dataset) or\
# isinstance(D1, tsgm.types.Tensor.__args__) and isinstance(D2, tsgm.types.Tensor.__args__)
if isinstance(D1, tsgm.dataset.Dataset) and isinstance(D2, tsgm.dataset.Dataset):
X1, X2 = D1.Xy_concat, D2.Xy_concat
else:
Expand Down Expand Up @@ -151,10 +147,12 @@ def __call__(self, D1: tsgm.dataset.DatasetOrTensor, D2: tsgm.dataset.DatasetOrT

:returns: downstream performance metric between D1 & D2.
"""
if isinstance(D1, tsgm.dataset.Dataset):
if isinstance(D1, tsgm.dataset.Dataset) and isinstance(D2, tsgm.dataset.Dataset):
D1D2 = D1 | D2
else:
if isinstance(D2, tsgm.dataset.Dataset):
if isinstance(D1, tsgm.dataset.Dataset):
D1D2 = np.concatenate((D1.X, D2))
elif isinstance(D2, tsgm.dataset.Dataset):
D1D2 = np.concatenate((D1, D2.X))
else:
D1D2 = np.concatenate((D1, D2))
Expand Down Expand Up @@ -211,3 +209,19 @@ def __call__(self, D1: tsgm.dataset.DatasetOrTensor, D2: tsgm.dataset.DatasetOrT
logger.warning("It is currently impossible to run MMD for labeled time series. Labels will be ignored!")
X1, X2 = _dataset_or_tensor_to_tensor(D1), _dataset_or_tensor_to_tensor(D2)
return tsgm.utils.mmd.MMD(X1, X2, kernel=self.kernel)


class DiscriminativeMetric(Metric):
"""
The discriminative metric measures how accurately a discriminative model can separate synthetic and real data.
"""
def __call__(self, d_hist: tsgm.dataset.DatasetOrTensor, d_syn: tsgm.dataset.DatasetOrTensor, model, test_size, n_epochs, metric=None, random_seed=None) -> float:
X_hist, X_syn = _dataset_or_tensor_to_tensor(d_hist), _dataset_or_tensor_to_tensor(d_syn)
X_all, y_all = np.concatenate([X_hist, X_syn]), np.concatenate([[1] * len(d_hist), [0] * len(d_syn)])
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(X_all, y_all, test_size=test_size, random_state=random_seed)
model.fit(X_train, y_train, epochs=n_epochs)
y_pred = model.predict(X_test)
if metric is None:
return sklearn.metrics.accuracy_score(y_test, y_pred)
else:
return metric(y_test, y_pred)
6 changes: 3 additions & 3 deletions tutorials/GANs/RCGAN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@
"statistics = [functools.partial(tsgm.metrics.statistics.axis_max_s, axis=1),\n",
" functools.partial(tsgm.metrics.statistics.axis_min_s, axis=1)]\n",
"\n",
"sim_metric = tsgm.metrics.SimilarityMetric(\n",
"sim_metric = tsgm.metrics.DistanceMetric(\n",
" statistics=statistics, discrepancy=lambda x, y: np.linalg.norm(x - y)\n",
")\n",
"\n",
"print(f\"Similarity metric: {sim_metric(X, X_gen)}\")"
"print(f\"Distance metric: {sim_metric(X, X_gen)}\")"
]
},
{
Expand Down Expand Up @@ -259,7 +259,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.10.8"
}
},
"nbformat": 4,
Expand Down
10 changes: 5 additions & 5 deletions tutorials/Metrics Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@
"id": "a121d403",
"metadata": {},
"source": [
"## Similarity metric\n",
"## Distance metric\n",
"\n",
"First, we define a list of summary statistics that reflect the similarity between the datasets. Module `tss.metrics.statistics` defines a set of handy statistics."
"First, we define a list of summary statistics that reflect the distance between the datasets. Module `tss.metrics.statistics` defines a set of handy statistics."
]
},
{
Expand Down Expand Up @@ -103,7 +103,7 @@
"id": "b2a72a7c",
"metadata": {},
"source": [
"Finally, we are putting all together using `tss.metrics.SimilarityMetric` object."
"Finally, we are putting all together using `tss.metrics.DistanceMetric` object."
]
},
{
Expand All @@ -113,7 +113,7 @@
"metadata": {},
"outputs": [],
"source": [
"sim_metric = tsgm.metrics.SimilarityMetric(\n",
"sim_metric = tsgm.metrics.DistanceMetric(\n",
" statistics=statistics, discrepancy=discrepancy_func\n",
")"
]
Expand Down Expand Up @@ -322,7 +322,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.10.8"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions tutorials/Model Selection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
"metadata": {},
"outputs": [],
"source": [
"metric_to_optimize = tsgm.metrics.metrics.SimilarityMetric(\n",
"metric_to_optimize = tsgm.metrics.metrics.DistanceMetric(\n",
" statistics=[\n",
" functools.partial(tsgm.metrics.statistics.axis_max_s, axis=None),\n",
" functools.partial(tsgm.metrics.statistics.axis_min_s, axis=None),\n",
Expand Down Expand Up @@ -265,7 +265,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.10.8"
}
},
"nbformat": 4,
Expand Down