Skip to content

Commit

Permalink
Merge pull request #280 from BradyAJohnston/dev-att-entity
Browse files Browse the repository at this point in the history
Add `entity_id` attribute
  • Loading branch information
BradyAJohnston authored Aug 3, 2023
2 parents 6b6ec46 + 7b850a0 commit 3ed5b52
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 30 deletions.
2 changes: 1 addition & 1 deletion MolecularNodes/density.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def map_to_vdb(file: str, invert: bool = False, world_scale=0.01, overwrite=Fals

# Rotate and scale the grid for import into Blender
grid.transform.rotate(np.pi / 2, vdb.Axis(1))
grid.transform.scale(np.array((-1, 1, 1)) * world_scale * voxel_size)

# Write the grid to a .vdb file
vdb.write(file_path, grid)
Expand Down
44 changes: 42 additions & 2 deletions MolecularNodes/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,38 @@ def molecule_local(

return mol_object

def get_chain_entity_id(file):
entities = file['entityList']
chain_names = file['chainIdList']
n_chains = len(chain_names)

arr_entity = np.zeros(n_chains, dtype = int)

counter = 0
for i, entity in enumerate(entities):
chain_idxs = entity['chainIndexList']

mask = np.array(range(len(chain_idxs))) + counter

arr_entity[mask] = i
# arr_entity[mask, 1] = chain_idxs
counter += len(chain_idxs)

return arr_entity

def set_atom_entity_id(mol, file):
mol.add_annotation('entity_id', int)
chain_names = file['chainNameList']
chain_entity_id = get_chain_entity_id(file)

chain_ids = np.array(list(map(
lambda x: np.where(x == chain_names)[0][0],
mol.chain_id
)))

entity_ids = chain_entity_id[chain_ids]
mol.set_annotation('entity_id', entity_ids)
return entity_ids

def open_structure_rcsb(pdb_code, cache_dir = None, include_bonds = True):
import biotite.structure.io.mmtf as mmtf
Expand All @@ -198,10 +230,9 @@ def open_structure_rcsb(pdb_code, cache_dir = None, include_bonds = True):
# returns a numpy array stack, where each array in the stack is a model in the
# the file. The stack will be of length = 1 if there is only one model in the file
mol = mmtf.get_structure(file, extra_fields = ["b_factor", "charge"], include_bonds = include_bonds)
set_atom_entity_id(mol, file)
return mol, file



def open_structure_local_pdb(file_path, include_bonds = True):
import biotite.structure.io.pdb as pdb

Expand Down Expand Up @@ -435,6 +466,9 @@ def att_chain_id():
chain_id = np.searchsorted(np.unique(mol_array.chain_id), mol_array.chain_id)
return chain_id

def att_entity_id():
return mol_array.entity_id

def att_b_factor():
return mol_array.b_factor

Expand Down Expand Up @@ -545,6 +579,7 @@ def att_sec_struct():
{'name': 'b_factor', 'value': att_b_factor, 'type': 'FLOAT', 'domain': 'POINT'},
{'name': 'vdw_radii', 'value': att_vdw_radii, 'type': 'FLOAT', 'domain': 'POINT'},
{'name': 'chain_id', 'value': att_chain_id, 'type': 'INT', 'domain': 'POINT'},
{'name': 'entity_id', 'value': att_entity_id, 'type': 'INT', 'domain': 'POINT'},
{'name': 'atom_name', 'value': att_atom_name, 'type': 'INT', 'domain': 'POINT'},
{'name': 'lipophobicity', 'value': att_lipophobicity, 'type': 'FLOAT', 'domain': 'POINT'},
{'name': 'charge', 'value': att_charge, 'type': 'FLOAT', 'domain': 'POINT'},
Expand Down Expand Up @@ -598,4 +633,9 @@ def att_sec_struct():
except:
warnings.warn('No chain information detected.')

try:
mol_object['entity_names'] = [ent['description'] for ent in file['entityList']]
except:
pass

return mol_object, coll_frames
6 changes: 3 additions & 3 deletions MolecularNodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def rotation_matrix(node_group, mat, location = [0,0], world_scale = 0.01):

return node

def chain_selection(node_name, input_list, attribute, starting_value = 0, label_prefix = ""):
def chain_selection(node_name, input_list, attribute = 'chain_id', starting_value = 0, label_prefix = ""):
"""
Given a an input_list, will create a node which takes an Integer input,
and has a boolean tick box for each item in the input list. The outputs will
Expand Down Expand Up @@ -667,7 +667,7 @@ def chain_selection(node_name, input_list, attribute, starting_value = 0, label_
# these are custom properties that are associated with the object when it is initial created
return chain_group

def chain_color(node_name, input_list, label_prefix = "Chain "):
def chain_color(node_name, input_list, label_prefix = "Chain ", field = "chain_id"):
"""
Given the input list of chain names, will create a node group which uses
the chain_id named attribute to manually set the colours for each of the chains.
Expand Down Expand Up @@ -702,7 +702,7 @@ def chain_color(node_name, input_list, label_prefix = "Chain "):
chain_number_node = chain_group.nodes.new("GeometryNodeInputNamedAttribute")
chain_number_node.data_type = 'INT'
chain_number_node.location = [-200, 400]
chain_number_node.inputs[0].default_value = 'chain_id'
chain_number_node.inputs[0].default_value = field
chain_number_node.outputs.get('Attribute')

# shortcut for creating new nodes
Expand Down
91 changes: 67 additions & 24 deletions MolecularNodes/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,35 +438,51 @@ def menu_item_surface_custom(layout_function, label):
emboss = True,
depress = True)

def menu_item_color_chains(layout_function, label):
op = layout_function.operator('mol.color_chains',
text = label,
emboss = True,
depress = True)

class MOL_OT_Color_Chain(bpy.types.Operator):
bl_idname = "mol.color_chains"
bl_label = "My Class Name"
bl_description = "Create a custom node for coloring each chain of a structure \
individually.\nRequires chain information to be available from the structure"
class MOL_OT_Custom_Color_Node(bpy.types.Operator):
bl_idname = "mol.custom_color_node"
bl_label = "Custom color by field node."
bl_options = {"REGISTER", "UNDO"}

node_name: bpy.props.StringProperty(
name = "node_name",
default = ""
)

node_property: bpy.props.StringProperty(
name = "node_property",
default = ""
)

field: bpy.props.StringProperty(
name = "field",
default = "chain_id"
)

prefix: bpy.props.StringProperty(
name = "prefix",
default = "Chain"
)

@classmethod
def poll(cls, context):
return True

return not False
def execute(self, context):
obj = context.active_object
try:
node_color_chain = nodes.chain_color(
node_name = f"MOL_color_chains_{obj.name}",
input_list = obj['chain_id_unique']
node_color = nodes.chain_color(
node_name = f"MOL_color_{self.node_name}_{obj.name}",
input_list = obj[self.node_property],
field = self.field,
label_prefix= self.prefix
)
mol_add_node(node_color_chain.name)
mol_add_node(node_color.name)
except:
self.report({'WARNING'}, message = 'Unable to detect chain information.')

self.report({"WARNING"}, message = f"{self.node_propperty} not available for object.")
return {"FINISHED"}

def invoke(self, context, event):
return self.execute(context)

def menu_chain_selection_custom(layout_function):
obj = bpy.context.view_layer.objects.active
Expand All @@ -486,18 +502,23 @@ class MOL_OT_Chain_Selection_Custom(bpy.types.Operator):
no chain information is available this node will not work"
bl_options = {"REGISTER", "UNDO"}

field: bpy.props.StringProperty(name = "field", default = "chain_id")
prefix: bpy.props.StringProperty(name = "prefix", default = "Chain ")
node_property: bpy.props.StringProperty(name = "node_property", default = "chain_id_unique")
node_name: bpy.props.StringProperty(name = "node_name", default = "chain")

@classmethod
def poll(cls, context):
return True

def execute(self, context):
obj = bpy.context.view_layer.objects.active
node_chains = nodes.chain_selection(
node_name = 'MOL_sel_' + str(obj.name) + "_chains",
input_list = obj['chain_id_unique'],
node_name = f'MOL_sel_{self.node_name}_{obj.name}',
input_list = obj[self.node_property],
starting_value = 0,
attribute = 'chain_id',
label_prefix = "Chain "
attribute = self.field,
label_prefix = self.prefix
)

mol_add_node(node_chains.name)
Expand Down Expand Up @@ -615,7 +636,17 @@ def draw(self, context):
"Creates a color based on atomic_number field")
menu_item_interface(layout, 'Color by Element', 'MOL_color_element',
"Choose a color for each of the first 20 elements")
menu_item_color_chains(layout, 'Color by Chains')
# menu_item_color_chains(layout, 'Color by Chains')
op = layout.operator('mol.custom_color_node', text = 'Color by Chain')
op.node_property = 'chain_id_unique'
op.node_name = "chain"
op.prefix = 'Chain '
op.field = 'chain_id'
op = layout.operator('mol.custom_color_node', text = 'Color by Entity')
op.node_property = 'entity_names'
op.node_name = "chain"
op.prefix = ""
op.field = 'entity_id'

class MOL_MT_Add_Node_Menu_Bonds(bpy.types.Menu):
bl_idname = 'MOL_MT_ADD_NODE_MENU_BONDS'
Expand Down Expand Up @@ -699,6 +730,18 @@ def draw(self, context):
"Outputs for protein, nucleic & sugars")
layout.separator()
menu_chain_selection_custom(layout)
op = layout.operator('mol.chain_selection_custom', text = 'Chain Selection')
op.field = 'chain_id'
op.prefix = 'Chain '
op.node_property = 'chain_id_unique'
op.field = 'chain_id'
op.node_name = 'chain'
op = layout.operator('mol.chain_selection_custom', text = 'Entity Selection')
op.field = 'entity_id'
op.prefix = ''
op.node_property = 'entity_names'
op.field = 'entity_id'
op.node_name = 'entity'
menu_ligand_selection_custom(layout)
layout.separator()
menu_item_interface(layout, 'Backbone', 'MOL_sel_backbone',
Expand Down
9 changes: 9 additions & 0 deletions tests/test_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import MolecularNodes as mn
import numpy as np

def test_entity_id():
mol = mn.load.molecule_rcsb('1cd3')
ents = mn.obj.get_attribute(mol, 'entity_id')

assert np.all(np.isin(ents, np.array(range(4))))

0 comments on commit 3ed5b52

Please sign in to comment.