diff --git a/qurator/eynollah/cli.py b/qurator/eynollah/cli.py index a422df9..833e904 100644 --- a/qurator/eynollah/cli.py +++ b/qurator/eynollah/cli.py @@ -139,6 +139,12 @@ is_flag=True, help="if this parameter set to true, this tool would apply machine based reading order detection", ) +@click.option( + "--do_ocr", + "-ocr/-noocr", + is_flag=True, + help="if this parameter set to true, this tool will try to do ocr", +) @click.option( "--log-level", "-l", @@ -167,6 +173,7 @@ def main( headers_off, light_version, reading_order_machine_based, + do_ocr, ignore_page_extraction, log_level ): @@ -205,6 +212,7 @@ def main( light_version=light_version, ignore_page_extraction=ignore_page_extraction, reading_order_machine_based=reading_order_machine_based, + do_ocr=do_ocr, ) eynollah.run() #pcgts = eynollah.run() diff --git a/qurator/eynollah/eynollah.py b/qurator/eynollah/eynollah.py index 5e06734..a505b0e 100644 --- a/qurator/eynollah/eynollah.py +++ b/qurator/eynollah/eynollah.py @@ -17,6 +17,16 @@ from ocrd_utils import getLogger import cv2 import numpy as np +from transformers import TrOCRProcessor +from PIL import Image +import torch +from difflib import SequenceMatcher as sq +from transformers import VisionEncoderDecoderModel +from numba import cuda +import copy +from scipy.signal import find_peaks +from scipy.ndimage import gaussian_filter1d + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" stderr = sys.stderr sys.stderr = open(os.devnull, "w") @@ -166,6 +176,7 @@ def __init__( light_version=False, ignore_page_extraction=False, reading_order_machine_based=False, + do_ocr=False, override_dpi=None, logger=None, pcgts=None, @@ -199,6 +210,7 @@ def __init__( self.headers_off = headers_off self.light_version = light_version self.ignore_page_extraction = ignore_page_extraction + self.ocr = do_ocr self.pcgts = pcgts if not dir_in: self.plotter = None if not enable_plotting else EynollahPlotter( @@ -233,6 +245,9 @@ def __init__( self.model_textline_dir = dir_models + "/eynollah-textline_light_20210425" else: self.model_textline_dir = dir_models + "/eynollah-textline_20210425" + if self.ocr: + self.model_ocr_dir = dir_models + "/checkpoint-166692_printed_trocr" + self.model_tables = dir_models + "/eynollah-tables_20210319" self.models = {} @@ -251,6 +266,10 @@ def __init__( self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np) self.model_region_fl = self.our_load_model(self.model_region_dir_fully) self.model_reading_order_machine = self.our_load_model(self.model_reading_order_machine_dir) + if self.ocr: + self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten") self.ls_imgs = os.listdir(self.dir_in) @@ -3135,6 +3154,223 @@ def do_order_of_regions_with_machine_optimized_algorithm(self,contours_only_text return order_of_texts, id_of_texts + def return_start_and_end_of_common_text_of_textline_ocr(self,textline_image, ind_tot): + width = np.shape(textline_image)[1] + height = np.shape(textline_image)[0] + common_window = int(0.2*width) + + width1 = int ( width/2. - common_window ) + width2 = int ( width/2. + common_window ) + + img_sum = np.sum(textline_image[:,:,0], axis=0) + sum_smoothed = gaussian_filter1d(img_sum, 3) + + peaks_real, _ = find_peaks(sum_smoothed, height=0) + + if len(peaks_real)>70: + print(len(peaks_real), 'len(peaks_real)') + + peaks_real = peaks_real[(peaks_realwidth1)] + + arg_sort = np.argsort(sum_smoothed[peaks_real]) + + arg_sort4 =arg_sort[::-1][:4] + + peaks_sort_4 = peaks_real[arg_sort][::-1][:4] + + argsort_sorted = np.argsort(peaks_sort_4) + + first_4_sorted = peaks_sort_4[argsort_sorted] + y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]] + #print(first_4_sorted,'first_4_sorted') + + arg_sortnew = np.argsort(y_4_sorted) + peaks_final =np.sort( first_4_sorted[arg_sortnew][2:] ) + + #plt.figure(ind_tot) + #plt.imshow(textline_image) + #plt.plot([peaks_final[0], peaks_final[0]], [0, height-1]) + #plt.plot([peaks_final[1], peaks_final[1]], [0, height-1]) + #plt.savefig('./'+str(ind_tot)+'.png') + + return peaks_final[0], peaks_final[1] + else: + pass + + + def return_start_and_end_of_common_text_of_textline_ocr_without_common_section(self,textline_image, ind_tot): + width = np.shape(textline_image)[1] + height = np.shape(textline_image)[0] + common_window = int(0.06*width) + + width1 = int ( width/2. - common_window ) + width2 = int ( width/2. + common_window ) + + img_sum = np.sum(textline_image[:,:,0], axis=0) + sum_smoothed = gaussian_filter1d(img_sum, 3) + + peaks_real, _ = find_peaks(sum_smoothed, height=0) + + if len(peaks_real)>70: + #print(len(peaks_real), 'len(peaks_real)') + + peaks_real = peaks_real[(peaks_realwidth1)] + + arg_max = np.argmax(sum_smoothed[peaks_real]) + + peaks_final = peaks_real[arg_max] + + #plt.figure(ind_tot) + #plt.imshow(textline_image) + #plt.plot([peaks_final, peaks_final], [0, height-1]) + ##plt.plot([peaks_final[1], peaks_final[1]], [0, height-1]) + #plt.savefig('./'+str(ind_tot)+'.png') + + return peaks_final + else: + return None + def return_start_and_end_of_common_text_of_textline_ocr_new_splitted(self,peaks_real, sum_smoothed, start_split, end_split): + peaks_real = peaks_real[(peaks_realstart_split)] + + arg_sort = np.argsort(sum_smoothed[peaks_real]) + + arg_sort4 =arg_sort[::-1][:4] + + peaks_sort_4 = peaks_real[arg_sort][::-1][:4] + + argsort_sorted = np.argsort(peaks_sort_4) + + first_4_sorted = peaks_sort_4[argsort_sorted] + y_4_sorted = sum_smoothed[peaks_real][arg_sort4[argsort_sorted]] + #print(first_4_sorted,'first_4_sorted') + + arg_sortnew = np.argsort(y_4_sorted) + peaks_final =np.sort( first_4_sorted[arg_sortnew][3:] ) + return peaks_final[0] + + def return_start_and_end_of_common_text_of_textline_ocr_new(self,textline_image, ind_tot): + width = np.shape(textline_image)[1] + height = np.shape(textline_image)[0] + common_window = int(0.15*width) + + width1 = int ( width/2. - common_window ) + width2 = int ( width/2. + common_window ) + mid = int(width/2.) + + img_sum = np.sum(textline_image[:,:,0], axis=0) + sum_smoothed = gaussian_filter1d(img_sum, 3) + + peaks_real, _ = find_peaks(sum_smoothed, height=0) + + if len(peaks_real)>70: + peak_start = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted(peaks_real, sum_smoothed, width1, mid+2) + + peak_end = self.return_start_and_end_of_common_text_of_textline_ocr_new_splitted(peaks_real, sum_smoothed, mid-2, width2) + + #plt.figure(ind_tot) + #plt.imshow(textline_image) + #plt.plot([peak_start, peak_start], [0, height-1]) + #plt.plot([peak_end, peak_end], [0, height-1]) + #plt.savefig('./'+str(ind_tot)+'.png') + + return peak_start, peak_end + else: + pass + + def return_ocr_of_textline_without_common_section(self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot): + if h2w_ratio > 0.05: + pixel_values = processor(textline_image, return_tensors="pt").pixel_values + generated_ids = model_ocr.generate(pixel_values.to(device)) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + else: + + #width = np.shape(textline_image)[1] + #height = np.shape(textline_image)[0] + #common_window = int(0.3*width) + + #width1 = int ( width/2. - common_window ) + #width2 = int ( width/2. + common_window ) + + + split_point = self.return_start_and_end_of_common_text_of_textline_ocr_without_common_section(textline_image, ind_tot) + if split_point: + image1 = textline_image[:, :split_point,:]# image.crop((0, 0, width2, height)) + image2 = textline_image[:, split_point:,:]#image.crop((width1, 0, width, height)) + + #pixel_values1 = processor(image1, return_tensors="pt").pixel_values + #pixel_values2 = processor(image2, return_tensors="pt").pixel_values + + pixel_values_merged = processor([image1,image2], return_tensors="pt").pixel_values + generated_ids_merged = model_ocr.generate(pixel_values_merged.to(device)) + generated_text_merged = processor.batch_decode(generated_ids_merged, skip_special_tokens=True) + + #print(generated_text_merged,'generated_text_merged') + + #generated_ids1 = model_ocr.generate(pixel_values1.to(device)) + #generated_ids2 = model_ocr.generate(pixel_values2.to(device)) + + #generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0] + #generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0] + + #generated_text = generated_text1 + ' ' + generated_text2 + generated_text = generated_text_merged[0] + ' ' + generated_text_merged[1] + + #print(generated_text1,'generated_text1') + #print(generated_text2, 'generated_text2') + #print('########################################') + else: + pixel_values = processor(textline_image, return_tensors="pt").pixel_values + generated_ids = model_ocr.generate(pixel_values.to(device)) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + #print(generated_text,'generated_text') + #print('########################################') + return generated_text + def return_ocr_of_textline(self, textline_image, model_ocr, processor, device, width_textline, h2w_ratio,ind_tot): + if h2w_ratio > 0.05: + pixel_values = processor(textline_image, return_tensors="pt").pixel_values + generated_ids = model_ocr.generate(pixel_values.to(device)) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + else: + #width = np.shape(textline_image)[1] + #height = np.shape(textline_image)[0] + #common_window = int(0.3*width) + + #width1 = int ( width/2. - common_window ) + #width2 = int ( width/2. + common_window ) + + try: + width1, width2 = self.return_start_and_end_of_common_text_of_textline_ocr_new(textline_image, ind_tot) + + image1 = textline_image[:, :width2,:]# image.crop((0, 0, width2, height)) + image2 = textline_image[:, width1:,:]#image.crop((width1, 0, width, height)) + + pixel_values1 = processor(image1, return_tensors="pt").pixel_values + pixel_values2 = processor(image2, return_tensors="pt").pixel_values + + generated_ids1 = model_ocr.generate(pixel_values1.to(device)) + generated_ids2 = model_ocr.generate(pixel_values2.to(device)) + + generated_text1 = processor.batch_decode(generated_ids1, skip_special_tokens=True)[0] + generated_text2 = processor.batch_decode(generated_ids2, skip_special_tokens=True)[0] + #print(generated_text1,'generated_text1') + #print(generated_text2, 'generated_text2') + #print('########################################') + + match = sq(None, generated_text1, generated_text2).find_longest_match(0, len(generated_text1), 0, len(generated_text2)) + + generated_text = generated_text1 + generated_text2[match.b+match.size:] + except: + pixel_values = processor(textline_image, return_tensors="pt").pixel_values + generated_ids = model_ocr.generate(pixel_values.to(device)) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + return generated_text + + def return_textline_contour_with_added_box_coordinate(self, textline_contour, box_ind): + textline_contour[:,0] = textline_contour[:,0] + box_ind[2] + textline_contour[:,1] = textline_contour[:,1] + box_ind[0] + return textline_contour def run(self): """ @@ -3398,6 +3634,7 @@ def run(self): if self.plotter: self.plotter.write_images_into_directory(polygons_of_images, image_page) t_order = time.time() + if self.full_layout: if self.reading_order_machine_based: @@ -3425,11 +3662,67 @@ def run(self): contours_only_text_parent_d_ordered = list(np.array(contours_only_text_parent_d_ordered, dtype=object)[index_by_text_par_con]) order_text_new, id_of_texts_tot = self.do_order_of_regions(contours_only_text_parent_d_ordered, contours_only_text_parent_h, boxes_d, textline_mask_tot_d) + + if self.ocr: + + device = cuda.get_current_device() + device.reset() + gc.collect() + model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") + torch.cuda.empty_cache() + model_ocr.to(device) + + ind_tot = 0 + #cv2.imwrite('./img_out.png', image_page) + + ocr_all_textlines = [] + for indexing, ind_poly_first in enumerate(all_found_textline_polygons): + ocr_textline_in_textregion = [] + for indexing2, ind_poly in enumerate(ind_poly_first): + if not (self.textline_light or self.curved_line): + ind_poly = copy.deepcopy(ind_poly) + box_ind = all_box_coord[indexing] + #print(ind_poly,np.shape(ind_poly), 'ind_poly') + #print(box_ind) + ind_poly = self.return_textline_contour_with_added_box_coordinate(ind_poly, box_ind) + #print(ind_poly_copy) + ind_poly[ind_poly<0] = 0 + x, y, w, h = cv2.boundingRect(ind_poly) + #print(ind_poly_copy, np.shape(ind_poly_copy)) + #print(x, y, w, h, h/float(w),'ratio') + h2w_ratio = h/float(w) + mask_poly = np.zeros(image_page.shape) + img_poly_on_img = np.copy(image_page) + + mask_poly = cv2.fillPoly(mask_poly, pts=[ind_poly], color=(1, 1, 1)) + + if self.textline_light: + mask_poly = cv2.dilate(mask_poly, KERNEL, iterations=1) + + img_poly_on_img[:,:,0][mask_poly[:,:,0] ==0] = 255 + img_poly_on_img[:,:,1][mask_poly[:,:,0] ==0] = 255 + img_poly_on_img[:,:,2][mask_poly[:,:,0] ==0] = 255 + + img_croped = img_poly_on_img[y:y+h, x:x+w, :] + text_ocr = self.return_ocr_of_textline_without_common_section(img_croped, model_ocr, processor, device, w, h2w_ratio, ind_tot) + + ocr_textline_in_textregion.append(text_ocr) + + ##cv2.imwrite(str(ind_tot)+'.png', img_croped) + ind_tot = ind_tot +1 + ocr_all_textlines.append(ocr_textline_in_textregion) + + else: + ocr_all_textlines = None + #print(ocr_all_textlines) self.logger.info("detection of reading order took %.1fs", time.time() - t_order) - pcgts = self.writer.build_pagexml_no_full_layout(txt_con_org, page_coord, order_text_new, id_of_texts_tot, all_found_textline_polygons, all_box_coord, polygons_of_images, polygons_of_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_xml, contours_tables) + pcgts = self.writer.build_pagexml_no_full_layout(txt_con_org, page_coord, order_text_new, id_of_texts_tot, all_found_textline_polygons, all_box_coord, polygons_of_images, polygons_of_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_xml, contours_tables, ocr_all_textlines) self.logger.info("Job done in %.1fs", time.time() - t0) ##return pcgts self.writer.write_pagexml(pcgts) #self.logger.info("Job done in %.1fs", time.time() - t0) + if self.dir_in: self.logger.info("All jobs done in %.1fs", time.time() - t0_tot) diff --git a/qurator/eynollah/writer.py b/qurator/eynollah/writer.py index f537f65..c69be9b 100644 --- a/qurator/eynollah/writer.py +++ b/qurator/eynollah/writer.py @@ -2,7 +2,7 @@ # pylint: disable=import-error from pathlib import Path import os.path - +import xml.etree.ElementTree as ET from .utils.xml import create_page_xml, xml_reading_order from .utils.counter import EynollahIdCounter @@ -12,6 +12,7 @@ CoordsType, PcGtsType, TextLineType, + TextEquivType, TextRegionType, ImageRegionType, TableRegionType, @@ -93,11 +94,13 @@ def serialize_lines_in_marginal(self, marginal_region, all_found_textline_polygo points_co += ' ' coords.set_points(points_co[:-1]) - def serialize_lines_in_region(self, text_region, all_found_textline_polygons, region_idx, page_coord, all_box_coord, slopes, counter): + def serialize_lines_in_region(self, text_region, all_found_textline_polygons, region_idx, page_coord, all_box_coord, slopes, counter, ocr_all_textlines_textregion): self.logger.debug('enter serialize_lines_in_region') for j in range(len(all_found_textline_polygons[region_idx])): coords = CoordsType() textline = TextLineType(id=counter.next_line_id, Coords=coords) + if ocr_all_textlines_textregion: + textline.set_TextEquiv( [ TextEquivType(Unicode=ocr_all_textlines_textregion[j]) ] ) text_region.add_TextLine(textline) region_bboxes = all_box_coord[region_idx] points_co = '' @@ -140,7 +143,7 @@ def write_pagexml(self, pcgts): with open(out_fname, 'w') as f: f.write(to_xml(pcgts)) - def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_box_coord, found_polygons_text_region_img, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, found_polygons_tables): + def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, order_of_texts, id_of_texts, all_found_textline_polygons, all_box_coord, found_polygons_text_region_img, found_polygons_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_to_be_written_in_xml, found_polygons_tables, ocr_all_textlines): self.logger.debug('enter build_pagexml_no_full_layout') # create the file structure @@ -159,7 +162,11 @@ def build_pagexml_no_full_layout(self, found_polygons_text_region, page_coord, o Coords=CoordsType(points=self.calculate_polygon_coords(found_polygons_text_region[mm], page_coord)), ) page.add_TextRegion(textregion) - self.serialize_lines_in_region(textregion, all_found_textline_polygons, mm, page_coord, all_box_coord, slopes, counter) + if ocr_all_textlines: + ocr_textlines = ocr_all_textlines[mm] + else: + ocr_textlines = None + self.serialize_lines_in_region(textregion, all_found_textline_polygons, mm, page_coord, all_box_coord, slopes, counter, ocr_textlines) for mm in range(len(found_polygons_marginals)): marginal = TextRegionType(id=counter.next_region_id, type_='marginalia',