-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge: Fix rare degenerate
comp_df
bug (#399)
- Adds a utility `df_add_noise_to_degenerate_rows` to add noise to degenerate rows of a numerical dataframe - Adds a test for the utility - Disallows `CustomDiscreteParameter` accepting `data` that has degenerate rows - Adds input tests for `CustomDiscreteParameter` - Adds the fix to the `comp_df` of `SubstanceParameter` - End-to-end test for degenerate substance encoding
- Loading branch information
Showing
8 changed files
with
147 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
"""Tests for dataframe utilities.""" | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from baybe.utils.dataframe import add_noise_to_perturb_degenerate_rows | ||
|
||
|
||
def test_degenerate_rows(): | ||
"""Test noise-based deduplication of degenerate rows.""" | ||
# Create random dataframe | ||
df = pd.DataFrame(np.random.randint(0, 100, size=(5, 3))).astype(float) | ||
|
||
# Manually create some degenerate rows | ||
df.loc[1] = df.loc[0] # Make row 1 identical to row 0 | ||
df.loc[3] = df.loc[2] # Make row 3 identical to row 2 | ||
df.iloc[:, -1] = 50.0 # Make last column constant to test the edge case | ||
|
||
# Add noise | ||
add_noise_to_perturb_degenerate_rows(df) | ||
|
||
# Assert that the utility fixed the degenerate rows | ||
assert not df.duplicated().any(), "Degenerate rows were not fixed by the utility." | ||
|
||
|
||
def test_degenerate_rows_invalid_input(): | ||
"""Test that the utility correctly handles invalid input.""" | ||
# Create random dataframe | ||
df = pd.DataFrame(np.random.randint(0, 100, size=(5, 3))).astype(float) | ||
|
||
# Manually create some degenerate rows | ||
df.loc[1] = df.loc[0] # Make row 1 identical to row 0 | ||
df.loc[3] = df.loc[2] # Make row 3 identical to row 2 | ||
|
||
# Insert invalid data types | ||
df = df.astype(object) # to avoid pandas column dtype warnings | ||
df["invalid"] = "A" | ||
df.iloc[1, 0] = "B" | ||
|
||
# Add noise | ||
with pytest.raises(TypeError): | ||
add_noise_to_perturb_degenerate_rows(df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters