forked from worldveil/invisible-watermark
-
Notifications
You must be signed in to change notification settings - Fork 12
/
export_onnx.py
59 lines (49 loc) · 2.3 KB
/
export_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import torch
import cv2
from rivagan import RivaGAN
torch.nn.Module.dump_patches = True
if __name__ == '__main__':
bgr = cv2.imread('test_vectors/original.jpg')
watermarks = np.random.randint(0,2,32)
data = torch.from_numpy(np.array([watermarks], dtype=np.float32))
frame = torch.from_numpy(np.array([bgr], dtype=np.float32)) / 127.5 - 1.0
frame = frame.permute(3, 0, 1, 2).unsqueeze(0)
rivagan = torch.load('rivaGan.pt', map_location=torch.device('cpu'))
encoder = rivagan.encoder
torch.onnx.export(encoder, args=(frame, data), f='rivagan_encoder.onnx',
export_params=True, opset_version=10, do_constant_folding=True,
input_names = ['frame', 'data'],
output_names = ['output'],
dynamic_axes={
'frame': {
0:'batch_size',
3:'height',
4:'width'
},
'data': {
0:'batch_size',
1:'wmBits'
},
'output': {
0:'batch_size',
3:'height',
4:'width'
}
})
decoder = rivagan.decoder
torch.onnx.export(decoder, args=(frame), f='rivagan_decoder.onnx',
export_params=True, opset_version=10, do_constant_folding=True,
input_names = ['frame'],
output_names = ['output'],
dynamic_axes={
'frame': {
0:'batch_size',
3:'height',
4:'width'
},
'output': {
0:'batch_size',
1:'wmBits'
}
})