diff --git a/tests/test_op.py b/tests/test_op.py index 8e821e0058..f41b4eeacc 100644 --- a/tests/test_op.py +++ b/tests/test_op.py @@ -2,6 +2,7 @@ from numpy.testing import assert_allclose import mxnet as mx from mxnet import gluon +from scipy.stats import ks_2samp import pytest from gluonnlp.op import * mx.npx.set_np() @@ -102,12 +103,36 @@ def test_gumbel_softmax(shape): assume_allones = (ret == 1).sum(axis=-1).asnumpy() assert_allclose(assume_allones, np.ones_like(assume_allones)) - +@pytest.mark.parametrize('shape', (50,)) @pytest.mark.seed(1) -def test_trunc_gumbel(): - # TODO(?) Improve the test case here - # It's generally difficult to test whether the samples are generated from a truncated gumbel - # distribution. Thus, we just verify that the samples are smaller than the provided threshold +def test_trunc_gumbel(shape): + # We first just verify that the samples are smaller than the provided threshold (i.e. they are truncated) + # And also attempt to remove the truncation and verify if it is sampled from a gumbel distribution + # using a KS-test with another sampled gumbel distribution + + # Verifying if the distribution is truncated for i in range(1000): - samples = trunc_gumbel(mx.np.ones((10,)), 1.0).asnumpy() + samples = trunc_gumbel(mx.np.ones(shape), 1.0).asnumpy() assert (samples < 1.0).all() + + # perform ks-tests + pvalues = [] + for i in range(1000): + logits = mx.np.random.uniform(-2, -1, shape) + sampled_gumbels = mx.np.random.gumbel(mx.np.zeros_like(logits)) + logits # sample a gumbel distribution + + # sample a potential truncated gumbel distribution + gumbels = mx.np.random.gumbel(mx.np.zeros_like(logits)) + logits + sampled_truncated_gumbels = trunc_gumbel(logits, 0.5) + + # remove the truncation + reconstructed_sample = -mx.np.log(mx.np.exp(-sampled_truncated_gumbels) - mx.np.exp(-0.5)) + + pvalue = ks_2samp(reconstructed_sample.asnumpy(), sampled_gumbels.asnumpy()).pvalue + pvalues.append(pvalue) + + pvalues = np.array(pvalues) + # Statistical inference condition: if out of all the tests, 90% of the resultant p-values > 0.05, + # accept the null hypothesis (i.e. the reconstructed_samples indeed arrive from a gumbel distribution) + assert (len(pvalues[pvalues > 0.05]) > 900) +