diff --git a/pyproject.toml b/pyproject.toml index 13bdbe4..4726636 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "skycalc_ipy" -version = "0.2.0" +version = "0.3.0a0" # When updating the version, also # - update the date in anisocado/version.py # - update the release notese in docs/source/index.rst @@ -28,7 +28,7 @@ dependencies = [ # package on PyPI, for minimumdependencies.yml "numpy>=1.18.0", "astropy>=4.0", - "requests>=2.20", + "httpx>=0.23", "pyyaml>=5.3" ] diff --git a/skycalc_ipy/core.py b/skycalc_ipy/core.py index 0cb7fa8..a93a3c7 100644 --- a/skycalc_ipy/core.py +++ b/skycalc_ipy/core.py @@ -1,57 +1,94 @@ +# -*- coding: utf-8 -*- """ -Based on the skycalc_cli package. +Based on the skycalc_cli package, but heavily modified. -This modele was taken (mostly unmodified) from ``skycalc_cli`` version 1.1. +The original code was taken from ``skycalc_cli`` version 1.1. Credit for ``skycalc_cli`` goes to ESO """ -from __future__ import print_function +import logging +import warnings import hashlib import json -import os +from os import environ from datetime import datetime from pathlib import Path -from typing import Dict -from types import ModuleType +from collections.abc import Mapping +from importlib import import_module -import requests +import httpx from astropy.io import fits -try: - import scopesim_data -except: - scopesim_data = None +CACHE_DIR_FALLBACK = ".astar/skycalc_ipy" -def get_cache_filenames(params: Dict, prefix: str, suffix: str): - """Get filenames to cache the data. + +def get_cache_dir() -> Path: + """Establish the cache directory. There are three possible locations for the cache directory: 1. As set in `os.environ["SKYCALC_IPY_CACHE_DIR"]` 2. As set in the `scopesim_data` package. 3. The `data` directory in this package. """ + try: + dir_cache = Path(environ["SKYCALC_IPY_CACHE_DIR"]) + except KeyError: + try: + sim_data = import_module("scopesim_data") + dir_cache = Path(getattr(sim_data, "dir_cache_skycalc")) + except (ImportError, AttributeError): + dir_cache = Path.home() / CACHE_DIR_FALLBACK + + if not dir_cache.is_dir(): + dir_cache.mkdir(parents=True) + + return dir_cache + + +class ESOQueryBase: + """Base class for queries to ESO skycalc server.""" + + REQUEST_TIMEOUT = 2 # Time limit (in seconds) for server response + BASE_URL = "https://etimecalret-002.eso.org/observing/etc" - if "SKYCALC_IPY_CACHE_DIR" in os.environ: - dir_cache = Path(os.environ["SKYCALC_IPY_CACHE_DIR"]) - elif isinstance(scopesim_data, ModuleType): - dir_cache = scopesim_data.dir_cache_skycalc - else: - dir_cache = Path(__file__).parent / "data" - # Three underscores between the key-value pairs, two underscores - # between the key and the value. - akey = "___".join(f"{k}__{v}" for k, v in params.items()) - ahash = hashlib.sha256(akey.encode("utf-8")).hexdigest() - fn_data = dir_cache / f"{prefix}_{ahash}.{suffix}" - fn_params = fn_data.with_suffix(".params.json") - return fn_data, fn_params - - -class AlmanacQuery: + def __init__(self, url, params): + self.url = url + self.params = params + + def _send_request(self) -> httpx.Response: + try: + with httpx.Client(base_url=self.BASE_URL, + timeout=self.REQUEST_TIMEOUT) as client: + response = client.post(self.url, json=self.params) + response.raise_for_status() + except httpx.RequestError as err: + logging.exception("An error occurred while requesting %s.", + err.request.url) + raise err + except httpx.HTTPStatusError as err: + logging.error("Error response %s while requesting %s.", + err.response.status_code, err.request.url) + raise err + return response + + def get_cache_filenames(self, prefix: str, suffix: str) -> str: + """Produce filename from hass of parameters. + + Using three underscores between the key-value pairs and two underscores + between the key and the value. + """ + akey = "___".join(f"{k}__{v}" for k, v in self.params.items()) + ahash = hashlib.sha256(akey.encode("utf-8")).hexdigest() + fname = f"{prefix}_{ahash}.{suffix}" + return fname + + +class AlmanacQuery(ESOQueryBase): """ - A class for querying the SkyCalc Almanac + A class for querying the SkyCalc Almanac. Parameters ---------- @@ -70,15 +107,14 @@ class AlmanacQuery: """ - REQUEST_TIMEOUT = 2 # Time limit (in seconds) for server response - def __init__(self, indic): + # FIXME: This basically checks isinstance(indic, ui.SkyCalc), but we + # can't import that because it would create a circual import. + # TODO: Find a better way to do this!! if hasattr(indic, "defaults"): indic = indic.values - self.almdata = None - self.almserver = "https://etimecalret-002.eso.org" - self.almurl = "/observing/etc/api/skycalc_almanac" + super().__init__("/api//skycalc_almanac", {}) # Left: users keyword (skycalc_cli), # Right: skycalc Almanac output keywords @@ -94,7 +130,6 @@ def __init__(self, indic): "observatory": "observatory", } - self.almindic = {} # The Almanac needs: # coord_ra : float [deg] # coord_dec : float [deg] @@ -107,50 +142,77 @@ def __init__(self, indic): # coord_ut_min : int # coord_ut_sec : float + self._set_date(indic) + self._set_radec(indic, "ra") + self._set_radec(indic, "dec") + + if "observatory" in indic: + self.params["observatory"] = indic["observatory"] + + def _set_date(self, indic): if "date" in indic and indic["date"] is not None: if isinstance(indic["date"], str): isotime = datetime.strptime(indic["date"], "%Y-%m-%dT%H:%M:%S") else: isotime = indic["date"] - self.almindic["input_type"] = "ut_time" - self.almindic["coord_year"] = isotime.year - self.almindic["coord_month"] = isotime.month - self.almindic["coord_day"] = isotime.day - self.almindic["coord_ut_hour"] = isotime.hour - self.almindic["coord_ut_min"] = isotime.minute - self.almindic["coord_ut_sec"] = isotime.second + updated = { + "input_type": "ut_time", + "coord_year": isotime.year, + "coord_month": isotime.month, + "coord_day": isotime.day, + "coord_ut_hour": isotime.hour, + "coord_ut_min": isotime.minute, + "coord_ut_sec": isotime.second, + } elif "mjd" in indic and indic["mjd"] is not None: - self.almindic["input_type"] = "mjd" - self.almindic["mjd"] = float(indic["mjd"]) + updated = { + "input_type": "mjd", + "mjd": float(indic["mjd"]), + } else: raise ValueError("No valid date or mjd given for the Almanac") - if "ra" not in indic or "dec" not in indic: - raise ValueError("ra/dec coordinate not given for the Almanac.") + self.params.update(updated) + def _set_radec(self, indic, which): try: - ra = float(indic["ra"]) - except ValueError: - print("Error: wrong ra format for the Almanac.") - raise - self.almindic["coord_ra"] = ra + self.params[f"coord_{which}"] = float(indic[which]) + except KeyError as err: + logging.exception("%s coordinate not given for the Almanac.", + which) + raise err + except ValueError as err: + logging.exception("Wrong %s format for the Almanac.", which) + raise err + + def _get_jsondata(self, file_path: Path): + if file_path.exists(): + return json.load(file_path.open(encoding="utf-8")) + + response = self._send_request() + if not response.text: + raise ValueError("Empty response.") + + jsondata = response.json()["output"] + # Use a fixed date so the stored files are always identical for + # identical requests. + jsondata["execution_datetime"] = "2017-01-07T00:00:00 UTC" try: - dec = float(indic["dec"]) - except ValueError: - print("Error: wrong dec format for the Almanac.") - raise - self.almindic["coord_dec"] = dec + json.dump(jsondata, file_path.open("w", encoding="utf-8")) + # json.dump(self.params, open(fn_params, 'w')) + except (PermissionError, FileNotFoundError) as err: + # Apparently it is not possible to save here. + raise err - if "observatory" in indic: - self.almindic["observatory"] = indic["observatory"] + return jsondata - def query(self): + def __call__(self): """ - Queries the ESO Skycalc server with the parameters in self.almindic + Query the ESO Skycalc server with the parameters in self.params. Returns ------- @@ -158,29 +220,10 @@ def query(self): Dictionary with the relevant parameters for the date given """ - fn_data, fn_params = get_cache_filenames(self.almindic, "almanacquery", "json") - if fn_data.exists(): - jsondata = json.load(open(fn_data)) - else: - url = self.almserver + self.almurl - - response = requests.post( - url, json.dumps(self.almindic), timeout=self.REQUEST_TIMEOUT - ) - rawdata = response.text - - jsondata1 = json.loads(rawdata) - jsondata = jsondata1["output"] - # Use a fixed date so the stored files are always identical for - # identical requests. - jsondata["execution_datetime"] = "2017-01-07T00:00:00 UTC" - - try: - json.dump(jsondata, open(fn_data, 'w')) - json.dump(self.almindic, open(fn_params, 'w')) - except (PermissionError, FileNotFoundError): - # Apparently it is not possible to save here. - pass + cache_dir = get_cache_dir() + cache_name = self.get_cache_filenames("almanacquery", "json") + cache_path = cache_dir / cache_name + jsondata = self._get_jsondata(cache_path) # Find the relevant (key, value) almdata = {} @@ -195,15 +238,22 @@ def query(self): try: almdata[key] = jsondata[subsection][value] except (KeyError, ValueError): - print(f"Warning: key \"{subsection}/{value}\" not found in the" - " Almanac response.") + logging.warning("Key '%s/%s' not found in Almanac response.", + subsection, value) return almdata + def query(self): + """Deprecated feature. Class is now callable, use that instead.""" + warnings.warn("The .query() method is deprecated and will be removed " + "in a future release. Please simply call the instance.", + DeprecationWarning, stacklevel=2) + return self() + -class SkyModel: +class SkyModel(ESOQueryBase): """ - Class for querying the Advanced SkyModel at ESO + Class for querying the Advanced SkyModel at ESO. Contains all the parameters needed for querying the ESO SkyCalc server. The parameters are contained in :attr:`.params` and the returned FITS file @@ -215,17 +265,13 @@ class SkyModel: """ - REQUEST_TIMEOUT = 2 # Time limit (in seconds) for server response - def __init__(self): - self.stop_on_errors_and_exceptions = True self.data = None - self.server = "https://etimecalret-002.eso.org" - self.url = self.server + "/observing/etc/api/skycalc" - self.deleter_script_url = self.server + "/observing/etc/api/rmtmp" - self.bugreport_text = "" + self.data_url = "/tmp/" + self.deleter_script_url = "/api/rmtmp" + self._last_status = "" self.tmpdir = "" - self.params = { + params = { # Airmass. Alt and airmass are coupled through the plane parallel # approximation airmass=sec(z), z being the zenith distance # z=90-Alt @@ -312,14 +358,25 @@ def __init__(self): "observatory": "paranal", # paranal } + super().__init__("/api/skycalc", params) + def fix_observatory(self): """ - Converts the human readable observatory name into its ESO ID number + Convert the human readable observatory name into its ESO ID number. The following observatory names are accepted: ``lasilla``, ``paranal``, ``armazones`` or ``3060m``, ``highanddry`` or ``5000m`` """ + # FIXME: DO WE ALWAYS WANT TO RAISE WHEN IT'S NOT ONE OF THOSE??? + if self.params["observatory"] not in { + "paranal", + "lasilla", + "armazones", + "3060m", + "5000m", + }: + return # nothing to do if self.params["observatory"] == "lasilla": self.params["observatory"] = "2400" @@ -339,6 +396,7 @@ def fix_observatory(self): raise ValueError( "Wrong Observatory name, please refer to the documentation." ) + return # for consistency def __getitem__(self, item): return self.params[item] @@ -348,135 +406,120 @@ def __setitem__(self, key, value): if key == "observatory": self.fix_observatory() - def handle_exception(self, err, msg): - print(msg) - print(err) - print(self.bugreport_text) - if self.stop_on_errors_and_exceptions: - # There used to be a sys.exit here. That was probably there to - # provide a clean exit when using skycalc_ipy as a command-line - # tool or something like that. However, skycalc_ipy is also used as - # a library, and libraries should never just exit and this - # command-line functionality does not seem to exist. So instead, - # just raise here. See also handle_error() below. - raise - - # handle the kind of errors we issue ourselves. - def handle_error(self, msg): - print(msg) - print(self.bugreport_text) - if self.stop_on_errors_and_exceptions: - # See handle_exception above. - raise - - def retrieve_data(self, url): + def _retrieve_data(self, url): try: self.data = fits.open(url) # Use a fixed date so the stored files are always identical for # identical requests. self.data[0].header["DATE"] = "2017-01-07T00:00:00" - except requests.exceptions.RequestException as err: - self.handle_exception( - err, "Exception raised trying to get FITS data from " + url - ) + except Exception as err: + logging.exception( + "Exception raised trying to get FITS data from %s", url) + raise err def write(self, local_filename, **kwargs): + """Write data to file.""" try: self.data.writeto(local_filename, **kwargs) - except (IOError, FileNotFoundError) as err: - self.handle_exception(err, "Exception raised trying to write fits file ") + except (IOError, FileNotFoundError): + logging.exception("Exception raised trying to write fits file.") def getdata(self): + """Deprecated feature, just use the .data attribute.""" + warnings.warn("The .getdata method is deprecated and will be removed " + "in a future release. Use the identical .data attribute " + "instead.", DeprecationWarning, stacklevel=2) return self.data - def delete_server_tmpdir(self, tmpdir): + def _delete_server_tmpdir(self, tmpdir): try: - response = requests.get( - self.deleter_script_url + "?d=" + tmpdir, timeout=self.REQUEST_TIMEOUT - ) - deleter_response = response.text.strip() + with httpx.Client(base_url=self.BASE_URL, + timeout=self.REQUEST_TIMEOUT) as client: + response = client.get(self.deleter_script_url, + params={"d": tmpdir}) + deleter_response = response.text.strip().lower() if deleter_response != "ok": - self.handle_error("Could not delete server tmpdir " + tmpdir) - except requests.exceptions.RequestException as err: - self.handle_exception( - err, "Exception raised trying to delete tmp dir " + tmpdir - ) - - def call(self, test=False): - # print 'self.url=',self.url - # print 'self.params=',self.params - - if self.params["observatory"] in { - "paranal", - "lasilla", - "armazones", - "3060m", - "5000m", - }: - self.fix_observatory() - - fn_data, fn_params = get_cache_filenames(self.params, "skymodel", "fits") - if fn_data.exists(): - self.data = fits.open(fn_data) + logging.error("Could not delete server tmpdir %s: %s", + tmpdir, deleter_response) + except httpx.HTTPError: + logging.exception("Exception raised trying to delete tmp dir %s", + tmpdir) + + def _update_params(self, updated: Mapping) -> None: + par_keys = self.params.keys() + new_keys = updated.keys() + self.params.update((key, updated[key]) for key in par_keys & new_keys) + logging.debug("Ignoring invalid keywords: %s", new_keys - par_keys) + + def __call__(self, **kwargs): + """Send server request.""" + if kwargs: + logging.info("Setting new parameters: %s", kwargs) + + self._update_params(kwargs) + self.fix_observatory() + + cache_dir = get_cache_dir() + cache_name = self.get_cache_filenames("skymodel", "fits") + cache_path = cache_dir / cache_name + + if cache_path.exists(): + self.data = fits.open(cache_path) return - try: - response = requests.post( - self.url, data=json.dumps(self.params), timeout=self.REQUEST_TIMEOUT - ) - except requests.exceptions.RequestException as err: - self.handle_exception( - err, "Exception raised trying to POST request " + self.url - ) - return + response = self._send_request() try: - res = json.loads(response.text) + res = response.json() status = res["status"] tmpdir = res["tmpdir"] except (KeyError, ValueError) as err: - self.handle_exception( - err, "Exception raised trying to decode server response " - ) - return + logging.exception( + "Exception raised trying to decode server response.") + raise err - tmpurl = self.server + "/observing/etc/tmp/" + tmpdir + "/skytable.fits" + self._last_status = status if status == "success": try: # retrive and save FITS data (in memory) - self.retrieve_data(tmpurl) - except requests.exceptions.RequestException as err: - self.handle_exception(err, "could not retrieve FITS data from server") + self._retrieve_data( + self.BASE_URL + self.data_url + tmpdir + "/skytable.fits") + except httpx.HTTPError as err: + logging.exception("Could not retrieve FITS data from server.") + raise err try: - self.data.writeto(fn_data) - json.dump(self.params, open(fn_params, 'w')) + self.data.writeto(cache_path) + # with fn_params.open("w", encoding="utf-8") as file: + # json.dump(self.params, file) except (PermissionError, FileNotFoundError): # Apparently it is not possible to save here. pass - self.delete_server_tmpdir(tmpdir) + self._delete_server_tmpdir(tmpdir) else: # print why validation failed - self.handle_error("parameter validation error: " + res["error"]) + logging.error("Parameter validation error: %s", res["error"]) - if test: - # print 'call() returning status:',status - return status + def call(self): + """Deprecated feature, just call the instance.""" + warnings.warn("The .call() method is deprecated and will be removed " + "in a future release. Please simply call the instance.", + DeprecationWarning, stacklevel=2) + self() def callwith(self, newparams): - for key, val in newparams.items(): - if key in self.params: # valid - self.params[key] = val - else: - pass - # print('callwith() ignoring invalid keyword: ', key) - self.call() + """Deprecated feature, just call the instance.""" + warnings.warn("The .callwith(args) method is deprecated and will be " + "removed in a future release. Please simply call the " + "instance with optional kwargs instead.", + DeprecationWarning, stacklevel=2) + self(**newparams) def printparams(self, keys=None): """ - List the values of all, or a subset, of parameters + List the values of all, or a subset, of parameters. Parameters ---------- @@ -485,7 +528,4 @@ def printparams(self, keys=None): """ for key in keys or self.params.keys(): - print(key, self.params[key]) - - def reset(self): - self.__init__() + print(f" {key}: {self.params[key]}") diff --git a/skycalc_ipy/tests/test_core.py b/skycalc_ipy/tests/test_core.py index e7e6c32..4c47b70 100644 --- a/skycalc_ipy/tests/test_core.py +++ b/skycalc_ipy/tests/test_core.py @@ -1,53 +1,61 @@ +import pytest from pytest import raises from skycalc_ipy import ui from skycalc_ipy import core from datetime import datetime as dt -# Mocks -skp = ui.SkyCalc() + +@pytest.fixture +def skp(): + return ui.SkyCalc() + + +@pytest.fixture +def skp_basic(skp): + skp.update({"ra": 0, "dec": 0}) + return skp class TestAlmanacInit: - def test_throws_exception_when_passed_virgin_SkyCalcParams(self): + def test_throws_exception_when_passed_virgin_SkyCalcParams(self, skp): with raises(ValueError): core.AlmanacQuery(skp) - def test_passes_for_valid_SkyCalcParams_with_date(self): + def test_passes_for_valid_SkyCalcParams_with_date(self, skp): skp.update({"ra": 0, "dec": 0, "date": "2000-1-1T0:0:0", "mjd": None}) alm = core.AlmanacQuery(skp) - print(alm.almindic) - assert alm.almindic["coord_year"] == 2000 - assert alm.almindic["coord_ut_sec"] == 0 - assert alm.almindic["input_type"] == "ut_time" + assert alm.params["coord_year"] == 2000 + assert alm.params["coord_ut_sec"] == 0 + assert alm.params["input_type"] == "ut_time" - def test_passes_for_valid_SkyCalcParams_with_mjd(self): - skp.update({"mjd": 0, "date": None}) - alm = core.AlmanacQuery(skp) - print(alm.almindic) - assert alm.almindic["mjd"] == 0 - assert alm.almindic["input_type"] == "mjd" + def test_passes_for_valid_SkyCalcParams_with_mjd(self, skp_basic): + skp_basic.update({"mjd": 0, "date": None}) + alm = core.AlmanacQuery(skp_basic) + assert alm.params["mjd"] == 0 + assert alm.params["input_type"] == "mjd" - def test_throws_exception_when_date_and_mjd_are_empty(self): - skp.update({"mjd": None, "date": None}) + def test_throws_exception_when_date_and_mjd_are_empty(self, skp_basic): + skp_basic.update({"mjd": None, "date": None}) with raises(ValueError): - core.AlmanacQuery(skp) + core.AlmanacQuery(skp_basic) - def test_passes_for_date_as_datetime_object(self): - skp.update({"date": dt(1986, 4, 26, 1, 24)}) - alm = core.AlmanacQuery(skp) - assert alm.almindic["coord_ut_min"] == 24 - assert alm.almindic["input_type"] == "ut_time" + def test_passes_for_date_as_datetime_object(self, skp_basic): + skp_basic.update({"date": dt(1986, 4, 26, 1, 24)}) + alm = core.AlmanacQuery(skp_basic) + assert alm.params["coord_ut_min"] == 24 + assert alm.params["input_type"] == "ut_time" - def test_throws_exception_when_date_is_unintelligible(self): - skp.update({"date": "bogus"}) + def test_throws_exception_when_date_is_unintelligible(self, skp_basic): + skp_basic.update({"date": "bogus"}) with raises(ValueError): - core.AlmanacQuery(skp) + core.AlmanacQuery(skp_basic) - def test_throws_exception_when_mjd_is_unintelligible(self): - skp.update({"mjd": "bogus"}) + def test_throws_exception_when_mjd_is_unintelligible(self, skp_basic): + skp_basic.update({"mjd": "bogus"}) with raises(ValueError): - core.AlmanacQuery(skp) + core.AlmanacQuery(skp_basic) + class TestLoadDataFromCache: def test_load_skymodel_from_cache(self): @@ -60,9 +68,7 @@ def test_load_skymodel_from_cache(self): } skymodel = core.SkyModel() - skymodel.callwith(params) + skymodel(**params) alm = core.AlmanacQuery(params) - alm.query() - - + _ = alm() diff --git a/skycalc_ipy/tests/test_ui.py b/skycalc_ipy/tests/test_ui.py index ccbee4a..44490d5 100644 --- a/skycalc_ipy/tests/test_ui.py +++ b/skycalc_ipy/tests/test_ui.py @@ -1,5 +1,5 @@ -import os -import inspect +from pathlib import Path +from collections.abc import Sequence, Mapping import pytest from pytest import raises @@ -8,26 +8,40 @@ from astropy.io import fits import synphot as sp -# Mocks -skp = ui.SkyCalc() -skp2 = ui.SkyCalc() -skp_small = ui.SkyCalc() -skp_small["wdelta"] = 100 + +PATH_HERE = Path(__file__).parent + + +@pytest.fixture +def skp(): + return ui.SkyCalc() + + +@pytest.fixture +def skp_small(): + skps = ui.SkyCalc() + skps["wdelta"] = 100 + return skps + + +@pytest.fixture +def basic_almanac_no_update(skp): + return skp.get_almanac_data(ra=180, dec=0, mjd=50000, + observatory="lasilla", update_values=False) class TestLoadYaml: def test_finds_file_for_specified_path(self): - dirname = os.path.dirname(inspect.getfile(inspect.currentframe())) - yaml_dict = ui.load_yaml(os.path.join(dirname, "../params.yaml")) + yaml_dict = ui.load_yaml(PATH_HERE.parent / "params.yaml") assert yaml_dict["season"][0] == 0 def test_throws_exception_for_nonexistent_file(self): with raises(ValueError): - ui.load_yaml("bogus.yaml") + ui.load_yaml(Path("bogus.yaml")) def test_accepts_string_block_input(self): str_yaml = """ - params : + params : - hello - world """ @@ -36,83 +50,81 @@ def test_accepts_string_block_input(self): class TestSkycalcParamsInit: - def test_loads_default_when_no_file_given(self): - assert type(skp.defaults) == dict + def test_loads_default_when_no_file_given(self, skp): + assert isinstance(skp.defaults, Mapping) assert skp.defaults["observatory"] == "paranal" assert skp.allowed["therm_t2"] == 0 - def test_print_comments_single_keywords(self, capsys): + def test_print_comments_single_keywords(self, skp, capsys): skp.print_comments("airmass") output = capsys.readouterr()[0].strip() assert output == "airmass : airmass in range [1.0, 3.0]" - def test_print_comments_mutliple_keywords(self, capsys): - skp.print_comments(["airmass", "season"]) + def test_print_comments_mutliple_keywords(self, skp, capsys): + skp.print_comments("airmass", "season") output = capsys.readouterr()[0].strip() - assert ( - output - == "airmass : airmass in range [1.0, 3.0]\n" - + "season : 0=all year, 1=dec/jan,2=feb/mar..." - ) + expected = ("airmass : airmass in range [1.0, 3.0]\n" + " season : 0=all year, 1=dec/jan,2=feb/mar...") + assert output == expected - def test_print_comments_misspelled_keyword(self, capsys): - skp.print_comments(["iarmass"]) + def test_print_comments_misspelled_keyword(self, skp, capsys): + skp.print_comments("iarmass") sys_out = capsys.readouterr() output = sys_out[0].strip() - assert output == "iarmass not found" + assert output == "iarmass : " - def test_keys_returns_list_of_keys(self): - assert type(skp.keys) == list + def test_keys_returns_list_of_keys(self, skp): + assert isinstance(skp.keys, Sequence) assert "observatory" in skp.keys -class TestSkycalcParamsValidateMethod(object): - def test_returns_true_for_all_good(self): - assert skp.validate_params() is True +class TestSkycalcParamsValidateMethod: + def test_returns_true_for_all_good(self, skp): + assert skp.validate_params() - def test_returns_false_for_bung_YN_flag(self): + def test_returns_false_for_bung_YN_flag(self, skp): skp["incl_starlight"] = "Bogus" - assert skp.validate_params() is False + assert not skp.validate_params() - def test_returns_false_for_bung_string_in_array(self): + def test_returns_false_for_bung_string_in_array(self, skp): skp["lsf_type"] = "Bogus" - assert skp.validate_params() is False + assert not skp.validate_params() - def test_returns_false_for_value_outside_range(self): + def test_returns_false_for_value_outside_range(self, skp): skp["airmass"] = 0.5 - assert skp.validate_params() is False + assert not skp.validate_params() - def test_returns_false_for_value_below_zero(self): + def test_returns_false_for_value_below_zero(self, skp): skp["lsf_boxcar_fwhm"] = -5.0 - assert skp.validate_params() is False + assert not skp.validate_params() class TestSkyCalcParamsGetSkySpectrum: @pytest.mark.webtest - def test_returns_data_with_valid_parameters(self): + def test_returns_data_with_valid_parameters(self, skp_small): tbl = skp_small.get_sky_spectrum() assert "lam" in tbl.colnames assert "flux" in tbl.colnames assert "trans" in tbl.colnames assert len(tbl) == 4606 - def test_throws_exception_for_invalid_parameters(self): + def test_throws_exception_for_invalid_parameters(self, skp): skp["airmass"] = 9001 with raises(ValueError): skp.get_sky_spectrum() @pytest.mark.webtest - def test_returns_table_for_return_type_table(self): + def test_returns_table_for_return_type_table(self, skp_small): tbl = skp_small.get_sky_spectrum(return_type="table") assert isinstance(tbl, table.Table) @pytest.mark.webtest - def test_returns_fits_for_return_type_fits(self): + def test_returns_fits_for_return_type_fits(self, skp_small): hdu = skp_small.get_sky_spectrum(return_type="fits") assert isinstance(hdu, fits.HDUList) @pytest.mark.webtest - def test_returned_fits_has_proper_meta_data(self): + def test_returned_fits_has_proper_meta_data(self, skp_small): hdu = skp_small.get_sky_spectrum(return_type="fits") assert "DATE_CRE" in hdu[0].header assert "SOURCE" in hdu[0].header @@ -120,12 +132,12 @@ def test_returned_fits_has_proper_meta_data(self): assert hdu[0].header["ETYPE"] == "TERCurve" @pytest.mark.webtest - def test_returns_three_arrays_for_return_type_array(self): + def test_returns_three_arrays_for_return_type_array(self, skp_small): arrs = skp_small.get_sky_spectrum(return_type="array") assert len(arrs) == 3 @pytest.mark.webtest - def test_returns_two_synphot_objects_for_return_type_synphot(self): + def test_returns_two_synphot_objects_for_return_type_synphot(self, skp_small): trans, flux = skp_small.get_sky_spectrum(return_type="synphot") assert isinstance(trans, sp.SpectralElement) assert isinstance(flux, sp.SourceSpectrum) @@ -136,7 +148,7 @@ def test_returns_nothing_if_return_type_is_invalid(self): class TestSkyCalcParamsGetAlmanacData: @pytest.mark.webtest - def test_return_updated_SkyCalcParams_values_dict_when_flag_true(self): + def test_return_updated_SkyCalcParams_values_dict_when_flag_true(self, skp): out_dict = skp.get_almanac_data( ra=180, dec=0, mjd=50000, observatory="lasilla", update_values=True ) @@ -144,15 +156,14 @@ def test_return_updated_SkyCalcParams_values_dict_when_flag_true(self): assert skp["observatory"] == "lasilla" @pytest.mark.webtest - def test_return_only_almanac_data_when_update_flag_false(self): - skp2["observatory"] == "paranal" - out_dict = skp.get_almanac_data( - ra=180, dec=0, mjd=50000, observatory="lasilla", update_values=False - ) + def test_return_only_almanac_data_when_update_flag_false( + self, skp, basic_almanac_no_update): + skp["observatory"] = "paranal" + out_dict = basic_almanac_no_update assert out_dict["observatory"] == "lasilla" - assert skp2["observatory"] == "paranal" + assert skp["observatory"] == "paranal" - def raise_error_if_both_date_and_mjd_are_empty(self): + def raise_error_if_both_date_and_mjd_are_empty(self, skp): with raises(ValueError): skp.get_almanac_data(180, 0) @@ -181,7 +192,7 @@ def test_return_data_for_valid_parameters(self): out_dict = ui.get_almanac_data( ra=180, dec=0, date="2000-1-1T0:0:0", observatory="lasilla" ) - assert type(out_dict) == dict + assert isinstance(out_dict, Mapping) assert len(out_dict) == 9 assert out_dict["observatory"] == "lasilla" @@ -190,7 +201,7 @@ def test_return_full_dict_when_flag_true(self): out_dict = ui.get_almanac_data( ra=180, dec=0, date="2000-1-1T0:0:0", return_full_dict=True ) - assert type(out_dict) == dict + assert isinstance(out_dict, Mapping) assert len(out_dict) == 39 assert type(out_dict["moon_sun_sep"]) == float @@ -199,17 +210,18 @@ def test_return_only_almanac_dict_when_flag_false(self): out_dict = ui.get_almanac_data( ra=180, dec=0, date="2000-1-1T0:0:0", return_full_dict=False ) - assert type(out_dict) == dict + assert isinstance(out_dict, Mapping) assert len(out_dict) == 9 @pytest.mark.webtest - def test_take_date_only_if_date_and_mjd_are_valid(self, capsys): - out_dict_date = ui.get_almanac_data(ra=180, dec=0, date="2000-1-1T0:0:0") - out_dict_both = ui.get_almanac_data( - ra=180, dec=0, mjd=50000, date="2000-1-1T0:0:0" - ) - output = capsys.readouterr()[0].strip() - assert output == "Warning: Both date and mjd are set. Using date" + def test_take_date_only_if_date_and_mjd_are_valid(self): + out_dict_date = ui.get_almanac_data( + ra=180, dec=0, date="2000-1-1T0:0:0") + with pytest.warns(UserWarning) as record: + out_dict_both = ui.get_almanac_data( + ra=180, dec=0, mjd=50000, date="2000-1-1T0:0:0" + ) + assert record[0].message.args[0] == "Both date and mjd are set. Using date" assert out_dict_both == out_dict_date @@ -218,7 +230,9 @@ class TestDocExamples: def test_example(self): skycalc = ui.SkyCalc() skycalc.get_almanac_data( - ra=83.8221, dec=-5.3911, date="2017-12-24T04:00:00", update_values=True + ra=83.8221, dec=-5.3911, + date="2017-12-24T04:00:00", + update_values=True ) tbl = skycalc.get_sky_spectrum() assert len(tbl) == 4606 @@ -233,7 +247,7 @@ def test_example(self): # assert out_dict["observatory"] == "2640" # assert out_dict["observatory_orig"] == "paranal" # -# def test_returns_corrected_SkyCalcParams_for_valid_observatory(self): +# def test_returns_corrected_SkyCalcParams_for_valid_observatory(self, skp): # skp["observatory"] = "paranal" # out_dict = ui.fix_observatory(skp) # assert out_dict["observatory"] == "2640" diff --git a/skycalc_ipy/ui.py b/skycalc_ipy/ui.py index f205747..31b7caf 100644 --- a/skycalc_ipy/ui.py +++ b/skycalc_ipy/ui.py @@ -1,5 +1,9 @@ -import os -import inspect +# -*- coding: utf-8 -*- +"""Skyclc IPY user interface.""" + +import logging +import warnings +from pathlib import Path from datetime import datetime as dt import yaml @@ -12,6 +16,7 @@ __all__ = ["SkyCalc"] +# TODO: this isn't used, but something VERY similar is done in core...... observatory_dict = { "lasilla": "2400", "paranal": "2640", @@ -21,12 +26,15 @@ "highanddry": "5000", } +PATH_HERE = Path(__file__).parent + class SkyCalc: + """Main UI class.""" + def __init__(self, ipt_str=None): if ipt_str is None: - dirname = os.path.dirname(inspect.getfile(inspect.currentframe())) - ipt_str = os.path.join(dirname, "params.yaml") + ipt_str = PATH_HERE / "params.yaml" params = load_yaml(ipt_str) @@ -39,20 +47,17 @@ def __init__(self, ipt_str=None): self.last_skycalc_response = None - def print_comments(self, param_names=None): - if param_names is None: - param_names = list(self.comments.keys()) - - if isinstance(param_names, str): - param_names = [param_names] + def print_comments(self, *param_names): + """Print descriptions of parameters. Print all if no names given.""" + param_names = param_names or list(self.comments.keys()) + maxkeylen = len(max(param_names, key=len)) for pname in param_names: - if pname not in self.comments: - print(f"{pname} not found") - else: - print(f"{pname} : {self.comments[pname]}") + comment = self.comments.get(pname, "") + print(f"{pname:>{maxkeylen}} : {comment}") def validate_params(self): + """Check allowed range for parameters.""" invalid_keys = [] for key in self.values: if self.check_type[key] == "no_check" or self.defaults[key] is None: @@ -78,14 +83,11 @@ def validate_params(self): if self.values[key] < self.allowed[key]: invalid_keys.append(key) - else: - pass - if invalid_keys: - print("See .comments[] for help") - print("The following entries are invalid:") + logging.warning("The following entries are invalid:") for key in invalid_keys: - print(f"'{key}' : {self.values[key]} : {self.comments[key]}") + logging.warning("'%s' : %s : %s", key, + self.values[key], self.comments[key]) return not invalid_keys @@ -98,6 +100,7 @@ def get_almanac_data( observatory=None, update_values=False, ): + """Query ESO Almanac with given parameters.""" if date is None and mjd is None: raise ValueError("Either date or mjd must be set") @@ -111,7 +114,7 @@ def get_almanac_data( def get_sky_spectrum(self, return_type="table", filename=None): """ - Retrieve a fits.HDUList object from the SkyCalc server + Retrieve a fits.HDUList object from the SkyCalc server. The HDUList can be returned in a variety of formats. @@ -135,14 +138,14 @@ def get_sky_spectrum(self, return_type="table", filename=None): - "fits": hdu (HDUList) """ - from astropy import table if not self.validate_params(): - raise ValueError("Object contains invalid parameters. Not calling ESO") + raise ValueError( + "Object contains invalid parameters. Not calling ESO") skm = SkyModel() - skm.callwith(self.values) + skm(**self.values) self.last_skycalc_response = skm.data if filename is not None: skm.write(filename) @@ -156,7 +159,7 @@ def get_sky_spectrum(self, return_type="table", filename=None): # Somehow, astropy doesn't quite parse the unit correctly. # Still, we shouldn't blindly overwrite it, so at least check. funit = tbl[colname].unit - if not str(funit) in ("ph/s/m2/micron/arcsec2", "None"): + if str(funit) not in ("ph/s/m2/micron/arcsec2", "None"): raise ValueError(f"Unexpected flux unit: {funit}") tbl[colname].unit = u.Unit("ph s-1 m-2 um-1 arcsec-2") @@ -208,8 +211,8 @@ def get_sky_spectrum(self, return_type="table", filename=None): points=tbl["lam"].data * tbl["lam"].unit, lookup_table=tbl["flux"].data * funit, ) - print( - "Warning: synphot doesn't accept surface brightnesses \n" + logging.warning( + "Synphot doesn't accept surface brightnesses \n" "The resulting spectrum should be multiplied by arcsec-2" ) @@ -229,7 +232,8 @@ def update(self, kwargs): def __setitem__(self, key, value): if key not in self.keys: - raise ValueError(key + " not in self.defaults") + raise KeyError(f"Key {key} is not defined. Only predefined keys " + "can be set. See SkyCalc.keys for a list of those.") self.values[key] = value def __getitem__(self, item): @@ -241,11 +245,12 @@ def keys(self): def load_yaml(ipt_str): - if ".yaml" in ipt_str.lower(): - if not os.path.exists(ipt_str): - raise ValueError(ipt_str + " not found") + # TODO: why not just load, what's all of this? + if ".yaml" in str(ipt_str).lower(): + if not ipt_str.exists(): + raise ValueError(f"{ipt_str} not found") - with open(ipt_str, "r") as fd: + with ipt_str.open("r", encoding="utf-8") as fd: fd = "\n".join(fd.readlines()) opts_dict = yaml.load(fd, Loader=yaml.FullLoader) else: @@ -263,15 +268,16 @@ def get_almanac_data( observatory=None, ): if date is not None and mjd is not None: - print("Warning: Both date and mjd are set. Using date") + warnings.warn("Both date and mjd are set. Using date", + UserWarning, stacklevel=2) skycalc_params = SkyCalc() - skycalc_params.values.update({"ra": ra, "dec": dec, "date": date, "mjd": mjd}) + skycalc_params.values.update( + {"ra": ra, "dec": dec, "date": date, "mjd": mjd}) if observatory is not None: skycalc_params.values["observatory"] = observatory skycalc_params.validate_params() - alm = AlmanacQuery(skycalc_params.values) - response = alm.query() + response = AlmanacQuery(skycalc_params.values)() if return_full_dict: skycalc_params.values.update(response)