Skip to content

Commit

Permalink
better tests (#5)
Browse files Browse the repository at this point in the history
* added tests for plot_learning_curve

* added tests for plot_confusion_matrix

* added tests for plot_roc_curve

* added tests for plot_ks_statistic

* added tests for plot_precision_recall_curve

* added tests for plot_feature_importances

* added tests for plot_silhouette

* added tests for plot_elbow_curve
  • Loading branch information
reiinakano authored Feb 17, 2017
1 parent 6bb4b3b commit 5ce205a
Show file tree
Hide file tree
Showing 4 changed files with 390 additions and 2 deletions.
4 changes: 4 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ python:
- "3.5"
- "3.6"
# command to install dependencies
before_script: # configure a headless display to test plot generation
- "export DISPLAY=:99.0"
- "sh -e /etc/init.d/xvfb start"
- sleep 3 # give xvfb some time to start
install:
- pip install --upgrade pip setuptools wheel
- pip install --only-binary=numpy,scipy numpy scipy
Expand Down
4 changes: 2 additions & 2 deletions scikitplot/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def clustering_factory(clf):
return clf


def plot_silhouette(clf, X, title='Silhouette Analysis', metric='euclidean', copy=False, ax=None):
def plot_silhouette(clf, X, title='Silhouette Analysis', metric='euclidean', copy=True, ax=None):
"""Plots silhouette analysis of clusters using fit_predict.
Args:
Expand Down Expand Up @@ -147,7 +147,7 @@ def plot_elbow_curve(clf, X, title='Elbow Plot', cluster_ranges=None, ax=None):
title (string, optional): Title of the generated plot. Defaults to "Elbow Plot"
cluster_ranges (None or :obj:`list` of int, optional): List of n_clusters for which
to plot the explained variances. Defaults to ``range(0, 11, 2)``.
to plot the explained variances. Defaults to ``range(1, 12, 2)``.
copy (boolean, optional): Determines whether ``fit`` is used on **clf** or on a
copy of **clf**.
Expand Down
305 changes: 305 additions & 0 deletions scikitplot/tests/test_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
import unittest
import scikitplot
import warnings
from sklearn.datasets import load_iris as load_data
from sklearn.datasets import load_breast_cancer
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.exceptions import NotFittedError
import numpy as np
import matplotlib.pyplot as plt


class TestClassifierFactory(unittest.TestCase):
Expand Down Expand Up @@ -82,5 +89,303 @@ def test_method_insertion(self):
'Overriding anyway. This may ' \
'result in unintended behavior.' in str(warning.message)


class TestPlotLearningCurve(unittest.TestCase):

def setUp(self):
np.random.seed(0)
self.X, self.y = load_data(return_X_y=True)
p = np.random.permutation(len(self.X))
self.X, self.y = self.X[p], self.y[p]

def tearDown(self):
plt.close("all")

def test_cv(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
ax = clf.plot_learning_curve(self.X, self.y)
ax = clf.plot_learning_curve(self.X, self.y, cv=5)

def test_train_sizes(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
ax = clf.plot_learning_curve(self.X, self.y, train_sizes=np.linspace(0.1, 1.0, 8))

def test_n_jobs(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
ax = clf.plot_learning_curve(self.X, self.y, n_jobs=-1)

def test_ax(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
fig, ax = plt.subplots(1, 1)
out_ax = clf.plot_learning_curve(self.X, self.y)
assert ax is not out_ax
out_ax = clf.plot_learning_curve(self.X, self.y, ax=ax)
assert ax is out_ax


class TestPlotConfusionMatrix(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.X, self.y = load_data(return_X_y=True)
p = np.random.permutation(len(self.X))
self.X, self.y = self.X[p], self.y[p]

def tearDown(self):
plt.close("all")

def test_cv(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
ax = clf.plot_confusion_matrix(self.X, self.y)
ax = clf.plot_confusion_matrix(self.X, self.y, cv=5)

def test_normalize(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
ax = clf.plot_confusion_matrix(self.X, self.y, normalize=True)
ax = clf.plot_confusion_matrix(self.X, self.y, normalize=False)

def test_do_cv(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
ax = clf.plot_confusion_matrix(self.X, self.y)
self.assertRaises(NotFittedError, clf.plot_confusion_matrix, self.X, self.y, do_cv=False)

def test_shuffle(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
ax = clf.plot_confusion_matrix(self.X, self.y, shuffle=True)
ax = clf.plot_confusion_matrix(self.X, self.y, shuffle=False)

def test_ax(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
fig, ax = plt.subplots(1, 1)
out_ax = clf.plot_confusion_matrix(self.X, self.y)
assert ax is not out_ax
out_ax = clf.plot_confusion_matrix(self.X, self.y, ax=ax)
assert ax is out_ax


class TestPlotROCCurve(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.X, self.y = load_data(return_X_y=True)
p = np.random.permutation(len(self.X))
self.X, self.y = self.X[p], self.y[p]

def tearDown(self):
plt.close("all")

def test_predict_proba(self):
np.random.seed(0)

class DummyClassifier:
def __init__(self):
pass

def fit(self):
pass

def predict(self):
pass

def score(self):
pass

clf = DummyClassifier()
scikitplot.classifier_factory(clf)
self.assertRaises(TypeError, clf.plot_roc_curve, self.X, self.y)

def test_do_split(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
ax = clf.plot_roc_curve(self.X, self.y)
self.assertRaises(AttributeError, clf.plot_roc_curve, self.X, self.y,
do_split=False)

def test_ax(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
fig, ax = plt.subplots(1, 1)
out_ax = clf.plot_roc_curve(self.X, self.y)
assert ax is not out_ax
out_ax = clf.plot_roc_curve(self.X, self.y, ax=ax)
assert ax is out_ax


class TestPlotKSStatistic(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.X, self.y = load_breast_cancer(return_X_y=True)
p = np.random.permutation(len(self.X))
self.X, self.y = self.X[p], self.y[p]

def tearDown(self):
plt.close("all")

def test_predict_proba(self):
np.random.seed(0)

class DummyClassifier:
def __init__(self):
pass

def fit(self):
pass

def predict(self):
pass

def score(self):
pass

clf = DummyClassifier()
scikitplot.classifier_factory(clf)
self.assertRaises(TypeError, clf.plot_ks_statistic, self.X, self.y)

def test_two_classes(self):
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
X, y = load_data(return_X_y=True)
self.assertRaises(ValueError, clf.plot_ks_statistic, X, y)

def test_do_split(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
ax = clf.plot_ks_statistic(self.X, self.y)
self.assertRaises(AttributeError, clf.plot_ks_statistic, self.X, self.y,
do_split=False)

def test_ax(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
fig, ax = plt.subplots(1, 1)
out_ax = clf.plot_ks_statistic(self.X, self.y)
assert ax is not out_ax
out_ax = clf.plot_ks_statistic(self.X, self.y, ax=ax)
assert ax is out_ax


class TestPlotPrecisionRecall(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.X, self.y = load_data(return_X_y=True)
p = np.random.permutation(len(self.X))
self.X, self.y = self.X[p], self.y[p]

def tearDown(self):
plt.close("all")

def test_predict_proba(self):
np.random.seed(0)

class DummyClassifier:
def __init__(self):
pass

def fit(self):
pass

def predict(self):
pass

def score(self):
pass

clf = DummyClassifier()
scikitplot.classifier_factory(clf)
self.assertRaises(TypeError, clf.plot_precision_recall_curve, self.X, self.y)

def test_do_split(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
ax = clf.plot_precision_recall_curve(self.X, self.y)
self.assertRaises(AttributeError, clf.plot_precision_recall_curve, self.X, self.y,
do_split=False)

def test_ax(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
fig, ax = plt.subplots(1, 1)
out_ax = clf.plot_precision_recall_curve(self.X, self.y)
assert ax is not out_ax
out_ax = clf.plot_precision_recall_curve(self.X, self.y, ax=ax)
assert ax is out_ax


class TestFeatureImportances(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.X, self.y = load_data(return_X_y=True)
p = np.random.permutation(len(self.X))
self.X, self.y = self.X[p], self.y[p]

def tearDown(self):
plt.close("all")

def test_feature_importances_in_clf(self):
np.random.seed(0)
clf = LogisticRegression()
scikitplot.classifier_factory(clf)
clf.fit(self.X, self.y)
self.assertRaises(TypeError, clf.plot_feature_importances)

def test_feature_names(self):
np.random.seed(0)
clf = RandomForestClassifier()
scikitplot.classifier_factory(clf)
clf.fit(self.X, self.y)
ax = clf.plot_feature_importances(feature_names=["a", "b", "c", "d"])

def test_max_num_features(self):
np.random.seed(0)
clf = RandomForestClassifier()
scikitplot.classifier_factory(clf)
clf.fit(self.X, self.y)
ax = clf.plot_feature_importances(max_num_features=2)
ax = clf.plot_feature_importances(max_num_features=4)
ax = clf.plot_feature_importances(max_num_features=6)

def test_order(self):
np.random.seed(0)
clf = RandomForestClassifier()
scikitplot.classifier_factory(clf)
clf.fit(self.X, self.y)
ax = clf.plot_feature_importances(order='ascending')
ax = clf.plot_feature_importances(order='descending')
ax = clf.plot_feature_importances(order=None)

def test_ax(self):
np.random.seed(0)
clf = RandomForestClassifier()
scikitplot.classifier_factory(clf)
clf.fit(self.X, self.y)
fig, ax = plt.subplots(1, 1)
out_ax = clf.plot_feature_importances()
assert ax is not out_ax
out_ax = clf.plot_feature_importances(ax=ax)
assert ax is out_ax


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 5ce205a

Please sign in to comment.