Skip to content

Commit

Permalink
* changed highs_linear_expression to be immutable by default
Browse files Browse the repository at this point in the history
* added __repr__ for easier debugging and pretty print linear expressions
* added __iadd__ for mutable operations
* added qsum for faster aggregation
* updated chained comparison support to work in immutable setting
  • Loading branch information
mathgeekcoder committed Aug 23, 2024
1 parent ddd3778 commit b4a0942
Showing 1 changed file with 161 additions and 28 deletions.
189 changes: 161 additions & 28 deletions src/highspy/highs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from itertools import groupby, product
from operator import itemgetter
from decimal import Decimal
from threading import local

class Highs(_Highs):
"""HiGHS solver interface"""
Expand Down Expand Up @@ -521,7 +522,7 @@ def addConstr(self, cons, name=None):
super().passRowName(con.index, name)

return con

def addConstrs(self, *args, **kwargs):
"""Adds multiple constraints to the model.
Expand Down Expand Up @@ -648,6 +649,21 @@ def setContinuous(self, var):
"""
super().changeColIntegrality(var.index, HighsVarType.kContinuous)

@staticmethod
def qsum(items):
"""Performs a faster sum for highs_linear_expressions.
Args:
items: A collection of highs_linear_expressions or highs_vars to be summed.
"""
expr = highs_linear_expression()

for v in items:
expr += v

return expr


## The following classes keep track of variables
## It is currently quite basic and may fail in complex scenarios

Expand Down Expand Up @@ -760,25 +776,39 @@ def __init__(self, other=None):
def __neg__(self):
return -1.0 * self

def __repr__(self):
# if we have duplicate variables, add the vals together
agg = [(var, sum(v[1] for v in Vals)) for var, Vals in groupby(sorted(zip(self.vars, self.vals)), key=itemgetter(0))]

v = str.join(" ", [f"{c}_v{x}" for x,c in agg])
c = f" {self.constant}" if abs(self.constant) > 1e-6 else ''

if self.LHS == self.RHS:
return f"{v}{c} == {self.LHS}"
else:
return f"{self.LHS} <= {v}{c} <= {self.RHS}"

# (LHS <= self <= RHS) <= (other.LHS <= other <= other.RHS)
def __le__(self, other):
if isinstance(other, highs_linear_expression):
if self.LHS != -kHighsInf and self.RHS != kHighsInf and len(other.vars) > 0 or other.LHS != -kHighsInf:
raise Exception('Cannot construct constraint with variables as bounds.')

# move variables from other to self
self.vars.extend(other.vars)
self.vals.extend([-1.0 * v for v in other.vals])
self.constant -= other.constant
self.RHS = 0
return self
# move variables from other to copy
copy = self.__clone_for_inequality_chain()
copy.vars.extend(other.vars)
copy.vals.extend([-1.0 * v for v in other.vals])
copy.constant -= other.constant
copy.RHS = 0
return copy

elif isinstance(other, highs_var):
return NotImplemented

elif isinstance(other, (int, float, Decimal)):
self.RHS = min(self.RHS, other)
return self
copy = self.__clone_for_inequality_chain()
copy.RHS = min(copy.RHS, other)
return copy

else:
return NotImplemented
Expand All @@ -789,13 +819,14 @@ def __eq__(self, other):
if self.LHS != -kHighsInf and len(other.vars) > 0 or other.LHS != -kHighsInf:
raise Exception('Cannot construct constraint with variables as bounds.')

# move variables from other to self
self.vars.extend(other.vars)
self.vals.extend([-1.0 * v for v in other.vals])
self.constant -= other.constant
self.LHS = 0
self.RHS = 0
return self
# move variables from other to copy
copy = self.__clone_for_inequality_chain()
copy.vars.extend(other.vars)
copy.vals.extend([-1.0 * v for v in other.vals])
copy.constant -= other.constant
copy.LHS = 0
copy.RHS = 0
return copy

elif isinstance(other, highs_var):
return NotImplemented
Expand All @@ -804,9 +835,20 @@ def __eq__(self, other):
if self.LHS != -kHighsInf or self.RHS != kHighsInf:
raise Exception('Logic error in constraint equality.')

self.LHS = other
self.RHS = other
return self
copy = self.__clone_for_inequality_chain()
copy.LHS = other
copy.RHS = other
return copy

# support expr == [lb, ub] --> lb <= expr <= ub
elif hasattr(other, "__getitem__") and hasattr(other, "__len__") and len(other) == 2:
if not (isinstance(other[0], (int, float, Decimal)) and isinstance(other[1], (int, float, Decimal))):
raise Exception('Provided bounds were not valid numbers.')

copy = self.__clone_for_inequality_chain()
copy.LHS = other[0]
copy.RHS = other[1]
return copy

else:
return NotImplemented
Expand All @@ -820,17 +862,18 @@ def __ge__(self, other):
return NotImplemented

elif isinstance(other, (int, float, Decimal)):
self.LHS = max(self.LHS, other)
return self
copy = self.__clone_for_inequality_chain()
copy.LHS = max(copy.LHS, other)
return copy

else:
return NotImplemented

def __radd__(self, other):
return self + other

# (LHS <= self <= RHS) + (LHS <= other <= RHS)
def __add__(self, other):
# (LHS <= self <= RHS) += (LHS <= other <= RHS)
def __iadd__(self, other):
if isinstance(other, highs_linear_expression):
self.vars.extend(other.vars)
self.vals.extend(other.vals)
Expand All @@ -851,16 +894,41 @@ def __add__(self, other):
else:
return NotImplemented

# (LHS <= self <= RHS) + (LHS <= other <= RHS)
def __add__(self, other):
if isinstance(other, highs_linear_expression):
copy = highs_linear_expression(self)
copy.vars.extend(other.vars)
copy.vals.extend(other.vals)
copy.constant += other.constant
copy.LHS = max(copy.LHS, other.LHS)
copy.RHS = min(copy.RHS, other.RHS)
return copy

elif isinstance(other, highs_var):
copy = highs_linear_expression(self)
copy.vars.append(other.index)
copy.vals.append(1.0)
return copy

elif isinstance(other, (int, float, Decimal)):
copy = highs_linear_expression(self)
copy.constant += other
return copy

else:
return NotImplemented

def __rmul__(self, other):
return self * other

def __mul__(self, other):
result = highs_linear_expression(self)

if isinstance(other, (int, float, Decimal)):
result.vals = [float(other) * v for v in self.vals]
result.constant *= float(other)
return result
copy = highs_linear_expression(self)
copy.vals = [float(other) * v for v in self.vals]
copy.constant *= float(other)
return copy

elif isinstance(other, highs_var):
raise Exception('Only linear expressions are allowed.')
else:
Expand All @@ -878,3 +946,68 @@ def __sub__(self, other):
return self + (-1.0 * other)
else:
return NotImplemented


# The following is needed to support chained comparison, i.e., lb <= expr <= ub. This is interpreted
# as '__bool__(lb <= expr) and (expr <= ub)'; returning (expr <= ub), since __bool__(lb <= expr) == True.
#
# We essentially want to "rewrite" this as '(lb <= expr) <= ub', while keeping the expr instance immutable.
# As a slight hack, we can use a shared (thread local) object to keep track of the chain.
#
# Whenever we perform an inequality, we first check if the current expression ('self') is part of a chain.
# If it is, we copy the inner '__chain' expression rather than 'self'.
#
# This inner '__chain' is set by __bool__(lb <= expr) and is reset after the inequality is evaluated.
#
# Two potential issues:
# 1. It is possible to manually construct this sequence, e.g.,
# tmp = x0 + x1
# bool(tmp <= 10)
# print(5 <= tmp) # outputs: 5 <= 1.0_v0 + 1.0_v1 <= 10
# print(5 <= tmp) # outputs: 5 <= 1.0_v0 + 1.0_v1 <= inf : chain is broken
#
# Note that:
# bool(tmp <= 10)
# tmp += x0 + x1 # changes tmp, so the chain is broken
# print(5 <= tmp) # outputs: 5 <= 2.0_v0 + 2.0_v1 <= inf
#
# 2. The chain might "break" if run within a debugger (on same thread), i.e., "watched debugger expressions"
# that evaluate any variant of highs_linear_expression inequalities.
#
# I believe these issues are low risk, the approach thread safe, and the performance/overhead is minimal.
#
__chain = local()

# capture the chain
def __bool__(self):
highs_linear_expression.__chain.inner = self
return True

# double check if original expr hasn't been modified
def __is_equal_except_bounds(self, other):
if isinstance(other, highs_linear_expression):
return self.vars == other.vars and self.vals == other.vals and self.constant == other.constant
else:
return False

# clone the current expression, except if we are within an active comparison chain
def __clone_for_inequality_chain(self):
# get thread local chain status
check = getattr(highs_linear_expression.__chain, 'check', None)
inner = getattr(highs_linear_expression.__chain, 'inner', None)
is_active_chain = check is self and self.__is_equal_except_bounds(inner)

# reset the chain
highs_linear_expression.__chain.inner = None
highs_linear_expression.__chain.check = self

return highs_linear_expression(inner if is_active_chain else self)

@staticmethod
def qsum(items):
"""Performs a faster sum for highs_linear_expressions.
Args:
items: A collection of highs_linear_expressions or highs_vars to be summed.
"""
return Highs.qsum(items)

0 comments on commit b4a0942

Please sign in to comment.