Skip to content

Commit

Permalink
Merge pull request #126 from anarkiwi/msg
Browse files Browse the repository at this point in the history
implement simple torchserve API integration for testing.
  • Loading branch information
anarkiwi authored Oct 15, 2023
2 parents d9c1e8d + 1d39900 commit e31f81e
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 41 deletions.
1 change: 1 addition & 0 deletions .github/workflows/config/test-shell.nix
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mkShell {
gnuradio.python.pkgs.pybind11
libsndfile
opencv
python3Packages.flask
soapysdr
spdlog
uhd
Expand Down
1 change: 1 addition & 0 deletions bin/apt_get.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ sudo apt-get update && \
libvulkan-dev \
libzstd-dev \
mesa-vulkan-drivers \
python3-flask \
python3-numpy \
python3-packaging \
python3-pandas \
Expand Down
1 change: 1 addition & 0 deletions bin/apt_get_39.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ sudo apt-get -y update && \
pybind11-dev \
python3-click \
python3-click-plugins \
python3-flask \
python3-gi \
python3-gi-cairo \
python3-lxml \
Expand Down
2 changes: 1 addition & 1 deletion codecheck-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pytype==2023.10.5
black==23.9.1
pytype==2023.10.5
zstandard==0.21.0
12 changes: 9 additions & 3 deletions grc/iqtlabs_image_inference.block.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ templates:
make: >
iqtlabs.image_inference(${tag}, ${vlen}, ${x}, ${y}, ${image_dir},
${convert_alpha}, ${norm_alpha}, ${norm_beta}, ${norm_type}, ${colormap},
${interpolation})
${interpolation}, ${model_server}, ${model_name})
cpp_templates:
includes: ['#include <gnuradio/iqtlabs/image_inference.h>']
declarations: 'gr::iqtlabs::image_inference::sptr ${id};'
make: >
this->${id} = gr::iqtlabs::image_inference::make(${tag}, ${vlen},
${x}, ${y}, ${image_dir}, ${convert_alpha}, ${norm_alpha}, ${norm_beta},
${norm_type}, ${colormap}, ${interpolation});
${norm_type}, ${colormap}, ${interpolation}, ${model_server},
${model_name});
link: ['libgnuradio-iqtlabs.so']


Expand Down Expand Up @@ -56,10 +57,15 @@ parameters:
default: 99 # cv::flip(), or 99 for no flip
- id: min_peak_points
dtype: float
- id: model_name
dtype: str
- id: model_server
dtype: str

asserts:
- ${ tag != "" }
- ${ vlen > 0 }
- ${ !model_server || (model_server && model_name) }

inputs:
- label: input
Expand All @@ -71,6 +77,6 @@ outputs:
- label: input
domain: stream
dtype: byte
vlen: ${ x * y }
vlen: 1

file_format: 1
3 changes: 2 additions & 1 deletion include/gnuradio/iqtlabs/image_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ class IQTLABS_API image_inference : virtual public gr::block {
const std::string &image_dir, double convert_alpha,
double norm_alpha, double norm_beta, int norm_type,
int colormap, int interpolation, int flip,
double min_peak_points);
double min_peak_points, const std::string &model_server,
const std::string &model_name);
};

} // namespace iqtlabs
Expand Down
108 changes: 76 additions & 32 deletions lib/image_inference_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@
*/

#include "image_inference_impl.h"
#include <boost/algorithm/string.hpp>
#include <boost/lexical_cast.hpp>
#include <fstream>
#include <gnuradio/io_signature.h>
#include <ios>
Expand All @@ -216,27 +218,39 @@ image_inference::make(const std::string &tag, int vlen, int x, int y,
const std::string &image_dir, double convert_alpha,
double norm_alpha, double norm_beta, int norm_type,
int colormap, int interpolation, int flip,
double min_peak_points) {
double min_peak_points, const std::string &model_server,
const std::string &model_name) {
return gnuradio::make_block_sptr<image_inference_impl>(
tag, vlen, x, y, image_dir, convert_alpha, norm_alpha, norm_beta,
norm_type, colormap, interpolation, flip, min_peak_points);
norm_type, colormap, interpolation, flip, min_peak_points, model_server,
model_name);
}

image_inference_impl::image_inference_impl(
const std::string &tag, int vlen, int x, int y,
const std::string &image_dir, double convert_alpha, double norm_alpha,
double norm_beta, int norm_type, int colormap, int interpolation, int flip,
double min_peak_points)
double min_peak_points, const std::string &model_server,
const std::string &model_name)
: gr::block("image_inference",
gr::io_signature::make(1 /* min inputs */, 1 /* max inputs */,
vlen * sizeof(input_type)),
gr::io_signature::make(1 /* min outputs */, 1 /*max outputs */,
x * y * sizeof(output_type) * 3)),
sizeof(output_type))),
tag_(pmt::intern(tag)), x_(x), y_(y), vlen_(vlen), last_rx_freq_(0),
last_rx_time_(0), image_dir_(image_dir), convert_alpha_(convert_alpha),
norm_alpha_(norm_alpha), norm_beta_(norm_beta), norm_type_(norm_type),
colormap_(colormap), interpolation_(interpolation), flip_(flip),
min_peak_points_(min_peak_points) {
min_peak_points_(min_peak_points), model_name_(model_name) {
// TODO: IPv6 IP addresses
std::vector<std::string> model_server_parts_;
boost::split(model_server_parts_, model_server, boost::is_any_of(":"),
boost::token_compress_on);
if (model_server_parts_.size() == 2) {
host_ = model_server_parts_[0];
port_ = model_server_parts_[1];
}
image_buffer_.reset(new std::vector<unsigned char>());
points_buffer_.reset(
new cv::Mat(cv::Size(vlen, 0), CV_32F, cv::Scalar::all(0)));
cmapped_buffer_.reset(
Expand Down Expand Up @@ -286,45 +300,66 @@ void image_inference_impl::create_image_() {
if (flip_ == -1 || flip_ == 0 || flip_ == 1) {
cv::flip(*output_item.buffer, *output_item.buffer, flip_);
}
cv::cvtColor(*output_item.buffer, *output_item.buffer, cv::COLOR_RGB2BGR);
output_q_.insert(output_q_.begin(), output_item);
}
points_buffer_->resize(0);
}
}

void image_inference_impl::output_image_(output_type *out) {
void image_inference_impl::output_image_() {
output_item_type output_item = output_q_.back();
void *resized_buffer_p = output_item.buffer->ptr(0);
std::stringstream str;
str << name() << unique_id();
pmt::pmt_t _id = pmt::string_to_symbol(str.str());
// TODO: add more metadata as needed for inference.
this->add_item_tag(0, nitems_written(0), RX_TIME_KEY,
make_rx_time_key_(output_item.ts), _id);
this->add_item_tag(0, nitems_written(0), RX_FREQ_KEY,
pmt::from_double(output_item.rx_freq), _id);
const size_t buffer_size =
output_item.buffer->total() * output_item.buffer->elemSize();
std::memcpy(out, resized_buffer_p, buffer_size);
// write image file
std::string image_file_base =
"image_" + host_now_str_(output_item.ts) + "_" +
std::to_string(uint64_t(x_)) + "x" + std::to_string(uint64_t(y_)) + "_" +
std::to_string(uint64_t(output_item.rx_freq)) + "Hz";
// TODO: re-enable if non-PNG image required.
// std::string image_file = image_file_base + ".bin";
// std::string dot_image_file = image_dir_ + "/." + image_file;
// std::string full_image_file = image_dir_ + "/" + image_file;
// std::ofstream image_out;
// image_out.open(dot_image_file, std::ios::binary | std::ios::out);
// image_out.write((const char *)resized_buffer_p, buffer_size);
// image_out.close();
// rename(dot_image_file.c_str(), full_image_file.c_str());
std::string image_file_png = image_file_base + ".png";
std::string image_file_png = image_file_base + IMAGE_EXT;
std::string dot_image_file_png = image_dir_ + "/." + image_file_png;
std::string full_image_file_png = image_dir_ + "/" + image_file_png;
cv::cvtColor(*output_item.buffer, *output_item.buffer, cv::COLOR_RGB2BGR);
cv::imwrite(dot_image_file_png, *output_item.buffer);
cv::imencode(IMAGE_EXT, *output_item.buffer, *image_buffer_);
std::ofstream image_out;
image_out.open(dot_image_file_png, std::ios::binary | std::ios::out);
image_out.write((const char *)image_buffer_->data(), image_buffer_->size());
image_out.close();
rename(dot_image_file_png.c_str(), full_image_file_png.c_str());
std::stringstream ss("", std::ios_base::app | std::ios_base::out);
ss << "{"
<< "\"ts\": " << host_now_str_(output_item.ts)
<< ", \"rx_freq\": " << output_item.rx_freq;
// TODO: synchronous requests for testing. Should be parallel.
if (host_.size() && port_.size()) {
boost::asio::io_context ioc;
boost::asio::ip::tcp::resolver resolver(ioc);
boost::beast::tcp_stream stream(ioc);
const std::string_view body(
reinterpret_cast<char const *>(image_buffer_->data()),
image_buffer_->size());
boost::beast::http::request<boost::beast::http::string_body> req{
boost::beast::http::verb::post, "/predictions/" + model_name_, 11};
req.set(boost::beast::http::field::host, host_);
req.set(boost::beast::http::field::user_agent, BOOST_BEAST_VERSION_STRING);
req.set(boost::beast::http::field::content_type, "image/" + IMAGE_TYPE);
req.body() = body;
req.prepare_payload();
boost::beast::flat_buffer buffer;
boost::beast::http::response<boost::beast::http::string_body> res;

try {
auto const results = resolver.resolve(host_, port_);
stream.connect(results);
boost::beast::http::write(stream, req);
boost::beast::http::read(stream, buffer, res);
ss << ", \"predictions\": " << res.body().data();
boost::beast::error_code ec;
stream.socket().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec);
} catch (std::exception &ex) {
ss << ", \"error\": \"" << ex.what() << "\"";
}
}
ss << "}" << std::endl;
const std::string s = ss.str();
out_buf_.insert(out_buf_.end(), s.begin(), s.end());
delete_output_();
}

Expand All @@ -337,8 +372,17 @@ int image_inference_impl::general_work(int noutput_items,
size_t in_first = nitems_read(0);

if (!output_q_.empty()) {
output_image_(static_cast<output_type *>(output_items[0]));
return 1;
output_image_();
}

if (!out_buf_.empty()) {
auto out = static_cast<output_type *>(output_items[0]);
const size_t leftover = std::min(out_buf_.size(), (size_t)noutput_items);
auto from = out_buf_.begin();
auto to = from + leftover;
std::copy(from, to, out);
out_buf_.erase(from, to);
return leftover;
}

std::vector<tag_t> all_tags, rx_freq_tags;
Expand Down
15 changes: 13 additions & 2 deletions lib/image_inference_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@
#define INCLUDED_IQTLABS_IMAGE_INFERENCE_IMPL_H

#include "base_impl.h"
#include <boost/asio/connect.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/beast/core.hpp>
#include <boost/beast/http.hpp>
#include <boost/beast/version.hpp>
#include <boost/scoped_ptr.hpp>
#include <gnuradio/iqtlabs/image_inference.h>
#include <opencv2/imgcodecs.hpp>
Expand All @@ -216,6 +221,8 @@ namespace iqtlabs {

using input_type = float;
using output_type = unsigned char;
const std::string IMAGE_TYPE = "png";
const std::string IMAGE_EXT = "." + IMAGE_TYPE;

typedef struct output_item {
uint64_t rx_freq;
Expand All @@ -230,21 +237,25 @@ class image_inference_impl : public image_inference, base_impl {
double convert_alpha_, norm_alpha_, norm_beta_, last_rx_time_,
min_peak_points_;
std::vector<output_item_type> output_q_;
boost::scoped_ptr<std::vector<unsigned char>> image_buffer_;
boost::scoped_ptr<cv::Mat> points_buffer_, cmapped_buffer_;
std::string image_dir_;
pmt::pmt_t tag_;
std::deque<output_type> out_buf_;
std::string model_name_, host_, port_;

void process_items_(size_t c, const input_type *&in);
void create_image_();
void output_image_(output_type *out);
void output_image_();
void delete_output_();

public:
image_inference_impl(const std::string &tag, int vlen, int x, int y,
const std::string &image_dir, double convert_alpha,
double norm_alpha, double norm_beta, int norm_type,
int colormap, int interpolation, int flip,
double min_peak_points);
double min_peak_points, const std::string &model_server,
const std::string &model_name);
~image_inference_impl();
int general_work(int noutput_items, gr_vector_int &ninput_items,
gr_vector_const_void_star &input_items,
Expand Down
4 changes: 3 additions & 1 deletion python/iqtlabs/bindings/image_inference_python.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
/* BINDTOOL_GEN_AUTOMATIC(0) */
/* BINDTOOL_USE_PYGCCXML(0) */
/* BINDTOOL_HEADER_FILE(image_inference.h) */
/* BINDTOOL_HEADER_FILE_HASH(8b70b8d06de725d327d7d86b793e4360) */
/* BINDTOOL_HEADER_FILE_HASH(ed9a9d38e8c7f50546d907bfa44a223b) */
/***********************************************************************************/

#include <pybind11/complex.h>
Expand Down Expand Up @@ -52,6 +52,8 @@ void bind_image_inference(py::module& m)
py::arg("interpolation"),
py::arg("flip"),
py::arg("min_peak_points"),
py::arg("model_server"),
py::arg("model_name"),
D(image_inference, make))


Expand Down
35 changes: 34 additions & 1 deletion python/iqtlabs/qa_image_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,15 @@
# limitations under the License.
#

import concurrent.futures
import imghdr
import json
import glob
import os
import pmt
import time
import tempfile
from flask import Flask
from gnuradio import gr, gr_unittest
from gnuradio import analog, blocks

Expand All @@ -225,11 +229,33 @@
class qa_image_inference(gr_unittest.TestCase):
def setUp(self):
self.tb = gr.top_block()
self.pid = os.fork()

def tearDown(self):
self.tb = None
if self.pid:
os.kill(self.pid, 15)

def simulate_torchserve(self, port, model_name, result):
app = Flask(__name__)

# nosemgrep:github.workflows.config.useless-inner-function
@app.route(f"/predictions/{model_name}", methods=["POST"])
def predictions_test():
return json.dumps(result), 200

try:
app.run(host="127.0.0.1", port=11001)
except RuntimeError:
return

def test_instance(self):
port = 11001
model_name = "testmodel"
predictions_result = {"modulation": 999}
if self.pid == 0:
self.simulate_torchserve(port, model_name, predictions_result)
return
x = 800
y = 600
fft_size = 1024
Expand Down Expand Up @@ -257,12 +283,14 @@ def test_instance(self):
2,
0,
-1e9,
f"localhost:{port}",
model_name,
)
c2r = blocks.complex_to_real(1)
stream2vector = blocks.stream_to_vector(gr.sizeof_float, fft_size)
throttle = blocks.throttle(gr.sizeof_float, samp_rate, True)
fs = blocks.file_sink(
gr.sizeof_char * output_vlen, os.path.join(tmpdir, test_file), False
gr.sizeof_char, os.path.join(tmpdir, test_file), False
)

self.tb.msg_connect((strobe, "strobe"), (source, "cmd"))
Expand All @@ -281,7 +309,12 @@ def test_instance(self):
for image_file in image_files:
stat = os.stat(image_file)
self.assertTrue(stat.st_size)
self.assertEqual(imghdr.what(image_file), "png")
self.assertTrue(os.stat(test_file).st_size)
with open(test_file) as f:
for line in f.readlines():
result = json.loads(line)
self.assertEqual(result["predictions"], predictions_result)


if __name__ == "__main__":
Expand Down

0 comments on commit e31f81e

Please sign in to comment.