From 081921bd58f157e3ab7d653d70527df097dae483 Mon Sep 17 00:00:00 2001 From: David Ormrod Morley Date: Mon, 25 Nov 2024 16:35:15 +0100 Subject: [PATCH] Minor changes and add unit test SO102 --- interfaces/molecule/rdkit.py | 90 ++++++++++++++++---------- pyproject.toml | 3 +- unit_tests/test_molecule_interfaces.py | 42 +++++++++++- unit_tests/xyz/product1.xyz | 5 ++ unit_tests/xyz/product2.xyz | 18 ++++++ unit_tests/xyz/reactant1.xyz | 6 ++ unit_tests/xyz/reactant2.xyz | 17 +++++ 7 files changed, 145 insertions(+), 36 deletions(-) create mode 100644 unit_tests/xyz/product1.xyz create mode 100644 unit_tests/xyz/product2.xyz create mode 100644 unit_tests/xyz/reactant1.xyz create mode 100644 unit_tests/xyz/reactant2.xyz diff --git a/interfaces/molecule/rdkit.py b/interfaces/molecule/rdkit.py index d137e6b3..d00f1ec1 100644 --- a/interfaces/molecule/rdkit.py +++ b/interfaces/molecule/rdkit.py @@ -1,6 +1,4 @@ -from typing import List, Literal, Optional, overload, TYPE_CHECKING -from typing import Dict -from typing import Any +from typing import List, Literal, Optional, overload, TYPE_CHECKING, Sequence, Dict, Any import random import sys from warnings import warn @@ -34,6 +32,8 @@ "get_conformations", "yield_coords", "canonicalize_mol", + "to_image", + "get_reaction_image", ] @@ -1300,7 +1300,14 @@ def canonicalize_mol(mol, inplace=False, **kwargs): @requires_optional_package("rdkit") -def to_image(mol, remove_hydrogens=True, filename=None, format="svg", as_string=True): +@requires_optional_package("PIL") +def to_image( + mol: Molecule, + remove_hydrogens: bool = True, + filename: Optional[str] = None, + fmt: Literal["svg", "png", "eps", "pdf", "jpeg"] = "svg", + as_string: bool = True, +): """ Convert single molecule to single image object @@ -1336,36 +1343,36 @@ def to_image(mol, remove_hydrogens=True, filename=None, format="svg", as_string= if "." in filename: extension = filename.split(".")[-1] if extension in classes.keys(): - format = extension + fmt = extension else: - msg = ["Image type %s not available." % (extension)] - msg += ["Available extensions are: %s" % (" ".join(extensions))] + msg = [f"Image type {extension} not available."] + msg += [f"Available extensions are: {' '.join(extensions)}"] raise Exception("\n".join(msg)) else: - filename = ".".join([filename, format]) + filename = ".".join([filename, fmt]) - if format not in classes.keys(): - raise Exception("Image type %s not available." % (format)) + if fmt not in classes.keys(): + raise Exception(f"Image type {fmt} not available.") rdmol = _rdmol_for_image(mol, remove_hydrogens) # Draw the image - if classes[format.lower()] is None: + if classes[fmt.lower()] is None: # With AMS version of RDKit MolsToGridImage fails for eps (because of paste) img = Draw.MolToImage(rdmol, size=(200, 100)) buf = BytesIO() - img.save(buf, format=format) + img.save(buf, format=fmt) img_text = buf.getvalue() else: # This fails for a C=C=O molecule, with AMS rdkit version - img = classes[format.lower()]([rdmol], molsPerRow=1, subImgSize=(200, 100)) + img = classes[fmt.lower()]([rdmol], molsPerRow=1, subImgSize=(200, 100)) img_text = img if isinstance(img, Image.Image): buf = BytesIO() - img.save(buf, format=format) + img.save(buf, format=fmt) img_text = buf.getvalue() # If I do not make this correction to the SVG text, it is not readable in JupyterLab - if format.lower() == "svg": + if fmt.lower() == "svg": img_text = _correct_svg(img_text) # Write to file, if required @@ -1382,14 +1389,21 @@ def to_image(mol, remove_hydrogens=True, filename=None, format="svg", as_string= @requires_optional_package("rdkit") -def get_reaction_image(reactants, products, filename=None, format="svg", as_string=True): +@requires_optional_package("PIL") +def get_reaction_image( + reactants: Sequence[Molecule], + products: Sequence[Molecule], + filename: Optional[str] = None, + fmt: Literal["svg", "png", "eps", "pdf", "jpeg"] = "svg", + as_string: bool = True, +): """ Create a 2D reaction image from reactants and products (PLAMS molecules) * ``reactants`` -- Iterable of PLAMS Molecule objects representing the reactants. - * ``products`` -- Iterable of PLAMS Molecule objects representig the products. + * ``products`` -- Iterable of PLAMS Molecule objects representing the products. * ``filename`` -- Optional: Name of image file to be created. - * ``format`` -- The format of the image (svg, png, eps, pdf, jpeg). + * ``fmt`` -- The format of the image (svg, png, eps, pdf, jpeg). The extension in the filename, if provided, takes precedence. * ``as_string`` -- Returns the image as a string or bytestring. If set to False, the original format will be returned, which can be either a PIL image or SVG text @@ -1402,22 +1416,22 @@ def get_reaction_image(reactants, products, filename=None, format="svg", as_stri if "." in filename: extension = filename.split(".")[-1] if extension in extensions: - format = extension + fmt = extension else: - msg = ["Image type %s not available." % (extension)] - msg += ["Available extensions are: %s" % (" ".join(extensions))] + msg = [f"Image type {extension} not available."] + msg += [f"Available extensions are: {' '.join(extensions)}"] raise Exception("\n".join(msg)) else: - filename = ".".join([filename, format]) + filename = ".".join([filename, fmt]) - if format.lower() not in extensions: - raise Exception("Image type %s not available." % (format)) + if fmt.lower() not in extensions: + raise Exception(f"Image type {fmt} not available.") # Get the actual image - if format.lower() == "svg": + if fmt.lower() == "svg": img_text = get_reaction_image_svg(reactants, products) else: - img_text = get_reaction_image_pil(reactants, products, format, as_string=as_string) + img_text = get_reaction_image_pil(reactants, products, fmt, as_string=as_string) # Write to file, if required if filename is not None: @@ -1429,12 +1443,14 @@ def get_reaction_image(reactants, products, filename=None, format="svg", as_stri return img_text -def get_reaction_image_svg(reactants, products, width=200, height=100): +def get_reaction_image_svg( + reactants: Sequence[Molecule], products: Sequence[Molecule], width: int = 200, height: int = 100 +): """ Create a 2D reaction image from reactants and products (PLAMS molecules) * ``reactants`` -- Iterable of PLAMS Molecule objects representing the reactants. - * ``products`` -- Iterable of PLAMS Molecule objects representig the products. + * ``products`` -- Iterable of PLAMS Molecule objects representing the products. * Returns -- SVG image text file. """ from rdkit import Chem @@ -1502,12 +1518,19 @@ def add_arrow_svg(img_text, width, height, nreactants, prefix=""): return img_text -def get_reaction_image_pil(reactants, products, format, width=200, height=100, as_string=True): +def get_reaction_image_pil( + reactants: Sequence[Molecule], + products: Sequence[Molecule], + fmt: Literal["svg", "png", "eps", "pdf", "jpeg"], + width: int = 200, + height: int = 100, + as_string: bool = True, +): """ Create a 2D reaction image from reactants and products (PLAMS molecules) * ``reactants`` -- Iterable of PLAMS Molecule objects representing the reactants. - * ``products`` -- Iterable of PLAMS Molecule objects representig the products. + * ``products`` -- Iterable of PLAMS Molecule objects representing the products. * Returns -- SVG image text file. """ from io import BytesIO @@ -1583,8 +1606,8 @@ def join_pil_images(pil_images): if not hasattr(rdMolDraw2D, "MolDraw2DCairo"): # We are working with the old AMS version of RDKit white = (255, 255, 255) - rimages = [to_image(mol, format=format, as_string=False) for i, mol in enumerate(reactants)] - pimages = [to_image(mol, format=format, as_string=False) for i, mol in enumerate(products)] + rimages = [to_image(mol, fmt=fmt, as_string=False) for i, mol in enumerate(reactants)] + pimages = [to_image(mol, fmt=fmt, as_string=False) for i, mol in enumerate(products)] blanc = Image.new("RGB", (width, height), white) # Get the image (with arrow) @@ -1611,7 +1634,7 @@ def join_pil_images(pil_images): img_text = img if as_string: buf = BytesIO() - img.save(buf, format=format) + img.save(buf, format=fmt) img_text = buf.getvalue() return img_text @@ -1832,7 +1855,6 @@ def _rdmol_for_image(mol, remove_hydrogens=True): from rdkit.Chem import AllChem from rdkit.Chem import RemoveHs - # rdmol = mol.to_rdmol_new() rdmol = to_rdmol(mol, presanitize=True) # Flatten the molecule diff --git a/pyproject.toml b/pyproject.toml index a81978c1..6b43c489 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,8 @@ analysis = [ "matplotlib>=3.5.1", "pandas>=1.5.2", "networkx>=2.7.1", - "natsort>=8.1.0" + "natsort>=8.1.0", + "pillow>=9.2.0" ] docs = [ diff --git a/unit_tests/test_molecule_interfaces.py b/unit_tests/test_molecule_interfaces.py index 0f705672..9bef0d52 100644 --- a/unit_tests/test_molecule_interfaces.py +++ b/unit_tests/test_molecule_interfaces.py @@ -6,7 +6,15 @@ from scm.plams.interfaces.molecule.ase import toASE, fromASE from scm.plams.unit_tests.test_helpers import get_mock_find_spec, get_mock_open_function from scm.plams.core.errors import MissingOptionalPackageError -from scm.plams.interfaces.molecule.rdkit import from_rdmol, to_rdmol, from_smiles, to_smiles, from_smarts +from scm.plams.interfaces.molecule.rdkit import ( + from_rdmol, + to_rdmol, + from_smiles, + to_smiles, + from_smarts, + to_image, + get_reaction_image, +) from scm.plams.interfaces.molecule.packmol import packmol @@ -230,6 +238,38 @@ def test_to_smiles_and_from_smiles_requires_rdkit_package(self, plams_mols, smil with pytest.raises(MissingOptionalPackageError): from_smiles(smiles[0]) + def test_to_image_and_get_reaction_image_can_generate_img_files(self, xyz_folder): + from pathlib import Path + import shutil + + # Given molecules + reactants = [Molecule(f"{xyz_folder}/reactant{i}.xyz") for i in range(1, 3)] + products = [Molecule(f"{xyz_folder}/product{i}.xyz") for i in range(1, 3)] + + # When create images for molecules and reactions + result_dir = Path("result_images/rdkit") + try: + shutil.rmtree(result_dir) + except FileNotFoundError: + pass + result_dir.mkdir(parents=True, exist_ok=True) + + for i, m in enumerate(reactants): + m.guess_bonds() + to_image(m, filename=f"{result_dir}/reactant{i+1}.png") + + for i, m in enumerate(products): + m.guess_bonds() + to_image(m, filename=f"{result_dir}/product{i+1}.png") + + get_reaction_image(reactants, products, filename=f"{result_dir}/reaction.png") + + # Then image files are successfully created + # N.B. for this test just check the files are generated, not that the contents is correct + for f in ["reactant1.png", "reactant2.png", "product1.png", "product2.png", "reaction.png"]: + file = result_dir / f + assert file.exists() + class TestPackmol: """ diff --git a/unit_tests/xyz/product1.xyz b/unit_tests/xyz/product1.xyz new file mode 100644 index 00000000..9c479f92 --- /dev/null +++ b/unit_tests/xyz/product1.xyz @@ -0,0 +1,5 @@ +3 + +O -0.0898298722 -0.1823078766 -0.4030695566 +H 0.0892757418 0.7223466689 -0.0923666662 +H 0.0251541304 -0.7229767923 0.3978492228 \ No newline at end of file diff --git a/unit_tests/xyz/product2.xyz b/unit_tests/xyz/product2.xyz new file mode 100644 index 00000000..6d6089de --- /dev/null +++ b/unit_tests/xyz/product2.xyz @@ -0,0 +1,18 @@ +16 + +C -0.5058742951 0.7197846249 1.4873502327 +O -1.8467123470 1.0371356691 1.5578617129 +O 0.2426104621 0.8031328420 2.4444063939 +H -1.2336711895 -0.1871011373 -3.0847664179 +H 1.1181735289 -0.9674586781 -3.3575709511 +O -2.2996282296 0.6761194013 -0.9441961430 +H 1.8534242875 -0.1774201547 0.8154069061 +H -2.5401517031 0.9226546484 -0.0164069755 +H 2.6629785819 -0.9669531279 -1.4040236989 +H -2.0724918515 1.3639645877 2.4489574275 +C -1.0116321368 0.2602883042 -0.9839670361 +C -0.5476048834 -0.1835433773 -2.2354228233 +C 0.7686298918 -0.6221529110 -2.3859139467 +C 1.6373720529 -0.6205255079 -1.2863630033 +C 1.1825135366 -0.1770749455 -0.0425597678 +C -0.1335817057 0.2638847621 0.1265310905 \ No newline at end of file diff --git a/unit_tests/xyz/reactant1.xyz b/unit_tests/xyz/reactant1.xyz new file mode 100644 index 00000000..81b5ba59 --- /dev/null +++ b/unit_tests/xyz/reactant1.xyz @@ -0,0 +1,6 @@ +4 + +H -0.2376962379 -0.7861223708 -0.7251893655 +H 0.2018867953 0.7880914552 -0.7285471665 +O -0.2208830412 0.0589617158 -0.1564035523 +H 0.2343924838 -0.0668018003 0.7570340842 \ No newline at end of file diff --git a/unit_tests/xyz/reactant2.xyz b/unit_tests/xyz/reactant2.xyz new file mode 100644 index 00000000..f2a51a24 --- /dev/null +++ b/unit_tests/xyz/reactant2.xyz @@ -0,0 +1,17 @@ +15 + +C -0.7714286028 0.5721884475 1.5648994528 +O -1.9163545429 1.2651552598 1.3585245940 +O -0.3232871439 0.3453806305 2.6856889384 +H -0.2386198927 0.0473169758 -3.1356983004 +H 1.8106086100 -1.3053060069 -2.7858132139 +O -1.6768410850 1.0650295484 -1.2558970875 +H 1.4319019123 -0.8043435938 1.4748215886 +H 2.6580511680 -1.7004062192 -0.4892887246 +H -2.2596018166 1.4887930327 2.2433128310 +C -0.6549422587 0.4025760289 -1.0187446070 +C 0.1233087332 -0.1571411828 -2.1280100126 +C 1.2660074098 -0.8988286619 -1.9347794419 +C 1.7441534022 -1.1272237399 -0.6323129834 +C 1.0490198939 -0.6209878769 0.4726114196 +C -0.1137347868 0.1066333578 0.3111295469 \ No newline at end of file