Skip to content

Commit

Permalink
Renaming classes, cleaning up
Browse files Browse the repository at this point in the history
  • Loading branch information
binho authored and binho committed Jan 21, 2024
1 parent fda71fb commit a168682
Show file tree
Hide file tree
Showing 29 changed files with 789 additions and 595 deletions.
2 changes: 1 addition & 1 deletion src/phylojunction/distribution/dn_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
self.ln_sd_list = ty.cast(ty.List[float], self.vectorized_params[1])

# for inference, we need to keep track of parent node names
if isinstance(parent_node_tracker, pgm.DeterministicNodePGM):
if isinstance(parent_node_tracker, pgm.DeterministicNodeDAG):
raise ec.ObjInitInvalidArgError(
self.DN_NAME,
("One of the arguments is a deterministic node. "
Expand Down
27 changes: 19 additions & 8 deletions src/phylojunction/inference/revbayes/rb_dn_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def get_exponential_rev_inference_spec_info(n_samples: int, exp_scale_or_rate_li

# if we can find a node that holds the value of the rate, we use it
try:
ith_sim_str += parent_node_tracker["rate"] # key: arg in PJ syntax, value: NodePGM name passed as arg
# key: arg in PJ syntax, value: NodeDAG name passed as arg
ith_sim_str += parent_node_tracker["rate"]

except:
ith_sim_str += str(scale_or_rate_list[ith_sim])
ith_sim_str += ")"
Expand Down Expand Up @@ -84,17 +86,22 @@ def get_gamma_rev_inference_spec_info(n_samples: int, gamma_shape_param_list: ty
for ith_sim in range(n_samples):
ith_sim_str = "dnGamma("

# if we can find a node that holds the value of the shape parameter, we use it
# if we can find a node that holds the value of the shape
# parameter, we use it
try:
ith_sim_str += parent_node_tracker["shape"] # returns NodePGM, and we grab its name
# returns NodeDAG, and we grab its name
ith_sim_str += parent_node_tracker["shape"]

except:
ith_sim_str += str(shape_list[ith_sim])

ith_sim_str += ", rate="

# if we can find a node that holds the value of the scale parameter, we use it
try:
ith_sim_str += parent_node_tracker["scale"] # returns NodePGM, and we grab its name
# returns NodeDAG, and we grab its name
ith_sim_str += parent_node_tracker["scale"]

except:
ith_sim_str += str(scale_or_rate_list[ith_sim])
ith_sim_str += ")"
Expand All @@ -117,15 +124,15 @@ def get_normal_rev_inference_spec_info(n_samples: int, norm_mean_param_list: ty.

# if we can find a node that holds the value of the mean, we use it
try:
ith_sim_str += parent_node_tracker["mean"] # returns NodePGM, and we grab its name
ith_sim_str += parent_node_tracker["mean"] # returns NodeDAG, and we grab its name
except:
ith_sim_str += str(real_mean_list[ith_sim])

ith_sim_str += ", sd="

# if we can find a node that holds the value of the sd, we use it
try:
ith_sim_str += parent_node_tracker["sd"] # returns NodePGM, and we grab its name
ith_sim_str += parent_node_tracker["sd"] # returns NodeDAG, and we grab its name
except:
ith_sim_str += str(real_sd_list[ith_sim])
ith_sim_str += ")"
Expand Down Expand Up @@ -153,15 +160,19 @@ def get_ln_rev_inference_spec_info(n_samples: int, ln_mean_list: ty.List[float],

# if we can find a node that holds the value of the mean, we use it
try:
ith_sim_str += parent_node_tracker["mean"] # returns NodePGM, and we grab its name
# returns NodePGM, and we grab its name
ith_sim_str += parent_node_tracker["mean"]

except:
ith_sim_str += str(real_mean_list[ith_sim])

ith_sim_str += ", sd="

# if we can find a node that holds the value of the sd, we use it
try:
ith_sim_str += parent_node_tracker["sd"] # returns NodePGM, and we grab its name
# returns NodeDAG, and we grab its name
ith_sim_str += parent_node_tracker["sd"]

except:
ith_sim_str += str(real_sd_list[ith_sim])
ith_sim_str += ")"
Expand Down
90 changes: 51 additions & 39 deletions src/phylojunction/inference/revbayes/rb_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,39 @@
import phylojunction.utility.exception_classes as ec

def get_mcmc_logging_spec_list(
a_node_pgm_name: str,
a_node_dag_name: str,
moves_str: str,
n_sim: int,
n_samples: int,
mcmc_chain_length: int,
prefix: str,
results_dir: str) -> ty.List[str]:
"""Generate list of strings where each element (one per simulation) will configure RevBayes' MCMC (moves) and logging inside a .Rev script
"""Generate MCMC move strings for .Rev script.
This method will produce a list of strings where each element
(one per sample) will configure RevBayes' MCMC (moves) and
logging inside a .Rev script
Args:
a_node_pgm_name (str): One of the probabilistic graphical model nodes, required to set the model object
moves_str (str): All moves to be carried out during MCMC, as a string containing new line characters
n_sim (int): Number of simulations
prefix (str): Prefix to preceed MCMC result .log file names
results_dir (str): String specifying directory where MCMC result .log files shall be put by RevBayes
a_node_dag_name (str): One of the probabilistic graphical model
nodes, required to set the model object.
moves_str (str): All moves to be carried out during MCMC, as a
string containing new line characters.
n_samples (int): Number of samples (simulations).
prefix (str): Prefix to preceed MCMC result .log file names.
results_dir (str): String specifying directory where MCMC result
.log files shall be put by RevBayes.
Returns:
(str): List of strings (one per simulation), each will configure RevBayes' MCMC (moves) and logging inside a .Rev script
(str): List of strings (one per sample), each will configure
RevBayes' MCMC (moves) and logging inside a .Rev script.
"""

mcmc_logging_spec_str: str = ""
mcmc_logging_spec_str += "mymodel = model(" + a_node_pgm_name + ")\n\n"
mcmc_logging_spec_str += "mymodel = model(" + a_node_dag_name + ")\n\n"
mcmc_logging_spec_str += "monitors[1] = mnScreen()\n"

mcmc_logging_spec_list: ty.List[str] = list()
for i in range(n_sim):
for i in range(n_samples):
mcmc_logging_spec_list.append(
moves_str + "\n" +
mcmc_logging_spec_str +
Expand Down Expand Up @@ -70,60 +79,63 @@ def dag_obj_to_rev_inference_spec(

all_nodes_all_sims_spec_list: ty.List[ty.List[str]] = []
all_nodes_moves_str: str = str()
sorted_node_pgm_list: ty.List[pgm.NodePGM] = dag_obj.get_sorted_node_dag_list()
sorted_node_dag_list: ty.List[pgm.NodeDAG] = dag_obj.get_sorted_node_dag_list()
node_name: str = str()
n_sim = 0

########################
# Going over PGM nodes #
########################
node_count = 1
for node_pgm in sorted_node_pgm_list:
node_name = node_pgm.node_name
is_clamped = node_pgm.is_clamped
for node_dag in sorted_node_dag_list:
node_name = node_dag.node_name
is_clamped = node_dag.is_clamped
node_inference_spec_str: str
node_inference_spec_list: ty.List[str] = [] # will contain all sims for this node

if isinstance(node_pgm, pgm.StochasticNodePGM):
if isinstance(node_dag, pgm.StochasticNodeDAG):
################
# Sampled node #
################
if node_pgm.is_sampled:
node_operator_weight = node_pgm.operator_weight
dn_obj = node_pgm.sampling_dn
if node_dag.is_sampled:
node_operator_weight = node_dag.operator_weight
dn_obj = node_dag.sampling_dn

if dn_obj.DN_NAME == "DnSSE": continue # TODO
if dn_obj.DN_NAME == "DnSSE":
continue # TODO

# getting distribution spec
n_sim, n_repl, rev_str_list = rbpar.get_rev_str_from_dn_parametric_obj(dn_obj)
n_sim, n_repl, rev_str_list = \
rbpar.get_rev_str_from_dn_parametric_obj(dn_obj)

# all simulations for this PGM node
for ith_sim in range(n_sim):
######################
# Observed data prep #
######################
if is_clamped:
node_inference_spec_str = "truth_" + node_name + " <- "
node_inference_spec_str = \
"truth_" + node_name + " <- "
start = ith_sim * n_repl
end = start + n_repl

if type(node_pgm.value) in (list, np.ndarray):
if type(node_pgm.value[0]) in (float, int, str):
if type(node_dag.value) in (list, np.ndarray):
if type(node_dag.value[0]) in (float, int, str):
if n_repl > 1:
node_inference_spec_str += "[" + ", ".join(str(v) for v in node_pgm.value[start:end]) + "]\n\n"
node_inference_spec_str += "[" + ", ".join(str(v) for v in node_dag.value[start:end]) + "]\n\n"
# with replicates, we need a for loop
node_inference_spec_str += "for (i in 1:" + str(n_repl) + ") {\n " + node_name + "[i] ~ " + rev_str_list[ith_sim] + "\n"
node_inference_spec_str += " " + node_name + "[i].clamp(" + "truth_" + node_name + "[i])"
node_inference_spec_str += "\n}"

# no replicates, single value
else:
node_inference_spec_str += str(node_pgm.value[0]) + "\n\n"
node_inference_spec_str += str(node_dag.value[0]) + "\n\n"
node_inference_spec_str += node_name + ".clamp(" + "truth_" + node_name + ")"

else:
print("parsing rb script, found list of objects in clamped node... ignoring for now")
# print(node_pgm.node_name + ": adding rev spec for sim = " + str(ith_sim))
# print(node_dag.node_name + ": adding rev spec for sim = " + str(ith_sim))

# with replicates, we need a for loop
# if n_repl > 1:
Expand All @@ -145,8 +157,8 @@ def dag_obj_to_rev_inference_spec(
# print(node_inference_spec_str)

# TODO
# elif isinstance(node_pgm, DeterministicNodePGM):
# pj2rev_deterministic(node_pgm) # will determine which class inside, and convert it to rev
# elif isinstance(node_dag, DeterministicNodeDAG):
# pj2rev_deterministic(node_dag) # will determine which class inside, and convert it to rev

############################
# Clamped stochastic node, #
Expand All @@ -155,7 +167,7 @@ def dag_obj_to_rev_inference_spec(
# x <- [1, 2] #
############################
else:
n_vals = len(node_pgm.value)
n_vals = len(node_dag.value)
n_sim = dag_obj.sample_size

# in case user enters clamped node first
Expand All @@ -180,12 +192,12 @@ def dag_obj_to_rev_inference_spec(
# clamped_node <- 1 # first .Rev script
# clamped_node <- 2 # second .Rev script
if n_vals == n_sim:
node_inference_spec_str = node_name + " <- " + node_pgm.value[ith_sim]
node_inference_spec_str = node_name + " <- " + node_dag.value[ith_sim]

# if clamped nodes have a single value, this value goes into each and every
# rev scripts
elif n_vals == 1:
node_inference_spec_str = node_name + " <- " + node_pgm.value[0]
node_inference_spec_str = node_name + " <- " + node_dag.value[0]

# if clamped nodes have more values than sampled stochastic nodes,
# all these values go into each and every rev script
Expand All @@ -195,12 +207,12 @@ def dag_obj_to_rev_inference_spec(
# TODO: fix this
# node_inference_spec_str = node_name + " <- "

# # node_pgm.value is always a list
# if type(node_pgm.value[0]) in (float, int, str):
# if len(node_pgm.value) == 1:
# node_inference_spec_str += str(node_pgm.value[0])
# # node_dag.value is always a list
# if type(node_dag.value[0]) in (float, int, str):
# if len(node_dag.value) == 1:
# node_inference_spec_str += str(node_dag.value[0])
# else:
# node_inference_spec_str += "[" + ", ".join(str(v) for v in node_pgm.value) + "]"
# node_inference_spec_str += "[" + ", ".join(str(v) for v in node_dag.value) + "]"

node_inference_spec_list.append(node_inference_spec_str)

Expand Down Expand Up @@ -230,8 +242,8 @@ def dag_obj_to_rev_inference_spec(
except: pass

# converting into, 1D: sims, 2D: nodes
for j, jth_node_pgm_list_all_sims in enumerate(all_nodes_all_sims_spec_list):
for i, ith_sim_this_node_str in enumerate(jth_node_pgm_list_all_sims):
for j, jth_node_dag_list_all_sims in enumerate(all_nodes_all_sims_spec_list):
for i, ith_sim_this_node_str in enumerate(jth_node_dag_list_all_sims):
# if there are > 1 nodes and we are not looking at the last node
if len(all_nodes_all_sims_spec_list) > 1 and j < (len(all_nodes_all_sims_spec_list)-1):
all_sims_model_spec_list[i] += ith_sim_this_node_str + line_sep
Expand Down
Loading

0 comments on commit a168682

Please sign in to comment.