Skip to content

Commit

Permalink
Re-add fh streaming possibility (#77)
Browse files Browse the repository at this point in the history
Co-authored-by: William Woodruff <[email protected]>
  • Loading branch information
julianhille and woodruffw authored Sep 26, 2024
1 parent ced975f commit dc191c2
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 0 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ pyo3 = { version = "0.22.3", features = [
"abi3-py38",
"py-clone",
] }
pyo3-file = "0.9.0"
3 changes: 3 additions & 0 deletions pyrage-stubs/pyrage-stubs/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from io import BufferedIOBase
from typing import Sequence, Union

from pyrage import ssh, x25519, passphrase, plugin
Expand All @@ -15,5 +16,7 @@ class IdentityError(Exception):

def encrypt(plaintext: bytes, recipients: Sequence[Recipient]) -> bytes: ...
def encrypt_file(infile: str, outfile: str, recipients: Sequence[Recipient]) -> None: ...
def encrypt_io(in_io: BufferedIOBase, recipients: Sequence[Recipient]) -> bytes: ...
def decrypt(ciphertext: bytes, identities: Sequence[Identity]) -> bytes: ...
def decrypt_file(infile: str, outfile: str, identities: Sequence[Identity]) -> None: ...
def decrypt_io(in_io: BufferedIOBase, out_io: BufferedIOBase, recipient: Sequence[Recipient]) -> None: ...
61 changes: 61 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use pyo3::{
py_run,
types::PyBytes,
};
use pyo3_file::PyFileLikeObject;

mod passphrase;
mod plugin;
Expand Down Expand Up @@ -259,6 +260,64 @@ fn decrypt_file(
Ok(())
}

fn from_pyobject(file: PyObject, read_only: bool) -> PyResult<PyFileLikeObject> {
// is a file-like
PyFileLikeObject::with_requirements(file, read_only, !read_only, false, false)
}

#[pyfunction]
fn encrypt_io(
reader: PyObject,
writer: PyObject,
recipients: Vec<Box<dyn PyrageRecipient>>,
) -> PyResult<()> {
// This turns each `dyn PyrageRecipient` into a `dyn Recipient`, which
// is what the underlying `age` API expects.
let recipients = recipients.into_iter().map(|pr| pr.as_recipient()).collect();
let reader = from_pyobject(reader, true)?;
let writer = from_pyobject(writer, false)?;
let mut reader = std::io::BufReader::new(reader);
let mut writer = std::io::BufWriter::new(writer);
let encryptor = Encryptor::with_recipients(recipients)
.ok_or_else(|| EncryptError::new_err("expected at least one recipient"))?;
let mut writer = encryptor
.wrap_output(&mut writer)
.map_err(|e| EncryptError::new_err(e.to_string()))?;
std::io::copy(&mut reader, &mut writer).map_err(|e| EncryptError::new_err(e.to_string()))?;
writer
.finish()
.map_err(|e| EncryptError::new_err(e.to_string()))?;
Ok(())
}

#[pyfunction]
fn decrypt_io(
reader: PyObject,
writer: PyObject,
identities: Vec<Box<dyn PyrageIdentity>>,
) -> PyResult<()> {
let identities = identities.iter().map(|pi| pi.as_ref().as_identity());
let reader = from_pyobject(reader, true)?;
let writer = from_pyobject(writer, false)?;
let reader = std::io::BufReader::new(reader);
let mut writer = std::io::BufWriter::new(writer);
let decryptor = match age::Decryptor::new_buffered(reader)
.map_err(|e| DecryptError::new_err(e.to_string()))?
{
age::Decryptor::Recipients(d) => d,
age::Decryptor::Passphrase(_) => {
return Err(DecryptError::new_err(
"invalid ciphertext (encrypted with passphrase, not identities)",
))
}
};
let mut reader = decryptor
.decrypt(identities)
.map_err(|e| DecryptError::new_err(e.to_string()))?;
std::io::copy(&mut reader, &mut writer)?;
Ok(())
}

#[pymodule]
fn pyrage(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
// HACK(ww): pyO3 modules are not packages, so we need this nasty
Expand Down Expand Up @@ -298,9 +357,11 @@ fn pyrage(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add("EncryptError", py.get_type_bound::<EncryptError>())?;
m.add_wrapped(wrap_pyfunction!(encrypt))?;
m.add_wrapped(wrap_pyfunction!(encrypt_file))?;
m.add_wrapped(wrap_pyfunction!(encrypt_io))?;
m.add("DecryptError", py.get_type_bound::<DecryptError>())?;
m.add_wrapped(wrap_pyfunction!(decrypt))?;
m.add_wrapped(wrap_pyfunction!(decrypt_file))?;
m.add_wrapped(wrap_pyfunction!(decrypt_io))?;

Ok(())
}
54 changes: 54 additions & 0 deletions test/test_pyrage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import tempfile
import unittest
from io import BytesIO

import pyrage

Expand All @@ -23,6 +24,59 @@ def test_roundtrip(self):

self.assertEqual(b"test", decrypted)

def test_roundtrip_io_fh(self):
identity = pyrage.x25519.Identity.generate()
recipient = identity.to_public()
with tempfile.TemporaryFile() as unencrypted:
unencrypted.write(b"test")
unencrypted.seek(0)
with tempfile.TemporaryFile() as encrypted:
pyrage.encrypt_io(unencrypted, encrypted, [recipient])
encrypted.seek(0)
with tempfile.TemporaryFile() as decrypted:
pyrage.decrypt_io(encrypted, decrypted, [identity])
decrypted.seek(0)
unencrypted.seek(0)
self.assertEqual(unencrypted.read(), decrypted.read())

def test_roundtrip_io_bytesio(self):
identity = pyrage.x25519.Identity.generate()
recipient = identity.to_public()
unencrypted = BytesIO(b'test')
encrypted = BytesIO()
decrypted = BytesIO()
pyrage.encrypt_io(unencrypted, encrypted, [recipient])
encrypted.seek(0)
pyrage.decrypt_io(encrypted, decrypted, [identity])
decrypted.seek(0)
unencrypted.seek(0)
self.assertEqual(unencrypted.read(), decrypted.read())

def test_roundtrip_io_fail(self):
identity = pyrage.x25519.Identity.generate()
recipient = identity.to_public()

with self.assertRaises(TypeError):
input = 'test'
output = BytesIO()
pyrage.encrypt_io(input, output, [recipient])

with self.assertRaises(TypeError):
input = BytesIO()
output = 'test'
pyrage.encrypt_io(input, output, [recipient])

with self.assertRaises(TypeError):
input = 'test'
output = BytesIO()
pyrage.decrypt_io(input, output, [recipient])

with self.assertRaises(TypeError):
input = BytesIO()
output = 'test'
pyrage.decrypt_io(input, output, [recipient])


def test_roundtrip_file(self):
identity = pyrage.x25519.Identity.generate()
recipient = identity.to_public()
Expand Down

0 comments on commit dc191c2

Please sign in to comment.