-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TST fix tests for JAX-Galsim #1252
base: main
Are you sure you want to change the base?
Changes from 250 commits
a01cc05
41482af
1561a15
1bf8ca9
4bb3fd9
e7778a0
caa88d9
a4cb713
60424b8
efd15fc
011944c
a7503db
7131c51
da8ae43
3d32477
8b25443
0515ad8
19471ac
6bcd68f
e3c5280
5f55582
0d3ca1a
c1aa753
ac7d4d9
51151af
b7cad5b
030bda0
c07af19
3940467
0a92bc4
86682d2
163687f
4cd1536
0952f29
f3a65d3
7786e37
be2e75f
2169160
bcee242
cb75725
6214b62
5e5b57d
1d988d1
7191fa0
90791a6
0505fc7
6b5fa52
5c838bb
b2700bb
ff23b6d
6539ce1
46a0e10
56dc788
0035409
b2c451f
ed7dae2
27bfe2b
cc94ba8
ad4395d
f3aefd2
5158546
54cffec
ebf552c
da90aac
51a8583
3fbe8d4
e2cfddd
93632df
4ba61f6
f6cfcd6
53cd860
f50323f
2100ade
9e69cc1
b2bcdc9
2ef3457
0994337
458cabb
d2f41ff
52b9910
8427e0b
5e4036a
b2a33fb
a534b6f
eb01912
88d7404
a6c0b3d
913aa52
213823c
cbcf0ec
e222d5e
5dd48a0
2cbb070
0d1dad3
ffe9e8a
caecb71
92a0f0c
7e24347
ac7876a
ad828a4
98346bf
bdeeb71
b4ad42e
13b55be
856169c
3524b8c
6c497dc
3d402a9
db25fdb
0c97449
d6982b9
322b3db
b030c3b
16402d3
34b37d1
53d54ee
f598233
944505b
7641057
f2cab46
dbe4829
e675f72
a4a183b
a68288d
6c2bac1
d293d88
618a33b
5591ded
a093624
8a3440d
7d6923e
6fac1a7
552aef8
ba43346
3a225bd
b67d4db
38d06d5
58b7f77
633d0f5
d5b6d76
8cc107a
ed0df69
a320a40
49c4416
fdfda0c
9c660b5
746ee19
481cdbb
728a18f
3ff8c02
5263be7
aacd9c8
46cbdfc
9a2d102
dfe3d11
7192fa6
8592754
71400ad
eecef97
fc4e74f
75dfbcb
e003b5a
ef5d07a
f1dd4d2
73851f4
7f4fd66
73b4e75
01d0ab6
e01492d
5775587
a5c6415
185bb5a
4ae8a78
544cc29
1f1699e
496ce4d
af11c61
bcfc251
3b5c1ca
70a875f
48edcbb
dff8cba
36064f5
e4c6a49
f9a318d
81d788f
e2a9c7e
5c716c8
88103b5
bb64924
783f757
ab72b99
80db992
b1c10d4
0f0feeb
e61fc38
7e3cded
011fde6
68e4b7a
d6886c9
1b44daa
ff175e1
62634db
4161bad
0e379d2
9696fbb
6a326a7
0d55c54
13a3042
d80d1d8
cc57eca
36da37b
414c8ab
9cd8ab2
aae4649
1b627dc
396084f
ae74cc7
ea082f5
1a3b3cb
4173070
04c247a
79536e1
a2f674a
ce25e31
3c4841c
c987110
5337334
8454028
101a36e
682f592
63855f9
15943e2
3cd63b6
9a5ab87
b1632e9
108e119
4f93a4e
b163e6f
61ae3d6
501c0fd
fb890ba
3fc5a90
5e088ec
b7ef1e2
d8a549e
1513366
ea74f8c
178f1b6
9204928
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,31 @@ | |
from numpy.testing import assert_raises | ||
from numpy.testing import assert_warns | ||
|
||
__all__ = [ | ||
"default_params", | ||
"gsobject_compare", | ||
"printval", | ||
"convertToShear", | ||
"check_basic_x", | ||
"check_basic_k", | ||
"assert_floatlike", | ||
"assert_intlike", | ||
"check_basic", | ||
"do_shoot", | ||
"do_kvalue", | ||
"radial_integrate", | ||
"drawNoise", | ||
"check_pickle", | ||
"check_all_diff", | ||
"timer", | ||
"CaptureLog", | ||
"assert_raises", | ||
"assert_warns", | ||
"Profile", | ||
"galsim_backend", | ||
"is_jax_galsim", | ||
] | ||
|
||
# This file has some helper functions that are used by tests from multiple files to help | ||
# avoid code duplication. | ||
|
||
|
@@ -43,6 +68,18 @@ | |
integration_relerr = 1.e-6, | ||
integration_abserr = 1.e-8) | ||
|
||
|
||
def galsim_backend(): | ||
if "jax_galsim/__init__.py" in galsim.__file__: | ||
return "jax_galsim" | ||
else: | ||
return "galsim" | ||
|
||
|
||
def is_jax_galsim(): | ||
return galsim_backend() == "jax_galsim" | ||
|
||
|
||
def gsobject_compare(obj1, obj2, conv=None, decimal=10): | ||
"""Helper function to check that two GSObjects are equivalent | ||
""" | ||
|
@@ -133,20 +170,32 @@ def check_basic_x(prof, name, approx_maxsb=False, scale=None): | |
np.testing.assert_allclose( | ||
image(i,j), prof._xValue(galsim.PositionD(x,y)), rtol=1.e-5, | ||
err_msg="%s profile sb image does not match _xValue at %d,%d"%(name,i,j)) | ||
assert prof.withFlux.__doc__ == galsim.GSObject.withFlux.__doc__ | ||
assert prof.__class__.withFlux.__doc__ == galsim.GSObject.withFlux.__doc__ | ||
if is_jax_galsim(): | ||
for line in galsim.GSObject.withFlux.__doc__.splitlines(): | ||
if line.strip() and "LAX" not in line: | ||
assert line.strip() in prof.withFlux.__doc__, ( | ||
prof.withFlux.__doc__, galsim.GSObject.withFlux.__doc__, | ||
) | ||
for line in galsim.GSObject.withFlux.__doc__.splitlines(): | ||
if line.strip() and "LAX" not in line: | ||
assert line.strip() in prof.__class__.withFlux.__doc__, ( | ||
prof.__class__.withFlux.__doc__, galsim.GSObject.withFlux.__doc__, | ||
) | ||
else: | ||
assert prof.withFlux.__doc__ == galsim.GSObject.withFlux.__doc__ | ||
assert prof.__class__.withFlux.__doc__ == galsim.GSObject.withFlux.__doc__ | ||
|
||
# Check negative flux: | ||
neg_image = prof.withFlux(-prof.flux).drawImage(method='sb', scale=scale, use_true_center=False) | ||
np.testing.assert_almost_equal(neg_image.array/prof.flux, -image.array/prof.flux, 7, | ||
'%s negative flux drawReal is not negative of +flux image'%name) | ||
np.testing.assert_array_almost_equal(neg_image.array/prof.flux, -image.array/prof.flux, 7, | ||
'%s negative flux drawReal is not negative of +flux image'%name) | ||
|
||
# Direct call to drawReal should also work and be equivalent to the above with scale = 1. | ||
prof.drawImage(image, method='sb', scale=1., use_true_center=False) | ||
image2 = image.copy() | ||
prof.drawReal(image2) | ||
np.testing.assert_equal(image2.array, image.array, | ||
err_msg="%s drawReal not equivalent to drawImage"%name) | ||
np.testing.assert_array_equal(image2.array, image.array, | ||
err_msg="%s drawReal not equivalent to drawImage"%name) | ||
|
||
# If supposed to be axisymmetric, make sure it is. | ||
if prof.is_axisymmetric: | ||
|
@@ -194,7 +243,7 @@ def check_basic_k(prof, name): | |
|
||
# Check negative flux: | ||
neg_image = prof.withFlux(-prof.flux).drawKImage(kimage.copy()) | ||
np.testing.assert_almost_equal(neg_image.array/prof.flux, -kimage.array/prof.flux, 7, | ||
np.testing.assert_array_almost_equal(neg_image.array/prof.flux, -kimage.array/prof.flux, 7, | ||
'%s negative flux drawK is not negative of +flux image'%name) | ||
|
||
# If supposed to be axisymmetric, make sure it is in the kValues. | ||
|
@@ -206,6 +255,30 @@ def check_basic_k(prof, name): | |
np.testing.assert_allclose(test_values, ref_value, rtol=1.e-5, | ||
err_msg="%s profile not axisymmetric in kValues"%name) | ||
|
||
def assert_floatlike(val): | ||
assert ( | ||
isinstance(val, float) | ||
or ( | ||
is_jax_galsim() | ||
and hasattr(val, "shape") | ||
and val.shape == () | ||
and hasattr(val, "dtype") | ||
and val.dtype.name in ["float", "float32", "float64"] | ||
) | ||
), "Value is not float-like: type(%r) = %r" % (val, type(val)) | ||
|
||
def assert_intlike(val): | ||
assert ( | ||
isinstance(val, int) | ||
or ( | ||
is_jax_galsim() | ||
and hasattr(val, "shape") | ||
and val.shape == () | ||
and hasattr(val, "dtype") | ||
and val.dtype.name in ["int", "int32", "int64"] | ||
) | ||
), "Value is not int-like: type(%r) = %r" % (val, type(val)) | ||
|
||
def check_basic(prof, name, approx_maxsb=False, scale=None, do_x=True, do_k=True): | ||
"""Do some basic sanity checks that should work for all profiles. | ||
""" | ||
|
@@ -220,12 +293,12 @@ def check_basic(prof, name, approx_maxsb=False, scale=None, do_x=True, do_k=True | |
prof.positive_flux - prof.negative_flux, prof.flux, | ||
err_msg="%s profile flux not equal to posflux + negflux"%name) | ||
assert isinstance(prof.centroid, galsim.PositionD) | ||
assert isinstance(prof.flux, float) | ||
assert isinstance(prof.positive_flux, float) | ||
assert isinstance(prof.negative_flux, float) | ||
assert isinstance(prof.max_sb, float) | ||
assert isinstance(prof.stepk, float) | ||
assert isinstance(prof.maxk, float) | ||
assert_floatlike(prof.flux) | ||
assert_floatlike(prof.positive_flux) | ||
assert_floatlike(prof.negative_flux) | ||
assert_floatlike(prof.max_sb) | ||
assert_floatlike(prof.stepk) | ||
assert_floatlike(prof.maxk) | ||
assert isinstance(prof.has_hard_edges, bool) | ||
assert isinstance(prof.is_axisymmetric, bool) | ||
assert isinstance(prof.is_analytic_x, bool) | ||
|
@@ -298,6 +371,9 @@ def do_shoot(prof, img, name): | |
print('nphot = ',nphot) | ||
img2 = img.copy() | ||
|
||
if is_jax_galsim(): | ||
rtol *= 3 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a sign of a problem with the jax-galsim implementation? |
||
|
||
# Use a deterministic random number generator so we don't fail tests because of rare flukes | ||
# in the random numbers. | ||
rng = galsim.UniformDeviate(12345) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
# | ||
|
||
import numpy | ||
import numpy as np | ||
import os | ||
import math | ||
|
||
|
@@ -122,11 +123,11 @@ def test_angle(): | |
|
||
# Check invalid constructors | ||
assert_raises(TypeError,galsim.AngleUnit, galsim.degrees) | ||
assert_raises(ValueError,galsim.AngleUnit, 'spam') | ||
assert_raises((ValueError, TypeError), galsim.AngleUnit, 'spam') | ||
assert_raises(TypeError,galsim.AngleUnit, 1, 3) | ||
assert_raises(TypeError,galsim.Angle, 3.4) | ||
assert_raises(TypeError,galsim.Angle, theta1, galsim.degrees) | ||
assert_raises(ValueError,galsim.Angle, 'spam', galsim.degrees) | ||
assert_raises((ValueError, TypeError), galsim.Angle, 'spam') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this should remove the last arg, |
||
assert_raises(TypeError,galsim.Angle, 1, 3) | ||
|
||
|
||
|
@@ -155,7 +156,7 @@ def test_celestialcoord_basic(): | |
|
||
x, y, z = c1.get_xyz() | ||
print('c1 is at x,y,z = ',x,y,z) | ||
np.testing.assert_equal((x,y,z), (1,0,0)) | ||
np.testing.assert_array_equal((x,y,z), (1,0,0)) | ||
assert c1 == galsim.CelestialCoord.from_xyz(x,y,z) | ||
|
||
x, y, z = c2.get_xyz() | ||
|
@@ -343,9 +344,19 @@ def test_projection(): | |
|
||
# First the trivial case | ||
p0 = center.project(center, projection='lambert') | ||
assert p0 == (0.0 * galsim.arcsec, 0.0 * galsim.arcsec) | ||
np.testing.assert_allclose( | ||
(p0[0].rad, p0[1].rad), | ||
(0.0, 0.0), | ||
rtol=0, | ||
atol=1e-16, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What computation are you doing in Jax-Galsim that makes this not be exactly zero. This should have been trivially true I would think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, the line bloat here seems rather gratuitous. Can we put each of these in a single line? |
||
c0 = center.deproject(*p0, projection='lambert') | ||
assert c0 == center | ||
np.testing.assert_allclose( | ||
c0.rad, | ||
center.rad, | ||
beckermr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rtol=0, | ||
atol=1e-16, | ||
) | ||
np.testing.assert_almost_equal(center.jac_deproject(*p0, projection='lambert').ravel(), | ||
(1,0,0,1)) | ||
|
||
|
@@ -398,9 +409,19 @@ def test_projection(): | |
|
||
# First the trivial case | ||
p0 = center.project(center, projection='stereographic') | ||
assert p0 == (0.0 * galsim.arcsec, 0.0 * galsim.arcsec) | ||
np.testing.assert_allclose( | ||
(p0[0].rad, p0[1].rad), | ||
(0.0, 0.0), | ||
rtol=0, | ||
atol=1e-16, | ||
) | ||
c0 = center.deproject(*p0, projection='stereographic') | ||
assert c0 == center | ||
np.testing.assert_allclose( | ||
c0.rad, | ||
center.rad, | ||
beckermr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rtol=0, | ||
atol=1e-16, | ||
) | ||
np.testing.assert_almost_equal(center.jac_deproject(*p0, projection='stereographic').ravel(), | ||
(1,0,0,1)) | ||
|
||
|
@@ -456,9 +477,19 @@ def test_projection(): | |
|
||
# First the trivial case | ||
p0 = center.project(center, projection='gnomonic') | ||
assert p0 == (0.0 * galsim.arcsec, 0.0 * galsim.arcsec) | ||
np.testing.assert_allclose( | ||
(p0[0].rad, p0[1].rad), | ||
(0.0, 0.0), | ||
rtol=0, | ||
atol=1e-16, | ||
) | ||
c0 = center.deproject(*p0, projection='gnomonic') | ||
assert c0 == center | ||
np.testing.assert_allclose( | ||
c0.rad, | ||
center.rad, | ||
beckermr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rtol=0, | ||
atol=1e-16, | ||
) | ||
np.testing.assert_almost_equal(center.jac_deproject(*p0, projection='gnomonic').ravel(), | ||
(1,0,0,1)) | ||
|
||
|
@@ -510,9 +541,19 @@ def test_projection(): | |
|
||
# First the trivial case | ||
p0 = center.project(center, projection='postel') | ||
assert p0 == (0.0 * galsim.arcsec, 0.0 * galsim.arcsec) | ||
np.testing.assert_allclose( | ||
(p0[0].rad, p0[1].rad), | ||
(0.0, 0.0), | ||
rtol=0, | ||
atol=1e-16, | ||
) | ||
c0 = center.deproject(*p0, projection='postel') | ||
assert c0 == center | ||
np.testing.assert_allclose( | ||
c0.rad, | ||
center.rad, | ||
beckermr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rtol=0, | ||
atol=1e-16, | ||
) | ||
np.testing.assert_almost_equal(center.jac_deproject(*p0, projection='postel').ravel(), | ||
(1,0,0,1)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this? I thought they were equivalent. I'd rather not use the more verbose one if we can avoid it.