forked from NVlabs/ocropus3-ocrorot
-
Notifications
You must be signed in to change notification settings - Fork 1
/
ocrorot-pred
executable file
·114 lines (88 loc) · 2.73 KB
/
ocrorot-pred
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/python
import os
import os.path
import argparse
import ocrorot
import scipy.ndimage as ndi
from pylab import *
from torch import nn
from dlinputs import gopen, paths, utils, filters
model_path = os.environ.get(
"MODELS", ".:/usr/local/share/ocrorot:/usr/share/ocrorot")
default_model = "rot-000003456-020897.pt"
parser = argparse.ArgumentParser("train a page segmenter")
parser.add_argument("-m", "--model", default=default_model, help="load model")
parser.add_argument("-b", "--batchsize", type=int, default=1)
parser.add_argument("-D", "--makesource", default=None)
parser.add_argument("-P", "--makepipeline", default=None)
parser.add_argument("-i", "--invert", action="store_true")
parser.add_argument("--display", type=int, default=0)
parser.add_argument("input")
parser.add_argument("output", nargs="?")
args = parser.parse_args()
ARGS = {k: v for k, v in args.__dict__.items()}
if args.display > 0:
rc("image", cmap="gray")
ion()
def make_source():
return gopen.sharditerator_once(args.input)
def make_pipeline():
def fixdepth(image):
assert image.ndim in [2, 3]
if image.ndim == 3:
image = np.mean(image, 2)
image -= amin(image)
image /= amax(image)
if args.invert:
image = 1-image
return image
return filters.compose(
filters.rename(input="bin.png png"),
filters.map(input=fixdepth))
if args.makesource:
execfile(args.makesource)
if args.makepipeline:
execfile(args.makepipeline)
def pixels_to_batch(x):
b, d, h, w = x.size()
return x.permute(0, 2, 3, 1).contiguous().view(b*h*w, d)
class PixelsToBatch(nn.Module):
def forward(self, x):
return pixels_to_batch(x)
source = make_source()
pipeline = make_pipeline()
source = pipeline(source)
if args.output:
sink = gopen.open_sink(args.output)
mname = paths.find_file(model_path, args.model)
assert mname is not None, "model not found"
print "loading", mname
rot = ocrorot.RotationEstimator(mname)
print rot.model
def display_batch(image, output):
clf()
if image is not None:
subplot(121)
imshow(image[0, :, :, 0], vmin=0, vmax=1)
if output is not None:
subplot(122)
imshow(output[0, :, :, 0], vmin=0, vmax=1)
draw()
ginput(1, 1e-3)
for i, sample in enumerate(source):
fname = sample["__key__"]
image = sample["input"]
output = rot.rotation(image)
print i, fname, output
if args.output:
if angle != 0.0:
corrected = ndi.rotate(image, -output, order=1)
else:
corrected = image
result = utils.metadict(sample, {
"__key__": fname,
"bin.png": corrected
})
sink.write(result)
if args.output:
sink.close()