Skip to content

Commit

Permalink
feat(mro): Improve martian_make_mro (#498)
Browse files Browse the repository at this point in the history
Include the filename in the error message of martian_make_mro.
Use current_executable in get_generator_name.
Deref symlinks in current_executable.
  • Loading branch information
sjackman authored Jun 12, 2024
1 parent ac52194 commit 13e107e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 48 deletions.
78 changes: 35 additions & 43 deletions martian/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,23 @@
//!
//! ## Documentation
//! For a guide style documentation and examples, visit: [https://martian-lang.github.io/martian-rust/](https://martian-lang.github.io/martian-rust/#/)
//!
pub use anyhow::Error;
use anyhow::{format_err, Context};
use anyhow::{ensure, Context, Result};
use backtrace::Backtrace;
use log::{error, info};

use std::collections::HashMap;
use std::fmt::Write as FmtWrite;
use std::fs::File;
use std::io::Write as IoWrite;
use std::os::unix::io::{FromRawFd, IntoRawFd};
use std::path::Path;

use std::{io, panic};
use time::format_description::modifier::{Day, Hour, Minute, Month, Second, Year};
use time::format_description::FormatItem::Literal;
use time::format_description::{Component, FormatItem};
use time::OffsetDateTime;
use utils::current_executable;

mod metadata;
pub use metadata::*;
Expand All @@ -42,15 +40,15 @@ pub use log::LevelFilter;
pub use mro::*;
pub mod prelude;

pub fn initialize(args: Vec<String>) -> Result<Metadata, Error> {
pub fn initialize(args: Vec<String>) -> Result<Metadata> {
let mut md = Metadata::new(args);
md.update_jobinfo()?;

Ok(md)
}

#[cold]
fn write_errors(msg: &str, is_assert: bool) -> Result<(), Error> {
fn write_errors(msg: &str, is_assert: bool) -> Result<()> {
let mut err_file: File = unsafe { File::from_raw_fd(4) };

// We want to aggressively avoid allocations here if we can, since one
Expand Down Expand Up @@ -79,7 +77,7 @@ fn write_errors(msg: &str, is_assert: bool) -> Result<(), Error> {
// We could use the proc macro, but then we'd need
// to compile the proc macro crate, which would slow down build times
// significantly for very little benefit in readability.
pub(crate) const DATE_FORMAT: &[FormatItem] = &[
pub(crate) const DATE_FORMAT: &[FormatItem<'_>] = &[
FormatItem::Component(Component::Year(Year::default())),
Literal(b"-"),
FormatItem::Component(Component::Month(Month::default())),
Expand Down Expand Up @@ -195,13 +193,13 @@ fn martian_entry_point<S: std::hash::BuildHasher>(
};

// Get the stage implementation
let _stage = stage_map.get(&md.stage_name).ok_or_else(
let stage = stage_map.get(&md.stage_name).with_context(
#[cold]
|| format_err!("Couldn't find requested Martian stage: {}", md.stage_name),
|| format!("Couldn't find requested Martian stage: {}", md.stage_name),
);

// special handler for non-existent stage
let stage = match _stage {
let stage = match stage {
Ok(s) => s,
Err(e) => {
let _ = write_errors(&format!("{e:?}"), false);
Expand Down Expand Up @@ -286,48 +284,43 @@ fn report_error(md: &mut Metadata, e: &Error, is_assert: bool) {
let _ = write_errors(&format!("{e:#}"), is_assert);
}

/// Return the environment variable CARGO_PKG_NAME or the current executable name.
fn get_generator_name() -> String {
std::env::var("CARGO_BIN_NAME")
.or_else(|_| std::env::var("CARGO_CRATE_NAME"))
.or_else(|_| std::env::var("CARGO_PKG_NAME"))
.unwrap_or_else(|_| {
option_env!("CARGO_BIN_NAME")
.or(option_env!("CARGO_CRATE_NAME"))
.unwrap_or("martian-rust")
.into()
})
std::env::var("CARGO_PKG_NAME").unwrap_or_else(|_| current_executable())
}

/// Write MRO to filename or stdout.
pub fn martian_make_mro(
header_comment: &str,
file_name: Option<impl AsRef<Path>>,
filename: Option<impl AsRef<Path>>,
rewrite: bool,
mro_registry: Vec<StageMro>,
) -> Result<(), Error> {
if let Some(ref f) = file_name {
let file_path = f.as_ref();
if file_path.is_dir() {
return Err(format_err!(
"Error! Path {} is a directory!",
file_path.display()
));
}
if file_path.exists() && !rewrite {
return Err(format_err!(
"File {} exists. You need to explicitly mention if it is okay to rewrite.",
file_path.display()
));
}
) -> Result<()> {
if let Some(filename) = &filename {
let filename = filename.as_ref();
ensure!(
!filename.is_dir(),
"Path {} is a directory",
filename.display()
);
ensure!(
rewrite || !filename.exists(),
"File {} exists. Use --rewrite to overwrite it.",
filename.display()
);
}

let final_mro_string = make_mro_string(header_comment, &mro_registry);
match file_name {
Some(f) => {
let mut output = File::create(f)?;
output.write_all(final_mro_string.as_bytes())?;
let mro = make_mro_string(header_comment, &mro_registry);
match filename {
Some(filename) => {
let filename = filename.as_ref();
File::create(filename)
.with_context(|| filename.display().to_string())?
.write_all(mro.as_bytes())
.with_context(|| filename.display().to_string())?;
}
None => {
print!("{final_mro_string}");
print!("{mro}");
}
}
Ok(())
Expand Down Expand Up @@ -361,8 +354,7 @@ pub fn make_mro_string(header_comment: &str, mro_registry: &[StageMro]) -> Strin
header_comment
.lines()
.all(|line| line.trim_end().is_empty() || line.starts_with('#')),
"All non-empty header lines must start with '#', but got\n{}",
header_comment
"All non-empty header lines must start with '#', but got\n{header_comment}"
);
format!(
"{}
Expand Down
2 changes: 1 addition & 1 deletion martian/src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ impl Metadata {

/// Update the Martian journal -- so that Martian knows what we've updated
fn update_journal(&self, name: &str) -> Result<()> {
let journal_name: Cow<str> = if self.stage_type != "main" {
let journal_name: Cow<'_, str> = if self.stage_type != "main" {
format!("{}_{name}", self.stage_type).into()
} else {
name.into()
Expand Down
10 changes: 6 additions & 4 deletions martian/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ pub fn to_camel_case(stage_name: &str) -> String {
/// Parse the `env::args()` and return the name of the
/// current executable as a String
pub fn current_executable() -> String {
let args: Vec<_> = std::env::args().collect();
std::path::Path::new(&args[0])
Path::new(&std::env::args().next().unwrap())
.canonicalize()
.unwrap()
.file_name()
.unwrap()
.to_string_lossy()
.into_owned()
.to_str()
.unwrap()
.to_string()
}

/// Given a filename and an extension, return the filename with the correct extension.
Expand Down

0 comments on commit 13e107e

Please sign in to comment.