Skip to content

Commit

Permalink
Merge pull request #127 from anarkiwi/path
Browse files Browse the repository at this point in the history
Image inference outputs image path, and handles multiline JSON from torchserve.
  • Loading branch information
anarkiwi authored Oct 17, 2023
2 parents e31f81e + 7267880 commit 600e2d3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
7 changes: 5 additions & 2 deletions lib/image_inference_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ void image_inference_impl::output_image_() {
std::stringstream ss("", std::ios_base::app | std::ios_base::out);
ss << "{"
<< "\"ts\": " << host_now_str_(output_item.ts)
<< ", \"rx_freq\": " << output_item.rx_freq;
<< ", \"rx_freq\": " << output_item.rx_freq << ", \"image_path\": \""
<< full_image_file_png << "\"";
// TODO: synchronous requests for testing. Should be parallel.
if (host_.size() && port_.size()) {
boost::asio::io_context ioc;
Expand Down Expand Up @@ -357,7 +358,9 @@ void image_inference_impl::output_image_() {
ss << ", \"error\": \"" << ex.what() << "\"";
}
}
ss << "}" << std::endl;
// double new line to faciliate json parsing, since prediction may contain new
// lines.
ss << "}\n" << std::endl;
const std::string s = ss.str();
out_buf_.insert(out_buf_.end(), s.begin(), s.end());
delete_output_();
Expand Down
18 changes: 12 additions & 6 deletions python/iqtlabs/qa_image_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ def simulate_torchserve(self, port, model_name, result):

# nosemgrep:github.workflows.config.useless-inner-function
@app.route(f"/predictions/{model_name}", methods=["POST"])
def predictions_test():
return json.dumps(result), 200
def predictions_test():
return json.dumps(result, indent=2), 200

try:
app.run(host="127.0.0.1", port=11001)
Expand All @@ -252,7 +252,7 @@ def predictions_test():
def test_instance(self):
port = 11001
model_name = "testmodel"
predictions_result = {"modulation": 999}
predictions_result = {"modulation": [{"conf": 0.9, "xywh": [1, 2, 3, 4]}]}
if self.pid == 0:
self.simulate_torchserve(port, model_name, predictions_result)
return
Expand Down Expand Up @@ -312,9 +312,15 @@ def test_instance(self):
self.assertEqual(imghdr.what(image_file), "png")
self.assertTrue(os.stat(test_file).st_size)
with open(test_file) as f:
for line in f.readlines():
result = json.loads(line)
self.assertEqual(result["predictions"], predictions_result)
content = f.read()
json_raw_all = content.split("\n\n")
self.assertTrue(json_raw_all)
for json_raw in json_raw_all:
if not json_raw:
continue
result = json.loads(json_raw)
self.assertTrue(os.path.exists(result["image_path"]))
self.assertEqual(result["predictions"], predictions_result)


if __name__ == "__main__":
Expand Down

0 comments on commit 600e2d3

Please sign in to comment.