Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: partition_pdf() support language specification for PaddleOCR #3400

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## 0.15.0-dev12
## 0.15.0-dev13

### Enhancements

Expand All @@ -9,6 +9,7 @@

### Features

* **Add support for specifying OCR language to `partition_pdf()`.** Extend language specification capability to `PaddleOCR` in addition to `TesseractOCR`. Users can now specify OCR languages for both OCR engines when using `partition_pdf()`.
* **Add AstraDB source connector** Adds support for ingesting documents from AstraDB.

### Fixes
Expand Down
28 changes: 9 additions & 19 deletions test_unstructured/partition/pdf_image/test_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Source,
)
from unstructured.partition.utils.ocr_models.google_vision_ocr import OCRAgentGoogleVision
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.partition.utils.ocr_models.paddle_ocr import OCRAgentPaddle
from unstructured.partition.utils.ocr_models.tesseract_ocr import (
OCRAgentTesseract,
Expand Down Expand Up @@ -85,10 +86,7 @@ def test_get_ocr_layout_from_image_tesseract(monkeypatch):
image = Image.new("RGB", (100, 100))

ocr_agent = OCRAgentTesseract()
ocr_layout = ocr_agent.get_layout_from_image(
image,
ocr_languages="eng",
)
ocr_layout = ocr_agent.get_layout_from_image(image)

expected_layout = [
TextRegion.from_coords(10, 5, 25, 15, "Hello", source=Source.OCR_TESSERACT),
Expand Down Expand Up @@ -128,7 +126,7 @@ def mock_ocr(*args, **kwargs):
]


def monkeypatch_load_agent(language: str):
def monkeypatch_load_agent(*args):
class MockAgent:
def __init__(self):
self.ocr = mock_ocr
Expand All @@ -145,10 +143,7 @@ def test_get_ocr_layout_from_image_paddle(monkeypatch):

image = Image.new("RGB", (100, 100))

ocr_layout = OCRAgentPaddle().get_layout_from_image(
image,
ocr_languages="eng",
)
ocr_layout = OCRAgentPaddle().get_layout_from_image(image)

expected_layout = [
TextRegion.from_coords(10, 5, 25, 15, "Hello", source=Source.OCR_PADDLE),
Expand All @@ -168,10 +163,7 @@ def test_get_ocr_text_from_image_tesseract(monkeypatch):
image = Image.new("RGB", (100, 100))

ocr_agent = OCRAgentTesseract()
ocr_text = ocr_agent.get_text_from_image(
image,
ocr_languages="eng",
)
ocr_text = ocr_agent.get_text_from_image(image)

assert ocr_text == "Hello World"

Expand All @@ -186,10 +178,7 @@ def test_get_ocr_text_from_image_paddle(monkeypatch):
image = Image.new("RGB", (100, 100))

ocr_agent = OCRAgentPaddle()
ocr_text = ocr_agent.get_text_from_image(
image,
ocr_languages="eng",
)
ocr_text = ocr_agent.get_text_from_image(image)

assert ocr_text == "Hello\n\nWorld\n\n!"

Expand Down Expand Up @@ -251,7 +240,7 @@ def test_get_ocr_from_image_google_vision(google_vision_client):
image = Image.new("RGB", (100, 100))

ocr_agent = google_vision_client
ocr_text = ocr_agent.get_text_from_image(image, ocr_languages="eng")
ocr_text = ocr_agent.get_text_from_image(image)

assert ocr_text == "Hello World!"

Expand Down Expand Up @@ -428,7 +417,8 @@ def mock_ocr_layout():

def test_get_table_tokens(mock_ocr_layout):
with patch.object(OCRAgentTesseract, "get_layout_from_image", return_value=mock_ocr_layout):
table_tokens = ocr.get_table_tokens(table_element_image=None)
ocr_agent = OCRAgent.get_agent(language="eng")
table_tokens = ocr.get_table_tokens(table_element_image=None, ocr_agent=ocr_agent)
expected_tokens = [
{
"bbox": [15, 25, 35, 45],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ def it_provides_access_to_the_configured_OCR_agent(
_get_ocr_agent_cls_qname_.return_value = OCR_AGENT_TESSERACT
get_instance_.return_value = ocr_agent_

ocr_agent = OCRAgent.get_agent()
ocr_agent = OCRAgent.get_agent(language="eng")

_get_ocr_agent_cls_qname_.assert_called_once_with()
get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT)
get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT, "eng")
assert ocr_agent is ocr_agent_

def but_it_raises_when_the_requested_agent_is_not_whitelisted(
self, _get_ocr_agent_cls_qname_: Mock
):
_get_ocr_agent_cls_qname_.return_value = "Invalid.Ocr.Agent.Qname"
with pytest.raises(ValueError, match="must be set to a whitelisted module"):
OCRAgent.get_agent()
OCRAgent.get_agent(language="eng")

@pytest.mark.parametrize("exception_cls", [ImportError, AttributeError])
def and_it_raises_when_the_requested_agent_cannot_be_loaded(
Expand All @@ -57,7 +57,7 @@ def and_it_raises_when_the_requested_agent_cannot_be_loaded(
"unstructured.partition.utils.ocr_models.ocr_interface.importlib.import_module",
side_effect=exception_cls,
), pytest.raises(RuntimeError, match="Could not get the OCRAgent instance"):
OCRAgent.get_agent()
OCRAgent.get_agent(language="eng")

@pytest.mark.parametrize(
("OCR_AGENT", "expected_value"),
Expand Down
2 changes: 1 addition & 1 deletion unstructured/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.15.0-dev12" # pragma: no cover
__version__ = "0.15.0-dev13" # pragma: no cover
5 changes: 4 additions & 1 deletion unstructured/metrics/table_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from unstructured.partition.pdf import convert_pdf_to_images
from unstructured.partition.pdf_image.ocr import get_table_tokens
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.utils import requires_dependencies


Expand All @@ -21,8 +22,10 @@ def image_or_pdf_to_dataframe(filename: str) -> pd.DataFrame:
else:
image = Image.open(filename).convert("RGB")

ocr_agent = OCRAgent.get_agent(language="eng")

return tables_agent.run_prediction(
image, ocr_tokens=get_table_tokens(image), result_format="dataframe"
image, ocr_tokens=get_table_tokens(image, ocr_agent), result_format="dataframe"
)


Expand Down
75 changes: 75 additions & 0 deletions unstructured/partition/lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,63 @@
"yor",
]

PYTESSERACT_TO_PADDLE_LANG_CODE_MAP = {
"afr": "af", # Afrikaans
"ara": "ar", # Arabic
"aze": "az", # Azerbaijani
"bel": "be", # Belarusian
"bos": "bs", # Bosnian
"bul": "bg", # Bulgarian
"ces": "cs", # Czech
"chi_sim": "ch", # Simplified Chinese
"chi_tra": "chinese_cht", # Traditional Chinese
"cym": "cy", # Welsh
"dan": "da", # Danish
"deu": "german", # German
"eng": "en", # English
"est": "et", # Estonian
"fas": "fa", # Persian
"fra": "fr", # French
"gle": "ga", # Irish
"hin": "hi", # Hindi
"hrv": "hr", # Croatian
"hun": "hu", # Hungarian
"ind": "id", # Indonesian
"isl": "is", # Icelandic
"ita": "it", # Italian
"jpn": "japan", # Japanese
"kor": "korean", # Korean
"kmr": "ku", # Kurdish
"lat": "rs_latin", # Latin
"lav": "lv", # Latvian
"lit": "lt", # Lithuanian
"mar": "mr", # Marathi
"mlt": "mt", # Maltese
"msa": "ms", # Malay
"nep": "ne", # Nepali
"nld": "nl", # Dutch
"nor": "no", # Norwegian
"pol": "pl", # Polish
"por": "pt", # Portuguese
"ron": "ro", # Romanian
"rus": "ru", # Russian
"slk": "sk", # Slovak
"slv": "sl", # Slovenian
"spa": "es", # Spanish
"sqi": "sq", # Albanian
"srp": "rs_cyrillic", # Serbian
"swa": "sw", # Swahili
"swe": "sv", # Swedish
"tam": "ta", # Tamil
"tel": "te", # Telugu
"tur": "tr", # Turkish
"uig": "ug", # Uyghur
"ukr": "uk", # Ukrainian
"urd": "ur", # Urdu
"uzb": "uz", # Uzbek
"vie": "vi", # Vietnamese
}


def prepare_languages_for_tesseract(languages: Optional[list[str]] = ["eng"]) -> str:
"""
Expand All @@ -169,6 +226,24 @@ def prepare_languages_for_tesseract(languages: Optional[list[str]] = ["eng"]) ->
return TESSERACT_LANGUAGES_SPLITTER.join(converted_languages)


def tesseract_to_paddle_language(tesseract_language: str = "eng") -> str:
"""
Convert TesseractOCR language code to PaddleOCR language code.

:param tesseract_language: str, language code used in TesseractOCR
:return: str, corresponding language code for PaddleOCR or None if not found
"""

lang = PYTESSERACT_TO_PADDLE_LANG_CODE_MAP.get(tesseract_language)
if not lang:
logger.warning(
f"{lang} is not a language supported by PaddleOCR, proceed with `en` instead."
christinestraub marked this conversation as resolved.
Show resolved Hide resolved
)
return "en"

return lang


def check_language_args(languages: list[str], ocr_languages: Optional[str]) -> Optional[list[str]]:
"""Handle users defining both `ocr_languages` and `languages`, giving preference to `languages`
and converting `ocr_languages` if needed, but defaulting to `None.
Expand Down
31 changes: 20 additions & 11 deletions unstructured/partition/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@
ocr_data_to_elements,
spooled_to_bytes_io_if_needed,
)
from unstructured.partition.lang import check_language_args, prepare_languages_for_tesseract
from unstructured.partition.lang import (
check_language_args,
prepare_languages_for_tesseract,
tesseract_to_paddle_language,
)
from unstructured.partition.pdf_image.analysis.bbox_visualisation import (
AnalysisDrawer,
FinalLayoutDrawer,
Expand Down Expand Up @@ -77,6 +81,7 @@
from unstructured.partition.text import element_from_text
from unstructured.partition.utils.config import env_config
from unstructured.partition.utils.constants import (
OCR_AGENT_PADDLE,
SORT_MODE_BASIC,
SORT_MODE_DONT,
SORT_MODE_XY_CUT,
Expand Down Expand Up @@ -227,7 +232,6 @@ def partition_pdf_or_image(
include_page_breaks: bool = False,
strategy: str = PartitionStrategy.AUTO,
infer_table_structure: bool = False,
ocr_languages: Optional[str] = None,
languages: Optional[list[str]] = None,
metadata_last_modified: Optional[str] = None,
hi_res_model_name: Optional[str] = None,
Expand Down Expand Up @@ -291,6 +295,10 @@ def partition_pdf_or_image(
if file is not None:
file.seek(0)

ocr_languages = prepare_languages_for_tesseract(languages)
if env_config.OCR_AGENT == OCR_AGENT_PADDLE:
ocr_languages = tesseract_to_paddle_language(ocr_languages)

if strategy == PartitionStrategy.HI_RES:
# NOTE(robinson): Catches a UserWarning that occurs when detection is called
with warnings.catch_warnings():
Expand All @@ -302,6 +310,7 @@ def partition_pdf_or_image(
infer_table_structure=infer_table_structure,
include_page_breaks=include_page_breaks,
languages=languages,
ocr_languages=ocr_languages,
metadata_last_modified=metadata_last_modified or last_modification_date,
hi_res_model_name=hi_res_model_name,
pdf_text_extractable=pdf_text_extractable,
Expand Down Expand Up @@ -333,6 +342,7 @@ def partition_pdf_or_image(
file=file,
include_page_breaks=include_page_breaks,
languages=languages,
ocr_languages=ocr_languages,
is_image=is_image,
metadata_last_modified=metadata_last_modified or last_modification_date,
starting_page_number=starting_page_number,
Expand Down Expand Up @@ -500,6 +510,7 @@ def _partition_pdf_or_image_local(
infer_table_structure: bool = False,
include_page_breaks: bool = False,
languages: Optional[list[str]] = None,
ocr_languages: Optional[str] = None,
ocr_mode: str = OCRMode.FULL_PAGE.value,
model_name: Optional[str] = None, # to be deprecated in favor of `hi_res_model_name`
hi_res_model_name: Optional[str] = None,
Expand Down Expand Up @@ -532,8 +543,6 @@ def _partition_pdf_or_image_local(
if languages is None:
languages = ["eng"]

ocr_languages = prepare_languages_for_tesseract(languages)

hi_res_model_name = hi_res_model_name or model_name or default_hi_res_model()
if pdf_image_dpi is None:
pdf_image_dpi = 300 if hi_res_model_name.startswith("chipper") else 200
Expand Down Expand Up @@ -819,7 +828,8 @@ def _partition_pdf_or_image_with_ocr(
filename: str = "",
file: Optional[bytes | IO[bytes]] = None,
include_page_breaks: bool = False,
languages: Optional[list[str]] = ["eng"],
languages: Optional[list[str]] = None,
ocr_languages: Optional[str] = None,
is_image: bool = False,
metadata_last_modified: Optional[str] = None,
starting_page_number: int = 1,
Expand All @@ -838,6 +848,7 @@ def _partition_pdf_or_image_with_ocr(
page_elements = _partition_pdf_or_image_with_ocr_from_image(
image=image,
languages=languages,
ocr_languages=ocr_languages,
page_number=page_number,
include_page_breaks=include_page_breaks,
metadata_last_modified=metadata_last_modified,
Expand All @@ -851,6 +862,7 @@ def _partition_pdf_or_image_with_ocr(
page_elements = _partition_pdf_or_image_with_ocr_from_image(
image=image,
languages=languages,
ocr_languages=ocr_languages,
page_number=page_number,
include_page_breaks=include_page_breaks,
metadata_last_modified=metadata_last_modified,
Expand All @@ -864,6 +876,7 @@ def _partition_pdf_or_image_with_ocr(
def _partition_pdf_or_image_with_ocr_from_image(
image: PILImage.Image,
languages: Optional[list[str]] = None,
ocr_languages: Optional[str] = None,
page_number: int = 1,
include_page_breaks: bool = False,
metadata_last_modified: Optional[str] = None,
Expand All @@ -874,17 +887,13 @@ def _partition_pdf_or_image_with_ocr_from_image(

from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent

ocr_agent = OCRAgent.get_agent()
ocr_languages = prepare_languages_for_tesseract(languages)
ocr_agent = OCRAgent.get_agent(language=ocr_languages)

# NOTE(christine): `unstructured_pytesseract.image_to_string()` returns sorted text
if ocr_agent.is_text_sorted():
sort_mode = SORT_MODE_DONT

ocr_data = ocr_agent.get_layout_elements_from_image(
image=image,
ocr_languages=ocr_languages,
)
ocr_data = ocr_agent.get_layout_elements_from_image(image=image)

metadata = ElementMetadata(
last_modified=metadata_last_modified,
Expand Down
Loading
Loading