diff --git a/.gitignore b/.gitignore index e2d3cd745..b9f39521e 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ datacache/ tests/temp_cache* predictions/ draft/ +scripts-expts/ # Data and predictions graphium/data/ZINC_bench_gnn/ diff --git a/README.md b/README.md index 35698d2c5..11b707bba 100644 --- a/README.md +++ b/README.md @@ -65,9 +65,55 @@ The above step needs to be done once. After that, enable the SDK and the environ source enable_ipu.sh .graphium_ipu ``` -## The Graphium CLI +## Training a model -Installing `graphium` makes two CLI tools available: `graphium` and `graphium-train`. These CLI tools make it easy to access advanced functionality, such as _training a model_, _extracting fingerprints from a pre-trained model_ or _precomputing the dataset_. For more information, visit [the documentation](https://graphium-docs.datamol.io/stable/cli/reference.html). +To learn how to train a model, we invite you to look at the documentation, or the jupyter notebooks available [here](https://github.com/datamol-io/graphium/tree/master/docs/tutorials/model_training). + +If you are not familiar with [PyTorch](https://pytorch.org/docs) or [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/), we highly recommend going through their tutorial first. + +## Running an experiment +We have setup Graphium with `hydra` for managing config files. To run an experiment go to the `expts/` folder. For example, to benchmark a GCN on the ToyMix dataset run +```bash +graphium-train dataset=toymix model=gcn +``` +To change parameters specific to this experiment like switching from `fp16` to `fp32` precision, you can either override them directly in the CLI via +```bash +graphium-train dataset=toymix model=gcn trainer.trainer.precision=32 +``` +or change them permamently in the dedicated experiment config under `expts/hydra-configs/toymix_gcn.yaml`. +Integrating `hydra` also allows you to quickly switch between accelerators. E.g., running +```bash +graphium-train dataset=toymix model=gcn accelerator=gpu +``` +automatically selects the correct configs to run the experiment on GPU. +Finally, you can also run a fine-tuning loop: +```bash +graphium-train +finetuning=admet +``` + +To use a config file you built from scratch you can run +```bash +graphium-train --config-path [PATH] --config-name [CONFIG] +``` +Thanks to the modular nature of `hydra` you can reuse many of our config settings for your own experiments with Graphium. + +## Preparing the data in advance +The data preparation including the featurization (e.g., of molecules from smiles to pyg-compatible format) is embedded in the pipeline and will be performed when executing `graphium-train [...]`. + +However, when working with larger datasets, it is recommended to perform data preparation in advance using a machine with sufficient allocated memory (e.g., ~400GB in the case of `LargeMix`). Preparing data in advance is also beneficial when running lots of concurrent jobs with identical molecular featurization, so that resources aren't wasted and processes don't conflict reading/writing in the same directory. + +The following command-line will prepare the data and cache it, then use it to train a model. +```bash +# First prepare the data and cache it in `path_to_cached_data` +graphium data prepare ++datamodule.args.processed_graph_data_path=[path_to_cached_data] + +# Then train the model on the prepared data +graphium-train [...] datamodule.args.processed_graph_data_path=[path_to_cached_data] +``` + +**Note** that `datamodule.args.processed_graph_data_path` can also be specified at `expts/hydra_configs/`. + +**Note** that, every time the configs of `datamodule.args.featurization` changes, you will need to run a new data preparation, which will automatically be saved in a separate directory that uses a hash unique to the configs. ## License diff --git a/expts/hydra-configs/README.md b/expts/hydra-configs/README.md index f695ae20c..40625917d 100644 --- a/expts/hydra-configs/README.md +++ b/expts/hydra-configs/README.md @@ -33,7 +33,7 @@ constants: trainer: model_checkpoint: - dirpath: models_checkpoints/neurips2023-small-gin/ + dirpath: models_checkpoints/neurips2023-small-gin/${now:%Y-%m-%d_%H-%M-%S}/ ``` We can now utilize `hydra` to e.g., run a sweep over our models on the ToyMix dataset via @@ -43,7 +43,7 @@ graphium-train -m model=gcn,gin where the ToyMix dataset is pre-configured in `main.yaml`. Read on to find out how to define new datasets and architectures for pre-training and fine-tuning. ## Pre-training / Fine-tuning -Say you trained a model with the following command: +Say you trained a model with the following command: ```bash graphium-train --config-name "main" ``` diff --git a/expts/hydra-configs/architecture/largemix.yaml b/expts/hydra-configs/architecture/largemix.yaml new file mode 100644 index 000000000..e56108572 --- /dev/null +++ b/expts/hydra-configs/architecture/largemix.yaml @@ -0,0 +1,118 @@ +# @package _global_ + +architecture: + model_type: FullGraphMultiTaskNetwork + mup_base_path: null + pre_nn: # Set as null to avoid a pre-nn network + out_dim: 64 + hidden_dims: 256 + depth: 2 + activation: relu + last_activation: none + dropout: &dropout 0.1 + normalization: &normalization layer_norm + last_normalization: *normalization + residual_type: none + + pre_nn_edges: null + + pe_encoders: + out_dim: 32 + pool: "sum" #"mean" "max" + last_norm: None #"batch_norm", "layer_norm" + encoders: #la_pos | rw_pos + la_pos: # Set as null to avoid a pre-nn network + encoder_type: "laplacian_pe" + input_keys: ["laplacian_eigvec", "laplacian_eigval"] + output_keys: ["feat"] + hidden_dim: 64 + out_dim: 32 + model_type: 'DeepSet' #'Transformer' or 'DeepSet' + num_layers: 2 + num_layers_post: 1 # Num. layers to apply after pooling + dropout: 0.1 + first_normalization: "none" #"batch_norm" or "layer_norm" + rw_pos: + encoder_type: "mlp" + input_keys: ["rw_return_probs"] + output_keys: ["feat"] + hidden_dim: 64 + out_dim: 32 + num_layers: 2 + dropout: 0.1 + normalization: "layer_norm" #"batch_norm" or "layer_norm" + first_normalization: "layer_norm" #"batch_norm" or "layer_norm" + + gnn: # Set as null to avoid a post-nn network + in_dim: 64 # or otherwise the correct value + out_dim: &gnn_dim 768 + hidden_dims: *gnn_dim + depth: 4 + activation: gelu + last_activation: none + dropout: 0.1 + normalization: "layer_norm" + last_normalization: *normalization + residual_type: simple + virtual_node: 'none' + + graph_output_nn: + graph: + pooling: [sum] + out_dim: *gnn_dim + hidden_dims: *gnn_dim + depth: 1 + activation: relu + last_activation: none + dropout: *dropout + normalization: *normalization + last_normalization: "none" + residual_type: none + node: + pooling: [sum] + out_dim: *gnn_dim + hidden_dims: *gnn_dim + depth: 1 + activation: relu + last_activation: none + dropout: *dropout + normalization: *normalization + last_normalization: "none" + residual_type: none + +datamodule: + module_type: "MultitaskFromSmilesDataModule" + args: + prepare_dict_or_graph: pyg:graph + featurization_n_jobs: 20 + featurization_progress: True + featurization_backend: "loky" + processed_graph_data_path: "../datacache/large-dataset/" + dataloading_from: "disk" + num_workers: 20 # -1 to use all + persistent_workers: True + featurization: + atom_property_list_onehot: [atomic-number, group, period, total-valence] + atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring] + edge_property_list: [bond-type-onehot, stereo, in-ring] + add_self_loop: False + explicit_H: False # if H is included + use_bonds_weights: False + pos_encoding_as_features: # encoder dropout 0.18 + pos_types: + lap_eigvec: + pos_level: node + pos_type: laplacian_eigvec + num_pos: 8 + normalization: "none" # nomrlization already applied on the eigen vectors + disconnected_comp: True # if eigen values/vector for disconnected graph are included + lap_eigval: + pos_level: node + pos_type: laplacian_eigval + num_pos: 8 + normalization: "none" # nomrlization already applied on the eigen vectors + disconnected_comp: True # if eigen values/vector for disconnected graph are included + rw_pos: # use same name as pe_encoder + pos_level: node + pos_type: rw_return_probs + ksteps: 16 \ No newline at end of file diff --git a/expts/hydra-configs/experiment/toymix_mpnn.yaml b/expts/hydra-configs/experiment/toymix_mpnn.yaml index b8b552b66..d79311335 100644 --- a/expts/hydra-configs/experiment/toymix_mpnn.yaml +++ b/expts/hydra-configs/experiment/toymix_mpnn.yaml @@ -10,4 +10,4 @@ constants: trainer: model_checkpoint: - dirpath: models_checkpoints/neurips2023-small-mpnn/ \ No newline at end of file + dirpath: models_checkpoints/neurips2023-small-mpnn/${now:%Y-%m-%d_%H-%M-%S}/ \ No newline at end of file diff --git a/expts/hydra-configs/model/gine.yaml b/expts/hydra-configs/model/gine.yaml new file mode 100644 index 000000000..50a4638f9 --- /dev/null +++ b/expts/hydra-configs/model/gine.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +architecture: + pre_nn_edges: # Set as null to avoid a pre-nn network + out_dim: 32 + hidden_dims: 128 + depth: 2 + activation: relu + last_activation: none + dropout: ${architecture.pre_nn.dropout} + normalization: ${architecture.pre_nn.normalization} + last_normalization: ${architecture.pre_nn.normalization} + residual_type: none + + gnn: + out_dim: &gnn_dim 704 + hidden_dims: *gnn_dim + layer_type: 'pyg:gine' + + graph_output_nn: + graph: + out_dim: *gnn_dim + hidden_dims: *gnn_dim + node: + out_dim: *gnn_dim + hidden_dims: *gnn_dim diff --git a/expts/hydra-configs/tasks/l1000_mcf7.yaml b/expts/hydra-configs/tasks/l1000_mcf7.yaml new file mode 100644 index 000000000..b6ffdfde7 --- /dev/null +++ b/expts/hydra-configs/tasks/l1000_mcf7.yaml @@ -0,0 +1,7 @@ +# NOTE: We cannot have a single config, since for fine-tuning we will +# only want to override the loss_metrics_datamodule, whereas for training we will +# want to override both. + +defaults: + - task_heads: l1000_mcf7 + - loss_metrics_datamodule: l1000_mcf7 \ No newline at end of file diff --git a/expts/hydra-configs/tasks/l1000_vcap.yaml b/expts/hydra-configs/tasks/l1000_vcap.yaml new file mode 100644 index 000000000..e212a4594 --- /dev/null +++ b/expts/hydra-configs/tasks/l1000_vcap.yaml @@ -0,0 +1,7 @@ +# NOTE: We cannot have a single config, since for fine-tuning we will +# only want to override the loss_metrics_datamodule, whereas for training we will +# want to override both. + +defaults: + - task_heads: l1000_vcap + - loss_metrics_datamodule: l1000_vcap \ No newline at end of file diff --git a/expts/hydra-configs/tasks/largemix.yaml b/expts/hydra-configs/tasks/largemix.yaml new file mode 100644 index 000000000..07417829c --- /dev/null +++ b/expts/hydra-configs/tasks/largemix.yaml @@ -0,0 +1,7 @@ +# NOTE: We cannot have a single config, since for fine-tuning we will +# only want to override the loss_metrics_datamodule, whereas for training we will +# want to override both. + +defaults: + - task_heads: largemix + - loss_metrics_datamodule: largemix \ No newline at end of file diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_mcf7.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_mcf7.yaml new file mode 100644 index 000000000..43933a7fa --- /dev/null +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_mcf7.yaml @@ -0,0 +1,49 @@ +# @package _global_ + +predictor: + metrics_on_progress_bar: + l1000_mcf7: [] + metrics_on_training_set: + l1000_mcf7: [] + loss_fun: + l1000_mcf7: + name: hybrid_ce_ipu + n_brackets: 3 + alpha: 0.5 + +metrics: + l1000_mcf7: + - name: auroc + metric: auroc + num_classes: 3 + task: multiclass + target_to_int: True + target_nan_mask: -1000 + ignore_index: -1000 + multitask_handling: mean-per-label + threshold_kwargs: null + - name: avpr + metric: averageprecision + num_classes: 3 + task: multiclass + target_to_int: True + target_nan_mask: -1000 + ignore_index: -1000 + multitask_handling: mean-per-label + threshold_kwargs: null + +datamodule: + args: # Matches that in the test_multitask_datamodule.py case. + task_specific_args: # To be replaced by a new class "DatasetParams" + l1000_mcf7: + df: null + df_path: ../data/graphium/large-dataset/LINCS_L1000_MCF7_0-2_th2.csv.gz + # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz + # or set path as the URL directly + smiles_col: "SMILES" + label_cols: geneID-* # geneID-* means all columns starting with "geneID-" + # sample_size: 2000 # use sample_size for test + task_level: graph + splits_path: ../data/graphium/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt` + # split_names: [train, val, test_seen] + epoch_sampling_fraction: 1.0 \ No newline at end of file diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_vcap.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_vcap.yaml new file mode 100644 index 000000000..27b89d862 --- /dev/null +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_vcap.yaml @@ -0,0 +1,49 @@ +# @package _global_ + +predictor: + metrics_on_progress_bar: + l1000_vcap: [] + metrics_on_training_set: + l1000_vcap: [] + loss_fun: + l1000_vcap: + name: hybrid_ce_ipu + n_brackets: 3 + alpha: 0.5 + +metrics: + l1000_vcap: + - name: auroc + metric: auroc + num_classes: 3 + task: multiclass + target_to_int: True + target_nan_mask: -1000 + ignore_index: -1000 + multitask_handling: mean-per-label + threshold_kwargs: null + - name: avpr + metric: averageprecision + num_classes: 3 + task: multiclass + target_to_int: True + target_nan_mask: -1000 + ignore_index: -1000 + multitask_handling: mean-per-label + threshold_kwargs: null + +datamodule: + args: # Matches that in the test_multitask_datamodule.py case. + task_specific_args: # To be replaced by a new class "DatasetParams" + l1000_vcap: + df: null + df_path: ../data/graphium/large-dataset/LINCS_L1000_VCAP_0-2_th2.csv.gz + # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz + # or set path as the URL directly + smiles_col: "SMILES" + label_cols: geneID-* # geneID-* means all columns starting with "geneID-" + # sample_size: 2000 # use sample_size for test + task_level: graph + splits_path: ../data/graphium/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt` + # split_names: [train, val, test_seen] + epoch_sampling_fraction: 1.0 \ No newline at end of file diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/largemix.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/largemix.yaml new file mode 100644 index 000000000..921960cd1 --- /dev/null +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/largemix.yaml @@ -0,0 +1,155 @@ +# @package _global_ + +predictor: + metrics_on_progress_bar: + l1000_vcap: [] + l1000_mcf7: [] + pcba_1328: [] + pcqm4m_g25: [] + pcqm4m_n4: [] + metrics_on_training_set: + l1000_vcap: [] + l1000_mcf7: [] + pcba_1328: [] + pcqm4m_g25: [] + pcqm4m_n4: [] + loss_fun: + l1000_vcap: + name: hybrid_ce_ipu + n_brackets: 3 + alpha: 0.5 + l1000_mcf7: + name: hybrid_ce_ipu + n_brackets: 3 + alpha: ${predictor.loss_fun.l1000_vcap.alpha} + pcba_1328: bce_logits_ipu + pcqm4m_g25: mae_ipu + pcqm4m_n4: mae_ipu + +metrics: + l1000_vcap: &classif_metrics + - name: auroc + metric: auroc + num_classes: 3 + task: multiclass + target_to_int: True + target_nan_mask: -1000 + ignore_index: -1000 + multitask_handling: mean-per-label + threshold_kwargs: null + - name: avpr + metric: averageprecision + num_classes: 3 + task: multiclass + target_to_int: True + target_nan_mask: -1000 + ignore_index: -1000 + multitask_handling: mean-per-label + threshold_kwargs: null + l1000_mcf7: *classif_metrics + pcba_1328: + # use auroc and averageprecision (non_ipu version) so tha nans are handled correctly + - name: auroc + metric: auroc + task: binary + multitask_handling: mean-per-label + target_nan_mask: ignore + threshold_kwargs: null + - name: avpr + metric: averageprecision + task: binary + multitask_handling: mean-per-label + target_nan_mask: ignore + threshold_kwargs: null + pcqm4m_g25: &pcqm_metrics + - name: mae + metric: mae_ipu + target_nan_mask: null + multitask_handling: mean-per-label + threshold_kwargs: null + - name: pearsonr + metric: pearsonr_ipu + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + - name: r2 + metric: r2_score_ipu + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + pcqm4m_n4: *pcqm_metrics + +datamodule: + args: # Matches that in the test_multitask_datamodule.py case. + task_specific_args: # To be replaced by a new class "DatasetParams" + l1000_vcap: + df: null + df_path: ../data/graphium/large-dataset/LINCS_L1000_VCAP_0-2_th2.csv.gz + # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz + # or set path as the URL directly + smiles_col: "SMILES" + label_cols: geneID-* # geneID-* means all columns starting with "geneID-" + # sample_size: 2000 # use sample_size for test + task_level: graph + splits_path: ../data/graphium/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt` + # split_names: [train, val, test_seen] + epoch_sampling_fraction: 1.0 + + l1000_mcf7: + df: null + df_path: ../data/graphium/large-dataset/LINCS_L1000_MCF7_0-2_th2.csv.gz + # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz + # or set path as the URL directly + smiles_col: "SMILES" + label_cols: geneID-* # geneID-* means all columns starting with "geneID-" + # sample_size: 2000 # use sample_size for test + task_level: graph + splits_path: ../data/graphium/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt` + # split_names: [train, val, test_seen] + epoch_sampling_fraction: 1.0 + + pcba_1328: + df: null + df_path: ../data/graphium/large-dataset/PCBA_1328_1564k.parquet + # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet + # or set path as the URL directly + smiles_col: "SMILES" + label_cols: assayID-* # assayID-* means all columns starting with "assayID-" + # sample_size: 2000 # use sample_size for test + task_level: graph + splits_path: ../data/graphium/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt` + # split_names: [train, val, test_seen] + epoch_sampling_fraction: 1.0 + + pcqm4m_g25: + df: null + df_path: ../data/graphium/large-dataset/PCQM4M_G25_N4.parquet + # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet + # or set path as the URL directly + smiles_col: "ordered_smiles" + label_cols: graph_* # graph_* means all columns starting with "graph_" + # sample_size: 2000 # use sample_size for test + task_level: graph + splits_path: ../data/graphium/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt` + # split_names: [train, val, test_seen] + label_normalization: + normalize_val_test: True + method: "normal" + epoch_sampling_fraction: 1.0 + + pcqm4m_n4: + df: null + df_path: ../data/graphium/large-dataset/PCQM4M_G25_N4.parquet + # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet + # or set path as the URL directly + smiles_col: "ordered_smiles" + label_cols: node_* # node_* means all columns starting with "node_" + # sample_size: 2000 # use sample_size for test + task_level: node + splits_path: ../data/graphium/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt` + # split_names: [train, val, test_seen] + seed: 42 + label_normalization: + normalize_val_test: True + method: "normal" + epoch_sampling_fraction: 1.0 \ No newline at end of file diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/pcba_1328.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcba_1328.yaml new file mode 100644 index 000000000..adc3321a0 --- /dev/null +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcba_1328.yaml @@ -0,0 +1,41 @@ +# @package _global_ + +predictor: + metrics_on_progress_bar: + pcba_1328: [] + metrics_on_training_set: + pcba_1328: [] + loss_fun: + pcba_1328: bce_logits_ipu + +metrics: + pcba_1328: + # use auroc and averageprecision (non_ipu version) so tha nans are handled correctly + - name: auroc + metric: auroc + task: binary + multitask_handling: mean-per-label + target_nan_mask: ignore + threshold_kwargs: null + - name: avpr + metric: averageprecision + task: binary + multitask_handling: mean-per-label + target_nan_mask: ignore + threshold_kwargs: null + +datamodule: + args: # Matches that in the test_multitask_datamodule.py case. + task_specific_args: # To be replaced by a new class "DatasetParams" + pcba_1328: + df: null + df_path: ../data/graphium/large-dataset/PCBA_1328_1564k.parquet + # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet + # or set path as the URL directly + smiles_col: "SMILES" + label_cols: assayID-* # assayID-* means all columns starting with "assayID-" + # sample_size: 2000 # use sample_size for test + task_level: graph + splits_path: ../data/graphium/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt` + # split_names: [train, val, test_seen] + epoch_sampling_fraction: 1.0 \ No newline at end of file diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_g25.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_g25.yaml new file mode 100644 index 000000000..047701f6e --- /dev/null +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_g25.yaml @@ -0,0 +1,46 @@ +# @package _global_ + +predictor: + metrics_on_progress_bar: + pcqm4m_g25: [] + metrics_on_training_set: + pcqm4m_g25: [] + loss_fun: + pcqm4m_g25: mae_ipu + +metrics: + pcqm4m_g25: + - name: mae + metric: mae_ipu + target_nan_mask: null + multitask_handling: mean-per-label + threshold_kwargs: null + - name: pearsonr + metric: pearsonr_ipu + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + - name: r2 + metric: r2_score_ipu + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + +datamodule: + args: # Matches that in the test_multitask_datamodule.py case. + task_specific_args: # To be replaced by a new class "DatasetParams" + pcqm4m_g25: + df: null + df_path: ../data/graphium/large-dataset/PCQM4M_G25_N4.parquet + # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet + # or set path as the URL directly + smiles_col: "ordered_smiles" + label_cols: graph_* # graph_* means all columns starting with "graph_" + # sample_size: 2000 # use sample_size for test + task_level: graph + splits_path: ../data/graphium/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt` + # split_names: [train, val, test_seen] + label_normalization: + normalize_val_test: True + method: "normal" + epoch_sampling_fraction: 1.0 \ No newline at end of file diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_n4.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_n4.yaml new file mode 100644 index 000000000..494843464 --- /dev/null +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_n4.yaml @@ -0,0 +1,45 @@ +# @package _global_ + +predictor: + metrics_on_progress_bar: + pcqm4m_n4: [] + metrics_on_training_set: + pcqm4m_n4: [] + loss_fun: + pcqm4m_n4: mae_ipu + +metrics: + pcqm4m_n4: + - name: mae + metric: mae_ipu + target_nan_mask: null + multitask_handling: mean-per-label + threshold_kwargs: null + - name: pearsonr + metric: pearsonr_ipu + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + - name: r2 + metric: r2_score_ipu + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + +datamodule: + pcqm4m_n4: + df: null + df_path: ../data/graphium/large-dataset/PCQM4M_G25_N4.parquet + # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet + # or set path as the URL directly + smiles_col: "ordered_smiles" + label_cols: node_* # node_* means all columns starting with "node_" + # sample_size: 2000 # use sample_size for test + task_level: node + splits_path: ../data/graphium/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt` + # split_names: [train, val, test_seen] + seed: 42 + label_normalization: + normalize_val_test: True + method: "normal" + epoch_sampling_fraction: 1.0 \ No newline at end of file diff --git a/expts/hydra-configs/tasks/pcba_1328.yaml b/expts/hydra-configs/tasks/pcba_1328.yaml new file mode 100644 index 000000000..61b5e7b29 --- /dev/null +++ b/expts/hydra-configs/tasks/pcba_1328.yaml @@ -0,0 +1,7 @@ +# NOTE: We cannot have a single config, since for fine-tuning we will +# only want to override the loss_metrics_datamodule, whereas for training we will +# want to override both. + +defaults: + - task_heads: pcba_1328 + - loss_metrics_datamodule: pcba_1328 \ No newline at end of file diff --git a/expts/hydra-configs/tasks/pcqm4m_g25.yaml b/expts/hydra-configs/tasks/pcqm4m_g25.yaml new file mode 100644 index 000000000..1d5b03469 --- /dev/null +++ b/expts/hydra-configs/tasks/pcqm4m_g25.yaml @@ -0,0 +1,7 @@ +# NOTE: We cannot have a single config, since for fine-tuning we will +# only want to override the loss_metrics_datamodule, whereas for training we will +# want to override both. + +defaults: + - task_heads: pcqm4m_g25 + - loss_metrics_datamodule: pcqm4m_g25 \ No newline at end of file diff --git a/expts/hydra-configs/tasks/pcqm4m_n4.yaml b/expts/hydra-configs/tasks/pcqm4m_n4.yaml new file mode 100644 index 000000000..daa077ccc --- /dev/null +++ b/expts/hydra-configs/tasks/pcqm4m_n4.yaml @@ -0,0 +1,7 @@ +# NOTE: We cannot have a single config, since for fine-tuning we will +# only want to override the loss_metrics_datamodule, whereas for training we will +# want to override both. + +defaults: + - task_heads: pcqm4m_n4 + - loss_metrics_datamodule: pcqm4m_n4 \ No newline at end of file diff --git a/expts/hydra-configs/tasks/task_heads/l1000_mcf7.yaml b/expts/hydra-configs/tasks/task_heads/l1000_mcf7.yaml new file mode 100644 index 000000000..c449a03f1 --- /dev/null +++ b/expts/hydra-configs/tasks/task_heads/l1000_mcf7.yaml @@ -0,0 +1,15 @@ +# @package _global_ + +architecture: + task_heads: + l1000_mcf7: + task_level: graph + out_dim: 2934 + hidden_dims: 128 + depth: 2 + activation: none + last_activation: none + dropout: ${architecture.pre_nn.dropout} + normalization: ${architecture.pre_nn.normalization} + last_normalization: "none" + residual_type: none \ No newline at end of file diff --git a/expts/hydra-configs/tasks/task_heads/l1000_vcap.yaml b/expts/hydra-configs/tasks/task_heads/l1000_vcap.yaml new file mode 100644 index 000000000..a71e75709 --- /dev/null +++ b/expts/hydra-configs/tasks/task_heads/l1000_vcap.yaml @@ -0,0 +1,15 @@ +# @package _global_ + +architecture: + task_heads: + l1000_vcap: + task_level: graph + out_dim: 2934 + hidden_dims: 128 + depth: 2 + activation: none + last_activation: none + dropout: ${architecture.pre_nn.dropout} + normalization: ${architecture.pre_nn.normalization} + last_normalization: "none" + residual_type: none \ No newline at end of file diff --git a/expts/hydra-configs/tasks/task_heads/largemix.yaml b/expts/hydra-configs/tasks/task_heads/largemix.yaml new file mode 100644 index 000000000..e69c38d1d --- /dev/null +++ b/expts/hydra-configs/tasks/task_heads/largemix.yaml @@ -0,0 +1,59 @@ +# @package _global_ + +architecture: + task_heads: + l1000_vcap: + task_level: graph + out_dim: 2934 + hidden_dims: 128 + depth: 2 + activation: none + last_activation: none + dropout: ${architecture.pre_nn.dropout} + normalization: ${architecture.pre_nn.normalization} + last_normalization: "none" + residual_type: none + l1000_mcf7: + task_level: graph + out_dim: 2934 + hidden_dims: 128 + depth: 2 + activation: none + last_activation: none + dropout: ${architecture.pre_nn.dropout} + normalization: ${architecture.pre_nn.normalization} + last_normalization: "none" + residual_type: none + pcba_1328: + task_level: graph + out_dim: 1328 + hidden_dims: 64 + depth: 2 + activation: relu + last_activation: none + dropout: ${architecture.pre_nn.dropout} + normalization: ${architecture.pre_nn.normalization} + last_normalization: "none" + residual_type: none + pcqm4m_g25: + task_level: graph + out_dim: 25 + hidden_dims: 32 + depth: 2 + activation: relu + last_activation: none + dropout: ${architecture.pre_nn.dropout} + normalization: ${architecture.pre_nn.normalization} + last_normalization: "none" + residual_type: none + pcqm4m_n4: + task_level: node + out_dim: 4 + hidden_dims: 32 + depth: 2 + activation: relu + last_activation: none + dropout: ${architecture.pre_nn.dropout} + normalization: ${architecture.pre_nn.normalization} + last_normalization: "none" + residual_type: none \ No newline at end of file diff --git a/expts/hydra-configs/tasks/task_heads/pcba_1328.yaml b/expts/hydra-configs/tasks/task_heads/pcba_1328.yaml new file mode 100644 index 000000000..498c89e98 --- /dev/null +++ b/expts/hydra-configs/tasks/task_heads/pcba_1328.yaml @@ -0,0 +1,15 @@ +# @package _global_ + +architecture: + task_heads: + pcba_1328: + task_level: graph + out_dim: 1328 + hidden_dims: 64 + depth: 2 + activation: relu + last_activation: none + dropout: ${architecture.pre_nn.dropout} + normalization: ${architecture.pre_nn.normalization} + last_normalization: "none" + residual_type: none \ No newline at end of file diff --git a/expts/hydra-configs/tasks/task_heads/pcqm4m_g25.yaml b/expts/hydra-configs/tasks/task_heads/pcqm4m_g25.yaml new file mode 100644 index 000000000..813ab1997 --- /dev/null +++ b/expts/hydra-configs/tasks/task_heads/pcqm4m_g25.yaml @@ -0,0 +1,15 @@ +# @package _global_ + +architecture: + task_heads: + pcqm4m_g25: + task_level: graph + out_dim: 25 + hidden_dims: 32 + depth: 2 + activation: relu + last_activation: none + dropout: ${architecture.pre_nn.dropout} + normalization: ${architecture.pre_nn.normalization} + last_normalization: "none" + residual_type: none \ No newline at end of file diff --git a/expts/hydra-configs/tasks/task_heads/pcqm4m_n4.yaml b/expts/hydra-configs/tasks/task_heads/pcqm4m_n4.yaml new file mode 100644 index 000000000..dda781cb0 --- /dev/null +++ b/expts/hydra-configs/tasks/task_heads/pcqm4m_n4.yaml @@ -0,0 +1,15 @@ +# @package _global_ + +architecture: + task_heads: + pcqm4m_n4: + task_level: node + out_dim: 4 + hidden_dims: 32 + depth: 2 + activation: relu + last_activation: none + dropout: ${architecture.pre_nn.dropout} + normalization: ${architecture.pre_nn.normalization} + last_normalization: "none" + residual_type: none \ No newline at end of file diff --git a/expts/hydra-configs/training/accelerator/largemix_cpu.yaml b/expts/hydra-configs/training/accelerator/largemix_cpu.yaml new file mode 100644 index 000000000..6f5e0606a --- /dev/null +++ b/expts/hydra-configs/training/accelerator/largemix_cpu.yaml @@ -0,0 +1,19 @@ +# @package _global_ + +datamodule: + args: + batch_size_training: 200 + batch_size_inference: 200 + featurization_n_jobs: 20 + num_workers: 20 + +predictor: + metrics_every_n_train_steps: 1000 + torch_scheduler_kwargs: + max_num_epochs: ${constants.max_epochs} + +trainer: + trainer: + precision: 32 + accumulate_grad_batches: 2 + max_epochs: ${constants.max_epochs} \ No newline at end of file diff --git a/expts/hydra-configs/training/accelerator/largemix_gpu.yaml b/expts/hydra-configs/training/accelerator/largemix_gpu.yaml new file mode 100644 index 000000000..06f2b9d5e --- /dev/null +++ b/expts/hydra-configs/training/accelerator/largemix_gpu.yaml @@ -0,0 +1,22 @@ +# @package _global_ + +accelerator: + float32_matmul_precision: medium + +datamodule: + args: + batch_size_training: 960 + batch_size_inference: 960 + featurization_n_jobs: 6 + num_workers: 6 + +predictor: + metrics_every_n_train_steps: 1000 + torch_scheduler_kwargs: + max_num_epochs: ${constants.max_epochs} + +trainer: + trainer: + precision: 16-mixed + # accumulate_grad_batches: 2 + max_epochs: ${constants.max_epochs} \ No newline at end of file diff --git a/expts/hydra-configs/training/accelerator/largemix_ipu.yaml b/expts/hydra-configs/training/accelerator/largemix_ipu.yaml new file mode 100644 index 000000000..090600e98 --- /dev/null +++ b/expts/hydra-configs/training/accelerator/largemix_ipu.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +datamodule: + args: + ipu_dataloader_training_opts: + mode: async + max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54 + max_num_edges_per_graph: 100 + ipu_dataloader_inference_opts: + mode: async + max_num_nodes_per_graph: 35 # valid max nodes: 51, max_edges: 118 + max_num_edges_per_graph: 100 + # Data handling-related + batch_size_training: 30 + batch_size_inference: 30 + +predictor: + optim_kwargs: + loss_scaling: 1024 + +trainer: + trainer: + precision: 16-true + accumulate_grad_batches: 2 \ No newline at end of file diff --git a/expts/hydra-configs/training/largemix.yaml b/expts/hydra-configs/training/largemix.yaml new file mode 100644 index 000000000..7c1a67953 --- /dev/null +++ b/expts/hydra-configs/training/largemix.yaml @@ -0,0 +1,33 @@ +# @package _global_ + +predictor: + random_seed: ${constants.seed} + optim_kwargs: + lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs + # weight_decay: 1.e-7 + torch_scheduler_kwargs: + module_type: WarmUpLinearLR + max_num_epochs: &max_epochs 100 + warmup_epochs: 10 + verbose: False + scheduler_kwargs: + target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label + multitask_handling: flatten # flatten, mean-per-label + +trainer: + seed: ${constants.seed} + logger: + save_dir: logs/neurips2023-large/ + name: ${constants.name} + project: ${constants.name} + model_checkpoint: + dirpath: model_checkpoints/large-dataset/${now:%Y-%m-%d_%H-%M-%S}/ + filename: ${constants.name} + save_last: True # saving last model + # save_top_k: 1 # and best model + # monitor: loss/val # wrt validation loss + trainer: + precision: 16-mixed + max_epochs: ${predictor.torch_scheduler_kwargs.max_num_epochs} + min_epochs: 1 + check_val_every_n_epoch: 20 \ No newline at end of file diff --git a/expts/hydra-configs/training/model/largemix_gcn.yaml b/expts/hydra-configs/training/model/largemix_gcn.yaml new file mode 100644 index 000000000..04864ecc9 --- /dev/null +++ b/expts/hydra-configs/training/model/largemix_gcn.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +constants: + name: large_data_gcn + wandb: + name: ${constants.name} + project: neurips2023-expts + entity: multitask-gnn + save_dir: logs/${constants.name} + entity: multitask-gnn + seed: 42 + max_epochs: 200 + data_dir: expts/data/large-dataset + raise_train_error: true + +trainer: + model_checkpoint: + dirpath: model_checkpoints/large-dataset/gcn/${now:%Y-%m-%d_%H-%M-%S}/ \ No newline at end of file diff --git a/expts/hydra-configs/training/model/largemix_gin.yaml b/expts/hydra-configs/training/model/largemix_gin.yaml new file mode 100644 index 000000000..41c12a014 --- /dev/null +++ b/expts/hydra-configs/training/model/largemix_gin.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +constants: + name: large_data_gin + wandb: + name: ${constants.name} + project: neurips2023-expts + entity: multitask-gnn + save_dir: logs/${constants.name} + entity: multitask-gnn + seed: 42 + max_epochs: 200 + data_dir: expts/data/large-dataset + raise_train_error: true + +trainer: + model_checkpoint: + dirpath: model_checkpoints/large-dataset/gin/${now:%Y-%m-%d_%H-%M-%S}/ \ No newline at end of file diff --git a/expts/hydra-configs/training/model/largemix_gine.yaml b/expts/hydra-configs/training/model/largemix_gine.yaml new file mode 100644 index 000000000..99cdee3eb --- /dev/null +++ b/expts/hydra-configs/training/model/largemix_gine.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +constants: + name: large_data_gine + wandb: + name: ${constants.name} + project: neurips2023-expts + entity: multitask-gnn + save_dir: logs/${constants.name} + entity: multitask-gnn + seed: 42 + max_epochs: 200 + data_dir: expts/data/large-dataset + raise_train_error: true + +trainer: + model_checkpoint: + dirpath: model_checkpoints/large-dataset/gine/${now:%Y-%m-%d_%H-%M-%S}/ \ No newline at end of file diff --git a/expts/hydra-configs/training/model/pcqm4m_gpspp.yaml b/expts/hydra-configs/training/model/pcqm4m_gpspp.yaml index 7fb1e1ee5..c156020b5 100644 --- a/expts/hydra-configs/training/model/pcqm4m_gpspp.yaml +++ b/expts/hydra-configs/training/model/pcqm4m_gpspp.yaml @@ -10,4 +10,4 @@ constants: trainer: model_checkpoint: - dirpath: models_checkpoints/PCMQ4Mv2/gpspp/ + dirpath: models_checkpoints/PCMQ4Mv2/gpspp/${now:%Y-%m-%d_%H-%M-%S}/ diff --git a/expts/hydra-configs/training/model/pcqm4m_mpnn.yaml b/expts/hydra-configs/training/model/pcqm4m_mpnn.yaml index ca643fe39..bab3896d1 100644 --- a/expts/hydra-configs/training/model/pcqm4m_mpnn.yaml +++ b/expts/hydra-configs/training/model/pcqm4m_mpnn.yaml @@ -10,4 +10,4 @@ constants: trainer: model_checkpoint: - dirpath: models_checkpoints/PCMQ4Mv2/mpnn/ + dirpath: models_checkpoints/PCMQ4Mv2/mpnn/${now:%Y-%m-%d_%H-%M-%S}/ diff --git a/expts/hydra-configs/training/model/toymix_gcn.yaml b/expts/hydra-configs/training/model/toymix_gcn.yaml index 3c7a13d05..4180515db 100644 --- a/expts/hydra-configs/training/model/toymix_gcn.yaml +++ b/expts/hydra-configs/training/model/toymix_gcn.yaml @@ -10,4 +10,4 @@ constants: trainer: model_checkpoint: - dirpath: models_checkpoints/neurips2023-small-gcn/ \ No newline at end of file + dirpath: models_checkpoints/small-dataset/gcn/${now:%Y-%m-%d_%H-%M-%S}/ \ No newline at end of file diff --git a/expts/hydra-configs/training/model/toymix_gin.yaml b/expts/hydra-configs/training/model/toymix_gin.yaml index 459694c9a..10007bd32 100644 --- a/expts/hydra-configs/training/model/toymix_gin.yaml +++ b/expts/hydra-configs/training/model/toymix_gin.yaml @@ -10,4 +10,4 @@ constants: trainer: model_checkpoint: - dirpath: models_checkpoints/neurips2023-small-gin/ \ No newline at end of file + dirpath: models_checkpoints/neurips2023-small-gin/${now:%Y-%m-%d_%H-%M-%S}/ \ No newline at end of file diff --git a/expts/hydra-configs/training/pcqm4m.yaml b/expts/hydra-configs/training/pcqm4m.yaml index 871a2a5f1..58860f807 100644 --- a/expts/hydra-configs/training/pcqm4m.yaml +++ b/expts/hydra-configs/training/pcqm4m.yaml @@ -28,7 +28,7 @@ trainer: # patience: 10 # mode: &mode min model_checkpoint: - dirpath: models_checkpoints/PCMQ4Mv2/ + dirpath: models_checkpoints/PCMQ4Mv2/${now:%Y-%m-%d_%H-%M-%S}/ filename: ${constants.name} #monitor: *monitor #mode: *mode diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index 5edca8081..f150ba83d 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -1200,14 +1200,15 @@ def setup( labels_size = {} labels_dtype = {} if stage == "fit" or stage is None: - # if self.train_ds is None: - self.train_ds = self._make_multitask_dataset( - self.dataloading_from, "train", save_smiles_and_ids=save_smiles_and_ids - ) - # if self.val_ds is None: - self.val_ds = self._make_multitask_dataset( - self.dataloading_from, "val", save_smiles_and_ids=save_smiles_and_ids - ) + if self.train_ds is None: + self.train_ds = self._make_multitask_dataset( + self.dataloading_from, "train", save_smiles_and_ids=save_smiles_and_ids + ) + + if self.val_ds is None: + self.val_ds = self._make_multitask_dataset( + self.dataloading_from, "val", save_smiles_and_ids=save_smiles_and_ids + ) logger.info(self.train_ds) logger.info(self.val_ds) @@ -1219,10 +1220,10 @@ def setup( labels_dtype.update(self.val_ds.labels_dtype) if stage == "test" or stage is None: - # if self.test_ds is None: - self.test_ds = self._make_multitask_dataset( - self.dataloading_from, "test", save_smiles_and_ids=save_smiles_and_ids - ) + if self.test_ds is None: + self.test_ds = self._make_multitask_dataset( + self.dataloading_from, "test", save_smiles_and_ids=save_smiles_and_ids + ) logger.info(self.test_ds) @@ -1341,7 +1342,13 @@ def _save_data_to_files(self, save_smiles_and_ids: bool = False) -> None: self.save_featurized_data(temp_datasets[stage], self._path_to_load_from_file(stage)) temp_datasets[stage].save_metadata(self._path_to_load_from_file(stage)) # self.train_ds, self.val_ds, self.test_ds will be created during `setup()` - del temp_datasets + + if self.dataloading_from == "disk": + del temp_datasets + else: + self.train_ds = temp_datasets["train"] + self.val_ds = temp_datasets["val"] + self.test_ds = temp_datasets["test"] def get_folder_size(self, path): # check if the data items are actually saved into the folders diff --git a/graphium/data/dataset.py b/graphium/data/dataset.py index 03737964f..cc8b7b685 100644 --- a/graphium/data/dataset.py +++ b/graphium/data/dataset.py @@ -192,7 +192,7 @@ def __init__( self.features = None self.labels = None elif dataloading_from == "ram": - logger.info("Transferring data from DISK to RAM...") + logger.info(f"Transferring {about} from DISK to RAM...") self.transfer_from_disk_to_ram() else: diff --git a/graphium/trainer/predictor.py b/graphium/trainer/predictor.py index b56b93e8f..aca1dde52 100644 --- a/graphium/trainer/predictor.py +++ b/graphium/trainer/predictor.py @@ -543,18 +543,16 @@ def validation_step(self, batch: Dict[str, Tensor], to_cpu: bool = True) -> Dict def test_step(self, batch: Dict[str, Tensor], to_cpu: bool = True) -> Dict[str, Any]: return self._general_step(batch=batch, step_name="test", to_cpu=to_cpu) - def _general_epoch_end(self, outputs: Dict[str, Any], step_name: str) -> None: + def _general_epoch_end(self, outputs: Dict[str, Any], step_name: str, device: str) -> None: r"""Common code for training_epoch_end, validation_epoch_end and testing_epoch_end""" # Transform the list of dict of dict, into a dict of list of dict preds = {} targets = {} - device = device = outputs[0]["preds"][self.tasks[0]].device # should be better way to do this - # device = 0 for task in self.tasks: - preds[task] = torch.cat([out["preds"][task].to(device=device) for out in outputs], dim=0) - targets[task] = torch.cat([out["targets"][task].to(device=device) for out in outputs], dim=0) + preds[task] = torch.cat([out["preds"][task].to(device) for out in outputs], dim=0) + targets[task] = torch.cat([out["targets"][task].to(device) for out in outputs], dim=0) if ("weights" in outputs[0].keys()) and (outputs[0]["weights"] is not None): - weights = torch.cat([out["weights"] for out in outputs], dim=0) + weights = torch.cat([out["weights"].to(device) for out in outputs], dim=0) else: weights = None @@ -613,7 +611,9 @@ def on_validation_batch_end( return super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) def on_validation_epoch_end(self) -> None: - metrics_logs = self._general_epoch_end(outputs=self.validation_step_outputs, step_name="val") + metrics_logs = self._general_epoch_end( + outputs=self.validation_step_outputs, step_name="val", device="cpu" + ) self.validation_step_outputs.clear() concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs) concatenated_metrics_logs["val/mean_time"] = torch.tensor(self.mean_val_time_tracker.mean_value) @@ -633,7 +633,7 @@ def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader self.test_step_outputs.append(outputs) def on_test_epoch_end(self) -> None: - metrics_logs = self._general_epoch_end(outputs=self.test_step_outputs, step_name="test") + metrics_logs = self._general_epoch_end(outputs=self.test_step_outputs, step_name="test", device="cpu") self.test_step_outputs.clear() concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs) diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 2bc89200c..b00c042e2 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -55,8 +55,7 @@ def test_ogb_datamodule(self): rm(TEMP_CACHE_DATA_PATH, recursive=True) # Reset the datamodule - ds._data_is_prepared = False - ds._data_is_cached = False + ds = GraphOGBDataModule(task_specific_args, **dm_args) ds.prepare_data(save_smiles_and_ids=True) @@ -299,8 +298,7 @@ def test_caching(self): rm(TEMP_CACHE_DATA_PATH, recursive=True) # Reset the datamodule - ds._data_is_prepared = False - ds._data_is_cached = False + ds = GraphOGBDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, **dm_args) ds.prepare_data(save_smiles_and_ids=True)