diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..8b98f32 --- /dev/null +++ b/errors.go @@ -0,0 +1,7 @@ +package mgm + +import ( + "errors" +) + +var ErrVersionning = errors.New("versionning error") diff --git a/field.go b/field.go index 378e59d..dd297a8 100644 --- a/field.go +++ b/field.go @@ -18,6 +18,10 @@ type DateFields struct { UpdatedAt time.Time `json:"updated_at" bson:"updated_at"` } +type VersionField struct { + Version_ int `json:"_v" bson:"_v"` +} + // PrepareID method prepares the ID value to be used for filtering // e.g convert hex-string ID value to bson.ObjectId func (f *IDField) PrepareID(id interface{}) (interface{}, error) { @@ -39,6 +43,26 @@ func (f *IDField) SetID(id interface{}) { f.ID = id.(primitive.ObjectID) } +// GetVersion returns the model version field +func (f *VersionField) GetVersion() interface{} { + return f.Version_ +} + +// GetVersionFieldName returns the field name holding the version field (has to match the bson tag) +func (f *VersionField) GetVersionFieldName() string { + return "_v" +} + +// SetVersion returns the model version field +func (f *VersionField) IncrementVersion() { + f.Version_++ +} + +// Determines whether the version field is in its zero value +func (f *VersionField) IsVersionZero() bool { + return f.Version_ == 0 +} + //-------------------------------- // DateField methods //-------------------------------- @@ -51,7 +75,7 @@ func (f *DateFields) Creating() error { return nil } -// Saving hook is used here to set the `updated_at` field +// Saving hook is used here to set the `updated_at` field // value when creating or updateing a model. // TODO: get context as param the next version(4). func (f *DateFields) Saving() error { diff --git a/model.go b/model.go index 8a85dfa..c84345b 100644 --- a/model.go +++ b/model.go @@ -26,6 +26,13 @@ type Model interface { SetID(id interface{}) } +type Versionable interface { + Version() interface{} + GetVersionBsonFieldName() string + IncrementVersion() + IsVersionZero() bool +} + // DefaultModel struct contains a model's default fields. type DefaultModel struct { IDField `bson:",inline"` diff --git a/model_test.go b/model_test.go index 835ec4b..ee24bff 100644 --- a/model_test.go +++ b/model_test.go @@ -1,10 +1,11 @@ package mgm_test import ( + "testing" + "github.com/kamva/mgm/v3/internal/util" "github.com/stretchr/testify/require" "go.mongodb.org/mongo-driver/bson/primitive" - "testing" ) func TestPrepareInvalidId(t *testing.T) { @@ -23,3 +24,11 @@ func TestPrepareId(t *testing.T) { require.Equal(t, val.(primitive.ObjectID), id) util.AssertErrIsNil(t, err) } + +func TestVersion(t *testing.T) { + d := &Doc{} + require.Equal(t, 0, d.GetVersion()) + require.Equal(t, "_v", d.GetVersionFieldName()) + d.IncrementVersion() + require.Equal(t, 1, d.GetVersion()) +} diff --git a/operation.go b/operation.go index f12a75c..2a14354 100644 --- a/operation.go +++ b/operation.go @@ -2,6 +2,8 @@ package mgm import ( "context" + "fmt" + "github.com/kamva/mgm/v3/field" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/options" @@ -13,6 +15,12 @@ func create(ctx context.Context, c *Collection, model Model, opts ...*options.In return err } + //If this model is versionable and its version value is zero, increment it to initialize it in the db + vmodel, isVersionable := model.(Versionable) + if isVersionable && vmodel.IsVersionZero() { + vmodel.IncrementVersion() + } + res, err := c.InsertOne(ctx, model, opts...) if err != nil { @@ -30,17 +38,44 @@ func first(ctx context.Context, c *Collection, filter interface{}, model Model, } func update(ctx context.Context, c *Collection, model Model, opts ...*options.UpdateOptions) error { + + //Get current version before calling hooks that could alter it + var v interface{} + vmodel, isVersionable := model.(Versionable) + if isVersionable { + v = vmodel.Version() + } + // Call to saving hook if err := callToBeforeUpdateHooks(ctx, model); err != nil { return err } - res, err := c.UpdateOne(ctx, bson.M{field.ID: model.GetID()}, bson.M{"$set": model}, opts...) + query := bson.M{field.ID: model.GetID()} + + if isVersionable { + if vmodel.IsVersionZero() { + query["$or"] = bson.A{ + bson.M{vmodel.GetVersionBsonFieldName(): v}, + bson.M{vmodel.GetVersionBsonFieldName(): bson.M{"$exists": false}}, + } + } else { + query[vmodel.GetVersionBsonFieldName()] = v + } + + vmodel.IncrementVersion() + } + + res, err := c.UpdateOne(ctx, query, bson.M{"$set": model}, opts...) if err != nil { return err } + if isVersionable && res.MatchedCount == 0 { + return fmt.Errorf("document %v %v with version %v could not be found %w", c.Name(), model.GetID(), v, ErrVersionning) + } + return callToAfterUpdateHooks(ctx, res, model) } diff --git a/testhelpers_test.go b/testhelpers_test.go index 1243558..bb2616a 100644 --- a/testhelpers_test.go +++ b/testhelpers_test.go @@ -1,11 +1,12 @@ package mgm_test import ( + "testing" + "github.com/kamva/mgm/v3" "github.com/kamva/mgm/v3/internal/util" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/options" - "testing" ) func setupDefConnection() { @@ -43,6 +44,7 @@ func findDoc(t *testing.T) *Doc { type Doc struct { mgm.DefaultModel `bson:",inline"` + mgm.VersionField `bson:",inline"` Name string `bson:"name"` Age int `bson:"age"`