Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[TEST] Add statistical inference for Truncated Gumbel distribution (#…
Browse files Browse the repository at this point in the history
…1578)

* Add statistical inference for Truncated Gumbel distribution

* Fix mxnet issue with test
  • Loading branch information
AetherPrior authored Aug 24, 2021
1 parent f5a9fc1 commit 83780ab
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions tests/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from numpy.testing import assert_allclose
import mxnet as mx
from mxnet import gluon
from scipy.stats import ks_2samp
import pytest
from gluonnlp.op import *
mx.npx.set_np()
Expand Down Expand Up @@ -102,12 +103,36 @@ def test_gumbel_softmax(shape):
assume_allones = (ret == 1).sum(axis=-1).asnumpy()
assert_allclose(assume_allones, np.ones_like(assume_allones))


@pytest.mark.parametrize('shape', (50,))
@pytest.mark.seed(1)
def test_trunc_gumbel():
# TODO(?) Improve the test case here
# It's generally difficult to test whether the samples are generated from a truncated gumbel
# distribution. Thus, we just verify that the samples are smaller than the provided threshold
def test_trunc_gumbel(shape):
# We first just verify that the samples are smaller than the provided threshold (i.e. they are truncated)
# And also attempt to remove the truncation and verify if it is sampled from a gumbel distribution
# using a KS-test with another sampled gumbel distribution

# Verifying if the distribution is truncated
for i in range(1000):
samples = trunc_gumbel(mx.np.ones((10,)), 1.0).asnumpy()
samples = trunc_gumbel(mx.np.ones(shape), 1.0).asnumpy()
assert (samples < 1.0).all()

# perform ks-tests
pvalues = []
for i in range(1000):
logits = mx.np.random.uniform(-2, -1, shape)
sampled_gumbels = mx.np.random.gumbel(mx.np.zeros_like(logits)) + logits # sample a gumbel distribution

# sample a potential truncated gumbel distribution
gumbels = mx.np.random.gumbel(mx.np.zeros_like(logits)) + logits
sampled_truncated_gumbels = trunc_gumbel(logits, 0.5)

# remove the truncation
reconstructed_sample = -mx.np.log(mx.np.exp(-sampled_truncated_gumbels) - mx.np.exp(-0.5))

pvalue = ks_2samp(reconstructed_sample.asnumpy(), sampled_gumbels.asnumpy()).pvalue
pvalues.append(pvalue)

pvalues = np.array(pvalues)
# Statistical inference condition: if out of all the tests, 90% of the resultant p-values > 0.05,
# accept the null hypothesis (i.e. the reconstructed_samples indeed arrive from a gumbel distribution)
assert (len(pvalues[pvalues > 0.05]) > 900)

0 comments on commit 83780ab

Please sign in to comment.