Skip to content

Commit

Permalink
Add support for random seed
Browse files Browse the repository at this point in the history
  • Loading branch information
binho authored and binho committed Jan 22, 2024
1 parent 677061f commit 2ba2ee7
Show file tree
Hide file tree
Showing 10 changed files with 265 additions and 98 deletions.
21 changes: 12 additions & 9 deletions src/phylojunction/interface/cmdbox/cmd_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@


def script2dag(script_file_path_or_model_spec: str,
in_pj_file: bool = True) -> pgm.DirectedAcyclicGraph:
in_pj_file: bool = True,
random_seed: ty.Optional[int] = None) -> pgm.DirectedAcyclicGraph:
"""Go through .pj lines and populate and return DAG.
Args:
script_file_path_or_model_spec (str): Path to script .pj
file or full string specifying model directly
file or full string specifying model directly.
random_seed (int): Random seed (integer) for simulation.
Defaults to None.
Returns:
(DirectedAcyclicGraph): DAG object built from .pj commands.
(DirectedAcyclicGraph): DAG object (model) built from .pj
script commands.
"""

def _execute_spec_lines(
Expand All @@ -40,14 +44,13 @@ def _execute_spec_lines(
line = line.lstrip().rstrip()
_ = cmdline2dag(dag_obj, line)

# handle seed if specified
np.random.seed(seed=123)
random.seed(123)

# initialize model DAG
dag = pgm.DirectedAcyclicGraph()
if random_seed is not None:
dag.random_seed = random_seed

all_lines_list: ty.List[str] = []

# get script strings
all_lines_list: ty.List[str] = list()
if in_pj_file:
with open(script_file_path_or_model_spec, "r") as infile:
all_lines_list = infile.readlines()
Expand Down
4 changes: 3 additions & 1 deletion src/phylojunction/interface/cmdbox/cmd_parse.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import typing as ty

import phylojunction.pgm.pgm as pgm

def script2dag(script_file_path: str) -> pgm.DirectedAcyclicGraph: ...
def script2dag(script_file_path: str, in_pj_file: bool=True, random_seed: ty.Optional[int]=None) -> pgm.DirectedAcyclicGraph: ...
def cmdline2dag(dag_obj: pgm.DirectedAcyclicGraph, cmd_line: str): ...
def parse_variable_assignment(dag_obj: pgm.DirectedAcyclicGraph, stoch_node_name: str, stoch_node_spec: str, cmd_line: str) -> None: ...
def parse_samp_dn_assignment(dag_obj, stoch_node_name, stoch_node_dn_spec, cmd_line) -> None: ...
Expand Down
28 changes: 20 additions & 8 deletions src/phylojunction/interface/pjcli/pj_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,25 @@ def execute_pj_script(
write_data: bool = False,
write_figures: str = "",
write_inference: bool = False,
write_nex_states: bool = False) -> None:
write_nex_states: bool = False,
a_random_seed: ty.Optional[int] = None) -> None:
"""
Execute .pj script
Execute .pj script.
This is called by application 'pjcli' from the terminal,
but can also be called from a .py script importing phylojunction
"""

# Reading model #
dag_obj: pgm.DirectedAcyclicGraph = \
pgm.DirectedAcyclicGraph()
random_seed: int = None
if a_random_seed is not None:
random_seed = int(a_random_seed)

print("Reading script " + model)
dag_obj = cmd.script2dag(model, in_pj_file=True)
dag_obj: pgm.DirectedAcyclicGraph \
= cmd.script2dag(model,
in_pj_file=True,
random_seed=random_seed)
print(" ... done!")

n_samples = dag_obj.sample_size
Expand Down Expand Up @@ -135,8 +140,8 @@ def call_cli() -> None:
dest="out_dir",
type=str,
default="./",
help=("Path to project root directory, where automatic "
"subdirectories will be created"))
help=("Path to project root directory, where subdirectories "
"will be automatically created"))
parser.add_argument(
"-p",
"--prefix",
Expand All @@ -150,6 +155,12 @@ def call_cli() -> None:
action="store_true",
default=False,
help="Toggle states nexus file output")
parser.add_argument(
"-r", "--random-seed",
dest="random_seed",
action="store",
default=None,
help="Random seed (integer)")

args = parser.parse_args()

Expand All @@ -160,7 +171,8 @@ def call_cli() -> None:
write_data=args.write_data,
write_figures=args.write_figures,
write_inference=args.write_inference,
write_nex_states=args.write_nex_states
write_nex_states=args.write_nex_states,
a_random_seed=args.random_seed
)

# if one wants to run pj_cli.py for some reason
Expand Down
40 changes: 34 additions & 6 deletions src/phylojunction/interface/pysidegui/pj_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ def __init__(self):
self.dag_obj = DirectedAcyclicGraph()
self.cmd_log_list = []

def update_dag_random_seed(self, a_random_seed: int):
self.dag_obj.random_seed = a_random_seed

def parse_cmd_update_pgm(
self,
cmd_line_list,
gui_main_window_obj,
clear_cmd_log_list: bool = False):

# in case we read two scripts
# consecutively, we want to clear
# the comand line history list
Expand All @@ -53,6 +57,17 @@ def parse_cmd_update_pgm(
line = line.strip()
print(" " + line)

# will set random seed once if DAG does not already have one
if not isinstance(self.dag_obj.random_seed, int):
a_random_seed: str = \
gui_main_window_obj.ui.ui_pages. \
random_seed_prefix_textbox.toPlainText()

# random seed is not None and not empty string
if a_random_seed:
random_seed: int = int(a_random_seed)
self.dag_obj.random_seed = random_seed

# side-effect in cmdline2dag
try:
valid_cmd_line = cmdp.cmdline2dag(self.dag_obj, line)
Expand Down Expand Up @@ -317,7 +332,7 @@ def show_compare_page(self):
def show_coverage_page(self):
self.reset_selection()
self.ui.pages.setCurrentWidget(self.ui.ui_pages.coverage_page)
self.ui.compare_button.set_active(True)
self.ui.covg_button.set_active(True)
self.ui.top_label_left.setText("COVERAGE VALIDATION")

def show_cmd_log_page(self):
Expand Down Expand Up @@ -704,6 +719,14 @@ def read_execute_script(self):
# reset everything #
self.clean_disable_everything()

# set seed
a_random_seed = self.ui.ui_pages.random_seed_prefix_textbox.toPlainText()
# is not None and is not an empty string
if a_random_seed:
random_seed: int = int(a_random_seed)
self.gui_modeling.update_dag_random_seed(random_seed)

# read all command line strings
cmd_line_list = pjread.read_text_file(script_fp)

# (side-effect: gui_modeling stores cmd hist) #
Expand Down Expand Up @@ -800,12 +823,17 @@ def read_coverage_hpd_csv(self):

def load_model(self):
# read file path #
model_fp, filter = QFileDialog.getOpenFileName(parent=self, caption="Load model", dir=".", filter="*.pickle")
model_fp, filter = QFileDialog.getOpenFileName(
parent=self,
caption="Load model",
dir=".",
filter="*.pickle")

if model_fp:
self.clean_disable_everything()
self.ui.ui_pages.cmd_log_textbox.clear()
self.gui_modeling.dag_obj, self.gui_modeling.cmd_log_list = pjread.read_serialized_pgm(model_fp)
self.gui_modeling.dag_obj, self.gui_modeling.cmd_log_list = \
pjread.read_serialized_pgm(model_fp)
self.refresh_node_lists()
self.ui.ui_pages.cmd_log_textbox.setText(self.gui_modeling.cmd_log())

Expand Down Expand Up @@ -1059,10 +1087,10 @@ def _prepare_for_tree(potential_repl: bool = False):
self.ui.ui_pages.sample_idx_spin.setMinimum(1)
self.ui.ui_pages.sample_idx_spin.setValue(1)

self.ui.ui_pages.repl_idx_spin.setMinimum(1)
self.ui.ui_pages.repl_idx_spin.setValue(1)
if potential_repl:
self.ui.ui_pages.repl_idx_spin.setEnabled(True)
self.ui.ui_pages.repl_idx_spin.setMinimum(1)
self.ui.ui_pages.repl_idx_spin.setValue(1)

def _prepare_for_scalar(potential_repl: bool = False):
# radio #
Expand Down Expand Up @@ -1109,7 +1137,7 @@ def _nothing_to_spin_through():
# (basically: non-deterministic nodes)
if isinstance(node_dag.value, list):
if isinstance(node_dag.value[0], pjdt.AnnotatedTree):
if node_dag.repl_size > 1:
if node_dag.repl_size >= 2:
_prepare_for_tree(potential_repl=True)

else:
Expand Down
70 changes: 45 additions & 25 deletions src/phylojunction/interface/pysidegui/pjguipages/gui_pages.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-

# Form implementation generated from reading ui file 'src/phylojunction/interface/pysidegui/pjguipages/pjgui_pages.ui'
# Form implementation generated from reading ui file 'pjgui_pages.ui'
#
# Created by: PyQt5 UI code generator 5.15.7
# Created by: PyQt5 UI code generator 5.15.9
#
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again. Do not edit this file unless you know what you are doing.
Expand All @@ -25,26 +25,45 @@ def setupUi(self, PJGUIPages):
self.settings_page_grid_layout = QtWidgets.QGridLayout(self.gridLayoutWidget_4)
self.settings_page_grid_layout.setContentsMargins(0, 0, 0, 0)
self.settings_page_grid_layout.setObjectName("settings_page_grid_layout")
self.label = QtWidgets.QLabel(self.gridLayoutWidget_4)
self.label.setMinimumSize(QtCore.QSize(0, 24))
self.label.setMaximumSize(QtCore.QSize(16777215, 24))
self.label.setObjectName("label")
self.settings_page_grid_layout.addWidget(self.label, 0, 0, 1, 1)
self.filename_prefix_hor_layout = QtWidgets.QHBoxLayout()
self.filename_prefix_hor_layout.setObjectName("filename_prefix_hor_layout")
self.filename_prefix_label = QtWidgets.QLabel(self.gridLayoutWidget_4)
self.filename_prefix_label.setMinimumSize(QtCore.QSize(110, 24))
self.filename_prefix_label.setMaximumSize(QtCore.QSize(110, 24))
self.filename_prefix_label.setObjectName("filename_prefix_label")
self.filename_prefix_hor_layout.addWidget(self.filename_prefix_label, 0, QtCore.Qt.AlignTop)
self.settings_page_grid_layout.addWidget(self.filename_prefix_label, 1, 0, 1, 1)
self.random_seed_prefix_label = QtWidgets.QLabel(self.gridLayoutWidget_4)
self.random_seed_prefix_label.setMinimumSize(QtCore.QSize(110, 24))
self.random_seed_prefix_label.setMaximumSize(QtCore.QSize(110, 24))
self.random_seed_prefix_label.setObjectName("random_seed_prefix_label")
self.settings_page_grid_layout.addWidget(self.random_seed_prefix_label, 2, 0, 1, 1)
self.settings_label = QtWidgets.QLabel(self.gridLayoutWidget_4)
self.settings_label.setMinimumSize(QtCore.QSize(0, 24))
self.settings_label.setMaximumSize(QtCore.QSize(16777215, 24))
font = QtGui.QFont()
font.setBold(True)
self.settings_label.setFont(font)
self.settings_label.setObjectName("settings_label")
self.settings_page_grid_layout.addWidget(self.settings_label, 0, 0, 1, 1)
spacerItem = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding)
self.settings_page_grid_layout.addItem(spacerItem, 3, 0, 1, 1)
self.filename_prefix_textbox = QtWidgets.QTextEdit(self.gridLayoutWidget_4)
self.filename_prefix_textbox.setMinimumSize(QtCore.QSize(0, 24))
self.filename_prefix_textbox.setMaximumSize(QtCore.QSize(16777215, 24))
self.filename_prefix_textbox.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
self.filename_prefix_textbox.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
self.filename_prefix_textbox.setObjectName("filename_prefix_textbox")
self.filename_prefix_hor_layout.addWidget(self.filename_prefix_textbox, 0, QtCore.Qt.AlignTop)
self.settings_page_grid_layout.addLayout(self.filename_prefix_hor_layout, 1, 0, 1, 1)
self.settings_page_grid_layout.addWidget(self.filename_prefix_textbox, 1, 1, 1, 1)
self.random_seed_prefix_textbox = QtWidgets.QTextEdit(self.gridLayoutWidget_4)
self.random_seed_prefix_textbox.setMinimumSize(QtCore.QSize(0, 24))
self.random_seed_prefix_textbox.setMaximumSize(QtCore.QSize(16777215, 24))
self.random_seed_prefix_textbox.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
self.random_seed_prefix_textbox.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
self.random_seed_prefix_textbox.setObjectName("random_seed_prefix_textbox")
self.settings_page_grid_layout.addWidget(self.random_seed_prefix_textbox, 2, 1, 1, 1)
self.line = QtWidgets.QFrame(self.gridLayoutWidget_4)
self.line.setFrameShape(QtWidgets.QFrame.HLine)
self.line.setFrameShadow(QtWidgets.QFrame.Sunken)
self.line.setObjectName("line")
self.settings_page_grid_layout.addWidget(self.line, 0, 1, 1, 1)
PJGUIPages.addWidget(self.settings_page)
self.pgm_page = QtWidgets.QWidget()
self.pgm_page.setMinimumSize(QtCore.QSize(980, 700))
Expand Down Expand Up @@ -108,7 +127,7 @@ def setupUi(self, PJGUIPages):
" color: #ec4a8a;\n"
"}")
icon = QtGui.QIcon()
icon.addPixmap(QtGui.QPixmap("src/phylojunction/interface/pysidegui/pjguipages/../images/icons/icon_clear.svg"), QtGui.QIcon.Normal, QtGui.QIcon.Off)
icon.addPixmap(QtGui.QPixmap("../images/icons/icon_clear.svg"), QtGui.QIcon.Normal, QtGui.QIcon.Off)
self.clear_model.setIcon(icon)
self.clear_model.setIconSize(QtCore.QSize(20, 20))
self.clear_model.setObjectName("clear_model")
Expand All @@ -131,16 +150,16 @@ def setupUi(self, PJGUIPages):
self.all_samples_radio.setAutoExclusive(True)
self.all_samples_radio.setObjectName("all_samples_radio")
self.radio_spin_hor_layout.addWidget(self.all_samples_radio)
spacerItem = QtWidgets.QSpacerItem(20, 20, QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Minimum)
self.radio_spin_hor_layout.addItem(spacerItem)
spacerItem1 = QtWidgets.QSpacerItem(20, 20, QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Minimum)
self.radio_spin_hor_layout.addItem(spacerItem1)
self.sample_idx_spin = QtWidgets.QSpinBox(self.gridLayoutWidget)
self.sample_idx_spin.setEnabled(False)
self.sample_idx_spin.setStyleSheet("color: black;")
self.sample_idx_spin.setAccelerated(True)
self.sample_idx_spin.setObjectName("sample_idx_spin")
self.radio_spin_hor_layout.addWidget(self.sample_idx_spin)
spacerItem1 = QtWidgets.QSpacerItem(20, 20, QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Minimum)
self.radio_spin_hor_layout.addItem(spacerItem1)
spacerItem2 = QtWidgets.QSpacerItem(20, 20, QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Minimum)
self.radio_spin_hor_layout.addItem(spacerItem2)
self.repl_idx_spin = QtWidgets.QSpinBox(self.gridLayoutWidget)
self.repl_idx_spin.setEnabled(False)
self.repl_idx_spin.setStyleSheet("color: black;\n"
Expand All @@ -149,8 +168,8 @@ def setupUi(self, PJGUIPages):
self.repl_idx_spin.setAccelerated(True)
self.repl_idx_spin.setObjectName("repl_idx_spin")
self.radio_spin_hor_layout.addWidget(self.repl_idx_spin)
spacerItem2 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
self.radio_spin_hor_layout.addItem(spacerItem2)
spacerItem3 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
self.radio_spin_hor_layout.addItem(spacerItem3)
self.save_pgm_node_plot = QtWidgets.QPushButton(self.gridLayoutWidget)
self.save_pgm_node_plot.setEnabled(True)
self.save_pgm_node_plot.setMinimumSize(QtCore.QSize(130, 24))
Expand Down Expand Up @@ -290,8 +309,8 @@ def setupUi(self, PJGUIPages):
self.avg_replicate_check_button.setCursor(QtGui.QCursor(QtCore.Qt.PointingHandCursor))
self.avg_replicate_check_button.setObjectName("avg_replicate_check_button")
self.node_stat_vert_layout.addWidget(self.compare_node_frame)
spacerItem3 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding)
self.node_stat_vert_layout.addItem(spacerItem3)
spacerItem4 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding)
self.node_stat_vert_layout.addItem(spacerItem4)
self.compare_stats_label = QtWidgets.QLabel(self.gridLayoutWidget_2)
self.compare_stats_label.setMinimumSize(QtCore.QSize(0, 16))
self.compare_stats_label.setMaximumSize(QtCore.QSize(16777215, 16))
Expand Down Expand Up @@ -521,8 +540,8 @@ def setupUi(self, PJGUIPages):
self.coverage_node_label.setAlignment(QtCore.Qt.AlignCenter)
self.coverage_node_label.setObjectName("coverage_node_label")
self.cov_node_stat_vert_layout.addWidget(self.coverage_frame, 0, QtCore.Qt.AlignHCenter)
spacerItem4 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding)
self.cov_node_stat_vert_layout.addItem(spacerItem4)
spacerItem5 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding)
self.cov_node_stat_vert_layout.addItem(spacerItem5)
self.compare_stats_label_2 = QtWidgets.QLabel(self.gridLayoutWidget_3)
self.compare_stats_label_2.setMinimumSize(QtCore.QSize(0, 16))
self.compare_stats_label_2.setMaximumSize(QtCore.QSize(16777215, 16))
Expand Down Expand Up @@ -586,14 +605,15 @@ def setupUi(self, PJGUIPages):

self.retranslateUi(PJGUIPages)
PJGUIPages.setCurrentIndex(1)
self.node_content_tabs.setCurrentIndex(0)
self.node_content_tabs.setCurrentIndex(1)
QtCore.QMetaObject.connectSlotsByName(PJGUIPages)

def retranslateUi(self, PJGUIPages):
_translate = QtCore.QCoreApplication.translate
PJGUIPages.setWindowTitle(_translate("PJGUIPages", "StackedWidget"))
self.label.setText(_translate("PJGUIPages", "Output configuration"))
self.filename_prefix_label.setText(_translate("PJGUIPages", "File name prefix:"))
self.random_seed_prefix_label.setText(_translate("PJGUIPages", "Random seed:"))
self.settings_label.setText(_translate("PJGUIPages", "Output configuration"))
self.model_label.setText(_translate("PJGUIPages", "Model nodes"))
self.clear_model.setText(_translate("PJGUIPages", " Clear model"))
self.one_sample_radio.setText(_translate("PJGUIPages", "One sample"))
Expand Down
Loading

0 comments on commit 2ba2ee7

Please sign in to comment.