From 960554b1f105ffd5664835bc9bdab2796e5c3527 Mon Sep 17 00:00:00 2001 From: Josh Bailey Date: Thu, 12 Oct 2023 04:42:47 +0000 Subject: [PATCH] output images to torchserve. --- codecheck-requirements.txt | 2 +- grc/iqtlabs_image_inference.block.yml | 12 +++-- include/gnuradio/iqtlabs/image_inference.h | 3 +- lib/image_inference_impl.cc | 51 +++++++++++-------- lib/image_inference_impl.h | 15 +++++- .../bindings/image_inference_python.cc | 4 +- python/iqtlabs/qa_image_inference.py | 8 ++- 7 files changed, 64 insertions(+), 31 deletions(-) diff --git a/codecheck-requirements.txt b/codecheck-requirements.txt index 4e420db9..9d46434d 100644 --- a/codecheck-requirements.txt +++ b/codecheck-requirements.txt @@ -1,3 +1,3 @@ -pytype==2023.10.5 black==23.9.1 +pytype==2023.10.5 zstandard==0.21.0 diff --git a/grc/iqtlabs_image_inference.block.yml b/grc/iqtlabs_image_inference.block.yml index d69bca6d..864d33ed 100644 --- a/grc/iqtlabs_image_inference.block.yml +++ b/grc/iqtlabs_image_inference.block.yml @@ -9,7 +9,7 @@ templates: make: > iqtlabs.image_inference(${tag}, ${vlen}, ${x}, ${y}, ${image_dir}, ${convert_alpha}, ${norm_alpha}, ${norm_beta}, ${norm_type}, ${colormap}, - ${interpolation}) + ${interpolation}, ${model_server}, ${model_name}) cpp_templates: includes: ['#include '] @@ -17,7 +17,8 @@ cpp_templates: make: > this->${id} = gr::iqtlabs::image_inference::make(${tag}, ${vlen}, ${x}, ${y}, ${image_dir}, ${convert_alpha}, ${norm_alpha}, ${norm_beta}, - ${norm_type}, ${colormap}, ${interpolation}); + ${norm_type}, ${colormap}, ${interpolation}, ${model_server}, + ${model_name}); link: ['libgnuradio-iqtlabs.so'] @@ -56,10 +57,15 @@ parameters: default: 99 # cv::flip(), or 99 for no flip - id: min_peak_points dtype: float + - id: model_name + dtype: str + - id: model_server + dtype: str asserts: - ${ tag != "" } - ${ vlen > 0 } + - ${ !model_server || (model_server && model_name) } inputs: - label: input @@ -71,6 +77,6 @@ outputs: - label: input domain: stream dtype: byte - vlen: ${ x * y } + vlen: 1 file_format: 1 diff --git a/include/gnuradio/iqtlabs/image_inference.h b/include/gnuradio/iqtlabs/image_inference.h index 4eaaf65f..9f91e125 100644 --- a/include/gnuradio/iqtlabs/image_inference.h +++ b/include/gnuradio/iqtlabs/image_inference.h @@ -232,7 +232,8 @@ class IQTLABS_API image_inference : virtual public gr::block { const std::string &image_dir, double convert_alpha, double norm_alpha, double norm_beta, int norm_type, int colormap, int interpolation, int flip, - double min_peak_points); + double min_peak_points, const std::string &model_name, + const std::string &model_server); }; } // namespace iqtlabs diff --git a/lib/image_inference_impl.cc b/lib/image_inference_impl.cc index 051ccca3..8fcdd651 100644 --- a/lib/image_inference_impl.cc +++ b/lib/image_inference_impl.cc @@ -216,27 +216,32 @@ image_inference::make(const std::string &tag, int vlen, int x, int y, const std::string &image_dir, double convert_alpha, double norm_alpha, double norm_beta, int norm_type, int colormap, int interpolation, int flip, - double min_peak_points) { + double min_peak_points, const std::string &model_name, + const std::string &model_server) { return gnuradio::make_block_sptr( tag, vlen, x, y, image_dir, convert_alpha, norm_alpha, norm_beta, - norm_type, colormap, interpolation, flip, min_peak_points); + norm_type, colormap, interpolation, flip, min_peak_points, model_name, + model_server); } image_inference_impl::image_inference_impl( const std::string &tag, int vlen, int x, int y, const std::string &image_dir, double convert_alpha, double norm_alpha, double norm_beta, int norm_type, int colormap, int interpolation, int flip, - double min_peak_points) + double min_peak_points, const std::string &model_name, + const std::string &model_server) : gr::block("image_inference", gr::io_signature::make(1 /* min inputs */, 1 /* max inputs */, vlen * sizeof(input_type)), gr::io_signature::make(1 /* min outputs */, 1 /*max outputs */, - x * y * sizeof(output_type) * 3)), + sizeof(output_type))), tag_(pmt::intern(tag)), x_(x), y_(y), vlen_(vlen), last_rx_freq_(0), last_rx_time_(0), image_dir_(image_dir), convert_alpha_(convert_alpha), norm_alpha_(norm_alpha), norm_beta_(norm_beta), norm_type_(norm_type), colormap_(colormap), interpolation_(interpolation), flip_(flip), - min_peak_points_(min_peak_points) { + min_peak_points_(min_peak_points), model_name_(model_name), + model_server_(model_server) { + image_buffer_.reset(new std::vector()); points_buffer_.reset( new cv::Mat(cv::Size(vlen, 0), CV_32F, cv::Scalar::all(0))); cmapped_buffer_.reset( @@ -286,13 +291,14 @@ void image_inference_impl::create_image_() { if (flip_ == -1 || flip_ == 0 || flip_ == 1) { cv::flip(*output_item.buffer, *output_item.buffer, flip_); } + cv::cvtColor(*output_item.buffer, *output_item.buffer, cv::COLOR_RGB2BGR); output_q_.insert(output_q_.begin(), output_item); } points_buffer_->resize(0); } } -void image_inference_impl::output_image_(output_type *out) { +void image_inference_impl::output_image_() { output_item_type output_item = output_q_.back(); void *resized_buffer_p = output_item.buffer->ptr(0); std::stringstream str; @@ -305,25 +311,19 @@ void image_inference_impl::output_image_(output_type *out) { pmt::from_double(output_item.rx_freq), _id); const size_t buffer_size = output_item.buffer->total() * output_item.buffer->elemSize(); - std::memcpy(out, resized_buffer_p, buffer_size); + // std::memcpy(out, resized_buffer_p, buffer_size); std::string image_file_base = "image_" + host_now_str_(output_item.ts) + "_" + std::to_string(uint64_t(x_)) + "x" + std::to_string(uint64_t(y_)) + "_" + std::to_string(uint64_t(output_item.rx_freq)) + "Hz"; - // TODO: re-enable if non-PNG image required. - // std::string image_file = image_file_base + ".bin"; - // std::string dot_image_file = image_dir_ + "/." + image_file; - // std::string full_image_file = image_dir_ + "/" + image_file; - // std::ofstream image_out; - // image_out.open(dot_image_file, std::ios::binary | std::ios::out); - // image_out.write((const char *)resized_buffer_p, buffer_size); - // image_out.close(); - // rename(dot_image_file.c_str(), full_image_file.c_str()); - std::string image_file_png = image_file_base + ".png"; + std::string image_file_png = image_file_base + IMAGE_EXT; std::string dot_image_file_png = image_dir_ + "/." + image_file_png; std::string full_image_file_png = image_dir_ + "/" + image_file_png; - cv::cvtColor(*output_item.buffer, *output_item.buffer, cv::COLOR_RGB2BGR); - cv::imwrite(dot_image_file_png, *output_item.buffer); + cv::imencode(IMAGE_EXT, *output_item.buffer, *image_buffer_); + std::ofstream image_out; + image_out.open(dot_image_file_png, std::ios::binary | std::ios::out); + image_out.write((const char *)image_buffer_->data(), image_buffer_->size()); + image_out.close(); rename(dot_image_file_png.c_str(), full_image_file_png.c_str()); delete_output_(); } @@ -337,8 +337,17 @@ int image_inference_impl::general_work(int noutput_items, size_t in_first = nitems_read(0); if (!output_q_.empty()) { - output_image_(static_cast(output_items[0])); - return 1; + output_image_(); + } + + if (!out_buf_.empty()) { + auto out = static_cast(output_items[0]); + const size_t leftover = std::min(out_buf_.size(), (size_t)noutput_items); + auto from = out_buf_.begin(); + auto to = from + leftover; + std::copy(from, to, out); + out_buf_.erase(from, to); + return leftover; } std::vector all_tags, rx_freq_tags; diff --git a/lib/image_inference_impl.h b/lib/image_inference_impl.h index 7d8eb96d..44cc1469 100644 --- a/lib/image_inference_impl.h +++ b/lib/image_inference_impl.h @@ -206,6 +206,11 @@ #define INCLUDED_IQTLABS_IMAGE_INFERENCE_IMPL_H #include "base_impl.h" +#include +#include +#include +#include +#include #include #include #include @@ -216,6 +221,7 @@ namespace iqtlabs { using input_type = float; using output_type = unsigned char; +const std::string IMAGE_EXT = ".png"; typedef struct output_item { uint64_t rx_freq; @@ -230,13 +236,17 @@ class image_inference_impl : public image_inference, base_impl { double convert_alpha_, norm_alpha_, norm_beta_, last_rx_time_, min_peak_points_; std::vector output_q_; + boost::scoped_ptr> image_buffer_; boost::scoped_ptr points_buffer_, cmapped_buffer_; std::string image_dir_; pmt::pmt_t tag_; + boost::asio::io_context ioc_; + std::deque out_buf_; + std::string model_name_, model_server_; void process_items_(size_t c, const input_type *&in); void create_image_(); - void output_image_(output_type *out); + void output_image_(); void delete_output_(); public: @@ -244,7 +254,8 @@ class image_inference_impl : public image_inference, base_impl { const std::string &image_dir, double convert_alpha, double norm_alpha, double norm_beta, int norm_type, int colormap, int interpolation, int flip, - double min_peak_points); + double min_peak_points, const std::string &model_name, + const std::string &model_server); ~image_inference_impl(); int general_work(int noutput_items, gr_vector_int &ninput_items, gr_vector_const_void_star &input_items, diff --git a/python/iqtlabs/bindings/image_inference_python.cc b/python/iqtlabs/bindings/image_inference_python.cc index 31551566..330c0fc4 100644 --- a/python/iqtlabs/bindings/image_inference_python.cc +++ b/python/iqtlabs/bindings/image_inference_python.cc @@ -14,7 +14,7 @@ /* BINDTOOL_GEN_AUTOMATIC(0) */ /* BINDTOOL_USE_PYGCCXML(0) */ /* BINDTOOL_HEADER_FILE(image_inference.h) */ -/* BINDTOOL_HEADER_FILE_HASH(8b70b8d06de725d327d7d86b793e4360) */ +/* BINDTOOL_HEADER_FILE_HASH(22b15057623b3d662df741b8aecd2dc9) */ /***********************************************************************************/ #include @@ -52,6 +52,8 @@ void bind_image_inference(py::module& m) py::arg("interpolation"), py::arg("flip"), py::arg("min_peak_points"), + py::arg("model_name"), + py::arg("model_server"), D(image_inference, make)) diff --git a/python/iqtlabs/qa_image_inference.py b/python/iqtlabs/qa_image_inference.py index 985cbbc6..fcf0dadd 100755 --- a/python/iqtlabs/qa_image_inference.py +++ b/python/iqtlabs/qa_image_inference.py @@ -203,6 +203,7 @@ # limitations under the License. # +import imghdr import glob import os import pmt @@ -257,12 +258,14 @@ def test_instance(self): 2, 0, -1e9, + "", + "", ) c2r = blocks.complex_to_real(1) stream2vector = blocks.stream_to_vector(gr.sizeof_float, fft_size) throttle = blocks.throttle(gr.sizeof_float, samp_rate, True) fs = blocks.file_sink( - gr.sizeof_char * output_vlen, os.path.join(tmpdir, test_file), False + gr.sizeof_char, os.path.join(tmpdir, test_file), False ) self.tb.msg_connect((strobe, "strobe"), (source, "cmd")) @@ -281,7 +284,8 @@ def test_instance(self): for image_file in image_files: stat = os.stat(image_file) self.assertTrue(stat.st_size) - self.assertTrue(os.stat(test_file).st_size) + self.assertEqual(imghdr.what(image_file), "png") + # self.assertTrue(os.stat(test_file).st_size) if __name__ == "__main__":