diff --git a/pyserini/index/pyutils.py b/pyserini/index/pyutils.py index 0805adbb8..18dc22dfc 100644 --- a/pyserini/index/pyutils.py +++ b/pyserini/index/pyutils.py @@ -66,7 +66,7 @@ def __init__(self, term, doc_freq, total_term_freq): self.doc_freq = doc_freq self.total_term_freq = total_term_freq - def analyze_term(self, term): + def analyze(self, text): ''' Parameters ---------- @@ -74,9 +74,13 @@ def analyze_term(self, term): Returns ------- result : str - Stemmed term + List of stemmed tokens ''' - return self.object.analyzeTerm(JString(term)) + stemmed = self.object.analyze(JString(text)) + token_list = [] + for token in stemmed.toArray(): + token_list.append(token) + return token_list def terms(self): ''' diff --git a/tests/test_indexutils.py b/tests/test_indexutils.py index 34a8f40c7..8160a07ba 100644 --- a/tests/test_indexutils.py +++ b/tests/test_indexutils.py @@ -27,11 +27,12 @@ def setUp(self): def test_terms(self): self.assertEqual(sum(1 for x in self.index_utils.terms()), 14363) - def test_term_stats(self): - term = 'retrieval' - self.assertEqual(self.index_utils.analyze_term(term), 'retriev') + def test_analyze(self): + self.assertEqual(' '.join(self.index_utils.analyze('retrieval')), 'retriev') + self.assertEqual(' '.join(self.index_utils.analyze('rapid retrieval, space economy')), 'rapid retriev space economi') - collection_freq, doc_freq = self.index_utils.get_term_counts(term) + def test_term_stats(self): + collection_freq, doc_freq = self.index_utils.get_term_counts('retrieval') self.assertEqual(collection_freq, 275) self.assertEqual(doc_freq, 138)