Skip to content

Commit

Permalink
ocr engine first integration
Browse files Browse the repository at this point in the history
  • Loading branch information
vahidrezanezhad committed Jul 17, 2024
1 parent eac18c5 commit 5144668
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 5 deletions.
8 changes: 8 additions & 0 deletions qurator/eynollah/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@
is_flag=True,
help="if this parameter set to true, this tool would apply machine based reading order detection",
)
@click.option(
"--do_ocr",
"-ocr/-noocr",
is_flag=True,
help="if this parameter set to true, this tool will try to do ocr",
)
@click.option(
"--log-level",
"-l",
Expand Down Expand Up @@ -167,6 +173,7 @@ def main(
headers_off,
light_version,
reading_order_machine_based,
do_ocr,
ignore_page_extraction,
log_level
):
Expand Down Expand Up @@ -205,6 +212,7 @@ def main(
light_version=light_version,
ignore_page_extraction=ignore_page_extraction,
reading_order_machine_based=reading_order_machine_based,
do_ocr=do_ocr,
)
eynollah.run()
#pcgts = eynollah.run()
Expand Down
295 changes: 294 additions & 1 deletion qurator/eynollah/eynollah.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
from ocrd_utils import getLogger
import cv2
import numpy as np
from transformers import TrOCRProcessor
from PIL import Image
import torch
from difflib import SequenceMatcher as sq
from transformers import VisionEncoderDecoderModel
from numba import cuda
import copy
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
stderr = sys.stderr
sys.stderr = open(os.devnull, "w")
Expand Down Expand Up @@ -166,6 +176,7 @@ def __init__(
light_version=False,
ignore_page_extraction=False,
reading_order_machine_based=False,
do_ocr=False,
override_dpi=None,
logger=None,
pcgts=None,
Expand Down Expand Up @@ -199,6 +210,7 @@ def __init__(
self.headers_off = headers_off
self.light_version = light_version
self.ignore_page_extraction = ignore_page_extraction
self.ocr = do_ocr
self.pcgts = pcgts
if not dir_in:
self.plotter = None if not enable_plotting else EynollahPlotter(
Expand Down Expand Up @@ -233,6 +245,9 @@ def __init__(
self.model_textline_dir = dir_models + "/eynollah-textline_light_20210425"
else:
self.model_textline_dir = dir_models + "/eynollah-textline_20210425"
if self.ocr:
self.model_ocr_dir = dir_models + "/checkpoint-166692_printed_trocr"

self.model_tables = dir_models + "/eynollah-tables_20210319"

self.models = {}
Expand All @@ -251,6 +266,10 @@ def __init__(
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
self.model_reading_order_machine = self.our_load_model(self.model_reading_order_machine_dir)
if self.ocr:
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten")

self.ls_imgs = os.listdir(self.dir_in)

Expand Down Expand Up @@ -3135,6 +3154,223 @@ def do_order_of_regions_with_machine_optimized_algorithm(self,contours_only_text


return order_of_texts, id_of_texts
def return_start_and_end_of_common_text_of_textline_ocr(self,textline_image, ind_tot):
width = np.shape(textline_image)[1]
height = np.shape(textline_image)[0]
common_window = int(0.2*width)

width1 = int ( width/2. - common_window )
width2 = int ( width/2. + common_window )

img_sum = np.sum(textline_image[:,:,0], axis=0)
sum_smoothed = gaussian_filter1d(img_sum, 3)

peaks_real, _ = find_peaks(sum_smoothed, height=0)

if len(peaks_real)>70:
print(len(peaks_real), 'len(peaks_real)')

peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]

arg_sort = np.argsort(sum_smoothed[peaks_real])

arg_sort4 =arg_sort[::-1][:4]

peaks_sort_4 = peaks_real[arg_sort][::-1][:4]

argsort_sorted = np.argsort(peaks_sort_4)

first_4_sorted = peaks_sort_4[argsort_sorted]
y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]]
#print(first_4_sorted,'first_4_sorted')

arg_sortnew = np.argsort(y_4_sorted)
peaks_final =np.sort( first_4_sorted[arg_sortnew][2:] )

#plt.figure(ind_tot)
#plt.imshow(textline_image)
#plt.plot([peaks_final[0], peaks_final[0]], [0, height-1])
#plt.plot([peaks_final[1], peaks_final[1]], [0, height-1])
#plt.savefig('./'+str(ind_tot)+'.png')

return peaks_final[0], peaks_final[1]
else:
pass


def return_start_and_end_of_common_text_of_textline_ocr_without_common_section(self,textline_image, ind_tot):
width = np.shape(textline_image)[1]
height = np.shape(textline_image)[0]
common_window = int(0.06*width)

width1 = int ( width/2. - common_window )
width2 = int ( width/2. + common_window )

img_sum = np.sum(textline_image[:,:,0], axis=0)
sum_smoothed = gaussian_filter1d(img_sum, 3)

peaks_real, _ = find_peaks(sum_smoothed, height=0)

if len(peaks_real)>70:
#print(len(peaks_real), 'len(peaks_real)')

peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)]

arg_max = np.argmax(sum_smoothed[peaks_real])

peaks_final = peaks_real[arg_max]

#plt.figure(ind_tot)
#plt.imshow(textline_image)
#plt.plot([peaks_final, peaks_final], [0, height-1])
##plt.plot([peaks_final[1], peaks_final[1]], [0, height-1])
#plt.savefig('./'+str(ind_tot)+'.png')

return peaks_final
else:
return None
def return_start_and_end_of_common_text_of_textline_ocr_new_splitted(self,peaks_real, sum_smoothed, start_split, end_split):
peaks_real = peaks_real[(peaks_real<end_split) & (peaks_real>start_split)]

arg_sort = np.argsort(sum_smoothed[peaks_real])

arg_sort4 =arg_sort[::-1][:4]

peaks_sort_4 = peaks_real[arg_sort][::-1][:4]

argsort_sorted = np.argsort(peaks_sort_4)

first_4_sorted = peaks_sort_4[argsort_sorted]
y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]]
#print(first_4_sorted,'first_4_sorted')

arg_sortnew = np.argsort(y_4_sorted)
peaks_final =np.sort( first_4_sorted[arg_sortnew][3:] )
return peaks_final[0]

def return_start_and_end_of_common_text_of_textline_ocr_new(self,textline_image, ind_tot):
width = np.shape(textline_image)[1]
height = np.shape(textline_image)[0]
common_window = int(0.15*width)

width1 = int ( width/2. - common_window )
width2 = int ( width/2. + common_window )
mid = int(width/2.)

img_sum = np.sum(textline_image[:,:,0], axis=0)
sum_smoothed = gaussian_filter1d(img_sum, 3)

peaks_real, _ = find_peaks(sum_smoothed, height=0)

if len(peaks_real)>70:
peak_start = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted(peaks_real, sum_smoothed, width1, mid+2)

peak_end = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted(peaks_real, sum_smoothed, mid-2, width2)

#plt.figure(ind_tot)
#plt.imshow(textline_image)
#plt.plot([peak_start, peak_start], [0, height-1])
#plt.plot([peak_end, peak_end], [0, height-1])
#plt.savefig('./'+str(ind_tot)+'.png')

return peak_start, peak_end
else:
pass

def return_ocr_of_textline_without_common_section(self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot):
if h2w_ratio > 0.05:
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
generated_ids = model_ocr.generate(pixel_values.to(device))
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:

#width = np.shape(textline_image)[1]
#height = np.shape(textline_image)[0]
#common_window = int(0.3*width)

#width1 = int ( width/2. - common_window )
#width2 = int ( width/2. + common_window )


split_point = self.return_start_and_end_of_common_text_of_textline_ocr_without_common_section(textline_image, ind_tot)
if split_point:
image1 = textline_image[:, :split_point,:]# image.crop((0, 0, width2, height))
image2 = textline_image[:, split_point:,:]#image.crop((width1, 0, width, height))

#pixel_values1 = processor(image1, return_tensors="pt").pixel_values
#pixel_values2 = processor(image2, return_tensors="pt").pixel_values

pixel_values_merged = processor([image1,image2], return_tensors="pt").pixel_values
generated_ids_merged = model_ocr.generate(pixel_values_merged.to(device))
generated_text_merged = processor.batch_decode(generated_ids_merged, skip_special_tokens=True)

#print(generated_text_merged,'generated_text_merged')

#generated_ids1 = model_ocr.generate(pixel_values1.to(device))
#generated_ids2 = model_ocr.generate(pixel_values2.to(device))

#generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0]
#generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0]

#generated_text = generated_text1 + ' ' + generated_text2
generated_text = generated_text_merged[0] + ' ' + generated_text_merged[1]

#print(generated_text1,'generated_text1')
#print(generated_text2, 'generated_text2')
#print('########################################')
else:
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
generated_ids = model_ocr.generate(pixel_values.to(device))
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

#print(generated_text,'generated_text')
#print('########################################')
return generated_text
def return_ocr_of_textline(self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot):
if h2w_ratio > 0.05:
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
generated_ids = model_ocr.generate(pixel_values.to(device))
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
#width = np.shape(textline_image)[1]
#height = np.shape(textline_image)[0]
#common_window = int(0.3*width)

#width1 = int ( width/2. - common_window )
#width2 = int ( width/2. + common_window )

try:
width1, width2 = self.return_start_and_end_of_common_text_of_textline_ocr_new(textline_image, ind_tot)

image1 = textline_image[:, :width2,:]# image.crop((0, 0, width2, height))
image2 = textline_image[:, width1:,:]#image.crop((width1, 0, width, height))

pixel_values1 = processor(image1, return_tensors="pt").pixel_values
pixel_values2 = processor(image2, return_tensors="pt").pixel_values

generated_ids1 = model_ocr.generate(pixel_values1.to(device))
generated_ids2 = model_ocr.generate(pixel_values2.to(device))

generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0]
generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0]
#print(generated_text1,'generated_text1')
#print(generated_text2, 'generated_text2')
#print('########################################')

match = sq(None, generated_text1, generated_text2).find_longest_match(0, len(generated_text1), 0, len(generated_text2))

generated_text = generated_text1 + generated_text2[match.b+match.size:]
except:
pixel_values = processor(textline_image, return_tensors="pt").pixel_values
generated_ids = model_ocr.generate(pixel_values.to(device))
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

return generated_text

def return_textline_contour_with_added_box_coordinate(self, textline_contour, box_ind):
textline_contour[:,0] = textline_contour[:,0] + box_ind[2]
textline_contour[:,1] = textline_contour[:,1] + box_ind[0]
return textline_contour

def run(self):
"""
Expand Down Expand Up @@ -3398,6 +3634,7 @@ def run(self):
if self.plotter:
self.plotter.write_images_into_directory(polygons_of_images, image_page)
t_order = time.time()

if self.full_layout:

if self.reading_order_machine_based:
Expand Down Expand Up @@ -3425,11 +3662,67 @@ def run(self):
contours_only_text_parent_d_ordered = list(np.array(contours_only_text_parent_d_ordered, dtype=object)[index_by_text_par_con])
order_text_new, id_of_texts_tot = self.do_order_of_regions(contours_only_text_parent_d_ordered, contours_only_text_parent_h, boxes_d, textline_mask_tot_d)


if self.ocr:

device = cuda.get_current_device()
device.reset()
gc.collect()
model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
torch.cuda.empty_cache()
model_ocr.to(device)

ind_tot = 0
#cv2.imwrite('./img_out.png', image_page)

ocr_all_textlines = []
for indexing, ind_poly_first in enumerate(all_found_textline_polygons):
ocr_textline_in_textregion = []
for indexing2, ind_poly in enumerate(ind_poly_first):
if not (self.textline_light or self.curved_line):
ind_poly = copy.deepcopy(ind_poly)
box_ind = all_box_coord[indexing]
#print(ind_poly,np.shape(ind_poly), 'ind_poly')
#print(box_ind)
ind_poly = self.return_textline_contour_with_added_box_coordinate(ind_poly, box_ind)
#print(ind_poly_copy)
ind_poly[ind_poly<0] = 0
x, y, w, h = cv2.boundingRect(ind_poly)
#print(ind_poly_copy, np.shape(ind_poly_copy))
#print(x, y, w, h, h/float(w),'ratio')
h2w_ratio = h/float(w)
mask_poly = np.zeros(image_page.shape)
img_poly_on_img = np.copy(image_page)

mask_poly = cv2.fillPoly(mask_poly, pts=[ind_poly], color=(1, 1, 1))

if self.textline_light:
mask_poly = cv2.dilate(mask_poly, KERNEL, iterations=1)

img_poly_on_img[:,:,0][mask_poly[:,:,0] ==0] = 255
img_poly_on_img[:,:,1][mask_poly[:,:,0] ==0] = 255
img_poly_on_img[:,:,2][mask_poly[:,:,0] ==0] = 255

img_croped = img_poly_on_img[y:y+h, x:x+w, :]
text_ocr = self.return_ocr_of_textline_without_common_section(img_croped, model_ocr, processor, device, w, h2w_ratio, ind_tot)

ocr_textline_in_textregion.append(text_ocr)

##cv2.imwrite(str(ind_tot)+'.png', img_croped)
ind_tot = ind_tot +1
ocr_all_textlines.append(ocr_textline_in_textregion)

else:
ocr_all_textlines = None
#print(ocr_all_textlines)
self.logger.info("detection of reading order took %.1fs", time.time() - t_order)
pcgts = self.writer.build_pagexml_no_full_layout(txt_con_org, page_coord, order_text_new, id_of_texts_tot, all_found_textline_polygons, all_box_coord, polygons_of_images, polygons_of_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_xml, contours_tables)
pcgts = self.writer.build_pagexml_no_full_layout(txt_con_org, page_coord, order_text_new, id_of_texts_tot, all_found_textline_polygons, all_box_coord, polygons_of_images, polygons_of_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_xml, contours_tables, ocr_all_textlines)
self.logger.info("Job done in %.1fs", time.time() - t0)
##return pcgts
self.writer.write_pagexml(pcgts)
#self.logger.info("Job done in %.1fs", time.time() - t0)

if self.dir_in:
self.logger.info("All jobs done in %.1fs", time.time() - t0_tot)
Loading

0 comments on commit 5144668

Please sign in to comment.