Skip to content

Commit

Permalink
Time split (rusty1s#137)
Browse files Browse the repository at this point in the history
Allows split according to a column values. (e.g. smaller 80% is train; larger 10% is test; rest is validation). Can be date column or other int valued columns that require temporal splits.

Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
RexYing and rusty1s authored Dec 16, 2021
1 parent 599087d commit a4e705e
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 33 deletions.
2 changes: 2 additions & 0 deletions benchmark/train/configs/financial.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ dataset:
task: node
task_type: classification
split: [0.8, 0.1, 0.1]
split_mode: random
split_column: DATE # only needed when split_mode = column
encoder: True
encoder_name: db
encoder_bn: True
Expand Down
2 changes: 2 additions & 0 deletions benchmark/train/configs/financial_regression.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ dataset:
task: node
task_type: regression
split: [0.8, 0.1, 0.1]
split_mode: column
split_column: DATE
encoder: True
encoder_name: db
encoder_bn: True
Expand Down
10 changes: 9 additions & 1 deletion kumo/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ def set_cfg(cfg):
# Database name
cfg.snowflake.database = 'kumo'

# overwrite GraphGym scheduler
# Default to random split
cfg.dataset.split_type = 'random'
# If split_type == 'column', split by the values of the column:
# lower values of this column will be in the training split;
# the highest values of this column will be in the test split.
# Restriction: the split column has to be in the prediction target table.
cfg.dataset.split_column = None

# Overwrite GraphGym scheduler
# (might improve default optimizer after more training experiences
# on databases)
cfg.optim.scheduler = 'none'
Expand Down
6 changes: 4 additions & 2 deletions kumo/custom/loader/financial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@


def load_financial(format, name, dataset_dir, label_table_name,
label_column_name):
label_column_name, split_column_name=None):
if name != "Financial":
return None

root_dir = osp.join(osp.dirname(osp.realpath(__file__)), "..", "..", "..")
root_dir = osp.join(root_dir, "test", "csv_data", "FINANCIAL")
dbmeta = DatabaseMetadata.load(osp.join(root_dir, "metadata.yml"))
dbmeta.set_label(label_table_name, label_column_name)
if split_column_name is not None:
dbmeta.set_split_column(split_column_name)

if format == "snowflake":
connector = SnowflakeConnector(
Expand All @@ -37,7 +39,7 @@ def load_financial(format, name, dataset_dir, label_table_name,

# TODO: Temporary work around (merge labels for LOAN.STATUS only)
if label_table_name == "LOAN" and label_column_name == "STATUS":
data["LOAN"].y[data["LOAN"].y == 2] = 0
data['LOAN'].y[data["LOAN"].y == 2] = 0
data["LOAN"].y[data["LOAN"].y == 3] = 1

return data
Expand Down
8 changes: 7 additions & 1 deletion kumo/custom/loader/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from kumo.connector import SnowflakeConnector


def load_imdb(format, name, dataset_dir, label_table_name, label_column_name):
def load_imdb(format, name, dataset_dir, label_table_name, label_column_name,
split_column_name=None):
if name != 'IMDB':
return None

Expand Down Expand Up @@ -80,6 +81,11 @@ def load_imdb(format, name, dataset_dir, label_table_name, label_column_name):
encoders = cfg[label_table_name]['encoders']
cfg[label_table_name]['label_encoder'] = \
[encoder for encoder in encoders if label_column_name == encoder[0]][0]
if split_column_name is not None:
# TODO: to enable column split, use dbmeta similar to financial.py to
# construct Kumo store, and set the split column:
# dbmeta.set_split_col(split_column_name)
print('TODO: integrate database metadata for imdb for column split.')

if format == 'snowflake':

Expand Down
18 changes: 10 additions & 8 deletions kumo/model/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def _get_pred_int(self, pred_score):
@torch.no_grad()
def _eval_auroc(self, label, pred):
# if pred_value is multi-class classification, compute softmax
if (self.cfg.dataset.task_type == "classification"
and pred.dim() > 1 and pred.size(1) > 2):
if (self.cfg.dataset.task_type == "classification" and pred.dim() > 1
and pred.size(1) > 2):
# TODO: add pred = nn.Softmax(dim=1)(pred) and "ovo" for
# multi-class if imbalance eval is needed via auc
return 0
return None
label = label.cpu()
pred = pred.cpu()
auc = roc_auc_score(label, pred)
Expand All @@ -97,7 +97,7 @@ def plot_error_distribution(self, pred, label, metric="err",
prefix="train"):
if metric == "err":
if pred.dim() == 1 or pred.size(0) == pred.numel():
err_raw = pred - label
err_raw = pred.squeeze() - label.squeeze()
else:
raise ValueError("Raw error can only be applied to scalar")
elif metric == "mae":
Expand Down Expand Up @@ -144,8 +144,9 @@ def training_step(self, batch, batch_idx):

self.log("train_acc", acc, on_step=False, on_epoch=True,
prog_bar=True)
self.log("train_auc", auc, on_step=False, on_epoch=True,
prog_bar=True)
if auc is not None:
self.log("train_auc", auc, on_step=False, on_epoch=True,
prog_bar=True)
elif self.cfg.dataset.task_type == "regression":
self._eval_regression(label, pred_value)
return loss
Expand Down Expand Up @@ -174,8 +175,9 @@ def validation_step(self, batch, batch_idx):

self.log("val_acc", acc, on_step=False, on_epoch=True,
prog_bar=True)
self.log("val_auc", auc, on_step=False, on_epoch=True,
prog_bar=True)
if auc is not None:
self.log("val_auc", auc, on_step=False, on_epoch=True,
prog_bar=True)
elif self.cfg.dataset.task_type == "regression":
if self.current_epoch >= PLOT_EPOCH_THRESHOLD and \
loss < self.val_best_loss:
Expand Down
4 changes: 4 additions & 0 deletions kumo/scan/database_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
class DatabaseMetadata:
def __init__(self, name: str, tables: List[TableMetadata]):
self.name = name
self.split_col_name = None

self._table_dict = {table.name: table for table in tables}
self.fill_metagraph_()
Expand Down Expand Up @@ -38,6 +39,9 @@ def config(self) -> DatabaseMetadataConfig:
tables = [table.config() for table in self.tables]
return DatabaseMetadataConfig(self.name, tables)

def set_split_column(self, column_name: str):
self.split_col_name = column_name

def table(self, table_name: str) -> TableMetadata:
return self._table_dict[table_name]

Expand Down
2 changes: 2 additions & 0 deletions kumo/store/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __init__(self, name: str, tables: List[Table], persist: bool = False):
if hasattr(table, 'label'):
self[table.name].label = table.label
self[table.name].y = self[table.name].label
if hasattr(table, 'split_col'):
self[table.name].split_col = table.split_col

# Add graph connectivity:
for data in zip(table.foreign_key_info, table.foreign_key_value):
Expand Down
29 changes: 17 additions & 12 deletions kumo/store/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,21 @@ class Table:
:obj:`(table_name, column_identifier)`. (default: :obj:`None`)
target: ((str or int, Encoder), optional): If not :obj:`None`, will use
this column as ground-truth labels. (default: :obj:`None`)
label_encoder: encoder information for the label column.
(default: :obj:`None`)
split_col_name: name of the column that is used to perform dataset
split for split_mode=column.
It is only applicable when the current table contains label.
(default: :obj:`None`)
shortcut (bool, optional): If set to :obj:`True`, will interpret this
table as an edge-level representation, holding many-to-many
connections. (default: :obj:`False`)
"""
def __init__(
self,
name: str,
df: pd.DataFrame,
encoders: Optional[List[ColumnEncoder]] = None,
foreign_key_info: Optional[List[ForeignKeyInfo]] = None,
label_encoder: Optional[ColumnEncoder] = None,
shortcut: bool = False,
):
def __init__(self, name: str, df: pd.DataFrame,
encoders: Optional[List[ColumnEncoder]] = None,
foreign_key_info: Optional[List[ForeignKeyInfo]] = None,
label_encoder: Optional[ColumnEncoder] = None,
split_col_name: Optional[str] = None, shortcut: bool = False):

if shortcut:
raise NotImplementedError(f"'shortcut' option not yet implemented "
Expand All @@ -58,9 +60,12 @@ def __init__(
dtype = torch.int64 if not is_floating_point(enc) else None
self.label = encode(df, [label_encoder], dtype)

# Remove any encoders that are also being used as labels:
# Remove any encoders that are also being used as labels
encoders = [encoder for encoder in encoders if name != encoder[0]]

if split_col_name is not None:
self.split_col = df[split_col_name]

self.feat_encoders = [(name, enc) for name, enc in encoders
if is_floating_point(enc)]
self.discrete_feat_encoders = [(name, enc) for name, enc in encoders
Expand Down Expand Up @@ -113,8 +118,8 @@ def from_connector(cls, connector: Connector, table_name: str,
df = connector.read_table(table_name, primary_key,
foreign_keys + columns, primary_key_dtype,
foreign_dtypes + column_dtypes)

return cls(table_name, df, encoders, foreign_key_info, label_encoder)
return cls(table_name, df, encoders, foreign_key_info, label_encoder,
split_col_name=dbmeta.split_col_name)

def config(self) -> TableConfig:
offset: int = 0
Expand Down
30 changes: 21 additions & 9 deletions kumo/train/loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os.path as osp
import pandas as pd
import torch
import torch_geometric.datasets as pyg_dataset
from torch_geometric.loader import NeighborLoader
Expand Down Expand Up @@ -48,13 +49,8 @@ def load_dataset(cfg: CfgNode, **kwargs):
label_column = cfg.dataset.label_column
# Try to load customized data format
for func in register.loader_dict.values():
dataset = func(
format,
name,
dataset_dir,
label_table,
label_column,
)
dataset = func(format, name, dataset_dir, label_table, label_column,
cfg.dataset.split_column)
if dataset is not None:
return dataset
if format == "PyG":
Expand Down Expand Up @@ -117,6 +113,8 @@ def split(
dataset: Store,
label_table: str,
split_ratio: List = [0.8, 0.1, 0.1],
split_mode: str = 'random',
split_column: str = 'none',
shuffle: bool = True,
):
r"""
Expand All @@ -125,7 +123,20 @@ def split(
"""
n = dataset[label_table].num_nodes
if shuffle:
if split_mode == 'column':
if split_column is None:
raise ValueError(
'Specify split_column for when split_mode is column.')
split_col = dataset[label_table].split_col
if isinstance(split_col, pd.Series):
sorted_idx = split_col.argsort()
id = sorted_idx.to_numpy()
elif isinstance(split_col, torch.Tensor):
sorted_idx = torch.argsort(dataset[label_table].split_col)
id = sorted_idx.numpy()
else:
raise TypeError('Unknown data type for split column.')
elif shuffle:
id = torch.randperm(n)
else:
id = torch.arange(n)
Expand Down Expand Up @@ -153,7 +164,8 @@ def create_dataset(cfg: CfgNode, **kwargs):
dataset = load_dataset(cfg, **kwargs)
for n_t in dataset.metadata()[0]:
dataset[n_t].x = dataset[n_t].feat
split(dataset, cfg.dataset.label_table, cfg.dataset.split)
split(dataset, cfg.dataset.label_table, cfg.dataset.split,
cfg.dataset.split_mode, cfg.dataset.split_column)

infer_dataset_info(cfg, dataset)

Expand Down

0 comments on commit a4e705e

Please sign in to comment.