Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Code to Node Conversion #105

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions extensions/lgpilot/codeToNode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import sys
import re

def create_func(writer, library_name, function_name, function_args):
writer.write(f"from {library_name} import {function_name}\n")
writer.write(f"def {function_name}({function_args}):\n\t")
writer.write(f"y = {function_name}({function_args})\n\t")
writer.write("return y\n")

def create_setup(writer, function_name):
writer.write(f"\nclass {function_name}Node(lg.Node):\n\t")
writer.write("INPUT = lg.Topic(InputMessage)\n\tOUTPUT = lg.Topic(OutputMessage)\n\n\t")
writer.write(f"def setup(self):\n\t\tself.func = {function_name}\n\n\t")

def create_feature(writer, function_name, function_args):
writer.write("@lg.subscriber(INPUT)\n\[email protected](OUTPUT)\n\n\t")
writer.write(f"def {function_name}_feature(self, message: InputMessage):\n\t\t")
# turn a string containing a list of function args (ex: 'x, y, z') into InputMessage attributes (ex: message.x, message.y, message.z)
params = [f"message.{i.strip()}" for i in function_args.split(",")]
# turn a list of parameters (ex: ["message.x", "message.y", "message.z"]) into a string of arguments
# ex: y = self.func(message.x, message.y, message.z)
writer.write("y = self.func(" + ", ".join(params) + ")\n\t\tyield self.OUTPUT, y")


def code_to_node(filename):
""" Take a python file <filename> containing a function and
output a file named node.py containing labgraph node.

"""
library_name, function_name, function_args = "", "", ""
with open(filename, 'r') as reader:
with open("node.py", 'w') as writer:
for line in reader:
# first check if line contains library_name and function name
result = re.search("from (.*) import (.*)", line)
if result is not None: # a match was found
library_name, function_name = result.group(1), result.group(2)
# next check if line contains function arguments
result = re.search(f"[a-zA-z]* = {function_name}\((.*)\)", line)
if result is not None:
function_args = result.group(1)

create_func(writer, library_name, function_name, function_args)
create_setup(writer, function_name)
create_feature(writer, function_name, function_args)


if __name__ == "__main__":
code_to_node(sys.argv[1])
12 changes: 12 additions & 0 deletions extensions/lgpilot/convolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import numpy as np
from scipy.signal import convolve

# Create two arrays
x = np.array([1, 2, 3, 4])
h = np.array([1, 2, 3])

# Perform convolution
y = convolve(x, h)

# Print result
print(y)
18 changes: 18 additions & 0 deletions extensions/lgpilot/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from scipy.signal import convolve
def convolve(x, h):
y = convolve(x, h)
return y

class convolveNode(lg.Node):
INPUT = lg.Topic(InputMessage)
OUTPUT = lg.Topic(OutputMessage)

def setup(self):
self.func = convolve

@lg.subscriber(INPUT)
@lg.publisher(OUTPUT)

def convolve_feature(self, message: InputMessage):
y = self.func(message.x, message.h)
yield self.OUTPUT, y
136 changes: 136 additions & 0 deletions extensions/lgpilot/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Import labgraph
import labgraph as lg
# Imports required for this example
from scipy.signal import convolve
import numpy as np

import labgraph as lg
import numpy as np
import pytest
from ...generators.sine_wave_generator import (
SineWaveChannelConfig,
SineWaveGenerator,
)

from ..mixer_one_input_node import MixerOneInputConfig, MixerOneInputNode
from ..signal_capture_node import SignalCaptureConfig, SignalCaptureNode
from ..signal_generator_node import SignalGeneratorNode


# A data type used in streaming, see docs: Messages
class InputMessage(lg.Message):
x: np.ndarray
h: np.ndarray

class OutputMessage(lg.Message):
data: np.ndarray


# ================================= CONVOLUTION ===================================


def convolve(x, h):
y = convolve(x, h)
return y

class ConvolveNode(lg.Node):
INPUT = lg.Topic(InputMessage)
OUTPUT = lg.Topic(OutputMessage)

def setup(self):
self.func = convolve

@lg.subscriber(INPUT)
@lg.publisher(OUTPUT)

def convolve_feature(self, message: InputMessage):
y = self.func(message.x, message.h)
yield self.OUTPUT, y

# ======================================================================


# class MixerOneInputConfig(lg.Config):
# # This is an NxM matrix (for M inputs, N outputs)
# weights: np.ndarray

class ConvolveInputConfig(lg.Config):
array: np.ndarray
kernel: np.ndarray


class MyGraphConfig(lg.Config):
sine_wave_channel_config: SineWaveChannelConfig
convolve_config: ConvolveInputConfig
capture_config: SignalCaptureConfig


class MyGraph(lg.Graph):

sample_source: SignalGeneratorNode
convolve_node: ConvolveNode
capture_node: SignalCaptureNode

def setup(self) -> None:
self.capture_node.configure(self.config.capture_config)
self.sample_source.set_generator(
SineWaveGenerator(self.config.sine_wave_channel_config)
)
self.convolve_node.configure(self.config.convolve_config)

def connections(self) -> lg.Connections:
return (
(self.convolve_node.INPUT, self.sample_source.SAMPLE_TOPIC),
(self.capture_node.SAMPLE_TOPIC, self.mixer_node.OUTPUT),
)


def test_convolve_input_node() -> None:
"""
Tests that node convolves correctly, uses numpy arrays and kernel sizes as input
"""

sample_rate = 1 # Hz
test_duration = 10 # sec

# Test configurations
shape = (2,)
amplitudes = np.array([5.0, 3.0])
frequencies = np.array([5, 10])
phase_shifts = np.array([1.0, 5.0])
midlines = np.array([3.0, -2.5])

test_array = [1, 2, 3]
test_kernel = [2]

# Generate expected values

expected = convolve(test_array, test_kernel) # use the convolve from the library to generate the expected values

# Create the graph
generator_config = SineWaveChannelConfig(
shape, amplitudes, frequencies, phase_shifts, midlines, sample_rate
)
capture_config = SignalCaptureConfig(int(test_duration / sample_rate))

# mixer_weights = np.identity(2)
# mixer_config = MixerOneInputConfig(mixer_weights)

convolve_input_array = [1, 2, 3]
convolve_input_kernel = [2]

convolve_config = ConvolveInputConfig(convolve_input_array, convolve_input_kernel)

my_graph_config = MyGraphConfig(generator_config, convolve_config, capture_config)

graph = MyGraph()
graph.configure(my_graph_config)

runner = lg.LocalRunner(module=graph)
runner.run()
received = np.array(graph.capture_node.samples).T
np.testing.assert_almost_equal(received, expected)

# 1. test the convolve function
# 2. create the graph and run it
# 3. repeat the same thing for other APIs -- just need to create simple test cases