Skip to content

Commit

Permalink
Reinstate brace checks implemented in huggingface#306
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Dec 18, 2023
1 parent a2afb0a commit 39e7aba
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions safetensors/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,9 @@ impl<'data> SafeTensors<'data> {
let string =
std::str::from_utf8(&buffer[8..stop]).map_err(|_| SafeTensorError::InvalidHeader)?;
// Assert the string starts with {
// NOTE: Add when we move to 0.4.0
// if !string.starts_with('{') {
// return Err(SafeTensorError::InvalidHeaderStart);
// }
if !string.starts_with('{') {
return Err(SafeTensorError::InvalidHeaderStart);
}
let metadata: Metadata = serde_json::from_str(string)
.map_err(|_| SafeTensorError::InvalidHeaderDeserialization)?;
let buffer_end = metadata.validate()?;
Expand Down Expand Up @@ -1134,18 +1133,17 @@ mod tests {
assert_eq!(loaded.len(), 0);
}

// Reserver for 0.4.0
// #[test]
// /// Test that the JSON header must begin with a `{` character.
// fn test_whitespace_start_padded_header_is_not_allowed() {
// let serialized = b"\x06\x00\x00\x00\x00\x00\x00\x00\x09\x0A{}\x0D\x20";
// match SafeTensors::deserialize(serialized) {
// Err(SafeTensorError::InvalidHeaderStart) => {
// // Correct error
// }
// _ => panic!("This should not be able to be deserialized"),
// }
// }
#[test]
/// Test that the JSON header must begin with a `{` character.
fn test_whitespace_start_padded_header_is_not_allowed() {
let serialized = b"\x06\x00\x00\x00\x00\x00\x00\x00\x09\x0A{}\x0D\x20";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::InvalidHeaderStart) => {
// Correct error
}
_ => panic!("This should not be able to be deserialized"),
}
}

#[test]
fn test_zero_sized_tensor() {
Expand Down

0 comments on commit 39e7aba

Please sign in to comment.