Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding website ingestion! #325

Merged
merged 14 commits into from
Nov 19, 2024
Merged
Empty file.
117 changes: 117 additions & 0 deletions dags/hivemind_etl_helpers/src/db/website/crawlee_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import asyncio
from typing import Any

from crawlee.playwright_crawler import PlaywrightCrawler, PlaywrightCrawlingContext
from defusedxml import ElementTree as ET


class CrawleeClient:
def __init__(
self,
max_requests: int = 20,
headless: bool = True,
browser_type: str = "chromium",
) -> None:
self.crawler = PlaywrightCrawler(
max_requests_per_crawl=max_requests,
headless=headless,
browser_type=browser_type,
)

# do not persist crawled data to local storage
self.crawler._configuration.persist_storage = False
self.crawler._configuration.write_metadata = False

amindadgar marked this conversation as resolved.
Show resolved Hide resolved
@self.crawler.router.default_handler
async def request_handler(context: PlaywrightCrawlingContext) -> None:
context.log.info(f"Processing {context.request.url} ...")

inner_text = await context.page.inner_text(selector="body")

if "sitemap.xml" in context.request.url:
links = self._extract_links_from_sitemap(inner_text)
await context.add_requests(requests=list(set(links)))
else:
await context.enqueue_links()

data = {
"url": context.request.url,
"title": await context.page.title(),
"inner_text": inner_text,
}

await context.push_data(data)

amindadgar marked this conversation as resolved.
Show resolved Hide resolved
def _extract_links_from_sitemap(self, sitemap_content: str) -> list[str]:
"""

Extract URLs from a sitemap XML content.

Parameters
----------
sitemap_content : str
The XML content of the sitemap

Raises
------
ET.ParseError
If the XML content is malformed

Returns
-------
links : list[str]
list of valid URLs extracted from the sitemap
"""
try:
root = ET.fromstring(sitemap_content)
namespace = {"ns": "http://www.sitemaps.org/schemas/sitemap/0.9"}
links = []
for element in root.findall("ns:url/ns:loc", namespace):
url = element.text.strip() if element.text else None
if url and url.startswith(("http://", "https://")):
links.append(url)
return links
amindadgar marked this conversation as resolved.
Show resolved Hide resolved
except ET.ParseError as e:
raise ValueError(f"Invalid sitemap XML: {str(e)}")
amindadgar marked this conversation as resolved.
Show resolved Hide resolved

return links
amindadgar marked this conversation as resolved.
Show resolved Hide resolved

async def crawl(self, links: list[str]) -> list[dict[str, Any]]:
"""
Crawl websites and extract data from all inner links under the domain routes.

Parameters
----------
links : list[str]
List of valid URLs to crawl

Returns
-------
crawled_data : list[dict[str, Any]]
List of dictionaries containing crawled data with keys:
- url: str
- title: str
- inner_text: str

Raises
------
ValueError
If any of the input URLs is invalid (not starting with http or https)
TimeoutError
If the crawl operation times out
"""
# Validate input URLs
valid_links = []
for url in links:
if url and isinstance(url, str) and url.startswith(("http://", "https://")):
valid_links.append(url)
else:
raise ValueError(f"Invalid URL: {url}")

try:
await self.crawler.add_requests(requests=valid_links)
await asyncio.wait_for(self.crawler.run(), timeout=3600) # 1 hour timeout
crawled_data = await self.crawler.get_data()
return crawled_data.items
amindadgar marked this conversation as resolved.
Show resolved Hide resolved
except asyncio.TimeoutError:
raise TimeoutError("Crawl operation timed out")
amindadgar marked this conversation as resolved.
Show resolved Hide resolved
amindadgar marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions dags/hivemind_etl_helpers/src/utils/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .github import ModulesGitHub
from .mediawiki import ModulesMediaWiki
from .notion import ModulesNotion
from .website import ModulesWebsite
6 changes: 3 additions & 3 deletions dags/hivemind_etl_helpers/src/utils/modules/modules_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_token(self, platform_id: ObjectId, token_type: str) -> str:

def get_platform_metadata(
self, platform_id: ObjectId, metadata_name: str
) -> str | dict:
) -> str | dict | list:
"""
get the userid that belongs to a platform

Expand All @@ -111,8 +111,8 @@ def get_platform_metadata(

Returns
---------
user_id : str
the user id that the platform belongs to
metadata_value : Any
the values that the metadata belongs to
"""
client = MongoSingleton.get_instance().get_client()

Expand Down
63 changes: 63 additions & 0 deletions dags/hivemind_etl_helpers/src/utils/modules/website.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging

from .modules_base import ModulesBase


class ModulesWebsite(ModulesBase):
def __init__(self) -> None:
self.platform_name = "website"
super().__init__()

def get_learning_platforms(
self,
) -> list[dict[str, str | list[str]]]:
"""
Get all the website communities with their page titles.

Returns
---------
community_orgs : list[dict[str, str | list[str]]] = []
a list of website data information

example data output:
```
[{
"community_id": "6579c364f1120850414e0dc5",
"platform_id": "6579c364f1120850414e0dc6",
"urls": ["link1", "link2"],
}]
```
"""
modules = self.query(platform=self.platform_name, projection={"name": 0})
communities_data: list[dict[str, str | list[str]]] = []

for module in modules:
community = module["community"]

# each platform of the community
for platform in module["options"]["platforms"]:
if platform["name"] != self.platform_name:
continue

platform_id = platform["platform"]

try:
website_links = self.get_platform_metadata(
platform_id=platform_id,
metadata_name="resources",
)

communities_data.append(
{
"community_id": str(community),
"platform_id": platform_id,
"urls": website_links,
}
)
except Exception as exp:
logging.error(
"Exception while fetching website modules "
f"for platform: {platform_id} | exception: {exp}"
)
amindadgar marked this conversation as resolved.
Show resolved Hide resolved

return communities_data
amindadgar marked this conversation as resolved.
Show resolved Hide resolved
81 changes: 81 additions & 0 deletions dags/hivemind_etl_helpers/tests/unit/test_website_etl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock

from dotenv import load_dotenv
from hivemind_etl_helpers.website_etl import WebsiteETL
from llama_index.core import Document


class TestWebsiteETL(IsolatedAsyncioTestCase):
def setUp(self):
"""
Setup for the test cases. Initializes a WebsiteETL instance with mocked dependencies.
"""
load_dotenv()
self.community_id = "test_community"
self.website_etl = WebsiteETL(self.community_id)
self.website_etl.crawlee_client = AsyncMock()
self.website_etl.ingestion_pipeline = MagicMock()

async def test_extract(self):
"""
Test the extract method.
"""
urls = ["https://example.com"]
mocked_data = [
{
"url": "https://example.com",
"inner_text": "Example text",
"title": "Example",
}
]
self.website_etl.crawlee_client.crawl.return_value = mocked_data

extracted_data = await self.website_etl.extract(urls)

self.assertEqual(extracted_data, mocked_data)
self.website_etl.crawlee_client.crawl.assert_awaited_once_with(urls)

def test_transform(self):
"""
Test the transform method.
"""
raw_data = [
{
"url": "https://example.com",
"inner_text": "Example text",
"title": "Example",
}
]
expected_documents = [
Document(
doc_id="https://example.com",
text="Example text",
metadata={"title": "Example", "url": "https://example.com"},
)
]

documents = self.website_etl.transform(raw_data)

self.assertEqual(len(documents), len(expected_documents))
self.assertEqual(documents[0].doc_id, expected_documents[0].doc_id)
self.assertEqual(documents[0].text, expected_documents[0].text)
self.assertEqual(documents[0].metadata, expected_documents[0].metadata)

def test_load(self):
"""
Test the load method.
"""
documents = [
Document(
doc_id="https://example.com",
text="Example text",
metadata={"title": "Example", "url": "https://example.com"},
)
]

self.website_etl.load(documents)

self.website_etl.ingestion_pipeline.run_pipeline.assert_called_once_with(
docs=documents
)
89 changes: 89 additions & 0 deletions dags/hivemind_etl_helpers/website_etl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Any

from hivemind_etl_helpers.ingestion_pipeline import CustomIngestionPipeline
from hivemind_etl_helpers.src.db.website.crawlee_client import CrawleeClient
from llama_index.core import Document


class WebsiteETL:
def __init__(
self,
community_id: str,
) -> None:
"""
Parameters
-----------
community_id : str
the community to save its data
"""
self.community_id = community_id
collection_name = "website"

# preparing the data extractor and ingestion pipelines
self.crawlee_client = CrawleeClient()
self.ingestion_pipeline = CustomIngestionPipeline(
self.community_id, collection_name=collection_name
)

async def extract(
self,
urls: list[str],
) -> list[dict[str, Any]]:
"""
Extract given urls

Parameters
-----------
urls : list[str]
a list of urls

Returns
---------
extracted_data : list[dict[str, Any]]
The crawled data from urls
"""
extracted_data = await self.crawlee_client.crawl(urls)

return extracted_data

amindadgar marked this conversation as resolved.
Show resolved Hide resolved
def transform(self, raw_data: list[dict[str, Any]]) -> list[Document]:
"""
transform raw data to llama-index documents

Parameters
------------
raw_data : list[dict[str, Any]]
crawled data

Returns
---------
documents : list[llama_index.Document]
list of llama-index documents
"""
documents: list[Document] = []

for data in raw_data:
doc_id = data["url"]
doc = Document(
doc_id=doc_id,
text=data["inner_text"],
metadata={
"title": data["title"],
"url": data["url"],
},
)
documents.append(doc)

return documents

def load(self, documents: list[Document]) -> None:
"""
load the documents into the vector db

Parameters
-------------
documents: list[llama_index.Document]
the llama-index documents to be ingested
"""
# loading data into db
self.ingestion_pipeline.run_pipeline(docs=documents)
amindadgar marked this conversation as resolved.
Show resolved Hide resolved
Loading