-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Document Loader: Add adapter for loading data from database (CrateDB)
Based on previous contributions, this has effectively become just a naming-things wrapper around LangChain Community's `SQLDatabaseLoader`, literally.
- Loading branch information
Showing
5 changed files
with
273 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,5 @@ | ||
"""CrateDB document loader.""" | ||
from langchain_community.document_loaders import SQLDatabaseLoader | ||
|
||
from typing import Iterator | ||
|
||
from langchain_core.document_loaders.base import BaseLoader | ||
from langchain_core.documents import Document | ||
|
||
|
||
class CrateDBLoader(BaseLoader): | ||
# TODO: Replace all TODOs in docstring. See example docstring: | ||
# https://github.com/langchain-ai/langchain/blob/869523ad728e6b76d77f170cce13925b4ebc3c1e/libs/community/langchain_community/document_loaders/recursive_url_loader.py#L54 | ||
""" | ||
CrateDB document loader integration | ||
# TODO: Replace with relevant packages, env vars. | ||
Setup: | ||
Install ``langchain-cratedb`` and set environment variable ``CRATEDB_API_KEY``. | ||
.. code-block:: bash | ||
pip install -U langchain-cratedb | ||
export CRATEDB_API_KEY="your-api-key" | ||
# TODO: Replace with relevant init params. | ||
Instantiate: | ||
.. code-block:: python | ||
from langchain_community.document_loaders import CrateDBLoader | ||
loader = CrateDBLoader( | ||
# required params = ... | ||
# other params = ... | ||
) | ||
Lazy load: | ||
.. code-block:: python | ||
docs = [] | ||
docs_lazy = loader.lazy_load() | ||
# async variant: | ||
# docs_lazy = await loader.alazy_load() | ||
for doc in docs_lazy: | ||
docs.append(doc) | ||
print(docs[0].page_content[:100]) | ||
print(docs[0].metadata) | ||
.. code-block:: python | ||
TODO: Example output | ||
# TODO: Delete if async load is not implemented | ||
Async load: | ||
.. code-block:: python | ||
docs = await loader.aload() | ||
print(docs[0].page_content[:100]) | ||
print(docs[0].metadata) | ||
.. code-block:: python | ||
TODO: Example output | ||
""" # noqa: E501 | ||
|
||
# TODO: This method must be implemented to load documents. | ||
# Do not implement load(), a default implementation is already available. | ||
def lazy_load(self) -> Iterator[Document]: | ||
raise NotImplementedError() | ||
|
||
# TODO: Implement if you would like to change default BaseLoader implementation | ||
# async def alazy_load(self) -> AsyncIterator[Document]: | ||
class CrateDBLoader(SQLDatabaseLoader): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
"""Module defines common test data.""" | ||
|
||
from pathlib import Path | ||
|
||
_THIS_DIR = Path(__file__).parent | ||
|
||
_DATA_DIR = _THIS_DIR / "data" | ||
|
||
# Paths to data files | ||
MLB_TEAMS_2012_CSV = _DATA_DIR / "mlb_teams_2012.csv" | ||
MLB_TEAMS_2012_SQL = _DATA_DIR / "mlb_teams_2012.sql" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
"Team", "Payroll (millions)", "Wins" | ||
"Nationals", 81.34, 98 | ||
"Reds", 82.20, 97 | ||
"Yankees", 197.96, 95 | ||
"Giants", 117.62, 94 | ||
"Braves", 83.31, 94 | ||
"Athletics", 55.37, 94 | ||
"Rangers", 120.51, 93 | ||
"Orioles", 81.43, 93 | ||
"Rays", 64.17, 90 | ||
"Angels", 154.49, 89 | ||
"Tigers", 132.30, 88 | ||
"Cardinals", 110.30, 88 | ||
"Dodgers", 95.14, 86 | ||
"White Sox", 96.92, 85 | ||
"Brewers", 97.65, 83 | ||
"Phillies", 174.54, 81 | ||
"Diamondbacks", 74.28, 81 | ||
"Pirates", 63.43, 79 | ||
"Padres", 55.24, 76 | ||
"Mariners", 81.97, 75 | ||
"Mets", 93.35, 74 | ||
"Blue Jays", 75.48, 73 | ||
"Royals", 60.91, 72 | ||
"Marlins", 118.07, 69 | ||
"Red Sox", 173.18, 69 | ||
"Indians", 78.43, 68 | ||
"Twins", 94.08, 66 | ||
"Rockies", 78.06, 64 | ||
"Cubs", 88.19, 61 | ||
"Astros", 60.65, 55 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
-- Provisioning table "mlb_teams_2012". | ||
-- | ||
-- psql postgresql://postgres@localhost < mlb_teams_2012.sql | ||
-- crash < mlb_teams_2012.sql | ||
|
||
DROP TABLE IF EXISTS mlb_teams_2012; | ||
CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT); | ||
INSERT INTO mlb_teams_2012 | ||
("Team", "Payroll (millions)", "Wins") | ||
VALUES | ||
('Nationals', 81.34, 98), | ||
('Reds', 82.20, 97), | ||
('Yankees', 197.96, 95), | ||
('Giants', 117.62, 94), | ||
('Braves', 83.31, 94), | ||
('Athletics', 55.37, 94), | ||
('Rangers', 120.51, 93), | ||
('Orioles', 81.43, 93), | ||
('Rays', 64.17, 90), | ||
('Angels', 154.49, 89), | ||
('Tigers', 132.30, 88), | ||
('Cardinals', 110.30, 88), | ||
('Dodgers', 95.14, 86), | ||
('White Sox', 96.92, 85), | ||
('Brewers', 97.65, 83), | ||
('Phillies', 174.54, 81), | ||
('Diamondbacks', 74.28, 81), | ||
('Pirates', 63.43, 79), | ||
('Padres', 55.24, 76), | ||
('Mariners', 81.97, 75), | ||
('Mets', 93.35, 74), | ||
('Blue Jays', 75.48, 73), | ||
('Royals', 60.91, 72), | ||
('Marlins', 118.07, 69), | ||
('Red Sox', 173.18, 69), | ||
('Indians', 78.43, 68), | ||
('Twins', 94.08, 66), | ||
('Rockies', 78.06, 64), | ||
('Cubs', 88.19, 61), | ||
('Astros', 60.65, 55) | ||
; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
""" | ||
Test SQLAlchemy document loader functionality on behalf of CrateDB. | ||
""" | ||
|
||
import functools | ||
import logging | ||
|
||
import pytest | ||
import sqlalchemy as sa | ||
from langchain_community.document_loaders.sql_database import SQLDatabaseLoader | ||
from langchain_community.utilities.sql_database import SQLDatabase | ||
|
||
from langchain_cratedb import CrateDBLoader | ||
from tests.data import MLB_TEAMS_2012_SQL | ||
|
||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
|
||
@pytest.fixture() | ||
def db(engine: sa.Engine) -> SQLDatabase: | ||
return SQLDatabase(engine=engine) | ||
|
||
|
||
@pytest.fixture() | ||
def provision_database(engine: sa.Engine) -> None: | ||
""" | ||
Provision database with table schema and data. | ||
""" | ||
sql_statements = MLB_TEAMS_2012_SQL.read_text() | ||
with engine.connect() as connection: | ||
connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;")) | ||
for statement in sql_statements.split(";"): | ||
statement = statement.strip() | ||
if not statement: | ||
continue | ||
connection.execute(sa.text(statement)) | ||
connection.commit() | ||
if engine.dialect.name.startswith("crate"): | ||
connection.execute(sa.text("REFRESH TABLE mlb_teams_2012;")) | ||
connection.commit() | ||
|
||
|
||
def test_cratedb_loader_no_options(db: SQLDatabase) -> None: | ||
"""Test SQLAlchemy loader basics.""" | ||
|
||
loader = CrateDBLoader("SELECT 1 AS a, 2 AS b", db=db) | ||
docs = loader.load() | ||
|
||
assert len(docs) == 1 | ||
assert docs[0].page_content == "a: 1\nb: 2" | ||
assert docs[0].metadata == {} | ||
|
||
|
||
def test_cratedb_loader_include_rownum_into_metadata(db: SQLDatabase) -> None: | ||
"""Test SQLAlchemy loader with row number in metadata.""" | ||
|
||
loader = CrateDBLoader( | ||
"SELECT 1 AS a, 2 AS b", | ||
db=db, | ||
include_rownum_into_metadata=True, | ||
) | ||
docs = loader.load() | ||
|
||
assert len(docs) == 1 | ||
assert docs[0].page_content == "a: 1\nb: 2" | ||
assert docs[0].metadata == {"row": 0} | ||
|
||
|
||
def test_cratedb_loader_include_query_into_metadata(db: SQLDatabase) -> None: | ||
"""Test SQLAlchemy loader with query in metadata.""" | ||
|
||
loader = CrateDBLoader( | ||
"SELECT 1 AS a, 2 AS b", db=db, include_query_into_metadata=True | ||
) | ||
docs = loader.load() | ||
|
||
assert len(docs) == 1 | ||
assert docs[0].page_content == "a: 1\nb: 2" | ||
assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"} | ||
|
||
|
||
def test_cratedb_loader_page_content_columns(db: SQLDatabase) -> None: | ||
"""Test SQLAlchemy loader with defined page content columns.""" | ||
|
||
# Define a custom callback function to convert a row into a "page content" string. | ||
row_to_content = functools.partial( | ||
SQLDatabaseLoader.page_content_default_mapper, column_names=["a"] | ||
) | ||
|
||
loader = CrateDBLoader( | ||
"SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b", | ||
db=db, | ||
page_content_mapper=row_to_content, | ||
) | ||
docs = loader.load() | ||
|
||
assert len(docs) == 2 | ||
assert docs[0].page_content == "a: 1" | ||
assert docs[0].metadata == {} | ||
|
||
assert docs[1].page_content == "a: 3" | ||
assert docs[1].metadata == {} | ||
|
||
|
||
def test_cratedb_loader_metadata_columns(db: SQLDatabase) -> None: | ||
"""Test SQLAlchemy loader with defined metadata columns.""" | ||
|
||
# Define a custom callback function to convert a row into a "metadata" dictionary. | ||
row_to_metadata = functools.partial( | ||
SQLDatabaseLoader.metadata_default_mapper, column_names=["b"] | ||
) | ||
|
||
loader = CrateDBLoader( | ||
"SELECT 1 AS a, 2 AS b", | ||
db=db, | ||
metadata_mapper=row_to_metadata, | ||
) | ||
docs = loader.load() | ||
|
||
assert len(docs) == 1 | ||
assert docs[0].metadata == {"b": 2} | ||
|
||
|
||
def test_cratedb_loader_real_data_with_sql_no_parameters( | ||
db: SQLDatabase, provision_database: None | ||
) -> None: | ||
"""Test SQLAlchemy loader with real data, querying by SQL statement.""" | ||
|
||
loader = CrateDBLoader( | ||
query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";', | ||
db=db, | ||
) | ||
docs = loader.load() | ||
|
||
assert len(docs) == 30 | ||
assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" | ||
assert docs[0].metadata == {} | ||
|
||
|
||
def test_cratedb_loader_real_data_with_sql_and_parameters( | ||
db: SQLDatabase, provision_database: None | ||
) -> None: | ||
"""Test SQLAlchemy loader, querying by SQL statement and parameters.""" | ||
|
||
loader = CrateDBLoader( | ||
query='SELECT * FROM mlb_teams_2012 WHERE "Team" LIKE :search ORDER BY "Team";', | ||
parameters={"search": "R%"}, | ||
db=db, | ||
) | ||
docs = loader.load() | ||
|
||
assert len(docs) == 6 | ||
assert docs[0].page_content == "Team: Rangers\nPayroll (millions): 120.51\nWins: 93" | ||
assert docs[0].metadata == {} | ||
|
||
|
||
def test_cratedb_loader_real_data_with_selectable( | ||
db: SQLDatabase, provision_database: None | ||
) -> None: | ||
"""Test SQLAlchemy loader with real data, querying by SQLAlchemy selectable.""" | ||
|
||
# Define an SQLAlchemy table. | ||
mlb_teams_2012 = sa.Table( | ||
"mlb_teams_2012", | ||
sa.MetaData(), | ||
sa.Column("Team", sa.VARCHAR), | ||
sa.Column("Payroll (millions)", sa.FLOAT), | ||
sa.Column("Wins", sa.BIGINT), | ||
) | ||
|
||
# Query the database table using an SQLAlchemy selectable. | ||
select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team) | ||
loader = CrateDBLoader( | ||
query=select, | ||
db=db, | ||
include_query_into_metadata=True, | ||
) | ||
docs = loader.load() | ||
|
||
assert len(docs) == 30 | ||
assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" | ||
assert docs[0].metadata == { | ||
"query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", ' | ||
'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 ' | ||
'ORDER BY mlb_teams_2012."Team"' | ||
} |