Skip to content

Commit

Permalink
output images to torchserve.
Browse files Browse the repository at this point in the history
  • Loading branch information
anarkiwi committed Oct 13, 2023
1 parent d9c1e8d commit 960554b
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 31 deletions.
2 changes: 1 addition & 1 deletion codecheck-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pytype==2023.10.5
black==23.9.1
pytype==2023.10.5
zstandard==0.21.0
12 changes: 9 additions & 3 deletions grc/iqtlabs_image_inference.block.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ templates:
make: >
iqtlabs.image_inference(${tag}, ${vlen}, ${x}, ${y}, ${image_dir},
${convert_alpha}, ${norm_alpha}, ${norm_beta}, ${norm_type}, ${colormap},
${interpolation})
${interpolation}, ${model_server}, ${model_name})
cpp_templates:
includes: ['#include <gnuradio/iqtlabs/image_inference.h>']
declarations: 'gr::iqtlabs::image_inference::sptr ${id};'
make: >
this->${id} = gr::iqtlabs::image_inference::make(${tag}, ${vlen},
${x}, ${y}, ${image_dir}, ${convert_alpha}, ${norm_alpha}, ${norm_beta},
${norm_type}, ${colormap}, ${interpolation});
${norm_type}, ${colormap}, ${interpolation}, ${model_server},
${model_name});
link: ['libgnuradio-iqtlabs.so']


Expand Down Expand Up @@ -56,10 +57,15 @@ parameters:
default: 99 # cv::flip(), or 99 for no flip
- id: min_peak_points
dtype: float
- id: model_name
dtype: str
- id: model_server
dtype: str

asserts:
- ${ tag != "" }
- ${ vlen > 0 }
- ${ !model_server || (model_server && model_name) }

inputs:
- label: input
Expand All @@ -71,6 +77,6 @@ outputs:
- label: input
domain: stream
dtype: byte
vlen: ${ x * y }
vlen: 1

file_format: 1
3 changes: 2 additions & 1 deletion include/gnuradio/iqtlabs/image_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ class IQTLABS_API image_inference : virtual public gr::block {
const std::string &image_dir, double convert_alpha,
double norm_alpha, double norm_beta, int norm_type,
int colormap, int interpolation, int flip,
double min_peak_points);
double min_peak_points, const std::string &model_name,
const std::string &model_server);
};

} // namespace iqtlabs
Expand Down
51 changes: 30 additions & 21 deletions lib/image_inference_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,27 +216,32 @@ image_inference::make(const std::string &tag, int vlen, int x, int y,
const std::string &image_dir, double convert_alpha,
double norm_alpha, double norm_beta, int norm_type,
int colormap, int interpolation, int flip,
double min_peak_points) {
double min_peak_points, const std::string &model_name,
const std::string &model_server) {
return gnuradio::make_block_sptr<image_inference_impl>(
tag, vlen, x, y, image_dir, convert_alpha, norm_alpha, norm_beta,
norm_type, colormap, interpolation, flip, min_peak_points);
norm_type, colormap, interpolation, flip, min_peak_points, model_name,
model_server);
}

image_inference_impl::image_inference_impl(
const std::string &tag, int vlen, int x, int y,
const std::string &image_dir, double convert_alpha, double norm_alpha,
double norm_beta, int norm_type, int colormap, int interpolation, int flip,
double min_peak_points)
double min_peak_points, const std::string &model_name,
const std::string &model_server)
: gr::block("image_inference",
gr::io_signature::make(1 /* min inputs */, 1 /* max inputs */,
vlen * sizeof(input_type)),
gr::io_signature::make(1 /* min outputs */, 1 /*max outputs */,
x * y * sizeof(output_type) * 3)),
sizeof(output_type))),
tag_(pmt::intern(tag)), x_(x), y_(y), vlen_(vlen), last_rx_freq_(0),
last_rx_time_(0), image_dir_(image_dir), convert_alpha_(convert_alpha),
norm_alpha_(norm_alpha), norm_beta_(norm_beta), norm_type_(norm_type),
colormap_(colormap), interpolation_(interpolation), flip_(flip),
min_peak_points_(min_peak_points) {
min_peak_points_(min_peak_points), model_name_(model_name),
model_server_(model_server) {
image_buffer_.reset(new std::vector<unsigned char>());
points_buffer_.reset(
new cv::Mat(cv::Size(vlen, 0), CV_32F, cv::Scalar::all(0)));
cmapped_buffer_.reset(
Expand Down Expand Up @@ -286,13 +291,14 @@ void image_inference_impl::create_image_() {
if (flip_ == -1 || flip_ == 0 || flip_ == 1) {
cv::flip(*output_item.buffer, *output_item.buffer, flip_);
}
cv::cvtColor(*output_item.buffer, *output_item.buffer, cv::COLOR_RGB2BGR);
output_q_.insert(output_q_.begin(), output_item);
}
points_buffer_->resize(0);
}
}

void image_inference_impl::output_image_(output_type *out) {
void image_inference_impl::output_image_() {
output_item_type output_item = output_q_.back();
void *resized_buffer_p = output_item.buffer->ptr(0);
std::stringstream str;
Expand All @@ -305,25 +311,19 @@ void image_inference_impl::output_image_(output_type *out) {
pmt::from_double(output_item.rx_freq), _id);
const size_t buffer_size =
output_item.buffer->total() * output_item.buffer->elemSize();
std::memcpy(out, resized_buffer_p, buffer_size);
// std::memcpy(out, resized_buffer_p, buffer_size);
std::string image_file_base =
"image_" + host_now_str_(output_item.ts) + "_" +
std::to_string(uint64_t(x_)) + "x" + std::to_string(uint64_t(y_)) + "_" +
std::to_string(uint64_t(output_item.rx_freq)) + "Hz";
// TODO: re-enable if non-PNG image required.
// std::string image_file = image_file_base + ".bin";
// std::string dot_image_file = image_dir_ + "/." + image_file;
// std::string full_image_file = image_dir_ + "/" + image_file;
// std::ofstream image_out;
// image_out.open(dot_image_file, std::ios::binary | std::ios::out);
// image_out.write((const char *)resized_buffer_p, buffer_size);
// image_out.close();
// rename(dot_image_file.c_str(), full_image_file.c_str());
std::string image_file_png = image_file_base + ".png";
std::string image_file_png = image_file_base + IMAGE_EXT;
std::string dot_image_file_png = image_dir_ + "/." + image_file_png;
std::string full_image_file_png = image_dir_ + "/" + image_file_png;
cv::cvtColor(*output_item.buffer, *output_item.buffer, cv::COLOR_RGB2BGR);
cv::imwrite(dot_image_file_png, *output_item.buffer);
cv::imencode(IMAGE_EXT, *output_item.buffer, *image_buffer_);
std::ofstream image_out;
image_out.open(dot_image_file_png, std::ios::binary | std::ios::out);
image_out.write((const char *)image_buffer_->data(), image_buffer_->size());
image_out.close();
rename(dot_image_file_png.c_str(), full_image_file_png.c_str());
delete_output_();
}
Expand All @@ -337,8 +337,17 @@ int image_inference_impl::general_work(int noutput_items,
size_t in_first = nitems_read(0);

if (!output_q_.empty()) {
output_image_(static_cast<output_type *>(output_items[0]));
return 1;
output_image_();
}

if (!out_buf_.empty()) {
auto out = static_cast<output_type *>(output_items[0]);
const size_t leftover = std::min(out_buf_.size(), (size_t)noutput_items);
auto from = out_buf_.begin();
auto to = from + leftover;
std::copy(from, to, out);
out_buf_.erase(from, to);
return leftover;
}

std::vector<tag_t> all_tags, rx_freq_tags;
Expand Down
15 changes: 13 additions & 2 deletions lib/image_inference_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@
#define INCLUDED_IQTLABS_IMAGE_INFERENCE_IMPL_H

#include "base_impl.h"
#include <boost/asio/connect.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/beast/core.hpp>
#include <boost/beast/http.hpp>
#include <boost/beast/version.hpp>
#include <boost/scoped_ptr.hpp>
#include <gnuradio/iqtlabs/image_inference.h>
#include <opencv2/imgcodecs.hpp>
Expand All @@ -216,6 +221,7 @@ namespace iqtlabs {

using input_type = float;
using output_type = unsigned char;
const std::string IMAGE_EXT = ".png";

typedef struct output_item {
uint64_t rx_freq;
Expand All @@ -230,21 +236,26 @@ class image_inference_impl : public image_inference, base_impl {
double convert_alpha_, norm_alpha_, norm_beta_, last_rx_time_,
min_peak_points_;
std::vector<output_item_type> output_q_;
boost::scoped_ptr<std::vector<unsigned char>> image_buffer_;
boost::scoped_ptr<cv::Mat> points_buffer_, cmapped_buffer_;
std::string image_dir_;
pmt::pmt_t tag_;
boost::asio::io_context ioc_;
std::deque<output_type> out_buf_;
std::string model_name_, model_server_;

void process_items_(size_t c, const input_type *&in);
void create_image_();
void output_image_(output_type *out);
void output_image_();
void delete_output_();

public:
image_inference_impl(const std::string &tag, int vlen, int x, int y,
const std::string &image_dir, double convert_alpha,
double norm_alpha, double norm_beta, int norm_type,
int colormap, int interpolation, int flip,
double min_peak_points);
double min_peak_points, const std::string &model_name,
const std::string &model_server);
~image_inference_impl();
int general_work(int noutput_items, gr_vector_int &ninput_items,
gr_vector_const_void_star &input_items,
Expand Down
4 changes: 3 additions & 1 deletion python/iqtlabs/bindings/image_inference_python.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
/* BINDTOOL_GEN_AUTOMATIC(0) */
/* BINDTOOL_USE_PYGCCXML(0) */
/* BINDTOOL_HEADER_FILE(image_inference.h) */
/* BINDTOOL_HEADER_FILE_HASH(8b70b8d06de725d327d7d86b793e4360) */
/* BINDTOOL_HEADER_FILE_HASH(22b15057623b3d662df741b8aecd2dc9) */
/***********************************************************************************/

#include <pybind11/complex.h>
Expand Down Expand Up @@ -52,6 +52,8 @@ void bind_image_inference(py::module& m)
py::arg("interpolation"),
py::arg("flip"),
py::arg("min_peak_points"),
py::arg("model_name"),
py::arg("model_server"),
D(image_inference, make))


Expand Down
8 changes: 6 additions & 2 deletions python/iqtlabs/qa_image_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@
# limitations under the License.
#

import imghdr
import glob
import os
import pmt
Expand Down Expand Up @@ -257,12 +258,14 @@ def test_instance(self):
2,
0,
-1e9,
"",
"",
)
c2r = blocks.complex_to_real(1)
stream2vector = blocks.stream_to_vector(gr.sizeof_float, fft_size)
throttle = blocks.throttle(gr.sizeof_float, samp_rate, True)
fs = blocks.file_sink(
gr.sizeof_char * output_vlen, os.path.join(tmpdir, test_file), False
gr.sizeof_char, os.path.join(tmpdir, test_file), False
)

self.tb.msg_connect((strobe, "strobe"), (source, "cmd"))
Expand All @@ -281,7 +284,8 @@ def test_instance(self):
for image_file in image_files:
stat = os.stat(image_file)
self.assertTrue(stat.st_size)
self.assertTrue(os.stat(test_file).st_size)
self.assertEqual(imghdr.what(image_file), "png")
# self.assertTrue(os.stat(test_file).st_size)


if __name__ == "__main__":
Expand Down

0 comments on commit 960554b

Please sign in to comment.