Skip to content

Commit

Permalink
Make video_det pseudolabeler take in a filtered_raw/ dir as input
Browse files Browse the repository at this point in the history
  • Loading branch information
kdu4108 committed Aug 5, 2024
1 parent e50fa4d commit dab7df1
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions pseudolabeling/pseudolabel_video_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,26 @@

# Set up argument parser
parser = argparse.ArgumentParser(description='Process video frames with YOLO')
parser.add_argument('--source_dir', type=str, required=True, help='Path to the source dir containing tar files of video shards')
parser.add_argument('--yolo_path', type=str, default="/store/swissai/a08/pseudolabelers/yolov8n.pt", help='Path to the YOLO model')
parser.add_argument("-I", '--input_dir', type=str, required=True, help='Path to the source dir containing tar files of video shards. Should be a subdir of `filtered_raw/`.')
parser.add_argument("-O", '--output_dir', type=str, default=None, help='Path to the target dir to save the bounding box outputs. Default None means it will be inferred.')
parser.add_argument("-M", '--yolo_path', type=str, default="/store/swissai/a08/pseudolabelers/yolov8n.pt", help='Path to the YOLO model')
parser.add_argument('--nth_frame', type=int, default=30, help='Select every nth frame (default: 30)')
parser.add_argument('--max_frames', type=int, default=None, help='Maximum number of frames to process (default: None, process all)')
parser.add_argument('--save_frames', type=bool, default=False, help='Whether to save frames')
args = parser.parse_args()

SOURCE_DIR = args.source_dir
if "filtered_raw" not in args.input_dir:
raise ValueError(f"Expected input dir to be a subdir of `filtered_raw/`, instead received {args.input_dir}.")

SOURCE_DIR = args.input_dir
NTH_FRAME = args.nth_frame
MAX_FRAMES = args.max_frames
SAVE_FRAMES = args.save_frames
JSON_OUTPUT_DIR = Path(SOURCE_DIR).parent.absolute() / "video_det/"
JSON_OUTPUT_DIR = (
args.output_dir
if args.output_dir is not None
else os.path.join(args.input_dir.replace("filtered_raw", "4m"), "video_det")
)

# Ensure output directories exist
os.makedirs(JSON_OUTPUT_DIR, exist_ok=True)
Expand Down

0 comments on commit dab7df1

Please sign in to comment.