From 7adf180ead991b66236d91daad5ff6c5919c48d9 Mon Sep 17 00:00:00 2001 From: Rosa Bulo Date: Thu, 21 Nov 2024 17:12:35 +0100 Subject: [PATCH 1/5] Rosa Bulo (REB) SCMSUITE-- SO102: Added 2D depiction molecules/reactions + improved stability to_rdmol - to_rdmol now has an argument presanitize, which iteratively sanitized - improved bonding/atomic charges, until sanitization succeeds. - There are the functions to_image and get_reaction_image. --- __init__.py | 2 + interfaces/molecule/rdkit.py | 579 ++++++++++++++++++++++++++++++++++- 2 files changed, 577 insertions(+), 4 deletions(-) diff --git a/__init__.py b/__init__.py index 47248dd3a..c25e50bea 100644 --- a/__init__.py +++ b/__init__.py @@ -98,6 +98,8 @@ to_smiles, writepdb, yield_coords, + to_image, + get_reaction_image, ) from scm.plams.interfaces.thirdparty.cp2k import Cp2kJob, Cp2kResults, Cp2kSettings2Mol from scm.plams.interfaces.thirdparty.crystal import CrystalJob, mol2CrystalConf diff --git a/interfaces/molecule/rdkit.py b/interfaces/molecule/rdkit.py index 01525cfa6..aac6b811a 100644 --- a/interfaces/molecule/rdkit.py +++ b/interfaces/molecule/rdkit.py @@ -114,7 +114,7 @@ def from_rdmol(rdkit_mol, confid=-1, properties=True): return plams_mol -def to_rdmol(plams_mol, sanitize=True, properties=True, assignChirality=False): +def to_rdmol(plams_mol, sanitize=True, properties=True, assignChirality=False, presanitize=False): """ Translate a PLAMS molecule into an RDKit molecule type. PLAMS |Molecule|, |Atom| or |Bond| properties are pickled if they are neither booleans, floats, @@ -218,10 +218,11 @@ def plams_to_rd_bonds(bo): if sanitize: try: - Chem.SanitizeMol(rdmol) + if presanitize: + rdmol = _presanitize(plams_mol, rdmol) + else: + Chem.SanitizeMol(rdmol) except ValueError as exc: - # rdkit_flag = Chem.SanitizeMol(rdmol,catchErrors=True) - # log ('RDKit Sanitization Error. Failed Operation Flag = %s'%(rdkit_flag)) log("RDKit Sanitization Error.") text = "Most likely this is a problem with the assigned bond orders: " text += "Use chemical insight to adjust them." @@ -1270,3 +1271,573 @@ def canonicalize_mol(mol, inplace=False, **kwargs): ret = mol.copy() ret.atoms = [at for _, at in sorted(zip(idx_rank, ret.atoms), reverse=True)] return ret + +def to_image (mol, remove_hydrogens=True, filename=None, format="svg", as_string=True): + """ + Convert single molecule to single image object + + * ``mol`` -- PLAMS Molecule object + * ``remove_hydrogens`` -- Wether or not to remove the H-atoms from the image + * ``filename`` -- Optional: Name of image file to be created. + * ``as_string`` -- Returns the image as a string or bytestring. If set to False, the original format + will be returned, which can be either a PIL image or SVG text + We do this because after converting a PIL image to a bytestring it is not possible + to further edit it (with our version of PIL). + * Returns -- SVG image text file. + """ + from io import BytesIO + from PIL import Image + from rdkit import Chem + from rdkit.Chem import Draw + from rdkit.Chem.Draw import rdMolDraw2D + from rdkit.Chem.Draw import MolsToGridImage + + extensions = ["svg", "png", "eps", "pdf", "jpeg"] + + classes = {} + classes["svg"] = _MolsToGridSVG + for ext in extensions[1:]: + classes[ext] = None + # PNG can only be created in this way with later version of RDKit + if hasattr(rdMolDraw2D,"MolDraw2DCairo"): + for ext in extensions[1:]: + classes[ext] = MolsToGridImage + + # Determine the type of image file + if filename is not None: + if "." in filename: + extension = filename.split(".")[-1] + if extension in classes.keys(): + format = extension + else: + msg = ["Image type %s not available."%(extension)] + msg += ["Available extensions are: %s"%(" ".join(extensions))] + raise Exception("\n".join(msg)) + else: + filename = ".".join(filename,format) + + if format not in classes.keys(): + raise Exception("Image type %s not available."%(format)) + + rdmol = _rdmol_for_image(mol, remove_hydrogens) + + # Draw the image + if classes[format.lower()] is None: + # With AMS version of RDKit MolsToGridImage fails for eps (because of paste) + img = Draw.MolToImage(rdmol,size=(200,100)) + buf = BytesIO() + img.save(buf, format=format) + img_text = buf.getvalue() + else: + # This fails for a C=C=O molecule, with AMS rdkit version + img = classes[format.lower()]([rdmol], molsPerRow=1, subImgSize=(200, 100)) + img_text = img + if isinstance(img,Image.Image): + buf = BytesIO() + img.save(buf, format=format) + img_text = buf.getvalue() + + # Write to file, if required + if filename is not None: + mode = "w" + if isinstance(img_text,bytes): + mode = "wb" + with open(filename,mode) as outfile: + outfile.write(img_text) + + if as_string: + img = img_text + return img + +def get_reaction_image (reactants, products, filename=None, format="svg", as_string=True): + """ + Create a 2D reaction image from reactants and products (PLAMS molecules) + + * ``reactants`` -- Iterable of PLAMS Molecule objects representing the reactants. + * ``products`` -- Iterable of PLAMS Molecule objects representig the products. + * ``filename`` -- Optional: Name of image file to be created. + * ``format`` -- The format of the image (svg, png, eps). + The extension in the filename, if provided, takes precedence. + * ``as_string`` -- Returns the image as a string or bytestring. If set to False, the original format + will be returned, which can be either a PIL image or SVG text + * Returns -- SVG image text file. + """ + extensions = ["svg", "png", "eps", "pdf", "jpeg"] + + # Determine the type of image file + if filename is not None: + if "." in filename: + extension = filename.split(".")[-1] + if extension in extensions: + format = extension + else: + msg = ["Image type %s not available."%(extension)] + msg += ["Available extensions are: %s"%(" ".join(extensions))] + raise Exception("\n".join(msg)) + else: + filename = ".".join(filename,format) + + if format.lower() not in extensions: + raise Exception("Image type %s not available."%(format)) + + # Get the actual image + if format.lower() == "svg": + img_text = get_reaction_image_svg (reactants, products) + else: + img_text = get_reaction_image_pil(reactants, products, format, as_string=as_string) + + # Write to file, if required + if filename is not None: + mode = "w" + if isinstance(img_text,bytes): + mode = "wb" + with open(filename,mode) as outfile: + outfile.write(img_text) + return img_text + +def get_reaction_image_svg (reactants, products, width=200, height=100): + """ + Create a 2D reaction image from reactants and products (PLAMS molecules) + + * ``reactants`` -- Iterable of PLAMS Molecule objects representing the reactants. + * ``products`` -- Iterable of PLAMS Molecule objects representig the products. + * Returns -- SVG image text file. + """ + from rdkit import Chem + + def svg_arrow (x1, y1, x2, y2, prefix=""): + """ + The reaction arrow in html format + """ + # The arrow head + l = ['<%sdefs> <%smarker id="arrow" viewBox="0 0 10 10" refX="5" refY="5" '%(prefix, prefix)] + l += ['markerWidth="6" markerHeight="6" '] + l += ['orient="auto-start-reverse"> <%spath d="M 0 0 L 10 5 L 0 10 z" />'%(prefix, prefix, prefix)] + arrow = ''.join(l) + # The line + l = ['<%sline x1="%i" y1="%i" x2="%i" y2="%i" '%(prefix, x1, y1, x2, y2)] + l += ['stroke="black" marker-end="url(#arrow)" />'] + line = ''.join(l) + return [arrow, line] + + def add_plus_signs_svg (img_text, width, height, nmols, nreactants, prefix=""): + """ + Add the lines with + signs to the SVG image + """ + y = int(0.55 * height) + t = [] + for i in range(nmols-1): + x = int(((i + 1) * width) - (0.1 * width)) + if i + 1 in (nreactants, nreactants + 1): + continue + t += ['<%stext x="%i" y="%i" font-size="16">+'%(prefix,x,y,prefix)] + lines = img_text.split("\n") + lines = lines[:-2] + t + lines[-2:] + return "\n".join(lines) + + def add_arrow_svg (img_text, width, height, nreactants, prefix=""): + """ + Add the arrow to the SVG image + """ + y = int(0.5 * height) + x1 = int((nreactants * width) + (0.3 * width)) + x2 = int((nreactants * width) + (0.7 * width)) + t = svg_arrow (x1, y, x2, y, prefix) + lines = img_text.split("\n") + lines = lines[:-2] + t + lines[-2:] + return "\n".join(lines) + + # Get the rdkit molecules + rdmols = [_rdmol_for_image(mol) for mol in reactants] + rdmols += [Chem.Mol()] # This is where the arrow will go + rdmols += [_rdmol_for_image(mol) for mol in products] + nmols = len(rdmols) + + # Place the molecules in a row of images + subimg_size = [width, height] + kwargs = {"legendFontSize":16} #,"legendFraction":0.1} + img_text = _MolsToGridSVG(rdmols, molsPerRow=nmols, subImgSize=subimg_size, **kwargs) + img_text = _correct_svg (img_text) + + # Add + and => + nreactants = len(reactants) + img_text = add_plus_signs_svg(img_text, width, height, nmols, nreactants) + img_text = add_arrow_svg (img_text, width, height, nreactants) + + return img_text + +def get_reaction_image_pil (reactants, products, format, width=200, height=100, as_string=True): + """ + Create a 2D reaction image from reactants and products (PLAMS molecules) + + * ``reactants`` -- Iterable of PLAMS Molecule objects representing the reactants. + * ``products`` -- Iterable of PLAMS Molecule objects representig the products. + * Returns -- SVG image text file. + """ + from io import BytesIO + from PIL import Image + from PIL import ImageDraw + from rdkit import Chem + from rdkit.Chem.Draw import rdMolDraw2D, MolsToGridImage + + def add_arrow_pil (img, width, height, nreactants): + """ + Add the arrow to the PIL image + """ + y1 = int(0.5 * height) + y2 = y1 + x1 = int((nreactants * width) + (0.3 * width)) + x2 = int((nreactants * width) + (0.7 * width)) + + # Draw a line + black = (0, 0, 0) + draw = ImageDraw.Draw(img) + draw.line(((x1, y1), (x2, y2)), fill=128) + + # Draw the arrow head + headscale = 20 + xshift = width // headscale + yshift = img.size[1] // headscale + p1 = (x2, y2 + yshift) + p2 = (x2, y2 - yshift) + p3 = (x2 + xshift, y2) + draw.polygon((p1, p2, p3), fill=black) + + return img + + def add_plus_signs_pil (img, width, height, nmols, nreactants): + """ + Add the lines with + signs to the SVG image + """ + white = (255, 255, 255) + black = (0, 0, 0) + + I1 = ImageDraw.Draw(img) + y = int(0.5 * height) + #myfont = ImageFont.truetype('FreeMono.ttf', 25) + for i in range(nmols): + x = int(((i + 1) * width) - (0.05 * width)) + if i + 1 in (nreactants,nreactants + 1): + continue + #I1.text((x, y), "+", font=myfont, fill=black) + I1.text((x, y), "+", fill=black) + return img + + def join_pil_images (pil_images): + """ + Create a new image which connects the ones above with text + """ + white = (255, 255, 255) + + widths = [img.width for img in pil_images] + width = sum(widths) + height = max([img.height for img in pil_images]) + final_img = Image.new('RGB', (width, height), white) + + # Concatenate the PIL images + for i,img in enumerate(pil_images): + pos = sum(widths[:i]) + h = int((height - img.height) / 2) + final_img.paste(img, (pos, h)) + + return final_img + + nreactants = len(reactants) + nmols = nreactants + len(products) + + if not hasattr(rdMolDraw2D, 'MolDraw2DCairo'): + # We are working with the old AMS version of RDKit + white = (255, 255, 255) + rimages = [to_image(mol, format=format, as_string=False) for i,mol in enumerate(reactants)] + pimages = [to_image(mol, format=format, as_string=False) for i,mol in enumerate(products)] + blanc = Image.new('RGB', (width, height), white) + + # Get the image (with arrow) + nreactants = len(reactants) + all_images = rimages + [blanc] + pimages + img = join_pil_images(all_images) + + else: + # We have a later version of RDKit that can regulate the font sizes + rdmols = [_rdmol_for_image(mol) for mol in reactants] + rdmols += [Chem.Mol()] # This is where the arrow will go + rdmols += [_rdmol_for_image(mol) for mol in products] + + # Place the molecules in a row of images + subimg_size = [width, height] + kwargs = {"legendFontSize":16} #,"legendFraction":0.1} + img = MolsToGridImage(rdmols, molsPerRow=nmols + 1, subImgSize=subimg_size, **kwargs) + + # Add + and => + img = add_plus_signs_pil(img, width, height, nmols, nreactants) + img = add_arrow_pil(img, width, height, nreactants) + + # Get the bytestring + img_text = img + if as_string: + buf = BytesIO() + img.save(buf, format=format) + img_text = buf.getvalue() + + return img_text + +def _correct_svg (image): + """ + Correct for a bug in the AMS rdkit created SVG file + """ + if not "svg:" in image: + return image + image = image.replace("svg:","") + lines = image.split("\n") + for iline,line in enumerate(lines): + if "xmlns:svg=" in line: + lines[iline] = line.replace("xmlns:svg","xmlns") + break + image = "\n".join(lines) + return image + +def _presanitize (mol, rdmol): + """ + Change bonding and atom charges to avoid failed sanitization + + Note: Used by to_rdmol + """ + from rdkit import Chem + mol = mol.copy() + for i in range(10): + try: + Chem.SanitizeMol(rdmol) + stored_exc = None + break + except ValueError as exc: + stored_exc = exc + text = repr(exc) + bonds, charges = _kekulize(mol, text) + rdmol = _update_system_for_sanitation(rdmol, bonds, charges) + if stored_exc is not None: + raise stored_exc + return rdmol + +def _update_system_for_sanitation(rdmol, altered_bonds, altered_charge): + """ + Change bond orders and charges if rdkit molecule, so that sanitiation will succeed. + """ + from rdkit import Chem + from rdkit.Chem import Atom + + emol = Chem.RWMol(rdmol) + for (iat, jat), order in altered_bonds.items(): + bond = emol.GetBondBetweenAtoms(iat, jat) + bond.SetBondType(Chem.BondType(order)) + for ind in (iat, jat): + at = emol.GetAtomWithIdx(ind) + if not at.GetIsAromatic: + continue + # Only change this if we are sure the atom is not aromatic anymore + aromatic = False + for bond in at.GetBonds(): + if str(bond.GetBondType()) == "AROMATIC": + aromatic = True + break + if not aromatic: + at.SetIsAromatic(False) + rdmol = emol.GetMol() + for iat, q in altered_charge.items(): + atom = rdmol.GetAtomWithIdx(iat) + atom.SetFormalCharge(q) + return rdmol + +def _kekulize (mol, text): + """ + Kekulize the atoms indicated as problematic by RDKit + + * ``mol`` - PLAMS molecule + * ``text`` - Sanitation error produced by RDKit + + Note: Returns the changes in bond orders and atomic charges, that will make sanitation succeed. + """ + from scm.plams import PeriodicTable as PT + + # Find the indices of the problematic atoms + indices = _find_aromatic_sequence (mol, text) + + # Order the indices, so that they are consecutive in the molecule + start = 0 + for i,iat in enumerate(indices): + neighbors = [mol.index(at)-1 for at in mol.neighbors(mol.atoms[iat])] + relevant_neighbors = [jat for jat in neighbors if jat in indices] + if len(relevant_neighbors) == 1: + start = i + break + iat = indices[start] + atoms = [iat] + while(1): + neighbors = [mol.index(at)-1 for at in mol.neighbors(mol.atoms[iat])] + relevant_neighbors = [jat for jat in neighbors if jat in indices and not jat in atoms] + if len(relevant_neighbors) == 0: + break + iat = relevant_neighbors[0] + atoms.append(iat) + if len(atoms) < len(indices): + raise Exception("The unkekulized atoms are not in a consecutive chain") + indices = atoms + + if len(indices) > 1: + # The first bond order is set to 2. + new_order = 2 + shift = 0 + # Unless this breaks valence rules. + iat = indices[0] + at = mol.atoms[iat] + valence = PT.get_connectors(at.atnum) + orders = [b.order for b in at.bonds] + jat = indices[1] + at_next = mol.atoms[jat] + valence_next = PT.get_connectors(at_next.atnum) + orders_next = [b.order for b in at_next.bonds] + if sum(orders) > valence or sum(orders_next) > valence_next: + new_order = 1 + shift = 1 + # Set the bond orders along the chain to 2, 1, 2, 1,... + altered_bonds = {} + for i,iat in enumerate(indices[:-1]): + bonds = mol.atoms[iat].bonds + bond = [b for b in bonds if b.other_end(mol.atoms[iat]) == mol.atoms[indices[i+1]]][0] + bond.order = new_order + altered_bonds[((mol.index(bond.atom1) - 1 , mol.index(bond.atom2) - 1))] = new_order + new_order = ((i+shift) % 2) + 1 + + # If the atom at the chain end has the wrong bond order, give it a charge. + for iat in indices[::-1]: + at = mol.atoms[iat] + valence = PT.get_connectors(at.atnum) + bonds = at.bonds + orders = [b.order for b in bonds] + charge = int(valence - sum(orders)) + if charge != 0: + break + + # Adjust the sign of the charge (taking already assigned atomic charges into account) + charges = [] + for i,at in enumerate(mol.atoms): + q = 0. + if "rdkit" in at.properties: + if "charge" in at.properties.rdkit: + q = at.properties.rdkit.charge + charges.append(q) + totcharge = sum(charges) - charges[iat] + # Assign the charge + altered_charge = {} + if charge != 0: + sign = charge / (abs(charge)) + # Here I hope that the estimated charge will be more reliable than the + # actual (user defined) system charge, but am not sure + est_charges = mol.guess_atomic_charges(adjust_to_systemcharge=False, depth=0) + molcharge = int(sum(est_charges)) + molcharge = molcharge - totcharge + if molcharge != 0: + sign = molcharge / (abs(molcharge)) + elif est_charges[iat] != 0: + sign = est_charges[iat] / abs(est_charges[iat]) + charge = int(sign * abs(charge)) + # Perhaps we tried this already + if charge == charges[iat]: + charge = -charge + mol.atoms[iat].properties.rdkit.charge = charge + altered_charge[iat] = charge + + return altered_bonds, altered_charge + +def _find_aromatic_sequence (mol, text): + """ + Find the sequence of atoms with 1.5 bond orders + """ + lines = text.split("\n") + line = lines[-1] + if "Unkekulized atoms:" in text: + text = line.split("Unkekulized atoms:")[-1].split("\\n")[0] + if '"' in text: + text = text.split('"')[0] + indices = [int(w) for w in text.split()] + if len(indices) > 1: + return indices + line = "atom %i marked aromatic"%(indices[0]) + text = line + if "marked aromatic" in text: + iat = int(line.split("atom")[-1].split()[0]) + indices = [iat] + while(iat is not None): + at = mol.atoms[iat] + iat = None + for bond in at.bonds: + if bond.order == 1.5: + nextat = bond.other_end(at) + iat = mol.index(nextat) - 1 + if iat in indices: + iat = None + continue + indices.append(iat) + break + elif "Explicit valence for atom" in text: + iat = int(line.split("#")[-1].split()[0]) + indices = [iat] + return indices + +def _rdmol_for_image (mol, remove_hydrogens=True): + """ + Convert PLAMS molecule to an RDKit molecule specifically for a 2D image + """ + from rdkit.Chem import AllChem + from rdkit.Chem import RemoveHs + + #rdmol = mol.to_rdmol_new() + rdmol = to_rdmol(mol, presanitize=True) + + # Flatten the molecule + AllChem.Compute2DCoords(rdmol) + # Remove the Hs only if there are carbon atoms in this system + # Otherwise this will turn an OH radical into a water molecule. + carbons = [i for i,at in enumerate(mol.atoms) if at.symbol in ["C", "Si"]] + if remove_hydrogens and len(carbons) > 0: + rdmol = RemoveHs(rdmol) + else: + for atom in rdmol.GetAtoms(): + atom.SetNoImplicit(True) + + ids = [c.GetId() for c in rdmol.GetConformers()] + for cid in ids: + rdmol.RemoveConformer(cid) + return rdmol + +def _MolsToGridSVG(mols, molsPerRow=3, subImgSize=(200, 200), legends=None, highlightAtomLists=None, + highlightBondLists=None, drawOptions=None, **kwargs): + """ + Replaces the old version of this function in our RDKit for a more recent one, with more options + """ + from rdkit.Chem.Draw import rdMolDraw2D + from rdkit.Geometry.rdGeometry import Point2D + + if legends is None: + legends = [''] * len(mols) + + nRows = len(mols) // molsPerRow + if len(mols) % molsPerRow: + nRows += 1 + + fullSize = (molsPerRow * subImgSize[0], nRows * subImgSize[1]) + + d2d = rdMolDraw2D.MolDraw2DSVG(fullSize[0], fullSize[1], subImgSize[0], subImgSize[1]) + if drawOptions is not None: + d2d.SetDrawOptions(drawOptions) + else: + dops = d2d.drawOptions() + for k, v in list(kwargs.items()): + if hasattr(dops, k): + setattr(dops, k, v) + del kwargs[k] + + d2d.DrawMolecules(list(mols), legends=legends or None, highlightAtoms=highlightAtomLists or [], + highlightBonds=highlightBondLists or [], **kwargs) + d2d.FinishDrawing() + res = d2d.GetDrawingText() + return res From 2d8a295e95a8fc4ef0362fcf508574e7aaffaad0 Mon Sep 17 00:00:00 2001 From: Rosa Bulo Date: Fri, 22 Nov 2024 12:38:15 +0100 Subject: [PATCH 2/5] Rosa Bulo (REB) SCMSUITE-- SO102: Cleaned up to_image and get_reaction_image. - to_image now works with svg in JupyterLab --- interfaces/molecule/rdkit.py | 207 ++++++++++++++++++----------------- 1 file changed, 106 insertions(+), 101 deletions(-) diff --git a/interfaces/molecule/rdkit.py b/interfaces/molecule/rdkit.py index 2c312c6af..c5fa72389 100644 --- a/interfaces/molecule/rdkit.py +++ b/interfaces/molecule/rdkit.py @@ -1296,13 +1296,15 @@ def canonicalize_mol(mol, inplace=False, **kwargs): ret.atoms = [at for _, at in sorted(zip(idx_rank, ret.atoms), reverse=True)] return ret -def to_image (mol, remove_hydrogens=True, filename=None, format="svg", as_string=True): - """ +@requires_optional_package("rdkit") +def to_image(mol, remove_hydrogens=True, filename=None, format="svg", as_string=True): + """ Convert single molecule to single image object * ``mol`` -- PLAMS Molecule object * ``remove_hydrogens`` -- Wether or not to remove the H-atoms from the image * ``filename`` -- Optional: Name of image file to be created. + * ``format`` -- One of "svg", "png", "eps", "pdf", "jpeg" * ``as_string`` -- Returns the image as a string or bytestring. If set to False, the original format will be returned, which can be either a PIL image or SVG text We do this because after converting a PIL image to a bytestring it is not possible @@ -1311,7 +1313,6 @@ def to_image (mol, remove_hydrogens=True, filename=None, format="svg", as_string """ from io import BytesIO from PIL import Image - from rdkit import Chem from rdkit.Chem import Draw from rdkit.Chem.Draw import rdMolDraw2D from rdkit.Chem.Draw import MolsToGridImage @@ -1323,7 +1324,7 @@ def to_image (mol, remove_hydrogens=True, filename=None, format="svg", as_string for ext in extensions[1:]: classes[ext] = None # PNG can only be created in this way with later version of RDKit - if hasattr(rdMolDraw2D,"MolDraw2DCairo"): + if hasattr(rdMolDraw2D, "MolDraw2DCairo"): for ext in extensions[1:]: classes[ext] = MolsToGridImage @@ -1334,21 +1335,21 @@ def to_image (mol, remove_hydrogens=True, filename=None, format="svg", as_string if extension in classes.keys(): format = extension else: - msg = ["Image type %s not available."%(extension)] - msg += ["Available extensions are: %s"%(" ".join(extensions))] + msg = ["Image type %s not available." % (extension)] + msg += ["Available extensions are: %s" % (" ".join(extensions))] raise Exception("\n".join(msg)) else: - filename = ".".join(filename,format) + filename = ".".join(filename, format) if format not in classes.keys(): - raise Exception("Image type %s not available."%(format)) + raise Exception("Image type %s not available." % (format)) rdmol = _rdmol_for_image(mol, remove_hydrogens) # Draw the image if classes[format.lower()] is None: # With AMS version of RDKit MolsToGridImage fails for eps (because of paste) - img = Draw.MolToImage(rdmol,size=(200,100)) + img = Draw.MolToImage(rdmol, size=(200, 100)) buf = BytesIO() img.save(buf, format=format) img_text = buf.getvalue() @@ -1356,31 +1357,36 @@ def to_image (mol, remove_hydrogens=True, filename=None, format="svg", as_string # This fails for a C=C=O molecule, with AMS rdkit version img = classes[format.lower()]([rdmol], molsPerRow=1, subImgSize=(200, 100)) img_text = img - if isinstance(img,Image.Image): + if isinstance(img, Image.Image): buf = BytesIO() img.save(buf, format=format) img_text = buf.getvalue() - + # If I do not make this correction to the SVG text, it is not readable in JupyterLab + if format.lower() == "svg": + img_text = _correct_svg(img_text) + # Write to file, if required if filename is not None: mode = "w" - if isinstance(img_text,bytes): + if isinstance(img_text, bytes): mode = "wb" - with open(filename,mode) as outfile: + with open(filename, mode) as outfile: outfile.write(img_text) if as_string: img = img_text return img -def get_reaction_image (reactants, products, filename=None, format="svg", as_string=True): - """ + +@requires_optional_package("rdkit") +def get_reaction_image(reactants, products, filename=None, format="svg", as_string=True): + """ Create a 2D reaction image from reactants and products (PLAMS molecules) * ``reactants`` -- Iterable of PLAMS Molecule objects representing the reactants. * ``products`` -- Iterable of PLAMS Molecule objects representig the products. * ``filename`` -- Optional: Name of image file to be created. - * ``format`` -- The format of the image (svg, png, eps). + * ``format`` -- The format of the image (svg, png, eps, pdf, jpeg). The extension in the filename, if provided, takes precedence. * ``as_string`` -- Returns the image as a string or bytestring. If set to False, the original format will be returned, which can be either a PIL image or SVG text @@ -1395,119 +1401,121 @@ def get_reaction_image (reactants, products, filename=None, format="svg", as_str if extension in extensions: format = extension else: - msg = ["Image type %s not available."%(extension)] - msg += ["Available extensions are: %s"%(" ".join(extensions))] + msg = ["Image type %s not available." % (extension)] + msg += ["Available extensions are: %s" % (" ".join(extensions))] raise Exception("\n".join(msg)) else: - filename = ".".join(filename,format) + filename = ".".join(filename, format) if format.lower() not in extensions: - raise Exception("Image type %s not available."%(format)) + raise Exception("Image type %s not available." % (format)) # Get the actual image if format.lower() == "svg": - img_text = get_reaction_image_svg (reactants, products) + img_text = get_reaction_image_svg(reactants, products) else: img_text = get_reaction_image_pil(reactants, products, format, as_string=as_string) # Write to file, if required if filename is not None: - mode = "w" - if isinstance(img_text,bytes): + mode = "w" + if isinstance(img_text, bytes): mode = "wb" - with open(filename,mode) as outfile: + with open(filename, mode) as outfile: outfile.write(img_text) return img_text -def get_reaction_image_svg (reactants, products, width=200, height=100): - """ + +def get_reaction_image_svg(reactants, products, width=200, height=100): + """ Create a 2D reaction image from reactants and products (PLAMS molecules) * ``reactants`` -- Iterable of PLAMS Molecule objects representing the reactants. * ``products`` -- Iterable of PLAMS Molecule objects representig the products. * Returns -- SVG image text file. - """ + """ from rdkit import Chem - def svg_arrow (x1, y1, x2, y2, prefix=""): + def svg_arrow(x1, y1, x2, y2, prefix=""): """ The reaction arrow in html format """ # The arrow head - l = ['<%sdefs> <%smarker id="arrow" viewBox="0 0 10 10" refX="5" refY="5" '%(prefix, prefix)] + l = ['<%sdefs> <%smarker id="arrow" viewBox="0 0 10 10" refX="5" refY="5" ' % (prefix, prefix)] l += ['markerWidth="6" markerHeight="6" '] - l += ['orient="auto-start-reverse"> <%spath d="M 0 0 L 10 5 L 0 10 z" />'%(prefix, prefix, prefix)] - arrow = ''.join(l) + l += ['orient="auto-start-reverse"> '] + l += ['<%spath d="M 0 0 L 10 5 L 0 10 z" />' % (prefix, prefix, prefix)] + arrow = "".join(l) # The line - l = ['<%sline x1="%i" y1="%i" x2="%i" y2="%i" '%(prefix, x1, y1, x2, y2)] + l = ['<%sline x1="%i" y1="%i" x2="%i" y2="%i" ' % (prefix, x1, y1, x2, y2)] l += ['stroke="black" marker-end="url(#arrow)" />'] - line = ''.join(l) + line = "".join(l) return [arrow, line] - def add_plus_signs_svg (img_text, width, height, nmols, nreactants, prefix=""): + def add_plus_signs_svg(img_text, width, height, nmols, nreactants, prefix=""): """ Add the lines with + signs to the SVG image """ y = int(0.55 * height) t = [] - for i in range(nmols-1): + for i in range(nmols - 1): x = int(((i + 1) * width) - (0.1 * width)) if i + 1 in (nreactants, nreactants + 1): continue - t += ['<%stext x="%i" y="%i" font-size="16">+'%(prefix,x,y,prefix)] + t += ['<%stext x="%i" y="%i" font-size="16">+' % (prefix, x, y, prefix)] lines = img_text.split("\n") - lines = lines[:-2] + t + lines[-2:] + lines = lines[:-2] + t + lines[-2:] return "\n".join(lines) - def add_arrow_svg (img_text, width, height, nreactants, prefix=""): + def add_arrow_svg(img_text, width, height, nreactants, prefix=""): """ Add the arrow to the SVG image """ y = int(0.5 * height) x1 = int((nreactants * width) + (0.3 * width)) x2 = int((nreactants * width) + (0.7 * width)) - t = svg_arrow (x1, y, x2, y, prefix) + t = svg_arrow(x1, y, x2, y, prefix) lines = img_text.split("\n") lines = lines[:-2] + t + lines[-2:] return "\n".join(lines) # Get the rdkit molecules rdmols = [_rdmol_for_image(mol) for mol in reactants] - rdmols += [Chem.Mol()] # This is where the arrow will go + rdmols += [Chem.Mol()] # This is where the arrow will go rdmols += [_rdmol_for_image(mol) for mol in products] nmols = len(rdmols) # Place the molecules in a row of images subimg_size = [width, height] - kwargs = {"legendFontSize":16} #,"legendFraction":0.1} + kwargs = {"legendFontSize": 16} # ,"legendFraction":0.1} img_text = _MolsToGridSVG(rdmols, molsPerRow=nmols, subImgSize=subimg_size, **kwargs) - img_text = _correct_svg (img_text) + img_text = _correct_svg(img_text) # Add + and => nreactants = len(reactants) img_text = add_plus_signs_svg(img_text, width, height, nmols, nreactants) - img_text = add_arrow_svg (img_text, width, height, nreactants) + img_text = add_arrow_svg(img_text, width, height, nreactants) return img_text -def get_reaction_image_pil (reactants, products, format, width=200, height=100, as_string=True): - """ +def get_reaction_image_pil(reactants, products, format, width=200, height=100, as_string=True): + """ Create a 2D reaction image from reactants and products (PLAMS molecules) * ``reactants`` -- Iterable of PLAMS Molecule objects representing the reactants. * ``products`` -- Iterable of PLAMS Molecule objects representig the products. * Returns -- SVG image text file. - """ + """ from io import BytesIO from PIL import Image from PIL import ImageDraw from rdkit import Chem from rdkit.Chem.Draw import rdMolDraw2D, MolsToGridImage - def add_arrow_pil (img, width, height, nreactants): + def add_arrow_pil(img, width, height, nreactants): """ Add the arrow to the PIL image - """ + """ y1 = int(0.5 * height) y2 = y1 x1 = int((nreactants * width) + (0.3 * width)) @@ -1529,25 +1537,24 @@ def add_arrow_pil (img, width, height, nreactants): return img - def add_plus_signs_pil (img, width, height, nmols, nreactants): + def add_plus_signs_pil(img, width, height, nmols, nreactants): """ Add the lines with + signs to the SVG image """ - white = (255, 255, 255) black = (0, 0, 0) I1 = ImageDraw.Draw(img) y = int(0.5 * height) - #myfont = ImageFont.truetype('FreeMono.ttf', 25) + # myfont = ImageFont.truetype('FreeMono.ttf', 25) for i in range(nmols): x = int(((i + 1) * width) - (0.05 * width)) - if i + 1 in (nreactants,nreactants + 1): + if i + 1 in (nreactants, nreactants + 1): continue - #I1.text((x, y), "+", font=myfont, fill=black) + # I1.text((x, y), "+", font=myfont, fill=black) I1.text((x, y), "+", fill=black) return img - def join_pil_images (pil_images): + def join_pil_images(pil_images): """ Create a new image which connects the ones above with text """ @@ -1556,10 +1563,10 @@ def join_pil_images (pil_images): widths = [img.width for img in pil_images] width = sum(widths) height = max([img.height for img in pil_images]) - final_img = Image.new('RGB', (width, height), white) + final_img = Image.new("RGB", (width, height), white) # Concatenate the PIL images - for i,img in enumerate(pil_images): + for i, img in enumerate(pil_images): pos = sum(widths[:i]) h = int((height - img.height) / 2) final_img.paste(img, (pos, h)) @@ -1569,13 +1576,13 @@ def join_pil_images (pil_images): nreactants = len(reactants) nmols = nreactants + len(products) - if not hasattr(rdMolDraw2D, 'MolDraw2DCairo'): + if not hasattr(rdMolDraw2D, "MolDraw2DCairo"): # We are working with the old AMS version of RDKit white = (255, 255, 255) - rimages = [to_image(mol, format=format, as_string=False) for i,mol in enumerate(reactants)] - pimages = [to_image(mol, format=format, as_string=False) for i,mol in enumerate(products)] - blanc = Image.new('RGB', (width, height), white) - + rimages = [to_image(mol, format=format, as_string=False) for i, mol in enumerate(reactants)] + pimages = [to_image(mol, format=format, as_string=False) for i, mol in enumerate(products)] + blanc = Image.new("RGB", (width, height), white) + # Get the image (with arrow) nreactants = len(reactants) all_images = rimages + [blanc] + pimages @@ -1584,19 +1591,19 @@ def join_pil_images (pil_images): else: # We have a later version of RDKit that can regulate the font sizes rdmols = [_rdmol_for_image(mol) for mol in reactants] - rdmols += [Chem.Mol()] # This is where the arrow will go + rdmols += [Chem.Mol()] # This is where the arrow will go rdmols += [_rdmol_for_image(mol) for mol in products] - + # Place the molecules in a row of images - subimg_size = [width, height] - kwargs = {"legendFontSize":16} #,"legendFraction":0.1} + subimg_size = [width, height] + kwargs = {"legendFontSize": 16} # ,"legendFraction":0.1} img = MolsToGridImage(rdmols, molsPerRow=nmols + 1, subImgSize=subimg_size, **kwargs) - + # Add + and => img = add_plus_signs_pil(img, width, height, nmols, nreactants) img = add_arrow_pil(img, width, height, nreactants) - # Get the bytestring + # Get the bytestring img_text = img if as_string: buf = BytesIO() @@ -1605,22 +1612,22 @@ def join_pil_images (pil_images): return img_text -def _correct_svg (image): +def _correct_svg(image): """ Correct for a bug in the AMS rdkit created SVG file """ if not "svg:" in image: return image - image = image.replace("svg:","") + image = image.replace("svg:", "") lines = image.split("\n") - for iline,line in enumerate(lines): + for iline, line in enumerate(lines): if "xmlns:svg=" in line: - lines[iline] = line.replace("xmlns:svg","xmlns") + lines[iline] = line.replace("xmlns:svg", "xmlns") break image = "\n".join(lines) return image -def _presanitize (mol, rdmol): +def _presanitize(mol, rdmol): """ Change bonding and atom charges to avoid failed sanitization @@ -1646,8 +1653,7 @@ def _update_system_for_sanitation(rdmol, altered_bonds, altered_charge): """ Change bond orders and charges if rdkit molecule, so that sanitiation will succeed. """ - from rdkit import Chem - from rdkit.Chem import Atom + from rdkit import Chem emol = Chem.RWMol(rdmol) for (iat, jat), order in altered_bonds.items(): @@ -1671,7 +1677,7 @@ def _update_system_for_sanitation(rdmol, altered_bonds, altered_charge): atom.SetFormalCharge(q) return rdmol -def _kekulize (mol, text): +def _kekulize(mol, text): """ Kekulize the atoms indicated as problematic by RDKit @@ -1683,20 +1689,20 @@ def _kekulize (mol, text): from scm.plams import PeriodicTable as PT # Find the indices of the problematic atoms - indices = _find_aromatic_sequence (mol, text) + indices = _find_aromatic_sequence(mol, text) # Order the indices, so that they are consecutive in the molecule start = 0 - for i,iat in enumerate(indices): - neighbors = [mol.index(at)-1 for at in mol.neighbors(mol.atoms[iat])] + for i, iat in enumerate(indices): + neighbors = [mol.index(at) - 1 for at in mol.neighbors(mol.atoms[iat])] relevant_neighbors = [jat for jat in neighbors if jat in indices] if len(relevant_neighbors) == 1: start = i break iat = indices[start] atoms = [iat] - while(1): - neighbors = [mol.index(at)-1 for at in mol.neighbors(mol.atoms[iat])] + while 1: + neighbors = [mol.index(at) - 1 for at in mol.neighbors(mol.atoms[iat])] relevant_neighbors = [jat for jat in neighbors if jat in indices and not jat in atoms] if len(relevant_neighbors) == 0: break @@ -1705,8 +1711,8 @@ def _kekulize (mol, text): if len(atoms) < len(indices): raise Exception("The unkekulized atoms are not in a consecutive chain") indices = atoms - - if len(indices) > 1: + + if len(indices) > 1: # The first bond order is set to 2. new_order = 2 shift = 0 @@ -1724,13 +1730,13 @@ def _kekulize (mol, text): shift = 1 # Set the bond orders along the chain to 2, 1, 2, 1,... altered_bonds = {} - for i,iat in enumerate(indices[:-1]): + for i, iat in enumerate(indices[:-1]): bonds = mol.atoms[iat].bonds - bond = [b for b in bonds if b.other_end(mol.atoms[iat]) == mol.atoms[indices[i+1]]][0] + bond = [b for b in bonds if b.other_end(mol.atoms[iat]) == mol.atoms[indices[i + 1]]][0] bond.order = new_order - altered_bonds[((mol.index(bond.atom1) - 1 , mol.index(bond.atom2) - 1))] = new_order - new_order = ((i+shift) % 2) + 1 - + altered_bonds[((mol.index(bond.atom1) - 1, mol.index(bond.atom2) - 1))] = new_order + new_order = ((i + shift) % 2) + 1 + # If the atom at the chain end has the wrong bond order, give it a charge. for iat in indices[::-1]: at = mol.atoms[iat] @@ -1743,8 +1749,8 @@ def _kekulize (mol, text): # Adjust the sign of the charge (taking already assigned atomic charges into account) charges = [] - for i,at in enumerate(mol.atoms): - q = 0. + for i, at in enumerate(mol.atoms): + q = 0.0 if "rdkit" in at.properties: if "charge" in at.properties.rdkit: q = at.properties.rdkit.charge @@ -1772,7 +1778,7 @@ def _kekulize (mol, text): return altered_bonds, altered_charge -def _find_aromatic_sequence (mol, text): +def _find_aromatic_sequence(mol, text): """ Find the sequence of atoms with 1.5 bond orders """ @@ -1785,12 +1791,12 @@ def _find_aromatic_sequence (mol, text): indices = [int(w) for w in text.split()] if len(indices) > 1: return indices - line = "atom %i marked aromatic"%(indices[0]) - text = line + line = "atom %i marked aromatic" % (indices[0]) + text = line if "marked aromatic" in text: iat = int(line.split("atom")[-1].split()[0]) - indices = [iat] - while(iat is not None): + indices = [iat] + while iat is not None: at = mol.atoms[iat] iat = None for bond in at.bonds: @@ -1807,23 +1813,23 @@ def _find_aromatic_sequence (mol, text): indices = [iat] return indices -def _rdmol_for_image (mol, remove_hydrogens=True): +def _rdmol_for_image(mol, remove_hydrogens=True): """ Convert PLAMS molecule to an RDKit molecule specifically for a 2D image """ from rdkit.Chem import AllChem from rdkit.Chem import RemoveHs - #rdmol = mol.to_rdmol_new() + # rdmol = mol.to_rdmol_new() rdmol = to_rdmol(mol, presanitize=True) - + # Flatten the molecule AllChem.Compute2DCoords(rdmol) # Remove the Hs only if there are carbon atoms in this system # Otherwise this will turn an OH radical into a water molecule. - carbons = [i for i,at in enumerate(mol.atoms) if at.symbol in ["C", "Si"]] + carbons = [i for i, at in enumerate(mol.atoms) if at.symbol in ["C", "Si"]] if remove_hydrogens and len(carbons) > 0: - rdmol = RemoveHs(rdmol) + rdmol = RemoveHs(rdmol) else: for atom in rdmol.GetAtoms(): atom.SetNoImplicit(True) @@ -1839,10 +1845,9 @@ def _MolsToGridSVG(mols, molsPerRow=3, subImgSize=(200, 200), legends=None, high Replaces the old version of this function in our RDKit for a more recent one, with more options """ from rdkit.Chem.Draw import rdMolDraw2D - from rdkit.Geometry.rdGeometry import Point2D if legends is None: - legends = [''] * len(mols) + legends = [""] * len(mols) nRows = len(mols) // molsPerRow if len(mols) % molsPerRow: From df2653d34c2e49b331b4b5fffc651cb3032dfc4b Mon Sep 17 00:00:00 2001 From: Rosa Bulo Date: Fri, 22 Nov 2024 13:23:40 +0100 Subject: [PATCH 3/5] Rosa Bulo (REB) SCMSUITE-- SO102: Cleaned up formatting some more --- interfaces/molecule/rdkit.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/interfaces/molecule/rdkit.py b/interfaces/molecule/rdkit.py index c5fa72389..706d83d73 100644 --- a/interfaces/molecule/rdkit.py +++ b/interfaces/molecule/rdkit.py @@ -1296,6 +1296,7 @@ def canonicalize_mol(mol, inplace=False, **kwargs): ret.atoms = [at for _, at in sorted(zip(idx_rank, ret.atoms), reverse=True)] return ret + @requires_optional_package("rdkit") def to_image(mol, remove_hydrogens=True, filename=None, format="svg", as_string=True): """ @@ -1498,6 +1499,7 @@ def add_arrow_svg(img_text, width, height, nreactants, prefix=""): return img_text + def get_reaction_image_pil(reactants, products, format, width=200, height=100, as_string=True): """ Create a 2D reaction image from reactants and products (PLAMS molecules) @@ -1612,6 +1614,7 @@ def join_pil_images(pil_images): return img_text + def _correct_svg(image): """ Correct for a bug in the AMS rdkit created SVG file @@ -1627,6 +1630,7 @@ def _correct_svg(image): image = "\n".join(lines) return image + def _presanitize(mol, rdmol): """ Change bonding and atom charges to avoid failed sanitization @@ -1634,6 +1638,7 @@ def _presanitize(mol, rdmol): Note: Used by to_rdmol """ from rdkit import Chem + mol = mol.copy() for i in range(10): try: @@ -1649,6 +1654,7 @@ def _presanitize(mol, rdmol): raise stored_exc return rdmol + def _update_system_for_sanitation(rdmol, altered_bonds, altered_charge): """ Change bond orders and charges if rdkit molecule, so that sanitiation will succeed. @@ -1677,6 +1683,7 @@ def _update_system_for_sanitation(rdmol, altered_bonds, altered_charge): atom.SetFormalCharge(q) return rdmol + def _kekulize(mol, text): """ Kekulize the atoms indicated as problematic by RDKit @@ -1778,6 +1785,7 @@ def _kekulize(mol, text): return altered_bonds, altered_charge + def _find_aromatic_sequence(mol, text): """ Find the sequence of atoms with 1.5 bond orders @@ -1813,6 +1821,7 @@ def _find_aromatic_sequence(mol, text): indices = [iat] return indices + def _rdmol_for_image(mol, remove_hydrogens=True): """ Convert PLAMS molecule to an RDKit molecule specifically for a 2D image @@ -1839,9 +1848,18 @@ def _rdmol_for_image(mol, remove_hydrogens=True): rdmol.RemoveConformer(cid) return rdmol -def _MolsToGridSVG(mols, molsPerRow=3, subImgSize=(200, 200), legends=None, highlightAtomLists=None, - highlightBondLists=None, drawOptions=None, **kwargs): - """ + +def _MolsToGridSVG( + mols, + molsPerRow=3, + subImgSize=(200, 200), + legends=None, + highlightAtomLists=None, + highlightBondLists=None, + drawOptions=None, + **kwargs, +): + """ Replaces the old version of this function in our RDKit for a more recent one, with more options """ from rdkit.Chem.Draw import rdMolDraw2D @@ -1865,8 +1883,13 @@ def _MolsToGridSVG(mols, molsPerRow=3, subImgSize=(200, 200), legends=None, high setattr(dops, k, v) del kwargs[k] - d2d.DrawMolecules(list(mols), legends=legends or None, highlightAtoms=highlightAtomLists or [], - highlightBonds=highlightBondLists or [], **kwargs) + d2d.DrawMolecules( + list(mols), + legends=legends or None, + highlightAtoms=highlightAtomLists or [], + highlightBonds=highlightBondLists or [], + **kwargs, + ) d2d.FinishDrawing() res = d2d.GetDrawingText() return res From 438c6dddcadbdc3170898da9ea801cf386006f15 Mon Sep 17 00:00:00 2001 From: Rosa Bulo Date: Fri, 22 Nov 2024 14:17:09 +0100 Subject: [PATCH 4/5] Rosa Bulo (REB) SCMSUITE-- SO102: Cleaned file for mypy --- interfaces/molecule/rdkit.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/interfaces/molecule/rdkit.py b/interfaces/molecule/rdkit.py index 706d83d73..d137e6b3a 100644 --- a/interfaces/molecule/rdkit.py +++ b/interfaces/molecule/rdkit.py @@ -1,4 +1,6 @@ from typing import List, Literal, Optional, overload, TYPE_CHECKING +from typing import Dict +from typing import Any import random import sys from warnings import warn @@ -1320,7 +1322,7 @@ def to_image(mol, remove_hydrogens=True, filename=None, format="svg", as_string= extensions = ["svg", "png", "eps", "pdf", "jpeg"] - classes = {} + classes: Dict[str, Any] = {} classes["svg"] = _MolsToGridSVG for ext in extensions[1:]: classes[ext] = None @@ -1340,7 +1342,7 @@ def to_image(mol, remove_hydrogens=True, filename=None, format="svg", as_string= msg += ["Available extensions are: %s" % (" ".join(extensions))] raise Exception("\n".join(msg)) else: - filename = ".".join(filename, format) + filename = ".".join([filename, format]) if format not in classes.keys(): raise Exception("Image type %s not available." % (format)) @@ -1406,7 +1408,7 @@ def get_reaction_image(reactants, products, filename=None, format="svg", as_stri msg += ["Available extensions are: %s" % (" ".join(extensions))] raise Exception("\n".join(msg)) else: - filename = ".".join(filename, format) + filename = ".".join([filename, format]) if format.lower() not in extensions: raise Exception("Image type %s not available." % (format)) @@ -1801,6 +1803,7 @@ def _find_aromatic_sequence(mol, text): return indices line = "atom %i marked aromatic" % (indices[0]) text = line + iat: Any if "marked aromatic" in text: iat = int(line.split("atom")[-1].split()[0]) indices = [iat] From 1b9c2f85766c119749741ea0682415bd71a5be4f Mon Sep 17 00:00:00 2001 From: David Ormrod Morley Date: Mon, 25 Nov 2024 16:35:15 +0100 Subject: [PATCH 5/5] Minor changes and add unit test SO102 --- interfaces/molecule/rdkit.py | 90 ++++++++++++++++---------- pyproject.toml | 3 +- unit_tests/test_molecule_interfaces.py | 42 +++++++++++- unit_tests/xyz/product1.xyz | 5 ++ unit_tests/xyz/product2.xyz | 18 ++++++ unit_tests/xyz/reactant1.xyz | 6 ++ unit_tests/xyz/reactant2.xyz | 17 +++++ 7 files changed, 145 insertions(+), 36 deletions(-) create mode 100644 unit_tests/xyz/product1.xyz create mode 100644 unit_tests/xyz/product2.xyz create mode 100644 unit_tests/xyz/reactant1.xyz create mode 100644 unit_tests/xyz/reactant2.xyz diff --git a/interfaces/molecule/rdkit.py b/interfaces/molecule/rdkit.py index d137e6b3a..18dfe5e9d 100644 --- a/interfaces/molecule/rdkit.py +++ b/interfaces/molecule/rdkit.py @@ -1,6 +1,4 @@ -from typing import List, Literal, Optional, overload, TYPE_CHECKING -from typing import Dict -from typing import Any +from typing import List, Literal, Optional, overload, TYPE_CHECKING, Sequence, Dict, Any import random import sys from warnings import warn @@ -34,6 +32,8 @@ "get_conformations", "yield_coords", "canonicalize_mol", + "to_image", + "get_reaction_image", ] @@ -1300,7 +1300,14 @@ def canonicalize_mol(mol, inplace=False, **kwargs): @requires_optional_package("rdkit") -def to_image(mol, remove_hydrogens=True, filename=None, format="svg", as_string=True): +@requires_optional_package("PIL") +def to_image( + mol: Molecule, + remove_hydrogens: bool = True, + filename: Optional[str] = None, + fmt: str = "svg", + as_string: bool = True, +): """ Convert single molecule to single image object @@ -1336,36 +1343,36 @@ def to_image(mol, remove_hydrogens=True, filename=None, format="svg", as_string= if "." in filename: extension = filename.split(".")[-1] if extension in classes.keys(): - format = extension + fmt = extension else: - msg = ["Image type %s not available." % (extension)] - msg += ["Available extensions are: %s" % (" ".join(extensions))] + msg = [f"Image type {extension} not available."] + msg += [f"Available extensions are: {' '.join(extensions)}"] raise Exception("\n".join(msg)) else: - filename = ".".join([filename, format]) + filename = ".".join([filename, fmt]) - if format not in classes.keys(): - raise Exception("Image type %s not available." % (format)) + if fmt not in classes.keys(): + raise Exception(f"Image type {fmt} not available.") rdmol = _rdmol_for_image(mol, remove_hydrogens) # Draw the image - if classes[format.lower()] is None: + if classes[fmt.lower()] is None: # With AMS version of RDKit MolsToGridImage fails for eps (because of paste) img = Draw.MolToImage(rdmol, size=(200, 100)) buf = BytesIO() - img.save(buf, format=format) + img.save(buf, format=fmt) img_text = buf.getvalue() else: # This fails for a C=C=O molecule, with AMS rdkit version - img = classes[format.lower()]([rdmol], molsPerRow=1, subImgSize=(200, 100)) + img = classes[fmt.lower()]([rdmol], molsPerRow=1, subImgSize=(200, 100)) img_text = img if isinstance(img, Image.Image): buf = BytesIO() - img.save(buf, format=format) + img.save(buf, format=fmt) img_text = buf.getvalue() # If I do not make this correction to the SVG text, it is not readable in JupyterLab - if format.lower() == "svg": + if fmt.lower() == "svg": img_text = _correct_svg(img_text) # Write to file, if required @@ -1382,14 +1389,21 @@ def to_image(mol, remove_hydrogens=True, filename=None, format="svg", as_string= @requires_optional_package("rdkit") -def get_reaction_image(reactants, products, filename=None, format="svg", as_string=True): +@requires_optional_package("PIL") +def get_reaction_image( + reactants: Sequence[Molecule], + products: Sequence[Molecule], + filename: Optional[str] = None, + fmt: str = "svg", + as_string: bool = True, +): """ Create a 2D reaction image from reactants and products (PLAMS molecules) * ``reactants`` -- Iterable of PLAMS Molecule objects representing the reactants. - * ``products`` -- Iterable of PLAMS Molecule objects representig the products. + * ``products`` -- Iterable of PLAMS Molecule objects representing the products. * ``filename`` -- Optional: Name of image file to be created. - * ``format`` -- The format of the image (svg, png, eps, pdf, jpeg). + * ``fmt`` -- The format of the image (svg, png, eps, pdf, jpeg). The extension in the filename, if provided, takes precedence. * ``as_string`` -- Returns the image as a string or bytestring. If set to False, the original format will be returned, which can be either a PIL image or SVG text @@ -1402,22 +1416,22 @@ def get_reaction_image(reactants, products, filename=None, format="svg", as_stri if "." in filename: extension = filename.split(".")[-1] if extension in extensions: - format = extension + fmt = extension else: - msg = ["Image type %s not available." % (extension)] - msg += ["Available extensions are: %s" % (" ".join(extensions))] + msg = [f"Image type {extension} not available."] + msg += [f"Available extensions are: {' '.join(extensions)}"] raise Exception("\n".join(msg)) else: - filename = ".".join([filename, format]) + filename = ".".join([filename, fmt]) - if format.lower() not in extensions: - raise Exception("Image type %s not available." % (format)) + if fmt.lower() not in extensions: + raise Exception(f"Image type {fmt} not available.") # Get the actual image - if format.lower() == "svg": + if fmt.lower() == "svg": img_text = get_reaction_image_svg(reactants, products) else: - img_text = get_reaction_image_pil(reactants, products, format, as_string=as_string) + img_text = get_reaction_image_pil(reactants, products, fmt, as_string=as_string) # Write to file, if required if filename is not None: @@ -1429,12 +1443,14 @@ def get_reaction_image(reactants, products, filename=None, format="svg", as_stri return img_text -def get_reaction_image_svg(reactants, products, width=200, height=100): +def get_reaction_image_svg( + reactants: Sequence[Molecule], products: Sequence[Molecule], width: int = 200, height: int = 100 +): """ Create a 2D reaction image from reactants and products (PLAMS molecules) * ``reactants`` -- Iterable of PLAMS Molecule objects representing the reactants. - * ``products`` -- Iterable of PLAMS Molecule objects representig the products. + * ``products`` -- Iterable of PLAMS Molecule objects representing the products. * Returns -- SVG image text file. """ from rdkit import Chem @@ -1502,12 +1518,19 @@ def add_arrow_svg(img_text, width, height, nreactants, prefix=""): return img_text -def get_reaction_image_pil(reactants, products, format, width=200, height=100, as_string=True): +def get_reaction_image_pil( + reactants: Sequence[Molecule], + products: Sequence[Molecule], + fmt: str, + width: int = 200, + height: int = 100, + as_string: bool = True, +): """ Create a 2D reaction image from reactants and products (PLAMS molecules) * ``reactants`` -- Iterable of PLAMS Molecule objects representing the reactants. - * ``products`` -- Iterable of PLAMS Molecule objects representig the products. + * ``products`` -- Iterable of PLAMS Molecule objects representing the products. * Returns -- SVG image text file. """ from io import BytesIO @@ -1583,8 +1606,8 @@ def join_pil_images(pil_images): if not hasattr(rdMolDraw2D, "MolDraw2DCairo"): # We are working with the old AMS version of RDKit white = (255, 255, 255) - rimages = [to_image(mol, format=format, as_string=False) for i, mol in enumerate(reactants)] - pimages = [to_image(mol, format=format, as_string=False) for i, mol in enumerate(products)] + rimages = [to_image(mol, fmt=fmt, as_string=False) for i, mol in enumerate(reactants)] + pimages = [to_image(mol, fmt=fmt, as_string=False) for i, mol in enumerate(products)] blanc = Image.new("RGB", (width, height), white) # Get the image (with arrow) @@ -1611,7 +1634,7 @@ def join_pil_images(pil_images): img_text = img if as_string: buf = BytesIO() - img.save(buf, format=format) + img.save(buf, format=fmt) img_text = buf.getvalue() return img_text @@ -1832,7 +1855,6 @@ def _rdmol_for_image(mol, remove_hydrogens=True): from rdkit.Chem import AllChem from rdkit.Chem import RemoveHs - # rdmol = mol.to_rdmol_new() rdmol = to_rdmol(mol, presanitize=True) # Flatten the molecule diff --git a/pyproject.toml b/pyproject.toml index a81978c15..6b43c4896 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,8 @@ analysis = [ "matplotlib>=3.5.1", "pandas>=1.5.2", "networkx>=2.7.1", - "natsort>=8.1.0" + "natsort>=8.1.0", + "pillow>=9.2.0" ] docs = [ diff --git a/unit_tests/test_molecule_interfaces.py b/unit_tests/test_molecule_interfaces.py index 0f705672c..9bef0d52b 100644 --- a/unit_tests/test_molecule_interfaces.py +++ b/unit_tests/test_molecule_interfaces.py @@ -6,7 +6,15 @@ from scm.plams.interfaces.molecule.ase import toASE, fromASE from scm.plams.unit_tests.test_helpers import get_mock_find_spec, get_mock_open_function from scm.plams.core.errors import MissingOptionalPackageError -from scm.plams.interfaces.molecule.rdkit import from_rdmol, to_rdmol, from_smiles, to_smiles, from_smarts +from scm.plams.interfaces.molecule.rdkit import ( + from_rdmol, + to_rdmol, + from_smiles, + to_smiles, + from_smarts, + to_image, + get_reaction_image, +) from scm.plams.interfaces.molecule.packmol import packmol @@ -230,6 +238,38 @@ def test_to_smiles_and_from_smiles_requires_rdkit_package(self, plams_mols, smil with pytest.raises(MissingOptionalPackageError): from_smiles(smiles[0]) + def test_to_image_and_get_reaction_image_can_generate_img_files(self, xyz_folder): + from pathlib import Path + import shutil + + # Given molecules + reactants = [Molecule(f"{xyz_folder}/reactant{i}.xyz") for i in range(1, 3)] + products = [Molecule(f"{xyz_folder}/product{i}.xyz") for i in range(1, 3)] + + # When create images for molecules and reactions + result_dir = Path("result_images/rdkit") + try: + shutil.rmtree(result_dir) + except FileNotFoundError: + pass + result_dir.mkdir(parents=True, exist_ok=True) + + for i, m in enumerate(reactants): + m.guess_bonds() + to_image(m, filename=f"{result_dir}/reactant{i+1}.png") + + for i, m in enumerate(products): + m.guess_bonds() + to_image(m, filename=f"{result_dir}/product{i+1}.png") + + get_reaction_image(reactants, products, filename=f"{result_dir}/reaction.png") + + # Then image files are successfully created + # N.B. for this test just check the files are generated, not that the contents is correct + for f in ["reactant1.png", "reactant2.png", "product1.png", "product2.png", "reaction.png"]: + file = result_dir / f + assert file.exists() + class TestPackmol: """ diff --git a/unit_tests/xyz/product1.xyz b/unit_tests/xyz/product1.xyz new file mode 100644 index 000000000..9c479f92e --- /dev/null +++ b/unit_tests/xyz/product1.xyz @@ -0,0 +1,5 @@ +3 + +O -0.0898298722 -0.1823078766 -0.4030695566 +H 0.0892757418 0.7223466689 -0.0923666662 +H 0.0251541304 -0.7229767923 0.3978492228 \ No newline at end of file diff --git a/unit_tests/xyz/product2.xyz b/unit_tests/xyz/product2.xyz new file mode 100644 index 000000000..6d6089de6 --- /dev/null +++ b/unit_tests/xyz/product2.xyz @@ -0,0 +1,18 @@ +16 + +C -0.5058742951 0.7197846249 1.4873502327 +O -1.8467123470 1.0371356691 1.5578617129 +O 0.2426104621 0.8031328420 2.4444063939 +H -1.2336711895 -0.1871011373 -3.0847664179 +H 1.1181735289 -0.9674586781 -3.3575709511 +O -2.2996282296 0.6761194013 -0.9441961430 +H 1.8534242875 -0.1774201547 0.8154069061 +H -2.5401517031 0.9226546484 -0.0164069755 +H 2.6629785819 -0.9669531279 -1.4040236989 +H -2.0724918515 1.3639645877 2.4489574275 +C -1.0116321368 0.2602883042 -0.9839670361 +C -0.5476048834 -0.1835433773 -2.2354228233 +C 0.7686298918 -0.6221529110 -2.3859139467 +C 1.6373720529 -0.6205255079 -1.2863630033 +C 1.1825135366 -0.1770749455 -0.0425597678 +C -0.1335817057 0.2638847621 0.1265310905 \ No newline at end of file diff --git a/unit_tests/xyz/reactant1.xyz b/unit_tests/xyz/reactant1.xyz new file mode 100644 index 000000000..81b5ba591 --- /dev/null +++ b/unit_tests/xyz/reactant1.xyz @@ -0,0 +1,6 @@ +4 + +H -0.2376962379 -0.7861223708 -0.7251893655 +H 0.2018867953 0.7880914552 -0.7285471665 +O -0.2208830412 0.0589617158 -0.1564035523 +H 0.2343924838 -0.0668018003 0.7570340842 \ No newline at end of file diff --git a/unit_tests/xyz/reactant2.xyz b/unit_tests/xyz/reactant2.xyz new file mode 100644 index 000000000..f2a51a242 --- /dev/null +++ b/unit_tests/xyz/reactant2.xyz @@ -0,0 +1,17 @@ +15 + +C -0.7714286028 0.5721884475 1.5648994528 +O -1.9163545429 1.2651552598 1.3585245940 +O -0.3232871439 0.3453806305 2.6856889384 +H -0.2386198927 0.0473169758 -3.1356983004 +H 1.8106086100 -1.3053060069 -2.7858132139 +O -1.6768410850 1.0650295484 -1.2558970875 +H 1.4319019123 -0.8043435938 1.4748215886 +H 2.6580511680 -1.7004062192 -0.4892887246 +H -2.2596018166 1.4887930327 2.2433128310 +C -0.6549422587 0.4025760289 -1.0187446070 +C 0.1233087332 -0.1571411828 -2.1280100126 +C 1.2660074098 -0.8988286619 -1.9347794419 +C 1.7441534022 -1.1272237399 -0.6323129834 +C 1.0490198939 -0.6209878769 0.4726114196 +C -0.1137347868 0.1066333578 0.3111295469 \ No newline at end of file