diff --git a/Cargo.lock b/Cargo.lock index 0167c63..9ce2c31 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -940,6 +940,15 @@ dependencies = [ "pyo3-build-config", ] +[[package]] +name = "pyo3-file" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488563e2317157edd6e12c3ef23e10363bd079bf8630e3de719e368b4eb02a21" +dependencies = [ + "pyo3", +] + [[package]] name = "pyo3-macros" version = "0.22.3" @@ -972,6 +981,7 @@ dependencies = [ "age", "age-core", "pyo3", + "pyo3-file", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 15a8106..09296d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,4 @@ pyo3 = { version = "0.22.3", features = [ "abi3-py38", "py-clone", ] } +pyo3-file = "0.9.0" diff --git a/pyrage-stubs/pyrage-stubs/__init__.pyi b/pyrage-stubs/pyrage-stubs/__init__.pyi index 8119866..1282a58 100644 --- a/pyrage-stubs/pyrage-stubs/__init__.pyi +++ b/pyrage-stubs/pyrage-stubs/__init__.pyi @@ -1,3 +1,4 @@ +from io import BufferedIOBase from typing import Sequence, Union from pyrage import ssh, x25519, passphrase, plugin @@ -15,5 +16,7 @@ class IdentityError(Exception): def encrypt(plaintext: bytes, recipients: Sequence[Recipient]) -> bytes: ... def encrypt_file(infile: str, outfile: str, recipients: Sequence[Recipient]) -> None: ... +def encrypt_io(in_io: BufferedIOBase, recipients: Sequence[Recipient]) -> bytes: ... def decrypt(ciphertext: bytes, identities: Sequence[Identity]) -> bytes: ... def decrypt_file(infile: str, outfile: str, identities: Sequence[Identity]) -> None: ... +def decrypt_io(in_io: BufferedIOBase, out_io: BufferedIOBase, recipient: Sequence[Recipient]) -> None: ... diff --git a/src/lib.rs b/src/lib.rs index 179316b..7c1198f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ use pyo3::{ py_run, types::PyBytes, }; +use pyo3_file::PyFileLikeObject; mod passphrase; mod plugin; @@ -259,6 +260,64 @@ fn decrypt_file( Ok(()) } +fn from_pyobject(file: PyObject, read_only: bool) -> PyResult { + // is a file-like + PyFileLikeObject::with_requirements(file, read_only, !read_only, false, false) +} + +#[pyfunction] +fn encrypt_io( + reader: PyObject, + writer: PyObject, + recipients: Vec>, +) -> PyResult<()> { + // This turns each `dyn PyrageRecipient` into a `dyn Recipient`, which + // is what the underlying `age` API expects. + let recipients = recipients.into_iter().map(|pr| pr.as_recipient()).collect(); + let reader = from_pyobject(reader, true)?; + let writer = from_pyobject(writer, false)?; + let mut reader = std::io::BufReader::new(reader); + let mut writer = std::io::BufWriter::new(writer); + let encryptor = Encryptor::with_recipients(recipients) + .ok_or_else(|| EncryptError::new_err("expected at least one recipient"))?; + let mut writer = encryptor + .wrap_output(&mut writer) + .map_err(|e| EncryptError::new_err(e.to_string()))?; + std::io::copy(&mut reader, &mut writer).map_err(|e| EncryptError::new_err(e.to_string()))?; + writer + .finish() + .map_err(|e| EncryptError::new_err(e.to_string()))?; + Ok(()) +} + +#[pyfunction] +fn decrypt_io( + reader: PyObject, + writer: PyObject, + identities: Vec>, +) -> PyResult<()> { + let identities = identities.iter().map(|pi| pi.as_ref().as_identity()); + let reader = from_pyobject(reader, true)?; + let writer = from_pyobject(writer, false)?; + let reader = std::io::BufReader::new(reader); + let mut writer = std::io::BufWriter::new(writer); + let decryptor = match age::Decryptor::new_buffered(reader) + .map_err(|e| DecryptError::new_err(e.to_string()))? + { + age::Decryptor::Recipients(d) => d, + age::Decryptor::Passphrase(_) => { + return Err(DecryptError::new_err( + "invalid ciphertext (encrypted with passphrase, not identities)", + )) + } + }; + let mut reader = decryptor + .decrypt(identities) + .map_err(|e| DecryptError::new_err(e.to_string()))?; + std::io::copy(&mut reader, &mut writer)?; + Ok(()) +} + #[pymodule] fn pyrage(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { // HACK(ww): pyO3 modules are not packages, so we need this nasty @@ -298,9 +357,11 @@ fn pyrage(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add("EncryptError", py.get_type_bound::())?; m.add_wrapped(wrap_pyfunction!(encrypt))?; m.add_wrapped(wrap_pyfunction!(encrypt_file))?; + m.add_wrapped(wrap_pyfunction!(encrypt_io))?; m.add("DecryptError", py.get_type_bound::())?; m.add_wrapped(wrap_pyfunction!(decrypt))?; m.add_wrapped(wrap_pyfunction!(decrypt_file))?; + m.add_wrapped(wrap_pyfunction!(decrypt_io))?; Ok(()) } diff --git a/test/test_pyrage.py b/test/test_pyrage.py index df4ddcf..4d1211e 100644 --- a/test/test_pyrage.py +++ b/test/test_pyrage.py @@ -1,6 +1,7 @@ import os import tempfile import unittest +from io import BytesIO import pyrage @@ -23,6 +24,59 @@ def test_roundtrip(self): self.assertEqual(b"test", decrypted) + def test_roundtrip_io_fh(self): + identity = pyrage.x25519.Identity.generate() + recipient = identity.to_public() + with tempfile.TemporaryFile() as unencrypted: + unencrypted.write(b"test") + unencrypted.seek(0) + with tempfile.TemporaryFile() as encrypted: + pyrage.encrypt_io(unencrypted, encrypted, [recipient]) + encrypted.seek(0) + with tempfile.TemporaryFile() as decrypted: + pyrage.decrypt_io(encrypted, decrypted, [identity]) + decrypted.seek(0) + unencrypted.seek(0) + self.assertEqual(unencrypted.read(), decrypted.read()) + + def test_roundtrip_io_bytesio(self): + identity = pyrage.x25519.Identity.generate() + recipient = identity.to_public() + unencrypted = BytesIO(b'test') + encrypted = BytesIO() + decrypted = BytesIO() + pyrage.encrypt_io(unencrypted, encrypted, [recipient]) + encrypted.seek(0) + pyrage.decrypt_io(encrypted, decrypted, [identity]) + decrypted.seek(0) + unencrypted.seek(0) + self.assertEqual(unencrypted.read(), decrypted.read()) + + def test_roundtrip_io_fail(self): + identity = pyrage.x25519.Identity.generate() + recipient = identity.to_public() + + with self.assertRaises(TypeError): + input = 'test' + output = BytesIO() + pyrage.encrypt_io(input, output, [recipient]) + + with self.assertRaises(TypeError): + input = BytesIO() + output = 'test' + pyrage.encrypt_io(input, output, [recipient]) + + with self.assertRaises(TypeError): + input = 'test' + output = BytesIO() + pyrage.decrypt_io(input, output, [recipient]) + + with self.assertRaises(TypeError): + input = BytesIO() + output = 'test' + pyrage.decrypt_io(input, output, [recipient]) + + def test_roundtrip_file(self): identity = pyrage.x25519.Identity.generate() recipient = identity.to_public()