Skip to content

Commit

Permalink
refactor: improve BehaviorDataset class with train/validation split
Browse files Browse the repository at this point in the history
  • Loading branch information
healthonrails committed Nov 4, 2024
1 parent 6e29ed8 commit b745e72
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 124 deletions.
156 changes: 91 additions & 65 deletions annolid/behavior/data_loading/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import os
import cv2
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset
from typing import Tuple, Dict, Callable, List, Optional
import logging
from sklearn.model_selection import train_test_split

logger = logging.getLogger(__name__)

Expand All @@ -17,89 +19,99 @@

class BehaviorDataset(Dataset):
def __init__(self, video_folder: str, num_frames: int = NUM_FRAMES, clip_len: float = CLIP_LEN,
fps: int = FPS, transform: Optional[Callable] = None, video_ext: str = ".mpg"):
fps: int = FPS, transform: Optional[Callable] = None, video_ext: str = ".mpg",
split: str = 'train', val_ratio: float = 0.2, random_seed: int = 42):
"""
Initializes the dataset with the folder containing videos and their corresponding annotations.
Initializes the dataset with optional training/validation split.
:param video_folder: Path to the folder containing video files.
:param num_frames: Number of frames to extract per video.
:param clip_len: Length of video clips in seconds.
:param fps: Frames per second of the video.
:param transform: Callable transformation to apply to frames.
:param video_ext: Video file extension (e.g., ".mpg").
:param split: Either 'train' or 'val' to specify the dataset split.
:param val_ratio: Ratio of data for validation.
:param random_seed: Random seed for reproducibility.
"""
self.video_folder = video_folder
self.num_frames = num_frames
self.clip_len = clip_len
self.fps = fps
self.video_ext = video_ext
self.transform = transform or ResizeCenterCropNormalize()
self.split = split
self.val_ratio = val_ratio
self.random_seed = random_seed

self.video_files, self.all_annotations = self.load_annotations()
if not self.video_files or not self.all_annotations:
raise ValueError("No video/annotation files found. Check paths and data.")

self.label_mapping = self.create_label_mapping()
self.indices = self.create_split_indices(split, val_ratio, random_seed)

if len(self.indices) == 0:
raise ValueError("No samples after split. Check split ratio and data.")

def load_annotations(self) -> Tuple[List[str], Dict[str, pd.DataFrame]]:
"""
Loads video files and their corresponding annotations (CSV).

:return: A tuple of video files and a dictionary of annotations DataFrames.
def create_split_indices(self, split: str, val_ratio: float, random_seed: int) -> List[int]:
"""
video_files = [f for f in os.listdir(
self.video_folder) if f.endswith(self.video_ext)]
all_annotations = {}
Splits dataset indices for training and validation using stratified sampling.
"""
np.random.seed(random_seed)
all_indices = np.arange(sum(len(annotations) for annotations in self.all_annotations.values()))
labels = []
for video_file, annotations in self.all_annotations.items():
for _, row in annotations.iterrows():
behavior = row.get("Behavior", "unlabeled")
labels.append(self.label_mapping.get(behavior, self.label_mapping["unlabeled"]))

for video_file in video_files.copy():
csv_file = os.path.splitext(video_file)[0] + ".csv"
csv_path = os.path.join(self.video_folder, csv_file)

try:
all_annotations[video_file] = pd.read_csv(csv_path)
except FileNotFoundError:
logger.warning(
f"CSV file not found for {video_file}. Skipping.")
video_files.remove(video_file)
except pd.errors.ParserError:
logger.warning(
f"Error parsing CSV file for {video_file}. Skipping.")
video_files.remove(video_file)
train_indices, val_indices = train_test_split(
all_indices, test_size=val_ratio, stratify=labels, random_state=random_seed
)

return video_files, all_annotations
return train_indices if split == 'train' else val_indices

def create_label_mapping(self) -> Dict[str, int]:
"""
Creates a mapping of behavior labels to integers.
def get_num_classes(self) -> int:
return len(self.label_mapping)

:return: Dictionary mapping behaviors to indices.
"""
behaviors = set()
for annotations in self.all_annotations.values():
behaviors.update(annotations["Behavior"].unique())
return {behavior: i for i, behavior in enumerate(behaviors)}

def __len__(self) -> int:
"""
Returns the total number of annotation rows across all videos.
return len(self.indices)

:return: Total number of data points.
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int, str]:
"""
return sum(len(annotations) for annotations in self.all_annotations.values())

def __getitem__(self, index: int) -> Optional[Tuple[torch.Tensor, int, str]]:
Fetches data for the given index within the split. Handles potential errors.
"""
Retrieves the data (frames, label, and video path) for a given index.
index = self.indices[idx]

:param index: Index in the dataset.
:return: A tuple containing frames, the label, and the video path, or None if an error occurs.
"""
video_file, row_index, annotations = self.get_video_and_row_index(index)

video_path = video_file # Path is already complete
behavior = annotations.iloc[row_index].get("Behavior", "unlabeled")
label = self.label_mapping.get(behavior, self.label_mapping["unlabeled"])
frames = self.load_video_frames(video_path, row_index, annotations)


if frames is None:
raise ValueError(f"Failed to load frames for video {video_path} at row {row_index}") # Or handle differently

if self.transform:
frames = torch.stack([self.transform(frame) for frame in frames])

return frames, label, video_path

def fetch_data(self, index: int) -> Optional[Tuple[torch.Tensor, int, str]]:
try:
video_file, row_index, annotations = self.get_video_and_row_index(
index)
video_path = os.path.join(self.video_folder, video_file)

trial_time = annotations.iloc[row_index,
annotations.columns.get_loc("Trial time")]
behavior = annotations.iloc[row_index,
annotations.columns.get_loc("Behavior")]
label = self.label_mapping[behavior]
# Default to "unlabeled" if "Behavior" column is missing
behavior = annotations.iloc[row_index]["Behavior"] if "Behavior" in annotations.columns else "unlabeled"
label = self.label_mapping.get(
behavior, self.label_mapping["unlabeled"])

frames = self.load_video_frames(video_path, row_index, annotations)

Expand All @@ -115,15 +127,37 @@ def __getitem__(self, index: int) -> Optional[Tuple[torch.Tensor, int, str]]:
logger.error(f"Error fetching item at index {index}: {e}")
return None

def load_video_frames(self, video_path: str, row_index: int, annotations: pd.DataFrame) -> Optional[torch.Tensor]:
"""
Loads the video frames for a given video and row index.
def load_annotations(self) -> Tuple[List[str], Dict[str, pd.DataFrame]]:
video_files = []
all_annotations = {}

:param video_path: Path to the video file.
:param row_index: Row index in the annotations DataFrame.
:param annotations: DataFrame containing the annotations for the video.
:return: A tensor containing the frames, or None if an error occurs.
"""
for root, _, files in os.walk(self.video_folder):
for file in files:
if file.endswith(self.video_ext):
video_path = os.path.join(root, file)
csv_path = os.path.splitext(video_path)[0] + ".csv"
video_files.append(video_path)

try:
all_annotations[video_path] = pd.read_csv(csv_path)
except FileNotFoundError:
logger.warning(
f"CSV not found for {video_path}. Skipping.")
except pd.errors.ParserError:
logger.warning(
f"Error parsing CSV for {video_path}. Skipping.")

return video_files, all_annotations

def create_label_mapping(self) -> Dict[str, int]:
behaviors = set()
for annotations in self.all_annotations.values():
behaviors.update(annotations["Behavior"].unique())
label_mapping = {behavior: i for i, behavior in enumerate(behaviors)}
label_mapping["unlabeled"] = len(label_mapping)
return label_mapping

def load_video_frames(self, video_path: str, row_index: int, annotations: pd.DataFrame) -> Optional[torch.Tensor]:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
logger.error(f"Could not open video: {video_path}")
Expand All @@ -133,7 +167,7 @@ def load_video_frames(self, video_path: str, row_index: int, annotations: pd.Dat

try:
start_frame = int(
annotations.iloc[row_index, annotations.columns.get_loc("Trial time")] * self.fps)
annotations.iloc[row_index]["Trial time"] * self.fps)
except KeyError as e:
logger.error(
f"Missing 'Trial time' column in CSV for {video_path}: {e}")
Expand All @@ -160,7 +194,6 @@ def load_video_frames(self, video_path: str, row_index: int, annotations: pd.Dat
return None

cap.release()

if len(frames) != self.num_frames:
logger.warning(
f"Expected {self.num_frames} frames, but got {len(frames)} from {video_path}")
Expand All @@ -169,13 +202,6 @@ def load_video_frames(self, video_path: str, row_index: int, annotations: pd.Dat
return torch.stack(frames)

def get_video_and_row_index(self, index: int) -> Tuple[str, int, pd.DataFrame]:
"""
Maps the index to a specific video and annotation row.
:param index: Dataset index.
:return: A tuple containing the video file name, row index, and annotation DataFrame.
:raises IndexError: If the index is out of range.
"""
current_index = 0
for video_file, annotations in self.all_annotations.items():
if index < current_index + len(annotations):
Expand Down
6 changes: 3 additions & 3 deletions annolid/behavior/models/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, feature_extractor: nn.Module, d_model: int = 512, nhead: int
self.feature_extractor = feature_extractor
self.positional_encoding = PositionalEncoding(d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout)
d_model, nhead, dim_feedforward, dropout, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers)
self.classifier = nn.Linear(d_model, num_classes)
Expand All @@ -86,11 +86,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# Reshape and transpose
# (frames, batch_size, feature_dim)
features = features.view(batch_size, frames, -1).transpose(0, 1)
features = features.view(batch_size, frames, -1)

features = self.positional_encoding(features)
encoded_features = self.transformer_encoder(features)
pooled_features = encoded_features.mean(
dim=0) # Global average pooling
dim=1) # Global average pooling
output = self.classifier(pooled_features)
return output
33 changes: 14 additions & 19 deletions annolid/behavior/models/feature_extractors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torchvision.models as models
import logging

logger = logging.getLogger(__name__)
Expand All @@ -8,34 +9,28 @@
class ResNetFeatureExtractor(nn.Module):
"""
Extracts features from images using a ResNet backbone.
Args:
pretrained (bool, optional): Whether to use a pre-trained ResNet model. Defaults to True.
feature_dim (int, optional): The desired dimension of the output features. Defaults to 512.
"""

def __init__(self, pretrained: bool = True, feature_dim: int = 512):
super().__init__()
self.resnet = torch.hub.load(
'pytorch/vision:v0.10.0', 'resnet18', pretrained=pretrained)
self.resnet_in_features = self.resnet.fc.in_features # Store the in_features
self.resnet.fc = nn.Identity()
# Use torchvision.models directly for easier weight handling
if pretrained:
# or .IMAGENET1K_V1 if you specifically need that
self.resnet = models.resnet18(
weights=models.ResNet18_Weights.DEFAULT)
else:
# Explicitly set weights to None if not pretrained
self.resnet = models.resnet18(weights=None)

self.resnet_in_features = self.resnet.fc.in_features
self.resnet.fc = nn.Identity() # Remove the classification head

# Use self.resnet_in_features for the projection layer
self.project_layer = nn.Linear(
self.resnet_in_features, feature_dim) if feature_dim != self.resnet_in_features else None

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the feature extractor.
Args:
x (torch.Tensor): The input image tensor.
Returns:
torch.Tensor: The extracted feature tensor.
"""
"""Forward pass."""
features = self.resnet(x)
if self.project_layer: # Apply projection if feature_dim is different
if self.project_layer:
features = self.project_layer(features)
return features
Loading

0 comments on commit b745e72

Please sign in to comment.