From b4a0942ef0f2906e8e90a3979dc0b552f5dda23e Mon Sep 17 00:00:00 2001 From: Luke Marshall <52978038+mathgeekcoder@users.noreply.github.com> Date: Thu, 22 Aug 2024 17:38:11 -0700 Subject: [PATCH] * changed highs_linear_expression to be immutable by default * added __repr__ for easier debugging and pretty print linear expressions * added __iadd__ for mutable operations * added qsum for faster aggregation * updated chained comparison support to work in immutable setting --- src/highspy/highs.py | 189 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 161 insertions(+), 28 deletions(-) diff --git a/src/highspy/highs.py b/src/highspy/highs.py index 4d4f4e05c2..5a421cf1e9 100644 --- a/src/highspy/highs.py +++ b/src/highspy/highs.py @@ -36,6 +36,7 @@ from itertools import groupby, product from operator import itemgetter from decimal import Decimal +from threading import local class Highs(_Highs): """HiGHS solver interface""" @@ -521,7 +522,7 @@ def addConstr(self, cons, name=None): super().passRowName(con.index, name) return con - + def addConstrs(self, *args, **kwargs): """Adds multiple constraints to the model. @@ -648,6 +649,21 @@ def setContinuous(self, var): """ super().changeColIntegrality(var.index, HighsVarType.kContinuous) + @staticmethod + def qsum(items): + """Performs a faster sum for highs_linear_expressions. + + Args: + items: A collection of highs_linear_expressions or highs_vars to be summed. + """ + expr = highs_linear_expression() + + for v in items: + expr += v + + return expr + + ## The following classes keep track of variables ## It is currently quite basic and may fail in complex scenarios @@ -760,25 +776,39 @@ def __init__(self, other=None): def __neg__(self): return -1.0 * self + def __repr__(self): + # if we have duplicate variables, add the vals together + agg = [(var, sum(v[1] for v in Vals)) for var, Vals in groupby(sorted(zip(self.vars, self.vals)), key=itemgetter(0))] + + v = str.join(" ", [f"{c}_v{x}" for x,c in agg]) + c = f" {self.constant}" if abs(self.constant) > 1e-6 else '' + + if self.LHS == self.RHS: + return f"{v}{c} == {self.LHS}" + else: + return f"{self.LHS} <= {v}{c} <= {self.RHS}" + # (LHS <= self <= RHS) <= (other.LHS <= other <= other.RHS) def __le__(self, other): if isinstance(other, highs_linear_expression): if self.LHS != -kHighsInf and self.RHS != kHighsInf and len(other.vars) > 0 or other.LHS != -kHighsInf: raise Exception('Cannot construct constraint with variables as bounds.') - # move variables from other to self - self.vars.extend(other.vars) - self.vals.extend([-1.0 * v for v in other.vals]) - self.constant -= other.constant - self.RHS = 0 - return self + # move variables from other to copy + copy = self.__clone_for_inequality_chain() + copy.vars.extend(other.vars) + copy.vals.extend([-1.0 * v for v in other.vals]) + copy.constant -= other.constant + copy.RHS = 0 + return copy elif isinstance(other, highs_var): return NotImplemented elif isinstance(other, (int, float, Decimal)): - self.RHS = min(self.RHS, other) - return self + copy = self.__clone_for_inequality_chain() + copy.RHS = min(copy.RHS, other) + return copy else: return NotImplemented @@ -789,13 +819,14 @@ def __eq__(self, other): if self.LHS != -kHighsInf and len(other.vars) > 0 or other.LHS != -kHighsInf: raise Exception('Cannot construct constraint with variables as bounds.') - # move variables from other to self - self.vars.extend(other.vars) - self.vals.extend([-1.0 * v for v in other.vals]) - self.constant -= other.constant - self.LHS = 0 - self.RHS = 0 - return self + # move variables from other to copy + copy = self.__clone_for_inequality_chain() + copy.vars.extend(other.vars) + copy.vals.extend([-1.0 * v for v in other.vals]) + copy.constant -= other.constant + copy.LHS = 0 + copy.RHS = 0 + return copy elif isinstance(other, highs_var): return NotImplemented @@ -804,9 +835,20 @@ def __eq__(self, other): if self.LHS != -kHighsInf or self.RHS != kHighsInf: raise Exception('Logic error in constraint equality.') - self.LHS = other - self.RHS = other - return self + copy = self.__clone_for_inequality_chain() + copy.LHS = other + copy.RHS = other + return copy + + # support expr == [lb, ub] --> lb <= expr <= ub + elif hasattr(other, "__getitem__") and hasattr(other, "__len__") and len(other) == 2: + if not (isinstance(other[0], (int, float, Decimal)) and isinstance(other[1], (int, float, Decimal))): + raise Exception('Provided bounds were not valid numbers.') + + copy = self.__clone_for_inequality_chain() + copy.LHS = other[0] + copy.RHS = other[1] + return copy else: return NotImplemented @@ -820,8 +862,9 @@ def __ge__(self, other): return NotImplemented elif isinstance(other, (int, float, Decimal)): - self.LHS = max(self.LHS, other) - return self + copy = self.__clone_for_inequality_chain() + copy.LHS = max(copy.LHS, other) + return copy else: return NotImplemented @@ -829,8 +872,8 @@ def __ge__(self, other): def __radd__(self, other): return self + other - # (LHS <= self <= RHS) + (LHS <= other <= RHS) - def __add__(self, other): + # (LHS <= self <= RHS) += (LHS <= other <= RHS) + def __iadd__(self, other): if isinstance(other, highs_linear_expression): self.vars.extend(other.vars) self.vals.extend(other.vals) @@ -851,16 +894,41 @@ def __add__(self, other): else: return NotImplemented + # (LHS <= self <= RHS) + (LHS <= other <= RHS) + def __add__(self, other): + if isinstance(other, highs_linear_expression): + copy = highs_linear_expression(self) + copy.vars.extend(other.vars) + copy.vals.extend(other.vals) + copy.constant += other.constant + copy.LHS = max(copy.LHS, other.LHS) + copy.RHS = min(copy.RHS, other.RHS) + return copy + + elif isinstance(other, highs_var): + copy = highs_linear_expression(self) + copy.vars.append(other.index) + copy.vals.append(1.0) + return copy + + elif isinstance(other, (int, float, Decimal)): + copy = highs_linear_expression(self) + copy.constant += other + return copy + + else: + return NotImplemented + def __rmul__(self, other): return self * other def __mul__(self, other): - result = highs_linear_expression(self) - if isinstance(other, (int, float, Decimal)): - result.vals = [float(other) * v for v in self.vals] - result.constant *= float(other) - return result + copy = highs_linear_expression(self) + copy.vals = [float(other) * v for v in self.vals] + copy.constant *= float(other) + return copy + elif isinstance(other, highs_var): raise Exception('Only linear expressions are allowed.') else: @@ -878,3 +946,68 @@ def __sub__(self, other): return self + (-1.0 * other) else: return NotImplemented + + + # The following is needed to support chained comparison, i.e., lb <= expr <= ub. This is interpreted + # as '__bool__(lb <= expr) and (expr <= ub)'; returning (expr <= ub), since __bool__(lb <= expr) == True. + # + # We essentially want to "rewrite" this as '(lb <= expr) <= ub', while keeping the expr instance immutable. + # As a slight hack, we can use a shared (thread local) object to keep track of the chain. + # + # Whenever we perform an inequality, we first check if the current expression ('self') is part of a chain. + # If it is, we copy the inner '__chain' expression rather than 'self'. + # + # This inner '__chain' is set by __bool__(lb <= expr) and is reset after the inequality is evaluated. + # + # Two potential issues: + # 1. It is possible to manually construct this sequence, e.g., + # tmp = x0 + x1 + # bool(tmp <= 10) + # print(5 <= tmp) # outputs: 5 <= 1.0_v0 + 1.0_v1 <= 10 + # print(5 <= tmp) # outputs: 5 <= 1.0_v0 + 1.0_v1 <= inf : chain is broken + # + # Note that: + # bool(tmp <= 10) + # tmp += x0 + x1 # changes tmp, so the chain is broken + # print(5 <= tmp) # outputs: 5 <= 2.0_v0 + 2.0_v1 <= inf + # + # 2. The chain might "break" if run within a debugger (on same thread), i.e., "watched debugger expressions" + # that evaluate any variant of highs_linear_expression inequalities. + # + # I believe these issues are low risk, the approach thread safe, and the performance/overhead is minimal. + # + __chain = local() + + # capture the chain + def __bool__(self): + highs_linear_expression.__chain.inner = self + return True + + # double check if original expr hasn't been modified + def __is_equal_except_bounds(self, other): + if isinstance(other, highs_linear_expression): + return self.vars == other.vars and self.vals == other.vals and self.constant == other.constant + else: + return False + + # clone the current expression, except if we are within an active comparison chain + def __clone_for_inequality_chain(self): + # get thread local chain status + check = getattr(highs_linear_expression.__chain, 'check', None) + inner = getattr(highs_linear_expression.__chain, 'inner', None) + is_active_chain = check is self and self.__is_equal_except_bounds(inner) + + # reset the chain + highs_linear_expression.__chain.inner = None + highs_linear_expression.__chain.check = self + + return highs_linear_expression(inner if is_active_chain else self) + +@staticmethod +def qsum(items): + """Performs a faster sum for highs_linear_expressions. + + Args: + items: A collection of highs_linear_expressions or highs_vars to be summed. + """ + return Highs.qsum(items)