-
Notifications
You must be signed in to change notification settings - Fork 0
/
LinearExpression.py
149 lines (115 loc) · 3.83 KB
/
LinearExpression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from decimal import Decimal, getcontext
from constants import ALMOST_ZERO, DECIMAL_PRECISION
import copy
getcontext().prec = DECIMAL_PRECISION
def formatDecimal(d):
return format(d, '+.6f').rstrip('0').rstrip('.')
def renderTerm(v, coe):
abscoe = abs(coe)
if abscoe < ALMOST_ZERO:
return ''
if v == '1':
return formatDecimal(coe)
if abscoe == 1:
return ('+' if coe > 0 else '-') + v
else:
return formatDecimal(coe) + v
class LinearExpression:
def __init__(self, terms):
self.terms = terms
@classmethod
def from_string(cls, s):
return cls(terms = cls.parse(s))
@staticmethod
def parse(s):
terms = {}
if s == None:
return terms
s = s.replace(' ', '')
s = s.replace('-', '+-')
if s[0] == '+': #special case for when first term is negative
s = s[1:]
for term in s.split('+'):
if term[-1].isalpha():
var = term[-1]
coe = term[:-1]
else:
var = '1'
coe = term
if coe == '':
coe = '1'
elif coe == '-':
coe = '-1'
if coe != '0':
if var in terms:
terms[var] += Decimal(coe)
else:
terms[var] = Decimal(coe)
return terms
def __str__(self):
result = ''
for var in self.sorted_keys():
result += renderTerm(var, self.terms[var])
if result == '':
result = '0'
if result[0] == '+':
result = result[1:]
return result
def __repr__(self):
return "LinearExpression('{0}')".format(self.__str__())
def __eq__(self, exp2):
if self.sorted_keys() != exp2.sorted_keys():
return False
for v, coe in self.terms.items():
if abs(coe-exp2.terms[v]) > ALMOST_ZERO:
return False
return True
def vars(self):
return set([var for var in self.terms.keys() if var != '1'])
def sorted_keys(self):
keys = sorted(self.terms.keys())
if keys and keys[0] == '1':
del keys[0]
keys.append('1')
return keys
def clone(self):
return LinearExpression(terms=copy.deepcopy(self.terms))
def __add__(self, e2):
result = self.clone()
for var,coe in e2.terms.items():
if var in result.terms:
result.terms[var] += coe
else:
result.terms[var] = coe
result.remove_zero_terms()
return result
def __sub__(self, e2):
return self + (e2 * -1)
def __mul__(self, scalar):
result = self.clone()
for var in result.terms.keys():
result.terms[var] *= scalar
return result
def __truediv__(self, scalar):
result = self.clone()
for var in result.terms.keys():
result.terms[var] /= scalar
return result
def remove_zero_terms(self):
self.terms = {var:coe for (var,coe) in self.terms.items() if abs(coe) > ALMOST_ZERO}
def substitute(self, var, expression):
eq = self.clone()
if not var in self.terms:
return eq
coe = eq.terms.pop(var, 0)
if abs(coe) < ALMOST_ZERO:
return eq
for v, c in expression.terms.items():
if v in eq.terms:
eq.terms[v] += c * coe
else:
eq.terms[v] = c * coe
eq.remove_zero_terms()
if not eq.terms:
return LinearExpression.from_string('0')
return eq