From 8c8be22fe4da081b951f95ad71cb44e9411d14ec Mon Sep 17 00:00:00 2001 From: WithoutPants <53250216+WithoutPants@users.noreply.github.com> Date: Tue, 3 Dec 2024 13:49:55 +1100 Subject: [PATCH] Performer custom fields (#5487) * Backend changes * Show custom field values * Add custom fields table input * Add custom field filtering * Add unit tests * Include custom fields in import/export * Anonymise performer custom fields * Move json.Number handler functions to api * Handle json.Number conversion in api --- graphql/schema/types/filters.graphql | 8 + graphql/schema/types/metadata.graphql | 7 + graphql/schema/types/performer.graphql | 8 + internal/api/json.go | 36 ++ internal/api/json_test.go | 60 ++ .../api/loaders/customfieldsloader_gen.go | 221 +++++++ internal/api/loaders/dataloaders.go | 35 +- internal/api/resolver_model_performer.go | 13 + internal/api/resolver_mutation_configure.go | 10 +- internal/api/resolver_mutation_performer.go | 13 +- internal/autotag/integration_test.go | 2 +- internal/identify/performer.go | 2 +- internal/identify/performer_test.go | 10 +- internal/manager/task_stash_box_tag.go | 2 +- pkg/gallery/import.go | 4 +- pkg/gallery/import_test.go | 6 +- pkg/image/import.go | 4 +- pkg/image/import_test.go | 6 +- pkg/models/custom_fields.go | 17 + pkg/models/filter.go | 6 + pkg/models/jsonschema/performer.go | 2 + pkg/models/mocks/PerformerReaderWriter.go | 54 +- pkg/models/model_performer.go | 14 + pkg/models/performer.go | 7 + pkg/models/repository_performer.go | 6 +- pkg/performer/export.go | 7 + pkg/performer/export_test.go | 59 +- pkg/performer/import.go | 23 +- pkg/performer/import_test.go | 27 +- pkg/scene/import.go | 4 +- pkg/scene/import_test.go | 6 +- pkg/sqlite/anonymise.go | 74 +++ pkg/sqlite/custom_fields.go | 308 ++++++++++ pkg/sqlite/custom_fields_test.go | 176 ++++++ pkg/sqlite/database.go | 2 +- pkg/sqlite/filter.go | 7 +- pkg/sqlite/migrations/71_custom_fields.up.sql | 9 + pkg/sqlite/performer.go | 28 +- pkg/sqlite/performer_filter.go | 7 + pkg/sqlite/performer_test.go | 574 +++++++++++++++--- pkg/sqlite/query.go | 8 +- pkg/sqlite/repository.go | 4 +- pkg/sqlite/setup_test.go | 17 +- pkg/sqlite/tables.go | 1 + pkg/utils/json.go | 16 - pkg/utils/map.go | 17 - pkg/utils/map_test.go | 55 -- ui/v2.5/graphql/data/performer.graphql | 2 + .../PerformerDetailsPanel.tsx | 2 + .../PerformerDetails/PerformerEditPanel.tsx | 41 +- ui/v2.5/src/components/Performers/styles.scss | 11 + .../src/components/Shared/CollapseButton.tsx | 7 +- .../src/components/Shared/CustomFields.tsx | 308 ++++++++++ ui/v2.5/src/components/Shared/DetailItem.tsx | 21 +- ui/v2.5/src/components/Shared/styles.scss | 50 ++ ui/v2.5/src/locales/en-GB.json | 11 + 56 files changed, 2158 insertions(+), 277 deletions(-) create mode 100644 internal/api/json.go create mode 100644 internal/api/json_test.go create mode 100644 internal/api/loaders/customfieldsloader_gen.go create mode 100644 pkg/models/custom_fields.go create mode 100644 pkg/sqlite/custom_fields.go create mode 100644 pkg/sqlite/custom_fields_test.go create mode 100644 pkg/sqlite/migrations/71_custom_fields.up.sql delete mode 100644 pkg/utils/json.go create mode 100644 ui/v2.5/src/components/Shared/CustomFields.tsx diff --git a/graphql/schema/types/filters.graphql b/graphql/schema/types/filters.graphql index 23396a98ffd..7600b563b83 100644 --- a/graphql/schema/types/filters.graphql +++ b/graphql/schema/types/filters.graphql @@ -91,6 +91,12 @@ input StashIDCriterionInput { modifier: CriterionModifier! } +input CustomFieldCriterionInput { + field: String! + value: [Any!] + modifier: CriterionModifier! +} + input PerformerFilterType { AND: PerformerFilterType OR: PerformerFilterType @@ -182,6 +188,8 @@ input PerformerFilterType { created_at: TimestampCriterionInput "Filter by last update time" updated_at: TimestampCriterionInput + + custom_fields: [CustomFieldCriterionInput!] } input SceneMarkerFilterType { diff --git a/graphql/schema/types/metadata.graphql b/graphql/schema/types/metadata.graphql index 38c910d369c..923c25b4c32 100644 --- a/graphql/schema/types/metadata.graphql +++ b/graphql/schema/types/metadata.graphql @@ -338,3 +338,10 @@ type SystemStatus { input MigrateInput { backupPath: String! } + +input CustomFieldsInput { + "If populated, the entire custom fields map will be replaced with this value" + full: Map + "If populated, only the keys in this map will be updated" + partial: Map +} diff --git a/graphql/schema/types/performer.graphql b/graphql/schema/types/performer.graphql index d6f3dd832c4..fbb67ce8f07 100644 --- a/graphql/schema/types/performer.graphql +++ b/graphql/schema/types/performer.graphql @@ -58,6 +58,8 @@ type Performer { updated_at: Time! groups: [Group!]! movies: [Movie!]! @deprecated(reason: "use groups instead") + + custom_fields: Map! } input PerformerCreateInput { @@ -93,6 +95,8 @@ input PerformerCreateInput { hair_color: String weight: Int ignore_auto_tag: Boolean + + custom_fields: Map } input PerformerUpdateInput { @@ -129,6 +133,8 @@ input PerformerUpdateInput { hair_color: String weight: Int ignore_auto_tag: Boolean + + custom_fields: CustomFieldsInput } input BulkUpdateStrings { @@ -167,6 +173,8 @@ input BulkPerformerUpdateInput { hair_color: String weight: Int ignore_auto_tag: Boolean + + custom_fields: CustomFieldsInput } input PerformerDestroyInput { diff --git a/internal/api/json.go b/internal/api/json.go new file mode 100644 index 00000000000..edc5f9df80c --- /dev/null +++ b/internal/api/json.go @@ -0,0 +1,36 @@ +package api + +import ( + "encoding/json" + "strings" +) + +// JSONNumberToNumber converts a JSON number to either a float64 or int64. +func jsonNumberToNumber(n json.Number) interface{} { + if strings.Contains(string(n), ".") { + f, _ := n.Float64() + return f + } + ret, _ := n.Int64() + return ret +} + +// ConvertMapJSONNumbers converts all JSON numbers in a map to either float64 or int64. +func convertMapJSONNumbers(m map[string]interface{}) (ret map[string]interface{}) { + if m == nil { + return nil + } + + ret = make(map[string]interface{}) + for k, v := range m { + if n, ok := v.(json.Number); ok { + ret[k] = jsonNumberToNumber(n) + } else if mm, ok := v.(map[string]interface{}); ok { + ret[k] = convertMapJSONNumbers(mm) + } else { + ret[k] = v + } + } + + return ret +} diff --git a/internal/api/json_test.go b/internal/api/json_test.go new file mode 100644 index 00000000000..7c1b2fe90f0 --- /dev/null +++ b/internal/api/json_test.go @@ -0,0 +1,60 @@ +package api + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConvertMapJSONNumbers(t *testing.T) { + tests := []struct { + name string + input map[string]interface{} + expected map[string]interface{} + }{ + { + name: "Convert JSON numbers to numbers", + input: map[string]interface{}{ + "int": json.Number("12"), + "float": json.Number("12.34"), + "string": "foo", + }, + expected: map[string]interface{}{ + "int": int64(12), + "float": 12.34, + "string": "foo", + }, + }, + { + name: "Convert JSON numbers to numbers in nested maps", + input: map[string]interface{}{ + "foo": map[string]interface{}{ + "int": json.Number("56"), + "float": json.Number("56.78"), + "nested-string": "bar", + }, + "int": json.Number("12"), + "float": json.Number("12.34"), + "string": "foo", + }, + expected: map[string]interface{}{ + "foo": map[string]interface{}{ + "int": int64(56), + "float": 56.78, + "nested-string": "bar", + }, + "int": int64(12), + "float": 12.34, + "string": "foo", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertMapJSONNumbers(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/internal/api/loaders/customfieldsloader_gen.go b/internal/api/loaders/customfieldsloader_gen.go new file mode 100644 index 00000000000..d4dd3de78ab --- /dev/null +++ b/internal/api/loaders/customfieldsloader_gen.go @@ -0,0 +1,221 @@ +// Code generated by github.com/vektah/dataloaden, DO NOT EDIT. + +package loaders + +import ( + "sync" + "time" + + "github.com/stashapp/stash/pkg/models" +) + +// CustomFieldsLoaderConfig captures the config to create a new CustomFieldsLoader +type CustomFieldsLoaderConfig struct { + // Fetch is a method that provides the data for the loader + Fetch func(keys []int) ([]models.CustomFieldMap, []error) + + // Wait is how long wait before sending a batch + Wait time.Duration + + // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit + MaxBatch int +} + +// NewCustomFieldsLoader creates a new CustomFieldsLoader given a fetch, wait, and maxBatch +func NewCustomFieldsLoader(config CustomFieldsLoaderConfig) *CustomFieldsLoader { + return &CustomFieldsLoader{ + fetch: config.Fetch, + wait: config.Wait, + maxBatch: config.MaxBatch, + } +} + +// CustomFieldsLoader batches and caches requests +type CustomFieldsLoader struct { + // this method provides the data for the loader + fetch func(keys []int) ([]models.CustomFieldMap, []error) + + // how long to done before sending a batch + wait time.Duration + + // this will limit the maximum number of keys to send in one batch, 0 = no limit + maxBatch int + + // INTERNAL + + // lazily created cache + cache map[int]models.CustomFieldMap + + // the current batch. keys will continue to be collected until timeout is hit, + // then everything will be sent to the fetch method and out to the listeners + batch *customFieldsLoaderBatch + + // mutex to prevent races + mu sync.Mutex +} + +type customFieldsLoaderBatch struct { + keys []int + data []models.CustomFieldMap + error []error + closing bool + done chan struct{} +} + +// Load a CustomFieldMap by key, batching and caching will be applied automatically +func (l *CustomFieldsLoader) Load(key int) (models.CustomFieldMap, error) { + return l.LoadThunk(key)() +} + +// LoadThunk returns a function that when called will block waiting for a CustomFieldMap. +// This method should be used if you want one goroutine to make requests to many +// different data loaders without blocking until the thunk is called. +func (l *CustomFieldsLoader) LoadThunk(key int) func() (models.CustomFieldMap, error) { + l.mu.Lock() + if it, ok := l.cache[key]; ok { + l.mu.Unlock() + return func() (models.CustomFieldMap, error) { + return it, nil + } + } + if l.batch == nil { + l.batch = &customFieldsLoaderBatch{done: make(chan struct{})} + } + batch := l.batch + pos := batch.keyIndex(l, key) + l.mu.Unlock() + + return func() (models.CustomFieldMap, error) { + <-batch.done + + var data models.CustomFieldMap + if pos < len(batch.data) { + data = batch.data[pos] + } + + var err error + // its convenient to be able to return a single error for everything + if len(batch.error) == 1 { + err = batch.error[0] + } else if batch.error != nil { + err = batch.error[pos] + } + + if err == nil { + l.mu.Lock() + l.unsafeSet(key, data) + l.mu.Unlock() + } + + return data, err + } +} + +// LoadAll fetches many keys at once. It will be broken into appropriate sized +// sub batches depending on how the loader is configured +func (l *CustomFieldsLoader) LoadAll(keys []int) ([]models.CustomFieldMap, []error) { + results := make([]func() (models.CustomFieldMap, error), len(keys)) + + for i, key := range keys { + results[i] = l.LoadThunk(key) + } + + customFieldMaps := make([]models.CustomFieldMap, len(keys)) + errors := make([]error, len(keys)) + for i, thunk := range results { + customFieldMaps[i], errors[i] = thunk() + } + return customFieldMaps, errors +} + +// LoadAllThunk returns a function that when called will block waiting for a CustomFieldMaps. +// This method should be used if you want one goroutine to make requests to many +// different data loaders without blocking until the thunk is called. +func (l *CustomFieldsLoader) LoadAllThunk(keys []int) func() ([]models.CustomFieldMap, []error) { + results := make([]func() (models.CustomFieldMap, error), len(keys)) + for i, key := range keys { + results[i] = l.LoadThunk(key) + } + return func() ([]models.CustomFieldMap, []error) { + customFieldMaps := make([]models.CustomFieldMap, len(keys)) + errors := make([]error, len(keys)) + for i, thunk := range results { + customFieldMaps[i], errors[i] = thunk() + } + return customFieldMaps, errors + } +} + +// Prime the cache with the provided key and value. If the key already exists, no change is made +// and false is returned. +// (To forcefully prime the cache, clear the key first with loader.clear(key).prime(key, value).) +func (l *CustomFieldsLoader) Prime(key int, value models.CustomFieldMap) bool { + l.mu.Lock() + var found bool + if _, found = l.cache[key]; !found { + l.unsafeSet(key, value) + } + l.mu.Unlock() + return !found +} + +// Clear the value at key from the cache, if it exists +func (l *CustomFieldsLoader) Clear(key int) { + l.mu.Lock() + delete(l.cache, key) + l.mu.Unlock() +} + +func (l *CustomFieldsLoader) unsafeSet(key int, value models.CustomFieldMap) { + if l.cache == nil { + l.cache = map[int]models.CustomFieldMap{} + } + l.cache[key] = value +} + +// keyIndex will return the location of the key in the batch, if its not found +// it will add the key to the batch +func (b *customFieldsLoaderBatch) keyIndex(l *CustomFieldsLoader, key int) int { + for i, existingKey := range b.keys { + if key == existingKey { + return i + } + } + + pos := len(b.keys) + b.keys = append(b.keys, key) + if pos == 0 { + go b.startTimer(l) + } + + if l.maxBatch != 0 && pos >= l.maxBatch-1 { + if !b.closing { + b.closing = true + l.batch = nil + go b.end(l) + } + } + + return pos +} + +func (b *customFieldsLoaderBatch) startTimer(l *CustomFieldsLoader) { + time.Sleep(l.wait) + l.mu.Lock() + + // we must have hit a batch limit and are already finalizing this batch + if b.closing { + l.mu.Unlock() + return + } + + l.batch = nil + l.mu.Unlock() + + b.end(l) +} + +func (b *customFieldsLoaderBatch) end(l *CustomFieldsLoader) { + b.data, b.error = l.fetch(b.keys) + close(b.done) +} diff --git a/internal/api/loaders/dataloaders.go b/internal/api/loaders/dataloaders.go index fca3e6c1842..493c353d785 100644 --- a/internal/api/loaders/dataloaders.go +++ b/internal/api/loaders/dataloaders.go @@ -13,6 +13,7 @@ //go:generate go run github.com/vektah/dataloaden SceneFileIDsLoader int []github.com/stashapp/stash/pkg/models.FileID //go:generate go run github.com/vektah/dataloaden ImageFileIDsLoader int []github.com/stashapp/stash/pkg/models.FileID //go:generate go run github.com/vektah/dataloaden GalleryFileIDsLoader int []github.com/stashapp/stash/pkg/models.FileID +//go:generate go run github.com/vektah/dataloaden CustomFieldsLoader int github.com/stashapp/stash/pkg/models.CustomFieldMap //go:generate go run github.com/vektah/dataloaden SceneOCountLoader int int //go:generate go run github.com/vektah/dataloaden ScenePlayCountLoader int int //go:generate go run github.com/vektah/dataloaden SceneOHistoryLoader int []time.Time @@ -51,13 +52,16 @@ type Loaders struct { ImageFiles *ImageFileIDsLoader GalleryFiles *GalleryFileIDsLoader - GalleryByID *GalleryLoader - ImageByID *ImageLoader - PerformerByID *PerformerLoader - StudioByID *StudioLoader - TagByID *TagLoader - GroupByID *GroupLoader - FileByID *FileLoader + GalleryByID *GalleryLoader + ImageByID *ImageLoader + + PerformerByID *PerformerLoader + PerformerCustomFields *CustomFieldsLoader + + StudioByID *StudioLoader + TagByID *TagLoader + GroupByID *GroupLoader + FileByID *FileLoader } type Middleware struct { @@ -88,6 +92,11 @@ func (m Middleware) Middleware(next http.Handler) http.Handler { maxBatch: maxBatch, fetch: m.fetchPerformers(ctx), }, + PerformerCustomFields: &CustomFieldsLoader{ + wait: wait, + maxBatch: maxBatch, + fetch: m.fetchPerformerCustomFields(ctx), + }, StudioByID: &StudioLoader{ wait: wait, maxBatch: maxBatch, @@ -214,6 +223,18 @@ func (m Middleware) fetchPerformers(ctx context.Context) func(keys []int) ([]*mo } } +func (m Middleware) fetchPerformerCustomFields(ctx context.Context) func(keys []int) ([]models.CustomFieldMap, []error) { + return func(keys []int) (ret []models.CustomFieldMap, errs []error) { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { + var err error + ret, err = m.Repository.Performer.GetCustomFieldsBulk(ctx, keys) + return err + }) + + return ret, toErrorSlice(err) + } +} + func (m Middleware) fetchStudios(ctx context.Context) func(keys []int) ([]*models.Studio, []error) { return func(keys []int) (ret []*models.Studio, errs []error) { err := m.Repository.WithDB(ctx, func(ctx context.Context) error { diff --git a/internal/api/resolver_model_performer.go b/internal/api/resolver_model_performer.go index b6f6af369ad..94da629323d 100644 --- a/internal/api/resolver_model_performer.go +++ b/internal/api/resolver_model_performer.go @@ -268,6 +268,19 @@ func (r *performerResolver) Groups(ctx context.Context, obj *models.Performer) ( return ret, nil } +func (r *performerResolver) CustomFields(ctx context.Context, obj *models.Performer) (map[string]interface{}, error) { + m, err := loaders.From(ctx).PerformerCustomFields.Load(obj.ID) + if err != nil { + return nil, err + } + + if m == nil { + return make(map[string]interface{}), nil + } + + return m, nil +} + // deprecated func (r *performerResolver) Movies(ctx context.Context, obj *models.Performer) (ret []*models.Group, err error) { return r.Groups(ctx, obj) diff --git a/internal/api/resolver_mutation_configure.go b/internal/api/resolver_mutation_configure.go index c4356ff5857..d9c71b09fca 100644 --- a/internal/api/resolver_mutation_configure.go +++ b/internal/api/resolver_mutation_configure.go @@ -645,13 +645,13 @@ func (r *mutationResolver) ConfigureUI(ctx context.Context, input map[string]int if input != nil { // #5483 - convert JSON numbers to float64 or int64 - input = utils.ConvertMapJSONNumbers(input) + input = convertMapJSONNumbers(input) c.SetUIConfiguration(input) } if partial != nil { // #5483 - convert JSON numbers to float64 or int64 - partial = utils.ConvertMapJSONNumbers(partial) + partial = convertMapJSONNumbers(partial) // merge partial into existing config existing := c.GetUIConfiguration() utils.MergeMaps(existing, partial) @@ -672,9 +672,9 @@ func (r *mutationResolver) ConfigureUISetting(ctx context.Context, key string, v // #5483 - convert JSON numbers to float64 or int64 if m, ok := value.(map[string]interface{}); ok { - value = utils.ConvertMapJSONNumbers(m) + value = convertMapJSONNumbers(m) } else if n, ok := value.(json.Number); ok { - value = utils.JSONNumberToNumber(n) + value = jsonNumberToNumber(n) } cfg.Set(key, value) @@ -686,7 +686,7 @@ func (r *mutationResolver) ConfigurePlugin(ctx context.Context, pluginID string, c := config.GetInstance() // #5483 - convert JSON numbers to float64 or int64 - input = utils.ConvertMapJSONNumbers(input) + input = convertMapJSONNumbers(input) c.SetPluginConfiguration(pluginID, input) if err := c.Write(); err != nil { diff --git a/internal/api/resolver_mutation_performer.go b/internal/api/resolver_mutation_performer.go index 87f0883ed24..47b02147de0 100644 --- a/internal/api/resolver_mutation_performer.go +++ b/internal/api/resolver_mutation_performer.go @@ -108,7 +108,13 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per return err } - err = qb.Create(ctx, &newPerformer) + i := &models.CreatePerformerInput{ + Performer: &newPerformer, + // convert json.Numbers to int/float + CustomFields: convertMapJSONNumbers(input.CustomFields), + } + + err = qb.Create(ctx, i) if err != nil { return err } @@ -290,6 +296,11 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per return nil, fmt.Errorf("converting tag ids: %w", err) } + updatedPerformer.CustomFields = input.CustomFields + // convert json.Numbers to int/float + updatedPerformer.CustomFields.Full = convertMapJSONNumbers(updatedPerformer.CustomFields.Full) + updatedPerformer.CustomFields.Partial = convertMapJSONNumbers(updatedPerformer.CustomFields.Partial) + var imageData []byte imageIncluded := translator.hasField("image") if input.Image != nil { diff --git a/internal/autotag/integration_test.go b/internal/autotag/integration_test.go index e74cb30aa66..565d73853c4 100644 --- a/internal/autotag/integration_test.go +++ b/internal/autotag/integration_test.go @@ -91,7 +91,7 @@ func createPerformer(ctx context.Context, pqb models.PerformerWriter) error { Name: testName, } - err := pqb.Create(ctx, &performer) + err := pqb.Create(ctx, &models.CreatePerformerInput{Performer: &performer}) if err != nil { return err } diff --git a/internal/identify/performer.go b/internal/identify/performer.go index 947bb09d6f8..7ee66b500c7 100644 --- a/internal/identify/performer.go +++ b/internal/identify/performer.go @@ -41,7 +41,7 @@ func createMissingPerformer(ctx context.Context, endpoint string, w PerformerCre return nil, err } - err = w.Create(ctx, newPerformer) + err = w.Create(ctx, &models.CreatePerformerInput{Performer: newPerformer}) if err != nil { return nil, fmt.Errorf("error creating performer: %w", err) } diff --git a/internal/identify/performer_test.go b/internal/identify/performer_test.go index 09690959de0..8d443763aa3 100644 --- a/internal/identify/performer_test.go +++ b/internal/identify/performer_test.go @@ -24,8 +24,8 @@ func Test_getPerformerID(t *testing.T) { db := mocks.NewDatabase() - db.Performer.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) { - p := args.Get(1).(*models.Performer) + db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Run(func(args mock.Arguments) { + p := args.Get(1).(*models.CreatePerformerInput) p.ID = validStoredID }).Return(nil) @@ -154,14 +154,14 @@ func Test_createMissingPerformer(t *testing.T) { db := mocks.NewDatabase() - db.Performer.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool { + db.Performer.On("Create", testCtx, mock.MatchedBy(func(p *models.CreatePerformerInput) bool { return p.Name == validName })).Run(func(args mock.Arguments) { - p := args.Get(1).(*models.Performer) + p := args.Get(1).(*models.CreatePerformerInput) p.ID = performerID }).Return(nil) - db.Performer.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool { + db.Performer.On("Create", testCtx, mock.MatchedBy(func(p *models.CreatePerformerInput) bool { return p.Name == invalidName })).Return(errors.New("error creating performer")) diff --git a/internal/manager/task_stash_box_tag.go b/internal/manager/task_stash_box_tag.go index 8bb39960140..e26edc8b1ab 100644 --- a/internal/manager/task_stash_box_tag.go +++ b/internal/manager/task_stash_box_tag.go @@ -194,7 +194,7 @@ func (t *StashBoxBatchTagTask) processMatchedPerformer(ctx context.Context, p *m return err } - if err := qb.Create(ctx, newPerformer); err != nil { + if err := qb.Create(ctx, &models.CreatePerformerInput{Performer: newPerformer}); err != nil { return err } diff --git a/pkg/gallery/import.go b/pkg/gallery/import.go index aaf37bd27e4..7cdf53691ad 100644 --- a/pkg/gallery/import.go +++ b/pkg/gallery/import.go @@ -188,7 +188,9 @@ func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*mod newPerformer := models.NewPerformer() newPerformer.Name = name - err := i.PerformerWriter.Create(ctx, &newPerformer) + err := i.PerformerWriter.Create(ctx, &models.CreatePerformerInput{ + Performer: &newPerformer, + }) if err != nil { return nil, err } diff --git a/pkg/gallery/import_test.go b/pkg/gallery/import_test.go index ec2cf7a77f5..b64f80d8f6b 100644 --- a/pkg/gallery/import_test.go +++ b/pkg/gallery/import_test.go @@ -201,8 +201,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { } db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) - db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { - performer := args.Get(1).(*models.Performer) + db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Run(func(args mock.Arguments) { + performer := args.Get(1).(*models.CreatePerformerInput) performer.ID = existingPerformerID }).Return(nil) @@ -235,7 +235,7 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { } db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() - db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) + db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/image/import.go b/pkg/image/import.go index 660eb1da18d..ec200af047f 100644 --- a/pkg/image/import.go +++ b/pkg/image/import.go @@ -274,7 +274,9 @@ func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*mod newPerformer := models.NewPerformer() newPerformer.Name = name - err := i.PerformerWriter.Create(ctx, &newPerformer) + err := i.PerformerWriter.Create(ctx, &models.CreatePerformerInput{ + Performer: &newPerformer, + }) if err != nil { return nil, err } diff --git a/pkg/image/import_test.go b/pkg/image/import_test.go index 9d63dd02e92..286e51fe34b 100644 --- a/pkg/image/import_test.go +++ b/pkg/image/import_test.go @@ -163,8 +163,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { } db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) - db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { - performer := args.Get(1).(*models.Performer) + db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Run(func(args mock.Arguments) { + performer := args.Get(1).(*models.CreatePerformerInput) performer.ID = existingPerformerID }).Return(nil) @@ -197,7 +197,7 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { } db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() - db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) + db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/models/custom_fields.go b/pkg/models/custom_fields.go new file mode 100644 index 00000000000..977c2fe89f3 --- /dev/null +++ b/pkg/models/custom_fields.go @@ -0,0 +1,17 @@ +package models + +import "context" + +type CustomFieldMap map[string]interface{} + +type CustomFieldsInput struct { + // If populated, the entire custom fields map will be replaced with this value + Full map[string]interface{} `json:"full"` + // If populated, only the keys in this map will be updated + Partial map[string]interface{} `json:"partial"` +} + +type CustomFieldsReader interface { + GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) + GetCustomFieldsBulk(ctx context.Context, ids []int) ([]CustomFieldMap, error) +} diff --git a/pkg/models/filter.go b/pkg/models/filter.go index 577aef42be9..2d25f651636 100644 --- a/pkg/models/filter.go +++ b/pkg/models/filter.go @@ -194,3 +194,9 @@ type PhashDistanceCriterionInput struct { type OrientationCriterionInput struct { Value []OrientationEnum `json:"value"` } + +type CustomFieldCriterionInput struct { + Field string `json:"field"` + Value []any `json:"value"` + Modifier CriterionModifier `json:"modifier"` +} diff --git a/pkg/models/jsonschema/performer.go b/pkg/models/jsonschema/performer.go index 7ffa69983b4..5edd5724c63 100644 --- a/pkg/models/jsonschema/performer.go +++ b/pkg/models/jsonschema/performer.go @@ -65,6 +65,8 @@ type Performer struct { StashIDs []models.StashID `json:"stash_ids,omitempty"` IgnoreAutoTag bool `json:"ignore_auto_tag,omitempty"` + CustomFields map[string]interface{} `json:"custom_fields,omitempty"` + // deprecated - for import only URL string `json:"url,omitempty"` Twitter string `json:"twitter,omitempty"` diff --git a/pkg/models/mocks/PerformerReaderWriter.go b/pkg/models/mocks/PerformerReaderWriter.go index 0f3e2be02b6..dbf19a3cdce 100644 --- a/pkg/models/mocks/PerformerReaderWriter.go +++ b/pkg/models/mocks/PerformerReaderWriter.go @@ -80,11 +80,11 @@ func (_m *PerformerReaderWriter) CountByTagID(ctx context.Context, tagID int) (i } // Create provides a mock function with given fields: ctx, newPerformer -func (_m *PerformerReaderWriter) Create(ctx context.Context, newPerformer *models.Performer) error { +func (_m *PerformerReaderWriter) Create(ctx context.Context, newPerformer *models.CreatePerformerInput) error { ret := _m.Called(ctx, newPerformer) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *models.Performer) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, *models.CreatePerformerInput) error); ok { r0 = rf(ctx, newPerformer) } else { r0 = ret.Error(0) @@ -314,6 +314,52 @@ func (_m *PerformerReaderWriter) GetAliases(ctx context.Context, relatedID int) return r0, r1 } +// GetCustomFields provides a mock function with given fields: ctx, id +func (_m *PerformerReaderWriter) GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) { + ret := _m.Called(ctx, id) + + var r0 map[string]interface{} + if rf, ok := ret.Get(0).(func(context.Context, int) map[string]interface{}); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]interface{}) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetCustomFieldsBulk provides a mock function with given fields: ctx, ids +func (_m *PerformerReaderWriter) GetCustomFieldsBulk(ctx context.Context, ids []int) ([]models.CustomFieldMap, error) { + ret := _m.Called(ctx, ids) + + var r0 []models.CustomFieldMap + if rf, ok := ret.Get(0).(func(context.Context, []int) []models.CustomFieldMap); ok { + r0 = rf(ctx, ids) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.CustomFieldMap) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok { + r1 = rf(ctx, ids) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetImage provides a mock function with given fields: ctx, performerID func (_m *PerformerReaderWriter) GetImage(ctx context.Context, performerID int) ([]byte, error) { ret := _m.Called(ctx, performerID) @@ -502,11 +548,11 @@ func (_m *PerformerReaderWriter) QueryForAutoTag(ctx context.Context, words []st } // Update provides a mock function with given fields: ctx, updatedPerformer -func (_m *PerformerReaderWriter) Update(ctx context.Context, updatedPerformer *models.Performer) error { +func (_m *PerformerReaderWriter) Update(ctx context.Context, updatedPerformer *models.UpdatePerformerInput) error { ret := _m.Called(ctx, updatedPerformer) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *models.Performer) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, *models.UpdatePerformerInput) error); ok { r0 = rf(ctx, updatedPerformer) } else { r0 = ret.Error(0) diff --git a/pkg/models/model_performer.go b/pkg/models/model_performer.go index 85257ba38a4..566dcae1eff 100644 --- a/pkg/models/model_performer.go +++ b/pkg/models/model_performer.go @@ -39,6 +39,18 @@ type Performer struct { StashIDs RelatedStashIDs `json:"stash_ids"` } +type CreatePerformerInput struct { + *Performer + + CustomFields map[string]interface{} `json:"custom_fields"` +} + +type UpdatePerformerInput struct { + *Performer + + CustomFields CustomFieldsInput `json:"custom_fields"` +} + func NewPerformer() Performer { currentTime := time.Now() return Performer{ @@ -80,6 +92,8 @@ type PerformerPartial struct { Aliases *UpdateStrings TagIDs *UpdateIDs StashIDs *UpdateStashIDs + + CustomFields CustomFieldsInput } func NewPerformerPartial() PerformerPartial { diff --git a/pkg/models/performer.go b/pkg/models/performer.go index 47394996d3f..7301afb83bb 100644 --- a/pkg/models/performer.go +++ b/pkg/models/performer.go @@ -198,6 +198,9 @@ type PerformerFilterType struct { CreatedAt *TimestampCriterionInput `json:"created_at"` // Filter by updated at UpdatedAt *TimestampCriterionInput `json:"updated_at"` + + // Filter by custom fields + CustomFields []CustomFieldCriterionInput `json:"custom_fields"` } type PerformerCreateInput struct { @@ -234,6 +237,8 @@ type PerformerCreateInput struct { HairColor *string `json:"hair_color"` Weight *int `json:"weight"` IgnoreAutoTag *bool `json:"ignore_auto_tag"` + + CustomFields map[string]interface{} `json:"custom_fields"` } type PerformerUpdateInput struct { @@ -271,4 +276,6 @@ type PerformerUpdateInput struct { HairColor *string `json:"hair_color"` Weight *int `json:"weight"` IgnoreAutoTag *bool `json:"ignore_auto_tag"` + + CustomFields CustomFieldsInput `json:"custom_fields"` } diff --git a/pkg/models/repository_performer.go b/pkg/models/repository_performer.go index 3fd93619011..ad0b61da0f7 100644 --- a/pkg/models/repository_performer.go +++ b/pkg/models/repository_performer.go @@ -43,12 +43,12 @@ type PerformerCounter interface { // PerformerCreator provides methods to create performers. type PerformerCreator interface { - Create(ctx context.Context, newPerformer *Performer) error + Create(ctx context.Context, newPerformer *CreatePerformerInput) error } // PerformerUpdater provides methods to update performers. type PerformerUpdater interface { - Update(ctx context.Context, updatedPerformer *Performer) error + Update(ctx context.Context, updatedPerformer *UpdatePerformerInput) error UpdatePartial(ctx context.Context, id int, updatedPerformer PerformerPartial) (*Performer, error) UpdateImage(ctx context.Context, performerID int, image []byte) error } @@ -80,6 +80,8 @@ type PerformerReader interface { TagIDLoader URLLoader + CustomFieldsReader + All(ctx context.Context) ([]*Performer, error) GetImage(ctx context.Context, performerID int) ([]byte, error) HasImage(ctx context.Context, performerID int) (bool, error) diff --git a/pkg/performer/export.go b/pkg/performer/export.go index 8f720338f3d..1455fb7bfa0 100644 --- a/pkg/performer/export.go +++ b/pkg/performer/export.go @@ -17,6 +17,7 @@ type ImageAliasStashIDGetter interface { models.AliasLoader models.StashIDLoader models.URLLoader + models.CustomFieldsReader } // ToJSON converts a Performer object into its JSON equivalent. @@ -87,6 +88,12 @@ func ToJSON(ctx context.Context, reader ImageAliasStashIDGetter, performer *mode newPerformerJSON.StashIDs = performer.StashIDs.List() + var err error + newPerformerJSON.CustomFields, err = reader.GetCustomFields(ctx, performer.ID) + if err != nil { + return nil, fmt.Errorf("getting performer custom fields: %v", err) + } + image, err := reader.GetImage(ctx, performer.ID) if err != nil { logger.Errorf("Error getting performer image: %v", err) diff --git a/pkg/performer/export_test.go b/pkg/performer/export_test.go index 36353b17de7..e51049e1491 100644 --- a/pkg/performer/export_test.go +++ b/pkg/performer/export_test.go @@ -15,9 +15,11 @@ import ( ) const ( - performerID = 1 - noImageID = 2 - errImageID = 3 + performerID = 1 + noImageID = 2 + errImageID = 3 + customFieldsID = 4 + errCustomFieldsID = 5 ) const ( @@ -50,6 +52,11 @@ var ( penisLength = 1.23 circumcisedEnum = models.CircumisedEnumCut circumcised = circumcisedEnum.String() + + emptyCustomFields = make(map[string]interface{}) + customFields = map[string]interface{}{ + "customField1": "customValue1", + } ) var imageBytes = []byte("imageBytes") @@ -118,8 +125,8 @@ func createEmptyPerformer(id int) models.Performer { } } -func createFullJSONPerformer(name string, image string) *jsonschema.Performer { - return &jsonschema.Performer{ +func createFullJSONPerformer(name string, image string, withCustomFields bool) *jsonschema.Performer { + ret := &jsonschema.Performer{ Name: name, Disambiguation: disambiguation, URLs: []string{url, twitter, instagram}, @@ -152,7 +159,13 @@ func createFullJSONPerformer(name string, image string) *jsonschema.Performer { Weight: weight, StashIDs: stashIDs, IgnoreAutoTag: autoTagIgnored, + CustomFields: emptyCustomFields, } + + if withCustomFields { + ret.CustomFields = customFields + } + return ret } func createEmptyJSONPerformer() *jsonschema.Performer { @@ -166,13 +179,15 @@ func createEmptyJSONPerformer() *jsonschema.Performer { UpdatedAt: json.JSONTime{ Time: updateTime, }, + CustomFields: emptyCustomFields, } } type testScenario struct { - input models.Performer - expected *jsonschema.Performer - err bool + input models.Performer + customFields map[string]interface{} + expected *jsonschema.Performer + err bool } var scenarios []testScenario @@ -181,20 +196,36 @@ func initTestTable() { scenarios = []testScenario{ { *createFullPerformer(performerID, performerName), - createFullJSONPerformer(performerName, image), + emptyCustomFields, + createFullJSONPerformer(performerName, image, false), + false, + }, + { + *createFullPerformer(customFieldsID, performerName), + customFields, + createFullJSONPerformer(performerName, image, true), false, }, { createEmptyPerformer(noImageID), + emptyCustomFields, createEmptyJSONPerformer(), false, }, { *createFullPerformer(errImageID, performerName), - createFullJSONPerformer(performerName, ""), + emptyCustomFields, + createFullJSONPerformer(performerName, "", false), // failure to get image should not cause an error false, }, + { + *createFullPerformer(errCustomFieldsID, performerName), + customFields, + nil, + // failure to get custom fields should cause an error + true, + }, } } @@ -204,11 +235,19 @@ func TestToJSON(t *testing.T) { db := mocks.NewDatabase() imageErr := errors.New("error getting image") + customFieldsErr := errors.New("error getting custom fields") db.Performer.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once() + db.Performer.On("GetImage", testCtx, customFieldsID).Return(imageBytes, nil).Once() db.Performer.On("GetImage", testCtx, noImageID).Return(nil, nil).Once() db.Performer.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once() + db.Performer.On("GetCustomFields", testCtx, performerID).Return(emptyCustomFields, nil).Once() + db.Performer.On("GetCustomFields", testCtx, customFieldsID).Return(customFields, nil).Once() + db.Performer.On("GetCustomFields", testCtx, noImageID).Return(emptyCustomFields, nil).Once() + db.Performer.On("GetCustomFields", testCtx, errImageID).Return(emptyCustomFields, nil).Once() + db.Performer.On("GetCustomFields", testCtx, errCustomFieldsID).Return(nil, customFieldsErr).Once() + for i, s := range scenarios { tag := s.input json, err := ToJSON(testCtx, db.Performer, &tag) diff --git a/pkg/performer/import.go b/pkg/performer/import.go index 49a2ce291ae..3aaacdb8b69 100644 --- a/pkg/performer/import.go +++ b/pkg/performer/import.go @@ -25,13 +25,15 @@ type Importer struct { Input jsonschema.Performer MissingRefBehaviour models.ImportMissingRefEnum - ID int - performer models.Performer - imageData []byte + ID int + performer models.Performer + customFields models.CustomFieldMap + imageData []byte } func (i *Importer) PreImport(ctx context.Context) error { i.performer = performerJSONToPerformer(i.Input) + i.customFields = i.Input.CustomFields if err := i.populateTags(ctx); err != nil { return err @@ -165,7 +167,10 @@ func (i *Importer) FindExistingID(ctx context.Context) (*int, error) { } func (i *Importer) Create(ctx context.Context) (*int, error) { - err := i.ReaderWriter.Create(ctx, &i.performer) + err := i.ReaderWriter.Create(ctx, &models.CreatePerformerInput{ + Performer: &i.performer, + CustomFields: i.customFields, + }) if err != nil { return nil, fmt.Errorf("error creating performer: %v", err) } @@ -175,9 +180,13 @@ func (i *Importer) Create(ctx context.Context) (*int, error) { } func (i *Importer) Update(ctx context.Context, id int) error { - performer := i.performer - performer.ID = id - err := i.ReaderWriter.Update(ctx, &performer) + i.performer.ID = id + err := i.ReaderWriter.Update(ctx, &models.UpdatePerformerInput{ + Performer: &i.performer, + CustomFields: models.CustomFieldsInput{ + Full: i.customFields, + }, + }) if err != nil { return fmt.Errorf("error updating existing performer: %v", err) } diff --git a/pkg/performer/import_test.go b/pkg/performer/import_test.go index 1ee569892d4..0a3f862914a 100644 --- a/pkg/performer/import_test.go +++ b/pkg/performer/import_test.go @@ -53,13 +53,14 @@ func TestImporterPreImport(t *testing.T) { assert.NotNil(t, err) - i.Input = *createFullJSONPerformer(performerName, image) + i.Input = *createFullJSONPerformer(performerName, image, true) err = i.PreImport(testCtx) assert.Nil(t, err) expectedPerformer := *createFullPerformer(0, performerName) assert.Equal(t, expectedPerformer, i.performer) + assert.Equal(t, models.CustomFieldMap(customFields), i.customFields) } func TestImporterPreImportWithTag(t *testing.T) { @@ -234,10 +235,18 @@ func TestCreate(t *testing.T) { Name: performerName, } + performerInput := models.CreatePerformerInput{ + Performer: &performer, + } + performerErr := models.Performer{ Name: performerNameErr, } + performerErrInput := models.CreatePerformerInput{ + Performer: &performerErr, + } + i := Importer{ ReaderWriter: db.Performer, TagWriter: db.Tag, @@ -245,11 +254,11 @@ func TestCreate(t *testing.T) { } errCreate := errors.New("Create error") - db.Performer.On("Create", testCtx, &performer).Run(func(args mock.Arguments) { - arg := args.Get(1).(*models.Performer) + db.Performer.On("Create", testCtx, &performerInput).Run(func(args mock.Arguments) { + arg := args.Get(1).(*models.CreatePerformerInput) arg.ID = performerID }).Return(nil).Once() - db.Performer.On("Create", testCtx, &performerErr).Return(errCreate).Once() + db.Performer.On("Create", testCtx, &performerErrInput).Return(errCreate).Once() id, err := i.Create(testCtx) assert.Equal(t, performerID, *id) @@ -284,7 +293,10 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input performer.ID = performerID - db.Performer.On("Update", testCtx, &performer).Return(nil).Once() + performerInput := models.UpdatePerformerInput{ + Performer: &performer, + } + db.Performer.On("Update", testCtx, &performerInput).Return(nil).Once() err := i.Update(testCtx, performerID) assert.Nil(t, err) @@ -293,7 +305,10 @@ func TestUpdate(t *testing.T) { // need to set id separately performerErr.ID = errImageID - db.Performer.On("Update", testCtx, &performerErr).Return(errUpdate).Once() + performerErrInput := models.UpdatePerformerInput{ + Performer: &performerErr, + } + db.Performer.On("Update", testCtx, &performerErrInput).Return(errUpdate).Once() err = i.Update(testCtx, errImageID) assert.NotNil(t, err) diff --git a/pkg/scene/import.go b/pkg/scene/import.go index c1b065bcf8a..e1248a77c3d 100644 --- a/pkg/scene/import.go +++ b/pkg/scene/import.go @@ -325,7 +325,9 @@ func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*mod newPerformer := models.NewPerformer() newPerformer.Name = name - err := i.PerformerWriter.Create(ctx, &newPerformer) + err := i.PerformerWriter.Create(ctx, &models.CreatePerformerInput{ + Performer: &newPerformer, + }) if err != nil { return nil, err } diff --git a/pkg/scene/import_test.go b/pkg/scene/import_test.go index 0e37dce16db..a6e3edcdfdb 100644 --- a/pkg/scene/import_test.go +++ b/pkg/scene/import_test.go @@ -327,8 +327,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { } db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) - db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { - p := args.Get(1).(*models.Performer) + db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Run(func(args mock.Arguments) { + p := args.Get(1).(*models.CreatePerformerInput) p.ID = existingPerformerID }).Return(nil) @@ -361,7 +361,7 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { } db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() - db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) + db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/sqlite/anonymise.go b/pkg/sqlite/anonymise.go index 519489abfc6..f30779fd3bc 100644 --- a/pkg/sqlite/anonymise.go +++ b/pkg/sqlite/anonymise.go @@ -600,6 +600,10 @@ func (db *Anonymiser) anonymisePerformers(ctx context.Context) error { return err } + if err := db.anonymiseCustomFields(ctx, goqu.T(performersCustomFieldsTable.GetTable()), "performer_id"); err != nil { + return err + } + return nil } @@ -1050,3 +1054,73 @@ func (db *Anonymiser) obfuscateString(in string, dict string) string { return out.String() } + +func (db *Anonymiser) anonymiseCustomFields(ctx context.Context, table exp.IdentifierExpression, idColumn string) error { + lastID := 0 + lastField := "" + total := 0 + const logEvery = 10000 + + for gotSome := true; gotSome; { + if err := txn.WithTxn(ctx, db, func(ctx context.Context) error { + query := dialect.From(table).Select( + table.Col(idColumn), + table.Col("field"), + table.Col("value"), + ).Where( + goqu.L("("+idColumn+", field)").Gt(goqu.L("(?, ?)", lastID, lastField)), + ).Order( + table.Col(idColumn).Asc(), table.Col("field").Asc(), + ).Limit(1000) + + gotSome = false + + const single = false + return queryFunc(ctx, query, single, func(rows *sqlx.Rows) error { + var ( + id int + field string + value string + ) + + if err := rows.Scan( + &id, + &field, + &value, + ); err != nil { + return err + } + + set := goqu.Record{} + set["field"] = db.obfuscateString(field, letters) + set["value"] = db.obfuscateString(value, letters) + + if len(set) > 0 { + stmt := dialect.Update(table).Set(set).Where( + table.Col(idColumn).Eq(id), + table.Col("field").Eq(field), + ) + + if _, err := exec(ctx, stmt); err != nil { + return fmt.Errorf("anonymising %s: %w", table.GetTable(), err) + } + } + + lastID = id + lastField = field + gotSome = true + total++ + + if total%logEvery == 0 { + logger.Infof("Anonymised %d %s custom fields", total, table.GetTable()) + } + + return nil + }) + }); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/sqlite/custom_fields.go b/pkg/sqlite/custom_fields.go new file mode 100644 index 00000000000..bac6ae5e17f --- /dev/null +++ b/pkg/sqlite/custom_fields.go @@ -0,0 +1,308 @@ +package sqlite + +import ( + "context" + "fmt" + "regexp" + "strings" + + "github.com/doug-martin/goqu/v9" + "github.com/doug-martin/goqu/v9/exp" + "github.com/jmoiron/sqlx" + "github.com/stashapp/stash/pkg/models" +) + +const maxCustomFieldNameLength = 64 + +type customFieldsStore struct { + table exp.IdentifierExpression + fk exp.IdentifierExpression +} + +func (s *customFieldsStore) deleteForID(ctx context.Context, id int) error { + table := s.table + q := dialect.Delete(table).Where(s.fk.Eq(id)) + _, err := exec(ctx, q) + if err != nil { + return fmt.Errorf("deleting from %s: %w", s.table.GetTable(), err) + } + + return nil +} + +func (s *customFieldsStore) SetCustomFields(ctx context.Context, id int, values models.CustomFieldsInput) error { + var partial bool + var valMap map[string]interface{} + + switch { + case values.Full != nil: + partial = false + valMap = values.Full + case values.Partial != nil: + partial = true + valMap = values.Partial + default: + return nil + } + + if err := s.validateCustomFields(valMap); err != nil { + return err + } + + return s.setCustomFields(ctx, id, valMap, partial) +} + +func (s *customFieldsStore) validateCustomFields(values map[string]interface{}) error { + // ensure that custom field names are valid + // no leading or trailing whitespace, no empty strings + for k := range values { + if err := s.validateCustomFieldName(k); err != nil { + return fmt.Errorf("custom field name %q: %w", k, err) + } + } + + return nil +} + +func (s *customFieldsStore) validateCustomFieldName(fieldName string) error { + // ensure that custom field names are valid + // no leading or trailing whitespace, no empty strings + if strings.TrimSpace(fieldName) == "" { + return fmt.Errorf("custom field name cannot be empty") + } + if fieldName != strings.TrimSpace(fieldName) { + return fmt.Errorf("custom field name cannot have leading or trailing whitespace") + } + if len(fieldName) > maxCustomFieldNameLength { + return fmt.Errorf("custom field name must be less than %d characters", maxCustomFieldNameLength+1) + } + return nil +} + +func getSQLValueFromCustomFieldInput(input interface{}) (interface{}, error) { + switch v := input.(type) { + case []interface{}, map[string]interface{}: + // TODO - in future it would be nice to convert to a JSON string + // however, we would need some way to differentiate between a JSON string and a regular string + // for now, we will not support objects and arrays + return nil, fmt.Errorf("unsupported custom field value type: %T", input) + default: + return v, nil + } +} + +func (s *customFieldsStore) sqlValueToValue(value interface{}) interface{} { + // TODO - if we ever support objects and arrays we will need to add support here + return value +} + +func (s *customFieldsStore) setCustomFields(ctx context.Context, id int, values map[string]interface{}, partial bool) error { + if !partial { + // delete existing custom fields + if err := s.deleteForID(ctx, id); err != nil { + return err + } + } + + if len(values) == 0 { + return nil + } + + conflictKey := s.fk.GetCol().(string) + ", field" + // upsert new custom fields + q := dialect.Insert(s.table).Prepared(true).Cols(s.fk, "field", "value"). + OnConflict(goqu.DoUpdate(conflictKey, goqu.Record{"value": goqu.I("excluded.value")})) + r := make([]interface{}, len(values)) + var i int + for key, value := range values { + v, err := getSQLValueFromCustomFieldInput(value) + if err != nil { + return fmt.Errorf("getting SQL value for field %q: %w", key, err) + } + r[i] = goqu.Record{"field": key, "value": v, s.fk.GetCol().(string): id} + i++ + } + + if _, err := exec(ctx, q.Rows(r...)); err != nil { + return fmt.Errorf("inserting custom fields: %w", err) + } + + return nil +} + +func (s *customFieldsStore) GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) { + q := dialect.Select("field", "value").From(s.table).Where(s.fk.Eq(id)) + + const single = false + ret := make(map[string]interface{}) + err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error { + var field string + var value interface{} + if err := rows.Scan(&field, &value); err != nil { + return fmt.Errorf("scanning custom fields: %w", err) + } + ret[field] = s.sqlValueToValue(value) + return nil + }) + if err != nil { + return nil, fmt.Errorf("getting custom fields: %w", err) + } + + return ret, nil +} + +func (s *customFieldsStore) GetCustomFieldsBulk(ctx context.Context, ids []int) ([]models.CustomFieldMap, error) { + q := dialect.Select(s.fk.As("id"), "field", "value").From(s.table).Where(s.fk.In(ids)) + + const single = false + ret := make([]models.CustomFieldMap, len(ids)) + + idi := make(map[int]int, len(ids)) + for i, id := range ids { + idi[id] = i + } + + err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error { + var id int + var field string + var value interface{} + if err := rows.Scan(&id, &field, &value); err != nil { + return fmt.Errorf("scanning custom fields: %w", err) + } + + i := idi[id] + m := ret[i] + if m == nil { + m = make(map[string]interface{}) + ret[i] = m + } + + m[field] = s.sqlValueToValue(value) + return nil + }) + if err != nil { + return nil, fmt.Errorf("getting custom fields: %w", err) + } + + return ret, nil +} + +type customFieldsFilterHandler struct { + table string + fkCol string + c []models.CustomFieldCriterionInput + idCol string +} + +func (h *customFieldsFilterHandler) innerJoin(f *filterBuilder, as string, field string) { + joinOn := fmt.Sprintf("%s = %s.%s AND %s.field = ?", h.idCol, as, h.fkCol, as) + f.addInnerJoin(h.table, as, joinOn, field) +} + +func (h *customFieldsFilterHandler) leftJoin(f *filterBuilder, as string, field string) { + joinOn := fmt.Sprintf("%s = %s.%s AND %s.field = ?", h.idCol, as, h.fkCol, as) + f.addLeftJoin(h.table, as, joinOn, field) +} + +func (h *customFieldsFilterHandler) handleCriterion(f *filterBuilder, joinAs string, cc models.CustomFieldCriterionInput) { + // convert values + cv := make([]interface{}, len(cc.Value)) + for i, v := range cc.Value { + var err error + cv[i], err = getSQLValueFromCustomFieldInput(v) + if err != nil { + f.setError(err) + return + } + } + + switch cc.Modifier { + case models.CriterionModifierEquals: + h.innerJoin(f, joinAs, cc.Field) + f.addWhere(fmt.Sprintf("%[1]s.value IN %s", joinAs, getInBinding(len(cv))), cv...) + case models.CriterionModifierNotEquals: + h.innerJoin(f, joinAs, cc.Field) + f.addWhere(fmt.Sprintf("%[1]s.value NOT IN %s", joinAs, getInBinding(len(cv))), cv...) + case models.CriterionModifierIncludes: + clauses := make([]sqlClause, len(cv)) + for i, v := range cv { + clauses[i] = makeClause(fmt.Sprintf("%s.value LIKE ?", joinAs), fmt.Sprintf("%%%v%%", v)) + } + h.innerJoin(f, joinAs, cc.Field) + f.whereClauses = append(f.whereClauses, clauses...) + case models.CriterionModifierExcludes: + for _, v := range cv { + f.addWhere(fmt.Sprintf("%[1]s.value NOT LIKE ?", joinAs), fmt.Sprintf("%%%v%%", v)) + } + h.leftJoin(f, joinAs, cc.Field) + case models.CriterionModifierMatchesRegex: + for _, v := range cv { + vs, ok := v.(string) + if !ok { + f.setError(fmt.Errorf("unsupported custom field criterion value type: %T", v)) + } + if _, err := regexp.Compile(vs); err != nil { + f.setError(err) + return + } + f.addWhere(fmt.Sprintf("(%s.value regexp ?)", joinAs), v) + } + h.innerJoin(f, joinAs, cc.Field) + case models.CriterionModifierNotMatchesRegex: + for _, v := range cv { + vs, ok := v.(string) + if !ok { + f.setError(fmt.Errorf("unsupported custom field criterion value type: %T", v)) + } + if _, err := regexp.Compile(vs); err != nil { + f.setError(err) + return + } + f.addWhere(fmt.Sprintf("(%s.value IS NULL OR %[1]s.value NOT regexp ?)", joinAs), v) + } + h.leftJoin(f, joinAs, cc.Field) + case models.CriterionModifierIsNull: + h.leftJoin(f, joinAs, cc.Field) + f.addWhere(fmt.Sprintf("%s.value IS NULL OR TRIM(%[1]s.value) = ''", joinAs)) + case models.CriterionModifierNotNull: + h.innerJoin(f, joinAs, cc.Field) + f.addWhere(fmt.Sprintf("TRIM(%[1]s.value) != ''", joinAs)) + case models.CriterionModifierBetween: + if len(cv) != 2 { + f.setError(fmt.Errorf("expected 2 values for custom field criterion modifier BETWEEN, got %d", len(cv))) + return + } + h.innerJoin(f, joinAs, cc.Field) + f.addWhere(fmt.Sprintf("%s.value BETWEEN ? AND ?", joinAs), cv[0], cv[1]) + case models.CriterionModifierNotBetween: + h.innerJoin(f, joinAs, cc.Field) + f.addWhere(fmt.Sprintf("%s.value NOT BETWEEN ? AND ?", joinAs), cv[0], cv[1]) + case models.CriterionModifierLessThan: + if len(cv) != 1 { + f.setError(fmt.Errorf("expected 1 value for custom field criterion modifier LESS_THAN, got %d", len(cv))) + return + } + h.innerJoin(f, joinAs, cc.Field) + f.addWhere(fmt.Sprintf("%s.value < ?", joinAs), cv[0]) + case models.CriterionModifierGreaterThan: + if len(cv) != 1 { + f.setError(fmt.Errorf("expected 1 value for custom field criterion modifier LESS_THAN, got %d", len(cv))) + return + } + h.innerJoin(f, joinAs, cc.Field) + f.addWhere(fmt.Sprintf("%s.value > ?", joinAs), cv[0]) + default: + f.setError(fmt.Errorf("unsupported custom field criterion modifier: %s", cc.Modifier)) + } +} + +func (h *customFieldsFilterHandler) handle(ctx context.Context, f *filterBuilder) { + if len(h.c) == 0 { + return + } + + for i, cc := range h.c { + join := fmt.Sprintf("custom_fields_%d", i) + h.handleCriterion(f, join, cc) + } +} diff --git a/pkg/sqlite/custom_fields_test.go b/pkg/sqlite/custom_fields_test.go new file mode 100644 index 00000000000..ce5c77487d9 --- /dev/null +++ b/pkg/sqlite/custom_fields_test.go @@ -0,0 +1,176 @@ +//go:build integration +// +build integration + +package sqlite_test + +import ( + "context" + "testing" + + "github.com/stashapp/stash/pkg/models" + "github.com/stretchr/testify/assert" +) + +func TestSetCustomFields(t *testing.T) { + performerIdx := performerIdx1WithScene + + mergeCustomFields := func(i map[string]interface{}) map[string]interface{} { + m := getPerformerCustomFields(performerIdx) + for k, v := range i { + m[k] = v + } + return m + } + + tests := []struct { + name string + input models.CustomFieldsInput + expected map[string]interface{} + wantErr bool + }{ + { + "valid full", + models.CustomFieldsInput{ + Full: map[string]interface{}{ + "key": "value", + }, + }, + map[string]interface{}{ + "key": "value", + }, + false, + }, + { + "valid partial", + models.CustomFieldsInput{ + Partial: map[string]interface{}{ + "key": "value", + }, + }, + mergeCustomFields(map[string]interface{}{ + "key": "value", + }), + false, + }, + { + "valid partial overwrite", + models.CustomFieldsInput{ + Partial: map[string]interface{}{ + "real": float64(4.56), + }, + }, + mergeCustomFields(map[string]interface{}{ + "real": float64(4.56), + }), + false, + }, + { + "leading space full", + models.CustomFieldsInput{ + Full: map[string]interface{}{ + " key": "value", + }, + }, + nil, + true, + }, + { + "trailing space full", + models.CustomFieldsInput{ + Full: map[string]interface{}{ + "key ": "value", + }, + }, + nil, + true, + }, + { + "leading space partial", + models.CustomFieldsInput{ + Partial: map[string]interface{}{ + " key": "value", + }, + }, + nil, + true, + }, + { + "trailing space partial", + models.CustomFieldsInput{ + Partial: map[string]interface{}{ + "key ": "value", + }, + }, + nil, + true, + }, + { + "big key full", + models.CustomFieldsInput{ + Full: map[string]interface{}{ + "12345678901234567890123456789012345678901234567890123456789012345": "value", + }, + }, + nil, + true, + }, + { + "big key partial", + models.CustomFieldsInput{ + Partial: map[string]interface{}{ + "12345678901234567890123456789012345678901234567890123456789012345": "value", + }, + }, + nil, + true, + }, + { + "empty key full", + models.CustomFieldsInput{ + Full: map[string]interface{}{ + "": "value", + }, + }, + nil, + true, + }, + { + "empty key partial", + models.CustomFieldsInput{ + Partial: map[string]interface{}{ + "": "value", + }, + }, + nil, + true, + }, + } + + // use performer custom fields store + store := db.Performer + id := performerIDs[performerIdx] + + assert := assert.New(t) + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + err := store.SetCustomFields(ctx, id, tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("SetCustomFields() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return + } + + actual, err := store.GetCustomFields(ctx, id) + if err != nil { + t.Errorf("GetCustomFields() error = %v", err) + return + } + + assert.Equal(tt.expected, actual) + }) + } +} diff --git a/pkg/sqlite/database.go b/pkg/sqlite/database.go index d2c0a8191e5..5ed803c1753 100644 --- a/pkg/sqlite/database.go +++ b/pkg/sqlite/database.go @@ -34,7 +34,7 @@ const ( cacheSizeEnv = "STASH_SQLITE_CACHE_SIZE" ) -var appSchemaVersion uint = 70 +var appSchemaVersion uint = 71 //go:embed migrations/*.sql var migrationsBox embed.FS diff --git a/pkg/sqlite/filter.go b/pkg/sqlite/filter.go index f4b5e7e7726..143487af235 100644 --- a/pkg/sqlite/filter.go +++ b/pkg/sqlite/filter.go @@ -95,6 +95,7 @@ type join struct { as string onClause string joinType string + args []interface{} } // equals returns true if the other join alias/table is equal to this one @@ -229,12 +230,13 @@ func (f *filterBuilder) not(n *filterBuilder) { // The AS is omitted if as is empty. // This method does not add a join if it its alias/table name is already // present in another existing join. -func (f *filterBuilder) addLeftJoin(table, as, onClause string) { +func (f *filterBuilder) addLeftJoin(table, as, onClause string, args ...interface{}) { newJoin := join{ table: table, as: as, onClause: onClause, joinType: "LEFT", + args: args, } f.joins.add(newJoin) @@ -245,12 +247,13 @@ func (f *filterBuilder) addLeftJoin(table, as, onClause string) { // The AS is omitted if as is empty. // This method does not add a join if it its alias/table name is already // present in another existing join. -func (f *filterBuilder) addInnerJoin(table, as, onClause string) { +func (f *filterBuilder) addInnerJoin(table, as, onClause string, args ...interface{}) { newJoin := join{ table: table, as: as, onClause: onClause, joinType: "INNER", + args: args, } f.joins.add(newJoin) diff --git a/pkg/sqlite/migrations/71_custom_fields.up.sql b/pkg/sqlite/migrations/71_custom_fields.up.sql new file mode 100644 index 00000000000..3440c20b13f --- /dev/null +++ b/pkg/sqlite/migrations/71_custom_fields.up.sql @@ -0,0 +1,9 @@ +CREATE TABLE `performer_custom_fields` ( + `performer_id` integer NOT NULL, + `field` varchar(64) NOT NULL, + `value` BLOB NOT NULL, + PRIMARY KEY (`performer_id`, `field`), + foreign key(`performer_id`) references `performers`(`id`) on delete CASCADE +); + +CREATE INDEX `index_performer_custom_fields_field_value` ON `performer_custom_fields` (`field`, `value`); diff --git a/pkg/sqlite/performer.go b/pkg/sqlite/performer.go index e20dc9c4cc5..e291078b204 100644 --- a/pkg/sqlite/performer.go +++ b/pkg/sqlite/performer.go @@ -226,6 +226,7 @@ var ( type PerformerStore struct { blobJoinQueryBuilder + customFieldsStore tableMgr *table } @@ -236,6 +237,10 @@ func NewPerformerStore(blobStore *BlobStore) *PerformerStore { blobStore: blobStore, joinTable: performerTable, }, + customFieldsStore: customFieldsStore{ + table: performersCustomFieldsTable, + fk: performersCustomFieldsTable.Col(performerIDColumn), + }, tableMgr: performerTableMgr, } } @@ -248,9 +253,9 @@ func (qb *PerformerStore) selectDataset() *goqu.SelectDataset { return dialect.From(qb.table()).Select(qb.table().All()) } -func (qb *PerformerStore) Create(ctx context.Context, newObject *models.Performer) error { +func (qb *PerformerStore) Create(ctx context.Context, newObject *models.CreatePerformerInput) error { var r performerRow - r.fromPerformer(*newObject) + r.fromPerformer(*newObject.Performer) id, err := qb.tableMgr.insertID(ctx, r) if err != nil { @@ -282,12 +287,17 @@ func (qb *PerformerStore) Create(ctx context.Context, newObject *models.Performe } } + const partial = false + if err := qb.setCustomFields(ctx, id, newObject.CustomFields, partial); err != nil { + return err + } + updated, err := qb.find(ctx, id) if err != nil { return fmt.Errorf("finding after create: %w", err) } - *newObject = *updated + *newObject.Performer = *updated return nil } @@ -330,12 +340,16 @@ func (qb *PerformerStore) UpdatePartial(ctx context.Context, id int, partial mod } } + if err := qb.SetCustomFields(ctx, id, partial.CustomFields); err != nil { + return nil, err + } + return qb.find(ctx, id) } -func (qb *PerformerStore) Update(ctx context.Context, updatedObject *models.Performer) error { +func (qb *PerformerStore) Update(ctx context.Context, updatedObject *models.UpdatePerformerInput) error { var r performerRow - r.fromPerformer(*updatedObject) + r.fromPerformer(*updatedObject.Performer) if err := qb.tableMgr.updateByID(ctx, updatedObject.ID, r); err != nil { return err @@ -365,6 +379,10 @@ func (qb *PerformerStore) Update(ctx context.Context, updatedObject *models.Perf } } + if err := qb.SetCustomFields(ctx, updatedObject.ID, updatedObject.CustomFields); err != nil { + return err + } + return nil } diff --git a/pkg/sqlite/performer_filter.go b/pkg/sqlite/performer_filter.go index 72990a7febd..ae882c9503e 100644 --- a/pkg/sqlite/performer_filter.go +++ b/pkg/sqlite/performer_filter.go @@ -203,6 +203,13 @@ func (qb *performerFilterHandler) criterionHandler() criterionHandler { performerRepository.tags.innerJoin(f, "performer_tag", "performers.id") }, }, + + &customFieldsFilterHandler{ + table: performersCustomFieldsTable.GetTable(), + fkCol: performerIDColumn, + c: filter.CustomFields, + idCol: "performers.id", + }, } } diff --git a/pkg/sqlite/performer_test.go b/pkg/sqlite/performer_test.go index e0294f3e442..d24b4ca4e6a 100644 --- a/pkg/sqlite/performer_test.go +++ b/pkg/sqlite/performer_test.go @@ -16,6 +16,12 @@ import ( "github.com/stretchr/testify/assert" ) +var testCustomFields = map[string]interface{}{ + "string": "aaa", + "int": int64(123), // int64 to match the type of the field in the database + "real": 1.23, +} + func loadPerformerRelationships(ctx context.Context, expected models.Performer, actual *models.Performer) error { if expected.Aliases.Loaded() { if err := actual.LoadAliases(ctx, db.Performer); err != nil { @@ -81,57 +87,62 @@ func Test_PerformerStore_Create(t *testing.T) { tests := []struct { name string - newObject models.Performer + newObject models.CreatePerformerInput wantErr bool }{ { "full", - models.Performer{ - Name: name, - Disambiguation: disambiguation, - Gender: &gender, - URLs: models.NewRelatedStrings(urls), - Birthdate: &birthdate, - Ethnicity: ethnicity, - Country: country, - EyeColor: eyeColor, - Height: &height, - Measurements: measurements, - FakeTits: fakeTits, - PenisLength: &penisLength, - Circumcised: &circumcised, - CareerLength: careerLength, - Tattoos: tattoos, - Piercings: piercings, - Favorite: favorite, - Rating: &rating, - Details: details, - DeathDate: &deathdate, - HairColor: hairColor, - Weight: &weight, - IgnoreAutoTag: ignoreAutoTag, - TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}), - Aliases: models.NewRelatedStrings(aliases), - StashIDs: models.NewRelatedStashIDs([]models.StashID{ - { - StashID: stashID1, - Endpoint: endpoint1, - }, - { - StashID: stashID2, - Endpoint: endpoint2, - }, - }), - CreatedAt: createdAt, - UpdatedAt: updatedAt, + models.CreatePerformerInput{ + Performer: &models.Performer{ + Name: name, + Disambiguation: disambiguation, + Gender: &gender, + URLs: models.NewRelatedStrings(urls), + Birthdate: &birthdate, + Ethnicity: ethnicity, + Country: country, + EyeColor: eyeColor, + Height: &height, + Measurements: measurements, + FakeTits: fakeTits, + PenisLength: &penisLength, + Circumcised: &circumcised, + CareerLength: careerLength, + Tattoos: tattoos, + Piercings: piercings, + Favorite: favorite, + Rating: &rating, + Details: details, + DeathDate: &deathdate, + HairColor: hairColor, + Weight: &weight, + IgnoreAutoTag: ignoreAutoTag, + TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}), + Aliases: models.NewRelatedStrings(aliases), + StashIDs: models.NewRelatedStashIDs([]models.StashID{ + { + StashID: stashID1, + Endpoint: endpoint1, + }, + { + StashID: stashID2, + Endpoint: endpoint2, + }, + }), + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + CustomFields: testCustomFields, }, false, }, { "invalid tag id", - models.Performer{ - Name: name, - TagIDs: models.NewRelatedIDs([]int{invalidID}), + models.CreatePerformerInput{ + Performer: &models.Performer{ + Name: name, + TagIDs: models.NewRelatedIDs([]int{invalidID}), + }, }, true, }, @@ -155,16 +166,16 @@ func Test_PerformerStore_Create(t *testing.T) { assert.NotZero(p.ID) - copy := tt.newObject + copy := *tt.newObject.Performer copy.ID = p.ID // load relationships - if err := loadPerformerRelationships(ctx, copy, &p); err != nil { + if err := loadPerformerRelationships(ctx, copy, p.Performer); err != nil { t.Errorf("loadPerformerRelationships() error = %v", err) return } - assert.Equal(copy, p) + assert.Equal(copy, *p.Performer) // ensure can find the performer found, err := qb.Find(ctx, p.ID) @@ -183,6 +194,15 @@ func Test_PerformerStore_Create(t *testing.T) { } assert.Equal(copy, *found) + // ensure custom fields are set + cf, err := qb.GetCustomFields(ctx, p.ID) + if err != nil { + t.Errorf("PerformerStore.GetCustomFields() error = %v", err) + return + } + + assert.Equal(tt.newObject.CustomFields, cf) + return }) } @@ -228,77 +248,109 @@ func Test_PerformerStore_Update(t *testing.T) { tests := []struct { name string - updatedObject *models.Performer + updatedObject models.UpdatePerformerInput wantErr bool }{ { "full", - &models.Performer{ - ID: performerIDs[performerIdxWithGallery], - Name: name, - Disambiguation: disambiguation, - Gender: &gender, - URLs: models.NewRelatedStrings(urls), - Birthdate: &birthdate, - Ethnicity: ethnicity, - Country: country, - EyeColor: eyeColor, - Height: &height, - Measurements: measurements, - FakeTits: fakeTits, - PenisLength: &penisLength, - Circumcised: &circumcised, - CareerLength: careerLength, - Tattoos: tattoos, - Piercings: piercings, - Favorite: favorite, - Rating: &rating, - Details: details, - DeathDate: &deathdate, - HairColor: hairColor, - Weight: &weight, - IgnoreAutoTag: ignoreAutoTag, - Aliases: models.NewRelatedStrings(aliases), - TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}), - StashIDs: models.NewRelatedStashIDs([]models.StashID{ - { - StashID: stashID1, - Endpoint: endpoint1, - }, - { - StashID: stashID2, - Endpoint: endpoint2, - }, - }), - CreatedAt: createdAt, - UpdatedAt: updatedAt, + models.UpdatePerformerInput{ + Performer: &models.Performer{ + ID: performerIDs[performerIdxWithGallery], + Name: name, + Disambiguation: disambiguation, + Gender: &gender, + URLs: models.NewRelatedStrings(urls), + Birthdate: &birthdate, + Ethnicity: ethnicity, + Country: country, + EyeColor: eyeColor, + Height: &height, + Measurements: measurements, + FakeTits: fakeTits, + PenisLength: &penisLength, + Circumcised: &circumcised, + CareerLength: careerLength, + Tattoos: tattoos, + Piercings: piercings, + Favorite: favorite, + Rating: &rating, + Details: details, + DeathDate: &deathdate, + HairColor: hairColor, + Weight: &weight, + IgnoreAutoTag: ignoreAutoTag, + Aliases: models.NewRelatedStrings(aliases), + TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}), + StashIDs: models.NewRelatedStashIDs([]models.StashID{ + { + StashID: stashID1, + Endpoint: endpoint1, + }, + { + StashID: stashID2, + Endpoint: endpoint2, + }, + }), + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, }, false, }, { "clear nullables", - &models.Performer{ - ID: performerIDs[performerIdxWithGallery], - Aliases: models.NewRelatedStrings([]string{}), - URLs: models.NewRelatedStrings([]string{}), - TagIDs: models.NewRelatedIDs([]int{}), - StashIDs: models.NewRelatedStashIDs([]models.StashID{}), + models.UpdatePerformerInput{ + Performer: &models.Performer{ + ID: performerIDs[performerIdxWithGallery], + Aliases: models.NewRelatedStrings([]string{}), + URLs: models.NewRelatedStrings([]string{}), + TagIDs: models.NewRelatedIDs([]int{}), + StashIDs: models.NewRelatedStashIDs([]models.StashID{}), + }, }, false, }, { "clear tag ids", - &models.Performer{ - ID: performerIDs[sceneIdxWithTag], - TagIDs: models.NewRelatedIDs([]int{}), + models.UpdatePerformerInput{ + Performer: &models.Performer{ + ID: performerIDs[sceneIdxWithTag], + TagIDs: models.NewRelatedIDs([]int{}), + }, + }, + false, + }, + { + "set custom fields", + models.UpdatePerformerInput{ + Performer: &models.Performer{ + ID: performerIDs[performerIdxWithGallery], + }, + CustomFields: models.CustomFieldsInput{ + Full: testCustomFields, + }, + }, + false, + }, + { + "clear custom fields", + models.UpdatePerformerInput{ + Performer: &models.Performer{ + ID: performerIDs[performerIdxWithGallery], + }, + CustomFields: models.CustomFieldsInput{ + Full: map[string]interface{}{}, + }, }, false, }, { "invalid tag id", - &models.Performer{ - ID: performerIDs[sceneIdxWithGallery], - TagIDs: models.NewRelatedIDs([]int{invalidID}), + models.UpdatePerformerInput{ + Performer: &models.Performer{ + ID: performerIDs[sceneIdxWithGallery], + TagIDs: models.NewRelatedIDs([]int{invalidID}), + }, }, true, }, @@ -309,9 +361,9 @@ func Test_PerformerStore_Update(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - copy := *tt.updatedObject + copy := *tt.updatedObject.Performer - if err := qb.Update(ctx, tt.updatedObject); (err != nil) != tt.wantErr { + if err := qb.Update(ctx, &tt.updatedObject); (err != nil) != tt.wantErr { t.Errorf("PerformerStore.Update() error = %v, wantErr %v", err, tt.wantErr) } @@ -331,6 +383,17 @@ func Test_PerformerStore_Update(t *testing.T) { } assert.Equal(copy, *s) + + // ensure custom fields are correct + if tt.updatedObject.CustomFields.Full != nil { + cf, err := qb.GetCustomFields(ctx, tt.updatedObject.ID) + if err != nil { + t.Errorf("PerformerStore.GetCustomFields() error = %v", err) + return + } + + assert.Equal(tt.updatedObject.CustomFields.Full, cf) + } }) } } @@ -573,6 +636,79 @@ func Test_PerformerStore_UpdatePartial(t *testing.T) { } } +func Test_PerformerStore_UpdatePartialCustomFields(t *testing.T) { + tests := []struct { + name string + id int + partial models.PerformerPartial + expected map[string]interface{} // nil to use the partial + }{ + { + "set custom fields", + performerIDs[performerIdxWithGallery], + models.PerformerPartial{ + CustomFields: models.CustomFieldsInput{ + Full: testCustomFields, + }, + }, + nil, + }, + { + "clear custom fields", + performerIDs[performerIdxWithGallery], + models.PerformerPartial{ + CustomFields: models.CustomFieldsInput{ + Full: map[string]interface{}{}, + }, + }, + nil, + }, + { + "partial custom fields", + performerIDs[performerIdxWithGallery], + models.PerformerPartial{ + CustomFields: models.CustomFieldsInput{ + Partial: map[string]interface{}{ + "string": "bbb", + "new_field": "new", + }, + }, + }, + map[string]interface{}{ + "int": int64(3), + "real": 1.3, + "string": "bbb", + "new_field": "new", + }, + }, + } + for _, tt := range tests { + qb := db.Performer + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + _, err := qb.UpdatePartial(ctx, tt.id, tt.partial) + if err != nil { + t.Errorf("PerformerStore.UpdatePartial() error = %v", err) + return + } + + // ensure custom fields are correct + cf, err := qb.GetCustomFields(ctx, tt.id) + if err != nil { + t.Errorf("PerformerStore.GetCustomFields() error = %v", err) + return + } + if tt.expected == nil { + assert.Equal(tt.partial.CustomFields.Full, cf) + } else { + assert.Equal(tt.expected, cf) + } + }) + } +} + func TestPerformerFindBySceneID(t *testing.T) { withTxn(func(ctx context.Context) error { pqb := db.Performer @@ -1042,6 +1178,242 @@ func TestPerformerQuery(t *testing.T) { } } +func TestPerformerQueryCustomFields(t *testing.T) { + tests := []struct { + name string + filter *models.PerformerFilterType + includeIdxs []int + excludeIdxs []int + wantErr bool + }{ + { + "equals", + &models.PerformerFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierEquals, + Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")}, + }, + }, + }, + []int{performerIdxWithGallery}, + nil, + false, + }, + { + "not equals", + &models.PerformerFilterType{ + Name: &models.StringCriterionInput{ + Value: getPerformerStringValue(performerIdxWithGallery, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotEquals, + Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")}, + }, + }, + }, + nil, + []int{performerIdxWithGallery}, + false, + }, + { + "includes", + &models.PerformerFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierIncludes, + Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")[9:]}, + }, + }, + }, + []int{performerIdxWithGallery}, + nil, + false, + }, + { + "excludes", + &models.PerformerFilterType{ + Name: &models.StringCriterionInput{ + Value: getPerformerStringValue(performerIdxWithGallery, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierExcludes, + Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")[9:]}, + }, + }, + }, + nil, + []int{performerIdxWithGallery}, + false, + }, + { + "regex", + &models.PerformerFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierMatchesRegex, + Value: []any{".*13_custom"}, + }, + }, + }, + []int{performerIdxWithGallery}, + nil, + false, + }, + { + "invalid regex", + &models.PerformerFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierMatchesRegex, + Value: []any{"["}, + }, + }, + }, + nil, + nil, + true, + }, + { + "not matches regex", + &models.PerformerFilterType{ + Name: &models.StringCriterionInput{ + Value: getPerformerStringValue(performerIdxWithGallery, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotMatchesRegex, + Value: []any{".*13_custom"}, + }, + }, + }, + nil, + []int{performerIdxWithGallery}, + false, + }, + { + "invalid not matches regex", + &models.PerformerFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotMatchesRegex, + Value: []any{"["}, + }, + }, + }, + nil, + nil, + true, + }, + { + "null", + &models.PerformerFilterType{ + Name: &models.StringCriterionInput{ + Value: getPerformerStringValue(performerIdxWithGallery, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "not existing", + Modifier: models.CriterionModifierIsNull, + }, + }, + }, + []int{performerIdxWithGallery}, + nil, + false, + }, + { + "null", + &models.PerformerFilterType{ + Name: &models.StringCriterionInput{ + Value: getPerformerStringValue(performerIdxWithGallery, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "string", + Modifier: models.CriterionModifierNotNull, + }, + }, + }, + []int{performerIdxWithGallery}, + nil, + false, + }, + { + "between", + &models.PerformerFilterType{ + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "real", + Modifier: models.CriterionModifierBetween, + Value: []any{0.05, 0.15}, + }, + }, + }, + []int{performerIdx1WithScene}, + nil, + false, + }, + { + "not between", + &models.PerformerFilterType{ + Name: &models.StringCriterionInput{ + Value: getPerformerStringValue(performerIdx1WithScene, "Name"), + Modifier: models.CriterionModifierEquals, + }, + CustomFields: []models.CustomFieldCriterionInput{ + { + Field: "real", + Modifier: models.CriterionModifierNotBetween, + Value: []any{0.05, 0.15}, + }, + }, + }, + nil, + []int{performerIdx1WithScene}, + false, + }, + } + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + performers, _, err := db.Performer.Query(ctx, tt.filter, nil) + if (err != nil) != tt.wantErr { + t.Errorf("PerformerStore.Query() error = %v, wantErr %v", err, tt.wantErr) + return + } + + ids := performersToIDs(performers) + include := indexesToIDs(performerIDs, tt.includeIdxs) + exclude := indexesToIDs(performerIDs, tt.excludeIdxs) + + for _, i := range include { + assert.Contains(ids, i) + } + for _, e := range exclude { + assert.NotContains(ids, e) + } + }) + } +} + func TestPerformerQueryPenisLength(t *testing.T) { var upper = 4.0 @@ -1172,7 +1544,7 @@ func TestPerformerUpdatePerformerImage(t *testing.T) { performer := models.Performer{ Name: name, } - err := qb.Create(ctx, &performer) + err := qb.Create(ctx, &models.CreatePerformerInput{Performer: &performer}) if err != nil { return fmt.Errorf("Error creating performer: %s", err.Error()) } @@ -1680,7 +2052,7 @@ func TestPerformerStashIDs(t *testing.T) { performer := &models.Performer{ Name: name, } - if err := qb.Create(ctx, performer); err != nil { + if err := qb.Create(ctx, &models.CreatePerformerInput{Performer: performer}); err != nil { return fmt.Errorf("Error creating performer: %s", err.Error()) } diff --git a/pkg/sqlite/query.go b/pkg/sqlite/query.go index 9c09d8beaed..4f4c0c8db56 100644 --- a/pkg/sqlite/query.go +++ b/pkg/sqlite/query.go @@ -133,6 +133,9 @@ func (qb *queryBuilder) join(table, as, onClause string) { func (qb *queryBuilder) addJoins(joins ...join) { qb.joins.add(joins...) + for _, j := range joins { + qb.args = append(qb.args, j.args...) + } } func (qb *queryBuilder) addFilter(f *filterBuilder) error { @@ -151,6 +154,9 @@ func (qb *queryBuilder) addFilter(f *filterBuilder) error { qb.args = append(args, qb.args...) } + // add joins here to insert args + qb.addJoins(f.getAllJoins()...) + clause, args = f.generateWhereClauses() if len(clause) > 0 { qb.addWhere(clause) @@ -169,8 +175,6 @@ func (qb *queryBuilder) addFilter(f *filterBuilder) error { qb.addArg(args...) } - qb.addJoins(f.getAllJoins()...) - return nil } diff --git a/pkg/sqlite/repository.go b/pkg/sqlite/repository.go index 2035b11c2fc..ac2954cfb24 100644 --- a/pkg/sqlite/repository.go +++ b/pkg/sqlite/repository.go @@ -222,8 +222,8 @@ func (r *repository) innerJoin(j joiner, as string, parentIDCol string) { } type joiner interface { - addLeftJoin(table, as, onClause string) - addInnerJoin(table, as, onClause string) + addLeftJoin(table, as, onClause string, args ...interface{}) + addInnerJoin(table, as, onClause string, args ...interface{}) } type joinRepository struct { diff --git a/pkg/sqlite/setup_test.go b/pkg/sqlite/setup_test.go index b63b6a04a2c..1d2854297ee 100644 --- a/pkg/sqlite/setup_test.go +++ b/pkg/sqlite/setup_test.go @@ -1508,6 +1508,18 @@ func performerAliases(i int) []string { return []string{getPerformerStringValue(i, "alias")} } +func getPerformerCustomFields(index int) map[string]interface{} { + if index%5 == 0 { + return nil + } + + return map[string]interface{}{ + "string": getPerformerStringValue(index, "custom"), + "int": int64(index % 5), + "real": float64(index) / 10, + } +} + // createPerformers creates n performers with plain Name and o performers with camel cased NaMe included func createPerformers(ctx context.Context, n int, o int) error { pqb := db.Performer @@ -1558,7 +1570,10 @@ func createPerformers(ctx context.Context, n int, o int) error { }) } - err := pqb.Create(ctx, &performer) + err := pqb.Create(ctx, &models.CreatePerformerInput{ + Performer: &performer, + CustomFields: getPerformerCustomFields(i), + }) if err != nil { return fmt.Errorf("Error creating performer %v+: %s", performer, err.Error()) diff --git a/pkg/sqlite/tables.go b/pkg/sqlite/tables.go index 481c4ee06a4..c6ab6a4d4d2 100644 --- a/pkg/sqlite/tables.go +++ b/pkg/sqlite/tables.go @@ -32,6 +32,7 @@ var ( performersURLsJoinTable = goqu.T(performerURLsTable) performersTagsJoinTable = goqu.T(performersTagsTable) performersStashIDsJoinTable = goqu.T("performer_stash_ids") + performersCustomFieldsTable = goqu.T("performer_custom_fields") studiosAliasesJoinTable = goqu.T(studioAliasesTable) studiosTagsJoinTable = goqu.T(studiosTagsTable) diff --git a/pkg/utils/json.go b/pkg/utils/json.go deleted file mode 100644 index ae69180688c..00000000000 --- a/pkg/utils/json.go +++ /dev/null @@ -1,16 +0,0 @@ -package utils - -import ( - "encoding/json" - "strings" -) - -// JSONNumberToNumber converts a JSON number to either a float64 or int64. -func JSONNumberToNumber(n json.Number) interface{} { - if strings.Contains(string(n), ".") { - f, _ := n.Float64() - return f - } - ret, _ := n.Int64() - return ret -} diff --git a/pkg/utils/map.go b/pkg/utils/map.go index dbef17646b2..0c555857443 100644 --- a/pkg/utils/map.go +++ b/pkg/utils/map.go @@ -1,7 +1,6 @@ package utils import ( - "encoding/json" "strings" ) @@ -80,19 +79,3 @@ func MergeMaps(dest map[string]interface{}, src map[string]interface{}) { dest[k] = v } } - -// ConvertMapJSONNumbers converts all JSON numbers in a map to either float64 or int64. -func ConvertMapJSONNumbers(m map[string]interface{}) (ret map[string]interface{}) { - ret = make(map[string]interface{}) - for k, v := range m { - if n, ok := v.(json.Number); ok { - ret[k] = JSONNumberToNumber(n) - } else if mm, ok := v.(map[string]interface{}); ok { - ret[k] = ConvertMapJSONNumbers(mm) - } else { - ret[k] = v - } - } - - return ret -} diff --git a/pkg/utils/map_test.go b/pkg/utils/map_test.go index 142cd639321..54dfacedd30 100644 --- a/pkg/utils/map_test.go +++ b/pkg/utils/map_test.go @@ -1,11 +1,8 @@ package utils import ( - "encoding/json" "reflect" "testing" - - "github.com/stretchr/testify/assert" ) func TestNestedMapGet(t *testing.T) { @@ -282,55 +279,3 @@ func TestMergeMaps(t *testing.T) { }) } } - -func TestConvertMapJSONNumbers(t *testing.T) { - tests := []struct { - name string - input map[string]interface{} - expected map[string]interface{} - }{ - { - name: "Convert JSON numbers to numbers", - input: map[string]interface{}{ - "int": json.Number("12"), - "float": json.Number("12.34"), - "string": "foo", - }, - expected: map[string]interface{}{ - "int": int64(12), - "float": 12.34, - "string": "foo", - }, - }, - { - name: "Convert JSON numbers to numbers in nested maps", - input: map[string]interface{}{ - "foo": map[string]interface{}{ - "int": json.Number("56"), - "float": json.Number("56.78"), - "nested-string": "bar", - }, - "int": json.Number("12"), - "float": json.Number("12.34"), - "string": "foo", - }, - expected: map[string]interface{}{ - "foo": map[string]interface{}{ - "int": int64(56), - "float": 56.78, - "nested-string": "bar", - }, - "int": int64(12), - "float": 12.34, - "string": "foo", - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := ConvertMapJSONNumbers(tt.input) - assert.Equal(t, tt.expected, result) - }) - } -} diff --git a/ui/v2.5/graphql/data/performer.graphql b/ui/v2.5/graphql/data/performer.graphql index 0aa60ce21bb..035c8abc72d 100644 --- a/ui/v2.5/graphql/data/performer.graphql +++ b/ui/v2.5/graphql/data/performer.graphql @@ -41,4 +41,6 @@ fragment PerformerData on Performer { death_date hair_color weight + + custom_fields } diff --git a/ui/v2.5/src/components/Performers/PerformerDetails/PerformerDetailsPanel.tsx b/ui/v2.5/src/components/Performers/PerformerDetails/PerformerDetailsPanel.tsx index e805c03e621..83fc64c956c 100644 --- a/ui/v2.5/src/components/Performers/PerformerDetails/PerformerDetailsPanel.tsx +++ b/ui/v2.5/src/components/Performers/PerformerDetails/PerformerDetailsPanel.tsx @@ -14,6 +14,7 @@ import { FormatWeight, } from "../PerformerList"; import { PatchComponent } from "src/patch"; +import { CustomFields } from "src/components/Shared/CustomFields"; interface IPerformerDetails { performer: GQL.PerformerDataFragment; @@ -176,6 +177,7 @@ export const PerformerDetailsPanel: React.FC = value={renderStashIDs()} fullWidth={fullWidth} /> + {fullWidth && } ); }); diff --git a/ui/v2.5/src/components/Performers/PerformerDetails/PerformerEditPanel.tsx b/ui/v2.5/src/components/Performers/PerformerDetails/PerformerEditPanel.tsx index 2adcb601e1e..df5b62b05f3 100644 --- a/ui/v2.5/src/components/Performers/PerformerDetails/PerformerEditPanel.tsx +++ b/ui/v2.5/src/components/Performers/PerformerDetails/PerformerEditPanel.tsx @@ -47,6 +47,8 @@ import { yupUniqueStringList, } from "src/utils/yup"; import { useTagsEdit } from "src/hooks/tagsEdit"; +import { CustomFieldsInput } from "src/components/Shared/CustomFields"; +import { cloneDeep } from "@apollo/client/utilities"; const isScraper = ( scraper: GQL.Scraper | GQL.StashBox @@ -61,6 +63,16 @@ interface IPerformerDetails { setEncodingImage: (loading: boolean) => void; } +function customFieldInput(isNew: boolean, input: {}) { + if (isNew) { + return input; + } else { + return { + full: input, + }; + } +} + export const PerformerEditPanel: React.FC = ({ performer, isVisible, @@ -115,6 +127,7 @@ export const PerformerEditPanel: React.FC = ({ ignore_auto_tag: yup.boolean().defined(), stash_ids: yup.mixed().defined(), image: yup.string().nullable().optional(), + custom_fields: yup.object().required().defined(), }); const initialValues = { @@ -142,15 +155,26 @@ export const PerformerEditPanel: React.FC = ({ tag_ids: (performer.tags ?? []).map((t) => t.id), ignore_auto_tag: performer.ignore_auto_tag ?? false, stash_ids: getStashIDs(performer.stash_ids), + custom_fields: cloneDeep(performer.custom_fields ?? {}), }; type InputValues = yup.InferType; + const [customFieldsError, setCustomFieldsError] = useState(); + + function submit(values: InputValues) { + const input = { + ...schema.cast(values), + custom_fields: customFieldInput(isNew, values.custom_fields), + }; + onSave(input); + } + const formik = useFormik({ initialValues, enableReinitialize: true, validate: yupFormikValidate(schema), - onSubmit: (values) => onSave(schema.cast(values)), + onSubmit: submit, }); const { tags, updateTagsStateFromScraper, tagsControl } = useTagsEdit( @@ -571,7 +595,11 @@ export const PerformerEditPanel: React.FC = ({ @@ -44,7 +45,7 @@ export const ExpandCollapseButton: React.FC<{ className="minimal expand-collapse" onClick={() => setCollapsed(!collapsed)} > - + ); diff --git a/ui/v2.5/src/components/Shared/CustomFields.tsx b/ui/v2.5/src/components/Shared/CustomFields.tsx new file mode 100644 index 00000000000..233254f7a15 --- /dev/null +++ b/ui/v2.5/src/components/Shared/CustomFields.tsx @@ -0,0 +1,308 @@ +import React, { useEffect, useMemo, useRef, useState } from "react"; +import { CollapseButton } from "./CollapseButton"; +import { DetailItem } from "./DetailItem"; +import { Button, Col, Form, FormGroup, InputGroup, Row } from "react-bootstrap"; +import { FormattedMessage, useIntl } from "react-intl"; +import { cloneDeep } from "@apollo/client/utilities"; +import { Icon } from "./Icon"; +import { faMinus, faPlus } from "@fortawesome/free-solid-svg-icons"; +import cx from "classnames"; + +const maxFieldNameLength = 64; + +export type CustomFieldMap = { + [key: string]: unknown; +}; + +interface ICustomFields { + values: CustomFieldMap; +} + +function convertValue(value: unknown): string { + if (typeof value === "string") { + return value; + } else if (typeof value === "number") { + return value.toString(); + } else if (typeof value === "boolean") { + return value ? "true" : "false"; + } else if (Array.isArray(value)) { + return value.join(", "); + } else { + return JSON.stringify(value); + } +} + +const CustomField: React.FC<{ field: string; value: unknown }> = ({ + field, + value, +}) => { + const valueStr = convertValue(value); + + // replace spaces with hyphen characters for css id + const id = field.toLowerCase().replace(/ /g, "-"); + + return ( + + ); +}; + +export const CustomFields: React.FC = ({ values }) => { + const intl = useIntl(); + if (Object.keys(values).length === 0) { + return null; + } + + return ( + // according to linter rule CSS classes shouldn't use underscores +
+ + {Object.entries(values).map(([key, value]) => ( + + ))} + +
+ ); +}; + +function isNumeric(v: string) { + return /^-?(?:0|(?:[1-9][0-9]*))(?:\.[0-9]+)?$/.test(v); +} + +function convertCustomValue(v: string) { + // if the value is numeric, convert it to a number + if (isNumeric(v)) { + return Number(v); + } else { + return v; + } +} + +const CustomFieldInput: React.FC<{ + field: string; + value: unknown; + onChange: (field: string, value: unknown) => void; + isNew?: boolean; + error?: string; +}> = ({ field, value, onChange, isNew = false, error }) => { + const intl = useIntl(); + const [currentField, setCurrentField] = useState(field); + const [currentValue, setCurrentValue] = useState(value as string); + + const fieldRef = useRef(null); + const valueRef = useRef(null); + + useEffect(() => { + setCurrentField(field); + setCurrentValue(value as string); + }, [field, value]); + + function onBlur() { + onChange(currentField, convertCustomValue(currentValue)); + } + + function onDelete() { + onChange("", ""); + } + + return ( + + + + {isNew ? ( + <> + setCurrentField(event.currentTarget.value)} + onBlur={onBlur} + /> + + ) : ( + {currentField} + )} + + + + setCurrentValue(event.currentTarget.value)} + onBlur={onBlur} + /> + + {!isNew && ( + + )} + + + + + {error} + + ); +}; + +interface ICustomField { + field: string; + value: unknown; +} + +interface ICustomFieldsInput { + values: CustomFieldMap; + error?: string; + onChange: (values: CustomFieldMap) => void; + setError: (error?: string) => void; +} + +export const CustomFieldsInput: React.FC = ({ + values, + error, + onChange, + setError, +}) => { + const intl = useIntl(); + + const [newCustomField, setNewCustomField] = useState({ + field: "", + value: "", + }); + + const fields = useMemo(() => { + const valueCopy = cloneDeep(values); + if (newCustomField.field !== "" && error === undefined) { + delete valueCopy[newCustomField.field]; + } + + const ret = Object.keys(valueCopy); + ret.sort(); + return ret; + }, [values, newCustomField, error]); + + function onSetNewField(v: ICustomField) { + // validate the field name + let newError = undefined; + if (v.field.length > maxFieldNameLength) { + newError = intl.formatMessage({ + id: "errors.custom_fields.field_name_length", + }); + } + if (v.field.trim() === "" && v.value !== "") { + newError = intl.formatMessage({ + id: "errors.custom_fields.field_name_required", + }); + } + if (v.field.trim() !== v.field) { + newError = intl.formatMessage({ + id: "errors.custom_fields.field_name_whitespace", + }); + } + if (fields.includes(v.field)) { + newError = intl.formatMessage({ + id: "errors.custom_fields.duplicate_field", + }); + } + + const oldField = newCustomField; + + setNewCustomField(v); + + const valuesCopy = cloneDeep(values); + if (oldField.field !== "" && error === undefined) { + delete valuesCopy[oldField.field]; + } + + // if valid, pass up + if (!newError && v.field !== "") { + valuesCopy[v.field] = v.value; + } + + onChange(valuesCopy); + setError(newError); + } + + function onAdd() { + const newValues = { + ...values, + [newCustomField.field]: newCustomField.value, + }; + setNewCustomField({ field: "", value: "" }); + onChange(newValues); + } + + function fieldChanged( + currentField: string, + newField: string, + value: unknown + ) { + let newValues = cloneDeep(values); + delete newValues[currentField]; + if (newField !== "") { + newValues[newField] = value; + } + onChange(newValues); + } + + return ( + + + + + + + + + + + + {fields.map((field) => ( + + fieldChanged(field, newField, newValue) + } + /> + ))} + onSetNewField({ field, value })} + isNew + /> + + + + + ); +}; diff --git a/ui/v2.5/src/components/Shared/DetailItem.tsx b/ui/v2.5/src/components/Shared/DetailItem.tsx index 304655a4c69..a92f75868d3 100644 --- a/ui/v2.5/src/components/Shared/DetailItem.tsx +++ b/ui/v2.5/src/components/Shared/DetailItem.tsx @@ -3,34 +3,39 @@ import { FormattedMessage } from "react-intl"; interface IDetailItem { id?: string | null; + label?: React.ReactNode; value?: React.ReactNode; + labelTitle?: string; title?: string; fullWidth?: boolean; + showEmpty?: boolean; } export const DetailItem: React.FC = ({ id, + label, value, + labelTitle, title, fullWidth, + showEmpty = false, }) => { - if (!id || !value || value === "Na") { + if (!id || (!showEmpty && (!value || value === "Na"))) { return <>; } - const message = ; + const message = label ?? ; + + // according to linter rule CSS classes shouldn't use underscores + const sanitisedID = id.replace(/_/g, "-"); return ( - // according to linter rule CSS classes shouldn't use underscores
- + {message} {fullWidth ? ":" : ""} - + {value}
diff --git a/ui/v2.5/src/components/Shared/styles.scss b/ui/v2.5/src/components/Shared/styles.scss index 644eff047cd..50777fff380 100644 --- a/ui/v2.5/src/components/Shared/styles.scss +++ b/ui/v2.5/src/components/Shared/styles.scss @@ -197,6 +197,15 @@ button.collapse-button.btn-primary:not(:disabled):not(.disabled):active { border: none; box-shadow: none; color: #f5f8fa; + text-align: left; +} + +button.collapse-button { + .fa-icon { + margin-left: 0; + } + + padding-left: 0; } .hover-popover-content { @@ -678,3 +687,44 @@ button.btn.favorite-button { } } } + +.custom-fields .detail-item .detail-item-title { + max-width: 130px; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.custom-fields-input > .collapse-button { + font-weight: 700; +} + +.custom-fields-row { + align-items: center; + font-family: "Courier New", Courier, monospace; + font-size: 0.875rem; + + .form-label { + margin-bottom: 0; + max-width: 100%; + overflow: hidden; + text-overflow: ellipsis; + vertical-align: middle; + white-space: nowrap; + } + + // labels with titles are styled with help cursor and dotted underline elsewhere + div.custom-fields-field label.form-label { + cursor: inherit; + text-decoration: inherit; + } + + .form-control, + .btn { + font-size: 0.875rem; + } + + &.custom-fields-new > div:not(:last-child) { + padding-right: 0; + } +} diff --git a/ui/v2.5/src/locales/en-GB.json b/ui/v2.5/src/locales/en-GB.json index f9e5c3c4902..f22ba1de574 100644 --- a/ui/v2.5/src/locales/en-GB.json +++ b/ui/v2.5/src/locales/en-GB.json @@ -854,6 +854,11 @@ "only": "Only" }, "custom": "Custom", + "custom_fields": { + "field": "Field", + "title": "Custom Fields", + "value": "Value" + }, "date": "Date", "date_format": "YYYY-MM-DD", "datetime_format": "YYYY-MM-DD HH:MM", @@ -1035,6 +1040,12 @@ }, "empty_server": "Add some scenes to your server to view recommendations on this page.", "errors": { + "custom_fields": { + "duplicate_field": "Field name must be unique", + "field_name_length": "Field name must fewer than 65 characters", + "field_name_required": "Field name is required", + "field_name_whitespace": "Field name cannot have leading or trailing whitespace" + }, "header": "Error", "image_index_greater_than_zero": "Image index must be greater than 0", "invalid_javascript_string": "Invalid javascript code: {error}",