-
Notifications
You must be signed in to change notification settings - Fork 2
/
openvpnzone.py
293 lines (262 loc) · 11.8 KB
/
openvpnzone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import os.path
import signal
import collections
from twisted.names import dns
from twisted.names.authority import FileAuthority
from twisted.names.client import Resolver
from twisted.internet import inotify
from twisted.internet import defer
from twisted.internet.task import deferLater
from twisted.python import filepath
from IPy import IP
def extract_zones_from_status_file(status_path):
""" Parses a openvpn status file and extracts the list of connected clients
and there ip address """
with open(status_path, 'r') as status_file:
mode = None
mode_changes = {
'OpenVPN CLIENT LIST': 'clients',
'ROUTING TABLE': 'routes',
'GLOBAL STATS': None
}
clients = {}
skip_next_lines = 0
for status_line in status_file:
if skip_next_lines > 0:
skip_next_lines -= 1
continue
status_line = status_line.strip()
if status_line in mode_changes:
mode = mode_changes[status_line]
skip_next_lines = 1
if mode == 'clients':
skip_next_lines += 1
continue
if mode == 'clients':
clients[status_line.split(',')[0]] = []
if mode == 'routes':
address, client = status_line.split(',')[0:2]
if '/' in address: # subnet
continue
try:
address = IP(address)
except ValueError: # cached route ...
continue
if address.len() > 1: # subnet
continue
if client not in clients:
raise ValueError('Error in status file')
clients[client].append(address)
return clients
class InMemoryAuthority(FileAuthority):
""" In memory authority class - handles the data of one zone"""
def __init__(self, data=None):
FileAuthority.__init__(self, data)
def loadFile(self, data):
if type(data) is tuple and len(data) == 2:
self.setData(*data)
def setData(self, soa, records):
""" set authority data
:param twisted.names.dns.Record_SOA soa: SOA record for this zone.
you must add the soa to the records list yourself!!
:param dict records: dictionary with record entries for this
domain."""
if soa == self.soa or self.changed(soa, records) is False:
return False
if type(soa) is tuple:
print('updated zone {0} to serial {1}'.format(soa[0], soa[1].serial))
self.soa = soa
self.records = records
return True
def changed(self, soa, records):
""" Checks whether the new record list differs from the old one"""
if self.records is None: # previously set data
return True
if len(self.records) != len(records):
return True
for name in self.records:
if name not in records:
return True
if len(self.records[name]) != len(records[name]):
return True
for record in self.records[name]:
for new_record in records[name]:
if new_record == record:
break
if new_record.__class__ is dns.Record_SOA and \
record.__class__ is dns.Record_SOA and \
new_record.mname == record.mname and \
new_record.rname == record.rname and \
new_record.refresh == record.refresh and \
new_record.retry == record.retry and \
new_record.expire == record.expire and \
new_record.minimum == record.minimum:
break
else:
return True
return False
AuthorityTuple = collections.namedtuple('AuthorityTuple', ('forward',
'backward4', 'backward6'))
class OpenVpnAuthorityHandler(list):
def __init__(self, config):
self.config = config
self.send_notify = False
# authorities for the data itself:
self.authorities = {}
for instance in self.config.instances:
self.authorities[instance] = AuthorityTuple(
forward=InMemoryAuthority(),
backward4=InMemoryAuthority(),
backward6=InMemoryAuthority()
)
self.append(self.authorities[instance].forward)
if self.config.instances[instance].subnet4:
self.append(self.authorities[instance].backward4)
if self.config.instances[instance].subnet6:
self.append(self.authorities[instance].backward6)
# load data:
self.loadInstances()
# watch for file changes:
signal.signal(signal.SIGUSR1, self.handle_signal)
notifier = inotify.INotify()
notifier.startReading()
for instance in self.config.instances.values():
notifier.watch(filepath.FilePath(instance.status_file),
callbacks=[self.status_file_changed])
print('Serving {0} zones: {1}'.format(len(self),
', '.join([z.soa[0].decode('utf-8') for z in self])))
def loadInstances(self):
""" (re)load data of all instances"""
for instance in self.config.instances:
self.loadInstance(self.config.instances[instance])
def loadInstance(self, instance):
clients = extract_zones_from_status_file(instance.status_file)
self.build_zone_from_clients(instance, clients)
@staticmethod
def create_record_base(zone_name, soa, initial_data):
records = collections.defaultdict(list)
if zone_name is None:
return records
records[zone_name.encode('utf-8')].append(soa)
for name, record in initial_data:
if name == '@':
name = zone_name
elif not name.endswith('.'):
name = name + '.' + zone_name
records[name.encode('utf-8')].append(record)
return records
def build_zone_from_clients(self, instance, clients):
""" Basic zone generation (uses only the client list),
additional data like SOA information must be passed
as keyword option """
soa = dns.Record_SOA(
mname=instance.mname,
rname=instance.rname,
serial=int(os.path.getmtime(instance.status_file)),
refresh=instance.refresh,
retry=instance.retry,
expire=instance.expire,
minimum=instance.minimum,
)
forward_records = self.create_record_base(instance.name, soa,
instance.forward_records)
backward4_records = self.create_record_base(instance.subnet4, soa,
instance.backward4_records)
backward6_records = self.create_record_base(instance.subnet6, soa,
instance.backward6_records)
for client, addresses in clients.items():
if instance.suffix is not None:
if instance.suffix == '@':
client += '.' + instance.name
else:
client += '.' + instance.suffix
client = client.lower().encode('utf-8')
for address in addresses:
reverse = IP(address).reverseName()[:-1].encode('utf-8')
if address.version() == 4:
forward_records[client] \
.append(dns.Record_A(str(address)))
backward4_records[reverse] \
.append(dns.Record_PTR(client))
elif address.version() == 6:
forward_records[client] \
.append(dns.Record_AAAA(str(address)))
backward6_records[reverse] \
.append(dns.Record_PTR(client))
# push data to authorities:
authority = self.authorities[instance.name]
if authority.forward.setData((instance.name.encode('utf-8'), soa), forward_records):
self.notify(instance, instance.name.encode('utf-8'))
if instance.subnet4:
if authority.backward4.setData((instance.subnet4.encode('utf-8'), soa), backward4_records):
self.notify(instance, instance.subnet4.encode('utf-8'))
if instance.subnet6:
if authority.backward6.setData((instance.subnet6.encode('utf-8'), soa), backward6_records):
self.notify(instance, instance.subnet6.encode('utf-8'))
def handle_signal(self, a, b):
self.loadInstances()
def status_file_changed(self, ignored, filepath, mask):
""" This is a callback for the twisted INotify module to inform about
file changes on status files. This methods searches for the
associated instances and schedules a loadInstances with a timeout
of one second to handle multiple file changes only once."""
instance = None
for one_instance in self.config.instances.values():
if one_instance.status_file == filepath.path:
instance = one_instance
break
if instance is None:
print('unknown status file: {0}'.format(filepath.path))
return
from twisted.internet import reactor
instance.version += 1
deferLater(reactor, 1, self.status_file_change_done,
instance, instance.version,
','.join(inotify.humanReadableMask(mask)))
def status_file_change_done(self, instance, version, reason=None):
""" This is a callback for twisted deferLater and scheduled by
status_file_changed. If the update id of the instance is the same
one as passed as argument the file is not changed within the last
second and we reload the instance.
:param config.OpenVpnInstance instance: instance
:param int version: version
:param str reason: textual reason for the reload"""
if instance.version > version: # file was modified again
return
print('rereading instance {2}: {0} changed ({1}), '.format(
instance.status_file, reason, instance.name))
self.loadInstance(instance)
def notify(self, instance, name):
if self.send_notify is not True:
return
for server in instance.notify:
print('Notify {0} new data for zone {1}'.format(server[0], name))
r = NotifyResolver(servers=[server])
r.sendNotify(name)
def start_notify(self):
self.send_notify = True
for instance in self.config.instances.values():
self.notify(instance, instance.name.encode('utf-8'))
if instance.subnet4:
self.notify(instance, instance.subnet4.encode('utf-8'))
if instance.subnet6:
self.notify(instance, instance.subnet6.encode('utf-8'))
class NotifyResolver(Resolver):
def sendNotify(self, zone):
protocol = self._connectedProtocol()
id = protocol.pickID()
m = dns.Message(id, opCode=dns.OP_NOTIFY)
m.queries = [dns.Query(zone, dns.SOA, dns.IN)]
try:
protocol.writeMessage(m, self.servers[0])
except:
return defer.fail()
resultDeferred = defer.Deferred()
cancelCall = protocol.callLater(10, protocol._clearFailed, resultDeferred, id)
protocol.liveMessages[id] = (resultDeferred, cancelCall)
d = resultDeferred
def cbQueried(result):
protocol.transport.stopListening()
return result
d.addBoth(cbQueried)
return d