From 4a0bc4b1d650af88c15eb7293550227d93bf1b17 Mon Sep 17 00:00:00 2001 From: Qi Zhao Date: Tue, 26 Mar 2024 17:17:18 +0800 Subject: [PATCH] =?UTF-8?q?style:=20=F0=9F=92=84=20format=20code=20using?= =?UTF-8?q?=20black?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyencrypt/aes.py | 145 ++++++++++++++--------------- pyencrypt/cli.py | 201 ++++++++++++++++++++++++++++------------- pyencrypt/decrypt.py | 12 ++- pyencrypt/encrypt.py | 53 ++++++----- pyencrypt/generate.py | 2 +- pyencrypt/license.py | 40 ++++---- pyencrypt/loader.py | 36 +++++--- tests/conftest.py | 82 ++++++++++------- tests/constants.py | 2 +- tests/test_aes.py | 63 +++++++------ tests/test_decrypt.py | 30 +++--- tests/test_encrypt.py | 111 ++++++++++++++--------- tests/test_generate.py | 8 +- tests/test_license.py | 108 ++++++++++++++-------- tests/test_loader.py | 82 +++++++++++------ tests/test_ntt.py | 90 ++++++++++++------ 16 files changed, 659 insertions(+), 406 deletions(-) diff --git a/pyencrypt/aes.py b/pyencrypt/aes.py index 88363d6..700e7f5 100644 --- a/pyencrypt/aes.py +++ b/pyencrypt/aes.py @@ -1,3 +1,4 @@ +# fmt: off import base64 import copy import struct @@ -32,9 +33,9 @@ def strip_padding(data: bytes) -> bytes: # Based *largely* on the Rijndael implementation # See: http://csrc.nist.gov/publications/fips/fips197/fips-197.pdf class AES(object): - '''Encapsulates the AES block cipher. + """Encapsulates the AES block cipher. You generally should not need this. Use the AESModeOfOperation classes - below instead.''' + below instead.""" # Number of rounds by keysize number_of_rounds = {16: 10, 24: 12, 32: 14} @@ -43,7 +44,7 @@ class AES(object): rcon = [ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, - 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91 + 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91, ] # S-box and Inverse S-box (S is for Substitution) @@ -69,7 +70,7 @@ class AES(object): 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, - 0xb0, 0x54, 0xbb, 0x16 + 0xb0, 0x54, 0xbb, 0x16, ] Si = [ 0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, @@ -93,7 +94,7 @@ class AES(object): 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef, 0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61, 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, - 0x55, 0x21, 0x0c, 0x7d + 0x55, 0x21, 0x0c, 0x7d, ] # Transformations for encryption @@ -140,7 +141,7 @@ class AES(object): 0x15878792, 0xc9e9e920, 0x87cece49, 0xaa5555ff, 0x50282878, 0xa5dfdf7a, 0x038c8c8f, 0x59a1a1f8, 0x09898980, 0x1a0d0d17, 0x65bfbfda, 0xd7e6e631, 0x844242c6, 0xd06868b8, 0x824141c3, 0x299999b0, 0x5a2d2d77, 0x1e0f0f11, - 0x7bb0b0cb, 0xa85454fc, 0x6dbbbbd6, 0x2c16163a + 0x7bb0b0cb, 0xa85454fc, 0x6dbbbbd6, 0x2c16163a, ] T2 = [ 0xa5c66363, 0x84f87c7c, 0x99ee7777, 0x8df67b7b, 0x0dfff2f2, 0xbdd66b6b, @@ -185,7 +186,7 @@ class AES(object): 0x92158787, 0x20c9e9e9, 0x4987cece, 0xffaa5555, 0x78502828, 0x7aa5dfdf, 0x8f038c8c, 0xf859a1a1, 0x80098989, 0x171a0d0d, 0xda65bfbf, 0x31d7e6e6, 0xc6844242, 0xb8d06868, 0xc3824141, 0xb0299999, 0x775a2d2d, 0x111e0f0f, - 0xcb7bb0b0, 0xfca85454, 0xd66dbbbb, 0x3a2c1616 + 0xcb7bb0b0, 0xfca85454, 0xd66dbbbb, 0x3a2c1616, ] T3 = [ 0x63a5c663, 0x7c84f87c, 0x7799ee77, 0x7b8df67b, 0xf20dfff2, 0x6bbdd66b, @@ -230,7 +231,7 @@ class AES(object): 0x87921587, 0xe920c9e9, 0xce4987ce, 0x55ffaa55, 0x28785028, 0xdf7aa5df, 0x8c8f038c, 0xa1f859a1, 0x89800989, 0x0d171a0d, 0xbfda65bf, 0xe631d7e6, 0x42c68442, 0x68b8d068, 0x41c38241, 0x99b02999, 0x2d775a2d, 0x0f111e0f, - 0xb0cb7bb0, 0x54fca854, 0xbbd66dbb, 0x163a2c16 + 0xb0cb7bb0, 0x54fca854, 0xbbd66dbb, 0x163a2c16, ] T4 = [ 0x6363a5c6, 0x7c7c84f8, 0x777799ee, 0x7b7b8df6, 0xf2f20dff, 0x6b6bbdd6, @@ -275,7 +276,7 @@ class AES(object): 0x87879215, 0xe9e920c9, 0xcece4987, 0x5555ffaa, 0x28287850, 0xdfdf7aa5, 0x8c8c8f03, 0xa1a1f859, 0x89898009, 0x0d0d171a, 0xbfbfda65, 0xe6e631d7, 0x4242c684, 0x6868b8d0, 0x4141c382, 0x9999b029, 0x2d2d775a, 0x0f0f111e, - 0xb0b0cb7b, 0x5454fca8, 0xbbbbd66d, 0x16163a2c + 0xb0b0cb7b, 0x5454fca8, 0xbbbbd66d, 0x16163a2c, ] # Transformations for decryption @@ -322,7 +323,7 @@ class AES(object): 0x1814ce79, 0x73c737bf, 0x53f7cdea, 0x5ffdaa5b, 0xdf3d6f14, 0x7844db86, 0xcaaff381, 0xb968c43e, 0x3824342c, 0xc2a3405f, 0x161dc372, 0xbce2250c, 0x283c498b, 0xff0d9541, 0x39a80171, 0x080cb3de, 0xd8b4e49c, 0x6456c190, - 0x7bcb8461, 0xd532b670, 0x486c5c74, 0xd0b85742 + 0x7bcb8461, 0xd532b670, 0x486c5c74, 0xd0b85742, ] T6 = [ 0x5051f4a7, 0x537e4165, 0xc31a17a4, 0x963a275e, 0xcb3bab6b, 0xf11f9d45, @@ -367,7 +368,7 @@ class AES(object): 0x791814ce, 0xbf73c737, 0xea53f7cd, 0x5b5ffdaa, 0x14df3d6f, 0x867844db, 0x81caaff3, 0x3eb968c4, 0x2c382434, 0x5fc2a340, 0x72161dc3, 0x0cbce225, 0x8b283c49, 0x41ff0d95, 0x7139a801, 0xde080cb3, 0x9cd8b4e4, 0x906456c1, - 0x617bcb84, 0x70d532b6, 0x74486c5c, 0x42d0b857 + 0x617bcb84, 0x70d532b6, 0x74486c5c, 0x42d0b857, ] T7 = [ 0xa75051f4, 0x65537e41, 0xa4c31a17, 0x5e963a27, 0x6bcb3bab, 0x45f11f9d, @@ -412,7 +413,7 @@ class AES(object): 0xce791814, 0x37bf73c7, 0xcdea53f7, 0xaa5b5ffd, 0x6f14df3d, 0xdb867844, 0xf381caaf, 0xc43eb968, 0x342c3824, 0x405fc2a3, 0xc372161d, 0x250cbce2, 0x498b283c, 0x9541ff0d, 0x017139a8, 0xb3de080c, 0xe49cd8b4, 0xc1906456, - 0x84617bcb, 0xb670d532, 0x5c74486c, 0x5742d0b8 + 0x84617bcb, 0xb670d532, 0x5c74486c, 0x5742d0b8, ] T8 = [ 0xf4a75051, 0x4165537e, 0x17a4c31a, 0x275e963a, 0xab6bcb3b, 0x9d45f11f, @@ -457,7 +458,7 @@ class AES(object): 0x14ce7918, 0xc737bf73, 0xf7cdea53, 0xfdaa5b5f, 0x3d6f14df, 0x44db8678, 0xaff381ca, 0x68c43eb9, 0x24342c38, 0xa3405fc2, 0x1dc37216, 0xe2250cbc, 0x3c498b28, 0x0d9541ff, 0xa8017139, 0x0cb3de08, 0xb4e49cd8, 0x56c19064, - 0xcb84617b, 0x32b670d5, 0x6c5c7448, 0xb85742d0 + 0xcb84617b, 0x32b670d5, 0x6c5c7448, 0xb85742d0, ] # Transformations for decryption key expansion @@ -504,7 +505,7 @@ class AES(object): 0x5bfb7e34, 0x55f2733f, 0x7fcd500e, 0x71c45d05, 0x63df4a18, 0x6dd64713, 0xd731dcca, 0xd938d1c1, 0xcb23c6dc, 0xc52acbd7, 0xef15e8e6, 0xe11ce5ed, 0xf307f2f0, 0xfd0efffb, 0xa779b492, 0xa970b999, 0xbb6bae84, 0xb562a38f, - 0x9f5d80be, 0x91548db5, 0x834f9aa8, 0x8d4697a3 + 0x9f5d80be, 0x91548db5, 0x834f9aa8, 0x8d4697a3, ] U2 = [ 0x00000000, 0x0b0e090d, 0x161c121a, 0x1d121b17, 0x2c382434, 0x27362d39, @@ -549,7 +550,7 @@ class AES(object): 0x345bfb7e, 0x3f55f273, 0x0e7fcd50, 0x0571c45d, 0x1863df4a, 0x136dd647, 0xcad731dc, 0xc1d938d1, 0xdccb23c6, 0xd7c52acb, 0xe6ef15e8, 0xede11ce5, 0xf0f307f2, 0xfbfd0eff, 0x92a779b4, 0x99a970b9, 0x84bb6bae, 0x8fb562a3, - 0xbe9f5d80, 0xb591548d, 0xa8834f9a, 0xa38d4697 + 0xbe9f5d80, 0xb591548d, 0xa8834f9a, 0xa38d4697, ] U3 = [ 0x00000000, 0x0d0b0e09, 0x1a161c12, 0x171d121b, 0x342c3824, 0x3927362d, @@ -594,7 +595,7 @@ class AES(object): 0x7e345bfb, 0x733f55f2, 0x500e7fcd, 0x5d0571c4, 0x4a1863df, 0x47136dd6, 0xdccad731, 0xd1c1d938, 0xc6dccb23, 0xcbd7c52a, 0xe8e6ef15, 0xe5ede11c, 0xf2f0f307, 0xfffbfd0e, 0xb492a779, 0xb999a970, 0xae84bb6b, 0xa38fb562, - 0x80be9f5d, 0x8db59154, 0x9aa8834f, 0x97a38d46 + 0x80be9f5d, 0x8db59154, 0x9aa8834f, 0x97a38d46, ] U4 = [ 0x00000000, 0x090d0b0e, 0x121a161c, 0x1b171d12, 0x24342c38, 0x2d392736, @@ -639,13 +640,13 @@ class AES(object): 0xfb7e345b, 0xf2733f55, 0xcd500e7f, 0xc45d0571, 0xdf4a1863, 0xd647136d, 0x31dccad7, 0x38d1c1d9, 0x23c6dccb, 0x2acbd7c5, 0x15e8e6ef, 0x1ce5ede1, 0x07f2f0f3, 0x0efffbfd, 0x79b492a7, 0x70b999a9, 0x6bae84bb, 0x62a38fb5, - 0x5d80be9f, 0x548db591, 0x4f9aa883, 0x4697a38d + 0x5d80be9f, 0x548db591, 0x4f9aa883, 0x4697a38d, ] def __init__(self, key): # noqa: C901 if len(key) not in (16, 24, 32): - raise ValueError('Invalid key size') + raise ValueError("Invalid key size") rounds = self.number_of_rounds[len(key)] @@ -659,10 +660,7 @@ def __init__(self, key): # noqa: C901 KC = len(key) // 4 # Convert the key into ints - tk = [ - struct.unpack('>i', key[i:i + 4])[0] - for i in range(0, len(key), 4) - ] + tk = [struct.unpack(">i", key[i : i + 4])[0] for i in range(0, len(key), 4)] # Copy values into round key arrays for i in range(0, KC): @@ -677,11 +675,11 @@ def __init__(self, key): # noqa: C901 tt = tk[KC - 1] # noqa: W504 tk[0] ^= ( - (self.S[(tt >> 16) & 0xFF] << 24) ^ - (self.S[(tt >> 8) & 0xFF] << 16) ^ - (self.S[tt & 0xFF] << 8) ^ - self.S[(tt >> 24) & 0xFF] ^ - (self.rcon[rconpointer] << 24) + (self.S[(tt >> 16) & 0xFF] << 24) + ^ (self.S[(tt >> 8) & 0xFF] << 16) + ^ (self.S[tt & 0xFF] << 8) + ^ self.S[(tt >> 24) & 0xFF] + ^ (self.rcon[rconpointer] << 24) ) rconpointer += 1 @@ -696,10 +694,10 @@ def __init__(self, key): # noqa: C901 tt = tk[KC // 2 - 1] tk[KC // 2] ^= ( - self.S[tt & 0xFF] ^ - (self.S[(tt >> 8) & 0xFF] << 8) ^ - (self.S[(tt >> 16) & 0xFF] << 16) ^ - (self.S[(tt >> 24) & 0xFF] << 24) + self.S[tt & 0xFF] + ^ (self.S[(tt >> 8) & 0xFF] << 8) + ^ (self.S[(tt >> 16) & 0xFF] << 16) + ^ (self.S[(tt >> 24) & 0xFF] << 24) ) for i in range(KC // 2 + 1, KC): @@ -718,35 +716,37 @@ def __init__(self, key): # noqa: C901 for j in range(0, 4): tt = self._Kd[r][j] self._Kd[r][j] = ( - self.U1[(tt >> 24) & 0xFF] ^ - self.U2[(tt >> 16) & 0xFF] ^ - self.U3[(tt >> 8) & 0xFF] ^ - self.U4[tt & 0xFF] + self.U1[(tt >> 24) & 0xFF] + ^ self.U2[(tt >> 16) & 0xFF] + ^ self.U3[(tt >> 8) & 0xFF] + ^ self.U4[tt & 0xFF] ) def encrypt(self, plaintext): - 'Encrypt a block of plain text using the AES block cipher.' + "Encrypt a block of plain text using the AES block cipher." if len(plaintext) != 16: - raise ValueError('wrong block length') + raise ValueError("wrong block length") rounds = len(self._Ke) - 1 (s1, s2, s3) = [1, 2, 3] a = [0, 0, 0, 0] # Convert plaintext to (ints ^ key) - t = [(_compact_word(plaintext[4 * i:4 * i + 4]) ^ self._Ke[0][i]) - for i in range(0, 4)] + t = [ + (_compact_word(plaintext[4 * i : 4 * i + 4]) ^ self._Ke[0][i]) + for i in range(0, 4) + ] # Apply round transforms for r in range(1, rounds): for i in range(0, 4): a[i] = ( - self.T1[(t[i] >> 24) & 0xFF] ^ - self.T2[(t[(i + s1) % 4] >> 16) & 0xFF] ^ - self.T3[(t[(i + s2) % 4] >> 8) & 0xFF] ^ - self.T4[t[(i + s3) % 4] & 0xFF] ^ - self._Ke[r][i] + self.T1[(t[i] >> 24) & 0xFF] + ^ self.T2[(t[(i + s1) % 4] >> 16) & 0xFF] + ^ self.T3[(t[(i + s2) % 4] >> 8) & 0xFF] + ^ self.T4[t[(i + s3) % 4] & 0xFF] + ^ self._Ke[r][i] ) t = copy.copy(a) @@ -762,27 +762,30 @@ def encrypt(self, plaintext): return result def decrypt(self, ciphertext): - 'Decrypt a block of cipher text using the AES block cipher.' + "Decrypt a block of cipher text using the AES block cipher." if len(ciphertext) != 16: - raise ValueError('wrong block length') + raise ValueError("wrong block length") rounds = len(self._Kd) - 1 (s1, s2, s3) = [3, 2, 1] a = [0, 0, 0, 0] # Convert ciphertext to (ints ^ key) - t = [(_compact_word(ciphertext[4 * i:4 * i + 4]) ^ self._Kd[0][i]) - for i in range(0, 4)] + t = [ + (_compact_word(ciphertext[4 * i : 4 * i + 4]) ^ self._Kd[0][i]) + for i in range(0, 4) + ] # Apply round transforms for r in range(1, rounds): for i in range(0, 4): a[i] = ( - self.T5[(t[i] >> 24) & 0xFF] ^ - self.T6[(t[(i + s1) % 4] >> 16) & 0xFF] ^ - self.T7[(t[(i + s2) % 4] >> 8) & 0xFF] ^ - self.T8[t[(i + s3) % 4] & 0xFF] ^ self._Kd[r][i] + self.T5[(t[i] >> 24) & 0xFF] + ^ self.T6[(t[(i + s1) % 4] >> 16) & 0xFF] + ^ self.T7[(t[(i + s2) % 4] >> 8) & 0xFF] + ^ self.T8[t[(i + s3) % 4] & 0xFF] + ^ self._Kd[r][i] ) t = copy.copy(a) @@ -799,28 +802,28 @@ def decrypt(self, ciphertext): class AESBlockModeOfOperation(object): - '''Super-class for AES modes of operation that require blocks.''' + """Super-class for AES modes of operation that require blocks.""" def __init__(self, key): self._aes = AES(key) def decrypt(self, ciphertext): - raise Exception('not implemented') + raise Exception("not implemented") def encrypt(self, plaintext): - raise Exception('not implemented') + raise Exception("not implemented") class AESModeOfOperationECB(AESBlockModeOfOperation): - '''AES Electronic Codebook Mode of Operation. - o Block-cipher, so data must be padded to 16 byte boundaries - Security Notes: - o This mode is not recommended - o Any two identical blocks produce identical encrypted values, - exposing data patterns. (See the image of Tux on wikipedia) - Also see: - o https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#Electronic_codebook_.28ECB.29 - o See NIST SP800-38A (http://csrc.nist.gov/publications/nistpubs/800-38a/sp800-38a.pdf); section 6.1''' + """AES Electronic Codebook Mode of Operation. + o Block-cipher, so data must be padded to 16 byte boundaries + Security Notes: + o This mode is not recommended + o Any two identical blocks produce identical encrypted values, + exposing data patterns. (See the image of Tux on wikipedia) + Also see: + o https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#Electronic_codebook_.28ECB.29 + o See NIST SP800-38A (http://csrc.nist.gov/publications/nistpubs/800-38a/sp800-38a.pdf); section 6.1""" name = "Electronic Codebook (ECB)" @@ -829,14 +832,14 @@ def __init__(self, key: str): def encrypt(self, data: bytes) -> bytes: if len(data) != 16: - raise ValueError('plain block must be 16 bytes') + raise ValueError("plain block must be 16 bytes") plain = _bytes_to_int(data) return _int_to_bytes(self._aes.encrypt(plain)) def decrypt(self, data: bytes) -> bytes: if len(data) != 16: - raise ValueError('cipher block must be 16 bytes') + raise ValueError("cipher block must be 16 bytes") cipher = _bytes_to_int(data) return _int_to_bytes(self._aes.decrypt(cipher)) @@ -845,15 +848,15 @@ def decrypt(self, data: bytes) -> bytes: def aes_encrypt(data: bytes, key: str) -> bytes: data = add_padding(data) cipher = list() - for x in [data[i:i + 16] for i in range(0, len(data), 16)]: + for x in [data[i : i + 16] for i in range(0, len(data), 16)]: cipher.append(AESModeOfOperationECB(key).encrypt(x)) - return b''.join(cipher) + return b"".join(cipher) def aes_decrypt(data: bytes, key: str) -> bytes: plain = list() if len(data) % 16 != 0: - raise Exception('invalid length') - for x in [data[i:i + 16] for i in range(0, len(data), 16)]: + raise Exception("invalid length") + for x in [data[i : i + 16] for i in range(0, len(data), 16)]: plain.append(AESModeOfOperationECB(key).decrypt(x)) - return strip_padding(b''.join(plain)) + return strip_padding(b"".join(plain)) diff --git a/pyencrypt/cli.py b/pyencrypt/cli.py index 25cb125..cdfff61 100644 --- a/pyencrypt/cli.py +++ b/pyencrypt/cli.py @@ -10,11 +10,11 @@ from pyencrypt import __description__, __version__ from pyencrypt.decrypt import decrypt_file -from pyencrypt.encrypt import (can_encrypt, encrypt_file, encrypt_key, generate_so_file) +from pyencrypt.encrypt import can_encrypt, encrypt_file, encrypt_key, generate_so_file from pyencrypt.generate import generate_aes_key from pyencrypt.license import MAX_DATETIME, MIN_DATETIME, generate_license_file -VERSION = fr""" +VERSION = rf""" _ _ __ _ _ ___ _ __ ___ _ __ _ _ _ __ | |_ | '_ \| | | |/ _ \ '_ \ / __| '__| | | | '_ \| __| @@ -33,18 +33,20 @@ """ PYTHON_MAJOR, PYTHON_MINOR = sys.version_info[:2] -LOADER_FILE_NAME = click.style("encrypted/{}", blink=True, fg='blue') -LICENSE_FILE_NAME = click.style("license.lic", blink=True, fg='blue') +LOADER_FILE_NAME = click.style("encrypted/{}", blink=True, fg="blue") +LICENSE_FILE_NAME = click.style("license.lic", blink=True, fg="blue") -SUCCESS_ANSI = click.style('successfully', fg='green') +SUCCESS_ANSI = click.style("successfully", fg="green") -INVALID_KEY_MSG = click.style('Your encryption 🔑 is invalid.', fg='red') +INVALID_KEY_MSG = click.style("Your encryption 🔑 is invalid.", fg="red") -INVALID_MAC_MSG = click.style('{} is not a valid mac address.', fg='red') +INVALID_MAC_MSG = click.style("{} is not a valid mac address.", fg="red") -INVALID_IPV4_MSG = click.style('{} is not a valid ipv4 address.', fg='red') +INVALID_IPV4_MSG = click.style("{} is not a valid ipv4 address.", fg="red") -INVALID_DATETIME_MSG = click.style('Before date must be less than after date.', fg='red') +INVALID_DATETIME_MSG = click.style( + "Before date must be less than after date.", fg="red" +) FINISH_ENCRYPT_MSG = f""" Encryption completed {SUCCESS_ANSI}. @@ -65,11 +67,11 @@ Generate license file {SUCCESS_ANSI}. Your license file is located in {LICENSE_FILE_NAME} """ -DATETIME_FORMATS = ['%Y-%m-%dT%H:%M:%S %z', '%Y-%m-%d %H:%M:%S', '%Y-%m-%d'] +DATETIME_FORMATS = ["%Y-%m-%dT%H:%M:%S %z", "%Y-%m-%d %H:%M:%S", "%Y-%m-%d"] class KeyParamType(click.ParamType): - name = 'key' + name = "key" def _check_key(self, key: str) -> bool: return not (len(key) % 4 or len(base64.b64decode(key)) % 16) @@ -81,15 +83,15 @@ def convert(self, value, param, ctx) -> str: return value def get_metavar(self, param): - return '🔑' + return "🔑" def __repr__(self) -> str: return "KEY" class MacAddressParamType(click.ParamType): - name = 'mac_address' - pattern = re.compile(r'^([0-9a-fA-F]{2}[:-]){5}([0-9a-fA-F]{2})$') + name = "mac_address" + pattern = re.compile(r"^([0-9a-fA-F]{2}[:-]){5}([0-9a-fA-F]{2})$") def convert(self, value, param, ctx) -> str: value = click.STRING.convert(value, param, ctx) @@ -98,14 +100,14 @@ def convert(self, value, param, ctx) -> str: return value def get_metavar(self, param): - return '01:23:45:67:89:AB' + return "01:23:45:67:89:AB" def __repr__(self) -> str: return "MacAddress" class IPv4AddressParamType(click.ParamType): - name = 'ipv4_address' + name = "ipv4_address" def convert(self, value, param, ctx) -> str: value = click.STRING.convert(value, param, ctx) @@ -115,7 +117,7 @@ def convert(self, value, param, ctx) -> str: self.fail(INVALID_IPV4_MSG.format(value), param, ctx) def get_metavar(self, param): - return '192.168.0.1' + return "192.168.0.1" def __repr__(self) -> str: return "Ipv4Address" @@ -128,29 +130,75 @@ class CustomParamType: @click.group() -@click.version_option(__version__, '--version', message=VERSION) -@click.help_option('-h', '--help') +@click.version_option(__version__, "--version", message=VERSION) +@click.help_option("-h", "--help") def cli(): pass -@cli.command(name='encrypt') -@click.argument('pathname', type=click.Path(exists=True, resolve_path=True)) -@click.option('-i', '--in-place', 'replace', default=False, help='make changes to files in place', is_flag=True) -@click.option('-k', '--key', default=None, help=KEY_OPTION_HELP, type=CustomParamType.KEY) -@click.option('--with-license', default=False, help='Add license to encrypted file', is_flag=True) -@click.option('-m', '--bind-mac', 'mac', default=None, help='Bind mac address to encrypted file', type=CustomParamType.MAC_ADDR) -@click.option('-4', '--bind-ipv4', 'ipv4', default=None, help='Bind ipv4 address to encrypted file', type=CustomParamType.IPV4_ADDR) -@click.option('-b', '--before', default=MIN_DATETIME, help='License is invalid before this date.', type=click.DateTime(formats=DATETIME_FORMATS)) -@click.option('-a', '--after', default=MAX_DATETIME, help='License is invalid after this date.', type=click.DateTime(formats=DATETIME_FORMATS)) -@click.confirmation_option('-y', '--yes', prompt='Are you sure you want to encrypt your python file?', help='Automatically answer yes for confirm questions.') -@click.help_option('-h', '--help') +@cli.command(name="encrypt") +@click.argument("pathname", type=click.Path(exists=True, resolve_path=True)) +@click.option( + "-i", + "--in-place", + "replace", + default=False, + help="make changes to files in place", + is_flag=True, +) +@click.option( + "-k", "--key", default=None, help=KEY_OPTION_HELP, type=CustomParamType.KEY +) +@click.option( + "--with-license", default=False, help="Add license to encrypted file", is_flag=True +) +@click.option( + "-m", + "--bind-mac", + "mac", + default=None, + help="Bind mac address to encrypted file", + type=CustomParamType.MAC_ADDR, +) +@click.option( + "-4", + "--bind-ipv4", + "ipv4", + default=None, + help="Bind ipv4 address to encrypted file", + type=CustomParamType.IPV4_ADDR, +) +@click.option( + "-b", + "--before", + default=MIN_DATETIME, + help="License is invalid before this date.", + type=click.DateTime(formats=DATETIME_FORMATS), +) +@click.option( + "-a", + "--after", + default=MAX_DATETIME, + help="License is invalid after this date.", + type=click.DateTime(formats=DATETIME_FORMATS), +) +@click.confirmation_option( + "-y", + "--yes", + prompt="Are you sure you want to encrypt your python file?", + help="Automatically answer yes for confirm questions.", +) +@click.help_option("-h", "--help") @click.pass_context -def encrypt_command(ctx, pathname, replace, key, with_license, mac, ipv4, before, after): +def encrypt_command( + ctx, pathname, replace, key, with_license, mac, ipv4, before, after +): """Encrypt your python code""" if key is None: key = generate_aes_key().decode() - click.echo(f'Your randomly encryption 🔑 is {click.style(key,underline=True, fg="yellow")}') + click.echo( + f'Your randomly encryption 🔑 is {click.style(key,underline=True, fg="yellow")}' + ) if before > after: ctx.fail(INVALID_DATETIME_MSG) @@ -159,25 +207,25 @@ def encrypt_command(ctx, pathname, replace, key, with_license, mac, ipv4, before if path.is_file(): if replace: - new_path = path.with_suffix('.pye') + new_path = path.with_suffix(".pye") else: - new_path = Path(os.getcwd()) / path.with_suffix('.pye').name + new_path = Path(os.getcwd()) / path.with_suffix(".pye").name encrypt_file(path, key, replace, new_path) elif path.is_dir(): if replace: work_dir = path else: - work_dir = Path(os.getcwd()) / 'encrypted' / path.name + work_dir = Path(os.getcwd()) / "encrypted" / path.name work_dir.exists() and shutil.rmtree(work_dir) shutil.copytree(path, work_dir) - files = set(work_dir.glob('**/*.py')) - with click.progressbar(files, label='🔐 Encrypting') as bar: + files = set(work_dir.glob("**/*.py")) + with click.progressbar(files, label="🔐 Encrypting") as bar: for file in bar: - new_path = file.with_suffix('.pye') + new_path = file.with_suffix(".pye") if can_encrypt(file): encrypt_file(file, key, True, new_path) else: - raise Exception(f'{path} is not a valid path.') + raise Exception(f"{path} is not a valid path.") cipher_key, d, n = encrypt_key(key.encode()) # 需要放进导入器中 loader_extension = generate_so_file(cipher_key, d, n, license=with_license) @@ -187,11 +235,20 @@ def encrypt_command(ctx, pathname, replace, key, with_license, mac, ipv4, before click.echo(FINISH_ENCRYPT_MSG.format(loader_extension.name)) -@cli.command(name='decrypt') -@click.argument('pathname', type=click.Path(exists=True, resolve_path=True)) -@click.option('-i', '--in-place', 'replace', default=False, help='make changes to files in place', is_flag=True) -@click.option('-k', '--key', required=True, help='Your encryption key.', type=CustomParamType.KEY) -@click.help_option('-h', '--help') +@cli.command(name="decrypt") +@click.argument("pathname", type=click.Path(exists=True, resolve_path=True)) +@click.option( + "-i", + "--in-place", + "replace", + default=False, + help="make changes to files in place", + is_flag=True, +) +@click.option( + "-k", "--key", required=True, help="Your encryption key.", type=CustomParamType.KEY +) +@click.help_option("-h", "--help") @click.pass_context def decrypt_command(ctx, pathname, replace, key): """Decrypt encrypted pye file""" @@ -199,9 +256,9 @@ def decrypt_command(ctx, pathname, replace, key): if path.is_file(): if replace: - new_path = path.with_suffix('.py') + new_path = path.with_suffix(".py") else: - new_path = Path(os.getcwd()) / path.with_suffix('.py').name + new_path = Path(os.getcwd()) / path.with_suffix(".py").name work_dir = new_path.parent origin_data = decrypt_file(path, key, replace, new_path) print(origin_data.decode()) @@ -209,23 +266,25 @@ def decrypt_command(ctx, pathname, replace, key): if replace: work_dir = path else: - work_dir = Path(os.getcwd()) / 'decrypted' / path.name + work_dir = Path(os.getcwd()) / "decrypted" / path.name work_dir.exists() and shutil.rmtree(work_dir) shutil.copytree(path, work_dir) - files = list(work_dir.glob('**/*.pye')) - with click.progressbar(files, label='🔓 Decrypting') as bar: + files = list(work_dir.glob("**/*.pye")) + with click.progressbar(files, label="🔓 Decrypting") as bar: for file in bar: - new_path = file.with_suffix('.py') + new_path = file.with_suffix(".py") decrypt_file(file, key, True, new_path) else: - raise Exception(f'{path} is not a valid path.') + raise Exception(f"{path} is not a valid path.") click.echo(FINISH_DECRYPT_MSG.format(work_dir=work_dir)) -@cli.command(name='generate') -@click.option('-k', '--key', required=True, help='Your encryption key.', type=CustomParamType.KEY) -@click.help_option('-h', '--help') +@cli.command(name="generate") +@click.option( + "-k", "--key", required=True, help="Your encryption key.", type=CustomParamType.KEY +) +@click.help_option("-h", "--help") @click.pass_context def generate_loader(ctx, key): """Generate loader file using specified key""" @@ -234,13 +293,31 @@ def generate_loader(ctx, key): click.echo(FINISH_GENERATE_LOADER_MSG.format(loader_extension.name)) -@cli.command(name='license') -@click.help_option('-h', '--help') -@click.option('-k', '--key', required=True, help='Your encryption key.', type=CustomParamType.KEY) -@click.option('-m', '--bind-mac', help='Your mac address.', type=CustomParamType.MAC_ADDR) -@click.option('-4', '--bind-ipv4', help='Your ipv4 address.', type=CustomParamType.IPV4_ADDR) -@click.option('-b', '--before', default=MIN_DATETIME, help='License is invalid before this date.', type=click.DateTime(formats=DATETIME_FORMATS)) -@click.option('-a', '--after', default=MAX_DATETIME, help='License is invalid after this date.', type=click.DateTime(formats=DATETIME_FORMATS)) +@cli.command(name="license") +@click.help_option("-h", "--help") +@click.option( + "-k", "--key", required=True, help="Your encryption key.", type=CustomParamType.KEY +) +@click.option( + "-m", "--bind-mac", help="Your mac address.", type=CustomParamType.MAC_ADDR +) +@click.option( + "-4", "--bind-ipv4", help="Your ipv4 address.", type=CustomParamType.IPV4_ADDR +) +@click.option( + "-b", + "--before", + default=MIN_DATETIME, + help="License is invalid before this date.", + type=click.DateTime(formats=DATETIME_FORMATS), +) +@click.option( + "-a", + "--after", + default=MAX_DATETIME, + help="License is invalid after this date.", + type=click.DateTime(formats=DATETIME_FORMATS), +) @click.pass_context def generate_license(ctx, key, mac, ipv4, before, after): """Generate license file using specified key""" @@ -251,5 +328,5 @@ def generate_license(ctx, key, mac, ipv4, before, after): click.echo(FINISH_GENERATE_LICENSE_MSG) -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/pyencrypt/decrypt.py b/pyencrypt/decrypt.py index ee876ef..140fd77 100644 --- a/pyencrypt/decrypt.py +++ b/pyencrypt/decrypt.py @@ -6,22 +6,24 @@ def decrypt_key(cipher_key: str, d: int, n: int) -> str: plain_ls = list() - for num in map(int, cipher_key.split('O')): + for num in map(int, cipher_key.split("O")): plain_ls.append(pow(num, d, n)) # 去掉intt后末尾多余的0 - return ''.join(map(chr, filter(lambda x: x != 0, intt(plain_ls)))) + return "".join(map(chr, filter(lambda x: x != 0, intt(plain_ls)))) def _decrypt_file(data: bytes, key: str) -> bytes: return aes_decrypt(data, key) -def decrypt_file(path: Path, key: str, delete_origin: bool = False, new_path: Path = None) -> bytes: - if path.suffix != '.pye': +def decrypt_file( + path: Path, key: str, delete_origin: bool = False, new_path: Path = None +) -> bytes: + if path.suffix != ".pye": raise Exception(f"{path.name} can't be decrypted.") data = _decrypt_file(path.read_bytes(), key) if new_path: - if new_path.suffix != '.py': + if new_path.suffix != ".py": raise Exception("Origin file path must be py suffix.") new_path.touch(exist_ok=True) new_path.write_bytes(data) diff --git a/pyencrypt/encrypt.py b/pyencrypt/encrypt.py index 1c328d3..262c3ac 100644 --- a/pyencrypt/encrypt.py +++ b/pyencrypt/encrypt.py @@ -11,10 +11,10 @@ from pyencrypt.ntt import ntt NOT_ALLOWED_ENCRYPT_FILES = [ - '__init__.py', + "__init__.py", ] -REMOVE_SELF_IMPORT = re.compile(r'^from pyencrypt\.[\s\S]*?$', re.MULTILINE) +REMOVE_SELF_IMPORT = re.compile(r"^from pyencrypt\.[\s\S]*?$", re.MULTILINE) def _encrypt_file( @@ -27,9 +27,9 @@ def _encrypt_file( def can_encrypt(path: Path) -> bool: if path.name in NOT_ALLOWED_ENCRYPT_FILES: return False - if 'management/commands/' in path.as_posix(): + if "management/commands/" in path.as_posix(): return False - if path.suffix != '.py': + if path.suffix != ".py": return False return True @@ -37,7 +37,7 @@ def can_encrypt(path: Path) -> bool: def encrypt_key(key: bytes): ascii_ls = [ord(x) for x in key.decode()] numbers = generate_rsa_number(2048) - e, n = numbers['e'], numbers['n'] + e, n = numbers["e"], numbers["n"] # fill length to be a power of 2 length = len(ascii_ls) if length & (length - 1) != 0: @@ -47,20 +47,28 @@ def encrypt_key(key: bytes): # ntt后再用RSA加密 for num in ntt(ascii_ls): cipher_ls.append(pow(num, e, n)) - return 'O'.join(map(str, cipher_ls)), numbers['d'], numbers['n'] + return "O".join(map(str, cipher_ls)), numbers["d"], numbers["n"] -def generate_so_file(cipher_key: str, d: int, n: int, base_dir: Optional[Path] = None, license: bool = False) -> Path: - private_key = f'{n}O{d}' +def generate_so_file( + cipher_key: str, + d: int, + n: int, + base_dir: Optional[Path] = None, + license: bool = False, +) -> Path: + private_key = f"{n}O{d}" path = Path(os.path.abspath(__file__)).parent decrypt_source_ls = list() - need_import_files = ['ntt.py', 'aes.py', 'decrypt.py', 'license.py'] + need_import_files = ["ntt.py", "aes.py", "decrypt.py", "license.py"] for file in need_import_files: file_path = path / file - decrypt_source_ls.append(REMOVE_SELF_IMPORT.sub('', file_path.read_text(encoding="utf-8"))) + decrypt_source_ls.append( + REMOVE_SELF_IMPORT.sub("", file_path.read_text(encoding="utf-8")) + ) - loader_source_path = path / 'loader.py' + loader_source_path = path / "loader.py" loader_source = ( REMOVE_SELF_IMPORT.sub("", loader_source_path.read_text(encoding="utf-8")) .replace("__private_key = ''", f"__private_key = '{private_key}'", 1) @@ -71,15 +79,15 @@ def generate_so_file(cipher_key: str, d: int, n: int, base_dir: Optional[Path] = if base_dir is None: base_dir = Path(os.getcwd()) - temp_dir = base_dir / 'encrypted' + temp_dir = base_dir / "encrypted" temp_dir.mkdir(exist_ok=True) - loader_file_path = temp_dir / 'loader.py' + loader_file_path = temp_dir / "loader.py" loader_file_path.touch(exist_ok=True) - decrypt_source = '\n'.join(decrypt_source_ls) + decrypt_source = "\n".join(decrypt_source_ls) # Origin file - loader_origin_file_path = temp_dir / 'loader_origin.py' + loader_origin_file_path = temp_dir / "loader_origin.py" loader_origin_file_path.touch(exist_ok=True) loader_origin_file_path.write_text( f"{decrypt_source}\n{loader_source}", encoding="utf-8" @@ -93,14 +101,15 @@ def generate_so_file(cipher_key: str, d: int, n: int, base_dir: Optional[Path] = from setuptools import setup # isort:skip from Cython.Build import cythonize from Cython.Distutils import build_ext + setup( ext_modules=cythonize(loader_file_path.as_posix(), language_level="3"), - script_args=['build_ext', '--build-lib', temp_dir.as_posix()], - cmdclass={'build_ext': build_ext}, + script_args=["build_ext", "--build-lib", temp_dir.as_posix()], + cmdclass={"build_ext": build_ext}, ) - if sys.platform.startswith('win'): + if sys.platform.startswith("win"): # loader.cp36-win_amd64.pyd - pattern = 'loader.cp*-*.pyd' + pattern = "loader.cp*-*.pyd" else: # loader.cpython-36m-x86_64-linux-gnu.so # loader.cpython-36m-darwin.so @@ -112,12 +121,14 @@ def generate_so_file(cipher_key: str, d: int, n: int, base_dir: Optional[Path] = return loader_extension.absolute() -def encrypt_file(path: Path, key: str, delete_origin: bool = False, new_path: Optional[Path] = None): +def encrypt_file( + path: Path, key: str, delete_origin: bool = False, new_path: Optional[Path] = None +): if not can_encrypt(path): raise Exception(f"{path.name} can't be encrypted.") encrypted_data = _encrypt_file(path.read_bytes(), key) if new_path: - if new_path.suffix != '.pye': + if new_path.suffix != ".pye": raise Exception("Encrypted file path must be pye suffix.") new_path.touch(exist_ok=True) new_path.write_bytes(encrypted_data) diff --git a/pyencrypt/generate.py b/pyencrypt/generate.py index d474abe..9cbeeca 100644 --- a/pyencrypt/generate.py +++ b/pyencrypt/generate.py @@ -10,4 +10,4 @@ def generate_aes_key(size: int = 32) -> bytes: def generate_rsa_number(bits: int) -> dict: r = RSA.generate(bits) - return {'p': r.p, 'q': r.q, 'n': r.n, 'e': r.e, 'd': r.d} + return {"p": r.p, "q": r.q, "n": r.n, "e": r.e, "d": r.d} diff --git a/pyencrypt/license.py b/pyencrypt/license.py index 2d3abc4..9a31aed 100644 --- a/pyencrypt/license.py +++ b/pyencrypt/license.py @@ -10,17 +10,17 @@ from pyencrypt.aes import aes_encrypt -DATE_FORMAT = '%Y-%m-%dT%H:%M:%S%z' +DATE_FORMAT = "%Y-%m-%dT%H:%M:%S%z" MIN_DATETIME = datetime.now().astimezone() MAX_DATETIME = datetime(year=2999, month=12, day=31).astimezone() def get_mac_address() -> str: - return ':'.join(("%012X" % uuid.getnode())[i:i + 2] for i in range(0, 12, 2)) + return ":".join(("%012X" % uuid.getnode())[i : i + 2] for i in range(0, 12, 2)) def get_host_ipv4() -> str: - if sys.platform == 'darwin': + if sys.platform == "darwin": command = "ifconfig | grep 'inet ' | grep -Fv 127.0.0.1 | awk '{print $2}'" return subprocess.check_output(command, shell=True).decode().split()[-1] else: @@ -31,11 +31,11 @@ def get_signature(data: bytes) -> str: return hashlib.sha256(data).hexdigest() -FIELDS = ['invalid_before', 'invalid_after', 'ipv4', 'mac'] +FIELDS = ["invalid_before", "invalid_after", "ipv4", "mac"] def _combine_data(data: dict) -> bytes: - return '*'.join(map(lambda x: f'{x}:{data[x]}', FIELDS)).encode() + return "*".join(map(lambda x: f"{x}:{data[x]}", FIELDS)).encode() def generate_license_file( @@ -45,7 +45,7 @@ def generate_license_file( before: datetime = None, mac_addr: str = None, ipv4: str = None, - **extra + **extra, ): if after is None: after = MAX_DATETIME @@ -58,20 +58,20 @@ def generate_license_file( before = before.astimezone() data = { - 'invalid_before': before.strftime(DATE_FORMAT), - 'invalid_after': after.strftime(DATE_FORMAT), - 'mac': mac_addr, - 'ipv4': ipv4, + "invalid_before": before.strftime(DATE_FORMAT), + "invalid_after": after.strftime(DATE_FORMAT), + "mac": mac_addr, + "ipv4": ipv4, } encrypted_data = aes_encrypt(_combine_data(data), aes_key) signature = get_signature(encrypted_data) - data.update({'signature': signature, **extra}) + data.update({"signature": signature, **extra}) if path is None: path = Path(os.getcwd()) - license_dir = path / 'licenses' + license_dir = path / "licenses" license_dir.mkdir(exist_ok=True) - license_path = license_dir / 'license.lic' + license_path = license_dir / "license.lic" license_path.touch(exist_ok=True) license_path.write_bytes(json.dumps(data, indent=4).encode()) return license_path.absolute() @@ -79,13 +79,15 @@ def generate_license_file( def check_license(license_path: Path, aes_key: str): if not license_path.exists(): - raise FileNotFoundError(f"License file {license_path.absolute().as_posix()} not found.") + raise FileNotFoundError( + f"License file {license_path.absolute().as_posix()} not found." + ) data = json.loads(license_path.read_text()) - signature = data.pop('signature') - before = datetime.strptime(data['invalid_before'], DATE_FORMAT).astimezone() - after = datetime.strptime(data['invalid_after'], DATE_FORMAT).astimezone() - mac_address = data.get('mac') - ipv4 = data.get('ipv4') + signature = data.pop("signature") + before = datetime.strptime(data["invalid_before"], DATE_FORMAT).astimezone() + after = datetime.strptime(data["invalid_after"], DATE_FORMAT).astimezone() + mac_address = data.get("mac") + ipv4 = data.get("ipv4") now = datetime.now().astimezone() if signature != get_signature(aes_encrypt(_combine_data(data), aes_key)): raise Exception("License signature is invalid.") diff --git a/pyencrypt/loader.py b/pyencrypt/loader.py index e60e7d5..fed1a6d 100644 --- a/pyencrypt/loader.py +++ b/pyencrypt/loader.py @@ -23,15 +23,15 @@ def __dir__(self) -> Iterable[str]: class EncryptFileLoader(abc.SourceLoader, Base): POSSIBLE_PATH = [ - Path(os.path.expanduser('~')) / '.licenses' / 'license.lic', - Path(os.path.abspath(__file__)).parent / 'licenses' / 'license.lic', - Path(os.getcwd()) / 'licenses' / 'license.lic', + Path(os.path.expanduser("~")) / ".licenses" / "license.lic", + Path(os.path.abspath(__file__)).parent / "licenses" / "license.lic", + Path(os.getcwd()) / "licenses" / "license.lic", ] def __init__(self, path) -> None: - self.path = path or '' - self.__private_key = '' - self.__cipher_key = '' + self.path = path or "" + self.__private_key = "" + self.__cipher_key = "" self.license = None self.license_path = None self._init_license_path() @@ -50,10 +50,12 @@ def check(self) -> bool: return False if self.license_path is None: - raise Exception('Could not find license file.') + raise Exception("Could not find license file.") - __n, __d = self.__private_key.split('O', 1) - check_license(self.license_path, decrypt_key(self.__cipher_key, int(__d), int(__n))) + __n, __d = self.__private_key.split("O", 1) + check_license( + self.license_path, decrypt_key(self.__cipher_key, int(__d), int(__n)) + ) return True def get_filename(self, fullname: str) -> str: @@ -61,15 +63,19 @@ def get_filename(self, fullname: str) -> str: def get_data(self, path: _Path) -> bytes: try: - __n, __d = self.__private_key.split('O', 1) - return decrypt_file(Path(path), decrypt_key(self.__cipher_key, int(__d), int(__n))) + __n, __d = self.__private_key.split("O", 1) + return decrypt_file( + Path(path), decrypt_key(self.__cipher_key, int(__d), int(__n)) + ) except Exception: traceback.print_exc() - return b'' + return b"" class EncryptFileFinder(abc.MetaPathFinder, Base): - def find_spec(self, fullname: str, path: Sequence[_Path], target: types.ModuleType = None) -> ModuleSpec: + def find_spec( + self, fullname: str, path: Sequence[_Path], target: types.ModuleType = None + ) -> ModuleSpec: if path: if isinstance(path, _NamespacePath): file_path = Path(path._path[0]) / f'{fullname.rsplit(".",1)[-1]}.pye' @@ -77,14 +83,14 @@ def find_spec(self, fullname: str, path: Sequence[_Path], target: types.ModuleTy file_path = Path(path[0]) / f'{fullname.rsplit(".",1)[-1]}.pye' else: for p in sys.path: - file_path = Path(p) / f'{fullname}.pye' + file_path = Path(p) / f"{fullname}.pye" if file_path.exists(): break file_path = file_path.absolute().as_posix() if not os.path.exists(file_path): return None loader = EncryptFileLoader(file_path) - return spec_from_loader(name=fullname, loader=loader, origin='origin-encrypt') + return spec_from_loader(name=fullname, loader=loader, origin="origin-encrypt") # TODO: generate randomly AES Class diff --git a/tests/conftest.py b/tests/conftest.py index 87ff186..260ab0e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,43 +5,56 @@ def pytest_configure(config): - config.addinivalue_line("markers", "file(name,function,code): mark test to run only on named environment") - config.addinivalue_line("markers", "license(enable,kwargs): mark test to run only on named environment") - config.addinivalue_line("markers", "package(name,function,code): mark test to run only on named environment") - - -@pytest.fixture(scope='function') + config.addinivalue_line( + "markers", + "file(name,function,code): mark test to run only on named environment", + ) + config.addinivalue_line( + "markers", "license(enable,kwargs): mark test to run only on named environment" + ) + config.addinivalue_line( + "markers", + "package(name,function,code): mark test to run only on named environment", + ) + + +@pytest.fixture(scope="function") def file_and_loader(request, tmp_path_factory): - tmp_path = tmp_path_factory.mktemp('file') + tmp_path = tmp_path_factory.mktemp("file") file_marker = request.node.get_closest_marker("file") - file_name = file_marker.kwargs.get('name') - function_name = file_marker.kwargs.get('function') - code = file_marker.kwargs.get('code') + file_name = file_marker.kwargs.get("name") + function_name = file_marker.kwargs.get("function") + code = file_marker.kwargs.get("code") license_marker = request.node.get_closest_marker("license") license, kwargs = False, {} if license_marker is not None: kwargs = license_marker.kwargs - license = kwargs.pop('enable', True) + license = kwargs.pop("enable", True) - file_path = tmp_path / f'{file_name}.py' + file_path = tmp_path / f"{file_name}.py" file_path.touch() - file_path.write_text("""\ + file_path.write_text( + """\ def {function_name}(): {code} - """.format(function_name=function_name, code=code), encoding='utf-8') + """.format( + function_name=function_name, code=code + ), + encoding="utf-8", + ) # generate loader.so key = generate_aes_key() - new_path = file_path.with_suffix('.pye') + new_path = file_path.with_suffix(".pye") encrypt_file(file_path, key, new_path=new_path) file_path.unlink() cipher_key, d, n = encrypt_key(key) loader_path = generate_so_file(cipher_key, d, n, file_path.parent, license=license) work_dir = loader_path.parent - work_dir.joinpath('loader.py').unlink() - work_dir.joinpath('loader.c').unlink() - work_dir.joinpath('loader_origin.py').unlink() + work_dir.joinpath("loader.py").unlink() + work_dir.joinpath("loader.c").unlink() + work_dir.joinpath("loader_origin.py").unlink() # License license and generate_license_file(key.decode(), work_dir, **kwargs) @@ -49,34 +62,39 @@ def {function_name}(): return (new_path, loader_path) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def package_and_loader(request, tmp_path_factory): - pkg_path = tmp_path_factory.mktemp('package') + pkg_path = tmp_path_factory.mktemp("package") file_marker = request.node.get_closest_marker("package") - package_name = file_marker.kwargs.get('name') - function_name = file_marker.kwargs.get('function') - code = file_marker.kwargs.get('code') + package_name = file_marker.kwargs.get("name") + function_name = file_marker.kwargs.get("function") + code = file_marker.kwargs.get("code") license_marker = request.node.get_closest_marker("license") license, kwargs = False, {} if license_marker is not None: kwargs = license_marker.kwargs - license = kwargs.pop('enable', True) + license = kwargs.pop("enable", True) current = pkg_path - for dir_name in package_name.split('.')[:-1]: + for dir_name in package_name.split(".")[:-1]: current = current.joinpath(dir_name) current.mkdir() - current.joinpath('__init__.py').touch() + current.joinpath("__init__.py").touch() file_path = current.joinpath(f'{package_name.split(".")[-1]}.py') - file_path.write_text("""\ + file_path.write_text( + """\ def {function_name}(): {code} - """.format(function_name=function_name, code=code), encoding='utf-8') + """.format( + function_name=function_name, code=code + ), + encoding="utf-8", + ) - new_path = file_path.with_suffix('.pye') + new_path = file_path.with_suffix(".pye") key = generate_aes_key() encrypt_file(file_path, key, new_path=new_path) file_path.unlink() @@ -84,9 +102,9 @@ def {function_name}(): cipher_key, d, n = encrypt_key(key) loader_path = generate_so_file(cipher_key, d, n, pkg_path, license) work_dir = loader_path.parent - work_dir.joinpath('loader.py').unlink() - work_dir.joinpath('loader.c').unlink() - work_dir.joinpath('loader_origin.py').unlink() + work_dir.joinpath("loader.py").unlink() + work_dir.joinpath("loader.c").unlink() + work_dir.joinpath("loader_origin.py").unlink() # License license and generate_license_file(key.decode(), work_dir, **kwargs) return pkg_path, loader_path diff --git a/tests/constants.py b/tests/constants.py index 2f2678a..a9b5037 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1 +1 @@ -AES_KEY = b'tiovaZzCF/Hlx/8uw0PXCkiqxCXGosP/AKo6Kn1gECw=' +AES_KEY = b"tiovaZzCF/Hlx/8uw0PXCkiqxCXGosP/AKo6Kn1gECw=" diff --git a/tests/test_aes.py b/tests/test_aes.py index 399e961..2598fdc 100644 --- a/tests/test_aes.py +++ b/tests/test_aes.py @@ -3,35 +3,40 @@ from constants import AES_KEY -PLAIN_1 = b'hello world' -CIPHER_1 = b'\xc5\xa1\xf8\xed\xf7\xa0\x03\xd8\xffu\x01\xac\x93\xcd+\xe1' -PLAIN_PADDING_1 = PLAIN_1 + b'\x05' * 5 +PLAIN_1 = b"hello world" +CIPHER_1 = b"\xc5\xa1\xf8\xed\xf7\xa0\x03\xd8\xffu\x01\xac\x93\xcd+\xe1" +PLAIN_PADDING_1 = PLAIN_1 + b"\x05" * 5 -PLAIN_2 = '你好 世界!'.encode() -CIPHER_2 = b'\t\xb6R0B\x1fgz\x06x\x9d\xaf\xb4\xe7_\x9f' -PLAIN_PADDING_2 = PLAIN_2 + b'\x02' * 2 +PLAIN_2 = "你好 世界!".encode() +CIPHER_2 = b"\t\xb6R0B\x1fgz\x06x\x9d\xaf\xb4\xe7_\x9f" +PLAIN_PADDING_2 = PLAIN_2 + b"\x02" * 2 -PLAIN_3 = b'abcdefghijklmnop' -CIPHER_3 = b'r\x14\xa7\x92\xd6\x1f\x0c\xf4\x10g\x99\t/\xf0z\xfc' + b'&\x80\xdb\x94\xd1\xf7\x9f\xe0Qo\x05\x98\x7f\xe6j\x8c' +PLAIN_3 = b"abcdefghijklmnop" +CIPHER_3 = ( + b"r\x14\xa7\x92\xd6\x1f\x0c\xf4\x10g\x99\t/\xf0z\xfc" + + b"&\x80\xdb\x94\xd1\xf7\x9f\xe0Qo\x05\x98\x7f\xe6j\x8c" +) class TestAES: @pytest.mark.parametrize( - 'key, plain, cipher', [ + "key, plain, cipher", + [ (AES_KEY, PLAIN_1, CIPHER_1), (AES_KEY, PLAIN_2, CIPHER_2), (AES_KEY, PLAIN_3, CIPHER_3), - ] + ], ) def test_aes_encrypt(self, plain, cipher, key): assert aes_encrypt(plain, key) == cipher @pytest.mark.parametrize( - 'key, plain, cipher', [ + "key, plain, cipher", + [ (AES_KEY, PLAIN_1, CIPHER_1), (AES_KEY, PLAIN_2, CIPHER_2), (AES_KEY, PLAIN_3, CIPHER_3), - ] + ], ) def test_aes_decrypt(self, plain, cipher, key): assert aes_decrypt(cipher, key) == plain @@ -41,28 +46,34 @@ class TestAESModeOfOperationECB: def setup_class(self): self.cipher = AESModeOfOperationECB(AES_KEY) - @pytest.mark.parametrize('length', [17, 18, 19, 20]) + @pytest.mark.parametrize("length", [17, 18, 19, 20]) def test_encrypt_exception(self, length): with pytest.raises(ValueError) as excinfo: - self.cipher.encrypt(b'a' * length) - assert str(excinfo.value) == 'plain block must be 16 bytes' + self.cipher.encrypt(b"a" * length) + assert str(excinfo.value) == "plain block must be 16 bytes" - @pytest.mark.parametrize('length', [17, 18, 19, 20]) + @pytest.mark.parametrize("length", [17, 18, 19, 20]) def test_decrypt_exception(self, length): with pytest.raises(ValueError) as excinfo: - self.cipher.decrypt(b'a' * length) - assert str(excinfo.value) == 'cipher block must be 16 bytes' + self.cipher.decrypt(b"a" * length) + assert str(excinfo.value) == "cipher block must be 16 bytes" - @pytest.mark.parametrize('plain, cipher', [ - (PLAIN_PADDING_1, CIPHER_1), - (PLAIN_PADDING_2, CIPHER_2), - ]) + @pytest.mark.parametrize( + "plain, cipher", + [ + (PLAIN_PADDING_1, CIPHER_1), + (PLAIN_PADDING_2, CIPHER_2), + ], + ) def test_encrypt(self, plain, cipher): assert self.cipher.encrypt(plain) == cipher - @pytest.mark.parametrize('plain, cipher', [ - (PLAIN_PADDING_1, CIPHER_1), - (PLAIN_PADDING_2, CIPHER_2), - ]) + @pytest.mark.parametrize( + "plain, cipher", + [ + (PLAIN_PADDING_1, CIPHER_1), + (PLAIN_PADDING_2, CIPHER_2), + ], + ) def test_decrypt(self, plain, cipher): assert self.cipher.decrypt(cipher) == plain diff --git a/tests/test_decrypt.py b/tests/test_decrypt.py index e5b9739..6a1756b 100644 --- a/tests/test_decrypt.py +++ b/tests/test_decrypt.py @@ -8,10 +8,13 @@ from constants import AES_KEY -@pytest.mark.parametrize('key', [ - AES_KEY, - generate_aes_key(), -]) +@pytest.mark.parametrize( + "key", + [ + AES_KEY, + generate_aes_key(), + ], +) def test_decrypt_key(key): cipher_key, d, n = encrypt_key(key) assert decrypt_key(cipher_key, d, n) == key.decode() @@ -19,20 +22,21 @@ def test_decrypt_key(key): @pytest.fixture def encrypted_python_file_path(tmp_path): - path = tmp_path / 'test.py' + path = tmp_path / "test.py" path.touch() path.write_text('print("hello world")') - new_path = tmp_path / 'test.pye' + new_path = tmp_path / "test.pye" encrypt_file(path, AES_KEY, new_path=new_path) path.unlink() return new_path @pytest.mark.parametrize( - 'path,key,exception', [ - (Path('tests/test.py'), AES_KEY, Exception), - (Path('tests/__init__.pye'), AES_KEY, FileNotFoundError), - ] + "path,key,exception", + [ + (Path("tests/test.py"), AES_KEY, Exception), + (Path("tests/__init__.pye"), AES_KEY, FileNotFoundError), + ], ) def test_decrypt_file_exception(path, key, exception): with pytest.raises(exception) as excinfo: @@ -51,14 +55,14 @@ def test_decrypt_file_delete_origin(encrypted_python_file_path): def test_decrypt_file_new_path(encrypted_python_file_path): - new_path = encrypted_python_file_path.parent / 'test.py' + new_path = encrypted_python_file_path.parent / "test.py" decrypt_file(encrypted_python_file_path, AES_KEY, new_path=new_path) assert new_path.exists() is True assert encrypted_python_file_path.exists() is True def test_decrypt_file_new_path_exception(encrypted_python_file_path): - new_path = encrypted_python_file_path.parent / 'test.pye' + new_path = encrypted_python_file_path.parent / "test.pye" with pytest.raises(Exception) as excinfo: decrypt_file(encrypted_python_file_path, AES_KEY, new_path=new_path) - assert str(excinfo.value) == 'Origin file path must be py suffix.' + assert str(excinfo.value) == "Origin file path must be py suffix." diff --git a/tests/test_encrypt.py b/tests/test_encrypt.py index fe20324..de1fbeb 100644 --- a/tests/test_encrypt.py +++ b/tests/test_encrypt.py @@ -10,10 +10,13 @@ from constants import AES_KEY -@pytest.mark.parametrize('key', [ - AES_KEY, - generate_aes_key(), -]) +@pytest.mark.parametrize( + "key", + [ + AES_KEY, + generate_aes_key(), + ], +) def test_encrypt_key(key): cipher, d, n = encrypt_key(key) assert isinstance(cipher, str) @@ -22,13 +25,14 @@ def test_encrypt_key(key): @pytest.mark.parametrize( - 'path,expected', [ - (Path('__init__.py'), False), - (Path('pyencrypt/__init__.py'), False), - (Path('management/commands/user.py'), False), - (Path('tests/test.pye'), False), - (Path('tests/test_encrypt.py'), True), - ] + "path,expected", + [ + (Path("__init__.py"), False), + (Path("pyencrypt/__init__.py"), False), + (Path("management/commands/user.py"), False), + (Path("tests/test.pye"), False), + (Path("tests/test_encrypt.py"), True), + ], ) def test_can_encrypt(path, expected): assert can_encrypt(path) == expected @@ -36,45 +40,68 @@ def test_can_encrypt(path, expected): class TestGenarateSoFile: def setup_method(self, method): - if method.__name__ == 'test_generate_so_file_default_path': - shutil.rmtree((Path(os.getcwd()) / 'encrypted').as_posix(), ignore_errors=True) - - @pytest.mark.parametrize('key', [ - AES_KEY, - generate_aes_key(), - ]) + if method.__name__ == "test_generate_so_file_default_path": + shutil.rmtree( + (Path(os.getcwd()) / "encrypted").as_posix(), ignore_errors=True + ) + + @pytest.mark.parametrize( + "key", + [ + AES_KEY, + generate_aes_key(), + ], + ) def test_generate_so_file(self, key, tmp_path): cipher_key, d, n = encrypt_key(key) assert generate_so_file(cipher_key, d, n, tmp_path) - assert (tmp_path / 'encrypted' / 'loader.py').exists() is True - assert (tmp_path / 'encrypted' / 'loader_origin.py').exists() is True - if sys.platform.startswith('win'): - assert next((tmp_path / 'encrypted').glob('loader.cp*-*.pyd'), None) is not None + assert (tmp_path / "encrypted" / "loader.py").exists() is True + assert (tmp_path / "encrypted" / "loader_origin.py").exists() is True + if sys.platform.startswith("win"): + assert ( + next((tmp_path / "encrypted").glob("loader.cp*-*.pyd"), None) + is not None + ) else: - assert next((tmp_path / 'encrypted').glob('loader.cpython-*-*.so'), None) is not None - - @pytest.mark.parametrize('key', [ - AES_KEY, - generate_aes_key(), - ]) + assert ( + next((tmp_path / "encrypted").glob("loader.cpython-*-*.so"), None) + is not None + ) + + @pytest.mark.parametrize( + "key", + [ + AES_KEY, + generate_aes_key(), + ], + ) def test_generate_so_file_default_path(self, key): cipher_key, d, n = encrypt_key(key) assert generate_so_file(cipher_key, d, n) - assert (Path(os.getcwd()) / 'encrypted' / 'loader.py').exists() is True - assert (Path(os.getcwd()) / 'encrypted' / 'loader_origin.py').exists() is True - if sys.platform.startswith('win'): - assert next((Path(os.getcwd()) / 'encrypted').glob('loader.cp*-*.pyd'), None) is not None + assert (Path(os.getcwd()) / "encrypted" / "loader.py").exists() is True + assert (Path(os.getcwd()) / "encrypted" / "loader_origin.py").exists() is True + if sys.platform.startswith("win"): + assert ( + next((Path(os.getcwd()) / "encrypted").glob("loader.cp*-*.pyd"), None) + is not None + ) else: - assert next((Path(os.getcwd()) / 'encrypted').glob('loader.cpython-*-*.so'), None) is not None + assert ( + next( + (Path(os.getcwd()) / "encrypted").glob("loader.cpython-*-*.so"), + None, + ) + is not None + ) @pytest.mark.parametrize( - 'path,key,exception', + "path,key,exception", [ - (Path('tests/test.py'), AES_KEY, FileNotFoundError), - (Path('tests/test.pye'), AES_KEY, Exception), # TODO: 封装Exception - (Path('tests/__init__.py'), AES_KEY, Exception), - ] + (Path("tests/test.py"), AES_KEY, FileNotFoundError), + (Path("tests/test.pye"), AES_KEY, Exception), # TODO: 封装Exception + (Path("tests/__init__.py"), AES_KEY, Exception), + ], ) def test_encrypt_file_path_exception(path, key, exception): with pytest.raises(exception) as excinfo: @@ -84,7 +111,7 @@ def test_encrypt_file_path_exception(path, key, exception): @pytest.fixture def python_file_path(tmp_path): - fn = tmp_path / 'test.py' + fn = tmp_path / "test.py" fn.touch() fn.write_text('print("hello world")') return fn @@ -101,14 +128,14 @@ def test_encrypt_file_delete_origin(python_file_path): def test_encrypt_file_new_path(python_file_path): - new_path = python_file_path.parent / 'test.pye' + new_path = python_file_path.parent / "test.pye" encrypt_file(python_file_path, AES_KEY, new_path=new_path) assert new_path.exists() is True assert python_file_path.exists() is True def test_encrypt_file_new_path_exception(python_file_path): - new_path = python_file_path.parent / 'test.py' + new_path = python_file_path.parent / "test.py" with pytest.raises(Exception) as excinfo: encrypt_file(python_file_path, AES_KEY, new_path=new_path) - assert str(excinfo.value) == 'Encrypted file path must be pye suffix.' + assert str(excinfo.value) == "Encrypted file path must be pye suffix." diff --git a/tests/test_generate.py b/tests/test_generate.py index 13ea11e..ceff09c 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -7,16 +7,16 @@ def test_generate_aes_key_default(): assert isinstance(generate_aes_key(), bytes) -@pytest.mark.parametrize('size', [32, 64, 1024, 4096]) +@pytest.mark.parametrize("size", [32, 64, 1024, 4096]) def test_generate_aes_key(size): assert isinstance(generate_aes_key(size), bytes) -@pytest.mark.parametrize('bits', [1024, 1025, 2045, 2048, 4096]) +@pytest.mark.parametrize("bits", [1024, 1025, 2045, 2048, 4096]) def test_generate_rsa_number(bits): numbers = generate_rsa_number(bits) assert len(numbers) == 5 - p, q, n, e, d = numbers['p'], numbers['q'], numbers['n'], numbers['e'], numbers['d'] + p, q, n, e, d = numbers["p"], numbers["q"], numbers["n"], numbers["e"], numbers["d"] assert p * q == n assert e * d % (p - 1) == 1 assert e * d % (q - 1) == 1 @@ -24,7 +24,7 @@ def test_generate_rsa_number(bits): assert pow(pow(plain, e, n), d, n) == plain -@pytest.mark.parametrize('bits', [-1, 123]) +@pytest.mark.parametrize("bits", [-1, 123]) def test_generate_rsa_number_exception(bits): with pytest.raises(ValueError) as excinfo: generate_rsa_number(bits) diff --git a/tests/test_license.py b/tests/test_license.py index df5db90..c137778 100644 --- a/tests/test_license.py +++ b/tests/test_license.py @@ -7,7 +7,13 @@ import pytest from pyencrypt.generate import generate_aes_key -from pyencrypt.license import (FIELDS, check_license, generate_license_file, get_host_ipv4, get_mac_address) +from pyencrypt.license import ( + FIELDS, + check_license, + generate_license_file, + get_host_ipv4, + get_mac_address, +) from constants import AES_KEY @@ -15,7 +21,10 @@ def test_get_mac_address(): mac_address = get_mac_address() assert mac_address is not None - assert re.match(r"^\s*([0-9a-fA-F]{2,2}:){5,5}[0-9a-fA-F]{2,2}\s*$", mac_address) is not None + assert ( + re.match(r"^\s*([0-9a-fA-F]{2,2}:){5,5}[0-9a-fA-F]{2,2}\s*$", mac_address) + is not None + ) def test_get_host_ipv4(): @@ -26,38 +35,45 @@ def test_get_host_ipv4(): class TestGenerateLicense: def setup_method(self, method): - self.fields = FIELDS + ['signature'] - shutil.rmtree((Path(os.getcwd()) / 'licenses').as_posix(), ignore_errors=True) + self.fields = FIELDS + ["signature"] + shutil.rmtree((Path(os.getcwd()) / "licenses").as_posix(), ignore_errors=True) def teardown_method(self, method): - shutil.rmtree((Path(os.getcwd()) / 'licenses').as_posix(), ignore_errors=True) + shutil.rmtree((Path(os.getcwd()) / "licenses").as_posix(), ignore_errors=True) - @pytest.mark.parametrize('key', [ - AES_KEY, - generate_aes_key(), - ]) + @pytest.mark.parametrize( + "key", + [ + AES_KEY, + generate_aes_key(), + ], + ) def test_generate_license_file_default_path(self, key): license_file_path = generate_license_file(key) assert license_file_path.exists() is True license_data = json.loads(license_file_path.read_text()) assert set(self.fields) - set(license_data.keys()) == set() - assert license_data['mac'] is None - assert license_data['ipv4'] is None + assert license_data["mac"] is None + assert license_data["ipv4"] is None - @pytest.mark.parametrize('key', [ - AES_KEY, - generate_aes_key(), - ]) + @pytest.mark.parametrize( + "key", + [ + AES_KEY, + generate_aes_key(), + ], + ) def test_generate_license_file(self, key, tmp_path): license_file_path = generate_license_file(key, path=tmp_path) assert license_file_path.exists() is True license_data = json.loads(license_file_path.read_text()) assert set(self.fields) - set(license_data.keys()) == set() - assert license_data['mac'] is None - assert license_data['ipv4'] is None + assert license_data["mac"] is None + assert license_data["ipv4"] is None @pytest.mark.parametrize( - 'key,after,before,mac_addr,ipv4', [ + "key,after,before,mac_addr,ipv4", + [ (AES_KEY, None, None, None, None), (AES_KEY, None, datetime(2022, 1, 1), None, None), (AES_KEY, None, datetime(2022, 1, 1).astimezone(), None, None), @@ -65,7 +81,7 @@ def test_generate_license_file(self, key, tmp_path): (AES_KEY, datetime(2222, 1, 1).astimezone(), None, None, None), (AES_KEY, None, None, get_mac_address(), None), (AES_KEY, None, None, None, get_host_ipv4()), - ] + ], ) def test_check_license(self, key, after, before, mac_addr, ipv4, tmp_path): license_file_path = generate_license_file( @@ -73,49 +89,61 @@ def test_check_license(self, key, after, before, mac_addr, ipv4, tmp_path): ) assert check_license(license_file_path, key) is True - @pytest.mark.parametrize('key', [ - AES_KEY, - generate_aes_key(), - ]) + @pytest.mark.parametrize( + "key", + [ + AES_KEY, + generate_aes_key(), + ], + ) def test_check_license_invalid(self, key, tmp_path): license_file_path = generate_license_file(key, path=tmp_path) license_data = json.loads(license_file_path.read_text()) - license_data['signature'] = 'invalid' - license_file_path.write_text(json.dumps(license_data), encoding='utf-8') + license_data["signature"] = "invalid" + license_file_path.write_text(json.dumps(license_data), encoding="utf-8") with pytest.raises(Exception) as excinfo: check_license(license_file_path, key) - assert str(excinfo.value) == 'License signature is invalid.' + assert str(excinfo.value) == "License signature is invalid." @pytest.mark.parametrize( - 'key,after,before', [ + "key,after,before", + [ (AES_KEY, datetime(2000, 1, 1), None), (generate_aes_key(), datetime(2000, 1, 1), None), (AES_KEY, None, datetime(2222, 1, 1)), (generate_aes_key(), None, datetime(2222, 1, 1)), - ] + ], ) def test_check_license_expired(self, key, after, before, tmp_path): - license_file_path = generate_license_file(key, path=tmp_path, after=after, before=before) + license_file_path = generate_license_file( + key, path=tmp_path, after=after, before=before + ) with pytest.raises(Exception) as excinfo: check_license(license_file_path, key) - assert str(excinfo.value) == 'License expired.' + assert str(excinfo.value) == "License expired." - @pytest.mark.parametrize('key,mac_addr', [ - (AES_KEY, 'invalid mac address'), - (generate_aes_key(), 'invalid mac address'), - ]) + @pytest.mark.parametrize( + "key,mac_addr", + [ + (AES_KEY, "invalid mac address"), + (generate_aes_key(), "invalid mac address"), + ], + ) def test_check_license_mac_addr(self, key, mac_addr, tmp_path): license_file_path = generate_license_file(key, path=tmp_path, mac_addr=mac_addr) with pytest.raises(Exception) as excinfo: check_license(license_file_path, key) - assert str(excinfo.value) == 'Machine mac address is invalid.' + assert str(excinfo.value) == "Machine mac address is invalid." - @pytest.mark.parametrize('key,ipv4', [ - (AES_KEY, 'invalid ipv4 address'), - (generate_aes_key(), 'invalid ipv4 address'), - ]) + @pytest.mark.parametrize( + "key,ipv4", + [ + (AES_KEY, "invalid ipv4 address"), + (generate_aes_key(), "invalid ipv4 address"), + ], + ) def test_check_license_ipv4(self, key, ipv4, tmp_path): license_file_path = generate_license_file(key, path=tmp_path, ipv4=ipv4) with pytest.raises(Exception) as excinfo: check_license(license_file_path, key) - assert str(excinfo.value) == 'Machine ipv4 address is invalid.' + assert str(excinfo.value) == "Machine ipv4 address is invalid." diff --git a/tests/test_loader.py b/tests/test_loader.py index 635af40..bc97a61 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -8,121 +8,147 @@ DEAFULT_META_PATH = sys.meta_path[::] -@pytest.mark.file(name='file1', function='test_file_1', code='\treturn "This is file test1"') +@pytest.mark.file( + name="file1", function="test_file_1", code='\treturn "This is file test1"' +) def test_python_file_sys_path(file_and_loader: Tuple[Path], monkeypatch): file_path, loader_path = file_and_loader monkeypatch.syspath_prepend(file_path.parent.as_posix()) monkeypatch.syspath_prepend(loader_path.parent.as_posix()) - sys.modules.pop('loader', None) + sys.modules.pop("loader", None) sys.meta_path = DEAFULT_META_PATH.copy() import loader from file1 import test_file_1 - assert test_file_1() == 'This is file test1' + assert test_file_1() == "This is file test1" @pytest.mark.license(enable=True) -@pytest.mark.file(name='file2', function='test_file_2', code='\treturn "This is file test2"') +@pytest.mark.file( + name="file2", function="test_file_2", code='\treturn "This is file test2"' +) def test_python_file_sys_path_with_license(file_and_loader: Tuple[Path], monkeypatch): file_path, loader_path = file_and_loader monkeypatch.syspath_prepend(file_path.parent.as_posix()) monkeypatch.syspath_prepend(loader_path.parent.as_posix()) - sys.modules.pop('loader', None) + sys.modules.pop("loader", None) sys.meta_path = DEAFULT_META_PATH.copy() import loader from file2 import test_file_2 - assert test_file_2() == 'This is file test2' + assert test_file_2() == "This is file test2" @pytest.mark.license(enable=True) -@pytest.mark.file(name='file3', function='test_file_3', code='\treturn "This is file test3"') -def test_python_file_sys_path_with_license_not_found(file_and_loader: Tuple[Path], monkeypatch): +@pytest.mark.file( + name="file3", function="test_file_3", code='\treturn "This is file test3"' +) +def test_python_file_sys_path_with_license_not_found( + file_and_loader: Tuple[Path], monkeypatch +): file_path, loader_path = file_and_loader monkeypatch.syspath_prepend(file_path.parent.as_posix()) monkeypatch.syspath_prepend(loader_path.parent.as_posix()) - shutil.rmtree(loader_path.parent / 'licenses') + shutil.rmtree(loader_path.parent / "licenses") with pytest.raises(Exception) as excinfo: - sys.modules.pop('loader', None) + sys.modules.pop("loader", None) sys.meta_path = DEAFULT_META_PATH.copy() import loader from file3 import test_file_3 - assert test_file_3() == 'This is file test3' + assert test_file_3() == "This is file test3" - assert str(excinfo.value) == 'Could not find license file.' + assert str(excinfo.value) == "Could not find license file." # Package -@pytest.mark.package(name='pkg1.a.b.c', function='test_package_1', code='\treturn "This is package test1"') +@pytest.mark.package( + name="pkg1.a.b.c", + function="test_package_1", + code='\treturn "This is package test1"', +) def test_python_package(package_and_loader: Tuple[Path], monkeypatch): package_path, loader_path = package_and_loader monkeypatch.syspath_prepend(package_path.as_posix()) monkeypatch.syspath_prepend(loader_path.parent.as_posix()) - sys.modules.pop('loader', None) + sys.modules.pop("loader", None) sys.meta_path = DEAFULT_META_PATH.copy() import loader from pkg1.a.b.c import test_package_1 - assert test_package_1() == 'This is package test1' + assert test_package_1() == "This is package test1" -@pytest.mark.package(name='pkg2.a.b.c', function='test_package_2', code='\treturn "This is package test2"') +@pytest.mark.package( + name="pkg2.a.b.c", + function="test_package_2", + code='\treturn "This is package test2"', +) def test_python_package_without_init_file(package_and_loader: Tuple[Path], monkeypatch): package_path, loader_path = package_and_loader monkeypatch.syspath_prepend(package_path.as_posix()) monkeypatch.syspath_prepend(loader_path.parent.as_posix()) - for file in package_path.glob('**/__init__.py'): + for file in package_path.glob("**/__init__.py"): file.unlink() - sys.modules.pop('loader', None) + sys.modules.pop("loader", None) sys.meta_path = DEAFULT_META_PATH.copy() import loader from pkg2.a.b.c import test_package_2 - assert test_package_2() == 'This is package test2' + assert test_package_2() == "This is package test2" @pytest.mark.license(enable=True) -@pytest.mark.package(name='pkg3.a.b.c', function='test_package_3', code='\treturn "This is package test3"') +@pytest.mark.package( + name="pkg3.a.b.c", + function="test_package_3", + code='\treturn "This is package test3"', +) def test_python_package_with_license(package_and_loader: Tuple[Path], monkeypatch): package_path, loader_path = package_and_loader monkeypatch.syspath_prepend(package_path.as_posix()) monkeypatch.syspath_prepend(loader_path.parent.as_posix()) - sys.modules.pop('loader', None) + sys.modules.pop("loader", None) sys.meta_path = DEAFULT_META_PATH.copy() import loader from pkg3.a.b.c import test_package_3 - assert test_package_3() == 'This is package test3' + assert test_package_3() == "This is package test3" @pytest.mark.license(enable=True) -@pytest.mark.package(name='pkg4.a.b.c', function='test_package_4', code='\treturn "This is package test4"') -def test_python_package_with_license_not_found(package_and_loader: Tuple[Path], monkeypatch): +@pytest.mark.package( + name="pkg4.a.b.c", + function="test_package_4", + code='\treturn "This is package test4"', +) +def test_python_package_with_license_not_found( + package_and_loader: Tuple[Path], monkeypatch +): package_path, loader_path = package_and_loader monkeypatch.syspath_prepend(package_path.as_posix()) monkeypatch.syspath_prepend(loader_path.parent.as_posix()) - shutil.rmtree(loader_path.parent.joinpath('licenses')) + shutil.rmtree(loader_path.parent.joinpath("licenses")) with pytest.raises(Exception) as excinfo: - sys.modules.pop('loader', None) + sys.modules.pop("loader", None) sys.meta_path = DEAFULT_META_PATH.copy() import loader from pkg4.a.b.c import test_package_4 - assert test_package_4() == 'This is package test4' - assert str(excinfo.value) == 'Could not find license file.' + assert test_package_4() == "This is package test4" + assert str(excinfo.value) == "Could not find license file." diff --git a/tests/test_ntt.py b/tests/test_ntt.py index f2cd880..df35d8a 100644 --- a/tests/test_ntt.py +++ b/tests/test_ntt.py @@ -6,55 +6,93 @@ class TestNtt: @pytest.mark.parametrize( - "input,expected", [ + "input,expected", + [ ([1, 2, 3, 4], [10, 173167434, 998244351, 825076915]), - ([1, 2, 3, 4, 5, 0, 0, 0], [15, 443713764, 173167439, 730825730, 3, 35028273, 825076920, 786920923]), - ] + ( + [1, 2, 3, 4, 5, 0, 0, 0], + [ + 15, + 443713764, + 173167439, + 730825730, + 3, + 35028273, + 825076920, + 786920923, + ], + ), + ], ) def test_ntt(self, input, expected): assert ntt(input) == expected @pytest.mark.parametrize( - "input,expected", [ + "input,expected", + [ ([10, 173167434, 998244351, 825076915], [1, 2, 3, 4]), - ([15, 443713764, 173167439, 730825730, 3, 35028273, 825076920, 786920923], [1, 2, 3, 4, 5, 0, 0, 0]), - ] + ( + [ + 15, + 443713764, + 173167439, + 730825730, + 3, + 35028273, + 825076920, + 786920923, + ], + [1, 2, 3, 4, 5, 0, 0, 0], + ), + ], ) def test_ntt_inverse(self, input, expected): assert intt(input) == expected - @pytest.mark.parametrize("input", [ - [1, 2, 3, 4, 5], - random_list(6), - random_list(7), - ]) + @pytest.mark.parametrize( + "input", + [ + [1, 2, 3, 4, 5], + random_list(6), + random_list(7), + ], + ) def test_ntt_exception(self, input): with pytest.raises(ValueError) as excinfo: ntt(input) assert str(excinfo.value) == "The length of input must be a power of 2." - @pytest.mark.parametrize("input", [ - [1, 2, 3, 4, 5], - random_list(6), - random_list(7), - ]) + @pytest.mark.parametrize( + "input", + [ + [1, 2, 3, 4, 5], + random_list(6), + random_list(7), + ], + ) def test_intt_exception(self, input): with pytest.raises(ValueError) as excinfo: intt(input) assert str(excinfo.value) == "The length of input must be a power of 2." - @pytest.mark.parametrize("input,expected", [ - (random_list(4), 4), - (random_list(8), 8), - (random_list(16), 16), - ]) + @pytest.mark.parametrize( + "input,expected", + [ + (random_list(4), 4), + (random_list(8), 8), + (random_list(16), 16), + ], + ) def test_ntt_result_length(self, input, expected): assert len(ntt(input)) == expected - @pytest.mark.parametrize("input,expected", [ - (random_list(4), 4), - (random_list(8), 8), - (random_list(16), 16), - ]) + @pytest.mark.parametrize( + "input,expected", + [ + (random_list(4), 4), + (random_list(8), 8), + (random_list(16), 16), + ], + ) def test_intt_result_length(self, input, expected): assert len(intt(input)) == expected