Skip to content

Commit

Permalink
v0.5.0: Introduce RTSP stream support and multi-stream handling
Browse files Browse the repository at this point in the history
- Added RTSP stream support in `YOFLO` class.
- Enabled multiple RTSP URLs with `--rtsp_urls` argument.
- Modified `start_webcam_detection` to prioritize RTSP streams over local webcams.
- Improved error handling and logging for RTSP streams.
- Refactored `_webcam_detection_thread` to support both RTSP and webcam sources.
  • Loading branch information
CharlesCNorton committed Jul 23, 2024
1 parent 3795a1f commit 4270cbc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 28 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='yoflo',
version='0.4.3',
version='0.5.0',
packages=find_packages(),
include_package_data=True,
install_requires=[
Expand Down
70 changes: 43 additions & 27 deletions yoflo/yoflo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
inference_limit=None,
class_names=None,
webcam_indices=None,
rtsp_urls=None,
):
"""Initialize the YO-FLO class with configuration options."""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -49,6 +50,7 @@ def __init__(
self.last_inference_time = 0
self.inference_phrases = []
self.webcam_indices = webcam_indices if webcam_indices else [0]
self.rtsp_urls = rtsp_urls if rtsp_urls else []
if model_path:
self.init_model(model_path)

Expand All @@ -62,7 +64,6 @@ def init_model(self, model_path):
f"Model path {os.path.abspath(model_path)} is not a directory."
)
return

try:
logging.info(f"Attempting to load model from {os.path.abspath(model_path)}")
self.model = (
Expand Down Expand Up @@ -216,34 +217,45 @@ def download_model(self):
logging.error(f"Error downloading model: {e}")

def start_webcam_detection(self):
"""Start separate threads for each specified webcam."""
"""Start separate threads for each specified webcam or RTSP stream."""
try:
if self.webcam_threads:
logging.warning("Webcam detection is already running.")
return
self.stop_webcam_flag.clear()
for index in self.webcam_indices:
thread = threading.Thread(
target=self._webcam_detection_thread, args=(index,)
)
thread.start()
self.webcam_threads.append(thread)
if self.rtsp_urls:
for rtsp_url in self.rtsp_urls:
thread = threading.Thread(
target=self._webcam_detection_thread, args=(rtsp_url,)
)
thread.start()
self.webcam_threads.append(thread)
else:
for index in self.webcam_indices:
thread = threading.Thread(
target=self._webcam_detection_thread, args=(index,)
)
thread.start()
self.webcam_threads.append(thread)
except Exception as e:
logging.error(f"Error starting webcam detection: {e}")

def _webcam_detection_thread(self, index):
"""Run the webcam detection loop in a separate thread for a specific webcam."""
def _webcam_detection_thread(self, source):
"""Run the webcam detection loop in a separate thread for a specific webcam or RTSP stream."""
try:
cap = cv2.VideoCapture(index)
if isinstance(source, str):
cap = cv2.VideoCapture(source)
else:
cap = cv2.VideoCapture(source)
if not cap.isOpened():
logging.error(f"Error: Could not open webcam {index}.")
logging.error(f"Error: Could not open video source {source}.")
return
window_name = f"Object Detection Webcam {index}"
window_name = f"Object Detection Source {source}"
while not self.stop_webcam_flag.is_set():
ret, frame = cap.read()
if not ret:
logging.error(
f"Error: Failed to capture image from webcam {index}."
f"Error: Failed to capture image from source {source}."
)
break
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
Expand All @@ -268,7 +280,7 @@ def _webcam_detection_thread(self, index):
self.pretty_print_detections(filtered_detections)
else:
logging.info(
f"Detections from webcam {index}: {filtered_detections}"
f"Detections from source {source}: {filtered_detections}"
)
if not self.headless:
frame = self.plot_bbox(frame, filtered_detections)
Expand All @@ -279,7 +291,7 @@ def _webcam_detection_thread(self, index):
self.save_screenshot(frame)
if self.log_to_file_active:
self.log_alert(
f"Detections from webcam {index}: {filtered_detections}"
f"Detections from source {source}: {filtered_detections}"
)
elif self.phrase:
results = self.run_expression_comprehension(image_pil, self.phrase)
Expand All @@ -296,19 +308,19 @@ def _webcam_detection_thread(self, index):
if clean_result in ["yes", "no"] and self.log_to_file_active:
if self.log_to_file_active:
self.log_alert(
f"Expression Comprehension from webcam {index}: {clean_result} at {datetime.now()}"
f"Expression Comprehension from source {source}: {clean_result} at {datetime.now()}"
)
if self.inference_phrases:
inference_result, phrase_results = self.evaluate_inference_chain(
image_pil
)
logging.info(
f"Inference Chain result from webcam {index}: {inference_result}, Details: {phrase_results}"
f"Inference Chain result from source {source}: {inference_result}, Details: {phrase_results}"
)
if self.pretty_print:
for idx, result in enumerate(phrase_results):
logging.info(
f"Inference {idx + 1} from webcam {index}: {'PASS' if result else 'FAIL'}"
f"Inference {idx + 1} from source {source}: {'PASS' if result else 'FAIL'}"
)
self.inference_count += 1
self.update_inference_rate()
Expand All @@ -321,9 +333,9 @@ def _webcam_detection_thread(self, index):
if not self.headless:
cv2.destroyWindow(window_name)
except cv2.error as e:
logging.error(f"OpenCV error in webcam detection thread {index}: {e}")
logging.error(f"OpenCV error in detection thread {source}: {e}")
except Exception as e:
logging.error(f"Error in webcam detection thread {index}: {e}")
logging.error(f"Error in detection thread {source}: {e}")

def stop_webcam_detection(self):
"""Stop all webcam detection threads."""
Expand Down Expand Up @@ -399,7 +411,6 @@ def evaluate_inference_chain(self, image):
if not self.inference_phrases:
logging.error("No inference phrases set.")
return "FAIL", []

results = []
for phrase in self.inference_phrases:
result = self.run_expression_comprehension(image, phrase)
Expand All @@ -408,7 +419,6 @@ def evaluate_inference_chain(self, image):
results.append(True)
else:
results.append(False)

overall_result = "PASS" if results.count(True) >= 2 else "FAIL"
return overall_result, results
except Exception as e:
Expand Down Expand Up @@ -483,6 +493,13 @@ def main():
type=int,
help="Specify the indices of the webcams to use (e.g., 0 1 2).",
)
parser.add_argument(
"-rtsp",
"--rtsp_urls",
nargs="+",
type=str,
help="Specify the RTSP URLs for the video streams.",
)

group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
Expand All @@ -501,17 +518,18 @@ def main():
args = parser.parse_args()
if not args.model_path and not args.download_model:
parser.error("You must specify either --model_path or --download_model.")

try:
setup_logging(args.log_to_file)
webcam_indices = args.webcam_indices if args.webcam_indices else [0]
rtsp_urls = args.rtsp_urls if args.rtsp_urls else []
if args.download_model:
yo_flo = YOFLO(
display_inference_speed=args.display_inference_speed,
pretty_print=args.pretty_print,
inference_limit=args.inference_limit,
class_names=args.object_detection,
webcam_indices=webcam_indices,
rtsp_urls=rtsp_urls,
)
yo_flo.download_model()
else:
Expand All @@ -528,8 +546,8 @@ def main():
inference_limit=args.inference_limit,
class_names=args.object_detection,
webcam_indices=webcam_indices,
rtsp_urls=rtsp_urls,
)

if args.phrase:
yo_flo.phrase = args.phrase
if args.inference_chain:
Expand All @@ -545,10 +563,8 @@ def main():
time.sleep(1)
except KeyboardInterrupt:
yo_flo.stop_webcam_detection()

except Exception as e:
logging.error(f"An error occurred during main loop: {e}")

else:
input("Press Enter to stop...")
yo_flo.stop_webcam_detection()
Expand Down

0 comments on commit 4270cbc

Please sign in to comment.