Skip to content

Commit

Permalink
Added plot options
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander März committed Aug 8, 2023
1 parent ff16ec5 commit d97f4a5
Showing 1 changed file with 49 additions and 1 deletion.
50 changes: 49 additions & 1 deletion tests/test_distribution_utils/test_dist_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,30 @@ def test_flow_select_plot(self):
assert not np.isnan(dist_df["nll"].values).any()
assert not np.isinf(dist_df["nll"].values).any()

def test_flow_select_plot(self):
# Create data for testing
target = np.array([0.2, 0.4, 0.6, 0.8]).reshape(-1, 1)
bound = np.max([np.abs(target.min()), target.max()])
target_support = "real"

candidate_flows = [
SplineFlow(target_support=target_support, count_bins=2, bound=bound, order="linear"),
SplineFlow(target_support=target_support, count_bins=2, bound=bound, order="quadratic")
]

# Call the function
dist_df = flow_dist_class().flow_select(
target, candidate_flows, n_samples=10, plot=True
).reset_index(drop=True)

# Assertions
assert isinstance(dist_df, pd.DataFrame)
assert not dist_df.isna().any().any()
assert isinstance(dist_df["NormFlow"].values[0], str)
assert np.issubdtype(dist_df["nll"].dtype, np.float64)
assert not np.isnan(dist_df["nll"].values).any()
assert not np.isinf(dist_df["nll"].values).any()

####################################################################################################################
# Multivariate Distribution
####################################################################################################################
Expand All @@ -112,7 +136,7 @@ def test_multivar_dist_select(self):
target,
multivar_dist_class.n_targets,
multivar_dist_class.n_dist_param
)
)[:, :multivar_dist_class.n_targets]
candidate_distributions = [MVN(), MVT(), MVN_LoRa()]

# Call the function
Expand All @@ -127,3 +151,27 @@ def test_multivar_dist_select(self):
assert np.issubdtype(dist_df["nll"].dtype, np.float64)
assert not np.isnan(dist_df["nll"].values).any()
assert not np.isinf(dist_df["nll"].values).any()

def test_multivar_dist_select_plot(self):
# Create data for testing
multivar_dist_class = MVN()
target = np.arange(0.1, 0.9, 0.1)
target = multivar_dist_class.target_append(
target,
multivar_dist_class.n_targets,
multivar_dist_class.n_dist_param
)[:, :multivar_dist_class.n_targets]
candidate_distributions = [MVN(), MVT(), MVN_LoRa()]

# Call the function
dist_df = multivariate_dist_class().dist_select(
target, candidate_distributions, n_samples=10, plot=True, ncol=1
).reset_index(drop=True)

# Assertions
assert isinstance(dist_df, pd.DataFrame)
assert not dist_df.isna().any().any()
assert isinstance(dist_df["distribution"].values[0], str)
assert np.issubdtype(dist_df["nll"].dtype, np.float64)
assert not np.isnan(dist_df["nll"].values).any()
assert not np.isinf(dist_df["nll"].values).any()

0 comments on commit d97f4a5

Please sign in to comment.