Skip to content

Commit

Permalink
just have it as a string
Browse files Browse the repository at this point in the history
  • Loading branch information
preyasshah committed Jul 17, 2024
1 parent 53d6b20 commit 0a99f23
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 60 deletions.
6 changes: 2 additions & 4 deletions martian-lab/examples/sum_sq/src/sum_squares.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
//! SumSquares stage code
use serde::{Deserialize, Serialize};

// The prelude brings the following items in scope:
// - Traits: MartianMain, MartianStage, RawMartianStage, MartianFileType, MartianMakePath
// - Struct/Enum: MartianRover, Resource, StageDef, MartianVoid,
// Error (from anyhow crate), LevelFilter (from log crate)
// - Macros: martian_stages!
// - Functions: martian_main, martian_main_with_log_level, martian_make_mro
use martian::prelude::*;

// Bring the procedural macros in scope:
// #[derive(MartianStruct)], #[derive(MartianType)], #[make_mro], martian_filetype!
use martian_derive::{make_mro, MartianStruct};
use serde::{Deserialize, Serialize};

// NOTE: The following four structs will serve as the associated type for the
// trait. The struct fields need to be owned and are limited to
Expand Down Expand Up @@ -73,7 +71,7 @@ impl MartianStage for SumSquares {
for value in args.values {
let chunk_inputs = SumSquaresChunkInputs { value };
// It is optional to create a chunk with resource. If not specified, default resource will be used
stage_def.add_chunk_with_resource(chunk_inputs, chunk_resource);
stage_def.add_chunk_with_resource(chunk_inputs, chunk_resource.clone());
}
// Return the stage definition
Ok(stage_def)
Expand Down
63 changes: 7 additions & 56 deletions martian/src/stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use log::warn;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::{Path, PathBuf};
Expand Down Expand Up @@ -117,7 +117,7 @@ impl<T: MartianFileType> MartianMakePath for T {
///
/// Memory/ thread request can be negative in matrian. See
/// [http://martian-lang.org/advanced-features/#resource-consumption](http://martian-lang.org/advanced-features/#resource-consumption)
#[derive(Debug, Serialize, Deserialize, Copy, Clone, Default)]
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
pub struct Resource {
#[serde(rename = "__mem_gb")]
mem_gb: Option<isize>,
Expand All @@ -126,56 +126,7 @@ pub struct Resource {
#[serde(rename = "__vmem_gb")]
vmem_gb: Option<isize>,
#[serde(rename = "__special")]
special: Option<GpuResource>,
}

#[derive(Debug, Copy, Clone, Default)]
struct GpuResource {
count: isize,
mem: isize,
}

impl std::fmt::Display for GpuResource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "gpu_count{}_mem{}", self.count, self.mem)
}
}

impl Serialize for GpuResource {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}

impl<'de> Deserialize<'de> for GpuResource {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;

if !s.starts_with("gpu_count") || !s.contains("_mem") {
return Err(serde::de::Error::custom("format error"));
}

let count_start = 9; // position after "gpu_count"
let count_end = s
.find("_mem")
.ok_or_else(|| serde::de::Error::custom("format error"))?;
let mem_start = count_end + 4; // position after "_mem"

let count = s[count_start..count_end]
.parse::<isize>()
.map_err(serde::de::Error::custom)?;
let mem = s[mem_start..]
.parse::<isize>()
.map_err(serde::de::Error::custom)?;

Ok(GpuResource { count, mem })
}
special: Option<String>,
}

impl Resource {
Expand Down Expand Up @@ -217,7 +168,7 @@ impl Resource {

/// Get the special resource request
pub fn get_special(&self) -> Option<String> {
self.special.map(|s| s.to_string())
self.special.clone()
}

/// Set the mem_gb
Expand Down Expand Up @@ -272,8 +223,8 @@ impl Resource {
/// assert_eq!(resource.get_threads(), None);
/// assert_eq!(resource.get_special(), "gpu_count1_mem8".to_owned());
/// ```
pub fn special(mut self, count: isize, mem: isize) -> Self {
self.special = Some(GpuResource { count, mem });
pub fn special(mut self, special: String) -> Self {
self.special = Some(special);
self
}

Expand Down Expand Up @@ -682,7 +633,7 @@ pub trait MartianStage: MroMaker {
fill_defaults(resource),
))
}
let rover = _chunk_prelude(chunk_idx, run_directory, chunk.resource)?;
let rover = _chunk_prelude(chunk_idx, run_directory, chunk.resource.clone())?;
self.main(args.clone(), chunk.inputs.clone(), rover)
};

Expand Down

0 comments on commit 0a99f23

Please sign in to comment.