Skip to content

Commit

Permalink
Merge pull request #11 from samblau/sam_dev
Browse files Browse the repository at this point in the history
Sam dev
  • Loading branch information
samblau authored Nov 6, 2023
2 parents 62d5f2e + bad3ecf commit 7792880
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 51 deletions.
18 changes: 18 additions & 0 deletions HiPRGen/initial_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,22 @@ def find_mol_entry_by_entry_id(mol_entries, entry_id):
);
"""

create_interrupt_state_table = """
CREATE TABLE interrupt_state (
seed INTEGER NOT NULL,
species_id INTEGER NOT NULL,
count INTEGER NOT NULL
);
"""

create_interrupt_cutoff_table = """
CREATE TABLE interrupt_cutoff (
seed INTEGER NOT NULL,
step INTEGER NOT NULL,
time INTEGER NOT NULL
);
"""


def insert_initial_state(
initial_state,
Expand All @@ -83,6 +99,8 @@ def insert_initial_state(
rn_cur.execute(create_initial_state_table)
rn_cur.execute(create_trajectories_table)
rn_cur.execute(create_factors_table)
rn_cur.execute(create_interrupt_state_table)
rn_cur.execute(create_interrupt_cutoff_table)
rn_con.commit()

rn_cur.execute(
Expand Down
45 changes: 28 additions & 17 deletions HiPRGen/mc_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def reaction_tally_report(
for (reaction_index, number) in sorted(
reaction_tally.items(), key=lambda pair: -pair[1]):
if number > cutoff:
report_generator.emit_text(str(number) + " occourances of:")
report_generator.emit_text(str(number) + " occurrences of:")
report_generator.emit_reaction(
network_loader.index_to_reaction(reaction_index),
label=str(reaction_index)
Expand Down Expand Up @@ -839,7 +839,7 @@ def compute_expected_final_state(self):
)

for seed in self.network_loader.trajectories:
state = np.copy(self.network_loader.initial_state_array)
state = np.copy(self.network_loader.initial_state_array[seed])
for step in self.network_loader.trajectories[seed]:
reaction_index = self.network_loader.trajectories[seed][step][0]
time = self.network_loader.trajectories[seed][step][1]
Expand All @@ -861,7 +861,7 @@ def compute_expected_final_state(self):
def compute_trajectory_final_states(self):
self.final_states = {}
for seed in self.network_loader.trajectories:
state = np.copy(self.network_loader.initial_state_array)
state = np.copy(self.network_loader.initial_state_array[seed])
for step in self.network_loader.trajectories[seed]:
reaction_index = self.network_loader.trajectories[seed][step][0]
time = self.network_loader.trajectories[seed][step][1]
Expand Down Expand Up @@ -906,13 +906,13 @@ def compute_production_consumption_info(self):
self.producing_reactions[product_index][reaction_index] += 1

def compute_state_time_series(self, seed):
state_dimension_size = len(self.network_loader.initial_state_array)
state_dimension_size = len(self.network_loader.initial_state_array[seed])
step_dimension_size = len(self.network_loader.trajectories[seed])
time_series = np.zeros(
(step_dimension_size, state_dimension_size),
dtype=int)

state = np.copy(self.network_loader.initial_state_array)
state = np.copy(self.network_loader.initial_state_array[seed])
for step in self.network_loader.trajectories[seed]:
reaction_index = self.network_loader.trajectories[seed][step][0]
time = self.network_loader.trajectories[seed][step][1]
Expand All @@ -936,9 +936,11 @@ def time_series_graph(
seeds,
species_of_interest,
path,
custom_y_max=None,
custom_colorstyle_list=None,
colors = list(mcolors.TABLEAU_COLORS.values()),
styles = ['solid', 'dotted', 'dashed', 'dashdot'],
internal_index_labels=True
styles = ['solid', 'dotted', 'dashed', 'dashdot','solid', 'dotted', 'dashed', 'dashdot','solid', 'dotted', 'dashed', 'dashdot','solid', 'dotted', 'dashed', 'dashdot'],
internal_index_labels=True,
):


Expand Down Expand Up @@ -970,11 +972,16 @@ def time_series_graph(

line_dict = {}
i = 0
for species_index in species_of_interest:
r = i % len(colors)
q = i // len(colors)
line_dict[species_index] = (colors[r], styles[q])
i += 1
if custom_colorstyle_list is None:
for species_index in species_of_interest:
r = i % len(colors)
q = i // len(colors)
line_dict[species_index] = (colors[r], styles[q])
i += 1
else:
for species_index in species_of_interest:
line_dict[species_index] = (custom_colorstyle_list[i][0], custom_colorstyle_list[i][1])
i += 1


fig, (ax0, ax1, ax2) = plt.subplots(
Expand All @@ -983,15 +990,19 @@ def time_series_graph(
gridspec_kw={'height_ratios':[2,2,1]})

y_max = 0
for step in range(total_time_series.shape[0]):
for species_index in species_of_interest:
y_max = max(y_max, total_time_series[step,species_index])
if custom_y_max is None:
for step in range(total_time_series.shape[0]):
for species_index in species_of_interest:
y_max = max(y_max, total_time_series[step,species_index])
y_max += 1
else:
y_max = custom_y_max

ax0.set_xlim([0,total_time_series.shape[0]])
ax0.set_ylim([0,y_max+1])
ax0.set_ylim([0,y_max])

ax1.set_xlim([0,total_time_series.shape[0]])
ax1.set_ylim([0,(y_max+1)/10])
ax1.set_ylim([0,(y_max)/10])


ticks = np.arange(0, total_time_series.shape[0])
Expand Down
37 changes: 16 additions & 21 deletions HiPRGen/network_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,23 +161,7 @@ def index_to_reaction(self, reaction_index):
return reaction


def load_trajectories(self):

cur = self.initial_state_con.cursor()

for row in cur.execute(sql_get_trajectory):
seed = row[0]
step = row[1]
reaction_id = row[2]
time = row[3]

if seed not in self.trajectories:
self.trajectories[seed] = {}

self.trajectories[seed][step] = (reaction_id, time)


def load_initial_state(self):
def load_initial_state_and_trajectories(self):

cur = self.initial_state_con.cursor()
initial_state_dict = {}
Expand All @@ -193,15 +177,26 @@ def load_initial_state(self):
for i in range(self.number_of_species):
initial_state_array[i] = initial_state_dict[i]

if self.initial_state_dict == {} and self.initial_state_array == {}:
if self.initial_state_dict == {}:
self.initial_state_dict = initial_state_dict
self.initial_state_array = initial_state_array
else:
for i in range(self.number_of_species):
if initial_state_array[i] > self.initial_state_array[i]:
self.initial_state_array[i] = initial_state_array[i]
if initial_state_dict[i] > self.initial_state_dict[i]:
self.initial_state_dict[i] = initial_state_dict[i]

for row in cur.execute(sql_get_trajectory):
seed = row[0]
step = row[1]
reaction_id = row[2]
time = row[3]

if seed not in self.trajectories:
self.trajectories[seed] = {}
self.initial_state_array[seed] = initial_state_array

self.trajectories[seed][step] = (reaction_id, time)



def set_initial_state_db(self, initial_state_database):
# NOTE: switching to a new initial state database and loading in trajectory
Expand Down
62 changes: 49 additions & 13 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import subprocess
import sqlite3
import pickle
import copy


import matplotlib.colors as mcolors
from HiPRGen.network_loader import NetworkLoader
from HiPRGen.initial_state import find_mol_entry_from_xyz_and_charge
from monty.serialization import loadfn, dumpfn
Expand Down Expand Up @@ -218,8 +219,7 @@ def li_test():
folder + "/initial_state.sqlite",
)

network_loader.load_trajectories()
network_loader.load_initial_state()
network_loader.load_initial_state_and_trajectories()

# HiPRGen has analysis tools to understand what happened in our simulation.
# The output files are written into the same folder in which the reaction
Expand Down Expand Up @@ -385,8 +385,7 @@ def mg_test():
folder + "/initial_state.sqlite",
)

network_loader.load_trajectories()
network_loader.load_initial_state()
network_loader.load_initial_state_and_trajectories()

report_generator = ReportGenerator(
network_loader.mol_entries, folder + "/dummy.tex", rebuild_mol_pictures=True
Expand Down Expand Up @@ -503,8 +502,7 @@ def mg_test():
# folder + "/initial_state.sqlite",
# )

# network_loader.load_trajectories()
# network_loader.load_initial_state()
# network_loader.load_initial_state_and_trajectories()

# report_generator = ReportGenerator(
# network_loader.mol_entries, folder + "/dummy.tex", rebuild_mol_pictures=True
Expand Down Expand Up @@ -673,8 +671,7 @@ def euvl_phase1_test():
folder + "/initial_state.sqlite",
)

network_loader.load_trajectories()
network_loader.load_initial_state()
network_loader.load_initial_state_and_trajectories()

report_generator = ReportGenerator(
network_loader.mol_entries, folder + "/dummy.tex", rebuild_mol_pictures=True
Expand Down Expand Up @@ -770,8 +767,7 @@ def euvl_phase2_test():
phase1_folder + "/mol_entries.pickle",
phase1_folder + f"/initial_state.sqlite",
)
phase1_network_loader.load_trajectories()
phase1_network_loader.load_initial_state()
phase1_network_loader.load_initial_state_and_trajectories()
phase1_simulation_replayer = SimulationReplayer(phase1_network_loader)
phase1_simulation_replayer.compute_trajectory_final_states()

Expand Down Expand Up @@ -807,8 +803,7 @@ def euvl_phase2_test():

for seed in range(1000, 2000):
network_loader.set_initial_state_db(folder + "/initial_state_"+str(seed)+".sqlite")
network_loader.load_trajectories()
network_loader.load_initial_state()
network_loader.load_initial_state_and_trajectories()

report_generator = ReportGenerator(
network_loader.mol_entries, folder + "/dummy.tex", rebuild_mol_pictures=True
Expand All @@ -820,6 +815,47 @@ def euvl_phase2_test():

sink_report(simulation_replayer, folder + "/sink_report.tex")

tps_plus1_id = find_mol_entry_from_xyz_and_charge(mol_entries, "./xyz_files/tps.xyz", 1)
phs_0_id = find_mol_entry_from_xyz_and_charge(mol_entries, "./xyz_files/phs.xyz", 0)
tba_0_id = find_mol_entry_from_xyz_and_charge(mol_entries, "./xyz_files/tba.xyz", 0)
nf_minus1_id = find_mol_entry_from_xyz_and_charge(mol_entries, "./xyz_files/nf.xyz", -1)

phase2_important_species = [tps_plus1_id, phs_0_id, tba_0_id, nf_minus1_id]

colors = list(mcolors.TABLEAU_COLORS.values())
phase2_colorstyle_list = []
for ii, species in enumerate(phase2_important_species):
phase2_colorstyle_list.append([colors[ii], "solid"])

ii = 0
for mol_id in simulation_replayer.sinks:
if mol_id not in phase2_important_species:
phase2_important_species.append(mol_id)
phase2_colorstyle_list.append([colors[ii%len(colors)], "dashed"])
ii += 1

phase1_important_species = copy.deepcopy(phase2_important_species)
phase1_important_species.append(len(mol_entries))

phase1_colorstyle_list = copy.deepcopy(phase2_colorstyle_list)
phase1_colorstyle_list.append(["black", "dotted"])

phase1_simulation_replayer.time_series_graph(
seeds=[i for i in range(1000,2000)],
species_of_interest=phase1_important_species,
path=os.path.join(folder,"phase1_time_series"),
custom_y_max=36,
custom_colorstyle_list=phase1_colorstyle_list
)

simulation_replayer.time_series_graph(
seeds=[i for i in range(1000,1000+1000*int(number_of_threads))],
species_of_interest=phase2_important_species,
path=os.path.join(folder,"phase2_time_series"),
custom_y_max=36,
custom_colorstyle_list=phase2_colorstyle_list
)

tests_passed = True
print("Number of species:", network_loader.number_of_species)
if network_loader.number_of_species == 103:
Expand Down

0 comments on commit 7792880

Please sign in to comment.