From 69d7de5c5268cf3ab1bc5db0c4658628ea4e88c3 Mon Sep 17 00:00:00 2001 From: Yaning Cui <39021445+emotionor@users.noreply.github.com> Date: Thu, 21 Nov 2024 14:46:04 +0800 Subject: [PATCH] Yn unimol2 (#288) * add unimol2 into unimol_tools * update diff modelsize config * update get repr for unimol v2, and fix bug when input is a dict of atoms and coordinates * update verion number * remove unicore dependency * update args for diff model size --- unimol_tools/setup.py | 2 +- unimol_tools/unimol_tools/config/__init__.py | 2 +- .../unimol_tools/config/model_config.py | 10 + unimol_tools/unimol_tools/data/conformer.py | 341 +++++++- unimol_tools/unimol_tools/data/datahub.py | 19 +- unimol_tools/unimol_tools/models/__init__.py | 2 +- unimol_tools/unimol_tools/models/nnmodel.py | 2 + .../unimol_tools/models/transformersv2.py | 762 ++++++++++++++++++ unimol_tools/unimol_tools/models/unimol.py | 55 +- unimol_tools/unimol_tools/models/unimolv2.py | 591 ++++++++++++++ unimol_tools/unimol_tools/predictor.py | 24 +- unimol_tools/unimol_tools/train.py | 6 + unimol_tools/unimol_tools/utils/util.py | 9 +- unimol_tools/unimol_tools/weights/__init__.py | 2 +- .../unimol_tools/weights/weighthub.py | 21 + 15 files changed, 1804 insertions(+), 44 deletions(-) create mode 100644 unimol_tools/unimol_tools/models/transformersv2.py create mode 100644 unimol_tools/unimol_tools/models/unimolv2.py diff --git a/unimol_tools/setup.py b/unimol_tools/setup.py index c96848e..29e316a 100644 --- a/unimol_tools/setup.py +++ b/unimol_tools/setup.py @@ -5,7 +5,7 @@ setup( name="unimol_tools", - version="0.1.0.post4", + version="0.1.1", description=("unimol_tools is a Python package for property prediciton with Uni-Mol in molecule, materials and protein."), long_description=open('README.md').read(), long_description_content_type='text/markdown', diff --git a/unimol_tools/unimol_tools/config/__init__.py b/unimol_tools/unimol_tools/config/__init__.py index cf8718e..29cd5f4 100644 --- a/unimol_tools/unimol_tools/config/__init__.py +++ b/unimol_tools/unimol_tools/config/__init__.py @@ -1 +1 @@ -from .model_config import MODEL_CONFIG \ No newline at end of file +from .model_config import MODEL_CONFIG, MODEL_CONFIG_V2 \ No newline at end of file diff --git a/unimol_tools/unimol_tools/config/model_config.py b/unimol_tools/unimol_tools/config/model_config.py index 6a6c4bc..02cf5e6 100644 --- a/unimol_tools/unimol_tools/config/model_config.py +++ b/unimol_tools/unimol_tools/config/model_config.py @@ -13,4 +13,14 @@ "crystal": "mp.dict.txt", "oled": "oled.dict.txt", }, +} + +MODEL_CONFIG_V2 = { + 'weight': { + '84m': 'modelzoo/84M/checkpoint.pt', + '164m': 'modelzoo/164M/checkpoint.pt', + '310m': 'modelzoo/310M/checkpoint.pt', + '570m': 'modelzoo/570M/checkpoint.pt', + '1.1B': 'modelzoo/1.1B/checkpoint.pt', + }, } \ No newline at end of file diff --git a/unimol_tools/unimol_tools/data/conformer.py b/unimol_tools/unimol_tools/data/conformer.py index e97571a..2c59414 100644 --- a/unimol_tools/unimol_tools/data/conformer.py +++ b/unimol_tools/unimol_tools/data/conformer.py @@ -16,11 +16,44 @@ from .dictionary import Dictionary from multiprocessing import Pool from tqdm import tqdm +import torch +from numba import njit from ..utils import logger from ..config import MODEL_CONFIG from ..weights import weight_download, WEIGHT_DIR +# https://github.com/snap-stanford/ogb/blob/master/ogb/utils/features.py +# allowable multiple choice node and edge features +allowable_features = { + "possible_atomic_num_list": list(range(1, 119)) + ["misc"], + "possible_chirality_list": [ + "CHI_UNSPECIFIED", + "CHI_TETRAHEDRAL_CW", + "CHI_TETRAHEDRAL_CCW", + "CHI_TRIGONALBIPYRAMIDAL", + "CHI_OCTAHEDRAL", + "CHI_SQUAREPLANAR", + "CHI_OTHER", + ], + "possible_degree_list": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "misc"], + "possible_formal_charge_list": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, "misc"], + "possible_numH_list": [0, 1, 2, 3, 4, 5, 6, 7, 8, "misc"], + "possible_number_radical_e_list": [0, 1, 2, 3, 4, "misc"], + "possible_hybridization_list": ["SP", "SP2", "SP3", "SP3D", "SP3D2", "misc"], + "possible_is_aromatic_list": [False, True], + "possible_is_in_ring_list": [False, True], + "possible_bond_type_list": ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC", "misc"], + "possible_bond_stereo_list": [ + "STEREONONE", + "STEREOZ", + "STEREOE", + "STEREOCIS", + "STEREOTRANS", + "STEREOANY", + ], + "possible_is_conjugated_list": [False, True], +} class ConformerGen(object): ''' @@ -96,7 +129,7 @@ def transform(self, smiles_list): return inputs -def inner_smi2coords(smi, seed=42, mode='fast', remove_hs=True): +def inner_smi2coords(smi, seed=42, mode='fast', remove_hs=True, return_mol=False): ''' This function is responsible for converting a SMILES (Simplified Molecular Input Line Entry System) string into 3D coordinates for each atom in the molecule. It also allows for the generation of 2D coordinates if 3D conformation generation fails, and optionally removes hydrogen atoms and their coordinates from the resulting data. @@ -140,6 +173,10 @@ def inner_smi2coords(smi, seed=42, mode='fast', remove_hs=True): except: print("Failed to generate conformer, replace with zeros.") coordinates = np.zeros((len(atoms),3)) + + if return_mol: + return mol # for unimolv2 + assert len(atoms) == len(coordinates), "coordinates shape is not align with {}".format(smi) if remove_hs: idx = [i for i, atom in enumerate(atoms) if atom != 'H'] @@ -214,3 +251,305 @@ def coords2unimol(atoms, coordinates, dictionary, max_atoms=256, remove_hs=True, 'src_coord': src_coord.astype(np.float32), 'src_edge_type': src_edge_type.astype(int), } + +class UniMolV2Feature(object): + ''' + This class is responsible for generating features for molecules represented as SMILES strings. It uses the ConformerGen class to generate conformers for the molecules and converts the resulting atom symbols and coordinates into a unified molecular representation. + ''' + def __init__(self, **params): + """ + Initializes the neural network model based on the provided model name and parameters. + + :param model_name: (str) The name of the model to initialize. + :param params: Additional parameters for model configuration. + + :return: An instance of the specified neural network model. + :raises ValueError: If the model name is not recognized. + """ + self._init_features(**params) + + def _init_features(self, **params): + """ + Initializes the features of the UniMolV2Feature object based on provided parameters. + + :param params: Arbitrary keyword arguments for feature configuration. + These can include the random seed, maximum number of atoms, data type, + generation method, generation mode, and whether to remove hydrogens. + """ + self.seed = params.get('seed', 42) + self.max_atoms = params.get('max_atoms', 128) + self.data_type = params.get('data_type', 'molecule') + self.method = params.get('method', 'rdkit_random') + self.mode = params.get('mode', 'fast') + self.remove_hs = params.get('remove_hs', True) + + def single_process(self, smiles): + """ + Processes a single SMILES string to generate conformers using the specified method. + + :param smiles: (str) The SMILES string representing the molecule. + :return: A unimolecular data representation (dictionary) of the molecule. + :raises ValueError: If the conformer generation method is unrecognized. + """ + if self.method == 'rdkit_random': + mol = inner_smi2coords(smiles, seed=self.seed, mode=self.mode, remove_hs=self.remove_hs, return_mol=True) + return mol2unimolv2(mol, self.max_atoms, remove_hs=self.remove_hs) + else: + raise ValueError('Unknown conformer generation method: {}'.format(self.method)) + + def transform_raw(self, atoms_list, coordinates_list): + + inputs = [] + for atoms, coordinates in zip(atoms_list, coordinates_list): + mol = create_mol_from_atoms_and_coords(atoms, coordinates) + inputs.append(mol2unimolv2(mol, self.max_atoms, remove_hs=self.remove_hs)) + return inputs + + def transform(self, smiles_list): + pool = Pool() + logger.info('Start generating conformers...') + inputs = [item for item in tqdm(pool.imap(self.single_process, smiles_list))] + pool.close() + # failed_cnt = np.mean([(item['src_coord']==0.0).all() for item in inputs]) + # logger.info('Succeeded in generating conformers for {:.2f}% of molecules.'.format((1-failed_cnt)*100)) + # failed_3d_cnt = np.mean([(item['src_coord'][:,2]==0.0).all() for item in inputs]) + # logger.info('Succeeded in generating 3d conformers for {:.2f}% of molecules.'.format((1-failed_3d_cnt)*100)) + return inputs + +def create_mol_from_atoms_and_coords(atoms, coordinates): + """ + Creates an RDKit molecule object from a list of atom symbols and their corresponding coordinates. + + :param atoms: (list) Atom symbols for the molecule. + :param coordinates: (list) Atomic coordinates for the molecule. + :return: RDKit molecule object. + """ + mol = Chem.RWMol() + atom_indices = [] + + for atom in atoms: + atom_idx = mol.AddAtom(Chem.Atom(atom)) + atom_indices.append(atom_idx) + + conf = Chem.Conformer(len(atoms)) + for i, coord in enumerate(coordinates): + conf.SetAtomPosition(i, coord) + + mol.AddConformer(conf) + Chem.SanitizeMol(mol) + return mol + +def mol2unimolv2(mol, max_atoms=128, remove_hs=True, **params): + """ + Converts atom symbols and coordinates into a unified molecular representation. + + :param mol: (rdkit.Chem.Mol) The molecule object containing atom symbols and coordinates. + :param max_atoms: (int) The maximum number of atoms to consider for the molecule. + :param remove_hs: (bool) Whether to remove hydrogen atoms from the representation. + :param params: Additional parameters. + + :return: A batched data containing the molecular representation. + """ + + mol = AllChem.AddHs(mol, addCoords=True) + atoms_h = np.array([atom.GetSymbol() for atom in mol.GetAtoms()]) + nH_idx = [i for i, atom in enumerate(atoms_h) if atom != 'H'] + atoms = atoms_h[nH_idx] + coordinates_h = mol.GetConformer().GetPositions().astype(np.float32) + coordinates = coordinates_h[nH_idx] + + # cropping atoms and coordinates + if len(atoms) > max_atoms: + idx = np.random.choice(len(atoms), max_atoms, replace=False) + atoms = atoms[idx] + coordinates = coordinates[idx] + # tokens padding + src_tokens = torch.tensor([AllChem.GetPeriodicTable().GetAtomicNumber(item) for item in atoms]) + src_pos = torch.tensor(coordinates) + # change AllChem.RemoveHs to AllChem.RemoveAllHs + mol = AllChem.RemoveAllHs(mol) + node_attr, edge_index, edge_attr = get_graph(mol) + feat = get_graph_features(edge_attr, edge_index, node_attr, drop_feat=0) + feat['src_tokens'] = src_tokens + feat['src_pos'] = src_pos + return feat + +def safe_index(l, e): + """ + Return index of element e in list l. If e is not present, return the last index + """ + try: + return l.index(e) + except: + return len(l) - 1 + +def atom_to_feature_vector(atom): + """ + Converts rdkit atom object to feature list of indices + :param mol: rdkit atom object + :return: list + """ + atom_feature = [ + safe_index(allowable_features["possible_atomic_num_list"], atom.GetAtomicNum()), + allowable_features["possible_chirality_list"].index(str(atom.GetChiralTag())), + safe_index(allowable_features["possible_degree_list"], atom.GetTotalDegree()), + safe_index( + allowable_features["possible_formal_charge_list"], atom.GetFormalCharge() + ), + safe_index(allowable_features["possible_numH_list"], atom.GetTotalNumHs()), + safe_index( + allowable_features["possible_number_radical_e_list"], + atom.GetNumRadicalElectrons(), + ), + safe_index( + allowable_features["possible_hybridization_list"], + str(atom.GetHybridization()), + ), + allowable_features["possible_is_aromatic_list"].index(atom.GetIsAromatic()), + allowable_features["possible_is_in_ring_list"].index(atom.IsInRing()), + ] + return atom_feature + + +def bond_to_feature_vector(bond): + """ + Converts rdkit bond object to feature list of indices + :param mol: rdkit bond object + :return: list + """ + bond_feature = [ + safe_index( + allowable_features["possible_bond_type_list"], str(bond.GetBondType()) + ), + allowable_features["possible_bond_stereo_list"].index(str(bond.GetStereo())), + allowable_features["possible_is_conjugated_list"].index(bond.GetIsConjugated()), + ] + return bond_feature + + +def get_graph(mol): + """ + Converts SMILES string to graph Data object + :input: SMILES string (str) + :return: graph object + """ + atom_features_list = [] + for atom in mol.GetAtoms(): + atom_features_list.append(atom_to_feature_vector(atom)) + x = np.array(atom_features_list, dtype=np.int32) + # bonds + num_bond_features = 3 # bond type, bond stereo, is_conjugated + if len(mol.GetBonds()) > 0: # mol has bonds + edges_list = [] + edge_features_list = [] + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + edge_feature = bond_to_feature_vector(bond) + # add edges in both directions + edges_list.append((i, j)) + edge_features_list.append(edge_feature) + edges_list.append((j, i)) + edge_features_list.append(edge_feature) + # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] + edge_index = np.array(edges_list, dtype=np.int32).T + # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] + edge_attr = np.array(edge_features_list, dtype=np.int32) + + else: # mol has no bonds + edge_index = np.empty((2, 0), dtype=np.int32) + edge_attr = np.empty((0, num_bond_features), dtype=np.int32) + return x, edge_index, edge_attr + +def get_graph_features(edge_attr, edge_index, node_attr, drop_feat): + # atom_feat_sizes = [128] + [16 for _ in range(8)] + atom_feat_sizes = [16 for _ in range(8)] + edge_feat_sizes = [16, 16, 16] + edge_attr, edge_index, x = edge_attr, edge_index, node_attr + N = x.shape[0] + + # atom feature here + atom_feat = convert_to_single_emb(x[:, 1:], atom_feat_sizes) + + # node adj matrix [N, N] bool + adj = np.zeros([N, N], dtype=np.int32) + adj[edge_index[0, :], edge_index[1, :]] = 1 + degree = adj.sum(axis=-1) + + # edge feature here + if len(edge_attr.shape) == 1: + edge_attr = edge_attr[:, None] + edge_feat = np.zeros([N, N, edge_attr.shape[-1]], dtype=np.int32) + edge_feat[edge_index[0, :], edge_index[1, :]] = ( + convert_to_single_emb(edge_attr, edge_feat_sizes) + 1 + ) + shortest_path_result = floyd_warshall(adj) + # max distance is 509 + if drop_feat: + atom_feat[...] = 1 + edge_feat[...] = 1 + degree[...] = 1 + shortest_path_result[...] = 511 + else: + atom_feat = atom_feat + 2 + edge_feat = edge_feat + 2 + degree = degree + 2 + shortest_path_result = shortest_path_result + 1 + + # combine, plus 1 for padding + feat = {} + feat["atom_feat"] = torch.from_numpy(atom_feat).long() + feat["atom_mask"] = torch.ones(N).long() + feat["edge_feat"] = torch.from_numpy(edge_feat).long() + feat["shortest_path"] = torch.from_numpy((shortest_path_result)).long() + feat["degree"] = torch.from_numpy(degree).long().view(-1) + # pair-type + atoms = feat["atom_feat"][..., 0] + pair_type = torch.cat( + [ + atoms.view(-1, 1, 1).expand(-1, N, -1), + atoms.view(1, -1, 1).expand(N, -1, -1), + ], + dim=-1, + ) + feat["pair_type"] = convert_to_single_emb(pair_type, [128, 128]) + feat["attn_bias"] = torch.zeros((N + 1, N + 1), dtype=torch.float32) + return feat + +def convert_to_single_emb(x, sizes): + assert x.shape[-1] == len(sizes) + offset = 1 + for i in range(len(sizes)): + assert (x[..., i] < sizes[i]).all() + x[..., i] = x[..., i] + offset + offset += sizes[i] + return x + + +@njit +def floyd_warshall(M): + (nrows, ncols) = M.shape + assert nrows == ncols + n = nrows + # set unreachable nodes distance to 510 + for i in range(n): + for j in range(n): + if M[i, j] == 0: + M[i, j] = 510 + + for i in range(n): + M[i, i] = 0 + + # floyed algo + for k in range(n): + for i in range(n): + for j in range(n): + cost_ikkj = M[i, k] + M[k, j] + if M[i, j] > cost_ikkj: + M[i, j] = cost_ikkj + + for i in range(n): + for j in range(n): + if M[i, j] >= 510: + M[i, j] = 510 + return M \ No newline at end of file diff --git a/unimol_tools/unimol_tools/data/datahub.py b/unimol_tools/unimol_tools/data/datahub.py index 704328d..1c76b9e 100644 --- a/unimol_tools/unimol_tools/data/datahub.py +++ b/unimol_tools/unimol_tools/data/datahub.py @@ -6,7 +6,7 @@ import numpy as np from .datareader import MolDataReader from .datascaler import TargetScaler -from .conformer import ConformerGen +from .conformer import ConformerGen, UniMolV2Feature class DataHub(object): """ @@ -75,10 +75,17 @@ def _init_data(self, **params): else: raise ValueError('Unknown task: {}'.format(self.task)) - if 'atoms' in self.data and 'coordinates' in self.data: - no_h_list = ConformerGen(**params).transform_raw(self.data['atoms'], self.data['coordinates']) - else: - smiles_list = self.data["smiles"] - no_h_list = ConformerGen(**params).transform(smiles_list) + if params.get('model_name', None) == 'unimolv1': + if 'atoms' in self.data and 'coordinates' in self.data: + no_h_list = ConformerGen(**params).transform_raw(self.data['atoms'], self.data['coordinates']) + else: + smiles_list = self.data["smiles"] + no_h_list = ConformerGen(**params).transform(smiles_list) + elif params.get('model_name', None) == 'unimolv2': + if 'atoms' in self.data and 'coordinates' in self.data: + no_h_list = UniMolV2Feature().transform_raw(self.data['atoms'], self.data['coordinates']) + else: + smiles_list = self.data["smiles"] + no_h_list = UniMolV2Feature().transform(smiles_list) self.data['unimol_input'] = no_h_list diff --git a/unimol_tools/unimol_tools/models/__init__.py b/unimol_tools/unimol_tools/models/__init__.py index 97f216b..d474a24 100644 --- a/unimol_tools/unimol_tools/models/__init__.py +++ b/unimol_tools/unimol_tools/models/__init__.py @@ -1 +1 @@ -from .nnmodel import NNModel, UniMolModel \ No newline at end of file +from .nnmodel import NNModel, UniMolModel, UniMolV2Model \ No newline at end of file diff --git a/unimol_tools/unimol_tools/models/nnmodel.py b/unimol_tools/unimol_tools/models/nnmodel.py index 0928e93..d52b959 100644 --- a/unimol_tools/unimol_tools/models/nnmodel.py +++ b/unimol_tools/unimol_tools/models/nnmodel.py @@ -13,11 +13,13 @@ import numpy as np from ..utils import logger from .unimol import UniMolModel +from .unimolv2 import UniMolV2Model from .loss import GHMC_Loss, FocalLossWithLogits, myCrossEntropyLoss, MAEwithNan NNMODEL_REGISTER = { 'unimolv1': UniMolModel, + 'unimolv2': UniMolV2Model, } LOSS_RREGISTER = { diff --git a/unimol_tools/unimol_tools/models/transformersv2.py b/unimol_tools/unimol_tools/models/transformersv2.py new file mode 100644 index 0000000..b6aadf4 --- /dev/null +++ b/unimol_tools/unimol_tools/models/transformersv2.py @@ -0,0 +1,762 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None, inplace=True): + """softmax dropout, and mask, bias are optional. + Args: + input (torch.Tensor): input tensor + dropout_prob (float): dropout probability + is_training (bool, optional): is in training or not. Defaults to True. + mask (torch.Tensor, optional): the mask tensor, use as input + mask . Defaults to None. + bias (torch.Tensor, optional): the bias tensor, use as input + bias . Defaults to None. + + Returns: + torch.Tensor: the result after softmax + """ + input = input.contiguous() + if not inplace: + # copy a input for non-inplace case + input = input.clone() + if mask is not None: + input += mask + if bias is not None: + input += bias + return F.dropout(F.softmax(input, dim=-1), p=dropout_prob, training=is_training) + +def permute_final_dims(tensor: torch.Tensor, inds): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + +class Dropout(nn.Module): + def __init__(self, p): + super().__init__() + self.p = p + + def forward(self, x, inplace: bool = False): + if self.p > 0 and self.training: + return F.dropout(x, p=self.p, training=True, inplace=inplace) + else: + return x + +class Linear(nn.Linear): + def __init__( + self, + d_in: int, + d_out: int, + bias: bool = True, + init: str = "default", + ): + super(Linear, self).__init__(d_in, d_out, bias=bias) + + self.use_bias = bias + + if self.use_bias: + with torch.no_grad(): + self.bias.fill_(0) + + if init == "default": + self._trunc_normal_init(1.0) + elif init == "relu": + self._trunc_normal_init(2.0) + elif init == "glorot": + self._glorot_uniform_init() + elif init == "gating": + self._zero_init(self.use_bias) + elif init == "normal": + self._normal_init() + elif init == "final": + self._zero_init(False) + else: + raise ValueError("Invalid init method.") + + def _trunc_normal_init(self, scale=1.0): + # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978 + _, fan_in = self.weight.shape + scale = scale / max(1, fan_in) + std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR + nn.init.trunc_normal_(self.weight, mean=0.0, std=std) + + def _glorot_uniform_init(self): + nn.init.xavier_uniform_(self.weight, gain=1) + + def _zero_init(self, use_bias=True): + with torch.no_grad(): + self.weight.fill_(0.0) + if use_bias: + with torch.no_grad(): + self.bias.fill_(1.0) + + def _normal_init(self): + torch.nn.init.kaiming_normal_(self.weight, nonlinearity="linear") + + +class Embedding(nn.Embedding): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + ): + super(Embedding, self).__init__( + num_embeddings, embedding_dim, padding_idx=padding_idx + ) + self._normal_init() + + if padding_idx is not None: + self.weight.data[self.padding_idx].zero_() + + def _normal_init(self, std=0.02): + nn.init.normal_(self.weight, mean=0.0, std=std) + +class Transition(nn.Module): + def __init__(self, d_in, n, dropout=0.0): + + super(Transition, self).__init__() + + self.d_in = d_in + self.n = n + + self.linear_1 = Linear(self.d_in, self.n * self.d_in, init="relu") + self.act = nn.GELU() + self.linear_2 = Linear(self.n * self.d_in, d_in, init="final") + self.dropout = dropout + + def _transition(self, x): + x = self.linear_1(x) + x = self.act(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.linear_2(x) + return x + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + + x = self._transition(x=x) + return x + + +class Attention(nn.Module): + def __init__( + self, + q_dim: int, + k_dim: int, + v_dim: int, + pair_dim: int, + head_dim: int, + num_heads: int, + gating: bool = False, + dropout: float = 0.0, + ): + super(Attention, self).__init__() + + self.num_heads = num_heads + total_dim = head_dim * self.num_heads + self.gating = gating + self.linear_q = Linear(q_dim, total_dim, bias=False, init="glorot") + self.linear_k = Linear(k_dim, total_dim, bias=False, init="glorot") + self.linear_v = Linear(v_dim, total_dim, bias=False, init="glorot") + self.linear_o = Linear(total_dim, q_dim, init="final") + self.linear_g = None + if self.gating: + self.linear_g = Linear(q_dim, total_dim, init="gating") + # precompute the 1/sqrt(head_dim) + self.norm = head_dim**-0.5 + self.dropout = dropout + self.linear_bias = Linear(pair_dim, num_heads) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + pair: torch.Tensor, + mask: torch.Tensor = None, + ) -> torch.Tensor: + g = None + if self.linear_g is not None: + # gating, use raw query input + g = self.linear_g(q) + + q = self.linear_q(q) + q *= self.norm + k = self.linear_k(k) + v = self.linear_v(v) + + q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous() + k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous() + v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3) + + attn = torch.matmul(q, k.transpose(-1, -2)) + del q, k + bias = self.linear_bias(pair).permute(0, 3, 1, 2).contiguous() + attn = softmax_dropout(attn, self.dropout, self.training, mask=mask, bias=bias) + o = torch.matmul(attn, v) + del attn, v + + o = o.transpose(-2, -3).contiguous() + o = o.view(*o.shape[:-2], -1) + + if g is not None: + o = torch.sigmoid(g) * o + + # merge heads + o = self.linear_o(o) + return o + + +class OuterProduct(nn.Module): + def __init__(self, d_atom, d_pair, d_hid=32): + super(OuterProduct, self).__init__() + + self.d_atom = d_atom + self.d_pair = d_pair + self.d_hid = d_hid + + self.linear_in = nn.Linear(d_atom, d_hid * 2) + self.linear_out = nn.Linear(d_hid**2, d_pair) + self.act = nn.GELU() + self._memory_efficient = True + + def _opm(self, a, b): + bsz, n, d = a.shape + # outer = torch.einsum("...bc,...de->...bdce", a, b) + a = a.view(bsz, n, 1, d, 1) + b = b.view(bsz, 1, n, 1, d) + outer = a * b + outer = outer.view(outer.shape[:-2] + (-1,)) + outer = self.linear_out(outer) + return outer + + def forward( + self, + m: torch.Tensor, + op_mask: Optional[torch.Tensor] = None, + op_norm: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + ab = self.linear_in(m) + ab = ab * op_mask + a, b = ab.chunk(2, dim=-1) + + if self._memory_efficient and torch.is_grad_enabled(): + z = checkpoint(self._opm, a, b, use_reentrant=False) + else: + z = self._opm(a, b) + + z *= op_norm + return z + +class AtomFeature(nn.Module): + """ + Compute atom features for each atom in the molecule. + """ + + def __init__( + self, + num_atom, + num_degree, + hidden_dim, + ): + super(AtomFeature, self).__init__() + self.atom_encoder = Embedding(num_atom, hidden_dim, padding_idx=0) + self.degree_encoder = Embedding(num_degree, hidden_dim, padding_idx=0) + self.vnode_encoder = Embedding(1, hidden_dim) + + def forward(self, batched_data, token_feat): + x, degree = ( + batched_data["atom_feat"], + batched_data["degree"], + ) + n_graph, n_node = x.size()[:2] + + node_feature = self.atom_encoder(x).sum(dim=-2) # [n_graph, n_node, n_hidden] + dtype = node_feature.dtype + degree_feature = self.degree_encoder(degree) + node_feature = node_feature + degree_feature + token_feat + + graph_token_feature = self.vnode_encoder.weight.unsqueeze(0).repeat( + n_graph, 1, 1 + ) + + graph_node_feature = torch.cat([graph_token_feature, node_feature], dim=1) + return graph_node_feature.type(dtype) + + +class EdgeFeature(nn.Module): + """ + Compute attention bias for each head. + """ + + def __init__( + self, + pair_dim, + num_edge, + num_spatial, + ): + super(EdgeFeature, self).__init__() + self.pair_dim = pair_dim + + self.edge_encoder = Embedding(num_edge, pair_dim, padding_idx=0) + self.shorest_path_encoder = Embedding(num_spatial, pair_dim, padding_idx=0) + self.vnode_virtual_distance = Embedding(1, pair_dim) + + def forward(self, batched_data, graph_attn_bias): + shortest_path = batched_data["shortest_path"] + edge_input = batched_data["edge_feat"] + + graph_attn_bias[:, 1:, 1:, :] = self.shorest_path_encoder(shortest_path) + + # reset spatial pos here + t = self.vnode_virtual_distance.weight.view(1, 1, self.pair_dim) + graph_attn_bias[:, 1:, 0, :] = t + graph_attn_bias[:, 0, :, :] = t + + edge_input = self.edge_encoder(edge_input).mean(-2) + graph_attn_bias[:, 1:, 1:, :] = graph_attn_bias[:, 1:, 1:, :] + edge_input + return graph_attn_bias + + +class SE3InvariantKernel(nn.Module): + """ + Compute 3D attention bias according to the position information for each head. + """ + + def __init__( + self, + pair_dim, + num_pair, + num_kernel, + std_width=1.0, + start=0.0, + stop=9.0, + ): + super(SE3InvariantKernel, self).__init__() + self.num_kernel = num_kernel + + self.gaussian = GaussianKernel( + self.num_kernel, + num_pair, + std_width=std_width, + start=start, + stop=stop, + ) + self.out_proj = NonLinear(self.num_kernel, pair_dim) + + def forward(self, dist, node_type_edge): + edge_feature = self.gaussian( + dist, + node_type_edge.long(), + ) + edge_feature = self.out_proj(edge_feature) + + return edge_feature + + +@torch.jit.script +def gaussian(x, mean, std): + pi = 3.14159 + a = (2 * pi) ** 0.5 + return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) + + +class GaussianKernel(nn.Module): + def __init__(self, K=128, num_pair=512, std_width=1.0, start=0.0, stop=9.0): + super().__init__() + self.K = K + std_width = std_width + start = start + stop = stop + mean = torch.linspace(start, stop, K) + self.std = (std_width * (mean[1] - mean[0])) + self.register_buffer("mean", mean) + self.mul = Embedding(num_pair, 1, padding_idx=0) + self.bias = Embedding(num_pair, 1, padding_idx=0) + nn.init.constant_(self.bias.weight, 0) + nn.init.constant_(self.mul.weight, 1.0) + + def forward(self, x, atom_pair): + mul = self.mul(atom_pair).abs().sum(dim=-2) + bias = self.bias(atom_pair).sum(dim=-2) + x = mul * x.unsqueeze(-1) + bias + x = x.expand(-1, -1, -1, self.K) + mean = self.mean.float().view(-1) + return gaussian(x.float(), mean, self.std) + + +class NonLinear(nn.Module): + def __init__(self, input, output_size, hidden=None): + super(NonLinear, self).__init__() + + if hidden is None: + hidden = input + self.layer1 = Linear(input, hidden, init="relu") + self.layer2 = Linear(hidden, output_size, init="final") + + def forward(self, x): + x = self.layer1(x) + x = F.gelu(x) + x = self.layer2(x) + return x + + def zero_init(self): + nn.init.zeros_(self.layer2.weight) + nn.init.zeros_(self.layer2.bias) + + +class MovementPredictionHead(nn.Module): + def __init__( + self, + embed_dim: int, + pair_dim: int, + num_head: int, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(embed_dim) + self.embed_dim = embed_dim + self.q_proj = Linear(embed_dim, embed_dim, bias=False, init="glorot") + self.k_proj = Linear(embed_dim, embed_dim, bias=False, init="glorot") + self.v_proj = Linear(embed_dim, embed_dim, bias=False, init="glorot") + self.num_head = num_head + self.scaling = (embed_dim // num_head) ** -0.5 + self.force_proj1 = Linear(embed_dim, 1, init="final", bias=False) + self.linear_bias = Linear(pair_dim, num_head) + self.pair_layer_norm = nn.LayerNorm(pair_dim) + self.dropout = 0.1 + + def zero_init(self): + nn.init.zeros_(self.force_proj1.weight) + + def forward( + self, + query, + pair, + attn_mask, + delta_pos, + ) -> None: + bsz, n_node, _ = query.size() + query = self.layer_norm(query) + q = ( + self.q_proj(query).view(bsz, n_node, self.num_head, -1).transpose(1, 2) + * self.scaling + ) + k = self.k_proj(query).view(bsz, n_node, self.num_head, -1).transpose(1, 2) + v = self.v_proj(query).view(bsz, n_node, self.num_head, -1).transpose(1, 2) + attn = q @ k.transpose(-1, -2) # [bsz, head, n, n] + + pair = self.pair_layer_norm(pair) + bias = self.linear_bias(pair).permute(0, 3, 1, 2).contiguous() + attn_probs = softmax_dropout( + attn, + self.dropout, + self.training, + mask=attn_mask.contiguous(), + bias=bias.contiguous(), + ).view(bsz, self.num_head, n_node, n_node) + rot_attn_probs = attn_probs.unsqueeze(-1) * delta_pos.unsqueeze(1).type_as( + attn_probs + ) # [bsz, head, n, n, 3] + rot_attn_probs = rot_attn_probs.permute(0, 1, 4, 2, 3) + x = rot_attn_probs @ v.unsqueeze(2) # [bsz, head , 3, n, d] + x = x.permute(0, 3, 2, 1, 4).contiguous().view(bsz, n_node, 3, -1) + cur_force = self.force_proj1(x).view(bsz, n_node, 3) + return cur_force + + +class DropPath(torch.nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, prob=None): + super(DropPath, self).__init__() + self.drop_prob = prob + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + def extra_repr(self) -> str: + return f"prob={self.drop_prob}" + + +class TriangleMultiplication(nn.Module): + def __init__(self, d_pair, d_hid): + super(TriangleMultiplication, self).__init__() + + self.linear_ab_p = Linear(d_pair, d_hid * 2) + self.linear_ab_g = Linear(d_pair, d_hid * 2, init="gating") + + self.linear_g = Linear(d_pair, d_pair, init="gating") + self.linear_z = Linear(d_hid, d_pair, init="final") + + self.layer_norm_out = nn.LayerNorm(d_hid) + + def forward( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + mask = mask.unsqueeze(-1) + mask = mask * (mask.shape[-2] ** -0.5) + + g = self.linear_g(z) + if self.training: + ab = self.linear_ab_p(z) * mask * torch.sigmoid(self.linear_ab_g(z)) + else: + ab = self.linear_ab_p(z) + ab *= mask + ab *= torch.sigmoid(self.linear_ab_g(z)) + a, b = torch.chunk(ab, 2, dim=-1) + del z, ab + + a1 = permute_final_dims(a, (2, 0, 1)) + b1 = b.transpose(-1, -3) + x = torch.matmul(a1, b1) + del a1, b1 + b2 = permute_final_dims(b, (2, 0, 1)) + a2 = a.transpose(-1, -3) + x = x + torch.matmul(a2, b2) + del a, b, a2, b2 + + x = permute_final_dims(x, (1, 2, 0)) + + x = self.layer_norm_out(x) + x = self.linear_z(x) + return g * x + +def get_activation_fn(activation): + """ Returns the activation function corresponding to `activation` """ + + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +class TransformerEncoderLayerV2(nn.Module): + """ + Implements a Transformer-M Encoder Layer. + """ + + def __init__( + self, + embedding_dim: int = 768, + pair_dim: int = 64, + pair_hidden_dim: int = 32, + ffn_embedding_dim: int = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + droppath_prob: float = 0.0, + pair_dropout: float = 0.25, + ) -> None: + super().__init__() + + # Initialize parameters + self.embedding_dim = embedding_dim + self.num_attention_heads = num_attention_heads + self.attention_dropout = attention_dropout + + if droppath_prob > 0.0: + self.dropout_module = DropPath(droppath_prob) + else: + self.dropout_module = Dropout(dropout) + + # Initialize blocks + self.activation_fn = get_activation_fn(activation_fn) + head_dim = self.embedding_dim // self.num_attention_heads + self.self_attn = Attention( + self.embedding_dim, + self.embedding_dim, + self.embedding_dim, + pair_dim=pair_dim, + head_dim=head_dim, + num_heads=self.num_attention_heads, + gating=False, + dropout=attention_dropout, + ) + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim) + + self.ffn = Transition( + self.embedding_dim, + ffn_embedding_dim // self.embedding_dim, + dropout=activation_dropout, + ) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = nn.LayerNorm(self.embedding_dim) + self.x_layer_norm_opm = nn.LayerNorm(self.embedding_dim) + + self.opm = OuterProduct(self.embedding_dim, pair_dim, d_hid=pair_hidden_dim) + # self.pair_layer_norm_opm = nn.LayerNorm(pair_dim) + + self.pair_layer_norm_ffn = nn.LayerNorm(pair_dim) + self.pair_ffn = Transition( + pair_dim, + 1, + dropout=activation_dropout, + ) + + self.pair_dropout = pair_dropout + self.pair_layer_norm_trimul = nn.LayerNorm(pair_dim) + self.pair_tri_mul = TriangleMultiplication(pair_dim, pair_hidden_dim) + + def shared_dropout(self, x, shared_dim, dropout): + shape = list(x.shape) + shape[shared_dim] = 1 + with torch.no_grad(): + mask = x.new_ones(shape) + return F.dropout(mask, p=dropout, training=self.training) * x + + def forward( + self, + x: torch.Tensor, + pair: torch.Tensor, + pair_mask: torch.Tensor, + self_attn_mask: Optional[torch.Tensor] = None, + op_mask: Optional[torch.Tensor] = None, + op_norm: Optional[torch.Tensor] = None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer implementation. + """ + residual = x + x = self.self_attn_layer_norm(x) + x = self.self_attn( + x, + x, + x, + pair=pair, + mask=self_attn_mask, + ) + x = self.dropout_module(x) + x = residual + x + + x = x + self.dropout_module(self.ffn(self.final_layer_norm(x))) + + pair = pair + self.dropout_module( + self.opm(self.x_layer_norm_opm(x), op_mask, op_norm) + ) + + # trimul + pair = pair + self.shared_dropout( + self.pair_tri_mul(self.pair_layer_norm_trimul(pair), pair_mask), + -3, + self.pair_dropout, + ) + + # ffn + pair = pair + self.dropout_module(self.pair_ffn(self.pair_layer_norm_ffn(pair))) + + return x, pair + + +class TransformerEncoderWithPairV2(nn.Module): + def __init__( + self, + num_encoder_layers: int = 6, + embedding_dim: int = 768, + + pair_dim: int = 64, + pair_hidden_dim: int = 32, + + ffn_embedding_dim: int = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.0, + activation_fn: str = "gelu", + droppath_prob: float = 0.0, + pair_dropout: float = 0.25, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.num_head = num_attention_heads + self.layer_norm = nn.LayerNorm(embedding_dim) + self.pair_layer_norm = nn.LayerNorm(pair_dim) + self.layers = nn.ModuleList([]) + + if droppath_prob > 0: + droppath_probs = [ + x.item() for x in torch.linspace(0, droppath_prob, num_encoder_layers) + ] + else: + droppath_probs = None + + self.layers.extend( + [ + TransformerEncoderLayerV2( + embedding_dim=embedding_dim, + pair_dim=pair_dim, + pair_hidden_dim=pair_hidden_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + droppath_prob=droppath_probs[i] + if droppath_probs is not None + else 0, + pair_dropout=pair_dropout, + ) + for i in range(num_encoder_layers) + ] + ) + + def forward( + self, + x, + pair, + atom_mask, + pair_mask, + attn_mask=None, + ) -> None: + + x = self.layer_norm(x) + pair = self.pair_layer_norm(pair) + op_mask = atom_mask.unsqueeze(-1) + op_mask = op_mask * (op_mask.size(-2) ** -0.5) + eps = 1e-3 + op_norm = 1.0 / (eps + torch.einsum("...bc,...dc->...bdc", op_mask, op_mask)) + for layer in self.layers: + x, pair = layer( + x, + pair, + pair_mask=pair_mask, + self_attn_mask=attn_mask, + op_mask=op_mask, + op_norm=op_norm, + ) + return x, pair \ No newline at end of file diff --git a/unimol_tools/unimol_tools/models/unimol.py b/unimol_tools/unimol_tools/models/unimol.py index 6617a35..e1dd8b0 100644 --- a/unimol_tools/unimol_tools/models/unimol.py +++ b/unimol_tools/unimol_tools/models/unimol.py @@ -180,33 +180,34 @@ def get_dist_features(dist, et): cls_repr = encoder_rep[:, 0, :] # CLS token repr all_repr = encoder_rep[:, :, :] # all token repr - filtered_tensors = [] - filtered_coords = [] - for tokens, coord in zip(src_tokens, src_coord): - filtered_tensor = tokens[(tokens != 0) & (tokens != 1) & (tokens != 2)] # filter out BOS(0), EOS(1), PAD(2) - filtered_coord = coord[(tokens != 0) & (tokens != 1) & (tokens != 2)] - filtered_tensors.append(filtered_tensor) - filtered_coords.append(filtered_coord) - - lengths = [len(filtered_tensor) for filtered_tensor in filtered_tensors] # Compute the lengths of the filtered tensors - if return_repr and return_atomic_reprs: - cls_atomic_reprs = [] - atomic_symbols = [] - for i in range(len(all_repr)): - atomic_reprs = encoder_rep[i, 1:lengths[i]+1, :] - atomic_symbol = [] - for atomic_num in filtered_tensors[i]: - atomic_symbol.append(self.dictionary.symbols[atomic_num]) - atomic_symbols.append(atomic_symbol) - cls_atomic_reprs.append(atomic_reprs) - return { - 'cls_repr': cls_repr, - 'atomic_symbol': atomic_symbols, - 'atomic_coords': filtered_coords, - 'atomic_reprs': cls_atomic_reprs - } - if return_repr and not return_atomic_reprs: - return {'cls_repr': cls_repr} + if return_repr: + filtered_tensors = [] + filtered_coords = [] + for tokens, coord in zip(src_tokens, src_coord): + filtered_tensor = tokens[(tokens != 0) & (tokens != 1) & (tokens != 2)] # filter out BOS(0), EOS(1), PAD(2) + filtered_coord = coord[(tokens != 0) & (tokens != 1) & (tokens != 2)] + filtered_tensors.append(filtered_tensor) + filtered_coords.append(filtered_coord) + + lengths = [len(filtered_tensor) for filtered_tensor in filtered_tensors] # Compute the lengths of the filtered tensors + if return_atomic_reprs: + cls_atomic_reprs = [] + atomic_symbols = [] + for i in range(len(all_repr)): + atomic_reprs = encoder_rep[i, 1:lengths[i]+1, :] + atomic_symbol = [] + for atomic_num in filtered_tensors[i]: + atomic_symbol.append(self.dictionary.symbols[atomic_num]) + atomic_symbols.append(atomic_symbol) + cls_atomic_reprs.append(atomic_reprs) + return { + 'cls_repr': cls_repr, + 'atomic_symbol': atomic_symbols, + 'atomic_coords': filtered_coords, + 'atomic_reprs': cls_atomic_reprs + } + else: + return {'cls_repr': cls_repr} logits = self.classification_head(cls_repr) return logits diff --git a/unimol_tools/unimol_tools/models/unimolv2.py b/unimol_tools/unimol_tools/models/unimolv2.py new file mode 100644 index 0000000..4fd82d4 --- /dev/null +++ b/unimol_tools/unimol_tools/models/unimolv2.py @@ -0,0 +1,591 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .transformersv2 import TransformerEncoderWithPairV2 +from ..utils import pad_1d_tokens, pad_2d, pad_coords +import argparse +import pathlib +import os + +from .transformersv2 import AtomFeature, EdgeFeature, SE3InvariantKernel, MovementPredictionHead +from ..utils import logger +from ..config import MODEL_CONFIG_V2 +from ..data import Dictionary +from ..weights import weight_download_v2, WEIGHT_DIR + +BACKBONE = { + 'transformer': TransformerEncoderWithPairV2, +} + +class UniMolV2Model(nn.Module): + """ + UniMolModel is a specialized model for molecular, protein, crystal, or MOF (Metal-Organic Frameworks) data. + It dynamically configures its architecture based on the type of data it is intended to work with. The model + supports multiple data types and incorporates various architecture configurations and pretrained weights. + + Attributes: + - output_dim: The dimension of the output layer. + - data_type: The type of data the model is designed to handle. + - remove_hs: Flag to indicate whether hydrogen atoms are removed in molecular data. + - pretrain_path: Path to the pretrained model weights. + - dictionary: The dictionary object used for tokenization and encoding. + - mask_idx: Index of the mask token in the dictionary. + - padding_idx: Index of the padding token in the dictionary. + - embed_tokens: Embedding layer for token embeddings. + - encoder: Transformer encoder backbone of the model. + - gbf_proj, gbf: Layers for Gaussian basis functions or numerical embeddings. + - classification_head: The final classification head of the model. + """ + def __init__(self, output_dim=2, model_size='84m', **params): + """ + Initializes the UniMolModel with specified parameters and data type. + + :param output_dim: (int) The number of output dimensions (classes). + :param data_type: (str) The type of data (e.g., 'molecule', 'protein'). + :param params: Additional parameters for model configuration. + """ + super().__init__() + + self.args = molecule_architecture(model_size=model_size) + self.output_dim = output_dim + self.model_size = model_size + self.remove_hs = params.get('remove_hs', False) + + name = model_size + if not os.path.exists(os.path.join(WEIGHT_DIR, MODEL_CONFIG_V2['weight'][name])): + weight_download_v2(MODEL_CONFIG_V2['weight'][name], WEIGHT_DIR) + + self.pretrain_path = os.path.join(WEIGHT_DIR, MODEL_CONFIG_V2['weight'][name]) + + self.token_num = 128 + self.padding_idx = 0 + self.mask_idx = 127 + self.embed_tokens = nn.Embedding( + self.token_num, self.args.encoder_embed_dim, self.padding_idx + ) + + self.encoder = BACKBONE[self.args.backbone]( + num_encoder_layers = self.args.num_encoder_layers, + embedding_dim = self.args.encoder_embed_dim, + + pair_dim = self.args.pair_embed_dim, + pair_hidden_dim = self.args.pair_hidden_dim, + + ffn_embedding_dim = self.args.ffn_embedding_dim, + num_attention_heads = self.args.num_attention_heads, + dropout = self.args.dropout, + attention_dropout = self.args.attention_dropout, + activation_dropout = self.args.activation_dropout, + activation_fn = self.args.activation_fn, + droppath_prob = self.args.droppath_prob, + pair_dropout = self.args.pair_dropout, + ) + + num_atom = 512 + num_degree = 128 + num_edge = 64 + num_pair = 512 + num_spatial = 512 + + K = 128 + n_edge_type = 1 + + self.atom_feature = AtomFeature( + num_atom=num_atom, + num_degree=num_degree, + hidden_dim=self.args.encoder_embed_dim, + ) + + self.edge_feature = EdgeFeature( + pair_dim=self.args.pair_embed_dim, + num_edge=num_edge, + num_spatial=num_spatial, + ) + + + self.se3_invariant_kernel = SE3InvariantKernel( + pair_dim=self.args.pair_embed_dim, + num_pair=num_pair, + num_kernel=K, + std_width=self.args.gaussian_std_width, + start=self.args.gaussian_mean_start, + stop=self.args.gaussian_mean_stop, + ) + + self.movement_pred_head = MovementPredictionHead( + self.args.encoder_embed_dim, self.args.pair_embed_dim, self.args.encoder_attention_heads + ) + + self.classification_heads = nn.ModuleDict() + self.dtype = torch.float32 + + self.classification_head = ClassificationHead( + input_dim=self.args.encoder_embed_dim, + inner_dim=self.args.encoder_embed_dim, + num_classes=self.output_dim, + activation_fn=self.args.pooler_activation_fn, + pooler_dropout=self.args.pooler_dropout, + ) + self.load_pretrained_weights(path=self.pretrain_path) + + def load_pretrained_weights(self, path): + """ + Loads pretrained weights into the model. + + :param path: (str) Path to the pretrained weight file. + """ + if path is not None: + logger.info("Loading pretrained weights from {}".format(path)) + state_dict = torch.load(path, map_location=lambda storage, loc: storage) + self.load_state_dict(state_dict['model'], strict=False) + + @classmethod + def build_model(cls, args): + """ + Class method to build a new instance of the UniMolModel. + + :param args: Arguments for model configuration. + :return: An instance of UniMolModel. + """ + return cls(args) +#'atom_feat', 'atom_mask', 'edge_feat', 'shortest_path', 'degree', 'pair_type', 'attn_bias', 'src_tokens' + def forward( + self, + atom_feat, + atom_mask, + edge_feat, + shortest_path, + degree, + pair_type, + attn_bias, + src_tokens, + src_pos, + return_repr=False, + return_atomic_reprs=False, + **kwargs + ): + + + pos = src_pos + + n_mol, n_atom = atom_feat.shape[:2] + token_feat = self.embed_tokens(src_tokens) + x = self.atom_feature({'atom_feat': atom_feat, 'degree': degree}, token_feat) + + dtype = self.dtype + + x = x.type(dtype) + + attn_mask = attn_bias.clone() + attn_bias = torch.zeros_like(attn_mask) + attn_mask = attn_mask.unsqueeze(1).repeat(1, self.args.encoder_attention_heads, 1, 1) + attn_bias = attn_bias.unsqueeze(-1).repeat(1, 1, 1, self.args.pair_embed_dim) + attn_bias = self.edge_feature({'shortest_path':shortest_path, 'edge_feat': edge_feat}, attn_bias) + attn_mask = attn_mask.type(self.dtype) + + atom_mask_cls = torch.cat( + [ + torch.ones(n_mol, 1, device=atom_mask.device, dtype=atom_mask.dtype), + atom_mask, + ], + dim=1, + ).type(self.dtype) + + pair_mask = atom_mask_cls.unsqueeze(-1) * atom_mask_cls.unsqueeze(-2) + + def one_block(x, pos, return_x=False): + delta_pos = pos.unsqueeze(1) - pos.unsqueeze(2) + dist = delta_pos.norm(dim=-1) + attn_bias_3d = self.se3_invariant_kernel(dist.detach(), pair_type) + new_attn_bias = attn_bias.clone() + new_attn_bias[:, 1:, 1:, :] = new_attn_bias[:, 1:, 1:, :] + attn_bias_3d + new_attn_bias = new_attn_bias.type(dtype) + x, pair = self.encoder( + x, + new_attn_bias, + atom_mask=atom_mask_cls, + pair_mask=pair_mask, + attn_mask=attn_mask, + ) + node_output = self.movement_pred_head( + x[:, 1:, :], + pair[:, 1:, 1:, :], + attn_mask[:, :, 1:, 1:], + delta_pos.detach(), + ) + if return_x: + return x, pair, pos + node_output + else: + return pos + node_output + + x, pair, pos = one_block(x, pos, return_x=True) + cls_repr = x[:, 0, :] # CLS token repr + all_repr = x[:, :, :] # all token repr + + if return_repr: + filtered_tensors = [] + filtered_coords = [] + + for tokens, coord in zip(src_tokens, src_pos): + filtered_tensor = tokens[(tokens != 0) & (tokens != 1) & (tokens != 2)] # filter out BOS(0), EOS(1), PAD(2) + filtered_coord = coord[(tokens != 0) & (tokens != 1) & (tokens != 2)] + filtered_tensors.append(filtered_tensor) + filtered_coords.append(filtered_coord) + + lengths = [len(filtered_tensor) for filtered_tensor in filtered_tensors] # Compute the lengths of the filtered tensors + if return_atomic_reprs: + cls_atomic_reprs = [] + atomic_symbols = [] + for i in range(len(all_repr)): + atomic_reprs = x[i, 1:lengths[i]+1, :] + atomic_symbol = filtered_tensors[i] + atomic_symbols.append(atomic_symbol) + cls_atomic_reprs.append(atomic_reprs) + return { + 'cls_repr': cls_repr, + 'atomic_symbol': atomic_symbols, + 'atomic_coords': filtered_coords, + 'atomic_reprs': cls_atomic_reprs + } + else: + return {'cls_repr': cls_repr} + + logits = self.classification_head(cls_repr) + return logits + + + def register_classification_head( + self, name, num_classes=None, inner_dim=None, **kwargs + ): + """Register a classification head.""" + if name in self.classification_heads: + prev_num_classes = self.classification_heads[name].out_proj.out_features + prev_inner_dim = self.classification_heads[name].dense.out_features + if num_classes != prev_num_classes or inner_dim != prev_inner_dim: + logger.warning( + 're-registering head "{}" with num_classes {} (prev: {}) ' + "and inner_dim {} (prev: {})".format( + name, num_classes, prev_num_classes, inner_dim, prev_inner_dim + ) + ) + self.classification_heads[name] = ClassificationHead( + input_dim=self.args.encoder_embed_dim, + inner_dim=inner_dim or self.args.encoder_embed_dim, + num_classes=num_classes, + activation_fn=self.args.pooler_activation_fn, + pooler_dropout=self.args.pooler_dropout, + ) + + def set_num_updates(self, num_updates): + """State from trainer to pass along to model at every update.""" + self._num_updates = num_updates + + def get_num_updates(self): + return self._num_updates + + def batch_collate_fn(self, samples): + """ + Custom collate function for batch processing non-MOF data. + + :param samples: A list of sample data. + + :return: A tuple containing a batch dictionary and labels. + """ + batch = {} + for k in samples[0][0].keys(): + if k == 'atom_feat': + v = pad_coords([s[0][k] for s in samples], pad_idx=self.padding_idx, dim=8) + elif k == 'atom_mask': + v = pad_1d_tokens([s[0][k] for s in samples], pad_idx=self.padding_idx) + elif k == 'edge_feat': + v = pad_2d([s[0][k] for s in samples], pad_idx=self.padding_idx, dim=3) + elif k == 'shortest_path': + v = pad_2d([s[0][k] for s in samples], pad_idx=self.padding_idx) + elif k == 'degree': + v = pad_1d_tokens([s[0][k] for s in samples], pad_idx=self.padding_idx) + elif k == 'pair_type': + v = pad_2d([s[0][k] for s in samples], pad_idx=self.padding_idx, dim=2) + elif k == 'attn_bias': + v = pad_2d([s[0][k] for s in samples], pad_idx=self.padding_idx) + elif k == 'src_tokens': + v = pad_1d_tokens([s[0][k] for s in samples], pad_idx=self.padding_idx) + elif k == 'src_pos': + v = pad_coords([s[0][k] for s in samples], pad_idx=self.padding_idx) + batch[k] = v + try: + label = torch.tensor([s[1] for s in samples]) + except: + label = None + return batch, label + +class ClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim, + inner_dim, + num_classes, + activation_fn, + pooler_dropout, + ): + """ + Initialize the classification head. + + :param input_dim: Dimension of input features. + :param inner_dim: Dimension of the inner layer. + :param num_classes: Number of classes for classification. + :param activation_fn: Activation function name. + :param pooler_dropout: Dropout rate for the pooling layer. + """ + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.activation_fn = get_activation_fn(activation_fn) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, features, **kwargs): + """ + Forward pass for the classification head. + + :param features: Input features for classification. + + :return: Output from the classification head. + """ + x = features + x = self.dropout(x) + x = self.dense(x) + x = self.activation_fn(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + +class NonLinearHead(nn.Module): + """ + A neural network module used for simple classification tasks. It consists of a two-layered linear network + with a nonlinear activation function in between. + + Attributes: + - linear1: The first linear layer. + - linear2: The second linear layer that outputs to the desired dimensions. + - activation_fn: The nonlinear activation function. + """ + def __init__( + self, + input_dim, + out_dim, + activation_fn, + hidden=None, + ): + """ + Initializes the NonLinearHead module. + + :param input_dim: Dimension of the input features. + :param out_dim: Dimension of the output. + :param activation_fn: The activation function to use. + :param hidden: Dimension of the hidden layer; defaults to the same as input_dim if not provided. + """ + super().__init__() + hidden = input_dim if not hidden else hidden + self.linear1 = nn.Linear(input_dim, hidden) + self.linear2 = nn.Linear(hidden, out_dim) + self.activation_fn = get_activation_fn(activation_fn) + + def forward(self, x): + """ + Forward pass of the NonLinearHead. + + :param x: Input tensor to the module. + + :return: Tensor after passing through the network. + """ + x = self.linear1(x) + x = self.activation_fn(x) + x = self.linear2(x) + return x + +@torch.jit.script +def gaussian(x, mean, std): + """ + Gaussian function implemented for PyTorch tensors. + + :param x: The input tensor. + :param mean: The mean for the Gaussian function. + :param std: The standard deviation for the Gaussian function. + + :return: The output tensor after applying the Gaussian function. + """ + pi = 3.14159 + a = (2 * pi) ** 0.5 + return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) + +def get_activation_fn(activation): + """ Returns the activation function corresponding to `activation` """ + + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + +class GaussianLayer(nn.Module): + """ + A neural network module implementing a Gaussian layer, useful in graph neural networks. + + Attributes: + - K: Number of Gaussian kernels. + - means, stds: Embeddings for the means and standard deviations of the Gaussian kernels. + - mul, bias: Embeddings for scaling and bias parameters. + """ + def __init__(self, K=128, edge_types=1024): + """ + Initializes the GaussianLayer module. + + :param K: Number of Gaussian kernels. + :param edge_types: Number of different edge types to consider. + + :return: An instance of the configured Gaussian kernel and edge types. + """ + super().__init__() + self.K = K + self.means = nn.Embedding(1, K) + self.stds = nn.Embedding(1, K) + self.mul = nn.Embedding(edge_types, 1) + self.bias = nn.Embedding(edge_types, 1) + nn.init.uniform_(self.means.weight, 0, 3) + nn.init.uniform_(self.stds.weight, 0, 3) + nn.init.constant_(self.bias.weight, 0) + nn.init.constant_(self.mul.weight, 1) + + def forward(self, x, edge_type): + """ + Forward pass of the GaussianLayer. + + :param x: Input tensor representing distances or other features. + :param edge_type: Tensor indicating types of edges in the graph. + + :return: Tensor transformed by the Gaussian layer. + """ + mul = self.mul(edge_type).type_as(x) + bias = self.bias(edge_type).type_as(x) + x = mul * x.unsqueeze(-1) + bias + x = x.expand(-1, -1, -1, self.K) + mean = self.means.weight.float().view(-1) + std = self.stds.weight.float().view(-1).abs() + 1e-5 + return gaussian(x.float(), mean, std).type_as(self.means.weight) + +class NumericalEmbed(nn.Module): + """ + Numerical embedding module, typically used for embedding edge features in graph neural networks. + + Attributes: + - K: Output dimension for embeddings. + - mul, bias, w_edge: Embeddings for transformation parameters. + - proj: Projection layer to transform inputs. + - ln: Layer normalization. + """ + def __init__(self, K=128, edge_types=1024, activation_fn='gelu'): + """ + Initializes the NonLinearHead. + + :param input_dim: The input dimension of the first layer. + :param out_dim: The output dimension of the second layer. + :param activation_fn: The activation function to use. + :param hidden: The dimension of the hidden layer; defaults to input_dim if not specified. + """ + super().__init__() + self.K = K + self.mul = nn.Embedding(edge_types, 1) + self.bias = nn.Embedding(edge_types, 1) + self.w_edge = nn.Embedding(edge_types, K) + + self.proj = NonLinearHead(1, K, activation_fn, hidden=2*K) + self.ln = nn.LayerNorm(K) + + nn.init.constant_(self.bias.weight, 0) + nn.init.constant_(self.mul.weight, 1) + nn.init.kaiming_normal_(self.w_edge.weight) + + + def forward(self, x, edge_type): # edge_type, atoms + """ + Forward pass of the NonLinearHead. + + :param x: Input tensor to the classification head. + + :return: The output tensor after passing through the layers. + """ + mul = self.mul(edge_type).type_as(x) + bias = self.bias(edge_type).type_as(x) + w_edge = self.w_edge(edge_type).type_as(x) + edge_emb = w_edge * torch.sigmoid(mul * x.unsqueeze(-1) + bias) + + edge_proj = x.unsqueeze(-1).type_as(self.mul.weight) + edge_proj = self.proj(edge_proj) + edge_proj = self.ln(edge_proj) + + h = edge_proj + edge_emb + h = h.type_as(self.mul.weight) + return h + +def molecule_architecture(model_size='84m'): + args = argparse.ArgumentParser() + if model_size == '84m': + args.num_encoder_layers = getattr(args, "num_encoder_layers", 12) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.num_attention_heads = getattr(args, "num_attention_heads", 48) + args.ffn_embedding_dim = getattr(args, "ffn_embedding_dim", 768) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 48) + elif model_size == '164m': + args.num_encoder_layers = getattr(args, "num_encoder_layers", 24) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.num_attention_heads = getattr(args, "num_attention_heads", 48) + args.ffn_embedding_dim = getattr(args, "ffn_embedding_dim", 768) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 48) + elif model_size == '310m': + args.num_encoder_layers = getattr(args, "num_encoder_layers", 32) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.num_attention_heads = getattr(args, "num_attention_heads", 64) + args.ffn_embedding_dim = getattr(args, "ffn_embedding_dim", 1024) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 64) + elif model_size == '570m': + args.num_encoder_layers = getattr(args, "num_encoder_layers", 32) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1536) + args.num_attention_heads = getattr(args, "num_attention_heads", 96) + args.ffn_embedding_dim = getattr(args, "ffn_embedding_dim", 1536) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 96) + elif model_size == '1.1B': + args.num_encoder_layers = getattr(args, "num_encoder_layers", 64) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1536) + args.num_attention_heads = getattr(args, "num_attention_heads", 96) + args.ffn_embedding_dim = getattr(args, "ffn_embedding_dim", 1536) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 96) + else: + raise ValueError('Current not support data type: {}'.format(model_size)) + args.pair_embed_dim = getattr(args, "pair_embed_dim", 512) + args.pair_hidden_dim = getattr(args, "pair_hidden_dim", 64) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.droppath_prob = getattr(args, "droppath_prob", 0.0) + args.pair_dropout = getattr(args, "pair_dropout", 0.25) + args.backbone = getattr(args, "backbone", "transformer") + args.gaussian_std_width = getattr(args, "gaussian_std_width", 1.0) + args.gaussian_mean_start = getattr(args, "gaussian_mean_start", 0.0) + args.gaussian_mean_stop = getattr(args, "gaussian_mean_stop", 9.0) + args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + return args + diff --git a/unimol_tools/unimol_tools/predictor.py b/unimol_tools/unimol_tools/predictor.py index 495625d..4529185 100644 --- a/unimol_tools/unimol_tools/predictor.py +++ b/unimol_tools/unimol_tools/predictor.py @@ -8,7 +8,7 @@ import torch from torch.utils.data import Dataset from .data import DataHub -from .models import UniMolModel +from .models import UniMolModel, UniMolV2Model from .tasks import Trainer class MolDataset(Dataset): @@ -32,6 +32,8 @@ class UniMolRepr(object): """ def __init__(self, data_type='molecule', remove_hs=False, + model_name='unimolv1', + model_size='84m', use_gpu=True): """ Initialize a :class:`UniMolRepr` class. @@ -39,11 +41,23 @@ def __init__(self, data_type='molecule', :param data_type: str, default='molecule', currently support molecule, oled. :param remove_hs: bool, default=False, whether to remove hydrogens in molecular. :param use_gpu: bool, default=True, whether to use gpu. + :param model_name: str, default='unimolv1', currently support unimolv1, unimolv2. + :param model_size: str, default='84m', model size of unimolv2. """ self.device = torch.device("cuda:0" if torch.cuda.is_available() and use_gpu else "cpu") - self.model = UniMolModel(output_dim=1, data_type=data_type, remove_hs=remove_hs).to(self.device) + if model_name == 'unimolv1': + self.model = UniMolModel(output_dim=1, data_type=data_type, remove_hs=remove_hs).to(self.device) + elif model_name == 'unimolv2': + self.model = UniMolV2Model(output_dim=1, model_size=model_size).to(self.device) + else: + raise ValueError('Unknown model name: {}'.format(model_name)) self.model.eval() - self.params = {'data_type': data_type, 'remove_hs': remove_hs} + self.params = { + 'data_type': data_type, + 'remove_hs': remove_hs, + 'model_name': model_name, + 'model_size': model_size, + } def get_repr(self, data=None, return_atomic_reprs=False): """ @@ -65,15 +79,17 @@ def get_repr(self, data=None, return_atomic_reprs=False): if isinstance(data, str): # single smiles string. data = [data] + data = np.array(data) elif isinstance(data, dict): # custom conformers, should take atoms and coordinates as input. assert 'atoms' in data and 'coordinates' in data elif isinstance(data, list): # list of smiles strings. assert isinstance(data[-1], str) + data = np.array(data) else: raise ValueError('Unknown data type: {}'.format(type(data))) - data = np.array(data) + datahub = DataHub(data=data, task='repr', is_train=False, diff --git a/unimol_tools/unimol_tools/train.py b/unimol_tools/unimol_tools/train.py index 57f3c1f..2bf1a4c 100644 --- a/unimol_tools/unimol_tools/train.py +++ b/unimol_tools/unimol_tools/train.py @@ -43,6 +43,8 @@ def __init__(self, use_amp=True, freeze_layers=None, freeze_layers_reversed=False, + model_name='unimolv1', + model_size='84m', **params, ): """ @@ -85,6 +87,8 @@ def __init__(self, :param freeze_layers: str or list, frozen layers by startwith name list. ['encoder', 'gbf'] will freeze all the layers whose name start with 'encoder' or 'gbf'. :param freeze_layers_reversed: bool, default=False, inverse selection of frozen layers :param params: dict, default=None, other parameters. + :param model_name: str, default='unimolv1', currently support unimolv1, unimolv2. + :param model_size: str, default='84m', model size. work when model_name is unimolv2. avaliable: 84m, 164m, 310m, 570m, 1.1B. """ config_path = os.path.join(os.path.dirname(__file__), 'config/default.yaml') @@ -111,6 +115,8 @@ def __init__(self, config.use_amp = use_amp config.freeze_layers = freeze_layers config.freeze_layers_reversed = freeze_layers_reversed + config.model_name = model_name + config.model_size = model_size self.save_path = save_path self.config = config diff --git a/unimol_tools/unimol_tools/utils/util.py b/unimol_tools/unimol_tools/utils/util.py index ab0843f..3c676fc 100644 --- a/unimol_tools/unimol_tools/utils/util.py +++ b/unimol_tools/unimol_tools/utils/util.py @@ -41,6 +41,7 @@ def copy_tensor(src, dst): def pad_2d( values, pad_idx, + dim=1, left_pad=False, pad_to_length=None, pad_to_multiple=1, @@ -61,7 +62,10 @@ def pad_2d( size = size if pad_to_length is None else max(size, pad_to_length) if pad_to_multiple != 1 and size % pad_to_multiple != 0: size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) - res = values[0].new(len(values), size, size).fill_(pad_idx) + if dim == 1: + res = values[0].new(len(values), size, size).fill_(pad_idx) + else: + res = values[0].new(len(values), size, size, dim).fill_(pad_idx) def copy_tensor(src, dst): assert dst.numel() == src.numel() @@ -75,6 +79,7 @@ def copy_tensor(src, dst): def pad_coords( values, pad_idx, + dim=3, left_pad=False, pad_to_length=None, pad_to_multiple=1, @@ -94,7 +99,7 @@ def pad_coords( size = size if pad_to_length is None else max(size, pad_to_length) if pad_to_multiple != 1 and size % pad_to_multiple != 0: size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) - res = values[0].new(len(values), size, 3).fill_(pad_idx) + res = values[0].new(len(values), size, dim).fill_(pad_idx) def copy_tensor(src, dst): assert dst.numel() == src.numel() diff --git a/unimol_tools/unimol_tools/weights/__init__.py b/unimol_tools/unimol_tools/weights/__init__.py index d54a3e6..25d2fff 100644 --- a/unimol_tools/unimol_tools/weights/__init__.py +++ b/unimol_tools/unimol_tools/weights/__init__.py @@ -1 +1 @@ -from .weighthub import weight_download, WEIGHT_DIR \ No newline at end of file +from .weighthub import weight_download, weight_download_v2, WEIGHT_DIR \ No newline at end of file diff --git a/unimol_tools/unimol_tools/weights/weighthub.py b/unimol_tools/unimol_tools/weights/weighthub.py index e2088d2..e236e04 100644 --- a/unimol_tools/unimol_tools/weights/weighthub.py +++ b/unimol_tools/unimol_tools/weights/weighthub.py @@ -39,6 +39,27 @@ def weight_download(pretrain, save_path, local_dir_use_symlinks=True): #max_workers=8 ) +def weight_download_v2(pretrain, save_path, local_dir_use_symlinks=True): + """ + Downloads the specified pretrained model weights. + + :param pretrain: (str), The name of the pretrained model to download. + :param save_path: (str), The directory where the weights should be saved. + :param local_dir_use_symlinks: (bool, optional), Whether to use symlinks for the local directory. Defaults to True. + """ + if os.path.exists(os.path.join(save_path, pretrain)): + logger.info(f'{pretrain} exists in {save_path}') + return + + logger.info(f'Downloading {pretrain}') + snapshot_download( + repo_id="dptech/Uni-Mol2", + local_dir=save_path, + allow_patterns=pretrain, + local_dir_use_symlinks=local_dir_use_symlinks, + #max_workers=8 + ) + # Download all the weights when this script is run def download_all_weights(local_dir_use_symlinks=False): """