From 34a2c4914aab46865f2d0191bbd12235b10ad3f9 Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Fri, 1 Nov 2024 21:48:13 +0800 Subject: [PATCH 1/8] feat(llm): added the process of intelligent generated gremlin for retrieval before subgraph retrieval --- hugegraph-llm/README.md | 2 +- .../hugegraph_llm/demo/rag_demo/rag_block.py | 6 +- .../hugegraph_llm/operators/graph_rag_task.py | 5 + .../operators/hugegraph_op/graph_rag_query.py | 120 ++++++++++++++++-- .../index_op/gremlin_example_index_query.py | 2 +- 5 files changed, 116 insertions(+), 19 deletions(-) diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md index eeb7f430..15d9fc43 100644 --- a/hugegraph-llm/README.md +++ b/hugegraph-llm/README.md @@ -150,7 +150,7 @@ Here is a brief usage guide: graph_rag.extract_keywords(text="Tell me about Al Pacino.").print_result() ``` -2. **Match Vid from Keywords*: Match the nodes with the keywords in the graph. +2. **Match Vid from Keywords**: Match the nodes with the keywords in the graph. ```python graph_rag.keywords_to_vid().print_result() diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py index 66c9b19f..842f0049 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py @@ -24,7 +24,7 @@ import pandas as pd from gradio.utils import NamedString -from hugegraph_llm.config import resource_path, prompt +from hugegraph_llm.config import resource_path, prompt, settings from hugegraph_llm.operators.graph_rag_task import RAGPipeline from hugegraph_llm.utils.log import log @@ -66,7 +66,7 @@ def rag_answer( if vector_search: rag.query_vector_index() if graph_search: - rag.extract_keywords().keywords_to_vid().query_graphdb() + rag.extract_keywords().keywords_to_vid().import_schema(settings.graph_name).query_graphdb() # TODO: add more user-defined search strategies rag.merge_dedup_rerank(graph_ratio, rerank_method, near_neighbor_first, custom_related_information) rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt) @@ -265,4 +265,4 @@ def several_rag_answer( ) questions_file.change(read_file_to_excel, questions_file, [qa_dataframe, answer_max_line_count]) answer_max_line_count.change(change_showing_excel, answer_max_line_count, qa_dataframe) - return inp, answer_prompt_input \ No newline at end of file + return inp, answer_prompt_input diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py index e6da8e09..c9160e17 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py @@ -26,6 +26,7 @@ from hugegraph_llm.operators.common_op.print_result import PrintResult from hugegraph_llm.operators.document_op.word_extract import WordExtract from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery +from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager from hugegraph_llm.operators.index_op.semantic_id_query import SemanticIdQuery from hugegraph_llm.operators.index_op.vector_index_query import VectorIndexQuery from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize @@ -90,6 +91,10 @@ def extract_keywords( ) return self + def import_schema(self, graph_name: str): + self._operators.append(SchemaManager(graph_name)) + return self + def keywords_to_vid( self, by: Literal["query", "keywords"] = "keywords", diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index fcc15530..16b29364 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -14,9 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import json +import os +import re from typing import Any, Dict, Optional, List, Set, Tuple -from hugegraph_llm.config import settings +from hugegraph_llm.config import settings, resource_path +from hugegraph_llm.indices.vector_index import VectorIndex +from hugegraph_llm.models.embeddings.base import BaseEmbedding +from hugegraph_llm.models.embeddings.init_embedding import Embeddings +from hugegraph_llm.models.llms.base import BaseLLM +from hugegraph_llm.models.llms.init_llm import LLMs from hugegraph_llm.utils.log import log from pyhugegraph.client import PyHugeClient @@ -24,7 +33,7 @@ # TODO: we could use a simpler query (like kneighbor-api to get the edges) # TODO: test with profile()/explain() to speed up the query -VID_QUERY_NEIGHBOR_TPL = """ +VID_QUERY_NEIGHBOR_TPL = """\ g.V({keywords}) .repeat( bothE({edge_labels}).limit({edge_limit}).otherV().dedup() @@ -46,7 +55,7 @@ .toList() """ -PROPERTY_QUERY_NEIGHBOR_TPL = """ +PROPERTY_QUERY_NEIGHBOR_TPL = """\ g.V().has('{prop}', within({keywords})) .repeat( bothE({edge_labels}).limit({edge_limit}).otherV().dedup() @@ -67,10 +76,35 @@ .toList() """ +GREMLIN_GENERATE_EXAMPLE_OPTION_TPL = """\ +# Example +Generate gremlin from the following user input. +{example_query} +The generated gremlin is: +```gremlin +{example_gremlin} +``` + +""" + +GREMLIN_GENERATE_TPL = """\ +Given the graph schema: +{schema} +Generate gremlin from the following user input. +{query} +The generated gremlin is:""" + class GraphRAGQuery: - def __init__(self, max_deep: int = 2, max_items: int = 20, prop_to_match: Optional[str] = None): + def __init__( + self, + max_deep: int = 2, + max_items: int = 20, + prop_to_match: Optional[str] = None, + llm: Optional[BaseLLM] = None, + embedding: Optional[BaseEmbedding] = None, + ): self._client = PyHugeClient( settings.graph_ip, settings.graph_port, @@ -83,9 +117,17 @@ def __init__(self, max_deep: int = 2, max_items: int = 20, prop_to_match: Option self._max_items = max_items self._prop_to_match = prop_to_match self._schema = "" + self._index_dir = os.path.join(resource_path, "gremlin_examples") + self._vector_index = VectorIndex.from_index_file(self._index_dir) + self._llm = llm + self._embedding = embedding def run(self, context: Dict[str, Any]) -> Dict[str, Any]: # pylint: disable=R0915 (too-many-statements) + if self._llm is None: + self._llm = LLMs().get_llm() + if self._embedding is None: + self._embedding = Embeddings().get_embedding() if self._client is None: if isinstance(context.get("graph_client"), PyHugeClient): self._client = context["graph_client"] @@ -99,7 +141,61 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: self._client = PyHugeClient(ip, port, graph, user, pwd, gs) assert self._client is not None, "No valid graph to search." - # 2. Extract params from context + # 1. Try to perform a query based on the generated gremlin + context = self._gremlin_generate_query(context) + # 2. Try to perform a query based on subgraph-search if the previous query failed + if not context.get("graph_result"): + context = self._subgraph_query(context) + + # TODO: replace print to log + verbose = context.get("verbose") or False + if verbose: + print("\033[93mKnowledge from Graph:") + print("\n".join(context["graph_result"]) + "\033[0m") + + return context + + def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: + query = context["query"] + query_embedding = context.get("query_embedding") + if query_embedding is None: + query_embedding = self._embedding.get_text_embedding(query) + + match_result = self._vector_index.search(query_embedding, top_k=1, dis_threshold=2) + prompt = "" + if match_result: + prompt += GREMLIN_GENERATE_EXAMPLE_OPTION_TPL.format( + example_query=match_result[0]["query"], + example_gremlin=match_result[0]["gremlin"] + ) + prompt += GREMLIN_GENERATE_TPL.format( + schema=json.dumps(context["schema"], ensure_ascii=False), + query=query + ) + + response = self._llm.generate(prompt=prompt) + match = re.search("```gremlin.*```", response, re.DOTALL) + if match: + gremlin = match.group()[len("```gremlin"):-len("```")] + log.info("Generated gremlin: %s", gremlin) + context["gremlin"] = gremlin + try: + result = self._client.gremlin().exec(gremlin=gremlin)["data"] + if result == [None]: + result = [] + context["graph_result"] = [json.dumps(item, ensure_ascii=False) for item in result] + context["graph_context_head"] = ( + f"The following are graph query result " + f"from gremlin query `{gremlin}`.\n" + ) + except Exception as e: + log.error(e) + else: + log.error("Failed to generate gremlin from the query.") + return context + + def _subgraph_query(self, context: Dict[str, Any]) -> Dict[str, Any]: + # 1. Extract params from context matched_vids = context.get("match_vids") if isinstance(context.get("max_deep"), int): self._max_deep = context["max_deep"] @@ -108,7 +204,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: if isinstance(context.get("prop_to_match"), str): self._prop_to_match = context["prop_to_match"] - # 3. Extract edge_labels from graph schema + # 2. Extract edge_labels from graph schema _, edge_labels = self._extract_labels_from_schema() edge_labels_str = ",".join("'" + label + "'" for label in edge_labels) # TODO: enhance the limit logic later @@ -170,13 +266,6 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: "`vertexA --[links]--> vertexB <--[links]-- vertexC ...`" "extracted based on key entities as subject:\n" ) - - # TODO: replace print to log - verbose = context.get("verbose") or False - if verbose: - print("\033[93mKnowledge from Graph:") - print("\n".join(context["graph_result"]) + "\033[0m") - return context def _format_graph_from_vertex(self, query_result: List[Any]) -> Set[str]: @@ -248,7 +337,10 @@ def _process_edge(self, item: Any, flat_rel: str, prior_edge_str_len: int, raw_flat_rel: List[Any], i: int, use_id_to_match: bool) -> Tuple[str, int]: props_str = ", ".join(f"{k}: {v}" for k, v in item["props"].items()) props_str = f"{{{props_str}}}" if len(props_str) > 0 else "" - prev_matched_str = raw_flat_rel[i - 1]["id"] if use_id_to_match else raw_flat_rel[i - 1]["props"][self._prop_to_match] + prev_matched_str = ( + raw_flat_rel[i - 1]["id"] if use_id_to_match + else raw_flat_rel[i - 1]["props"][self._prop_to_match] + ) if item["outV"] == prev_matched_str: edge_str = f" --[{item['label']}{props_str}]--> " diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py index ddcf589d..a04be6f4 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py @@ -35,5 +35,5 @@ def __init__(self, query: str, embedding: BaseEmbedding, num_examples: int = 1): def run(self, context: Dict[str, Any]) -> Dict[str, Any]: context["query"] = self.query query_embedding = self.embedding.get_text_embedding(self.query) - context["match_result"] = self.vector_index.search(query_embedding, self.num_examples) + context["match_result"] = self.vector_index.search(query_embedding, self.num_examples, dis_threshold=2) return context From add83da49f233ccd2c5109a7ed7c801faf7e3e2c Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Sat, 2 Nov 2024 12:06:59 +0800 Subject: [PATCH 2/8] feat(llm): add text2gremlin block in rag app --- .../src/hugegraph_llm/demo/rag_demo/app.py | 5 +- .../demo/rag_demo/other_block.py | 2 +- .../demo/rag_demo/text2gremlin_block.py | 103 ++++++++++++++++++ 3 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py index 912b064f..7eaf01d3 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py @@ -34,6 +34,7 @@ apply_graph_config, ) from hugegraph_llm.demo.rag_demo.other_block import create_other_block +from hugegraph_llm.demo.rag_demo.text2gremlin_block import create_text2gremlin_block from hugegraph_llm.demo.rag_demo.rag_block import create_rag_block, rag_answer from hugegraph_llm.demo.rag_demo.vector_graph_block import create_vector_graph_block from hugegraph_llm.resources.demo.css import CSS @@ -92,7 +93,9 @@ def init_rag_ui() -> gr.Interface: textbox_input_schema, textbox_info_extract_template = create_vector_graph_block() with gr.Tab(label="2. (Graph)RAG & User Functions 📖"): textbox_inp, textbox_answer_prompt_input = create_rag_block() - with gr.Tab(label="3. Others Tools 🚧"): + with gr.Tab(label="3. Text2gremlin ⚙️"): + create_text2gremlin_block() + with gr.Tab(label="4. Others Tools 🚧"): create_other_block() diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/other_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/other_block.py index 5297f8a0..d72975ce 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/other_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/other_block.py @@ -21,7 +21,7 @@ def create_other_block(): - gr.Markdown("""## 4. Other Tools """) + gr.Markdown("""## 5. Other Tools """) with gr.Row(): inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin query", show_copy_button=True, lines=8) out = gr.Code(label="Output", language="json", elem_classes="code-container-show") diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py new file mode 100644 index 00000000..8a83d4ae --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json + +import gradio as gr + +from hugegraph_llm.models.embeddings.init_embedding import Embeddings +from hugegraph_llm.models.llms.init_llm import LLMs +from hugegraph_llm.operators.gremlin_generate_task import GremlinGenerator + + +def build_example_vector_index(temp_file): + full_path = temp_file.name + if full_path.endswith(".json"): + with open(full_path, "r", encoding="utf-8") as f: + examples = json.load(f) + else: + return "ERROR: please input json file." + builder = GremlinGenerator( + llm=LLMs().get_llm(), + embedding=Embeddings().get_embedding(), + ) + return builder.example_index_build(examples).run() + + +def gremlin_generate(inp, use_schema, use_example, example_num, schema): + generator = GremlinGenerator( + llm=LLMs().get_llm(), + embedding=Embeddings().get_embedding(), + ) + if use_example == "true": + generator.example_index_query(inp, example_num) + context = generator.gremlin_generate(use_schema, use_example, schema).run() + return context.get("match_result", "No Results"), context["result"] + + +def create_text2gremlin_block() -> list: + gr.Markdown("""## 4. Text2gremlin Tools """) + + gr.Markdown("## Build Example Vector Index") + gr.Markdown("Uploaded json file should be in format below:\n\n" + "[{\"query\":\"who is peter\", \"gremlin\":\"g.V().has('name', 'peter')\"}]") + with gr.Row(): + file = gr.File(label="Upload Example Query-Gremlin Pairs Json") + out = gr.Textbox(label="Result Message") + with gr.Row(): + btn = gr.Button("Build Example Vector Index") + btn.click(build_example_vector_index, inputs=[file], outputs=[out]) # pylint: disable=no-member + gr.Markdown("## Nature Language To Gremlin") + SCHEMA = """{ + "vertices": [ + {"vertex_label": "entity", "properties": []} + ], + "edges": [ + { + "edge_label": "relation", + "source_vertex_label": "entity", + "target_vertex_label": "entity", + "properties": {} + } + ] + }""" + with gr.Row(): + with gr.Column(scale=1): + schema_box = gr.Textbox(value=SCHEMA, label="Schema") + with gr.Column(scale=1): + input_box = gr.Textbox(value="Tell me about Al Pacino.", + label="Nature Language Query") + match = gr.Textbox(label="Best-Matched Examples") + out = gr.Textbox(label="Structured Query Language: Gremlin") + with gr.Column(scale=1): + use_example_radio = gr.Radio(choices=["true", "false"], value="false", + label="Use example") + use_schema_radio = gr.Radio(choices=["true", "false"], value="false", + label="Use schema") + example_num_slider = gr.Slider( + minimum=1, + maximum=10, + step=1, + value=5, + label="Number of examples" + ) + btn = gr.Button("Text2Gremlin") + btn.click( # pylint: disable=no-member + fn=gremlin_generate, + inputs=[input_box, use_schema_radio, use_example_radio, example_num_slider, schema_box], + outputs=[match, out] + ) From 41fa0497465f6da61953964a03a383aa0f1ec1c5 Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Wed, 6 Nov 2024 02:03:22 +0800 Subject: [PATCH 3/8] feat(llm): supported uploading csv file to build text2gremlin example index --- .../hugegraph_llm/demo/rag_demo/text2gremlin_block.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py index 8a83d4ae..c4e52497 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py @@ -17,6 +17,7 @@ import json +import pandas as pd import gradio as gr from hugegraph_llm.models.embeddings.init_embedding import Embeddings @@ -29,6 +30,8 @@ def build_example_vector_index(temp_file): if full_path.endswith(".json"): with open(full_path, "r", encoding="utf-8") as f: examples = json.load(f) + elif full_path.endswith(".csv"): + examples = pd.read_csv(full_path).to_dict('records') else: return "ERROR: please input json file." builder = GremlinGenerator( @@ -53,8 +56,10 @@ def create_text2gremlin_block() -> list: gr.Markdown("""## 4. Text2gremlin Tools """) gr.Markdown("## Build Example Vector Index") - gr.Markdown("Uploaded json file should be in format below:\n\n" - "[{\"query\":\"who is peter\", \"gremlin\":\"g.V().has('name', 'peter')\"}]") + gr.Markdown("Uploaded json file should be in format below:\n" + "[{\"query\":\"who is peter\", \"gremlin\":\"g.V().has('name', 'peter')\"}]\n" + "Uploaded csv file should be in format below:\n" + "query,gremlin\n\"who is peter\",\"g.V().has('name', 'peter')\"") with gr.Row(): file = gr.File(label="Upload Example Query-Gremlin Pairs Json") out = gr.Textbox(label="Result Message") From 8a6b80e961d2000d5b8acd08b15c05a15baa1b32 Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Thu, 14 Nov 2024 17:54:14 +0800 Subject: [PATCH 4/8] feat(llm): remove text2gremlin web demo --- hugegraph-llm/README.md | 11 +- .../demo/gremlin_generate_web_demo.py | 208 ------------------ .../hugegraph_llm/utils/graph_index_utils.py | 2 +- 3 files changed, 5 insertions(+), 216 deletions(-) delete mode 100644 hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md index e9d6f90a..e0e91e3e 100644 --- a/hugegraph-llm/README.md +++ b/hugegraph-llm/README.md @@ -46,21 +46,18 @@ graph systems and large language models. ```bash python3 -m hugegraph_llm.demo.rag_demo.app --host 127.0.0.1 --port 18001 ``` -6. Or start the gradio interactive demo of **Text2Gremlin**, you can run with the following command, and open http://127.0.0.1:8002 after starting. You can also change the default host `0.0.0.0` and port `8002` as above. (🚧ing) - ```bash - python3 -m hugegraph_llm.demo.gremlin_generate_web_demo - ``` -7. After running the web demo, the config file `.env` will be automatically generated at the path `hugegraph-llm/.env`. Additionally, a prompt-related configuration file `config_prompt.yaml` will also be generated at the path `hugegraph-llm/src/hugegraph_llm/resources/demo/config_prompt.yaml`. + +6. After running the web demo, the config file `.env` will be automatically generated at the path `hugegraph-llm/.env`. Additionally, a prompt-related configuration file `config_prompt.yaml` will also be generated at the path `hugegraph-llm/src/hugegraph_llm/resources/demo/config_prompt.yaml`. You can modify the content on the web page, and it will be automatically saved to the configuration file after the corresponding feature is triggered. You can also modify the file directly without restarting the web application; simply refresh the page to load your latest changes. (Optional)To regenerate the config file, you can use `config.generate` with `-u` or `--update`. ```bash python3 -m hugegraph_llm.config.generate --update ``` -8. (__Optional__) You could use +7. (__Optional__) You could use [hugegraph-hubble](https://hugegraph.apache.org/docs/quickstart/hugegraph-hubble/#21-use-docker-convenient-for-testdev) to visit the graph data, could run it via [Docker/Docker-Compose](https://hub.docker.com/r/hugegraph/hubble) for guidance. (Hubble is a graph-analysis dashboard include data loading/schema management/graph traverser/display). -9. (__Optional__) offline download NLTK stopwords +8. (__Optional__) offline download NLTK stopwords ```bash python ./hugegraph_llm/operators/common_op/nltk_helper.py ``` diff --git a/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py deleted file mode 100644 index 61663210..00000000 --- a/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py +++ /dev/null @@ -1,208 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -import json -import uvicorn -import gradio as gr -from fastapi import FastAPI -from hugegraph_llm.config import settings -from hugegraph_llm.models.llms.init_llm import LLMs -from hugegraph_llm.models.embeddings.init_embedding import Embeddings -from hugegraph_llm.operators.gremlin_generate_task import GremlinGenerator - - -def build_example_vector_index(temp_file): - full_path = temp_file.name - if full_path.endswith(".json"): - with open(full_path, "r", encoding="utf-8") as f: - examples = json.load(f) - else: - return "ERROR: please input json file." - builder = GremlinGenerator( - llm=LLMs().get_llm(), - embedding=Embeddings().get_embedding(), - ) - return builder.example_index_build(examples).run() - - -def gremlin_generate(inp, use_schema, use_example, example_num, schema): - generator = GremlinGenerator( - llm=LLMs().get_llm(), - embedding=Embeddings().get_embedding(), - ) - if use_example == "true": - generator.example_index_query(inp, example_num) - context = generator.gremlin_generate(use_schema, use_example, schema).run() - return context.get("match_result", "No Results"), context["result"] - - -if __name__ == '__main__': - app = FastAPI() - with gr.Blocks() as demo: - gr.Markdown( - """# HugeGraph LLM Text2Gremlin Demo""" - ) - gr.Markdown("## Set up the LLM") - llm_dropdown = gr.Dropdown(["openai", "qianfan_wenxin", "ollama"], value=settings.llm_type, - label="LLM") - - - @gr.render(inputs=[llm_dropdown]) - def llm_settings(llm_type): - settings.llm_type = llm_type - if llm_type == "openai": - with gr.Row(): - llm_config_input = [ - gr.Textbox(value=settings.openai_api_key, label="api_key"), - gr.Textbox(value=settings.openai_api_base, label="api_base"), - gr.Textbox(value=settings.openai_language_model, label="model_name"), - gr.Textbox(value=str(settings.openai_max_tokens), label="max_token"), - ] - elif llm_type == "qianfan_wenxin": - with gr.Row(): - llm_config_input = [ - gr.Textbox(value=settings.qianfan_api_key, label="api_key"), - gr.Textbox(value=settings.qianfan_secret_key, label="secret_key"), - gr.Textbox(value=settings.qianfan_chat_url, label="chat_url"), - gr.Textbox(value=settings.qianfan_language_model, label="model_name") - ] - elif llm_type == "ollama": - with gr.Row(): - llm_config_input = [ - gr.Textbox(value=settings.ollama_host, label="host"), - gr.Textbox(value=str(settings.ollama_port), label="port"), - gr.Textbox(value=settings.ollama_language_model, label="model_name"), - gr.Textbox(value="", visible=False) - ] - else: - llm_config_input = [] - llm_config_button = gr.Button("Apply Configuration") - - def apply_configuration(arg1, arg2, arg3, arg4): - llm_option = settings.llm_type - if llm_option == "openai": - settings.openai_api_key = arg1 - settings.openai_api_base = arg2 - settings.openai_language_model = arg3 - settings.openai_max_tokens = int(arg4) - elif llm_option == "qianfan_wenxin": - settings.qianfan_api_key = arg1 - settings.qianfan_secret_key = arg2 - settings.qianfan_chat_url = arg3 - settings.qianfan_language_model = arg4 - elif llm_option == "ollama": - settings.ollama_host = arg1 - settings.ollama_port = int(arg2) - settings.ollama_language_model = arg3 - gr.Info("configured!") - - llm_config_button.click(apply_configuration, inputs=llm_config_input) # pylint: disable=no-member - - gr.Markdown("## Set up the Embedding") - embedding_dropdown = gr.Dropdown( - choices=["openai", "ollama"], - value=settings.embedding_type, - label="Embedding" - ) - - @gr.render(inputs=[embedding_dropdown]) - def embedding_settings(embedding_type): - settings.embedding_type = embedding_type - if embedding_type == "openai": - with gr.Row(): - embedding_config_input = [ - gr.Textbox(value=settings.openai_api_key, label="api_key"), - gr.Textbox(value=settings.openai_api_base, label="api_base"), - gr.Textbox(value=settings.openai_embedding_model, label="model_name") - ] - elif embedding_type == "ollama": - with gr.Row(): - embedding_config_input = [ - gr.Textbox(value=settings.ollama_host, label="host"), - gr.Textbox(value=str(settings.ollama_port), label="port"), - gr.Textbox(value=settings.ollama_embedding_model, label="model_name"), - ] - else: - embedding_config_input = [] - embedding_config_button = gr.Button("Apply Configuration") - - def apply_configuration(arg1, arg2, arg3): - embedding_option = settings.embedding_type - if embedding_option == "openai": - settings.openai_api_key = arg1 - settings.openai_api_base = arg2 - settings.openai_embedding_model = arg3 - elif embedding_option == "ollama": - settings.ollama_host = arg1 - settings.ollama_port = int(arg2) - settings.ollama_embedding_model = arg3 - gr.Info("configured!") - # pylint: disable=no-member - embedding_config_button.click(apply_configuration, inputs=embedding_config_input) - - gr.Markdown("## Build Example Vector Index") - gr.Markdown("Uploaded json file should be in format below:\n\n" - "[{\"query\":\"who is peter\", \"gremlin\":\"g.V().has('name', 'peter')\"}]") - with gr.Row(): - file = gr.File(label="Upload Example Query-Gremlin Pairs Json") - out = gr.Textbox(label="Result Message") - with gr.Row(): - btn = gr.Button("Build Example Vector Index") - btn.click(build_example_vector_index, inputs=[file], outputs=[out]) # pylint: disable=no-member - gr.Markdown("## Nature Language To Gremlin") - SCHEMA = """{ - "vertices": [ - {"vertex_label": "entity", "properties": []} - ], - "edges": [ - { - "edge_label": "relation", - "source_vertex_label": "entity", - "target_vertex_label": "entity", - "properties": {} - } - ] -}""" - with gr.Row(): - with gr.Column(scale=1): - schema_box = gr.Textbox(value=SCHEMA, label="Schema") - with gr.Column(scale=1): - input_box = gr.Textbox(value="Tell me about Al Pacino.", - label="Nature Language Query") - match = gr.Textbox(label="Best-Matched Examples") - out = gr.Textbox(label="Structured Query Language: Gremlin") - with gr.Column(scale=1): - use_example_radio = gr.Radio(choices=["true", "false"], value="false", - label="Use example") - use_schema_radio = gr.Radio(choices=["true", "false"], value="false", - label="Use schema") - example_num_slider = gr.Slider( - minimum=1, - maximum=10, - step=1, - value=5, - label="Number of examples" - ) - btn = gr.Button("Text2Gremlin") - btn.click( # pylint: disable=no-member - fn=gremlin_generate, - inputs=[input_box, use_schema_radio, use_example_radio, example_num_slider, schema_box], - outputs=[match, out] - ) - app = gr.mount_gradio_app(app, demo, path="/") - uvicorn.run(app, host="0.0.0.0", port=8002) diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py index a8ea1156..b70b03a9 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py @@ -65,7 +65,7 @@ def extract_graph(input_file, input_text, schema, example_prompt) -> str: builder.import_schema(from_hugegraph=schema) else: return "ERROR: please input with correct schema/format." - builder.chunk_split(texts, "document", "zh").extract_info(example_prompt, "property_graph") + builder.chunk_split(texts, "document", "zh").extract_info(example_prompt, "triples") try: context = builder.run() From 8ae36f5bc4cb1845c167b607d87641dd64d8bfc0 Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Fri, 15 Nov 2024 11:08:03 +0800 Subject: [PATCH 5/8] fix(llm): modified return type of VectorIndex.search function --- hugegraph-llm/src/hugegraph_llm/indices/vector_index.py | 4 ++-- .../src/hugegraph_llm/operators/index_op/semantic_id_query.py | 1 - .../hugegraph_llm/operators/index_op/vector_index_query.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py b/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py index 3732a9f0..4ba19832 100644 --- a/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py @@ -18,7 +18,7 @@ import os import pickle as pkl from copy import deepcopy -from typing import List, Dict, Any, Set, Union +from typing import List, Any, Set, Union import faiss import numpy as np @@ -85,7 +85,7 @@ def remove(self, props: Union[Set[Any], List[Any]]) -> int: self.properties = [p for i, p in enumerate(self.properties) if i not in indices] return remove_num - def search(self, query_vector: List[float], top_k: int, dis_threshold: float = 0.9) -> List[Dict[str, Any]]: + def search(self, query_vector: List[float], top_k: int, dis_threshold: float = 0.9) -> List[Any]: if self.index.ntotal == 0: return [] diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py index d7b5b89e..3253608e 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py @@ -99,6 +99,5 @@ def _fuzzy_match_vids(self, keywords: List[str]) -> List[str]: keyword_vector = self.embedding.get_text_embedding(keyword) results = self.vector_index.search(keyword_vector, top_k=self.topk_per_keyword) if results: - # FIXME: type mismatch, got 'list[dict[str, Any]]' instead fuzzy_match_result.extend(results[:self.topk_per_keyword]) return fuzzy_match_result diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py index e845b615..2fbe82a7 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py @@ -34,7 +34,7 @@ def __init__(self, embedding: BaseEmbedding, topk: int = 3): def run(self, context: Dict[str, Any]) -> Dict[str, Any]: query = context.get("query") query_embedding = self.embedding.get_text_embedding(query) - results = self.vector_index.search(query_embedding, self.topk) + results = self.vector_index.search(query_embedding, self.topk, dis_threshold=2) # TODO: check format results context["vector_result"] = results From 0d45f2a1b67e4b885aa4b5b3d6bb77805514e9b4 Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Fri, 15 Nov 2024 11:38:15 +0800 Subject: [PATCH 6/8] fix(llm): added information of extracted vertices in text2gremlin --- .../hugegraph_llm/operators/hugegraph_op/graph_rag_query.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index 3506203f..37e9e4c0 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -91,6 +91,8 @@ GREMLIN_GENERATE_TPL = """\ Given the graph schema: {schema} +Given the extracted vertex vid: +{vertices} Generate gremlin from the following user input. {query} The generated gremlin is:""" @@ -158,6 +160,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: query = context["query"] + vertices = context.get("match_vids") query_embedding = context.get("query_embedding") if query_embedding is None: query_embedding = self._embedding.get_text_embedding(query) @@ -171,6 +174,7 @@ def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: ) prompt += GREMLIN_GENERATE_TPL.format( schema=json.dumps(context["schema"], ensure_ascii=False), + vertices="\n".join([f"- {vid}" for vid in vertices]), query=query ) From d4be8788516078915e909d206099bc74375fe3f0 Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Sun, 17 Nov 2024 19:34:29 +0800 Subject: [PATCH 7/8] fix(llm): fix the text2gremlin logic --- .../demo/rag_demo/text2gremlin_block.py | 12 +-- .../operators/gremlin_generate_task.py | 8 +- .../operators/hugegraph_op/graph_rag_query.py | 4 + .../index_op/gremlin_example_index_query.py | 8 +- .../operators/llm_op/gremlin_generate.py | 76 ++++++++++++------- 5 files changed, 67 insertions(+), 41 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py index c4e52497..e75b16f8 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py @@ -46,14 +46,14 @@ def gremlin_generate(inp, use_schema, use_example, example_num, schema): llm=LLMs().get_llm(), embedding=Embeddings().get_embedding(), ) - if use_example == "true": - generator.example_index_query(inp, example_num) - context = generator.gremlin_generate(use_schema, use_example, schema).run() + if use_example: + generator.example_index_query(example_num) + context = generator.gremlin_generate(use_schema, use_example, schema).run(query=inp) return context.get("match_result", "No Results"), context["result"] def create_text2gremlin_block() -> list: - gr.Markdown("""## 4. Text2gremlin Tools """) + gr.Markdown("""## Text2gremlin Tools """) gr.Markdown("## Build Example Vector Index") gr.Markdown("Uploaded json file should be in format below:\n" @@ -89,9 +89,9 @@ def create_text2gremlin_block() -> list: match = gr.Textbox(label="Best-Matched Examples") out = gr.Textbox(label="Structured Query Language: Gremlin") with gr.Column(scale=1): - use_example_radio = gr.Radio(choices=["true", "false"], value="false", + use_example_radio = gr.Radio(choices=[True, False], value=False, label="Use example") - use_schema_radio = gr.Radio(choices=["true", "false"], value="false", + use_schema_radio = gr.Radio(choices=[True, False], value=False, label="Use schema") example_num_slider = gr.Slider( minimum=1, diff --git a/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py b/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py index 772dcd22..4b349130 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py @@ -37,8 +37,8 @@ def example_index_build(self, examples): self.operators.append(BuildGremlinExampleIndex(self.embedding, examples)) return self - def example_index_query(self, query, num_examples): - self.operators.append(GremlinExampleIndexQuery(query, self.embedding, num_examples)) + def example_index_query(self, num_examples): + self.operators.append(GremlinExampleIndexQuery(self.embedding, num_examples)) return self def gremlin_generate(self, use_schema, use_example, schema): @@ -51,8 +51,8 @@ def print_result(self): @log_time("total time") @record_qps - def run(self): - context = {} + def run(self, **kwargs): + context = kwargs for operator in self.operators: context = self._run_operator(operator, context) return context diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index 37e9e4c0..ec5a6742 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -90,7 +90,9 @@ GREMLIN_GENERATE_TPL = """\ Given the graph schema: +```json {schema} +``` Given the extracted vertex vid: {vertices} Generate gremlin from the following user input. @@ -172,6 +174,8 @@ def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: example_query=match_result[0]["query"], example_gremlin=match_result[0]["gremlin"] ) + else: + log.warning("No matching example found, generate gremlin with no example.") prompt += GREMLIN_GENERATE_TPL.format( schema=json.dumps(context["schema"], ensure_ascii=False), vertices="\n".join([f"- {vid}" for vid in vertices]), diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py index a04be6f4..6c2015de 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py @@ -25,15 +25,15 @@ class GremlinExampleIndexQuery: - def __init__(self, query: str, embedding: BaseEmbedding, num_examples: int = 1): - self.query = query + def __init__(self, embedding: BaseEmbedding, num_examples: int = 1): self.embedding = embedding self.num_examples = num_examples self.index_dir = os.path.join(resource_path, "gremlin_examples") self.vector_index = VectorIndex.from_index_file(self.index_dir) def run(self, context: Dict[str, Any]) -> Dict[str, Any]: - context["query"] = self.query - query_embedding = self.embedding.get_text_embedding(self.query) + query = context.get("query") + assert query, "query is required" + query_embedding = self.embedding.get_text_embedding(query) context["match_result"] = self.vector_index.search(query_embedding, self.num_examples, dis_threshold=2) return context diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py index f117a5ec..2eb0ba46 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py @@ -18,7 +18,7 @@ import re import json -from typing import Optional, List, Dict, Any +from typing import Optional, List, Dict, Any, Union from hugegraph_llm.models.llms.base import BaseLLM @@ -28,46 +28,65 @@ def gremlin_examples(examples: List[Dict[str, str]]) -> str: for example in examples: example_strings.append( f"- query: {example['query']}\n" - f"- gremlin: {example['gremlin']}") + f"- gremlin:\n```gremlin\n{example['gremlin']}\n```") return "\n\n".join(example_strings) -def gremlin_generate_prompt(inp: str) -> str: - return f"""Generate gremlin from the following user input. -The output format must be: "gremlin: generated gremlin". +def gremlin_generate_prompt(query: str) -> str: + return f"""\ +Generate gremlin from the following user query. +The output format must be like: +```gremlin +g.V().limit(10) +``` -The query is: {inp}""" +Generate gremlin from the following user query. +{query} +The generated gremlin is:""" -def gremlin_generate_with_schema_prompt(schema: str, inp: str) -> str: - return f"""Given the graph schema: +def gremlin_generate_with_schema_prompt(schema: str, query: str) -> str: + return f"""\ +Given the graph schema: {schema} -Generate gremlin from the following user input. -The output format must be: "gremlin: generated gremlin". +Generate gremlin from the following user query. +The output format must be like: +```gremlin +g.V().limit(10) +``` -The query is: {inp}""" +Generate gremlin from the following user query. +{query} +The generated gremlin is:""" -def gremlin_generate_with_example_prompt(example: str, inp: str) -> str: +def gremlin_generate_with_example_prompt(example: str, query: str) -> str: return f"""Given the example query-gremlin pairs: {example} -Generate gremlin from the following user input. -The output format must be: "gremlin: generated gremlin". +Generate gremlin from the following user query. +The output format must be like: +```gremlin +g.V().limit(10) +``` -The query is: {inp}""" +Generate gremlin from the following user query. +{query} +The generated gremlin is:""" -def gremlin_generate_with_schema_and_example_prompt(schema: str, example: str, inp: str) -> str: - return f"""Given the graph schema: -{schema} +def gremlin_generate_with_schema_and_example_prompt(schema: str, example: str, query: str) -> str: + return f"""\ Given the example query-gremlin pairs: {example} -Generate gremlin from the following user input. -The output format must be: "gremlin: generated gremlin". - -The query is: {inp}""" +Given the graph schema: +```json +{schema} +``` +Generate gremlin from the following user query. +{query} +The generated gremlin is:""" class GremlinGenerate: @@ -76,25 +95,28 @@ def __init__( llm: BaseLLM, use_schema: bool = False, use_example: bool = False, - schema: Optional[dict] = None + schema: Optional[Union[dict, str]] = None ) -> None: self.llm = llm self.use_schema = use_schema self.use_example = use_example + if isinstance(schema, dict): + schema = json.dumps(schema, encode='utf8') self.schema = schema def run(self, context: Dict[str, Any]) -> Dict[str, Any]: query = context.get("query", "") + assert query, "query is required" examples = context.get("match_result", []) if not self.use_schema and not self.use_example: prompt = gremlin_generate_prompt(query) elif not self.use_schema and self.use_example: prompt = gremlin_generate_with_example_prompt(gremlin_examples(examples), query) elif self.use_schema and not self.use_example: - prompt = gremlin_generate_with_schema_prompt(json.dumps(self.schema), query) + prompt = gremlin_generate_with_schema_prompt(self.schema, query) else: prompt = gremlin_generate_with_schema_and_example_prompt( - json.dumps(self.schema), + self.schema, gremlin_examples(examples), query ) @@ -105,10 +127,10 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: return context def _extract_gremlin(self, response: str) -> str: - match = re.search(r'gremlin[::][^\n]+\n?', response) + match = re.search("```gremlin.*```", response, re.DOTALL) if match is None: return "Unable to generate gremlin from your query." - return match.group()[len("gremlin:"):].strip() + return match.group()[len("```gremlin"):-len("```")].strip() if __name__ == '__main__': From 0f129678489807f854c8d6ca1070f2e3e649a4ba Mon Sep 17 00:00:00 2001 From: vichayturen <1073931273@qq.com> Date: Wed, 20 Nov 2024 17:35:19 +0800 Subject: [PATCH 8/8] fix(llm): added logic to automatically build text2gremlin example index --- .../operators/hugegraph_op/graph_rag_query.py | 21 +++- .../resources/demo/text2gremlin.csv | 99 +++++++++++++++++++ 2 files changed, 117 insertions(+), 3 deletions(-) create mode 100644 hugegraph-llm/src/hugegraph_llm/resources/demo/text2gremlin.csv diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index ae179a7d..ef68f67c 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -20,6 +20,7 @@ import re from typing import Any, Dict, Optional, List, Set, Tuple +import pandas as pd from hugegraph_llm.config import settings, resource_path from hugegraph_llm.indices.vector_index import VectorIndex from hugegraph_llm.models.embeddings.base import BaseEmbedding @@ -122,8 +123,6 @@ def __init__( self._max_items = max_items self._prop_to_match = prop_to_match self._schema = "" - self._index_dir = os.path.join(resource_path, "gremlin_examples") - self._vector_index = VectorIndex.from_index_file(self._index_dir) self._llm = llm self._embedding = embedding @@ -167,7 +166,8 @@ def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: if query_embedding is None: query_embedding = self._embedding.get_text_embedding(query) - match_result = self._vector_index.search(query_embedding, top_k=1, dis_threshold=2) + vector_index = self._get_gremlin_example_index() + match_result = vector_index.search(query_embedding, top_k=1, dis_threshold=2) prompt = "" if match_result: prompt += GREMLIN_GENERATE_EXAMPLE_OPTION_TPL.format( @@ -279,6 +279,21 @@ def _subgraph_query(self, context: Dict[str, Any]) -> Dict[str, Any]: log.debug("\n".join(context["graph_result"])) return context + def _get_gremlin_example_index(self) -> VectorIndex: + index_dir = os.path.join(resource_path, "gremlin_examples") + if not (os.path.exists(os.path.join(index_dir, "index.faiss")) + and os.path.exists(os.path.join(index_dir, "properties.pkl"))): + log.warning("No gremlin example index found, will generate one.") + properties = (pd.read_csv(os.path.join(resource_path, "demo", "text2gremlin.csv")) + .to_dict(orient="records")) + embeddings = [self._embedding.get_text_embedding(row["query"]) for row in properties] + vector_index = VectorIndex(len(embeddings[0])) + vector_index.add(embeddings, properties) + vector_index.to_index_file(index_dir) + else: + vector_index = VectorIndex.from_index_file(index_dir) + return vector_index + def _format_graph_from_vertex(self, query_result: List[Any]) -> Set[str]: knowledge = set() for item in query_result: diff --git a/hugegraph-llm/src/hugegraph_llm/resources/demo/text2gremlin.csv b/hugegraph-llm/src/hugegraph_llm/resources/demo/text2gremlin.csv new file mode 100644 index 00000000..9b7f07ea --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/resources/demo/text2gremlin.csv @@ -0,0 +1,99 @@ +query,gremlin +腾讯适合合作吗,"g.V().has('company','name','腾讯').as('a').project('公司信息','法人','对外投资企业数量','投资人-自然人','高管','投资人-公司','最终受益人-自然人','最终受益人-公司').by(valueMap('description', 'email', 'phone', 'operatingStatus', 'registrationAddress', 'salaryTreatment', 'registeredCapital', 'registeredCapitalCurrency', 'financingInformation')).by(select('a').in('legalPerson').values('name')).by(select('a').out('companyInvest').values('name').count()).by(select('a').in('personInvest').values('name').fold()).by(select('a').in('companyInvest').values('name').fold()).by(select('a').in('serve').limit(3).values('name').fold()).by(select('a').in('finalBeneficiaryPerson').values('name').limit(3).fold()).by(select('a').in('finalBeneficiaryCompany').limit(3).values('name').fold())" +四川省有哪些企业?,"g.V().has('company','province','四川').limit(20).values('name')" +腾讯是在哪年成立的,"g.V().has('company','name','腾讯').values('establishmentYear')" +给我一份2011年成立的公司名单,"g.V().has('company','establishmentYear',2011).limit(20).values('name')" +5278600是哪个企业的电话?经常骚扰,"g.V().has('company','phone', containing('5278600')).values('name')" +北京有哪些通用航空生产服务行业的公司,"g.V().has('company','city','北京').has('industry',containing('通用航空生产服务业')).values('name')" +沈阳1999年成立的计算机外围设备制造企业有哪些,"g.V().has('company', 'city','沈阳').has('establishmentYear', 1999).has('industry',containing('计算机外围设备制造')).limit(20).values('name')" +沈阳2000年以前成立的计算机外围设备制造企业有哪些,"g.V().has('company', 'city','沈阳').has('establishmentYear', lt(2000)).has('industry',containing('计算机外围设备制造')).limit(20).values('name')" +马化腾是哪家公司法人,"g.V().has('person', 'name', '马化腾').out('legalPerson').limit(20).values('name')" +马化腾是哪家公司股东,"g.V().has('person', 'name', '马化腾').out('personInvest').limit(20).values('name')" +腾讯的投资人都有谁,"g.V().has('company', 'name', '腾讯').inE('personInvest', 'companyInvest').as('s').otherV().as('p').project('shareholdingRatio', 'name').by(select('s').values('shareholdingRatio')).by(select('p').values('name'))" +腾讯的注册资本是多少,"g.V().has('company', 'name', '腾讯').valueMap('registeredCapital', 'registeredCapitalCurrency')" +腾讯的工资待遇怎么样?,"g.V().has('company', 'name', '腾讯').values('salaryTreatment')" +腾讯最近在招聘吗?,"g.V().has('company', 'name', '腾讯').values('recruitmentInfo')" +腾讯的党委书记是谁,"g.V().has('company', 'name', '腾讯').inE('serve').as('po').outV().as('pe').project('name', 'position').by(select('po').values('position')).by(select('pe').values('name'))" +腾讯中马化腾的职务,"g.V().has('company', 'name', '腾讯').inE('serve').where(outV().has('name', '马化腾')).values('position')" +腾讯股权构成和股东信息,"g.V().has('company', 'name', '腾讯').inE('personInvest', 'companyInvest').as('a').outV().project('name', 'info').by(values('name')).by(select('a').valueMap())" +腾讯的地址和法人信息,"g.V().has('company','name','腾讯').as('a').in('legalPerson').project('registrationAddress', 'legalPerson').by(select('a').values('registrationAddress')).by(values('name'))" +腾讯和美团的关系是什么,"g.V().has('company','name','腾讯').bothE().where(otherV().has('company','name', '美团')).label()" +腾讯的马化腾,"g.V().has('company', 'name', '腾讯').in().has('person', 'name', '马化腾').outE('legalPerson','actualControllerPerson','personInvest','serve').as('a').limit(20).inV().as('b').project('edge', 'name').by(select('a').label()).by(select('b').values('name'))" +腾讯的统一社会信用代码,"g.V().has('company', 'name', '腾讯').values('unifiedCreditCode')" +腾讯的员工工资是多少,"g.V().hasLabel('company').has('name', '腾讯').values('salaryTreatment')" +腾讯详细工商信息,"g.V().has('company', 'name', '腾讯').match( __.as('a').valueMap('name', 'registrationAddress', 'registeredCapital', 'industry', 'businessScope').fold().as('company'), __.as('a').in('legalPerson').values('name').fold().as('legalPerson'),__.as('a').out('branch').values('name').fold().as('branch'),__.as('a').in('personInvest').values('name').fold().as('personInvest'),__.as('a').in('companyInvest').values('name').fold().as('companyInvest')).select('company','legalPerson', 'branch', 'personInvest', 'companyInvest')" +腾讯的马化腾和张志东是合作伙伴关系吗?,"g.V().has('person', 'name','马化腾').out('partners').has('person','name','张志东').hasNext()" +腾讯的马化腾的合作伙伴都有谁?,"g.V().has('person', 'name','马化腾').out('partners').limit(20).values('name')" +马化腾都投资了哪些公司?,"g.V().has('person', 'name','马化腾').out('personInvest').limit(20).values('name')" +腾讯都控股了哪些公司,且控股比例是多少?,"g.V().has('company', 'name', '腾讯').outE('controllingShareholderCompany').as('a').inV().project('name', 'info').by(values('name')).by(select('a').valueMap())" +马化腾在腾讯的持股比例是多少?,"g.V().has('company', 'name', '腾讯').inE('personInvest').where(outV().has('person','name','马化腾')).values('shareholdingRatio')" +马化腾在腾讯的认缴出资额是多少?,"g.V().has('company', 'name', '腾讯').inE('personInvest').where(outV().has('person','name','马化腾')).valueMap('capitalContribution','unitOfContribution')" +马化腾在腾讯的认缴出资日期是什么时候?,"g.V().has('company', 'name', '腾讯').inE('personInvest').where(outV().has('person','name','马化腾')).values('contributionDate')" +腾讯实际人数与缴纳社保人数多少,"g.V().has('company', 'name', '腾讯').valueMap('insuredNumberOfPeople','numberOfEmployees')" +腾讯已实缴的资金有多少,"g.V().has('company', 'name', '腾讯').valueMap('paidInCapital','paidInCapitalCurrency')" +腾讯有多少家子公司,"g.V().has('company', 'name', '腾讯').out('branch').count()" +腾讯有多少员工,主要做什么的,"g.V().has('company', 'name', '腾讯').valueMap('numberOfEmployees', 'businessScope')" +腾讯的实际控制人和法人是谁?,"g.V().has('company', 'name', '腾讯').inE().hasLabel('legalPerson', 'actualControllerPerson', 'actualControllerCompany').as('a').outV().as('b').project('name', 'label').by(select('b').values('name')).by(select('a').label())" +制造业有哪些公司上海,"g.V().has('company', 'city','上海').has('industry',containing('制造业')).limit(20).values('name')" +2000年之前成立的企业有哪些,"g.V().has('company', 'establishmentYear', lt(2000)).values('name')" +和腾讯有关系的公司和人有哪些,"g.V().has('company', 'name', '腾讯').bothE().limit(20).as('a').otherV().as('b').project('name', 'label').by(select('b').values('name')).by(select('a').label())" +2008年到2015年之间成立的企业有哪些,"g.V().has('company','establishmentYear', between(2008, 2015)).limit(20).values('name')" +腾讯有官方的微信公众号吗?,"g.V().has('company', 'name', '腾讯').values('wechatPublicNumber')" +腾讯旗下有几家公司,"g.V().has('company','name','腾讯').as('a').match(__.as('a').out('branch').values('name').count().fold().as('num'),__.as('a').out('branch').values('name').fold().as('company_name')).select('num','company_name')" +腾讯的子公司或分公司有哪些,"g.V().has('company','name','腾讯').as('a').match(__.as('a').out('branch').values('name').count().fold().as('num'),__.as('a').out('branch').limit(10).values('name').fold().as('company_name')).select('num','company_name')" +腾讯的法人代表是谁,该公司注册资本是多少?,"g.V().has('company', 'name', '腾讯').as('a').project('legalPerson','registeredCapital').by(select('a').in('legalPerson').values('name')).by(valueMap('registeredCapital','registeredCapitalCurrency'))" +腾讯各股东占比,"g.V().has('company', 'name', '腾讯').inE('personInvest', 'companyInvest').as('a').outV().as('b').project('name', 'info').by(select('b').values('name')).by(select('a').valueMap())" +腾讯的主要负责人,"g.V().has('company', 'name', '腾讯').inE('legalPerson','actualControllerPerson','serve').as('a').outV().project('label','name').by(select('a').label()).by(values('name'))" +腾讯的实际控制人对其股权比例为多少,"g.V().has('company', 'name', '腾讯').match(__.as('a').in('actualControllerPerson').as('b'),__.as('a').inE('personInvest').as('c').outV().as('d'),where('b', eq('d')).by('name')).project('name', 'info').by(select('d').values('name')).by(select('c').valueMap())" +腾讯都控股了哪些公司,其认缴出资额分别是多少?,"g.V().has('company', 'name', '腾讯').outE('controllingShareholderCompany').as('e').inV().as('v').project('name','info').by(select('v').values('name')).by(select('e').valueMap())" +腾讯的马化腾还在其他哪些公司任职?,"g.V().has('company', 'name', '腾讯').as('a').in().has('person','name', '马化腾').out('serve')where(neq('a')).values('name')" +腾讯的老板信息,"g.V().has('company', 'name', '腾讯').inE().hasLabel('legalPerson','actualControllerPerson', 'actualControllerCompany').as('a').outV().as('b').project('label', 'name').by(select('a').label()).by(select('b').values('name'))" +腾讯和美团的关系,"g.V().has('company', 'name', '腾讯').bothE().where(otherV().has('company', 'name', '美团')).label()" +腾讯的法定代表人马化腾有几家公司,"g.V().has('company', 'name', '腾讯').in('actualControllerPerson', 'legalPerson').has('person', 'name', '马化腾').outE('legalPerson','actualControllerPerson','personInvest','serve').as('a').limit(20).inV().as('b').project('edge', 'name').by(select('a').label()).by(select('b').values('name'))" +腾讯的马化腾和王兴的股份比例,"g.V().has('company', 'name', '腾讯').inE('personInvest').as('a').outV().has('name',within('马化腾', '王兴')).as('b').project('name','shareholdingRatio').by(select('b').values('name')).by(select('a').values('shareholdingRatio'))" +腾讯和美团的关系,"g.V().has('name', '腾讯').bothE().where(otherV().has('name', '美团')).label()" +腾讯的股东马化腾的相关信息,"g.V().has('company', 'name', '腾讯').in('personInvest').has('person', 'name', '马化腾').outE('legalPerson','actualControllerPerson','personInvest','serve').as('a').limit(20).inV().as('b').project('edge', 'name').by(select('a').label()).by(select('b').values('name'))" +腾讯在招战略分析岗位,"g.V().has('company','name','腾讯').valueMap('recruitmentInfo')" +腾讯和中科院有关系吗,"g.V().has('name', '腾讯').bothE().where(otherV().has('name', '中科院')).label()" +腾讯的社保信息,"g.V().has('company', 'name', '腾讯').valueMap('insuredNumberOfPeople', 'unifiedCreditCode', 'industry', 'taxpayerIdentificationNumber','administrativeDivision', 'province')" +腾讯的知识产权情况,"g.V().has('company', 'name', '腾讯').valueMap('copyrightForWorks','websiteRegistrationRecord','patentInformation')" +腾讯的法定代表人马化腾的背景,"g.V().has('company', 'name', '腾讯').in('legalPerson').has('person', 'name', '马化腾').outE('legalPerson','actualControllerPerson','personInvest','serve').as('a').limit(20).inV().as('b').project('edge', 'name').by(select('a').label()).by(select('b').values('name'))" +腾讯马化腾共持股多少,"g.V().has('company', 'name', '腾讯').inE('personInvest').where(outV().has('name','马化腾')).valueMap()" +腾讯的老板成立过哪些公司,"g.V().has('company', 'name', '腾讯').in('legalPerson').outE('legalPerson','actualControllerPerson','personInvest','serve').as('a').limit(20).inV().as('b').project('edge', 'name').by(select('a').label()).by(select('b').values('name'))" +腾讯的23条对外投资信息,"g.V().has('company', 'name', '腾讯').outE('companyInvest').as('a').inV().as('b').project('name', 'info').by(select('b').values('name')).by(select('a').valueMap())" +腾讯领导层,"g.V().has('company', 'name', '腾讯').inE('serve').as('a').outV().as('b').project('name','position').by(select('b').values('name')).by(select('a').values('position'))" +腾讯的马化腾还有哪些企业任职,"g.V().has('company', 'name', '腾讯').as('a').in().has('person', 'name', '马化腾').out('serve').where(neq('a')).values('name')" +腾讯黄永刚注册资本,"g.V().has('company', 'name', '腾讯').inE('personInvest').where(outV().has('person', 'name','黄永刚')).valueMap('capitalContribution','unitOfContribution')" +腾讯的16家控股企业都是哪16家,"g.V().has('company', 'name', '腾讯').out('controllingShareholderCompany').values('name').limit(16)" +腾讯历任董事会秘书,"g.V().has('company', 'name', '腾讯').inE('serve').has('position', containing('董事会秘书')).outV().values('name')" +腾讯的十大股东是谁,"g.V().has('company', 'name', '腾讯').inE('personInvest', 'companyInvest').order().by('shareholdingRatio',desc).limit(10).as('a').outV().as('b').project('name','shareholdingRatio').by(select('b').values('name')).by(select('a').values('shareholdingRatio'))" +腾讯的法人和实际控制人都是谁,"g.V().has('company', 'name', '腾讯').as('a') +.project('法人', '实际控制人') +.by(__.in('legalPerson').values('name').fold()) +.by(__.in('actualControllerPerson','actualControllerCompany').values('name').fold()) +.select('法人','实际控制人').by(__.coalesce(identity(), __.constant('未知')))" +腾讯的法人和董事长分别是谁,"g.V().has('company', 'name', '腾讯').as('a') +.project('法人', '董事长') +.by(__.in('legalPerson').values('name').fold()) +.by(__.inE('serve').has('position', containing('董事长')).outV().values('name').fold()) +.select('法人','董事长').by(__.coalesce(identity(), __.constant('未知')))" +腾讯的CEO投资了哪些企业,"g.V().has('company', 'name', '腾讯').inE('serve').has('position', containing('CEO')).outV().outE('legalPerson','actualControllerPerson','personInvest','serve').as('a').limit(20).inV().as('b').project('edge', 'name').by(select('a').label()).by(select('b').values('name'))" +腾讯的总经理有控股公司吗?,"g.V().has('company', 'name', '腾讯').inE('serve').has('position', containing('总经理')).outV().out('controllingShareholderPerson').values('name')" +腾讯老板的合作伙伴有哪些,"g.V().has('company', 'name', '腾讯').in('legalPerson','actualControllerPerson').dedup().project('name','partners').by(values('name')).by(out('partners').values('name').fold())" +腾讯的最总受益人和董事长是同一个人吗?,"g.V().has('company', 'name', '腾讯').as('c') +.project('最终受益人', '董事长') +.by(__.in('finalBeneficiaryPerson').values('name').fold().coalesce(identity(), __.constant('未知'))) +.by(__.inE('serve').has('position', containing('董事长')).outV().values('name').fold().coalesce(identity(), __.constant('未知'))) +.select('最终受益人', '董事长')" +腾讯的股东、法人、最终受益人、董事、董事长分别有哪些,"g.V().has('company', 'name', '腾讯').as('a') +.project('股东','法人','最终受益人','董事','董事长') +.by(__.in('personInvest').values('name').fold()) +.by(__.in('legalPerson').values('name').fold()) +.by(__.in('actualControllerPerson').values('name').fold()) +.by(__.inE('serve').has('position', containing('董事')).outV().values('name').fold()) +.by(__.inE('serve').has('position', containing('董事长')).outV().values('name').fold())" +腾讯的关联公司,"g.V().has('company','name','腾讯').project('branch','companyInvest').by(out('branch').values('name').fold()).by(out('companyInvest').values('name').fold())" +腾讯的实缴资本和注册资本分别是多少,"g.V().has('company', 'name', '腾讯').valueMap('registeredCapital', 'registeredCapitalCurrency','paidInCapital','paidInCapitalCurrency')" +腾讯有哪些知识产权,"g.V().has('company','name','腾讯').valueMap('copyrightForWorks', 'patentInformation', 'websiteRegistrationRecord')" +腾讯的作品著作权有哪些,"g.V().has('company', 'name', '腾讯').values('copyrightForWorks')" +腾讯有哪些岗位,"g.V().has('company','name','腾讯').values('recruitmentInfo')"