diff --git a/pyproject.toml b/pyproject.toml index a423e742..61f7aecc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,3 +28,6 @@ where = ["src"] [tool.pytest.ini_options] pythonpath = [".", "src"] testpaths = ["tests"] + +[tool.black] +line-length = 119 diff --git a/src/vame/__init__.py b/src/vame/__init__.py index 3e43ad94..b3bc56a6 100644 --- a/src/vame/__init__.py +++ b/src/vame/__init__.py @@ -3,6 +3,7 @@ sys.dont_write_bytecode = True from vame.initialize_project import init_new_project +from vame.preprocessing.preprocessing import preprocessing from vame.model import create_trainset from vame.model import train_model from vame.model import evaluate_model diff --git a/src/vame/analysis/community_analysis.py b/src/vame/analysis/community_analysis.py index ca0232ba..60dbcffb 100644 --- a/src/vame/analysis/community_analysis.py +++ b/src/vame/analysis/community_analysis.py @@ -250,12 +250,7 @@ def get_motif_labels( file_labels = np.load( os.path.join( path_to_dir, - str(n_clusters) - + "_" - + segmentation_algorithm - + "_label_" - + session - + ".npy", + str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy", ) ) shape = len(file_labels) @@ -276,12 +271,7 @@ def get_motif_labels( file_labels = np.load( os.path.join( path_to_dir, - str(n_clusters) - + "_" - + segmentation_algorithm - + "_label_" - + session - + ".npy", + str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy", ) )[:min_frames] community_label.extend(file_labels) @@ -390,11 +380,7 @@ def create_cohort_community_bag( add = input("Extend list or add in the end? (ext/end)") if add == "ext": motif_idx = int(input("Which motif number? ")) - list_idx = int( - input( - "At which position in the list? (pythonic indexing starts at 0) " - ) - ) + list_idx = int(input("At which position in the list? (pythonic indexing starts at 0) ")) community_bag[list_idx].append(motif_idx) if add == "end": motif_idx = int(input("Which motif number? ")) @@ -440,9 +426,7 @@ def get_cohort_community_labels( for j in range(len(clust)): find_clust = np.where(motif_labels == clust[j])[0] community_labels[find_clust] = i - community_labels = np.int64( - scipy.signal.medfilt(community_labels, median_filter_size) - ) + community_labels = np.int64(scipy.signal.medfilt(community_labels, median_filter_size)) community_labels_all.append(community_labels) return community_labels_all @@ -468,12 +452,7 @@ def save_cohort_community_labels_per_file( file_labels = np.load( os.path.join( path_to_dir, - str(n_clusters) - + "_" - + segmentation_algorithm - + "_label_" - + session - + ".npy", + str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy", ) ) community_labels = get_cohort_community_labels( @@ -640,9 +619,7 @@ def community( ), cohort_community_bag, ) - with open( - os.path.join(path_to_dir, "hierarchy" + ".pkl"), "wb" - ) as fp: # Pickling + with open(os.path.join(path_to_dir, "hierarchy" + ".pkl"), "wb") as fp: # Pickling pickle.dump(cohort_community_bag, fp) # Added by Luiz - 11/10/2024 @@ -659,9 +636,7 @@ def community( # # Work in Progress - cohort is False else: - raise NotImplementedError( - "Community analysis for cohort=False is not supported yet." - ) + raise NotImplementedError("Community analysis for cohort=False is not supported yet.") # labels = get_labels(cfg, files, model_name, n_clusters, parametrization) # transition_matrices = compute_transition_matrices( # files, diff --git a/src/vame/analysis/generative_functions.py b/src/vame/analysis/generative_functions.py index 6849e5d2..075a6b0e 100644 --- a/src/vame/analysis/generative_functions.py +++ b/src/vame/analysis/generative_functions.py @@ -333,12 +333,7 @@ def generative_model( os.path.join( path_to_file, "", - str(n_clusters) - + "_" - + segmentation_algorithm - + "_label_" - + session - + ".npy", + str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy", ) ) return random_generative_samples_motif( diff --git a/src/vame/analysis/gif_creator.py b/src/vame/analysis/gif_creator.py index 6a864298..d28c4d5c 100644 --- a/src/vame/analysis/gif_creator.py +++ b/src/vame/analysis/gif_creator.py @@ -102,9 +102,7 @@ def create_video( frame = frames[i] ax2.imshow(frame, cmap=cmap_reversed) # ax2.set_title("Motif %d,\n Community: %s" % (lbl, motifs[lbl]), fontsize=10) - fig.savefig( - os.path.join(path_to_file, "gif_frames", session + "gif_%d.png") % i - ) + fig.savefig(os.path.join(path_to_file, "gif_frames", session + "gif_%d.png") % i) def gif( @@ -205,9 +203,7 @@ def gif( random_state=cfg["random_state"], ) - latent_vector = np.load( - os.path.join(path_to_file, "", "latent_vector_" + session + ".npy") - ) + latent_vector = np.load(os.path.join(path_to_file, "", "latent_vector_" + session + ".npy")) num_points = cfg["num_points"] if num_points > latent_vector.shape[0]: @@ -228,12 +224,7 @@ def gif( umap_label = np.load( os.path.join( path_to_file, - str(n_clusters) - + "_" - + segmentation_algorithm - + "_label_" - + session - + ".npy", + str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy", ) ) elif label == "community": diff --git a/src/vame/analysis/pose_segmentation.py b/src/vame/analysis/pose_segmentation.py index 1fce1670..e6520835 100644 --- a/src/vame/analysis/pose_segmentation.py +++ b/src/vame/analysis/pose_segmentation.py @@ -82,15 +82,9 @@ def embedd_latent_vectors( data_sample_np = data[:, i : temp_win + i].T data_sample_np = np.reshape(data_sample_np, (1, temp_win, num_features)) if use_gpu: - h_n = model.encoder( - torch.from_numpy(data_sample_np) - .type("torch.FloatTensor") - .cuda() - ) + h_n = model.encoder(torch.from_numpy(data_sample_np).type("torch.FloatTensor").cuda()) else: - h_n = model.encoder( - torch.from_numpy(data_sample_np).type("torch.FloatTensor").to() - ) + h_n = model.encoder(torch.from_numpy(data_sample_np).type("torch.FloatTensor").to()) mu, _, _ = model.lmbda(h_n) latent_vector_list.append(mu.cpu().data.numpy()) @@ -406,9 +400,7 @@ def segment_session( ) else: - logger.info( - f"\nSegmentation with {n_clusters} k-means clusters already exists for model {model_name}" - ) + logger.info(f"\nSegmentation with {n_clusters} k-means clusters already exists for model {model_name}") if os.path.exists( os.path.join( diff --git a/src/vame/analysis/tree_hierarchy.py b/src/vame/analysis/tree_hierarchy.py index a718e861..652648a4 100644 --- a/src/vame/analysis/tree_hierarchy.py +++ b/src/vame/analysis/tree_hierarchy.py @@ -42,9 +42,7 @@ def hierarchy_pos( raise TypeError("cannot use hierarchy_pos on a graph that is not a tree") if root is None: if isinstance(G, nx.DiGraph): - root = next( - iter(nx.topological_sort(G)) - ) # allows back compatibility with nx version 1.11 + root = next(iter(nx.topological_sort(G))) # allows back compatibility with nx version 1.11 else: root = random.choice(list(G.nodes)) @@ -121,9 +119,7 @@ def merge_func( for i in range(n_clusters): for j in range(n_clusters): try: - cost = motif_norm[i] + motif_norm[j] / np.abs( - transition_matrix[i, j] + transition_matrix[j, i] - ) + cost = motif_norm[i] + motif_norm[j] / np.abs(transition_matrix[i, j] + transition_matrix[j, i]) except ZeroDivisionError: print( "Error: Transition probabilities between motif " diff --git a/src/vame/analysis/umap.py b/src/vame/analysis/umap.py index 2216d43d..1fd67b25 100644 --- a/src/vame/analysis/umap.py +++ b/src/vame/analysis/umap.py @@ -328,12 +328,7 @@ def visualization( os.path.join( path_to_file, "", - str(n_clusters) - + "_" - + segmentation_algorithm - + "_label_" - + session - + ".npy", + str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy", ) ) output_figure = umap_label_vis( diff --git a/src/vame/analysis/videowriter.py b/src/vame/analysis/videowriter.py index 40f78bdf..a03656aa 100644 --- a/src/vame/analysis/videowriter.py +++ b/src/vame/analysis/videowriter.py @@ -71,19 +71,12 @@ def create_cluster_videos( labels = np.load( os.path.join( path_to_file, - str(n_clusters) - + "_" - + segmentation_algorithm - + "_label_" - + session - + ".npy", + str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy", ) ) if flag == "community": if cohort: - logger.info( - "Cohort community videos getting created for " + session + " ..." - ) + logger.info("Cohort community videos getting created for " + session + " ...") labels = np.load( os.path.join( path_to_file, @@ -109,9 +102,7 @@ def create_cluster_videos( ) capture = cv.VideoCapture(video_file_path) if not capture.isOpened(): - raise ValueError( - f"Video capture could not be opened. Ensure the video file is valid.\n {video_file_path}" - ) + raise ValueError(f"Video capture could not be opened. Ensure the video file is valid.\n {video_file_path}") width = capture.get(cv.CAP_PROP_FRAME_WIDTH) height = capture.get(cv.CAP_PROP_FRAME_HEIGHT) fps = 25 # capture.get(cv.CAP_PROP_FPS) diff --git a/src/vame/initialize_project/new.py b/src/vame/initialize_project/new.py index 784d72fa..b6f05328 100644 --- a/src/vame/initialize_project/new.py +++ b/src/vame/initialize_project/new.py @@ -110,20 +110,14 @@ def init_new_project( for i in videos: # Check if it is a folder if os.path.isdir(i): - vids_in_dir = [ - os.path.join(i, vp) for vp in os.listdir(i) if video_type in vp - ] + vids_in_dir = [os.path.join(i, vp) for vp in os.listdir(i) if video_type in vp] vids = vids + vids_in_dir if len(vids_in_dir) == 0: logger.info(f"No videos found in {i}") - logger.info( - f"Perhaps change the video_type, which is currently set to: {video_type}" - ) + logger.info(f"Perhaps change the video_type, which is currently set to: {video_type}") else: videos = vids - logger.info( - f"{len(vids_in_dir)} videos from the directory {i} were added to the project." - ) + logger.info(f"{len(vids_in_dir)} videos from the directory {i} were added to the project.") else: if os.path.isfile(i): vids = vids + [i] @@ -210,9 +204,7 @@ def init_new_project( unique_num_features = list(set(num_features_list)) if len(unique_num_features) > 1: - raise ValueError( - "All pose estimation files must have the same number of features." - ) + raise ValueError("All pose estimation files must have the same number of features.") if config_kwargs is None: config_kwargs = {} diff --git a/src/vame/logging/logger.py b/src/vame/logging/logger.py index 9fd68753..f5fb0b61 100644 --- a/src/vame/logging/logger.py +++ b/src/vame/logging/logger.py @@ -6,8 +6,7 @@ class VameLogger: LOG_FORMAT = ( - "%(asctime)-15s.%(msecs)d %(levelname)-5s --- [%(threadName)s]" - " %(name)-15s : %(lineno)d : %(message)s" + "%(asctime)-15s.%(msecs)d %(levelname)-5s --- [%(threadName)s]" " %(name)-15s : %(lineno)d : %(message)s" ) LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" @@ -19,9 +18,7 @@ def __init__( ): self.log_level = log_level self.file_handler = None - logging.basicConfig( - level=log_level, format=self.LOG_FORMAT, datefmt=self.LOG_DATE_FORMAT - ) + logging.basicConfig(level=log_level, format=self.LOG_FORMAT, datefmt=self.LOG_DATE_FORMAT) self.logger = logging.getLogger(f"{base_name}") if self.logger.hasHandlers(): self.logger.handlers.clear() @@ -29,9 +26,7 @@ def __init__( self.logger.setLevel(self.log_level) # Stream handler for logging to stdout stream_handler = logging.StreamHandler() - stream_handler.setFormatter( - logging.Formatter(self.LOG_FORMAT, self.LOG_DATE_FORMAT) - ) + stream_handler.setFormatter(logging.Formatter(self.LOG_FORMAT, self.LOG_DATE_FORMAT)) self.logger.addHandler(stream_handler) self.logger.propagate = False @@ -56,9 +51,7 @@ def add_file_handler(self, file_path: str): f.write(f"{line_break}[LOGGING STARTED AT: {handler_datetime}]") self.file_handler = logging.FileHandler(file_path, mode="a") - self.file_handler.setFormatter( - logging.Formatter(self.LOG_FORMAT, self.LOG_DATE_FORMAT) - ) + self.file_handler.setFormatter(logging.Formatter(self.LOG_FORMAT, self.LOG_DATE_FORMAT)) self.logger.addHandler(self.file_handler) def remove_file_handler(self): diff --git a/src/vame/model/create_training.py b/src/vame/model/create_training.py index ea19312f..a9f554c4 100644 --- a/src/vame/model/create_training.py +++ b/src/vame/model/create_training.py @@ -98,9 +98,7 @@ def plot_check_parameter( plt.title("Original signal z-scored") plt.legend() - logger.info( - "Please run the function with check_parameter=False if you are happy with the results" - ) + logger.info("Please run the function with check_parameter=False if you are happy with the results") def traindata_aligned( @@ -170,10 +168,7 @@ def traindata_aligned( if cfg["robust"]: iqr_val = iqr(X_z) - logger.info( - "IQR value: %.2f, IQR cutoff: %.2f" - % (iqr_val, cfg["iqr_factor"] * iqr_val) - ) + logger.info("IQR value: %.2f, IQR cutoff: %.2f" % (iqr_val, cfg["iqr_factor"] * iqr_val)) for i in range(X_z.shape[0]): for marker in range(X_z.shape[1]): if X_z[i, marker] > cfg["iqr_factor"] * iqr_val: @@ -330,10 +325,7 @@ def traindata_fixed( if cfg["robust"]: iqr_val = iqr(X_z) - logger.info( - "IQR value: %.2f, IQR cutoff: %.2f" - % (iqr_val, cfg["iqr_factor"] * iqr_val) - ) + logger.info("IQR value: %.2f, IQR cutoff: %.2f" % (iqr_val, cfg["iqr_factor"] * iqr_val)) for i in range(X_z.shape[0]): for marker in range(X_z.shape[1]): if X_z[i, marker] > cfg["iqr_factor"] * iqr_val: @@ -373,9 +365,7 @@ def traindata_fixed( else: if pose_ref_index is None: - raise ValueError( - "Please provide a pose reference index for training on fixed data. E.g. [0,5]" - ) + raise ValueError("Please provide a pose reference index for training on fixed data. E.g. [0,5]") # save numpy arrays the the test/train info: np.save( os.path.join( @@ -493,15 +483,10 @@ def create_trainset( logger.info("Creating training dataset...") if cfg["robust"]: - logger.info( - "Using robust setting to eliminate outliers! IQR factor: %d" - % cfg["iqr_factor"] - ) + logger.info("Using robust setting to eliminate outliers! IQR factor: %d" % cfg["iqr_factor"]) if not fixed: - logger.info( - "Creating trainset from the vame.egocentrical_alignment() output " - ) + logger.info("Creating trainset from the vame.egocentrical_alignment() output ") traindata_aligned( cfg, sessions, @@ -522,9 +507,7 @@ def create_trainset( ) if not check_parameter: - logger.info( - "A training and test set has been created. Next step: vame.train_model()" - ) + logger.info("A training and test set has been created. Next step: vame.train_model()") except Exception as e: logger.exception(str(e)) diff --git a/src/vame/model/evaluate.py b/src/vame/model/evaluate.py index 5372c922..de2111c4 100644 --- a/src/vame/model/evaluate.py +++ b/src/vame/model/evaluate.py @@ -69,18 +69,10 @@ def plot_reconstruction( x = x.permute(0, 2, 1) if use_gpu: data = x[:, :seq_len_half, :].type("torch.FloatTensor").cuda() - data_fut = ( - x[:, seq_len_half : seq_len_half + FUTURE_STEPS, :] - .type("torch.FloatTensor") - .cuda() - ) + data_fut = x[:, seq_len_half : seq_len_half + FUTURE_STEPS, :].type("torch.FloatTensor").cuda() else: data = x[:, :seq_len_half, :].type("torch.FloatTensor").to() - data_fut = ( - x[:, seq_len_half : seq_len_half + FUTURE_STEPS, :] - .type("torch.FloatTensor") - .to() - ) + data_fut = x[:, seq_len_half : seq_len_half + FUTURE_STEPS, :].type("torch.FloatTensor").to() if FUTURE_DECODER: x_tilde, future, latent, mu, logvar = model(data) @@ -99,9 +91,7 @@ def plot_reconstruction( if FUTURE_DECODER: fig, axs = plt.subplots(2, 5) - fig.suptitle( - "Reconstruction [top] and future prediction [bottom] of input sequence" - ) + fig.suptitle("Reconstruction [top] and future prediction [bottom] of input sequence") for i in range(5): axs[0, i].plot(data_orig[i, ...], color="k", label="Sequence Data") axs[0, i].plot( @@ -129,9 +119,7 @@ def plot_reconstruction( fig.set_tight_layout(True) if not suffix: fig.savefig( - os.path.join( - filepath, "evaluate", "Reconstruction_" + model_name + ".png" - ), + os.path.join(filepath, "evaluate", "Reconstruction_" + model_name + ".png"), bbox_inches="tight", ) elif suffix: @@ -174,12 +162,8 @@ def plot_loss( basepath = os.path.join(cfg["project_path"], "model", "model_losses") train_loss = np.load(os.path.join(basepath, "train_losses_" + model_name + ".npy")) test_loss = np.load(os.path.join(basepath, "test_losses_" + model_name + ".npy")) - mse_loss_train = np.load( - os.path.join(basepath, "mse_train_losses_" + model_name + ".npy") - ) - mse_loss_test = np.load( - os.path.join(basepath, "mse_test_losses_" + model_name + ".npy") - ) + mse_loss_train = np.load(os.path.join(basepath, "mse_train_losses_" + model_name + ".npy")) + mse_loss_test = np.load(os.path.join(basepath, "mse_test_losses_" + model_name + ".npy")) km_losses = np.load(os.path.join(basepath, "kmeans_losses_" + model_name + ".npy")) kl_loss = np.load(os.path.join(basepath, "kl_losses_" + model_name + ".npy")) fut_loss = np.load(os.path.join(basepath, "fut_losses_" + model_name + ".npy")) @@ -196,9 +180,7 @@ def plot_loss( ax1.plot(kl_loss, label="KL-Loss") ax1.plot(fut_loss, label="Prediction-Loss") ax1.legend() - fig.savefig( - os.path.join(filepath, "evaluate", "MSE-and-KL-Loss" + model_name + ".png") - ) + fig.savefig(os.path.join(filepath, "evaluate", "MSE-and-KL-Loss" + model_name + ".png")) def eval_temporal( @@ -308,9 +290,7 @@ def eval_temporal( ) ) elif snapshot: - model.load_state_dict( - torch.load(snapshot), map_location=torch.device("cpu") - ) + model.load_state_dict(torch.load(snapshot), map_location=torch.device("cpu")) model.eval() # toggle evaluation mode testset = SEQUENCE_DATASET( @@ -320,9 +300,7 @@ def eval_temporal( temporal_window=TEMPORAL_WINDOW, logger_config=logger_config, ) - test_loader = Data.DataLoader( - testset, batch_size=TEST_BATCH_SIZE, shuffle=True, drop_last=True - ) + test_loader = Data.DataLoader(testset, batch_size=TEST_BATCH_SIZE, shuffle=True, drop_last=True) if not snapshot: plot_reconstruction( @@ -405,13 +383,9 @@ def evaluate_model( if not use_snapshots: eval_temporal(cfg, use_gpu, model_name, fixed) # suffix=suffix elif use_snapshots: - snapshots = os.listdir( - os.path.join(cfg["project_path"], "model", "best_model", "snapshots") - ) + snapshots = os.listdir(os.path.join(cfg["project_path"], "model", "best_model", "snapshots")) for snap in snapshots: - fullpath = os.path.join( - cfg["project_path"], "model", "best_model", "snapshots", snap - ) + fullpath = os.path.join(cfg["project_path"], "model", "best_model", "snapshots", snap) epoch = snap.split("_")[-1] eval_temporal( cfg, diff --git a/src/vame/model/rnn_vae.py b/src/vame/model/rnn_vae.py index ca0e5b49..cc1c88b6 100644 --- a/src/vame/model/rnn_vae.py +++ b/src/vame/model/rnn_vae.py @@ -175,9 +175,7 @@ def kl_annealing( elif function == "sigmoid": new_weight = float(1 / (1 + np.exp(-0.9 * (epoch - annealtime)))) else: - raise NotImplementedError( - 'currently only "linear" and "sigmoid" are implemented' - ) + raise NotImplementedError('currently only "linear" and "sigmoid" are implemented') return new_weight else: @@ -302,18 +300,10 @@ def train( data_item = data_item.permute(0, 2, 1) if use_gpu: data = data_item[:, :seq_len_half, :].type("torch.FloatTensor").cuda() - fut = ( - data_item[:, seq_len_half : seq_len_half + future_steps, :] - .type("torch.FloatTensor") - .cuda() - ) + fut = data_item[:, seq_len_half : seq_len_half + future_steps, :].type("torch.FloatTensor").cuda() else: data = data_item[:, :seq_len_half, :].type("torch.FloatTensor").to() - fut = ( - data_item[:, seq_len_half : seq_len_half + future_steps, :] - .type("torch.FloatTensor") - .to() - ) + fut = data_item[:, seq_len_half : seq_len_half + future_steps, :].type("torch.FloatTensor").to() if noise is True: data_gaussian = gaussian(data, True, seq_len_half) @@ -327,12 +317,7 @@ def train( kmeans_loss = cluster_loss(latent.T, kloss, klmbda, bsize) kl_loss = kullback_leibler_loss(mu, logvar) kl_weight = kl_annealing(epoch, kl_start, annealtime, anneal_function) - loss = ( - rec_loss - + fut_rec_loss - + BETA * kl_weight * kl_loss - + kl_weight * kmeans_loss - ) + loss = rec_loss + fut_rec_loss + BETA * kl_weight * kl_loss + kl_weight * kmeans_loss fut_loss += fut_rec_loss.item() else: data_tilde, latent, mu, logvar = model(data_gaussian) @@ -536,15 +521,9 @@ def train_model(config: str, save_logs: bool = False) -> None: fixed = cfg["egocentric_data"] logger.info("Train Variational Autoencoder - model name: %s \n" % model_name) - if not os.path.exists( - os.path.join(cfg["project_path"], "model", "best_model", "") - ): + if not os.path.exists(os.path.join(cfg["project_path"], "model", "best_model", "")): os.mkdir(os.path.join(cfg["project_path"], "model", "best_model", "")) - os.mkdir( - os.path.join( - cfg["project_path"], "model", "best_model", "snapshots", "" - ) - ) + os.mkdir(os.path.join(cfg["project_path"], "model", "best_model", "snapshots", "")) os.mkdir(os.path.join(cfg["project_path"], "model", "model_losses", "")) # make sure torch uses cuda for GPU computing @@ -555,9 +534,7 @@ def train_model(config: str, save_logs: bool = False) -> None: logger.info("GPU used: {}".format(torch.cuda.get_device_name(0))) else: torch.device("cpu") - logger.info( - "warning, a GPU was not found... proceeding with CPU (slow!) \n" - ) + logger.info("warning, a GPU was not found... proceeding with CPU (slow!) \n") # raise NotImplementedError('GPU Computing is required!') # HYPERPARAMETERS @@ -687,16 +664,12 @@ def train_model(config: str, save_logs: bool = False) -> None: ) ) try: - logger.info( - "Loading pretrained weights from %s\n" % pretrained_model - ) + logger.info("Loading pretrained weights from %s\n" % pretrained_model) model.load_state_dict(torch.load(pretrained_model)) KL_START = 0 ANNEALTIME = 1 except Exception: - logger.error( - "Could not load pretrained model. Check file path in config.yaml." - ) + logger.error("Could not load pretrained model. Check file path in config.yaml.") """ DATASET """ trainset = SEQUENCE_DATASET( @@ -712,19 +685,14 @@ def train_model(config: str, save_logs: bool = False) -> None: temporal_window=TEMPORAL_WINDOW, ) - train_loader = Data.DataLoader( - trainset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, drop_last=True - ) - test_loader = Data.DataLoader( - testset, batch_size=TEST_BATCH_SIZE, shuffle=True, drop_last=True - ) + train_loader = Data.DataLoader(trainset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, drop_last=True) + test_loader = Data.DataLoader(testset, batch_size=TEST_BATCH_SIZE, shuffle=True, drop_last=True) optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, amsgrad=True) if optimizer_scheduler: logger.info( - "Scheduler step size: %d, Scheduler gamma: %.2f\n" - % (scheduler_step_size, cfg["scheduler_gamma"]) + "Scheduler step size: %d, Scheduler gamma: %.2f\n" % (scheduler_step_size, cfg["scheduler_gamma"]) ) # Thanks to @alexcwsmith for the optimized scheduler contribution scheduler = ReduceLROnPlateau( @@ -737,9 +705,7 @@ def train_model(config: str, save_logs: bool = False) -> None: verbose=True, ) else: - scheduler = StepLR( - optimizer, step_size=scheduler_step_size, gamma=1, last_epoch=-1 - ) + scheduler = StepLR(optimizer, step_size=scheduler_step_size, gamma=1, last_epoch=-1) logger.info("Start training... ") for epoch in tqdm( @@ -817,12 +783,7 @@ def train_model(config: str, save_logs: bool = False) -> None: "model", "best_model", "snapshots", - model_name - + "_" - + cfg["project_name"] - + "_epoch_" - + str(epoch) - + ".pkl", + model_name + "_" + cfg["project_name"] + "_epoch_" + str(epoch) + ".pkl", ), ) diff --git a/src/vame/pipeline.py b/src/vame/pipeline.py index 5fb73371..cba1b344 100644 --- a/src/vame/pipeline.py +++ b/src/vame/pipeline.py @@ -64,9 +64,7 @@ def get_raw_datasets(self) -> xr.Dataset: datasets = list() attributes = list() for session in sessions: - ds_path = ( - Path(self.config["project_path"]) / "data" / "raw" / f"{session}.nc" - ) + ds_path = Path(self.config["project_path"]) / "data" / "raw" / f"{session}.nc" ds = load_vame_dataset(ds_path=ds_path) ds = ds.expand_dims({"session": [session]}) datasets.append(ds) @@ -78,9 +76,7 @@ def get_raw_datasets(self) -> xr.Dataset: dss_attrs.setdefault(key, []).append(value) for key, values in dss_attrs.items(): unique_values = unique_in_order(values) # Maintain order of unique values - dss_attrs[key] = ( - unique_values[0] if len(unique_values) == 1 else unique_values - ) + dss_attrs[key] = unique_values[0] if len(unique_values) == 1 else unique_values for key, value in dss_attrs.items(): dss.attrs[key] = value return dss diff --git a/src/vame/preprocessing/align_egocentrical.py b/src/vame/preprocessing/align_egocentrical.py index c2d3b1cb..36ba8240 100644 --- a/src/vame/preprocessing/align_egocentrical.py +++ b/src/vame/preprocessing/align_egocentrical.py @@ -118,9 +118,7 @@ def align_mouse_legacy( pose_list_bordered = [] for i in pose_list: - pose_list_bordered.append( - (int(i[idx][0] + crop_size[0]), int(i[idx][1] + crop_size[1])) - ) + pose_list_bordered.append((int(i[idx][0] + crop_size[0]), int(i[idx][1] + crop_size[1]))) img = cv.copyMakeBorder( frame, @@ -363,10 +361,7 @@ def egocentric_alignment_legacy( # call function and save into your VAME data folder paths_to_pose_nwb_series_data = cfg["paths_to_pose_nwb_series_data"] for i, session in enumerate(sessions): - logger.info( - "Aligning session %s, Pose confidence value: %.2f" - % (session, confidence) - ) + logger.info("Aligning session %s, Pose confidence value: %.2f" % (session, confidence)) egocentric_time_series, frames = alignment_legacy( project_path=project_path, session=session, @@ -405,9 +400,7 @@ def egocentric_alignment_legacy( egocentric_time_series_shifted, ) - logger.info( - "Your data is now in the right format and you can call vame.create_trainset()" - ) + logger.info("Your data is now in the right format and you can call vame.create_trainset()") except Exception as e: logger.exception(f"{e}") raise e @@ -477,10 +470,7 @@ def egocentric_alignment( # call function and save into your VAME data folder for i, session in enumerate(sessions): - logger.info( - "Aligning session %s, Pose confidence value: %.2f" - % (session, confidence) - ) + logger.info("Aligning session %s, Pose confidence value: %.2f" % (session, confidence)) # read out data file_path = str(Path(project_path) / "data" / "raw" / f"{session}.nc") _, data_mat, ds = read_pose_estimation_file(file_path=file_path) @@ -540,9 +530,7 @@ def egocentric_alignment( result_file = Path(project_path) / "data" / "processed" / session / f"{session}-aligned.nc" ds.to_netcdf(result_file, engine="scipy") - logger.info( - "Your data is now in the right format and you can call vame.create_trainset()" - ) + logger.info("Your data is now in the right format and you can call vame.create_trainset()") except Exception as e: logger.exception(f"{e}") raise e @@ -601,9 +589,7 @@ def alignment( pose_list_bordered = [] for i in pose_list: - pose_list_bordered.append( - (int(i[idx][0] + crop_size[0]), int(i[idx][1] + crop_size[1])) - ) + pose_list_bordered.append((int(i[idx][0] + crop_size[0]), int(i[idx][1] + crop_size[1]))) punkte = [] for i in pose_ref_index: diff --git a/src/vame/preprocessing/align_new.py b/src/vame/preprocessing/align_new.py index 94635414..7bce5e91 100644 --- a/src/vame/preprocessing/align_new.py +++ b/src/vame/preprocessing/align_new.py @@ -31,9 +31,7 @@ def align_time_series(data, keypoint1, keypoint2, confidence_threshold): # Loop over individuals for ind in range(positions.shape[1]): - individual_positions = positions[ - :, ind, :, : - ] # Shape: (time, keypoints, space) + individual_positions = positions[:, ind, :, :] # Shape: (time, keypoints, space) individual_confidence = confidence[:, ind, :] # Shape: (time, keypoints) # Replace low-confidence points with NaN @@ -54,9 +52,7 @@ def align_time_series(data, keypoint1, keypoint2, confidence_threshold): ) # Centralize all positions around the first keypoint - centralized_positions = ( - individual_positions - individual_positions[:, idx1, :][:, np.newaxis, :] - ) + centralized_positions = individual_positions - individual_positions[:, idx1, :][:, np.newaxis, :] # Calculate vectors between keypoints vector = centralized_positions[:, idx2, :] # Vector from keypoint1 to keypoint2 diff --git a/src/vame/preprocessing/clean_timeseries.py b/src/vame/preprocessing/clean_timeseries.py new file mode 100644 index 00000000..87351204 --- /dev/null +++ b/src/vame/preprocessing/clean_timeseries.py @@ -0,0 +1,77 @@ +from pathlib import Path +import numpy as np +from scipy.stats import iqr + +from vame.logging.logger import VameLogger +from vame.io.load_poses import load_vame_dataset +from vame.util.data_manipulation import interpolate_nans_with_pandas + + +logger_config = VameLogger(__name__) +logger = logger_config.logger + + +def clean_timeseries( + config: dict, +): + X_all_sessions = [] + pos = [0] + pos_temp = 0 + + session_names = config["session_names"] + for session in session_names: + logger.info("z-scoring of session %s" % session) + + # path_to_file = Path(config["project_path"]) / "data" / "processed" / session / session + "-PE-seq.npy" + # data = np.load(path_to_file) + + path_to_file = Path(config["project_path"]) / "data" / "processed" / session / session + "-aligned.nc" + ds = load_vame_dataset(path_to_file) + X = ds.position_aligned.sel(individuals="individual_0").values + + # Standardize data + X_mean = np.mean(X, axis=0) + X_std = np.std(X, axis=0) + X_z = (X - X_mean) / X_std + + # Robust interquartile range outlier detection + if config["robust"]: + iqr_val = iqr(X_z, axis=0) + logger.info("IQR value: %.2f, IQR cutoff: %.2f" % (iqr_val, config["iqr_factor"] * iqr_val)) + for t in range(X_z.shape[0]): # Iterate over time dimension + for kp in range(X_z.shape[1]): # Iterate over keypoints dimension + for sp in range(X_z.shape[2]): # Iterate over space dimennsion (x, y) + if X_z[t, kp, sp] > config["iqr_factor"] * iqr_val[kp, sp]: + X_z[t, kp, sp] = np.nan + elif X_z[t, kp, sp] < -config["iqr_factor"] * iqr_val[kp, sp]: + X_z[t, kp, sp] = np.nan + X_z = interpolate_nans_with_pandas(X_z) + + X_len = X.shape[0] + pos_temp += X_len + pos.append(pos_temp) + X_all_sessions.append(X_z) + + X_all_sessions = np.concatenate(X_all_sessions, axis=0) + + # Detect and delete anchors + detect_anchors = np.std(X_all_sessions, axis=0) + sort_anchors = np.sort(detect_anchors) + if sort_anchors[0] == sort_anchors[1]: + anchors = np.where(detect_anchors == sort_anchors[0])[0] + anchor_1_temp = anchors[0] + anchor_2_temp = anchors[1] + else: + anchor_1_temp = int(np.where(detect_anchors == sort_anchors[0])[0]) + anchor_2_temp = int(np.where(detect_anchors == sort_anchors[1])[0]) + + if anchor_1_temp > anchor_2_temp: + anchor_1 = anchor_1_temp + anchor_2 = anchor_2_temp + else: + anchor_1 = anchor_2_temp + anchor_2 = anchor_1_temp + + X = np.delete(X, anchor_1, 1) + X = np.delete(X, anchor_2, 1) + X = X.T diff --git a/src/vame/preprocessing/preprocessing.py b/src/vame/preprocessing/preprocessing.py new file mode 100644 index 00000000..fff20f28 --- /dev/null +++ b/src/vame/preprocessing/preprocessing.py @@ -0,0 +1,27 @@ +from pathlib import Path +import xarray as xr + +from vame.logging.logger import VameLogger +from vame.preprocessing.align_egocentrical import ( + egocentric_alignment_legacy, + egocentric_alignment, +) + + +def preprocessing( + config: dict, + pose_ref_1: str = "snout", + pose_ref_2: str = "tailbase", + save_logs: bool = False, +): + + egocentric_alignment( + config=config, + pose_ref_1=pose_ref_1, + pose_ref_2=pose_ref_2, + ) + + clean_timeseries( + config=config, + save_logs=save_logs, + ) diff --git a/src/vame/schemas/project.py b/src/vame/schemas/project.py index 60ef0de3..037f75e1 100644 --- a/src/vame/schemas/project.py +++ b/src/vame/schemas/project.py @@ -29,9 +29,7 @@ class ProjectSchema(BaseModel): title="Project name", ) creation_datetime: str = Field( - default_factory=lambda: datetime.now(timezone.utc).isoformat( - timespec="seconds" - ), + default_factory=lambda: datetime.now(timezone.utc).isoformat(timespec="seconds"), title="Creation datetime", ) model_name: str = Field( diff --git a/src/vame/schemas/states.py b/src/vame/schemas/states.py index aba933a8..00724e41 100644 --- a/src/vame/schemas/states.py +++ b/src/vame/schemas/states.py @@ -84,9 +84,7 @@ class MotifVideosFunctionSchema(BaseStateSchema): title="Type of video", default=".mp4", ) - segmentation_algorithm: SegmentationAlgorithms = Field( - title="Segmentation algorithm" - ) + segmentation_algorithm: SegmentationAlgorithms = Field(title="Segmentation algorithm") output_video_type: str = Field( title="Type of output video", default=".mp4", @@ -95,9 +93,7 @@ class MotifVideosFunctionSchema(BaseStateSchema): class CommunityFunctionSchema(BaseStateSchema): cohort: bool = Field(title="Cohort", default=True) - segmentation_algorithm: SegmentationAlgorithms = Field( - title="Segmentation algorithm" - ) + segmentation_algorithm: SegmentationAlgorithms = Field(title="Segmentation algorithm") cut_tree: int | None = Field( title="Cut tree", default=None, @@ -105,9 +101,7 @@ class CommunityFunctionSchema(BaseStateSchema): class CommunityVideosFunctionSchema(BaseStateSchema): - segmentation_algorithm: SegmentationAlgorithms = Field( - title="Segmentation algorithm" - ) + segmentation_algorithm: SegmentationAlgorithms = Field(title="Segmentation algorithm") cohort: bool = Field(title="Cohort", default=True) video_type: str = Field( title="Type of video", @@ -120,9 +114,7 @@ class CommunityVideosFunctionSchema(BaseStateSchema): class VisualizationFunctionSchema(BaseStateSchema): - segmentation_algorithm: SegmentationAlgorithms = Field( - title="Segmentation algorithm" - ) + segmentation_algorithm: SegmentationAlgorithms = Field(title="Segmentation algorithm") label: Optional[str] = Field( title="Type of labels to visualize", default=None, @@ -130,9 +122,7 @@ class VisualizationFunctionSchema(BaseStateSchema): class GenerativeModelFunctionSchema(BaseStateSchema): - segmentation_algorithm: SegmentationAlgorithms = Field( - title="Segmentation algorithm" - ) + segmentation_algorithm: SegmentationAlgorithms = Field(title="Segmentation algorithm") mode: GenerativeModelModeEnum = Field( title="Mode for generating samples", default=GenerativeModelModeEnum.sampling, diff --git a/src/vame/util/auxiliary.py b/src/vame/util/auxiliary.py index c206818b..560f6293 100644 --- a/src/vame/util/auxiliary.py +++ b/src/vame/util/auxiliary.py @@ -137,10 +137,7 @@ def read_config(configname: str) -> dict: write_config(configname, cfg) except Exception as err: if len(err.args) > 2: - if ( - err.args[2] - == "could not determine a constructor for the tag '!!python/tuple'" - ): + if err.args[2] == "could not determine a constructor for the tag '!!python/tuple'": with open(path, "r") as ymlfile: cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader) write_config(configname, cfg) diff --git a/src/vame/util/cli.py b/src/vame/util/cli.py index e0a09fdf..f7aa5ce7 100644 --- a/src/vame/util/cli.py +++ b/src/vame/util/cli.py @@ -24,6 +24,4 @@ def get_sessions_from_user_input( if user_input in cfg["session_names"]: sessions = [user_input] else: - raise ValueError( - "Invalid input. Please enter yes, no, or a valid session name." - ) + raise ValueError("Invalid input. Please enter yes, no, or a valid session name.") diff --git a/src/vame/util/csv_to_npy.py b/src/vame/util/csv_to_npy.py index 4010c03e..7ce2c382 100644 --- a/src/vame/util/csv_to_npy.py +++ b/src/vame/util/csv_to_npy.py @@ -89,9 +89,7 @@ def pose_to_numpy( i = interpol_first_rows_nans(i) positions = np.concatenate(pose_list, axis=1) - final_positions = np.zeros( - (data_mat.shape[0], int(data_mat.shape[1] / 3) * 2) - ) + final_positions = np.zeros((data_mat.shape[0], int(data_mat.shape[1] / 3) * 2)) jdx = 0 idx = 0 @@ -113,9 +111,7 @@ def pose_to_numpy( ) logger.info("conversion from DeepLabCut csv to numpy complete...") - logger.info( - "Your data is now in right format and you can call vame.create_trainset()" - ) + logger.info("Your data is now in right format and you can call vame.create_trainset()") except Exception as e: logger.exception(f"{e}") raise e diff --git a/src/vame/util/data_manipulation.py b/src/vame/util/data_manipulation.py index 97bba269..563ddcc3 100644 --- a/src/vame/util/data_manipulation.py +++ b/src/vame/util/data_manipulation.py @@ -212,6 +212,32 @@ def interpol_first_rows_nans(arr: np.ndarray) -> np.ndarray: return arr +def interpolate_nans_with_pandas(data: np.ndarray) -> np.ndarray: + """ + Interpolate NaN values along the time axis of a 3D NumPy array using Pandas. + + Parameters: + ----------- + data : numpy.ndarray + Input 3D array of shape (time, keypoints, space). + + Returns: + -------- + numpy.ndarray: + Array with NaN values interpolated. + """ + for kp in range(data.shape[1]): # Loop over keypoints dimension + for sp in range(data.shape[2]): # Loop over space dimension (x, y) + series = pd.Series(data[:, kp, sp]) + series_interpolated = series.interpolate( + method="linear", + limit_direction="both", + axis=0, + ) + data[:, kp, sp] = series_interpolated.values + return data + + def crop_and_flip_legacy( rect: Tuple, src: np.ndarray, @@ -363,9 +389,7 @@ def nc_to_dataframe(nc_data): # Flatten position data position_data = nc_data["position"].isel(individuals=0).values - position_column_names = [ - f"{keypoint}_{sp}" for keypoint in keypoints for sp in space - ] + position_column_names = [f"{keypoint}_{sp}" for keypoint in keypoints for sp in space] position_flattened = position_data.reshape(position_data.shape[0], -1) # Create a DataFrame for position data @@ -383,9 +407,7 @@ def nc_to_dataframe(nc_data): # Reorder columns: keypoint_x, keypoint_y, keypoint_confidence reordered_columns = [] for keypoint in keypoints: - reordered_columns.extend( - [f"{keypoint}_x", f"{keypoint}_y", f"{keypoint}_confidence"] - ) + reordered_columns.extend([f"{keypoint}_x", f"{keypoint}_y", f"{keypoint}_confidence"]) combined_df = combined_df[reordered_columns] diff --git a/src/vame/util/gif_pose_helper.py b/src/vame/util/gif_pose_helper.py index 9c384497..752efcba 100644 --- a/src/vame/util/gif_pose_helper.py +++ b/src/vame/util/gif_pose_helper.py @@ -145,9 +145,7 @@ def get_animal_frames( frame = frame - bg frame[frame <= 0] = 0 except Exception: - logger.info( - f"Couldn't find a frame in capture.read(). #Frame: {idx + start + lag}" - ) + logger.info(f"Couldn't find a frame in capture.read(). #Frame: {idx + start + lag}") continue # Read coordinates and add border diff --git a/src/vame/util/report.py b/src/vame/util/report.py index 1bb36fd3..e13504b5 100644 --- a/src/vame/util/report.py +++ b/src/vame/util/report.py @@ -38,15 +38,10 @@ def report( report_folder.mkdir(exist_ok=True) # Motifs and Communities - if ( - project_states.get("segment_session", {}).get("execution_state", "") - != "success" - ): + if project_states.get("segment_session", {}).get("execution_state", "") != "success": raise Exception("Segmentation failed. Skipping motifs and communities report.") if project_states.get("community", {}).get("execution_state", "") != "success": - raise Exception( - "Community detection failed. Skipping motifs and communities report." - ) + raise Exception("Community detection failed. Skipping motifs and communities report.") ml = np.load( project_path @@ -96,8 +91,7 @@ def report( title=f"Community and Motif Counts - Cohort - {model_name} - {segmentation_algorithm} - {n_clusters}", save_to_file=True, save_path=str( - report_folder - / f"community_motifs_cohort_{model_name}_{segmentation_algorithm}-{n_clusters}.png" + report_folder / f"community_motifs_cohort_{model_name}_{segmentation_algorithm}-{n_clusters}.png" ), ) @@ -141,8 +135,7 @@ def report( title=f"Community and Motif Counts - {session} - {model_name} - {segmentation_algorithm} - {n_clusters}", save_to_file=True, save_path=str( - report_folder - / f"community_motifs_{session}_{model_name}_{segmentation_algorithm}-{n_clusters}.png" + report_folder / f"community_motifs_{session}_{model_name}_{segmentation_algorithm}-{n_clusters}.png" ), ) @@ -165,9 +158,7 @@ def plot_community_motifs( community_indices = [community for community, count in communities] community_counts = [count for community, count in communities] total_community_counts = sum(community_counts) - community_percentages = [ - (count / total_community_counts) * 100 for count in community_counts - ] + community_percentages = [(count / total_community_counts) * 100 for count in community_counts] # Define positions and bar widths bar_width = 0.8 @@ -203,9 +194,7 @@ def plot_community_motifs( ax2 = ax1.twinx() ax2.set_ylim(ax1.get_ylim()) ax2.set_yticks(ax1.get_yticks()) - ax2.set_yticklabels( - [f"{(tick / total_community_counts) * 100:.1f}%" for tick in ax1.get_yticks()] - ) + ax2.set_yticklabels([f"{(tick / total_community_counts) * 100:.1f}%" for tick in ax1.get_yticks()]) ax2.set_ylabel("Percentage") # Overlay motif bars within each community @@ -217,16 +206,12 @@ def plot_community_motifs( motifs_sorted = [motif for motif, count in motif_counts] counts_sorted = [count for motif, count in motif_counts] total_motif_counts = sum(counts_sorted) - motif_percentages = [ - (count / total_motif_counts) * 100 for count in counts_sorted - ] + motif_percentages = [(count / total_motif_counts) * 100 for count in counts_sorted] num_motifs = len(motifs_sorted) # Adjust motif bar width to fill the community bar width if num_motifs > 0: - motif_width = ( - motif_bar_width / num_motifs * 0.9 - ) # Slightly reduce width to create space between bars + motif_width = motif_bar_width / num_motifs * 0.9 # Slightly reduce width to create space between bars else: motif_width = motif_bar_width @@ -240,8 +225,7 @@ def plot_community_motifs( bars = ax1.bar( motif_positions, counts_sorted, - width=motif_width - * 0.9, # Slightly reduce width to create space between bars + width=motif_width * 0.9, # Slightly reduce width to create space between bars label=f"Motifs in Community {community}", ) @@ -254,9 +238,7 @@ def plot_community_motifs( ha="center", va="bottom", fontsize=9, - color=( - "white" if bar.get_facecolor()[0] < 0.5 else "black" - ), # Contrast with bar color + color=("white" if bar.get_facecolor()[0] < 0.5 else "black"), # Contrast with bar color ) # Add percentage values on top of motif bars @@ -268,9 +250,7 @@ def plot_community_motifs( ha="center", va="bottom", fontsize=8, - color=( - "white" if bar.get_facecolor()[0] < 0.5 else "black" - ), # Contrast with bar color + color=("white" if bar.get_facecolor()[0] < 0.5 else "black"), # Contrast with bar color ) # Formatting diff --git a/tests/test_analysis.py b/tests/test_analysis.py index e41206d4..e80502d5 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -27,9 +27,7 @@ def test_pose_segmentation_hmm_files_exists( "individual_segmentation": individual_segmentation, } mock_config["hmm_trained"] = hmm_trained - with patch( - "vame.analysis.pose_segmentation.read_config", return_value=mock_config - ) as mock_read_config: + with patch("vame.analysis.pose_segmentation.read_config", return_value=mock_config) as mock_read_config: with patch("builtins.input", return_value="yes"): vame.segment_session( setup_project_and_train_model["config_path"], @@ -39,13 +37,7 @@ def test_pose_segmentation_hmm_files_exists( file = setup_project_and_train_model["config_data"]["session_names"][0] model_name = setup_project_and_train_model["config_data"]["model_name"] n_clusters = setup_project_and_train_model["config_data"]["n_clusters"] - save_base_path = ( - Path(project_path) - / "results" - / file - / model_name - / f"{segmentation_algorithm}-{n_clusters}" - ) + save_base_path = Path(project_path) / "results" / file / model_name / f"{segmentation_algorithm}-{n_clusters}" latent_vector_path = save_base_path / f"latent_vector_{file}.npy" motif_usage_path = save_base_path / f"motif_usage_{file}.npy" @@ -54,9 +46,7 @@ def test_pose_segmentation_hmm_files_exists( @pytest.mark.parametrize("segmentation_algorithm", ["hmm", "kmeans"]) -def test_motif_videos_mp4_files_exists( - setup_project_and_train_model, segmentation_algorithm -): +def test_motif_videos_mp4_files_exists(setup_project_and_train_model, segmentation_algorithm): vame.motif_videos( setup_project_and_train_model["config_path"], segmentation_algorithm=segmentation_algorithm, @@ -82,9 +72,7 @@ def test_motif_videos_mp4_files_exists( @pytest.mark.parametrize("segmentation_algorithm", ["hmm", "kmeans"]) -def test_motif_videos_avi_files_exists( - setup_project_and_train_model, segmentation_algorithm -): +def test_motif_videos_avi_files_exists(setup_project_and_train_model, segmentation_algorithm): # Check if the files are created vame.motif_videos( setup_project_and_train_model["config_path"], @@ -144,9 +132,7 @@ def test_motif_videos_avi_files_exists( @pytest.mark.parametrize("segmentation_algorithm", ["hmm", "kmeans"]) -def test_cohort_community_files_exists( - setup_project_and_train_model, segmentation_algorithm -): +def test_cohort_community_files_exists(setup_project_and_train_model, segmentation_algorithm): # Check if the files are created vame.community( setup_project_and_train_model["config_path"], @@ -158,17 +144,10 @@ def test_cohort_community_files_exists( project_path = setup_project_and_train_model["config_data"]["project_path"] n_clusters = setup_project_and_train_model["config_data"]["n_clusters"] - base_path = ( - Path(project_path) - / "results" - / "community_cohort" - / f"{segmentation_algorithm}-{n_clusters}" - ) + base_path = Path(project_path) / "results" / "community_cohort" / f"{segmentation_algorithm}-{n_clusters}" cohort_path = base_path / "cohort_transition_matrix.npy" community_path = base_path / "cohort_community_label.npy" - cohort_segmentation_algorithm_path = ( - base_path / f"cohort_{segmentation_algorithm}_label.npy" - ) + cohort_segmentation_algorithm_path = base_path / f"cohort_{segmentation_algorithm}_label.npy" cohort_community_bag_path = base_path / "cohort_community_bag.npy" assert cohort_path.exists() @@ -268,12 +247,7 @@ def test_visualization_output_files( project_path = setup_project_and_train_model["config_data"]["project_path"] save_base_path = ( - Path(project_path) - / "results" - / file - / model_name - / f"{segmentation_algorithm}-{n_clusters}" - / "community" + Path(project_path) / "results" / file / model_name / f"{segmentation_algorithm}-{n_clusters}" / "community" ) assert len(list(save_base_path.glob(f"umap_vis*{file}.png"))) > 0 @@ -316,9 +290,7 @@ def test_report( config=setup_project_and_train_model["config_path"], segmentation_algorithm=segmentation_algorithm, ) - reports_path = ( - Path(setup_project_and_train_model["config_data"]["project_path"]) / "reports" - ) + reports_path = Path(setup_project_and_train_model["config_data"]["project_path"]) / "reports" assert len(list(reports_path.glob("*.png"))) > 0 diff --git a/tests/test_initialize_project.py b/tests/test_initialize_project.py index 0f39cae4..7fea631e 100644 --- a/tests/test_initialize_project.py +++ b/tests/test_initialize_project.py @@ -17,9 +17,7 @@ def test_project_name_config(setup_project_not_aligned_data): """ config = Path(setup_project_not_aligned_data["config_path"]) config_values = read_config(config) - assert ( - config_values["project_name"] == setup_project_not_aligned_data["project_name"] - ) + assert config_values["project_name"] == setup_project_not_aligned_data["project_name"] def test_existing_project(): diff --git a/tests/test_model.py b/tests/test_model.py index f2f827da..5b92489f 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -47,42 +47,14 @@ def test_train_model_losses_files_exists(setup_project_and_train_model): model_name = setup_project_and_train_model["config_data"]["model_name"] # save logged losses - train_losses_path = ( - Path(project_path) / "model" / "model_losses" / f"train_losses_{model_name}.npy" - ) - test_losses_path = ( - Path(project_path) / "model" / "model_losses" / f"test_losses_{model_name}.npy" - ) - kmeans_losses_path = ( - Path(project_path) - / "model" - / "model_losses" - / f"kmeans_losses_{model_name}.npy" - ) - kl_losses_path = ( - Path(project_path) / "model" / "model_losses" / f"kl_losses_{model_name}.npy" - ) - weight_values_path = ( - Path(project_path) - / "model" - / "model_losses" - / f"weight_values_{model_name}.npy" - ) - mse_train_losses_path = ( - Path(project_path) - / "model" - / "model_losses" - / f"mse_train_losses_{model_name}.npy" - ) - mse_test_losses_path = ( - Path(project_path) - / "model" - / "model_losses" - / f"mse_test_losses_{model_name}.npy" - ) - fut_losses_path = ( - Path(project_path) / "model" / "model_losses" / f"fut_losses_{model_name}.npy" - ) + train_losses_path = Path(project_path) / "model" / "model_losses" / f"train_losses_{model_name}.npy" + test_losses_path = Path(project_path) / "model" / "model_losses" / f"test_losses_{model_name}.npy" + kmeans_losses_path = Path(project_path) / "model" / "model_losses" / f"kmeans_losses_{model_name}.npy" + kl_losses_path = Path(project_path) / "model" / "model_losses" / f"kl_losses_{model_name}.npy" + weight_values_path = Path(project_path) / "model" / "model_losses" / f"weight_values_{model_name}.npy" + mse_train_losses_path = Path(project_path) / "model" / "model_losses" / f"mse_train_losses_{model_name}.npy" + mse_test_losses_path = Path(project_path) / "model" / "model_losses" / f"mse_test_losses_{model_name}.npy" + fut_losses_path = Path(project_path) / "model" / "model_losses" / f"fut_losses_{model_name}.npy" assert train_losses_path.exists() assert test_losses_path.exists() @@ -98,9 +70,7 @@ def test_train_model_best_model_file_exists(setup_project_and_train_model): project_path = setup_project_and_train_model["config_data"]["project_path"] model_name = setup_project_and_train_model["config_data"]["model_name"] project_name = setup_project_and_train_model["config_data"]["project_name"] - best_model_path = ( - Path(project_path) / "model" / "best_model" / f"{model_name}_{project_name}.pkl" - ) + best_model_path = Path(project_path) / "model" / "best_model" / f"{model_name}_{project_name}.pkl" assert best_model_path.exists() @@ -108,12 +78,8 @@ def test_train_model_best_model_file_exists(setup_project_and_train_model): def test_evaluate_model_images_exists(setup_project_and_evaluate_model): project_path = setup_project_and_evaluate_model["config_data"]["project_path"] model_name = setup_project_and_evaluate_model["config_data"]["model_name"] - reconstruction_image_path = ( - Path(project_path) / "model" / "evaluate" / "Future_Reconstruction.png" - ) - loss_image_path = ( - Path(project_path) / "model" / "evaluate" / f"MSE-and-KL-Loss{model_name}.png" - ) + reconstruction_image_path = Path(project_path) / "model" / "evaluate" / "Future_Reconstruction.png" + loss_image_path = Path(project_path) / "model" / "evaluate" / f"MSE-and-KL-Loss{model_name}.png" assert reconstruction_image_path.exists() assert loss_image_path.exists() diff --git a/tests/test_util.py b/tests/test_util.py index 28426e53..52f1a769 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -6,12 +6,8 @@ def test_pose_to_numpy_file_exists(setup_project_and_convert_pose_to_numpy): """ Test if the pose-estimation file was converted to a numpy array file. """ - project_path = setup_project_and_convert_pose_to_numpy["config_data"][ - "project_path" - ] - file_name = setup_project_and_convert_pose_to_numpy["config_data"]["session_names"][ - 0 - ] + project_path = setup_project_and_convert_pose_to_numpy["config_data"]["project_path"] + file_name = setup_project_and_convert_pose_to_numpy["config_data"]["session_names"][0] file_path = os.path.join( project_path, "data",