Skip to content

Commit

Permalink
feat: pull recursively to include subject references
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Coufal <[email protected]>
  • Loading branch information
tumido committed May 2, 2024
1 parent ccc28b3 commit 01732df
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 2 deletions.
14 changes: 13 additions & 1 deletion oras/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,11 +862,13 @@ def pull(self, *args, **kwargs) -> List[str]:
refresh_headers = kwargs.get("refresh_headers")
if refresh_headers is None:
refresh_headers = True
container = self.get_container(kwargs["target"])
target: str = kwargs["target"]
container = self.get_container(target)
self.load_configs(container, configs=kwargs.get("config_path"))
manifest = self.get_manifest(container, allowed_media_type, refresh_headers)
outdir = kwargs.get("outdir") or oras.utils.get_tmpdir()
overwrite = kwargs.get("overwrite", True)
include_subject = kwargs.get("include_subject", False)

files = []
for layer in manifest.get("layers", []):
Expand Down Expand Up @@ -900,6 +902,16 @@ def pull(self, *args, **kwargs) -> List[str]:
self.download_blob(container, layer["digest"], outfile)
logger.info(f"Successfully pulled {outfile}.")
files.append(outfile)

if include_subject and manifest.get('subject', False):
separator = "@" if "@" in target else ":"
repo, _tag = target.rsplit(separator, 1)
subject_digest = manifest['subject']['digest']
new_kwargs = kwargs
new_kwargs['target'] = f'{repo}@{subject_digest}'

files += self.pull(*args, **kwargs)

return files

@decorator.ensure_container
Expand Down
5 changes: 5 additions & 0 deletions oras/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def target(registry):
return f"{registry}/dinosaur/artifact:v1"


@pytest.fixture
def derived_target(registry):
return f"{registry}/dinosaur/artifact:v1-derived"


@pytest.fixture
def target_dir(registry):
return f"{registry}/dinosaur/directory:v1"
1 change: 1 addition & 0 deletions oras/tests/derived-artifact.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
referred artifact is greeting extinct creatures
37 changes: 36 additions & 1 deletion oras/tests/test_oras.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

import oras.client
import oras.provider

here = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -67,6 +68,40 @@ def test_basic_push_pull(tmp_path, registry, credentials, target):
assert res.status_code == 201



@pytest.mark.with_auth(False)
def test_push_pull_attached_artifacts(tmp_path, registry, credentials, target, derived_target):
"""
Basic tests for oras (without authentication)
"""
client = oras.client.OrasClient(hostname=registry, insecure=True)

artifact = os.path.join(here, "artifact.txt")
assert os.path.exists(artifact)

res = client.push(files=[artifact], target=target)
assert res.status_code in [200, 201]

derived_artifact = os.path.join(here, "derived-artifact.txt")
assert os.path.exists(derived_artifact)

manifest = client.remote.get_manifest(target)
subject = oras.provider.Subject.from_manifest(manifest)
res = client.push(files=[derived_artifact], target=derived_target, subject=subject)
assert res.status_code in [200, 201]

# Test pulling elsewhere
files = sorted(client.pull(target=derived_target, outdir=tmp_path, include_subject=True))
assert len(files) == 2
assert os.path.basename(files[0]) == "artifact.txt"
assert os.path.basename(files[1]) == "derived-artifact.txt"
assert str(tmp_path) in files[0]
assert str(tmp_path) in files[1]
assert os.path.exists(files[0])
assert os.path.exists(files[1])



@pytest.mark.with_auth(False)
def test_get_delete_tags(tmp_path, registry, credentials, target):
"""
Expand All @@ -87,7 +122,7 @@ def test_get_delete_tags(tmp_path, registry, credentials, target):
assert not client.delete_tags(target, "v1-boop-boop")
assert "v1" in client.delete_tags(target, "v1")
tags = client.get_tags(target)
assert not tags
assert "v1" not in tags


def test_get_many_tags():
Expand Down

0 comments on commit 01732df

Please sign in to comment.