Skip to content

Commit

Permalink
get_inference.
Browse files Browse the repository at this point in the history
  • Loading branch information
anarkiwi committed Sep 17, 2024
1 parent c6d9db0 commit 54ccab8
Showing 1 changed file with 107 additions and 64 deletions.
171 changes: 107 additions & 64 deletions gamutrf/grscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,12 @@ def __init__(

if description:
description = description.strip('"')
stare = freq_end == 0

##################################################
# Parameters
##################################################
self.wavelearner = wavelearner
self.iqtlabs = iqtlabs
self.tag_now = tag_now

##################################################
# Blocks
Expand All @@ -127,7 +125,6 @@ def __init__(
logging.info(f"gamutrf {pbr_version} with gr-iqtlabs {griqtlabs_path}")

tune_step_fft, tune_step_hz, peak_fft_range = self.calc_rates(
stare,
freq_start,
freq_end,
sweep_sec,
Expand All @@ -154,6 +151,7 @@ def __init__(
fft_processor_affinity,
low_power_hold_down,
slew_rx_time,
tag_now,
)
self.fft_blocks = (
self.get_dc_blocks(
Expand Down Expand Up @@ -190,7 +188,7 @@ def __init__(
description=description,
rotate_secs=rotate_secs,
pre_fft=pretune,
tag_now=self.tag_now,
tag_now=tag_now,
low_power_hold_down=(not pretune and low_power_hold_down),
slew_rx_time=slew_rx_time,
peak_fft_range=peak_fft_range,
Expand Down Expand Up @@ -225,7 +223,7 @@ def __init__(
"",
)
self.connect((self.retune_pre_fft, 0), (iq_zmq_block, 0))
self.last_db_block = self.fft_blocks[-1]
last_db_block = self.fft_blocks[-1]
self.samples_blocks = []
self.write_samples_block = None
if write_samples:
Expand Down Expand Up @@ -264,9 +262,86 @@ def __init__(
self.pduzmq_block = pduzmq(fft_zmq_block_addr)
logging.info("serving FFT on %s", fft_zmq_block_addr)

self.inference_blocks = []
self.image_inference_block = None
self.iq_inference_block = None
self.image_inference_block, self.iq_inference_block, self.inference_blocks = (
self.get_inference(
colormap,
inference_batch,
inference_min_confidence,
inference_min_db,
inference_model_name,
inference_model_server,
inference_output_dir,
inference_text_color,
iq_inference_background,
iq_inference_model_name,
iq_inference_model_server,
iq_power_inference,
n_image,
n_inference,
nfft,
rotate_secs,
samp_rate,
tune_step_fft,
)
)

self.inference_output_block = self.connect_inference(
compass,
external_gps_server,
external_gps_server_port,
fft_batch_size,
gps_server,
self.image_inference_block,
inference_addr,
self.inference_blocks,
inference_output_dir,
inference_port,
self.iq_inference_block,
iq_inference_squelch_alpha,
iq_inference_squelch_db,
last_db_block,
mqtt_server,
nfft,
retune_fft,
self.retune_pre_fft,
stare,
use_external_gps,
use_external_heading,
self.write_samples_block,
)

if pretune:
self.msg_connect((self.retune_pre_fft, "tune"), (self.sources[0], cmd_port))
self.msg_connect((self.retune_pre_fft, "tune"), (retune_fft, "cmd"))
else:
self.msg_connect((retune_fft, "tune"), (self.sources[0], cmd_port))
self.msg_connect((retune_fft, "json"), (self.pduzmq_block, "json"))
self.connect_blocks(self.sources[0], self.sources[1:])
self.connect_blocks(self.sources[-1], self.fft_blocks)
self.connect_blocks(self.retune_pre_fft, self.samples_blocks)

def get_inference(
self,
colormap,
inference_batch,
inference_min_confidence,
inference_min_db,
inference_model_name,
inference_model_server,
inference_output_dir,
inference_text_color,
iq_inference_background,
iq_inference_model_name,
iq_inference_model_server,
iq_power_inference,
n_image,
n_inference,
nfft,
rotate_secs,
samp_rate,
tune_step_fft,
):
inference_blocks = []

if inference_output_dir:
Path(inference_output_dir).mkdir(parents=True, exist_ok=True)
Expand All @@ -277,7 +352,7 @@ def __init__(
inference_text_color = ",".join(
[str(c) for c in [wc.blue, wc.green, wc.red]]
)
self.image_inference_block = self.iqtlabs.image_inference(
image_inference_block = self.iqtlabs.image_inference(
tag="rx_freq",
vlen=nfft,
x=640,
Expand All @@ -301,10 +376,9 @@ def __init__(
samp_rate=int(samp_rate),
text_color=inference_text_color,
)
self.inference_blocks.append(self.image_inference_block)

inference_blocks.append(image_inference_block)
if iq_inference_model_server and iq_inference_model_name:
self.iq_inference_block = iqtlabs.iq_inference(
iq_inference_block = self.iqtlabs.iq_inference(
tag="rx_freq",
vlen=nfft,
n_vlen=1,
Expand All @@ -319,67 +393,33 @@ def __init__(
background=iq_inference_background,
batch=inference_batch,
)
self.inference_blocks.append(self.iq_inference_block)

self.inference_output_block = self.connect_inference(
self.inference_blocks,
self.iq_inference_block,
self.image_inference_block,
inference_addr,
inference_port,
mqtt_server,
compass,
gps_server,
use_external_gps,
use_external_heading,
external_gps_server,
external_gps_server_port,
iq_inference_squelch_db,
iq_inference_squelch_alpha,
fft_batch_size,
nfft,
self.retune_pre_fft,
retune_fft,
stare,
self.last_db_block,
self.write_samples_block,
inference_output_dir,
)

if pretune:
self.msg_connect((self.retune_pre_fft, "tune"), (self.sources[0], cmd_port))
self.msg_connect((self.retune_pre_fft, "tune"), (retune_fft, "cmd"))
else:
self.msg_connect((retune_fft, "tune"), (self.sources[0], cmd_port))
self.msg_connect((retune_fft, "json"), (self.pduzmq_block, "json"))
self.connect_blocks(self.sources[0], self.sources[1:])
self.connect_blocks(self.sources[-1], self.fft_blocks)
self.connect_blocks(self.retune_pre_fft, self.samples_blocks)
inference_blocks.append(iq_inference_block)
return (image_inference_block, iq_inference_block, inference_blocks)

def connect_inference(
self,
inference_blocks,
iq_inference_block,
image_inference_block,
inference_addr,
inference_port,
mqtt_server,
compass,
gps_server,
use_external_gps,
use_external_heading,
external_gps_server,
external_gps_server_port,
iq_inference_squelch_db,
iq_inference_squelch_alpha,
fft_batch_size,
gps_server,
image_inference_block,
inference_addr,
inference_blocks,
inference_output_dir,
inference_port,
iq_inference_block,
iq_inference_squelch_alpha,
iq_inference_squelch_db,
last_db_block,
mqtt_server,
nfft,
retune_pre_fft,
retune_fft,
retune_pre_fft,
stare,
last_db_block,
use_external_gps,
use_external_heading,
write_samples_block,
inference_output_dir,
):
if not inference_blocks:
return None
Expand Down Expand Up @@ -433,7 +473,6 @@ def connect_inference(

def calc_rates(
self,
stare,
freq_start,
freq_end,
sweep_sec,
Expand All @@ -444,6 +483,7 @@ def calc_rates(
tune_step_fft,
peak_fft_range,
):
stare = freq_end == 0
tune_step_hz = int(samp_rate * tuneoverlap)
if stare:
freq_range = samp_rate
Expand Down Expand Up @@ -526,6 +566,7 @@ def get_pretune_block(
pretune,
low_power_hold_down,
slew_rx_time,
tag_now,
):
# if pretuning, the pretune block will also do the batching.
if pretune:
Expand All @@ -541,7 +582,7 @@ def get_pretune_block(
tune_step_fft=tune_step_fft,
skip_tune_step_fft=skip_tune_step,
tuning_ranges=tuning_ranges,
tag_now=self.tag_now,
tag_now=tag_now,
low_power_hold_down=low_power_hold_down,
slew_rx_time=slew_rx_time,
)
Expand Down Expand Up @@ -652,6 +693,7 @@ def get_fft_blocks(
fft_processor_affinity,
low_power_hold_down,
slew_rx_time,
tag_now,
):
fft_batch_size, fft_blocks = self.get_offload_fft_blocks(
vkfft,
Expand All @@ -673,6 +715,7 @@ def get_fft_blocks(
pretune,
low_power_hold_down,
slew_rx_time,
tag_now,
)
return (fft_batch_size, retune_pre_fft, [retune_pre_fft] + fft_blocks)

Expand Down

0 comments on commit 54ccab8

Please sign in to comment.