Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeplerCBVCorrector can now take a KeplerLightCurve Object #168

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
29 changes: 19 additions & 10 deletions pyke/lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ class KeplerLightCurve(LightCurve):
Array indicating the quality of each data point
quality_bitmask : int
Bitmask specifying quality flags of cadences that should be ignored
quality_mask : array-like
Mask applied to tpf
channel : int
Channel number
campaign : int
Expand All @@ -322,14 +324,20 @@ class KeplerLightCurve(LightCurve):
"""

def __init__(self, time, flux, flux_err=None, centroid_col=None,
centroid_row=None, quality=None, quality_bitmask=None,
centroid_row=None, quality=None,
quality_bitmask=KeplerQualityFlags.DEFAULT_BITMASK,
quality_mask=None,
channel=None, campaign=None, quarter=None, mission=None,
cadenceno=None, keplerid=None):
super(KeplerLightCurve, self).__init__(time, flux, flux_err)
self.centroid_col = centroid_col
self.centroid_row = centroid_row
self.quality = quality
self.quality_bitmask = quality_bitmask
if quality_mask is None:
self.quality_mask = np.asarray(self.quality & self.quality_bitmask, dtype=bool)
else:
self.quality_mask = quality_mask
self.channel = channel
self.campaign = campaign
self.quarter = quarter
Expand Down Expand Up @@ -562,8 +570,10 @@ def lc_file(self, value):
# this enables `lc_file` to be either a string
# or an object from KeplerLightCurveFile
if isinstance(value, str):
self._lc_file = KeplerLightCurveFile(value)
self._lc_file = KeplerLightCurveFile(value).SAP_FLUX
elif isinstance(value, KeplerLightCurveFile):
self._lc_file = value.SAP_FLUX
elif isinstance(value, KeplerLightCurve):
self._lc_file = value
else:
raise ValueError("lc_file must be either a string or a"
Expand Down Expand Up @@ -603,22 +613,21 @@ def correct(self, cbvs=[1, 2]):
cbv_array.append(cbv_data.field('VECTOR_{}'.format(i))[self.lc_file.quality_mask])
cbv_array = np.asarray(cbv_array)

sap_lc = self.lc_file.SAP_FLUX
median_sap_flux = np.nanmedian(sap_lc.flux)
norm_sap_flux = sap_lc.flux / median_sap_flux - 1
norm_err_sap_flux = sap_lc.flux_err / median_sap_flux
median_flux = np.nanmedian(self.lc_file.flux)
norm_flux = self.lc_file.flux / median_flux - 1
norm_err_flux = self.lc_file.flux_err / median_flux

def mean_model(*theta):
coeffs = np.asarray(theta)
return np.dot(coeffs, cbv_array)

loss = self.loss_function(data=norm_sap_flux, mean=mean_model,
var=norm_err_sap_flux)
loss = self.loss_function(data=norm_flux, mean=mean_model,
var=norm_err_flux)
self._opt_result = loss.fit(x0=np.zeros(len(cbvs)), method='L-BFGS-B')
self._coeffs = self._opt_result.x
flux_hat = sap_lc.flux - median_sap_flux * mean_model(self._coeffs)
flux_hat = self.lc_file.flux - median_flux * mean_model(self._coeffs)

return LightCurve(time=sap_lc.time, flux=flux_hat.reshape(-1))
return LightCurve(time=self.lc_file.time, flux=flux_hat.reshape(-1))

def get_cbv_url(self):
# gets the html page and finds all references to 'a' tag
Expand Down
9 changes: 6 additions & 3 deletions pyke/targetpixelfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,14 @@ def to_lightcurve(self, aperture_mask='pipeline'):
centroid_col=centroid_col,
centroid_row=centroid_row,
quality=self.quality,
quality_bitmask=self.quality_bitmask,
quality_mask=self._quality_mask(self.quality_bitmask),
channel=self.channel,
campaign=self.campaign,
quarter=self.quarter,
mission=self.mission,
cadenceno=self.cadenceno)
cadenceno=self.cadenceno,
keplerid=self.keplerid)

def centroids(self, aperture_mask='pipeline'):
"""Returns centroids based on sample moments.
Expand All @@ -250,9 +253,9 @@ def centroids(self, aperture_mask='pipeline'):
col_centr, row_centr : tuple
Arrays containing centroids for column and row at each cadence
"""
if aperture_mask == 'pipeline':
if aperture_mask is 'pipeline':
Copy link
Member

@mirca mirca Jan 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not correct, should be ==. is checks if aperture_mask and 'pipeline' are the same object, whereas == checks for the value.

aperture_mask = self.pipeline_mask
elif aperture_mask == 'all':
elif aperture_mask is 'all':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

mask = ~np.isnan(self.hdu[1].data['FLUX'][100])
aperture_mask = np.ones((self.shape[1], self.shape[2]),
dtype=bool) * mask
Expand Down