Skip to content

Commit

Permalink
feat: allow to use pyproject name (by default) when packaging
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Sep 7, 2023
1 parent 5d93d09 commit 0e381af
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 32 deletions.
2 changes: 1 addition & 1 deletion docs/pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ To share the pipeline and turn it into a pip installable package, you can use th

```python
model.package(
name="path/to/your/package",
name="your-package-name", # leave None to reuse name in pyproject.toml
version="0.0.1",
root_dir="path/to/project/root", # optional, to retrieve an existing pyproject.toml file
# if you don't have a pyproject.toml, you can provide the metadata here instead
Expand Down
62 changes: 40 additions & 22 deletions edspdf/utils/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,18 @@ def validate(cls, value, config=None):
print(file.path)
"""

INIT_PY = """\
INIT_PY = """
# -----------------------------------------
# This section was autogenerated by edspdf
# -----------------------------------------
import edspdf
from pathlib import Path
__version__ = {__version__}
def load(device: "torch.device" = "cpu") -> edspdf.Pipeline:
artifacts_path = Path(__file__).parent / "{artifacts_path}"
artifacts_path = Path(__file__).parent / "{artifacts_dir}"
model = edspdf.load(artifacts_path, device=device)
return model
"""
Expand Down Expand Up @@ -189,9 +195,9 @@ def __init__(
pyproject: Optional[Dict[str, Any]],
pipeline: Union[Path, "edspdf.Pipeline"],
version: str,
name: ModuleName,
name: Optional[ModuleName],
root_dir: Path = ".",
build_dir: Path = "build",
build_name: Path = "build",
out_dir: Path = "dist",
artifacts_name: ModuleName = "artifacts",
dependencies: Optional[Sequence[Tuple[str, str]]] = None,
Expand All @@ -203,20 +209,15 @@ def __init__(
.strip()
)
self.version = version
self.name = name
self.pyproject = pyproject
self.root_dir = root_dir.resolve()
self.build_dir = build_dir
self.out_dir = self.root_dir / out_dir
self.artifacts_name = artifacts_name
self.name = name
self.pipeline = pipeline
self.dependencies = dependencies
self.pipeline = pipeline
self.artifacts_name = artifacts_name
self.out_dir = self.root_dir / out_dir

with self.ensure_pyproject(metadata):
logger.info(f"root_dir: {self.root_dir}")
logger.info(f"build_dir: {self.build_dir}")
logger.info(f"artifacts_name: {self.artifacts_name}")
logger.info(f"name: {self.name}")

python_executable = (
Path(self.poetry_bin_path).read_text().split("\n")[0][2:]
Expand All @@ -233,9 +234,16 @@ def __init__(
if result.returncode != 0:
raise Exception()
out = result.stdout.decode().strip().split("\n")

self.poetry_packages = eval(out[0])
self.build_dir = root_dir / build_name / self.name
self.file_paths = [self.root_dir / file_path for file_path in out[1:]]

logger.info(f"root_dir: {self.root_dir}")
logger.info(f"build_dir: {self.build_dir}")
logger.info(f"artifacts_name: {self.artifacts_name}")
logger.info(f"name: {self.name}")

@contextmanager
def ensure_pyproject(self, metadata):
"""Generates a Poetry based pyproject.toml"""
Expand Down Expand Up @@ -269,6 +277,11 @@ def ensure_pyproject(self, metadata):
toml.dumps(self.pyproject)
)
else:
self.name = (
self.pyproject["tool"]["poetry"]["name"]
if self.name is None
else self.name
)
for key, value in metadata.items():
pyproject_value = self.pyproject["tool"]["poetry"].get(key)
if pyproject_value != metadata[key]:
Expand Down Expand Up @@ -341,6 +354,8 @@ def update_pyproject(self):
def make_src_dir(self):
snake_name = snake_case(self.name.lower())
package_dir = self.build_dir / snake_name
shutil.rmtree(package_dir, ignore_errors=True)
os.makedirs(package_dir, exist_ok=True)
build_artifacts_dir = package_dir / self.artifacts_name
for file_path in self.list_files_to_add():
new_file_path = self.build_dir / Path(file_path).relative_to(self.root_dir)
Expand Down Expand Up @@ -368,17 +383,19 @@ def make_src_dir(self):
else:
self.pipeline.save(build_artifacts_dir)
os.makedirs(package_dir, exist_ok=True)
(package_dir / "__init__.py").write_text(
INIT_PY.format(
artifacts_path=os.path.relpath(build_artifacts_dir, package_dir)
with open(package_dir / "__init__.py", mode="a") as f:
f.write(
INIT_PY.format(
__version__=repr(self.version),
artifacts_dir=os.path.relpath(build_artifacts_dir, package_dir),
)
)
)


@app.command(name="package")
def package(
pipeline: Union[Path, "edspdf.Pipeline"],
name: ModuleName,
name: Optional[ModuleName] = None,
root_dir: Path = ".",
artifacts_name: ModuleName = "artifacts",
check_dependencies: bool = False,
Expand All @@ -395,6 +412,11 @@ def package(

if not pyproject_path.exists():
check_dependencies = True
if name is None:
raise ValueError(
f"No pyproject.toml could be found in the root directory {root_dir}, "
f"you need to create one, or fill the name parameter."
)

dependencies = None
if check_dependencies:
Expand All @@ -405,9 +427,6 @@ def package(
print("DEPENDENCY", dep[0].ljust(30), dep[1])

root_dir = root_dir.resolve()
build_dir = root_dir / "build" / name
shutil.rmtree(build_dir, ignore_errors=True)
os.makedirs(build_dir)

pyproject = None
if pyproject_path.exists():
Expand All @@ -423,7 +442,6 @@ def package(
name=name,
version=version,
root_dir=root_dir,
build_dir=build_dir,
artifacts_name=artifacts_name,
dependencies=dependencies,
metadata=metadata,
Expand Down
71 changes: 62 additions & 9 deletions tests/utils/test_package.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import importlib

import pip
import pytest

from edspdf.utils.package import package


def test_blank_package(frozen_pipeline, tmp_path):

# Missing metadata makes poetry fail due to missing author / description
with pytest.raises(Exception):
package(
Expand Down Expand Up @@ -32,11 +34,12 @@ def test_blank_package(frozen_pipeline, tmp_path):
assert (tmp_path / "build" / "test-model").is_dir()


def test_package_with_files(frozen_pipeline, tmp_path):
@pytest.mark.parametrize("package_name", ["my-test-model", None])
def test_package_with_files(frozen_pipeline, tmp_path, package_name):
frozen_pipeline.save(tmp_path / "model")

((tmp_path / "test_model_trainer").mkdir(parents=True))
(tmp_path / "test_model_trainer" / "__init__.py").write_text(
((tmp_path / "test_model").mkdir(parents=True))
(tmp_path / "test_model" / "__init__.py").write_text(
"""\
print("Hello World!")
"""
Expand All @@ -48,7 +51,7 @@ def test_package_with_files(frozen_pipeline, tmp_path):
build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "test-model-trainer"
name = "test-model"
version = "0.0.0"
description = "A test model"
authors = ["Test Author <[email protected]>"]
Expand All @@ -64,26 +67,76 @@ def test_package_with_files(frozen_pipeline, tmp_path):
pipeline=frozen_pipeline,
root_dir=tmp_path,
version="0.1.0",
name="test-model",
name=package_name,
metadata={
"description": "Wrong description",
"authors": "Test Author <[email protected]>",
},
)

package(
name=package_name,
pipeline=tmp_path / "model",
root_dir=tmp_path,
check_dependencies=True,
version="0.1.0",
name="test-model",
distributions=None,
metadata={
"description": "A test model",
"authors": "Test Author <[email protected]>",
},
)

module_name = "test_model" if package_name is None else "my_test_model"

assert (tmp_path / "dist").is_dir()
assert (tmp_path / "dist" / "test_model-0.1.0.tar.gz").is_file()
assert (tmp_path / "dist" / "test_model-0.1.0-py3-none-any.whl").is_file()
assert (tmp_path / "dist" / f"{module_name}-0.1.0.tar.gz").is_file()
assert (tmp_path / "dist" / f"{module_name}-0.1.0-py3-none-any.whl").is_file()
assert (tmp_path / "pyproject.toml").is_file()

# pip install the whl file
pip.main(
[
"install",
str(tmp_path / "dist" / f"{module_name}-0.1.0-py3-none-any.whl"),
"--force-reinstall",
]
)

module = importlib.import_module(module_name)

assert module.__version__ == "0.1.0"

with open(module.__file__) as f:
assert f.read() == (
(
"""\
print("Hello World!")
"""
if package_name is None
else ""
)
+ """
# -----------------------------------------
# This section was autogenerated by edspdf
# -----------------------------------------
import edspdf
from pathlib import Path
__version__ = '0.1.0'
def load(device: "torch.device" = "cpu") -> edspdf.Pipeline:
artifacts_path = Path(__file__).parent / "artifacts"
model = edspdf.load(artifacts_path, device=device)
return model
"""
)


@pytest.fixture(scope="session", autouse=True)
def clean_after():
yield

pip.main(["uninstall", "-y", "test-model"])
pip.main(["uninstall", "-y", "my-test-model"])

0 comments on commit 0e381af

Please sign in to comment.