From fbcf525097039bb1a9280533381f73d09d307531 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Tue, 27 Aug 2024 09:15:34 -0400 Subject: [PATCH] weird nan bug in Reg/Zouhe fixed. Python pointer issue! --- .../boundary_condition/bc_regularized.py | 40 ++----------------- xlb/velocity_set/velocity_set.py | 20 ++++++++++ 2 files changed, 23 insertions(+), 37 deletions(-) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 6958d1b..84dbbf9 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -63,25 +63,6 @@ def __init__( # The operator to compute the momentum flux self.momentum_flux = MomentumFlux() - # helper function - def compute_qi(self): - # Qi = cc - cs^2*I - dim = self.velocity_set.d - Qi = self.velocity_set.cc - if dim == 3: - diagonal = (0, 3, 5) - offdiagonal = (1, 2, 4) - elif dim == 2: - diagonal = (0, 2) - offdiagonal = (1,) - else: - raise ValueError(f"dim = {dim} not supported") - - # multiply off-diagonal elements by 2 because the Q tensor is symmetric - Qi[:, diagonal] += -1.0 / 3.0 - Qi[:, offdiagonal] *= 2.0 - return Qi - @partial(jit, static_argnums=(0,), inline=True) def regularize_fpop(self, fpop, feq): """ @@ -102,22 +83,7 @@ def regularize_fpop(self, fpop, feq): # Qi = cc - cs^2*I dim = self.velocity_set.d weights = self.velocity_set.w[(slice(None),) + (None,) * dim] - # TODO: if I use the following I get NaN ! figure out why! - # Qi = jnp.array(self.compute_qi(), dtype=self.compute_dtype) - Qi = jnp.array(self.velocity_set.cc, dtype=self.compute_dtype) - if dim == 3: - diagonal = (0, 3, 5) - offdiagonal = (1, 2, 4) - elif dim == 2: - diagonal = (0, 2) - offdiagonal = (1,) - else: - raise ValueError(f"dim = {dim} not supported") - - # Qi = cc - cs^2*I - # multiply off-diagonal elements by 2 because the Q tensor is symmetric - Qi = Qi.at[:, diagonal].add(-1.0 / 3.0) - Qi = Qi.at[:, offdiagonal].multiply(2.0) + Qi = jnp.array(self.velocity_set.qi, dtype=self.compute_dtype) # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} f_neq = fpop - feq @@ -166,7 +132,6 @@ def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update # _u_vec = wp.vec(_d, dtype=self.compute_dtype) # compute Qi tensor and store it in self - _qi = wp.constant(wp.mat((_q, _d * (_d + 1) // 2), dtype=wp.float32)(self.compute_qi())) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _rho = wp.float32(rho) @@ -175,7 +140,8 @@ def _construct_warp(self): _w = self.velocity_set.wp_w _c = self.velocity_set.wp_c _c32 = self.velocity_set.wp_c32 - # TODO: this is way less than ideal. we should not be making new types + _qi = self.velocity_set.wp_qi + # TODO: related to _c32: this is way less than ideal. we should not be making new types @wp.func def _get_fsum( diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index cd63b36..a93d039 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -41,6 +41,7 @@ def __init__(self, d, q, c, w): self.main_indices = self._construct_main_indices() self.right_indices = self._construct_right_indices() self.left_indices = self._construct_left_indices() + self.qi = self._construct_qi() # Make warp constants for these vectors # TODO: Following warp updates these may not be necessary @@ -49,6 +50,7 @@ def __init__(self, d, q, c, w): self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices)) self.wp_cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc)) self.wp_c32 = wp.constant(wp.mat((self.d, self.q), dtype=wp.float32)(self.c)) + self.wp_qi = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.qi)) def warp_lattice_vec(self, dtype): return wp.vec(len(self.c), dtype=dtype) @@ -59,6 +61,24 @@ def warp_u_vec(self, dtype): def warp_stream_mat(self, dtype): return wp.mat((self.q, self.d), dtype=dtype) + def _construct_qi(self): + # Qi = cc - cs^2*I + dim = self.d + Qi = self.cc.copy() + if dim == 3: + diagonal = (0, 3, 5) + offdiagonal = (1, 2, 4) + elif dim == 2: + diagonal = (0, 2) + offdiagonal = (1,) + else: + raise ValueError(f"dim = {dim} not supported") + + # multiply off-diagonal elements by 2 because the Q tensor is symmetric + Qi[:, diagonal] += -1.0 / 3.0 + Qi[:, offdiagonal] *= 2.0 + return Qi + def _construct_lattice_moment(self): """ This function constructs the moments of the lattice.