diff --git a/.gitignore b/.gitignore index 45b5efd..d5ea1ce 100644 --- a/.gitignore +++ b/.gitignore @@ -114,6 +114,13 @@ ENV/ # data dataset/ +datasets/ +graphormer/data/pyg_datasets/datasets/ +examples/property_prediction/dataset/ +examples/property_prediction/datasets/ +graphormer/data/pyg_datasets/test.py +graphormer/data/pyg_datasets/test.tar +fs-mol/ # reranking /examples/reranking/rerank_data @@ -128,12 +135,14 @@ exps # Weights and Biases logs wandb/ - *.pyc *.log ckpts -examples/dataset examples/property_prediction/ckpts -#examples/property_prediction/dataset + !examples/property_prediction/dataset/pcqm4m-v2/RELEASE_v1.txt !examples/property_prediction/dataset/pcqm4m_kddcup2021/RELEASE_v1.txt + +# for self-testing +examples/property_prediction/pcqv2_pyg.sh +examples/property_prediction/fs_mol.sh \ No newline at end of file diff --git a/examples/property_prediction/fs_mol.sh b/examples/property_prediction/fs_mol.sh new file mode 100644 index 0000000..881a49d --- /dev/null +++ b/examples/property_prediction/fs_mol.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +n_gpu=4 +epoch=300 +max_epoch=$((epoch + 1)) +batch_size=64 +tot_updates=$((500000*epoch/batch_size/n_gpu)) +warmup_updates=$((tot_updates/10)) + +CUDA_VISIBLE_DEVICES=0,1 fairseq-train \ +--user-dir ./graphormer \ +--num-workers 16 \ +--ddp-backend=legacy_ddp \ +--dataset-name fsmol \ +--dataset-source pyg \ +--task graph_prediction \ +--criterion binary_logloss \ +--arch graphormer_base \ +--num-classes 5135 \ +--attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \ +--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.0 \ +--lr-scheduler polynomial_decay --power 1 --warmup-updates ${warmup_updates} --total-num-update ${tot_updates} \ +--lr 2e-4 --end-learning-rate 1e-9 \ +--batch-size ${batch_size} \ +--data-buffer-size 20 \ +--encoder-layers 12 \ +--encoder-embed-dim 768 \ +--encoder-ffn-embed-dim 768 \ +--encoder-attention-heads 32 \ +--max-epoch ${max_epoch} \ +--no-save \ +--sandwich-norm \ +--fp16 diff --git a/examples/property_prediction/hiv_pre.sh b/examples/property_prediction/hiv_pre.sh index 5cad3fb..a7ae1b8 100644 --- a/examples/property_prediction/hiv_pre.sh +++ b/examples/property_prediction/hiv_pre.sh @@ -17,7 +17,7 @@ CUDA_VISIBLE_DEVICES=3 fairseq-train \ --dataset-source ogb \ --task graph_prediction_with_flag \ --criterion binary_logloss_with_flag \ ---arch graphormer_base \ +--arch graphormer_graphpred_base \ --num-classes 1 \ --attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \ --optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.0 \ diff --git a/examples/property_prediction/pcqv1.sh b/examples/property_prediction/pcqv1.sh index 4e0fe8b..1d0eb15 100644 --- a/examples/property_prediction/pcqv1.sh +++ b/examples/property_prediction/pcqv1.sh @@ -10,7 +10,7 @@ fairseq-train \ --dataset-source ogb \ --task graph_prediction \ --criterion l1_loss \ ---arch graphormer_base \ +--arch graphormer_graphpred_base \ --num-classes 1 \ --attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \ --optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.0 \ diff --git a/examples/property_prediction/pcqv2.sh b/examples/property_prediction/pcqv2.sh index 4491635..0f34fef 100644 --- a/examples/property_prediction/pcqv2.sh +++ b/examples/property_prediction/pcqv2.sh @@ -12,7 +12,7 @@ fairseq-train \ --dataset-source ogb \ --task graph_prediction \ --criterion l1_loss \ ---arch graphormer_base \ +--arch graphormer_graphpred_base \ --num-classes 1 \ --attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \ --optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.0 \ diff --git a/examples/property_prediction/pcqv2_pyg.sh b/examples/property_prediction/pcqv2_pyg.sh new file mode 100644 index 0000000..5e4920f --- /dev/null +++ b/examples/property_prediction/pcqv2_pyg.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +ulimit -c unlimited + +fairseq-train \ +--user-dir ../../graphormer \ +--num-workers 16 \ +--ddp-backend=legacy_ddp \ +--dataset-name pcqm4mv2_pyg \ +--dataset-source pyg \ +--task graph_prediction \ +--criterion l1_loss \ +--arch graphormer_graphpred_base \ +--num-classes 1 \ +--attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \ +--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.0 \ +--lr-scheduler polynomial_decay --power 1 --warmup-updates 60000 --total-num-update 1000000 \ +--lr 2e-4 --end-learning-rate 1e-9 \ +--batch-size 256 \ +--fp16 \ +--data-buffer-size 20 \ +--no-save +#--save-dir ./ckpts \ No newline at end of file diff --git a/examples/property_prediction/zinc.sh b/examples/property_prediction/zinc.sh index c896882..5c63392 100644 --- a/examples/property_prediction/zinc.sh +++ b/examples/property_prediction/zinc.sh @@ -10,7 +10,7 @@ CUDA_VISIBLE_DEVICES=0 fairseq-train \ --dataset-source pyg \ --task graph_prediction \ --criterion l1_loss \ ---arch graphormer_slim \ +--arch graphormer_graphpred_slim \ --num-classes 1 \ --attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \ --optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.01 \ diff --git a/graphormer/criterions/l1_loss.py b/graphormer/criterions/l1_loss.py index e5e3665..8f83c43 100644 --- a/graphormer/criterions/l1_loss.py +++ b/graphormer/criterions/l1_loss.py @@ -29,7 +29,7 @@ def forward(self, model, sample, reduce=True): natoms = sample["net_input"]["batched_data"]["x"].shape[1] logits = model(**sample["net_input"]) - logits = logits[:, 0, :] + #logits = logits[:, 0, :] # B x C targets = model.get_targets(sample, [logits]) loss = nn.L1Loss(reduction="sum")(logits, targets[: logits.size(0)]) diff --git a/graphormer/data/pyg_datasets/fsmol.py b/graphormer/data/pyg_datasets/fsmol.py new file mode 100644 index 0000000..1713ea5 --- /dev/null +++ b/graphormer/data/pyg_datasets/fsmol.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import os.path as osp +import shutil +from ogb.utils import smiles2graph +from ogb.utils.torch_util import replace_numpy_with_torchtensor +from ogb.utils.url import decide_download, download_url, extract_zip +import pandas as pd +import numpy as np +from tqdm import tqdm +import torch + +from torch_geometric.data import InMemoryDataset, Data + +import tarfile +import jsonlines +import gzip + + +# few-shot not implemented yet +class FSmolPYG(InMemoryDataset): + def __init__( + self, + root, + split, + seed: int = 0, + transform=None, + pre_transform=None + ) -> None: + assert split in ['train', 'valid', 'test'] + self.split = split + self.original_root = root + self.url = 'https://figshare.com/ndownloader/files/31345321' + self.task_to_head = {} + self.num_heads_total = 0 + super().__init__(root, transform, pre_transform) + + # calculate the total number of heads & construct the task_to_head dict + # path: datasets/fsmol/raw/train + path = osp.join(self.raw_dir(), split) + self.unzip_gz_and_calc_head(path) + self.data, self.slices = torch.load(self.processed_file_names) + + def unzip_gz_and_calc_head(self, path): + if os.path.exists(path): + head_num = 0 + dirs = os.listdir(path) + for dir in dirs: + if '.gz' in dir: + filename = dir.replace(".gz","") + assert filename not in self.task_to_head, f"Duplicated task {filename} in split {self.split}!" + self.task_to_head['filename'] = head_num + head_num += 1 + gzip_file = gzip.GzipFile(path + dir) + with open(path + filename,'wb+') as f: + f.write(gzip_file.read()) + for dir in dirs: # delete .gz files + if '.gz' in dir: + os.unlink(dir) + self.num_heads_total = head_num + else: + raise Exception("The file to unzip does not exist!") + + @property + def raw_dir(self): # datasets/fsmol/raw/train + return f"{self.root}/fsmol/raw" + + @property + def processed_dir(self): + return f"{self.root}/fsmol/processed" + + @property + def raw_file_names(self): + return 'fsmol.tar' + + @property + def processed_file_names(self): + return f'{self.split}.pt' + + def download(self): + # Download fsmol.tar to `self.raw_dir` & unzip the file. + # datasets/raw/fsmol/train + path = download_url(self.url, self.original_root) + tar = tarfile.open(path) + tar.extractall() + tar.close() + # os.unlink(path) # keep the tar file + + def process(self): + # Read data into huge `Data` list. + path = osp.join(self.raw_dir(), self.split) + data_list = [] + dirs = os.listdir(path) + for dir in dirs: + with open(dir, "r+", encoding="utf8") as f: + filename = dir.replace(".gz","") + head = self.task_to_head[filename] + for item in jsonlines.Reader(f): + data = Data() + data.head = head + data.smiles = item["SMILES"] + data.y = -1 if item["Property"] == 0.0 else head + data_list.append(data) + + print(f"Converting SMILES strings to graphs in split '{self.split}':") + for i, data in enumerate(tqdm(data_list)): + graph = self.smiles2graph(data.smiles) + data.x = torch.from_numpy(graph['node_feat']).to(torch.int64) + data.edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64) + data.edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64) + del data.smiles + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + + +if __name__ == '__main__': + dataset = FSmolPYG() + print(dataset) + print(dataset.data.edge_index) + print(dataset.data.edge_index.shape) + print(dataset.data.x.shape) + print(dataset[100]) + print(dataset[100].y) + print(dataset.get_idx_split()) \ No newline at end of file diff --git a/graphormer/data/pyg_datasets/pcqv2_pyg.py b/graphormer/data/pyg_datasets/pcqv2_pyg.py new file mode 100644 index 0000000..df2f481 --- /dev/null +++ b/graphormer/data/pyg_datasets/pcqv2_pyg.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import os.path as osp +import shutil +from ogb.utils import smiles2graph +from ogb.utils.torch_util import replace_numpy_with_torchtensor +from ogb.utils.url import decide_download, download_url, extract_zip +import pandas as pd +import numpy as np +from tqdm import tqdm +import torch + +from torch_geometric.data import InMemoryDataset, Data + +class PCQv2PYG(InMemoryDataset): + def __init__(self, root='datasets', smiles2graph = smiles2graph, transform=None, pre_transform=None): + self.folder = osp.join(root, 'pcqm4m-v2') + self.original_root = root + self.smiles2graph = smiles2graph + self.url = 'https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m-v2.zip' + super().__init__(self.folder, transform, pre_transform) + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def raw_file_names(self): + return 'data.csv.gz' + + @property + def processed_file_names(self): + return 'geometric_data_processed.pt' + + def download(self): + # Download to `self.raw_dir`. + path = download_url(self.url, self.original_root) + extract_zip(path, self.original_root) + os.unlink(path) + + def process(self): + # Read data into huge `Data` list. + df = pd.read_csv(osp.join(self.raw_dir, 'data.csv.gz')) + smiles_list = df['smiles'][:20000] + homolumogap_list = df['homolumogap'][:20000] + data_list = [] + + print("Converting SMILES strings to graphs...") + for i in tqdm(range(len(smiles_list))): + data = Data() + graph = self.smiles2graph(smiles_list[i]) + homolumogap = homolumogap_list[i] + data.x = torch.from_numpy(graph['node_feat']).to(torch.int64) + data.edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64) + data.edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64) + data.y = torch.Tensor([homolumogap]) + data_list.append(data) + + # double check NaN values + # split_dict = self.get_idx_split() + # assert(all([not torch.isnan(data_list[i].y)[0] for i in split_dict['train']])) + # assert(all([not torch.isnan(data_list[i].y)[0] for i in split_dict['valid']])) + # assert(all([torch.isnan(data_list[i].y)[0] for i in split_dict['test-dev']])) + # assert(all([torch.isnan(data_list[i].y)[0] for i in split_dict['test-challenge']])) + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + + def get_idx_split(self): + # split_dict = replace_numpy_with_torchtensor(torch.load(osp.join(self.root, 'split_dict.pt'))) + # return split_dict + split_dict = {'train': None, 'valid': None, 'test-dev': None} + split_dict['train'] = torch.from_numpy(np.arange(0, 16000)).to(torch.int64) + split_dict['valid'] = torch.from_numpy(np.arange(16000, 18000)).to(torch.int64) + split_dict['test-dev'] = torch.from_numpy(np.arange(18000, 20000)).to(torch.int64) + return split_dict + +if __name__ == '__main__': + dataset = PCQv2PYG() + print(dataset) + print(dataset.data.edge_index) + print(dataset.data.edge_index.shape) + print(dataset.data.x.shape) + print(dataset[100]) + print(dataset[100].y) + print(dataset.get_idx_split()) \ No newline at end of file diff --git a/graphormer/data/pyg_datasets/pyg_dataset_lookup_table.py b/graphormer/data/pyg_datasets/pyg_dataset_lookup_table.py index b01f3aa..587798c 100644 --- a/graphormer/data/pyg_datasets/pyg_dataset_lookup_table.py +++ b/graphormer/data/pyg_datasets/pyg_dataset_lookup_table.py @@ -5,8 +5,41 @@ from torch_geometric.datasets import * from torch_geometric.data import Dataset from .pyg_dataset import GraphormerPYGDataset +from .pcqv2_pyg import PCQv2PYG +from .fsmol import FSmolPYG + import torch.distributed as dist +from vpack import breakpt + + +class MyFSmolPYG(FSmolPYG): + def download(self): + if not dist.is_initialized() or dist.get_rank() == 0: + super(MyFSmolPYG, self).download() + if dist.is_initialized(): + dist.barrier() + + def process(self): + if not dist.is_initialized() or dist.get_rank() == 0: + super(MyFSmolPYG, self).process() + if dist.is_initialized(): + dist.barrier() + + +class MyPCQv2PYG(PCQv2PYG): + def download(self): + if not dist.is_initialized() or dist.get_rank() == 0: + super(MyPCQv2PYG, self).download() + if dist.is_initialized(): + dist.barrier() + + def process(self): + if not dist.is_initialized() or dist.get_rank() == 0: + super(MyPCQv2PYG, self).process() + if dist.is_initialized(): + dist.barrier() + class MyQM7b(QM7b): def download(self): @@ -98,6 +131,18 @@ def GetPYGDataset(dataset_spec: str, seed: int) -> Optional[Dataset]: if name == "name": nm = value inner_dataset = MyMoleculeNet(root=root, name=nm) + elif name == "pcqm4mv2_pyg": + root = "datasets" + inner_dataset = MyPCQv2PYG(root=root) + idx_split = inner_dataset.get_idx_split() + train_idx = idx_split["train"] + valid_idx = idx_split["valid"] + test_idx = idx_split["test-dev"] + elif name == "fsmol": + root = "datasets" + train_set = MyFSmolPYG(root=root, split="train") + valid_set = MyFSmolPYG(root=root, split="val") + test_set = MyFSmolPYG(root=root, split="test") else: raise ValueError(f"Unknown dataset name {name} for pyg source.") if train_set is not None: @@ -111,6 +156,14 @@ def GetPYGDataset(dataset_spec: str, seed: int) -> Optional[Dataset]: valid_set, test_set, ) + elif train_idx is not None: + return GraphormerPYGDataset( + inner_dataset, + seed, + train_idx, + valid_idx, + test_idx, + ) else: return ( None diff --git a/graphormer/models/__init__.py b/graphormer/models/__init__.py index cdb3b3f..b7c5258 100644 --- a/graphormer/models/__init__.py +++ b/graphormer/models/__init__.py @@ -1,4 +1,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .graphormer import GraphormerModel diff --git a/graphormer/models/graphormer.py b/graphormer/models/graphormer.py deleted file mode 100644 index f952c69..0000000 --- a/graphormer/models/graphormer.py +++ /dev/null @@ -1,354 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import logging - -import torch -import torch.nn as nn -import torch.nn.functional as F -from fairseq import utils -from fairseq.models import ( - FairseqEncoder, - FairseqEncoderModel, - register_model, - register_model_architecture, -) -from fairseq.modules import ( - LayerNorm, -) -from fairseq.utils import safe_hasattr - -from ..modules import init_graphormer_params, GraphormerGraphEncoder - -logger = logging.getLogger(__name__) - -from ..pretrain import load_pretrained_model - - -@register_model("graphormer") -class GraphormerModel(FairseqEncoderModel): - def __init__(self, args, encoder): - super().__init__(encoder) - self.args = args - - if getattr(args, "apply_graphormer_init", False): - self.apply(init_graphormer_params) - self.encoder_embed_dim = args.encoder_embed_dim - if args.pretrained_model_name != "none": - self.load_state_dict(load_pretrained_model(args.pretrained_model_name)) - if not args.load_pretrained_model_output_layer: - self.encoder.reset_output_layer_parameters() - - @staticmethod - def add_args(parser): - """Add model-specific arguments to the parser.""" - # Arguments related to dropout - parser.add_argument( - "--dropout", type=float, metavar="D", help="dropout probability" - ) - parser.add_argument( - "--attention-dropout", - type=float, - metavar="D", - help="dropout probability for" " attention weights", - ) - parser.add_argument( - "--act-dropout", - type=float, - metavar="D", - help="dropout probability after" " activation in FFN", - ) - - # Arguments related to hidden states and self-attention - parser.add_argument( - "--encoder-ffn-embed-dim", - type=int, - metavar="N", - help="encoder embedding dimension for FFN", - ) - parser.add_argument( - "--encoder-layers", type=int, metavar="N", help="num encoder layers" - ) - parser.add_argument( - "--encoder-attention-heads", - type=int, - metavar="N", - help="num encoder attention heads", - ) - - # Arguments related to input and output embeddings - parser.add_argument( - "--encoder-embed-dim", - type=int, - metavar="N", - help="encoder embedding dimension", - ) - parser.add_argument( - "--share-encoder-input-output-embed", - action="store_true", - help="share encoder input" " and output embeddings", - ) - parser.add_argument( - "--encoder-learned-pos", - action="store_true", - help="use learned positional embeddings in the encoder", - ) - parser.add_argument( - "--no-token-positional-embeddings", - action="store_true", - help="if set, disables positional embeddings" " (outside self attention)", - ) - parser.add_argument( - "--max-positions", type=int, help="number of positional embeddings to learn" - ) - - # Arguments related to parameter initialization - parser.add_argument( - "--apply-graphormer-init", - action="store_true", - help="use custom param initialization for Graphormer", - ) - - # misc params - parser.add_argument( - "--activation-fn", - choices=utils.get_available_activation_fns(), - help="activation function to use", - ) - parser.add_argument( - "--encoder-normalize-before", - action="store_true", - help="apply layernorm before each encoder block", - ) - parser.add_argument( - "--pre-layernorm", - action="store_true", - help="apply layernorm before self-attention and ffn. Without this, post layernorm will used", - ) - - def max_nodes(self): - return self.encoder.max_nodes - - @classmethod - def build_model(cls, args, task): - """Build a new model instance.""" - # make sure all arguments are present in older models - base_architecture(args) - - if not safe_hasattr(args, "max_nodes"): - args.max_nodes = args.tokens_per_sample - - logger.info(args) - - encoder = GraphormerEncoder(args) - return cls(args, encoder) - - def forward(self, batched_data, **kwargs): - return self.encoder(batched_data, **kwargs) - - -class GraphormerEncoder(FairseqEncoder): - def __init__(self, args): - super().__init__(dictionary=None) - self.max_nodes = args.max_nodes - - self.graph_encoder = GraphormerGraphEncoder( - # < for graphormer - num_atoms=args.num_atoms, - num_in_degree=args.num_in_degree, - num_out_degree=args.num_out_degree, - num_edges=args.num_edges, - num_spatial=args.num_spatial, - num_edge_dis=args.num_edge_dis, - edge_type=args.edge_type, - multi_hop_max_dist=args.multi_hop_max_dist, - # > - num_encoder_layers=args.encoder_layers, - embedding_dim=args.encoder_embed_dim, - ffn_embedding_dim=args.encoder_ffn_embed_dim, - num_attention_heads=args.encoder_attention_heads, - dropout=args.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.act_dropout, - encoder_normalize_before=args.encoder_normalize_before, - pre_layernorm=args.pre_layernorm, - apply_graphormer_init=args.apply_graphormer_init, - activation_fn=args.activation_fn, - ) - - self.share_input_output_embed = args.share_encoder_input_output_embed - self.embed_out = None - self.lm_output_learned_bias = None - - # Remove head is set to true during fine-tuning - self.load_softmax = not getattr(args, "remove_head", False) - - self.masked_lm_pooler = nn.Linear( - args.encoder_embed_dim, args.encoder_embed_dim - ) - - self.lm_head_transform_weight = nn.Linear( - args.encoder_embed_dim, args.encoder_embed_dim - ) - self.activation_fn = utils.get_activation_fn(args.activation_fn) - self.layer_norm = LayerNorm(args.encoder_embed_dim) - - self.lm_output_learned_bias = None - if self.load_softmax: - self.lm_output_learned_bias = nn.Parameter(torch.zeros(1)) - - if not self.share_input_output_embed: - self.embed_out = nn.Linear( - args.encoder_embed_dim, args.num_classes, bias=False - ) - else: - raise NotImplementedError - - def reset_output_layer_parameters(self): - self.lm_output_learned_bias = nn.Parameter(torch.zeros(1)) - if self.embed_out is not None: - self.embed_out.reset_parameters() - - def forward(self, batched_data, perturb=None, masked_tokens=None, **unused): - inner_states, graph_rep = self.graph_encoder( - batched_data, - perturb=perturb, - ) - - x = inner_states[-1].transpose(0, 1) - - # project masked tokens only - if masked_tokens is not None: - raise NotImplementedError - - x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x))) - - # project back to size of vocabulary - if self.share_input_output_embed and hasattr( - self.graph_encoder.embed_tokens, "weight" - ): - x = F.linear(x, self.graph_encoder.embed_tokens.weight) - elif self.embed_out is not None: - x = self.embed_out(x) - if self.lm_output_learned_bias is not None: - x = x + self.lm_output_learned_bias - - return x - - def max_nodes(self): - """Maximum output length supported by the encoder.""" - return self.max_nodes - - def upgrade_state_dict_named(self, state_dict, name): - if not self.load_softmax: - for k in list(state_dict.keys()): - if "embed_out.weight" in k or "lm_output_learned_bias" in k: - del state_dict[k] - return state_dict - - -@register_model_architecture("graphormer", "graphormer") -def base_architecture(args): - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.act_dropout = getattr(args, "act_dropout", 0.0) - - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) - args.encoder_layers = getattr(args, "encoder_layers", 6) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) - - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) - args.share_encoder_input_output_embed = getattr( - args, "share_encoder_input_output_embed", False - ) - args.no_token_positional_embeddings = getattr( - args, "no_token_positional_embeddings", False - ) - - args.apply_graphormer_init = getattr(args, "apply_graphormer_init", False) - - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) - - -@register_model_architecture("graphormer", "graphormer_base") -def graphormer_base_architecture(args): - if args.pretrained_model_name == "pcqm4mv1_graphormer_base" or \ - args.pretrained_model_name == "pcqm4mv2_graphormer_base" or \ - args.pretrained_model_name == "pcqm4mv1_graphormer_base_for_molhiv": - args.encoder_layers = 12 - args.encoder_attention_heads = 32 - args.encoder_ffn_embed_dim = 768 - args.encoder_embed_dim = 768 - args.dropout = getattr(args, "dropout", 0.0) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.act_dropout = getattr(args, "act_dropout", 0.1) - else: - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) - args.encoder_layers = getattr(args, "encoder_layers", 12) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 32) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768) - args.dropout = getattr(args, "dropout", 0.0) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.act_dropout = getattr(args, "act_dropout", 0.1) - - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) - args.apply_graphormer_init = getattr(args, "apply_graphormer_init", True) - args.share_encoder_input_output_embed = getattr( - args, "share_encoder_input_output_embed", False - ) - args.no_token_positional_embeddings = getattr( - args, "no_token_positional_embeddings", False - ) - args.pre_layernorm = getattr(args, "pre_layernorm", False) - base_architecture(args) - - -@register_model_architecture("graphormer", "graphormer_slim") -def graphormer_slim_architecture(args): - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 80) - - args.encoder_layers = getattr(args, "encoder_layers", 12) - - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 80) - - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) - args.apply_graphormer_init = getattr(args, "apply_graphormer_init", True) - args.share_encoder_input_output_embed = getattr( - args, "share_encoder_input_output_embed", False - ) - args.no_token_positional_embeddings = getattr( - args, "no_token_positional_embeddings", False - ) - args.pre_layernorm = getattr(args, "pre_layernorm", False) - base_architecture(args) - - -@register_model_architecture("graphormer", "graphormer_large") -def graphormer_large_architecture(args): - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) - - args.encoder_layers = getattr(args, "encoder_layers", 24) - - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 32) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) - - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) - args.apply_graphormer_init = getattr(args, "apply_graphormer_init", True) - args.share_encoder_input_output_embed = getattr( - args, "share_encoder_input_output_embed", False - ) - args.no_token_positional_embeddings = getattr( - args, "no_token_positional_embeddings", False - ) - args.pre_layernorm = getattr(args, "pre_layernorm", False) - base_architecture(args) diff --git a/graphormer/models/graphormer_custom_model.py b/graphormer/models/graphormer_custom_model.py new file mode 100644 index 0000000..75eb51b --- /dev/null +++ b/graphormer/models/graphormer_custom_model.py @@ -0,0 +1,165 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +A demonstrative file to show how to define a customized graphormer model. +You can define extra node layers/bias layers or define a predict layer group. +""" + +import logging +import contextlib + +logger = logging.getLogger(__name__) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.models import ( + FairseqEncoderModel, + register_model, + register_model_architecture, +) +from fairseq.utils import safe_hasattr +from ..modules import init_graphormer_params +from ..modules import PredictLayerGroup +from ..modules import GraphormerGraphEncoder as GraphormerGraphEncoderBase +from .graphormer_encoder import GraphormerEncoder as GraphormerEncoderBase + +from ..utils import ( + graphormer_default_add_args, + guess_first_load, + upgrade_state_dict_named_from_pretrained, +) + + +class GraphormerGraphEncoder(GraphormerGraphEncoderBase): + """ + Define extra node layers or bias layers here if needed. + """ + + def init_extra_node_layers(self, args): + super().init_extra_node_layers(args) + # Your code here + pass + + def init_extra_bias_layers(self, args): + super().init_extra_bias_layers(args) + # Your code here + pass + + def forward_extra_node_layers(self, batched_data, x): + x = super().forward_extra_node_layers(batched_data, x) + # Your code here + return x + + def forward_extra_bias_layers(self, batched_data, attn_bias): + bias = super().forward_extra_bias_layers(batched_data, attn_bias) + # Your code here + return bias + + +class GraphormerEncoder(GraphormerEncoderBase): + def build_graph_encoder(self, args): + return GraphormerGraphEncoder(args) + + +@register_model("graphormer_custom") +class GraphormerCustomModel(FairseqEncoderModel): + """ + Register your customized model architecture here. + """ + + def __init__(self, args, encoder): + super().__init__(encoder) + self.args = args + if getattr(args, "apply_graphormer_init", False): + self.apply(init_graphormer_params) + + print(f"{self.__class__.__name__}: {self}") + + @staticmethod + def add_args(parser): + graphormer_default_add_args(parser) + + def max_nodes(self): + return self.encoder.max_nodes + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure all arguments are present in older models + base_architecture(args) + if not safe_hasattr(args, "max_nodes"): + args.max_nodes = args.tokens_per_sample + + encoder = GraphormerEncoder(args) + return cls(args, encoder) + + # Define customized prediction layers. + def register_predictor(self, out_dims): + self.predictor = PredictLayerGroup( + in_dim=self.args.encoder_embed_dim, + out_dims=out_dims, + activation=utils.get_activation_fn(self.args.activation_fn), + n_layers=2, + ) + + def forward(self, batched_data, **kwargs): + encoder_out = self.encoder(batched_data, **kwargs) + x_cls = encoder_out["encoder_out"][0][0, :, :] # B x d + x = self.predictor(x_cls) + return x + + def upgrade_state_dict_named(self, state_dict, name): + named_parameters = {k: v for k, v in self.named_parameters()} + first_load = guess_first_load(named_parameters, state_dict, name) + if first_load: + msg = upgrade_state_dict_named_from_pretrained( + named_parameters, state_dict, name + ) + logger.warning(f"upgrade_state_dict_named_from_pretrained: {msg}") + + # fill missing keys + for k in named_parameters: + if k not in state_dict: + state_dict[k] = named_parameters[k] + logger.warning( + f"Warning: {k} is missing from the checkpoint, copying from model" + ) + + # remove ununsed keys + for k in list(state_dict.keys()): + if k not in named_parameters: + del state_dict[k] + logger.warning( + f"Warning: {k} is not used in the model, removing from the checkpoint" + ) + + return state_dict + + +@register_model_architecture("graphormer_custom", "graphormer_custom_base") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.act_dropout = getattr(args, "act_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 32) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.apply_graphormer_init = getattr(args, "apply_graphormer_init", True) + args.share_encoder_input_output_embed = getattr( + args, "share_encoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) diff --git a/graphormer/models/graphormer_encoder.py b/graphormer/models/graphormer_encoder.py new file mode 100644 index 0000000..f08091e --- /dev/null +++ b/graphormer/models/graphormer_encoder.py @@ -0,0 +1,133 @@ +from builtins import hasattr +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.models import FairseqEncoder + +from fairseq.modules import ( + LayerNorm, +) +from ..modules import GraphormerGraphEncoder + +logger = logging.getLogger(__name__) + + +class GraphormerEncoder(FairseqEncoder): + def __init__(self, args): + super().__init__(dictionary=None) + self.max_nodes = args.max_nodes + + self.graph_encoder = self.build_graph_encoder(args) + self.lm_output_learned_bias = nn.Parameter(torch.zeros(1)) + self.lm_head_transform_weight = nn.Linear( + args.encoder_embed_dim, args.encoder_embed_dim + ) + self.activation_fn = utils.get_activation_fn(args.activation_fn) + self.layer_norm = LayerNorm(args.encoder_embed_dim) + + def build_graph_encoder(self, args): + return GraphormerGraphEncoder(args) + + def make_padding_mask(self, batched_data): + encoder_padding_mask = (batched_data["x"][:, :, 0]).eq(0) # B x T + # prepend 1 for CLS token + B_zeros = torch.zeros( + (encoder_padding_mask.size(0), 1), + dtype=torch.bool, + device=encoder_padding_mask.device, + ) + encoder_padding_mask = torch.cat( + [B_zeros, encoder_padding_mask], dim=1 + ).contiguous() + return encoder_padding_mask + + def forward(self, batched_data, perturb=None, masked_tokens=None, **unused): + inner_states, graph_rep = self.graph_encoder( + batched_data, + perturb=perturb, + ) + + x = inner_states[-1].transpose(0, 1) + x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x))) + x = x + self.lm_output_learned_bias + + encoder_padding_mask = self.make_padding_mask(batched_data) + + src_lengths = ( + (~encoder_padding_mask) + .sum(dim=1, dtype=torch.int32) + .reshape(-1, 1) + .contiguous() + ) + + return { + "encoder_out": [x.transpose(0, 1)], # T x B x C + "encoder_padding_mask": [encoder_padding_mask], # B x T + "encoder_embedding": [x], # B x T x C + "encoder_states": inner_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [src_lengths], + } + + def max_nodes(self): + """Maximum output length supported by the encoder.""" + return self.max_nodes + + def max_positions(self): + return self.max_nodes + + @torch.jit.export + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + if len(encoder_out["encoder_out"]) == 0: + new_encoder_out = [] + else: + new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] + if len(encoder_out["encoder_padding_mask"]) == 0: + new_encoder_padding_mask = [] + else: + new_encoder_padding_mask = [ + encoder_out["encoder_padding_mask"][0].index_select(0, new_order) + ] + if len(encoder_out["encoder_embedding"]) == 0: + new_encoder_embedding = [] + else: + new_encoder_embedding = [ + encoder_out["encoder_embedding"][0].index_select(0, new_order) + ] + + if len(encoder_out["src_tokens"]) == 0: + src_tokens = [] + else: + src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] + + if len(encoder_out["src_lengths"]) == 0: + src_lengths = [] + else: + src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] + + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": src_tokens, # B x T + "src_lengths": src_lengths, # B x 1 + } diff --git a/graphormer/models/graphormer_graphpred.py b/graphormer/models/graphormer_graphpred.py new file mode 100644 index 0000000..ea4f463 --- /dev/null +++ b/graphormer/models/graphormer_graphpred.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Graph property prediction model (classification/regression). +""" + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderModel, + register_model, + register_model_architecture, +) +from fairseq.modules import ( + LayerNorm, +) +from fairseq.utils import safe_hasattr + +from .graphormer_encoder import GraphormerEncoder +from ..modules import init_graphormer_params + +logger = logging.getLogger(__name__) + +from ..pretrain import load_pretrained_model +from ..utils import ( + graphormer_default_add_args, + guess_first_load, + upgrade_state_dict_named_from_pretrained, +) + +@register_model("graphormer_graphpred") +class GraphormerModel(FairseqEncoderModel): + def __init__(self, args, encoder): + super().__init__(encoder) + self.args = args + if getattr(args, "apply_graphormer_init", False): + self.apply(init_graphormer_params) + self.encoder_embed_dim = args.encoder_embed_dim + if args.pretrained_model_name != "none": + self.load_state_dict(load_pretrained_model(args.pretrained_model_name)) + if not args.load_pretrained_model_output_layer: + self.encoder.reset_output_layer_parameters() + self.output_layer = nn.Linear(self.args.encoder_embed_dim, 1) + + def max_nodes(self): + return self.encoder.max_nodes + + @staticmethod + def add_args(parser): + graphormer_default_add_args(parser) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure all arguments are present in older models + base_architecture(args) + + if not safe_hasattr(args, "max_nodes"): + args.max_nodes = args.tokens_per_sample + + logger.info(args) + + encoder = GraphormerEncoder(args) + return cls(args, encoder) + + def forward(self, batched_data, **kwargs): + encoder_out = self.encoder(batched_data, **kwargs) + x = encoder_out["encoder_out"][0].transpose(0, 1) # B x T x C + x = x[:, 0, :] # B x C; Only use cls token for prediction + x = self.output_layer(x) + return x + + def upgrade_state_dict_named(self, state_dict, name): + named_parameters = {k: v for k, v in self.named_parameters()} + first_load = guess_first_load(named_parameters, state_dict, name) + if first_load: + msg = upgrade_state_dict_named_from_pretrained( + named_parameters, state_dict, name + ) + logger.warning(f"upgrade_state_dict_named_from_pretrained: {msg}") + + # fill missing keys + for k in named_parameters: + if k not in state_dict: + state_dict[k] = named_parameters[k] + logger.warning( + f"Warning: {k} is missing from the checkpoint, copying from model" + ) + + # remove ununsed keys + for k in list(state_dict.keys()): + if k not in named_parameters: + del state_dict[k] + logger.warning( + f"Warning: {k} is not used in the model, removing from the checkpoint" + ) + + return state_dict + + +@register_model_architecture("graphormer_graphpred", "graphormer_graphpred") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.act_dropout = getattr(args, "act_dropout", 0.0) + + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.share_encoder_input_output_embed = getattr( + args, "share_encoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + + args.apply_graphormer_init = getattr(args, "apply_graphormer_init", False) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + + +@register_model_architecture("graphormer_graphpred", "graphormer_graphpred_base") +def graphormer_base_architecture(args): + if args.pretrained_model_name == "pcqm4mv1_graphormer_base" or \ + args.pretrained_model_name == "pcqm4mv2_graphormer_base" or \ + args.pretrained_model_name == "pcqm4mv1_graphormer_base_for_molhiv": + args.encoder_layers = 12 + args.encoder_attention_heads = 32 + args.encoder_ffn_embed_dim = 768 + args.encoder_embed_dim = 768 + args.dropout = getattr(args, "dropout", 0.0) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.act_dropout = getattr(args, "act_dropout", 0.1) + else: + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 32) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768) + args.dropout = getattr(args, "dropout", 0.0) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.act_dropout = getattr(args, "act_dropout", 0.1) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.apply_graphormer_init = getattr(args, "apply_graphormer_init", True) + args.share_encoder_input_output_embed = getattr( + args, "share_encoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.pre_layernorm = getattr(args, "pre_layernorm", False) + base_architecture(args) + + +@register_model_architecture("graphormer_graphpred", "graphormer_graphpred_slim") +def graphormer_slim_architecture(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 80) + + args.encoder_layers = getattr(args, "encoder_layers", 12) + + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 80) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.apply_graphormer_init = getattr(args, "apply_graphormer_init", True) + args.share_encoder_input_output_embed = getattr( + args, "share_encoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.pre_layernorm = getattr(args, "pre_layernorm", False) + base_architecture(args) + + +@register_model_architecture("graphormer_graphpred", "graphormer_graphpred_large") +def graphormer_large_architecture(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + + args.encoder_layers = getattr(args, "encoder_layers", 24) + + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 32) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.apply_graphormer_init = getattr(args, "apply_graphormer_init", True) + args.share_encoder_input_output_embed = getattr( + args, "share_encoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.pre_layernorm = getattr(args, "pre_layernorm", False) + base_architecture(args) diff --git a/graphormer/modules/__init__.py b/graphormer/modules/__init__.py index e12b002..89e2e25 100644 --- a/graphormer/modules/__init__.py +++ b/graphormer/modules/__init__.py @@ -5,3 +5,4 @@ from .graphormer_layers import GraphNodeFeature, GraphAttnBias from .graphormer_graph_encoder_layer import GraphormerGraphEncoderLayer from .graphormer_graph_encoder import GraphormerGraphEncoder, init_graphormer_params +from .predict_layers import PredictLayerGroup \ No newline at end of file diff --git a/graphormer/modules/droppath.py b/graphormer/modules/droppath.py new file mode 100644 index 0000000..86d0740 --- /dev/null +++ b/graphormer/modules/droppath.py @@ -0,0 +1,24 @@ +import torch + + +class DropPath(torch.nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, prob=None): + super(DropPath, self).__init__() + self.drop_prob = prob + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + def extra_repr(self) -> str: + return f"prob={self.drop_prob}" diff --git a/graphormer/modules/graphormer_graph_encoder.py b/graphormer/modules/graphormer_graph_encoder.py index 4d18cbe..5006734 100644 --- a/graphormer/modules/graphormer_graph_encoder.py +++ b/graphormer/modules/graphormer_graph_encoder.py @@ -43,45 +43,50 @@ def normal_(data): class GraphormerGraphEncoder(nn.Module): - def __init__( - self, - num_atoms: int, - num_in_degree: int, - num_out_degree: int, - num_edges: int, - num_spatial: int, - num_edge_dis: int, - edge_type: str, - multi_hop_max_dist: int, - num_encoder_layers: int = 12, - embedding_dim: int = 768, - ffn_embedding_dim: int = 768, - num_attention_heads: int = 32, - dropout: float = 0.1, - attention_dropout: float = 0.1, - activation_dropout: float = 0.1, - layerdrop: float = 0.0, - encoder_normalize_before: bool = False, - pre_layernorm: bool = False, - apply_graphormer_init: bool = False, - activation_fn: str = "gelu", - embed_scale: float = None, - freeze_embeddings: bool = False, - n_trans_layers_to_freeze: int = 0, - export: bool = False, - traceable: bool = False, - q_noise: float = 0.0, - qn_block_size: int = 8, - ) -> None: - + def __init__(self, args) -> None: super().__init__() + + # for Graphormer + num_atoms = args.num_atoms + num_in_degree = args.num_in_degree + num_out_degree = args.num_out_degree + num_edges = args.num_edges + num_spatial = args.num_spatial + num_edge_dis = args.num_edge_dis + edge_type = args.edge_type + multi_hop_max_dist = args.multi_hop_max_dist + sandwich_norm = args.sandwich_norm + num_encoder_layers = args.encoder_layers + embedding_dim = args.encoder_embed_dim + ffn_embedding_dim = args.encoder_ffn_embed_dim + num_attention_heads = args.encoder_attention_heads + + # Fine-tuning parameters + layer_scale = args.layer_scale + droppath_prob = args.droppath_prob + + # Attention parameters + dropout = args.dropout + attention_dropout = args.attention_dropout + activation_dropout = args.act_dropout + encoder_normalize_before = args.encoder_normalize_before + apply_graphormer_init = args.apply_graphormer_init + activation_fn = args.activation_fn + + # Disable original Dropout when using DropPath + if droppath_prob > 0.0: + dropout = attention_dropout = activation_dropout = 0.0 + + # stochastic depth decay rule (linearly increasing) + droppath_probs = [ + x.item() for x in torch.linspace(0, droppath_prob, num_encoder_layers) + ] + self.dropout_module = FairseqDropout( dropout, module_name=self.__class__.__name__ ) - self.layerdrop = layerdrop self.embedding_dim = embedding_dim self.apply_graphormer_init = apply_graphormer_init - self.traceable = traceable self.graph_node_feature = GraphNodeFeature( num_heads=num_attention_heads, @@ -91,6 +96,7 @@ def __init__( hidden_dim=embedding_dim, n_layers=num_encoder_layers, ) + self.init_extra_node_layers(args) self.graph_attn_bias = GraphAttnBias( num_heads=num_attention_heads, @@ -103,46 +109,29 @@ def __init__( hidden_dim=embedding_dim, n_layers=num_encoder_layers, ) - - self.embed_scale = embed_scale - - if q_noise > 0: - self.quant_noise = apply_quant_noise_( - nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), - q_noise, - qn_block_size, - ) - else: - self.quant_noise = None + self.init_extra_bias_layers(args) if encoder_normalize_before: - self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export) + self.emb_layer_norm = LayerNorm(self.embedding_dim) else: self.emb_layer_norm = None - if pre_layernorm: - self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) - - if self.layerdrop > 0.0: - self.layers = LayerDropModuleList(p=self.layerdrop) - else: - self.layers = nn.ModuleList([]) + self.layers = nn.ModuleList([]) self.layers.extend( [ self.build_graphormer_graph_encoder_layer( embedding_dim=self.embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, + layer_scale=layer_scale, + droppath=droppath_probs[i], dropout=self.dropout_module.p, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, - export=export, - q_noise=q_noise, - qn_block_size=qn_block_size, - pre_layernorm=pre_layernorm, + sandwich_norm=sandwich_norm, ) - for _ in range(num_encoder_layers) + for i in range(num_encoder_layers) ] ) @@ -150,44 +139,68 @@ def __init__( if self.apply_graphormer_init: self.apply(init_graphormer_params) - def freeze_module_params(m): - if m is not None: - for p in m.parameters(): - p.requires_grad = False - - if freeze_embeddings: - raise NotImplementedError("Freezing embeddings is not implemented yet.") - - for layer in range(n_trans_layers_to_freeze): - freeze_module_params(self.layers[layer]) - def build_graphormer_graph_encoder_layer( self, embedding_dim, ffn_embedding_dim, num_attention_heads, + layer_scale, + droppath, dropout, attention_dropout, activation_dropout, activation_fn, - export, - q_noise, - qn_block_size, - pre_layernorm, + sandwich_norm, ): return GraphormerGraphEncoderLayer( embedding_dim=embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, + layer_scale=layer_scale, + droppath=droppath, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, - export=export, - q_noise=q_noise, - qn_block_size=qn_block_size, - pre_layernorm=pre_layernorm, + sandwich_norm=sandwich_norm, + ) + + def init_extra_node_layers(self, args): + pass + + def init_extra_bias_layers(self, args): + pass + + def forward_extra_node_layers(self, batched_data, x): + """ + input: + batched_data: dict + x: tensor, B x T x C (T = N + 1) + output: + x: tensor, B x T x C (T = N + 1) + """ + return x + + def forward_extra_bias_layers(self, batched_data, attn_bias): + """ + attn_bias: B x H x T x T (T = N + 1) + input: + batched_data: dict + attn_bias: tensor, B x H x T x T (T = N + 1) + output: + attn_bias: tensor, B x H x T x T (T = N + 1) + """ + return attn_bias + + def make_padding_mask(self, batched_data): + data_x = batched_data["x"] + n_graph, n_node = data_x.size()[:2] + padding_mask = (data_x[:, :, 0]).eq(0) # B x N x 1 + padding_mask_cls = torch.zeros( + n_graph, 1, device=padding_mask.device, dtype=padding_mask.dtype ) + padding_mask = torch.cat((padding_mask_cls, padding_mask), dim=1) + return padding_mask.contiguous() def forward( self, @@ -196,68 +209,60 @@ def forward( last_state_only: bool = False, token_embeddings: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, + need_attn: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: - is_tpu = False - # compute padding mask. This is needed for multi-head attention - data_x = batched_data["x"] - n_graph, n_node = data_x.size()[:2] - padding_mask = (data_x[:, :, 0]).eq(0) # B x T x 1 - padding_mask_cls = torch.zeros( - n_graph, 1, device=padding_mask.device, dtype=padding_mask.dtype - ) - padding_mask = torch.cat((padding_mask_cls, padding_mask), dim=1) - # B x (T+1) x 1 + # B x T x 1, T = N + 1 + padding_mask = self.make_padding_mask(batched_data) - if token_embeddings is not None: - x = token_embeddings - else: - x = self.graph_node_feature(batched_data) - - if perturb is not None: - #ic(torch.mean(torch.abs(x[:, 1, :]))) - #ic(torch.mean(torch.abs(perturb))) - x[:, 1:, :] += perturb - - # x: B x T x C + # node features -> B x T x C + x = self.graph_node_feature(batched_data) + # extra layers -> keep the same shape + x = self.forward_extra_node_layers(batched_data, x) + # attn bias -> B x H x T x T attn_bias = self.graph_attn_bias(batched_data) + # extra layers -> keep the same shape + attn_bias = self.forward_extra_bias_layers(batched_data, attn_bias) - if self.embed_scale is not None: - x = x * self.embed_scale - - if self.quant_noise is not None: - x = self.quant_noise(x) + if perturb is not None: + # perturb: B x N x C + x[:, 1:, :] += perturb if self.emb_layer_norm is not None: x = self.emb_layer_norm(x) x = self.dropout_module(x) - # account for padding while computing the representation - # B x T x C -> T x B x C x = x.transpose(0, 1) inner_states = [] + inner_attns = [] + if not last_state_only: inner_states.append(x) for layer in self.layers: - x, _ = layer( + x, attn = layer( x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask, self_attn_bias=attn_bias, + need_weights=need_attn, ) if not last_state_only: inner_states.append(x) + if need_attn: + inner_attns.append(attn) graph_rep = x[0, :, :] if last_state_only: inner_states = [x] + if need_attn: + inner_attns = [attn] - if self.traceable: - return torch.stack(inner_states), graph_rep - else: + if not need_attn: return inner_states, graph_rep + else: + return inner_states, graph_rep, inner_attns diff --git a/graphormer/modules/graphormer_graph_encoder_layer.py b/graphormer/modules/graphormer_graph_encoder_layer.py index 6d815c7..efded6f 100644 --- a/graphormer/modules/graphormer_graph_encoder_layer.py +++ b/graphormer/modules/graphormer_graph_encoder_layer.py @@ -16,6 +16,7 @@ from fairseq.modules.quant_noise import quant_noise from .multihead_attention import MultiheadAttention +from .droppath import DropPath class GraphormerGraphEncoderLayer(nn.Module): @@ -24,6 +25,8 @@ def __init__( embedding_dim: int = 768, ffn_embedding_dim: int = 3072, num_attention_heads: int = 8, + layer_scale: float = 0.0, + droppath: float = 0.0, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, @@ -32,7 +35,7 @@ def __init__( q_noise: float = 0.0, qn_block_size: int = 8, init_fn: Callable = None, - pre_layernorm: bool = False, + sandwich_norm: bool = False, ) -> None: super().__init__() @@ -45,11 +48,15 @@ def __init__( self.attention_dropout = attention_dropout self.q_noise = q_noise self.qn_block_size = qn_block_size - self.pre_layernorm = pre_layernorm + self.sandwich_norm = sandwich_norm + + if droppath > 0.0: + self.dropout_module = DropPath(droppath) + else: + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) - self.dropout_module = FairseqDropout( - dropout, module_name=self.__class__.__name__ - ) self.activation_dropout_module = FairseqDropout( activation_dropout, module_name=self.__class__.__name__ ) @@ -66,8 +73,21 @@ def __init__( ) # layer norm associated with the self attention layer + self.self_attn_layer_norm_sandwich = ( + LayerNorm(self.embedding_dim, export=export) if self.sandwich_norm else None + ) self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export) + if layer_scale > 0: + self.layer_scale1 = nn.Parameter( + layer_scale * torch.ones(self.embedding_dim) + ) + self.layer_scale2 = nn.Parameter( + layer_scale * torch.ones(self.embedding_dim) + ) + else: + self.layer_scale1 = self.layer_scale2 = 1.0 + self.fc1 = self.build_fc1( self.embedding_dim, ffn_embedding_dim, @@ -82,6 +102,9 @@ def __init__( ) # layer norm associated with the position wise feed-forward NN + self.final_layer_norm_sandwich = ( + LayerNorm(self.embedding_dim, export=export) if self.sandwich_norm else None + ) self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): @@ -114,6 +137,7 @@ def forward( self_attn_bias: Optional[torch.Tensor] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, + need_weights=False, ): """ LayerNorm is applied either before or after the self-attention/ffn @@ -121,30 +145,38 @@ def forward( """ # x: T x B x C residual = x - if self.pre_layernorm: - x = self.self_attn_layer_norm(x) + if self.self_attn_layer_norm_sandwich: + x = self.self_attn_layer_norm_sandwich(x) + x, attn = self.self_attn( query=x, key=x, value=x, attn_bias=self_attn_bias, key_padding_mask=self_attn_padding_mask, - need_weights=False, + need_weights=need_weights, attn_mask=self_attn_mask, ) - x = self.dropout_module(x) - x = residual + x - if not self.pre_layernorm: + x = self.dropout_module(self.layer_scale1 * x) + + if self.sandwich_norm: + x = self.self_attn_layer_norm(x) + x = residual + x + else: + x = residual + x x = self.self_attn_layer_norm(x) residual = x - if self.pre_layernorm: - x = self.final_layer_norm(x) + if self.sandwich_norm: + x = self.final_layer_norm_sandwich(x) x = self.activation_fn(self.fc1(x)) x = self.activation_dropout_module(x) x = self.fc2(x) - x = self.dropout_module(x) - x = residual + x - if not self.pre_layernorm: + x = self.dropout_module(self.layer_scale2 * x) + if self.sandwich_norm: + x = self.final_layer_norm(x) + x = residual + x + else: + x = residual + x x = self.final_layer_norm(x) return x, attn diff --git a/graphormer/modules/predict_layers.py b/graphormer/modules/predict_layers.py new file mode 100644 index 0000000..fc47180 --- /dev/null +++ b/graphormer/modules/predict_layers.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PredictLayer(nn.Module): + def __init__( + self, in_dim, out_dim, activation=None, sandwich_norm=False, n_layers=1 + ): + super().__init__() + assert sandwich_norm == False, "sandwich norm not supported" + self.activation = activation + + self.layers = nn.ModuleList( + [nn.Linear(in_dim, in_dim, bias=True) for _ in range(n_layers - 1)] + ) + self.fc_out = nn.Linear(in_dim, out_dim, bias=True) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + if self.activation is not None: + x = self.activation(x) + x = self.fc_out(x) + return x + + +# Defines a group of output layers. Output dimensions of each layer should be assigned in the list 'out_dims'. +class PredictLayerGroup(nn.Module): + def __init__(self, in_dim, out_dims, activation=None, sandwich_norm=False, n_layers=1): + super().__init__() + self.layers_list = nn.ModuleList( + [PredictLayer(in_dim, out_dim, activation, sandwich_norm, n_layers) for out_dim in out_dims] + ) + + def forward(self, x): + return [layer(x) for layer in self.layers_list] diff --git a/graphormer/utils/__init__.py b/graphormer/utils/__init__.py index 59e481e..37b94f2 100644 --- a/graphormer/utils/__init__.py +++ b/graphormer/utils/__init__.py @@ -1,2 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + +from .utils import graphormer_default_add_args, guess_first_load, upgrade_state_dict_named_from_pretrained \ No newline at end of file diff --git a/graphormer/utils/utils.py b/graphormer/utils/utils.py new file mode 100644 index 0000000..e01c03d --- /dev/null +++ b/graphormer/utils/utils.py @@ -0,0 +1,191 @@ +import torch + +import fairseq.utils as fairseq_utils + + +def graphormer_default_add_args(parser): + """Add model-specific arguments to the parser.""" + # Arguments related to dropout + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for" " attention weights", + ) + parser.add_argument( + "--act-dropout", + type=float, + metavar="D", + help="dropout probability after" " activation in FFN", + ) + + # Arguments related to hidden states and self-attention + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) + + # Arguments related to input and output embeddings + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--share-encoder-input-output-embed", + action="store_true", + help="share encoder input" " and output embeddings", + ) + parser.add_argument( + "--encoder-learned-pos", + action="store_true", + help="use learned positional embeddings in the encoder", + ) + parser.add_argument( + "--no-token-positional-embeddings", + action="store_true", + help="if set, disables positional embeddings" " (outside self attention)", + ) + parser.add_argument( + "--max-positions", type=int, help="number of positional embeddings to learn" + ) + + # Arguments related to parameter initialization + parser.add_argument( + "--apply-graphormer-init", + action="store_true", + help="use custom param initialization for Graphormer", + ) + + # Arguments related to fintuning tricks + parser.add_argument( + "--layer-scale", + type=float, + default=0.0, + ) + + parser.add_argument( + "--droppath-prob", + type=float, + default=0.0, + ) + + # misc params + parser.add_argument( + "--activation-fn", + choices=fairseq_utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--sandwich-norm", + default=False, + action="store_true", + help="use sandwich layernorm for the encoder block", + ) + + +def guess_first_load(named_parameters, state_dict, name): + first_load = False + # guess is this the first time we are loading the checkpoint + if set(named_parameters.keys()) != set(state_dict.keys()): + first_load = True + + if not first_load: + for k in named_parameters: + if named_parameters[k].shape != state_dict[k].shape: + first_load = True + break + + return first_load + + +def guess_load_from_pm6(named_parameters, state_dict, name): + upgrade_from_pm6 = False + if any("final_sandwich_layer_norm" in k for k in state_dict.keys()): + upgrade_from_pm6 = True + return upgrade_from_pm6 + + +def upgrade_state_dict_named_from_pretrained(named_parameters, state_dict, name): + from_pm6 = guess_load_from_pm6(named_parameters, state_dict, name) + if from_pm6: + upgrade_state_dict_named_from_pm6_ckpt(named_parameters, state_dict, name) + msg = "Upgraded state_dict from pm6 checkpoint" + else: + upgrade_state_dict_named_from_m3_ckpt(named_parameters, state_dict, name) + msg = "Upgraded state_dict from m3 checkpoint" + return msg + + +def upgrade_state_dict_named_from_pm6_ckpt(named_parameters, state_dict, name): + def upgrade_pm6_keys(key_name): + new_key_name = key_name.replace("sentence_encoder", "graph_encoder") + new_key_name = new_key_name.replace( + "self_attn_sandwich_layer_norm", "self_attn_layer_norm_sandwich" + ) + new_key_name = new_key_name.replace( + "final_layer_norm", "final_layer_norm_sandwich" + ) + new_key_name = new_key_name.replace( + "final_sandwich_layer_norm", "final_layer_norm" + ) + return new_key_name + + old_keys = list(state_dict.keys()) + for key in old_keys: + new_key = upgrade_pm6_keys(key) + if new_key != key: + state_dict[new_key] = state_dict.pop(key) + + zero_init_keys = ["role_encoder", "pos_encoder"] + + for key in named_parameters: + if any(x in key for x in zero_init_keys): + state_dict[key] = torch.zeros_like(named_parameters[key].data) + + to_remove_keys = [ + "masked_lm_pooler", + "regression_lm_head_list", + "regression_embed_out_list", + "regression_ln_list", + ] + _tmp = [] + for key in state_dict.keys(): + if any(x in key for x in to_remove_keys): + _tmp.append(key) + for key in _tmp: + state_dict.pop(key) + + +def upgrade_state_dict_named_from_m3_ckpt(named_parameters, state_dict, name): + to_remove_keys = [ + "encoder.embed_outs", + "encoder.edge_out", + "encoder.spatial_out", + ] + _tmp = [] + for key in state_dict.keys(): + if any(x in key for x in to_remove_keys): + _tmp.append(key) + for key in _tmp: + state_dict.pop(key)