Skip to content

Commit

Permalink
dask: naive bayes
Browse files Browse the repository at this point in the history
  • Loading branch information
noahnovsak committed Aug 22, 2023
1 parent 51beee7 commit 8eaab15
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
18 changes: 17 additions & 1 deletion Orange/classification/naive_bayes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import numpy as np
import scipy.sparse as sp
import dask.array as da
from dask import delayed, compute

from Orange.classification import Learner, Model
from Orange.data import Instance, Storage, Table
from Orange.data.dask import DaskTable
from Orange.statistics import contingency
from Orange.preprocess import Discretize, RemoveNaNColumns

Expand Down Expand Up @@ -63,7 +66,7 @@ def __init__(self, log_cont_prob, class_prob, domain):
def predict_storage(self, data):
if isinstance(data, Instance):
data = Table.from_numpy(None, np.atleast_2d(data.x))
if type(data) is Table: # pylint: disable=unidiomatic-typecheck
if type(data) in (Table, DaskTable): # pylint: disable=unidiomatic-typecheck
return self.predict(data.X)

if not len(data) or not len(data[0]):
Expand All @@ -86,6 +89,8 @@ def predict(self, X):
if self.log_cont_prob is not None:
if sp.issparse(X):
self._sparse_probs(X, probs)
elif isinstance(X, da.Array):
probs = self._dask_probs(X, probs)

Check warning on line 93 in Orange/classification/naive_bayes.py

View check run for this annotation

Codecov / codecov/patch

Orange/classification/naive_bayes.py#L93

Added line #L93 was not covered by tests
else:
self._dense_probs(X, probs)
np.exp(probs, probs)
Expand All @@ -104,6 +109,17 @@ def _dense_probs(self, data, probs):
probs += probs0[col]
return probs

def _dask_probs(self, data, probs):
@delayed
def map_probs(col, attr_prob):
col[np.isnan(col)] = attr_prob.shape[1] - 1
return np.vstack((attr_prob.T, zeros))[col.astype(int)]

Check warning on line 116 in Orange/classification/naive_bayes.py

View check run for this annotation

Codecov / codecov/patch

Orange/classification/naive_bayes.py#L113-L116

Added lines #L113 - L116 were not covered by tests

zeros = np.zeros((1, probs.shape[1]))
probs = sum(map_probs(col, attr_prob)

Check warning on line 119 in Orange/classification/naive_bayes.py

View check run for this annotation

Codecov / codecov/patch

Orange/classification/naive_bayes.py#L118-L119

Added lines #L118 - L119 were not covered by tests
for col, attr_prob in zip(data.T, self.log_cont_prob))
return compute(probs)[0]

Check warning on line 121 in Orange/classification/naive_bayes.py

View check run for this annotation

Codecov / codecov/patch

Orange/classification/naive_bayes.py#L121

Added line #L121 was not covered by tests

def _sparse_probs(self, data, probs):
n_vals = max(p.shape[1] for p in self.log_cont_prob) + 1
log_prob = np.zeros((len(self.log_cont_prob),
Expand Down
46 changes: 46 additions & 0 deletions Orange/data/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from Orange.data import Table, RowInstance
from Orange.data.table import _FromTableConversion, _ArrayConversion
from Orange.statistics.util import contingency


class DaskRowInstance(RowInstance):
Expand Down Expand Up @@ -264,6 +265,51 @@ def _filter_has_class(self, negate=False):
retain = np.logical_not(retain)
return self.from_table_rows(self, np.asarray(retain))

def _compute_contingency(self, col_vars=None, row_var=None):
if row_var is None:
row_var = self.domain.class_var
if row_var is None:
raise ValueError("No row variable")

Check warning on line 272 in Orange/data/dask.py

View check run for this annotation

Codecov / codecov/patch

Orange/data/dask.py#L270-L272

Added lines #L270 - L272 were not covered by tests

row_indi = self.domain.index(row_var)
row_var = self.domain[row_indi]

if not row_var.is_discrete:
raise TypeError("Row variable must be discrete")

Check warning on line 278 in Orange/data/dask.py

View check run for this annotation

Codecov / codecov/patch

Orange/data/dask.py#L278

Added line #L278 was not covered by tests

if col_vars is None:
col_indi = range(len(self.domain.variables))

Check warning on line 281 in Orange/data/dask.py

View check run for this annotation

Codecov / codecov/patch

Orange/data/dask.py#L281

Added line #L281 was not covered by tests
else:
col_indi = [self.domain.index(var) for var in col_vars]
col_vars = [self.domain[ind] for ind in col_indi]

if any(not var.is_discrete for var in col_vars):
raise NotImplementedError("Contingency can only be computed for categorical values.")

@dask.delayed
def delayed_contingency(*args, **kwargs):
return contingency(*args, **kwargs)

n_atts = self.X.shape[1]
contingencies = [None] * len(col_vars)
for arr, f_cond, f_ind in (
(self.X, lambda i: 0 <= i < n_atts, lambda i: i),
(self._Y, lambda i: i >= n_atts, lambda i: i - n_atts),
(self.metas, lambda i: i < 0, lambda i: -1 - i)):

for e, ind in enumerate(col_indi):
if f_cond(ind):
col_i, arr_i, var = e, f_ind(col_indi[e]), col_vars[e]
col = arr if arr.ndim == 1 else arr[:, arr_i]
contingencies[col_i] = delayed_contingency(
col.astype(float),
self._get_column_view(row_indi),
max_X=len(var.values) - 1,
max_y=len(row_var.values) - 1,
weights=self.W if self.has_weights() else None)

return dask.compute(contingencies)[0]


def dask_stats(X, compute_variance=False):
is_numeric = np.issubdtype(X.dtype, np.number)
Expand Down

0 comments on commit 8eaab15

Please sign in to comment.