diff --git a/sdl/gpu.go b/sdl/gpu.go index 0fa466011..54833b347 100644 --- a/sdl/gpu.go +++ b/sdl/gpu.go @@ -9,15 +9,22 @@ import ( types "github.com/akash-network/akash-api/go/node/types/v1beta3" ) +type gpuInterface string + type v2GPUNvidia struct { - Model string `yaml:"model"` - RAM *memoryQuantity `yaml:"ram,omitempty"` + Model string `yaml:"model"` + RAM *memoryQuantity `yaml:"ram,omitempty"` + Interface *gpuInterface `yaml:"interface,omitempty"` } func (sdl *v2GPUNvidia) String() string { key := sdl.Model if sdl.RAM != nil { - key += "/ram/" + sdl.RAM.StringWithSuffix("Gi") + key = fmt.Sprintf("%s/ram/%s", key, sdl.RAM.StringWithSuffix("Gi")) + } + + if sdl.Interface != nil { + key = fmt.Sprintf("%s/interface/%s", key, *sdl.Interface) } return key @@ -109,3 +116,16 @@ func (sdl *v2GPUAttributes) UnmarshalYAML(node *yaml.Node) error { return nil } + +func (sdl *gpuInterface) UnmarshalYAML(node *yaml.Node) error { + switch node.Value { + case "pcie": + case "sxm": + default: + return fmt.Errorf("sdl: invalid GPU interface %s. expected \"pcie|sxm\"", node.Value) + } + + *sdl = gpuInterface(node.Value) + + return nil +} diff --git a/sdl/gpu_test.go b/sdl/gpu_test.go index 282239296..8277fbc76 100644 --- a/sdl/gpu_test.go +++ b/sdl/gpu_test.go @@ -88,6 +88,41 @@ attributes: require.Error(t, err) } +func TestV2ResourceGPU_InterfaceInvalid(t *testing.T) { + var stream = ` +units: 1 +attributes: + vendor: + nvidia: + - model: a100 + interface: pciex +` + var p v2ResourceGPU + + err := yaml.Unmarshal([]byte(stream), &p) + require.Error(t, err) +} + +func TestV2ResourceGPU_RamWithInterface(t *testing.T) { + var stream = ` +units: 1 +attributes: + vendor: + nvidia: + - model: a100 + ram: 80Gi + interface: pcie +` + var p v2ResourceGPU + + err := yaml.Unmarshal([]byte(stream), &p) + require.NoError(t, err) + require.Equal(t, gpuQuantity(1), p.Units) + require.Equal(t, 1, len(p.Attributes)) + require.Equal(t, "vendor/nvidia/model/a100/ram/80Gi/interface/pcie", p.Attributes[0].Key) + require.Equal(t, "true", p.Attributes[0].Value) +} + func TestV2ResourceGPU_MultipleModels(t *testing.T) { var stream = ` units: 1