From d8fc90bd101be10f3b3aee7d51a50f3dac5fd7d1 Mon Sep 17 00:00:00 2001 From: Josh Bailey Date: Thu, 12 Oct 2023 04:42:47 +0000 Subject: [PATCH] implement simple torchserve API integration for testing. --- codecheck-requirements.txt | 2 +- grc/iqtlabs_image_inference.block.yml | 12 +- include/gnuradio/iqtlabs/image_inference.h | 3 +- lib/image_inference_impl.cc | 108 ++++++++++++------ lib/image_inference_impl.h | 15 ++- .../bindings/image_inference_python.cc | 4 +- python/iqtlabs/qa_image_inference.py | 34 +++++- 7 files changed, 137 insertions(+), 41 deletions(-) diff --git a/codecheck-requirements.txt b/codecheck-requirements.txt index 4e420db9..9d46434d 100644 --- a/codecheck-requirements.txt +++ b/codecheck-requirements.txt @@ -1,3 +1,3 @@ -pytype==2023.10.5 black==23.9.1 +pytype==2023.10.5 zstandard==0.21.0 diff --git a/grc/iqtlabs_image_inference.block.yml b/grc/iqtlabs_image_inference.block.yml index d69bca6d..864d33ed 100644 --- a/grc/iqtlabs_image_inference.block.yml +++ b/grc/iqtlabs_image_inference.block.yml @@ -9,7 +9,7 @@ templates: make: > iqtlabs.image_inference(${tag}, ${vlen}, ${x}, ${y}, ${image_dir}, ${convert_alpha}, ${norm_alpha}, ${norm_beta}, ${norm_type}, ${colormap}, - ${interpolation}) + ${interpolation}, ${model_server}, ${model_name}) cpp_templates: includes: ['#include '] @@ -17,7 +17,8 @@ cpp_templates: make: > this->${id} = gr::iqtlabs::image_inference::make(${tag}, ${vlen}, ${x}, ${y}, ${image_dir}, ${convert_alpha}, ${norm_alpha}, ${norm_beta}, - ${norm_type}, ${colormap}, ${interpolation}); + ${norm_type}, ${colormap}, ${interpolation}, ${model_server}, + ${model_name}); link: ['libgnuradio-iqtlabs.so'] @@ -56,10 +57,15 @@ parameters: default: 99 # cv::flip(), or 99 for no flip - id: min_peak_points dtype: float + - id: model_name + dtype: str + - id: model_server + dtype: str asserts: - ${ tag != "" } - ${ vlen > 0 } + - ${ !model_server || (model_server && model_name) } inputs: - label: input @@ -71,6 +77,6 @@ outputs: - label: input domain: stream dtype: byte - vlen: ${ x * y } + vlen: 1 file_format: 1 diff --git a/include/gnuradio/iqtlabs/image_inference.h b/include/gnuradio/iqtlabs/image_inference.h index 4eaaf65f..142a8fcc 100644 --- a/include/gnuradio/iqtlabs/image_inference.h +++ b/include/gnuradio/iqtlabs/image_inference.h @@ -232,7 +232,8 @@ class IQTLABS_API image_inference : virtual public gr::block { const std::string &image_dir, double convert_alpha, double norm_alpha, double norm_beta, int norm_type, int colormap, int interpolation, int flip, - double min_peak_points); + double min_peak_points, const std::string &model_server, + const std::string &model_name); }; } // namespace iqtlabs diff --git a/lib/image_inference_impl.cc b/lib/image_inference_impl.cc index 051ccca3..fc87463a 100644 --- a/lib/image_inference_impl.cc +++ b/lib/image_inference_impl.cc @@ -203,6 +203,8 @@ */ #include "image_inference_impl.h" +#include +#include #include #include #include @@ -216,27 +218,39 @@ image_inference::make(const std::string &tag, int vlen, int x, int y, const std::string &image_dir, double convert_alpha, double norm_alpha, double norm_beta, int norm_type, int colormap, int interpolation, int flip, - double min_peak_points) { + double min_peak_points, const std::string &model_server, + const std::string &model_name) { return gnuradio::make_block_sptr( tag, vlen, x, y, image_dir, convert_alpha, norm_alpha, norm_beta, - norm_type, colormap, interpolation, flip, min_peak_points); + norm_type, colormap, interpolation, flip, min_peak_points, model_server, + model_name); } image_inference_impl::image_inference_impl( const std::string &tag, int vlen, int x, int y, const std::string &image_dir, double convert_alpha, double norm_alpha, double norm_beta, int norm_type, int colormap, int interpolation, int flip, - double min_peak_points) + double min_peak_points, const std::string &model_server, + const std::string &model_name) : gr::block("image_inference", gr::io_signature::make(1 /* min inputs */, 1 /* max inputs */, vlen * sizeof(input_type)), gr::io_signature::make(1 /* min outputs */, 1 /*max outputs */, - x * y * sizeof(output_type) * 3)), + sizeof(output_type))), tag_(pmt::intern(tag)), x_(x), y_(y), vlen_(vlen), last_rx_freq_(0), last_rx_time_(0), image_dir_(image_dir), convert_alpha_(convert_alpha), norm_alpha_(norm_alpha), norm_beta_(norm_beta), norm_type_(norm_type), colormap_(colormap), interpolation_(interpolation), flip_(flip), - min_peak_points_(min_peak_points) { + min_peak_points_(min_peak_points), model_name_(model_name) { + // TODO: IPv6 IP addresses + std::vector model_server_parts_; + boost::split(model_server_parts_, model_server, boost::is_any_of(":"), + boost::token_compress_on); + if (model_server_parts_.size() == 2) { + host_ = model_server_parts_[0]; + port_ = model_server_parts_[1]; + } + image_buffer_.reset(new std::vector()); points_buffer_.reset( new cv::Mat(cv::Size(vlen, 0), CV_32F, cv::Scalar::all(0))); cmapped_buffer_.reset( @@ -286,45 +300,66 @@ void image_inference_impl::create_image_() { if (flip_ == -1 || flip_ == 0 || flip_ == 1) { cv::flip(*output_item.buffer, *output_item.buffer, flip_); } + cv::cvtColor(*output_item.buffer, *output_item.buffer, cv::COLOR_RGB2BGR); output_q_.insert(output_q_.begin(), output_item); } points_buffer_->resize(0); } } -void image_inference_impl::output_image_(output_type *out) { +void image_inference_impl::output_image_() { output_item_type output_item = output_q_.back(); - void *resized_buffer_p = output_item.buffer->ptr(0); - std::stringstream str; - str << name() << unique_id(); - pmt::pmt_t _id = pmt::string_to_symbol(str.str()); - // TODO: add more metadata as needed for inference. - this->add_item_tag(0, nitems_written(0), RX_TIME_KEY, - make_rx_time_key_(output_item.ts), _id); - this->add_item_tag(0, nitems_written(0), RX_FREQ_KEY, - pmt::from_double(output_item.rx_freq), _id); - const size_t buffer_size = - output_item.buffer->total() * output_item.buffer->elemSize(); - std::memcpy(out, resized_buffer_p, buffer_size); + // write image file std::string image_file_base = "image_" + host_now_str_(output_item.ts) + "_" + std::to_string(uint64_t(x_)) + "x" + std::to_string(uint64_t(y_)) + "_" + std::to_string(uint64_t(output_item.rx_freq)) + "Hz"; - // TODO: re-enable if non-PNG image required. - // std::string image_file = image_file_base + ".bin"; - // std::string dot_image_file = image_dir_ + "/." + image_file; - // std::string full_image_file = image_dir_ + "/" + image_file; - // std::ofstream image_out; - // image_out.open(dot_image_file, std::ios::binary | std::ios::out); - // image_out.write((const char *)resized_buffer_p, buffer_size); - // image_out.close(); - // rename(dot_image_file.c_str(), full_image_file.c_str()); - std::string image_file_png = image_file_base + ".png"; + std::string image_file_png = image_file_base + IMAGE_EXT; std::string dot_image_file_png = image_dir_ + "/." + image_file_png; std::string full_image_file_png = image_dir_ + "/" + image_file_png; - cv::cvtColor(*output_item.buffer, *output_item.buffer, cv::COLOR_RGB2BGR); - cv::imwrite(dot_image_file_png, *output_item.buffer); + cv::imencode(IMAGE_EXT, *output_item.buffer, *image_buffer_); + std::ofstream image_out; + image_out.open(dot_image_file_png, std::ios::binary | std::ios::out); + image_out.write((const char *)image_buffer_->data(), image_buffer_->size()); + image_out.close(); rename(dot_image_file_png.c_str(), full_image_file_png.c_str()); + std::stringstream ss("", std::ios_base::app | std::ios_base::out); + ss << "{" + << "\"ts\": " << host_now_str_(output_item.ts) + << ", \"rx_freq\": " << output_item.rx_freq; + // TODO: synchronous requests for testing. Should be parallel. + if (host_.size() && port_.size()) { + boost::asio::io_context ioc; + boost::asio::ip::tcp::resolver resolver(ioc); + boost::beast::tcp_stream stream(ioc); + const std::string_view body( + reinterpret_cast(image_buffer_->data()), + image_buffer_->size()); + boost::beast::http::request req{ + boost::beast::http::verb::post, "/predictions/" + model_name_, 11}; + req.set(boost::beast::http::field::host, host_); + req.set(boost::beast::http::field::user_agent, BOOST_BEAST_VERSION_STRING); + req.set(boost::beast::http::field::content_type, "image/" + IMAGE_TYPE); + req.body() = body; + req.prepare_payload(); + boost::beast::flat_buffer buffer; + boost::beast::http::response res; + + try { + auto const results = resolver.resolve(host_, port_); + stream.connect(results); + boost::beast::http::write(stream, req); + boost::beast::http::read(stream, buffer, res); + ss << ", \"predictions\": " << res.body().data(); + boost::beast::error_code ec; + stream.socket().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec); + } catch (std::exception &ex) { + ss << ", \"error\": \"" << ex.what() << "\""; + } + } + ss << "}" << std::endl; + const std::string s = ss.str(); + out_buf_.insert(out_buf_.end(), s.begin(), s.end()); delete_output_(); } @@ -337,8 +372,17 @@ int image_inference_impl::general_work(int noutput_items, size_t in_first = nitems_read(0); if (!output_q_.empty()) { - output_image_(static_cast(output_items[0])); - return 1; + output_image_(); + } + + if (!out_buf_.empty()) { + auto out = static_cast(output_items[0]); + const size_t leftover = std::min(out_buf_.size(), (size_t)noutput_items); + auto from = out_buf_.begin(); + auto to = from + leftover; + std::copy(from, to, out); + out_buf_.erase(from, to); + return leftover; } std::vector all_tags, rx_freq_tags; diff --git a/lib/image_inference_impl.h b/lib/image_inference_impl.h index 7d8eb96d..f3e86d2c 100644 --- a/lib/image_inference_impl.h +++ b/lib/image_inference_impl.h @@ -206,6 +206,11 @@ #define INCLUDED_IQTLABS_IMAGE_INFERENCE_IMPL_H #include "base_impl.h" +#include +#include +#include +#include +#include #include #include #include @@ -216,6 +221,8 @@ namespace iqtlabs { using input_type = float; using output_type = unsigned char; +const std::string IMAGE_TYPE = "png"; +const std::string IMAGE_EXT = "." + IMAGE_TYPE; typedef struct output_item { uint64_t rx_freq; @@ -230,13 +237,16 @@ class image_inference_impl : public image_inference, base_impl { double convert_alpha_, norm_alpha_, norm_beta_, last_rx_time_, min_peak_points_; std::vector output_q_; + boost::scoped_ptr> image_buffer_; boost::scoped_ptr points_buffer_, cmapped_buffer_; std::string image_dir_; pmt::pmt_t tag_; + std::deque out_buf_; + std::string model_name_, host_, port_; void process_items_(size_t c, const input_type *&in); void create_image_(); - void output_image_(output_type *out); + void output_image_(); void delete_output_(); public: @@ -244,7 +254,8 @@ class image_inference_impl : public image_inference, base_impl { const std::string &image_dir, double convert_alpha, double norm_alpha, double norm_beta, int norm_type, int colormap, int interpolation, int flip, - double min_peak_points); + double min_peak_points, const std::string &model_server, + const std::string &model_name); ~image_inference_impl(); int general_work(int noutput_items, gr_vector_int &ninput_items, gr_vector_const_void_star &input_items, diff --git a/python/iqtlabs/bindings/image_inference_python.cc b/python/iqtlabs/bindings/image_inference_python.cc index 31551566..ea5e3bb0 100644 --- a/python/iqtlabs/bindings/image_inference_python.cc +++ b/python/iqtlabs/bindings/image_inference_python.cc @@ -14,7 +14,7 @@ /* BINDTOOL_GEN_AUTOMATIC(0) */ /* BINDTOOL_USE_PYGCCXML(0) */ /* BINDTOOL_HEADER_FILE(image_inference.h) */ -/* BINDTOOL_HEADER_FILE_HASH(8b70b8d06de725d327d7d86b793e4360) */ +/* BINDTOOL_HEADER_FILE_HASH(ed9a9d38e8c7f50546d907bfa44a223b) */ /***********************************************************************************/ #include @@ -52,6 +52,8 @@ void bind_image_inference(py::module& m) py::arg("interpolation"), py::arg("flip"), py::arg("min_peak_points"), + py::arg("model_server"), + py::arg("model_name"), D(image_inference, make)) diff --git a/python/iqtlabs/qa_image_inference.py b/python/iqtlabs/qa_image_inference.py index 985cbbc6..4150f842 100755 --- a/python/iqtlabs/qa_image_inference.py +++ b/python/iqtlabs/qa_image_inference.py @@ -203,11 +203,15 @@ # limitations under the License. # +import concurrent.futures +import imghdr +import json import glob import os import pmt import time import tempfile +from flask import Flask from gnuradio import gr, gr_unittest from gnuradio import analog, blocks @@ -225,11 +229,32 @@ class qa_image_inference(gr_unittest.TestCase): def setUp(self): self.tb = gr.top_block() + self.pid = os.fork() def tearDown(self): self.tb = None + if self.pid: + os.kill(self.pid, 15) + + def simulate_torchserve(self, port, model_name, result): + app = Flask(__name__) + + @app.route("/predictions/testmodel", methods=["POST"]) + def predictions_test(): + return json.dumps(result), 200 + + try: + app.run(host="127.0.0.1", port=11001) + except RuntimeError: + return def test_instance(self): + port = 11001 + model_name = "testmodel" + predictions_result = {"modulation": 999} + if self.pid == 0: + self.simulate_torchserve(port, model_name, predictions_result) + return x = 800 y = 600 fft_size = 1024 @@ -257,12 +282,14 @@ def test_instance(self): 2, 0, -1e9, + f"localhost:{port}", + model_name, ) c2r = blocks.complex_to_real(1) stream2vector = blocks.stream_to_vector(gr.sizeof_float, fft_size) throttle = blocks.throttle(gr.sizeof_float, samp_rate, True) fs = blocks.file_sink( - gr.sizeof_char * output_vlen, os.path.join(tmpdir, test_file), False + gr.sizeof_char, os.path.join(tmpdir, test_file), False ) self.tb.msg_connect((strobe, "strobe"), (source, "cmd")) @@ -281,7 +308,12 @@ def test_instance(self): for image_file in image_files: stat = os.stat(image_file) self.assertTrue(stat.st_size) + self.assertEqual(imghdr.what(image_file), "png") self.assertTrue(os.stat(test_file).st_size) + with open(test_file) as f: + for line in f.readlines(): + result = json.loads(line) + self.assertEqual(result["predictions"], predictions_result) if __name__ == "__main__":