Skip to content

Commit

Permalink
add VkFFT vs. software FFT test.
Browse files Browse the repository at this point in the history
  • Loading branch information
anarkiwi committed Sep 25, 2023
1 parent 2bcc3ab commit 045c22e
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 87 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
fetch-depth: 0
- name: build_test
run: |
bin/apt_get.sh && bin/build_test.sh && bin/test_grc310.sh
bin/apt_get.sh && TEST_VKFFT=1 bin/build_test.sh && bin/test_grc310.sh
test-2004-gnuradio39:
runs-on: ubuntu-20.04
steps:
Expand Down
6 changes: 3 additions & 3 deletions lib/libvkfft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ VkFFTResult _transferDataToCPU(char *cpu_arr) {
if (res != VK_SUCCESS)
return VKFFT_ERROR_MALLOC_FAILED;
if (_shift) {
const size_t halfFftBufferSize = fftBufferSize / 2;
for (int i = 0; i < vkConfiguration.numberBatches; ++i) {
memcpy(cpu_arr + fftBufferSize / 2, data, fftBufferSize / 2);
memcpy(cpu_arr, data + fftBufferSize / 2, fftBufferSize / 2);
cpu_arr += fftBufferSize;
memcpy(cpu_arr + halfFftBufferSize, data, halfFftBufferSize);
memcpy(cpu_arr, data + halfFftBufferSize, halfFftBufferSize);
cpu_arr += fftBufferSize;
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion lib/vkfft_short_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ int vkfft_short_impl::work(int noutput_items,

for (int i = 0; i < noutput_items; ++i) {
const int buffer_index = i * vlen_ * 2;
_converter->conv(&in[i], &buffer[0], vlen_);
_converter->conv(&in[buffer_index], &buffer[0], vlen_);
vkfft_offload((char *)&buffer[0], (char *)&out[buffer_index]);
}

Expand Down
157 changes: 85 additions & 72 deletions python/iqtlabs/qa_retune_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@
from gnuradio import gr, gr_unittest
from gnuradio import blocks
from gnuradio import fft
from gnuradio.fft import window

try:
from gnuradio.iqtlabs import (
Expand Down Expand Up @@ -280,14 +279,15 @@ def retune_fft(self, fft_roll):
fft_write_count = 2
bucket_range = 1
fft_min = -1e9
fft_batch_size = 1

with tempfile.TemporaryDirectory() as tmpdir:
test_file = os.path.join(tmpdir, "samples.csv")
iqtlabs_tuneable_test_source_0 = tuneable_test_source(freq_end)

iqtlabs_retune_pre_fft_0 = retune_pre_fft(
points,
1, # fft_batch_size
fft_batch_size,
"rx_freq",
int(freq_start),
int(freq_end),
Expand Down Expand Up @@ -318,8 +318,19 @@ def retune_fft(self, fft_roll):
True,
)
pdu_decoder_0 = pdu_decoder()
fft_vxx_0 = fft.fft_vcc(
points, True, window.blackmanharris(points), True, 1
fft_vxx_0 = fft.fft_vcc(points, True, [], fft_roll, 1)

try:
if os.getenv("TEST_VKFFT", 0):
from gnuradio.iqtlabs import vkfft

fft_vxx_0 = vkfft(fft_batch_size * points, points, fft_roll)
print("using VkFFT")
except ImportError:
print("using software FFT")
window = blocks.multiply_const_vff(
[val for val in fft.window.blackmanharris(points) for _ in range(2)]
* fft_batch_size
)
blocks_throttle_0 = blocks.throttle(
gr.sizeof_gr_complex * 1, samp_rate, True
Expand All @@ -329,7 +340,6 @@ def retune_fft(self, fft_roll):
blocks_complex_to_mag_0 = blocks.complex_to_mag(points)
blocks_nlog10_ff_0 = blocks.nlog10_ff(20, points, 0)
vr1 = vector_roll(points)
vr2 = vector_roll(points)

self.tb.msg_connect(
(iqtlabs_retune_pre_fft_0, "tune"),
Expand All @@ -342,14 +352,14 @@ def retune_fft(self, fft_roll):
self.tb.connect((blocks_complex_to_mag_0, 0), (blocks_nlog10_ff_0, 0))
self.tb.connect((blocks_nlog10_ff_0, 0), (iqtlabs_retune_fft_0, 0))
if fft_roll:
self.tb.connect((fft_vxx_0, 0), (blocks_complex_to_mag_0, 0))
else:
# double roll, is a no-op
self.tb.connect((fft_vxx_0, 0), (vr1, 0))
self.tb.connect((vr1, 0), (vr2, 0))
self.tb.connect((vr2, 0), (blocks_complex_to_mag_0, 0))
else:
self.tb.connect((fft_vxx_0, 0), (blocks_complex_to_mag_0, 0))
self.tb.connect((vr1, 0), (blocks_complex_to_mag_0, 0))
self.tb.connect((iqtlabs_retune_fft_0, 0), (blocks_file_sink_0, 0))
self.tb.connect((iqtlabs_retune_pre_fft_0, 0), (fft_vxx_0, 0))
self.tb.connect((iqtlabs_retune_pre_fft_0, 0), (window, 0))
self.tb.connect((window, 0), (fft_vxx_0, 0))
self.tb.connect((blocks_throttle_0, 0), (iqtlabs_retune_pre_fft_0, 0))
self.tb.connect((iqtlabs_tuneable_test_source_0, 0), (blocks_throttle_0, 0))

Expand All @@ -369,67 +379,70 @@ def retune_fft(self, fft_roll):
self.assertTrue(os.path.exists(test_file))

with open(test_file, encoding="utf8") as f:
linebuffer = ""
last_data = time.time()
last_ts = 0
last_buckets = None
last_tuning_range = None
file_poll_timeout = 0.001
while tuning_range_changes < 5:
self.assertLess(time.time() - last_data, 5)
line = f.readline()
linebuffer = linebuffer + line
if not linebuffer.endswith("\n"):
time.sleep(file_poll_timeout)
continue
last_data = time.time()
line = linebuffer.strip()
try:
linebuffer = ""
record = json.loads(line)
ts = record["ts"]
self.assertGreater(ts, last_ts)
last_ts = ts
self.assertGreaterEqual(ts, record["sweep_start"])
config = record["config"]
self.assertEqual("a text description", config["description"]),
tuning_range_freq_start = config["tuning_range_freq_start"]
tuning_range_freq_end = config["tuning_range_freq_end"]
tuning_range = int(config["tuning_range"])
if tuning_range != last_tuning_range:
tuning_range_changes += 1
print("tuning_range_changes:", tuning_range_changes)
last_tuning_range = tuning_range
self.assertTrue(
(
tuning_range_freq_start == freq_start
and tuning_range_freq_end == freq_mid
)
or (
tuning_range_freq_start == freq_mid + samp_rate
and tuning_range_freq_end == freq_end
last_data = time.time()
last_ts = 0
last_buckets = None
last_tuning_range = None
file_poll_timeout = 0.001
while tuning_range_changes < 5:
self.assertLess(time.time() - last_data, 5)
line = f.readline()
linebuffer = linebuffer + line
if not linebuffer.endswith("\n"):
time.sleep(file_poll_timeout)
continue
last_data = time.time()
line = linebuffer.strip()
linebuffer = ""
record = json.loads(line)
ts = round(record["ts"])
self.assertGreaterEqual(ts, last_ts)
last_ts = ts
config = record["config"]
self.assertEqual("a text description", config["description"]),
tuning_range_freq_start = config["tuning_range_freq_start"]
tuning_range_freq_end = config["tuning_range_freq_end"]
tuning_range = int(config["tuning_range"])
if tuning_range != last_tuning_range:
tuning_range_changes += 1
print("tuning_range_changes:", tuning_range_changes)
last_tuning_range = tuning_range
self.assertTrue(
(
tuning_range_freq_start == freq_start
and tuning_range_freq_end == freq_mid
)
or (
tuning_range_freq_start == freq_mid + samp_rate
and tuning_range_freq_end == freq_end
)
)
)
self.assertEqual(config["freq_start"], freq_start)
self.assertEqual(config["freq_end"], freq_end)
self.assertEqual(config["sample_rate"], samp_rate)
self.assertEqual(config["nfft"], points)
buckets = record["buckets"]
self.assertTrue(buckets, (last_buckets, buckets))
bucket_counts[len(buckets)] += 1
fs = [int(f) for f in buckets.keys()]
self.assertGreaterEqual(min(fs), tuning_range_freq_start)
self.assertLessEqual(max(fs), tuning_range_freq_end)
new_records = [
{
"ts": ts,
"f": int(freq),
"v": float(value),
"t": int(tuning_range),
}
for freq, value in buckets.items()
]
records.extend(new_records)
last_buckets = buckets
self.assertEqual(config["freq_start"], freq_start)
self.assertEqual(config["freq_end"], freq_end)
self.assertEqual(config["sample_rate"], samp_rate)
self.assertEqual(config["nfft"], points)
buckets = record["buckets"]
self.assertTrue(buckets, (last_buckets, buckets))
bucket_counts[len(buckets)] += 1
fs = [int(f) for f in buckets.keys()]
self.assertGreaterEqual(min(fs), tuning_range_freq_start)
self.assertLessEqual(max(fs), tuning_range_freq_end)
new_records = [
{
"ts": ts,
"f": int(freq),
"v": float(value),
"t": int(tuning_range),
}
for freq, value in buckets.items()
]
records.extend(new_records)
last_buckets = buckets
except Exception:
self.tb.stop()
raise

self.tb.stop()
self.tb.wait()
Expand All @@ -446,7 +459,7 @@ def retune_fft(self, fft_roll):

for _, df in all_df.groupby("t"):
# must have plausible unscaled dB value
self.assertTrue(fft_min <= df["v"].min() <= 1, df["v"].min())
self.assertTrue(fft_min <= df["v"].min() <= 1, (fft_min, df["v"].min()))
self.assertTrue(50 <= df["v"].max() <= 61, df["v"].max())
df["m"] = df.groupby("f")["v"].apply(lambda x: x.max() - x.min())
non_unique_v = df[df["m"] > 2]
Expand Down Expand Up @@ -499,7 +512,7 @@ def retune_fft(self, fft_roll):

class qa_retune_fft_no_roll(gr_unittest.TestCase, qa_retune_fft_base):
def setUp(self):
self.tb = gr.top_block(catch_exceptions=False)
self.tb = gr.top_block(catch_exceptions=True)

def tearDown(self):
self.tb = None
Expand All @@ -510,7 +523,7 @@ def test_retune_fft_no_roll(self):

class qa_retune_fft_roll(gr_unittest.TestCase, qa_retune_fft_base):
def setUp(self):
self.tb = gr.top_block(catch_exceptions=False)
self.tb = gr.top_block(catch_exceptions=True)

def tearDown(self):
self.tb = None
Expand Down
55 changes: 45 additions & 10 deletions python/iqtlabs/qa_vkfft.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,22 +203,57 @@
# limitations under the License.
#

from gnuradio import gr, gr_unittest
import numpy as np
from gnuradio import blocks, fft, iqtlabs, gr, gr_unittest

# from gnuradio import blocks
from gnuradio.iqtlabs import vkfft

def run_fft(fft_block, points, fft_roll, input_items):
src1 = blocks.vector_source_c(input_items, vlen=len(input_items))
dst1 = blocks.vector_sink_c(vlen=len(input_items))
tb = gr.top_block()
tb.connect(src1, fft_block)
tb.connect(fft_block, dst1)
tb.run()
data = dst1.data()
tb.stop()
tb.wait()
del tb
return data


class qa_vkfft(gr_unittest.TestCase):
def setUp(self):
self.tb = gr.top_block()
def test_instance(self):
for fft_roll in (True, False):
fft_batch_size = 4
points = 8

def tearDown(self):
self.tb = None
batch_input_items = []
for i in range(1, fft_batch_size + 1):
batch_input_items.extend(
1j * np.arange(points)
) # pytype: disable=wrong-arg-types

def test_instance(self):
# TODO: find workaround llvmpipe simulated gpu crashes under CI testing
instance = vkfft(1024, 1, True)
sw_data = []
for i in range(1, fft_batch_size + 1):
batch = (i - 1) * points
sw_data.extend(
run_fft(
fft.fft_vcc(points, True, [], fft_roll, 1),
points,
fft_roll,
batch_input_items[batch : batch + points],
)
)

vkfft_data = run_fft(
iqtlabs.vkfft(fft_batch_size * points, points, fft_roll),
points,
fft_roll,
batch_input_items,
)

if os.getenv("TEST_VKFFT", 0):
self.assertEqual(vkfft_data, sw_data)


if __name__ == "__main__":
Expand Down

0 comments on commit 045c22e

Please sign in to comment.