Skip to content

Commit

Permalink
Make sure variable types are hashable + sortable
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Aug 14, 2024
1 parent 548ff1f commit 0231cff
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions src/jaxls/_variables.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from functools import total_ordering
from typing import Any, Callable, ClassVar, Iterable, Literal, cast, overload

import jax
Expand All @@ -10,6 +11,24 @@
from jax import numpy as jnp


@total_ordering
class _HashableSortableMeta(type):
"""We use variable types as dictionary keys. This metaclass makes sure that
the types themselves can be hashed and ordered.
Relevant: https://github.com/google/jax/issues/15358
"""

def __hash__(cls):
return object.__hash__(cls)

def __lt__(cls, other):
if cls.__name__ == other.__name__:
return id(cls) < id(other)
else:
return cls.__name__ < other.__name__


@dataclass(frozen=True)
class VarTypeOrdering:
"""Object describing how variables are ordered within a `VarValues` object
Expand All @@ -28,13 +47,11 @@ def ordered_dict_items[T](
self,
var_type_mapping: dict[type[Var[Any]], T],
) -> list[tuple[type[Var[Any]], T]]:
return sorted(
var_type_mapping.items(), key=lambda x: self.order_from_type[x[0]]
)
return sorted(var_type_mapping.items(), key=lambda x: x[0])


@jdc.pytree_dataclass
class Var[T]:
class Var[T](metaclass=_HashableSortableMeta):
"""A symbolic representation of an optimization variable."""

id: int | jax.Array
Expand Down

0 comments on commit 0231cff

Please sign in to comment.