-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_data.py
130 lines (107 loc) · 4.19 KB
/
generate_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
from random import choices, sample, choice
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
import matplotlib.pyplot as plt
from mlxtend.plotting import plot_confusion_matrix
from PIL import Image
variances = [0.25]
def gen_square(size, edge_len, add_noise = False):
res = np.zeros((size,size))
dist_from = choices(range(size-edge_len+1),k=2)
for i in range(size):
for j in range(size):
if i >= dist_from[0] and i < (dist_from[0] + edge_len) and j >= dist_from[1] and j < edge_len + dist_from[1]:
res[i,j] =1
# print(dist_from)
noise = np.random.multivariate_normal(np.full(size,0),np.diag(np.full(size,sample(variances,1))),(size))
# noise = abs(noise)
if add_noise:
return torch.tensor(res + noise, dtype=torch.float)
return torch.tensor(res, dtype=torch.float)
# a = np.array(gen_square(32,17, add_noise=True))
# print(a)
# im = Image.fromarray(a*255)
# im.show()
def gen_circle(size, radius, add_noise = False):
res = np.zeros((size,size))
center = choices(range(round(radius),round(size-radius+1)),k=2)
for i in range(size):
for j in range(size):
if np.sqrt((i-center[0])**2+(j-center[1])**2) < radius:
res[i,j] = 1
noise = np.random.multivariate_normal(np.full(size,0),np.diag(np.full(size,sample(variances,1))),(size))
# noise = abs(noise)
if add_noise:
return torch.tensor(res + noise, dtype=torch.float)
return torch.tensor(res, dtype=torch.float)
# a = np.array(gen_circle(32, 14, add_noise=True))
# print(a)
# im = Image.fromarray(a*255)
# im.show()
def gen_rectangle(size, edge_len1, edge_len2, add_noise = False):
if edge_len1==edge_len2:
print("error, this is a square")
res = np.zeros((size,size))
max_len = max(edge_len1,edge_len2)
dist_from1 = sample(range(size-edge_len1+1),1)
dist_from2 = sample(range(size-edge_len2+1),1)
for i in range(size):
for j in range(size):
if i >= dist_from1[0] and i < (dist_from1[0] + edge_len1) and j >= dist_from2[0] and j < edge_len2 + dist_from2[0]:
res[i,j] =1
# print(dist_from)
noise = np.random.multivariate_normal(np.full(size,0),np.diag(np.full(size,sample(variances,1))),(size))
# noise = abs(noise)
if add_noise:
return torch.tensor(res + noise, dtype=torch.float)
return torch.tensor(res, dtype=torch.float)
def gen_triangle(size, base_len, add_noise = False):
res = np.zeros((size,size))
height_len = round(base_len/2)
height = choice(range(height_len-1,size+1))
dist_from_left = choice(range(size-base_len+1))
for i in range(size):
for j in range(size):
dist_from_base = height - i
if j >= dist_from_left+dist_from_base and j < dist_from_left+base_len - dist_from_base and i <= height and i > height-height_len:
res[i,j] =1
# print(dist_from)
noise = np.random.multivariate_normal(np.full(size,0),np.diag(np.full(size,sample(variances,1))),(size))
# noise = abs(noise)
if add_noise:
return torch.tensor(res + noise, dtype=torch.float)
return torch.tensor(res, dtype=torch.float)
#
# a = np.array(gen_triangle(10, 7, add_noise=True))
# print(a)
# im = Image.fromarray(a*255)
# im.show()
def gen_noise(size):
noise = np.random.multivariate_normal(np.full(size,0),np.diag(np.full(size,sample(variances,1))),(size))
return torch.tensor(noise, dtype=torch.float)
def dict_append(key, val, dictionary):
"""
Implements the dictionary insertion for the 'find_absolute_tags' function
:param key: dictionary key
:param dictionary: the dictionary created in the external function
"""
if key not in dictionary:
dictionary[key] = [val]
else:
dictionary[key].append(val)
def multiplyList(myList):
# Multiply elements one by one
result = 1
for x in myList:
result = result * x
return result
def sum_params(model):
sum = 0
for name, params in model.state_dict().items():
if multiplyList(params.shape) != 1:
print(name, multiplyList(params.shape))
sum += multiplyList(params.shape)
print(f"sum :{sum}")