From eec294e76a3bf8c854f2a10bd1a7aa40704d1277 Mon Sep 17 00:00:00 2001 From: yiscah Date: Wed, 24 Nov 2021 11:35:53 +0200 Subject: [PATCH] deep copy subject before changing + error handling --- reporthandling/regoresourcesaggregator.go | 67 +++++++++++++------ .../regoresourcesaggregator_test.go | 22 +++++- 2 files changed, 65 insertions(+), 24 deletions(-) diff --git a/reporthandling/regoresourcesaggregator.go b/reporthandling/regoresourcesaggregator.go index caef6713..084e2f97 100644 --- a/reporthandling/regoresourcesaggregator.go +++ b/reporthandling/regoresourcesaggregator.go @@ -1,6 +1,8 @@ package reporthandling import ( + "bytes" + "encoding/gob" "strings" "github.com/armosec/k8s-interface/workloadinterface" @@ -8,7 +10,7 @@ import ( var aggregatorAttribute = "resourcesAggregator" -func RegoResourcesAggregator(rule *PolicyRule, k8sObjects []map[string]interface{}) []map[string]interface{} { +func RegoResourcesAggregator(rule *PolicyRule, k8sObjects []map[string]interface{}) ([]map[string]interface{}, error) { if aggregateBy, ok := rule.Attributes[aggregatorAttribute]; ok { switch aggregateBy { case "subject-role-rolebinding": @@ -16,13 +18,13 @@ func RegoResourcesAggregator(rule *PolicyRule, k8sObjects []map[string]interface case "apiserver-pod": return AggregateResourcesByAPIServerPod(k8sObjects) default: - return k8sObjects + return k8sObjects, nil } } - return k8sObjects + return k8sObjects, nil } -func AggregateResourcesBySubjects(k8sObjects []map[string]interface{}) []map[string]interface{} { +func AggregateResourcesBySubjects(k8sObjects []map[string]interface{}) ([]map[string]interface{}, error) { var aggregatedK8sObjects []map[string]interface{} for _, firstk8sObject := range k8sObjects { bindingWorkload := workloadinterface.NewWorkloadObj(firstk8sObject) @@ -37,7 +39,10 @@ func AggregateResourcesBySubjects(k8sObjects []map[string]interface{}) []map[str if subjects, ok := workloadinterface.InspectMap(bindingWorkloadObj, "subjects"); ok { if data, ok := subjects.([]interface{}); ok { for _, subject := range data { - subjectAllFields := setSubjectFields(subject.(map[string]interface{})) + subjectAllFields, err := setSubjectFields(subject.(map[string]interface{})) + if err != nil { + return aggregatedK8sObjects, err + } subjectAllFields[workloadinterface.RelatedObjectsKey] = []map[string]interface{}{bindingWorkload.GetObject(), roleWorkload.GetObject()} newObj := workloadinterface.NewRegoResponseVectorObject(subjectAllFields) aggregatedK8sObjects = append(aggregatedK8sObjects, newObj.GetObject()) @@ -51,11 +56,11 @@ func AggregateResourcesBySubjects(k8sObjects []map[string]interface{}) []map[str } } } - return aggregatedK8sObjects + return aggregatedK8sObjects, nil } // Create custom object of apiserver pod. Has required fields + cmdline -func AggregateResourcesByAPIServerPod(k8sObjects []map[string]interface{}) []map[string]interface{} { +func AggregateResourcesByAPIServerPod(k8sObjects []map[string]interface{}) ([]map[string]interface{}, error) { apiServerPod := map[string]interface{}{} for _, obj := range k8sObjects { workload := workloadinterface.NewWorkloadObj(obj) @@ -67,30 +72,50 @@ func AggregateResourcesByAPIServerPod(k8sObjects []map[string]interface{}) []map apiServerPod["apiVersion"] = workload.GetApiVersion() containers, err := workload.GetContainers() if err != nil || len(containers) == 0 { - return nil + return nil, err } // apiServer has only one container apiServerPod["cmdline"] = containers[0].Command - return []map[string]interface{}{apiServerPod} + return []map[string]interface{}{apiServerPod}, nil } } } - return nil + return nil, nil } -func setSubjectFields(subject map[string]interface{}) map[string]interface{} { - - if _, ok := workloadinterface.InspectMap(subject, "name"); !ok { - subject["name"] = "" +func setSubjectFields(subject map[string]interface{}) (map[string]interface{}, error) { + newSubject, err := DeepCopyMap(subject) + if err != nil { + return nil, err + } + if _, ok := workloadinterface.InspectMap(newSubject, "name"); !ok { + newSubject["name"] = "" + } + if _, ok := workloadinterface.InspectMap(newSubject, "namespace"); !ok { + newSubject["namespace"] = "" } - if _, ok := workloadinterface.InspectMap(subject, "namespace"); !ok { - subject["namespace"] = "" + if _, ok := workloadinterface.InspectMap(newSubject, "kind"); !ok { + newSubject["kind"] = "" } - if _, ok := workloadinterface.InspectMap(subject, "kind"); !ok { - subject["kind"] = "" + if _, ok := workloadinterface.InspectMap(newSubject, "apiVersion"); !ok { + newSubject["apiVersion"] = "" + } + return newSubject, nil +} + +// DeepCopyMap performs a deep copy of the given map m. +func DeepCopyMap(m map[string]interface{}) (map[string]interface{}, error) { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + dec := gob.NewDecoder(&buf) + err := enc.Encode(m) + if err != nil { + return nil, err } - if _, ok := workloadinterface.InspectMap(subject, "apiVersion"); !ok { - subject["apiVersion"] = "" + var copy map[string]interface{} + err = dec.Decode(©) + if err != nil { + return nil, err } - return subject + return copy, nil } diff --git a/reporthandling/regoresourcesaggregator_test.go b/reporthandling/regoresourcesaggregator_test.go index a9aada41..fb34afdc 100644 --- a/reporthandling/regoresourcesaggregator_test.go +++ b/reporthandling/regoresourcesaggregator_test.go @@ -2,6 +2,7 @@ package reporthandling import ( "encoding/json" + "fmt" "testing" ) @@ -20,7 +21,10 @@ func TestAggregateResourcesAPIServerPod(t *testing.T) { t.Errorf("error unmarshaling %s", err) } inputList := []map[string]interface{}{pod} - outputList := AggregateResourcesByAPIServerPod(inputList) + outputList, err := AggregateResourcesByAPIServerPod(inputList) + if err != nil { + t.Errorf(err.Error()) + } if len(outputList) != 1 { t.Errorf("error in AggregateResourcesAPIServerPod, len should be 1, got len = %d", len(outputList)) } @@ -43,7 +47,10 @@ func TestAggregateResourcesBySubjects(t *testing.T) { } // r := make(map[string]interface{}, []byte(role)) inputList := []map[string]interface{}{r, rb} - outputList := AggregateResourcesBySubjects(inputList) + outputList, err := AggregateResourcesBySubjects(inputList) + if err != nil { + t.Errorf(err.Error()) + } if len(outputList) != 1 { t.Errorf("error in AggregateResourcesBySubjects, len should be 1, got len = %d", len(outputList)) } @@ -65,7 +72,16 @@ func TestAggregateResourcesBySubjects2(t *testing.T) { } // r := make(map[string]interface{}, []byte(role)) inputList := []map[string]interface{}{r, rb} - outputList := AggregateResourcesBySubjects(inputList) + outputList, err := AggregateResourcesBySubjects(inputList) + if err != nil { + t.Errorf(err.Error()) + } + val, err := json.MarshalIndent(outputList, "", " ") + if err != nil { + t.Errorf(err.Error()) + } + a := string(val) + fmt.Println(a) if len(outputList) != 2 { t.Errorf("error in AggregateResourcesBySubjects, len should be 2, got len = %d", len(outputList)) }