Skip to content

Commit

Permalink
refactor: move init stats to own module
Browse files Browse the repository at this point in the history
  • Loading branch information
dynobo committed Jun 24, 2024
1 parent e1b4c01 commit f4e5484
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 91 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install matplotlib default font
run: |
wget -O dejavu.zip http://sourceforge.net/projects/dejavu/files/dejavu/2.37/dejavu-fonts-ttf-2.37.zip
unzip -d dejavu/ dejavu.zip
mv dejavu /usr/share/fonts/
fc-cache -fv
- name: Install dependencies
run: pip install '.[dev]'
- name: Format
Expand Down
8 changes: 5 additions & 3 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Changelog

## 0.4.0 (upcoming)
## 0.4.0 (2024-06-24)

- Breaking change: Drop support for Python < 3.9.
- Improve performance
- Breaking changes:
- Drop support for Python < 3.9.
- Rename `lmdiag.info()` to `lmdiag.help()`
- Fix crash for `linearmodels` with 2+ degrees of freedom.
- Add support for scikit-learn's `LinearRegression`
- Improve performance
6 changes: 3 additions & 3 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import statsmodels.api as sma

import lmdiag
import lmdiag.statistics
import lmdiag.statistics.select

df = sma.datasets.get_rdataset("ames", "openintro").data
lm = sm.formula.api.ols("np.log10(price) ~ Q('Overall.Qual') + np.log(area)", df).fit()

lm_stats = lmdiag.statistics.init_stats(lm)
lm_stats = lmdiag.statistics.select.get_stats(lm)


if __name__ == "__main__":
Expand All @@ -22,7 +22,7 @@
"lm_stats.standard_residuals",
"lm_stats.cooks_d",
"lm_stats.leverage",
"lm_stats.params_count",
"lm_stats.parameter_count",
"lm_stats.sqrt_abs_residuals",
"lm_stats.normalized_quantiles",
"lmdiag.plot(lm)",
Expand Down
12 changes: 6 additions & 6 deletions lmdiag/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from statsmodels.nonparametric.smoothers_lowess import lowess

from lmdiag import style
from lmdiag.statistics import init_stats
from lmdiag.statistics.base import StatsBase
from lmdiag.statistics.select import get_stats

LOWESS_DELTA = 0.005
LOWESS_IT = 2
Expand Down Expand Up @@ -54,7 +54,7 @@ def resid_fit(
Returns:
Figure of the plot.
"""
lm_stats = lm if isinstance(lm, StatsBase) else init_stats(lm, x=x, y=y)
lm_stats = lm if isinstance(lm, StatsBase) else get_stats(lm, x=x, y=y)

fitted = lm_stats.fitted_values
residuals = lm_stats.residuals
Expand Down Expand Up @@ -92,7 +92,7 @@ def q_q(
Returns:
Figure of the plot.
"""
lm_stats = lm if isinstance(lm, StatsBase) else init_stats(lm, x=x, y=y)
lm_stats = lm if isinstance(lm, StatsBase) else get_stats(lm, x=x, y=y)

std_resid = lm_stats.standard_residuals
quantiles = lm_stats.normalized_quantiles
Expand Down Expand Up @@ -149,7 +149,7 @@ def scale_loc(
Returns:
Figure of the plot.
"""
lm_stats = lm if isinstance(lm, StatsBase) else init_stats(lm, x=x, y=y)
lm_stats = lm if isinstance(lm, StatsBase) else get_stats(lm, x=x, y=y)

fitted_vals = lm_stats.fitted_values
sqrt_abs_res = lm_stats.sqrt_abs_residuals
Expand Down Expand Up @@ -199,7 +199,7 @@ def resid_lev(
Returns:
Figure of the plot.
"""
lm_stats = lm if isinstance(lm, StatsBase) else init_stats(lm, x=x, y=y)
lm_stats = lm if isinstance(lm, StatsBase) else get_stats(lm, x=x, y=y)

std_resid = lm_stats.standard_residuals
cooks_d = lm_stats.cooks_d
Expand Down Expand Up @@ -265,7 +265,7 @@ def plot(
Returns:
Figure of the plot.
"""
lm_stats = init_stats(lm=lm, x=x, y=y)
lm_stats = get_stats(lm=lm, x=x, y=y)

fig, axs = plt.subplots(2, 2, **style.subplots)

Expand Down
72 changes: 0 additions & 72 deletions lmdiag/statistics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,72 +0,0 @@
import warnings
from typing import Any, Optional

import numpy as np
from statsmodels.genmod.generalized_linear_model import GLMResults
from statsmodels.regression.linear_model import RegressionResultsWrapper
from statsmodels.robust.robust_linear_model import RLMResults

from lmdiag.statistics.base import StatsBase

try:
import sklearn
except ImportError:
sklearn = None

try:
import linearmodels
except ImportError:
linearmodels = None


def _warn_x_y() -> None:
warnings.warn(
"`x` and `y` arguments are ignored for this model type. Do not pass them.",
stacklevel=3,
)


def _init_linearmodels_stats(lm: Any) -> StatsBase:
from lmdiag.statistics.linearmodels_stats import LinearmodelsStats

return LinearmodelsStats(lm)


def _init_sklearn_stats(lm: Any, x: np.ndarray, y: np.ndarray) -> StatsBase:
from lmdiag.statistics.sklearn_stats import SklearnStats

return SklearnStats(lm, x=x, y=y)


def _init_statsmodels_stats(lm: Any) -> StatsBase:
from lmdiag.statistics.statsmodels_stats import StatsmodelsStats

return StatsmodelsStats(lm)


def init_stats(
lm: Any, x: Optional[np.ndarray] = None, y: Optional[np.ndarray] = None
) -> StatsBase:
"""Gather statistics depending on linear model type."""
if isinstance(lm, (RegressionResultsWrapper, GLMResults, RLMResults)):
if x or y:
_warn_x_y()
model_stats = _init_statsmodels_stats(lm)

elif linearmodels and isinstance(lm, linearmodels.iv.results.OLSResults):
if x or y:
_warn_x_y()
model_stats = _init_linearmodels_stats(lm)

elif sklearn and isinstance(lm, sklearn.linear_model.LinearRegression):
if x is None or y is None:
raise ValueError("x and y args must be provided this model type!")
model_stats = _init_sklearn_stats(lm, x, y)

else:
raise TypeError(
"Model type not (yet) supported. Currently supported are linear "
"models from `statsmodels`, `linearmodels` and `sklearn` packages."
)

return model_stats
74 changes: 74 additions & 0 deletions lmdiag/statistics/select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import warnings
from typing import Any, Optional

import numpy as np
from statsmodels.genmod.generalized_linear_model import GLMResults
from statsmodels.regression.linear_model import RegressionResultsWrapper
from statsmodels.robust.robust_linear_model import RLMResults

from lmdiag.statistics.base import StatsBase

try:
import sklearn
except ImportError:
sklearn = None

try:
import linearmodels
except ImportError:
linearmodels = None


def _warn_x_y() -> None:
warnings.warn(
"`x` and `y` arguments are ignored for this model type. Do not pass them.",
stacklevel=3,
)


def _init_linearmodels_stats(lm: Any) -> StatsBase:
from lmdiag.statistics.linearmodels_stats import LinearmodelsStats

return LinearmodelsStats(lm)


def _init_sklearn_stats(lm: Any, x: np.ndarray, y: np.ndarray) -> StatsBase:
from lmdiag.statistics.sklearn_stats import SklearnStats

return SklearnStats(lm, x=x, y=y)


def _init_statsmodels_stats(lm: Any) -> StatsBase:
from lmdiag.statistics.statsmodels_stats import StatsmodelsStats

return StatsmodelsStats(lm)


def get_stats(
lm: Any, x: Optional[np.ndarray] = None, y: Optional[np.ndarray] = None
) -> StatsBase:
"""Gather statistics depending on linear model type."""
if isinstance(lm, (RegressionResultsWrapper, GLMResults, RLMResults)):
if x or y:
_warn_x_y()
model_stats = _init_statsmodels_stats(lm)

elif linearmodels and isinstance(
lm, (linearmodels.iv.results.OLSResults, linearmodels.iv.results.IVResults)
):
if x or y:
_warn_x_y()
model_stats = _init_linearmodels_stats(lm)

elif sklearn and isinstance(lm, sklearn.linear_model.LinearRegression):
if x is None or y is None:
raise ValueError("x and y args must be provided this model type!")
model_stats = _init_sklearn_stats(lm, x, y)

else:
raise TypeError(
"Model type not (yet) supported. Currently supported are linear "
"models from `statsmodels`, `linearmodels` and `sklearn` packages."
)

return model_stats
11 changes: 11 additions & 0 deletions lmdiag/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ class MplKwargs(TypedDict, total=False):


def use(style: str) -> None:
"""Set predefined style for plots.
Available styles:
- 'black_and_red' (mimics style of R's lm.diag)
Args:
style: Name of the preset style.
Raises:
ValueError: If style is unknown.
"""
if style == "black_and_red":
scatter.update(
{"marker": "o", "color": "none", "edgecolors": "black", "linewidth": 1}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "lmdiag"
version = "0.3.8"
version = "0.4.0"
description = "Diagnostic Plots for Lineare Regression Models. Similar to plot.lm in R."
keywords = [
"lm",
Expand Down

0 comments on commit f4e5484

Please sign in to comment.