Skip to content

Commit

Permalink
Merge pull request #215 from ecmwf/develop
Browse files Browse the repository at this point in the history
Version 1.0.7
  • Loading branch information
mathleur authored Sep 18, 2024
2 parents 2f57082 + 103d50d commit 9f63f40
Show file tree
Hide file tree
Showing 24 changed files with 168 additions and 135 deletions.
8 changes: 5 additions & 3 deletions polytope/datacube/backends/datacube.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Dict

from ...utility.combinatorics import validate_axes
from ..datacube_axis import DatacubeAxis
Expand Down Expand Up @@ -31,9 +31,10 @@ def __init__(self, axis_options=None, compressed_axes_options=[]):
self.merged_axes = []
self.unwanted_path = {}
self.compressed_axes = compressed_axes_options
self.grid_md5_hash = None

@abstractmethod
def get(self, requests: TensorIndexTree) -> Any:
def get(self, requests: TensorIndexTree, context: Dict) -> Any:
"""Return data given a set of request trees"""

@property
Expand Down Expand Up @@ -69,6 +70,7 @@ def _create_axes(self, name, values, transformation_type_key, transformation_opt
# TODO: do we use this?? This shouldn't work for a disk in lat/lon on a octahedral or other grid??
for compressed_grid_axis in transformation.compressed_grid_axes:
self.compressed_grid_axes.append(compressed_grid_axis)
self.grid_md5_hash = transformation.md5_hash
if len(final_axis_names) > 1:
self.coupled_axes.append(final_axis_names)
for axis in final_axis_names:
Expand Down Expand Up @@ -128,7 +130,7 @@ def get_indices(self, path: DatacubePath, axis, lower, upper, method=None):
indexes = axis.find_indexes(path, self)
idx_between = axis.find_indices_between(indexes, lower, upper, self, method)

logging.info(f"For axis {axis.name} between {lower} and {upper}, found indices {idx_between}")
logging.debug(f"For axis {axis.name} between {lower} and {upper}, found indices {idx_between}")

return idx_between

Expand Down
18 changes: 8 additions & 10 deletions polytope/datacube/backends/fdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def check_branching_axes(self, request):
for axis_name in axes_to_remove:
self._axes.pop(axis_name, None)

def get(self, requests: TensorIndexTree):
def get(self, requests: TensorIndexTree, context=None):
if context is None:
context = {}
requests.pprint()
if len(requests.children) == 0:
return requests
Expand All @@ -104,11 +106,11 @@ def get(self, requests: TensorIndexTree):
uncompressed_request = {}
for i, key in enumerate(compressed_request[0].keys()):
uncompressed_request[key] = combi[i]
complete_uncompressed_request = (uncompressed_request, compressed_request[1])
complete_uncompressed_request = (uncompressed_request, compressed_request[1], self.grid_md5_hash)
complete_list_complete_uncompressed_requests.append(complete_uncompressed_request)
complete_fdb_decoding_info.append(fdb_requests_decoding_info[j])
logging.debug("The requests we give GribJump are: %s", complete_list_complete_uncompressed_requests)
output_values = self.gj.extract(complete_list_complete_uncompressed_requests)
output_values = self.gj.extract(complete_list_complete_uncompressed_requests, context)
logging.debug("GribJump outputs: %s", output_values)
self.assign_fdb_output_to_nodes(output_values, complete_fdb_decoding_info)

Expand All @@ -124,7 +126,7 @@ def get_fdb_requests(

# First when request node is root, go to its children
if requests.axis.name == "root":
logging.info("Looking for data for the tree: %s", [leaf.flatten() for leaf in requests.leaves])
logging.debug("Looking for data for the tree: %s", [leaf.flatten() for leaf in requests.leaves])

for c in requests.children:
self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info)
Expand Down Expand Up @@ -161,8 +163,8 @@ def remove_duplicates_in_request_ranges(self, fdb_node_ranges, current_start_idx
new_current_start_idx = []
for j, idx in enumerate(sub_lat_idxs):
if idx not in seen_indices:
# TODO: need to remove it from the values in the corresponding tree node
# TODO: need to read just the range we give to gj ... DONE?
# NOTE: need to remove it from the values in the corresponding tree node
# NOTE: need to read just the range we give to gj
original_fdb_node_range_vals.append(actual_fdb_node[0].values[j])
seen_indices.add(idx)
new_current_start_idx.append(idx)
Expand All @@ -187,8 +189,6 @@ def nearest_lat_lon_search(self, requests):

second_ax = requests.children[0].children[0].axis

# TODO: actually, here we should not remap the nearest_pts, we should instead unmap the
# found_latlon_pts and then remap them later once we have compared found_latlon_pts and nearest_pts
nearest_pts = [
[lat_val, second_ax._remap_val_to_axis_range(lon_val)]
for (lat_val, lon_val) in zip(
Expand Down Expand Up @@ -325,8 +325,6 @@ def sort_fdb_request_ranges(self, current_start_idx, lat_length, fdb_node_ranges
request_ranges_with_idx = list(enumerate(interm_request_ranges))
sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0])
original_indices, sorted_request_ranges = zip(*sorted_list)
logging.debug("We sorted the request ranges into: %s", sorted_request_ranges)
logging.debug("The sorted and unique leaf node ranges are: %s", new_fdb_node_ranges)
return (original_indices, sorted_request_ranges, new_fdb_node_ranges)

def datacube_natural_indexes(self, axis, subarray):
Expand Down
4 changes: 3 additions & 1 deletion polytope/datacube/backends/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ def __init__(self, dimensions, compressed_axes_options=[]):
self.stride[k] = stride_cumulative
stride_cumulative *= self.dimensions[k]

def get(self, requests: TensorIndexTree):
def get(self, requests: TensorIndexTree, context=None):
# Takes in a datacube and verifies the leaves of the tree are complete
# (ie it found values for all datacube axis)

if context is None:
context = {}
for r in requests.leaves:
path = r.flatten()
if len(path.items()) == len(self.dimensions.items()):
Expand Down
8 changes: 5 additions & 3 deletions polytope/datacube/backends/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ def __init__(self, dataarray: xr.DataArray, axis_options=None, compressed_axes_o
val = self._axes[name].type
self._check_and_add_axes(options, name, val)

def get(self, requests, leaf_path=None, axis_counter=0):
def get(self, requests, context=None, leaf_path=None, axis_counter=0):
if context is None:
context = {}
if leaf_path is None:
leaf_path = {}
if requests.axis.name == "root":
for c in requests.children:
self.get(c, leaf_path, axis_counter + 1)
self.get(c, context, leaf_path, axis_counter + 1)
else:
key_value_path = {requests.axis.name: requests.values}
ax = requests.axis
Expand All @@ -66,7 +68,7 @@ def get(self, requests, leaf_path=None, axis_counter=0):
if len(requests.children) != 0:
# We are not a leaf and we loop over
for c in requests.children:
self.get(c, leaf_path, axis_counter + 1)
self.get(c, context, leaf_path, axis_counter + 1)
else:
if self.axis_counter != axis_counter:
requests.remove_branch()
Expand Down
1 change: 1 addition & 0 deletions polytope/datacube/tensor_index_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def add_child(self, node):
def add_value(self, value):
new_values = list(self.values)
new_values.append(value)
new_values.sort()
self.values = tuple(new_values)

def create_child(self, axis, value, next_nodes):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self, name, mapper_options):
self._final_mapped_axes = self._final_transformation._mapped_axes
self._axis_reversed = self._final_transformation._axis_reversed
self.compressed_grid_axes = self._final_transformation.compressed_grid_axes
self.md5_hash = self._final_transformation.md5_hash

def generate_final_transformation(self):
map_type = _type_to_datacube_mapper_lookup[self.grid_type]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self, base_axis, mapped_axes, resolution, local_area=[]):
self._axis_reversed = {mapped_axes[0]: True, mapped_axes[1]: False}
self._first_axis_vals = self.first_axis_vals()
self.compressed_grid_axes = [self._mapped_axes[1]]
self.md5_hash = md5_hash.get(resolution, None)

def first_axis_vals(self):
rad2deg = 180 / math.pi
Expand Down Expand Up @@ -133,3 +134,7 @@ def unmap(self, first_val, second_val):
second_idx = self.second_axis_vals(first_val).index(second_val)
healpix_index = self.axes_idx_to_healpix_idx(first_idx, second_idx)
return healpix_index


# md5 grid hash in form {resolution : hash}
md5_hash = {}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, base_axis, mapped_axes, resolution, local_area=[]):
self.k = int(math.log2(self.Nside))
self.Npix = 12 * self.Nside * self.Nside
self.Ncap = (self.Nside * (self.Nside - 1)) << 1
self.md5_hash = md5_hash.get(resolution, None)

def first_axis_vals(self):
rad2deg = 180 / math.pi
Expand Down Expand Up @@ -211,3 +212,7 @@ def ring_to_nested(self, idx):

def int_sqrt(self, i):
return int(math.sqrt(i + 0.5))


# md5 grid hash in form {resolution : hash}
md5_hash = {}
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ def __init__(self, base_axis, mapped_axes, resolution, local_area=[]):
if not isinstance(resolution, list):
self.first_resolution = resolution
self.second_resolution = resolution
self.md5_hash = md5_hash.get(resolution, None)
else:
self.first_resolution = resolution[0]
self.second_resolution = resolution[1]
self.md5_hash = md5_hash.get(tuple(resolution), None)
self._first_deg_increment = (local_area[1] - local_area[0]) / self.first_resolution
self._second_deg_increment = (local_area[3] - local_area[2]) / self.second_resolution
self._axis_reversed = {mapped_axes[0]: True, mapped_axes[1]: False}
Expand Down Expand Up @@ -68,3 +70,7 @@ def unmap(self, first_val, second_val):
second_idx = self.second_axis_vals(first_val).index(second_val)
final_index = self.axes_idx_to_regular_idx(first_idx, second_idx)
return final_index


# md5 grid hash in form {resolution : hash}
md5_hash = {}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, base_axis, mapped_axes, resolution, local_area=[]):
self._second_axis_spacing = {}
self._axis_reversed = {mapped_axes[0]: True, mapped_axes[1]: False}
self.compressed_grid_axes = [self._mapped_axes[1]]
self.md5_hash = md5_hash.get(resolution, None)

def gauss_first_guess(self):
i = 0
Expand Down Expand Up @@ -2750,3 +2751,9 @@ def unmap(self, first_val, second_val):
(first_idx, second_idx) = self.find_second_axis_idx(first_val, second_val)
octahedral_index = self.axes_idx_to_octahedral_idx(first_idx, second_idx)
return octahedral_index


# md5 grid hash in form {resolution : hash}
md5_hash = {
1280: "158db321ae8e773681eeb40e0a3d350f",
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(self, base_axis, mapped_axes, resolution, local_area=[]):
self._axis_reversed = {mapped_axes[0]: False, mapped_axes[1]: False}
self._first_axis_vals = self.first_axis_vals()
self.compressed_grid_axes = [self._mapped_axes[1]]
self.md5_hash = md5_hash.get(resolution, None)

def first_axis_vals(self):
resolution = 180 / (self._resolution - 1)
Expand Down Expand Up @@ -1504,3 +1505,7 @@ def unmap(self, first_val, second_val):
second_idx = self.second_axis_vals(first_val).index(second_val)
reduced_ll_index = self.axes_idx_to_reduced_ll_idx(first_idx, second_idx)
return reduced_ll_index


# md5 grid hash in form {resolution : hash}
md5_hash = {}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self, base_axis, mapped_axes, resolution, local_area=[]):
self._axis_reversed = {mapped_axes[0]: True, mapped_axes[1]: False}
self._first_axis_vals = self.first_axis_vals()
self.compressed_grid_axes = [self._mapped_axes[1]]
self.md5_hash = md5_hash.get(resolution, None)

def first_axis_vals(self):
first_ax_vals = [90 - i * self.deg_increment for i in range(2 * self._resolution)]
Expand Down Expand Up @@ -56,3 +57,7 @@ def unmap(self, first_val, second_val):
second_idx = self.second_axis_vals(first_val).index(second_val)
final_index = self.axes_idx_to_regular_idx(first_idx, second_idx)
return final_index


# md5 grid hash in form {resolution : hash}
md5_hash = {}
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,5 @@ def unmap_tree_node(self, node, unwanted_path):
if node.axis.name == self._first_axis:
(new_first_vals, new_second_vals) = self.unmerge(node.values)
node.values = new_first_vals
# TODO: actually need to give the second axis of the transformation to get the interm axis
interm_node = node.add_node_layer_after(self._second_axis, new_second_vals)
return (interm_node, unwanted_path)
50 changes: 18 additions & 32 deletions polytope/engine/hullslicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def remap_values(self, ax, value):

def _build_sliceable_child(self, polytope, ax, node, datacube, values, next_nodes, slice_axis_idx):
for i, value in enumerate(values):
if i == 0:
if i == 0 or ax.name not in self.compressed_axes:
fvalue = ax.to_float(value)
new_polytope = slice(polytope, ax.name, fvalue, slice_axis_idx)
remapped_val = self.remap_values(ax, value)
Expand All @@ -121,19 +121,8 @@ def _build_sliceable_child(self, polytope, ax, node, datacube, values, next_node
child["unsliced_polytopes"].add(new_polytope)
next_nodes.append(child)
else:
if ax.name not in self.compressed_axes:
fvalue = ax.to_float(value)
new_polytope = slice(polytope, ax.name, fvalue, slice_axis_idx)
remapped_val = self.remap_values(ax, value)
(child, next_nodes) = node.create_child(ax, remapped_val, next_nodes)
child["unsliced_polytopes"] = copy(node["unsliced_polytopes"])
child["unsliced_polytopes"].remove(polytope)
if new_polytope is not None:
child["unsliced_polytopes"].add(new_polytope)
next_nodes.append(child)
else:
remapped_val = self.remap_values(ax, value)
child.add_value(remapped_val)
remapped_val = self.remap_values(ax, value)
child.add_value(remapped_val)

def _build_branch(self, ax, node, datacube, next_nodes):
if ax.name not in self.compressed_axes:
Expand All @@ -142,26 +131,23 @@ def _build_branch(self, ax, node, datacube, next_nodes):
for polytope in node["unsliced_polytopes"]:
if ax.name in polytope._axes:
right_unsliced_polytopes.append(polytope)
# for polytope in node["unsliced_polytopes"]:
for i, polytope in enumerate(right_unsliced_polytopes):
node._parent = parent_node
# if ax.name in polytope._axes:
if True:
lower, upper, slice_axis_idx = polytope.extents(ax.name)
# here, first check if the axis is an unsliceable axis and directly build node if it is
# NOTE: we should have already created the ax_is_unsliceable cache before
if self.ax_is_unsliceable[ax.name]:
self._build_unsliceable_child(polytope, ax, node, datacube, [lower], next_nodes, slice_axis_idx)
else:
values = self.find_values_between(polytope, ax, node, datacube, lower, upper)
# NOTE: need to only remove the branches if the values are empty,
# but only if there are no other possible children left in the tree that
# we can append and if somehow this happens before and we need to remove, then what do we do??
if i == len(right_unsliced_polytopes) - 1:
# we have iterated all polytopes and we can now remove the node if we need to
if len(values) == 0 and len(node.children) == 0:
node.remove_branch()
self._build_sliceable_child(polytope, ax, node, datacube, values, next_nodes, slice_axis_idx)
lower, upper, slice_axis_idx = polytope.extents(ax.name)
# here, first check if the axis is an unsliceable axis and directly build node if it is
# NOTE: we should have already created the ax_is_unsliceable cache before
if self.ax_is_unsliceable[ax.name]:
self._build_unsliceable_child(polytope, ax, node, datacube, [lower], next_nodes, slice_axis_idx)
else:
values = self.find_values_between(polytope, ax, node, datacube, lower, upper)
# NOTE: need to only remove the branches if the values are empty,
# but only if there are no other possible children left in the tree that
# we can append and if somehow this happens before and we need to remove, then what do we do??
if i == len(right_unsliced_polytopes) - 1:
# we have iterated all polytopes and we can now remove the node if we need to
if len(values) == 0 and len(node.children) == 0:
node.remove_branch()
self._build_sliceable_child(polytope, ax, node, datacube, values, next_nodes, slice_axis_idx)
else:
all_values = []
all_lowers = []
Expand Down
1 change: 0 additions & 1 deletion polytope/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ class Config(ConfigModel):
class PolytopeOptions(ABC):
@staticmethod
def get_polytope_options(options):

parser = argparse.ArgumentParser(allow_abbrev=False)
conflator = Conflator(app_name="polytope", model=Config, cli=False, argparser=parser, **options)
config_options = conflator.load()
Expand Down
10 changes: 8 additions & 2 deletions polytope/polytope.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import List

from .options import PolytopeOptions
Expand Down Expand Up @@ -55,9 +56,14 @@ def slice(self, polytopes: List[ConvexPolytope]):
"""Low-level API which takes a polytope geometry object and uses it to slice the datacube"""
return self.engine.extract(self.datacube, polytopes)

def retrieve(self, request: Request, method="standard"):
def retrieve(self, request: Request, method="standard", context=None):
"""Higher-level API which takes a request and uses it to slice the datacube"""
if context is None:
context = {}
logging.info("Starting request for %s ", context)
self.datacube.check_branching_axes(request)
request_tree = self.engine.extract(self.datacube, request.polytopes())
self.datacube.get(request_tree)
logging.info("Created request tree for %s ", context)
self.datacube.get(request_tree, context)
logging.info("Retrieved data for %s ", context)
return request_tree
Loading

0 comments on commit 9f63f40

Please sign in to comment.