Skip to content

Commit

Permalink
Merge branch 'main' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
healthonrails authored Apr 18, 2024
2 parents 33156eb + bfc90ff commit 99b5af9
Show file tree
Hide file tree
Showing 8 changed files with 399 additions and 13 deletions.
271 changes: 271 additions & 0 deletions annolid/annotation/sleap2labelme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
import os
import h5py
from annolid.annotation.keypoints import save_labels
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from annolid.gui.shape import Shape
from annolid.data.videos import CV2Video
from annolid.utils.shapes import load_zone_json, is_point_in_polygon


def fill_missing(Y, kind="linear"):
"""Fills missing values independently along each dimension after the first.
Reference: https://sleap.ai/notebooks/Analysis_examples.html
"""

# Store initial shape.
initial_shape = Y.shape

# Flatten after first dim.
Y = Y.reshape((initial_shape[0], -1))

# Interpolate along each slice.
for i in range(Y.shape[-1]):
y = Y[:, i]

# Build interpolant.
x = np.flatnonzero(~np.isnan(y))
f = interp1d(x, y[x], kind=kind, fill_value=np.nan, bounds_error=False)

# Fill missing
xq = np.flatnonzero(np.isnan(y))
y[xq] = f(xq)

# Fill leading or trailing NaNs with the nearest non-NaN values
mask = np.isnan(y)
y[mask] = np.interp(np.flatnonzero(
mask), np.flatnonzero(~mask), y[~mask])

# Save slice
Y[:, i] = y

# Restore to initial shape.
Y = Y.reshape(initial_shape)

return Y


def plot_keypoint_locations(filename, keypoint_name):
with h5py.File(filename, "r") as f:
dset_names = list(f.keys())
locations = f["tracks"][:].T
node_names = [n.decode() for n in f["node_names"][:]]

# Create KEYPOINT_MAP from node_names
KEYPOINT_MAP = {name.lower(): i for i, name in enumerate(node_names)}

assert keypoint_name.lower() in KEYPOINT_MAP, "Invalid keypoint name. Supported keypoints: " + \
", ".join(KEYPOINT_MAP.keys())
keypoint_index = KEYPOINT_MAP[keypoint_name.lower()]

print("===filename===")
print(filename)
print()

print("===HDF5 datasets===")
print(dset_names)
print()

print("===locations data shape===")
print(locations.shape)
print()

print("===nodes===")
for i, name in enumerate(node_names):
print(f"{i}: {name}")
print()

frame_count, node_count, _, instance_count = locations.shape

print("frame count:", frame_count)
print("node count:", node_count)
print("instance count:", instance_count)

locations = fill_missing(locations)

keypoint_loc = locations[:, keypoint_index, :, :]

plt.figure()
plt.plot(keypoint_loc[:, 0, 0], 'y', label=f"{keypoint_name}-0")
plt.plot(-1 * keypoint_loc[:, 1, 0], 'y')
plt.legend(loc="center right")
plt.title(f"{keypoint_name.capitalize()} locations")

plt.figure(figsize=(7, 7))
plt.plot(keypoint_loc[:, 0, 0], keypoint_loc[:, 1, 0],
'y', label=f"{keypoint_name}-0")
plt.legend()
plt.xlim(0, 1024)
plt.xticks([])
plt.ylim(0, 1024)
plt.yticks([])
plt.title(f"{keypoint_name.capitalize()} tracks")

plt.show()


def plot_all_keypoints(filename):
"""
Plot the locations and tracks of all keypoints in the provided HDF5 file.
Parameters:
filename (str): Path to the HDF5 file containing the keypoints data.
Returns:
None
"""
with h5py.File(filename, "r") as f:
# Extract datasets
dset_names = list(f.keys())
locations = f["tracks"][:].T
node_names = [n.decode() for n in f["node_names"][:]]

# Create KEYPOINT_MAP from node_names
KEYPOINT_MAP = {name.lower(): i for i, name in enumerate(node_names)}

print("===filename===")
print(filename)
print()

print("===HDF5 datasets===")
print(dset_names)
print()

print("===nodes===")
for i, name in enumerate(node_names):
print(f"{i}: {name}")
print()

frame_count, node_count, _, instance_count = locations.shape

print("===locations data shape===")
print(locations.shape)
print()

print("frame count:", frame_count)
print("node count:", node_count)
print("instance count:", instance_count)

locations = fill_missing(locations)

# Plot each keypoint
plt.figure(figsize=(15, 10))
for keypoint_name, keypoint_index in KEYPOINT_MAP.items():
keypoint_loc = locations[:, keypoint_index, :, :]
plt.plot(keypoint_loc[:, 0, 0], keypoint_loc[:, 1, 0],
label=f"{keypoint_name.capitalize()}-0")

plt.legend()
plt.xlim(0, 1024)
plt.ylim(0, 1024)
plt.title('All Keypoints Tracks')
plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')
plt.grid(True)
plt.show()


def get_frame_info(video_file=None):
"""
Get frame information (height, width, and total number of frames) from a video file.
Parameters:
video_file (str, optional): Path to the video file. Defaults to None.
Returns:
tuple: A tuple containing height, width, and number of frames.
"""
if video_file is not None:
video_loader = CV2Video(video_file)
first_frame = video_loader.get_first_frame()
height = video_loader.get_height()
width = video_loader.get_width()
num_frames = video_loader.total_frames()
else:
height, width, num_frames = 600, 800, 89761
return height, width, num_frames


def convert_sleap_h5_to_labelme(h5_file_path,
zone_info=None):
"""
Convert a SLEAP HDF5 file to Labelme JSON files.
Parameters:
h5_file_path (str): Path to the SLEAP HDF5 file.
Returns:
None
"""
# Output folder name without extension
output_folder = os.path.splitext(h5_file_path)[0]
# Create output folder if not exists
os.makedirs(output_folder, exist_ok=True)
video_file = h5_file_path.replace('.h5', 'mp4')
if os.path.exists(video_file):
height, width, _ = get_frame_info(video_file)
else:
height, width, _ = 600, 800, 89761

# Determine image information
video_name = os.path.splitext(os.path.basename(h5_file_path))[0]
if zone_info is None:
zone_info = os.path.join(output_folder, f"{video_name}_000000000.json")
has_zone_info = os.path.exists(zone_info)
zone_shapes = load_zone_json(zone_info)['shapes']
zone_shapes = [
zone_shape for zone_shape in zone_shapes if 'zone' in zone_shape['description'].lower()]

with h5py.File(h5_file_path, 'r') as f:
# Extract relevant datasets
locations = f["tracks"][:].T
locations = fill_missing(locations)
node_names = [n.decode() for n in f["node_names"][:]]

# Get the dimensions from the data
frame_count, node_count, _, instance_count = locations.shape

# Iterate through frames
for frame_idx in range(frame_count):
shape_list = []
for instance_idx in range(instance_count):
# Iterate through nodes in the instance

for node_idx, node_name in enumerate(node_names):
# Extract coordinates
x, y = locations[frame_idx, node_idx, :, instance_idx]
point_in_zone = []
if has_zone_info:
for zone_shape in zone_shapes:
if not is_point_in_polygon((x, y), zone_shape['points']):
point_in_zone.append(False)
if len(point_in_zone) >= len(zone_shapes):
continue
# Create Labelme shape
shape = Shape(label=node_name,
shape_type='point',
group_id=None,
flags={},
visible=True
)
shape.points = [[x, y]]
# Add shape to annotation
shape_list.append(shape)

json_file = os.path.join(
output_folder, f"{video_name}_{frame_idx:0>{9}}.json")
img_path = json_file.replace('.json', '.png')
save_labels(json_file, img_path, shape_list, height, width)
if frame_idx % 100 == 0:
print(f"Saving file {json_file}")


if __name__ == '__main__':
# Example usage
# plot_keypoint_locations("/Downloads/R2311_P4S1_reencoded.h5",
# keypoint_name='head')
# plot_all_keypoints("/Downloads/R2311_P4S1_reencoded.h5")
# plt.show()
convert_sleap_h5_to_labelme(
"/Downloads/R2311_P4S1_reencoded.h5")
14 changes: 14 additions & 0 deletions annolid/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from annolid.gui.widgets.video_slider import VideoSlider, VideoSliderMark
from annolid.gui.widgets.step_size_widget import StepSizeWidget
from annolid.gui.widgets.downsample_videos_dialog import VideoRescaleWidget
from annolid.gui.widgets.convert_sleap_dialog import ConvertSleapDialog
from annolid.postprocessing.quality_control import pred_dict_to_labelme
from annolid.annotation.timestamps import convert_frame_number_to_time
from annolid.segmentation.SAM.edge_sam_bg import VideoProcessor
Expand Down Expand Up @@ -357,6 +358,14 @@ def __init__(self,
self.tr("Downsample Videos")
)

convert_sleap = action(
self.tr("&Load SLEAP h5"),
self.convert_sleap_h5_to_labelme,
None,
"Load SLEAP h5",
self.tr("Load SLEAP h5")
)

step_size = QtWidgets.QWidgetAction(self)
step_size.setIcon(QtGui.QIcon(
str(
Expand Down Expand Up @@ -510,6 +519,7 @@ def __init__(self,
utils.addActions(self.menus.file, (quality_control,))
utils.addActions(self.menus.file, (segment_cells,))
utils.addActions(self.menus.file, (downsample_video,))
utils.addActions(self.menus.file, (convert_sleap,))
utils.addActions(self.menus.file, (advance_params,))

utils.addActions(self.menus.view, (glitter2,))
Expand Down Expand Up @@ -589,6 +599,10 @@ def downsample_videos(self):
video_downsample_widget = VideoRescaleWidget()
video_downsample_widget.exec_()

def convert_sleap_h5_to_labelme(self):
convert_sleap_h5_widget = ConvertSleapDialog()
convert_sleap_h5_widget.exec_()

def openAudio(self):
if self.video_file:
self.audio_widget = AudioWidget(self.video_file)
Expand Down
3 changes: 2 additions & 1 deletion annolid/gui/label_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def load(self, filename):
elif data["imagePath"] is not None and data["imagePath"] != "":
# relative path from label file to relative path from cwd
imagePath = osp.join(osp.dirname(filename), data["imagePath"])
imageData = self.load_image_file(imagePath)
if osp.exists(imagePath):
imageData = self.load_image_file(imagePath)
flags = data.get("flags") or {}
if self.is_video_frame is None and data["imagePath"] is not None:
imagePath = data["imagePath"]
Expand Down
59 changes: 59 additions & 0 deletions annolid/gui/widgets/convert_sleap_dialog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import sys
from qtpy.QtWidgets import (QApplication, QDialog,
QPushButton, QFileDialog,
QMessageBox, QLineEdit)
from annolid.annotation.sleap2labelme import convert_sleap_h5_to_labelme


class ConvertSleapDialog(QDialog):
def __init__(self):
super().__init__()
self.initUI()

def initUI(self):
self.setWindowTitle('Convert H5 to Labelme')
self.setGeometry(100, 100, 400, 200)

self.txt_selected_file = QLineEdit(self)
self.txt_selected_file.setGeometry(50, 20, 300, 20)
self.txt_selected_file.setReadOnly(True)

self.btn_select_file = QPushButton('Select H5 File', self)
self.btn_select_file.setGeometry(50, 50, 200, 30)
self.btn_select_file.clicked.connect(self.select_file)

self.btn_run = QPushButton('Run', self)
self.btn_run.setGeometry(50, 100, 100, 30)
self.btn_run.clicked.connect(self.run_conversion)
self.btn_run.setEnabled(False)

self.btn_close = QPushButton('Close', self)
self.btn_close.setGeometry(200, 100, 100, 30)
self.btn_close.clicked.connect(self.close)

def select_file(self):
file_dialog = QFileDialog(self)
file_dialog.setNameFilter("H5 Files (*.h5)")
file_dialog.setWindowTitle("Select H5 File")
file_dialog.setFileMode(QFileDialog.ExistingFile)
if file_dialog.exec_():
file_paths = file_dialog.selectedFiles()
if file_paths:
self.h5_file_path = file_paths[0]
self.txt_selected_file.setText(self.h5_file_path)
self.btn_run.setEnabled(True)

def run_conversion(self):
try:
convert_sleap_h5_to_labelme(self.h5_file_path)
QMessageBox.information(
self, "Success", "Conversion completed successfully.")
except Exception as e:
QMessageBox.critical(self, "Error", f"An error occurred: {str(e)}")


if __name__ == '__main__':
app = QApplication(sys.argv)
dialog = ConvertSleapDialog()
dialog.show()
sys.exit(app.exec_())
Loading

0 comments on commit 99b5af9

Please sign in to comment.