diff --git a/integration_tests/handler/contest_test.go b/integration_tests/handler/contest_test.go index fb3a134c..6554c3b7 100644 --- a/integration_tests/handler/contest_test.go +++ b/integration_tests/handler/contest_test.go @@ -291,7 +291,13 @@ func TestEditContest(t *testing.T) { "204 without change": { http.StatusNoContent, mockdata.ContestID3(), - schema.EditContestRequest{}, + schema.EditContestRequest{ + Duration: &schema.Duration{ + // DurationのValidationで落とされるのでSinceも埋める + Since: mockdata.CloneMockContests()[2].Since, + Until: &until, // Untilはnilにすると「未定」に変更される + }, + }, nil, }, "400 invalid contestID": { diff --git a/integration_tests/handler/project_test.go b/integration_tests/handler/project_test.go index b4f47645..26f5f092 100644 --- a/integration_tests/handler/project_test.go +++ b/integration_tests/handler/project_test.go @@ -253,7 +253,16 @@ func TestEditProject(t *testing.T) { "204 without changes": { http.StatusNoContent, mockdata.ProjectID2(), - schema.EditProjectRequest{}, + schema.EditProjectRequest{ + Duration: &schema.YearWithSemesterDuration{ + // DurationのValidationで落とされるのでSinceも埋める + Since: schema.YearWithSemester{ + Year: mockdata.CloneMockProjects()[1].SinceYear, + Semester: schema.Semester(mockdata.CloneMockProjects()[1].SinceSemester), + }, + Until: duration.Until, // Untilはnilにすると「未定」に変更される + }, + }, nil, }, "400 invalid projectID": { diff --git a/internal/infrastructure/repository/contest_impl.go b/internal/infrastructure/repository/contest_impl.go index 6971ce4f..82c90c34 100644 --- a/internal/infrastructure/repository/contest_impl.go +++ b/internal/infrastructure/repository/contest_impl.go @@ -3,6 +3,7 @@ package repository import ( "context" "errors" + "time" "github.com/gofrs/uuid" "github.com/traPtitech/traPortfolio/internal/domain" @@ -107,6 +108,16 @@ func (r *ContestRepository) CreateContest(ctx context.Context, args *repository. } func (r *ContestRepository) UpdateContest(ctx context.Context, contestID uuid.UUID, args *repository.UpdateContestArgs) error { + origin := &model.Contest{} + if err := r.h. + WithContext(ctx). + Where(&model.Contest{ID: contestID}). + First(origin). + Error; err != nil { + return err + } + untilEmpty := origin.Until.Equal(time.Time{}) + changes := map[string]interface{}{} if v, ok := args.Name.V(); ok { changes["name"] = v @@ -120,7 +131,7 @@ func (r *ContestRepository) UpdateContest(ctx context.Context, contestID uuid.UU if v, ok := args.Since.V(); ok { changes["since"] = v } - if v, ok := args.Until.V(); ok { + if v, ok := args.Until.V(); ok == untilEmpty || v != origin.Until { changes["until"] = v } @@ -130,14 +141,6 @@ func (r *ContestRepository) UpdateContest(ctx context.Context, contestID uuid.UU var c model.Contest err := r.h.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - if err := tx. - WithContext(ctx). - Where(&model.Contest{ID: contestID}). - First(&model.Contest{}). - Error; err != nil { - return err - } - if err := tx. WithContext(ctx). Model(&model.Contest{ID: contestID}). diff --git a/internal/infrastructure/repository/contest_test.go b/internal/infrastructure/repository/contest_test.go index d1285e65..d1532de0 100644 --- a/internal/infrastructure/repository/contest_test.go +++ b/internal/infrastructure/repository/contest_test.go @@ -3,6 +3,7 @@ package repository import ( "context" "testing" + "time" "github.com/gofrs/uuid" "github.com/samber/lo" @@ -12,6 +13,7 @@ import ( "github.com/traPtitech/traPortfolio/internal/infrastructure/external/mock_external" "go.uber.org/mock/gomock" + "github.com/traPtitech/traPortfolio/internal/pkgs/optional" "github.com/traPtitech/traPortfolio/internal/pkgs/random" "github.com/traPtitech/traPortfolio/internal/usecases/repository" ) @@ -122,6 +124,7 @@ func Test_UpdateContest(t *testing.T) { t.Run("update no fields", func(t *testing.T) { args := &repository.UpdateContestArgs{} + args.Until = optional.New(contest.TimeEnd, contest.TimeEnd == time.Time{}) err := repo.UpdateContest(context.Background(), contest.ID, args) assert.NoError(t, err) @@ -130,6 +133,23 @@ func Test_UpdateContest(t *testing.T) { assert.Equal(t, contest, gotContest) }) + + t.Run("update until to nil", func(t *testing.T) { + argWithUntil := random.CreateContestArgs() + argWithUntil.Until = random.UpdateContestArgs().Since + contest, err := repo.CreateContest(context.Background(), argWithUntil) + assert.NoError(t, err) + + argWithoutUntil := random.UpdateContestArgs() + argWithoutUntil.Until = optional.Of[time.Time]{} + err = repo.UpdateContest(context.Background(), contest.ID, argWithoutUntil) + assert.NoError(t, err) + }) + + t.Run("update failed: not contest id", func(t *testing.T) { + err = repo.UpdateContest(context.Background(), random.UUID(), &repository.UpdateContestArgs{}) + assert.Error(t, err) + }) } func Test_DeleteContest(t *testing.T) { diff --git a/internal/infrastructure/repository/project_impl.go b/internal/infrastructure/repository/project_impl.go index abd1ae8f..81938a71 100644 --- a/internal/infrastructure/repository/project_impl.go +++ b/internal/infrastructure/repository/project_impl.go @@ -138,6 +138,15 @@ func (r *ProjectRepository) CreateProject(ctx context.Context, args *repository. } func (r *ProjectRepository) UpdateProject(ctx context.Context, projectID uuid.UUID, args *repository.UpdateProjectArgs) error { + origin := &model.Project{} + if err := r.h. + WithContext(ctx). + Where(&model.Project{ID: projectID}). + First(origin). + Error; err != nil { + return err + } + changes := map[string]interface{}{} if v, ok := args.Name.V(); ok { changes["name"] = v @@ -154,10 +163,16 @@ func (r *ProjectRepository) UpdateProject(ctx context.Context, projectID uuid.UU changes["since_semester"] = ss } } - if uy, ok := args.UntilYear.V(); ok { - if us, ok := args.UntilSemester.V(); ok { - changes["until_year"] = uy - changes["until_semester"] = us + untilYear, validYear := args.UntilYear.V() + untilSemester, validSemester := args.UntilSemester.V() + if validYear == validSemester { + originUntil := domain.YearWithSemester{Year: origin.UntilYear, Semester: origin.UntilSemester} + argUntil := domain.YearWithSemester{Year: int(args.UntilYear.ValueOrZero()), Semester: int(args.UntilSemester.ValueOrZero())} + originValid := originUntil.IsValid() + // Untilが未定かどうかの状態が異なるか、Untilが異なる場合に更新 + if validYear != originValid || (validYear && argUntil != originUntil) { + changes["until_year"] = untilYear + changes["until_semester"] = untilSemester } } diff --git a/internal/infrastructure/repository/project_test.go b/internal/infrastructure/repository/project_test.go index 64e95179..144fbfe7 100644 --- a/internal/infrastructure/repository/project_test.go +++ b/internal/infrastructure/repository/project_test.go @@ -146,8 +146,12 @@ func TestProjectRepository_UpdateProject(t *testing.T) { project1.Duration.Since.Semester = int(ss) } } - if uy, ok := arg1.UntilYear.V(); ok { - if us, ok := arg1.UntilSemester.V(); ok { + uy, validYear := arg1.UntilYear.V() + us, validSemester := arg1.UntilSemester.V() + if validYear == validSemester { + originUntil, originValid := project1.Duration.Until.V() + argUntil := domain.YearWithSemester{Year: int(uy), Semester: int(us)} + if validYear != originValid || (validYear && argUntil != originUntil) { project1.Duration.Until = optional.From(domain.YearWithSemester{ Year: int(uy), Semester: int(us),