diff --git a/astromodels/functions/functions_2D.py b/astromodels/functions/functions_2D.py index 69c055ea..7e3b395e 100644 --- a/astromodels/functions/functions_2D.py +++ b/astromodels/functions/functions_2D.py @@ -8,6 +8,8 @@ from astromodels.utils.angular_distance import angular_distance from astromodels.utils.vincenty import vincenty +import hashlib + class Latitude_galactic_diffuse(Function2D): r""" @@ -534,6 +536,12 @@ class SpatialTemplate_2D(Function2D): desc : normalization initial value : 1 fix : yes + + hash : + + desc: hash of model map [needed for memoization] + initial value: 1 + fix: yes """ @@ -556,14 +564,21 @@ def load_file(self,fitsfile,ihdu=0): self._wcs = wcs.WCS( header = f[ihdu].header ) self._map = f[ihdu].data - + self._nX = f[ihdu].header['NAXIS1'] self._nY = f[ihdu].header['NAXIS2'] + + #note: map coordinates are switched compared to header. NAXIS1 is coordinate 1, not 0. + #see http://docs.astropy.org/en/stable/io/fits/#working-with-image-data assert self._map.shape[1] == self._nX, "NAXIS1 = %d in fits header, but %d in map" % (self._nX, self._map.shape[1]) assert self._map.shape[0] == self._nY, "NAXIS2 = %d in fits header, but %d in map" % (self._nY, self._map.shape[0]) - #note: map coordinates are switched compared to header. NAXIS1 is coordinate 1, not 0. - #see http://docs.astropy.org/en/stable/io/fits/#working-with-image-data + #hash sum uniquely identifying the template function (defined by its 2D map array and coordinate system) + #this is needed so that the memoization won't confuse different SpatialTemplate_2D objects. + h = hashlib.sha224() + h.update( self._map) + h.update( repr(self._wcs) ) + self.hash = int(h.hexdigest(), 16) def set_frame(self, new_frame): @@ -577,7 +592,7 @@ def set_frame(self, new_frame): self._frame = new_frame - def evaluate(self, x, y, K): + def evaluate(self, x, y, K, hash): # We assume x and y are R.A. and Dec coord = SkyCoord(ra=x, dec=y, frame=self._frame, unit="deg") diff --git a/astromodels/tests/test_functions.py b/astromodels/tests/test_functions.py index d6b5bc67..d27ff32c 100644 --- a/astromodels/tests/test_functions.py +++ b/astromodels/tests/test_functions.py @@ -1,4 +1,5 @@ import pytest +import os import astropy.units as u import numpy as np @@ -7,10 +8,12 @@ from astromodels.functions.function import FunctionMeta, Function1D, Function2D, FunctionDefinitionError, \ UnknownParameter, DesignViolation, get_function, get_function_class, UnknownFunction, list_functions from astromodels.functions.functions import Powerlaw, Line -from astromodels.functions.functions_2D import Gaussian_on_sphere +from astromodels.functions.functions_2D import Gaussian_on_sphere, SpatialTemplate_2D from astromodels.functions.functions_3D import Continuous_injection_diffusion from astromodels.functions import function as function_module +from astropy.io import fits + __author__ = 'giacomov' @@ -849,3 +852,67 @@ def test_function3D(): with pytest.raises(TypeError): c.set_units("not existent", u.deg, u.keV, 1.0 / (u.keV * u.s * u.deg**2 * u.cm**2)) + +def test_spatial_template_2D(): + + #make the fits files with templates to test. + cards = { + "SIMPLE": "T", + "BITPIX": -32, + "NAXIS" : 2, + "NAXIS1": 360, + "NAXIS2": 360, + "DATE": '2018-06-15', + "CUNIT1": 'deg', + "CRVAL1": 83, + "CRPIX1": 0, + "CDELT1": -0.0166667, + "CUNIT2": 'deg', + "CRVAL2": -2.0, + "CRPIX2": 0, + "CDELT2": 0.0166667, + "CTYPE1": 'GLON-CAR', + "CTYPE2": 'GLAT-CAR' } + + data = np.zeros([400,400]) + data[0:100,0:100] = 1 + hdu = fits.PrimaryHDU(data=data, header=fits.Header(cards)) + hdu.writeto("test1.fits", overwrite=True) + + data[:,:]=0 + data[200:300,200:300] = 1 + hdu = fits.PrimaryHDU(data=data, header=fits.Header(cards)) + hdu.writeto("test2.fits", overwrite=True) + + + #Now load template files and test their evaluation + shape1=SpatialTemplate_2D() + shape1.load_file("test1.fits") + shape1.K = 1 + + shape2=SpatialTemplate_2D() + shape2.load_file("test2.fits") + shape2.K = 1 + + assert shape1.hash != shape2.hash + + assert np.all ( shape1.evaluate( [312, 306], [41, 41], [1,1], [40, 2]) == [1., 0.] ) + assert np.all ( shape2.evaluate( [312, 306], [41, 41], [1,1], [40, 2]) == [0., 1.] ) + assert np.all ( shape1.evaluate( [312, 306], [41, 41], [1,10], [40, 2]) == [1., 0.] ) + assert np.all ( shape2.evaluate( [312, 306], [41, 41], [1,10], [40, 2]) == [0., 10.] ) + + + shape1.K = 1 + shape2.K = 1 + assert np.all ( shape1( [312, 306], [41, 41]) == [1., 0.] ) + assert np.all ( shape2( [312, 306], [41, 41]) == [0., 1.] ) + + shape1.K = 1 + shape2.K = 10 + assert np.all ( shape1( [312, 306], [41, 41]) == [1., 0.] ) + assert np.all ( shape2( [312, 306], [41, 41]) == [0., 10.] ) + + os.remove("test1.fits") + os.remove("test2.fits") + +