-
Notifications
You must be signed in to change notification settings - Fork 1
/
embedding.py
40 lines (32 loc) · 1.27 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import openai
import pandas as pd
from dotenv import load_dotenv
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.organization = os.getenv("OPENAI_ORGANIZATION")
EMBEDDING_MODEL = "text-embedding-ada-002" # OpenAI's best embeddings as of Apr 2023
BATCH_SIZE = 1000 # you can submit up to 2048 embedding inputs per request
def embed(text, store_path):
embeddings = []
for batch_start in range(0, len(text), BATCH_SIZE):
batch_end = batch_start + BATCH_SIZE
batch = text[batch_start:batch_end]
response = openai.Embedding.create(
model=EMBEDDING_MODEL,
input=list(batch.values),
)
for i, be in enumerate(response["data"]):
assert (
i == be["index"]
) # double check embeddings are in same order as input
batch_embeddings = [e["embedding"] for e in response["data"]]
embeddings.extend(batch_embeddings)
df = pd.DataFrame({"text": text, "embedding": embeddings})
df.to_csv(store_path, index=False)
if __name__ == "__main__":
for f in os.listdir("csv_to_embed"):
text = pd.read_csv("csv_to_embed/" + f).apply(
lambda x: " ".join(x.dropna().astype(str)), axis=1
)
embed(text, "embedding/" + f)