diff --git a/graphql/schema/schema.graphql b/graphql/schema/schema.graphql index f11edb46f36..251c2af838c 100644 --- a/graphql/schema/schema.graphql +++ b/graphql/schema/schema.graphql @@ -359,6 +359,12 @@ type Mutation { groupsDestroy(ids: [ID!]!): Boolean! bulkGroupUpdate(input: BulkGroupUpdateInput!): [Group!] + addGroupSubGroups(input: GroupSubGroupAddInput!): Boolean! + removeGroupSubGroups(input: GroupSubGroupRemoveInput!): Boolean! + + "Reorder sub groups within a group. Returns true if successful." + reorderSubGroups(input: ReorderSubGroupsInput!): Boolean! + tagCreate(input: TagCreateInput!): Tag tagUpdate(input: TagUpdateInput!): Tag tagDestroy(input: TagDestroyInput!): Boolean! diff --git a/graphql/schema/types/filters.graphql b/graphql/schema/types/filters.graphql index 1ca8c1fb08d..f0f84efda8c 100644 --- a/graphql/schema/types/filters.graphql +++ b/graphql/schema/types/filters.graphql @@ -261,7 +261,7 @@ input SceneFilterType { "Filter to only include scenes with this movie" movies: MultiCriterionInput @deprecated(reason: "use groups instead") "Filter to only include scenes with this group" - groups: MultiCriterionInput + groups: HierarchicalMultiCriterionInput "Filter to only include scenes with this gallery" galleries: MultiCriterionInput "Filter to only include scenes with these tags" @@ -390,6 +390,15 @@ input GroupFilterType { "Filter by last update time" updated_at: TimestampCriterionInput + "Filter by containing groups" + containing_groups: HierarchicalMultiCriterionInput + "Filter by sub groups" + sub_groups: HierarchicalMultiCriterionInput + "Filter by number of containing groups the group has" + containing_group_count: IntCriterionInput + "Filter by number of sub-groups the group has" + sub_group_count: IntCriterionInput + "Filter by related scenes that meet this criteria" scenes_filter: SceneFilterType "Filter by related studios that meet this criteria" diff --git a/graphql/schema/types/group.graphql b/graphql/schema/types/group.graphql index 15bb3556ca6..b42e4fd1fef 100644 --- a/graphql/schema/types/group.graphql +++ b/graphql/schema/types/group.graphql @@ -1,3 +1,9 @@ +"GroupDescription represents a relationship to a group with a description of the relationship" +type GroupDescription { + group: Group! + description: String +} + type Group { id: ID! name: String! @@ -15,12 +21,21 @@ type Group { created_at: Time! updated_at: Time! + containing_groups: [GroupDescription!]! + sub_groups: [GroupDescription!]! + front_image_path: String # Resolver back_image_path: String # Resolver - scene_count: Int! # Resolver + scene_count(depth: Int): Int! # Resolver + sub_group_count(depth: Int): Int! # Resolver scenes: [Scene!]! } +input GroupDescriptionInput { + group_id: ID! + description: String +} + input GroupCreateInput { name: String! aliases: String @@ -34,6 +49,10 @@ input GroupCreateInput { synopsis: String urls: [String!] tag_ids: [ID!] + + containing_groups: [GroupDescriptionInput!] + sub_groups: [GroupDescriptionInput!] + "This should be a URL or a base64 encoded data URL" front_image: String "This should be a URL or a base64 encoded data URL" @@ -53,12 +72,21 @@ input GroupUpdateInput { synopsis: String urls: [String!] tag_ids: [ID!] + + containing_groups: [GroupDescriptionInput!] + sub_groups: [GroupDescriptionInput!] + "This should be a URL or a base64 encoded data URL" front_image: String "This should be a URL or a base64 encoded data URL" back_image: String } +input BulkUpdateGroupDescriptionsInput { + groups: [GroupDescriptionInput!]! + mode: BulkUpdateIdMode! +} + input BulkGroupUpdateInput { clientMutationId: String ids: [ID!] @@ -68,13 +96,42 @@ input BulkGroupUpdateInput { director: String urls: BulkUpdateStrings tag_ids: BulkUpdateIds + + containing_groups: BulkUpdateGroupDescriptionsInput + sub_groups: BulkUpdateGroupDescriptionsInput } input GroupDestroyInput { id: ID! } +input ReorderSubGroupsInput { + "ID of the group to reorder sub groups for" + group_id: ID! + """ + IDs of the sub groups to reorder. These must be a subset of the current sub groups. + Sub groups will be inserted in this order at the insert_index + """ + sub_group_ids: [ID!]! + "The sub-group ID at which to insert the sub groups" + insert_at_id: ID! + "If true, the sub groups will be inserted after the insert_index, otherwise they will be inserted before" + insert_after: Boolean +} + type FindGroupsResultType { count: Int! groups: [Group!]! } + +input GroupSubGroupAddInput { + containing_group_id: ID! + sub_groups: [GroupDescriptionInput!]! + "The index at which to insert the sub groups. If not provided, the sub groups will be appended to the end" + insert_index: Int +} + +input GroupSubGroupRemoveInput { + containing_group_id: ID! + sub_group_ids: [ID!]! +} diff --git a/graphql/schema/types/movie.graphql b/graphql/schema/types/movie.graphql index 0723bcc4f28..845827b3f17 100644 --- a/graphql/schema/types/movie.graphql +++ b/graphql/schema/types/movie.graphql @@ -18,7 +18,7 @@ type Movie { front_image_path: String # Resolver back_image_path: String # Resolver - scene_count: Int! # Resolver + scene_count(depth: Int): Int! # Resolver scenes: [Scene!]! } diff --git a/graphql/schema/types/performer.graphql b/graphql/schema/types/performer.graphql index 8ac6c6579ad..d6f3dd832c4 100644 --- a/graphql/schema/types/performer.graphql +++ b/graphql/schema/types/performer.graphql @@ -56,7 +56,7 @@ type Performer { weight: Int created_at: Time! updated_at: Time! - groups: [Group!]! @deprecated(reason: "use groups instead") + groups: [Group!]! movies: [Movie!]! @deprecated(reason: "use groups instead") } diff --git a/internal/api/changeset_translator.go b/internal/api/changeset_translator.go index efac25087d6..1170088aac9 100644 --- a/internal/api/changeset_translator.go +++ b/internal/api/changeset_translator.go @@ -434,3 +434,64 @@ func (t changesetTranslator) updateGroupIDsBulk(value *BulkUpdateIds, field stri Mode: value.Mode, }, nil } + +func groupsDescriptionsFromGroupInput(input []*GroupDescriptionInput) ([]models.GroupIDDescription, error) { + ret := make([]models.GroupIDDescription, len(input)) + + for i, v := range input { + gID, err := strconv.Atoi(v.GroupID) + if err != nil { + return nil, fmt.Errorf("invalid group ID: %s", v.GroupID) + } + + ret[i] = models.GroupIDDescription{ + GroupID: gID, + } + if v.Description != nil { + ret[i].Description = *v.Description + } + } + + return ret, nil +} + +func (t changesetTranslator) groupIDDescriptions(value []*GroupDescriptionInput) (models.RelatedGroupDescriptions, error) { + groupsScenes, err := groupsDescriptionsFromGroupInput(value) + if err != nil { + return models.RelatedGroupDescriptions{}, err + } + + return models.NewRelatedGroupDescriptions(groupsScenes), nil +} + +func (t changesetTranslator) updateGroupIDDescriptions(value []*GroupDescriptionInput, field string) (*models.UpdateGroupDescriptions, error) { + if !t.hasField(field) { + return nil, nil + } + + groupsScenes, err := groupsDescriptionsFromGroupInput(value) + if err != nil { + return nil, err + } + + return &models.UpdateGroupDescriptions{ + Groups: groupsScenes, + Mode: models.RelationshipUpdateModeSet, + }, nil +} + +func (t changesetTranslator) updateGroupIDDescriptionsBulk(value *BulkUpdateGroupDescriptionsInput, field string) (*models.UpdateGroupDescriptions, error) { + if !t.hasField(field) || value == nil { + return nil, nil + } + + groups, err := groupsDescriptionsFromGroupInput(value.Groups) + if err != nil { + return nil, err + } + + return &models.UpdateGroupDescriptions{ + Groups: groups, + Mode: value.Mode, + }, nil +} diff --git a/internal/api/resolver.go b/internal/api/resolver.go index e5c635b9a7d..ab6eead7e5e 100644 --- a/internal/api/resolver.go +++ b/internal/api/resolver.go @@ -37,6 +37,7 @@ type Resolver struct { sceneService manager.SceneService imageService manager.ImageService galleryService manager.GalleryService + groupService manager.GroupService hookExecutor hookExecutor } diff --git a/internal/api/resolver_model_movie.go b/internal/api/resolver_model_movie.go index abbbccaf10a..04018d81fbb 100644 --- a/internal/api/resolver_model_movie.go +++ b/internal/api/resolver_model_movie.go @@ -5,7 +5,9 @@ import ( "github.com/stashapp/stash/internal/api/loaders" "github.com/stashapp/stash/internal/api/urlbuilders" + "github.com/stashapp/stash/pkg/group" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/scene" ) func (r *groupResolver) Date(ctx context.Context, obj *models.Group) (*string, error) { @@ -71,6 +73,68 @@ func (r groupResolver) Tags(ctx context.Context, obj *models.Group) (ret []*mode return ret, firstError(errs) } +func (r groupResolver) relatedGroups(ctx context.Context, rgd models.RelatedGroupDescriptions) (ret []*GroupDescription, err error) { + // rgd must be loaded + gds := rgd.List() + ids := make([]int, len(gds)) + for i, gd := range gds { + ids[i] = gd.GroupID + } + + groups, errs := loaders.From(ctx).GroupByID.LoadAll(ids) + + err = firstError(errs) + if err != nil { + return + } + + ret = make([]*GroupDescription, len(groups)) + for i, group := range groups { + ret[i] = &GroupDescription{Group: group} + d := gds[i].Description + if d != "" { + ret[i].Description = &d + } + } + + return ret, firstError(errs) +} + +func (r groupResolver) ContainingGroups(ctx context.Context, obj *models.Group) (ret []*GroupDescription, err error) { + if !obj.ContainingGroups.Loaded() { + if err := r.withReadTxn(ctx, func(ctx context.Context) error { + return obj.LoadContainingGroupIDs(ctx, r.repository.Group) + }); err != nil { + return nil, err + } + } + + return r.relatedGroups(ctx, obj.ContainingGroups) +} + +func (r groupResolver) SubGroups(ctx context.Context, obj *models.Group) (ret []*GroupDescription, err error) { + if !obj.SubGroups.Loaded() { + if err := r.withReadTxn(ctx, func(ctx context.Context) error { + return obj.LoadSubGroupIDs(ctx, r.repository.Group) + }); err != nil { + return nil, err + } + } + + return r.relatedGroups(ctx, obj.SubGroups) +} + +func (r *groupResolver) SubGroupCount(ctx context.Context, obj *models.Group, depth *int) (ret int, err error) { + if err := r.withReadTxn(ctx, func(ctx context.Context) error { + ret, err = group.CountByContainingGroupID(ctx, r.repository.Group, obj.ID, depth) + return err + }); err != nil { + return 0, err + } + + return ret, nil +} + func (r *groupResolver) FrontImagePath(ctx context.Context, obj *models.Group) (*string, error) { var hasImage bool if err := r.withReadTxn(ctx, func(ctx context.Context) error { @@ -106,9 +170,9 @@ func (r *groupResolver) BackImagePath(ctx context.Context, obj *models.Group) (* return &imagePath, nil } -func (r *groupResolver) SceneCount(ctx context.Context, obj *models.Group) (ret int, err error) { +func (r *groupResolver) SceneCount(ctx context.Context, obj *models.Group, depth *int) (ret int, err error) { if err := r.withReadTxn(ctx, func(ctx context.Context) error { - ret, err = r.repository.Scene.CountByGroupID(ctx, obj.ID) + ret, err = scene.CountByGroupID(ctx, r.repository.Scene, obj.ID, depth) return err }); err != nil { return 0, err diff --git a/internal/api/resolver_mutation_group.go b/internal/api/resolver_mutation_group.go index d455dd1058c..d75994d1497 100644 --- a/internal/api/resolver_mutation_group.go +++ b/internal/api/resolver_mutation_group.go @@ -6,6 +6,7 @@ import ( "strconv" "github.com/stashapp/stash/internal/static" + "github.com/stashapp/stash/pkg/group" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/sliceutil/stringslice" @@ -43,6 +44,16 @@ func groupFromGroupCreateInput(ctx context.Context, input GroupCreateInput) (*mo return nil, fmt.Errorf("converting tag ids: %w", err) } + newGroup.ContainingGroups, err = translator.groupIDDescriptions(input.ContainingGroups) + if err != nil { + return nil, fmt.Errorf("converting containing group ids: %w", err) + } + + newGroup.SubGroups, err = translator.groupIDDescriptions(input.SubGroups) + if err != nil { + return nil, fmt.Errorf("converting containing group ids: %w", err) + } + if input.Urls != nil { newGroup.URLs = models.NewRelatedStrings(input.Urls) } @@ -82,26 +93,10 @@ func (r *mutationResolver) GroupCreate(ctx context.Context, input GroupCreateInp // Start the transaction and save the group if err := r.withTxn(ctx, func(ctx context.Context) error { - qb := r.repository.Group - - err = qb.Create(ctx, newGroup) - if err != nil { + if err = r.groupService.Create(ctx, newGroup, frontimageData, backimageData); err != nil { return err } - // update image table - if len(frontimageData) > 0 { - if err := qb.UpdateFrontImage(ctx, newGroup.ID, frontimageData); err != nil { - return err - } - } - - if len(backimageData) > 0 { - if err := qb.UpdateBackImage(ctx, newGroup.ID, backimageData); err != nil { - return err - } - } - return nil }); err != nil { return nil, err @@ -141,6 +136,18 @@ func groupPartialFromGroupUpdateInput(translator changesetTranslator, input Grou return } + updatedGroup.ContainingGroups, err = translator.updateGroupIDDescriptions(input.ContainingGroups, "containing_groups") + if err != nil { + err = fmt.Errorf("converting containing group ids: %w", err) + return + } + + updatedGroup.SubGroups, err = translator.updateGroupIDDescriptions(input.SubGroups, "sub_groups") + if err != nil { + err = fmt.Errorf("converting containing group ids: %w", err) + return + } + updatedGroup.URLs = translator.updateStrings(input.Urls, "urls") return updatedGroup, nil @@ -179,26 +186,20 @@ func (r *mutationResolver) GroupUpdate(ctx context.Context, input GroupUpdateInp } } - // Start the transaction and save the group - var group *models.Group if err := r.withTxn(ctx, func(ctx context.Context) error { - qb := r.repository.Group - group, err = qb.UpdatePartial(ctx, groupID, updatedGroup) - if err != nil { - return err + frontImage := group.ImageInput{ + Image: frontimageData, + Set: frontImageIncluded, } - // update image table - if frontImageIncluded { - if err := qb.UpdateFrontImage(ctx, group.ID, frontimageData); err != nil { - return err - } + backImage := group.ImageInput{ + Image: backimageData, + Set: backImageIncluded, } - if backImageIncluded { - if err := qb.UpdateBackImage(ctx, group.ID, backimageData); err != nil { - return err - } + _, err = r.groupService.UpdatePartial(ctx, groupID, updatedGroup, frontImage, backImage) + if err != nil { + return err } return nil @@ -207,9 +208,9 @@ func (r *mutationResolver) GroupUpdate(ctx context.Context, input GroupUpdateInp } // for backwards compatibility - run both movie and group hooks - r.hookExecutor.ExecutePostHooks(ctx, group.ID, hook.GroupUpdatePost, input, translator.getFields()) - r.hookExecutor.ExecutePostHooks(ctx, group.ID, hook.MovieUpdatePost, input, translator.getFields()) - return r.getGroup(ctx, group.ID) + r.hookExecutor.ExecutePostHooks(ctx, groupID, hook.GroupUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, groupID, hook.MovieUpdatePost, input, translator.getFields()) + return r.getGroup(ctx, groupID) } func groupPartialFromBulkGroupUpdateInput(translator changesetTranslator, input BulkGroupUpdateInput) (ret models.GroupPartial, err error) { @@ -230,6 +231,18 @@ func groupPartialFromBulkGroupUpdateInput(translator changesetTranslator, input return } + updatedGroup.ContainingGroups, err = translator.updateGroupIDDescriptionsBulk(input.ContainingGroups, "containing_groups") + if err != nil { + err = fmt.Errorf("converting containing group ids: %w", err) + return + } + + updatedGroup.SubGroups, err = translator.updateGroupIDDescriptionsBulk(input.SubGroups, "sub_groups") + if err != nil { + err = fmt.Errorf("converting containing group ids: %w", err) + return + } + updatedGroup.URLs = translator.optionalURLsBulk(input.Urls, nil) return updatedGroup, nil @@ -254,10 +267,8 @@ func (r *mutationResolver) BulkGroupUpdate(ctx context.Context, input BulkGroupU ret := []*models.Group{} if err := r.withTxn(ctx, func(ctx context.Context) error { - qb := r.repository.Group - for _, groupID := range groupIDs { - group, err := qb.UpdatePartial(ctx, groupID, updatedGroup) + group, err := r.groupService.UpdatePartial(ctx, groupID, updatedGroup, group.ImageInput{}, group.ImageInput{}) if err != nil { return err } @@ -333,3 +344,70 @@ func (r *mutationResolver) GroupsDestroy(ctx context.Context, groupIDs []string) return true, nil } + +func (r *mutationResolver) AddGroupSubGroups(ctx context.Context, input GroupSubGroupAddInput) (bool, error) { + groupID, err := strconv.Atoi(input.ContainingGroupID) + if err != nil { + return false, fmt.Errorf("converting group id: %w", err) + } + + subGroups, err := groupsDescriptionsFromGroupInput(input.SubGroups) + if err != nil { + return false, fmt.Errorf("converting sub group ids: %w", err) + } + + if err := r.withTxn(ctx, func(ctx context.Context) error { + return r.groupService.AddSubGroups(ctx, groupID, subGroups, input.InsertIndex) + }); err != nil { + return false, err + } + + return true, nil +} + +func (r *mutationResolver) RemoveGroupSubGroups(ctx context.Context, input GroupSubGroupRemoveInput) (bool, error) { + groupID, err := strconv.Atoi(input.ContainingGroupID) + if err != nil { + return false, fmt.Errorf("converting group id: %w", err) + } + + subGroupIDs, err := stringslice.StringSliceToIntSlice(input.SubGroupIds) + if err != nil { + return false, fmt.Errorf("converting sub group ids: %w", err) + } + + if err := r.withTxn(ctx, func(ctx context.Context) error { + return r.groupService.RemoveSubGroups(ctx, groupID, subGroupIDs) + }); err != nil { + return false, err + } + + return true, nil +} + +func (r *mutationResolver) ReorderSubGroups(ctx context.Context, input ReorderSubGroupsInput) (bool, error) { + groupID, err := strconv.Atoi(input.GroupID) + if err != nil { + return false, fmt.Errorf("converting group id: %w", err) + } + + subGroupIDs, err := stringslice.StringSliceToIntSlice(input.SubGroupIds) + if err != nil { + return false, fmt.Errorf("converting sub group ids: %w", err) + } + + insertPointID, err := strconv.Atoi(input.InsertAtID) + if err != nil { + return false, fmt.Errorf("converting insert at id: %w", err) + } + + insertAfter := utils.IsTrue(input.InsertAfter) + + if err := r.withTxn(ctx, func(ctx context.Context) error { + return r.groupService.ReorderSubGroups(ctx, groupID, subGroupIDs, insertPointID, insertAfter) + }); err != nil { + return false, err + } + + return true, nil +} diff --git a/internal/api/server.go b/internal/api/server.go index b32ee04a027..63a81da7c2e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -158,11 +158,13 @@ func Initialize() (*Server, error) { sceneService := mgr.SceneService imageService := mgr.ImageService galleryService := mgr.GalleryService + groupService := mgr.GroupService resolver := &Resolver{ repository: repo, sceneService: sceneService, imageService: imageService, galleryService: galleryService, + groupService: groupService, hookExecutor: pluginCache, } diff --git a/internal/dlna/cds.go b/internal/dlna/cds.go index 531fc1cb55c..a38e0e55bed 100644 --- a/internal/dlna/cds.go +++ b/internal/dlna/cds.go @@ -682,7 +682,7 @@ func (me *contentDirectoryService) getGroups() []interface{} { func (me *contentDirectoryService) getGroupScenes(paths []string, host string) []interface{} { sceneFilter := &models.SceneFilterType{ - Groups: &models.MultiCriterionInput{ + Groups: &models.HierarchicalMultiCriterionInput{ Modifier: models.CriterionModifierIncludes, Value: []string{paths[0]}, }, diff --git a/internal/manager/init.go b/internal/manager/init.go index 347d08a153e..020ba944d40 100644 --- a/internal/manager/init.go +++ b/internal/manager/init.go @@ -17,6 +17,7 @@ import ( "github.com/stashapp/stash/pkg/ffmpeg" "github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/gallery" + "github.com/stashapp/stash/pkg/group" "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/job" "github.com/stashapp/stash/pkg/logger" @@ -67,6 +68,10 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) { Folder: db.Folder, } + groupService := &group.Service{ + Repository: db.Group, + } + sceneServer := &SceneServer{ TxnManager: repo.TxnManager, SceneCoverGetter: repo.Scene, @@ -99,6 +104,7 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) { SceneService: sceneService, ImageService: imageService, GalleryService: galleryService, + GroupService: groupService, scanSubs: &subscriptionManager{}, } diff --git a/internal/manager/manager.go b/internal/manager/manager.go index 397503930dc..ffba184d2bd 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -66,6 +66,7 @@ type Manager struct { SceneService SceneService ImageService ImageService GalleryService GalleryService + GroupService GroupService scanSubs *subscriptionManager } diff --git a/internal/manager/repository.go b/internal/manager/repository.go index 766f8039f85..13e1e8ae81b 100644 --- a/internal/manager/repository.go +++ b/internal/manager/repository.go @@ -3,6 +3,7 @@ package manager import ( "context" + "github.com/stashapp/stash/pkg/group" "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene" @@ -33,3 +34,12 @@ type GalleryService interface { Updated(ctx context.Context, galleryID int) error } + +type GroupService interface { + Create(ctx context.Context, group *models.Group, frontimageData []byte, backimageData []byte) error + UpdatePartial(ctx context.Context, id int, updatedGroup models.GroupPartial, frontImage group.ImageInput, backImage group.ImageInput) (*models.Group, error) + + AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error + RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error + ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error +} diff --git a/internal/manager/task_export.go b/internal/manager/task_export.go index ecbcf593af5..19abba2158d 100644 --- a/internal/manager/task_export.go +++ b/internal/manager/task_export.go @@ -1134,6 +1134,10 @@ func (t *ExportTask) exportGroup(ctx context.Context, wg *sync.WaitGroup, jobCha logger.Errorf("[groups] <%s> error getting group urls: %v", m.Name, err) continue } + if err := m.LoadSubGroupIDs(ctx, r.Group); err != nil { + logger.Errorf("[groups] <%s> error getting group sub-groups: %v", m.Name, err) + continue + } newGroupJSON, err := group.ToJSON(ctx, groupReader, studioReader, m) @@ -1150,6 +1154,25 @@ func (t *ExportTask) exportGroup(ctx context.Context, wg *sync.WaitGroup, jobCha newGroupJSON.Tags = tag.GetNames(tags) + subGroups := m.SubGroups.List() + if err := func() error { + for _, sg := range subGroups { + subGroup, err := groupReader.Find(ctx, sg.GroupID) + if err != nil { + return fmt.Errorf("error getting sub group: %v", err) + } + + newGroupJSON.SubGroups = append(newGroupJSON.SubGroups, jsonschema.SubGroupDescription{ + // TODO - this won't be unique + Group: subGroup.Name, + Description: sg.Description, + }) + } + return nil + }(); err != nil { + logger.Errorf("[groups] <%s> %v", m.Name, err) + } + if t.includeDependencies { if m.StudioID != nil { t.studios.IDs = sliceutil.AppendUnique(t.studios.IDs, *m.StudioID) diff --git a/internal/manager/task_import.go b/internal/manager/task_import.go index ae9a5865765..87185c66183 100644 --- a/internal/manager/task_import.go +++ b/internal/manager/task_import.go @@ -327,6 +327,7 @@ func (t *ImportTask) importStudio(ctx context.Context, studioJSON *jsonschema.St func (t *ImportTask) ImportGroups(ctx context.Context) { logger.Info("[groups] importing") + pendingSubs := make(map[string][]*jsonschema.Group) path := t.json.json.Groups files, err := os.ReadDir(path) @@ -351,24 +352,72 @@ func (t *ImportTask) ImportGroups(ctx context.Context) { logger.Progressf("[groups] %d of %d", index, len(files)) if err := r.WithTxn(ctx, func(ctx context.Context) error { - groupImporter := &group.Importer{ - ReaderWriter: r.Group, - StudioWriter: r.Studio, - TagWriter: r.Tag, - Input: *groupJSON, - MissingRefBehaviour: t.MissingRefBehaviour, + return t.importGroup(ctx, groupJSON, pendingSubs, false) + }); err != nil { + var subError group.SubGroupNotExistError + if errors.As(err, &subError) { + missingSub := subError.MissingSubGroup() + pendingSubs[missingSub] = append(pendingSubs[missingSub], groupJSON) + continue } - return performImport(ctx, groupImporter, t.DuplicateBehaviour) - }); err != nil { - logger.Errorf("[groups] <%s> import failed: %v", fi.Name(), err) + logger.Errorf("[groups] <%s> failed to import: %v", fi.Name(), err) continue } } + for _, s := range pendingSubs { + for _, orphanGroupJSON := range s { + if err := r.WithTxn(ctx, func(ctx context.Context) error { + return t.importGroup(ctx, orphanGroupJSON, nil, true) + }); err != nil { + logger.Errorf("[groups] <%s> failed to create: %v", orphanGroupJSON.Name, err) + continue + } + } + } + logger.Info("[groups] import complete") } +func (t *ImportTask) importGroup(ctx context.Context, groupJSON *jsonschema.Group, pendingSub map[string][]*jsonschema.Group, fail bool) error { + r := t.repository + + importer := &group.Importer{ + ReaderWriter: r.Group, + StudioWriter: r.Studio, + TagWriter: r.Tag, + Input: *groupJSON, + MissingRefBehaviour: t.MissingRefBehaviour, + } + + // first phase: return error if parent does not exist + if !fail { + importer.MissingRefBehaviour = models.ImportMissingRefEnumFail + } + + if err := performImport(ctx, importer, t.DuplicateBehaviour); err != nil { + return err + } + + for _, containingGroupJSON := range pendingSub[groupJSON.Name] { + if err := t.importGroup(ctx, containingGroupJSON, pendingSub, fail); err != nil { + var subError group.SubGroupNotExistError + if errors.As(err, &subError) { + missingSub := subError.MissingSubGroup() + pendingSub[missingSub] = append(pendingSub[missingSub], containingGroupJSON) + continue + } + + return fmt.Errorf("failed to create containing group <%s>: %v", containingGroupJSON.Name, err) + } + } + + delete(pendingSub, groupJSON.Name) + + return nil +} + func (t *ImportTask) ImportFiles(ctx context.Context) { logger.Info("[files] importing") diff --git a/pkg/group/create.go b/pkg/group/create.go new file mode 100644 index 00000000000..56d6b7a4ed4 --- /dev/null +++ b/pkg/group/create.go @@ -0,0 +1,41 @@ +package group + +import ( + "context" + "errors" + + "github.com/stashapp/stash/pkg/models" +) + +var ( + ErrEmptyName = errors.New("name cannot be empty") + ErrHierarchyLoop = errors.New("a group cannot be contained by one of its subgroups") +) + +func (s *Service) Create(ctx context.Context, group *models.Group, frontimageData []byte, backimageData []byte) error { + r := s.Repository + + if err := s.validateCreate(ctx, group); err != nil { + return err + } + + err := r.Create(ctx, group) + if err != nil { + return err + } + + // update image table + if len(frontimageData) > 0 { + if err := r.UpdateFrontImage(ctx, group.ID, frontimageData); err != nil { + return err + } + } + + if len(backimageData) > 0 { + if err := r.UpdateBackImage(ctx, group.ID, backimageData); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/group/import.go b/pkg/group/import.go index 4bf038c8776..589e75df30d 100644 --- a/pkg/group/import.go +++ b/pkg/group/import.go @@ -16,6 +16,18 @@ type ImporterReaderWriter interface { FindByName(ctx context.Context, name string, nocase bool) (*models.Group, error) } +type SubGroupNotExistError struct { + missingSubGroup string +} + +func (e SubGroupNotExistError) Error() string { + return fmt.Sprintf("sub group <%s> does not exist", e.missingSubGroup) +} + +func (e SubGroupNotExistError) MissingSubGroup() string { + return e.missingSubGroup +} + type Importer struct { ReaderWriter ImporterReaderWriter StudioWriter models.StudioFinderCreator @@ -202,6 +214,22 @@ func (i *Importer) createStudio(ctx context.Context, name string) (int, error) { } func (i *Importer) PostImport(ctx context.Context, id int) error { + subGroups, err := i.getSubGroups(ctx) + if err != nil { + return err + } + + if len(subGroups) > 0 { + if _, err := i.ReaderWriter.UpdatePartial(ctx, id, models.GroupPartial{ + SubGroups: &models.UpdateGroupDescriptions{ + Groups: subGroups, + Mode: models.RelationshipUpdateModeSet, + }, + }); err != nil { + return fmt.Errorf("error setting parents: %v", err) + } + } + if len(i.frontImageData) > 0 { if err := i.ReaderWriter.UpdateFrontImage(ctx, id, i.frontImageData); err != nil { return fmt.Errorf("error setting group front image: %v", err) @@ -256,3 +284,53 @@ func (i *Importer) Update(ctx context.Context, id int) error { return nil } + +func (i *Importer) getSubGroups(ctx context.Context) ([]models.GroupIDDescription, error) { + var subGroups []models.GroupIDDescription + for _, subGroup := range i.Input.SubGroups { + group, err := i.ReaderWriter.FindByName(ctx, subGroup.Group, false) + if err != nil { + return nil, fmt.Errorf("error finding parent by name: %v", err) + } + + if group == nil { + if i.MissingRefBehaviour == models.ImportMissingRefEnumFail { + return nil, SubGroupNotExistError{missingSubGroup: subGroup.Group} + } + + if i.MissingRefBehaviour == models.ImportMissingRefEnumIgnore { + continue + } + + if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { + parentID, err := i.createSubGroup(ctx, subGroup.Group) + if err != nil { + return nil, err + } + subGroups = append(subGroups, models.GroupIDDescription{ + GroupID: parentID, + Description: subGroup.Description, + }) + } + } else { + subGroups = append(subGroups, models.GroupIDDescription{ + GroupID: group.ID, + Description: subGroup.Description, + }) + } + } + + return subGroups, nil +} + +func (i *Importer) createSubGroup(ctx context.Context, name string) (int, error) { + newGroup := models.NewGroup() + newGroup.Name = name + + err := i.ReaderWriter.Create(ctx, &newGroup) + if err != nil { + return 0, err + } + + return newGroup.ID, nil +} diff --git a/pkg/group/query.go b/pkg/group/query.go index bc0753b0055..b3adafaf523 100644 --- a/pkg/group/query.go +++ b/pkg/group/query.go @@ -30,3 +30,15 @@ func CountByTagID(ctx context.Context, r models.GroupQueryer, id int, depth *int return r.QueryCount(ctx, filter, nil) } + +func CountByContainingGroupID(ctx context.Context, r models.GroupQueryer, id int, depth *int) (int, error) { + filter := &models.GroupFilterType{ + ContainingGroups: &models.HierarchicalMultiCriterionInput{ + Value: []string{strconv.Itoa(id)}, + Modifier: models.CriterionModifierIncludes, + Depth: depth, + }, + } + + return r.QueryCount(ctx, filter, nil) +} diff --git a/pkg/group/reorder.go b/pkg/group/reorder.go new file mode 100644 index 00000000000..b4afd1b0968 --- /dev/null +++ b/pkg/group/reorder.go @@ -0,0 +1,33 @@ +package group + +import ( + "context" + "errors" + + "github.com/stashapp/stash/pkg/models" +) + +var ErrInvalidInsertIndex = errors.New("invalid insert index") + +func (s *Service) ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error { + // get the group + existing, err := s.Repository.Find(ctx, groupID) + if err != nil { + return err + } + + // ensure it exists + if existing == nil { + return models.ErrNotFound + } + + // TODO - ensure the subgroups exist in the group + + // ensure the insert index is valid + if insertPointID < 0 { + return ErrInvalidInsertIndex + } + + // reorder the subgroups + return s.Repository.ReorderSubGroups(ctx, groupID, subGroupIDs, insertPointID, insertAfter) +} diff --git a/pkg/group/service.go b/pkg/group/service.go new file mode 100644 index 00000000000..ff6e0354184 --- /dev/null +++ b/pkg/group/service.go @@ -0,0 +1,46 @@ +package group + +import ( + "context" + + "github.com/stashapp/stash/pkg/models" +) + +type CreatorUpdater interface { + models.GroupGetter + models.GroupCreator + models.GroupUpdater + + models.ContainingGroupLoader + models.SubGroupLoader + + AnscestorFinder + SubGroupIDFinder + SubGroupAdder + SubGroupRemover + SubGroupReorderer +} + +type AnscestorFinder interface { + FindInAncestors(ctx context.Context, ascestorIDs []int, ids []int) ([]int, error) +} + +type SubGroupIDFinder interface { + FindSubGroupIDs(ctx context.Context, containingID int, ids []int) ([]int, error) +} + +type SubGroupAdder interface { + AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error +} + +type SubGroupRemover interface { + RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error +} + +type SubGroupReorderer interface { + ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertID int, insertAfter bool) error +} + +type Service struct { + Repository CreatorUpdater +} diff --git a/pkg/group/update.go b/pkg/group/update.go new file mode 100644 index 00000000000..d0bc9602add --- /dev/null +++ b/pkg/group/update.go @@ -0,0 +1,112 @@ +package group + +import ( + "context" + "fmt" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sliceutil" +) + +type SubGroupAlreadyInGroupError struct { + GroupIDs []int +} + +func (e *SubGroupAlreadyInGroupError) Error() string { + return fmt.Sprintf("subgroups with IDs %v already in group", e.GroupIDs) +} + +type ImageInput struct { + Image []byte + Set bool +} + +func (s *Service) UpdatePartial(ctx context.Context, id int, updatedGroup models.GroupPartial, frontImage ImageInput, backImage ImageInput) (*models.Group, error) { + if err := s.validateUpdate(ctx, id, updatedGroup); err != nil { + return nil, err + } + + r := s.Repository + + group, err := r.UpdatePartial(ctx, id, updatedGroup) + if err != nil { + return nil, err + } + + // update image table + if frontImage.Set { + if err := r.UpdateFrontImage(ctx, id, frontImage.Image); err != nil { + return nil, err + } + } + + if backImage.Set { + if err := r.UpdateBackImage(ctx, id, backImage.Image); err != nil { + return nil, err + } + } + + return group, nil +} + +func (s *Service) AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error { + // get the group + existing, err := s.Repository.Find(ctx, groupID) + if err != nil { + return err + } + + // ensure it exists + if existing == nil { + return models.ErrNotFound + } + + // ensure the subgroups aren't already sub-groups of the group + subGroupIDs := sliceutil.Map(subGroups, func(sg models.GroupIDDescription) int { + return sg.GroupID + }) + + existingSubGroupIDs, err := s.Repository.FindSubGroupIDs(ctx, groupID, subGroupIDs) + if err != nil { + return err + } + + if len(existingSubGroupIDs) > 0 { + return &SubGroupAlreadyInGroupError{ + GroupIDs: existingSubGroupIDs, + } + } + + // validate the hierarchy + d := &models.UpdateGroupDescriptions{ + Groups: subGroups, + Mode: models.RelationshipUpdateModeAdd, + } + if err := s.validateUpdateGroupHierarchy(ctx, existing, nil, d); err != nil { + return err + } + + // validate insert index + if insertIndex != nil && *insertIndex < 0 { + return ErrInvalidInsertIndex + } + + // add the subgroups + return s.Repository.AddSubGroups(ctx, groupID, subGroups, insertIndex) +} + +func (s *Service) RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error { + // get the group + existing, err := s.Repository.Find(ctx, groupID) + if err != nil { + return err + } + + // ensure it exists + if existing == nil { + return models.ErrNotFound + } + + // add the subgroups + return s.Repository.RemoveSubGroups(ctx, groupID, subGroupIDs) +} diff --git a/pkg/group/validate.go b/pkg/group/validate.go new file mode 100644 index 00000000000..723b9f6997a --- /dev/null +++ b/pkg/group/validate.go @@ -0,0 +1,117 @@ +package group + +import ( + "context" + "strings" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sliceutil" +) + +func (s *Service) validateCreate(ctx context.Context, group *models.Group) error { + if err := validateName(group.Name); err != nil { + return err + } + + containingIDs := group.ContainingGroups.IDs() + subIDs := group.SubGroups.IDs() + + if err := s.validateGroupHierarchy(ctx, containingIDs, subIDs); err != nil { + return err + } + + return nil +} + +func (s *Service) validateUpdate(ctx context.Context, id int, partial models.GroupPartial) error { + // get the existing group - ensure it exists + existing, err := s.Repository.Find(ctx, id) + if err != nil { + return err + } + + if existing == nil { + return models.ErrNotFound + } + + if partial.Name.Set { + if err := validateName(partial.Name.Value); err != nil { + return err + } + } + + if err := s.validateUpdateGroupHierarchy(ctx, existing, partial.ContainingGroups, partial.SubGroups); err != nil { + return err + } + + return nil +} + +func validateName(n string) error { + // ensure name is not empty + if strings.TrimSpace(n) == "" { + return ErrEmptyName + } + + return nil +} + +func (s *Service) validateGroupHierarchy(ctx context.Context, containingIDs []int, subIDs []int) error { + // only need to validate if both are non-empty + if len(containingIDs) == 0 || len(subIDs) == 0 { + return nil + } + + // ensure none of the containing groups are in the sub groups + found, err := s.Repository.FindInAncestors(ctx, containingIDs, subIDs) + if err != nil { + return err + } + + if len(found) > 0 { + return ErrHierarchyLoop + } + + return nil +} + +func (s *Service) validateUpdateGroupHierarchy(ctx context.Context, existing *models.Group, containingGroups *models.UpdateGroupDescriptions, subGroups *models.UpdateGroupDescriptions) error { + // no need to validate if there are no changes + if containingGroups == nil && subGroups == nil { + return nil + } + + if err := existing.LoadContainingGroupIDs(ctx, s.Repository); err != nil { + return err + } + existingContainingGroups := existing.ContainingGroups.List() + + if err := existing.LoadSubGroupIDs(ctx, s.Repository); err != nil { + return err + } + existingSubGroups := existing.SubGroups.List() + + effectiveContainingGroups := existingContainingGroups + if containingGroups != nil { + effectiveContainingGroups = containingGroups.Apply(existingContainingGroups) + } + + effectiveSubGroups := existingSubGroups + if subGroups != nil { + effectiveSubGroups = subGroups.Apply(existingSubGroups) + } + + containingIDs := idsFromGroupDescriptions(effectiveContainingGroups) + subIDs := idsFromGroupDescriptions(effectiveSubGroups) + + // ensure we haven't set the group as a subgroup of itself + if sliceutil.Contains(containingIDs, existing.ID) || sliceutil.Contains(subIDs, existing.ID) { + return ErrHierarchyLoop + } + + return s.validateGroupHierarchy(ctx, containingIDs, subIDs) +} + +func idsFromGroupDescriptions(v []models.GroupIDDescription) []int { + return sliceutil.Map(v, func(g models.GroupIDDescription) int { return g.GroupID }) +} diff --git a/pkg/models/group.go b/pkg/models/group.go index db7badccc90..6afda3f4890 100644 --- a/pkg/models/group.go +++ b/pkg/models/group.go @@ -23,6 +23,14 @@ type GroupFilterType struct { TagCount *IntCriterionInput `json:"tag_count"` // Filter by date Date *DateCriterionInput `json:"date"` + // Filter by containing groups + ContainingGroups *HierarchicalMultiCriterionInput `json:"containing_groups"` + // Filter by sub groups + SubGroups *HierarchicalMultiCriterionInput `json:"sub_groups"` + // Filter by number of containing groups the group has + ContainingGroupCount *IntCriterionInput `json:"containing_group_count"` + // Filter by number of sub-groups the group has + SubGroupCount *IntCriterionInput `json:"sub_group_count"` // Filter by related scenes that meet this criteria ScenesFilter *SceneFilterType `json:"scenes_filter"` // Filter by related studios that meet this criteria diff --git a/pkg/models/jsonschema/group.go b/pkg/models/jsonschema/group.go index fcf1ffe60a0..b284dab6e77 100644 --- a/pkg/models/jsonschema/group.go +++ b/pkg/models/jsonschema/group.go @@ -11,21 +11,27 @@ import ( "github.com/stashapp/stash/pkg/models/json" ) +type SubGroupDescription struct { + Group string `json:"name,omitempty"` + Description string `json:"description,omitempty"` +} + type Group struct { - Name string `json:"name,omitempty"` - Aliases string `json:"aliases,omitempty"` - Duration int `json:"duration,omitempty"` - Date string `json:"date,omitempty"` - Rating int `json:"rating,omitempty"` - Director string `json:"director,omitempty"` - Synopsis string `json:"synopsis,omitempty"` - FrontImage string `json:"front_image,omitempty"` - BackImage string `json:"back_image,omitempty"` - URLs []string `json:"urls,omitempty"` - Studio string `json:"studio,omitempty"` - Tags []string `json:"tags,omitempty"` - CreatedAt json.JSONTime `json:"created_at,omitempty"` - UpdatedAt json.JSONTime `json:"updated_at,omitempty"` + Name string `json:"name,omitempty"` + Aliases string `json:"aliases,omitempty"` + Duration int `json:"duration,omitempty"` + Date string `json:"date,omitempty"` + Rating int `json:"rating,omitempty"` + Director string `json:"director,omitempty"` + Synopsis string `json:"synopsis,omitempty"` + FrontImage string `json:"front_image,omitempty"` + BackImage string `json:"back_image,omitempty"` + URLs []string `json:"urls,omitempty"` + Studio string `json:"studio,omitempty"` + Tags []string `json:"tags,omitempty"` + SubGroups []SubGroupDescription `json:"sub_groups,omitempty"` + CreatedAt json.JSONTime `json:"created_at,omitempty"` + UpdatedAt json.JSONTime `json:"updated_at,omitempty"` // deprecated - for import only URL string `json:"url,omitempty"` diff --git a/pkg/models/mocks/GroupReaderWriter.go b/pkg/models/mocks/GroupReaderWriter.go index 5e3a2644ca8..dc745d09487 100644 --- a/pkg/models/mocks/GroupReaderWriter.go +++ b/pkg/models/mocks/GroupReaderWriter.go @@ -289,6 +289,29 @@ func (_m *GroupReaderWriter) GetBackImage(ctx context.Context, groupID int) ([]b return r0, r1 } +// GetContainingGroupDescriptions provides a mock function with given fields: ctx, id +func (_m *GroupReaderWriter) GetContainingGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) { + ret := _m.Called(ctx, id) + + var r0 []models.GroupIDDescription + if rf, ok := ret.Get(0).(func(context.Context, int) []models.GroupIDDescription); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.GroupIDDescription) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetFrontImage provides a mock function with given fields: ctx, groupID func (_m *GroupReaderWriter) GetFrontImage(ctx context.Context, groupID int) ([]byte, error) { ret := _m.Called(ctx, groupID) @@ -312,6 +335,29 @@ func (_m *GroupReaderWriter) GetFrontImage(ctx context.Context, groupID int) ([] return r0, r1 } +// GetSubGroupDescriptions provides a mock function with given fields: ctx, id +func (_m *GroupReaderWriter) GetSubGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) { + ret := _m.Called(ctx, id) + + var r0 []models.GroupIDDescription + if rf, ok := ret.Get(0).(func(context.Context, int) []models.GroupIDDescription); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.GroupIDDescription) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetTagIDs provides a mock function with given fields: ctx, relatedID func (_m *GroupReaderWriter) GetTagIDs(ctx context.Context, relatedID int) ([]int, error) { ret := _m.Called(ctx, relatedID) diff --git a/pkg/models/model_group.go b/pkg/models/model_group.go index af3ac56c68c..82c71996ae8 100644 --- a/pkg/models/model_group.go +++ b/pkg/models/model_group.go @@ -21,6 +21,9 @@ type Group struct { URLs RelatedStrings `json:"urls"` TagIDs RelatedIDs `json:"tag_ids"` + + ContainingGroups RelatedGroupDescriptions `json:"containing_groups"` + SubGroups RelatedGroupDescriptions `json:"sub_groups"` } func NewGroup() Group { @@ -43,20 +46,34 @@ func (m *Group) LoadTagIDs(ctx context.Context, l TagIDLoader) error { }) } +func (m *Group) LoadContainingGroupIDs(ctx context.Context, l ContainingGroupLoader) error { + return m.ContainingGroups.load(func() ([]GroupIDDescription, error) { + return l.GetContainingGroupDescriptions(ctx, m.ID) + }) +} + +func (m *Group) LoadSubGroupIDs(ctx context.Context, l SubGroupLoader) error { + return m.SubGroups.load(func() ([]GroupIDDescription, error) { + return l.GetSubGroupDescriptions(ctx, m.ID) + }) +} + type GroupPartial struct { Name OptionalString Aliases OptionalString Duration OptionalInt Date OptionalDate // Rating expressed in 1-100 scale - Rating OptionalInt - StudioID OptionalInt - Director OptionalString - Synopsis OptionalString - URLs *UpdateStrings - TagIDs *UpdateIDs - CreatedAt OptionalTime - UpdatedAt OptionalTime + Rating OptionalInt + StudioID OptionalInt + Director OptionalString + Synopsis OptionalString + URLs *UpdateStrings + TagIDs *UpdateIDs + ContainingGroups *UpdateGroupDescriptions + SubGroups *UpdateGroupDescriptions + CreatedAt OptionalTime + UpdatedAt OptionalTime } func NewGroupPartial() GroupPartial { diff --git a/pkg/models/model_joins.go b/pkg/models/model_joins.go index 189c2d7721f..7b7cae3e46a 100644 --- a/pkg/models/model_joins.go +++ b/pkg/models/model_joins.go @@ -68,3 +68,8 @@ func GroupsScenesFromInput(input []SceneMovieInput) ([]GroupsScenes, error) { return ret, nil } + +type GroupIDDescription struct { + GroupID int `json:"group_id"` + Description string `json:"description"` +} diff --git a/pkg/models/relationships.go b/pkg/models/relationships.go index 81528c26e95..5495f858b17 100644 --- a/pkg/models/relationships.go +++ b/pkg/models/relationships.go @@ -2,6 +2,8 @@ package models import ( "context" + + "github.com/stashapp/stash/pkg/sliceutil" ) type SceneIDLoader interface { @@ -37,6 +39,14 @@ type SceneGroupLoader interface { GetGroups(ctx context.Context, id int) ([]GroupsScenes, error) } +type ContainingGroupLoader interface { + GetContainingGroupDescriptions(ctx context.Context, id int) ([]GroupIDDescription, error) +} + +type SubGroupLoader interface { + GetSubGroupDescriptions(ctx context.Context, id int) ([]GroupIDDescription, error) +} + type StashIDLoader interface { GetStashIDs(ctx context.Context, relatedID int) ([]StashID, error) } @@ -185,6 +195,82 @@ func (r *RelatedGroups) load(fn func() ([]GroupsScenes, error)) error { return nil } +type RelatedGroupDescriptions struct { + list []GroupIDDescription +} + +// NewRelatedGroups returns a loaded RelateGroups object with the provided groups. +// Loaded will return true when called on the returned object if the provided slice is not nil. +func NewRelatedGroupDescriptions(list []GroupIDDescription) RelatedGroupDescriptions { + return RelatedGroupDescriptions{ + list: list, + } +} + +// Loaded returns true if the relationship has been loaded. +func (r RelatedGroupDescriptions) Loaded() bool { + return r.list != nil +} + +func (r RelatedGroupDescriptions) mustLoaded() { + if !r.Loaded() { + panic("list has not been loaded") + } +} + +// List returns the related Groups. Panics if the relationship has not been loaded. +func (r RelatedGroupDescriptions) List() []GroupIDDescription { + r.mustLoaded() + + return r.list +} + +// List returns the related Groups. Panics if the relationship has not been loaded. +func (r RelatedGroupDescriptions) IDs() []int { + r.mustLoaded() + + return sliceutil.Map(r.list, func(d GroupIDDescription) int { return d.GroupID }) +} + +// Add adds the provided ids to the list. Panics if the relationship has not been loaded. +func (r *RelatedGroupDescriptions) Add(groups ...GroupIDDescription) { + r.mustLoaded() + + r.list = append(r.list, groups...) +} + +// ForID returns the GroupsScenes object for the given group ID. Returns nil if not found. +func (r *RelatedGroupDescriptions) ForID(id int) *GroupIDDescription { + r.mustLoaded() + + for _, v := range r.list { + if v.GroupID == id { + return &v + } + } + + return nil +} + +func (r *RelatedGroupDescriptions) load(fn func() ([]GroupIDDescription, error)) error { + if r.Loaded() { + return nil + } + + ids, err := fn() + if err != nil { + return err + } + + if ids == nil { + ids = []GroupIDDescription{} + } + + r.list = ids + + return nil +} + type RelatedStashIDs struct { list []StashID } diff --git a/pkg/models/repository_group.go b/pkg/models/repository_group.go index 0396049b66e..704390d77b3 100644 --- a/pkg/models/repository_group.go +++ b/pkg/models/repository_group.go @@ -66,6 +66,8 @@ type GroupReader interface { GroupCounter URLLoader TagIDLoader + ContainingGroupLoader + SubGroupLoader All(ctx context.Context) ([]*Group, error) GetFrontImage(ctx context.Context, groupID int) ([]byte, error) diff --git a/pkg/models/repository_scene.go b/pkg/models/repository_scene.go index 60783fff5cd..e28347c5b82 100644 --- a/pkg/models/repository_scene.go +++ b/pkg/models/repository_scene.go @@ -37,10 +37,7 @@ type SceneQueryer interface { type SceneCounter interface { Count(ctx context.Context) (int, error) CountByPerformerID(ctx context.Context, performerID int) (int, error) - CountByGroupID(ctx context.Context, groupID int) (int, error) CountByFileID(ctx context.Context, fileID FileID) (int, error) - CountByStudioID(ctx context.Context, studioID int) (int, error) - CountByTagID(ctx context.Context, tagID int) (int, error) CountMissingChecksum(ctx context.Context) (int, error) CountMissingOSHash(ctx context.Context) (int, error) OCountByPerformerID(ctx context.Context, performerID int) (int, error) diff --git a/pkg/models/scene.go b/pkg/models/scene.go index 814c4a41d62..48317240276 100644 --- a/pkg/models/scene.go +++ b/pkg/models/scene.go @@ -56,7 +56,7 @@ type SceneFilterType struct { // Filter to only include scenes with this studio Studios *HierarchicalMultiCriterionInput `json:"studios"` // Filter to only include scenes with this group - Groups *MultiCriterionInput `json:"groups"` + Groups *HierarchicalMultiCriterionInput `json:"groups"` // Filter to only include scenes with this movie Movies *MultiCriterionInput `json:"movies"` // Filter to only include scenes with this gallery diff --git a/pkg/models/update.go b/pkg/models/update.go index 2302a2e699a..6aaff8c317f 100644 --- a/pkg/models/update.go +++ b/pkg/models/update.go @@ -133,3 +133,68 @@ func applyUpdate[T comparable](values []T, mode RelationshipUpdateMode, existing return nil } + +type UpdateGroupDescriptions struct { + Groups []GroupIDDescription `json:"groups"` + Mode RelationshipUpdateMode `json:"mode"` +} + +// Apply applies the update to a list of existing ids, returning the result. +func (u *UpdateGroupDescriptions) Apply(existing []GroupIDDescription) []GroupIDDescription { + if u == nil { + return existing + } + + switch u.Mode { + case RelationshipUpdateModeAdd: + return u.applyAdd(existing) + case RelationshipUpdateModeRemove: + return u.applyRemove(existing) + case RelationshipUpdateModeSet: + return u.Groups + } + + return nil +} + +func (u *UpdateGroupDescriptions) applyAdd(existing []GroupIDDescription) []GroupIDDescription { + // overwrite any existing values with the same id + ret := append([]GroupIDDescription{}, existing...) + for _, v := range u.Groups { + found := false + for i, vv := range ret { + if vv.GroupID == v.GroupID { + ret[i] = v + found = true + break + } + } + + if !found { + ret = append(ret, v) + } + } + + return ret +} + +func (u *UpdateGroupDescriptions) applyRemove(existing []GroupIDDescription) []GroupIDDescription { + // remove any existing values with the same id + var ret []GroupIDDescription + for _, v := range existing { + found := false + for _, vv := range u.Groups { + if vv.GroupID == v.GroupID { + found = true + break + } + } + + // if not found in the remove list, keep it + if !found { + ret = append(ret, v) + } + } + + return ret +} diff --git a/pkg/scene/query.go b/pkg/scene/query.go index a8b1993a6a0..c640266f9ef 100644 --- a/pkg/scene/query.go +++ b/pkg/scene/query.go @@ -144,3 +144,15 @@ func CountByTagID(ctx context.Context, r models.SceneQueryer, id int, depth *int return r.QueryCount(ctx, filter, nil) } + +func CountByGroupID(ctx context.Context, r models.SceneQueryer, id int, depth *int) (int, error) { + filter := &models.SceneFilterType{ + Groups: &models.HierarchicalMultiCriterionInput{ + Value: []string{strconv.Itoa(id)}, + Modifier: models.CriterionModifierIncludes, + Depth: depth, + }, + } + + return r.QueryCount(ctx, filter, nil) +} diff --git a/pkg/sliceutil/collections.go b/pkg/sliceutil/collections.go index bd4070cdc94..18930df259e 100644 --- a/pkg/sliceutil/collections.go +++ b/pkg/sliceutil/collections.go @@ -146,7 +146,7 @@ func Filter[T any](vs []T, f func(T) bool) []T { return ret } -// Filter returns the result of applying f to each element of the vs slice. +// Map returns the result of applying f to each element of the vs slice. func Map[T any, V any](vs []T, f func(T) V) []V { ret := make([]V, len(vs)) for i, v := range vs { diff --git a/pkg/sqlite/database.go b/pkg/sqlite/database.go index ee5e5399bf9..7dd4771d33f 100644 --- a/pkg/sqlite/database.go +++ b/pkg/sqlite/database.go @@ -30,7 +30,7 @@ const ( dbConnTimeout = 30 ) -var appSchemaVersion uint = 66 +var appSchemaVersion uint = 67 //go:embed migrations/*.sql var migrationsBox embed.FS diff --git a/pkg/sqlite/filter_hierarchical.go b/pkg/sqlite/filter_hierarchical.go new file mode 100644 index 00000000000..bc5ff9032b3 --- /dev/null +++ b/pkg/sqlite/filter_hierarchical.go @@ -0,0 +1,222 @@ +package sqlite + +import ( + "context" + "fmt" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/utils" +) + +// hierarchicalRelationshipHandler provides handlers for parent, children, parent count, and child count criteria. +type hierarchicalRelationshipHandler struct { + primaryTable string + relationTable string + aliasPrefix string + parentIDCol string + childIDCol string +} + +func (h hierarchicalRelationshipHandler) validateModifier(m models.CriterionModifier) error { + switch m { + case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull: + // valid + return nil + default: + return fmt.Errorf("invalid modifier %s", m) + } +} + +func (h hierarchicalRelationshipHandler) handleNullNotNull(f *filterBuilder, m models.CriterionModifier, isParents bool) { + var notClause string + if m == models.CriterionModifierNotNull { + notClause = "NOT" + } + + as := h.aliasPrefix + "_parents" + col := h.childIDCol + if !isParents { + as = h.aliasPrefix + "_children" + col = h.parentIDCol + } + + // Based on: + // f.addLeftJoin("tags_relations", "parent_relations", "tags.id = parent_relations.child_id") + // f.addWhere(fmt.Sprintf("parent_relations.parent_id IS %s NULL", notClause)) + + f.addLeftJoin(h.relationTable, as, fmt.Sprintf("%s.id = %s.%s", h.primaryTable, as, col)) + f.addWhere(fmt.Sprintf("%s.%s IS %s NULL", as, col, notClause)) +} + +func (h hierarchicalRelationshipHandler) parentsAlias() string { + return h.aliasPrefix + "_parents" +} + +func (h hierarchicalRelationshipHandler) childrenAlias() string { + return h.aliasPrefix + "_children" +} + +func (h hierarchicalRelationshipHandler) valueQuery(value []string, depth int, alias string, isParents bool) string { + var depthCondition string + if depth != -1 { + depthCondition = fmt.Sprintf("WHERE depth < %d", depth) + } + + queryTempl := `{alias} AS ( +SELECT {root_id_col} AS root_id, {item_id_col} AS item_id, 0 AS depth FROM {relation_table} WHERE {root_id_col} IN` + getInBinding(len(value)) + ` +UNION +SELECT root_id, {item_id_col}, depth + 1 FROM {relation_table} INNER JOIN {alias} ON item_id = {root_id_col} ` + depthCondition + ` +)` + + var queryMap utils.StrFormatMap + if isParents { + queryMap = utils.StrFormatMap{ + "root_id_col": h.parentIDCol, + "item_id_col": h.childIDCol, + } + } else { + queryMap = utils.StrFormatMap{ + "root_id_col": h.childIDCol, + "item_id_col": h.parentIDCol, + } + } + + queryMap["alias"] = alias + queryMap["relation_table"] = h.relationTable + + return utils.StrFormat(queryTempl, queryMap) +} + +func (h hierarchicalRelationshipHandler) handleValues(f *filterBuilder, c models.HierarchicalMultiCriterionInput, isParents bool, aliasSuffix string) { + if len(c.Value) == 0 { + return + } + + var args []interface{} + for _, val := range c.Value { + args = append(args, val) + } + + depthVal := 0 + if c.Depth != nil { + depthVal = *c.Depth + } + + tableAlias := h.parentsAlias() + if !isParents { + tableAlias = h.childrenAlias() + } + tableAlias += aliasSuffix + + query := h.valueQuery(c.Value, depthVal, tableAlias, isParents) + f.addRecursiveWith(query, args...) + + f.addLeftJoin(tableAlias, "", fmt.Sprintf("%s.item_id = %s.id", tableAlias, h.primaryTable)) + addHierarchicalConditionClauses(f, c, tableAlias, "root_id") +} + +func (h hierarchicalRelationshipHandler) handleValuesSimple(f *filterBuilder, value string, isParents bool) { + joinCol := h.childIDCol + valueCol := h.parentIDCol + if !isParents { + joinCol = h.parentIDCol + valueCol = h.childIDCol + } + + tableAlias := h.parentsAlias() + if !isParents { + tableAlias = h.childrenAlias() + } + + f.addInnerJoin(h.relationTable, tableAlias, fmt.Sprintf("%s.%s = %s.id", tableAlias, joinCol, h.primaryTable)) + f.addWhere(fmt.Sprintf("%s.%s = ?", tableAlias, valueCol), value) +} + +func (h hierarchicalRelationshipHandler) hierarchicalCriterionHandler(criterion *models.HierarchicalMultiCriterionInput, isParents bool) criterionHandlerFunc { + return func(ctx context.Context, f *filterBuilder) { + if criterion != nil { + c := criterion.CombineExcludes() + + // validate the modifier + if err := h.validateModifier(c.Modifier); err != nil { + f.setError(err) + return + } + + if c.Modifier == models.CriterionModifierIsNull || c.Modifier == models.CriterionModifierNotNull { + h.handleNullNotNull(f, c.Modifier, isParents) + return + } + + if len(c.Value) == 0 && len(c.Excludes) == 0 { + return + } + + depth := 0 + if c.Depth != nil { + depth = *c.Depth + } + + // if we have a single include, no excludes, and no depth, we can use a simple join and where clause + if (c.Modifier == models.CriterionModifierIncludes || c.Modifier == models.CriterionModifierIncludesAll) && len(c.Value) == 1 && len(c.Excludes) == 0 && depth == 0 { + h.handleValuesSimple(f, c.Value[0], isParents) + return + } + + aliasSuffix := "" + h.handleValues(f, c, isParents, aliasSuffix) + + if len(c.Excludes) > 0 { + exCriterion := models.HierarchicalMultiCriterionInput{ + Value: c.Excludes, + Depth: c.Depth, + Modifier: models.CriterionModifierExcludes, + } + + aliasSuffix := "2" + h.handleValues(f, exCriterion, isParents, aliasSuffix) + } + } + } +} + +func (h hierarchicalRelationshipHandler) ParentsCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + const isParents = true + return h.hierarchicalCriterionHandler(criterion, isParents) +} + +func (h hierarchicalRelationshipHandler) ChildrenCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + const isParents = false + return h.hierarchicalCriterionHandler(criterion, isParents) +} + +func (h hierarchicalRelationshipHandler) countCriterionHandler(c *models.IntCriterionInput, isParents bool) criterionHandlerFunc { + tableAlias := h.parentsAlias() + col := h.childIDCol + otherCol := h.parentIDCol + if !isParents { + tableAlias = h.childrenAlias() + col = h.parentIDCol + otherCol = h.childIDCol + } + tableAlias += "_count" + + return func(ctx context.Context, f *filterBuilder) { + if c != nil { + f.addLeftJoin(h.relationTable, tableAlias, fmt.Sprintf("%s.%s = %s.id", tableAlias, col, h.primaryTable)) + clause, args := getIntCriterionWhereClause(fmt.Sprintf("count(distinct %s.%s)", tableAlias, otherCol), *c) + + f.addHaving(clause, args...) + } + } +} + +func (h hierarchicalRelationshipHandler) ParentCountCriterionHandler(parentCount *models.IntCriterionInput) criterionHandlerFunc { + const isParents = true + return h.countCriterionHandler(parentCount, isParents) +} + +func (h hierarchicalRelationshipHandler) ChildCountCriterionHandler(childCount *models.IntCriterionInput) criterionHandlerFunc { + const isParents = false + return h.countCriterionHandler(childCount, isParents) +} diff --git a/pkg/sqlite/group.go b/pkg/sqlite/group.go index 21c224242f8..603494fe71a 100644 --- a/pkg/sqlite/group.go +++ b/pkg/sqlite/group.go @@ -27,6 +27,8 @@ const ( groupURLsTable = "group_urls" groupURLColumn = "url" + + groupRelationsTable = "groups_relations" ) type groupRow struct { @@ -128,6 +130,7 @@ var ( type GroupStore struct { blobJoinQueryBuilder tagRelationshipStore + groupRelationshipStore tableMgr *table } @@ -143,6 +146,9 @@ func NewGroupStore(blobStore *BlobStore) *GroupStore { joinTable: groupsTagsTableMgr, }, }, + groupRelationshipStore: groupRelationshipStore{ + table: groupRelationshipTableMgr, + }, tableMgr: groupTableMgr, } @@ -176,6 +182,14 @@ func (qb *GroupStore) Create(ctx context.Context, newObject *models.Group) error return err } + if err := qb.groupRelationshipStore.createContainingRelationships(ctx, id, newObject.ContainingGroups); err != nil { + return err + } + + if err := qb.groupRelationshipStore.createSubRelationships(ctx, id, newObject.SubGroups); err != nil { + return err + } + updated, err := qb.find(ctx, id) if err != nil { return fmt.Errorf("finding after create: %w", err) @@ -211,6 +225,14 @@ func (qb *GroupStore) UpdatePartial(ctx context.Context, id int, partial models. return nil, err } + if err := qb.groupRelationshipStore.modifyContainingRelationships(ctx, id, partial.ContainingGroups); err != nil { + return nil, err + } + + if err := qb.groupRelationshipStore.modifySubRelationships(ctx, id, partial.SubGroups); err != nil { + return nil, err + } + return qb.find(ctx, id) } @@ -232,6 +254,14 @@ func (qb *GroupStore) Update(ctx context.Context, updatedObject *models.Group) e return err } + if err := qb.groupRelationshipStore.replaceContainingRelationships(ctx, updatedObject.ID, updatedObject.ContainingGroups); err != nil { + return err + } + + if err := qb.groupRelationshipStore.replaceSubRelationships(ctx, updatedObject.ID, updatedObject.SubGroups); err != nil { + return err + } + return nil } @@ -412,9 +442,7 @@ func (qb *GroupStore) makeQuery(ctx context.Context, groupFilter *models.GroupFi return nil, err } - var err error - query.sortAndPagination, err = qb.getGroupSort(findFilter) - if err != nil { + if err := qb.setGroupSort(&query, findFilter); err != nil { return nil, err } @@ -460,11 +488,12 @@ var groupSortOptions = sortOptions{ "random", "rating", "scenes_count", + "sub_group_order", "tag_count", "updated_at", } -func (qb *GroupStore) getGroupSort(findFilter *models.FindFilterType) (string, error) { +func (qb *GroupStore) setGroupSort(query *queryBuilder, findFilter *models.FindFilterType) error { var sort string var direction string if findFilter == nil { @@ -477,22 +506,31 @@ func (qb *GroupStore) getGroupSort(findFilter *models.FindFilterType) (string, e // CVE-2024-32231 - ensure sort is in the list of allowed sorts if err := groupSortOptions.validateSort(sort); err != nil { - return "", err + return err } - sortQuery := "" switch sort { + case "sub_group_order": + // sub_group_order is a special sort that sorts by the order_index of the subgroups + if query.hasJoin("groups_parents") { + query.sortAndPagination += getSort("order_index", direction, "groups_parents") + } else { + // this will give unexpected results if the query is not filtered by a parent group and + // the group has multiple parents and order indexes + query.join(groupRelationsTable, "", "groups.id = groups_relations.sub_id") + query.sortAndPagination += getSort("order_index", direction, groupRelationsTable) + } case "tag_count": - sortQuery += getCountSort(groupTable, groupsTagsTable, groupIDColumn, direction) + query.sortAndPagination += getCountSort(groupTable, groupsTagsTable, groupIDColumn, direction) case "scenes_count": // generic getSort won't work for this - sortQuery += getCountSort(groupTable, groupsScenesTable, groupIDColumn, direction) + query.sortAndPagination += getCountSort(groupTable, groupsScenesTable, groupIDColumn, direction) default: - sortQuery += getSort(sort, direction, "groups") + query.sortAndPagination += getSort(sort, direction, "groups") } // Whatever the sorting, always use name/id as a final sort - sortQuery += ", COALESCE(groups.name, groups.id) COLLATE NATURAL_CI ASC" - return sortQuery, nil + query.sortAndPagination += ", COALESCE(groups.name, groups.id) COLLATE NATURAL_CI ASC" + return nil } func (qb *GroupStore) queryGroups(ctx context.Context, query string, args []interface{}) ([]*models.Group, error) { @@ -592,3 +630,74 @@ WHERE groups.studio_id = ? func (qb *GroupStore) GetURLs(ctx context.Context, groupID int) ([]string, error) { return groupsURLsTableMgr.get(ctx, groupID) } + +// FindSubGroupIDs returns a list of group IDs where a group in the ids list is a sub-group of the parent group +func (qb *GroupStore) FindSubGroupIDs(ctx context.Context, containingID int, ids []int) ([]int, error) { + /* + SELECT gr.sub_id FROM groups_relations gr + WHERE gr.containing_id = :parentID AND gr.sub_id IN (:ids); + */ + table := groupRelationshipTableMgr.table + q := dialect.From(table).Prepared(true). + Select(table.Col("sub_id")).Where( + table.Col("containing_id").Eq(containingID), + table.Col("sub_id").In(ids), + ) + + const single = false + var ret []int + if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error { + var id int + if err := r.Scan(&id); err != nil { + return err + } + + ret = append(ret, id) + return nil + }); err != nil { + return nil, err + } + + return ret, nil +} + +// FindInAscestors returns a list of group IDs where a group in the ids list is an ascestor of the ancestor group IDs +func (qb *GroupStore) FindInAncestors(ctx context.Context, ascestorIDs []int, ids []int) ([]int, error) { + /* + WITH RECURSIVE ascestors AS ( + SELECT g.id AS parent_id FROM groups g WHERE g.id IN (:ascestorIDs) + UNION + SELECT gr.containing_id FROM groups_relations gr INNER JOIN ascestors a ON a.parent_id = gr.sub_id + ) + SELECT p.parent_id FROM ascestors p WHERE p.parent_id IN (:ids); + */ + table := qb.table() + const ascestors = "ancestors" + const parentID = "parent_id" + q := dialect.From(ascestors).Prepared(true). + WithRecursive(ascestors, + dialect.From(qb.table()).Select(table.Col(idColumn).As(parentID)). + Where(table.Col(idColumn).In(ascestorIDs)). + Union( + dialect.From(groupRelationsJoinTable).InnerJoin( + goqu.I(ascestors), goqu.On(goqu.I("parent_id").Eq(goqu.I("sub_id"))), + ).Select("containing_id"), + ), + ).Select(parentID).Where(goqu.I(parentID).In(ids)) + + const single = false + var ret []int + if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error { + var id int + if err := r.Scan(&id); err != nil { + return err + } + + ret = append(ret, id) + return nil + }); err != nil { + return nil, err + } + + return ret, nil +} diff --git a/pkg/sqlite/group_filter.go b/pkg/sqlite/group_filter.go index 97bde1f2474..dcb7bcdfc94 100644 --- a/pkg/sqlite/group_filter.go +++ b/pkg/sqlite/group_filter.go @@ -51,6 +51,14 @@ func (qb *groupFilterHandler) handle(ctx context.Context, f *filterBuilder) { f.handleCriterion(ctx, qb.criterionHandler()) } +var groupHierarchyHandler = hierarchicalRelationshipHandler{ + primaryTable: groupTable, + relationTable: groupRelationsTable, + aliasPrefix: groupTable, + parentIDCol: "containing_id", + childIDCol: "sub_id", +} + func (qb *groupFilterHandler) criterionHandler() criterionHandler { groupFilter := qb.groupFilter return compoundHandler{ @@ -66,6 +74,10 @@ func (qb *groupFilterHandler) criterionHandler() criterionHandler { qb.tagsCriterionHandler(groupFilter.Tags), qb.tagCountCriterionHandler(groupFilter.TagCount), &dateCriterionHandler{groupFilter.Date, "groups.date", nil}, + groupHierarchyHandler.ParentsCriterionHandler(groupFilter.ContainingGroups), + groupHierarchyHandler.ChildrenCriterionHandler(groupFilter.SubGroups), + groupHierarchyHandler.ParentCountCriterionHandler(groupFilter.ContainingGroupCount), + groupHierarchyHandler.ChildCountCriterionHandler(groupFilter.SubGroupCount), ×tampCriterionHandler{groupFilter.CreatedAt, "groups.created_at", nil}, ×tampCriterionHandler{groupFilter.UpdatedAt, "groups.updated_at", nil}, diff --git a/pkg/sqlite/group_relationships.go b/pkg/sqlite/group_relationships.go new file mode 100644 index 00000000000..fe94394f905 --- /dev/null +++ b/pkg/sqlite/group_relationships.go @@ -0,0 +1,457 @@ +package sqlite + +import ( + "context" + "fmt" + + "github.com/doug-martin/goqu/v9" + "github.com/doug-martin/goqu/v9/exp" + "github.com/jmoiron/sqlx" + "github.com/stashapp/stash/pkg/models" + "gopkg.in/guregu/null.v4" + "gopkg.in/guregu/null.v4/zero" +) + +type groupRelationshipRow struct { + ContainingID int `db:"containing_id"` + SubID int `db:"sub_id"` + OrderIndex int `db:"order_index"` + Description zero.String `db:"description"` +} + +func (r groupRelationshipRow) resolve(useContainingID bool) models.GroupIDDescription { + id := r.ContainingID + if !useContainingID { + id = r.SubID + } + + return models.GroupIDDescription{ + GroupID: id, + Description: r.Description.String, + } +} + +type groupRelationshipStore struct { + table *table +} + +func (s *groupRelationshipStore) GetContainingGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) { + const idIsContaining = false + return s.getGroupRelationships(ctx, id, idIsContaining) +} + +func (s *groupRelationshipStore) GetSubGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) { + const idIsContaining = true + return s.getGroupRelationships(ctx, id, idIsContaining) +} + +func (s *groupRelationshipStore) getGroupRelationships(ctx context.Context, id int, idIsContaining bool) ([]models.GroupIDDescription, error) { + col := "containing_id" + if !idIsContaining { + col = "sub_id" + } + + table := s.table.table + q := dialect.Select(table.All()). + From(table). + Where(table.Col(col).Eq(id)). + Order(table.Col("order_index").Asc()) + + const single = false + var ret []models.GroupIDDescription + if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error { + var row groupRelationshipRow + if err := rows.StructScan(&row); err != nil { + return err + } + + ret = append(ret, row.resolve(!idIsContaining)) + + return nil + }); err != nil { + return nil, fmt.Errorf("getting group relationships from %s: %w", table.GetTable(), err) + } + + return ret, nil +} + +// getMaxOrderIndex gets the maximum order index for the containing group with the given id +func (s *groupRelationshipStore) getMaxOrderIndex(ctx context.Context, containingID int) (int, error) { + idColumn := s.table.table.Col("containing_id") + + q := dialect.Select(goqu.MAX("order_index")). + From(s.table.table). + Where(idColumn.Eq(containingID)) + + var maxOrderIndex zero.Int + if err := querySimple(ctx, q, &maxOrderIndex); err != nil { + return 0, fmt.Errorf("getting max order index: %w", err) + } + + return int(maxOrderIndex.Int64), nil +} + +// createRelationships creates relationships between a group and other groups. +// If idIsContaining is true, the provided id is the containing group. +func (s *groupRelationshipStore) createRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions, idIsContaining bool) error { + if d.Loaded() { + for i, v := range d.List() { + orderIndex := i + 1 + + r := groupRelationshipRow{ + ContainingID: id, + SubID: v.GroupID, + OrderIndex: orderIndex, + Description: zero.StringFrom(v.Description), + } + + if !idIsContaining { + // get the max order index of the containing groups sub groups + containingID := v.GroupID + maxOrderIndex, err := s.getMaxOrderIndex(ctx, containingID) + if err != nil { + return err + } + + r.ContainingID = v.GroupID + r.SubID = id + r.OrderIndex = maxOrderIndex + 1 + } + + _, err := s.table.insert(ctx, r) + if err != nil { + return fmt.Errorf("inserting into %s: %w", s.table.table.GetTable(), err) + } + } + + return nil + } + + return nil +} + +// createRelationships creates relationships between a group and other groups. +// If idIsContaining is true, the provided id is the containing group. +func (s *groupRelationshipStore) createContainingRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error { + const idIsContaining = false + return s.createRelationships(ctx, id, d, idIsContaining) +} + +// createRelationships creates relationships between a group and other groups. +// If idIsContaining is true, the provided id is the containing group. +func (s *groupRelationshipStore) createSubRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error { + const idIsContaining = true + return s.createRelationships(ctx, id, d, idIsContaining) +} + +func (s *groupRelationshipStore) replaceRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions, idIsContaining bool) error { + // always destroy the existing relationships even if the new list is empty + if err := s.destroyAllJoins(ctx, id, idIsContaining); err != nil { + return err + } + + return s.createRelationships(ctx, id, d, idIsContaining) +} + +func (s *groupRelationshipStore) replaceContainingRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error { + const idIsContaining = false + return s.replaceRelationships(ctx, id, d, idIsContaining) +} + +func (s *groupRelationshipStore) replaceSubRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error { + const idIsContaining = true + return s.replaceRelationships(ctx, id, d, idIsContaining) +} + +func (s *groupRelationshipStore) modifyRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions, idIsContaining bool) error { + if v == nil { + return nil + } + + switch v.Mode { + case models.RelationshipUpdateModeSet: + return s.replaceJoins(ctx, id, *v, idIsContaining) + case models.RelationshipUpdateModeAdd: + return s.addJoins(ctx, id, v.Groups, idIsContaining) + case models.RelationshipUpdateModeRemove: + toRemove := make([]int, len(v.Groups)) + for i, vv := range v.Groups { + toRemove[i] = vv.GroupID + } + return s.destroyJoins(ctx, id, toRemove, idIsContaining) + } + + return nil +} + +func (s *groupRelationshipStore) modifyContainingRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions) error { + const idIsContaining = false + return s.modifyRelationships(ctx, id, v, idIsContaining) +} + +func (s *groupRelationshipStore) modifySubRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions) error { + const idIsContaining = true + return s.modifyRelationships(ctx, id, v, idIsContaining) +} + +func (s *groupRelationshipStore) addJoins(ctx context.Context, id int, groups []models.GroupIDDescription, idIsContaining bool) error { + // if we're adding to a containing group, get the max order index first + var maxOrderIndex int + if idIsContaining { + var err error + maxOrderIndex, err = s.getMaxOrderIndex(ctx, id) + if err != nil { + return err + } + } + + for i, vv := range groups { + r := groupRelationshipRow{ + Description: zero.StringFrom(vv.Description), + } + + if idIsContaining { + r.ContainingID = id + r.SubID = vv.GroupID + r.OrderIndex = maxOrderIndex + (i + 1) + } else { + // get the max order index of the containing groups sub groups + containingMaxOrderIndex, err := s.getMaxOrderIndex(ctx, vv.GroupID) + if err != nil { + return err + } + + r.ContainingID = vv.GroupID + r.SubID = id + r.OrderIndex = containingMaxOrderIndex + 1 + } + + _, err := s.table.insert(ctx, r) + if err != nil { + return fmt.Errorf("inserting into %s: %w", s.table.table.GetTable(), err) + } + } + + return nil +} + +func (s *groupRelationshipStore) destroyAllJoins(ctx context.Context, id int, idIsContaining bool) error { + table := s.table.table + idColumn := table.Col("containing_id") + if !idIsContaining { + idColumn = table.Col("sub_id") + } + + q := dialect.Delete(table).Where(idColumn.Eq(id)) + + if _, err := exec(ctx, q); err != nil { + return fmt.Errorf("destroying %s: %w", table.GetTable(), err) + } + + return nil +} + +func (s *groupRelationshipStore) replaceJoins(ctx context.Context, id int, v models.UpdateGroupDescriptions, idIsContaining bool) error { + if err := s.destroyAllJoins(ctx, id, idIsContaining); err != nil { + return err + } + + // convert to RelatedGroupDescriptions + rgd := models.NewRelatedGroupDescriptions(v.Groups) + return s.createRelationships(ctx, id, rgd, idIsContaining) +} + +func (s *groupRelationshipStore) destroyJoins(ctx context.Context, id int, toRemove []int, idIsContaining bool) error { + table := s.table.table + idColumn := table.Col("containing_id") + fkColumn := table.Col("sub_id") + if !idIsContaining { + idColumn = table.Col("sub_id") + fkColumn = table.Col("containing_id") + } + + q := dialect.Delete(table).Where(idColumn.Eq(id), fkColumn.In(toRemove)) + + if _, err := exec(ctx, q); err != nil { + return fmt.Errorf("destroying %s: %w", table.GetTable(), err) + } + + return nil +} + +func (s *groupRelationshipStore) getOrderIndexOfSubGroup(ctx context.Context, containingGroupID int, subGroupID int) (int, error) { + table := s.table.table + q := dialect.Select("order_index"). + From(table). + Where( + table.Col("containing_id").Eq(containingGroupID), + table.Col("sub_id").Eq(subGroupID), + ) + + var orderIndex null.Int + if err := querySimple(ctx, q, &orderIndex); err != nil { + return 0, fmt.Errorf("getting order index: %w", err) + } + + if !orderIndex.Valid { + return 0, fmt.Errorf("sub-group %d not found in containing group %d", subGroupID, containingGroupID) + } + + return int(orderIndex.Int64), nil +} + +func (s *groupRelationshipStore) getGroupIDAtOrderIndex(ctx context.Context, containingGroupID int, orderIndex int) (*int, error) { + table := s.table.table + q := dialect.Select(table.Col("sub_id")).From(table).Where( + table.Col("containing_id").Eq(containingGroupID), + table.Col("order_index").Eq(orderIndex), + ) + + var ret null.Int + if err := querySimple(ctx, q, &ret); err != nil { + return nil, fmt.Errorf("getting sub id for order index: %w", err) + } + + if !ret.Valid { + return nil, nil + } + + intRet := int(ret.Int64) + return &intRet, nil +} + +func (s *groupRelationshipStore) getOrderIndexAfterOrderIndex(ctx context.Context, containingGroupID int, orderIndex int) (int, error) { + table := s.table.table + q := dialect.Select(goqu.MIN("order_index")).From(table).Where( + table.Col("containing_id").Eq(containingGroupID), + table.Col("order_index").Gt(orderIndex), + ) + + var ret null.Int + if err := querySimple(ctx, q, &ret); err != nil { + return 0, fmt.Errorf("getting order index: %w", err) + } + + if !ret.Valid { + return orderIndex + 1, nil + } + + return int(ret.Int64), nil +} + +// incrementOrderIndexes increments the order_index value of all sub-groups in the containing group at or after the given index +func (s *groupRelationshipStore) incrementOrderIndexes(ctx context.Context, groupID int, indexBefore int) error { + table := s.table.table + + // WORKAROUND - sqlite won't allow incrementing the value directly since it causes a + // unique constraint violation. + // Instead, we first set the order index to a negative value temporarily + // see https://stackoverflow.com/a/7703239/695786 + q := dialect.Update(table).Set(exp.Record{ + "order_index": goqu.L("-order_index"), + }).Where( + table.Col("containing_id").Eq(groupID), + table.Col("order_index").Gte(indexBefore), + ) + + if _, err := exec(ctx, q); err != nil { + return fmt.Errorf("updating %s: %w", table.GetTable(), err) + } + + q = dialect.Update(table).Set(exp.Record{ + "order_index": goqu.L("1-order_index"), + }).Where( + table.Col("containing_id").Eq(groupID), + table.Col("order_index").Lt(0), + ) + + if _, err := exec(ctx, q); err != nil { + return fmt.Errorf("updating %s: %w", table.GetTable(), err) + } + + return nil +} + +func (s *groupRelationshipStore) reorderSubGroup(ctx context.Context, groupID int, subGroupID int, insertPointID int, insertAfter bool) error { + insertPointIndex, err := s.getOrderIndexOfSubGroup(ctx, groupID, insertPointID) + if err != nil { + return err + } + + // if we're setting before + if insertAfter { + insertPointIndex, err = s.getOrderIndexAfterOrderIndex(ctx, groupID, insertPointIndex) + if err != nil { + return err + } + } + + // increment the order index of all sub-groups after and including the insertion point + if err := s.incrementOrderIndexes(ctx, groupID, int(insertPointIndex)); err != nil { + return err + } + + // set the order index of the sub-group to the insertion point + table := s.table.table + q := dialect.Update(table).Set(exp.Record{ + "order_index": insertPointIndex, + }).Where( + table.Col("containing_id").Eq(groupID), + table.Col("sub_id").Eq(subGroupID), + ) + + if _, err := exec(ctx, q); err != nil { + return fmt.Errorf("updating %s: %w", table.GetTable(), err) + } + + return nil +} + +func (s *groupRelationshipStore) AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error { + const idIsContaining = true + + if err := s.addJoins(ctx, groupID, subGroups, idIsContaining); err != nil { + return err + } + + ids := make([]int, len(subGroups)) + for i, v := range subGroups { + ids[i] = v.GroupID + } + + if insertIndex != nil { + // get the id of the sub-group at the insert index + insertPointID, err := s.getGroupIDAtOrderIndex(ctx, groupID, *insertIndex) + if err != nil { + return err + } + + if insertPointID == nil { + // if the insert index is out of bounds, just assume adding to the end + return nil + } + + // reorder the sub-groups + const insertAfter = false + if err := s.ReorderSubGroups(ctx, groupID, ids, *insertPointID, insertAfter); err != nil { + return err + } + } + + return nil +} + +func (s *groupRelationshipStore) RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error { + const idIsContaining = true + return s.destroyJoins(ctx, groupID, subGroupIDs, idIsContaining) +} + +func (s *groupRelationshipStore) ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error { + for _, id := range subGroupIDs { + if err := s.reorderSubGroup(ctx, groupID, id, insertPointID, insertAfter); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/sqlite/group_test.go b/pkg/sqlite/group_test.go index 45171337970..1d3637c8611 100644 --- a/pkg/sqlite/group_test.go +++ b/pkg/sqlite/group_test.go @@ -14,6 +14,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sliceutil" + "github.com/stashapp/stash/pkg/sliceutil/intslice" ) func loadGroupRelationships(ctx context.Context, expected models.Group, actual *models.Group) error { @@ -27,22 +29,34 @@ func loadGroupRelationships(ctx context.Context, expected models.Group, actual * return err } } + if expected.ContainingGroups.Loaded() { + if err := actual.LoadContainingGroupIDs(ctx, db.Group); err != nil { + return err + } + } + if expected.SubGroups.Loaded() { + if err := actual.LoadSubGroupIDs(ctx, db.Group); err != nil { + return err + } + } return nil } func Test_GroupStore_Create(t *testing.T) { var ( - name = "name" - url = "url" - aliases = "alias1, alias2" - director = "director" - rating = 60 - duration = 34 - synopsis = "synopsis" - date, _ = models.ParseDate("2003-02-01") - createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) - updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + name = "name" + url = "url" + aliases = "alias1, alias2" + director = "director" + rating = 60 + duration = 34 + synopsis = "synopsis" + date, _ = models.ParseDate("2003-02-01") + containingGroupDescription = "containingGroupDescription" + subGroupDescription = "subGroupDescription" + createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) ) tests := []struct { @@ -53,15 +67,21 @@ func Test_GroupStore_Create(t *testing.T) { { "full", models.Group{ - Name: name, - Duration: &duration, - Date: &date, - Rating: &rating, - StudioID: &studioIDs[studioIdxWithGroup], - Director: director, - Synopsis: synopsis, - URLs: models.NewRelatedStrings([]string{url}), - TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithGroup]}), + Name: name, + Duration: &duration, + Date: &date, + Rating: &rating, + StudioID: &studioIDs[studioIdxWithGroup], + Director: director, + Synopsis: synopsis, + URLs: models.NewRelatedStrings([]string{url}), + TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithGroup]}), + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithScene], Description: containingGroupDescription}, + }), + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithStudio], Description: subGroupDescription}, + }), Aliases: aliases, CreatedAt: createdAt, UpdatedAt: updatedAt, @@ -76,6 +96,22 @@ func Test_GroupStore_Create(t *testing.T) { }, true, }, + { + "invalid containing group id", + models.Group{ + Name: name, + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{{GroupID: invalidID}}), + }, + true, + }, + { + "invalid sub group id", + models.Group{ + Name: name, + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{{GroupID: invalidID}}), + }, + true, + }, } qb := db.Group @@ -131,36 +167,44 @@ func Test_GroupStore_Create(t *testing.T) { func Test_groupQueryBuilder_Update(t *testing.T) { var ( - name = "name" - url = "url" - aliases = "alias1, alias2" - director = "director" - rating = 60 - duration = 34 - synopsis = "synopsis" - date, _ = models.ParseDate("2003-02-01") - createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) - updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + name = "name" + url = "url" + aliases = "alias1, alias2" + director = "director" + rating = 60 + duration = 34 + synopsis = "synopsis" + date, _ = models.ParseDate("2003-02-01") + containingGroupDescription = "containingGroupDescription" + subGroupDescription = "subGroupDescription" + createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) ) tests := []struct { name string - updatedObject *models.Group + updatedObject models.Group wantErr bool }{ { "full", - &models.Group{ - ID: groupIDs[groupIdxWithTag], - Name: name, - Duration: &duration, - Date: &date, - Rating: &rating, - StudioID: &studioIDs[studioIdxWithGroup], - Director: director, - Synopsis: synopsis, - URLs: models.NewRelatedStrings([]string{url}), - TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithGroup]}), + models.Group{ + ID: groupIDs[groupIdxWithTag], + Name: name, + Duration: &duration, + Date: &date, + Rating: &rating, + StudioID: &studioIDs[studioIdxWithGroup], + Director: director, + Synopsis: synopsis, + URLs: models.NewRelatedStrings([]string{url}), + TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithGroup]}), + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithScene], Description: containingGroupDescription}, + }), + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithStudio], Description: subGroupDescription}, + }), Aliases: aliases, CreatedAt: createdAt, UpdatedAt: updatedAt, @@ -169,16 +213,34 @@ func Test_groupQueryBuilder_Update(t *testing.T) { }, { "clear tag ids", - &models.Group{ + models.Group{ ID: groupIDs[groupIdxWithTag], Name: name, TagIDs: models.NewRelatedIDs([]int{}), }, false, }, + { + "clear containing ids", + models.Group{ + ID: groupIDs[groupIdxWithParent], + Name: name, + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{}), + }, + false, + }, + { + "clear sub ids", + models.Group{ + ID: groupIDs[groupIdxWithChild], + Name: name, + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{}), + }, + false, + }, { "invalid studio id", - &models.Group{ + models.Group{ ID: groupIDs[groupIdxWithScene], Name: name, StudioID: &invalidID, @@ -187,13 +249,31 @@ func Test_groupQueryBuilder_Update(t *testing.T) { }, { "invalid tag id", - &models.Group{ + models.Group{ ID: groupIDs[groupIdxWithScene], Name: name, TagIDs: models.NewRelatedIDs([]int{invalidID}), }, true, }, + { + "invalid containing group id", + models.Group{ + ID: groupIDs[groupIdxWithScene], + Name: name, + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{{GroupID: invalidID}}), + }, + true, + }, + { + "invalid sub group id", + models.Group{ + ID: groupIDs[groupIdxWithScene], + Name: name, + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{{GroupID: invalidID}}), + }, + true, + }, } qb := db.Group @@ -201,9 +281,10 @@ func Test_groupQueryBuilder_Update(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - copy := *tt.updatedObject + actual := tt.updatedObject + expected := tt.updatedObject - if err := qb.Update(ctx, tt.updatedObject); (err != nil) != tt.wantErr { + if err := qb.Update(ctx, &actual); (err != nil) != tt.wantErr { t.Errorf("groupQueryBuilder.Update() error = %v, wantErr %v", err, tt.wantErr) } @@ -211,49 +292,61 @@ func Test_groupQueryBuilder_Update(t *testing.T) { return } - s, err := qb.Find(ctx, tt.updatedObject.ID) + s, err := qb.Find(ctx, actual.ID) if err != nil { t.Errorf("groupQueryBuilder.Find() error = %v", err) } // load relationships - if err := loadGroupRelationships(ctx, copy, s); err != nil { + if err := loadGroupRelationships(ctx, expected, s); err != nil { t.Errorf("loadGroupRelationships() error = %v", err) return } - assert.Equal(copy, *s) + assert.Equal(expected, *s) }) } } -func clearGroupPartial() models.GroupPartial { +var clearGroupPartial = models.GroupPartial{ // leave mandatory fields - return models.GroupPartial{ - Aliases: models.OptionalString{Set: true, Null: true}, - Synopsis: models.OptionalString{Set: true, Null: true}, - Director: models.OptionalString{Set: true, Null: true}, - Duration: models.OptionalInt{Set: true, Null: true}, - URLs: &models.UpdateStrings{Mode: models.RelationshipUpdateModeSet}, - Date: models.OptionalDate{Set: true, Null: true}, - Rating: models.OptionalInt{Set: true, Null: true}, - StudioID: models.OptionalInt{Set: true, Null: true}, - TagIDs: &models.UpdateIDs{Mode: models.RelationshipUpdateModeSet}, + Aliases: models.OptionalString{Set: true, Null: true}, + Synopsis: models.OptionalString{Set: true, Null: true}, + Director: models.OptionalString{Set: true, Null: true}, + Duration: models.OptionalInt{Set: true, Null: true}, + URLs: &models.UpdateStrings{Mode: models.RelationshipUpdateModeSet}, + Date: models.OptionalDate{Set: true, Null: true}, + Rating: models.OptionalInt{Set: true, Null: true}, + StudioID: models.OptionalInt{Set: true, Null: true}, + TagIDs: &models.UpdateIDs{Mode: models.RelationshipUpdateModeSet}, + ContainingGroups: &models.UpdateGroupDescriptions{Mode: models.RelationshipUpdateModeSet}, + SubGroups: &models.UpdateGroupDescriptions{Mode: models.RelationshipUpdateModeSet}, +} + +func emptyGroup(idx int) models.Group { + return models.Group{ + ID: groupIDs[idx], + Name: groupNames[idx], + TagIDs: models.NewRelatedIDs([]int{}), + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{}), + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{}), } } func Test_groupQueryBuilder_UpdatePartial(t *testing.T) { var ( - name = "name" - url = "url" - aliases = "alias1, alias2" - director = "director" - rating = 60 - duration = 34 - synopsis = "synopsis" - date, _ = models.ParseDate("2003-02-01") - createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) - updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + name = "name" + url = "url" + aliases = "alias1, alias2" + director = "director" + rating = 60 + duration = 34 + synopsis = "synopsis" + date, _ = models.ParseDate("2003-02-01") + containingGroupDescription = "containingGroupDescription" + subGroupDescription = "subGroupDescription" + createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) ) tests := []struct { @@ -285,6 +378,20 @@ func Test_groupQueryBuilder_UpdatePartial(t *testing.T) { IDs: []int{tagIDs[tagIdx1WithGroup], tagIDs[tagIdx1WithDupName]}, Mode: models.RelationshipUpdateModeSet, }, + ContainingGroups: &models.UpdateGroupDescriptions{ + Groups: []models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithStudio], Description: containingGroupDescription}, + {GroupID: groupIDs[groupIdxWithThreeTags], Description: containingGroupDescription}, + }, + Mode: models.RelationshipUpdateModeSet, + }, + SubGroups: &models.UpdateGroupDescriptions{ + Groups: []models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithTag], Description: subGroupDescription}, + {GroupID: groupIDs[groupIdxWithDupName], Description: subGroupDescription}, + }, + Mode: models.RelationshipUpdateModeSet, + }, }, models.Group{ ID: groupIDs[groupIdxWithScene], @@ -300,17 +407,113 @@ func Test_groupQueryBuilder_UpdatePartial(t *testing.T) { CreatedAt: createdAt, UpdatedAt: updatedAt, TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithGroup]}), + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithStudio], Description: containingGroupDescription}, + {GroupID: groupIDs[groupIdxWithThreeTags], Description: containingGroupDescription}, + }), + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithTag], Description: subGroupDescription}, + {GroupID: groupIDs[groupIdxWithDupName], Description: subGroupDescription}, + }), }, false, }, { "clear all", groupIDs[groupIdxWithScene], - clearGroupPartial(), + clearGroupPartial, + emptyGroup(groupIdxWithScene), + false, + }, + { + "clear tag ids", + groupIDs[groupIdxWithTag], + clearGroupPartial, + emptyGroup(groupIdxWithTag), + false, + }, + { + "clear group relationships", + groupIDs[groupIdxWithParentAndChild], + clearGroupPartial, + emptyGroup(groupIdxWithParentAndChild), + false, + }, + { + "add containing group", + groupIDs[groupIdxWithParent], + models.GroupPartial{ + ContainingGroups: &models.UpdateGroupDescriptions{ + Groups: []models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithScene], Description: containingGroupDescription}, + }, + Mode: models.RelationshipUpdateModeAdd, + }, + }, models.Group{ - ID: groupIDs[groupIdxWithScene], - Name: groupNames[groupIdxWithScene], - TagIDs: models.NewRelatedIDs([]int{}), + ID: groupIDs[groupIdxWithParent], + Name: groupNames[groupIdxWithParent], + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithChild]}, + {GroupID: groupIDs[groupIdxWithScene], Description: containingGroupDescription}, + }), + }, + false, + }, + { + "add sub group", + groupIDs[groupIdxWithChild], + models.GroupPartial{ + SubGroups: &models.UpdateGroupDescriptions{ + Groups: []models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithScene], Description: subGroupDescription}, + }, + Mode: models.RelationshipUpdateModeAdd, + }, + }, + models.Group{ + ID: groupIDs[groupIdxWithChild], + Name: groupNames[groupIdxWithChild], + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithParent]}, + {GroupID: groupIDs[groupIdxWithScene], Description: subGroupDescription}, + }), + }, + false, + }, + { + "remove containing group", + groupIDs[groupIdxWithParent], + models.GroupPartial{ + ContainingGroups: &models.UpdateGroupDescriptions{ + Groups: []models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithChild]}, + }, + Mode: models.RelationshipUpdateModeRemove, + }, + }, + models.Group{ + ID: groupIDs[groupIdxWithParent], + Name: groupNames[groupIdxWithParent], + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{}), + }, + false, + }, + { + "remove sub group", + groupIDs[groupIdxWithChild], + models.GroupPartial{ + SubGroups: &models.UpdateGroupDescriptions{ + Groups: []models.GroupIDDescription{ + {GroupID: groupIDs[groupIdxWithParent]}, + }, + Mode: models.RelationshipUpdateModeRemove, + }, + }, + models.Group{ + ID: groupIDs[groupIdxWithChild], + Name: groupNames[groupIdxWithChild], + SubGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{}), }, false, }, @@ -784,7 +987,38 @@ func TestGroupQuerySorting(t *testing.T) { groups = queryGroups(ctx, t, nil, &findFilter) lastGroup := groups[len(groups)-1] - assert.Equal(t, groupIDs[groupIdxWithScene], lastGroup.ID) + assert.Equal(t, groupIDs[groupIdxWithParentAndScene], lastGroup.ID) + + return nil + }) +} + +func TestGroupQuerySortOrderIndex(t *testing.T) { + sort := "sub_group_order" + direction := models.SortDirectionEnumDesc + findFilter := models.FindFilterType{ + Sort: &sort, + Direction: &direction, + } + + groupFilter := models.GroupFilterType{ + ContainingGroups: &models.HierarchicalMultiCriterionInput{ + Value: intslice.IntSliceToStringSlice([]int{groupIdxWithChild}), + Modifier: models.CriterionModifierIncludes, + }, + } + + withTxn(func(ctx context.Context) error { + // just ensure there are no errors + _, _, err := db.Group.Query(ctx, &groupFilter, &findFilter) + if err != nil { + t.Errorf("Error querying group: %s", err.Error()) + } + + _, _, err = db.Group.Query(ctx, nil, &findFilter) + if err != nil { + t.Errorf("Error querying group: %s", err.Error()) + } return nil }) @@ -830,6 +1064,832 @@ func TestGroupUpdateBackImage(t *testing.T) { } } +func TestGroupQueryContainingGroups(t *testing.T) { + const nameField = "Name" + + type criterion struct { + valueIdxs []int + modifier models.CriterionModifier + depth int + } + + tests := []struct { + name string + c criterion + q string + includeIdxs []int + }{ + { + "includes", + criterion{ + []int{groupIdxWithChild}, + models.CriterionModifierIncludes, + 0, + }, + "", + []int{groupIdxWithParent}, + }, + { + "excludes", + criterion{ + []int{groupIdxWithChild}, + models.CriterionModifierExcludes, + 0, + }, + getGroupStringValue(groupIdxWithParent, nameField), + nil, + }, + { + "includes (all levels)", + criterion{ + []int{groupIdxWithGrandChild}, + models.CriterionModifierIncludes, + -1, + }, + "", + []int{groupIdxWithParentAndChild, groupIdxWithGrandParent}, + }, + { + "includes (1 level)", + criterion{ + []int{groupIdxWithGrandChild}, + models.CriterionModifierIncludes, + 1, + }, + "", + []int{groupIdxWithParentAndChild, groupIdxWithGrandParent}, + }, + { + "is null", + criterion{ + nil, + models.CriterionModifierIsNull, + 0, + }, + getGroupStringValue(groupIdxWithParent, nameField), + nil, + }, + { + "not null", + criterion{ + nil, + models.CriterionModifierNotNull, + 0, + }, + "", + []int{groupIdxWithParentAndChild, groupIdxWithParent, groupIdxWithGrandParent, groupIdxWithParentAndScene}, + }, + } + + qb := db.Group + + for _, tt := range tests { + valueIDs := indexesToIDs(groupIDs, tt.c.valueIdxs) + expectedIDs := indexesToIDs(groupIDs, tt.includeIdxs) + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + groupFilter := &models.GroupFilterType{ + ContainingGroups: &models.HierarchicalMultiCriterionInput{ + Value: intslice.IntSliceToStringSlice(valueIDs), + Modifier: tt.c.modifier, + }, + } + + if tt.c.depth != 0 { + groupFilter.ContainingGroups.Depth = &tt.c.depth + } + + findFilter := models.FindFilterType{} + if tt.q != "" { + findFilter.Q = &tt.q + } + + groups, _, err := qb.Query(ctx, groupFilter, &findFilter) + if err != nil { + t.Errorf("GroupStore.Query() error = %v", err) + return + } + + // get ids of groups + groupIDs := sliceutil.Map(groups, func(g *models.Group) int { return g.ID }) + assert.ElementsMatch(t, expectedIDs, groupIDs) + }) + } +} + +func TestGroupQuerySubGroups(t *testing.T) { + const nameField = "Name" + + type criterion struct { + valueIdxs []int + modifier models.CriterionModifier + depth int + } + + tests := []struct { + name string + c criterion + q string + expectedIdxs []int + }{ + { + "includes", + criterion{ + []int{groupIdxWithParent}, + models.CriterionModifierIncludes, + 0, + }, + "", + []int{groupIdxWithChild}, + }, + { + "excludes", + criterion{ + []int{groupIdxWithParent}, + models.CriterionModifierExcludes, + 0, + }, + getGroupStringValue(groupIdxWithChild, nameField), + nil, + }, + { + "includes (all levels)", + criterion{ + []int{groupIdxWithGrandParent}, + models.CriterionModifierIncludes, + -1, + }, + "", + []int{groupIdxWithGrandChild, groupIdxWithParentAndChild}, + }, + { + "includes (1 level)", + criterion{ + []int{groupIdxWithGrandParent}, + models.CriterionModifierIncludes, + 1, + }, + "", + []int{groupIdxWithGrandChild, groupIdxWithParentAndChild}, + }, + { + "is null", + criterion{ + nil, + models.CriterionModifierIsNull, + 0, + }, + getGroupStringValue(groupIdxWithChild, nameField), + nil, + }, + { + "not null", + criterion{ + nil, + models.CriterionModifierNotNull, + 0, + }, + "", + []int{groupIdxWithGrandChild, groupIdxWithChild, groupIdxWithParentAndChild, groupIdxWithChildWithScene}, + }, + } + + qb := db.Group + + for _, tt := range tests { + valueIDs := indexesToIDs(groupIDs, tt.c.valueIdxs) + expectedIDs := indexesToIDs(groupIDs, tt.expectedIdxs) + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + groupFilter := &models.GroupFilterType{ + SubGroups: &models.HierarchicalMultiCriterionInput{ + Value: intslice.IntSliceToStringSlice(valueIDs), + Modifier: tt.c.modifier, + }, + } + + if tt.c.depth != 0 { + groupFilter.SubGroups.Depth = &tt.c.depth + } + + findFilter := models.FindFilterType{} + if tt.q != "" { + findFilter.Q = &tt.q + } + + groups, _, err := qb.Query(ctx, groupFilter, &findFilter) + if err != nil { + t.Errorf("GroupStore.Query() error = %v", err) + return + } + + // get ids of groups + groupIDs := sliceutil.Map(groups, func(g *models.Group) int { return g.ID }) + assert.ElementsMatch(t, expectedIDs, groupIDs) + }) + } +} + +func TestGroupQueryContainingGroupCount(t *testing.T) { + const nameField = "Name" + + tests := []struct { + name string + value int + modifier models.CriterionModifier + q string + expectedIdxs []int + }{ + { + "equals", + 1, + models.CriterionModifierEquals, + "", + []int{groupIdxWithParent, groupIdxWithGrandParent, groupIdxWithParentAndChild, groupIdxWithParentAndScene}, + }, + { + "not equals", + 1, + models.CriterionModifierNotEquals, + getGroupStringValue(groupIdxWithParent, nameField), + nil, + }, + { + "less than", + 1, + models.CriterionModifierLessThan, + getGroupStringValue(groupIdxWithParent, nameField), + nil, + }, + { + "greater than", + 0, + models.CriterionModifierGreaterThan, + "", + []int{groupIdxWithParent, groupIdxWithGrandParent, groupIdxWithParentAndChild, groupIdxWithParentAndScene}, + }, + } + + qb := db.Group + + for _, tt := range tests { + expectedIDs := indexesToIDs(groupIDs, tt.expectedIdxs) + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + groupFilter := &models.GroupFilterType{ + ContainingGroupCount: &models.IntCriterionInput{ + Value: tt.value, + Modifier: tt.modifier, + }, + } + + findFilter := models.FindFilterType{} + if tt.q != "" { + findFilter.Q = &tt.q + } + + groups, _, err := qb.Query(ctx, groupFilter, &findFilter) + if err != nil { + t.Errorf("GroupStore.Query() error = %v", err) + return + } + + // get ids of groups + groupIDs := sliceutil.Map(groups, func(g *models.Group) int { return g.ID }) + assert.ElementsMatch(t, expectedIDs, groupIDs) + }) + } +} + +func TestGroupQuerySubGroupCount(t *testing.T) { + const nameField = "Name" + + tests := []struct { + name string + value int + modifier models.CriterionModifier + q string + expectedIdxs []int + }{ + { + "equals", + 1, + models.CriterionModifierEquals, + "", + []int{groupIdxWithChild, groupIdxWithGrandChild, groupIdxWithParentAndChild, groupIdxWithChildWithScene}, + }, + { + "not equals", + 1, + models.CriterionModifierNotEquals, + getGroupStringValue(groupIdxWithChild, nameField), + nil, + }, + { + "less than", + 1, + models.CriterionModifierLessThan, + getGroupStringValue(groupIdxWithChild, nameField), + nil, + }, + { + "greater than", + 0, + models.CriterionModifierGreaterThan, + "", + []int{groupIdxWithChild, groupIdxWithGrandChild, groupIdxWithParentAndChild, groupIdxWithChildWithScene}, + }, + } + + qb := db.Group + + for _, tt := range tests { + expectedIDs := indexesToIDs(groupIDs, tt.expectedIdxs) + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + groupFilter := &models.GroupFilterType{ + SubGroupCount: &models.IntCriterionInput{ + Value: tt.value, + Modifier: tt.modifier, + }, + } + + findFilter := models.FindFilterType{} + if tt.q != "" { + findFilter.Q = &tt.q + } + + groups, _, err := qb.Query(ctx, groupFilter, &findFilter) + if err != nil { + t.Errorf("GroupStore.Query() error = %v", err) + return + } + + // get ids of groups + groupIDs := sliceutil.Map(groups, func(g *models.Group) int { return g.ID }) + assert.ElementsMatch(t, expectedIDs, groupIDs) + }) + } +} + +func TestGroupFindInAncestors(t *testing.T) { + tests := []struct { + name string + ancestorIdxs []int + idxs []int + expectedIdxs []int + }{ + { + "basic", + []int{groupIdxWithGrandParent}, + []int{groupIdxWithGrandChild}, + []int{groupIdxWithGrandChild}, + }, + { + "same", + []int{groupIdxWithScene}, + []int{groupIdxWithScene}, + []int{groupIdxWithScene}, + }, + { + "no matches", + []int{groupIdxWithGrandParent}, + []int{groupIdxWithScene}, + nil, + }, + } + + qb := db.Group + + for _, tt := range tests { + ancestorIDs := indexesToIDs(groupIDs, tt.ancestorIdxs) + ids := indexesToIDs(groupIDs, tt.idxs) + expectedIDs := indexesToIDs(groupIDs, tt.expectedIdxs) + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + found, err := qb.FindInAncestors(ctx, ancestorIDs, ids) + if err != nil { + t.Errorf("GroupStore.FindInAncestors() error = %v", err) + return + } + + // get ids of groups + assert.ElementsMatch(t, found, expectedIDs) + }) + } +} + +func TestGroupReorderSubGroups(t *testing.T) { + tests := []struct { + name string + subGroupLen int + idxsToMove []int + insertLoc int + insertAfter bool + // order of elements, using original indexes + expectedIdxs []int + }{ + { + "move single back before", + 5, + []int{2}, + 1, + false, + []int{0, 2, 1, 3, 4}, + }, + { + "move single forward before", + 5, + []int{2}, + 4, + false, + []int{0, 1, 3, 2, 4}, + }, + { + "move multiple back before", + 5, + []int{3, 2, 4}, + 0, + false, + []int{3, 2, 4, 0, 1}, + }, + { + "move multiple forward before", + 5, + []int{2, 1, 0}, + 4, + false, + []int{3, 2, 1, 0, 4}, + }, + { + "move single back after", + 5, + []int{2}, + 0, + true, + []int{0, 2, 1, 3, 4}, + }, + { + "move single forward after", + 5, + []int{2}, + 4, + true, + []int{0, 1, 3, 4, 2}, + }, + { + "move multiple back after", + 5, + []int{3, 2, 4}, + 0, + false, + []int{0, 3, 2, 4, 1}, + }, + { + "move multiple forward after", + 5, + []int{2, 1, 0}, + 4, + false, + []int{3, 4, 2, 1, 0}, + }, + } + + qb := db.Group + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + // create the group + group := models.Group{ + Name: "TestGroupReorderSubGroups", + } + + if err := qb.Create(ctx, &group); err != nil { + t.Errorf("GroupStore.Create() error = %v", err) + return + } + + // and sub-groups + idxToId := make([]int, tt.subGroupLen) + + for i := 0; i < tt.subGroupLen; i++ { + subGroup := models.Group{ + Name: fmt.Sprintf("SubGroup %d", i), + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: group.ID}, + }), + } + + if err := qb.Create(ctx, &subGroup); err != nil { + t.Errorf("GroupStore.Create() error = %v", err) + return + } + + idxToId[i] = subGroup.ID + } + + // reorder + idsToMove := indexesToIDs(idxToId, tt.idxsToMove) + insertID := idxToId[tt.insertLoc] + if err := qb.ReorderSubGroups(ctx, group.ID, idsToMove, insertID, tt.insertAfter); err != nil { + t.Errorf("GroupStore.ReorderSubGroups() error = %v", err) + return + } + + // validate the new order + gd, err := qb.GetSubGroupDescriptions(ctx, group.ID) + if err != nil { + t.Errorf("GroupStore.GetSubGroupDescriptions() error = %v", err) + return + } + + // get ids of groups + newIDs := sliceutil.Map(gd, func(gd models.GroupIDDescription) int { return gd.GroupID }) + newIdxs := sliceutil.Map(newIDs, func(id int) int { return sliceutil.Index(idxToId, id) }) + + assert.ElementsMatch(t, tt.expectedIdxs, newIdxs) + }) + } +} + +func TestGroupAddSubGroups(t *testing.T) { + tests := []struct { + name string + existingSubGroupLen int + insertGroupsLen int + insertLoc int + // order of elements, using original indexes + expectedIdxs []int + }{ + { + "append single", + 4, + 1, + 999, + []int{0, 1, 2, 3, 4}, + }, + { + "insert single middle", + 4, + 1, + 2, + []int{0, 1, 4, 2, 3}, + }, + { + "insert single start", + 4, + 1, + 0, + []int{4, 0, 1, 2, 3}, + }, + { + "append multiple", + 4, + 2, + 999, + []int{0, 1, 2, 3, 4, 5}, + }, + { + "insert multiple middle", + 4, + 2, + 2, + []int{0, 1, 4, 5, 2, 3}, + }, + { + "insert multiple start", + 4, + 2, + 0, + []int{4, 5, 0, 1, 2, 3}, + }, + } + + qb := db.Group + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + // create the group + group := models.Group{ + Name: "TestGroupReorderSubGroups", + } + + if err := qb.Create(ctx, &group); err != nil { + t.Errorf("GroupStore.Create() error = %v", err) + return + } + + // and sub-groups + idxToId := make([]int, tt.existingSubGroupLen+tt.insertGroupsLen) + + for i := 0; i < tt.existingSubGroupLen; i++ { + subGroup := models.Group{ + Name: fmt.Sprintf("Existing SubGroup %d", i), + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: group.ID}, + }), + } + + if err := qb.Create(ctx, &subGroup); err != nil { + t.Errorf("GroupStore.Create() error = %v", err) + return + } + + idxToId[i] = subGroup.ID + } + + // and sub-groups to insert + for i := 0; i < tt.insertGroupsLen; i++ { + subGroup := models.Group{ + Name: fmt.Sprintf("Inserted SubGroup %d", i), + } + + if err := qb.Create(ctx, &subGroup); err != nil { + t.Errorf("GroupStore.Create() error = %v", err) + return + } + + idxToId[i+tt.existingSubGroupLen] = subGroup.ID + } + + // convert ids to description + idDescriptions := make([]models.GroupIDDescription, tt.insertGroupsLen) + for i, id := range idxToId[tt.existingSubGroupLen:] { + idDescriptions[i] = models.GroupIDDescription{GroupID: id} + } + + // add + if err := qb.AddSubGroups(ctx, group.ID, idDescriptions, &tt.insertLoc); err != nil { + t.Errorf("GroupStore.AddSubGroups() error = %v", err) + return + } + + // validate the new order + gd, err := qb.GetSubGroupDescriptions(ctx, group.ID) + if err != nil { + t.Errorf("GroupStore.GetSubGroupDescriptions() error = %v", err) + return + } + + // get ids of groups + newIDs := sliceutil.Map(gd, func(gd models.GroupIDDescription) int { return gd.GroupID }) + newIdxs := sliceutil.Map(newIDs, func(id int) int { return sliceutil.Index(idxToId, id) }) + + assert.ElementsMatch(t, tt.expectedIdxs, newIdxs) + }) + } +} + +func TestGroupRemoveSubGroups(t *testing.T) { + tests := []struct { + name string + subGroupLen int + removeIdxs []int + // order of elements, using original indexes + expectedIdxs []int + }{ + { + "remove last", + 4, + []int{3}, + []int{0, 1, 2}, + }, + { + "remove first", + 4, + []int{0}, + []int{1, 2, 3}, + }, + { + "remove middle", + 4, + []int{2}, + []int{0, 1, 3}, + }, + { + "remove multiple", + 4, + []int{1, 3}, + []int{0, 2}, + }, + { + "remove all", + 4, + []int{0, 1, 2, 3}, + []int{}, + }, + } + + qb := db.Group + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + // create the group + group := models.Group{ + Name: "TestGroupReorderSubGroups", + } + + if err := qb.Create(ctx, &group); err != nil { + t.Errorf("GroupStore.Create() error = %v", err) + return + } + + // and sub-groups + idxToId := make([]int, tt.subGroupLen) + + for i := 0; i < tt.subGroupLen; i++ { + subGroup := models.Group{ + Name: fmt.Sprintf("Existing SubGroup %d", i), + ContainingGroups: models.NewRelatedGroupDescriptions([]models.GroupIDDescription{ + {GroupID: group.ID}, + }), + } + + if err := qb.Create(ctx, &subGroup); err != nil { + t.Errorf("GroupStore.Create() error = %v", err) + return + } + + idxToId[i] = subGroup.ID + } + + idsToRemove := indexesToIDs(idxToId, tt.removeIdxs) + if err := qb.RemoveSubGroups(ctx, group.ID, idsToRemove); err != nil { + t.Errorf("GroupStore.RemoveSubGroups() error = %v", err) + return + } + + // validate the new order + gd, err := qb.GetSubGroupDescriptions(ctx, group.ID) + if err != nil { + t.Errorf("GroupStore.GetSubGroupDescriptions() error = %v", err) + return + } + + // get ids of groups + newIDs := sliceutil.Map(gd, func(gd models.GroupIDDescription) int { return gd.GroupID }) + newIdxs := sliceutil.Map(newIDs, func(id int) int { return sliceutil.Index(idxToId, id) }) + + assert.ElementsMatch(t, tt.expectedIdxs, newIdxs) + }) + } +} + +func TestGroupFindSubGroupIDs(t *testing.T) { + tests := []struct { + name string + containingGroupIdx int + subIdxs []int + expectedIdxs []int + }{ + { + "overlap", + groupIdxWithGrandChild, + []int{groupIdxWithParentAndChild, groupIdxWithGrandParent}, + []int{groupIdxWithParentAndChild}, + }, + { + "non-overlap", + groupIdxWithGrandChild, + []int{groupIdxWithGrandParent}, + []int{}, + }, + { + "none", + groupIdxWithScene, + []int{groupIdxWithDupName}, + []int{}, + }, + { + "invalid", + invalidID, + []int{invalidID}, + []int{}, + }, + } + + qb := db.Group + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + subIDs := indexesToIDs(groupIDs, tt.subIdxs) + + id := indexToID(groupIDs, tt.containingGroupIdx) + + found, err := qb.FindSubGroupIDs(ctx, id, subIDs) + if err != nil { + t.Errorf("GroupStore.FindSubGroupIDs() error = %v", err) + return + } + + // get ids of groups + foundIdxs := sliceutil.Map(found, func(id int) int { return sliceutil.Index(groupIDs, id) }) + + assert.ElementsMatch(t, tt.expectedIdxs, foundIdxs) + }) + } +} + // TODO Update // TODO Destroy - ensure image is destroyed // TODO Find diff --git a/pkg/sqlite/migrations/67_group_relationships.up.sql b/pkg/sqlite/migrations/67_group_relationships.up.sql new file mode 100644 index 00000000000..76ac29cc83f --- /dev/null +++ b/pkg/sqlite/migrations/67_group_relationships.up.sql @@ -0,0 +1,13 @@ +CREATE TABLE `groups_relations` ( + `containing_id` integer not null, + `sub_id` integer not null, + `order_index` integer not null, + `description` varchar(255), + primary key (`containing_id`, `sub_id`), + foreign key (`containing_id`) references `groups`(`id`) on delete cascade, + foreign key (`sub_id`) references `groups`(`id`) on delete cascade, + check (`containing_id` != `sub_id`) +); + +CREATE INDEX `index_groups_relations_sub_id` ON `groups_relations` (`sub_id`); +CREATE UNIQUE INDEX `index_groups_relations_order_index_unique` ON `groups_relations` (`containing_id`, `order_index`); diff --git a/pkg/sqlite/query.go b/pkg/sqlite/query.go index 597ab66b98f..9c09d8beaed 100644 --- a/pkg/sqlite/query.go +++ b/pkg/sqlite/query.go @@ -110,6 +110,16 @@ func (qb *queryBuilder) addArg(args ...interface{}) { qb.args = append(qb.args, args...) } +func (qb *queryBuilder) hasJoin(alias string) bool { + for _, j := range qb.joins { + if j.alias() == alias { + return true + } + } + + return false +} + func (qb *queryBuilder) join(table, as, onClause string) { newJoin := join{ table: table, diff --git a/pkg/sqlite/scene.go b/pkg/sqlite/scene.go index 9b8bd73157b..c950be4d160 100644 --- a/pkg/sqlite/scene.go +++ b/pkg/sqlite/scene.go @@ -791,13 +791,6 @@ func (qb *SceneStore) FindByGroupID(ctx context.Context, groupID int) ([]*models return ret, nil } -func (qb *SceneStore) CountByGroupID(ctx context.Context, groupID int) (int, error) { - joinTable := scenesGroupsJoinTable - - q := dialect.Select(goqu.COUNT("*")).From(joinTable).Where(joinTable.Col(groupIDColumn).Eq(groupID)) - return count(ctx, q) -} - func (qb *SceneStore) Count(ctx context.Context) (int, error) { q := dialect.Select(goqu.COUNT("*")).From(qb.table()) return count(ctx, q) @@ -858,6 +851,7 @@ func (qb *SceneStore) PlayDuration(ctx context.Context) (float64, error) { return ret, nil } +// TODO - currently only used by unit test func (qb *SceneStore) CountByStudioID(ctx context.Context, studioID int) (int, error) { table := qb.table() @@ -865,13 +859,6 @@ func (qb *SceneStore) CountByStudioID(ctx context.Context, studioID int) (int, e return count(ctx, q) } -func (qb *SceneStore) CountByTagID(ctx context.Context, tagID int) (int, error) { - joinTable := scenesTagsJoinTable - - q := dialect.Select(goqu.COUNT("*")).From(joinTable).Where(joinTable.Col(tagIDColumn).Eq(tagID)) - return count(ctx, q) -} - func (qb *SceneStore) countMissingFingerprints(ctx context.Context, fpType string) (int, error) { fpTable := fingerprintTableMgr.table.As("fingerprints_temp") diff --git a/pkg/sqlite/scene_filter.go b/pkg/sqlite/scene_filter.go index 3f2233395fa..2e63dad975f 100644 --- a/pkg/sqlite/scene_filter.go +++ b/pkg/sqlite/scene_filter.go @@ -149,7 +149,7 @@ func (qb *sceneFilterHandler) criterionHandler() criterionHandler { studioCriterionHandler(sceneTable, sceneFilter.Studios), qb.groupsCriterionHandler(sceneFilter.Groups), - qb.groupsCriterionHandler(sceneFilter.Movies), + qb.moviesCriterionHandler(sceneFilter.Movies), qb.galleriesCriterionHandler(sceneFilter.Galleries), qb.performerTagsCriterionHandler(sceneFilter.PerformerTags), @@ -483,7 +483,8 @@ func (qb *sceneFilterHandler) performerAgeCriterionHandler(performerAge *models. } } -func (qb *sceneFilterHandler) groupsCriterionHandler(movies *models.MultiCriterionInput) criterionHandlerFunc { +// legacy handler +func (qb *sceneFilterHandler) moviesCriterionHandler(movies *models.MultiCriterionInput) criterionHandlerFunc { addJoinsFunc := func(f *filterBuilder) { sceneRepository.groups.join(f, "", "scenes.id") f.addLeftJoin("groups", "", "groups_scenes.group_id = groups.id") @@ -492,6 +493,23 @@ func (qb *sceneFilterHandler) groupsCriterionHandler(movies *models.MultiCriteri return h.handler(movies) } +func (qb *sceneFilterHandler) groupsCriterionHandler(groups *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + h := joinedHierarchicalMultiCriterionHandlerBuilder{ + primaryTable: sceneTable, + foreignTable: groupTable, + foreignFK: "group_id", + + relationsTable: groupRelationsTable, + parentFK: "containing_id", + childFK: "sub_id", + joinAs: "scene_group", + joinTable: groupsScenesTable, + primaryFK: sceneIDColumn, + } + + return h.handler(groups) +} + func (qb *sceneFilterHandler) galleriesCriterionHandler(galleries *models.MultiCriterionInput) criterionHandlerFunc { addJoinsFunc := func(f *filterBuilder) { sceneRepository.galleries.join(f, "", "scenes.id") diff --git a/pkg/sqlite/scene_test.go b/pkg/sqlite/scene_test.go index 9116158fc9f..a3174d7278d 100644 --- a/pkg/sqlite/scene_test.go +++ b/pkg/sqlite/scene_test.go @@ -16,6 +16,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/sliceutil" + "github.com/stashapp/stash/pkg/sliceutil/intslice" "github.com/stretchr/testify/assert" ) @@ -2217,7 +2218,7 @@ func TestSceneQuery(t *testing.T) { }, }) if (err != nil) != tt.wantErr { - t.Errorf("PerformerStore.Query() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("SceneStore.Query() error = %v, wantErr %v", err, tt.wantErr) return } @@ -3873,6 +3874,100 @@ func TestSceneQueryStudioDepth(t *testing.T) { }) } +func TestSceneGroups(t *testing.T) { + type criterion struct { + valueIdxs []int + modifier models.CriterionModifier + depth int + } + + tests := []struct { + name string + c criterion + q string + includeIdxs []int + excludeIdxs []int + }{ + { + "includes", + criterion{ + []int{groupIdxWithScene}, + models.CriterionModifierIncludes, + 0, + }, + "", + []int{sceneIdxWithGroup}, + nil, + }, + { + "excludes", + criterion{ + []int{groupIdxWithScene}, + models.CriterionModifierExcludes, + 0, + }, + getSceneStringValue(sceneIdxWithGroup, titleField), + nil, + []int{sceneIdxWithGroup}, + }, + { + "includes (depth = 1)", + criterion{ + []int{groupIdxWithChildWithScene}, + models.CriterionModifierIncludes, + 1, + }, + "", + []int{sceneIdxWithGroupWithParent}, + nil, + }, + } + + for _, tt := range tests { + valueIDs := indexesToIDs(groupIDs, tt.c.valueIdxs) + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + sceneFilter := &models.SceneFilterType{ + Groups: &models.HierarchicalMultiCriterionInput{ + Value: intslice.IntSliceToStringSlice(valueIDs), + Modifier: tt.c.modifier, + }, + } + + if tt.c.depth != 0 { + sceneFilter.Groups.Depth = &tt.c.depth + } + + findFilter := &models.FindFilterType{} + if tt.q != "" { + findFilter.Q = &tt.q + } + + results, err := db.Scene.Query(ctx, models.SceneQueryOptions{ + SceneFilter: sceneFilter, + QueryOptions: models.QueryOptions{ + FindFilter: findFilter, + }, + }) + if err != nil { + t.Errorf("SceneStore.Query() error = %v", err) + return + } + + include := indexesToIDs(sceneIDs, tt.includeIdxs) + exclude := indexesToIDs(sceneIDs, tt.excludeIdxs) + + assert.Subset(results.IDs, include) + + for _, e := range exclude { + assert.NotContains(results.IDs, e) + } + }) + } +} + func TestSceneQueryMovies(t *testing.T) { withTxn(func(ctx context.Context) error { sqb := db.Scene @@ -4188,78 +4283,6 @@ func verifyScenesPerformerCount(t *testing.T, performerCountCriterion models.Int }) } -func TestSceneCountByTagID(t *testing.T) { - withTxn(func(ctx context.Context) error { - sqb := db.Scene - - sceneCount, err := sqb.CountByTagID(ctx, tagIDs[tagIdxWithScene]) - - if err != nil { - t.Errorf("error calling CountByTagID: %s", err.Error()) - } - - assert.Equal(t, 1, sceneCount) - - sceneCount, err = sqb.CountByTagID(ctx, 0) - - if err != nil { - t.Errorf("error calling CountByTagID: %s", err.Error()) - } - - assert.Equal(t, 0, sceneCount) - - return nil - }) -} - -func TestSceneCountByGroupID(t *testing.T) { - withTxn(func(ctx context.Context) error { - sqb := db.Scene - - sceneCount, err := sqb.CountByGroupID(ctx, groupIDs[groupIdxWithScene]) - - if err != nil { - t.Errorf("error calling CountByGroupID: %s", err.Error()) - } - - assert.Equal(t, 1, sceneCount) - - sceneCount, err = sqb.CountByGroupID(ctx, 0) - - if err != nil { - t.Errorf("error calling CountByGroupID: %s", err.Error()) - } - - assert.Equal(t, 0, sceneCount) - - return nil - }) -} - -func TestSceneCountByStudioID(t *testing.T) { - withTxn(func(ctx context.Context) error { - sqb := db.Scene - - sceneCount, err := sqb.CountByStudioID(ctx, studioIDs[studioIdxWithScene]) - - if err != nil { - t.Errorf("error calling CountByStudioID: %s", err.Error()) - } - - assert.Equal(t, 1, sceneCount) - - sceneCount, err = sqb.CountByStudioID(ctx, 0) - - if err != nil { - t.Errorf("error calling CountByStudioID: %s", err.Error()) - } - - assert.Equal(t, 0, sceneCount) - - return nil - }) -} - func TestFindByMovieID(t *testing.T) { withTxn(func(ctx context.Context) error { sqb := db.Scene diff --git a/pkg/sqlite/setup_test.go b/pkg/sqlite/setup_test.go index aa6af73c4c9..624ffb4e222 100644 --- a/pkg/sqlite/setup_test.go +++ b/pkg/sqlite/setup_test.go @@ -78,6 +78,7 @@ const ( sceneIdxWithGrandChildStudio sceneIdxMissingPhash sceneIdxWithPerformerParentTag + sceneIdxWithGroupWithParent // new indexes above lastSceneIdx @@ -153,9 +154,15 @@ const ( groupIdxWithTag groupIdxWithTwoTags groupIdxWithThreeTags + groupIdxWithGrandChild + groupIdxWithChild + groupIdxWithParentAndChild + groupIdxWithParent + groupIdxWithGrandParent + groupIdxWithParentAndScene + groupIdxWithChildWithScene // groups with dup names start from the end - // create 7 more basic groups (can remove this if we add more indexes) - groupIdxWithDupName = groupIdxWithStudio + 7 + groupIdxWithDupName groupsNameCase = groupIdxWithDupName groupsNameNoCase = 1 @@ -390,7 +397,8 @@ var ( } sceneGroups = linkMap{ - sceneIdxWithGroup: {groupIdxWithScene}, + sceneIdxWithGroup: {groupIdxWithScene}, + sceneIdxWithGroupWithParent: {groupIdxWithParentAndScene}, } sceneStudios = map[int]int{ @@ -541,15 +549,31 @@ var ( } ) +var ( + groupParentLinks = [][2]int{ + {groupIdxWithChild, groupIdxWithParent}, + {groupIdxWithGrandChild, groupIdxWithParentAndChild}, + {groupIdxWithParentAndChild, groupIdxWithGrandParent}, + {groupIdxWithChildWithScene, groupIdxWithParentAndScene}, + } +) + func indexesToIDs(ids []int, indexes []int) []int { ret := make([]int, len(indexes)) for i, idx := range indexes { - ret[i] = ids[idx] + ret[i] = indexToID(ids, idx) } return ret } +func indexToID(ids []int, idx int) int { + if idx < 0 { + return invalidID + } + return ids[idx] +} + func indexFromID(ids []int, id int) int { for i, v := range ids { if v == id { @@ -697,6 +721,10 @@ func populateDB() error { return fmt.Errorf("error linking tags parent: %s", err.Error()) } + if err := linkGroupsParent(ctx, db.Group); err != nil { + return fmt.Errorf("error linking tags parent: %s", err.Error()) + } + for _, ms := range markerSpecs { if err := createMarker(ctx, db.SceneMarker, ms); err != nil { return fmt.Errorf("error creating scene marker: %s", err.Error()) @@ -1885,6 +1913,24 @@ func linkTagsParent(ctx context.Context, qb models.TagReaderWriter) error { }) } +func linkGroupsParent(ctx context.Context, qb models.GroupReaderWriter) error { + return doLinks(groupParentLinks, func(parentIndex, childIndex int) error { + groupID := groupIDs[childIndex] + + p := models.GroupPartial{ + ContainingGroups: &models.UpdateGroupDescriptions{ + Groups: []models.GroupIDDescription{ + {GroupID: groupIDs[parentIndex]}, + }, + Mode: models.RelationshipUpdateModeAdd, + }, + } + + _, err := qb.UpdatePartial(ctx, groupID, p) + return err + }) +} + func addTagImage(ctx context.Context, qb models.TagWriter, tagIndex int) error { return qb.UpdateImage(ctx, tagIDs[tagIndex], []byte("image")) } diff --git a/pkg/sqlite/tables.go b/pkg/sqlite/tables.go index 365abe81292..74a5ebe698c 100644 --- a/pkg/sqlite/tables.go +++ b/pkg/sqlite/tables.go @@ -37,8 +37,9 @@ var ( studiosTagsJoinTable = goqu.T(studiosTagsTable) studiosStashIDsJoinTable = goqu.T("studio_stash_ids") - groupsURLsJoinTable = goqu.T(groupURLsTable) - groupsTagsJoinTable = goqu.T(groupsTagsTable) + groupsURLsJoinTable = goqu.T(groupURLsTable) + groupsTagsJoinTable = goqu.T(groupsTagsTable) + groupRelationsJoinTable = goqu.T(groupRelationsTable) tagsAliasesJoinTable = goqu.T(tagAliasesTable) tagRelationsJoinTable = goqu.T(tagRelationsTable) @@ -361,6 +362,10 @@ var ( foreignTable: tagTableMgr, orderBy: tagTableMgr.table.Col("name").Asc(), } + + groupRelationshipTableMgr = &table{ + table: groupRelationsJoinTable, + } ) var ( diff --git a/pkg/sqlite/tag_filter.go b/pkg/sqlite/tag_filter.go index 26e33c36c6e..ba9e9bb08ec 100644 --- a/pkg/sqlite/tag_filter.go +++ b/pkg/sqlite/tag_filter.go @@ -2,7 +2,6 @@ package sqlite import ( "context" - "fmt" "github.com/stashapp/stash/pkg/models" ) @@ -51,6 +50,14 @@ func (qb *tagFilterHandler) handle(ctx context.Context, f *filterBuilder) { f.handleCriterion(ctx, qb.criterionHandler()) } +var tagHierarchyHandler = hierarchicalRelationshipHandler{ + primaryTable: tagTable, + relationTable: tagRelationsTable, + aliasPrefix: tagTable, + parentIDCol: "parent_id", + childIDCol: "child_id", +} + func (qb *tagFilterHandler) criterionHandler() criterionHandler { tagFilter := qb.tagFilter return compoundHandler{ @@ -72,10 +79,10 @@ func (qb *tagFilterHandler) criterionHandler() criterionHandler { qb.groupCountCriterionHandler(tagFilter.MovieCount), qb.markerCountCriterionHandler(tagFilter.MarkerCount), - qb.parentsCriterionHandler(tagFilter.Parents), - qb.childrenCriterionHandler(tagFilter.Children), - qb.parentCountCriterionHandler(tagFilter.ParentCount), - qb.childCountCriterionHandler(tagFilter.ChildCount), + tagHierarchyHandler.ParentsCriterionHandler(tagFilter.Parents), + tagHierarchyHandler.ChildrenCriterionHandler(tagFilter.Children), + tagHierarchyHandler.ParentCountCriterionHandler(tagFilter.ParentCount), + tagHierarchyHandler.ChildCountCriterionHandler(tagFilter.ChildCount), ×tampCriterionHandler{tagFilter.CreatedAt, "tags.created_at", nil}, ×tampCriterionHandler{tagFilter.UpdatedAt, "tags.updated_at", nil}, @@ -212,213 +219,3 @@ func (qb *tagFilterHandler) markerCountCriterionHandler(markerCount *models.IntC } } } - -func (qb *tagFilterHandler) parentsCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { - return func(ctx context.Context, f *filterBuilder) { - if criterion != nil { - tags := criterion.CombineExcludes() - - // validate the modifier - switch tags.Modifier { - case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull: - // valid - default: - f.setError(fmt.Errorf("invalid modifier %s for tag parent/children", criterion.Modifier)) - } - - if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull { - var notClause string - if tags.Modifier == models.CriterionModifierNotNull { - notClause = "NOT" - } - - f.addLeftJoin("tags_relations", "parent_relations", "tags.id = parent_relations.child_id") - - f.addWhere(fmt.Sprintf("parent_relations.parent_id IS %s NULL", notClause)) - return - } - - if len(tags.Value) == 0 && len(tags.Excludes) == 0 { - return - } - - if len(tags.Value) > 0 { - var args []interface{} - for _, val := range tags.Value { - args = append(args, val) - } - - depthVal := 0 - if tags.Depth != nil { - depthVal = *tags.Depth - } - - var depthCondition string - if depthVal != -1 { - depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal) - } - - query := `parents AS ( - SELECT parent_id AS root_id, child_id AS item_id, 0 AS depth FROM tags_relations WHERE parent_id IN` + getInBinding(len(tags.Value)) + ` - UNION - SELECT root_id, child_id, depth + 1 FROM tags_relations INNER JOIN parents ON item_id = parent_id ` + depthCondition + ` - )` - - f.addRecursiveWith(query, args...) - - f.addLeftJoin("parents", "", "parents.item_id = tags.id") - - addHierarchicalConditionClauses(f, tags, "parents", "root_id") - } - - if len(tags.Excludes) > 0 { - var args []interface{} - for _, val := range tags.Excludes { - args = append(args, val) - } - - depthVal := 0 - if tags.Depth != nil { - depthVal = *tags.Depth - } - - var depthCondition string - if depthVal != -1 { - depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal) - } - - query := `parents2 AS ( - SELECT parent_id AS root_id, child_id AS item_id, 0 AS depth FROM tags_relations WHERE parent_id IN` + getInBinding(len(tags.Excludes)) + ` - UNION - SELECT root_id, child_id, depth + 1 FROM tags_relations INNER JOIN parents2 ON item_id = parent_id ` + depthCondition + ` - )` - - f.addRecursiveWith(query, args...) - - f.addLeftJoin("parents2", "", "parents2.item_id = tags.id") - - addHierarchicalConditionClauses(f, models.HierarchicalMultiCriterionInput{ - Value: tags.Excludes, - Depth: tags.Depth, - Modifier: models.CriterionModifierExcludes, - }, "parents2", "root_id") - } - } - } -} - -func (qb *tagFilterHandler) childrenCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { - return func(ctx context.Context, f *filterBuilder) { - if criterion != nil { - tags := criterion.CombineExcludes() - - // validate the modifier - switch tags.Modifier { - case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull: - // valid - default: - f.setError(fmt.Errorf("invalid modifier %s for tag parent/children", criterion.Modifier)) - } - - if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull { - var notClause string - if tags.Modifier == models.CriterionModifierNotNull { - notClause = "NOT" - } - - f.addLeftJoin("tags_relations", "child_relations", "tags.id = child_relations.parent_id") - - f.addWhere(fmt.Sprintf("child_relations.child_id IS %s NULL", notClause)) - return - } - - if len(tags.Value) == 0 && len(tags.Excludes) == 0 { - return - } - - if len(tags.Value) > 0 { - var args []interface{} - for _, val := range tags.Value { - args = append(args, val) - } - - depthVal := 0 - if tags.Depth != nil { - depthVal = *tags.Depth - } - - var depthCondition string - if depthVal != -1 { - depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal) - } - - query := `children AS ( - SELECT child_id AS root_id, parent_id AS item_id, 0 AS depth FROM tags_relations WHERE child_id IN` + getInBinding(len(tags.Value)) + ` - UNION - SELECT root_id, parent_id, depth + 1 FROM tags_relations INNER JOIN children ON item_id = child_id ` + depthCondition + ` - )` - - f.addRecursiveWith(query, args...) - - f.addLeftJoin("children", "", "children.item_id = tags.id") - - addHierarchicalConditionClauses(f, tags, "children", "root_id") - } - - if len(tags.Excludes) > 0 { - var args []interface{} - for _, val := range tags.Excludes { - args = append(args, val) - } - - depthVal := 0 - if tags.Depth != nil { - depthVal = *tags.Depth - } - - var depthCondition string - if depthVal != -1 { - depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal) - } - - query := `children2 AS ( - SELECT child_id AS root_id, parent_id AS item_id, 0 AS depth FROM tags_relations WHERE child_id IN` + getInBinding(len(tags.Excludes)) + ` - UNION - SELECT root_id, parent_id, depth + 1 FROM tags_relations INNER JOIN children2 ON item_id = child_id ` + depthCondition + ` - )` - - f.addRecursiveWith(query, args...) - - f.addLeftJoin("children2", "", "children2.item_id = tags.id") - - addHierarchicalConditionClauses(f, models.HierarchicalMultiCriterionInput{ - Value: tags.Excludes, - Depth: tags.Depth, - Modifier: models.CriterionModifierExcludes, - }, "children2", "root_id") - } - } - } -} - -func (qb *tagFilterHandler) parentCountCriterionHandler(parentCount *models.IntCriterionInput) criterionHandlerFunc { - return func(ctx context.Context, f *filterBuilder) { - if parentCount != nil { - f.addLeftJoin("tags_relations", "parents_count", "parents_count.child_id = tags.id") - clause, args := getIntCriterionWhereClause("count(distinct parents_count.parent_id)", *parentCount) - - f.addHaving(clause, args...) - } - } -} - -func (qb *tagFilterHandler) childCountCriterionHandler(childCount *models.IntCriterionInput) criterionHandlerFunc { - return func(ctx context.Context, f *filterBuilder) { - if childCount != nil { - f.addLeftJoin("tags_relations", "children_count", "children_count.parent_id = tags.id") - clause, args := getIntCriterionWhereClause("count(distinct children_count.child_id)", *childCount) - - f.addHaving(clause, args...) - } - } -} diff --git a/pkg/sqlite/tag_test.go b/pkg/sqlite/tag_test.go index f673567f82e..5359be78517 100644 --- a/pkg/sqlite/tag_test.go +++ b/pkg/sqlite/tag_test.go @@ -712,7 +712,7 @@ func TestTagQueryParent(t *testing.T) { assert.Len(t, tags, 1) // ensure id is correct - assert.Equal(t, sceneIDs[tagIdxWithParentTag], tags[0].ID) + assert.Equal(t, tagIDs[tagIdxWithParentTag], tags[0].ID) tagCriterion.Modifier = models.CriterionModifierExcludes diff --git a/ui/v2.5/graphql/data/group.graphql b/ui/v2.5/graphql/data/group.graphql index 60f55e30948..963e8d6e672 100644 --- a/ui/v2.5/graphql/data/group.graphql +++ b/ui/v2.5/graphql/data/group.graphql @@ -15,11 +15,21 @@ fragment GroupData on Group { ...SlimTagData } + containing_groups { + group { + ...SlimGroupData + } + description + } + synopsis urls front_image_path back_image_path scene_count + scene_count_all: scene_count(depth: -1) + sub_group_count + sub_group_count_all: sub_group_count(depth: -1) scenes { id diff --git a/ui/v2.5/graphql/mutations/group.graphql b/ui/v2.5/graphql/mutations/group.graphql index fb739e84009..8065e4adbb7 100644 --- a/ui/v2.5/graphql/mutations/group.graphql +++ b/ui/v2.5/graphql/mutations/group.graphql @@ -23,3 +23,15 @@ mutation GroupDestroy($id: ID!) { mutation GroupsDestroy($ids: [ID!]!) { groupsDestroy(ids: $ids) } + +mutation AddGroupSubGroups($input: GroupSubGroupAddInput!) { + addGroupSubGroups(input: $input) +} + +mutation RemoveGroupSubGroups($input: GroupSubGroupRemoveInput!) { + removeGroupSubGroups(input: $input) +} + +mutation ReorderSubGroups($input: ReorderSubGroupsInput!) { + reorderSubGroups(input: $input) +} diff --git a/ui/v2.5/src/components/Groups/ContainingGroupsMultiSet.tsx b/ui/v2.5/src/components/Groups/ContainingGroupsMultiSet.tsx new file mode 100644 index 00000000000..25ad0be4a3e --- /dev/null +++ b/ui/v2.5/src/components/Groups/ContainingGroupsMultiSet.tsx @@ -0,0 +1,61 @@ +import React from "react"; +import * as GQL from "src/core/generated-graphql"; +import { MultiSetModeButtons } from "../Shared/MultiSet"; +import { + IRelatedGroupEntry, + RelatedGroupTable, +} from "./GroupDetails/RelatedGroupTable"; +import { Group, GroupSelect } from "./GroupSelect"; + +export const ContainingGroupsMultiSet: React.FC<{ + existingValue?: IRelatedGroupEntry[]; + value: IRelatedGroupEntry[]; + mode: GQL.BulkUpdateIdMode; + disabled?: boolean; + onUpdate: (value: IRelatedGroupEntry[]) => void; + onSetMode: (mode: GQL.BulkUpdateIdMode) => void; +}> = (props) => { + const { mode, onUpdate, existingValue } = props; + + function onSetMode(m: GQL.BulkUpdateIdMode) { + if (m === mode) { + return; + } + + // if going to Set, set the existing ids + if (m === GQL.BulkUpdateIdMode.Set && existingValue) { + onUpdate(existingValue); + // if going from Set, wipe the ids + } else if ( + m !== GQL.BulkUpdateIdMode.Set && + mode === GQL.BulkUpdateIdMode.Set + ) { + onUpdate([]); + } + + props.onSetMode(m); + } + + function onRemoveSet(items: Group[]) { + onUpdate(items.map((group) => ({ group }))); + } + + return ( +