-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
124 lines (100 loc) · 3.68 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import base64
import pem
import requests
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers, RSAPublicKey
from cryptography.hazmat.primitives.serialization import load_pem_public_key
def is_dict(text):
if isinstance(text, dict):
return True
else:
return False
def base64url_encode(input_bytes: bytes):
return base64.urlsafe_b64encode(input_bytes).replace(b'=', b'')
def base64url_decode(input_str):
if isinstance(input_str, str):
input_str = input_str.encode('utf-8')
# removes the base64 padding (=)
padding = len(input_str) % 4
if padding > 0:
input_str += b'=' * (4 - padding)
return base64.urlsafe_b64decode(input_str)
def force_unicode(value):
if isinstance(value, bytes):
return value.decode('utf-8')
elif isinstance(value, str):
return value.encode('utf-8')
else:
raise TypeError('Unexpected type {}'.format(type(value)))
def force_bytes(value):
if isinstance(value, str):
return value.encode('utf-8')
elif isinstance(value, bytes):
return value
else:
raise TypeError('Unexpected type {}'.format(type(value)))
def read_pem_file(file, debug=True):
key = pem.parse_file(file)
if debug:
print('Read PEM:\n{}\n'.format(str(key)))
return key[0].as_bytes()
def read_file(file: str, strip_newline=True, join_lines=False):
"""
Reads a file
:param file: Full path of the file to read. HTTP is used to retrieve file if parameter starts with https://
:param strip_newline: If True, carriage returns (\n) are removed
:param join_lines: If True, lines read are concatenated into a single line.
:return: List of lines read
"""
if file.lower().startswith("https://"):
# TODO consider if keepends overlaps with strip_newline or join_lines
key_list = requests.get(file, verify=True).text.splitlines()
else:
with open(file, "r") as f:
key_list = f.readlines()
if strip_newline:
key_list = [x.replace('\n', '').strip() for x in key_list]
if join_lines:
return "".join(key_list)
return key_list
def rsa_jwk_to_pubkey(jwk):
"""
Only RS256 supported so far
:return:
"""
e = int.from_bytes(base64url_decode(jwk['e']), byteorder='big')
n = int.from_bytes(base64url_decode(jwk['n']), byteorder='big')
pub_num = RSAPublicNumbers(e, n)
return pub_num.public_key(backend=default_backend())
def rsa_jwks_to_pubkey(jwks: dict, keyid=None):
jwk = {}
if keyid is None:
jwk = jwks['keys'][0]
else:
for key in jwks['keys']:
if key['kid'] == keyid:
jwk = key
return rsa_jwk_to_pubkey(jwk)
def rsa_pubkey_to_jwk(pem_file, key_id=None, pubkey: RSAPublicKey=None):
"""
Only RS256 supported so far
:param pubkey: if provided, it is used. Otherwise, try to load public key from pem file
:param pem_file: path of PEM file containing RSA public key
:param key_id: (optional) keyid to be included in the JWK
:return:
"""
# TODO add test for pem_file override
if pubkey is None:
rsa_pk: RSAPublicKey = load_pem_public_key(read_pem_file(pem_file), default_backend())
else:
rsa_pk = pubkey
e = rsa_pk.public_numbers().e
n = rsa_pk.public_numbers().n
jwk = {}
if key_id is not None:
jwk['kid'] = key_id
jwk['kty'] = "RSA"
jwk["alg"] = "RS256"
jwk['n'] = force_unicode(base64url_encode(n.to_bytes((n.bit_length() + 7) // 8, byteorder='big')))
jwk['e'] = force_unicode(base64url_encode(e.to_bytes((e.bit_length() + 7) // 8, byteorder='big')))
return jwk