Skip to content

Commit

Permalink
Implement Adaptive RLSA (#154)
Browse files Browse the repository at this point in the history
* Update metrics and use RLSA to identify text

* Update column detection

* Add numba parallelization

* Simplify RLSA

* Update tests
  • Loading branch information
xavctn authored Dec 31, 2023
1 parent a05dc9e commit a688b8d
Show file tree
Hide file tree
Showing 16 changed files with 421 additions and 80 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ pyarrow>=7
numpy
pymupdf>=1.19.1
opencv-python
numba
beautifulsoup4
xlsxwriter>=3.0.6
1 change: 1 addition & 0 deletions src/img2table/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding: utf-8


class Validations:
def __post_init__(self):
"""Run validation methods if declared.
Expand Down
78 changes: 65 additions & 13 deletions src/img2table/tables/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,92 @@
import cv2
import numpy as np
import polars as pl
from numba import njit, prange

from img2table.tables.objects.cell import Cell


@njit("List(int64)(int32[:,:],int32[:,:])", fastmath=True, cache=True, parallel=False)
def remove_dots(cc_labels: np.ndarray, stats: np.ndarray) -> List[int]:
"""
Remove dots from connected components
:param cc_labels: connected components' label array
:param stats: connected components' stats array
:return: list of non-dot connected components' indexes
"""
cc_to_keep = list()

for idx in prange(len(stats)):
if idx == 0:
continue

x, y, w, h, area = stats[idx][:]

# If it is not squared, continue
if max(w, h) / min(w, h) >= 1.5:
cc_to_keep.append(idx)
continue

# Check number of inner pixels
inner_pixels = 0
for row in prange(y, y + h):
prev_position = -1
for col in range(x, x + w):
value = cc_labels[row][col]
if value == idx:
if prev_position >= 0:
inner_pixels += col - prev_position - 1
prev_position = col

for col in prange(x, x + w):
prev_position = -1
for row in range(y, y + h):
value = cc_labels[row][col]
if value == idx:
if prev_position >= 0:
inner_pixels += row - prev_position - 1
prev_position = row

if not inner_pixels / (2 * area) <= 0.05:
cc_to_keep.append(idx)

return cc_to_keep


def compute_char_length(img: np.ndarray) -> Tuple[Optional[float], Optional[np.ndarray]]:
"""
Compute average character length based on connected components analysis
Compute average character length based on connected components' analysis
:param img: image array
:return: tuple with average character length and connected components array
"""
# Thresholding
_, thresh = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)

# Connected components
_, _, stats, _ = cv2.connectedComponentsWithStats(thresh, 8, cv2.CV_32S)
_, cc_labels, stats, _ = cv2.connectedComponentsWithStats(thresh, 8, cv2.CV_32S)

# Remove connected components with less than 5 pixels
mask_pixels = stats[:, cv2.CC_STAT_AREA] > 5
# Remove dots
cc_to_keep = remove_dots(cc_labels=cc_labels, stats=stats)
stats = stats[cc_to_keep, :]

# Remove connected components with less than 10 pixels
mask_pixels = stats[:, cv2.CC_STAT_AREA] > 10
stats = stats[mask_pixels]

if len(stats) == 0:
return None, None

# Create mask to remove connected components corresponding to the complete image
mask_height = img.shape[0] > stats[:, cv2.CC_STAT_HEIGHT]
mask_width = img.shape[1] > stats[:, cv2.CC_STAT_WIDTH]
mask_img = mask_width & mask_height

# Filter components based on aspect ratio
mask_lower_ar = 0.5 < stats[:, cv2.CC_STAT_WIDTH] / stats[:, cv2.CC_STAT_HEIGHT]
mask_upper_ar = 2 > stats[:, cv2.CC_STAT_WIDTH] / stats[:, cv2.CC_STAT_HEIGHT]
mask_ar = mask_lower_ar & mask_upper_ar
mask_ar = (np.maximum(stats[:, cv2.CC_STAT_WIDTH], stats[:, cv2.CC_STAT_HEIGHT])
/ np.minimum(stats[:, cv2.CC_STAT_WIDTH], stats[:, cv2.CC_STAT_HEIGHT])) <= 2

stats = stats[mask_img & mask_ar]
# Filter components based on fill ratio
mask_fill = stats[:, cv2.CC_STAT_AREA] / (stats[:, cv2.CC_STAT_WIDTH] * stats[:, cv2.CC_STAT_HEIGHT]) > 0.08

stats = stats[mask_ar & mask_fill]

if len(stats) == 0:
return None, None

# Compute median width and height
median_width = np.median(stats[:, cv2.CC_STAT_WIDTH])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ def get_coherent_whitespace_position(ws: Cell, elements: List[Cell]) -> Cell:


def filter_coherent_delimiters(delimiters: List[Cell], elements: List[Cell]) -> List[Cell]:
"""
Filter only delimiters that had value to the group
:param delimiters: list of delimiters
:param elements: list of elements contained within the delimiter group
:return: filtered list of delimiters
"""
# Check delimiters coherency (i.e) if it adds value
filtered_delims = list()
for delim in delimiters:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,21 @@
from img2table.tables.processing.borderless_tables.whitespaces import adjacent_whitespaces


@dataclass
class MatchingWS:
ws: Cell
position: int
top: bool
bottom: bool


@dataclass
class VertWS:
x1: int
x2: int
y1: int
y2: int
whitespaces: List[Cell] = field(default_factory=lambda: [])
positions: List[int] = field(default_factory=lambda: [])

@property
def y1(self):
return min([ws.y1 for ws in self.whitespaces]) if self.whitespaces else 0

@property
def y2(self):
return max([ws.y2 for ws in self.whitespaces]) if self.whitespaces else 0

@property
def height(self):
Expand All @@ -36,19 +37,6 @@ def width(self):
def cell(self) -> Cell:
return Cell(x1=self.x1, y1=self.y1, x2=self.x2, y2=self.y2)

@property
def continuous(self):
if self.positions:
positions = sorted(self.positions)
return all([p2 - p1 <= 1 for p1, p2 in zip(positions, positions[1:])])
return False

def add_ws(self, whitespaces: List[Cell]):
self.whitespaces += whitespaces

def add_position(self, position: int):
self.positions.append(position)


def deduplicate_whitespaces(vertical_whitespaces: List[VertWS], elements: List[Cell]) -> List[VertWS]:
"""
Expand Down Expand Up @@ -108,26 +96,48 @@ def get_vertical_whitespaces(table_segment: TableSegment) -> Tuple[List[Cell], L
:param table_segment: TableSegment object
:return: tuple containing list of vertical whitespaces and list of unused whitespaces
"""
table_areas = sorted(table_segment.table_areas, key=lambda x: x.position)

# Identify all whitespaces x values
x_ws = sorted(set([ws.x1 for ws in table_segment.whitespaces] + [ws.x2 for ws in table_segment.whitespaces]))

# Get vertical whitespaces
vertical_ws = list()
for x_left, x_right in zip(x_ws, x_ws[1:]):
# Create a whitespace object
vert_ws = VertWS(x1=x_left, x2=x_right)

for tb_area in table_segment.table_areas:
rng_ws = list()
for id_area, tb_area in enumerate(table_areas):
# Get matching whitespaces
matching_ws = [ws for ws in tb_area.whitespaces if min(vert_ws.x2, ws.x2) - max(vert_ws.x1, ws.x1) > 0]
matching_ws = sorted([ws for ws in tb_area.whitespaces if min(x_right, ws.x2) - max(x_left, ws.x1) > 0],
key=lambda ws: ws.y1)

if matching_ws:
vert_ws.add_position(tb_area.position)
vert_ws.add_ws(matching_ws)

# If it is composed of continuous whitespaces, use them
if vert_ws.continuous:
vertical_ws.append(vert_ws)
for ws in matching_ws:
m_ws = MatchingWS(ws=ws,
position=id_area,
top=ws.y1 == tb_area.y1,
bottom=ws.y2 == tb_area.y2)
rng_ws.append(m_ws)

if rng_ws:
# Create cluster of coherent ws
seq = iter(rng_ws)
ws_clusters = [[next(seq)]]
for m_ws in seq:
prev_ws = ws_clusters[-1][-1]

# If consecutive ws do not match, create a new cluster
if m_ws.position - prev_ws.position > 1 or not (prev_ws.bottom and m_ws.top):
ws_clusters.append([])
ws_clusters[-1].append(m_ws)

for cl in ws_clusters:
# Compute vertical boundaries
y1 = table_areas[cl[0].position - 1].y2 if cl[0].top and cl[0].position > 0 else cl[0].ws.y1
y2 = table_areas[cl[-1].position + 1].y1 if cl[-1].bottom and cl[-1].position < len(table_areas) - 1 else cl[-1].ws.y2

vert_ws = VertWS(x1=x_left, x2=x_right, y1=y1, y2=y2,
whitespaces=[m_ws.ws for m_ws in cl])
vertical_ws.append(vert_ws)

# Filter whitespaces by height
max_height = max([ws.height for ws in vertical_ws])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from img2table.tables.objects.line import Line
from img2table.tables.processing.borderless_tables.layout.column_segments import segment_image_columns
from img2table.tables.processing.borderless_tables.layout.image_elements import get_image_elements
from img2table.tables.processing.borderless_tables.layout.rlsa import identify_text_mask
from img2table.tables.processing.borderless_tables.layout.table_segments import get_table_segments
from img2table.tables.processing.borderless_tables.model import TableSegment, ImageSegment


def segment_image(thresh: np.ndarray, lines: List[Line], char_length: float, median_line_sep: float) -> List[TableSegment]:
def segment_image(thresh: np.ndarray, lines: List[Line], char_length: float,
median_line_sep: float) -> List[TableSegment]:
"""
Segment image and its elements
:param thresh: thresholded image array
Expand All @@ -19,9 +21,13 @@ def segment_image(thresh: np.ndarray, lines: List[Line], char_length: float, med
:param median_line_sep: median line separation
:return: list of ImageSegment objects with corresponding elements
"""
# Identify text mask
text_thresh = identify_text_mask(thresh=thresh,
lines=lines,
char_length=char_length)

# Identify image elements
img_elements = get_image_elements(thresh=thresh,
lines=lines,
img_elements = get_image_elements(thresh=text_thresh,
char_length=char_length,
median_line_sep=median_line_sep)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,16 @@
import numpy as np

from img2table.tables.objects.cell import Cell
from img2table.tables.objects.line import Line


def get_image_elements(thresh: np.ndarray, lines: List[Line], char_length: float, median_line_sep: float) -> List[Cell]:
def get_image_elements(thresh: np.ndarray, char_length: float, median_line_sep: float) -> List[Cell]:
"""
Identify image elements
:param thresh: thresholded image array
:param lines: list of image rows
:param char_length: average character length
:param median_line_sep: median line separation
:return: list of image elements
"""
# Mask rows
for l in lines:
if l.horizontal and l.length >= 3 * char_length:
cv2.rectangle(thresh, (l.x1 - l.thickness, l.y1), (l.x2 + l.thickness, l.y2), (0, 0, 0), 3 * l.thickness)
elif l.vertical and l.length >= 2 * char_length:
cv2.rectangle(thresh, (l.x1, l.y1 - l.thickness), (l.x2, l.y2 + l.thickness), (0, 0, 0), 3 * l.thickness)

# Dilate to combine adjacent text contours
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,
(max(int(char_length), 1), max(int(median_line_sep // 6), 1)))
Expand Down
Loading

0 comments on commit a688b8d

Please sign in to comment.