diff --git a/gamutrf/grscan.py b/gamutrf/grscan.py index 3e81a594..050c7941 100644 --- a/gamutrf/grscan.py +++ b/gamutrf/grscan.py @@ -20,6 +20,7 @@ sys.exit(1) from gamutrf.grsource import get_source +from gamutrf.gryolo import yolo_bbox from gamutrf.utils import endianstr @@ -36,6 +37,8 @@ def __init__( freq_end=1e9, freq_start=100e6, igain=0, + inference_min_confidence=0.5, + inference_nms_threshold=0.5, inference_min_db=-200, inference_model_server="", inference_model_name="", @@ -207,10 +210,10 @@ def __init__( self.inference_blocks = [ blocks.stream_to_vector(gr.sizeof_float * nfft, 1), self.image_inference_block, - blocks.file_sink( - gr.sizeof_char, - str(Path(inference_output_dir, "predictions.txt")), - False, + yolo_bbox( + str(Path(inference_output_dir, "predictions")), + inference_min_confidence, + inference_nms_threshold, ), ] diff --git a/gamutrf/grterminal_sink.py b/gamutrf/grterminal_sink.py new file mode 100644 index 00000000..13ed99f8 --- /dev/null +++ b/gamutrf/grterminal_sink.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import sys +import numpy as np + +try: + from gnuradio import gr # pytype: disable=import-error +except ModuleNotFoundError as err: # pragma: no cover + print( + "Run from outside a supported environment, please run via Docker (https://github.com/IQTLabs/gamutRF#readme): %s" + % err + ) + sys.exit(1) + + +class terminal_sink(gr.sync_block): + # Prints output layer of classifier to terminal for troubleshooting + def __init__(self, input_vlen, batch_size): + self.input_vlen = input_vlen + self.batch_size = batch_size + gr.sync_block.__init__( + self, + name="terminal_sink", + in_sig=[(np.float32, self.input_vlen)], + out_sig=None, + ) + self.batch_ctr = 0 + + def work(self, input_items, output_items): + in0 = input_items[0] + _batch = in0.reshape(self.batch_size, -1) + self.batch_ctr += 1 + return len(input_items[0]) diff --git a/gamutrf/gryolo.py b/gamutrf/gryolo.py index 86079ed4..3fbfcb6a 100644 --- a/gamutrf/gryolo.py +++ b/gamutrf/gryolo.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import cv2 -import numpy as np -import pmt +import json +import os import sys from pathlib import Path +import cv2 +import numpy as np try: from gnuradio import gr # pytype: disable=import-error @@ -16,149 +17,102 @@ sys.exit(1) -class terminal_sink(gr.sync_block): - # Prints output layer of classifier to terminal for troubleshooting - def __init__(self, input_vlen, batch_size): - self.input_vlen = input_vlen - self.batch_size = batch_size - gr.sync_block.__init__( - self, - name="terminal_sink", - in_sig=[(np.float32, self.input_vlen)], - out_sig=None, - ) - self.batch_ctr = 0 +DELIM = "\n\n" - def work(self, input_items, output_items): - in0 = input_items[0] - batch = in0.reshape(self.batch_size, -1) - self.batch_ctr += 1 - return len(input_items[0]) - -# TODO: refactor to consume predictions feed, and retrieve png file from local disk -# (provided by image_inference block) class yolo_bbox(gr.sync_block): def __init__( self, - image_shape, - prediction_shape, - batch_size, - sample_rate, output_dir, confidence_threshold=0.5, nms_threshold=0.5, ): - self.image_shape = image_shape - self.image_vlen = np.prod(image_shape) - self.prediction_shape = prediction_shape - self.prediction_vlen = np.prod(prediction_shape) - self.batch_size = batch_size - self.sample_rate = sample_rate self.output_dir = output_dir self.confidence_threshold = confidence_threshold self.nms_threshold = nms_threshold + self.yaml_buffer = "" gr.sync_block.__init__( self, name="yolo_bbox", - in_sig=[(np.float32, self.image_vlen), (np.float32, self.prediction_vlen)], + in_sig=[np.ubyte], out_sig=None, ) - def draw_bounding_box(self, img, class_id, confidence, x, y, x_plus_w, y_plus_h): - # label = f'{CLASSES[class_id]} ({confidence:.2f})' - # label = f'{class_id} ({confidence:.2f})' - label = f"{class_id}" - color = (255, 255, 255) # self.colors[class_id] + def work(self, input_items, output_items): + n = 0 + for input_item in input_items: + raw_input_item = input_item.tobytes().decode("utf8") + n += len(raw_input_item) + self.yaml_buffer += raw_input_item + while True: + delim_pos = self.yaml_buffer.find(DELIM) + if delim_pos == -1: + break + raw_item = self.yaml_buffer[:delim_pos] + item = json.loads(raw_item) + self.yaml_buffer = self.yaml_buffer[delim_pos + len(DELIM) :] + self.process_item(item) + return n + + def draw_bounding_box(self, img, name, confidence, x, y, x_plus_w, y_plus_h): + label = f"{name}" + color = (255, 255, 255) cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2) cv2.putText( img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2 ) - def work(self, input_items, output_items): - rx_times = [ - sum(pmt.to_python(rx_time_pmt.value)) - for rx_time_pmt in self.get_tags_in_window( - 0, 0, len(input_items[0]), pmt.to_pmt("rx_time") - ) - ] - rx_freqs = [ - pmt.to_python(rx_freq_pmt.value) - for rx_freq_pmt in self.get_tags_in_window( - 0, 0, len(input_items[0]), pmt.to_pmt("rx_freq") - ) - ] - - image = input_items[0][0] - image = image.reshape(self.image_shape) - image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - original_image = image - [height, width, _] = original_image.shape - length = max((height, width)) - image = np.zeros((length, length, 3), np.uint8) - image[0:height, 0:width] = original_image - scale = length / 640 - - prediction = input_items[1][0] - prediction = prediction.reshape(self.prediction_shape) - prediction = np.array([cv2.transpose(prediction[0])]) - rows = prediction.shape[1] + def process_item(self, item): + predictions = item.get("predictions", None) + if not predictions: + return boxes = [] scores = [] - class_ids = [] + detections = [] - for i in range(rows): - classes_scores = prediction[0][i][4:] - (minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc( - classes_scores - ) - if maxScore >= self.confidence_threshold: + for name, prediction_data in predictions.items(): + for prediction in prediction_data: + conf = prediction["conf"] + if conf < self.confidence_threshold: + continue + xywh = prediction["xywh"] box = [ - prediction[0][i][0] - (0.5 * prediction[0][i][2]), - prediction[0][i][1] - (0.5 * prediction[0][i][3]), - prediction[0][i][2], - prediction[0][i][3], + xywh[0] - (0.5 * xywh[2]), + xywh[1] - (0.5 * xywh[3]), + xywh[2], + xywh[3], ] + detections.append({"box": box, "score": conf, "name": name}) boxes.append(box) - scores.append(maxScore) - class_ids.append(maxClassIndex) + scores.append(conf) + + if not detections: + return + original_image = cv2.imread(item["image_path"]) result_boxes = cv2.dnn.NMSBoxes( boxes, scores, self.confidence_threshold, self.nms_threshold, 0.5, 200 ) - detections = [] - for i in range(len(result_boxes)): - index = result_boxes[i] - box = boxes[index] - detection = { - "class_id": class_ids[index], - #'class_name': CLASSES[class_ids[index]], - "confidence": scores[index], - "box": box, - "scale": scale, - } - detections.append(detection) + # TODO: output to ZMQ + for detection in detections: self.draw_bounding_box( original_image, - class_ids[index], - scores[index], - round(box[0] * scale), - round(box[1] * scale), - round((box[0] + box[2]) * scale), - round((box[1] + box[3]) * scale), + detection["name"], + detection["score"], + round(box[0]), + round(box[1]), + round(box[0] + box[2]), + round(box[1] + box[3]), ) - Path(self.output_dir, "predictions").mkdir(parents=True, exist_ok=True) + Path(self.output_dir).mkdir(parents=True, exist_ok=True) filename = str( Path( self.output_dir, - "predictions", - f"prediction_{rx_times[-1]:.3f}_{rx_freqs[-1]:.0f}Hz_{self.sample_rate:.0f}sps.png", + "_".join(["prediction", os.path.basename(item["image_path"])]), ) ) cv2.imwrite(filename, original_image) - - return 1 # len(input_items[0]) diff --git a/gamutrf/scan.py b/gamutrf/scan.py index ba247e4b..351484f3 100644 --- a/gamutrf/scan.py +++ b/gamutrf/scan.py @@ -216,6 +216,20 @@ def argument_parser(): action="store_true", help="Use dc_block_cc long form", ) + parser.add_argument( + "--inference_min_confidence", + dest="inference_min_confidence", + type=float, + default=0.5, + help="minimum confidence score to plot", + ) + parser.add_argument( + "--inference_nms_confidence", + dest="inference_nms_threshold", + type=float, + default=0.5, + help="NMS threshold", + ) parser.add_argument( "--inference_min_db", dest="inference_min_db",