Skip to content

Commit

Permalink
improve handling of user input for not all sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
luiztauffer committed Nov 15, 2024
1 parent d322bc3 commit 0839bc1
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 152 deletions.
33 changes: 11 additions & 22 deletions src/vame/analysis/community_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from typing import List, Tuple

from vame.util.auxiliary import read_config
from vame.analysis.tree_hierarchy import (
graph_to_tree,
draw_tree,
bag_nodes_by_cutline,
)
from vame.util.data_manipulation import consecutive
from typing import List, Tuple
from vame.util.cli import get_sessions_from_user_input
from vame.schemas.states import save_state, CommunityFunctionSchema
from vame.logging.logger import VameLogger
from vame.schemas.project import SegmentationAlgorithms
Expand Down Expand Up @@ -558,29 +560,16 @@ def community(
model_name = cfg["model_name"]
n_clusters = cfg["n_clusters"]

sessions = []
if cfg["all_data"] == "No":
all_flag = input(
"Do you want to write motif videos for your entire dataset? \n"
"If you only want to use a specific dataset type filename: \n"
"yes/no/filename "
)
# Get sessions
if cfg["all_data"] in ["Yes", "yes"]:
sessions = cfg["session_names"]
else:
all_flag = "yes"

if all_flag == "yes" or all_flag == "Yes":
for session in cfg["session_names"]:
sessions.append(session)
elif all_flag == "no" or all_flag == "No":
for session in cfg["session_names"]:
use_session = input("Do you want to quantify " + session + "? yes/no: ")
if use_session == "yes":
sessions.append(session)
if use_session == "no":
continue
# else:
# files.append(all_flag)
sessions = get_sessions_from_user_input(

Check warning on line 567 in src/vame/analysis/community_analysis.py

View check run for this annotation

Codecov / codecov/patch

src/vame/analysis/community_analysis.py#L567

Added line #L567 was not covered by tests
cfg=cfg,
action_message="run community analysis",
)

# Run community analysis for cohort=True
if cohort:
path_to_dir = Path(
os.path.join(
Expand Down
32 changes: 9 additions & 23 deletions src/vame/analysis/generative_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from sklearn.mixture import GaussianMixture

from vame.schemas.states import GenerativeModelFunctionSchema, save_state
from vame.util.auxiliary import read_config
from vame.logging.logger import VameLogger
from vame.util.auxiliary import read_config
from vame.util.model_util import load_model
from vame.util.cli import get_sessions_from_user_input
from vame.schemas.project import SegmentationAlgorithms


Expand Down Expand Up @@ -257,29 +258,14 @@ def generative_model(
model_name = cfg["model_name"]
n_clusters = cfg["n_clusters"]

sessions = []
if cfg["all_data"] == "No":
all_flag = input(
"Do you want to write motif videos for your entire dataset? \n"
"If you only want to use a specific dataset type filename: \n"
"yes/no/filename "
)
# Get sessions
if cfg["all_data"] in ["Yes", "yes"]:
sessions = cfg["session_names"]
else:
all_flag = "yes"

if all_flag == "yes" or all_flag == "Yes":
for session in cfg["session_names"]:
sessions.append(session)

elif all_flag == "no" or all_flag == "No":
for session in cfg["session_names"]:
use_session = input("Do you want to quantify " + session + "? yes/no: ")
if use_session == "yes":
sessions.append(session)
if use_session == "no":
continue
else:
sessions.append(all_flag)
sessions = get_sessions_from_user_input(

Check warning on line 265 in src/vame/analysis/generative_functions.py

View check run for this annotation

Codecov / codecov/patch

src/vame/analysis/generative_functions.py#L265

Added line #L265 was not covered by tests
cfg=cfg,
action_message="generate samples",
)

model = load_model(cfg, model_name, fixed=False)

Expand Down
33 changes: 10 additions & 23 deletions src/vame/analysis/gif_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
from typing import List, Tuple

from vame.util.auxiliary import read_config
from vame.util.gif_pose_helper import get_animal_frames
from typing import List, Tuple
from vame.util.cli import get_sessions_from_user_input
from vame.logging.logger import VameLogger
from vame.schemas.project import SegmentationAlgorithms

Expand Down Expand Up @@ -154,29 +156,14 @@ def gif(
if segmentation_algorithm not in cfg["segmentation_algorithms"]:
raise ValueError("Segmentation algorithm not found in config")

Check warning on line 157 in src/vame/analysis/gif_creator.py

View check run for this annotation

Codecov / codecov/patch

src/vame/analysis/gif_creator.py#L157

Added line #L157 was not covered by tests

sessions = []
if cfg["all_data"] == "No":
all_flag = input(
"Do you want to write motif videos for your entire dataset? \n"
"If you only want to use a specific dataset type filename: \n"
"yes/no/filename "
)
else:
all_flag = "yes"

if all_flag == "yes" or all_flag == "Yes":
for session in cfg["session_names"]:
sessions.append(session)

elif all_flag == "no" or all_flag == "No":
for session in cfg["session_names"]:
use_session = input("Do you want to quantify " + session + "? yes/no: ")
if use_session == "yes":
sessions.append(session)
if use_session == "no":
continue
# Get sessions
if cfg["all_data"] in ["Yes", "yes"]:
sessions = cfg["session_names"]
else:
sessions.append(all_flag)
sessions = get_sessions_from_user_input(

Check warning on line 163 in src/vame/analysis/gif_creator.py

View check run for this annotation

Codecov / codecov/patch

src/vame/analysis/gif_creator.py#L163

Added line #L163 was not covered by tests
cfg=cfg,
action_message="create gifs",
)

for session in sessions:
path_to_file = os.path.join(
Expand Down
33 changes: 10 additions & 23 deletions src/vame/analysis/pose_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import numpy as np
from pathlib import Path
from typing import List, Tuple, Union
from vame.util.data_manipulation import consecutive
from hmmlearn import hmm
from sklearn.cluster import KMeans

from vame.schemas.states import save_state, SegmentSessionFunctionSchema
from vame.logging.logger import VameLogger, TqdmToLogger
from vame.util.auxiliary import read_config
from vame.model.rnn_model import RNN_VAE
from vame.util.auxiliary import read_config
# from vame.util.data_manipulation import consecutive
from vame.util.cli import get_sessions_from_user_input
from vame.util.model_util import load_model


Expand Down Expand Up @@ -336,29 +338,14 @@ def segment_session(
)
)

sessions = []
if cfg["all_data"] == "No":
all_flag = input(
"Do you want to quantify your entire dataset? \n"
"If you only want to use a specific dataset type the session name: \n"
"yes/no/session_name "
)
else:
all_flag = "yes"

if all_flag == "yes" or all_flag == "Yes":
# Get sessions
if cfg["all_data"] in ["Yes", "yes"]:
sessions = cfg["session_names"]
elif all_flag == "no" or all_flag == "No":
for session in cfg["session_names"]:
use_session = input(
"Do you want to quantify " + session + "? yes/no: "
)
if use_session == "yes":
sessions.append(session)
if use_session == "no":
continue
else:
sessions.append(all_flag)
sessions = get_sessions_from_user_input(

Check warning on line 345 in src/vame/analysis/pose_segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/vame/analysis/pose_segmentation.py#L345

Added line #L345 was not covered by tests
cfg=cfg,
action_message="run segmentation",
)

use_gpu = torch.cuda.is_available()
if use_gpu:
Expand Down
27 changes: 8 additions & 19 deletions src/vame/analysis/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from pathlib import Path
import matplotlib.pyplot as plt
from typing import Optional, Union

from vame.util.auxiliary import read_config
from vame.util.cli import get_sessions_from_user_input
from vame.schemas.states import VisualizationFunctionSchema, save_state
from vame.logging.logger import VameLogger
from vame.schemas.project import SegmentationAlgorithms
Expand Down Expand Up @@ -264,27 +266,14 @@ def visualization(
model_name = cfg["model_name"]
n_clusters = cfg["n_clusters"]

sessions = []
if cfg["all_data"] == "No":
all_flag = input(
"Do you want to write motif videos for your entire dataset? \n"
"If you only want to use a specific dataset, type the session name: \n"
"yes/no/session_name "
)
else:
all_flag = "yes"

if all_flag == "yes" or all_flag == "Yes":
# Get sessions
if cfg["all_data"] in ["Yes", "yes"]:
sessions = cfg["session_names"]
elif all_flag == "no" or all_flag == "No":
for session in cfg["session_names"]:
use_session = input("Do you want to quantify " + session + "? yes/no: ")
if use_session == "yes":
sessions.append(session)
if use_session == "no":
continue
else:
sessions.append(all_flag)
sessions = get_sessions_from_user_input(

Check warning on line 273 in src/vame/analysis/umap.py

View check run for this annotation

Codecov / codecov/patch

src/vame/analysis/umap.py#L273

Added line #L273 was not covered by tests
cfg=cfg,
action_message="generate visualization",
)

for idx, session in enumerate(sessions):
path_to_file = os.path.join(
Expand Down
59 changes: 17 additions & 42 deletions src/vame/analysis/videowriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import cv2 as cv
import tqdm
from typing import Union
import imageio

from vame.util.auxiliary import read_config
from vame.util.cli import get_sessions_from_user_input
from vame.schemas.states import (
save_state,
MotifVideosFunctionSchema,
CommunityVideosFunctionSchema,
)
import imageio
from vame.logging.logger import VameLogger, TqdmToLogger
from vame.schemas.project import SegmentationAlgorithms

Expand Down Expand Up @@ -214,29 +216,15 @@ def motif_videos(
n_clusters = cfg["n_clusters"]

logger.info(f"Creating motif videos for algorithm: {segmentation_algorithm}...")
sessions = []
if cfg["all_data"] == "No":
all_flag = input(
"Do you want to write motif videos for your entire dataset? \n"
"If you only want to use a specific dataset, type the session name: \n"
"yes/no/session_name"
)
else:
all_flag = "yes"

if all_flag == "yes" or all_flag == "Yes":
for session in cfg["session_names"]:
sessions.append(session)

elif all_flag == "no" or all_flag == "No":
for session in cfg["session_names"]:
use_session = input("Do you want to quantify " + session + "? yes/no: ")
if use_session == "yes":
sessions.append(session)
if use_session == "no":
continue

# Get sessions
if cfg["all_data"] in ["Yes", "yes"]:
sessions = cfg["session_names"]
else:
sessions.append(all_flag)
sessions = get_sessions_from_user_input(

Check warning on line 224 in src/vame/analysis/videowriter.py

View check run for this annotation

Codecov / codecov/patch

src/vame/analysis/videowriter.py#L224

Added line #L224 was not covered by tests
cfg=cfg,
action_message="write motif videos",
)

logger.info("Cluster size is: %d " % n_clusters)
for session in sessions:
Expand Down Expand Up @@ -329,27 +317,14 @@ def community_videos(
model_name = cfg["model_name"]
n_clusters = cfg["n_clusters"]

sessions = []
if cfg["all_data"] == "No":
all_flag = input(
"Do you want to write motif videos for your entire dataset? \n"
"If you only want to use a specific dataset, type the session name: \n"
"yes/no/session_name"
)
else:
all_flag = "yes"

if all_flag == "yes" or all_flag == "Yes":
# Get sessions
if cfg["all_data"] in ["Yes", "yes"]:
sessions = cfg["session_names"]
elif all_flag == "no" or all_flag == "No":
for session in cfg["session_names"]:
use_session = input("Do you want to quantify " + session + "? yes/no: ")
if use_session == "yes":
sessions.append(session)
if use_session == "no":
continue
else:
sessions.append(all_flag)
sessions = get_sessions_from_user_input(

Check warning on line 324 in src/vame/analysis/videowriter.py

View check run for this annotation

Codecov / codecov/patch

src/vame/analysis/videowriter.py#L324

Added line #L324 was not covered by tests
cfg=cfg,
action_message="write community videos",
)

logger.info("Cluster size is: %d " % n_clusters)
for session in sessions:
Expand Down
29 changes: 29 additions & 0 deletions src/vame/util/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import List


def get_sessions_from_user_input(
cfg: dict,
action_message: str = "run this step",
) -> List[str]:
user_input = input(

Check warning on line 8 in src/vame/util/cli.py

View check run for this annotation

Codecov / codecov/patch

src/vame/util/cli.py#L8

Added line #L8 was not covered by tests
f"Do you want to {action_message} for your entire dataset? \n"
"If you only want to use a specific session, type the session name \n"
"yes/no/<session_name>: "
)
if user_input in ["Yes", "yes"]:
sessions = cfg["session_names"]
elif user_input in ["No", "no"]:
sessions = []
for session in cfg["session_names"]:
use_session = input("Do you want to use " + session + "? yes/no: ")
if use_session in ["Yes", "yes"]:
sessions.append(session)

Check warning on line 20 in src/vame/util/cli.py

View check run for this annotation

Codecov / codecov/patch

src/vame/util/cli.py#L13-L20

Added lines #L13 - L20 were not covered by tests
else:
continue

Check warning on line 22 in src/vame/util/cli.py

View check run for this annotation

Codecov / codecov/patch

src/vame/util/cli.py#L22

Added line #L22 was not covered by tests
else:
if user_input in cfg["session_names"]:
sessions = [user_input]

Check warning on line 25 in src/vame/util/cli.py

View check run for this annotation

Codecov / codecov/patch

src/vame/util/cli.py#L24-L25

Added lines #L24 - L25 were not covered by tests
else:
raise ValueError(

Check warning on line 27 in src/vame/util/cli.py

View check run for this annotation

Codecov / codecov/patch

src/vame/util/cli.py#L27

Added line #L27 was not covered by tests
"Invalid input. Please enter yes, no, or a valid session name."
)

0 comments on commit 0839bc1

Please sign in to comment.