-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Config and serialization enhancement (#86)
This PR significantly enhances our serialization machinery: * Search spaces can now be deserialized using a selected classmethod. Currently, this is enabled by manually registering hooks for our search space classes. However, if we add more constructors to other classes in the future, we can consider automating the process using hook predicate functions / factories. Perhaps, there will even be a dedicated [strategy](python-attrs/cattrs#489) available for this mechanism. * Due to this change, there is no more separate "config converter" required, since creation from config can now happen via the regular converter. In fact, in future we can even think about deprecating our "from_config" approach, because there really is no more such thing as a "config" – it's just a "regular" JSON string that goes through the default converter. * The serialization functionality now sits in its own subpackage. * Missing serialization mixins have been added to `Interval` and the two subspace classes. * Added a basic serialization roundtrip test for dataframes and a corresponding hypothesis strategy. * Changed binarization of dataframes to use regular pickle due to some edge cases with the previous `parquet` approach that where detected through the above test. * Added a `SearchSpace.from_dataframe` convenience constructor for consistency, which is also helps to simplify campaign configs.
- Loading branch information
Showing
37 changed files
with
750 additions
and
303 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
"""Parameter utilities.""" | ||
|
||
from typing import Any, Callable, Collection, Dict, List, Optional, TypeVar | ||
|
||
import pandas as pd | ||
|
||
from baybe.parameters.base import Parameter | ||
|
||
_TParameter = TypeVar("_TParameter", bound=Parameter) | ||
|
||
|
||
def get_parameters_from_dataframe( | ||
df: pd.DataFrame, | ||
factory: Callable[[str, Collection[Any]], _TParameter], | ||
parameters: Optional[List[_TParameter]] = None, | ||
) -> List[_TParameter]: | ||
"""Create a list of parameters from a dataframe. | ||
Returns one parameter for each column of the given dataframe. By default, | ||
the parameters are created using the provided factory, which takes the name | ||
of the column and its unique values as arguments. However, there is also | ||
the possibility to provide explicit parameter objects with names matching specific | ||
columns of the dataframe, to bypass the parameter factory creation for those | ||
columns. This allows finer control, for example, to specify custom parameter | ||
attributes (e.g. specific optional arguments) compared to what would be provided | ||
by the factory. Still, the pre-specified parameters are validated to ensure that | ||
they are compatible with the contents of the dataframe. | ||
Args: | ||
df: The dataframe from which to create the parameters. | ||
factory: A parameter factor, creating parameter objects for the columns | ||
from the column name and the unique column values. | ||
parameters: An optional list of parameter objects to bypass the factory | ||
creation for columns whose names match with the parameter names. | ||
Returns: | ||
The combined parameter list, containing both the (validated) pre-specified | ||
parameters and the parameters inferred from the dataframe. | ||
Raises: | ||
ValueError: If several parameters with identical names are provided. | ||
ValueError: If a parameter was specified for which no match was found. | ||
""" | ||
# Turn the pre-specified parameters into a dict and check for duplicate names | ||
specified_params: Dict[str, _TParameter] = {} | ||
if parameters is not None: | ||
for param in parameters: | ||
if param.name in specified_params: | ||
raise ValueError( | ||
f"You provided several parameters with the name '{param.name}'." | ||
) | ||
specified_params[param.name] = param | ||
|
||
# Try to find a parameter match for each dataframe column | ||
parameters = [] | ||
for name, series in df.items(): | ||
assert isinstance( | ||
name, str | ||
), "The given dataframe must only contain string-valued column names." | ||
unique_values = series.unique() | ||
|
||
# If a match is found, assert that the values are in range | ||
if match := specified_params.pop(name, None): | ||
if not all(match.is_in_range(x) for x in unique_values): | ||
raise ValueError( | ||
f"The dataframe column '{name}' contains the values " | ||
f"{unique_values}, which are outside the range of {match}." | ||
) | ||
parameters.append(match) | ||
|
||
# Otherwise, create a new parameter using the factory | ||
else: | ||
param = factory(name, unique_values) | ||
parameters.append(param) | ||
|
||
# By now, all pre-specified parameters must have been used | ||
if specified_params: | ||
raise ValueError( | ||
f"For the parameter(s) {list(specified_params.keys())}, " | ||
f"no match could be found in the given dataframe." | ||
) | ||
|
||
return parameters |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.