-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cosmology, Cache, and Configuration data model #86
base: master
Are you sure you want to change the base?
Cosmology, Cache, and Configuration data model #86
Conversation
This makes Cosmology semi-immutable, and allows cached results to survive through unflattening of JAX transformations.
Dataclass is introduced in python 3.7 though, maybe most people have moved on from 3.6? Edit: 3.6 dropped and 3.9 & 3.10 added. CI are made faster. |
With caches, many functions in @EiffL What do you think about this? Right now the cache is a dictionary, so there can still be side effects. Relevant discussion: jax-ml/jax#5344 (comment). |
6e099e3
to
7c848b2
Compare
0dd2f76
to
0056535
Compare
756422f
to
b956647
Compare
Thanks @eelregit there is a lot of great things in there ^^! The cache and dataclass looks nice. And so, yeah the way I see it there is a tradeoff between making pure functions or having a simple API.... The only drawback of the current implementation is in the following case: cosmo = jc.Planck15()
x = jitted_function1(cosmo, ...)
y = jitted_function2(cosmo, ...) in that case the cache computed by the first function is not communicated to the second one, so you do some of the cosmology computation twice, but it doesnt lead to any wrong results. To avoid this and be able to reuse the cache I would just then recommend to write that same code this way: cosmo = jc.Planck15()
@jax.jit
def my_fun(cosmo):
x = function1(cosmo, ...)
y = function2(cosmo, ...)
return x,y
my_fun(cosmo) In practice in many cases you would just jit the likelihood or the simulation code itself and then you have no problem. So the question is whether allowing for using the cache over jitted functions is worth changing the API to have functions return the cosmology object... I'm leaning towards keeping a simple interface: chi = bkgrd.radial_comoving_distance(cosmo, a) instead of cosmo, chi = bkgrd.radial_comoving_distance(cosmo, a) just because it appears very suprising to a typical user. |
Unless you have a compeling use case that really would benefit from the more optimiized implementation. I'm also thinking it could be an option/config to have by default the non-pure API, but if an advanced user wants it, they could retrieved the cosmology and associated cache. What do you think? |
Thanks @EiffL ! The previous non-pure API does not allow functional cache in jitted inner functions like What do you think about the second pattern in jax-ml/jax#5344 (comment) ? |
Hummmm we could precompute everything at the instantiation of the cosmology object... We could imagine a mechanism that "registers" all functions that use cached values and computes the cache before anything else happens... Then the user API would stay the same, the functions would be pure. But.... It would mean that creating a cosmology would be slow for Interactive users.... Hummmm |
And we could have an option to decide which type of execution you want, one that plays nicely with jitted functions, and one that sticks to the current behavior for easy interactive use. |
Something like the following? def compute_y(cosmo, x):
# initialize cache and output cosmo with cache if input is None
if x is None:
if cosmo.is_cached(key):
return cosmo
value = ...
return cosmo.cache_set(key, value)
if not cosmo.is_cached(key):
cosmo = comput_y.init(cosmo) # or more strictly just raise runtimeerror?
value = cosmo.cache_get(key)
y = ...
return y
# and/or something more explicit like
compute_y.init = partial(compute_y, x=None) with some global Cosmology cache initialization like class Cosmology:
...
def cache_init(self, *args):
cosmo = self
for compute_y in args:
cosmo = compute_y.init(cosmo)
return cosmo Contributor should use cosmo = Planck15()
cosmo = cosmo.cache_init(compute_y, compute_z) and are encouraged to think functionally. Maybe we can iterate on this to find convergence ^^ |
If @lru_cache
def precompute_y(cosmo):
table = ...
return table
def compute_y(cosmo, x):
table = precompute_y(cosmo)
y = ...
return y With this it seems like everything can be pure and one doesn't need to touch |
Unfortunately, from functools import lru_cache
from typing import NamedTuple
class C(NamedTuple):
min: float = 0.
max: float = 1.
@lru_cache()
def f(c):
return jnp.linspace(c.min, c.max, 6)
@jit
def g(c, w, b):
return b + w * f(c)
g(C(), 1., 0.) results in TypeError: unhashable type: 'DynamicJaxprTracer' A similar issue in numba: numba/numba#4062 |
frozen
dataclass is semi-immutableaux_fields
can be specified to be the pytreeaux_data
Cosmology.config
using thisContainer
, worth switching?Are there cases where this is not desirable?