Skip to content

Commit

Permalink
Merge pull request #343 from yashksaini-coder/yash/fix-342
Browse files Browse the repository at this point in the history
✅ Enhance the Plotting visualization functions & update code snippets
  • Loading branch information
UTSAVS26 authored Nov 6, 2024
2 parents afd48f9 + 413138e commit c79112e
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 135 deletions.
Empty file removed Tests/graphing/__init__.py
Empty file.
75 changes: 0 additions & 75 deletions Tests/graphing/test_advanced_graphing.py

This file was deleted.

1 change: 1 addition & 0 deletions pysnippets/graphing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__all__ = []
119 changes: 65 additions & 54 deletions pysnippets/graphing/advanced_graphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,65 +2,73 @@
import numpy as np
import seaborn as sns
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import logging
import scipy.stats as stats
from dataclasses import dataclass
from typing import Callable, Any

# Set up logging
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global configuration dictionary
config = {
'default_style': 'default',
'default_figsize': (10, 6),
'default_color_palette': 'viridis'
}
@dataclass
class GraphConfig:
style: str = 'default'
figsize: tuple = (10, 6)
color_palette: str = 'viridis'

def set_config(style='default', figsize=(10, 6), color_palette='viridis'):
# Global configuration instance
config = GraphConfig()

def set_config(style: str = 'default', figsize: tuple = (10, 6), color_palette: str = 'viridis') -> None:
"""Set global configuration for plotting."""
global config
config['default_style'] = style
config['default_figsize'] = figsize
config['default_color_palette'] = color_palette
if style not in plt.style.available:
logger.warning(f"Invalid style '{style}' provided. Defaulting to 'default'.")
style = 'default'

config.style = style
config.figsize = figsize
config.color_palette = color_palette
plt.style.use(style)

# Validate color palette
valid_palettes = sns.color_palette() # Get list of valid palettes
if color_palette in sns.palettes.SEABORN_PALETTES.keys(): # Check against valid Seaborn palettes
if color_palette in sns.palettes.SEABORN_PALETTES.keys():
sns.set_palette(color_palette)
else:
# Use a default palette silently without printing
sns.set_palette('viridis')
logger.warning(f"Invalid color palette '{color_palette}' provided. Defaulting to 'viridis'.")
sns.set_palette('viridis')

def reset_config():
"""Resets the configuration to default values."""
global config
config = {
'default_style': 'default',
'default_figsize': (10, 6),
'default_color_palette': 'viridis'
}
def get_available_styles() -> list:
"""Returns a list of available matplotlib styles."""
return plt.style.available

def apply_config():
def apply_config() -> None:
"""Applies the current configuration settings."""
try:
if config['default_style'] != 'default':
plt.style.use(config['default_style'])
plt.rcParams['figure.figsize'] = config['default_figsize']
sns.set_palette(config['default_color_palette'])
plt.style.use(config.style if config.style != 'default' else 'default')
plt.rcParams['figure.figsize'] = config.figsize
sns.set_palette(config.color_palette)
except Exception as e:
# logger.error(f"Error applying configuration: {str(e)}")
# logger.info("Reverting to default matplotlib style")
logger.error(f"Error applying configuration: {str(e)}")
logger.info("Reverting to default matplotlib style")
plt.style.use('default')

def get_available_styles():
return plt.style.available
def safe_plot(plot_func: Callable[..., None], *args: Any, **kwargs: Any) -> None:
"""Safely executes a plotting function with error handling."""
try:
plot_func(*args, **kwargs)
except Exception as e:
logger.error(f"Error in plotting: {str(e)}")
raise

def line_plot(x, y, title="Line Plot", xlabel="X-axis", ylabel="Y-axis", interactive=False):
def line_plot(x: np.ndarray, y: np.ndarray, title: str = "Line Plot", xlabel: str = "X-axis", ylabel: str = "Y-axis", interactive: bool = False) -> None:
"""Creates a line plot."""
apply_config()
if interactive:
fig = go.Figure(data=go.Scatter(x=x, y=y, mode='lines'))
fig.update_layout(title=title, xaxis_title=xlabel, yaxis_title=ylabel)
return fig
fig.show()
else:
plt.figure()
plt.plot(x, y)
Expand All @@ -70,7 +78,8 @@ def line_plot(x, y, title="Line Plot", xlabel="X-axis", ylabel="Y-axis", interac
plt.grid(True)
plt.show()

def bar_chart(categories, values, title="Bar Chart", xlabel="Categories", ylabel="Values"):
def bar_chart(categories: list, values: list, title: str = "Bar Chart", xlabel: str = "Categories", ylabel: str = "Values") -> None:
"""Creates a bar chart."""
apply_config()
plt.figure()
plt.bar(categories, values)
Expand All @@ -80,7 +89,8 @@ def bar_chart(categories, values, title="Bar Chart", xlabel="Categories", ylabel
plt.xticks(rotation=45)
plt.show()

def scatter_plot(x, y, title="Scatter Plot", xlabel="X-axis", ylabel="Y-axis"):
def scatter_plot(x: np.ndarray, y: np.ndarray, title: str = "Scatter Plot", xlabel: str = "X-axis", ylabel: str = "Y-axis") -> None:
"""Creates a scatter plot."""
apply_config()
plt.figure()
plt.scatter(x, y)
Expand All @@ -90,15 +100,17 @@ def scatter_plot(x, y, title="Scatter Plot", xlabel="X-axis", ylabel="Y-axis"):
plt.grid(True)
plt.show()

def pie_chart(labels, sizes, title="Pie Chart"):
def pie_chart(labels: list, sizes: list, title: str = "Pie Chart") -> None:
"""Creates a pie chart."""
apply_config()
plt.figure()
plt.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=90)
plt.axis('equal')
plt.title(title)
plt.show()

def subplot(plot_funcs, nrows, ncols, titles):
def subplot(plot_funcs: list, nrows: int, ncols: int, titles: list) -> None:
"""Creates subplots from a list of plotting functions."""
apply_config()
fig, axs = plt.subplots(nrows, ncols, figsize=(5*ncols, 5*nrows))
axs = axs.flatten()
Expand All @@ -109,35 +121,34 @@ def subplot(plot_funcs, nrows, ncols, titles):
plt.tight_layout()
plt.show()

def heatmap(data, title="Heatmap"):
def heatmap(data: np.ndarray, title: str = "Heatmap") -> None:
"""Creates a heatmap."""
apply_config()
plt.figure()
sns.heatmap(data, annot=True, cmap="YlGnBu")
plt.title(title)
plt.show()

def normalize_data(data):
def normalize_data(data: np.ndarray) -> np.ndarray:
"""Normalizes the data to a range of [0, 1]."""
return (data - np.min(data)) / (np.max(data) - np.min(data))

def moving_average(data, window):
def moving_average(data: np.ndarray, window: int) -> np.ndarray:
"""Calculates the moving average of the data."""
return np.convolve(data, np.ones(window), 'valid') / window

def set_color_palette(palette):
def set_color_palette(palette: str) -> None:
"""Sets the color palette for plots."""
sns.set_palette(palette)

def annotate_point(x, y, text):
def annotate_point(x: float, y: float, text: str) -> None:
"""Annotates a point on the plot."""
plt.annotate(text, (x, y), xytext=(5, 5), textcoords='offset points')

def qq_plot(data, title="Q-Q Plot"):
def qq_plot(data: np.ndarray, title: str = "Q-Q Plot") -> None:
"""Creates a Q-Q plot."""
apply_config()
plt.figure()
stats.probplot(data, dist="norm", plot=plt) # Use scipy.stats.probplot
stats.probplot(data, dist="norm", plot=plt)
plt.title(title)
plt.show()

def safe_plot(plot_func, *args, **kwargs):
try:
plot_func(*args, **kwargs)
except Exception as e:
# logger.error(f"Error in plotting: {str(e)}")
raise
14 changes: 8 additions & 6 deletions pysnippets/graphing/graphing-example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import numpy as np
from advanced_graphing import (
set_config, line_plot, bar_chart, scatter_plot, pie_chart,
Expand All @@ -8,13 +10,13 @@

# Print available styles
available_styles = get_available_styles()
print("Available styles:", available_styles)
logging.info("Available styles: %s", available_styles)

# Set global configuration
preferred_style = 'seaborn-v0_8-darkgrid' if 'seaborn-v0_8-darkgrid' in available_styles else 'ggplot'
set_config(style=preferred_style, figsize=(12, 8), color_palette='Set2')

# Generate some example data
# Generate example data
x = np.linspace(0, 10, 100)
y = np.sin(x)

Expand All @@ -41,13 +43,13 @@
safe_plot(heatmap, data, title="Random Heatmap")

# Subplots
def plot1():
def plot_sin():
line_plot(x, np.sin(x), title="Sin(x)")

def plot2():
def plot_cos():
line_plot(x, np.cos(x), title="Cos(x)")

safe_plot(subplot, [plot1, plot2], 1, 2, ["Sin(x)", "Cos(x)"])
safe_plot(subplot, [plot_sin, plot_cos], 1, 2, ["Sin(x)", "Cos(x)"])

# Data preprocessing
normalized_data = normalize_data(y)
Expand All @@ -71,4 +73,4 @@ def annotated_plot():
data = np.random.normal(0, 1, 1000)
safe_plot(qq_plot, data, title="Q-Q Plot of Normal Distribution")

print("All plots have been generated and displayed.")
logging.info("All plots have been generated and displayed.")
47 changes: 47 additions & 0 deletions pysnippets/graphing/test_advanced_graphing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import unittest
import numpy as np
from advanced_graphing import (
set_config, normalize_data, moving_average, safe_plot, line_plot, bar_chart, get_available_styles, config
)

class TestAdvancedGraphing(unittest.TestCase):

def setUp(self):
"""Set up the test environment."""
set_config(style='ggplot', figsize=(10, 6), color_palette='viridis')

def test_normalize_data(self):
"""Test the normalization of data."""
data = np.array([1, 2, 3, 4, 5])
normalized = normalize_data(data)
expected = np.array([0, 0.25, 0.5, 0.75, 1])
np.testing.assert_array_almost_equal(normalized, expected)

def test_moving_average(self):
"""Test the moving average calculation."""
data = np.array([1, 2, 3, 4, 5])
moving_avg = moving_average(data, window=3)
expected = np.array([2, 3, 4])
np.testing.assert_array_almost_equal(moving_avg, expected)

def test_set_config(self):
"""Test setting configuration."""
available_styles = get_available_styles()
valid_style = available_styles[0] # Use the first available style
set_config(style=valid_style, figsize=(12, 8), color_palette='Set1')

# Access the global config directly
self.assertEqual(config.style, valid_style)
self.assertEqual(config.figsize, (12, 8))
self.assertEqual(config.color_palette, 'Set1')

def test_safe_plot(self):
"""Test that safe_plot does not raise an error for valid functions."""
try:
safe_plot(line_plot, np.array([1, 2, 3]), np.array([1, 4, 9]), title="Test Line Plot")
safe_plot(bar_chart, ['A', 'B', 'C'], [1, 2, 3], title="Test Bar Chart")
except Exception as e:
self.fail(f"safe_plot raised an exception: {str(e)}")

if __name__ == '__main__':
unittest.main()

0 comments on commit c79112e

Please sign in to comment.