Skip to content

Commit

Permalink
Merge pull request #1231 from anarkiwi/dualmodel
Browse files Browse the repository at this point in the history
fix inference combiner for multiple inference blocks.
  • Loading branch information
anarkiwi authored Apr 12, 2024
2 parents dd21291 + c3382df commit 3330527
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
26 changes: 15 additions & 11 deletions gamutrf/grinferenceoutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ def __init__(self, inputs):

def mix(self, input_items):
items = []
n = 0
ns = {}
for i, input_item in enumerate(input_items):
raw_input_item = input_item.tobytes().decode("utf8")
n += len(raw_input_item)
self.json_buffer[i] += raw_input_item
raw_input_item = input_item.tobytes().decode("utf8").split("\x00")[0]
ns[i] = len(raw_input_item)
if len(raw_input_item):
self.json_buffer[i] += raw_input_item
for i in self.json_buffer:
while True:
delim_pos = self.json_buffer[i].find(DELIM)
if delim_pos == -1:
Expand All @@ -47,11 +49,11 @@ def mix(self, input_items):
item = json.loads(raw_item)
items.append(item)
except json.JSONDecodeError as e:
logging.error("cannot decode %s: %s", raw_item, e)
return (n, items)
logging.error("cannot decode %s from source %u: %s", raw_item, i, e)
return (ns, items)


class inferenceoutput(gr.sync_block):
class inferenceoutput(gr.basic_block):
def __init__(
self,
name,
Expand Down Expand Up @@ -85,7 +87,7 @@ def __init__(
),
)
self.reporter_thread.start()
gr.sync_block.__init__(
gr.basic_block.__init__(
self,
name="inferenceoutput",
in_sig=([np.ubyte] * inputs),
Expand Down Expand Up @@ -143,8 +145,10 @@ def reporter_thread(
mqtt_reporter.log(log_path, "inference", start_time, item)
self.q.task_done()

def work(self, input_items, output_items):
n, items = self.mixer.mix(input_items)
def general_work(self, input_items, output_items):
ns, items = self.mixer.mix(input_items)
for i, n in ns.items():
self.consume(i, n)
for item in items:
self.q.put(item)
return n
return 0
4 changes: 2 additions & 2 deletions gamutrf/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def argument_parser():
"--nfft",
dest="nfft",
type=int,
default=2048,
default=1024,
help="FFTI size [default=%(default)r]",
)
parser.add_argument(
Expand Down Expand Up @@ -291,7 +291,7 @@ def argument_parser():
"--iq_inference_len",
dest="iq_inference_len",
type=int,
default=4096,
default=1024,
help="number of samples to send for I/Q inference",
)
parser.add_argument(
Expand Down

0 comments on commit 3330527

Please sign in to comment.