Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added pseudolabel_frames.py #19

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions pseudolabeling/pseudolabel_video_det.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import tarfile
import os
import cv2
from ultralytics import YOLO
import argparse
import json
import jsonlines
from pathlib import Path
import tempfile


# Set up argument parser
parser = argparse.ArgumentParser(description='Process video frames with YOLO')
parser.add_argument("-I", '--input_dir', type=str, required=True, help='Path to the source dir containing tar files of video shards. Should be a subdir of `filtered_raw/`.')
parser.add_argument("-O", '--output_dir', type=str, default=None, help='Path to the target dir to save the bounding box outputs. Default None means it will be inferred.')
parser.add_argument("-M", '--yolo_path', type=str, default="/store/swissai/a08/pseudolabelers/yolov8n.pt", help='Path to the YOLO model')
parser.add_argument('--nth_frame', type=int, default=30, help='Select every nth frame (default: 30)')
parser.add_argument('--max_frames', type=int, default=None, help='Maximum number of frames to process (default: None, process all)')
parser.add_argument('--save_frames', type=bool, default=False, help='Whether to save frames')
args = parser.parse_args()

if "filtered_raw" not in args.input_dir:
raise ValueError(f"Expected input dir to be a subdir of `filtered_raw/`, instead received {args.input_dir}.")

SOURCE_DIR = args.input_dir
NTH_FRAME = args.nth_frame
MAX_FRAMES = args.max_frames
SAVE_FRAMES = args.save_frames
JSON_OUTPUT_DIR = (
args.output_dir
if args.output_dir is not None
else os.path.join(args.input_dir.replace("filtered_raw", "4m"), "video_det")
)

# Ensure output directories exist
os.makedirs(JSON_OUTPUT_DIR, exist_ok=True)

# Load the YOLO model
model = YOLO(args.yolo_path) # pretrained YOLOv8n model

for tfile in sorted(os.listdir(SOURCE_DIR)):
print(tfile)
if tfile.endswith(".tar"):
# Get the shard number from the input file name and create the output tar file
shard_number = os.path.splitext(os.path.basename(tfile))[0]
output_tar_path = os.path.join(JSON_OUTPUT_DIR, f"{shard_number}.tar")

# Extract the tar file
with tempfile.TemporaryDirectory() as tmpdirname:
print('created temporary directory', tmpdirname)
output_dir = os.path.join(tmpdirname, "extracted_frames")
labeled_output_dir = os.path.join(tmpdirname, "labeled_frames")

os.makedirs(output_dir, exist_ok=True)
os.makedirs(labeled_output_dir, exist_ok=True)

with tarfile.open(os.path.join(SOURCE_DIR, tfile), "r") as tar:
tar.extractall(path=tmpdirname, numeric_owner=True)



with tarfile.open(output_tar_path, "w") as output_tar:
# Iterate through extracted files
for root, dirs, files in sorted(os.walk(tmpdirname)):
for file in files:
if file.endswith(".mp4"):
video_path = os.path.join(root, file)
video = cv2.VideoCapture(video_path)
frame_paths = []
frame_count = 0
processed_frames = 0
json_data_list = []

while True:
success, frame = video.read()
if not success:
break
if frame_count % NTH_FRAME == 0:
# Save the frame as an image
frame_path = os.path.join(output_dir, f"{file[:-4]}_frame_{frame_count}.jpg")
cv2.imwrite(frame_path, frame)
frame_paths.append(frame_path)
processed_frames += 1
frame_count += 1
if MAX_FRAMES and processed_frames >= MAX_FRAMES:
break
video.release()

# Apply pseudolabeling to the extracted frames
import pdb; pdb.set_trace()
results = model(frame_paths, project=labeled_output_dir, name=file[:-4])

for i, result in enumerate(results):
# Save labeled image
if SAVE_FRAMES:
result.save(filename=f'{file[:-4]}_labeled_frame_{i}.jpg')

# Extract bounding box information
boxes = result.boxes
frame_data = []
for box in boxes:
xyxy = box.xyxy[0].tolist() # get box coordinates
conf = box.conf.item() # get confidence score
cls = int(box.cls.item()) # get class id
frame_data.append({
"bbox": xyxy,
"confidence": conf,
"class": cls,
"class_name": result.names[cls]
})
json_data_list.append(frame_data)

# Save JSONL file
jsonl_filename = f"{file[:-4]}.jsonl"
jsonl_path = os.path.join(JSON_OUTPUT_DIR, jsonl_filename)
with jsonlines.open(jsonl_path, mode='w') as writer:
writer.write_all(json_data_list)

# Add JSONL file to the tar archive
output_tar.add(jsonl_path, arcname=jsonl_filename)

# Remove the temporary JSONL file
os.remove(jsonl_path)

print("Frame extraction, pseudolabeling, and JSONL export complete.")