Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dumping sympy equations to file. Fix to work on newer versions of python #42

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions msdsl/eqn/lds.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import numpy as np
import scipy.linalg

import sympy as sp
from msdsl.expr.extras import msdsl_ast_to_sympy
from msdsl.assignment import Assignment

class LDS:
def __init__(self, A=None, B=None, C=None, D=None):
# save settings
Expand Down Expand Up @@ -55,6 +59,31 @@ def __str__(self):

# return result
return retval

def convert_to_sympy(self, states, inputs, outputs):
state_strings = list(map(lambda x: str(x), states))
inputs_strings = list(map(lambda x: str(x), inputs))
outputs_strings = list(map(lambda x: str(x), outputs))
# Convert state, input, and output strings to sympy symbols
states = sp.symbols(state_strings)
inputs = sp.symbols(inputs_strings)
outputs = sp.symbols(outputs_strings)

# Convert numpy arrays to sympy matrices
A_sym = sp.Matrix(self.A)
B_sym = sp.Matrix(self.B)
C_sym = sp.Matrix(self.C)
D_sym = sp.Matrix(self.D)

# Define state-space equations
state_eq = A_sym * sp.Matrix(states) + B_sym * sp.Matrix(inputs)
output_eq = C_sym * sp.Matrix(states) + D_sym * sp.Matrix(inputs)

# Explicitly compute the derivatives (dot{x})
state_ode = sp.Matrix([sp.diff(state, 't') for state in states]) - state_eq

return state_ode, output_eq


class LdsCollection:
def __init__(self):
Expand All @@ -75,3 +104,68 @@ def append(self, lds: LDS):
self.B = np.concatenate((self.B, B), axis=2) if self.B is not None else B
self.C = np.concatenate((self.C, C), axis=2) if self.C is not None else C
self.D = np.concatenate((self.D, D), axis=2) if self.D is not None else D


def convert_to_sympy_piecewise(self, states, inputs, outputs, sel_bits, sel_eqns):
# Convert states, inputs, and outputs to sympy symbols
state_strings = list(map(str, states))
inputs_strings = list(map(str, inputs))
outputs_strings = list(map(str, outputs))

states = sp.symbols(state_strings)
inputs = sp.symbols(inputs_strings)
outputs = sp.symbols(outputs_strings)

# Convert sel_bits to SymPy symbols if they aren't already
sel_bits_sympy = [sp.Symbol(str(sel_bit)) if not isinstance(sel_bit, sp.Basic) else sel_bit for sel_bit in sel_bits]

# Initialize lists for state and output equations with default expressions
state_eq_piecewise = [sp.Piecewise((0, True)) for _ in range(len(states))]
output_eq_piecewise = [sp.Piecewise((0, True)) for _ in range(len(outputs))]

# Iterate over all possible configurations of sel_bits

for k in range(self.A.shape[2]): # Number of scenarios
A_sym = sp.Matrix(self.A[:, :, k])
B_sym = sp.Matrix(self.B[:, :, k])
C_sym = sp.Matrix(self.C[:, :, k])
D_sym = sp.Matrix(self.D[:, :, k])

# Define state-space equations
state_eq = A_sym * sp.Matrix(states) + B_sym * sp.Matrix(inputs)
output_eq = C_sym * sp.Matrix(states) + D_sym * sp.Matrix(inputs)

# Compute derivatives (dot{x}) for the state equations
state_ode = sp.Matrix([sp.diff(state, 't') for state in states]) - state_eq

# Create the condition for this scenario
condition = True
for i, sel_bit_sym in enumerate(sel_bits_sympy):
bit_value = (k >> i) & 1
# Use logical AND to build up the condition
condition = sp.And(condition, sp.Eq(sel_bit_sym, bit_value))


# Process sel_eqns to adjust the condition if needed
for sel_eqn in sel_eqns:


if isinstance(sel_eqn, Assignment):
signal = sp.Symbol(sel_eqn.signal.name)
expr = msdsl_ast_to_sympy(sel_eqn.expr)

condition = condition.subs(signal, expr)

# Assign the corresponding equations to the Piecewise objects
for i in range(len(states)):
state_eq_piecewise[i] = sp.Eq( sp.Derivative(states[i],sp.Symbol('t')), sp.Piecewise((state_ode[i], condition), (state_eq_piecewise[i], True)))

for i in range(len(outputs)):
output_eq_piecewise[i] = sp.Eq(outputs[i], sp.Piecewise((output_eq[i], condition), (output_eq_piecewise[i], True)))

# For each equation we need to substitute in
return state_eq_piecewise + output_eq_piecewise




4 changes: 4 additions & 0 deletions msdsl/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from msdsl.expr.format import RealFormat, SIntFormat, UIntFormat, Format, IntFormat

import sympy as sp
# constant wrapping

def wrap_constant(operand):
Expand Down Expand Up @@ -1039,6 +1040,9 @@ def mt19937(clk=None, rst=None, cke=None, seed=None):
def lcg_op(clk=None, rst=None, cke=None, seed=None):
return LCG(clk=clk, rst=rst, cke=cke, seed=seed)




# testing

def main():
Expand Down
38 changes: 36 additions & 2 deletions msdsl/expr/extras.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Union, List
from numbers import Number, Integral
from msdsl.expr.expr import ModelExpr, concatenate, BitwiseAnd, array
from msdsl.expr.expr import ModelExpr, concatenate, BitwiseAnd, array, LessThan, GreaterThan, Product, Sum, EqualTo, Array, Constant
from msdsl.expr.signals import AnalogSignal, DigitalSignal

import sympy as sp

def all_between(x: List[ModelExpr], lo: Union[Number, ModelExpr], hi: Union[Number, ModelExpr]) -> ModelExpr:
"""
Expand Down Expand Up @@ -37,4 +40,35 @@ def if_(condition, then, else_):
:param else_: Action to be executed for False case
:return: Boolean
"""
return array([else_, then], condition)
return array([else_, then], condition)

def msdsl_ast_to_sympy(ast):
"""
Convert an AST from msdsl to a sympy expression.
"""
if isinstance(ast, LessThan):
return sp.Lt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs))
elif isinstance(ast, GreaterThan):
return sp.Gt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs))
elif isinstance(ast, Product):
accum = 1 # Corrected from 0 to 1 to properly accumulate products
for operand in ast.operands:
accum *= msdsl_ast_to_sympy(operand)
return accum
elif isinstance(ast, Sum):
accum = 0
for operand in ast.operands:
accum += msdsl_ast_to_sympy(operand)
return accum
elif isinstance(ast, EqualTo):
return sp.Eq(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs))
elif isinstance(ast, Array):
elements = ast.operands[:-1]
address = msdsl_ast_to_sympy(ast.operands[-1])
return sp.Piecewise(*[(msdsl_ast_to_sympy(elem), address if i == 1 else sp.Not(address)) for i, elem in enumerate(elements)])
elif isinstance(ast, Constant):
return ast.value
elif isinstance(ast, AnalogSignal) or isinstance(ast, DigitalSignal):
return sp.Symbol(str(ast))
else:
raise Exception(f"Unsupported AST node: {type(ast)}")
76 changes: 72 additions & 4 deletions msdsl/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import OrderedDict, Iterable
from collections import OrderedDict
from collections.abc import Iterable
from itertools import chain
from numbers import Integral, Number
from typing import List, Set, Union
Expand Down Expand Up @@ -30,8 +31,12 @@
from msdsl.function import GeneralFunction, Function, PlaceholderFunction, MultiFunction
from msdsl.lfsr import LFSR

from msdsl.expr.extras import msdsl_ast_to_sympy

from scipy.signal import cont2discrete

import sympy as sp

class Bus:
def __init__(self, signal: Signal, n: Integral):
self.signal = signal
Expand Down Expand Up @@ -778,7 +783,7 @@ def get_equation_io(self, eqn_sys: EqnSys):
# determine sel_bits
sel_bit_names = set(signal_names(eqn_sys.get_sel_bits()))
sel_bits = self.get_signals(sel_bit_names)

# return result
return inputs, states, outputs, sel_bits

Expand All @@ -794,8 +799,12 @@ def add_eqn_sys(self, eqns: List[ModelExpr], extra_outputs=None, clk=None, rst=N
:param extra_outputs: List of internal variables in the system of equations that should be bound to analog signals.
:param clk: Name of clock signal to use (None will default to `CLK_MSDSL)
:param rst: Name of the reset signal to use (None will default to `RST_MSDSL)

Returns an LDSCollection Object
"""



# set defaults
extra_outputs = extra_outputs if extra_outputs is not None else []

Expand All @@ -804,7 +813,7 @@ def add_eqn_sys(self, eqns: List[ModelExpr], extra_outputs=None, clk=None, rst=N

# analyze equation to find out knowns and unknowns
inputs, states, outputs, sel_bits = self.get_equation_io(eqn_sys)

# add the extra outputs as needed
for extra_output in extra_outputs:
if not isinstance(extra_output, Signal):
Expand Down Expand Up @@ -832,6 +841,8 @@ def add_eqn_sys(self, eqns: List[ModelExpr], extra_outputs=None, clk=None, rst=N

# add to collection of LDS systems
collection.append(lds)



# construct address for selection
if len(sel_bits) > 0:
Expand All @@ -843,6 +854,10 @@ def add_eqn_sys(self, eqns: List[ModelExpr], extra_outputs=None, clk=None, rst=N
self.add_discrete_time_lds(collection=collection, inputs=inputs,
states=states, outputs=outputs, sel=sel,
clk=clk, rst=rst)

return collection, inputs, states, outputs, sel_bits



def add_discrete_time_lds(self, collection, inputs=None, states=None, outputs=None, sel=None,
clk=None, rst=None):
Expand Down Expand Up @@ -870,6 +885,45 @@ def add_discrete_time_lds(self, collection, inputs=None, states=None, outputs=No
else:
self.bind_name(outputs[row].name, expr)

def get_sympy_lds(self):
"""
Must be run after add_eqn_sys.
Returns an array of piecewise LDSs in the form of sympy equations.
Appends all assignments.
"""

sympy_lds = []

#self.assignments

for circuit in self.circuits:
filtered_states = list(filter(lambda x: not isinstance(x, DigitalInput), circuit.sel_bits))
state_str = list(map(lambda x: str(x), filtered_states))

sel_eqns = self.get_assignments(state_str)


circuit_lds = circuit.collection.convert_to_sympy_piecewise(circuit.states, circuit.inputs, circuit.outputs, circuit.sel_bits, sel_eqns)
circuit.sympy_eqs = circuit_lds
sympy_lds += circuit_lds

for symbol_name, assignment_obj in self.unmodified_assignments.items():
sympy_lds.append(sp.Eq(sp.Symbol(symbol_name), msdsl_ast_to_sympy(assignment_obj.expr)))

return sympy_lds

def write_sympy_lds_to_file(self, filename):
"""
Must be run after compile().
Writes the sympy equations to json format.
"""

sympy_lds = self.diffeqs

with open(filename, 'w') as f:
for eq in sympy_lds:
f.write(str(eq) + '\n')

def set_tf(self, input_: Signal, output: Signal, tf, clk=None, rst=None):
"""
Method to assign an output signal as a function of the input signal by applying a given transfer function.
Expand Down Expand Up @@ -1012,10 +1066,24 @@ def make_circuit(self, clk=None, rst=None):

def compile(self, gen: CodeGenerator):
# compile circuits
self.unmodified_assignments = deepcopy(self.assignments)

for circuit in self.circuits:
eqns = circuit.compile_to_eqn_list()
self.add_eqn_sys(eqns, circuit.extra_outputs, clk=circuit.clk, rst=circuit.rst)

ldscollection, inputs, states, outputs, sel_bits = self.add_eqn_sys(eqns, circuit.extra_outputs, clk=circuit.clk, rst=circuit.rst)

circuit.collection = ldscollection #Assign the circuit.collection object so we may use it later.
circuit.inputs = inputs
circuit.states = states
circuit.outputs = outputs
circuit.sel_bits = sel_bits



self.diffeqs = self.get_sympy_lds()


# determine the I/Os and internal variables
ios = []
internals = []
Expand Down
35 changes: 34 additions & 1 deletion msdsl/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,42 @@ def warn(s):
def list2dict(l):
return {elem: k for k, elem in enumerate(l)}


def msdsl_ast_to_sympy(ast):
"""
Convert an AST from msdsl to a sympy expression.
"""
if isinstance(ast, LessThan):
return sp.Lt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs))
elif isinstance(ast, GreaterThan):
return sp.Gt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs))
elif isinstance(ast, Product):
accum = 1 # Corrected from 0 to 1 to properly accumulate products
for operand in ast.operands:
accum *= msdsl_ast_to_sympy(operand)
return accum
elif isinstance(ast, Sum):
accum = 0
for operand in ast.operands:
accum += msdsl_ast_to_sympy(operand)
return accum
elif isinstance(ast, EqualTo):
return sp.Eq(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs))
elif isinstance(ast, Array):
elements = ast.operands[:-1]
address = msdsl_ast_to_sympy(ast.operands[-1])
return sp.Piecewise(*[(msdsl_ast_to_sympy(elem), address if i == 1 else sp.Not(address)) for i, elem in enumerate(elements)])
elif isinstance(ast, Constant):
return ast.value
elif isinstance(ast, AnalogSignal) or isinstance(ast, DigitalSignal):
return sp.Symbol(str(ast))
else:
raise Exception(f"Unsupported AST node: {type(ast)}")

def main():
# list2dict tests
print(list2dict(['a', 'b', 'c']))

if __name__ == '__main__':
main()
main()