Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added type hints & missing documentation to python functions #31

Merged
merged 1 commit into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 65 additions & 10 deletions analysis/Python_scripts/plotting.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import textwrap
import pandas as pd
from matplotlib import pyplot as plt
from typing import List
import seaborn as sns
import plotly.graph_objs as go
import numpy as np


def make_boxplot(grouped_df: pd.DataFrame, colname: str, legend: str):
def make_boxplot(grouped_df: pd.DataFrame, colname: str, legend: str) -> None:
"""Create a boxplot for each group in a DataFrame.

Args:
grouped_df (pd.DataFrame): The DataFrame grouped by some criteria.
colname (str): The name of the column to create boxplots for.
legend (str): The legend for the plot.
"""
# Create the plot with a width of 10 inches
fig, ax = plt.subplots(figsize=(17, 5))

Expand All @@ -31,7 +39,28 @@ def make_boxplot(grouped_df: pd.DataFrame, colname: str, legend: str):
plt.show()


def plot_histograms_sidebyside(same_query_ref1, same_query_ref2, same_query_ref3, same_query_ref4, same_query_ref5, column_name, xaxis_title='', title=''):
def plot_histograms_sidebyside(
same_query_ref1: pd.DataFrame,
same_query_ref2: pd.DataFrame,
same_query_ref3: pd.DataFrame,
same_query_ref4: pd.DataFrame,
same_query_ref5: pd.DataFrame,
column_name: str,
xaxis_title: str = '',
title: str = ''
) -> None:
"""Create and display a histogram for each of five input DataFrames.

Args:
same_query_ref1 (pd.DataFrame): The first input DataFrame.
same_query_ref2 (pd.DataFrame): The second input DataFrame.
same_query_ref3 (pd.DataFrame): The third input DataFrame.
same_query_ref4 (pd.DataFrame): The fourth input DataFrame.
same_query_ref5 (pd.DataFrame): The fifth input DataFrame.
column_name (str): The name of the column to create histograms for.
xaxis_title (str): The title for the x-axis.
title (str): The title for the plot.
"""
# Define number of bins
n_bins = 10

Expand Down Expand Up @@ -71,7 +100,14 @@ def plot_histograms_sidebyside(same_query_ref1, same_query_ref2, same_query_ref3
fig.show()


def plot_histogram(x, xaxis_title='', title=''):
def plot_histogram(x: List[float], xaxis_title: str = '', title: str = '') -> None:
"""Create and display a histogram for the input data.

Args:
x (List[float]): The input data.
xaxis_title (str): The title for the x-axis.
title (str): The title for the plot.
"""
# Define number of bins
n_bins = 20

Expand Down Expand Up @@ -101,7 +137,21 @@ def create_plot(df: pd.DataFrame,
showlegend: bool = True,
normalized_matches: bool = True,
nist_scale: bool = True,
hide_labels: bool = False):
hide_labels: bool = False) -> plt.Figure:
""" Create a boxplot with two y-axes for the input DataFrame.

Args:
df (pd.DataFrame): The input DataFrame.
grouping_column (str): The name of the column to group by.
xlabel (str): The label for the x-axis.
showlegend (bool): Whether to show the legend.
normalized_matches (bool): Whether to normalize the matches.
nist_scale (bool): Whether to use the NIST scale.
hide_labels (bool): Whether to hide the labels.

Returns:
fig (plt.Figure): The plot.
"""
matches_col = 'matches'
scores_col = 'scores'

Expand All @@ -120,7 +170,6 @@ def create_plot(df: pd.DataFrame,
plot_width = n_bars * bar_width
fig = plt.figure(figsize=(plot_width, 5))


ax = sns.boxplot(x=grouping_column, y="value", hue="Number",
data=df, hue_order=[matches_col, np.nan],
medianprops={'color': 'darkgreen', 'linewidth': 4.0},
Expand All @@ -137,8 +186,7 @@ def create_plot(df: pd.DataFrame,
if top <= 6:
ax.set_ylim(-0.5, 5.5)
else:
ax.set_ylim(0 - 0.1 * top, top= 1.1* top) # Set y-axis limits

ax.set_ylim(0 - 0.1 * top, top=1.1 * top) # Set y-axis limits

# Set font size of x-axis tick labels
ax.tick_params(axis='x', labelsize=tick_fontsize)
Expand Down Expand Up @@ -169,7 +217,6 @@ def create_plot(df: pd.DataFrame,
else:
ax.set_xlabel("", fontsize=0)
ax2.set_xlabel("", fontsize=0)


# Create a count for each x-axis label
count_data = df[grouping_column].value_counts()
Expand All @@ -186,15 +233,23 @@ def create_plot(df: pd.DataFrame,
labels[0] = ax.get_ylabel()
labels[1] = ax2.get_ylabel()
ax.legend(handles, labels, loc='upper right', fontsize=14)

if hide_labels:
ax.set_ylabel("", fontsize=0)
ax2.set_ylabel("", fontsize=0)

return fig


def scatterplot_matplotlib(df):
def scatterplot_matplotlib(df: pd.DataFrame) -> plt.Figure:
""" Create a scatterplot with the input DataFrame.

Args:
df (pd.DataFrame): The input DataFrame.

Returns:
fig (plt.Figure): The plot.
"""
fig = plt.figure(figsize=(18, 6))
scatter = plt.scatter(
df['scores'],
Expand Down
Loading