Skip to content

Commit

Permalink
Merge pull request #112 from EthoML/data-ingestion
Browse files Browse the repository at this point in the history
adopt movement and xarray for ingestion
  • Loading branch information
luiztauffer authored Nov 21, 2024
2 parents 0839bc1 + 964c213 commit be2a226
Show file tree
Hide file tree
Showing 26 changed files with 887 additions and 237 deletions.
1 change: 1 addition & 0 deletions .github/workflows/testing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
branches:
- main
- dev
- data-ingestion


jobs:
Expand Down
110 changes: 110 additions & 0 deletions examples/pipeline.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# VAME Pipeline example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from vame.pipeline import VAMEPipeline\n",
"from vame.util.sample_data import download_sample_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set up your working directory and project name\n",
"working_directory = '.'\n",
"project_name = 'pipeline_example'\n",
"\n",
"# You can run VAME with data from different sources\n",
"source_software = \"SLEAP\" # \"DeepLabCut\", \"SLEAP\" or \"LightningPose\"\n",
"\n",
"# Download sample data\n",
"ps = download_sample_data(source_software)\n",
"videos = [ps[\"video\"]]\n",
"poses_estimations = [ps[\"poses\"]]\n",
"\n",
"# Instantiate the pipeline\n",
"# this will create a VAME project and prepare the data\n",
"pipeline = VAMEPipeline(\n",
" working_directory=working_directory,\n",
" project_name=project_name,\n",
" videos=videos,\n",
" poses_estimations=poses_estimations,\n",
" source_software=source_software,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Before running the pipeline, you can check the imported datasets\n",
"ds = pipeline.get_raw_datasets()\n",
"ds"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Run the pipeline\n",
"pipeline.run_pipeline()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# If for some reason you need to stop the pipeline, you can resume it later from any step\n",
"# Example: resuming from community clustering step\n",
"pipeline.run_pipeline(from_step=6)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "vame",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 0 additions & 3 deletions reinstall.sh

This file was deleted.

3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ pydantic==2.7.4
imageio==2.34.1
imageio-ffmpeg==0.5.1
pynwb==2.8.1
tables>=3.10.0
movement==0.0.20
# netCDF4>=1.7.2
30 changes: 21 additions & 9 deletions src/vame/analysis/generative_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,10 @@ def generative_model(

if mode == "sampling":
latent_vector = np.load(
os.path.join(path_to_file, "latent_vector_" + session + ".npy")
os.path.join(
path_to_file,
"latent_vector_" + session + ".npy",
)
)
return random_generative_samples(
cfg,
Expand All @@ -291,7 +294,10 @@ def generative_model(

if mode == "reconstruction":
latent_vector = np.load(
os.path.join(path_to_file, "latent_vector_" + session + ".npy")
os.path.join(
path_to_file,
"latent_vector_" + session + ".npy",
)
)
return random_reconstruction_samples(
cfg,
Expand All @@ -305,7 +311,10 @@ def generative_model(
f"Algorithm {segmentation_algorithm} not supported for cluster center visualization."
)
cluster_center = np.load(
os.path.join(path_to_file, "cluster_center_" + session + ".npy")
os.path.join(
path_to_file,
"cluster_center_" + session + ".npy",
)
)
return visualize_cluster_center(
cfg,
Expand All @@ -315,7 +324,10 @@ def generative_model(

if mode == "motifs":
latent_vector = np.load(
os.path.join(path_to_file, "latent_vector_" + session + ".npy")
os.path.join(
path_to_file,
"latent_vector_" + session + ".npy",
)
)
labels = np.load(
os.path.join(
Expand All @@ -330,11 +342,11 @@ def generative_model(
)
)
return random_generative_samples_motif(
cfg,
model,
latent_vector,
labels,
n_clusters,
cfg=cfg,
model=model,
latent_vector=latent_vector,
labels=labels,
n_clusters=n_clusters,
)
except Exception as e:
logger.exception(str(e))
Expand Down
9 changes: 8 additions & 1 deletion src/vame/analysis/pose_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vame.logging.logger import VameLogger, TqdmToLogger
from vame.model.rnn_model import RNN_VAE
from vame.util.auxiliary import read_config

# from vame.util.data_manipulation import consecutive
from vame.util.cli import get_sessions_from_user_input
from vame.util.model_util import load_model
Expand Down Expand Up @@ -66,7 +67,13 @@ def embedd_latent_vectors(
for session in sessions:
logger.info(f"Embedding of latent vector for file {session}")
data = np.load(
os.path.join(project_path, "data", session, session + "-PE-seq-clean.npy")
os.path.join(
project_path,
"data",
"processed",
session,
session + "-PE-seq-clean.npy",
)
)
latent_vector_list = []
with torch.no_grad():
Expand Down
11 changes: 9 additions & 2 deletions src/vame/analysis/videowriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def create_cluster_videos(
)
if flag == "community":
if cohort:
logger.info("Cohort community videos getting created for " + session + " ...")
logger.info(
"Cohort community videos getting created for " + session + " ..."
)
labels = np.load(
os.path.join(
path_to_file,
Expand All @@ -99,7 +101,12 @@ def create_cluster_videos(
)
)

video_file_path = os.path.join(config["project_path"], "videos", session + video_type)
video_file_path = os.path.join(
config["project_path"],
"data",
"raw",
session + video_type,
)
capture = cv.VideoCapture(video_file_path)
if not capture.isOpened():
raise ValueError(
Expand Down
Loading

0 comments on commit be2a226

Please sign in to comment.