Skip to content

Commit

Permalink
fix init function bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiTao-Li committed May 16, 2024
1 parent 4f024b1 commit 0c25ada
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"emb_model_config_name": "qwen_emb_config",
"rag_config": {
"knowledge_id": ["agentscope_api_rag"],
"similarity_top_k": 3,
"similarity_top_k": 2,
"log_retrieval": true,
"recent_n_mem": 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
{
"knowledge_id": "agentscope_api_rag",
"emb_model_config_name": "qwen_emb_config",
"chunk_size": 2048,
"chunk_size": 1024,
"chunk_overlap": 40,
"data_processing": [
{
Expand Down
20 changes: 15 additions & 5 deletions src/agentscope/rag/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,27 @@ class Knowledge(ABC):

def __init__(
self,
model: Optional[ModelWrapperBase] = None,
knowledge_id: str,
emb_model: Any = None,
knowledge_config: Optional[dict] = None,
rag_config: Optional[dict] = None,
model: Optional[ModelWrapperBase] = None,
**kwargs: Any,
) -> None:
# pylint: disable=unused-argument
self.postprocessing_model = model
"""
initialize the knowledge component
Args:
knowledge_id (str):
The id of the knowledge unit.
emb_model (ModelWrapperBase):
The embedding model used for generate embeddings
knowledge_config (dict):
The configuration to generate or load the index.
"""
self.knowledge_id = knowledge_id
self.emb_model = emb_model
self.knowledge_config = knowledge_config or {}
self.rag_config = rag_config or {}
self.postprocessing_model = model

@abstractmethod
def _init_rag(
Expand Down Expand Up @@ -68,7 +78,7 @@ def retrieve(
"""

@abstractmethod
def _set_retriever(self, **kwargs: Any) -> None:
def set_retriever(self, **kwargs: Any) -> None:
"""update retriever of RAG module"""

def post_processing(
Expand Down
4 changes: 3 additions & 1 deletion src/agentscope/rag/knowledge_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
def _init_knowledge(self) -> None:
"""initialize the knowledge bank"""
for config in self.configs:
print("bank", config)
self.add_data_as_knowledge(
knowledge_id=config["knowledge_id"],
emb_model_name=config["emb_model_config_name"],
Expand Down Expand Up @@ -116,10 +117,10 @@ def add_data_as_knowledge(
self.stored_knowledge[knowledge_id] = LlamaIndexKnowledge(
knowledge_id=knowledge_id,
emb_model=load_model_by_config_name(emb_model_name),
knowledge_config=knowledge_config,
model=load_model_by_config_name(model_name)
if model_name
else None,
index_config=knowledge_config,
)
logger.info(f"data loaded for knowledge_id = {knowledge_id}.")

Expand Down Expand Up @@ -167,5 +168,6 @@ def equip(self, agent: AgentBase, duplicate: bool = False) -> None:
knowledge_id=rid,
duplicate=duplicate,
)
knowledge.set_retriever(agent.rag_config)
agent.knowledge_list.append(knowledge)
agent.retriever_list.append(knowledge.retriever)
37 changes: 23 additions & 14 deletions src/agentscope/rag/llama_index_knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,16 @@ class LlamaIndexKnowledge(Knowledge):
def __init__(
self,
knowledge_id: str,
persist_root: Optional[str] = None,
model: Optional[ModelWrapperBase] = None,
emb_model: Union[ModelWrapperBase, BaseEmbedding, None] = None,
knowledge_config: dict = None,
knowledge_config: Optional[dict] = None,
model: Optional[ModelWrapperBase] = None,
persist_root: Optional[str] = None,
overwrite_index: Optional[bool] = False,
showprogress: Optional[bool] = True,
**kwargs: Any,
) -> None:
"""
initialize the RAG component based on the
initialize the knowledge component based on the
llama-index framework: https://github.com/run-llama/llama_index
Notes:
Expand All @@ -169,22 +169,27 @@ def __init__(
Args:
knowledge_id (str):
The id of the RAG knowledge unit.
persist_root (str):
The root directory for index persisting
model (ModelWrapperBase):
The language model used for final synthesis
emb_model (ModelWrapperBase):
The embedding model used for generate embeddings
knowledge_config (dict):
The configuration for llama-index to
generate or load the index.
model (ModelWrapperBase):
The language model used for final synthesis
persist_root (str):
The root directory for index persisting
overwrite_index (Optional[bool]):
Whether to overwrite the index while refreshing
showprogress (Optional[bool]):
Whether to show the indexing progress
"""
super().__init__(model, emb_model, knowledge_config, **kwargs)
self.knowledge_id = knowledge_id
super().__init__(
knowledge_id=knowledge_id,
emb_model=emb_model,
knowledge_config=knowledge_config,
model=model,
**kwargs,
)
if persist_root is None:
persist_root = file_manager.dir
self.persist_dir = os.path.join(persist_root, knowledge_id)
Expand All @@ -203,6 +208,7 @@ def __init__(
f"Embedding model does not support {type(self.emb_model)}.",
)
# then we can initialize the RAG
print("init", self.knowledge_config)
self._init_rag()

def _init_rag(self, **kwargs: Any) -> None:
Expand All @@ -225,7 +231,7 @@ def _init_rag(self, **kwargs: Any) -> None:
# self.refresh_index()
else:
self._data_to_index()
self._set_retriever()
self.set_retriever()
logger.info(
f"RAG with knowledge ids: {self.knowledge_id} "
f"initialization completed!\n",
Expand Down Expand Up @@ -411,8 +417,9 @@ def _set_transformations(self, config: dict) -> Any:
transformations = {"transformations": transformations}
return transformations

def _set_retriever(
def set_retriever(
self,
rag_config: Optional[dict] = None,
retriever: Optional[BaseRetriever] = None,
**kwargs: Any,
) -> None:
Expand All @@ -421,17 +428,19 @@ def _set_retriever(
Args:
retriever (Optional[BaseRetriever]): passing a retriever in llama
rag_config (dict): rag configuration, including similarity top k
index.
"""
# set the retriever
rag_config = rag_config or {}
if retriever is None:
logger.info(
f"similarity_top_k"
f'={self.rag_config.get("similarity_top_k", DEFAULT_TOP_K)}',
f'={rag_config.get("similarity_top_k", DEFAULT_TOP_K)}',
)
self.retriever = self.index.as_retriever(
embed_model=self.emb_model,
similarity_top_k=self.rag_config.get(
similarity_top_k=rag_config.get(
"similarity_top_k",
DEFAULT_TOP_K,
),
Expand Down

0 comments on commit 0c25ada

Please sign in to comment.