Skip to content

Commit

Permalink
Merge pull request #594 from anarkiwi/tests2
Browse files Browse the repository at this point in the history
smoke tests for wavelearner.
  • Loading branch information
anarkiwi authored Mar 22, 2023
2 parents f4e934b + 5adc66f commit 8f6e990
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 41 deletions.
78 changes: 39 additions & 39 deletions gamutrf/grscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,6 @@


class grscan(gr.top_block):
def connect_blocks(self, first_block, other_blocks):
last_block = first_block
for block in other_blocks:
self.connect((last_block, 0), (block, 0))
last_block = block

@staticmethod
def get_fft_blocks(fft_size, sdr):
if sdr == "SoapyAIRT":
import wavelearner # pytype: disable=import-error

fft_batch_size = 256
return (
[
blocks.stream_to_vector(
gr.sizeof_gr_complex, fft_batch_size * fft_size
),
wavelearner.fft(int(fft_batch_size * fft_size), (fft_size), True),
blocks.vector_to_stream(
gr.sizeof_gr_complex * fft_size, fft_batch_size
),
],
True,
)
return (
[
blocks.stream_to_vector(gr.sizeof_gr_complex, fft_size),
fft.fft_vcc(fft_size, True, window.blackmanharris(fft_size), True, 1),
],
False,
)

def __init__(
self,
freq_end=1e9,
Expand All @@ -73,6 +41,7 @@ def __init__(
inference_output_dir="",
inference_input_len=2048,
iqtlabs=None,
wavelearner=None,
):
gr.top_block.__init__(self, "scan", catch_exceptions=True)

Expand All @@ -83,6 +52,7 @@ def __init__(
self.freq_start = freq_start
self.sweep_sec = sweep_sec
self.fft_size = fft_size
self.wavelearner = wavelearner

##################################################
# Blocks
Expand All @@ -99,9 +69,6 @@ def __init__(
sdrargs=sdrargs,
)

if not iqtlabs:
return

fft_blocks, fft_roll = self.get_fft_blocks(fft_size, sdr)
self.fft_blocks = fft_blocks + [
blocks.complex_to_mag(fft_size),
Expand Down Expand Up @@ -159,16 +126,18 @@ def __init__(
self.fft_blocks.append((zeromq.pub_sink(1, 1, zmq_addr, 100, False, 65536, "")))

self.inference_blocks = []
if sdr == "SoapyAIRT" and inference_plan_file and inference_output_dir:
import wavelearner # pytype: disable=import-error

if inference_plan_file and inference_output_dir:
if not self.wavelearner:
raise ValueError(
"trying to use inference but wavelearner not available"
)
inference_batch_size = 128
output_len = 1
self.inference_blocks = [
blocks.stream_to_vector(
gr.sizeof_gr_complex * 1, inference_batch_size * inference_input_len
),
wavelearner.inference(
self.wavelearner.inference(
inference_plan_file,
True,
inference_input_len * inference_batch_size,
Expand All @@ -194,3 +163,34 @@ def __init__(
self.inference_blocks,
):
self.connect_blocks(self.source_0, pipeline_blocks)

def connect_blocks(self, first_block, other_blocks):
last_block = first_block
for block in other_blocks:
self.connect((last_block, 0), (block, 0))
last_block = block

def get_fft_blocks(self, fft_size, sdr):
if self.wavelearner:
fft_batch_size = 256
return (
[
blocks.stream_to_vector(
gr.sizeof_gr_complex, fft_batch_size * fft_size
),
self.wavelearner.fft(
int(fft_batch_size * fft_size), (fft_size), True
),
blocks.vector_to_stream(
gr.sizeof_gr_complex * fft_size, fft_batch_size
),
],
True,
)
return (
[
blocks.stream_to_vector(gr.sizeof_gr_complex, fft_size),
fft.fft_vcc(fft_size, True, window.blackmanharris(fft_size), True, 1),
],
False,
)
4 changes: 4 additions & 0 deletions gamutrf/grsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from urllib.parse import urlparse

try:
import pmt
from gnuradio import blocks
from gnuradio import soapy
from gnuradio import uhd
Expand Down Expand Up @@ -51,7 +52,10 @@ def get_source(
grblock.source_0 = blocks.throttle(sizeof_gr_complex, samp_rate, True)
grblock.connect((grblock.recording_source_0, 0), (grblock.source_0, 0))
# TODO: enable setting frequency change tags on the stream, so can test scanner.
# grblock.source_0.set_msg_handler(pmt.intern(grblock.cmd_port), grblock.freq_setter)
grblock.freq_setter = lambda _x, _y: None
grblock.cmd_port = "command"
grblock.source_0.message_port_register_in(pmt.intern(grblock.cmd_port))
else:
raise ValueError("unsupported/missing file location")
return
Expand Down
10 changes: 10 additions & 0 deletions gamutrf/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,15 @@ def main():
print("Must provide --sample_dir when writing samples/points")
sys.exit(1)

wavelearner = None
try:
import wavelearner as wavelearner_lib # pytype: disable=import-error

wavelearner = wavelearner_lib
print("using wavelearner")
except ModuleNotFoundError:
print("wavelearner not available")

prom_vars = init_prom_vars()
prom_vars["freq_start_hz"].set(options.freq_start)
prom_vars["freq_end_hz"].set(options.freq_end)
Expand All @@ -227,6 +236,7 @@ def main():
inference_output_dir=options.inference_output_dir,
inference_input_len=options.inference_input_len,
iqtlabs=iqtlabs,
wavelearner=wavelearner,
)

def sig_handler(_sig=None, _frame=None):
Expand Down
35 changes: 33 additions & 2 deletions tests/test_grscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import time
import unittest

from gnuradio import iqtlabs
from gnuradio import fft # pytype: disable=import-error
from gnuradio.fft import window # pytype: disable=import-error

from gamutrf.grsource import get_source
from gamutrf.grscan import grscan

Expand All @@ -10,14 +14,41 @@ class FakeTb:
pass


class FakeWaveLearner:
def fft(self, batch_fft_size, _fft_size, forward):
return fft.fft_vcc(
batch_fft_size, forward, window.blackmanharris(batch_fft_size), True, 1
)


class GrscanTestCase(unittest.TestCase):
def test_get_source_smoke(self):
self.assertRaises(RuntimeError, get_source, FakeTb, "ettus", 1e3, 10)
self.assertRaises(RuntimeError, get_source, FakeTb, "bladerf", 1e3, 10)

def test_grscan_smoke(self):
start = time.time()
tb = grscan(sdr="file:/dev/zero", samp_rate=int(1.024e6))
tb = grscan(
sdr="file:/dev/zero",
samp_rate=int(1.024e6),
write_samples=1,
sample_dir="/tmp",
iqtlabs=iqtlabs,
wavelearner=None,
)
tb.start()
time.sleep(15)
tb.stop()
tb.wait()

def test_grscan_wavelearner_smoke(self):
tb = grscan(
sdr="file:/dev/zero",
samp_rate=int(1.024e6),
write_samples=1,
sample_dir="/tmp",
iqtlabs=iqtlabs,
wavelearner=FakeWaveLearner(),
)
tb.start()
time.sleep(15)
tb.stop()
Expand Down

0 comments on commit 8f6e990

Please sign in to comment.