Skip to content

Commit

Permalink
Get storage class from PVC for snapshot matching (#105)
Browse files Browse the repository at this point in the history
Instead of using the PV to get the storage class
use the PVC storage class instead. If it is nil
then k8s version is < 1.28 and we return an error
indicating we cannot determine the storage class
name.

Signed-off-by: Alexander Wels <[email protected]>
  • Loading branch information
awels authored Mar 27, 2024
1 parent 7fc750b commit 701c59c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 44 deletions.
3 changes: 0 additions & 3 deletions hack/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ kind: ClusterRole
metadata:
name: kubevirt-csi-snapshot
rules:
- apiGroups: [""]
resources: ["persistentvolumes"]
verbs: ["get"]
- apiGroups: ["storage.k8s.io"]
resources: ["storageclasses"]
verbs: ["get"]
Expand Down
27 changes: 10 additions & 17 deletions pkg/kubevirt/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,13 @@ func (c *client) CreateVolumeSnapshot(ctx context.Context, namespace, name, clai
}

func (c *client) getSnapshotClassNameFromVolumeClaimName(ctx context.Context, namespace, claimName, snapshotClassName string) (string, error) {
volumeName, err := c.getVolumeNameFromClaimName(ctx, namespace, claimName)
if err != nil || volumeName == "" {
klog.V(2).Infof("Error getting volume name for claim %s in namespace %s: %v", claimName, namespace, err)
storageClassName, err := c.getStorageClassNameFromClaimName(ctx, namespace, claimName)
if err != nil {
klog.V(2).Infof("Error getting storage class name for claim %s in namespace %s: %v", claimName, namespace, err)
return "", fmt.Errorf("unable to determine snapshot class name for infra source volume")
}
storageClassName, err := c.getStorageClassFromVolume(ctx, volumeName)
if err != nil {
return "", err
if storageClassName == "" {
return "", fmt.Errorf("unable to determine storage class name for snapshot creation")
}
allowed, err := c.isStorageClassAllowed(ctx, storageClassName)
if err != nil {
Expand Down Expand Up @@ -327,24 +326,18 @@ func (c *client) isStorageClassAllowed(ctx context.Context, storageClassName str
}

// Determine the name of the volume associated with the passed in claim name
func (c *client) getVolumeNameFromClaimName(ctx context.Context, namespace, claimName string) (string, error) {
func (c *client) getStorageClassNameFromClaimName(ctx context.Context, namespace, claimName string) (string, error) {
volumeClaim, err := c.kubernetesClient.CoreV1().PersistentVolumeClaims(namespace).Get(ctx, claimName, metav1.GetOptions{})
if err != nil {
klog.Errorf("Error getting volume claim %s in namespace %s: %v", claimName, namespace, err)
return "", err
}
klog.V(5).Infof("found volumeClaim %#v", volumeClaim)
return volumeClaim.Spec.VolumeName, nil
}

// Determine the storage class from the volume
func (c *client) getStorageClassFromVolume(ctx context.Context, volumeName string) (string, error) {
volume, err := c.kubernetesClient.CoreV1().PersistentVolumes().Get(ctx, volumeName, metav1.GetOptions{})
if err != nil {
klog.V(2).Infof("Error getting volume %s: %v", volumeName, err)
return "", err
storageClassName := ""
if volumeClaim.Spec.StorageClassName != nil {
storageClassName = *volumeClaim.Spec.StorageClassName
}
return volume.Spec.StorageClassName, nil
return storageClassName, nil
}

// Get the associated snapshot class based on the storage class the following logic is used:
Expand Down
45 changes: 21 additions & 24 deletions pkg/kubevirt/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const (
testClaimName = "pvc-valid-data-volume"
testClaimName2 = "pvc-valid-data-volume2"
testClaimName3 = "pvc-valid-data-volume3"
testClaimName4 = "pvc-default-storage-class"
testNamespace = "test-namespace"
unboundTestClaimName = "unbound-test-claim"
)
Expand Down Expand Up @@ -122,26 +123,20 @@ var _ = Describe("Client", func() {
Entry("should return error when provider doesn't match", storageClassName, nonMatchingProvisioner, "", true),
)

It("Storage class from volume should return a storage class", func() {
storageClass, err := c.getStorageClassFromVolume(context.TODO(), testVolumeName)
It("storage class from claim should return a storage class name", func() {
storageClassName, err := c.getStorageClassNameFromClaimName(context.TODO(), testNamespace, testClaimName)
Expect(err).ToNot(HaveOccurred())
Expect(storageClass).To(Equal(storageClassName))
Expect(storageClassName).To(Equal(storageClassName))
})

It("Storage class from volume should return error if getting volume returns an error", func() {
storageClass, err := c.getStorageClassFromVolume(context.TODO(), "invalid")
It("storage class from claim should return error if getting claim name returns an error", func() {
volumeName, err := c.getStorageClassNameFromClaimName(context.TODO(), testNamespace, "invalid")
Expect(err).To(HaveOccurred())
Expect(storageClass).To(Equal(""))
})

It("volume from claim should return a volume name", func() {
volumeName, err := c.getVolumeNameFromClaimName(context.TODO(), testNamespace, testClaimName)
Expect(err).ToNot(HaveOccurred())
Expect(volumeName).To(Equal(testVolumeName))
Expect(volumeName).To(Equal(""))
})

It("volume from claim should return error if getting claim name returns an error", func() {
volumeName, err := c.getVolumeNameFromClaimName(context.TODO(), testNamespace, "invalid")
It("snapshot class from claim name should return error if claim has nil storage class", func() {
volumeName, err := c.getSnapshotClassNameFromVolumeClaimName(context.TODO(), testNamespace, testClaimName4, volumeSnapshotClassName)
Expect(err).To(HaveOccurred())
Expect(volumeName).To(Equal(""))
})
Expand All @@ -158,8 +153,6 @@ var _ = Describe("Client", func() {
},
Entry("should return snapshot class", testClaimName, testNamespace, volumeSnapshotClassName, volumeSnapshotClassName, false),
Entry("should return error when claim is invalid", "invalid", testNamespace, volumeSnapshotClassName, "", true),
Entry("should return error when claim is unbound", unboundTestClaimName, testNamespace, volumeSnapshotClassName, "", true),
Entry("should return error when volume cannot be found", testClaimName2, testNamespace, volumeSnapshotClassName, "", true),
)

It("should return error if the storage class is not allowed", func() {
Expand Down Expand Up @@ -270,9 +263,10 @@ func NewFakeClient() *client {
defaultStorageClass := createStorageClass(defaultStorageClassName, provisioner, true)
testVolume := createPersistentVolume(testVolumeName, storageClassName)
testVolumeNotAllowed := createPersistentVolume(testVolumeNameNotAllowed, "not-allowed-storage-class")
testClaim := createPersistentVolumeClaim(testClaimName, testVolumeName, storageClassName)
testClaim2 := createPersistentVolumeClaim(testClaimName2, "testVolumeName2", storageClassName)
testClaim3 := createPersistentVolumeClaim(testClaimName3, testVolumeNameNotAllowed, "not-allowed-storage-class")
testClaim := createPersistentVolumeClaim(testClaimName, testVolumeName, ptr.To[string](storageClassName))
testClaim2 := createPersistentVolumeClaim(testClaimName2, "testVolumeName2", ptr.To[string](storageClassName))
testClaim3 := createPersistentVolumeClaim(testClaimName3, testVolumeNameNotAllowed, ptr.To[string]("not-allowed-storage-class"))
testClaimDefault := createPersistentVolumeClaim(testClaimName4, testVolumeName, nil)
unboundClaim := &k8sv1.PersistentVolumeClaim{
ObjectMeta: metav1.ObjectMeta{
Name: unboundTestClaimName,
Expand All @@ -283,7 +277,7 @@ func NewFakeClient() *client {
},
}
fakeK8sClient := k8sfake.NewSimpleClientset(storageClass, defaultStorageClass, testVolume,
testVolumeNotAllowed, testClaim, testClaim2, testClaim3, unboundClaim)
testVolumeNotAllowed, testClaim, testClaim2, testClaim3, unboundClaim, testClaimDefault)

fakeSnapClient := snapfake.NewSimpleClientset(
createVolumeSnapshotClass(volumeSnapshotClassName, provisioner, false),
Expand Down Expand Up @@ -330,18 +324,21 @@ func createPersistentVolume(name, storageClassName string) *k8sv1.PersistentVolu
}
}

func createPersistentVolumeClaim(name, volumeName, storageClassName string) *k8sv1.PersistentVolumeClaim {
return &k8sv1.PersistentVolumeClaim{
func createPersistentVolumeClaim(name, volumeName string, storageClassName *string) *k8sv1.PersistentVolumeClaim {
pvc := &k8sv1.PersistentVolumeClaim{
ObjectMeta: metav1.ObjectMeta{
Name: name,
Namespace: testNamespace,
Labels: map[string]string{"test": "test"},
},
Spec: k8sv1.PersistentVolumeClaimSpec{
StorageClassName: ptr.To[string](storageClassName),
VolumeName: volumeName,
VolumeName: volumeName,
},
}
if storageClassName != nil {
pvc.Spec.StorageClassName = storageClassName
}
return pvc
}

func createStorageClass(name, provisioner string, isDefault bool) *storagev1.StorageClass {
Expand Down

0 comments on commit 701c59c

Please sign in to comment.