Skip to content

Commit

Permalink
Yn patch 1 (#310)
Browse files Browse the repository at this point in the history
* [Fix] Predict input without target_col will cause exception, introduced in the last update.
[Fix] When smiles is None, using atoms as SMILES for splitting.
[Update] fix typo and version 0.1.2.post2

* [Update] Safer way to handle predictions with missing predict_cols
  • Loading branch information
emotionor authored Jan 6, 2025
1 parent 35efc65 commit 5221406
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion unimol_tools/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="unimol_tools",
version="0.1.2.post1",
version="0.1.2.post2",
description=("unimol_tools is a Python package for property prediciton with Uni-Mol in molecule, materials and protein."),
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down
13 changes: 7 additions & 6 deletions unimol_tools/unimol_tools/data/datareader.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,18 @@ def read_data(self, data=None, is_train=True, **params):
elif isinstance(target_cols, list):
pass
else:
for col in target_cols:
if col not in data.columns:
data[target_cols] = -1.0
break

raise ValueError('Unknown target_cols type: {}'.format(type(target_cols)))

if is_train:
if anomaly_clean:
data = self.anomaly_clean(data, task, target_cols)
if task == 'multiclass':
multiclass_cnt = int(data[target_cols].max() + 1)

else:
for col in target_cols:
if col not in data.columns or data[col].isnull().any():
data[col] = -1.0

targets = data[target_cols].values.tolist()
num_classes = len(target_cols)

Expand Down
3 changes: 3 additions & 0 deletions unimol_tools/unimol_tools/data/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def split(self, smiles, target=None, group=None, scaffolds=None, **params):
if self.n_splits == 1:
logger.warning('Only one fold is used for training, no splitting is performed.')
return [(np.arange(len(smiles)), ())]
if smiles is None and 'atoms' in params:
smiles = params['atoms']
logger.warning('Atoms are used as SMILES for splitting.')
if self.method in ['random']:
self.skf = self.splitter.split(smiles)
elif self.method in ['scaffold']:
Expand Down
2 changes: 1 addition & 1 deletion unimol_tools/unimol_tools/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def predict(self, data, save_path=None, metrics='none'):
- classification: auc, auprc, log_loss, acc, f1_score, mcc, precision, recall, cohen_kappa.
- regression: mse, pearsonr, spearmanr, mse, r2.
- regression: mae, pearsonr, spearmanr, mse, r2.
- multiclass: log_loss, acc.
Expand Down

0 comments on commit 5221406

Please sign in to comment.