Skip to content

Commit

Permalink
review: harmonise on main() -> int
Browse files Browse the repository at this point in the history
make all the main's return ints and use the sys.exit(main()) idiom

note that test need to be posix status aware as a consequence
  • Loading branch information
Robin Bryce committed Dec 12, 2024
1 parent 10878ed commit c1cfe1c
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 39 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ jobs:
python-version: ["3.11", "3.12", "3.13" ]
# reduced matrix for ci
os: [ubuntu-latest, windows-latest]
# this limit mitigates against rate limiting making tests flaky
max-parallel: 2
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
Expand Down
8 changes: 5 additions & 3 deletions datatrails_scitt_samples/scripts/check_operation_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datatrails_scitt_samples.statement_registration import wait_for_entry_id


def main():
def main() -> int:
"""Polls for the signed statement to be registered"""

parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -46,8 +46,10 @@ def main():
print(entry_id)
except TimeoutError as e:
print(e, file=sys.stderr)
sys.exit(1)
return 1

return 0


if __name__ == "__main__":
main()
sys.exit(main())
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from hashlib import sha256


def main(args=None):
def main(args=None) -> int:
"""Creates a signed statement"""

parser = argparse.ArgumentParser(description="Create a signed statement.")
Expand Down Expand Up @@ -120,6 +120,8 @@ def main(args=None):
with open(args.output_file, "wb") as output_file:
output_file.write(signed_statement)

return 0


if __name__ == "__main__":
main()
sys.exit(main())
6 changes: 4 additions & 2 deletions datatrails_scitt_samples/scripts/create_signed_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datatrails_scitt_samples.statement_creation import OPTION_USE_DRAFT_04_LABELS


def main(args=None):
def main(args=None) -> int:
"""Creates a signed statement"""

parser = argparse.ArgumentParser(description="Create a signed statement.")
Expand Down Expand Up @@ -124,6 +124,8 @@ def main(args=None):
with open(args.output_file, "wb") as output_file:
output_file.write(signed_statement)

return 0


if __name__ == "__main__":
main()
sys.exit(main())
11 changes: 6 additions & 5 deletions datatrails_scitt_samples/scripts/register_signed_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def attach_receipt(
file.write(ts)


def main(args=None):
def main(args=None) -> int:
"""Creates a Transparent Statement"""

parser = argparse.ArgumentParser(description="Register a signed statement.")
Expand Down Expand Up @@ -111,7 +111,7 @@ def main(args=None):
entry_id = wait_for_entry_id(ctx, op_id)
except TimeoutError as e:
ctx.error(e)
sys.exit(1)
return 1
ctx.info("Fully Registered with Entry ID %s", entry_id)

result = {"entryid": entry_id}
Expand All @@ -131,12 +131,12 @@ def main(args=None):
receipt = get_receipt(ctx, entry_id)
if not verify_receipt_mmriver(receipt, leaf):
ctx.info("Receipt verification failed")
sys.exit(1)
return 1
result["leaf"] = leaf.hex()

if args.output_file == "":
print(json.dumps(result))
return
return 0

if args.output_receipt_file != "":
with open(args.output_receipt_file, "wb") as file:
Expand All @@ -148,7 +148,8 @@ def main(args=None):
ctx.info(f"File saved successfully {args.output_file}")

print(json.dumps(result))
return 0


if __name__ == "__main__":
main()
sys.exit(main())
25 changes: 12 additions & 13 deletions datatrails_scitt_samples/scripts/verify_receipt.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def verify_transparent_statement(
return verify_receipt_mmriver(receipt_bytes, leaf)


def main(args=None) -> bool:
def main(args=None) -> int:
"""Verifies a counter signed receipt signature"""

parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -104,30 +104,30 @@ def main(args=None) -> bool:

if not (args.leaf or args.event_json_file or args.entryid):
ctx.error("either --leaf or --event-json-file is required")
return False
return 1

leaf = None
if args.leaf:
try:
leaf = bytes.fromhex(args.leaf)
except ValueError:
ctx.error("failed to parse leaf hash")
return False
return 1

elif args.event_json_file:
try:
event = json.loads(open_event_json(args.event_json_file))
except ValueError:
ctx.error("failed to parse event json")
return False
return 1
leaf = v3leaf_hash(event)
elif args.entryid:
identity = entryid_to_identity(args.entryid)
try:
event = get_event(ctx, identity, True)
except HTTPError as e:
ctx.error("failed to obtain event: %s", e)
return False
return 1
leaf = v3leaf_hash(event)

if leaf is None:
Expand All @@ -144,14 +144,13 @@ def main(args=None) -> bool:
transparent_statement = read_cbor_file(args.transparent_statement_file)
verified = verify_transparent_statement(transparent_statement, leaf)

if verified:
print("verification succeeded")
return True
print("verification failed")
return False
if not verified:
print("verification failed")
return 1

print("verification succeeded")
return 0


if __name__ == "__main__":
if not main():
sys.exit(1)
sys.exit(0)
sys.exit(main())
23 changes: 12 additions & 11 deletions tests/test_verify_receipt.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def test_verify_failed_for_tampered_event(self):
# First verify the event as is

# Verify the leaf value directly
verified = False

verified = 1
output = io.StringIO()
with redirect_stdout(output):
verified = verify_receipt(
Expand All @@ -124,7 +125,7 @@ def test_verify_failed_for_tampered_event(self):
]
)
self.assertEqual(output.getvalue().strip(), "verification succeeded")
self.assertTrue(verified)
self.assertEqual(verified, 0)

event["event_attributes"]["test_verify_failed_for_tampered_event"] = "tampered"
with open(event_json_file, "w") as file:
Expand All @@ -141,7 +142,7 @@ def test_verify_failed_for_tampered_event(self):
]
)
self.assertEqual(output.getvalue().strip(), "verification failed")
self.assertFalse(verified)
self.assertEqual(verified, 1)

@unittest.skipUnless(
os.getenv("DATATRAILS_CLIENT_SECRET") != "",
Expand Down Expand Up @@ -199,7 +200,7 @@ def test_verify_transparent_statement_by_leaf(self):
self.assertTrue(os.path.exists(f"{self.test_dir}/transparent-statement.cbor"))

# Verify the leaf value directly
verified = False
verified = 1
output = io.StringIO()
with redirect_stdout(output):
verified = verify_receipt(
Expand All @@ -211,7 +212,7 @@ def test_verify_transparent_statement_by_leaf(self):
]
)
self.assertEqual(output.getvalue().strip(), "verification succeeded")
self.assertTrue(verified)
self.assertEqual(verified, 0)

@unittest.skipUnless(
os.getenv("DATATRAILS_CLIENT_SECRET") != "",
Expand Down Expand Up @@ -269,7 +270,7 @@ def test_verify_transparent_statement_by_entryid(self):
self.assertTrue(os.path.exists(f"{self.test_dir}/transparent-statement.cbor"))

# Verify the leaf value directly
verified = False
verified = 1
output = io.StringIO()
with redirect_stdout(output):
verified = verify_receipt(
Expand All @@ -281,7 +282,7 @@ def test_verify_transparent_statement_by_entryid(self):
]
)
self.assertEqual(output.getvalue().strip(), "verification succeeded")
self.assertTrue(verified)
self.assertEqual(verified, 0)

@unittest.skipUnless(
os.getenv("DATATRAILS_CLIENT_SECRET") != "",
Expand Down Expand Up @@ -339,7 +340,7 @@ def test_verify_receipt_by_leaf(self):
self.assertTrue(os.path.exists(f"{self.test_dir}/transparent-statement.cbor"))

# Verify the leaf value directly
verified = False
verified = 1
output = io.StringIO()
with redirect_stdout(output):
verified = verify_receipt(
Expand All @@ -351,7 +352,7 @@ def test_verify_receipt_by_leaf(self):
]
)
self.assertEqual(output.getvalue().strip(), "verification succeeded")
self.assertTrue(verified)
self.assertEqual(verified, 0)

@unittest.skipUnless(
os.getenv("DATATRAILS_CLIENT_SECRET") != "",
Expand Down Expand Up @@ -409,7 +410,7 @@ def test_verify_receipt_by_entryid(self):
self.assertTrue(os.path.exists(f"{self.test_dir}/transparent-statement.cbor"))

# Verify the leaf value directly
verified = False
verified = 1
output = io.StringIO()
with redirect_stdout(output):
verified = verify_receipt(
Expand All @@ -421,7 +422,7 @@ def test_verify_receipt_by_entryid(self):
]
)
self.assertEqual(output.getvalue().strip(), "verification succeeded")
self.assertTrue(verified)
self.assertEqual(verified, 0)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions tests/test_verify_receipt_bad_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_verify_receipt_leaf_not_hex(self):
"this is not hex",
]
)
self.assertFalse(verified)
self.assertEqual(verified, 1)

def test_verify_receipt_event_file_not_json(self):
"""Cover various bad input cases"""
Expand All @@ -53,7 +53,7 @@ def test_verify_receipt_event_file_not_json(self):
self.bad_json_file,
]
)
self.assertFalse(verified)
self.assertEqual(verified, 1)

def test_verify_receipt_bad_entryid(self):
"""Cover various bad input cases"""
Expand All @@ -67,7 +67,7 @@ def test_verify_receipt_bad_entryid(self):
"this is not found",
]
)
self.assertFalse(verified)
self.assertEqual(verified, 1)


if __name__ == "__main__":
Expand Down

0 comments on commit c1cfe1c

Please sign in to comment.