From 89eb4c98bd1da0b712d280bbeaa78bebeeb8e8d9 Mon Sep 17 00:00:00 2001 From: conoromand Date: Tue, 4 Jul 2023 14:56:15 +0200 Subject: [PATCH 1/4] Added band scaling, tested with data, not with lightcurves --- redback/plotting.py | 58 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/redback/plotting.py b/redback/plotting.py index d78b646a2..a7fbb1ef0 100644 --- a/redback/plotting.py +++ b/redback/plotting.py @@ -11,6 +11,7 @@ import redback from redback.utils import KwargsAccessorWithDefault +import ipdb class _FilenameGetter(object): def __init__(self, suffix: str) -> None: @@ -40,6 +41,7 @@ class Plotter(object): legend_cols = KwargsAccessorWithDefault("legend_cols", 2) color = KwargsAccessorWithDefault("color", "k") band_labels = KwargsAccessorWithDefault("band_labels", None) + band_scaling = KwargsAccessorWithDefault("band_scaling", {}) dpi = KwargsAccessorWithDefault("dpi", 300) elinewidth = KwargsAccessorWithDefault("elinewidth", 2) errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "x") @@ -72,10 +74,10 @@ class Plotter(object): uncertainty_mode = KwargsAccessorWithDefault("uncertainty_mode", "random_models") plot_max_likelihood = KwargsAccessorWithDefault("plot_max_likelihood", True) - xlim_high_multiplier = 2.0 - xlim_low_multiplier = 0.5 - ylim_high_multiplier = 2.0 - ylim_low_multiplier = 0.5 + xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 2.0) + xlim_low_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.5) + ylim_high_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 2.0) + ylim_low_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.5) def __init__(self, transient: Union[redback.transient.Transient, None], **kwargs) -> None: """ @@ -87,6 +89,7 @@ def __init__(self, transient: Union[redback.transient.Transient, None], **kwargs :keyword legend_cols: Same as matplotlib legend columns. :keyword color: Color of the data points. :keyword band_labels: List with the names of the bands. + :keyword band_scaling: Dict with the scaling for each band. First entry should be {type: '+' or '*'} for different types. :keyword dpi: Same as matplotlib dpi. :keyword elinewidth: same as matplotlib elinewidth :keyword errorbar_fmt: 'fmt' argument of `ax.errorbar`. @@ -519,11 +522,25 @@ def plot_data( continue if isinstance(label, float): label = f"{label:.2e}" - ax.errorbar( - self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices], - xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices], - fmt=self.errorbar_fmt, ms=self.ms, color=color, - elinewidth=self.elinewidth, capsize=self.capsize, label=label) + if band in self.band_scaling: + if self.band_scaling.get("type") == '*': + ax.errorbar( + self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] * self.band_scaling.get(band), + xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices] * self.band_scaling.get(band), + fmt=self.errorbar_fmt, ms=self.ms, color=color, + elinewidth=self.elinewidth, capsize=self.capsize, label=label) + elif self.band_scaling.get("type") == '+': + ax.errorbar( + self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] + self.band_scaling.get(band), + xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices] + self.band_scaling.get(band), + fmt=self.errorbar_fmt, ms=self.ms, color=color, + elinewidth=self.elinewidth, capsize=self.capsize, label=label) + else: + ax.errorbar( + self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices], + xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices], + fmt=self.errorbar_fmt, ms=self.ms, color=color, + elinewidth=self.elinewidth, capsize=self.capsize, label=label) self._set_x_axis(axes=ax) self._set_y_axis_data(ax) @@ -566,16 +583,33 @@ def plot_lightcurve( frequency = redback.utils.bands_to_frequency([band]) self._model_kwargs['frequency'] = np.ones(len(times)) * frequency if self.plot_max_likelihood: + ipdb.set_trace() ys = self.model(times, **self._max_like_params, **self._model_kwargs) - axes.plot(times - self._reference_mjd_date, ys, color=color, alpha=0.65, lw=2) + if band in self.band_scaling: + if self.band_scaling.get("type") == '*': + axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color=color, alpha=0.65, lw=2) + elif self.band_scaling.get("type") == '+': + axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color=color, alpha=0.65, lw=2) + else: + axes.plot(times - self._reference_mjd_date, ys, color=color, alpha=0.65, lw=2) random_ys_list = [self.model(times, **random_params, **self._model_kwargs) for random_params in self._get_random_parameters()] if self.uncertainty_mode == "random_models": for ys in random_ys_list: - axes.plot(times - self._reference_mjd_date, ys, color='red', alpha=0.05, lw=2, zorder=-1) + if self.band_scaling.get("type") == '*': + axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color='red', alpha=0.05, lw=2, zorder=-1) + elif self.band_scaling.get("type") == '+': + axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color='red', alpha=0.05, lw=2, zorder=-1) + else: + axes.plot(times - self._reference_mjd_date, ys, color='red', alpha=0.05, lw=2, zorder=-1) elif self.uncertainty_mode == "credible_intervals": - lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list) + if self.band_scaling.get("type") == '*': + lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) * self.band_scaling.get(band)) + elif self.band_scaling.get("type") == '+': + lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) + self.band_scaling.get(band)) + else: + lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list)) axes.fill_between( times - self._reference_mjd_date, lower_bound, upper_bound, alpha=self.uncertainty_band_alpha, color=color) From 6c996e2448b2ada4e95e004e53d2441656b49a25 Mon Sep 17 00:00:00 2001 From: conoromand Date: Tue, 4 Jul 2023 15:47:36 +0200 Subject: [PATCH 2/4] tested band_scaling, seems to work --- redback/plotting.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/redback/plotting.py b/redback/plotting.py index a7fbb1ef0..f817e4a33 100644 --- a/redback/plotting.py +++ b/redback/plotting.py @@ -11,8 +11,6 @@ import redback from redback.utils import KwargsAccessorWithDefault -import ipdb - class _FilenameGetter(object): def __init__(self, suffix: str) -> None: self.suffix = suffix @@ -583,7 +581,6 @@ def plot_lightcurve( frequency = redback.utils.bands_to_frequency([band]) self._model_kwargs['frequency'] = np.ones(len(times)) * frequency if self.plot_max_likelihood: - ipdb.set_trace() ys = self.model(times, **self._max_like_params, **self._model_kwargs) if band in self.band_scaling: if self.band_scaling.get("type") == '*': @@ -597,17 +594,19 @@ def plot_lightcurve( for random_params in self._get_random_parameters()] if self.uncertainty_mode == "random_models": for ys in random_ys_list: - if self.band_scaling.get("type") == '*': - axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color='red', alpha=0.05, lw=2, zorder=-1) - elif self.band_scaling.get("type") == '+': - axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color='red', alpha=0.05, lw=2, zorder=-1) + if band in self.band_scaling: + if self.band_scaling.get("type") == '*': + axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color='red', alpha=0.05, lw=2, zorder=-1) + elif self.band_scaling.get("type") == '+': + axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color='red', alpha=0.05, lw=2, zorder=-1) else: axes.plot(times - self._reference_mjd_date, ys, color='red', alpha=0.05, lw=2, zorder=-1) elif self.uncertainty_mode == "credible_intervals": - if self.band_scaling.get("type") == '*': - lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) * self.band_scaling.get(band)) - elif self.band_scaling.get("type") == '+': - lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) + self.band_scaling.get(band)) + if band in self.band_scaling: + if self.band_scaling.get("type") == '*': + lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) * self.band_scaling.get(band)) + elif self.band_scaling.get("type") == '+': + lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) + self.band_scaling.get(band)) else: lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list)) axes.fill_between( From a9fa5d7afc1c58c5f527606ff8c1c4613a56f425 Mon Sep 17 00:00:00 2001 From: conoromand Date: Wed, 2 Aug 2023 13:35:07 +0200 Subject: [PATCH 3/4] added label changes for band scaling --- redback/plotting.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/redback/plotting.py b/redback/plotting.py index f817e4a33..0dfde99ca 100644 --- a/redback/plotting.py +++ b/redback/plotting.py @@ -510,7 +510,10 @@ def plot_data( if band in self._filters: color = self._colors[list(self._filters).index(band)] if band_label_generator is None: - label = band + if band in self.band_scaling: + label = band + ' ' + self.band_scaling.get("type") + ' ' + str(self.band_scaling.get(band)) + else: + label = band else: label = next(band_label_generator) elif self.plot_others: From 450c3306dec412bfcefdc1894c9d01ef752329d2 Mon Sep 17 00:00:00 2001 From: conoromand Date: Wed, 2 Aug 2023 14:32:28 +0200 Subject: [PATCH 4/4] Changed append to prepand and * to x --- redback/plotting.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/redback/plotting.py b/redback/plotting.py index 0dfde99ca..78abfeb67 100644 --- a/redback/plotting.py +++ b/redback/plotting.py @@ -87,7 +87,7 @@ def __init__(self, transient: Union[redback.transient.Transient, None], **kwargs :keyword legend_cols: Same as matplotlib legend columns. :keyword color: Color of the data points. :keyword band_labels: List with the names of the bands. - :keyword band_scaling: Dict with the scaling for each band. First entry should be {type: '+' or '*'} for different types. + :keyword band_scaling: Dict with the scaling for each band. First entry should be {type: '+' or 'x'} for different types. :keyword dpi: Same as matplotlib dpi. :keyword elinewidth: same as matplotlib elinewidth :keyword errorbar_fmt: 'fmt' argument of `ax.errorbar`. @@ -511,7 +511,7 @@ def plot_data( color = self._colors[list(self._filters).index(band)] if band_label_generator is None: if band in self.band_scaling: - label = band + ' ' + self.band_scaling.get("type") + ' ' + str(self.band_scaling.get(band)) + label = str(self.band_scaling.get(band)) + ' ' + self.band_scaling.get("type") + ' ' + band else: label = band else: @@ -524,7 +524,7 @@ def plot_data( if isinstance(label, float): label = f"{label:.2e}" if band in self.band_scaling: - if self.band_scaling.get("type") == '*': + if self.band_scaling.get("type") == 'x': ax.errorbar( self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] * self.band_scaling.get(band), xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices] * self.band_scaling.get(band), @@ -586,7 +586,7 @@ def plot_lightcurve( if self.plot_max_likelihood: ys = self.model(times, **self._max_like_params, **self._model_kwargs) if band in self.band_scaling: - if self.band_scaling.get("type") == '*': + if self.band_scaling.get("type") == 'x': axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color=color, alpha=0.65, lw=2) elif self.band_scaling.get("type") == '+': axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color=color, alpha=0.65, lw=2) @@ -598,7 +598,7 @@ def plot_lightcurve( if self.uncertainty_mode == "random_models": for ys in random_ys_list: if band in self.band_scaling: - if self.band_scaling.get("type") == '*': + if self.band_scaling.get("type") == 'x': axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color='red', alpha=0.05, lw=2, zorder=-1) elif self.band_scaling.get("type") == '+': axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color='red', alpha=0.05, lw=2, zorder=-1) @@ -606,7 +606,7 @@ def plot_lightcurve( axes.plot(times - self._reference_mjd_date, ys, color='red', alpha=0.05, lw=2, zorder=-1) elif self.uncertainty_mode == "credible_intervals": if band in self.band_scaling: - if self.band_scaling.get("type") == '*': + if self.band_scaling.get("type") == 'x': lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) * self.band_scaling.get(band)) elif self.band_scaling.get("type") == '+': lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) + self.band_scaling.get(band))