Skip to content

Commit

Permalink
Fix #180 Database query built from user-controlled sources
Browse files Browse the repository at this point in the history
  • Loading branch information
albinpa committed Sep 26, 2023
1 parent 44bbb4e commit b5d567a
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 35 deletions.
8 changes: 8 additions & 0 deletions src/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"strconv"
"strings"
"time"

"github.com/microcosm-cc/bluemonday"
)

const (
Expand Down Expand Up @@ -142,3 +144,9 @@ func GetRandomString(length int) string {
}
return b.String()
}

// Sanitize sanitizes the string
func Sanitize(s string) string {
p := bluemonday.UGCPolicy()
return p.Sanitize(s)
}
4 changes: 3 additions & 1 deletion src/handler/actionlog_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ func GetOrgLogs(w http.ResponseWriter, r *http.Request) {
limit = 50
}

logs, lastID, err := actionlog.GetAccessLogByOrgID(orgID, startID, limit)
sanitizedOrgId := common.Sanitize(orgID)

logs, lastID, err := actionlog.GetAccessLogByOrgID(sanitizedOrgId, startID, limit)
if err != nil {
m := fmt.Sprintf("Failed to get logs for organization: %v", orgID)
common.HandleError(w, http.StatusInternalServerError, m, err)
Expand Down
35 changes: 29 additions & 6 deletions src/handler/consent_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ func GetConsentResponse(w http.ResponseWriter, userID string, orgID string) (Con
return c, err
}

consents, err := consent.GetByUserOrg(userID, orgID)
sanitizedOrgId := common.Sanitize(orgID)
sanitizedUserId := common.Sanitize(userID)

consents, err := consent.GetByUserOrg(sanitizedUserId, sanitizedOrgId)
if err != nil {
if err.Error() == "not found" {
var con consent.Consents
Expand Down Expand Up @@ -124,7 +127,10 @@ func GetConsents(w http.ResponseWriter, r *http.Request) {
return
}

c, err := GetConsentResponse(w, userID, orgID)
sanitizedOrgId := common.Sanitize(orgID)
sanitizedUserId := common.Sanitize(userID)

c, err := GetConsentResponse(w, sanitizedUserId, sanitizedOrgId)
if err != nil {
log.Printf("Failed to get consents for user: %v org: %v err: %v", userID, orgID, err)
return
Expand Down Expand Up @@ -273,6 +279,10 @@ func GetConsentPurposeByID(w http.ResponseWriter, r *http.Request) {
cpResp.UserID = userID
cpResp.Consents = cp

sanitizedOrgId := common.Sanitize(orgID)
sanitizedUserId := common.Sanitize(userID)
sanitizedPurposeId := common.Sanitize(purposeID)

// Data retention expiry
if o.DataRetention.Enabled {

Expand All @@ -287,7 +297,7 @@ func GetConsentPurposeByID(w http.ResponseWriter, r *http.Request) {

if isPurposeAllowed {

latestConsentHistory, err := consenthistory.GetLatestByUserOrgPurposeID(userID, orgID, purposeID)
latestConsentHistory, err := consenthistory.GetLatestByUserOrgPurposeID(sanitizedUserId, sanitizedOrgId, sanitizedPurposeId)
if err != nil {
response, _ := json.Marshal(cpResp)
w.Header().Set(config.ContentTypeHeader, config.ContentTypeJSON)
Expand Down Expand Up @@ -345,7 +355,13 @@ func GetAllUsersConsentedToAttribute(w http.ResponseWriter, r *http.Request) {
w.Write(response)
return
}
userIDs, nextID, err := consent.GetConsentedUsers(orgID, purposeID, attributeID, startID, limit)

sanitizedOrgId := common.Sanitize(orgID)
sanitizedPurposeId := common.Sanitize(purposeID)
sanitizedAttributeId := common.Sanitize(attributeID)
sanitizedStartId := common.Sanitize(startID)

userIDs, nextID, err := consent.GetConsentedUsers(sanitizedOrgId, sanitizedPurposeId, sanitizedAttributeId, sanitizedStartId, limit)

if err != nil {
m := fmt.Sprintf("Failed to fetch users constented orgID: %v purposeID: %v attributeID: %v", orgID, purposeID, attributeID)
Expand Down Expand Up @@ -454,7 +470,11 @@ func GetAllUsersConsentedToPurpose(w http.ResponseWriter, r *http.Request) {
w.Write(response)
return
}
userIDs, nextID, err := consent.GetPurposeConsentedAllUsers(orgID, purposeID, startID, limit)
sanitizedOrgId := common.Sanitize(orgID)
sanitizedPurposeId := common.Sanitize(purposeID)
sanitizedStartId := common.Sanitize(startID)

userIDs, nextID, err := consent.GetPurposeConsentedAllUsers(sanitizedOrgId, sanitizedPurposeId, sanitizedStartId, limit)

if err != nil {
m := fmt.Sprintf("Failed to fetch users constented orgID: %v purposeID: %v ", orgID, purposeID)
Expand Down Expand Up @@ -622,6 +642,9 @@ func UpdatePurposeAllConsentsv2(w http.ResponseWriter, r *http.Request) {
cRespWithDataRetention.UserID = cResp.UserID
cRespWithDataRetention.ID = cResp.ID

sanitizedOrgId := common.Sanitize(orgID)
sanitizedUserId := common.Sanitize(userID)

for i, _ := range cResp.ConsentsAndPurposes {
var tempConsentsAndPurposeWithDataRetention consentsAndPurposeWithDataRetention
tempConsentsAndPurposeWithDataRetention.Consents = cResp.ConsentsAndPurposes[i].Consents
Expand All @@ -632,7 +655,7 @@ func UpdatePurposeAllConsentsv2(w http.ResponseWriter, r *http.Request) {

// Check if purpose is allowed
if cResp.ConsentsAndPurposes[i].Count.Consented > 0 {
latestConsentHistory, err := consenthistory.GetLatestByUserOrgPurposeID(userID, orgID, cResp.ConsentsAndPurposes[i].Purpose.ID)
latestConsentHistory, err := consenthistory.GetLatestByUserOrgPurposeID(sanitizedUserId, sanitizedOrgId, cResp.ConsentsAndPurposes[i].Purpose.ID)
if err != nil {
cRespWithDataRetention.ConsentsAndPurposes = append(cRespWithDataRetention.ConsentsAndPurposes, tempConsentsAndPurposeWithDataRetention)
continue
Expand Down
9 changes: 7 additions & 2 deletions src/handler/consenthistory_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,19 @@ func GetUserConsentHistory(w http.ResponseWriter, r *http.Request) {

log.Printf("start: %v orgId: %v purposeid: %v limit: %v start:%v end:%v", startID, orgID, purposeID, limit, startDate, endDate)
if orgID != "" && purposeID != "" {
chs, lastID, err = consenthistory.GetByUserOrgPurposeID(userID, orgID, purposeID, startID, limit)
sanitizedOrgId := common.Sanitize(orgID)
sanitizedPurposeId := common.Sanitize(purposeID)

chs, lastID, err = consenthistory.GetByUserOrgPurposeID(userID, sanitizedOrgId, sanitizedPurposeId, startID, limit)
if err != nil {
m := fmt.Sprintf("Failed to get consent history for user id:%v orgID: %v purposeID : %v", userID, orgID, purposeID)
common.HandleError(w, http.StatusNotFound, m, err)
return
}
} else if orgID != "" {
chs, lastID, err = consenthistory.GetByUserOrgID(userID, orgID, startID, limit)
sanitizedOrgId := common.Sanitize(orgID)

chs, lastID, err = consenthistory.GetByUserOrgID(userID, sanitizedOrgId, startID, limit)
if err != nil {
m := fmt.Sprintf("Failed to get consent history for user id:%v orgID: %v", userID, orgID)
common.HandleError(w, http.StatusNotFound, m, err)
Expand Down
28 changes: 20 additions & 8 deletions src/handler/datarequest_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ func GetDeleteMyData(w http.ResponseWriter, r *http.Request) {
orgID := mux.Vars(r)["orgID"]
userID := token.GetUserID(r)

drs, err := getDataReqWithUserOrgTypeID(userID, orgID, dr.DataRequestTypeDelete)
sanitizedOrgId := common.Sanitize(orgID)

drs, err := getDataReqWithUserOrgTypeID(userID, sanitizedOrgId, dr.DataRequestTypeDelete)

if err != nil {
m := fmt.Sprintf("Failed to get user: %v data request for organization: %v", userID, orgID)
Expand Down Expand Up @@ -58,12 +60,14 @@ func GetMyOrgDataRequestStatus(w http.ResponseWriter, r *http.Request) {
var dReqs []dr.DataRequest
var lastID string

sanitizedOrgId := common.Sanitize(orgID)

if requestStatus == "open" {
dReqs, lastID, err = dr.GetOpenDataRequestsByOrgUserID(orgID, userID, startID, limit)
dReqs, lastID, err = dr.GetOpenDataRequestsByOrgUserID(sanitizedOrgId, userID, startID, limit)
} else if requestStatus == "closed" {
dReqs, lastID, err = dr.GetClosedDataRequestsByOrgUserID(orgID, userID, startID, limit)
dReqs, lastID, err = dr.GetClosedDataRequestsByOrgUserID(sanitizedOrgId, userID, startID, limit)
} else {
dReqs, lastID, err = dr.GetDataRequestsByOrgUserID(orgID, userID, startID, limit)
dReqs, lastID, err = dr.GetDataRequestsByOrgUserID(sanitizedOrgId, userID, startID, limit)
}

if err != nil {
Expand Down Expand Up @@ -151,7 +155,9 @@ func DeleteMyData(w http.ResponseWriter, r *http.Request) {
orgID := mux.Vars(r)["orgID"]
userID := token.GetUserID(r)

resp, err := getOngoingDataRequest(userID, orgID, dr.DataRequestTypeDelete)
sanitizedOrgId := common.Sanitize(orgID)

resp, err := getOngoingDataRequest(userID, sanitizedOrgId, dr.DataRequestTypeDelete)

if err == nil && resp.RequestOngoing == true {
m := fmt.Sprintf("Request (%v) ongoing for user: %v organization: %v", dr.GetRequestTypeStr(dr.DataRequestTypeDelete), userID, orgID)
Expand Down Expand Up @@ -184,7 +190,9 @@ func GetDeleteMyDataStatus(w http.ResponseWriter, r *http.Request) {
orgID := mux.Vars(r)["orgID"]
userID := token.GetUserID(r)

resp, err := getOngoingDataRequest(userID, orgID, dr.DataRequestTypeDelete)
sanitizedOrgId := common.Sanitize(orgID)

resp, err := getOngoingDataRequest(userID, sanitizedOrgId, dr.DataRequestTypeDelete)

if err != nil {
m := fmt.Sprintf("Failed to get user: %v data request for organization: %v", userID, orgID)
Expand Down Expand Up @@ -250,7 +258,9 @@ func GetDownloadMyData(w http.ResponseWriter, r *http.Request) {
orgID := mux.Vars(r)["orgID"]
userID := token.GetUserID(r)

drs, err := getDataReqWithUserOrgTypeID(userID, orgID, dr.DataRequestTypeDownload)
sanitizedOrgId := common.Sanitize(orgID)

drs, err := getDataReqWithUserOrgTypeID(userID, sanitizedOrgId, dr.DataRequestTypeDownload)
if err != nil {
m := fmt.Sprintf("Failed to get user: %v data request for organization: %v", userID, orgID)
common.HandleError(w, http.StatusInternalServerError, m, err)
Expand All @@ -267,7 +277,9 @@ func DownloadMyData(w http.ResponseWriter, r *http.Request) {
orgID := mux.Vars(r)["orgID"]
userID := token.GetUserID(r)

resp, err := getOngoingDataRequest(userID, orgID, dr.DataRequestTypeDownload)
sanitizedOrgId := common.Sanitize(orgID)

resp, err := getOngoingDataRequest(userID, sanitizedOrgId, dr.DataRequestTypeDownload)

if err == nil && resp.RequestOngoing {
m := fmt.Sprintf("Request (%v) ongoing for user: %v organization: %v", dr.GetRequestTypeStr(dr.DataRequestTypeDownload), userID, orgID)
Expand Down
26 changes: 19 additions & 7 deletions src/handler/iam_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,10 @@ func LoginUser(w http.ResponseWriter, r *http.Request) {
common.HandleError(w, status, m, err)
return
}
sanitizedUserName := common.Sanitize(lReq.Username)

//TODO: Remove me when the auth server is per dev environment
u, err := user.GetByEmail(lReq.Username)
u, err := user.GetByEmail(sanitizedUserName)
if err != nil {
m := fmt.Sprintf("Login failed for non existant user:%v", lReq.Username)
common.HandleError(w, http.StatusUnauthorized, m, err)
Expand Down Expand Up @@ -689,8 +691,10 @@ func ValidateUserEmail(w http.ResponseWriter, r *http.Request) {
valResp.Result = true
valResp.Message = "Email address is valid and not in use in our system"

sanitizedEmail := common.Sanitize(validateReq.Email)

//Check whether the email is unique
exist, err := user.EmailExist(validateReq.Email)
exist, err := user.EmailExist(sanitizedEmail)
if err != nil {
m := fmt.Sprintf("Failed to validate user email: %v", validateReq.Email)
common.HandleError(w, http.StatusInternalServerError, m, err)
Expand Down Expand Up @@ -731,8 +735,10 @@ func ValidatePhoneNumber(w http.ResponseWriter, r *http.Request) {
valResp.Result = true
valResp.Message = "Phone number is not in use"

sanitizedPhoneNumber := common.Sanitize(validateReq.Phone)

//Check whether the phone number is unique
exist, err := user.PhoneNumberExist(validateReq.Phone)
exist, err := user.PhoneNumberExist(sanitizedPhoneNumber)
if err != nil {
m := fmt.Sprintf("Failed to validate user phone number: %v", validateReq.Phone)
common.HandleError(w, http.StatusInternalServerError, m, err)
Expand All @@ -750,7 +756,7 @@ func ValidatePhoneNumber(w http.ResponseWriter, r *http.Request) {
}

//Check whether the phone number is in otp colleciton
o, err := otp.PhoneNumberExist(validateReq.Phone)
o, err := otp.PhoneNumberExist(sanitizedPhoneNumber)
if err != nil {
m := fmt.Sprintf("Failed to validate user phone number: %v", validateReq.Phone)
common.HandleError(w, http.StatusInternalServerError, m, err)
Expand Down Expand Up @@ -829,7 +835,9 @@ func verifyPhoneNumber(w http.ResponseWriter, r *http.Request, clientType int) {
o.Phone = verifyReq.Phone
o.Otp = vCode

oldOtp, err := otp.SearchPhone(o.Phone)
sanitizedPhoneNumber := common.Sanitize(o.Phone)

oldOtp, err := otp.SearchPhone(sanitizedPhoneNumber)
if err == nil {
otp.Delete(oldOtp.ID.Hex())
}
Expand Down Expand Up @@ -863,7 +871,9 @@ func VerifyOtp(w http.ResponseWriter, r *http.Request) {
return
}

o, err := otp.SearchPhone(otpReq.Phone)
sanitizedPhoneNumber := common.Sanitize(otpReq.Phone)

o, err := otp.SearchPhone(sanitizedPhoneNumber)
if err != nil {
valResp.Result = false
valResp.Message = "Unregistered phone number: " + otpReq.Phone
Expand Down Expand Up @@ -1007,8 +1017,10 @@ func ForgotPassword(w http.ResponseWriter, r *http.Request) {

log.Printf("User: %v forgot password", fp.Username)

sanitizedUserName := common.Sanitize(fp.Username)

//Get user details from DB
u, err := user.GetByEmail(fp.Username)
u, err := user.GetByEmail(sanitizedUserName)
if err != nil {
log.Printf("User with %v doesnt exist", fp.Username)
handleError(w, fp.Username, http.StatusNotFound, iamError{}, err)
Expand Down
11 changes: 8 additions & 3 deletions src/handler/organization_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -1445,7 +1445,10 @@ func DeleteUserFromOrganization(w http.ResponseWriter, r *http.Request) {
return
}

err = consent.DeleteByUserOrg(userID, organizationID)
sanitizedOrgId := common.Sanitize(organizationID)
sanitizedUserId := common.Sanitize(userID)

err = consent.DeleteByUserOrg(sanitizedUserId, sanitizedOrgId)
if err != nil {
m := fmt.Sprintf("Failed to remove user :%v consents from organization:%v", userID, organizationID)
common.HandleError(w, http.StatusInternalServerError, m, err)
Expand Down Expand Up @@ -1987,10 +1990,12 @@ func GetDataRequests(w http.ResponseWriter, r *http.Request) {
var dReqs []dr.DataRequest
var lastID string

sanitizedOrgId := common.Sanitize(orgID)

if requestStatus == "open" {
dReqs, lastID, err = dr.GetOpenDataRequestsByOrgID(orgID, startID, limit)
dReqs, lastID, err = dr.GetOpenDataRequestsByOrgID(sanitizedOrgId, startID, limit)
} else if requestStatus == "closed" {
dReqs, lastID, err = dr.GetClosedDataRequestsByOrgID(orgID, startID, limit)
dReqs, lastID, err = dr.GetClosedDataRequestsByOrgID(sanitizedOrgId, startID, limit)
} else {
m := fmt.Sprintf("Incorrect query parameter: %v to get data requests for organization: %v", requestStatus, orgID)
common.HandleError(w, http.StatusBadRequest, m, nil)
Expand Down
Loading

0 comments on commit b5d567a

Please sign in to comment.