From 2ba2ee77ab3050006899876692a73ccc6299182d Mon Sep 17 00:00:00 2001 From: binho Date: Mon, 22 Jan 2024 11:39:59 -0600 Subject: [PATCH] Add support for random seed --- .../interface/cmdbox/cmd_parse.py | 21 ++- .../interface/cmdbox/cmd_parse.pyi | 4 +- src/phylojunction/interface/pjcli/pj_cli.py | 28 +++- .../interface/pysidegui/pj_gui.py | 40 ++++- .../pysidegui/pjguipages/gui_pages.py | 70 +++++--- .../pysidegui/pjguipages/pjgui_pages.ui | 154 ++++++++++++------ src/phylojunction/pgm/pgm.py | 26 ++- src/phylojunction/pgm/pgm.pyi | 5 + .../utility/exception_classes.py | 10 ++ .../utility/exception_classes.pyi | 5 + 10 files changed, 265 insertions(+), 98 deletions(-) diff --git a/src/phylojunction/interface/cmdbox/cmd_parse.py b/src/phylojunction/interface/cmdbox/cmd_parse.py index 642d0fd..cf1ecdf 100644 --- a/src/phylojunction/interface/cmdbox/cmd_parse.py +++ b/src/phylojunction/interface/cmdbox/cmd_parse.py @@ -20,15 +20,19 @@ def script2dag(script_file_path_or_model_spec: str, - in_pj_file: bool = True) -> pgm.DirectedAcyclicGraph: + in_pj_file: bool = True, + random_seed: ty.Optional[int] = None) -> pgm.DirectedAcyclicGraph: """Go through .pj lines and populate and return DAG. Args: script_file_path_or_model_spec (str): Path to script .pj - file or full string specifying model directly + file or full string specifying model directly. + random_seed (int): Random seed (integer) for simulation. + Defaults to None. Returns: - (DirectedAcyclicGraph): DAG object built from .pj commands. + (DirectedAcyclicGraph): DAG object (model) built from .pj + script commands. """ def _execute_spec_lines( @@ -40,14 +44,13 @@ def _execute_spec_lines( line = line.lstrip().rstrip() _ = cmdline2dag(dag_obj, line) - # handle seed if specified - np.random.seed(seed=123) - random.seed(123) - + # initialize model DAG dag = pgm.DirectedAcyclicGraph() + if random_seed is not None: + dag.random_seed = random_seed - all_lines_list: ty.List[str] = [] - + # get script strings + all_lines_list: ty.List[str] = list() if in_pj_file: with open(script_file_path_or_model_spec, "r") as infile: all_lines_list = infile.readlines() diff --git a/src/phylojunction/interface/cmdbox/cmd_parse.pyi b/src/phylojunction/interface/cmdbox/cmd_parse.pyi index 324bf41..e4c8e58 100644 --- a/src/phylojunction/interface/cmdbox/cmd_parse.pyi +++ b/src/phylojunction/interface/cmdbox/cmd_parse.pyi @@ -1,6 +1,8 @@ +import typing as ty + import phylojunction.pgm.pgm as pgm -def script2dag(script_file_path: str) -> pgm.DirectedAcyclicGraph: ... +def script2dag(script_file_path: str, in_pj_file: bool=True, random_seed: ty.Optional[int]=None) -> pgm.DirectedAcyclicGraph: ... def cmdline2dag(dag_obj: pgm.DirectedAcyclicGraph, cmd_line: str): ... def parse_variable_assignment(dag_obj: pgm.DirectedAcyclicGraph, stoch_node_name: str, stoch_node_spec: str, cmd_line: str) -> None: ... def parse_samp_dn_assignment(dag_obj, stoch_node_name, stoch_node_dn_spec, cmd_line) -> None: ... diff --git a/src/phylojunction/interface/pjcli/pj_cli.py b/src/phylojunction/interface/pjcli/pj_cli.py index 66948c5..61535fc 100644 --- a/src/phylojunction/interface/pjcli/pj_cli.py +++ b/src/phylojunction/interface/pjcli/pj_cli.py @@ -20,20 +20,25 @@ def execute_pj_script( write_data: bool = False, write_figures: str = "", write_inference: bool = False, - write_nex_states: bool = False) -> None: + write_nex_states: bool = False, + a_random_seed: ty.Optional[int] = None) -> None: """ - Execute .pj script + Execute .pj script. This is called by application 'pjcli' from the terminal, but can also be called from a .py script importing phylojunction """ # Reading model # - dag_obj: pgm.DirectedAcyclicGraph = \ - pgm.DirectedAcyclicGraph() + random_seed: int = None + if a_random_seed is not None: + random_seed = int(a_random_seed) print("Reading script " + model) - dag_obj = cmd.script2dag(model, in_pj_file=True) + dag_obj: pgm.DirectedAcyclicGraph \ + = cmd.script2dag(model, + in_pj_file=True, + random_seed=random_seed) print(" ... done!") n_samples = dag_obj.sample_size @@ -135,8 +140,8 @@ def call_cli() -> None: dest="out_dir", type=str, default="./", - help=("Path to project root directory, where automatic " - "subdirectories will be created")) + help=("Path to project root directory, where subdirectories " + "will be automatically created")) parser.add_argument( "-p", "--prefix", @@ -150,6 +155,12 @@ def call_cli() -> None: action="store_true", default=False, help="Toggle states nexus file output") + parser.add_argument( + "-r", "--random-seed", + dest="random_seed", + action="store", + default=None, + help="Random seed (integer)") args = parser.parse_args() @@ -160,7 +171,8 @@ def call_cli() -> None: write_data=args.write_data, write_figures=args.write_figures, write_inference=args.write_inference, - write_nex_states=args.write_nex_states + write_nex_states=args.write_nex_states, + a_random_seed=args.random_seed ) # if one wants to run pj_cli.py for some reason diff --git a/src/phylojunction/interface/pysidegui/pj_gui.py b/src/phylojunction/interface/pysidegui/pj_gui.py index 53877e5..0fe3650 100644 --- a/src/phylojunction/interface/pysidegui/pj_gui.py +++ b/src/phylojunction/interface/pysidegui/pj_gui.py @@ -34,11 +34,15 @@ def __init__(self): self.dag_obj = DirectedAcyclicGraph() self.cmd_log_list = [] + def update_dag_random_seed(self, a_random_seed: int): + self.dag_obj.random_seed = a_random_seed + def parse_cmd_update_pgm( self, cmd_line_list, gui_main_window_obj, clear_cmd_log_list: bool = False): + # in case we read two scripts # consecutively, we want to clear # the comand line history list @@ -53,6 +57,17 @@ def parse_cmd_update_pgm( line = line.strip() print(" " + line) + # will set random seed once if DAG does not already have one + if not isinstance(self.dag_obj.random_seed, int): + a_random_seed: str = \ + gui_main_window_obj.ui.ui_pages. \ + random_seed_prefix_textbox.toPlainText() + + # random seed is not None and not empty string + if a_random_seed: + random_seed: int = int(a_random_seed) + self.dag_obj.random_seed = random_seed + # side-effect in cmdline2dag try: valid_cmd_line = cmdp.cmdline2dag(self.dag_obj, line) @@ -317,7 +332,7 @@ def show_compare_page(self): def show_coverage_page(self): self.reset_selection() self.ui.pages.setCurrentWidget(self.ui.ui_pages.coverage_page) - self.ui.compare_button.set_active(True) + self.ui.covg_button.set_active(True) self.ui.top_label_left.setText("COVERAGE VALIDATION") def show_cmd_log_page(self): @@ -704,6 +719,14 @@ def read_execute_script(self): # reset everything # self.clean_disable_everything() + # set seed + a_random_seed = self.ui.ui_pages.random_seed_prefix_textbox.toPlainText() + # is not None and is not an empty string + if a_random_seed: + random_seed: int = int(a_random_seed) + self.gui_modeling.update_dag_random_seed(random_seed) + + # read all command line strings cmd_line_list = pjread.read_text_file(script_fp) # (side-effect: gui_modeling stores cmd hist) # @@ -800,12 +823,17 @@ def read_coverage_hpd_csv(self): def load_model(self): # read file path # - model_fp, filter = QFileDialog.getOpenFileName(parent=self, caption="Load model", dir=".", filter="*.pickle") + model_fp, filter = QFileDialog.getOpenFileName( + parent=self, + caption="Load model", + dir=".", + filter="*.pickle") if model_fp: self.clean_disable_everything() self.ui.ui_pages.cmd_log_textbox.clear() - self.gui_modeling.dag_obj, self.gui_modeling.cmd_log_list = pjread.read_serialized_pgm(model_fp) + self.gui_modeling.dag_obj, self.gui_modeling.cmd_log_list = \ + pjread.read_serialized_pgm(model_fp) self.refresh_node_lists() self.ui.ui_pages.cmd_log_textbox.setText(self.gui_modeling.cmd_log()) @@ -1059,10 +1087,10 @@ def _prepare_for_tree(potential_repl: bool = False): self.ui.ui_pages.sample_idx_spin.setMinimum(1) self.ui.ui_pages.sample_idx_spin.setValue(1) + self.ui.ui_pages.repl_idx_spin.setMinimum(1) + self.ui.ui_pages.repl_idx_spin.setValue(1) if potential_repl: self.ui.ui_pages.repl_idx_spin.setEnabled(True) - self.ui.ui_pages.repl_idx_spin.setMinimum(1) - self.ui.ui_pages.repl_idx_spin.setValue(1) def _prepare_for_scalar(potential_repl: bool = False): # radio # @@ -1109,7 +1137,7 @@ def _nothing_to_spin_through(): # (basically: non-deterministic nodes) if isinstance(node_dag.value, list): if isinstance(node_dag.value[0], pjdt.AnnotatedTree): - if node_dag.repl_size > 1: + if node_dag.repl_size >= 2: _prepare_for_tree(potential_repl=True) else: diff --git a/src/phylojunction/interface/pysidegui/pjguipages/gui_pages.py b/src/phylojunction/interface/pysidegui/pjguipages/gui_pages.py index 64d2e96..553a451 100644 --- a/src/phylojunction/interface/pysidegui/pjguipages/gui_pages.py +++ b/src/phylojunction/interface/pysidegui/pjguipages/gui_pages.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- -# Form implementation generated from reading ui file 'src/phylojunction/interface/pysidegui/pjguipages/pjgui_pages.ui' +# Form implementation generated from reading ui file 'pjgui_pages.ui' # -# Created by: PyQt5 UI code generator 5.15.7 +# Created by: PyQt5 UI code generator 5.15.9 # # WARNING: Any manual changes made to this file will be lost when pyuic5 is # run again. Do not edit this file unless you know what you are doing. @@ -25,26 +25,45 @@ def setupUi(self, PJGUIPages): self.settings_page_grid_layout = QtWidgets.QGridLayout(self.gridLayoutWidget_4) self.settings_page_grid_layout.setContentsMargins(0, 0, 0, 0) self.settings_page_grid_layout.setObjectName("settings_page_grid_layout") - self.label = QtWidgets.QLabel(self.gridLayoutWidget_4) - self.label.setMinimumSize(QtCore.QSize(0, 24)) - self.label.setMaximumSize(QtCore.QSize(16777215, 24)) - self.label.setObjectName("label") - self.settings_page_grid_layout.addWidget(self.label, 0, 0, 1, 1) - self.filename_prefix_hor_layout = QtWidgets.QHBoxLayout() - self.filename_prefix_hor_layout.setObjectName("filename_prefix_hor_layout") self.filename_prefix_label = QtWidgets.QLabel(self.gridLayoutWidget_4) self.filename_prefix_label.setMinimumSize(QtCore.QSize(110, 24)) self.filename_prefix_label.setMaximumSize(QtCore.QSize(110, 24)) self.filename_prefix_label.setObjectName("filename_prefix_label") - self.filename_prefix_hor_layout.addWidget(self.filename_prefix_label, 0, QtCore.Qt.AlignTop) + self.settings_page_grid_layout.addWidget(self.filename_prefix_label, 1, 0, 1, 1) + self.random_seed_prefix_label = QtWidgets.QLabel(self.gridLayoutWidget_4) + self.random_seed_prefix_label.setMinimumSize(QtCore.QSize(110, 24)) + self.random_seed_prefix_label.setMaximumSize(QtCore.QSize(110, 24)) + self.random_seed_prefix_label.setObjectName("random_seed_prefix_label") + self.settings_page_grid_layout.addWidget(self.random_seed_prefix_label, 2, 0, 1, 1) + self.settings_label = QtWidgets.QLabel(self.gridLayoutWidget_4) + self.settings_label.setMinimumSize(QtCore.QSize(0, 24)) + self.settings_label.setMaximumSize(QtCore.QSize(16777215, 24)) + font = QtGui.QFont() + font.setBold(True) + self.settings_label.setFont(font) + self.settings_label.setObjectName("settings_label") + self.settings_page_grid_layout.addWidget(self.settings_label, 0, 0, 1, 1) + spacerItem = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding) + self.settings_page_grid_layout.addItem(spacerItem, 3, 0, 1, 1) self.filename_prefix_textbox = QtWidgets.QTextEdit(self.gridLayoutWidget_4) self.filename_prefix_textbox.setMinimumSize(QtCore.QSize(0, 24)) self.filename_prefix_textbox.setMaximumSize(QtCore.QSize(16777215, 24)) self.filename_prefix_textbox.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff) self.filename_prefix_textbox.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff) self.filename_prefix_textbox.setObjectName("filename_prefix_textbox") - self.filename_prefix_hor_layout.addWidget(self.filename_prefix_textbox, 0, QtCore.Qt.AlignTop) - self.settings_page_grid_layout.addLayout(self.filename_prefix_hor_layout, 1, 0, 1, 1) + self.settings_page_grid_layout.addWidget(self.filename_prefix_textbox, 1, 1, 1, 1) + self.random_seed_prefix_textbox = QtWidgets.QTextEdit(self.gridLayoutWidget_4) + self.random_seed_prefix_textbox.setMinimumSize(QtCore.QSize(0, 24)) + self.random_seed_prefix_textbox.setMaximumSize(QtCore.QSize(16777215, 24)) + self.random_seed_prefix_textbox.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff) + self.random_seed_prefix_textbox.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff) + self.random_seed_prefix_textbox.setObjectName("random_seed_prefix_textbox") + self.settings_page_grid_layout.addWidget(self.random_seed_prefix_textbox, 2, 1, 1, 1) + self.line = QtWidgets.QFrame(self.gridLayoutWidget_4) + self.line.setFrameShape(QtWidgets.QFrame.HLine) + self.line.setFrameShadow(QtWidgets.QFrame.Sunken) + self.line.setObjectName("line") + self.settings_page_grid_layout.addWidget(self.line, 0, 1, 1, 1) PJGUIPages.addWidget(self.settings_page) self.pgm_page = QtWidgets.QWidget() self.pgm_page.setMinimumSize(QtCore.QSize(980, 700)) @@ -108,7 +127,7 @@ def setupUi(self, PJGUIPages): " color: #ec4a8a;\n" "}") icon = QtGui.QIcon() - icon.addPixmap(QtGui.QPixmap("src/phylojunction/interface/pysidegui/pjguipages/../images/icons/icon_clear.svg"), QtGui.QIcon.Normal, QtGui.QIcon.Off) + icon.addPixmap(QtGui.QPixmap("../images/icons/icon_clear.svg"), QtGui.QIcon.Normal, QtGui.QIcon.Off) self.clear_model.setIcon(icon) self.clear_model.setIconSize(QtCore.QSize(20, 20)) self.clear_model.setObjectName("clear_model") @@ -131,16 +150,16 @@ def setupUi(self, PJGUIPages): self.all_samples_radio.setAutoExclusive(True) self.all_samples_radio.setObjectName("all_samples_radio") self.radio_spin_hor_layout.addWidget(self.all_samples_radio) - spacerItem = QtWidgets.QSpacerItem(20, 20, QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Minimum) - self.radio_spin_hor_layout.addItem(spacerItem) + spacerItem1 = QtWidgets.QSpacerItem(20, 20, QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Minimum) + self.radio_spin_hor_layout.addItem(spacerItem1) self.sample_idx_spin = QtWidgets.QSpinBox(self.gridLayoutWidget) self.sample_idx_spin.setEnabled(False) self.sample_idx_spin.setStyleSheet("color: black;") self.sample_idx_spin.setAccelerated(True) self.sample_idx_spin.setObjectName("sample_idx_spin") self.radio_spin_hor_layout.addWidget(self.sample_idx_spin) - spacerItem1 = QtWidgets.QSpacerItem(20, 20, QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Minimum) - self.radio_spin_hor_layout.addItem(spacerItem1) + spacerItem2 = QtWidgets.QSpacerItem(20, 20, QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Minimum) + self.radio_spin_hor_layout.addItem(spacerItem2) self.repl_idx_spin = QtWidgets.QSpinBox(self.gridLayoutWidget) self.repl_idx_spin.setEnabled(False) self.repl_idx_spin.setStyleSheet("color: black;\n" @@ -149,8 +168,8 @@ def setupUi(self, PJGUIPages): self.repl_idx_spin.setAccelerated(True) self.repl_idx_spin.setObjectName("repl_idx_spin") self.radio_spin_hor_layout.addWidget(self.repl_idx_spin) - spacerItem2 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) - self.radio_spin_hor_layout.addItem(spacerItem2) + spacerItem3 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) + self.radio_spin_hor_layout.addItem(spacerItem3) self.save_pgm_node_plot = QtWidgets.QPushButton(self.gridLayoutWidget) self.save_pgm_node_plot.setEnabled(True) self.save_pgm_node_plot.setMinimumSize(QtCore.QSize(130, 24)) @@ -290,8 +309,8 @@ def setupUi(self, PJGUIPages): self.avg_replicate_check_button.setCursor(QtGui.QCursor(QtCore.Qt.PointingHandCursor)) self.avg_replicate_check_button.setObjectName("avg_replicate_check_button") self.node_stat_vert_layout.addWidget(self.compare_node_frame) - spacerItem3 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding) - self.node_stat_vert_layout.addItem(spacerItem3) + spacerItem4 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding) + self.node_stat_vert_layout.addItem(spacerItem4) self.compare_stats_label = QtWidgets.QLabel(self.gridLayoutWidget_2) self.compare_stats_label.setMinimumSize(QtCore.QSize(0, 16)) self.compare_stats_label.setMaximumSize(QtCore.QSize(16777215, 16)) @@ -521,8 +540,8 @@ def setupUi(self, PJGUIPages): self.coverage_node_label.setAlignment(QtCore.Qt.AlignCenter) self.coverage_node_label.setObjectName("coverage_node_label") self.cov_node_stat_vert_layout.addWidget(self.coverage_frame, 0, QtCore.Qt.AlignHCenter) - spacerItem4 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding) - self.cov_node_stat_vert_layout.addItem(spacerItem4) + spacerItem5 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding) + self.cov_node_stat_vert_layout.addItem(spacerItem5) self.compare_stats_label_2 = QtWidgets.QLabel(self.gridLayoutWidget_3) self.compare_stats_label_2.setMinimumSize(QtCore.QSize(0, 16)) self.compare_stats_label_2.setMaximumSize(QtCore.QSize(16777215, 16)) @@ -586,14 +605,15 @@ def setupUi(self, PJGUIPages): self.retranslateUi(PJGUIPages) PJGUIPages.setCurrentIndex(1) - self.node_content_tabs.setCurrentIndex(0) + self.node_content_tabs.setCurrentIndex(1) QtCore.QMetaObject.connectSlotsByName(PJGUIPages) def retranslateUi(self, PJGUIPages): _translate = QtCore.QCoreApplication.translate PJGUIPages.setWindowTitle(_translate("PJGUIPages", "StackedWidget")) - self.label.setText(_translate("PJGUIPages", "Output configuration")) self.filename_prefix_label.setText(_translate("PJGUIPages", "File name prefix:")) + self.random_seed_prefix_label.setText(_translate("PJGUIPages", "Random seed:")) + self.settings_label.setText(_translate("PJGUIPages", "Output configuration")) self.model_label.setText(_translate("PJGUIPages", "Model nodes")) self.clear_model.setText(_translate("PJGUIPages", " Clear model")) self.one_sample_radio.setText(_translate("PJGUIPages", "One sample")) diff --git a/src/phylojunction/interface/pysidegui/pjguipages/pjgui_pages.ui b/src/phylojunction/interface/pysidegui/pjguipages/pjgui_pages.ui index 4b47adb..46cb412 100644 --- a/src/phylojunction/interface/pysidegui/pjguipages/pjgui_pages.ui +++ b/src/phylojunction/interface/pysidegui/pjguipages/pjgui_pages.ui @@ -39,8 +39,46 @@ + + + + + 110 + 24 + + + + + 110 + 24 + + + + File name prefix: + + + + + + + + 110 + 24 + + + + + 110 + 24 + + + + Random seed: + + + - + 0 @@ -53,55 +91,79 @@ 24 + + + true + + Output configuration - - - - - - - 110 - 24 - - - - - 110 - 24 - - - - File name prefix: - - - - - - - - 0 - 24 - - - - - 16777215 - 24 - - - - Qt::ScrollBarAlwaysOff - - - Qt::ScrollBarAlwaysOff - - - - + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + + 0 + 24 + + + + + 16777215 + 24 + + + + Qt::ScrollBarAlwaysOff + + + Qt::ScrollBarAlwaysOff + + + + + + + + 0 + 24 + + + + + 16777215 + 24 + + + + Qt::ScrollBarAlwaysOff + + + Qt::ScrollBarAlwaysOff + + + + + + + Qt::Horizontal + + @@ -542,7 +604,7 @@ font: 14pt ; QTabWidget::Rounded - 0 + 1 diff --git a/src/phylojunction/pgm/pgm.py b/src/phylojunction/pgm/pgm.py index df08030..cc99138 100644 --- a/src/phylojunction/pgm/pgm.py +++ b/src/phylojunction/pgm/pgm.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as ty import numpy as np +import random import matplotlib.pyplot as plt # type: ignore import matplotlib.ticker as mticker # type: ignore import statistics as stat # type: ignore @@ -50,16 +51,35 @@ class DirectedAcyclicGraph(): name_node_dict: ty.Dict[str, NodeDAG] n_nodes: int sample_size: int + _random_seed: int def __init__(self) -> None: # keys are proper DAG nodes, values are their values self.node_val_dict = dict() - # keys are DAG node names, vals are NodeDAG instances self.name_node_dict = dict() - self.n_nodes = 0 self.sample_size = 0 # how many simulations will be run + self._random_seed = None + + @property + def random_seed(self) -> int: + return self._random_seed + + @random_seed.setter + def random_seed(self, a_seed) -> None: + # handle seed if not None, not empty string, and is integer + if a_seed and isinstance(a_seed, int): + self._random_seed = a_seed + + # now we execute the seed (for two random number generators)! + np.random.seed(seed=a_seed) + random.seed(a_seed) + + else: + raise ec.DAGCannotInitialize( + "seed was 'None'. It must be an integer.") + def add_node(self, node_dag: NodeDAG) -> None: # check that nodes carry the right number of values @@ -358,7 +378,7 @@ def get_start2end_str(self, return "\n".join( str(v) for v in self._value[start:end]) - # not a tree + # tree, and repl_idx is always 0 else: return str(self._value[start + repl_idx]) diff --git a/src/phylojunction/pgm/pgm.pyi b/src/phylojunction/pgm/pgm.pyi index 1d33e7d..5e3c295 100644 --- a/src/phylojunction/pgm/pgm.pyi +++ b/src/phylojunction/pgm/pgm.pyi @@ -15,7 +15,12 @@ class DirectedAcyclicGraph: name_node_dict: ty.Dict[str, NodeDAG] n_nodes: int sample_size: int + _random_seed: int def __init__(self) -> None: ... + @property + def random_seed(self) -> int: ... + @random_seed.setter + def random_seed(self, a_seed: int) -> None: ... def add_node(self, node_dag: NodeDAG) -> None: ... def get_node_dag_by_name(self, node_name): ... def get_display_str_by_name(self, node_name, sample_idx=None, repl_size=1): ... diff --git a/src/phylojunction/utility/exception_classes.py b/src/phylojunction/utility/exception_classes.py index d57cca9..c68585c 100644 --- a/src/phylojunction/utility/exception_classes.py +++ b/src/phylojunction/utility/exception_classes.py @@ -701,6 +701,16 @@ def __str__(self) -> str: return self.message +class DAGCannotInitialize(Exception): + message: str + + def __init__(self, message: str) -> None: + self.message = "ERROR: When initializing DAG object, " + message + + def __str__(self) -> str: + return self.message + + class DAGCannotAddNodeError(Exception): message: str diff --git a/src/phylojunction/utility/exception_classes.pyi b/src/phylojunction/utility/exception_classes.pyi index 4be2f0e..24754f2 100644 --- a/src/phylojunction/utility/exception_classes.pyi +++ b/src/phylojunction/utility/exception_classes.pyi @@ -219,6 +219,11 @@ class NodeDAGNodeStatCantFloatError(Exception): def __init__(self, node_name: str) -> None: ... def __str__(self) -> str: ... +class DAGCannotInitialize(Exception): + message: str + def __init__(self, message: str) -> None: ... + def __str__(self) -> str: ... + class DAGCannotAddNodeError(Exception): message: str def __init__(self, node_name: str, message: str) -> None: ...