Skip to content

Commit

Permalink
update text splitter
Browse files Browse the repository at this point in the history
  • Loading branch information
Alleria1809 committed Jul 10, 2024
1 parent ee314b1 commit fbdcd1c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 28 deletions.
20 changes: 8 additions & 12 deletions lightrag/lightrag/components/data_process/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,22 +181,18 @@ def __init__(
super().__init__()

self.split_by = split_by
assert (
split_by in SEPARATORS
), f"Invalid options for split_by. You must select from {list(SEPARATORS.keys())}."
if split_by not in SEPARATORS:
raise ValueError(f"Invalid options for split_by. You must select from {list(SEPARATORS.keys())}.")

assert (
chunk_overlap < chunk_size
), f"chunk_overlap can't be larger than or equal to chunk_size. Received chunk_size: {chunk_size}, chunk_overlap: {chunk_overlap}"
if chunk_overlap >= chunk_size:
raise ValueError(f"chunk_overlap can't be larger than or equal to chunk_size. Received chunk_size: {chunk_size}, chunk_overlap: {chunk_overlap}")

assert (
chunk_size > 0
), f"chunk_size must be greater than 0. Received value: {chunk_size}"
if chunk_size <= 0:
raise ValueError(f"chunk_size must be greater than 0. Received value: {chunk_size}")
self.chunk_size = chunk_size

assert (
chunk_overlap >= 0
), f"chunk_overlap must be non-negative. Received value: {chunk_overlap}"
if chunk_overlap < 0:
raise ValueError(f"chunk_overlap must be non-negative. Received value: {chunk_overlap}")
self.chunk_overlap = chunk_overlap

self.batch_size = batch_size
Expand Down
37 changes: 21 additions & 16 deletions lightrag/tests/test_text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,25 @@ def setUp(self):
# Set up a TextSplitter instance before each test
self.splitter = TextSplitter(split_by="word", chunk_size=5, chunk_overlap=2)

# def test_invalid_split_by(self):
# # Test initialization with invalid split_by value
# with self.assertRaises(ValueError):
# TextSplitter(split_by="invalid", chunk_size=5, chunk_overlap=0)
def test_invalid_split_by(self):
# Test initialization with invalid split_by value
with self.assertRaises(ValueError):
TextSplitter(split_by="invalid", chunk_size=5, chunk_overlap=0)

# def test_negative_chunk_size(self):
# # Test initialization with negative chunk_size
# with self.assertRaises(ValueError):
# TextSplitter(split_by="word", chunk_size=-1, chunk_overlap=0)
def test_negative_chunk_size(self):
# Test initialization with negative chunk_size
with self.assertRaises(ValueError):
TextSplitter(split_by="word", chunk_size=-1, chunk_overlap=0)

# def test_negative_chunk_overlap(self):
# # Test initialization with negative chunk_overlap
# with self.assertRaises(ValueError):
# TextSplitter(split_by="word", chunk_size=5, chunk_overlap=-1)
def test_negative_chunk_overlap(self):
# Test initialization with negative chunk_overlap
with self.assertRaises(ValueError):
TextSplitter(split_by="word", chunk_size=5, chunk_overlap=-1)

def test_equal_chunk_overlap_size(self):
# Test initialization with equal chunk overlap and chunk size
with self.assertRaises(ValueError):
TextSplitter(split_by="word", chunk_size=5, chunk_overlap=5)

def test_split_by_word(self):
# Test the basic functionality of splitting by word
Expand Down Expand Up @@ -54,10 +59,10 @@ def test_document_splitting(self):
result_texts = [doc.text for doc in result]
self.assertEqual(result_texts, expected_texts)

# def test_empty_text_handling(self):
# # Test handling of empty text
# with self.assertRaises(ValueError):
# self.splitter.call([Document(text=None, id="1")])
def test_empty_text_handling(self):
# Test handling of empty text
with self.assertRaises(ValueError):
self.splitter.call([Document(text=None, id="1")])


if __name__ == "__main__":
Expand Down

0 comments on commit fbdcd1c

Please sign in to comment.