-
Notifications
You must be signed in to change notification settings - Fork 5
/
plot_utils.py
134 lines (118 loc) · 3.9 KB
/
plot_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import inspect
from PIL import Image
import matplotlib.pyplot as plt
import dill
from functools import wraps
import numpy as np
def packplot(plotf):
@wraps(plotf)
def decorated_plotf(*args, **kwargs):
import dill
import inspect
from PIL import Image
import os
n_pos_args = len(args)
all_arg_names = list(inspect.signature(plotf).parameters.keys())
pos_arg_names = all_arg_names[:n_pos_args]
for i in range(n_pos_args):
kwargs[pos_arg_names[i]] = args[i]
assert 'filename' in kwargs
filename = kwargs['filename']
ret = plotf(**kwargs)
plot_function_name = plotf.__name__
try:
plot_function_code = inspect.getsource(plotf)
except OSError:
print(f"Cannot find the code of function {plot_function_name}")
return ret
plot_function_code = clean_code(plot_function_code)
img = Image.open(filename)
img_packed = img.copy()
img.close()
os.remove(filename)
packed_info = (kwargs, plot_function_code, plot_function_name)
packed_bytes = dill.dumps(packed_info)
img_packed.save(filename, exif=packed_bytes)
return ret
return decorated_plotf
#def plot_random(x, y, filename, **kwargs):
# plot_args = locals()
# import dill
# import inspect
# from PIL import Image
#
# import matplotlib.pyplot as plt
# plt.plot(x, y)
# plt.show()
# plt.savefig(filename)
#
# try:
# plot_function_code = inspect.getsource(plot_random)
# plot_function_name = 'plot_random'
# img = Image.open(filename)
# packed_info = (plot_args, plot_function_code, plot_function_name)
# packed_bytes = dill.dumps(packed_info)
# img.save(filename, exif=packed_bytes)
# except OSError:
# return
# return
def clean_code(code_str):
# Split the code into lines
lines = code_str.splitlines()
common_indent = ''
min_len = min(len(l) for l in lines)
should_break = False
for i in range(min_len):
cs = set()
for j in range(len(lines)):
c = lines[j][i]
if c in ['\t',' ']:
cs.add(c)
else:
should_break = True
if should_break:
break
if should_break:
break
if len(cs)==1:
common_indent += list(cs)[0]
# Strip leading and trailing whitespace from each line
cleaned_lines = [line[len(common_indent):] for line in lines]
# Join the cleaned lines back into a single string with newlines
cleaned_code_str = '\n'.join(cleaned_lines)
return cleaned_code_str
def retrieve_plot(filename):
img = Image.open(filename)
img.load()
packed_info = dill.loads(img.info['exif'][6:])
plot_args, plot_function_code, plot_function_name = packed_info
#assert plot_function_code.startswith("@packplot")
plot_function_code = plot_function_code.split("\n", 1)[1]
filename = plot_args['filename']
plot_args['filename'] = f'tmp-{str(hash(filename))}' + filename
exec(plot_function_code)
exec(plot_function_name + '(**plot_args)')
def retrieve_data(filename):
img = Image.open(filename)
img.load()
packed_info = dill.loads(img.info['exif'][6:])
plot_args, plot_function_code, plot_function_name = packed_info
return plot_args
def retrieve_code(filename):
img = Image.open(filename)
img.load()
packed_info = dill.loads(img.info['exif'][6:])
plot_args, plot_function_code, plot_function_name = packed_info
return plot_function_code
if __name__ == '__main__':
@packplot
def plot_test_dec(x, y, filename):
plt.plot(x, y)
plt.show()
plt.savefig(filename)
return 0
x = np.random.rand(10)
y = np.random.rand(10)
#plot_random(x, y, 'test.png')
plot_test_dec(x, y, 'test.png')
retrieve_plot('test.png')