Skip to content

Commit

Permalink
weird nan bug in Reg/Zouhe fixed. Python pointer issue!
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Aug 27, 2024
1 parent 2cd61fd commit fbcf525
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 37 deletions.
40 changes: 3 additions & 37 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,6 @@ def __init__(
# The operator to compute the momentum flux
self.momentum_flux = MomentumFlux()

# helper function
def compute_qi(self):
# Qi = cc - cs^2*I
dim = self.velocity_set.d
Qi = self.velocity_set.cc
if dim == 3:
diagonal = (0, 3, 5)
offdiagonal = (1, 2, 4)
elif dim == 2:
diagonal = (0, 2)
offdiagonal = (1,)
else:
raise ValueError(f"dim = {dim} not supported")

# multiply off-diagonal elements by 2 because the Q tensor is symmetric
Qi[:, diagonal] += -1.0 / 3.0
Qi[:, offdiagonal] *= 2.0
return Qi

@partial(jit, static_argnums=(0,), inline=True)
def regularize_fpop(self, fpop, feq):
"""
Expand All @@ -102,22 +83,7 @@ def regularize_fpop(self, fpop, feq):
# Qi = cc - cs^2*I
dim = self.velocity_set.d
weights = self.velocity_set.w[(slice(None),) + (None,) * dim]
# TODO: if I use the following I get NaN ! figure out why!
# Qi = jnp.array(self.compute_qi(), dtype=self.compute_dtype)
Qi = jnp.array(self.velocity_set.cc, dtype=self.compute_dtype)
if dim == 3:
diagonal = (0, 3, 5)
offdiagonal = (1, 2, 4)
elif dim == 2:
diagonal = (0, 2)
offdiagonal = (1,)
else:
raise ValueError(f"dim = {dim} not supported")

# Qi = cc - cs^2*I
# multiply off-diagonal elements by 2 because the Q tensor is symmetric
Qi = Qi.at[:, diagonal].add(-1.0 / 3.0)
Qi = Qi.at[:, offdiagonal].multiply(2.0)
Qi = jnp.array(self.velocity_set.qi, dtype=self.compute_dtype)

# Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq}
f_neq = fpop - feq
Expand Down Expand Up @@ -166,7 +132,6 @@ def _construct_warp(self):
# Set local constants TODO: This is a hack and should be fixed with warp update
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
# compute Qi tensor and store it in self
_qi = wp.constant(wp.mat((_q, _d * (_d + 1) // 2), dtype=wp.float32)(self.compute_qi()))
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_rho = wp.float32(rho)
Expand All @@ -175,7 +140,8 @@ def _construct_warp(self):
_w = self.velocity_set.wp_w
_c = self.velocity_set.wp_c
_c32 = self.velocity_set.wp_c32
# TODO: this is way less than ideal. we should not be making new types
_qi = self.velocity_set.wp_qi
# TODO: related to _c32: this is way less than ideal. we should not be making new types

@wp.func
def _get_fsum(
Expand Down
20 changes: 20 additions & 0 deletions xlb/velocity_set/velocity_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, d, q, c, w):
self.main_indices = self._construct_main_indices()
self.right_indices = self._construct_right_indices()
self.left_indices = self._construct_left_indices()
self.qi = self._construct_qi()

# Make warp constants for these vectors
# TODO: Following warp updates these may not be necessary
Expand All @@ -49,6 +50,7 @@ def __init__(self, d, q, c, w):
self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices))
self.wp_cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc))
self.wp_c32 = wp.constant(wp.mat((self.d, self.q), dtype=wp.float32)(self.c))
self.wp_qi = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.qi))

def warp_lattice_vec(self, dtype):
return wp.vec(len(self.c), dtype=dtype)
Expand All @@ -59,6 +61,24 @@ def warp_u_vec(self, dtype):
def warp_stream_mat(self, dtype):
return wp.mat((self.q, self.d), dtype=dtype)

def _construct_qi(self):
# Qi = cc - cs^2*I
dim = self.d
Qi = self.cc.copy()
if dim == 3:
diagonal = (0, 3, 5)
offdiagonal = (1, 2, 4)
elif dim == 2:
diagonal = (0, 2)
offdiagonal = (1,)
else:
raise ValueError(f"dim = {dim} not supported")

# multiply off-diagonal elements by 2 because the Q tensor is symmetric
Qi[:, diagonal] += -1.0 / 3.0
Qi[:, offdiagonal] *= 2.0
return Qi

def _construct_lattice_moment(self):
"""
This function constructs the moments of the lattice.
Expand Down

0 comments on commit fbcf525

Please sign in to comment.