diff --git a/csep/utils/plots.py b/csep/utils/plots.py index ff0f759d..ea612fb3 100644 --- a/csep/utils/plots.py +++ b/csep/utils/plots.py @@ -2,7 +2,7 @@ import shutil import string import warnings -from typing import TYPE_CHECKING, Optional, Any, List, Union, Tuple +from typing import TYPE_CHECKING, Optional, Any, List, Union, Tuple, Sequence import cartopy import cartopy.crs as ccrs @@ -76,7 +76,7 @@ # Consistency and Comparison tests "capsize": 2, "hbars": True, - # Specific to spatial plotting + # Spatial plotting "grid_labels": True, "grid_fontsize": 8, "region_color": "black", @@ -100,7 +100,8 @@ def plot_magnitude_vs_time( ax: Optional[Axes] = None, color: Optional[str] = "steelblue", size: Optional[int] = 4, - mag_scale: Optional[int] = 6, + max_size: Optional[int] = 300, + power: Optional[int] = 4, alpha: Optional[float] = 0.5, show: bool = False, **kwargs: Any, @@ -118,9 +119,11 @@ def plot_magnitude_vs_time( Color of the scatter plot points. If not provided, defaults to value in `DEFAULT_PLOT_ARGS`. size (int): - Size of the scatter plot markers. - mag_scale (int): - Scaling factor for the magnitudes. + Size of the event with the minimum magnitude + max_size (int): + Size of the event with the maximum magnitude + power (int): + Power scaling of the scatter sizing. alpha (float): Transparency level for the scatter plot points. If not provided, defaults to value in `DEFAULT_PLOT_ARGS`. @@ -149,7 +152,7 @@ def plot_magnitude_vs_time( mag, marker="o", c=color, - s=_autosize_scatter(size, mag, mag_scale), + s=_autosize_scatter(mag, min_size=size, max_size=max_size, power=power), alpha=alpha, ) @@ -1419,10 +1422,14 @@ def plot_catalog( ax: Optional[matplotlib.axes.Axes] = None, projection: Optional[Union[ccrs.Projection, str]] = ccrs.PlateCarree(), show: bool = False, - extent: Optional[List[float]] = None, + extent: Optional[Sequence[float]] = None, set_global: bool = False, - mag_scale: float = 1, - mag_ticks: Optional[List[float]] = None, + mag_ticks: Optional[Union[Sequence[float], np.ndarray, int]] = None, + size: float = 15, + max_size: float = 300, + power: float = 3, + min_val: Optional[float] = None, + max_val: Optional[float] = None, plot_region: bool = False, **kwargs, ) -> matplotlib.axes.Axes: @@ -1436,10 +1443,14 @@ def plot_catalog( extent (list): Default 1.05 * :func:`catalog.region.get_bbox()`. projection (cartopy.crs.Projection): Projection to be used in the underlying basemap set_global (bool): Display the complete globe as basemap. - mag_scale (float): Scaling of the scatter. - mag_ticks (list): Ticks to display in the legend. + size (float): Size of the event with the minimum magnitude + max_size (float): Size of the catalog's maximum magnitude + power (float, list): Power scaling of the scatter sizing. + min_val (float): Override minimum magnitude of the catalog for scatter sizing + max_val (float): Override maximum magnitude of the catalog for scatter sizing + mag_ticks (list, int): Ticks to display in the legend. plot_region (bool): Flag to plot the catalog region border. - kwargs: size, alpha, markercolor, markeredgecolor, figsize, legend, + kwargs: alpha, markercolor, markeredgecolor, figsize, legend, legend_title, legend_labelspacing, legend_borderpad, legend_framealpha Returns: @@ -1455,10 +1466,11 @@ def plot_catalog( ax = plot_basemap(basemap, extent, ax=ax, set_global=set_global, show=False, **plot_args) # Plot catalog - scatter = ax.scatter( + ax.scatter( catalog.get_longitudes(), catalog.get_latitudes(), - s=_size_map(plot_args["size"], catalog.get_magnitudes(), mag_scale), + s=_autosize_scatter(values=catalog.get_magnitudes(), min_size=size, max_size=max_size, + power=power, min_val=min_val, max_val=max_val), transform=ccrs.PlateCarree(), color=plot_args["markercolor"], edgecolors=plot_args["markeredgecolor"], @@ -1467,22 +1479,32 @@ def plot_catalog( # Legend if plot_args["legend"]: - mw_range = [min(catalog.get_magnitudes()), max(catalog.get_magnitudes())] - - if isinstance(mag_ticks, (tuple, list, numpy.ndarray)): - if not numpy.all([mw_range[0] <= i <= mw_range[1] for i in mag_ticks]): - print("Magnitude ticks do not lie within the catalog magnitude range") - elif mag_ticks is None: - mag_ticks = numpy.linspace(mw_range[0], mw_range[1], 4) - - handles, labels = scatter.legend_elements( - prop="sizes", - num=list(_size_map(plot_args["size"], mag_ticks, mag_scale)), - alpha=0.3, + if isinstance(mag_ticks, (list, np.ndarray)): + mag_ticks = np.array(mag_ticks) + else: + mw_range = [min(catalog.get_magnitudes()), max(catalog.get_magnitudes())] + mag_ticks = np.linspace(mw_range[0], mw_range[1], mag_ticks or 4, endpoint=True) + + # Map mag_ticks to marker sizes using the custom size mapping function + legend_sizes = _autosize_scatter( + values=mag_ticks, + min_size=size, + max_size=max_size, + power=power, + min_val=min_val or np.min(catalog.get_magnitudes()), + max_val=max_val or np.max(catalog.get_magnitudes()) ) + + # Create custom legend handles + handles = [pyplot.Line2D([0], [0], marker='o', lw=0, label=str(m), + markersize=np.sqrt(s), markerfacecolor='gray', alpha=0.5, + markeredgewidth=0.8, + markeredgecolor='black') + for m, s in zip(mag_ticks, legend_sizes)] + ax.legend( handles, - numpy.round(mag_ticks, 1), + np.round(mag_ticks, 1), loc=plot_args["legend_loc"], handletextpad=5, title=plot_args.get("legend_title") or "Magnitudes", @@ -2127,25 +2149,14 @@ def _plot_pvalues_and_intervals(test_results, ax, var=None): return ax -def _autosize_scatter(markersize, values, scale): - if isinstance(scale, (int, float)): - # return (values - min(values) + markersize) ** scale # Adjust this formula as needed for better visualization - # return mark0ersize * (1 + (values - numpy.min(values)) / (numpy.max(values) - numpy.min(values)) ** scale) - return markersize / (scale ** min(values)) * numpy.power(values, scale) - - elif isinstance(scale, (numpy.ndarray, list)): - return scale - else: - raise ValueError("scale data type not supported") - +def _autosize_scatter(values, min_size=50., max_size=400., power=3.0, min_val=None, + max_val=None): -def _size_map(markersize, values, scale): - if isinstance(scale, (int, float)): - return markersize / (scale ** min(values)) * numpy.power(values, scale) - elif isinstance(scale, (numpy.ndarray, list)): - return scale - else: - raise ValueError("Scale data type not supported") + min_val = min_val or np.min(values) + max_val = max_val or np.max(values) + normalized_values = ((values - min_val) / (max_val - min_val)) ** power + marker_sizes = min_size + normalized_values * (max_size - min_size) * bool(power) + return marker_sizes def _autoscale_histogram(ax: pyplot.Axes, bin_edges, simulated, observation, mass=99.5): @@ -2286,17 +2297,6 @@ def _create_geo_axes(figsize, extent, projection, set_global): return ax -def _calculate_marker_size(markersize, magnitudes, scale): - mw_range = [min(magnitudes), max(magnitudes)] - if isinstance(scale, (int, float)): - return (markersize / (scale ** mw_range[0])) * numpy.power(magnitudes, scale) - elif isinstance(scale, (numpy.ndarray, list)): - return scale - else: - raise ValueError("Scale data type not supported") - - -# Helper function to add gridlines def _add_gridlines(ax, grid_labels, grid_fontsize): gl = ax.gridlines(draw_labels=grid_labels, alpha=0.5) gl.right_labels = False diff --git a/tests/test_plots.py b/tests/test_plots.py index 0190871a..e5de65fa 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -42,7 +42,6 @@ _get_basemap, # noqa _calculate_spatial_extent, # noqa _create_geo_axes, # noqa - _calculate_marker_size, # noqa _add_gridlines, # noqa _get_marker_style, # noqa _get_marker_t_color, # noqa @@ -50,7 +49,6 @@ _get_axis_limits, # noqa _add_labels_for_publication, # noqa _autosize_scatter, # noqa - _size_map, # noqa _autoscale_histogram, # noqa _annotate_distribution_plot, # noqa _define_colormap_and_alpha, # noqa @@ -71,7 +69,7 @@ def is_internet_available(): is_github_actions = os.getenv("GITHUB_ACTIONS") == "true" -show_plots = False +show_plots = True class TestPlots(unittest.TestCase): @@ -125,10 +123,10 @@ def test_plot_magnitude_vs_time(self): self.assertTrue(all(scatter_color[:3] == (1.0, 0.0, 0.0))) # Check if color is red # Test with custom marker size - ax = plot_magnitude_vs_time(catalog=self.observation_m2, size=10, mag_scale=1, + ax = plot_magnitude_vs_time(catalog=self.observation_m2, size=25, max_size=600, show=show_plots) scatter_sizes = ax.collections[0].get_sizes() - func_sizes = _autosize_scatter(10, self.observation_m2.data["magnitude"], 1) + func_sizes = _autosize_scatter(self.observation_m2.data["magnitude"], 25, 600, 4) numpy.testing.assert_array_almost_equal(scatter_sizes, func_sizes) # Test with custom alpha @@ -136,14 +134,14 @@ def test_plot_magnitude_vs_time(self): scatter_alpha = ax.collections[0].get_alpha() self.assertEqual(scatter_alpha, 0.5) - # Test with custom mag_scale - ax = plot_magnitude_vs_time(catalog=self.observation_m2, mag_scale=8, show=show_plots) + # Test with custom marker size power + ax = plot_magnitude_vs_time(catalog=self.observation_m2, power=6, show=show_plots) scatter_sizes = ax.collections[0].get_sizes() - func_sizes = _autosize_scatter(4, self.observation_m2.data["magnitude"], 8) + func_sizes = _autosize_scatter(self.observation_m2.data["magnitude"], 4, 300, 6) numpy.testing.assert_array_almost_equal(scatter_sizes, func_sizes) - - # Test with show=show_plots (just to ensure no errors occur) - plot_magnitude_vs_time(catalog=self.observation_m2, show=show_plots) + # + # # Test with show=True (just to ensure no errors occur) + plot_magnitude_vs_time(catalog=self.observation_m2, show=True) def test_plot_cumulative_events_default(self): # Test with default arguments to ensure basic functionality @@ -832,8 +830,14 @@ def setUp(self): [[-125, 25], [-85, 25], [-85, 65], [-125, 65], [-125, 25]] ) + self.mock_fix = MagicMock() + self.mock_fix.get_magnitudes.return_value = numpy.array([4, 5, 6, 7, 8]) + self.mock_fix.get_latitudes.return_value = numpy.array([36, 35, 34, 33, 32]) + self.mock_fix.get_longitudes.return_value = numpy.array([-110, -110, -110, -110, -110]) + self.mock_fix.get_bbox.return_value = [-114, -104, 31.5, 37.5] + def test_plot_catalog_default(self): - # Test plot with default settings + # Test plot with default settings4 ax = plot_catalog(self.mock_catalog, show=show_plots) self.assertIsInstance(ax, plt.Axes) self.assertEqual(ax.get_title(), '') @@ -850,6 +854,42 @@ def test_plot_catalog_without_legend(self): legend = ax.get_legend() self.assertIsNone(legend) + def test_plot_catalog_custom_legend(self): + + ax = plot_catalog(self.mock_catalog, mag_ticks=5, + show=show_plots) + legend = ax.get_legend() + self.assertIsNotNone(legend) + + mags = self.mock_catalog.get_magnitudes() + mag_bins = numpy.linspace(min(mags), max(mags), 3, endpoint=True) + ax = plot_catalog(self.mock_catalog, mag_ticks=mag_bins, show=show_plots) + legend = ax.get_legend() + self.assertIsNotNone(legend) + + def test_plot_catalog_correct_sizing(self): + + ax = plot_catalog(self.mock_fix, + figsize=(4,6), + mag_ticks=[4, 5, 6, 7, 8], + legend_loc='right', + show=show_plots) + legend = ax.get_legend() + self.assertIsNotNone(legend) + + def test_plot_catalog_custom_sizes(self): + + ax = plot_catalog(self.mock_catalog, size=5, max_size=800, power=6, + show=show_plots) + legend = ax.get_legend() + self.assertIsNotNone(legend) + + def test_plot_catalog_same_size(self): + + ax = plot_catalog(self.mock_catalog, size=30, power=0, show=show_plots) + legend = ax.get_legend() + self.assertIsNotNone(legend) + def test_plot_catalog_with_custom_extent(self): # Test plot with custom extent custom_extent = (-130, 20, 10, 80) @@ -1080,19 +1120,43 @@ def test_add_labels_for_publication(self): self.assertEqual(len(annotations), 1) self.assertEqual(annotations[0].get_text(), "(a)") - def test_autosize_scatter(self): - values = numpy.array([1, 2, 3]) - scale = 1.5 - expected_sizes = (2 / (scale**1)) * numpy.power(values, scale) - numpy.testing.assert_array_almost_equal( - _autosize_scatter(2, values, scale), expected_sizes - ) - def test_size_map(self): - values = numpy.array([1, 2, 3]) - scale = 1.5 - expected_sizes = (2 / (scale**1)) * numpy.power(values, scale) - numpy.testing.assert_array_almost_equal(_size_map(2, values, scale), expected_sizes) + def test_autosize_scatter(self): + values = numpy.array([1, 2, 3, 4, 5]) + expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=3.0) + result = _autosize_scatter(values, min_size=50., max_size=400., power=3.0) + numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2) + + values = numpy.array([1, 2, 3, 4, 5]) + min_val = 0 + max_val = 10 + expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=3.0, + min_val=min_val, max_val=max_val) + result = _autosize_scatter(values, min_size=50., max_size=400., power=3.0, min_val=min_val, + max_val=max_val) + numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2) + + values = numpy.array([1, 2, 3, 4, 5]) + power = 2.0 + expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=power) + result = _autosize_scatter(values, min_size=50., max_size=400., power=power) + numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2) + + values = numpy.array([1, 2, 3, 4, 5]) + power = 0.0 + expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=power) + result = _autosize_scatter(values, min_size=50., max_size=400., power=power) + numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2) + + values = numpy.array([5, 5, 5, 5, 5]) + expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=3.0) + result = _autosize_scatter(values, min_size=50., max_size=400., power=3.0) + numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2) + + values = numpy.array([10, 100, 1000, 10000, 100000]) + expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=3.0) + result = _autosize_scatter(values, min_size=50., max_size=400., power=3.0) + numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2) def test_autoscale_histogram(self): fig, ax = plt.subplots() @@ -1177,26 +1241,6 @@ def test_create_geo_axes(self): self.assertIsInstance(ax, plt.Axes) self.assertAlmostEqual(ax.get_extent(), extent) - def test_calculate_marker_size(self): - # Test marker size calculation with a scale factor - magnitudes = numpy.array([4.0, 5.0, 6.0]) - markersize = 5 - scale = 1.2 - sizes = _calculate_marker_size(markersize, magnitudes, scale) - expected_sizes = (markersize / (scale ** min(magnitudes))) * numpy.power( - magnitudes, scale - ) - numpy.testing.assert_array_almost_equal(sizes, expected_sizes) - - # Test marker size calculation with a fixed scale array - scale_array = numpy.array([10, 20, 30]) - sizes = _calculate_marker_size(markersize, magnitudes, scale_array) - numpy.testing.assert_array_almost_equal(sizes, scale_array) - - # Test invalid scale type - with self.assertRaises(ValueError): - _calculate_marker_size(markersize, magnitudes, "invalid_scale") - def test_add_gridlines(self): # Test adding gridlines to an axis fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()})