From f5daf2c6774ab633d02a0d940c2691a795676604 Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 21 Nov 2024 14:25:06 +0100 Subject: [PATCH] netCDF4 --- requirements.txt | 3 ++- src/vame/analysis/generative_functions.py | 30 ++++++++++++++++------- src/vame/io/load_poses.py | 2 +- src/vame/util/align_egocentrical.py | 28 ++++++++++++--------- tests/test_analysis.py | 8 +++--- 5 files changed, 44 insertions(+), 27 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4afa2cbe..8faa951f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,5 @@ imageio==2.34.1 imageio-ffmpeg==0.5.1 pynwb==2.8.1 tables>=3.10.0 -movement==0.0.20 \ No newline at end of file +movement==0.0.20 +netCDF4>=1.7.2 \ No newline at end of file diff --git a/src/vame/analysis/generative_functions.py b/src/vame/analysis/generative_functions.py index 574a02e2..6849e5d2 100644 --- a/src/vame/analysis/generative_functions.py +++ b/src/vame/analysis/generative_functions.py @@ -281,7 +281,10 @@ def generative_model( if mode == "sampling": latent_vector = np.load( - os.path.join(path_to_file, "latent_vector_" + session + ".npy") + os.path.join( + path_to_file, + "latent_vector_" + session + ".npy", + ) ) return random_generative_samples( cfg, @@ -291,7 +294,10 @@ def generative_model( if mode == "reconstruction": latent_vector = np.load( - os.path.join(path_to_file, "latent_vector_" + session + ".npy") + os.path.join( + path_to_file, + "latent_vector_" + session + ".npy", + ) ) return random_reconstruction_samples( cfg, @@ -305,7 +311,10 @@ def generative_model( f"Algorithm {segmentation_algorithm} not supported for cluster center visualization." ) cluster_center = np.load( - os.path.join(path_to_file, "cluster_center_" + session + ".npy") + os.path.join( + path_to_file, + "cluster_center_" + session + ".npy", + ) ) return visualize_cluster_center( cfg, @@ -315,7 +324,10 @@ def generative_model( if mode == "motifs": latent_vector = np.load( - os.path.join(path_to_file, "latent_vector_" + session + ".npy") + os.path.join( + path_to_file, + "latent_vector_" + session + ".npy", + ) ) labels = np.load( os.path.join( @@ -330,11 +342,11 @@ def generative_model( ) ) return random_generative_samples_motif( - cfg, - model, - latent_vector, - labels, - n_clusters, + cfg=cfg, + model=model, + latent_vector=latent_vector, + labels=labels, + n_clusters=n_clusters, ) except Exception as e: logger.exception(str(e)) diff --git a/src/vame/io/load_poses.py b/src/vame/io/load_poses.py index a873233b..8adeeda4 100644 --- a/src/vame/io/load_poses.py +++ b/src/vame/io/load_poses.py @@ -50,4 +50,4 @@ def load_vame_dataset(ds_path: Path | str) -> xr.Dataset: Returns: -------- """ - return xr.open_dataset(ds_path) + return xr.open_dataset(ds_path, engine="netcdf4") diff --git a/src/vame/util/align_egocentrical.py b/src/vame/util/align_egocentrical.py index 34e55e4d..7ab47099 100644 --- a/src/vame/util/align_egocentrical.py +++ b/src/vame/util/align_egocentrical.py @@ -81,12 +81,14 @@ def align_mouse( i = interpol_first_rows_nans(i) if use_video: - video_path = str(os.path.join( - project_path, - "data", - "raw", - session + video_format, - )) + video_path = str( + os.path.join( + project_path, + "data", + "raw", + session + video_format, + ) + ) capture = cv.VideoCapture(video_path) if not capture.isOpened(): raise Exception(f"Unable to open video file: {video_path}") @@ -304,12 +306,14 @@ def alignment( if use_video: # compute background - video_path = str(os.path.join( - project_path, - "data", - "raw", - session + video_format, - )) + video_path = str( + os.path.join( + project_path, + "data", + "raw", + session + video_format, + ) + ) bg = background( project_path=project_path, session=session, diff --git a/tests/test_analysis.py b/tests/test_analysis.py index d92cc382..e41206d4 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -339,16 +339,16 @@ def test_gif_frames_files_exists(setup_project_and_evaluate_model, label): vame.segment_session(setup_project_and_evaluate_model["config_path"]) def mock_background( - path_to_file=None, - filename=None, + project_path=None, + session=None, video_path=None, num_frames=None, save_background=True, ): num_frames = 100 return background( - project_path=path_to_file, - session=filename, + project_path=project_path, + session=session, video_path=video_path, num_frames=num_frames, save_background=save_background,