Skip to content

Commit

Permalink
Update Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander März committed Aug 7, 2023
1 parent e6b5346 commit deb7a8e
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 69 deletions.
126 changes: 63 additions & 63 deletions test/test_distribution_utils/test_dist_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,23 @@ def test_univar_dist_select(self):
assert not np.isnan(dist_df["nll"].values).any()
assert not np.isinf(dist_df["nll"].values).any()

# def test_univar_dist_select_plot(self):
# # Create data for testing
# target = np.array([0.2, 0.4, 0.6, 0.8]).reshape(-1, 1)
# candidate_distributions = [Beta, Gaussian, StudentT, Gamma, Cauchy, LogNormal, Weibull, Gumbel, Laplace]
#
# # Call the function
# dist_df = univariate_dist_class().dist_select(
# target, candidate_distributions, n_samples=10, plot=True
# ).reset_index(drop=True)
#
# # Assertions
# assert isinstance(dist_df, pd.DataFrame)
# assert not dist_df.isna().any().any()
# assert isinstance(dist_df["distribution"].values[0], str)
# assert np.issubdtype(dist_df["nll"].dtype, np.float64)
# assert not np.isnan(dist_df["nll"].values).any()
# assert not np.isinf(dist_df["nll"].values).any()
def test_univar_dist_select_plot(self):
# Create data for testing
target = np.array([0.2, 0.4, 0.6, 0.8]).reshape(-1, 1)
candidate_distributions = [Beta, Gaussian, StudentT, Gamma, Cauchy, LogNormal, Weibull, Gumbel, Laplace]

# Call the function
dist_df = univariate_dist_class().dist_select(
target, candidate_distributions, n_samples=10, plot=True
).reset_index(drop=True)

# Assertions
assert isinstance(dist_df, pd.DataFrame)
assert not dist_df.isna().any().any()
assert isinstance(dist_df["distribution"].values[0], str)
assert np.issubdtype(dist_df["nll"].dtype, np.float64)
assert not np.isnan(dist_df["nll"].values).any()
assert not np.isinf(dist_df["nll"].values).any()

####################################################################################################################
# Normalizing Flows
Expand Down Expand Up @@ -79,29 +79,29 @@ def test_flow_select(self):
assert not np.isnan(dist_df["nll"].values).any()
assert not np.isinf(dist_df["nll"].values).any()

# def test_flow_select_plot(self):
# # Create data for testing
# target = np.array([0.2, 0.4, 0.6, 0.8]).reshape(-1, 1)
# bound = np.max([np.abs(target.min()), target.max()])
# target_support = "real"
#
# candidate_flows = [
# SplineFlow(target_support=target_support, count_bins=2, bound=bound, order="linear"),
# SplineFlow(target_support=target_support, count_bins=2, bound=bound, order="quadratic")
# ]
#
# # Call the function
# dist_df = flow_dist_class().flow_select(
# target, candidate_flows, n_samples=10, plot=True
# ).reset_index(drop=True)
#
# # Assertions
# assert isinstance(dist_df, pd.DataFrame)
# assert not dist_df.isna().any().any()
# assert isinstance(dist_df["NormFlow"].values[0], str)
# assert np.issubdtype(dist_df["nll"].dtype, np.float64)
# assert not np.isnan(dist_df["nll"].values).any()
# assert not np.isinf(dist_df["nll"].values).any()
def test_flow_select_plot(self):
# Create data for testing
target = np.array([0.2, 0.4, 0.6, 0.8]).reshape(-1, 1)
bound = np.max([np.abs(target.min()), target.max()])
target_support = "real"

candidate_flows = [
SplineFlow(target_support=target_support, count_bins=2, bound=bound, order="linear"),
SplineFlow(target_support=target_support, count_bins=2, bound=bound, order="quadratic")
]

# Call the function
dist_df = flow_dist_class().flow_select(
target, candidate_flows, n_samples=10, plot=True
).reset_index(drop=True)

# Assertions
assert isinstance(dist_df, pd.DataFrame)
assert not dist_df.isna().any().any()
assert isinstance(dist_df["NormFlow"].values[0], str)
assert np.issubdtype(dist_df["nll"].dtype, np.float64)
assert not np.isnan(dist_df["nll"].values).any()
assert not np.isinf(dist_df["nll"].values).any()

####################################################################################################################
# Multivariate Distribution
Expand Down Expand Up @@ -130,26 +130,26 @@ def test_multivar_dist_select(self):
assert not np.isnan(dist_df["nll"].values).any()
assert not np.isinf(dist_df["nll"].values).any()

# def test_multivar_dist_select_plot(self):
# # Create data for testing
# multivar_dist_class = MVN()
# target = np.arange(0.1, 0.9, 0.1)
# target = multivar_dist_class.target_append(
# target,
# multivar_dist_class.n_targets,
# multivar_dist_class.n_dist_param
# )
# candidate_distributions = [MVN(), MVT(), MVN_LoRa()]
#
# # Call the function
# dist_df = multivariate_dist_class().dist_select(
# target, candidate_distributions, n_samples=10, plot=True
# ).reset_index(drop=True)
#
# # Assertions
# assert isinstance(dist_df, pd.DataFrame)
# assert not dist_df.isna().any().any()
# assert isinstance(dist_df["distribution"].values[0], str)
# assert np.issubdtype(dist_df["nll"].dtype, np.float64)
# assert not np.isnan(dist_df["nll"].values).any()
# assert not np.isinf(dist_df["nll"].values).any()
def test_multivar_dist_select_plot(self):
# Create data for testing
multivar_dist_class = MVN()
target = np.arange(0.1, 0.9, 0.1)
target = multivar_dist_class.target_append(
target,
multivar_dist_class.n_targets,
multivar_dist_class.n_dist_param
)
candidate_distributions = [MVN(), MVT(), MVN_LoRa()]

# Call the function
dist_df = multivariate_dist_class().dist_select(
target, candidate_distributions, n_samples=10, plot=True
).reset_index(drop=True)

# Assertions
assert isinstance(dist_df, pd.DataFrame)
assert not dist_df.isna().any().any()
assert isinstance(dist_df["distribution"].values[0], str)
assert np.issubdtype(dist_df["nll"].dtype, np.float64)
assert not np.isnan(dist_df["nll"].values).any()
assert not np.isinf(dist_df["nll"].values).any()
1 change: 0 additions & 1 deletion test/test_distribution_utils/test_draw_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

class TestClass(BaseTestClass):
def test_draw_samples(self, dist_class, loss_fn):

if dist_class.dist.univariate:
# Create data for testing
predt_params = pd.DataFrame(np.array([0.5 for _ in range(dist_class.dist.n_dist_param)], dtype="float32")).T
Expand Down
4 changes: 1 addition & 3 deletions test/test_distributions/test_Expectile.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,4 @@ def test_expectile_norm():
# Assertions
assert isinstance(out, np.ndarray)
assert not np.isnan(out).any()
assert not np.isinf(out).any(

)
assert not np.isinf(out).any()
4 changes: 2 additions & 2 deletions test/test_distributions/test_multivariate_distns.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd

from ..utils import BaseTestClass
from xgboostlss.utils import softplus_fn
from xgboostlss.utils import softplus_fn, softplus_fn_df
import pytest
import torch
import numpy as np
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_create_param_dict(self, multivariate_dist):
if multivariate_dist.__name__ == "MVN_LoRa":
param_dict = multivariate_dist.create_param_dict(n_targets=2, rank=1, response_fn=softplus_fn)
if multivariate_dist.__name__ == "MVT":
param_dict = multivariate_dist.create_param_dict(n_targets=2, response_fn=softplus_fn, response_fn_df=softplus_fn)
param_dict = multivariate_dist.create_param_dict(n_targets=2, response_fn=softplus_fn, response_fn_df=softplus_fn_df)
assert isinstance(param_dict, dict)
assert all(callable(func) for func in param_dict.values())

Expand Down

0 comments on commit deb7a8e

Please sign in to comment.