diff --git a/src/vame/schemas/states.py b/src/vame/schemas/states.py index db75ff8b..b1003d5c 100644 --- a/src/vame/schemas/states.py +++ b/src/vame/schemas/states.py @@ -6,89 +6,169 @@ from enum import Enum from vame.schemas.project import Parametrizations + class StatesEnum(str, Enum): - success = 'success' - failed = 'failed' - running = 'running' - aborted = 'aborted' + success = "success" + failed = "failed" + running = "running" + aborted = "aborted" + class GenerativeModelModeEnum(str, Enum): - sampling = 'sampling' - reconstruction = 'reconstruction' - centers = 'centers' - motifs = 'motifs' + sampling = "sampling" + reconstruction = "reconstruction" + centers = "centers" + motifs = "motifs" -class BaseStateSchema(BaseModel): - config: str = Field(title='Configuration file path') - execution_state: StatesEnum | None = Field(title='Method execution state', default=None) +class BaseStateSchema(BaseModel): + config: str = Field(title="Configuration file path") + execution_state: StatesEnum | None = Field( + title="Method execution state", + default=None, + ) class EgocentricAlignmentFunctionSchema(BaseStateSchema): - pose_ref_index: list = Field(title='Pose reference index', default=[0, 5]) - crop_size: tuple = Field(title='Crop size', default=(300, 300)) - use_video: bool = Field(title='Use video', default=False) - video_format: str = Field(title='Video format', default='.mp4') - check_video: bool = Field(title='Check video', default=False) - - -class PoseToNumpyFunctionSchema(BaseStateSchema): - ... + pose_ref_index: list = Field( + title="Pose reference index", + default=[0, 5], + ) + crop_size: tuple = Field( + title="Crop size", + default=(300, 300), + ) + use_video: bool = Field( + title="Use video", + default=False, + ) + video_format: str = Field( + title="Video format", + default=".mp4", + ) + check_video: bool = Field( + title="Check video", + default=False, + ) + + +class PoseToNumpyFunctionSchema(BaseStateSchema): ... class CreateTrainsetFunctionSchema(BaseStateSchema): - pose_ref_index: Optional[list] = Field(title='Pose reference index', default=None) - check_parameter: bool = Field(title='Check parameter', default=False) + pose_ref_index: Optional[list] = Field( + title="Pose reference index", + default=None, + ) + check_parameter: bool = Field( + title="Check parameter", + default=False, + ) -class TrainModelFunctionSchema(BaseStateSchema): - ... +class TrainModelFunctionSchema(BaseStateSchema): ... class EvaluateModelFunctionSchema(BaseStateSchema): - use_snapshots: bool = Field(title='Use snapshots', default=False) + use_snapshots: bool = Field( + title="Use snapshots", + default=False, + ) -class PoseSegmentationFunctionSchema(BaseStateSchema): - ... +class PoseSegmentationFunctionSchema(BaseStateSchema): ... + class MotifVideosFunctionSchema(BaseStateSchema): - videoType: str = Field(title='Type of video', default='.mp4') - parametrization: Parametrizations = Field(title='Parametrization') - output_video_type: str = Field(title='Type of output video', default='.mp4') + videoType: str = Field( + title="Type of video", + default=".mp4", + ) + parametrization: Parametrizations = Field(title="Parametrization") + output_video_type: str = Field( + title="Type of output video", + default=".mp4", + ) class CommunityFunctionSchema(BaseStateSchema): - cohort: bool = Field(title='Cohort', default=True) - parametrization: Parametrizations = Field(title='Parametrization') - cut_tree: int | None = Field(title='Cut tree', default=None) + cohort: bool = Field(title="Cohort", default=True) + parametrization: Parametrizations = Field(title="Parametrization") + cut_tree: int | None = Field( + title="Cut tree", + default=None, + ) class CommunityVideosFunctionSchema(BaseStateSchema): - parametrization: Parametrizations = Field(title='Parametrization') - videoType: str = Field(title='Type of video', default='.mp4') + parametrization: Parametrizations = Field(title="Parametrization") + videoType: str = Field( + title="Type of video", + default=".mp4", + ) class VisualizationFunctionSchema(BaseStateSchema): - parametrization: Parametrizations = Field(title='Parametrization') - label: Optional[str] = Field(title='Type of labels to visualize', default=None) + parametrization: Parametrizations = Field(title="Parametrization") + label: Optional[str] = Field( + title="Type of labels to visualize", + default=None, + ) + class GenerativeModelFunctionSchema(BaseStateSchema): - parametrization: Parametrizations = Field(title='Parametrization') - mode: GenerativeModelModeEnum = Field(title='Mode for generating samples', default=GenerativeModelModeEnum.sampling) + parametrization: Parametrizations = Field(title="Parametrization") + mode: GenerativeModelModeEnum = Field( + title="Mode for generating samples", + default=GenerativeModelModeEnum.sampling, + ) + class VAMEPipelineStatesSchema(BaseModel): - egocentric_alignment: Optional[EgocentricAlignmentFunctionSchema | Dict] = Field(title='Egocentric alignment', default={}) - pose_to_numpy: Optional[PoseToNumpyFunctionSchema | Dict] = Field(title='CSV to numpy', default={}) - create_trainset: Optional[CreateTrainsetFunctionSchema | Dict] = Field(title='Create trainset', default={}) - train_model: Optional[TrainModelFunctionSchema | Dict] = Field(title='Train model', default={}) - evaluate_model: Optional[EvaluateModelFunctionSchema | Dict] = Field(title='Evaluate model', default={}) - pose_segmentation: Optional[PoseSegmentationFunctionSchema | Dict] = Field(title='Pose segmentation', default={}) - motif_videos: Optional[MotifVideosFunctionSchema | Dict] = Field(title='Motif videos', default={}) - community: Optional[CommunityFunctionSchema | Dict] = Field(title='Community', default={}) - community_videos: Optional[CommunityVideosFunctionSchema | Dict] = Field(title='Community videos', default={}) - visualization: Optional[VisualizationFunctionSchema | Dict] = Field(title='Visualization', default={}) - generative_model: Optional[GenerativeModelFunctionSchema | Dict] = Field(title='Generative model', default={}) + egocentric_alignment: Optional[EgocentricAlignmentFunctionSchema | Dict] = Field( + title="Egocentric alignment", + default={}, + ) + pose_to_numpy: Optional[PoseToNumpyFunctionSchema | Dict] = Field( + title="CSV to numpy", + default={}, + ) + create_trainset: Optional[CreateTrainsetFunctionSchema | Dict] = Field( + title="Create trainset", + default={}, + ) + train_model: Optional[TrainModelFunctionSchema | Dict] = Field( + title="Train model", + default={}, + ) + evaluate_model: Optional[EvaluateModelFunctionSchema | Dict] = Field( + title="Evaluate model", + default={}, + ) + pose_segmentation: Optional[PoseSegmentationFunctionSchema | Dict] = Field( + title="Pose segmentation", + default={}, + ) + motif_videos: Optional[MotifVideosFunctionSchema | Dict] = Field( + title="Motif videos", + default={}, + ) + community: Optional[CommunityFunctionSchema | Dict] = Field( + title="Community", + default={}, + ) + community_videos: Optional[CommunityVideosFunctionSchema | Dict] = Field( + title="Community videos", + default={}, + ) + visualization: Optional[VisualizationFunctionSchema | Dict] = Field( + title="Visualization", + default={}, + ) + generative_model: Optional[GenerativeModelFunctionSchema | Dict] = Field( + title="Generative model", + default={}, + ) def _save_state(model: BaseModel, function_name: str, state: StatesEnum) -> None: @@ -97,16 +177,16 @@ def _save_state(model: BaseModel, function_name: str, state: StatesEnum) -> None """ config_file_path = Path(model.config) project_path = config_file_path.parent - states_file_path = project_path / 'states/states.json' + states_file_path = project_path / "states/states.json" - with open(states_file_path, 'r') as f: + with open(states_file_path, "r") as f: states = json.load(f) pipeline_states = VAMEPipelineStatesSchema(**states) model.execution_state = state setattr(pipeline_states, function_name, model.model_dump()) - with open(states_file_path, 'w') as f: + with open(states_file_path, "w") as f: json.dump(pipeline_states.model_dump(), f, indent=4) @@ -119,13 +199,13 @@ def decorator(func: callable): @wraps(func) def wrapper(*args, **kwargs): # Create an instance of the Pydantic model using provided args and kwargs - function_name = func.__name__ + function_name = func.__name__ attribute_names = list(model.model_fields.keys()) kwargs_dict = {} for attr in attribute_names: - if attr == 'execution_state': - kwargs_dict[attr] = 'running' + if attr == "execution_state": + kwargs_dict[attr] = "running" continue kwargs_dict[attr] = kwargs.get(attr, model.model_fields[attr].default) @@ -145,5 +225,7 @@ def wrapper(*args, **kwargs): except KeyboardInterrupt as e: _save_state(kwargs_model, function_name, state=StatesEnum.aborted) raise e + return wrapper - return decorator \ No newline at end of file + + return decorator diff --git a/src/vame/util/align_egocentrical.py b/src/vame/util/align_egocentrical.py index d08657aa..373cd597 100644 --- a/src/vame/util/align_egocentrical.py +++ b/src/vame/util/align_egocentrical.py @@ -12,7 +12,7 @@ import numpy as np import pandas as pd import tqdm -from typing import Tuple, List +from typing import Tuple, List, Union from vame.logging.logger import VameLogger, TqdmToLogger from pathlib import Path from vame.util.auxiliary import read_config @@ -22,14 +22,14 @@ interpol_first_rows_nans, crop_and_flip, background, - read_pose_estimation_file + read_pose_estimation_file, ) - logger_config = VameLogger(__name__) logger = logger_config.logger + def align_mouse( path_to_file: str, filename: str, @@ -42,170 +42,203 @@ def align_mouse( bg: np.ndarray, frame_count: int, use_video: bool = True, - tqdm_stream: TqdmToLogger = None -) -> Tuple[List[np.ndarray],List[List[np.ndarray]], np.ndarray]: + tqdm_stream: Union[TqdmToLogger, None] = None, +) -> Tuple[List[np.ndarray], List[List[np.ndarray]], np.ndarray]: """ Align the mouse in the video frames. - Args: - path_to_file (str): Path to the file directory. - filename (str): Name of the video file without the format. - video_format (str): Format of the video file. - crop_size (Tuple[int, int]): Size to crop the video frames. - pose_list (List[np.ndarray]): List of pose coordinates. - pose_ref_index (Tuple[int, int]): Pose reference indices. - confidence (float): Pose confidence threshold. - pose_flip_ref (Tuple[int, int]): Reference indices for flipping. - bg (np.ndarray): Background image. - frame_count (int): Number of frames to align. - use_video (bool, optional): bool if video should be cropped or DLC points only. Defaults to True. - - Returns: - Tuple[List[np.ndarray], List[List[np.ndarray]], np.ndarray]: List of aligned images, list of aligned DLC points, and time series data. + Parameters: + ----------- + path_to_file : str + Path to the file directory. + filename : str + Name of the video file without the format. + video_format : str + Format of the video file. + crop_size : Tuple[int, int] + Size to crop the video frames. + pose_list : List[np.ndarray] + List of pose coordinates. + pose_ref_index : Tuple[int, int] + Pose reference indices. + confidence : float + Pose confidence threshold. + pose_flip_ref : Tuple[int, int] + Reference indices for flipping. + bg : np.ndarray + Background image. + frame_count : int + Number of frames to align. + use_video : bool, optional + bool if video should be cropped or DLC points only. Defaults to True. + tqdm_stream : Union[TqdmToLogger, None], optional + Tqdm stream to log the progress. Defaults to None. + + Returns + ------- + Tuple[List[np.ndarray], List[List[np.ndarray]], np.ndarray] + List of aligned images, list of aligned DLC points, and aligned time series data. """ - images = [] points = [] - for i in pose_list: for j in i: if j[2] <= confidence: - j[0],j[1] = np.nan, np.nan - + j[0], j[1] = np.nan, np.nan for i in pose_list: i = interpol_first_rows_nans(i) if use_video: - capture = cv.VideoCapture(os.path.join(path_to_file,'videos',filename+video_format)) - + capture = cv.VideoCapture( + os.path.join(path_to_file, "videos", filename + video_format) + ) if not capture.isOpened(): - raise Exception("Unable to open video file: {0}".format(os.path.join(path_to_file,'videos',filename+video_format))) - - for idx in tqdm.tqdm(range(frame_count), disable=not True, file=tqdm_stream, desc='Align frames'): + raise Exception( + "Unable to open video file: {0}".format( + os.path.join(path_to_file, "videos", filename + video_format) + ) + ) + for idx in tqdm.tqdm( + range(frame_count), disable=not True, file=tqdm_stream, desc="Align frames" + ): if use_video: - #Read frame + # Read frame try: ret, frame = capture.read() frame = cv.cvtColor(frame, cv.COLOR_BGR2GRAY) frame = frame - bg frame[frame <= 0] = 0 except Exception: - logger.info("Couldn't find a frame in capture.read(). #Frame: %d" %idx) + logger.info("Couldn't find a frame in capture.read(). #Frame: %d" % idx) continue else: - frame=np.zeros((1,1)) + frame = np.zeros((1, 1)) - #Read coordinates and add border + # Read coordinates and add border pose_list_bordered = [] for i in pose_list: - pose_list_bordered.append((int(i[idx][0]+crop_size[0]),int(i[idx][1]+crop_size[1]))) - - img = cv.copyMakeBorder(frame, crop_size[1], crop_size[1], crop_size[0], crop_size[0], cv.BORDER_CONSTANT, 0) + pose_list_bordered.append( + (int(i[idx][0] + crop_size[0]), int(i[idx][1] + crop_size[1])) + ) + img = cv.copyMakeBorder( + frame, + crop_size[1], + crop_size[1], + crop_size[0], + crop_size[0], + cv.BORDER_CONSTANT, + 0, + ) coord_center = [] punkte = [] for i in pose_ref_index: coord = [] - - coord.append(pose_list_bordered[i][0]) # changed from pose_list_bordered[i][0] 2/28/2024 PN - coord.append(pose_list_bordered[i][1]) # changed from pose_list_bordered[i][1] 2/28/2024 PN - + # changed from pose_list_bordered[i][0] 2/28/2024 PN + coord.append(pose_list_bordered[i][0]) + # changed from pose_list_bordered[i][1] 2/28/2024 PN + coord.append(pose_list_bordered[i][1]) punkte.append(coord) - # coord_center.append(pose_list_bordered[5][0]-5) # coord_center.append(pose_list_bordered[5][0]+5) - # coord_center = [coord_center] punkte = [punkte] # coord_center = np.asarray(coord_center) punkte = np.asarray(punkte) - #calculate minimal rectangle around snout and tail + # calculate minimal rectangle around snout and tail rect = cv.minAreaRect(punkte) # rect_belly = cv.minAreaRect(coord_center) - # center_belly, size_belly, theta_belly = rect_belly - - #change size in rect tuple structure to be equal to crop_size + # change size in rect tuple structure to be equal to crop_size lst = list(rect) lst[1] = crop_size # lst[0] = center_belly rect = tuple(lst) - center, size, theta = rect # lst2 = list(rect) # lst2[0][0] = center[0] - size[0]//2 # lst2[0][1] = center[1] - size[1]//2 - # rect = tuple(lst2) - # center[0] -= size[0]//2 # center[1] -= size[0]//2 # added this shift to change center to belly 2/28/2024 - #crop image - out, shifted_points = crop_and_flip(rect, img,pose_list_bordered,pose_flip_ref) + # crop image + out, shifted_points = crop_and_flip( + rect, img, pose_list_bordered, pose_flip_ref + ) - if use_video: #for memory optimization, just save images when video is used. + if use_video: # for memory optimization, just save images when video is used. images.append(out) points.append(shifted_points) if use_video: capture.release() - time_series = np.zeros((len(pose_list)*2,frame_count)) + time_series = np.zeros((len(pose_list) * 2, frame_count)) for i in range(frame_count): idx = 0 for j in range(len(pose_list)): - time_series[idx:idx+2,i] = points[i][j] + time_series[idx : idx + 2, i] = points[i][j] idx += 2 return images, points, time_series -def play_aligned_video(a: List[np.ndarray], n: List[List[np.ndarray]], frame_count: int) -> None: +def play_aligned_video( + a: List[np.ndarray], + n: List[List[np.ndarray]], + frame_count: int, +) -> None: """ Play the aligned video. - Args: - a (List[np.ndarray]): List of aligned images. - n (List[List[np.ndarray]]): List of aligned DLC points. - frame_count (int): Number of frames in the video. + Parameters + ---------- + a : List[np.ndarray] + List of aligned images. + n : List[List[np.ndarray]] + List of aligned DLC points. + frame_count : int + Number of frames in the video. """ - colors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(255,0,255),(0,255,255),(0,0,0),(255,255,255)] - + colors = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + (0, 0, 0), + (255, 255, 255), + ] for i in range(frame_count): # Capture frame-by-frame - ret, frame = True,a[i] - if ret == True: - - # Display the resulting frame - frame = cv.cvtColor(frame.astype('uint8')*255, cv.COLOR_GRAY2BGR) - im_color = cv.applyColorMap(frame, cv.COLORMAP_JET) - - for c,j in enumerate(n[i]): - cv.circle(im_color,(j[0], j[1]), 5, colors[c], -1) - - cv.imshow('Frame',im_color) - - # Press Q on keyboard to exit - if cv.waitKey(25) & 0xFF == ord('q'): - break - - # Break the loop + ret, frame = True, a[i] + if ret is True: + # Display the resulting frame + frame = cv.cvtColor(frame.astype("uint8") * 255, cv.COLOR_GRAY2BGR) + im_color = cv.applyColorMap(frame, cv.COLORMAP_JET) + for c, j in enumerate(n[i]): + cv.circle(im_color, (j[0], j[1]), 5, colors[c], -1) + cv.imshow("Frame", im_color) + # Press Q on keyboard to exit + # Break the loop + if cv.waitKey(25) & 0xFF == ord("q"): + break else: break cv.destroyAllWindows() - def alignment( path_to_file: str, filename: str, @@ -214,65 +247,87 @@ def alignment( crop_size: Tuple[int, int], confidence: float, pose_estimation_filetype: PoseEstimationFiletype, - path_to_pose_nwb_series_data: str = None, + path_to_pose_nwb_series_data: Union[str, None] = None, use_video: bool = False, check_video: bool = False, - tqdm_stream: TqdmToLogger = None, + tqdm_stream: Union[TqdmToLogger, None] = None, ) -> Tuple[np.ndarray, List[np.ndarray]]: """ Perform alignment of egocentric data. - Args: - path_to_file (str): Path to the file directory. - filename (str): Name of the video file without the format. - pose_ref_index (List[int]): Pose reference indices. - video_format (str): Format of the video file. - crop_size (Tuple[int, int]): Size to crop the video frames. - confidence (float): Pose confidence threshold. - use_video (bool, optional): Whether to use video for alignment. Defaults to False. - check_video (bool, optional): Whether to check the aligned video. Defaults to False. - - Returns: - Tuple[np.ndarray, List[np.ndarray]]: Aligned time series data and list of aligned frames. + Parameters: + ----------- + path_to_file : str + Path to the file directory. + filename : str + Name of the video file without the format. + pose_ref_index : List[int] + Pose reference indices. + video_format : str + Format of the video file. + crop_size : Tuple[int, int] + Size to crop the video frames. + confidence : float + Pose confidence threshold. + pose_estimation_filetype : PoseEstimationFiletype + Pose estimation file type. Can be .csv or .nwb. + path_to_pose_nwb_series_data : Union[str, None], optional + Path to the pose series data in nwb files. Defaults to None. + use_video : bool, optional + Whether to use video for alignment. Defaults to False. + check_video : bool, optional + Whether to check the aligned video. Defaults to False. + tqdm_stream : Union[TqdmToLogger, None], optional + Tqdm stream to log the progress. Defaults to None. + + Returns + ------- + Tuple[np.ndarray, List[np.ndarray]] + Aligned time series data and list of aligned frames. """ - - #read out data - folder_path = os.path.join(path_to_file,'videos','pose_estimation') + # read out data + folder_path = os.path.join(path_to_file, "videos", "pose_estimation") data, data_mat = read_pose_estimation_file( folder_path=folder_path, filename=filename, filetype=pose_estimation_filetype, - path_to_pose_nwb_series_data=path_to_pose_nwb_series_data + path_to_pose_nwb_series_data=path_to_pose_nwb_series_data, ) # get the coordinates for alignment from data table pose_list = [] - for i in range(int(data_mat.shape[1]/3)): - pose_list.append(data_mat[:,i*3:(i+1)*3]) - - #list of reference coordinate indices for alignment - #0: snout, 1: forehand_left, 2: forehand_right, - #3: hindleft, 4: hindright, 5: tail + for i in range(int(data_mat.shape[1] / 3)): + pose_list.append(data_mat[:, i * 3 : (i + 1) * 3]) + # list of reference coordinate indices for alignment + # 0: snout, 1: forehand_left, 2: forehand_right, + # 3: hindleft, 4: hindright, 5: tail pose_ref_index = pose_ref_index - #list of 2 reference coordinate indices for avoiding flipping + # list of 2 reference coordinate indices for avoiding flipping pose_flip_ref = pose_ref_index if use_video: - #compute background - bg = background(path_to_file,filename,video_format, save_background=False) - capture = cv.VideoCapture(os.path.join(path_to_file,'videos',filename+video_format)) + # compute background + bg = background(path_to_file, filename, video_format, save_background=False) + capture = cv.VideoCapture( + os.path.join(path_to_file, "videos", filename + video_format) + ) if not capture.isOpened(): - raise Exception("Unable to open video file: {0}".format(os.path.join(path_to_file,'videos',filename+video_format))) + raise Exception( + "Unable to open video file: {0}".format( + os.path.join(path_to_file, "videos", filename + video_format) + ) + ) frame_count = int(capture.get(cv.CAP_PROP_FRAME_COUNT)) capture.release() else: bg = 0 - frame_count = len(data) # Change this to an abitrary number if you first want to test the code - + frame_count = len( + data + ) # Change this to an abitrary number if you first want to test the code frames, n, time_series = align_mouse( path_to_file, @@ -298,60 +353,76 @@ def alignment( @save_state(model=EgocentricAlignmentFunctionSchema) def egocentric_alignment( config: str, - pose_ref_index: list = [5,6], - crop_size: tuple = (300,300), + pose_ref_index: list = [5, 6], + crop_size: tuple = (300, 300), use_video: bool = False, - video_format: str = '.mp4', + video_format: str = ".mp4", check_video: bool = False, - save_logs: bool = False + save_logs: bool = False, ) -> None: - """Aligns egocentric data for VAME training - - Args: - config (str): Path for the project config file. - pose_ref_index (list, optional): Pose reference index to be used to align. Defaults to [5,6]. - crop_size (tuple, optional): Size to crop the video. Defaults to (300,300). - use_video (bool, optional): Weather to use video to do the post alignment. Defaults to False. # TODO check what to put in this docstring - video_format (str, optional): Video format, can be .mp4 or .avi. Defaults to '.mp4'. - check_video (bool, optional): Weather to check the video. Defaults to False. + """ + Egocentric alignment of bevarioral videos. + Fills in the values in the "egocentric_alignment" key of the states.json file. + Creates training dataset for VAME at: + - project_name/ + - data/ + - video1/ + - filename-PE-seq.npy + - video2/ + - filename-PE-seq.npy + + Parameters: + config : str + Path for the project config file. + pose_ref_index : list, optional + Pose reference index to be used to align. Defaults to [5,6]. + crop_size : tuple, optional + Size to crop the video. Defaults to (300,300). + use_video : bool, optional + Weather to use video to do the post alignment. Defaults to False. + video_format : str, optional + Video format, can be .mp4 or .avi. Defaults to '.mp4'. + check_video : bool, optional + Weather to check the video. Defaults to False. Raises: - ValueError: If the config.yaml indicates that the data is not egocentric. + ------ + ValueError + If the config.yaml indicates that the data is not egocentric. """ - - # pose_ref_index changed in this script from [0,5] to [5,6] on 2/7/2024 PN - """ Happy aligning """ - #config parameters - try: config_file = Path(config).resolve() - cfg = read_config(config_file) + cfg = read_config(str(config_file)) tqdm_stream = None if save_logs: - log_path = Path(cfg['project_path']) / 'logs' / 'egocentric_alignment.log' - logger_config.add_file_handler(log_path) + log_path = Path(cfg["project_path"]) / "logs" / "egocentric_alignment.log" + logger_config.add_file_handler(str(log_path)) tqdm_stream = TqdmToLogger(logger=logger) - logger.info('Starting egocentric alignment') - path_to_file = cfg['project_path'] - filename = cfg['video_sets'] - confidence = cfg['pose_confidence'] - num_features = cfg['num_features'] - video_format=video_format - crop_size=crop_size + logger.info("Starting egocentric alignment") + path_to_file = cfg["project_path"] + filename = cfg["video_sets"] + confidence = cfg["pose_confidence"] + num_features = cfg["num_features"] + video_format = video_format + crop_size = crop_size y_shifted_indices = np.arange(0, num_features, 2) x_shifted_indices = np.arange(1, num_features, 2) belly_Y_ind = pose_ref_index[0] * 2 belly_X_ind = (pose_ref_index[0] * 2) + 1 - if cfg['egocentric_data']: - raise ValueError("The config.yaml indicates that the data is egocentric. Please check the parameter egocentric_data") + if cfg["egocentric_data"]: + raise ValueError( + "The config.yaml indicates that the data is egocentric. Please check the parameter egocentric_data" + ) # call function and save into your VAME data folder - paths_to_pose_nwb_series_data = cfg['paths_to_pose_nwb_series_data'] + paths_to_pose_nwb_series_data = cfg["paths_to_pose_nwb_series_data"] for i, file in enumerate(filename): - logger.info("Aligning data %s, Pose confidence value: %.2f" %(file, confidence)) + logger.info( + "Aligning data %s, Pose confidence value: %.2f" % (file, confidence) + ) egocentric_time_series, frames = alignment( path_to_file=path_to_file, filename=file, @@ -359,25 +430,34 @@ def egocentric_alignment( video_format=video_format, crop_size=crop_size, confidence=confidence, - pose_estimation_filetype=cfg['pose_estimation_filetype'], - path_to_pose_nwb_series_data=paths_to_pose_nwb_series_data if not paths_to_pose_nwb_series_data else paths_to_pose_nwb_series_data[i], + pose_estimation_filetype=cfg["pose_estimation_filetype"], + path_to_pose_nwb_series_data=( + paths_to_pose_nwb_series_data + if not paths_to_pose_nwb_series_data + else paths_to_pose_nwb_series_data[i] + ), use_video=use_video, check_video=check_video, - tqdm_stream=tqdm_stream + tqdm_stream=tqdm_stream, ) # Shifiting section added 2/29/2024 PN egocentric_time_series_shifted = egocentric_time_series - belly_Y_shift = egocentric_time_series[belly_Y_ind,:] - belly_X_shift = egocentric_time_series[belly_X_ind,:] + belly_Y_shift = egocentric_time_series[belly_Y_ind, :] + belly_X_shift = egocentric_time_series[belly_X_ind, :] egocentric_time_series_shifted[y_shifted_indices, :] -= belly_Y_shift egocentric_time_series_shifted[x_shifted_indices, :] -= belly_X_shift - np.save(os.path.join(path_to_file,'data',file,file+'-PE-seq.npy'), egocentric_time_series_shifted) # save new shifted file - # np.save(os.path.join(path_to_file,'data/',file,"",file+'-PE-seq.npy', egocentric_time_series)) + # Save new shifted file + np.save( + os.path.join(path_to_file, "data", file, file + "-PE-seq.npy"), + egocentric_time_series_shifted, + ) - logger.info("Your data is now ine right format and you can call vame.create_trainset()") + logger.info( + "Your data is now in the right format and you can call vame.create_trainset()" + ) except Exception as e: logger.exception(f"{e}") raise e diff --git a/src/vame/util/data_manipulation.py b/src/vame/util/data_manipulation.py index 508e6d42..b522386d 100644 --- a/src/vame/util/data_manipulation.py +++ b/src/vame/util/data_manipulation.py @@ -16,95 +16,154 @@ logger_config = VameLogger(__name__) logger = logger_config.logger -def get_pose_data_from_nwb_file(nwbfile: NWBFile, path_to_pose_nwb_series_data: str) -> LabelledDict: + +def get_pose_data_from_nwb_file( + nwbfile: NWBFile, + path_to_pose_nwb_series_data: str, +) -> LabelledDict: """ Get pose data from nwb file using a inside path to the nwb data. - Args: - nwbfile (NWBFile): NWB file object. - path_to_pose_nwb_series_data (str): Path to the pose data inside the nwb file. - - Returns: - LabelledDict: Pose data. + Parameters: + ---------- + nwbfile : NWBFile) + NWB file object. + path_to_pose_nwb_series_data : str + Path to the pose data inside the nwb file. + + Returns + ------- + LabelledDict + Pose data. """ if not path_to_pose_nwb_series_data: - raise ValueError('Path to pose nwb series data is required.') - + raise ValueError("Path to pose nwb series data is required.") pose_data = nwbfile - for key in path_to_pose_nwb_series_data.split('/'): + for key in path_to_pose_nwb_series_data.split("/"): if isinstance(pose_data, dict): pose_data = pose_data.get(key) continue pose_data = getattr(pose_data, key) return pose_data -def get_dataframe_from_pose_nwb_file(file_path: str, path_to_pose_nwb_series_data: str): - with NWBHDF5IO(file_path, 'r') as io: +def get_dataframe_from_pose_nwb_file( + file_path: str, + path_to_pose_nwb_series_data: str, +) -> pd.DataFrame: + """ + Get pose data from nwb file and return it as a pandas DataFrame. + + Parameters + ---------- + file_path : str + Path to the nwb file. + path_to_pose_nwb_series_data : str + Path to the pose data inside the nwb file. + + Returns + ------- + pd.DataFrame + Pose data as a pandas DataFrame. + """ + with NWBHDF5IO(file_path, "r") as io: nwbfile = io.read() - # Todo change to use variable as path to pose estimation in nwb + # TODO - change to use variable as path to pose estimation in nwb pose = get_pose_data_from_nwb_file(nwbfile, path_to_pose_nwb_series_data) - dataframes = [] for label, pose_series in pose.items(): data = pose_series.data[:] confidence = pose_series.confidence[:] - df = pd.DataFrame(data, columns=[f'{label}_x', f'{label}_y']) - df[f'likelihood_{label}'] = confidence + df = pd.DataFrame(data, columns=[f"{label}_x", f"{label}_y"]) + df[f"likelihood_{label}"] = confidence dataframes.append(df) final_df = pd.concat(dataframes, axis=1) - return final_df + def read_pose_estimation_file( folder_path: str, filename: str, filetype: PoseEstimationFiletype, - path_to_pose_nwb_series_data: Optional[str] = None + path_to_pose_nwb_series_data: Optional[str] = None, ) -> Tuple[pd.DataFrame, np.ndarray]: - - + """ + Read pose estimation file. + + Parameters + ---------- + folder_path : str + Path to the folder containing the pose estimation file. + filename : str + Name of the pose estimation file. + filetype : PoseEstimationFiletype + Type of the pose estimation file. Supported types are 'csv' and 'nwb'. + path_to_pose_nwb_series_data : str, optional + Path to the pose data inside the nwb file, by default None + + Returns + ------- + Tuple[pd.DataFrame, np.ndarray] + Tuple containing the pose estimation data as a pandas DataFrame and a numpy array. + """ if filetype == PoseEstimationFiletype.csv: - file_path = Path(folder_path) / f'{filename}.{filetype}' + file_path = Path(folder_path) / f"{filename}.{filetype}" data = pd.read_csv(file_path, skiprows=2) - if 'coords' in data: - data = data.drop(columns=['coords'], axis=1) + if "coords" in data: + data = data.drop(columns=["coords"], axis=1) data_mat = pd.DataFrame.to_numpy(data) return data, data_mat elif filetype == PoseEstimationFiletype.nwb: - file_path = Path(folder_path) / f'{filename}.{filetype}' - data = get_dataframe_from_pose_nwb_file(file_path=file_path, path_to_pose_nwb_series_data=path_to_pose_nwb_series_data) + file_path = Path(folder_path) / f"{filename}.{filetype}" + if not path_to_pose_nwb_series_data: + raise ValueError("Path to pose nwb series data is required.") + data = get_dataframe_from_pose_nwb_file( + file_path=str(file_path), + path_to_pose_nwb_series_data=path_to_pose_nwb_series_data, + ) data_mat = pd.DataFrame.to_numpy(data) return data, data_mat + raise ValueError(f"Filetype {filetype} not supported") - raise ValueError(f'Filetype {filetype} not supported') - - -def consecutive(data: np.ndarray, stepsize: int = 1) -> List[np.ndarray]: - """Find consecutive sequences in the data array. - Args: - data (np.ndarray): Input array. - stepsize (int, optional): Step size. Defaults to 1. - - Returns: - List[np.ndarray]: List of consecutive sequences. +def consecutive( + data: np.ndarray, + stepsize: int = 1, +) -> List[np.ndarray]: + """ + Find consecutive sequences in the data array. + + Parameters + ---------- + data : np.ndarray + Input array. + stepsize : int, optional + Step size. Defaults to 1. + + Returns + ------- + List[np.ndarray] + List of consecutive sequences. """ data = data[:] - return np.split(data, np.where(np.diff(data) != stepsize)[0]+1) + return np.split(data, np.where(np.diff(data) != stepsize)[0] + 1) def nan_helper(y: np.ndarray) -> Tuple: """ Identifies indices of NaN values in an array and provides a function to convert them to non-NaN indices. - Args: - y (np.ndarray): Input array containing NaN values. - - Returns: - Tuple[np.ndarray, Union[np.ndarray, None]]: A tuple containing two elements: - - An array of boolean values indicating the positions of NaN values. - - A lambda function to convert NaN indices to non-NaN indices. + Parameters + ---------- + y : np.ndarray + Input array containing NaN values. + + Returns + ------- + Tuple[np.ndarray, Union[np.ndarray, None]] + A tuple containing two elements: + - An array of boolean values indicating the positions of NaN values. + - A lambda function to convert NaN indices to non-NaN indices. """ return np.isnan(y), lambda z: z.nonzero()[0] @@ -113,15 +172,19 @@ def interpol_all_nans(arr: np.ndarray) -> np.ndarray: """ Interpolates all NaN values in the given array. - Args: - arr (np.ndarray): Input array containing NaN values. + Parameters + ---------- + arr : np.ndarray + Input array containing NaN values. - Returns: - np.ndarray: Array with NaN values replaced by interpolated values. + Returns + ------- + np.ndarray + Array with NaN values replaced by interpolated values. """ y = np.transpose(arr) nans, x = nan_helper(y) - y[nans]= np.interp(x(nans), x(~nans), y[~nans]) + y[nans] = np.interp(x(nans), x(~nans), y[~nans]) arr = np.transpose(y) return arr @@ -130,142 +193,164 @@ def interpol_first_rows_nans(arr: np.ndarray) -> np.ndarray: """ Interpolates NaN values in the given array. - Args: - arr (np.ndarray): Input array with NaN values. + Parameters + ---------- + arr : np.ndarray + Input array with NaN values. - Returns: - np.ndarray: Array with interpolated NaN values. + Returns + ------- + np.ndarray + Array with interpolated NaN values. """ - y = np.transpose(arr) - nans, x = nan_helper(y[0]) - y[0][nans]= np.interp(x(nans), x(~nans), y[0][~nans]) + y[0][nans] = np.interp(x(nans), x(~nans), y[0][~nans]) nans, x = nan_helper(y[1]) - y[1][nans]= np.interp(x(nans), x(~nans), y[1][~nans]) - + y[1][nans] = np.interp(x(nans), x(~nans), y[1][~nans]) arr = np.transpose(y) - return arr + def crop_and_flip( rect: Tuple, src: np.ndarray, points: List[np.ndarray], - ref_index: Tuple[int, int] + ref_index: Tuple[int, int], ) -> Tuple[np.ndarray, List[np.ndarray]]: """ Crop and flip the image based on the given rectangle and points. - Args: - rect (Tuple): Rectangle coordinates (center, size, theta). - src (np.ndarray): Source image. - points (List[np.ndarray]): List of points. - ref_index (Tuple[int, int]): Reference indices for alignment. - - Returns: - Tuple[np.ndarray, List[np.ndarray]]: Cropped and flipped image, and shifted points. + Parameters + ---------- + rect : Tuple + Rectangle coordinates (center, size, theta). + src: np.ndarray + Source image. + points : List[np.ndarray] + List of points. + ref_index : Tuple[int, int] + Reference indices for alignment. + + Returns + ------- + Tuple[np.ndarray, List[np.ndarray]] + Cropped and flipped image, and shifted points. """ - #Read out rect structures and convert + # Read out rect structures and convert center, size, theta = rect - center, size = tuple(map(int, center)), tuple(map(int, size)) - #Get rotation matrix + # Get rotation matrix M = cv.getRotationMatrix2D(center, theta, 1) - #shift DLC points - x_diff = center[0] - size[0]//2 - y_diff = center[1] - size[1]//2 - + # shift DLC points + x_diff = center[0] - size[0] // 2 + y_diff = center[1] - size[1] // 2 dlc_points_shifted = [] - for i in points: - point=cv.transform(np.array([[[i[0], i[1]]]]),M)[0][0] - + point = cv.transform(np.array([[[i[0], i[1]]]]), M)[0][0] point[0] -= x_diff point[1] -= y_diff - dlc_points_shifted.append(point) # Perform rotation on src image - dst = cv.warpAffine(src.astype('float32'), M, src.shape[:2]) + dst = cv.warpAffine(src.astype("float32"), M, src.shape[:2]) out = cv.getRectSubPix(dst, size, center) - #check if flipped correctly, otherwise flip again + # check if flipped correctly, otherwise flip again if dlc_points_shifted[ref_index[1]][0] >= dlc_points_shifted[ref_index[0]][0]: - rect = ((size[0]//2,size[0]//2),size,180) #should second value be size[1]? Is this relevant to the flip? 3/5/24 KKL + rect = ( + (size[0] // 2, size[0] // 2), + size, + 180, + ) # should second value be size[1]? Is this relevant to the flip? 3/5/24 KKL center, size, theta = rect center, size = tuple(map(int, center)), tuple(map(int, size)) - #Get rotation matrix + # Get rotation matrix M = cv.getRotationMatrix2D(center, theta, 1) - #shift DLC points - x_diff = center[0] - size[0]//2 - y_diff = center[1] - size[1]//2 + # shift DLC points + x_diff = center[0] - size[0] // 2 + y_diff = center[1] - size[1] // 2 points = dlc_points_shifted dlc_points_shifted = [] for i in points: - point=cv.transform(np.array([[[i[0], i[1]]]]),M)[0][0] - + point = cv.transform(np.array([[[i[0], i[1]]]]), M)[0][0] point[0] -= x_diff point[1] -= y_diff - dlc_points_shifted.append(point) # Perform rotation on src image - dst = cv.warpAffine(out.astype('float32'), M, out.shape[:2]) + dst = cv.warpAffine(out.astype("float32"), M, out.shape[:2]) out = cv.getRectSubPix(dst, size, center) - return out, dlc_points_shifted + def background( path_to_file: str, filename: str, - file_format: str = '.mp4', + file_format: str = ".mp4", num_frames: int = 1000, - save_background: bool = True + save_background: bool = True, ) -> np.ndarray: """ Compute background image from fixed camera. - Args: - path_to_file (str): Path to the directory containing the video files. - filename (str): Name of the video file. - file_format (str, optional): Format of the video file. Defaults to '.mp4'. - num_frames (int, optional): Number of frames to use for background computation. Defaults to 1000. - - Returns: - np.ndarray: Background image. + Parameters + ---------- + path_to_file : str + Path to the directory containing the video files. + filename : str + Name of the video file. + file_format : str, optional + Format of the video file. Defaults to '.mp4'. + num_frames : int, optional + Number of frames to use for background computation. Defaults to 1000. + + Returns + ------- + np.ndarray + Background image. """ - - capture = cv.VideoCapture(os.path.join(path_to_file,"videos",filename+file_format)) - + capture = cv.VideoCapture( + os.path.join(path_to_file, "videos", filename + file_format) + ) if not capture.isOpened(): - raise Exception("Unable to open video file: {0}".format(os.path.join(path_to_file,"videos",filename+file_format))) + raise Exception( + "Unable to open video file: {0}".format( + os.path.join(path_to_file, "videos", filename + file_format) + ) + ) frame_count = int(capture.get(cv.CAP_PROP_FRAME_COUNT)) ret, frame = capture.read() - height, width, _ = frame.shape - frames = np.zeros((height,width,num_frames)) + frames = np.zeros((height, width, num_frames)) - for i in tqdm.tqdm(range(num_frames), disable=not True, desc='Compute background image for video %s' %filename): + for i in tqdm.tqdm( + range(num_frames), + disable=not True, + desc="Compute background image for video %s" % filename, + ): rand = np.random.choice(frame_count, replace=False) - capture.set(1,rand) + capture.set(1, rand) ret, frame = capture.read() gray = cv.cvtColor(frame, cv.COLOR_BGR2GRAY) - frames[...,i] = gray + frames[..., i] = gray - logger.info('Finishing up!') - medFrame = np.median(frames,2) - background = median_filter(medFrame, (5,5)) + logger.info("Finishing up!") + medFrame = np.median(frames, 2) + background = median_filter(medFrame, (5, 5)) if save_background: - np.save(os.path.join(path_to_file,"videos",filename+'-background.npy'),background) + np.save( + os.path.join(path_to_file, "videos", filename + "-background.npy"), + background, + ) capture.release() - return background \ No newline at end of file + return background