Skip to content

Commit

Permalink
ft: Improved autosize function for scatter plots.
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloitu committed Aug 22, 2024
1 parent f9ade59 commit 694f6f3
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 101 deletions.
114 changes: 57 additions & 57 deletions csep/utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
import string
import warnings
from typing import TYPE_CHECKING, Optional, Any, List, Union, Tuple
from typing import TYPE_CHECKING, Optional, Any, List, Union, Tuple, Sequence

import cartopy
import cartopy.crs as ccrs
Expand Down Expand Up @@ -76,7 +76,7 @@
# Consistency and Comparison tests
"capsize": 2,
"hbars": True,
# Specific to spatial plotting
# Spatial plotting
"grid_labels": True,
"grid_fontsize": 8,
"region_color": "black",
Expand All @@ -100,7 +100,8 @@ def plot_magnitude_vs_time(
ax: Optional[Axes] = None,
color: Optional[str] = "steelblue",
size: Optional[int] = 4,
mag_scale: Optional[int] = 6,
max_size: Optional[int] = 300,
power: Optional[int] = 4,
alpha: Optional[float] = 0.5,
show: bool = False,
**kwargs: Any,
Expand All @@ -118,9 +119,11 @@ def plot_magnitude_vs_time(
Color of the scatter plot points. If not provided, defaults to value in
`DEFAULT_PLOT_ARGS`.
size (int):
Size of the scatter plot markers.
mag_scale (int):
Scaling factor for the magnitudes.
Size of the event with the minimum magnitude
max_size (int):
Size of the event with the maximum magnitude
power (int):
Power scaling of the scatter sizing.
alpha (float):
Transparency level for the scatter plot points. If not provided, defaults to value
in `DEFAULT_PLOT_ARGS`.
Expand Down Expand Up @@ -149,7 +152,7 @@ def plot_magnitude_vs_time(
mag,
marker="o",
c=color,
s=_autosize_scatter(size, mag, mag_scale),
s=_autosize_scatter(mag, min_size=size, max_size=max_size, power=power),
alpha=alpha,
)

Expand Down Expand Up @@ -1419,10 +1422,14 @@ def plot_catalog(
ax: Optional[matplotlib.axes.Axes] = None,
projection: Optional[Union[ccrs.Projection, str]] = ccrs.PlateCarree(),
show: bool = False,
extent: Optional[List[float]] = None,
extent: Optional[Sequence[float]] = None,
set_global: bool = False,
mag_scale: float = 1,
mag_ticks: Optional[List[float]] = None,
mag_ticks: Optional[Union[Sequence[float], np.ndarray, int]] = None,
size: float = 15,
max_size: float = 300,
power: float = 3,
min_val: Optional[float] = None,
max_val: Optional[float] = None,
plot_region: bool = False,
**kwargs,
) -> matplotlib.axes.Axes:
Expand All @@ -1436,10 +1443,14 @@ def plot_catalog(
extent (list): Default 1.05 * :func:`catalog.region.get_bbox()`.
projection (cartopy.crs.Projection): Projection to be used in the underlying basemap
set_global (bool): Display the complete globe as basemap.
mag_scale (float): Scaling of the scatter.
mag_ticks (list): Ticks to display in the legend.
size (float): Size of the event with the minimum magnitude
max_size (float): Size of the catalog's maximum magnitude
power (float, list): Power scaling of the scatter sizing.
min_val (float): Override minimum magnitude of the catalog for scatter sizing
max_val (float): Override maximum magnitude of the catalog for scatter sizing
mag_ticks (list, int): Ticks to display in the legend.
plot_region (bool): Flag to plot the catalog region border.
kwargs: size, alpha, markercolor, markeredgecolor, figsize, legend,
kwargs: alpha, markercolor, markeredgecolor, figsize, legend,
legend_title, legend_labelspacing, legend_borderpad, legend_framealpha
Returns:
Expand All @@ -1455,10 +1466,11 @@ def plot_catalog(
ax = plot_basemap(basemap, extent, ax=ax, set_global=set_global, show=False, **plot_args)

# Plot catalog
scatter = ax.scatter(
ax.scatter(
catalog.get_longitudes(),
catalog.get_latitudes(),
s=_size_map(plot_args["size"], catalog.get_magnitudes(), mag_scale),
s=_autosize_scatter(values=catalog.get_magnitudes(), min_size=size, max_size=max_size,
power=power, min_val=min_val, max_val=max_val),
transform=ccrs.PlateCarree(),
color=plot_args["markercolor"],
edgecolors=plot_args["markeredgecolor"],
Expand All @@ -1467,22 +1479,32 @@ def plot_catalog(

# Legend
if plot_args["legend"]:
mw_range = [min(catalog.get_magnitudes()), max(catalog.get_magnitudes())]

if isinstance(mag_ticks, (tuple, list, numpy.ndarray)):
if not numpy.all([mw_range[0] <= i <= mw_range[1] for i in mag_ticks]):
print("Magnitude ticks do not lie within the catalog magnitude range")
elif mag_ticks is None:
mag_ticks = numpy.linspace(mw_range[0], mw_range[1], 4)

handles, labels = scatter.legend_elements(
prop="sizes",
num=list(_size_map(plot_args["size"], mag_ticks, mag_scale)),
alpha=0.3,
if isinstance(mag_ticks, (list, np.ndarray)):
mag_ticks = np.array(mag_ticks)
else:
mw_range = [min(catalog.get_magnitudes()), max(catalog.get_magnitudes())]
mag_ticks = np.linspace(mw_range[0], mw_range[1], mag_ticks or 4, endpoint=True)

# Map mag_ticks to marker sizes using the custom size mapping function
legend_sizes = _autosize_scatter(
values=mag_ticks,
min_size=size,
max_size=max_size,
power=power,
min_val=min_val or np.min(catalog.get_magnitudes()),
max_val=max_val or np.max(catalog.get_magnitudes())
)

# Create custom legend handles
handles = [pyplot.Line2D([0], [0], marker='o', lw=0, label=str(m),
markersize=np.sqrt(s), markerfacecolor='gray', alpha=0.5,
markeredgewidth=0.8,
markeredgecolor='black')
for m, s in zip(mag_ticks, legend_sizes)]

ax.legend(
handles,
numpy.round(mag_ticks, 1),
np.round(mag_ticks, 1),
loc=plot_args["legend_loc"],
handletextpad=5,
title=plot_args.get("legend_title") or "Magnitudes",
Expand Down Expand Up @@ -2127,25 +2149,14 @@ def _plot_pvalues_and_intervals(test_results, ax, var=None):
return ax


def _autosize_scatter(markersize, values, scale):
if isinstance(scale, (int, float)):
# return (values - min(values) + markersize) ** scale # Adjust this formula as needed for better visualization
# return mark0ersize * (1 + (values - numpy.min(values)) / (numpy.max(values) - numpy.min(values)) ** scale)
return markersize / (scale ** min(values)) * numpy.power(values, scale)

elif isinstance(scale, (numpy.ndarray, list)):
return scale
else:
raise ValueError("scale data type not supported")

def _autosize_scatter(values, min_size=50., max_size=400., power=3.0, min_val=None,
max_val=None):

def _size_map(markersize, values, scale):
if isinstance(scale, (int, float)):
return markersize / (scale ** min(values)) * numpy.power(values, scale)
elif isinstance(scale, (numpy.ndarray, list)):
return scale
else:
raise ValueError("Scale data type not supported")
min_val = min_val or np.min(values)
max_val = max_val or np.max(values)
normalized_values = ((values - min_val) / (max_val - min_val)) ** power
marker_sizes = min_size + normalized_values * (max_size - min_size) * bool(power)
return marker_sizes


def _autoscale_histogram(ax: pyplot.Axes, bin_edges, simulated, observation, mass=99.5):
Expand Down Expand Up @@ -2286,17 +2297,6 @@ def _create_geo_axes(figsize, extent, projection, set_global):
return ax


def _calculate_marker_size(markersize, magnitudes, scale):
mw_range = [min(magnitudes), max(magnitudes)]
if isinstance(scale, (int, float)):
return (markersize / (scale ** mw_range[0])) * numpy.power(magnitudes, scale)
elif isinstance(scale, (numpy.ndarray, list)):
return scale
else:
raise ValueError("Scale data type not supported")


# Helper function to add gridlines
def _add_gridlines(ax, grid_labels, grid_fontsize):
gl = ax.gridlines(draw_labels=grid_labels, alpha=0.5)
gl.right_labels = False
Expand Down
132 changes: 88 additions & 44 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,13 @@
_get_basemap, # noqa
_calculate_spatial_extent, # noqa
_create_geo_axes, # noqa
_calculate_marker_size, # noqa
_add_gridlines, # noqa
_get_marker_style, # noqa
_get_marker_t_color, # noqa
_get_marker_w_color, # noqa
_get_axis_limits, # noqa
_add_labels_for_publication, # noqa
_autosize_scatter, # noqa
_size_map, # noqa
_autoscale_histogram, # noqa
_annotate_distribution_plot, # noqa
_define_colormap_and_alpha, # noqa
Expand All @@ -71,7 +69,7 @@ def is_internet_available():

is_github_actions = os.getenv("GITHUB_ACTIONS") == "true"

show_plots = False
show_plots = True


class TestPlots(unittest.TestCase):
Expand Down Expand Up @@ -125,25 +123,25 @@ def test_plot_magnitude_vs_time(self):
self.assertTrue(all(scatter_color[:3] == (1.0, 0.0, 0.0))) # Check if color is red

# Test with custom marker size
ax = plot_magnitude_vs_time(catalog=self.observation_m2, size=10, mag_scale=1,
ax = plot_magnitude_vs_time(catalog=self.observation_m2, size=25, max_size=600,
show=show_plots)
scatter_sizes = ax.collections[0].get_sizes()
func_sizes = _autosize_scatter(10, self.observation_m2.data["magnitude"], 1)
func_sizes = _autosize_scatter(self.observation_m2.data["magnitude"], 25, 600, 4)
numpy.testing.assert_array_almost_equal(scatter_sizes, func_sizes)

# Test with custom alpha
ax = plot_magnitude_vs_time(catalog=self.observation_m2, alpha=0.5, show=show_plots)
scatter_alpha = ax.collections[0].get_alpha()
self.assertEqual(scatter_alpha, 0.5)

# Test with custom mag_scale
ax = plot_magnitude_vs_time(catalog=self.observation_m2, mag_scale=8, show=show_plots)
# Test with custom marker size power
ax = plot_magnitude_vs_time(catalog=self.observation_m2, power=6, show=show_plots)
scatter_sizes = ax.collections[0].get_sizes()
func_sizes = _autosize_scatter(4, self.observation_m2.data["magnitude"], 8)
func_sizes = _autosize_scatter(self.observation_m2.data["magnitude"], 4, 300, 6)
numpy.testing.assert_array_almost_equal(scatter_sizes, func_sizes)

# Test with show=show_plots (just to ensure no errors occur)
plot_magnitude_vs_time(catalog=self.observation_m2, show=show_plots)
#
# # Test with show=True (just to ensure no errors occur)
plot_magnitude_vs_time(catalog=self.observation_m2, show=True)

def test_plot_cumulative_events_default(self):
# Test with default arguments to ensure basic functionality
Expand Down Expand Up @@ -832,8 +830,14 @@ def setUp(self):
[[-125, 25], [-85, 25], [-85, 65], [-125, 65], [-125, 25]]
)

self.mock_fix = MagicMock()
self.mock_fix.get_magnitudes.return_value = numpy.array([4, 5, 6, 7, 8])
self.mock_fix.get_latitudes.return_value = numpy.array([36, 35, 34, 33, 32])
self.mock_fix.get_longitudes.return_value = numpy.array([-110, -110, -110, -110, -110])
self.mock_fix.get_bbox.return_value = [-114, -104, 31.5, 37.5]

def test_plot_catalog_default(self):
# Test plot with default settings
# Test plot with default settings4
ax = plot_catalog(self.mock_catalog, show=show_plots)
self.assertIsInstance(ax, plt.Axes)
self.assertEqual(ax.get_title(), '')
Expand All @@ -850,6 +854,42 @@ def test_plot_catalog_without_legend(self):
legend = ax.get_legend()
self.assertIsNone(legend)

def test_plot_catalog_custom_legend(self):

ax = plot_catalog(self.mock_catalog, mag_ticks=5,
show=show_plots)
legend = ax.get_legend()
self.assertIsNotNone(legend)

mags = self.mock_catalog.get_magnitudes()
mag_bins = numpy.linspace(min(mags), max(mags), 3, endpoint=True)
ax = plot_catalog(self.mock_catalog, mag_ticks=mag_bins, show=show_plots)
legend = ax.get_legend()
self.assertIsNotNone(legend)

def test_plot_catalog_correct_sizing(self):

ax = plot_catalog(self.mock_fix,
figsize=(4,6),
mag_ticks=[4, 5, 6, 7, 8],
legend_loc='right',
show=show_plots)
legend = ax.get_legend()
self.assertIsNotNone(legend)

def test_plot_catalog_custom_sizes(self):

ax = plot_catalog(self.mock_catalog, size=5, max_size=800, power=6,
show=show_plots)
legend = ax.get_legend()
self.assertIsNotNone(legend)

def test_plot_catalog_same_size(self):

ax = plot_catalog(self.mock_catalog, size=30, power=0, show=show_plots)
legend = ax.get_legend()
self.assertIsNotNone(legend)

def test_plot_catalog_with_custom_extent(self):
# Test plot with custom extent
custom_extent = (-130, 20, 10, 80)
Expand Down Expand Up @@ -1080,19 +1120,43 @@ def test_add_labels_for_publication(self):
self.assertEqual(len(annotations), 1)
self.assertEqual(annotations[0].get_text(), "(a)")

def test_autosize_scatter(self):
values = numpy.array([1, 2, 3])
scale = 1.5
expected_sizes = (2 / (scale**1)) * numpy.power(values, scale)
numpy.testing.assert_array_almost_equal(
_autosize_scatter(2, values, scale), expected_sizes
)

def test_size_map(self):
values = numpy.array([1, 2, 3])
scale = 1.5
expected_sizes = (2 / (scale**1)) * numpy.power(values, scale)
numpy.testing.assert_array_almost_equal(_size_map(2, values, scale), expected_sizes)
def test_autosize_scatter(self):
values = numpy.array([1, 2, 3, 4, 5])
expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=3.0)
result = _autosize_scatter(values, min_size=50., max_size=400., power=3.0)
numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2)

values = numpy.array([1, 2, 3, 4, 5])
min_val = 0
max_val = 10
expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=3.0,
min_val=min_val, max_val=max_val)
result = _autosize_scatter(values, min_size=50., max_size=400., power=3.0, min_val=min_val,
max_val=max_val)
numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2)

values = numpy.array([1, 2, 3, 4, 5])
power = 2.0
expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=power)
result = _autosize_scatter(values, min_size=50., max_size=400., power=power)
numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2)

values = numpy.array([1, 2, 3, 4, 5])
power = 0.0
expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=power)
result = _autosize_scatter(values, min_size=50., max_size=400., power=power)
numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2)

values = numpy.array([5, 5, 5, 5, 5])
expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=3.0)
result = _autosize_scatter(values, min_size=50., max_size=400., power=3.0)
numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2)

values = numpy.array([10, 100, 1000, 10000, 100000])
expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=3.0)
result = _autosize_scatter(values, min_size=50., max_size=400., power=3.0)
numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2)

def test_autoscale_histogram(self):
fig, ax = plt.subplots()
Expand Down Expand Up @@ -1177,26 +1241,6 @@ def test_create_geo_axes(self):
self.assertIsInstance(ax, plt.Axes)
self.assertAlmostEqual(ax.get_extent(), extent)

def test_calculate_marker_size(self):
# Test marker size calculation with a scale factor
magnitudes = numpy.array([4.0, 5.0, 6.0])
markersize = 5
scale = 1.2
sizes = _calculate_marker_size(markersize, magnitudes, scale)
expected_sizes = (markersize / (scale ** min(magnitudes))) * numpy.power(
magnitudes, scale
)
numpy.testing.assert_array_almost_equal(sizes, expected_sizes)

# Test marker size calculation with a fixed scale array
scale_array = numpy.array([10, 20, 30])
sizes = _calculate_marker_size(markersize, magnitudes, scale_array)
numpy.testing.assert_array_almost_equal(sizes, scale_array)

# Test invalid scale type
with self.assertRaises(ValueError):
_calculate_marker_size(markersize, magnitudes, "invalid_scale")

def test_add_gridlines(self):
# Test adding gridlines to an axis
fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()})
Expand Down

0 comments on commit 694f6f3

Please sign in to comment.