-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreator.py
162 lines (130 loc) · 5.32 KB
/
creator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# -*- coding: utf-8 -*-
# @Time : 2024/12/22
# @Author : liuboyuan
# @Description :
import os
from api.conf import azure_model_config, gpt_model_config
from framework.inori_llm_core.language_model.azure_model import AzureModel
from langchain_community.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_community.vectorstores import FAISS, Pinecone
from langchain.text_splitter import CharacterTextSplitter
def read_file(file_name):
"""读取文本"""
with open(file_name, encoding='utf-8') as f:
file_datas = f.read()
return file_datas
def split_data(file_data, split_str=r'\n\n---\n\n'):
"""文本拆分"""
text_splitter = CharacterTextSplitter(
separator=split_str,
chunk_size=150,
chunk_overlap=0,
length_function=len,
is_separator_regex=True,
)
docs = text_splitter.create_documents([file_data])
print(docs[1])
return docs
def embedding_text(docs, db_name):
# AZURE_OPENAI_ENDPOINT = azure_model_config.get("api_endpoint")
# AZURE_OPENAI_API_KEY = azure_model_config.get("api_key")
# AZURE_OPENAI_API_VERSION = "2023-07-01-preview"
# embedding_client = AzureOpenAIEmbeddings(
# azure_endpoint=AZURE_OPENAI_ENDPOINT,
# openai_api_key=AZURE_OPENAI_API_KEY,
# openai_api_version=AZURE_OPENAI_API_VERSION,
# chunk_size=150,
# )
# print(gpt_model_config)
os.environ['OPENAI_API_KEY'] = gpt_model_config.get("api_key")
embedding_client = OpenAIEmbeddings()
try:
db = FAISS.from_documents(docs, embedding_client)
if not os.path.exists(db_name): # 向量数据库文件不存在就创建并保存
db.save_local(db_name)
else:
old_db = FAISS.load_local(db_name, embedding_client) # 向量数据库文件存在就添加并保存
old_db.merge_from(db)
old_db.save_local(db_name)
return True
except Exception as e:
raise e
def query_db(db_name, query):
"""使用 Faiss 作为向量数据库,去向量数据里做内容检索"""
db = FAISS.load_local(db_name, AzureOpenAIEmbeddings())
retriever = db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"score_threshold": 0.8, 'k': 1, "fetch_k": 2} # 按相关性去查找, k默认返回几条
)
docs = retriever.get_relevant_documents(query)
if docs:
for d in docs:
return d.page_content.split(":**")[-1]
else:
return "没有符合条件的回答"
# db = Pinecone.load_local(db_name, OpenAIEmbeddings())
# retriever = db.as_retriever(
# search_type="similarity_score_threshold",
# search_kwargs={"score_threshold": 0.8, 'k': 1, "fetch_k": 2} # 按相关性去查找, k默认返回几条
# )
# docs = retriever.get_relevant_documents(query)
# if docs:
# for d in docs:
# return d.page_content.split(":**")[-1]
# else:
# return "没有符合条件的回答"
def t(embeddings):
# Create a vector store with a sample text
from langchain_core.vectorstores import InMemoryVectorStore
text = "LangChain is the framework for building context-aware reasoning applications"
vectorstore = InMemoryVectorStore.from_texts(
[text],
embedding=embeddings,
)
# Use the vectorstore as a retriever
retriever = vectorstore.as_retriever()
# Retrieve the most similar text
retrieved_documents = retriever.invoke("What is LangChain?")
# show the retrieved document's content
retrieved_documents[0].page_content
def save_to_db(file_name, db_name, split_str):
"""
主函数:读取文本、分割文本、向量化、存到向量数据库
"""
file_data = read_file(file_name)
docs = split_data(file_data, split_str)
res = embedding_text(docs, db_name)
return res
def format_data(input_filename, output_filename):
try:
# 读取输入文件内容
with open(input_filename, 'r', encoding='utf-8') as infile:
lines = infile.readlines()
# 写入输出文件,为每一行添加编号
with open(output_filename, 'w', encoding='utf-8') as outfile:
for index, line in enumerate(lines, start=1):
if "健康专家" in line.strip():
# 写入编号和对应的行
print(line)
outfile.write(f"{index}.\n{line.strip()}\n")
else:
outfile.write(line)
print("文件处理完成。")
except Exception as e:
print(f"处理文件时发生错误: {e}")
format_data(file_name, "output.txt")
if __name__ == "__main__":
file_name = "C:\\Project\\FischlAgent\\fischlApi\\framework\\amane_knowledge\\raw_data\\food_additive_2.txt"
db_name = "vector_database_food_additive"
split_str = r'\d+\.'
save_to_db(file_name, db_name, split_str)
# AZURE_OPENAI_ENDPOINT = azure_model_config.get("api_endpoint")
# AZURE_OPENAI_API_KEY = azure_model_config.get("api_key")
# AZURE_OPENAI_API_VERSION = "2023-07-01-preview"
# embedding_client = AzureOpenAIEmbeddings(
# azure_endpoint=AZURE_OPENAI_ENDPOINT,
# openai_api_key=AZURE_OPENAI_API_KEY,
# openai_api_version=AZURE_OPENAI_API_VERSION,
# chunk_size=150,
# )
# t(embedding_client)