Skip to content

Commit

Permalink
test: πŸ’ Speed up unintest (#55)
Browse files Browse the repository at this point in the history
* test: πŸ’ speed up test_loader

* test: πŸ’ typo
  • Loading branch information
ZhaoQi99 authored Dec 3, 2024
1 parent aa52595 commit c82108a
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 37 deletions.
105 changes: 71 additions & 34 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import shutil
from pathlib import Path

import pytest

from pyencrypt.encrypt import encrypt_file, encrypt_key, generate_so_file
Expand All @@ -19,21 +22,59 @@ def pytest_configure(config):
)


@pytest.fixture(scope="function")
def file_and_loader(request, tmp_path_factory):
tmp_path = tmp_path_factory.mktemp("file")
def _common_loader(path, license=False):
open("tmp", "a+").write(f"common:path: {path}" + "\n")

key = generate_aes_key()
cipher_key, d, n = encrypt_key(key)
loader_path = generate_so_file(cipher_key, d, n, path, license=license)

work_dir = loader_path.parent
work_dir.joinpath("loader.py").unlink()
work_dir.joinpath("loader.c").unlink()
work_dir.joinpath("loader_origin.py").unlink()
return key, loader_path.absolute()

file_marker = request.node.get_closest_marker("file")
file_name = file_marker.kwargs.get("name")
function_name = file_marker.kwargs.get("function")
code = file_marker.kwargs.get("code")

@pytest.fixture(scope="session")
def common_loader(tmp_path_factory):
open("tmp", "a+").write("make loader" + "\n")
tmp_path = tmp_path_factory.mktemp("loader")
return _common_loader(tmp_path, False)


@pytest.fixture(scope="session")
def common_loader_with_license(tmp_path_factory):
tmp_path = tmp_path_factory.mktemp("loader_with_license")
return _common_loader(tmp_path, True)


@pytest.fixture(scope="function")
def file_and_loader(request, common_loader, common_loader_with_license, tmp_path):
license_marker = request.node.get_closest_marker("license")
license, kwargs = False, {}
if license_marker is not None:
kwargs = license_marker.kwargs
license = kwargs.pop("enable", True)

if license:
key, loader_path = common_loader_with_license
else:
key, loader_path = common_loader

# copy loader -> tmp_path
loader_path = (
Path(shutil.copytree(loader_path.parent, tmp_path / "encrypted"))
/ loader_path.name
)
if license:
generate_license_file(key.decode(), loader_path.parent, **kwargs)

file_marker = request.node.get_closest_marker("file")
file_name = file_marker.kwargs.get("name")
function_name = file_marker.kwargs.get("function")
code = file_marker.kwargs.get("code")

file_path = tmp_path / f"{file_name}.py"
file_path.touch()
file_path.write_text(
Expand All @@ -45,38 +86,43 @@ def {function_name}():
),
encoding="utf-8",
)
# generate loader.so
key = generate_aes_key()

new_path = file_path.with_suffix(".pye")
encrypt_file(file_path, key.decode(), new_path=new_path)
file_path.unlink()
cipher_key, d, n = encrypt_key(key)
loader_path = generate_so_file(cipher_key, d, n, file_path.parent, license=license)
work_dir = loader_path.parent
work_dir.joinpath("loader.py").unlink()
work_dir.joinpath("loader.c").unlink()
work_dir.joinpath("loader_origin.py").unlink()

# License
license and generate_license_file(key.decode(), work_dir, **kwargs)
return (new_path, loader_path)


@pytest.fixture(scope="function")
def package_and_loader(request, tmp_path_factory):
pkg_path = tmp_path_factory.mktemp("package")

file_marker = request.node.get_closest_marker("package")
package_name = file_marker.kwargs.get("name")
function_name = file_marker.kwargs.get("function")
code = file_marker.kwargs.get("code")
def package_and_loader(request, common_loader, common_loader_with_license, tmp_path):
pkg_path = tmp_path

license_marker = request.node.get_closest_marker("license")
license, kwargs = False, {}
if license_marker is not None:
kwargs = license_marker.kwargs
license = kwargs.pop("enable", True)

if license:
key, loader_path = common_loader_with_license
else:
key, loader_path = common_loader

# copy loader -> tmp_path
loader_path = (
Path(shutil.copytree(loader_path.parent, tmp_path / "encrypted"))
/ loader_path.name
)

if license:
generate_license_file(key.decode(), loader_path.parent, **kwargs)

file_marker = request.node.get_closest_marker("package")
package_name = file_marker.kwargs.get("name")
function_name = file_marker.kwargs.get("function")
code = file_marker.kwargs.get("code")

current = pkg_path
for dir_name in package_name.split(".")[:-1]:
current = current.joinpath(dir_name)
Expand All @@ -95,16 +141,7 @@ def {function_name}():
)

new_path = file_path.with_suffix(".pye")
key = generate_aes_key()
encrypt_file(file_path, key, new_path=new_path)
encrypt_file(file_path, key.decode(), new_path=new_path)
file_path.unlink()

cipher_key, d, n = encrypt_key(key)
loader_path = generate_so_file(cipher_key, d, n, pkg_path, license)
work_dir = loader_path.parent
work_dir.joinpath("loader.py").unlink()
work_dir.joinpath("loader.c").unlink()
work_dir.joinpath("loader_origin.py").unlink()
# License
license and generate_license_file(key.decode(), work_dir, **kwargs)
return pkg_path, loader_path
6 changes: 3 additions & 3 deletions tests/test_encrypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_can_encrypt(path, expected):
assert can_encrypt(path) == expected


class TestGenarateSoFile:
class TestGenerateSoFile:
def setup_method(self, method):
if method.__name__ == "test_generate_so_file_default_path":
shutil.rmtree(
Expand All @@ -53,7 +53,7 @@ def setup_method(self, method):
)
def test_generate_so_file(self, key, tmp_path):
cipher_key, d, n = encrypt_key(key)
assert generate_so_file(cipher_key, d, n, tmp_path)
assert generate_so_file(cipher_key, d, n, tmp_path).exists()
assert (tmp_path / "encrypted" / "loader.py").exists() is True
assert (tmp_path / "encrypted" / "loader_origin.py").exists() is True
if sys.platform.startswith("win"):
Expand All @@ -76,7 +76,7 @@ def test_generate_so_file(self, key, tmp_path):
)
def test_generate_so_file_default_path(self, key):
cipher_key, d, n = encrypt_key(key)
assert generate_so_file(cipher_key, d, n)
assert generate_so_file(cipher_key, d, n).exists()
assert (Path(os.getcwd()) / "encrypted" / "loader.py").exists() is True
assert (Path(os.getcwd()) / "encrypted" / "loader_origin.py").exists() is True
if sys.platform.startswith("win"):
Expand Down

0 comments on commit c82108a

Please sign in to comment.