diff --git a/graphql/schema/schema.graphql b/graphql/schema/schema.graphql index f11edb46f36..251c2af838c 100644 --- a/graphql/schema/schema.graphql +++ b/graphql/schema/schema.graphql @@ -359,6 +359,12 @@ type Mutation { groupsDestroy(ids: [ID!]!): Boolean! bulkGroupUpdate(input: BulkGroupUpdateInput!): [Group!] + addGroupSubGroups(input: GroupSubGroupAddInput!): Boolean! + removeGroupSubGroups(input: GroupSubGroupRemoveInput!): Boolean! + + "Reorder sub groups within a group. Returns true if successful." + reorderSubGroups(input: ReorderSubGroupsInput!): Boolean! + tagCreate(input: TagCreateInput!): Tag tagUpdate(input: TagUpdateInput!): Tag tagDestroy(input: TagDestroyInput!): Boolean! diff --git a/graphql/schema/types/filters.graphql b/graphql/schema/types/filters.graphql index 1ca8c1fb08d..f0f84efda8c 100644 --- a/graphql/schema/types/filters.graphql +++ b/graphql/schema/types/filters.graphql @@ -261,7 +261,7 @@ input SceneFilterType { "Filter to only include scenes with this movie" movies: MultiCriterionInput @deprecated(reason: "use groups instead") "Filter to only include scenes with this group" - groups: MultiCriterionInput + groups: HierarchicalMultiCriterionInput "Filter to only include scenes with this gallery" galleries: MultiCriterionInput "Filter to only include scenes with these tags" @@ -390,6 +390,15 @@ input GroupFilterType { "Filter by last update time" updated_at: TimestampCriterionInput + "Filter by containing groups" + containing_groups: HierarchicalMultiCriterionInput + "Filter by sub groups" + sub_groups: HierarchicalMultiCriterionInput + "Filter by number of containing groups the group has" + containing_group_count: IntCriterionInput + "Filter by number of sub-groups the group has" + sub_group_count: IntCriterionInput + "Filter by related scenes that meet this criteria" scenes_filter: SceneFilterType "Filter by related studios that meet this criteria" diff --git a/graphql/schema/types/group.graphql b/graphql/schema/types/group.graphql index 15bb3556ca6..b42e4fd1fef 100644 --- a/graphql/schema/types/group.graphql +++ b/graphql/schema/types/group.graphql @@ -1,3 +1,9 @@ +"GroupDescription represents a relationship to a group with a description of the relationship" +type GroupDescription { + group: Group! + description: String +} + type Group { id: ID! name: String! @@ -15,12 +21,21 @@ type Group { created_at: Time! updated_at: Time! + containing_groups: [GroupDescription!]! + sub_groups: [GroupDescription!]! + front_image_path: String # Resolver back_image_path: String # Resolver - scene_count: Int! # Resolver + scene_count(depth: Int): Int! # Resolver + sub_group_count(depth: Int): Int! # Resolver scenes: [Scene!]! } +input GroupDescriptionInput { + group_id: ID! + description: String +} + input GroupCreateInput { name: String! aliases: String @@ -34,6 +49,10 @@ input GroupCreateInput { synopsis: String urls: [String!] tag_ids: [ID!] + + containing_groups: [GroupDescriptionInput!] + sub_groups: [GroupDescriptionInput!] + "This should be a URL or a base64 encoded data URL" front_image: String "This should be a URL or a base64 encoded data URL" @@ -53,12 +72,21 @@ input GroupUpdateInput { synopsis: String urls: [String!] tag_ids: [ID!] + + containing_groups: [GroupDescriptionInput!] + sub_groups: [GroupDescriptionInput!] + "This should be a URL or a base64 encoded data URL" front_image: String "This should be a URL or a base64 encoded data URL" back_image: String } +input BulkUpdateGroupDescriptionsInput { + groups: [GroupDescriptionInput!]! + mode: BulkUpdateIdMode! +} + input BulkGroupUpdateInput { clientMutationId: String ids: [ID!] @@ -68,13 +96,42 @@ input BulkGroupUpdateInput { director: String urls: BulkUpdateStrings tag_ids: BulkUpdateIds + + containing_groups: BulkUpdateGroupDescriptionsInput + sub_groups: BulkUpdateGroupDescriptionsInput } input GroupDestroyInput { id: ID! } +input ReorderSubGroupsInput { + "ID of the group to reorder sub groups for" + group_id: ID! + """ + IDs of the sub groups to reorder. These must be a subset of the current sub groups. + Sub groups will be inserted in this order at the insert_index + """ + sub_group_ids: [ID!]! + "The sub-group ID at which to insert the sub groups" + insert_at_id: ID! + "If true, the sub groups will be inserted after the insert_index, otherwise they will be inserted before" + insert_after: Boolean +} + type FindGroupsResultType { count: Int! groups: [Group!]! } + +input GroupSubGroupAddInput { + containing_group_id: ID! + sub_groups: [GroupDescriptionInput!]! + "The index at which to insert the sub groups. If not provided, the sub groups will be appended to the end" + insert_index: Int +} + +input GroupSubGroupRemoveInput { + containing_group_id: ID! + sub_group_ids: [ID!]! +} diff --git a/graphql/schema/types/movie.graphql b/graphql/schema/types/movie.graphql index 0723bcc4f28..845827b3f17 100644 --- a/graphql/schema/types/movie.graphql +++ b/graphql/schema/types/movie.graphql @@ -18,7 +18,7 @@ type Movie { front_image_path: String # Resolver back_image_path: String # Resolver - scene_count: Int! # Resolver + scene_count(depth: Int): Int! # Resolver scenes: [Scene!]! } diff --git a/graphql/schema/types/performer.graphql b/graphql/schema/types/performer.graphql index 8ac6c6579ad..d6f3dd832c4 100644 --- a/graphql/schema/types/performer.graphql +++ b/graphql/schema/types/performer.graphql @@ -56,7 +56,7 @@ type Performer { weight: Int created_at: Time! updated_at: Time! - groups: [Group!]! @deprecated(reason: "use groups instead") + groups: [Group!]! movies: [Movie!]! @deprecated(reason: "use groups instead") } diff --git a/internal/api/changeset_translator.go b/internal/api/changeset_translator.go index efac25087d6..1170088aac9 100644 --- a/internal/api/changeset_translator.go +++ b/internal/api/changeset_translator.go @@ -434,3 +434,64 @@ func (t changesetTranslator) updateGroupIDsBulk(value *BulkUpdateIds, field stri Mode: value.Mode, }, nil } + +func groupsDescriptionsFromGroupInput(input []*GroupDescriptionInput) ([]models.GroupIDDescription, error) { + ret := make([]models.GroupIDDescription, len(input)) + + for i, v := range input { + gID, err := strconv.Atoi(v.GroupID) + if err != nil { + return nil, fmt.Errorf("invalid group ID: %s", v.GroupID) + } + + ret[i] = models.GroupIDDescription{ + GroupID: gID, + } + if v.Description != nil { + ret[i].Description = *v.Description + } + } + + return ret, nil +} + +func (t changesetTranslator) groupIDDescriptions(value []*GroupDescriptionInput) (models.RelatedGroupDescriptions, error) { + groupsScenes, err := groupsDescriptionsFromGroupInput(value) + if err != nil { + return models.RelatedGroupDescriptions{}, err + } + + return models.NewRelatedGroupDescriptions(groupsScenes), nil +} + +func (t changesetTranslator) updateGroupIDDescriptions(value []*GroupDescriptionInput, field string) (*models.UpdateGroupDescriptions, error) { + if !t.hasField(field) { + return nil, nil + } + + groupsScenes, err := groupsDescriptionsFromGroupInput(value) + if err != nil { + return nil, err + } + + return &models.UpdateGroupDescriptions{ + Groups: groupsScenes, + Mode: models.RelationshipUpdateModeSet, + }, nil +} + +func (t changesetTranslator) updateGroupIDDescriptionsBulk(value *BulkUpdateGroupDescriptionsInput, field string) (*models.UpdateGroupDescriptions, error) { + if !t.hasField(field) || value == nil { + return nil, nil + } + + groups, err := groupsDescriptionsFromGroupInput(value.Groups) + if err != nil { + return nil, err + } + + return &models.UpdateGroupDescriptions{ + Groups: groups, + Mode: value.Mode, + }, nil +} diff --git a/internal/api/resolver.go b/internal/api/resolver.go index e5c635b9a7d..ab6eead7e5e 100644 --- a/internal/api/resolver.go +++ b/internal/api/resolver.go @@ -37,6 +37,7 @@ type Resolver struct { sceneService manager.SceneService imageService manager.ImageService galleryService manager.GalleryService + groupService manager.GroupService hookExecutor hookExecutor } diff --git a/internal/api/resolver_model_movie.go b/internal/api/resolver_model_movie.go index abbbccaf10a..04018d81fbb 100644 --- a/internal/api/resolver_model_movie.go +++ b/internal/api/resolver_model_movie.go @@ -5,7 +5,9 @@ import ( "github.com/stashapp/stash/internal/api/loaders" "github.com/stashapp/stash/internal/api/urlbuilders" + "github.com/stashapp/stash/pkg/group" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/scene" ) func (r *groupResolver) Date(ctx context.Context, obj *models.Group) (*string, error) { @@ -71,6 +73,68 @@ func (r groupResolver) Tags(ctx context.Context, obj *models.Group) (ret []*mode return ret, firstError(errs) } +func (r groupResolver) relatedGroups(ctx context.Context, rgd models.RelatedGroupDescriptions) (ret []*GroupDescription, err error) { + // rgd must be loaded + gds := rgd.List() + ids := make([]int, len(gds)) + for i, gd := range gds { + ids[i] = gd.GroupID + } + + groups, errs := loaders.From(ctx).GroupByID.LoadAll(ids) + + err = firstError(errs) + if err != nil { + return + } + + ret = make([]*GroupDescription, len(groups)) + for i, group := range groups { + ret[i] = &GroupDescription{Group: group} + d := gds[i].Description + if d != "" { + ret[i].Description = &d + } + } + + return ret, firstError(errs) +} + +func (r groupResolver) ContainingGroups(ctx context.Context, obj *models.Group) (ret []*GroupDescription, err error) { + if !obj.ContainingGroups.Loaded() { + if err := r.withReadTxn(ctx, func(ctx context.Context) error { + return obj.LoadContainingGroupIDs(ctx, r.repository.Group) + }); err != nil { + return nil, err + } + } + + return r.relatedGroups(ctx, obj.ContainingGroups) +} + +func (r groupResolver) SubGroups(ctx context.Context, obj *models.Group) (ret []*GroupDescription, err error) { + if !obj.SubGroups.Loaded() { + if err := r.withReadTxn(ctx, func(ctx context.Context) error { + return obj.LoadSubGroupIDs(ctx, r.repository.Group) + }); err != nil { + return nil, err + } + } + + return r.relatedGroups(ctx, obj.SubGroups) +} + +func (r *groupResolver) SubGroupCount(ctx context.Context, obj *models.Group, depth *int) (ret int, err error) { + if err := r.withReadTxn(ctx, func(ctx context.Context) error { + ret, err = group.CountByContainingGroupID(ctx, r.repository.Group, obj.ID, depth) + return err + }); err != nil { + return 0, err + } + + return ret, nil +} + func (r *groupResolver) FrontImagePath(ctx context.Context, obj *models.Group) (*string, error) { var hasImage bool if err := r.withReadTxn(ctx, func(ctx context.Context) error { @@ -106,9 +170,9 @@ func (r *groupResolver) BackImagePath(ctx context.Context, obj *models.Group) (* return &imagePath, nil } -func (r *groupResolver) SceneCount(ctx context.Context, obj *models.Group) (ret int, err error) { +func (r *groupResolver) SceneCount(ctx context.Context, obj *models.Group, depth *int) (ret int, err error) { if err := r.withReadTxn(ctx, func(ctx context.Context) error { - ret, err = r.repository.Scene.CountByGroupID(ctx, obj.ID) + ret, err = scene.CountByGroupID(ctx, r.repository.Scene, obj.ID, depth) return err }); err != nil { return 0, err diff --git a/internal/api/resolver_mutation_group.go b/internal/api/resolver_mutation_group.go index d455dd1058c..d75994d1497 100644 --- a/internal/api/resolver_mutation_group.go +++ b/internal/api/resolver_mutation_group.go @@ -6,6 +6,7 @@ import ( "strconv" "github.com/stashapp/stash/internal/static" + "github.com/stashapp/stash/pkg/group" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/sliceutil/stringslice" @@ -43,6 +44,16 @@ func groupFromGroupCreateInput(ctx context.Context, input GroupCreateInput) (*mo return nil, fmt.Errorf("converting tag ids: %w", err) } + newGroup.ContainingGroups, err = translator.groupIDDescriptions(input.ContainingGroups) + if err != nil { + return nil, fmt.Errorf("converting containing group ids: %w", err) + } + + newGroup.SubGroups, err = translator.groupIDDescriptions(input.SubGroups) + if err != nil { + return nil, fmt.Errorf("converting containing group ids: %w", err) + } + if input.Urls != nil { newGroup.URLs = models.NewRelatedStrings(input.Urls) } @@ -82,26 +93,10 @@ func (r *mutationResolver) GroupCreate(ctx context.Context, input GroupCreateInp // Start the transaction and save the group if err := r.withTxn(ctx, func(ctx context.Context) error { - qb := r.repository.Group - - err = qb.Create(ctx, newGroup) - if err != nil { + if err = r.groupService.Create(ctx, newGroup, frontimageData, backimageData); err != nil { return err } - // update image table - if len(frontimageData) > 0 { - if err := qb.UpdateFrontImage(ctx, newGroup.ID, frontimageData); err != nil { - return err - } - } - - if len(backimageData) > 0 { - if err := qb.UpdateBackImage(ctx, newGroup.ID, backimageData); err != nil { - return err - } - } - return nil }); err != nil { return nil, err @@ -141,6 +136,18 @@ func groupPartialFromGroupUpdateInput(translator changesetTranslator, input Grou return } + updatedGroup.ContainingGroups, err = translator.updateGroupIDDescriptions(input.ContainingGroups, "containing_groups") + if err != nil { + err = fmt.Errorf("converting containing group ids: %w", err) + return + } + + updatedGroup.SubGroups, err = translator.updateGroupIDDescriptions(input.SubGroups, "sub_groups") + if err != nil { + err = fmt.Errorf("converting containing group ids: %w", err) + return + } + updatedGroup.URLs = translator.updateStrings(input.Urls, "urls") return updatedGroup, nil @@ -179,26 +186,20 @@ func (r *mutationResolver) GroupUpdate(ctx context.Context, input GroupUpdateInp } } - // Start the transaction and save the group - var group *models.Group if err := r.withTxn(ctx, func(ctx context.Context) error { - qb := r.repository.Group - group, err = qb.UpdatePartial(ctx, groupID, updatedGroup) - if err != nil { - return err + frontImage := group.ImageInput{ + Image: frontimageData, + Set: frontImageIncluded, } - // update image table - if frontImageIncluded { - if err := qb.UpdateFrontImage(ctx, group.ID, frontimageData); err != nil { - return err - } + backImage := group.ImageInput{ + Image: backimageData, + Set: backImageIncluded, } - if backImageIncluded { - if err := qb.UpdateBackImage(ctx, group.ID, backimageData); err != nil { - return err - } + _, err = r.groupService.UpdatePartial(ctx, groupID, updatedGroup, frontImage, backImage) + if err != nil { + return err } return nil @@ -207,9 +208,9 @@ func (r *mutationResolver) GroupUpdate(ctx context.Context, input GroupUpdateInp } // for backwards compatibility - run both movie and group hooks - r.hookExecutor.ExecutePostHooks(ctx, group.ID, hook.GroupUpdatePost, input, translator.getFields()) - r.hookExecutor.ExecutePostHooks(ctx, group.ID, hook.MovieUpdatePost, input, translator.getFields()) - return r.getGroup(ctx, group.ID) + r.hookExecutor.ExecutePostHooks(ctx, groupID, hook.GroupUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, groupID, hook.MovieUpdatePost, input, translator.getFields()) + return r.getGroup(ctx, groupID) } func groupPartialFromBulkGroupUpdateInput(translator changesetTranslator, input BulkGroupUpdateInput) (ret models.GroupPartial, err error) { @@ -230,6 +231,18 @@ func groupPartialFromBulkGroupUpdateInput(translator changesetTranslator, input return } + updatedGroup.ContainingGroups, err = translator.updateGroupIDDescriptionsBulk(input.ContainingGroups, "containing_groups") + if err != nil { + err = fmt.Errorf("converting containing group ids: %w", err) + return + } + + updatedGroup.SubGroups, err = translator.updateGroupIDDescriptionsBulk(input.SubGroups, "sub_groups") + if err != nil { + err = fmt.Errorf("converting containing group ids: %w", err) + return + } + updatedGroup.URLs = translator.optionalURLsBulk(input.Urls, nil) return updatedGroup, nil @@ -254,10 +267,8 @@ func (r *mutationResolver) BulkGroupUpdate(ctx context.Context, input BulkGroupU ret := []*models.Group{} if err := r.withTxn(ctx, func(ctx context.Context) error { - qb := r.repository.Group - for _, groupID := range groupIDs { - group, err := qb.UpdatePartial(ctx, groupID, updatedGroup) + group, err := r.groupService.UpdatePartial(ctx, groupID, updatedGroup, group.ImageInput{}, group.ImageInput{}) if err != nil { return err } @@ -333,3 +344,70 @@ func (r *mutationResolver) GroupsDestroy(ctx context.Context, groupIDs []string) return true, nil } + +func (r *mutationResolver) AddGroupSubGroups(ctx context.Context, input GroupSubGroupAddInput) (bool, error) { + groupID, err := strconv.Atoi(input.ContainingGroupID) + if err != nil { + return false, fmt.Errorf("converting group id: %w", err) + } + + subGroups, err := groupsDescriptionsFromGroupInput(input.SubGroups) + if err != nil { + return false, fmt.Errorf("converting sub group ids: %w", err) + } + + if err := r.withTxn(ctx, func(ctx context.Context) error { + return r.groupService.AddSubGroups(ctx, groupID, subGroups, input.InsertIndex) + }); err != nil { + return false, err + } + + return true, nil +} + +func (r *mutationResolver) RemoveGroupSubGroups(ctx context.Context, input GroupSubGroupRemoveInput) (bool, error) { + groupID, err := strconv.Atoi(input.ContainingGroupID) + if err != nil { + return false, fmt.Errorf("converting group id: %w", err) + } + + subGroupIDs, err := stringslice.StringSliceToIntSlice(input.SubGroupIds) + if err != nil { + return false, fmt.Errorf("converting sub group ids: %w", err) + } + + if err := r.withTxn(ctx, func(ctx context.Context) error { + return r.groupService.RemoveSubGroups(ctx, groupID, subGroupIDs) + }); err != nil { + return false, err + } + + return true, nil +} + +func (r *mutationResolver) ReorderSubGroups(ctx context.Context, input ReorderSubGroupsInput) (bool, error) { + groupID, err := strconv.Atoi(input.GroupID) + if err != nil { + return false, fmt.Errorf("converting group id: %w", err) + } + + subGroupIDs, err := stringslice.StringSliceToIntSlice(input.SubGroupIds) + if err != nil { + return false, fmt.Errorf("converting sub group ids: %w", err) + } + + insertPointID, err := strconv.Atoi(input.InsertAtID) + if err != nil { + return false, fmt.Errorf("converting insert at id: %w", err) + } + + insertAfter := utils.IsTrue(input.InsertAfter) + + if err := r.withTxn(ctx, func(ctx context.Context) error { + return r.groupService.ReorderSubGroups(ctx, groupID, subGroupIDs, insertPointID, insertAfter) + }); err != nil { + return false, err + } + + return true, nil +} diff --git a/internal/api/server.go b/internal/api/server.go index b32ee04a027..63a81da7c2e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -158,11 +158,13 @@ func Initialize() (*Server, error) { sceneService := mgr.SceneService imageService := mgr.ImageService galleryService := mgr.GalleryService + groupService := mgr.GroupService resolver := &Resolver{ repository: repo, sceneService: sceneService, imageService: imageService, galleryService: galleryService, + groupService: groupService, hookExecutor: pluginCache, } diff --git a/internal/dlna/cds.go b/internal/dlna/cds.go index 531fc1cb55c..a38e0e55bed 100644 --- a/internal/dlna/cds.go +++ b/internal/dlna/cds.go @@ -682,7 +682,7 @@ func (me *contentDirectoryService) getGroups() []interface{} { func (me *contentDirectoryService) getGroupScenes(paths []string, host string) []interface{} { sceneFilter := &models.SceneFilterType{ - Groups: &models.MultiCriterionInput{ + Groups: &models.HierarchicalMultiCriterionInput{ Modifier: models.CriterionModifierIncludes, Value: []string{paths[0]}, }, diff --git a/internal/manager/init.go b/internal/manager/init.go index 347d08a153e..020ba944d40 100644 --- a/internal/manager/init.go +++ b/internal/manager/init.go @@ -17,6 +17,7 @@ import ( "github.com/stashapp/stash/pkg/ffmpeg" "github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/gallery" + "github.com/stashapp/stash/pkg/group" "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/job" "github.com/stashapp/stash/pkg/logger" @@ -67,6 +68,10 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) { Folder: db.Folder, } + groupService := &group.Service{ + Repository: db.Group, + } + sceneServer := &SceneServer{ TxnManager: repo.TxnManager, SceneCoverGetter: repo.Scene, @@ -99,6 +104,7 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) { SceneService: sceneService, ImageService: imageService, GalleryService: galleryService, + GroupService: groupService, scanSubs: &subscriptionManager{}, } diff --git a/internal/manager/manager.go b/internal/manager/manager.go index 397503930dc..ffba184d2bd 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -66,6 +66,7 @@ type Manager struct { SceneService SceneService ImageService ImageService GalleryService GalleryService + GroupService GroupService scanSubs *subscriptionManager } diff --git a/internal/manager/repository.go b/internal/manager/repository.go index 766f8039f85..13e1e8ae81b 100644 --- a/internal/manager/repository.go +++ b/internal/manager/repository.go @@ -3,6 +3,7 @@ package manager import ( "context" + "github.com/stashapp/stash/pkg/group" "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene" @@ -33,3 +34,12 @@ type GalleryService interface { Updated(ctx context.Context, galleryID int) error } + +type GroupService interface { + Create(ctx context.Context, group *models.Group, frontimageData []byte, backimageData []byte) error + UpdatePartial(ctx context.Context, id int, updatedGroup models.GroupPartial, frontImage group.ImageInput, backImage group.ImageInput) (*models.Group, error) + + AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error + RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error + ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error +} diff --git a/internal/manager/task_export.go b/internal/manager/task_export.go index ecbcf593af5..19abba2158d 100644 --- a/internal/manager/task_export.go +++ b/internal/manager/task_export.go @@ -1134,6 +1134,10 @@ func (t *ExportTask) exportGroup(ctx context.Context, wg *sync.WaitGroup, jobCha logger.Errorf("[groups] <%s> error getting group urls: %v", m.Name, err) continue } + if err := m.LoadSubGroupIDs(ctx, r.Group); err != nil { + logger.Errorf("[groups] <%s> error getting group sub-groups: %v", m.Name, err) + continue + } newGroupJSON, err := group.ToJSON(ctx, groupReader, studioReader, m) @@ -1150,6 +1154,25 @@ func (t *ExportTask) exportGroup(ctx context.Context, wg *sync.WaitGroup, jobCha newGroupJSON.Tags = tag.GetNames(tags) + subGroups := m.SubGroups.List() + if err := func() error { + for _, sg := range subGroups { + subGroup, err := groupReader.Find(ctx, sg.GroupID) + if err != nil { + return fmt.Errorf("error getting sub group: %v", err) + } + + newGroupJSON.SubGroups = append(newGroupJSON.SubGroups, jsonschema.SubGroupDescription{ + // TODO - this won't be unique + Group: subGroup.Name, + Description: sg.Description, + }) + } + return nil + }(); err != nil { + logger.Errorf("[groups] <%s> %v", m.Name, err) + } + if t.includeDependencies { if m.StudioID != nil { t.studios.IDs = sliceutil.AppendUnique(t.studios.IDs, *m.StudioID) diff --git a/internal/manager/task_import.go b/internal/manager/task_import.go index ae9a5865765..87185c66183 100644 --- a/internal/manager/task_import.go +++ b/internal/manager/task_import.go @@ -327,6 +327,7 @@ func (t *ImportTask) importStudio(ctx context.Context, studioJSON *jsonschema.St func (t *ImportTask) ImportGroups(ctx context.Context) { logger.Info("[groups] importing") + pendingSubs := make(map[string][]*jsonschema.Group) path := t.json.json.Groups files, err := os.ReadDir(path) @@ -351,24 +352,72 @@ func (t *ImportTask) ImportGroups(ctx context.Context) { logger.Progressf("[groups] %d of %d", index, len(files)) if err := r.WithTxn(ctx, func(ctx context.Context) error { - groupImporter := &group.Importer{ - ReaderWriter: r.Group, - StudioWriter: r.Studio, - TagWriter: r.Tag, - Input: *groupJSON, - MissingRefBehaviour: t.MissingRefBehaviour, + return t.importGroup(ctx, groupJSON, pendingSubs, false) + }); err != nil { + var subError group.SubGroupNotExistError + if errors.As(err, &subError) { + missingSub := subError.MissingSubGroup() + pendingSubs[missingSub] = append(pendingSubs[missingSub], groupJSON) + continue } - return performImport(ctx, groupImporter, t.DuplicateBehaviour) - }); err != nil { - logger.Errorf("[groups] <%s> import failed: %v", fi.Name(), err) + logger.Errorf("[groups] <%s> failed to import: %v", fi.Name(), err) continue } } + for _, s := range pendingSubs { + for _, orphanGroupJSON := range s { + if err := r.WithTxn(ctx, func(ctx context.Context) error { + return t.importGroup(ctx, orphanGroupJSON, nil, true) + }); err != nil { + logger.Errorf("[groups] <%s> failed to create: %v", orphanGroupJSON.Name, err) + continue + } + } + } + logger.Info("[groups] import complete") } +func (t *ImportTask) importGroup(ctx context.Context, groupJSON *jsonschema.Group, pendingSub map[string][]*jsonschema.Group, fail bool) error { + r := t.repository + + importer := &group.Importer{ + ReaderWriter: r.Group, + StudioWriter: r.Studio, + TagWriter: r.Tag, + Input: *groupJSON, + MissingRefBehaviour: t.MissingRefBehaviour, + } + + // first phase: return error if parent does not exist + if !fail { + importer.MissingRefBehaviour = models.ImportMissingRefEnumFail + } + + if err := performImport(ctx, importer, t.DuplicateBehaviour); err != nil { + return err + } + + for _, containingGroupJSON := range pendingSub[groupJSON.Name] { + if err := t.importGroup(ctx, containingGroupJSON, pendingSub, fail); err != nil { + var subError group.SubGroupNotExistError + if errors.As(err, &subError) { + missingSub := subError.MissingSubGroup() + pendingSub[missingSub] = append(pendingSub[missingSub], containingGroupJSON) + continue + } + + return fmt.Errorf("failed to create containing group <%s>: %v", containingGroupJSON.Name, err) + } + } + + delete(pendingSub, groupJSON.Name) + + return nil +} + func (t *ImportTask) ImportFiles(ctx context.Context) { logger.Info("[files] importing") diff --git a/pkg/group/create.go b/pkg/group/create.go new file mode 100644 index 00000000000..56d6b7a4ed4 --- /dev/null +++ b/pkg/group/create.go @@ -0,0 +1,41 @@ +package group + +import ( + "context" + "errors" + + "github.com/stashapp/stash/pkg/models" +) + +var ( + ErrEmptyName = errors.New("name cannot be empty") + ErrHierarchyLoop = errors.New("a group cannot be contained by one of its subgroups") +) + +func (s *Service) Create(ctx context.Context, group *models.Group, frontimageData []byte, backimageData []byte) error { + r := s.Repository + + if err := s.validateCreate(ctx, group); err != nil { + return err + } + + err := r.Create(ctx, group) + if err != nil { + return err + } + + // update image table + if len(frontimageData) > 0 { + if err := r.UpdateFrontImage(ctx, group.ID, frontimageData); err != nil { + return err + } + } + + if len(backimageData) > 0 { + if err := r.UpdateBackImage(ctx, group.ID, backimageData); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/group/import.go b/pkg/group/import.go index 4bf038c8776..589e75df30d 100644 --- a/pkg/group/import.go +++ b/pkg/group/import.go @@ -16,6 +16,18 @@ type ImporterReaderWriter interface { FindByName(ctx context.Context, name string, nocase bool) (*models.Group, error) } +type SubGroupNotExistError struct { + missingSubGroup string +} + +func (e SubGroupNotExistError) Error() string { + return fmt.Sprintf("sub group <%s> does not exist", e.missingSubGroup) +} + +func (e SubGroupNotExistError) MissingSubGroup() string { + return e.missingSubGroup +} + type Importer struct { ReaderWriter ImporterReaderWriter StudioWriter models.StudioFinderCreator @@ -202,6 +214,22 @@ func (i *Importer) createStudio(ctx context.Context, name string) (int, error) { } func (i *Importer) PostImport(ctx context.Context, id int) error { + subGroups, err := i.getSubGroups(ctx) + if err != nil { + return err + } + + if len(subGroups) > 0 { + if _, err := i.ReaderWriter.UpdatePartial(ctx, id, models.GroupPartial{ + SubGroups: &models.UpdateGroupDescriptions{ + Groups: subGroups, + Mode: models.RelationshipUpdateModeSet, + }, + }); err != nil { + return fmt.Errorf("error setting parents: %v", err) + } + } + if len(i.frontImageData) > 0 { if err := i.ReaderWriter.UpdateFrontImage(ctx, id, i.frontImageData); err != nil { return fmt.Errorf("error setting group front image: %v", err) @@ -256,3 +284,53 @@ func (i *Importer) Update(ctx context.Context, id int) error { return nil } + +func (i *Importer) getSubGroups(ctx context.Context) ([]models.GroupIDDescription, error) { + var subGroups []models.GroupIDDescription + for _, subGroup := range i.Input.SubGroups { + group, err := i.ReaderWriter.FindByName(ctx, subGroup.Group, false) + if err != nil { + return nil, fmt.Errorf("error finding parent by name: %v", err) + } + + if group == nil { + if i.MissingRefBehaviour == models.ImportMissingRefEnumFail { + return nil, SubGroupNotExistError{missingSubGroup: subGroup.Group} + } + + if i.MissingRefBehaviour == models.ImportMissingRefEnumIgnore { + continue + } + + if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { + parentID, err := i.createSubGroup(ctx, subGroup.Group) + if err != nil { + return nil, err + } + subGroups = append(subGroups, models.GroupIDDescription{ + GroupID: parentID, + Description: subGroup.Description, + }) + } + } else { + subGroups = append(subGroups, models.GroupIDDescription{ + GroupID: group.ID, + Description: subGroup.Description, + }) + } + } + + return subGroups, nil +} + +func (i *Importer) createSubGroup(ctx context.Context, name string) (int, error) { + newGroup := models.NewGroup() + newGroup.Name = name + + err := i.ReaderWriter.Create(ctx, &newGroup) + if err != nil { + return 0, err + } + + return newGroup.ID, nil +} diff --git a/pkg/group/query.go b/pkg/group/query.go index bc0753b0055..b3adafaf523 100644 --- a/pkg/group/query.go +++ b/pkg/group/query.go @@ -30,3 +30,15 @@ func CountByTagID(ctx context.Context, r models.GroupQueryer, id int, depth *int return r.QueryCount(ctx, filter, nil) } + +func CountByContainingGroupID(ctx context.Context, r models.GroupQueryer, id int, depth *int) (int, error) { + filter := &models.GroupFilterType{ + ContainingGroups: &models.HierarchicalMultiCriterionInput{ + Value: []string{strconv.Itoa(id)}, + Modifier: models.CriterionModifierIncludes, + Depth: depth, + }, + } + + return r.QueryCount(ctx, filter, nil) +} diff --git a/pkg/group/reorder.go b/pkg/group/reorder.go new file mode 100644 index 00000000000..b4afd1b0968 --- /dev/null +++ b/pkg/group/reorder.go @@ -0,0 +1,33 @@ +package group + +import ( + "context" + "errors" + + "github.com/stashapp/stash/pkg/models" +) + +var ErrInvalidInsertIndex = errors.New("invalid insert index") + +func (s *Service) ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error { + // get the group + existing, err := s.Repository.Find(ctx, groupID) + if err != nil { + return err + } + + // ensure it exists + if existing == nil { + return models.ErrNotFound + } + + // TODO - ensure the subgroups exist in the group + + // ensure the insert index is valid + if insertPointID < 0 { + return ErrInvalidInsertIndex + } + + // reorder the subgroups + return s.Repository.ReorderSubGroups(ctx, groupID, subGroupIDs, insertPointID, insertAfter) +} diff --git a/pkg/group/service.go b/pkg/group/service.go new file mode 100644 index 00000000000..ff6e0354184 --- /dev/null +++ b/pkg/group/service.go @@ -0,0 +1,46 @@ +package group + +import ( + "context" + + "github.com/stashapp/stash/pkg/models" +) + +type CreatorUpdater interface { + models.GroupGetter + models.GroupCreator + models.GroupUpdater + + models.ContainingGroupLoader + models.SubGroupLoader + + AnscestorFinder + SubGroupIDFinder + SubGroupAdder + SubGroupRemover + SubGroupReorderer +} + +type AnscestorFinder interface { + FindInAncestors(ctx context.Context, ascestorIDs []int, ids []int) ([]int, error) +} + +type SubGroupIDFinder interface { + FindSubGroupIDs(ctx context.Context, containingID int, ids []int) ([]int, error) +} + +type SubGroupAdder interface { + AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error +} + +type SubGroupRemover interface { + RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error +} + +type SubGroupReorderer interface { + ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertID int, insertAfter bool) error +} + +type Service struct { + Repository CreatorUpdater +} diff --git a/pkg/group/update.go b/pkg/group/update.go new file mode 100644 index 00000000000..d0bc9602add --- /dev/null +++ b/pkg/group/update.go @@ -0,0 +1,112 @@ +package group + +import ( + "context" + "fmt" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sliceutil" +) + +type SubGroupAlreadyInGroupError struct { + GroupIDs []int +} + +func (e *SubGroupAlreadyInGroupError) Error() string { + return fmt.Sprintf("subgroups with IDs %v already in group", e.GroupIDs) +} + +type ImageInput struct { + Image []byte + Set bool +} + +func (s *Service) UpdatePartial(ctx context.Context, id int, updatedGroup models.GroupPartial, frontImage ImageInput, backImage ImageInput) (*models.Group, error) { + if err := s.validateUpdate(ctx, id, updatedGroup); err != nil { + return nil, err + } + + r := s.Repository + + group, err := r.UpdatePartial(ctx, id, updatedGroup) + if err != nil { + return nil, err + } + + // update image table + if frontImage.Set { + if err := r.UpdateFrontImage(ctx, id, frontImage.Image); err != nil { + return nil, err + } + } + + if backImage.Set { + if err := r.UpdateBackImage(ctx, id, backImage.Image); err != nil { + return nil, err + } + } + + return group, nil +} + +func (s *Service) AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error { + // get the group + existing, err := s.Repository.Find(ctx, groupID) + if err != nil { + return err + } + + // ensure it exists + if existing == nil { + return models.ErrNotFound + } + + // ensure the subgroups aren't already sub-groups of the group + subGroupIDs := sliceutil.Map(subGroups, func(sg models.GroupIDDescription) int { + return sg.GroupID + }) + + existingSubGroupIDs, err := s.Repository.FindSubGroupIDs(ctx, groupID, subGroupIDs) + if err != nil { + return err + } + + if len(existingSubGroupIDs) > 0 { + return &SubGroupAlreadyInGroupError{ + GroupIDs: existingSubGroupIDs, + } + } + + // validate the hierarchy + d := &models.UpdateGroupDescriptions{ + Groups: subGroups, + Mode: models.RelationshipUpdateModeAdd, + } + if err := s.validateUpdateGroupHierarchy(ctx, existing, nil, d); err != nil { + return err + } + + // validate insert index + if insertIndex != nil && *insertIndex < 0 { + return ErrInvalidInsertIndex + } + + // add the subgroups + return s.Repository.AddSubGroups(ctx, groupID, subGroups, insertIndex) +} + +func (s *Service) RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error { + // get the group + existing, err := s.Repository.Find(ctx, groupID) + if err != nil { + return err + } + + // ensure it exists + if existing == nil { + return models.ErrNotFound + } + + // add the subgroups + return s.Repository.RemoveSubGroups(ctx, groupID, subGroupIDs) +} diff --git a/pkg/group/validate.go b/pkg/group/validate.go new file mode 100644 index 00000000000..723b9f6997a --- /dev/null +++ b/pkg/group/validate.go @@ -0,0 +1,117 @@ +package group + +import ( + "context" + "strings" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sliceutil" +) + +func (s *Service) validateCreate(ctx context.Context, group *models.Group) error { + if err := validateName(group.Name); err != nil { + return err + } + + containingIDs := group.ContainingGroups.IDs() + subIDs := group.SubGroups.IDs() + + if err := s.validateGroupHierarchy(ctx, containingIDs, subIDs); err != nil { + return err + } + + return nil +} + +func (s *Service) validateUpdate(ctx context.Context, id int, partial models.GroupPartial) error { + // get the existing group - ensure it exists + existing, err := s.Repository.Find(ctx, id) + if err != nil { + return err + } + + if existing == nil { + return models.ErrNotFound + } + + if partial.Name.Set { + if err := validateName(partial.Name.Value); err != nil { + return err + } + } + + if err := s.validateUpdateGroupHierarchy(ctx, existing, partial.ContainingGroups, partial.SubGroups); err != nil { + return err + } + + return nil +} + +func validateName(n string) error { + // ensure name is not empty + if strings.TrimSpace(n) == "" { + return ErrEmptyName + } + + return nil +} + +func (s *Service) validateGroupHierarchy(ctx context.Context, containingIDs []int, subIDs []int) error { + // only need to validate if both are non-empty + if len(containingIDs) == 0 || len(subIDs) == 0 { + return nil + } + + // ensure none of the containing groups are in the sub groups + found, err := s.Repository.FindInAncestors(ctx, containingIDs, subIDs) + if err != nil { + return err + } + + if len(found) > 0 { + return ErrHierarchyLoop + } + + return nil +} + +func (s *Service) validateUpdateGroupHierarchy(ctx context.Context, existing *models.Group, containingGroups *models.UpdateGroupDescriptions, subGroups *models.UpdateGroupDescriptions) error { + // no need to validate if there are no changes + if containingGroups == nil && subGroups == nil { + return nil + } + + if err := existing.LoadContainingGroupIDs(ctx, s.Repository); err != nil { + return err + } + existingContainingGroups := existing.ContainingGroups.List() + + if err := existing.LoadSubGroupIDs(ctx, s.Repository); err != nil { + return err + } + existingSubGroups := existing.SubGroups.List() + + effectiveContainingGroups := existingContainingGroups + if containingGroups != nil { + effectiveContainingGroups = containingGroups.Apply(existingContainingGroups) + } + + effectiveSubGroups := existingSubGroups + if subGroups != nil { + effectiveSubGroups = subGroups.Apply(existingSubGroups) + } + + containingIDs := idsFromGroupDescriptions(effectiveContainingGroups) + subIDs := idsFromGroupDescriptions(effectiveSubGroups) + + // ensure we haven't set the group as a subgroup of itself + if sliceutil.Contains(containingIDs, existing.ID) || sliceutil.Contains(subIDs, existing.ID) { + return ErrHierarchyLoop + } + + return s.validateGroupHierarchy(ctx, containingIDs, subIDs) +} + +func idsFromGroupDescriptions(v []models.GroupIDDescription) []int { + return sliceutil.Map(v, func(g models.GroupIDDescription) int { return g.GroupID }) +} diff --git a/pkg/models/group.go b/pkg/models/group.go index db7badccc90..6afda3f4890 100644 --- a/pkg/models/group.go +++ b/pkg/models/group.go @@ -23,6 +23,14 @@ type GroupFilterType struct { TagCount *IntCriterionInput `json:"tag_count"` // Filter by date Date *DateCriterionInput `json:"date"` + // Filter by containing groups + ContainingGroups *HierarchicalMultiCriterionInput `json:"containing_groups"` + // Filter by sub groups + SubGroups *HierarchicalMultiCriterionInput `json:"sub_groups"` + // Filter by number of containing groups the group has + ContainingGroupCount *IntCriterionInput `json:"containing_group_count"` + // Filter by number of sub-groups the group has + SubGroupCount *IntCriterionInput `json:"sub_group_count"` // Filter by related scenes that meet this criteria ScenesFilter *SceneFilterType `json:"scenes_filter"` // Filter by related studios that meet this criteria diff --git a/pkg/models/jsonschema/group.go b/pkg/models/jsonschema/group.go index fcf1ffe60a0..b284dab6e77 100644 --- a/pkg/models/jsonschema/group.go +++ b/pkg/models/jsonschema/group.go @@ -11,21 +11,27 @@ import ( "github.com/stashapp/stash/pkg/models/json" ) +type SubGroupDescription struct { + Group string `json:"name,omitempty"` + Description string `json:"description,omitempty"` +} + type Group struct { - Name string `json:"name,omitempty"` - Aliases string `json:"aliases,omitempty"` - Duration int `json:"duration,omitempty"` - Date string `json:"date,omitempty"` - Rating int `json:"rating,omitempty"` - Director string `json:"director,omitempty"` - Synopsis string `json:"synopsis,omitempty"` - FrontImage string `json:"front_image,omitempty"` - BackImage string `json:"back_image,omitempty"` - URLs []string `json:"urls,omitempty"` - Studio string `json:"studio,omitempty"` - Tags []string `json:"tags,omitempty"` - CreatedAt json.JSONTime `json:"created_at,omitempty"` - UpdatedAt json.JSONTime `json:"updated_at,omitempty"` + Name string `json:"name,omitempty"` + Aliases string `json:"aliases,omitempty"` + Duration int `json:"duration,omitempty"` + Date string `json:"date,omitempty"` + Rating int `json:"rating,omitempty"` + Director string `json:"director,omitempty"` + Synopsis string `json:"synopsis,omitempty"` + FrontImage string `json:"front_image,omitempty"` + BackImage string `json:"back_image,omitempty"` + URLs []string `json:"urls,omitempty"` + Studio string `json:"studio,omitempty"` + Tags []string `json:"tags,omitempty"` + SubGroups []SubGroupDescription `json:"sub_groups,omitempty"` + CreatedAt json.JSONTime `json:"created_at,omitempty"` + UpdatedAt json.JSONTime `json:"updated_at,omitempty"` // deprecated - for import only URL string `json:"url,omitempty"` diff --git a/pkg/models/mocks/GroupReaderWriter.go b/pkg/models/mocks/GroupReaderWriter.go index 5e3a2644ca8..dc745d09487 100644 --- a/pkg/models/mocks/GroupReaderWriter.go +++ b/pkg/models/mocks/GroupReaderWriter.go @@ -289,6 +289,29 @@ func (_m *GroupReaderWriter) GetBackImage(ctx context.Context, groupID int) ([]b return r0, r1 } +// GetContainingGroupDescriptions provides a mock function with given fields: ctx, id +func (_m *GroupReaderWriter) GetContainingGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) { + ret := _m.Called(ctx, id) + + var r0 []models.GroupIDDescription + if rf, ok := ret.Get(0).(func(context.Context, int) []models.GroupIDDescription); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.GroupIDDescription) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetFrontImage provides a mock function with given fields: ctx, groupID func (_m *GroupReaderWriter) GetFrontImage(ctx context.Context, groupID int) ([]byte, error) { ret := _m.Called(ctx, groupID) @@ -312,6 +335,29 @@ func (_m *GroupReaderWriter) GetFrontImage(ctx context.Context, groupID int) ([] return r0, r1 } +// GetSubGroupDescriptions provides a mock function with given fields: ctx, id +func (_m *GroupReaderWriter) GetSubGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) { + ret := _m.Called(ctx, id) + + var r0 []models.GroupIDDescription + if rf, ok := ret.Get(0).(func(context.Context, int) []models.GroupIDDescription); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.GroupIDDescription) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetTagIDs provides a mock function with given fields: ctx, relatedID func (_m *GroupReaderWriter) GetTagIDs(ctx context.Context, relatedID int) ([]int, error) { ret := _m.Called(ctx, relatedID) diff --git a/pkg/models/model_group.go b/pkg/models/model_group.go index af3ac56c68c..82c71996ae8 100644 --- a/pkg/models/model_group.go +++ b/pkg/models/model_group.go @@ -21,6 +21,9 @@ type Group struct { URLs RelatedStrings `json:"urls"` TagIDs RelatedIDs `json:"tag_ids"` + + ContainingGroups RelatedGroupDescriptions `json:"containing_groups"` + SubGroups RelatedGroupDescriptions `json:"sub_groups"` } func NewGroup() Group { @@ -43,20 +46,34 @@ func (m *Group) LoadTagIDs(ctx context.Context, l TagIDLoader) error { }) } +func (m *Group) LoadContainingGroupIDs(ctx context.Context, l ContainingGroupLoader) error { + return m.ContainingGroups.load(func() ([]GroupIDDescription, error) { + return l.GetContainingGroupDescriptions(ctx, m.ID) + }) +} + +func (m *Group) LoadSubGroupIDs(ctx context.Context, l SubGroupLoader) error { + return m.SubGroups.load(func() ([]GroupIDDescription, error) { + return l.GetSubGroupDescriptions(ctx, m.ID) + }) +} + type GroupPartial struct { Name OptionalString Aliases OptionalString Duration OptionalInt Date OptionalDate // Rating expressed in 1-100 scale - Rating OptionalInt - StudioID OptionalInt - Director OptionalString - Synopsis OptionalString - URLs *UpdateStrings - TagIDs *UpdateIDs - CreatedAt OptionalTime - UpdatedAt OptionalTime + Rating OptionalInt + StudioID OptionalInt + Director OptionalString + Synopsis OptionalString + URLs *UpdateStrings + TagIDs *UpdateIDs + ContainingGroups *UpdateGroupDescriptions + SubGroups *UpdateGroupDescriptions + CreatedAt OptionalTime + UpdatedAt OptionalTime } func NewGroupPartial() GroupPartial { diff --git a/pkg/models/model_joins.go b/pkg/models/model_joins.go index 189c2d7721f..7b7cae3e46a 100644 --- a/pkg/models/model_joins.go +++ b/pkg/models/model_joins.go @@ -68,3 +68,8 @@ func GroupsScenesFromInput(input []SceneMovieInput) ([]GroupsScenes, error) { return ret, nil } + +type GroupIDDescription struct { + GroupID int `json:"group_id"` + Description string `json:"description"` +} diff --git a/pkg/models/relationships.go b/pkg/models/relationships.go index 81528c26e95..5495f858b17 100644 --- a/pkg/models/relationships.go +++ b/pkg/models/relationships.go @@ -2,6 +2,8 @@ package models import ( "context" + + "github.com/stashapp/stash/pkg/sliceutil" ) type SceneIDLoader interface { @@ -37,6 +39,14 @@ type SceneGroupLoader interface { GetGroups(ctx context.Context, id int) ([]GroupsScenes, error) } +type ContainingGroupLoader interface { + GetContainingGroupDescriptions(ctx context.Context, id int) ([]GroupIDDescription, error) +} + +type SubGroupLoader interface { + GetSubGroupDescriptions(ctx context.Context, id int) ([]GroupIDDescription, error) +} + type StashIDLoader interface { GetStashIDs(ctx context.Context, relatedID int) ([]StashID, error) } @@ -185,6 +195,82 @@ func (r *RelatedGroups) load(fn func() ([]GroupsScenes, error)) error { return nil } +type RelatedGroupDescriptions struct { + list []GroupIDDescription +} + +// NewRelatedGroups returns a loaded RelateGroups object with the provided groups. +// Loaded will return true when called on the returned object if the provided slice is not nil. +func NewRelatedGroupDescriptions(list []GroupIDDescription) RelatedGroupDescriptions { + return RelatedGroupDescriptions{ + list: list, + } +} + +// Loaded returns true if the relationship has been loaded. +func (r RelatedGroupDescriptions) Loaded() bool { + return r.list != nil +} + +func (r RelatedGroupDescriptions) mustLoaded() { + if !r.Loaded() { + panic("list has not been loaded") + } +} + +// List returns the related Groups. Panics if the relationship has not been loaded. +func (r RelatedGroupDescriptions) List() []GroupIDDescription { + r.mustLoaded() + + return r.list +} + +// List returns the related Groups. Panics if the relationship has not been loaded. +func (r RelatedGroupDescriptions) IDs() []int { + r.mustLoaded() + + return sliceutil.Map(r.list, func(d GroupIDDescription) int { return d.GroupID }) +} + +// Add adds the provided ids to the list. Panics if the relationship has not been loaded. +func (r *RelatedGroupDescriptions) Add(groups ...GroupIDDescription) { + r.mustLoaded() + + r.list = append(r.list, groups...) +} + +// ForID returns the GroupsScenes object for the given group ID. Returns nil if not found. +func (r *RelatedGroupDescriptions) ForID(id int) *GroupIDDescription { + r.mustLoaded() + + for _, v := range r.list { + if v.GroupID == id { + return &v + } + } + + return nil +} + +func (r *RelatedGroupDescriptions) load(fn func() ([]GroupIDDescription, error)) error { + if r.Loaded() { + return nil + } + + ids, err := fn() + if err != nil { + return err + } + + if ids == nil { + ids = []GroupIDDescription{} + } + + r.list = ids + + return nil +} + type RelatedStashIDs struct { list []StashID } diff --git a/pkg/models/repository_group.go b/pkg/models/repository_group.go index 0396049b66e..704390d77b3 100644 --- a/pkg/models/repository_group.go +++ b/pkg/models/repository_group.go @@ -66,6 +66,8 @@ type GroupReader interface { GroupCounter URLLoader TagIDLoader + ContainingGroupLoader + SubGroupLoader All(ctx context.Context) ([]*Group, error) GetFrontImage(ctx context.Context, groupID int) ([]byte, error) diff --git a/pkg/models/repository_scene.go b/pkg/models/repository_scene.go index 60783fff5cd..e28347c5b82 100644 --- a/pkg/models/repository_scene.go +++ b/pkg/models/repository_scene.go @@ -37,10 +37,7 @@ type SceneQueryer interface { type SceneCounter interface { Count(ctx context.Context) (int, error) CountByPerformerID(ctx context.Context, performerID int) (int, error) - CountByGroupID(ctx context.Context, groupID int) (int, error) CountByFileID(ctx context.Context, fileID FileID) (int, error) - CountByStudioID(ctx context.Context, studioID int) (int, error) - CountByTagID(ctx context.Context, tagID int) (int, error) CountMissingChecksum(ctx context.Context) (int, error) CountMissingOSHash(ctx context.Context) (int, error) OCountByPerformerID(ctx context.Context, performerID int) (int, error) diff --git a/pkg/models/scene.go b/pkg/models/scene.go index 814c4a41d62..48317240276 100644 --- a/pkg/models/scene.go +++ b/pkg/models/scene.go @@ -56,7 +56,7 @@ type SceneFilterType struct { // Filter to only include scenes with this studio Studios *HierarchicalMultiCriterionInput `json:"studios"` // Filter to only include scenes with this group - Groups *MultiCriterionInput `json:"groups"` + Groups *HierarchicalMultiCriterionInput `json:"groups"` // Filter to only include scenes with this movie Movies *MultiCriterionInput `json:"movies"` // Filter to only include scenes with this gallery diff --git a/pkg/models/update.go b/pkg/models/update.go index 2302a2e699a..6aaff8c317f 100644 --- a/pkg/models/update.go +++ b/pkg/models/update.go @@ -133,3 +133,68 @@ func applyUpdate[T comparable](values []T, mode RelationshipUpdateMode, existing return nil } + +type UpdateGroupDescriptions struct { + Groups []GroupIDDescription `json:"groups"` + Mode RelationshipUpdateMode `json:"mode"` +} + +// Apply applies the update to a list of existing ids, returning the result. +func (u *UpdateGroupDescriptions) Apply(existing []GroupIDDescription) []GroupIDDescription { + if u == nil { + return existing + } + + switch u.Mode { + case RelationshipUpdateModeAdd: + return u.applyAdd(existing) + case RelationshipUpdateModeRemove: + return u.applyRemove(existing) + case RelationshipUpdateModeSet: + return u.Groups + } + + return nil +} + +func (u *UpdateGroupDescriptions) applyAdd(existing []GroupIDDescription) []GroupIDDescription { + // overwrite any existing values with the same id + ret := append([]GroupIDDescription{}, existing...) + for _, v := range u.Groups { + found := false + for i, vv := range ret { + if vv.GroupID == v.GroupID { + ret[i] = v + found = true + break + } + } + + if !found { + ret = append(ret, v) + } + } + + return ret +} + +func (u *UpdateGroupDescriptions) applyRemove(existing []GroupIDDescription) []GroupIDDescription { + // remove any existing values with the same id + var ret []GroupIDDescription + for _, v := range existing { + found := false + for _, vv := range u.Groups { + if vv.GroupID == v.GroupID { + found = true + break + } + } + + // if not found in the remove list, keep it + if !found { + ret = append(ret, v) + } + } + + return ret +} diff --git a/pkg/scene/query.go b/pkg/scene/query.go index a8b1993a6a0..c640266f9ef 100644 --- a/pkg/scene/query.go +++ b/pkg/scene/query.go @@ -144,3 +144,15 @@ func CountByTagID(ctx context.Context, r models.SceneQueryer, id int, depth *int return r.QueryCount(ctx, filter, nil) } + +func CountByGroupID(ctx context.Context, r models.SceneQueryer, id int, depth *int) (int, error) { + filter := &models.SceneFilterType{ + Groups: &models.HierarchicalMultiCriterionInput{ + Value: []string{strconv.Itoa(id)}, + Modifier: models.CriterionModifierIncludes, + Depth: depth, + }, + } + + return r.QueryCount(ctx, filter, nil) +} diff --git a/pkg/sliceutil/collections.go b/pkg/sliceutil/collections.go index bd4070cdc94..18930df259e 100644 --- a/pkg/sliceutil/collections.go +++ b/pkg/sliceutil/collections.go @@ -146,7 +146,7 @@ func Filter[T any](vs []T, f func(T) bool) []T { return ret } -// Filter returns the result of applying f to each element of the vs slice. +// Map returns the result of applying f to each element of the vs slice. func Map[T any, V any](vs []T, f func(T) V) []V { ret := make([]V, len(vs)) for i, v := range vs { diff --git a/pkg/sqlite/database.go b/pkg/sqlite/database.go index ee5e5399bf9..7dd4771d33f 100644 --- a/pkg/sqlite/database.go +++ b/pkg/sqlite/database.go @@ -30,7 +30,7 @@ const ( dbConnTimeout = 30 ) -var appSchemaVersion uint = 66 +var appSchemaVersion uint = 67 //go:embed migrations/*.sql var migrationsBox embed.FS diff --git a/pkg/sqlite/filter_hierarchical.go b/pkg/sqlite/filter_hierarchical.go new file mode 100644 index 00000000000..bc5ff9032b3 --- /dev/null +++ b/pkg/sqlite/filter_hierarchical.go @@ -0,0 +1,222 @@ +package sqlite + +import ( + "context" + "fmt" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/utils" +) + +// hierarchicalRelationshipHandler provides handlers for parent, children, parent count, and child count criteria. +type hierarchicalRelationshipHandler struct { + primaryTable string + relationTable string + aliasPrefix string + parentIDCol string + childIDCol string +} + +func (h hierarchicalRelationshipHandler) validateModifier(m models.CriterionModifier) error { + switch m { + case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull: + // valid + return nil + default: + return fmt.Errorf("invalid modifier %s", m) + } +} + +func (h hierarchicalRelationshipHandler) handleNullNotNull(f *filterBuilder, m models.CriterionModifier, isParents bool) { + var notClause string + if m == models.CriterionModifierNotNull { + notClause = "NOT" + } + + as := h.aliasPrefix + "_parents" + col := h.childIDCol + if !isParents { + as = h.aliasPrefix + "_children" + col = h.parentIDCol + } + + // Based on: + // f.addLeftJoin("tags_relations", "parent_relations", "tags.id = parent_relations.child_id") + // f.addWhere(fmt.Sprintf("parent_relations.parent_id IS %s NULL", notClause)) + + f.addLeftJoin(h.relationTable, as, fmt.Sprintf("%s.id = %s.%s", h.primaryTable, as, col)) + f.addWhere(fmt.Sprintf("%s.%s IS %s NULL", as, col, notClause)) +} + +func (h hierarchicalRelationshipHandler) parentsAlias() string { + return h.aliasPrefix + "_parents" +} + +func (h hierarchicalRelationshipHandler) childrenAlias() string { + return h.aliasPrefix + "_children" +} + +func (h hierarchicalRelationshipHandler) valueQuery(value []string, depth int, alias string, isParents bool) string { + var depthCondition string + if depth != -1 { + depthCondition = fmt.Sprintf("WHERE depth < %d", depth) + } + + queryTempl := `{alias} AS ( +SELECT {root_id_col} AS root_id, {item_id_col} AS item_id, 0 AS depth FROM {relation_table} WHERE {root_id_col} IN` + getInBinding(len(value)) + ` +UNION +SELECT root_id, {item_id_col}, depth + 1 FROM {relation_table} INNER JOIN {alias} ON item_id = {root_id_col} ` + depthCondition + ` +)` + + var queryMap utils.StrFormatMap + if isParents { + queryMap = utils.StrFormatMap{ + "root_id_col": h.parentIDCol, + "item_id_col": h.childIDCol, + } + } else { + queryMap = utils.StrFormatMap{ + "root_id_col": h.childIDCol, + "item_id_col": h.parentIDCol, + } + } + + queryMap["alias"] = alias + queryMap["relation_table"] = h.relationTable + + return utils.StrFormat(queryTempl, queryMap) +} + +func (h hierarchicalRelationshipHandler) handleValues(f *filterBuilder, c models.HierarchicalMultiCriterionInput, isParents bool, aliasSuffix string) { + if len(c.Value) == 0 { + return + } + + var args []interface{} + for _, val := range c.Value { + args = append(args, val) + } + + depthVal := 0 + if c.Depth != nil { + depthVal = *c.Depth + } + + tableAlias := h.parentsAlias() + if !isParents { + tableAlias = h.childrenAlias() + } + tableAlias += aliasSuffix + + query := h.valueQuery(c.Value, depthVal, tableAlias, isParents) + f.addRecursiveWith(query, args...) + + f.addLeftJoin(tableAlias, "", fmt.Sprintf("%s.item_id = %s.id", tableAlias, h.primaryTable)) + addHierarchicalConditionClauses(f, c, tableAlias, "root_id") +} + +func (h hierarchicalRelationshipHandler) handleValuesSimple(f *filterBuilder, value string, isParents bool) { + joinCol := h.childIDCol + valueCol := h.parentIDCol + if !isParents { + joinCol = h.parentIDCol + valueCol = h.childIDCol + } + + tableAlias := h.parentsAlias() + if !isParents { + tableAlias = h.childrenAlias() + } + + f.addInnerJoin(h.relationTable, tableAlias, fmt.Sprintf("%s.%s = %s.id", tableAlias, joinCol, h.primaryTable)) + f.addWhere(fmt.Sprintf("%s.%s = ?", tableAlias, valueCol), value) +} + +func (h hierarchicalRelationshipHandler) hierarchicalCriterionHandler(criterion *models.HierarchicalMultiCriterionInput, isParents bool) criterionHandlerFunc { + return func(ctx context.Context, f *filterBuilder) { + if criterion != nil { + c := criterion.CombineExcludes() + + // validate the modifier + if err := h.validateModifier(c.Modifier); err != nil { + f.setError(err) + return + } + + if c.Modifier == models.CriterionModifierIsNull || c.Modifier == models.CriterionModifierNotNull { + h.handleNullNotNull(f, c.Modifier, isParents) + return + } + + if len(c.Value) == 0 && len(c.Excludes) == 0 { + return + } + + depth := 0 + if c.Depth != nil { + depth = *c.Depth + } + + // if we have a single include, no excludes, and no depth, we can use a simple join and where clause + if (c.Modifier == models.CriterionModifierIncludes || c.Modifier == models.CriterionModifierIncludesAll) && len(c.Value) == 1 && len(c.Excludes) == 0 && depth == 0 { + h.handleValuesSimple(f, c.Value[0], isParents) + return + } + + aliasSuffix := "" + h.handleValues(f, c, isParents, aliasSuffix) + + if len(c.Excludes) > 0 { + exCriterion := models.HierarchicalMultiCriterionInput{ + Value: c.Excludes, + Depth: c.Depth, + Modifier: models.CriterionModifierExcludes, + } + + aliasSuffix := "2" + h.handleValues(f, exCriterion, isParents, aliasSuffix) + } + } + } +} + +func (h hierarchicalRelationshipHandler) ParentsCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + const isParents = true + return h.hierarchicalCriterionHandler(criterion, isParents) +} + +func (h hierarchicalRelationshipHandler) ChildrenCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + const isParents = false + return h.hierarchicalCriterionHandler(criterion, isParents) +} + +func (h hierarchicalRelationshipHandler) countCriterionHandler(c *models.IntCriterionInput, isParents bool) criterionHandlerFunc { + tableAlias := h.parentsAlias() + col := h.childIDCol + otherCol := h.parentIDCol + if !isParents { + tableAlias = h.childrenAlias() + col = h.parentIDCol + otherCol = h.childIDCol + } + tableAlias += "_count" + + return func(ctx context.Context, f *filterBuilder) { + if c != nil { + f.addLeftJoin(h.relationTable, tableAlias, fmt.Sprintf("%s.%s = %s.id", tableAlias, col, h.primaryTable)) + clause, args := getIntCriterionWhereClause(fmt.Sprintf("count(distinct %s.%s)", tableAlias, otherCol), *c) + + f.addHaving(clause, args...) + } + } +} + +func (h hierarchicalRelationshipHandler) ParentCountCriterionHandler(parentCount *models.IntCriterionInput) criterionHandlerFunc { + const isParents = true + return h.countCriterionHandler(parentCount, isParents) +} + +func (h hierarchicalRelationshipHandler) ChildCountCriterionHandler(childCount *models.IntCriterionInput) criterionHandlerFunc { + const isParents = false + return h.countCriterionHandler(childCount, isParents) +} diff --git a/pkg/sqlite/group.go b/pkg/sqlite/group.go index 21c224242f8..603494fe71a 100644 --- a/pkg/sqlite/group.go +++ b/pkg/sqlite/group.go @@ -27,6 +27,8 @@ const ( groupURLsTable = "group_urls" groupURLColumn = "url" + + groupRelationsTable = "groups_relations" ) type groupRow struct { @@ -128,6 +130,7 @@ var ( type GroupStore struct { blobJoinQueryBuilder tagRelationshipStore + groupRelationshipStore tableMgr *table } @@ -143,6 +146,9 @@ func NewGroupStore(blobStore *BlobStore) *GroupStore { joinTable: groupsTagsTableMgr, }, }, + groupRelationshipStore: groupRelationshipStore{ + table: groupRelationshipTableMgr, + }, tableMgr: groupTableMgr, } @@ -176,6 +182,14 @@ func (qb *GroupStore) Create(ctx context.Context, newObject *models.Group) error return err } + if err := qb.groupRelationshipStore.createContainingRelationships(ctx, id, newObject.ContainingGroups); err != nil { + return err + } + + if err := qb.groupRelationshipStore.createSubRelationships(ctx, id, newObject.SubGroups); err != nil { + return err + } + updated, err := qb.find(ctx, id) if err != nil { return fmt.Errorf("finding after create: %w", err) @@ -211,6 +225,14 @@ func (qb *GroupStore) UpdatePartial(ctx context.Context, id int, partial models. return nil, err } + if err := qb.groupRelationshipStore.modifyContainingRelationships(ctx, id, partial.ContainingGroups); err != nil { + return nil, err + } + + if err := qb.groupRelationshipStore.modifySubRelationships(ctx, id, partial.SubGroups); err != nil { + return nil, err + } + return qb.find(ctx, id) } @@ -232,6 +254,14 @@ func (qb *GroupStore) Update(ctx context.Context, updatedObject *models.Group) e return err } + if err := qb.groupRelationshipStore.replaceContainingRelationships(ctx, updatedObject.ID, updatedObject.ContainingGroups); err != nil { + return err + } + + if err := qb.groupRelationshipStore.replaceSubRelationships(ctx, updatedObject.ID, updatedObject.SubGroups); err != nil { + return err + } + return nil } @@ -412,9 +442,7 @@ func (qb *GroupStore) makeQuery(ctx context.Context, groupFilter *models.GroupFi return nil, err } - var err error - query.sortAndPagination, err = qb.getGroupSort(findFilter) - if err != nil { + if err := qb.setGroupSort(&query, findFilter); err != nil { return nil, err } @@ -460,11 +488,12 @@ var groupSortOptions = sortOptions{ "random", "rating", "scenes_count", + "sub_group_order", "tag_count", "updated_at", } -func (qb *GroupStore) getGroupSort(findFilter *models.FindFilterType) (string, error) { +func (qb *GroupStore) setGroupSort(query *queryBuilder, findFilter *models.FindFilterType) error { var sort string var direction string if findFilter == nil { @@ -477,22 +506,31 @@ func (qb *GroupStore) getGroupSort(findFilter *models.FindFilterType) (string, e // CVE-2024-32231 - ensure sort is in the list of allowed sorts if err := groupSortOptions.validateSort(sort); err != nil { - return "", err + return err } - sortQuery := "" switch sort { + case "sub_group_order": + // sub_group_order is a special sort that sorts by the order_index of the subgroups + if query.hasJoin("groups_parents") { + query.sortAndPagination += getSort("order_index", direction, "groups_parents") + } else { + // this will give unexpected results if the query is not filtered by a parent group and + // the group has multiple parents and order indexes + query.join(groupRelationsTable, "", "groups.id = groups_relations.sub_id") + query.sortAndPagination += getSort("order_index", direction, groupRelationsTable) + } case "tag_count": - sortQuery += getCountSort(groupTable, groupsTagsTable, groupIDColumn, direction) + query.sortAndPagination += getCountSort(groupTable, groupsTagsTable, groupIDColumn, direction) case "scenes_count": // generic getSort won't work for this - sortQuery += getCountSort(groupTable, groupsScenesTable, groupIDColumn, direction) + query.sortAndPagination += getCountSort(groupTable, groupsScenesTable, groupIDColumn, direction) default: - sortQuery += getSort(sort, direction, "groups") + query.sortAndPagination += getSort(sort, direction, "groups") } // Whatever the sorting, always use name/id as a final sort - sortQuery += ", COALESCE(groups.name, groups.id) COLLATE NATURAL_CI ASC" - return sortQuery, nil + query.sortAndPagination += ", COALESCE(groups.name, groups.id) COLLATE NATURAL_CI ASC" + return nil } func (qb *GroupStore) queryGroups(ctx context.Context, query string, args []interface{}) ([]*models.Group, error) { @@ -592,3 +630,74 @@ WHERE groups.studio_id = ? func (qb *GroupStore) GetURLs(ctx context.Context, groupID int) ([]string, error) { return groupsURLsTableMgr.get(ctx, groupID) } + +// FindSubGroupIDs returns a list of group IDs where a group in the ids list is a sub-group of the parent group +func (qb *GroupStore) FindSubGroupIDs(ctx context.Context, containingID int, ids []int) ([]int, error) { + /* + SELECT gr.sub_id FROM groups_relations gr + WHERE gr.containing_id = :parentID AND gr.sub_id IN (:ids); + */ + table := groupRelationshipTableMgr.table + q := dialect.From(table).Prepared(true). + Select(table.Col("sub_id")).Where( + table.Col("containing_id").Eq(containingID), + table.Col("sub_id").In(ids), + ) + + const single = false + var ret []int + if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error { + var id int + if err := r.Scan(&id); err != nil { + return err + } + + ret = append(ret, id) + return nil + }); err != nil { + return nil, err + } + + return ret, nil +} + +// FindInAscestors returns a list of group IDs where a group in the ids list is an ascestor of the ancestor group IDs +func (qb *GroupStore) FindInAncestors(ctx context.Context, ascestorIDs []int, ids []int) ([]int, error) { + /* + WITH RECURSIVE ascestors AS ( + SELECT g.id AS parent_id FROM groups g WHERE g.id IN (:ascestorIDs) + UNION + SELECT gr.containing_id FROM groups_relations gr INNER JOIN ascestors a ON a.parent_id = gr.sub_id + ) + SELECT p.parent_id FROM ascestors p WHERE p.parent_id IN (:ids); + */ + table := qb.table() + const ascestors = "ancestors" + const parentID = "parent_id" + q := dialect.From(ascestors).Prepared(true). + WithRecursive(ascestors, + dialect.From(qb.table()).Select(table.Col(idColumn).As(parentID)). + Where(table.Col(idColumn).In(ascestorIDs)). + Union( + dialect.From(groupRelationsJoinTable).InnerJoin( + goqu.I(ascestors), goqu.On(goqu.I("parent_id").Eq(goqu.I("sub_id"))), + ).Select("containing_id"), + ), + ).Select(parentID).Where(goqu.I(parentID).In(ids)) + + const single = false + var ret []int + if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error { + var id int + if err := r.Scan(&id); err != nil { + return err + } + + ret = append(ret, id) + return nil + }); err != nil { + return nil, err + } + + return ret, nil +} diff --git a/pkg/sqlite/group_filter.go b/pkg/sqlite/group_filter.go index 97bde1f2474..dcb7bcdfc94 100644 --- a/pkg/sqlite/group_filter.go +++ b/pkg/sqlite/group_filter.go @@ -51,6 +51,14 @@ func (qb *groupFilterHandler) handle(ctx context.Context, f *filterBuilder) { f.handleCriterion(ctx, qb.criterionHandler()) } +var groupHierarchyHandler = hierarchicalRelationshipHandler{ + primaryTable: groupTable, + relationTable: groupRelationsTable, + aliasPrefix: groupTable, + parentIDCol: "containing_id", + childIDCol: "sub_id", +} + func (qb *groupFilterHandler) criterionHandler() criterionHandler { groupFilter := qb.groupFilter return compoundHandler{ @@ -66,6 +74,10 @@ func (qb *groupFilterHandler) criterionHandler() criterionHandler { qb.tagsCriterionHandler(groupFilter.Tags), qb.tagCountCriterionHandler(groupFilter.TagCount), &dateCriterionHandler{groupFilter.Date, "groups.date", nil}, + groupHierarchyHandler.ParentsCriterionHandler(groupFilter.ContainingGroups), + groupHierarchyHandler.ChildrenCriterionHandler(groupFilter.SubGroups), + groupHierarchyHandler.ParentCountCriterionHandler(groupFilter.ContainingGroupCount), + groupHierarchyHandler.ChildCountCriterionHandler(groupFilter.SubGroupCount), ×tampCriterionHandler{groupFilter.CreatedAt, "groups.created_at", nil}, ×tampCriterionHandler{groupFilter.UpdatedAt, "groups.updated_at", nil}, diff --git a/pkg/sqlite/group_relationships.go b/pkg/sqlite/group_relationships.go new file mode 100644 index 00000000000..fe94394f905 --- /dev/null +++ b/pkg/sqlite/group_relationships.go @@ -0,0 +1,457 @@ +package sqlite + +import ( + "context" + "fmt" + + "github.com/doug-martin/goqu/v9" + "github.com/doug-martin/goqu/v9/exp" + "github.com/jmoiron/sqlx" + "github.com/stashapp/stash/pkg/models" + "gopkg.in/guregu/null.v4" + "gopkg.in/guregu/null.v4/zero" +) + +type groupRelationshipRow struct { + ContainingID int `db:"containing_id"` + SubID int `db:"sub_id"` + OrderIndex int `db:"order_index"` + Description zero.String `db:"description"` +} + +func (r groupRelationshipRow) resolve(useContainingID bool) models.GroupIDDescription { + id := r.ContainingID + if !useContainingID { + id = r.SubID + } + + return models.GroupIDDescription{ + GroupID: id, + Description: r.Description.String, + } +} + +type groupRelationshipStore struct { + table *table +} + +func (s *groupRelationshipStore) GetContainingGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) { + const idIsContaining = false + return s.getGroupRelationships(ctx, id, idIsContaining) +} + +func (s *groupRelationshipStore) GetSubGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) { + const idIsContaining = true + return s.getGroupRelationships(ctx, id, idIsContaining) +} + +func (s *groupRelationshipStore) getGroupRelationships(ctx context.Context, id int, idIsContaining bool) ([]models.GroupIDDescription, error) { + col := "containing_id" + if !idIsContaining { + col = "sub_id" + } + + table := s.table.table + q := dialect.Select(table.All()). + From(table). + Where(table.Col(col).Eq(id)). + Order(table.Col("order_index").Asc()) + + const single = false + var ret []models.GroupIDDescription + if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error { + var row groupRelationshipRow + if err := rows.StructScan(&row); err != nil { + return err + } + + ret = append(ret, row.resolve(!idIsContaining)) + + return nil + }); err != nil { + return nil, fmt.Errorf("getting group relationships from %s: %w", table.GetTable(), err) + } + + return ret, nil +} + +// getMaxOrderIndex gets the maximum order index for the containing group with the given id +func (s *groupRelationshipStore) getMaxOrderIndex(ctx context.Context, containingID int) (int, error) { + idColumn := s.table.table.Col("containing_id") + + q := dialect.Select(goqu.MAX("order_index")). + From(s.table.table). + Where(idColumn.Eq(containingID)) + + var maxOrderIndex zero.Int + if err := querySimple(ctx, q, &maxOrderIndex); err != nil { + return 0, fmt.Errorf("getting max order index: %w", err) + } + + return int(maxOrderIndex.Int64), nil +} + +// createRelationships creates relationships between a group and other groups. +// If idIsContaining is true, the provided id is the containing group. +func (s *groupRelationshipStore) createRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions, idIsContaining bool) error { + if d.Loaded() { + for i, v := range d.List() { + orderIndex := i + 1 + + r := groupRelationshipRow{ + ContainingID: id, + SubID: v.GroupID, + OrderIndex: orderIndex, + Description: zero.StringFrom(v.Description), + } + + if !idIsContaining { + // get the max order index of the containing groups sub groups + containingID := v.GroupID + maxOrderIndex, err := s.getMaxOrderIndex(ctx, containingID) + if err != nil { + return err + } + + r.ContainingID = v.GroupID + r.SubID = id + r.OrderIndex = maxOrderIndex + 1 + } + + _, err := s.table.insert(ctx, r) + if err != nil { + return fmt.Errorf("inserting into %s: %w", s.table.table.GetTable(), err) + } + } + + return nil + } + + return nil +} + +// createRelationships creates relationships between a group and other groups. +// If idIsContaining is true, the provided id is the containing group. +func (s *groupRelationshipStore) createContainingRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error { + const idIsContaining = false + return s.createRelationships(ctx, id, d, idIsContaining) +} + +// createRelationships creates relationships between a group and other groups. +// If idIsContaining is true, the provided id is the containing group. +func (s *groupRelationshipStore) createSubRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error { + const idIsContaining = true + return s.createRelationships(ctx, id, d, idIsContaining) +} + +func (s *groupRelationshipStore) replaceRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions, idIsContaining bool) error { + // always destroy the existing relationships even if the new list is empty + if err := s.destroyAllJoins(ctx, id, idIsContaining); err != nil { + return err + } + + return s.createRelationships(ctx, id, d, idIsContaining) +} + +func (s *groupRelationshipStore) replaceContainingRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error { + const idIsContaining = false + return s.replaceRelationships(ctx, id, d, idIsContaining) +} + +func (s *groupRelationshipStore) replaceSubRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error { + const idIsContaining = true + return s.replaceRelationships(ctx, id, d, idIsContaining) +} + +func (s *groupRelationshipStore) modifyRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions, idIsContaining bool) error { + if v == nil { + return nil + } + + switch v.Mode { + case models.RelationshipUpdateModeSet: + return s.replaceJoins(ctx, id, *v, idIsContaining) + case models.RelationshipUpdateModeAdd: + return s.addJoins(ctx, id, v.Groups, idIsContaining) + case models.RelationshipUpdateModeRemove: + toRemove := make([]int, len(v.Groups)) + for i, vv := range v.Groups { + toRemove[i] = vv.GroupID + } + return s.destroyJoins(ctx, id, toRemove, idIsContaining) + } + + return nil +} + +func (s *groupRelationshipStore) modifyContainingRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions) error { + const idIsContaining = false + return s.modifyRelationships(ctx, id, v, idIsContaining) +} + +func (s *groupRelationshipStore) modifySubRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions) error { + const idIsContaining = true + return s.modifyRelationships(ctx, id, v, idIsContaining) +} + +func (s *groupRelationshipStore) addJoins(ctx context.Context, id int, groups []models.GroupIDDescription, idIsContaining bool) error { + // if we're adding to a containing group, get the max order index first + var maxOrderIndex int + if idIsContaining { + var err error + maxOrderIndex, err = s.getMaxOrderIndex(ctx, id) + if err != nil { + return err + } + } + + for i, vv := range groups { + r := groupRelationshipRow{ + Description: zero.StringFrom(vv.Description), + } + + if idIsContaining { + r.ContainingID = id + r.SubID = vv.GroupID + r.OrderIndex = maxOrderIndex + (i + 1) + } else { + // get the max order index of the containing groups sub groups + containingMaxOrderIndex, err := s.getMaxOrderIndex(ctx, vv.GroupID) + if err != nil { + return err + } + + r.ContainingID = vv.GroupID + r.SubID = id + r.OrderIndex = containingMaxOrderIndex + 1 + } + + _, err := s.table.insert(ctx, r) + if err != nil { + return fmt.Errorf("inserting into %s: %w", s.table.table.GetTable(), err) + } + } + + return nil +} + +func (s *groupRelationshipStore) destroyAllJoins(ctx context.Context, id int, idIsContaining bool) error { + table := s.table.table + idColumn := table.Col("containing_id") + if !idIsContaining { + idColumn = table.Col("sub_id") + } + + q := dialect.Delete(table).Where(idColumn.Eq(id)) + + if _, err := exec(ctx, q); err != nil { + return fmt.Errorf("destroying %s: %w", table.GetTable(), err) + } + + return nil +} + +func (s *groupRelationshipStore) replaceJoins(ctx context.Context, id int, v models.UpdateGroupDescriptions, idIsContaining bool) error { + if err := s.destroyAllJoins(ctx, id, idIsContaining); err != nil { + return err + } + + // convert to RelatedGroupDescriptions + rgd := models.NewRelatedGroupDescriptions(v.Groups) + return s.createRelationships(ctx, id, rgd, idIsContaining) +} + +func (s *groupRelationshipStore) destroyJoins(ctx context.Context, id int, toRemove []int, idIsContaining bool) error { + table := s.table.table + idColumn := table.Col("containing_id") + fkColumn := table.Col("sub_id") + if !idIsContaining { + idColumn = table.Col("sub_id") + fkColumn = table.Col("containing_id") + } + + q := dialect.Delete(table).Where(idColumn.Eq(id), fkColumn.In(toRemove)) + + if _, err := exec(ctx, q); err != nil { + return fmt.Errorf("destroying %s: %w", table.GetTable(), err) + } + + return nil +} + +func (s *groupRelationshipStore) getOrderIndexOfSubGroup(ctx context.Context, containingGroupID int, subGroupID int) (int, error) { + table := s.table.table + q := dialect.Select("order_index"). + From(table). + Where( + table.Col("containing_id").Eq(containingGroupID), + table.Col("sub_id").Eq(subGroupID), + ) + + var orderIndex null.Int + if err := querySimple(ctx, q, &orderIndex); err != nil { + return 0, fmt.Errorf("getting order index: %w", err) + } + + if !orderIndex.Valid { + return 0, fmt.Errorf("sub-group %d not found in containing group %d", subGroupID, containingGroupID) + } + + return int(orderIndex.Int64), nil +} + +func (s *groupRelationshipStore) getGroupIDAtOrderIndex(ctx context.Context, containingGroupID int, orderIndex int) (*int, error) { + table := s.table.table + q := dialect.Select(table.Col("sub_id")).From(table).Where( + table.Col("containing_id").Eq(containingGroupID), + table.Col("order_index").Eq(orderIndex), + ) + + var ret null.Int + if err := querySimple(ctx, q, &ret); err != nil { + return nil, fmt.Errorf("getting sub id for order index: %w", err) + } + + if !ret.Valid { + return nil, nil + } + + intRet := int(ret.Int64) + return &intRet, nil +} + +func (s *groupRelationshipStore) getOrderIndexAfterOrderIndex(ctx context.Context, containingGroupID int, orderIndex int) (int, error) { + table := s.table.table + q := dialect.Select(goqu.MIN("order_index")).From(table).Where( + table.Col("containing_id").Eq(containingGroupID), + table.Col("order_index").Gt(orderIndex), + ) + + var ret null.Int + if err := querySimple(ctx, q, &ret); err != nil { + return 0, fmt.Errorf("getting order index: %w", err) + } + + if !ret.Valid { + return orderIndex + 1, nil + } + + return int(ret.Int64), nil +} + +// incrementOrderIndexes increments the order_index value of all sub-groups in the containing group at or after the given index +func (s *groupRelationshipStore) incrementOrderIndexes(ctx context.Context, groupID int, indexBefore int) error { + table := s.table.table + + // WORKAROUND - sqlite won't allow incrementing the value directly since it causes a + // unique constraint violation. + // Instead, we first set the order index to a negative value temporarily + // see https://stackoverflow.com/a/7703239/695786 + q := dialect.Update(table).Set(exp.Record{ + "order_index": goqu.L("-order_index"), + }).Where( + table.Col("containing_id").Eq(groupID), + table.Col("order_index").Gte(indexBefore), + ) + + if _, err := exec(ctx, q); err != nil { + return fmt.Errorf("updating %s: %w", table.GetTable(), err) + } + + q = dialect.Update(table).Set(exp.Record{ + "order_index": goqu.L("1-order_index"), + }).Where( + table.Col("containing_id").Eq(groupID), + table.Col("order_index").Lt(0), + ) + + if _, err := exec(ctx, q); err != nil { + return fmt.Errorf("updating %s: %w", table.GetTable(), err) + } + + return nil +} + +func (s *groupRelationshipStore) reorderSubGroup(ctx context.Context, groupID int, subGroupID int, insertPointID int, insertAfter bool) error { + insertPointIndex, err := s.getOrderIndexOfSubGroup(ctx, groupID, insertPointID) + if err != nil { + return err + } + + // if we're setting before + if insertAfter { + insertPointIndex, err = s.getOrderIndexAfterOrderIndex(ctx, groupID, insertPointIndex) + if err != nil { + return err + } + } + + // increment the order index of all sub-groups after and including the insertion point + if err := s.incrementOrderIndexes(ctx, groupID, int(insertPointIndex)); err != nil { + return err + } + + // set the order index of the sub-group to the insertion point + table := s.table.table + q := dialect.Update(table).Set(exp.Record{ + "order_index": insertPointIndex, + }).Where( + table.Col("containing_id").Eq(groupID), + table.Col("sub_id").Eq(subGroupID), + ) + + if _, err := exec(ctx, q); err != nil { + return fmt.Errorf("updating %s: %w", table.GetTable(), err) + } + + return nil +} + +func (s *groupRelationshipStore) AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error { + const idIsContaining = true + + if err := s.addJoins(ctx, groupID, subGroups, idIsContaining); err != nil { + return err + } + + ids := make([]int, len(subGroups)) + for i, v := range subGroups { + ids[i] = v.GroupID + } + + if insertIndex != nil { + // get the id of the sub-group at the insert index + insertPointID, err := s.getGroupIDAtOrderIndex(ctx, groupID, *insertIndex) + if err != nil { + return err + } + + if insertPointID == nil { + // if the insert index is out of bounds, just assume adding to the end + return nil + } + + // reorder the sub-groups + const insertAfter = false + if err := s.ReorderSubGroups(ctx, groupID, ids, *insertPointID, insertAfter); err != nil { + return err + } + } + + return nil +} + +func (s *groupRelationshipStore) RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error { + const idIsContaining = true + return s.destroyJoins(ctx, groupID, subGroupIDs, idIsContaining) +} + +func (s *groupRelationshipStore) ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error { + for _, id := range subGroupIDs { + if err := s.reorderSubGroup(ctx, groupID, id, insertPointID, insertAfter); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/sqlite/group_test.go b/pkg/sqlite/group_test.go index 45171337970..1d3637c8611 100644 --- a/pkg/sqlite/group_test.go +++ b/pkg/sqlite/group_test.go @@ -14,6 +14,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sliceutil" + "github.com/stashapp/stash/pkg/sliceutil/intslice" ) func loadGroupRelationships(ctx context.Context, expected models.Group, actual *models.Group) error { @@ -27,22 +29,34 @@ func loadGroupRelationships(ctx context.Context, expected models.Group, actual * return err } } + if expected.ContainingGroups.Loaded() { + if err := actual.LoadContainingGroupIDs(ctx, db.Group); err != nil { + return err + } + } + if expected.SubGroups.Loaded() { + if err := actual.LoadSubGroupIDs(ctx, db.Group); err != nil { + return err + } + } return nil } func Test_GroupStore_Create(t *testing.T) { var ( - name = "name" - url = "url" - aliases = "alias1, alias2" - director = "director" - rating = 60 - duration = 34 - synopsis = "synopsis" - date, _ = models.ParseDate("2003-02-01") - createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) - updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + name = "name" + url = "url" + aliases = "alias1, alias2" + director = "director" + rating = 60 + duration = 34 + synopsis = "synopsis" + date, _ = models.ParseDate("2003-02-01") + containingGroupDescription = "containingGroupDescription" + subGroupDescription = "subGroupDescription" + createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) ) tests := []struct { @@ -53,15 +67,21 @@ func Test_GroupStore_Create(t *testing.T) { { "full", models.Group{ - Name: name, - Duration: &duration, - Date: &date, - Rating: &rating, - StudioID: &studioIDs[studioIdxWithGroup], - Director: director, - Synopsis: synopsis, - URLs: models.NewRelatedStrings([]string{url}), - TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithGroup]}), + Name: name, + Duration: &duration, + Date: &date, + Rating: &rating, + StudioID: &studioIDs[studioIdxWithGroup], + Director: director, + Synopsis: synopsis, + URLs: models.NewRelatedStrings([]string{url}), + TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithGroup]}), + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithScene], Description: containingGroupDescription}, + }), + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithStudio], Description: subGroupDescription}, + }), Aliases: aliases, CreatedAt: createdAt, UpdatedAt: updatedAt, @@ -76,6 +96,22 @@ func Test_GroupStore_Create(t *testing.T) { }, true, }, + { + "invalid containing group id", + models.Group{ + Name: name, + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{{GroupID: invalidID}}), + }, + true, + }, + { + "invalid sub group id", + models.Group{ + Name: name, + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{{GroupID: invalidID}}), + }, + true, + }, } qb := db.Group @@ -131,36 +167,44 @@ func Test_GroupStore_Create(t *testing.T) { func Test_groupQueryBuilder_Update(t *testing.T) { var ( - name = "name" - url = "url" - aliases = "alias1, alias2" - director = "director" - rating = 60 - duration = 34 - synopsis = "synopsis" - date, _ = models.ParseDate("2003-02-01") - createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) - updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + name = "name" + url = "url" + aliases = "alias1, alias2" + director = "director" + rating = 60 + duration = 34 + synopsis = "synopsis" + date, _ = models.ParseDate("2003-02-01") + containingGroupDescription = "containingGroupDescription" + subGroupDescription = "subGroupDescription" + createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) ) tests := []struct { name string - updatedObject *models.Group + updatedObject models.Group wantErr bool }{ { "full", - &models.Group{ - ID: groupIDs[groupIdxWithTag], - Name: name, - Duration: &duration, - Date: &date, - Rating: &rating, - StudioID: &studioIDs[studioIdxWithGroup], - Director: director, - Synopsis: synopsis, - URLs: models.NewRelatedStrings([]string{url}), - TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithGroup]}), + models.Group{ + ID: groupIDs[groupIdxWithTag], + Name: name, + Duration: &duration, + Date: &date, + Rating: &rating, + StudioID: &studioIDs[studioIdxWithGroup], + Director: director, + Synopsis: synopsis, + URLs: models.NewRelatedStrings([]string{url}), + TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithGroup]}), + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithScene], Description: containingGroupDescription}, + }), + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithStudio], Description: subGroupDescription}, + }), Aliases: aliases, CreatedAt: createdAt, UpdatedAt: updatedAt, @@ -169,16 +213,34 @@ func Test_groupQueryBuilder_Update(t *testing.T) { }, { "clear tag ids", - &models.Group{ + models.Group{ ID: groupIDs[groupIdxWithTag], Name: name, TagIDs: models.NewRelatedIDs([]int{}), }, false, }, + { + "clear containing ids", + models.Group{ + ID: groupIDs[groupIdxWithParent], + Name: name, + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{}), + }, + false, + }, + { + "clear sub ids", + models.Group{ + ID: groupIDs[groupIdxWithChild], + Name: name, + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{}), + }, + false, + }, { "invalid studio id", - &models.Group{ + models.Group{ ID: groupIDs[groupIdxWithScene], Name: name, StudioID: &invalidID, @@ -187,13 +249,31 @@ func Test_groupQueryBuilder_Update(t *testing.T) { }, { "invalid tag id", - &models.Group{ + models.Group{ ID: groupIDs[groupIdxWithScene], Name: name, TagIDs: models.NewRelatedIDs([]int{invalidID}), }, true, }, + { + "invalid containing group id", + models.Group{ + ID: groupIDs[groupIdxWithScene], + Name: name, + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{{GroupID: invalidID}}), + }, + true, + }, + { + "invalid sub group id", + models.Group{ + ID: groupIDs[groupIdxWithScene], + Name: name, + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{{GroupID: invalidID}}), + }, + true, + }, } qb := db.Group @@ -201,9 +281,10 @@ func Test_groupQueryBuilder_Update(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - copy := *tt.updatedObject + actual := tt.updatedObject + expected := tt.updatedObject - if err := qb.Update(ctx, tt.updatedObject); (err != nil) != tt.wantErr { + if err := qb.Update(ctx, &actual); (err != nil) != tt.wantErr { t.Errorf("groupQueryBuilder.Update() error = %v, wantErr %v", err, tt.wantErr) } @@ -211,49 +292,61 @@ func Test_groupQueryBuilder_Update(t *testing.T) { return } - s, err := qb.Find(ctx, tt.updatedObject.ID) + s, err := qb.Find(ctx, actual.ID) if err != nil { t.Errorf("groupQueryBuilder.Find() error = %v", err) } // load relationships - if err := loadGroupRelationships(ctx, copy, s); err != nil { + if err := loadGroupRelationships(ctx, expected, s); err != nil { t.Errorf("loadGroupRelationships() error = %v", err) return } - assert.Equal(copy, *s) + assert.Equal(expected, *s) }) } } -func clearGroupPartial() models.GroupPartial { +var clearGroupPartial = models.GroupPartial{ // leave mandatory fields - return models.GroupPartial{ - Aliases: models.OptionalString{Set: true, Null: true}, - Synopsis: models.OptionalString{Set: true, Null: true}, - Director: models.OptionalString{Set: true, Null: true}, - Duration: models.OptionalInt{Set: true, Null: true}, - URLs: &models.UpdateStrings{Mode: models.RelationshipUpdateModeSet}, - Date: models.OptionalDate{Set: true, Null: true}, - Rating: models.OptionalInt{Set: true, Null: true}, - StudioID: models.OptionalInt{Set: true, Null: true}, - TagIDs: &models.UpdateIDs{Mode: models.RelationshipUpdateModeSet}, + Aliases: models.OptionalString{Set: true, Null: true}, + Synopsis: models.OptionalString{Set: true, Null: true}, + Director: models.OptionalString{Set: true, Null: true}, + Duration: models.OptionalInt{Set: true, Null: true}, + URLs: &models.UpdateStrings{Mode: models.RelationshipUpdateModeSet}, + Date: models.OptionalDate{Set: true, Null: true}, + Rating: models.OptionalInt{Set: true, Null: true}, + StudioID: models.OptionalInt{Set: true, Null: true}, + TagIDs: &models.UpdateIDs{Mode: models.RelationshipUpdateModeSet}, + ContainingGroups: &models.UpdateGroupDescriptions{Mode: models.RelationshipUpdateModeSet}, + SubGroups: &models.UpdateGroupDescriptions{Mode: models.RelationshipUpdateModeSet}, +} + +func emptyGroup(idx int) models.Group { + return models.Group{ + ID: groupIDs[idx], + Name: groupNames[idx], + TagIDs: models.NewRelatedIDs([]int{}), + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{}), + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{}), } } func Test_groupQueryBuilder_UpdatePartial(t *testing.T) { var ( - name = "name" - url = "url" - aliases = "alias1, alias2" - director = "director" - rating = 60 - duration = 34 - synopsis = "synopsis" - date, _ = models.ParseDate("2003-02-01") - createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) - updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + name = "name" + url = "url" + aliases = "alias1, alias2" + director = "director" + rating = 60 + duration = 34 + synopsis = "synopsis" + date, _ = models.ParseDate("2003-02-01") + containingGroupDescription = "containingGroupDescription" + subGroupDescription = "subGroupDescription" + createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) ) tests := []struct { @@ -285,6 +378,20 @@ func Test_groupQueryBuilder_UpdatePartial(t *testing.T) { IDs: []int{tagIDs[tagIdx1WithGroup], tagIDs[tagIdx1WithDupName]}, Mode: models.RelationshipUpdateModeSet, }, + ContainingGroups: &models.UpdateGroupDescriptions{ + Groups: []models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithStudio], Description: containingGroupDescription}, + {GroupID: groupIDs[groupIdxWithThreeTags], Description: containingGroupDescription}, + }, + Mode: models.RelationshipUpdateModeSet, + }, + SubGroups: &models.UpdateGroupDescriptions{ + Groups: []models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithTag], Description: subGroupDescription}, + {GroupID: groupIDs[groupIdxWithDupName], Description: subGroupDescription}, + }, + Mode: models.RelationshipUpdateModeSet, + }, }, models.Group{ ID: groupIDs[groupIdxWithScene], @@ -300,17 +407,113 @@ func Test_groupQueryBuilder_UpdatePartial(t *testing.T) { CreatedAt: createdAt, UpdatedAt: updatedAt, TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithGroup]}), + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithStudio], Description: containingGroupDescription}, + {GroupID: groupIDs[groupIdxWithThreeTags], Description: containingGroupDescription}, + }), + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithTag], Description: subGroupDescription}, + {GroupID: groupIDs[groupIdxWithDupName], Description: subGroupDescription}, + }), }, false, }, { "clear all", groupIDs[groupIdxWithScene], - clearGroupPartial(), + clearGroupPartial, + emptyGroup(groupIdxWithScene), + false, + }, + { + "clear tag ids", + groupIDs[groupIdxWithTag], + clearGroupPartial, + emptyGroup(groupIdxWithTag), + false, + }, + { + "clear group relationships", + groupIDs[groupIdxWithParentAndChild], + clearGroupPartial, + emptyGroup(groupIdxWithParentAndChild), + false, + }, + { + "add containing group", + groupIDs[groupIdxWithParent], + models.GroupPartial{ + ContainingGroups: &models.UpdateGroupDescriptions{ + Groups: []models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithScene], Description: containingGroupDescription}, + }, + Mode: models.RelationshipUpdateModeAdd, + }, + }, models.Group{ - ID: groupIDs[groupIdxWithScene], - Name: groupNames[groupIdxWithScene], - TagIDs: models.NewRelatedIDs([]int{}), + ID: groupIDs[groupIdxWithParent], + Name: groupNames[groupIdxWithParent], + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithChild]}, + {GroupID: groupIDs[groupIdxWithScene], Description: containingGroupDescription}, + }), + }, + false, + }, + { + "add sub group", + groupIDs[groupIdxWithChild], + models.GroupPartial{ + SubGroups: &models.UpdateGroupDescriptions{ + Groups: []models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithScene], Description: subGroupDescription}, + }, + Mode: models.RelationshipUpdateModeAdd, + }, + }, + models.Group{ + ID: groupIDs[groupIdxWithChild], + Name: groupNames[groupIdxWithChild], + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithParent]}, + {GroupID: groupIDs[groupIdxWithScene], Description: subGroupDescription}, + }), + }, + false, + }, + { + "remove containing group", + groupIDs[groupIdxWithParent], + models.GroupPartial{ + ContainingGroups: &models.UpdateGroupDescriptions{ + Groups: []models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithChild]}, + }, + Mode: models.RelationshipUpdateModeRemove, + }, + }, + models.Group{ + ID: groupIDs[groupIdxWithParent], + Name: groupNames[groupIdxWithParent], + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{}), + }, + false, + }, + { + "remove sub group", + groupIDs[groupIdxWithChild], + models.GroupPartial{ + SubGroups: &models.UpdateGroupDescriptions{ + Groups: []models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithParent]}, + }, + Mode: models.RelationshipUpdateModeRemove, + }, + }, + models.Group{ + ID: groupIDs[groupIdxWithChild], + Name: groupNames[groupIdxWithChild], + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{}), }, false, }, @@ -784,7 +987,38 @@ func TestGroupQuerySorting(t *testing.T) { groups = queryGroups(ctx, t, nil, &findFilter) lastGroup := groups[len(groups)-1] - assert.Equal(t, groupIDs[groupIdxWithScene], lastGroup.ID) + assert.Equal(t, groupIDs[groupIdxWithParentAndScene], lastGroup.ID) + + return nil + }) +} + +func TestGroupQuerySortOrderIndex(t *testing.T) { + sort := "sub_group_order" + direction := models.SortDirectionEnumDesc + findFilter := models.FindFilterType{ + Sort: &sort, + Direction: &direction, + } + + groupFilter := models.GroupFilterType{ + ContainingGroups: &models.HierarchicalMultiCriterionInput{ + Value: intslice.IntSliceToStringSlice([]int{groupIdxWithChild}), + Modifier: models.CriterionModifierIncludes, + }, + } + + withTxn(func(ctx context.Context) error { + // just ensure there are no errors + _, _, err := db.Group.Query(ctx, &groupFilter, &findFilter) + if err != nil { + t.Errorf("Error querying group: %s", err.Error()) + } + + _, _, err = db.Group.Query(ctx, nil, &findFilter) + if err != nil { + t.Errorf("Error querying group: %s", err.Error()) + } return nil }) @@ -830,6 +1064,832 @@ func TestGroupUpdateBackImage(t *testing.T) { } } +func TestGroupQueryContainingGroups(t *testing.T) { + const nameField = "Name" + + type criterion struct { + valueIdxs []int + modifier models.CriterionModifier + depth int + } + + tests := []struct { + name string + c criterion + q string + includeIdxs []int + }{ + { + "includes", + criterion{ + []int{groupIdxWithChild}, + models.CriterionModifierIncludes, + 0, + }, + "", + []int{groupIdxWithParent}, + }, + { + "excludes", + criterion{ + []int{groupIdxWithChild}, + models.CriterionModifierExcludes, + 0, + }, + getGroupStringValue(groupIdxWithParent, nameField), + nil, + }, + { + "includes (all levels)", + criterion{ + []int{groupIdxWithGrandChild}, + models.CriterionModifierIncludes, + -1, + }, + "", + []int{groupIdxWithParentAndChild, groupIdxWithGrandParent}, + }, + { + "includes (1 level)", + criterion{ + []int{groupIdxWithGrandChild}, + models.CriterionModifierIncludes, + 1, + }, + "", + []int{groupIdxWithParentAndChild, groupIdxWithGrandParent}, + }, + { + "is null", + criterion{ + nil, + models.CriterionModifierIsNull, + 0, + }, + getGroupStringValue(groupIdxWithParent, nameField), + nil, + }, + { + "not null", + criterion{ + nil, + models.CriterionModifierNotNull, + 0, + }, + "", + []int{groupIdxWithParentAndChild, groupIdxWithParent, groupIdxWithGrandParent, groupIdxWithParentAndScene}, + }, + } + + qb := db.Group + + for _, tt := range tests { + valueIDs := indexesToIDs(groupIDs, tt.c.valueIdxs) + expectedIDs := indexesToIDs(groupIDs, tt.includeIdxs) + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + groupFilter := &models.GroupFilterType{ + ContainingGroups: &models.HierarchicalMultiCriterionInput{ + Value: intslice.IntSliceToStringSlice(valueIDs), + Modifier: tt.c.modifier, + }, + } + + if tt.c.depth != 0 { + groupFilter.ContainingGroups.Depth = &tt.c.depth + } + + findFilter := models.FindFilterType{} + if tt.q != "" { + findFilter.Q = &tt.q + } + + groups, _, err := qb.Query(ctx, groupFilter, &findFilter) + if err != nil { + t.Errorf("GroupStore.Query() error = %v", err) + return + } + + // get ids of groups + groupIDs := sliceutil.Map(groups, func(g *models.Group) int { return g.ID }) + assert.ElementsMatch(t, expectedIDs, groupIDs) + }) + } +} + +func TestGroupQuerySubGroups(t *testing.T) { + const nameField = "Name" + + type criterion struct { + valueIdxs []int + modifier models.CriterionModifier + depth int + } + + tests := []struct { + name string + c criterion + q string + expectedIdxs []int + }{ + { + "includes", + criterion{ + []int{groupIdxWithParent}, + models.CriterionModifierIncludes, + 0, + }, + "", + []int{groupIdxWithChild}, + }, + { + "excludes", + criterion{ + []int{groupIdxWithParent}, + models.CriterionModifierExcludes, + 0, + }, + getGroupStringValue(groupIdxWithChild, nameField), + nil, + }, + { + "includes (all levels)", + criterion{ + []int{groupIdxWithGrandParent}, + models.CriterionModifierIncludes, + -1, + }, + "", + []int{groupIdxWithGrandChild, groupIdxWithParentAndChild}, + }, + { + "includes (1 level)", + criterion{ + []int{groupIdxWithGrandParent}, + models.CriterionModifierIncludes, + 1, + }, + "", + []int{groupIdxWithGrandChild, groupIdxWithParentAndChild}, + }, + { + "is null", + criterion{ + nil, + models.CriterionModifierIsNull, + 0, + }, + getGroupStringValue(groupIdxWithChild, nameField), + nil, + }, + { + "not null", + criterion{ + nil, + models.CriterionModifierNotNull, + 0, + }, + "", + []int{groupIdxWithGrandChild, groupIdxWithChild, groupIdxWithParentAndChild, groupIdxWithChildWithScene}, + }, + } + + qb := db.Group + + for _, tt := range tests { + valueIDs := indexesToIDs(groupIDs, tt.c.valueIdxs) + expectedIDs := indexesToIDs(groupIDs, tt.expectedIdxs) + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + groupFilter := &models.GroupFilterType{ + SubGroups: &models.HierarchicalMultiCriterionInput{ + Value: intslice.IntSliceToStringSlice(valueIDs), + Modifier: tt.c.modifier, + }, + } + + if tt.c.depth != 0 { + groupFilter.SubGroups.Depth = &tt.c.depth + } + + findFilter := models.FindFilterType{} + if tt.q != "" { + findFilter.Q = &tt.q + } + + groups, _, err := qb.Query(ctx, groupFilter, &findFilter) + if err != nil { + t.Errorf("GroupStore.Query() error = %v", err) + return + } + + // get ids of groups + groupIDs := sliceutil.Map(groups, func(g *models.Group) int { return g.ID }) + assert.ElementsMatch(t, expectedIDs, groupIDs) + }) + } +} + +func TestGroupQueryContainingGroupCount(t *testing.T) { + const nameField = "Name" + + tests := []struct { + name string + value int + modifier models.CriterionModifier + q string + expectedIdxs []int + }{ + { + "equals", + 1, + models.CriterionModifierEquals, + "", + []int{groupIdxWithParent, groupIdxWithGrandParent, groupIdxWithParentAndChild, groupIdxWithParentAndScene}, + }, + { + "not equals", + 1, + models.CriterionModifierNotEquals, + getGroupStringValue(groupIdxWithParent, nameField), + nil, + }, + { + "less than", + 1, + models.CriterionModifierLessThan, + getGroupStringValue(groupIdxWithParent, nameField), + nil, + }, + { + "greater than", + 0, + models.CriterionModifierGreaterThan, + "", + []int{groupIdxWithParent, groupIdxWithGrandParent, groupIdxWithParentAndChild, groupIdxWithParentAndScene}, + }, + } + + qb := db.Group + + for _, tt := range tests { + expectedIDs := indexesToIDs(groupIDs, tt.expectedIdxs) + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + groupFilter := &models.GroupFilterType{ + ContainingGroupCount: &models.IntCriterionInput{ + Value: tt.value, + Modifier: tt.modifier, + }, + } + + findFilter := models.FindFilterType{} + if tt.q != "" { + findFilter.Q = &tt.q + } + + groups, _, err := qb.Query(ctx, groupFilter, &findFilter) + if err != nil { + t.Errorf("GroupStore.Query() error = %v", err) + return + } + + // get ids of groups + groupIDs := sliceutil.Map(groups, func(g *models.Group) int { return g.ID }) + assert.ElementsMatch(t, expectedIDs, groupIDs) + }) + } +} + +func TestGroupQuerySubGroupCount(t *testing.T) { + const nameField = "Name" + + tests := []struct { + name string + value int + modifier models.CriterionModifier + q string + expectedIdxs []int + }{ + { + "equals", + 1, + models.CriterionModifierEquals, + "", + []int{groupIdxWithChild, groupIdxWithGrandChild, groupIdxWithParentAndChild, groupIdxWithChildWithScene}, + }, + { + "not equals", + 1, + models.CriterionModifierNotEquals, + getGroupStringValue(groupIdxWithChild, nameField), + nil, + }, + { + "less than", + 1, + models.CriterionModifierLessThan, + getGroupStringValue(groupIdxWithChild, nameField), + nil, + }, + { + "greater than", + 0, + models.CriterionModifierGreaterThan, + "", + []int{groupIdxWithChild, groupIdxWithGrandChild, groupIdxWithParentAndChild, groupIdxWithChildWithScene}, + }, + } + + qb := db.Group + + for _, tt := range tests { + expectedIDs := indexesToIDs(groupIDs, tt.expectedIdxs) + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + groupFilter := &models.GroupFilterType{ + SubGroupCount: &models.IntCriterionInput{ + Value: tt.value, + Modifier: tt.modifier, + }, + } + + findFilter := models.FindFilterType{} + if tt.q != "" { + findFilter.Q = &tt.q + } + + groups, _, err := qb.Query(ctx, groupFilter, &findFilter) + if err != nil { + t.Errorf("GroupStore.Query() error = %v", err) + return + } + + // get ids of groups + groupIDs := sliceutil.Map(groups, func(g *models.Group) int { return g.ID }) + assert.ElementsMatch(t, expectedIDs, groupIDs) + }) + } +} + +func TestGroupFindInAncestors(t *testing.T) { + tests := []struct { + name string + ancestorIdxs []int + idxs []int + expectedIdxs []int + }{ + { + "basic", + []int{groupIdxWithGrandParent}, + []int{groupIdxWithGrandChild}, + []int{groupIdxWithGrandChild}, + }, + { + "same", + []int{groupIdxWithScene}, + []int{groupIdxWithScene}, + []int{groupIdxWithScene}, + }, + { + "no matches", + []int{groupIdxWithGrandParent}, + []int{groupIdxWithScene}, + nil, + }, + } + + qb := db.Group + + for _, tt := range tests { + ancestorIDs := indexesToIDs(groupIDs, tt.ancestorIdxs) + ids := indexesToIDs(groupIDs, tt.idxs) + expectedIDs := indexesToIDs(groupIDs, tt.expectedIdxs) + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + found, err := qb.FindInAncestors(ctx, ancestorIDs, ids) + if err != nil { + t.Errorf("GroupStore.FindInAncestors() error = %v", err) + return + } + + // get ids of groups + assert.ElementsMatch(t, found, expectedIDs) + }) + } +} + +func TestGroupReorderSubGroups(t *testing.T) { + tests := []struct { + name string + subGroupLen int + idxsToMove []int + insertLoc int + insertAfter bool + // order of elements, using original indexes + expectedIdxs []int + }{ + { + "move single back before", + 5, + []int{2}, + 1, + false, + []int{0, 2, 1, 3, 4}, + }, + { + "move single forward before", + 5, + []int{2}, + 4, + false, + []int{0, 1, 3, 2, 4}, + }, + { + "move multiple back before", + 5, + []int{3, 2, 4}, + 0, + false, + []int{3, 2, 4, 0, 1}, + }, + { + "move multiple forward before", + 5, + []int{2, 1, 0}, + 4, + false, + []int{3, 2, 1, 0, 4}, + }, + { + "move single back after", + 5, + []int{2}, + 0, + true, + []int{0, 2, 1, 3, 4}, + }, + { + "move single forward after", + 5, + []int{2}, + 4, + true, + []int{0, 1, 3, 4, 2}, + }, + { + "move multiple back after", + 5, + []int{3, 2, 4}, + 0, + false, + []int{0, 3, 2, 4, 1}, + }, + { + "move multiple forward after", + 5, + []int{2, 1, 0}, + 4, + false, + []int{3, 4, 2, 1, 0}, + }, + } + + qb := db.Group + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + // create the group + group := models.Group{ + Name: "TestGroupReorderSubGroups", + } + + if err := qb.Create(ctx, &group); err != nil { + t.Errorf("GroupStore.Create() error = %v", err) + return + } + + // and sub-groups + idxToId := make([]int, tt.subGroupLen) + + for i := 0; i < tt.subGroupLen; i++ { + subGroup := models.Group{ + Name: fmt.Sprintf("SubGroup %d", i), + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: group.ID}, + }), + } + + if err := qb.Create(ctx, &subGroup); err != nil { + t.Errorf("GroupStore.Create() error = %v", err) + return + } + + idxToId[i] = subGroup.ID + } + + // reorder + idsToMove := indexesToIDs(idxToId, tt.idxsToMove) + insertID := idxToId[tt.insertLoc] + if err := qb.ReorderSubGroups(ctx, group.ID, idsToMove, insertID, tt.insertAfter); err != nil { + t.Errorf("GroupStore.ReorderSubGroups() error = %v", err) + return + } + + // validate the new order + gd, err := qb.GetSubGroupDescriptions(ctx, group.ID) + if err != nil { + t.Errorf("GroupStore.GetSubGroupDescriptions() error = %v", err) + return + } + + // get ids of groups + newIDs := sliceutil.Map(gd, func(gd models.GroupIDDescription) int { return gd.GroupID }) + newIdxs := sliceutil.Map(newIDs, func(id int) int { return sliceutil.Index(idxToId, id) }) + + assert.ElementsMatch(t, tt.expectedIdxs, newIdxs) + }) + } +} + +func TestGroupAddSubGroups(t *testing.T) { + tests := []struct { + name string + existingSubGroupLen int + insertGroupsLen int + insertLoc int + // order of elements, using original indexes + expectedIdxs []int + }{ + { + "append single", + 4, + 1, + 999, + []int{0, 1, 2, 3, 4}, + }, + { + "insert single middle", + 4, + 1, + 2, + []int{0, 1, 4, 2, 3}, + }, + { + "insert single start", + 4, + 1, + 0, + []int{4, 0, 1, 2, 3}, + }, + { + "append multiple", + 4, + 2, + 999, + []int{0, 1, 2, 3, 4, 5}, + }, + { + "insert multiple middle", + 4, + 2, + 2, + []int{0, 1, 4, 5, 2, 3}, + }, + { + "insert multiple start", + 4, + 2, + 0, + []int{4, 5, 0, 1, 2, 3}, + }, + } + + qb := db.Group + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + // create the group + group := models.Group{ + Name: "TestGroupReorderSubGroups", + } + + if err := qb.Create(ctx, &group); err != nil { + t.Errorf("GroupStore.Create() error = %v", err) + return + } + + // and sub-groups + idxToId := make([]int, tt.existingSubGroupLen+tt.insertGroupsLen) + + for i := 0; i < tt.existingSubGroupLen; i++ { + subGroup := models.Group{ + Name: fmt.Sprintf("Existing SubGroup %d", i), + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: group.ID}, + }), + } + + if err := qb.Create(ctx, &subGroup); err != nil { + t.Errorf("GroupStore.Create() error = %v", err) + return + } + + idxToId[i] = subGroup.ID + } + + // and sub-groups to insert + for i := 0; i < tt.insertGroupsLen; i++ { + subGroup := models.Group{ + Name: fmt.Sprintf("Inserted SubGroup %d", i), + } + + if err := qb.Create(ctx, &subGroup); err != nil { + t.Errorf("GroupStore.Create() error = %v", err) + return + } + + idxToId[i+tt.existingSubGroupLen] = subGroup.ID + } + + // convert ids to description + idDescriptions := make([]models.GroupIDDescription, tt.insertGroupsLen) + for i, id := range idxToId[tt.existingSubGroupLen:] { + idDescriptions[i] = models.GroupIDDescription{GroupID: id} + } + + // add + if err := qb.AddSubGroups(ctx, group.ID, idDescriptions, &tt.insertLoc); err != nil { + t.Errorf("GroupStore.AddSubGroups() error = %v", err) + return + } + + // validate the new order + gd, err := qb.GetSubGroupDescriptions(ctx, group.ID) + if err != nil { + t.Errorf("GroupStore.GetSubGroupDescriptions() error = %v", err) + return + } + + // get ids of groups + newIDs := sliceutil.Map(gd, func(gd models.GroupIDDescription) int { return gd.GroupID }) + newIdxs := sliceutil.Map(newIDs, func(id int) int { return sliceutil.Index(idxToId, id) }) + + assert.ElementsMatch(t, tt.expectedIdxs, newIdxs) + }) + } +} + +func TestGroupRemoveSubGroups(t *testing.T) { + tests := []struct { + name string + subGroupLen int + removeIdxs []int + // order of elements, using original indexes + expectedIdxs []int + }{ + { + "remove last", + 4, + []int{3}, + []int{0, 1, 2}, + }, + { + "remove first", + 4, + []int{0}, + []int{1, 2, 3}, + }, + { + "remove middle", + 4, + []int{2}, + []int{0, 1, 3}, + }, + { + "remove multiple", + 4, + []int{1, 3}, + []int{0, 2}, + }, + { + "remove all", + 4, + []int{0, 1, 2, 3}, + []int{}, + }, + } + + qb := db.Group + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + // create the group + group := models.Group{ + Name: "TestGroupReorderSubGroups", + } + + if err := qb.Create(ctx, &group); err != nil { + t.Errorf("GroupStore.Create() error = %v", err) + return + } + + // and sub-groups + idxToId := make([]int, tt.subGroupLen) + + for i := 0; i < tt.subGroupLen; i++ { + subGroup := models.Group{ + Name: fmt.Sprintf("Existing SubGroup %d", i), + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: group.ID}, + }), + } + + if err := qb.Create(ctx, &subGroup); err != nil { + t.Errorf("GroupStore.Create() error = %v", err) + return + } + + idxToId[i] = subGroup.ID + } + + idsToRemove := indexesToIDs(idxToId, tt.removeIdxs) + if err := qb.RemoveSubGroups(ctx, group.ID, idsToRemove); err != nil { + t.Errorf("GroupStore.RemoveSubGroups() error = %v", err) + return + } + + // validate the new order + gd, err := qb.GetSubGroupDescriptions(ctx, group.ID) + if err != nil { + t.Errorf("GroupStore.GetSubGroupDescriptions() error = %v", err) + return + } + + // get ids of groups + newIDs := sliceutil.Map(gd, func(gd models.GroupIDDescription) int { return gd.GroupID }) + newIdxs := sliceutil.Map(newIDs, func(id int) int { return sliceutil.Index(idxToId, id) }) + + assert.ElementsMatch(t, tt.expectedIdxs, newIdxs) + }) + } +} + +func TestGroupFindSubGroupIDs(t *testing.T) { + tests := []struct { + name string + containingGroupIdx int + subIdxs []int + expectedIdxs []int + }{ + { + "overlap", + groupIdxWithGrandChild, + []int{groupIdxWithParentAndChild, groupIdxWithGrandParent}, + []int{groupIdxWithParentAndChild}, + }, + { + "non-overlap", + groupIdxWithGrandChild, + []int{groupIdxWithGrandParent}, + []int{}, + }, + { + "none", + groupIdxWithScene, + []int{groupIdxWithDupName}, + []int{}, + }, + { + "invalid", + invalidID, + []int{invalidID}, + []int{}, + }, + } + + qb := db.Group + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + subIDs := indexesToIDs(groupIDs, tt.subIdxs) + + id := indexToID(groupIDs, tt.containingGroupIdx) + + found, err := qb.FindSubGroupIDs(ctx, id, subIDs) + if err != nil { + t.Errorf("GroupStore.FindSubGroupIDs() error = %v", err) + return + } + + // get ids of groups + foundIdxs := sliceutil.Map(found, func(id int) int { return sliceutil.Index(groupIDs, id) }) + + assert.ElementsMatch(t, tt.expectedIdxs, foundIdxs) + }) + } +} + // TODO Update // TODO Destroy - ensure image is destroyed // TODO Find diff --git a/pkg/sqlite/migrations/67_group_relationships.up.sql b/pkg/sqlite/migrations/67_group_relationships.up.sql new file mode 100644 index 00000000000..76ac29cc83f --- /dev/null +++ b/pkg/sqlite/migrations/67_group_relationships.up.sql @@ -0,0 +1,13 @@ +CREATE TABLE `groups_relations` ( + `containing_id` integer not null, + `sub_id` integer not null, + `order_index` integer not null, + `description` varchar(255), + primary key (`containing_id`, `sub_id`), + foreign key (`containing_id`) references `groups`(`id`) on delete cascade, + foreign key (`sub_id`) references `groups`(`id`) on delete cascade, + check (`containing_id` != `sub_id`) +); + +CREATE INDEX `index_groups_relations_sub_id` ON `groups_relations` (`sub_id`); +CREATE UNIQUE INDEX `index_groups_relations_order_index_unique` ON `groups_relations` (`containing_id`, `order_index`); diff --git a/pkg/sqlite/query.go b/pkg/sqlite/query.go index 597ab66b98f..9c09d8beaed 100644 --- a/pkg/sqlite/query.go +++ b/pkg/sqlite/query.go @@ -110,6 +110,16 @@ func (qb *queryBuilder) addArg(args ...interface{}) { qb.args = append(qb.args, args...) } +func (qb *queryBuilder) hasJoin(alias string) bool { + for _, j := range qb.joins { + if j.alias() == alias { + return true + } + } + + return false +} + func (qb *queryBuilder) join(table, as, onClause string) { newJoin := join{ table: table, diff --git a/pkg/sqlite/scene.go b/pkg/sqlite/scene.go index 9b8bd73157b..c950be4d160 100644 --- a/pkg/sqlite/scene.go +++ b/pkg/sqlite/scene.go @@ -791,13 +791,6 @@ func (qb *SceneStore) FindByGroupID(ctx context.Context, groupID int) ([]*models return ret, nil } -func (qb *SceneStore) CountByGroupID(ctx context.Context, groupID int) (int, error) { - joinTable := scenesGroupsJoinTable - - q := dialect.Select(goqu.COUNT("*")).From(joinTable).Where(joinTable.Col(groupIDColumn).Eq(groupID)) - return count(ctx, q) -} - func (qb *SceneStore) Count(ctx context.Context) (int, error) { q := dialect.Select(goqu.COUNT("*")).From(qb.table()) return count(ctx, q) @@ -858,6 +851,7 @@ func (qb *SceneStore) PlayDuration(ctx context.Context) (float64, error) { return ret, nil } +// TODO - currently only used by unit test func (qb *SceneStore) CountByStudioID(ctx context.Context, studioID int) (int, error) { table := qb.table() @@ -865,13 +859,6 @@ func (qb *SceneStore) CountByStudioID(ctx context.Context, studioID int) (int, e return count(ctx, q) } -func (qb *SceneStore) CountByTagID(ctx context.Context, tagID int) (int, error) { - joinTable := scenesTagsJoinTable - - q := dialect.Select(goqu.COUNT("*")).From(joinTable).Where(joinTable.Col(tagIDColumn).Eq(tagID)) - return count(ctx, q) -} - func (qb *SceneStore) countMissingFingerprints(ctx context.Context, fpType string) (int, error) { fpTable := fingerprintTableMgr.table.As("fingerprints_temp") diff --git a/pkg/sqlite/scene_filter.go b/pkg/sqlite/scene_filter.go index 3f2233395fa..2e63dad975f 100644 --- a/pkg/sqlite/scene_filter.go +++ b/pkg/sqlite/scene_filter.go @@ -149,7 +149,7 @@ func (qb *sceneFilterHandler) criterionHandler() criterionHandler { studioCriterionHandler(sceneTable, sceneFilter.Studios), qb.groupsCriterionHandler(sceneFilter.Groups), - qb.groupsCriterionHandler(sceneFilter.Movies), + qb.moviesCriterionHandler(sceneFilter.Movies), qb.galleriesCriterionHandler(sceneFilter.Galleries), qb.performerTagsCriterionHandler(sceneFilter.PerformerTags), @@ -483,7 +483,8 @@ func (qb *sceneFilterHandler) performerAgeCriterionHandler(performerAge *models. } } -func (qb *sceneFilterHandler) groupsCriterionHandler(movies *models.MultiCriterionInput) criterionHandlerFunc { +// legacy handler +func (qb *sceneFilterHandler) moviesCriterionHandler(movies *models.MultiCriterionInput) criterionHandlerFunc { addJoinsFunc := func(f *filterBuilder) { sceneRepository.groups.join(f, "", "scenes.id") f.addLeftJoin("groups", "", "groups_scenes.group_id = groups.id") @@ -492,6 +493,23 @@ func (qb *sceneFilterHandler) groupsCriterionHandler(movies *models.MultiCriteri return h.handler(movies) } +func (qb *sceneFilterHandler) groupsCriterionHandler(groups *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + h := joinedHierarchicalMultiCriterionHandlerBuilder{ + primaryTable: sceneTable, + foreignTable: groupTable, + foreignFK: "group_id", + + relationsTable: groupRelationsTable, + parentFK: "containing_id", + childFK: "sub_id", + joinAs: "scene_group", + joinTable: groupsScenesTable, + primaryFK: sceneIDColumn, + } + + return h.handler(groups) +} + func (qb *sceneFilterHandler) galleriesCriterionHandler(galleries *models.MultiCriterionInput) criterionHandlerFunc { addJoinsFunc := func(f *filterBuilder) { sceneRepository.galleries.join(f, "", "scenes.id") diff --git a/pkg/sqlite/scene_test.go b/pkg/sqlite/scene_test.go index 9116158fc9f..a3174d7278d 100644 --- a/pkg/sqlite/scene_test.go +++ b/pkg/sqlite/scene_test.go @@ -16,6 +16,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/sliceutil" + "github.com/stashapp/stash/pkg/sliceutil/intslice" "github.com/stretchr/testify/assert" ) @@ -2217,7 +2218,7 @@ func TestSceneQuery(t *testing.T) { }, }) if (err != nil) != tt.wantErr { - t.Errorf("PerformerStore.Query() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("SceneStore.Query() error = %v, wantErr %v", err, tt.wantErr) return } @@ -3873,6 +3874,100 @@ func TestSceneQueryStudioDepth(t *testing.T) { }) } +func TestSceneGroups(t *testing.T) { + type criterion struct { + valueIdxs []int + modifier models.CriterionModifier + depth int + } + + tests := []struct { + name string + c criterion + q string + includeIdxs []int + excludeIdxs []int + }{ + { + "includes", + criterion{ + []int{groupIdxWithScene}, + models.CriterionModifierIncludes, + 0, + }, + "", + []int{sceneIdxWithGroup}, + nil, + }, + { + "excludes", + criterion{ + []int{groupIdxWithScene}, + models.CriterionModifierExcludes, + 0, + }, + getSceneStringValue(sceneIdxWithGroup, titleField), + nil, + []int{sceneIdxWithGroup}, + }, + { + "includes (depth = 1)", + criterion{ + []int{groupIdxWithChildWithScene}, + models.CriterionModifierIncludes, + 1, + }, + "", + []int{sceneIdxWithGroupWithParent}, + nil, + }, + } + + for _, tt := range tests { + valueIDs := indexesToIDs(groupIDs, tt.c.valueIdxs) + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + sceneFilter := &models.SceneFilterType{ + Groups: &models.HierarchicalMultiCriterionInput{ + Value: intslice.IntSliceToStringSlice(valueIDs), + Modifier: tt.c.modifier, + }, + } + + if tt.c.depth != 0 { + sceneFilter.Groups.Depth = &tt.c.depth + } + + findFilter := &models.FindFilterType{} + if tt.q != "" { + findFilter.Q = &tt.q + } + + results, err := db.Scene.Query(ctx, models.SceneQueryOptions{ + SceneFilter: sceneFilter, + QueryOptions: models.QueryOptions{ + FindFilter: findFilter, + }, + }) + if err != nil { + t.Errorf("SceneStore.Query() error = %v", err) + return + } + + include := indexesToIDs(sceneIDs, tt.includeIdxs) + exclude := indexesToIDs(sceneIDs, tt.excludeIdxs) + + assert.Subset(results.IDs, include) + + for _, e := range exclude { + assert.NotContains(results.IDs, e) + } + }) + } +} + func TestSceneQueryMovies(t *testing.T) { withTxn(func(ctx context.Context) error { sqb := db.Scene @@ -4188,78 +4283,6 @@ func verifyScenesPerformerCount(t *testing.T, performerCountCriterion models.Int }) } -func TestSceneCountByTagID(t *testing.T) { - withTxn(func(ctx context.Context) error { - sqb := db.Scene - - sceneCount, err := sqb.CountByTagID(ctx, tagIDs[tagIdxWithScene]) - - if err != nil { - t.Errorf("error calling CountByTagID: %s", err.Error()) - } - - assert.Equal(t, 1, sceneCount) - - sceneCount, err = sqb.CountByTagID(ctx, 0) - - if err != nil { - t.Errorf("error calling CountByTagID: %s", err.Error()) - } - - assert.Equal(t, 0, sceneCount) - - return nil - }) -} - -func TestSceneCountByGroupID(t *testing.T) { - withTxn(func(ctx context.Context) error { - sqb := db.Scene - - sceneCount, err := sqb.CountByGroupID(ctx, groupIDs[groupIdxWithScene]) - - if err != nil { - t.Errorf("error calling CountByGroupID: %s", err.Error()) - } - - assert.Equal(t, 1, sceneCount) - - sceneCount, err = sqb.CountByGroupID(ctx, 0) - - if err != nil { - t.Errorf("error calling CountByGroupID: %s", err.Error()) - } - - assert.Equal(t, 0, sceneCount) - - return nil - }) -} - -func TestSceneCountByStudioID(t *testing.T) { - withTxn(func(ctx context.Context) error { - sqb := db.Scene - - sceneCount, err := sqb.CountByStudioID(ctx, studioIDs[studioIdxWithScene]) - - if err != nil { - t.Errorf("error calling CountByStudioID: %s", err.Error()) - } - - assert.Equal(t, 1, sceneCount) - - sceneCount, err = sqb.CountByStudioID(ctx, 0) - - if err != nil { - t.Errorf("error calling CountByStudioID: %s", err.Error()) - } - - assert.Equal(t, 0, sceneCount) - - return nil - }) -} - func TestFindByMovieID(t *testing.T) { withTxn(func(ctx context.Context) error { sqb := db.Scene diff --git a/pkg/sqlite/setup_test.go b/pkg/sqlite/setup_test.go index aa6af73c4c9..624ffb4e222 100644 --- a/pkg/sqlite/setup_test.go +++ b/pkg/sqlite/setup_test.go @@ -78,6 +78,7 @@ const ( sceneIdxWithGrandChildStudio sceneIdxMissingPhash sceneIdxWithPerformerParentTag + sceneIdxWithGroupWithParent // new indexes above lastSceneIdx @@ -153,9 +154,15 @@ const ( groupIdxWithTag groupIdxWithTwoTags groupIdxWithThreeTags + groupIdxWithGrandChild + groupIdxWithChild + groupIdxWithParentAndChild + groupIdxWithParent + groupIdxWithGrandParent + groupIdxWithParentAndScene + groupIdxWithChildWithScene // groups with dup names start from the end - // create 7 more basic groups (can remove this if we add more indexes) - groupIdxWithDupName = groupIdxWithStudio + 7 + groupIdxWithDupName groupsNameCase = groupIdxWithDupName groupsNameNoCase = 1 @@ -390,7 +397,8 @@ var ( } sceneGroups = linkMap{ - sceneIdxWithGroup: {groupIdxWithScene}, + sceneIdxWithGroup: {groupIdxWithScene}, + sceneIdxWithGroupWithParent: {groupIdxWithParentAndScene}, } sceneStudios = map[int]int{ @@ -541,15 +549,31 @@ var ( } ) +var ( + groupParentLinks = [][2]int{ + {groupIdxWithChild, groupIdxWithParent}, + {groupIdxWithGrandChild, groupIdxWithParentAndChild}, + {groupIdxWithParentAndChild, groupIdxWithGrandParent}, + {groupIdxWithChildWithScene, groupIdxWithParentAndScene}, + } +) + func indexesToIDs(ids []int, indexes []int) []int { ret := make([]int, len(indexes)) for i, idx := range indexes { - ret[i] = ids[idx] + ret[i] = indexToID(ids, idx) } return ret } +func indexToID(ids []int, idx int) int { + if idx < 0 { + return invalidID + } + return ids[idx] +} + func indexFromID(ids []int, id int) int { for i, v := range ids { if v == id { @@ -697,6 +721,10 @@ func populateDB() error { return fmt.Errorf("error linking tags parent: %s", err.Error()) } + if err := linkGroupsParent(ctx, db.Group); err != nil { + return fmt.Errorf("error linking tags parent: %s", err.Error()) + } + for _, ms := range markerSpecs { if err := createMarker(ctx, db.SceneMarker, ms); err != nil { return fmt.Errorf("error creating scene marker: %s", err.Error()) @@ -1885,6 +1913,24 @@ func linkTagsParent(ctx context.Context, qb models.TagReaderWriter) error { }) } +func linkGroupsParent(ctx context.Context, qb models.GroupReaderWriter) error { + return doLinks(groupParentLinks, func(parentIndex, childIndex int) error { + groupID := groupIDs[childIndex] + + p := models.GroupPartial{ + ContainingGroups: &models.UpdateGroupDescriptions{ + Groups: []models.GroupIDDescription{ + {GroupID: groupIDs[parentIndex]}, + }, + Mode: models.RelationshipUpdateModeAdd, + }, + } + + _, err := qb.UpdatePartial(ctx, groupID, p) + return err + }) +} + func addTagImage(ctx context.Context, qb models.TagWriter, tagIndex int) error { return qb.UpdateImage(ctx, tagIDs[tagIndex], []byte("image")) } diff --git a/pkg/sqlite/tables.go b/pkg/sqlite/tables.go index 365abe81292..74a5ebe698c 100644 --- a/pkg/sqlite/tables.go +++ b/pkg/sqlite/tables.go @@ -37,8 +37,9 @@ var ( studiosTagsJoinTable = goqu.T(studiosTagsTable) studiosStashIDsJoinTable = goqu.T("studio_stash_ids") - groupsURLsJoinTable = goqu.T(groupURLsTable) - groupsTagsJoinTable = goqu.T(groupsTagsTable) + groupsURLsJoinTable = goqu.T(groupURLsTable) + groupsTagsJoinTable = goqu.T(groupsTagsTable) + groupRelationsJoinTable = goqu.T(groupRelationsTable) tagsAliasesJoinTable = goqu.T(tagAliasesTable) tagRelationsJoinTable = goqu.T(tagRelationsTable) @@ -361,6 +362,10 @@ var ( foreignTable: tagTableMgr, orderBy: tagTableMgr.table.Col("name").Asc(), } + + groupRelationshipTableMgr = &table{ + table: groupRelationsJoinTable, + } ) var ( diff --git a/pkg/sqlite/tag_filter.go b/pkg/sqlite/tag_filter.go index 26e33c36c6e..ba9e9bb08ec 100644 --- a/pkg/sqlite/tag_filter.go +++ b/pkg/sqlite/tag_filter.go @@ -2,7 +2,6 @@ package sqlite import ( "context" - "fmt" "github.com/stashapp/stash/pkg/models" ) @@ -51,6 +50,14 @@ func (qb *tagFilterHandler) handle(ctx context.Context, f *filterBuilder) { f.handleCriterion(ctx, qb.criterionHandler()) } +var tagHierarchyHandler = hierarchicalRelationshipHandler{ + primaryTable: tagTable, + relationTable: tagRelationsTable, + aliasPrefix: tagTable, + parentIDCol: "parent_id", + childIDCol: "child_id", +} + func (qb *tagFilterHandler) criterionHandler() criterionHandler { tagFilter := qb.tagFilter return compoundHandler{ @@ -72,10 +79,10 @@ func (qb *tagFilterHandler) criterionHandler() criterionHandler { qb.groupCountCriterionHandler(tagFilter.MovieCount), qb.markerCountCriterionHandler(tagFilter.MarkerCount), - qb.parentsCriterionHandler(tagFilter.Parents), - qb.childrenCriterionHandler(tagFilter.Children), - qb.parentCountCriterionHandler(tagFilter.ParentCount), - qb.childCountCriterionHandler(tagFilter.ChildCount), + tagHierarchyHandler.ParentsCriterionHandler(tagFilter.Parents), + tagHierarchyHandler.ChildrenCriterionHandler(tagFilter.Children), + tagHierarchyHandler.ParentCountCriterionHandler(tagFilter.ParentCount), + tagHierarchyHandler.ChildCountCriterionHandler(tagFilter.ChildCount), ×tampCriterionHandler{tagFilter.CreatedAt, "tags.created_at", nil}, ×tampCriterionHandler{tagFilter.UpdatedAt, "tags.updated_at", nil}, @@ -212,213 +219,3 @@ func (qb *tagFilterHandler) markerCountCriterionHandler(markerCount *models.IntC } } } - -func (qb *tagFilterHandler) parentsCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { - return func(ctx context.Context, f *filterBuilder) { - if criterion != nil { - tags := criterion.CombineExcludes() - - // validate the modifier - switch tags.Modifier { - case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull: - // valid - default: - f.setError(fmt.Errorf("invalid modifier %s for tag parent/children", criterion.Modifier)) - } - - if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull { - var notClause string - if tags.Modifier == models.CriterionModifierNotNull { - notClause = "NOT" - } - - f.addLeftJoin("tags_relations", "parent_relations", "tags.id = parent_relations.child_id") - - f.addWhere(fmt.Sprintf("parent_relations.parent_id IS %s NULL", notClause)) - return - } - - if len(tags.Value) == 0 && len(tags.Excludes) == 0 { - return - } - - if len(tags.Value) > 0 { - var args []interface{} - for _, val := range tags.Value { - args = append(args, val) - } - - depthVal := 0 - if tags.Depth != nil { - depthVal = *tags.Depth - } - - var depthCondition string - if depthVal != -1 { - depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal) - } - - query := `parents AS ( - SELECT parent_id AS root_id, child_id AS item_id, 0 AS depth FROM tags_relations WHERE parent_id IN` + getInBinding(len(tags.Value)) + ` - UNION - SELECT root_id, child_id, depth + 1 FROM tags_relations INNER JOIN parents ON item_id = parent_id ` + depthCondition + ` - )` - - f.addRecursiveWith(query, args...) - - f.addLeftJoin("parents", "", "parents.item_id = tags.id") - - addHierarchicalConditionClauses(f, tags, "parents", "root_id") - } - - if len(tags.Excludes) > 0 { - var args []interface{} - for _, val := range tags.Excludes { - args = append(args, val) - } - - depthVal := 0 - if tags.Depth != nil { - depthVal = *tags.Depth - } - - var depthCondition string - if depthVal != -1 { - depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal) - } - - query := `parents2 AS ( - SELECT parent_id AS root_id, child_id AS item_id, 0 AS depth FROM tags_relations WHERE parent_id IN` + getInBinding(len(tags.Excludes)) + ` - UNION - SELECT root_id, child_id, depth + 1 FROM tags_relations INNER JOIN parents2 ON item_id = parent_id ` + depthCondition + ` - )` - - f.addRecursiveWith(query, args...) - - f.addLeftJoin("parents2", "", "parents2.item_id = tags.id") - - addHierarchicalConditionClauses(f, models.HierarchicalMultiCriterionInput{ - Value: tags.Excludes, - Depth: tags.Depth, - Modifier: models.CriterionModifierExcludes, - }, "parents2", "root_id") - } - } - } -} - -func (qb *tagFilterHandler) childrenCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { - return func(ctx context.Context, f *filterBuilder) { - if criterion != nil { - tags := criterion.CombineExcludes() - - // validate the modifier - switch tags.Modifier { - case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull: - // valid - default: - f.setError(fmt.Errorf("invalid modifier %s for tag parent/children", criterion.Modifier)) - } - - if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull { - var notClause string - if tags.Modifier == models.CriterionModifierNotNull { - notClause = "NOT" - } - - f.addLeftJoin("tags_relations", "child_relations", "tags.id = child_relations.parent_id") - - f.addWhere(fmt.Sprintf("child_relations.child_id IS %s NULL", notClause)) - return - } - - if len(tags.Value) == 0 && len(tags.Excludes) == 0 { - return - } - - if len(tags.Value) > 0 { - var args []interface{} - for _, val := range tags.Value { - args = append(args, val) - } - - depthVal := 0 - if tags.Depth != nil { - depthVal = *tags.Depth - } - - var depthCondition string - if depthVal != -1 { - depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal) - } - - query := `children AS ( - SELECT child_id AS root_id, parent_id AS item_id, 0 AS depth FROM tags_relations WHERE child_id IN` + getInBinding(len(tags.Value)) + ` - UNION - SELECT root_id, parent_id, depth + 1 FROM tags_relations INNER JOIN children ON item_id = child_id ` + depthCondition + ` - )` - - f.addRecursiveWith(query, args...) - - f.addLeftJoin("children", "", "children.item_id = tags.id") - - addHierarchicalConditionClauses(f, tags, "children", "root_id") - } - - if len(tags.Excludes) > 0 { - var args []interface{} - for _, val := range tags.Excludes { - args = append(args, val) - } - - depthVal := 0 - if tags.Depth != nil { - depthVal = *tags.Depth - } - - var depthCondition string - if depthVal != -1 { - depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal) - } - - query := `children2 AS ( - SELECT child_id AS root_id, parent_id AS item_id, 0 AS depth FROM tags_relations WHERE child_id IN` + getInBinding(len(tags.Excludes)) + ` - UNION - SELECT root_id, parent_id, depth + 1 FROM tags_relations INNER JOIN children2 ON item_id = child_id ` + depthCondition + ` - )` - - f.addRecursiveWith(query, args...) - - f.addLeftJoin("children2", "", "children2.item_id = tags.id") - - addHierarchicalConditionClauses(f, models.HierarchicalMultiCriterionInput{ - Value: tags.Excludes, - Depth: tags.Depth, - Modifier: models.CriterionModifierExcludes, - }, "children2", "root_id") - } - } - } -} - -func (qb *tagFilterHandler) parentCountCriterionHandler(parentCount *models.IntCriterionInput) criterionHandlerFunc { - return func(ctx context.Context, f *filterBuilder) { - if parentCount != nil { - f.addLeftJoin("tags_relations", "parents_count", "parents_count.child_id = tags.id") - clause, args := getIntCriterionWhereClause("count(distinct parents_count.parent_id)", *parentCount) - - f.addHaving(clause, args...) - } - } -} - -func (qb *tagFilterHandler) childCountCriterionHandler(childCount *models.IntCriterionInput) criterionHandlerFunc { - return func(ctx context.Context, f *filterBuilder) { - if childCount != nil { - f.addLeftJoin("tags_relations", "children_count", "children_count.parent_id = tags.id") - clause, args := getIntCriterionWhereClause("count(distinct children_count.child_id)", *childCount) - - f.addHaving(clause, args...) - } - } -} diff --git a/pkg/sqlite/tag_test.go b/pkg/sqlite/tag_test.go index f673567f82e..5359be78517 100644 --- a/pkg/sqlite/tag_test.go +++ b/pkg/sqlite/tag_test.go @@ -712,7 +712,7 @@ func TestTagQueryParent(t *testing.T) { assert.Len(t, tags, 1) // ensure id is correct - assert.Equal(t, sceneIDs[tagIdxWithParentTag], tags[0].ID) + assert.Equal(t, tagIDs[tagIdxWithParentTag], tags[0].ID) tagCriterion.Modifier = models.CriterionModifierExcludes diff --git a/ui/v2.5/graphql/data/group.graphql b/ui/v2.5/graphql/data/group.graphql index 60f55e30948..963e8d6e672 100644 --- a/ui/v2.5/graphql/data/group.graphql +++ b/ui/v2.5/graphql/data/group.graphql @@ -15,11 +15,21 @@ fragment GroupData on Group { ...SlimTagData } + containing_groups { + group { + ...SlimGroupData + } + description + } + synopsis urls front_image_path back_image_path scene_count + scene_count_all: scene_count(depth: -1) + sub_group_count + sub_group_count_all: sub_group_count(depth: -1) scenes { id diff --git a/ui/v2.5/graphql/mutations/group.graphql b/ui/v2.5/graphql/mutations/group.graphql index fb739e84009..8065e4adbb7 100644 --- a/ui/v2.5/graphql/mutations/group.graphql +++ b/ui/v2.5/graphql/mutations/group.graphql @@ -23,3 +23,15 @@ mutation GroupDestroy($id: ID!) { mutation GroupsDestroy($ids: [ID!]!) { groupsDestroy(ids: $ids) } + +mutation AddGroupSubGroups($input: GroupSubGroupAddInput!) { + addGroupSubGroups(input: $input) +} + +mutation RemoveGroupSubGroups($input: GroupSubGroupRemoveInput!) { + removeGroupSubGroups(input: $input) +} + +mutation ReorderSubGroups($input: ReorderSubGroupsInput!) { + reorderSubGroups(input: $input) +} diff --git a/ui/v2.5/src/components/Groups/ContainingGroupsMultiSet.tsx b/ui/v2.5/src/components/Groups/ContainingGroupsMultiSet.tsx new file mode 100644 index 00000000000..25ad0be4a3e --- /dev/null +++ b/ui/v2.5/src/components/Groups/ContainingGroupsMultiSet.tsx @@ -0,0 +1,61 @@ +import React from "react"; +import * as GQL from "src/core/generated-graphql"; +import { MultiSetModeButtons } from "../Shared/MultiSet"; +import { + IRelatedGroupEntry, + RelatedGroupTable, +} from "./GroupDetails/RelatedGroupTable"; +import { Group, GroupSelect } from "./GroupSelect"; + +export const ContainingGroupsMultiSet: React.FC<{ + existingValue?: IRelatedGroupEntry[]; + value: IRelatedGroupEntry[]; + mode: GQL.BulkUpdateIdMode; + disabled?: boolean; + onUpdate: (value: IRelatedGroupEntry[]) => void; + onSetMode: (mode: GQL.BulkUpdateIdMode) => void; +}> = (props) => { + const { mode, onUpdate, existingValue } = props; + + function onSetMode(m: GQL.BulkUpdateIdMode) { + if (m === mode) { + return; + } + + // if going to Set, set the existing ids + if (m === GQL.BulkUpdateIdMode.Set && existingValue) { + onUpdate(existingValue); + // if going from Set, wipe the ids + } else if ( + m !== GQL.BulkUpdateIdMode.Set && + mode === GQL.BulkUpdateIdMode.Set + ) { + onUpdate([]); + } + + props.onSetMode(m); + } + + function onRemoveSet(items: Group[]) { + onUpdate(items.map((group) => ({ group }))); + } + + return ( +
+ + {mode !== GQL.BulkUpdateIdMode.Remove ? ( + + ) : ( + onRemoveSet(items)} + values={[]} + isDisabled={props.disabled} + /> + )} +
+ ); +}; diff --git a/ui/v2.5/src/components/Groups/EditGroupsDialog.tsx b/ui/v2.5/src/components/Groups/EditGroupsDialog.tsx index 5e0360d6964..d404ccf9c33 100644 --- a/ui/v2.5/src/components/Groups/EditGroupsDialog.tsx +++ b/ui/v2.5/src/components/Groups/EditGroupsDialog.tsx @@ -9,6 +9,7 @@ import { useToast } from "src/hooks/Toast"; import * as FormUtils from "src/utils/form"; import { RatingSystem } from "../Shared/Rating/RatingSystem"; import { + getAggregateIds, getAggregateInputIDs, getAggregateInputValue, getAggregateRating, @@ -18,12 +19,54 @@ import { import { faPencilAlt } from "@fortawesome/free-solid-svg-icons"; import { isEqual } from "lodash-es"; import { MultiSet } from "../Shared/MultiSet"; +import { ContainingGroupsMultiSet } from "./ContainingGroupsMultiSet"; +import { IRelatedGroupEntry } from "./GroupDetails/RelatedGroupTable"; interface IListOperationProps { selected: GQL.GroupDataFragment[]; onClose: (applied: boolean) => void; } +export function getAggregateContainingGroups( + state: Pick[] +) { + const sortedLists: IRelatedGroupEntry[][] = state.map((o) => + o.containing_groups + .map((oo) => ({ + group: oo.group, + description: oo.description, + })) + .sort((a, b) => a.group.id.localeCompare(b.group.id)) + ); + + return getAggregateIds(sortedLists); +} + +function getAggregateContainingGroupInput( + mode: GQL.BulkUpdateIdMode, + input: IRelatedGroupEntry[] | undefined, + aggregateValues: IRelatedGroupEntry[] +): GQL.BulkUpdateGroupDescriptionsInput | undefined { + if (mode === GQL.BulkUpdateIdMode.Set && (!input || input.length === 0)) { + // and all scenes have the same ids, + if (aggregateValues.length > 0) { + // then unset, otherwise ignore + return { mode, groups: [] }; + } + } else { + // if input non-empty, then we are setting them + return { + mode, + groups: + input?.map((e) => { + return { group_id: e.group.id, description: e.description }; + }) || [], + }; + } + + return undefined; +} + export const EditGroupsDialog: React.FC = ( props: IListOperationProps ) => { @@ -39,6 +82,12 @@ export const EditGroupsDialog: React.FC = ( const [tagIds, setTagIds] = useState(); const [existingTagIds, setExistingTagIds] = useState(); + const [containingGroupsMode, setGroupMode] = + React.useState(GQL.BulkUpdateIdMode.Add); + const [containingGroups, setGroups] = useState(); + const [existingContainingGroups, setExistingContainingGroups] = + useState(); + const [updateGroups] = useBulkGroupUpdate(getGroupInput()); const [isUpdating, setIsUpdating] = useState(false); @@ -47,17 +96,23 @@ export const EditGroupsDialog: React.FC = ( const aggregateRating = getAggregateRating(props.selected); const aggregateStudioId = getAggregateStudioId(props.selected); const aggregateTagIds = getAggregateTagIds(props.selected); + const aggregateGroups = getAggregateContainingGroups(props.selected); const groupInput: GQL.BulkGroupUpdateInput = { ids: props.selected.map((group) => group.id), director, }; - // if rating is undefined groupInput.rating100 = getAggregateInputValue(rating100, aggregateRating); groupInput.studio_id = getAggregateInputValue(studioId, aggregateStudioId); groupInput.tag_ids = getAggregateInputIDs(tagMode, tagIds, aggregateTagIds); + groupInput.containing_groups = getAggregateContainingGroupInput( + containingGroupsMode, + containingGroups, + aggregateGroups + ); + return groupInput; } @@ -85,17 +140,22 @@ export const EditGroupsDialog: React.FC = ( let updateRating: number | undefined; let updateStudioId: string | undefined; let updateTagIds: string[] = []; + let updateContainingGroupIds: IRelatedGroupEntry[] = []; let updateDirector: string | undefined; let first = true; state.forEach((group: GQL.GroupDataFragment) => { const groupTagIDs = (group.tags ?? []).map((p) => p.id).sort(); + const groupContainingGroupIDs = (group.containing_groups ?? []).sort( + (a, b) => a.group.id.localeCompare(b.group.id) + ); if (first) { first = false; updateRating = group.rating100 ?? undefined; updateStudioId = group.studio?.id ?? undefined; updateTagIds = groupTagIDs; + updateContainingGroupIds = groupContainingGroupIDs; updateDirector = group.director ?? undefined; } else { if (group.rating100 !== updateRating) { @@ -110,12 +170,16 @@ export const EditGroupsDialog: React.FC = ( if (!isEqual(groupTagIDs, updateTagIds)) { updateTagIds = []; } + if (!isEqual(groupContainingGroupIDs, updateContainingGroupIds)) { + updateTagIds = []; + } } }); setRating(updateRating); setStudioId(updateStudioId); setExistingTagIds(updateTagIds); + setExistingContainingGroups(updateContainingGroupIds); setDirector(updateDirector); }, [props.selected]); @@ -166,6 +230,19 @@ export const EditGroupsDialog: React.FC = ( /> + + + + + setGroups(v)} + onSetMode={(newMode) => setGroupMode(newMode)} + existingValue={existingContainingGroups ?? []} + value={containingGroups ?? []} + mode={containingGroupsMode} + /> + diff --git a/ui/v2.5/src/components/Groups/GroupCard.tsx b/ui/v2.5/src/components/Groups/GroupCard.tsx index ff84262533f..85fa9ed077e 100644 --- a/ui/v2.5/src/components/Groups/GroupCard.tsx +++ b/ui/v2.5/src/components/Groups/GroupCard.tsx @@ -1,4 +1,4 @@ -import React, { useEffect, useState } from "react"; +import React, { useEffect, useMemo, useState } from "react"; import { Button, ButtonGroup } from "react-bootstrap"; import * as GQL from "src/core/generated-graphql"; import { GridCard, calculateCardWidth } from "../Shared/GridCard/GridCard"; @@ -10,26 +10,66 @@ import { FormattedMessage } from "react-intl"; import { RatingBanner } from "../Shared/RatingBanner"; import { faPlayCircle, faTag } from "@fortawesome/free-solid-svg-icons"; import ScreenUtils from "src/utils/screen"; +import { RelatedGroupPopoverButton } from "./RelatedGroupPopover"; + +const Description: React.FC<{ + sceneNumber?: number; + description?: string; +}> = ({ sceneNumber, description }) => { + if (!sceneNumber && !description) return null; + + return ( + <> +
+ {sceneNumber !== undefined && ( + + #{sceneNumber} + + )} + {description !== undefined && ( + + {description} + + )} + + ); +}; interface IProps { group: GQL.GroupDataFragment; containerWidth?: number; - sceneIndex?: number; + sceneNumber?: number; selecting?: boolean; selected?: boolean; onSelectedChanged?: (selected: boolean, shiftKey: boolean) => void; + fromGroupId?: string; + onMove?: (srcIds: string[], targetId: string, after: boolean) => void; } export const GroupCard: React.FC = ({ group, - sceneIndex, + sceneNumber, containerWidth, selecting, selected, onSelectedChanged, + fromGroupId, + onMove, }) => { const [cardWidth, setCardWidth] = useState(); + const groupDescription = useMemo(() => { + if (!fromGroupId) { + return undefined; + } + + const containingGroup = group.containing_groups.find( + (cg) => cg.group.id === fromGroupId + ); + + return containingGroup?.description ?? undefined; + }, [fromGroupId, group.containing_groups]); + useEffect(() => { if (!containerWidth || ScreenUtils.isMobile()) return; @@ -41,19 +81,6 @@ export const GroupCard: React.FC = ({ setCardWidth(fittedCardWidth); }, [containerWidth]); - function maybeRenderSceneNumber() { - if (!sceneIndex) return; - - return ( - <> -
- - #{sceneIndex} - - - ); - } - function maybeRenderScenesPopoverButton() { if (group.scenes.length === 0) return; @@ -93,14 +120,28 @@ export const GroupCard: React.FC = ({ } function maybeRenderPopoverButtonGroup() { - if (sceneIndex || group.scenes.length > 0 || group.tags.length > 0) { + if ( + sceneNumber || + groupDescription || + group.scenes.length > 0 || + group.tags.length > 0 || + group.containing_groups.length > 0 || + group.sub_group_count > 0 + ) { return ( <> - {maybeRenderSceneNumber()} +
{maybeRenderScenesPopoverButton()} {maybeRenderTagPopoverButton()} + {(group.sub_group_count > 0 || + group.containing_groups.length > 0) && ( + + )} ); @@ -110,6 +151,8 @@ export const GroupCard: React.FC = ({ return ( ; onSelectChange: (id: string, selected: boolean, shiftKey: boolean) => void; + fromGroupId?: string; + onMove?: (srcIds: string[], targetId: string, after: boolean) => void; } export const GroupCardGrid: React.FC = ({ groups, selectedIds, onSelectChange, + fromGroupId, + onMove, }) => { const [componentRef, { width }] = useContainerDimensions(); return ( @@ -27,6 +31,8 @@ export const GroupCardGrid: React.FC = ({ onSelectedChanged={(selected: boolean, shiftKey: boolean) => onSelectChange(p.id, selected, shiftKey) } + fromGroupId={fromGroupId} + onMove={onMove} /> ))} diff --git a/ui/v2.5/src/components/Groups/GroupDetails/AddGroupsDialog.tsx b/ui/v2.5/src/components/Groups/GroupDetails/AddGroupsDialog.tsx new file mode 100644 index 00000000000..b893568101d --- /dev/null +++ b/ui/v2.5/src/components/Groups/GroupDetails/AddGroupsDialog.tsx @@ -0,0 +1,121 @@ +import React, { useCallback, useMemo, useState } from "react"; +import { Form } from "react-bootstrap"; +import { useIntl } from "react-intl"; +import * as GQL from "src/core/generated-graphql"; +import { useToast } from "src/hooks/Toast"; +import { faPlus } from "@fortawesome/free-solid-svg-icons"; +import { RelatedGroupTable, IRelatedGroupEntry } from "./RelatedGroupTable"; +import { ModalComponent } from "src/components/Shared/Modal"; +import { useAddSubGroups } from "src/core/StashService"; +import { ListFilterModel } from "src/models/list-filter/filter"; +import { + ContainingGroupsCriterionOption, + GroupsCriterion, +} from "src/models/list-filter/criteria/groups"; + +interface IListOperationProps { + containingGroup: GQL.GroupDataFragment; + onClose: (applied: boolean) => void; +} + +export const AddSubGroupsDialog: React.FC = ( + props: IListOperationProps +) => { + const intl = useIntl(); + const [isUpdating, setIsUpdating] = useState(false); + + const addSubGroups = useAddSubGroups(); + + const Toast = useToast(); + + const [entries, setEntries] = useState([]); + + const excludeIDs = useMemo( + () => [ + ...props.containingGroup.containing_groups.map((m) => m.group.id), + props.containingGroup.id, + ], + [props.containingGroup] + ); + + const filterHook = useCallback( + (f: ListFilterModel) => { + const groupValue = { + id: props.containingGroup.id, + label: props.containingGroup.name, + }; + + // filter out sub groups that are already in the containing group + const criterion = new GroupsCriterion(ContainingGroupsCriterionOption); + criterion.value = { + items: [groupValue], + depth: 1, + excluded: [], + }; + criterion.modifier = GQL.CriterionModifier.Excludes; + f.criteria.push(criterion); + + return f; + }, + [props.containingGroup] + ); + + const onSave = async () => { + setIsUpdating(true); + try { + // add the sub groups + await addSubGroups( + props.containingGroup.id, + entries.map((m) => ({ + group_id: m.group.id, + description: m.description, + })) + ); + + const imageCount = entries.length; + Toast.success( + intl.formatMessage( + { id: "toast.added_entity" }, + { + count: imageCount, + singularEntity: intl.formatMessage({ id: "group" }), + pluralEntity: intl.formatMessage({ id: "groups" }), + } + ) + ); + + props.onClose(true); + } catch (err) { + Toast.error(err); + } finally { + setIsUpdating(false); + } + }; + + return ( + props.onClose(false), + text: intl.formatMessage({ id: "actions.cancel" }), + variant: "secondary", + }} + isRunning={isUpdating} + > +
+ setEntries(input)} + excludeIDs={excludeIDs} + filterHook={filterHook} + /> + +
+ ); +}; diff --git a/ui/v2.5/src/components/Groups/GroupDetails/Group.tsx b/ui/v2.5/src/components/Groups/GroupDetails/Group.tsx index 7ef2ca8e5fd..0aa4dbd5472 100644 --- a/ui/v2.5/src/components/Groups/GroupDetails/Group.tsx +++ b/ui/v2.5/src/components/Groups/GroupDetails/Group.tsx @@ -9,7 +9,7 @@ import { useGroupUpdate, useGroupDestroy, } from "src/core/StashService"; -import { useHistory, RouteComponentProps } from "react-router-dom"; +import { useHistory, RouteComponentProps, Redirect } from "react-router-dom"; import { DetailsEditNavbar } from "src/components/Shared/DetailsEditNavbar"; import { ErrorMessage } from "src/components/Shared/ErrorMessage"; import { LoadingIndicator } from "src/components/Shared/LoadingIndicator"; @@ -35,16 +35,89 @@ import { ExpandCollapseButton } from "src/components/Shared/CollapseButton"; import { AliasList } from "src/components/Shared/DetailsPage/AliasList"; import { HeaderImage } from "src/components/Shared/DetailsPage/HeaderImage"; import { LightboxLink } from "src/hooks/Lightbox/LightboxLink"; +import { + TabTitleCounter, + useTabKey, +} from "src/components/Shared/DetailsPage/Tabs"; +import { Tab, Tabs } from "react-bootstrap"; +import { GroupSubGroupsPanel } from "./GroupSubGroupsPanel"; + +const validTabs = ["default", "scenes", "subgroups"] as const; +type TabKey = (typeof validTabs)[number]; + +function isTabKey(tab: string): tab is TabKey { + return validTabs.includes(tab as TabKey); +} + +const GroupTabs: React.FC<{ + tabKey?: TabKey; + group: GQL.GroupDataFragment; + abbreviateCounter: boolean; +}> = ({ tabKey, group, abbreviateCounter }) => { + const { scene_count: sceneCount, sub_group_count: groupCount } = group; + + const populatedDefaultTab = useMemo(() => { + if (sceneCount == 0 && groupCount !== 0) { + return "subgroups"; + } + + return "scenes"; + }, [sceneCount, groupCount]); + + const { setTabKey } = useTabKey({ + tabKey, + validTabs, + defaultTabKey: populatedDefaultTab, + baseURL: `/groups/${group.id}`, + }); + + return ( + + + } + > + + + + } + > + + + + ); +}; interface IProps { group: GQL.GroupDataFragment; + tabKey?: TabKey; } interface IGroupParams { id: string; + tab?: string; } -const GroupPage: React.FC = ({ group }) => { +const GroupPage: React.FC = ({ group, tabKey }) => { const intl = useIntl(); const history = useHistory(); const Toast = useToast(); @@ -55,6 +128,7 @@ const GroupPage: React.FC = ({ group }) => { const enableBackgroundImage = uiConfig?.enableMovieBackgroundImage ?? false; const compactExpandedDetails = uiConfig?.compactExpandedDetails ?? false; const showAllDetails = uiConfig?.showAllDetails ?? true; + const abbreviateCounter = uiConfig?.abbreviateCounters ?? false; const [collapsed, setCollapsed] = useState(!showAllDetails); const loadStickyHeader = useLoadStickyHeader(); @@ -230,14 +304,6 @@ const GroupPage: React.FC = ({ group }) => { } } - const renderTabs = () => ; - - function maybeRenderTab() { - if (!isEditing) { - return renderTabs(); - } - } - if (updating || deleting) return ; const headerClassName = cx("detail-header", { @@ -335,7 +401,15 @@ const GroupPage: React.FC = ({ group }) => {
-
{maybeRenderTab()}
+
+ {!isEditing && ( + + )} +
{renderDeleteAlert()} @@ -344,19 +418,33 @@ const GroupPage: React.FC = ({ group }) => { }; const GroupLoader: React.FC> = ({ + location, match, }) => { - const { id } = match.params; + const { id, tab } = match.params; const { data, loading, error } = useFindGroup(id); useScrollToTopOnMount(); + if (tab && !isTabKey(tab)) { + return ( + + ); + } + if (loading) return ; if (error) return ; if (!data?.findGroup) return ; - return ; + return ( + + ); }; export default GroupLoader; diff --git a/ui/v2.5/src/components/Groups/GroupDetails/GroupDetailsPanel.tsx b/ui/v2.5/src/components/Groups/GroupDetails/GroupDetailsPanel.tsx index eb3696550fe..6a20eb9081a 100644 --- a/ui/v2.5/src/components/Groups/GroupDetails/GroupDetailsPanel.tsx +++ b/ui/v2.5/src/components/Groups/GroupDetails/GroupDetailsPanel.tsx @@ -5,7 +5,28 @@ import TextUtils from "src/utils/text"; import { DetailItem } from "src/components/Shared/DetailItem"; import { Link } from "react-router-dom"; import { DirectorLink } from "src/components/Shared/Link"; -import { TagLink } from "src/components/Shared/TagLink"; +import { GroupLink, TagLink } from "src/components/Shared/TagLink"; + +interface IGroupDescription { + group: GQL.SlimGroupDataFragment; + description?: string | null; +} + +const GroupsList: React.FC<{ groups: IGroupDescription[] }> = ({ groups }) => { + if (!groups.length) { + return null; + } + + return ( +
    + {groups.map((entry) => ( +
  • + +
  • + ))} +
+ ); +}; interface IGroupDetailsPanel { group: GQL.GroupDataFragment; @@ -48,6 +69,13 @@ export const GroupDetailsPanel: React.FC = ({ value={renderTagsField()} fullWidth={fullWidth} /> + {group.containing_groups.length > 0 && ( + } + fullWidth={fullWidth} + /> + )} ); } diff --git a/ui/v2.5/src/components/Groups/GroupDetails/GroupEditPanel.tsx b/ui/v2.5/src/components/Groups/GroupDetails/GroupEditPanel.tsx index a3ccf5b8b33..0b94baf2791 100644 --- a/ui/v2.5/src/components/Groups/GroupDetails/GroupEditPanel.tsx +++ b/ui/v2.5/src/components/Groups/GroupDetails/GroupEditPanel.tsx @@ -1,4 +1,4 @@ -import React, { useEffect, useState } from "react"; +import React, { useEffect, useMemo, useState } from "react"; import { FormattedMessage, useIntl } from "react-intl"; import * as GQL from "src/core/generated-graphql"; import * as yup from "yup"; @@ -26,6 +26,8 @@ import { } from "src/utils/yup"; import { Studio, StudioSelect } from "src/components/Studios/StudioSelect"; import { useTagsEdit } from "src/hooks/tagsEdit"; +import { Group } from "src/components/Groups/GroupSelect"; +import { RelatedGroupTable, IRelatedGroupEntry } from "./RelatedGroupTable"; interface IGroupEditPanel { group: Partial; @@ -60,6 +62,7 @@ export const GroupEditPanel: React.FC = ({ const [scrapedGroup, setScrapedGroup] = useState(); const [studio, setStudio] = useState(null); + const [containingGroups, setContainingGroups] = useState([]); const schema = yup.object({ name: yup.string().required(), @@ -68,6 +71,14 @@ export const GroupEditPanel: React.FC = ({ date: yupDateString(intl), studio_id: yup.string().required().nullable(), tag_ids: yup.array(yup.string().required()).defined(), + containing_groups: yup + .array( + yup.object({ + group_id: yup.string().required(), + description: yup.string().nullable().ensure(), + }) + ) + .defined(), director: yup.string().ensure(), urls: yupUniqueStringList(intl), synopsis: yup.string().ensure(), @@ -82,6 +93,9 @@ export const GroupEditPanel: React.FC = ({ date: group?.date ?? "", studio_id: group?.studio?.id ?? null, tag_ids: (group?.tags ?? []).map((t) => t.id), + containing_groups: (group?.containing_groups ?? []).map((m) => { + return { group_id: m.group.id, description: m.description ?? "" }; + }), director: group?.director ?? "", urls: group?.urls ?? [], synopsis: group?.synopsis ?? "", @@ -101,6 +115,17 @@ export const GroupEditPanel: React.FC = ({ (ids) => formik.setFieldValue("tag_ids", ids) ); + const containingGroupEntries = useMemo(() => { + return formik.values.containing_groups + .map((m) => { + return { + group: containingGroups.find((mm) => mm.id === m.group_id), + description: m.description, + }; + }) + .filter((m) => m.group !== undefined) as IRelatedGroupEntry[]; + }, [formik.values.containing_groups, containingGroups]); + function onSetStudio(item: Studio | null) { setStudio(item); formik.setFieldValue("studio_id", item ? item.id : null); @@ -110,6 +135,10 @@ export const GroupEditPanel: React.FC = ({ setStudio(group.studio ?? null); }, [group.studio]); + useEffect(() => { + setContainingGroups(group.containing_groups?.map((m) => m.group) ?? []); + }, [group.containing_groups]); + // set up hotkeys useEffect(() => { // Mousetrap.bind("u", (e) => { @@ -366,6 +395,30 @@ export const GroupEditPanel: React.FC = ({ return renderField("tag_ids", title, tagsControl()); } + function onSetContainingGroupEntries(input: IRelatedGroupEntry[]) { + setContainingGroups(input.map((m) => m.group)); + + const newGroups = input.map((m) => ({ + group_id: m.group.id, + description: m.description, + })); + + formik.setFieldValue("containing_groups", newGroups); + } + + function renderContainingGroupsField() { + const title = intl.formatMessage({ id: "containing_groups" }); + const control = ( + + ); + + return renderField("containing_groups", title, control); + } + // TODO: CSS class return (
@@ -394,6 +447,7 @@ export const GroupEditPanel: React.FC = ({ {renderInputField("aliases")} {renderDurationField("duration")} {renderDateField("date")} + {renderContainingGroupsField()} {renderStudioField()} {renderInputField("director")} {renderURLListField("urls", onScrapeGroupURL, urlScrapable)} diff --git a/ui/v2.5/src/components/Groups/GroupDetails/GroupScenesPanel.tsx b/ui/v2.5/src/components/Groups/GroupDetails/GroupScenesPanel.tsx index acca9f0aae1..ab9ec3fea24 100644 --- a/ui/v2.5/src/components/Groups/GroupDetails/GroupScenesPanel.tsx +++ b/ui/v2.5/src/components/Groups/GroupDetails/GroupScenesPanel.tsx @@ -1,6 +1,9 @@ import React from "react"; import * as GQL from "src/core/generated-graphql"; -import { GroupsCriterion } from "src/models/list-filter/criteria/groups"; +import { + GroupsCriterion, + GroupsCriterionOption, +} from "src/models/list-filter/criteria/groups"; import { ListFilterModel } from "src/models/list-filter/filter"; import { SceneList } from "src/components/Scenes/SceneList"; import { View } from "src/components/List/views"; @@ -8,13 +11,14 @@ import { View } from "src/components/List/views"; interface IGroupScenesPanel { active: boolean; group: GQL.GroupDataFragment; + showSubGroupContent?: boolean; } -export const GroupScenesPanel: React.FC = ({ - active, - group, -}) => { - function filterHook(filter: ListFilterModel) { +function useFilterHook( + group: Pick, + showSubGroupContent?: boolean +) { + return (filter: ListFilterModel) => { const groupValue = { id: group.id, label: group.name }; // if group is already present, then we modify it, otherwise add let groupCriterion = filter.criteria.find((c) => { @@ -28,23 +32,35 @@ export const GroupScenesPanel: React.FC = ({ ) { // add the group if not present if ( - !groupCriterion.value.find((p) => { + !groupCriterion.value.items.find((p) => { return p.id === group.id; }) ) { - groupCriterion.value.push(groupValue); + groupCriterion.value.items.push(groupValue); } groupCriterion.modifier = GQL.CriterionModifier.IncludesAll; } else { // overwrite - groupCriterion = new GroupsCriterion(); - groupCriterion.value = [groupValue]; + groupCriterion = new GroupsCriterion(GroupsCriterionOption); + groupCriterion.value = { + items: [groupValue], + depth: showSubGroupContent ? -1 : 0, + excluded: [], + }; filter.criteria.push(groupCriterion); } return filter; - } + }; +} + +export const GroupScenesPanel: React.FC = ({ + active, + group, + showSubGroupContent, +}) => { + const filterHook = useFilterHook(group, showSubGroupContent); if (group && group.id) { return ( @@ -53,6 +69,7 @@ export const GroupScenesPanel: React.FC = ({ defaultSort="group_scene_number" alterQuery={active} view={View.GroupScenes} + fromGroupId={group.id} /> ); } diff --git a/ui/v2.5/src/components/Groups/GroupDetails/GroupSubGroupsPanel.tsx b/ui/v2.5/src/components/Groups/GroupDetails/GroupSubGroupsPanel.tsx new file mode 100644 index 00000000000..a2bb26e9511 --- /dev/null +++ b/ui/v2.5/src/components/Groups/GroupDetails/GroupSubGroupsPanel.tsx @@ -0,0 +1,204 @@ +import React, { useMemo } from "react"; +import * as GQL from "src/core/generated-graphql"; +import { GroupList } from "../GroupList"; +import { ListFilterModel } from "src/models/list-filter/filter"; +import { + ContainingGroupsCriterionOption, + GroupsCriterion, +} from "src/models/list-filter/criteria/groups"; +import { + useRemoveSubGroups, + useReorderSubGroupsMutation, +} from "src/core/StashService"; +import { ButtonToolbar } from "react-bootstrap"; +import { ListOperationButtons } from "src/components/List/ListOperationButtons"; +import { useListContext } from "src/components/List/ListProvider"; +import { + PageSizeSelector, + SearchTermInput, +} from "src/components/List/ListFilter"; +import { useFilter } from "src/components/List/FilterProvider"; +import { IFilteredListToolbar } from "src/components/List/FilteredListToolbar"; +import { + showWhenNoneSelected, + showWhenSelected, +} from "src/components/List/ItemList"; +import { faMinus, faPlus } from "@fortawesome/free-solid-svg-icons"; +import { useIntl } from "react-intl"; +import { useToast } from "src/hooks/Toast"; +import { useModal } from "src/hooks/modal"; +import { AddSubGroupsDialog } from "./AddGroupsDialog"; + +const useContainingGroupFilterHook = ( + group: Pick, + showSubGroupContent?: boolean +) => { + return (filter: ListFilterModel) => { + const groupValue = { id: group.id, label: group.name }; + // if studio is already present, then we modify it, otherwise add + let groupCriterion = filter.criteria.find((c) => { + return c.criterionOption.type === "containing_groups"; + }) as GroupsCriterion | undefined; + + if (groupCriterion) { + // add the group if not present + if ( + !groupCriterion.value.items.find((p) => { + return p.id === group.id; + }) + ) { + groupCriterion.value.items.push(groupValue); + } + } else { + groupCriterion = new GroupsCriterion(ContainingGroupsCriterionOption); + groupCriterion.value = { + items: [groupValue], + excluded: [], + depth: showSubGroupContent ? -1 : 0, + }; + groupCriterion.modifier = GQL.CriterionModifier.Includes; + filter.criteria.push(groupCriterion); + } + + filter.sortBy = "sub_group_order"; + filter.sortDirection = GQL.SortDirectionEnum.Asc; + + return filter; + }; +}; + +const Toolbar: React.FC = ({ + onEdit, + onDelete, + operations, +}) => { + const { getSelected, onSelectAll, onSelectNone } = useListContext(); + const { filter, setFilter } = useFilter(); + + return ( + +
+ +
+ setFilter(filter.setPageSize(size))} + /> + 0} + otherOperations={operations} + onEdit={onEdit} + onDelete={onDelete} + /> +
+ ); +}; + +interface IGroupSubGroupsPanel { + active: boolean; + group: GQL.GroupDataFragment; +} + +export const GroupSubGroupsPanel: React.FC = ({ + active, + group, +}) => { + const intl = useIntl(); + const Toast = useToast(); + const { modal, showModal, closeModal } = useModal(); + + const [reorderSubGroups] = useReorderSubGroupsMutation(); + const mutateRemoveSubGroups = useRemoveSubGroups(); + + const filterHook = useContainingGroupFilterHook(group); + + const defaultFilter = useMemo(() => { + const sortBy = "sub_group_order"; + const ret = new ListFilterModel(GQL.FilterMode.Groups, undefined, { + defaultSortBy: sortBy, + }); + + // unset the sort by so that its not included in the URL + ret.sortBy = undefined; + + return ret; + }, []); + + async function removeSubGroups( + result: GQL.FindGroupsQueryResult, + filter: ListFilterModel, + selectedIds: Set + ) { + try { + await mutateRemoveSubGroups(group.id, Array.from(selectedIds.values())); + + Toast.success( + intl.formatMessage( + { id: "toast.removed_entity" }, + { + count: selectedIds.size, + singularEntity: intl.formatMessage({ id: "group" }), + pluralEntity: intl.formatMessage({ id: "groups" }), + } + ) + ); + } catch (e) { + Toast.error(e); + } + } + + async function onAddSubGroups() { + showModal( + + ); + } + + const otherOperations = [ + { + text: intl.formatMessage({ id: "actions.add_sub_groups" }), + onClick: onAddSubGroups, + isDisplayed: showWhenNoneSelected, + postRefetch: true, + icon: faPlus, + buttonVariant: "secondary", + }, + { + text: intl.formatMessage({ id: "actions.remove_from_containing_group" }), + onClick: removeSubGroups, + isDisplayed: showWhenSelected, + postRefetch: true, + icon: faMinus, + buttonVariant: "danger", + }, + ]; + + function onMove(srcIds: string[], targetId: string, after: boolean) { + reorderSubGroups({ + variables: { + input: { + group_id: group.id, + sub_group_ids: srcIds, + insert_at_id: targetId, + insert_after: after, + }, + }, + }); + } + + return ( + <> + {modal} + } + /> + + ); +}; diff --git a/ui/v2.5/src/components/Groups/GroupDetails/RelatedGroupTable.tsx b/ui/v2.5/src/components/Groups/GroupDetails/RelatedGroupTable.tsx new file mode 100644 index 00000000000..feed49ad09b --- /dev/null +++ b/ui/v2.5/src/components/Groups/GroupDetails/RelatedGroupTable.tsx @@ -0,0 +1,137 @@ +import React, { useMemo } from "react"; +import { FormattedMessage } from "react-intl"; +import * as GQL from "src/core/generated-graphql"; +import { Form, Row, Col } from "react-bootstrap"; +import { Group, GroupSelect } from "src/components/Groups/GroupSelect"; +import cx from "classnames"; +import { ListFilterModel } from "src/models/list-filter/filter"; + +export type GroupSceneIndexMap = Map; + +export interface IRelatedGroupEntry { + group: Group; + description?: GQL.InputMaybe | undefined; +} + +export const RelatedGroupTable: React.FC<{ + value: IRelatedGroupEntry[]; + onUpdate: (input: IRelatedGroupEntry[]) => void; + excludeIDs?: string[]; + filterHook?: (f: ListFilterModel) => ListFilterModel; + disabled?: boolean; +}> = (props) => { + const { value, onUpdate } = props; + + const groupIDs = useMemo(() => value.map((m) => m.group.id), [value]); + + const excludeIDs = useMemo( + () => [...groupIDs, ...(props.excludeIDs ?? [])], + [props.excludeIDs, groupIDs] + ); + + const updateFieldChanged = (index: number, description: string | null) => { + const newValues = value.map((existing, i) => { + if (i === index) { + return { + ...existing, + description, + }; + } + return existing; + }); + + onUpdate(newValues); + }; + + function onGroupSet(index: number, groups: Group[]) { + if (!groups.length) { + // remove this entry + const newValues = value.filter((_, i) => i !== index); + onUpdate(newValues); + return; + } + + const group = groups[0]; + + const newValues = value.map((existing, i) => { + if (i === index) { + return { + ...existing, + group: group, + }; + } + return existing; + }); + + onUpdate(newValues); + } + + function onNewGroupSet(groups: Group[]) { + if (!groups.length) { + return; + } + + const group = groups[0]; + + const newValues = [ + ...value, + { + group: group, + scene_index: null, + }, + ]; + + onUpdate(newValues); + } + + return ( +
+ + + + + + + {value.map((m, i) => ( + + + onGroupSet(i, items)} + values={[m.group!]} + excludeIds={excludeIDs} + filterHook={props.filterHook} + isDisabled={props.disabled} + /> + + + ) => { + updateFieldChanged( + i, + e.currentTarget.value === "" ? null : e.currentTarget.value + ); + }} + disabled={props.disabled} + /> + + + ))} + + + onNewGroupSet(items)} + values={[]} + excludeIds={excludeIDs} + filterHook={props.filterHook} + isDisabled={props.disabled} + /> + + +
+ ); +}; diff --git a/ui/v2.5/src/components/Groups/GroupList.tsx b/ui/v2.5/src/components/Groups/GroupList.tsx index ba45912762b..d3f395037d1 100644 --- a/ui/v2.5/src/components/Groups/GroupList.tsx +++ b/ui/v2.5/src/components/Groups/GroupList.tsx @@ -1,4 +1,4 @@ -import React, { useState } from "react"; +import React, { PropsWithChildren, useState } from "react"; import { useIntl } from "react-intl"; import cloneDeep from "lodash-es/cloneDeep"; import Mousetrap from "mousetrap"; @@ -17,6 +17,35 @@ import { DeleteEntityDialog } from "../Shared/DeleteEntityDialog"; import { GroupCardGrid } from "./GroupCardGrid"; import { EditGroupsDialog } from "./EditGroupsDialog"; import { View } from "../List/views"; +import { + IFilteredListToolbar, + IItemListOperation, +} from "../List/FilteredListToolbar"; + +const GroupExportDialog: React.FC<{ + open?: boolean; + selectedIds: Set; + isExportAll?: boolean; + onClose: () => void; +}> = ({ open = false, selectedIds, isExportAll = false, onClose }) => { + if (!open) { + return null; + } + + return ( + + ); +}; + +const filterMode = GQL.FilterMode.Groups; function getItems(result: GQL.FindGroupsQueryResult) { return result?.data?.findGroups?.groups ?? []; @@ -26,24 +55,57 @@ function getCount(result: GQL.FindGroupsQueryResult) { return result?.data?.findGroups?.count ?? 0; } -interface IGroupList { +interface IGroupListContext { filterHook?: (filter: ListFilterModel) => ListFilterModel; + defaultFilter?: ListFilterModel; view?: View; alterQuery?: boolean; + selectable?: boolean; +} + +export const GroupListContext: React.FC< + PropsWithChildren +> = ({ alterQuery, filterHook, defaultFilter, view, selectable, children }) => { + return ( + + {children} + + ); +}; + +interface IGroupList extends IGroupListContext { + fromGroupId?: string; + onMove?: (srcIds: string[], targetId: string, after: boolean) => void; + renderToolbar?: (props: IFilteredListToolbar) => React.ReactNode; + otherOperations?: IItemListOperation[]; } export const GroupList: React.FC = ({ filterHook, alterQuery, + defaultFilter, view, + fromGroupId, + onMove, + selectable, + renderToolbar, + otherOperations: providedOperations = [], }) => { const intl = useIntl(); const history = useHistory(); const [isExportDialogOpen, setIsExportDialogOpen] = useState(false); const [isExportAll, setIsExportAll] = useState(false); - const filterMode = GQL.FilterMode.Groups; - const otherOperations = [ { text: intl.formatMessage({ id: "actions.view_random" }), @@ -58,6 +120,7 @@ export const GroupList: React.FC = ({ text: intl.formatMessage({ id: "actions.export_all" }), onClick: onExportAll, }, + ...providedOperations, ]; function addKeybinds( @@ -110,42 +173,23 @@ export const GroupList: React.FC = ({ selectedIds: Set, onSelectChange: (id: string, selected: boolean, shiftKey: boolean) => void ) { - function maybeRenderGroupExportDialog() { - if (isExportDialogOpen) { - return ( - setIsExportDialogOpen(false)} - /> - ); - } - } - - function renderGroups() { - if (!result.data?.findGroups) return; - - if (filter.displayMode === DisplayMode.Grid) { - return ( + return ( + <> + setIsExportDialogOpen(false)} + /> + {filter.displayMode === DisplayMode.Grid && ( - ); - } - if (filter.displayMode === DisplayMode.List) { - return

TODO

; - } - } - return ( - <> - {maybeRenderGroupExportDialog()} - {renderGroups()} + )} ); } @@ -173,15 +217,12 @@ export const GroupList: React.FC = ({ } return ( - = ({ renderContent={renderContent} renderEditDialog={renderEditDialog} renderDeleteDialog={renderDeleteDialog} + renderToolbar={renderToolbar} /> - + ); }; diff --git a/ui/v2.5/src/components/Groups/GroupSelect.tsx b/ui/v2.5/src/components/Groups/GroupSelect.tsx index 4f611e5e3b4..dd16088e9e7 100644 --- a/ui/v2.5/src/components/Groups/GroupSelect.tsx +++ b/ui/v2.5/src/components/Groups/GroupSelect.tsx @@ -56,13 +56,14 @@ const groupSelectSort = PatchFunction( sortGroupsByRelevance ); -const _GroupSelect: React.FC< +export const GroupSelect: React.FC< IFilterProps & IFilterValueProps & { hoverPlacement?: Placement; excludeIds?: string[]; + filterHook?: (f: ListFilterModel) => ListFilterModel; } -> = (props) => { +> = PatchComponent("GroupSelect", (props) => { const [createGroup] = useGroupCreate(); const { configuration } = React.useContext(ConfigurationContext); @@ -75,12 +76,17 @@ const _GroupSelect: React.FC< const exclude = useMemo(() => props.excludeIds ?? [], [props.excludeIds]); async function loadGroups(input: string): Promise { - const filter = new ListFilterModel(GQL.FilterMode.Groups); + let filter = new ListFilterModel(GQL.FilterMode.Groups); filter.searchTerm = input; filter.currentPage = 1; filter.itemsPerPage = maxOptionsShown; filter.sortBy = "name"; filter.sortDirection = GQL.SortDirectionEnum.Asc; + + if (props.filterHook) { + filter = props.filterHook(filter); + } + const query = await queryFindGroupsForSelect(filter); let ret = query.data.findGroups.groups.filter((group) => { // HACK - we should probably exclude these in the backend query, but @@ -255,9 +261,7 @@ const _GroupSelect: React.FC< closeMenuOnSelect={!props.isMulti} /> ); -}; - -export const GroupSelect = PatchComponent("GroupSelect", _GroupSelect); +}); const _GroupIDSelect: React.FC> = ( props diff --git a/ui/v2.5/src/components/Groups/GroupTag.tsx b/ui/v2.5/src/components/Groups/GroupTag.tsx new file mode 100644 index 00000000000..3443d18dfef --- /dev/null +++ b/ui/v2.5/src/components/Groups/GroupTag.tsx @@ -0,0 +1,28 @@ +import React from "react"; +import { Link } from "react-router-dom"; +import * as GQL from "src/core/generated-graphql"; +import { GroupLink } from "../Shared/TagLink"; + +export const GroupTag: React.FC<{ + group: Pick; + linkType?: "scene" | "sub_group" | "details"; + description?: string; +}> = ({ group, linkType, description }) => { + return ( +
+ + {group.name + + +
+ ); +}; diff --git a/ui/v2.5/src/components/Groups/RelatedGroupPopover.tsx b/ui/v2.5/src/components/Groups/RelatedGroupPopover.tsx new file mode 100644 index 00000000000..03095f284a4 --- /dev/null +++ b/ui/v2.5/src/components/Groups/RelatedGroupPopover.tsx @@ -0,0 +1,110 @@ +import { + faFilm, + faArrowUpLong, + faArrowDownLong, +} from "@fortawesome/free-solid-svg-icons"; +import React, { useMemo } from "react"; +import { Button, OverlayTrigger, Tooltip } from "react-bootstrap"; +import { Count } from "../Shared/PopoverCountButton"; +import { Icon } from "../Shared/Icon"; +import { HoverPopover } from "../Shared/HoverPopover"; +import { Link } from "react-router-dom"; +import NavUtils from "src/utils/navigation"; +import * as GQL from "src/core/generated-graphql"; +import { useIntl } from "react-intl"; +import { GroupTag } from "./GroupTag"; + +interface IProps { + group: Pick< + GQL.GroupDataFragment, + "id" | "name" | "containing_groups" | "sub_group_count" + >; +} + +const ContainingGroupsCount: React.FC = ({ group }) => { + const { containing_groups: containingGroups } = group; + + const popoverContent = useMemo(() => { + if (!containingGroups.length) { + return []; + } + + return containingGroups.map((entry) => ( + + )); + }, [containingGroups]); + + if (!containingGroups.length) { + return null; + } + + return ( + + + + + + + ); +}; + +const SubGroupCount: React.FC = ({ group }) => { + const intl = useIntl(); + + const count = group.sub_group_count; + + if (!count) { + return null; + } + + function getTitle() { + const pluralCategory = intl.formatPlural(count); + const options = { + one: "sub_group", + other: "sub_groups", + }; + const plural = intl.formatMessage({ + id: options[pluralCategory as "one"] || options.other, + }); + return `${count} ${plural}`; + } + + return ( + {getTitle()}} + placement="bottom" + > + + + + + + ); +}; + +export const RelatedGroupPopoverButton: React.FC = ({ group }) => { + return ( + + + + ); +}; diff --git a/ui/v2.5/src/components/Groups/styles.scss b/ui/v2.5/src/components/Groups/styles.scss index 3d1868fb815..1b80045c73d 100644 --- a/ui/v2.5/src/components/Groups/styles.scss +++ b/ui/v2.5/src/components/Groups/styles.scss @@ -14,7 +14,8 @@ width: 100%; } - .group-scene-number { + .group-scene-number, + .group-containing-group-description { text-align: center; } @@ -89,3 +90,24 @@ } } } + +.groups-list { + list-style-type: none; + padding-inline-start: 0; + + li { + display: inline; + } +} + +.related-group-popover-button { + .containing-group-count { + display: inline-block; + } + + .related-group-count .fa-icon { + color: $text-muted; + margin-left: 0; + margin-right: 0.25rem; + } +} diff --git a/ui/v2.5/src/components/List/FilteredListToolbar.tsx b/ui/v2.5/src/components/List/FilteredListToolbar.tsx index d6887c51de0..6018dd836f6 100644 --- a/ui/v2.5/src/components/List/FilteredListToolbar.tsx +++ b/ui/v2.5/src/components/List/FilteredListToolbar.tsx @@ -31,14 +31,16 @@ export interface IItemListOperation { buttonVariant?: string; } -export const FilteredListToolbar: React.FC<{ - showEditFilter: (editingCriterion?: string) => void; +export interface IFilteredListToolbar { + showEditFilter?: (editingCriterion?: string) => void; view?: View; onEdit?: () => void; onDelete?: () => void; operations?: IListFilterOperation[]; zoomable?: boolean; -}> = ({ +} + +export const FilteredListToolbar: React.FC = ({ showEditFilter, view, onEdit, @@ -60,13 +62,15 @@ export const FilteredListToolbar: React.FC<{ } return ( - - showEditFilter()} - view={view} - /> + + {showEditFilter && ( + showEditFilter()} + view={view} + /> + )} { @@ -59,6 +63,7 @@ interface IItemListProps { filter: ListFilterModel, selectedIds: Set ) => () => void; + renderToolbar?: (props: IFilteredListToolbar) => React.ReactNode; } export const ItemList = ( @@ -73,6 +78,7 @@ export const ItemList = ( renderDeleteDialog, renderMetadataByline, addKeybinds, + renderToolbar: providedToolbar, } = props; const { filter, setFilter: updateFilter } = useFilter(); @@ -142,28 +148,30 @@ export const ItemList = ( } }, [addKeybinds, result, effectiveFilter, selectedIds]); - async function onOperationClicked(o: IItemListOperation) { - await o.onClick(result, effectiveFilter, selectedIds); - if (o.postRefetch) { - result.refetch(); - } - } - - const operations = otherOperations?.map((o) => ({ - text: o.text, - onClick: () => { - onOperationClicked(o); - }, - isDisplayed: () => { - if (o.isDisplayed) { - return o.isDisplayed(result, effectiveFilter, selectedIds); + const operations = useMemo(() => { + async function onOperationClicked(o: IItemListOperation) { + await o.onClick(result, effectiveFilter, selectedIds); + if (o.postRefetch) { + result.refetch(); } + } - return true; - }, - icon: o.icon, - buttonVariant: o.buttonVariant, - })); + return otherOperations?.map((o) => ({ + text: o.text, + onClick: () => { + onOperationClicked(o); + }, + isDisplayed: () => { + if (o.isDisplayed) { + return o.isDisplayed(result, effectiveFilter, selectedIds); + } + + return true; + }, + icon: o.icon, + buttonVariant: o.buttonVariant, + })); + }, [result, effectiveFilter, selectedIds, otherOperations]); function onEdit() { if (!renderEditDialog) { @@ -215,16 +223,22 @@ export const ItemList = ( updateFilter(filter.clearCriteria()); } + const filterListToolbarProps = { + showEditFilter, + view: view, + operations: operations, + zoomable: zoomable, + onEdit: renderEditDialog ? onEdit : undefined, + onDelete: renderDeleteDialog ? onDelete : undefined, + }; + return (
- + {providedToolbar ? ( + providedToolbar(filterListToolbarProps) + ) : ( + + )} showEditFilter(c.criterionOption.type)} @@ -258,6 +272,7 @@ export const ItemList = ( interface IItemListContextProps { filterMode: GQL.FilterMode; defaultSort?: string; + defaultFilter?: ListFilterModel; useResult: (filter: ListFilterModel) => T; getCount: (data: T) => number; getItems: (data: T) => E[]; @@ -275,6 +290,7 @@ export const ItemListContext = ( const { filterMode, defaultSort, + defaultFilter: providedDefaultFilter, useResult, getCount, getItems, @@ -287,10 +303,11 @@ export const ItemListContext = ( const emptyFilter = useMemo( () => + providedDefaultFilter?.clone() ?? new ListFilterModel(filterMode, undefined, { defaultSortBy: defaultSort, }), - [filterMode, defaultSort] + [filterMode, defaultSort, providedDefaultFilter] ); const [filter, setFilterState] = useState( @@ -343,3 +360,11 @@ export const showWhenSingleSelection = ( ) => { return selectedIds.size == 1; }; + +export const showWhenNoneSelected = ( + result: T, + filter: ListFilterModel, + selectedIds: Set +) => { + return selectedIds.size === 0; +}; diff --git a/ui/v2.5/src/components/List/ListFilter.tsx b/ui/v2.5/src/components/List/ListFilter.tsx index bff14336cfd..24ea02af1d4 100644 --- a/ui/v2.5/src/components/List/ListFilter.tsx +++ b/ui/v2.5/src/components/List/ListFilter.tsx @@ -1,5 +1,11 @@ import cloneDeep from "lodash-es/cloneDeep"; -import React, { useCallback, useEffect, useRef, useState } from "react"; +import React, { + useCallback, + useEffect, + useMemo, + useRef, + useState, +} from "react"; import Mousetrap from "mousetrap"; import { SortDirectionEnum } from "src/core/generated-graphql"; import { @@ -102,36 +108,17 @@ export const SearchTermInput: React.FC<{ ); }; -interface IListFilterProps { - onFilterUpdate: (newFilter: ListFilterModel) => void; - filter: ListFilterModel; - view?: View; - openFilterDialog: () => void; -} - const PAGE_SIZE_OPTIONS = ["20", "40", "60", "120", "250", "500", "1000"]; -export const ListFilter: React.FC = ({ - onFilterUpdate, - filter, - openFilterDialog, - view, -}) => { - const [customPageSizeShowing, setCustomPageSizeShowing] = useState(false); - const perPageSelect = useRef(null); - const [perPageInput, perPageFocus] = useFocus(); - - const filterOptions = filter.options; - +export const PageSizeSelector: React.FC<{ + pageSize: number; + setPageSize: (pageSize: number) => void; +}> = ({ pageSize, setPageSize }) => { const intl = useIntl(); - useEffect(() => { - Mousetrap.bind("r", () => onReshuffleRandomSort()); - - return () => { - Mousetrap.unbind("r"); - }; - }); + const perPageSelect = useRef(null); + const [perPageInput, perPageFocus] = useFocus(); + const [customPageSizeShowing, setCustomPageSizeShowing] = useState(false); useEffect(() => { if (customPageSizeShowing) { @@ -139,6 +126,27 @@ export const ListFilter: React.FC = ({ } }, [customPageSizeShowing, perPageFocus]); + const pageSizeOptions = useMemo(() => { + const ret = PAGE_SIZE_OPTIONS.map((o) => { + return { + label: o, + value: o, + }; + }); + const currentPerPage = pageSize.toString(); + if (!ret.find((o) => o.value === currentPerPage)) { + ret.push({ label: currentPerPage, value: currentPerPage }); + ret.sort((a, b) => parseInt(a.value, 10) - parseInt(b.value, 10)); + } + + ret.push({ + label: `${intl.formatMessage({ id: "custom" })}...`, + value: "custom", + }); + + return ret; + }, [intl, pageSize]); + function onChangePageSize(val: string) { if (val === "custom") { // added timeout since Firefox seems to trigger the rootClose immediately @@ -154,6 +162,94 @@ export const ListFilter: React.FC = ({ return; } + setPageSize(pp); + } + + return ( +
+ onChangePageSize(e.target.value)} + value={pageSize.toString()} + className="btn-secondary" + > + {pageSizeOptions.map((s) => ( + + ))} + + setCustomPageSizeShowing(false)} + > + +
+ + ) => { + if (e.key === "Enter") { + onChangePageSize( + (perPageInput.current as HTMLInputElement)?.value ?? "" + ); + e.preventDefault(); + } + }} + /> + + + + +
+
+
+
+ ); +}; + +interface IListFilterProps { + onFilterUpdate: (newFilter: ListFilterModel) => void; + filter: ListFilterModel; + view?: View; + openFilterDialog: () => void; +} + +export const ListFilter: React.FC = ({ + onFilterUpdate, + filter, + openFilterDialog, + view, +}) => { + const filterOptions = filter.options; + + const intl = useIntl(); + + useEffect(() => { + Mousetrap.bind("r", () => onReshuffleRandomSort()); + + return () => { + Mousetrap.unbind("r"); + }; + }); + + function onChangePageSize(pp: number) { const newFilter = cloneDeep(filter); newFilter.itemsPerPage = pp; newFilter.currentPage = 1; @@ -211,25 +307,6 @@ export const ListFilter: React.FC = ({ (o) => o.value === filter.sortBy ); - const pageSizeOptions = PAGE_SIZE_OPTIONS.map((o) => { - return { - label: o, - value: o, - }; - }); - const currentPerPage = filter.itemsPerPage.toString(); - if (!pageSizeOptions.find((o) => o.value === currentPerPage)) { - pageSizeOptions.push({ label: currentPerPage, value: currentPerPage }); - pageSizeOptions.sort( - (a, b) => parseInt(a.value, 10) - parseInt(b.value, 10) - ); - } - - pageSizeOptions.push({ - label: `${intl.formatMessage({ id: "custom" })}...`, - value: "custom", - }); - return ( <>
@@ -301,63 +378,10 @@ export const ListFilter: React.FC = ({ )} -
- onChangePageSize(e.target.value)} - value={filter.itemsPerPage.toString()} - className="btn-secondary" - > - {pageSizeOptions.map((s) => ( - - ))} - - setCustomPageSizeShowing(false)} - > - -
- - ) => { - if (e.key === "Enter") { - onChangePageSize( - (perPageInput.current as HTMLInputElement)?.value ?? - "" - ); - e.preventDefault(); - } - }} - /> - - - - -
-
-
-
+ ); } diff --git a/ui/v2.5/src/components/List/ListOperationButtons.tsx b/ui/v2.5/src/components/List/ListOperationButtons.tsx index 4373d933847..92bcf9ebcd2 100644 --- a/ui/v2.5/src/components/List/ListOperationButtons.tsx +++ b/ui/v2.5/src/components/List/ListOperationButtons.tsx @@ -1,4 +1,4 @@ -import React, { useEffect } from "react"; +import React, { PropsWithChildren, useEffect } from "react"; import { Button, ButtonGroup, @@ -16,6 +16,23 @@ import { faTrash, } from "@fortawesome/free-solid-svg-icons"; +export const OperationDropdown: React.FC> = ({ + children, +}) => { + if (!children) return null; + + return ( + + + + + + {children} + + + ); +}; + export interface IListFilterOperation { text: string; onClick: () => void; @@ -154,6 +171,11 @@ export const ListOperationButtons: React.FC = ({ if (otherOperations) { otherOperations .filter((o) => { + // buttons with icons are rendered in the button group + if (o.icon) { + return false; + } + if (!o.isDisplayed) { return true; } @@ -173,18 +195,11 @@ export const ListOperationButtons: React.FC = ({ }); } - if (options.length > 0) { - return ( - - - - - - {options} - - - ); - } + return ( + + {options.length > 0 ? options : undefined} + + ); } return ( diff --git a/ui/v2.5/src/components/List/ListProvider.tsx b/ui/v2.5/src/components/List/ListProvider.tsx index a3a41a93d9f..6ef0d5055c0 100644 --- a/ui/v2.5/src/components/List/ListProvider.tsx +++ b/ui/v2.5/src/components/List/ListProvider.tsx @@ -66,6 +66,26 @@ export function useListContext() { return context as IListContextState; } +const emptyState: IListContextState = { + selectable: false, + selectedIds: new Set(), + getSelected: () => [], + onSelectChange: () => {}, + onSelectAll: () => {}, + onSelectNone: () => {}, + items: [], +}; + +export function useListContextOptional() { + const context = React.useContext(ListStateContext); + + if (context === null) { + return emptyState as IListContextState; + } + + return context as IListContextState; +} + interface IQueryResultContextOptions< T extends QueryResult, E extends IHasID = IHasID diff --git a/ui/v2.5/src/components/List/styles.scss b/ui/v2.5/src/components/List/styles.scss index edfb9d2a791..632ac9533fc 100644 --- a/ui/v2.5/src/components/List/styles.scss +++ b/ui/v2.5/src/components/List/styles.scss @@ -572,6 +572,10 @@ input[type="range"].zoom-slider { } } +.filtered-list-toolbar { + justify-content: center; +} + .search-term-input { margin-right: 0.5rem; } diff --git a/ui/v2.5/src/components/List/util.ts b/ui/v2.5/src/components/List/util.ts index 69d3528bd4b..3b85b666d86 100644 --- a/ui/v2.5/src/components/List/util.ts +++ b/ui/v2.5/src/components/List/util.ts @@ -7,6 +7,7 @@ import { QueryResult } from "@apollo/client"; import { IHasID } from "src/utils/data"; import { ConfigurationContext } from "src/hooks/Config"; import { View } from "./views"; +import { usePrevious } from "src/hooks/state"; export function useFilterURL( filter: ListFilterModel, @@ -180,6 +181,25 @@ export function useListSelect(items: T[]) { const [selectedIds, setSelectedIds] = useState>(new Set()); const [lastClickedId, setLastClickedId] = useState(); + const prevItems = usePrevious(items); + + useEffect(() => { + if (prevItems === items) { + return; + } + + // filter out any selectedIds that are no longer in the list + const newSelectedIds = new Set(); + + selectedIds.forEach((id) => { + if (items.some((item) => item.id === id)) { + newSelectedIds.add(id); + } + }); + + setSelectedIds(newSelectedIds); + }, [prevItems, items, selectedIds]); + function singleSelect(id: string, selected: boolean) { setLastClickedId(id); diff --git a/ui/v2.5/src/components/List/views.ts b/ui/v2.5/src/components/List/views.ts index 2b4179014e1..bb36a4c4ea6 100644 --- a/ui/v2.5/src/components/List/views.ts +++ b/ui/v2.5/src/components/List/views.ts @@ -31,4 +31,5 @@ export enum View { StudioChildren = "studio_children", GroupScenes = "group_scenes", + GroupSubGroups = "group_sub_groups", } diff --git a/ui/v2.5/src/components/Scenes/SceneCard.tsx b/ui/v2.5/src/components/Scenes/SceneCard.tsx index cbe3ee64e0c..b5d053c67e7 100644 --- a/ui/v2.5/src/components/Scenes/SceneCard.tsx +++ b/ui/v2.5/src/components/Scenes/SceneCard.tsx @@ -1,15 +1,10 @@ import React, { useEffect, useMemo, useRef, useState } from "react"; import { Button, ButtonGroup, OverlayTrigger, Tooltip } from "react-bootstrap"; -import { Link, useHistory } from "react-router-dom"; +import { useHistory } from "react-router-dom"; import cx from "classnames"; import * as GQL from "src/core/generated-graphql"; import { Icon } from "../Shared/Icon"; -import { - GalleryLink, - TagLink, - GroupLink, - SceneMarkerLink, -} from "../Shared/TagLink"; +import { GalleryLink, TagLink, SceneMarkerLink } from "../Shared/TagLink"; import { HoverPopover } from "../Shared/HoverPopover"; import { SweatDrops } from "../Shared/SweatDrops"; import { TruncatedText } from "../Shared/TruncatedText"; @@ -20,7 +15,7 @@ import { ConfigurationContext } from "src/hooks/Config"; import { PerformerPopoverButton } from "../Shared/PerformerPopoverButton"; import { GridCard, calculateCardWidth } from "../Shared/GridCard/GridCard"; import { RatingBanner } from "../Shared/RatingBanner"; -import { FormattedNumber } from "react-intl"; +import { FormattedMessage, FormattedNumber } from "react-intl"; import { faBox, faCopy, @@ -34,6 +29,7 @@ import { PreviewScrubber } from "./PreviewScrubber"; import { PatchComponent } from "src/patch"; import ScreenUtils from "src/utils/screen"; import { StudioOverlay } from "../Shared/GridCard/StudioOverlay"; +import { GroupTag } from "../Groups/GroupTag"; interface IScenePreviewProps { isPortrait: boolean; @@ -106,8 +102,26 @@ interface ISceneCardProps { selected?: boolean | undefined; zoomIndex?: number; onSelectedChanged?: (selected: boolean, shiftKey: boolean) => void; + fromGroupId?: string; } +const Description: React.FC<{ + sceneNumber?: number; +}> = ({ sceneNumber }) => { + if (!sceneNumber) return null; + + return ( + <> +
+ {sceneNumber !== undefined && ( + + #{sceneNumber} + + )} + + ); +}; + const SceneCardPopovers = PatchComponent( "SceneCard.Popovers", (props: ISceneCardProps) => { @@ -116,6 +130,17 @@ const SceneCardPopovers = PatchComponent( [props.scene] ); + const sceneNumber = useMemo(() => { + if (!props.fromGroupId) { + return undefined; + } + + const group = props.scene.groups.find( + (g) => g.group.id === props.fromGroupId + ); + return group?.scene_index ?? undefined; + }, [props.fromGroupId, props.scene.groups]); + function maybeRenderTagPopoverButton() { if (props.scene.tags.length <= 0) return; @@ -147,23 +172,7 @@ const SceneCardPopovers = PatchComponent( if (props.scene.groups.length <= 0) return; const popoverContent = props.scene.groups.map((sceneGroup) => ( -
- - {sceneGroup.group.name - - -
+ )); return ( @@ -283,10 +292,12 @@ const SceneCardPopovers = PatchComponent( props.scene.scene_markers.length > 0 || props.scene?.o_counter || props.scene.galleries.length > 0 || - props.scene.organized) + props.scene.organized || + sceneNumber !== undefined) ) { return ( <> +
{maybeRenderTagPopoverButton()} diff --git a/ui/v2.5/src/components/Scenes/SceneCardsGrid.tsx b/ui/v2.5/src/components/Scenes/SceneCardsGrid.tsx index b882b5ec543..9884e37a06f 100644 --- a/ui/v2.5/src/components/Scenes/SceneCardsGrid.tsx +++ b/ui/v2.5/src/components/Scenes/SceneCardsGrid.tsx @@ -10,6 +10,7 @@ interface ISceneCardsGrid { selectedIds: Set; zoomIndex: number; onSelectChange: (id: string, selected: boolean, shiftKey: boolean) => void; + fromGroupId?: string; } export const SceneCardsGrid: React.FC = ({ @@ -18,6 +19,7 @@ export const SceneCardsGrid: React.FC = ({ selectedIds, zoomIndex, onSelectChange, + fromGroupId, }) => { const [componentRef, { width }] = useContainerDimensions(); return ( @@ -35,6 +37,7 @@ export const SceneCardsGrid: React.FC = ({ onSelectedChanged={(selected: boolean, shiftKey: boolean) => onSelectChange(scene.id, selected, shiftKey) } + fromGroupId={fromGroupId} /> ))}
diff --git a/ui/v2.5/src/components/Scenes/SceneDetails/SceneGroupPanel.tsx b/ui/v2.5/src/components/Scenes/SceneDetails/SceneGroupPanel.tsx index 6f58b504cf9..53a5e174d30 100644 --- a/ui/v2.5/src/components/Scenes/SceneDetails/SceneGroupPanel.tsx +++ b/ui/v2.5/src/components/Scenes/SceneDetails/SceneGroupPanel.tsx @@ -13,7 +13,7 @@ export const SceneGroupPanel: React.FC = ( )); diff --git a/ui/v2.5/src/components/Scenes/SceneList.tsx b/ui/v2.5/src/components/Scenes/SceneList.tsx index 6fa7c5dbd18..e78e31d2cbd 100644 --- a/ui/v2.5/src/components/Scenes/SceneList.tsx +++ b/ui/v2.5/src/components/Scenes/SceneList.tsx @@ -75,6 +75,7 @@ interface ISceneList { defaultSort?: string; view?: View; alterQuery?: boolean; + fromGroupId?: string; } export const SceneList: React.FC = ({ @@ -82,6 +83,7 @@ export const SceneList: React.FC = ({ defaultSort, view, alterQuery, + fromGroupId, }) => { const intl = useIntl(); const history = useHistory(); @@ -297,6 +299,7 @@ export const SceneList: React.FC = ({ zoomIndex={filter.zoomIndex} selectedIds={selectedIds} onSelectChange={onSelectChange} + fromGroupId={fromGroupId} /> ); } diff --git a/ui/v2.5/src/components/Scenes/styles.scss b/ui/v2.5/src/components/Scenes/styles.scss index 57d68f94105..b9df2f7b5c6 100644 --- a/ui/v2.5/src/components/Scenes/styles.scss +++ b/ui/v2.5/src/components/Scenes/styles.scss @@ -208,6 +208,10 @@ textarea.scene-description { &-preview { aspect-ratio: 16/9; } + + .scene-group-scene-number { + text-align: center; + } } .scene-card, diff --git a/ui/v2.5/src/components/Shared/GridCard/GridCard.tsx b/ui/v2.5/src/components/Shared/GridCard/GridCard.tsx index 1d1a37528d9..33aa24e32cd 100644 --- a/ui/v2.5/src/components/Shared/GridCard/GridCard.tsx +++ b/ui/v2.5/src/components/Shared/GridCard/GridCard.tsx @@ -1,10 +1,18 @@ -import React, { MutableRefObject, useRef, useState } from "react"; +import React, { + MutableRefObject, + PropsWithChildren, + useRef, + useState, +} from "react"; import { Card, Form } from "react-bootstrap"; import { Link } from "react-router-dom"; import cx from "classnames"; import { TruncatedText } from "../TruncatedText"; import ScreenUtils from "src/utils/screen"; import useResizeObserver from "@react-hook/resize-observer"; +import { Icon } from "../Icon"; +import { faGripLines } from "@fortawesome/free-solid-svg-icons"; +import { DragSide, useDragMoveSelect } from "./dragMoveSelect"; interface ICardProps { className?: string; @@ -24,6 +32,10 @@ interface ICardProps { resumeTime?: number; duration?: number; interactiveHeatmap?: string; + + // move logic - both of the following are required to enable move dragging + objectId?: string; // required for move dragging + onMove?: (srcIds: string[], targetId: string, after: boolean) => void; } export const calculateCardWidth = ( @@ -66,60 +78,82 @@ export const useContainerDimensions = ( return [target, dimension]; }; -export const GridCard: React.FC = (props: ICardProps) => { - function handleImageClick(event: React.MouseEvent) { - const { shiftKey } = event; +const Checkbox: React.FC<{ + selected?: boolean; + onSelectedChanged?: (selected: boolean, shiftKey: boolean) => void; +}> = ({ selected = false, onSelectedChanged }) => { + let shiftKey = false; - if (!props.onSelectedChanged) { - return; - } + return ( + onSelectedChanged!(!selected, shiftKey)} + onClick={(event: React.MouseEvent) => { + shiftKey = event.shiftKey; + event.stopPropagation(); + }} + /> + ); +}; - if (props.selecting) { - props.onSelectedChanged(!props.selected, shiftKey); - event.preventDefault(); - } +const DragHandle: React.FC<{ + setInHandle: (inHandle: boolean) => void; +}> = ({ setInHandle }) => { + function onMouseEnter() { + setInHandle(true); } - function handleDrag(event: React.DragEvent) { - if (props.selecting) { - event.dataTransfer.setData("text/plain", ""); - event.dataTransfer.setDragImage(new Image(), 0, 0); - } + function onMouseLeave() { + setInHandle(false); } - function handleDragOver(event: React.DragEvent) { - const ev = event; - const shiftKey = false; - - if (!props.onSelectedChanged) { - return; - } + return ( + + + + ); +}; - if (props.selecting && !props.selected) { - props.onSelectedChanged(true, shiftKey); - } +const Controls: React.FC> = ({ children }) => { + return
{children}
; +}; - ev.dataTransfer.dropEffect = "move"; - ev.preventDefault(); +const MoveTarget: React.FC<{ dragSide: DragSide }> = ({ dragSide }) => { + if (dragSide === undefined) { + return null; } - let shiftKey = false; + return ( +
+ ); +}; - function maybeRenderCheckbox() { - if (props.onSelectedChanged) { - return ( - props.onSelectedChanged!(!props.selected, shiftKey)} - onClick={(event: React.MouseEvent) => { - shiftKey = event.shiftKey; - event.stopPropagation(); - }} - /> - ); +export const GridCard: React.FC = (props: ICardProps) => { + const { setInHandle, moveTarget, dragProps } = useDragMoveSelect({ + selecting: props.selecting || false, + selected: props.selected || false, + onSelectedChanged: props.onSelectedChanged, + objectId: props.objectId, + onMove: props.onMove, + }); + + function handleImageClick(event: React.MouseEvent) { + const { shiftKey } = event; + + if (!props.onSelectedChanged) { + return; + } + + if (props.selecting) { + props.onSelectedChanged(!props.selected, shiftKey); + event.preventDefault(); } } @@ -156,16 +190,26 @@ export const GridCard: React.FC = (props: ICardProps) => { - {maybeRenderCheckbox()} + {moveTarget !== undefined && } + + {props.onSelectedChanged && ( + + )} + + {!!props.objectId && props.onMove && ( + + )} +
void; + objectId?: string; + onMove?: (srcIds: string[], targetId: string, after: boolean) => void; +}) { + const { selectedIds } = useListContextOptional(); + + const [inHandle, setInHandle] = useState(false); + const [moveSrc, setMoveSrc] = useState(false); + const [moveTarget, setMoveTarget] = useState(); + + const canSelect = props.onSelectedChanged && props.selecting; + const canMove = !!props.objectId && props.onMove && inHandle; + const draggable = canSelect || canMove; + + function onDragStart(event: React.DragEvent) { + if (!draggable) { + event.preventDefault(); + return; + } + + if (!inHandle && props.selecting) { + event.dataTransfer.setData("text/plain", ""); + // event.dataTransfer.setDragImage(new Image(), 0, 0); + event.dataTransfer.effectAllowed = "copy"; + event.stopPropagation(); + } else if (inHandle && props.objectId) { + if (selectedIds.size > 1 && selectedIds.has(props.objectId)) { + // moving all selected + const movingIds = Array.from(selectedIds.values()).join(","); + event.dataTransfer.setData("text/plain", movingIds); + } else { + // moving single + setMoveSrc(true); + event.dataTransfer.setData("text/plain", props.objectId); + } + event.dataTransfer.effectAllowed = "move"; + event.stopPropagation(); + } + } + + function doSetMoveTarget(event: React.DragEvent) { + const isBefore = + event.nativeEvent.offsetX < event.currentTarget.clientWidth / 2; + if (isBefore && moveTarget !== DragSide.BEFORE) { + setMoveTarget(DragSide.BEFORE); + } else if (!isBefore && moveTarget !== DragSide.AFTER) { + setMoveTarget(DragSide.AFTER); + } + } + + function onDragEnter(event: React.DragEvent) { + const ev = event; + const shiftKey = false; + + if (ev.dataTransfer.effectAllowed === "copy") { + if (!props.onSelectedChanged) { + return; + } + + if (props.selecting && !props.selected) { + props.onSelectedChanged(true, shiftKey); + } + + ev.dataTransfer.dropEffect = "copy"; + ev.preventDefault(); + } else if (ev.dataTransfer.effectAllowed === "move" && !moveSrc) { + doSetMoveTarget(event); + ev.dataTransfer.dropEffect = "move"; + ev.preventDefault(); + } else { + ev.dataTransfer.dropEffect = "none"; + } + } + + function onDragLeave(event: React.DragEvent) { + if (event.currentTarget.contains(event.relatedTarget as Node)) { + return; + } + + setMoveTarget(undefined); + } + + function onDragOver(event: React.DragEvent) { + if (event.dataTransfer.effectAllowed === "move" && moveSrc) { + return; + } + + doSetMoveTarget(event); + + event.preventDefault(); + } + + function onDragEnd() { + setMoveTarget(undefined); + setMoveSrc(false); + } + + function onDrop(event: React.DragEvent) { + const ev = event; + + if ( + ev.dataTransfer.effectAllowed === "copy" || + !props.onMove || + !props.objectId + ) { + return; + } + + const srcIds = ev.dataTransfer.getData("text/plain").split(","); + const targetId = props.objectId; + const after = moveTarget === DragSide.AFTER; + + props.onMove(srcIds, targetId, after); + + onDragEnd(); + } + + return { + inHandle, + setInHandle, + moveTarget, + dragProps: { + draggable: draggable || undefined, + onDragStart, + onDragEnter, + onDragLeave, + onDragOver, + onDragEnd, + onDrop, + }, + }; +} diff --git a/ui/v2.5/src/components/Shared/GridCard/styles.scss b/ui/v2.5/src/components/Shared/GridCard/styles.scss index fcf699fe234..ece1f280aab 100644 --- a/ui/v2.5/src/components/Shared/GridCard/styles.scss +++ b/ui/v2.5/src/components/Shared/GridCard/styles.scss @@ -57,3 +57,28 @@ transition: opacity 0.5s; } } + +.move-target { + align-items: center; + background-color: $primary; + color: $secondary; + display: flex; + height: 100%; + justify-content: center; + opacity: 0.5; + pointer-events: none; + position: absolute; + width: 10%; + + &.move-target-before { + left: 0; + } + + &.move-target-after { + right: 0; + } +} + +.card-drag-handle { + filter: drop-shadow(1px 1px 1px rgba(0, 0, 0, 0.7)); +} diff --git a/ui/v2.5/src/components/Shared/Icon.tsx b/ui/v2.5/src/components/Shared/Icon.tsx index adbc3dcfdd3..32f4d0259f0 100644 --- a/ui/v2.5/src/components/Shared/Icon.tsx +++ b/ui/v2.5/src/components/Shared/Icon.tsx @@ -1,23 +1,16 @@ import React from "react"; -import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; -import { IconDefinition, SizeProp } from "@fortawesome/fontawesome-svg-core"; +import { + FontAwesomeIcon, + FontAwesomeIconProps, +} from "@fortawesome/react-fontawesome"; import { PatchComponent } from "src/patch"; -interface IIcon { - icon: IconDefinition; - className?: string; - color?: string; - size?: SizeProp; -} - -export const Icon: React.FC = PatchComponent( +export const Icon: React.FC = PatchComponent( "Icon", - ({ icon, className, color, size }) => ( + (props) => ( ) ); diff --git a/ui/v2.5/src/components/Shared/MultiSet.tsx b/ui/v2.5/src/components/Shared/MultiSet.tsx index 521a2577bef..4ed7a99ffe4 100644 --- a/ui/v2.5/src/components/Shared/MultiSet.tsx +++ b/ui/v2.5/src/components/Shared/MultiSet.tsx @@ -1,5 +1,5 @@ import React from "react"; -import { useIntl } from "react-intl"; +import { IntlShape, useIntl } from "react-intl"; import * as GQL from "src/core/generated-graphql"; import { Button, ButtonGroup } from "react-bootstrap"; @@ -52,70 +52,96 @@ const Select: React.FC = (props) => { ); }; -export const MultiSet: React.FC = (props) => { - const intl = useIntl(); - const modes = [ - GQL.BulkUpdateIdMode.Set, - GQL.BulkUpdateIdMode.Add, - GQL.BulkUpdateIdMode.Remove, - ]; - - function getModeText(mode: GQL.BulkUpdateIdMode) { - switch (mode) { - case GQL.BulkUpdateIdMode.Set: - return intl.formatMessage({ - id: "actions.overwrite", - defaultMessage: "Overwrite", - }); - case GQL.BulkUpdateIdMode.Add: - return intl.formatMessage({ id: "actions.add", defaultMessage: "Add" }); - case GQL.BulkUpdateIdMode.Remove: - return intl.formatMessage({ - id: "actions.remove", - defaultMessage: "Remove", - }); - } +function getModeText(intl: IntlShape, mode: GQL.BulkUpdateIdMode) { + switch (mode) { + case GQL.BulkUpdateIdMode.Set: + return intl.formatMessage({ + id: "actions.overwrite", + defaultMessage: "Overwrite", + }); + case GQL.BulkUpdateIdMode.Add: + return intl.formatMessage({ id: "actions.add", defaultMessage: "Add" }); + case GQL.BulkUpdateIdMode.Remove: + return intl.formatMessage({ + id: "actions.remove", + defaultMessage: "Remove", + }); } +} + +export const MultiSetModeButton: React.FC<{ + mode: GQL.BulkUpdateIdMode; + active: boolean; + onClick: () => void; + disabled?: boolean; +}> = ({ mode, active, onClick, disabled }) => { + const intl = useIntl(); + + return ( + + ); +}; + +const modes = [ + GQL.BulkUpdateIdMode.Set, + GQL.BulkUpdateIdMode.Add, + GQL.BulkUpdateIdMode.Remove, +]; - function onSetMode(mode: GQL.BulkUpdateIdMode) { - if (mode === props.mode) { +export const MultiSetModeButtons: React.FC<{ + mode: GQL.BulkUpdateIdMode; + onSetMode: (mode: GQL.BulkUpdateIdMode) => void; + disabled?: boolean; +}> = ({ mode, onSetMode, disabled }) => { + return ( + + {modes.map((m) => ( + onSetMode(m)} + disabled={disabled} + /> + ))} + + ); +}; + +export const MultiSet: React.FC = (props) => { + const { mode, onUpdate, existingIds } = props; + + function onSetMode(m: GQL.BulkUpdateIdMode) { + if (m === mode) { return; } // if going to Set, set the existing ids - if (mode === GQL.BulkUpdateIdMode.Set && props.existingIds) { - props.onUpdate(props.existingIds); + if (m === GQL.BulkUpdateIdMode.Set && existingIds) { + onUpdate(existingIds); // if going from Set, wipe the ids } else if ( - mode !== GQL.BulkUpdateIdMode.Set && - props.mode === GQL.BulkUpdateIdMode.Set + m !== GQL.BulkUpdateIdMode.Set && + mode === GQL.BulkUpdateIdMode.Set ) { - props.onUpdate([]); + onUpdate([]); } - props.onSetMode(mode); - } - - function renderModeButton(mode: GQL.BulkUpdateIdMode) { - return ( - - ); + props.onSetMode(m); } return (
- - {modes.map((m) => renderModeButton(m))} - +