From 3247b15ed51b5e1107ae3c87656f91dcdf532a1f Mon Sep 17 00:00:00 2001 From: David Hoese Date: Tue, 25 Jun 2024 20:50:19 -0500 Subject: [PATCH] Update astronomy to preserve dtype with numpy 2 --- pyorbital/astronomy.py | 90 ++++++++++++++++++++++++------- pyorbital/tests/test_astronomy.py | 47 ++++++++++++++-- 2 files changed, 113 insertions(+), 24 deletions(-) diff --git a/pyorbital/astronomy.py b/pyorbital/astronomy.py index 3c212d49..881e97d2 100644 --- a/pyorbital/astronomy.py +++ b/pyorbital/astronomy.py @@ -1,28 +1,42 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- - +# # Copyright (c) 2011, 2013 - +# # Author(s): - +# # Martin Raspaud - +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. - +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. - +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . +"""Angle and time-based astronomy functions. -"""Astronomy module. Parts taken from http://www.geoastro.de/elevaz/basics/index.htm + +Note on argument types +---------------------- + +Many of these functions accept Python datetime objects, +numpy datetime64 objects, or anything that can be turned +into a numpy array of datetime64 objects. These objects are inherently +64-bit so if other arguments (ex. longitude and latitude arrays) are +32-bit floats internal operations will be automatically promoted to +64-bit floating point numbers. Where possible these are then converted +back to 32-bit before being returned. In general scalar inputs will also +produce scalar outputs. + """ +import datetime import numpy as np @@ -42,12 +56,14 @@ def jdays2000(utc_time): def jdays(utc_time): """Get the julian day of *utc_time*. """ - return jdays2000(utc_time) + 2451545 + return jdays2000(utc_time) + 2451545.0 def _days(dt): """Get the days (floating point) from *d_t*. """ + if hasattr(dt, "shape"): + dt = np.asanyarray(dt, dtype=np.timedelta64) return dt / np.timedelta64(1, 'D') @@ -117,6 +133,7 @@ def _local_hour_angle(utc_time, longitude, right_ascension): def get_alt_az(utc_time, lon, lat): """Return sun altitude and azimuth from *utc_time*, *lon*, and *lat*. + lon,lat in degrees The returned angles are given in radians. """ @@ -125,10 +142,13 @@ def get_alt_az(utc_time, lon, lat): ra_, dec = sun_ra_dec(utc_time) h__ = _local_hour_angle(utc_time, lon, ra_) - return (np.arcsin(np.sin(lat) * np.sin(dec) + - np.cos(lat) * np.cos(dec) * np.cos(h__)), - np.arctan2(-np.sin(h__), (np.cos(lat) * np.tan(dec) - - np.sin(lat) * np.cos(h__)))) + alt_az = (np.arcsin(np.sin(lat) * np.sin(dec) + + np.cos(lat) * np.cos(dec) * np.cos(h__)), + np.arctan2(-np.sin(h__), (np.cos(lat) * np.tan(dec) - + np.sin(lat) * np.cos(h__)))) + if not isinstance(lon, float): + alt_az = (alt_az[0].astype(lon.dtype), alt_az[1].astype(lon.dtype)) + return alt_az def cos_zen(utc_time, lon, lat): @@ -141,7 +161,10 @@ def cos_zen(utc_time, lon, lat): r_a, dec = sun_ra_dec(utc_time) h__ = _local_hour_angle(utc_time, lon, r_a) - return (np.sin(lat) * np.sin(dec) + np.cos(lat) * np.cos(dec) * np.cos(h__)) + csza = (np.sin(lat) * np.sin(dec) + np.cos(lat) * np.cos(dec) * np.cos(h__)) + if not isinstance(lon, float): + csza = csza.astype(lon.dtype) + return csza def sun_zenith_angle(utc_time, lon, lat): @@ -149,13 +172,15 @@ def sun_zenith_angle(utc_time, lon, lat): lon,lat in degrees. The angle returned is given in degrees """ - return np.rad2deg(np.arccos(cos_zen(utc_time, lon, lat))) + sza = np.rad2deg(np.arccos(cos_zen(utc_time, lon, lat))) + if not isinstance(lon, float): + sza = sza.astype(lon.dtype) + return sza def sun_earth_distance_correction(utc_time): """Calculate the sun earth distance correction, relative to 1 AU. """ - # Computation according to # https://web.archive.org/web/20150117190838/http://curious.astro.cornell.edu/question.php?number=582 # with @@ -175,11 +200,10 @@ def sun_earth_distance_correction(utc_time): # "=" 1 - 0.0167 * np.cos(theta) corr = 1 - 0.0167 * np.cos(2 * np.pi * (jdays2000(utc_time) - 3) / 365.25636) - return corr -def observer_position(time, lon, lat, alt): +def observer_position(utc_time, lon, lat, alt): """Calculate observer ECI position. http://celestrak.com/columns/v02n03/ @@ -188,7 +212,7 @@ def observer_position(time, lon, lat, alt): lon = np.deg2rad(lon) lat = np.deg2rad(lat) - theta = (gmst(time) + lon) % (2 * np.pi) + theta = (gmst(utc_time) + lon) % (2 * np.pi) c = 1 / np.sqrt(1 + F * (F - 2) * np.sin(lat)**2) sq = c * (1 - F)**2 @@ -199,6 +223,32 @@ def observer_position(time, lon, lat, alt): vx = -MFACTOR * y # kilometers/second vy = MFACTOR * x - vz = 0 - + vz = _float_to_sibling_result(0.0, vx) + + if not isinstance(lon, float): + x = x.astype(lon.dtype, copy=False) + y = y.astype(lon.dtype, copy=False) + z = z.astype(lon.dtype, copy=False) + vx = vx.astype(lon.dtype, copy=False) + vy = vy.astype(lon.dtype, copy=False) + vz = vz.astype(lon.dtype, copy=False) # type: ignore[union-attr] return (x, y, z), (vx, vy, vz) + + +def _float_to_sibling_result(result_to_convert, template_result): + """Convert a scalar to the same type as another return type. + + This is mostly used to make a static value consistent with the types of + other returned values. + + """ + if isinstance(template_result, float): + return result_to_convert + # get any array like object that might be wrapped by our template (ex. xarray DataArray) + array_like = template_result if hasattr(template_result, "__array_function__") else template_result.data + array_convert = np.asarray(result_to_convert, like=array_like) + if not hasattr(template_result, "__array_function__"): + # the template result has some wrapper class (likely xarray DataArray) + # recreate the wrapper object + array_convert = template_result.__class__(array_convert) + return array_convert diff --git a/pyorbital/tests/test_astronomy.py b/pyorbital/tests/test_astronomy.py index 17e88475..6eba33a8 100644 --- a/pyorbital/tests/test_astronomy.py +++ b/pyorbital/tests/test_astronomy.py @@ -22,11 +22,33 @@ from datetime import datetime +import dask.array as da import numpy as np +import numpy.typing as npt import pytest import pyorbital.astronomy as astr +try: + from xarray import DataArray +except ImportError: + DataArray = None + + +def _create_dask_array(input_list: list, dtype: npt.DTypeLike) -> da.Array: + np_arr = np.array(input_list, dtype=dtype) + return da.from_array(np_arr) + + +def _create_xarray_numpy(input_list: list, dtype: npt.DTypeLike) -> DataArray: + np_arr = np.array(input_list, dtype=dtype) + return DataArray(np_arr) + + +def _create_xarray_dask(input_list: list, dtype: npt.DTypeLike) -> DataArray: + dask_arr = _create_dask_array(input_list, dtype) + return DataArray(dask_arr) + class TestAstronomy: @@ -50,14 +72,30 @@ def test_jdays(self, dt, exp_jdays, exp_j2000): (0.0, 0.0, 1.8751916863323426), ] ) - @pytest.mark.parametrize("dtype", [None, np.float32, np.float64]) - def test_sunangles(self, lon, lat, exp_theta, dtype): + @pytest.mark.parametrize( + ("dtype", "array_construct"), + [ + (None, None), + (np.float32, np.array), + (np.float64, np.array), + (np.float32, _create_dask_array), + (np.float64, _create_dask_array), + (np.float32, _create_xarray_numpy), + (np.float64, _create_xarray_numpy), + (np.float32, _create_xarray_dask), + (np.float64, _create_xarray_dask), + ] + ) + def test_sunangles(self, lon, lat, exp_theta, dtype, array_construct): """Test the sun-angle calculations.""" + if array_construct is None and dtype is not None: + pytest.skip(reason="Xarray dependency unavailable") + time_slot = datetime(2011, 9, 23, 12, 0) abs_tolerance = 1e-8 if dtype is not None: - lon = np.array([lon], dtype=dtype) - lat = np.array([lat], dtype=dtype) + lon = array_construct([lon], dtype=dtype) + lat = array_construct([lat], dtype=dtype) if np.dtype(dtype).itemsize < 8: abs_tolerance = 1e-4 @@ -68,6 +106,7 @@ def test_sunangles(self, lon, lat, exp_theta, dtype): else: assert sun_theta.dtype == dtype np.testing.assert_allclose(sun_theta, exp_theta, atol=abs_tolerance) + assert isinstance(sun_theta, type(lon)) def test_sun_earth_distance_correction(self): """Test the sun-earth distance correction."""