Skip to content

Commit

Permalink
Implement borsh in guest code
Browse files Browse the repository at this point in the history
  • Loading branch information
kpp committed Jul 9, 2024
1 parent 61757a1 commit fc22673
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 132 deletions.
67 changes: 32 additions & 35 deletions crates/sovereign-sdk/adapters/mock-zkvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use std::sync::{Arc, Condvar, Mutex};
use anyhow::ensure;
use borsh::{BorshDeserialize, BorshSerialize};
use serde::{Deserialize, Serialize};
use sov_rollup_interface::zk::{Matches, ValidityCondition};
use sov_rollup_interface::da::BlockHeaderTrait;
use sov_rollup_interface::zk::{Matches, StateTransitionData, ValidityCondition};

/// A mock commitment to a particular zkVM program.
#[derive(Debug, Clone, PartialEq, Eq, BorshDeserialize, BorshSerialize, Serialize, Deserialize)]
Expand Down Expand Up @@ -152,15 +153,14 @@ impl<ValidityCond: ValidityCondition> sov_rollup_interface::zk::ZkvmHost
type Guest = MockZkGuest;

fn add_hint<T: BorshSerialize>(&mut self, item: T) {
unimplemented!()
// let hint = bincode::serialize(&item).unwrap();
// let proof_info = ProofInfo {
// hint,
// validity_condition: self.validity_condition,
// };

// let data = bincode::serialize(&proof_info).unwrap();
// self.committed_data.push_back(data)
let hint = borsh::to_vec(&item).unwrap();
let proof_info = ProofInfo {
hint,
validity_condition: self.validity_condition,
};

let data = borsh::to_vec(&proof_info).unwrap();
self.committed_data.push_back(data)
}

fn simulate_with_hints(&mut self) -> Self::Guest {
Expand All @@ -173,32 +173,29 @@ impl<ValidityCond: ValidityCondition> sov_rollup_interface::zk::ZkvmHost
Ok(sov_rollup_interface::zk::Proof::PublicInput(data))
}

fn extract_output<
Da: sov_rollup_interface::da::DaSpec,
Root: Serialize + serde::de::DeserializeOwned,
>(
fn extract_output<Da: sov_rollup_interface::da::DaSpec, Root: BorshDeserialize>(
proof: &sov_rollup_interface::zk::Proof,
) -> Result<sov_rollup_interface::zk::StateTransition<Da, Root>, Self::Error> {
unimplemented!()
// match proof {
// sov_rollup_interface::zk::Proof::PublicInput(pub_input) => {
// let data: ProofInfo<Da::ValidityCondition> = bincode::deserialize(pub_input)?;
// let st: StateTransitionData<Root, (), Da> = bincode::deserialize(&data.hint)?;

// Ok(sov_rollup_interface::zk::StateTransition {
// initial_state_root: st.initial_state_root,
// final_state_root: st.final_state_root,
// validity_condition: data.validity_condition,
// state_diff: Default::default(),
// da_slot_hash: st.da_block_header_of_commitments.hash(),
// sequencer_public_key: vec![],
// sequencer_da_public_key: vec![],
// })
// }
// sov_rollup_interface::zk::Proof::Full(_) => {
// panic!("Mock DA doesn't generate real proofs")
// }
// }
match proof {
sov_rollup_interface::zk::Proof::PublicInput(pub_input) => {
let data: ProofInfo<Da::ValidityCondition> = bincode::deserialize(pub_input)?;
let st: StateTransitionData<Root, (), Da> =
BorshDeserialize::deserialize(&mut &*data.hint)?;

Ok(sov_rollup_interface::zk::StateTransition {
initial_state_root: st.initial_state_root,
final_state_root: st.final_state_root,
validity_condition: data.validity_condition,
state_diff: Default::default(),
da_slot_hash: st.da_block_header_of_commitments.hash(),
sequencer_public_key: vec![],
sequencer_da_public_key: vec![],
})
}
sov_rollup_interface::zk::Proof::Full(_) => {
panic!("Mock DA doesn't generate real proofs")
}
}
}
}

Expand Down Expand Up @@ -238,7 +235,7 @@ impl sov_rollup_interface::zk::ZkvmGuest for MockZkGuest {
}
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, BorshDeserialize, BorshSerialize, Serialize, Deserialize)]
struct ProofInfo<ValidityCond> {
hint: Vec<u8>,
validity_condition: ValidityCond,
Expand Down
89 changes: 51 additions & 38 deletions crates/sovereign-sdk/adapters/risc0-bonsai/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::time::Duration;

use anyhow::anyhow;
use bonsai_sdk::alpha as bonsai_sdk;
use borsh::BorshSerialize;
use borsh::{BorshDeserialize, BorshSerialize};
use risc0_zkvm::sha::Digest;
use risc0_zkvm::{
compute_image_id, ExecutorEnvBuilder, ExecutorImpl, Groth16Receipt, InnerReceipt, Journal,
Expand Down Expand Up @@ -142,7 +142,7 @@ impl BonsaiClient {
let _ = notify.send(res);
}
BonsaiRequest::Download { url, notify } => {
debug!(%url, "Bonsai:upload_input");
debug!(%url, "Bonsai:download");
let res = client.download(&url);
let res = unwrap_bonsai_response!(res, 'client, 'queue);
let _ = notify.send(res);
Expand Down Expand Up @@ -326,44 +326,56 @@ impl<'a> Risc0BonsaiHost<'a> {
}

fn add_hint_bonsai<T: BorshSerialize>(&mut self, item: T) {
unimplemented!()
// // For running in "prove" mode.

// // Prepare input data and upload it.
// let client = self.client.as_ref().unwrap();

// let input_data = to_vec(&item).unwrap();
// let input_data = bytemuck::cast_slice(&input_data).to_vec();
// // handle error
// let input_id = client.upload_input(input_data);
// tracing::info!("Uploaded input with id: {}", input_id);
// self.last_input_id = Some(input_id);
// For running in "prove" mode.

// Prepare input data and upload it.
let client = self.client.as_ref().unwrap();

let mut input_data = vec![];
let mut buf = borsh::to_vec(&item).unwrap();
// append [0..] alignment
let rem = buf.len() % 4;
if rem > 0 {
buf.extend(vec![0; 4 - rem]);
}
let buf_u32: &[u32] = bytemuck::cast_slice(&buf);
// write len(u64) in LE
let len = buf_u32.len() as u64;
input_data.extend(len.to_le_bytes());
// write buf
input_data.extend(buf);

// handle error
let input_id = client.upload_input(input_data);
tracing::info!("Uploaded input with id: {}", input_id);
self.last_input_id = Some(input_id);
}
}

impl<'a> ZkvmHost for Risc0BonsaiHost<'a> {
type Guest = Risc0Guest;

fn add_hint<T: BorshSerialize>(&mut self, item: T) {
unimplemented!()
// // For running in "execute" mode.

// // We use the in-memory size of `item` as an indication of how much
// // space to reserve. This is in no way guaranteed to be exact, but
// // usually the in-memory size and serialized data size are quite close.
// //
// // Note: this is just an optimization to avoid frequent reallocations,
// // it's not actually required.
// self.env
// .reserve(std::mem::size_of::<T>() / std::mem::size_of::<u32>());

// let mut serializer = risc0_zkvm::serde::Serializer::new(&mut self.env);
// item.serialize(&mut serializer)
// .expect("Risc0 hint serialization is infallible");

// if self.client.is_some() {
// self.add_hint_bonsai(item)
// }
// For running in "execute" mode.

let mut buf = borsh::to_vec(&item).expect("Risc0 hint serialization is infallible");
// append [0..] alignment to cast &[u8] to &[u32]
let rem = buf.len() % 4;
if rem > 0 {
buf.extend(vec![0; 4 - rem]);
}
let buf: &[u32] = bytemuck::cast_slice(&buf);
// write len(u64) in LE
let len = buf.len() as u64;
let len_buf = &len.to_le_bytes()[..];
let len_buf: &[u32] = bytemuck::cast_slice(len_buf);
self.env.extend_from_slice(len_buf);
// write buf
self.env.extend_from_slice(buf);

if self.client.is_some() {
self.add_hint_bonsai(item)
}
}

/// Guest simulation (execute mode) is run inside the Risc0 VM locally
Expand Down Expand Up @@ -496,19 +508,20 @@ impl<'a> ZkvmHost for Risc0BonsaiHost<'a> {
}
}

fn extract_output<Da: sov_rollup_interface::da::DaSpec, Root: Serialize + DeserializeOwned>(
fn extract_output<Da: sov_rollup_interface::da::DaSpec, Root: BorshDeserialize>(
proof: &Proof,
) -> Result<sov_rollup_interface::zk::StateTransition<Da, Root>, Self::Error> {
match proof {
let journal = match proof {
Proof::PublicInput(journal) => {
let journal: Journal = bincode::deserialize(journal)?;
Ok(journal.decode()?)
journal
}
Proof::Full(data) => {
let receipt: Receipt = bincode::deserialize(data)?;
Ok(receipt.journal.decode()?)
receipt.journal
}
}
};
Ok(BorshDeserialize::try_from_slice(&journal.bytes)?)
}
}

Expand Down
54 changes: 20 additions & 34 deletions crates/sovereign-sdk/adapters/risc0/src/guest/native.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! This module implements the `ZkvmGuest` trait for the RISC0 VM.
use borsh::{BorshDeserialize, BorshSerialize};
use risc0_zkvm::serde::WordRead;
use sov_rollup_interface::zk::ZkvmGuest;

#[derive(Default)]
Expand All @@ -19,37 +18,12 @@ impl Hints {
}
}

impl WordRead for Hints {
fn read_words(&mut self, words: &mut [u32]) -> risc0_zkvm::serde::Result<()> {
if let Some(slice) = self.values.get(self.position..self.position + words.len()) {
words.copy_from_slice(slice);
self.position += words.len();
Ok(())
} else {
Err(risc0_zkvm::serde::Error::DeserializeUnexpectedEnd)
}
}

fn read_padded_bytes(&mut self, bytes: &mut [u8]) -> risc0_zkvm::serde::Result<()> {
use risc0_zkvm::align_up;
use risc0_zkvm_platform::WORD_SIZE;

let remaining_bytes: &[u8] = bytemuck::cast_slice(&self.values[self.position..]);
if bytes.len() > remaining_bytes.len() {
return Err(risc0_zkvm::serde::Error::DeserializeUnexpectedEnd);
}
bytes.copy_from_slice(&remaining_bytes[..bytes.len()]);
self.position += align_up(bytes.len(), WORD_SIZE) / WORD_SIZE;
Ok(())
}
}

/// A guest for the RISC0 VM. Implements the `ZkvmGuest` trait
/// using interior mutability to test the functionality.
#[derive(Default)]
pub struct Risc0Guest {
hints: std::sync::Mutex<Hints>,
commits: std::sync::Mutex<Vec<u32>>,
// commits: std::sync::Mutex<Vec<u32>>,
}

impl Risc0Guest {
Expand All @@ -62,21 +36,33 @@ impl Risc0Guest {
pub fn with_hints(hints: Vec<u32>) -> Self {
Self {
hints: std::sync::Mutex::new(Hints::with_hints(hints)),
commits: Default::default(),
// commits: Default::default(),
}
}
}

impl ZkvmGuest for Risc0Guest {
fn read_from_host<T: BorshDeserialize>(&self) -> T {
unimplemented!("read_from_host")
// let mut hints = self.hints.lock().unwrap();
// let mut hints = hints.deref_mut();
// T::deserialize(&mut Deserializer::new(&mut hints)).unwrap()
let mut hints = self.hints.lock().unwrap();
let hints = &mut *hints;
let pos = &mut hints.position;
let env = &hints.values;
// read len(u64) in LE
let len_buf = &env[*pos..*pos + 2];
let len_bytes = bytemuck::cast_slice(len_buf);
let len_bytes: [u8; 8] = len_bytes.try_into().expect("Exactly 4 bytes");
let len = u64::from_le_bytes(len_bytes) as usize;
*pos += 2;
// read buf
let buf = &env[*pos..*pos + len];
let buf: &[u8] = bytemuck::cast_slice(buf);
*pos += len;
// deserialize
BorshDeserialize::deserialize(&mut &*buf).unwrap()
}

fn commit<T: BorshSerialize>(&self, item: &T) {
unimplemented!("commit")
fn commit<T: BorshSerialize>(&self, _item: &T) {
unimplemented!("commitment never used in a test code")
// self.commits.lock().unwrap().extend_from_slice(
// &risc0_zkvm::serde::to_vec(item).expect("Serialization to vec is infallible"),
// );
Expand Down
19 changes: 15 additions & 4 deletions crates/sovereign-sdk/adapters/risc0/src/guest/zkvm.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! This module implements the `ZkvmGuest` trait for the RISC0 VM.
use borsh::{BorshDeserialize, BorshSerialize};
use risc0_zkvm::guest::env;
use risc0_zkvm::guest::env::Write;
use sov_rollup_interface::zk::ZkvmGuest;

/// A guest for the RISC0 VM. Implements the `ZkvmGuest` trait
Expand All @@ -17,12 +18,22 @@ impl Risc0Guest {

impl ZkvmGuest for Risc0Guest {
fn read_from_host<T: BorshDeserialize>(&self) -> T {
unimplemented!()
// FIXME: env::read()
// read len(u64) in LE
let mut len_buf = [0u8; 8];
env::read_slice(&mut len_buf);
let len = u64::from_le_bytes(len_buf);
// read buf
let mut buf: Vec<u32> = vec![0; len as usize];
env::read_slice(&mut buf);
let slice: &[u8] = bytemuck::cast_slice(&buf);
// deserialize
BorshDeserialize::deserialize(&mut &*slice).expect("Failed to deserialize input from host")
}

fn commit<T: BorshSerialize>(&self, item: &T) {
unimplemented!()
// FIXME: env::commit(item);
// use risc0_zkvm::guest::env::Write as _;
let buf = borsh::to_vec(item).expect("Serialization to vec is infallible");
let mut journal = env::journal();
journal.write_slice(&buf);
}
}
Loading

0 comments on commit fc22673

Please sign in to comment.