diff --git a/examples/conversation_with_RAG_agents/configs/agent_config.json b/examples/conversation_with_RAG_agents/configs/agent_config.json index 596c0a2b0..5061d536b 100644 --- a/examples/conversation_with_RAG_agents/configs/agent_config.json +++ b/examples/conversation_with_RAG_agents/configs/agent_config.json @@ -41,7 +41,7 @@ "emb_model_config_name": "qwen_emb_config", "rag_config": { "knowledge_id": ["agentscope_api_rag"], - "similarity_top_k": 3, + "similarity_top_k": 2, "log_retrieval": true, "recent_n_mem": 1 } diff --git a/examples/conversation_with_RAG_agents/configs/knowledge_config.json b/examples/conversation_with_RAG_agents/configs/knowledge_config.json index 0935c3dc5..d7ef45542 100644 --- a/examples/conversation_with_RAG_agents/configs/knowledge_config.json +++ b/examples/conversation_with_RAG_agents/configs/knowledge_config.json @@ -39,7 +39,7 @@ { "knowledge_id": "agentscope_api_rag", "emb_model_config_name": "qwen_emb_config", - "chunk_size": 2048, + "chunk_size": 1024, "chunk_overlap": 40, "data_processing": [ { diff --git a/src/agentscope/rag/knowledge.py b/src/agentscope/rag/knowledge.py index 4bd95e755..285459cdc 100644 --- a/src/agentscope/rag/knowledge.py +++ b/src/agentscope/rag/knowledge.py @@ -30,17 +30,27 @@ class Knowledge(ABC): def __init__( self, - model: Optional[ModelWrapperBase] = None, + knowledge_id: str, emb_model: Any = None, knowledge_config: Optional[dict] = None, - rag_config: Optional[dict] = None, + model: Optional[ModelWrapperBase] = None, **kwargs: Any, ) -> None: # pylint: disable=unused-argument - self.postprocessing_model = model + """ + initialize the knowledge component + Args: + knowledge_id (str): + The id of the knowledge unit. + emb_model (ModelWrapperBase): + The embedding model used for generate embeddings + knowledge_config (dict): + The configuration to generate or load the index. + """ + self.knowledge_id = knowledge_id self.emb_model = emb_model self.knowledge_config = knowledge_config or {} - self.rag_config = rag_config or {} + self.postprocessing_model = model @abstractmethod def _init_rag( @@ -68,7 +78,7 @@ def retrieve( """ @abstractmethod - def _set_retriever(self, **kwargs: Any) -> None: + def set_retriever(self, **kwargs: Any) -> None: """update retriever of RAG module""" def post_processing( diff --git a/src/agentscope/rag/knowledge_bank.py b/src/agentscope/rag/knowledge_bank.py index 3f397c9e6..3a5a30ff0 100644 --- a/src/agentscope/rag/knowledge_bank.py +++ b/src/agentscope/rag/knowledge_bank.py @@ -50,6 +50,7 @@ def __init__( def _init_knowledge(self) -> None: """initialize the knowledge bank""" for config in self.configs: + print("bank", config) self.add_data_as_knowledge( knowledge_id=config["knowledge_id"], emb_model_name=config["emb_model_config_name"], @@ -116,10 +117,10 @@ def add_data_as_knowledge( self.stored_knowledge[knowledge_id] = LlamaIndexKnowledge( knowledge_id=knowledge_id, emb_model=load_model_by_config_name(emb_model_name), + knowledge_config=knowledge_config, model=load_model_by_config_name(model_name) if model_name else None, - index_config=knowledge_config, ) logger.info(f"data loaded for knowledge_id = {knowledge_id}.") @@ -167,5 +168,6 @@ def equip(self, agent: AgentBase, duplicate: bool = False) -> None: knowledge_id=rid, duplicate=duplicate, ) + knowledge.set_retriever(agent.rag_config) agent.knowledge_list.append(knowledge) agent.retriever_list.append(knowledge.retriever) diff --git a/src/agentscope/rag/llama_index_knowledge.py b/src/agentscope/rag/llama_index_knowledge.py index 721cf4d40..49439321f 100644 --- a/src/agentscope/rag/llama_index_knowledge.py +++ b/src/agentscope/rag/llama_index_knowledge.py @@ -143,16 +143,16 @@ class LlamaIndexKnowledge(Knowledge): def __init__( self, knowledge_id: str, - persist_root: Optional[str] = None, - model: Optional[ModelWrapperBase] = None, emb_model: Union[ModelWrapperBase, BaseEmbedding, None] = None, - knowledge_config: dict = None, + knowledge_config: Optional[dict] = None, + model: Optional[ModelWrapperBase] = None, + persist_root: Optional[str] = None, overwrite_index: Optional[bool] = False, showprogress: Optional[bool] = True, **kwargs: Any, ) -> None: """ - initialize the RAG component based on the + initialize the knowledge component based on the llama-index framework: https://github.com/run-llama/llama_index Notes: @@ -169,22 +169,27 @@ def __init__( Args: knowledge_id (str): The id of the RAG knowledge unit. - persist_root (str): - The root directory for index persisting - model (ModelWrapperBase): - The language model used for final synthesis emb_model (ModelWrapperBase): The embedding model used for generate embeddings knowledge_config (dict): The configuration for llama-index to generate or load the index. + model (ModelWrapperBase): + The language model used for final synthesis + persist_root (str): + The root directory for index persisting overwrite_index (Optional[bool]): Whether to overwrite the index while refreshing showprogress (Optional[bool]): Whether to show the indexing progress """ - super().__init__(model, emb_model, knowledge_config, **kwargs) - self.knowledge_id = knowledge_id + super().__init__( + knowledge_id=knowledge_id, + emb_model=emb_model, + knowledge_config=knowledge_config, + model=model, + **kwargs, + ) if persist_root is None: persist_root = file_manager.dir self.persist_dir = os.path.join(persist_root, knowledge_id) @@ -203,6 +208,7 @@ def __init__( f"Embedding model does not support {type(self.emb_model)}.", ) # then we can initialize the RAG + print("init", self.knowledge_config) self._init_rag() def _init_rag(self, **kwargs: Any) -> None: @@ -225,7 +231,7 @@ def _init_rag(self, **kwargs: Any) -> None: # self.refresh_index() else: self._data_to_index() - self._set_retriever() + self.set_retriever() logger.info( f"RAG with knowledge ids: {self.knowledge_id} " f"initialization completed!\n", @@ -411,8 +417,9 @@ def _set_transformations(self, config: dict) -> Any: transformations = {"transformations": transformations} return transformations - def _set_retriever( + def set_retriever( self, + rag_config: Optional[dict] = None, retriever: Optional[BaseRetriever] = None, **kwargs: Any, ) -> None: @@ -421,17 +428,19 @@ def _set_retriever( Args: retriever (Optional[BaseRetriever]): passing a retriever in llama + rag_config (dict): rag configuration, including similarity top k index. """ # set the retriever + rag_config = rag_config or {} if retriever is None: logger.info( f"similarity_top_k" - f'={self.rag_config.get("similarity_top_k", DEFAULT_TOP_K)}', + f'={rag_config.get("similarity_top_k", DEFAULT_TOP_K)}', ) self.retriever = self.index.as_retriever( embed_model=self.emb_model, - similarity_top_k=self.rag_config.get( + similarity_top_k=rag_config.get( "similarity_top_k", DEFAULT_TOP_K, ),