diff --git a/graphql/schema/schema.graphql b/graphql/schema/schema.graphql index 5ec16b17b07..c5b8d60892d 100644 --- a/graphql/schema/schema.graphql +++ b/graphql/schema/schema.graphql @@ -325,6 +325,7 @@ type Mutation { tagDestroy(input: TagDestroyInput!): Boolean! tagsDestroy(ids: [ID!]!): Boolean! tagsMerge(input: TagsMergeInput!): Tag + bulkTagUpdate(input: BulkTagUpdateInput!): [Tag!] """ Moves the given files to the given destination. Returns true if successful. diff --git a/graphql/schema/types/tag.graphql b/graphql/schema/types/tag.graphql index 69b8221c5f1..6438b52e1fa 100644 --- a/graphql/schema/types/tag.graphql +++ b/graphql/schema/types/tag.graphql @@ -60,3 +60,14 @@ input TagsMergeInput { source: [ID!]! destination: ID! } + +input BulkTagUpdateInput { + ids: [ID!] + description: String + aliases: BulkUpdateStrings + ignore_auto_tag: Boolean + favorite: Boolean + + parent_ids: BulkUpdateIds + child_ids: BulkUpdateIds +} diff --git a/internal/api/resolver_model_tag.go b/internal/api/resolver_model_tag.go index 9124b18f483..d219fcc66d7 100644 --- a/internal/api/resolver_model_tag.go +++ b/internal/api/resolver_model_tag.go @@ -3,6 +3,7 @@ package api import ( "context" + "github.com/stashapp/stash/internal/api/loaders" "github.com/stashapp/stash/internal/api/urlbuilders" "github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/image" @@ -12,36 +13,43 @@ import ( ) func (r *tagResolver) Parents(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) { - if err := r.withReadTxn(ctx, func(ctx context.Context) error { - ret, err = r.repository.Tag.FindByChildTagID(ctx, obj.ID) - return err - }); err != nil { - return nil, err + if !obj.ParentIDs.Loaded() { + if err := r.withReadTxn(ctx, func(ctx context.Context) error { + return obj.LoadParentIDs(ctx, r.repository.Tag) + }); err != nil { + return nil, err + } } - return ret, nil + var errs []error + ret, errs = loaders.From(ctx).TagByID.LoadAll(obj.ParentIDs.List()) + return ret, firstError(errs) } func (r *tagResolver) Children(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) { - if err := r.withReadTxn(ctx, func(ctx context.Context) error { - ret, err = r.repository.Tag.FindByParentTagID(ctx, obj.ID) - return err - }); err != nil { - return nil, err + if !obj.ChildIDs.Loaded() { + if err := r.withReadTxn(ctx, func(ctx context.Context) error { + return obj.LoadChildIDs(ctx, r.repository.Tag) + }); err != nil { + return nil, err + } } - return ret, nil + var errs []error + ret, errs = loaders.From(ctx).TagByID.LoadAll(obj.ChildIDs.List()) + return ret, firstError(errs) } func (r *tagResolver) Aliases(ctx context.Context, obj *models.Tag) (ret []string, err error) { - if err := r.withReadTxn(ctx, func(ctx context.Context) error { - ret, err = r.repository.Tag.GetAliases(ctx, obj.ID) - return err - }); err != nil { - return nil, err + if !obj.Aliases.Loaded() { + if err := r.withReadTxn(ctx, func(ctx context.Context) error { + return obj.LoadAliases(ctx, r.repository.Tag) + }); err != nil { + return nil, err + } } - return ret, err + return obj.Aliases.List(), nil } func (r *tagResolver) SceneCount(ctx context.Context, obj *models.Tag, depth *int) (ret int, err error) { diff --git a/internal/api/resolver_mutation_tag.go b/internal/api/resolver_mutation_tag.go index 2c3128c58d4..2554f1bb55f 100644 --- a/internal/api/resolver_mutation_tag.go +++ b/internal/api/resolver_mutation_tag.go @@ -33,26 +33,21 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput) newTag := models.NewTag() newTag.Name = input.Name + newTag.Aliases = models.NewRelatedStrings(input.Aliases) newTag.Favorite = translator.bool(input.Favorite) newTag.Description = translator.string(input.Description) newTag.IgnoreAutoTag = translator.bool(input.IgnoreAutoTag) var err error - var parentIDs []int - if len(input.ParentIds) > 0 { - parentIDs, err = stringslice.StringSliceToIntSlice(input.ParentIds) - if err != nil { - return nil, fmt.Errorf("converting parent ids: %w", err) - } + newTag.ParentIDs, err = translator.relatedIds(input.ParentIds) + if err != nil { + return nil, fmt.Errorf("converting parent tag ids: %w", err) } - var childIDs []int - if len(input.ChildIds) > 0 { - childIDs, err = stringslice.StringSliceToIntSlice(input.ChildIds) - if err != nil { - return nil, fmt.Errorf("converting child ids: %w", err) - } + newTag.ChildIDs, err = translator.relatedIds(input.ChildIds) + if err != nil { + return nil, fmt.Errorf("converting child tag ids: %w", err) } // Process the base 64 encoded image string @@ -68,8 +63,7 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput) if err := r.withTxn(ctx, func(ctx context.Context) error { qb := r.repository.Tag - // ensure name is unique - if err := tag.EnsureTagNameUnique(ctx, 0, newTag.Name, qb); err != nil { + if err := tag.ValidateCreate(ctx, newTag, qb); err != nil { return err } @@ -85,36 +79,6 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput) } } - if len(input.Aliases) > 0 { - if err := tag.EnsureAliasesUnique(ctx, newTag.ID, input.Aliases, qb); err != nil { - return err - } - - if err := qb.UpdateAliases(ctx, newTag.ID, input.Aliases); err != nil { - return err - } - } - - if len(parentIDs) > 0 { - if err := qb.UpdateParentTags(ctx, newTag.ID, parentIDs); err != nil { - return err - } - } - - if len(childIDs) > 0 { - if err := qb.UpdateChildTags(ctx, newTag.ID, childIDs); err != nil { - return err - } - } - - // FIXME: This should be called before any changes are made, but - // requires a rewrite of ValidateHierarchy. - if len(parentIDs) > 0 || len(childIDs) > 0 { - if err := tag.ValidateHierarchy(ctx, &newTag, parentIDs, childIDs, qb); err != nil { - return err - } - } - return nil }); err != nil { return nil, err @@ -137,24 +101,21 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput) // Populate tag from the input updatedTag := models.NewTagPartial() + updatedTag.Name = translator.optionalString(input.Name, "name") updatedTag.Favorite = translator.optionalBool(input.Favorite, "favorite") updatedTag.IgnoreAutoTag = translator.optionalBool(input.IgnoreAutoTag, "ignore_auto_tag") updatedTag.Description = translator.optionalString(input.Description, "description") - var parentIDs []int - if translator.hasField("parent_ids") { - parentIDs, err = stringslice.StringSliceToIntSlice(input.ParentIds) - if err != nil { - return nil, fmt.Errorf("converting parent ids: %w", err) - } + updatedTag.Aliases = translator.updateStrings(input.Aliases, "aliases") + + updatedTag.ParentIDs, err = translator.updateIds(input.ParentIds, "parent_ids") + if err != nil { + return nil, fmt.Errorf("converting parent tag ids: %w", err) } - var childIDs []int - if translator.hasField("child_ids") { - childIDs, err = stringslice.StringSliceToIntSlice(input.ChildIds) - if err != nil { - return nil, fmt.Errorf("converting child ids: %w", err) - } + updatedTag.ChildIDs, err = translator.updateIds(input.ChildIds, "child_ids") + if err != nil { + return nil, fmt.Errorf("converting child tag ids: %w", err) } var imageData []byte @@ -171,24 +132,10 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput) if err := r.withTxn(ctx, func(ctx context.Context) error { qb := r.repository.Tag - // ensure name is unique - t, err = qb.Find(ctx, tagID) - if err != nil { + if err := tag.ValidateUpdate(ctx, tagID, updatedTag, qb); err != nil { return err } - if t == nil { - return fmt.Errorf("tag with id %d not found", tagID) - } - - if input.Name != nil && t.Name != *input.Name { - if err := tag.EnsureTagNameUnique(ctx, tagID, *input.Name, qb); err != nil { - return err - } - - updatedTag.Name = models.NewOptionalString(*input.Name) - } - t, err = qb.UpdatePartial(ctx, tagID, updatedTag) if err != nil { return err @@ -201,35 +148,61 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput) } } - if translator.hasField("aliases") { - if err := tag.EnsureAliasesUnique(ctx, tagID, input.Aliases, qb); err != nil { - return err - } + return nil + }); err != nil { + return nil, err + } - if err := qb.UpdateAliases(ctx, tagID, input.Aliases); err != nil { - return err - } - } + r.hookExecutor.ExecutePostHooks(ctx, t.ID, hook.TagUpdatePost, input, translator.getFields()) + return r.getTag(ctx, t.ID) +} - if parentIDs != nil { - if err := qb.UpdateParentTags(ctx, tagID, parentIDs); err != nil { - return err - } - } +func (r *mutationResolver) BulkTagUpdate(ctx context.Context, input BulkTagUpdateInput) ([]*models.Tag, error) { + tagIDs, err := stringslice.StringSliceToIntSlice(input.Ids) + if err != nil { + return nil, fmt.Errorf("converting ids: %w", err) + } + + translator := changesetTranslator{ + inputMap: getUpdateInputMap(ctx), + } + + // Populate scene from the input + updatedTag := models.NewTagPartial() + + updatedTag.Description = translator.optionalString(input.Description, "description") + updatedTag.Favorite = translator.optionalBool(input.Favorite, "favorite") + updatedTag.IgnoreAutoTag = translator.optionalBool(input.IgnoreAutoTag, "ignore_auto_tag") + + updatedTag.Aliases = translator.updateStringsBulk(input.Aliases, "aliases") + + updatedTag.ParentIDs, err = translator.updateIdsBulk(input.ParentIds, "parent_ids") + if err != nil { + return nil, fmt.Errorf("converting parent tag ids: %w", err) + } + + updatedTag.ChildIDs, err = translator.updateIdsBulk(input.ChildIds, "child_ids") + if err != nil { + return nil, fmt.Errorf("converting child tag ids: %w", err) + } + + ret := []*models.Tag{} + + // Start the transaction and save the scenes + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Tag - if childIDs != nil { - if err := qb.UpdateChildTags(ctx, tagID, childIDs); err != nil { + for _, tagID := range tagIDs { + if err := tag.ValidateUpdate(ctx, tagID, updatedTag, qb); err != nil { return err } - } - // FIXME: This should be called before any changes are made, but - // requires a rewrite of ValidateHierarchy. - if parentIDs != nil || childIDs != nil { - if err := tag.ValidateHierarchy(ctx, t, parentIDs, childIDs, qb); err != nil { - logger.Errorf("Error saving tag: %s", err) + tag, err := qb.UpdatePartial(ctx, tagID, updatedTag) + if err != nil { return err } + + ret = append(ret, tag) } return nil @@ -237,8 +210,20 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput) return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, t.ID, hook.TagUpdatePost, input, translator.getFields()) - return r.getTag(ctx, t.ID) + // execute post hooks outside of txn + var newRet []*models.Tag + for _, tag := range ret { + r.hookExecutor.ExecutePostHooks(ctx, tag.ID, hook.TagUpdatePost, input, translator.getFields()) + + tag, err = r.getTag(ctx, tag.ID) + if err != nil { + return nil, err + } + + newRet = append(newRet, tag) + } + + return newRet, nil } func (r *mutationResolver) TagDestroy(ctx context.Context, input TagDestroyInput) (bool, error) { @@ -331,7 +316,7 @@ func (r *mutationResolver) TagsMerge(ctx context.Context, input TagsMergeInput) return err } - err = tag.ValidateHierarchy(ctx, t, parents, children, qb) + err = tag.ValidateHierarchyExisting(ctx, t, parents, children, qb) if err != nil { logger.Errorf("Error merging tag: %s", err) return err diff --git a/pkg/models/mocks/TagReaderWriter.go b/pkg/models/mocks/TagReaderWriter.go index 9b610e49b6e..f4c494016f3 100644 --- a/pkg/models/mocks/TagReaderWriter.go +++ b/pkg/models/mocks/TagReaderWriter.go @@ -450,6 +450,29 @@ func (_m *TagReaderWriter) GetAliases(ctx context.Context, relatedID int) ([]str return r0, r1 } +// GetChildIDs provides a mock function with given fields: ctx, relatedID +func (_m *TagReaderWriter) GetChildIDs(ctx context.Context, relatedID int) ([]int, error) { + ret := _m.Called(ctx, relatedID) + + var r0 []int + if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok { + r0 = rf(ctx, relatedID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, relatedID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetImage provides a mock function with given fields: ctx, tagID func (_m *TagReaderWriter) GetImage(ctx context.Context, tagID int) ([]byte, error) { ret := _m.Called(ctx, tagID) @@ -473,6 +496,29 @@ func (_m *TagReaderWriter) GetImage(ctx context.Context, tagID int) ([]byte, err return r0, r1 } +// GetParentIDs provides a mock function with given fields: ctx, relatedID +func (_m *TagReaderWriter) GetParentIDs(ctx context.Context, relatedID int) ([]int, error) { + ret := _m.Called(ctx, relatedID) + + var r0 []int + if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok { + r0 = rf(ctx, relatedID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, relatedID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // HasImage provides a mock function with given fields: ctx, tagID func (_m *TagReaderWriter) HasImage(ctx context.Context, tagID int) (bool, error) { ret := _m.Called(ctx, tagID) diff --git a/pkg/models/model_tag.go b/pkg/models/model_tag.go index 04f5ac1a2ec..e8a797e8760 100644 --- a/pkg/models/model_tag.go +++ b/pkg/models/model_tag.go @@ -1,6 +1,7 @@ package models import ( + "context" "time" ) @@ -12,6 +13,10 @@ type Tag struct { IgnoreAutoTag bool `json:"ignore_auto_tag"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` + + Aliases RelatedStrings `json:"aliases"` + ParentIDs RelatedIDs `json:"parent_ids"` + ChildIDs RelatedIDs `json:"tag_ids"` } func NewTag() Tag { @@ -22,6 +27,24 @@ func NewTag() Tag { } } +func (s *Tag) LoadAliases(ctx context.Context, l AliasLoader) error { + return s.Aliases.load(func() ([]string, error) { + return l.GetAliases(ctx, s.ID) + }) +} + +func (s *Tag) LoadParentIDs(ctx context.Context, l TagRelationLoader) error { + return s.ParentIDs.load(func() ([]int, error) { + return l.GetParentIDs(ctx, s.ID) + }) +} + +func (s *Tag) LoadChildIDs(ctx context.Context, l TagRelationLoader) error { + return s.ChildIDs.load(func() ([]int, error) { + return l.GetChildIDs(ctx, s.ID) + }) +} + type TagPartial struct { Name OptionalString Description OptionalString @@ -29,6 +52,10 @@ type TagPartial struct { IgnoreAutoTag OptionalBool CreatedAt OptionalTime UpdatedAt OptionalTime + + Aliases *UpdateStrings + ParentIDs *UpdateIDs + ChildIDs *UpdateIDs } func NewTagPartial() TagPartial { diff --git a/pkg/models/relationships.go b/pkg/models/relationships.go index 29772890f04..021fab4dbfb 100644 --- a/pkg/models/relationships.go +++ b/pkg/models/relationships.go @@ -24,6 +24,11 @@ type TagIDLoader interface { GetTagIDs(ctx context.Context, relatedID int) ([]int, error) } +type TagRelationLoader interface { + GetParentIDs(ctx context.Context, relatedID int) ([]int, error) + GetChildIDs(ctx context.Context, relatedID int) ([]int, error) +} + type FileIDLoader interface { GetManyFileIDs(ctx context.Context, ids []int) ([][]FileID, error) } diff --git a/pkg/models/repository_tag.go b/pkg/models/repository_tag.go index ca8f6971bf7..6d38785e6d0 100644 --- a/pkg/models/repository_tag.go +++ b/pkg/models/repository_tag.go @@ -84,6 +84,7 @@ type TagReader interface { TagCounter AliasLoader + TagRelationLoader All(ctx context.Context) ([]*Tag, error) GetImage(ctx context.Context, tagID int) ([]byte, error) diff --git a/pkg/sqlite/tables.go b/pkg/sqlite/tables.go index 64d1e4eb236..701c503305d 100644 --- a/pkg/sqlite/tables.go +++ b/pkg/sqlite/tables.go @@ -36,6 +36,9 @@ var ( studiosStashIDsJoinTable = goqu.T("studio_stash_ids") moviesURLsJoinTable = goqu.T(movieURLsTable) + + tagsAliasesJoinTable = goqu.T(tagAliasesTable) + tagRelationsJoinTable = goqu.T(tagRelationsTable) ) var ( @@ -294,6 +297,24 @@ var ( table: goqu.T(tagTable), idColumn: goqu.T(tagTable).Col(idColumn), } + + tagsAliasesTableMgr = &stringTable{ + table: table{ + table: tagsAliasesJoinTable, + idColumn: tagsAliasesJoinTable.Col(tagIDColumn), + }, + stringColumn: tagsAliasesJoinTable.Col(tagAliasColumn), + } + + tagsParentTagsTableMgr = &joinTable{ + table: table{ + table: tagRelationsJoinTable, + idColumn: tagRelationsJoinTable.Col(tagChildIDColumn), + }, + fkColumn: tagRelationsJoinTable.Col(tagParentIDColumn), + } + + tagsChildTagsTableMgr = *tagsParentTagsTableMgr.invert() ) var ( diff --git a/pkg/sqlite/tag.go b/pkg/sqlite/tag.go index 99cc42edcdf..127ad3310e1 100644 --- a/pkg/sqlite/tag.go +++ b/pkg/sqlite/tag.go @@ -24,6 +24,10 @@ const ( tagAliasColumn = "alias" tagImageBlobColumn = "image_blob" + + tagRelationsTable = "tags_relations" + tagParentIDColumn = "parent_id" + tagChildIDColumn = "child_id" ) type tagRow struct { @@ -173,6 +177,24 @@ func (qb *TagStore) Create(ctx context.Context, newObject *models.Tag) error { return err } + if newObject.Aliases.Loaded() { + if err := tagsAliasesTableMgr.insertJoins(ctx, id, newObject.Aliases.List()); err != nil { + return err + } + } + + if newObject.ParentIDs.Loaded() { + if err := tagsParentTagsTableMgr.insertJoins(ctx, id, newObject.ParentIDs.List()); err != nil { + return err + } + } + + if newObject.ChildIDs.Loaded() { + if err := tagsChildTagsTableMgr.insertJoins(ctx, id, newObject.ChildIDs.List()); err != nil { + return err + } + } + updated, err := qb.find(ctx, id) if err != nil { return fmt.Errorf("finding after create: %w", err) @@ -198,6 +220,24 @@ func (qb *TagStore) UpdatePartial(ctx context.Context, id int, partial models.Ta } } + if partial.Aliases != nil { + if err := tagsAliasesTableMgr.modifyJoins(ctx, id, partial.Aliases.Values, partial.Aliases.Mode); err != nil { + return nil, err + } + } + + if partial.ParentIDs != nil { + if err := tagsParentTagsTableMgr.modifyJoins(ctx, id, partial.ParentIDs.IDs, partial.ParentIDs.Mode); err != nil { + return nil, err + } + } + + if partial.ChildIDs != nil { + if err := tagsChildTagsTableMgr.modifyJoins(ctx, id, partial.ChildIDs.IDs, partial.ChildIDs.Mode); err != nil { + return nil, err + } + } + return qb.find(ctx, id) } @@ -209,6 +249,24 @@ func (qb *TagStore) Update(ctx context.Context, updatedObject *models.Tag) error return err } + if updatedObject.Aliases.Loaded() { + if err := tagsAliasesTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.Aliases.List()); err != nil { + return err + } + } + + if updatedObject.ParentIDs.Loaded() { + if err := tagsParentTagsTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.ParentIDs.List()); err != nil { + return err + } + } + + if updatedObject.ChildIDs.Loaded() { + if err := tagsChildTagsTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.ChildIDs.List()); err != nil { + return err + } + } + return nil } @@ -423,6 +481,14 @@ func (qb *TagStore) FindByNames(ctx context.Context, names []string, nocase bool return ret, nil } +func (qb *TagStore) GetParentIDs(ctx context.Context, relatedID int) ([]int, error) { + return tagsParentTagsTableMgr.get(ctx, relatedID) +} + +func (qb *TagStore) GetChildIDs(ctx context.Context, relatedID int) ([]int, error) { + return tagsChildTagsTableMgr.get(ctx, relatedID) +} + func (qb *TagStore) FindByParentTagID(ctx context.Context, parentID int) ([]*models.Tag, error) { query := ` SELECT tags.* FROM tags diff --git a/pkg/tag/update.go b/pkg/tag/update.go index dcb78bf9cab..99e9b916569 100644 --- a/pkg/tag/update.go +++ b/pkg/tag/update.go @@ -33,6 +33,10 @@ type InvalidTagHierarchyError struct { } func (e *InvalidTagHierarchyError) Error() string { + if e.ApplyingTag == "" { + return fmt.Sprintf("cannot apply tag \"%s\" as a %s of tag as it is already %s", e.InvalidTag, e.Direction, e.CurrentRelation) + } + return fmt.Sprintf("cannot apply tag \"%s\" as a %s of \"%s\" as it is already %s (%s)", e.InvalidTag, e.Direction, e.ApplyingTag, e.CurrentRelation, e.TagPath) } @@ -80,16 +84,83 @@ func EnsureAliasesUnique(ctx context.Context, id int, aliases []string, qb model type RelationshipFinder interface { FindAllAncestors(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error) FindAllDescendants(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error) - FindByChildTagID(ctx context.Context, childID int) ([]*models.Tag, error) - FindByParentTagID(ctx context.Context, parentID int) ([]*models.Tag, error) + models.TagRelationLoader +} + +func ValidateHierarchyNew(ctx context.Context, parentIDs, childIDs []int, qb RelationshipFinder) error { + allAncestors := make(map[int]*models.TagPath) + allDescendants := make(map[int]*models.TagPath) + + for _, parentID := range parentIDs { + parentsAncestors, err := qb.FindAllAncestors(ctx, parentID, nil) + if err != nil { + return err + } + + for _, ancestorTag := range parentsAncestors { + allAncestors[ancestorTag.ID] = ancestorTag + } + } + + for _, childID := range childIDs { + childsDescendants, err := qb.FindAllDescendants(ctx, childID, nil) + if err != nil { + return err + } + + for _, descendentTag := range childsDescendants { + allDescendants[descendentTag.ID] = descendentTag + } + } + + // Validate that the tag is not a parent of any of its ancestors + validateParent := func(testID int) error { + if parentTag, exists := allDescendants[testID]; exists { + return &InvalidTagHierarchyError{ + Direction: "parent", + CurrentRelation: "a descendant", + InvalidTag: parentTag.Name, + TagPath: parentTag.Path, + } + } + + return nil + } + + // Validate that the tag is not a child of any of its ancestors + validateChild := func(testID int) error { + if childTag, exists := allAncestors[testID]; exists { + return &InvalidTagHierarchyError{ + Direction: "child", + CurrentRelation: "an ancestor", + InvalidTag: childTag.Name, + TagPath: childTag.Path, + } + } + + return nil + } + + for _, parentID := range parentIDs { + if err := validateParent(parentID); err != nil { + return err + } + } + + for _, childID := range childIDs { + if err := validateChild(childID); err != nil { + return err + } + } + + return nil } -func ValidateHierarchy(ctx context.Context, tag *models.Tag, parentIDs, childIDs []int, qb RelationshipFinder) error { - id := tag.ID +func ValidateHierarchyExisting(ctx context.Context, tag *models.Tag, parentIDs, childIDs []int, qb RelationshipFinder) error { allAncestors := make(map[int]*models.TagPath) allDescendants := make(map[int]*models.TagPath) - parentsAncestors, err := qb.FindAllAncestors(ctx, id, nil) + parentsAncestors, err := qb.FindAllAncestors(ctx, tag.ID, nil) if err != nil { return err } @@ -98,7 +169,7 @@ func ValidateHierarchy(ctx context.Context, tag *models.Tag, parentIDs, childIDs allAncestors[ancestorTag.ID] = ancestorTag } - childsDescendants, err := qb.FindAllDescendants(ctx, id, nil) + childsDescendants, err := qb.FindAllDescendants(ctx, tag.ID, nil) if err != nil { return err } @@ -135,28 +206,6 @@ func ValidateHierarchy(ctx context.Context, tag *models.Tag, parentIDs, childIDs return nil } - if parentIDs == nil { - parentTags, err := qb.FindByChildTagID(ctx, id) - if err != nil { - return err - } - - for _, parentTag := range parentTags { - parentIDs = append(parentIDs, parentTag.ID) - } - } - - if childIDs == nil { - childTags, err := qb.FindByParentTagID(ctx, id) - if err != nil { - return err - } - - for _, childTag := range childTags { - childIDs = append(childIDs, childTag.ID) - } - } - for _, parentID := range parentIDs { if err := validateParent(parentID); err != nil { return err @@ -176,38 +225,38 @@ func MergeHierarchy(ctx context.Context, destination int, sources []int, qb Rela var mergedParents, mergedChildren []int allIds := append([]int{destination}, sources...) - addTo := func(mergedItems []int, tags []*models.Tag) []int { + addTo := func(mergedItems []int, tagIDs []int) []int { Tags: - for _, tag := range tags { + for _, tagID := range tagIDs { // Ignore tags which are already set for _, existingItem := range mergedItems { - if tag.ID == existingItem { + if tagID == existingItem { continue Tags } } // Ignore tags which are being merged, as these are rolled up anyway (if A is merged into B any direct link between them can be ignored) for _, id := range allIds { - if tag.ID == id { + if tagID == id { continue Tags } } - mergedItems = append(mergedItems, tag.ID) + mergedItems = append(mergedItems, tagID) } return mergedItems } for _, id := range allIds { - parents, err := qb.FindByChildTagID(ctx, id) + parents, err := qb.GetParentIDs(ctx, id) if err != nil { return nil, nil, err } mergedParents = addTo(mergedParents, parents) - children, err := qb.FindByParentTagID(ctx, id) + children, err := qb.GetChildIDs(ctx, id) if err != nil { return nil, nil, err } diff --git a/pkg/tag/update_test.go b/pkg/tag/update_test.go index c581d34ac43..462c981434f 100644 --- a/pkg/tag/update_test.go +++ b/pkg/tag/update_test.go @@ -211,14 +211,11 @@ var testUniqueHierarchyCases = []testUniqueHierarchyCase{ func TestEnsureHierarchy(t *testing.T) { for _, tc := range testUniqueHierarchyCases { - testEnsureHierarchy(t, tc, false, false) - testEnsureHierarchy(t, tc, true, false) - testEnsureHierarchy(t, tc, false, true) - testEnsureHierarchy(t, tc, true, true) + testEnsureHierarchy(t, tc) } } -func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, queryChildren bool) { +func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase) { db := mocks.NewDatabase() var parentIDs, childIDs []int @@ -244,16 +241,6 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, } } - if queryParents { - parentIDs = nil - db.Tag.On("FindByChildTagID", testCtx, tc.id).Return(tc.parents, nil).Once() - } - - if queryChildren { - childIDs = nil - db.Tag.On("FindByParentTagID", testCtx, tc.id).Return(tc.children, nil).Once() - } - db.Tag.On("FindAllAncestors", testCtx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath { return tc.onFindAllAncestors }, func(ctx context.Context, tagID int, excludeIDs []int) error { @@ -272,7 +259,7 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, return fmt.Errorf("undefined descendants for: %d", tagID) }).Maybe() - res := ValidateHierarchy(testCtx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, db.Tag) + res := ValidateHierarchyExisting(testCtx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, db.Tag) assert := assert.New(t) diff --git a/pkg/tag/validate.go b/pkg/tag/validate.go new file mode 100644 index 00000000000..966cec9451b --- /dev/null +++ b/pkg/tag/validate.go @@ -0,0 +1,102 @@ +package tag + +import ( + "context" + "errors" + "fmt" + + "github.com/stashapp/stash/pkg/models" +) + +var ( + ErrNameMissing = errors.New("tag name must not be blank") +) + +type NotFoundError struct { + id int +} + +func (e *NotFoundError) Error() string { + return fmt.Sprintf("tag with id %d not found", e.id) +} + +func ValidateCreate(ctx context.Context, tag models.Tag, qb models.TagReader) error { + if tag.Name == "" { + return ErrNameMissing + } + + if err := EnsureTagNameUnique(ctx, 0, tag.Name, qb); err != nil { + return err + } + + if tag.Aliases.Loaded() { + if err := EnsureAliasesUnique(ctx, tag.ID, tag.Aliases.List(), qb); err != nil { + return err + } + } + + if len(tag.ParentIDs.List()) > 0 || len(tag.ChildIDs.List()) > 0 { + if err := ValidateHierarchyNew(ctx, tag.ParentIDs.List(), tag.ChildIDs.List(), qb); err != nil { + return err + } + } + + return nil +} + +func ValidateUpdate(ctx context.Context, id int, partial models.TagPartial, qb models.TagReader) error { + existing, err := qb.Find(ctx, id) + if err != nil { + return err + } + + if existing == nil { + return &NotFoundError{id} + } + + if partial.Name.Set { + if partial.Name.Value == "" { + return ErrNameMissing + } + + if err := EnsureTagNameUnique(ctx, id, partial.Name.Value, qb); err != nil { + return err + } + } + + if partial.Aliases != nil { + if err := existing.LoadAliases(ctx, qb); err != nil { + return err + } + + if err := EnsureAliasesUnique(ctx, id, partial.Aliases.Apply(existing.Aliases.List()), qb); err != nil { + return err + } + } + + if partial.ParentIDs != nil || partial.ChildIDs != nil { + if err := existing.LoadParentIDs(ctx, qb); err != nil { + return err + } + + if err := existing.LoadChildIDs(ctx, qb); err != nil { + return err + } + + parentIDs := partial.ParentIDs + if parentIDs == nil { + parentIDs = &models.UpdateIDs{IDs: existing.ParentIDs.List(), Mode: models.RelationshipUpdateModeSet} + } + + childIDs := partial.ChildIDs + if childIDs == nil { + childIDs = &models.UpdateIDs{IDs: existing.ChildIDs.List(), Mode: models.RelationshipUpdateModeSet} + } + + if err := ValidateHierarchyExisting(ctx, existing, parentIDs.Apply(existing.ParentIDs.List()), childIDs.Apply(existing.ChildIDs.List()), qb); err != nil { + return err + } + } + + return nil +} diff --git a/ui/v2.5/graphql/mutations/tag.graphql b/ui/v2.5/graphql/mutations/tag.graphql index 20e3b4b81a5..f2138e05702 100644 --- a/ui/v2.5/graphql/mutations/tag.graphql +++ b/ui/v2.5/graphql/mutations/tag.graphql @@ -18,6 +18,12 @@ mutation TagUpdate($input: TagUpdateInput!) { } } +mutation BulkTagUpdate($input: BulkTagUpdateInput!) { + bulkTagUpdate(input: $input) { + ...TagData + } +} + mutation TagsMerge($source: [ID!]!, $destination: ID!) { tagsMerge(input: { source: $source, destination: $destination }) { ...TagData diff --git a/ui/v2.5/src/components/Tags/EditTagsDialog.tsx b/ui/v2.5/src/components/Tags/EditTagsDialog.tsx new file mode 100644 index 00000000000..d771ea1c94c --- /dev/null +++ b/ui/v2.5/src/components/Tags/EditTagsDialog.tsx @@ -0,0 +1,237 @@ +import React, { useEffect, useState } from "react"; +import { Form } from "react-bootstrap"; +import { FormattedMessage, useIntl } from "react-intl"; +import { useBulkTagUpdate } from "src/core/StashService"; +import * as GQL from "src/core/generated-graphql"; +import { ModalComponent } from "../Shared/Modal"; +import { useToast } from "src/hooks/Toast"; +import { MultiSet } from "../Shared/MultiSet"; +import { + getAggregateState, + getAggregateStateObject, +} from "src/utils/bulkUpdate"; +import { IndeterminateCheckbox } from "../Shared/IndeterminateCheckbox"; +import { BulkUpdateTextInput } from "../Shared/BulkUpdateTextInput"; +import { faPencilAlt } from "@fortawesome/free-solid-svg-icons"; + +function Tags(props: { + isUpdating: boolean; + controlId: string; + messageId: string; + existingTagIds: string[] | undefined; + tagIDs: GQL.BulkUpdateIds; + setTagIDs: (value: React.SetStateAction) => void; +}) { + const { + isUpdating, + controlId, + messageId, + existingTagIds, + tagIDs, + setTagIDs, + } = props; + + return ( + + + + + + setTagIDs((existing) => ({ ...existing, ids: itemIDs })) + } + onSetMode={(newMode) => + setTagIDs((existing) => ({ ...existing, mode: newMode })) + } + existingIds={existingTagIds ?? []} + ids={tagIDs.ids ?? []} + mode={tagIDs.mode} + /> + + ); +} + +interface IListOperationProps { + selected: GQL.TagDataFragment[]; + onClose: (applied: boolean) => void; +} + +const tagFields = ["favorite", "description", "ignore_auto_tag"]; + +export const EditTagsDialog: React.FC = ( + props: IListOperationProps +) => { + const intl = useIntl(); + const Toast = useToast(); + + const [parentTagIDs, setParentTagIDs_] = useState({ + mode: GQL.BulkUpdateIdMode.Add, + }); + + function setParentTagIDs(value: React.SetStateAction) { + console.log(value); + setParentTagIDs_(value); + } + + const [existingParentTagIds, setExistingParentTagIds] = useState(); + + const [childTagIDs, setChildTagIDs] = useState({ + mode: GQL.BulkUpdateIdMode.Add, + }); + const [existingChildTagIds, setExistingChildTagIds] = useState(); + + const [updateInput, setUpdateInput] = useState({}); + + const [updateTags] = useBulkTagUpdate(getTagInput()); + + // Network state + const [isUpdating, setIsUpdating] = useState(false); + + function setUpdateField(input: Partial) { + setUpdateInput({ ...updateInput, ...input }); + } + + function getTagInput(): GQL.BulkTagUpdateInput { + const tagInput: GQL.BulkTagUpdateInput = { + ids: props.selected.map((tag) => { + return tag.id; + }), + ...updateInput, + parent_ids: parentTagIDs, + child_ids: childTagIDs, + }; + + return tagInput; + } + + async function onSave() { + setIsUpdating(true); + try { + await updateTags(); + Toast.success( + intl.formatMessage( + { id: "toast.updated_entity" }, + { + entity: intl.formatMessage({ id: "tags" }).toLocaleLowerCase(), + } + ) + ); + props.onClose(true); + } catch (e) { + Toast.error(e); + } + setIsUpdating(false); + } + + useEffect(() => { + const updateState: GQL.BulkTagUpdateInput = {}; + + const state = props.selected; + let updateParentTagIds: string[] = []; + let updateChildTagIds: string[] = []; + let first = true; + + state.forEach((tag: GQL.TagDataFragment) => { + getAggregateStateObject(updateState, tag, tagFields, first); + + const thisParents = (tag.parents ?? []).map((t) => t.id).sort(); + updateParentTagIds = + getAggregateState(updateParentTagIds, thisParents, first) ?? []; + + const thisChildren = (tag.children ?? []).map((t) => t.id).sort(); + updateChildTagIds = + getAggregateState(updateChildTagIds, thisChildren, first) ?? []; + + first = false; + }); + + setExistingParentTagIds(updateParentTagIds); + setExistingChildTagIds(updateChildTagIds); + setUpdateInput(updateState); + }, [props.selected]); + + function renderTextField( + name: string, + value: string | undefined | null, + setter: (newValue: string | undefined) => void + ) { + return ( + + + + + setter(newValue)} + unsetDisabled={props.selected.length < 2} + /> + + ); + } + + return ( + props.onClose(false), + text: intl.formatMessage({ id: "actions.cancel" }), + variant: "secondary", + }} + isRunning={isUpdating} + > +
+ + setUpdateField({ favorite: checked })} + checked={updateInput.favorite ?? undefined} + label={intl.formatMessage({ id: "favourite" })} + /> + + + {renderTextField("description", updateInput.description, (v) => + setUpdateField({ description: v }) + )} + + + + + + + + setUpdateField({ ignore_auto_tag: checked }) + } + checked={updateInput.ignore_auto_tag ?? undefined} + /> + + +
+ ); +}; diff --git a/ui/v2.5/src/components/Tags/TagList.tsx b/ui/v2.5/src/components/Tags/TagList.tsx index ea580c2c948..2458a273ba8 100644 --- a/ui/v2.5/src/components/Tags/TagList.tsx +++ b/ui/v2.5/src/components/Tags/TagList.tsx @@ -28,6 +28,7 @@ import { ExportDialog } from "../Shared/ExportDialog"; import { tagRelationHook } from "../../core/tags"; import { faTrashAlt } from "@fortawesome/free-solid-svg-icons"; import { TagCardGrid } from "./TagCardGrid"; +import { EditTagsDialog } from "./EditTagsDialog"; interface ITagList { filterHook?: (filter: ListFilterModel) => ListFilterModel; @@ -325,6 +326,13 @@ export const TagList: React.FC = ({ filterHook, alterQuery }) => { ); } + function renderEditDialog( + selectedTags: GQL.TagDataFragment[], + onClose: (confirmed: boolean) => void + ) { + return ; + } + function renderDeleteDialog( selectedTags: GQL.TagDataFragment[], onClose: (confirmed: boolean) => void @@ -361,6 +369,7 @@ export const TagList: React.FC = ({ filterHook, alterQuery }) => { addKeybinds={addKeybinds} renderContent={renderContent} renderDeleteDialog={renderDeleteDialog} + renderEditDialog={renderEditDialog} /> ); }; diff --git a/ui/v2.5/src/core/StashService.ts b/ui/v2.5/src/core/StashService.ts index 7559210a555..251df72f57a 100644 --- a/ui/v2.5/src/core/StashService.ts +++ b/ui/v2.5/src/core/StashService.ts @@ -1873,6 +1873,17 @@ export const useTagUpdate = () => }, }); +export const useBulkTagUpdate = (input: GQL.BulkTagUpdateInput) => + GQL.useBulkTagUpdateMutation({ + variables: { input }, + update(cache, result) { + if (!result.data?.bulkTagUpdate) return; + + evictTypeFields(cache, tagMutationImpactedTypeFields); + evictQueries(cache, tagMutationImpactedQueries); + }, + }); + export const useTagDestroy = (input: GQL.TagDestroyInput) => GQL.useTagDestroyMutation({ variables: input,