-
Notifications
You must be signed in to change notification settings - Fork 2
/
make_tlsa
executable file
·257 lines (233 loc) · 7.6 KB
/
make_tlsa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
#!/usr/bin/python
# ⁻*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
# stdlib modules
import os
import sys
import ssl
import json
import base64
import hashlib
import urllib2
import binascii
import contextlib
import subprocess
# local modules
import config
force = False
sshfp_algo = {
"rsa": 1,
"dss": 2,
"ecdsa": 3,
"ed25519": 4
}
def genSSHFP(hostname, cert):
cert = cert.split()
kind = cert[0]
cert = cert[1]
if hostname == "":
hostname = "@"
if kind.startswith("ssh-"):
kind = kind[4:]
elif kind.startswith("ecdsa-"):
kind = "ecdsa"
if kind in sshfp_algo:
return '%s\tIN\tSSHFP %s %s %s' % (
hostname,
sshfp_algo[kind],
2,
hashlib.sha256(base64.b64decode(cert)).hexdigest()
)
# reftype: 0 = plain cert, 1 = sha256, 2 = sha512
# certytype:
# 0 = CA pinning,
# 1 = cert pinning,
# 2 = self trusted CA,
# 3 = self trusted cert
# selector: 0 = full cert, 1 = SubjectPublicKeyInfo
def genTLSA(
hostname, pemcert, port, proto, certtype=1,
selector=0, reftype=1, compat=True
):
if proto not in ['tcp', 'udp']:
raise ValueError("proto should be tcp or udp, nor %s" % proto)
if not isinstance(port, int):
raise ValueError("port should be int")
if selector != 0:
raise NotImplemented("Selection must be 0")
if certtype not in [1, 3]:
raise NotImplemented("Only EE-cert supported right now")
# errors will be thrown already before we get here
dercert = ssl.PEM_cert_to_DER_cert(pemcert)
if not dercert:
raise ValueError("Fail to decode PEM cert to DER")
if hostname:
hostname = ".%s" % hostname
certhex = hashCert(reftype, dercert)
if compat:
# octet length is half of the string length;
# remember certtype and reftype are part of the length so +2
return "_%s._%s%s\tIN\tTYPE52 \# %s 0%s0%s0%s%s" % (
port,
proto,
hostname,
len(certhex) / 2 + 3,
certtype,
selector,
reftype,
certhex
)
else:
return "_%s._%s%s\tIN\tTLSA %s %s %s %s" % (
port,
proto,
hostname,
certtype,
selector,
reftype,
certhex
)
def hashCert(reftype, certblob):
if reftype == 0:
return binascii.b2a_hex(certblob).upper()
elif reftype == 1:
hashobj = hashlib.sha256()
hashobj.update(certblob)
elif reftype == 2:
hashobj = hashlib.sha512()
hashobj.update(certblob)
else:
return 0
return hashobj.hexdigest().upper()
def validate_services_definition(data, server):
conf = data[0]
certs = data[1]
if not conf:
raise ValueError("No services definition on %s" % server)
for service, params in conf.items():
if "transport" not in params or "ports" not in params:
raise ValueError(
"%s section as not transport or no ports defined on %s" % (
service,
server
)
)
if params["transport"] not in ["tcp", "udp"]:
raise ValueError(
"transport must be 'tcp' or 'udp' in %s section on %s" % (
service,
server
)
)
for port in params["ports"]:
if not isinstance(port, int):
raise ValueError(
"ports should be integer in %s section on %s" % (
service,
server
)
)
certs_services = set()
for domain, dns in certs.items():
for dn, params in dns.items():
for service in params["services"]:
certs_services.add(service)
services_not_defined = certs_services.difference(conf.keys())
if services_not_defined:
raise ValueError(
"Services %s not defined on %s" % (
", ".join(services_not_defined),
server
)
)
def _generate_domain(data, domain, done=None):
if done is None:
done = set()
RRs = []
services_definition = data[0]
certificates = data[1].get(domain, {})
for dn, params in certificates.items():
for service in params["services"]:
for port in services_definition[service]["ports"]:
for cert in params["certs"]:
for reftype in config.DNS_TLSA_REFTYPE:
tlsa_rr = genTLSA(
dn + '.',
cert,
port,
services_definition[service]["transport"],
certtype=3,
selector=0,
reftype=reftype,
compat=config.DNS_TLSA_COMPAT
)
if tlsa_rr not in done:
RRs.append(tlsa_rr)
done.add(tlsa_rr)
ssh_certificates = data[2].get(domain, {})
for dn, certs in ssh_certificates.items():
for cert in certs:
sshfp_rr = genSSHFP(dn + '.', cert)
if sshfp_rr and sshfp_rr not in done:
RRs.append(sshfp_rr)
done.add(sshfp_rr)
return RRs
def generate_domain(domain):
done = set()
RRs = []
for server in config.SERVERS:
try:
with contextlib.closing(
urllib2.urlopen("http://%s:%s" % (server, config.PORT))
) as f:
data = json.loads(f.read())
validate_services_definition(data, server)
RRs.extend(_generate_domain(data, domain, done))
except urllib2.URLError:
sys.stderr.write(
"Failed to connect to http://%s:%s\n use --force to force" % (
server,
config.PORT
)
)
if not force:
sys.exit(1)
RRs.sort()
return "\n".join(RRs) + "\n"
def update_serial(domain):
if (
config.UPDATE_SERIAL_SCRIPT and
os.path.isfile(config.UPDATE_SERIAL_SCRIPT)
):
p = subprocess.Popen([config.UPDATE_SERIAL_SCRIPT, domain])
code = p.wait()
if code > 0:
print "Serial of zone %s need to be updated" % domain
else:
print "Zone %s updated" % domain
else:
print "Serial of zone %s need to be updated" % domain
if __name__ == '__main__':
servers = config.SERVERS
bind_base_path = config.BIND_BASE_PATH
force = "--force" in sys.argv
for domain in config.DOMAINS:
data = generate_domain(domain)
with open("%s%s" % (bind_base_path, domain)) as f1:
if f1.read() != data:
with open("%s%s.new" % (bind_base_path, domain), 'w') as f2:
f2.write(data)
os.rename(
"%s%s.new" % (bind_base_path, domain),
"%s%s" % (bind_base_path, domain)
)
update_serial(domain)