Skip to content

Commit

Permalink
project.py
Browse files Browse the repository at this point in the history
  • Loading branch information
luiztauffer committed Sep 12, 2024
1 parent 4f5a7a5 commit 9adec63
Showing 1 changed file with 268 additions and 73 deletions.
341 changes: 268 additions & 73 deletions src/vame/schemas/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,109 +2,304 @@
from typing import List, Optional
from enum import Enum


class Parametrizations(str, Enum):
hmm = 'hmm'
kmeans = 'kmeans'
hmm = "hmm"
kmeans = "kmeans"

class Config:
use_enum_values = True


class PoseEstimationFiletype(str, Enum):
csv = 'csv'
nwb = 'nwb'
csv = "csv"
nwb = "nwb"

class Config:
use_enum_values = True


class ProjectSchema(BaseModel):
# Project attributes
Project: str = Field(..., title='Project name')
model_name: str = Field(default='VAME', title='Model name')
n_cluster: int = Field(default=15, title='Number of clusters')
pose_confidence: float = Field(default=0.99, title='Pose confidence')
Project: str = Field(
...,
title="Project name",
)
model_name: str = Field(
default="VAME",
title="Model name",
)
n_cluster: int = Field(
default=15,
title="Number of clusters",
)
pose_confidence: float = Field(
default=0.99,
title="Pose confidence",
)

# Project path and videos
project_path: str = Field(..., title='Project path')
video_sets: List[str] = Field(..., title='Video sets')
pose_estimation_filetype: PoseEstimationFiletype = Field(title='Pose estimation filetype')
paths_to_pose_nwb_series_data: Optional[List[str]] = Field(title='Paths to pose series data in nwb files', default=None)
project_path: str = Field(
...,
title="Project path",
)
video_sets: List[str] = Field(
...,
title="Video sets",
)
pose_estimation_filetype: PoseEstimationFiletype = Field(
title="Pose estimation filetype",
)
paths_to_pose_nwb_series_data: Optional[List[str]] = Field(
title="Paths to pose series data in nwb files",
default=None,
)

# Data
all_data: str = Field(default='yes', title='All data')
egocentric_data: bool = Field(default=False, title='Egocentric data')
robust: bool = Field(default=True, title='Robust data')
iqr_factor: int = Field(default=4, title='IQR factor')
axis: str = Field(default='None', title='Axis')
savgol_filter: bool = Field(default=True, title='Savgol filter')
savgol_length: int = Field(default=5, title='Savgol length')
savgol_order: int = Field(default=2, title='Savgol order')
test_fraction: float = Field(default=0.1, title='Test fraction')
all_data: str = Field(
default="yes",
title="All data",
)
egocentric_data: bool = Field(
default=False,
title="Egocentric data",
)
robust: bool = Field(
default=True,
title="Robust data",
)
iqr_factor: int = Field(
default=4,
title="IQR factor",
)
axis: str = Field(
default="None",
title="Axis",
)
savgol_filter: bool = Field(
default=True,
title="Savgol filter",
)
savgol_length: int = Field(
default=5,
title="Savgol length",
)
savgol_order: int = Field(
default=2,
title="Savgol order",
)
test_fraction: float = Field(
default=0.1,
title="Test fraction",
)

# RNN model general hyperparameter:
pretrained_model: str = Field(default='None', title='Pretrained model')
pretrained_weights: bool = Field(default=False, title='Pretrained weights')
num_features: int = Field(default=12, title='Number of features')
batch_size: int = Field(default=256, title='Batch size')
max_epochs: int = Field(default=500, title='Max epochs')
model_snapshot: int = Field(default=50, title='Model snapshot')
model_convergence: int = Field(default=50, title='Model convergence')
transition_function: str = Field(default='GRU', title='Transition function')
beta: float = Field(default=1, title='Beta')
beta_norm: bool = Field(default=False, title='Beta normalization')
zdims: int = Field(default=30, title='Zdims')
learning_rate: float = Field(default=5e-4, title='Learning rate')
time_window: int = Field(default=30, title='Time window')
prediction_decoder: int = Field(default=1, title='Prediction decoder')
prediction_steps: int = Field(default=15, title='Prediction steps')
noise: bool = Field(default=False, title='Noise')
scheduler: int = Field(default=1, title='Scheduler')
scheduler_step_size: int = Field(default=100, title='Scheduler step size')
scheduler_gamma: float = Field(default=0.2, title='Scheduler gamma')
scheduler_threshold: float = Field(default=None, title='Scheduler threshold')
softplus: bool = Field(default=False, title='Softplus')


pretrained_model: str = Field(
default="None",
title="Pretrained model",
)
pretrained_weights: bool = Field(
default=False,
title="Pretrained weights",
)
num_features: int = Field(
default=12,
title="Number of features",
)
batch_size: int = Field(
default=256,
title="Batch size",
)
max_epochs: int = Field(
default=500,
title="Max epochs",
)
model_snapshot: int = Field(
default=50,
title="Model snapshot",
)
model_convergence: int = Field(
default=50,
title="Model convergence",
)
transition_function: str = Field(
default="GRU",
title="Transition function",
)
beta: float = Field(
default=1,
title="Beta",
)
beta_norm: bool = Field(
default=False,
title="Beta normalization",
)
zdims: int = Field(
default=30,
title="Zdims",
)
learning_rate: float = Field(
default=5e-4,
title="Learning rate",
)
time_window: int = Field(
default=30,
title="Time window",
)
prediction_decoder: int = Field(
default=1,
title="Prediction decoder",
)
prediction_steps: int = Field(
default=15,
title="Prediction steps",
)
noise: bool = Field(
default=False,
title="Noise",
)
scheduler: int = Field(
default=1,
title="Scheduler",
)
scheduler_step_size: int = Field(
default=100,
title="Scheduler step size",
)
scheduler_gamma: float = Field(
default=0.2,
title="Scheduler gamma",
)
scheduler_threshold: float = Field(
default=None,
title="Scheduler threshold",
)
softplus: bool = Field(
default=False,
title="Softplus",
)
# Segmentation:
parametrizations: List[Parametrizations] = Field(
title='Parametrizations',
default_factory=lambda: [Parametrizations.hmm.value, Parametrizations.kmeans.value]
title="Parametrizations",
default_factory=lambda: [
Parametrizations.hmm.value,
Parametrizations.kmeans.value,
],
)
hmm_trained: bool = Field(
default=False,
title="HMM trained",
)
load_data: str = Field(
default="-PE-seq-clean",
title="Load data",
)
individual_parametrization: bool = Field(
default=False,
title="Individual parametrization",
)
random_state_kmeans: int = Field(
default=42,
title="Random state kmeans",
)
n_init_kmeans: int = Field(
default=15,
title="N init kmeans",
)
hmm_trained: bool = Field(default=False, title='HMM trained')
load_data: str = Field(default='-PE-seq-clean', title='Load data')
individual_parametrization: bool = Field(default=False, title='Individual parametrization')
random_state_kmeans: int = Field(default=42, title='Random state kmeans')
n_init_kmeans: int = Field(default=15, title='N init kmeans')

# Video writer:
length_of_motif_video: int = Field(default=1000, title='Length of motif video')
length_of_motif_video: int = Field(
default=1000,
title="Length of motif video",
)

# UMAP parameter:
min_dist: float = Field(default=0.1, title='Min dist')
n_neighbors: int = Field(default=200, title='N neighbors')
random_state: int = Field(default=42, title='Random state')
num_points: int = Field(default=30000, title='Num points')
min_dist: float = Field(
default=0.1,
title="Min dist",
)
n_neighbors: int = Field(
default=200,
title="N neighbors",
)
random_state: int = Field(
default=42,
title="Random state",
)
num_points: int = Field(
default=30000,
title="Num points",
)

# RNN encoder hyperparameter:
hidden_size_layer_1: int = Field(default=256, title='Hidden size layer 1')
hidden_size_layer_2: int = Field(default=256, title='Hidden size layer 2')
dropout_encoder: float = Field(default=0, title='Dropout encoder')
hidden_size_layer_1: int = Field(
default=256,
title="Hidden size layer 1",
)
hidden_size_layer_2: int = Field(
default=256,
title="Hidden size layer 2",
)
dropout_encoder: float = Field(
default=0,
title="Dropout encoder",
)

# RNN reconstruction hyperparameter:
hidden_size_rec: int = Field(default=256, title='Hidden size rec')
dropout_rec: float = Field(default=0, title='Dropout rec')
n_layers: int = Field(default=1, title='N layers')
hidden_size_rec: int = Field(
default=256,
title="Hidden size rec",
)
dropout_rec: float = Field(
default=0,
title="Dropout rec",
)
n_layers: int = Field(
default=1,
title="N layers",
)

# RNN prediction hyperparameter:
hidden_size_pred: int = Field(default=256, title='Hidden size pred')
dropout_pred: float = Field(default=0, title='Dropout pred')
hidden_size_pred: int = Field(
default=256,
title="Hidden size pred",
)
dropout_pred: float = Field(
default=0,
title="Dropout pred",
)

# RNN loss hyperparameter:
mse_reconstruction_reduction: str = Field(default='sum', title='MSE reconstruction reduction')
mse_prediction_reduction: str = Field(default='sum', title='MSE prediction reduction')
kmeans_loss: int = Field(default=30, title='Kmeans loss')
kmeans_lambda: float = Field(default=0.1, title='Kmeans lambda')
anneal_function: str = Field(default='linear', title='Anneal function')
kl_start: int = Field(default=2, title='KL start')
annealtime: int = Field(default=4, title='Annealtime')

model_config: ConfigDict = ConfigDict(protected_namespaces=(), use_enum_values=True)
mse_reconstruction_reduction: str = Field(
default="sum",
title="MSE reconstruction reduction",
)
mse_prediction_reduction: str = Field(
default="sum",
title="MSE prediction reduction",
)
kmeans_loss: int = Field(
default=30,
title="Kmeans loss",
)
kmeans_lambda: float = Field(
default=0.1,
title="Kmeans lambda",
)
anneal_function: str = Field(
default="linear",
title="Anneal function",
)
kl_start: int = Field(
default=2,
title="KL start",
)
annealtime: int = Field(
default=4,
title="Annealtime",
)

model_config: ConfigDict = ConfigDict(
protected_namespaces=(),
use_enum_values=True,
)

0 comments on commit 9adec63

Please sign in to comment.