diff --git a/internal/collections/infra/mysql_repo.go b/internal/collections/infra/mysql_repo.go index d00846704..7ec02415a 100644 --- a/internal/collections/infra/mysql_repo.go +++ b/internal/collections/infra/mysql_repo.go @@ -488,18 +488,37 @@ func (r mysqlRepo) AddPersonCollection( ctx context.Context, userID model.UserID, cat collection.PersonCollectCategory, targetID model.PersonID, ) error { - var table = r.q.PersonCollect - err := table.WithContext(ctx).Create(&dao.PersonCollect{ + collect := &dao.PersonCollect{ UserID: userID, Category: string(cat), TargetID: targetID, CreatedTime: uint32(time.Now().Unix()), + } + err := r.q.Transaction(func(tx *query.Query) error { + switch cat { + case collection.PersonCollectCategoryCharacter: + character, err := tx.Character.WithContext(ctx).Where(tx.Character.ID.Eq(targetID)).Take() + if err != nil { + r.log.Error("failed to get character", zap.Error(err)) + return err + } + character.Collects += 1 + tx.Character.WithContext(ctx).Save(character) + case collection.PersonCollectCategoryPerson: + person, err := tx.Person.WithContext(ctx).Where(tx.Person.ID.Eq(targetID)).Take() + if err != nil { + r.log.Error("failed to get person", zap.Error(err)) + return err + } + person.Collects += 1 + tx.Person.WithContext(ctx).Save(person) + } + return tx.PersonCollect.WithContext(ctx).Create(collect) }) if err != nil { r.log.Error("failed to create person collection record", zap.Error(err)) return errgo.Wrap(err, "dal") } - return nil } @@ -507,9 +526,32 @@ func (r mysqlRepo) RemovePersonCollection( ctx context.Context, userID model.UserID, cat collection.PersonCollectCategory, targetID model.PersonID, ) error { - _, err := r.q.PersonCollect.WithContext(ctx). - Where(r.q.PersonCollect.UserID.Eq(userID), r.q.PersonCollect.Category.Eq(string(cat)), - r.q.PersonCollect.TargetID.Eq(targetID)).Delete() + err := r.q.Transaction(func(tx *query.Query) error { + switch cat { + case collection.PersonCollectCategoryCharacter: + character, err := tx.Character.WithContext(ctx).Where(tx.Character.ID.Eq(targetID)).Take() + if err != nil { + r.log.Error("failed to get character", zap.Error(err)) + return err + } + character.Collects -= 1 + tx.Character.WithContext(ctx).Save(character) + case collection.PersonCollectCategoryPerson: + person, err := tx.Person.WithContext(ctx).Where(tx.Person.ID.Eq(targetID)).Take() + if err != nil { + r.log.Error("failed to get person", zap.Error(err)) + return err + } + person.Collects -= 1 + tx.Person.WithContext(ctx).Save(person) + } + _, err := tx.PersonCollect.WithContext(ctx).Where( + tx.PersonCollect.UserID.Eq(userID), + tx.PersonCollect.Category.Eq(string(cat)), + tx.PersonCollect.TargetID.Eq(targetID), + ).Delete() + return err + }) if err != nil { r.log.Error("failed to delete person collection record", zap.Error(err)) return errgo.Wrap(err, "dal") diff --git a/internal/collections/infra/mysql_repo_test.go b/internal/collections/infra/mysql_repo_test.go index df4b89bd8..c7c83fb41 100644 --- a/internal/collections/infra/mysql_repo_test.go +++ b/internal/collections/infra/mysql_repo_test.go @@ -519,13 +519,12 @@ func TestMysqlRepo_GetPersonCollect(t *testing.T) { const mid model.PersonID = 12000 repo, q := getRepo(t) - table := q.PersonCollect test.RunAndCleanup(t, func() { - _, err := table.WithContext(context.TODO()).Where(table.UserID.Eq(uid)).Delete() + _, err := q.PersonCollect.WithContext(context.TODO()).Where(q.PersonCollect.UserID.Eq(uid)).Delete() require.NoError(t, err) }) - err := table.WithContext(context.Background()).Create(&dao.PersonCollect{ + err := q.PersonCollect.WithContext(context.Background()).Create(&dao.PersonCollect{ UserID: uid, Category: cat, TargetID: mid, @@ -547,20 +546,33 @@ func TestMysqlRepo_AddPersonCollect(t *testing.T) { const uid model.UserID = 40000 const cat = "prsn" const mid model.PersonID = 13000 + const collects uint32 = 10 repo, q := getRepo(t) table := q.PersonCollect test.RunAndCleanup(t, func() { _, err := table.WithContext(context.TODO()).Where(table.UserID.Eq(uid)).Delete() require.NoError(t, err) + _, err = q.Person.WithContext(context.TODO()).Where(q.Person.ID.Eq(mid)).Delete() + require.NoError(t, err) + }) + + err := q.Person.WithContext(context.Background()).Create(&dao.Person{ + ID: mid, + Collects: collects, }) + require.NoError(t, err) - err := repo.AddPersonCollection(context.Background(), uid, cat, mid) + err = repo.AddPersonCollection(context.Background(), uid, cat, mid) require.NoError(t, err) r, err := table.WithContext(context.TODO()).Where(table.UserID.Eq(uid)).Take() require.NoError(t, err) require.NotZero(t, r.ID) + + p, err := q.Person.WithContext(context.Background()).Where(q.Person.ID.Eq(mid)).Take() + require.NoError(t, err) + require.Equal(t, collects+1, p.Collects) } func TestMysqlRepo_RemovePersonCollect(t *testing.T) { @@ -570,15 +582,22 @@ func TestMysqlRepo_RemovePersonCollect(t *testing.T) { const uid model.UserID = 41000 const cat = "prsn" const mid model.PersonID = 14000 + const collects uint32 = 10 repo, q := getRepo(t) - table := q.PersonCollect test.RunAndCleanup(t, func() { - _, err := table.WithContext(context.TODO()).Where(table.UserID.Eq(uid)).Delete() + _, err := q.PersonCollect.WithContext(context.TODO()).Where(q.PersonCollect.UserID.Eq(uid)).Delete() + require.NoError(t, err) + _, err = q.Person.WithContext(context.TODO()).Where(q.Person.ID.Eq(mid)).Delete() require.NoError(t, err) }) - err := table.WithContext(context.Background()).Create(&dao.PersonCollect{ + err := q.Person.WithContext(context.Background()).Create(&dao.Person{ + ID: mid, + Collects: collects, + }) + require.NoError(t, err) + err = q.PersonCollect.WithContext(context.Background()).Create(&dao.PersonCollect{ UserID: uid, Category: cat, TargetID: mid, @@ -586,15 +605,19 @@ func TestMysqlRepo_RemovePersonCollect(t *testing.T) { }) require.NoError(t, err) - r, err := table.WithContext(context.TODO()).Where(table.UserID.Eq(uid)).Take() + r, err := q.PersonCollect.WithContext(context.TODO()).Where(q.PersonCollect.UserID.Eq(uid)).Take() require.NoError(t, err) require.NotZero(t, r.ID) err = repo.RemovePersonCollection(context.Background(), uid, cat, mid) require.NoError(t, err) - _, err = table.WithContext(context.TODO()).Where(table.UserID.Eq(uid)).Take() + _, err = q.PersonCollect.WithContext(context.TODO()).Where(q.PersonCollect.UserID.Eq(uid)).Take() require.ErrorIs(t, err, gorm.ErrRecordNotFound) + + p, err := q.Person.WithContext(context.Background()).Where(q.Person.ID.Eq(mid)).Take() + require.NoError(t, err) + require.Equal(t, collects-1, p.Collects) } func TestMysqlRepo_CountPersonCollections(t *testing.T) {