Skip to content

Commit

Permalink
Add cls label to plot
Browse files Browse the repository at this point in the history
  • Loading branch information
iegrsy committed Nov 15, 2023
1 parent f4e9791 commit f7a4e16
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,20 @@ def __str__(self) -> str:
x1, y1, x2, y2 = detection.xyxy.tolist()[0]
cropped_image = frame[int(y1):int(y2), int(x1):int(x2)]

# TODO: use custom cls
# obj_cls = model_cls.predict(cropped_image, conf=opts.conf, verbose=opts.verbose, classes=__classes)
# print(f"[{obj_cls[0].probs.top1}] {obj_cls[0].names[obj_cls[0].probs.top1]}")
obj_cls = model_cls.predict(
cropped_image, conf=opts.conf, verbose=opts.verbose, classes=__classes)
clsid = obj_cls[0].probs.top1
box_label = obj_cls[0].names[clsid]
print(f"[{clsid}] {box_label}")

if opts.showpreview:
color = (0, 255, 0)
if clsid == 0:
color = (0, 0, 255)

cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
cv2.putText(frame, box_label, (int(x1), int(y1) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)

ffolder = os.path.join(f"build", f"{o.name}")
if not os.path.exists(ffolder):
Expand All @@ -107,7 +118,7 @@ def __str__(self) -> str:
process_total = process_total + ellapse_time

if opts.showpreview:
frame = detections[0].plot()
# frame = detections[0].plot()
cv2.imshow('Frame', frame)

if cv2.waitKey(1) & 0xFF == ord('q'):
Expand Down

0 comments on commit f7a4e16

Please sign in to comment.