Skip to content

Commit

Permalink
passing number of columns as an argument
Browse files Browse the repository at this point in the history
  • Loading branch information
vahidrezanezhad committed Sep 12, 2024
1 parent 2c93904 commit 1b18ae8
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 22 deletions.
14 changes: 13 additions & 1 deletion qurator/eynollah/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,24 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i
is_flag=True,
help="if this parameter set to true, this tool will try to do ocr",
)
@click.option(
"--num_col_upper",
"-ncu",
help="lower limit of columns in document image",
)
@click.option(
"--num_col_lower",
"-ncl",
help="upper limit of columns in document image",
)
@click.option(
"--log_level",
"-l",
type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']),
help="Override log level globally to this",
)

def layout(image, out, dir_in, model, save_images, save_layout, save_deskewed, save_all, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, ignore_page_extraction, log_level):
def layout(image, out, dir_in, model, save_images, save_layout, save_deskewed, save_all, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, num_col_upper, num_col_lower, ignore_page_extraction, log_level):
if log_level:
setOverrideLogLevel(log_level)
initLogging()
Expand Down Expand Up @@ -235,6 +245,8 @@ def layout(image, out, dir_in, model, save_images, save_layout, save_deskewed, s
ignore_page_extraction=ignore_page_extraction,
reading_order_machine_based=reading_order_machine_based,
do_ocr=do_ocr,
num_col_upper=num_col_upper,
num_col_lower=num_col_lower,
)
if dir_in:
eynollah.run()
Expand Down
96 changes: 75 additions & 21 deletions qurator/eynollah/eynollah.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def __init__(
ignore_page_extraction=False,
reading_order_machine_based=False,
do_ocr=False,
num_col_upper=None,
num_col_lower=None,
override_dpi=None,
logger=None,
pcgts=None,
Expand Down Expand Up @@ -212,6 +214,14 @@ def __init__(
self.headers_off = headers_off
self.ignore_page_extraction = ignore_page_extraction
self.ocr = do_ocr
if num_col_upper:
self.num_col_upper = int(num_col_upper)
else:
self.num_col_upper = num_col_upper
if num_col_lower:
self.num_col_lower = int(num_col_lower)
else:
self.num_col_lower = num_col_lower
self.pcgts = pcgts
if not dir_in:
self.plotter = None if not enable_plotting else EynollahPlotter(
Expand Down Expand Up @@ -597,36 +607,80 @@ def resize_and_enhance_image_with_column_classifier(self,light_version):
else:
img = self.imread()
img_bin = None


width_early = img.shape[1]
t1 = time.time()
_, page_coord = self.early_page_for_num_of_column_classification(img_bin)
if not self.dir_in:
model_num_classifier, session_col_classifier = self.start_new_session_and_model(self.model_dir_of_col_classifier)

if self.input_binary:
img_in = np.copy(img)
width_early = img_in.shape[1]
img_in = img_in / 255.0
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = img_in.reshape(1, 448, 448, 3)
else:
img_1ch = self.imread(grayscale=True)
width_early = img_1ch.shape[1]
img_1ch = img_1ch[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
if self.num_col_upper and not self.num_col_lower:
num_col = self.num_col_upper
label_p_pred = [np.ones(6)]
elif self.num_col_lower and not self.num_col_upper:
num_col = self.num_col_lower
label_p_pred = [np.ones(6)]

elif (not self.num_col_upper and not self.num_col_lower):
if self.input_binary:
img_in = np.copy(img)
img_in = img_in / 255.0
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = img_in.reshape(1, 448, 448, 3)
else:
img_1ch = self.imread(grayscale=True)
width_early = img_1ch.shape[1]
img_1ch = img_1ch[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]

img_1ch = img_1ch / 255.0
img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
img_in[0, :, :, 0] = img_1ch[:, :]
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
img_1ch = img_1ch / 255.0
img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
img_in[0, :, :, 0] = img_1ch[:, :]
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]


if self.dir_in:
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
if self.dir_in:
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
else:
label_p_pred = model_num_classifier.predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1
elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower):
if self.input_binary:
img_in = np.copy(img)
img_in = img_in / 255.0
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = img_in.reshape(1, 448, 448, 3)
else:
img_1ch = self.imread(grayscale=True)
width_early = img_1ch.shape[1]
img_1ch = img_1ch[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]

img_1ch = img_1ch / 255.0
img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
img_in[0, :, :, 0] = img_1ch[:, :]
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]


if self.dir_in:
label_p_pred = self.model_classifier.predict(img_in, verbose=0)
else:
label_p_pred = model_num_classifier.predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1

if num_col > self.num_col_upper:
num_col = self.num_col_upper
label_p_pred = [np.ones(6)]
if num_col < self.num_col_lower:
num_col = self.num_col_lower
label_p_pred = [np.ones(6)]

else:
label_p_pred = model_num_classifier.predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1
num_col = self.num_col_upper
label_p_pred = [np.ones(6)]


self.logger.info("Found %d columns (%s)", num_col, np.around(label_p_pred, decimals=5))

Expand Down

0 comments on commit 1b18ae8

Please sign in to comment.