Skip to content

Commit

Permalink
Merge pull request #1 from antithesishq/code-review
Browse files Browse the repository at this point in the history
Initial code review
  • Loading branch information
herzogp authored Apr 25, 2024
2 parents 4b744c8 + f36e994 commit 6217f9d
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 138 deletions.
14 changes: 7 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ libloading = "0.8.3"
libc = "0.2.153"
rand = "0.8.5"
rustc_version_runtime = "0.3.0"
lazy_static = "1.4.0"
once_cell = "1"
linkme = "0.3"
paste = "1.0.14"

Expand Down
56 changes: 19 additions & 37 deletions lib/src/assert/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use once_cell::sync::Lazy;
use serde_json::{Value, json};
use std::collections::HashMap;
use crate::internal;
Expand Down Expand Up @@ -171,46 +172,29 @@ impl AssertionInfo {

// Verify that the TrackingInfo for self in
// ASSERT_TRACKER has been updated according to self.condition
fn track_entry(&self, tracker: &mut HashMap<String, TrackingInfo>) {

fn track_entry(&self) {
// Requirement: Catalog entries must always will emit()
if !self.hit {
self.emit();
return
}

let tracking_key = self.id.clone();

// Establish TrackingInfo for this trackingKey when needed
let maybe_info = tracker.get(&tracking_key);
if maybe_info.is_none() {
tracker.insert(self.id.clone(), TrackingInfo::new());
}

// Record the condition in the associated TrackingInfo entry
let condition = self.condition;
let tracked_entry = tracker.entry(self.id.clone());
tracked_entry.and_modify(|e: &mut TrackingInfo| {
if condition {
e.pass_count += 1;
} else {
e.fail_count += 1;
}
});

// Really emit the assertion when first seeing a condition
if let Some(tracking_info) = tracker.get(&tracking_key) {
let pass_count = tracking_info.pass_count;
let fail_count = tracking_info.fail_count;
if condition {
if pass_count == 1 {
self.emit()
}
} else if fail_count == 1 {
self.emit()
}
let mut tracker = ASSERT_TRACKER.lock().unwrap();
let info = tracker.entry(self.id.clone()).or_default();
// Record the condition in the associated TrackingInfo entry,
// and emit the assertion when first seeing a condition
let emitting = if self.condition {
info.pass_count += 1;
info.pass_count == 1
} else {
info.fail_count += 1;
info.fail_count == 1
};

drop(tracker); // release the lock asap
if emitting {
self.emit();
}
}

fn emit(&self) {
Expand Down Expand Up @@ -238,9 +222,7 @@ impl AssertionInfo {
}
}

lazy_static!{
static ref ASSERT_TRACKER: Mutex<HashMap<String, TrackingInfo>> = Mutex::new(HashMap::new());
}
static ASSERT_TRACKER: Lazy<Mutex<HashMap<String, TrackingInfo>>> = Lazy::new(|| Mutex::new(HashMap::new()));

#[allow(clippy::too_many_arguments)]
pub fn assert_impl(
Expand All @@ -260,7 +242,7 @@ pub fn assert_impl(


let assertion = AssertionInfo::new(assert_type, display_type, condition, message, class, function, file, begin_line, begin_column, hit, must_hit, id, details);
let _ = &assertion.track_entry(ASSERT_TRACKER.lock().as_deref_mut().unwrap());
let _ = &assertion.track_entry();
}

#[cfg(test)]
Expand Down Expand Up @@ -617,7 +599,7 @@ mod tests {
let mut tracking_data = TrackingInfo::new();

let tracking_key: String = key.to_owned();
match ASSERT_TRACKER.lock().as_deref().unwrap().get(&tracking_key) {
match ASSERT_TRACKER.lock().unwrap().get(&tracking_key) {
None => tracking_data,
Some(ti) => {
tracking_data.pass_count = ti.pass_count;
Expand Down
23 changes: 12 additions & 11 deletions lib/src/internal/local_handler.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use serde_json::{Value};
use std::env;
use std::fs::File;
use std::io::{Write, BufWriter, Error};
use std::io::{Write, Error};

use crate::internal::{LibHandler};

static LOCAL_OUTPUT: &str = "ANTITHESIS_SDK_LOCAL_OUTPUT";
const LOCAL_OUTPUT: &str = "ANTITHESIS_SDK_LOCAL_OUTPUT";

// #[allow(dead_code)]
pub struct LocalHandler {
maybe_writer: Option<BufWriter<Box<dyn Write + Send>>>
maybe_writer: Option<File>
}

impl LocalHandler {
Expand All @@ -27,7 +26,7 @@ impl LocalHandler {
// Seems like LocalHandler gets bound to a reference with
// a 'static lifetime.
LocalHandler{
maybe_writer: Some(BufWriter::with_capacity(0, Box::new(f)))
maybe_writer: Some(f)
}
} else {
eprintln!("Unable to write to '{}' - {}", filename.as_str(), create_result.unwrap_err());
Expand All @@ -39,13 +38,15 @@ impl LocalHandler {
}

impl LibHandler for LocalHandler {
fn output(&mut self, value: &Value) -> Result<(), Error> {
let maybe_writer = self.maybe_writer.as_mut();
match maybe_writer {
fn output(&self, value: &Value) -> Result<(), Error> {
match &self.maybe_writer {
Some(b2w) => {
let mut text_line = value.to_string();
text_line.push('\n');
b2w.write_all(text_line.as_bytes())?;
let mut b2w = b2w;
// The compact Display impl (selected using `{}`) of `serde_json::Value` contains no newlines,
// hence we are outputing valid JSONL format here.
// Using the `{:#}` format specifier may results in extra newlines and indentation.
// See https://docs.rs/serde_json/latest/serde_json/enum.Value.html#impl-Display-for-Value.
writeln!(b2w, "{}", value)?;
b2w.flush()?;
Ok(())
},
Expand Down
45 changes: 13 additions & 32 deletions lib/src/internal/mod.rs
Original file line number Diff line number Diff line change
@@ -1,52 +1,34 @@
use once_cell::sync::Lazy;
use rustc_version_runtime::version;
use serde_json::{Value, json};
use std::io::{Error};
use std::sync::Mutex;
use local_handler::LocalHandler;
use voidstar_handler::{VoidstarHandler, has_voidstar};
use voidstar_handler::{VoidstarHandler};

mod local_handler;
mod voidstar_handler;

// Hardly ever changes, refers to the underlying JSON representation
static PROTOCOL_VERSION: &str = "1.0.0";
const PROTOCOL_VERSION: &str = "1.0.0";

// Tracks SDK releases
static SDK_VERSION: &str = "0.1.1";
const SDK_VERSION: &str = "0.1.1";

// static mut LIB_HANDLER: Option<Box<dyn LibHandler>> = None;
static LIB_HANDLER: Mutex<Option<Box<dyn LibHandler + Send>>> = Mutex::new(None);
static LIB_HANDLER: Lazy<Box<dyn LibHandler + Sync + Send>> = Lazy::new(|| {
match VoidstarHandler::try_load() {
Ok(handler) => Box::new(handler),
Err(_) => Box::new(LocalHandler::new()),
}
});

trait LibHandler {
fn output(&mut self, value: &Value) -> Result<(), Error>;
fn output(&self, value: &Value) -> Result<(), Error>;
fn random(&self) -> u64;
}

fn instantiate_handler() {
if LIB_HANDLER.lock().unwrap().is_some() {
return
}

let lh : Box<dyn LibHandler + Send> = if has_voidstar() {
Box::new(VoidstarHandler::new())
} else {
Box::new(LocalHandler::new())
};

{
let mut x = LIB_HANDLER.lock().unwrap();
*x = Some(lh);
}

let sdk_value: Value = sdk_info();
dispatch_output(&sdk_value)
}

// Made public so it can be invoked from the antithesis_sdk_rust::random module
pub fn dispatch_random() -> u64 {
instantiate_handler();
LIB_HANDLER.lock().unwrap().as_ref().unwrap().random()

LIB_HANDLER.random()
}

// Ignore any and all errors - either the output is completed,
Expand All @@ -66,8 +48,7 @@ pub fn dispatch_random() -> u64 {
// Made public so it can be invoked from the antithesis_sdk_rust::lifecycle
// and antithesis_sdk_rust::assert module
pub fn dispatch_output(json_data: &Value) {
instantiate_handler();
let _ = LIB_HANDLER.lock().unwrap().as_mut().unwrap().output(json_data);
let _ = LIB_HANDLER.output(json_data);
}

fn sdk_info() -> Value {
Expand Down
68 changes: 30 additions & 38 deletions lib/src/internal/voidstar_handler.rs
Original file line number Diff line number Diff line change
@@ -1,62 +1,54 @@
use libc::c_char;
use libloading::{Library, Symbol};
use libc::{c_char, size_t};
use libloading::{Library};
use serde_json::{Value};
use std::ffi::{CString};
use std::io::{Error};
use std::sync::{Once, Mutex, Arc};

use crate::internal::{LibHandler};

static LIB_NAME: &str = "/usr/lib/libmockstar.so";
const LIB_NAME: &str = "/usr/lib/libmockstar.so";

pub fn has_voidstar() -> bool {
load_voidstar();
LIB_VOIDSTAR.lock().unwrap().is_some()
}

static LIB_VOIDSTAR: Mutex<Option<Arc<Library>>> = Mutex::new(None);

fn load_voidstar() {
static LOAD_VOIDSTAR: Once = Once::new();
LOAD_VOIDSTAR.call_once(|| {
let result = unsafe {
Library::new(LIB_NAME)
};
let mut lib_voidstar = LIB_VOIDSTAR.lock().unwrap();
*lib_voidstar = result.ok().map(Arc::new);
});
}

#[derive(Debug)]
pub struct VoidstarHandler {
voidstar_lib: Arc<Library>,
// Not used directly but exists to ensure the library is loaded
// and all the following function pointers points to valid memory.
_lib: Library,
// SAFETY: The memory pointed by `s` must be valid up to `l` bytes.
fuzz_json_data: unsafe fn(s: *const c_char, l: size_t),
fuzz_get_random: fn() -> u64,
}

impl VoidstarHandler {
pub fn new() -> Self {
load_voidstar();
let lib = LIB_VOIDSTAR.lock().unwrap().as_ref().unwrap().clone();
VoidstarHandler{
voidstar_lib: lib,
pub fn try_load() -> Result<Self, libloading::Error> {
// SAFETY:
// - The `libvoidstar`/`libmockstar `libraries that we intended to load
// should not have initalization procedures that requires special arrangments at loading time.
// Otherwise, loading an arbitrary library that happens to be at `LIB_NAME` is an unsupported case.
// - Similarly, we load symbols by names and assume they have the expected signatures,
// and loading arbitrary symbols that happen to take those names are unsupported.
// - `fuzz_json_data` and `fuzz_get_random` copy the function pointers,
// but they would be valid as we bind their lifetime to the library they are from
// by storing all of them in the `VoidstarHandler` struct.
unsafe {
let lib = Library::new(LIB_NAME)?;
let fuzz_json_data = *lib.get(b"fuzz_json_data\0")?;
let fuzz_get_random = *lib.get(b"fuzz_get_random\0")?;
Ok(VoidstarHandler { _lib: lib, fuzz_json_data, fuzz_get_random })
}
}
}

impl LibHandler for VoidstarHandler {
fn output(&mut self, value: &Value) -> Result<(), Error> {
fn output(&self, value: &Value) -> Result<(), Error> {
let payload = value.to_string();
// SAFETY: The data pointer and length passed into `fuzz_json_data` points to valid memory
// that we just initialized above.
unsafe {
let json_data_func: Symbol<unsafe extern fn(s: *const c_char)> = self.voidstar_lib.get(b"fuzz_json_data").unwrap();
let payload = CString::new(value.to_string())?;
json_data_func(payload.as_ptr());
(self.fuzz_json_data)(payload.as_bytes().as_ptr() as *const c_char, payload.len());
}
Ok(())
}

fn random(&self) -> u64 {
unsafe {
let get_random_func: Symbol<unsafe extern fn() -> u64> = self.voidstar_lib.get(b"fuzz_get_random").unwrap();
get_random_func()
}
(self.fuzz_get_random)()
}
}

3 changes: 0 additions & 3 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#[macro_use]
extern crate lazy_static;

pub mod assert;
pub mod lifecycle;
pub mod random;
Expand Down
Loading

0 comments on commit 6217f9d

Please sign in to comment.