Skip to content

Commit

Permalink
begin refactor of continuum-fitting to enable future dev
Browse files Browse the repository at this point in the history
  • Loading branch information
moustakas committed Aug 12, 2024
1 parent 0a16586 commit 102b990
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 135 deletions.
255 changes: 139 additions & 116 deletions py/fastspecfit/continuum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,8 +1014,106 @@ def _get_rchi2(chi2, ndof, nfree):
return rchi2_spec, rchi2_phot, rchi2_tot


def _younger_than_universe(age, tuniv, agepad=0.5):
"""Return the indices of the templates younger than the age of the universe
(plus an agepadding amount) at the given redshift. age in yr, agepad and
tuniv in Gyr
"""
return np.where(age <= 1e9 * (agepad + tuniv))[0]


def _compute_vdisp(redshift, specwave, specivar):
"""Determine if we can solve for the velocity dispersion."""
restwave = specwave / (1. + redshift)
I = np.where((specivar > 0) & (restwave > 3500.) & (restwave < 5500.))[0]
compute_vdisp = (len(I) > 0) and (np.ptp(restwave[I]) > 500.)
return compute_vdisp


def continuum_fastphot(redshift, objflam, objflamivar, templates,
CTools, ebv_guess=0.05, tuniv=None,
constrain_age=False):
"""Model the broadband photometry.
"""
vdisp = templates.vdisp_nominal
#log.info(f'Adopting nominal vdisp={vdisp:.0f} km/s.')

# Optionally ignore templates which are older than the age of the
# universe at the redshift of the object.
if constrain_age:
agekeep = _younger_than_universe(templates.info['age'].value, tuniv)
else:
agekeep = np.arange(templates.ntemplates)
nage = len(agekeep)

if np.all(objflamivar == 0.):
log.info('All photometry is masked.')
coeff = np.zeros(nage) # nage not nsed
rchi2_cont, rchi2_phot = 0., 0.
dn4000_model = 0.
sedmodel = np.zeros(len(templates.wave))
else:
# Get the coefficients and chi2 at the nominal velocity dispersion.
t0 = time.time()

objflamistd = np.sqrt(objflamivar)

# maintain backwards-compatibility
if templates.use_legacy_fitting:
sedtemplates, sedphot_flam = CTools.templates2data(
templates.flux_nomvdisp[:, agekeep],
templates.wave, flamphot=True,
redshift=redshift, dluminosity=data['dluminosity'],
vdisp=None, synthphot=True, photsys=data['photsys'])
sedflam = sedphot_flam * CTools.massnorm * FLUXNORM
coeff, rchi2_phot = CTools.call_nnls(sedflam, objflam, objflamivar)
rchi2_phot /= np.sum(objflamivar > 0.) # dof???
else:
ebv, _, coeff, resid = CTools.fit_stellar_continuum(
templates.flux_nomvdisp[:, agekeep], # [npix,nsed]
fit_vdisp=False, vdisp_guess=vdisp, ebv_guess=ebv_guess,
objflam=objflam, objflamistd=objflamistd,
synthphot=True, synthspec=False)
log.info(f'Best-fitting E(B-V)={ebv:.3f} mag.')

_, rchi2_phot, rchi2_cont = CTools.stellar_continuum_chi2(
resid, ncoeff=len(coeff), vdisp_fitted=False,
ndof_phot=np.sum(objflamivar > 0.))

sedmodel = CTools.optimizer_saved_contmodel

log.info(f'Fitting {nage} models took {time.time()-t0:.2f} seconds.')

if np.all(coeff == 0.):
log.warning('Continuum coefficients are all zero.')
sedmodel = np.zeros(len(templates.wave))
dn4000_model = 0.
else:
# Measure Dn(4000) from the line-free model.
if templates.use_legacy_fitting:
sedmodel = sedtemplates.dot(coeff)
sedtemplates_nolines, _ = CTools.templates2data(
templates.flux_nolines_nomvdisp[:, agekeep],
templates.wave,
redshift=redshift, dluminosity=data['dluminosity'],
vdisp=None, synthphot=False)
sedmodel_nolines = sedtemplates_nolines.dot(coeff)
else:
sedmodel_nolines = CTools.build_stellar_continuum(
templates.flux_nolines_nomvdisp[:, agekeep], coeff,
ebv=ebv, vdisp=None, dust_emission=False)

dn4000_model, _ = Photometry.get_dn4000(
templates.wave, sedmodel_nolines, rest=True)
log.info(f'Model Dn(4000)={dn4000_model:.3f}.')

return coeff, rchi2_cont, rchi2_phot, ebv, vdisp, dn4000_model, sedmodel


def continuum_specfit(data, result, templates, igm, phot,
constrain_age=False, no_smooth_continuum=False,
ebv_guess=0.05, constrain_age=False,
no_smooth_continuum=False,
fastphot=False, debug_plots=False):
"""Fit the non-negative stellar continuum of a single spectrum.
Expand All @@ -1037,14 +1135,6 @@ def continuum_specfit(data, result, templates, igm, phot,
- We solve for velocity dispersion if ...
"""
def younger_than_universe(age, tuniv, agepad=0.5):
"""Return the indices of the templates younger than the age of the universe
(plus an agepadding amount) at the given redshift. age in yr, agepad and
tuniv in Gyr
"""
return np.where(age <= 1e9 * (agepad + tuniv))[0]


tall = time.time()

CTools = ContinuumTools(igm, phot, templates, data, fastphot)
Expand All @@ -1056,7 +1146,6 @@ def younger_than_universe(age, tuniv, agepad=0.5):
photometry = data['phot']
objflam = photometry['flam'].value * FLUXNORM
objflamivar = (photometry['flam_ivar'].value / FLUXNORM**2) * phot.bands_to_fit
assert(np.all(objflamivar >= 0.))

if np.any(phot.bands_to_fit):
# Require at least one photometric optical band; do not just fit the IR
Expand All @@ -1065,129 +1154,62 @@ def younger_than_universe(age, tuniv, agepad=0.5):
opt = ((lambda_eff > 3e3) & (lambda_eff < 1e4))
if np.all(objflamivar[opt] == 0.):
log.warning('All optical bands are masked; masking all photometry.')
objflamivar[:] = 0.0

objflamistd = np.sqrt(objflamivar)
objflamivar[:] = 0.

# Optionally ignore templates which are older than the age of the
# universe at the redshift of the object.
if constrain_age:
agekeep = younger_than_universe(templates.info['age'].value, data['tuniv'])
else:
agekeep = np.arange(templates.ntemplates)
nage = len(agekeep)

vdisp_nominal = templates.vdisp_nominal
ebv_guess = 0.05
import pdb ; pdb.set_trace()

# Photometry-only fitting.
if fastphot:
log.info(f'Adopting nominal vdisp={vdisp_nominal:.0f} km/s.')
vdisp = vdisp_nominal

if np.all(objflamivar == 0.):
log.info('All photometry is masked.')
coeff = np.zeros(nage) # nage not nsed
rchi2_cont, rchi2_phot = 0., 0.
dn4000_model = 0.
sedmodel = np.zeros(len(templates.wave))
else:
# Get the coefficients and chi2 at the nominal velocity dispersion.
t0 = time.time()

# maintain backwards-compatibility
if templates.use_legacy_fitting:
sedtemplates, sedphot_flam = CTools.templates2data(
templates.flux_nomvdisp[:, agekeep],
templates.wave, flamphot=True,
redshift=redshift, dluminosity=data['dluminosity'],
vdisp=None, synthphot=True, photsys=data['photsys'])
sedflam = sedphot_flam * CTools.massnorm * FLUXNORM
coeff, rchi2_phot = CTools.call_nnls(sedflam, objflam, objflamivar)
rchi2_phot /= np.sum(objflamivar > 0.) # dof???
else:
ebv, _, coeff, resid = CTools.fit_stellar_continuum(
templates.flux_nomvdisp[:, agekeep], # [npix,nsed]
fit_vdisp=False,
vdisp_guess=vdisp_nominal, ebv_guess=ebv_guess,
objflam=objflam, objflamistd=objflamistd,
synthphot=True, synthspec=False
)

_, rchi2_phot, rchi2_cont = CTools.stellar_continuum_chi2(
resid, ncoeff=len(coeff), vdisp_fitted=False,
ndof_phot=np.sum(objflamivar > 0.)
)

sedmodel = CTools.optimizer_saved_contmodel

log.info(f'Fitting {nage} models took {time.time()-t0:.2f} seconds.')

if np.all(coeff == 0.):
log.warning('Continuum coefficients are all zero.')
sedmodel = np.zeros(len(templates.wave))
dn4000_model = 0.
else:
# Measure Dn(4000) from the line-free model.
if templates.use_legacy_fitting:
sedmodel = sedtemplates.dot(coeff)
sedtemplates_nolines, _ = CTools.templates2data(
templates.flux_nolines_nomvdisp[:, agekeep],
templates.wave,
redshift=redshift, dluminosity=data['dluminosity'],
vdisp=None, synthphot=False)
sedmodel_nolines = sedtemplates_nolines.dot(coeff)
else:
sedmodel_nolines = CTools.build_stellar_continuum(
templates.flux_nolines_nomvdisp[:, agekeep], coeff,
ebv=ebv, vdisp=None,
dust_emission=False,
)

log.info(f'Best-fitting E(B-V)={ebv:.3f} mag.')

dn4000_model, _ = Photometry.get_dn4000(templates.wave,
sedmodel_nolines, rest=True)
log.info(f'Model Dn(4000)={dn4000_model:.3f}.')
coeff, rchi2_cont, rchi2_phot, ebv, vdisp, dn4000_model, sedmodel = \
continuum_fastphot(redshift, objflam, objflamivar, templates, CTools,
ebv_guess=ebv_guess, tuniv=data['tuniv'],
constrain_age=constrain_age)
else:
# Combine all three cameras; we will unpack them to build the
# best-fitting model (per-camera) below.
specwave = np.hstack(data['wave'])
specflux = np.hstack(data['flux'])
flamivar = np.hstack(data['ivar'])
specivar = flamivar * ~np.hstack(data['linemask']) # mask emission lines
specivar_nolinemask = np.hstack(data['ivar'])
specivar = specivar_nolinemask * np.logical_not(np.hstack(data['linemask'])) # mask emission lines

if np.all(specivar == 0.) or np.any(specivar < 0.):
specivar = flamivar # not great...
if np.all(specivar == 0.) or np.any(specivar < 0.):
errmsg = 'All pixels are masked or some inverse variances are negative!'
log.critical(errmsg)
raise ValueError(errmsg)

specistd = np.sqrt(specivar)

npix = len(specwave)

objflamistd = np.sqrt(objflamivar)

# Optionally ignore templates which are older than the age of the
# universe at the redshift of the object.
if constrain_age:
agekeep = _younger_than_universe(templates.info['age'].value, data['tuniv'])
else:
agekeep = np.arange(templates.ntemplates)
nage = len(agekeep)

vdisp_nominal = templates.vdisp_nominal

# We'll need the filters for the aperture correction, below.
filters_in = phot.synth_filters[data['photsys']]

# Solve for the velocity dispersion if the wavelength coverage is
# sufficient.
restwave = specwave / (1. + redshift)
Ivdisp = np.where((specivar > 0) & (restwave > 3500.) & (restwave < 5500.))[0]
compute_vdisp = (len(Ivdisp) > 0) and (np.ptp(restwave[Ivdisp]) > 500.)
# Solve for the velocity dispersion?
compute_vdisp = _compute_vdisp(redshift, specwave, specivar)

if len(data['cameras']) == 3:
log.info('S/N_{}={:.2f}, S/N_{}={:.2f}, S/N_{}={:.2f}, rest wavelength coverage={:.0f}-{:.0f} A.'.format(
data['cameras'][0], data['snr'][0],
data['cameras'][1], data['snr'][1],
data['cameras'][2], data['snr'][2],
restwave[0], restwave[-1]))

# Maintain backwards compatibility. With the old templates, the velocity
# dispersion and aperture corrections are determined separately, so we
# separate that code out from the new templates, where they are
# determined simultatneously.
#if len(data['cameras']) == 3:
# log.info('S/N_{}={:.2f}, S/N_{}={:.2f}, S/N_{}={:.2f}, rest wavelength coverage={:.0f}-{:.0f} A.'.format(
# data['cameras'][0], data['snr'][0],
# data['cameras'][1], data['snr'][1],
# data['cameras'][2], data['snr'][2],
# restwave[0], restwave[-1]))

# Maintain backwards compatibility. With the old templates, the
# velocity dispersion and aperture corrections are determined
# separately, so we separate that code out from the new templates,
# where they are determined simultaneously.
if templates.use_legacy_fitting:
if compute_vdisp:
t0 = time.time()
Expand All @@ -1212,7 +1234,7 @@ def younger_than_universe(age, tuniv, agepad=0.5):
f'one; adopting vdisp={vdisp_nominal:.0f} km/s.')
vdispbest, vdispivar = vdisp_nominal, 0.
else:
log.info(f'Best-fitting vdisp={vdispbest:.1f}+/-{1./np.sqrt(vdispivar):.1f} km/s.')
log.info(f'Best-fitting vdisp={vdispbest:.0f}+/-{1./np.sqrt(vdispivar):.0f} km/s.')
else:
vdispbest = vdisp_nominal
log.info(f'Finding vdisp failed; adopting vdisp={vdisp_nominal:.0f} km/s.')
Expand Down Expand Up @@ -1281,7 +1303,7 @@ def younger_than_universe(age, tuniv, agepad=0.5):
if np.any(I):
apercorr = median(apercorrs[I])

log.info(f'Median aperture correction = {apercorr:.3f} [{np.min(apercorrs):.3f}-{np.max(apercorrs):.3f}].')
log.info(f'Median aperture correction {apercorr:.3f} [{np.min(apercorrs):.3f}-{np.max(apercorrs):.3f}].')

if apercorr <= 0.:
log.warning('Aperture correction not well-defined; adopting 1.0.')
Expand Down Expand Up @@ -1374,7 +1396,7 @@ def younger_than_universe(age, tuniv, agepad=0.5):
if np.any(I):
apercorr = median(apercorrs[I])

log.info(f'Median aperture correction = {apercorr:.3f} ' + \
log.info(f'Median aperture correction {apercorr:.3f} ' + \
f'[{np.min(apercorrs):.3f}-{np.max(apercorrs):.3f}].')
if apercorr <= 0.:
log.warning('Aperture correction not well-defined; adopting 1.0.')
Expand Down Expand Up @@ -1426,9 +1448,9 @@ def younger_than_universe(age, tuniv, agepad=0.5):

if compute_vdisp:
#log.info(f'Best-fitting vdisp={vdisp:.1f}+/-{1./np.sqrt(vdispivar):.1f} km/s.')
log.info(f'Best-fitting vdisp={vdisp:.1f} km/s.')
log.info(f'Best-fitting vdisp={vdisp:.0f} km/s.')
else:
log.info(f'Insufficient wavelength coverage to compute vdisp; adopting nominal vdisp={vdisp:.1f} km/s')
log.info(f'Insufficient wavelength coverage to compute vdisp; adopting nominal vdisp={vdisp:.0f} km/s')

_, rchi2_phot, rchi2_cont = CTools.stellar_continuum_chi2(
resid, ncoeff=len(coeff), vdisp_fitted=compute_vdisp,
Expand All @@ -1450,7 +1472,8 @@ def younger_than_universe(age, tuniv, agepad=0.5):
sedmodel_nolines, rest=True)

# Get DN(4000). Specivar is line-masked so we can't use it!
dn4000, dn4000_ivar = Photometry.get_dn4000(specwave, specflux, flam_ivar=flamivar,
dn4000, dn4000_ivar = Photometry.get_dn4000(specwave, specflux,
flam_ivar=specivar_nolinemask,
redshift=redshift, rest=False)

if dn4000_ivar > 0:
Expand Down
13 changes: 8 additions & 5 deletions py/fastspecfit/fastspecfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ def fastspec_one(iobj, data, out_dtype,
emline_table = sc_data.emlines.table
templates = sc_data.templates

log.info(f'Continuum- and emission-line fitting object {iobj} [{phot.uniqueid_col.lower()} ' + \
f'{data["uniqueid"]}, z={data["redshift"]:.6f}].')
if fastphot:
log.info(f'Continuum fitting object {iobj} [{phot.uniqueid_col.lower()} ' + \
f'{data["uniqueid"]}, z={data["redshift"]:.6f}].')
else:
log.info(f'Continuum- and emission-line fitting object {iobj} [{phot.uniqueid_col.lower()} ' + \
f'{data["uniqueid"]}, z={data["redshift"]:.6f}].')

# output structure
out = BoxedScalar(out_dtype)
Expand Down Expand Up @@ -159,8 +163,6 @@ def fastspec(fastphot=False, stackfit=False, args=None, comm=None, verbose=False
data = Spec.read(mp_pool, fastphot=fastphot, debug_plots=args.debug_plots,
constrain_age=args.constrain_age)

import pdb ; pdb.set_trace()

ncoeff = sc_data.templates.ntemplates
out_dtype, out_units = get_output_dtype(Spec.specprod,
phot=sc_data.photometry,
Expand All @@ -184,9 +186,10 @@ def fastspec(fastphot=False, stackfit=False, args=None, comm=None, verbose=False
} for iobj in range(Spec.ntargets)]

_out = mp_pool.starmap(fastspec_one, fitargs)

out = list(zip(*_out))

import pdb ; pdb.set_trace()

meta = create_output_meta(Spec.meta, data,
phot=sc_data.photometry,
fastphot=fastphot, stackfit=stackfit)
Expand Down
Loading

0 comments on commit 102b990

Please sign in to comment.