Skip to content

Commit

Permalink
removed all unnecessary calls to compute()
Browse files Browse the repository at this point in the history
  • Loading branch information
ParticularMiner committed Oct 25, 2021
1 parent 1ac8e38 commit 1fa55ca
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 45 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,3 @@ docs/source/auto_examples/
docs/source/examples/mydask.png

dask-worker-space
/.project
/.pydevproject
90 changes: 47 additions & 43 deletions dask_ml/feature_extraction/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@ def _hasher(self):
return sklearn.feature_extraction.text.FeatureHasher


def _n_samples(X):
"""Count the number of samples in dask.array.Array X."""
def chunk_n_samples(chunk, axis, keepdims):
return np.array([chunk.shape[0]], dtype=np.int64)

return da.reduction(X,
chunk=chunk_n_samples,
aggregate=np.sum,
concatenate=False,
dtype=np.int64)


def _document_frequency(X, dtype):
"""Count the number of non-zero values for each feature in dask array X."""
def chunk_doc_freq(chunk, axis, keepdims):
Expand All @@ -133,7 +145,7 @@ def chunk_doc_freq(chunk, axis, keepdims):
aggregate=np.sum,
axis=0,
concatenate=False,
dtype=dtype).compute().astype(dtype)
dtype=dtype)


class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer):
Expand Down Expand Up @@ -203,17 +215,19 @@ class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer):
['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this']
"""

def fit_transform(self, raw_documents, y=None):
def get_params(self):
# Note that in general 'self' could refer to an instance of either this
# class or a subclass of this class. Hence it is possible that
# self.get_params() could get unexpected parameters of an instance of a
# subclass. Such parameters need to be excluded here:
subclass_instance_params = self.get_params()
subclass_instance_params = super().get_params()
excluded_keys = getattr(self, '_non_CountVectorizer_params', [])
params = {key: subclass_instance_params[key]
for key in subclass_instance_params
if key not in excluded_keys}
return {key: subclass_instance_params[key]
for key in subclass_instance_params
if key not in excluded_keys}

def fit_transform(self, raw_documents, y=None):
params = self.get_params()
vocabulary = params.pop("vocabulary")
vocabulary_for_transform = vocabulary

Expand All @@ -227,12 +241,12 @@ def fit_transform(self, raw_documents, y=None):
# Case 2: learn vocabulary from the data.
vocabularies = raw_documents.map_partitions(_build_vocabulary, params)
vocabulary = vocabulary_for_transform = (
_merge_vocabulary( *vocabularies.to_delayed() ))
_merge_vocabulary(*vocabularies.to_delayed()))
vocabulary_for_transform = vocabulary_for_transform.persist()
vocabulary_ = vocabulary.compute()
n_features = len(vocabulary_)

meta = scipy.sparse.eye(0, format="csr", dtype=self.dtype)
meta = scipy.sparse.csr_matrix((0, n_features), dtype=self.dtype)
if isinstance(raw_documents, dd.Series):
result = raw_documents.map_partitions(
_count_vectorizer_transform, vocabulary_for_transform,
Expand All @@ -241,23 +255,14 @@ def fit_transform(self, raw_documents, y=None):
result = raw_documents.map_partitions(
_count_vectorizer_transform, vocabulary_for_transform, params)
result = build_array(result, n_features, meta)
result.compute_chunk_sizes()

self.vocabulary_ = vocabulary_
self.fixed_vocabulary_ = fixed_vocabulary

return result

def transform(self, raw_documents):
# Note that in general 'self' could refer to an instance of either this
# class or a subclass of this class. Hence it is possible that
# self.get_params() could get unexpected parameters of an instance of a
# subclass. Such parameters need to be excluded here:
subclass_instance_params = self.get_params()
excluded_keys = getattr(self, '_non_CountVectorizer_params', [])
params = {key: subclass_instance_params[key]
for key in subclass_instance_params
if key not in excluded_keys}
params = self.get_params()
vocabulary = params.pop("vocabulary")

if vocabulary is None:
Expand All @@ -271,14 +276,13 @@ def transform(self, raw_documents):
except ValueError:
vocabulary_for_transform = dask.delayed(vocabulary)
else:
(vocabulary_for_transform,) = client.scatter(
(vocabulary,), broadcast=True
)
(vocabulary_for_transform,) = client.scatter((vocabulary,),
broadcast=True)
else:
vocabulary_for_transform = vocabulary

n_features = vocabulary_length(vocabulary_for_transform)
meta = scipy.sparse.eye(0, format="csr", dtype=self.dtype)
meta = scipy.sparse.csr_matrix((0, n_features), dtype=self.dtype)
if isinstance(raw_documents, dd.Series):
result = raw_documents.map_partitions(
_count_vectorizer_transform, vocabulary_for_transform,
Expand All @@ -287,7 +291,6 @@ def transform(self, raw_documents):
transformed = raw_documents.map_partitions(
_count_vectorizer_transform, vocabulary_for_transform, params)
result = build_array(transformed, n_features, meta)
result.compute_chunk_sizes()
return result

class TfidfTransformer(sklearn.feature_extraction.text.TfidfTransformer):
Expand Down Expand Up @@ -331,30 +334,23 @@ def fit(self, X, y=None):
X : sparse matrix of shape n_samples, n_features)
A matrix of term/token counts.
"""
# X = check_array(X, accept_sparse=('csr', 'csc'))
# if not sp.issparse(X):
# X = sp.csr_matrix(X)
dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64

if self.use_idf:
n_samples, n_features = X.shape
def get_idf_diag(X, dtype):
n_samples = _n_samples(X) # X.shape[0] is not yet known
n_features = X.shape[1]
df = _document_frequency(X, dtype)
# df = df.astype(dtype, **_astype_copy_false(df))

# perform idf smoothing if required
df += int(self.smooth_idf)
n_samples += int(self.smooth_idf)

# log+1 instead of log makes sure terms with zero idf don't get
# suppressed entirely.
idf = np.log(n_samples / df) + 1
self._idf_diag = scipy.sparse.diags(
idf,
offsets=0,
shape=(n_features, n_features),
format="csr",
dtype=dtype,
)
return np.log(n_samples / df) + 1

dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64

if self.use_idf:
self._idf_diag = get_idf_diag(X, dtype)

return self

Expand Down Expand Up @@ -404,8 +400,17 @@ def _dot_idf_diag(chunk):
# idf_ being a property, the automatic attributes detection
# does not work as usual and we need to specify the attribute
# name:
check_is_fitted(self, attributes=["idf_"], msg="idf vector is not fitted")

check_is_fitted(self, attributes=["idf_"],
msg="idf vector is not fitted")
if dask.is_dask_collection(self._idf_diag):
_idf_diag = self._idf_diag.compute()
n_features = len(_idf_diag)
self._idf_diag = scipy.sparse.diags(
_idf_diag,
offsets=0,
shape=(n_features, n_features),
format="csr",
dtype=_idf_diag.dtype)
X = X.map_blocks(_dot_idf_diag, dtype=np.float64, meta=meta)

if self.norm:
Expand Down Expand Up @@ -619,8 +624,7 @@ def fit(self, raw_documents, y=None):
"""
self._check_params()
self._warn_for_unused_params()
X = super().fit_transform(raw_documents,
y=self._non_CountVectorizer_params)
X = super().fit_transform(raw_documents)
self._tfidf.fit(X)
return self

Expand Down

0 comments on commit 1fa55ca

Please sign in to comment.