From e614ed7ad951036c6e342867b0ca4ef7fa437bb9 Mon Sep 17 00:00:00 2001
From: Jarrett Tierney <jmt@amazon.com>
Date: Tue, 23 Jul 2024 14:55:07 -0700
Subject: [PATCH] twoliter: add kit overrides for test interactions

---
 twoliter/src/cmd/fetch.rs |  43 +++++++++++-
 twoliter/src/cmd/mod.rs   |   1 +
 twoliter/src/lock.rs      | 133 ++++++++++++++++++++++++++++++--------
 3 files changed, 149 insertions(+), 28 deletions(-)

diff --git a/twoliter/src/cmd/fetch.rs b/twoliter/src/cmd/fetch.rs
index 92ad46478..89a9b8064 100644
--- a/twoliter/src/cmd/fetch.rs
+++ b/twoliter/src/cmd/fetch.rs
@@ -2,6 +2,9 @@ use crate::lock::Lock;
 use crate::project;
 use anyhow::Result;
 use clap::Parser;
+use log::warn;
+use std::collections::HashMap;
+use std::error::Error;
 use std::path::PathBuf;
 
 #[derive(Debug, Parser)]
@@ -12,13 +15,51 @@ pub(crate) struct Fetch {
 
     #[clap(long = "arch", default_value = "x86_64")]
     pub(crate) arch: String,
+
+    #[clap(long = "kit-override", short = 'K', value_parser = parse_key_val::<String, PathBuf>)]
+    pub(crate) kit_override: Option<Vec<(String, PathBuf)>>,
+}
+
+/// Parse a single key-value pair
+fn parse_key_val<T, U>(s: &str) -> Result<(T, U), Box<dyn Error + Send + Sync + 'static>>
+where
+    T: std::str::FromStr,
+    T::Err: Error + Send + Sync + 'static,
+    U: std::str::FromStr,
+    U::Err: Error + Send + Sync + 'static,
+{
+    let pos = s
+        .find('=')
+        .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
+    Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
 }
 
 impl Fetch {
     pub(super) async fn run(&self) -> Result<()> {
         let project = project::load_or_find_project(self.project_path.clone()).await?;
         let lock_file = Lock::load(&project).await?;
-        lock_file.fetch(&project, self.arch.as_str()).await?;
+        if self.kit_override.is_some() {
+            warn!(
+                r#"
+!!!
+Bottlerocket is being built with an overwritten kit.
+This means that the resulting variant images are not based on a remotely
+hosted and officially tagged version of kits.         
+!!!
+"#
+            );
+        }
+        lock_file
+            .fetch(
+                &project,
+                self.arch.as_str(),
+                self.kit_override
+                    .clone()
+                    .map(|x| crate::lock::LockOverrides {
+                        kit: HashMap::from_iter(x),
+                    }),
+            )
+            .await?;
         Ok(())
     }
 }
diff --git a/twoliter/src/cmd/mod.rs b/twoliter/src/cmd/mod.rs
index 6808677a1..8cf016119 100644
--- a/twoliter/src/cmd/mod.rs
+++ b/twoliter/src/cmd/mod.rs
@@ -140,6 +140,7 @@ mod test {
         let command = Fetch {
             project_path: Some(project_path.to_path_buf()),
             arch: arch.into(),
+            kit_override: None,
         };
         command.run().await.unwrap()
     }
diff --git a/twoliter/src/lock.rs b/twoliter/src/lock.rs
index 782fea616..fad8bd28c 100644
--- a/twoliter/src/lock.rs
+++ b/twoliter/src/lock.rs
@@ -1,8 +1,10 @@
 use crate::common::fs::{create_dir_all, read, remove_dir_all, remove_file, write};
 use crate::project::{Image, Project, ValidIdentifier, Vendor};
 use crate::schema_version::SchemaVersion;
-use anyhow::{ensure, Context, Result};
+use anyhow::{bail, ensure, Context, Result};
+use async_walkdir::WalkDir;
 use base64::Engine;
+use futures::StreamExt;
 use oci_cli_wrapper::{DockerArchitecture, ImageTool};
 use olpc_cjson::CanonicalFormatter as CanonicalJsonFormatter;
 use semver::Version;
@@ -163,8 +165,14 @@ struct ExternalKitMetadata {
 #[derive(Debug)]
 struct OCIArchive {
     image: LockedImage,
-    digest: String,
     cache_dir: PathBuf,
+    source: OCISource,
+}
+
+#[derive(Debug)]
+enum OCISource {
+    Registry { digest: String },
+    Local { path: PathBuf },
 }
 
 impl OCIArchive {
@@ -174,23 +182,71 @@ impl OCIArchive {
     {
         Ok(Self {
             image: image.clone(),
-            digest: digest.into(),
             cache_dir: cache_dir.as_ref().to_path_buf(),
+            source: OCISource::Registry {
+                digest: digest.into(),
+            },
+        })
+    }
+
+    fn from_path<P>(image: &LockedImage, path: P, cache_dir: P) -> Result<Self>
+    where
+        P: AsRef<Path>,
+    {
+        Ok(Self {
+            image: image.clone(),
+            cache_dir: cache_dir.as_ref().to_path_buf(),
+            source: OCISource::Local {
+                path: path.as_ref().to_path_buf(),
+            },
         })
     }
 
     fn archive_path(&self) -> PathBuf {
-        self.cache_dir.join(self.digest.replace(':', "-"))
+        match &self.source {
+            OCISource::Registry { digest } => self.cache_dir.join(digest.replace(':', "-")),
+            OCISource::Local { .. } => self.cache_dir.join(format!(
+                "{}-{}-{}-override",
+                self.image.name, self.image.version, self.image.vendor
+            )),
+        }
     }
 
-    async fn pull_image(&self, image_tool: &ImageTool) -> Result<()> {
-        let digest_uri = self.image.digest_uri(self.digest.as_str());
-        let oci_archive_path = self.archive_path();
-        if !oci_archive_path.exists() {
-            create_dir_all(&oci_archive_path).await?;
-            image_tool
-                .pull_oci_image(oci_archive_path.as_path(), digest_uri.as_str())
-                .await?;
+    async fn pull_image(&self, image_tool: &ImageTool, arch: &str) -> Result<()> {
+        match &self.source {
+            OCISource::Registry { digest } => {
+                let digest_uri = self.image.digest_uri(digest.as_str());
+                let oci_archive_path = self.archive_path();
+                if !oci_archive_path.exists() {
+                    create_dir_all(&oci_archive_path).await?;
+                    image_tool
+                        .pull_oci_image(oci_archive_path.as_path(), digest_uri.as_str())
+                        .await?;
+                }
+            }
+            OCISource::Local { path } => {
+                let oci_archive_path = self.archive_path();
+                // We need to look for an archive matching the architecture
+                let name = self.image.name.clone();
+                let build_dir = path.join(format!("build/kits/{name}"));
+                let mut walker = WalkDir::new(build_dir);
+                let suffix = format!("{}.tar", arch);
+                while let Some(Ok(entry)) = walker.next().await {
+                    if entry.path().is_file() && entry.path().to_string_lossy().ends_with(&suffix) {
+                        let archive_fp = File::open(entry.path())
+                            .context("failed to open oci archive from disk")?;
+                        let mut archive = TarArchive::new(archive_fp);
+                        archive
+                            .unpack(oci_archive_path.clone())
+                            .context("failed to extract oci archive from file")?;
+                        return Ok(());
+                    }
+                }
+                bail!(
+                    "No oci image archive was found in {}. Have you built the kit?",
+                    path.display()
+                );
+            }
         }
         Ok(())
     }
@@ -201,13 +257,15 @@ impl OCIArchive {
     {
         let path = out_dir.as_ref();
         let digest_file = path.join("digest");
-        if digest_file.exists() {
-            let digest = read_to_string(&digest_file).await.context(format!(
-                "failed to read digest file at {}",
-                digest_file.display()
-            ))?;
-            if digest == self.digest {
-                return Ok(());
+        if let OCISource::Registry { digest } = &self.source {
+            if digest_file.exists() {
+                let on_disk = read_to_string(&digest_file).await.context(format!(
+                    "failed to read digest file at {}",
+                    digest_file.display()
+                ))?;
+                if on_disk == *digest {
+                    return Ok(());
+                }
             }
         }
 
@@ -240,17 +298,22 @@ impl OCIArchive {
                 .unpack(path)
                 .context("failed to unpack layer to disk")?;
         }
-        write(&digest_file, self.digest.as_str())
-            .await
-            .context(format!(
+        if let OCISource::Registry { digest } = &self.source {
+            write(&digest_file, digest.as_str()).await.context(format!(
                 "failed to record digest to {}",
                 digest_file.display()
             ))?;
+        }
 
         Ok(())
     }
 }
 
+#[derive(Debug, Clone)]
+pub(crate) struct LockOverrides {
+    pub kit: HashMap<String, PathBuf>,
+}
+
 /// Represents the structure of a `Twoliter.lock` lock file.
 #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
 #[serde(rename_all = "kebab-case")]
@@ -305,7 +368,12 @@ impl Lock {
     }
 
     /// Fetches all external kits defined in a Twoliter.lock to the build directory
-    pub(crate) async fn fetch(&self, project: &Project, arch: &str) -> Result<()> {
+    pub(crate) async fn fetch(
+        &self,
+        project: &Project,
+        arch: &str,
+        overrides: Option<LockOverrides>,
+    ) -> Result<()> {
         let image_tool = ImageTool::from_environment()?;
         let target_dir = project.external_kits_dir();
         create_dir_all(&target_dir).await.context(format!(
@@ -313,8 +381,14 @@ impl Lock {
             target_dir.display()
         ))?;
         for image in self.kit.iter() {
-            self.extract_kit(&image_tool, &project.external_kits_dir(), image, arch)
-                .await?;
+            self.extract_kit(
+                &image_tool,
+                &project.external_kits_dir(),
+                image,
+                arch,
+                overrides.clone(),
+            )
+            .await?;
         }
         let mut kit_list = Vec::new();
         let mut ser =
@@ -371,6 +445,7 @@ impl Lock {
         path: P,
         image: &LockedImage,
         arch: &str,
+        overrides: Option<LockOverrides>,
     ) -> Result<()>
     where
         P: AsRef<Path>,
@@ -384,10 +459,14 @@ impl Lock {
 
         // First get the manifest for the specific requested architecture
         let manifest = self.get_manifest(image_tool, image, arch).await?;
-        let oci_archive = OCIArchive::new(image, manifest.digest.as_str(), &cache_path)?;
+        let oci_archive = if let Some(path) = overrides.as_ref().and_then(|x| x.kit.get(&name)) {
+            OCIArchive::from_path(image, path, &cache_path)
+        } else {
+            OCIArchive::new(image, manifest.digest.as_str(), &cache_path)
+        }?;
 
         // Checks for the saved image locally, or else pulls and saves it
-        oci_archive.pull_image(image_tool).await?;
+        oci_archive.pull_image(image_tool, arch).await?;
 
         // Checks if this archive has already been extracted by checking a digest file
         // otherwise cleans up the path and unpacks the archive