Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docstrings and print statements #230

Merged
merged 10 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,10 @@ repos:
hooks:
- id: isort
args: ["--profile", "black"]

- repo: https://github.com/PyCQA/docformatter
rev: v1.7.5
hooks:
- id: docformatter
additional_dependencies: [tomli]
args: ["--in-place", "--black"]
4 changes: 2 additions & 2 deletions examples/baseline_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def evaluate(
)
pred = np.stack(pred_ser.values)
elif name == "global_popularity":
"""Predict the globally most visited dst nodes and predict them across
the src nodes."""
"""Predict the globally most visited dst nodes and predict them across the src
nodes."""
lst_cat = []
for lst in train_table.df[task.dst_entity_col]:
lst_cat.extend(lst)
Expand Down
14 changes: 7 additions & 7 deletions examples/lightgbm_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def add_past_label_feature(
train_table_df: pd.DataFrame,
past_table_df: pd.DataFrame,
) -> pd.DataFrame:
"""Add past visit count and percentage of global popularity to train table
df used for lightGBM training, evaluation of testing.
"""Add past visit count and percentage of global popularity to train table df used
for lightGBM training, evaluation of testing.

Args:
evaluate_table_df (pd.DataFrame): The dataframe used for evaluation.
Expand Down Expand Up @@ -244,8 +244,8 @@ def add_past_label_feature(
def prepare_for_link_pred_eval(
evaluate_table_df: pd.DataFrame, past_table_df: pd.DataFrame
) -> pd.DataFrame:
"""Transform evaluation dataframe into the correct format for link
prediction metric calculation.
"""Transform evaluation dataframe into the correct format for link prediction metric
calculation.

Args:
pred_table_df (pd.DataFrame): The prediction dataframe.
Expand Down Expand Up @@ -375,9 +375,9 @@ def evaluate(
train_table: Table,
task: LinkTask,
) -> Dict[str, float]:
"""Given the input dataframe used for lightGBM binary link classification
and its output prediction scores and true labels, generate link prediction
evaluation metrics.
"""Given the input dataframe used for lightGBM binary link classification and its
output prediction scores and true labels, generate link prediction evaluation
metrics.

Args:
lightgbm_output (pd.DataFrame): The lightGBM input dataframe merged
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "relbench"
version = "0.2.0"
version = "1.0.0-rc1"
description = "RelBench: Relational Deep Learning Benchmark"
authors = [{name = "RelBench Team", email = "[email protected]"}]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion relbench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from relbench import base, datasets, modeling, tasks

__version__ = "1.0.0"
__version__ = "1.0.0-rc1"
33 changes: 17 additions & 16 deletions relbench/base/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@


class Database:
r"""A database is a collection of named tables linked by foreign key -
primary key connections."""
r"""A database is a collection of named tables linked by foreign key - primary key
connections."""

def __init__(self, table_dict: Dict[str, Table]) -> None:
r"""Creates a database from a dictionary of tables."""
Expand All @@ -22,15 +22,17 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}()"

def save(self, path: Union[str, os.PathLike]) -> None:
r"""Saves the database to a directory. Simply saves each table
individually with the table name as base name of file."""
r"""Save the database to a directory.

Simply saves each table individually with the table name as base name of file.
"""

for name, table in self.table_dict.items():
table.save(f"{path}/{name}.parquet")

@classmethod
def load(cls, path: Union[str, os.PathLike]) -> Self:
r"""Loads a database from a directory of tables in parquet files."""
r"""Load a database from a directory of tables in parquet files."""

table_dict = {}
for table_path in Path(path).glob("*.parquet"):
Expand All @@ -42,7 +44,7 @@ def load(cls, path: Union[str, os.PathLike]) -> Self:
@property
@lru_cache(maxsize=None)
def min_timestamp(self) -> pd.Timestamp:
r"""Returns the earliest timestamp in the database."""
r"""Return the earliest timestamp in the database."""

return min(
table.min_timestamp
Expand All @@ -53,36 +55,35 @@ def min_timestamp(self) -> pd.Timestamp:
@property
@lru_cache(maxsize=None)
def max_timestamp(self) -> pd.Timestamp:
r"""Returns the latest timestamp in the database."""
r"""Return the latest timestamp in the database."""

return max(
table.max_timestamp
for table in self.table_dict.values()
if table.time_col is not None
)

def upto(self, time_stamp: pd.Timestamp) -> Self:
r"""Returns a database with all rows upto time_stamp."""
def upto(self, timestamp: pd.Timestamp) -> Self:
r"""Return a database with all rows upto timestamp."""

return Database(
table_dict={
name: table.upto(time_stamp) for name, table in self.table_dict.items()
name: table.upto(timestamp) for name, table in self.table_dict.items()
}
)

def from_(self, time_stamp: pd.Timestamp) -> Self:
r"""Returns a database with all rows from time_stamp."""
def from_(self, timestamp: pd.Timestamp) -> Self:
r"""Return a database with all rows from timestamp."""

return Database(
table_dict={
name: table.from_(time_stamp) for name, table in self.table_dict.items()
name: table.from_(timestamp) for name, table in self.table_dict.items()
}
)

def reindex_pkeys_and_fkeys(self) -> None:
r"""Mapping primary and foreign keys into indices according to
the ordering in the primary key tables.
"""
r"""Map primary and foreign keys into indices according to the ordering in the
primary key tables."""
# Get pkey to idx mapping:
index_map_dict: Dict[str, pd.Series] = {}
for table_name, table in self.table_dict.items():
Expand Down
64 changes: 52 additions & 12 deletions relbench/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,44 @@


class Dataset:
r"""A dataset is a database with validation and test timestamps defined for it.

Attributes:
val_timestamp: Rows upto this timestamp (inclusive) can be input for validation.
test_timestamp: Rows upto this timestamp (inclusive) can be input for testing.

Validation split of a task involves predicting the target variable for a
time period after val_timestamp (exclusive) using data upto val_timestamp.
Similarly for test_timestamp.
"""

# To be set by subclass.
val_timestamp: pd.Timestamp
test_timestamp: pd.Timestamp

def __init__(
self,
cache_dir: Optional[str] = None,
) -> None:
r"""Create a dataset object.

Args:
cache_dir: A directory for caching the database object. If specified,
we will either process and cache the file (if not available) or use
the cached file. If None, we will not use cached file and re-process
everything from scratch without saving the cache.
"""

self.cache_dir = cache_dir

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"

def validate_and_correct_db(self, db):
r"""Validate and correct input db in-place."""
r"""Validate and correct input db in-place.

Removing rows after test_timestamp can result in dangling foreign keys.
"""
# Validate that all primary keys are consecutively index.

for table_name, table in db.table_dict.items():
Expand All @@ -46,33 +70,45 @@ def validate_and_correct_db(self, db):

@lru_cache(maxsize=None)
def get_db(self, upto_test_timestamp=True) -> Database:
r"""Return the database object.

The returned database object is cached in memory.

Args:
upto_test_timestamp: If True, only return rows upto test_timestamp.

Returns:
Database: The database object.

`upto_test_timestamp` is True by default to prevent test leakage.
"""

db_path = f"{self.cache_dir}/db"
if self.cache_dir and Path(db_path).exists() and any(Path(db_path).iterdir()):
print(f"loading Database object from {db_path}...")
print(f"Loading Database object from {db_path}...")
tic = time.time()
db = Database.load(db_path)
toc = time.time()
print(f"done in {toc - tic:.2f} seconds.")
print(f"Done in {toc - tic:.2f} seconds.")

else:
print("making Database object from raw files...")
print("Making Database object from scratch...")
print(
"(You can also use `get_dataset(..., download=True)` "
"for datasets prepared by the RelBench team.)"
)
tic = time.time()
db = self.make_db()
toc = time.time()
print(f"done in {toc - tic:.2f} seconds.")

print("reindexing pkeys and fkeys...")
tic = time.time()
db.reindex_pkeys_and_fkeys()
toc = time.time()
print(f"done in {toc - tic:.2f} seconds.")
print(f"Done in {toc - tic:.2f} seconds.")

if self.cache_dir:
print(f"caching Database object to {db_path}...")
print(f"Caching Database object to {db_path}...")
tic = time.time()
db.save(db_path)
toc = time.time()
print(f"done in {toc - tic:.2f} seconds.")
print(f"Done in {toc - tic:.2f} seconds.")

if upto_test_timestamp:
db = db.upto(self.test_timestamp)
Expand All @@ -82,4 +118,8 @@ def get_db(self, upto_test_timestamp=True) -> Database:
return db

def make_db(self) -> Database:
r"""Make the database object from scratch, i.e. using raw data sources.

To be implemented by subclass.
"""
raise NotImplementedError
41 changes: 24 additions & 17 deletions relbench/base/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ class Table:
r"""A table in a database.

Args:
df (pandas.DataFrame): The underlying data frame of the table.
fkey_col_to_pkey_table (Dict[str, str]): A dictionary mapping
df: The underlying data frame of the table.
fkey_col_to_pkey_table: A dictionary mapping
foreign key names to table names that contain the foreign keys as
primary keys.
pkey_col (str, optional): The primary key column if it exists.
(default: :obj:`None`)
time_col (str, optional): The time column. (default: :obj:`None`)
pkey_col: The primary key column if it exists.
time_col: The time column.
"""

def __init__(
Expand All @@ -45,12 +44,14 @@ def __repr__(self) -> str:
)

def __len__(self) -> int:
r"""Returns the number of rows in the table."""
r"""Return the number of rows in the table."""
return len(self.df)

def save(self, path: Union[str, os.PathLike]) -> None:
r"""Saves the table to a parquet file. Stores other attributes as
parquet metadata."""
r"""Save the table to a parquet file.

Stores other attributes as parquet metadata.
"""
assert str(path).endswith(".parquet")
metadata = {
"fkey_col_to_pkey_table": self.fkey_col_to_pkey_table,
Expand All @@ -76,7 +77,7 @@ def save(self, path: Union[str, os.PathLike]) -> None:

@classmethod
def load(cls, path: Union[str, os.PathLike]) -> Self:
r"""Loads a table from a parquet file."""
r"""Load a table from a parquet file."""
assert str(path).endswith(".parquet")

# Read the Parquet file using pyarrow
Expand All @@ -97,27 +98,33 @@ def load(cls, path: Union[str, os.PathLike]) -> Self:
time_col=metadata["time_col"],
)

def upto(self, time_stamp: pd.Timestamp) -> Self:
r"""Returns a table with all rows upto time."""
def upto(self, timestamp: pd.Timestamp) -> Self:
r"""Return a table with all rows upto timestamp (inclusive).

Table without time_col are returned as is.
"""

if self.time_col is None:
return self

return Table(
df=self.df.query(f"{self.time_col} <= @time_stamp"),
df=self.df.query(f"{self.time_col} <= @timestamp"),
fkey_col_to_pkey_table=self.fkey_col_to_pkey_table,
pkey_col=self.pkey_col,
time_col=self.time_col,
)

def from_(self, time_stamp: pd.Timestamp) -> Self:
r"""Returns a table with all rows from time."""
def from_(self, timestamp: pd.Timestamp) -> Self:
r"""Return a table with all rows from timestamp onwards (inclusive).

Table without time_col are returned as is.
"""

if self.time_col is None:
return self

return Table(
df=self.df.query(f"{self.time_col} >= @time_stamp"),
df=self.df.query(f"{self.time_col} >= @timestamp"),
fkey_col_to_pkey_table=self.fkey_col_to_pkey_table,
pkey_col=self.pkey_col,
time_col=self.time_col,
Expand All @@ -126,7 +133,7 @@ def from_(self, time_stamp: pd.Timestamp) -> Self:
@property
@lru_cache(maxsize=None)
def min_timestamp(self) -> pd.Timestamp:
r"""Returns the earliest time in the table."""
r"""Return the earliest time in the table."""

if self.time_col is None:
raise ValueError("Table has no time column.")
Expand All @@ -136,7 +143,7 @@ def min_timestamp(self) -> pd.Timestamp:
@property
@lru_cache(maxsize=None)
def max_timestamp(self) -> pd.Timestamp:
r"""Returns the latest time in the table."""
r"""Return the latest time in the table."""

if self.time_col is None:
raise ValueError("Table has no time column.")
Expand Down
Loading
Loading