Skip to content

Commit

Permalink
feat: partition_pdf() support language specification for PaddleOCR (#…
Browse files Browse the repository at this point in the history
…3400)

Closes #3159.

This PR extends language specification capability to `PaddleOCR` in
addition to `TesseractOCR`. Users can now specify OCR languages for both
OCR engines when using `partition_pdf()`.

### Testing

```
os.environ["OCR_AGENT"] = "unstructured.partition.utils.ocr_models.paddle_ocr.OCRAgentPaddle"

elements = partition_pdf(
    filename=<file_path>,
    strategy=strategy,
    languages=["chi_sim"], # chinese - simplified
    infer_table_structure=True,
)
```
  • Loading branch information
christinestraub authored Jul 16, 2024
1 parent 6b1d5f2 commit 48bdf94
Show file tree
Hide file tree
Showing 13 changed files with 186 additions and 108 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

### Features

* **Add support for specifying OCR language to `partition_pdf()`.** Extend language specification capability to `PaddleOCR` in addition to `TesseractOCR`. Users can now specify OCR languages for both OCR engines when using `partition_pdf()`.
* **Add AstraDB source connector** Adds support for ingesting documents from AstraDB.

### Fixes
Expand Down
28 changes: 9 additions & 19 deletions test_unstructured/partition/pdf_image/test_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Source,
)
from unstructured.partition.utils.ocr_models.google_vision_ocr import OCRAgentGoogleVision
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.partition.utils.ocr_models.paddle_ocr import OCRAgentPaddle
from unstructured.partition.utils.ocr_models.tesseract_ocr import (
OCRAgentTesseract,
Expand Down Expand Up @@ -85,10 +86,7 @@ def test_get_ocr_layout_from_image_tesseract(monkeypatch):
image = Image.new("RGB", (100, 100))

ocr_agent = OCRAgentTesseract()
ocr_layout = ocr_agent.get_layout_from_image(
image,
ocr_languages="eng",
)
ocr_layout = ocr_agent.get_layout_from_image(image)

expected_layout = [
TextRegion.from_coords(10, 5, 25, 15, "Hello", source=Source.OCR_TESSERACT),
Expand Down Expand Up @@ -128,7 +126,7 @@ def mock_ocr(*args, **kwargs):
]


def monkeypatch_load_agent(language: str):
def monkeypatch_load_agent(*args):
class MockAgent:
def __init__(self):
self.ocr = mock_ocr
Expand All @@ -145,10 +143,7 @@ def test_get_ocr_layout_from_image_paddle(monkeypatch):

image = Image.new("RGB", (100, 100))

ocr_layout = OCRAgentPaddle().get_layout_from_image(
image,
ocr_languages="eng",
)
ocr_layout = OCRAgentPaddle().get_layout_from_image(image)

expected_layout = [
TextRegion.from_coords(10, 5, 25, 15, "Hello", source=Source.OCR_PADDLE),
Expand All @@ -168,10 +163,7 @@ def test_get_ocr_text_from_image_tesseract(monkeypatch):
image = Image.new("RGB", (100, 100))

ocr_agent = OCRAgentTesseract()
ocr_text = ocr_agent.get_text_from_image(
image,
ocr_languages="eng",
)
ocr_text = ocr_agent.get_text_from_image(image)

assert ocr_text == "Hello World"

Expand All @@ -186,10 +178,7 @@ def test_get_ocr_text_from_image_paddle(monkeypatch):
image = Image.new("RGB", (100, 100))

ocr_agent = OCRAgentPaddle()
ocr_text = ocr_agent.get_text_from_image(
image,
ocr_languages="eng",
)
ocr_text = ocr_agent.get_text_from_image(image)

assert ocr_text == "Hello\n\nWorld\n\n!"

Expand Down Expand Up @@ -251,7 +240,7 @@ def test_get_ocr_from_image_google_vision(google_vision_client):
image = Image.new("RGB", (100, 100))

ocr_agent = google_vision_client
ocr_text = ocr_agent.get_text_from_image(image, ocr_languages="eng")
ocr_text = ocr_agent.get_text_from_image(image)

assert ocr_text == "Hello World!"

Expand Down Expand Up @@ -428,7 +417,8 @@ def mock_ocr_layout():

def test_get_table_tokens(mock_ocr_layout):
with patch.object(OCRAgentTesseract, "get_layout_from_image", return_value=mock_ocr_layout):
table_tokens = ocr.get_table_tokens(table_element_image=None)
ocr_agent = OCRAgent.get_agent(language="eng")
table_tokens = ocr.get_table_tokens(table_element_image=None, ocr_agent=ocr_agent)
expected_tokens = [
{
"bbox": [15, 25, 35, 45],
Expand Down
34 changes: 34 additions & 0 deletions test_unstructured/partition/test_lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
check_language_args,
detect_languages,
prepare_languages_for_tesseract,
tesseract_to_paddle_language,
)

DIRECTORY = pathlib.Path(__file__).parent.resolve()
Expand Down Expand Up @@ -84,6 +85,39 @@ def test_prepare_languages_for_tesseract_no_valid_languages(caplog):
assert "Failed to find any valid standard language code from languages" in caplog.text


@pytest.mark.parametrize(
("tesseract_lang", "expected_lang"),
[
("eng", "en"),
("chi_sim", "ch"),
("chi_tra", "chinese_cht"),
("deu", "german"),
("jpn", "japan"),
("kor", "korean"),
],
)
def test_tesseract_to_paddle_language_valid_codes(tesseract_lang, expected_lang):
assert expected_lang == tesseract_to_paddle_language(tesseract_lang)


def test_tesseract_to_paddle_language_invalid_codes(caplog):
tesseract_lang = "unsupported_lang"
assert tesseract_to_paddle_language(tesseract_lang) == "en"
assert "unsupported_lang is not a language code supported by PaddleOCR," in caplog.text


@pytest.mark.parametrize(
("tesseract_lang", "expected_lang"),
[
("ENG", "en"),
("Fra", "fr"),
("DEU", "german"),
],
)
def test_tesseract_to_paddle_language_case_sensitivity(tesseract_lang, expected_lang):
assert expected_lang == tesseract_to_paddle_language(tesseract_lang)


def test_detect_languages_english_auto():
text = "This is a short sentence."
assert detect_languages(text) == ["eng"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ def it_provides_access_to_the_configured_OCR_agent(
_get_ocr_agent_cls_qname_.return_value = OCR_AGENT_TESSERACT
get_instance_.return_value = ocr_agent_

ocr_agent = OCRAgent.get_agent()
ocr_agent = OCRAgent.get_agent(language="eng")

_get_ocr_agent_cls_qname_.assert_called_once_with()
get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT)
get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT, "eng")
assert ocr_agent is ocr_agent_

def but_it_raises_when_the_requested_agent_is_not_whitelisted(
self, _get_ocr_agent_cls_qname_: Mock
):
_get_ocr_agent_cls_qname_.return_value = "Invalid.Ocr.Agent.Qname"
with pytest.raises(ValueError, match="must be set to a whitelisted module"):
OCRAgent.get_agent()
OCRAgent.get_agent(language="eng")

@pytest.mark.parametrize("exception_cls", [ImportError, AttributeError])
def and_it_raises_when_the_requested_agent_cannot_be_loaded(
Expand All @@ -57,7 +57,7 @@ def and_it_raises_when_the_requested_agent_cannot_be_loaded(
"unstructured.partition.utils.ocr_models.ocr_interface.importlib.import_module",
side_effect=exception_cls,
), pytest.raises(RuntimeError, match="Could not get the OCRAgent instance"):
OCRAgent.get_agent()
OCRAgent.get_agent(language="eng")

@pytest.mark.parametrize(
("OCR_AGENT", "expected_value"),
Expand Down
5 changes: 4 additions & 1 deletion unstructured/metrics/table_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from unstructured.partition.pdf import convert_pdf_to_images
from unstructured.partition.pdf_image.ocr import get_table_tokens
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.utils import requires_dependencies


Expand All @@ -21,8 +22,10 @@ def image_or_pdf_to_dataframe(filename: str) -> pd.DataFrame:
else:
image = Image.open(filename).convert("RGB")

ocr_agent = OCRAgent.get_agent(language="eng")

return tables_agent.run_prediction(
image, ocr_tokens=get_table_tokens(image), result_format="dataframe"
image, ocr_tokens=get_table_tokens(image, ocr_agent), result_format="dataframe"
)


Expand Down
2 changes: 1 addition & 1 deletion unstructured/partition/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def partition_image(
"""
exactly_one(filename=filename, file=file)

languages = check_language_args(languages or [], ocr_languages) or ["eng"]
languages = check_language_args(languages or [], ocr_languages)

return partition_pdf_or_image(
filename=filename,
Expand Down
76 changes: 76 additions & 0 deletions unstructured/partition/lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,63 @@
"yor",
]

PYTESSERACT_TO_PADDLE_LANG_CODE_MAP = {
"afr": "af", # Afrikaans
"ara": "ar", # Arabic
"aze": "az", # Azerbaijani
"bel": "be", # Belarusian
"bos": "bs", # Bosnian
"bul": "bg", # Bulgarian
"ces": "cs", # Czech
"chi_sim": "ch", # Simplified Chinese
"chi_tra": "chinese_cht", # Traditional Chinese
"cym": "cy", # Welsh
"dan": "da", # Danish
"deu": "german", # German
"eng": "en", # English
"est": "et", # Estonian
"fas": "fa", # Persian
"fra": "fr", # French
"gle": "ga", # Irish
"hin": "hi", # Hindi
"hrv": "hr", # Croatian
"hun": "hu", # Hungarian
"ind": "id", # Indonesian
"isl": "is", # Icelandic
"ita": "it", # Italian
"jpn": "japan", # Japanese
"kor": "korean", # Korean
"kmr": "ku", # Kurdish
"lat": "rs_latin", # Latin
"lav": "lv", # Latvian
"lit": "lt", # Lithuanian
"mar": "mr", # Marathi
"mlt": "mt", # Maltese
"msa": "ms", # Malay
"nep": "ne", # Nepali
"nld": "nl", # Dutch
"nor": "no", # Norwegian
"pol": "pl", # Polish
"por": "pt", # Portuguese
"ron": "ro", # Romanian
"rus": "ru", # Russian
"slk": "sk", # Slovak
"slv": "sl", # Slovenian
"spa": "es", # Spanish
"sqi": "sq", # Albanian
"srp": "rs_cyrillic", # Serbian
"swa": "sw", # Swahili
"swe": "sv", # Swedish
"tam": "ta", # Tamil
"tel": "te", # Telugu
"tur": "tr", # Turkish
"uig": "ug", # Uyghur
"ukr": "uk", # Ukrainian
"urd": "ur", # Urdu
"uzb": "uz", # Uzbek
"vie": "vi", # Vietnamese
}


def prepare_languages_for_tesseract(languages: Optional[list[str]] = ["eng"]) -> str:
"""
Expand All @@ -169,6 +226,25 @@ def prepare_languages_for_tesseract(languages: Optional[list[str]] = ["eng"]) ->
return TESSERACT_LANGUAGES_SPLITTER.join(converted_languages)


def tesseract_to_paddle_language(tesseract_language: str) -> str:
"""
Convert TesseractOCR language code to PaddleOCR language code.
:param tesseract_language: str, language code used in TesseractOCR
:return: str, corresponding language code for PaddleOCR or None if not found
"""

lang = PYTESSERACT_TO_PADDLE_LANG_CODE_MAP.get(tesseract_language.lower())
if not lang:
logger.warning(
f"{tesseract_language} is not a language code supported by PaddleOCR, "
f"proceeding with `en` instead."
)
return "en"

return lang


def check_language_args(languages: list[str], ocr_languages: Optional[str]) -> Optional[list[str]]:
"""Handle users defining both `ocr_languages` and `languages`, giving preference to `languages`
and converting `ocr_languages` if needed, but defaulting to `None.
Expand Down
Loading

0 comments on commit 48bdf94

Please sign in to comment.