Skip to content

Commit

Permalink
Drastically improved usage of the MSSP multi-tenant workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
frikky committed Sep 26, 2024
1 parent 16ad616 commit a55f6e8
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 21 deletions.
7 changes: 6 additions & 1 deletion db-connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -3726,7 +3726,7 @@ func GetOrg(ctx context.Context, id string) (*Org, error) {
if err := project.Dbclient.Get(ctx, key, curOrg); err != nil {
log.Printf("[ERROR] Error in org loading (2) for %s: %s", key, err)
//log.Printf("Users: %s", curOrg.Users)
if strings.Contains(err.Error(), `cannot load field`) && strings.Contains(err.Error(), `users`) {
if strings.Contains(err.Error(), `cannot load field`) && strings.Contains(err.Error(), `users`) && !strings.Contains(err.Error(), `users_last_session`) {
//Self correcting Org handler for user migration. This may come in handy if we change the structure of private apps later too.
log.Printf("[INFO] Error in org loading (3). Migrating org to new org and user handler (2): %s", err)
err = nil
Expand Down Expand Up @@ -8000,6 +8000,11 @@ func SetWorkflow(ctx context.Context, workflow Workflow, id string, optionalEdit
// Handles parent/child workflow relationships
if len(workflow.ParentWorkflowId) > 0 {
DeleteCache(ctx, fmt.Sprintf("workflow_%s_childworkflows", workflow.ID))
DeleteCache(ctx, fmt.Sprintf("workflow_%s_childworkflows", workflow.ParentWorkflowId))
}

if len(workflow.ChildWorkflowIds) > 0 {
DeleteCache(ctx, fmt.Sprintf("workflow_%s_childworkflows", workflow.ID))
}

if project.CacheDb {
Expand Down
2 changes: 1 addition & 1 deletion files.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func fileAuthentication(request *http.Request) (string, error) {
ctx := GetContext(request)
workflowExecution, err := GetWorkflowExecution(ctx, executionId[0])
if err != nil {
log.Printf("[ERROR] Couldn't find execution ID %s", executionId[0])
log.Printf("[ERROR] Couldn't find execution ID from '%s'", executionId)
return "", err
}

Expand Down
73 changes: 70 additions & 3 deletions shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -6687,7 +6687,7 @@ func diffWorkflows(oldWorkflow Workflow, newWorkflow Workflow, update bool) {
if err != nil {
log.Printf("[WARNING] Failed updating child workflow %s: %s", childWorkflow.ID, err)
} else {
log.Printf("[INFO] Updated child workflow '%s' based on parent %s", childWorkflow.ID, oldWorkflow.ID)
log.Printf("\n\n[INFO] Updated child workflow '%s' based on parent %s\n\n", childWorkflow.ID, oldWorkflow.ID)

SetWorkflowRevision(ctx, childWorkflow)
passedOrg := Org{
Expand Down Expand Up @@ -6801,7 +6801,9 @@ func SaveWorkflow(resp http.ResponseWriter, request *http.Request) {
}
*/

log.Printf("[DEBUG] Making ALL input questions required for %s", workflow.ID)
if len(workflow.InputQuestions) > 0 {
log.Printf("[DEBUG] Making ALL '%d' input questions required for %s", len(workflow.InputQuestions), workflow.ID)
}
for qIndex, _ := range workflow.InputQuestions {
workflow.InputQuestions[qIndex].Required = true
}
Expand Down Expand Up @@ -17169,6 +17171,20 @@ func HandleGetCacheKey(resp http.ResponseWriter, request *http.Request) {
resp.Write([]byte(`{"success": false, "reason": "Organization ID's don't match"}`))
return
}

user, err := HandleApiAuthentication(resp, request)
if err == nil {
user.ActiveOrg.Id = fileId
skipExecutionAuth = true

if user.ActiveOrg.Id != fileId {
log.Printf("[INFO] OrgId %s and %s don't match in get cache key list. Checking cache auth", user.ActiveOrg.Id, fileId)

requireCacheAuth = true
skipExecutionAuth = false
user.ActiveOrg.Id = fileId
}
}
} else {
if len(location) <= 6 {
log.Printf("[ERROR] Cache Path too short: %d", len(location))
Expand Down Expand Up @@ -28107,6 +28123,10 @@ func HandleExecutionCacheIncrement(ctx context.Context, execution WorkflowExecut
}
}


// FIXME: Always fails:


func GetChildWorkflows(resp http.ResponseWriter, request *http.Request) {
cors := HandleCors(resp, request)
if cors {
Expand Down Expand Up @@ -28148,11 +28168,56 @@ func GetChildWorkflows(resp http.ResponseWriter, request *http.Request) {
workflow, err := GetWorkflow(ctx, fileId)
if err != nil {
log.Printf("[WARNING] Workflow %s doesn't exist.", fileId)
resp.WriteHeader(401)
resp.WriteHeader(403)
resp.Write([]byte(`{"success": false, "reason": "Failed finding workflow"}`))
return
}

// FIXME: Check if this workflow has a parent workflow
//log.Printf("[DEBUG] Parent workflow: %#v", workflow.ParentWorkflowId)
if len(workflow.ParentWorkflowId) > 0 && workflow.ParentWorkflowId != fileId {
workflow, err = GetWorkflow(ctx, workflow.ParentWorkflowId)
if err != nil {
log.Printf("[WARNING] Parent workflow %s doesn't exist.", workflow.ParentWorkflowId)
resp.WriteHeader(403)
resp.Write([]byte(`{"success": false, "reason": "Failed finding parent workflow"}`))
return
}

// Updating role
orgUserFound := false
for _, orgId := range user.Orgs {
if orgId != workflow.OrgId {
continue
}

org, err := GetOrg(ctx, orgId)
if err != nil {
log.Printf("[WARNING] Failed getting org during parent org loading %s: %s", org.Id, err)
resp.WriteHeader(500)
resp.Write([]byte(`{"success": false}`))
return
}

for _, orgUser := range org.Users {
if user.Id == orgUser.Id {
user.Role = orgUser.Role
user.ActiveOrg.Id = org.Id
orgUserFound = true
}
}

break
}

if !orgUserFound {
log.Printf("[WARNING] User %s not found in parent org %s", user.Username, workflow.OrgId)
resp.WriteHeader(403)
resp.Write([]byte(`{"success": false, "reason": "User not found in parent org"}`))
return
}
}

// Check workflow.Sharing == private / public / org too
if user.Id != workflow.Owner || len(user.Id) == 0 {
// Added org-reader as the user should be able to read everything in an org
Expand All @@ -28172,6 +28237,8 @@ func GetChildWorkflows(resp http.ResponseWriter, request *http.Request) {
}
}



// Access is granted -> get revisions
childWorkflows, err := ListChildWorkflows(ctx, workflow.ID)
if err != nil {
Expand Down
115 changes: 99 additions & 16 deletions stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"
"sort"
"strings"
"strconv"

"encoding/json"
"io/ioutil"
Expand Down Expand Up @@ -44,7 +45,7 @@ func HandleGetWidget(resp http.ResponseWriter, request *http.Request) {
widget = location[6]
}

log.Printf("Should get widget %s in dashboard %s", widget, dashboard)
//log.Printf("Should get widget %s in dashboard %s", widget, dashboard)
id := uuid.NewV4().String()

// Returning some static info for now
Expand Down Expand Up @@ -304,6 +305,10 @@ func GetSpecificStats(resp http.ResponseWriter, request *http.Request) {
}
}

// Remove ? from orgId or statsKey
orgId = strings.Split(orgId, "?")[0]
statsKey = strings.Split(statsKey, "?")[0]

if len(statsKey) <= 1 {
log.Printf("[WARNING] Invalid stats key: %s", statsKey)
resp.WriteHeader(400)
Expand Down Expand Up @@ -331,6 +336,23 @@ func GetSpecificStats(resp http.ResponseWriter, request *http.Request) {
return
}

// Default
statDays := 30
// Check for if the query parameter exists
if len(request.URL.Query().Get("days")) > 0 {
amountQuery := request.URL.Query().Get("days")
statDays, err = strconv.Atoi(amountQuery)
if err != nil {
log.Printf("[WARNING] Failed parsing days query parameter: %s", err)
} else {
if statDays > 365 {
statDays = 365
}
}
}

log.Printf("[INFO] Should get stats for key %s for the last %d days", statsKey, statDays)

totalEntires := 0
totalValue := 0
statEntries := []AdditionalUseConfig{}
Expand All @@ -341,37 +363,89 @@ func GetSpecificStats(resp http.ResponseWriter, request *http.Request) {

allStats := []string{}
for _, daily := range info.DailyStatistics {
// Check if the date is more than statDays ago
shouldAppend := true
if daily.Date.Before(time.Now().AddDate(0, 0, -statDays)) {
shouldAppend = false
}

for _, addition := range daily.Additions {
if strings.ToLower(strings.ReplaceAll(addition.Key, " ", "_")) == statsKey {
newKey := strings.ToLower(strings.ReplaceAll(addition.Key, " ", "_"))
if shouldAppend && newKey == statsKey {
totalEntires++
totalValue += int(addition.Value)

addition.Key = statsKey
addition.Date = daily.Date
statEntries = append(statEntries, addition)
}

break
if !ArrayContains(allStats, newKey) {
allStats = append(allStats, newKey)
}
}
}

if !ArrayContains(allStats, addition.Key) {
allStats = append(allStats, addition.Key)
// Deduplicate and merge same days
mergedEntries := []AdditionalUseConfig{}
for _, entry := range statEntries {
found := false
for mergedEntryIndex, mergedEntry := range mergedEntries {
if mergedEntry.Date.Day() == entry.Date.Day() && mergedEntry.Date.Month() == entry.Date.Month() && mergedEntry.Date.Year() == entry.Date.Year() {
mergedEntries[mergedEntryIndex].Value += entry.Value
found = true
break
}
}

if !found {
mergedEntries = append(mergedEntries, entry)
}
}

if len(statEntries) == 0 {
marshalledEntries, err := json.Marshal(allStats)
if err != nil {
log.Printf("[ERROR] Failed marshal in get org stats: %s", err)
resp.WriteHeader(500)
resp.Write([]byte(fmt.Sprintf(`{"success": false, "reason": "Failed unpacking data for org stats"}`)))
return
statEntries = mergedEntries

// Check if entries exist for the last X statDays
// Backfill any missing ones
if len(statEntries) < statDays {
// Find the missing days
missingDays := []time.Time{}
for i := 0; i < statDays; i++ {
missingDays = append(missingDays, time.Now().AddDate(0, 0, -i))
}

resp.WriteHeader(200)
resp.Write([]byte(fmt.Sprintf(`{"success": false, "key": "%s", "total": %d, "available_entries": %s, "entries": []}`, statsKey, totalValue, string(marshalledEntries))))
return
// Find the missing entries
appended := 0
foundAmount := 0
toAppend := []AdditionalUseConfig{}
for _, missingDay := range missingDays {
found := false
for _, entry := range statEntries {
if entry.Date.Day() == missingDay.Day() && entry.Date.Month() == missingDay.Month() && entry.Date.Year() == missingDay.Year() {
foundAmount += 1
found = true
break
}
}

if !found {
appended += 1
toAppend = append(toAppend, AdditionalUseConfig{
Key: statsKey,
Value: 0,
Date: missingDay,
})
}
}

statEntries = append(statEntries, toAppend...)
}

// Sort statentries by date
sort.Slice(statEntries, func(i, j int) bool {
return statEntries[i].Date.Before(statEntries[j].Date)
})

marshalledEntries, err := json.Marshal(statEntries)
if err != nil {
log.Printf("[ERROR] Failed marshal in get org stats: %s", err)
Expand All @@ -380,9 +454,18 @@ func GetSpecificStats(resp http.ResponseWriter, request *http.Request) {
return
}

availableStats, err := json.Marshal(allStats)
if err != nil {
log.Printf("[ERROR] Failed marshal in get org stats: %s", err)
resp.WriteHeader(500)
resp.Write([]byte(fmt.Sprintf(`{"success": false, "reason": "Failed unpacking data for org stats"}`)))
return
}

successful := totalValue != 0

resp.WriteHeader(200)
resp.Write([]byte(fmt.Sprintf(`{"success": true, "key": "%s", "total": %d, "entries": %s}`, statsKey, totalValue, string(marshalledEntries))))
resp.Write([]byte(fmt.Sprintf(`{"success": %v, "key": "%s", "total": %d, "available_keys": %s, "entries": %s}`, successful, statsKey, totalValue, string(availableStats), string(marshalledEntries))))
}

func HandleGetStatistics(resp http.ResponseWriter, request *http.Request) {
Expand Down

0 comments on commit a55f6e8

Please sign in to comment.