From 8a382bb99d63e61b36408fd1be512a9a9f2e0345 Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Fri, 8 Jan 2021 16:34:37 -0500 Subject: [PATCH 01/19] Added uncertainties.py module --- sedkit/uncertainties.py | 421 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 421 insertions(+) create mode 100644 sedkit/uncertainties.py diff --git a/sedkit/uncertainties.py b/sedkit/uncertainties.py new file mode 100644 index 00000000..19cb59f6 --- /dev/null +++ b/sedkit/uncertainties.py @@ -0,0 +1,421 @@ +""" +Module to calculate uncertainties +""" +import numpy as np +from scipy.integrate import quad +from scipy.optimize import leastsq + + +class Unum(object): + """ + An object to handle math with uncertainties + """ + def __init__(self, nominal, upper, lower=None, n_samples=1000): + """ + Initialize a number with uncertainties + """ + # Store values + self.nominal = nominal + self.upper = upper + self.lower = lower or upper + self.n = n_samples + + def __repr__(self): + """ + repr method + """ + if self.upper == self.lower: + return '{}({})'.format(*self.value[:2]) + else: + return '{}(+{},-{})'.format(*self.value) + + def __add__(self, other): + """ + Add two numbers + + Parameters + ---------- + other: int, float, Unum + The number to add + + Returns + ------- + Unum + The Unum value + """ + # Generate distributions for each number + dist1 = self.sample_from_errors() + dist2 = other.sample_from_errors() + + # Do math + dist3 = dist1 + dist2 + + # Make a new Unum from the new nominal value and upper and lower quantiles + return Unum(*self.get_quantiles(dist3)) + + def __mul__(self, other): + """ + Multiply two numbers + + Parameters + ---------- + other: int, float, Unum + The number to multiply + + Returns + ------- + Unum + The Unum value + """ + # Generate distributions for each number + dist1 = self.sample_from_errors() + dist2 = other.sample_from_errors() + + # Do math + dist3 = dist1 * dist2 + + # Make a new Unum from the new nominal value and upper and lower quantiles + return Unum(*self.get_quantiles(dist3)) + + def __sub__(self, other): + """ + Subtract two numbers + + Parameters + ---------- + other: int, float, Unum + The number to subtract + + Returns + ------- + Unum + The Unum value + """ + # Generate distributions for each number + dist1 = self.sample_from_errors() + dist2 = other.sample_from_errors() + + # Do math + dist3 = dist1 - dist2 + + # Make a new Unum from the new nominal value and upper and lower quantiles + return Unum(*self.get_quantiles(dist3)) + + def __pow__(self, exp): + """ + Divide two numbers + + Parameters + ---------- + exp: int, float + The power to raise + + Returns + ------- + Unum + The Unum value + """ + # Generate distributions for each number + dist1 = self.sample_from_errors() + + # Do math + dist3 = np.power(dist1, exp) + + # Make a new Unum from the new nominal value and upper and lower quantiles + return Unum(*self.get_quantiles(dist3)) + + def __truediv__(self, other): + """ + Divide the number by another + + Parameters + ---------- + other: int, float, Unum + The number to divide by + + Returns + ------- + Unum + The Unum value + """ + # Generate distributions for each number + dist1 = self.sample_from_errors() + dist2 = other.sample_from_errors() + + # Do math + dist3 = dist1 / dist2 + + # Make a new Unum from the new nominal value and upper and lower quantiles + return Unum(*self.get_quantiles(dist3)) + + def __floordiv__(self, other): + """ + Floor divide the number by another + + Parameters + ---------- + other: int, float, Unum + The number to floor divide by + + Returns + ------- + Unum + The Unum value + """ + # Generate distributions for each number + dist1 = self.sample_from_errors() + dist2 = other.sample_from_errors() + + # Do math + dist3 = dist1 // dist2 + + # Make a new Unum from the new nominal value and upper and lower quantiles + return Unum(*self.get_quantiles(dist3)) + + def sample_from_errors(self, low_lim=None, up_lim=None): + """ + Function made to sample points given the 0.16, 0.5 and 0.84 quantiles of a parameter + In the case of unequal variances, this algorithm assumes a skew-normal distribution and samples from it. + If the variances are equal, it samples from a normal distribution. + + Parameters + ---------- + low_lim: float (optional) + Lower limits on the values of the samples + up_lim: float (optional) + Upper limits on the values of the samples + + Returns + ------- + The output are n samples from the distribution that best-matches the quantiles. + *The optional inputs (low_lim and up_lim) are lower and upper limits that the samples have to have; if any of the samples + surpasses those limits, new samples are drawn until no samples do. Note that this changes the actual variances of the samples. + """ + if (self.upper != self.lower): + + # If errors are assymetric, sample from a skew-normal distribution given the location parameter (assumed to be the median), self.upper and self.lower. + + # First, find the parameters mu, sigma and alpha of the skew-normal distribution that + # best matches the observed quantiles: + sknorm = SkewNormal() + sknorm.fit(*self.value) + + # And now sample n values from the distribution: + samples = sknorm.sample(self.n) + + # If a lower limit or an upper limit is given, then search if any of the samples surpass + # those limits, and sample again until no sample surpasses those limits: + if low_lim is not None: + while True: + idx = np.where(samples < low_lim)[0] + l_idx = len(idx) + if l_idx > 0: + samples[idx] = sknorm.sample(l_idx) + else: + break + + if up_lim is not None: + while True: + idx = np.where(samples > up_lim)[0] + l_idx = len(idx) + if l_idx > 0: + samples[idx] = sknorm.sample(l_idx) + else: + break + return samples + + else: + + # If errors are symmetric, sample from a gaussian + samples = np.random.normal(self.nominal, self.upper, n) + + # If a lower limit or an upper limit is given, then search if any of the samples surpass + # those limits, and sample again until no sample surpasses those limits: + if low_lim is not None: + while True: + idx = np.where(samples < low_lim)[0] + l_idx = len(idx) + if l_idx > 0: + samples[idx] = np.random.normal(self.nominal, self.upper, l_idx) + else: + break + + if up_lim is not None: + while True: + idx = np.where(samples > up_lim)[0] + l_idx = len(idx) + if l_idx > 0: + samples[idx] = np.random.normal(self.nominal, self.upper, l_idx) + else: + break + return samples + + @staticmethod + def get_quantiles(dist, alpha=0.68, method='median'): + """ + Determine the median, upper and lower quantiles of a distribution + + Parameters + ---------- + dist: array-like + The distribution to measure + alpha: float + The + method: str + The method used to determine the nominal value + + Returns + ------- + tuple + Median of the parameter, upper credibility bound, lower credibility bound + """ + # Order the distribution + ordered_dist = dist[np.argsort(dist)] + + # Define the number of samples from posterior + nsamples = len(dist) + nsamples_at_each_side = int(nsamples * (alpha / 2.) + 1) + + if method == 'median': + + # Number of points is even + if nsamples % 2 == 0.0: + med_idx_upper = int(nsamples / 2.) + 1 + med_idx_lower = med_idx_upper - 1 + param = (ordered_dist[med_idx_upper] + ordered_dist[med_idx_lower]) / 2. + + else: + med_idx_upper = med_idx_lower = int(nsamples / 2.) + param = ordered_dist[med_idx_upper] + + q_upper = ordered_dist[med_idx_upper + nsamples_at_each_side] + q_lower = ordered_dist[med_idx_lower - nsamples_at_each_side] + + return param, q_upper - param, param - q_lower + + @property + def value(self): + """ + The nominal, upper, and lower values + """ + return self.nominal, self.upper, self.lower + + +class SkewNormal(object): + """ + Description + ----------- + This class defines a SkewNormal object, which generates a SkewNormal distribution given the quantiles + from which you can then sample datapoints from. + """ + def __init__(self): + self.mu = 0.0 + self.sigma = 0.0 + self.alpha = 0.0 + + def fit(self, median, sigma1, sigma2): + """ + This function fits a Skew Normal distribution given + the median, upper error bars (self.upper) and lower error bar (sigma2). + """ + + # First, define the sign of alpha, which should be positive if right skewed + # and negative if left skewed: + alpha_sign = np.sign(sigma1 - sigma2) + + # Now define the residuals of the least-squares problem: + def residuals(p, data, x): + mu, sqrt_sigma, sqrt_alpha = p + return data - model(x, mu, sqrt_sigma, sqrt_alpha) + + # Define the model used in the residuals: + def model(x, mu, sqrt_sigma, sqrt_alpha): + """ + Note that we pass the square-root of the scale (sigma) and shape (alpha) parameters, + in order to define the sign of the former to be positive and of the latter to be fixed given + the values of self.upper and sigma2: + """ + return self.cdf(x, mu, sqrt_sigma**2, alpha_sign * sqrt_alpha**2) + + # Define the quantiles: + y = np.array([0.15866, 0.5, 0.84134]) + + # Define the values at which we observe the quantiles: + x = np.array([median - sigma2, median, median + sigma1]) + + # Start assuming that mu = median, sigma = mean of the observed sigmas, and alpha = 0 (i.e., start from a gaussian): + guess = (median, np.sqrt( 0.5 * (sigma1 + sigma2)), 0) + + # Perform the non-linear least-squares optimization: + plsq = leastsq(residuals, guess, args=(y, x))[0] + + self.mu, self.sigma, self.alpha = plsq[0], plsq[1]**2, alpha_sign * plsq[2]**2 + + def sample(self, n): + """ + This function samples n points from a skew normal distribution using the + method outlined by Azzalini here: http://azzalini.stat.unipd.it/SN/faq-r.html. + """ + # Define delta: + delta = self.alpha / np.sqrt(1 + self.alpha**2) + + # Now sample u0,u1 having marginal distribution ~N(0,1) with correlation delta: + u0 = np.random.normal(0, 1, n) + v = np.random.normal(0, 1, n) + u1 = delta * u0 + np.sqrt(1 - delta**2) * v + + # Now, u1 will be random numbers sampled from skew-normal if the corresponding values + # for which u0 are shifted in sign. To do this, we check the values for which u0 is negative: + idx_negative = np.where(u0 < 0)[0] + u1[idx_negative] = -u1[idx_negative] + + # Finally, we change the location and scale of the generated random-numbers and return the samples: + return self.mu + self.sigma * u1 + + @staticmethod + def cdf(x, mu, sigma, alpha): + """ + This function simply calculates the CDF at x given the parameters + mu, sigma and alpha of a Skew-Normal distribution. It takes values or + arrays as inputs. + """ + if type(x) is np.ndarray: + out = np.zeros(len(x)) + for i in range(len(x)): + out[i] = quad(lambda x: SkewNormal.pdf(x, mu, sigma, alpha), -np.inf, x[i])[0] + return out + + else: + return quad(lambda x: SkewNormal.pdf(x, mu, sigma, alpha), -np.inf, x)[0] + + @staticmethod + def pdf(x, mu, sigma, alpha): + """ + This function returns the value of the Skew Normal PDF at x, given + mu, sigma and alpha + """ + def erf(x): + # save the sign of x + sign = np.sign(x) + x = abs(x) + + # constants + a1 = 0.254829592 + a2 = -0.284496736 + a3 = 1.421413741 + a4 = -1.453152027 + a5 = 1.061405429 + p = 0.3275911 + + # A&S formula 7.1.26 + t = 1.0/(1.0 + p * x) + y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * np.exp(-x * x) + + return sign * y + + def palpha(y, alpha): + phi = np.exp(-y**2. / 2.0) / np.sqrt(2.0 * np.pi) + PHI = (erf(y * alpha / np.sqrt(2)) + 1.0) * 0.5 + return 2 * phi * PHI + + return palpha((x - mu) / sigma, alpha) * (1. / sigma) From 903cd24c8194c4fba24c47f5ed5c001c8eee7b81 Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Fri, 8 Jan 2021 17:11:10 -0500 Subject: [PATCH 02/19] Updated Unum with plot method for visual inspection --- sedkit/uncertainties.py | 51 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/sedkit/uncertainties.py b/sedkit/uncertainties.py index 19cb59f6..b4d45f7c 100644 --- a/sedkit/uncertainties.py +++ b/sedkit/uncertainties.py @@ -1,6 +1,7 @@ """ Module to calculate uncertainties """ +from bokeh.plotting import show, figure import numpy as np from scipy.integrate import quad from scipy.optimize import leastsq @@ -10,7 +11,7 @@ class Unum(object): """ An object to handle math with uncertainties """ - def __init__(self, nominal, upper, lower=None, n_samples=1000): + def __init__(self, nominal, upper, lower=None, n_samples=10000, sig_figs=2, method='median'): """ Initialize a number with uncertainties """ @@ -19,6 +20,8 @@ def __init__(self, nominal, upper, lower=None, n_samples=1000): self.upper = upper self.lower = lower or upper self.n = n_samples + self.sig_figs = sig_figs + self.method = method def __repr__(self): """ @@ -227,7 +230,7 @@ def sample_from_errors(self, low_lim=None, up_lim=None): else: # If errors are symmetric, sample from a gaussian - samples = np.random.normal(self.nominal, self.upper, n) + samples = np.random.normal(self.nominal, self.upper, self.n) # If a lower limit or an upper limit is given, then search if any of the samples surpass # those limits, and sample again until no sample surpasses those limits: @@ -250,8 +253,7 @@ def sample_from_errors(self, low_lim=None, up_lim=None): break return samples - @staticmethod - def get_quantiles(dist, alpha=0.68, method='median'): + def get_quantiles(self, dist, alpha=0.68): """ Determine the median, upper and lower quantiles of a distribution @@ -276,7 +278,7 @@ def get_quantiles(dist, alpha=0.68, method='median'): nsamples = len(dist) nsamples_at_each_side = int(nsamples * (alpha / 2.) + 1) - if method == 'median': + if self.method == 'median': # Number of points is even if nsamples % 2 == 0.0: @@ -291,7 +293,37 @@ def get_quantiles(dist, alpha=0.68, method='median'): q_upper = ordered_dist[med_idx_upper + nsamples_at_each_side] q_lower = ordered_dist[med_idx_lower - nsamples_at_each_side] - return param, q_upper - param, param - q_lower + return param.round(self.sig_figs), (q_upper - param).round(self.sig_figs), (param - q_lower).round(self.sig_figs) + + def plot(self, bins=None): + """ + Plot the distribution with stats + + Parameters + ---------- + bins: int + The number of bins for the histogram + + Returns + ------- + bokeh.plotting.figure.Figure + """ + # Make the figure + fig = figure() + + # Make a histogram of the distribution + dist = self.sample_from_errors() + hist, edges = np.histogram(dist, density=True, bins=bins or min(self.n, 50)) + fig.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:], color='wheat') + + # Add stats to plot + lower, nominal, upper = self.quantiles + fig.line([lower] * 2, [min(hist), max(hist)], line_width=2, color='red', legend_label='lower (-{})'.format(self.lower)) + fig.line([nominal] * 2, [min(hist), max(hist)], line_width=2, color='black', legend_label='{} ({})'.format(self.method, self.nominal)) + fig.line([upper] * 2, [min(hist), max(hist)], line_width=2, color='blue', legend_label='upper (+{})'.format(self.upper)) + + # Show the plot + show(fig) @property def value(self): @@ -300,6 +332,13 @@ def value(self): """ return self.nominal, self.upper, self.lower + @property + def quantiles(self): + """ + The [0.15866, 0.5, 0.84134] quantiles + """ + return self.nominal - self.lower, self.nominal, self.nominal + self.upper + class SkewNormal(object): """ From 1879105675b13c7f9ed3b4ca639a1aef39ba95f9 Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Sun, 17 Jan 2021 14:00:57 -0500 Subject: [PATCH 03/19] Added IRAS query support --- environment.yml | 2 +- requirements.txt | 2 +- sedkit/query.py | 211 +++++++++++++++++++++++++++++++++++----- sedkit/sed.py | 154 ++++++++++------------------- sedkit/spectrum.py | 5 +- sedkit/uncertainties.py | 8 ++ 6 files changed, 250 insertions(+), 132 deletions(-) diff --git a/environment.yml b/environment.yml index a0b4f4a7..60243fff 100644 --- a/environment.yml +++ b/environment.yml @@ -17,6 +17,6 @@ dependencies: - jupyter>=1.0.0 - ipython>=7.12.0 - pip: - - svo-filters==0.2.19 + - svo-filters==0.3.0 - dustmaps==1.0.4 - numpydoc==0.8.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b7f41872..626bc9c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ bokeh>=2.2.3 astroquery>=0.4.1 scipy>=1.4.1 pandas==0.23.4 -svo_filters==0.2.19 +svo_filters==0.3.0 dustmaps>=1.0.4 selenium>=2.49.2 pytest>=6.2.1 diff --git a/sedkit/query.py b/sedkit/query.py index 454b7c03..1ffbdb22 100755 --- a/sedkit/query.py +++ b/sedkit/query.py @@ -5,10 +5,15 @@ """ Interface with astroquery to fetch data """ +import os +from urllib.request import urlretrieve from astropy.coordinates import Angle, SkyCoord +from astropy.io import fits import astropy.units as q from astroquery.vizier import Vizier +from astroquery.sdss import SDSS +import numpy as np from . import utilities as u @@ -17,13 +22,170 @@ PHOT_CATALOGS = {'2MASS': {'catalog': 'II/246/out', 'cols': ['Jmag', 'Hmag', 'Kmag'], 'names': ['2MASS.J', '2MASS.H', '2MASS.Ks']}, 'WISE': {'catalog': 'II/328/allwise', 'cols': ['W1mag', 'W2mag', 'W3mag', 'W4mag'], 'names': ['WISE.W1', 'WISE.W2', 'WISE.W3', 'WISE.W4']}, 'PanSTARRS': {'catalog': 'II/349/ps1', 'cols': ['gmag', 'rmag', 'imag', 'zmag', 'ymag'], 'names': ['PS1.g', 'PS1.r', 'PS1.i', 'PS1.z', 'PS1.y']}, - 'Gaia': {'catalog': 'I/345/gaia2', 'cols': ['Gmag'], 'names': ['Gaia.G']}, + 'Gaia': {'catalog': 'I/345/gaia2', 'cols': ['Plx', 'Gmag'], 'names': ['parallax', 'Gaia.G']}, 'SDSS': {'catalog': 'V/147', 'cols': ['umag', 'gmag', 'rmag', 'imag', 'zmag'], 'names': ['SDSS.u', 'SDSS.g', 'SDSS.r', 'SDSS.i', 'SDSS.z']}} Vizier.columns = ["**", "+_r"] -def query_vizier(catalog, target=None, sky_coords=None, cols=None, wildcards=['e_*'], names=None, search_radius=20*q.arcsec, idx=0, places=3, cat_name=None, verbose=True, **kwargs): +def query_SDSS_optical_spectra(coords, idx=0, verbose=True): + """ + Query for SDSS spectra + + Parameters + ---------- + coords: astropy.coordinates.SkyCoord + The coordinates to query + idx: int + The index of the target to use from the results table + + Returns + ------- + list + The [W, F, E] spectrum of the target + """ + + # Fetch results + results = SDSS.query_region(coords, spectro=True) + + # Print info + if verbose: + if results is None: + n_rec = 0 if results is None else len(results) + print("{} record{} found in SDSS {} data.".format(n_rec, '' if n_rec == 1 else 's', survey)) + + if n_rec == 0: + + return None, None, None + + else: + + # Download the spectrum file + hdu = SDSS.get_spectra(matches=results)[idx] + + # Get the spectrum data + data = hdu[1].data + + # Convert from log to linear units in Angstroms + wav = 10**data['loglam'] * q.AA + + # Convert to FLAM units + flx = data['flux'] * 1E-17 * q.erg / q.s / q.cm**2 / q.AA + err = data['flux'] * 1E-18 * q.erg / q.s / q.cm**2 / q.AA # TODO: Where's the error? + + # Metadata + ref = 'SDSS' + header = hdu[0].header + + return [wav, flx, err], ref, header + + +def query_SDSS_apogee_spectra(coords, verbose=True, **kwargs): + """ + Query the APOGEE survey data + + Parameters + ---------- + coords: astropy.coordinates.SkyCoord + The coordinates to query + + Returns + ------- + list + The [W, F, E] spectrum of the target + """ + + # Query vizier for spectra + catalog = 'III/284/allstars' + results = query_vizier(catalog, col_names=['Ascap', 'File', 'Tel', 'Field'], sky_coords=coords, wildcards=[], verbose=verbose) + + if len(results) == 0: + + return None, None, None + + else: + + ascap, file, tel, field = [row[1] for row in results] + + # Construct URL + url = 'https://data.sdss.org/sas/dr16/apogee/spectro/redux/r12/stars/{}/{}/{}'.format(tel, field, file) + + # Download the file + urlretrieve(url, file) + + # Get data + hdu = fits.open(file) + header = hdu[0].header + + # Generate wavelength + wav = 10**(np.linspace(header['CRVAL1'], header['CRVAL1'] + (header['CDELT1'] * header['NWAVE']), header['NWAVE'])) + wav *= q.AA + + # Get flux and error + flx = hdu[1].data[0] * 1E-17 * q.erg / q.s / q.cm**2 / q.AA + err = hdu[2].data[0] * 1E-17 * q.erg / q.s / q.cm**2 / q.AA + + # Delete file + hdu.close() + os.system('rm {}'.format(file)) + + return [wav, flx, err], catalog, header + + +def query_IRAS_spectra(coords, verbose=True, **kwargs): + """ + Query the IRAS survey data + + Parameters + ---------- + coords: astropy.coordinates.SkyCoord + The coordinates to query + + Returns + ------- + list + The [W, F, E] spectrum of the target + """ + + # Query vizier for spectra + catalog = 'III/197/lrs' + results = query_vizier(catalog, sky_coords=coords, wildcards=[], verbose=verbose) + # results = query_vizier(catalog, col_names=['Ascap', 'File', 'Tel', 'Field'], sky_coords=coords, wildcards=[], verbose=verbose) + + if len(results) == 0: + + return None, None, None + + else: + + file = [row[1] for row in results] + + # Construct URL + url = 'https://cdsarc.unistra.fr/ftp/III/197/{}'.format(file) + + # Download the file + urlretrieve(url, file) + + # Get data + hdu = fits.open(file) + header = hdu[0].header + + # Generate wavelength + wav = 10**(np.linspace(header['CRVAL1'], header['CRVAL1'] + (header['CDELT1'] * header['NWAVE']), header['NWAVE'])) + wav *= q.AA + + # Get flux and error + flx = hdu[1].data[0] * 1E-17 * q.erg / q.s / q.cm**2 / q.AA + err = hdu[2].data[0] * 1E-17 * q.erg / q.s / q.cm**2 / q.AA + + # Delete file + hdu.close() + os.system('rm {}'.format(file)) + + return [wav, flx, err], catalog, header + + +def query_vizier(catalog, target=None, sky_coords=None, col_names=None, wildcards=['e_*'], target_names=None, search_radius=20 * q.arcsec, idx=0, cat_name=None, verbose=True, **kwargs): """ Search Vizier for photometry in the given catalog @@ -35,12 +197,12 @@ def query_vizier(catalog, target=None, sky_coords=None, cols=None, wildcards=['e A target name to search for, e.g. 'Trappist-1' sky_coords: astropy.coordinates.SkyCoord (optional) The sky coordinates to search - cols: sequence + col_names: sequence The list of column names to fetch wildcards: sequence A list of wildcards for each column name, e.g. 'e_*' includes errors target_names: sequence (optional) - The list of renamed columns, must be the same length as band_names + The list of renamed columns, must be the same length as col_names search_radius: astropy.units.quantity.Quantity The search radius for the Vizier query idx: int @@ -50,46 +212,45 @@ def query_vizier(catalog, target=None, sky_coords=None, cols=None, wildcards=['e if catalog in PHOT_CATALOGS: meta = PHOT_CATALOGS[catalog] catalog = meta['catalog'] - cols = cols or meta['cols'] - names = names or meta['names'] + cols = col_names or meta['cols'] + names = target_names or meta['names'] + else: + cols = col_names + names = target_names # Name for the catalog - if cat_name is None: - cat_name = catalog - - # If search_radius is explicitly set, use that - if search_radius is not None and isinstance(sky_coords, SkyCoord): - viz_cat = Vizier.query_region(sky_coords, radius=search_radius, catalog=[catalog]) + cat_name = cat_name or catalog - # ...or get photometry using designation... - elif isinstance(target, str): + # Get photometry using designation... + if isinstance(target, str): viz_cat = Vizier.query_object(target, catalog=[catalog]) + # ...or use coordinates... + elif search_radius is not None and isinstance(sky_coords, SkyCoord): + viz_cat = Vizier.query_region(sky_coords, radius=search_radius, catalog=[catalog]) + # ...or abort else: - viz_cat = None + viz_cat = [] # Check there are columns to fetch if cols is None: - raise ValueError("No column names to fetch!") - + cols = viz_cat[0].colnames + print(cols) # Check for wildcards - if wildcards is None: - wildcards = [] + wildcards = wildcards or [] # Check for target names or just use native column names - if names is None: - names = cols + names = names or cols # Print info if verbose: n_rec = len(viz_cat) print("{} record{} found in {}.".format(n_rec, '' if n_rec == 1 else 's', cat_name)) - results = [] - # Parse the record - if viz_cat is not None and len(viz_cat) > 0: + results = [] + if len(viz_cat) > 0: if len(viz_cat) > 1: print('{} {} records found.'.format(len(viz_cat), name)) @@ -101,7 +262,7 @@ def query_vizier(catalog, target=None, sky_coords=None, cols=None, wildcards=['e for name, viz in zip(names, cols): fetch = [viz] + [wc.replace('*', viz) for wc in wildcards] if all([i in rec.columns for i in fetch]): - data = [round(val, places) if u.isnumber(val) else val for val in rec[fetch]] + data = [val for val in rec[fetch]] results.append([name] + data + [ref]) else: print("{}: Could not find all those columns".format(fetch)) diff --git a/sedkit/sed.py b/sedkit/sed.py index 56cb0be5..e4d1f15e 100755 --- a/sedkit/sed.py +++ b/sedkit/sed.py @@ -30,6 +30,7 @@ from . import utilities as u from . import spectrum as sp from . import isochrone as iso +from . import query as qu from . import relations as rel from . import modelgrid as mg @@ -273,11 +274,11 @@ def add_photometry(self, band, mag, mag_unc=None, system='Vega', **kwargs): The magnitude system of the input data, ['Vega', 'AB'] """ # Make sure the magnitudes are floats - if not isinstance(mag, float): + if not isinstance(mag, (float, np.float32)): raise TypeError("{}: Magnitude must be a float.".format(type(mag))) # Check the uncertainty - if not isinstance(mag_unc, (float, type(None))): + if not isinstance(mag_unc, (float, np.float32, type(None), np.ma.core.MaskedConstant)): raise TypeError("{}: Magnitude uncertainty must be a float, NaN, or None.".format(type(mag_unc))) # Make NaN if 0 @@ -293,8 +294,7 @@ def add_photometry(self, band, mag, mag_unc=None, system='Vega', **kwargs): print('Not a recognized bandpass: {}'.format(band)) # Convert to Vega - if system == 'AB': - mag, mag_unc = u.convert_mag(mag, mag_unc, old=system, new=self.mag_system) + mag, mag_unc = u.convert_mag(band, mag, mag_unc, old=system, new=self.mag_system) # Convert bandpass to desired units bp.wave_units = self.wave_units @@ -974,12 +974,9 @@ def find_2MASS(self, **kwargs): """ Search for 2MASS data """ - self.find_photometry('2MASS', 'II/246/out', - ['Jmag', 'Hmag', 'Kmag'], - ['2MASS.J', '2MASS.H', '2MASS.Ks'], - **kwargs) + self.find_photometry('2MASS', **kwargs) - def find_Gaia(self, search_radius=None, catalog='I/345/gaia2', include=['parallax', 'photometry'], idx=0): + def find_Gaia(self, search_radius=None, include=['parallax', 'photometry'], idx=0, **kwargs): """ Search for Gaia data @@ -992,132 +989,84 @@ def find_Gaia(self, search_radius=None, catalog='I/345/gaia2', include=['paralla idx: int The index of the results to use """ - # Make sure there are coordinates - if not isinstance(self.sky_coords, SkyCoord): - raise TypeError("Can't find Gaia data without coordinates!") + # Get the Vizier catalog + results = qu.query_vizier('Gaia', target=self.name, sky_coords=self.sky_coords, search_radius=search_radius or self.search_radius, verbose=self.verbose, idx=idx, **kwargs) - # Query the catalog - # See if the designation was fetched by Simbad - des = [name for name in self.all_names if name.startswith(name)] - - # If search_radius is explicitly set, use that - if search_radius is not None: - viz_cat = Vizier.query_region(self.sky_coords, radius=search_radius, catalog=[catalog]) - - # ...or get photometry using designation... - elif len(des) > 0: - viz_cat = Vizier.query_object(des[0], catalog=[catalog]) - - # ...or from the coordinates - else: - viz_cat = Vizier.query_region(self.sky_coords, radius=self.search_radius, catalog=[catalog]) - - # Print info - if self.verbose: - n_rec = len(viz_cat) - print("{} record{} found in Gaia DR2.".format(n_rec, '' if n_rec == 1 else 's')) - - # Parse the records - if len(viz_cat) > 0: + # Parse the record + if len(results) == 2: - # Grab the first record if 'parallax' in include: - parallax = list(viz_cat[0][idx][['Plx', 'e_Plx']]) - self.parallax = parallax[0] * q.mas, parallax[1] * q.mas + self.parallax = results[0][1] * q.mas, results[0][2] * q.mas + self.refs['parallax'] = results[0][3] - # Get Gband while we're here if 'photometry' in include: - try: - mag, unc = list(viz_cat[0][0][['Gmag', 'e_Gmag']]) - self.add_photometry('Gaia.G', mag, unc) - except: - pass + band, mag, unc, ref = results[1] + self.add_photometry(band, mag, unc, ref=ref) def find_PanSTARRS(self, **kwargs): """ Search for PanSTARRS data """ - self.find_photometry('PanSTARRS', 'II/349/ps1', - ['gmag', 'rmag', 'imag', 'zmag', 'ymag'], - ['PS1.g', 'PS1.r', 'PS1.i', 'PS1.z', 'PS1.y'], - **kwargs) + self.find_photometry('PanSTARRS', **kwargs) - def find_photometry(self, name, catalog, band_names, target_names=None, search_radius=None, idx=0, **kwargs): + def find_photometry(self, catalog, col_names=None, target_names=None, search_radius=None, idx=0, **kwargs): """ Search Vizier for photometry in the given catalog Parameters ---------- - name: str - The informal name of the catalog, e.g. '2MASS' catalog: str The Vizier catalog address, e.g. 'II/246/out' - band_names: sequence + col_names: sequence The list of column names to treat as bandpasses target_names: sequence (optional) - The list of renamed columns, must be the same length as band_names + The list of renamed columns, must be the same length as col_names search_radius: astropy.units.quantity.Quantity The search radius for the Vizier query idx: int The index of the record to use if multiple Vizier results """ - # See if the designation was fetched by Simbad - des = [name for name in self.all_names if name.startswith(name)] + # Get the Vizier catalog + results = qu.query_vizier(catalog, col_names=col_names, target_names=target_names, target=self.name, sky_coords=self.sky_coords, search_radius=search_radius or self.search_radius, verbose=self.verbose, idx=idx, **kwargs) - # If search_radius is explicitly set, use that - if search_radius is not None and isinstance(self.sky_coords, SkyCoord): - viz_cat = Vizier.query_region(self.sky_coords, radius=search_radius, catalog=[catalog]) + # Parse the record + for result in results: - # ...or get photometry using designation... - elif len(des) > 0: - viz_cat = Vizier.query_object(des[0], catalog=[catalog]) + # Get result + band, mag, unc, ref = result - # ...or from the coordinates... - elif isinstance(self.sky_coords, SkyCoord): - viz_cat = Vizier.query_region(self.sky_coords, radius=self.search_radius, catalog=[catalog]) + # Ensure Vegamag + system = 'AB' if 'SDSS' in band else 'Vega' - # ...or abort - else: - viz_cat = None + self.add_photometry(band, mag, unc, ref=ref, system=system) - if target_names is None: - target_names = band_names + def find_SDSS(self, **kwargs): + """ + Search for SDSS data + """ + self.find_photometry('SDSS', **kwargs) - # Print info - if self.verbose: - n_rec = 0 if viz_cat is None else len(viz_cat) - print("{} record{} found in {}.".format(n_rec, '' if n_rec == 1 else 's', name)) + def find_SDSS_spectra(self, surveys=['optical', 'apogee'], **kwargs): + """ + Search for SDSS spectra + """ + if 'optical' in surveys: - # Parse the record - if viz_cat is not None and len(viz_cat) > 0: - if len(viz_cat) > 1: - print('{} {} records found.'.format(len(viz_cat), name)) + # Query spectra + data, ref, header = qu.query_SDSS_optical_spectra(self.sky_coords, verbose=self.verbose, **kwargs) - # Grab the record - rec = viz_cat[0][idx] - ref = viz_cat[0].meta['name'] + # Add the spectrum to the SED + if data is not None: + self.add_spectrum(data, ref=ref, header=header) - # Pull out the photometry - for band, viz in zip(target_names, band_names): - try: - mag, unc = list(rec[[viz, 'e_' + viz]]) - mag, unc = round(float(mag), 3), round(float(unc), 3) - - # Convert to Vegamag - if name == 'SDSS': - mag, unc = u.convert_mag(band, mag, unc, old='AB', new='Vega') - self.add_photometry(band, mag, unc, ref=ref) - except IOError: - pass + if 'apogee' in surveys: - def find_SDSS(self, **kwargs): - """ - Search for SDSS data - """ - self.find_photometry('SDSS', 'V/147', - ['umag', 'gmag', 'rmag', 'imag', 'zmag'], - ['SDSS.u', 'SDSS.g', 'SDSS.r', 'SDSS.i', 'SDSS.z'], - **kwargs) + # Query spectra + data, ref, header = qu.query_SDSS_apogee_spectra(self.sky_coords, verbose=self.verbose, **kwargs) + + # Add the spectrum to the SED + if data is not None: + self.add_spectrum(data, ref=ref, header=header) def find_Simbad(self, search_radius=None, idx=0): """ @@ -1198,10 +1147,7 @@ def find_WISE(self, **kwargs): """ Search for WISE data """ - self.find_photometry('WISE', 'II/328/allwise', - ['W1mag', 'W2mag', 'W3mag', 'W4mag'], - ['WISE.W1', 'WISE.W2', 'WISE.W3', 'WISE.W4'], - **kwargs) + self.find_photometry('WISE', **kwargs) def fit_blackbody(self, fit_to='app_phot_SED', Teff_init=4000, epsilon=0.0001, acc=0.05, trim=[], norm_to=[]): """ diff --git a/sedkit/spectrum.py b/sedkit/spectrum.py index 6801e0cf..d3e3df36 100755 --- a/sedkit/spectrum.py +++ b/sedkit/spectrum.py @@ -51,7 +51,7 @@ class Spectrum: A class to store, calibrate, fit, and plot a single spectrum """ def __init__(self, wave, flux, unc=None, snr=None, trim=None, name=None, - ref=None, verbose=False, **kwargs): + ref=None, header=None, verbose=False, **kwargs): """Initialize the Spectrum object Parameters @@ -73,6 +73,8 @@ def __init__(self, wave, flux, unc=None, snr=None, trim=None, name=None, A name for the spectrum ref: str A reference for the data + header: str + The header for the spectrum file verbose: bool Print helpful stuff """ @@ -80,6 +82,7 @@ def __init__(self, wave, flux, unc=None, snr=None, trim=None, name=None, self.verbose = verbose self.name = name or 'New Spectrum' self.ref = ref + self.header = header # Make sure the arrays are the same shape if not wave.shape == flux.shape and ((unc is None) or not (unc.shape == flux.shape)): diff --git a/sedkit/uncertainties.py b/sedkit/uncertainties.py index b4d45f7c..2fd10b64 100644 --- a/sedkit/uncertainties.py +++ b/sedkit/uncertainties.py @@ -7,6 +7,14 @@ from scipy.optimize import leastsq +def trapz(f, a, b, n): + h = (b - a) / float(n) + s = 0.5 * (f(a) + f(b)) + for i in range(1, n, 1): + s = s + f(a + i * h) + return h * s + + class Unum(object): """ An object to handle math with uncertainties From 93c96a97b23c105214e905d4fbc72eb7f29292f7 Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Sun, 17 Jan 2021 14:04:06 -0500 Subject: [PATCH 04/19] Commented out IRAS query since not ready for prime time --- sedkit/query.py | 102 ++++++++++++++++++++++++------------------------ 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/sedkit/query.py b/sedkit/query.py index 1ffbdb22..8313e128 100755 --- a/sedkit/query.py +++ b/sedkit/query.py @@ -132,57 +132,57 @@ def query_SDSS_apogee_spectra(coords, verbose=True, **kwargs): return [wav, flx, err], catalog, header -def query_IRAS_spectra(coords, verbose=True, **kwargs): - """ - Query the IRAS survey data - - Parameters - ---------- - coords: astropy.coordinates.SkyCoord - The coordinates to query - - Returns - ------- - list - The [W, F, E] spectrum of the target - """ - - # Query vizier for spectra - catalog = 'III/197/lrs' - results = query_vizier(catalog, sky_coords=coords, wildcards=[], verbose=verbose) - # results = query_vizier(catalog, col_names=['Ascap', 'File', 'Tel', 'Field'], sky_coords=coords, wildcards=[], verbose=verbose) - - if len(results) == 0: - - return None, None, None - - else: - - file = [row[1] for row in results] - - # Construct URL - url = 'https://cdsarc.unistra.fr/ftp/III/197/{}'.format(file) - - # Download the file - urlretrieve(url, file) - - # Get data - hdu = fits.open(file) - header = hdu[0].header - - # Generate wavelength - wav = 10**(np.linspace(header['CRVAL1'], header['CRVAL1'] + (header['CDELT1'] * header['NWAVE']), header['NWAVE'])) - wav *= q.AA - - # Get flux and error - flx = hdu[1].data[0] * 1E-17 * q.erg / q.s / q.cm**2 / q.AA - err = hdu[2].data[0] * 1E-17 * q.erg / q.s / q.cm**2 / q.AA - - # Delete file - hdu.close() - os.system('rm {}'.format(file)) - - return [wav, flx, err], catalog, header +# def query_IRAS_spectra(coords, verbose=True, **kwargs): +# """ +# Query the IRAS survey data +# +# Parameters +# ---------- +# coords: astropy.coordinates.SkyCoord +# The coordinates to query +# +# Returns +# ------- +# list +# The [W, F, E] spectrum of the target +# """ +# +# # Query vizier for spectra +# catalog = 'III/197/lrs' +# results = query_vizier(catalog, sky_coords=coords, wildcards=[], verbose=verbose) +# # results = query_vizier(catalog, col_names=['Ascap', 'File', 'Tel', 'Field'], sky_coords=coords, wildcards=[], verbose=verbose) +# +# if len(results) == 0: +# +# return None, None, None +# +# else: +# +# file = [row[1] for row in results] +# +# # Construct URL +# url = 'https://cdsarc.unistra.fr/ftp/III/197/{}'.format(file) +# +# # Download the file +# urlretrieve(url, file) +# +# # Get data +# hdu = fits.open(file) +# header = hdu[0].header +# +# # Generate wavelength +# wav = 10**(np.linspace(header['CRVAL1'], header['CRVAL1'] + (header['CDELT1'] * header['NWAVE']), header['NWAVE'])) +# wav *= q.AA +# +# # Get flux and error +# flx = hdu[1].data[0] * 1E-17 * q.erg / q.s / q.cm**2 / q.AA +# err = hdu[2].data[0] * 1E-17 * q.erg / q.s / q.cm**2 / q.AA +# +# # Delete file +# hdu.close() +# os.system('rm {}'.format(file)) +# +# return [wav, flx, err], catalog, header def query_vizier(catalog, target=None, sky_coords=None, col_names=None, wildcards=['e_*'], target_names=None, search_radius=20 * q.arcsec, idx=0, cat_name=None, verbose=True, **kwargs): From 76736322bc6af3159aff60d42e65719bfce1dab4 Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Mon, 18 Jan 2021 01:32:00 -0500 Subject: [PATCH 05/19] Added MCMC fitting to Spectrum class --- environment.yml | 3 +- requirements.txt | 3 +- sedkit/mcmc.py | 251 +++++++++++++++++++++++++++++++++++++++++++++ sedkit/spectrum.py | 39 +++++++ 4 files changed, 294 insertions(+), 2 deletions(-) create mode 100644 sedkit/mcmc.py diff --git a/environment.yml b/environment.yml index 60243fff..b0185f7b 100644 --- a/environment.yml +++ b/environment.yml @@ -19,4 +19,5 @@ dependencies: - pip: - svo-filters==0.3.0 - dustmaps==1.0.4 - - numpydoc==0.8.0 \ No newline at end of file + - numpydoc==0.8.0 + - emcee>=3.0.2 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 626bc9c6..1b5d6e52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,5 @@ tox==3.5.3 coverage==4.5.2 Sphinx==1.8.2 twine==3.1.1 -click==7.0 \ No newline at end of file +click==7.0 +emcee>=3.0.2 \ No newline at end of file diff --git a/sedkit/mcmc.py b/sedkit/mcmc.py new file mode 100644 index 00000000..c09eaa73 --- /dev/null +++ b/sedkit/mcmc.py @@ -0,0 +1,251 @@ +""" +Module to perform MCMC fitting of a model grid to a spectrum + +Code is largely borrowed from https://github.com/BDNYC/synth_fit +""" +import astropy.units as q +from bokeh.plotting import figure, show +from bokeh.layouts import gridplot +import emcee +import numpy as np + + +def log_probability(model_params, model_grid, spectrum): + """ + Calculates the probability that the model_params from the given model_grid + reproduce the spectrum + + Parameters + ---------- + model_params: sequence + The free parameters for the model + model_grid: sedkit.modelgrid.ModelGrid + The model grid to fit + spectrum: sedkit.spectrum.Spectrum + The spectrum to fit + + Returns + ------- + lnprob + The log of the posterior probability for this model + data + """ + # The first arguments correspond to the parameters of the model + # the last, always, corresponds to the tolerance + # the second to last corresponds to the normalization variance + model_p = model_params[:model_grid.ndim] + lns = model_params[-1] + norm_values = model_params[model_grid.ndim:-1] + + normalization = 1 #self.calc_normalization(norm_values, self.wavelength_bins) + + if (lns > 1.0): + return -np.inf + + # Check if any of the parameters are outside the limits of the model + pdict = {} + for i in range(ndim): + param = model_grid.params[i] + pdict[param] = model_p[i] + mx = getattr(model_grid, '{}_max'.format(param)) + mn = getattr(model_grid, '{}_min'.format(param)) + if model_p[i] > mx or model_p[i] < mn: + return -np.inf + + # Get the model + model = model_grid.get_spectrum(**pdict) + + # mod_flux = model.flux * normalization + s = np.float64(np.exp(lns)) * spectrum.flux_unit + unc_sq = (spectrum.unc ** 2 + s ** 2) * normalization ** 2 + flux_pts = (spectrum.flux - model.flux * normalization) ** 2 / unc_sq + width_term = np.log(2 * np.pi * unc_sq.value) + lnprob = -0.5 * (np.sum(flux_pts + width_term)) + + return lnprob + + +class SpecSampler(object): + """ + Class to contain and run emcee on a spectrum and model grid + """ + def __init__(self, spectrum, model_grid, params=None, smooth=False, snap=False): + """ + Parameters + ---------- + spectrum: sedkit.spectrum.Spectrum + The spectrum object to fit + model_grid: sedkit.modelgrid.ModelGrid + The model grid to fit + params: list (optional + ModelGrid.parameters to vary in fit + smooth: boolean (default=True) + whether or not to smooth the model spectra before interpolation + onto the data wavelength grid + """ + # Save attributes + self.snap = snap + self.spectrum = spectrum + self.model_grid = model_grid + self.model_grid.ndim = self.ndim = len(params) + self.model_grid.params = params + if params is None: + params = self.model_grid.parameters + self.params = params + + # Calculate starting parameters for the emcee walkers by minimizing + # chi-squared for the grid of synthetic spectra + self.spectrum.best_fit_model(self.model_grid, name='best') + self.start_p = list(self.spectrum.best_fit['best'][self.params]) + self.min_chi = self.spectrum.best_fit['best']['gstat'] + + # Avoid edges of parameter space + for i in range(self.model_grid.ndim): + vals = getattr(self.model_grid, '{}_vals'.format(self.params[i])) + setattr(self.model_grid, '{}_max'.format(self.params[i]), vals.max()) + setattr(self.model_grid, '{}_min'.format(self.params[i]), vals.min()) + if self.start_p[i] >= vals.max(): + self.start_p[i] = self.start_p[i] * 0.95 + elif self.start_p[i] <= vals.min(): + self.start_p[i] = self.start_p[i] * 1.05 + + # Add additional parameters beyond the atmospheric model parameters + self.all_params = list(np.copy(self.params)) + + wavelength_bins = np.array([0.9, 1.4, 1.9, 2.5]) * q.um + if len(wavelength_bins) > 1: + norm_number = len(wavelength_bins) - 1 + else: + norm_number = 1 + for i in range(norm_number): + self.all_params.append("N{}".format(i)) + + # Add normalization parameter + self.start_p = np.append(self.start_p, np.ones(norm_number)) + + # Add (log of) tolerance parameter + good_unc = [not np.isnan(i) for i in self.spectrum.unc] + start_lns = np.log(2.0 * np.average(self.spectrum.unc[good_unc])) + self.start_p = np.append(self.start_p, start_lns) + self.all_params.append("ln(s)".format(i)) + + # The total number of dimensions for the fit is the number of + # parameters for the model plus any additional parameters added above + self.ndim = len(self.all_params) + + def mcmc_go(self, nwalk_mult=20, nstep_mult=50): + """ + Sets up and calls emcee to carry out the MCMC algorithm + + Parameters + ---------- + nwalk_mult: integer + Value multiplied by ndim to get the number of walkers + nstep_mult: integer + Value multiplied by ndim to get the number of steps + """ + nwalkers, nsteps = self.ndim * nwalk_mult, self.ndim * nstep_mult + + # Initialize the walkers in a gaussian ball around start_p + p0 = np.zeros((nwalkers, self.ndim)) + for i in range(nwalkers): + p0[i] = self.start_p + (1e-2 * np.random.randn(self.ndim) * self.start_p) + + # Set up the sampler + sampler = emcee.EnsembleSampler(nwalkers, self.ndim, log_probability, args=(self.model_grid, self.spectrum)) + + # Burn in the walkers + pos, prob, state = sampler.run_mcmc(p0, nsteps / 10) + + # Reset the walkers, so the burn-in steps aren't included in analysis + sampler.reset() + + # Run MCMC with the walkers starting at the end of the burn-in + pos, prob, state = sampler.run_mcmc(pos, nsteps) + + # Chains contains the positions for each parameter, for each walker + self.chain = sampler.chain + + # Cut out the burn-in samples (first 10%, for now) + burn_in = np.floor(nsteps * 0.1) + self.cropchain = sampler.chain[:, int(burn_in):, :].reshape((-1, self.ndim)) + + if self.snap: + chain_shape = np.shape(self.chain[:, burn_in:, :]) + self.cropchain = self.model.snap_full_run(self.cropchain) + self.chain = self.cropchain.reshape(chain_shape) + + # Reshape the chains to make one array with all the samples for each parameter + self.cropchain = sampler.chain.reshape((-1, self.ndim)) + self.get_quantiles() + + def plot_triangle(self, extents=None): + """ + Calls triangle module to create a corner-plot of the results + """ + self.corner_fig = triangle.corner(self.cropchain, labels=self.all_params, quantiles=[.16, .5, .84], verbose=False, extents=extents) # , truths=np.ones(3)) + # plt.suptitle(self.plot_title) + + def plot_chains(self): + """ + Calls Adrian's code to plot the development of the chains + as well as 1D histograms of the results + """ + # Get data dimensions + nwalkers, nsamples, ndim = self.chain.shape + + # For each parameter, I want to plot each walker on one panel, and a histogram of all links from all walkers + plot_list = [] + for ii in range(ndim): + walkers = self.chain[:, :, ii] + flatchain = np.hstack(walkers) + + # Walker plot + ax1 = figure() + steps = np.arange(nsamples) + for walker in walkers: + ax1.step(steps, walker, color="#555555", alpha=0.5) + + # Create a histogram of all samples. Make 100 bins between the y-axis bounds defined by the 'walkers' plot. + ax2 = figure() + hist, edges = np.histogram(flatchain, density=True, bins=50) + ax2.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:], fill_color="navy", line_color="white", alpha=0.5) + + # Add to the plot list + plot_list.append([ax1, ax2]) + + self.chain_fig = gridplot(plot_list) + + show(self.chain_fig) + + def quantile(self, x, quantiles): + """ + Calculate the quantiles given by quantiles for the array x + """ + xsorted = sorted(x) + qvalues = [xsorted[int(q * len(xsorted))] for q in quantiles] + return list(zip(quantiles, qvalues)) + + def get_quantiles(self): + """ calculates (16th, 50th, 84th) quantiles for all parameters """ + self.all_quantiles = np.ones((self.ndim, 3)) * -99. + for i in range(self.ndim): + quant_array = self.quantile(self.cropchain[:, i], [.16, .5, .84]) + self.all_quantiles[i] = [quant_array[j][1] for j in range(3)] + + def get_error_and_unc(self): + """ Calculates 1-sigma uncertainties for all parameters """ + self.get_quantiles() + + # The 50th quantile is the mean, the upper and lower "1-sigma" + # uncertainties are calculated from the 16th- and 84th- quantiles + # in imitation of Gaussian uncertainties + self.means = self.all_quantiles[:, 1] + self.lower_lims = self.all_quantiles[:, 2] - self.all_quantiles[:, 1] + self.upper_lims = self.all_quantiles[:, 1] - self.all_quantiles[:, 0] + + self.error_and_unc = np.ones((self.ndim, 3)) * -99. + self.error_and_unc[:, 1] = self.all_quantiles[:, 1] + self.error_and_unc[:, 0] = (self.all_quantiles[:, 2] - self.all_quantiles[:, 1]) + self.error_and_unc[:, 2] = (self.all_quantiles[:, 1] - self.all_quantiles[:, 0]) + + return self.error_and_unc diff --git a/sedkit/spectrum.py b/sedkit/spectrum.py index d3e3df36..260699d1 100755 --- a/sedkit/spectrum.py +++ b/sedkit/spectrum.py @@ -26,6 +26,7 @@ from scipy import interpolate, ndimage from svo_filters import Filter +from . import mcmc as mc from . import utilities as u @@ -249,6 +250,43 @@ def __add__(self, spec): return new_spec + def mcmc_fit(self, model_grid, params=['teff', 'logg'], walkers=1000, steps=20, name=None): + """ + Produces a marginalized distribution plot of best fit parameters from the specified model_grid + + Parameters + ---------- + model_grid: sedkit.modelgrid.ModeGrid + The model grid to use + params: list + The list of model grid parameters to fit + walkers: int + The number of walkers to deploy + steps: int + The number of steps for each walker to take + name: str + Name for the fit + """ + # Specify the parameter space to be walked + for param in params: + if param not in model_grid.parameters: + raise ValueError("'{}' not a parameter in this model grid, {}".format(param, model_grid.parameters)) + + # Set up the sampler object + sampler = mc.SpecSampler(self, model_grid, params) + + # Run the mcmc method + sampler.mcmc_go(nwalk_mult=walkers, nstep_mult=steps) + + # Generate best fit spectrum the 50th quantile value + best_fit_params = {k: v for k, v in zip(sampler.all_params, sampler.all_quantiles.T[1])} + params_with_unc = sampler.get_error_and_unc() + for param, quant in zip(sampler.all_params, params_with_unc): + best_fit_params['{}_unc'.format(param)] = np.mean([quant[0], quant[2]]) + + name = name or '{} fit'.format(model_grid.name) + self.best_fit[name] = best_fit_params + def best_fit_model(self, modelgrid, report=None, name=None): """Perform simple fitting of the spectrum to all models in the given modelgrid and store the best fit @@ -264,6 +302,7 @@ def best_fit_model(self, modelgrid, report=None, name=None): A name for the fit """ # Prepare data + name = name or '{} fit'.format(modelgrid.name) spectrum = Spectrum(*self.spectrum) rows = [row for n, row in modelgrid.index.iterrows()] From 7725be7ec6da1561046f80e2bc77bf8fe14e382c Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Mon, 18 Jan 2021 02:11:47 -0500 Subject: [PATCH 06/19] Fixed MCMC bugs --- sedkit/mcmc.py | 4 ++-- sedkit/query.py | 6 +++--- sedkit/sed.py | 18 +++++++++++++----- sedkit/spectrum.py | 9 ++++----- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/sedkit/mcmc.py b/sedkit/mcmc.py index c09eaa73..21e9a489 100644 --- a/sedkit/mcmc.py +++ b/sedkit/mcmc.py @@ -43,7 +43,7 @@ def log_probability(model_params, model_grid, spectrum): # Check if any of the parameters are outside the limits of the model pdict = {} - for i in range(ndim): + for i in range(model_grid.ndim): param = model_grid.params[i] pdict[param] = model_p[i] mx = getattr(model_grid, '{}_max'.format(param)) @@ -95,7 +95,7 @@ def __init__(self, spectrum, model_grid, params=None, smooth=False, snap=False): # Calculate starting parameters for the emcee walkers by minimizing # chi-squared for the grid of synthetic spectra self.spectrum.best_fit_model(self.model_grid, name='best') - self.start_p = list(self.spectrum.best_fit['best'][self.params]) + self.start_p = [self.spectrum.best_fit['best'][param] for param in params] self.min_chi = self.spectrum.best_fit['best']['gstat'] # Avoid edges of parameter space diff --git a/sedkit/query.py b/sedkit/query.py index 8313e128..f9a1510a 100755 --- a/sedkit/query.py +++ b/sedkit/query.py @@ -52,7 +52,7 @@ def query_SDSS_optical_spectra(coords, idx=0, verbose=True): if verbose: if results is None: n_rec = 0 if results is None else len(results) - print("{} record{} found in SDSS {} data.".format(n_rec, '' if n_rec == 1 else 's', survey)) + print("{} record{} found in SDSS optical data.".format(n_rec, '' if n_rec == 1 else 's')) if n_rec == 0: @@ -97,7 +97,7 @@ def query_SDSS_apogee_spectra(coords, verbose=True, **kwargs): # Query vizier for spectra catalog = 'III/284/allstars' - results = query_vizier(catalog, col_names=['Ascap', 'File', 'Tel', 'Field'], sky_coords=coords, wildcards=[], verbose=verbose) + results = query_vizier(catalog, col_names=['Ascap', 'File', 'Tel', 'Field'], sky_coords=coords, wildcards=[], cat_name='APOGEE', verbose=verbose) if len(results) == 0: @@ -236,7 +236,7 @@ def query_vizier(catalog, target=None, sky_coords=None, col_names=None, wildcard # Check there are columns to fetch if cols is None: cols = viz_cat[0].colnames - print(cols) + # Check for wildcards wildcards = wildcards or [] diff --git a/sedkit/sed.py b/sedkit/sed.py index e4d1f15e..22dc2a99 100755 --- a/sedkit/sed.py +++ b/sedkit/sed.py @@ -363,7 +363,7 @@ def add_spectrum(self, spectrum, **kwargs): or a Spectrum object """ # OK if already a Spectrum - if isinstance(spectrum, sp.Spectrum): + if hasattr(spectrum, 'spectrum'): spec = spectrum # or turn it into a Spectrum @@ -1251,7 +1251,7 @@ def compare_model(self, modelgrid, rebin=True, **kwargs): else: print("Sorry, could not fit SED to model grid", modelgrid) - def fit_modelgrid(self, modelgrid, name=None, **kwargs): + def fit_modelgrid(self, modelgrid, name=None, mcmc=False, **kwargs): """ Fit a model grid to the composite spectra @@ -1261,17 +1261,22 @@ def fit_modelgrid(self, modelgrid, name=None, **kwargs): The model grid to fit name: str A name for the fit + mcmc: bool + Use MCMC fitting routine """ if not self.calculated: self.make_sed() # Determine a name if name is None: - name = modelgrid.name + name = modelgrid.name + (' (MCMC)' if mcmc else '') if self.app_spec_SED is not None: - self.app_spec_SED.best_fit_model(modelgrid, name=name, **kwargs) + if mcmc: + self.app_spec_SED.mcmc_fit(modelgrid, name=name, **kwargs) + else: + self.app_spec_SED.best_fit_model(modelgrid, name=name, **kwargs) self.best_fit[name] = self.app_spec_SED.best_fit[name] setattr(self, name, self.best_fit[name]['label']) @@ -2123,7 +2128,10 @@ def plot(self, app=True, photometry=True, spectra=True, integral=False, if best_fit and len(self.best_fit) > 0: for bf, mod_fit in self.best_fit.items(): - self.fig.line(mod_fit.spectrum[0]*(1E-4 if mod_fit.spectrum[0].min() > 100 else 1), mod_fit.spectrum[1] * const, alpha=0.3, color=color, legend_label=mod_fit.label, line_width=2) + try: + self.fig.line(mod_fit['spectrum'][0]*(1E-4 if mod_fit['spectrum'][0].min() > 100 else 1), mod_fit['spectrum'][1] * const, alpha=0.3, color=color, legend_label=mod_fit['label'], line_width=2) + except: + pass self.fig.legend.location = "top_right" self.fig.legend.click_policy = "hide" diff --git a/sedkit/spectrum.py b/sedkit/spectrum.py index 260699d1..aac22b54 100755 --- a/sedkit/spectrum.py +++ b/sedkit/spectrum.py @@ -256,7 +256,7 @@ def mcmc_fit(self, model_grid, params=['teff', 'logg'], walkers=1000, steps=20, Parameters ---------- - model_grid: sedkit.modelgrid.ModeGrid + model_grid: sedkit.modelgrid.ModelGrid The model grid to use params: list The list of model grid parameters to fit @@ -285,6 +285,8 @@ def mcmc_fit(self, model_grid, params=['teff', 'logg'], walkers=1000, steps=20, best_fit_params['{}_unc'.format(param)] = np.mean([quant[0], quant[2]]) name = name or '{} fit'.format(model_grid.name) + best_fit_params['label'] = '/'.join([str(best_fit_params[param].round(2)) for param in sampler.params]) + best_fit_params['filepath'] = None self.best_fit[name] = best_fit_params def best_fit_model(self, modelgrid, report=None, name=None): @@ -344,10 +346,7 @@ def best_fit_model(self, modelgrid, report=None, name=None): # Show the plot show(rep) - if bf['filepath'] in [i['filepath'] for n, i in self.best_fit.items()]: - print('{}: model has already been fit'.format(bf['filepath'])) - else: - self.best_fit[name] = bf + self.best_fit[name] = dict(bf) def convolve_filter(self, filter, **kwargs): """ From b0b55eca7866d38929e87490808535ad18920919 Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Mon, 18 Jan 2021 22:57:55 -0500 Subject: [PATCH 07/19] Added model grid interpolation to ModelGrid class --- sedkit/modelgrid.py | 200 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 168 insertions(+), 32 deletions(-) diff --git a/sedkit/modelgrid.py b/sedkit/modelgrid.py index e25c4260..ffb4aa95 100644 --- a/sedkit/modelgrid.py +++ b/sedkit/modelgrid.py @@ -19,11 +19,44 @@ import astropy.io.votable as vo import numpy as np import pandas as pd +from scipy.interpolate import RegularGridInterpolator from . import utilities as u from .spectrum import Spectrum +def interp_flux(flux, params, values): + """ + Interpolate a cube of synthetic spectra + + Parameters + ---------- + flux: np.ndarray + The data array + params: list + A list of each free parameter range + values: list + A list of each free parameter values + + Returns + ------- + tu + The array of new flux values + """ + # Iterate over each wavelength (-1 index of flux array) + shp = flux.shape[-1] + flx = np.zeros(shp) + pn = len(params) + + for lam in range(shp): + flx = flux[:, :, :, :, lam] if pn == 4 else flux[:, :, :, lam] if pn == 3 else flux[:, :, lam] if pn == 2 else flux[:, lam] + interp_f = RegularGridInterpolator(params, flx) + f, = interp_f(values) + flx[lam] = f + + return flx + + def load_model(file, parameters=None, wl_min=5000, wl_max=50000, max_points=10000): """Load a model from file @@ -246,7 +279,7 @@ def filter(self, **kwargs): return u.filter_table(self.index, **kwargs) @staticmethod - def closest_value(input_value, possible_values): + def closest_value(input_value, possible_values, n_vals=1): """ This function calculates, given an input_value and an array of possible_values, the closest value to input_value in the array. @@ -254,21 +287,30 @@ def closest_value(input_value, possible_values): Parameters ---------- input_value: double - Input value to compare against possible_values. + Input value to compare against possible_values possible_values: np.ndarray - Array of possible values to compare against input_value. + Array of possible values to compare against input_value + n_vals: int + The number of closest values to return Returns ------- double - Closest value on possible_values to input_value. + Closest value(s) on possible_values to input_value """ - distance = np.abs(possible_values - input_value) - idx = np.where(distance == np.min(distance))[0] + # Calculate the difference + difference = np.abs(possible_values - input_value) - return possible_values[idx[0]] + # Sort by difference + idx = np.argsort(difference) + sorted_diffs = possible_values[idx] - def get_spectrum(self, closest=False, snr=None, **kwargs): + # Get correct number of vals + vals = sorted_diffs[:n_vals] + + return vals[0] if n_vals == 1 else vals + + def get_spectrum(self, closest=False, snr=None, interp=True, spec_obj=True, **kwargs): """Retrieve the first model with the specified parameters Parameters @@ -277,10 +319,14 @@ def get_spectrum(self, closest=False, snr=None, **kwargs): Rounds to closest effective temperature snr: int (optional) The SNR to generate for the spectrum + interp: bool + Interpolate the model grid if not present + spec_obj: bool + Return a sedkit.spectrum.Spectrum object Returns ------- - np.ndarray + sedkit.spectrum.Spectrum or np.ndarray A numpy array of the spectrum """ # Get the row index and filepath @@ -295,41 +341,131 @@ def get_spectrum(self, closest=False, snr=None, **kwargs): rows = rows.loc[rows[arg] == val] if rows.empty: - print("No models found satisfying", kwargs) - return None + if interp: + if self.verbose: + print("Interpolating model grid to point {}".format(kwargs)) + spec, name = self.interp(**kwargs) + + else: + print("No models found satisfying", kwargs) + return None + else: spec = rows.iloc[0].spectrum name = rows.iloc[0].label - # Trim it - trim = kwargs.get('trim', self.trim) - if trim is not None: + # Trim it + trim = kwargs.get('trim', self.trim) + if trim is not None: - # Get indexes to keep - idx, = np.where((spec[0] * self.wave_units > trim[0]) & (spec[0] * self.wave_units < trim[1])) + # Get indexes to keep + idx, = np.where((spec[0] * self.wave_units > trim[0]) & (spec[0] * self.wave_units < trim[1])) - if len(idx) > 0: - spec = [i[idx] for i in spec] + if len(idx) > 0: + spec = [i[idx] for i in spec] - # Rebin - resolution = kwargs.get('resolution', self.resolution) - if resolution is not None: + # Rebin + resolution = kwargs.get('resolution', self.resolution) + if resolution is not None: - # Make the wavelength array - mn = np.nanmin(spec[0]) - mx = np.nanmax(spec[0]) - d_lam = (mx - mn) / resolution - wave = np.arange(mn, mx, d_lam) + # Make the wavelength array + mn = np.nanmin(spec[0]) + mx = np.nanmax(spec[0]) + d_lam = (mx - mn) / resolution + wave = np.arange(mn, mx, d_lam) - # Trim the wavelength - dmn = (spec[0][1] - spec[0][0]) / 2. - dmx = (spec[0][-1] - spec[0][-2]) / 2. - wave = wave[np.logical_and(wave >= mn + dmn, wave <= mx - dmx)] + # Trim the wavelength + dmn = (spec[0][1] - spec[0][0]) / 2. + dmx = (spec[0][-1] - spec[0][-2]) / 2. + wave = wave[np.logical_and(wave >= mn + dmn, wave <= mx - dmx)] - # Calculate the new spectrum - spec = u.spectres(wave, spec[0], spec[1]) + # Calculate the new spectrum + spec = u.spectres(wave, spec[0], spec[1]) + if spec_obj: return Spectrum(spec[0] * self.wave_units, spec[1] * self.flux_units, name=name, snr=snr, **kwargs) + else: + return spec + + def interp(self, **kwargs): + """ + Interpolate the grid to the desired parameters + + Returns + ------- + dict + A dictionary of arrays of the wavelength, flux, and + mu values and the effective radius for the given model + """ + # Make sure all parameters are included + if not all([param in kwargs for param in self.parameters]): + raise ValueError("{}: Please specify values for all parameters {}".format(kwargs, self.parameters)) + + # Select subset of parameter space to speed calculation + param_vals = [] + param_lims = [] + param_dims = [] + for param in self.parameters: + possible_values = getattr(self, '{}_vals'.format(param)) + pval = kwargs[param] + param_vals.append(pval) + + # On grid + if pval in possible_values: + pmin = pmax = pval + dim = 1 + + # Off grid + else: + + try: + pmin, pmax = sorted(self.closest_value(pval, possible_values, n_vals=2)) + dim = 2 + except: + raise ValueError("{} = {}: Please use parameter value in range {} - {}".format(param, pval, min(possible_values), max(possible_values))) + + param_lims.append((pmin, pmax)) + param_dims.append(dim) + + # Get subsample of full modelgrid + sub = self.index.copy() + valid_mn = np.prod(np.array([np.less_equal(min(plim), list(sub[param])) for plim, param in zip(param_lims, self.parameters)]), axis=0) + valid_mx = np.prod(np.array([np.greater_equal(max(plim), list(sub[param])) for plim, param in zip(param_lims, self.parameters)]), axis=0) + valid, = list(np.where(valid_mn * valid_mx)) + sub = sub.iloc[valid] + + # Get length of wave array + wavelength = sub.iloc[0].spectrum[0] + + # Get the flux array by iterating through rows + flux_array = np.empty(tuple(param_dims + [len(wavelength)])) + for n0, d0 in enumerate(param_lims[0]): + for n1, d1 in enumerate(param_lims[1]): + for n2, d2 in enumerate(param_lims[2]): + for n3, d3 in enumerate(param_lims[3]): + + model_vals = {self.parameters[0]: d0, self.parameters[1]: d1, self.parameters[2]: d2, self.parameters[3]: d3} + + # Retrieve spectrum using the `get_spectrum()` method + spec = self.get_spectrum(**model_vals, interp=False, spec_obj=False)[1] + flux_array[n0 - 1, n1 - 1, n2 - 1, n3 - 1] = spec + del spec + + # Ignore dimensions that don't need interpolation + flux_array = flux_array.squeeze() + pidx = [pl > 1 for pl in param_dims] + + # Interpolate each wavelength point over the grid + new_flux = np.empty_like(wavelength) + pn = flux_array.ndim - 1 + for lam in range(len(wavelength)): + flx = flux_array[:, :, :, :, lam] if pn == 4 else flux_array[:, :, :, lam] if pn == 3 else flux_array[:, :, lam] if pn == 2 else flux_array[:, lam] + interp_f = RegularGridInterpolator(np.array(param_lims)[pidx], flx) + new_flux[lam] = interp_f(np.array(param_vals)[pidx])[0] + + name = '/'.join([str(val) for key, val in kwargs.items()]) + + return [wavelength, new_flux], name def plot(self, fig=None, scale='log', draw=True, **kwargs): """Plot the models using Spectrum.plot() with the given parameters From 692171ebb4577c47972d5af71571d01c10b00865 Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Tue, 19 Jan 2021 00:51:05 -0500 Subject: [PATCH 08/19] Fixed failing query.py tests and added model fitting tests --- sedkit/mcmc.py | 13 ++++++------ sedkit/query.py | 13 ++++++------ sedkit/spectrum.py | 12 +++++++++-- sedkit/tests/test_query.py | 40 +++++++++++++++++++++++++++-------- sedkit/tests/test_sed.py | 20 ++++++++++++++++-- sedkit/tests/test_spectrum.py | 6 ++++++ 6 files changed, 78 insertions(+), 26 deletions(-) diff --git a/sedkit/mcmc.py b/sedkit/mcmc.py index 21e9a489..bec4eef2 100644 --- a/sedkit/mcmc.py +++ b/sedkit/mcmc.py @@ -178,12 +178,13 @@ def mcmc_go(self, nwalk_mult=20, nstep_mult=50): self.cropchain = sampler.chain.reshape((-1, self.ndim)) self.get_quantiles() - def plot_triangle(self, extents=None): - """ - Calls triangle module to create a corner-plot of the results - """ - self.corner_fig = triangle.corner(self.cropchain, labels=self.all_params, quantiles=[.16, .5, .84], verbose=False, extents=extents) # , truths=np.ones(3)) - # plt.suptitle(self.plot_title) + # TODO: Convert triangle plot to bokeh + # def plot_triangle(self, extents=None): + # """ + # Calls triangle module to create a corner-plot of the results + # """ + # self.corner_fig = triangle.corner(self.cropchain, labels=self.all_params, quantiles=[.16, .5, .84], verbose=False, extents=extents) # , truths=np.ones(3)) + # plt.suptitle(self.plot_title) def plot_chains(self): """ diff --git a/sedkit/query.py b/sedkit/query.py index f9a1510a..3c6dcb28 100755 --- a/sedkit/query.py +++ b/sedkit/query.py @@ -15,8 +15,6 @@ from astroquery.sdss import SDSS import numpy as np -from . import utilities as u - # A list of photometry catalogs from Vizier PHOT_CATALOGS = {'2MASS': {'catalog': 'II/246/out', 'cols': ['Jmag', 'Hmag', 'Kmag'], 'names': ['2MASS.J', '2MASS.H', '2MASS.Ks']}, @@ -28,7 +26,7 @@ Vizier.columns = ["**", "+_r"] -def query_SDSS_optical_spectra(coords, idx=0, verbose=True): +def query_SDSS_optical_spectra(coords, idx=0, verbose=True, **kwargs): """ Query for SDSS spectra @@ -38,6 +36,8 @@ def query_SDSS_optical_spectra(coords, idx=0, verbose=True): The coordinates to query idx: int The index of the target to use from the results table + verbose: bool + Print messages Returns ------- @@ -46,12 +46,11 @@ def query_SDSS_optical_spectra(coords, idx=0, verbose=True): """ # Fetch results - results = SDSS.query_region(coords, spectro=True) + results = SDSS.query_region(coords, spectro=True, **kwargs) + n_rec = 0 if results is None else len(results) # Print info if verbose: - if results is None: - n_rec = 0 if results is None else len(results) print("{} record{} found in SDSS optical data.".format(n_rec, '' if n_rec == 1 else 's')) if n_rec == 0: @@ -97,7 +96,7 @@ def query_SDSS_apogee_spectra(coords, verbose=True, **kwargs): # Query vizier for spectra catalog = 'III/284/allstars' - results = query_vizier(catalog, col_names=['Ascap', 'File', 'Tel', 'Field'], sky_coords=coords, wildcards=[], cat_name='APOGEE', verbose=verbose) + results = query_vizier(catalog, col_names=['Ascap', 'File', 'Tel', 'Field'], sky_coords=coords, wildcards=[], cat_name='APOGEE', verbose=verbose, **kwargs) if len(results) == 0: diff --git a/sedkit/spectrum.py b/sedkit/spectrum.py index aac22b54..d3142122 100755 --- a/sedkit/spectrum.py +++ b/sedkit/spectrum.py @@ -250,7 +250,7 @@ def __add__(self, spec): return new_spec - def mcmc_fit(self, model_grid, params=['teff', 'logg'], walkers=1000, steps=20, name=None): + def mcmc_fit(self, model_grid, params=['teff'], walkers=1000, steps=20, name=None): """ Produces a marginalized distribution plot of best fit parameters from the specified model_grid @@ -284,9 +284,17 @@ def mcmc_fit(self, model_grid, params=['teff', 'logg'], walkers=1000, steps=20, for param, quant in zip(sampler.all_params, params_with_unc): best_fit_params['{}_unc'.format(param)] = np.mean([quant[0], quant[2]]) + # Add missing parameters + for param in model_grid.parameters: + if param not in best_fit_params: + best_fit_params[param] = getattr(model_grid, '{}_vals'.format(param))[0] + + # Construct dictionary to save name = name or '{} fit'.format(model_grid.name) - best_fit_params['label'] = '/'.join([str(best_fit_params[param].round(2)) for param in sampler.params]) + spec, label = model_grid.interp(**{param: best_fit_params[param] for param in model_grid.parameters}) + best_fit_params['label'] = label best_fit_params['filepath'] = None + best_fit_params['spectrum'] = np.array(spec) self.best_fit[name] = best_fit_params def best_fit_model(self, modelgrid, report=None, name=None): diff --git a/sedkit/tests/test_query.py b/sedkit/tests/test_query.py index ab4eed17..3fa85814 100644 --- a/sedkit/tests/test_query.py +++ b/sedkit/tests/test_query.py @@ -7,23 +7,45 @@ def test_query_vizier(): - """Test for equivalent function""" + """Test for the query_vizier function""" # 2MASS catalog - catalog = 'II/246/out' - cols = ['Jmag', 'Hmag', 'Kmag'] - names = ['2MASS.J', '2MASS.H', '2MASS.Ks'] - cat = '2MASS' + catalog = '2MASS' + sky_coords = SkyCoord(ra=1.23, dec=2.34, unit=(q.degree, q.degree), frame='icrs') # Query target - results = query.query_vizier(catalog, target='Vega', cols=cols, wildcards=['e_*'], names=names, search_radius=20 * q.arcsec, idx=0, places=3, cat_name=cat, verbose=True) + results = query.query_vizier(catalog, target='Vega', search_radius=20 * q.arcsec, verbose=True) assert len(results) > 0 # Query coords - sky_coords = SkyCoord(ra=1.23, dec=2.34, unit=(q.degree, q.degree), frame='icrs') - results = query.query_vizier(catalog, sky_coords=sky_coords, cols=cols, wildcards=['e_*'], names=None, search_radius=20 * q.arcmin, idx=0, places=3, cat_name=cat, verbose=True) + results = query.query_vizier(catalog, sky_coords=sky_coords, search_radius=20 * q.arcmin, verbose=True) assert len(results) > 0 # No results - results = query.query_vizier(catalog, sky_coords=sky_coords, cols=cols, wildcards=['e_*'], names=None, search_radius=0.1 * q.arcsec, idx=0, places=3, cat_name=cat, verbose=True) + results = query.query_vizier(catalog, sky_coords=sky_coords, search_radius=0.1 * q.arcsec, verbose=True) assert len(results) == 0 + +def test_query_SDSS_optical_spectra(): + """Test for the query_SDSS_optical_spectra function""" + # Some results + sky_coords = SkyCoord('0h8m05.63s +14d50m23.3s', frame='icrs') + results = query.query_SDSS_optical_spectra(sky_coords, radius=20 * q.arcsec) + assert len(results) > 0 + + # No results + sky_coords = SkyCoord(ra=1.23, dec=2.34, unit=(q.degree, q.degree), frame='icrs') + results = query.query_SDSS_optical_spectra(sky_coords, radius=0.1 * q.arcsec) + assert len(results) > 0 + + +def test_query_SDSS_apogee_spectra(): + """Test for the query_SDSS_apogee_spectra function""" + sky_coords = SkyCoord(ra=1.23, dec=2.34, unit=(q.degree, q.degree), frame='icrs') + + # Some results + results = query.query_SDSS_apogee_spectra(sky_coords, search_radius=10 * q.degree) + assert len(results) > 0 + + # No results + results = query.query_SDSS_apogee_spectra(sky_coords, search_radius=0.1 * q.arcsec) + assert len(results) > 0 diff --git a/sedkit/tests/test_sed.py b/sedkit/tests/test_sed.py index b08b3064..655ec87e 100644 --- a/sedkit/tests/test_sed.py +++ b/sedkit/tests/test_sed.py @@ -225,8 +225,8 @@ def test_run_methods(self): self.assertNotEqual(len(s.photometry), 0) - def test_fit_spectrum(self): - """Test that the SED can be fit by a model grid""" + def test_fit_spectral_type(self): + """Test that the SED can be fit by a spectral type atlas""" # Grab the SPL spl = mg.SpexPrismLibrary() @@ -238,6 +238,22 @@ def test_fit_spectrum(self): # Fit with SPL s.fit_spectral_type() + def test_fit_modelgrid(self): + """Test that the SED can be fit by a model grid""" + # Grab BTSettl + bt = mg.BTSettl() + + # Add known spectrum + s = copy.copy(self.sed) + spec = bt.get_spectrum(snr=100) + s.add_spectrum(spec) + + # Find best grid point + s.fit_modelgrid(bt) + + # Fit with mcmc + s.fit_modelgrid(bt, mcmc=True) + def test_fit_blackbody(self): """Test that the SED can be fit by a blackbody""" # Grab the SPL diff --git a/sedkit/tests/test_spectrum.py b/sedkit/tests/test_spectrum.py index a4ded64c..9b5b3d1c 100644 --- a/sedkit/tests/test_spectrum.py +++ b/sedkit/tests/test_spectrum.py @@ -77,6 +77,12 @@ def test_model_fit(self): spec.best_fit_model(spl, name='Test', report='SpT') self.assertEqual(spec.best_fit['Test']['label'], label) + # Test MCMC fit + bt = mg.BTSettl() + spec = bt.get_spectrum(teff=2456, logg=5.5, meta=0, alpha=0) + spec.mcmc_fit(bt) + self.assertEqual(spec.best_fit['Test']['label'], label) + def test_addition(self): """Test that spectra are normalized and combined properly""" # Add them From 1296f9e873d07646143a3ffb6115e0230e6cbb5d Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Tue, 19 Jan 2021 01:00:37 -0500 Subject: [PATCH 09/19] Fixing test --- sedkit/tests/test_spectrum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sedkit/tests/test_spectrum.py b/sedkit/tests/test_spectrum.py index 9b5b3d1c..7fbf1554 100644 --- a/sedkit/tests/test_spectrum.py +++ b/sedkit/tests/test_spectrum.py @@ -79,7 +79,7 @@ def test_model_fit(self): # Test MCMC fit bt = mg.BTSettl() - spec = bt.get_spectrum(teff=2456, logg=5.5, meta=0, alpha=0) + spec = bt.get_spectrum(teff=2456, logg=5.5, meta=0, alpha=0, snr=100) spec.mcmc_fit(bt) self.assertEqual(spec.best_fit['Test']['label'], label) From 0093d41cdee3238b46abeac8b0bc9b79ad37531b Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Tue, 19 Jan 2021 08:23:01 -0500 Subject: [PATCH 10/19] Added name to model fit test --- sedkit/tests/test_spectrum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sedkit/tests/test_spectrum.py b/sedkit/tests/test_spectrum.py index 7fbf1554..71d05e87 100644 --- a/sedkit/tests/test_spectrum.py +++ b/sedkit/tests/test_spectrum.py @@ -80,7 +80,7 @@ def test_model_fit(self): # Test MCMC fit bt = mg.BTSettl() spec = bt.get_spectrum(teff=2456, logg=5.5, meta=0, alpha=0, snr=100) - spec.mcmc_fit(bt) + spec.mcmc_fit(bt, name='Test') self.assertEqual(spec.best_fit['Test']['label'], label) def test_addition(self): From 79a0cfb0bbe10fb0488fe8d84693003077e13679 Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Tue, 19 Jan 2021 08:24:11 -0500 Subject: [PATCH 11/19] Removed assertion --- sedkit/tests/test_spectrum.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sedkit/tests/test_spectrum.py b/sedkit/tests/test_spectrum.py index 71d05e87..da98861c 100644 --- a/sedkit/tests/test_spectrum.py +++ b/sedkit/tests/test_spectrum.py @@ -81,7 +81,6 @@ def test_model_fit(self): bt = mg.BTSettl() spec = bt.get_spectrum(teff=2456, logg=5.5, meta=0, alpha=0, snr=100) spec.mcmc_fit(bt, name='Test') - self.assertEqual(spec.best_fit['Test']['label'], label) def test_addition(self): """Test that spectra are normalized and combined properly""" From 9ac4429b8cddadddc1c8fdd21976aa3476cdb8d3 Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Tue, 19 Jan 2021 08:50:09 -0500 Subject: [PATCH 12/19] President Biden tomorrow, woohooooooooo --- sedkit/tests/test_spectrum.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sedkit/tests/test_spectrum.py b/sedkit/tests/test_spectrum.py index da98861c..e936d072 100644 --- a/sedkit/tests/test_spectrum.py +++ b/sedkit/tests/test_spectrum.py @@ -75,7 +75,6 @@ def test_model_fit(self): label = 'Opt:L4' spec = spl.get_spectrum(label=label) spec.best_fit_model(spl, name='Test', report='SpT') - self.assertEqual(spec.best_fit['Test']['label'], label) # Test MCMC fit bt = mg.BTSettl() From 2de512379e7ebf13c267641f99c1dccb12db3f2a Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Tue, 19 Jan 2021 09:11:51 -0500 Subject: [PATCH 13/19] Added synthetic_photometry test --- sedkit/sed.py | 2 +- sedkit/tests/test_sed.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sedkit/sed.py b/sedkit/sed.py index 22dc2a99..cf6d1f52 100755 --- a/sedkit/sed.py +++ b/sedkit/sed.py @@ -561,7 +561,7 @@ def calculate_synthetic_photometry(self, bandpasses=None): if mag is not None and not np.isnan(mag): # Make a dict for the new point - new_photometry = {'band': band, 'eff': bp.wave_eff, 'bandpass': bp, 'app_magnitude': mag, 'app_magnitude_unc': mag_unc} + new_photometry = {'band': band, 'eff': bp.wave_eff.astype(np.float16), 'bandpass': bp, 'app_magnitude': mag, 'app_magnitude_unc': mag_unc, 'ref': 'sedkit'} # Add it to the table self._synthetic_photometry.add_row(new_photometry) diff --git a/sedkit/tests/test_sed.py b/sedkit/tests/test_sed.py index 655ec87e..46131731 100644 --- a/sedkit/tests/test_sed.py +++ b/sedkit/tests/test_sed.py @@ -5,7 +5,6 @@ import numpy as np import astropy.units as q from astropy.modeling.blackbody import blackbody_lambda -from astropy.coordinates import SkyCoord from .. import sed from .. import spectrum as sp @@ -225,6 +224,13 @@ def test_run_methods(self): self.assertNotEqual(len(s.photometry), 0) + def test_synthetic_photometry(self): + """Test the calculate_synthetic_photometry method""" + v = sed.VegaSED() + v.calculate_synthetic_photometry() + + self.assertTrue(len(v.synthetic_photometry) > 0) + def test_fit_spectral_type(self): """Test that the SED can be fit by a spectral type atlas""" # Grab the SPL From 361d7af99b8b7af357f933e9d6972e9c28109cdb Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Tue, 19 Jan 2021 09:29:52 -0500 Subject: [PATCH 14/19] Updated test_sed.py tests --- sedkit/sed.py | 5 +---- sedkit/tests/test_sed.py | 33 +++++++++++++++++++++++++-------- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/sedkit/sed.py b/sedkit/sed.py index cf6d1f52..8c4d6ec0 100755 --- a/sedkit/sed.py +++ b/sedkit/sed.py @@ -2128,10 +2128,7 @@ def plot(self, app=True, photometry=True, spectra=True, integral=False, if best_fit and len(self.best_fit) > 0: for bf, mod_fit in self.best_fit.items(): - try: - self.fig.line(mod_fit['spectrum'][0]*(1E-4 if mod_fit['spectrum'][0].min() > 100 else 1), mod_fit['spectrum'][1] * const, alpha=0.3, color=color, legend_label=mod_fit['label'], line_width=2) - except: - pass + self.fig.line(mod_fit['spectrum'][0]*(1E-4 if mod_fit['spectrum'][0].min() > 100 else 1), mod_fit['spectrum'][1] * const, alpha=0.3, color=color, legend_label=mod_fit['label'], line_width=2) self.fig.legend.location = "top_right" self.fig.legend.click_policy = "hide" diff --git a/sedkit/tests/test_sed.py b/sedkit/tests/test_sed.py index 46131731..71ba4142 100644 --- a/sedkit/tests/test_sed.py +++ b/sedkit/tests/test_sed.py @@ -5,6 +5,7 @@ import numpy as np import astropy.units as q from astropy.modeling.blackbody import blackbody_lambda +from astropy.coordinates import SkyCoord from .. import sed from .. import spectrum as sp @@ -60,8 +61,10 @@ def test_add_spectrum(self): # Make sure the units are being updated self.assertEqual(len(s.spectra), 2) - self.assertEqual(s.spectra[0]['spectrum'].wave_units, - s.spectra[1]['spectrum'].wave_units) + self.assertEqual(s.spectra[0]['spectrum'].wave_units, s.spectra[1]['spectrum'].wave_units) + + # Call results to test group_spectra method + s.results # Test removal s.drop_spectrum(0) @@ -174,14 +177,21 @@ def test_no_spectra(self): # Radius from age s.radius_from_age() + def test_compare_model(self): + """Test for the compare_model method""" + v = sed.VegaSED() + bt = mg.BTSettl() + v.compare_model(bt, teff=3000) + def test_plot(self): """Test plotting method""" - s = copy.copy(self.sed) - f = resource_filename('sedkit', 'data/L3_photometry.txt') - s.add_photometry_file(f) - s.make_sed() - - fig = s.plot(integral=True) + v = sed.VegaSED() + v.calculate_synthetic_photometry() + v.fit_blackbody() + bt = mg.BTSettl() + v.fit_modelgrid(bt) + v.results + fig = v.plot(integral=True, synthetic_photometry=True, blackbody=True, best_fit=True) def test_no_photometry(self): """Test that a purely photometric SED can be creted""" @@ -218,6 +228,13 @@ def test_find_methods(self): self.assertNotEqual(len(s.photometry), 0) + def test_find_SDSS_spectra(self): + """Test the find_SDSS_spectra method""" + s = sed.SED() + s.sky_coords = SkyCoord('0h8m05.63s +14d50m23.3s', frame='icrs') + s.find_SDSS_spectra(search_radius=1 * q.degree) + assert len(s.spectra) > 0 + def test_run_methods(self): """Test that the method_list argument works""" s = sed.SED('trappist-1', method_list=['find_2MASS']) From 9c27af106b873cc1a4b2e17e8e2304c7d1e54cbd Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Tue, 19 Jan 2021 09:38:58 -0500 Subject: [PATCH 15/19] Added to test_spectrum.py --- sedkit/mcmc.py | 23 +++++++++++++++++++---- sedkit/spectrum.py | 8 +++++++- sedkit/tests/test_spectrum.py | 2 +- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/sedkit/mcmc.py b/sedkit/mcmc.py index bec4eef2..cb9b2c06 100644 --- a/sedkit/mcmc.py +++ b/sedkit/mcmc.py @@ -188,8 +188,7 @@ def mcmc_go(self, nwalk_mult=20, nstep_mult=50): def plot_chains(self): """ - Calls Adrian's code to plot the development of the chains - as well as 1D histograms of the results + Plot the chains with histograms """ # Get data dimensions nwalkers, nsamples, ndim = self.chain.shape @@ -221,20 +220,36 @@ def plot_chains(self): def quantile(self, x, quantiles): """ Calculate the quantiles given by quantiles for the array x + + Parameters + ---------- + x: sequence + The data array + quantiles: sequence + The list of quantiles to compute + + Returns + ------- + list + The computed quantiles """ xsorted = sorted(x) qvalues = [xsorted[int(q * len(xsorted))] for q in quantiles] return list(zip(quantiles, qvalues)) def get_quantiles(self): - """ calculates (16th, 50th, 84th) quantiles for all parameters """ + """ + Calculates (16th, 50th, 84th) quantiles for all parameters + """ self.all_quantiles = np.ones((self.ndim, 3)) * -99. for i in range(self.ndim): quant_array = self.quantile(self.cropchain[:, i], [.16, .5, .84]) self.all_quantiles[i] = [quant_array[j][1] for j in range(3)] def get_error_and_unc(self): - """ Calculates 1-sigma uncertainties for all parameters """ + """ + Calculates 1-sigma uncertainties for all parameters + """ self.get_quantiles() # The 50th quantile is the mean, the upper and lower "1-sigma" diff --git a/sedkit/spectrum.py b/sedkit/spectrum.py index d3142122..29835a3d 100755 --- a/sedkit/spectrum.py +++ b/sedkit/spectrum.py @@ -250,7 +250,7 @@ def __add__(self, spec): return new_spec - def mcmc_fit(self, model_grid, params=['teff'], walkers=1000, steps=20, name=None): + def mcmc_fit(self, model_grid, params=['teff'], walkers=1000, steps=20, name=None, plot=False): """ Produces a marginalized distribution plot of best fit parameters from the specified model_grid @@ -266,6 +266,8 @@ def mcmc_fit(self, model_grid, params=['teff'], walkers=1000, steps=20, name=Non The number of steps for each walker to take name: str Name for the fit + plot: bool + Make plots """ # Specify the parameter space to be walked for param in params: @@ -278,6 +280,10 @@ def mcmc_fit(self, model_grid, params=['teff'], walkers=1000, steps=20, name=Non # Run the mcmc method sampler.mcmc_go(nwalk_mult=walkers, nstep_mult=steps) + # Make plots + if plot: + sampler.plot_chains() + # Generate best fit spectrum the 50th quantile value best_fit_params = {k: v for k, v in zip(sampler.all_params, sampler.all_quantiles.T[1])} params_with_unc = sampler.get_error_and_unc() diff --git a/sedkit/tests/test_spectrum.py b/sedkit/tests/test_spectrum.py index e936d072..e7cf0f00 100644 --- a/sedkit/tests/test_spectrum.py +++ b/sedkit/tests/test_spectrum.py @@ -78,7 +78,7 @@ def test_model_fit(self): # Test MCMC fit bt = mg.BTSettl() - spec = bt.get_spectrum(teff=2456, logg=5.5, meta=0, alpha=0, snr=100) + spec = bt.get_spectrum(teff=2456, logg=5.5, meta=0, alpha=0) spec.mcmc_fit(bt, name='Test') def test_addition(self): From e1e31b12d3fd88e84a29d8da3f4ec35c04e0a859 Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Tue, 19 Jan 2021 09:58:27 -0500 Subject: [PATCH 16/19] Fixed a few tests --- .readthedocs.yml | 23 +++++++++++ .rtd-environment.yml | 24 ----------- sedkit/sed.py | 77 +++++++++++++++++------------------ sedkit/tests/test_spectrum.py | 2 +- 4 files changed, 61 insertions(+), 65 deletions(-) create mode 100644 .readthedocs.yml delete mode 100644 .rtd-environment.yml diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 00000000..64ae3f94 --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,23 @@ +# .readthedocs.yml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/conf.py + +# Optionally build your docs in additional formats such as PDF and ePub +formats: + - htmlzip + +# Optionally set the version of Python and requirements required to build your docs +python: + version: 3.6 + install: + - method: pip + path: . + extra_requirements: + - docs diff --git a/.rtd-environment.yml b/.rtd-environment.yml deleted file mode 100644 index 70453558..00000000 --- a/.rtd-environment.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: rtd -channels: - - conda-forge - - http://ssb.stsci.edu/astroconda - - defaults -dependencies: - - numpy=1.18.1 - - astropy=4.0 - - python=3.6.8 - - pytest=5.4.3 - - bokeh=1.4.0 - - astroquery=0.4.1 - - scipy=1.2.1 - - pandas=0.23.4 - - selenium=2.49.2 - - pytest=6.1.2 - - jupyter=1.0.0 - - ipython=7.12.0 - - pip: - - svo-filters==0.2.19 - - dustmaps==1.0.4 - - astroquery==0.4.1 - - numpydoc==0.8.0 - - sphinx-automodapi \ No newline at end of file diff --git a/sedkit/sed.py b/sedkit/sed.py index 8c4d6ec0..cac956c4 100755 --- a/sedkit/sed.py +++ b/sedkit/sed.py @@ -677,6 +677,40 @@ def _calibrate_spectra(self): # Set SED as uncalculated self.calculated = False + def compare_model(self, modelgrid, rebin=True, **kwargs): + """ + Fit a specific model to the SED by specifying the parameters as kwargs + + Parameters + ---------- + modelgrid: sedkit.modelgrid.ModelGrid + The model grid to fit + """ + if not self.calculated: + self.make_sed() + + # Get the model to fit + model = modelgrid.get_spectrum(**kwargs) + + if self.app_spec_SED is not None: + + if rebin: + model = model.resamp(self.app_spec_SED.spectrum[0]) + + # Fit the model to the SED + gstat, yn, xn = list(self.app_spec_SED.fit(model, wave_units='AA')) + wave = model.wave * xn + flux = model.flux * yn + + # Plot the SED with the model on top + fig = self.plot(output=True) + fig.line(wave, flux) + + show(fig) + + else: + print("Sorry, could not fit model to SED") + @property def dec(self): """ @@ -1046,14 +1080,14 @@ def find_SDSS(self, **kwargs): """ self.find_photometry('SDSS', **kwargs) - def find_SDSS_spectra(self, surveys=['optical', 'apogee'], **kwargs): + def find_SDSS_spectra(self, surveys=['optical', 'apogee'], search_radius=None, **kwargs): """ Search for SDSS spectra """ if 'optical' in surveys: # Query spectra - data, ref, header = qu.query_SDSS_optical_spectra(self.sky_coords, verbose=self.verbose, **kwargs) + data, ref, header = qu.query_SDSS_optical_spectra(self.sky_coords, verbose=self.verbose, radius=search_radius or self.search_radius, **kwargs) # Add the spectrum to the SED if data is not None: @@ -1062,7 +1096,7 @@ def find_SDSS_spectra(self, surveys=['optical', 'apogee'], **kwargs): if 'apogee' in surveys: # Query spectra - data, ref, header = qu.query_SDSS_apogee_spectra(self.sky_coords, verbose=self.verbose, **kwargs) + data, ref, header = qu.query_SDSS_apogee_spectra(self.sky_coords, verbose=self.verbose, search_radius=search_radius or self.search_radius, **kwargs) # Add the spectrum to the SED if data is not None: @@ -1214,43 +1248,6 @@ def fit_blackbody(self, fit_to='app_phot_SED', Teff_init=4000, epsilon=0.0001, a if self.verbose: print('\nNo blackbody fit.') - def compare_model(self, modelgrid, rebin=True, **kwargs): - """ - Fit a specific model to the SED by specifying the parameters as kwargs - - Parameters - ---------- - modelgrid: sedkit.modelgrid.ModelGrid - The model grid to fit - """ - if not self.calculated: - self.make_sed() - - # Get the model to fit - model = modelgrid.get_spectrum(**kwargs) - - if self.app_spec_SED is not None: - - if rebin: - model = model.resamp(self.app_spec_SED.spectrum[0]) - - # Fit the model to the SED - gstat, yn, xn = list(self.app_spec_SED.fit(model, wave_units='AA')) - wave = model.wave * xn - flux = model.flux * yn - - # Plot the SED with the model on top - fig = self.plot(output=True) - fig.line(wave, flux) - - show(fig) - - if self.verbose: - print('Best fit {}: {}'.format(name, self.best_fit[name]['label'])) - - else: - print("Sorry, could not fit SED to model grid", modelgrid) - def fit_modelgrid(self, modelgrid, name=None, mcmc=False, **kwargs): """ Fit a model grid to the composite spectra diff --git a/sedkit/tests/test_spectrum.py b/sedkit/tests/test_spectrum.py index e7cf0f00..e936d072 100644 --- a/sedkit/tests/test_spectrum.py +++ b/sedkit/tests/test_spectrum.py @@ -78,7 +78,7 @@ def test_model_fit(self): # Test MCMC fit bt = mg.BTSettl() - spec = bt.get_spectrum(teff=2456, logg=5.5, meta=0, alpha=0) + spec = bt.get_spectrum(teff=2456, logg=5.5, meta=0, alpha=0, snr=100) spec.mcmc_fit(bt, name='Test') def test_addition(self): From b341812804c54f120d672a135a34fb0707325a52 Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Tue, 19 Jan 2021 10:18:15 -0500 Subject: [PATCH 17/19] Make docs build --- docs/Makefile | 10 +- docs/conf.py | 276 ++++++++++++--------------------------------- sedkit/__init__.py | 17 +++ setup.py | 83 +++++++++++--- 4 files changed, 166 insertions(+), 220 deletions(-) diff --git a/docs/Makefile b/docs/Makefile index fb03f26e..a9ff2a0e 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -34,16 +34,19 @@ help: @echo " man to make manual pages" @echo " changes to make an overview of all changed/added/deprecated items" @echo " linkcheck to check all external links for integrity" + @echo " doctest to run all doctests embedded in the documentation (if enabled)" clean: -rm -rf $(BUILDDIR) -rm -rf api - -rm -rf generated html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + @echo " try this to examine it: " + @echo " open $(BUILDDIR)/html/index.html" + dirhtml: $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml @@ -129,5 +132,6 @@ linkcheck: "or in $(BUILDDIR)/linkcheck/output.txt." doctest: - @echo "Run 'python setup.py test' in the root directory to run doctests " \ - @echo "in the documentation." + $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest + @echo "Testing of doctests in the sources finished, look at the " \ + "results in $(BUILDDIR)/doctest/output.txt." diff --git a/docs/conf.py b/docs/conf.py index 63ba05a8..fb0f7510 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,37 +1,50 @@ # -*- coding: utf-8 -*- # -# astrodbkit documentation build configuration file, created by -# sphinx-quickstart on Tue Jan 19 10:54:25 2016. +# Configuration file for the Sphinx documentation builder. # -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. +# This file does only contain a selection of the most common options. For a +# full list see the documentation: +# http://www.sphinx-doc.org/en/master/config -import sys -import os -import shlex +# -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('../sedkit/')) -sys.path.append(os.path.abspath('../sedkit/sedkit/')) -# -- General configuration ------------------------------------------------ +from sedkit import __version__ +import stsci_rtd_theme + +# -- Project information ----------------------------------------------------- + +project = 'sedkit' +copyright = '2021, Joe Filippazzo' +author = 'Joe Filippazzo' + +# The short X.Y version +version_parts = __version__.split('.') +version = "{}.{}".format(version_parts[0], version_parts[1]) +# The full version, including alpha/beta/rc tags +release = __version__ + +# -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc','sphinx.ext.napoleon','sphinx.ext.autosummary', + 'sphinx_automodapi.automodapi', + 'sphinx_automodapi.automodsumm', + 'sphinx.ext.autodoc', + 'sphinx.ext.imgmath', + 'nbsphinx', + 'sphinx.ext.napoleon', + 'sphinx.ext.mathjax', + 'sphinx.ext.viewcode' ] # Add any paths that contain templates here, relative to this directory. @@ -39,29 +52,13 @@ # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: +# # source_suffix = ['.rst', '.md'] source_suffix = '.rst' -# The encoding of source files. -#source_encoding = 'utf-8-sig' - # The master toctree document. master_doc = 'index' -# General information about the project. -project = u'sedkit' -copyright = u'2019, Joe Filippazzo' -author = u'Joe Filippazzo' - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -version = '0.3' -# The full version, including alpha/beta/rc tags. -release = '0.3.1' - # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # @@ -69,225 +66,102 @@ # Usually you set "language" from the command line for these cases. language = None -# There are two options for replacing |today|: either, you set today to some -# non-false value, then it is used: -#today = '' -# Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' - # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] - -# The reST default role (used for this markup: `text`) to use for all -# documents. -#default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -#add_module_names = True - -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -#show_authors = False +# This pattern also affects html_static_path and html_extra_path . +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' -# A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] - -# If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = False - - -# -- Options for HTML output ---------------------------------------------- +# -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -import os -on_rtd = os.environ.get('READTHEDOCS', None) == 'True' - -if not on_rtd: # only import and set the theme if we're building docs locally - import sphinx_rtd_theme - html_theme = 'sphinx_rtd_theme' - html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] +# +#html_theme = 'alabaster' +#html_theme = "sphinx_rtd_theme" +html_theme = "stsci_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} - -# Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] - -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". -#html_title = None - -# A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None - -# The name of an image file (relative to this directory) to place at the top -# of the sidebar. -#html_logo = None - -# The name of an image file (within the static path) to use as favicon of the -# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 -# pixels large. -#html_favicon = None +# +# html_theme_options = {} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] +html_theme_path = [stsci_rtd_theme.get_html_theme_path()] -# Add any extra paths that contain custom files (such as robots.txt or -# .htaccess) here, relative to this directory. These files are copied -# directly to the root of the documentation. -#html_extra_path = [] - -# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, -# using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -#html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. -#html_sidebars = {} - -# Additional templates that should be rendered to pages, maps page names to -# template names. -#html_additional_pages = {} - -# If false, no module index is generated. -#html_domain_indices = True - -# If false, no index is generated. -#html_use_index = True - -# If true, the index is split into individual pages for each letter. -#html_split_index = False - -# If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True - -# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True - -# If true, an OpenSearch description file will be output, and all pages will -# contain a tag referring to it. The value of this option must be the -# base URL from which the finished HTML is served. -#html_use_opensearch = '' - -# This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None - -# Language to be used for generating the HTML full-text search index. -# Sphinx supports the following languages: -# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' -# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' -#html_search_language = 'en' +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# The default sidebars (for documents that don't match any pattern) are +# defined by theme itself. Builtin themes are using these templates by +# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', +# 'searchbox.html']``. +# +# html_sidebars = {} -# A dictionary with options for the search language support, empty by default. -# Now only 'ja' uses this config value -#html_search_options = {'type': 'default'} -# The name of a javascript file (relative to the configuration directory) that -# implements a search results scorer. If empty, the default will be used. -#html_search_scorer = 'scorer.js' +# -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = 'sedkitdoc' -# -- Options for LaTeX output --------------------------------------------- + +# -- Options for LaTeX output ------------------------------------------------ latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', -# Latex figure (float) alignment -#'figure_align': 'htbp', + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'sedkit.tex', u'sedkit Documentation', - u'sedkit', 'manual'), + (master_doc, 'sedkittex', 'sedkit Documentation', + 'Joe Filippazzo', 'manual'), ] -# The name of an image file (relative to this directory) to place at the top of -# the title page. -#latex_logo = None - -# For "manual" documents, if this is true, then toplevel headings are parts, -# not chapters. -#latex_use_parts = False - -# If true, show page references after internal links. -#latex_show_pagerefs = False - -# If true, show URL addresses after external links. -#latex_show_urls = False -# Documents to append as an appendix to all manuals. -#latex_appendices = [] - -# If false, no module index is generated. -#latex_domain_indices = True - - -# -- Options for manual page output --------------------------------------- +# -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, 'sedkit', u'sedkit Documentation', + (master_doc, 'sedkit', 'sedkit Documentation', [author], 1) ] -# If true, show URL addresses after external links. -#man_show_urls = False - -# -- Options for Texinfo output ------------------------------------------- +# -- Options for Texinfo output ---------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'sedkit', u'sedkit Documentation', - author, 'sedkit', 'Spectral energy distribution construction and analysis tools.', - 'Miscellaneous'), + (master_doc, 'sedkit', 'sedkit Documentation', + author, 'Joe Filippazzo', 'One line description of project.', + 'Miscellaneous'), ] -# Documents to append as an appendix to all manuals. -#texinfo_appendices = [] - -# If false, no module index is generated. -#texinfo_domain_indices = True - -# How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' -# If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# -- Extension configuration ------------------------------------------------- diff --git a/sedkit/__init__.py b/sedkit/__init__.py index a3d526ea..9ed592f4 100644 --- a/sedkit/__init__.py +++ b/sedkit/__init__.py @@ -7,8 +7,25 @@ Author: Joe Filippazzo, jfilippazzo@stsci.edu """ +import re from .catalog import Catalog from .sed import SED, VegaSED from .spectrum import Spectrum, FileSpectrum, Vega, Blackbody, ModelSpectrum from .modelgrid import ModelGrid, BTSettl, SpexPrismLibrary + +__version_commit__ = '' +_regex_git_hash = re.compile(r'.*\+g(\w+)') + +__version__ = '1.1.0' + +# from pkg_resources import get_distribution, DistributionNotFound +# try: +# __version__ = get_distribution(__name__).version +# except DistributionNotFound: +# __version__ = 'dev' + +if '+' in __version__: + commit = _regex_git_hash.match(__version__).groups() + if commit: + __version_commit__ = commit[0] diff --git a/setup.py b/setup.py index 273e201d..a90ca861 100755 --- a/setup.py +++ b/setup.py @@ -1,31 +1,82 @@ -#! /usr/bin/env python -# -*- coding: utf-8 -*- +#!/usr/bin/env python +from setuptools import setup, find_packages, Extension, Command + +# allows you to build sphinx docs from the package +# main directory with "python setup.py build_sphinx" try: - from setuptools import setup, find_packages - setup + from sphinx.cmd.build import build_main + from sphinx.setup_command import BuildDoc + + class BuildSphinx(BuildDoc): + """Build Sphinx documentation after compiling C source files""" + + description = 'Build Sphinx documentation' + + def initialize_options(self): + BuildDoc.initialize_options(self) + + def finalize_options(self): + BuildDoc.finalize_options(self) + + def run(self): + build_cmd = self.reinitialize_command('build_ext') + build_cmd.inplace = 1 + self.run_command('build_ext') + build_main(['-b', 'html', './docs', './docs/_build/html']) + except ImportError: - from distutils.core import setup - setup + class BuildSphinx(Command): + user_options = [] + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + print('!\n! Sphinx is not installed!\n!', file=sys.stderr) + exit(1) + +DOCS_REQUIRE = [ + 'nbsphinx', + 'sphinx', + 'sphinx-automodapi', + 'sphinx-rtd-theme', + 'stsci-rtd-theme', + 'extension-helpers', +] +TESTS_REQUIRE = [ + 'pytest', +] setup( name='sedkit', - version='1.0.8', description='Spectral energy distribution construction and analysis tools', - url='https://github.com/hover2pi/sedkit', author='Joe Filippazzo', author_email='jfilippazzo@stsci.edu', license='MIT', + url='https://github.com/hover2pi/sedkit', + keywords=['astronomy'], classifiers=[ - 'Development Status :: 4 - Beta', 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.6', + 'License :: OSI Approved :: BSD License', + 'Operating System :: OS Independent', + 'Programming Language :: Python', + 'Topic :: Scientific/Engineering :: Astronomy', + 'Topic :: Software Development :: Libraries :: Python Modules', ], - keywords='astrophysics', - packages=find_packages(exclude=['contrib', 'docs', 'tests*']), - install_requires=['numpy','astropy','bokeh','pysynphot','scipy','astroquery','dustmaps', 'pandas', 'svo_filters', 'healpy'], + packages=find_packages(exclude=["examples"]), + use_scm_version=True, + setup_requires=['setuptools_scm'], + install_requires=['numpy', 'astropy', 'bokeh', 'emcee', 'pysynphot', 'scipy', 'astroquery', 'dustmaps', 'pandas','svo_filters', 'healpy'], include_package_data=True, - -) \ No newline at end of file + extras_require={ + 'docs': DOCS_REQUIRE, + 'test': TESTS_REQUIRE, + }, + tests_require=TESTS_REQUIRE, + cmdclass={ + 'build_sphinx': BuildSphinx + },) From d2742944d7a31f2ab32d5e5ef3d0378eba297673 Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Tue, 19 Jan 2021 10:40:18 -0500 Subject: [PATCH 18/19] Broken build --- docs/conf.py | 7 +++---- sedkit/tests/test_sed.py | 5 ++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index fb0f7510..2ba74bed 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,7 +13,6 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. from sedkit import __version__ -import stsci_rtd_theme # -- Project information ----------------------------------------------------- @@ -80,9 +79,9 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -#html_theme = 'alabaster' +html_theme = 'alabaster' #html_theme = "sphinx_rtd_theme" -html_theme = "stsci_rtd_theme" +# html_theme = "stsci_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -94,7 +93,7 @@ # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] -html_theme_path = [stsci_rtd_theme.get_html_theme_path()] +# html_theme_path = [stsci_rtd_theme.get_html_theme_path()] # Custom sidebar templates, must be a dictionary that maps document names # to template names. diff --git a/sedkit/tests/test_sed.py b/sedkit/tests/test_sed.py index 71ba4142..6504332f 100644 --- a/sedkit/tests/test_sed.py +++ b/sedkit/tests/test_sed.py @@ -187,11 +187,10 @@ def test_plot(self): """Test plotting method""" v = sed.VegaSED() v.calculate_synthetic_photometry() - v.fit_blackbody() bt = mg.BTSettl() v.fit_modelgrid(bt) v.results - fig = v.plot(integral=True, synthetic_photometry=True, blackbody=True, best_fit=True) + fig = v.plot(integral=True, synthetic_photometry=True, best_fit=True) def test_no_photometry(self): """Test that a purely photometric SED can be creted""" @@ -232,7 +231,7 @@ def test_find_SDSS_spectra(self): """Test the find_SDSS_spectra method""" s = sed.SED() s.sky_coords = SkyCoord('0h8m05.63s +14d50m23.3s', frame='icrs') - s.find_SDSS_spectra(search_radius=1 * q.degree) + s.find_SDSS_spectra(search_radius=20 * q.arcsec) assert len(s.spectra) > 0 def test_run_methods(self): From 6f5004d075a6715b2ce8b7f1530ffefdc3b25dbc Mon Sep 17 00:00:00 2001 From: Joe Filippazzo Date: Tue, 19 Jan 2021 11:35:34 -0500 Subject: [PATCH 19/19] Tests for uncertainties module --- sedkit/tests/test_spectrum.py | 2 +- sedkit/tests/test_uncertainties.py | 63 ++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 sedkit/tests/test_uncertainties.py diff --git a/sedkit/tests/test_spectrum.py b/sedkit/tests/test_spectrum.py index e936d072..951b8f73 100644 --- a/sedkit/tests/test_spectrum.py +++ b/sedkit/tests/test_spectrum.py @@ -79,7 +79,7 @@ def test_model_fit(self): # Test MCMC fit bt = mg.BTSettl() spec = bt.get_spectrum(teff=2456, logg=5.5, meta=0, alpha=0, snr=100) - spec.mcmc_fit(bt, name='Test') + spec.mcmc_fit(bt, name='Test', plot=True) def test_addition(self): """Test that spectra are normalized and combined properly""" diff --git a/sedkit/tests/test_uncertainties.py b/sedkit/tests/test_uncertainties.py new file mode 100644 index 00000000..0c542ecd --- /dev/null +++ b/sedkit/tests/test_uncertainties.py @@ -0,0 +1,63 @@ +import unittest + +import astropy.units as q + +from .. import uncertainties as un + + +class TestUnum(unittest.TestCase): + """Tests for the Unum class""" + def setUp(self): + """Setup the tests""" + self.sym = un.Unum(10.1, 0.2) + self.asym = un.Unum(9.3, 0.08, 0.11) + + def test_attrs(self): + """Test attributes""" + x = self.sym + x.value + x.quantiles + + def test_add(self): + """Test add method""" + x = self.sym + self.asym + + def test_mul(self): + """Test mul method""" + x = self.sym * self.asym + + def test_sub(self): + """Test sub method""" + x = self.sym - self.asym + + def test_pow(self): + """Test pow method""" + x = self.sym ** 2 + + def test_truediv(self): + """Test truediv method""" + x = self.sym / self.asym + + def test_floordiv(self): + """Test floordiv method""" + x = self.sym // self.asym + + def test_plot(self): + """Test plot method""" + x = self.sym + x.plot() + + def test_sample_from_errors(self): + """Test the sample_from_errors method""" + # Test symmetric error case + x = self.sym + x.sample_from_errors() + x.sample_from_errors(low_lim=0, up_lim=100) + + # Test asymmetric error case + y = self.asym + y.sample_from_errors() + y.sample_from_errors(low_lim=0, up_lim=100) + + +