Skip to content

Commit

Permalink
Use check_spectrum_plottable from astrodb_utils (#548)
Browse files Browse the repository at this point in the history
* remove `spectrum_plottable` function
* modify tests for new check_spectrum_plottable
  • Loading branch information
kelle authored Aug 2, 2024
1 parent 117aff1 commit 3e70973
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 203 deletions.
83 changes: 3 additions & 80 deletions simple/utils/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,22 @@
import sqlite3
from typing import Optional

import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
import requests
import sqlalchemy.exc
from astrodb_utils import (
AstroDBError,
find_source_in_db,
internet_connection,
)
from astrodb_utils.spectra import check_spectrum_plottable
from astrodbkit2.astrodb import Database
from astropy.io import fits
from specutils import Spectrum1D

from simple.schema import Spectra

__all__ = [
"ingest_spectrum",
"ingest_spectrum_from_fits",
"spectrum_plottable",
"find_spectra",
]

Expand Down Expand Up @@ -162,7 +158,7 @@ def ingest_spectrum(
return flags

# Check if spectrum is plottable
flags["plottable"] = spectrum_plottable(spectrum, raise_error=raise_error)
flags["plottable"] = check_spectrum_plottable(spectrum, raise_error=raise_error)

# Compile fields into a dictionary
row_data = {
Expand All @@ -184,7 +180,7 @@ def ingest_spectrum(

try:
# Attempt to add spectrum to database
# This will throw errors based on validation in schema.py
# This will throw errors based on validation in schema.py
# and any database checks (as for example IntegrityError)
obj = Spectra(**row_data)
with db.session as session:
Expand Down Expand Up @@ -265,79 +261,6 @@ def ingest_spectrum_from_fits(db, source, spectrum_fits_file):
)


def spectrum_plottable(spectrum_path, raise_error=True, show_plot=False):
"""
Check if spectrum is plottable
"""
# load the spectrum and make sure it's a Spectrum1D object

try:
# spectrum: Spectrum1D = load_spectrum(spectrum_path) #astrodbkit2 method
spectrum = Spectrum1D.read(spectrum_path)
except Exception as e:
msg = (
str(e) + f"\nSkipping {spectrum_path}: \n"
"unable to load file as Spectrum1D object"
)
if raise_error:
logger.error(msg)
raise AstroDBError(msg)
else:
logger.warning(msg)
return False

# checking spectrum has good units and not only NaNs
try:
wave: np.ndarray = spectrum.spectral_axis.to(u.micron).value
flux: np.ndarray = spectrum.flux.value
except AttributeError as e:
msg = str(e) + f"Skipping {spectrum_path}: unable to parse spectral axis"
if raise_error:
logger.error(msg)
raise AstroDBError(msg)
else:
logger.warning(msg)
return False
except u.UnitConversionError as e:
msg = (
f"{e} \n"
f"Skipping {spectrum_path}: unable to convert spectral axis to microns"
)
if raise_error:
logger.error(msg)
raise AstroDBError(msg)
else:
logger.warning(msg)
return False
except ValueError as e:
msg = f"{e} \nSkipping {spectrum_path}: Value error"
if raise_error:
logger.error(msg)
raise AstroDBError(msg)
else:
logger.warning(msg)
return False

# check for NaNs
nan_check: np.ndarray = ~np.isnan(flux) & ~np.isnan(wave)
wave = wave[nan_check]
flux = flux[nan_check]
if not len(wave):
msg = f"Skipping {spectrum_path}: spectrum is all NaNs"
if raise_error:
logger.error(msg)
raise AstroDBError(msg)
else:
logger.warning(msg)
return False

if show_plot:
plt.plot(wave, flux)
plt.show()

return True


def find_spectra(
db: Database,
source: str,
Expand Down
217 changes: 94 additions & 123 deletions tests/test_spectra_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,83 +7,111 @@
from simple.utils.spectra import (
ingest_spectrum,
# ingest_spectrum_from_fits,
spectrum_plottable,
)


@pytest.mark.filterwarnings(
"ignore", message=".*Note: astropy.io.fits uses zero-based indexing.*"
"ignore",
message=".*SAWarning: Column 'Spectra.reference' is marked as a member of the primary key for table 'Spectra'.*",
)
@pytest.mark.filterwarnings(
"ignore", message=".*'datfix' made the change 'Set MJD-OBS to.*"
"ignore", message=".*'kiwi': No known catalog could be found.*"
)
@pytest.mark.filterwarnings(
"ignore",
message=(
".*'erg/cm2/s/A' contains multiple slashes, "
"which is discouraged by the FITS standard.*",
),
@pytest.mark.parametrize(
"test_input, message",
[
(
{
"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
},
"Value required for regime",
), # missing regime
(
{
"source": "apple",
"regime": "nir",
"instrument": "SpeX",
"obs_date": "2020-01-01",
},
"Value required for telescope",
), # missing telescope
(
{
"source": "apple",
"regime": "nir",
"telescope": "IRTF",
"obs_date": "2020-01-01",
},
"Value required for instrument",
), # missing instrument
(
{
"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"obs_date": "2020-01-01",
},
"NOT NULL constraint failed: Spectra.reference",
), # missing reference
(
{
"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"obs_date": "2020-01-01",
"reference": "Ref 5",
},
"FOREIGN KEY constraint failed",
), # invalid reference
(
{
"source": "kiwi",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"obs_date": "2020-01-01",
"reference": "Ref 1",
},
"No unique source match for kiwi in the database",
), # invalid source
(
{
"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"reference": "Ref 1",
},
"Invalid date received: None",
), # missing date
(
{
"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "fake regime",
"obs_date": "2020-01-01",
"reference": "Ref 1",
},
"FOREIGN KEY constraint failed",
), # invalid regime
],
)
@pytest.mark.parametrize("test_input, message", [
({"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
}, "Value required for regime"), # missing regime
({"source": "apple",
"regime": "nir",
"instrument": "SpeX",
"obs_date": "2020-01-01",
}, "Value required for telescope"), # missing telescope
({"source": "apple",
"regime": "nir",
"telescope": "IRTF",
"obs_date": "2020-01-01",
}, "Value required for instrument"), # missing instrument
({"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"obs_date": "2020-01-01",
}, "NOT NULL constraint failed: Spectra.reference"), # missing reference
({"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"obs_date": "2020-01-01",
"reference": "Ref 5",
}, "FOREIGN KEY constraint failed"), # invalid reference
({"source": "kiwi",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"obs_date": "2020-01-01",
"reference": "Ref 1",
}, "No unique source match for kiwi in the database"), # invalid source
({"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"reference": "Ref 1",
}, "Invalid date received: None"), # missing date
({"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "fake regime",
"obs_date": "2020-01-01",
"reference": "Ref 1",
}, "FOREIGN KEY constraint failed"), # invalid regime
])
def test_ingest_spectrum_errors(temp_db, test_input, message):
# Test for ingest_spectrum that is expected to return errors

# Prepare parameters to send to ingest_spectrum
spectrum = "https://bdnyc.s3.amazonaws.com/tests/U10176.fits"
spectrum = "https://bdnyc.s3.amazonaws.com/IRS/2MASS+J03552337%2B1133437.fits"
parameters = {"db": temp_db, "spectrum": spectrum}
parameters.update(test_input)

Expand All @@ -98,19 +126,8 @@ def test_ingest_spectrum_errors(temp_db, test_input, message):
assert message in result["message"]


@pytest.mark.filterwarnings("ignore:Verification")
@pytest.mark.filterwarnings("ignore", message=".*Card 'AIRMASS' is not FITS standard.*")
@pytest.mark.filterwarnings(
"ignore:Note"
) # : astropy.io.fits uses zero-based indexing.
@pytest.mark.filterwarnings("ignore:'datfix' made the change 'Set MJD-OBS to")
@pytest.mark.filterwarnings(
"ignore:'erg/cm2/s/A' contains multiple slashes,"
" which is discouraged by the FITS standard"
)
@pytest.mark.filterwarnings("ignore")
def test_ingest_spectrum_works(temp_db):
spectrum = "https://bdnyc.s3.amazonaws.com/tests/U10176.fits"
spectrum = "https://bdnyc.s3.amazonaws.com/IRS/2MASS+J03552337%2B1133437.fits"
result = ingest_spectrum(
temp_db,
source="banana",
Expand All @@ -123,49 +140,3 @@ def test_ingest_spectrum_works(temp_db):
mode="Prism",
)
assert result["added"] is True


@pytest.mark.filterwarnings("ignore:Invalid 'BLANK' keyword in header.")
@pytest.mark.filterwarnings("ignore:'datfix' made the change 'Set MJD-OBS to")
@pytest.mark.filterwarnings("ignore:The WCS transformation has more axes")
@pytest.mark.filterwarnings("ignore:'cdfix' made the change 'Success'")
@pytest.mark.filterwarnings("ignore:MJD-OBS =")
@pytest.mark.filterwarnings(
"ignore",
message=(
"'erg/cm2/s/A' contains multiple slashes, "
"which is discouraged by the FITS standard.*",
),
)
@pytest.mark.filterwarnings("ignore")
@pytest.mark.parametrize(
"file",
[
"https://s3.amazonaws.com/bdnyc/optical_spectra/2MASS1538-1953_tell.fits",
"https://s3.amazonaws.com/bdnyc/spex_prism_lhs3003_080729.txt",
"https://bdnyc.s3.amazonaws.com/IRS/2351-2537_IRS_spectrum.dat",
],
)
def test_spectrum_plottable_false(file):
with pytest.raises(AstroDBError) as error_message:
spectrum_plottable(file)
assert "unable to load file as Spectrum1D object" in str(error_message.value)

result = spectrum_plottable(file, raise_error=False)
assert result is False


@pytest.mark.parametrize(
"file",
[
(
"https://bdnyc.s3.amazonaws.com/SpeX/Prism/"
"2MASS+J04510093-3402150_2012-09-27.fits"
),
"https://bdnyc.s3.amazonaws.com/IRS/2MASS+J23515044-2537367.fits",
"https://bdnyc.s3.amazonaws.com/optical_spectra/vhs1256b_opt_Osiris.fits",
],
)
def test_spectrum_plottable_true(file):
result = spectrum_plottable(file)
assert result is True

0 comments on commit 3e70973

Please sign in to comment.