Skip to content

Commit

Permalink
update workflow + update code with feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Alleria1809 committed Jun 29, 2024
1 parent b7fa9ec commit aedefb5
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 93 deletions.
9 changes: 0 additions & 9 deletions lightrag/components/data_process/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,25 +172,16 @@ def __init__(
"""
super().__init__()

# variable value checks
self.split_by = split_by
# Validate split_by is in SEPARATORS
options = ", ".join(f"'{key}'" for key in SEPARATORS.keys())
assert split_by in SEPARATORS, f"Invalid options for split_by. You must select from {options}."
# log.error(f"Invalid options for split_by. You must select from {options}.")

# Validate chunk_overlap is less than chunk_size
assert chunk_overlap < chunk_size, f"chunk_overlap can't be larger than or equal to chunk_size. Received chunk_size: {chunk_size}, chunk_overlap: {chunk_overlap}"
# log.error(f"chunk_overlap can't be larger than or equal to chunk_size. Received chunk_size: {chunk_size}, chunk_overlap: {chunk_overlap}")

# Validate chunk_size is greater than 0
assert chunk_size > 0, f"chunk_size must be greater than 0. Received value: {chunk_size}"
# log.error(f"chunk_size must be greater than 0. Received value: {chunk_size}")
self.chunk_size = chunk_size

# Validate chunk_overlap is non-negative
assert chunk_overlap >= 0, f"chunk_overlap must be non-negative. Received value: {chunk_overlap}"
# log.error(f"chunk_overlap must be non-negative. Received value: {chunk_overlap}")
self.chunk_overlap = chunk_overlap

self.batch_size = batch_size
Expand Down
6 changes: 0 additions & 6 deletions lightrag/tests/test_gt_text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,6 @@ def test_overlap_zero_end(self):
text = "one two three four five six seven eight nine ten"
self.compare_splits(text)

# def test_invalid_parameters(self):
# with self.assertRaises(ValueError):
# TextSplitter(split_by="word", chunk_size=-1, chunk_overlap=2)
# with self.assertRaises(ValueError):
# TextSplitter(split_by="word", chunk_size=5, chunk_overlap=6)


if __name__ == '__main__':
unittest.main()
156 changes: 78 additions & 78 deletions lightrag/tests/test_transformer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,84 +23,84 @@ def setUp(self) -> None:
"The red panda (Ailurus fulgens), also called the lesser panda, the red bear-cat, and the red cat-bear, is a mammal native to the eastern Himalayas and southwestern China.",
]

# def test_transformer_embedder(self):
# transformer_embedder_model = "thenlper/gte-base"
# transformer_embedder_model_component = TransformerEmbedder(
# model_name=transformer_embedder_model
# )
# print(
# f"Testing transformer embedder with model {transformer_embedder_model_component}"
# )
# print("Testing transformer embedder")
# output = transformer_embedder_model_component(
# model=transformer_embedder_model, input="Hello world"
# )
# print(output)

# def test_transformer_client(self):
# transformer_client = TransformersClient()
# print("Testing transformer client")
# # run the model
# kwargs = {
# "model": "thenlper/gte-base",
# # "mock": False,
# }
# api_kwargs = transformer_client.convert_inputs_to_api_kwargs(
# input="Hello world",
# model_kwargs=kwargs,
# model_type=ModelType.EMBEDDER,
# )
# # print(api_kwargs)
# output = transformer_client.call(
# api_kwargs=api_kwargs, model_type=ModelType.EMBEDDER
# )

# # print(transformer_client)
# # print(output)

# def test_transformer_reranker(self):
# transformer_reranker_model = "BAAI/bge-reranker-base"
# transformer_reranker_model_component = TransformerReranker()
# # print(
# # f"Testing transformer reranker with model {transformer_reranker_model_component}"
# # )

# model_kwargs = {
# "model": transformer_reranker_model,
# "documents": self.documents,
# "query": self.query,
# "top_k": 2,
# }

# output = transformer_reranker_model_component(
# **model_kwargs,
# )
# # assert output is a list of float with length 2
# self.assertEqual(len(output), 2)
# self.assertEqual(type(output[0]), float)

# def test_transformer_reranker_client(self):
# transformer_reranker_client = TransformersClient(
# model_name="BAAI/bge-reranker-base"
# )
# print("Testing transformer reranker client")
# # run the model
# kwargs = {
# "model": "BAAI/bge-reranker-base",
# "documents": self.documents,
# "top_k": 2,
# }
# api_kwargs = transformer_reranker_client.convert_inputs_to_api_kwargs(
# input=self.query,
# model_kwargs=kwargs,
# model_type=ModelType.RERANKER,
# )
# print(api_kwargs)
# self.assertEqual(api_kwargs["model"], "BAAI/bge-reranker-base")
# output = transformer_reranker_client.call(
# api_kwargs=api_kwargs, model_type=ModelType.RERANKER
# )
# self.assertEqual(type(output), tuple)
def test_transformer_embedder(self):
transformer_embedder_model = "thenlper/gte-base"
transformer_embedder_model_component = TransformerEmbedder(
model_name=transformer_embedder_model
)
print(
f"Testing transformer embedder with model {transformer_embedder_model_component}"
)
print("Testing transformer embedder")
output = transformer_embedder_model_component(
model=transformer_embedder_model, input="Hello world"
)
print(output)

def test_transformer_client(self):
transformer_client = TransformersClient()
print("Testing transformer client")
# run the model
kwargs = {
"model": "thenlper/gte-base",
# "mock": False,
}
api_kwargs = transformer_client.convert_inputs_to_api_kwargs(
input="Hello world",
model_kwargs=kwargs,
model_type=ModelType.EMBEDDER,
)
# print(api_kwargs)
output = transformer_client.call(
api_kwargs=api_kwargs, model_type=ModelType.EMBEDDER
)

# print(transformer_client)
# print(output)

def test_transformer_reranker(self):
transformer_reranker_model = "BAAI/bge-reranker-base"
transformer_reranker_model_component = TransformerReranker()
# print(
# f"Testing transformer reranker with model {transformer_reranker_model_component}"
# )

model_kwargs = {
"model": transformer_reranker_model,
"documents": self.documents,
"query": self.query,
"top_k": 2,
}

output = transformer_reranker_model_component(
**model_kwargs,
)
# assert output is a list of float with length 2
self.assertEqual(len(output), 2)
self.assertEqual(type(output[0]), float)

def test_transformer_reranker_client(self):
transformer_reranker_client = TransformersClient(
model_name="BAAI/bge-reranker-base"
)
print("Testing transformer reranker client")
# run the model
kwargs = {
"model": "BAAI/bge-reranker-base",
"documents": self.documents,
"top_k": 2,
}
api_kwargs = transformer_reranker_client.convert_inputs_to_api_kwargs(
input=self.query,
model_kwargs=kwargs,
model_type=ModelType.RERANKER,
)
print(api_kwargs)
self.assertEqual(api_kwargs["model"], "BAAI/bge-reranker-base")
output = transformer_reranker_client.call(
api_kwargs=api_kwargs, model_type=ModelType.RERANKER
)
self.assertEqual(type(output), tuple)


def test_transformer_llm_response(self):
Expand Down

0 comments on commit aedefb5

Please sign in to comment.