Skip to content

Commit

Permalink
add gotm transport; bumpy pygetm dependency to 0.9.3
Browse files Browse the repository at this point in the history
  • Loading branch information
jornbr committed Sep 30, 2024
1 parent 91871ec commit a75af45
Show file tree
Hide file tree
Showing 8 changed files with 749 additions and 43 deletions.
2 changes: 1 addition & 1 deletion environment-win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ dependencies:
- cmake
- m2w64-toolchain
- pip
- pygetm>=0.9.2
- pygetm>=0.9.3
- scipy
- h5py
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ dependencies:
- cmake
- fortran-compiler
- pip
- pygetm>=0.9.2
- pygetm>=0.9.3
- scipy
- h5py
1 change: 1 addition & 0 deletions src/fabmos/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._version import __version__

from pygetm import TimeUnit
from pygetm import vertical_coordinates
from pygetm.core import Array

from . import domain
Expand Down
218 changes: 205 additions & 13 deletions src/fabmos/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ def get(
out[slice_spec] = self._global_values
return out[slice_spec]
if self.values is None:
raise Exception(f'self.values is None')
raise Exception(f"self.values is None")
if compressed_values is None:
raise Exception(f'compressed_values is None {self._source.name}')
#raise Exception(f'{self.values.shape}, {self._slice[0]!r}, {self._slice[1].shape!r} {self._source!r}')
raise Exception(f"compressed_values is None {self._source.name}")
# raise Exception(f'{self.values.shape}, {self._slice[0]!r}, {self._slice[1].shape!r} {self._source!r}')
self.values[self._slice] = compressed_values[..., 0, :]
return super().get(out, slice_spec)

Expand All @@ -119,21 +119,87 @@ def coords(self):
yield CompressedToFullGrid(c, self._grid, self._slice[-1], g)


class ClustersToFullGrid(pygetm.output.operators.UnivariateTransformWithData):
def __init__(
self,
source: pygetm.output.operators.Base,
grid: Optional[pygetm.domain.Grid],
clusters: Iterable[np.ndarray],
global_array: Optional[Array] = None,
):
self._grid = grid
self._clusters = clusters
self._slices = [(c,) for c in clusters]
shape = source.shape[:-2] + (grid.ny, grid.nx)
self._z = None
if source.ndim > 2:
self._z = CENTERS if source.shape[0] == grid.nz else INTERFACES
self._slices = [(slice(None),) + s for s in self._slices]
dims = pygetm.output.operators.grid2dims(grid, self._z)
expression = f"{self.__class__.__name__}({source.expression})"
super().__init__(source, shape=shape, dims=dims, expression=expression)
self.values.fill(self.fill_value)
self._global_values = None if global_array is None else global_array.values

def get(
self,
out: Optional[npt.ArrayLike] = None,
slice_spec: Tuple[int, ...] = (),
) -> npt.ArrayLike:
compressed_values = self._source.get()
if self._global_values is not None:
if out is None:
return self._global_values
else:
out[slice_spec] = self._global_values
return out[slice_spec]
if self.values is None:
raise Exception(f"self.values is None")
if compressed_values is None:
raise Exception(f"compressed_values is None {self._source.name}")
# raise Exception(f'{self.values.shape}, {self._slice[0]!r}, {self._slice[1].shape!r} {self._source!r}')
for i, s in enumerate(self._slices):
self.values[s] = compressed_values[..., 0, i, np.newaxis]
return super().get(out, slice_spec)

def _update_coordinates(grid: pygetm.domain.Grid, area: np.ndarray, h: Optional[np.ndarray]=None):
@property
def grid(self) -> pygetm.domain.Grid:
return self._grid

@property
def coords(self):
global_x = self._grid.lon if self._grid.domain.spherical else self._grid.x
global_y = self._grid.lat if self._grid.domain.spherical else self._grid.y
global_z = self._grid.zc if self._z == CENTERS else self._grid.zf
globals_arrays = (global_x, global_y, global_z)
for c, g in zip(self._source.coords, globals_arrays):
yield ClustersToFullGrid(c, self._grid, self._clusters, g)


def _update_coordinates(
grid: pygetm.domain.Grid,
area: np.ndarray,
h: Optional[np.ndarray] = None,
bottom_to_surface: bool = False,
):
slc_loc, slc_glob, _, _ = grid.domain.tiling.subdomain2slices()
grid.D.values[slc_loc] = grid.H.values[slc_loc]
if h is None:
grid.ho.values[slc_loc] = grid.H.values[slc_loc]
grid.hn.values[slc_loc] = grid.H.values[slc_loc]
grid.domain.vertical_coordinates.initialize(grid)
grid.domain.vertical_coordinates.update(0.0)
grid.ho.values[slc_loc] = grid.hn.values[slc_loc]
else:
grid.ho.values[slc_loc] = h[slc_glob]
grid.hn.values[slc_loc] = h[slc_glob]
grid.zf.all_values.fill(0.0)
grid.zf.all_values[1:, ...] = -grid.hn.all_values.cumsum(axis=0)
grid.zc.all_values[...] = 0.5 * (
grid.zf.all_values[:-1, :, :] + grid.zf.all_values[1:, :, :]
)
grid.zf.all_values[1:] = grid.hn.all_values.cumsum(axis=0)
if bottom_to_surface:
# First interface = -sum(hn), then increasing to 0
grid.zf.all_values -= grid.zf.all_values[-1]
else:
# First interface = 0, then decreasing to -sum(hn)
grid.zf.all_values *= -1.0
grid.zc.all_values[...] = 0.5 * (grid.zf.all_values[:-1] + grid.zf.all_values[1:])
grid.zc.all_values[:, grid._land] = 0.0
grid.zf.all_values[:, grid._land] = 0.0
grid.domain.depth.all_values[...] = -grid.zc.all_values
Expand All @@ -147,7 +213,6 @@ def _update_coordinates(grid: pygetm.domain.Grid, area: np.ndarray, h: Optional[

grid.area.values[slc_loc] = area[slc_glob]
grid.iarea.values[slc_loc] = 1.0 / grid.area.values[slc_loc]
grid.D.values[slc_loc] = grid.H.values[slc_loc]


def compress(full_domain: Optional[Domain], comm: Optional[MPI.Comm] = None) -> Domain:
Expand Down Expand Up @@ -191,7 +256,8 @@ def compress(full_domain: Optional[Domain], comm: Optional[MPI.Comm] = None) ->
tiling=tiling,
logger=full_domain.root_logger,
halox=0,
haloy=0
haloy=0,
vertical_coordinates=full_domain.vertical_coordinates,
)

slc_loc, slc_glob, _, _ = domain.tiling.subdomain2slices()
Expand All @@ -211,11 +277,137 @@ def compress(full_domain: Optional[Domain], comm: Optional[MPI.Comm] = None) ->
)
domain.default_output_transforms.append(tf)

domain.uncompressed_area = area
domain.global_area = area

return domain


def compress_clusters(
full_domain: Optional[Domain],
clusters: npt.ArrayLike,
comm: Optional[MPI.Comm] = None,
decompress_output: bool = False,
) -> Domain:
full_domain.initialize(pygetm.BAROCLINIC)

# More compressed domain: simple subdomain division along x dimension
tiling = pygetm.parallel.Tiling(nrow=1, comm=comm)

to_compress = ("mask", "lon", "lat", "x", "y", "H", "area")
global_fields = {}
values = None
for name in to_compress:
if full_domain.glob is not None and hasattr(full_domain.glob, name):
all_values = getattr(full_domain.glob, name)
if all_values is not None:
values = all_values[1::2, 1::2]
global_values = tiling.comm.bcast(values)
if global_values is not None:
global_fields[name] = global_values

unmasked = global_fields["mask"] != 0
clusters = np.ma.asarray(clusters)
assert clusters.shape == unmasked.shape
unmasked &= ~np.ma.getmaskarray(clusters)
clusters = np.asarray(clusters)
unique_clusters = np.unique(clusters[unmasked])

compressed_fields = {}
for name, values in global_fields.items():
compressed_fields[name] = np.empty((unique_clusters.size,), dtype=values.dtype)

logger = full_domain.logger
logger.info(f"Found {unique_clusters.size} unique clusters:")
for i, c in enumerate(unique_clusters):
sel = clusters == c
cluster_values = {}
for name, values in global_fields.items():
cluster_values[name] = values[sel]
compressed_fields[name][i] = cluster_values[name].mean()
compressed_fields["area"][i] = cluster_values["area"].sum()

mean_lon = compressed_fields["lon"][i]
mean_lat = compressed_fields["lat"][i]
dlon = np.abs(global_fields["lon"] - mean_lon)
dlon[dlon > 180] = 360 - dlon[dlon > 180]
dist = dlon**2 + (global_fields["lat"] - mean_lat) ** 2
inear = np.ma.array(dist, mask=~sel).argmin()
near_lon = global_fields["lon"].flat[inear]
near_lat = global_fields["lat"].flat[inear]

logger.info(
f" {c} mean: {mean_lon:.6f} degrees East, {mean_lat:.6f} degrees North, {compressed_fields['H'][i]:.1f} m"
)
logger.info(f" {c}: {near_lon:.6f} degrees East, {near_lat:.6f} degrees North")

domain = pygetm.domain.create(
unique_clusters.size,
1,
full_domain.nz,
x=compressed_fields["x"],
y=compressed_fields["y"],
lon=compressed_fields["lon"],
lat=compressed_fields["lat"],
mask=1,
H=compressed_fields["H"],
spherical=full_domain.spherical,
tiling=tiling,
logger=full_domain.root_logger,
halox=0,
haloy=0,
)

domain.global_area = compressed_fields["area"][np.newaxis, :]

if decompress_output:
if domain.tiling.rank == 0:
tf = functools.partial(
ClustersToFullGrid,
grid=full_domain.T,
clusters=[clusters == c for c in unique_clusters],
)
domain.default_output_transforms.append(tf)
else:
dims = ("lat", "lon") if full_domain.spherical else ("y", "x")
dims_ = (dims[0] + "_", dims[1] + "_")
cluster_index = np.full(clusters.shape, -1, dtype=np.int16)
for i, v in enumerate(unique_clusters):
cluster_index[clusters == v] = i
domain.T.extra_output_coordinates.append(
pygetm.output.operators.WrappedArray(
cluster_index, "cluster_index", dims_, fill_value=-1
)
)
for name in dims:
array = getattr(domain.T, name)
attrs = {"units": array.units, "long_name": array.long_name}
domain.T.extra_output_coordinates.append(
pygetm.output.operators.WrappedArray(
global_fields[name], name + "_", dims_, attrs=attrs
)
)

return domain


def split_clusters(clusters: npt.ArrayLike):
from skimage.segmentation import flood_fill

masked = np.ma.getmaskarray(clusters)
clusters = np.asarray(clusters)
unique_clusters = np.unique(clusters[~masked])
n = -1
for c in unique_clusters:
while True:
indices = (clusters == c).nonzero()
if indices[0].size == 0:
break
n += 1
seed_point = (indices[0][0], indices[1][0])
flood_fill(clusters, seed_point, -n, connectivity=1, in_place=True)
return np.ma.array(-clusters, mask=masked)


def drop_grids(domain: Domain, *grids: Grid):
for name in list(domain.fields):
if domain.fields[name].grid in grids:
Expand Down
Loading

0 comments on commit a75af45

Please sign in to comment.