Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: partition_pdf() support language specification for PaddleOCR #3400

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

### Features

* **Add support for specifying OCR language to `partition_pdf()`.** Extend language specification capability to `PaddleOCR` in addition to `TesseractOCR`. Users can now specify OCR languages for both OCR engines when using `partition_pdf()`.
* **Add AstraDB source connector** Adds support for ingesting documents from AstraDB.

### Fixes
Expand Down
28 changes: 9 additions & 19 deletions test_unstructured/partition/pdf_image/test_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Source,
)
from unstructured.partition.utils.ocr_models.google_vision_ocr import OCRAgentGoogleVision
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.partition.utils.ocr_models.paddle_ocr import OCRAgentPaddle
from unstructured.partition.utils.ocr_models.tesseract_ocr import (
OCRAgentTesseract,
Expand Down Expand Up @@ -85,10 +86,7 @@ def test_get_ocr_layout_from_image_tesseract(monkeypatch):
image = Image.new("RGB", (100, 100))

ocr_agent = OCRAgentTesseract()
ocr_layout = ocr_agent.get_layout_from_image(
image,
ocr_languages="eng",
)
ocr_layout = ocr_agent.get_layout_from_image(image)

expected_layout = [
TextRegion.from_coords(10, 5, 25, 15, "Hello", source=Source.OCR_TESSERACT),
Expand Down Expand Up @@ -128,7 +126,7 @@ def mock_ocr(*args, **kwargs):
]


def monkeypatch_load_agent(language: str):
def monkeypatch_load_agent(*args):
class MockAgent:
def __init__(self):
self.ocr = mock_ocr
Expand All @@ -145,10 +143,7 @@ def test_get_ocr_layout_from_image_paddle(monkeypatch):

image = Image.new("RGB", (100, 100))

ocr_layout = OCRAgentPaddle().get_layout_from_image(
image,
ocr_languages="eng",
)
ocr_layout = OCRAgentPaddle().get_layout_from_image(image)

expected_layout = [
TextRegion.from_coords(10, 5, 25, 15, "Hello", source=Source.OCR_PADDLE),
Expand All @@ -168,10 +163,7 @@ def test_get_ocr_text_from_image_tesseract(monkeypatch):
image = Image.new("RGB", (100, 100))

ocr_agent = OCRAgentTesseract()
ocr_text = ocr_agent.get_text_from_image(
image,
ocr_languages="eng",
)
ocr_text = ocr_agent.get_text_from_image(image)

assert ocr_text == "Hello World"

Expand All @@ -186,10 +178,7 @@ def test_get_ocr_text_from_image_paddle(monkeypatch):
image = Image.new("RGB", (100, 100))

ocr_agent = OCRAgentPaddle()
ocr_text = ocr_agent.get_text_from_image(
image,
ocr_languages="eng",
)
ocr_text = ocr_agent.get_text_from_image(image)

assert ocr_text == "Hello\n\nWorld\n\n!"

Expand Down Expand Up @@ -251,7 +240,7 @@ def test_get_ocr_from_image_google_vision(google_vision_client):
image = Image.new("RGB", (100, 100))

ocr_agent = google_vision_client
ocr_text = ocr_agent.get_text_from_image(image, ocr_languages="eng")
ocr_text = ocr_agent.get_text_from_image(image)

assert ocr_text == "Hello World!"

Expand Down Expand Up @@ -428,7 +417,8 @@ def mock_ocr_layout():

def test_get_table_tokens(mock_ocr_layout):
with patch.object(OCRAgentTesseract, "get_layout_from_image", return_value=mock_ocr_layout):
table_tokens = ocr.get_table_tokens(table_element_image=None)
ocr_agent = OCRAgent.get_agent(language="eng")
table_tokens = ocr.get_table_tokens(table_element_image=None, ocr_agent=ocr_agent)
expected_tokens = [
{
"bbox": [15, 25, 35, 45],
Expand Down
34 changes: 34 additions & 0 deletions test_unstructured/partition/test_lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
check_language_args,
detect_languages,
prepare_languages_for_tesseract,
tesseract_to_paddle_language,
)

DIRECTORY = pathlib.Path(__file__).parent.resolve()
Expand Down Expand Up @@ -84,6 +85,39 @@ def test_prepare_languages_for_tesseract_no_valid_languages(caplog):
assert "Failed to find any valid standard language code from languages" in caplog.text


@pytest.mark.parametrize(
("tesseract_lang", "expected_lang"),
[
("eng", "en"),
("chi_sim", "ch"),
("chi_tra", "chinese_cht"),
("deu", "german"),
("jpn", "japan"),
("kor", "korean"),
],
)
def test_tesseract_to_paddle_language_valid_codes(tesseract_lang, expected_lang):
assert expected_lang == tesseract_to_paddle_language(tesseract_lang)


def test_tesseract_to_paddle_language_invalid_codes(caplog):
tesseract_lang = "unsupported_lang"
assert tesseract_to_paddle_language(tesseract_lang) == "en"
assert "unsupported_lang is not a language code supported by PaddleOCR," in caplog.text


@pytest.mark.parametrize(
("tesseract_lang", "expected_lang"),
[
("ENG", "en"),
("Fra", "fr"),
("DEU", "german"),
],
)
def test_tesseract_to_paddle_language_case_sensitivity(tesseract_lang, expected_lang):
assert expected_lang == tesseract_to_paddle_language(tesseract_lang)


def test_detect_languages_english_auto():
text = "This is a short sentence."
assert detect_languages(text) == ["eng"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ def it_provides_access_to_the_configured_OCR_agent(
_get_ocr_agent_cls_qname_.return_value = OCR_AGENT_TESSERACT
get_instance_.return_value = ocr_agent_

ocr_agent = OCRAgent.get_agent()
ocr_agent = OCRAgent.get_agent(language="eng")

_get_ocr_agent_cls_qname_.assert_called_once_with()
get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT)
get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT, "eng")
assert ocr_agent is ocr_agent_

def but_it_raises_when_the_requested_agent_is_not_whitelisted(
self, _get_ocr_agent_cls_qname_: Mock
):
_get_ocr_agent_cls_qname_.return_value = "Invalid.Ocr.Agent.Qname"
with pytest.raises(ValueError, match="must be set to a whitelisted module"):
OCRAgent.get_agent()
OCRAgent.get_agent(language="eng")

@pytest.mark.parametrize("exception_cls", [ImportError, AttributeError])
def and_it_raises_when_the_requested_agent_cannot_be_loaded(
Expand All @@ -57,7 +57,7 @@ def and_it_raises_when_the_requested_agent_cannot_be_loaded(
"unstructured.partition.utils.ocr_models.ocr_interface.importlib.import_module",
side_effect=exception_cls,
), pytest.raises(RuntimeError, match="Could not get the OCRAgent instance"):
OCRAgent.get_agent()
OCRAgent.get_agent(language="eng")

@pytest.mark.parametrize(
("OCR_AGENT", "expected_value"),
Expand Down
5 changes: 4 additions & 1 deletion unstructured/metrics/table_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from unstructured.partition.pdf import convert_pdf_to_images
from unstructured.partition.pdf_image.ocr import get_table_tokens
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.utils import requires_dependencies


Expand All @@ -21,8 +22,10 @@ def image_or_pdf_to_dataframe(filename: str) -> pd.DataFrame:
else:
image = Image.open(filename).convert("RGB")

ocr_agent = OCRAgent.get_agent(language="eng")

return tables_agent.run_prediction(
image, ocr_tokens=get_table_tokens(image), result_format="dataframe"
image, ocr_tokens=get_table_tokens(image, ocr_agent), result_format="dataframe"
)


Expand Down
2 changes: 1 addition & 1 deletion unstructured/partition/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def partition_image(
"""
exactly_one(filename=filename, file=file)

languages = check_language_args(languages or [], ocr_languages) or ["eng"]
languages = check_language_args(languages or [], ocr_languages)

return partition_pdf_or_image(
filename=filename,
Expand Down
76 changes: 76 additions & 0 deletions unstructured/partition/lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,63 @@
"yor",
]

PYTESSERACT_TO_PADDLE_LANG_CODE_MAP = {
"afr": "af", # Afrikaans
"ara": "ar", # Arabic
"aze": "az", # Azerbaijani
"bel": "be", # Belarusian
"bos": "bs", # Bosnian
"bul": "bg", # Bulgarian
"ces": "cs", # Czech
"chi_sim": "ch", # Simplified Chinese
"chi_tra": "chinese_cht", # Traditional Chinese
"cym": "cy", # Welsh
"dan": "da", # Danish
"deu": "german", # German
"eng": "en", # English
"est": "et", # Estonian
"fas": "fa", # Persian
"fra": "fr", # French
"gle": "ga", # Irish
"hin": "hi", # Hindi
"hrv": "hr", # Croatian
"hun": "hu", # Hungarian
"ind": "id", # Indonesian
"isl": "is", # Icelandic
"ita": "it", # Italian
"jpn": "japan", # Japanese
"kor": "korean", # Korean
"kmr": "ku", # Kurdish
"lat": "rs_latin", # Latin
"lav": "lv", # Latvian
"lit": "lt", # Lithuanian
"mar": "mr", # Marathi
"mlt": "mt", # Maltese
"msa": "ms", # Malay
"nep": "ne", # Nepali
"nld": "nl", # Dutch
"nor": "no", # Norwegian
"pol": "pl", # Polish
"por": "pt", # Portuguese
"ron": "ro", # Romanian
"rus": "ru", # Russian
"slk": "sk", # Slovak
"slv": "sl", # Slovenian
"spa": "es", # Spanish
"sqi": "sq", # Albanian
"srp": "rs_cyrillic", # Serbian
"swa": "sw", # Swahili
"swe": "sv", # Swedish
"tam": "ta", # Tamil
"tel": "te", # Telugu
"tur": "tr", # Turkish
"uig": "ug", # Uyghur
"ukr": "uk", # Ukrainian
"urd": "ur", # Urdu
"uzb": "uz", # Uzbek
"vie": "vi", # Vietnamese
}


def prepare_languages_for_tesseract(languages: Optional[list[str]] = ["eng"]) -> str:
"""
Expand All @@ -169,6 +226,25 @@ def prepare_languages_for_tesseract(languages: Optional[list[str]] = ["eng"]) ->
return TESSERACT_LANGUAGES_SPLITTER.join(converted_languages)


def tesseract_to_paddle_language(tesseract_language: str) -> str:
"""
Convert TesseractOCR language code to PaddleOCR language code.

:param tesseract_language: str, language code used in TesseractOCR
:return: str, corresponding language code for PaddleOCR or None if not found
"""

lang = PYTESSERACT_TO_PADDLE_LANG_CODE_MAP.get(tesseract_language.lower())
if not lang:
logger.warning(
f"{tesseract_language} is not a language code supported by PaddleOCR, "
f"proceeding with `en` instead."
)
return "en"

return lang


def check_language_args(languages: list[str], ocr_languages: Optional[str]) -> Optional[list[str]]:
"""Handle users defining both `ocr_languages` and `languages`, giving preference to `languages`
and converting `ocr_languages` if needed, but defaulting to `None.
Expand Down
Loading
Loading