From c264bce8e8ee26df21d9c5a6c8b08a8ad43aa0e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michal=20Bel=C3=A1k?= Date: Wed, 6 Mar 2024 11:29:34 +0100 Subject: [PATCH] refactor: use mypy type checking * add proper type hints * add mypy type checking to CI --- .github/workflows/test.yaml | 1 + edvart/export_utils.py | 6 ++++- edvart/plots.py | 2 +- edvart/report.py | 14 ++++++------ edvart/report_sections/bivariate_analysis.py | 8 ++++++- edvart/report_sections/dataset_overview.py | 8 +++---- edvart/report_sections/group_analysis.py | 20 ++++++++++------- .../report_sections/multivariate_analysis.py | 20 +++++++++-------- edvart/report_sections/table_of_contents.py | 2 +- .../timeseries_analysis/boxplots_over_time.py | 10 ++++----- .../timeseries_analysis/fourier_transform.py | 2 +- .../timeseries_analysis/rolling_statistics.py | 4 ++-- .../seasonal_decomposition.py | 2 +- .../timeseries_analysis/short_time_ft.py | 4 ++-- .../timeseries_analysis/stationarity_tests.py | 2 +- edvart/report_sections/univariate_analysis.py | 22 ++++++++++--------- edvart/utils.py | 11 +++++----- pyproject.toml | 6 +++++ 18 files changed, 85 insertions(+), 59 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7d86942..0897bfb 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -48,6 +48,7 @@ jobs: poetry run pylint --rcfile=.pylintrc edvart/ poetry run black --check --line-length 100 edvart/ tests/ poetry run isort --check --line-length 100 --profile black edvart/ tests/ + poetry run mypy edvart/ dismiss-stale-reviews: runs-on: ubuntu-22.04 diff --git a/edvart/export_utils.py b/edvart/export_utils.py index 97a087f..d5283c5 100644 --- a/edvart/export_utils.py +++ b/edvart/export_utils.py @@ -29,6 +29,10 @@ def embed_image_base64(image_path: str, mime: str = "image/png") -> str: # Look up directory where currently executed template is located # Jinja's @environmentfilter or @contextfilter does not seem to provide # any information about the path of the template. - template_dir = os.path.dirname(inspect.getfile(inspect.currentframe().f_back)) + current_frame = inspect.currentframe() + assert current_frame is not None + frame_back = current_frame.f_back + assert frame_back is not None + template_dir = os.path.dirname(inspect.getfile(frame_back)) with open(os.path.join(template_dir, image_path), "rb") as img: return f"data:{mime};base64," + str(base64.b64encode(img.read()).decode("utf-8")) diff --git a/edvart/plots.py b/edvart/plots.py index 4bca094..2013968 100644 --- a/edvart/plots.py +++ b/edvart/plots.py @@ -107,7 +107,7 @@ def _scatter_plot_2d_noninteractive( color_categorical = pd.Categorical(df[color_col]) color_codes = color_categorical.codes else: - color_codes = df[color_col] + color_codes = df[color_col].values.astype(np.signedinteger) scatter = ax.scatter(x, y, c=color_codes, alpha=opacity) if is_color_categorical: diff --git a/edvart/report.py b/edvart/report.py index 0634ce2..9fb6eef 100755 --- a/edvart/report.py +++ b/edvart/report.py @@ -4,7 +4,7 @@ import warnings from abc import ABC from copy import copy -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Sized, Tuple, Union import isort import nbconvert @@ -64,7 +64,7 @@ def __init__( self.df = dataframe self.sections: list[Section] = [] self.verbosity = Verbosity(verbosity) - self._table_of_contents = None + self._table_of_contents: Optional[TableOfContents] = None def _warn_if_empty(self) -> None: """Warns if the report contains no sections.""" @@ -207,7 +207,7 @@ def _export_html( Maximum number of seconds to wait for a cell to finish execution. """ # Execute notebook to produce output of cells - html_exp_kwargs = dict( + html_exp_kwargs: Dict[str, Any] = dict( preprocessors=[nbconvert.preprocessors.ExecutePreprocessor(timeout=timeout)] ) if template_name is not None: @@ -275,7 +275,7 @@ def export_html( # and unpickles the the whole report object from the decoded binary data unpickle_report = code_dedent( f""" - data = {buffer_base64} + data = {buffer_base64!r} report = pickle.loads(base64.b85decode(data), fix_imports=False) """ ) @@ -676,7 +676,7 @@ def __init__( columns_bivariate_analysis: Optional[List[str]] = None, columns_multivariate_analysis: Optional[List[str]] = None, columns_group_analysis: Optional[List[str]] = None, - groupby: Union[str, List[str]] = None, + groupby: Optional[Union[str, List[str]]] = None, ): super().__init__(dataframe, verbosity) @@ -699,7 +699,7 @@ def __init__( ) if isinstance(groupby, str): color_col = groupby - elif hasattr(groupby, "__len__") and len(groupby) == 1: + elif isinstance(groupby, Sized) and len(groupby) == 1: color_col = groupby[0] else: color_col = None @@ -740,7 +740,7 @@ def __init__( verbosity: Verbosity = Verbosity.LOW, ): super().__init__(dataframe, verbosity) - if not is_date(dataframe.index): + if not is_date(dataframe.index.to_series()): raise ValueError( "Input dataframe needs to be indexed by time." "Please reindex your data to be indexed by either a DatetimeIndex or a PeriodIndex." diff --git a/edvart/report_sections/bivariate_analysis.py b/edvart/report_sections/bivariate_analysis.py index cad75eb..23e6473 100644 --- a/edvart/report_sections/bivariate_analysis.py +++ b/edvart/report_sections/bivariate_analysis.py @@ -120,6 +120,8 @@ def __init__( raise ValueError("Either both or neither of columns_x, columns_y must be specified.") # For analyses which do not take columns_pairs, prepare columns_x and columns_y in case # columns_pairs is the only parameter specified + columns_x_no_pairs: Optional[List[str]] + columns_y_no_pairs: Optional[List[str]] if columns is None and columns_x is None and columns_pairs is not None: columns_x_no_pairs = [pair[0] for pair in columns_pairs] columns_y_no_pairs = [pair[1] for pair in columns_pairs] @@ -456,6 +458,7 @@ def _get_columns_x_y( if columns is None: columns = list(df.columns) columns_x = columns_y = columns + assert columns_y is not None columns_x = [col for col in columns_x if is_numeric(df[col])] columns_y = [col for col in columns_y if is_numeric(df[col])] @@ -722,6 +725,7 @@ def include_column(col: str) -> bool: columns_x = columns columns_y = columns if not allow_categorical: + assert columns_y is not None columns_x = list(filter(include_column, columns_x)) columns_y = list(filter(include_column, columns_y)) sns.pairplot(df, x_vars=columns_x, y_vars=columns_y, hue=color_col) @@ -908,6 +912,8 @@ def include_column(col: str) -> bool: if columns_x is None: columns_pairs = list(itertools.combinations(columns, 2)) else: + assert columns_x is not None + assert columns_y is not None columns_pairs = [ (col_x, col_y) for (col_x, col_y) in itertools.product(columns_x, columns_y) @@ -971,7 +977,7 @@ def contingency_table( annot = table.replace(0, "") if hide_zeros else table ax = sns.heatmap( - scaling_func(table), + scaling_func(table.values), annot=annot, fmt="", cbar=False, diff --git a/edvart/report_sections/dataset_overview.py b/edvart/report_sections/dataset_overview.py index 3cdcd10..9704bd7 100644 --- a/edvart/report_sections/dataset_overview.py +++ b/edvart/report_sections/dataset_overview.py @@ -449,7 +449,7 @@ def data_types(df: pd.DataFrame, columns: Optional[List[str]] = None) -> None: """ if columns is not None: df = df[columns] - dtypes = df.apply( + dtypes = df.apply( # type: ignore func=lambda x_: str(infer_data_type(x_)), axis=0, result_type="expand", @@ -652,7 +652,7 @@ def missing_values( bar_plot_title: str = "Missing Values Percentage of Each Column", bar_plot_ylim: float = 0, bar_plot_color: str = "#FFA07A", - **bar_plot_args: Dict[str, Any], + **bar_plot_args: Any, ) -> None: """Displays a table of missing values percentages for each column of df and a bar plot of the percentages. @@ -675,7 +675,7 @@ def missing_values( Bar plot y axis bottom limit. bar_plot_color : str Color of bars in the bar plot in hex format. - bar_plot_args : Dict[str, Any] + bar_plot_args : Any Additional kwargs passed to pandas.Series.bar. """ if columns is not None: @@ -717,7 +717,7 @@ def missing_values( title=bar_plot_title, ylim=bar_plot_ylim, color=bar_plot_color, - **bar_plot_args, + **bar_plot_args, # type: ignore ) .set_ylabel("Missing Values [%]") ) diff --git a/edvart/report_sections/group_analysis.py b/edvart/report_sections/group_analysis.py index 7ad113e..8ee9854 100644 --- a/edvart/report_sections/group_analysis.py +++ b/edvart/report_sections/group_analysis.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Union import colorlover as cl import nbformat.v4 as nbfv4 @@ -102,7 +102,7 @@ def required_imports(self) -> List[str]: "import plotly.graph_objects as go", "from edvart.data_types import infer_data_type, DataType", "from edvart import utils", - "from typing import List, Dict, Optional, Callable", + "from typing import List, Dict, Optional, Callable, Iterable", "from plotly.subplots import make_subplots", ] @@ -218,7 +218,7 @@ def add_cells(self, cells: List[Dict[str, Any]], df: pd.DataFrame) -> None: ) cells.append(nbfv4.new_code_cell(code)) - columns = self.columns if self.columns is not None else df.columns + columns = self.columns if self.columns is not None else df.columns.to_list() if not self.show_statistics and not self.show_dist: return @@ -362,7 +362,7 @@ def within_group_stats( df: pd.DataFrame, groupby: List[str], column: str, - stats: Dict[str, Callable[[pd.Series], float]] = None, + stats: Optional[Dict[str, Callable[[pd.Series], float]]] = None, round_decimals: int = 2, ) -> None: """Display withing group statistics for a column of df grouped by one or other more columns. @@ -448,7 +448,9 @@ def group_missing_values( df_grouped = df.groupby(groupby)[columns] # Calculate number of samples in each group - sizes = df_grouped.size().rename("Group Size") + sizes = df_grouped.size() + assert isinstance(sizes, pd.Series) + sizes = sizes.rename("Group Size") # Calculate missing values percentage of each column for each group missing = df_grouped.apply(lambda g: g.isna().sum(axis=0)) @@ -490,7 +492,7 @@ def color_cell(value): background-color: {bg_hex}; """ - render = final_table.style.applymap( + render = final_table.style.map( func=color_cell, subset=pd.IndexSlice[:, colored_columns] ).format(formatter="{0:.2f} %", subset=pd.IndexSlice[:, colored_columns]) else: @@ -553,7 +555,8 @@ def group_barplot( fig = go.Figure() for color_idx, (idx, row) in enumerate(pivot.iterrows()): - if hasattr(idx, "__len__") and not isinstance(idx, str): + group_name: Hashable + if isinstance(idx, Iterable) and not isinstance(idx, str): group_name = "_".join([str(i) for i in idx]) else: group_name = idx @@ -641,7 +644,8 @@ def overlaid_histograms( ) for color_idx, (name, group) in enumerate(df.groupby(groupby)): - if hasattr(name, "__len__") and not isinstance(name, str): + group_name: Hashable + if isinstance(name, Iterable) and not isinstance(name, str): group_name = "_".join([str(i) for i in name]) else: group_name = name diff --git a/edvart/report_sections/multivariate_analysis.py b/edvart/report_sections/multivariate_analysis.py index 63d194b..c94fbdc 100644 --- a/edvart/report_sections/multivariate_analysis.py +++ b/edvart/report_sections/multivariate_analysis.py @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import matplotlib.pyplot as plt import nbformat.v4 as nbfv4 @@ -487,7 +487,7 @@ def pca_explained_variance( plt.ylabel("Explained variance ratio") plt.xticks( ticks=range(len(pca.explained_variance_ratio_)), - labels=range(1, (len(pca.explained_variance_ratio_) + 1)), + labels=[str(label) for label in range(1, (len(pca.explained_variance_ratio_) + 1))], ) if show_grid: plt.grid() @@ -630,13 +630,15 @@ def parallel_coordinates( columns = [col for col in columns if col not in hide_columns] if drop_na: df = df.dropna() + + line: Optional[Dict[str, Any]] = None if color_col is not None: is_categorical_color = infer_data_type(df[color_col]) in ( DataType.CATEGORICAL, DataType.UNIQUE, DataType.BOOLEAN, ) - + colorscale: Union[List[Tuple[float, str]], str] if is_categorical_color: categories = df[color_col].unique() colorscale = get_default_discrete_colorscale(n_colors=len(categories)) @@ -669,8 +671,6 @@ def parallel_coordinates( "cmax": len(categories) - 0.5, } ) - else: - line = None # Add numeric columns to dimensions dimensions = [{"label": col_name, "values": df[col_name]} for col_name in numeric_columns] # Add categorical columns to dimensions @@ -818,12 +818,15 @@ def parallel_categories( columns = [col for col in columns if col not in hide_columns] if drop_na: df = df.dropna() + + line: Optional[Dict[str, Any]] = None if color_col is not None: categorical_color = infer_data_type(df[color_col]) in ( DataType.CATEGORICAL, DataType.UNIQUE, DataType.BOOLEAN, ) + colorscale: Union[List[Tuple[float, str]], str] if categorical_color: categories = df[color_col].unique() colorscale = get_default_discrete_colorscale(n_colors=len(categories)) @@ -833,14 +836,15 @@ def parallel_categories( color_series = df[color_col] colorscale = "Bluered_r" + colorbar: Dict[str, Any] = {"title": color_col} line = { "color": color_series, "colorscale": colorscale, - "colorbar": {"title": color_col}, + "colorbar": colorbar, } if categorical_color: - line["colorbar"].update( + colorbar.update( { "tickvals": color_series.unique(), "ticktext": categories, @@ -855,8 +859,6 @@ def parallel_categories( "cmax": len(categories) - 0.5, } ) - else: - line = None dimensions = [go.parcats.Dimension(values=df[col_name], label=col_name) for col_name in columns] diff --git a/edvart/report_sections/table_of_contents.py b/edvart/report_sections/table_of_contents.py index 198dcd5..b41117d 100644 --- a/edvart/report_sections/table_of_contents.py +++ b/edvart/report_sections/table_of_contents.py @@ -94,7 +94,7 @@ def show(self, sections: List[Section]) -> None: """ display(Markdown(self._title)) - lines = [] + lines: List[str] = [] for section in sections: self._add_section_lines(section, 1, lines, self._include_subsections) display(Markdown("\n".join(lines))) diff --git a/edvart/report_sections/timeseries_analysis/boxplots_over_time.py b/edvart/report_sections/timeseries_analysis/boxplots_over_time.py index 61f756d..5b41d38 100644 --- a/edvart/report_sections/timeseries_analysis/boxplots_over_time.py +++ b/edvart/report_sections/timeseries_analysis/boxplots_over_time.py @@ -48,7 +48,7 @@ def __init__( self, verbosity: Verbosity = Verbosity.LOW, columns: Optional[List[str]] = None, - grouping_function: Callable[[Any], str] = None, + grouping_function: Optional[Callable[[Any], str]] = None, grouping_function_imports: Optional[List[str]] = None, grouping_name: Optional[str] = None, default_nunique_max: int = 80, @@ -161,7 +161,7 @@ def show(self, df: pd.DataFrame) -> None: ) -def default_grouping_functions() -> Dict[str, Callable[[datetime], str]]: +def default_grouping_functions() -> Dict[str, Callable[[pd.Timestamp], str]]: """Return a dictionary of function names and functions. The function takes a pandas datetime and represents it as a rougher (in terms of time) @@ -170,7 +170,7 @@ def default_grouping_functions() -> Dict[str, Callable[[datetime], str]]: Returns ------- - Dict[str, Callable[[datetime], str]] + Dict[str, Callable[[pandas.Timestamp], str]] Dictionary from grouping function names to grouping functions. """ return { @@ -217,7 +217,7 @@ def get_default_grouping_func(df: pd.DataFrame, nunique_max: int = 80) -> Tuple[ def show_boxplots_over_time( df: pd.DataFrame, columns: Optional[List[str]] = None, - grouping_function: Callable[[Any], str] = None, + grouping_function: Optional[Callable[[Any], str]] = None, grouping_name: Optional[str] = None, default_nunique_max: int = 80, figsize: Tuple[float, float] = (20, 7), @@ -264,7 +264,7 @@ def show_boxplots_over_time( grouping_name, grouping_function = get_default_grouping_func( df, nunique_max=default_nunique_max ) - elif default_grouping_funcs.get(grouping_name) is not None: + elif grouping_name is not None and default_grouping_funcs.get(grouping_name) is not None: grouping_function = default_grouping_funcs[grouping_name] if columns is None: diff --git a/edvart/report_sections/timeseries_analysis/fourier_transform.py b/edvart/report_sections/timeseries_analysis/fourier_transform.py index f28df76..f625ec7 100644 --- a/edvart/report_sections/timeseries_analysis/fourier_transform.py +++ b/edvart/report_sections/timeseries_analysis/fourier_transform.py @@ -145,7 +145,7 @@ def show_fourier_transform( for col in columns: if not is_numeric(df[col]): raise ValueError(f"Cannot perform Fourier transform for non-numeric column `{col}`") - index_freq = pd.infer_freq(df.index) or "" + index_freq = pd.infer_freq(df.index.to_series()) or "" for col in columns: # FFT requires samples at regular intervals df_col = df[col].interpolate(method="time") diff --git a/edvart/report_sections/timeseries_analysis/rolling_statistics.py b/edvart/report_sections/timeseries_analysis/rolling_statistics.py index 32cb65a..5896ace 100644 --- a/edvart/report_sections/timeseries_analysis/rolling_statistics.py +++ b/edvart/report_sections/timeseries_analysis/rolling_statistics.py @@ -161,8 +161,8 @@ def show_rolling_statistics( index = df.index[window_size - 1 :] layout = dict(xaxis_rangeslider_visible=True) - - data = [] + import plotly + data: List[List[plotly.basedatatypes.BaseTraceType]] = [] for col in columns: data.append([]) if show_std_dev: diff --git a/edvart/report_sections/timeseries_analysis/seasonal_decomposition.py b/edvart/report_sections/timeseries_analysis/seasonal_decomposition.py index 40a25e8..991939f 100644 --- a/edvart/report_sections/timeseries_analysis/seasonal_decomposition.py +++ b/edvart/report_sections/timeseries_analysis/seasonal_decomposition.py @@ -149,7 +149,7 @@ def show_seasonal_decomposition( If the input data is not indexed by time in ascending order. """ df = df.interpolate(method="time") - if pd.infer_freq(df.index) is None and period is None: + if pd.infer_freq(df.index.to_series()) is None and period is None: display( Markdown( "
" diff --git a/edvart/report_sections/timeseries_analysis/short_time_ft.py b/edvart/report_sections/timeseries_analysis/short_time_ft.py index 602c5d2..4dbca4d 100644 --- a/edvart/report_sections/timeseries_analysis/short_time_ft.py +++ b/edvart/report_sections/timeseries_analysis/short_time_ft.py @@ -128,7 +128,7 @@ def show_short_time_ft( columns: Optional[List[str]] = None, overlap: Optional[int] = None, log: bool = True, - window: Union[str, Tuple, "array-like"] = "hamming", + window: Union[str, Tuple, np.typing.ArrayLike] = "hamming", scaling: str = "spectrum", figsize: Tuple[float, float] = (20, 7), colormap: Any = "viridis", @@ -185,7 +185,7 @@ def show_short_time_ft( for col in columns: if not is_numeric(df[col]): raise ValueError(f"Cannot perform STFT for non-numeric column {col}") - index_freq = pd.infer_freq(df.index) or "" + index_freq = pd.infer_freq(df.index.to_series()) or "" for col in columns: display(Markdown(f"---\n### {col}")) freqs, times, sx = signal.spectrogram( # pylint: disable=invalid-name diff --git a/edvart/report_sections/timeseries_analysis/stationarity_tests.py b/edvart/report_sections/timeseries_analysis/stationarity_tests.py index a073044..aef514e 100644 --- a/edvart/report_sections/timeseries_analysis/stationarity_tests.py +++ b/edvart/report_sections/timeseries_analysis/stationarity_tests.py @@ -96,7 +96,7 @@ def show(self, df: pd.DataFrame) -> None: show_stationarity_tests(df=df, columns=self.columns) -def default_stationarity_tests() -> Dict[pd.Series, Callable[[pd.Series], "test_result"]]: +def default_stationarity_tests() -> Dict[str, Callable[[pd.Series], Any]]: """Return a dictionary of stationarity test and functions. Stationarity tests are: diff --git a/edvart/report_sections/univariate_analysis.py b/edvart/report_sections/univariate_analysis.py index 39cfe82..fec7746 100644 --- a/edvart/report_sections/univariate_analysis.py +++ b/edvart/report_sections/univariate_analysis.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import matplotlib.pyplot as plt import nbformat.v4 as nbfv4 @@ -243,7 +243,7 @@ def default_quantile_statistics(): def histogram( series: pd.Series, - bins: Optional[int] = None, + bins: Optional[Union[int, str, np.ndarray]] = None, density: bool = False, box_plot: bool = True, figsize: Tuple[float, float] = (20, 7), @@ -256,8 +256,10 @@ def histogram( ---------- series : pd.Series Numerical series. - bins : int, optional - Number of bins of the histogram. If None, the number of bins is inferred. + bins : int or str or array_like, optional + If bins is an int, it defines the number of equal-width bins in the range of the series. + If bins is a string, it defines the method used to calculate the optimal bin width. + If bins is an array, it defines the bin edges. density : bool (default = False) If True, the area of the histogram bars will sum up to 1. box_plot : bool (default = True) @@ -285,7 +287,7 @@ def histogram( bins = "sturges" else: bins = bin_edges - + assert bins is not None if box_plot: _fig, (ax_box, ax_hist) = plt.subplots( nrows=2, @@ -295,7 +297,7 @@ def histogram( ) sns.boxplot(x=series, ax=ax_box, **boxplot_kwargs) sns.histplot( - data=series, + data=series.to_frame(), bins=bins, stat="density" if density else "count", ax=ax_hist, @@ -306,7 +308,7 @@ def histogram( else: plt.figure(figsize=figsize) sns.histplot( - data=series, + data=series.to_frame(), bins=bins, stat="density" if density else "count", kde=False, @@ -320,7 +322,7 @@ def bar_plot( relative_count: bool = False, figsize: Tuple[float, float] = (20, 7), plotting_threshold: int = 50, - **bar_plot_args: Dict[str, Any], + **bar_plot_args: Any, ) -> None: """Plots a bar plot visualizing frequencies of series elements. @@ -335,7 +337,7 @@ def bar_plot( plotting_threshold : int If the number of unique values in the series is greater than this, no plot is created instead a warning is issued. - bar_plot_args : Dict[str, Any] + bar_plot_args : Any Additional kwargs passed to pandas.Series.bar. """ if series.nunique() > plotting_threshold: @@ -406,7 +408,7 @@ def top_most_frequent(series: pd.Series, n_top: int = 5) -> None: n_top : int The number of most frequent values to include in the table. """ - frequent_values = dict_to_html(utils.top_frequent_values(series, n_top=n_top)) + frequent_values = dict_to_html(dict(utils.top_frequent_values(series, n_top=n_top))) frequent_values_html = add_html_heading(frequent_values, "Most frequent values", 3) display(HTML(frequent_values_html)) diff --git a/edvart/utils.py b/edvart/utils.py index 63fdab8..11fd9b2 100755 --- a/edvart/utils.py +++ b/edvart/utils.py @@ -1,6 +1,6 @@ import os from contextlib import contextmanager -from typing import Any, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, Iterable, Iterator, List, Literal, Mapping, Optional, Tuple, Union import pandas as pd import plotly @@ -10,7 +10,7 @@ from edvart.data_types import is_numeric -def top_frequent_values(series: pd.Series, n_top: int = 10) -> Dict[Any, float]: +def top_frequent_values(series: pd.Series, n_top: int = 10) -> Mapping[str, Any]: """ Counts top n most frequent values in series along with other value counts and NULL value counts. @@ -30,7 +30,7 @@ def top_frequent_values(series: pd.Series, n_top: int = 10) -> Dict[Any, float]: # Calculate frequencies counts = series.value_counts() nan_count = series.isna().sum() - result_dict = { + result_dict: Dict[str, Any] = { **(counts[:n_top].to_dict()), "Other values count": counts[n_top:].sum(), "Null": nan_count, @@ -446,7 +446,7 @@ def mode(series: pd.Series) -> float: most_frequent = series.mode(dropna=True) if len(most_frequent) == 0: return float("nan") - return most_frequent[0] + return float(most_frequent[0]) def std(series: pd.Series) -> float: @@ -622,4 +622,5 @@ def env_var(name: str, value: str) -> Iterator[None]: try: yield finally: - os.environ = original_env + os.environ.clear() + os.environ.update(original_env) diff --git a/pyproject.toml b/pyproject.toml index bdb9d92..bdb8f90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,13 @@ black = "^22.3.0" pylint = "~3.1" sphinx-copybutton = "^0.5.2" pytest-xdist = "^3.3.1" +pandas-stubs = "^2.2.0" +mypy = "^1.8.0" [build-system] requires = ["poetry_core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +[tool.mypy] +ignore_missing_imports = true +python_version = "3.9"