Skip to content

Commit

Permalink
Support multiple spacings
Browse files Browse the repository at this point in the history
  • Loading branch information
diptodip committed Feb 25, 2023
1 parent 11d3ae7 commit 496b3ec
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/chromatix/field.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import jax.numpy as jnp
from chex import Array, assert_rank
from chex import Array, assert_rank, assert_equal_shape
from flax import struct
from einops import rearrange
from typing import Union, Optional, Tuple, Any
Expand Down Expand Up @@ -96,7 +96,7 @@ def create(
must be provided.
"""
# Getting everything into right shape
field_dx: jnp.ndarray = rearrange(jnp.atleast_1d(dx), "1 -> 1 1 1 1")
field_dx: jnp.ndarray = rearrange(jnp.atleast_1d(dx), "c -> 1 1 1 c")
field_spectrum: jnp.ndarray = rearrange(
jnp.atleast_1d(spectrum), "c -> 1 1 1 c"
)
Expand All @@ -106,6 +106,7 @@ def create(
field_spectral_density = field_spectral_density / jnp.sum(
field_spectral_density
) # Must sum to 1
assert_equal_shape([field_dx, field_spectrum, field_spectral_density])
if u is None:
# NOTE(dd): when jitting this function, shape must be a
# static argument --- possibly requiring multiple traces
Expand Down

0 comments on commit 496b3ec

Please sign in to comment.