Skip to content

Commit

Permalink
Merge pull request #111 from BDNYC/joe_tests
Browse files Browse the repository at this point in the history
Tests Galore
  • Loading branch information
hover2pi authored Sep 20, 2024
2 parents 590f23c + 2a5d046 commit 4d43733
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 47 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ or via `conda` with
```
git clone https://github.com/hover2pi/sedkit.git
cd sedkit
conda env create -f env/environment-<python_version>.yml --force
conda env create -f env/environment-3.11.yml --force
conda activate sedkit
python setup.py install
```

where `<python_version>` is `3.7` or `3.8`.

## Demo

An SED can be constructed by importing and initializing an `SED` object.
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ dependencies = [
"numpy>=1.25.1",
"pandas>=1.3.5",
"scipy>=1.8.0",
"svo-filters>=0.4.4"
"svo-filters>=0.4.4",
"importlib-resources"
]
dynamic = ["version"]

Expand Down
Binary file modified sedkit/data/models/atmospheric/btsettl/index.p
Binary file not shown.
Binary file modified sedkit/data/models/atmospheric/spexprismlibrary/index.p
Binary file not shown.
16 changes: 7 additions & 9 deletions sedkit/modelgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ def __init__(self, name, parameters, wave_units=None, flux_units=None,
setattr(self, key, val)

# Make the empty table
columns = self.parameters+['filepath', 'spectrum', 'label']
columns = self.parameters+['filepath', 'spectrum', 'label', 'weights']
self.index = pd.DataFrame(columns=columns)

def add_model(self, spectrum, label=None, filepath=None, **kwargs):
def add_model(self, spectrum, label=None, filepath=None, weights=1, **kwargs):
"""Add the given model with the specified parameter values as kwargs
Parameters
Expand All @@ -214,11 +214,10 @@ def add_model(self, spectrum, label=None, filepath=None, **kwargs):
raise ValueError("Must have kwargs for", self.parameters)

# Make the dictionary of new data
kwargs.update({'spectrum': spectrum, 'filepath': filepath, 'label': label})
new_rec = pd.DataFrame({k: [v] for k, v in kwargs.items()})
kwargs.update({'spectrum': spectrum, 'filepath': filepath, 'label': label, 'weights': weights})
new_rec = {k: kwargs[k] for k in kwargs.keys()}

# Add it to the index
print(self.index.columns, new_rec.keys())
self.index.loc[len(self.index)] = new_rec

@staticmethod
Expand Down Expand Up @@ -365,7 +364,7 @@ def index_models(self, parameters=None, wl_min=0.3*q.um, wl_max=25*q.um):

# Update attributes
if parameters is None:
parameters = [col for col in self.index.columns if col not in ['filepath', 'spectrum', 'label']]
parameters = [col for col in self.index.columns if col not in ['filepath', 'spectrum', 'label', 'weights']]
self.parameters = parameters

def _in_range(self, **kwargs):
Expand Down Expand Up @@ -575,9 +574,7 @@ def photometry(self, bandpasses, weight=False):
"""
# Copy the ModelGrid and empty the index
phot = ModelGrid(name=self.name, parameters=self.parameters)
dic = copy(self.__dict__)
dic['index'] = dic['index'][:0]
phot.__dict__ = dic
phot.__dict__.update({k:v for k, v in self.__dict__.items() if k != 'index'})

# Iterate over the rows
for n, row in self.index.iterrows():
Expand Down Expand Up @@ -698,6 +695,7 @@ def resample_grid(self, name=None, **kwargs):
pars['label'] = spec.name
pars['spectrum'] = spec.data
pars['filepath'] = 'interp'
pars['weights'] = 1

# Add line to the new dataframe
new_index = pd.concat([new_index, pd.DataFrame([pars])], ignore_index=True)
Expand Down
21 changes: 13 additions & 8 deletions sedkit/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,14 @@ def evaluate(self, rel_name, x_val, xunits=None, yunits=None, fit_local=False, p

if plot:
plt = self.plot(rel_name, xunits=xunits, yunits=yunits)
plt.circle([x_val.value if hasattr(x_val, 'unit') else x_val], [y_val.value if hasattr(y_val, 'unit') else y_val], color='red', size=10, legend_label='{}({})'.format(rel['yparam'], x_val))
if y_upper:
plt.line([x_val, x_val], [y_val - y_lower, y_val + y_upper], color='red')
xv = x_val.value if hasattr(x_val, 'unit') else x_val
yv = y_val.value if hasattr(y_val, 'unit') else y_val
plt.scatter([xv], [yv], color='red', size=10, legend_label='{}({})'.format(rel['yparam'], x_val))
print(y_val, y_upper, y_lower)
if y_upper is not None:
yvl = y_lower.value if hasattr(y_lower, 'unit') else y_lower
yvu = y_upper.value if hasattr(y_upper, 'unit') else y_upper
plt.line([xv, xv], [yv - yvl, yv + yvu], color='red')
show(plt)

# Restore full relation
Expand Down Expand Up @@ -362,7 +367,7 @@ def plot(self, rel_name, xunits=None, yunits=None, **kwargs):
fig.yaxis.axis_label = '{}{}'.format(yparam, '[{}]'.format(yunits or rel['yunit']))

# Draw points
fig.circle(x * xu, y * yu, legend_label='Data', **kwargs)
fig.scatter(x * xu, y * yu, legend_label='Data', **kwargs)

return fig

Expand Down Expand Up @@ -502,10 +507,10 @@ def generate(self, orders):
# ====================================================================

# Get the data
cat1 = V.query_constraints('J/ApJ/810/158/table1')[0]
cat2 = V.query_constraints('J/ApJ/810/158/table9')[0]
cat1 = V.query_constraints(catalog='J/ApJ/810/158/table1')[0]
cat2 = V.query_constraints(catalog='J/ApJ/810/158/table9')[0]

# Join the tables to getthe spectral types and radii in one table
# Join the tables to get the spectral types and radii in one table
mlty_data = at.join(cat1, cat2, keys='ID', join_type='outer')

# Only keep field age
Expand Down Expand Up @@ -619,7 +624,7 @@ def plot(self, draw=False):

# Add the data
if n == 0:
fig.circle(data['data']['spt'], data['data']['radius'], size=8,
fig.scatter(data['data']['spt'], data['data']['radius'], size=8,
color=color, legend_label=data['ref'])
else:
fig.square(data['data']['spt'], data['data']['radius'], size=8,
Expand Down
28 changes: 18 additions & 10 deletions sedkit/sed.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,17 +856,20 @@ def compare_model(self, modelgrid, fit_to='spec', rebin=True, **kwargs):

if spec is not None:

if rebin and fit_to == 'spec':
model = model.resamp(spec.spectrum[0])
# If you want to resample, you need to trim here
# if rebin and fit_to == 'spec':
# model = model.resamp(spec.spectrum[0])
#
# # Fit the model to the SED
# gstat, yn, xn = list(spec.fit(model, wave_units='AA'))
# wave = model.wave * xn
# flux = model.flux * yn

# Fit the model to the SED
gstat, yn, xn = list(spec.fit(model, wave_units='AA'))
wave = model.wave * xn
flux = model.flux * yn
normed = model.norm_to_spec(spec)

# Plot the SED with the model on top
fig = self.plot(output=True)
fig.line(wave, flux)
fig.line(normed.wave, normed.flux, color='red')

show(fig)

Expand Down Expand Up @@ -1262,10 +1265,14 @@ def find_SDSS_spectra(self, surveys=['optical', 'apogee'], search_radius=None, *
"""
Search for SDSS spectra
"""
# Manual or parent radius
if search_radius is None:
search_radius = self.search_radius

if 'optical' in surveys:

# Query spectra
data, ref, header = qu.query_SDSS_optical_spectra(target=self.name, sky_coords=self.sky_coords, verbose=self.verbose, radius=search_radius or self.search_radius, **kwargs)
data, ref, header = qu.query_SDSS_optical_spectra(target=self.name, sky_coords=self.sky_coords, verbose=self.verbose, radius=search_radius, **kwargs)

# Add the spectrum to the SED
if data is not None:
Expand All @@ -1277,7 +1284,7 @@ def find_SDSS_spectra(self, surveys=['optical', 'apogee'], search_radius=None, *
if 'apogee' in surveys:

# Query spectra
data, ref, header = qu.query_SDSS_apogee_spectra(target=self.name, sky_coords=self.sky_coords, verbose=self.verbose, search_radius=search_radius or self.search_radius, **kwargs)
data, ref, header = qu.query_SDSS_apogee_spectra(target=self.name, sky_coords=self.sky_coords, verbose=self.verbose, search_radius=search_radius, **kwargs)

# Add the spectrum to the SED
if data is not None:
Expand Down Expand Up @@ -1916,7 +1923,8 @@ def infer_radius(self, radius_units=q.Rsun, infer_from=None, plot=False):
if len(infer_froms) > 0:
infer_from = infer_froms[0]
if infer_from not in infer_froms or infer_from is None:
raise ValueError("{}: Please choose valid relation to infer the radius. Try {}".format('None' if infer_from is None else infer_from, infer_froms))
self.message('Could not calculate radius without spectral_type, Lbol, M_2MASS.J, or M_2MASS.Ks')
# raise ValueError("{}: Please choose valid relation to infer the radius. Try {}".format('None' if infer_from is None else infer_from, infer_froms))

# Try model isochrones
if infer_from == 'evo_model':
Expand Down
6 changes: 6 additions & 0 deletions sedkit/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,10 @@ def norm_to_spec(self, spec, add=False, plot=False, **kwargs):
spec0 = slf.data
spec1 = spec.data[:, idx]

# Fix shape mismatch
if len(spec0[0]) < len(spec1[0]):
spec1 = spec1[:, 1:]

# Find the normalization factor
norm = u.minimize_norm(spec1[1], spec0[1], **kwargs)

Expand Down Expand Up @@ -1053,7 +1057,9 @@ def resamp(self, wave=None, resolution=None, name=None):
wave = wave.value

# Bin the spectrum
print(wave, self.wave)
binned = u.spectres(wave, self.wave, self.flux, self.unc)
print(binned[0])

# Update the spectrum
spectrum = [i * Q for i, Q in zip(binned, self.units)]
Expand Down
8 changes: 5 additions & 3 deletions sedkit/tests/test_modelgrid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import os
from pkg_resources import resource_filename
import importlib_resources

import astropy.units as q

Expand All @@ -15,7 +15,9 @@ def setUp(self):
# Make Model class for testing
params = ['spty']
grid = mg.ModelGrid('Test', params, q.AA, q.erg/q.s/q.cm**2/q.AA)
path = resource_filename('sedkit', 'data/models/atmospheric/spexprismlibrary')

model_path = 'data/models/atmospheric/spexprismlibrary'
path = importlib_resources.files('sedkit')/ model_path

# Delete the pickle so the models need to be indexed
os.remove(os.path.join(path, 'index.p'))
Expand Down Expand Up @@ -72,7 +74,7 @@ def test_load_model():
"""Test the load_model function"""
# Get the XML file
path = 'data/models/atmospheric/spexprismlibrary/spex-prism_2MASPJ0345432+254023_20030905_BUR06B.txt.xml'
filepath = resource_filename('sedkit', path)
filepath = importlib_resources.files('sedkit') / path

# Load the model
meta = mg.load_model(filepath)
Expand Down
4 changes: 2 additions & 2 deletions sedkit/tests/test_relations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from pkg_resources import resource_filename
import importlib_resources
import unittest

import astropy.units as q
Expand Down Expand Up @@ -50,7 +50,7 @@ class TestRelation(unittest.TestCase):
"""Tests for the Relation base class"""
def setUp(self):
# Set the file
self.file = resource_filename('sedkit', 'data/dwarf_sequence.txt')
self.file = str(importlib_resources.files('sedkit')/ 'data/dwarf_sequence.txt')

def test_init(self):
"""Test class initialization"""
Expand Down
10 changes: 6 additions & 4 deletions sedkit/tests/test_sed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import copy
from pkg_resources import resource_filename
import importlib_resources

import numpy as np
import astropy.units as q
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_add_photometry_table(self):
s = copy.copy(self.sed)

# Add the photometry
f = resource_filename('sedkit', 'data/L3_photometry.txt')
f = str(importlib_resources.files('sedkit')/ 'data/L3_photometry.txt')
s.add_photometry_table(f)
self.assertEqual(len(s.photometry), 8)

Expand Down Expand Up @@ -173,7 +173,7 @@ def test_compare_model(self):
"""Test for the compare_model method"""
v = sed.VegaSED()
bt = mg.BTSettl()
v.compare_model(bt, teff=3000)
v.compare_model(bt, teff=10000)

def test_plot(self):
"""Test plotting method"""
Expand Down Expand Up @@ -225,6 +225,8 @@ def test_find_SDSS_spectra(self):
s = sed.SED()
s.sky_coords = SkyCoord('0h8m05.63s +14d50m23.3s', frame='icrs')
s.find_SDSS_spectra(search_radius=20 * q.arcsec)
s.find_SDSS()
s.plot()

def test_run_methods(self):
"""Test that the method_list argument works"""
Expand Down Expand Up @@ -279,7 +281,7 @@ def test_fit_modelgrid(self):
# s.add_spectrum(self.spec1)
#
# # Add photometry
# f = resource_filename('sedkit', 'data/L3_photometry.txt')
# f = str(importlib_resources.files('sedkit')/ 'data/L3_photometry.txt')
# s.add_photometry_file(f)
#
# # Fit with SPL
Expand Down
13 changes: 6 additions & 7 deletions sedkit/tests/test_spectrum.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
"""A suite of tests for the spectrum.py module"""
import unittest
import copy
from pkg_resources import resource_filename
import importlib_resources

import numpy as np
import astropy.units as q
from svo_filters import Filter

from sedkit import modelgrid as mg
from sedkit import spectrum as sp
from sedkit import utilities as u


class TestSpectrum(unittest.TestCase):
Expand Down Expand Up @@ -89,7 +88,7 @@ def test_model_fit(self):

# Test MCMC fit
bt = mg.BTSettl()
spec = bt.get_spectrum(teff=2456, logg=5.5, meta=0, alpha=0, snr=100)
spec = bt.get_spectrum(teff=2456, logg=5., meta=0, alpha=0, snr=100)
spec.mcmc_fit(bt, name='Test', report=True)

def test_addition(self):
Expand Down Expand Up @@ -152,7 +151,7 @@ def test_interpolate(self):
spec2 = self.flat2.interpolate(spec1)

# Check wavelength is updated
self.assertTrue(np.all(spec1.wave == spec2.wave))
self.assertTrue(set(spec2.wave).issubset(set(spec1.wave)))

def test_renormalize(self):
"""Test that a spectrum is properly normalized to a given magnitude"""
Expand Down Expand Up @@ -206,7 +205,7 @@ def test_synthetic_mag(self):
self.assertIsInstance(mag_unc, float)

# Test flux
flx, flx_unc = s1.synthetic_flux(filt, plot=True)
flx, flx_unc = s1.synthetic_flux(filt, plot=False)
self.assertIsInstance(flx, q.quantity.Quantity)
self.assertIsInstance(flx_unc, q.quantity.Quantity)

Expand Down Expand Up @@ -258,8 +257,8 @@ class TestFileSpectrum(unittest.TestCase):
def setUp(self):
"""Setup the tests"""
# Files for testing
self.fitsfile = resource_filename('sedkit', 'data/Trappist-1_NIR.fits')
self.txtfile = resource_filename('sedkit', 'data/STScI_Vega.txt')
self.fitsfile = str(importlib_resources.files('sedkit')/ 'data/Trappist-1_NIR.fits')
self.txtfile = str(importlib_resources.files('sedkit')/ 'data/STScI_Vega.txt')

def test_fits(self):
"""Test that a fits file can be loaded"""
Expand Down

0 comments on commit 4d43733

Please sign in to comment.