Skip to content

Commit

Permalink
now can handle mixed two-column and single column papers and single c…
Browse files Browse the repository at this point in the history
…olumn papers
  • Loading branch information
bozyurt committed Nov 15, 2023
1 parent ca9e6cc commit d342f83
Showing 1 changed file with 45 additions and 35 deletions.
80 changes: 45 additions & 35 deletions hocr2pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import spacy
import utils


class BBox(object):
def __init__(self, y0, x0, y1, x1, node):
self.y0 = y0
Expand All @@ -24,7 +25,7 @@ def __repr__(self):
def from_node(cls, node):
bbox_str = node.attrib['title']
tokens = bbox_str.split()
assert len(tokens) == 5 and tokens[0] == 'bbox'
assert len(tokens) == 5 and tokens[0] == 'bbox'
return cls(int(tokens[1]), int(tokens[2]), int(tokens[3]),
int(tokens[4]), node)

Expand All @@ -39,7 +40,7 @@ def add_member(self, member: BBox):
self.members.append(member)
for m in self.members:
if self.y0min > m.y0:
self.y0min = m.y0
self.y0min = m.y0
if self.x0min > m.x0:
self.x0min = m.x0
if self.x1max < m.x1:
Expand All @@ -48,7 +49,7 @@ def add_member(self, member: BBox):
self.y1max = m.y1

def __str__(self):
return '[%s %s %s %s]' % (self.y0min, self.x0min, self.y1max, self.x1max )
return '[%s %s %s %s]' % (self.y0min, self.x0min, self.y1max, self.x1max)

def __repr__(self):
return self.__str__()
Expand All @@ -57,7 +58,7 @@ def belongs(self, candidate: BBox):
if not self.members:
return True
# print('{} =? {}'.format( self.members[0].y0, candidate.y0))
adiff = abs(self.members[0].y0 - candidate.y0)
adiff = abs(self.members[0].y0 - candidate.y0)
return adiff <= 1

def get_y0(self):
Expand All @@ -69,7 +70,6 @@ def size(self):
def area(self):
return (self.y1max - self.y0min) * (self.x1max - self.x0min)


def grow2(self, clusters, ymin, ymax):
grown = False
for c in clusters:
Expand All @@ -78,8 +78,7 @@ def grow2(self, clusters, ymin, ymax):
self.add_member(m)
grown = True
if grown:
self.members.sort(key= lambda x: x.x0)

self.members.sort(key=lambda x: x.x0)

def grow(self, clusters):
grown = False
Expand All @@ -95,12 +94,11 @@ def grow(self, clusters):
def is_inside(self, candidate: BBox):
c = candidate
x0, y0, x1, y1 = c.x0, c.y0, c.x1, c.y1
xmid = x0 + (x1 - x0)/2
ymid = y0 + (y1 - y0)/2
#return self.y0min <= y0 and self.y1max >= y1 and self.x0min <= x0 and self.x1max >= x1
xmid = x0 + (x1 - x0) / 2
ymid = y0 + (y1 - y0) / 2
# return self.y0min <= y0 and self.y1max >= y1 and self.x0min <= x0 and self.x1max >= x1
return self.y0min <= ymid and self.y1max >= ymid and self.x0min <= xmid and self.x1max >= xmid


def get_text(self, nlp=None):
lines = []
for m in self.members:
Expand All @@ -110,10 +108,10 @@ def get_text(self, nlp=None):
for i, line in enumerate(lines):
if len(line) == 0:
continue
next_tok = lines[i+1][0] if i+1 < num_lines else None
next_tok = lines[i + 1][0] if i + 1 < num_lines else None
if line[-1].endswith('-') and next_tok and str(next_tok[0]).islower():
line[-1] = line[-1][:-1] + lines[i+1][0]
del lines[i+1][0]
line[-1] = line[-1][:-1] + lines[i + 1][0]
del lines[i + 1][0]

if nlp:
lines = clean_figure_text(lines, nlp, from_top=True)
Expand All @@ -127,14 +125,14 @@ def get_text(self, nlp=None):


def sanitize(content):
content = content.replace("\uFB02 ",'fl')
content = content.replace("\uFB01 ",'fi')
content = content.replace("\uFB02",'fl')
content = content.replace("\uFB01 ",'fi')
content = content.replace("\uFB02 ", 'fl')
content = content.replace("\uFB01 ", 'fi')
content = content.replace("\uFB02", 'fl')
content = content.replace("\uFB01 ", 'fi')
return content


def clean_figure_text(lines, nlp, from_top=True):
def clean_figure_text(lines, nlp, from_top=True):
if len(lines) == 0:
return lines
removed = []
Expand All @@ -147,7 +145,7 @@ def clean_figure_text(lines, nlp, from_top=True):
break
i += 1
else:
i = len(lines) -1
i = len(lines) - 1
while i >= 0:
if utils.is_figure_text(lines[i], nlp):
removed.append(i)
Expand Down Expand Up @@ -199,7 +197,7 @@ def is_figure_caption(node):


def collect_all_text(node, lines):
if node.tag == 'div':
if node.tag == 'div':
for child in node:
collect_text(child, lines)
elif node.tag == 'span':
Expand All @@ -212,11 +210,11 @@ def collect_all_text(node, lines):


def collect_text(node, lines):
if node.tag == 'div':
if is_eligible(node):
for child in node:
collect_text(child, lines)
elif is_figure_caption(node):
if node.tag == 'div':
if is_eligible(node):
for child in node:
collect_text(child, lines)
elif is_figure_caption(node):
lines.append(['FIGURE_CAPTION'])
elif node.tag == 'span':
if node.attrib['class'] == 'ocrx_line':
Expand All @@ -231,7 +229,7 @@ def cluster_bboxes(bbox_list, top_el=None, nlp=None):
clusters = []
for bbox in bbox_list:
closest = None
for c in clusters:
for c in clusters:
if c.belongs(bbox):
closest = c
break
Expand All @@ -243,10 +241,10 @@ def cluster_bboxes(bbox_list, top_el=None, nlp=None):
clusters.append(c)
for c in clusters:
print("Cluster {} - members:{} {}".format(c.label, len(c.members), c))
print('-'*80)
print('-' * 80)
ct = find_columns(clusters)
if ct:
ymid = ct[0].y1max + (ct[1].y0min - ct[0].y1max)/2
ymid = ct[0].y1max + (ct[1].y0min - ct[0].y1max) / 2
ymax = ct[1].y1max + 50
clist = list(clusters)
clist.remove(ct[0])
Expand All @@ -268,8 +266,22 @@ def cluster_bboxes(bbox_list, top_el=None, nlp=None):
if top_el is not None:
page_el = SubElement(top_el, 'page')
page_el.text = content
else:
# assumption: a single column
clusters = []
for bbox in bbox_list:
c = Cluster(bbox.y0)
c.add_member(bbox)
clusters.append(c)

print('-'*80)
content = ''
for c in clusters:
content += c.get_text(nlp=nlp)
print(content)
if top_el is not None:
page_el = SubElement(top_el, 'page')
page_el.text = content
print('-' * 80)


def find_columns(clusters):
Expand All @@ -280,12 +292,11 @@ def find_columns(clusters):
clist.sort(reverse=True, key=lambda x: x.area())
# two column assumption
col1, col2 = clist[0], clist[1]
if col1.size() == 1 or col2.size() ==1 :
if col1.size() == 1 or col2.size() == 1:
return None
return (col1, col2) if col1.get_y0() < col2.get_y0() else (col2, col1)



def handle_page(node, num_cols=2, top_el=None, nlp=None):
bbox_list = []
for child in node:
Expand All @@ -302,7 +313,7 @@ def main():
parser.add_argument('-i', action='store', help="input HOCR html file", required=True)
parser.add_argument('-o', action='store', help="output text XML file", required=True)

args = parser.parse_args()
args = parser.parse_args()

hocr_file = args.i
out_xml_file = args.o
Expand All @@ -318,11 +329,10 @@ def main():
print("wrote file:", out_xml_file)



def test_driver():
tree = ET.parse('x.html')
for node in tree.findall('.//body/div'):
print (node.attrib)
print(node.attrib)
handle_page(node)


Expand Down

0 comments on commit d342f83

Please sign in to comment.