From b8870857d2fe3f9d24c731b28e666e62ec0ba5b6 Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Sat, 6 Jul 2024 11:40:14 -0700 Subject: [PATCH 1/9] add docformatter pre-commit --- .pre-commit-config.yaml | 7 +++++++ examples/baseline_link.py | 4 ++-- examples/lightgbm_link.py | 14 +++++++------- relbench/base/database.py | 15 ++++++++------- relbench/base/table.py | 6 ++++-- relbench/base/task_link.py | 8 +++----- relbench/modeling/graph.py | 16 ++++++++-------- relbench/modeling/loader.py | 17 ++++++++++------- relbench/modeling/utils.py | 3 +-- relbench/tasks/amazon.py | 28 ++++++++++++++-------------- relbench/tasks/avito.py | 16 ++++++---------- relbench/tasks/event.py | 12 +++++------- relbench/tasks/f1.py | 10 ++++------ relbench/tasks/hm.py | 8 ++++---- relbench/tasks/stack.py | 12 ++++++------ relbench/tasks/trial.py | 3 ++- 16 files changed, 91 insertions(+), 88 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 45587576..286d15ae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,3 +26,10 @@ repos: hooks: - id: isort args: ["--profile", "black"] + + - repo: https://github.com/PyCQA/docformatter + rev: v1.7.5 + hooks: + - id: docformatter + additional_dependencies: [tomli] + args: ["--in-place", "--black"] diff --git a/examples/baseline_link.py b/examples/baseline_link.py index 9e1899ad..d9057856 100644 --- a/examples/baseline_link.py +++ b/examples/baseline_link.py @@ -69,8 +69,8 @@ def evaluate( ) pred = np.stack(pred_ser.values) elif name == "global_popularity": - """Predict the globally most visited dst nodes and predict them across - the src nodes.""" + """Predict the globally most visited dst nodes and predict them across the src + nodes.""" lst_cat = [] for lst in train_table.df[task.dst_entity_col]: lst_cat.extend(lst) diff --git a/examples/lightgbm_link.py b/examples/lightgbm_link.py index 077dbb1d..1d1aa34f 100644 --- a/examples/lightgbm_link.py +++ b/examples/lightgbm_link.py @@ -130,8 +130,8 @@ def add_past_label_feature( train_table_df: pd.DataFrame, past_table_df: pd.DataFrame, ) -> pd.DataFrame: - """Add past visit count and percentage of global popularity to train table - df used for lightGBM training, evaluation of testing. + """Add past visit count and percentage of global popularity to train table df used + for lightGBM training, evaluation of testing. Args: evaluate_table_df (pd.DataFrame): The dataframe used for evaluation. @@ -244,8 +244,8 @@ def add_past_label_feature( def prepare_for_link_pred_eval( evaluate_table_df: pd.DataFrame, past_table_df: pd.DataFrame ) -> pd.DataFrame: - """Transform evaluation dataframe into the correct format for link - prediction metric calculation. + """Transform evaluation dataframe into the correct format for link prediction metric + calculation. Args: pred_table_df (pd.DataFrame): The prediction dataframe. @@ -375,9 +375,9 @@ def evaluate( train_table: Table, task: LinkTask, ) -> Dict[str, float]: - """Given the input dataframe used for lightGBM binary link classification - and its output prediction scores and true labels, generate link prediction - evaluation metrics. + """Given the input dataframe used for lightGBM binary link classification and its + output prediction scores and true labels, generate link prediction evaluation + metrics. Args: lightgbm_output (pd.DataFrame): The lightGBM input dataframe merged diff --git a/relbench/base/database.py b/relbench/base/database.py index e0210b67..d646994b 100644 --- a/relbench/base/database.py +++ b/relbench/base/database.py @@ -10,8 +10,8 @@ class Database: - r"""A database is a collection of named tables linked by foreign key - - primary key connections.""" + r"""A database is a collection of named tables linked by foreign key - primary key + connections.""" def __init__(self, table_dict: Dict[str, Table]) -> None: r"""Creates a database from a dictionary of tables.""" @@ -22,8 +22,10 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}()" def save(self, path: Union[str, os.PathLike]) -> None: - r"""Saves the database to a directory. Simply saves each table - individually with the table name as base name of file.""" + r"""Saves the database to a directory. + + Simply saves each table individually with the table name as base name of file. + """ for name, table in self.table_dict.items(): table.save(f"{path}/{name}.parquet") @@ -80,9 +82,8 @@ def from_(self, time_stamp: pd.Timestamp) -> Self: ) def reindex_pkeys_and_fkeys(self) -> None: - r"""Mapping primary and foreign keys into indices according to - the ordering in the primary key tables. - """ + r"""Mapping primary and foreign keys into indices according to the ordering in + the primary key tables.""" # Get pkey to idx mapping: index_map_dict: Dict[str, pd.Series] = {} for table_name, table in self.table_dict.items(): diff --git a/relbench/base/table.py b/relbench/base/table.py index 0564a9e0..9c899aec 100644 --- a/relbench/base/table.py +++ b/relbench/base/table.py @@ -49,8 +49,10 @@ def __len__(self) -> int: return len(self.df) def save(self, path: Union[str, os.PathLike]) -> None: - r"""Saves the table to a parquet file. Stores other attributes as - parquet metadata.""" + r"""Saves the table to a parquet file. + + Stores other attributes as parquet metadata. + """ assert str(path).endswith(".parquet") metadata = { "fkey_col_to_pkey_table": self.fkey_col_to_pkey_table, diff --git a/relbench/base/task_link.py b/relbench/base/task_link.py index e9d3337b..d6a5ca2a 100644 --- a/relbench/base/task_link.py +++ b/relbench/base/task_link.py @@ -96,11 +96,9 @@ def num_dst_nodes(self) -> int: return len(self.dataset.get_db().table_dict[self.dst_entity_table]) def stats(self) -> Dict[str, Dict[str, int]]: - r"""Get train / val / test table statistics for each timestamp - and the whole table, including number of unique source entities, - number of unique destination entities, number of destination - entities and number of rows. - """ + r"""Get train / val / test table statistics for each timestamp and the whole + table, including number of unique source entities, number of unique destination + entities, number of destination entities and number of rows.""" res = {} for split in ["train", "val", "test"]: split_stats = {} diff --git a/relbench/modeling/graph.py b/relbench/modeling/graph.py index 874d9a11..a513a225 100644 --- a/relbench/modeling/graph.py +++ b/relbench/modeling/graph.py @@ -23,9 +23,8 @@ def make_pkey_fkey_graph( text_embedder_cfg: Optional[TextEmbedderConfig] = None, cache_dir: Optional[str] = None, ) -> Tuple[HeteroData, Dict[str, Dict[str, Dict[StatType, Any]]]]: - r"""Given a :class:`Database` object, construct a heterogeneous graph with - primary-foreign key relationships, together with the column stats of each - table. + r"""Given a :class:`Database` object, construct a heterogeneous graph with primary- + foreign key relationships, together with the column stats of each table. Args: db (Database): A database object containing a set of tables. @@ -114,11 +113,12 @@ def make_pkey_fkey_graph( class AttachTargetTransform: r"""Adds the target label to the heterogeneous mini-batch. - The batch consists of disjoins subgraphs loaded via temporal sampling. - The same input node can occur twice with different timestamps, and thus - different subgraphs and labels. Hence labels cannot be stored in the graph - object directly, and must be attached to the batch after the batch is - created.""" + + The batch consists of disjoins subgraphs loaded via temporal sampling. The same + input node can occur twice with different timestamps, and thus different subgraphs + and labels. Hence labels cannot be stored in the graph object directly, and must be + attached to the batch after the batch is created. + """ def __init__(self, entity: str, target: Tensor): self.entity = entity diff --git a/relbench/modeling/loader.py b/relbench/modeling/loader.py index d75ee9d7..941d5722 100644 --- a/relbench/modeling/loader.py +++ b/relbench/modeling/loader.py @@ -12,9 +12,8 @@ def batched_arange(count: Tensor) -> Tuple[Tensor, Tensor]: - r"""Fast implementation of bached version of torch.arange. - It essentially does the following - >>> batch = torch.cat([torch.full((c,), i) for i, c in enumerate(count)]) + r"""Fast implementation of bached version of torch.arange. It essentially does the + following >>> batch = torch.cat([torch.full((c,), i) for i, c in enumerate(count)]) >>> arange = torch.cat([torch.arange(c) for c in count]) Args: @@ -52,7 +51,8 @@ def __init__( self._col_indices = sparse_tensor.col_indices().to(device) def __getitem__(self, indices: Tensor) -> Tuple[Tensor, Tensor]: - r"""Given a tensor of row indices, return a tuple of tensors + r"""Given a tensor of row indices, return a tuple of tensors. + - :obj:`row_batch` (Tensor): Batch offset for column indices. - :obj:`col_index` (Tensor): Column indices. Specifically, :obj:`sparse_tensor[indices[i]]` can be obtained by @@ -124,8 +124,10 @@ def __len__(self) -> int: class CustomLinkDataset(Dataset): - r"""A custom link prediction dataset. Sample source nodes, time, and one - positive destination node.""" + r"""A custom link prediction dataset. + + Sample source nodes, time, and one positive destination node. + """ def __init__( self, @@ -143,7 +145,8 @@ def __init__( self.src_time = src_time def __getitem__(self, index) -> Tensor: - r"""Returns 1-dim tensor of size 3 + r"""Returns 1-dim tensor of size 3. + - source node index - positive destination node index - source node time diff --git a/relbench/modeling/utils.py b/relbench/modeling/utils.py index 441cf9f7..6bfb3cd5 100644 --- a/relbench/modeling/utils.py +++ b/relbench/modeling/utils.py @@ -9,8 +9,7 @@ def to_unix_time(ser: pd.Series) -> np.ndarray: - r"""Converts a :class:`pandas.Timestamp` series to UNIX timestamp - (in seconds).""" + r"""Converts a :class:`pandas.Timestamp` series to UNIX timestamp (in seconds).""" assert ser.dtype in [np.dtype("datetime64[s]"), np.dtype("datetime64[ns]")] unix_time = ser.astype(int).values if ser.dtype == np.dtype("datetime64[ns]"): diff --git a/relbench/tasks/amazon.py b/relbench/tasks/amazon.py index 35f9bfb0..ef2f0781 100644 --- a/relbench/tasks/amazon.py +++ b/relbench/tasks/amazon.py @@ -17,8 +17,8 @@ class UserChurnTask(NodeTask): - r"""Churn for a customer is 1 if the customer does not review any product - in the time window, else 0.""" + r"""Churn for a customer is 1 if the customer does not review any product in the + time window, else 0.""" task_type = TaskType.BINARY_CLASSIFICATION entity_col = "customer_id" @@ -73,8 +73,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class UserLTVTask(NodeTask): - r"""LTV (life-time value) for a customer is the sum of prices of products - that the customer reviews in the time window.""" + r"""LTV (life-time value) for a customer is the sum of prices of products that the + customer reviews in the time window.""" task_type = TaskType.REGRESSION entity_col = "customer_id" @@ -132,8 +132,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class ItemChurnTask(NodeTask): - r"""Churn for a product is 1 if the product recieves at least one review - in the time window, else 0.""" + r"""Churn for a product is 1 if the product recieves at least one review in the time + window, else 0.""" task_type = TaskType.BINARY_CLASSIFICATION entity_col = "product_id" @@ -188,8 +188,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class ItemLTVTask(NodeTask): - r"""LTV (life-time value) for a product is the numer of times the product - is purchased in the time window multiplied by price.""" + r"""LTV (life-time value) for a product is the numer of times the product is + purchased in the time window multiplied by price.""" task_type = TaskType.REGRESSION entity_col = "product_id" @@ -234,8 +234,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class UserItemPurchaseTask(LinkTask): - r"""Predict the list of distinct items each customer will purchase in the - next two years.""" + r"""Predict the list of distinct items each customer will purchase in the next two + years.""" task_type = TaskType.LINK_PREDICTION src_entity_col = "customer_id" @@ -285,8 +285,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class UserItemRateTask(LinkTask): - r"""Predict the list of distinct items each customer will purchase and give a 5 star review in the - next two years.""" + r"""Predict the list of distinct items each customer will purchase and give a 5 star + review in the next two years.""" task_type = TaskType.LINK_PREDICTION src_entity_col = "customer_id" @@ -338,8 +338,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class UserItemReviewTask(LinkTask): - r"""Predict the list of distinct items each customer will purchase and give a detailed review in the - next two years.""" + r"""Predict the list of distinct items each customer will purchase and give a + detailed review in the next two years.""" task_type = TaskType.LINK_PREDICTION src_entity_col = "customer_id" diff --git a/relbench/tasks/avito.py b/relbench/tasks/avito.py index f486cfdc..7970ee7b 100644 --- a/relbench/tasks/avito.py +++ b/relbench/tasks/avito.py @@ -17,9 +17,8 @@ class AdCTRTask(NodeTask): - r"""Assuming the ad will be clicked in the next 4 days, predict the - Click-Through-Rate (CTR) for each ad. - """ + r"""Assuming the ad will be clicked in the next 4 days, predict the Click-Through- + Rate (CTR) for each ad.""" task_type = TaskType.REGRESSION entity_table = "AdsInfo" @@ -68,9 +67,7 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class UserVisitsTask(NodeTask): - r"""Predict whether each customer will visit more than one ad in the next - 4 days. - """ + r"""Predict whether each customer will visit more than one ad in the next 4 days.""" task_type = TaskType.BINARY_CLASSIFICATION entity_table = "UserInfo" @@ -118,9 +115,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class UserClicksTask(NodeTask): - r"""Predict whether the each customer will click on more than one ads in - the next 4 days - """ + r"""Predict whether the each customer will click on more than one ads in the next 4 + days.""" task_type = TaskType.BINARY_CLASSIFICATION entity_table = "UserInfo" @@ -176,7 +172,7 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class UserAdVisitTask(LinkTask): - r"""Predict the distinct list of ads a user will visit in the next 4 days""" + r"""Predict the distinct list of ads a user will visit in the next 4 days.""" task_type = TaskType.LINK_PREDICTION src_entity_table = "UserInfo" diff --git a/relbench/tasks/event.py b/relbench/tasks/event.py index 8bd0245d..42be63a5 100644 --- a/relbench/tasks/event.py +++ b/relbench/tasks/event.py @@ -6,8 +6,7 @@ class UserAttendanceTask(NodeTask): - r"""Predict the number of events a user will go to in the next seven days - 7 days.""" + r"""Predict the number of events a user will go to in the next seven days 7 days.""" task_type = TaskType.REGRESSION entity_col = "user" @@ -58,9 +57,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class UserRepeatTask(NodeTask): - r"""Predict whether a user will attend an event in the - next 7 days if they have already attended an event in the - last 14 days.""" + r"""Predict whether a user will attend an event in the next 7 days if they have + already attended an event in the last 14 days.""" task_type = TaskType.BINARY_CLASSIFICATION entity_col = "user" @@ -137,8 +135,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class UserIgnoreTask(NodeTask): - r"""Predict whether a user will ignore more than 2 event invitations - in the next 7 days.""" + r"""Predict whether a user will ignore more than 2 event invitations in the next 7 + days.""" task_type = TaskType.BINARY_CLASSIFICATION entity_col = "user" diff --git a/relbench/tasks/f1.py b/relbench/tasks/f1.py index d11dd6cb..ec2e4ace 100644 --- a/relbench/tasks/f1.py +++ b/relbench/tasks/f1.py @@ -6,9 +6,8 @@ class DriverPositionTask(NodeTask): - r"""Predict the average finishing position of each driver - all races in the next 2 months. - """ + r"""Predict the average finishing position of each driver all races in the next 2 + months.""" task_type = TaskType.REGRESSION entity_col = "driverId" @@ -123,9 +122,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class DriverTop3Task(NodeTask): - r"""Predict if each driver will qualify in the top-3 for - a race within the next 1 month. - """ + r"""Predict if each driver will qualify in the top-3 for a race within the next 1 + month.""" task_type = TaskType.BINARY_CLASSIFICATION entity_col = "driverId" diff --git a/relbench/tasks/hm.py b/relbench/tasks/hm.py index 25743f39..7e93b596 100644 --- a/relbench/tasks/hm.py +++ b/relbench/tasks/hm.py @@ -17,8 +17,8 @@ class UserItemPurchaseTask(LinkTask): - r"""Predict the list of articles each customer will purchase in the next - seven days""" + r"""Predict the list of articles each customer will purchase in the next seven + days.""" task_type = TaskType.LINK_PREDICTION src_entity_col = "customer_id" @@ -120,8 +120,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class ItemSalesTask(NodeTask): - r"""Predict the total sales for an article (the sum of prices of the - associated transactions) in the next week.""" + r"""Predict the total sales for an article (the sum of prices of the associated + transactions) in the next week.""" task_type = TaskType.REGRESSION entity_col = "article_id" diff --git a/relbench/tasks/stack.py b/relbench/tasks/stack.py index a203eaa2..f803b7e5 100644 --- a/relbench/tasks/stack.py +++ b/relbench/tasks/stack.py @@ -105,8 +105,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class PostVotesTask(NodeTask): - r"""Predict the number of upvotes that an existing question will receive in - the next 2 years.""" + r"""Predict the number of upvotes that an existing question will receive in the next + 2 years.""" task_type = TaskType.REGRESSION entity_col = "PostId" @@ -218,8 +218,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class UserPostCommentTask(LinkTask): - r"""Predict a list of existing posts that a user will comment in the next - two years.""" + r"""Predict a list of existing posts that a user will comment in the next two + years.""" task_type = TaskType.LINK_PREDICTION src_entity_col = "UserId" @@ -279,8 +279,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class PostPostRelatedTask(LinkTask): - r"""Predict a list of existing posts that users will link a given post to in the next - two years.""" + r"""Predict a list of existing posts that users will link a given post to in the + next two years.""" task_type = TaskType.LINK_PREDICTION src_entity_col = "PostId" diff --git a/relbench/tasks/trial.py b/relbench/tasks/trial.py index 1e69c3cb..dd876eb4 100644 --- a/relbench/tasks/trial.py +++ b/relbench/tasks/trial.py @@ -78,7 +78,8 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab class StudyAdverseTask(NodeTask): - r"""Predict the number of affected patients with severe advsere events/death for the trial in the next 1 year.""" + r"""Predict the number of affected patients with severe advsere events/death for the + trial in the next 1 year.""" task_type = TaskType.REGRESSION entity_col = "nct_id" From cad2d757ad7b693542af831618d9b0bd64f0d61f Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Sat, 6 Jul 2024 12:08:45 -0700 Subject: [PATCH 2/9] update --- relbench/base/database.py | 24 +++++++------- relbench/base/dataset.py | 65 ++++++++++++++++++++++++++++++-------- relbench/base/table.py | 37 ++++++++++++---------- relbench/modeling/graph.py | 28 +++++++++++----- 4 files changed, 105 insertions(+), 49 deletions(-) diff --git a/relbench/base/database.py b/relbench/base/database.py index d646994b..4bc1192b 100644 --- a/relbench/base/database.py +++ b/relbench/base/database.py @@ -22,7 +22,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}()" def save(self, path: Union[str, os.PathLike]) -> None: - r"""Saves the database to a directory. + r"""Save the database to a directory. Simply saves each table individually with the table name as base name of file. """ @@ -32,7 +32,7 @@ def save(self, path: Union[str, os.PathLike]) -> None: @classmethod def load(cls, path: Union[str, os.PathLike]) -> Self: - r"""Loads a database from a directory of tables in parquet files.""" + r"""Load a database from a directory of tables in parquet files.""" table_dict = {} for table_path in Path(path).glob("*.parquet"): @@ -44,7 +44,7 @@ def load(cls, path: Union[str, os.PathLike]) -> Self: @property @lru_cache(maxsize=None) def min_timestamp(self) -> pd.Timestamp: - r"""Returns the earliest timestamp in the database.""" + r"""Return the earliest timestamp in the database.""" return min( table.min_timestamp @@ -55,7 +55,7 @@ def min_timestamp(self) -> pd.Timestamp: @property @lru_cache(maxsize=None) def max_timestamp(self) -> pd.Timestamp: - r"""Returns the latest timestamp in the database.""" + r"""Return the latest timestamp in the database.""" return max( table.max_timestamp @@ -63,27 +63,27 @@ def max_timestamp(self) -> pd.Timestamp: if table.time_col is not None ) - def upto(self, time_stamp: pd.Timestamp) -> Self: - r"""Returns a database with all rows upto time_stamp.""" + def upto(self, timestamp: pd.Timestamp) -> Self: + r"""Return a database with all rows upto timestamp.""" return Database( table_dict={ - name: table.upto(time_stamp) for name, table in self.table_dict.items() + name: table.upto(timestamp) for name, table in self.table_dict.items() } ) - def from_(self, time_stamp: pd.Timestamp) -> Self: - r"""Returns a database with all rows from time_stamp.""" + def from_(self, timestamp: pd.Timestamp) -> Self: + r"""Return a database with all rows from timestamp.""" return Database( table_dict={ - name: table.from_(time_stamp) for name, table in self.table_dict.items() + name: table.from_(timestamp) for name, table in self.table_dict.items() } ) def reindex_pkeys_and_fkeys(self) -> None: - r"""Mapping primary and foreign keys into indices according to the ordering in - the primary key tables.""" + r"""Map primary and foreign keys into indices according to the ordering in the + primary key tables.""" # Get pkey to idx mapping: index_map_dict: Dict[str, pd.Series] = {} for table_name, table in self.table_dict.items(): diff --git a/relbench/base/dataset.py b/relbench/base/dataset.py index 1ffb4e7c..5ae6d6e7 100644 --- a/relbench/base/dataset.py +++ b/relbench/base/dataset.py @@ -10,6 +10,17 @@ class Dataset: + r"""A dataset is a database with validation and test timestamps defined for it. + + val_timestamp: Rows upto this timestamp (inclusive) can be input for validation. + test_timestamp: Rows upto this timestamp (inclusive) can be input for testing. + + Validation split of a task involves predicting the target variable for a + time period after val_timestamp (exclusive) using data upto val_timestamp. + Similarly for test_timestamp. + """ + + # To be set by subclass. val_timestamp: pd.Timestamp test_timestamp: pd.Timestamp @@ -17,13 +28,31 @@ def __init__( self, cache_dir: Optional[str] = None, ) -> None: + r"""Create a dataset object. + + Args: + cache_dir: A directory for caching the database object. If specified, + we will either process and cache the file (if not available) or use the cached file. If None, + we will not use cached file and re-process everything from scratch + without saving the cache. + """ + self.cache_dir = cache_dir def __repr__(self) -> str: - return f"{self.__class__.__name__}()" + return ( + f"{self.__class__.__name__}(\n" + f"val_timestamp={self.val_timestamp},\n" + f"test_timestamp={self.test_timestamp},\n" + f"cache_dir={self.cache_dir},\n" + f")" + ) def validate_and_correct_db(self, db): - r"""Validate and correct input db in-place.""" + r"""Validate and correct input db in-place. + + Removing rows after test_timestamp can result in dangling foreign keys. + """ # Validate that all primary keys are consecutively index. for table_name, table in db.table_dict.items(): @@ -46,33 +75,39 @@ def validate_and_correct_db(self, db): @lru_cache(maxsize=None) def get_db(self, upto_test_timestamp=True) -> Database: + r"""Return the database object. + + The returned database object is cached in memory. + + Args: + upto_test_timestamp: If True, only return rows upto test_timestamp. + + Returns: + Database: The database object. + """ + db_path = f"{self.cache_dir}/db" if self.cache_dir and Path(db_path).exists() and any(Path(db_path).iterdir()): - print(f"loading Database object from {db_path}...") + print(f"Loading Database object from {db_path}...") tic = time.time() db = Database.load(db_path) toc = time.time() - print(f"done in {toc - tic:.2f} seconds.") + print(f"Done in {toc - tic:.2f} seconds.") else: - print("making Database object from raw files...") + print("Making Database object from scratch...") tic = time.time() db = self.make_db() - toc = time.time() - print(f"done in {toc - tic:.2f} seconds.") - - print("reindexing pkeys and fkeys...") - tic = time.time() db.reindex_pkeys_and_fkeys() toc = time.time() - print(f"done in {toc - tic:.2f} seconds.") + print(f"Done in {toc - tic:.2f} seconds.") if self.cache_dir: - print(f"caching Database object to {db_path}...") + print(f"Caching Database object to {db_path}...") tic = time.time() db.save(db_path) toc = time.time() - print(f"done in {toc - tic:.2f} seconds.") + print(f"Done in {toc - tic:.2f} seconds.") if upto_test_timestamp: db = db.upto(self.test_timestamp) @@ -82,4 +117,8 @@ def get_db(self, upto_test_timestamp=True) -> Database: return db def make_db(self) -> Database: + r"""Make the database object from scratch, i.e. using raw data sources. + + To be implemented by subclass. + """ raise NotImplementedError diff --git a/relbench/base/table.py b/relbench/base/table.py index 9c899aec..f8e78d8d 100644 --- a/relbench/base/table.py +++ b/relbench/base/table.py @@ -14,13 +14,12 @@ class Table: r"""A table in a database. Args: - df (pandas.DataFrame): The underlying data frame of the table. - fkey_col_to_pkey_table (Dict[str, str]): A dictionary mapping + df: The underlying data frame of the table. + fkey_col_to_pkey_table: A dictionary mapping foreign key names to table names that contain the foreign keys as primary keys. - pkey_col (str, optional): The primary key column if it exists. - (default: :obj:`None`) - time_col (str, optional): The time column. (default: :obj:`None`) + pkey_col: The primary key column if it exists. + time_col: The time column. """ def __init__( @@ -45,11 +44,11 @@ def __repr__(self) -> str: ) def __len__(self) -> int: - r"""Returns the number of rows in the table.""" + r"""Return the number of rows in the table.""" return len(self.df) def save(self, path: Union[str, os.PathLike]) -> None: - r"""Saves the table to a parquet file. + r"""Save the table to a parquet file. Stores other attributes as parquet metadata. """ @@ -78,7 +77,7 @@ def save(self, path: Union[str, os.PathLike]) -> None: @classmethod def load(cls, path: Union[str, os.PathLike]) -> Self: - r"""Loads a table from a parquet file.""" + r"""Load a table from a parquet file.""" assert str(path).endswith(".parquet") # Read the Parquet file using pyarrow @@ -99,27 +98,33 @@ def load(cls, path: Union[str, os.PathLike]) -> Self: time_col=metadata["time_col"], ) - def upto(self, time_stamp: pd.Timestamp) -> Self: - r"""Returns a table with all rows upto time.""" + def upto(self, timestamp: pd.Timestamp) -> Self: + r"""Return a table with all rows upto timestamp (inclusive). + + Table without time_col are returned as is. + """ if self.time_col is None: return self return Table( - df=self.df.query(f"{self.time_col} <= @time_stamp"), + df=self.df.query(f"{self.time_col} <= @timestamp"), fkey_col_to_pkey_table=self.fkey_col_to_pkey_table, pkey_col=self.pkey_col, time_col=self.time_col, ) - def from_(self, time_stamp: pd.Timestamp) -> Self: - r"""Returns a table with all rows from time.""" + def from_(self, timestamp: pd.Timestamp) -> Self: + r"""Return a table with all rows from timestamp onwards (inclusive). + + Table without time_col are returned as is. + """ if self.time_col is None: return self return Table( - df=self.df.query(f"{self.time_col} >= @time_stamp"), + df=self.df.query(f"{self.time_col} >= @timestamp"), fkey_col_to_pkey_table=self.fkey_col_to_pkey_table, pkey_col=self.pkey_col, time_col=self.time_col, @@ -128,7 +133,7 @@ def from_(self, time_stamp: pd.Timestamp) -> Self: @property @lru_cache(maxsize=None) def min_timestamp(self) -> pd.Timestamp: - r"""Returns the earliest time in the table.""" + r"""Return the earliest time in the table.""" if self.time_col is None: raise ValueError("Table has no time column.") @@ -138,7 +143,7 @@ def min_timestamp(self) -> pd.Timestamp: @property @lru_cache(maxsize=None) def max_timestamp(self) -> pd.Timestamp: - r"""Returns the latest time in the table.""" + r"""Return the latest time in the table.""" if self.time_col is None: raise ValueError("Table has no time column.") diff --git a/relbench/modeling/graph.py b/relbench/modeling/graph.py index a513a225..db25cb0d 100644 --- a/relbench/modeling/graph.py +++ b/relbench/modeling/graph.py @@ -27,11 +27,11 @@ def make_pkey_fkey_graph( foreign key relationships, together with the column stats of each table. Args: - db (Database): A database object containing a set of tables. - col_to_stype_dict (Dict[str, Dict[str, stype]]): Column to stype for + db: A database object containing a set of tables. + col_to_stype_dict: Column to stype for each table. - text_embedder_cfg (TextEmbedderConfig): Text embedder config. - cache_dir (str, optional): A directory for storing materialized tensor + text_embedder_cfg: Text embedder config. + cache_dir: A directory for storing materialized tensor frames. If specified, we will either cache the file or use the cached file. If not specified, we will not use cached file and re-process everything from scratch without saving the cache. @@ -112,12 +112,12 @@ def make_pkey_fkey_graph( class AttachTargetTransform: - r"""Adds the target label to the heterogeneous mini-batch. + r"""Attach the target label to the heterogeneous mini-batch. The batch consists of disjoins subgraphs loaded via temporal sampling. The same - input node can occur twice with different timestamps, and thus different subgraphs - and labels. Hence labels cannot be stored in the graph object directly, and must be - attached to the batch after the batch is created. + input node can occur multiple times with different timestamps, and thus different + subgraphs and labels. Hence labels cannot be stored in the graph object directly, + and must be attached to the batch after the batch is created. """ def __init__(self, entity: str, target: Tensor): @@ -130,6 +130,14 @@ def __call__(self, batch: HeteroData) -> HeteroData: class NodeTrainTableInput(NamedTuple): + r"""Trainining table input for node prediction. + + - nodes is a Tensor of node indices. + - time is a Tensor of node timestamps. + - target is a Tensor of node labels. + - transform attaches the target to the batch. + """ + nodes: Tuple[NodeType, Tensor] time: Optional[Tensor] target: Optional[Tensor] @@ -141,6 +149,8 @@ def get_node_train_table_input( task: NodeTask, multilabel: bool = False, ) -> NodeTrainTableInput: + r"""Get the training table input for node prediction.""" + nodes = torch.from_numpy(table.df[task.entity_col].astype(int).values) time: Optional[Tensor] = None @@ -191,6 +201,8 @@ def get_link_train_table_input( table: Table, task: LinkTask, ) -> LinkTrainTableInput: + r"""Get the training table input for link prediction.""" + src_node_idx: Tensor = torch.from_numpy( table.df[task.src_entity_col].astype(int).values ) From 16ee4e43078c5a957dd075a5c691fff48a90fdb7 Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Sat, 6 Jul 2024 12:31:31 -0700 Subject: [PATCH 3/9] update --- relbench/base/dataset.py | 15 ++++--- relbench/base/task_base.py | 80 +++++++++++++++++++++++++++++++++++--- 2 files changed, 83 insertions(+), 12 deletions(-) diff --git a/relbench/base/dataset.py b/relbench/base/dataset.py index 5ae6d6e7..a685d239 100644 --- a/relbench/base/dataset.py +++ b/relbench/base/dataset.py @@ -12,8 +12,9 @@ class Dataset: r"""A dataset is a database with validation and test timestamps defined for it. - val_timestamp: Rows upto this timestamp (inclusive) can be input for validation. - test_timestamp: Rows upto this timestamp (inclusive) can be input for testing. + Attributes: + val_timestamp: Rows upto this timestamp (inclusive) can be input for validation. + test_timestamp: Rows upto this timestamp (inclusive) can be input for testing. Validation split of a task involves predicting the target variable for a time period after val_timestamp (exclusive) using data upto val_timestamp. @@ -32,9 +33,9 @@ def __init__( Args: cache_dir: A directory for caching the database object. If specified, - we will either process and cache the file (if not available) or use the cached file. If None, - we will not use cached file and re-process everything from scratch - without saving the cache. + we will either process and cache the file (if not available) or use + the cached file. If None, we will not use cached file and re-process + everything from scratch without saving the cache. """ self.cache_dir = cache_dir @@ -44,7 +45,7 @@ def __repr__(self) -> str: f"{self.__class__.__name__}(\n" f"val_timestamp={self.val_timestamp},\n" f"test_timestamp={self.test_timestamp},\n" - f"cache_dir={self.cache_dir},\n" + # f"cache_dir={self.cache_dir},\n" f")" ) @@ -84,6 +85,8 @@ def get_db(self, upto_test_timestamp=True) -> Database: Returns: Database: The database object. + + `upto_test_timestamp` is True by default to prevent test leakage. """ db_path = f"{self.cache_dir}/db" diff --git a/relbench/base/task_base.py b/relbench/base/task_base.py index 8bbee168..c7f4cff5 100644 --- a/relbench/base/task_base.py +++ b/relbench/base/task_base.py @@ -30,8 +30,18 @@ class TaskType(Enum): class BaseTask: - r"""A task on a dataset.""" + r"""Base class for a task on a dataset. + Attributes: + task_type: The type of the task. + timedelta: The prediction task at `timestamp` is over the time window (timestamp, timestamp + timedelta]. + num_eval_timestamps: The number of evaluation time windows. e.g., test time windows are (test_timestamp, test_timestamp + timedelta] ... (test_timestamp + (num_eval_timestamps - 1) * timedelta, test_timestamp + num_eval_timestamps * timedelta]. + metrics: The metrics to evaluate this task on. + + Inherited by NodeTask and LinkTask. + """ + + # To be set by subclass. task_type: TaskType timedelta: pd.Timedelta num_eval_timestamps: int = 1 @@ -42,6 +52,15 @@ def __init__( dataset: Dataset, cache_dir: Optional[str] = None, ): + r"""Create a task object. + + Args: + dataset: The dataset object on which the task is defined. + cache_dir: A directory for caching the task table objects. If specified, + we will either process and cache the file (if not available) or use + the cached file. If None, we will not use cached file and re-process + everything from scratch without saving the cache. + """ self.dataset = dataset self.cache_dir = cache_dir @@ -54,18 +73,37 @@ def __init__( ) def __repr__(self) -> str: - return f"{self.__class__.__name__}(dataset={self.dataset})" + return ( + f"{self.__class__.__name__}(\n" + f"dataset={self.dataset},\n" + # f"cache_dir={self.cache_dir},\n" + f"task_type={self.task_type},\n" + f"timedelta={self.timedelta},\n" + # f"num_eval_timestamps={self.num_eval_timestamps},\n" + # f"metrics={self.metrics},\n" + ) def make_table( self, db: Database, timestamps: "pd.Series[pd.Timestamp]", ) -> Table: - r"""To be implemented by subclass.""" + r"""Make a table using the task definition. + + Args: + db: The database object to use for (historical) ground truth. + timestamps: Collection of timestamps to compute labels for. A label can be + computed for a timestamp using historical data + upto this timestamp in the database. + + To be implemented by subclass. + """ raise NotImplementedError def _get_table(self, split: str) -> Table: + r"""Helper function to get a table for a split.""" + db = self.dataset.get_db(upto_test_timestamp=split != "test") if split == "train": @@ -120,6 +158,20 @@ def _get_table(self, split: str) -> Table: @lru_cache(maxsize=None) def get_table(self, split, mask_input_cols=None): + r"""Get a table for a split. + + Args: + split: The split to get the table for. One of "train", "val", or "test". + mask_input_cols: If True, keep only the input columns in the table. If + None, mask the input columns only for the test split. This helps + prevent data leakage. + + Returns: + The task table for the split. + + The table is cached in memory. + """ + if mask_input_cols is None: mask_input_cols = split == "test" @@ -149,9 +201,25 @@ def _mask_input_cols(self, table: Table) -> Table: ) def filter_dangling_entities(self, table: Table) -> Table: - r"""Filter out dangling entities from a table.""" + r"""Filter out dangling entities from a table. + + Implemented by NodeTask and LinkTask. + """ raise NotImplementedError - def evaluate(self): - r"""Evaluate a prediction table.""" + def evaluate( + self, + pred: NDArray, + target_table: Optional[Table] = None, + metrics: Optional[List[Callable[[NDArray, NDArray], float]]] = None, + ): + r"""Evaluate predictions on the task. + + Args: + pred: Predictions as a numpy array. + target_table: The target table. If None, use the test table. + metrics: The metrics to evaluate the prediction table. If None, use the default metrics for the task. + + Implemented by NodeTask and LinkTask. + """ raise NotImplementedError From 9c8525bc5ae9e7e1ad9898d029eacad67989b4bc Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Sat, 6 Jul 2024 12:37:43 -0700 Subject: [PATCH 4/9] update --- relbench/base/task_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/relbench/base/task_base.py b/relbench/base/task_base.py index c7f4cff5..36f1a9d2 100644 --- a/relbench/base/task_base.py +++ b/relbench/base/task_base.py @@ -218,7 +218,8 @@ def evaluate( Args: pred: Predictions as a numpy array. target_table: The target table. If None, use the test table. - metrics: The metrics to evaluate the prediction table. If None, use the default metrics for the task. + metrics: The metrics to evaluate the prediction table. If None, use + the default metrics for the task. Implemented by NodeTask and LinkTask. """ From 975674437fc50ee67ca3682239a7df2ea727fc5d Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Sat, 6 Jul 2024 12:43:21 -0700 Subject: [PATCH 5/9] all relbench.base docstrings done --- relbench/base/task_base.py | 8 ++++++-- relbench/base/task_link.py | 14 +++++++++++++- relbench/base/task_node.py | 13 +++++++++++-- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/relbench/base/task_base.py b/relbench/base/task_base.py index 36f1a9d2..a895c186 100644 --- a/relbench/base/task_base.py +++ b/relbench/base/task_base.py @@ -34,8 +34,12 @@ class BaseTask: Attributes: task_type: The type of the task. - timedelta: The prediction task at `timestamp` is over the time window (timestamp, timestamp + timedelta]. - num_eval_timestamps: The number of evaluation time windows. e.g., test time windows are (test_timestamp, test_timestamp + timedelta] ... (test_timestamp + (num_eval_timestamps - 1) * timedelta, test_timestamp + num_eval_timestamps * timedelta]. + timedelta: The prediction task at `timestamp` is over the time window + (timestamp, timestamp + timedelta]. + num_eval_timestamps: The number of evaluation time windows. e.g., test + time windows are (test_timestamp, test_timestamp + timedelta] ... + (test_timestamp + (num_eval_timestamps - 1) * timedelta, test_timestamp + + num_eval_timestamps * timedelta]. metrics: The metrics to evaluate this task on. Inherited by NodeTask and LinkTask. diff --git a/relbench/base/task_link.py b/relbench/base/task_link.py index d6a5ca2a..aac25e69 100644 --- a/relbench/base/task_link.py +++ b/relbench/base/task_link.py @@ -12,7 +12,18 @@ class LinkTask(BaseTask): - r"""A link prediction task on a dataset.""" + r"""A link prediction task on a dataset. + + Attributes: + src_entity_col: The source entity column. + src_entity_table: The source entity table. + dst_entity_col: The destination entity column. + dst_entity_table: The destination entity table. + time_col: The time column. + eval_k: k for eval@k metrics. + + Other attributes are inherited from BaseTask. + """ src_entity_col: str src_entity_table: str @@ -99,6 +110,7 @@ def stats(self) -> Dict[str, Dict[str, int]]: r"""Get train / val / test table statistics for each timestamp and the whole table, including number of unique source entities, number of unique destination entities, number of destination entities and number of rows.""" + res = {} for split in ["train", "val", "test"]: split_stats = {} diff --git a/relbench/base/task_node.py b/relbench/base/task_node.py index 6b3d641a..f6012a49 100644 --- a/relbench/base/task_node.py +++ b/relbench/base/task_node.py @@ -11,7 +11,16 @@ class NodeTask(BaseTask): - r"""A link prediction task on a dataset.""" + r"""A node prediction task on a dataset. + + Attributes: + entity_col: The entity column. + entity_table: The entity table. + time_col: The time column. + target_col: The target column. + + Other attributes are inherited from BaseTask. + """ entity_col: str entity_table: str @@ -53,7 +62,7 @@ def evaluate( return {fn.__name__: fn(target, pred) for fn in metrics} - def stats(self) -> dict[str, dict[str, Any]]: + def stats(self) -> Dict[str, Dict[str, Any]]: r"""Get train / val / test table statistics for each timestamp and the whole table, including number of rows and number of entities. Tasks with different task types have different statistics computed: From b58b9167d48e222932ee36e3f47a0ad67495cf68 Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Sat, 6 Jul 2024 13:02:35 -0700 Subject: [PATCH 6/9] datasets and tasks docstrings and prints --- relbench/base/dataset.py | 4 ++++ relbench/base/task_base.py | 10 +++++++++ relbench/datasets/__init__.py | 37 +++++++++++++++++++++++++++++-- relbench/tasks/__init__.py | 41 ++++++++++++++++++++++++++++++++--- 4 files changed, 87 insertions(+), 5 deletions(-) diff --git a/relbench/base/dataset.py b/relbench/base/dataset.py index a685d239..a62e4999 100644 --- a/relbench/base/dataset.py +++ b/relbench/base/dataset.py @@ -99,6 +99,10 @@ def get_db(self, upto_test_timestamp=True) -> Database: else: print("Making Database object from scratch...") + print( + "(You can also use `get_dataset(..., download=True)`" + "for datasets prepared by the RelBench team.)" + ) tic = time.time() db = self.make_db() db.reindex_pkeys_and_fkeys() diff --git a/relbench/base/task_base.py b/relbench/base/task_base.py index a895c186..5992062a 100644 --- a/relbench/base/task_base.py +++ b/relbench/base/task_base.py @@ -1,3 +1,4 @@ +import time from enum import Enum from functools import lru_cache from pathlib import Path @@ -183,7 +184,16 @@ def get_table(self, split, mask_input_cols=None): if self.cache_dir and Path(table_path).exists(): table = Table.load(table_path) else: + print(f"Making task table for {split} split from scratch...") + print( + "(You can also use `get_task(..., download=True)` " + "for tasks prepared by the RelBench team.)" + ) + tic = time.time() table = self._get_table(split) + toc = time.time() + print(f"Done in {toc - tic:.2f} seconds.") + if self.cache_dir: table.save(table_path) diff --git a/relbench/datasets/__init__.py b/relbench/datasets/__init__.py index bcef97cf..9de0f358 100644 --- a/relbench/datasets/__init__.py +++ b/relbench/datasets/__init__.py @@ -1,6 +1,7 @@ import json import pkgutil from functools import lru_cache +from typing import List import pooch @@ -24,17 +25,37 @@ def register_dataset( cls: Dataset, *args, **kwargs, -): +) -> None: + r"""Register an instantiation of a :class:`Dataset` subclass with the given name. + + Args: + name: The name of the dataset. + cls: The class of the dataset. + args: The arguments to instantiate the dataset. + kwargs: The keyword arguments to instantiate the dataset. + + The name is used to enable caching and downloading functionalities. + `cache_dir` is added to kwargs by default. If you want to override it, you + can pass `cache_dir` as a keyword argument in `kwargs`. + """ + cache_dir = f"{pooch.os_cache('relbench')}/{name}" kwargs = {"cache_dir": cache_dir, **kwargs} dataset_registry[name] = (cls, args, kwargs) -def get_dataset_names(): +def get_dataset_names() -> List[str]: + r"""Return a list of names of the registered datasets.""" return list(dataset_registry.keys()) def download_dataset(name: str) -> None: + r"""Download dataset from RelBench server into its cache directory. + + The downloaded database will be automatically picked up by the dataset object, when + `dataset.get_db()` is called. + """ + DOWNLOAD_REGISTRY.fetch( f"{name}/db.zip", processor=pooch.Unzip(extract_dir="db"), @@ -44,6 +65,18 @@ def download_dataset(name: str) -> None: @lru_cache(maxsize=None) def get_dataset(name: str, download=False) -> Dataset: + r"""Return a dataset object by name. + + Args: + name: The name of the dataset. + download: If True, download the dataset from the RelBench server. + + Returns: + Dataset: The dataset object. + + If `download` is True, the dataset will be downloaded into the cache. + """ + if download: download_dataset(name) cls, args, kwargs = dataset_registry[name] diff --git a/relbench/tasks/__init__.py b/relbench/tasks/__init__.py index c096ba80..d4e875d4 100644 --- a/relbench/tasks/__init__.py +++ b/relbench/tasks/__init__.py @@ -2,6 +2,7 @@ import pkgutil from collections import defaultdict from functools import lru_cache +from typing import List import pooch @@ -24,20 +25,41 @@ def register_task( dataset_name: str, task_name: str, - cls, + cls: BaseTask, *args, **kwargs, -): +) -> None: + r"""Register an instantiation of a :class:`BaseTask` subclass with the given name. + + Args: + dataset_name: The name of the dataset. + task_name: The name of the task. + cls: The class of the task. + args: The arguments to instantiate the task. + kwargs: The keyword arguments to instantiate the task. + + The name is used to enable caching and downloading functionalities. + `cache_dir` is added to kwargs by default. If you want to override it, you + can pass `cache_dir` as a keyword argument in `kwargs`. + """ + cache_dir = f"{pooch.os_cache('relbench')}/{dataset_name}/tasks/{task_name}" kwargs = {"cache_dir": cache_dir, **kwargs} task_registry[dataset_name][task_name] = (cls, args, kwargs) -def get_task_names(dataset_name: str): +def get_task_names(dataset_name: str) -> List[str]: + r"""Return a list of names of the registered tasks for the given dataset.""" return list(task_registry[dataset_name].keys()) def download_task(dataset_name: str, task_name: str) -> None: + r"""Download task from RelBench server into its cache directory. + + The downloaded task tables will be automatically picked up by the task object, when + `task.get_table(split)` is called. + """ + DOWNLOAD_REGISTRY.fetch( f"{dataset_name}/tasks/{task_name}.zip", processor=pooch.Unzip(extract_dir=task_name), @@ -47,6 +69,19 @@ def download_task(dataset_name: str, task_name: str) -> None: @lru_cache(maxsize=None) def get_task(dataset_name: str, task_name: str, download=False) -> BaseTask: + r"""Return a task object by name. + + Args: + dataset_name: The name of the dataset. + task_name: The name of the task. + download: If True, download the task from the RelBench server. + + Returns: + BaseTask: The task object. + + If `download` is True, the task will be downloaded into the cache. + """ + if download: download_task(dataset_name, task_name) dataset = get_dataset(dataset_name) From 41f39dbdf1245b9ce857e074b86ba70542ba7996 Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Sat, 6 Jul 2024 14:16:10 -0700 Subject: [PATCH 7/9] revert repr changes --- relbench/base/dataset.py | 10 ++-------- relbench/base/task_base.py | 10 +--------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/relbench/base/dataset.py b/relbench/base/dataset.py index a62e4999..9c06a17a 100644 --- a/relbench/base/dataset.py +++ b/relbench/base/dataset.py @@ -41,13 +41,7 @@ def __init__( self.cache_dir = cache_dir def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(\n" - f"val_timestamp={self.val_timestamp},\n" - f"test_timestamp={self.test_timestamp},\n" - # f"cache_dir={self.cache_dir},\n" - f")" - ) + return f"{self.__class__.__name__}()" def validate_and_correct_db(self, db): r"""Validate and correct input db in-place. @@ -100,7 +94,7 @@ def get_db(self, upto_test_timestamp=True) -> Database: else: print("Making Database object from scratch...") print( - "(You can also use `get_dataset(..., download=True)`" + "(You can also use `get_dataset(..., download=True)` " "for datasets prepared by the RelBench team.)" ) tic = time.time() diff --git a/relbench/base/task_base.py b/relbench/base/task_base.py index 5992062a..3ee9e777 100644 --- a/relbench/base/task_base.py +++ b/relbench/base/task_base.py @@ -78,15 +78,7 @@ def __init__( ) def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(\n" - f"dataset={self.dataset},\n" - # f"cache_dir={self.cache_dir},\n" - f"task_type={self.task_type},\n" - f"timedelta={self.timedelta},\n" - # f"num_eval_timestamps={self.num_eval_timestamps},\n" - # f"metrics={self.metrics},\n" - ) + return f"{self.__class__.__name__}(dataset={repr(self.dataset)})" def make_table( self, From a17f641ecd1866b3a2afd7c4e3df7f9f442168a4 Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Mon, 8 Jul 2024 11:13:09 -0700 Subject: [PATCH 8/9] v1.0.0-rc1 --- pyproject.toml | 2 +- relbench/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c1ca4c82..882215a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "relbench" -version = "0.2.0" +version = "1.0.0-rc1" description = "RelBench: Relational Deep Learning Benchmark" authors = [{name = "RelBench Team", email = "relbench@cs.stanford.edu"}] readme = "README.md" diff --git a/relbench/__init__.py b/relbench/__init__.py index da35152b..1441ce00 100644 --- a/relbench/__init__.py +++ b/relbench/__init__.py @@ -1,3 +1,3 @@ from relbench import base, datasets, modeling, tasks -__version__ = "1.0.0" +__version__ = "1.0.0-rc1" From 16dd072e98750171e59dcd97e972b82117502b2c Mon Sep 17 00:00:00 2001 From: Rishabh Ranjan Date: Mon, 8 Jul 2024 11:17:16 -0700 Subject: [PATCH 9/9] minor --- relbench/base/task_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/relbench/base/task_base.py b/relbench/base/task_base.py index 3ee9e777..27301c00 100644 --- a/relbench/base/task_base.py +++ b/relbench/base/task_base.py @@ -93,7 +93,8 @@ def make_table( computed for a timestamp using historical data upto this timestamp in the database. - To be implemented by subclass. + To be implemented by subclass. The table rows need not be ordered + deterministically. """ raise NotImplementedError