Skip to content

Commit

Permalink
Feature/pretty print objective (#175)
Browse files Browse the repository at this point in the history
This PR implements the pretty print concept on `objective`.

- It implements the `__str__()` for the classes `objective`, `targets` and `parameters`
- It also fixes the problem of printing too many line breaks in case a `searchspace` has an empty subspace
  • Loading branch information
RimRihana authored Mar 25, 2024
2 parents 10d1372 + d028d22 commit be2dca2
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Simulation user guide
- Example for transfer learning backtest utility
- `pyupgrade` pre-commit hook
- Better human readable `__str__` representation of objective and targets

### Changed
- More detailed and sophisticated search space user guide
Expand Down
15 changes: 15 additions & 0 deletions baybe/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@ class Objective(SerialMixin):
)
"""The function used to combine the different targets."""

def __str__(self) -> str:
start_bold = "\033[1m"
end_bold = "\033[0m"

# Convert the targets list to a dataframe to have a tabular output
targets_list = [target.summary() for target in self.targets]
targets_df = pd.DataFrame(targets_list)
targets_df["Weight"] = self.weights

objective_str = f"""{start_bold}Objective{end_bold}
\n{start_bold}Mode: {end_bold}{self.mode}
\n{start_bold}Targets {end_bold}\n{targets_df}
\n{start_bold}Combine Function: {end_bold}{self.combine_func}"""
return objective_str.replace("\n", "\n ")

@weights.default
def _default_weights(self) -> list[float]:
"""Create the default weights."""
Expand Down
3 changes: 3 additions & 0 deletions baybe/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def is_in_range(self, item: Any) -> bool:
def summary(self) -> dict:
"""Return a custom summarization of the parameter."""

def __str__(self) -> str:
return str(self.summary())


@define(frozen=True, slots=False)
class DiscreteParameter(Parameter, ABC):
Expand Down
14 changes: 10 additions & 4 deletions baybe/searchspace/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,16 @@ class SearchSpace(SerialMixin):
def __str__(self) -> str:
start_bold = "\033[1m"
end_bold = "\033[0m"
searchspace_str = f"""{start_bold}Search Space{end_bold}
\n{start_bold}Search Space Type: {end_bold}{self.type.name}
\n{self.discrete}
\n{self.continuous}"""
head_str = f"""{start_bold}Search Space{end_bold}
\n{start_bold}Search Space Type: {end_bold}{self.type.name}"""

# Check the sub space size to avoid adding unwanted break lines
# if the sub space is empty
discrete_str = f"\n\n{self.discrete}" if not self.discrete.is_empty else ""
continuous_str = (
f"\n\n{self.continuous}" if not self.continuous.is_empty else ""
)
searchspace_str = f"{head_str}{discrete_str}{continuous_str}"
return searchspace_str.replace("\n", "\n ").replace("\r", "\r ")

def __attrs_post_init__(self):
Expand Down
7 changes: 7 additions & 0 deletions baybe/targets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ def transform(self, data: pd.DataFrame) -> pd.DataFrame:
A dataframe containing the transformed data.
"""

@abstractmethod
def summary(self) -> dict:
"""Return a custom summarization of the target."""

def __str__(self) -> str:
return str(self.summary())


def _add_missing_type_hook(hook):
"""Adjust the structuring hook such that it auto-fills missing target types.
Expand Down
12 changes: 12 additions & 0 deletions baybe/targets/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,15 @@ def transform(self, data: pd.DataFrame) -> pd.DataFrame: # noqa: D102
transformed = data.copy()

return transformed

def summary(self) -> dict: # noqa: D102
# See base class.
target_dict = dict(
Type=self.__class__.__name__,
Name=self.name,
Mode=self.mode.name,
Lower_Bound=self.bounds.lower,
Upper_Bound=self.bounds.upper,
Transformation=self.transformation.name if self.transformation else "None",
)
return target_dict

0 comments on commit be2dca2

Please sign in to comment.