Skip to content

Commit

Permalink
Yn update (#306)
Browse files Browse the repository at this point in the history
* add contining training from specific dir[load_model_dir]

* Update model loading to handle different output dimensions in retrain(trasnfer learning)

* update docs

* update unimol format Uni-Mol

* update url dptech-core to deepmodeling

* update version setup

* update unimol v2 docs

* update split methods

* update split method: group split; kfold=1 for all training

* merge main

* update train docs

* Fix: unimol_tools using unimolv2 sometimes hang at multiprocesses

* update version 0.1.2

* update log for generate conformers

* [update] unimolv2 feature calculated using numpy.ndarray instead of torch.tensor

* [update] conformer details for UnimolV1

* [Update] Windows platforms no longer use multiprocessing by default

* [Fix] Exception when using unimolv2 with atoms count >128

* [Update] update version number and requirements

* [Update] Weight download log info
  • Loading branch information
emotionor authored Dec 26, 2024
1 parent 3669afc commit 50bbad3
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 55 deletions.
3 changes: 2 additions & 1 deletion unimol_tools/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="unimol_tools",
version="0.1.2",
version="0.1.2.post1",
description=("unimol_tools is a Python package for property prediciton with Uni-Mol in molecule, materials and protein."),
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand All @@ -28,6 +28,7 @@
"pyyaml",
"addict",
"scikit-learn",
"numba",
"tqdm"],
python_requires=">=3.6",
include_package_data=True,
Expand Down
112 changes: 69 additions & 43 deletions unimol_tools/unimol_tools/data/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from .dictionary import Dictionary
from multiprocessing import Pool
from tqdm import tqdm
import torch
from numba import njit

from ..utils import logger
Expand Down Expand Up @@ -95,6 +94,12 @@ def _init_features(self, **params):
weight_download(self.dict_name, WEIGHT_DIR)
self.dictionary = Dictionary.load(os.path.join(WEIGHT_DIR, self.dict_name))
self.dictionary.add_symbol("[MASK]", is_special=True)
if os.name == 'posix':
self.multi_process = params.get('multi_process', True)
else:
self.multi_process = params.get('multi_process', False)
if self.multi_process:
logger.warning('Please use "if __name__ == "__main__":" to wrap the main function when using multi_process on Windows.')

def single_process(self, smiles):
"""
Expand All @@ -118,14 +123,27 @@ def transform_raw(self, atoms_list, coordinates_list):
return inputs

def transform(self, smiles_list):
pool = Pool()
logger.info('Start generating conformers...')
inputs = [item for item in tqdm(pool.imap(self.single_process, smiles_list))]
pool.close()
failed_cnt = np.mean([(item['src_coord']==0.0).all() for item in inputs])
logger.info('Succeeded in generating conformers for {:.2f}% of molecules.'.format((1-failed_cnt)*100))
failed_3d_cnt = np.mean([(item['src_coord'][:,2]==0.0).all() for item in inputs])
logger.info('Succeeded in generating 3d conformers for {:.2f}% of molecules.'.format((1-failed_3d_cnt)*100))
if self.multi_process:
pool = Pool(processes=min(8, os.cpu_count()))
inputs = [item for item in tqdm(pool.imap(self.single_process, smiles_list))]
pool.close()
else:
inputs = [self.single_process(smiles) for smiles in tqdm(smiles_list)]

failed_conf = [(item['src_coord']==0.0).all() for item in inputs]
logger.info('Succeeded in generating conformers for {:.2f}% of molecules.'.format((1-np.mean(failed_conf))*100))
failed_conf_indices = [index for index, value in enumerate(failed_conf) if value]
if len(failed_conf_indices) > 0:
logger.info('Failed conformers indices: {}'.format(failed_conf_indices))
logger.debug('Failed conformers SMILES: {}'.format([smiles_list[index] for index in failed_conf_indices]))

failed_conf_3d = [(item['src_coord'][:,2]==0.0).all() for item in inputs]
logger.info('Succeeded in generating 3d conformers for {:.2f}% of molecules.'.format((1-np.mean(failed_conf_3d))*100))
failed_conf_3d_indices = [index for index, value in enumerate(failed_conf_3d) if value]
if len(failed_conf_3d_indices) > 0:
logger.info('Failed 3d conformers indices: {}'.format(failed_conf_3d_indices))
logger.debug('Failed 3d conformers SMILES: {}'.format([smiles_list[index] for index in failed_conf_3d_indices]))
return inputs


Expand Down Expand Up @@ -282,6 +300,12 @@ def _init_features(self, **params):
self.method = params.get('method', 'rdkit_random')
self.mode = params.get('mode', 'fast')
self.remove_hs = params.get('remove_hs', True)
if os.name == 'posix':
self.multi_process = params.get('multi_process', True)
else:
self.multi_process = params.get('multi_process', False)
if self.multi_process:
logger.warning('Please use "if __name__ == "__main__":" to wrap the main function when using multi_process on Windows.')

def single_process(self, smiles):
"""
Expand All @@ -291,7 +315,6 @@ def single_process(self, smiles):
:return: A unimolecular data representation (dictionary) of the molecule.
:raises ValueError: If the conformer generation method is unrecognized.
"""
torch.set_num_threads(1)
if self.method == 'rdkit_random':
mol = inner_smi2coords(smiles, seed=self.seed, mode=self.mode, remove_hs=self.remove_hs, return_mol=True)
return mol2unimolv2(mol, self.max_atoms, remove_hs=self.remove_hs)
Expand All @@ -307,11 +330,13 @@ def transform_raw(self, atoms_list, coordinates_list):
return inputs

def transform(self, smiles_list):
torch.set_num_threads(1)
pool = Pool(processes=min(8, os.cpu_count()))
logger.info('Start generating conformers...')
inputs = [item for item in tqdm(pool.imap(self.single_process, smiles_list))]
pool.close()
if self.multi_process:
pool = Pool(processes=min(8, os.cpu_count()))
inputs = [item for item in tqdm(pool.imap(self.single_process, smiles_list))]
pool.close()
else:
inputs = [self.single_process(smiles) for smiles in tqdm(smiles_list)]

failed_conf = [(item['src_coord']==0.0).all() for item in inputs]
logger.info('Succeeded in generating conformers for {:.2f}% of molecules.'.format((1-np.mean(failed_conf))*100))
Expand Down Expand Up @@ -358,31 +383,31 @@ def mol2unimolv2(mol, max_atoms=128, remove_hs=True, **params):
:param mol: (rdkit.Chem.Mol) The molecule object containing atom symbols and coordinates.
:param max_atoms: (int) The maximum number of atoms to consider for the molecule.
:param remove_hs: (bool) Whether to remove hydrogen atoms from the representation.
:param remove_hs: (bool) Whether to remove hydrogen atoms from the representation. This must be True for UniMolV2.
:param params: Additional parameters.
:return: A batched data containing the molecular representation.
"""

mol = AllChem.AddHs(mol, addCoords=True)
atoms_h = np.array([atom.GetSymbol() for atom in mol.GetAtoms()])
nH_idx = [i for i, atom in enumerate(atoms_h) if atom != 'H']
atoms = atoms_h[nH_idx]
coordinates_h = mol.GetConformer().GetPositions().astype(np.float32)
coordinates = coordinates_h[nH_idx]
mol = AllChem.RemoveAllHs(mol)
atoms = np.array([atom.GetSymbol() for atom in mol.GetAtoms()])
coordinates = mol.GetConformer().GetPositions().astype(np.float32)

# cropping atoms and coordinates
if len(atoms) > max_atoms:
idx = np.random.choice(len(atoms), max_atoms, replace=False)
atoms = atoms[idx]
coordinates = coordinates[idx]
mask = np.zeros(len(atoms), dtype=bool)
mask[:max_atoms] = True
np.random.shuffle(mask) # shuffle the mask
atoms = atoms[mask]
coordinates = coordinates[mask]
else:
mask = np.ones(len(atoms), dtype=bool)
# tokens padding
src_tokens = torch.tensor([AllChem.GetPeriodicTable().GetAtomicNumber(item) for item in atoms])
src_coord = torch.tensor(coordinates)
# change AllChem.RemoveHs to AllChem.RemoveAllHs
mol = AllChem.RemoveAllHs(mol)
src_tokens = [AllChem.GetPeriodicTable().GetAtomicNumber(item) for item in atoms]
src_coord = coordinates
#
node_attr, edge_index, edge_attr = get_graph(mol)
feat = get_graph_features(edge_attr, edge_index, node_attr, drop_feat=0)
feat = get_graph_features(edge_attr, edge_index, node_attr, drop_feat=0, mask=mask)
feat['src_tokens'] = src_tokens
feat['src_coord'] = src_coord
return feat
Expand Down Expand Up @@ -474,7 +499,7 @@ def get_graph(mol):
edge_attr = np.empty((0, num_bond_features), dtype=np.int32)
return x, edge_index, edge_attr

def get_graph_features(edge_attr, edge_index, node_attr, drop_feat):
def get_graph_features(edge_attr, edge_index, node_attr, drop_feat, mask):
# atom_feat_sizes = [128] + [16 for _ in range(8)]
atom_feat_sizes = [16 for _ in range(8)]
edge_feat_sizes = [16, 16, 16]
Expand Down Expand Up @@ -511,22 +536,23 @@ def get_graph_features(edge_attr, edge_index, node_attr, drop_feat):

# combine, plus 1 for padding
feat = {}
feat["atom_feat"] = torch.from_numpy(atom_feat).long()
feat["atom_mask"] = torch.ones(N).long()
feat["edge_feat"] = torch.from_numpy(edge_feat).long()
feat["shortest_path"] = torch.from_numpy((shortest_path_result)).long()
feat["degree"] = torch.from_numpy(degree).long().view(-1)
feat["atom_feat"] = atom_feat[mask]
feat["atom_mask"] = np.ones(N, dtype=np.int64)[mask]
feat["edge_feat"] = edge_feat[mask][:, mask]
feat["shortest_path"] = shortest_path_result[mask][:, mask]
feat["degree"] = degree.reshape(-1)[mask]
# pair-type
atoms = feat["atom_feat"][..., 0]
pair_type = torch.cat(
[
atoms.view(-1, 1, 1).expand(-1, N, -1),
atoms.view(1, -1, 1).expand(N, -1, -1),
],
dim=-1,
)
atoms = atom_feat[..., 0]
pair_type = np.concatenate(
[
np.expand_dims(atoms, axis=(1, 2)).repeat(N, axis=1),
np.expand_dims(atoms, axis=(0, 2)).repeat(N, axis=0),
],
axis=-1,
)
pair_type = pair_type[mask][:, mask]
feat["pair_type"] = convert_to_single_emb(pair_type, [128, 128])
feat["attn_bias"] = torch.zeros((N + 1, N + 1), dtype=torch.float32)
feat["attn_bias"] = np.zeros((mask.sum() + 1, mask.sum() + 1), dtype=np.float32)
return feat

def convert_to_single_emb(x, sizes):
Expand Down
4 changes: 2 additions & 2 deletions unimol_tools/unimol_tools/data/datahub.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def _init_data(self, **params):
no_h_list = ConformerGen(**params).transform(smiles_list)
elif params.get('model_name', None) == 'unimolv2':
if 'atoms' in self.data and 'coordinates' in self.data:
no_h_list = UniMolV2Feature().transform_raw(self.data['atoms'], self.data['coordinates'])
no_h_list = UniMolV2Feature(**params).transform_raw(self.data['atoms'], self.data['coordinates'])
else:
smiles_list = self.data["smiles"]
no_h_list = UniMolV2Feature().transform(smiles_list)
no_h_list = UniMolV2Feature(**params).transform(smiles_list)

self.data['unimol_input'] = no_h_list

Expand Down
18 changes: 9 additions & 9 deletions unimol_tools/unimol_tools/models/unimolv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,23 +300,23 @@ def batch_collate_fn(self, samples):
batch = {}
for k in samples[0][0].keys():
if k == 'atom_feat':
v = pad_coords([s[0][k] for s in samples], pad_idx=self.padding_idx, dim=8)
v = pad_coords([torch.tensor(s[0][k]) for s in samples], pad_idx=self.padding_idx, dim=8)
elif k == 'atom_mask':
v = pad_1d_tokens([s[0][k] for s in samples], pad_idx=self.padding_idx)
v = pad_1d_tokens([torch.tensor(s[0][k]) for s in samples], pad_idx=self.padding_idx)
elif k == 'edge_feat':
v = pad_2d([s[0][k] for s in samples], pad_idx=self.padding_idx, dim=3)
v = pad_2d([torch.tensor(s[0][k]) for s in samples], pad_idx=self.padding_idx, dim=3)
elif k == 'shortest_path':
v = pad_2d([s[0][k] for s in samples], pad_idx=self.padding_idx)
v = pad_2d([torch.tensor(s[0][k]) for s in samples], pad_idx=self.padding_idx)
elif k == 'degree':
v = pad_1d_tokens([s[0][k] for s in samples], pad_idx=self.padding_idx)
v = pad_1d_tokens([torch.tensor(s[0][k]) for s in samples], pad_idx=self.padding_idx)
elif k == 'pair_type':
v = pad_2d([s[0][k] for s in samples], pad_idx=self.padding_idx, dim=2)
v = pad_2d([torch.tensor(s[0][k]) for s in samples], pad_idx=self.padding_idx, dim=2)
elif k == 'attn_bias':
v = pad_2d([s[0][k] for s in samples], pad_idx=self.padding_idx)
v = pad_2d([torch.tensor(s[0][k]) for s in samples], pad_idx=self.padding_idx)
elif k == 'src_tokens':
v = pad_1d_tokens([s[0][k] for s in samples], pad_idx=self.padding_idx)
v = pad_1d_tokens([torch.tensor(s[0][k]) for s in samples], pad_idx=self.padding_idx)
elif k == 'src_coord':
v = pad_coords([s[0][k] for s in samples], pad_idx=self.padding_idx)
v = pad_coords([torch.tensor(s[0][k]) for s in samples], pad_idx=self.padding_idx)
batch[k] = v
try:
label = torch.tensor([s[1] for s in samples])
Expand Down

0 comments on commit 50bbad3

Please sign in to comment.