diff --git a/benchmark/train/configs/financial.yaml b/benchmark/train/configs/financial.yaml index 4b9459cb..e9bd7390 100644 --- a/benchmark/train/configs/financial.yaml +++ b/benchmark/train/configs/financial.yaml @@ -14,6 +14,8 @@ dataset: task: node task_type: classification split: [0.8, 0.1, 0.1] + split_mode: random + split_column: DATE # only needed when split_mode = column encoder: True encoder_name: db encoder_bn: True diff --git a/benchmark/train/configs/financial_regression.yaml b/benchmark/train/configs/financial_regression.yaml index 9ed292fb..ecee5f04 100644 --- a/benchmark/train/configs/financial_regression.yaml +++ b/benchmark/train/configs/financial_regression.yaml @@ -14,6 +14,8 @@ dataset: task: node task_type: regression split: [0.8, 0.1, 0.1] + split_mode: column + split_column: DATE encoder: True encoder_name: db encoder_bn: True diff --git a/kumo/config/config.py b/kumo/config/config.py index fb8e8cfb..203de85c 100644 --- a/kumo/config/config.py +++ b/kumo/config/config.py @@ -45,7 +45,15 @@ def set_cfg(cfg): # Database name cfg.snowflake.database = 'kumo' - # overwrite GraphGym scheduler + # Default to random split + cfg.dataset.split_type = 'random' + # If split_type == 'column', split by the values of the column: + # lower values of this column will be in the training split; + # the highest values of this column will be in the test split. + # Restriction: the split column has to be in the prediction target table. + cfg.dataset.split_column = None + + # Overwrite GraphGym scheduler # (might improve default optimizer after more training experiences # on databases) cfg.optim.scheduler = 'none' diff --git a/kumo/custom/loader/financial.py b/kumo/custom/loader/financial.py index 46f42413..703efcab 100644 --- a/kumo/custom/loader/financial.py +++ b/kumo/custom/loader/financial.py @@ -9,7 +9,7 @@ def load_financial(format, name, dataset_dir, label_table_name, - label_column_name): + label_column_name, split_column_name=None): if name != "Financial": return None @@ -17,6 +17,8 @@ def load_financial(format, name, dataset_dir, label_table_name, root_dir = osp.join(root_dir, "test", "csv_data", "FINANCIAL") dbmeta = DatabaseMetadata.load(osp.join(root_dir, "metadata.yml")) dbmeta.set_label(label_table_name, label_column_name) + if split_column_name is not None: + dbmeta.set_split_column(split_column_name) if format == "snowflake": connector = SnowflakeConnector( @@ -37,7 +39,7 @@ def load_financial(format, name, dataset_dir, label_table_name, # TODO: Temporary work around (merge labels for LOAN.STATUS only) if label_table_name == "LOAN" and label_column_name == "STATUS": - data["LOAN"].y[data["LOAN"].y == 2] = 0 + data['LOAN'].y[data["LOAN"].y == 2] = 0 data["LOAN"].y[data["LOAN"].y == 3] = 1 return data diff --git a/kumo/custom/loader/imdb.py b/kumo/custom/loader/imdb.py index 66ced4aa..23bd84d6 100644 --- a/kumo/custom/loader/imdb.py +++ b/kumo/custom/loader/imdb.py @@ -8,7 +8,8 @@ from kumo.connector import SnowflakeConnector -def load_imdb(format, name, dataset_dir, label_table_name, label_column_name): +def load_imdb(format, name, dataset_dir, label_table_name, label_column_name, + split_column_name=None): if name != 'IMDB': return None @@ -80,6 +81,11 @@ def load_imdb(format, name, dataset_dir, label_table_name, label_column_name): encoders = cfg[label_table_name]['encoders'] cfg[label_table_name]['label_encoder'] = \ [encoder for encoder in encoders if label_column_name == encoder[0]][0] + if split_column_name is not None: + # TODO: to enable column split, use dbmeta similar to financial.py to + # construct Kumo store, and set the split column: + # dbmeta.set_split_col(split_column_name) + print('TODO: integrate database metadata for imdb for column split.') if format == 'snowflake': diff --git a/kumo/model/model_builder.py b/kumo/model/model_builder.py index 42ad7796..5fff4f7d 100644 --- a/kumo/model/model_builder.py +++ b/kumo/model/model_builder.py @@ -70,11 +70,11 @@ def _get_pred_int(self, pred_score): @torch.no_grad() def _eval_auroc(self, label, pred): # if pred_value is multi-class classification, compute softmax - if (self.cfg.dataset.task_type == "classification" - and pred.dim() > 1 and pred.size(1) > 2): + if (self.cfg.dataset.task_type == "classification" and pred.dim() > 1 + and pred.size(1) > 2): # TODO: add pred = nn.Softmax(dim=1)(pred) and "ovo" for # multi-class if imbalance eval is needed via auc - return 0 + return None label = label.cpu() pred = pred.cpu() auc = roc_auc_score(label, pred) @@ -97,7 +97,7 @@ def plot_error_distribution(self, pred, label, metric="err", prefix="train"): if metric == "err": if pred.dim() == 1 or pred.size(0) == pred.numel(): - err_raw = pred - label + err_raw = pred.squeeze() - label.squeeze() else: raise ValueError("Raw error can only be applied to scalar") elif metric == "mae": @@ -144,8 +144,9 @@ def training_step(self, batch, batch_idx): self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True) - self.log("train_auc", auc, on_step=False, on_epoch=True, - prog_bar=True) + if auc is not None: + self.log("train_auc", auc, on_step=False, on_epoch=True, + prog_bar=True) elif self.cfg.dataset.task_type == "regression": self._eval_regression(label, pred_value) return loss @@ -174,8 +175,9 @@ def validation_step(self, batch, batch_idx): self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True) - self.log("val_auc", auc, on_step=False, on_epoch=True, - prog_bar=True) + if auc is not None: + self.log("val_auc", auc, on_step=False, on_epoch=True, + prog_bar=True) elif self.cfg.dataset.task_type == "regression": if self.current_epoch >= PLOT_EPOCH_THRESHOLD and \ loss < self.val_best_loss: diff --git a/kumo/scan/database_metadata.py b/kumo/scan/database_metadata.py index faa39ded..3f4ced32 100644 --- a/kumo/scan/database_metadata.py +++ b/kumo/scan/database_metadata.py @@ -11,6 +11,7 @@ class DatabaseMetadata: def __init__(self, name: str, tables: List[TableMetadata]): self.name = name + self.split_col_name = None self._table_dict = {table.name: table for table in tables} self.fill_metagraph_() @@ -38,6 +39,9 @@ def config(self) -> DatabaseMetadataConfig: tables = [table.config() for table in self.tables] return DatabaseMetadataConfig(self.name, tables) + def set_split_column(self, column_name: str): + self.split_col_name = column_name + def table(self, table_name: str) -> TableMetadata: return self._table_dict[table_name] diff --git a/kumo/store/store.py b/kumo/store/store.py index dbbb0702..cc6ced4e 100644 --- a/kumo/store/store.py +++ b/kumo/store/store.py @@ -57,6 +57,8 @@ def __init__(self, name: str, tables: List[Table], persist: bool = False): if hasattr(table, 'label'): self[table.name].label = table.label self[table.name].y = self[table.name].label + if hasattr(table, 'split_col'): + self[table.name].split_col = table.split_col # Add graph connectivity: for data in zip(table.foreign_key_info, table.foreign_key_value): diff --git a/kumo/store/table.py b/kumo/store/table.py index 61d4abf3..65ef86ab 100644 --- a/kumo/store/table.py +++ b/kumo/store/table.py @@ -25,19 +25,21 @@ class Table: :obj:`(table_name, column_identifier)`. (default: :obj:`None`) target: ((str or int, Encoder), optional): If not :obj:`None`, will use this column as ground-truth labels. (default: :obj:`None`) + label_encoder: encoder information for the label column. + (default: :obj:`None`) + split_col_name: name of the column that is used to perform dataset + split for split_mode=column. + It is only applicable when the current table contains label. + (default: :obj:`None`) shortcut (bool, optional): If set to :obj:`True`, will interpret this table as an edge-level representation, holding many-to-many connections. (default: :obj:`False`) """ - def __init__( - self, - name: str, - df: pd.DataFrame, - encoders: Optional[List[ColumnEncoder]] = None, - foreign_key_info: Optional[List[ForeignKeyInfo]] = None, - label_encoder: Optional[ColumnEncoder] = None, - shortcut: bool = False, - ): + def __init__(self, name: str, df: pd.DataFrame, + encoders: Optional[List[ColumnEncoder]] = None, + foreign_key_info: Optional[List[ForeignKeyInfo]] = None, + label_encoder: Optional[ColumnEncoder] = None, + split_col_name: Optional[str] = None, shortcut: bool = False): if shortcut: raise NotImplementedError(f"'shortcut' option not yet implemented " @@ -58,9 +60,12 @@ def __init__( dtype = torch.int64 if not is_floating_point(enc) else None self.label = encode(df, [label_encoder], dtype) - # Remove any encoders that are also being used as labels: + # Remove any encoders that are also being used as labels encoders = [encoder for encoder in encoders if name != encoder[0]] + if split_col_name is not None: + self.split_col = df[split_col_name] + self.feat_encoders = [(name, enc) for name, enc in encoders if is_floating_point(enc)] self.discrete_feat_encoders = [(name, enc) for name, enc in encoders @@ -113,8 +118,8 @@ def from_connector(cls, connector: Connector, table_name: str, df = connector.read_table(table_name, primary_key, foreign_keys + columns, primary_key_dtype, foreign_dtypes + column_dtypes) - - return cls(table_name, df, encoders, foreign_key_info, label_encoder) + return cls(table_name, df, encoders, foreign_key_info, label_encoder, + split_col_name=dbmeta.split_col_name) def config(self) -> TableConfig: offset: int = 0 diff --git a/kumo/train/loader.py b/kumo/train/loader.py index eb0f91a5..351c81fe 100644 --- a/kumo/train/loader.py +++ b/kumo/train/loader.py @@ -1,5 +1,6 @@ import logging import os.path as osp +import pandas as pd import torch import torch_geometric.datasets as pyg_dataset from torch_geometric.loader import NeighborLoader @@ -48,13 +49,8 @@ def load_dataset(cfg: CfgNode, **kwargs): label_column = cfg.dataset.label_column # Try to load customized data format for func in register.loader_dict.values(): - dataset = func( - format, - name, - dataset_dir, - label_table, - label_column, - ) + dataset = func(format, name, dataset_dir, label_table, label_column, + cfg.dataset.split_column) if dataset is not None: return dataset if format == "PyG": @@ -117,6 +113,8 @@ def split( dataset: Store, label_table: str, split_ratio: List = [0.8, 0.1, 0.1], + split_mode: str = 'random', + split_column: str = 'none', shuffle: bool = True, ): r""" @@ -125,7 +123,20 @@ def split( """ n = dataset[label_table].num_nodes - if shuffle: + if split_mode == 'column': + if split_column is None: + raise ValueError( + 'Specify split_column for when split_mode is column.') + split_col = dataset[label_table].split_col + if isinstance(split_col, pd.Series): + sorted_idx = split_col.argsort() + id = sorted_idx.to_numpy() + elif isinstance(split_col, torch.Tensor): + sorted_idx = torch.argsort(dataset[label_table].split_col) + id = sorted_idx.numpy() + else: + raise TypeError('Unknown data type for split column.') + elif shuffle: id = torch.randperm(n) else: id = torch.arange(n) @@ -153,7 +164,8 @@ def create_dataset(cfg: CfgNode, **kwargs): dataset = load_dataset(cfg, **kwargs) for n_t in dataset.metadata()[0]: dataset[n_t].x = dataset[n_t].feat - split(dataset, cfg.dataset.label_table, cfg.dataset.split) + split(dataset, cfg.dataset.label_table, cfg.dataset.split, + cfg.dataset.split_mode, cfg.dataset.split_column) infer_dataset_info(cfg, dataset)