diff --git a/src/phylojunction/distribution/dn_parametric.py b/src/phylojunction/distribution/dn_parametric.py index 0a69d98..e53f66f 100644 --- a/src/phylojunction/distribution/dn_parametric.py +++ b/src/phylojunction/distribution/dn_parametric.py @@ -101,7 +101,7 @@ def __init__( self.ln_sd_list = ty.cast(ty.List[float], self.vectorized_params[1]) # for inference, we need to keep track of parent node names - if isinstance(parent_node_tracker, pgm.DeterministicNodePGM): + if isinstance(parent_node_tracker, pgm.DeterministicNodeDAG): raise ec.ObjInitInvalidArgError( self.DN_NAME, ("One of the arguments is a deterministic node. " diff --git a/src/phylojunction/inference/revbayes/rb_dn_parametric.py b/src/phylojunction/inference/revbayes/rb_dn_parametric.py index f23fd34..6eafa49 100644 --- a/src/phylojunction/inference/revbayes/rb_dn_parametric.py +++ b/src/phylojunction/inference/revbayes/rb_dn_parametric.py @@ -56,7 +56,9 @@ def get_exponential_rev_inference_spec_info(n_samples: int, exp_scale_or_rate_li # if we can find a node that holds the value of the rate, we use it try: - ith_sim_str += parent_node_tracker["rate"] # key: arg in PJ syntax, value: NodePGM name passed as arg + # key: arg in PJ syntax, value: NodeDAG name passed as arg + ith_sim_str += parent_node_tracker["rate"] + except: ith_sim_str += str(scale_or_rate_list[ith_sim]) ith_sim_str += ")" @@ -84,9 +86,12 @@ def get_gamma_rev_inference_spec_info(n_samples: int, gamma_shape_param_list: ty for ith_sim in range(n_samples): ith_sim_str = "dnGamma(" - # if we can find a node that holds the value of the shape parameter, we use it + # if we can find a node that holds the value of the shape + # parameter, we use it try: - ith_sim_str += parent_node_tracker["shape"] # returns NodePGM, and we grab its name + # returns NodeDAG, and we grab its name + ith_sim_str += parent_node_tracker["shape"] + except: ith_sim_str += str(shape_list[ith_sim]) @@ -94,7 +99,9 @@ def get_gamma_rev_inference_spec_info(n_samples: int, gamma_shape_param_list: ty # if we can find a node that holds the value of the scale parameter, we use it try: - ith_sim_str += parent_node_tracker["scale"] # returns NodePGM, and we grab its name + # returns NodeDAG, and we grab its name + ith_sim_str += parent_node_tracker["scale"] + except: ith_sim_str += str(scale_or_rate_list[ith_sim]) ith_sim_str += ")" @@ -117,7 +124,7 @@ def get_normal_rev_inference_spec_info(n_samples: int, norm_mean_param_list: ty. # if we can find a node that holds the value of the mean, we use it try: - ith_sim_str += parent_node_tracker["mean"] # returns NodePGM, and we grab its name + ith_sim_str += parent_node_tracker["mean"] # returns NodeDAG, and we grab its name except: ith_sim_str += str(real_mean_list[ith_sim]) @@ -125,7 +132,7 @@ def get_normal_rev_inference_spec_info(n_samples: int, norm_mean_param_list: ty. # if we can find a node that holds the value of the sd, we use it try: - ith_sim_str += parent_node_tracker["sd"] # returns NodePGM, and we grab its name + ith_sim_str += parent_node_tracker["sd"] # returns NodeDAG, and we grab its name except: ith_sim_str += str(real_sd_list[ith_sim]) ith_sim_str += ")" @@ -153,7 +160,9 @@ def get_ln_rev_inference_spec_info(n_samples: int, ln_mean_list: ty.List[float], # if we can find a node that holds the value of the mean, we use it try: - ith_sim_str += parent_node_tracker["mean"] # returns NodePGM, and we grab its name + # returns NodePGM, and we grab its name + ith_sim_str += parent_node_tracker["mean"] + except: ith_sim_str += str(real_mean_list[ith_sim]) @@ -161,7 +170,9 @@ def get_ln_rev_inference_spec_info(n_samples: int, ln_mean_list: ty.List[float], # if we can find a node that holds the value of the sd, we use it try: - ith_sim_str += parent_node_tracker["sd"] # returns NodePGM, and we grab its name + # returns NodeDAG, and we grab its name + ith_sim_str += parent_node_tracker["sd"] + except: ith_sim_str += str(real_sd_list[ith_sim]) ith_sim_str += ")" diff --git a/src/phylojunction/inference/revbayes/rb_inference.py b/src/phylojunction/inference/revbayes/rb_inference.py index 2f7b374..ed67726 100644 --- a/src/phylojunction/inference/revbayes/rb_inference.py +++ b/src/phylojunction/inference/revbayes/rb_inference.py @@ -9,30 +9,39 @@ import phylojunction.utility.exception_classes as ec def get_mcmc_logging_spec_list( - a_node_pgm_name: str, + a_node_dag_name: str, moves_str: str, - n_sim: int, + n_samples: int, mcmc_chain_length: int, prefix: str, results_dir: str) -> ty.List[str]: - """Generate list of strings where each element (one per simulation) will configure RevBayes' MCMC (moves) and logging inside a .Rev script + """Generate MCMC move strings for .Rev script. + + This method will produce a list of strings where each element + (one per sample) will configure RevBayes' MCMC (moves) and + logging inside a .Rev script Args: - a_node_pgm_name (str): One of the probabilistic graphical model nodes, required to set the model object - moves_str (str): All moves to be carried out during MCMC, as a string containing new line characters - n_sim (int): Number of simulations - prefix (str): Prefix to preceed MCMC result .log file names - results_dir (str): String specifying directory where MCMC result .log files shall be put by RevBayes + a_node_dag_name (str): One of the probabilistic graphical model + nodes, required to set the model object. + moves_str (str): All moves to be carried out during MCMC, as a + string containing new line characters. + n_samples (int): Number of samples (simulations). + prefix (str): Prefix to preceed MCMC result .log file names. + results_dir (str): String specifying directory where MCMC result + .log files shall be put by RevBayes. Returns: - (str): List of strings (one per simulation), each will configure RevBayes' MCMC (moves) and logging inside a .Rev script + (str): List of strings (one per sample), each will configure + RevBayes' MCMC (moves) and logging inside a .Rev script. """ + mcmc_logging_spec_str: str = "" - mcmc_logging_spec_str += "mymodel = model(" + a_node_pgm_name + ")\n\n" + mcmc_logging_spec_str += "mymodel = model(" + a_node_dag_name + ")\n\n" mcmc_logging_spec_str += "monitors[1] = mnScreen()\n" mcmc_logging_spec_list: ty.List[str] = list() - for i in range(n_sim): + for i in range(n_samples): mcmc_logging_spec_list.append( moves_str + "\n" + mcmc_logging_spec_str + @@ -70,7 +79,7 @@ def dag_obj_to_rev_inference_spec( all_nodes_all_sims_spec_list: ty.List[ty.List[str]] = [] all_nodes_moves_str: str = str() - sorted_node_pgm_list: ty.List[pgm.NodePGM] = dag_obj.get_sorted_node_dag_list() + sorted_node_dag_list: ty.List[pgm.NodeDAG] = dag_obj.get_sorted_node_dag_list() node_name: str = str() n_sim = 0 @@ -78,24 +87,26 @@ def dag_obj_to_rev_inference_spec( # Going over PGM nodes # ######################## node_count = 1 - for node_pgm in sorted_node_pgm_list: - node_name = node_pgm.node_name - is_clamped = node_pgm.is_clamped + for node_dag in sorted_node_dag_list: + node_name = node_dag.node_name + is_clamped = node_dag.is_clamped node_inference_spec_str: str node_inference_spec_list: ty.List[str] = [] # will contain all sims for this node - if isinstance(node_pgm, pgm.StochasticNodePGM): + if isinstance(node_dag, pgm.StochasticNodeDAG): ################ # Sampled node # ################ - if node_pgm.is_sampled: - node_operator_weight = node_pgm.operator_weight - dn_obj = node_pgm.sampling_dn + if node_dag.is_sampled: + node_operator_weight = node_dag.operator_weight + dn_obj = node_dag.sampling_dn - if dn_obj.DN_NAME == "DnSSE": continue # TODO + if dn_obj.DN_NAME == "DnSSE": + continue # TODO # getting distribution spec - n_sim, n_repl, rev_str_list = rbpar.get_rev_str_from_dn_parametric_obj(dn_obj) + n_sim, n_repl, rev_str_list = \ + rbpar.get_rev_str_from_dn_parametric_obj(dn_obj) # all simulations for this PGM node for ith_sim in range(n_sim): @@ -103,14 +114,15 @@ def dag_obj_to_rev_inference_spec( # Observed data prep # ###################### if is_clamped: - node_inference_spec_str = "truth_" + node_name + " <- " + node_inference_spec_str = \ + "truth_" + node_name + " <- " start = ith_sim * n_repl end = start + n_repl - if type(node_pgm.value) in (list, np.ndarray): - if type(node_pgm.value[0]) in (float, int, str): + if type(node_dag.value) in (list, np.ndarray): + if type(node_dag.value[0]) in (float, int, str): if n_repl > 1: - node_inference_spec_str += "[" + ", ".join(str(v) for v in node_pgm.value[start:end]) + "]\n\n" + node_inference_spec_str += "[" + ", ".join(str(v) for v in node_dag.value[start:end]) + "]\n\n" # with replicates, we need a for loop node_inference_spec_str += "for (i in 1:" + str(n_repl) + ") {\n " + node_name + "[i] ~ " + rev_str_list[ith_sim] + "\n" node_inference_spec_str += " " + node_name + "[i].clamp(" + "truth_" + node_name + "[i])" @@ -118,12 +130,12 @@ def dag_obj_to_rev_inference_spec( # no replicates, single value else: - node_inference_spec_str += str(node_pgm.value[0]) + "\n\n" + node_inference_spec_str += str(node_dag.value[0]) + "\n\n" node_inference_spec_str += node_name + ".clamp(" + "truth_" + node_name + ")" else: print("parsing rb script, found list of objects in clamped node... ignoring for now") - # print(node_pgm.node_name + ": adding rev spec for sim = " + str(ith_sim)) + # print(node_dag.node_name + ": adding rev spec for sim = " + str(ith_sim)) # with replicates, we need a for loop # if n_repl > 1: @@ -145,8 +157,8 @@ def dag_obj_to_rev_inference_spec( # print(node_inference_spec_str) # TODO - # elif isinstance(node_pgm, DeterministicNodePGM): - # pj2rev_deterministic(node_pgm) # will determine which class inside, and convert it to rev + # elif isinstance(node_dag, DeterministicNodeDAG): + # pj2rev_deterministic(node_dag) # will determine which class inside, and convert it to rev ############################ # Clamped stochastic node, # @@ -155,7 +167,7 @@ def dag_obj_to_rev_inference_spec( # x <- [1, 2] # ############################ else: - n_vals = len(node_pgm.value) + n_vals = len(node_dag.value) n_sim = dag_obj.sample_size # in case user enters clamped node first @@ -180,12 +192,12 @@ def dag_obj_to_rev_inference_spec( # clamped_node <- 1 # first .Rev script # clamped_node <- 2 # second .Rev script if n_vals == n_sim: - node_inference_spec_str = node_name + " <- " + node_pgm.value[ith_sim] + node_inference_spec_str = node_name + " <- " + node_dag.value[ith_sim] # if clamped nodes have a single value, this value goes into each and every # rev scripts elif n_vals == 1: - node_inference_spec_str = node_name + " <- " + node_pgm.value[0] + node_inference_spec_str = node_name + " <- " + node_dag.value[0] # if clamped nodes have more values than sampled stochastic nodes, # all these values go into each and every rev script @@ -195,12 +207,12 @@ def dag_obj_to_rev_inference_spec( # TODO: fix this # node_inference_spec_str = node_name + " <- " - # # node_pgm.value is always a list - # if type(node_pgm.value[0]) in (float, int, str): - # if len(node_pgm.value) == 1: - # node_inference_spec_str += str(node_pgm.value[0]) + # # node_dag.value is always a list + # if type(node_dag.value[0]) in (float, int, str): + # if len(node_dag.value) == 1: + # node_inference_spec_str += str(node_dag.value[0]) # else: - # node_inference_spec_str += "[" + ", ".join(str(v) for v in node_pgm.value) + "]" + # node_inference_spec_str += "[" + ", ".join(str(v) for v in node_dag.value) + "]" node_inference_spec_list.append(node_inference_spec_str) @@ -230,8 +242,8 @@ def dag_obj_to_rev_inference_spec( except: pass # converting into, 1D: sims, 2D: nodes - for j, jth_node_pgm_list_all_sims in enumerate(all_nodes_all_sims_spec_list): - for i, ith_sim_this_node_str in enumerate(jth_node_pgm_list_all_sims): + for j, jth_node_dag_list_all_sims in enumerate(all_nodes_all_sims_spec_list): + for i, ith_sim_this_node_str in enumerate(jth_node_dag_list_all_sims): # if there are > 1 nodes and we are not looking at the last node if len(all_nodes_all_sims_spec_list) > 1 and j < (len(all_nodes_all_sims_spec_list)-1): all_sims_model_spec_list[i] += ith_sim_this_node_str + line_sep diff --git a/src/phylojunction/interface/cmdbox/cmd_parse.py b/src/phylojunction/interface/cmdbox/cmd_parse.py index 9dea170..8533b9e 100644 --- a/src/phylojunction/interface/cmdbox/cmd_parse.py +++ b/src/phylojunction/interface/cmdbox/cmd_parse.py @@ -62,7 +62,7 @@ def _execute_spec_lines( # Debugging space # ################### # seeing trees in script strings in main() - # for node_name, node_pgm in dag_obj.name_node_dict.items(): + # for node_name, node_dag in dag_obj.name_node_dict.items(): # if node_name == "trs": # # note that pjgui uses matplotlib.figure.Figure # # (which is part of Matplotlib's OOP class library) @@ -78,10 +78,10 @@ def _execute_spec_lines( # ax.spines['right'].set_visible(False) # ax.spines['top'].set_visible(False) - # print(node_pgm.value[0].tree.as_string(schema="newick")) + # print(node_dag.value[0].tree.as_string(schema="newick")) # pjtr.plot_ann_tree( - # node_pgm.value[0], + # node_dag.value[0], # ax, # use_age=False, # start_at_origin=True, @@ -90,7 +90,7 @@ def _execute_spec_lines( # plt.show() - # print(node_pgm.get_node_stats_str(0, len(node_pgm.value), 0)) + # print(node_dag.get_node_stats_str(0, len(node_dag.value), 0)) # as if we had clicked "See" in the inference tab # all_sims_model_spec_list, all_sims_mcmc_logging_spec_list, dir_list = \ @@ -301,7 +301,7 @@ def parse_variable_assignment( static script or via GUI. """ - def create_add_stoch_node_pgm(a_stoch_node_name: str, + def create_add_stoch_node_dag(a_stoch_node_name: str, sample_size: int, a_val_obj_list: ty.List[ty.Any], a_ct_fn_obj: pgm.ConstantFn): @@ -309,8 +309,8 @@ def create_add_stoch_node_pgm(a_stoch_node_name: str, replicate_size: int = 1 if a_ct_fn_obj is not None: replicate_size = a_ct_fn_obj.n_repl - - stoch_node = pgm.StochasticNodePGM( + + stoch_node = pgm.StochasticNodeDAG( a_stoch_node_name, sample_size=sample_size, replicate_size=replicate_size, @@ -391,13 +391,13 @@ def create_add_stoch_node_pgm(a_stoch_node_name: str, stoch_node_spec, ("Something went wrong during variable assignment " "Could not find both the name of a function " - "(e.g., \'read_tree\') and its specification (e.g., " - "\'(file=\"examples/geosse_dummy_tree1\", node_name_attr=\"index\")\').")) + "(e.g., 'read_tree') and its specification (e.g., " + "'(file=\"examples/geosse_dummy_tree1\", node_name_attr=\"index\")').")) val_obj_list = cmdu.val_or_obj(dag_obj, values_list) n_samples = len(val_obj_list) - create_add_stoch_node_pgm(stoch_node_name, + create_add_stoch_node_dag(stoch_node_name, n_samples, val_obj_list, ct_fn_obj) @@ -431,11 +431,11 @@ def create_add_rv_pgm( sample_size: int, replicate_size: int, a_dn_obj: pgm.DistributionPGM, - parent_pgm_nodes: ty.List[pgm.NodePGM], + parent_pgm_nodes: ty.List[pgm.NodeDAG], clamped: bool): # set dn inside rv, then call .sample - stoch_node_pgm = pgm.StochasticNodePGM( + stoch_node_dag = pgm.StochasticNodeDAG( a_stoch_node_name, sample_size=sample_size, replicate_size=replicate_size, @@ -443,7 +443,7 @@ def create_add_rv_pgm( parent_nodes=parent_pgm_nodes, clamped=clamped) - dag_obj.add_node(stoch_node_pgm) + dag_obj.add_node(stoch_node_dag) if re.search(cmdu.sampling_dn_spec_regex, stoch_node_dn_spec) is None: @@ -528,7 +528,7 @@ def create_add_rv_pgm( # if user passes 'r' as a node, # e.g., 'n_sim <- 2', then 'normal(n=n_sim...) - elif isinstance(spec_dict["n"][0], pgm.StochasticNodePGM): + elif isinstance(spec_dict["n"][0], pgm.StochasticNodeDAG): # this check and the try-except below were already done inside # .create_dn_obj() above # @@ -557,7 +557,7 @@ def create_add_rv_pgm( # if user passes 'nr' as a node, # e.g., 'n_rep <- 2', then 'normal(...nr=n_rep...) - elif isinstance(spec_dict["nr"][0], pgm.StochasticNodePGM): + elif isinstance(spec_dict["nr"][0], pgm.StochasticNodeDAG): # this check and the try-except below were already done inside # .create_dn_obj() above # @@ -593,7 +593,7 @@ def parse_deterministic_function_assignment( det_node_fn_spec: str, cmd_line: str) -> None: """ - Create DeterministicNodePGM instance from command string with + Create DeterministicNodeDAG instance from command string with ':=' operator, then add it to ProbabilisticGraphiclModel instance. This node is not sampled (not a random variable) and is deterministically initialized via a deterministic function call. @@ -611,25 +611,25 @@ def parse_deterministic_function_assignment( def create_add_det_nd_pgm(det_nd_name: str, det_obj: ty.Any, - parent_pgm_nodes: ty.List[pgm.NodePGM]): + parent_pgm_nodes: ty.List[pgm.NodeDAG]): - det_nd_pgm = pgm.DeterministicNodePGM(det_nd_name, + det_nd_pgm = pgm.DeterministicNodeDAG(det_nd_name, value=det_obj, parent_nodes=parent_pgm_nodes) dag_obj.add_node(det_nd_pgm) - # deterministic node is of class DeterministicNodePGM, which - # derives NodePGM -- we do not need to initialize NodePGM - # as when a new StochasticNodePGM is created (see above in + # deterministic node is of class DeterministicNodeDAG, which + # derives NodeDAG -- we do not need to initialize NodeDAG + # as when a new StochasticNodeDAG is created (see above in # create_add_rv_pgm()) # # this check is just to make sure we are adding a class - # deriving from NodePGM - # if isinstance(det_obj, NodePGM): + # deriving from NodeDAG + # if isinstance(det_obj, NodeDAG): # det_obj.node_name = det_nd_name # det_obj.parent_nodes = parent_pgm_nodes - # det_nd_pgm = det_obj - # dag_obj.add_node(det_nd_pgm) + # det_nd_dag = det_obj + # dag_obj.add_node(det_nd_dag) if re.search(cmdu.sampling_dn_spec_regex, det_node_fn_spec) is None: raise ec.ScriptSyntaxError( @@ -906,7 +906,7 @@ def create_add_det_nd_pgm(det_nd_name: str, # looking at dag nodes for node_name, node_dag in dag.name_node_dict.items(): - if isinstance(node_dag, pgm.StochasticNodePGM): + if isinstance(node_dag, pgm.StochasticNodeDAG): if isinstance(node_dag.value[0], pjtr.AnnotatedTree): print(node_dag.value[0].tree.as_string(schema="newick")) diff --git a/src/phylojunction/interface/cmdbox/cmd_parse_utils.py b/src/phylojunction/interface/cmdbox/cmd_parse_utils.py index a3dd925..10c4eb6 100644 --- a/src/phylojunction/interface/cmdbox/cmd_parse_utils.py +++ b/src/phylojunction/interface/cmdbox/cmd_parse_utils.py @@ -40,7 +40,7 @@ def val_or_obj(dag_obj: pgm.DirectedAcyclicGraph, - val: ty.List[str]) -> ty.List[ty.Union[pgm.NodePGM, str]]: + val: ty.List[str]) -> ty.List[ty.Union[pgm.NodeDAG, str]]: """Return list of strings with values or node names. Checks if provided values are directly accessible as values @@ -62,16 +62,17 @@ def val_or_obj(dag_obj: pgm.DirectedAcyclicGraph, strings) and/or stochastic node objects. """ - val_or_obj_list: ty.List[ty.Union[pgm.NodePGM, str]] = [] + val_or_obj_list: ty.List[ty.Union[pgm.NodeDAG, str]] = list() for v in val: if isinstance(v, str): # checking if string could potentially be a node object that # has a name if re.match(character_value_regex, v): + # if it does find a node with that name, we add the node object try: - # appending StochasticNodePGM + # appending StochasticNodeDAG val_or_obj_list.append(dag_obj.name_node_dict[v]) except KeyError: @@ -90,18 +91,15 @@ def parse_spec( fn_spec_str: str, cmd_line: str) \ -> ty.Tuple[ - ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]], - ty.List[pgm.NodePGM]]: + ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]], + ty.List[pgm.NodeDAG]]: spec_dict: ty.Dict[str, str] = tokenize_fn_spec(fn_spec_str, cmd_line) - # debugging - # print("spec_dict = ", spec_dict) - spec_dict_return: \ - ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]] = dict() + ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]] = dict() - parent_pgm_nodes: ty.List[pgm.NodePGM] = [] + parent_pgm_nodes: ty.List[pgm.NodeDAG] = [] for param_name, an_arg in spec_dict.items(): ############# @@ -112,28 +110,27 @@ def parse_spec( # if argument is a list if re.match(vector_value_regex, an_arg): - # print("\n\n parsing arg " + str(an_arg) + " as vector") arg_list = parse_val_vector(an_arg) + # if scalar variable else: - # print("\n\n parsing arg " + str(an_arg) + " as str") arg_list.append(an_arg) val_obj_list = val_or_obj(dag_obj, arg_list) spec_dict_return[param_name] = val_obj_list # { param_name_str: # [ number_or_quoted_str_str1, - # a_NodePGM1, + # a_NodeDAG1, # number_or_quoted_str_str2, - # aNodePGM2 + # aNodeDAG2 # ] # } for vo in val_obj_list: - if isinstance(vo, pgm.NodePGM): + if isinstance(vo, pgm.NodeDAG): parent_pgm_nodes.append(vo) - # values in spec_dict will be lists of strings or lists of NodePGMs + # values in spec_dict will be lists of strings or lists of NodeDAGs return spec_dict_return, parent_pgm_nodes diff --git a/src/phylojunction/interface/cmdbox/cmd_parse_utils.pyi b/src/phylojunction/interface/cmdbox/cmd_parse_utils.pyi index d5f66a7..8055908 100644 --- a/src/phylojunction/interface/cmdbox/cmd_parse_utils.pyi +++ b/src/phylojunction/interface/cmdbox/cmd_parse_utils.pyi @@ -14,7 +14,7 @@ sampled_as_regex: Incomplete sampling_dn_spec_regex: Incomplete deterministic_regex: Incomplete -def val_or_obj(dag_obj: pgm.DirectedAcyclicGraph, val: ty.List[str]) -> ty.List[ty.Union[pgm.NodePGM, str]]: ... -def parse_spec(dag_obj: pgm.DirectedAcyclicGraph, fn_spec_str: str, cmd_line: str) -> ty.Tuple[ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]], ty.List[pgm.NodePGM]]: ... +def val_or_obj(dag_obj: pgm.DirectedAcyclicGraph, val: ty.List[str]) -> ty.List[ty.Union[pgm.NodeDAG, str]]: ... +def parse_spec(dag_obj: pgm.DirectedAcyclicGraph, fn_spec_str: str, cmd_line: str) -> ty.Tuple[ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]], ty.List[pgm.NodeDAG]]: ... def parse_val_vector(vec_str: str) -> ty.List[str]: ... def tokenize_fn_spec(fn_spec_str: str, cmd_line: str) -> ty.Dict[str, str]: ... diff --git a/src/phylojunction/interface/grammar/ct_fn_grammar.py b/src/phylojunction/interface/grammar/ct_fn_grammar.py index e54f0c0..2e7053e 100644 --- a/src/phylojunction/interface/grammar/ct_fn_grammar.py +++ b/src/phylojunction/interface/grammar/ct_fn_grammar.py @@ -30,7 +30,7 @@ def grammar_check(cls, ct_fn_id: str, fn_param: str) -> bool: def init_return_ann_tr( cls, ct_fn_param_dict: - ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) \ + ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) \ -> pgm.ConstantFn: if not ct_fn_param_dict: @@ -49,7 +49,7 @@ def init_return_ann_tr( def create_ct_fn_obj( cls, ct_fn_id: str, - ct_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) \ + ct_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) \ -> ty.Optional[ # ty.Union[ pgm.ConstantFn @@ -62,7 +62,7 @@ def create_ct_fn_obj( ct_fn_id (str): Name of constant function to being called ct_fn_param_dict (dict): Dictionary containing constant function parameter names (str) as keys and lists (of either - strings or NodePGMs) as values + strings or NodeDAGs) as values Returns: Object: one of a variety of objects to be stored within a clamped diff --git a/src/phylojunction/interface/grammar/ct_fn_treereader_makers.py b/src/phylojunction/interface/grammar/ct_fn_treereader_makers.py index 875d603..3c23a42 100644 --- a/src/phylojunction/interface/grammar/ct_fn_treereader_makers.py +++ b/src/phylojunction/interface/grammar/ct_fn_treereader_makers.py @@ -12,7 +12,7 @@ def make_tree_reader(ct_fn_name: str, ct_fn_param_dict: \ - ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) \ + ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) \ -> pgm.ConstantFn: ############################# # IMPORTANT: Default values # @@ -26,9 +26,9 @@ def make_tree_reader(ct_fn_name: str, # args were already grammar-checked in constant_fn_grammar for arg, val in ct_fn_param_dict.items(): # if element in val is string, it remains unchanged, - # if NodePGM, we get its string-fied value + # if NodeDAG, we get its string-fied value extracted_val: ty.List[str] = \ - pgm.extract_value_from_nodepgm(val) + pgm.extract_vals_as_str_from_node_dag(val) if True: if arg == "n": diff --git a/src/phylojunction/interface/grammar/det_fn_discrete_sse_makers.py b/src/phylojunction/interface/grammar/det_fn_discrete_sse_makers.py index ccd520c..06f7073 100644 --- a/src/phylojunction/interface/grammar/det_fn_discrete_sse_makers.py +++ b/src/phylojunction/interface/grammar/det_fn_discrete_sse_makers.py @@ -9,12 +9,12 @@ __email__ = "f.mendes@wustl.edu" -def extract_value_from_pgmnodes(pgm_node_list: ty.List[pgm.NodePGM]) \ +def extract_value_from_pgmnodes(dag_node_list: ty.List[pgm.NodeDAG]) \ -> ty.List[float]: """_summary_ Args: - pgm_node_list (NodePGM): List of NodePGM objects (typing includes str because of type-safety) + pgm_node_list (NodeDM): List of NodeDAG objects (typing includes str because of type-safety) Raises: ec.NoPlatingAllowedError: _description_ @@ -23,37 +23,37 @@ def extract_value_from_pgmnodes(pgm_node_list: ty.List[pgm.NodePGM]) \ Returns: ty.List[ty.List[float]]: _description_ """ - many_nodes_pgm = len(pgm_node_list) > 1 + many_nodes_dag = len(dag_node_list) > 1 v_list: ty.List[float] = [] - for node_pgm in pgm_node_list: + for node_dag in dag_node_list: # so mypy won't complain - if isinstance(node_pgm, pgm.NodePGM): + if isinstance(node_dag, pgm.NodeDAG): # no plating supported - if node_pgm.repl_size > 1: + if node_dag.repl_size > 1: raise ec.NoPlatingAllowedError( - "sse_rate", node_pgm.node_name) + "sse_rate", node_dag.node_name) - v = node_pgm.value # list (I think before I also + v = node_dag.value # list (I think before I also # allowed numpy.ndarray, but not anymore) # so mypy won't complain if isinstance(v, list): - if len(v) > 1 and many_nodes_pgm: + if len(v) > 1 and many_nodes_dag: raise ec.StateDependentParameterMisspec( message=( ("If many variables are passed as arguments " "to initialize another variable, each of these " "variables can contain only a single value"))) - elif len(v) == 1 and many_nodes_pgm: + elif len(v) == 1 and many_nodes_dag: # making list longer # (v should be a list, which is why I don't use append) v_list += v - elif len(v) >= 1 and not many_nodes_pgm: + elif len(v) >= 1 and not many_nodes_dag: return v return v_list @@ -62,7 +62,7 @@ def extract_value_from_pgmnodes(pgm_node_list: ty.List[pgm.NodePGM]) \ def make_DiscreteStateDependentRate( det_fn_name: str, det_fn_param_dict: - ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) \ + ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) \ -> sseobj.DiscreteStateDependentRate: """ Create and return DiscreteStateDependentRate as prompted by @@ -71,7 +71,7 @@ def make_DiscreteStateDependentRate( Args: det_fn_name (str): Name of the function being called det_fn_param_dict (dict): dictionary containing parameter - strings as keys, and lists of either strings or NodePGMs + strings as keys, and lists of either strings or NodeDAGs as value(s) Returns: @@ -100,10 +100,10 @@ def make_DiscreteStateDependentRate( if arg == "value": # val is a list of random variable objects # if type(val[0]) != str: - if isinstance(val[0], pgm.NodePGM): + if isinstance(val[0], pgm.NodeDAG): # need to declare cast_val separately so mypy won't complain - cast_val1: ty.List[pgm.NodePGM] = \ - ty.cast(ty.List[pgm.NodePGM], val) + cast_val1: ty.List[pgm.NodeDAG] = \ + ty.cast(ty.List[pgm.NodeDAG], val) value = extract_value_from_pgmnodes(cast_val1) # val is a list of strings @@ -194,7 +194,7 @@ def make_DiscreteStateDependentRate( def make_DiscreteStateDependentProbability( det_fn_name: str, det_fn_param_dict: - ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) \ + ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) \ -> sseobj.DiscreteStateDependentProbability: """ Create and return DiscreteStateDependentProbability as prompted by @@ -203,7 +203,7 @@ def make_DiscreteStateDependentProbability( Args: det_fn_name (str): Name of the function being called det_fn_param_dict (dict): dictionary containing parameter - strings as keys, and lists of either strings or NodePGMs + strings as keys, and lists of either strings or NodeDAGs as value(s) Returns: @@ -231,10 +231,10 @@ def make_DiscreteStateDependentProbability( if arg == "value": # val is a list of random variable objects # if type(val[0]) != str: - if isinstance(val[0], pgm.NodePGM): + if isinstance(val[0], pgm.NodeDAG): # need to declare cast_val separately so mypy won't complain - cast_val1: ty.List[pgm.NodePGM] = \ - ty.cast(ty.List[pgm.NodePGM], val) + cast_val1: ty.List[pgm.NodeDAG] = \ + ty.cast(ty.List[pgm.NodeDAG], val) value = extract_value_from_pgmnodes(cast_val1) # val is a list of strings @@ -291,7 +291,7 @@ def make_DiscreteStateDependentProbability( def make_SSEStash( det_fn_name: str, det_fn_param_dict: - ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) \ + ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) \ -> sseobj.SSEStash: """ Create SSEStash as prompted by deterministic function call @@ -299,7 +299,7 @@ def make_SSEStash( Args: det_fn_name (str): Name of the function being called det_fn_param_dict (dict): dictionary containing parameter - strings as keys, and lists of either strings or NodePGMs + strings as keys, and lists of either strings or NodeDAGs Return: Object holding all discrete state-dependent rates and @@ -310,8 +310,8 @@ def make_SSEStash( n_time_slices: int = 1 time_slice_age_ends: ty.List[float] = [] seed_age_for_time_slicing: ty.Optional[float] = None - flat_state_dep_rate_mat: ty.List[pgm.DeterministicNodePGM] = [] - flat_state_dep_prob_mat: ty.List[pgm.DeterministicNodePGM] = [] + flat_state_dep_rate_mat: ty.List[pgm.DeterministicNodeDAG] = [] + flat_state_dep_prob_mat: ty.List[pgm.DeterministicNodeDAG] = [] ############################################# # Reading all arguments and checking health # @@ -322,7 +322,7 @@ def make_SSEStash( first_element = val[0] extracted_value = first_element # can be scalar or container - if isinstance(first_element, pgm.NodePGM) \ + if isinstance(first_element, pgm.NodeDAG) \ and len(val) == 1: extracted_value = first_element.value @@ -374,7 +374,7 @@ def make_SSEStash( if det_fn_param_dict["flat_prob_mat"]: flat_state_dep_prob_mat = \ [v for v in det_fn_param_dict["flat_prob_mat"] - if isinstance(v, pgm.DeterministicNodePGM)] + if isinstance(v, pgm.DeterministicNodeDAG)] # total number of rates has to be divisible by number of slices # if len(flat_state_dep_prob_mat) % n_time_slices != 0: @@ -383,10 +383,10 @@ def make_SSEStash( elif arg == "flat_rate_mat": if det_fn_param_dict["flat_rate_mat"]: - # list of NodePGM's + # list of NodeDAG's flat_state_dep_rate_mat = \ [v for v in det_fn_param_dict["flat_rate_mat"] - if isinstance(v, pgm.DeterministicNodePGM)] + if isinstance(v, pgm.DeterministicNodeDAG)] # total number of rates has to be divisible by number of slices # if len(flat_state_dep_rate_mat) % n_time_slices != 0: diff --git a/src/phylojunction/interface/grammar/det_fn_discrete_sse_makers.pyi b/src/phylojunction/interface/grammar/det_fn_discrete_sse_makers.pyi index 0c6d902..8b29285 100644 --- a/src/phylojunction/interface/grammar/det_fn_discrete_sse_makers.pyi +++ b/src/phylojunction/interface/grammar/det_fn_discrete_sse_makers.pyi @@ -2,6 +2,6 @@ import phylojunction.pgm.pgm as pgm import phylojunction.calculation.discrete_sse as sseobj import typing as ty -def make_DiscreteStateDependentRate(det_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) -> sseobj.DiscreteStateDependentRate: ... -def make_DiscreteStateDependentProbability(det_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) -> sseobj.DiscreteStateDependentProbability: ... -def make_SSEStash(det_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) -> sseobj.SSEStash: ... +def make_DiscreteStateDependentRate(det_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) -> sseobj.DiscreteStateDependentRate: ... +def make_DiscreteStateDependentProbability(det_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) -> sseobj.DiscreteStateDependentProbability: ... +def make_SSEStash(det_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) -> sseobj.SSEStash: ... diff --git a/src/phylojunction/interface/grammar/det_fn_grammar.py b/src/phylojunction/interface/grammar/det_fn_grammar.py index 0e5b64d..6cf9acc 100644 --- a/src/phylojunction/interface/grammar/det_fn_grammar.py +++ b/src/phylojunction/interface/grammar/det_fn_grammar.py @@ -34,7 +34,7 @@ def grammar_check(cls, det_fn_id: str, fn_param: str) -> bool: def init_return_state_dep_rate( cls, det_fn_param_dict: - ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) \ + ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) \ -> sseobj.DiscreteStateDependentRate: if not det_fn_param_dict: @@ -59,7 +59,7 @@ def init_return_state_dep_rate( def init_return_state_dep_prob( cls, det_fn_param_dict: - ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) \ + ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) \ -> sseobj.DiscreteStateDependentProbability: if not det_fn_param_dict: @@ -83,7 +83,7 @@ def init_return_state_dep_prob( def init_return_sse_stash( cls, det_fn_param_dict: - ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) \ + ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) \ -> sseobj.SSEStash: if not det_fn_param_dict: @@ -108,7 +108,7 @@ def init_return_sse_stash( def create_det_fn_obj( cls, det_fn_id: str, - det_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) \ + det_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) \ -> ty.Optional[ ty.Union[ sseobj.DiscreteStateDependentRate, @@ -122,7 +122,7 @@ def create_det_fn_obj( det_fn_id (str): Name of deterministic function being called det_fn_param_dict (dict): Dictionary containing deterministic function parameter names (str) as keys and lists (of either - strings or NodePGMs) as values + strings or NodeDAGs) as values Returns: Object: one of a variety of objects containing information for diff --git a/src/phylojunction/interface/grammar/det_fn_grammar.pyi b/src/phylojunction/interface/grammar/det_fn_grammar.pyi index ccc1ef8..89b18db 100644 --- a/src/phylojunction/interface/grammar/det_fn_grammar.pyi +++ b/src/phylojunction/interface/grammar/det_fn_grammar.pyi @@ -7,8 +7,8 @@ class PJDetFnGrammar: @classmethod def grammar_check(cls, det_fn_id: str, fn_param: str) -> bool: ... @classmethod - def init_return_state_dep_rate(cls, det_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) -> sseobj.DiscreteStateDependentRate: ... + def init_return_state_dep_rate(cls, det_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) -> sseobj.DiscreteStateDependentRate: ... @classmethod - def init_return_sse_stash(cls, det_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) -> ty.Tuple[sseobj.MacroevolEventHandler, sseobj.DiscreteStateDependentProbabilityHandler]: ... + def init_return_sse_stash(cls, det_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) -> ty.Tuple[sseobj.MacroevolEventHandler, sseobj.DiscreteStateDependentProbabilityHandler]: ... @classmethod - def create_det_fn_obj(cls, det_fn_id: str, det_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) -> ty.Optional[ty.Union[sseobj.DiscreteStateDependentRate, sseobj.MacroevolEventHandler]]: ... + def create_det_fn_obj(cls, det_fn_id: str, det_fn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) -> ty.Optional[ty.Union[sseobj.DiscreteStateDependentRate, sseobj.MacroevolEventHandler]]: ... diff --git a/src/phylojunction/interface/grammar/dn_discrete_sse_makers.py b/src/phylojunction/interface/grammar/dn_discrete_sse_makers.py index ee56719..45f9583 100644 --- a/src/phylojunction/interface/grammar/dn_discrete_sse_makers.py +++ b/src/phylojunction/interface/grammar/dn_discrete_sse_makers.py @@ -13,7 +13,7 @@ def make_discrete_SSE_dn( dn_name: str, - dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) \ + dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) \ -> pgm.DistributionPGM: ############################# @@ -67,33 +67,33 @@ def make_discrete_SSE_dn( # if element in val is string: remains unchanged # - # if StochasticNodePGM: we get its string-fied value - extracted_val = pgm.extract_value_from_nodepgm(val) + # if StochasticNodeDAG: we get its string-fied value + extracted_val = pgm.extract_vals_as_str_from_node_dag(val) ############################ # Non-vectorized arguments # ############################ # ... thus using only the first value! - first_val: ty.Union[str, pgm.NodePGM] + first_val: ty.Union[str, pgm.NodeDAG] if len(extracted_val) >= 1: first_val = extracted_val[0] - # if DeterministicNodePGM is in val - # e.g., val = [pgm.DeterministicNodePGM] + # if DeterministicNodeDAG is in val + # e.g., val = [pgm.DeterministicNodeDAG] else: first_val = val[0] - if arg == "stash" and isinstance(first_val, pgm.NodePGM): - nodepgm_val = first_val.value + if arg == "stash" and isinstance(first_val, pgm.NodeDAG): + node_dag_val = first_val.value - if isinstance(nodepgm_val, sseobj.SSEStash): - stash = nodepgm_val + if isinstance(node_dag_val, sseobj.SSEStash): + stash = node_dag_val # SSEStash will return None if prob_handler # wasn't created by user through script - # prob_handler = nodepgm_val.get_prob_handler() + # prob_handler = node_dag_val.get_prob_handler() elif arg in ("n", "nr", "runtime_limit", "min_rec_taxa", "max_rec_taxa", "abort_at_obs"): @@ -284,14 +284,14 @@ def make_discrete_SSE_dn( sse_stash = sseobj.SSEStash(event_handler) - # det_nd_pgm = pgm.DeterministicNodePGM( + # det_nd_dag = pgm.DeterministicNodeDAG( # "events", value=event_handler, parent_nodes=None) - det_nd_pgm = pgm.DeterministicNodePGM( + det_nd_pgm = pgm.DeterministicNodeDAG( "sse_stash", value=sse_stash, parent_nodes=None) - dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]] = dict() + dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]] = dict() dn_param_dict["n"] = ["1"] dn_param_dict["nr"] = ["1"] dn_param_dict["stash"] = [det_nd_pgm] diff --git a/src/phylojunction/interface/grammar/dn_grammar.py b/src/phylojunction/interface/grammar/dn_grammar.py index bfe90a8..497ae52 100644 --- a/src/phylojunction/interface/grammar/dn_grammar.py +++ b/src/phylojunction/interface/grammar/dn_grammar.py @@ -52,7 +52,7 @@ def grammar_check(cls, dn_id: str, dn_param: str) -> bool: def init_return_parametric_dn( cls, dn_id: str, - dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) \ + dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) \ -> pgm.DistributionPGM: """Create and return parametric distributions for sampling. @@ -71,7 +71,7 @@ def init_return_parametric_dn( # key: parameter name # value: paramenter's parent (DAG node) name - # e.g., { lambda: node_pgm1_name } + # e.g., { lambda: node_dag1_name } parent_node_tracker: ty.Dict[str, str] = dict() if not dn_param_dict: @@ -84,19 +84,20 @@ def init_return_parametric_dn( ln_sd: ty.List[float] = list() ln_log_space: bool = True - # { mean: node_pgm1_name, sd: node_pgm2_name, ... } + # { mean: node_dag1_name, sd: node_dag2_name, ... } if dn_param_dict: # val is list for arg, val in dn_param_dict.items(): # val = val[0] # TODO: deal with vectorization later # needed for building inference specifications - if isinstance(val[0], pgm.StochasticNodePGM): + if isinstance(val[0], pgm.StochasticNodeDAG): parent_node_tracker[arg] = val[0].node_name # if element in val is string, it remains unchanged, - # if NodePGM, we get its string-fied value - extracted_val_list = pgm.extract_value_from_nodepgm(val) + # if NodeDAG, we get its string-fied value + extracted_val_list = \ + pgm.extract_vals_as_str_from_node_dag(val) if not cls.grammar_check("lognormal", arg): raise ec.ParseNotAParameterError(arg) @@ -170,16 +171,16 @@ def init_return_parametric_dn( norm_mean: ty.List[float] = list() norm_sd: ty.List[float] = list() - # { mean: node_pgm1_name, sd: node_pgm2_name, ... } + # { mean: node_dag1_name, sd: node_dag2_name, ... } if dn_param_dict: # val is list for arg, val in dn_param_dict.items(): # needed for building inference specifications - if isinstance(val[0], pgm.StochasticNodePGM): + if isinstance(val[0], pgm.StochasticNodeDAG): parent_node_tracker[arg] = val[0].node_name - extracted_val_list = pgm.extract_value_from_nodepgm(val) + extracted_val_list = pgm.extract_vals_as_str_from_node_dag(val) if not cls.grammar_check("normal", arg): raise ec.ParseNotAParameterError(arg) @@ -246,11 +247,11 @@ def init_return_parametric_dn( for arg, val in dn_param_dict.items(): # needed for building inference specifications - if isinstance(val[0], pgm.StochasticNodePGM): + if isinstance(val[0], pgm.StochasticNodeDAG): # needed for building inference specifications parent_node_tracker[arg] = val[0].node_name - extracted_val_list = pgm.extract_value_from_nodepgm(val) + extracted_val_list = pgm.extract_vals_as_str_from_node_dag(val) if not cls.grammar_check("exponential", arg): raise ec.ParseNotAParameterError(arg) @@ -297,16 +298,16 @@ def init_return_parametric_dn( gamma_scale_or_rate: ty.List[float] = [] gamma_rate_parameterization: bool = False - # { mean: node_pgm1_name, sd: node_pgm2_name, ... } + # { mean: node_dag1_name, sd: node_dag2_name, ... } if dn_param_dict: # val is list for arg, val in dn_param_dict.items(): # needed for building inference specifications - if isinstance(val[0], pgm.StochasticNodePGM): + if isinstance(val[0], pgm.StochasticNodeDAG): parent_node_tracker[arg] = val[0].node_name - extracted_val_list = pgm.extract_value_from_nodepgm(val) + extracted_val_list = pgm.extract_vals_as_str_from_node_dag(val) if not cls.grammar_check("gamma", arg): raise ec.ParseNotAParameterError(arg) @@ -380,15 +381,15 @@ def init_return_parametric_dn( unif_min: ty.List[float] = [] unif_max: ty.List[float] = [] - # { mean: node_pgm1_name, sd: node_pgm2_name, ... } + # { mean: node_dag1_name, sd: node_dag2_name, ... } if dn_param_dict: # val is list for arg, val in dn_param_dict.items(): - if isinstance(val[0], pgm.StochasticNodePGM): + if isinstance(val[0], pgm.StochasticNodeDAG): # needed for building inference specifications parent_node_tracker[arg] = val[0].node_name - extracted_val_list = pgm.extract_value_from_nodepgm(val) + extracted_val_list = pgm.extract_vals_as_str_from_node_dag(val) if not cls.grammar_check("unif", arg): raise ec.ParseNotAParameterError(arg) @@ -449,7 +450,7 @@ def init_return_parametric_dn( @classmethod def init_return_discrete_SSE_dn( cls, - dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) \ + dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) \ -> pgm.DistributionPGM: """Create and return SSE distribution for sampling. @@ -470,7 +471,7 @@ def init_return_discrete_SSE_dn( def create_dn_obj( cls, dn_id: str, - dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) \ + dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) \ -> pgm.DistributionPGM: """Create and return prob. distribution (for sampling) object. diff --git a/src/phylojunction/interface/grammar/dn_grammar.pyi b/src/phylojunction/interface/grammar/dn_grammar.pyi index 80ba6df..6eec361 100644 --- a/src/phylojunction/interface/grammar/dn_grammar.pyi +++ b/src/phylojunction/interface/grammar/dn_grammar.pyi @@ -7,8 +7,8 @@ class PJDnGrammar: @classmethod def grammar_check(cls, dn_id: str, dn_param: str) -> bool: ... @classmethod - def init_return_parametric_dn(cls, dn_id: str, dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) -> pgm.DistributionPGM: ... + def init_return_parametric_dn(cls, dn_id: str, dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) -> pgm.DistributionPGM: ... @classmethod - def init_return_discrete_SSE_dn(cls, dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) -> pgm.DistributionPGM: ... + def init_return_discrete_SSE_dn(cls, dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) -> pgm.DistributionPGM: ... @classmethod - def create_dn_obj(cls, dn_id: str, dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodePGM]]]) -> pgm.DistributionPGM: ... + def create_dn_obj(cls, dn_id: str, dn_param_dict: ty.Dict[str, ty.List[ty.Union[str, pgm.NodeDAG]]]) -> pgm.DistributionPGM: ... diff --git a/src/phylojunction/interface/pjcli/cli_plotting.py b/src/phylojunction/interface/pjcli/cli_plotting.py index 85fe9bb..df6f945 100644 --- a/src/phylojunction/interface/pjcli/cli_plotting.py +++ b/src/phylojunction/interface/pjcli/cli_plotting.py @@ -6,7 +6,7 @@ from phylojunction.interface.pysidegui.content_main_window \ import ContentGUIMainWindow from phylojunction.pgm.pgm import DirectedAcyclicGraph -from phylojunction.pgm.pgm import NodePGM +from phylojunction.pgm.pgm import NodeDAG import phylojunction.interface.cmdbox.cmd_parse as cmdp import phylojunction.plotting.pj_organize as pjorg import phylojunction.plotting.pj_draw as pjdraw @@ -50,13 +50,13 @@ def selected_node_plot_cli( fig_dir: str, fig_obj: Figure, fig_axes: Axes, - node_pgm: NodePGM, + node_dag: NodeDAG, prefix: str = "", sample_idx: int = 0, repl_idx: int = 0, repl_size: int = 1) -> None: """ - Plot pgm node on provided Axes object (fig_axes), intended + Plot DAG node on provided Axes object (fig_axes), intended to be scoped to pj_cli.execute_pj_script() then update canvas with new plot """ @@ -66,16 +66,16 @@ def selected_node_plot_cli( if prefix: outfile_path += prefix + "_" - outfile_path += node_pgm.node_name + str(sample_idx + 1) \ + outfile_path += node_dag.node_name + str(sample_idx + 1) \ + "_" + str(repl_idx + 1) print("outfile_path = " + outfile_path) # if stochastic or constant, value will be list - if isinstance(node_pgm.value, list): + if isinstance(node_dag.value, list): # if a tree - if isinstance(node_pgm.value[0], pjdt.AnnotatedTree): - node_pgm.plot_node( + if isinstance(node_dag.value[0], pjdt.AnnotatedTree): + node_dag.plot_node( fig_axes, sample_idx=sample_idx, repl_idx=repl_idx, @@ -83,7 +83,7 @@ def selected_node_plot_cli( # when not a tree else: - node_pgm.plot_node( + node_dag.plot_node( fig_axes, sample_idx=sample_idx, repl_size=repl_size) @@ -96,7 +96,7 @@ def selected_node_plot_cli( # deterministic node, value is an Object # else: - # print(type(node_pgm.value)) + # print(type(node_dag.value)) def call_node_plot_cli( @@ -112,8 +112,8 @@ def call_node_plot_cli( for node_name, range_tup in node_range_dict.items(): if node_name in dag_obj.name_node_dict: - node_pgm = dag_obj.get_node_dag_by_name(node_name) - repl_size = node_pgm.repl_size + node_dag = dag_obj.get_node_dag_by_name(node_name) + repl_size = node_dag.repl_size start_idx = 0 end_idx = 1 @@ -140,7 +140,7 @@ def call_node_plot_cli( fig_dir, fig_obj, fig_axes, - node_pgm, + node_dag, prefix=prefix, sample_idx=sample_idx, repl_idx=repl_idx, diff --git a/src/phylojunction/interface/pjcli/pj_cli.py b/src/phylojunction/interface/pjcli/pj_cli.py index 103452e..66948c5 100644 --- a/src/phylojunction/interface/pjcli/pj_cli.py +++ b/src/phylojunction/interface/pjcli/pj_cli.py @@ -54,9 +54,9 @@ def execute_pj_script( os.mkdir(output_dir) # debugging (looking at model) - # for node_name, node_pgm in dag_obj.name_node_dict.items(): + # for node_name, node_dag in dag_obj.name_node_dict.items(): # print("\nnode name = " + node_name) - # print(node_pgm.value) + # print(node_dag.value) # Writing data # if write_data: diff --git a/src/phylojunction/interface/pysidegui/pj_gui.py b/src/phylojunction/interface/pysidegui/pj_gui.py index be6c575..53877e5 100644 --- a/src/phylojunction/interface/pysidegui/pj_gui.py +++ b/src/phylojunction/interface/pysidegui/pj_gui.py @@ -178,7 +178,7 @@ def __init__(self): # which makes it so spin buttons are not properly # initialized self.ui.ui_pages.node_list.itemClicked.connect( - lambda do_node: self.do_selected_node_pgm_page( + lambda do_node: self.do_selected_node_dag_page( spin_buttons_clicked=False)) # radio button update # @@ -364,19 +364,19 @@ def parse_cmd_update_gui(self): def selected_node_display( self, - node_pgm, + node_dag, do_all_samples, sample_idx=None, repl_idx=0, repl_size=1): - display_node_pgm_value_str = str() - display_node_pgm_stat_str = str() + display_node_dag_value_str = str() + display_node_dag_stat_str = str() is_tree = False # try: - if isinstance(node_pgm.value, list) and \ - isinstance(node_pgm.value[0], pjdt.AnnotatedTree): + if isinstance(node_dag.value, list) and \ + isinstance(node_dag.value[0], pjdt.AnnotatedTree): is_tree = True # if deterministic, not subscriptable @@ -391,16 +391,16 @@ def selected_node_display( end = start + repl_size # values - display_node_pgm_value_str = \ - node_pgm.get_start2end_str( + display_node_dag_value_str = \ + node_dag.get_start2end_str( start, end, repl_idx=repl_idx, is_tree=is_tree) # summary stats - display_node_pgm_stat_str = \ - node_pgm.get_node_stats_str( + display_node_dag_stat_str = \ + node_dag.get_node_stats_str( start, end, repl_idx) @@ -408,42 +408,42 @@ def selected_node_display( # we get all samples else: # just calling __str__ - display_node_pgm_value_str = \ + display_node_dag_value_str = \ self.gui_modeling.dag_obj \ - .get_display_str_by_name(node_pgm.node_name) + .get_display_str_by_name(node_dag.node_name) # getting all values - display_node_pgm_stat_str = \ - node_pgm.get_node_stats_str( - 0, len(node_pgm.value), repl_idx) # summary stats + display_node_dag_stat_str = \ + node_dag.get_node_stats_str( + 0, len(node_dag.value), repl_idx) # summary stats - # print("Set values_content QLineEdit widget with text: " + display_node_pgm_value_str) - # print("Set summary_content QLineEdit widget with text: " + display_node_pgm_stat_str) - self.ui.ui_pages.values_content.setText(display_node_pgm_value_str) - self.ui.ui_pages.summary_content.setText(display_node_pgm_stat_str) + # print("Set values_content QLineEdit widget with text: " + display_node_dag_value_str) + # print("Set summary_content QLineEdit widget with text: " + display_node_dag_stat_str) + self.ui.ui_pages.values_content.setText(display_node_dag_value_str) + self.ui.ui_pages.summary_content.setText(display_node_dag_stat_str) def selected_node_plot( self, fig_obj, fig_axes, - node_pgm, + node_dag, do_all_samples, sample_idx=None, repl_idx=0, repl_size=1): """ - Plot pgm node on 'node_display_fig_axes' (Axes object) scoped to 'call_gui()', + Plot DAG node on 'node_display_fig_axes' (Axes object) scoped to 'call_gui()', then update canvas with new plot """ # try: # if stochastic or constant, value will be list - if isinstance(node_pgm.value, list): + if isinstance(node_dag.value, list): # if a tree - if isinstance(node_pgm.value[0], pjdt.AnnotatedTree): - self.draw_node_pgm(fig_axes, - node_pgm, + if isinstance(node_dag.value[0], pjdt.AnnotatedTree): + self.draw_node_dag(fig_axes, + node_dag, sample_idx=sample_idx, repl_idx=repl_idx, repl_size=repl_size) @@ -451,13 +451,13 @@ def selected_node_plot( # when not a tree else: if do_all_samples: - self.draw_node_pgm(fig_axes, - node_pgm, + self.draw_node_dag(fig_axes, + node_dag, repl_size=repl_size) else: - self.draw_node_pgm(fig_axes, - node_pgm, + self.draw_node_dag(fig_axes, + node_dag, sample_idx=sample_idx, repl_size=repl_size) @@ -467,19 +467,19 @@ def selected_node_plot( # deterministic node else: - self.draw_node_pgm(fig_axes, node_pgm) + self.draw_node_dag(fig_axes, node_dag) fig_obj.canvas.draw() def selected_node_read(self, node_name: str): - node_pgm = self.gui_modeling.dag_obj.get_node_dag_by_name(node_name) + node_dag = self.gui_modeling.dag_obj.get_node_dag_by_name(node_name) # this is n_sim inside sampling distribution classes - sample_size = len(node_pgm) - repl_size = node_pgm.repl_size + sample_size = len(node_dag) + repl_size = node_dag.repl_size - return node_pgm, sample_size, repl_size + return node_dag, sample_size, repl_size - def do_selected_node_pgm_page(self, spin_buttons_clicked: bool = False): + def do_selected_node_dag_page(self, spin_buttons_clicked: bool = False): """ Display selected node's string representation and plot it on canvas if possible, for pgm page @@ -494,7 +494,7 @@ def do_selected_node_pgm_page(self, spin_buttons_clicked: bool = False): if selected_node_name not in ("", None): # reading node information # - node_pgm, sample_size, repl_size = \ + node_dag, sample_size, repl_size = \ self.selected_node_read(selected_node_name) # spin boxes must be up-to-date # @@ -513,7 +513,7 @@ def do_selected_node_pgm_page(self, spin_buttons_clicked: bool = False): # if spin buttons were clicked, no need # to updated radio and spin buttons if not spin_buttons_clicked: - self.init_and_refresh_radio_spin(node_pgm, sample_size, repl_size) + self.init_and_refresh_radio_spin(node_dag, sample_size, repl_size) do_all_samples = \ self.ui.ui_pages.all_samples_radio.isChecked() @@ -529,7 +529,7 @@ def do_selected_node_pgm_page(self, spin_buttons_clicked: bool = False): # Now do node # ############### - self.selected_node_display(node_pgm, + self.selected_node_display(node_dag, do_all_samples, sample_idx=sample_idx, repl_idx=repl_idx, @@ -537,7 +537,7 @@ def do_selected_node_pgm_page(self, spin_buttons_clicked: bool = False): self.selected_node_plot(fig_obj, fig_axes, - node_pgm, + node_dag, do_all_samples, sample_idx=sample_idx, repl_idx=repl_idx, @@ -566,7 +566,7 @@ def do_selected_node_compare_page(self): if selected_node_name not in ("", None): # reading node information # - selected_node_pgm, selected_node_sample_size, selected_node_repl_size = \ + selected_node_dag, selected_node_sample_size, selected_node_repl_size = \ self.selected_node_read(selected_node_name) # could be more efficient, but this @@ -599,7 +599,7 @@ def do_selected_node_compare_page(self): .loc[:, "program"] = "PJ" # scalar was selected # - if isinstance(selected_node_pgm.value[0], (int, float, np.float64)): + if isinstance(selected_node_dag.value[0], (int, float, np.float64)): self.ui.ui_pages.summary_stats_list.clear() if not self.is_avg_repl_check: @@ -612,7 +612,7 @@ def do_selected_node_compare_page(self): self.pj_comparison_df = scalar_repl_summary_df # tree was selected # - elif isinstance(selected_node_pgm.value[0], pjdt.AnnotatedTree): + elif isinstance(selected_node_dag.value[0], pjdt.AnnotatedTree): self.pj_comparison_df = tree_summary_df_dict[selected_node_name] self.ui.ui_pages.summary_stats_list.clear() self.ui.ui_pages.summary_stats_list.addItems( @@ -643,7 +643,7 @@ def do_selected_node_coverage_page(self): if selected_node_name not in ("", None): # reading node information # - selected_node_pgm, selected_node_sample_size, selected_node_repl_size = \ + selected_node_dag, selected_node_sample_size, selected_node_repl_size = \ self.selected_node_read(selected_node_name) # could be more efficient, but this @@ -662,12 +662,12 @@ def do_selected_node_coverage_page(self): tree_internal_nd_states_str_dict = tree_output_stash # str because could be constant set by hand - if isinstance(selected_node_pgm.value[0], (str, int, float, np.float64)): + if isinstance(selected_node_dag.value[0], (str, int, float, np.float64)): if selected_node_repl_size <= 1: self.ui.ui_pages.cov_summary_stats_list.clear() # sampled, non-deterministic - if selected_node_pgm.is_sampled: + if selected_node_dag.is_sampled: self.coverage_df = scalar_value_df_dict[selected_node_repl_size] # constant, non-deterministic @@ -877,14 +877,15 @@ def write_data_to_dir(self, prefix: str = ""): # drawing # ################# - def draw_node_pgm( + def draw_node_dag( self, axes, - node_pgm, + node_dag, sample_idx=None, repl_idx=0, repl_size=1) -> None: - return node_pgm.plot_node( + + return node_dag.plot_node( axes, sample_idx=sample_idx, repl_idx=repl_idx, @@ -908,14 +909,14 @@ def draw_violin(self): if selected_node_name not in ("", None): # reading node information # - node_pgm, sample_size, repl_size = \ + node_dag, sample_size, repl_size = \ self.selected_node_read(selected_node_name) if not self.pj_comparison_df.empty and \ not self.other_comparison_df.empty: # scalar - if isinstance(node_pgm.value[0], (int, float, np.float64)): + if isinstance(node_dag.value[0], (int, float, np.float64)): if not self.is_avg_repl_check: thing_to_compare = selected_node_name @@ -940,7 +941,7 @@ def draw_violin(self): pass # debugging - # print("node_pgm.is_sampled = " + str(node_pgm.is_sampled)) + # print("node_dag.is_sampled = " + str(node_dag.is_sampled)) # print(tabulate(self.pj_comparison_df, self.pj_comparison_df.head(), tablefmt="pretty", showindex=False).lstrip()) # print(tabulate(self.other_comparison_df, self.other_comparison_df.head(), tablefmt="pretty", showindex=False).lstrip()) @@ -952,7 +953,7 @@ def draw_violin(self): # something went wrong, # we clear comparison figure - if joint_dataframe.empty or not node_pgm.is_sampled: + if joint_dataframe.empty or not node_dag.is_sampled: # node list should only contain # sampled nodes already, but # just being sure... @@ -993,14 +994,14 @@ def draw_cov(self): if selected_node_name not in ("", None): # reading node information # - node_pgm, sample_size, repl_size = \ + node_dag, sample_size, repl_size = \ self.selected_node_read(selected_node_name) if not self.coverage_df.empty and \ not self.hpd_df.empty: # scalar - if isinstance(node_pgm.value[0], (str, int, float, np.float64)): + if isinstance(node_dag.value[0], (str, int, float, np.float64)): thing_to_validate = selected_node_name # debugging @@ -1040,7 +1041,7 @@ def draw_cov(self): # checking # ##################### - def init_and_refresh_radio_spin(self, node_pgm, sample_size, repl_size): + def init_and_refresh_radio_spin(self, node_dag, sample_size, repl_size): def _prepare_for_tree(potential_repl: bool = False): # radio # @@ -1106,9 +1107,9 @@ def _nothing_to_spin_through(): # can we even circulate through something # (basically: non-deterministic nodes) - if isinstance(node_pgm.value, list): - if isinstance(node_pgm.value[0], pjdt.AnnotatedTree): - if node_pgm.repl_size > 1: + if isinstance(node_dag.value, list): + if isinstance(node_dag.value[0], pjdt.AnnotatedTree): + if node_dag.repl_size > 1: _prepare_for_tree(potential_repl=True) else: @@ -1118,7 +1119,7 @@ def _nothing_to_spin_through(): else: if not self.ui.ui_pages.all_samples_radio.isEnabled() and \ not self.ui.ui_pages.one_sample_radio.isEnabled(): - if node_pgm.repl_size > 1: + if node_dag.repl_size > 1: _prepare_for_scalar(potential_repl=True) else: @@ -1133,7 +1134,7 @@ def _nothing_to_spin_through(): # irrespective of all samples or one sample # it should always be possible to check one # sample if there are replicates - if node_pgm.repl_size > 1: + if node_dag.repl_size > 1: # radio self.ui.ui_pages.one_sample_radio.setEnabled(True) self.ui.ui_pages.one_sample_radio.setCheckable(True) @@ -1170,7 +1171,7 @@ def _nothing_to_spin_through(): self.ui.ui_pages.sample_idx_spin.blockSignals(False) self.ui.ui_pages.repl_idx_spin.blockSignals(False) - # def init_and_refresh_radio_spin_working(self, node_pgm, sample_size, repl_size): + # def init_and_refresh_radio_spin_working(self, node_dag, sample_size, repl_size): # def _prepare_for_tree(): # # radio # @@ -1201,9 +1202,9 @@ def _nothing_to_spin_through(): # # Stochastic node # # ################### - # if node_pgm.is_sampled: + # if node_dag.is_sampled: # # tree # - # if isinstance(node_pgm.value[0], pjdt.AnnotatedTree): + # if isinstance(node_dag.value[0], pjdt.AnnotatedTree): # _prepare_for_tree() # # non-tree @@ -1271,9 +1272,9 @@ def _nothing_to_spin_through(): # # non-deterministic node because # # .value will not be list if so - # if isinstance(node_pgm.value, list): + # if isinstance(node_dag.value, list): # # tree # - # if isinstance(node_pgm.value[0], pjdt.AnnotatedTree): + # if isinstance(node_dag.value[0], pjdt.AnnotatedTree): # _prepare_for_tree() # # non-tree @@ -1363,12 +1364,12 @@ def refresh_node_lists(self): def refresh_selected_node_display_plot_radio(self): # if nodes have been created and selected # if self.ui.ui_pages.node_list.currentItem() is not None: - self.do_selected_node_pgm_page() + self.do_selected_node_dag_page() def refresh_selected_node_display_plot_spin(self): # if nodes have been created and selected # if self.ui.ui_pages.node_list.currentItem() is not None: - self.do_selected_node_pgm_page(spin_buttons_clicked=True) + self.do_selected_node_dag_page(spin_buttons_clicked=True) def refresh_cmd_history(self, user_reset=False): if user_reset: diff --git a/src/phylojunction/interface/pysimple_pjgui/pj_gui_pysimplegui.py b/src/phylojunction/interface/pysimple_pjgui/pj_gui_pysimplegui.py index 71c8c07..7329571 100644 --- a/src/phylojunction/interface/pysimple_pjgui/pj_gui_pysimplegui.py +++ b/src/phylojunction/interface/pysimple_pjgui/pj_gui_pysimplegui.py @@ -151,60 +151,60 @@ def clean_disable_everything(cmd_line_hist: str, msg: str) -> ty.Tuple[pgm.Direc return dag_obj, ax, comparison_ax, validation_ax - def draw_node_pgm(axes, node_pgm, sample_idx=None, repl_idx=0, repl_size=1): - return node_pgm.plot_node(axes, sample_idx=sample_idx, repl_idx=repl_idx, repl_size=repl_size) + def draw_node_dag(axes, node_dag, sample_idx=None, repl_idx=0, repl_size=1): + return node_dag.plot_node(axes, sample_idx=sample_idx, repl_idx=repl_idx, repl_size=repl_size) def selected_node_read(dag_obj, node_name): - node_pgm = dag_obj.get_node_dag_by_name(node_name) - # display_node_pgm_value_str = pg.get_display_str_by_name(node_name) - sample_size = len(node_pgm) # this is n_sim inside sampling distribution classes - repl_size = node_pgm.repl_size + node_dag = dag_obj.get_node_dag_by_name(node_name) + # display_node_dag_value_str = pg.get_display_str_by_name(node_name) + sample_size = len(node_dag) # this is n_sim inside sampling distribution classes + repl_size = node_dag.repl_size - return node_pgm, sample_size, repl_size + return node_dag, sample_size, repl_size - def selected_node_display(wdw, dag_obj, node_pgm, do_all_samples, sample_idx=None, repl_idx=0, repl_size=1): - display_node_pgm_value_str = str() - display_node_pgm_stat_str = str() + def selected_node_display(wdw, dag_obj, node_dag, do_all_samples, sample_idx=None, repl_idx=0, repl_size=1): + display_node_dag_value_str = str() + display_node_dag_stat_str = str() # first we do values # we care about a specific sample and maybe a specific replicate if not sample_idx == None and not do_all_samples: start = sample_idx * repl_size end = start + repl_size - display_node_pgm_value_str = node_pgm.get_start2end_str(start, end) # values - display_node_pgm_stat_str = node_pgm.get_node_stats_str(start, end, repl_idx) # summary stats + display_node_dag_value_str = node_dag.get_start2end_str(start, end) # values + display_node_dag_stat_str = node_dag.get_node_stats_str(start, end, repl_idx) # summary stats # we get all samples else: # just calling __str__ - display_node_pgm_value_str = dag_obj.get_display_str_by_name(node_pgm.node_name) + display_node_dag_value_str = dag_obj.get_display_str_by_name(node_dag.node_name) # getting all values - display_node_pgm_stat_str = node_pgm.get_node_stats_str(0, len(node_pgm.value), repl_idx) # summary stats + display_node_dag_stat_str = node_dag.get_node_stats_str(0, len(node_dag.value), repl_idx) # summary stats - wdw["-PGM-NODE-DISPLAY-"].update(display_node_pgm_value_str) - wdw["-PGM-NODE-STAT-"].update(display_node_pgm_stat_str) + wdw["-PGM-NODE-DISPLAY-"].update(display_node_dag_value_str) + wdw["-PGM-NODE-STAT-"].update(display_node_dag_stat_str) - def selected_node_plot(fig_obj, node_pgm, do_all_samples, sample_idx=None, repl_idx=0, repl_size=1): + def selected_node_plot(fig_obj, node_dag, do_all_samples, sample_idx=None, repl_idx=0, repl_size=1): """ Plot pgm node on 'node_display_fig_axes' (Axes object) scoped to 'call_gui()', then update canvas with new plot """ try: # if a tree - if isinstance(node_pgm.value[0], pjdt.AnnotatedTree): - draw_node_pgm(node_display_fig_axes, node_pgm, sample_idx=sample_idx, repl_idx=repl_idx) + if isinstance(node_dag.value[0], pjdt.AnnotatedTree): + draw_node_dag(node_display_fig_axes, node_dag, sample_idx=sample_idx, repl_idx=repl_idx) # when not a tree else: if do_all_samples: - draw_node_pgm(node_display_fig_axes, node_pgm, repl_size=repl_size) + draw_node_dag(node_display_fig_axes, node_dag, repl_size=repl_size) else: - draw_node_pgm(node_display_fig_axes, node_pgm, sample_idx=sample_idx, repl_size=repl_size) + draw_node_dag(node_display_fig_axes, node_dag, sample_idx=sample_idx, repl_size=repl_size) # when it's deterministic except: - draw_node_pgm(node_display_fig_axes, node_pgm) + draw_node_dag(node_display_fig_axes, node_dag) fig_obj.canvas.draw() @@ -214,14 +214,15 @@ def do_selected_node(dag_obj, wdw, fig_obj, node_name, do_all_samples=True): Given selected node name, display its string representation and plot it on canvas if possible """ - node_pgm, sample_size, repl_size = selected_node_read(dag_obj, node_name) - # updates spin window with number of elements in this node_pgm + node_dag, sample_size, repl_size = selected_node_read(dag_obj, node_name) + + # updates spin window with number of elements in this node_dag # window["-ITH-VAL-"].update(values=[x for x in range(1, sample_size + 1)]) # can only select the number of values this node contains wdw["-ITH-SAMPLE-"].update(values=[x for x in range(1, sample_size + 1)]) # can only select the number of values this node contains - if type(node_pgm.value) == list: - if isinstance(node_pgm.value[0], pjdt.AnnotatedTree): + if type(node_dag.value) == list: + if isinstance(node_dag.value[0], pjdt.AnnotatedTree): wdw["-ITH-REPL-"].update(values=[x for x in range(1, repl_size + 1)]) # can only select the number of values this node contains else: wdw["-ITH-REPL-"].update(disabled=True) @@ -232,12 +233,12 @@ def do_selected_node(dag_obj, wdw, fig_obj, node_name, do_all_samples=True): repl_idx = int(wdw["-ITH-REPL-"].get()) - 1 # (offset) # updating node values on window happens inside - selected_node_display(wdw, dag_obj, node_pgm, do_all_samples, sample_idx=sample_idx, repl_idx=repl_idx, repl_size=repl_size) + selected_node_display(wdw, dag_obj, node_dag, do_all_samples, sample_idx=sample_idx, repl_idx=repl_idx, repl_size=repl_size) # plotting to canvas happens inside - selected_node_plot(fig_obj, node_pgm, do_all_samples, sample_idx=sample_idx, repl_idx=repl_idx, repl_size=repl_size) + selected_node_plot(fig_obj, node_dag, do_all_samples, sample_idx=sample_idx, repl_idx=repl_idx, repl_size=repl_size) - return node_pgm + return node_dag ###################### # Development screen # @@ -742,7 +743,8 @@ def do_selected_node(dag_obj, wdw, fig_obj, node_name, do_all_samples=True): value_str = str() if values["-COPY-ALL-"]: - value_str = dag_obj.get_display_str_by_name(node_pgm.node_name) + value_str = dag_obj.get_display_str_by_name(node_dag.node_name) + else: value_str = values['-PGM-NODE-DISPLAY-'] @@ -768,27 +770,27 @@ def do_selected_node(dag_obj, wdw, fig_obj, node_name, do_all_samples=True): # if nodes have been created and selected if values["-PGM-NODES-"]: - selected_node_pgm_name = values["-PGM-NODES-"][0] + selected_node_dag_name = values["-PGM-NODES-"][0] do_all_samples = window["-ALL-SAMPLES-"].get() # True or False # if selected node is tree, we do not want to show all trees on display by default try: - if isinstance(dag_obj.get_node_dag_by_name(selected_node_pgm_name).value[0], pjdt.AnnotatedTree): + if isinstance(dag_obj.get_node_dag_by_name(selected_node_dag_name).value[0], pjdt.AnnotatedTree): do_all_samples = False - except: pass # the value of the node_pgm might be an DiscreteStateDependentRate, which is not subscriptable, so we pass + except: pass # the value of the node_dag might be an DiscreteStateDependentRate, which is not subscriptable, so we pass - node_pgm = do_selected_node(dag_obj, window, node_display_fig, selected_node_pgm_name, do_all_samples=do_all_samples) + node_dag = do_selected_node(dag_obj, window, node_display_fig, selected_node_dag_name, do_all_samples=do_all_samples) # we enable value copying as soon as a node is clicked window["-COPY-VALUE-"].update(disabled=False) window["-COPY-ALL-"].update(disabled=False) # if there is a chance for replicates to exist, we enable the one-sample radio button - if node_pgm.is_sampled: + if node_dag.is_sampled: window["-ONE-SAMPLE-"].update(disabled=False) # cycling through trees can only be done with "one-sample" radio button - if isinstance(node_pgm.value[0], pjdt.AnnotatedTree): + if isinstance(node_dag.value[0], pjdt.AnnotatedTree): window["-ALL-SAMPLES-"].update(disabled=True) window["-ALL-SAMPLES-"].update(False) window["-ONE-SAMPLE-"].update(True) @@ -814,7 +816,7 @@ def do_selected_node(dag_obj, wdw, fig_obj, node_name, do_all_samples=True): window["-ITH-SAMPLE-"].update(disabled=False) # if we are looking at trees, we can cycle through replicates - if isinstance(node_pgm.value[0], pjdt.AnnotatedTree): + if isinstance(node_dag.value[0], pjdt.AnnotatedTree): window["-ITH-REPL-"].update(disabled=False) # otherwise, all replicates will be visualized as histogram (no cycling allowed) else: @@ -827,9 +829,9 @@ def do_selected_node(dag_obj, wdw, fig_obj, node_name, do_all_samples=True): elif event == "-ITH-SAMPLE-": # if nodes have been created and selected if values["-PGM-NODES-"]: - selected_node_pgm_name = values["-PGM-NODES-"][0] + selected_node_dag_name = values["-PGM-NODES-"][0] do_all_samples = window["-ALL-SAMPLES-"].get() # True or False - node_pgm = do_selected_node(dag_obj, window, node_display_fig, selected_node_pgm_name, do_all_samples=do_all_samples) + node_dag = do_selected_node(dag_obj, window, node_display_fig, selected_node_dag_name, do_all_samples=do_all_samples) ####################### @@ -841,15 +843,15 @@ def do_selected_node(dag_obj, wdw, fig_obj, node_name, do_all_samples=True): # repl_idx = values["-ITH-REPL-"] - 1 # (offset) # # only updates display if tree node is selected - # if isinstance(node_pgm.value[0], AnnotatedTree): - # draw_node_pgm(node_display_fig_axes, node_pgm, sample_idx=sample_idx, repl_idx=repl_idx) + # if isinstance(node_dag.value[0], AnnotatedTree): + # draw_node_dag(node_display_fig_axes, node_dag, sample_idx=sample_idx, repl_idx=repl_idx) # node_display_fig.canvas.draw() # if nodes have been created and selected if values["-PGM-NODES-"]: - selected_node_pgm_name = values["-PGM-NODES-"][0] + selected_node_dag_name = values["-PGM-NODES-"][0] do_all_samples = window["-ALL-SAMPLES-"].get() # True or False - node_pgm = do_selected_node(dag_obj, window, node_display_fig, selected_node_pgm_name, do_all_samples=do_all_samples) + node_dag = do_selected_node(dag_obj, window, node_display_fig, selected_node_dag_name, do_all_samples=do_all_samples) ################## diff --git a/src/phylojunction/pgm/pgm.py b/src/phylojunction/pgm/pgm.py index 6ec6cce..79b989a 100644 --- a/src/phylojunction/pgm/pgm.py +++ b/src/phylojunction/pgm/pgm.py @@ -16,10 +16,8 @@ __email__ = "f.mendes@wustl.edu" -# code for @abstract attribute +# code for @abstract_attribute R = ty.TypeVar('R') - - def abstract_attribute(obj: ty.Callable[[ty.Any], R] = None) -> R: class DummyAttribute: @@ -36,47 +34,59 @@ class DummyAttribute: class DirectedAcyclicGraph(): - node_val_dict: ty.Dict[NodePGM, ty.Any] - name_node_dict: ty.Dict[str, NodePGM] + """Directed acyclic graph (DAG) class defining the model. + + Attributes: + node_val_dict (dict): Dictionary with keys being DAG nodes, and + values being the values stored in the nodes. + name_node_dict (dict): Dictionary with keys being DAG node + names, and the DAG nodes themselves as values. + n_nodes (int): Total number of DAG nodes in the graph. + sample_size (int): How many samples (simulations) to be either + drawn or specified in every DAG node. + """ + + node_val_dict: ty.Dict[NodeDAG, ty.Any] + name_node_dict: ty.Dict[str, NodeDAG] n_nodes: int sample_size: int def __init__(self) -> None: - # keys are proper PGM nodes, values are their values + # keys are proper DAG nodes, values are their values self.node_val_dict = dict() - # keys are NodePGM names, vals are NodePGM instances + # keys are DAG node names, vals are NodeDAG instances self.name_node_dict = dict() self.n_nodes = 0 self.sample_size = 0 # how many simulations will be run - def add_node(self, node_dag: NodePGM) -> None: + def add_node(self, node_dag: NodeDAG) -> None: # check that nodes carry the right number of values # (the number of simulations) - if isinstance(node_dag, StochasticNodePGM): + if isinstance(node_dag, StochasticNodeDAG): ############# # Important # ############# # note how only sampled nodes have any business in setting - # the sample size of a PGM object; this means we let the users + # the sample size of a DAG object; this means we let the users # fool around with nodes with assigned (fixed, clamped) values # through '<-' if node_dag.is_sampled: # if the pgm's sample size is still 0, # or if we started off with a scalar node but then - # added sampled node, we update the pgm's sample size + # added sampled node, we update the DAG's sample size if not self.sample_size or \ - (self.sample_size == 1 and node_dag.sample_size > 1): - self.sample_size = node_dag.sample_size + (self.sample_size == 1 and node_dag._sample_size > 1): + self.sample_size = node_dag._sample_size # if the number of values in a node is 1, it gets vectorized, # so this is allowed; but if the node is sampled and the number # of values is > 1 and < than that of other nodes, we have a # problem - elif self.sample_size != node_dag.sample_size and \ - node_dag.sample_size > 1: + elif self.sample_size != node_dag._sample_size and \ + node_dag._sample_size > 1: raise ec.DAGCannotAddNodeError( node_dag.node_name, @@ -120,18 +130,17 @@ def get_display_str_by_name( repl_size=1): if node_name in self.name_node_dict: - # calls __str__() of NodePGM + # calls __str__() of NodeDAG return str(self.name_node_dict[node_name]) - def get_sorted_node_dag_list(self) -> ty.List[NodePGM]: - node_dag_list: ty.List[NodePGM] = [node_dag for node_dag in self.node_val_dict] + def get_sorted_node_dag_list(self) -> ty.List[NodeDAG]: + node_dag_list: ty.List[NodeDAG] = [node_dag for node_dag in self.node_val_dict] node_dag_list.sort() return node_dag_list class ValueGenerator(ABC): - @abstract_attribute def n_samples(self): pass @@ -165,42 +174,95 @@ def get_rev_inference_spec_info(self) -> ty.List[str]: class DistributionPGM(ValueGenerator): + """Class for randomly generating (i.e., sampling) values. + + An object of this class is required by stochastic DAG nodes so + values can be sampled. + """ + + # python3's way of "declaring" an (required) abstract attribute @property @abstractmethod - def DN_NAME(self): - pass + def DN_NAME(self) -> str: + raise NotImplementedError class ConstantFn(ValueGenerator): + """Class for deterministically generating values. + + An object is class whenever user input must be parsed and modified + before it can be stored in a DAG node. + """ + @property @abstractmethod def CT_FN_NAME(self): - pass - -############################################################################## - - -class NodePGM(ABC): - # for later, I think - # value: ty.Union[float, ty.List[ty.Union[float,T]]] = None + raise NotImplementedError + + +class NodeDAG(ABC): + """Node class for building directed acyclic graphs (DAGs). + + This class has abstract methods, so it cannot be instantiated. + It is derived by StochasticNodeDAG and DeterministicNodeDAG + + Attributes: + node_name (str): Name of the node, e.g., 'lambda'. + sample_size (int): How many samples (simulations) to be either + drawn or specified for the node. + replicate_size (int): How many replicates (for each sample) to + be drawn or specified the node. This is the size of a + 'plate' in graphical notation. Defaults to 1. + value (object): List of values associated to node. Defaults to + None. + call_order_idx (int): The order at which this node was added to + DAG. If it is the first node of the DAG, for example, this + value is 1. Defaults to None. + sampled (bool): Flag specifying if what the 'value' attribute + stores are stochastic samples from a distribution. Defaults to + 'False'. + deterministic (bool): Flag specifying if what the 'value' + attribute stores is the output of a deterministic function. + Defaults to 'False'. + clamped (bool): Flag specifying if what the 'value' attribute + stores is observed (i.e., data). Defaults to 'False'. + parent_nodes (NodeDAG): List of (parent) NodeDAG objects that + are in the path between this node and the outermost layer + in the DAG. Defaults to None. + """ - # don't want to allow NodePGM to be initialized + node_name: str + _value: ty.List[ty.Any] + _sample_size: int + _repl_size: int + call_order_idx: int + is_sampled: bool + is_deterministic: bool + is_clamped: bool + parent_nd_list: ty.Optional[ty.List[NodeDAG]] + + # why decorate __init__: + # + # (i) don't want to allow NodeDAG to be instantiated + # (ii) want to provide a "default" initializer + # (iii) want to force all daughter classes to define their own + # __init__(), and in there have a super().__init__() call @abstractmethod def __init__(self, node_name: str, - sample_size: int, - value: ty.Optional[ty.List[ty.Any]] = None, + sample_size: ty.Optional[int], replicate_size: int = 1, + value: ty.Optional[ty.List[ty.Any]] = None, call_order_idx: ty.Optional[int] = None, sampled: bool = False, deterministic: bool = False, clamped: bool = False, - parent_nodes: ty.Optional[ty.List[NodePGM]] = None): + parent_nodes: ty.Optional[ty.List[NodeDAG]] = None): self.node_name = node_name - self.value = value - self.sample_size = sample_size - self.repl_size = replicate_size + self._value = value + self._sample_size = sample_size + self._repl_size = replicate_size self.call_order_idx = call_order_idx self.is_sampled = sampled self.is_deterministic = deterministic @@ -209,7 +271,7 @@ def __init__(self, # note that when the dag_obj adds this to its list of nodes, # value will be None (value is populated when we call .sample()), # and self.length will be = 1; we nonetheless add this call here - # for completion (useful in debugging and testing + # for completion (useful in debugging and testing) # # as we build the PGF through a script/gui, self.populate_length() # is in fact called from outside through method get_length() @@ -218,17 +280,42 @@ def __init__(self, # self.full_length = self.length * self.repl_size self.param_of = None - if isinstance(self.value, (list, np.ndarray)) and not \ - isinstance(self.value[0], pjtr.AnnotatedTree): + if isinstance(self._value, (list, np.ndarray)) and not \ + isinstance(self._value[0], pjtr.AnnotatedTree): self._flatten_and_extract_values() - # side-effect updates self.value + @property + def value(self): + return self._value + + @value.setter + def value(self, a_value): + self._value = a_value + + # no setter! + @property + def sample_size(self): + return self._sample_size + + # no setter! + @property + def repl_size(self): + return self._repl_size + + # TODO: maybe later make n_samples and n_repl properties + # and make it so their setters can only raise Errors telling + # callers that they can only be set upon initialization + + # side-effect updates self._value def _flatten_and_extract_values(self) -> None: + """If value member is 2D-list, flatten it to 1D-list.""" + values: ty.List[ty.Any] = [] # so mypy won't complain - if isinstance(self.value, list): - for v in self.value: + if isinstance(self._value, list): + for v in self._value: + # ._value is a 2D-list if not isinstance(v, (int, float, str, np.float64)): values_inside_nodes = v.value @@ -241,42 +328,45 @@ def _flatten_and_extract_values(self) -> None: values.append(val) + # each value in ._value is scalar else: values.append(v) - self.value = values + self._value = values # called by GUI def get_start2end_str(self, start: int, end: int, repl_idx: int = 0, - is_tree: bool = False) -> str: - if isinstance(self.value, np.ndarray): - self.value = \ - ", ".join(str(v) for v in self.value.tolist()[start:end]) + is_tree: bool = False) -> str: + """Get string representation of values within specific range.""" - if isinstance(self.value, list): - if len(self.value) >= 2: + if isinstance(self._value, np.ndarray): + self._value = \ + ", ".join(str(v) for v in self._value.tolist()[start:end]) + + if isinstance(self._value, list): + if len(self._value) >= 2: if not is_tree: - if isinstance(self.value[0], + if isinstance(self._value[0], (int, float, str, np.float64)): return ", ".join( - str(v) for v in self.value[start:end]) + str(v) for v in self._value[start:end]) else: return "\n".join( - str(v) for v in self.value[start:end]) + str(v) for v in self._value[start:end]) # not a tree else: - return str(self.value[start + repl_idx]) + return str(self._value[start + repl_idx]) # single element in value else: - return str(self.value[0]) + return str(self._value[0]) - return str(self.value) + return str(self._value) def __hash__(self): return hash(self.node_name) @@ -308,7 +398,7 @@ def __len__(self) -> int: if isinstance(self.value, list): n_values = len(self.value) - n_repls = self.repl_size + n_repls = self._repl_size if n_values >= 1: if n_values % n_repls == 0: @@ -329,7 +419,7 @@ def __len__(self) -> int: @abstractmethod def plot_node(self, axes: plt.Axes, - sample_idx: ty.Optional[int] = None, + sample_idx: int = 0, repl_idx: int = 0, repl_size: int = 1, branch_attr: ty.Optional[str] = "state") -> None: @@ -401,7 +491,7 @@ def get_node_stats_str(self, float_stat_v = float(stat_v) except ValueError: - raise ec.NodePGMNodeStatCantFloatError( + raise ec.NodeDAGNodeStatCantFloatError( self.node_name) try: @@ -465,12 +555,18 @@ def get_node_stats_str(self, return stats_str -############################################################################## - -class StochasticNodePGM(NodePGM): +class StochasticNodeDAG(NodeDAG): + """Derived DAG node that can generate own value. + + This class is also used when a value is being specified by a user, + or parsed from it. In other words, constant nodes in the graph + are still stochastic DAG nodes under the hood. + """ - random_value: ty.List[ty.Any] + sampling_dn: ty.Optional[DistributionPGM] + constant_fn: ty.Optional[DistributionPGM] + operator_weight: float def __init__(self, node_name: str, @@ -487,16 +583,14 @@ def __init__(self, self.is_sampled = False self.sampling_dn = sampled_from # dn object self.constant_fn = returned_from # constant fn object - - if not value: - self.random_value = self.get_value() - else: - self.random_value = value + # not value checks for both [] and None + generated_or_specified_value: ty.List[ty.Any] = \ + self.generate_value() if not value else value super().__init__(node_name, sample_size=sample_size, - value=self.random_value, + value=generated_or_specified_value, replicate_size=replicate_size, call_order_idx=call_order_idx, deterministic=deterministic, @@ -512,15 +606,19 @@ def __init__(self, if self.sampling_dn: self.is_sampled = True # r.v. value is sampled - def get_value(self) -> ty.List[ty.Any]: + def generate_value(self) -> ty.List[ty.Any]: + """Generate value.""" + if self.sampling_dn: return self.sampling_dn.generate() elif self.constant_fn: + print("inside generate_value()") return self.constant_fn.generate() else: - raise RuntimeError("exiting...") + raise RuntimeError(("Cannot generate value. No distribution nor " + "constant function provided.")) def __str__(self) -> str: return super().__str__() @@ -530,22 +628,22 @@ def __lt__(self, other): def plot_node(self, axes: plt.Axes, - sample_idx: ty.Optional[int] = None, + sample_idx: int = 0, repl_idx: int = 0, repl_size: int = 1, branch_attr: str = "state") -> None: - """_summary_ + """Plot node (side-effect) on provided Axes object Args: - axes (matplotlib.pyplot.Axes): _description_ - sample_idx (int, optional): Which sample to plot. - Defaults to 0. - repl_idx (int, optional): Which tree replicate to plot - (one tree is plotted at a time). Defaults to 0. - repl_size (int, optional): How many scalar random variables + axes (matplotlib.pyplot.Axes): Axes object where we are + drawing the tree. + sample_idx (int): Which sample to plot. Defaults to 0. + repl_idx (int): Which replicate to plot (one + replicated is plotted at a time). Defaults to 0. + repl_size (int): How many scalar random variables to plot at a time. Defaults to 1. - branch_attr (str, optional): Which discrete attribute - associated to a branch length to color by. Defaults to "state". + branch_attr (str, optional): If tree, which branch attribute + to color branch according to. Defaults to 'state'. """ # if list @@ -570,7 +668,7 @@ def plot_node(self, hist_vals = [ty.cast(float, v) for v in self.value] # one sample - if sample_idx is not None and self.sampling_dn: + if sample_idx is not None and self._sampling_dn: plot_node_histogram(axes, hist_vals, sample_idx=sample_idx, @@ -602,22 +700,27 @@ def populate_operator_weight(self): else: # TODO: later see how rev moves 2D-arrays and tree nodes raise RuntimeError( - ("Could not determine dimension of StochasticNodePGM when" + ("Could not determine dimension of StochasticNodeDAG when" " figuring out operator weight. Exiting...")) def get_node_stats_str(self, start: int, end: int, repl_idx: int) -> str: return super().get_node_stats_str(start, end, repl_idx) -############################################################################## +class DeterministicNodeDAG(NodeDAG): + """Derived DAG node that holds value dependent on another node's. + + Deterministic node values depend deterministically on those from + parent nodes. These nodes are used to modify and annotate + stochastic node values. + """ -class DeterministicNodePGM(NodePGM): def __init__(self, - node_name, - value=None, - call_order_idx=None, - deterministic=True, - parent_nodes=None): + node_name: str, + value: ty.Optional[ty.List[ty.Any]] = None, + call_order_idx: ty.Optional[int] = None, + deterministic: bool = True, + parent_nodes: ty.Optional[ty.List[NodeDAG]] = None): super().__init__(node_name, sample_size=None, @@ -640,11 +743,10 @@ def __len__(self) -> int: def plot_node(self, axes: plt.Axes, - sample_idx: ty.Optional[int] = None, + sample_idx: ty.Optional[int] = 0, repl_idx: ty.Optional[int] = 0, repl_size: ty.Optional[int] = 1, branch_attr: ty.Optional[str] = "state") -> None: - plot_blank(axes) def populate_operator_weight(self): @@ -707,32 +809,37 @@ def plot_blank(axes: plt.Axes) -> None: #################### # Helper functions # #################### -def extract_value_from_nodepgm( - val_list: ty.List[ty.Union[str, NodePGM]]) -> ty.List[str]: - """ - Return list of values +def extract_vals_as_str_from_node_dag( + val_list: ty.List[ty.Union[str, NodeDAG]]) -> ty.List[str]: + """Get values from DAG node. If all elements are strings, returns copy of 'val_list'. - When elements are StochasticNodePGMs, replaces those objects + When elements are StochasticNodeDAGs, replaces those objects by their values after casting to string (their values must be within a list). - If StochasticNodePGMs objects do not have .value field or if - they cannot be string-fied, raise exception. + Raise: + VariableMisspec: Is raised if DAG node does not have a value + or if it cannot be string-fied. + + Returns: + (str): List of values as strings. """ - extracted_val_list: ty.List[str] = [] + + extracted_val_list: ty.List[str] = list() + for v in val_list: if isinstance(v, str): extracted_val_list.append(v) - elif isinstance(v, StochasticNodePGM) and v.value: + elif isinstance(v, StochasticNodeDAG) and v.value: try: extracted_val_list.extend([str(i) for i in v.value]) except (AttributeError, TypeError) as e: raise ec.VariableMisspec(str(v)) - # will be empty if DeterministicNodePGM is in val_list + # will be empty if DeterministicNodeDAG is in val_list return extracted_val_list diff --git a/src/phylojunction/pgm/pgm.pyi b/src/phylojunction/pgm/pgm.pyi index 9867586..73dbeca 100644 --- a/src/phylojunction/pgm/pgm.pyi +++ b/src/phylojunction/pgm/pgm.pyi @@ -18,15 +18,15 @@ def abstract_attribute(obj: ty.Callable[[ty.Any], R] = None) -> R: return ty.cast(R, _obj) class DirectedAcyclicGraph: - node_val_dict: ty.Dict[NodePGM, ty.Any] - name_node_dict: ty.Dict[str, NodePGM] + node_val_dict: ty.Dict[NodeDAG, ty.Any] + name_node_dict: ty.Dict[str, NodeDAG] n_nodes: int sample_size: int def __init__(self) -> None: ... - def add_node(self, node_dag: NodePGM) -> None: ... + def add_node(self, node_dag: NodeDAG) -> None: ... def get_node_dag_by_name(self, node_name): ... def get_display_str_by_name(self, node_name, sample_idx: Incomplete | None = ..., repl_size: int = ...): ... - def get_sorted_node_dag_list(self) -> ty.List[NodePGM]: ... + def get_sorted_node_dag_list(self) -> ty.List[NodeDAG]: ... class ValueGenerator(ABC, metaclass=abc.ABCMeta): @abstract_attribute @@ -56,7 +56,7 @@ class ConstantFn(ValueGenerator): def CT_FN_NAME(self): pass -class NodePGM(ABC, metaclass=abc.ABCMeta): +class NodeDAG(ABC, metaclass=abc.ABCMeta): node_name: str value: ty.Optional[ty.List[ty.Any]] sample_size: int @@ -69,7 +69,7 @@ class NodePGM(ABC, metaclass=abc.ABCMeta): param_of: Incomplete @abstractmethod - def __init__(self, node_name: str, sample_size: int, value: ty.Optional[ty.List[ty.Any]] = ..., replicate_size: int = ..., call_order_idx: ty.Optional[int] = ..., sampled: bool = ..., deterministic: bool = ..., clamped: bool = ..., parent_nodes: ty.Optional[ty.List[NodePGM]] = ...): ... + def __init__(self, node_name: str, sample_size: int, value: ty.Optional[ty.List[ty.Any]] = ..., replicate_size: int = ..., call_order_idx: ty.Optional[int] = ..., sampled: bool = ..., deterministic: bool = ..., clamped: bool = ..., parent_nodes: ty.Optional[ty.List[NodeDAG]] = ...): ... def _flatten_and_extract_values(self) -> None: ... def get_start2end_str(self, start: int, end: int, repl_idx: int=0, is_tree: bool=False) -> str: ... def __str__(self) -> str: ... @@ -84,21 +84,22 @@ class NodePGM(ABC, metaclass=abc.ABCMeta): @abstractmethod def populate_operator_weight(self): ... -class StochasticNodePGM(NodePGM): - random_value: ty.List[ty.Any] +class StochasticNodeDAG(NodeDAG): + _generated_value: ty.List[ty.Any] is_sampled: bool sampling_dn: Incomplete operator_weight: float def __init__(self, node_name: str, sample_size: int, sampled_from: Incomplete | None = ..., value: ty.Optional[ty.List[ty.Any]] = ..., replicate_size: int = ..., call_order_idx: ty.Optional[int] = ..., deterministic: bool = ..., clamped: bool = ..., parent_nodes: ty.Optional[ty.List[ty.Any]] = ...) -> None: ... value: Incomplete - def get_value(self) -> ty.List[ty.Any]: ... + def generate_value(self) -> ty.List[ty.Any]: ... + def generated_value(self) -> ty.List[ty.Any]: ... def sample(self) -> None: ... def __lt__(self, other): ... def get_gcf(self, axes: plt.Axes, sample_idx: ty.Optional[int]=None, repl_idx: int=0, repl_size: int=1, branch_attr: ty.Optional[str]="state") -> None: ... def populate_operator_weight(self) -> None: ... -class DeterministicNodePGM(NodePGM): +class DeterministicNodeDAG(NodeDAG): is_sampled: bool def __init__(self, node_name, value: Incomplete | None = ..., call_order_idx: Incomplete | None = ..., deterministic: bool = ..., parent_nodes: Incomplete | None = ...) -> None: ... @@ -106,6 +107,6 @@ class DeterministicNodePGM(NodePGM): def get_gcf(self, axes: plt.Axes, sample_idx: ty.Optional[int]=None, repl_idx: int=0, repl_size: int=1, branch_attr: ty.Optional[str]="state") -> None: ... def populate_operator_weight(self) -> None: ... -def extract_value_from_nodepgm(val_list: ty.List[ty.Union[str, NodePGM]]) -> ty.List[str]: ... +def extract_vals_as_str_from_node_dag(val_list: ty.List[ty.Union[str, NodeDAG]]) -> ty.List[str]: ... def get_histogram_gcf(axes: plt.Axes, values_list: ty.List[float], sample_idx: ty.Optional[int] = ..., repl_size: int = ...) -> None: ... diff --git a/src/phylojunction/readwrite/pj_write.py b/src/phylojunction/readwrite/pj_write.py index bfff9e2..9baf79e 100644 --- a/src/phylojunction/readwrite/pj_write.py +++ b/src/phylojunction/readwrite/pj_write.py @@ -184,7 +184,7 @@ def prep_data_df( """ sample_size = dag_obj.sample_size - node_pgm_list = dag_obj.get_sorted_node_dag_list() + node_dag_list = dag_obj.get_sorted_node_dag_list() # tree nodes (one dataframe per different tree node) @@ -235,21 +235,21 @@ def prep_data_df( # stats; scalars always have only average and std. dev. # main loop: nodes! - for node_pgm in node_pgm_list: - rv_name = node_pgm.node_name - node_val = node_pgm.value # list - n_repl = node_pgm.repl_size + for node_dag in node_dag_list: + rv_name = node_dag.node_name + node_val = node_dag.value # list + n_repl = node_dag.repl_size # if stochastic node is constant and at least one value was indeed # saved in list - if isinstance(node_pgm, pgm.StochasticNodePGM) and node_val: + if isinstance(node_dag, pgm.StochasticNodeDAG) and node_val: ################################ # Fixed-value stochastic nodes # ################################ # scalar constants (no support for replication via 2D-lists yet) if isinstance(node_val[0], (str, int, float, np.float64)) \ - and not node_pgm.is_sampled: + and not node_dag.is_sampled: if scalar_constant_df.empty: if sample_size == 0: @@ -277,7 +277,7 @@ def prep_data_df( # Scalars # ########### if isinstance(node_val[0], (str, int, float, np.float64)) \ - and node_pgm.is_sampled: + and node_dag.is_sampled: ###################################################### # DataFrame holding scalar values (replicates if # @@ -793,8 +793,8 @@ def dump_pgm_data(dir_string: str, None """ - sorted_node_pgm_list: \ - ty.List[pgm.NodePGM] = dag_obj.get_sorted_node_dag_list() + sorted_node_dag_list: \ + ty.List[pgm.NodeDAG] = dag_obj.get_sorted_node_dag_list() # populating data stashes that will be dumped and their file names scalar_output_stash: \ diff --git a/src/phylojunction/readwrite/pj_write.pyi b/src/phylojunction/readwrite/pj_write.pyi index b9cdc31..bb627f2 100644 --- a/src/phylojunction/readwrite/pj_write.pyi +++ b/src/phylojunction/readwrite/pj_write.pyi @@ -8,7 +8,7 @@ import phylojunction.pgm.pgm as pgm def write_text_output(outfile_handle: ty.IO, content_string_list: ty.List[str]) -> None: ... def write_data_df(outfile_handle: ty.IO, data_df: pd.DataFrame, format: str = ...) -> None: ... def write_fig_to_file(outfile_path: str,fig_obj: plt.Figure) -> None: ... -def prep_data_df(sample_size: int, node_pgm_list: ty.List[pgm.NodePGM], write_nex_states: bool=False) -> ty.Tuple[ty.List[ty.Union[pd.DataFrame, ty.Dict[int, pd.DataFrame]]], ty.List[ty.Union[ty.Dict[str, pd.DataFrame], ty.Dict[str, str]]]]: ... +def prep_data_df(sample_size: int, node_dag_list: ty.List[pgm.NodeDAG], write_nex_states: bool=False) -> ty.Tuple[ty.List[ty.Union[pd.DataFrame, ty.Dict[int, pd.DataFrame]]], ty.List[ty.Union[ty.Dict[str, pd.DataFrame], ty.Dict[str, str]]]]: ... def prep_data_filepaths_dfs(scalar_output_stash: ty.List[ty.Union[pd.DataFrame, ty.Dict[int, pd.DataFrame]]], tree_output_stash: ty.List[ty.Union[ty.Dict[str, pd.DataFrame], ty.Dict[str, str]]] = []) -> ty.Tuple[ty.List[str], ty.List[ty.Union[pd.DataFrame, str]]]: ... def dump_pgm_data(dir_string: str, dag_obj: pgm.DirectedAcyclicGraph, prefix: str, write_nex_states: bool = ...) -> None: ... def dump_serialized_pgm(dir_string: str, dag_obj: pgm.DirectedAcyclicGraph, prefix: str = ...) -> None: ... diff --git a/src/phylojunction/utility/exception_classes.py b/src/phylojunction/utility/exception_classes.py index d089b4b..d57cca9 100644 --- a/src/phylojunction/utility/exception_classes.py +++ b/src/phylojunction/utility/exception_classes.py @@ -166,20 +166,20 @@ def __str__(self) -> str: class NoPlatingAllowedError(Exception): det_name: str message: str - node_pgm_name: str + node_dag_name: str def __init__(self, det_name: str, - problematic_node_pgm_name: str, + problematic_node_dag_name: str, message: str = "") -> None: self.det_name = det_name self.message = message - self.node_pgm_name = problematic_node_pgm_name + self.node_dag_name = problematic_node_dag_name super().__init__(self.message) def __str__(self) -> str: return "ERROR: When executing " + self.det_name + "(), replicates " \ - + "were detected for argument " + self.node_pgm_name \ + + "were detected for argument " + self.node_dag_name \ + ". Plating is not supported for this deterministic function." @@ -690,7 +690,7 @@ def __str__(self) -> str: # PGM exceptions # -class NodePGMNodeStatCantFloatError(Exception): +class NodeDAGStatCantFloatError(Exception): message: str def __init__(self, node_name: str) -> None: diff --git a/src/phylojunction/utility/exception_classes.pyi b/src/phylojunction/utility/exception_classes.pyi index f408e2a..4be2f0e 100644 --- a/src/phylojunction/utility/exception_classes.pyi +++ b/src/phylojunction/utility/exception_classes.pyi @@ -55,8 +55,8 @@ class NodeInferenceDimensionalityError(Exception): class NoPlatingAllowedError(Exception): det_name: str message: str - node_pgm_name: str - def __init__(self, det_name: str, problematic_node_pgm_name: str, message: str = ...) -> None: ... + node_dag_name: str + def __init__(self, det_name: str, problematic_node_dag_name: str, message: str = ...) -> None: ... class ObjInitRequireSameParameterTypeError(Exception): message: str @@ -214,7 +214,7 @@ class ParseDetFnInitFailError(Exception): def __str__(self) -> str: ... # PGM exceptions # -class NodePGMNodeStatCantFloatError(Exception): +class NodeDAGNodeStatCantFloatError(Exception): message: str def __init__(self, node_name: str) -> None: ... def __str__(self) -> str: ... diff --git a/tests/inference/test_inference_rb_dn_parametric.py b/tests/inference/test_inference_rb_dn_parametric.py index 55e28c8..b3cd50e 100644 --- a/tests/inference/test_inference_rb_dn_parametric.py +++ b/tests/inference/test_inference_rb_dn_parametric.py @@ -27,14 +27,14 @@ def test_pj2rb_uniform(self): stoch_node_name, _, stoch_node_spec = re.split(cmdu.sampled_as_regex, cmd_line) cmdp.parse_samp_dn_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line) - a_node_pgm = dag_obj.get_node_dag_by_name("u") + a_node_dag = dag_obj.get_node_dag_by_name("u") rb_sample_mean = -0.002172273 # mean(u) in RB rb_sample_sd = 0.5781801 # stdev(u) in RB rb_sample_min = -0.9999855 # min(u) in RB rb_sample_max = 0.9999579 # max(u) in RB - pj_unif_values = a_node_pgm.value # 100000 floats in list + pj_unif_values = a_node_dag.value # 100000 floats in list self.assertAlmostEqual(rb_sample_mean, mean(pj_unif_values), delta=1e-2) self.assertAlmostEqual(rb_sample_sd, stdev(pj_unif_values), delta=1e-2) @@ -52,13 +52,18 @@ def test_pj2rb_exponential(self): dag_obj = pgm.DirectedAcyclicGraph() n_repl = 100000 - cmd_line1 = "e1 ~ exponential(n=1, nr=" + str(n_repl) + ", rate=0.5, rate_parameterization=\"true\")" # default is true - - stoch_node_name, _, stoch_node_spec = re.split(cmdu.sampled_as_regex, cmd_line1) - cmdp.parse_samp_dn_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line1) - a_node_pgm1 = dag_obj.get_node_dag_by_name("e1") - - # rev_str_list = a_node_pgm1.sampling_dn.get_rev_inference_spec_info() + cmd_line1 = "e1 ~ exponential(n=1, nr=" + str(n_repl) \ + + ", rate=0.5, rate_parameterization=\"true\")" # default is true + + stoch_node_name, _, stoch_node_spec = \ + re.split(cmdu.sampled_as_regex, cmd_line1) + cmdp.parse_samp_dn_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line1) + a_node_dag1 = dag_obj.get_node_dag_by_name("e1") + + # rev_str_list = a_node_dag1.sampling_dn.get_rev_inference_spec_info() # str_to_run_in_rb_for_unittest1 = "for (i in 1:" + str(n_repl) + ") {\n" + \ # " e1[i] ~ " + rev_str_list[0] + "\n" + \ # "}" @@ -67,15 +72,20 @@ def test_pj2rb_exponential(self): rb_sample_mean1 = 2.013223 # mean(e1) in RB rb_sample_sd1 = 2.017317 # stdev(e1) in RB - pj_exponential_values1 = a_node_pgm1.value # 100000 floats in list + pj_exponential_values1 = a_node_dag1.value # 100000 floats in list - cmd_line2 = "e2 ~ exponential(n=1, nr=" + str(n_repl) + ", rate=0.5, rate_parameterization=\"false\")" + cmd_line2 = "e2 ~ exponential(n=1, nr=" + str(n_repl) \ + + ", rate=0.5, rate_parameterization=\"false\")" - stoch_node_name, _, stoch_node_spec = re.split(cmdu.sampled_as_regex, cmd_line2) - cmdp.parse_samp_dn_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line2) - a_node_pgm2 = dag_obj.get_node_dag_by_name("e2") + stoch_node_name, _, stoch_node_spec = \ + re.split(cmdu.sampled_as_regex, cmd_line2) + cmdp.parse_samp_dn_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line2) + a_node_dag2 = dag_obj.get_node_dag_by_name("e2") - # rev_str_list = a_node_pgm2.sampling_dn.get_rev_inference_spec_info() + # rev_str_list = a_node_dag2.sampling_dn.get_rev_inference_spec_info() # str_to_run_in_rb_for_unittest2 = "for (i in 1:" + str(n_repl) + ") {\n" + \ # " e2[i] ~ " + rev_str_list[0] + "\n" + \ # "}" @@ -84,12 +94,20 @@ def test_pj2rb_exponential(self): rb_sample_mean2 = 0.50126 # mean(e2) in RB rb_sample_sd2 = 0.5042986 # stdev(e2) in RB - pj_exponential_values2 = a_node_pgm2.value # 100000 floats in list + pj_exponential_values2 = a_node_dag2.value # 100000 floats in list - self.assertAlmostEqual(rb_sample_mean1, mean(pj_exponential_values1), delta=1e-1) - self.assertAlmostEqual(rb_sample_sd1, stdev(pj_exponential_values1), delta=1e-1) - self.assertAlmostEqual(rb_sample_mean2, mean(pj_exponential_values2), delta=1e-1) - self.assertAlmostEqual(rb_sample_sd2, stdev(pj_exponential_values2), delta=1e-1) + self.assertAlmostEqual(rb_sample_mean1, + mean(pj_exponential_values1), + delta=1e-1) + self.assertAlmostEqual(rb_sample_sd1, + stdev(pj_exponential_values1), + delta=1e-1) + self.assertAlmostEqual(rb_sample_mean2, + mean(pj_exponential_values2), + delta=1e-1) + self.assertAlmostEqual(rb_sample_sd2, + stdev(pj_exponential_values2), + delta=1e-1) def test_pj2rb_gamma(self): @@ -113,14 +131,18 @@ def test_pj2rb_normal(self): n_repl = 100000 cmd_line = "n ~ normal(n=1, nr=" + str(n_repl) + ", mean=0.5, sd=0.1)" - stoch_node_name, _, stoch_node_spec = re.split(cmdu.sampled_as_regex, cmd_line) - cmdp.parse_samp_dn_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line) - a_node_pgm = dag_obj.get_node_dag_by_name("n") + stoch_node_name, _, stoch_node_spec = \ + re.split(cmdu.sampled_as_regex, cmd_line) + cmdp.parse_samp_dn_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line) + a_node_dag = dag_obj.get_node_dag_by_name("n") rb_sample_mean = 0.4998006 # mean(n) in RB rb_sample_sd = 0.09983804 # stdev(n) in RB - pj_normal_values = a_node_pgm.value # 1000 floats in list + pj_normal_values = a_node_dag.value # 1000 floats in list self.assertAlmostEqual(rb_sample_mean, mean(pj_normal_values), delta=1e-2) self.assertAlmostEqual(rb_sample_sd, stdev(pj_normal_values), delta=1e-2) diff --git a/tests/interface/test_cmd_parametric_sampling_dn_assignment.py b/tests/interface/test_cmd_parametric_sampling_dn_assignment.py index 8c06ae9..9fd1f61 100644 --- a/tests/interface/test_cmd_parametric_sampling_dn_assignment.py +++ b/tests/interface/test_cmd_parametric_sampling_dn_assignment.py @@ -26,13 +26,13 @@ def test_sampling_unif_assignment(self): stoch_node_name, _, stoch_node_spec = re.split(cmdu.sampled_as_regex, cmd_line) cmdp.parse_samp_dn_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line) - a_node_pgm = dag_obj.get_node_dag_by_name("u") + a_node_dag = dag_obj.get_node_dag_by_name("u") - self.assertTrue(isinstance(a_node_pgm.value, list)) - self.assertEqual(len(a_node_pgm.value), 100000) - self.assertAlmostEqual(0.0, mean(a_node_pgm.value), delta=1e-2) - self.assertLessEqual(-1.0, min(a_node_pgm.value)) - self.assertGreater(1.0, max(a_node_pgm.value)) + self.assertTrue(isinstance(a_node_dag.value, list)) + self.assertEqual(len(a_node_dag.value), 100000) + self.assertAlmostEqual(0.0, mean(a_node_dag.value), delta=1e-2) + self.assertLessEqual(-1.0, min(a_node_dag.value)) + self.assertGreater(1.0, max(a_node_dag.value)) self.assertEqual(1, dag_obj.n_nodes) @@ -58,12 +58,12 @@ def test_sampling_unif_vectorized_assignment(self): stoch_node_name, _, stoch_node_spec = re.split(cmdu.sampled_as_regex, cmd_line2) cmdp.parse_samp_dn_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line2) - a_node_pgm = dag_obj.get_node_dag_by_name("u") + a_node_dag = dag_obj.get_node_dag_by_name("u") - self.assertTrue(isinstance(a_node_pgm.value, list)) + self.assertTrue(isinstance(a_node_dag.value, list)) for idx, tup in enumerate(tups): - self.assertTrue(tup[0] <= a_node_pgm.value[idx * 2] < tup[1]) - self.assertTrue(tup[0] <= a_node_pgm.value[idx * 2 + 1] < tup[1]) + self.assertTrue(tup[0] <= a_node_dag.value[idx * 2] < tup[1]) + self.assertTrue(tup[0] <= a_node_dag.value[idx * 2 + 1] < tup[1]) self.assertEqual(2, dag_obj.n_nodes) @@ -84,28 +84,35 @@ def test_sampling_exp_assignment(self): stoch_node_name, _, stoch_node_spec = \ re.split(cmdu.sampled_as_regex, cmd_line1) - cmdp.parse_samp_dn_assignment( - dag_obj, stoch_node_name, stoch_node_spec, cmd_line1) + cmdp.parse_samp_dn_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line1) - a_node_pgm1 = dag_obj.get_node_dag_by_name("e1") + a_node_dag1 = dag_obj.get_node_dag_by_name("e1") - self.assertTrue(isinstance(a_node_pgm1.value, list)) - self.assertEqual(len(a_node_pgm1.value), 100000) - self.assertAlmostEqual(2.0, mean(a_node_pgm1.value), delta=0.05) + self.assertTrue(isinstance(a_node_dag1.value, list)) + self.assertEqual(len(a_node_dag1.value), 100000) + self.assertAlmostEqual(2.0, mean(a_node_dag1.value), delta=0.05) self.assertEqual(1, dag_obj.n_nodes) ####################################### # Exponential, scale parameterization # ####################################### - cmd_line2 = "e2 ~ exponential(n=100000, nr=1, rate=0.5, rate_parameterization=\"false\")" + cmd_line2 = ('e2 ~ exponential(n=100000, nr=1, rate=0.5, ' + 'rate_parameterization=\"false\")') - stoch_node_name, _, stoch_node_spec = re.split(cmdu.sampled_as_regex, cmd_line2) - cmdp.parse_samp_dn_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line2) - a_node_pgm2 = dag_obj.get_node_dag_by_name("e2") - - self.assertTrue(isinstance(a_node_pgm2.value, list)) - self.assertEqual(len(a_node_pgm2.value), 100000) - self.assertAlmostEqual(0.5, mean(a_node_pgm2.value), delta=0.05) + stoch_node_name, _, stoch_node_spec = \ + re.split(cmdu.sampled_as_regex, cmd_line2) + cmdp.parse_samp_dn_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line2) + a_node_dag2 = dag_obj.get_node_dag_by_name("e2") + + self.assertTrue(isinstance(a_node_dag2.value, list)) + self.assertEqual(len(a_node_dag2.value), 100000) + self.assertAlmostEqual(0.5, mean(a_node_dag2.value), delta=0.05) self.assertEqual(2, dag_obj.n_nodes) @@ -125,14 +132,16 @@ def test_sampling_gamma_assignment(self): stoch_node_name, _, stoch_node_spec = \ re.split(cmdu.sampled_as_regex, cmd_line1) - cmdp.parse_samp_dn_assignment( - dag_obj, stoch_node_name, stoch_node_spec, cmd_line1) + cmdp.parse_samp_dn_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line1) - a_node_pgm1 = dag_obj.get_node_dag_by_name("g1") + a_node_dag1 = dag_obj.get_node_dag_by_name("g1") - self.assertTrue(isinstance(a_node_pgm1.value, list)) - self.assertEqual(len(a_node_pgm1.value), 100000) - self.assertAlmostEqual(0.25, mean(a_node_pgm1.value), delta=0.05) + self.assertTrue(isinstance(a_node_dag1.value, list)) + self.assertEqual(len(a_node_dag1.value), 100000) + self.assertAlmostEqual(0.25, mean(a_node_dag1.value), delta=0.05) self.assertEqual(1, dag_obj.n_nodes) ################################ @@ -144,14 +153,16 @@ def test_sampling_gamma_assignment(self): stoch_node_name, _, stoch_node_spec = \ re.split(cmdu.sampled_as_regex, cmd_line2) - cmdp.parse_samp_dn_assignment( - dag_obj, stoch_node_name, stoch_node_spec, cmd_line2) + cmdp.parse_samp_dn_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line2) - a_node_pgm2 = dag_obj.get_node_dag_by_name("g2") + a_node_dag2 = dag_obj.get_node_dag_by_name("g2") - self.assertTrue(isinstance(a_node_pgm2.value, list)) - self.assertEqual(len(a_node_pgm2.value), 100000) - self.assertAlmostEqual(1.0, mean(a_node_pgm2.value), delta=0.05) + self.assertTrue(isinstance(a_node_dag2.value, list)) + self.assertEqual(len(a_node_dag2.value), 100000) + self.assertAlmostEqual(1.0, mean(a_node_dag2.value), delta=0.05) self.assertEqual(2, dag_obj.n_nodes) @@ -169,14 +180,16 @@ def test_sampling_normal_assignment(self): stoch_node_name, _, stoch_node_spec = \ re.split(cmdu.sampled_as_regex, cmd_line1) - cmdp.parse_samp_dn_assignment( - dag_obj, stoch_node_name, stoch_node_spec, cmd_line1) + cmdp.parse_samp_dn_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line1) - a_node_pgm = dag_obj.get_node_dag_by_name("n") + a_node_dag = dag_obj.get_node_dag_by_name("n") - self.assertTrue(isinstance(a_node_pgm.value, list)) - self.assertEqual(len(a_node_pgm.value), 100000) - self.assertAlmostEqual(0.5, mean(a_node_pgm.value), delta=0.1) + self.assertTrue(isinstance(a_node_dag.value, list)) + self.assertEqual(len(a_node_dag.value), 100000) + self.assertAlmostEqual(0.5, mean(a_node_dag.value), delta=0.1) self.assertEqual(1, dag_obj.n_nodes) @@ -197,14 +210,16 @@ def test_sampling_ln_assignment(self): stoch_node_name, _, stoch_node_spec = \ re.split(cmdu.sampled_as_regex, cmd_line1) - cmdp.parse_samp_dn_assignment( - dag_obj, stoch_node_name, stoch_node_spec, cmd_line1) + cmdp.parse_samp_dn_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line1) - a_node_pgm = dag_obj.get_node_dag_by_name("ln1") + a_node_dag = dag_obj.get_node_dag_by_name("ln1") - self.assertTrue(isinstance(a_node_pgm.value, list)) - self.assertEqual(len(a_node_pgm.value), 100000) - self.assertAlmostEqual(2.37, mean(a_node_pgm.value), delta=0.1) + self.assertTrue(isinstance(a_node_dag.value, list)) + self.assertEqual(len(a_node_dag.value), 100000) + self.assertAlmostEqual(2.37, mean(a_node_dag.value), delta=0.1) self.assertEqual(1, dag_obj.n_nodes) ########################## @@ -219,15 +234,17 @@ def test_sampling_ln_assignment(self): stoch_node_name, _, stoch_node_spec = \ re.split(cmdu.sampled_as_regex, cmd_line2) - cmdp.parse_samp_dn_assignment( - dag_obj, stoch_node_name, stoch_node_spec, cmd_line2) + cmdp.parse_samp_dn_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line2) - a_node_pgm = dag_obj.get_node_dag_by_name("ln2") + a_node_dag = dag_obj.get_node_dag_by_name("ln2") - self.assertTrue(isinstance(a_node_pgm.value, list)) - # self.assertEqual(len(a_node_pgm.value), 100000) - # self.assertAlmostEqual(2.37, mean(a_node_pgm.value), delta=0.1) - # self.assertEqual(2, dag_obj.n_nodes) + self.assertTrue(isinstance(a_node_dag.value, list)) + self.assertEqual(len(a_node_dag.value), 100000) + self.assertAlmostEqual(2.37, mean(a_node_dag.value), delta=0.1) + self.assertEqual(2, dag_obj.n_nodes) def test_unif_misspec(self): diff --git a/tests/interface/test_cmd_var_assignment.py b/tests/interface/test_cmd_var_assignment.py index 554a0be..b1d449a 100644 --- a/tests/interface/test_cmd_var_assignment.py +++ b/tests/interface/test_cmd_var_assignment.py @@ -14,95 +14,122 @@ class TestVarAssignment(unittest.TestCase): def test_var_assignment(self): """ - Test if a series of different variable assignments are correctly evaluated - and result in the right probabilistic graphical model + Test variable assignment. + + Test a series of different variable assignments produce the + right DAG. """ dag_obj = pgm.DirectedAcyclicGraph() cmd_line1 = "a <- 1" - stoch_node_name, _, stoch_node_spec = re.split(cmdu.assign_regex, cmd_line1) - cmd.parse_variable_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line1) - a_node_pgm = dag_obj.get_node_dag_by_name("a") + stoch_node_name, _, stoch_node_spec = \ + re.split(cmdu.assign_regex, cmd_line1) + cmd.parse_variable_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line1) + a_node_dag = dag_obj.get_node_dag_by_name("a") self.assertEqual(1, dag_obj.n_nodes) - self.assertEqual(type(a_node_pgm.value), list) - self.assertEqual(len(a_node_pgm.value), 1) - self.assertEqual(a_node_pgm.value, ["1"]) + self.assertEqual(type(a_node_dag.value), list) + self.assertEqual(len(a_node_dag.value), 1) + self.assertEqual(a_node_dag.value, ["1"]) # --- # cmd_line2 = "b <- [1, 2, 3]" - stoch_node_name, _, stoch_node_spec = re.split(cmdu.assign_regex, cmd_line2) - cmd.parse_variable_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line2) - b_node_pgm = dag_obj.get_node_dag_by_name("b") + stoch_node_name, _, stoch_node_spec = \ + re.split(cmdu.assign_regex, cmd_line2) + cmd.parse_variable_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line2) + b_node_dag = dag_obj.get_node_dag_by_name("b") self.assertEqual(2, dag_obj.n_nodes) - self.assertEqual(type(b_node_pgm.value), list) - self.assertEqual(len(b_node_pgm.value), 3) - self.assertEqual(b_node_pgm.value, ["1", "2", "3"]) + self.assertEqual(type(b_node_dag.value), list) + self.assertEqual(len(b_node_dag.value), 3) + self.assertEqual(b_node_dag.value, ["1", "2", "3"]) # --- # cmd_line3 = "a <- [1]" - stoch_node_name, _, stoch_node_spec = re.split(cmdu.assign_regex, cmd_line3) - cmd.parse_variable_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line3) - a_node_pgm = dag_obj.get_node_dag_by_name("a") + stoch_node_name, _, stoch_node_spec = \ + re.split(cmdu.assign_regex, cmd_line3) + cmd.parse_variable_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line3) + a_node_dag = dag_obj.get_node_dag_by_name("a") self.assertEqual(2, dag_obj.n_nodes) - self.assertEqual(type(a_node_pgm.value), list) - self.assertEqual(len(a_node_pgm.value), 1) - self.assertEqual(a_node_pgm.value, ["1"]) + self.assertEqual(type(a_node_dag.value), list) + self.assertEqual(len(a_node_dag.value), 1) + self.assertEqual(a_node_dag.value, ["1"]) # --- # cmd_line4 = "c <- b" - stoch_node_name, _, stoch_node_spec = re.split(cmdu.assign_regex, cmd_line4) - cmd.parse_variable_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line4) - c_node_pgm = dag_obj.get_node_dag_by_name("c") + stoch_node_name, _, stoch_node_spec = \ + re.split(cmdu.assign_regex, cmd_line4) + cmd.parse_variable_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line4) + c_node_dag = dag_obj.get_node_dag_by_name("c") self.assertEqual(3, dag_obj.n_nodes) - self.assertEqual(type(c_node_pgm.value), list) - self.assertEqual(len(c_node_pgm.value), 3) - self.assertEqual(c_node_pgm.value, ["1", "2", "3"]) + self.assertEqual(type(c_node_dag.value), list) + self.assertEqual(len(c_node_dag.value), 3) + self.assertEqual(c_node_dag.value, ["1", "2", "3"]) # --- # cmd_line5 = "d <- [c, 4, 5, 6]" - stoch_node_name, _, stoch_node_spec = re.split(cmdu.assign_regex, cmd_line5) - cmd.parse_variable_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line5) - d_node_pgm = dag_obj.get_node_dag_by_name("d") + stoch_node_name, _, stoch_node_spec = \ + re.split(cmdu.assign_regex, cmd_line5) + cmd.parse_variable_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line5) + d_node_dag = dag_obj.get_node_dag_by_name("d") self.assertEqual(4, dag_obj.n_nodes) - self.assertEqual(type(d_node_pgm.value), list) - self.assertEqual(len(d_node_pgm.value), 6) - self.assertEqual(d_node_pgm.value, ["1", "2", "3", "4", "5", "6"]) + self.assertEqual(type(d_node_dag.value), list) + self.assertEqual(len(d_node_dag.value), 6) + self.assertEqual(d_node_dag.value, ["1", "2", "3", "4", "5", "6"]) def test_var_assignment_read_tree_string(self): - """ - Test if read_tree() calls using newick strings directly - are correctly evaluated and result in the right probabilistic - graphical model + """Test read_tree() from string. + + Test read_tree()calls using Newick strings directly produce + the right DAG. """ dag_obj = pgm.DirectedAcyclicGraph() - cmd_line1 = ('tr <- read_tree(string="((sp1[&index=1]:1.0,sp2[&index=2]:1.0)' - '[&index=4]:1.0,sp3[&index=3]:2.0)[&index=5];", node_name_attr="index")') + cmd_line1 = ('tr <- read_tree(string="((sp1[&index=1]:1.0,sp2' + '[&index=2]:1.0)[&index=4]:1.0,sp3[&index=3]:2.0)' + '[&index=5];", node_name_attr="index")') - stoch_node_name, _, stoch_node_spec = re.split(cmdu.assign_regex, cmd_line1) - cmd.parse_variable_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line1) - a_node_pgm = dag_obj.get_node_dag_by_name("tr") + stoch_node_name, _, stoch_node_spec = \ + re.split(cmdu.assign_regex, cmd_line1) + cmd.parse_variable_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line1) + a_node_dag = dag_obj.get_node_dag_by_name("tr") self.assertEqual(1, dag_obj.n_nodes) - self.assertEqual(type(a_node_pgm.value), list) - self.assertEqual(len(a_node_pgm.value), 1) - self.assertEqual(a_node_pgm.value[0].__str__().rstrip(), + self.assertEqual(type(a_node_dag.value), list) + self.assertEqual(len(a_node_dag.value), 1) + self.assertEqual(a_node_dag.value[0].__str__().rstrip(), ('((nd1:1.0[&index=1],nd2:1.0[&index=2])nd4:1.0' '[&index=4],nd3:2.0[&index=3])nd5[&index=5];')) @@ -110,21 +137,19 @@ def test_var_assignment_read_tree_string(self): stoch_node_name, _, stoch_node_spec = re.split(cmdu.assign_regex, cmd_line2) cmd.parse_variable_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line2) - a_node_pgm = dag_obj.get_node_dag_by_name("tr") - - # print(a_node_pgm.value[0]) + a_node_dag = dag_obj.get_node_dag_by_name("tr") self.assertEqual(1, dag_obj.n_nodes) - self.assertEqual(type(a_node_pgm.value), list) - self.assertEqual(len(a_node_pgm.value), 1) - self.assertEqual(a_node_pgm.value[0].__str__().rstrip(), + self.assertEqual(type(a_node_dag.value), list) + self.assertEqual(len(a_node_dag.value), 1) + self.assertEqual(a_node_dag.value[0].__str__().rstrip(), ('((sp1:1.0,sp2:1.0)nd1:1.0,sp3:2.0)root;')) def test_var_assignment_read_tree_from_file(self): - """ - Test if read_tree() calls using tree files are correctly - evaluated and result in the right probabilistic - graphical model + """Test read_tree() from file. + + Test read_tree()calls using tree files containing Newick + strings directly produce the right DAG. """ dag_obj = pgm.DirectedAcyclicGraph() @@ -132,14 +157,18 @@ def test_var_assignment_read_tree_from_file(self): cmd_line1 = ('tr <- read_tree(file_path="examples/trees_maps_files' '/tree_to_read.tre", node_name_attr="index")') - stoch_node_name, _, stoch_node_spec = re.split(cmdu.assign_regex, cmd_line1) - cmd.parse_variable_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line1) - a_node_pgm = dag_obj.get_node_dag_by_name("tr") + stoch_node_name, _, stoch_node_spec = \ + re.split(cmdu.assign_regex, cmd_line1) + cmd.parse_variable_assignment(dag_obj, + stoch_node_name, + stoch_node_spec, + cmd_line1) + a_node_dag = dag_obj.get_node_dag_by_name("tr") self.assertEqual(1, dag_obj.n_nodes) - self.assertEqual(type(a_node_pgm.value), list) - self.assertEqual(len(a_node_pgm.value), 1) - self.assertEqual(a_node_pgm.value[0].__str__().rstrip(), + self.assertEqual(type(a_node_dag.value), list) + self.assertEqual(len(a_node_dag.value), 1) + self.assertEqual(a_node_dag.value[0].__str__().rstrip(), ('((nd1:1.0[&index=1],nd2:1.0[&index=2])nd4:1.0' '[&index=4],nd3:2.0[&index=3])nd5[&index=5];')) @@ -148,20 +177,17 @@ def test_var_assignment_read_tree_from_file(self): stoch_node_name, _, stoch_node_spec = re.split(cmdu.assign_regex, cmd_line2) cmd.parse_variable_assignment(dag_obj, stoch_node_name, stoch_node_spec, cmd_line2) - a_node_pgm = dag_obj.get_node_dag_by_name("tr") + a_node_dag = dag_obj.get_node_dag_by_name("tr") self.assertEqual(1, dag_obj.n_nodes) - self.assertEqual(type(a_node_pgm.value), list) - self.assertEqual(len(a_node_pgm.value), 2) - self.assertEqual(a_node_pgm.value[1].__str__().rstrip(), + self.assertEqual(type(a_node_dag.value), list) + self.assertEqual(len(a_node_dag.value), 2) + self.assertEqual(a_node_dag.value[1].__str__().rstrip(), ('((nd1:0.1[&index=1],nd2:0.1[&index=2])nd4:1.9' '[&index=4],nd3:2.0[&index=3])nd5[&index=5];')) def test_var_assignment_read_tree_string_exceptions(self): - """ - Test if errors with read_tree() calls using newick strings - are handled correctly - """ + """Test read_tree() from string raised exceptions.""" dag_obj = pgm.DirectedAcyclicGraph() @@ -207,10 +233,7 @@ def test_var_assignment_read_tree_string_exceptions(self): expected_exception_message2) def test_var_assignment_read_tree_file_exceptions(self): - """ - Test if errors with read_tree() calls using tree files - are handled correctly - """ + """Test read_tree() from file raised exceptions.""" dag_obj = pgm.DirectedAcyclicGraph() @@ -247,9 +270,9 @@ def test_var_assignment_read_tree_file_exceptions(self): with self.assertRaises(ec.ObjInitInvalidArgError) as exc: cmd.parse_variable_assignment(dag_obj, - stoch_node_name, - stoch_node_spec, - cmd_line2) + stoch_node_name, + stoch_node_spec, + cmd_line2) self.assertEqual(str(exc.exception), expected_exception_message2) diff --git a/tests/readwrite/test_data_dump.py b/tests/readwrite/test_data_dump.py index 5026818..5f8232b 100644 --- a/tests/readwrite/test_data_dump.py +++ b/tests/readwrite/test_data_dump.py @@ -117,41 +117,41 @@ def setUpClass(cls) -> None: # rv cls.bisse_pgm.add_node( - pgm.StochasticNodePGM( + pgm.StochasticNodeDAG( "l0", n_sim, value=l0, sampled_from="Log-normal" ) ) cls.bisse_pgm.add_node( - pgm.StochasticNodePGM( + pgm.StochasticNodeDAG( "mu0", n_sim, value=mu0, sampled_from="Log-normal" ) ) # deterministic - cls.bisse_pgm.add_node(pgm.DeterministicNodePGM("l0r", value=l0rate)) - cls.bisse_pgm.add_node(pgm.DeterministicNodePGM("mu0r", value=mu0rate)) - cls.bisse_pgm.add_node(pgm.DeterministicNodePGM("q01r", value=q01rate)) - cls.bisse_pgm.add_node(pgm.DeterministicNodePGM("l1r", value=l1rate)) - cls.bisse_pgm.add_node(pgm.DeterministicNodePGM("mu1r", value=mu1rate)) - cls.bisse_pgm.add_node(pgm.DeterministicNodePGM("q10r", value=q10rate)) - cls.bisse_pgm.add_node(pgm.DeterministicNodePGM("meh", value=meh)) + cls.bisse_pgm.add_node(pgm.DeterministicNodeDAG("l0r", value=l0rate)) + cls.bisse_pgm.add_node(pgm.DeterministicNodeDAG("mu0r", value=mu0rate)) + cls.bisse_pgm.add_node(pgm.DeterministicNodeDAG("q01r", value=q01rate)) + cls.bisse_pgm.add_node(pgm.DeterministicNodeDAG("l1r", value=l1rate)) + cls.bisse_pgm.add_node(pgm.DeterministicNodeDAG("mu1r", value=mu1rate)) + cls.bisse_pgm.add_node(pgm.DeterministicNodeDAG("q10r", value=q10rate)) + cls.bisse_pgm.add_node(pgm.DeterministicNodeDAG("meh", value=meh)) # more rv cls.bisse_pgm.add_node( - pgm.StochasticNodePGM( + pgm.StochasticNodeDAG( "trs", n_sim, value=trs, sampled_from="DnSSE", replicate_size=n_repl ) ) - # sorted_node_pgm_list = bisse_pgm.get_sorted_node_dag_list() + # sorted_node_dag_list = bisse_pgm.get_sorted_node_dag_list() ################### # Output handling # ################### # populating dataframe to be dumped # data_df_names_list, data_df_list = \ - # pjwrite.prep_data_df(sorted_node_pgm_list) + # pjwrite.prep_data_df(sorted_node_dag_list) cls.scalar_output_stash, cls.tree_output_stash = \ pjwrite.prep_data_df(cls.bisse_pgm, write_nex_states=True)