Skip to content

Commit

Permalink
Fix modeled_types tests
Browse files Browse the repository at this point in the history
Tests were failing because the model changed to using SingleLineString but
tests weren't updated and were still using String.

This updates modeled_types to use TryFrom as the primary source of conversion
logic, delegating to it in Deserialize.  This simplifies using the new types in
tests, and tests our logic rather than serde.  Controller tests can then use
try_into to get the expected type.

Because we're no longer diverting through serde, we no longer need to awkwardly
quote data strings in the tests.
  • Loading branch information
tjkirch committed Nov 6, 2019
1 parent 043fe5a commit 488a236
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 69 deletions.
195 changes: 136 additions & 59 deletions workspaces/api/apiserver/src/modeled_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,32 @@
use serde::{Deserialize, Deserializer, Serialize, Serializer};
// Just need serde's Error in scope to get its trait methods
use serde::de::Error as _;
use snafu::{ensure, ResultExt};
use std::borrow::Borrow;
use std::convert::TryFrom;
use std::fmt;
use std::ops::Deref;

pub mod error {
use snafu::Snafu;

#[derive(Debug, Snafu)]
#[snafu(visibility = "pub(super)")]
pub enum Error {
#[snafu(display("Can't create SingleLineString containing line terminator"))]
StringContainsLineTerminator,

#[snafu(display("Invalid base64 input: {}", source))]
InvalidBase64 { source: base64::DecodeError },

#[snafu(display(
"Identifiers may only contain ASCII alphanumerics plus hyphens, received '{}'",
input
))]
InvalidIdentifier { input: String },
}
}

/// ValidBase64 can only be created by deserializing from valid base64 text. It stores the
/// original text, not the decoded form. Its purpose is input validation, namely being used as a
/// field in a model structure so that you don't even accept a request with a field that has
Expand All @@ -20,15 +42,33 @@ pub struct ValidBase64 {
}

/// Validate base64 format before we accept the input.
impl TryFrom<&str> for ValidBase64 {
type Error = error::Error;

fn try_from(input: &str) -> Result<Self, Self::Error> {
base64::decode(&input).context(error::InvalidBase64)?;
Ok(ValidBase64 {
inner: input.to_string(),
})
}
}

impl TryFrom<String> for ValidBase64 {
type Error = error::Error;

fn try_from(input: String) -> Result<Self, Self::Error> {
Self::try_from(input.as_ref())
}
}

impl<'de> Deserialize<'de> for ValidBase64 {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let original = String::deserialize(deserializer)?;
base64::decode(&original)
.map_err(|e| D::Error::custom(format!("Invalid base64: {}", e)))?;
Ok(ValidBase64 { inner: original })
Self::try_from(original)
.map_err(|e| D::Error::custom(format!("Unable to deserialize into ValidBase64: {}", e)))
}
}

Expand Down Expand Up @@ -83,19 +123,19 @@ impl From<ValidBase64> for String {
#[cfg(test)]
mod test_valid_base64 {
use super::ValidBase64;
use std::convert::TryFrom;

#[test]
fn valid_base64() {
let v: ValidBase64 = serde_json::from_str("\"aGk=\"").unwrap();
let v = ValidBase64::try_from("aGk=").unwrap();
let decoded_bytes = base64::decode(v.as_ref()).unwrap();
let decoded = std::str::from_utf8(&decoded_bytes).unwrap();
assert_eq!(decoded, "hi");
}

#[test]
fn invalid_base64() {
assert!(serde_json::from_str::<ValidBase64>("\"invalid base64\"").is_err());
assert!(serde_json::from_str::<ValidBase64>("").is_err());
assert!(ValidBase64::try_from("invalid base64").is_err());
}
}

Expand All @@ -110,14 +150,10 @@ pub struct SingleLineString {
inner: String,
}

/// Validate line count before we accept a deserialization.
impl<'de> Deserialize<'de> for SingleLineString {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let original = String::deserialize(deserializer)?;
impl TryFrom<&str> for SingleLineString {
type Error = error::Error;

fn try_from(input: &str) -> Result<Self, Self::Error> {
// Rust does not treat all Unicode line terminators as starting a new line, so we check for
// specific characters here, rather than just counting from lines().
// https://en.wikipedia.org/wiki/Newline#Unicode
Expand All @@ -130,14 +166,39 @@ impl<'de> Deserialize<'de> for SingleLineString {
'\u{2028}', // line separator
'\u{2029}', // paragraph separator
];
if let Some(term) = original.find(&line_terminators[..]) {
Err(D::Error::custom(format!(
"Can't create SingleLineString with line terminator '{}'",
term,
)))
} else {
Ok(SingleLineString { inner: original })
}

ensure!(
!input.contains(&line_terminators[..]),
error::StringContainsLineTerminator
);

Ok(Self {
inner: input.to_string(),
})
}
}

impl TryFrom<String> for SingleLineString {
type Error = error::Error;

fn try_from(input: String) -> Result<Self, Self::Error> {
Self::try_from(input.as_ref())
}
}

/// Validate line count before we accept a deserialization.
impl<'de> Deserialize<'de> for SingleLineString {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let original = String::deserialize(deserializer)?;
Self::try_from(original).map_err(|e| {
D::Error::custom(format!(
"Unable to deserialize into SingleLineString: {}",
e
))
})
}
}

Expand Down Expand Up @@ -191,29 +252,30 @@ impl From<SingleLineString> for String {
#[cfg(test)]
mod test_single_line_string {
use super::SingleLineString;
use std::convert::TryFrom;

#[test]
fn valid_single_line_string() {
assert!(serde_json::from_str::<SingleLineString>("\"\"").is_ok());
assert!(serde_json::from_str::<SingleLineString>("\"hi\"").is_ok());
assert!(SingleLineString::try_from("").is_ok());
assert!(SingleLineString::try_from("hi").is_ok());
let long_string = std::iter::repeat(" ").take(9999).collect::<String>();
let json_long_string = format!("\"{}\"", &long_string);
assert!(serde_json::from_str::<SingleLineString>(&json_long_string).is_ok());
let json_long_string = format!("{}", &long_string);
assert!(SingleLineString::try_from(json_long_string).is_ok());
}

#[test]
fn invalid_single_line_string() {
assert!(serde_json::from_str::<SingleLineString>("\"Hello\nWorld\"").is_err());
assert!(SingleLineString::try_from("Hello\nWorld").is_err());

assert!(serde_json::from_str::<SingleLineString>("\"\n\"").is_err());
assert!(serde_json::from_str::<SingleLineString>("\"\r\"").is_err());
assert!(serde_json::from_str::<SingleLineString>("\"\r\n\"").is_err());
assert!(SingleLineString::try_from("\n").is_err());
assert!(SingleLineString::try_from("\r").is_err());
assert!(SingleLineString::try_from("\r\n").is_err());

assert!(serde_json::from_str::<SingleLineString>("\"\u{000B}\"").is_err()); // vertical tab
assert!(serde_json::from_str::<SingleLineString>("\"\u{000C}\"").is_err()); // form feed
assert!(serde_json::from_str::<SingleLineString>("\"\u{0085}\"").is_err()); // next line
assert!(serde_json::from_str::<SingleLineString>("\"\u{2028}\"").is_err()); // line separator
assert!(serde_json::from_str::<SingleLineString>("\"\u{2029}\"").is_err());
assert!(SingleLineString::try_from("\u{000B}").is_err()); // vertical tab
assert!(SingleLineString::try_from("\u{000C}").is_err()); // form feed
assert!(SingleLineString::try_from("\u{0085}").is_err()); // next line
assert!(SingleLineString::try_from("\u{2028}").is_err()); // line separator
assert!(SingleLineString::try_from("\u{2029}").is_err());
// paragraph separator
}
}
Expand All @@ -230,24 +292,38 @@ pub struct Identifier {
inner: String,
}

impl TryFrom<&str> for Identifier {
type Error = error::Error;

fn try_from(input: &str) -> Result<Self, Self::Error> {
ensure!(
input
.chars()
.all(|c| (c.is_ascii() && c.is_alphanumeric()) || c == '-'),
error::InvalidIdentifier { input }
);
Ok(Identifier {
inner: input.to_string(),
})
}
}

impl TryFrom<String> for Identifier {
type Error = error::Error;

fn try_from(input: String) -> Result<Self, Self::Error> {
Self::try_from(input.as_ref())
}
}

impl<'de> Deserialize<'de> for Identifier {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let original = String::deserialize(deserializer)?;

if !original
.chars()
.all(|c| (c.is_ascii() && c.is_alphanumeric()) || c == '-')
{
Err(D::Error::custom(format!(
"Identifiers may only contain ASCII alphanumerics plus hyphens; received '{}'",
original,
)))
} else {
Ok(Identifier { inner: original })
}
Self::try_from(original)
.map_err(|e| D::Error::custom(format!("Unable to deserialize into Identifier: {}", e)))
}
}

Expand Down Expand Up @@ -295,25 +371,26 @@ impl fmt::Display for Identifier {
#[cfg(test)]
mod test_valid_identifier {
use super::Identifier;
use std::convert::TryFrom;

#[test]
fn valid_identifier() {
assert!(serde_json::from_str::<Identifier>("\"hello-world\"").is_ok());
assert!(serde_json::from_str::<Identifier>("\"helloworld\"").is_ok());
assert!(serde_json::from_str::<Identifier>("\"123321hello\"").is_ok());
assert!(serde_json::from_str::<Identifier>("\"hello-1234\"").is_ok());
assert!(serde_json::from_str::<Identifier>("\"--------\"").is_ok());
assert!(serde_json::from_str::<Identifier>("\"11111111\"").is_ok());
assert!(Identifier::try_from("hello-world").is_ok());
assert!(Identifier::try_from("helloworld").is_ok());
assert!(Identifier::try_from("123321hello").is_ok());
assert!(Identifier::try_from("hello-1234").is_ok());
assert!(Identifier::try_from("--------").is_ok());
assert!(Identifier::try_from("11111111").is_ok());
}

#[test]
fn invalid_identifier() {
assert!(serde_json::from_str::<Identifier>("\"../\"").is_err());
assert!(serde_json::from_str::<Identifier>("\"{}\"").is_err());
assert!(serde_json::from_str::<Identifier>("\"hello|World\"").is_err());
assert!(serde_json::from_str::<Identifier>("\"hello\nWorld\"").is_err());
assert!(serde_json::from_str::<Identifier>("\"hello_world\"").is_err());
assert!(serde_json::from_str::<Identifier>("\"タール\"").is_err());
assert!(serde_json::from_str::<Identifier>("\"💝\"").is_err());
assert!(Identifier::try_from("../").is_err());
assert!(Identifier::try_from("{}").is_err());
assert!(Identifier::try_from("hello|World").is_err());
assert!(Identifier::try_from("hello\nWorld").is_err());
assert!(Identifier::try_from("hello_world").is_err());
assert!(Identifier::try_from("タール").is_err());
assert!(Identifier::try_from("💝").is_err());
}
}
17 changes: 9 additions & 8 deletions workspaces/api/apiserver/src/server/controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ mod test {
use crate::datastore::{Committed, DataStore, Key, KeyType};
use crate::model::Service;
use maplit::{hashmap, hashset};
use std::convert::TryInto;

#[test]
fn get_settings_works() {
Expand All @@ -391,7 +392,7 @@ mod test {

// Retrieve with helper
let settings = get_settings(&ds, Committed::Live).unwrap();
assert_eq!(settings.hostname, Some("json string".to_string()));
assert_eq!(settings.hostname, Some("json string".try_into().unwrap()));
}

#[test]
Expand All @@ -407,10 +408,10 @@ mod test {

// Retrieve with helper
let settings = get_settings_prefix(&ds, "", Committed::Live).unwrap();
assert_eq!(settings.timezone, Some("json string".to_string()));
assert_eq!(settings.timezone, Some("json string".try_into().unwrap()));

let settings = get_settings_prefix(&ds, "tim", Committed::Live).unwrap();
assert_eq!(settings.timezone, Some("json string".to_string()));
assert_eq!(settings.timezone, Some("json string".try_into().unwrap()));

let settings = get_settings_prefix(&ds, "timbits", Committed::Live).unwrap();
assert_eq!(settings.timezone, None);
Expand All @@ -437,7 +438,7 @@ mod test {
// Retrieve with helper
let settings =
get_settings_keys(&ds, &hashset!("settings.timezone"), Committed::Live).unwrap();
assert_eq!(settings.timezone, Some("json string 1".to_string()));
assert_eq!(settings.timezone, Some("json string 1".try_into().unwrap()));
assert_eq!(settings.hostname, None);
}

Expand All @@ -464,7 +465,7 @@ mod test {
assert_eq!(
services,
hashmap!("foo".to_string() => Service {
configuration_files: vec!["file1".to_string()],
configuration_files: vec!["file1".try_into().unwrap()],
restart_commands: vec!["echo hi".to_string()]
})
);
Expand All @@ -473,7 +474,7 @@ mod test {
#[test]
fn set_settings_works() {
let mut settings = Settings::default();
settings.timezone = Some("tz".to_string());
settings.timezone = Some("tz".try_into().unwrap());

// Set with helper
let mut ds = MemoryDataStore::new();
Expand Down Expand Up @@ -547,7 +548,7 @@ mod test {

// Confirm pending
let settings = get_settings(&ds, Committed::Pending).unwrap();
assert_eq!(settings.hostname, Some("json string".to_string()));
assert_eq!(settings.hostname, Some("json string".try_into().unwrap()));
// No live settings yet
get_settings(&ds, Committed::Live).unwrap_err();

Expand All @@ -558,6 +559,6 @@ mod test {
get_settings(&ds, Committed::Pending).unwrap_err();
// Confirm live
let settings = get_settings(&ds, Committed::Live).unwrap();
assert_eq!(settings.hostname, Some("json string".to_string()));
assert_eq!(settings.hostname, Some("json string".try_into().unwrap()));
}
}
5 changes: 3 additions & 2 deletions workspaces/api/thar-be-settings/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,17 @@ impl RenderedConfigFile {
mod test {
use super::*;
use maplit::{hashmap, hashset};
use std::convert::TryInto;

#[test]
fn test_get_config_file_names() {
let input_map = hashmap!(
"foo".to_string() => model::Service {
configuration_files: vec!["file1".to_string()],
configuration_files: vec!["file1".try_into().unwrap()],
restart_commands: vec!["echo hi".to_string()]
},
"bar".to_string() => model::Service {
configuration_files: vec!["file1".to_string(), "file2".to_string()],
configuration_files: vec!["file1".try_into().unwrap(), "file2".try_into().unwrap()],
restart_commands: vec!["echo hi".to_string()]
},
);
Expand Down

0 comments on commit 488a236

Please sign in to comment.