Skip to content

Commit

Permalink
Implement Torchserve + bounding boxes.
Browse files Browse the repository at this point in the history
  • Loading branch information
anarkiwi committed Oct 17, 2023
1 parent 7ec97bc commit 3760b2d
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 107 deletions.
11 changes: 7 additions & 4 deletions gamutrf/grscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
sys.exit(1)

from gamutrf.grsource import get_source
from gamutrf.gryolo import yolo_bbox
from gamutrf.utils import endianstr


Expand All @@ -36,6 +37,8 @@ def __init__(
freq_end=1e9,
freq_start=100e6,
igain=0,
inference_min_confidence=0.5,
inference_nms_threshold=0.5,
inference_min_db=-200,
inference_model_server="",
inference_model_name="",
Expand Down Expand Up @@ -207,10 +210,10 @@ def __init__(
self.inference_blocks = [
blocks.stream_to_vector(gr.sizeof_float * nfft, 1),
self.image_inference_block,
blocks.file_sink(
gr.sizeof_char,
str(Path(inference_output_dir, "predictions.txt")),
False,
yolo_bbox(
str(Path(inference_output_dir, "predictions")),
inference_min_confidence,
inference_nms_threshold,
),
]

Expand Down
33 changes: 33 additions & 0 deletions gamutrf/grterminal_sink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
import numpy as np

Check warning on line 4 in gamutrf/grterminal_sink.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/grterminal_sink.py#L3-L4

Added lines #L3 - L4 were not covered by tests

try:
from gnuradio import gr # pytype: disable=import-error

Check warning on line 7 in gamutrf/grterminal_sink.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/grterminal_sink.py#L6-L7

Added lines #L6 - L7 were not covered by tests
except ModuleNotFoundError as err: # pragma: no cover
print(
"Run from outside a supported environment, please run via Docker (https://github.com/IQTLabs/gamutRF#readme): %s"
% err
)
sys.exit(1)


class terminal_sink(gr.sync_block):

Check warning on line 16 in gamutrf/grterminal_sink.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/grterminal_sink.py#L16

Added line #L16 was not covered by tests
# Prints output layer of classifier to terminal for troubleshooting
def __init__(self, input_vlen, batch_size):
self.input_vlen = input_vlen
self.batch_size = batch_size
gr.sync_block.__init__(

Check warning on line 21 in gamutrf/grterminal_sink.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/grterminal_sink.py#L18-L21

Added lines #L18 - L21 were not covered by tests
self,
name="terminal_sink",
in_sig=[(np.float32, self.input_vlen)],
out_sig=None,
)
self.batch_ctr = 0

Check warning on line 27 in gamutrf/grterminal_sink.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/grterminal_sink.py#L27

Added line #L27 was not covered by tests

def work(self, input_items, output_items):
in0 = input_items[0]
_batch = in0.reshape(self.batch_size, -1)
self.batch_ctr += 1
return len(input_items[0])

Check warning on line 33 in gamutrf/grterminal_sink.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/grterminal_sink.py#L29-L33

Added lines #L29 - L33 were not covered by tests
160 changes: 57 additions & 103 deletions gamutrf/gryolo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import cv2
import numpy as np
import pmt
import json
import os
import sys
from pathlib import Path
import cv2
import numpy as np

try:
from gnuradio import gr # pytype: disable=import-error
Expand All @@ -16,149 +17,102 @@
sys.exit(1)


class terminal_sink(gr.sync_block):
# Prints output layer of classifier to terminal for troubleshooting
def __init__(self, input_vlen, batch_size):
self.input_vlen = input_vlen
self.batch_size = batch_size
gr.sync_block.__init__(
self,
name="terminal_sink",
in_sig=[(np.float32, self.input_vlen)],
out_sig=None,
)
self.batch_ctr = 0
DELIM = "\n\n"

def work(self, input_items, output_items):
in0 = input_items[0]
batch = in0.reshape(self.batch_size, -1)
self.batch_ctr += 1
return len(input_items[0])


# TODO: refactor to consume predictions feed, and retrieve png file from local disk
# (provided by image_inference block)
class yolo_bbox(gr.sync_block):
def __init__(
self,
image_shape,
prediction_shape,
batch_size,
sample_rate,
output_dir,
confidence_threshold=0.5,
nms_threshold=0.5,
):
self.image_shape = image_shape
self.image_vlen = np.prod(image_shape)
self.prediction_shape = prediction_shape
self.prediction_vlen = np.prod(prediction_shape)
self.batch_size = batch_size
self.sample_rate = sample_rate
self.output_dir = output_dir
self.confidence_threshold = confidence_threshold
self.nms_threshold = nms_threshold
self.yaml_buffer = ""

gr.sync_block.__init__(
self,
name="yolo_bbox",
in_sig=[(np.float32, self.image_vlen), (np.float32, self.prediction_vlen)],
in_sig=[np.ubyte],
out_sig=None,
)

def draw_bounding_box(self, img, class_id, confidence, x, y, x_plus_w, y_plus_h):
# label = f'{CLASSES[class_id]} ({confidence:.2f})'
# label = f'{class_id} ({confidence:.2f})'
label = f"{class_id}"
color = (255, 255, 255) # self.colors[class_id]
def work(self, input_items, output_items):
n = 0
for input_item in input_items:
raw_input_item = input_item.tobytes().decode("utf8")
n += len(raw_input_item)
self.yaml_buffer += raw_input_item
while True:
delim_pos = self.yaml_buffer.find(DELIM)
if delim_pos == -1:
break
raw_item = self.yaml_buffer[:delim_pos]
item = json.loads(raw_item)
self.yaml_buffer = self.yaml_buffer[delim_pos + len(DELIM) :]
self.process_item(item)
return n

Check warning on line 56 in gamutrf/gryolo.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/gryolo.py#L43-L56

Added lines #L43 - L56 were not covered by tests

def draw_bounding_box(self, img, name, confidence, x, y, x_plus_w, y_plus_h):
label = f"{name}"
color = (255, 255, 255)

Check warning on line 60 in gamutrf/gryolo.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/gryolo.py#L59-L60

Added lines #L59 - L60 were not covered by tests
cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2)
cv2.putText(
img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2
)

def work(self, input_items, output_items):
rx_times = [
sum(pmt.to_python(rx_time_pmt.value))
for rx_time_pmt in self.get_tags_in_window(
0, 0, len(input_items[0]), pmt.to_pmt("rx_time")
)
]
rx_freqs = [
pmt.to_python(rx_freq_pmt.value)
for rx_freq_pmt in self.get_tags_in_window(
0, 0, len(input_items[0]), pmt.to_pmt("rx_freq")
)
]

image = input_items[0][0]
image = image.reshape(self.image_shape)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
original_image = image
[height, width, _] = original_image.shape
length = max((height, width))
image = np.zeros((length, length, 3), np.uint8)
image[0:height, 0:width] = original_image
scale = length / 640

prediction = input_items[1][0]
prediction = prediction.reshape(self.prediction_shape)
prediction = np.array([cv2.transpose(prediction[0])])
rows = prediction.shape[1]
def process_item(self, item):
predictions = item.get("predictions", None)
if not predictions:
return

Check warning on line 69 in gamutrf/gryolo.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/gryolo.py#L67-L69

Added lines #L67 - L69 were not covered by tests

boxes = []
scores = []
class_ids = []
detections = []

Check warning on line 73 in gamutrf/gryolo.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/gryolo.py#L73

Added line #L73 was not covered by tests

for i in range(rows):
classes_scores = prediction[0][i][4:]
(minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc(
classes_scores
)
if maxScore >= self.confidence_threshold:
for name, prediction_data in predictions.items():
for prediction in prediction_data:
conf = prediction["conf"]
if conf < self.confidence_threshold:
continue
xywh = prediction["xywh"]

Check warning on line 80 in gamutrf/gryolo.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/gryolo.py#L75-L80

Added lines #L75 - L80 were not covered by tests
box = [
prediction[0][i][0] - (0.5 * prediction[0][i][2]),
prediction[0][i][1] - (0.5 * prediction[0][i][3]),
prediction[0][i][2],
prediction[0][i][3],
xywh[0] - (0.5 * xywh[2]),
xywh[1] - (0.5 * xywh[3]),
xywh[2],
xywh[3],
]
detections.append({"box": box, "score": conf, "name": name})

Check warning on line 87 in gamutrf/gryolo.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/gryolo.py#L87

Added line #L87 was not covered by tests
boxes.append(box)
scores.append(maxScore)
class_ids.append(maxClassIndex)
scores.append(conf)

Check warning on line 89 in gamutrf/gryolo.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/gryolo.py#L89

Added line #L89 was not covered by tests

if not detections:
return

Check warning on line 92 in gamutrf/gryolo.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/gryolo.py#L91-L92

Added lines #L91 - L92 were not covered by tests

original_image = cv2.imread(item["image_path"])

Check warning on line 94 in gamutrf/gryolo.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/gryolo.py#L94

Added line #L94 was not covered by tests
result_boxes = cv2.dnn.NMSBoxes(
boxes, scores, self.confidence_threshold, self.nms_threshold, 0.5, 200
)

detections = []
for i in range(len(result_boxes)):
index = result_boxes[i]
box = boxes[index]
detection = {
"class_id": class_ids[index],
#'class_name': CLASSES[class_ids[index]],
"confidence": scores[index],
"box": box,
"scale": scale,
}
detections.append(detection)
# TODO: output to ZMQ
for detection in detections:

Check warning on line 100 in gamutrf/gryolo.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/gryolo.py#L100

Added line #L100 was not covered by tests
self.draw_bounding_box(
original_image,
class_ids[index],
scores[index],
round(box[0] * scale),
round(box[1] * scale),
round((box[0] + box[2]) * scale),
round((box[1] + box[3]) * scale),
detection["name"],
detection["score"],
round(box[0]),
round(box[1]),
round(box[0] + box[2]),
round(box[1] + box[3]),
)

Path(self.output_dir, "predictions").mkdir(parents=True, exist_ok=True)
Path(self.output_dir).mkdir(parents=True, exist_ok=True)

Check warning on line 111 in gamutrf/gryolo.py

View check run for this annotation

Codecov / codecov/patch

gamutrf/gryolo.py#L111

Added line #L111 was not covered by tests
filename = str(
Path(
self.output_dir,
"predictions",
f"prediction_{rx_times[-1]:.3f}_{rx_freqs[-1]:.0f}Hz_{self.sample_rate:.0f}sps.png",
"_".join(["prediction", os.path.basename(item["image_path"])]),
)
)
cv2.imwrite(filename, original_image)

return 1 # len(input_items[0])
14 changes: 14 additions & 0 deletions gamutrf/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,20 @@ def argument_parser():
action="store_true",
help="Use dc_block_cc long form",
)
parser.add_argument(
"--inference_min_confidence",
dest="inference_min_confidence",
type=float,
default=0.5,
help="minimum confidence score to plot",
)
parser.add_argument(
"--inference_nms_confidence",
dest="inference_nms_threshold",
type=float,
default=0.5,
help="NMS threshold",
)
parser.add_argument(
"--inference_min_db",
dest="inference_min_db",
Expand Down

0 comments on commit 3760b2d

Please sign in to comment.