diff --git a/tests/conftest.py b/tests/conftest.py index 34d2a04..8457976 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,6 @@ +import shutil +from pathlib import Path + import pytest from pyencrypt.encrypt import encrypt_file, encrypt_key, generate_so_file @@ -19,21 +22,59 @@ def pytest_configure(config): ) -@pytest.fixture(scope="function") -def file_and_loader(request, tmp_path_factory): - tmp_path = tmp_path_factory.mktemp("file") +def _common_loader(path, license=False): + open("tmp", "a+").write(f"common:path: {path}" + "\n") + + key = generate_aes_key() + cipher_key, d, n = encrypt_key(key) + loader_path = generate_so_file(cipher_key, d, n, path, license=license) + + work_dir = loader_path.parent + work_dir.joinpath("loader.py").unlink() + work_dir.joinpath("loader.c").unlink() + work_dir.joinpath("loader_origin.py").unlink() + return key, loader_path.absolute() - file_marker = request.node.get_closest_marker("file") - file_name = file_marker.kwargs.get("name") - function_name = file_marker.kwargs.get("function") - code = file_marker.kwargs.get("code") +@pytest.fixture(scope="session") +def common_loader(tmp_path_factory): + open("tmp", "a+").write("make loader" + "\n") + tmp_path = tmp_path_factory.mktemp("loader") + return _common_loader(tmp_path, False) + + +@pytest.fixture(scope="session") +def common_loader_with_license(tmp_path_factory): + tmp_path = tmp_path_factory.mktemp("loader_with_license") + return _common_loader(tmp_path, True) + + +@pytest.fixture(scope="function") +def file_and_loader(request, common_loader, common_loader_with_license, tmp_path): license_marker = request.node.get_closest_marker("license") license, kwargs = False, {} if license_marker is not None: kwargs = license_marker.kwargs license = kwargs.pop("enable", True) + if license: + key, loader_path = common_loader_with_license + else: + key, loader_path = common_loader + + # copy loader -> tmp_path + loader_path = ( + Path(shutil.copytree(loader_path.parent, tmp_path / "encrypted")) + / loader_path.name + ) + if license: + generate_license_file(key.decode(), loader_path.parent, **kwargs) + + file_marker = request.node.get_closest_marker("file") + file_name = file_marker.kwargs.get("name") + function_name = file_marker.kwargs.get("function") + code = file_marker.kwargs.get("code") + file_path = tmp_path / f"{file_name}.py" file_path.touch() file_path.write_text( @@ -45,31 +86,17 @@ def {function_name}(): ), encoding="utf-8", ) - # generate loader.so - key = generate_aes_key() + new_path = file_path.with_suffix(".pye") encrypt_file(file_path, key.decode(), new_path=new_path) file_path.unlink() - cipher_key, d, n = encrypt_key(key) - loader_path = generate_so_file(cipher_key, d, n, file_path.parent, license=license) - work_dir = loader_path.parent - work_dir.joinpath("loader.py").unlink() - work_dir.joinpath("loader.c").unlink() - work_dir.joinpath("loader_origin.py").unlink() - # License - license and generate_license_file(key.decode(), work_dir, **kwargs) return (new_path, loader_path) @pytest.fixture(scope="function") -def package_and_loader(request, tmp_path_factory): - pkg_path = tmp_path_factory.mktemp("package") - - file_marker = request.node.get_closest_marker("package") - package_name = file_marker.kwargs.get("name") - function_name = file_marker.kwargs.get("function") - code = file_marker.kwargs.get("code") +def package_and_loader(request, common_loader, common_loader_with_license, tmp_path): + pkg_path = tmp_path license_marker = request.node.get_closest_marker("license") license, kwargs = False, {} @@ -77,6 +104,25 @@ def package_and_loader(request, tmp_path_factory): kwargs = license_marker.kwargs license = kwargs.pop("enable", True) + if license: + key, loader_path = common_loader_with_license + else: + key, loader_path = common_loader + + # copy loader -> tmp_path + loader_path = ( + Path(shutil.copytree(loader_path.parent, tmp_path / "encrypted")) + / loader_path.name + ) + + if license: + generate_license_file(key.decode(), loader_path.parent, **kwargs) + + file_marker = request.node.get_closest_marker("package") + package_name = file_marker.kwargs.get("name") + function_name = file_marker.kwargs.get("function") + code = file_marker.kwargs.get("code") + current = pkg_path for dir_name in package_name.split(".")[:-1]: current = current.joinpath(dir_name) @@ -95,16 +141,7 @@ def {function_name}(): ) new_path = file_path.with_suffix(".pye") - key = generate_aes_key() - encrypt_file(file_path, key, new_path=new_path) + encrypt_file(file_path, key.decode(), new_path=new_path) file_path.unlink() - cipher_key, d, n = encrypt_key(key) - loader_path = generate_so_file(cipher_key, d, n, pkg_path, license) - work_dir = loader_path.parent - work_dir.joinpath("loader.py").unlink() - work_dir.joinpath("loader.c").unlink() - work_dir.joinpath("loader_origin.py").unlink() - # License - license and generate_license_file(key.decode(), work_dir, **kwargs) return pkg_path, loader_path diff --git a/tests/test_encrypt.py b/tests/test_encrypt.py index f2dd9ed..6aef4c0 100644 --- a/tests/test_encrypt.py +++ b/tests/test_encrypt.py @@ -37,7 +37,7 @@ def test_can_encrypt(path, expected): assert can_encrypt(path) == expected -class TestGenarateSoFile: +class TestGenerateSoFile: def setup_method(self, method): if method.__name__ == "test_generate_so_file_default_path": shutil.rmtree( @@ -53,7 +53,7 @@ def setup_method(self, method): ) def test_generate_so_file(self, key, tmp_path): cipher_key, d, n = encrypt_key(key) - assert generate_so_file(cipher_key, d, n, tmp_path) + assert generate_so_file(cipher_key, d, n, tmp_path).exists() assert (tmp_path / "encrypted" / "loader.py").exists() is True assert (tmp_path / "encrypted" / "loader_origin.py").exists() is True if sys.platform.startswith("win"): @@ -76,7 +76,7 @@ def test_generate_so_file(self, key, tmp_path): ) def test_generate_so_file_default_path(self, key): cipher_key, d, n = encrypt_key(key) - assert generate_so_file(cipher_key, d, n) + assert generate_so_file(cipher_key, d, n).exists() assert (Path(os.getcwd()) / "encrypted" / "loader.py").exists() is True assert (Path(os.getcwd()) / "encrypted" / "loader_origin.py").exists() is True if sys.platform.startswith("win"):