Skip to content

Commit

Permalink
Fix discriminative metric (#45)
Browse files Browse the repository at this point in the history
* fix bugs in SliceAndShuffle

* fix DiscriminativeMetric

* better comment

* better comment

* update test_discriminative_metric

---------

Co-authored-by: liyiersan-server5 <[email protected]>
  • Loading branch information
liyiersan and liyiersan authored May 24, 2024
1 parent 5d02985 commit 6237a0f
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 87 deletions.
9 changes: 5 additions & 4 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,18 +191,19 @@ def test_mmd_metric():


def test_discriminative_metric():
ts = np.array([[[0, 2], [11, -11], [1, 2]], [[10, 21], [1, -1], [6, 8]]]).astype(np.float32)
ts = np.sin(np.arange(10)[:, None, None] + np.arange(6)[None, :, None]) # sin_sequence, [10, 6, 3]
D1 = tsgm.dataset.Dataset(ts, y=None)

diff_ts = np.array([[[12, 13], [10, 10], [-1, -2]], [[-1, 32], [2, 1], [10, 8]]]).astype(np.float32)
diff_ts = np.sin(np.arange(10)[:, None, None] + np.arange(6)[None, :, None]) + 1000 # sin_sequence, [10, 6, 3]
D2 = tsgm.dataset.Dataset(diff_ts, y=None)

model = tsgm.models.zoo["clf_cl_n"](seq_len=ts.shape[1], feat_dim=ts.shape[2], output_dim=1).model
model = tsgm.models.zoo["clf_cl_n"](seq_len=ts.shape[1], feat_dim=ts.shape[2], output_dim=2).model
model.compile(
tf.keras.optimizers.Adam(),
tf.keras.losses.CategoricalCrossentropy(from_logits=True)
tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
)
discr_metric = tsgm.metrics.DiscriminativeMetric()
# should be easy to be classified
assert discr_metric(d_hist=D1, d_syn=D2, model=model, test_size=0.2, random_seed=42, n_epochs=5) == 1.0
assert discr_metric(d_hist=D1, d_syn=D2, model=model, metric=sklearn.metrics.precision_score, test_size=0.2, random_seed=42, n_epochs=5) == 1.0

Expand Down
7 changes: 6 additions & 1 deletion tsgm/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,12 @@ def __call__(self, d_hist: tsgm.dataset.DatasetOrTensor, d_syn: tsgm.dataset.Dat
X_all, y_all = np.concatenate([X_hist, X_syn]), np.concatenate([[1] * len(d_hist), [0] * len(d_syn)])
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(X_all, y_all, test_size=test_size, random_state=random_seed)
model.fit(X_train, y_train, epochs=n_epochs)
y_pred = (model.predict(X_test) > 0.5).astype(int)
pred = model.predict(X_test)
# check the shape, 1D array or N-D arrary
if len(pred.shape) == 1: # binary classification with sigmoid activation
y_pred = (pred > 0.5).astype(int)
else: # multiple classification with softmax activation
y_pred = np.argmax(pred, axis=-1).astype(int)
if metric is None:
return sklearn.metrics.accuracy_score(y_test, y_pred)
else:
Expand Down
Loading

0 comments on commit 6237a0f

Please sign in to comment.