-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
87 lines (68 loc) · 2.53 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from torch_geometric.data import InMemoryDataset, download_url, Data
import re
import numpy
import os
class MANETDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
super().__init__(root, transform, pre_transform, pre_filter)
self.data = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return 'p2p-Gnutella05.txt'
@property
def processed_file_names(self):
return 'data.pt'
def download(self):
# Download to `self.raw_dir`.
pass
def getEdgeIndex(self):
edge_indices = [[],[]]
deneme = [[]]
with open('data/p2p-Gnutella05.txt') as file:
for line in file:
#print(line)
tmp = []
tmp.append([int(s) for s in re.split('\t|\n|\[\,|\]',line) if s.isdigit()])
tmp = numpy.array(tmp).T.tolist()
print(tmp)
#edge_indices.append(tmp)
#row = []
#row.append(line)
edge_indices = numpy.append(edge_indices,tmp,axis=1)
#deneme.append(tmp[0][0],)
#print(len(edge_indices[0]))
#print(deneme)
int_edge_indices = []
#for element in edge_indices:
#print(element)
#int_edge_indices.append(int(element))
#edge_indices = numpy.array(edge_indices).T
print(edge_indices)
edge_indices = torch.tensor(edge_indices).to(torch.long)
#edge_indices = edge_indices.t().to(torch.long).view(2, -1)
print(edge_indices.shape)
return edge_indices
def getNodes(self):
return torch.tensor(numpy.zeros(shape=(8844,1)))
def process(self):
# Read data into huge `Data` list.
data_list = []
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
edge_index = self.getEdgeIndex()
x = self.getNodes()
#data, slices = self.collate(data_list)
data = Data(x=x, edge_index=edge_index)
self.data = data
#torch.save((data, slices), self.processed_paths[0])
torch.save(data,
os.path.join(self.processed_dir,'data.pt'))
dataset = MANETDataset(root="data/")
data = dataset[0]
for key, item in data:
print(f'{key} found in data')
print(data.num_nodes)
print(data.num_edges)