Skip to content

Commit

Permalink
Merge Dataset and RelBenchDataset classes (#212)
Browse files Browse the repository at this point in the history
.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rishabh-ranjan and pre-commit-ci[bot] authored Jul 3, 2024
1 parent 30af0a3 commit 68e9700
Show file tree
Hide file tree
Showing 19 changed files with 104 additions and 186 deletions.
2 changes: 1 addition & 1 deletion relbench/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .database import Database
from .dataset import Dataset, RelBenchDataset
from .dataset import Dataset
from .table import Table
from .task_base import BaseTask
from .task_link import LinkTask, RelBenchLinkTask
Expand Down
84 changes: 31 additions & 53 deletions relbench/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import tempfile
import time
from functools import lru_cache
from pathlib import Path
from typing import List, Optional, Tuple, Type, Union

Expand All @@ -12,39 +13,24 @@

from relbench import DOWNLOAD_REGISTRY
from relbench.data.database import Database
from relbench.data.task_base import BaseTask
from relbench.utils import unzip_processor


class Dataset:
val_timestamp: pd.Timestamp
test_timestamp: pd.Timestamp
max_eval_time_frames: int = 1

def __init__(
self,
db: Database,
val_timestamp: pd.Timestamp,
test_timestamp: pd.Timestamp,
max_eval_time_frames: int,
cache_dir: Optional[str] = None,
) -> None:
r"""Class holding database and task table construction logic.
Args:
db (Database): The database object.
val_timestamp (pd.Timestamp): The first timestamp for making val table.
test_timestamp (pd.Timestamp): The first timestamp for making test table.
max_eval_time_frames (int): The maximum number of unique timestamps used to build test and val tables.
"""
self._full_db = db
self.val_timestamp = val_timestamp
self.test_timestamp = test_timestamp
self.max_eval_time_frames = max_eval_time_frames

self.db = db.upto(test_timestamp)

self.validate_and_correct_db()
self.cache_dir = cache_dir

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"

# TODO: remove this or db.reindex_pkeys_and_fkeys
def validate_and_correct_db(self):
r"""Validate and correct input db in-place."""
# Validate that all primary keys are consecutively index.
Expand All @@ -67,17 +53,17 @@ def validate_and_correct_db(self):
if mask.any():
table.df.loc[mask, fkey_col] = None

@lru_cache(maxsize=None)
def get_db(self, upto_test_timestamp=True) -> Database:
db_path = f"{self.cache_dir}/db"
if self.cache_dir and Path(db_path).exists() and any(Path(db_path).iterdir()):
print(f"loading Database object from {db_path}...")
tic = time.time()
db = Database.load(db_path)
toc = time.time()
print(f"done in {toc - tic:.2f} seconds.")

class RelBenchDataset(Dataset):
name: str
val_timestamp: pd.Timestamp
test_timestamp: pd.Timestamp

db_dir: str = "db"

def __init__(self, process=None) -> None:
db_path = pooch.os_cache("relbench") / self.name / self.db_dir
if not db_path.exists():
else:
print("making Database object from raw files...")
tic = time.time()
db = self.make_db()
Expand All @@ -90,33 +76,25 @@ def __init__(self, process=None) -> None:
toc = time.time()
print(f"done in {toc - tic:.2f} seconds.")

print(f"caching Database object to {db_path}...")
tic = time.time()
db.save(db_path)
toc = time.time()
print(f"done in {toc - tic:.2f} seconds.")
print(f"use process=False to load from cache.")
if self.cache_dir:
print(f"caching Database object to {db_path}...")
tic = time.time()
db.save(db_path)
toc = time.time()
print(f"done in {toc - tic:.2f} seconds.")

else:
print(f"loading Database object from {db_path}...")
tic = time.time()
db = Database.load(db_path)
toc = time.time()
print(f"done in {toc - tic:.2f} seconds.")
if upto_test_timestamp:
db = db.upto(self.test_timestamp)

super().__init__(
db,
self.val_timestamp,
self.test_timestamp,
self.max_eval_time_frames,
)
return db

def make_db(self) -> Database:
raise NotImplementedError

# TODO: move out of here.
def pack_db(self, root: Union[str, os.PathLike]) -> Tuple[str, str]:
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / self.db_dir
db_path = Path(tmpdir) / "db"
print(f"saving Database object to {db_path}...")
tic = time.time()
self._full_db.save(db_path)
Expand All @@ -125,7 +103,7 @@ def pack_db(self, root: Union[str, os.PathLike]) -> Tuple[str, str]:

print("making zip archive for db...")
tic = time.time()
zip_path = Path(root) / self.name / self.db_dir
zip_path = Path(root) / self.name / "db"
zip_path = shutil.make_archive(zip_path, "zip", db_path)
toc = time.time()
print(f"done in {toc - tic:.2f} seconds.")
Expand All @@ -136,4 +114,4 @@ def pack_db(self, root: Union[str, os.PathLike]) -> Tuple[str, str]:
print(f"upload: {zip_path}")
print(f"sha256: {sha256}")

return f"{self.name}/{self.db_dir}.zip", sha256
return f"{self.name}/db.zip", sha256
29 changes: 12 additions & 17 deletions relbench/data/task_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@

from relbench import DOWNLOAD_REGISTRY
from relbench.data.database import Database
from relbench.data.dataset import Dataset
from relbench.data.table import Table
from relbench.utils import unzip_processor

if TYPE_CHECKING:
from relbench.data import Dataset


class BaseTask:
Expand Down Expand Up @@ -59,10 +56,11 @@ def make_table(
@property
def train_table(self) -> Table:
"""Returns the train table for a task."""
db = self.dataset.get_db()
if "train" not in self._cached_table_dict:
timestamps = pd.date_range(
start=self.dataset.val_timestamp - self.timedelta,
end=self.dataset.db.min_timestamp,
end=db.min_timestamp,
freq=-self.timedelta,
)
if len(timestamps) < 3:
Expand All @@ -71,7 +69,7 @@ def train_table(self) -> Table:
f"({len(timestamps)} given)"
)
table = self.make_table(
self.dataset.db,
db,
timestamps,
)
self._cached_table_dict["train"] = table
Expand All @@ -82,11 +80,9 @@ def train_table(self) -> Table:
@property
def val_table(self) -> Table:
r"""Returns the val table for a task."""
db = self.dataset.get_db()
if "val" not in self._cached_table_dict:
if (
self.dataset.val_timestamp + self.timedelta
> self.dataset.db.max_timestamp
):
if self.dataset.val_timestamp + self.timedelta > db.max_timestamp:
raise RuntimeError(
"val timestamp + timedelta is larger than max timestamp! "
"This would cause val labels to be generated with "
Expand All @@ -101,7 +97,7 @@ def val_table(self) -> Table:
)

table = self.make_table(
self.dataset.db,
db,
pd.date_range(
self.dataset.val_timestamp,
end_timestamp,
Expand All @@ -115,12 +111,10 @@ def val_table(self) -> Table:

@property
def test_table(self) -> Table:
db = self.dataset.get_db(upto_test_timestamp=False)
r"""Returns the test table for a task."""
if "full_test" not in self._cached_table_dict:
if (
self.dataset.test_timestamp + self.timedelta
> self.dataset._full_db.max_timestamp
):
if self.dataset.test_timestamp + self.timedelta > db.max_timestamp:
raise RuntimeError(
"test timestamp + timedelta is larger than max timestamp! "
"This would cause test labels to be generated with "
Expand All @@ -131,11 +125,11 @@ def test_table(self) -> Table:
end_timestamp = min(
self.dataset.test_timestamp
+ self.timedelta * (self.dataset.max_eval_time_frames - 1),
self.dataset._full_db.max_timestamp - self.timedelta,
db.max_timestamp - self.timedelta,
)

full_table = self.make_table(
self.dataset._full_db,
db,
pd.date_range(
self.dataset.test_timestamp,
end_timestamp,
Expand Down Expand Up @@ -196,6 +190,7 @@ class TaskType(Enum):
LINK_PREDICTION = "link_prediction"


# TODO: move somewhere else
def _pack_tables(task, root: Union[str, os.PathLike]) -> Tuple[str, str]:
_dummy_db = Database(
table_dict={
Expand Down
4 changes: 2 additions & 2 deletions relbench/data/task_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ def evaluate(

@property
def num_src_nodes(self) -> int:
return len(self.dataset.db.table_dict[self.src_entity_table])
return len(self.dataset.get_db().table_dict[self.src_entity_table])

@property
def num_dst_nodes(self) -> int:
return len(self.dataset.db.table_dict[self.dst_entity_table])
return len(self.dataset.get_db().table_dict[self.dst_entity_table])

@property
def val_seed_time(self) -> int:
Expand Down
3 changes: 2 additions & 1 deletion relbench/data/task_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(dataset={self.dataset})"

def filter_dangling_entities(self, table: Table) -> Table:
num_entities = len(self.dataset.db.table_dict[self.entity_table])
db = self.dataset.get_db()
num_entities = len(db.table_dict[self.entity_table])
filter_mask = table.df[self.entity_col] >= num_entities

if filter_mask.any():
Expand Down
19 changes: 9 additions & 10 deletions relbench/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pooch

from ..data import Dataset
from . import amazon, avito, event, f1, hm, stack, trial
from . import amazon, avito, event, f1, fake, hm, stack, trial

dataset_registry = {}

Expand All @@ -28,6 +28,8 @@ def register_dataset(
*args,
**kwargs,
):
relbench_cache = pooch.os_cache("relbench")
kwargs = {"cache_dir": f"{relbench_cache}/{name}", **kwargs}
dataset_registry[name] = (cls, args, kwargs)


Expand All @@ -36,18 +38,15 @@ def get_dataset_names():


def download_dataset(name: str) -> None:
try:
DOWNLOAD_REGISTRY.fetch(
f"{name}/db.zip",
processor=pooch.Unzip(extract_dir="db"),
progressbar=True,
)
except ValueError:
print("failed to download, will attempt to make db from raw files")
DOWNLOAD_REGISTRY.fetch(
f"{name}/db.zip",
processor=pooch.Unzip(extract_dir="db"),
progressbar=True,
)


@lru_cache(maxsize=None)
def get_dataset(name: str, download=True) -> Dataset:
def get_dataset(name: str, download=False) -> Dataset:
if download:
download_dataset(name)
cls, args, kwargs = dataset_registry[name]
Expand Down
14 changes: 4 additions & 10 deletions relbench/datasets/amazon.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@
import pooch
import pyarrow as pa

from relbench.data import Database, RelBenchDataset, Table
from relbench.data import Database, Dataset, Table


class AmazonDataset(RelBenchDataset):
class AmazonDataset(Dataset):
name = "rel-amazon"
val_timestamp = pd.Timestamp("2015-10-01")
test_timestamp = pd.Timestamp("2016-01-01")

max_eval_time_frames = 1

category_list = ["books", "fashion"]

url_prefix = "https://datarepo.eng.ucsd.edu/mcauley_group/data/amazon_v2"
_category_to_url_key = {"books": "Books", "fashion": "AMAZON_FASHION"}

Expand All @@ -28,15 +26,11 @@ def __init__(
self,
category: str = "books",
use_5_core: bool = True,
*,
process: bool = False,
cache_dir: str = None,
):
self.category = category
self.use_5_core = use_5_core

# self.name = f"{self.name}-{category}{'_5_core' if use_5_core else ''}"

super().__init__(process=process)
super().__init__(cache_dir=cache_dir)

def make_db(self) -> Database:
r"""Process the raw files into a database."""
Expand Down
12 changes: 2 additions & 10 deletions relbench/datasets/avito.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import pandas as pd
import pooch

from relbench.data import Database, RelBenchDataset, Table
from relbench.data import Database, Dataset, Table
from relbench.utils import clean_datetime, unzip_processor


class AvitoDataset(RelBenchDataset):
class AvitoDataset(Dataset):
name = "rel-avito"
url = "https://www.kaggle.com/competitions/avito-context-ad-clicks"
err_msg = (
Expand All @@ -20,14 +20,6 @@ class AvitoDataset(RelBenchDataset):
test_timestamp = pd.Timestamp("2015-05-14")
max_eval_time_frames = 1

def __init__(
self,
*,
process: bool = False,
):
self.name = f"{self.name}"
super().__init__(process=process)

def make_db(self) -> Database:
# Customize path as necessary
r"""Process the raw files into a database."""
Expand Down
Loading

0 comments on commit 68e9700

Please sign in to comment.