Skip to content

Commit

Permalink
Yn update (#304)
Browse files Browse the repository at this point in the history
* add contining training from specific dir[load_model_dir]

* Update model loading to handle different output dimensions in retrain(trasnfer learning)

* update docs

* update unimol format Uni-Mol

* update url dptech-core to deepmodeling

* update version setup

* update unimol v2 docs

* update split methods

* update split method: group split; kfold=1 for all training

* merge main

* update train docs

* Fix: unimol_tools using unimolv2 sometimes hang at multiprocesses

* update version 0.1.2

* update log for generate conformers
  • Loading branch information
emotionor authored Dec 23, 2024
1 parent 90ad6af commit 3669afc
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 33 deletions.
2 changes: 1 addition & 1 deletion unimol_tools/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="unimol_tools",
version="0.1.1.post1",
version="0.1.2",
description=("unimol_tools is a Python package for property prediciton with Uni-Mol in molecule, materials and protein."),
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down
27 changes: 20 additions & 7 deletions unimol_tools/unimol_tools/data/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def single_process(self, smiles):
:return: A unimolecular data representation (dictionary) of the molecule.
:raises ValueError: If the conformer generation method is unrecognized.
"""
torch.set_num_threads(1)
if self.method == 'rdkit_random':
mol = inner_smi2coords(smiles, seed=self.seed, mode=self.mode, remove_hs=self.remove_hs, return_mol=True)
return mol2unimolv2(mol, self.max_atoms, remove_hs=self.remove_hs)
Expand All @@ -306,14 +307,26 @@ def transform_raw(self, atoms_list, coordinates_list):
return inputs

def transform(self, smiles_list):
pool = Pool()
torch.set_num_threads(1)
pool = Pool(processes=min(8, os.cpu_count()))
logger.info('Start generating conformers...')
inputs = [item for item in tqdm(pool.imap(self.single_process, smiles_list))]
pool.close()
# failed_cnt = np.mean([(item['src_coord']==0.0).all() for item in inputs])
# logger.info('Succeeded in generating conformers for {:.2f}% of molecules.'.format((1-failed_cnt)*100))
# failed_3d_cnt = np.mean([(item['src_coord'][:,2]==0.0).all() for item in inputs])
# logger.info('Succeeded in generating 3d conformers for {:.2f}% of molecules.'.format((1-failed_3d_cnt)*100))

failed_conf = [(item['src_coord']==0.0).all() for item in inputs]
logger.info('Succeeded in generating conformers for {:.2f}% of molecules.'.format((1-np.mean(failed_conf))*100))
failed_conf_indices = [index for index, value in enumerate(failed_conf) if value]
if len(failed_conf_indices) > 0:
logger.info('Failed conformers indices: {}'.format(failed_conf_indices))
logger.debug('Failed conformers SMILES: {}'.format([smiles_list[index] for index in failed_conf_indices]))

failed_conf_3d = [(item['src_coord'][:,2]==0.0).all() for item in inputs]
logger.info('Succeeded in generating 3d conformers for {:.2f}% of molecules.'.format((1-np.mean(failed_conf_3d))*100))
failed_conf_3d_indices = [index for index, value in enumerate(failed_conf_3d) if value]
if len(failed_conf_3d_indices) > 0:
logger.info('Failed 3d conformers indices: {}'.format(failed_conf_3d_indices))
logger.debug('Failed 3d conformers SMILES: {}'.format([smiles_list[index] for index in failed_conf_3d_indices]))

return inputs

def create_mol_from_atoms_and_coords(atoms, coordinates):
Expand Down Expand Up @@ -365,13 +378,13 @@ def mol2unimolv2(mol, max_atoms=128, remove_hs=True, **params):
coordinates = coordinates[idx]
# tokens padding
src_tokens = torch.tensor([AllChem.GetPeriodicTable().GetAtomicNumber(item) for item in atoms])
src_pos = torch.tensor(coordinates)
src_coord = torch.tensor(coordinates)
# change AllChem.RemoveHs to AllChem.RemoveAllHs
mol = AllChem.RemoveAllHs(mol)
node_attr, edge_index, edge_attr = get_graph(mol)
feat = get_graph_features(edge_attr, edge_index, node_attr, drop_feat=0)
feat['src_tokens'] = src_tokens
feat['src_pos'] = src_pos
feat['src_coord'] = src_coord
return feat

def safe_index(l, e):
Expand Down
26 changes: 26 additions & 0 deletions unimol_tools/unimol_tools/data/datahub.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from .datareader import MolDataReader
from .datascaler import TargetScaler
from .conformer import ConformerGen, UniMolV2Feature
from .split import Splitter
from ..utils import logger


class DataHub(object):
"""
Expand All @@ -31,6 +34,7 @@ def __init__(self, data=None, is_train=True, save_path=None, **params):
self.multiclass_cnt = params.get('multiclass_cnt', None)
self.ss_method = params.get('target_normalize', 'none')
self._init_data(**params)
self._init_split(**params)

def _init_data(self, **params):
"""
Expand Down Expand Up @@ -89,3 +93,25 @@ def _init_data(self, **params):
no_h_list = UniMolV2Feature().transform(smiles_list)

self.data['unimol_input'] = no_h_list

def _init_split(self, **params):

self.split_method = params.get('split_method','5fold_random')
kfold, method = int(self.split_method.split('fold')[0]), self.split_method.split('_')[-1] # Nfold_xxxx
self.kfold = params.get('kfold', kfold)
self.method = params.get('split', method)
self.split_seed = params.get('split_seed', 42)
self.data['kfold'] = self.kfold
if not self.is_train:
return
self.splitter = Splitter(self.method, self.kfold, seed=self.split_seed)
split_nfolds = self.splitter.split(**self.data)
if self.kfold == 1:
logger.info(f"Kfold is 1, all data is used for training.")
else:
logger.info(f"Split method: {self.method}, fold: {self.kfold}")
nfolds = np.zeros(len(split_nfolds[0][0])+len(split_nfolds[0][1]), dtype=int)
for enu, (tr_idx, te_idx) in enumerate(split_nfolds):
nfolds[te_idx] = enu
self.data['split_nfolds'] = split_nfolds
return split_nfolds
14 changes: 7 additions & 7 deletions unimol_tools/unimol_tools/data/datareader.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,17 @@ def read_data(self, data=None, is_train=True, **params):
target_cols = target_cols.split(',')
elif isinstance(target_cols, list):
pass
else:
else:
for col in target_cols:
if col not in data.columns:
data[target_cols] = -1.0
break

if is_train and anomaly_clean:
data = self.anomaly_clean(data, task, target_cols)

if is_train and task == 'multiclass':
multiclass_cnt = int(data[target_cols].max() + 1)
if is_train:
if anomaly_clean:
data = self.anomaly_clean(data, task, target_cols)
if task == 'multiclass':
multiclass_cnt = int(data[target_cols].max() + 1)

targets = data[target_cols].values.tolist()
num_classes = len(target_cols)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,30 @@

from __future__ import absolute_import, division, print_function

import numpy as np
from sklearn.model_selection import (
GroupKFold,
KFold,
StratifiedKFold,
)
from ..utils import logger


class Splitter(object):
"""
The Splitter class is responsible for splitting a dataset into train and test sets
based on the specified method.
"""
def __init__(self, split_method='5fold_random', seed=42):
def __init__(self, method='random', kfold=5, seed=42, **params):
"""
Initializes the Splitter with a specified split method and random seed.
:param split_method: (str) The method for splitting the dataset, in the format 'Nfold_method'.
Defaults to '5fold_random'.
:param seed: (int) Random seed for reproducibility in random splitting. Defaults to 42.
"""
self.n_splits, self.method = int(split_method.split('fold')[0]), split_method.split('_')[-1] # Nfold_xxxx
self.method = method
self.n_splits = kfold
self.seed = seed
self.splitter = self._init_split()

Expand All @@ -34,18 +38,22 @@ def _init_split(self):
:return: The initialized splitter object.
:raises ValueError: If an unknown splitting method is specified.
"""
if self.n_splits == 1:
return None
if self.method == 'random':
splitter = KFold(n_splits=self.n_splits, shuffle=True, random_state=self.seed)
elif self.method == 'scaffold' or self.method == 'group':
splitter = GroupKFold(n_splits=self.n_splits)
elif self.method == 'stratified':
splitter = StratifiedKFold(n_splits=self.n_splits, shuffle=True, random_state=self.seed)
elif self.method == 'select':
splitter = GroupKFold(n_splits=self.n_splits)
else:
raise ValueError('Unknown splitter method: {}fold - {}'.format(self.n_splits, self.method))

return splitter

def split(self, data, target=None, group=None):
def split(self, smiles, target=None, group=None, scaffolds=None, **params):
"""
Splits the dataset into train and test sets based on the initialized method.
Expand All @@ -56,7 +64,32 @@ def split(self, data, target=None, group=None):
:return: An iterator yielding train and test set indices for each fold.
:raises ValueError: If the splitter method does not support the provided parameters.
"""
try:
return self.splitter.split(data, target, group)
except:
raise ValueError('Unknown splitter method: {}fold - {}'.format(self.n_splits, self.method))
if self.n_splits == 1:
logger.warning('Only one fold is used for training, no splitting is performed.')
return [(np.arange(len(smiles)), ())]
if self.method in ['random']:
self.skf = self.splitter.split(smiles)
elif self.method in ['scaffold']:
self.skf = self.splitter.split(smiles, target, scaffolds)
elif self.method in ['group']:
self.skf = self.splitter.split(smiles, target, group)
elif self.method in ['stratified']:
self.skf = self.splitter.split(smiles, group)
elif self.method in ['select']:
unique_groups = np.unique(group)
if len(unique_groups) == self.n_splits:
split_folds = []
for unique_group in unique_groups:
train_idx = np.where(group != unique_group)[0]
test_idx = np.where(group == unique_group)[0]
split_folds.append((train_idx, test_idx))
self.split_folds = split_folds
return self.split_folds
else:
logger.error('The number of unique groups is not equal to the number of splits.')
exit(1)
else:
logger.error('Unknown splitter method: {}'.format(self.method))
exit(1)
self.split_folds = list(self.skf)
return self.split_folds
8 changes: 4 additions & 4 deletions unimol_tools/unimol_tools/models/nnmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, data, trainer, **params):
self.data_type = params.get('data_type', 'molecule')
self.loss_key = params.get('loss_key', None)
self.trainer = trainer
self.splitter = self.trainer.splitter
#self.splitter = self.trainer.splitter
self.model_params = params.copy()
self.task = params['task']
if self.task in OUTPUT_DIM:
Expand Down Expand Up @@ -150,7 +150,7 @@ def run(self):
y.reshape(y.shape[0], self.num_classes)).astype(float)
else:
y_pred = np.zeros((y.shape[0], self.model_params['output_dim']))
for fold, (tr_idx, te_idx) in enumerate(self.splitter.split(X, y, group)):
for fold, (tr_idx, te_idx) in enumerate(self.data['split_nfolds']):
X_train, y_train = X[tr_idx], y[tr_idx]
X_valid, y_valid = X[te_idx], y[te_idx]
traindataset = NNDataset(X_train, y_train)
Expand Down Expand Up @@ -220,7 +220,7 @@ def evaluate(self, trainer=None, checkpoints_path=None):
"""
logger.info("start predict NNModel:{}".format(self.model_name))
testdataset = NNDataset(self.features, np.asarray(self.data['target']))
for fold in range(self.splitter.n_splits):
for fold in range(self.data['kfold']):
model_path = os.path.join(checkpoints_path, f'model_{fold}.pth')
self.model.load_state_dict(torch.load(
model_path, map_location=self.trainer.device)['model_state_dict'])
Expand All @@ -229,7 +229,7 @@ def evaluate(self, trainer=None, checkpoints_path=None):
if fold == 0:
y_pred = np.zeros_like(_y_pred)
y_pred += _y_pred
y_pred /= self.splitter.n_splits
y_pred /= self.data['kfold']
self.cv['test_pred'] = y_pred

def count_parameters(self, model):
Expand Down
8 changes: 4 additions & 4 deletions unimol_tools/unimol_tools/models/unimolv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,14 @@ def forward(
pair_type,
attn_bias,
src_tokens,
src_pos,
src_coord,
return_repr=False,
return_atomic_reprs=False,
**kwargs
):


pos = src_pos
pos = src_coord

n_mol, n_atom = atom_feat.shape[:2]
token_feat = self.embed_tokens(src_tokens)
Expand Down Expand Up @@ -232,7 +232,7 @@ def one_block(x, pos, return_x=False):
filtered_tensors = []
filtered_coords = []

for tokens, coord in zip(src_tokens, src_pos):
for tokens, coord in zip(src_tokens, src_coord):
filtered_tensor = tokens[(tokens != 0) & (tokens != 1) & (tokens != 2)] # filter out BOS(0), EOS(1), PAD(2)
filtered_coord = coord[(tokens != 0) & (tokens != 1) & (tokens != 2)]
filtered_tensors.append(filtered_tensor)
Expand Down Expand Up @@ -315,7 +315,7 @@ def batch_collate_fn(self, samples):
v = pad_2d([s[0][k] for s in samples], pad_idx=self.padding_idx)
elif k == 'src_tokens':
v = pad_1d_tokens([s[0][k] for s in samples], pad_idx=self.padding_idx)
elif k == 'src_pos':
elif k == 'src_coord':
v = pad_coords([s[0][k] for s in samples], pad_idx=self.padding_idx)
batch[k] = v
try:
Expand Down
2 changes: 0 additions & 2 deletions unimol_tools/unimol_tools/tasks/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# from transformers.optimization import get_linear_schedule_with_warmup
from ..utils import Metrics
from ..utils import logger
from .split import Splitter
from tqdm import tqdm

import time
Expand Down Expand Up @@ -46,7 +45,6 @@ def _init_trainer(self, **params):
self.split_seed = params.get('split_seed', 42)
self.seed = params.get('seed', 42)
self.set_seed(self.seed)
self.splitter = Splitter(self.split_method, self.split_seed)
self.logger_level = int(params.get('logger_level', 1))
### init NN trainer params ###
self.learning_rate = float(params.get('learning_rate', 1e-4))
Expand Down
16 changes: 15 additions & 1 deletion unimol_tools/unimol_tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,23 @@ def __init__(self,
- multilabel_regression: mae, mse, r2.
:param split: str, default='random', split method of training dataset. currently support: random, scaffold, group, stratified.
:param split: str, default='random', split method of training dataset. currently support: random, scaffold, group, stratified, select.
- random: random split.
- scaffold: split by scaffold.
- group: split by group. `split_group_col` should be specified.
- stratified: stratified split. `split_group_col` should be specified.
- select: use `split_group_col` to manually select the split group. Column values of `split_group_col` should be range from 0 to kfold-1 to indicate the split group.
:param split_group_col: str, default='scaffold', column name of group split.
:param kfold: int, default=5, number of folds for k-fold cross validation.
- 1: no split. all data will be used for training.
:param save_path: str, default='./exp', path to save training results.
:param remove_hs: bool, default=False, whether to remove hydrogens from molecules.
:param smiles_col: str, default='SMILES', column name of SMILES.
Expand Down

0 comments on commit 3669afc

Please sign in to comment.