Skip to content

Commit

Permalink
Remove separate validation array
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeeter committed Oct 30, 2024
1 parent 134c114 commit ff1d036
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 125 deletions.
36 changes: 16 additions & 20 deletions upstairs/src/buffer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// Copyright 2023 Oxide Computer Company
use crate::RawReadResponse;
use bytes::{Bytes, BytesMut};
use crucible_protocol::ReadBlockContext;
use itertools::Itertools;
Expand Down Expand Up @@ -141,14 +140,17 @@ impl Buffer {
/// # Panics
/// The response data length must be the same as our buffer length (which
/// must be an even multiple of block size, ensured at construction)
pub(crate) fn write_read_response(&mut self, response: RawReadResponse) {
assert!(response.data.len() == self.data.len());
assert_eq!(response.data.len() % self.block_size, 0);
pub(crate) fn write_read_response(
&mut self,
blocks: &[ReadBlockContext],
data: &mut BytesMut,
) {
assert!(data.len() == self.data.len());
assert_eq!(data.len() % self.block_size, 0);
let bs = self.block_size;

// Build contiguous chunks which are all owned, to copy in bulk
for (empty, mut group) in &response
.blocks
for (empty, mut group) in &blocks
.iter()
.enumerate()
.chunk_by(|(_i, b)| matches!(b, ReadBlockContext::Empty))
Expand All @@ -164,16 +166,13 @@ impl Buffer {

// Special case: if the entire buffer is owned, then we swap it
// instead of copying element-by-element.
if count == response.blocks.len()
&& self.data.len() == response.data.len()
{
self.data = response.data;
if count == blocks.len() && self.data.len() == data.len() {
self.data = std::mem::take(data);
break;
} else {
// Otherwise, just copy the sub-region
self.data[(block * bs)..][..(count * bs)].copy_from_slice(
&response.data[(block * bs)..][..(count * bs)],
);
self.data[(block * bs)..][..(count * bs)]
.copy_from_slice(&data[(block * bs)..][..(count * bs)]);
}
}
}
Expand Down Expand Up @@ -493,7 +492,7 @@ mod test {
let mut rng = rand::thread_rng();
rng.fill_bytes(&mut data);

let blocks = (0..10)
let blocks: Vec<_> = (0..10)
.map(|i| {
if f(i) {
ReadBlockContext::Unencrypted { hash: 123 }
Expand All @@ -503,10 +502,7 @@ mod test {
})
.collect();

buf.write_read_response(RawReadResponse {
blocks,
data: data.clone(),
});
buf.write_read_response(&blocks, &mut data.clone());

for i in 0..10 {
let buf_chunk = &buf[i * 512..][..512];
Expand Down Expand Up @@ -564,12 +560,12 @@ mod test {
let mut rng = rand::thread_rng();
rng.fill_bytes(&mut data);

let blocks = (0..10)
let blocks: Vec<_> = (0..10)
.map(|_| ReadBlockContext::Unencrypted { hash: 123 })
.collect();

let prev_data_ptr = data.as_ptr();
buf.write_read_response(RawReadResponse { blocks, data });
buf.write_read_response(&blocks, &mut data);

assert_eq!(buf.data.as_ptr(), prev_data_ptr);
}
Expand Down
41 changes: 17 additions & 24 deletions upstairs/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::{
ClientIOStateCount, ClientId, CrucibleDecoder, CrucibleError, DownstairsIO,
DsState, EncryptionContext, IOState, IOop, JobId, Message, RawReadResponse,
ReconcileIO, ReconcileIOState, RegionDefinitionStatus, RegionMetadata,
Validation,
};
use crucible_common::{x509::TLSContext, ExtentId, VerboseTimeout};
use crucible_protocol::{
Expand Down Expand Up @@ -1268,7 +1267,6 @@ impl DownstairsClient {
ds_id: JobId,
job: &mut DownstairsIO,
responses: Result<RawReadResponse, CrucibleError>,
read_validations: Vec<Validation>,
deactivate: bool,
extent_info: Option<ExtentInfo>,
) -> bool {
Expand Down Expand Up @@ -1399,7 +1397,7 @@ impl DownstairsClient {
*/
let read_data = responses.unwrap();
assert!(!read_data.blocks.is_empty());
if job.read_validations != read_validations {
if job.data.as_ref().unwrap().blocks != read_data.blocks {
// XXX This error needs to go to Nexus
// XXX This will become the "force all downstairs
// to stop and refuse to restart" mode.
Expand All @@ -1413,8 +1411,8 @@ impl DownstairsClient {
self.client_id,
ds_id,
self.cfg.session_id,
job.read_validations,
read_validations,
job.data.as_ref().unwrap().blocks,
read_data.blocks,
start_eid,
start_offset,
job.state,
Expand Down Expand Up @@ -1461,9 +1459,7 @@ impl DownstairsClient {
assert!(extent_info.is_none());
if jobs_completed_ok == 1 {
assert!(job.data.is_none());
assert!(job.read_validations.is_empty());
job.data = Some(read_data);
job.read_validations = read_validations;
assert!(!job.acked);
ackable = true;
debug!(self.log, "Read AckReady {}", ds_id.0);
Expand All @@ -1475,7 +1471,8 @@ impl DownstairsClient {
* that and verify they are the same.
*/
debug!(self.log, "Read already AckReady {ds_id}");
if job.read_validations != read_validations {
let job_blocks = &job.data.as_ref().unwrap().blocks;
if job_blocks != &read_data.blocks {
// XXX This error needs to go to Nexus
// XXX This will become the "force all downstairs
// to stop and refuse to restart" mode.
Expand All @@ -1486,8 +1483,8 @@ impl DownstairsClient {
job: {:?}",
self.client_id,
ds_id,
job.read_validations,
read_validations,
job_blocks,
read_data.blocks,
job,
);
}
Expand Down Expand Up @@ -2995,18 +2992,15 @@ fn update_net_done_probes(m: &Message, cid: ClientId) {
}

/// Returns:
/// - `Ok(Some(ctx))` for successfully decrypted data
/// - `Ok(None)` if there is no block context and the block is all 0
/// - `Ok(())` for successfully decrypted data, or if there is no block context
/// and the block is all 0s (i.e. a valid empty block)
/// - `Err(..)` otherwise
///
/// The return value of this will be stored with the job, and compared
/// between each read.
pub(crate) fn validate_encrypted_read_response(
block_context: Option<crucible_protocol::EncryptionContext>,
data: &mut [u8],
encryption_context: &EncryptionContext,
log: &Logger,
) -> Result<Validation, CrucibleError> {
) -> Result<(), CrucibleError> {
// XXX because we don't have block generation numbers, an attacker
// downstairs could:
//
Expand All @@ -3028,7 +3022,7 @@ pub(crate) fn validate_encrypted_read_response(
//
// XXX if it's not a blank block, we may be under attack?
if data.iter().all(|&x| x == 0) {
return Ok(Validation::Empty);
return Ok(());
} else {
error!(log, "got empty block context with non-blank block");
return Err(CrucibleError::MissingBlockContext);
Expand All @@ -3050,29 +3044,28 @@ pub(crate) fn validate_encrypted_read_response(
Tag::from_slice(&ctx.tag[..]),
);
if decryption_result.is_ok() {
Ok(Validation::Encrypted(ctx))
Ok(())
} else {
error!(log, "Decryption failed!");
Err(CrucibleError::DecryptionError)
}
}

/// Returns:
/// - Ok(Some(valid_hash)) where the integrity hash matches
/// - Ok(None) where there is no integrity hash in the response and the
/// block is all 0
/// - Ok(()) where the integrity hash matches (or the integrity hash is missing
/// and the block is all 0s, indicating an empty block)
/// - Err otherwise
pub(crate) fn validate_unencrypted_read_response(
block_hash: Option<u64>,
data: &mut [u8],
log: &Logger,
) -> Result<Validation, CrucibleError> {
) -> Result<(), CrucibleError> {
if let Some(hash) = block_hash {
// check integrity hashes - make sure it is correct
let computed_hash = integrity_hash(&[data]);

if computed_hash == hash {
Ok(Validation::Unencrypted(computed_hash))
Ok(())
} else {
// No integrity hash was correct for this response
error!(log, "No match computed hash:0x{:x}", computed_hash,);
Expand All @@ -3096,7 +3089,7 @@ pub(crate) fn validate_unencrypted_read_response(
//
// XXX if it's not a blank block, we may be under attack?
if data[..].iter().all(|&x| x == 0) {
Ok(Validation::Empty)
Ok(())
} else {
error!(log, "got empty block context with non-blank block");
Err(CrucibleError::MissingBlockContext)
Expand Down
21 changes: 6 additions & 15 deletions upstairs/src/deferred.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::sync::Arc;
use crate::{
backpressure::BackpressureGuard, client::ConnectionId,
upstairs::UpstairsConfig, BlockContext, BlockOp, ClientData, ClientId,
ImpactedBlocks, Message, RawWrite, Validation,
ImpactedBlocks, Message, RawWrite,
};
use bytes::BytesMut;
use crucible_common::{integrity_hash, CrucibleError, RegionDefinition};
Expand Down Expand Up @@ -194,9 +194,6 @@ impl DeferredWrite {
pub(crate) struct DeferredMessage {
pub message: Message,

/// If this was a `ReadResponse`, then the validation result is stored here
pub hashes: Vec<Validation>,

pub client_id: ClientId,

/// See `DeferredRead::connection_id`
Expand Down Expand Up @@ -225,7 +222,7 @@ impl DeferredRead {
/// Consume the `DeferredRead` and perform decryption
///
/// If decryption fails, then the resulting `Message` has an error in the
/// `responses` field, and `hashes` is empty.
/// `responses` field.
pub fn run(mut self) -> DeferredMessage {
use crate::client::{
validate_encrypted_read_response,
Expand All @@ -234,7 +231,6 @@ impl DeferredRead {
let Message::ReadResponse { header, data } = &mut self.message else {
panic!("invalid DeferredRead");
};
let mut hashes = vec![];

if let Ok(rs) = header.blocks.as_mut() {
assert_eq!(data.len() % rs.len(), 0);
Expand Down Expand Up @@ -284,14 +280,10 @@ impl DeferredRead {
)
})
};
match v {
Ok(hash) => hashes.push(hash),
Err(e) => {
error!(self.log, "decryption failure: {e:?}");
header.blocks = Err(e);
hashes.clear();
break;
}
if let Err(e) = v {
error!(self.log, "decryption failure: {e:?}");
header.blocks = Err(e);
break;
}
}
}
Expand All @@ -300,7 +292,6 @@ impl DeferredRead {
client_id: self.client_id,
message: self.message,
connection_id: self.connection_id,
hashes,
}
}
}
27 changes: 5 additions & 22 deletions upstairs/src/downstairs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
DownstairsMend, DsState, ExtentFix, ExtentRepairIDs, IOState, IOStateCount,
IOop, ImpactedBlocks, JobId, Message, RawReadResponse, RawWrite,
ReconcileIO, ReconciliationId, RegionDefinition, ReplaceResult,
SnapshotDetails, Validation, WorkSummary,
SnapshotDetails, WorkSummary,
};
use crucible_common::{
impacted_blocks::ImpactedAddr, BlockIndex, BlockOffset, ExtentId,
Expand Down Expand Up @@ -425,8 +425,6 @@ impl Downstairs {
let done = self.ds_active.get_mut(&ds_id).unwrap();
assert!(!done.acked);

let data = done.data.take();

done.acked = true;
let r = done.result();

Expand Down Expand Up @@ -476,6 +474,10 @@ impl Downstairs {
// Copy (if present) read data back to the guest buffer they
// provided to us, and notify any waiters.
if let Some(res) = done.res.take() {
let data = done
.data
.as_mut()
.map(|v| (v.blocks.as_slice(), &mut v.data));
res.transfer_and_notify(data, r);
}

Expand Down Expand Up @@ -2141,7 +2143,6 @@ impl Downstairs {
res,
replay: false,
data: None,
read_validations: vec![],
backpressure_guard: bp_guard,
},
);
Expand Down Expand Up @@ -2817,7 +2818,6 @@ impl Downstairs {
&mut self,
client_id: ClientId,
m: Message,
read_validations: Vec<Validation>,
up_state: &UpstairsState,
) -> Result<(), CrucibleError> {
let (upstairs_id, session_id, ds_id, read_data, extent_info) = match m {
Expand Down Expand Up @@ -3068,7 +3068,6 @@ impl Downstairs {
ds_id,
client_id,
read_data,
read_validations,
up_state,
extent_info,
);
Expand Down Expand Up @@ -3115,17 +3114,10 @@ impl Downstairs {
) -> bool {
let was_ackable = self.ackable_work.contains(&ds_id);

// Make up dummy values for hashes, since they're not actually checked
// here (besides confirming that we have the correct number).
let hashes = match &responses {
Ok(r) => vec![Validation::Unencrypted(0); r.blocks.len()],
Err(..) => vec![],
};
self.process_io_completion_inner(
ds_id,
client_id,
responses,
hashes,
up_state,
extent_info,
);
Expand All @@ -3139,7 +3131,6 @@ impl Downstairs {
ds_id: JobId,
client_id: ClientId,
responses: Result<RawReadResponse, CrucibleError>,
read_validations: Vec<Validation>,
up_state: &UpstairsState,
extent_info: Option<ExtentInfo>,
) {
Expand Down Expand Up @@ -3168,18 +3159,10 @@ impl Downstairs {
return;
};

// Sanity-checking for a programmer error during offloaded decryption.
// If we didn't get one hash per read block, then `responses` must
// have been converted into `Err(..)`.
if let Ok(reads) = &responses {
assert_eq!(reads.blocks.len(), read_validations.len());
}

if self.clients[client_id].process_io_completion(
ds_id,
job,
responses,
read_validations,
deactivate,
extent_info,
) {
Expand Down
Loading

0 comments on commit ff1d036

Please sign in to comment.