diff --git a/internal/modifier/list.go b/internal/modifier/list.go index b6d040e6a..d1ce9d642 100644 --- a/internal/modifier/list.go +++ b/internal/modifier/list.go @@ -22,14 +22,12 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" ) -type list struct { - modifiers []oci.SpecModifier -} +type List []oci.SpecModifier // Merge merges a set of OCI specification modifiers as a list. // This can be used to compose modifiers. func Merge(modifiers ...oci.SpecModifier) oci.SpecModifier { - var filteredModifiers []oci.SpecModifier + var filteredModifiers List for _, m := range modifiers { if m == nil { continue @@ -37,19 +35,19 @@ func Merge(modifiers ...oci.SpecModifier) oci.SpecModifier { filteredModifiers = append(filteredModifiers, m) } - return list{ - modifiers: filteredModifiers, - } + return filteredModifiers } // Modify applies a list of modifiers in sequence and returns on any errors encountered. -func (m list) Modify(spec *specs.Spec) error { - for _, mm := range m.modifiers { +func (m List) Modify(spec *specs.Spec) error { + for _, mm := range m { + if mm == nil { + continue + } err := mm.Modify(spec) if err != nil { return err } } - return nil } diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index 50c19a4f9..a44264598 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -79,26 +79,27 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp if err != nil { return nil, err } - // For CDI mode we make no additional modifications. - if mode == "cdi" { - return modeModifier, nil - } - graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver) - if err != nil { - return nil, err - } - - featureModifier, err := modifier.NewFeatureGatedModifier(logger, cfg, image) - if err != nil { - return nil, err + var modifiers modifier.List + for _, modifierType := range supportedModifierTypes(mode) { + switch modifierType { + case "mode": + modifiers = append(modifiers, modeModifier) + case "graphics": + graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver) + if err != nil { + return nil, err + } + modifiers = append(modifiers, graphicsModifier) + case "feature-gated": + featureGatedModifier, err := modifier.NewFeatureGatedModifier(logger, cfg, image) + if err != nil { + return nil, err + } + modifiers = append(modifiers, featureGatedModifier) + } } - modifiers := modifier.Merge( - modeModifier, - graphicsModifier, - featureModifier, - ) return modifiers, nil } @@ -114,3 +115,17 @@ func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, o return nil, fmt.Errorf("invalid runtime mode: %v", cfg.NVIDIAContainerRuntimeConfig.Mode) } + +// supportedModifierTypes returns the modifiers supported for a specific runtime mode. +func supportedModifierTypes(mode string) []string { + switch mode { + case "cdi": + // For CDI mode we make no additional modifications. + return []string{"mode"} + case "csv": + // For CSV mode we support mode and feature-gated modification. + return []string{"mode", "feature-gated"} + default: + return []string{"mode", "graphics", "feature-gated"} + } +}