From 66bdca3aa68d268382cc544c08e092bf9ff37526 Mon Sep 17 00:00:00 2001 From: Albin Antony Date: Tue, 26 Sep 2023 13:56:37 +0530 Subject: [PATCH] Fix #180 Database query built from user-controlled sources --- src/common/utils.go | 8 +++++++ src/handler/actionlog_handler.go | 4 +++- src/handler/consent_handler.go | 30 ++++++++++++++++++++++----- src/handler/consenthistory_handler.go | 9 ++++++-- src/handler/datarequest_handler.go | 16 +++++++++----- src/handler/iam_handler.go | 26 ++++++++++++++++------- src/handler/organization_handler.go | 11 +++++++--- src/handler/webhooks_handler.go | 22 +++++++++++++++----- 8 files changed, 98 insertions(+), 28 deletions(-) diff --git a/src/common/utils.go b/src/common/utils.go index ee75939..515e715 100644 --- a/src/common/utils.go +++ b/src/common/utils.go @@ -10,6 +10,8 @@ import ( "strconv" "strings" "time" + + "github.com/microcosm-cc/bluemonday" ) const ( @@ -142,3 +144,9 @@ func GetRandomString(length int) string { } return b.String() } + +// Sanitize sanitizes the string +func Sanitize(s string) string { + p := bluemonday.UGCPolicy() + return p.Sanitize(s) +} diff --git a/src/handler/actionlog_handler.go b/src/handler/actionlog_handler.go index 0a090f6..74e8123 100644 --- a/src/handler/actionlog_handler.go +++ b/src/handler/actionlog_handler.go @@ -34,7 +34,9 @@ func GetOrgLogs(w http.ResponseWriter, r *http.Request) { limit = 50 } - logs, lastID, err := actionlog.GetAccessLogByOrgID(orgID, startID, limit) + sanitizedOrgId := common.Sanitize(orgID) + + logs, lastID, err := actionlog.GetAccessLogByOrgID(sanitizedOrgId, startID, limit) if err != nil { m := fmt.Sprintf("Failed to get logs for organization: %v", orgID) common.HandleError(w, http.StatusInternalServerError, m, err) diff --git a/src/handler/consent_handler.go b/src/handler/consent_handler.go index 387ae20..c2f1412 100644 --- a/src/handler/consent_handler.go +++ b/src/handler/consent_handler.go @@ -85,7 +85,10 @@ func GetConsentResponse(w http.ResponseWriter, userID string, orgID string) (Con return c, err } - consents, err := consent.GetByUserOrg(userID, orgID) + sanitizedOrgId := common.Sanitize(orgID) + sanitizedUserId := common.Sanitize(userID) + + consents, err := consent.GetByUserOrg(sanitizedUserId, sanitizedOrgId) if err != nil { if err.Error() == "not found" { var con consent.Consents @@ -124,7 +127,10 @@ func GetConsents(w http.ResponseWriter, r *http.Request) { return } - c, err := GetConsentResponse(w, userID, orgID) + sanitizedOrgId := common.Sanitize(orgID) + sanitizedUserId := common.Sanitize(userID) + + c, err := GetConsentResponse(w, sanitizedUserId, sanitizedOrgId) if err != nil { log.Printf("Failed to get consents for user: %v org: %v err: %v", userID, orgID, err) return @@ -273,6 +279,10 @@ func GetConsentPurposeByID(w http.ResponseWriter, r *http.Request) { cpResp.UserID = userID cpResp.Consents = cp + sanitizedOrgId := common.Sanitize(orgID) + sanitizedUserId := common.Sanitize(userID) + sanitizedPurposeId := common.Sanitize(purposeID) + // Data retention expiry if o.DataRetention.Enabled { @@ -287,7 +297,7 @@ func GetConsentPurposeByID(w http.ResponseWriter, r *http.Request) { if isPurposeAllowed { - latestConsentHistory, err := consenthistory.GetLatestByUserOrgPurposeID(userID, orgID, purposeID) + latestConsentHistory, err := consenthistory.GetLatestByUserOrgPurposeID(sanitizedUserId, sanitizedOrgId, sanitizedPurposeId) if err != nil { response, _ := json.Marshal(cpResp) w.Header().Set(config.ContentTypeHeader, config.ContentTypeJSON) @@ -345,7 +355,13 @@ func GetAllUsersConsentedToAttribute(w http.ResponseWriter, r *http.Request) { w.Write(response) return } - userIDs, nextID, err := consent.GetConsentedUsers(orgID, purposeID, attributeID, startID, limit) + + sanitizedOrgId := common.Sanitize(orgID) + sanitizedPurposeId := common.Sanitize(purposeID) + sanitizedAttributeId := common.Sanitize(attributeID) + sanitizedStartId := common.Sanitize(startID) + + userIDs, nextID, err := consent.GetConsentedUsers(sanitizedOrgId, sanitizedPurposeId, sanitizedAttributeId, sanitizedStartId, limit) if err != nil { m := fmt.Sprintf("Failed to fetch users constented orgID: %v purposeID: %v attributeID: %v", orgID, purposeID, attributeID) @@ -454,7 +470,11 @@ func GetAllUsersConsentedToPurpose(w http.ResponseWriter, r *http.Request) { w.Write(response) return } - userIDs, nextID, err := consent.GetPurposeConsentedAllUsers(orgID, purposeID, startID, limit) + sanitizedOrgId := common.Sanitize(orgID) + sanitizedPurposeId := common.Sanitize(purposeID) + sanitizedStartId := common.Sanitize(startID) + + userIDs, nextID, err := consent.GetPurposeConsentedAllUsers(sanitizedOrgId, sanitizedPurposeId, sanitizedStartId, limit) if err != nil { m := fmt.Sprintf("Failed to fetch users constented orgID: %v purposeID: %v ", orgID, purposeID) diff --git a/src/handler/consenthistory_handler.go b/src/handler/consenthistory_handler.go index de02b85..7239c5d 100644 --- a/src/handler/consenthistory_handler.go +++ b/src/handler/consenthistory_handler.go @@ -95,14 +95,19 @@ func GetUserConsentHistory(w http.ResponseWriter, r *http.Request) { log.Printf("start: %v orgId: %v purposeid: %v limit: %v start:%v end:%v", startID, orgID, purposeID, limit, startDate, endDate) if orgID != "" && purposeID != "" { - chs, lastID, err = consenthistory.GetByUserOrgPurposeID(userID, orgID, purposeID, startID, limit) + sanitizedOrgId := common.Sanitize(orgID) + sanitizedPurposeId := common.Sanitize(purposeID) + + chs, lastID, err = consenthistory.GetByUserOrgPurposeID(userID, sanitizedOrgId, sanitizedPurposeId, startID, limit) if err != nil { m := fmt.Sprintf("Failed to get consent history for user id:%v orgID: %v purposeID : %v", userID, orgID, purposeID) common.HandleError(w, http.StatusNotFound, m, err) return } } else if orgID != "" { - chs, lastID, err = consenthistory.GetByUserOrgID(userID, orgID, startID, limit) + sanitizedOrgId := common.Sanitize(orgID) + + chs, lastID, err = consenthistory.GetByUserOrgID(userID, sanitizedOrgId, startID, limit) if err != nil { m := fmt.Sprintf("Failed to get consent history for user id:%v orgID: %v", userID, orgID) common.HandleError(w, http.StatusNotFound, m, err) diff --git a/src/handler/datarequest_handler.go b/src/handler/datarequest_handler.go index b14ade2..674d174 100644 --- a/src/handler/datarequest_handler.go +++ b/src/handler/datarequest_handler.go @@ -18,7 +18,9 @@ func GetDeleteMyData(w http.ResponseWriter, r *http.Request) { orgID := mux.Vars(r)["orgID"] userID := token.GetUserID(r) - drs, err := getDataReqWithUserOrgTypeID(userID, orgID, dr.DataRequestTypeDelete) + sanitizedOrgId := common.Sanitize(orgID) + + drs, err := getDataReqWithUserOrgTypeID(userID, sanitizedOrgId, dr.DataRequestTypeDelete) if err != nil { m := fmt.Sprintf("Failed to get user: %v data request for organization: %v", userID, orgID) @@ -58,12 +60,14 @@ func GetMyOrgDataRequestStatus(w http.ResponseWriter, r *http.Request) { var dReqs []dr.DataRequest var lastID string + sanitizedOrgId := common.Sanitize(orgID) + if requestStatus == "open" { - dReqs, lastID, err = dr.GetOpenDataRequestsByOrgUserID(orgID, userID, startID, limit) + dReqs, lastID, err = dr.GetOpenDataRequestsByOrgUserID(sanitizedOrgId, userID, startID, limit) } else if requestStatus == "closed" { - dReqs, lastID, err = dr.GetClosedDataRequestsByOrgUserID(orgID, userID, startID, limit) + dReqs, lastID, err = dr.GetClosedDataRequestsByOrgUserID(sanitizedOrgId, userID, startID, limit) } else { - dReqs, lastID, err = dr.GetDataRequestsByOrgUserID(orgID, userID, startID, limit) + dReqs, lastID, err = dr.GetDataRequestsByOrgUserID(sanitizedOrgId, userID, startID, limit) } if err != nil { @@ -151,7 +155,9 @@ func DeleteMyData(w http.ResponseWriter, r *http.Request) { orgID := mux.Vars(r)["orgID"] userID := token.GetUserID(r) - resp, err := getOngoingDataRequest(userID, orgID, dr.DataRequestTypeDelete) + sanitizedOrgId := common.Sanitize(orgID) + + resp, err := getOngoingDataRequest(userID, sanitizedOrgId, dr.DataRequestTypeDelete) if err == nil && resp.RequestOngoing == true { m := fmt.Sprintf("Request (%v) ongoing for user: %v organization: %v", dr.GetRequestTypeStr(dr.DataRequestTypeDelete), userID, orgID) diff --git a/src/handler/iam_handler.go b/src/handler/iam_handler.go index 5463371..4a0d1bf 100644 --- a/src/handler/iam_handler.go +++ b/src/handler/iam_handler.go @@ -520,8 +520,10 @@ func LoginUser(w http.ResponseWriter, r *http.Request) { common.HandleError(w, status, m, err) return } + sanitizedUserName := common.Sanitize(lReq.Username) + //TODO: Remove me when the auth server is per dev environment - u, err := user.GetByEmail(lReq.Username) + u, err := user.GetByEmail(sanitizedUserName) if err != nil { m := fmt.Sprintf("Login failed for non existant user:%v", lReq.Username) common.HandleError(w, http.StatusUnauthorized, m, err) @@ -689,8 +691,10 @@ func ValidateUserEmail(w http.ResponseWriter, r *http.Request) { valResp.Result = true valResp.Message = "Email address is valid and not in use in our system" + sanitizedEmail := common.Sanitize(validateReq.Email) + //Check whether the email is unique - exist, err := user.EmailExist(validateReq.Email) + exist, err := user.EmailExist(sanitizedEmail) if err != nil { m := fmt.Sprintf("Failed to validate user email: %v", validateReq.Email) common.HandleError(w, http.StatusInternalServerError, m, err) @@ -731,8 +735,10 @@ func ValidatePhoneNumber(w http.ResponseWriter, r *http.Request) { valResp.Result = true valResp.Message = "Phone number is not in use" + sanitizedPhoneNumber := common.Sanitize(validateReq.Phone) + //Check whether the phone number is unique - exist, err := user.PhoneNumberExist(validateReq.Phone) + exist, err := user.PhoneNumberExist(sanitizedPhoneNumber) if err != nil { m := fmt.Sprintf("Failed to validate user phone number: %v", validateReq.Phone) common.HandleError(w, http.StatusInternalServerError, m, err) @@ -750,7 +756,7 @@ func ValidatePhoneNumber(w http.ResponseWriter, r *http.Request) { } //Check whether the phone number is in otp colleciton - o, err := otp.PhoneNumberExist(validateReq.Phone) + o, err := otp.PhoneNumberExist(sanitizedPhoneNumber) if err != nil { m := fmt.Sprintf("Failed to validate user phone number: %v", validateReq.Phone) common.HandleError(w, http.StatusInternalServerError, m, err) @@ -829,7 +835,9 @@ func verifyPhoneNumber(w http.ResponseWriter, r *http.Request, clientType int) { o.Phone = verifyReq.Phone o.Otp = vCode - oldOtp, err := otp.SearchPhone(o.Phone) + sanitizedPhoneNumber := common.Sanitize(o.Phone) + + oldOtp, err := otp.SearchPhone(sanitizedPhoneNumber) if err == nil { otp.Delete(oldOtp.ID.Hex()) } @@ -863,7 +871,9 @@ func VerifyOtp(w http.ResponseWriter, r *http.Request) { return } - o, err := otp.SearchPhone(otpReq.Phone) + sanitizedPhoneNumber := common.Sanitize(otpReq.Phone) + + o, err := otp.SearchPhone(sanitizedPhoneNumber) if err != nil { valResp.Result = false valResp.Message = "Unregistered phone number: " + otpReq.Phone @@ -1007,8 +1017,10 @@ func ForgotPassword(w http.ResponseWriter, r *http.Request) { log.Printf("User: %v forgot password", fp.Username) + sanitizedUserName := common.Sanitize(fp.Username) + //Get user details from DB - u, err := user.GetByEmail(fp.Username) + u, err := user.GetByEmail(sanitizedUserName) if err != nil { log.Printf("User with %v doesnt exist", fp.Username) handleError(w, fp.Username, http.StatusNotFound, iamError{}, err) diff --git a/src/handler/organization_handler.go b/src/handler/organization_handler.go index eff2d6f..190d89b 100644 --- a/src/handler/organization_handler.go +++ b/src/handler/organization_handler.go @@ -1445,7 +1445,10 @@ func DeleteUserFromOrganization(w http.ResponseWriter, r *http.Request) { return } - err = consent.DeleteByUserOrg(userID, organizationID) + sanitizedOrgId := common.Sanitize(organizationID) + sanitizedUserId := common.Sanitize(userID) + + err = consent.DeleteByUserOrg(sanitizedUserId, sanitizedOrgId) if err != nil { m := fmt.Sprintf("Failed to remove user :%v consents from organization:%v", userID, organizationID) common.HandleError(w, http.StatusInternalServerError, m, err) @@ -1987,10 +1990,12 @@ func GetDataRequests(w http.ResponseWriter, r *http.Request) { var dReqs []dr.DataRequest var lastID string + sanitizedOrgId := common.Sanitize(orgID) + if requestStatus == "open" { - dReqs, lastID, err = dr.GetOpenDataRequestsByOrgID(orgID, startID, limit) + dReqs, lastID, err = dr.GetOpenDataRequestsByOrgID(sanitizedOrgId, startID, limit) } else if requestStatus == "closed" { - dReqs, lastID, err = dr.GetClosedDataRequestsByOrgID(orgID, startID, limit) + dReqs, lastID, err = dr.GetClosedDataRequestsByOrgID(sanitizedOrgId, startID, limit) } else { m := fmt.Sprintf("Incorrect query parameter: %v to get data requests for organization: %v", requestStatus, orgID) common.HandleError(w, http.StatusBadRequest, m, nil) diff --git a/src/handler/webhooks_handler.go b/src/handler/webhooks_handler.go index f278bad..63cda3d 100644 --- a/src/handler/webhooks_handler.go +++ b/src/handler/webhooks_handler.go @@ -116,9 +116,11 @@ func CreateWebhook(w http.ResponseWriter, r *http.Request) { common.HandleError(w, http.StatusBadRequest, m, err) return } + sanitizedOrgId := common.Sanitize(organizationID) + sanitizedPayloadURL := common.Sanitize(requestPayload.PayloadURL) // Check if webhook with provided payload URL already exists - count, err := wh.GetWebhookCountByPayloadURL(organizationID, requestPayload.PayloadURL) + count, err := wh.GetWebhookCountByPayloadURL(sanitizedOrgId, sanitizedPayloadURL) if err != nil { m := fmt.Sprintf("Failed to create webhook for organisation:%v", organizationID) common.HandleError(w, http.StatusInternalServerError, m, err) @@ -226,8 +228,10 @@ func GetAllWebhooks(w http.ResponseWriter, r *http.Request) { return } + sanitizedOrgId := common.Sanitize(organizationID) + // Fetching all the webhooks for an organisation - webhooks, err := wh.GetAllWebhooksByOrgID(organizationID) + webhooks, err := wh.GetAllWebhooksByOrgID(sanitizedOrgId) if err != nil { m := fmt.Sprintf("Failed to fetch webhooks for organization: %v", organizationID) common.HandleError(w, http.StatusInternalServerError, m, err) @@ -287,8 +291,10 @@ func GetWebhook(w http.ResponseWriter, r *http.Request) { return } + sanitizedOrgId := common.Sanitize(organizationID) + // Fetching webhook by ID for an organisation - webhook, err := wh.GetByOrgID(webhookID, organizationID) + webhook, err := wh.GetByOrgID(webhookID, sanitizedOrgId) if err != nil { m := fmt.Sprintf("Failed to get webhook:%v for organisation: %v", webhookID, organizationID) common.HandleError(w, http.StatusNotFound, m, err) @@ -315,8 +321,11 @@ func DeleteWebhook(w http.ResponseWriter, r *http.Request) { return } + sanitizedOrgId := common.Sanitize(organizationID) + sanitizedWebhookId := common.Sanitize(webhookID) + // Validating the given webhook ID for an organisation - _, err = wh.GetByOrgID(webhookID, organizationID) + _, err = wh.GetByOrgID(sanitizedWebhookId, sanitizedOrgId) if err != nil { m := fmt.Sprintf("Failed to get webhook:%v for organisation: %v", webhookID, organizationID) common.HandleError(w, http.StatusBadRequest, m, err) @@ -392,8 +401,11 @@ func UpdateWebhook(w http.ResponseWriter, r *http.Request) { return } + sanitizedOrgId := common.Sanitize(organizationID) + sanitizedPayloadURL := common.Sanitize(requestPayload.PayloadURL) + // Check if webhook with provided payload URL already exists - tempWebhook, err := wh.GetWebhookByPayloadURL(organizationID, requestPayload.PayloadURL) + tempWebhook, err := wh.GetWebhookByPayloadURL(sanitizedOrgId, sanitizedPayloadURL) if err == nil { if tempWebhook.ID.Hex() != webhookID { m := fmt.Sprintf("Webhook with provided payload URL already exists; Failed to update webhook for organisation:%v", organizationID)