From 0a99f235d14e493ffae4d0dc86ed3d5dcc35d688 Mon Sep 17 00:00:00 2001 From: Preyas Shah Date: Tue, 16 Jul 2024 23:06:56 -0700 Subject: [PATCH] just have it as a string --- .../examples/sum_sq/src/sum_squares.rs | 6 +- martian/src/stage.rs | 63 +++---------------- 2 files changed, 9 insertions(+), 60 deletions(-) diff --git a/martian-lab/examples/sum_sq/src/sum_squares.rs b/martian-lab/examples/sum_sq/src/sum_squares.rs index 87f7999241..5a046c9404 100644 --- a/martian-lab/examples/sum_sq/src/sum_squares.rs +++ b/martian-lab/examples/sum_sq/src/sum_squares.rs @@ -1,7 +1,5 @@ //! SumSquares stage code -use serde::{Deserialize, Serialize}; - // The prelude brings the following items in scope: // - Traits: MartianMain, MartianStage, RawMartianStage, MartianFileType, MartianMakePath // - Struct/Enum: MartianRover, Resource, StageDef, MartianVoid, @@ -9,10 +7,10 @@ use serde::{Deserialize, Serialize}; // - Macros: martian_stages! // - Functions: martian_main, martian_main_with_log_level, martian_make_mro use martian::prelude::*; - // Bring the procedural macros in scope: // #[derive(MartianStruct)], #[derive(MartianType)], #[make_mro], martian_filetype! use martian_derive::{make_mro, MartianStruct}; +use serde::{Deserialize, Serialize}; // NOTE: The following four structs will serve as the associated type for the // trait. The struct fields need to be owned and are limited to @@ -73,7 +71,7 @@ impl MartianStage for SumSquares { for value in args.values { let chunk_inputs = SumSquaresChunkInputs { value }; // It is optional to create a chunk with resource. If not specified, default resource will be used - stage_def.add_chunk_with_resource(chunk_inputs, chunk_resource); + stage_def.add_chunk_with_resource(chunk_inputs, chunk_resource.clone()); } // Return the stage definition Ok(stage_def) diff --git a/martian/src/stage.rs b/martian/src/stage.rs index 81ba68edd3..d7bc25a515 100644 --- a/martian/src/stage.rs +++ b/martian/src/stage.rs @@ -6,7 +6,7 @@ use log::warn; #[cfg(feature = "rayon")] use rayon::prelude::*; use serde::de::DeserializeOwned; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::{Deserialize, Serialize}; use std::fs::File; use std::io::{BufReader, BufWriter}; use std::path::{Path, PathBuf}; @@ -117,7 +117,7 @@ impl MartianMakePath for T { /// /// Memory/ thread request can be negative in matrian. See /// [http://martian-lang.org/advanced-features/#resource-consumption](http://martian-lang.org/advanced-features/#resource-consumption) -#[derive(Debug, Serialize, Deserialize, Copy, Clone, Default)] +#[derive(Debug, Serialize, Deserialize, Clone, Default)] pub struct Resource { #[serde(rename = "__mem_gb")] mem_gb: Option, @@ -126,56 +126,7 @@ pub struct Resource { #[serde(rename = "__vmem_gb")] vmem_gb: Option, #[serde(rename = "__special")] - special: Option, -} - -#[derive(Debug, Copy, Clone, Default)] -struct GpuResource { - count: isize, - mem: isize, -} - -impl std::fmt::Display for GpuResource { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "gpu_count{}_mem{}", self.count, self.mem) - } -} - -impl Serialize for GpuResource { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_str(&self.to_string()) - } -} - -impl<'de> Deserialize<'de> for GpuResource { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let s = String::deserialize(deserializer)?; - - if !s.starts_with("gpu_count") || !s.contains("_mem") { - return Err(serde::de::Error::custom("format error")); - } - - let count_start = 9; // position after "gpu_count" - let count_end = s - .find("_mem") - .ok_or_else(|| serde::de::Error::custom("format error"))?; - let mem_start = count_end + 4; // position after "_mem" - - let count = s[count_start..count_end] - .parse::() - .map_err(serde::de::Error::custom)?; - let mem = s[mem_start..] - .parse::() - .map_err(serde::de::Error::custom)?; - - Ok(GpuResource { count, mem }) - } + special: Option, } impl Resource { @@ -217,7 +168,7 @@ impl Resource { /// Get the special resource request pub fn get_special(&self) -> Option { - self.special.map(|s| s.to_string()) + self.special.clone() } /// Set the mem_gb @@ -272,8 +223,8 @@ impl Resource { /// assert_eq!(resource.get_threads(), None); /// assert_eq!(resource.get_special(), "gpu_count1_mem8".to_owned()); /// ``` - pub fn special(mut self, count: isize, mem: isize) -> Self { - self.special = Some(GpuResource { count, mem }); + pub fn special(mut self, special: String) -> Self { + self.special = Some(special); self } @@ -682,7 +633,7 @@ pub trait MartianStage: MroMaker { fill_defaults(resource), )) } - let rover = _chunk_prelude(chunk_idx, run_directory, chunk.resource)?; + let rover = _chunk_prelude(chunk_idx, run_directory, chunk.resource.clone())?; self.main(args.clone(), chunk.inputs.clone(), rover) };