From 6df3d0e9e49363626189962d7638bef3a44ca4f6 Mon Sep 17 00:00:00 2001 From: David Rodriguez Date: Tue, 16 Jul 2024 15:38:01 -0400 Subject: [PATCH] Reworking how ingest_spectrum works using the ORM (#501) * First pass at reworking how ingest_spectrum works using the ORM * Fixing some issues in unit tests * Partial implementation of new flags response * Iterating on error message handling for ingest_spectrum * Reworking tests to be more compact * Expanding validation to check for missing telescope, instrument, and mode * Minor cleanup * Apply suggestions from code review Co-authored-by: Kelle Cruz * Update simple/utils/spectra.py Co-authored-by: Kelle Cruz * Adding warning when converting date into ISO format --------- Co-authored-by: Kelle Cruz --- simple/schema.py | 38 ++++-- simple/utils/spectra.py | 246 ++++++------------------------------ tests/test_schema.py | 25 +++- tests/test_spectra_utils.py | 222 ++++++++++---------------------- 4 files changed, 160 insertions(+), 371 deletions(-) diff --git a/simple/schema.py b/simple/schema.py index 30c1a8ca8..48fb29a1a 100644 --- a/simple/schema.py +++ b/simple/schema.py @@ -3,6 +3,7 @@ """ import enum +from datetime import datetime import sqlalchemy as sa from astrodbkit2.astrodb import Base @@ -344,28 +345,30 @@ class Gravities(Base): class Spectra(Base): # Table to store references to spectra __tablename__ = 'Spectra' + source = Column( String(100), ForeignKey("Sources.source", ondelete="cascade", onupdate="cascade"), nullable=False, primary_key=True, ) + # Data access_url = Column(String(1000), nullable=False) # URL of spectrum location - original_spectrum = Column( - String(1000) - ) # URL of original spectrum location, if applicable - local_spectrum = Column( - String(1000) - ) # local directory (via environment variable) of spectrum location + + # URL of original spectrum location, if applicable + original_spectrum = Column(String(1000)) + # local directory (via environment variable) of spectrum location + local_spectrum = Column(String(1000)) + regime = Column( String(30), ForeignKey("Regimes.regime", ondelete="cascade", onupdate="cascade"), primary_key=True, ) - telescope = Column(String(30)) - instrument = Column(String(30)) - mode = Column(String(30)) # eg, Prism, Echelle, etc + telescope = Column(String(30), nullable=False) + instrument = Column(String(30), nullable=False) + mode = Column(String(30), nullable=False) # eg, Prism, Echelle, etc observation_date = Column(DateTime, primary_key=True) # Common metadata @@ -387,6 +390,23 @@ class Spectra(Base): {}, ) + @validates("access_url", "regime", "source", "telescope", "instrument", "mode") + def validate_required(self, key, value): + if value is None: + raise ValueError(f"Value required for {key}") + return value + + @validates("observation_date") + def validate_date(self, key, value): + if value is None: + raise ValueError(f"Invalid date received: {value}") + elif not isinstance(value, datetime): + # Convert to datetime for storing in the database + # Will throw error if unable to convert + print("WARNING: Value will be converted to ISO format.") + value = datetime.fromisoformat(value) + return value + class ModeledParameters(Base): # Table to store derived/inferred paramaters from models diff --git a/simple/utils/spectra.py b/simple/utils/spectra.py index 8090bdb1a..78018af66 100644 --- a/simple/utils/spectra.py +++ b/simple/utils/spectra.py @@ -1,25 +1,22 @@ import logging -import requests -import numpy.ma as ma -import pandas as pd # used for to_datetime conversion -import dateutil # used to convert obs date to datetime object -import sqlalchemy.exc import sqlite3 -import numpy as np from typing import Optional -import matplotlib.pyplot as plt -from astropy.io import fits import astropy.units as u -from specutils import Spectrum1D - -from astrodbkit2.astrodb import Database +import matplotlib.pyplot as plt +import numpy as np +import requests +import sqlalchemy.exc from astrodb_utils import ( AstroDBError, find_source_in_db, internet_connection, - find_publication, ) +from astrodbkit2.astrodb import Database +from astropy.io import fits +from specutils import Spectrum1D + +from simple.schema import Spectra __all__ = [ "ingest_spectrum", @@ -76,6 +73,10 @@ def ingest_spectrum( Returns ------- flags: dict + Status response with the following keys: + - "added": True if it's added and False if it's skipped. + - "content": the data that was attempted to add + - "message": string which includes information about why skipped Raises ------ @@ -83,133 +84,17 @@ def ingest_spectrum( """ flags = { - "skipped": False, - "dupe": False, - "missing_instrument": False, - "no_obs_date": False, "added": False, - "plottable": False, + "content": {}, + "message": "" } - # Check input values - if regime is None: - msg = "Regime is required" - logger.error(msg) - flags["skipped"] = True - if raise_error: - raise AstroDBError(msg) - else: - return flags - else: - good_regime = db.query(db.Regimes).filter(db.Regimes.c.regime == regime).table() - if len(good_regime) == 0: - msg = f"Regime {regime} is not in Regimes table" - logger.error(msg) - flags["skipped"] = True - if raise_error: - raise AstroDBError(msg) - else: - return flags - - if telescope is None: - msg = "Telescope is required" - logger.error(msg) - flags["skipped"] = True - if raise_error: - raise AstroDBError(msg) - else: - return flags - else: - good_telescope = ( - db.query(db.Telescopes) - .filter(db.Telescopes.c.telescope == telescope) - .table() - ) - if len(good_telescope) == 0: - msg = f"Telescope {telescope} is not in Telescopes table" - logger.error(msg) - flags["skipped"] = True - if raise_error: - raise AstroDBError(msg) - else: - return flags - - if instrument is None: - msg = "Instrument is required" - logger.error(msg) - flags["skipped"] = True - if raise_error: - raise AstroDBError(msg) - else: - return flags - else: - good_instrument = ( - db.query(db.Instruments) - .filter(db.Instruments.c.instrument == instrument) - .table() - ) - if len(good_instrument) == 0: - msg = f"Instrument {instrument} is not in Instruments table" - logger.error(msg) - flags["skipped"] = True - if raise_error: - raise AstroDBError(msg) - else: - return flags - - if mode is None: - msg = "Mode is required" - logger.error(msg) - flags["skipped"] = True - if raise_error: - raise AstroDBError(msg) - else: - return flags - else: - good_mode = ( - db.query(db.Instruments) - .filter(db.Instruments.c.instrument == instrument) - .filter(db.Instruments.c.mode == mode) - .table() - ) - if len(good_mode) == 0: - msg = f"Mode {mode} is not in Instruments table for {instrument}" - logger.error(msg) - flags["skipped"] = True - if raise_error: - raise AstroDBError(msg) - else: - return flags - - if reference is None: - msg = "Reference is required" - logger.error(msg) - flags["skipped"] = True - if raise_error: - raise AstroDBError(msg) - else: - return flags - else: - good_reference = find_publication(db, reference=reference) - if good_reference[0] is False: - msg = ( - f"Spectrum for {source} could not be added to the database because the " - f"reference {reference} is not in Publications table. \n" - f"(Add it with ingest_publication function.) \n " - ) - logger.error(msg) - flags["skipped"] = True - if raise_error: - raise AstroDBError(msg) - else: - return flags - # Get source name as it appears in the database db_name = find_source_in_db(db, source) if len(db_name) != 1: msg = f"No unique source match for {source} in the database" - flags["skipped"] = True + flags["message"] = msg if raise_error: raise AstroDBError(msg) else: @@ -226,7 +111,6 @@ def ingest_spectrum( request_response.status_code ) # The website is up if the status code is 200 if status_code != 200: - flags["skipped"] = True msg = ( "The spectrum location does not appear to be valid: \n" f"spectrum: {spectrum} \n" @@ -242,7 +126,6 @@ def ingest_spectrum( request_response1 = requests.head(original_spectrum) status_code1 = request_response1.status_code if status_code1 != 200: - flags["skipped"] = True msg = ( "The spectrum location does not appear to be valid: \n" f"spectrum: {original_spectrum} \n" @@ -258,52 +141,6 @@ def ingest_spectrum( msg = "No internet connection. Internet is needed to check spectrum files." raise AstroDBError(msg) - # SKIP if observation date is blank - if ma.is_masked(obs_date) or obs_date == "" or obs_date is None: - obs_date = None - missing_obs_msg = ( - f"Skipping spectrum with missing observation date: {source} \n" - ) - missing_row_spe = f"{source, obs_date, reference} \n" - flags["no_obs_date"] = True - logger.debug(missing_row_spe) - if raise_error: - logger.error(missing_obs_msg) - raise AstroDBError(missing_obs_msg) - else: - logger.warning(missing_obs_msg) - return flags - else: - try: - obs_date = pd.to_datetime( - obs_date - ) # TODO: Another method that doesn't require pandas? - except ValueError: - flags["no_obs_date"] = True - if raise_error: - msg = ( - f"{source}: Can't convert obs date to Date Time object: {obs_date}" - ) - logger.error(msg) - raise AstroDBError(msg) - else: - return flags - except dateutil.parser._parser.ParserError: - flags["no_obs_date"] = True - if raise_error: - msg = ( - f"{source}: Can't convert obs date to Date Time object: {obs_date}" - ) - logger.error(msg) - raise AstroDBError(msg) - else: - msg = ( - f"Skipping {source} Can't convert obs date to Date Time object: " - f"{obs_date}" - ) - logger.warning(msg) - return flags - matches = find_spectra( db, source, @@ -314,14 +151,14 @@ def ingest_spectrum( mode=mode, ) if len(matches) > 0: - msg = f"Skipping suspected duplicate measurement\n{source}\n" - msg2 = f"{matches}" f"{instrument, mode, obs_date, reference, spectrum} \n" - logger.warning(msg) + msg = f"Skipping suspected duplicate measurement: {source}" + msg2 = f"{matches} {instrument, mode, obs_date, reference, spectrum}" logger.debug(msg2) - flags["dupe"] = True + flags["message"] = msg if raise_error: - raise AstroDBError + raise AstroDBError(msg) else: + logger.warning(msg) return flags # Check if spectrum is plottable @@ -343,40 +180,37 @@ def ingest_spectrum( "other_references": other_references, } logger.debug(row_data) + flags["content"] = row_data - # Attempt to add spectrum to database try: - with db.engine.connect() as conn: - conn.execute(db.Spectra.insert().values(row_data)) - conn.commit() + # Attempt to add spectrum to database + # This will throw errors based on validation in schema.py + # and any database checks (as for example IntegrityError) + obj = Spectra(**row_data) + with db.session as session: + session.add(obj) + session.commit() + flags["added"] = True logger.info(f"Added {source} : \n" f"{row_data}") - except sqlalchemy.exc.IntegrityError as e: - msg = "Integrity Error:" f"{source} \n" f"{row_data}" - logger.error(msg + str(e) + f" \n {row_data}") - flags["skipped"] = True - if raise_error: - raise AstroDBError(msg) - else: - return flags - except sqlite3.IntegrityError as e: - msg = "Integrity Error: " f"{source} \n" f"{row_data}" - logger.error(msg + str(e)) - flags["skipped"] = True + except (sqlite3.IntegrityError, sqlalchemy.exc.IntegrityError) as e: + msg = f"Integrity Error: {source} \n {e}" + flags["message"] = msg if raise_error: raise AstroDBError(msg) else: + logger.error(msg) return flags except Exception as e: msg = ( - f"Spectrum for {source} could not be added to the database" - f"for unexpected reason: \n {row_data} \n error: {str(e)}" + f"Spectrum for {source} could not be added to the database " + f"for unexpected reason: {e}" ) - logger.error(msg) - flags["skipped"] = True + flags["message"] = msg if raise_error: raise AstroDBError(msg) else: + logger.warning(msg) return flags return flags @@ -466,7 +300,7 @@ def spectrum_plottable(spectrum_path, raise_error=True, show_plot=False): return False except u.UnitConversionError as e: msg = ( - f"{str(e)} \n" + f"{e} \n" f"Skipping {spectrum_path}: unable to convert spectral axis to microns" ) if raise_error: @@ -476,7 +310,7 @@ def spectrum_plottable(spectrum_path, raise_error=True, show_plot=False): logger.warning(msg) return False except ValueError as e: - msg = f"{str(e)} \nSkipping {spectrum_path}: Value error" + msg = f"{e} \nSkipping {spectrum_path}: Value error" if raise_error: logger.error(msg) raise AstroDBError(msg) diff --git a/tests/test_schema.py b/tests/test_schema.py index b2824566a..1966300df 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,8 +1,10 @@ """Tests for the schema itself and any validating functions""" +from datetime import datetime + import pytest -from simple.schema import Parallaxes, PhotometryFilters, Publications, Sources +from simple.schema import Parallaxes, PhotometryFilters, Publications, Sources, Spectra def schema_tester(table, values, error_state): @@ -57,4 +59,23 @@ def test_publications(values, error_state): ]) def test_parallaxes(values, error_state): """Validating Parallaxes""" - schema_tester(Parallaxes, values, error_state) \ No newline at end of file + schema_tester(Parallaxes, values, error_state) + + +@pytest.mark.parametrize("values, error_state", + [ + ({"access_url": None}, ValueError), + ({"source": None}, ValueError), + ({"regime": None}, ValueError), + ({"telescope": None}, ValueError), + ({"instrument": None}, ValueError), + ({"mode": None}, ValueError), + ({"observation_date": "2024-01-01"}, None), + ({"observation_date": datetime(2024,1,1)}, None), + ({"observation_date": None}, ValueError), + ({"observation_date": "fake"}, ValueError), + ]) +def test_spectra(values, error_state): + """Validating Spectra""" + # Note: due to how this works, only the columns with values provided get tested + schema_tester(Spectra, values, error_state) \ No newline at end of file diff --git a/tests/test_spectra_utils.py b/tests/test_spectra_utils.py index a8eead2eb..f86d095bb 100644 --- a/tests/test_spectra_utils.py +++ b/tests/test_spectra_utils.py @@ -3,6 +3,7 @@ from astrodb_utils.utils import ( AstroDBError, ) + from simple.utils.spectra import ( ingest_spectrum, # ingest_spectrum_from_fits, @@ -10,8 +11,6 @@ ) -@pytest.mark.filterwarnings("ignore::UserWarning") -@pytest.mark.filterwarnings("ignore", message=".*not FITS standard.*") @pytest.mark.filterwarnings( "ignore", message=".*Note: astropy.io.fits uses zero-based indexing.*" ) @@ -25,163 +24,78 @@ "which is discouraged by the FITS standard.*", ), ) -def test_ingest_spectrum_errors(temp_db): - spectrum = "https://bdnyc.s3.amazonaws.com/tests/U10176.fits" - with pytest.raises(AstroDBError) as error_message: - ingest_spectrum(temp_db, source="apple", spectrum=spectrum) - assert "Regime is required" in str(error_message.value) - result = ingest_spectrum( - temp_db, source="apple", spectrum=spectrum, raise_error=False - ) - assert result["added"] is False - assert result["skipped"] is True +@pytest.mark.parametrize("test_input, message", [ + ({"source": "apple", + "telescope": "IRTF", + "instrument": "SpeX", + "mode": "Prism", + }, "Value required for regime"), # missing regime + ({"source": "apple", + "regime": "nir", + "instrument": "SpeX", + "obs_date": "2020-01-01", + }, "Value required for telescope"), # missing telescope + ({"source": "apple", + "regime": "nir", + "telescope": "IRTF", + "obs_date": "2020-01-01", + }, "Value required for instrument"), # missing instrument + ({"source": "apple", + "telescope": "IRTF", + "instrument": "SpeX", + "mode": "Prism", + "regime": "nir", + "obs_date": "2020-01-01", + }, "NOT NULL constraint failed: Spectra.reference"), # missing reference + ({"source": "apple", + "telescope": "IRTF", + "instrument": "SpeX", + "mode": "Prism", + "regime": "nir", + "obs_date": "2020-01-01", + "reference": "Ref 5", + }, "FOREIGN KEY constraint failed"), # invalid reference + ({"source": "kiwi", + "telescope": "IRTF", + "instrument": "SpeX", + "mode": "Prism", + "regime": "nir", + "obs_date": "2020-01-01", + "reference": "Ref 1", + }, "No unique source match for kiwi in the database"), # invalid source + ({"source": "apple", + "telescope": "IRTF", + "instrument": "SpeX", + "mode": "Prism", + "regime": "nir", + "reference": "Ref 1", + }, "Invalid date received: None"), # missing date + ({"source": "apple", + "telescope": "IRTF", + "instrument": "SpeX", + "mode": "Prism", + "regime": "fake regime", + "obs_date": "2020-01-01", + "reference": "Ref 1", + }, "FOREIGN KEY constraint failed"), # invalid regime +]) +def test_ingest_spectrum_errors(temp_db, test_input, message): + # Test for ingest_spectrum that is expected to return errors - with pytest.raises(AstroDBError) as error_message: - ingest_spectrum( - temp_db, - source="apple", - telescope="IRTF", - instrument="SpeX", - mode="Prism", - regime="nir", - spectrum=spectrum, - ) - assert "Reference is required" in str(error_message.value) - ingest_spectrum( - temp_db, source="apple", regime="nir", spectrum=spectrum, raise_error=False - ) - assert result["added"] is False - assert result["skipped"] is True - - with pytest.raises(AstroDBError) as error_message: - ingest_spectrum( - temp_db, - source="apple", - regime="nir", - spectrum=spectrum, - telescope="IRTF", - instrument="SpeX", - mode="Prism", - reference="Ref 5", - ) - assert "not in Publications table" in str(error_message.value) - ingest_spectrum( - temp_db, - source="apple", - regime="nir", - spectrum=spectrum, - telescope="IRTF", - instrument="SpeX", - mode="Prism", - reference="Ref 5", - raise_error=False, - ) - assert result["added"] is False - assert result["skipped"] is True - - with pytest.raises(AstroDBError) as error_message: - ingest_spectrum( - temp_db, - source="kiwi", - regime="nir", - spectrum=spectrum, - reference="Ref 1", - telescope="IRTF", - instrument="SpeX", - mode="Prism", - ) - assert "No unique source match for kiwi in the database" in str(error_message.value) - result = ingest_spectrum( - temp_db, - source="kiwi", - regime="nir", - spectrum=spectrum, - reference="Ref 1", - raise_error=False, - telescope="IRTF", - instrument="SpeX", - mode="Prism", - ) - assert result["added"] is False - assert result["skipped"] is True + # Prepare parameters to send to ingest_spectrum + spectrum = "https://bdnyc.s3.amazonaws.com/tests/U10176.fits" + parameters = {"db": temp_db, "spectrum": spectrum} + parameters.update(test_input) + # Check that error was raised with pytest.raises(AstroDBError) as error_message: - ingest_spectrum( - temp_db, - source="apple", - regime="nir", - spectrum=spectrum, - reference="Ref 1", - telescope="IRTF", - instrument="SpeX", - mode="Prism", - ) - assert "missing observation date" in str(error_message.value) - result = ingest_spectrum( - temp_db, - source="apple", - regime="nir", - spectrum=spectrum, - reference="Ref 1", - telescope="IRTF", - instrument="SpeX", - mode="Prism", - raise_error=False, - ) - assert result["added"] is False - assert result["skipped"] is False - assert result["no_obs_date"] is True - - # with pytest.raises(AstroDBError) as error_message: - # ingest_spectrum( - # db, - # source="orange", - # regime="nir", - # spectrum=spectrum, - # reference="Ref 1", - # obs_date="Jan20", - # ) - # assert "Can't convert obs date to Date Time object" in str(error_message.value) - # result = ingest_spectrum( - # db, - # source="orange", - # regime="nir", - # spectrum=spectrum, - # reference="Ref 1", - # obs_date="Jan20", - # raise_error=False, - # ) - # assert result["added"] is False - # assert result["skipped"] is False - # assert result["no_obs_date"] is True + _ = ingest_spectrum(**parameters) + assert message in str(error_message.value) - with pytest.raises(AstroDBError) as error_message: - result = ingest_spectrum( - temp_db, - source="orange", - regime="far-uv", - spectrum=spectrum, - reference="Ref 1", - obs_date="1/1/2024", - telescope="Keck I", - instrument="LRIS", - mode="OG570", - ) - assert "not in Regimes table" in str(error_message.value) - result = ingest_spectrum( - temp_db, - source="orange", - regime="far-uv", - spectrum=spectrum, - reference="Ref 1", - obs_date="1/1/2024", - telescope="Keck I", - instrument="LRIS", - mode="OG570", - raise_error=False, - ) + # Suppress error but check that it was still captured + result = ingest_spectrum(**parameters, raise_error=False) assert result["added"] is False - assert result["skipped"] is True + assert message in result["message"] @pytest.mark.filterwarnings("ignore:Verification")