Skip to content

Commit

Permalink
align_egocentrical and functions
Browse files Browse the repository at this point in the history
  • Loading branch information
luiztauffer committed Sep 12, 2024
1 parent 26dc697 commit 4f5a7a5
Show file tree
Hide file tree
Showing 3 changed files with 570 additions and 323 deletions.
196 changes: 139 additions & 57 deletions src/vame/schemas/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,89 +6,169 @@
from enum import Enum
from vame.schemas.project import Parametrizations


class StatesEnum(str, Enum):
success = 'success'
failed = 'failed'
running = 'running'
aborted = 'aborted'
success = "success"
failed = "failed"
running = "running"
aborted = "aborted"


class GenerativeModelModeEnum(str, Enum):
sampling = 'sampling'
reconstruction = 'reconstruction'
centers = 'centers'
motifs = 'motifs'
sampling = "sampling"
reconstruction = "reconstruction"
centers = "centers"
motifs = "motifs"

class BaseStateSchema(BaseModel):
config: str = Field(title='Configuration file path')
execution_state: StatesEnum | None = Field(title='Method execution state', default=None)

class BaseStateSchema(BaseModel):
config: str = Field(title="Configuration file path")
execution_state: StatesEnum | None = Field(
title="Method execution state",
default=None,
)


class EgocentricAlignmentFunctionSchema(BaseStateSchema):
pose_ref_index: list = Field(title='Pose reference index', default=[0, 5])
crop_size: tuple = Field(title='Crop size', default=(300, 300))
use_video: bool = Field(title='Use video', default=False)
video_format: str = Field(title='Video format', default='.mp4')
check_video: bool = Field(title='Check video', default=False)


class PoseToNumpyFunctionSchema(BaseStateSchema):
...
pose_ref_index: list = Field(
title="Pose reference index",
default=[0, 5],
)
crop_size: tuple = Field(
title="Crop size",
default=(300, 300),
)
use_video: bool = Field(
title="Use video",
default=False,
)
video_format: str = Field(
title="Video format",
default=".mp4",
)
check_video: bool = Field(
title="Check video",
default=False,
)


class PoseToNumpyFunctionSchema(BaseStateSchema): ...


class CreateTrainsetFunctionSchema(BaseStateSchema):
pose_ref_index: Optional[list] = Field(title='Pose reference index', default=None)
check_parameter: bool = Field(title='Check parameter', default=False)
pose_ref_index: Optional[list] = Field(
title="Pose reference index",
default=None,
)
check_parameter: bool = Field(
title="Check parameter",
default=False,
)


class TrainModelFunctionSchema(BaseStateSchema):
...
class TrainModelFunctionSchema(BaseStateSchema): ...


class EvaluateModelFunctionSchema(BaseStateSchema):
use_snapshots: bool = Field(title='Use snapshots', default=False)
use_snapshots: bool = Field(
title="Use snapshots",
default=False,
)


class PoseSegmentationFunctionSchema(BaseStateSchema):
...
class PoseSegmentationFunctionSchema(BaseStateSchema): ...


class MotifVideosFunctionSchema(BaseStateSchema):
videoType: str = Field(title='Type of video', default='.mp4')
parametrization: Parametrizations = Field(title='Parametrization')
output_video_type: str = Field(title='Type of output video', default='.mp4')
videoType: str = Field(
title="Type of video",
default=".mp4",
)
parametrization: Parametrizations = Field(title="Parametrization")
output_video_type: str = Field(
title="Type of output video",
default=".mp4",
)


class CommunityFunctionSchema(BaseStateSchema):
cohort: bool = Field(title='Cohort', default=True)
parametrization: Parametrizations = Field(title='Parametrization')
cut_tree: int | None = Field(title='Cut tree', default=None)
cohort: bool = Field(title="Cohort", default=True)
parametrization: Parametrizations = Field(title="Parametrization")
cut_tree: int | None = Field(
title="Cut tree",
default=None,
)


class CommunityVideosFunctionSchema(BaseStateSchema):
parametrization: Parametrizations = Field(title='Parametrization')
videoType: str = Field(title='Type of video', default='.mp4')
parametrization: Parametrizations = Field(title="Parametrization")
videoType: str = Field(
title="Type of video",
default=".mp4",
)


class VisualizationFunctionSchema(BaseStateSchema):
parametrization: Parametrizations = Field(title='Parametrization')
label: Optional[str] = Field(title='Type of labels to visualize', default=None)
parametrization: Parametrizations = Field(title="Parametrization")
label: Optional[str] = Field(
title="Type of labels to visualize",
default=None,
)


class GenerativeModelFunctionSchema(BaseStateSchema):
parametrization: Parametrizations = Field(title='Parametrization')
mode: GenerativeModelModeEnum = Field(title='Mode for generating samples', default=GenerativeModelModeEnum.sampling)
parametrization: Parametrizations = Field(title="Parametrization")
mode: GenerativeModelModeEnum = Field(
title="Mode for generating samples",
default=GenerativeModelModeEnum.sampling,
)


class VAMEPipelineStatesSchema(BaseModel):
egocentric_alignment: Optional[EgocentricAlignmentFunctionSchema | Dict] = Field(title='Egocentric alignment', default={})
pose_to_numpy: Optional[PoseToNumpyFunctionSchema | Dict] = Field(title='CSV to numpy', default={})
create_trainset: Optional[CreateTrainsetFunctionSchema | Dict] = Field(title='Create trainset', default={})
train_model: Optional[TrainModelFunctionSchema | Dict] = Field(title='Train model', default={})
evaluate_model: Optional[EvaluateModelFunctionSchema | Dict] = Field(title='Evaluate model', default={})
pose_segmentation: Optional[PoseSegmentationFunctionSchema | Dict] = Field(title='Pose segmentation', default={})
motif_videos: Optional[MotifVideosFunctionSchema | Dict] = Field(title='Motif videos', default={})
community: Optional[CommunityFunctionSchema | Dict] = Field(title='Community', default={})
community_videos: Optional[CommunityVideosFunctionSchema | Dict] = Field(title='Community videos', default={})
visualization: Optional[VisualizationFunctionSchema | Dict] = Field(title='Visualization', default={})
generative_model: Optional[GenerativeModelFunctionSchema | Dict] = Field(title='Generative model', default={})
egocentric_alignment: Optional[EgocentricAlignmentFunctionSchema | Dict] = Field(
title="Egocentric alignment",
default={},
)
pose_to_numpy: Optional[PoseToNumpyFunctionSchema | Dict] = Field(
title="CSV to numpy",
default={},
)
create_trainset: Optional[CreateTrainsetFunctionSchema | Dict] = Field(
title="Create trainset",
default={},
)
train_model: Optional[TrainModelFunctionSchema | Dict] = Field(
title="Train model",
default={},
)
evaluate_model: Optional[EvaluateModelFunctionSchema | Dict] = Field(
title="Evaluate model",
default={},
)
pose_segmentation: Optional[PoseSegmentationFunctionSchema | Dict] = Field(
title="Pose segmentation",
default={},
)
motif_videos: Optional[MotifVideosFunctionSchema | Dict] = Field(
title="Motif videos",
default={},
)
community: Optional[CommunityFunctionSchema | Dict] = Field(
title="Community",
default={},
)
community_videos: Optional[CommunityVideosFunctionSchema | Dict] = Field(
title="Community videos",
default={},
)
visualization: Optional[VisualizationFunctionSchema | Dict] = Field(
title="Visualization",
default={},
)
generative_model: Optional[GenerativeModelFunctionSchema | Dict] = Field(
title="Generative model",
default={},
)


def _save_state(model: BaseModel, function_name: str, state: StatesEnum) -> None:
Expand All @@ -97,16 +177,16 @@ def _save_state(model: BaseModel, function_name: str, state: StatesEnum) -> None
"""
config_file_path = Path(model.config)
project_path = config_file_path.parent
states_file_path = project_path / 'states/states.json'
states_file_path = project_path / "states/states.json"

with open(states_file_path, 'r') as f:
with open(states_file_path, "r") as f:
states = json.load(f)

pipeline_states = VAMEPipelineStatesSchema(**states)
model.execution_state = state
setattr(pipeline_states, function_name, model.model_dump())

with open(states_file_path, 'w') as f:
with open(states_file_path, "w") as f:
json.dump(pipeline_states.model_dump(), f, indent=4)


Expand All @@ -119,13 +199,13 @@ def decorator(func: callable):
@wraps(func)
def wrapper(*args, **kwargs):
# Create an instance of the Pydantic model using provided args and kwargs
function_name = func.__name__
function_name = func.__name__
attribute_names = list(model.model_fields.keys())

kwargs_dict = {}
for attr in attribute_names:
if attr == 'execution_state':
kwargs_dict[attr] = 'running'
if attr == "execution_state":
kwargs_dict[attr] = "running"
continue
kwargs_dict[attr] = kwargs.get(attr, model.model_fields[attr].default)

Expand All @@ -145,5 +225,7 @@ def wrapper(*args, **kwargs):
except KeyboardInterrupt as e:
_save_state(kwargs_model, function_name, state=StatesEnum.aborted)
raise e

return wrapper
return decorator

return decorator
Loading

0 comments on commit 4f5a7a5

Please sign in to comment.