-
Notifications
You must be signed in to change notification settings - Fork 0
/
1 - Preparing Data for RAG.py
132 lines (92 loc) · 3.16 KB
/
1 - Preparing Data for RAG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Databricks notebook source
# MAGIC %run ./Includes/Setup
# COMMAND ----------
# MAGIC %md
# MAGIC # Data Exploration
# COMMAND ----------
source_path = "s3://dalhussein-odsc/papers/"
articles_path = 'dbfs:/mnt/odsc/papers'
download_dataset(source_path, articles_path)
# COMMAND ----------
files = dbutils.fs.ls(articles_path)
display(files)
# COMMAND ----------
# MAGIC %md
# MAGIC # Bronze
# COMMAND ----------
df_raw = (spark.read.format("binaryfile")
.load(articles_path)
.withColumnRenamed("path", "doc_uri")
)
display(df_raw)
# COMMAND ----------
bronze_table_name = f"{catalog_name}.{schema_name}.bronze_articles_raw"
df_raw.write.mode("overwrite").saveAsTable(bronze_table_name)
# COMMAND ----------
# MAGIC %md
# MAGIC # Silver
# COMMAND ----------
import fitz # PyMuPDF
def pdf_to_text(pdf_content):
"""
Convert PDF to text using PyMuPDF
"""
doc = fitz.open(stream=pdf_content, filetype="pdf")
text = ""
for page in doc:
text += page.get_text()
return text
# COMMAND ----------
from pyspark.sql.types import StringType
pdf_to_text_udf = udf(pdf_to_text, StringType())
df_parsed = (df_raw.withColumn("text", pdf_to_text_udf("content"))
.select("doc_uri", "text")
)
display(df_parsed)
# COMMAND ----------
silver_table_name = f"{catalog_name}.{schema_name}.silver_articles_parsed"
df_parsed.write.mode("overwrite").saveAsTable(silver_table_name)
# COMMAND ----------
# MAGIC %md
# MAGIC # Gold
# COMMAND ----------
import pandas as pd
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import Document
from llama_index.core.utils import set_global_tokenizer
from transformers import AutoTokenizer
from typing import Iterator
from pyspark.sql.functions import pandas_udf, explode
@pandas_udf("array<string>")
def read_as_chunk(batch_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
# set llama as tokenizer
set_global_tokenizer(
AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
)
# sentence splitter from llama_index to split on sentences
splitter = SentenceSplitter(chunk_size=500, chunk_overlap=50)
def extract_and_split(col):
nodes = splitter.get_nodes_from_documents([Document(text=col)])
return [n.text for n in nodes]
for x in batch_iter:
yield x.apply(extract_and_split)
# COMMAND ----------
gold_table_name = f"{catalog_name}.{schema_name}.gold_articles_chunks"
spark.sql(f"""
CREATE TABLE IF NOT EXISTS {gold_table_name} (
chunk_id BIGINT GENERATED BY DEFAULT AS IDENTITY,
chunked_text STRING,
doc_uri STRING
-- NOTE: the table has to be CDC because VectorSearch is using DLT that is requiring CDC state
) TBLPROPERTIES (delta.enableChangeDataFeed = true);
""")
# COMMAND ----------
df_chunks = (df_parsed.withColumn("chunked_text", explode(read_as_chunk("text")))
.select("chunked_text", "doc_uri")
)
display(df_chunks)
# COMMAND ----------
df_chunks.write.mode("overwrite").saveAsTable(gold_table_name)
# COMMAND ----------
# MAGIC %md
# MAGIC