Skip to content

Commit

Permalink
Merge branch 'main' of github.com:OpenFreeEnergy/konnektor
Browse files Browse the repository at this point in the history
  • Loading branch information
RiesBen committed Sep 18, 2024
2 parents 4c00c73 + 540ed94 commit c5c4549
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 16 deletions.
19 changes: 13 additions & 6 deletions examples/konnektor_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
"source": [
"from konnektor.utils import toy_data\n",
"\n",
"components, mapper, combo_scorer = toy_data.build_random_dataset(\n",
"components, mapper, scorer = toy_data.build_random_dataset(\n",
" n_compounds=20, rand_seed=42\n",
")"
]
Expand Down Expand Up @@ -107,7 +107,7 @@
"from konnektor.network_planners import StarNetworkGenerator\n",
"\n",
"network_planner = StarNetworkGenerator(\n",
" mappers=mapper, scorer=combo_scorer, n_processes=1\n",
" mappers=mapper, scorer=scorer, n_processes=1\n",
")"
]
},
Expand Down Expand Up @@ -215,7 +215,7 @@
"source": [
"from konnektor.network_planners import CyclicNetworkGenerator\n",
"\n",
"network_planner = CyclicNetworkGenerator(mappers=mapper, scorer=combo_scorer)"
"network_planner = CyclicNetworkGenerator(mappers=mapper, scorer=scorer)"
]
},
{
Expand Down Expand Up @@ -297,7 +297,7 @@
"source": [
"from konnektor.network_planners import StarrySkyNetworkGenerator\n",
"\n",
"network_planner = StarrySkyNetworkGenerator(mappers=mapper, scorer=combo_scorer)"
"network_planner = StarrySkyNetworkGenerator(mappers=mapper, scorer=scorer)"
]
},
{
Expand Down Expand Up @@ -748,6 +748,7 @@
}
],
"source": [
"# NBVAL_SKIP\n",
"## Concatenate Sub-Networks\n",
"from konnektor.network_tools import concatenate_networks\n",
"from konnektor.network_planners import MstConcatenator\n",
Expand Down Expand Up @@ -805,6 +806,8 @@
"metadata": {},
"outputs": [],
"source": [
"# NBVAL_SKIP\n",
"\n",
"from konnektor.visualization import draw_network_widget\n",
"\n",
"draw_network_widget(charged_starry_sky_network, show_molecules=True, show_mappings=True);"
Expand Down Expand Up @@ -832,9 +835,11 @@
},
"outputs": [],
"source": [
"# NBVAL_SKIP\n",
"\n",
"# first let's generate the Star Network (we will use the networker from above).\n",
"ch_radial_networker = StarNetworkGenerator(\n",
" mappers=mapper, scorer=combo_scorer, n_processes=5\n",
" mappers=mapper, scorer=combo_scorer, n_processes=1\n",
")\n",
"charge_star_network = ch_radial_networker(used_componentes)\n",
"charge_star_network.name = \"Star Network\"\n",
Expand All @@ -843,7 +848,7 @@
"from konnektor.network_planners import RedundantMinimalSpanningTreeNetworkGenerator\n",
"\n",
"ch_rmst_networker = RedundantMinimalSpanningTreeNetworkGenerator(\n",
" mappers=mapper, scorer=combo_scorer, n_processes=5, n_redundancy=2\n",
" mappers=mapper, scorer=combo_scorer, n_processes=1, n_redundancy=2\n",
")\n",
"charge_rmst_network = ch_rmst_networker(used_componentes)\n",
"charge_rmst_network.name = \"Redundant MST Network\""
Expand Down Expand Up @@ -923,6 +928,8 @@
}
],
"source": [
"# NBVAL_SKIP\n",
"\n",
"import numpy as np\n",
"from konnektor.network_analysis.network_analysis import (\n",
" get_network_cost,\n",
Expand Down
70 changes: 63 additions & 7 deletions src/konnektor/network_planners/concatenators/max_concatenator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/konnektor

from tqdm import tqdm
import functools
import itertools
import logging
from typing import Iterable, Union
Expand Down Expand Up @@ -84,13 +86,64 @@ def concatenate_networks(
nodesB = ligandNetworkB.nodes
pedges = [(na, nb) for na in nodesA for nb in nodesB]

bipartite_graph_mappings = _parallel_map_scoring(
possible_edges=pedges,
scorer=self.scorer,
mappers=self.mappers,
n_processes=self.n_processes,
show_progress=self.progress,
)
if self.n_processes > 1:
bipartite_graph_mappings = _parallel_map_scoring(
possible_edges=pedges,
scorer=self.scorer,
mappers=self.mappers,
n_processes=self.n_processes,
show_progress=self.progress,
)

else: # serial variant
if self.progress is True:
progress = functools.partial(
tqdm, total=len(pedges), delay=1.5, desc="Mapping Subnets"
)
else:
progress = lambda x: x

bipartite_graph_mappings = []
for component_pair in progress(pedges):
best_score = 0.0
best_mapping = None
molA = component_pair[0]
molB = component_pair[1]

for mapper in self.mappers:
try:
mapping_generator = mapper.suggest_mappings(molA, molB)
except:
continue

if self.scorer:
tmp_mappings = [
mapping.with_annotations(
{"score": self.scorer(mapping)}
)
for mapping in mapping_generator
]

if len(tmp_mappings) > 0:
tmp_best_mapping = min(
tmp_mappings, key=lambda m: m.annotations["score"]
)

if (
tmp_best_mapping.annotations["score"] < best_score
or best_mapping is None
):
best_score = tmp_best_mapping.annotations["score"]
best_mapping = tmp_best_mapping
else:
try:
best_mapping = next(mapping_generator)
except:
print("warning")
continue

if best_mapping is not None:
bipartite_graph_mappings.append(best_mapping)

# Add network connecting edges
selected_edges.extend(bipartite_graph_mappings)
Expand All @@ -107,4 +160,7 @@ def concatenate_networks(

log.info(f"Total Concatenated Edges: {len(selected_edges)} ")

if not concat_LigandNetwork.is_connected():
raise RuntimeError("could not build a connected network!")

return concat_LigandNetwork
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# For details, see https://github.com/OpenFreeEnergy/konnektor

import functools
import logging
import multiprocessing as mult

from gufe import AtomMapper, AtomMapping
Expand All @@ -25,6 +24,7 @@ def thread_mapping(args) -> list[AtomMapping]:
return a list of scored atom mappings
"""

jobID, compound_pairs, mappers, scorer = args

mappings = []
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/konnektor

import pytest
from gufe import LigandNetwork

from konnektor.network_planners.concatenators.max_concatenator import MaxConcatenator
Expand All @@ -13,8 +14,11 @@


# more test here also for the params
def test_max_network_concatenation(ligand_network_ab):
concatenator = MaxConcatenator(mappers=GenAtomMapper(), scorer=genScorer)
@pytest.mark.parametrize("n_process", [1, 2])
def test_max_network_concatenation(ligand_network_ab, n_process):
concatenator = MaxConcatenator(
mappers=GenAtomMapper(), scorer=genScorer, n_processes=n_process
)

ln_a, ln_b = ligand_network_ab
nA = len(ln_a.nodes)
Expand Down

0 comments on commit c5c4549

Please sign in to comment.