diff --git a/go.mod b/go.mod index 89f5d57..f42a3ad 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.3 // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/s2a-go v0.1.5 // indirect github.com/google/uuid v1.3.0 // indirect @@ -29,12 +30,19 @@ require ( github.com/gorilla/css v1.0.0 // indirect github.com/gorilla/mux v1.8.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/klauspost/compress v1.13.6 // indirect github.com/microcosm-cc/bluemonday v1.0.25 // indirect + github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect github.com/spf13/cobra v1.7.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.1.2 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect + go.mongodb.org/mongo-driver v1.12.1 // indirect go.opencensus.io v0.24.0 // indirect golang.org/x/crypto v0.12.0 // indirect golang.org/x/net v0.14.0 // indirect diff --git a/go.sum b/go.sum index b4bbc9b..b6a401f 100644 --- a/go.sum +++ b/go.sum @@ -73,11 +73,15 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= @@ -100,8 +104,12 @@ github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB7 github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/microcosm-cc/bluemonday v1.0.25 h1:4NEwSfiJ+Wva0VxN5B8OwMicaJvD8r9tlJWm9rtloEg= github.com/microcosm-cc/bluemonday v1.0.25/go.mod h1:ZIOjCQp1OrzBBPIJmfX4qDYFuhU02nx4bn030ixfHLE= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= @@ -124,7 +132,17 @@ github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.mongodb.org/mongo-driver v1.12.1 h1:nLkghSU8fQNaK7oUmDhQFsnrtcoNy7Z6LVFKsEecqgE= +go.mongodb.org/mongo-driver v1.12.1/go.mod h1:/rGBTebI3XYboVmgz+Wv3Bcbl3aD0QF9zl6kDDw18rQ= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= @@ -132,6 +150,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -185,6 +204,7 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= diff --git a/src/actionlog/actionlog.go b/src/actionlog/actionlog.go index fdaeab9..ba7b6e5 100644 --- a/src/actionlog/actionlog.go +++ b/src/actionlog/actionlog.go @@ -1,9 +1,13 @@ package actionlog import ( + "context" + "github.com/bb-consent/api/src/database" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) // Log type const @@ -38,7 +42,7 @@ func GetTypeStr(logType int) string { // ActionLog All access logs type ActionLog struct { - ID bson.ObjectId `bson:"_id,omitempty"` + ID primitive.ObjectID `bson:"_id,omitempty"` Type int TypeStr string OrgID string @@ -47,20 +51,13 @@ type ActionLog struct { Action string //Free string storing the real log } -func session() *mgo.Session { - return database.DB.Session.Copy() -} - -func collection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("actionLogs") +func collection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("actionLogs") } // Add Adds access log func Add(log ActionLog) error { - s := session() - defer s.Close() - - err := collection(s).Insert(log) + _, err := collection().InsertOne(context.TODO(), log) if err != nil { return err } @@ -69,13 +66,29 @@ func Add(log ActionLog) error { // GetAccessLogByOrgID gets all notifications of a given user func GetAccessLogByOrgID(orgID string, startID string, limit int) (results []ActionLog, lastID string, err error) { - s := session() - defer s.Close() - if startID == "" { - err = collection(s).Find(bson.M{"orgid": orgID}).Sort("-_id").Limit(limit).All(&results) - } else { - err = collection(s).Find(bson.M{"orgid": orgID, "_id": bson.M{"$lt": bson.ObjectIdHex(startID)}}).Sort("-_id").Limit(limit).All(&results) + findOptions := options.Find() + findOptions.SetSort(bson.D{{Key: "_id", Value: -1}}) + findOptions.SetLimit(int64(limit)) + + filter := bson.M{"orgid": orgID} + if startID != "" { + startId, err := primitive.ObjectIDFromHex(startID) + if err != nil { + return nil, "", err + } + + filter["_id"] = bson.M{"$lt": startId} + } + + cursor, err := collection().Find(context.TODO(), filter, findOptions) + if err != nil { + return nil, "", err + } + defer cursor.Close(context.TODO()) + + if err := cursor.All(context.TODO(), &results); err != nil { + return nil, "", err } lastID = "" diff --git a/src/consent/consents.go b/src/consent/consents.go index e9e8870..aef6d2f 100644 --- a/src/consent/consents.go +++ b/src/consent/consents.go @@ -1,11 +1,13 @@ package consent import ( + "context" "time" "github.com/bb-consent/api/src/database" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" ) type consentStatus struct { @@ -30,67 +32,61 @@ type Purpose struct { // Consents data type type Consents struct { - ID bson.ObjectId `bson:"_id,omitempty"` + ID primitive.ObjectID `bson:"_id,omitempty"` OrgID string UserID string Purposes []Purpose } -func session() *mgo.Session { - return database.DB.Session.Copy() -} - -func collection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("consents") +func collection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("consents") } // Add Adds an consent to the collection func Add(consent Consents) (Consents, error) { - s := session() - defer s.Close() - consent.ID = bson.NewObjectId() - return consent, collection(s).Insert(&consent) + consent.ID = primitive.NewObjectID() + _, err := collection().InsertOne(context.TODO(), &consent) + return consent, err } // DeleteByUserOrg Deletes the consent by userID, orgID func DeleteByUserOrg(userID string, orgID string) error { - s := session() - defer s.Close() - return collection(s).Remove(bson.M{"userid": userID, "orgid": orgID}) + _, err := collection().DeleteMany(context.TODO(), bson.M{"userid": userID, "orgid": orgID}) + return err } // GetByUserOrg Get all consents of a user in organization func GetByUserOrg(userID string, orgID string) (Consents, error) { - s := session() - defer s.Close() var consents Consents - err := collection(s).Find(bson.M{"userid": userID, "orgid": orgID}).One(&consents) + err := collection().FindOne(context.TODO(), bson.M{"userid": userID, "orgid": orgID}).Decode(&consents) return consents, err } // Get Get consent by consentID func Get(consentID string) (Consents, error) { - s := session() - defer s.Close() - var result Consents - err := collection(s).FindId(bson.ObjectIdHex(consentID)).One(&result) + + consentId, err := primitive.ObjectIDFromHex(consentID) + if err != nil { + return result, err + } + err = collection().FindOne(context.TODO(), bson.M{"_id": consentId}).Decode(&result) return result, err } // GetConsentedUsers Get list of users who are consented to an attribute func GetConsentedUsers(orgID string, purposeID string, attributeID string, startID string, limit int) (userIDs []string, lastID string, err error) { - s := session() - defer s.Close() - c := collection(s) + c := collection() limit = 10000 var results []Consents + var cur *mongo.Cursor + if startID == "" { pipeline := []bson.M{ {"$match": bson.M{"orgid": orgID}}, @@ -101,9 +97,9 @@ func GetConsentedUsers(orgID string, purposeID string, attributeID string, start "purposes.consents.templateid": attributeID, "purposes.consents.status.consented": bson.M{"$regex": "^A"}}, }, - {"$limit": limit}, + {"$limit": int64(limit)}, } - err = c.Pipe(pipeline).All(&results) + cur, err = c.Aggregate(context.TODO(), pipeline) } else { pipeline := []bson.M{ {"$match": bson.M{"orgid": orgID}}, @@ -114,15 +110,21 @@ func GetConsentedUsers(orgID string, purposeID string, attributeID string, start "purposes.consents.templateid": attributeID, "purposes.consents.status.consented": bson.M{"$regex": "^A"}}, }, - {"$limit": limit}, + {"$limit": int64(limit)}, {"$gt": startID}, } - err = c.Pipe(pipeline).All(&results) + cur, err = c.Aggregate(context.TODO(), pipeline) } if err != nil { return } + defer cur.Close(context.TODO()) + + if err = cur.All(context.TODO(), &results); err != nil { + return + } + for _, item := range results { userIDs = append(userIDs, item.UserID) } @@ -136,12 +138,12 @@ func GetConsentedUsers(orgID string, purposeID string, attributeID string, start // GetPurposeConsentedAllUsers Get all users with at-least one attribute consented in purpose. func GetPurposeConsentedAllUsers(orgID string, purposeID string, startID string, limit int) (userIDs []string, lastID string, err error) { - s := session() - defer s.Close() - c := collection(s) + c := collection() limit = 10000 var results []Consents + var cur *mongo.Cursor + if startID == "" { pipeline := []bson.M{ {"$match": bson.M{"orgid": orgID}}, @@ -151,9 +153,9 @@ func GetPurposeConsentedAllUsers(orgID string, purposeID string, startID string, "purposes.id": purposeID, "purposes.consents.status.consented": bson.M{"$regex": "^A"}}, }, - {"$limit": limit}, + {"$limit": int64(limit)}, } - err = c.Pipe(pipeline).All(&results) + cur, err = c.Aggregate(context.TODO(), pipeline) } else { pipeline := []bson.M{ {"$match": bson.M{"orgid": orgID}}, @@ -163,15 +165,21 @@ func GetPurposeConsentedAllUsers(orgID string, purposeID string, startID string, "purposes.id": purposeID, "purposes.consents.status.consented": bson.M{"$regex": "^A"}}, }, - {"$limit": limit}, + {"$limit": int64(limit)}, {"$gt": startID}, } - err = c.Pipe(pipeline).All(&results) + cur, err = c.Aggregate(context.TODO(), pipeline) } if err != nil { return } + defer cur.Close(context.TODO()) + + if err = cur.All(context.TODO(), &results); err != nil { + return + } + keys := make(map[string]bool) for _, item := range results { if _, value := keys[item.UserID]; !value { @@ -189,9 +197,9 @@ func GetPurposeConsentedAllUsers(orgID string, purposeID string, startID string, // UpdatePurposes Update consents purposes func UpdatePurposes(consents Consents) (Consents, error) { - s := session() - defer s.Close() - c := collection(s) + c := collection() - return consents, c.Update(bson.M{"_id": consents.ID}, bson.M{"$set": bson.M{"purposes": consents.Purposes}}) + _, err := c.UpdateOne(context.TODO(), bson.M{"_id": consents.ID}, bson.M{"$set": bson.M{"purposes": consents.Purposes}}) + + return consents, err } diff --git a/src/consenthistory/consenthistory.go b/src/consenthistory/consenthistory.go index 5f66ea8..3b66805 100644 --- a/src/consenthistory/consenthistory.go +++ b/src/consenthistory/consenthistory.go @@ -1,16 +1,19 @@ package consenthistory import ( + "context" "time" "github.com/bb-consent/api/src/database" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) // ConsentHistory HOlds the consent logs type ConsentHistory struct { - ID bson.ObjectId `bson:"_id,omitempty"` + ID primitive.ObjectID `bson:"_id,omitempty"` UserID string OrgID string PurposeID string @@ -18,35 +21,50 @@ type ConsentHistory struct { Log string } -func session() *mgo.Session { - return database.DB.Session.Copy() -} - -func collection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("consentHistory") +func collection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("consentHistory") } // Add Adds a consent history to the collection func Add(ch ConsentHistory) (ConsentHistory, error) { - s := session() - defer s.Close() - ch.ID = bson.NewObjectId() + ch.ID = primitive.NewObjectID() - return ch, collection(s).Insert(&ch) + _, err := collection().InsertOne(context.TODO(), &ch) + + return ch, err } // GetByUserID Gets all history of a given userID func GetByUserID(userID string, startID string, limit int) ([]ConsentHistory, string, error) { - s := session() - defer s.Close() + filter := bson.M{ + "userid": userID, + } + + findOptions := options.Find() + findOptions.SetSort(bson.D{{Key: "_id", Value: -1}}) + findOptions.SetLimit(int64(limit)) + + if startID != "" { + startId, err := primitive.ObjectIDFromHex(startID) + if err != nil { + return nil, "", err + } + + filter["_id"] = bson.M{"$lt": startId} + } var results []ConsentHistory - var err error - if startID == "" { - err = collection(s).Find(bson.M{"userid": userID}).Sort("-_id").Limit(limit).All(&results) - } else { - err = collection(s).Find(bson.M{"userid": userID, "_id": bson.M{"$lt": bson.ObjectIdHex(startID)}}).Sort("-_id").Limit(limit).All(&results) + + cur, err := collection().Find(context.TODO(), filter, findOptions) + if err != nil { + return nil, "", err + } + + defer cur.Close(context.TODO()) + + if err := cur.All(context.TODO(), &results); err != nil { + return nil, "", err } var lastID = "" @@ -61,28 +79,49 @@ func GetByUserID(userID string, startID string, limit int) ([]ConsentHistory, st // GetLatestByUserOrgPurposeID Gets latest consent history of a given userID in an organization with purposeID func GetLatestByUserOrgPurposeID(userID string, orgID string, purposeID string) (ConsentHistory, error) { - s := session() - defer s.Close() - var result ConsentHistory - var err error + filter := bson.M{"userid": userID, "orgid": orgID, "purposeid": purposeID} + options := options.FindOne().SetSort(bson.D{{Key: "_id", Value: -1}}) - err = collection(s).Find(bson.M{"userid": userID, "orgid": orgID, "purposeid": purposeID}).Sort("-_id").One(&result) + var result ConsentHistory + err := collection().FindOne(context.TODO(), filter, options).Decode(&result) return result, err } // GetByUserOrgPurposeID Gets all history of a given userID in an organization with purposeID func GetByUserOrgPurposeID(userID string, orgID string, purposeID string, startID string, limit int) ([]ConsentHistory, string, error) { - s := session() - defer s.Close() + + filter := bson.M{ + "userid": userID, + "orgid": orgID, + "purposeid": purposeID, + } + + findOptions := options.Find() + findOptions.SetSort(bson.D{{Key: "_id", Value: -1}}) + findOptions.SetLimit(int64(limit)) + + if startID != "" { + startId, err := primitive.ObjectIDFromHex(startID) + if err != nil { + return nil, "", err + } + + filter["_id"] = bson.M{"$lt": startId} + } var results []ConsentHistory - var err error - if startID == "" { - err = collection(s).Find(bson.M{"userid": userID, "orgid": orgID, "purposeid": purposeID}).Sort("-_id").Limit(limit).All(&results) - } else { - err = collection(s).Find(bson.M{"userid": userID, "orgid": orgID, "purposeid": purposeID, "_id": bson.M{"$lt": bson.ObjectIdHex(startID)}}).Sort("-_id").Limit(limit).All(&results) + + cur, err := collection().Find(context.TODO(), filter, findOptions) + if err != nil { + return nil, "", err + } + + defer cur.Close(context.TODO()) + + if err := cur.All(context.TODO(), &results); err != nil { + return nil, "", err } var lastID = "" @@ -97,15 +136,35 @@ func GetByUserOrgPurposeID(userID string, orgID string, purposeID string, startI // GetByUserOrgID Gets all history of a given userID in an organization func GetByUserOrgID(userID string, orgID string, startID string, limit int) ([]ConsentHistory, string, error) { - s := session() - defer s.Close() + filter := bson.M{ + "userid": userID, + "orgid": orgID, + } + + findOptions := options.Find() + findOptions.SetSort(bson.D{{Key: "_id", Value: -1}}) + findOptions.SetLimit(int64(limit)) + + if startID != "" { + startId, err := primitive.ObjectIDFromHex(startID) + if err != nil { + return nil, "", err + } + + filter["_id"] = bson.M{"$lt": startId} + } var results []ConsentHistory - var err error - if startID == "" { - err = collection(s).Find(bson.M{"userid": userID, "orgid": orgID}).Sort("-_id").Limit(limit).All(&results) - } else { - err = collection(s).Find(bson.M{"userid": userID, "orgid": orgID, "_id": bson.M{"$lt": bson.ObjectIdHex(startID)}}).Sort("-_id").Limit(limit).All(&results) + + cur, err := collection().Find(context.TODO(), filter, findOptions) + if err != nil { + return nil, "", err + } + + defer cur.Close(context.TODO()) + + if err := cur.All(context.TODO(), &results); err != nil { + return nil, "", err } var lastID = "" @@ -120,8 +179,6 @@ func GetByUserOrgID(userID string, orgID string, startID string, limit int) ([]C // GetByDateRange Gets all history of a given userID with date range func GetByDateRange(userID string, startDate string, endDate string, startID string, limit int) ([]ConsentHistory, string, error) { - s := session() - defer s.Close() var results []ConsentHistory var err error @@ -138,13 +195,37 @@ func GetByDateRange(userID string, startDate string, endDate string, startID str if err != nil { return results, "", err } - sID := bson.NewObjectIdWithTime(sDate) - eID := bson.NewObjectIdWithTime(eDate) + + sID := primitive.NewObjectIDFromTimestamp(sDate) + eID := primitive.NewObjectIDFromTimestamp(eDate) + + findOptions := options.Find() + findOptions.SetSort(bson.D{{Key: "_id", Value: -1}}) + findOptions.SetLimit(int64(limit)) + + var cur *mongo.Cursor if startID == "" { - err = collection(s).Find(bson.M{"userid": userID, "_id": bson.M{"$gte": sID, "$lt": eID}}).Sort("-_id").Limit(limit).All(&results) + cur, err = collection().Find(context.TODO(), bson.M{"userid": userID, "_id": bson.M{"$gte": sID, "$lt": eID}}, findOptions) + if err != nil { + return nil, "", err + } } else { - err = collection(s).Find(bson.M{"userid": userID, "_id": bson.M{"$lt": bson.ObjectIdHex(startID), "$gte": sID}}).Sort("-_id").Limit(limit).All(&results) + startId, err := primitive.ObjectIDFromHex(startID) + if err != nil { + return nil, "", err + } + + cur, err = collection().Find(context.TODO(), bson.M{"userid": userID, "_id": bson.M{"$lt": startId, "$gte": sID}}, findOptions) + if err != nil { + return nil, "", err + } + } + + defer cur.Close(context.TODO()) + + if err := cur.All(context.TODO(), &results); err != nil { + return nil, "", err } var lastID = "" diff --git a/src/database/db.go b/src/database/db.go index c8263f3..f1a49f3 100644 --- a/src/database/db.go +++ b/src/database/db.go @@ -1,16 +1,19 @@ package database import ( + "context" "log" "time" "github.com/bb-consent/api/src/config" - mgo "github.com/globalsign/mgo" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) type db struct { - Session *mgo.Session - Name string + Client *mongo.Client + Name string } // DB Database session pointer @@ -18,24 +21,34 @@ var DB db // Init Connects to the DB, initializes the collection func Init(config *config.Configuration) error { - mongoDBDialInfo := &mgo.DialInfo{ - Addrs: config.DataBase.Hosts, - Timeout: 60 * time.Second, - Database: config.DataBase.Name, - Username: config.DataBase.UserName, - Password: config.DataBase.Password, + MongoDBURL := "mongodb://" + config.DataBase.UserName + ":" + config.DataBase.Password + "@" + config.DataBase.Hosts[0] + "/" + config.DataBase.Name + + clientOptions := options.Client().ApplyURI(MongoDBURL) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Create a new MongoDB client + client, err := mongo.Connect(context.Background(), clientOptions) + if err != nil { + log.Printf("Error connecting to MongoDB: %v", err) + return err } - // Create a session which maintains a pool of socket connections to our MongoDB. - session, err := mgo.DialWithInfo(mongoDBDialInfo) + // Ping the MongoDB server + err = client.Ping(ctx, nil) if err != nil { return err } - DB = db{session, config.DataBase.Name} + DB = db{ + Client: client, + Name: config.DataBase.Name, + } err = initCollection("organizations", []string{"name"}, true) if err != nil { + log.Printf("initialising collection: %v", err) return err } @@ -113,18 +126,24 @@ func Init(config *config.Configuration) error { } func initCollection(collectionName string, keys []string, unique bool) error { - c := DB.Session.DB(DB.Name).C(collectionName) - index := mgo.Index{ - Key: keys, - Unique: unique, - DropDups: true, - Background: true, - Sparse: true, + c := DB.Client.Database(DB.Name).Collection(collectionName) + + indexOptions := options.Index() + + keysDoc := bson.D{} + for _, key := range keys { + keysDoc = append(keysDoc, bson.E{Key: key, Value: 1}) + } + + indexModel := mongo.IndexModel{ + Keys: keysDoc, + Options: indexOptions.SetSparse(true).SetUnique(unique), } - err := c.EnsureIndex(index) + _, err := c.Indexes().CreateOne(context.TODO(), indexModel) if err != nil { + log.Printf("error creating index on the specified keys: %v", err) return err } diff --git a/src/datarequests/datarequests.go b/src/datarequests/datarequests.go index 8ce5777..ede4dc2 100644 --- a/src/datarequests/datarequests.go +++ b/src/datarequests/datarequests.go @@ -1,11 +1,14 @@ package datarequests import ( + "context" "time" "github.com/bb-consent/api/src/database" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) // Data Request type and status const @@ -48,7 +51,7 @@ var RequestTypes = []iDString{ // DataRequest Data request information type DataRequest struct { - ID bson.ObjectId `bson:"_id,omitempty"` + ID primitive.ObjectID `bson:"_id,omitempty"` UserID string OrgID string UserName string @@ -61,12 +64,8 @@ type DataRequest struct { Comments [DataRequestMaxComments]string } -func session() *mgo.Session { - return database.DB.Session.Copy() -} - -func collection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("userDataRequests") +func collection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("userDataRequests") } // GetStatusTypeStr Get status type string from ID @@ -86,23 +85,21 @@ func GetRequestTypeStr(requestType int) string { // Add Adds access log func Add(req DataRequest) (DataRequest, error) { - s := session() - defer s.Close() - req.ID = bson.NewObjectId() + req.ID = primitive.NewObjectID() + + _, err := collection().InsertOne(context.TODO(), req) - return req, collection(s).Insert(req) + return req, err } // Update Update the req entry -func Update(reqID bson.ObjectId, state int, comments [DataRequestMaxComments]string) (err error) { - s := session() - defer s.Close() +func Update(reqID primitive.ObjectID, state int, comments [DataRequestMaxComments]string) (err error) { if state >= DataRequestStatusProcessedWithoutAction { - err = collection(s).Update(bson.M{"_id": reqID}, bson.M{"$set": bson.M{"comments": comments, "state": state, "closeddate": time.Now()}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": reqID}, bson.M{"$set": bson.M{"comments": comments, "state": state, "closeddate": time.Now()}}) } else { - err = collection(s).Update(bson.M{"_id": reqID}, bson.M{"$set": bson.M{"comments": comments, "state": state}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": reqID}, bson.M{"$set": bson.M{"comments": comments, "state": state}}) } if err != nil { return err @@ -112,25 +109,48 @@ func Update(reqID bson.ObjectId, state int, comments [DataRequestMaxComments]str // GetDataRequestByID Returns the data requests record by ID func GetDataRequestByID(reqID string) (DataRequest, error) { - s := session() - defer s.Close() - var dataReqest DataRequest - err := collection(s).FindId(bson.ObjectIdHex(reqID)).One(&dataReqest) + + reqId, err := primitive.ObjectIDFromHex(reqID) + if err != nil { + return dataReqest, err + } + + err = collection().FindOne(context.TODO(), bson.M{"_id": reqId}).Decode(&dataReqest) return dataReqest, err } // GetOpenDataRequestsByOrgID Get data requests against orgID func GetOpenDataRequestsByOrgID(orgID string, startID string, limit int) (results []DataRequest, lastID string, err error) { - s := session() - defer s.Close() - if startID == "" { - err = collection(s).Find(bson.M{"orgid": orgID, "state": bson.M{"$lt": DataRequestStatusProcessedWithoutAction}}).Sort("-_id").Limit(limit).All(&results) - } else { - err = collection(s).Find(bson.M{"orgid": orgID, "state": bson.M{"$lt": DataRequestStatusProcessedWithoutAction}, - "_id": bson.M{"$lt": bson.ObjectIdHex(startID)}}).Sort("-_id").Limit(limit).All(&results) + filter := bson.M{ + "orgid": orgID, + "state": bson.M{"$lt": DataRequestStatusProcessedWithoutAction}, + } + + findOptions := options.Find() + findOptions.SetSort(bson.D{{Key: "_id", Value: -1}}) + findOptions.SetLimit(int64(limit)) + + if startID != "" { + startId, err := primitive.ObjectIDFromHex(startID) + if err != nil { + return nil, "", err + } + + filter["_id"] = bson.M{"$lt": startId} + } + + cur, err := collection().Find(context.TODO(), filter, findOptions) + if err != nil { + return nil, "", err + } + + defer cur.Close(context.TODO()) + + if err := cur.All(context.TODO(), &results); err != nil { + return nil, "", err } lastID = "" @@ -145,14 +165,33 @@ func GetOpenDataRequestsByOrgID(orgID string, startID string, limit int) (result // GetClosedDataRequestsByOrgID Get data requests against orgID func GetClosedDataRequestsByOrgID(orgID string, startID string, limit int) (results []DataRequest, lastID string, err error) { - s := session() - defer s.Close() + filter := bson.M{ + "orgid": orgID, + "state": bson.M{"$gte": DataRequestStatusProcessedWithoutAction}, + } - if startID == "" { - err = collection(s).Find(bson.M{"orgid": orgID, "state": bson.M{"$gte": DataRequestStatusProcessedWithoutAction}}).Sort("-_id").Limit(limit).All(&results) - } else { - err = collection(s).Find(bson.M{"orgid": orgID, "state": bson.M{"$gte": DataRequestStatusProcessedWithoutAction}, - "_id": bson.M{"$lt": bson.ObjectIdHex(startID)}}).Sort("-_id").Limit(limit).All(&results) + findOptions := options.Find() + findOptions.SetSort(bson.D{{Key: "_id", Value: -1}}) + findOptions.SetLimit(int64(limit)) + + if startID != "" { + startId, err := primitive.ObjectIDFromHex(startID) + if err != nil { + return nil, "", err + } + + filter["_id"] = bson.M{"$lt": startId} + } + + cur, err := collection().Find(context.TODO(), filter, findOptions) + if err != nil { + return nil, "", err + } + + defer cur.Close(context.TODO()) + + if err := cur.All(context.TODO(), &results); err != nil { + return nil, "", err } lastID = "" @@ -167,14 +206,34 @@ func GetClosedDataRequestsByOrgID(orgID string, startID string, limit int) (resu // GetOpenDataRequestsByOrgUserID Get data requests against orgID func GetOpenDataRequestsByOrgUserID(orgID string, userID string, startID string, limit int) (results []DataRequest, lastID string, err error) { - s := session() - defer s.Close() + filter := bson.M{ + "orgid": orgID, + "userid": userID, + "state": bson.M{"$lt": DataRequestStatusProcessedWithoutAction}, + } - if startID == "" { - err = collection(s).Find(bson.M{"orgid": orgID, "userid": userID, "state": bson.M{"$lt": DataRequestStatusProcessedWithoutAction}}).Sort("-_id").Limit(limit).All(&results) - } else { - err = collection(s).Find(bson.M{"orgid": orgID, "userid": userID, "state": bson.M{"$lt": DataRequestStatusProcessedWithoutAction}, - "_id": bson.M{"$lt": bson.ObjectIdHex(startID)}}).Sort("-_id").Limit(limit).All(&results) + findOptions := options.Find() + findOptions.SetSort(bson.D{{Key: "_id", Value: -1}}) + findOptions.SetLimit(int64(limit)) + + if startID != "" { + startId, err := primitive.ObjectIDFromHex(startID) + if err != nil { + return nil, "", err + } + + filter["_id"] = bson.M{"$lt": startId} + } + + cur, err := collection().Find(context.TODO(), filter, findOptions) + if err != nil { + return nil, "", err + } + + defer cur.Close(context.TODO()) + + if err := cur.All(context.TODO(), &results); err != nil { + return nil, "", err } lastID = "" @@ -189,14 +248,34 @@ func GetOpenDataRequestsByOrgUserID(orgID string, userID string, startID string, // GetClosedDataRequestsByOrgUserID Get data requests against orgID func GetClosedDataRequestsByOrgUserID(orgID string, userID string, startID string, limit int) (results []DataRequest, lastID string, err error) { - s := session() - defer s.Close() + filter := bson.M{ + "orgid": orgID, + "userid": userID, + "state": bson.M{"$gte": DataRequestStatusProcessedWithoutAction}, + } - if startID == "" { - err = collection(s).Find(bson.M{"orgid": orgID, "userid": userID, "state": bson.M{"$gte": DataRequestStatusProcessedWithoutAction}}).Sort("-_id").Limit(limit).All(&results) - } else { - err = collection(s).Find(bson.M{"orgid": orgID, "userid": userID, "state": bson.M{"$gte": DataRequestStatusProcessedWithoutAction}, - "_id": bson.M{"$lt": bson.ObjectIdHex(startID)}}).Sort("-_id").Limit(limit).All(&results) + findOptions := options.Find() + findOptions.SetSort(bson.D{{Key: "_id", Value: -1}}) + findOptions.SetLimit(int64(limit)) + + if startID != "" { + startId, err := primitive.ObjectIDFromHex(startID) + if err != nil { + return nil, "", err + } + + filter["_id"] = bson.M{"$lt": startId} + } + + cur, err := collection().Find(context.TODO(), filter, findOptions) + if err != nil { + return nil, "", err + } + + defer cur.Close(context.TODO()) + + if err := cur.All(context.TODO(), &results); err != nil { + return nil, "", err } lastID = "" @@ -211,13 +290,33 @@ func GetClosedDataRequestsByOrgUserID(orgID string, userID string, startID strin // GetDataRequestsByOrgUserID Get data requests against userID func GetDataRequestsByOrgUserID(orgID string, userID string, startID string, limit int) (results []DataRequest, lastID string, err error) { - s := session() - defer s.Close() + filter := bson.M{ + "orgid": orgID, + "userid": userID, + } - if startID == "" { - err = collection(s).Find(bson.M{"orgid": orgID, "userid": userID}).Sort("-_id").Limit(limit).All(&results) - } else { - err = collection(s).Find(bson.M{"orgid": orgID, "userid": userID, "_id": bson.M{"$lt": bson.ObjectIdHex(startID)}}).Sort("-_id").Limit(limit).All(&results) + findOptions := options.Find() + findOptions.SetSort(bson.D{{Key: "_id", Value: -1}}) + findOptions.SetLimit(int64(limit)) + + if startID != "" { + startId, err := primitive.ObjectIDFromHex(startID) + if err != nil { + return nil, "", err + } + + filter["_id"] = bson.M{"$lt": startId} + } + + cur, err := collection().Find(context.TODO(), filter, findOptions) + if err != nil { + return nil, "", err + } + + defer cur.Close(context.TODO()) + + if err := cur.All(context.TODO(), &results); err != nil { + return nil, "", err } lastID = "" @@ -232,10 +331,16 @@ func GetDataRequestsByOrgUserID(orgID string, userID string, startID string, lim // GetDataRequestsByUserOrgTypeID Get data requests against orgID func GetDataRequestsByUserOrgTypeID(orgID string, userID string, drType int) (results []DataRequest, err error) { - s := session() - defer s.Close() - err = collection(s).Find(bson.M{"orgid": orgID, "userid": userID, "type": drType}).All(&results) + cur, err := collection().Find(context.TODO(), bson.M{"orgid": orgID, "userid": userID, "type": drType}) + if err != nil { + return nil, err + } + defer cur.Close(context.TODO()) + + if err := cur.All(context.TODO(), &results); err != nil { + return nil, err + } - return results, err + return results, nil } diff --git a/src/handler/actionlog_handler.go b/src/handler/actionlog_handler.go index a042ed6..0a090f6 100644 --- a/src/handler/actionlog_handler.go +++ b/src/handler/actionlog_handler.go @@ -44,7 +44,7 @@ func GetOrgLogs(w http.ResponseWriter, r *http.Request) { var ls orgLogsResp for _, l := range logs { ls.Logs = append(ls.Logs, orgLog{ID: l.ID.Hex(), Type: l.Type, TypeStr: l.TypeStr, - UserID: l.UserID, UserName: l.UserName, TimeStamp: l.ID.Time().String(), Log: l.Action}) + UserID: l.UserID, UserName: l.UserName, TimeStamp: l.ID.Timestamp().String(), Log: l.Action}) } ls.Links = common.CreatePaginationLinks(r, startID, lastID, limit) diff --git a/src/handler/consent_handler.go b/src/handler/consent_handler.go index 18b6c2f..387ae20 100644 --- a/src/handler/consent_handler.go +++ b/src/handler/consent_handler.go @@ -178,7 +178,7 @@ func GetConsents(w http.ResponseWriter, r *http.Request) { continue } - RespData.ConsentsAndPurposes[i].DataRetention.Expiry = latestConsentHistory.ID.Time().Add(time.Second * time.Duration(o.DataRetention.RetentionPeriod)).UTC().String() + RespData.ConsentsAndPurposes[i].DataRetention.Expiry = latestConsentHistory.ID.Timestamp().Add(time.Second * time.Duration(o.DataRetention.RetentionPeriod)).UTC().String() log.Printf("Expiry for purpose:%v is %v", RespData.ConsentsAndPurposes[i].Purpose.ID, RespData.ConsentsAndPurposes[i].DataRetention.Expiry) } @@ -295,7 +295,7 @@ func GetConsentPurposeByID(w http.ResponseWriter, r *http.Request) { return } - cpResp.DataRetention.Expiry = latestConsentHistory.ID.Time().Add(time.Second * time.Duration(o.DataRetention.RetentionPeriod)).UTC().String() + cpResp.DataRetention.Expiry = latestConsentHistory.ID.Timestamp().Add(time.Second * time.Duration(o.DataRetention.RetentionPeriod)).UTC().String() } } @@ -638,7 +638,7 @@ func UpdatePurposeAllConsentsv2(w http.ResponseWriter, r *http.Request) { continue } - tempConsentsAndPurposeWithDataRetention.DataRetention.Expiry = latestConsentHistory.ID.Time().Add(time.Second * time.Duration(o.DataRetention.RetentionPeriod)).UTC().String() + tempConsentsAndPurposeWithDataRetention.DataRetention.Expiry = latestConsentHistory.ID.Timestamp().Add(time.Second * time.Duration(o.DataRetention.RetentionPeriod)).UTC().String() } } diff --git a/src/handler/consenthistory_handler.go b/src/handler/consenthistory_handler.go index 86bdd77..de02b85 100644 --- a/src/handler/consenthistory_handler.go +++ b/src/handler/consenthistory_handler.go @@ -126,7 +126,7 @@ func GetUserConsentHistory(w http.ResponseWriter, r *http.Request) { var chsResp consentHistoryResp for _, ch := range chs { - chsResp.ConsentHistory = append(chsResp.ConsentHistory, consentHistoryShort{ID: ch.ID.Hex(), OrgID: ch.OrgID, PurposeID: ch.PurposeID, Log: ch.Log, TimeStamp: ch.ID.Time().Format(time.RFC3339)}) + chsResp.ConsentHistory = append(chsResp.ConsentHistory, consentHistoryShort{ID: ch.ID.Hex(), OrgID: ch.OrgID, PurposeID: ch.PurposeID, Log: ch.Log, TimeStamp: ch.ID.Timestamp().Format(time.RFC3339)}) } //chsResp.Links = common.CreatePaginationLinks(r, startID, lastID, limit) diff --git a/src/handler/datarequest_handler.go b/src/handler/datarequest_handler.go index 8d36164..b14ade2 100644 --- a/src/handler/datarequest_handler.go +++ b/src/handler/datarequest_handler.go @@ -34,7 +34,7 @@ func GetDeleteMyData(w http.ResponseWriter, r *http.Request) { func transformDataReqToResp(dReq dr.DataRequest) dataReqResp { return dataReqResp{ID: dReq.ID, UserID: dReq.UserID, UserName: dReq.UserName, OrgID: dReq.OrgID, Type: dReq.Type, State: dReq.State, StateStr: dr.GetStatusTypeStr(dReq.State), Comment: dReq.Comments[dReq.State], TypeStr: dr.GetRequestTypeStr(dReq.Type), - ClosedDate: dReq.ClosedDate.String(), RequestedDate: dReq.ID.Time().String()} + ClosedDate: dReq.ClosedDate.String(), RequestedDate: dReq.ID.Timestamp().String()} } // GetMyOrgDataRequestStatus Get data request status @@ -139,7 +139,7 @@ func getOngoingDataRequest(userID string, orgID string, drType int) (resp myData resp.ID = d.ID.Hex() resp.State = d.State resp.StateStr = d.StateStr - resp.RequestedDate = d.ID.Time().String() + resp.RequestedDate = d.ID.Timestamp().String() } } diff --git a/src/handler/iam_handler.go b/src/handler/iam_handler.go index 298b35c..5463371 100644 --- a/src/handler/iam_handler.go +++ b/src/handler/iam_handler.go @@ -24,8 +24,8 @@ import ( "github.com/bb-consent/api/src/otp" "github.com/bb-consent/api/src/token" "github.com/bb-consent/api/src/user" - "github.com/globalsign/mgo/bson" "github.com/gorilla/mux" + "go.mongodb.org/mongo-driver/bson/primitive" ) type registerReq struct { @@ -145,6 +145,8 @@ func RegisterUser(w http.ResponseWriter, r *http.Request) { u.IamID = userIamID u.Email = regReq.Username u.Phone = regReq.Phone + u.Orgs = []user.Org{} + u.Roles = []user.Role{} u, err = user.Add(u) if err != nil { @@ -756,7 +758,7 @@ func ValidatePhoneNumber(w http.ResponseWriter, r *http.Request) { } if o != (otp.Otp{}) { - if bson.NewObjectId().Time().Sub(o.ID.Time()) > 2*time.Minute { + if primitive.NewObjectID().Timestamp().Sub(o.ID.Timestamp()) > 2*time.Minute { err = otp.Delete(o.ID.Hex()) if err != nil { m := fmt.Sprintf("Failed to clear expired otp") diff --git a/src/handler/organization_handler.go b/src/handler/organization_handler.go index f9d40fb..eff2d6f 100644 --- a/src/handler/organization_handler.go +++ b/src/handler/organization_handler.go @@ -2,6 +2,7 @@ package handler import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -25,8 +26,8 @@ import ( "github.com/bb-consent/api/src/token" "github.com/bb-consent/api/src/user" "github.com/bb-consent/api/src/webhooks" - "github.com/globalsign/mgo/bson" "github.com/gorilla/mux" + "go.mongodb.org/mongo-driver/bson/primitive" ) type organization struct { @@ -383,25 +384,36 @@ func handleEulaUpdateNotification(o org.Organization) { // Get all users subscribed to this organization. orgID := o.ID.Hex() - iter := user.GetOrgSubscribeIter(orgID) + iter, err := user.GetOrgSubscribeIter(orgID) + if err != nil { + log.Printf("Failed to find users: %v", err) + return + } - var u user.User + for iter.Next(context.TODO()) { + var u user.User + err := iter.Decode(&u) + if err != nil { + log.Printf("Failed to decode user: %v", err) + continue + } - for iter.Next(&u) { if u.Client.Token == "" { continue } - err := notifications.SendEulaUpdateNotification(u, o) + + err = notifications.SendEulaUpdateNotification(u, o) if err != nil { notificationErrCount++ continue } + notificationSent++ } log.Printf("notification sending for EULA update orgID: %v with err: %v sent: %v", orgID, notificationErrCount, notificationSent) - err := iter.Close() + err = iter.Close(context.TODO()) if err != nil { log.Printf("Failed to close the iterator: %v", iter) } @@ -604,7 +616,7 @@ func AddConsentPurposes(w http.ResponseWriter, r *http.Request) { tempLawfulUsage := getLawfulUsageByLawfulBasis(p.LawfulBasisOfProcessing) tempPurpose := org.Purpose{ - ID: bson.NewObjectId().Hex(), + ID: primitive.NewObjectID().Hex(), Name: p.Name, Description: p.Description, LawfulUsage: tempLawfulUsage, @@ -923,7 +935,7 @@ func AddConsentTemplates(w http.ResponseWriter, r *http.Request) { // Appending the new template to existing org templates o.Templates = append(o.Templates, org.Template{ - ID: bson.NewObjectId().Hex(), + ID: primitive.NewObjectID().Hex(), Consent: t.Consent, PurposeIDs: t.PurposeIDs, }) @@ -1291,7 +1303,7 @@ func UpdateGlobalPolicyConfiguration(w http.ResponseWriter, r *http.Request) { } // Check if type id is valid bson objectid hex - if !bson.IsObjectIdHex(policyReq.TypeID) { + if !primitive.IsValidObjectID(policyReq.TypeID) { m := fmt.Sprintf("Invalid organization type ID: %v", policyReq.TypeID) common.HandleError(w, http.StatusBadRequest, m, err) return @@ -1726,7 +1738,7 @@ func GetOrganizationSubscriptionStatus(w http.ResponseWriter, r *http.Request) { } type orgUserCount struct { - SubscribeUserCount int + SubscribeUserCount int64 } // GetOrganizationUsersCount Gets count of organization users @@ -1752,15 +1764,25 @@ func GetOrganizationUsersCount(w http.ResponseWriter, r *http.Request) { func handleDataBreachNotification(dataBreachID string, orgID string, orgName string) { // Get all users subscribed to this organization. - iter := user.GetOrgSubscribeIter(orgID) + iter, err := user.GetOrgSubscribeIter(orgID) + if err != nil { + log.Printf("Failed to find users: %v", err) + return + } - var u user.User + for iter.Next(context.TODO()) { + var u user.User + err := iter.Decode(&u) + if err != nil { + log.Printf("Failed to decode user: %v", err) + continue + } - for iter.Next(&u) { if u.Client.Token == "" { continue } - err := notifications.SendDataBreachNotification(dataBreachID, u, orgID, orgName) + + err = notifications.SendDataBreachNotification(dataBreachID, u, orgID, orgName) if err != nil { notificationErrCount++ continue @@ -1770,7 +1792,7 @@ func handleDataBreachNotification(dataBreachID string, orgID string, orgName str log.Printf("notification sending for DataBreach orgID: %v with err: %v sent: %v", orgID, notificationErrCount, notificationSent) - err := iter.Close() + err = iter.Close(context.TODO()) if err != nil { log.Printf("Failed to close the iterator: %v", iter) } @@ -1779,15 +1801,25 @@ func handleDataBreachNotification(dataBreachID string, orgID string, orgName str // TODO: Refactor and use common iterator and pass the function func handleEventNotification(eventID string, orgID string, orgName string) { // Get all users subscribed to this organization. - iter := user.GetOrgSubscribeIter(orgID) + iter, err := user.GetOrgSubscribeIter(orgID) + if err != nil { + log.Printf("Failed to find users: %v", err) + return + } - var u user.User + for iter.Next(context.TODO()) { + var u user.User + err := iter.Decode(&u) + if err != nil { + log.Printf("Failed to decode user: %v", err) + continue + } - for iter.Next(&u) { if u.Client.Token == "" { continue } - err := notifications.SendEventNotification(eventID, u, orgID, orgName) + + err = notifications.SendEventNotification(eventID, u, orgID, orgName) if err != nil { notificationErrCount++ continue @@ -1797,7 +1829,7 @@ func handleEventNotification(eventID string, orgID string, orgName string) { log.Printf("notification sending for event orgID: %v with err: %v sent: %v", orgID, notificationErrCount, notificationSent) - err := iter.Close() + err = iter.Close(context.TODO()) if err != nil { log.Printf("Failed to close the iterator: %v", iter) } @@ -1836,7 +1868,7 @@ func NotifyDataBreach(w http.ResponseWriter, r *http.Request) { } dataBreachEntry := misc.DataBreach{} - dataBreachEntry.ID = bson.NewObjectId() + dataBreachEntry.ID = primitive.NewObjectID() dataBreachEntry.HeadLine = dBNotificationReq.HeadLine dataBreachEntry.UsersCount = dBNotificationReq.UsersCount dataBreachEntry.DpoEmail = dBNotificationReq.DpoEmail @@ -1887,7 +1919,7 @@ func NotifyEvents(w http.ResponseWriter, r *http.Request) { } eventEntry := misc.Event{} - eventEntry.ID = bson.NewObjectId() + eventEntry.ID = primitive.NewObjectID() eventEntry.OrgID = orgID eventEntry.Details = eventNotificationReq.Details @@ -1906,7 +1938,7 @@ func NotifyEvents(w http.ResponseWriter, r *http.Request) { } type dataReqResp struct { - ID bson.ObjectId `bson:"_id,omitempty"` + ID primitive.ObjectID `bson:"_id,omitempty"` UserID string UserName string OrgID string diff --git a/src/handler/webhooks_handler.go b/src/handler/webhooks_handler.go index 692509b..f278bad 100644 --- a/src/handler/webhooks_handler.go +++ b/src/handler/webhooks_handler.go @@ -17,8 +17,8 @@ import ( "github.com/bb-consent/api/src/org" "github.com/bb-consent/api/src/user" wh "github.com/bb-consent/api/src/webhooks" - "github.com/globalsign/mgo/bson" "github.com/gorilla/mux" + "go.mongodb.org/mongo-driver/bson/primitive" ) // WebhookEventTypesResp Define response structure for webhook event types @@ -206,11 +206,11 @@ func CreateWebhook(w http.ResponseWriter, r *http.Request) { // WebhookWithLastDeliveryStatus Defines webhook structure along with last delivery status type WebhookWithLastDeliveryStatus struct { - ID bson.ObjectId `bson:"_id,omitempty"` // Webhook ID - PayloadURL string // Webhook payload URL - Disabled bool // Disabled or not - TimeStamp string // UTC timestamp - IsLastDeliverySuccess bool // Indicates whether last payload delivery to webhook was success or not + ID primitive.ObjectID `bson:"_id,omitempty"` // Webhook ID + PayloadURL string // Webhook payload URL + Disabled bool // Disabled or not + TimeStamp string // UTC timestamp + IsLastDeliverySuccess bool // Indicates whether last payload delivery to webhook was success or not } // GetAllWebhooks Gets all webhooks for an organisation @@ -545,13 +545,13 @@ func PingWebhook(w http.ResponseWriter, r *http.Request) { // recentWebhookDelivery Defines the structure for recent webhook delivery type recentWebhookDelivery struct { - ID bson.ObjectId `bson:"_id,omitempty"` // Webhook delivery ID - WebhookID string // Webhook ID - ResponseStatusCode int // HTTP response status code - ResponseStatusStr string // HTTP response status string - TimeStamp string // UTC timestamp when webhook execution started - Status string // Status of webhook delivery for e.g. failed or completed - StatusDescription string // Describe the status for e.g. Reason for failure + ID primitive.ObjectID `bson:"_id,omitempty"` // Webhook delivery ID + WebhookID string // Webhook ID + ResponseStatusCode int // HTTP response status code + ResponseStatusStr string // HTTP response status string + TimeStamp string // UTC timestamp when webhook execution started + Status string // Status of webhook delivery for e.g. failed or completed + StatusDescription string // Describe the status for e.g. Reason for failure } type recentWebhookDeliveryResp struct { @@ -624,7 +624,7 @@ func GetRecentWebhookDeliveries(w http.ResponseWriter, r *http.Request) { } type webhookDeliveryResp struct { - ID bson.ObjectId `bson:"_id,omitempty"` + ID primitive.ObjectID `bson:"_id,omitempty"` RequestHeaders map[string][]string // HTTP headers posted to webhook endpoint RequestPayload interface{} // JSON payload posted to webhook endpoint ResponseHeaders map[string][]string // HTTP response headers received from webhook endpoint diff --git a/src/image/images.go b/src/image/images.go index d084494..5cdd550 100644 --- a/src/image/images.go +++ b/src/image/images.go @@ -1,31 +1,29 @@ package image import ( + "context" + "github.com/bb-consent/api/src/database" - mgo "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" ) // Image data type type Image struct { - ID bson.ObjectId `bson:"_id,omitempty"` + ID primitive.ObjectID `bson:"_id,omitempty"` Data []byte } -func session() *mgo.Session { - return database.DB.Session.Copy() -} -func collection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("images") +func collection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("images") } // Add Adds an image to image store func Add(image []byte) (imageID string, err error) { - s := session() - defer s.Close() - i := Image{bson.NewObjectId(), image} - err = collection(s).Insert(&i) + i := Image{primitive.NewObjectID(), image} + _, err = collection().InsertOne(context.TODO(), &i) if err != nil { return "", err } @@ -35,10 +33,13 @@ func Add(image []byte) (imageID string, err error) { // Get Fetches the image by ID func Get(imageID string) (Image, error) { - s := session() - defer s.Close() - var image Image - err := collection(s).FindId(bson.ObjectIdHex(imageID)).One(&image) + + imageId, err := primitive.ObjectIDFromHex(imageID) + if err != nil { + return image, err + } + + err = collection().FindOne(context.TODO(), bson.M{"_id": imageId}).Decode(&image) return image, err } diff --git a/src/misc/collections.go b/src/misc/collections.go index 75c329c..c569bff 100644 --- a/src/misc/collections.go +++ b/src/misc/collections.go @@ -1,9 +1,11 @@ package misc import ( + "context" + "github.com/bb-consent/api/src/database" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" ) const ( @@ -16,7 +18,7 @@ const ( // DataBreach stores the Data breach informations type DataBreach struct { - ID bson.ObjectId `bson:"_id,omitempty"` + ID primitive.ObjectID `bson:"_id,omitempty"` Type int OrgID string HeadLine string @@ -28,27 +30,21 @@ type DataBreach struct { // Event stores event related information. type Event struct { - ID bson.ObjectId `bson:"_id,omitempty"` + ID primitive.ObjectID `bson:"_id,omitempty"` Type int OrgID string Details string } -func session() *mgo.Session { - return database.DB.Session.Copy() -} - -func collection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("misc") +func collection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("misc") } // AddDataBreachNotifications Update the data breach info to organization func AddDataBreachNotifications(dataBreach DataBreach) error { - s := session() - defer s.Close() dataBreach.Type = DocTypeOrgDataBreach - err := collection(s).Insert(dataBreach) + _, err := collection().InsertOne(context.TODO(), dataBreach) if err != nil { return err } @@ -57,11 +53,9 @@ func AddDataBreachNotifications(dataBreach DataBreach) error { // AddEventNotifications Update the data breach info to organization func AddEventNotifications(event Event) error { - s := session() - defer s.Close() event.Type = DocTypeOrgEvent - err := collection(s).Insert(event) + _, err := collection().InsertOne(context.TODO(), event) if err != nil { return err } diff --git a/src/notifications/notification_db.go b/src/notifications/notification_db.go index 0ce4c45..1e7f965 100644 --- a/src/notifications/notification_db.go +++ b/src/notifications/notification_db.go @@ -1,11 +1,13 @@ package notifications import ( + "context" "time" "github.com/bb-consent/api/src/database" - mgo "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" ) // Notification Types @@ -19,7 +21,7 @@ const ( // Notification data type type Notification struct { - ID bson.ObjectId `bson:"_id,omitempty"` + ID primitive.ObjectID `bson:"_id,omitempty"` Type int Title string UserID string @@ -33,31 +35,25 @@ type Notification struct { AttributeIDs []string } -func session() *mgo.Session { - return database.DB.Session.Copy() -} - -func collection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("notifications") +func collection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("notifications") } // Add Adds a notification to the collection func Add(notification Notification) (Notification, error) { - s := session() - defer s.Close() - notification.ID = bson.NewObjectId() + notification.ID = primitive.NewObjectID() notification.Timestamp = time.Now().Format(time.RFC3339) - return notification, collection(s).Insert(¬ification) + _, err := collection().InsertOne(context.TODO(), ¬ification) + + return notification, err } // GetUnReadCountByUserID gets count of un-read notifications of a given user -func GetUnReadCountByUserID(userID string) (count int, err error) { - s := session() - defer s.Close() +func GetUnReadCountByUserID(userID string) (int, error) { - count, err = collection(s).Find(bson.M{"userid": userID, "readstatus": false}).Count() + count, err := collection().CountDocuments(context.TODO(), bson.M{"userid": userID, "readstatus": false}) - return count, err + return int(count), err } diff --git a/src/org/organizations.go b/src/org/organizations.go index 9ee6dec..2a88f62 100644 --- a/src/org/organizations.go +++ b/src/org/organizations.go @@ -1,13 +1,16 @@ package org import ( + "context" "errors" "log" "github.com/bb-consent/api/src/database" "github.com/bb-consent/api/src/orgtype" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) // Admin Users @@ -18,7 +21,7 @@ type Admin struct { // Organization organization data type type Organization struct { - ID bson.ObjectId `bson:"_id,omitempty"` + ID primitive.ObjectID `bson:"_id,omitempty"` Name string CoverImageID string CoverImageURL string @@ -228,49 +231,55 @@ var LawfulBasisOfProcessingMappings = []LawfulBasisOfProcessingMapping{ }, } -func session() *mgo.Session { - return database.DB.Session.Copy() -} - -func collection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("organizations") +func collection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("organizations") } // Add Adds an organization func Add(org Organization) (Organization, error) { - s := session() - defer s.Close() - org.ID = bson.NewObjectId() - return org, collection(s).Insert(&org) + org.ID = primitive.NewObjectID() + _, err := collection().InsertOne(context.TODO(), &org) + if err != nil { + return org, err + } + return org, nil } // Get Gets a single organization by given id func Get(organizationID string) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } var result Organization - err := collection(s).FindId(bson.ObjectIdHex(organizationID)).One(&result) + err = collection().FindOne(context.TODO(), bson.M{"_id": orgID}).Decode(&result) return result, err } // Update Updates the organization func Update(org Organization) (Organization, error) { - s := session() - defer s.Close() - err := collection(s).UpdateId(org.ID, org) + filter := bson.M{"_id": org.ID} + update := bson.M{"$set": org} + + _, err := collection().UpdateOne(context.TODO(), filter, update) + if err != nil { + return org, err + } return org, err } // UpdateCoverImage Update the organization image func UpdateCoverImage(organizationID string, imageID string, imageURL string) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$set": bson.M{"coverimageid": imageID, "coverimageurl": imageURL}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$set": bson.M{"coverimageid": imageID, "coverimageurl": imageURL}}) if err != nil { return Organization{}, err } @@ -280,10 +289,12 @@ func UpdateCoverImage(organizationID string, imageID string, imageURL string) (O // UpdateLogoImage Update the organization image func UpdateLogoImage(organizationID string, imageID string, imageURL string) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$set": bson.M{"logoimageid": imageID, "logoimageurl": imageURL}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$set": bson.M{"logoimageid": imageID, "logoimageurl": imageURL}}) if err != nil { return Organization{}, err } @@ -293,10 +304,12 @@ func UpdateLogoImage(organizationID string, imageID string, imageURL string) (Or // AddAdminUsers Add admin users to organization func AddAdminUsers(organizationID string, admin Admin) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$push": bson.M{"admins": admin}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$push": bson.M{"admins": admin}}) if err != nil { return Organization{}, err } @@ -306,21 +319,30 @@ func AddAdminUsers(organizationID string, admin Admin) (Organization, error) { // GetAdminUsers Get admin users of organization func GetAdminUsers(organizationID string) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } + + filter := bson.M{"_id": orgID} + projection := bson.M{"admins": 1} + + findOptions := options.FindOne().SetProjection(projection) var result Organization - err := collection(s).FindId(bson.ObjectIdHex(organizationID)).Select(bson.M{"admins": 1}).One(&result) + err = collection().FindOne(context.TODO(), filter, findOptions).Decode(&result) return result, err } // DeleteAdminUsers Delete admin users from organization func DeleteAdminUsers(organizationID string, admin Admin) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$pull": bson.M{"admins": admin}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$pull": bson.M{"admins": admin}}) if err != nil { return Organization{}, err } @@ -330,34 +352,27 @@ func DeleteAdminUsers(organizationID string, admin Admin) (Organization, error) // UpdateOrganizationsOrgType Updates the embedded organization type snippet of all Organization func UpdateOrganizationsOrgType(oType orgtype.OrgType) error { - s := session() - defer s.Close() - c := collection(s) - - var org Organization - iter := c.Find(bson.M{"type._id": oType.ID}).Iter() - for iter.Next(&org) { - if org.Type.ID == oType.ID { - org.Type = oType - } - err := c.UpdateId(org.ID, org) - if err != nil { - return err - } - } - if err := iter.Close(); err != nil { + + filter := bson.M{"type._id": oType.ID} + update := bson.M{"$set": bson.M{"type": oType}} + + _, err := collection().UpdateMany(context.TODO(), filter, update) + if err != nil { return err } + log.Println("successfully updated organiztions for type name change") return nil } // UpdatePurposes Update the organization purposes func UpdatePurposes(organizationID string, purposes []Purpose) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$set": bson.M{"purposes": purposes}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$set": bson.M{"purposes": purposes}}) if err != nil { return Organization{}, err } @@ -367,10 +382,12 @@ func UpdatePurposes(organizationID string, purposes []Purpose) (Organization, er // DeletePurposes Delete the given purpose func DeletePurposes(organizationID string, purposes Purpose) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$pull": bson.M{"purposes": purposes}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$pull": bson.M{"purposes": purposes}}) if err != nil { return Organization{}, err } @@ -380,8 +397,6 @@ func DeletePurposes(organizationID string, purposes Purpose) (Organization, erro // GetPurpose Get the organization purpose by ID func GetPurpose(organizationID string, purposeID string) (Purpose, error) { - s := session() - defer s.Close() o, err := Get(organizationID) if err != nil { @@ -398,10 +413,12 @@ func GetPurpose(organizationID string, purposeID string) (Purpose, error) { // AddTemplates Add the organization templates func AddTemplates(organizationID string, template Template) error { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$push": bson.M{"templates": template}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$push": bson.M{"templates": template}}) if err != nil { return err } @@ -410,10 +427,12 @@ func AddTemplates(organizationID string, template Template) error { // DeleteTemplates Delete the organization templates func DeleteTemplates(organizationID string, templates Template) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$pull": bson.M{"templates": templates}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$pull": bson.M{"templates": templates}}) if err != nil { return Organization{}, err } @@ -423,10 +442,12 @@ func DeleteTemplates(organizationID string, templates Template) (Organization, e // UpdateTemplates Update the organization templates func UpdateTemplates(organizationID string, templates []Template) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$set": bson.M{"templates": templates}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$set": bson.M{"templates": templates}}) if err != nil { return Organization{}, err } @@ -436,8 +457,6 @@ func UpdateTemplates(organizationID string, templates []Template) (Organization, // GetTemplate Get the organization template by ID func GetTemplate(organizationID string, templateID string) (Template, error) { - s := session() - defer s.Close() o, err := Get(organizationID) if err != nil { @@ -454,10 +473,12 @@ func GetTemplate(organizationID string, templateID string) (Template, error) { // SetEnabled Sets the enabled status to true/false func SetEnabled(organizationID string, enabled bool) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$set": bson.M{"enabled": enabled}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$set": bson.M{"enabled": enabled}}) if err != nil { return Organization{}, err } @@ -467,52 +488,85 @@ func SetEnabled(organizationID string, enabled bool) (Organization, error) { // GetSubscribeMethod Get org subscribe method func GetSubscribeMethod(orgID string) (int, error) { - s := session() - defer s.Close() - c := collection(s) - var result Organization - err := c.FindId(bson.ObjectIdHex(orgID)).Select(bson.M{"subs.method": 1}).One(&result) + + orgId, err := primitive.ObjectIDFromHex(orgID) + if err != nil { + return result.Subs.Method, err + } + + filter := bson.M{"_id": orgId} + projection := bson.M{"subs.method": 1} + + findOptions := options.FindOne().SetProjection(projection) + + err = collection().FindOne(context.TODO(), filter, findOptions).Decode(&result) return result.Subs.Method, err } // UpdateSubscribeMethod Update subscription method func UpdateSubscribeMethod(orgID string, method int) error { - s := session() - defer s.Close() - c := collection(s) + orgId, err := primitive.ObjectIDFromHex(orgID) + if err != nil { + return err + } + + filter := bson.M{"_id": orgId} + update := bson.M{"$set": bson.M{"subs.method": method}} + + _, err = collection().UpdateOne(context.TODO(), filter, update) + if err != nil { + return err + } - return c.UpdateId(bson.ObjectIdHex(orgID), bson.M{"$set": bson.M{"subs.method": method}}) + return nil } // UpdateSubscribeKey Update subscription key func UpdateSubscribeKey(orgID string, key string) error { - s := session() - defer s.Close() - c := collection(s) + orgId, err := primitive.ObjectIDFromHex(orgID) + if err != nil { + return err + } + + filter := bson.M{"_id": orgId} + update := bson.M{"$set": bson.M{"subs.key": key}} - return c.UpdateId(bson.ObjectIdHex(orgID), bson.M{"$set": bson.M{"subs.key": key}}) + _, err = collection().UpdateOne(context.TODO(), filter, update) + if err != nil { + return err + } + + return nil } // GetSubscribeKey Update subscription token func GetSubscribeKey(orgID string) (string, error) { - s := session() - defer s.Close() - c := collection(s) - var result Organization - err := c.FindId(bson.ObjectIdHex(orgID)).Select(bson.M{"subs.key": 1}).One(&result) + + orgId, err := primitive.ObjectIDFromHex(orgID) + if err != nil { + return result.Subs.Key, err + } + + filter := bson.M{"_id": orgId} + projection := bson.M{"subs.key": 1} + findOptions := options.FindOne().SetProjection(projection) + + err = collection().FindOne(context.TODO(), filter, findOptions).Decode(&result) return result.Subs.Key, err } // UpdateIdentityProviderByOrgID Update the identity provider config for org func UpdateIdentityProviderByOrgID(organizationID string, identityProviderRepresentation IdentityProviderRepresentation) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$set": bson.M{"identityproviderrepresentation": identityProviderRepresentation}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$set": bson.M{"identityproviderrepresentation": identityProviderRepresentation}}) if err != nil { return Organization{}, err } @@ -522,10 +576,12 @@ func UpdateIdentityProviderByOrgID(organizationID string, identityProviderRepres // DeleteIdentityProviderByOrgID Delete the identity provider config for org func DeleteIdentityProviderByOrgID(organizationID string) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$set": bson.M{"identityproviderrepresentation": nil}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$set": bson.M{"identityproviderrepresentation": nil}}) if err != nil { return Organization{}, err } @@ -535,10 +591,12 @@ func DeleteIdentityProviderByOrgID(organizationID string) (Organization, error) // UpdateExternalIdentityProviderAvailableStatus Update the external identity provider available status for org func UpdateExternalIdentityProviderAvailableStatus(organizationID string, availableStatus bool) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$set": bson.M{"externalidentityprovideravailable": availableStatus}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$set": bson.M{"externalidentityprovideravailable": availableStatus}}) if err != nil { return Organization{}, err } @@ -548,10 +606,12 @@ func UpdateExternalIdentityProviderAvailableStatus(organizationID string, availa // UpdateOpenIDClientByOrgID Update OpenID config for org func UpdateOpenIDClientByOrgID(organizationID string, openIDConfig KeycloakOpenIDClient) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$set": bson.M{"keycloakopenidclient": openIDConfig}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$set": bson.M{"keycloakopenidclient": openIDConfig}}) if err != nil { return Organization{}, err } @@ -561,10 +621,12 @@ func UpdateOpenIDClientByOrgID(organizationID string, openIDConfig KeycloakOpenI // DeleteOpenIDClientByOrgID Delete OpenID config for org func DeleteOpenIDClientByOrgID(organizationID string) (Organization, error) { - s := session() - defer s.Close() + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return Organization{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationID)}, bson.M{"$set": bson.M{"keycloakopenidclient": nil}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgID}, bson.M{"$set": bson.M{"keycloakopenidclient": nil}}) if err != nil { return Organization{}, err } @@ -574,11 +636,18 @@ func DeleteOpenIDClientByOrgID(organizationID string) (Organization, error) { // GetName Get organization name by given id func GetName(organizationID string) (string, error) { - s := session() - defer s.Close() - var result Organization - err := collection(s).FindId(bson.ObjectIdHex(organizationID)).Select(bson.M{"name": 1}).One(&result) + + orgID, err := primitive.ObjectIDFromHex(organizationID) + if err != nil { + return result.Name, err + } + + filter := bson.M{"_id": orgID} + projection := bson.M{"name": 1} + findOptions := options.FindOne().SetProjection(projection) + + err = collection().FindOne(context.TODO(), filter, findOptions).Decode(&result) return result.Name, err } diff --git a/src/orgtype/orgType.go b/src/orgtype/orgType.go index 5f01349..fd4e612 100644 --- a/src/orgtype/orgType.go +++ b/src/orgtype/orgType.go @@ -1,65 +1,75 @@ package orgtype import ( + "context" + "github.com/bb-consent/api/src/database" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" ) // OrgType Type related information type OrgType struct { - ID bson.ObjectId `bson:"_id,omitempty"` + ID primitive.ObjectID `bson:"_id,omitempty"` Type string ImageID string ImageURL string } -func session() *mgo.Session { - return database.DB.Session.Copy() -} - -func collection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("orgTypes") +func collection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("orgTypes") } // Add Adds an organization func Add(ot OrgType) (OrgType, error) { - s := session() - defer s.Close() - ot.ID = bson.NewObjectId() - return ot, collection(s).Insert(&ot) + ot.ID = primitive.NewObjectID() + _, err := collection().InsertOne(context.TODO(), &ot) + + return ot, err } // Get Gets organization type by given id func Get(organizationTypeID string) (OrgType, error) { - s := session() - defer s.Close() - var result OrgType - err := collection(s).FindId(bson.ObjectIdHex(organizationTypeID)).One(&result) + + orgTypeID, err := primitive.ObjectIDFromHex(organizationTypeID) + if err != nil { + return result, err + } + + err = collection().FindOne(context.Background(), bson.M{"_id": orgTypeID}).Decode(&result) return result, err } // GetAll Gets all organization types func GetAll() ([]OrgType, error) { - s := session() - defer s.Close() var results []OrgType - err := collection(s).Find(nil).All(&results) + + cursor, err := collection().Find(context.TODO(), bson.M{}) + if err != nil { + return nil, err + } + defer cursor.Close(context.TODO()) + + if err := cursor.All(context.TODO(), &results); err != nil { + return nil, err + } return results, err } // Update Update the organization type func Update(organizationTypeID string, typeName string) (OrgType, error) { - s := session() - defer s.Close() + orgTypeID, err := primitive.ObjectIDFromHex(organizationTypeID) + if err != nil { + return OrgType{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationTypeID)}, - bson.M{"$set": bson.M{"type": typeName}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgTypeID}, bson.M{"$set": bson.M{"type": typeName}}) if err == nil { return Get(organizationTypeID) } @@ -68,17 +78,24 @@ func Update(organizationTypeID string, typeName string) (OrgType, error) { // Delete Deletes an organization func Delete(organizationTypeID string) error { - s := session() - defer s.Close() + orgTypeID, err := primitive.ObjectIDFromHex(organizationTypeID) + if err != nil { + return err + } - return collection(s).Remove(bson.M{"_id": bson.ObjectIdHex(organizationTypeID)}) + _, err = collection().DeleteOne(context.TODO(), bson.M{"_id": orgTypeID}) + + return err } // UpdateImage Update the org type image func UpdateImage(organizationTypeID string, imageID string, imageURL string) error { - s := session() - defer s.Close() + orgTypeID, err := primitive.ObjectIDFromHex(organizationTypeID) + if err != nil { + return err + } - return collection(s).Update(bson.M{"_id": bson.ObjectIdHex(organizationTypeID)}, + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": orgTypeID}, bson.M{"$set": bson.M{"imageid": imageID, "imageurl": imageURL}}) + return err } diff --git a/src/otp/otps.go b/src/otp/otps.go index f54493a..130fc85 100644 --- a/src/otp/otps.go +++ b/src/otp/otps.go @@ -1,14 +1,17 @@ package otp import ( + "context" + "github.com/bb-consent/api/src/database" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" ) // Otp Otp holds the generated OTP info type Otp struct { - ID bson.ObjectId `bson:"_id,omitempty"` + ID primitive.ObjectID `bson:"_id,omitempty"` Name string Email string Phone string @@ -16,70 +19,68 @@ type Otp struct { Verified bool } -func session() *mgo.Session { - return database.DB.Session.Copy() -} - -func collection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("otps") +func collection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("otps") } // Add Adds the otp to the db func Add(otp Otp) (Otp, error) { - s := session() - defer s.Close() - otp.ID = bson.NewObjectId() + otp.ID = primitive.NewObjectID() + + _, err := collection().InsertOne(context.TODO(), otp) + if err != nil { + return Otp{}, err + } - return otp, collection(s).Insert(&otp) + return otp, nil } // Delete Deletes the otp entry by ID func Delete(otpID string) error { - s := session() - defer s.Close() + otpId, err := primitive.ObjectIDFromHex(otpID) + if err != nil { + return err + } - return collection(s).RemoveId(bson.ObjectIdHex(otpID)) + _, err = collection().DeleteOne(context.TODO(), bson.M{"_id": otpId}) + if err != nil { + return err + } + + return nil } // UpdateVerified Updates the verified filed func UpdateVerified(o Otp) error { - s := session() - defer s.Close() - c := collection(s) + filter := bson.M{"_id": o.ID} + update := bson.M{"$set": bson.M{"verified": o.Verified}} - err := c.Update(bson.M{"_id": o.ID}, bson.M{"$set": bson.M{"verified": o.Verified}}) + _, err := collection().UpdateOne(context.TODO(), filter, update) return err } // PhoneNumberExist Check if phone number is already in the colleciton func PhoneNumberExist(phone string) (o Otp, err error) { - s := session() - defer s.Close() - - q := collection(s).Find(bson.M{"phone": phone}).Limit(1) + filter := bson.M{"phone": phone} - c, err := q.Count() - if err != nil { + err = collection().FindOne(context.TODO(), filter).Decode(&o) + if err == mongo.ErrNoDocuments { return o, err - } - - if c == 0 { + } else if err != nil { return o, err } - q.One(&o) return o, err } // SearchPhone Search phone number in otp db func SearchPhone(phone string) (Otp, error) { - s := session() - defer s.Close() + filter := bson.M{"phone": phone} var result Otp - err := collection(s).Find(bson.M{"phone": phone}).One(&result) + err := collection().FindOne(context.TODO(), filter).Decode(&result) if err != nil { return result, err } diff --git a/src/user/users.go b/src/user/users.go index e83d968..cb69d54 100644 --- a/src/user/users.go +++ b/src/user/users.go @@ -1,23 +1,26 @@ package user import ( + "context" "log" "time" "github.com/bb-consent/api/src/database" "github.com/bb-consent/api/src/org" "github.com/bb-consent/api/src/orgtype" - mgo "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) // Org Organization snippet stored as part of user type Org struct { - OrgID bson.ObjectId `bson:"orgid,omitempty"` + OrgID primitive.ObjectID `bson:"orgid,omitempty"` Name string Location string Type string - TypeID bson.ObjectId `bson:"typeid,omitempty"` + TypeID primitive.ObjectID `bson:"typeid,omitempty"` EulaAccepted bool } @@ -35,7 +38,7 @@ type Role struct { // User data type type User struct { - ID bson.ObjectId `bson:"_id,omitempty"` + ID primitive.ObjectID `bson:"_id,omitempty"` Name string IamID string Email string @@ -50,54 +53,55 @@ type User struct { IncompleteProfile bool } -func session() *mgo.Session { - return database.DB.Session.Copy() -} - -func collection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("users") +func collection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("users") } // Add Adds an user to the collection func Add(user User) (User, error) { - s := session() - defer s.Close() - user.ID = bson.NewObjectId() + user.ID = primitive.NewObjectID() user.LastVisit = time.Now().Format(time.RFC3339) - return user, collection(s).Insert(&user) + _, err := collection().InsertOne(context.TODO(), &user) + + return user, err } // Update Update the user details func Update(userID string, u User) (User, error) { - s := session() - defer s.Close() + userId, err := primitive.ObjectIDFromHex(userID) + if err != nil { + return User{}, err + } - err := collection(s).UpdateId(bson.ObjectIdHex(userID), u) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": userId}, bson.M{"$set": u}) if err != nil { return User{}, err } + u, err = Get(userID) return u, err } // Delete Deletes the user by ID func Delete(userID string) error { - s := session() - defer s.Close() + userId, err := primitive.ObjectIDFromHex(userID) + if err != nil { + return err + } + filter := bson.M{"_id": userId} + + _, err = collection().DeleteOne(context.TODO(), filter) - return collection(s).RemoveId(bson.ObjectIdHex(userID)) + return err } // GetByIamID Get the user by IamID func GetByIamID(iamID string) (User, error) { var result User - s := session() - defer s.Close() - - err := collection(s).Find(bson.M{"iamid": iamID}).One(&result) + err := collection().FindOne(context.TODO(), bson.M{"iamid": iamID}).Decode(&result) if err != nil { log.Printf("Failed to find user id:%v err:%v", iamID, err) return result, err @@ -108,21 +112,28 @@ func GetByIamID(iamID string) (User, error) { // Get Gets a single user by given id func Get(userID string) (User, error) { - s := session() - defer s.Close() - c := collection(s) + c := collection() + + userId, err := primitive.ObjectIDFromHex(userID) + if err != nil { + return User{}, err + } var result User - err := c.FindId(bson.ObjectIdHex(userID)).One(&result) + // Find the user by ID + filter := bson.M{"_id": userId} + err = c.FindOne(context.TODO(), filter).Decode(&result) if err != nil { - log.Printf("Failed to find user id:%v err:%v", userID, err) + log.Printf("Failed to find user ID: %v, error: %v", userID, err) return result, err } - //Update the last visited field + // Update the last visited field t := time.Now().Format(time.RFC3339) - err = c.Update(bson.M{"_id": bson.ObjectIdHex(userID)}, bson.M{"$set": bson.M{"lastvisit": t}}) + update := bson.M{"$set": bson.M{"lastvisit": t}} + updateOptions := options.Update().SetUpsert(false) + _, err = c.UpdateOne(context.TODO(), filter, update, updateOptions) if err != nil { log.Printf("Failed to update LastVisit field for id:%v \n", userID) } @@ -132,64 +143,62 @@ func Get(userID string) (User, error) { // GetByEmail Get user details by email func GetByEmail(email string) (User, error) { - s := session() - defer s.Close() - var u User - err := collection(s).Find(bson.M{"email": email}).Select(bson.M{"iamid": 1, "name": 1, "roles": 1}).One(&u) + filter := bson.M{"email": email} + + projection := bson.M{"iamid": 1, "name": 1, "roles": 1} + + findOptions := options.FindOne().SetProjection(projection) + + err := collection().FindOne(context.TODO(), filter, findOptions).Decode(&u) return u, err } // EmailExist Check if email id is already in the collection func EmailExist(email string) (bool, error) { - s := session() - defer s.Close() + filter := bson.M{"email": email} - q := collection(s).Find(bson.M{"email": email}).Limit(1) + countOptions := options.Count().SetLimit(1) - c, err := q.Count() + count, err := collection().CountDocuments(context.TODO(), filter, countOptions) if err != nil { return false, err } - if c == 0 { - return false, err - } - - return true, err + return count > 0, nil } // PhoneNumberExist Check if phone number is already in the collection func PhoneNumberExist(phone string) (bool, error) { - s := session() - defer s.Close() + filter := bson.M{"phone": phone} - q := collection(s).Find(bson.M{"phone": phone}).Limit(1) + countOptions := options.Count().SetLimit(1) - c, err := q.Count() + count, err := collection().CountDocuments(context.TODO(), filter, countOptions) if err != nil { return false, err } - if c == 0 { - return false, err - } - - return true, err + return count > 0, nil } // UpdateClientDeviceInfo Update the client device info func UpdateClientDeviceInfo(userID string, client ClientInfo) (User, error) { - s := session() - defer s.Close() - c := collection(s) + userId, err := primitive.ObjectIDFromHex(userID) + if err != nil { + return User{}, err + } - err := c.Update(bson.M{"_id": bson.ObjectIdHex(userID)}, bson.M{"$set": bson.M{"client": client}}) + filter := bson.M{"_id": userId} + update := bson.M{"$set": bson.M{"client": client}} + + _, err = collection().UpdateOne(context.TODO(), filter, update) if err != nil { return User{}, err } + //TODO: Is this DB get necessary? u, err := Get(userID) return u, err @@ -197,10 +206,12 @@ func UpdateClientDeviceInfo(userID string, client ClientInfo) (User, error) { // AddRole Add roles to users func AddRole(userID string, role Role) (User, error) { - s := session() - defer s.Close() + userId, err := primitive.ObjectIDFromHex(userID) + if err != nil { + return User{}, err + } - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(userID)}, bson.M{"$push": bson.M{"roles": role}}) + _, err = collection().UpdateOne(context.TODO(), bson.M{"_id": userId}, bson.M{"$push": bson.M{"roles": role}}) if err != nil { return User{}, err } @@ -210,45 +221,73 @@ func AddRole(userID string, role Role) (User, error) { // GetOrgSubscribeUsers Get list of users subscribed to an organizations func GetOrgSubscribeUsers(orgID string, startID string, limit int) ([]User, string, error) { - s := session() - defer s.Close() - var results []User var err error limit = 10000 - if startID == "" { - err = collection(s).Find(bson.M{"orgs.orgid": bson.ObjectIdHex(orgID)}).Select(bson.M{"name": 1, "phone": 1, "email": 1}).Sort("-_id").Limit(limit).All(&results) - } else { - err = collection(s).Find(bson.M{"orgs.orgid": bson.ObjectIdHex(orgID), "_id": bson.M{"$lt": bson.ObjectIdHex(startID)}}).Select(bson.M{"name": 1, "phone": 1, "email": 1}).Sort("-_id").Limit(limit).All(&results) + + findOptions := options.Find() + findOptions.SetSort(bson.D{{Key: "_id", Value: -1}}) + findOptions.SetLimit(int64(limit)) + + orgId, err := primitive.ObjectIDFromHex(orgID) + if err != nil { + return results, "", err + } + + filter := bson.M{"orgs.orgid": orgId} + if startID != "" { + startId, err := primitive.ObjectIDFromHex(startID) + if err != nil { + return results, "", err + } + + filter["_id"] = bson.M{"$lt": startId} + } + + cursor, err := collection().Find(context.TODO(), filter, findOptions) + if err != nil { + return results, "", err } + defer cursor.Close(context.TODO()) - var lastID = "" + err = cursor.All(context.TODO(), &results) + + lastID := "" if err == nil { if len(results) != 0 && len(results) == (limit) { lastID = results[len(results)-1].ID.Hex() } } - return results, lastID, err + return results, lastID, nil + } // GetOrgSubscribeIter Get Iterator to users subscribed to an organizations -func GetOrgSubscribeIter(orgID string) *mgo.Iter { - s := session() - defer s.Close() +func GetOrgSubscribeIter(orgID string) (*mongo.Cursor, error) { + orgId, err := primitive.ObjectIDFromHex(orgID) + if err != nil { + return nil, err + } - iter := collection(s).Find(bson.M{"orgs.orgid": bson.ObjectIdHex(orgID)}).Iter() + filter := bson.M{"orgs.orgid": orgId} + cursor, err := collection().Find(context.TODO(), filter) + if err != nil { + return nil, err + } - return iter + return cursor, nil } // GetOrgSubscribeCount Get count of users subscribed to an organizations -func GetOrgSubscribeCount(orgID string) (int, error) { - s := session() - defer s.Close() - - count, err := collection(s).Find(bson.M{"orgs.orgid": bson.ObjectIdHex(orgID)}).Count() +func GetOrgSubscribeCount(orgID string) (int64, error) { + orgId, err := primitive.ObjectIDFromHex(orgID) + if err != nil { + return 0, err + } + filter := bson.M{"orgs.orgid": orgId} + count, err := collection().CountDocuments(context.TODO(), filter) if err != nil { log.Printf("Failed to find user count by org id:%v err:%v", orgID, err) return 0, err @@ -259,69 +298,92 @@ func GetOrgSubscribeCount(orgID string) (int, error) { // UpdateOrgTypeOfSubscribedUsers Updates the embedded organization type snippet for all users func UpdateOrgTypeOfSubscribedUsers(orgType orgtype.OrgType) error { - s := session() - defer s.Close() - c := collection(s) + filter := bson.M{"orgs.typeid": orgType.ID} + + cursor, err := collection().Find(context.TODO(), filter) + if err != nil { + return err + } + defer cursor.Close(context.TODO()) + + for cursor.Next(context.TODO()) { + var u User + err := cursor.Decode(&u) + if err != nil { + return err + } - var u User - iter := c.Find(bson.M{"orgs.typeid": orgType.ID}).Iter() - for iter.Next(&u) { for i := range u.Orgs { if u.Orgs[i].TypeID == orgType.ID { u.Orgs[i].Type = orgType.Type } - err := c.UpdateId(u.ID, u) - if err != nil { - return err - } + } + + _, err = collection().ReplaceOne(context.TODO(), bson.M{"_id": u.ID}, u) + if err != nil { + return err } } - if err := iter.Close(); err != nil { - return err - } - log.Println("successfully updated users for organization type name change") + + log.Println("Successfully updated users for organization type name change") return nil } // UpdateOrganizationsSubscribedUsers Updates the embedded organization snippet for all users func UpdateOrganizationsSubscribedUsers(org org.Organization) error { - s := session() - defer s.Close() - c := collection(s) + filter := bson.M{"orgs.orgid": org.ID} + cursor, err := collection().Find(context.TODO(), filter) + if err != nil { + return err + } + defer cursor.Close(context.TODO()) + + for cursor.Next(context.TODO()) { + var result User + err := cursor.Decode(&result) + if err != nil { + return err + } - var result User - iter := c.Find(bson.M{"orgs.orgid": org.ID}).Iter() - for iter.Next(&result) { for i := range result.Orgs { if result.Orgs[i].OrgID == org.ID { result.Orgs[i].Name = org.Name result.Orgs[i].Location = org.Location } - err := c.UpdateId(result.ID, result) - if err != nil { - return err - } + } + + _, err = collection().ReplaceOne(context.TODO(), bson.M{"_id": result.ID}, result) + if err != nil { + return err } } - if err := iter.Close(); err != nil { - return err - } + return nil } // UpdateOrganization Updates organization to user collection func UpdateOrganization(userID string, org Org) (User, error) { - s := session() - defer s.Close() - c := collection(s) - var result User - err := c.Update(bson.M{"_id": bson.ObjectIdHex(userID)}, bson.M{"$push": bson.M{"orgs": org}}) + + userId, err := primitive.ObjectIDFromHex(userID) + if err != nil { + return User{}, err + } + + filter := bson.M{"_id": userId} + + update := bson.M{"$push": bson.M{"orgs": org}} + + _, err = collection().UpdateOne(context.TODO(), filter, update) + if err != nil { + return result, err + } + + err = collection().FindOne(context.TODO(), filter).Decode(&result) if err != nil { return result, err } - err = c.FindId(bson.ObjectIdHex(userID)).One(&result) return result, err } @@ -336,9 +398,6 @@ func GetUserOrgDetails(u User, oID string) (org Org, found bool) { // DeleteOrganization Remove user from an organization func DeleteOrganization(userID string, orgID string) (User, error) { - s := session() - defer s.Close() - c := collection(s) u, err := Get(userID) if err != nil { @@ -347,47 +406,69 @@ func DeleteOrganization(userID string, orgID string) (User, error) { org, _ := GetUserOrgDetails(u, orgID) //Check found == true + userId, err := primitive.ObjectIDFromHex(userID) + if err != nil { + return User{}, err + } + + filter := bson.M{"_id": userId} + update := bson.M{"$pull": bson.M{"orgs": org}} + var result User - err = c.Update(bson.M{"_id": bson.ObjectIdHex(userID)}, bson.M{"$pull": bson.M{"orgs": org}}) + + _, err = collection().UpdateOne(context.TODO(), filter, update) if err != nil { return result, err } - err = c.FindId(bson.ObjectIdHex(userID)).One(&result) + err = collection().FindOne(context.TODO(), filter).Decode(&result) + return result, err } // RemoveRole Remove role of an user func RemoveRole(userID string, role Role) (User, error) { - s := session() - defer s.Close() + userId, err := primitive.ObjectIDFromHex(userID) + if err != nil { + return User{}, err + } + filter := bson.M{"_id": userId} + update := bson.M{"$pull": bson.M{"roles": role}} - err := collection(s).Update(bson.M{"_id": bson.ObjectIdHex(userID)}, bson.M{"$pull": bson.M{"roles": role}}) + _, err = collection().UpdateOne(context.TODO(), filter, update) if err != nil { return User{}, err } + u, err := Get(userID) return u, err } // UpdateAPIKey update apikey to user func UpdateAPIKey(userID string, apiKey string) error { - s := session() - defer s.Close() - c := collection(s) + userId, err := primitive.ObjectIDFromHex(userID) + if err != nil { + return err + } + filter := bson.M{"_id": userId} + update := bson.M{"$set": bson.M{"apikey": apiKey}} - err := c.Update(bson.M{"_id": bson.ObjectIdHex(userID)}, bson.M{"$set": bson.M{"apikey": apiKey}}) + _, err = collection().UpdateOne(context.TODO(), filter, update) return err } // GetAPIKey Gets the API key of the user func GetAPIKey(userID string) (string, error) { - s := session() - defer s.Close() + userId, err := primitive.ObjectIDFromHex(userID) + if err != nil { + return "", err + } - var result User - err := collection(s).FindId(bson.ObjectIdHex(userID)).Select(bson.M{"apikey": 1}).One(&result) + projection := bson.M{"apikey": 1} + opts := options.FindOne().SetProjection(projection) + var result User + err = collection().FindOne(context.TODO(), bson.M{"_id": userId}, opts).Decode(&result) if err != nil { log.Printf("Failed to find user by id:%v err:%v", userID, err) return "", err diff --git a/src/webhookdispatcher/webhookdispatcher.go b/src/webhookdispatcher/webhookdispatcher.go index bff3403..49ec355 100644 --- a/src/webhookdispatcher/webhookdispatcher.go +++ b/src/webhookdispatcher/webhookdispatcher.go @@ -18,7 +18,7 @@ import ( "github.com/bb-consent/api/src/config" "github.com/confluentinc/confluent-kafka-go/kafka" - "github.com/globalsign/mgo/bson" + "go.mongodb.org/mongo-driver/bson/primitive" ) // WebhookEvent Webhook event wrapper @@ -59,20 +59,20 @@ var DeliveryStatus = map[int]string{ // Webhook Defines the structure for an organisation webhook type Webhook struct { - ID bson.ObjectId `bson:"_id,omitempty"` // Webhook ID - OrgID string // Organisation ID - PayloadURL string // Webhook payload URL - ContentType string // Webhook payload content type for e.g application/json - SubscribedEvents []string // Events subscribed for e.g. user.data.delete - Disabled bool // Disabled or not - SecretKey string // For calculating SHA256 HMAC to verify data integrity and authenticity - SkipSSLVerification bool // Skip SSL certificate verification or not (expiry is checked) - TimeStamp string // UTC timestamp + ID primitive.ObjectID `bson:"_id,omitempty"` // Webhook ID + OrgID string // Organisation ID + PayloadURL string // Webhook payload URL + ContentType string // Webhook payload content type for e.g application/json + SubscribedEvents []string // Events subscribed for e.g. user.data.delete + Disabled bool // Disabled or not + SecretKey string // For calculating SHA256 HMAC to verify data integrity and authenticity + SkipSSLVerification bool // Skip SSL certificate verification or not (expiry is checked) + TimeStamp string // UTC timestamp } // WebhookDelivery Details of payload delivery to webhook endpoint type WebhookDelivery struct { - ID bson.ObjectId `bson:"_id,omitempty"` // Webhook delivery ID + ID primitive.ObjectID `bson:"_id,omitempty"` // Webhook delivery ID WebhookID string // Webhook ID UserID string // ID of user who triggered the webhook event WebhookEventType string // Webhook event type for e.g. data.delete.initiated @@ -162,7 +162,7 @@ func WebhookDispatcherInit(webhookConfig *config.Configuration) { // Instantiating webhook delivery webhookDelivery = WebhookDelivery{ - ID: bson.NewObjectId(), + ID: primitive.NewObjectID(), WebhookID: webhookEvent.WebhookID, UserID: userID, WebhookEventType: webhookEventType, diff --git a/src/webhookdispatcher/webhookdispatcher_db.go b/src/webhookdispatcher/webhookdispatcher_db.go index ca9ec0a..e38ea18 100644 --- a/src/webhookdispatcher/webhookdispatcher_db.go +++ b/src/webhookdispatcher/webhookdispatcher_db.go @@ -1,41 +1,42 @@ package webhookdispatcher import ( + "context" + "github.com/bb-consent/api/src/database" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" ) -func session() *mgo.Session { - return database.DB.Session.Copy() -} - -func webhookCollection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("webhooks") +func webhookCollection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("webhooks") } -func webhookDeliveryCollection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("webhookDeliveries") +func webhookDeliveryCollection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("webhookDeliveries") } // GetWebhookByOrgID Gets a webhook by organisation ID and webhook ID func GetWebhookByOrgID(webhookID, orgID string) (result Webhook, err error) { - s := session() - defer s.Close() + webhookId, err := primitive.ObjectIDFromHex(webhookID) + if err != nil { + return result, err + } - err = webhookCollection(s).Find(bson.M{"_id": bson.ObjectIdHex(webhookID), "orgid": orgID}).One(&result) + err = webhookCollection().FindOne(context.TODO(), bson.M{"_id": webhookId, "orgid": orgID}).Decode(&result) return result, err } // AddWebhookDelivery Adds payload delivery details to database for a webhook event func AddWebhookDelivery(webhookDelivery WebhookDelivery) (WebhookDelivery, error) { - s := session() - defer s.Close() - if webhookDelivery.ID == "" { - webhookDelivery.ID = bson.NewObjectId() + if webhookDelivery.ID == primitive.NilObjectID { + webhookDelivery.ID = primitive.NewObjectID() } - return webhookDelivery, webhookDeliveryCollection(s).Insert(&webhookDelivery) + _, err := webhookDeliveryCollection().InsertOne(context.TODO(), &webhookDelivery) + + return webhookDelivery, err } diff --git a/src/webhooks/webhooks_db.go b/src/webhooks/webhooks_db.go index 41f8d0f..9cae2e2 100644 --- a/src/webhooks/webhooks_db.go +++ b/src/webhooks/webhooks_db.go @@ -1,27 +1,31 @@ package webhooks import ( + "context" + "github.com/bb-consent/api/src/database" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) // Webhook Defines the structure for an organisation webhook type Webhook struct { - ID bson.ObjectId `bson:"_id,omitempty"` // Webhook ID - OrgID string // Organisation ID - PayloadURL string // Webhook payload URL - ContentType string // Webhook payload content type for e.g application/json - SubscribedEvents []string // Events subscribed for e.g. user.data.delete - Disabled bool // Disabled or not - SecretKey string // For calculating SHA256 HMAC to verify data integrity and authenticity - SkipSSLVerification bool // Skip SSL certificate verification or not (expiry is checked) - TimeStamp string // UTC timestamp + ID primitive.ObjectID `bson:"_id,omitempty"` // Webhook ID + OrgID string // Organisation ID + PayloadURL string // Webhook payload URL + ContentType string // Webhook payload content type for e.g application/json + SubscribedEvents []string // Events subscribed for e.g. user.data.delete + Disabled bool // Disabled or not + SecretKey string // For calculating SHA256 HMAC to verify data integrity and authenticity + SkipSSLVerification bool // Skip SSL certificate verification or not (expiry is checked) + TimeStamp string // UTC timestamp } // WebhookDelivery Details of payload delivery to webhook endpoint type WebhookDelivery struct { - ID bson.ObjectId `bson:"_id,omitempty"` // Webhook delivery ID + ID primitive.ObjectID `bson:"_id,omitempty"` // Webhook delivery ID WebhookID string // Webhook ID UserID string // ID of user who triggered the webhook event WebhookEventType string // Webhook event type for e.g. data.delete.initiated @@ -37,132 +41,175 @@ type WebhookDelivery struct { StatusDescription string // Describe the status for e.g. Reason for failure } -func session() *mgo.Session { - return database.DB.Session.Copy() -} - -func webhookCollection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("webhooks") +func webhookCollection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("webhooks") } -func webhookDeliveryCollection(s *mgo.Session) *mgo.Collection { - return s.DB(database.DB.Name).C("webhookDeliveries") +func webhookDeliveryCollection() *mongo.Collection { + return database.DB.Client.Database(database.DB.Name).Collection("webhookDeliveries") } // CreateWebhook Adds a webhook for an organisation func CreateWebhook(webhook Webhook) (Webhook, error) { - s := session() - defer s.Close() + webhook.ID = primitive.NewObjectID() - webhook.ID = bson.NewObjectId() - - return webhook, webhookCollection(s).Insert(&webhook) + _, err := webhookCollection().InsertOne(context.TODO(), &webhook) + if err != nil { + return webhook, err + } + return webhook, nil } // GetByOrgID Gets a webhook by organisation ID and webhook ID func GetByOrgID(webhookID, orgID string) (result Webhook, err error) { - s := session() - defer s.Close() + webhookId, err := primitive.ObjectIDFromHex(webhookID) + if err != nil { + return result, err + } - err = webhookCollection(s).Find(bson.M{"_id": bson.ObjectIdHex(webhookID), "orgid": orgID}).One(&result) + err = webhookCollection().FindOne(context.TODO(), bson.M{"_id": webhookId, "orgid": orgID}).Decode(&result) return result, err } // DeleteWebhook Deletes a webhook for an organisation func DeleteWebhook(webhookID string) error { - s := session() - defer s.Close() + webhookId, err := primitive.ObjectIDFromHex(webhookID) + if err != nil { + return err + } - return webhookCollection(s).RemoveId(bson.ObjectIdHex(webhookID)) + filter := bson.M{"_id": webhookId} + + _, err = webhookCollection().DeleteOne(context.TODO(), filter) + if err != nil { + return err + } + + return nil } // UpdateWebhook Updates a webhook for an organization func UpdateWebhook(webhook Webhook) (Webhook, error) { - s := session() - defer s.Close() - err := webhookCollection(s).UpdateId(webhook.ID, webhook) + filter := bson.M{"_id": webhook.ID} + update := bson.M{"$set": webhook} + + _, err := webhookCollection().UpdateOne(context.TODO(), filter, update) return webhook, err } // GetActiveWebhooksByOrgID Gets all active webhooks for a particular organisation func GetActiveWebhooksByOrgID(orgID string) (results []Webhook, err error) { - s := session() - defer s.Close() + filter := bson.M{"orgid": orgID, "disabled": false} + + cursor, err := webhookCollection().Find(context.TODO(), filter) + if err != nil { + return nil, err + } + defer cursor.Close(context.TODO()) - err = webhookCollection(s).Find(bson.M{"orgid": orgID, "disabled": false}).All(&results) + if err := cursor.All(context.TODO(), &results); err != nil { + return nil, err + } return results, err } // GetWebhookCountByPayloadURL Gets the count of webhooks with same payload URL for an organisation -func GetWebhookCountByPayloadURL(orgID string, payloadURL string) (count int, err error) { - s := session() - defer s.Close() +func GetWebhookCountByPayloadURL(orgID string, payloadURL string) (count int64, err error) { - count, err = webhookCollection(s).Find(bson.M{"orgid": orgID, "payloadurl": payloadURL}).Count() + count, err = webhookCollection().CountDocuments(context.TODO(), bson.M{"orgid": orgID, "payloadurl": payloadURL}) return count, err } // GetAllWebhooksByOrgID Gets all webhooks for a given organisation func GetAllWebhooksByOrgID(orgID string) (results []Webhook, err error) { - s := session() - defer s.Close() + filter := bson.M{"orgid": orgID} - err = webhookCollection(s).Find(bson.M{"orgid": orgID}).Sort("-timestamp").All(&results) + options := options.Find().SetSort(bson.D{{Key: "timestamp", Value: -1}}) - return results, err + cursor, err := webhookCollection().Find(context.TODO(), filter, options) + if err != nil { + return nil, err + } + defer cursor.Close(context.TODO()) + + if err := cursor.All(context.TODO(), &results); err != nil { + return nil, err + } + + return results, nil } // GetLastWebhookDelivery Gets the last delivery for a webhook func GetLastWebhookDelivery(webhookID string) (result WebhookDelivery, err error) { - s := session() - defer s.Close() + filter := bson.M{"webhookid": webhookID} - err = webhookDeliveryCollection(s).Find(bson.M{"webhookid": webhookID}).Sort("-executionstarttimestamp").One(&result) + options := options.FindOne().SetSort(bson.D{{Key: "executionstarttimestamp", Value: -1}}) - return result, err + err = webhookDeliveryCollection().FindOne(context.TODO(), filter, options).Decode(&result) + if err != nil { + return WebhookDelivery{}, err + } + + return result, nil } // GetWebhookByPayloadURL Get the webhook for an organisation by payload URL func GetWebhookByPayloadURL(orgID string, payloadURL string) (result Webhook, err error) { - s := session() - defer s.Close() - err = webhookCollection(s).Find(bson.M{"orgid": orgID, "payloadurl": payloadURL}).One(&result) + err = webhookCollection().FindOne(context.TODO(), bson.M{"orgid": orgID, "payloadurl": payloadURL}).Decode(&result) return result, err } // GetWebhookDeliveryByID Gets payload delivery details by ID func GetWebhookDeliveryByID(webhookID string, webhookDeliveryID string) (result WebhookDelivery, err error) { - s := session() - defer s.Close() + webhookDeliveryId, err := primitive.ObjectIDFromHex(webhookDeliveryID) + if err != nil { + return result, err + } - err = webhookDeliveryCollection(s).Find(bson.M{"webhookid": webhookID, "_id": bson.ObjectIdHex(webhookDeliveryID)}).One(&result) + err = webhookDeliveryCollection().FindOne(context.TODO(), bson.M{"webhookid": webhookID, "_id": webhookDeliveryId}).Decode(&result) return result, err } // GetAllDeliveryByWebhookID Gets all webhook deliveries for a webhook func GetAllDeliveryByWebhookID(webhookID string, startID string, limit int) (results []WebhookDelivery, lastID string, err error) { - s := session() - defer s.Close() + filter := bson.M{"webhookid": webhookID} + + options := options.Find() + + if startID != "" { + startId, err := primitive.ObjectIDFromHex(startID) + if err != nil { + return nil, "", err + } + + filter["_id"] = bson.M{"$lt": startId} + } + + options.SetSort(bson.D{{Key: "executionstarttimestamp", Value: -1}}) + options.SetLimit(int64(limit)) - if startID == "" { - err = webhookDeliveryCollection(s).Find(bson.M{"webhookid": webhookID}).Sort("-executionstarttimestamp").Limit(limit).All(&results) - } else { - err = webhookDeliveryCollection(s).Find(bson.M{"webhookid": webhookID, "_id": bson.M{"$lt": bson.ObjectIdHex(startID)}}).Sort("-executionstarttimestamp").Limit(limit).All(&results) + cursor, err := webhookDeliveryCollection().Find(context.TODO(), filter, options) + if err != nil { + return nil, "", err + } + defer cursor.Close(context.TODO()) + + if err := cursor.All(context.TODO(), &results); err != nil { + return nil, "", err } lastID = "" - if err == nil { - if len(results) != 0 && len(results) == (limit) { - lastID = results[len(results)-1].ID.Hex() - } + + if len(results) != 0 && len(results) == (limit) { + lastID = results[len(results)-1].ID.Hex() } - return results, lastID, err + return results, lastID, nil }