forked from univai-summerschool-2019/LogisticText
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path20ng.py
24 lines (22 loc) · 1012 Bytes
/
20ng.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import fetch_20newsgroups
categories = [
'alt.atheism',
'talk.religion.misc',
]
data = fetch_20newsgroups(subset='train', categories=categories)
pipeline = Pipeline([('vect', CountVectorizer()),
('tfidf', TfidfTransformer()),
('clf', LogisticRegression())])
grid = {'vect__ngram_range': [(1, 2)],
'tfidf__norm': ['l1', 'l2'],
'clf__C': [100, 10, 1.0, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]}
if __name__=='__main__':
grid_search = GridSearchCV(pipeline, grid, cv=5, n_jobs=-1)
grid_search.fit(data.data, data.target)
print("Best score: %0.3f" % grid_search.best_score_)
print("-------------------------------------------")
print("Best parameters set:", grid_search.best_estimator_.get_params())