From 4390d9d200c312cc1f542b5ed38eece40345cf85 Mon Sep 17 00:00:00 2001 From: Shreyas Shirish Agrawal <48771895+splendidbug@users.noreply.github.com> Date: Wed, 10 Jul 2024 10:13:39 -0700 Subject: [PATCH] Knowledge packs (#3) * knowledge pack creqation is working- need to cleanup code * added joinpath * added tests and code improvements * refactored code - docstrings, path changes --- Project.toml | 4 + src/RAGKit.jl | 20 ++++- src/crawl.jl | 136 ++++++++++++++++++++++++---------- src/extract_urls.jl | 22 +++--- src/make_embeddings.jl | 163 +++++++++++++++++++++++++++++++++++++++++ src/preparation.jl | 115 +++++++++++++++++++++++++++++ src/utils.jl | 68 +++++++++++++++++ test/runtests.jl | 32 +++++++- 8 files changed, 506 insertions(+), 54 deletions(-) create mode 100644 src/make_embeddings.jl create mode 100644 src/preparation.jl create mode 100644 src/utils.jl diff --git a/Project.toml b/Project.toml index 63cd0f8..964d069 100644 --- a/Project.toml +++ b/Project.toml @@ -5,9 +5,13 @@ version = "0.1.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +DotEnv = "4dc1fcf4-5e3b-5448-94ab-0c38ec0385c1" EzXML = "8f5d6c58-4d21-5cfd-889c-e3ad7ee6a615" Gumbo = "708ec375-b3d6-5a57-a7ce-8257bf98657a" +HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" +Inflate = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" +PromptingTools = "670122d1-24a8-4d70-bfce-740807c42192" URIParser = "30578b45-9adc-5946-b283-645ec420af67" URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" diff --git a/src/RAGKit.jl b/src/RAGKit.jl index 3d27e95..b895363 100644 --- a/src/RAGKit.jl +++ b/src/RAGKit.jl @@ -1,10 +1,28 @@ +module RAGKit using HTTP, Gumbo, AbstractTrees, URIs using Gumbo: HTMLDocument, HTMLElement using EzXML +using PromptingTools +const PT = PromptingTools +const RT = PromptingTools.Experimental.RAGTools +using LinearAlgebra, Unicode, SparseArrays +using HDF5 +using Tar +using Inflate + +using SHA +using Serialization, URIs # using Regex # using Robots include("parser.jl") include("crawl.jl") -include("extract_urls.jl") \ No newline at end of file +include("extract_urls.jl") +include("preparation.jl") + +include("make_embeddings.jl") +export make_embeddings + + +end \ No newline at end of file diff --git a/src/crawl.jl b/src/crawl.jl index b1dd60b..b147511 100644 --- a/src/crawl.jl +++ b/src/crawl.jl @@ -1,10 +1,16 @@ -include("parser.jl") +## TODO: Make multiple dispatch for the following function to remove if-else +""" + parse_robots_txt!(robots_txt::String) -## TODO: Make multiple dispatch for the following function -function parse_robots_txt!(robots_txt::String, url_queue::Vector{<:AbstractString}) - ## TODO: Make a cache of rules for a quick lookup +Parses the robots.txt string and returns rules along with the URLs on Sitemap + +# Arguments +- `robots_txt`: robots.txt as a string +""" +function parse_robots_txt!(robots_txt::String) rules = Dict{String,Dict{String,Vector{String}}}() current_user_agent = "" + sitemap_urls = Vector{AbstractString}() for line in split(robots_txt, '\n') line = strip(line) @@ -25,35 +31,46 @@ function parse_robots_txt!(robots_txt::String, url_queue::Vector{<:AbstractStrin end elseif startswith(line, "Sitemap:") url = strip(split(line, ":")[2]) - push!(url_queue, url) + push!(sitemap_urls, url) end end - return rules + return rules, sitemap_urls end +""" + check_robots_txt(user_agent::AbstractString, + url::AbstractString) + +Checks the robots.txt of a URL and returns a boolean representing if `user_agent` is allowed to crawl the input url + +# Arguments +- `user_agent`: user agent attempting to crawl the webpage +- `url`: input URL string +""" function check_robots_txt(user_agent::AbstractString, - url::AbstractString, - restricted_urls::Dict{String,Set{AbstractString}}, - url_queue::Vector{<:AbstractString}) + url::AbstractString) + + ## TODO: Make a cache of rules for a quick lookup + # if (haskey(restricted_urls, url)) + # if (in(path, restricted_urls[url])) + # println("Not allowed to crawl $url") + # return false + # else + # return true + # end + # end URI = URIs.URI(url) path = URI.path - if (haskey(restricted_urls, url)) - if (in(path, restricted_urls[url])) - println("Not allowed to crawl $url") - return false - else - return true - end - end robots_URL = string(URI.scheme, "://", URI.host, "/robots.txt") + sitemap_urls = Vector{AbstractString}() try response = HTTP.get(robots_URL) robots_txt = String(response.body) - rules = parse_robots_txt!(robots_txt, url_queue) + rules, sitemap_urls = parse_robots_txt!(robots_txt) user_agents = [user_agent, "*"] for ua in user_agents if haskey(rules, ua) @@ -62,26 +79,25 @@ function check_robots_txt(user_agent::AbstractString, for allow_rule in allow_rules if startswith(path, allow_rule) - return true + return true, sitemap_urls end end for disallow_rule in disallow_rules if startswith(path, disallow_rule) - println("Not allowed to crawl $url") - return false + @warn "Not allowed to crawl $url" + return false, sitemap_urls end end end end - return true + return true, sitemap_urls catch - println("robots.txt unavailable for $url") - return true + @info "robots.txt unavailable for $url" + return true, sitemap_urls end end - """ get_base_url(url::AbstractString) @@ -100,35 +116,77 @@ end """ - makeRAG(input_urls::Vector{<:AbstractString}) + process_hostname(url::AbstractString) -Extracts the base url. +Returns the hostname of an input URL + +# Arguments +- `url`: URL string +""" +function process_hostname(url::AbstractString) + URI = URIs.URI(url) + hostname = String(URI.host) + return hostname +end + + +""" + process_hostname(url::AbstractString, hostname_dict::Dict{AbstractString,Vector{AbstractString}}) + +Adds the `url` to it's hostname in `hostname_dict` # Arguments -- `input_urls`: vector containing URL strings to parse +- `url`: URL string +- `hostname_dict`: Dict with key being hostname and value being a vector of URLs """ -function makeRAG(input_urls::Vector{<:AbstractString}) +function process_hostname!(url::AbstractString, hostname_dict::Dict{AbstractString,Vector{AbstractString}}) + hostname = process_hostname(url) + + # Add the URL to the dictionary under its hostname + if haskey(hostname_dict, hostname) + push!(hostname_dict[hostname], url) + else + hostname_dict[hostname] = [url] + end +end + + +""" + crawl(input_urls::Vector{<:AbstractString}) + +Crawls on the input URLs and returns a `hostname_url_dict` which is a dictionary with key being hostnames and the values being the URLs + +# Arguments +- `input_urls`: A vector of input URLs +""" +function crawl(input_urls::Vector{<:AbstractString}) url_queue = Vector{AbstractString}(input_urls) visited_url_set = Set{AbstractString}() - restricted_urls = Dict{String,Set{AbstractString}}() - parsed_blocks = [] - ## TODO: Add parallel processing for URLs + hostname_url_dict = Dict{AbstractString,Vector{AbstractString}}() + sitemap_urls = Vector{AbstractString}() + # TODO: Add parallel processing for URLs while !isempty(url_queue) url = url_queue[1] popfirst!(url_queue) base_url = get_base_url(url) - ## TODO: Show some respect to robots.txt if !in(base_url, visited_url_set) push!(visited_url_set, base_url) - if !check_robots_txt("*", base_url, restricted_urls, url_queue) - break + crawlable, sitemap_urls = check_robots_txt("*", base_url) + append!(url_queue, sitemap_urls) + if crawlable + try + get_urls!(base_url, url_queue) + process_hostname!(url, hostname_url_dict) + catch + @error "Bad URL: $base_url" + end end - get_urls!(base_url, url_queue) - push!(parsed_blocks, parse_url_to_blocks(base_url)) end end - return parsed_blocks -end \ No newline at end of file + + return hostname_url_dict + +end diff --git a/src/extract_urls.jl b/src/extract_urls.jl index a1d99bf..b9ea364 100644 --- a/src/extract_urls.jl +++ b/src/extract_urls.jl @@ -124,16 +124,16 @@ function get_urls!(url::AbstractString, url_queue::Vector{<:AbstractString}) @info "Scraping link: $url" # println(url) - try - fetched_content = HTTP.get(url) - parsed = Gumbo.parsehtml(String(fetched_content.body)) - if (url[end-3:end] == ".xml") - find_urls_xml!(url_xml, url_queue) - else - find_urls_html!(url, parsed.root, url_queue) - end - # print("-------------") - catch e - println("Bad URL: $url") + # try + fetched_content = HTTP.get(url) + parsed = Gumbo.parsehtml(String(fetched_content.body)) + if (url[end-3:end] == ".xml") + find_urls_xml!(url_xml, url_queue) + else + find_urls_html!(url, parsed.root, url_queue) end + # print("-------------") + # catch e + # println("Bad URL: $url") + # end end \ No newline at end of file diff --git a/src/make_embeddings.jl b/src/make_embeddings.jl new file mode 100644 index 0000000..ba079aa --- /dev/null +++ b/src/make_embeddings.jl @@ -0,0 +1,163 @@ +## TODO: Make a function to Check for version number + +""" + report_artifact() + +prints artifact information +""" +function report_artifact(fn_output) + @info("ARTIFACT: $(basename(fn_output))") + @info("sha256: ", bytes2hex(open(sha256, fn_output))) + @info("git-tree-sha1: ", Tar.tree_hash(IOBuffer(inflate_gzip(fn_output)))) +end + + + + +""" + create_output_folders() + +Creates output folders +""" +function create_output_folders(knowledge_pack_path::String) + # Define the folder path + folder_path = joinpath(knowledge_pack_path, "packs") + println("folder_path:", folder_path) + # Check if the folder exists + if !isdir(folder_path) + mkpath(folder_path) + @info "Folder created: $folder_path" + else + @info "Folder already exists: $folder_path" + end + +end + +""" + make_chunks(hostname_url_dict::Dict{AbstractString,Vector{AbstractString}}) + +Parses URLs from hostname_url_dict and saves the chunks + +# Arguments +- hostname_url_dict: Dict with key being hostname and value being a vector of URLs +""" +function make_chunks(hostname_url_dict::Dict{AbstractString,Vector{AbstractString}}, knowledge_pack_path::String) + output_chunks = Vector{SubString{String}}() + output_sources = Vector{String}() + SAVE_CHUNKS = true + CHUNK_SIZE = 512 + for (hostname, urls) in hostname_url_dict + for url in urls + try + chunks, sources = process_paths(url) + append!(output_chunks, chunks) + append!(output_sources, sources) + catch + @error "error!! check url: $url" + end + end + if SAVE_CHUNKS + serialize(joinpath(knowledge_pack_path, "$(hostname)-chunks-$(CHUNK_SIZE).jls"), output_chunks) + serialize(joinpath(knowledge_pack_path, "$(hostname)-sources-$(CHUNK_SIZE).jls"), output_sources) + end + + end + + +end + +""" + generate_embeddings() + +Deserializes chunks and sources to generate embeddings +""" +function generate_embeddings(knowledge_pack_path::String) + embedder = RT.BatchEmbedder() + entries = readdir(knowledge_pack_path) + + # Initialize a dictionary to group files by hostname and chunk size + hostname_files = Dict{String,Dict{Int,Dict{String,String}}}() + + # Regular expressions to match the file patterns + chunks_pattern = r"^(.*)-chunks-(\d+)\.jls$" + sources_pattern = r"^(.*)-sources-(\d+)\.jls$" + + # Group files by hostname and chunk size + for file in entries + match_chunks = match(chunks_pattern, file) + match_sources = match(sources_pattern, file) + + if match_chunks !== nothing + hostname = match_chunks.captures[1] + chunk_size = parse(Int, match_chunks.captures[2]) + if !haskey(hostname_files, hostname) + hostname_files[hostname] = Dict{Int,Dict{String,String}}() + end + if !haskey(hostname_files[hostname], chunk_size) + hostname_files[hostname][chunk_size] = Dict{String,String}() + end + hostname_files[hostname][chunk_size]["chunks"] = joinpath(knowledge_pack_path, file) + elseif match_sources !== nothing + hostname = match_sources.captures[1] + chunk_size = parse(Int, match_sources.captures[2]) + if !haskey(hostname_files, hostname) + hostname_files[hostname] = Dict{Int,Dict{String,String}}() + end + if !haskey(hostname_files[hostname], chunk_size) + hostname_files[hostname][chunk_size] = Dict{String,String}() + end + hostname_files[hostname][chunk_size]["sources"] = joinpath(knowledge_pack_path, file) + end + end + + + # Process each pair of files + for (hostname, chunk_files) in hostname_files + for (chunk_size, files) in chunk_files + if haskey(files, "chunks") && haskey(files, "sources") + chunks_file = files["chunks"] + sources_file = files["sources"] + chunks = deserialize(chunks_file) + sources = deserialize(sources_file) + cost_tracker = Threads.Atomic{Float64}(0.0) + full_embeddings = RT.get_embeddings(embedder, chunks; model="text-embedding-3-large", verbose=false, cost_tracker, api_key=ENV["OPENAI_API_KEY"]) + + # Float32 + fn_output = joinpath(knowledge_pack_path, "packs", "$hostname-textembedding3large-0-Float32__v1.0.tar.gz") + fn_temp = joinpath(knowledge_pack_path, "packs", "pack.hdf5") + h5open(fn_temp, "w") do file + file["chunks"] = chunks + file["sources"] = sources + file["embeddings"] = full_embeddings + file["type"] = "ChunkIndex" + # file["metadata"] = "$hostname ecosystem docstrings, chunk size $chunk_size, downloaded on 20240330, contains: Makie.jl, AlgebraOfGraphics.jl, GeoMakie.jl, GraphMakie.jl, MakieThemes.jl, TopoPlots.jl, Tyler.jl" + end + run(tar - cvzf$fn_output - C$(dirname(fn_temp))$(basename(fn_temp))) + report_artifact(fn_output) + + else + @warn "Missing pair for hostname: $hostname, chunk size: $chunk_size" + end + end + end + +end + + + +""" + make_embeddings(input_urls::Vector{<:AbstractString}) + +Entry point to crawl, parse and create embeddings + +# Arguments +- input_urls: vector containing URL strings to parse +""" +function make_embeddings(input_urls::Vector{<:AbstractString}) + hostname_url_dict = Dict{AbstractString,Vector{AbstractString}}() + hostname_url_dict = crawl(input_urls) + knowledge_pack_path = joinpath(@__DIR__, "..", "knowledge_packs") + create_output_folders(knowledge_pack_path) + make_chunks(hostname_url_dict, knowledge_pack_path) + generate_embeddings(knowledge_pack_path) +end \ No newline at end of file diff --git a/src/preparation.jl b/src/preparation.jl new file mode 100644 index 0000000..ab8d7b5 --- /dev/null +++ b/src/preparation.jl @@ -0,0 +1,115 @@ +# include("recursive_splitter.jl") +include("utils.jl") +""" + get_header_path(d::Dict) + +Concatenates the h1, h2, h3 keys from the metadata of a Dict + +# Examples +```julia +d = Dict("metadata" => Dict{Symbol,Any}(:h1 => "Axis", :h2 => "Attributes", :h3 => "yzoomkey"), "heading" => "yzoomkey") +get_header_path(d) +# Output: "Axis/Attributes/yzoomkey" +``` +""" +function get_header_path(d::Dict) + metadata = get(d, "metadata", Dict{Any,Any}()) + isempty(metadata) && return nothing + keys_ = [:h1, :h2, :h3] + vals = get.(Ref(metadata), keys_, "") |> x -> filter(!isempty, x) |> x -> join(x, "/") + isempty(vals) ? nothing : vals +end + + +"Roll-up chunks (that have the same header!), so we can split them later by to get the desired length" +function roll_up_chunks(parsed_blocks, url::AbstractString; separator::String="") + docs = String[] + io = IOBuffer() + last_header = nothing + sources = String[] + + for chunk in parsed_blocks + header = get_header_path(chunk) + if isnothing(header) || header != last_header + # New content block, commit work thus far + str = String(take!(io)) + if !isempty(str) + push!(docs, str) + src = url * (isnothing(last_header) ? "" : "::$last_header") + push!(sources, src) + end + last_header = header + end + # Append the new chunk together with a separator + haskey(chunk, "code") && print(io, chunk["code"], separator) + haskey(chunk, "text") && print(io, chunk["text"], separator) + end + ## commit remaining docs + str = String(take!(io)) + if !isempty(str) + push!(docs, str) + src = url * (isnothing(last_header) ? "" : "::$last_header") + push!(sources, src) + end + return docs, sources +end + + +struct DocParserChunker <: RT.AbstractChunker end +""" + RT.get_chunks(chunker::DocParserChunker, + html_files::Vector{<:AbstractString}; + sources::AbstractVector{<:AbstractString}=html_files, + verbose::Bool=true, + separators=["\n\n", ". ", "\n", " "], max_length::Int=256) + +Extracts chunks from HTML files, by parsing the content in the HTML, rolling up chunks by headers, and splits them by separators to get the desired length. +""" +function RT.get_chunks(chunker::DocParserChunker, url::AbstractString; + verbose::Bool=true, + separators=["\n\n", ". ", "\n", " "], max_length::Int=256) + + + SEP = "" + sources = AbstractVector{<:AbstractString} + output_chunks = Vector{SubString{String}}() + output_sources = Vector{eltype(sources)}() + + verbose && @info "Processing $(url)..." + + parsed_blocks = parse_url_to_blocks(url) + ## Roll up to the same header + docs_, sources_ = roll_up_chunks(parsed_blocks, url; separator=SEP) + + ## roll up chunks by SEP splitter, then remove it later + for (doc, src) in zip(docs_, sources_) + ## roll up chunks by SEP splitter, then remove it later + doc_chunks = PT.recursive_splitter(doc, [SEP, separators...]; max_length) .|> + x -> replace(x, SEP => " ") .|> strip |> x -> filter(!isempty, x) + # skip if no chunks found + isempty(doc_chunks) && continue + append!(output_chunks, doc_chunks) + append!(output_sources, fill(src, length(doc_chunks))) + end + return output_chunks, output_sources +end + + + +"Process folders provided in `paths`. In each, take all HTML files, scrape them, chunk them and postprocess them." +function process_paths(url::AbstractString, max_length::Int=512) + + output_chunks = Vector{SubString{String}}() + output_sources = Vector{String}() + + chunks, sources = RT.get_chunks(DocParserChunker(), url; max_length) + + append!(output_chunks, chunks) + append!(output_sources, sources) + + + @info "Scraping done: $(length(output_chunks)) chunks" + postprocess_chunks(output_chunks, output_sources; min_length=40, skip_code=true) + + return output_chunks, output_sources +end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..4bf1e07 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,68 @@ +"Finds duplicates in a list of chunks using SHA-256 hash. Returns a bit vector of the same length as the input list, where `true` indicates a duplicate (second instance of the same text)." +function find_duplicates(chunks::AbstractVector{<:AbstractString}) + # hash the chunks for easier search + hashed_chunks = bytes2hex.(sha256.(chunks)) + sorted_indices = sortperm(hashed_chunks) # Sort indices based on hashed values + + duplicates = falses(length(chunks)) + prev_hash = "" # Initialize with an empty string to ensure the first comparison fails + + for idx in sorted_indices + current_hash = hashed_chunks[idx] + # Check if current hash matches the previous one, indicating a duplicate + if current_hash == prev_hash + duplicates[idx] = true # Mark as duplicate + else + prev_hash = current_hash # Update previous hash for the next iteration + end + end + + return duplicates +end + +"Removes chunks that are duplicated in the input list of chunks and their corresponding sources." +function remove_duplicates(chunks::AbstractVector{<:AbstractString}, sources::AbstractVector{<:AbstractString}) + idxs = find_duplicates(chunks) + return chunks[.!idxs], sources[.!idxs] +end + +"Removes chunks that are shorter than a specified length (`min_length`) from the input list of chunks and their corresponding sources." +function remove_short_chunks(chunks::AbstractVector{<:AbstractString}, sources::AbstractVector{<:AbstractString}; min_length::Int=40, skip_code::Bool=true) + idx = if skip_code + ## Keep short chunks if they contain code (might be combined with some preceding/suceeeding text) + findall(x -> length(x) >= min_length || occursin("```", x), chunks) + else + findall(x -> length(x) >= min_length, chunks) + end + return chunks[idx], sources[idx] +end + + +function replace_local_paths(sources::AbstractVector{<:AbstractString}, paths::AbstractVector{<:AbstractString}, websites::AbstractVector{<:AbstractString}) + @assert length(paths) == length(websites) "Length of `paths` must match length of `websites`" + replacement_pairs = paths .=> websites + output = map(x -> replace(x, replacement_pairs...), sources) +end + + +"Post-processes the input list of chunks and their corresponding sources by removing short chunks and duplicates." +function postprocess_chunks(chunks::AbstractVector{<:AbstractString}, sources::AbstractVector{<:AbstractString}; min_length::Int=40, skip_code::Bool=true, + paths::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, websites::Union{Nothing,AbstractVector{<:AbstractString}}=nothing) + len_ = length(chunks) + chunks, sources = remove_short_chunks(chunks, sources; min_length, skip_code) + @info "Removed $(len_ - length(chunks)) short chunks" + + len_ = length(chunks) + chunks, sources = remove_duplicates(chunks, sources) + @info "Removed $(len_ - length(chunks)) duplicate chunks" + + ## Renaming sources + if !isnothing(paths) && !isnothing(websites) + sources = replace_local_paths(sources, paths, websites) + @info "Replaced local paths with websites" + end + + return chunks, sources + + +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index aee7e9c..fdde81f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,33 @@ + using Test +urls = Vector{AbstractString}(["https://docs.julialang.org/en/v1/"]) +url = urls[1] +queue = Vector{AbstractString}() + +@testset "check robots.txt" begin + result, sitemap_queue = check_robots_txt("*", url) + @test result == true +end + +@testset "HTTP get" begin + @test HTTP.get(url) != nothing +end + +@testset "get_urls!" begin + get_urls!(url, queue) + @test length(queue) > 1 +end + +@testset "parse & roll_up" begin + parsed_blocks = parse_url_to_blocks(url) + @test length(parsed_blocks) > 0 + SEP = "" + docs_, sources_ = roll_up_chunks(parsed_blocks, url; separator=SEP) + @test length(docs_) > 0 && length(sources_) > 0 && docs_[1] != nothing && sources_[1] != nothing +end -include("..\\src\\RAGKit.jl") +@testset "overall test" begin + chunks, sources = process_paths(url) + @test length(chunks) > 0 && length(sources) > 0 && chunks[1] != nothing && sources[1] != nothing -@testset "RAGKit Tests" begin - # Your test cases go here end