Skip to content

Commit

Permalink
Document Loader: Add adapter for loading data from database (CrateDB)
Browse files Browse the repository at this point in the history
Based on previous contributions, this has effectively become just a
naming-things wrapper around LangChain Community's `SQLDatabaseLoader`,
literally.
  • Loading branch information
amotl committed Dec 15, 2024
1 parent dbb8f40 commit 8dade36
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 70 deletions.
73 changes: 3 additions & 70 deletions langchain_cratedb/document_loaders.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,5 @@
"""CrateDB document loader."""
from langchain_community.document_loaders import SQLDatabaseLoader

from typing import Iterator

from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document


class CrateDBLoader(BaseLoader):
# TODO: Replace all TODOs in docstring. See example docstring:
# https://github.com/langchain-ai/langchain/blob/869523ad728e6b76d77f170cce13925b4ebc3c1e/libs/community/langchain_community/document_loaders/recursive_url_loader.py#L54
"""
CrateDB document loader integration
# TODO: Replace with relevant packages, env vars.
Setup:
Install ``langchain-cratedb`` and set environment variable ``CRATEDB_API_KEY``.
.. code-block:: bash
pip install -U langchain-cratedb
export CRATEDB_API_KEY="your-api-key"
# TODO: Replace with relevant init params.
Instantiate:
.. code-block:: python
from langchain_community.document_loaders import CrateDBLoader
loader = CrateDBLoader(
# required params = ...
# other params = ...
)
Lazy load:
.. code-block:: python
docs = []
docs_lazy = loader.lazy_load()
# async variant:
# docs_lazy = await loader.alazy_load()
for doc in docs_lazy:
docs.append(doc)
print(docs[0].page_content[:100])
print(docs[0].metadata)
.. code-block:: python
TODO: Example output
# TODO: Delete if async load is not implemented
Async load:
.. code-block:: python
docs = await loader.aload()
print(docs[0].page_content[:100])
print(docs[0].metadata)
.. code-block:: python
TODO: Example output
""" # noqa: E501

# TODO: This method must be implemented to load documents.
# Do not implement load(), a default implementation is already available.
def lazy_load(self) -> Iterator[Document]:
raise NotImplementedError()

# TODO: Implement if you would like to change default BaseLoader implementation
# async def alazy_load(self) -> AsyncIterator[Document]:
class CrateDBLoader(SQLDatabaseLoader):
pass
11 changes: 11 additions & 0 deletions tests/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Module defines common test data."""

from pathlib import Path

_THIS_DIR = Path(__file__).parent

_DATA_DIR = _THIS_DIR / "data"

# Paths to data files
MLB_TEAMS_2012_CSV = _DATA_DIR / "mlb_teams_2012.csv"
MLB_TEAMS_2012_SQL = _DATA_DIR / "mlb_teams_2012.sql"
32 changes: 32 additions & 0 deletions tests/data/mlb_teams_2012.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"Team", "Payroll (millions)", "Wins"
"Nationals", 81.34, 98
"Reds", 82.20, 97
"Yankees", 197.96, 95
"Giants", 117.62, 94
"Braves", 83.31, 94
"Athletics", 55.37, 94
"Rangers", 120.51, 93
"Orioles", 81.43, 93
"Rays", 64.17, 90
"Angels", 154.49, 89
"Tigers", 132.30, 88
"Cardinals", 110.30, 88
"Dodgers", 95.14, 86
"White Sox", 96.92, 85
"Brewers", 97.65, 83
"Phillies", 174.54, 81
"Diamondbacks", 74.28, 81
"Pirates", 63.43, 79
"Padres", 55.24, 76
"Mariners", 81.97, 75
"Mets", 93.35, 74
"Blue Jays", 75.48, 73
"Royals", 60.91, 72
"Marlins", 118.07, 69
"Red Sox", 173.18, 69
"Indians", 78.43, 68
"Twins", 94.08, 66
"Rockies", 78.06, 64
"Cubs", 88.19, 61
"Astros", 60.65, 55

41 changes: 41 additions & 0 deletions tests/data/mlb_teams_2012.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
-- Provisioning table "mlb_teams_2012".
--
-- psql postgresql://postgres@localhost < mlb_teams_2012.sql
-- crash < mlb_teams_2012.sql

DROP TABLE IF EXISTS mlb_teams_2012;
CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT);
INSERT INTO mlb_teams_2012
("Team", "Payroll (millions)", "Wins")
VALUES
('Nationals', 81.34, 98),
('Reds', 82.20, 97),
('Yankees', 197.96, 95),
('Giants', 117.62, 94),
('Braves', 83.31, 94),
('Athletics', 55.37, 94),
('Rangers', 120.51, 93),
('Orioles', 81.43, 93),
('Rays', 64.17, 90),
('Angels', 154.49, 89),
('Tigers', 132.30, 88),
('Cardinals', 110.30, 88),
('Dodgers', 95.14, 86),
('White Sox', 96.92, 85),
('Brewers', 97.65, 83),
('Phillies', 174.54, 81),
('Diamondbacks', 74.28, 81),
('Pirates', 63.43, 79),
('Padres', 55.24, 76),
('Mariners', 81.97, 75),
('Mets', 93.35, 74),
('Blue Jays', 75.48, 73),
('Royals', 60.91, 72),
('Marlins', 118.07, 69),
('Red Sox', 173.18, 69),
('Indians', 78.43, 68),
('Twins', 94.08, 66),
('Rockies', 78.06, 64),
('Cubs', 88.19, 61),
('Astros', 60.65, 55)
;
186 changes: 186 additions & 0 deletions tests/integration_tests/test_document_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""
Test SQLAlchemy document loader functionality on behalf of CrateDB.
"""

import functools
import logging

import pytest
import sqlalchemy as sa
from langchain_community.document_loaders.sql_database import SQLDatabaseLoader
from langchain_community.utilities.sql_database import SQLDatabase

from langchain_cratedb import CrateDBLoader
from tests.data import MLB_TEAMS_2012_SQL

logging.basicConfig(level=logging.DEBUG)


@pytest.fixture()
def db(engine: sa.Engine) -> SQLDatabase:
return SQLDatabase(engine=engine)


@pytest.fixture()
def provision_database(engine: sa.Engine) -> None:
"""
Provision database with table schema and data.
"""
sql_statements = MLB_TEAMS_2012_SQL.read_text()
with engine.connect() as connection:
connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;"))
for statement in sql_statements.split(";"):
statement = statement.strip()
if not statement:
continue
connection.execute(sa.text(statement))
connection.commit()
if engine.dialect.name.startswith("crate"):
connection.execute(sa.text("REFRESH TABLE mlb_teams_2012;"))
connection.commit()


def test_cratedb_loader_no_options(db: SQLDatabase) -> None:
"""Test SQLAlchemy loader basics."""

loader = CrateDBLoader("SELECT 1 AS a, 2 AS b", db=db)
docs = loader.load()

assert len(docs) == 1
assert docs[0].page_content == "a: 1\nb: 2"
assert docs[0].metadata == {}


def test_cratedb_loader_include_rownum_into_metadata(db: SQLDatabase) -> None:
"""Test SQLAlchemy loader with row number in metadata."""

loader = CrateDBLoader(
"SELECT 1 AS a, 2 AS b",
db=db,
include_rownum_into_metadata=True,
)
docs = loader.load()

assert len(docs) == 1
assert docs[0].page_content == "a: 1\nb: 2"
assert docs[0].metadata == {"row": 0}


def test_cratedb_loader_include_query_into_metadata(db: SQLDatabase) -> None:
"""Test SQLAlchemy loader with query in metadata."""

loader = CrateDBLoader(
"SELECT 1 AS a, 2 AS b", db=db, include_query_into_metadata=True
)
docs = loader.load()

assert len(docs) == 1
assert docs[0].page_content == "a: 1\nb: 2"
assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"}


def test_cratedb_loader_page_content_columns(db: SQLDatabase) -> None:
"""Test SQLAlchemy loader with defined page content columns."""

# Define a custom callback function to convert a row into a "page content" string.
row_to_content = functools.partial(
SQLDatabaseLoader.page_content_default_mapper, column_names=["a"]
)

loader = CrateDBLoader(
"SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b",
db=db,
page_content_mapper=row_to_content,
)
docs = loader.load()

assert len(docs) == 2
assert docs[0].page_content == "a: 1"
assert docs[0].metadata == {}

assert docs[1].page_content == "a: 3"
assert docs[1].metadata == {}


def test_cratedb_loader_metadata_columns(db: SQLDatabase) -> None:
"""Test SQLAlchemy loader with defined metadata columns."""

# Define a custom callback function to convert a row into a "metadata" dictionary.
row_to_metadata = functools.partial(
SQLDatabaseLoader.metadata_default_mapper, column_names=["b"]
)

loader = CrateDBLoader(
"SELECT 1 AS a, 2 AS b",
db=db,
metadata_mapper=row_to_metadata,
)
docs = loader.load()

assert len(docs) == 1
assert docs[0].metadata == {"b": 2}


def test_cratedb_loader_real_data_with_sql_no_parameters(
db: SQLDatabase, provision_database: None
) -> None:
"""Test SQLAlchemy loader with real data, querying by SQL statement."""

loader = CrateDBLoader(
query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";',
db=db,
)
docs = loader.load()

assert len(docs) == 30
assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89"
assert docs[0].metadata == {}


def test_cratedb_loader_real_data_with_sql_and_parameters(
db: SQLDatabase, provision_database: None
) -> None:
"""Test SQLAlchemy loader, querying by SQL statement and parameters."""

loader = CrateDBLoader(
query='SELECT * FROM mlb_teams_2012 WHERE "Team" LIKE :search ORDER BY "Team";',
parameters={"search": "R%"},
db=db,
)
docs = loader.load()

assert len(docs) == 6
assert docs[0].page_content == "Team: Rangers\nPayroll (millions): 120.51\nWins: 93"
assert docs[0].metadata == {}


def test_cratedb_loader_real_data_with_selectable(
db: SQLDatabase, provision_database: None
) -> None:
"""Test SQLAlchemy loader with real data, querying by SQLAlchemy selectable."""

# Define an SQLAlchemy table.
mlb_teams_2012 = sa.Table(
"mlb_teams_2012",
sa.MetaData(),
sa.Column("Team", sa.VARCHAR),
sa.Column("Payroll (millions)", sa.FLOAT),
sa.Column("Wins", sa.BIGINT),
)

# Query the database table using an SQLAlchemy selectable.
select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team)
loader = CrateDBLoader(
query=select,
db=db,
include_query_into_metadata=True,
)
docs = loader.load()

assert len(docs) == 30
assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89"
assert docs[0].metadata == {
"query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", '
'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 '
'ORDER BY mlb_teams_2012."Team"'
}

0 comments on commit 8dade36

Please sign in to comment.