diff --git a/molecularnodes/blender/mesh.py b/molecularnodes/blender/mesh.py index f241ef7b..a7386e16 100644 --- a/molecularnodes/blender/mesh.py +++ b/molecularnodes/blender/mesh.py @@ -379,7 +379,7 @@ def create_data_object( ): # still requires a unique call TODO: figure out why # I think this has to do with the bcif instancing extraction - array = np.unique(array) + # array = np.unique(array) locations = array["translation"] * world_scale if not collection: diff --git a/molecularnodes/color.py b/molecularnodes/color.py index 023f69ab..8b835ba9 100644 --- a/molecularnodes/color.py +++ b/molecularnodes/color.py @@ -4,7 +4,113 @@ import numpy as np import numpy.typing as npt +import math +def clamp(value, min_value, max_value): + return max(min_value, min(value, max_value)) + +class Lab: + Kn = 18 + Xn = 0.950470 + Yn = 1 + Zn = 1.088830 + T0 = 0.137931034 + T1 = 0.206896552 + T2 = 0.12841855 + T3 = 0.008856452 + + def __init__(self, l=0.1, a=0.0, b=0.0): + self.lab = [l, a, b] + self.l = l + self.a = a + self.b = b + + @staticmethod + def zero(): + return Lab(0, 0, 0) + + @staticmethod + def distance(a, b): + x = b.l - a.l + y = b.a - a.a + z = b.b - a.b + return math.sqrt(x * x + y * y + z * z) + + @staticmethod + def darken(out, c, amount): + out.l = c.l - Lab.Kn * amount + out.a = c.a + out.b = c.b + return out + + @staticmethod + def lighten(out, c, amount): + return Lab.darken(out, c, -amount) + + @staticmethod + def darken_color(c, amount): + tmp_darken_lab = Lab.from_color(c) + return Lab.to_color(Lab.darken(tmp_darken_lab, tmp_darken_lab, amount)) + + @staticmethod + def lighten_color(c, amount): + return Lab.darken_color(c, -amount) + + @staticmethod + def from_color(color): + r, g, b, a = color * 255 + x, y, z = Lab.rgbToXyz(r, g, b) + l = 116 * y - 16 + return Lab(l if l >= 0 else 0, 500 * (x - y), 200 * (y - z)) + + @staticmethod + def to_color(lab): + y = (lab.l + 16) / 116 + x = y if math.isnan(lab.a) else y + lab.a / 500 + z = y if math.isnan(lab.b) else y - lab.b / 200 + + y = Lab.Yn * Lab.lab_xyz(y) + x = Lab.Xn * Lab.lab_xyz(x) + z = Lab.Zn * Lab.lab_xyz(z) + + r = Lab.xyz_rgb(3.2404542 * x - 1.5371385 * y - 0.4985314 * z) + g = Lab.xyz_rgb(-0.9692660 * x + 1.8760108 * y + 0.0415560 * z) + b = Lab.xyz_rgb(0.0556434 * x - 0.2040259 * y + 1.0572252 * z) + + return [ + round(clamp(r, 0, 255))/255.0, + round(clamp(g, 0, 255))/255.0, + round(clamp(b, 0, 255))/255.0, + 1.0 + ] + + @staticmethod + def xyz_rgb(c): + return 255 * (12.92 * c if c <= 0.00304 else 1.055 * math.pow(c, 1 / 2.4) - 0.055) + + @staticmethod + def lab_xyz(t): + return t * t * t if t > Lab.T1 else Lab.T2 * (t - Lab.T0) + + @staticmethod + def rgb_xyz(c): + c /= 255 + return c / 12.92 if c <= 0.04045 else math.pow((c + 0.055) / 1.055, 2.4) + + @staticmethod + def xyz_lab(t): + return math.pow(t, 1 / 3) if t > Lab.T3 else t / Lab.T2 + Lab.T0 + + @staticmethod + def rgbToXyz(r, g, b): + r = Lab.rgb_xyz(r) + g = Lab.rgb_xyz(g) + b = Lab.rgb_xyz(b) + x = Lab.xyz_lab((0.4124564 * r + 0.3575761 * g + 0.1804375 * b) / Lab.Xn) + y = Lab.xyz_lab((0.2126729 * r + 0.7151522 * g + 0.0721750 * b) / Lab.Yn) + z = Lab.xyz_lab((0.0193339 * r + 0.1191920 * g + 0.9503041 * b) / Lab.Zn) + return [x, y, z] + def random_rgb(seed=None): """Random Pastel RGB values""" if seed: diff --git a/molecularnodes/entities/ensemble/bcif.py b/molecularnodes/entities/ensemble/bcif.py deleted file mode 100644 index 5aa162ae..00000000 --- a/molecularnodes/entities/ensemble/bcif.py +++ /dev/null @@ -1,583 +0,0 @@ -import numpy as np -from mathutils import Matrix -from typing import Any, Dict, List, Optional, TypedDict, Union -from biotite.structure import AtomArray - - -class BCIF: - def __init__(self, file_path): - # super().__init__() - self.file_path = file_path - self.file = self.read() - self.array = _atom_array_from_bcif(self.file) - self._transforms_data = _get_ops_from_bcif(self.file) - self.n_models = 1 - self.n_atoms = self.array.shape - self.chain_ids = self._chain_ids() - - def read(self): - # if isinstance(self.file_path, BytesIO): - # open_bcif = self.file_path.getvalue() - # else: - with open(self.file_path, "rb") as data: - open_bcif = loads(data.read()) - - return open_bcif - - def assemblies(self, as_array=True): - return self._transforms_data - - def _chain_ids(self, as_int=False): - if as_int: - return np.unique(self.array.chain_id, return_inverse=True)[1] - return np.unique(self.array.chain_id) - - -def _atom_array_from_bcif(open_bcif): - categories = open_bcif.data_blocks[0] - - # check if a petworld CellPack model or not - is_petworld = False - if "PDB_model_num" in categories["pdbx_struct_assembly_gen"].field_names: - print("PetWorld!") - is_petworld = True - - atom_site = categories["atom_site"] - n_atoms = atom_site.row_count - - # Initialise the atom array that will contain all of the data for the atoms - # in the bcif file. TODO support multi-model bcif files - # we first pull out the coordinates as they are from 3 different fields, but all - # other fields should be single self-contained fields - mol = AtomArray(n_atoms) - coord_field_names = [f"Cartn_{axis}" for axis in "xyz"] - mol.coord = np.hstack( - list( - [ - np.array(atom_site[column]).reshape((n_atoms, 1)) - for column in coord_field_names - ] - ) - ) - - # the list of current - atom_site_lookup = { - # have to make sure the chain_id ends up being the same as the space operatore - "label_asym_id": "chain_id", - "label_atom_id": "atom_name", - "label_comp_id": "res_name", - "type_symbol": "element", - "label_seq_id": "res_id", - "B_iso_or_equiv": "b_factor", - "label_entity_id": "entity_id", - "pdbx_PDB_model_num": "model_id", - "pdbx_formal_charge": "charge", - "occupancy": "occupany", - "id": "atom_id", - } - - if is_petworld: - # annotations[0][1] = 'pdbx_PDB_model_num' - atom_site_lookup.pop("label_asym_id") - atom_site_lookup["pdbx_PDB_model_num"] = "chain_id" - - for name in atom_site.field_names: - # the coordinates have already been extracted so we can skip over those field names - if name in coord_field_names: - continue - - # numpy does a pretty good job of guessing the data types from the fields - data = np.array(atom_site[name]) - - # if a specific name for an annotation is already specified earlier, we can - # use that to ensure consitency. All other fields are also still added as we - # may as well do so, in case we want any extra data - annotation_name = atom_site_lookup.get(name) - if not annotation_name: - annotation_name = name - - # TODO this could be expanded to capture fields that are entirely '' and drop them - # or fill them with 0s - if annotation_name == "res_id" and data[0] == "": - data = np.array([0 if x == "" else x for x in data]) - - mol.set_annotation(annotation_name, data) - - return mol - - -def rotation_from_matrix(matrix): - rotation_matrix = np.identity(4, dtype=float) - rotation_matrix[:3, :3] = matrix - translation, rotation, scale = Matrix(rotation_matrix).decompose() - return rotation - - -def _get_ops_from_bcif(open_bcif): - is_petworld = False - cats = open_bcif.data_blocks[0] - assembly_gen = cats["pdbx_struct_assembly_gen"] - gen_arr = np.column_stack( - list([assembly_gen[name] for name in assembly_gen.field_names]) - ) - dtype = [ - ("assembly_id", int), - ("chain_id", "U10"), - ("trans_id", int), - ("rotation", float, 4), # quaternion form rotations - ("translation", float, 3), - ] - ops = cats["pdbx_struct_oper_list"] - ok_names = [ - "matrix[1][1]", - "matrix[1][2]", - "matrix[1][3]", - "matrix[2][1]", - "matrix[2][2]", - "matrix[2][3]", - "matrix[3][1]", - "matrix[3][2]", - "matrix[3][3]", - "vector[1]", - "vector[2]", - "vector[3]", - ] - # test if petworld - if "PDB_model_num" in assembly_gen.field_names: - print("PetWorld!") - is_petworld = True - op_ids = np.array(ops["id"]) - struct_ops = np.column_stack( - list([np.array(ops[name]).reshape((ops.row_count, 1)) for name in ok_names]) - ) - rotations = np.array( - list([rotation_from_matrix(x[0:9].reshape((3, 3))) for x in struct_ops]) - ) - translations = struct_ops[:, 9:12] - - gen_list = [] - for i, gen in enumerate(gen_arr): - ids = [] - if "-" in gen[1]: - if "," in gen[1]: - for gexpr in gen[1].split(","): - if "-" in gexpr: - start, end = [int(x) for x in gexpr.strip("()").split("-")] - ids.extend((np.array(range(start, end + 1))).tolist()) - else: - ids.append(int(gexpr.strip("()"))) - else: - start, end = [int(x) for x in gen[1].strip("()").split("-")] - ids.extend((np.array(range(start, end + 1))).tolist()) - else: - ids = np.array([int(x) for x in gen[1].strip("()").split(",")]).tolist() - real_ids = np.nonzero(np.in1d(op_ids, ids))[0] - chains = np.array(gen[2].strip(" ").split(",")) - if is_petworld: - # all chain of the model receive theses transformation - chains = np.array([gen[3]]) - arr = np.zeros(chains.size * len(real_ids), dtype=dtype) - arr["chain_id"] = np.tile(chains, len(real_ids)) - mask = np.repeat(np.array(real_ids), len(chains)) - try: - arr["trans_id"] = gen[3] - except IndexError: - pass - arr["rotation"] = rotations[mask, :] - arr["translation"] = translations[mask, :] - gen_list.append(arr) - return np.concatenate(gen_list) - - -# This BinaryCIF implementation was taken from here: https://gist.github.com/dsehnal/b06f5555fa9145da69fe69abfeab6eaf - -# BinaryCIF Parser -# Copyright (c) 2021 David Sehnal , licensed under MIT. -# -# Resources: -# - https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1008247 -# - https://github.com/molstar/BinaryCIF & https://github.com/molstar/BinaryCIF/blob/master/encoding.md -# -# Implementation based on Mol*: -# - https://github.com/molstar/molstar/blob/master/src/mol-io/common/binary-cif/encoding.ts -# - https://github.com/molstar/molstar/blob/master/src/mol-io/common/binary-cif/decoder.ts -# - https://github.com/molstar/molstar/blob/master/src/mol-io/reader/cif/binary/parser.ts - - -class EncodingBase(TypedDict): - kind: str - - -class EncodedData(TypedDict): - encoding: List[EncodingBase] - data: bytes - - -class EncodedColumn(TypedDict): - name: str - data: EncodedData - mask: Optional[EncodedData] - - -class EncodedCategory(TypedDict): - name: str - rowCount: int - columns: List[EncodedColumn] - - -class EncodedDataBlock(TypedDict): - header: str - categories: List[EncodedCategory] - - -class EncodedFile(TypedDict): - version: str - encoder: str - dataBlocks: List[EncodedDataBlock] - - -def _decode(encoded_data: EncodedData) -> Union[np.ndarray, List[str]]: - result = encoded_data["data"] - for encoding in encoded_data["encoding"][::-1]: - if encoding["kind"] in _decoders: - result = _decoders[encoding["kind"]](result, encoding) # type: ignore - else: - raise ValueError(f"Unsupported encoding '{encoding['kind']}'") - - return result # type: ignore - - -class DataTypes: - Int8 = 1 - Int16 = 2 - Int32 = 3 - Uint8 = 4 - Uint16 = 5 - Uint32 = 6 - Float32 = 32 - Float64 = 33 - - -_dtypes = { - DataTypes.Int8: "i1", - DataTypes.Int16: "i2", - DataTypes.Int32: "i4", - DataTypes.Uint8: "u1", - DataTypes.Uint16: "u2", - DataTypes.Uint32: "u4", - DataTypes.Float32: "f4", - DataTypes.Float64: "f8", -} - - -def _get_dtype(type: int) -> str: - if type in _dtypes: - return _dtypes[type] - - raise ValueError(f"Unsupported data type '{type}'") - - -class ByteArrayEncoding(EncodingBase): - type: int - - -class FixedPointEncoding(EncodingBase): - factor: float - srcType: int - - -class IntervalQuantizationEncoding(EncodingBase): - min: float - max: float - numSteps: int - srcType: int - - -class RunLengthEncoding(EncodingBase): - srcType: int - srcSize: int - - -class DeltaEncoding(EncodingBase): - origin: int - srcType: int - - -class IntegerPackingEncoding(EncodingBase): - byteCount: int - isUnsigned: bool - srcSize: int - - -class StringArrayEncoding(EncodingBase): - dataEncoding: List[EncodingBase] - stringData: str - offsetEncoding: List[EncodingBase] - offsets: bytes - - -def _decode_byte_array(data: bytes, encoding: ByteArrayEncoding) -> np.ndarray: - return np.frombuffer(data, dtype="<" + _get_dtype(encoding["type"])) - - -def _decode_fixed_point(data: np.ndarray, encoding: FixedPointEncoding) -> np.ndarray: - return np.array(data, dtype=_get_dtype(encoding["srcType"])) / encoding["factor"] - - -def _decode_interval_quantization( - data: np.ndarray, encoding: IntervalQuantizationEncoding -) -> np.ndarray: - delta = (encoding["max"] - encoding["min"]) / (encoding["numSteps"] - 1) - return ( - np.array(data, dtype=_get_dtype(encoding["srcType"])) * delta + encoding["min"] - ) - - -def _decode_run_length(data: np.ndarray, encoding: RunLengthEncoding) -> np.ndarray: - return np.repeat( - np.array(data[::2], dtype=_get_dtype(encoding["srcType"])), repeats=data[1::2] - ) - - -def _decode_delta(data: np.ndarray, encoding: DeltaEncoding) -> np.ndarray: - result = np.array(data, dtype=_get_dtype(encoding["srcType"])) - if encoding["origin"]: - result[0] += encoding["origin"] - return np.cumsum(result, out=result) - - -def _decode_integer_packing_signed( - data: np.ndarray, encoding: IntegerPackingEncoding -) -> np.ndarray: - upper_limit = 0x7F if encoding["byteCount"] == 1 else 0x7FFF - lower_limit = -upper_limit - 1 - n = len(data) - output = np.zeros(encoding["srcSize"], dtype="i4") - i = 0 - j = 0 - while i < n: - value = 0 - t = data[i] - while t == upper_limit or t == lower_limit: - value += t - i += 1 - t = data[i] - value += t - output[j] = value - i += 1 - j += 1 - return output - - -def _decode_integer_packing_unsigned( - data: np.ndarray, encoding: IntegerPackingEncoding -) -> np.ndarray: - upper_limit = 0xFF if encoding["byteCount"] == 1 else 0xFFFF - n = len(data) - output = np.zeros(encoding["srcSize"], dtype="i4") - i = 0 - j = 0 - while i < n: - value = 0 - t = data[i] - while t == upper_limit: - value += t - i += 1 - t = data[i] - value += t - output[j] = value - i += 1 - j += 1 - return output - - -def _decode_integer_packing( - data: np.ndarray, encoding: IntegerPackingEncoding -) -> np.ndarray: - if len(data) == encoding["srcSize"]: - return data - if encoding["isUnsigned"]: - return _decode_integer_packing_unsigned(data, encoding) - else: - return _decode_integer_packing_signed(data, encoding) - - -def _decode_string_array(data: bytes, encoding: StringArrayEncoding) -> List[str]: - offsets = _decode( - EncodedData(encoding=encoding["offsetEncoding"], data=encoding["offsets"]) - ) - indices = _decode(EncodedData(encoding=encoding["dataEncoding"], data=data)) - - str = encoding["stringData"] - strings = [""] - for i in range(1, len(offsets)): - strings.append(str[offsets[i - 1] : offsets[i]]) # type: ignore - - return [strings[i + 1] for i in indices] # type: ignore - - -_decoders = { - "ByteArray": _decode_byte_array, - "FixedPoint": _decode_fixed_point, - "IntervalQuantization": _decode_interval_quantization, - "RunLength": _decode_run_length, - "Delta": _decode_delta, - "IntegerPacking": _decode_integer_packing, - "StringArray": _decode_string_array, -} - - -############################################################################### - - -class CifValueKind: - Present = 0 - # Expressed in CIF as `.` - NotPresent = 1 - # Expressed in CIF as `?` - Unknown = 2 - - -class CifField: - def __getitem__(self, idx: int) -> Union[str, float, int, None]: - # if self._value_kinds and self._value_kinds[idx]: - # return None - return self._values[idx] - - def __len__(self): - return self.row_count - - @property - def values(self): - """ - A numpy array of numbers or a list of strings. - """ - return self._values - - @property - def value_kinds(self): - """ - value_kinds represent the presence or absence of particular "CIF value". - - If the mask is not set, every value is present: - - 0 = Value is present - - 1 = . = value not specified - - 2 = ? = value unknown - """ - return self._value_kinds - - def __init__( - self, - name: str, - values: Union[np.ndarray, List[str]], - value_kinds: Optional[np.ndarray], - ): - self.name = name - self._values = values - self._value_kinds = value_kinds - self.row_count = len(values) - - -class CifCategory: - def __getattr__(self, name: str) -> Any: - return self[name] - - def __getitem__(self, name: str) -> Optional[CifField]: - if name not in self._field_cache: - return None - - if not self._field_cache[name]: - self._field_cache[name] = _decode_column(self._columns[name]) - - return self._field_cache[name] - - def __contains__(self, key: str): - return key in self._columns - - def __init__(self, category: EncodedCategory, lazy: bool): - self.field_names = [c["name"] for c in category["columns"]] - self._field_cache = { - c["name"]: None if lazy else _decode_column(c) for c in category["columns"] - } - self._columns: Dict[str, EncodedColumn] = { - c["name"]: c for c in category["columns"] - } - self.row_count = category["rowCount"] - self.name = category["name"][1:] - - -class CifDataBlock: - def __getattr__(self, name: str) -> Any: - return self.categories[name] - - def __getitem__(self, name: str) -> CifCategory: - return self.categories[name] - - def __contains__(self, key: str): - return key in self.categories - - def __init__(self, header: str, categories: Dict[str, CifCategory]): - self.header = header - self.categories = categories - - -class CifFile: - def __getitem__(self, index_or_name: Union[int, str]): - """ - Access a data block by index or header (case sensitive) - """ - if isinstance(index_or_name, str): - return ( - self._block_map[index_or_name] - if index_or_name in self._block_map - else None - ) - else: - return ( - self.data_blocks[index_or_name] - if index_or_name < len(self.data_blocks) - else None - ) - - def __len__(self): - return len(self.data_blocks) - - def __contains__(self, key: str): - return key in self._block_map - - def __init__(self, data_blocks: List[CifDataBlock]): - self.data_blocks = data_blocks - self._block_map = {b.header: b for b in data_blocks} - - -def _decode_column(column: EncodedColumn) -> CifField: - values = _decode(column["data"]) - value_kinds = _decode(column["mask"]) if column["mask"] else None # type: ignore - # type: ignore - return CifField(name=column["name"], values=values, value_kinds=value_kinds) - - -def loads(data: Union[bytes, EncodedFile], lazy=True) -> CifFile: - """ - - data: msgpack encoded blob or EncodedFile object - - lazy: - - True: individual columns are decoded only when accessed - - False: decode all columns immediately - """ - import msgpack - - file: EncodedFile = ( - data if isinstance(data, dict) and "dataBlocks" in data else msgpack.loads(data) - ) # type: ignore - - data_blocks = [ - CifDataBlock( - header=block["header"], - categories={ - cat["name"][1:]: CifCategory(category=cat, lazy=lazy) - for cat in block["categories"] - }, - ) - for block in file["dataBlocks"] - ] - - return CifFile(data_blocks=data_blocks) diff --git a/molecularnodes/entities/ensemble/cellpack.py b/molecularnodes/entities/ensemble/cellpack.py index 6a56afdb..277cea2d 100644 --- a/molecularnodes/entities/ensemble/cellpack.py +++ b/molecularnodes/entities/ensemble/cellpack.py @@ -1,24 +1,86 @@ +import os +import json from pathlib import Path +from biotite.structure import AtomArray import numpy as np import bpy from .ensemble import Ensemble -from .bcif import BCIF -from .cif import OldCIF +from .cif import CIF from ..molecule import molecule from ... import blender as bl from ... import color class CellPack(Ensemble): - def __init__(self, file_path): + def __init__(self, file_path, remove_space=False): super().__init__(file_path) - self.file_type = self._file_type() - self.data = self._read(self.file_path) - self.array = self.data.array + self.data = self._read(self.file_path, remove_space) self.transformations = self.data.assemblies(as_array=True) - self.chain_ids = self.data.chain_ids + self.chain_ids = self.array.asym_id + self.entity_ids = np.unique(self.array.entity_id) + self.entity_chains = {} + + # look up color_palette of entity_id + wpath = os.path.dirname(os.path.abspath(self.file_path)) + self.color_palette = os.path.join(wpath, "color_palette.json") + self.color_entity = {} + if os.path.exists(self.color_palette): + self.color_palette = json.load(open(self.color_palette, "r")) + for entity in np.unique(self.array.entity_id): + ename = self.data.entities[entity] + if ename in self.color_palette: + self.color_entity[entity] = np.array( + [ + self.color_palette[ename]["x"] / 255.0, + self.color_palette[ename]["y"] / 255.0, + self.color_palette[ename]["z"] / 255.0, + 1.0, + ] + ) + else: + self.color_entity[entity] = color.random_rgb(int(entity)) + for i, entity in enumerate(self.entity_ids): + symids = self.array.asym_id[self.array.entity_id == entity] + self.entity_chains[entity] = np.unique(symids) + + @property + def array(self): + return self.data.array + + # @property + # def chain_ids(self): + # return np.unique(self.array) + + def create_transparent_material(self, name="MN Transparent"): + # Create a new material + material_name = name + material = bpy.data.materials.new(name=material_name) + + # Enable 'Use Nodes' + material.use_nodes = True + + # Clear all default nodes + nodes = material.node_tree.nodes + nodes.clear() + + # Add a Material Output node + output_node = nodes.new(type="ShaderNodeOutputMaterial") + output_node.location = (300, 0) + + # Add a Transparent BSDF node + transparent_node = nodes.new(type="ShaderNodeBsdfTransparent") + transparent_node.location = (0, 0) + + # Connect the Transparent BSDF node to the Material Output node + material.node_tree.links.new( + transparent_node.outputs["BSDF"], output_node.inputs["Surface"] + ) + + # Optionally set the color of the transparent BSDF + transparent_node.inputs["Color"].default_value = (1, 1, 1, 1) # RGBA + return material def create_object( self, @@ -34,42 +96,39 @@ def create_object( return self.data_object - def _file_type(self): - return Path(self.file_path).suffix.strip(".") + @property + def file_type(self): + return self.file_path.suffix.strip(".") - def _read(self, file_path): + def _read(self, file_path, remove_space=False): "Read a Cellpack File" - suffix = Path(file_path).suffix - - if suffix in (".bin", ".bcif"): - data = BCIF(file_path) - elif suffix == ".cif": - data = OldCIF(file_path) - else: - raise ValueError(f"Invalid file format: '{suffix}") - + data = CIF(file_path, remove_space) return data def _create_object_instances( self, name: str = "CellPack", node_setup: bool = True ) -> bpy.types.Collection: collection = bl.coll.cellpack(name) - - if self.file_type == "cif": - array = self.array[0] - else: - array = self.array - for i, chain in enumerate(np.unique(array.chain_id)): - chain_atoms = array[array.chain_id == chain] - obj, coll_none = molecule._create_object( + for i, chain in enumerate(np.unique(self.array.asym_id)): + # print(f"Creating chain {chain}...") + chain_atoms = self.array[self.array.asym_id == chain] + model, coll_none = molecule._create_object( array=chain_atoms, name=f"{str(i).rjust(4, '0')}_{chain}", collection=collection, ) - - colors = np.tile(color.random_rgb(i), (len(chain_atoms), 1)) + # random color per chain + # could also do by entity, + chain-lighten + atom-lighten + entity = chain_atoms.entity_id[0] + color_entity = self.color_entity[entity] + # color.random_rgb(int(entity)) + # lighten for each chain + nc = len(self.entity_chains[entity]) + ci = np.where(self.entity_chains[entity] == chain)[0][0] * 2 + color_chain = color.Lab.lighten_color(color_entity, (float(ci) / nc)) + colors = np.tile(color_chain, (len(chain_atoms), 1)) bl.mesh.store_named_attribute( - obj, + model, name="Color", data=colors, data_type="FLOAT_COLOR", @@ -78,7 +137,7 @@ def _create_object_instances( if node_setup: bl.nodes.create_starting_node_tree( - obj, name=f"MN_pack_instance_{name}", color=None + model, name=f"MN_pack_instance_{name}", color=None ) self.data_collection = collection @@ -102,9 +161,81 @@ def _setup_node_tree(self, name="CellPack", fraction=1.0, as_points=False): node_pack = bl.nodes.add_custom(group, "Ensemble Instance", location=[-100, 0]) node_pack.inputs["Instances"].default_value = self.data_collection - node_pack.inputs["Fraction"].default_value = fraction - node_pack.inputs["As Points"].default_value = as_points + # node_pack.inputs['Fraction'].default_value = fraction + # node_pack.inputs['As Points'].default_value = as_points + + # Create the GeometryNodeIsViewport node + node_is_viewport = group.nodes.new("GeometryNodeIsViewport") + node_is_viewport.location = (-490.0, -240.0) + + # Create the GeometryNodeSwitch node + node_switch = group.nodes.new("GeometryNodeSwitch") + node_switch.location = (-303.0, -102.0) + # Set the input type of the switch node to FLOAT + node_switch.input_type = "FLOAT" + # Set the true and false values of the switch node + node_switch.inputs[1].default_value = 1.0 + node_switch.inputs[2].default_value = 0.1 + + group.links.new(node_is_viewport.outputs[0], node_switch.inputs[0]) + + group.links.new(node_switch.outputs[0], node_pack.inputs["Fraction"]) + + group.links.new(node_is_viewport.outputs[0], node_pack.inputs["As Points"]) + + # createa a plane primitive node + node_plane = group.nodes.new("GeometryNodeMeshGrid") + node_plane.location = (-1173, 252) + + # create a geomtry transform node + node_transform = group.nodes.new("GeometryNodeTransform") + node_transform.location = (-947, 245) + # change mesh translation + node_transform.inputs[1].default_value = (3.0, 0.0, 0.0) + # change mesh rotation + node_transform.inputs[2].default_value = (0.0, 3.14 / 2.0, 0.0) + # change mesh scale + node_transform.inputs[3].default_value = (50.0, 50.0, 1.0) + # link the plane to the transform node + group.links.new(node_plane.outputs[0], node_transform.inputs[0]) + + # create transparent material and setMaterial node + material = self.create_transparent_material() + node_set_material = group.nodes.new("GeometryNodeSetMaterial") + node_set_material.location = (-100, 289) + group.links.new(node_transform.outputs[0], node_set_material.inputs[0]) + node_set_material.inputs[2].default_value = material + # create the join geoemtry node + node_join = group.nodes.new("GeometryNodeJoinGeometry") + node_join.location = (151, 122) + group.links.new(node_set_material.outputs[0], node_join.inputs[0]) + group.links.new(node_pack.outputs[0], node_join.inputs[0]) + + # create a geomtry proximity node and link the plane to it + node_proximity = group.nodes.new("GeometryNodeProximity") + node_proximity.location = (-586, 269) + group.links.new(node_transform.outputs[0], node_proximity.inputs[0]) + + # get the position attribute node + node_position = group.nodes.new("GeometryNodeInputPosition") + node_position.location = (-796, 86) + + # link it to the posistion sample in proximity + group.links.new(node_position.outputs[0], node_proximity.inputs[2]) + + # create a compare node that take the distance from the proximity node + # and compare it to be greter than 2.0 + node_compare = group.nodes.new("FunctionNodeCompare") + node_compare.location = (-354, 316) + node_compare.data_type = "FLOAT" + node_compare.operation = "GREATER_THAN" + node_compare.inputs[1].default_value = 2.0 + # do the link + group.links.new(node_proximity.outputs[1], node_compare.inputs[0]) + + # link the outpot of the compare node to the selection node_pack + group.links.new(node_compare.outputs[0], node_pack.inputs["Selection"]) link = group.links.new link(bl.nodes.get_input(group).outputs[0], node_pack.inputs[0]) - link(node_pack.outputs[0], bl.nodes.get_output(group).inputs[0]) + link(node_join.outputs[0], bl.nodes.get_output(group).inputs[0]) diff --git a/molecularnodes/entities/ensemble/cif.py b/molecularnodes/entities/ensemble/cif.py index 06be969d..6f3042ca 100644 --- a/molecularnodes/entities/ensemble/cif.py +++ b/molecularnodes/entities/ensemble/cif.py @@ -1,326 +1,269 @@ -import itertools -import warnings - -import biotite.structure as struc -import biotite.structure.io.pdbx as pdbx +from pathlib import Path import numpy as np -from biotite import InvalidFileError - -from ..molecule.assembly import AssemblyParser -from ..molecule.molecule import Molecule - - -class OldCIF(Molecule): - def __init__(self, file_path, extra_fields=None, sec_struct=True): - super().__init__(file_path=file_path) - self.array = self._get_structure( - extra_fields=extra_fields, sec_struct=sec_struct - ) - self.n_atoms = self.array.array_length() - - def _read(self, file_path): - return pdbx.legacy.PDBxFile.read(file_path) - - def _get_structure(self, extra_fields: str = None, sec_struct=True, bonds=True): - fields = ["b_factor", "charge", "occupancy", "atom_id"] - if extra_fields: - [fields.append(x) for x in extra_fields] - - # if the 'atom_site' doesn't exist then it will just be a small molecule - # which can be extracted with the get_component() - try: - array = pdbx.get_structure(self.file, extra_fields=extra_fields) - try: - array.set_annotation( - "sec_struct", _get_secondary_structure(array, self.file) - ) - except KeyError: - warnings.warn("No secondary structure information.") - try: - array.set_annotation("entity_id", _get_entity_id(array, self.file)) - except KeyError: - warnings.warn("Non entity_id information.") - - except InvalidFileError: - array = pdbx.get_component(self.file) - - # pdbx files don't seem to have bond information defined, so connect them based - # on their residue names - if not array.bonds and bonds: - array.bonds = struc.bonds.connect_via_residue_names( - array, inter_residue=True - ) - - return array - - def _entity_ids(self): - entities = self.file["entity"] - if not entities: - return None - - return entities.get("pdbx_description", None) - - def _assemblies(self): - return CIFAssemblyParser(self.file).get_assemblies() +from mathutils import Matrix +from typing import Any, Dict, List, Optional, TypedDict, Union +from biotite.structure import AtomArray +import biotite.structure.io.pdbx as pdbx -def _ss_label_to_int(label): - if "HELX" in label: - return 1 - elif "STRN" in label: - return 2 +class CIF: + def __init__(self, file_path, remove_space=False): + # super().__init__() + self.file_path = file_path + self.file = self.read(remove_space) + self.entities = {} + categories = self.file.block + # check if a petworld CellPack model or not + self.is_petworld = False + if "PDB_model_num" in categories["pdbx_struct_assembly_gen"]: + self.is_petworld = True + entity = {} + entityids = [] + pdbx_description = [] + if self.is_petworld: + entity = categories['pdbx_model'] + entityids = [str(i+1) for i in range(len(entity['name']))] + pdbx_description = entity['name'].as_array() + else: + entity = categories['entity'] + entityids = entity['id'].as_array() + pdbx_description = entity['pdbx_description'].as_array() + for i in range(len(entityids)): + self.entities[entityids[i]] = pdbx_description[i] + + self.array = _atom_array_from_cif(categories) + self.lookup = _get_entity_chain_id(self.array, categories) + self._transforms_data = _get_ops_from_cif(categories, + lookup=self.lookup) + self.n_models = 1 + self.n_atoms = self.array.shape + if self.is_petworld: + self.array.asym_id = self.array.chain_id + # self.array.chain_id = self.array.asym_id + self.chain_ids = self._chain_ids() + + # Function to remove leading whitespaces line by line + def remove_leading_whitespace(self, file_path, output_file_path): + # Open the original file for reading and a new file for writing + with open(file_path, 'r') as infile, open(output_file_path, 'w') as outfile: + # Process the file line by line + for line in infile: + # Remove leading whitespace from each line and write it to the new file + outfile.write(line.lstrip()) + + def read(self, remove_space=False): + suffix = Path(self.file_path).suffix + print('reading file', self.file_path) + if suffix in (".bin", ".bcif"): + return pdbx.BinaryCIFFile.read(self.file_path) + elif suffix == ".cif": + if remove_space: + self.remove_leading_whitespace( + str(self.file_path), + str(self.file_path) + ".nw.cif") + self.file_path = str(self.file_path) + ".nw.cif" + return pdbx.CIFFile.read(self.file_path) + # with open(self.file_path, "rb") as data: + # open_bcif = loads(data.read()) + # + # return open_bcif + + def assemblies(self, as_array=True): + return self._transforms_data + + def _chain_ids(self, as_int=False): + if as_int: + return np.unique(self.array.chain_id, return_inverse=True)[1] + return np.unique(self.array.chain_id) + + +def _get_entity_chain_id(array, categories): + if "entity_poly" in categories: + chain_ids = categories["entity_poly"]["pdbx_strand_id"].as_array() + entity_ids = categories["entity_poly"]["entity_id"].as_array() + entity_lookup = dict(zip(chain_ids, entity_ids)) + return entity_lookup else: - return 3 - - -def _get_secondary_structure(array, file): - """ - Get secondary structure information for the array from the file. - - Parameters - ---------- - array : numpy array - The array for which secondary structure information is to be retrieved. - file : object - The file object containing the secondary structure information. - - Returns - ------- - numpy array - A numpy array of secondary structure information, where each element is either 0, 1, 2, or 3. - - 0: Not a peptide - - 1: Alpha helix - - 2: Beta sheet - - 3: Loop - - Raises - ------ - KeyError - If the 'struct_conf' category is not found in the file. - """ - - # get the annotations for the struc_conf cetegory. Provides start and end - # residues for the annotations. For most files this will only contain the - # alpha helices, but will sometimes contain also other secondary structure - # information such as in AlphaFold predictions - - conf = file.get_category("struct_conf") - if not conf: - raise KeyError - starts = conf["beg_auth_seq_id"].astype(int) - ends = conf["end_auth_seq_id"].astype(int) - chains = conf["end_auth_asym_id"].astype(str) - id_label = conf["id"].astype(str) - - # most files will have a separate category for the beta sheets - # this can just be appended to the other start / end / id and be processed - # as normal - sheet = file.get_category("struct_sheet_range") - if sheet: - starts = np.append(starts, sheet["beg_auth_seq_id"].astype(int)) - ends = np.append(ends, sheet["end_auth_seq_id"].astype(int)) - chains = np.append(chains, sheet["end_auth_asym_id"].astype(str)) - id_label = np.append(id_label, np.repeat("STRN", len(sheet["id"]))) - - # convert the string labels to integer representations of the SS - # AH: 1, BS: 2, LOOP: 3 - - id_int = np.array([_ss_label_to_int(label) for label in id_label], int) - - # create a lookup dictionary that enables lookup of secondary structure - # based on the chain_id and res_id values - - lookup = dict() - for chain in np.unique(chains): - arrays = [] - mask = chain == chains - start_sub = starts[mask] - end_sub = ends[mask] - id_sub = id_int[mask] - - for start, end, id in zip(start_sub, end_sub, id_sub): - idx = np.arange(start, end + 1, dtype=int) - arr = np.zeros((len(idx), 2), dtype=int) - arr[:, 0] = idx - arr[:, 1] = 3 - arr[:, 1] = id - arrays.append(arr) - - lookup[chain] = dict(np.vstack(arrays).tolist()) - - # use the lookup dictionary to get the SS annotation based on the chain_id and res_id - secondary_structure = np.zeros(len(array.chain_id), int) - for i, (chain, res) in enumerate(zip(array.chain_id, array.res_id)): - try: - secondary_structure[i] = lookup[chain].get(res, 3) - except KeyError: - secondary_structure[i] = 0 - - # assign SS to 0 where not peptide - secondary_structure[~struc.filter_amino_acids(array)] = 0 - return secondary_structure - - -def _get_entity_id(array, file): - entities = file.get_category("entity_poly") - if not entities: - raise KeyError - chain_ids = entities["pdbx_strand_id"] - - # the chain_ids are an array of individual items np.array(['A,B', 'C', 'D,E,F']) - # which need to be categorised as [1, 1, 2, 3, 3, 3] for their belonging to individual - # entities - - chains = [] - idx = [] - for i, chain_str in enumerate(chain_ids): - for chain in chain_str.split(","): - chains.append(chain) - idx.append(i) - - entity_lookup = dict(zip(chains, idx)) - chain_id_int = np.array( - [entity_lookup.get(chain, -1) for chain in array.chain_id], int - ) - return chain_id_int - - -class CIFAssemblyParser(AssemblyParser): - # Implementation adapted from ``biotite.structure.io.pdbx.convert`` - - def __init__(self, file_cif): - self._file = file_cif - - def list_assemblies(self): - return list(pdbx.list_assemblies(self._file).keys()) - - def get_transformations(self, assembly_id): - assembly_gen_category = self._file["pdbx_struct_assembly_gen"] - - struct_oper_category = self._file["pdbx_struct_oper_list"] - - if assembly_id not in assembly_gen_category["assembly_id"]: - raise KeyError(f"File has no Assembly ID '{assembly_id}'") - - # Extract all possible transformations indexed by operation ID - transformation_dict = _get_transformations(struct_oper_category) - - # Get necessary transformations and the affected chain IDs - # NOTE: The chains given here refer to the `label_asym_id` field - # of the `atom_site` category - # However, by default `PDBxFile` uses the `auth_asym_id` as - # chain ID - matrices = [] - for id, op_expr, asym_id_expr in zip( - assembly_gen_category["assembly_id"], - assembly_gen_category["oper_expression"], - assembly_gen_category["asym_id_list"], - ): - # Find the operation expressions for given assembly ID - # We already asserted that the ID is actually present - if id == assembly_id: - operations = _parse_operation_expression(op_expr) - affected_chain_ids = asym_id_expr.split(",") - for i, operation in enumerate(operations): - rotations = [] - translations = [] - for op_step in operation: - rotation, translation = transformation_dict[op_step] - rotations.append(rotation) - translations.append(translation) - matrix = _chain_transformations(rotations, translations) - matrices.append((affected_chain_ids, matrix.tolist())) - - return matrices - - def get_assemblies(self): - assembly_dict = {} - for assembly_id in self.list_assemblies(): - assembly_dict[assembly_id] = self.get_transformations(assembly_id) - - return assembly_dict - - -def _chain_transformations(rotations, translations): - """ - Get a total rotation/translation transformation by combining - multiple rotation/translation transformations. - This is done by intermediately combining rotation matrices and - translation vectors into 4x4 matrices in the form - - |r11 r12 r13 t1| - |r21 r22 r23 t2| - |r31 r32 r33 t3| - |0 0 0 1 |. - """ - total_matrix = np.identity(4) - for rotation, translation in zip(rotations, translations): - matrix = np.zeros((4, 4)) - matrix[:3, :3] = rotation - matrix[:3, 3] = translation - matrix[3, 3] = 1 - total_matrix = matrix @ total_matrix - - # return total_matrix[:3, :3], total_matrix[:3, 3] - return matrix - - -def _get_transformations(struct_oper): - """ - Get transformation operation in terms of rotation matrix and - translation for each operation ID in ``pdbx_struct_oper_list``. - """ - transformation_dict = {} - for index, id in enumerate(struct_oper["id"]): - rotation_matrix = np.array( + chain_ids = [] + entity_ids = [] + for i,chain in enumerate(array.asym_id): + if chain not in chain_ids or array.entity_id[i] not in entity_ids: + chain_ids.append(chain) + entity_ids.append(array.entity_id[i]) + entity_lookup = dict(zip(chain_ids, entity_ids)) + return entity_lookup + + +def _atom_array_from_cif(categories): + # check if a petworld CellPack model or not + is_petworld = False + if "PDB_model_num" in categories["pdbx_struct_assembly_gen"]: + is_petworld = True + + atom_site = categories["atom_site"] + n_atoms = atom_site.row_count + + # Initialise the atom array that will contain all of the data for the atoms + # in the bcif file. TODO support multi-model bcif files + # we first pull out the coordinates + # as they are from 3 different fields, but all + # other fields should be single self-contained fields + mol = AtomArray(n_atoms) + coord_field_names = [f"Cartn_{axis}" for axis in "xyz"] + mol.coord = np.hstack( + list( [ - [float(struct_oper[f"matrix[{i}][{j}]"][index]) for j in (1, 2, 3)] - for i in (1, 2, 3) + atom_site[column].as_array().reshape((n_atoms, 1)) + for column in coord_field_names ] ) - translation_vector = np.array( - [float(struct_oper[f"vector[{i}]"][index]) for i in (1, 2, 3)] - ) - transformation_dict[id] = (rotation_matrix, translation_vector) - return transformation_dict - - -def _parse_operation_expression(expression): - """ - Get successive operation steps (IDs) for the given - ``oper_expression``. - Form the cartesian product, if necessary. - """ - # Split groups by parentheses: - # use the opening parenthesis as delimiter - # and just remove the closing parenthesis - expressions_per_step = expression.replace(")", "").split("(") - expressions_per_step = [e for e in expressions_per_step if len(e) > 0] - # Important: Operations are applied from right to left - expressions_per_step.reverse() + ) - operations = [] - for expr in expressions_per_step: - if "-" in expr: - if "," in expr: - for gexpr in expr.split(","): + # the list of current + atom_site_lookup = { + # have to make sure the chain_id + # ends up being the same as the space operator + "label_asym_id": "chain_id", + "label_atom_id": "atom_name", + "label_comp_id": "res_name", + "type_symbol": "element", + "label_seq_id": "res_id", + "B_iso_or_equiv": "b_factor", + "label_entity_id": "entity_id", + "pdbx_PDB_model_num": "model_id", + "pdbx_formal_charge": "charge", + "occupancy": "occupany", + "id": "atom_id", + } + + if is_petworld: + atom_site_lookup.pop("label_asym_id") + atom_site_lookup["pdbx_PDB_model_num"] = "chain_id" + atom_site_lookup.pop("label_entity_id") + + # for name in atom_site.field_names: + for name, column in atom_site.items(): + # the coordinates have already been extracted + # so we can skip over those field names + if name in coord_field_names: + continue + # numpy does a pretty good job of guessing + # the data types from the fields + data = atom_site[name].as_array() + if name == "label_asym_id": + # print("set annoatation ", name) + # print(data) + mol.asym_id = data + # if a specific name for an annotation is + # already specified earlier, we can + # use that to ensure consitency. All other + # fields are also still added as we + # may as well do so, in case we want any extra data + annotation_name = atom_site_lookup.get(name) + if not annotation_name: + annotation_name = name + # TODO this could be expanded to capture + # fields that are entirely '' and drop them + # or fill them with 0s + if annotation_name == "res_id" and (data[0] == "" or data[0] == "."): + data = np.array([0 if (x == "" or x == ".") else x for x in data]) + mol.set_annotation(annotation_name, data) + if name == "pdbx_PDB_model_num" and is_petworld: + mol.set_annotation('entity_id', data) + return mol + + +def rotation_from_matrix(matrix): + rotation_matrix = np.identity(4, dtype=float) + rotation_matrix[:3, :3] = matrix + translation, rotation, scale = Matrix(rotation_matrix).decompose() + return rotation + + +def _get_ops_from_cif(categories, lookup=None): + is_petworld = False + assembly_gen = categories["pdbx_struct_assembly_gen"] + gen_arr = np.column_stack( + list([assembly_gen[name].as_array() for name in assembly_gen]) + ) + dtype = [ + ("assembly_id", int), + ("chain_id", "U10"), + ("transform_id", int), + ("rotation", float, 4), # quaternion form rotations + ("translation", float, 3), + ] + ops = categories["pdbx_struct_oper_list"] + ok_names = [ + "matrix[1][1]", + "matrix[1][2]", + "matrix[1][3]", + "matrix[2][1]", + "matrix[2][2]", + "matrix[2][3]", + "matrix[3][1]", + "matrix[3][2]", + "matrix[3][3]", + "vector[1]", + "vector[2]", + "vector[3]", + ] + # test if petworld + if "PDB_model_num" in assembly_gen: + is_petworld = True + # operator ID can be a string + op_ids = ops["id"].as_array(str) + struct_ops = np.column_stack( + list([ops[name].as_array().reshape((ops.row_count, 1)) + for name in ok_names]) + ) + rotations = np.array( + list([rotation_from_matrix(x[0:9].reshape((3, 3))) + for x in struct_ops]) + ) + translations = struct_ops[:, 9:12] + + gen_list = [] + for i, gen in enumerate(gen_arr): + ids = [] + if "-" in gen[1]: + if "," in gen[1]: + for gexpr in gen[1].split(","): if "-" in gexpr: - first, last = gexpr.split("-") - operations.append( - [str(id) for id in range(int(first), int(last) + 1)] - ) + start, end = [int(x) + for x in gexpr.strip("()").split("-")] + ids.extend((np.array(range(start, end + 1))).tolist()) else: - operations.append([gexpr]) + ids.append(int(gexpr.strip("()"))) else: - # Range of operation IDs, they must be integers - first, last = expr.split("-") - operations.append([str(id) for id in range(int(first), int(last) + 1)]) - elif "," in expr: - # List of operation IDs - operations.append(expr.split(",")) + start, end = [int(x) for x in gen[1].strip("()").split("-")] + ids.extend((np.array(range(start, end + 1))).tolist()) else: - # Single operation ID - operations.append([expr]) - - # Cartesian product of operations - return list(itertools.product(*operations)) + ids = np.array([int(x) + for x in gen[1].strip("()").split(",")]).tolist() + real_ids = np.nonzero(np.in1d(op_ids, [str(num) for num in ids]))[0] + chains = np.array(gen[2].strip(" ").split(",")) + if is_petworld: + # all chain of the model receive theses transformation + chains = np.array([gen[3]]) + arr = np.zeros(chains.size * len(real_ids), dtype=dtype) + arr["chain_id"] = np.tile(chains, len(real_ids)) + mask = np.repeat(np.array(real_ids), len(chains)) + if len(mask) == 0: + print("no mask chains are ", chains, real_ids, mask) + try: + arr["assembly_id"] = gen[0] + except IndexError: + pass + if is_petworld: + arr["transform_id"] = gen[3] + else: + if lookup: + arr["transform_id"] = np.array( + [lookup[chain] for chain in arr["chain_id"]]) + else: + arr["transform_id"] = mask + arr["rotation"] = rotations[mask, :] + arr["translation"] = translations[mask, :] + gen_list.append(arr) + return np.concatenate(gen_list) diff --git a/molecularnodes/entities/ensemble/oldcif.py b/molecularnodes/entities/ensemble/oldcif.py new file mode 100644 index 00000000..c6c6d47b --- /dev/null +++ b/molecularnodes/entities/ensemble/oldcif.py @@ -0,0 +1,330 @@ +import itertools +import warnings + +import biotite.structure as struc +import biotite.structure.io.pdbx as pdbx +import numpy as np +from biotite import InvalidFileError + +from ..molecule.assembly import AssemblyParser +from ..molecule.molecule import Molecule + +# import biotite.structure.io.pdbx as pdbx +# file_path = "D:\\Data\\machineryoflife\\cellpack_atom_instancesApr2024.cif" +# file_path = "D:\\Data\\sarscov2\\SaiLi_atom_instances.cif" +# file = pdbx.legacy.PDBxFile.read(file_path) +# array = pdbx.get_structure(file) +class OldCIF(Molecule): + def __init__(self, file_path, extra_fields=None, sec_struct=True): + super().__init__(file_path=file_path) + self.array = self._get_structure( + extra_fields=extra_fields, sec_struct=sec_struct + ) + self.n_atoms = self.array.array_length() + + def _read(self, file_path): + return pdbx.legacy.PDBxFile.read(file_path) + + def _get_structure(self, extra_fields: str = None, sec_struct=True, bonds=True): + fields = ["b_factor", "charge", "occupancy", "atom_id"] + if extra_fields: + [fields.append(x) for x in extra_fields] + + # if the 'atom_site' doesn't exist then it will just be a small molecule + # which can be extracted with the get_component() + try: + array = pdbx.get_structure(self.file, extra_fields=extra_fields) + try: + array.set_annotation( + "sec_struct", _get_secondary_structure(array, self.file) + ) + except KeyError: + warnings.warn("No secondary structure information.") + try: + array.set_annotation("entity_id", _get_entity_id(array, self.file)) + except KeyError: + warnings.warn("Non entity_id information.") + + except InvalidFileError: + array = pdbx.get_component(self.file) + + # pdbx files don't seem to have bond information defined, so connect them based + # on their residue names + if not array.bonds and bonds: + array.bonds = struc.bonds.connect_via_residue_names( + array, inter_residue=True + ) + + return array + + def _entity_ids(self): + entities = self.file["entity"] + if not entities: + return None + + return entities.get("pdbx_description", None) + + def _assemblies(self): + return CIFAssemblyParser(self.file).get_assemblies() + + +def _ss_label_to_int(label): + if "HELX" in label: + return 1 + elif "STRN" in label: + return 2 + else: + return 3 + + +def _get_secondary_structure(array, file): + """ + Get secondary structure information for the array from the file. + + Parameters + ---------- + array : numpy array + The array for which secondary structure information is to be retrieved. + file : object + The file object containing the secondary structure information. + + Returns + ------- + numpy array + A numpy array of secondary structure information, where each element is either 0, 1, 2, or 3. + - 0: Not a peptide + - 1: Alpha helix + - 2: Beta sheet + - 3: Loop + + Raises + ------ + KeyError + If the 'struct_conf' category is not found in the file. + """ + + # get the annotations for the struc_conf cetegory. Provides start and end + # residues for the annotations. For most files this will only contain the + # alpha helices, but will sometimes contain also other secondary structure + # information such as in AlphaFold predictions + + conf = file.get_category("struct_conf") + if not conf: + raise KeyError + starts = conf["beg_auth_seq_id"].astype(int) + ends = conf["end_auth_seq_id"].astype(int) + chains = conf["end_auth_asym_id"].astype(str) + id_label = conf["id"].astype(str) + + # most files will have a separate category for the beta sheets + # this can just be appended to the other start / end / id and be processed + # as normal + sheet = file.get_category("struct_sheet_range") + if sheet: + starts = np.append(starts, sheet["beg_auth_seq_id"].astype(int)) + ends = np.append(ends, sheet["end_auth_seq_id"].astype(int)) + chains = np.append(chains, sheet["end_auth_asym_id"].astype(str)) + id_label = np.append(id_label, np.repeat("STRN", len(sheet["id"]))) + + # convert the string labels to integer representations of the SS + # AH: 1, BS: 2, LOOP: 3 + + id_int = np.array([_ss_label_to_int(label) for label in id_label], int) + + # create a lookup dictionary that enables lookup of secondary structure + # based on the chain_id and res_id values + + lookup = dict() + for chain in np.unique(chains): + arrays = [] + mask = chain == chains + start_sub = starts[mask] + end_sub = ends[mask] + id_sub = id_int[mask] + + for start, end, id in zip(start_sub, end_sub, id_sub): + idx = np.arange(start, end + 1, dtype=int) + arr = np.zeros((len(idx), 2), dtype=int) + arr[:, 0] = idx + arr[:, 1] = 3 + arr[:, 1] = id + arrays.append(arr) + + lookup[chain] = dict(np.vstack(arrays).tolist()) + + # use the lookup dictionary to get the SS annotation based on the chain_id and res_id + secondary_structure = np.zeros(len(array.chain_id), int) + for i, (chain, res) in enumerate(zip(array.chain_id, array.res_id)): + try: + secondary_structure[i] = lookup[chain].get(res, 3) + except KeyError: + secondary_structure[i] = 0 + + # assign SS to 0 where not peptide + secondary_structure[~struc.filter_amino_acids(array)] = 0 + return secondary_structure + + +def _get_entity_id(array, file): + entities = file.get_category("entity_poly") + if not entities: + raise KeyError + chain_ids = entities["pdbx_strand_id"] + + # the chain_ids are an array of individual items np.array(['A,B', 'C', 'D,E,F']) + # which need to be categorised as [1, 1, 2, 3, 3, 3] for their belonging to individual + # entities + + chains = [] + idx = [] + for i, chain_str in enumerate(chain_ids): + for chain in chain_str.split(","): + chains.append(chain) + idx.append(i) + + entity_lookup = dict(zip(chains, idx)) + chain_id_int = np.array( + [entity_lookup.get(chain, -1) for chain in array.chain_id], int + ) + return chain_id_int + + +class CIFAssemblyParser(AssemblyParser): + # Implementation adapted from ``biotite.structure.io.pdbx.convert`` + + def __init__(self, file_cif): + self._file = file_cif + + def list_assemblies(self): + return list(pdbx.list_assemblies(self._file).keys()) + + def get_transformations(self, assembly_id): + assembly_gen_category = self._file["pdbx_struct_assembly_gen"] + + struct_oper_category = self._file["pdbx_struct_oper_list"] + + if assembly_id not in assembly_gen_category["assembly_id"]: + raise KeyError(f"File has no Assembly ID '{assembly_id}'") + + # Extract all possible transformations indexed by operation ID + transformation_dict = _get_transformations(struct_oper_category) + + # Get necessary transformations and the affected chain IDs + # NOTE: The chains given here refer to the `label_asym_id` field + # of the `atom_site` category + # However, by default `PDBxFile` uses the `auth_asym_id` as + # chain ID + matrices = [] + for id, op_expr, asym_id_expr in zip( + assembly_gen_category["assembly_id"], + assembly_gen_category["oper_expression"], + assembly_gen_category["asym_id_list"], + ): + # Find the operation expressions for given assembly ID + # We already asserted that the ID is actually present + if id == assembly_id: + operations = _parse_operation_expression(op_expr) + affected_chain_ids = asym_id_expr.split(",") + for i, operation in enumerate(operations): + rotations = [] + translations = [] + for op_step in operation: + rotation, translation = transformation_dict[op_step] + rotations.append(rotation) + translations.append(translation) + matrix = _chain_transformations(rotations, translations) + matrices.append((affected_chain_ids, matrix.tolist())) + + return matrices + + def get_assemblies(self): + assembly_dict = {} + for assembly_id in self.list_assemblies(): + assembly_dict[assembly_id] = self.get_transformations(assembly_id) + + return assembly_dict + + +def _chain_transformations(rotations, translations): + """ + Get a total rotation/translation transformation by combining + multiple rotation/translation transformations. + This is done by intermediately combining rotation matrices and + translation vectors into 4x4 matrices in the form + + |r11 r12 r13 t1| + |r21 r22 r23 t2| + |r31 r32 r33 t3| + |0 0 0 1 |. + """ + total_matrix = np.identity(4) + for rotation, translation in zip(rotations, translations): + matrix = np.zeros((4, 4)) + matrix[:3, :3] = rotation + matrix[:3, 3] = translation + matrix[3, 3] = 1 + total_matrix = matrix @ total_matrix + + # return total_matrix[:3, :3], total_matrix[:3, 3] + return matrix + + +def _get_transformations(struct_oper): + """ + Get transformation operation in terms of rotation matrix and + translation for each operation ID in ``pdbx_struct_oper_list``. + """ + transformation_dict = {} + for index, id in enumerate(struct_oper["id"]): + rotation_matrix = np.array( + [ + [float(struct_oper[f"matrix[{i}][{j}]"][index]) for j in (1, 2, 3)] + for i in (1, 2, 3) + ] + ) + translation_vector = np.array( + [float(struct_oper[f"vector[{i}]"][index]) for i in (1, 2, 3)] + ) + transformation_dict[id] = (rotation_matrix, translation_vector) + return transformation_dict + + +def _parse_operation_expression(expression): + """ + Get successive operation steps (IDs) for the given + ``oper_expression``. + Form the cartesian product, if necessary. + """ + # Split groups by parentheses: + # use the opening parenthesis as delimiter + # and just remove the closing parenthesis + expressions_per_step = expression.replace(")", "").split("(") + expressions_per_step = [e for e in expressions_per_step if len(e) > 0] + # Important: Operations are applied from right to left + expressions_per_step.reverse() + + operations = [] + for expr in expressions_per_step: + if "-" in expr: + if "," in expr: + for gexpr in expr.split(","): + if "-" in gexpr: + first, last = gexpr.split("-") + operations.append( + [str(id) for id in range(int(first), int(last) + 1)] + ) + else: + operations.append([gexpr]) + else: + # Range of operation IDs, they must be integers + first, last = expr.split("-") + operations.append([str(id) for id in range(int(first), int(last) + 1)]) + elif "," in expr: + # List of operation IDs + operations.append(expr.split(",")) + else: + # Single operation ID + operations.append([expr]) + + # Cartesian product of operations + return list(itertools.product(*operations)) diff --git a/molecularnodes/entities/ensemble/ui.py b/molecularnodes/entities/ensemble/ui.py index 8464dd48..1e104cf3 100644 --- a/molecularnodes/entities/ensemble/ui.py +++ b/molecularnodes/entities/ensemble/ui.py @@ -63,24 +63,35 @@ def panel_starfile(layout, scene): subtype="FILE_PATH", maxlen=0, ) + bpy.types.Scene.mol_import_cell_pack_name = bpy.props.StringProperty( name="Name", description="Name of the created object.", - default="NewCellPackModel", + default="NewMesoscaleModel", maxlen=0, ) +bpy.types.Scene.mol_import_cell_pack_remove_space = bpy.props.BoolProperty( + name="Remove Space", + description="Remove spaces from cif file lines.", + default=False, +) + def load_cellpack( file_path, name="NewCellPackModel", + remove_space=False, node_setup=True, world_scale=0.01, - fraction: float = 1, + fraction: float = 1 ): - ensemble = CellPack(file_path) + ensemble = CellPack(file_path, remove_space=remove_space) model = ensemble.create_object( - name=name, node_setup=node_setup, world_scale=world_scale, fraction=fraction + name=name, + node_setup=node_setup, + world_scale=world_scale, + fraction=fraction ) return model @@ -97,7 +108,8 @@ def execute(self, context): load_cellpack( file_path=s.mol_import_cell_pack_path, name=s.mol_import_cell_pack_name, - node_setup=True, + remove_space=s.mol_import_cell_pack_remove_space, + node_setup=True ) return {"FINISHED"} @@ -108,4 +120,5 @@ def panel_cellpack(layout, scene): row_import = layout.row() row_import.prop(scene, "mol_import_cell_pack_name") layout.prop(scene, "mol_import_cell_pack_path") + layout.prop(scene, "mol_import_cell_pack_remove_space") row_import.operator("mol.import_cell_pack") diff --git a/molecularnodes/entities/molecule/ui.py b/molecularnodes/entities/molecule/ui.py index f574d7df..fb622496 100644 --- a/molecularnodes/entities/molecule/ui.py +++ b/molecularnodes/entities/molecule/ui.py @@ -7,7 +7,7 @@ from ...download import FileDownloadPDBError, download, CACHE_DIR from ...blender import path_resolve -from ..ensemble.cif import OldCIF +from ..ensemble.oldcif import OldCIF from .molecule import Molecule from .pdb import PDB from .pdbx import BCIF, CIF diff --git a/tests/test_assembly.py b/tests/test_assembly.py index 4705a924..176afe9d 100644 --- a/tests/test_assembly.py +++ b/tests/test_assembly.py @@ -6,7 +6,7 @@ import biotite.structure.io.pdbx as biotite_cif import molecularnodes.entities.molecule.pdb as pdb import molecularnodes.entities.ensemble.cif as cif - +import molecularnodes.entities.ensemble.oldcif as oldcif DATA_DIR = join(dirname(realpath(__file__)), "data") @@ -34,7 +34,7 @@ def test_get_transformations(pdb_id, format): use_author_fields=False, ) ref_assembly = biotite_cif.get_assembly(cif_file, model=1) - test_parser = cif.CIFAssemblyParser(cif_file) + test_parser = oldcif.CIFAssemblyParser(cif_file) else: raise ValueError(f"Format '{format}' does not exist") @@ -62,7 +62,7 @@ def test_get_transformations_cif(assembly_id): ) ref_assembly = biotite_cif.get_assembly(cif_file, model=1, assembly_id=assembly_id) - test_parser = cif.CIFAssemblyParser(cif_file) + test_parser = oldcif.CIFAssemblyParser(cif_file) test_transformations = test_parser.get_transformations(assembly_id) check_transformations(test_transformations, atoms, ref_assembly)