Skip to content

Commit

Permalink
Bug fix and several usability improvements.
Browse files Browse the repository at this point in the history
- Fix a bug in auto_simplify method.
- Remove a warning.
- Add an option in plot to not display sub-regressions.
  • Loading branch information
Ezibenroc committed Jun 29, 2018
1 parent 9166bbc commit c670a74
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 13 deletions.
10 changes: 5 additions & 5 deletions advanced_features.ipynb

Large diffs are not rendered by default.

11 changes: 4 additions & 7 deletions pytree/reg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections import namedtuple, Counter
import itertools
import math
import sys
from abc import ABC, abstractmethod
from copy import deepcopy
from decimal import Decimal, InvalidOperation
Expand All @@ -15,8 +14,6 @@
import statsmodels.formula.api as statsmodels
except ImportError:
statsmodels = None
sys.stderr.write(
'WARNING: no module statsmodels, the tree will not be simplified.')
try:
import graphviz # https://github.com/xflr6/graphviz
except ImportError:
Expand Down Expand Up @@ -262,16 +259,16 @@ def __show_plot(self, log, log_x, log_y):
plt.yscale('log')
plt.show()

def plot_dataset(self, log=False, log_x=False, log_y=False, alpha=0.5):
def plot_dataset(self, log=False, log_x=False, log_y=False, alpha=0.5, plot_merged_reg=False):
data = list(self)
x = [d[0] for d in data]
y = [d[1] for d in data]
plt.figure(figsize=(20, 20))
plt.subplot(2, 1, 1)
plt.plot(x, y, 'o', color='blue', alpha=alpha)
if len(self.breakpoints) > 0:
if len(self.breakpoints) > 0 and plot_merged_reg:
self.merge().__plot_reg('black', log=log or log_x)
if isinstance(self, Node):
if isinstance(self, Node) and plot_merged_reg:
xl, yl = zip(*self.left)
xr, yr = zip(*reversed(list(self.right)))
Node(Leaf(xl, yl, self.config), Leaf(xr, yr, self.config),
Expand Down Expand Up @@ -829,7 +826,7 @@ def auto_simplify(self):
min_reg = None
for res in result:
reg = res['regression']
if min_error > reg.error or reg.error_equal(min_error, self.error):
if min_error > reg.error or reg.error_equal(min_error, reg.error):
min_error = reg.error
min_reg = reg
return min_reg
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from setuptools import setup
import subprocess

VERSION = '0.0.3'
VERSION = '0.0.4'


class CommandError(Exception):
Expand Down
5 changes: 5 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def test_plot_dataset(self, mock_show):
reg.plot_dataset(log=True)
reg.plot_dataset(log_x=True)
reg.plot_dataset(log_y=True)
reg.plot_dataset(plot_merged_reg=True)

@mock.patch("matplotlib.pyplot.show")
def test_plot_error(self, mock_show):
Expand Down Expand Up @@ -405,6 +406,7 @@ def test_plot_dataset(self, mock_show):
reg.plot_dataset(log=True)
reg.plot_dataset(log_x=True)
reg.plot_dataset(log_y=True)
reg.plot_dataset(plot_merged_reg=True)

def generic_multiplesplits_simplify(self, cls, repeat):
self.maxDiff = None
Expand Down Expand Up @@ -432,6 +434,9 @@ def generic_multiplesplits_simplify(self, cls, repeat):
self.assertEqual(simple_reg.breakpoints, expected_reg.breakpoints)
self.assertEqual(simple_reg.RSS, expected_reg.RSS)
self.assertEqual(simple_reg.BIC, expected_reg.BIC)
# Checking that the auto_simplify() is a fix-point
new_reg = simple_reg.auto_simplify()
self.assertEqual(simple_reg.breakpoints, new_reg.breakpoints)

def test_multiple_splits_simplify(self):
self.generic_multiplesplits_simplify(float, 1)
Expand Down

0 comments on commit c670a74

Please sign in to comment.