From 67725fb12872dabf0032c2f377a755394e30e2f1 Mon Sep 17 00:00:00 2001 From: DingDongSoLong4 <99329275+DingDongSoLong4@users.noreply.github.com> Date: Sat, 23 Sep 2023 22:42:44 +0200 Subject: [PATCH 1/9] Remove manager.Repository --- internal/api/loaders/dataloaders.go | 31 ++-- internal/api/resolver.go | 8 +- internal/api/resolver_mutation_file.go | 2 +- internal/api/resolver_mutation_stash_box.go | 6 +- internal/api/resolver_mutation_tag_test.go | 4 +- internal/api/resolver_query_configuration.go | 2 +- internal/api/resolver_query_scraper.go | 2 +- internal/api/server.go | 42 +++-- internal/autotag/integration_test.go | 2 +- internal/identify/identify_test.go | 2 +- internal/manager/manager.go | 82 ++++----- internal/manager/manager_tasks.go | 13 +- internal/manager/repository.go | 51 ------ internal/manager/task_autotag.go | 87 ++++----- internal/manager/task_clean.go | 15 +- internal/manager/task_export.go | 167 +++++++++--------- internal/manager/task_generate.go | 32 ++-- ...task_generate_interactive_heatmap_speed.go | 7 +- internal/manager/task_generate_markers.go | 16 +- internal/manager/task_generate_phash.go | 10 +- internal/manager/task_generate_screenshot.go | 14 +- internal/manager/task_identify.go | 34 ++-- internal/manager/task_import.go | 155 ++++++++-------- internal/manager/task_scan.go | 153 +++++++++------- internal/manager/task_stash_box_tag.go | 66 ++++--- pkg/file/move.go | 2 +- pkg/models/repository.go | 17 +- pkg/sqlite/setup_test.go | 4 +- pkg/sqlite/transaction.go | 2 +- 29 files changed, 519 insertions(+), 509 deletions(-) diff --git a/internal/api/loaders/dataloaders.go b/internal/api/loaders/dataloaders.go index d98c663a146..714ee9c3899 100644 --- a/internal/api/loaders/dataloaders.go +++ b/internal/api/loaders/dataloaders.go @@ -17,9 +17,7 @@ import ( "net/http" "time" - "github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/txn" ) type contextKey struct{ name string } @@ -49,8 +47,7 @@ type Loaders struct { } type Middleware struct { - DatabaseProvider txn.DatabaseProvider - Repository manager.Repository + Repository models.Repository } func (m Middleware) Middleware(next http.Handler) http.Handler { @@ -131,13 +128,9 @@ func toErrorSlice(err error) []error { return nil } -func (m Middleware) withTxn(ctx context.Context, fn func(ctx context.Context) error) error { - return txn.WithDatabase(ctx, m.DatabaseProvider, fn) -} - func (m Middleware) fetchScenes(ctx context.Context) func(keys []int) ([]*models.Scene, []error) { return func(keys []int) (ret []*models.Scene, errs []error) { - err := m.withTxn(ctx, func(ctx context.Context) error { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { var err error ret, err = m.Repository.Scene.FindMany(ctx, keys) return err @@ -148,7 +141,7 @@ func (m Middleware) fetchScenes(ctx context.Context) func(keys []int) ([]*models func (m Middleware) fetchImages(ctx context.Context) func(keys []int) ([]*models.Image, []error) { return func(keys []int) (ret []*models.Image, errs []error) { - err := m.withTxn(ctx, func(ctx context.Context) error { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { var err error ret, err = m.Repository.Image.FindMany(ctx, keys) return err @@ -160,7 +153,7 @@ func (m Middleware) fetchImages(ctx context.Context) func(keys []int) ([]*models func (m Middleware) fetchGalleries(ctx context.Context) func(keys []int) ([]*models.Gallery, []error) { return func(keys []int) (ret []*models.Gallery, errs []error) { - err := m.withTxn(ctx, func(ctx context.Context) error { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { var err error ret, err = m.Repository.Gallery.FindMany(ctx, keys) return err @@ -172,7 +165,7 @@ func (m Middleware) fetchGalleries(ctx context.Context) func(keys []int) ([]*mod func (m Middleware) fetchPerformers(ctx context.Context) func(keys []int) ([]*models.Performer, []error) { return func(keys []int) (ret []*models.Performer, errs []error) { - err := m.withTxn(ctx, func(ctx context.Context) error { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { var err error ret, err = m.Repository.Performer.FindMany(ctx, keys) return err @@ -184,7 +177,7 @@ func (m Middleware) fetchPerformers(ctx context.Context) func(keys []int) ([]*mo func (m Middleware) fetchStudios(ctx context.Context) func(keys []int) ([]*models.Studio, []error) { return func(keys []int) (ret []*models.Studio, errs []error) { - err := m.withTxn(ctx, func(ctx context.Context) error { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { var err error ret, err = m.Repository.Studio.FindMany(ctx, keys) return err @@ -195,7 +188,7 @@ func (m Middleware) fetchStudios(ctx context.Context) func(keys []int) ([]*model func (m Middleware) fetchTags(ctx context.Context) func(keys []int) ([]*models.Tag, []error) { return func(keys []int) (ret []*models.Tag, errs []error) { - err := m.withTxn(ctx, func(ctx context.Context) error { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { var err error ret, err = m.Repository.Tag.FindMany(ctx, keys) return err @@ -206,7 +199,7 @@ func (m Middleware) fetchTags(ctx context.Context) func(keys []int) ([]*models.T func (m Middleware) fetchMovies(ctx context.Context) func(keys []int) ([]*models.Movie, []error) { return func(keys []int) (ret []*models.Movie, errs []error) { - err := m.withTxn(ctx, func(ctx context.Context) error { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { var err error ret, err = m.Repository.Movie.FindMany(ctx, keys) return err @@ -217,7 +210,7 @@ func (m Middleware) fetchMovies(ctx context.Context) func(keys []int) ([]*models func (m Middleware) fetchFiles(ctx context.Context) func(keys []models.FileID) ([]models.File, []error) { return func(keys []models.FileID) (ret []models.File, errs []error) { - err := m.withTxn(ctx, func(ctx context.Context) error { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { var err error ret, err = m.Repository.File.Find(ctx, keys...) return err @@ -228,7 +221,7 @@ func (m Middleware) fetchFiles(ctx context.Context) func(keys []models.FileID) ( func (m Middleware) fetchScenesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) { return func(keys []int) (ret [][]models.FileID, errs []error) { - err := m.withTxn(ctx, func(ctx context.Context) error { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { var err error ret, err = m.Repository.Scene.GetManyFileIDs(ctx, keys) return err @@ -239,7 +232,7 @@ func (m Middleware) fetchScenesFileIDs(ctx context.Context) func(keys []int) ([] func (m Middleware) fetchImagesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) { return func(keys []int) (ret [][]models.FileID, errs []error) { - err := m.withTxn(ctx, func(ctx context.Context) error { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { var err error ret, err = m.Repository.Image.GetManyFileIDs(ctx, keys) return err @@ -250,7 +243,7 @@ func (m Middleware) fetchImagesFileIDs(ctx context.Context) func(keys []int) ([] func (m Middleware) fetchGalleriesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) { return func(keys []int) (ret [][]models.FileID, errs []error) { - err := m.withTxn(ctx, func(ctx context.Context) error { + err := m.Repository.WithDB(ctx, func(ctx context.Context) error { var err error ret, err = m.Repository.Gallery.GetManyFileIDs(ctx, keys) return err diff --git a/internal/api/resolver.go b/internal/api/resolver.go index ea0bd256c22..91b8dc5b742 100644 --- a/internal/api/resolver.go +++ b/internal/api/resolver.go @@ -13,7 +13,6 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/scraper" - "github.com/stashapp/stash/pkg/txn" ) var ( @@ -33,8 +32,7 @@ type hookExecutor interface { } type Resolver struct { - txnManager txn.Manager - repository manager.Repository + repository models.Repository sceneService manager.SceneService imageService manager.ImageService galleryService manager.GalleryService @@ -102,11 +100,11 @@ type tagResolver struct{ *Resolver } type savedFilterResolver struct{ *Resolver } func (r *Resolver) withTxn(ctx context.Context, fn func(ctx context.Context) error) error { - return txn.WithTxn(ctx, r.txnManager, fn) + return r.repository.WithTxn(ctx, fn) } func (r *Resolver) withReadTxn(ctx context.Context, fn func(ctx context.Context) error) error { - return txn.WithReadTxn(ctx, r.txnManager, fn) + return r.repository.WithReadTxn(ctx, fn) } func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*models.SceneMarker, err error) { diff --git a/internal/api/resolver_mutation_file.go b/internal/api/resolver_mutation_file.go index e8fecef80a8..10167a6d436 100644 --- a/internal/api/resolver_mutation_file.go +++ b/internal/api/resolver_mutation_file.go @@ -17,7 +17,7 @@ func (r *mutationResolver) MoveFiles(ctx context.Context, input MoveFilesInput) fileStore := r.repository.File folderStore := r.repository.Folder mover := file.NewMover(fileStore, folderStore) - mover.RegisterHooks(ctx, r.txnManager) + mover.RegisterHooks(ctx) var ( folder *models.Folder diff --git a/internal/api/resolver_mutation_stash_box.go b/internal/api/resolver_mutation_stash_box.go index 2f8593097f9..4f595d3ae9e 100644 --- a/internal/api/resolver_mutation_stash_box.go +++ b/internal/api/resolver_mutation_stash_box.go @@ -27,7 +27,7 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input return false, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) } - client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository()) + client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.repository.TxnManager, r.stashboxRepository()) return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint) } @@ -49,7 +49,7 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input S return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) } - client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository()) + client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.repository.TxnManager, r.stashboxRepository()) id, err := strconv.Atoi(input.ID) if err != nil { @@ -91,7 +91,7 @@ func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, inp return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) } - client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository()) + client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.repository.TxnManager, r.stashboxRepository()) id, err := strconv.Atoi(input.ID) if err != nil { diff --git a/internal/api/resolver_mutation_tag_test.go b/internal/api/resolver_mutation_tag_test.go index b4098512982..ba604061681 100644 --- a/internal/api/resolver_mutation_tag_test.go +++ b/internal/api/resolver_mutation_tag_test.go @@ -5,7 +5,6 @@ import ( "errors" "testing" - "github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" "github.com/stashapp/stash/pkg/plugin" @@ -18,8 +17,7 @@ import ( func newResolver() *Resolver { txnMgr := &mocks.TxnManager{} return &Resolver{ - txnManager: txnMgr, - repository: manager.Repository{ + repository: models.Repository{ TxnManager: txnMgr, Tag: &mocks.TagReaderWriter{}, }, diff --git a/internal/api/resolver_query_configuration.go b/internal/api/resolver_query_configuration.go index 7de9bda0da6..c6fe587ae52 100644 --- a/internal/api/resolver_query_configuration.go +++ b/internal/api/resolver_query_configuration.go @@ -243,7 +243,7 @@ func makeConfigUIResult() map[string]interface{} { } func (r *queryResolver) ValidateStashBoxCredentials(ctx context.Context, input config.StashBoxInput) (*StashBoxValidationResult, error) { - client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.txnManager, r.stashboxRepository()) + client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.repository.TxnManager, r.stashboxRepository()) user, err := client.GetUser(ctx) valid := user != nil && user.Me != nil diff --git a/internal/api/resolver_query_scraper.go b/internal/api/resolver_query_scraper.go index 0220316b2fb..55d6564dece 100644 --- a/internal/api/resolver_query_scraper.go +++ b/internal/api/resolver_query_scraper.go @@ -238,7 +238,7 @@ func (r *queryResolver) getStashBoxClient(index int) (*stashbox.Client, error) { return nil, fmt.Errorf("%w: invalid stash_box_index %d", ErrInput, index) } - return stashbox.NewClient(*boxes[index], r.txnManager, r.stashboxRepository()), nil + return stashbox.NewClient(*boxes[index], r.repository.TxnManager, r.stashboxRepository()), nil } func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.Source, input ScrapeSingleSceneInput) ([]*scraper.ScrapedScene, error) { diff --git a/internal/api/server.go b/internal/api/server.go index b909914cdfd..15f72e416e7 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -82,11 +82,10 @@ func Start() error { return errors.New(message) } - txnManager := manager.GetInstance().Repository + repo := manager.GetInstance().Repository dataloaders := loaders.Middleware{ - DatabaseProvider: txnManager, - Repository: txnManager, + Repository: repo, } r.Use(dataloaders.Middleware) @@ -96,8 +95,7 @@ func Start() error { imageService := manager.GetInstance().ImageService galleryService := manager.GetInstance().GalleryService resolver := &Resolver{ - txnManager: txnManager, - repository: txnManager, + repository: repo, sceneService: sceneService, imageService: imageService, galleryService: galleryService, @@ -145,33 +143,33 @@ func Start() error { }) r.Mount("/performer", performerRoutes{ - txnManager: txnManager, - performerFinder: txnManager.Performer, + txnManager: repo.TxnManager, + performerFinder: repo.Performer, }.Routes()) r.Mount("/scene", sceneRoutes{ - txnManager: txnManager, - sceneFinder: txnManager.Scene, - fileGetter: txnManager.File, - captionFinder: txnManager.File, - sceneMarkerFinder: txnManager.SceneMarker, - tagFinder: txnManager.Tag, + txnManager: repo.TxnManager, + sceneFinder: repo.Scene, + fileGetter: repo.File, + captionFinder: repo.File, + sceneMarkerFinder: repo.SceneMarker, + tagFinder: repo.Tag, }.Routes()) r.Mount("/image", imageRoutes{ - txnManager: txnManager, - imageFinder: txnManager.Image, - fileGetter: txnManager.File, + txnManager: repo.TxnManager, + imageFinder: repo.Image, + fileGetter: repo.File, }.Routes()) r.Mount("/studio", studioRoutes{ - txnManager: txnManager, - studioFinder: txnManager.Studio, + txnManager: repo.TxnManager, + studioFinder: repo.Studio, }.Routes()) r.Mount("/movie", movieRoutes{ - txnManager: txnManager, - movieFinder: txnManager.Movie, + txnManager: repo.TxnManager, + movieFinder: repo.Movie, }.Routes()) r.Mount("/tag", tagRoutes{ - txnManager: txnManager, - tagFinder: txnManager.Tag, + txnManager: repo.TxnManager, + tagFinder: repo.Tag, }.Routes()) r.Mount("/downloads", downloadsRoutes{}.Routes()) diff --git a/internal/autotag/integration_test.go b/internal/autotag/integration_test.go index 84ae016987c..add74133cf0 100644 --- a/internal/autotag/integration_test.go +++ b/internal/autotag/integration_test.go @@ -62,7 +62,7 @@ func runTests(m *testing.M) int { panic(fmt.Sprintf("Could not initialize database: %s", err.Error())) } - r = db.TxnRepository() + r = db.Repository() // defer close and delete the database defer testTeardown(databaseFile) diff --git a/internal/identify/identify_test.go b/internal/identify/identify_test.go index 04ff0360765..c032cbe53ac 100644 --- a/internal/identify/identify_test.go +++ b/internal/identify/identify_test.go @@ -254,7 +254,7 @@ func TestSceneIdentifier_modifyScene(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tr.modifyScene(testCtx, repo, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr { + if err := tr.modifyScene(testCtx, repo.TxnManager, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr { t.Errorf("SceneIdentifier.modifyScene() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/internal/manager/manager.go b/internal/manager/manager.go index e199f9ce78a..9bd7fbfae16 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -131,7 +131,7 @@ type Manager struct { DLNAService *dlna.Service Database *sqlite.Database - Repository Repository + Repository models.Repository SceneService SceneService ImageService ImageService @@ -174,6 +174,7 @@ func initialize() error { initProfiling(cfg.GetCPUProfilePath()) db := sqlite.NewDatabase() + repo := db.Repository() // start with empty paths emptyPaths := paths.Paths{} @@ -186,48 +187,48 @@ func initialize() error { PluginCache: plugin.NewCache(cfg), Database: db, - Repository: sqliteRepository(db), + Repository: repo, Paths: &emptyPaths, scanSubs: &subscriptionManager{}, } instance.SceneService = &scene.Service{ - File: db.File, - Repository: db.Scene, - MarkerRepository: db.SceneMarker, + File: repo.File, + Repository: repo.Scene, + MarkerRepository: repo.SceneMarker, PluginCache: instance.PluginCache, Paths: instance.Paths, Config: cfg, } instance.ImageService = &image.Service{ - File: db.File, - Repository: db.Image, + File: repo.File, + Repository: repo.Image, } instance.GalleryService = &gallery.Service{ - Repository: db.Gallery, - ImageFinder: db.Image, + Repository: repo.Gallery, + ImageFinder: repo.Image, ImageService: instance.ImageService, - File: db.File, - Folder: db.Folder, + File: repo.File, + Folder: repo.Folder, } instance.JobManager = initJobManager() sceneServer := SceneServer{ - TxnManager: instance.Repository, - SceneCoverGetter: instance.Repository.Scene, + TxnManager: repo.TxnManager, + SceneCoverGetter: repo.Scene, } - instance.DLNAService = dlna.NewService(instance.Repository, dlna.Repository{ - SceneFinder: instance.Repository.Scene, - FileGetter: instance.Repository.File, - StudioFinder: instance.Repository.Studio, - TagFinder: instance.Repository.Tag, - PerformerFinder: instance.Repository.Performer, - MovieFinder: instance.Repository.Movie, + instance.DLNAService = dlna.NewService(repo.TxnManager, dlna.Repository{ + SceneFinder: repo.Scene, + FileGetter: repo.File, + StudioFinder: repo.Studio, + TagFinder: repo.Tag, + PerformerFinder: repo.Performer, + MovieFinder: repo.Movie, }, instance.Config, &sceneServer) if !cfg.IsNewSystem() { @@ -268,8 +269,8 @@ func initialize() error { logger.Warnf("could not initialize FFMPEG subsystem: %v", err) } - instance.Scanner = makeScanner(db, instance.PluginCache) - instance.Cleaner = makeCleaner(db, instance.PluginCache) + instance.Scanner = makeScanner(repo, instance.PluginCache) + instance.Cleaner = makeCleaner(repo, instance.PluginCache) // if DLNA is enabled, start it now if instance.Config.GetDLNADefaultEnabled() { @@ -293,13 +294,13 @@ func galleryFileFilter(ctx context.Context, f models.File) bool { return isZip(f.Base().Basename) } -func makeScanner(db *sqlite.Database, pluginCache *plugin.Cache) *file.Scanner { +func makeScanner(repo models.Repository, pluginCache *plugin.Cache) *file.Scanner { return &file.Scanner{ Repository: file.Repository{ - Manager: db, - DatabaseProvider: db, - FileStore: db.File, - FolderStore: db.Folder, + Manager: repo.TxnManager, + DatabaseProvider: repo.TxnManager, + FileStore: repo.File, + FolderStore: repo.Folder, }, FileDecorators: []file.Decorator{ &file.FilteredDecorator{ @@ -320,14 +321,14 @@ func makeScanner(db *sqlite.Database, pluginCache *plugin.Cache) *file.Scanner { } } -func makeCleaner(db *sqlite.Database, pluginCache *plugin.Cache) *file.Cleaner { +func makeCleaner(repo models.Repository, pluginCache *plugin.Cache) *file.Cleaner { return &file.Cleaner{ FS: &file.OsFS{}, Repository: file.Repository{ - Manager: db, - DatabaseProvider: db, - FileStore: db.File, - FolderStore: db.Folder, + Manager: repo.TxnManager, + DatabaseProvider: repo.TxnManager, + FileStore: repo.File, + FolderStore: repo.Folder, }, Handlers: []file.CleanHandler{ &cleanHandler{}, @@ -523,13 +524,14 @@ func writeStashIcon() { // initScraperCache initializes a new scraper cache and returns it. func (s *Manager) initScraperCache() *scraper.Cache { - ret, err := scraper.NewCache(config.GetInstance(), s.Repository, scraper.Repository{ - SceneFinder: s.Repository.Scene, - GalleryFinder: s.Repository.Gallery, - TagFinder: s.Repository.Tag, - PerformerFinder: s.Repository.Performer, - MovieFinder: s.Repository.Movie, - StudioFinder: s.Repository.Studio, + repo := s.Repository + ret, err := scraper.NewCache(s.Config, repo.TxnManager, scraper.Repository{ + SceneFinder: repo.Scene, + GalleryFinder: repo.Gallery, + TagFinder: repo.Tag, + PerformerFinder: repo.Performer, + MovieFinder: repo.Movie, + StudioFinder: repo.Studio, }) if err != nil { @@ -697,7 +699,7 @@ func (s *Manager) Setup(ctx context.Context, input SetupInput) error { return fmt.Errorf("error initializing FFMPEG subsystem: %v", err) } - instance.Scanner = makeScanner(instance.Database, instance.PluginCache) + instance.Scanner = makeScanner(instance.Repository, instance.PluginCache) return nil } diff --git a/internal/manager/manager_tasks.go b/internal/manager/manager_tasks.go index ed4eea17116..dea78b53537 100644 --- a/internal/manager/manager_tasks.go +++ b/internal/manager/manager_tasks.go @@ -112,7 +112,8 @@ func (s *Manager) Import(ctx context.Context) (int, error) { j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { task := ImportTask{ - txnManager: s.Repository, + repository: s.Repository, + resetter: s.Database, BaseDir: metadataPath, Reset: true, DuplicateBehaviour: ImportDuplicateEnumFail, @@ -136,7 +137,7 @@ func (s *Manager) Export(ctx context.Context) (int, error) { var wg sync.WaitGroup wg.Add(1) task := ExportTask{ - txnManager: s.Repository, + repository: s.Repository, full: true, fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(), } @@ -167,7 +168,7 @@ func (s *Manager) Generate(ctx context.Context, input GenerateMetadataInput) (in } j := &GenerateJob{ - txnManager: s.Repository, + repository: s.Repository, input: input, } @@ -212,7 +213,7 @@ func (s *Manager) generateScreenshot(ctx context.Context, sceneId string, at *fl } task := GenerateCoverTask{ - txnManager: s.Repository, + repository: s.Repository, Scene: *scene, ScreenshotAt: at, Overwrite: true, @@ -239,7 +240,7 @@ type AutoTagMetadataInput struct { func (s *Manager) AutoTag(ctx context.Context, input AutoTagMetadataInput) int { j := autoTagJob{ - txnManager: s.Repository, + repository: s.Repository, input: input, } @@ -255,7 +256,7 @@ type CleanMetadataInput struct { func (s *Manager) Clean(ctx context.Context, input CleanMetadataInput) int { j := cleanJob{ cleaner: s.Cleaner, - txnManager: s.Repository, + repository: s.Repository, sceneService: s.SceneService, imageService: s.ImageService, input: input, diff --git a/internal/manager/repository.go b/internal/manager/repository.go index 77859d06baa..fa0c865c683 100644 --- a/internal/manager/repository.go +++ b/internal/manager/repository.go @@ -6,59 +6,8 @@ import ( "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene" - "github.com/stashapp/stash/pkg/sqlite" - "github.com/stashapp/stash/pkg/txn" ) -type Repository struct { - models.TxnManager - - File models.FileReaderWriter - Folder models.FolderReaderWriter - Gallery models.GalleryReaderWriter - GalleryChapter models.GalleryChapterReaderWriter - Image models.ImageReaderWriter - Movie models.MovieReaderWriter - Performer models.PerformerReaderWriter - Scene models.SceneReaderWriter - SceneMarker models.SceneMarkerReaderWriter - Studio models.StudioReaderWriter - Tag models.TagReaderWriter - SavedFilter models.SavedFilterReaderWriter -} - -func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error { - return txn.WithTxn(ctx, r, fn) -} - -func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error { - return txn.WithReadTxn(ctx, r, fn) -} - -func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error { - return txn.WithDatabase(ctx, r, fn) -} - -func sqliteRepository(d *sqlite.Database) Repository { - txnRepo := d.TxnRepository() - - return Repository{ - TxnManager: txnRepo, - File: d.File, - Folder: d.Folder, - Gallery: d.Gallery, - GalleryChapter: txnRepo.GalleryChapter, - Image: d.Image, - Movie: txnRepo.Movie, - Performer: txnRepo.Performer, - Scene: d.Scene, - SceneMarker: txnRepo.SceneMarker, - Studio: txnRepo.Studio, - Tag: txnRepo.Tag, - SavedFilter: txnRepo.SavedFilter, - } -} - type SceneService interface { Create(ctx context.Context, input *models.Scene, fileIDs []models.FileID, coverImage []byte) (*models.Scene, error) AssignFile(ctx context.Context, sceneID int, fileID models.FileID) error diff --git a/internal/manager/task_autotag.go b/internal/manager/task_autotag.go index 0f1cadb2df0..e23437651f7 100644 --- a/internal/manager/task_autotag.go +++ b/internal/manager/task_autotag.go @@ -19,7 +19,7 @@ import ( ) type autoTagJob struct { - txnManager Repository + repository models.Repository input AutoTagMetadataInput cache match.Cache @@ -56,7 +56,7 @@ func (j *autoTagJob) autoTagFiles(ctx context.Context, progress *job.Progress, p studios: studios, tags: tags, progress: progress, - txnManager: j.txnManager, + repository: j.repository, cache: &j.cache, } @@ -73,8 +73,8 @@ func (j *autoTagJob) autoTagSpecific(ctx context.Context, progress *job.Progress studioCount := len(studioIds) tagCount := len(tagIds) - if err := j.txnManager.WithReadTxn(ctx, func(ctx context.Context) error { - r := j.txnManager + r := j.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { performerQuery := r.Performer studioQuery := r.Studio tagQuery := r.Tag @@ -123,16 +123,17 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre return } + r := j.repository tagger := autotag.Tagger{ - TxnManager: j.txnManager, + TxnManager: r.TxnManager, Cache: &j.cache, } for _, performerId := range performerIds { var performers []*models.Performer - if err := j.txnManager.WithDB(ctx, func(ctx context.Context) error { - performerQuery := j.txnManager.Performer + if err := r.WithDB(ctx, func(ctx context.Context) error { + performerQuery := r.Performer ignoreAutoTag := false perPage := -1 @@ -161,7 +162,7 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre return fmt.Errorf("performer with id %s not found", performerId) } - if err := performer.LoadAliases(ctx, j.txnManager.Performer); err != nil { + if err := performer.LoadAliases(ctx, r.Performer); err != nil { return fmt.Errorf("loading aliases for performer %d: %w", performer.ID, err) } performers = append(performers, performer) @@ -173,7 +174,6 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre } err := func() error { - r := j.txnManager if err := tagger.PerformerScenes(ctx, performer, paths, r.Scene); err != nil { return fmt.Errorf("processing scenes: %w", err) } @@ -215,9 +215,9 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress, return } - r := j.txnManager + r := j.repository tagger := autotag.Tagger{ - TxnManager: j.txnManager, + TxnManager: r.TxnManager, Cache: &j.cache, } @@ -308,15 +308,15 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa return } - r := j.txnManager + r := j.repository tagger := autotag.Tagger{ - TxnManager: j.txnManager, + TxnManager: r.TxnManager, Cache: &j.cache, } for _, tagId := range tagIds { var tags []*models.Tag - if err := j.txnManager.WithDB(ctx, func(ctx context.Context) error { + if err := r.WithDB(ctx, func(ctx context.Context) error { tagQuery := r.Tag ignoreAutoTag := false perPage := -1 @@ -402,7 +402,7 @@ type autoTagFilesTask struct { tags bool progress *job.Progress - txnManager Repository + repository models.Repository cache *match.Cache } @@ -482,7 +482,9 @@ func (t *autoTagFilesTask) makeGalleryFilter() *models.GalleryFilterType { return ret } -func (t *autoTagFilesTask) getCount(ctx context.Context, r Repository) (int, error) { +func (t *autoTagFilesTask) getCount(ctx context.Context) (int, error) { + r := t.repository + pp := 0 findFilter := &models.FindFilterType{ PerPage: &pp, @@ -522,7 +524,7 @@ func (t *autoTagFilesTask) getCount(ctx context.Context, r Repository) (int, err return sceneCount + imageCount + galleryCount, nil } -func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) { +func (t *autoTagFilesTask) processScenes(ctx context.Context) { if job.IsCancelled(ctx) { return } @@ -534,10 +536,12 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) { findFilter := models.BatchFindFilter(batchSize) sceneFilter := t.makeSceneFilter() + r := t.repository + more := true for more { var scenes []*models.Scene - if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error { + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { var err error scenes, err = scene.Query(ctx, r.Scene, sceneFilter, findFilter) return err @@ -555,7 +559,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) { } tt := autoTagSceneTask{ - txnManager: t.txnManager, + repository: r, scene: ss, performers: t.performers, studios: t.studios, @@ -583,7 +587,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) { } } -func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) { +func (t *autoTagFilesTask) processImages(ctx context.Context) { if job.IsCancelled(ctx) { return } @@ -595,10 +599,12 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) { findFilter := models.BatchFindFilter(batchSize) imageFilter := t.makeImageFilter() + r := t.repository + more := true for more { var images []*models.Image - if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error { + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { var err error images, err = image.Query(ctx, r.Image, imageFilter, findFilter) return err @@ -616,7 +622,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) { } tt := autoTagImageTask{ - txnManager: t.txnManager, + repository: t.repository, image: ss, performers: t.performers, studios: t.studios, @@ -644,7 +650,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) { } } -func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) { +func (t *autoTagFilesTask) processGalleries(ctx context.Context) { if job.IsCancelled(ctx) { return } @@ -656,10 +662,12 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) { findFilter := models.BatchFindFilter(batchSize) galleryFilter := t.makeGalleryFilter() + r := t.repository + more := true for more { var galleries []*models.Gallery - if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error { + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { var err error galleries, _, err = r.Gallery.Query(ctx, galleryFilter, findFilter) return err @@ -677,7 +685,7 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) { } tt := autoTagGalleryTask{ - txnManager: t.txnManager, + repository: t.repository, gallery: ss, performers: t.performers, studios: t.studios, @@ -706,9 +714,8 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) { } func (t *autoTagFilesTask) process(ctx context.Context) { - r := t.txnManager - if err := r.WithReadTxn(ctx, func(ctx context.Context) error { - total, err := t.getCount(ctx, t.txnManager) + if err := t.repository.WithReadTxn(ctx, func(ctx context.Context) error { + total, err := t.getCount(ctx) if err != nil { return err } @@ -724,13 +731,13 @@ func (t *autoTagFilesTask) process(ctx context.Context) { return } - t.processScenes(ctx, r) - t.processImages(ctx, r) - t.processGalleries(ctx, r) + t.processScenes(ctx) + t.processImages(ctx) + t.processGalleries(ctx) } type autoTagSceneTask struct { - txnManager Repository + repository models.Repository scene *models.Scene performers bool @@ -742,8 +749,8 @@ type autoTagSceneTask struct { func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - r := t.txnManager - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + r := t.repository + if err := r.WithTxn(ctx, func(ctx context.Context) error { if t.scene.Path == "" { // nothing to do return nil @@ -774,7 +781,7 @@ func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) { } type autoTagImageTask struct { - txnManager Repository + repository models.Repository image *models.Image performers bool @@ -786,8 +793,8 @@ type autoTagImageTask struct { func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - r := t.txnManager - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + r := t.repository + if err := r.WithTxn(ctx, func(ctx context.Context) error { if t.performers { if err := autotag.ImagePerformers(ctx, t.image, r.Image, r.Performer, t.cache); err != nil { return fmt.Errorf("tagging image performers for %s: %v", t.image.DisplayName(), err) @@ -813,7 +820,7 @@ func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) { } type autoTagGalleryTask struct { - txnManager Repository + repository models.Repository gallery *models.Gallery performers bool @@ -825,8 +832,8 @@ type autoTagGalleryTask struct { func (t *autoTagGalleryTask) Start(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - r := t.txnManager - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + r := t.repository + if err := r.WithTxn(ctx, func(ctx context.Context) error { if t.performers { if err := autotag.GalleryPerformers(ctx, t.gallery, r.Gallery, r.Performer, t.cache); err != nil { return fmt.Errorf("tagging gallery performers for %s: %v", t.gallery.DisplayName(), err) diff --git a/internal/manager/task_clean.go b/internal/manager/task_clean.go index 207c6381866..0ea4332f918 100644 --- a/internal/manager/task_clean.go +++ b/internal/manager/task_clean.go @@ -16,7 +16,6 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/scene" - "github.com/stashapp/stash/pkg/txn" ) type cleaner interface { @@ -25,7 +24,7 @@ type cleaner interface { type cleanJob struct { cleaner cleaner - txnManager Repository + repository models.Repository input CleanMetadataInput sceneService SceneService imageService ImageService @@ -61,10 +60,11 @@ func (j *cleanJob) cleanEmptyGalleries(ctx context.Context) { const batchSize = 1000 var toClean []int findFilter := models.BatchFindFilter(batchSize) - if err := txn.WithTxn(ctx, j.txnManager, func(ctx context.Context) error { + r := j.repository + if err := r.WithTxn(ctx, func(ctx context.Context) error { found := true for found { - emptyGalleries, _, err := j.txnManager.Gallery.Query(ctx, &models.GalleryFilterType{ + emptyGalleries, _, err := r.Gallery.Query(ctx, &models.GalleryFilterType{ ImageCount: &models.IntCriterionInput{ Value: 0, Modifier: models.CriterionModifierEquals, @@ -108,9 +108,10 @@ func (j *cleanJob) cleanEmptyGalleries(ctx context.Context) { func (j *cleanJob) deleteGallery(ctx context.Context, id int) { pluginCache := GetInstance().PluginCache - qb := j.txnManager.Gallery - if err := txn.WithTxn(ctx, j.txnManager, func(ctx context.Context) error { + r := j.repository + if err := r.WithTxn(ctx, func(ctx context.Context) error { + qb := r.Gallery g, err := qb.Find(ctx, id) if err != nil { return err @@ -120,7 +121,7 @@ func (j *cleanJob) deleteGallery(ctx context.Context, id int) { return fmt.Errorf("gallery with id %d not found", id) } - if err := g.LoadPrimaryFile(ctx, j.txnManager.File); err != nil { + if err := g.LoadPrimaryFile(ctx, r.File); err != nil { return err } diff --git a/internal/manager/task_export.go b/internal/manager/task_export.go index a7278253ecc..c268a3eecac 100644 --- a/internal/manager/task_export.go +++ b/internal/manager/task_export.go @@ -31,7 +31,7 @@ import ( ) type ExportTask struct { - txnManager Repository + repository models.Repository full bool baseDir string @@ -98,7 +98,7 @@ func CreateExportTask(a models.HashAlgorithm, input ExportObjectsInput) *ExportT } return &ExportTask{ - txnManager: GetInstance().Repository, + repository: GetInstance().Repository, fileNamingAlgorithm: a, scenes: newExportSpec(input.Scenes), images: newExportSpec(input.Images), @@ -148,29 +148,27 @@ func (t *ExportTask) Start(ctx context.Context, wg *sync.WaitGroup) { paths.EmptyJSONDirs(t.baseDir) paths.EnsureJSONDirs(t.baseDir) - txnErr := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { - r := t.txnManager - + txnErr := t.repository.WithTxn(ctx, func(ctx context.Context) error { // include movie scenes and gallery images if !t.full { // only include movie scenes if includeDependencies is also set if !t.scenes.all && t.includeDependencies { - t.populateMovieScenes(ctx, r) + t.populateMovieScenes(ctx) } // always export gallery images if !t.images.all { - t.populateGalleryImages(ctx, r) + t.populateGalleryImages(ctx) } } - t.ExportScenes(ctx, workerCount, r) - t.ExportImages(ctx, workerCount, r) - t.ExportGalleries(ctx, workerCount, r) - t.ExportMovies(ctx, workerCount, r) - t.ExportPerformers(ctx, workerCount, r) - t.ExportStudios(ctx, workerCount, r) - t.ExportTags(ctx, workerCount, r) + t.ExportScenes(ctx, workerCount) + t.ExportImages(ctx, workerCount) + t.ExportGalleries(ctx, workerCount) + t.ExportMovies(ctx, workerCount) + t.ExportPerformers(ctx, workerCount) + t.ExportStudios(ctx, workerCount) + t.ExportTags(ctx, workerCount) return nil }) @@ -277,9 +275,10 @@ func (t *ExportTask) zipFile(fn, outDir string, z *zip.Writer) error { return nil } -func (t *ExportTask) populateMovieScenes(ctx context.Context, repo Repository) { - reader := repo.Movie - sceneReader := repo.Scene +func (t *ExportTask) populateMovieScenes(ctx context.Context) { + r := t.repository + reader := r.Movie + sceneReader := r.Scene var movies []*models.Movie var err error @@ -307,9 +306,10 @@ func (t *ExportTask) populateMovieScenes(ctx context.Context, repo Repository) { } } -func (t *ExportTask) populateGalleryImages(ctx context.Context, repo Repository) { - reader := repo.Gallery - imageReader := repo.Image +func (t *ExportTask) populateGalleryImages(ctx context.Context) { + r := t.repository + reader := r.Gallery + imageReader := r.Image var galleries []*models.Gallery var err error @@ -342,10 +342,10 @@ func (t *ExportTask) populateGalleryImages(ctx context.Context, repo Repository) } } -func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo Repository) { +func (t *ExportTask) ExportScenes(ctx context.Context, workers int) { var scenesWg sync.WaitGroup - sceneReader := repo.Scene + sceneReader := t.repository.Scene var scenes []*models.Scene var err error @@ -367,7 +367,7 @@ func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo Reposit for w := 0; w < workers; w++ { // create export Scene workers scenesWg.Add(1) - go exportScene(ctx, &scenesWg, jobCh, repo, t) + go t.exportScene(ctx, &scenesWg, jobCh) } for i, scene := range scenes { @@ -385,7 +385,7 @@ func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo Reposit logger.Infof("[scenes] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func exportFile(f models.File, t *ExportTask) { +func (t *ExportTask) exportFile(f models.File) { newFileJSON := fileToJSON(f) fn := newFileJSON.Filename() @@ -449,7 +449,7 @@ func fileToJSON(f models.File) jsonschema.DirEntry { return &base } -func exportFolder(f models.Folder, t *ExportTask) { +func (t *ExportTask) exportFolder(f models.Folder) { newFileJSON := folderToJSON(f) fn := newFileJSON.Filename() @@ -475,15 +475,17 @@ func folderToJSON(f models.Folder) jsonschema.DirEntry { return &base } -func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo Repository, t *ExportTask) { +func (t *ExportTask) exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Scene) { defer wg.Done() - sceneReader := repo.Scene - studioReader := repo.Studio - movieReader := repo.Movie - galleryReader := repo.Gallery - performerReader := repo.Performer - tagReader := repo.Tag - sceneMarkerReader := repo.SceneMarker + + r := t.repository + sceneReader := r.Scene + studioReader := r.Studio + movieReader := r.Movie + galleryReader := r.Gallery + performerReader := r.Performer + tagReader := r.Tag + sceneMarkerReader := r.SceneMarker for s := range jobChan { sceneHash := s.GetHash(t.fileNamingAlgorithm) @@ -500,7 +502,7 @@ func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models // export files for _, f := range s.Files.List() { - exportFile(f, t) + t.exportFile(f) } newSceneJSON.Studio, err = scene.GetStudioName(ctx, studioReader, s) @@ -589,10 +591,11 @@ func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models } } -func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo Repository) { +func (t *ExportTask) ExportImages(ctx context.Context, workers int) { var imagesWg sync.WaitGroup - imageReader := repo.Image + r := t.repository + imageReader := r.Image var images []*models.Image var err error @@ -614,7 +617,7 @@ func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo Reposit for w := 0; w < workers; w++ { // create export Image workers imagesWg.Add(1) - go exportImage(ctx, &imagesWg, jobCh, repo, t) + go t.exportImage(ctx, &imagesWg, jobCh) } for i, image := range images { @@ -632,22 +635,24 @@ func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo Reposit logger.Infof("[images] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Image, repo Repository, t *ExportTask) { +func (t *ExportTask) exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Image) { defer wg.Done() - studioReader := repo.Studio - galleryReader := repo.Gallery - performerReader := repo.Performer - tagReader := repo.Tag + + r := t.repository + studioReader := r.Studio + galleryReader := r.Gallery + performerReader := r.Performer + tagReader := r.Tag for s := range jobChan { imageHash := s.Checksum - if err := s.LoadFiles(ctx, repo.Image); err != nil { + if err := s.LoadFiles(ctx, r.Image); err != nil { logger.Errorf("[images] <%s> error getting image files: %s", imageHash, err.Error()) continue } - if err := s.LoadURLs(ctx, repo.Image); err != nil { + if err := s.LoadURLs(ctx, r.Image); err != nil { logger.Errorf("[images] <%s> error getting image urls: %s", imageHash, err.Error()) continue } @@ -656,7 +661,7 @@ func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models // export files for _, f := range s.Files.List() { - exportFile(f, t) + t.exportFile(f) } var err error @@ -715,10 +720,10 @@ func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models } } -func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo Repository) { +func (t *ExportTask) ExportGalleries(ctx context.Context, workers int) { var galleriesWg sync.WaitGroup - reader := repo.Gallery + reader := t.repository.Gallery var galleries []*models.Gallery var err error @@ -740,7 +745,7 @@ func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo Repo for w := 0; w < workers; w++ { // create export Scene workers galleriesWg.Add(1) - go exportGallery(ctx, &galleriesWg, jobCh, repo, t) + go t.exportGallery(ctx, &galleriesWg, jobCh) } for i, gallery := range galleries { @@ -759,15 +764,17 @@ func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo Repo logger.Infof("[galleries] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo Repository, t *ExportTask) { +func (t *ExportTask) exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Gallery) { defer wg.Done() - studioReader := repo.Studio - performerReader := repo.Performer - tagReader := repo.Tag - galleryChapterReader := repo.GalleryChapter + + r := t.repository + studioReader := r.Studio + performerReader := r.Performer + tagReader := r.Tag + galleryChapterReader := r.GalleryChapter for g := range jobChan { - if err := g.LoadFiles(ctx, repo.Gallery); err != nil { + if err := g.LoadFiles(ctx, r.Gallery); err != nil { logger.Errorf("[galleries] <%s> failed to fetch files for gallery: %s", g.DisplayName(), err.Error()) continue } @@ -782,12 +789,12 @@ func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *mode // export files for _, f := range g.Files.List() { - exportFile(f, t) + t.exportFile(f) } // export folder if necessary if g.FolderID != nil { - folder, err := repo.Folder.Find(ctx, *g.FolderID) + folder, err := r.Folder.Find(ctx, *g.FolderID) if err != nil { logger.Errorf("[galleries] <%s> error getting gallery folder: %v", galleryHash, err) continue @@ -798,7 +805,7 @@ func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *mode continue } - exportFolder(*folder, t) + t.exportFolder(*folder) } newGalleryJSON.Studio, err = gallery.GetStudioName(ctx, studioReader, g) @@ -857,10 +864,10 @@ func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *mode } } -func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo Repository) { +func (t *ExportTask) ExportPerformers(ctx context.Context, workers int) { var performersWg sync.WaitGroup - reader := repo.Performer + reader := t.repository.Performer var performers []*models.Performer var err error all := t.full || (t.performers != nil && t.performers.all) @@ -880,7 +887,7 @@ func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo Rep for w := 0; w < workers; w++ { // create export Performer workers performersWg.Add(1) - go t.exportPerformer(ctx, &performersWg, jobCh, repo) + go t.exportPerformer(ctx, &performersWg, jobCh) } for i, performer := range performers { @@ -896,10 +903,11 @@ func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo Rep logger.Infof("[performers] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Performer, repo Repository) { +func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Performer) { defer wg.Done() - performerReader := repo.Performer + r := t.repository + performerReader := r.Performer for p := range jobChan { newPerformerJSON, err := performer.ToJSON(ctx, performerReader, p) @@ -909,7 +917,7 @@ func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jo continue } - tags, err := repo.Tag.FindByPerformerID(ctx, p.ID) + tags, err := r.Tag.FindByPerformerID(ctx, p.ID) if err != nil { logger.Errorf("[performers] <%s> error getting performer tags: %s", p.Name, err.Error()) continue @@ -929,10 +937,10 @@ func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jo } } -func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo Repository) { +func (t *ExportTask) ExportStudios(ctx context.Context, workers int) { var studiosWg sync.WaitGroup - reader := repo.Studio + reader := t.repository.Studio var studios []*models.Studio var err error all := t.full || (t.studios != nil && t.studios.all) @@ -953,7 +961,7 @@ func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo Reposi for w := 0; w < workers; w++ { // create export Studio workers studiosWg.Add(1) - go t.exportStudio(ctx, &studiosWg, jobCh, repo) + go t.exportStudio(ctx, &studiosWg, jobCh) } for i, studio := range studios { @@ -969,10 +977,10 @@ func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo Reposi logger.Infof("[studios] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Studio, repo Repository) { +func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Studio) { defer wg.Done() - studioReader := repo.Studio + studioReader := t.repository.Studio for s := range jobChan { newStudioJSON, err := studio.ToJSON(ctx, studioReader, s) @@ -990,10 +998,10 @@ func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobCh } } -func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo Repository) { +func (t *ExportTask) ExportTags(ctx context.Context, workers int) { var tagsWg sync.WaitGroup - reader := repo.Tag + reader := t.repository.Tag var tags []*models.Tag var err error all := t.full || (t.tags != nil && t.tags.all) @@ -1014,7 +1022,7 @@ func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo Repositor for w := 0; w < workers; w++ { // create export Tag workers tagsWg.Add(1) - go t.exportTag(ctx, &tagsWg, jobCh, repo) + go t.exportTag(ctx, &tagsWg, jobCh) } for i, tag := range tags { @@ -1030,10 +1038,10 @@ func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo Repositor logger.Infof("[tags] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Tag, repo Repository) { +func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Tag) { defer wg.Done() - tagReader := repo.Tag + tagReader := t.repository.Tag for thisTag := range jobChan { newTagJSON, err := tag.ToJSON(ctx, tagReader, thisTag) @@ -1051,10 +1059,10 @@ func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan } } -func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo Repository) { +func (t *ExportTask) ExportMovies(ctx context.Context, workers int) { var moviesWg sync.WaitGroup - reader := repo.Movie + reader := t.repository.Movie var movies []*models.Movie var err error all := t.full || (t.movies != nil && t.movies.all) @@ -1075,7 +1083,7 @@ func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo Reposit for w := 0; w < workers; w++ { // create export Studio workers moviesWg.Add(1) - go t.exportMovie(ctx, &moviesWg, jobCh, repo) + go t.exportMovie(ctx, &moviesWg, jobCh) } for i, movie := range movies { @@ -1091,11 +1099,12 @@ func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo Reposit logger.Infof("[movies] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func (t *ExportTask) exportMovie(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Movie, repo Repository) { +func (t *ExportTask) exportMovie(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Movie) { defer wg.Done() - movieReader := repo.Movie - studioReader := repo.Studio + r := t.repository + movieReader := r.Movie + studioReader := r.Studio for m := range jobChan { newMovieJSON, err := movie.ToJSON(ctx, movieReader, studioReader, m) diff --git a/internal/manager/task_generate.go b/internal/manager/task_generate.go index ce3d7100028..08d01c311c7 100644 --- a/internal/manager/task_generate.go +++ b/internal/manager/task_generate.go @@ -55,7 +55,7 @@ type GeneratePreviewOptionsInput struct { const generateQueueSize = 200000 type GenerateJob struct { - txnManager Repository + repository models.Repository input GenerateMetadataInput overwrite bool @@ -112,8 +112,9 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) { Overwrite: j.overwrite, } - if err := j.txnManager.WithReadTxn(ctx, func(ctx context.Context) error { - qb := j.txnManager.Scene + r := j.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + qb := r.Scene if len(j.input.SceneIDs) == 0 && len(j.input.MarkerIDs) == 0 { totals = j.queueTasks(ctx, g, queue) } else { @@ -129,7 +130,7 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) { } if len(j.input.MarkerIDs) > 0 { - markers, err = j.txnManager.SceneMarker.FindMany(ctx, markerIDs) + markers, err = r.SceneMarker.FindMany(ctx, markerIDs) if err != nil { return err } @@ -229,12 +230,14 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que findFilter := models.BatchFindFilter(batchSize) + r := j.repository + for more := true; more; { if job.IsCancelled(ctx) { return totals } - scenes, err := scene.Query(ctx, j.txnManager.Scene, nil, findFilter) + scenes, err := scene.Query(ctx, r.Scene, nil, findFilter) if err != nil { logger.Errorf("Error encountered queuing files to scan: %s", err.Error()) return totals @@ -245,7 +248,7 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que return totals } - if err := ss.LoadFiles(ctx, j.txnManager.Scene); err != nil { + if err := ss.LoadFiles(ctx, r.Scene); err != nil { logger.Errorf("Error encountered queuing files to scan: %s", err.Error()) return totals } @@ -266,7 +269,7 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que return totals } - images, err := image.Query(ctx, j.txnManager.Image, nil, findFilter) + images, err := image.Query(ctx, r.Image, nil, findFilter) if err != nil { logger.Errorf("Error encountered queuing files to scan: %s", err.Error()) return totals @@ -277,7 +280,7 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que return totals } - if err := ss.LoadFiles(ctx, j.txnManager.Image); err != nil { + if err := ss.LoadFiles(ctx, r.Image); err != nil { logger.Errorf("Error encountered queuing files to scan: %s", err.Error()) return totals } @@ -331,9 +334,11 @@ func getGeneratePreviewOptions(optionsInput GeneratePreviewOptionsInput) generat } func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator, scene *models.Scene, queue chan<- Task, totals *totalsGenerate) { + r := j.repository + if j.input.Covers { task := &GenerateCoverTask{ - txnManager: j.txnManager, + repository: r, Scene: *scene, Overwrite: j.overwrite, } @@ -390,7 +395,7 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator, if j.input.Markers { task := &GenerateMarkersTask{ - TxnManager: j.txnManager, + repository: r, Scene: scene, Overwrite: j.overwrite, fileNamingAlgorithm: j.fileNamingAlgo, @@ -429,10 +434,9 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator, // generate for all files in scene for _, f := range scene.Files.List() { task := &GeneratePhashTask{ + repository: r, File: f, fileNamingAlgorithm: j.fileNamingAlgo, - txnManager: j.txnManager, - fileUpdater: j.txnManager.File, Overwrite: j.overwrite, } @@ -446,10 +450,10 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator, if j.input.InteractiveHeatmapsSpeeds { task := &GenerateInteractiveHeatmapSpeedTask{ + repository: r, Scene: *scene, Overwrite: j.overwrite, fileNamingAlgorithm: j.fileNamingAlgo, - TxnManager: j.txnManager, } if task.required() { @@ -462,7 +466,7 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator, func (j *GenerateJob) queueMarkerJob(g *generate.Generator, marker *models.SceneMarker, queue chan<- Task, totals *totalsGenerate) { task := &GenerateMarkersTask{ - TxnManager: j.txnManager, + repository: j.repository, Marker: marker, Overwrite: j.overwrite, fileNamingAlgorithm: j.fileNamingAlgo, diff --git a/internal/manager/task_generate_interactive_heatmap_speed.go b/internal/manager/task_generate_interactive_heatmap_speed.go index 4f91bd023ea..61350f09c2b 100644 --- a/internal/manager/task_generate_interactive_heatmap_speed.go +++ b/internal/manager/task_generate_interactive_heatmap_speed.go @@ -11,10 +11,10 @@ import ( ) type GenerateInteractiveHeatmapSpeedTask struct { + repository models.Repository Scene models.Scene Overwrite bool fileNamingAlgorithm models.HashAlgorithm - TxnManager Repository } func (t *GenerateInteractiveHeatmapSpeedTask) GetDescription() string { @@ -42,10 +42,11 @@ func (t *GenerateInteractiveHeatmapSpeedTask) Start(ctx context.Context) { median := generator.InteractiveSpeed - if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { + r := t.repository + if err := r.WithTxn(ctx, func(ctx context.Context) error { primaryFile := t.Scene.Files.Primary() primaryFile.InteractiveSpeed = &median - qb := t.TxnManager.File + qb := r.File return qb.Update(ctx, primaryFile) }); err != nil && ctx.Err() == nil { logger.Error(err.Error()) diff --git a/internal/manager/task_generate_markers.go b/internal/manager/task_generate_markers.go index fa5ac902255..2d792f718f2 100644 --- a/internal/manager/task_generate_markers.go +++ b/internal/manager/task_generate_markers.go @@ -12,7 +12,7 @@ import ( ) type GenerateMarkersTask struct { - TxnManager Repository + repository models.Repository Scene *models.Scene Marker *models.SceneMarker Overwrite bool @@ -41,9 +41,10 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) { if t.Marker != nil { var scene *models.Scene - if err := t.TxnManager.WithReadTxn(ctx, func(ctx context.Context) error { + r := t.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { var err error - scene, err = t.TxnManager.Scene.Find(ctx, t.Marker.SceneID) + scene, err = r.Scene.Find(ctx, t.Marker.SceneID) if err != nil { return err } @@ -51,7 +52,7 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) { return fmt.Errorf("scene with id %d not found", t.Marker.SceneID) } - return scene.LoadPrimaryFile(ctx, t.TxnManager.File) + return scene.LoadPrimaryFile(ctx, r.File) }); err != nil { logger.Errorf("error finding scene for marker generation: %v", err) return @@ -70,9 +71,10 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) { func (t *GenerateMarkersTask) generateSceneMarkers(ctx context.Context) { var sceneMarkers []*models.SceneMarker - if err := t.TxnManager.WithReadTxn(ctx, func(ctx context.Context) error { + r := t.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { var err error - sceneMarkers, err = t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID) + sceneMarkers, err = r.SceneMarker.FindBySceneID(ctx, t.Scene.ID) return err }); err != nil { logger.Errorf("error getting scene markers: %s", err.Error()) @@ -129,7 +131,7 @@ func (t *GenerateMarkersTask) generateMarker(videoFile *models.VideoFile, scene func (t *GenerateMarkersTask) markersNeeded(ctx context.Context) int { markers := 0 - sceneMarkers, err := t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID) + sceneMarkers, err := t.repository.SceneMarker.FindBySceneID(ctx, t.Scene.ID) if err != nil { logger.Errorf("error finding scene markers: %s", err.Error()) return 0 diff --git a/internal/manager/task_generate_phash.go b/internal/manager/task_generate_phash.go index 9f3945da34c..1d04d6a8aff 100644 --- a/internal/manager/task_generate_phash.go +++ b/internal/manager/task_generate_phash.go @@ -7,15 +7,13 @@ import ( "github.com/stashapp/stash/pkg/hash/videophash" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/txn" ) type GeneratePhashTask struct { + repository models.Repository File *models.VideoFile Overwrite bool fileNamingAlgorithm models.HashAlgorithm - txnManager txn.Manager - fileUpdater models.FileUpdater } func (t *GeneratePhashTask) GetDescription() string { @@ -34,15 +32,15 @@ func (t *GeneratePhashTask) Start(ctx context.Context) { return } - if err := txn.WithTxn(ctx, t.txnManager, func(ctx context.Context) error { - qb := t.fileUpdater + r := t.repository + if err := r.WithTxn(ctx, func(ctx context.Context) error { hashValue := int64(*hash) t.File.Fingerprints = t.File.Fingerprints.AppendUnique(models.Fingerprint{ Type: models.FingerprintTypePhash, Fingerprint: hashValue, }) - return qb.Update(ctx, t.File) + return r.File.Update(ctx, t.File) }); err != nil && ctx.Err() == nil { logger.Errorf("Error setting phash: %v", err) } diff --git a/internal/manager/task_generate_screenshot.go b/internal/manager/task_generate_screenshot.go index 1050ebd1c05..f8bff653c87 100644 --- a/internal/manager/task_generate_screenshot.go +++ b/internal/manager/task_generate_screenshot.go @@ -10,9 +10,9 @@ import ( ) type GenerateCoverTask struct { + repository models.Repository Scene models.Scene ScreenshotAt *float64 - txnManager Repository Overwrite bool } @@ -23,11 +23,13 @@ func (t *GenerateCoverTask) GetDescription() string { func (t *GenerateCoverTask) Start(ctx context.Context) { scenePath := t.Scene.Path + r := t.repository + var required bool - if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error { + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { required = t.required(ctx) - return t.Scene.LoadPrimaryFile(ctx, t.txnManager.File) + return t.Scene.LoadPrimaryFile(ctx, r.File) }); err != nil { logger.Error(err) } @@ -70,8 +72,8 @@ func (t *GenerateCoverTask) Start(ctx context.Context) { return } - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { - qb := t.txnManager.Scene + if err := r.WithTxn(ctx, func(ctx context.Context) error { + qb := r.Scene scenePartial := models.NewScenePartial() // update the scene cover table @@ -103,7 +105,7 @@ func (t *GenerateCoverTask) required(ctx context.Context) bool { } // if the scene has a cover, then we don't need to generate it - hasCover, err := t.txnManager.Scene.HasCover(ctx, t.Scene.ID) + hasCover, err := t.repository.Scene.HasCover(ctx, t.Scene.ID) if err != nil { logger.Errorf("Error getting cover: %v", err) return false diff --git a/internal/manager/task_identify.go b/internal/manager/task_identify.go index 0022a69ca31..8978e65cb0b 100644 --- a/internal/manager/task_identify.go +++ b/internal/manager/task_identify.go @@ -14,7 +14,6 @@ import ( "github.com/stashapp/stash/pkg/scraper" "github.com/stashapp/stash/pkg/scraper/stashbox" "github.com/stashapp/stash/pkg/sliceutil/stringslice" - "github.com/stashapp/stash/pkg/txn" ) var ErrInput = errors.New("invalid request input") @@ -52,7 +51,8 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) { // if scene ids provided, use those // otherwise, batch query for all scenes - ordering by path // don't use a transaction to query scenes - if err := txn.WithDatabase(ctx, instance.Repository, func(ctx context.Context) error { + r := instance.Repository + if err := r.WithDB(ctx, func(ctx context.Context) error { if len(j.input.SceneIDs) == 0 { return j.identifyAllScenes(ctx, sources) } @@ -70,7 +70,7 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) { // find the scene var err error - scene, err := instance.Repository.Scene.Find(ctx, id) + scene, err := r.Scene.Find(ctx, id) if err != nil { return fmt.Errorf("finding scene id %d: %w", id, err) } @@ -89,6 +89,8 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) { } func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.ScraperSource) error { + r := instance.Repository + // exclude organised organised := false sceneFilter := scene.FilterFromPaths(j.input.Paths) @@ -102,7 +104,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify. // get the count pp := 0 findFilter.PerPage = &pp - countResult, err := instance.Repository.Scene.Query(ctx, models.SceneQueryOptions{ + countResult, err := r.Scene.Query(ctx, models.SceneQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: findFilter, Count: true, @@ -115,7 +117,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify. j.progress.SetTotal(countResult.Count) - return scene.BatchProcess(ctx, instance.Repository.Scene, sceneFilter, findFilter, func(scene *models.Scene) error { + return scene.BatchProcess(ctx, r.Scene, sceneFilter, findFilter, func(scene *models.Scene) error { if job.IsCancelled(ctx) { return nil } @@ -132,18 +134,19 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source var taskError error j.progress.ExecuteTask("Identifying "+s.Path, func() { + r := instance.Repository task := identify.SceneIdentifier{ - SceneReaderUpdater: instance.Repository.Scene, - StudioReaderWriter: instance.Repository.Studio, - PerformerCreator: instance.Repository.Performer, - TagFinderCreator: instance.Repository.Tag, + SceneReaderUpdater: r.Scene, + StudioReaderWriter: r.Studio, + PerformerCreator: r.Performer, + TagFinderCreator: r.Tag, DefaultOptions: j.input.Options, Sources: sources, SceneUpdatePostHookExecutor: j.postHookExecutor, } - taskError = task.Identify(ctx, instance.Repository, s) + taskError = task.Identify(ctx, r.TxnManager, s) }) if taskError != nil { @@ -164,14 +167,15 @@ func (j *IdentifyJob) getSources() ([]identify.ScraperSource, error) { var src identify.ScraperSource if stashBox != nil { + r := instance.Repository src = identify.ScraperSource{ Name: "stash-box: " + stashBox.Endpoint, Scraper: stashboxSource{ - stashbox.NewClient(*stashBox, instance.Repository, stashbox.Repository{ - Scene: instance.Repository.Scene, - Performer: instance.Repository.Performer, - Tag: instance.Repository.Tag, - Studio: instance.Repository.Studio, + stashbox.NewClient(*stashBox, r.TxnManager, stashbox.Repository{ + Scene: r.Scene, + Performer: r.Performer, + Tag: r.Tag, + Studio: r.Studio, }), stashBox.Endpoint, }, diff --git a/internal/manager/task_import.go b/internal/manager/task_import.go index c0f97e254ae..5b3e30fe77d 100644 --- a/internal/manager/task_import.go +++ b/internal/manager/task_import.go @@ -25,8 +25,13 @@ import ( "github.com/stashapp/stash/pkg/tag" ) +type Resetter interface { + Reset() error +} + type ImportTask struct { - txnManager Repository + repository models.Repository + resetter Resetter json jsonUtils BaseDir string @@ -66,8 +71,10 @@ func CreateImportTask(a models.HashAlgorithm, input ImportObjectsInput) (*Import } } + mgr := GetInstance() return &ImportTask{ - txnManager: GetInstance().Repository, + repository: mgr.Repository, + resetter: mgr.Database, BaseDir: baseDir, TmpZip: tmpZip, Reset: false, @@ -109,7 +116,7 @@ func (t *ImportTask) Start(ctx context.Context) { } if t.Reset { - err := t.txnManager.Reset() + err := t.resetter.Reset() if err != nil { logger.Errorf("Error resetting database: %s", err.Error()) @@ -194,6 +201,8 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) { return } + r := t.repository + for i, fi := range files { index := i + 1 performerJSON, err := jsonschema.LoadPerformerFile(filepath.Join(path, fi.Name())) @@ -204,11 +213,9 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) { logger.Progressf("[performers] %d of %d", index, len(files)) - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { - r := t.txnManager - readerWriter := r.Performer + if err := r.WithTxn(ctx, func(ctx context.Context) error { importer := &performer.Importer{ - ReaderWriter: readerWriter, + ReaderWriter: r.Performer, TagWriter: r.Tag, Input: *performerJSON, } @@ -237,6 +244,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) { return } + r := t.repository + for i, fi := range files { index := i + 1 studioJSON, err := jsonschema.LoadStudioFile(filepath.Join(path, fi.Name())) @@ -247,8 +256,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) { logger.Progressf("[studios] %d of %d", index, len(files)) - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { - return t.ImportStudio(ctx, studioJSON, pendingParent, t.txnManager.Studio) + if err := r.WithTxn(ctx, func(ctx context.Context) error { + return t.importStudio(ctx, studioJSON, pendingParent) }); err != nil { if errors.Is(err, studio.ErrParentStudioNotExist) { // add to the pending parent list so that it is created after the parent @@ -269,8 +278,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) { for _, s := range pendingParent { for _, orphanStudioJSON := range s { - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { - return t.ImportStudio(ctx, orphanStudioJSON, nil, t.txnManager.Studio) + if err := r.WithTxn(ctx, func(ctx context.Context) error { + return t.importStudio(ctx, orphanStudioJSON, nil) }); err != nil { logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error()) continue @@ -282,9 +291,9 @@ func (t *ImportTask) ImportStudios(ctx context.Context) { logger.Info("[studios] import complete") } -func (t *ImportTask) ImportStudio(ctx context.Context, studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, readerWriter studio.ImporterReaderWriter) error { +func (t *ImportTask) importStudio(ctx context.Context, studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio) error { importer := &studio.Importer{ - ReaderWriter: readerWriter, + ReaderWriter: t.repository.Studio, Input: *studioJSON, MissingRefBehaviour: t.MissingRefBehaviour, } @@ -302,7 +311,7 @@ func (t *ImportTask) ImportStudio(ctx context.Context, studioJSON *jsonschema.St s := pendingParent[studioJSON.Name] for _, childStudioJSON := range s { // map is nil since we're not checking parent studios at this point - if err := t.ImportStudio(ctx, childStudioJSON, nil, readerWriter); err != nil { + if err := t.importStudio(ctx, childStudioJSON, nil); err != nil { return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error()) } } @@ -326,6 +335,8 @@ func (t *ImportTask) ImportMovies(ctx context.Context) { return } + r := t.repository + for i, fi := range files { index := i + 1 movieJSON, err := jsonschema.LoadMovieFile(filepath.Join(path, fi.Name())) @@ -336,14 +347,10 @@ func (t *ImportTask) ImportMovies(ctx context.Context) { logger.Progressf("[movies] %d of %d", index, len(files)) - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { - r := t.txnManager - readerWriter := r.Movie - studioReaderWriter := r.Studio - + if err := r.WithTxn(ctx, func(ctx context.Context) error { movieImporter := &movie.Importer{ - ReaderWriter: readerWriter, - StudioWriter: studioReaderWriter, + ReaderWriter: r.Movie, + StudioWriter: r.Studio, Input: *movieJSON, MissingRefBehaviour: t.MissingRefBehaviour, } @@ -371,6 +378,8 @@ func (t *ImportTask) ImportFiles(ctx context.Context) { return } + r := t.repository + pendingParent := make(map[string][]jsonschema.DirEntry) for i, fi := range files { @@ -383,8 +392,8 @@ func (t *ImportTask) ImportFiles(ctx context.Context) { logger.Progressf("[files] %d of %d", index, len(files)) - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { - return t.ImportFile(ctx, fileJSON, pendingParent) + if err := r.WithTxn(ctx, func(ctx context.Context) error { + return t.importFile(ctx, fileJSON, pendingParent) }); err != nil { if errors.Is(err, file.ErrZipFileNotExist) { // add to the pending parent list so that it is created after the parent @@ -405,8 +414,8 @@ func (t *ImportTask) ImportFiles(ctx context.Context) { for _, s := range pendingParent { for _, orphanFileJSON := range s { - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { - return t.ImportFile(ctx, orphanFileJSON, nil) + if err := r.WithTxn(ctx, func(ctx context.Context) error { + return t.importFile(ctx, orphanFileJSON, nil) }); err != nil { logger.Errorf("[files] <%s> failed to create: %s", orphanFileJSON.DirEntry().Path, err.Error()) continue @@ -418,12 +427,11 @@ func (t *ImportTask) ImportFiles(ctx context.Context) { logger.Info("[files] import complete") } -func (t *ImportTask) ImportFile(ctx context.Context, fileJSON jsonschema.DirEntry, pendingParent map[string][]jsonschema.DirEntry) error { - r := t.txnManager - readerWriter := r.File +func (t *ImportTask) importFile(ctx context.Context, fileJSON jsonschema.DirEntry, pendingParent map[string][]jsonschema.DirEntry) error { + r := t.repository fileImporter := &file.Importer{ - ReaderWriter: readerWriter, + ReaderWriter: r.File, FolderStore: r.Folder, Input: fileJSON, } @@ -437,7 +445,7 @@ func (t *ImportTask) ImportFile(ctx context.Context, fileJSON jsonschema.DirEntr s := pendingParent[fileJSON.DirEntry().Path] for _, childFileJSON := range s { // map is nil since we're not checking parent studios at this point - if err := t.ImportFile(ctx, childFileJSON, nil); err != nil { + if err := t.importFile(ctx, childFileJSON, nil); err != nil { return fmt.Errorf("failed to create child file <%s>: %s", childFileJSON.DirEntry().Path, err.Error()) } } @@ -461,6 +469,8 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) { return } + r := t.repository + for i, fi := range files { index := i + 1 galleryJSON, err := jsonschema.LoadGalleryFile(filepath.Join(path, fi.Name())) @@ -471,21 +481,14 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) { logger.Progressf("[galleries] %d of %d", index, len(files)) - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { - r := t.txnManager - readerWriter := r.Gallery - tagWriter := r.Tag - performerWriter := r.Performer - studioWriter := r.Studio - chapterWriter := r.GalleryChapter - + if err := r.WithTxn(ctx, func(ctx context.Context) error { galleryImporter := &gallery.Importer{ - ReaderWriter: readerWriter, + ReaderWriter: r.Gallery, FolderFinder: r.Folder, FileFinder: r.File, - PerformerWriter: performerWriter, - StudioWriter: studioWriter, - TagWriter: tagWriter, + PerformerWriter: r.Performer, + StudioWriter: r.Studio, + TagWriter: r.Tag, Input: *galleryJSON, MissingRefBehaviour: t.MissingRefBehaviour, } @@ -500,7 +503,7 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) { GalleryID: galleryImporter.ID, Input: m, MissingRefBehaviour: t.MissingRefBehaviour, - ReaderWriter: chapterWriter, + ReaderWriter: r.GalleryChapter, } if err := performImport(ctx, chapterImporter, t.DuplicateBehaviour); err != nil { @@ -532,6 +535,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) { return } + r := t.repository + for i, fi := range files { index := i + 1 tagJSON, err := jsonschema.LoadTagFile(filepath.Join(path, fi.Name())) @@ -542,8 +547,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) { logger.Progressf("[tags] %d of %d", index, len(files)) - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { - return t.ImportTag(ctx, tagJSON, pendingParent, false, t.txnManager.Tag) + if err := r.WithTxn(ctx, func(ctx context.Context) error { + return t.importTag(ctx, tagJSON, pendingParent, false) }); err != nil { var parentError tag.ParentTagNotExistError if errors.As(err, &parentError) { @@ -558,8 +563,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) { for _, s := range pendingParent { for _, orphanTagJSON := range s { - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { - return t.ImportTag(ctx, orphanTagJSON, nil, true, t.txnManager.Tag) + if err := r.WithTxn(ctx, func(ctx context.Context) error { + return t.importTag(ctx, orphanTagJSON, nil, true) }); err != nil { logger.Errorf("[tags] <%s> failed to create: %s", orphanTagJSON.Name, err.Error()) continue @@ -570,9 +575,9 @@ func (t *ImportTask) ImportTags(ctx context.Context) { logger.Info("[tags] import complete") } -func (t *ImportTask) ImportTag(ctx context.Context, tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool, readerWriter tag.ImporterReaderWriter) error { +func (t *ImportTask) importTag(ctx context.Context, tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool) error { importer := &tag.Importer{ - ReaderWriter: readerWriter, + ReaderWriter: t.repository.Tag, Input: *tagJSON, MissingRefBehaviour: t.MissingRefBehaviour, } @@ -587,7 +592,7 @@ func (t *ImportTask) ImportTag(ctx context.Context, tagJSON *jsonschema.Tag, pen } for _, childTagJSON := range pendingParent[tagJSON.Name] { - if err := t.ImportTag(ctx, childTagJSON, pendingParent, fail, readerWriter); err != nil { + if err := t.importTag(ctx, childTagJSON, pendingParent, fail); err != nil { var parentError tag.ParentTagNotExistError if errors.As(err, &parentError) { pendingParent[parentError.MissingParent()] = append(pendingParent[parentError.MissingParent()], childTagJSON) @@ -616,6 +621,8 @@ func (t *ImportTask) ImportScenes(ctx context.Context) { return } + r := t.repository + for i, fi := range files { index := i + 1 @@ -627,29 +634,20 @@ func (t *ImportTask) ImportScenes(ctx context.Context) { continue } - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { - r := t.txnManager - readerWriter := r.Scene - tagWriter := r.Tag - galleryWriter := r.Gallery - movieWriter := r.Movie - performerWriter := r.Performer - studioWriter := r.Studio - markerWriter := r.SceneMarker - + if err := r.WithTxn(ctx, func(ctx context.Context) error { sceneImporter := &scene.Importer{ - ReaderWriter: readerWriter, + ReaderWriter: r.Scene, Input: *sceneJSON, FileFinder: r.File, FileNamingAlgorithm: t.fileNamingAlgorithm, MissingRefBehaviour: t.MissingRefBehaviour, - GalleryFinder: galleryWriter, - MovieWriter: movieWriter, - PerformerWriter: performerWriter, - StudioWriter: studioWriter, - TagWriter: tagWriter, + GalleryFinder: r.Gallery, + MovieWriter: r.Movie, + PerformerWriter: r.Performer, + StudioWriter: r.Studio, + TagWriter: r.Tag, } if err := performImport(ctx, sceneImporter, t.DuplicateBehaviour); err != nil { @@ -662,8 +660,8 @@ func (t *ImportTask) ImportScenes(ctx context.Context) { SceneID: sceneImporter.ID, Input: m, MissingRefBehaviour: t.MissingRefBehaviour, - ReaderWriter: markerWriter, - TagWriter: tagWriter, + ReaderWriter: r.SceneMarker, + TagWriter: r.Tag, } if err := performImport(ctx, markerImporter, t.DuplicateBehaviour); err != nil { @@ -693,6 +691,8 @@ func (t *ImportTask) ImportImages(ctx context.Context) { return } + r := t.repository + for i, fi := range files { index := i + 1 @@ -704,25 +704,18 @@ func (t *ImportTask) ImportImages(ctx context.Context) { continue } - if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { - r := t.txnManager - readerWriter := r.Image - tagWriter := r.Tag - galleryWriter := r.Gallery - performerWriter := r.Performer - studioWriter := r.Studio - + if err := r.WithTxn(ctx, func(ctx context.Context) error { imageImporter := &image.Importer{ - ReaderWriter: readerWriter, + ReaderWriter: r.Image, FileFinder: r.File, Input: *imageJSON, MissingRefBehaviour: t.MissingRefBehaviour, - GalleryFinder: galleryWriter, - PerformerWriter: performerWriter, - StudioWriter: studioWriter, - TagWriter: tagWriter, + GalleryFinder: r.Gallery, + PerformerWriter: r.Performer, + StudioWriter: r.Studio, + TagWriter: r.Tag, } return performImport(ctx, imageImporter, t.DuplicateBehaviour) diff --git a/internal/manager/task_scan.go b/internal/manager/task_scan.go index f1f3e39272f..26985e86fbc 100644 --- a/internal/manager/task_scan.go +++ b/internal/manager/task_scan.go @@ -19,6 +19,7 @@ import ( "github.com/stashapp/stash/pkg/job" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/models/paths" "github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/scene/generate" "github.com/stashapp/stash/pkg/txn" @@ -48,10 +49,14 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) { paths[i] = p.Path } + mgr := GetInstance() + c := mgr.Config + repo := mgr.Repository + start := time.Now() const taskQueueSize = 200000 - taskQueue := job.NewTaskQueue(ctx, progress, taskQueueSize, instance.Config.GetParallelTasksWithAutoDetection()) + taskQueue := job.NewTaskQueue(ctx, progress, taskQueueSize, c.GetParallelTasksWithAutoDetection()) var minModTime time.Time if j.input.Filter != nil && j.input.Filter.MinModTime != nil { @@ -59,13 +64,11 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) { } j.scanner.Scan(ctx, getScanHandlers(j.input, taskQueue, progress), file.ScanOptions{ - Paths: paths, - ScanFilters: []file.PathFilter{newScanFilter(instance.Config, minModTime)}, - ZipFileExtensions: instance.Config.GetGalleryExtensions(), - ParallelTasks: instance.Config.GetParallelTasksWithAutoDetection(), - HandlerRequiredFilters: []file.Filter{ - newHandlerRequiredFilter(instance.Config), - }, + Paths: paths, + ScanFilters: []file.PathFilter{newScanFilter(c, repo, minModTime)}, + ZipFileExtensions: c.GetGalleryExtensions(), + ParallelTasks: c.GetParallelTasksWithAutoDetection(), + HandlerRequiredFilters: []file.Filter{newHandlerRequiredFilter(c, repo)}, }, progress) taskQueue.Close() @@ -123,17 +126,16 @@ type handlerRequiredFilter struct { videoFileNamingAlgorithm models.HashAlgorithm } -func newHandlerRequiredFilter(c *config.Instance) *handlerRequiredFilter { - db := instance.Database +func newHandlerRequiredFilter(c *config.Instance, repo models.Repository) *handlerRequiredFilter { processes := c.GetParallelTasksWithAutoDetection() return &handlerRequiredFilter{ extensionConfig: newExtensionConfig(c), - txnManager: db, - SceneFinder: db.Scene, - ImageFinder: db.Image, - GalleryFinder: db.Gallery, - CaptionUpdater: db.File, + txnManager: repo.TxnManager, + SceneFinder: repo.Scene, + ImageFinder: repo.Image, + GalleryFinder: repo.Gallery, + CaptionUpdater: repo.File, FolderCache: lru.New(processes * 2), videoFileNamingAlgorithm: c.GetVideoFileNamingAlgorithm(), } @@ -226,6 +228,10 @@ func (f *handlerRequiredFilter) Accept(ctx context.Context, ff models.File) bool type scanFilter struct { extensionConfig + txnManager txn.Manager + FileFinder models.FileFinder + CaptionUpdater video.CaptionUpdater + stashPaths config.StashConfigs generatedPath string videoExcludeRegex []*regexp.Regexp @@ -233,9 +239,12 @@ type scanFilter struct { minModTime time.Time } -func newScanFilter(c *config.Instance, minModTime time.Time) *scanFilter { +func newScanFilter(c *config.Instance, repo models.Repository, minModTime time.Time) *scanFilter { return &scanFilter{ extensionConfig: newExtensionConfig(c), + txnManager: repo.TxnManager, + FileFinder: repo.File, + CaptionUpdater: repo.File, stashPaths: c.GetStashPaths(), generatedPath: c.GetGeneratedPath(), videoExcludeRegex: generateRegexps(c.GetExcludes()), @@ -263,7 +272,7 @@ func (f *scanFilter) Accept(ctx context.Context, path string, info fs.FileInfo) if fsutil.MatchExtension(path, video.CaptionExts) { // we don't include caption files in the file scan, but we do need // to handle them - video.AssociateCaptions(ctx, path, instance.Repository, instance.Database.File, instance.Database.File) + video.AssociateCaptions(ctx, path, f.txnManager, f.FileFinder, f.CaptionUpdater) return false } @@ -308,30 +317,37 @@ func (f *scanFilter) Accept(ctx context.Context, path string, info fs.FileInfo) type scanConfig struct { isGenerateThumbnails bool isGenerateClipPreviews bool + + createGalleriesFromFolders bool } func (c *scanConfig) GetCreateGalleriesFromFolders() bool { - return instance.Config.GetCreateGalleriesFromFolders() + return c.createGalleriesFromFolders } func getScanHandlers(options ScanMetadataInput, taskQueue *job.TaskQueue, progress *job.Progress) []file.Handler { - db := instance.Database - pluginCache := instance.PluginCache + mgr := GetInstance() + c := mgr.Config + r := mgr.Repository + pluginCache := mgr.PluginCache return []file.Handler{ &file.FilteredHandler{ Filter: file.FilterFunc(imageFileFilter), Handler: &image.ScanHandler{ - CreatorUpdater: db.Image, - GalleryFinder: db.Gallery, + CreatorUpdater: r.Image, + GalleryFinder: r.Gallery, ScanGenerator: &imageGenerators{ - input: options, - taskQueue: taskQueue, - progress: progress, + input: options, + taskQueue: taskQueue, + progress: progress, + paths: mgr.Paths, + sequentialScanning: c.GetSequentialScanning(), }, ScanConfig: &scanConfig{ - isGenerateThumbnails: options.ScanGenerateThumbnails, - isGenerateClipPreviews: options.ScanGenerateClipPreviews, + isGenerateThumbnails: options.ScanGenerateThumbnails, + isGenerateClipPreviews: options.ScanGenerateClipPreviews, + createGalleriesFromFolders: c.GetCreateGalleriesFromFolders(), }, PluginCache: pluginCache, Paths: instance.Paths, @@ -340,25 +356,28 @@ func getScanHandlers(options ScanMetadataInput, taskQueue *job.TaskQueue, progre &file.FilteredHandler{ Filter: file.FilterFunc(galleryFileFilter), Handler: &gallery.ScanHandler{ - CreatorUpdater: db.Gallery, - SceneFinderUpdater: db.Scene, - ImageFinderUpdater: db.Image, + CreatorUpdater: r.Gallery, + SceneFinderUpdater: r.Scene, + ImageFinderUpdater: r.Image, PluginCache: pluginCache, }, }, &file.FilteredHandler{ Filter: file.FilterFunc(videoFileFilter), Handler: &scene.ScanHandler{ - CreatorUpdater: db.Scene, + CreatorUpdater: r.Scene, + CaptionUpdater: r.File, PluginCache: pluginCache, - CaptionUpdater: db.File, ScanGenerator: &sceneGenerators{ - input: options, - taskQueue: taskQueue, - progress: progress, + input: options, + taskQueue: taskQueue, + progress: progress, + paths: mgr.Paths, + fileNamingAlgorithm: c.GetVideoFileNamingAlgorithm(), + sequentialScanning: c.GetSequentialScanning(), }, - FileNamingAlgorithm: instance.Config.GetVideoFileNamingAlgorithm(), - Paths: instance.Paths, + FileNamingAlgorithm: c.GetVideoFileNamingAlgorithm(), + Paths: mgr.Paths, }, }, } @@ -368,6 +387,9 @@ type imageGenerators struct { input ScanMetadataInput taskQueue *job.TaskQueue progress *job.Progress + + paths *paths.Paths + sequentialScanning bool } func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f models.File) error { @@ -376,8 +398,6 @@ func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f model progress := g.progress t := g.input path := f.Base().Path - config := instance.Config - sequentialScanning := config.GetSequentialScanning() if t.ScanGenerateThumbnails { // this should be quick, so always generate sequentially @@ -405,7 +425,7 @@ func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f model progress.Increment() } - if sequentialScanning { + if g.sequentialScanning { previewsFn(ctx) } else { g.taskQueue.Add(fmt.Sprintf("Generating preview for %s", path), previewsFn) @@ -416,7 +436,7 @@ func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f model } func (g *imageGenerators) generateThumbnail(ctx context.Context, i *models.Image, f models.File) error { - thumbPath := GetInstance().Paths.Generated.GetThumbnailPath(i.Checksum, models.DefaultGthumbWidth) + thumbPath := g.paths.Generated.GetThumbnailPath(i.Checksum, models.DefaultGthumbWidth) exists, _ := fsutil.FileExists(thumbPath) if exists { return nil @@ -435,13 +455,16 @@ func (g *imageGenerators) generateThumbnail(ctx context.Context, i *models.Image logger.Debugf("Generating thumbnail for %s", path) + mgr := GetInstance() + c := mgr.Config + clipPreviewOptions := image.ClipPreviewOptions{ - InputArgs: instance.Config.GetTranscodeInputArgs(), - OutputArgs: instance.Config.GetTranscodeOutputArgs(), - Preset: instance.Config.GetPreviewPreset().String(), + InputArgs: c.GetTranscodeInputArgs(), + OutputArgs: c.GetTranscodeOutputArgs(), + Preset: c.GetPreviewPreset().String(), } - encoder := image.NewThumbnailEncoder(instance.FFMPEG, instance.FFProbe, clipPreviewOptions) + encoder := image.NewThumbnailEncoder(mgr.FFMPEG, mgr.FFProbe, clipPreviewOptions) data, err := encoder.GetThumbnail(f, models.DefaultGthumbWidth) if err != nil { @@ -464,6 +487,10 @@ type sceneGenerators struct { input ScanMetadataInput taskQueue *job.TaskQueue progress *job.Progress + + paths *paths.Paths + fileNamingAlgorithm models.HashAlgorithm + sequentialScanning bool } func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *models.VideoFile) error { @@ -472,9 +499,8 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode progress := g.progress t := g.input path := f.Path - config := instance.Config - fileNamingAlgorithm := config.GetVideoFileNamingAlgorithm() - sequentialScanning := config.GetSequentialScanning() + + mgr := GetInstance() if t.ScanGenerateSprites { progress.AddTotal(1) @@ -482,13 +508,13 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode taskSprite := GenerateSpriteTask{ Scene: *s, Overwrite: overwrite, - fileNamingAlgorithm: fileNamingAlgorithm, + fileNamingAlgorithm: g.fileNamingAlgorithm, } taskSprite.Start(ctx) progress.Increment() } - if sequentialScanning { + if g.sequentialScanning { spriteFn(ctx) } else { g.taskQueue.Add(fmt.Sprintf("Generating sprites for %s", path), spriteFn) @@ -499,17 +525,16 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode progress.AddTotal(1) phashFn := func(ctx context.Context) { taskPhash := GeneratePhashTask{ + repository: mgr.Repository, File: f, - fileNamingAlgorithm: fileNamingAlgorithm, - txnManager: instance.Database, - fileUpdater: instance.Database.File, Overwrite: overwrite, + fileNamingAlgorithm: g.fileNamingAlgorithm, } taskPhash.Start(ctx) progress.Increment() } - if sequentialScanning { + if g.sequentialScanning { phashFn(ctx) } else { g.taskQueue.Add(fmt.Sprintf("Generating phash for %s", path), phashFn) @@ -521,12 +546,12 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode previewsFn := func(ctx context.Context) { options := getGeneratePreviewOptions(GeneratePreviewOptionsInput{}) - g := &generate.Generator{ - Encoder: instance.FFMPEG, - FFMpegConfig: instance.Config, - LockManager: instance.ReadLockManager, - MarkerPaths: instance.Paths.SceneMarkers, - ScenePaths: instance.Paths.Scene, + generator := &generate.Generator{ + Encoder: mgr.FFMPEG, + FFMpegConfig: mgr.Config, + LockManager: mgr.ReadLockManager, + MarkerPaths: g.paths.SceneMarkers, + ScenePaths: g.paths.Scene, Overwrite: overwrite, } @@ -535,14 +560,14 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode ImagePreview: t.ScanGenerateImagePreviews, Options: options, Overwrite: overwrite, - fileNamingAlgorithm: fileNamingAlgorithm, - generator: g, + fileNamingAlgorithm: g.fileNamingAlgorithm, + generator: generator, } taskPreview.Start(ctx) progress.Increment() } - if sequentialScanning { + if g.sequentialScanning { previewsFn(ctx) } else { g.taskQueue.Add(fmt.Sprintf("Generating preview for %s", path), previewsFn) @@ -553,8 +578,8 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode progress.AddTotal(1) g.taskQueue.Add(fmt.Sprintf("Generating cover for %s", path), func(ctx context.Context) { taskCover := GenerateCoverTask{ + repository: mgr.Repository, Scene: *s, - txnManager: instance.Repository, Overwrite: overwrite, } taskCover.Start(ctx) diff --git a/internal/manager/task_stash_box_tag.go b/internal/manager/task_stash_box_tag.go index 6833f166343..9cd48abc8f0 100644 --- a/internal/manager/task_stash_box_tag.go +++ b/internal/manager/task_stash_box_tag.go @@ -9,7 +9,6 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scraper/stashbox" "github.com/stashapp/stash/pkg/studio" - "github.com/stashapp/stash/pkg/txn" ) type StashBoxTagTaskType int @@ -92,18 +91,21 @@ func (t *StashBoxBatchTagTask) findStashBoxPerformer(ctx context.Context) (*mode var performer *models.ScrapedPerformer var err error - client := stashbox.NewClient(*t.box, instance.Repository, stashbox.Repository{ - Scene: instance.Repository.Scene, - Performer: instance.Repository.Performer, - Tag: instance.Repository.Tag, - Studio: instance.Repository.Studio, + r := instance.Repository + client := stashbox.NewClient(*t.box, r.TxnManager, stashbox.Repository{ + Scene: r.Scene, + Performer: r.Performer, + Tag: r.Tag, + Studio: r.Studio, }) if t.refresh { var remoteID string - if err := txn.WithReadTxn(ctx, instance.Repository, func(ctx context.Context) error { + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + qb := r.Performer + if !t.performer.StashIDs.Loaded() { - err = t.performer.LoadStashIDs(ctx, instance.Repository.Performer) + err = t.performer.LoadStashIDs(ctx, qb) if err != nil { return err } @@ -145,8 +147,9 @@ func (t *StashBoxBatchTagTask) processMatchedPerformer(ctx context.Context, p *m } // Start the transaction and update the performer - err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { - qb := instance.Repository.Performer + r := instance.Repository + err = r.WithTxn(ctx, func(ctx context.Context) error { + qb := r.Performer existingStashIDs, err := qb.GetStashIDs(ctx, storedID) if err != nil { @@ -181,8 +184,10 @@ func (t *StashBoxBatchTagTask) processMatchedPerformer(ctx context.Context, p *m return } - err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { - qb := instance.Repository.Performer + r := instance.Repository + err = r.WithTxn(ctx, func(ctx context.Context) error { + qb := r.Performer + if err := qb.Create(ctx, newPerformer); err != nil { return err } @@ -233,18 +238,19 @@ func (t *StashBoxBatchTagTask) findStashBoxStudio(ctx context.Context) (*models. var studio *models.ScrapedStudio var err error - client := stashbox.NewClient(*t.box, instance.Repository, stashbox.Repository{ - Scene: instance.Repository.Scene, - Performer: instance.Repository.Performer, - Tag: instance.Repository.Tag, - Studio: instance.Repository.Studio, + r := instance.Repository + client := stashbox.NewClient(*t.box, r.TxnManager, stashbox.Repository{ + Scene: r.Scene, + Performer: r.Performer, + Tag: r.Tag, + Studio: r.Studio, }) if t.refresh { var remoteID string - if err := txn.WithReadTxn(ctx, instance.Repository, func(ctx context.Context) error { + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { if !t.studio.StashIDs.Loaded() { - err = t.studio.LoadStashIDs(ctx, instance.Repository.Studio) + err = t.studio.LoadStashIDs(ctx, r.Studio) if err != nil { return err } @@ -293,8 +299,9 @@ func (t *StashBoxBatchTagTask) processMatchedStudio(ctx context.Context, s *mode } // Start the transaction and update the studio - err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { - qb := instance.Repository.Studio + r := instance.Repository + err = r.WithTxn(ctx, func(ctx context.Context) error { + qb := r.Studio existingStashIDs, err := qb.GetStashIDs(ctx, storedID) if err != nil { @@ -341,8 +348,10 @@ func (t *StashBoxBatchTagTask) processMatchedStudio(ctx context.Context, s *mode } // Start the transaction and save the studio - err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { - qb := instance.Repository.Studio + r := instance.Repository + err = r.WithTxn(ctx, func(ctx context.Context) error { + qb := r.Studio + if err := qb.Create(ctx, newStudio); err != nil { return err } @@ -375,8 +384,10 @@ func (t *StashBoxBatchTagTask) processParentStudio(ctx context.Context, parent * } // Start the transaction and save the studio - err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { - qb := instance.Repository.Studio + r := instance.Repository + err = r.WithTxn(ctx, func(ctx context.Context) error { + qb := r.Studio + if err := qb.Create(ctx, newParentStudio); err != nil { return err } @@ -408,8 +419,9 @@ func (t *StashBoxBatchTagTask) processParentStudio(ctx context.Context, parent * } // Start the transaction and update the studio - err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { - qb := instance.Repository.Studio + r := instance.Repository + err = r.WithTxn(ctx, func(ctx context.Context) error { + qb := r.Studio existingStashIDs, err := qb.GetStashIDs(ctx, storedID) if err != nil { diff --git a/pkg/file/move.go b/pkg/file/move.go index 64a83fed645..04305c8b5a5 100644 --- a/pkg/file/move.go +++ b/pkg/file/move.go @@ -181,7 +181,7 @@ func (m *Mover) moveFile(oldPath, newPath string) error { return nil } -func (m *Mover) RegisterHooks(ctx context.Context, mgr txn.Manager) { +func (m *Mover) RegisterHooks(ctx context.Context) { txn.AddPostCommitHook(ctx, func(ctx context.Context) { m.commit() }) diff --git a/pkg/models/repository.go b/pkg/models/repository.go index 9ba4eead11a..3e06a49e14d 100644 --- a/pkg/models/repository.go +++ b/pkg/models/repository.go @@ -1,17 +1,18 @@ package models import ( + "context" + "github.com/stashapp/stash/pkg/txn" ) type TxnManager interface { txn.Manager txn.DatabaseProvider - Reset() error } type Repository struct { - TxnManager + TxnManager TxnManager File FileReaderWriter Folder FolderReaderWriter @@ -26,3 +27,15 @@ type Repository struct { Tag TagReaderWriter SavedFilter SavedFilterReaderWriter } + +func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error { + return txn.WithTxn(ctx, r.TxnManager, fn) +} + +func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error { + return txn.WithReadTxn(ctx, r.TxnManager, fn) +} + +func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error { + return txn.WithDatabase(ctx, r.TxnManager, fn) +} diff --git a/pkg/sqlite/setup_test.go b/pkg/sqlite/setup_test.go index e182ef99b5b..cdead935d09 100644 --- a/pkg/sqlite/setup_test.go +++ b/pkg/sqlite/setup_test.go @@ -1176,7 +1176,7 @@ func makeImage(i int) *models.Image { } func createImages(ctx context.Context, n int) error { - qb := db.TxnRepository().Image + qb := db.Repository().Image fqb := db.File for i := 0; i < n; i++ { @@ -1262,7 +1262,7 @@ func makeGallery(i int, includeScenes bool) *models.Gallery { } func createGalleries(ctx context.Context, n int) error { - gqb := db.TxnRepository().Gallery + gqb := db.Repository().Gallery fqb := db.File for i := 0; i < n; i++ { diff --git a/pkg/sqlite/transaction.go b/pkg/sqlite/transaction.go index 726b927e714..c763d4e0c8a 100644 --- a/pkg/sqlite/transaction.go +++ b/pkg/sqlite/transaction.go @@ -123,7 +123,7 @@ func (db *Database) IsLocked(err error) bool { return false } -func (db *Database) TxnRepository() models.Repository { +func (db *Database) Repository() models.Repository { return models.Repository{ TxnManager: db, File: db.File, From ea432c7f8b239aaec2dc740676a1e842cc8e3f32 Mon Sep 17 00:00:00 2001 From: DingDongSoLong4 <99329275+DingDongSoLong4@users.noreply.github.com> Date: Sun, 6 Aug 2023 23:59:30 +0200 Subject: [PATCH 2/9] Refactor other repositories --- internal/api/resolver.go | 5 ++ internal/api/resolver_mutation_migrate.go | 18 ++--- internal/api/resolver_mutation_stash_box.go | 15 +---- internal/api/resolver_query_configuration.go | 4 +- internal/api/resolver_query_find_scene.go | 11 +-- internal/api/resolver_query_scraper.go | 2 +- internal/dlna/cds.go | 42 +++++++----- internal/dlna/dms.go | 12 ++-- internal/dlna/service.go | 26 +++++-- internal/identify/identify.go | 19 +++--- internal/identify/identify_test.go | 8 +-- internal/manager/manager.go | 37 ++-------- internal/manager/task_clean.go | 21 +++--- internal/manager/task_identify.go | 12 ++-- internal/manager/task_stash_box_tag.go | 18 ++--- pkg/file/clean.go | 33 +++++---- pkg/file/file.go | 29 ++++++-- pkg/file/folder_rename_detect.go | 8 ++- pkg/file/scan.go | 32 ++++----- pkg/scene/filename_parser.go | 36 +++++++--- pkg/scraper/autotag.go | 5 +- pkg/scraper/cache.go | 42 +++++++++--- pkg/scraper/postprocessing.go | 31 +++++---- pkg/scraper/stashbox/stash_box.go | 71 +++++++++++++------- 24 files changed, 300 insertions(+), 237 deletions(-) diff --git a/internal/api/resolver.go b/internal/api/resolver.go index 91b8dc5b742..7988182acaa 100644 --- a/internal/api/resolver.go +++ b/internal/api/resolver.go @@ -13,6 +13,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/scraper" + "github.com/stashapp/stash/pkg/scraper/stashbox" ) var ( @@ -107,6 +108,10 @@ func (r *Resolver) withReadTxn(ctx context.Context, fn func(ctx context.Context) return r.repository.WithReadTxn(ctx, fn) } +func (r *Resolver) stashboxRepository() stashbox.Repository { + return stashbox.NewRepository(r.repository) +} + func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*models.SceneMarker, err error) { if err := r.withReadTxn(ctx, func(ctx context.Context) error { ret, err = r.repository.SceneMarker.Wall(ctx, q) diff --git a/internal/api/resolver_mutation_migrate.go b/internal/api/resolver_mutation_migrate.go index 477f46b44d9..26a56a1b892 100644 --- a/internal/api/resolver_mutation_migrate.go +++ b/internal/api/resolver_mutation_migrate.go @@ -11,30 +11,30 @@ import ( ) func (r *mutationResolver) MigrateSceneScreenshots(ctx context.Context, input MigrateSceneScreenshotsInput) (string, error) { - db := manager.GetInstance().Database + mgr := manager.GetInstance() t := &task.MigrateSceneScreenshotsJob{ ScreenshotsPath: manager.GetInstance().Paths.Generated.Screenshots, Input: scene.MigrateSceneScreenshotsInput{ DeleteFiles: utils.IsTrue(input.DeleteFiles), OverwriteExisting: utils.IsTrue(input.OverwriteExisting), }, - SceneRepo: db.Scene, - TxnManager: db, + SceneRepo: mgr.Repository.Scene, + TxnManager: mgr.Repository.TxnManager, } - jobID := manager.GetInstance().JobManager.Add(ctx, "Migrating scene screenshots to blobs...", t) + jobID := mgr.JobManager.Add(ctx, "Migrating scene screenshots to blobs...", t) return strconv.Itoa(jobID), nil } func (r *mutationResolver) MigrateBlobs(ctx context.Context, input MigrateBlobsInput) (string, error) { - db := manager.GetInstance().Database + mgr := manager.GetInstance() t := &task.MigrateBlobsJob{ - TxnManager: db, - BlobStore: db.Blobs, - Vacuumer: db, + TxnManager: mgr.Database, + BlobStore: mgr.Database.Blobs, + Vacuumer: mgr.Database, DeleteOld: utils.IsTrue(input.DeleteOld), } - jobID := manager.GetInstance().JobManager.Add(ctx, "Migrating blobs...", t) + jobID := mgr.JobManager.Add(ctx, "Migrating blobs...", t) return strconv.Itoa(jobID), nil } diff --git a/internal/api/resolver_mutation_stash_box.go b/internal/api/resolver_mutation_stash_box.go index 4f595d3ae9e..2198ab6ff4a 100644 --- a/internal/api/resolver_mutation_stash_box.go +++ b/internal/api/resolver_mutation_stash_box.go @@ -11,15 +11,6 @@ import ( "github.com/stashapp/stash/pkg/scraper/stashbox" ) -func (r *Resolver) stashboxRepository() stashbox.Repository { - return stashbox.Repository{ - Scene: r.repository.Scene, - Performer: r.repository.Performer, - Tag: r.repository.Tag, - Studio: r.repository.Studio, - } -} - func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input StashBoxFingerprintSubmissionInput) (bool, error) { boxes := config.GetInstance().GetStashBoxes() @@ -27,7 +18,7 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input return false, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) } - client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.repository.TxnManager, r.stashboxRepository()) + client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.stashboxRepository()) return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint) } @@ -49,7 +40,7 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input S return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) } - client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.repository.TxnManager, r.stashboxRepository()) + client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.stashboxRepository()) id, err := strconv.Atoi(input.ID) if err != nil { @@ -91,7 +82,7 @@ func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, inp return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) } - client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.repository.TxnManager, r.stashboxRepository()) + client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.stashboxRepository()) id, err := strconv.Atoi(input.ID) if err != nil { diff --git a/internal/api/resolver_query_configuration.go b/internal/api/resolver_query_configuration.go index c6fe587ae52..8ac5f840d1e 100644 --- a/internal/api/resolver_query_configuration.go +++ b/internal/api/resolver_query_configuration.go @@ -243,7 +243,9 @@ func makeConfigUIResult() map[string]interface{} { } func (r *queryResolver) ValidateStashBoxCredentials(ctx context.Context, input config.StashBoxInput) (*StashBoxValidationResult, error) { - client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.repository.TxnManager, r.stashboxRepository()) + box := models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey} + client := stashbox.NewClient(box, r.stashboxRepository()) + user, err := client.GetUser(ctx) valid := user != nil && user.Me != nil diff --git a/internal/api/resolver_query_find_scene.go b/internal/api/resolver_query_find_scene.go index 2b33d211585..c6492680ea6 100644 --- a/internal/api/resolver_query_find_scene.go +++ b/internal/api/resolver_query_find_scene.go @@ -191,16 +191,11 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model } func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config models.SceneParserInput) (ret *SceneParserResultType, err error) { - parser := scene.NewFilenameParser(filter, config) + repo := scene.NewFilenameParserRepository(r.repository) + parser := scene.NewFilenameParser(filter, config, repo) if err := r.withReadTxn(ctx, func(ctx context.Context) error { - result, count, err := parser.Parse(ctx, scene.FilenameParserRepository{ - Scene: r.repository.Scene, - Performer: r.repository.Performer, - Studio: r.repository.Studio, - Movie: r.repository.Movie, - Tag: r.repository.Tag, - }) + result, count, err := parser.Parse(ctx) if err != nil { return err diff --git a/internal/api/resolver_query_scraper.go b/internal/api/resolver_query_scraper.go index 55d6564dece..97d3f526549 100644 --- a/internal/api/resolver_query_scraper.go +++ b/internal/api/resolver_query_scraper.go @@ -238,7 +238,7 @@ func (r *queryResolver) getStashBoxClient(index int) (*stashbox.Client, error) { return nil, fmt.Errorf("%w: invalid stash_box_index %d", ErrInput, index) } - return stashbox.NewClient(*boxes[index], r.repository.TxnManager, r.stashboxRepository()), nil + return stashbox.NewClient(*boxes[index], r.stashboxRepository()), nil } func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.Source, input ScrapeSingleSceneInput) ([]*scraper.ScrapedScene, error) { diff --git a/internal/dlna/cds.go b/internal/dlna/cds.go index eba98ac489f..a23f2177cfe 100644 --- a/internal/dlna/cds.go +++ b/internal/dlna/cds.go @@ -41,7 +41,6 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/sliceutil/stringslice" - "github.com/stashapp/stash/pkg/txn" ) var pageSize = 100 @@ -360,10 +359,11 @@ func (me *contentDirectoryService) handleBrowseMetadata(obj object, host string) } else { var scene *models.Scene - if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { - scene, err = me.repository.SceneFinder.Find(ctx, sceneID) + r := me.repository + if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error { + scene, err = r.SceneFinder.Find(ctx, sceneID) if scene != nil { - err = scene.LoadPrimaryFile(ctx, me.repository.FileGetter) + err = scene.LoadPrimaryFile(ctx, r.FileGetter) } if err != nil { @@ -452,7 +452,8 @@ func getSortDirection(sceneFilter *models.SceneFilterType, sort string) models.S func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType, parentID string, host string) []interface{} { var objs []interface{} - if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { + r := me.repository + if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error { sort := me.VideoSortOrder direction := getSortDirection(sceneFilter, sort) findFilter := &models.FindFilterType{ @@ -461,7 +462,7 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType Direction: &direction, } - scenes, total, err := scene.QueryWithCount(ctx, me.repository.SceneFinder, sceneFilter, findFilter) + scenes, total, err := scene.QueryWithCount(ctx, r.SceneFinder, sceneFilter, findFilter) if err != nil { return err } @@ -472,13 +473,13 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType parentID: parentID, } - objs, err = pager.getPages(ctx, me.repository.SceneFinder, total) + objs, err = pager.getPages(ctx, r.SceneFinder, total) if err != nil { return err } } else { for _, s := range scenes { - if err := s.LoadPrimaryFile(ctx, me.repository.FileGetter); err != nil { + if err := s.LoadPrimaryFile(ctx, r.FileGetter); err != nil { return err } @@ -497,7 +498,8 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType func (me *contentDirectoryService) getPageVideos(sceneFilter *models.SceneFilterType, parentID string, page int, host string) []interface{} { var objs []interface{} - if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { + r := me.repository + if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error { pager := scenePager{ sceneFilter: sceneFilter, parentID: parentID, @@ -506,7 +508,7 @@ func (me *contentDirectoryService) getPageVideos(sceneFilter *models.SceneFilter sort := me.VideoSortOrder direction := getSortDirection(sceneFilter, sort) var err error - objs, err = pager.getPageVideos(ctx, me.repository.SceneFinder, me.repository.FileGetter, page, host, sort, direction) + objs, err = pager.getPageVideos(ctx, r.SceneFinder, r.FileGetter, page, host, sort, direction) if err != nil { return err } @@ -540,8 +542,9 @@ func (me *contentDirectoryService) getAllScenes(host string) []interface{} { func (me *contentDirectoryService) getStudios() []interface{} { var objs []interface{} - if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { - studios, err := me.repository.StudioFinder.All(ctx) + r := me.repository + if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error { + studios, err := r.StudioFinder.All(ctx) if err != nil { return err } @@ -579,8 +582,9 @@ func (me *contentDirectoryService) getStudioScenes(paths []string, host string) func (me *contentDirectoryService) getTags() []interface{} { var objs []interface{} - if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { - tags, err := me.repository.TagFinder.All(ctx) + r := me.repository + if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error { + tags, err := r.TagFinder.All(ctx) if err != nil { return err } @@ -618,8 +622,9 @@ func (me *contentDirectoryService) getTagScenes(paths []string, host string) []i func (me *contentDirectoryService) getPerformers() []interface{} { var objs []interface{} - if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { - performers, err := me.repository.PerformerFinder.All(ctx) + r := me.repository + if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error { + performers, err := r.PerformerFinder.All(ctx) if err != nil { return err } @@ -657,8 +662,9 @@ func (me *contentDirectoryService) getPerformerScenes(paths []string, host strin func (me *contentDirectoryService) getMovies() []interface{} { var objs []interface{} - if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { - movies, err := me.repository.MovieFinder.All(ctx) + r := me.repository + if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error { + movies, err := r.MovieFinder.All(ctx) if err != nil { return err } diff --git a/internal/dlna/dms.go b/internal/dlna/dms.go index fe078aab022..4ca8eeddc2c 100644 --- a/internal/dlna/dms.go +++ b/internal/dlna/dms.go @@ -48,7 +48,6 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/txn" ) type SceneFinder interface { @@ -271,7 +270,6 @@ type Server struct { // Time interval between SSPD announces NotifyInterval time.Duration - txnManager txn.Manager repository Repository sceneServer sceneServer ipWhitelistManager *ipWhitelistManager @@ -439,12 +437,13 @@ func (me *Server) serveIcon(w http.ResponseWriter, r *http.Request) { } var scene *models.Scene - err := txn.WithReadTxn(r.Context(), me.txnManager, func(ctx context.Context) error { + repo := me.repository + err := repo.WithReadTxn(r.Context(), func(ctx context.Context) error { idInt, err := strconv.Atoi(sceneId) if err != nil { return nil } - scene, _ = me.repository.SceneFinder.Find(ctx, idInt) + scene, _ = repo.SceneFinder.Find(ctx, idInt) return nil }) if err != nil { @@ -579,12 +578,13 @@ func (me *Server) initMux(mux *http.ServeMux) { mux.HandleFunc(resPath, func(w http.ResponseWriter, r *http.Request) { sceneId := r.URL.Query().Get("scene") var scene *models.Scene - err := txn.WithReadTxn(r.Context(), me.txnManager, func(ctx context.Context) error { + repo := me.repository + err := repo.WithReadTxn(r.Context(), func(ctx context.Context) error { sceneIdInt, err := strconv.Atoi(sceneId) if err != nil { return nil } - scene, _ = me.repository.SceneFinder.Find(ctx, sceneIdInt) + scene, _ = repo.SceneFinder.Find(ctx, sceneIdInt) return nil }) if err != nil { diff --git a/internal/dlna/service.go b/internal/dlna/service.go index d5399e6a11e..955edf6e05c 100644 --- a/internal/dlna/service.go +++ b/internal/dlna/service.go @@ -1,6 +1,7 @@ package dlna import ( + "context" "fmt" "net" "net/http" @@ -14,6 +15,8 @@ import ( ) type Repository struct { + TxnManager models.TxnManager + SceneFinder SceneFinder FileGetter models.FileGetter StudioFinder StudioFinder @@ -22,6 +25,22 @@ type Repository struct { MovieFinder MovieFinder } +func NewRepository(repo models.Repository) Repository { + return Repository{ + TxnManager: repo.TxnManager, + FileGetter: repo.File, + SceneFinder: repo.Scene, + StudioFinder: repo.Studio, + TagFinder: repo.Tag, + PerformerFinder: repo.Performer, + MovieFinder: repo.Movie, + } +} + +func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error { + return txn.WithReadTxn(ctx, r.TxnManager, fn) +} + type Status struct { Running bool `json:"running"` // If not currently running, time until it will be started. If running, time until it will be stopped @@ -60,7 +79,6 @@ type Config interface { } type Service struct { - txnManager txn.Manager repository Repository config Config sceneServer sceneServer @@ -133,9 +151,8 @@ func (s *Service) init() error { } s.server = &Server{ - txnManager: s.txnManager, - sceneServer: s.sceneServer, repository: s.repository, + sceneServer: s.sceneServer, ipWhitelistManager: s.ipWhitelistMgr, Interfaces: interfaces, HTTPConn: func() net.Listener { @@ -197,9 +214,8 @@ func (s *Service) init() error { // } // NewService initialises and returns a new DLNA service. -func NewService(txnManager txn.Manager, repo Repository, cfg Config, sceneServer sceneServer) *Service { +func NewService(repo Repository, cfg Config, sceneServer sceneServer) *Service { ret := &Service{ - txnManager: txnManager, repository: repo, sceneServer: sceneServer, config: cfg, diff --git a/internal/identify/identify.go b/internal/identify/identify.go index db8ca2f54ab..d55497df8ad 100644 --- a/internal/identify/identify.go +++ b/internal/identify/identify.go @@ -43,6 +43,7 @@ type ScraperSource struct { } type SceneIdentifier struct { + TxnManager txn.Manager SceneReaderUpdater SceneReaderUpdater StudioReaderWriter models.StudioReaderWriter PerformerCreator PerformerCreator @@ -53,8 +54,8 @@ type SceneIdentifier struct { SceneUpdatePostHookExecutor SceneUpdatePostHookExecutor } -func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager, scene *models.Scene) error { - result, err := t.scrapeScene(ctx, txnManager, scene) +func (t *SceneIdentifier) Identify(ctx context.Context, scene *models.Scene) error { + result, err := t.scrapeScene(ctx, scene) var multipleMatchErr *MultipleMatchesFoundError if err != nil { if !errors.As(err, &multipleMatchErr) { @@ -70,7 +71,7 @@ func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager, options := t.getOptions(multipleMatchErr.Source) if options.SkipMultipleMatchTag != nil && len(*options.SkipMultipleMatchTag) > 0 { // Tag it with the multiple results tag - err := t.addTagToScene(ctx, txnManager, scene, *options.SkipMultipleMatchTag) + err := t.addTagToScene(ctx, scene, *options.SkipMultipleMatchTag) if err != nil { return err } @@ -83,7 +84,7 @@ func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager, } // results were found, modify the scene - if err := t.modifyScene(ctx, txnManager, scene, result); err != nil { + if err := t.modifyScene(ctx, scene, result); err != nil { return fmt.Errorf("error modifying scene: %v", err) } @@ -95,7 +96,7 @@ type scrapeResult struct { source ScraperSource } -func (t *SceneIdentifier) scrapeScene(ctx context.Context, txnManager txn.Manager, scene *models.Scene) (*scrapeResult, error) { +func (t *SceneIdentifier) scrapeScene(ctx context.Context, scene *models.Scene) (*scrapeResult, error) { // iterate through the input sources for _, source := range t.Sources { // scrape using the source @@ -256,9 +257,9 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene, return ret, nil } -func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager txn.Manager, s *models.Scene, result *scrapeResult) error { +func (t *SceneIdentifier) modifyScene(ctx context.Context, s *models.Scene, result *scrapeResult) error { var updater *scene.UpdateSet - if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error { + if err := txn.WithTxn(ctx, t.TxnManager, func(ctx context.Context) error { // load scene relationships if err := s.LoadURLs(ctx, t.SceneReaderUpdater); err != nil { return err @@ -311,8 +312,8 @@ func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager txn.Manage return nil } -func (t *SceneIdentifier) addTagToScene(ctx context.Context, txnManager txn.Manager, s *models.Scene, tagToAdd string) error { - if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error { +func (t *SceneIdentifier) addTagToScene(ctx context.Context, s *models.Scene, tagToAdd string) error { + if err := txn.WithTxn(ctx, t.TxnManager, func(ctx context.Context) error { tagID, err := strconv.Atoi(tagToAdd) if err != nil { return fmt.Errorf("error converting tag ID %s: %w", tagToAdd, err) diff --git a/internal/identify/identify_test.go b/internal/identify/identify_test.go index c032cbe53ac..b8472c1186d 100644 --- a/internal/identify/identify_test.go +++ b/internal/identify/identify_test.go @@ -202,7 +202,7 @@ func TestSceneIdentifier_Identify(t *testing.T) { TagIDs: models.NewRelatedIDs([]int{}), StashIDs: models.NewRelatedStashIDs([]models.StashID{}), } - if err := identifier.Identify(testCtx, &mocks.TxnManager{}, scene); (err != nil) != tt.wantErr { + if err := identifier.Identify(testCtx, scene); (err != nil) != tt.wantErr { t.Errorf("SceneIdentifier.Identify() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -210,9 +210,6 @@ func TestSceneIdentifier_Identify(t *testing.T) { } func TestSceneIdentifier_modifyScene(t *testing.T) { - repo := models.Repository{ - TxnManager: &mocks.TxnManager{}, - } boolFalse := false defaultOptions := &MetadataOptions{ SetOrganized: &boolFalse, @@ -221,6 +218,7 @@ func TestSceneIdentifier_modifyScene(t *testing.T) { SkipSingleNamePerformers: &boolFalse, } tr := &SceneIdentifier{ + TxnManager: &mocks.TxnManager{}, DefaultOptions: defaultOptions, } @@ -254,7 +252,7 @@ func TestSceneIdentifier_modifyScene(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tr.modifyScene(testCtx, repo.TxnManager, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr { + if err := tr.modifyScene(testCtx, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr { t.Errorf("SceneIdentifier.modifyScene() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/internal/manager/manager.go b/internal/manager/manager.go index 9bd7fbfae16..567af1c015f 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -222,14 +222,8 @@ func initialize() error { SceneCoverGetter: repo.Scene, } - instance.DLNAService = dlna.NewService(repo.TxnManager, dlna.Repository{ - SceneFinder: repo.Scene, - FileGetter: repo.File, - StudioFinder: repo.Studio, - TagFinder: repo.Tag, - PerformerFinder: repo.Performer, - MovieFinder: repo.Movie, - }, instance.Config, &sceneServer) + dlnaRepository := dlna.NewRepository(repo) + instance.DLNAService = dlna.NewService(dlnaRepository, cfg, &sceneServer) if !cfg.IsNewSystem() { logger.Infof("using config file: %s", cfg.GetConfigFile()) @@ -296,12 +290,7 @@ func galleryFileFilter(ctx context.Context, f models.File) bool { func makeScanner(repo models.Repository, pluginCache *plugin.Cache) *file.Scanner { return &file.Scanner{ - Repository: file.Repository{ - Manager: repo.TxnManager, - DatabaseProvider: repo.TxnManager, - FileStore: repo.File, - FolderStore: repo.Folder, - }, + Repository: file.NewRepository(repo), FileDecorators: []file.Decorator{ &file.FilteredDecorator{ Decorator: &video.Decorator{ @@ -323,13 +312,8 @@ func makeScanner(repo models.Repository, pluginCache *plugin.Cache) *file.Scanne func makeCleaner(repo models.Repository, pluginCache *plugin.Cache) *file.Cleaner { return &file.Cleaner{ - FS: &file.OsFS{}, - Repository: file.Repository{ - Manager: repo.TxnManager, - DatabaseProvider: repo.TxnManager, - FileStore: repo.File, - FolderStore: repo.Folder, - }, + FS: &file.OsFS{}, + Repository: file.NewRepository(repo), Handlers: []file.CleanHandler{ &cleanHandler{}, }, @@ -524,15 +508,8 @@ func writeStashIcon() { // initScraperCache initializes a new scraper cache and returns it. func (s *Manager) initScraperCache() *scraper.Cache { - repo := s.Repository - ret, err := scraper.NewCache(s.Config, repo.TxnManager, scraper.Repository{ - SceneFinder: repo.Scene, - GalleryFinder: repo.Gallery, - TagFinder: repo.Tag, - PerformerFinder: repo.Performer, - MovieFinder: repo.Movie, - StudioFinder: repo.Studio, - }) + scraperRepository := scraper.NewRepository(s.Repository) + ret, err := scraper.NewCache(s.Config, scraperRepository) if err != nil { logger.Errorf("Error reading scraper configs: %s", err.Error()) diff --git a/internal/manager/task_clean.go b/internal/manager/task_clean.go index 0ea4332f918..38e029252a3 100644 --- a/internal/manager/task_clean.go +++ b/internal/manager/task_clean.go @@ -254,9 +254,7 @@ func (f *cleanFilter) shouldCleanImage(path string, stash *config.StashConfig) b return false } -type cleanHandler struct { - PluginCache *plugin.Cache -} +type cleanHandler struct{} func (h *cleanHandler) HandleFile(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error { if err := h.handleRelatedScenes(ctx, fileDeleter, fileID); err != nil { @@ -278,7 +276,7 @@ func (h *cleanHandler) HandleFolder(ctx context.Context, fileDeleter *file.Delet func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error { mgr := GetInstance() - sceneQB := mgr.Database.Scene + sceneQB := mgr.Repository.Scene scenes, err := sceneQB.FindByFileID(ctx, fileID) if err != nil { return err @@ -304,12 +302,9 @@ func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *fil return err } - checksum := scene.Checksum - oshash := scene.OSHash - mgr.PluginCache.RegisterPostHooks(ctx, scene.ID, plugin.SceneDestroyPost, plugin.SceneDestroyInput{ - Checksum: checksum, - OSHash: oshash, + Checksum: scene.Checksum, + OSHash: scene.OSHash, Path: scene.Path, }, nil) } else { @@ -336,7 +331,7 @@ func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *fil func (h *cleanHandler) handleRelatedGalleries(ctx context.Context, fileID models.FileID) error { mgr := GetInstance() - qb := mgr.Database.Gallery + qb := mgr.Repository.Gallery galleries, err := qb.FindByFileID(ctx, fileID) if err != nil { return err @@ -382,7 +377,7 @@ func (h *cleanHandler) handleRelatedGalleries(ctx context.Context, fileID models func (h *cleanHandler) deleteRelatedFolderGalleries(ctx context.Context, folderID models.FolderID) error { mgr := GetInstance() - qb := mgr.Database.Gallery + qb := mgr.Repository.Gallery galleries, err := qb.FindByFolderID(ctx, folderID) if err != nil { return err @@ -406,7 +401,7 @@ func (h *cleanHandler) deleteRelatedFolderGalleries(ctx context.Context, folderI func (h *cleanHandler) handleRelatedImages(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error { mgr := GetInstance() - imageQB := mgr.Database.Image + imageQB := mgr.Repository.Image images, err := imageQB.FindByFileID(ctx, fileID) if err != nil { return err @@ -414,7 +409,7 @@ func (h *cleanHandler) handleRelatedImages(ctx context.Context, fileDeleter *fil imageFileDeleter := &image.FileDeleter{ Deleter: fileDeleter, - Paths: GetInstance().Paths, + Paths: mgr.Paths, } for _, i := range images { diff --git a/internal/manager/task_identify.go b/internal/manager/task_identify.go index 8978e65cb0b..2f8bddcede9 100644 --- a/internal/manager/task_identify.go +++ b/internal/manager/task_identify.go @@ -136,6 +136,7 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source j.progress.ExecuteTask("Identifying "+s.Path, func() { r := instance.Repository task := identify.SceneIdentifier{ + TxnManager: r.TxnManager, SceneReaderUpdater: r.Scene, StudioReaderWriter: r.Studio, PerformerCreator: r.Performer, @@ -146,7 +147,7 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source SceneUpdatePostHookExecutor: j.postHookExecutor, } - taskError = task.Identify(ctx, r.TxnManager, s) + taskError = task.Identify(ctx, s) }) if taskError != nil { @@ -167,16 +168,11 @@ func (j *IdentifyJob) getSources() ([]identify.ScraperSource, error) { var src identify.ScraperSource if stashBox != nil { - r := instance.Repository + stashboxRepository := stashbox.NewRepository(instance.Repository) src = identify.ScraperSource{ Name: "stash-box: " + stashBox.Endpoint, Scraper: stashboxSource{ - stashbox.NewClient(*stashBox, r.TxnManager, stashbox.Repository{ - Scene: r.Scene, - Performer: r.Performer, - Tag: r.Tag, - Studio: r.Studio, - }), + stashbox.NewClient(*stashBox, stashboxRepository), stashBox.Endpoint, }, RemoteSite: stashBox.Endpoint, diff --git a/internal/manager/task_stash_box_tag.go b/internal/manager/task_stash_box_tag.go index 9cd48abc8f0..9a7700ba7c2 100644 --- a/internal/manager/task_stash_box_tag.go +++ b/internal/manager/task_stash_box_tag.go @@ -92,12 +92,9 @@ func (t *StashBoxBatchTagTask) findStashBoxPerformer(ctx context.Context) (*mode var err error r := instance.Repository - client := stashbox.NewClient(*t.box, r.TxnManager, stashbox.Repository{ - Scene: r.Scene, - Performer: r.Performer, - Tag: r.Tag, - Studio: r.Studio, - }) + + stashboxRepository := stashbox.NewRepository(r) + client := stashbox.NewClient(*t.box, stashboxRepository) if t.refresh { var remoteID string @@ -239,12 +236,9 @@ func (t *StashBoxBatchTagTask) findStashBoxStudio(ctx context.Context) (*models. var err error r := instance.Repository - client := stashbox.NewClient(*t.box, r.TxnManager, stashbox.Repository{ - Scene: r.Scene, - Performer: r.Performer, - Tag: r.Tag, - Studio: r.Studio, - }) + + stashboxRepository := stashbox.NewRepository(r) + client := stashbox.NewClient(*t.box, stashboxRepository) if t.refresh { var remoteID string diff --git a/pkg/file/clean.go b/pkg/file/clean.go index d3e27a774a2..ad8efd3ae25 100644 --- a/pkg/file/clean.go +++ b/pkg/file/clean.go @@ -11,7 +11,6 @@ import ( "github.com/stashapp/stash/pkg/job" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/txn" ) // Cleaner scans through stored file and folder instances and removes those that are no longer present on disk. @@ -112,14 +111,15 @@ func (j *cleanJob) execute(ctx context.Context) error { folderCount int ) - if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error { + r := j.Repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { var err error - fileCount, err = j.Repository.FileStore.CountAllInPaths(ctx, j.options.Paths) + fileCount, err = r.File.CountAllInPaths(ctx, j.options.Paths) if err != nil { return err } - folderCount, err = j.Repository.FolderStore.CountAllInPaths(ctx, j.options.Paths) + folderCount, err = r.Folder.CountAllInPaths(ctx, j.options.Paths) if err != nil { return err } @@ -172,13 +172,14 @@ func (j *cleanJob) assessFiles(ctx context.Context, toDelete *deleteSet) error { progress := j.progress more := true - if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error { + r := j.Repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { for more { if job.IsCancelled(ctx) { return nil } - files, err := j.Repository.FileStore.FindAllInPaths(ctx, j.options.Paths, batchSize, offset) + files, err := r.File.FindAllInPaths(ctx, j.options.Paths, batchSize, offset) if err != nil { return fmt.Errorf("error querying for files: %w", err) } @@ -223,8 +224,9 @@ func (j *cleanJob) assessFiles(ctx context.Context, toDelete *deleteSet) error { // flagFolderForDelete adds folders to the toDelete set, with the leaf folders added first func (j *cleanJob) flagFileForDelete(ctx context.Context, toDelete *deleteSet, f models.File) error { + r := j.Repository // add contained files first - containedFiles, err := j.Repository.FileStore.FindByZipFileID(ctx, f.Base().ID) + containedFiles, err := r.File.FindByZipFileID(ctx, f.Base().ID) if err != nil { return fmt.Errorf("error finding contained files for %q: %w", f.Base().Path, err) } @@ -235,7 +237,7 @@ func (j *cleanJob) flagFileForDelete(ctx context.Context, toDelete *deleteSet, f } // add contained folders as well - containedFolders, err := j.Repository.FolderStore.FindByZipFileID(ctx, f.Base().ID) + containedFolders, err := r.Folder.FindByZipFileID(ctx, f.Base().ID) if err != nil { return fmt.Errorf("error finding contained folders for %q: %w", f.Base().Path, err) } @@ -256,13 +258,14 @@ func (j *cleanJob) assessFolders(ctx context.Context, toDelete *deleteSet) error progress := j.progress more := true - if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error { + r := j.Repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { for more { if job.IsCancelled(ctx) { return nil } - folders, err := j.Repository.FolderStore.FindAllInPaths(ctx, j.options.Paths, batchSize, offset) + folders, err := r.Folder.FindAllInPaths(ctx, j.options.Paths, batchSize, offset) if err != nil { return fmt.Errorf("error querying for folders: %w", err) } @@ -380,14 +383,15 @@ func (j *cleanJob) shouldCleanFolder(ctx context.Context, f *models.Folder) bool func (j *cleanJob) deleteFile(ctx context.Context, fileID models.FileID, fn string) { // delete associated objects fileDeleter := NewDeleter() - if err := txn.WithTxn(ctx, j.Repository, func(ctx context.Context) error { + r := j.Repository + if err := r.WithTxn(ctx, func(ctx context.Context) error { fileDeleter.RegisterHooks(ctx) if err := j.fireHandlers(ctx, fileDeleter, fileID); err != nil { return err } - return j.Repository.FileStore.Destroy(ctx, fileID) + return r.File.Destroy(ctx, fileID) }); err != nil { logger.Errorf("Error deleting file %q from database: %s", fn, err.Error()) return @@ -397,14 +401,15 @@ func (j *cleanJob) deleteFile(ctx context.Context, fileID models.FileID, fn stri func (j *cleanJob) deleteFolder(ctx context.Context, folderID models.FolderID, fn string) { // delete associated objects fileDeleter := NewDeleter() - if err := txn.WithTxn(ctx, j.Repository, func(ctx context.Context) error { + r := j.Repository + if err := r.WithTxn(ctx, func(ctx context.Context) error { fileDeleter.RegisterHooks(ctx) if err := j.fireFolderHandlers(ctx, fileDeleter, folderID); err != nil { return err } - return j.Repository.FolderStore.Destroy(ctx, folderID) + return r.Folder.Destroy(ctx, folderID) }); err != nil { logger.Errorf("Error deleting folder %q from database: %s", fn, err.Error()) return diff --git a/pkg/file/file.go b/pkg/file/file.go index 179e1e01af7..72c7f8a1ab4 100644 --- a/pkg/file/file.go +++ b/pkg/file/file.go @@ -1,15 +1,36 @@ package file import ( + "context" + "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/txn" ) // Repository provides access to storage methods for files and folders. type Repository struct { - txn.Manager - txn.DatabaseProvider + TxnManager models.TxnManager + + File models.FileReaderWriter + Folder models.FolderReaderWriter +} + +func NewRepository(repo models.Repository) Repository { + return Repository{ + TxnManager: repo.TxnManager, + File: repo.File, + Folder: repo.Folder, + } +} + +func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error { + return txn.WithTxn(ctx, r.TxnManager, fn) +} + +func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error { + return txn.WithReadTxn(ctx, r.TxnManager, fn) +} - FileStore models.FileReaderWriter - FolderStore models.FolderReaderWriter +func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error { + return txn.WithDatabase(ctx, r.TxnManager, fn) } diff --git a/pkg/file/folder_rename_detect.go b/pkg/file/folder_rename_detect.go index 0b57d9c087a..4f6d31bd50b 100644 --- a/pkg/file/folder_rename_detect.go +++ b/pkg/file/folder_rename_detect.go @@ -86,6 +86,8 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models. } // rejects is a set of folder ids which were found to still exist + r := s.Repository + if err := symWalk(file.fs, file.Path, func(path string, d fs.DirEntry, err error) error { if err != nil { // don't let errors prevent scanning @@ -118,7 +120,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models. } // check if the file exists in the database based on basename, size and mod time - existing, err := s.Repository.FileStore.FindByFileInfo(ctx, info, size) + existing, err := r.File.FindByFileInfo(ctx, info, size) if err != nil { return fmt.Errorf("checking for existing file %q: %w", path, err) } @@ -140,7 +142,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models. if c == nil { // need to check if the folder exists in the filesystem - pf, err := s.Repository.FolderStore.Find(ctx, e.Base().ParentFolderID) + pf, err := r.Folder.Find(ctx, e.Base().ParentFolderID) if err != nil { return fmt.Errorf("getting parent folder %d: %w", e.Base().ParentFolderID, err) } @@ -164,7 +166,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models. // parent folder is missing, possible candidate // count the total number of files in the existing folder - count, err := s.Repository.FileStore.CountByFolderID(ctx, parentFolderID) + count, err := r.File.CountByFolderID(ctx, parentFolderID) if err != nil { return fmt.Errorf("counting files in folder %d: %w", parentFolderID, err) } diff --git a/pkg/file/scan.go b/pkg/file/scan.go index a0d301e60c2..fe574c7eb7c 100644 --- a/pkg/file/scan.go +++ b/pkg/file/scan.go @@ -144,7 +144,7 @@ func (s *Scanner) Scan(ctx context.Context, handlers []Handler, options ScanOpti ProgressReports: progressReporter, options: options, txnRetryer: txn.Retryer{ - Manager: s.Repository, + Manager: s.Repository.TxnManager, Retries: maxRetries, }, } @@ -163,7 +163,7 @@ func (s *scanJob) withTxn(ctx context.Context, fn func(ctx context.Context) erro } func (s *scanJob) withDB(ctx context.Context, fn func(ctx context.Context) error) error { - return txn.WithDatabase(ctx, s.Repository, fn) + return s.Repository.WithDB(ctx, fn) } func (s *scanJob) execute(ctx context.Context) { @@ -439,7 +439,7 @@ func (s *scanJob) getFolderID(ctx context.Context, path string) (*models.FolderI return &v, nil } - ret, err := s.Repository.FolderStore.FindByPath(ctx, path) + ret, err := s.Repository.Folder.FindByPath(ctx, path) if err != nil { return nil, err } @@ -469,7 +469,7 @@ func (s *scanJob) getZipFileID(ctx context.Context, zipFile *scanFile) (*models. return &v, nil } - ret, err := s.Repository.FileStore.FindByPath(ctx, path) + ret, err := s.Repository.File.FindByPath(ctx, path) if err != nil { return nil, fmt.Errorf("getting zip file ID for %q: %w", path, err) } @@ -489,7 +489,7 @@ func (s *scanJob) handleFolder(ctx context.Context, file scanFile) error { defer s.incrementProgress(file) // determine if folder already exists in data store (by path) - f, err := s.Repository.FolderStore.FindByPath(ctx, path) + f, err := s.Repository.Folder.FindByPath(ctx, path) if err != nil { return fmt.Errorf("checking for existing folder %q: %w", path, err) } @@ -553,7 +553,7 @@ func (s *scanJob) onNewFolder(ctx context.Context, file scanFile) (*models.Folde logger.Infof("%s doesn't exist. Creating new folder entry...", file.Path) }) - if err := s.Repository.FolderStore.Create(ctx, toCreate); err != nil { + if err := s.Repository.Folder.Create(ctx, toCreate); err != nil { return nil, fmt.Errorf("creating folder %q: %w", file.Path, err) } @@ -589,7 +589,7 @@ func (s *scanJob) handleFolderRename(ctx context.Context, file scanFile) (*model renamedFrom.ParentFolderID = parentFolderID - if err := s.Repository.FolderStore.Update(ctx, renamedFrom); err != nil { + if err := s.Repository.Folder.Update(ctx, renamedFrom); err != nil { return nil, fmt.Errorf("updating folder for rename %q: %w", renamedFrom.Path, err) } @@ -621,7 +621,7 @@ func (s *scanJob) onExistingFolder(ctx context.Context, f scanFile, existing *mo if update { var err error - if err = s.Repository.FolderStore.Update(ctx, existing); err != nil { + if err = s.Repository.Folder.Update(ctx, existing); err != nil { return nil, fmt.Errorf("updating folder %q: %w", f.Path, err) } } @@ -642,7 +642,7 @@ func (s *scanJob) handleFile(ctx context.Context, f scanFile) error { if err := s.withDB(ctx, func(ctx context.Context) error { // determine if file already exists in data store var err error - ff, err = s.Repository.FileStore.FindByPath(ctx, f.Path) + ff, err = s.Repository.File.FindByPath(ctx, f.Path) if err != nil { return fmt.Errorf("checking for existing file %q: %w", f.Path, err) } @@ -740,7 +740,7 @@ func (s *scanJob) onNewFile(ctx context.Context, f scanFile) (models.File, error // if not renamed, queue file for creation if err := s.withTxn(ctx, func(ctx context.Context) error { - if err := s.Repository.FileStore.Create(ctx, file); err != nil { + if err := s.Repository.File.Create(ctx, file); err != nil { return fmt.Errorf("creating file %q: %w", path, err) } @@ -833,7 +833,7 @@ func (s *scanJob) handleRename(ctx context.Context, f models.File, fp []models.F var others []models.File for _, tfp := range fp { - thisOthers, err := s.Repository.FileStore.FindByFingerprint(ctx, tfp) + thisOthers, err := s.Repository.File.FindByFingerprint(ctx, tfp) if err != nil { return nil, fmt.Errorf("getting files by fingerprint %v: %w", tfp, err) } @@ -891,12 +891,12 @@ func (s *scanJob) handleRename(ctx context.Context, f models.File, fp []models.F fBase.Fingerprints = otherBase.Fingerprints if err := s.withTxn(ctx, func(ctx context.Context) error { - if err := s.Repository.FileStore.Update(ctx, f); err != nil { + if err := s.Repository.File.Update(ctx, f); err != nil { return fmt.Errorf("updating file for rename %q: %w", fBase.Path, err) } if s.isZipFile(fBase.Basename) { - if err := TransferZipFolderHierarchy(ctx, s.Repository.FolderStore, fBase.ID, otherBase.Path, fBase.Path); err != nil { + if err := TransferZipFolderHierarchy(ctx, s.Repository.Folder, fBase.ID, otherBase.Path, fBase.Path); err != nil { return fmt.Errorf("moving folder hierarchy for renamed zip file %q: %w", fBase.Path, err) } } @@ -958,7 +958,7 @@ func (s *scanJob) setMissingMetadata(ctx context.Context, f scanFile, existing m // queue file for update if err := s.withTxn(ctx, func(ctx context.Context) error { - if err := s.Repository.FileStore.Update(ctx, existing); err != nil { + if err := s.Repository.File.Update(ctx, existing); err != nil { return fmt.Errorf("updating file %q: %w", path, err) } @@ -981,7 +981,7 @@ func (s *scanJob) setMissingFingerprints(ctx context.Context, f scanFile, existi existing.SetFingerprints(fp) if err := s.withTxn(ctx, func(ctx context.Context) error { - if err := s.Repository.FileStore.Update(ctx, existing); err != nil { + if err := s.Repository.File.Update(ctx, existing); err != nil { return fmt.Errorf("updating file %q: %w", f.Path, err) } @@ -1030,7 +1030,7 @@ func (s *scanJob) onExistingFile(ctx context.Context, f scanFile, existing model // queue file for update if err := s.withTxn(ctx, func(ctx context.Context) error { - if err := s.Repository.FileStore.Update(ctx, existing); err != nil { + if err := s.Repository.File.Update(ctx, existing); err != nil { return fmt.Errorf("updating file %q: %w", path, err) } diff --git a/pkg/scene/filename_parser.go b/pkg/scene/filename_parser.go index b7c38863e54..0426696def5 100644 --- a/pkg/scene/filename_parser.go +++ b/pkg/scene/filename_parser.go @@ -410,17 +410,19 @@ type FilenameParser struct { ParserInput models.SceneParserInput Filter *models.FindFilterType whitespaceRE *regexp.Regexp + repository FilenameParserRepository performerCache map[string]*models.Performer studioCache map[string]*models.Studio movieCache map[string]*models.Movie tagCache map[string]*models.Tag } -func NewFilenameParser(filter *models.FindFilterType, config models.SceneParserInput) *FilenameParser { +func NewFilenameParser(filter *models.FindFilterType, config models.SceneParserInput, repo FilenameParserRepository) *FilenameParser { p := &FilenameParser{ Pattern: *filter.Q, ParserInput: config, Filter: filter, + repository: repo, } p.performerCache = make(map[string]*models.Performer) @@ -457,7 +459,17 @@ type FilenameParserRepository struct { Tag models.TagQueryer } -func (p *FilenameParser) Parse(ctx context.Context, repo FilenameParserRepository) ([]*models.SceneParserResult, int, error) { +func NewFilenameParserRepository(repo models.Repository) FilenameParserRepository { + return FilenameParserRepository{ + Scene: repo.Scene, + Performer: repo.Performer, + Studio: repo.Studio, + Movie: repo.Movie, + Tag: repo.Tag, + } +} + +func (p *FilenameParser) Parse(ctx context.Context) ([]*models.SceneParserResult, int, error) { // perform the query to find the scenes mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords) @@ -479,17 +491,17 @@ func (p *FilenameParser) Parse(ctx context.Context, repo FilenameParserRepositor p.Filter.Q = nil - scenes, total, err := QueryWithCount(ctx, repo.Scene, sceneFilter, p.Filter) + scenes, total, err := QueryWithCount(ctx, p.repository.Scene, sceneFilter, p.Filter) if err != nil { return nil, 0, err } - ret := p.parseScenes(ctx, repo, scenes, mapper) + ret := p.parseScenes(ctx, scenes, mapper) return ret, total, nil } -func (p *FilenameParser) parseScenes(ctx context.Context, repo FilenameParserRepository, scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult { +func (p *FilenameParser) parseScenes(ctx context.Context, scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult { var ret []*models.SceneParserResult for _, scene := range scenes { sceneHolder := mapper.parse(scene) @@ -498,7 +510,7 @@ func (p *FilenameParser) parseScenes(ctx context.Context, repo FilenameParserRep r := &models.SceneParserResult{ Scene: scene, } - p.setParserResult(ctx, repo, *sceneHolder, r) + p.setParserResult(ctx, *sceneHolder, r) ret = append(ret, r) } @@ -671,7 +683,7 @@ func (p *FilenameParser) setMovies(ctx context.Context, qb MovieNameFinder, h sc } } -func (p *FilenameParser) setParserResult(ctx context.Context, repo FilenameParserRepository, h sceneHolder, result *models.SceneParserResult) { +func (p *FilenameParser) setParserResult(ctx context.Context, h sceneHolder, result *models.SceneParserResult) { if h.result.Title != "" { title := h.result.Title title = p.replaceWhitespaceCharacters(title) @@ -692,15 +704,17 @@ func (p *FilenameParser) setParserResult(ctx context.Context, repo FilenameParse result.Rating = h.result.Rating } + r := p.repository + if len(h.performers) > 0 { - p.setPerformers(ctx, repo.Performer, h, result) + p.setPerformers(ctx, r.Performer, h, result) } if len(h.tags) > 0 { - p.setTags(ctx, repo.Tag, h, result) + p.setTags(ctx, r.Tag, h, result) } - p.setStudio(ctx, repo.Studio, h, result) + p.setStudio(ctx, r.Studio, h, result) if len(h.movies) > 0 { - p.setMovies(ctx, repo.Movie, h, result) + p.setMovies(ctx, r.Movie, h, result) } } diff --git a/pkg/scraper/autotag.go b/pkg/scraper/autotag.go index 5eb3922a804..6151a9794c0 100644 --- a/pkg/scraper/autotag.go +++ b/pkg/scraper/autotag.go @@ -18,7 +18,6 @@ const ( ) type autotagScraper struct { - // repository models.Repository txnManager txn.Manager performerReader models.PerformerAutoTagQueryer studioReader models.StudioAutoTagQueryer @@ -208,9 +207,9 @@ func (s autotagScraper) spec() Scraper { } } -func getAutoTagScraper(txnManager txn.Manager, repo Repository, globalConfig GlobalConfig) scraper { +func getAutoTagScraper(repo Repository, globalConfig GlobalConfig) scraper { base := autotagScraper{ - txnManager: txnManager, + txnManager: repo.TxnManager, performerReader: repo.PerformerFinder, studioReader: repo.StudioFinder, tagReader: repo.TagFinder, diff --git a/pkg/scraper/cache.go b/pkg/scraper/cache.go index c110944f624..07d50ee71a8 100644 --- a/pkg/scraper/cache.go +++ b/pkg/scraper/cache.go @@ -76,6 +76,8 @@ type GalleryFinder interface { } type Repository struct { + TxnManager models.TxnManager + SceneFinder SceneFinder GalleryFinder GalleryFinder TagFinder TagFinder @@ -84,12 +86,27 @@ type Repository struct { StudioFinder StudioFinder } +func NewRepository(repo models.Repository) Repository { + return Repository{ + TxnManager: repo.TxnManager, + SceneFinder: repo.Scene, + GalleryFinder: repo.Gallery, + TagFinder: repo.Tag, + PerformerFinder: repo.Performer, + MovieFinder: repo.Movie, + StudioFinder: repo.Studio, + } +} + +func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error { + return txn.WithReadTxn(ctx, r.TxnManager, fn) +} + // Cache stores the database of scrapers type Cache struct { client *http.Client scrapers map[string]scraper // Scraper ID -> Scraper globalConfig GlobalConfig - txnManager txn.Manager repository Repository } @@ -121,14 +138,13 @@ func newClient(gc GlobalConfig) *http.Client { // // Scraper configurations are loaded from yml files in the provided scrapers // directory and any subdirectories. -func NewCache(globalConfig GlobalConfig, txnManager txn.Manager, repo Repository) (*Cache, error) { +func NewCache(globalConfig GlobalConfig, repo Repository) (*Cache, error) { // HTTP Client setup client := newClient(globalConfig) ret := &Cache{ client: client, globalConfig: globalConfig, - txnManager: txnManager, repository: repo, } @@ -147,7 +163,7 @@ func (c *Cache) loadScrapers() (map[string]scraper, error) { // Add built-in scrapers freeOnes := getFreeonesScraper(c.globalConfig) - autoTag := getAutoTagScraper(c.txnManager, c.repository, c.globalConfig) + autoTag := getAutoTagScraper(c.repository, c.globalConfig) scrapers[freeOnes.spec().ID] = freeOnes scrapers[autoTag.spec().ID] = autoTag @@ -368,9 +384,12 @@ func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty Scrape func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error) { var ret *models.Scene - if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { + r := c.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + qb := r.SceneFinder + var err error - ret, err = c.repository.SceneFinder.Find(ctx, sceneID) + ret, err = qb.Find(ctx, sceneID) if err != nil { return err } @@ -379,7 +398,7 @@ func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error) return fmt.Errorf("scene with id %d not found", sceneID) } - return ret.LoadURLs(ctx, c.repository.SceneFinder) + return ret.LoadURLs(ctx, qb) }); err != nil { return nil, err } @@ -388,9 +407,12 @@ func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error) func (c Cache) getGallery(ctx context.Context, galleryID int) (*models.Gallery, error) { var ret *models.Gallery - if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { + r := c.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + qb := r.GalleryFinder + var err error - ret, err = c.repository.GalleryFinder.Find(ctx, galleryID) + ret, err = qb.Find(ctx, galleryID) if err != nil { return err } @@ -399,7 +421,7 @@ func (c Cache) getGallery(ctx context.Context, galleryID int) (*models.Gallery, return fmt.Errorf("gallery with id %d not found", galleryID) } - return ret.LoadFiles(ctx, c.repository.GalleryFinder) + return ret.LoadFiles(ctx, qb) }); err != nil { return nil, err } diff --git a/pkg/scraper/postprocessing.go b/pkg/scraper/postprocessing.go index e504e4d1cac..666aefc0a3e 100644 --- a/pkg/scraper/postprocessing.go +++ b/pkg/scraper/postprocessing.go @@ -6,7 +6,6 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/txn" ) // postScrape handles post-processing of scraped content. If the content @@ -46,8 +45,9 @@ func (c Cache) postScrape(ctx context.Context, content ScrapedContent) (ScrapedC } func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerformer) (ScrapedContent, error) { - if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { - tqb := c.repository.TagFinder + r := c.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + tqb := r.TagFinder tags, err := postProcessTags(ctx, tqb, p.Tags) if err != nil { @@ -72,8 +72,9 @@ func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerforme func (c Cache) postScrapeMovie(ctx context.Context, m models.ScrapedMovie) (ScrapedContent, error) { if m.Studio != nil { - if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { - return match.ScrapedStudio(ctx, c.repository.StudioFinder, m.Studio, nil) + r := c.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + return match.ScrapedStudio(ctx, r.StudioFinder, m.Studio, nil) }); err != nil { return nil, err } @@ -113,11 +114,12 @@ func (c Cache) postScrapeScene(ctx context.Context, scene ScrapedScene) (Scraped scene.URLs = []string{*scene.URL} } - if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { - pqb := c.repository.PerformerFinder - mqb := c.repository.MovieFinder - tqb := c.repository.TagFinder - sqb := c.repository.StudioFinder + r := c.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + pqb := r.PerformerFinder + mqb := r.MovieFinder + tqb := r.TagFinder + sqb := r.StudioFinder for _, p := range scene.Performers { if p == nil { @@ -167,10 +169,11 @@ func (c Cache) postScrapeScene(ctx context.Context, scene ScrapedScene) (Scraped } func (c Cache) postScrapeGallery(ctx context.Context, g ScrapedGallery) (ScrapedContent, error) { - if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { - pqb := c.repository.PerformerFinder - tqb := c.repository.TagFinder - sqb := c.repository.StudioFinder + r := c.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + pqb := r.PerformerFinder + tqb := r.TagFinder + sqb := r.StudioFinder for _, p := range g.Performers { err := match.ScrapedPerformer(ctx, pqb, p, nil) diff --git a/pkg/scraper/stashbox/stash_box.go b/pkg/scraper/stashbox/stash_box.go index 7abff7032e2..46971aeefaa 100644 --- a/pkg/scraper/stashbox/stash_box.go +++ b/pkg/scraper/stashbox/stash_box.go @@ -56,22 +56,37 @@ type TagFinder interface { } type Repository struct { + TxnManager models.TxnManager + Scene SceneReader Performer PerformerReader Tag TagFinder Studio StudioReader } +func NewRepository(repo models.Repository) Repository { + return Repository{ + TxnManager: repo.TxnManager, + Scene: repo.Scene, + Performer: repo.Performer, + Tag: repo.Tag, + Studio: repo.Studio, + } +} + +func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error { + return txn.WithReadTxn(ctx, r.TxnManager, fn) +} + // Client represents the client interface to a stash-box server instance. type Client struct { client *graphql.Client - txnManager txn.Manager repository Repository box models.StashBox } // NewClient returns a new instance of a stash-box client. -func NewClient(box models.StashBox, txnManager txn.Manager, repo Repository) *Client { +func NewClient(box models.StashBox, repo Repository) *Client { authHeader := func(req *http.Request) { req.Header.Set("ApiKey", box.APIKey) } @@ -82,7 +97,6 @@ func NewClient(box models.StashBox, txnManager txn.Manager, repo Repository) *Cl return &Client{ client: client, - txnManager: txnManager, repository: repo, box: box, } @@ -129,8 +143,9 @@ func (c Client) FindStashBoxSceneByFingerprints(ctx context.Context, sceneID int func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int) ([][]*scraper.ScrapedScene, error) { var fingerprints [][]*graphql.FingerprintQueryInput - if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { - qb := c.repository.Scene + r := c.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + qb := r.Scene for _, sceneID := range ids { scene, err := qb.Find(ctx, sceneID) @@ -142,7 +157,7 @@ func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int) return fmt.Errorf("scene with id %d not found", sceneID) } - if err := scene.LoadFiles(ctx, c.repository.Scene); err != nil { + if err := scene.LoadFiles(ctx, r.Scene); err != nil { return err } @@ -243,8 +258,9 @@ func (c Client) SubmitStashBoxFingerprints(ctx context.Context, sceneIDs []strin var fingerprints []graphql.FingerprintSubmission - if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { - qb := c.repository.Scene + r := c.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + qb := r.Scene for _, sceneID := range ids { scene, err := qb.Find(ctx, sceneID) @@ -382,9 +398,9 @@ func (c Client) FindStashBoxPerformersByNames(ctx context.Context, performerIDs } var performers []*models.Performer - - if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { - qb := c.repository.Performer + r := c.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + qb := r.Performer for _, performerID := range ids { performer, err := qb.Find(ctx, performerID) @@ -417,8 +433,9 @@ func (c Client) FindStashBoxPerformersByPerformerNames(ctx context.Context, perf var performers []*models.Performer - if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { - qb := c.repository.Performer + r := c.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + qb := r.Performer for _, performerID := range ids { performer, err := qb.Find(ctx, performerID) @@ -739,14 +756,15 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen ss.URL = &s.Urls[0].URL } - if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { - pqb := c.repository.Performer - tqb := c.repository.Tag + r := c.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + pqb := r.Performer + tqb := r.Tag if s.Studio != nil { ss.Studio = studioFragmentToScrapedStudio(*s.Studio) - err := match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio, &c.box.Endpoint) + err := match.ScrapedStudio(ctx, r.Studio, ss.Studio, &c.box.Endpoint) if err != nil { return err } @@ -761,7 +779,7 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen if parentStudio.FindStudio != nil { ss.Studio.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio) - err = match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio.Parent, &c.box.Endpoint) + err = match.ScrapedStudio(ctx, r.Studio, ss.Studio.Parent, &c.box.Endpoint) if err != nil { return err } @@ -809,8 +827,9 @@ func (c Client) FindStashBoxPerformerByID(ctx context.Context, id string) (*mode ret := performerFragmentToScrapedPerformer(*performer.FindPerformer) - if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { - err := match.ScrapedPerformer(ctx, c.repository.Performer, ret, &c.box.Endpoint) + r := c.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + err := match.ScrapedPerformer(ctx, r.Performer, ret, &c.box.Endpoint) return err }); err != nil { return nil, err @@ -836,8 +855,9 @@ func (c Client) FindStashBoxPerformerByName(ctx context.Context, name string) (* return nil, nil } - if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { - err := match.ScrapedPerformer(ctx, c.repository.Performer, ret, &c.box.Endpoint) + r := c.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { + err := match.ScrapedPerformer(ctx, r.Performer, ret, &c.box.Endpoint) return err }); err != nil { return nil, err @@ -864,10 +884,11 @@ func (c Client) FindStashBoxStudio(ctx context.Context, query string) (*models.S var ret *models.ScrapedStudio if studio.FindStudio != nil { - if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { + r := c.repository + if err := r.WithReadTxn(ctx, func(ctx context.Context) error { ret = studioFragmentToScrapedStudio(*studio.FindStudio) - err = match.ScrapedStudio(ctx, c.repository.Studio, ret, &c.box.Endpoint) + err = match.ScrapedStudio(ctx, r.Studio, ret, &c.box.Endpoint) if err != nil { return err } @@ -881,7 +902,7 @@ func (c Client) FindStashBoxStudio(ctx context.Context, query string) (*models.S if parentStudio.FindStudio != nil { ret.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio) - err = match.ScrapedStudio(ctx, c.repository.Studio, ret.Parent, &c.box.Endpoint) + err = match.ScrapedStudio(ctx, r.Studio, ret.Parent, &c.box.Endpoint) if err != nil { return err } From 2bff76470a071863f9b079767604fc466063b3ba Mon Sep 17 00:00:00 2001 From: DingDongSoLong4 <99329275+DingDongSoLong4@users.noreply.github.com> Date: Fri, 1 Sep 2023 16:48:08 +0200 Subject: [PATCH 3/9] Fix tests and add database mock --- internal/api/resolver_mutation_tag_test.go | 35 +++---- internal/autotag/gallery_test.go | 83 +++++++-------- internal/autotag/image_test.go | 83 +++++++-------- internal/autotag/integration_test.go | 4 +- internal/autotag/performer_test.go | 36 +++---- internal/autotag/scene_test.go | 83 +++++++-------- internal/autotag/studio_test.go | 43 ++++---- internal/autotag/tag_test.go | 42 ++++---- internal/identify/identify_test.go | 29 ++++-- internal/identify/performer_test.go | 16 +-- internal/identify/scene_test.go | 35 ++++--- internal/identify/studio_test.go | 13 +-- pkg/gallery/export_test.go | 23 ++-- pkg/gallery/import_test.go | 84 +++++++-------- pkg/image/export_test.go | 12 +-- pkg/image/import_test.go | 84 +++++++-------- pkg/models/mocks/database.go | 91 ++++++++++++++++ pkg/models/mocks/transaction.go | 59 ----------- pkg/movie/export_test.go | 38 ++++--- pkg/movie/import_test.go | 79 +++++++------- pkg/performer/export_test.go | 12 +-- pkg/performer/import_test.go | 77 +++++++------- pkg/scene/export_test.go | 78 +++++++------- pkg/scene/import_test.go | 115 ++++++++++---------- pkg/scene/update_test.go | 18 ++-- pkg/sqlite/setup_test.go | 4 +- pkg/studio/export_test.go | 24 ++--- pkg/studio/import_test.go | 116 ++++++++++----------- pkg/tag/export_test.go | 44 ++++---- pkg/tag/import_test.go | 114 ++++++++++---------- pkg/tag/update_test.go | 15 ++- 31 files changed, 806 insertions(+), 783 deletions(-) create mode 100644 pkg/models/mocks/database.go delete mode 100644 pkg/models/mocks/transaction.go diff --git a/internal/api/resolver_mutation_tag_test.go b/internal/api/resolver_mutation_tag_test.go index ba604061681..0c9a5c733ce 100644 --- a/internal/api/resolver_mutation_tag_test.go +++ b/internal/api/resolver_mutation_tag_test.go @@ -14,13 +14,9 @@ import ( ) // TODO - move this into a common area -func newResolver() *Resolver { - txnMgr := &mocks.TxnManager{} +func newResolver(db *mocks.Database) *Resolver { return &Resolver{ - repository: models.Repository{ - TxnManager: txnMgr, - Tag: &mocks.TagReaderWriter{}, - }, + repository: db.Repository(), hookExecutor: &mockHookExecutor{}, } } @@ -43,9 +39,8 @@ func (*mockHookExecutor) ExecutePostHooks(ctx context.Context, id int, hookType } func TestTagCreate(t *testing.T) { - r := newResolver() - - tagRW := r.repository.Tag.(*mocks.TagReaderWriter) + db := mocks.NewDatabase() + r := newResolver(db) pp := 1 findFilter := &models.FindFilterType{ @@ -70,17 +65,17 @@ func TestTagCreate(t *testing.T) { } } - tagRW.On("Query", mock.Anything, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{ + db.Tag.On("Query", mock.Anything, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{ { ID: existingTagID, Name: existingTagName, }, }, 1, nil).Once() - tagRW.On("Query", mock.Anything, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once() - tagRW.On("Query", mock.Anything, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once() + db.Tag.On("Query", mock.Anything, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once() + db.Tag.On("Query", mock.Anything, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once() expectedErr := errors.New("TagCreate error") - tagRW.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Return(expectedErr) + db.Tag.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Return(expectedErr) // fails here because testCtx is empty // TODO: Fix this @@ -99,22 +94,22 @@ func TestTagCreate(t *testing.T) { }) assert.Equal(t, expectedErr, err) - tagRW.AssertExpectations(t) + db.Tag.AssertExpectations(t) - r = newResolver() - tagRW = r.repository.Tag.(*mocks.TagReaderWriter) + db = mocks.NewDatabase() + r = newResolver(db) - tagRW.On("Query", mock.Anything, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once() - tagRW.On("Query", mock.Anything, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once() + db.Tag.On("Query", mock.Anything, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once() + db.Tag.On("Query", mock.Anything, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once() newTag := &models.Tag{ ID: newTagID, Name: tagName, } - tagRW.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { + db.Tag.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { arg := args.Get(1).(*models.Tag) arg.ID = newTagID }).Return(nil) - tagRW.On("Find", mock.Anything, newTagID).Return(newTag, nil) + db.Tag.On("Find", mock.Anything, newTagID).Return(newTag, nil) tag, err := r.Mutation().TagCreate(testCtx, TagCreateInput{ Name: tagName, diff --git a/internal/autotag/gallery_test.go b/internal/autotag/gallery_test.go index 23c3d931ee6..b6214b1246b 100644 --- a/internal/autotag/gallery_test.go +++ b/internal/autotag/gallery_test.go @@ -52,11 +52,10 @@ func TestGalleryPerformers(t *testing.T) { assert := assert.New(t) for _, test := range testTables { - mockPerformerReader := &mocks.PerformerReaderWriter{} - mockGalleryReader := &mocks.GalleryReaderWriter{} + db := mocks.NewDatabase() - mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() + db.Performer.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Performer.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() if test.Matches { matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool { @@ -69,7 +68,7 @@ func TestGalleryPerformers(t *testing.T) { return galleryPartialsEqual(got, expected) }) - mockGalleryReader.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once() + db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once() } gallery := models.Gallery{ @@ -77,11 +76,11 @@ func TestGalleryPerformers(t *testing.T) { Path: test.Path, PerformerIDs: models.NewRelatedIDs([]int{}), } - err := GalleryPerformers(testCtx, &gallery, mockGalleryReader, mockPerformerReader, nil) + err := GalleryPerformers(testCtx, &gallery, db.Gallery, db.Performer, nil) assert.Nil(err) - mockPerformerReader.AssertExpectations(t) - mockGalleryReader.AssertExpectations(t) + db.Performer.AssertExpectations(t) + db.Gallery.AssertExpectations(t) } } @@ -107,7 +106,7 @@ func TestGalleryStudios(t *testing.T) { assert := assert.New(t) - doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) { + doTest := func(db *mocks.Database, test pathTestTable) { if test.Matches { matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool { expected := models.GalleryPartial{ @@ -116,29 +115,28 @@ func TestGalleryStudios(t *testing.T) { return galleryPartialsEqual(got, expected) }) - mockGalleryReader.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once() + db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once() } gallery := models.Gallery{ ID: galleryID, Path: test.Path, } - err := GalleryStudios(testCtx, &gallery, mockGalleryReader, mockStudioReader, nil) + err := GalleryStudios(testCtx, &gallery, db.Gallery, db.Studio, nil) assert.Nil(err) - mockStudioReader.AssertExpectations(t) - mockGalleryReader.AssertExpectations(t) + db.Studio.AssertExpectations(t) + db.Gallery.AssertExpectations(t) } for _, test := range testTables { - mockStudioReader := &mocks.StudioReaderWriter{} - mockGalleryReader := &mocks.GalleryReaderWriter{} + db := mocks.NewDatabase() - mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() - mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() + db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() + db.Studio.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() - doTest(mockStudioReader, mockGalleryReader, test) + doTest(db, test) } // test against aliases @@ -146,17 +144,16 @@ func TestGalleryStudios(t *testing.T) { studio.Name = unmatchedName for _, test := range testTables { - mockStudioReader := &mocks.StudioReaderWriter{} - mockGalleryReader := &mocks.GalleryReaderWriter{} + db := mocks.NewDatabase() - mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() - mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{ + db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() + db.Studio.On("GetAliases", testCtx, studioID).Return([]string{ studioName, }, nil).Once() - mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once() + db.Studio.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once() - doTest(mockStudioReader, mockGalleryReader, test) + doTest(db, test) } } @@ -182,7 +179,7 @@ func TestGalleryTags(t *testing.T) { assert := assert.New(t) - doTest := func(mockTagReader *mocks.TagReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) { + doTest := func(db *mocks.Database, test pathTestTable) { if test.Matches { matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool { expected := models.GalleryPartial{ @@ -194,7 +191,7 @@ func TestGalleryTags(t *testing.T) { return galleryPartialsEqual(got, expected) }) - mockGalleryReader.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once() + db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once() } gallery := models.Gallery{ @@ -202,38 +199,36 @@ func TestGalleryTags(t *testing.T) { Path: test.Path, TagIDs: models.NewRelatedIDs([]int{}), } - err := GalleryTags(testCtx, &gallery, mockGalleryReader, mockTagReader, nil) + err := GalleryTags(testCtx, &gallery, db.Gallery, db.Tag, nil) assert.Nil(err) - mockTagReader.AssertExpectations(t) - mockGalleryReader.AssertExpectations(t) + db.Tag.AssertExpectations(t) + db.Gallery.AssertExpectations(t) } for _, test := range testTables { - mockTagReader := &mocks.TagReaderWriter{} - mockGalleryReader := &mocks.GalleryReaderWriter{} + db := mocks.NewDatabase() - mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() - mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() + db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() + db.Tag.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() - doTest(mockTagReader, mockGalleryReader, test) + doTest(db, test) } const unmatchedName = "unmatched" tag.Name = unmatchedName for _, test := range testTables { - mockTagReader := &mocks.TagReaderWriter{} - mockGalleryReader := &mocks.GalleryReaderWriter{} + db := mocks.NewDatabase() - mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() - mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{ + db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() + db.Tag.On("GetAliases", testCtx, tagID).Return([]string{ tagName, }, nil).Once() - mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once() + db.Tag.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once() - doTest(mockTagReader, mockGalleryReader, test) + doTest(db, test) } } diff --git a/internal/autotag/image_test.go b/internal/autotag/image_test.go index 06991beea1f..8c5b6ae0762 100644 --- a/internal/autotag/image_test.go +++ b/internal/autotag/image_test.go @@ -49,11 +49,10 @@ func TestImagePerformers(t *testing.T) { assert := assert.New(t) for _, test := range testTables { - mockPerformerReader := &mocks.PerformerReaderWriter{} - mockImageReader := &mocks.ImageReaderWriter{} + db := mocks.NewDatabase() - mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() + db.Performer.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Performer.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() if test.Matches { matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool { @@ -66,7 +65,7 @@ func TestImagePerformers(t *testing.T) { return imagePartialsEqual(got, expected) }) - mockImageReader.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once() + db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once() } image := models.Image{ @@ -74,11 +73,11 @@ func TestImagePerformers(t *testing.T) { Path: test.Path, PerformerIDs: models.NewRelatedIDs([]int{}), } - err := ImagePerformers(testCtx, &image, mockImageReader, mockPerformerReader, nil) + err := ImagePerformers(testCtx, &image, db.Image, db.Performer, nil) assert.Nil(err) - mockPerformerReader.AssertExpectations(t) - mockImageReader.AssertExpectations(t) + db.Performer.AssertExpectations(t) + db.Image.AssertExpectations(t) } } @@ -104,7 +103,7 @@ func TestImageStudios(t *testing.T) { assert := assert.New(t) - doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) { + doTest := func(db *mocks.Database, test pathTestTable) { if test.Matches { matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool { expected := models.ImagePartial{ @@ -113,29 +112,28 @@ func TestImageStudios(t *testing.T) { return imagePartialsEqual(got, expected) }) - mockImageReader.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once() + db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once() } image := models.Image{ ID: imageID, Path: test.Path, } - err := ImageStudios(testCtx, &image, mockImageReader, mockStudioReader, nil) + err := ImageStudios(testCtx, &image, db.Image, db.Studio, nil) assert.Nil(err) - mockStudioReader.AssertExpectations(t) - mockImageReader.AssertExpectations(t) + db.Studio.AssertExpectations(t) + db.Image.AssertExpectations(t) } for _, test := range testTables { - mockStudioReader := &mocks.StudioReaderWriter{} - mockImageReader := &mocks.ImageReaderWriter{} + db := mocks.NewDatabase() - mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() - mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() + db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() + db.Studio.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() - doTest(mockStudioReader, mockImageReader, test) + doTest(db, test) } // test against aliases @@ -143,17 +141,16 @@ func TestImageStudios(t *testing.T) { studio.Name = unmatchedName for _, test := range testTables { - mockStudioReader := &mocks.StudioReaderWriter{} - mockImageReader := &mocks.ImageReaderWriter{} + db := mocks.NewDatabase() - mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() - mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{ + db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() + db.Studio.On("GetAliases", testCtx, studioID).Return([]string{ studioName, }, nil).Once() - mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once() + db.Studio.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once() - doTest(mockStudioReader, mockImageReader, test) + doTest(db, test) } } @@ -179,7 +176,7 @@ func TestImageTags(t *testing.T) { assert := assert.New(t) - doTest := func(mockTagReader *mocks.TagReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) { + doTest := func(db *mocks.Database, test pathTestTable) { if test.Matches { matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool { expected := models.ImagePartial{ @@ -191,7 +188,7 @@ func TestImageTags(t *testing.T) { return imagePartialsEqual(got, expected) }) - mockImageReader.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once() + db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once() } image := models.Image{ @@ -199,22 +196,21 @@ func TestImageTags(t *testing.T) { Path: test.Path, TagIDs: models.NewRelatedIDs([]int{}), } - err := ImageTags(testCtx, &image, mockImageReader, mockTagReader, nil) + err := ImageTags(testCtx, &image, db.Image, db.Tag, nil) assert.Nil(err) - mockTagReader.AssertExpectations(t) - mockImageReader.AssertExpectations(t) + db.Tag.AssertExpectations(t) + db.Image.AssertExpectations(t) } for _, test := range testTables { - mockTagReader := &mocks.TagReaderWriter{} - mockImageReader := &mocks.ImageReaderWriter{} + db := mocks.NewDatabase() - mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() - mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() + db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() + db.Tag.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() - doTest(mockTagReader, mockImageReader, test) + doTest(db, test) } // test against aliases @@ -222,16 +218,15 @@ func TestImageTags(t *testing.T) { tag.Name = unmatchedName for _, test := range testTables { - mockTagReader := &mocks.TagReaderWriter{} - mockImageReader := &mocks.ImageReaderWriter{} + db := mocks.NewDatabase() - mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() - mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{ + db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() + db.Tag.On("GetAliases", testCtx, tagID).Return([]string{ tagName, }, nil).Once() - mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once() + db.Tag.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once() - doTest(mockTagReader, mockImageReader, test) + doTest(db, test) } } diff --git a/internal/autotag/integration_test.go b/internal/autotag/integration_test.go index add74133cf0..ce5bf7b212c 100644 --- a/internal/autotag/integration_test.go +++ b/internal/autotag/integration_test.go @@ -474,11 +474,11 @@ func createGallery(ctx context.Context, w models.GalleryWriter, o *models.Galler } func withTxn(f func(ctx context.Context) error) error { - return txn.WithTxn(context.TODO(), db, f) + return txn.WithTxn(testCtx, db, f) } func withDB(f func(ctx context.Context) error) error { - return txn.WithDatabase(context.TODO(), db, f) + return txn.WithDatabase(testCtx, db, f) } func populateDB() error { diff --git a/internal/autotag/performer_test.go b/internal/autotag/performer_test.go index aa0a43d92f8..1b798e24081 100644 --- a/internal/autotag/performer_test.go +++ b/internal/autotag/performer_test.go @@ -45,7 +45,7 @@ func TestPerformerScenes(t *testing.T) { } func testPerformerScenes(t *testing.T, performerName, expectedRegex string) { - mockSceneReader := &mocks.SceneReaderWriter{} + db := mocks.NewDatabase() const performerID = 2 @@ -84,7 +84,7 @@ func testPerformerScenes(t *testing.T, performerName, expectedRegex string) { Direction: &direction, } - mockSceneReader.On("Query", mock.Anything, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)). + db.Scene.On("Query", mock.Anything, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)). Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() for i := range matchingPaths { @@ -100,19 +100,19 @@ func testPerformerScenes(t *testing.T, performerName, expectedRegex string) { return scenePartialsEqual(got, expected) }) - mockSceneReader.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once() + db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once() } tagger := Tagger{ - TxnManager: &mocks.TxnManager{}, + TxnManager: db, } - err := tagger.PerformerScenes(testCtx, &performer, nil, mockSceneReader) + err := tagger.PerformerScenes(testCtx, &performer, nil, db.Scene) assert := assert.New(t) assert.Nil(err) - mockSceneReader.AssertExpectations(t) + db.Scene.AssertExpectations(t) } func TestPerformerImages(t *testing.T) { @@ -140,7 +140,7 @@ func TestPerformerImages(t *testing.T) { } func testPerformerImages(t *testing.T, performerName, expectedRegex string) { - mockImageReader := &mocks.ImageReaderWriter{} + db := mocks.NewDatabase() const performerID = 2 @@ -179,7 +179,7 @@ func testPerformerImages(t *testing.T, performerName, expectedRegex string) { Direction: &direction, } - mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)). + db.Image.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)). Return(mocks.ImageQueryResult(images, len(images)), nil).Once() for i := range matchingPaths { @@ -195,19 +195,19 @@ func testPerformerImages(t *testing.T, performerName, expectedRegex string) { return imagePartialsEqual(got, expected) }) - mockImageReader.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once() + db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once() } tagger := Tagger{ - TxnManager: &mocks.TxnManager{}, + TxnManager: db, } - err := tagger.PerformerImages(testCtx, &performer, nil, mockImageReader) + err := tagger.PerformerImages(testCtx, &performer, nil, db.Image) assert := assert.New(t) assert.Nil(err) - mockImageReader.AssertExpectations(t) + db.Image.AssertExpectations(t) } func TestPerformerGalleries(t *testing.T) { @@ -235,7 +235,7 @@ func TestPerformerGalleries(t *testing.T) { } func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) { - mockGalleryReader := &mocks.GalleryReaderWriter{} + db := mocks.NewDatabase() const performerID = 2 @@ -275,7 +275,7 @@ func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) { Direction: &direction, } - mockGalleryReader.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() + db.Gallery.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() for i := range matchingPaths { galleryID := i + 1 @@ -290,17 +290,17 @@ func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) { return galleryPartialsEqual(got, expected) }) - mockGalleryReader.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once() + db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once() } tagger := Tagger{ - TxnManager: &mocks.TxnManager{}, + TxnManager: db, } - err := tagger.PerformerGalleries(testCtx, &performer, nil, mockGalleryReader) + err := tagger.PerformerGalleries(testCtx, &performer, nil, db.Gallery) assert := assert.New(t) assert.Nil(err) - mockGalleryReader.AssertExpectations(t) + db.Gallery.AssertExpectations(t) } diff --git a/internal/autotag/scene_test.go b/internal/autotag/scene_test.go index a714c364c41..611c56c38a3 100644 --- a/internal/autotag/scene_test.go +++ b/internal/autotag/scene_test.go @@ -182,11 +182,10 @@ func TestScenePerformers(t *testing.T) { assert := assert.New(t) for _, test := range testTables { - mockPerformerReader := &mocks.PerformerReaderWriter{} - mockSceneReader := &mocks.SceneReaderWriter{} + db := mocks.NewDatabase() - mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() + db.Performer.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Performer.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() scene := models.Scene{ ID: sceneID, @@ -205,14 +204,14 @@ func TestScenePerformers(t *testing.T) { return scenePartialsEqual(got, expected) }) - mockSceneReader.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once() + db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once() } - err := ScenePerformers(testCtx, &scene, mockSceneReader, mockPerformerReader, nil) + err := ScenePerformers(testCtx, &scene, db.Scene, db.Performer, nil) assert.Nil(err) - mockPerformerReader.AssertExpectations(t) - mockSceneReader.AssertExpectations(t) + db.Performer.AssertExpectations(t) + db.Scene.AssertExpectations(t) } } @@ -240,7 +239,7 @@ func TestSceneStudios(t *testing.T) { assert := assert.New(t) - doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) { + doTest := func(db *mocks.Database, test pathTestTable) { if test.Matches { matchPartial := mock.MatchedBy(func(got models.ScenePartial) bool { expected := models.ScenePartial{ @@ -249,29 +248,28 @@ func TestSceneStudios(t *testing.T) { return scenePartialsEqual(got, expected) }) - mockSceneReader.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once() + db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once() } scene := models.Scene{ ID: sceneID, Path: test.Path, } - err := SceneStudios(testCtx, &scene, mockSceneReader, mockStudioReader, nil) + err := SceneStudios(testCtx, &scene, db.Scene, db.Studio, nil) assert.Nil(err) - mockStudioReader.AssertExpectations(t) - mockSceneReader.AssertExpectations(t) + db.Studio.AssertExpectations(t) + db.Scene.AssertExpectations(t) } for _, test := range testTables { - mockStudioReader := &mocks.StudioReaderWriter{} - mockSceneReader := &mocks.SceneReaderWriter{} + db := mocks.NewDatabase() - mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() - mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() + db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() + db.Studio.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() - doTest(mockStudioReader, mockSceneReader, test) + doTest(db, test) } const unmatchedName = "unmatched" @@ -279,17 +277,16 @@ func TestSceneStudios(t *testing.T) { // test against aliases for _, test := range testTables { - mockStudioReader := &mocks.StudioReaderWriter{} - mockSceneReader := &mocks.SceneReaderWriter{} + db := mocks.NewDatabase() - mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() - mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{ + db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() + db.Studio.On("GetAliases", testCtx, studioID).Return([]string{ studioName, }, nil).Once() - mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once() + db.Studio.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once() - doTest(mockStudioReader, mockSceneReader, test) + doTest(db, test) } } @@ -315,7 +312,7 @@ func TestSceneTags(t *testing.T) { assert := assert.New(t) - doTest := func(mockTagReader *mocks.TagReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) { + doTest := func(db *mocks.Database, test pathTestTable) { if test.Matches { matchPartial := mock.MatchedBy(func(got models.ScenePartial) bool { expected := models.ScenePartial{ @@ -327,7 +324,7 @@ func TestSceneTags(t *testing.T) { return scenePartialsEqual(got, expected) }) - mockSceneReader.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once() + db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once() } scene := models.Scene{ @@ -335,22 +332,21 @@ func TestSceneTags(t *testing.T) { Path: test.Path, TagIDs: models.NewRelatedIDs([]int{}), } - err := SceneTags(testCtx, &scene, mockSceneReader, mockTagReader, nil) + err := SceneTags(testCtx, &scene, db.Scene, db.Tag, nil) assert.Nil(err) - mockTagReader.AssertExpectations(t) - mockSceneReader.AssertExpectations(t) + db.Tag.AssertExpectations(t) + db.Scene.AssertExpectations(t) } for _, test := range testTables { - mockTagReader := &mocks.TagReaderWriter{} - mockSceneReader := &mocks.SceneReaderWriter{} + db := mocks.NewDatabase() - mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() - mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() + db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() + db.Tag.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() - doTest(mockTagReader, mockSceneReader, test) + doTest(db, test) } const unmatchedName = "unmatched" @@ -358,16 +354,15 @@ func TestSceneTags(t *testing.T) { // test against aliases for _, test := range testTables { - mockTagReader := &mocks.TagReaderWriter{} - mockSceneReader := &mocks.SceneReaderWriter{} + db := mocks.NewDatabase() - mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) - mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() - mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{ + db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() + db.Tag.On("GetAliases", testCtx, tagID).Return([]string{ tagName, }, nil).Once() - mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once() + db.Tag.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once() - doTest(mockTagReader, mockSceneReader, test) + doTest(db, test) } } diff --git a/internal/autotag/studio_test.go b/internal/autotag/studio_test.go index aa52c9c5179..d5db603d34f 100644 --- a/internal/autotag/studio_test.go +++ b/internal/autotag/studio_test.go @@ -83,7 +83,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) { aliasName := tc.aliasName aliasRegex := tc.aliasRegex - mockSceneReader := &mocks.SceneReaderWriter{} + db := mocks.NewDatabase() var studioID = 2 @@ -130,7 +130,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) { } // if alias provided, then don't find by name - onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)) + onNameQuery := db.Scene.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)) if aliasName == "" { onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() @@ -145,7 +145,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) { }, } - mockSceneReader.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). + db.Scene.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() } @@ -159,19 +159,19 @@ func testStudioScenes(t *testing.T, tc testStudioCase) { return scenePartialsEqual(got, expected) }) - mockSceneReader.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once() + db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once() } tagger := Tagger{ - TxnManager: &mocks.TxnManager{}, + TxnManager: db, } - err := tagger.StudioScenes(testCtx, &studio, nil, aliases, mockSceneReader) + err := tagger.StudioScenes(testCtx, &studio, nil, aliases, db.Scene) assert := assert.New(t) assert.Nil(err) - mockSceneReader.AssertExpectations(t) + db.Scene.AssertExpectations(t) } func TestStudioImages(t *testing.T) { @@ -188,7 +188,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) { aliasName := tc.aliasName aliasRegex := tc.aliasRegex - mockImageReader := &mocks.ImageReaderWriter{} + db := mocks.NewDatabase() var studioID = 2 @@ -234,7 +234,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) { } // if alias provided, then don't find by name - onNameQuery := mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)) + onNameQuery := db.Image.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)) if aliasName == "" { onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once() } else { @@ -248,7 +248,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) { }, } - mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). + db.Image.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). Return(mocks.ImageQueryResult(images, len(images)), nil).Once() } @@ -262,19 +262,19 @@ func testStudioImages(t *testing.T, tc testStudioCase) { return imagePartialsEqual(got, expected) }) - mockImageReader.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once() + db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once() } tagger := Tagger{ - TxnManager: &mocks.TxnManager{}, + TxnManager: db, } - err := tagger.StudioImages(testCtx, &studio, nil, aliases, mockImageReader) + err := tagger.StudioImages(testCtx, &studio, nil, aliases, db.Image) assert := assert.New(t) assert.Nil(err) - mockImageReader.AssertExpectations(t) + db.Image.AssertExpectations(t) } func TestStudioGalleries(t *testing.T) { @@ -290,7 +290,8 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) { expectedRegex := tc.expectedRegex aliasName := tc.aliasName aliasRegex := tc.aliasRegex - mockGalleryReader := &mocks.GalleryReaderWriter{} + + db := mocks.NewDatabase() var studioID = 2 @@ -337,7 +338,7 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) { } // if alias provided, then don't find by name - onNameQuery := mockGalleryReader.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter) + onNameQuery := db.Gallery.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter) if aliasName == "" { onNameQuery.Return(galleries, len(galleries), nil).Once() } else { @@ -351,7 +352,7 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) { }, } - mockGalleryReader.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() + db.Gallery.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() } for i := range matchingPaths { @@ -364,17 +365,17 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) { return galleryPartialsEqual(got, expected) }) - mockGalleryReader.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once() + db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once() } tagger := Tagger{ - TxnManager: &mocks.TxnManager{}, + TxnManager: db, } - err := tagger.StudioGalleries(testCtx, &studio, nil, aliases, mockGalleryReader) + err := tagger.StudioGalleries(testCtx, &studio, nil, aliases, db.Gallery) assert := assert.New(t) assert.Nil(err) - mockGalleryReader.AssertExpectations(t) + db.Gallery.AssertExpectations(t) } diff --git a/internal/autotag/tag_test.go b/internal/autotag/tag_test.go index 4b183200490..cb2ae907b25 100644 --- a/internal/autotag/tag_test.go +++ b/internal/autotag/tag_test.go @@ -83,7 +83,7 @@ func testTagScenes(t *testing.T, tc testTagCase) { aliasName := tc.aliasName aliasRegex := tc.aliasRegex - mockSceneReader := &mocks.SceneReaderWriter{} + db := mocks.NewDatabase() const tagID = 2 @@ -131,7 +131,7 @@ func testTagScenes(t *testing.T, tc testTagCase) { } // if alias provided, then don't find by name - onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)) + onNameQuery := db.Scene.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)) if aliasName == "" { onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() } else { @@ -145,7 +145,7 @@ func testTagScenes(t *testing.T, tc testTagCase) { }, } - mockSceneReader.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). + db.Scene.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() } @@ -162,19 +162,19 @@ func testTagScenes(t *testing.T, tc testTagCase) { return scenePartialsEqual(got, expected) }) - mockSceneReader.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once() + db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once() } tagger := Tagger{ - TxnManager: &mocks.TxnManager{}, + TxnManager: db, } - err := tagger.TagScenes(testCtx, &tag, nil, aliases, mockSceneReader) + err := tagger.TagScenes(testCtx, &tag, nil, aliases, db.Scene) assert := assert.New(t) assert.Nil(err) - mockSceneReader.AssertExpectations(t) + db.Scene.AssertExpectations(t) } func TestTagImages(t *testing.T) { @@ -191,7 +191,7 @@ func testTagImages(t *testing.T, tc testTagCase) { aliasName := tc.aliasName aliasRegex := tc.aliasRegex - mockImageReader := &mocks.ImageReaderWriter{} + db := mocks.NewDatabase() const tagID = 2 @@ -238,7 +238,7 @@ func testTagImages(t *testing.T, tc testTagCase) { } // if alias provided, then don't find by name - onNameQuery := mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)) + onNameQuery := db.Image.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)) if aliasName == "" { onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once() } else { @@ -252,7 +252,7 @@ func testTagImages(t *testing.T, tc testTagCase) { }, } - mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). + db.Image.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). Return(mocks.ImageQueryResult(images, len(images)), nil).Once() } @@ -269,19 +269,19 @@ func testTagImages(t *testing.T, tc testTagCase) { return imagePartialsEqual(got, expected) }) - mockImageReader.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once() + db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once() } tagger := Tagger{ - TxnManager: &mocks.TxnManager{}, + TxnManager: db, } - err := tagger.TagImages(testCtx, &tag, nil, aliases, mockImageReader) + err := tagger.TagImages(testCtx, &tag, nil, aliases, db.Image) assert := assert.New(t) assert.Nil(err) - mockImageReader.AssertExpectations(t) + db.Image.AssertExpectations(t) } func TestTagGalleries(t *testing.T) { @@ -298,7 +298,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) { aliasName := tc.aliasName aliasRegex := tc.aliasRegex - mockGalleryReader := &mocks.GalleryReaderWriter{} + db := mocks.NewDatabase() const tagID = 2 @@ -346,7 +346,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) { } // if alias provided, then don't find by name - onNameQuery := mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter) + onNameQuery := db.Gallery.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter) if aliasName == "" { onNameQuery.Return(galleries, len(galleries), nil).Once() } else { @@ -360,7 +360,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) { }, } - mockGalleryReader.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() + db.Gallery.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() } for i := range matchingPaths { @@ -376,18 +376,18 @@ func testTagGalleries(t *testing.T, tc testTagCase) { return galleryPartialsEqual(got, expected) }) - mockGalleryReader.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once() + db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once() } tagger := Tagger{ - TxnManager: &mocks.TxnManager{}, + TxnManager: db, } - err := tagger.TagGalleries(testCtx, &tag, nil, aliases, mockGalleryReader) + err := tagger.TagGalleries(testCtx, &tag, nil, aliases, db.Gallery) assert := assert.New(t) assert.Nil(err) - mockGalleryReader.AssertExpectations(t) + db.Gallery.AssertExpectations(t) } diff --git a/internal/identify/identify_test.go b/internal/identify/identify_test.go index b8472c1186d..2c3232091e2 100644 --- a/internal/identify/identify_test.go +++ b/internal/identify/identify_test.go @@ -108,17 +108,17 @@ func TestSceneIdentifier_Identify(t *testing.T) { }, } - mockSceneReaderWriter := &mocks.SceneReaderWriter{} - mockSceneReaderWriter.On("GetURLs", mock.Anything, mock.Anything).Return(nil, nil) - mockSceneReaderWriter.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool { + db := mocks.NewDatabase() + + db.Scene.On("GetURLs", mock.Anything, mock.Anything).Return(nil, nil) + db.Scene.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool { return id == errUpdateID }), mock.Anything).Return(nil, errors.New("update error")) - mockSceneReaderWriter.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool { + db.Scene.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool { return id != errUpdateID }), mock.Anything).Return(nil, nil) - mockTagFinderCreator := &mocks.TagReaderWriter{} - mockTagFinderCreator.On("Find", mock.Anything, skipMultipleTagID).Return(&models.Tag{ + db.Tag.On("Find", mock.Anything, skipMultipleTagID).Return(&models.Tag{ ID: skipMultipleTagID, Name: skipMultipleTagIDStr, }, nil) @@ -185,8 +185,11 @@ func TestSceneIdentifier_Identify(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { identifier := SceneIdentifier{ - SceneReaderUpdater: mockSceneReaderWriter, - TagFinderCreator: mockTagFinderCreator, + TxnManager: db, + SceneReaderUpdater: db.Scene, + StudioReaderWriter: db.Studio, + PerformerCreator: db.Performer, + TagFinderCreator: db.Tag, DefaultOptions: defaultOptions, Sources: sources, SceneUpdatePostHookExecutor: mockHookExecutor{}, @@ -210,6 +213,8 @@ func TestSceneIdentifier_Identify(t *testing.T) { } func TestSceneIdentifier_modifyScene(t *testing.T) { + db := mocks.NewDatabase() + boolFalse := false defaultOptions := &MetadataOptions{ SetOrganized: &boolFalse, @@ -218,8 +223,12 @@ func TestSceneIdentifier_modifyScene(t *testing.T) { SkipSingleNamePerformers: &boolFalse, } tr := &SceneIdentifier{ - TxnManager: &mocks.TxnManager{}, - DefaultOptions: defaultOptions, + TxnManager: db, + SceneReaderUpdater: db.Scene, + StudioReaderWriter: db.Studio, + PerformerCreator: db.Performer, + TagFinderCreator: db.Tag, + DefaultOptions: defaultOptions, } type args struct { diff --git a/internal/identify/performer_test.go b/internal/identify/performer_test.go index 6903b86ab21..09690959de0 100644 --- a/internal/identify/performer_test.go +++ b/internal/identify/performer_test.go @@ -22,8 +22,9 @@ func Test_getPerformerID(t *testing.T) { remoteSiteID := "2" name := "name" - mockPerformerReaderWriter := mocks.PerformerReaderWriter{} - mockPerformerReaderWriter.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) { + db := mocks.NewDatabase() + + db.Performer.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) { p := args.Get(1).(*models.Performer) p.ID = validStoredID }).Return(nil) @@ -131,7 +132,7 @@ func Test_getPerformerID(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := getPerformerID(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p, tt.args.createMissing, tt.args.skipSingleName) + got, err := getPerformerID(testCtx, tt.args.endpoint, db.Performer, tt.args.p, tt.args.createMissing, tt.args.skipSingleName) if (err != nil) != tt.wantErr { t.Errorf("getPerformerID() error = %v, wantErr %v", err, tt.wantErr) return @@ -151,15 +152,16 @@ func Test_createMissingPerformer(t *testing.T) { invalidName := "invalidName" performerID := 1 - mockPerformerReaderWriter := mocks.PerformerReaderWriter{} - mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool { + db := mocks.NewDatabase() + + db.Performer.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool { return p.Name == validName })).Run(func(args mock.Arguments) { p := args.Get(1).(*models.Performer) p.ID = performerID }).Return(nil) - mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool { + db.Performer.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool { return p.Name == invalidName })).Return(errors.New("error creating performer")) @@ -212,7 +214,7 @@ func Test_createMissingPerformer(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := createMissingPerformer(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p) + got, err := createMissingPerformer(testCtx, tt.args.endpoint, db.Performer, tt.args.p) if (err != nil) != tt.wantErr { t.Errorf("createMissingPerformer() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/internal/identify/scene_test.go b/internal/identify/scene_test.go index bb0598b060a..272ca43cb1d 100644 --- a/internal/identify/scene_test.go +++ b/internal/identify/scene_test.go @@ -24,14 +24,15 @@ func Test_sceneRelationships_studio(t *testing.T) { Strategy: FieldStrategyMerge, } - mockStudioReaderWriter := &mocks.StudioReaderWriter{} - mockStudioReaderWriter.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) { + db := mocks.NewDatabase() + + db.Studio.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) { s := args.Get(1).(*models.Studio) s.ID = validStoredIDInt }).Return(nil) tr := sceneRelationships{ - studioReaderWriter: mockStudioReaderWriter, + studioReaderWriter: db.Studio, fieldOptions: make(map[string]*FieldOptions), } @@ -174,8 +175,10 @@ func Test_sceneRelationships_performers(t *testing.T) { }), } + db := mocks.NewDatabase() + tr := sceneRelationships{ - sceneReader: &mocks.SceneReaderWriter{}, + sceneReader: db.Scene, fieldOptions: make(map[string]*FieldOptions), } @@ -363,22 +366,21 @@ func Test_sceneRelationships_tags(t *testing.T) { StashIDs: models.NewRelatedStashIDs([]models.StashID{}), } - mockSceneReaderWriter := &mocks.SceneReaderWriter{} - mockTagReaderWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() - mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool { + db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool { return p.Name == validName })).Run(func(args mock.Arguments) { t := args.Get(1).(*models.Tag) t.ID = validStoredIDInt }).Return(nil) - mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool { + db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool { return p.Name == invalidName })).Return(errors.New("error creating tag")) tr := sceneRelationships{ - sceneReader: mockSceneReaderWriter, - tagCreator: mockTagReaderWriter, + sceneReader: db.Scene, + tagCreator: db.Tag, fieldOptions: make(map[string]*FieldOptions), } @@ -552,10 +554,10 @@ func Test_sceneRelationships_stashIDs(t *testing.T) { }), } - mockSceneReaderWriter := &mocks.SceneReaderWriter{} + db := mocks.NewDatabase() tr := sceneRelationships{ - sceneReader: mockSceneReaderWriter, + sceneReader: db.Scene, fieldOptions: make(map[string]*FieldOptions), } @@ -706,12 +708,13 @@ func Test_sceneRelationships_cover(t *testing.T) { newDataEncoded := base64Prefix + utils.GetBase64StringFromData(newData) invalidData := newDataEncoded + "!!!" - mockSceneReaderWriter := &mocks.SceneReaderWriter{} - mockSceneReaderWriter.On("GetCover", testCtx, sceneID).Return(existingData, nil) - mockSceneReaderWriter.On("GetCover", testCtx, errSceneID).Return(nil, errors.New("error getting cover")) + db := mocks.NewDatabase() + + db.Scene.On("GetCover", testCtx, sceneID).Return(existingData, nil) + db.Scene.On("GetCover", testCtx, errSceneID).Return(nil, errors.New("error getting cover")) tr := sceneRelationships{ - sceneReader: mockSceneReaderWriter, + sceneReader: db.Scene, fieldOptions: make(map[string]*FieldOptions), } diff --git a/internal/identify/studio_test.go b/internal/identify/studio_test.go index 458cf6da67d..5424a6a93c1 100644 --- a/internal/identify/studio_test.go +++ b/internal/identify/studio_test.go @@ -19,18 +19,19 @@ func Test_createMissingStudio(t *testing.T) { invalidName := "invalidName" createdID := 1 - mockStudioReaderWriter := &mocks.StudioReaderWriter{} - mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool { + db := mocks.NewDatabase() + + db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool { return p.Name == validName })).Run(func(args mock.Arguments) { s := args.Get(1).(*models.Studio) s.ID = createdID }).Return(nil) - mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool { + db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool { return p.Name == invalidName })).Return(errors.New("error creating studio")) - mockStudioReaderWriter.On("UpdatePartial", testCtx, models.StudioPartial{ + db.Studio.On("UpdatePartial", testCtx, models.StudioPartial{ ID: createdID, StashIDs: &models.UpdateStashIDs{ StashIDs: []models.StashID{ @@ -42,7 +43,7 @@ func Test_createMissingStudio(t *testing.T) { Mode: models.RelationshipUpdateModeSet, }, }).Return(nil, errors.New("error updating stash ids")) - mockStudioReaderWriter.On("UpdatePartial", testCtx, models.StudioPartial{ + db.Studio.On("UpdatePartial", testCtx, models.StudioPartial{ ID: createdID, StashIDs: &models.UpdateStashIDs{ StashIDs: []models.StashID{ @@ -106,7 +107,7 @@ func Test_createMissingStudio(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := createMissingStudio(testCtx, tt.args.endpoint, mockStudioReaderWriter, tt.args.studio) + got, err := createMissingStudio(testCtx, tt.args.endpoint, db.Studio, tt.args.studio) if (err != nil) != tt.wantErr { t.Errorf("createMissingStudio() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/gallery/export_test.go b/pkg/gallery/export_test.go index 3a6ffa2ec55..db0c0691043 100644 --- a/pkg/gallery/export_test.go +++ b/pkg/gallery/export_test.go @@ -157,19 +157,19 @@ var getStudioScenarios = []stringTestScenario{ } func TestGetStudioName(t *testing.T) { - mockStudioReader := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() studioErr := errors.New("error getting image") - mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{ + db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{ Name: studioName, }, nil).Once() - mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once() - mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once() + db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once() + db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once() for i, s := range getStudioScenarios { gallery := s.input - json, err := GetStudioName(testCtx, mockStudioReader, &gallery) + json, err := GetStudioName(testCtx, db.Studio, &gallery) switch { case !s.err && err != nil: @@ -181,7 +181,7 @@ func TestGetStudioName(t *testing.T) { } } - mockStudioReader.AssertExpectations(t) + db.Studio.AssertExpectations(t) } const ( @@ -258,17 +258,17 @@ var validChapters = []*models.GalleryChapter{ } func TestGetGalleryChaptersJSON(t *testing.T) { - mockChapterReader := &mocks.GalleryChapterReaderWriter{} + db := mocks.NewDatabase() chaptersErr := errors.New("error getting gallery chapters") - mockChapterReader.On("FindByGalleryID", testCtx, galleryID).Return(validChapters, nil).Once() - mockChapterReader.On("FindByGalleryID", testCtx, noChaptersID).Return(nil, nil).Once() - mockChapterReader.On("FindByGalleryID", testCtx, errChaptersID).Return(nil, chaptersErr).Once() + db.GalleryChapter.On("FindByGalleryID", testCtx, galleryID).Return(validChapters, nil).Once() + db.GalleryChapter.On("FindByGalleryID", testCtx, noChaptersID).Return(nil, nil).Once() + db.GalleryChapter.On("FindByGalleryID", testCtx, errChaptersID).Return(nil, chaptersErr).Once() for i, s := range getGalleryChaptersJSONScenarios { gallery := s.input - json, err := GetGalleryChaptersJSON(testCtx, mockChapterReader, &gallery) + json, err := GetGalleryChaptersJSON(testCtx, db.GalleryChapter, &gallery) switch { case !s.err && err != nil: @@ -279,5 +279,4 @@ func TestGetGalleryChaptersJSON(t *testing.T) { assert.Equal(t, s.expected, json, "[%d]", i) } } - } diff --git a/pkg/gallery/import_test.go b/pkg/gallery/import_test.go index 0997b4a57e2..e165fe3afb1 100644 --- a/pkg/gallery/import_test.go +++ b/pkg/gallery/import_test.go @@ -78,19 +78,19 @@ func TestImporterPreImport(t *testing.T) { } func TestImporterPreImportWithStudio(t *testing.T) { - studioReaderWriter := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - StudioWriter: studioReaderWriter, + StudioWriter: db.Studio, Input: jsonschema.Gallery{ Studio: existingStudioName, }, } - studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ + db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ ID: existingStudioID, }, nil).Once() - studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() + db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() err := i.PreImport(testCtx) assert.Nil(t, err) @@ -100,22 +100,22 @@ func TestImporterPreImportWithStudio(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - studioReaderWriter.AssertExpectations(t) + db.Studio.AssertExpectations(t) } func TestImporterPreImportWithMissingStudio(t *testing.T) { - studioReaderWriter := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - StudioWriter: studioReaderWriter, + StudioWriter: db.Studio, Input: jsonschema.Gallery{ Studio: missingStudioName, }, MissingRefBehaviour: models.ImportMissingRefEnumFail, } - studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) - studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { + db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { s := args.Get(1).(*models.Studio) s.ID = existingStudioID }).Return(nil) @@ -132,32 +132,32 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { assert.Nil(t, err) assert.Equal(t, existingStudioID, *i.gallery.StudioID) - studioReaderWriter.AssertExpectations(t) + db.Studio.AssertExpectations(t) } func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { - studioReaderWriter := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - StudioWriter: studioReaderWriter, + StudioWriter: db.Studio, Input: jsonschema.Gallery{ Studio: missingStudioName, }, MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() - studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) } func TestImporterPreImportWithPerformer(t *testing.T) { - performerReaderWriter := &mocks.PerformerReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - PerformerWriter: performerReaderWriter, + PerformerWriter: db.Performer, MissingRefBehaviour: models.ImportMissingRefEnumFail, Input: jsonschema.Gallery{ Performers: []string{ @@ -166,13 +166,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) { }, } - performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ + db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ { ID: existingPerformerID, Name: existingPerformerName, }, }, nil).Once() - performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() + db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() err := i.PreImport(testCtx) assert.Nil(t, err) @@ -182,14 +182,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - performerReaderWriter.AssertExpectations(t) + db.Performer.AssertExpectations(t) } func TestImporterPreImportWithMissingPerformer(t *testing.T) { - performerReaderWriter := &mocks.PerformerReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - PerformerWriter: performerReaderWriter, + PerformerWriter: db.Performer, Input: jsonschema.Gallery{ Performers: []string{ missingPerformerName, @@ -198,8 +198,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) - performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { + db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) + db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { performer := args.Get(1).(*models.Performer) performer.ID = existingPerformerID }).Return(nil) @@ -216,14 +216,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { assert.Nil(t, err) assert.Equal(t, []int{existingPerformerID}, i.gallery.PerformerIDs.List()) - performerReaderWriter.AssertExpectations(t) + db.Performer.AssertExpectations(t) } func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { - performerReaderWriter := &mocks.PerformerReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - PerformerWriter: performerReaderWriter, + PerformerWriter: db.Performer, Input: jsonschema.Gallery{ Performers: []string{ missingPerformerName, @@ -232,18 +232,18 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() - performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) + db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() + db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) } func TestImporterPreImportWithTag(t *testing.T) { - tagReaderWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - TagWriter: tagReaderWriter, + TagWriter: db.Tag, MissingRefBehaviour: models.ImportMissingRefEnumFail, Input: jsonschema.Gallery{ Tags: []string{ @@ -252,13 +252,13 @@ func TestImporterPreImportWithTag(t *testing.T) { }, } - tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ + db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ { ID: existingTagID, Name: existingTagName, }, }, nil).Once() - tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() + db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() err := i.PreImport(testCtx) assert.Nil(t, err) @@ -268,14 +268,14 @@ func TestImporterPreImportWithTag(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - tagReaderWriter.AssertExpectations(t) + db.Tag.AssertExpectations(t) } func TestImporterPreImportWithMissingTag(t *testing.T) { - tagReaderWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - TagWriter: tagReaderWriter, + TagWriter: db.Tag, Input: jsonschema.Gallery{ Tags: []string{ missingTagName, @@ -284,8 +284,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) - tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { + db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { t := args.Get(1).(*models.Tag) t.ID = existingTagID }).Return(nil) @@ -302,14 +302,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { assert.Nil(t, err) assert.Equal(t, []int{existingTagID}, i.gallery.TagIDs.List()) - tagReaderWriter.AssertExpectations(t) + db.Tag.AssertExpectations(t) } func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { - tagReaderWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - TagWriter: tagReaderWriter, + TagWriter: db.Tag, Input: jsonschema.Gallery{ Tags: []string{ missingTagName, @@ -318,8 +318,8 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() - tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) + db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/image/export_test.go b/pkg/image/export_test.go index 1a5897271ef..b228a371bc8 100644 --- a/pkg/image/export_test.go +++ b/pkg/image/export_test.go @@ -130,19 +130,19 @@ var getStudioScenarios = []stringTestScenario{ } func TestGetStudioName(t *testing.T) { - mockStudioReader := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() studioErr := errors.New("error getting image") - mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{ + db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{ Name: studioName, }, nil).Once() - mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once() - mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once() + db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once() + db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once() for i, s := range getStudioScenarios { image := s.input - json, err := GetStudioName(testCtx, mockStudioReader, &image) + json, err := GetStudioName(testCtx, db.Studio, &image) switch { case !s.err && err != nil: @@ -154,5 +154,5 @@ func TestGetStudioName(t *testing.T) { } } - mockStudioReader.AssertExpectations(t) + db.Studio.AssertExpectations(t) } diff --git a/pkg/image/import_test.go b/pkg/image/import_test.go index 3ab586359e8..ea1d899eeb4 100644 --- a/pkg/image/import_test.go +++ b/pkg/image/import_test.go @@ -40,19 +40,19 @@ func TestImporterPreImport(t *testing.T) { } func TestImporterPreImportWithStudio(t *testing.T) { - studioReaderWriter := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - StudioWriter: studioReaderWriter, + StudioWriter: db.Studio, Input: jsonschema.Image{ Studio: existingStudioName, }, } - studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ + db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ ID: existingStudioID, }, nil).Once() - studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() + db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() err := i.PreImport(testCtx) assert.Nil(t, err) @@ -62,22 +62,22 @@ func TestImporterPreImportWithStudio(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - studioReaderWriter.AssertExpectations(t) + db.Studio.AssertExpectations(t) } func TestImporterPreImportWithMissingStudio(t *testing.T) { - studioReaderWriter := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - StudioWriter: studioReaderWriter, + StudioWriter: db.Studio, Input: jsonschema.Image{ Studio: missingStudioName, }, MissingRefBehaviour: models.ImportMissingRefEnumFail, } - studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) - studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { + db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { s := args.Get(1).(*models.Studio) s.ID = existingStudioID }).Return(nil) @@ -94,32 +94,32 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { assert.Nil(t, err) assert.Equal(t, existingStudioID, *i.image.StudioID) - studioReaderWriter.AssertExpectations(t) + db.Studio.AssertExpectations(t) } func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { - studioReaderWriter := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - StudioWriter: studioReaderWriter, + StudioWriter: db.Studio, Input: jsonschema.Image{ Studio: missingStudioName, }, MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() - studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) } func TestImporterPreImportWithPerformer(t *testing.T) { - performerReaderWriter := &mocks.PerformerReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - PerformerWriter: performerReaderWriter, + PerformerWriter: db.Performer, MissingRefBehaviour: models.ImportMissingRefEnumFail, Input: jsonschema.Image{ Performers: []string{ @@ -128,13 +128,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) { }, } - performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ + db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ { ID: existingPerformerID, Name: existingPerformerName, }, }, nil).Once() - performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() + db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() err := i.PreImport(testCtx) assert.Nil(t, err) @@ -144,14 +144,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - performerReaderWriter.AssertExpectations(t) + db.Performer.AssertExpectations(t) } func TestImporterPreImportWithMissingPerformer(t *testing.T) { - performerReaderWriter := &mocks.PerformerReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - PerformerWriter: performerReaderWriter, + PerformerWriter: db.Performer, Input: jsonschema.Image{ Performers: []string{ missingPerformerName, @@ -160,8 +160,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) - performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { + db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) + db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { performer := args.Get(1).(*models.Performer) performer.ID = existingPerformerID }).Return(nil) @@ -178,14 +178,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { assert.Nil(t, err) assert.Equal(t, []int{existingPerformerID}, i.image.PerformerIDs.List()) - performerReaderWriter.AssertExpectations(t) + db.Performer.AssertExpectations(t) } func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { - performerReaderWriter := &mocks.PerformerReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - PerformerWriter: performerReaderWriter, + PerformerWriter: db.Performer, Input: jsonschema.Image{ Performers: []string{ missingPerformerName, @@ -194,18 +194,18 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() - performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) + db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() + db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) } func TestImporterPreImportWithTag(t *testing.T) { - tagReaderWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - TagWriter: tagReaderWriter, + TagWriter: db.Tag, MissingRefBehaviour: models.ImportMissingRefEnumFail, Input: jsonschema.Image{ Tags: []string{ @@ -214,13 +214,13 @@ func TestImporterPreImportWithTag(t *testing.T) { }, } - tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ + db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ { ID: existingTagID, Name: existingTagName, }, }, nil).Once() - tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() + db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() err := i.PreImport(testCtx) assert.Nil(t, err) @@ -230,14 +230,14 @@ func TestImporterPreImportWithTag(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - tagReaderWriter.AssertExpectations(t) + db.Tag.AssertExpectations(t) } func TestImporterPreImportWithMissingTag(t *testing.T) { - tagReaderWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - TagWriter: tagReaderWriter, + TagWriter: db.Tag, Input: jsonschema.Image{ Tags: []string{ missingTagName, @@ -246,8 +246,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) - tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { + db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { t := args.Get(1).(*models.Tag) t.ID = existingTagID }).Return(nil) @@ -264,14 +264,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { assert.Nil(t, err) assert.Equal(t, []int{existingTagID}, i.image.TagIDs.List()) - tagReaderWriter.AssertExpectations(t) + db.Tag.AssertExpectations(t) } func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { - tagReaderWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - TagWriter: tagReaderWriter, + TagWriter: db.Tag, Input: jsonschema.Image{ Tags: []string{ missingTagName, @@ -280,8 +280,8 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() - tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) + db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/models/mocks/database.go b/pkg/models/mocks/database.go new file mode 100644 index 00000000000..c3b6cd62631 --- /dev/null +++ b/pkg/models/mocks/database.go @@ -0,0 +1,91 @@ +package mocks + +import ( + "context" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/txn" +) + +type Database struct { + File *FileReaderWriter + Folder *FolderReaderWriter + Gallery *GalleryReaderWriter + GalleryChapter *GalleryChapterReaderWriter + Image *ImageReaderWriter + Movie *MovieReaderWriter + Performer *PerformerReaderWriter + Scene *SceneReaderWriter + SceneMarker *SceneMarkerReaderWriter + Studio *StudioReaderWriter + Tag *TagReaderWriter + SavedFilter *SavedFilterReaderWriter +} + +func (*Database) Begin(ctx context.Context, exclusive bool) (context.Context, error) { + return ctx, nil +} + +func (*Database) WithDatabase(ctx context.Context) (context.Context, error) { + return ctx, nil +} + +func (*Database) Commit(ctx context.Context) error { + return nil +} + +func (*Database) Rollback(ctx context.Context) error { + return nil +} + +func (*Database) Complete(ctx context.Context) { +} + +func (*Database) AddPostCommitHook(ctx context.Context, hook txn.TxnFunc) { +} + +func (*Database) AddPostRollbackHook(ctx context.Context, hook txn.TxnFunc) { +} + +func (*Database) IsLocked(err error) bool { + return false +} + +func (*Database) Reset() error { + return nil +} + +func NewDatabase() *Database { + return &Database{ + File: &FileReaderWriter{}, + Folder: &FolderReaderWriter{}, + Gallery: &GalleryReaderWriter{}, + GalleryChapter: &GalleryChapterReaderWriter{}, + Image: &ImageReaderWriter{}, + Movie: &MovieReaderWriter{}, + Performer: &PerformerReaderWriter{}, + Scene: &SceneReaderWriter{}, + SceneMarker: &SceneMarkerReaderWriter{}, + Studio: &StudioReaderWriter{}, + Tag: &TagReaderWriter{}, + SavedFilter: &SavedFilterReaderWriter{}, + } +} + +func (db *Database) Repository() models.Repository { + return models.Repository{ + TxnManager: db, + File: db.File, + Folder: db.Folder, + Gallery: db.Gallery, + GalleryChapter: db.GalleryChapter, + Image: db.Image, + Movie: db.Movie, + Performer: db.Performer, + Scene: db.Scene, + SceneMarker: db.SceneMarker, + Studio: db.Studio, + Tag: db.Tag, + SavedFilter: db.SavedFilter, + } +} diff --git a/pkg/models/mocks/transaction.go b/pkg/models/mocks/transaction.go deleted file mode 100644 index e7a0163d4de..00000000000 --- a/pkg/models/mocks/transaction.go +++ /dev/null @@ -1,59 +0,0 @@ -package mocks - -import ( - context "context" - - "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/txn" -) - -type TxnManager struct{} - -func (*TxnManager) Begin(ctx context.Context, exclusive bool) (context.Context, error) { - return ctx, nil -} - -func (*TxnManager) WithDatabase(ctx context.Context) (context.Context, error) { - return ctx, nil -} - -func (*TxnManager) Commit(ctx context.Context) error { - return nil -} - -func (*TxnManager) Rollback(ctx context.Context) error { - return nil -} - -func (*TxnManager) Complete(ctx context.Context) { -} - -func (*TxnManager) AddPostCommitHook(ctx context.Context, hook txn.TxnFunc) { -} - -func (*TxnManager) AddPostRollbackHook(ctx context.Context, hook txn.TxnFunc) { -} - -func (*TxnManager) IsLocked(err error) bool { - return false -} - -func (*TxnManager) Reset() error { - return nil -} - -func NewTxnRepository() models.Repository { - return models.Repository{ - TxnManager: &TxnManager{}, - Gallery: &GalleryReaderWriter{}, - GalleryChapter: &GalleryChapterReaderWriter{}, - Image: &ImageReaderWriter{}, - Movie: &MovieReaderWriter{}, - Performer: &PerformerReaderWriter{}, - Scene: &SceneReaderWriter{}, - SceneMarker: &SceneMarkerReaderWriter{}, - Studio: &StudioReaderWriter{}, - Tag: &TagReaderWriter{}, - SavedFilter: &SavedFilterReaderWriter{}, - } -} diff --git a/pkg/movie/export_test.go b/pkg/movie/export_test.go index 2f037a758da..d369caaa789 100644 --- a/pkg/movie/export_test.go +++ b/pkg/movie/export_test.go @@ -168,34 +168,32 @@ func initTestTable() { func TestToJSON(t *testing.T) { initTestTable() - mockMovieReader := &mocks.MovieReaderWriter{} + db := mocks.NewDatabase() imageErr := errors.New("error getting image") - mockMovieReader.On("GetFrontImage", testCtx, movieID).Return(frontImageBytes, nil).Once() - mockMovieReader.On("GetFrontImage", testCtx, missingStudioMovieID).Return(frontImageBytes, nil).Once() - mockMovieReader.On("GetFrontImage", testCtx, emptyID).Return(nil, nil).Once().Maybe() - mockMovieReader.On("GetFrontImage", testCtx, errFrontImageID).Return(nil, imageErr).Once() - mockMovieReader.On("GetFrontImage", testCtx, errBackImageID).Return(frontImageBytes, nil).Once() + db.Movie.On("GetFrontImage", testCtx, movieID).Return(frontImageBytes, nil).Once() + db.Movie.On("GetFrontImage", testCtx, missingStudioMovieID).Return(frontImageBytes, nil).Once() + db.Movie.On("GetFrontImage", testCtx, emptyID).Return(nil, nil).Once().Maybe() + db.Movie.On("GetFrontImage", testCtx, errFrontImageID).Return(nil, imageErr).Once() + db.Movie.On("GetFrontImage", testCtx, errBackImageID).Return(frontImageBytes, nil).Once() - mockMovieReader.On("GetBackImage", testCtx, movieID).Return(backImageBytes, nil).Once() - mockMovieReader.On("GetBackImage", testCtx, missingStudioMovieID).Return(backImageBytes, nil).Once() - mockMovieReader.On("GetBackImage", testCtx, emptyID).Return(nil, nil).Once() - mockMovieReader.On("GetBackImage", testCtx, errBackImageID).Return(nil, imageErr).Once() - mockMovieReader.On("GetBackImage", testCtx, errFrontImageID).Return(backImageBytes, nil).Maybe() - mockMovieReader.On("GetBackImage", testCtx, errStudioMovieID).Return(backImageBytes, nil).Maybe() - - mockStudioReader := &mocks.StudioReaderWriter{} + db.Movie.On("GetBackImage", testCtx, movieID).Return(backImageBytes, nil).Once() + db.Movie.On("GetBackImage", testCtx, missingStudioMovieID).Return(backImageBytes, nil).Once() + db.Movie.On("GetBackImage", testCtx, emptyID).Return(nil, nil).Once() + db.Movie.On("GetBackImage", testCtx, errBackImageID).Return(nil, imageErr).Once() + db.Movie.On("GetBackImage", testCtx, errFrontImageID).Return(backImageBytes, nil).Maybe() + db.Movie.On("GetBackImage", testCtx, errStudioMovieID).Return(backImageBytes, nil).Maybe() studioErr := errors.New("error getting studio") - mockStudioReader.On("Find", testCtx, studioID).Return(&movieStudio, nil) - mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil) - mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr) + db.Studio.On("Find", testCtx, studioID).Return(&movieStudio, nil) + db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil) + db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr) for i, s := range scenarios { movie := s.movie - json, err := ToJSON(testCtx, mockMovieReader, mockStudioReader, &movie) + json, err := ToJSON(testCtx, db.Movie, db.Studio, &movie) switch { case !s.err && err != nil: @@ -207,6 +205,6 @@ func TestToJSON(t *testing.T) { } } - mockMovieReader.AssertExpectations(t) - mockStudioReader.AssertExpectations(t) + db.Movie.AssertExpectations(t) + db.Studio.AssertExpectations(t) } diff --git a/pkg/movie/import_test.go b/pkg/movie/import_test.go index e4bca5a969e..c4957545da3 100644 --- a/pkg/movie/import_test.go +++ b/pkg/movie/import_test.go @@ -69,10 +69,11 @@ func TestImporterPreImport(t *testing.T) { } func TestImporterPreImportWithStudio(t *testing.T) { - studioReaderWriter := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - StudioWriter: studioReaderWriter, + ReaderWriter: db.Movie, + StudioWriter: db.Studio, Input: jsonschema.Movie{ Name: movieName, FrontImage: frontImage, @@ -82,10 +83,10 @@ func TestImporterPreImportWithStudio(t *testing.T) { }, } - studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ + db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ ID: existingStudioID, }, nil).Once() - studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() + db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() err := i.PreImport(testCtx) assert.Nil(t, err) @@ -95,14 +96,15 @@ func TestImporterPreImportWithStudio(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - studioReaderWriter.AssertExpectations(t) + db.Studio.AssertExpectations(t) } func TestImporterPreImportWithMissingStudio(t *testing.T) { - studioReaderWriter := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - StudioWriter: studioReaderWriter, + ReaderWriter: db.Movie, + StudioWriter: db.Studio, Input: jsonschema.Movie{ Name: movieName, FrontImage: frontImage, @@ -111,8 +113,8 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) - studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { + db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { s := args.Get(1).(*models.Studio) s.ID = existingStudioID }).Return(nil) @@ -129,14 +131,15 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { assert.Nil(t, err) assert.Equal(t, existingStudioID, *i.movie.StudioID) - studioReaderWriter.AssertExpectations(t) + db.Studio.AssertExpectations(t) } func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { - studioReaderWriter := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - StudioWriter: studioReaderWriter, + ReaderWriter: db.Movie, + StudioWriter: db.Studio, Input: jsonschema.Movie{ Name: movieName, FrontImage: frontImage, @@ -145,27 +148,28 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() - studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) } func TestImporterPostImport(t *testing.T) { - readerWriter := &mocks.MovieReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Movie, + StudioWriter: db.Studio, frontImageData: frontImageBytes, backImageData: backImageBytes, } updateMovieImageErr := errors.New("UpdateImages error") - readerWriter.On("UpdateFrontImage", testCtx, movieID, frontImageBytes).Return(nil).Once() - readerWriter.On("UpdateBackImage", testCtx, movieID, backImageBytes).Return(nil).Once() - readerWriter.On("UpdateFrontImage", testCtx, errImageID, frontImageBytes).Return(updateMovieImageErr).Once() + db.Movie.On("UpdateFrontImage", testCtx, movieID, frontImageBytes).Return(nil).Once() + db.Movie.On("UpdateBackImage", testCtx, movieID, backImageBytes).Return(nil).Once() + db.Movie.On("UpdateFrontImage", testCtx, errImageID, frontImageBytes).Return(updateMovieImageErr).Once() err := i.PostImport(testCtx, movieID) assert.Nil(t, err) @@ -173,25 +177,26 @@ func TestImporterPostImport(t *testing.T) { err = i.PostImport(testCtx, errImageID) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Movie.AssertExpectations(t) } func TestImporterFindExistingID(t *testing.T) { - readerWriter := &mocks.MovieReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Movie, + StudioWriter: db.Studio, Input: jsonschema.Movie{ Name: movieName, }, } errFindByName := errors.New("FindByName error") - readerWriter.On("FindByName", testCtx, movieName, false).Return(nil, nil).Once() - readerWriter.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{ + db.Movie.On("FindByName", testCtx, movieName, false).Return(nil, nil).Once() + db.Movie.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{ ID: existingMovieID, }, nil).Once() - readerWriter.On("FindByName", testCtx, movieNameErr, false).Return(nil, errFindByName).Once() + db.Movie.On("FindByName", testCtx, movieNameErr, false).Return(nil, errFindByName).Once() id, err := i.FindExistingID(testCtx) assert.Nil(t, id) @@ -207,11 +212,11 @@ func TestImporterFindExistingID(t *testing.T) { assert.Nil(t, id) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Movie.AssertExpectations(t) } func TestCreate(t *testing.T) { - readerWriter := &mocks.MovieReaderWriter{} + db := mocks.NewDatabase() movie := models.Movie{ Name: movieName, @@ -222,16 +227,17 @@ func TestCreate(t *testing.T) { } i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Movie, + StudioWriter: db.Studio, movie: movie, } errCreate := errors.New("Create error") - readerWriter.On("Create", testCtx, &movie).Run(func(args mock.Arguments) { + db.Movie.On("Create", testCtx, &movie).Run(func(args mock.Arguments) { m := args.Get(1).(*models.Movie) m.ID = movieID }).Return(nil).Once() - readerWriter.On("Create", testCtx, &movieErr).Return(errCreate).Once() + db.Movie.On("Create", testCtx, &movieErr).Return(errCreate).Once() id, err := i.Create(testCtx) assert.Equal(t, movieID, *id) @@ -242,11 +248,11 @@ func TestCreate(t *testing.T) { assert.Nil(t, id) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Movie.AssertExpectations(t) } func TestUpdate(t *testing.T) { - readerWriter := &mocks.MovieReaderWriter{} + db := mocks.NewDatabase() movie := models.Movie{ Name: movieName, @@ -257,7 +263,8 @@ func TestUpdate(t *testing.T) { } i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Movie, + StudioWriter: db.Studio, movie: movie, } @@ -265,7 +272,7 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input movie.ID = movieID - readerWriter.On("Update", testCtx, &movie).Return(nil).Once() + db.Movie.On("Update", testCtx, &movie).Return(nil).Once() err := i.Update(testCtx, movieID) assert.Nil(t, err) @@ -274,10 +281,10 @@ func TestUpdate(t *testing.T) { // need to set id separately movieErr.ID = errImageID - readerWriter.On("Update", testCtx, &movieErr).Return(errUpdate).Once() + db.Movie.On("Update", testCtx, &movieErr).Return(errUpdate).Once() err = i.Update(testCtx, errImageID) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Movie.AssertExpectations(t) } diff --git a/pkg/performer/export_test.go b/pkg/performer/export_test.go index d63a9e05eb8..483cf80d5d8 100644 --- a/pkg/performer/export_test.go +++ b/pkg/performer/export_test.go @@ -203,17 +203,17 @@ func initTestTable() { func TestToJSON(t *testing.T) { initTestTable() - mockPerformerReader := &mocks.PerformerReaderWriter{} + db := mocks.NewDatabase() imageErr := errors.New("error getting image") - mockPerformerReader.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once() - mockPerformerReader.On("GetImage", testCtx, noImageID).Return(nil, nil).Once() - mockPerformerReader.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once() + db.Performer.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once() + db.Performer.On("GetImage", testCtx, noImageID).Return(nil, nil).Once() + db.Performer.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once() for i, s := range scenarios { tag := s.input - json, err := ToJSON(testCtx, mockPerformerReader, &tag) + json, err := ToJSON(testCtx, db.Performer, &tag) switch { case !s.err && err != nil: @@ -225,5 +225,5 @@ func TestToJSON(t *testing.T) { } } - mockPerformerReader.AssertExpectations(t) + db.Performer.AssertExpectations(t) } diff --git a/pkg/performer/import_test.go b/pkg/performer/import_test.go index cb4bbd25fe1..26960a1a539 100644 --- a/pkg/performer/import_test.go +++ b/pkg/performer/import_test.go @@ -63,10 +63,11 @@ func TestImporterPreImport(t *testing.T) { } func TestImporterPreImportWithTag(t *testing.T) { - tagReaderWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - TagWriter: tagReaderWriter, + ReaderWriter: db.Performer, + TagWriter: db.Tag, MissingRefBehaviour: models.ImportMissingRefEnumFail, Input: jsonschema.Performer{ Tags: []string{ @@ -75,13 +76,13 @@ func TestImporterPreImportWithTag(t *testing.T) { }, } - tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ + db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ { ID: existingTagID, Name: existingTagName, }, }, nil).Once() - tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() + db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() err := i.PreImport(testCtx) assert.Nil(t, err) @@ -91,14 +92,15 @@ func TestImporterPreImportWithTag(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - tagReaderWriter.AssertExpectations(t) + db.Tag.AssertExpectations(t) } func TestImporterPreImportWithMissingTag(t *testing.T) { - tagReaderWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - TagWriter: tagReaderWriter, + ReaderWriter: db.Performer, + TagWriter: db.Tag, Input: jsonschema.Performer{ Tags: []string{ missingTagName, @@ -107,8 +109,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) - tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { + db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { t := args.Get(1).(*models.Tag) t.ID = existingTagID }).Return(nil) @@ -125,14 +127,15 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { assert.Nil(t, err) assert.Equal(t, existingTagID, i.performer.TagIDs.List()[0]) - tagReaderWriter.AssertExpectations(t) + db.Tag.AssertExpectations(t) } func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { - tagReaderWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - TagWriter: tagReaderWriter, + ReaderWriter: db.Performer, + TagWriter: db.Tag, Input: jsonschema.Performer{ Tags: []string{ missingTagName, @@ -141,25 +144,26 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() - tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) + db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) } func TestImporterPostImport(t *testing.T) { - readerWriter := &mocks.PerformerReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Performer, + TagWriter: db.Tag, imageData: imageBytes, } updatePerformerImageErr := errors.New("UpdateImage error") - readerWriter.On("UpdateImage", testCtx, performerID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updatePerformerImageErr).Once() + db.Performer.On("UpdateImage", testCtx, performerID, imageBytes).Return(nil).Once() + db.Performer.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updatePerformerImageErr).Once() err := i.PostImport(testCtx, performerID) assert.Nil(t, err) @@ -167,14 +171,15 @@ func TestImporterPostImport(t *testing.T) { err = i.PostImport(testCtx, errImageID) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Performer.AssertExpectations(t) } func TestImporterFindExistingID(t *testing.T) { - readerWriter := &mocks.PerformerReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Performer, + TagWriter: db.Tag, Input: jsonschema.Performer{ Name: performerName, }, @@ -195,13 +200,13 @@ func TestImporterFindExistingID(t *testing.T) { } errFindByNames := errors.New("FindByNames error") - readerWriter.On("Query", testCtx, performerFilter(performerName), findFilter).Return(nil, 0, nil).Once() - readerWriter.On("Query", testCtx, performerFilter(existingPerformerName), findFilter).Return([]*models.Performer{ + db.Performer.On("Query", testCtx, performerFilter(performerName), findFilter).Return(nil, 0, nil).Once() + db.Performer.On("Query", testCtx, performerFilter(existingPerformerName), findFilter).Return([]*models.Performer{ { ID: existingPerformerID, }, }, 1, nil).Once() - readerWriter.On("Query", testCtx, performerFilter(performerNameErr), findFilter).Return(nil, 0, errFindByNames).Once() + db.Performer.On("Query", testCtx, performerFilter(performerNameErr), findFilter).Return(nil, 0, errFindByNames).Once() id, err := i.FindExistingID(testCtx) assert.Nil(t, id) @@ -217,11 +222,11 @@ func TestImporterFindExistingID(t *testing.T) { assert.Nil(t, id) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Performer.AssertExpectations(t) } func TestCreate(t *testing.T) { - readerWriter := &mocks.PerformerReaderWriter{} + db := mocks.NewDatabase() performer := models.Performer{ Name: performerName, @@ -232,16 +237,17 @@ func TestCreate(t *testing.T) { } i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Performer, + TagWriter: db.Tag, performer: performer, } errCreate := errors.New("Create error") - readerWriter.On("Create", testCtx, &performer).Run(func(args mock.Arguments) { + db.Performer.On("Create", testCtx, &performer).Run(func(args mock.Arguments) { arg := args.Get(1).(*models.Performer) arg.ID = performerID }).Return(nil).Once() - readerWriter.On("Create", testCtx, &performerErr).Return(errCreate).Once() + db.Performer.On("Create", testCtx, &performerErr).Return(errCreate).Once() id, err := i.Create(testCtx) assert.Equal(t, performerID, *id) @@ -252,11 +258,11 @@ func TestCreate(t *testing.T) { assert.Nil(t, id) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Performer.AssertExpectations(t) } func TestUpdate(t *testing.T) { - readerWriter := &mocks.PerformerReaderWriter{} + db := mocks.NewDatabase() performer := models.Performer{ Name: performerName, @@ -267,7 +273,8 @@ func TestUpdate(t *testing.T) { } i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Performer, + TagWriter: db.Tag, performer: performer, } @@ -275,7 +282,7 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input performer.ID = performerID - readerWriter.On("Update", testCtx, &performer).Return(nil).Once() + db.Performer.On("Update", testCtx, &performer).Return(nil).Once() err := i.Update(testCtx, performerID) assert.Nil(t, err) @@ -284,10 +291,10 @@ func TestUpdate(t *testing.T) { // need to set id separately performerErr.ID = errImageID - readerWriter.On("Update", testCtx, &performerErr).Return(errUpdate).Once() + db.Performer.On("Update", testCtx, &performerErr).Return(errUpdate).Once() err = i.Update(testCtx, errImageID) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Performer.AssertExpectations(t) } diff --git a/pkg/scene/export_test.go b/pkg/scene/export_test.go index 19e12ecea70..0c44734ce54 100644 --- a/pkg/scene/export_test.go +++ b/pkg/scene/export_test.go @@ -186,17 +186,17 @@ var scenarios = []basicTestScenario{ } func TestToJSON(t *testing.T) { - mockSceneReader := &mocks.SceneReaderWriter{} + db := mocks.NewDatabase() imageErr := errors.New("error getting image") - mockSceneReader.On("GetCover", testCtx, sceneID).Return(imageBytes, nil).Once() - mockSceneReader.On("GetCover", testCtx, noImageID).Return(nil, nil).Once() - mockSceneReader.On("GetCover", testCtx, errImageID).Return(nil, imageErr).Once() + db.Scene.On("GetCover", testCtx, sceneID).Return(imageBytes, nil).Once() + db.Scene.On("GetCover", testCtx, noImageID).Return(nil, nil).Once() + db.Scene.On("GetCover", testCtx, errImageID).Return(nil, imageErr).Once() for i, s := range scenarios { scene := s.input - json, err := ToBasicJSON(testCtx, mockSceneReader, &scene) + json, err := ToBasicJSON(testCtx, db.Scene, &scene) switch { case !s.err && err != nil: @@ -208,7 +208,7 @@ func TestToJSON(t *testing.T) { } } - mockSceneReader.AssertExpectations(t) + db.Scene.AssertExpectations(t) } func createStudioScene(studioID int) models.Scene { @@ -242,19 +242,19 @@ var getStudioScenarios = []stringTestScenario{ } func TestGetStudioName(t *testing.T) { - mockStudioReader := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() studioErr := errors.New("error getting image") - mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{ + db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{ Name: studioName, }, nil).Once() - mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once() - mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once() + db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once() + db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once() for i, s := range getStudioScenarios { scene := s.input - json, err := GetStudioName(testCtx, mockStudioReader, &scene) + json, err := GetStudioName(testCtx, db.Studio, &scene) switch { case !s.err && err != nil: @@ -266,7 +266,7 @@ func TestGetStudioName(t *testing.T) { } } - mockStudioReader.AssertExpectations(t) + db.Studio.AssertExpectations(t) } type stringSliceTestScenario struct { @@ -305,17 +305,17 @@ func getTags(names []string) []*models.Tag { } func TestGetTagNames(t *testing.T) { - mockTagReader := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() tagErr := errors.New("error getting tag") - mockTagReader.On("FindBySceneID", testCtx, sceneID).Return(getTags(names), nil).Once() - mockTagReader.On("FindBySceneID", testCtx, noTagsID).Return(nil, nil).Once() - mockTagReader.On("FindBySceneID", testCtx, errTagsID).Return(nil, tagErr).Once() + db.Tag.On("FindBySceneID", testCtx, sceneID).Return(getTags(names), nil).Once() + db.Tag.On("FindBySceneID", testCtx, noTagsID).Return(nil, nil).Once() + db.Tag.On("FindBySceneID", testCtx, errTagsID).Return(nil, tagErr).Once() for i, s := range getTagNamesScenarios { scene := s.input - json, err := GetTagNames(testCtx, mockTagReader, &scene) + json, err := GetTagNames(testCtx, db.Tag, &scene) switch { case !s.err && err != nil: @@ -327,7 +327,7 @@ func TestGetTagNames(t *testing.T) { } } - mockTagReader.AssertExpectations(t) + db.Tag.AssertExpectations(t) } type sceneMoviesTestScenario struct { @@ -391,20 +391,21 @@ var getSceneMoviesJSONScenarios = []sceneMoviesTestScenario{ } func TestGetSceneMoviesJSON(t *testing.T) { - mockMovieReader := &mocks.MovieReaderWriter{} + db := mocks.NewDatabase() + movieErr := errors.New("error getting movie") - mockMovieReader.On("Find", testCtx, validMovie1).Return(&models.Movie{ + db.Movie.On("Find", testCtx, validMovie1).Return(&models.Movie{ Name: movie1Name, }, nil).Once() - mockMovieReader.On("Find", testCtx, validMovie2).Return(&models.Movie{ + db.Movie.On("Find", testCtx, validMovie2).Return(&models.Movie{ Name: movie2Name, }, nil).Once() - mockMovieReader.On("Find", testCtx, invalidMovie).Return(nil, movieErr).Once() + db.Movie.On("Find", testCtx, invalidMovie).Return(nil, movieErr).Once() for i, s := range getSceneMoviesJSONScenarios { scene := s.input - json, err := GetSceneMoviesJSON(testCtx, mockMovieReader, &scene) + json, err := GetSceneMoviesJSON(testCtx, db.Movie, &scene) switch { case !s.err && err != nil: @@ -416,7 +417,7 @@ func TestGetSceneMoviesJSON(t *testing.T) { } } - mockMovieReader.AssertExpectations(t) + db.Movie.AssertExpectations(t) } const ( @@ -542,27 +543,26 @@ var invalidMarkers2 = []*models.SceneMarker{ } func TestGetSceneMarkersJSON(t *testing.T) { - mockTagReader := &mocks.TagReaderWriter{} - mockMarkerReader := &mocks.SceneMarkerReaderWriter{} + db := mocks.NewDatabase() markersErr := errors.New("error getting scene markers") tagErr := errors.New("error getting tags") - mockMarkerReader.On("FindBySceneID", testCtx, sceneID).Return(validMarkers, nil).Once() - mockMarkerReader.On("FindBySceneID", testCtx, noMarkersID).Return(nil, nil).Once() - mockMarkerReader.On("FindBySceneID", testCtx, errMarkersID).Return(nil, markersErr).Once() - mockMarkerReader.On("FindBySceneID", testCtx, errFindPrimaryTagID).Return(invalidMarkers1, nil).Once() - mockMarkerReader.On("FindBySceneID", testCtx, errFindByMarkerID).Return(invalidMarkers2, nil).Once() + db.SceneMarker.On("FindBySceneID", testCtx, sceneID).Return(validMarkers, nil).Once() + db.SceneMarker.On("FindBySceneID", testCtx, noMarkersID).Return(nil, nil).Once() + db.SceneMarker.On("FindBySceneID", testCtx, errMarkersID).Return(nil, markersErr).Once() + db.SceneMarker.On("FindBySceneID", testCtx, errFindPrimaryTagID).Return(invalidMarkers1, nil).Once() + db.SceneMarker.On("FindBySceneID", testCtx, errFindByMarkerID).Return(invalidMarkers2, nil).Once() - mockTagReader.On("Find", testCtx, validTagID1).Return(&models.Tag{ + db.Tag.On("Find", testCtx, validTagID1).Return(&models.Tag{ Name: validTagName1, }, nil) - mockTagReader.On("Find", testCtx, validTagID2).Return(&models.Tag{ + db.Tag.On("Find", testCtx, validTagID2).Return(&models.Tag{ Name: validTagName2, }, nil) - mockTagReader.On("Find", testCtx, invalidTagID).Return(nil, tagErr) + db.Tag.On("Find", testCtx, invalidTagID).Return(nil, tagErr) - mockTagReader.On("FindBySceneMarkerID", testCtx, validMarkerID1).Return([]*models.Tag{ + db.Tag.On("FindBySceneMarkerID", testCtx, validMarkerID1).Return([]*models.Tag{ { Name: validTagName1, }, @@ -570,16 +570,16 @@ func TestGetSceneMarkersJSON(t *testing.T) { Name: validTagName2, }, }, nil) - mockTagReader.On("FindBySceneMarkerID", testCtx, validMarkerID2).Return([]*models.Tag{ + db.Tag.On("FindBySceneMarkerID", testCtx, validMarkerID2).Return([]*models.Tag{ { Name: validTagName2, }, }, nil) - mockTagReader.On("FindBySceneMarkerID", testCtx, invalidMarkerID2).Return(nil, tagErr).Once() + db.Tag.On("FindBySceneMarkerID", testCtx, invalidMarkerID2).Return(nil, tagErr).Once() for i, s := range getSceneMarkersJSONScenarios { scene := s.input - json, err := GetSceneMarkersJSON(testCtx, mockMarkerReader, mockTagReader, &scene) + json, err := GetSceneMarkersJSON(testCtx, db.SceneMarker, db.Tag, &scene) switch { case !s.err && err != nil: @@ -591,5 +591,5 @@ func TestGetSceneMarkersJSON(t *testing.T) { } } - mockTagReader.AssertExpectations(t) + db.Tag.AssertExpectations(t) } diff --git a/pkg/scene/import_test.go b/pkg/scene/import_test.go index f1bd5ceb373..bb13c96732d 100644 --- a/pkg/scene/import_test.go +++ b/pkg/scene/import_test.go @@ -56,20 +56,19 @@ func TestImporterPreImport(t *testing.T) { } func TestImporterPreImportWithStudio(t *testing.T) { - studioReaderWriter := &mocks.StudioReaderWriter{} - testCtx := context.Background() + db := mocks.NewDatabase() i := Importer{ - StudioWriter: studioReaderWriter, + StudioWriter: db.Studio, Input: jsonschema.Scene{ Studio: existingStudioName, }, } - studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ + db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ ID: existingStudioID, }, nil).Once() - studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() + db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() err := i.PreImport(testCtx) assert.Nil(t, err) @@ -79,22 +78,22 @@ func TestImporterPreImportWithStudio(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - studioReaderWriter.AssertExpectations(t) + db.Studio.AssertExpectations(t) } func TestImporterPreImportWithMissingStudio(t *testing.T) { - studioReaderWriter := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - StudioWriter: studioReaderWriter, + StudioWriter: db.Studio, Input: jsonschema.Scene{ Studio: missingStudioName, }, MissingRefBehaviour: models.ImportMissingRefEnumFail, } - studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) - studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { + db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { s := args.Get(1).(*models.Studio) s.ID = existingStudioID }).Return(nil) @@ -111,32 +110,32 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { assert.Nil(t, err) assert.Equal(t, existingStudioID, *i.scene.StudioID) - studioReaderWriter.AssertExpectations(t) + db.Studio.AssertExpectations(t) } func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { - studioReaderWriter := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - StudioWriter: studioReaderWriter, + StudioWriter: db.Studio, Input: jsonschema.Scene{ Studio: missingStudioName, }, MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() - studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) } func TestImporterPreImportWithPerformer(t *testing.T) { - performerReaderWriter := &mocks.PerformerReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - PerformerWriter: performerReaderWriter, + PerformerWriter: db.Performer, MissingRefBehaviour: models.ImportMissingRefEnumFail, Input: jsonschema.Scene{ Performers: []string{ @@ -145,13 +144,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) { }, } - performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ + db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ { ID: existingPerformerID, Name: existingPerformerName, }, }, nil).Once() - performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() + db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() err := i.PreImport(testCtx) assert.Nil(t, err) @@ -161,14 +160,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - performerReaderWriter.AssertExpectations(t) + db.Performer.AssertExpectations(t) } func TestImporterPreImportWithMissingPerformer(t *testing.T) { - performerReaderWriter := &mocks.PerformerReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - PerformerWriter: performerReaderWriter, + PerformerWriter: db.Performer, Input: jsonschema.Scene{ Performers: []string{ missingPerformerName, @@ -177,8 +176,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) - performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { + db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) + db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { p := args.Get(1).(*models.Performer) p.ID = existingPerformerID }).Return(nil) @@ -195,14 +194,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { assert.Nil(t, err) assert.Equal(t, []int{existingPerformerID}, i.scene.PerformerIDs.List()) - performerReaderWriter.AssertExpectations(t) + db.Performer.AssertExpectations(t) } func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { - performerReaderWriter := &mocks.PerformerReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - PerformerWriter: performerReaderWriter, + PerformerWriter: db.Performer, Input: jsonschema.Scene{ Performers: []string{ missingPerformerName, @@ -211,19 +210,18 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() - performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) + db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() + db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) } func TestImporterPreImportWithMovie(t *testing.T) { - movieReaderWriter := &mocks.MovieReaderWriter{} - testCtx := context.Background() + db := mocks.NewDatabase() i := Importer{ - MovieWriter: movieReaderWriter, + MovieWriter: db.Movie, MissingRefBehaviour: models.ImportMissingRefEnumFail, Input: jsonschema.Scene{ Movies: []jsonschema.SceneMovie{ @@ -235,11 +233,11 @@ func TestImporterPreImportWithMovie(t *testing.T) { }, } - movieReaderWriter.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{ + db.Movie.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{ ID: existingMovieID, Name: existingMovieName, }, nil).Once() - movieReaderWriter.On("FindByName", testCtx, existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once() + db.Movie.On("FindByName", testCtx, existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once() err := i.PreImport(testCtx) assert.Nil(t, err) @@ -249,15 +247,14 @@ func TestImporterPreImportWithMovie(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - movieReaderWriter.AssertExpectations(t) + db.Movie.AssertExpectations(t) } func TestImporterPreImportWithMissingMovie(t *testing.T) { - movieReaderWriter := &mocks.MovieReaderWriter{} - testCtx := context.Background() + db := mocks.NewDatabase() i := Importer{ - MovieWriter: movieReaderWriter, + MovieWriter: db.Movie, Input: jsonschema.Scene{ Movies: []jsonschema.SceneMovie{ { @@ -268,8 +265,8 @@ func TestImporterPreImportWithMissingMovie(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - movieReaderWriter.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Times(3) - movieReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Run(func(args mock.Arguments) { + db.Movie.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Times(3) + db.Movie.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Run(func(args mock.Arguments) { m := args.Get(1).(*models.Movie) m.ID = existingMovieID }).Return(nil) @@ -286,14 +283,14 @@ func TestImporterPreImportWithMissingMovie(t *testing.T) { assert.Nil(t, err) assert.Equal(t, existingMovieID, i.scene.Movies.List()[0].MovieID) - movieReaderWriter.AssertExpectations(t) + db.Movie.AssertExpectations(t) } func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) { - movieReaderWriter := &mocks.MovieReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - MovieWriter: movieReaderWriter, + MovieWriter: db.Movie, Input: jsonschema.Scene{ Movies: []jsonschema.SceneMovie{ { @@ -304,18 +301,18 @@ func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - movieReaderWriter.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Once() - movieReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Return(errors.New("Create error")) + db.Movie.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Once() + db.Movie.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) } func TestImporterPreImportWithTag(t *testing.T) { - tagReaderWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - TagWriter: tagReaderWriter, + TagWriter: db.Tag, MissingRefBehaviour: models.ImportMissingRefEnumFail, Input: jsonschema.Scene{ Tags: []string{ @@ -324,13 +321,13 @@ func TestImporterPreImportWithTag(t *testing.T) { }, } - tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ + db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ { ID: existingTagID, Name: existingTagName, }, }, nil).Once() - tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() + db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() err := i.PreImport(testCtx) assert.Nil(t, err) @@ -340,14 +337,14 @@ func TestImporterPreImportWithTag(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - tagReaderWriter.AssertExpectations(t) + db.Tag.AssertExpectations(t) } func TestImporterPreImportWithMissingTag(t *testing.T) { - tagReaderWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - TagWriter: tagReaderWriter, + TagWriter: db.Tag, Input: jsonschema.Scene{ Tags: []string{ missingTagName, @@ -356,8 +353,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) - tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { + db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { t := args.Get(1).(*models.Tag) t.ID = existingTagID }).Return(nil) @@ -374,14 +371,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { assert.Nil(t, err) assert.Equal(t, []int{existingTagID}, i.scene.TagIDs.List()) - tagReaderWriter.AssertExpectations(t) + db.Tag.AssertExpectations(t) } func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { - tagReaderWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - TagWriter: tagReaderWriter, + TagWriter: db.Tag, Input: jsonschema.Scene{ Tags: []string{ missingTagName, @@ -390,8 +387,8 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() - tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) + db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() + db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/scene/update_test.go b/pkg/scene/update_test.go index c89be66f35b..2aee2bbcbc6 100644 --- a/pkg/scene/update_test.go +++ b/pkg/scene/update_test.go @@ -1,7 +1,6 @@ package scene import ( - "context" "errors" "strconv" "testing" @@ -105,8 +104,6 @@ func TestUpdater_Update(t *testing.T) { tagID ) - ctx := context.Background() - performerIDs := []int{performerID} tagIDs := []int{tagID} stashID := "stashID" @@ -119,14 +116,15 @@ func TestUpdater_Update(t *testing.T) { updateErr := errors.New("error updating") - qb := mocks.SceneReaderWriter{} - qb.On("UpdatePartial", ctx, mock.MatchedBy(func(id int) bool { + db := mocks.NewDatabase() + + db.Scene.On("UpdatePartial", testCtx, mock.MatchedBy(func(id int) bool { return id != badUpdateID }), mock.Anything).Return(validScene, nil) - qb.On("UpdatePartial", ctx, badUpdateID, mock.Anything).Return(nil, updateErr) + db.Scene.On("UpdatePartial", testCtx, badUpdateID, mock.Anything).Return(nil, updateErr) - qb.On("UpdateCover", ctx, sceneID, cover).Return(nil).Once() - qb.On("UpdateCover", ctx, badCoverID, cover).Return(updateErr).Once() + db.Scene.On("UpdateCover", testCtx, sceneID, cover).Return(nil).Once() + db.Scene.On("UpdateCover", testCtx, badCoverID, cover).Return(updateErr).Once() tests := []struct { name string @@ -204,7 +202,7 @@ func TestUpdater_Update(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.u.Update(ctx, &qb) + got, err := tt.u.Update(testCtx, db.Scene) if (err != nil) != tt.wantErr { t.Errorf("Updater.Update() error = %v, wantErr %v", err, tt.wantErr) return @@ -215,7 +213,7 @@ func TestUpdater_Update(t *testing.T) { }) } - qb.AssertExpectations(t) + db.Scene.AssertExpectations(t) } func TestUpdateSet_UpdateInput(t *testing.T) { diff --git a/pkg/sqlite/setup_test.go b/pkg/sqlite/setup_test.go index cdead935d09..aad8ee161c1 100644 --- a/pkg/sqlite/setup_test.go +++ b/pkg/sqlite/setup_test.go @@ -1176,7 +1176,7 @@ func makeImage(i int) *models.Image { } func createImages(ctx context.Context, n int) error { - qb := db.Repository().Image + qb := db.Image fqb := db.File for i := 0; i < n; i++ { @@ -1262,7 +1262,7 @@ func makeGallery(i int, includeScenes bool) *models.Gallery { } func createGalleries(ctx context.Context, n int) error { - gqb := db.Repository().Gallery + gqb := db.Gallery fqb := db.File for i := 0; i < n; i++ { diff --git a/pkg/studio/export_test.go b/pkg/studio/export_test.go index f1cce33465c..6682213b018 100644 --- a/pkg/studio/export_test.go +++ b/pkg/studio/export_test.go @@ -1,7 +1,6 @@ package studio import ( - "context" "errors" "github.com/stashapp/stash/pkg/models" @@ -162,27 +161,26 @@ func initTestTable() { func TestToJSON(t *testing.T) { initTestTable() - ctx := context.Background() - mockStudioReader := &mocks.StudioReaderWriter{} + db := mocks.NewDatabase() imageErr := errors.New("error getting image") - mockStudioReader.On("GetImage", ctx, studioID).Return(imageBytes, nil).Once() - mockStudioReader.On("GetImage", ctx, noImageID).Return(nil, nil).Once() - mockStudioReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once() - mockStudioReader.On("GetImage", ctx, missingParentStudioID).Return(imageBytes, nil).Maybe() - mockStudioReader.On("GetImage", ctx, errStudioID).Return(imageBytes, nil).Maybe() + db.Studio.On("GetImage", testCtx, studioID).Return(imageBytes, nil).Once() + db.Studio.On("GetImage", testCtx, noImageID).Return(nil, nil).Once() + db.Studio.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once() + db.Studio.On("GetImage", testCtx, missingParentStudioID).Return(imageBytes, nil).Maybe() + db.Studio.On("GetImage", testCtx, errStudioID).Return(imageBytes, nil).Maybe() parentStudioErr := errors.New("error getting parent studio") - mockStudioReader.On("Find", ctx, parentStudioID).Return(&parentStudio, nil) - mockStudioReader.On("Find", ctx, missingStudioID).Return(nil, nil) - mockStudioReader.On("Find", ctx, errParentStudioID).Return(nil, parentStudioErr) + db.Studio.On("Find", testCtx, parentStudioID).Return(&parentStudio, nil) + db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil) + db.Studio.On("Find", testCtx, errParentStudioID).Return(nil, parentStudioErr) for i, s := range scenarios { studio := s.input - json, err := ToJSON(ctx, mockStudioReader, &studio) + json, err := ToJSON(testCtx, db.Studio, &studio) switch { case !s.err && err != nil: @@ -194,5 +192,5 @@ func TestToJSON(t *testing.T) { } } - mockStudioReader.AssertExpectations(t) + db.Studio.AssertExpectations(t) } diff --git a/pkg/studio/import_test.go b/pkg/studio/import_test.go index d754a01c17f..78b40d3d9f5 100644 --- a/pkg/studio/import_test.go +++ b/pkg/studio/import_test.go @@ -25,6 +25,8 @@ const ( missingParentStudioName = "existingParentStudioName" ) +var testCtx = context.Background() + func TestImporterName(t *testing.T) { i := Importer{ Input: jsonschema.Studio{ @@ -43,22 +45,21 @@ func TestImporterPreImport(t *testing.T) { IgnoreAutoTag: autoTagIgnored, }, } - ctx := context.Background() - err := i.PreImport(ctx) + err := i.PreImport(testCtx) assert.NotNil(t, err) i.Input.Image = image - err = i.PreImport(ctx) + err = i.PreImport(testCtx) assert.Nil(t, err) i.Input = *createFullJSONStudio(studioName, image, []string{"alias"}) i.Input.ParentStudio = "" - err = i.PreImport(ctx) + err = i.PreImport(testCtx) assert.Nil(t, err) expectedStudio := createFullStudio(0, 0) @@ -67,11 +68,10 @@ func TestImporterPreImport(t *testing.T) { } func TestImporterPreImportWithParent(t *testing.T) { - readerWriter := &mocks.StudioReaderWriter{} - ctx := context.Background() + db := mocks.NewDatabase() i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Studio, Input: jsonschema.Studio{ Name: studioName, Image: image, @@ -79,28 +79,27 @@ func TestImporterPreImportWithParent(t *testing.T) { }, } - readerWriter.On("FindByName", ctx, existingParentStudioName, false).Return(&models.Studio{ + db.Studio.On("FindByName", testCtx, existingParentStudioName, false).Return(&models.Studio{ ID: existingStudioID, }, nil).Once() - readerWriter.On("FindByName", ctx, existingParentStudioErr, false).Return(nil, errors.New("FindByName error")).Once() + db.Studio.On("FindByName", testCtx, existingParentStudioErr, false).Return(nil, errors.New("FindByName error")).Once() - err := i.PreImport(ctx) + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingStudioID, *i.studio.ParentID) i.Input.ParentStudio = existingParentStudioErr - err = i.PreImport(ctx) + err = i.PreImport(testCtx) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Studio.AssertExpectations(t) } func TestImporterPreImportWithMissingParent(t *testing.T) { - readerWriter := &mocks.StudioReaderWriter{} - ctx := context.Background() + db := mocks.NewDatabase() i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Studio, Input: jsonschema.Studio{ Name: studioName, Image: image, @@ -109,33 +108,32 @@ func TestImporterPreImportWithMissingParent(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - readerWriter.On("FindByName", ctx, missingParentStudioName, false).Return(nil, nil).Times(3) - readerWriter.On("Create", ctx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { + db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Times(3) + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { s := args.Get(1).(*models.Studio) s.ID = existingStudioID }).Return(nil) - err := i.PreImport(ctx) + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport(ctx) + err = i.PreImport(testCtx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport(ctx) + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingStudioID, *i.studio.ParentID) - readerWriter.AssertExpectations(t) + db.Studio.AssertExpectations(t) } func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) { - readerWriter := &mocks.StudioReaderWriter{} - ctx := context.Background() + db := mocks.NewDatabase() i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Studio, Input: jsonschema.Studio{ Name: studioName, Image: image, @@ -144,19 +142,18 @@ func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - readerWriter.On("FindByName", ctx, missingParentStudioName, false).Return(nil, nil).Once() - readerWriter.On("Create", ctx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) + db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Once() + db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) - err := i.PreImport(ctx) + err := i.PreImport(testCtx) assert.NotNil(t, err) } func TestImporterPostImport(t *testing.T) { - readerWriter := &mocks.StudioReaderWriter{} - ctx := context.Background() + db := mocks.NewDatabase() i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Studio, Input: jsonschema.Studio{ Aliases: []string{"alias"}, }, @@ -165,56 +162,54 @@ func TestImporterPostImport(t *testing.T) { updateStudioImageErr := errors.New("UpdateImage error") - readerWriter.On("UpdateImage", ctx, studioID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateImage", ctx, errImageID, imageBytes).Return(updateStudioImageErr).Once() + db.Studio.On("UpdateImage", testCtx, studioID, imageBytes).Return(nil).Once() + db.Studio.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateStudioImageErr).Once() - err := i.PostImport(ctx, studioID) + err := i.PostImport(testCtx, studioID) assert.Nil(t, err) - err = i.PostImport(ctx, errImageID) + err = i.PostImport(testCtx, errImageID) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Studio.AssertExpectations(t) } func TestImporterFindExistingID(t *testing.T) { - readerWriter := &mocks.StudioReaderWriter{} - ctx := context.Background() + db := mocks.NewDatabase() i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Studio, Input: jsonschema.Studio{ Name: studioName, }, } errFindByName := errors.New("FindByName error") - readerWriter.On("FindByName", ctx, studioName, false).Return(nil, nil).Once() - readerWriter.On("FindByName", ctx, existingStudioName, false).Return(&models.Studio{ + db.Studio.On("FindByName", testCtx, studioName, false).Return(nil, nil).Once() + db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ ID: existingStudioID, }, nil).Once() - readerWriter.On("FindByName", ctx, studioNameErr, false).Return(nil, errFindByName).Once() + db.Studio.On("FindByName", testCtx, studioNameErr, false).Return(nil, errFindByName).Once() - id, err := i.FindExistingID(ctx) + id, err := i.FindExistingID(testCtx) assert.Nil(t, id) assert.Nil(t, err) i.Input.Name = existingStudioName - id, err = i.FindExistingID(ctx) + id, err = i.FindExistingID(testCtx) assert.Equal(t, existingStudioID, *id) assert.Nil(t, err) i.Input.Name = studioNameErr - id, err = i.FindExistingID(ctx) + id, err = i.FindExistingID(testCtx) assert.Nil(t, id) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Studio.AssertExpectations(t) } func TestCreate(t *testing.T) { - readerWriter := &mocks.StudioReaderWriter{} - ctx := context.Background() + db := mocks.NewDatabase() studio := models.Studio{ Name: studioName, @@ -225,32 +220,31 @@ func TestCreate(t *testing.T) { } i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Studio, studio: studio, } errCreate := errors.New("Create error") - readerWriter.On("Create", ctx, &studio).Run(func(args mock.Arguments) { + db.Studio.On("Create", testCtx, &studio).Run(func(args mock.Arguments) { s := args.Get(1).(*models.Studio) s.ID = studioID }).Return(nil).Once() - readerWriter.On("Create", ctx, &studioErr).Return(errCreate).Once() + db.Studio.On("Create", testCtx, &studioErr).Return(errCreate).Once() - id, err := i.Create(ctx) + id, err := i.Create(testCtx) assert.Equal(t, studioID, *id) assert.Nil(t, err) i.studio = studioErr - id, err = i.Create(ctx) + id, err = i.Create(testCtx) assert.Nil(t, id) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Studio.AssertExpectations(t) } func TestUpdate(t *testing.T) { - readerWriter := &mocks.StudioReaderWriter{} - ctx := context.Background() + db := mocks.NewDatabase() studio := models.Studio{ Name: studioName, @@ -261,7 +255,7 @@ func TestUpdate(t *testing.T) { } i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Studio, studio: studio, } @@ -269,19 +263,19 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input studio.ID = studioID - readerWriter.On("Update", ctx, &studio).Return(nil).Once() + db.Studio.On("Update", testCtx, &studio).Return(nil).Once() - err := i.Update(ctx, studioID) + err := i.Update(testCtx, studioID) assert.Nil(t, err) i.studio = studioErr // need to set id separately studioErr.ID = errImageID - readerWriter.On("Update", ctx, &studioErr).Return(errUpdate).Once() + db.Studio.On("Update", testCtx, &studioErr).Return(errUpdate).Once() - err = i.Update(ctx, errImageID) + err = i.Update(testCtx, errImageID) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Studio.AssertExpectations(t) } diff --git a/pkg/tag/export_test.go b/pkg/tag/export_test.go index c4f4691d73f..75a85b7eb15 100644 --- a/pkg/tag/export_test.go +++ b/pkg/tag/export_test.go @@ -1,7 +1,6 @@ package tag import ( - "context" "errors" "github.com/stashapp/stash/pkg/models" @@ -109,35 +108,34 @@ func initTestTable() { func TestToJSON(t *testing.T) { initTestTable() - ctx := context.Background() - mockTagReader := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() imageErr := errors.New("error getting image") aliasErr := errors.New("error getting aliases") parentsErr := errors.New("error getting parents") - mockTagReader.On("GetAliases", ctx, tagID).Return([]string{"alias"}, nil).Once() - mockTagReader.On("GetAliases", ctx, noImageID).Return(nil, nil).Once() - mockTagReader.On("GetAliases", ctx, errImageID).Return(nil, nil).Once() - mockTagReader.On("GetAliases", ctx, errAliasID).Return(nil, aliasErr).Once() - mockTagReader.On("GetAliases", ctx, withParentsID).Return(nil, nil).Once() - mockTagReader.On("GetAliases", ctx, errParentsID).Return(nil, nil).Once() - - mockTagReader.On("GetImage", ctx, tagID).Return(imageBytes, nil).Once() - mockTagReader.On("GetImage", ctx, noImageID).Return(nil, nil).Once() - mockTagReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once() - mockTagReader.On("GetImage", ctx, withParentsID).Return(imageBytes, nil).Once() - mockTagReader.On("GetImage", ctx, errParentsID).Return(nil, nil).Once() - - mockTagReader.On("FindByChildTagID", ctx, tagID).Return(nil, nil).Once() - mockTagReader.On("FindByChildTagID", ctx, noImageID).Return(nil, nil).Once() - mockTagReader.On("FindByChildTagID", ctx, withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once() - mockTagReader.On("FindByChildTagID", ctx, errParentsID).Return(nil, parentsErr).Once() - mockTagReader.On("FindByChildTagID", ctx, errImageID).Return(nil, nil).Once() + db.Tag.On("GetAliases", testCtx, tagID).Return([]string{"alias"}, nil).Once() + db.Tag.On("GetAliases", testCtx, noImageID).Return(nil, nil).Once() + db.Tag.On("GetAliases", testCtx, errImageID).Return(nil, nil).Once() + db.Tag.On("GetAliases", testCtx, errAliasID).Return(nil, aliasErr).Once() + db.Tag.On("GetAliases", testCtx, withParentsID).Return(nil, nil).Once() + db.Tag.On("GetAliases", testCtx, errParentsID).Return(nil, nil).Once() + + db.Tag.On("GetImage", testCtx, tagID).Return(imageBytes, nil).Once() + db.Tag.On("GetImage", testCtx, noImageID).Return(nil, nil).Once() + db.Tag.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once() + db.Tag.On("GetImage", testCtx, withParentsID).Return(imageBytes, nil).Once() + db.Tag.On("GetImage", testCtx, errParentsID).Return(nil, nil).Once() + + db.Tag.On("FindByChildTagID", testCtx, tagID).Return(nil, nil).Once() + db.Tag.On("FindByChildTagID", testCtx, noImageID).Return(nil, nil).Once() + db.Tag.On("FindByChildTagID", testCtx, withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once() + db.Tag.On("FindByChildTagID", testCtx, errParentsID).Return(nil, parentsErr).Once() + db.Tag.On("FindByChildTagID", testCtx, errImageID).Return(nil, nil).Once() for i, s := range scenarios { tag := s.tag - json, err := ToJSON(ctx, mockTagReader, &tag) + json, err := ToJSON(testCtx, db.Tag, &tag) switch { case !s.err && err != nil: @@ -149,5 +147,5 @@ func TestToJSON(t *testing.T) { } } - mockTagReader.AssertExpectations(t) + db.Tag.AssertExpectations(t) } diff --git a/pkg/tag/import_test.go b/pkg/tag/import_test.go index 997fb35f773..9378adf75a3 100644 --- a/pkg/tag/import_test.go +++ b/pkg/tag/import_test.go @@ -58,10 +58,10 @@ func TestImporterPreImport(t *testing.T) { } func TestImporterPostImport(t *testing.T) { - readerWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Tag, Input: jsonschema.Tag{ Aliases: []string{"alias"}, }, @@ -72,23 +72,23 @@ func TestImporterPostImport(t *testing.T) { updateTagAliasErr := errors.New("UpdateAlias error") updateTagParentsErr := errors.New("UpdateParentTags error") - readerWriter.On("UpdateAliases", testCtx, tagID, i.Input.Aliases).Return(nil).Once() - readerWriter.On("UpdateAliases", testCtx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once() - readerWriter.On("UpdateAliases", testCtx, withParentsID, i.Input.Aliases).Return(nil).Once() - readerWriter.On("UpdateAliases", testCtx, errParentsID, i.Input.Aliases).Return(nil).Once() + db.Tag.On("UpdateAliases", testCtx, tagID, i.Input.Aliases).Return(nil).Once() + db.Tag.On("UpdateAliases", testCtx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once() + db.Tag.On("UpdateAliases", testCtx, withParentsID, i.Input.Aliases).Return(nil).Once() + db.Tag.On("UpdateAliases", testCtx, errParentsID, i.Input.Aliases).Return(nil).Once() - readerWriter.On("UpdateImage", testCtx, tagID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateImage", testCtx, errAliasID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateTagImageErr).Once() - readerWriter.On("UpdateImage", testCtx, withParentsID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateImage", testCtx, errParentsID, imageBytes).Return(nil).Once() + db.Tag.On("UpdateImage", testCtx, tagID, imageBytes).Return(nil).Once() + db.Tag.On("UpdateImage", testCtx, errAliasID, imageBytes).Return(nil).Once() + db.Tag.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateTagImageErr).Once() + db.Tag.On("UpdateImage", testCtx, withParentsID, imageBytes).Return(nil).Once() + db.Tag.On("UpdateImage", testCtx, errParentsID, imageBytes).Return(nil).Once() var parentTags []int - readerWriter.On("UpdateParentTags", testCtx, tagID, parentTags).Return(nil).Once() - readerWriter.On("UpdateParentTags", testCtx, withParentsID, []int{100}).Return(nil).Once() - readerWriter.On("UpdateParentTags", testCtx, errParentsID, []int{100}).Return(updateTagParentsErr).Once() + db.Tag.On("UpdateParentTags", testCtx, tagID, parentTags).Return(nil).Once() + db.Tag.On("UpdateParentTags", testCtx, withParentsID, []int{100}).Return(nil).Once() + db.Tag.On("UpdateParentTags", testCtx, errParentsID, []int{100}).Return(updateTagParentsErr).Once() - readerWriter.On("FindByName", testCtx, "Parent", false).Return(&models.Tag{ID: 100}, nil) + db.Tag.On("FindByName", testCtx, "Parent", false).Return(&models.Tag{ID: 100}, nil) err := i.PostImport(testCtx, tagID) assert.Nil(t, err) @@ -106,14 +106,14 @@ func TestImporterPostImport(t *testing.T) { err = i.PostImport(testCtx, errParentsID) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Tag.AssertExpectations(t) } func TestImporterPostImportParentMissing(t *testing.T) { - readerWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Tag, Input: jsonschema.Tag{}, imageData: imageBytes, } @@ -133,33 +133,33 @@ func TestImporterPostImportParentMissing(t *testing.T) { var emptyParents []int - readerWriter.On("UpdateImage", testCtx, mock.Anything, mock.Anything).Return(nil) - readerWriter.On("UpdateAliases", testCtx, mock.Anything, mock.Anything).Return(nil) - - readerWriter.On("FindByName", testCtx, "Create", false).Return(nil, nil).Once() - readerWriter.On("FindByName", testCtx, "CreateError", false).Return(nil, nil).Once() - readerWriter.On("FindByName", testCtx, "CreateFindError", false).Return(nil, findError).Once() - readerWriter.On("FindByName", testCtx, "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once() - readerWriter.On("FindByName", testCtx, "Fail", false).Return(nil, nil).Once() - readerWriter.On("FindByName", testCtx, "FailFindError", false).Return(nil, findError) - readerWriter.On("FindByName", testCtx, "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once() - readerWriter.On("FindByName", testCtx, "Ignore", false).Return(nil, nil).Once() - readerWriter.On("FindByName", testCtx, "IgnoreFindError", false).Return(nil, findError) - readerWriter.On("FindByName", testCtx, "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once() - - readerWriter.On("UpdateParentTags", testCtx, createID, []int{100}).Return(nil).Once() - readerWriter.On("UpdateParentTags", testCtx, createFoundID, []int{101}).Return(nil).Once() - readerWriter.On("UpdateParentTags", testCtx, failFoundID, []int{102}).Return(nil).Once() - readerWriter.On("UpdateParentTags", testCtx, ignoreID, emptyParents).Return(nil).Once() - readerWriter.On("UpdateParentTags", testCtx, ignoreFoundID, []int{103}).Return(nil).Once() - - readerWriter.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool { + db.Tag.On("UpdateImage", testCtx, mock.Anything, mock.Anything).Return(nil) + db.Tag.On("UpdateAliases", testCtx, mock.Anything, mock.Anything).Return(nil) + + db.Tag.On("FindByName", testCtx, "Create", false).Return(nil, nil).Once() + db.Tag.On("FindByName", testCtx, "CreateError", false).Return(nil, nil).Once() + db.Tag.On("FindByName", testCtx, "CreateFindError", false).Return(nil, findError).Once() + db.Tag.On("FindByName", testCtx, "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once() + db.Tag.On("FindByName", testCtx, "Fail", false).Return(nil, nil).Once() + db.Tag.On("FindByName", testCtx, "FailFindError", false).Return(nil, findError) + db.Tag.On("FindByName", testCtx, "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once() + db.Tag.On("FindByName", testCtx, "Ignore", false).Return(nil, nil).Once() + db.Tag.On("FindByName", testCtx, "IgnoreFindError", false).Return(nil, findError) + db.Tag.On("FindByName", testCtx, "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once() + + db.Tag.On("UpdateParentTags", testCtx, createID, []int{100}).Return(nil).Once() + db.Tag.On("UpdateParentTags", testCtx, createFoundID, []int{101}).Return(nil).Once() + db.Tag.On("UpdateParentTags", testCtx, failFoundID, []int{102}).Return(nil).Once() + db.Tag.On("UpdateParentTags", testCtx, ignoreID, emptyParents).Return(nil).Once() + db.Tag.On("UpdateParentTags", testCtx, ignoreFoundID, []int{103}).Return(nil).Once() + + db.Tag.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool { return t.Name == "Create" })).Run(func(args mock.Arguments) { t := args.Get(1).(*models.Tag) t.ID = 100 }).Return(nil).Once() - readerWriter.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool { + db.Tag.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool { return t.Name == "CreateError" })).Return(errors.New("failed creating parent")).Once() @@ -206,25 +206,25 @@ func TestImporterPostImportParentMissing(t *testing.T) { err = i.PostImport(testCtx, ignoreFoundID) assert.Nil(t, err) - readerWriter.AssertExpectations(t) + db.Tag.AssertExpectations(t) } func TestImporterFindExistingID(t *testing.T) { - readerWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Tag, Input: jsonschema.Tag{ Name: tagName, }, } errFindByName := errors.New("FindByName error") - readerWriter.On("FindByName", testCtx, tagName, false).Return(nil, nil).Once() - readerWriter.On("FindByName", testCtx, existingTagName, false).Return(&models.Tag{ + db.Tag.On("FindByName", testCtx, tagName, false).Return(nil, nil).Once() + db.Tag.On("FindByName", testCtx, existingTagName, false).Return(&models.Tag{ ID: existingTagID, }, nil).Once() - readerWriter.On("FindByName", testCtx, tagNameErr, false).Return(nil, errFindByName).Once() + db.Tag.On("FindByName", testCtx, tagNameErr, false).Return(nil, errFindByName).Once() id, err := i.FindExistingID(testCtx) assert.Nil(t, id) @@ -240,11 +240,11 @@ func TestImporterFindExistingID(t *testing.T) { assert.Nil(t, id) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Tag.AssertExpectations(t) } func TestCreate(t *testing.T) { - readerWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() tag := models.Tag{ Name: tagName, @@ -255,16 +255,16 @@ func TestCreate(t *testing.T) { } i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Tag, tag: tag, } errCreate := errors.New("Create error") - readerWriter.On("Create", testCtx, &tag).Run(func(args mock.Arguments) { + db.Tag.On("Create", testCtx, &tag).Run(func(args mock.Arguments) { t := args.Get(1).(*models.Tag) t.ID = tagID }).Return(nil).Once() - readerWriter.On("Create", testCtx, &tagErr).Return(errCreate).Once() + db.Tag.On("Create", testCtx, &tagErr).Return(errCreate).Once() id, err := i.Create(testCtx) assert.Equal(t, tagID, *id) @@ -275,11 +275,11 @@ func TestCreate(t *testing.T) { assert.Nil(t, id) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Tag.AssertExpectations(t) } func TestUpdate(t *testing.T) { - readerWriter := &mocks.TagReaderWriter{} + db := mocks.NewDatabase() tag := models.Tag{ Name: tagName, @@ -290,7 +290,7 @@ func TestUpdate(t *testing.T) { } i := Importer{ - ReaderWriter: readerWriter, + ReaderWriter: db.Tag, tag: tag, } @@ -298,7 +298,7 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input tag.ID = tagID - readerWriter.On("Update", testCtx, &tag).Return(nil).Once() + db.Tag.On("Update", testCtx, &tag).Return(nil).Once() err := i.Update(testCtx, tagID) assert.Nil(t, err) @@ -307,10 +307,10 @@ func TestUpdate(t *testing.T) { // need to set id separately tagErr.ID = errImageID - readerWriter.On("Update", testCtx, &tagErr).Return(errUpdate).Once() + db.Tag.On("Update", testCtx, &tagErr).Return(errUpdate).Once() err = i.Update(testCtx, errImageID) assert.NotNil(t, err) - readerWriter.AssertExpectations(t) + db.Tag.AssertExpectations(t) } diff --git a/pkg/tag/update_test.go b/pkg/tag/update_test.go index 4cc14e96142..bb137d818b4 100644 --- a/pkg/tag/update_test.go +++ b/pkg/tag/update_test.go @@ -219,8 +219,7 @@ func TestEnsureHierarchy(t *testing.T) { } func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, queryChildren bool) { - mockTagReader := &mocks.TagReaderWriter{} - ctx := context.Background() + db := mocks.NewDatabase() var parentIDs, childIDs []int find := make(map[int]*models.Tag) @@ -247,15 +246,15 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, if queryParents { parentIDs = nil - mockTagReader.On("FindByChildTagID", ctx, tc.id).Return(tc.parents, nil).Once() + db.Tag.On("FindByChildTagID", testCtx, tc.id).Return(tc.parents, nil).Once() } if queryChildren { childIDs = nil - mockTagReader.On("FindByParentTagID", ctx, tc.id).Return(tc.children, nil).Once() + db.Tag.On("FindByParentTagID", testCtx, tc.id).Return(tc.children, nil).Once() } - mockTagReader.On("FindAllAncestors", ctx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath { + db.Tag.On("FindAllAncestors", testCtx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath { return tc.onFindAllAncestors }, func(ctx context.Context, tagID int, excludeIDs []int) error { if tc.onFindAllAncestors != nil { @@ -264,7 +263,7 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, return fmt.Errorf("undefined ancestors for: %d", tagID) }).Maybe() - mockTagReader.On("FindAllDescendants", ctx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath { + db.Tag.On("FindAllDescendants", testCtx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath { return tc.onFindAllDescendants }, func(ctx context.Context, tagID int, excludeIDs []int) error { if tc.onFindAllDescendants != nil { @@ -273,7 +272,7 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, return fmt.Errorf("undefined descendants for: %d", tagID) }).Maybe() - res := ValidateHierarchy(ctx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, mockTagReader) + res := ValidateHierarchy(testCtx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, db.Tag) assert := assert.New(t) @@ -285,5 +284,5 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, assert.Nil(res) } - mockTagReader.AssertExpectations(t) + db.Tag.AssertExpectations(t) } From d3bf97e21d8a9bfbd8e2610970f899251e50f6a6 Mon Sep 17 00:00:00 2001 From: DingDongSoLong4 <99329275+DingDongSoLong4@users.noreply.github.com> Date: Mon, 17 Jul 2023 00:13:45 +0200 Subject: [PATCH 4/9] Add AssertExpectations method --- internal/api/resolver_mutation_tag_test.go | 3 ++- internal/autotag/gallery_test.go | 9 +++----- internal/autotag/image_test.go | 9 +++----- internal/autotag/performer_test.go | 6 +++--- internal/autotag/scene_test.go | 9 +++----- internal/autotag/studio_test.go | 6 +++--- internal/autotag/tag_test.go | 6 +++--- pkg/gallery/export_test.go | 4 +++- pkg/gallery/import_test.go | 18 ++++++++++------ pkg/image/export_test.go | 2 +- pkg/image/import_test.go | 18 ++++++++++------ pkg/models/mocks/database.go | 16 +++++++++++++++ pkg/movie/export_test.go | 3 +-- pkg/movie/import_test.go | 14 +++++++------ pkg/performer/export_test.go | 2 +- pkg/performer/import_test.go | 14 +++++++------ pkg/scene/export_test.go | 10 ++++----- pkg/scene/import_test.go | 24 ++++++++++++++-------- pkg/scene/update_test.go | 2 +- pkg/studio/export_test.go | 2 +- pkg/studio/import_test.go | 14 +++++++------ pkg/tag/export_test.go | 2 +- pkg/tag/import_test.go | 10 ++++----- pkg/tag/update_test.go | 2 +- 24 files changed, 120 insertions(+), 85 deletions(-) diff --git a/internal/api/resolver_mutation_tag_test.go b/internal/api/resolver_mutation_tag_test.go index 0c9a5c733ce..5f94e1e916e 100644 --- a/internal/api/resolver_mutation_tag_test.go +++ b/internal/api/resolver_mutation_tag_test.go @@ -94,7 +94,7 @@ func TestTagCreate(t *testing.T) { }) assert.Equal(t, expectedErr, err) - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) db = mocks.NewDatabase() r = newResolver(db) @@ -117,4 +117,5 @@ func TestTagCreate(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, tag) + db.AssertExpectations(t) } diff --git a/internal/autotag/gallery_test.go b/internal/autotag/gallery_test.go index b6214b1246b..6333a6c172a 100644 --- a/internal/autotag/gallery_test.go +++ b/internal/autotag/gallery_test.go @@ -79,8 +79,7 @@ func TestGalleryPerformers(t *testing.T) { err := GalleryPerformers(testCtx, &gallery, db.Gallery, db.Performer, nil) assert.Nil(err) - db.Performer.AssertExpectations(t) - db.Gallery.AssertExpectations(t) + db.AssertExpectations(t) } } @@ -125,8 +124,7 @@ func TestGalleryStudios(t *testing.T) { err := GalleryStudios(testCtx, &gallery, db.Gallery, db.Studio, nil) assert.Nil(err) - db.Studio.AssertExpectations(t) - db.Gallery.AssertExpectations(t) + db.AssertExpectations(t) } for _, test := range testTables { @@ -202,8 +200,7 @@ func TestGalleryTags(t *testing.T) { err := GalleryTags(testCtx, &gallery, db.Gallery, db.Tag, nil) assert.Nil(err) - db.Tag.AssertExpectations(t) - db.Gallery.AssertExpectations(t) + db.AssertExpectations(t) } for _, test := range testTables { diff --git a/internal/autotag/image_test.go b/internal/autotag/image_test.go index 8c5b6ae0762..88e42532d08 100644 --- a/internal/autotag/image_test.go +++ b/internal/autotag/image_test.go @@ -76,8 +76,7 @@ func TestImagePerformers(t *testing.T) { err := ImagePerformers(testCtx, &image, db.Image, db.Performer, nil) assert.Nil(err) - db.Performer.AssertExpectations(t) - db.Image.AssertExpectations(t) + db.AssertExpectations(t) } } @@ -122,8 +121,7 @@ func TestImageStudios(t *testing.T) { err := ImageStudios(testCtx, &image, db.Image, db.Studio, nil) assert.Nil(err) - db.Studio.AssertExpectations(t) - db.Image.AssertExpectations(t) + db.AssertExpectations(t) } for _, test := range testTables { @@ -199,8 +197,7 @@ func TestImageTags(t *testing.T) { err := ImageTags(testCtx, &image, db.Image, db.Tag, nil) assert.Nil(err) - db.Tag.AssertExpectations(t) - db.Image.AssertExpectations(t) + db.AssertExpectations(t) } for _, test := range testTables { diff --git a/internal/autotag/performer_test.go b/internal/autotag/performer_test.go index 1b798e24081..16df6a0ea86 100644 --- a/internal/autotag/performer_test.go +++ b/internal/autotag/performer_test.go @@ -112,7 +112,7 @@ func testPerformerScenes(t *testing.T, performerName, expectedRegex string) { assert := assert.New(t) assert.Nil(err) - db.Scene.AssertExpectations(t) + db.AssertExpectations(t) } func TestPerformerImages(t *testing.T) { @@ -207,7 +207,7 @@ func testPerformerImages(t *testing.T, performerName, expectedRegex string) { assert := assert.New(t) assert.Nil(err) - db.Image.AssertExpectations(t) + db.AssertExpectations(t) } func TestPerformerGalleries(t *testing.T) { @@ -302,5 +302,5 @@ func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) { assert := assert.New(t) assert.Nil(err) - db.Gallery.AssertExpectations(t) + db.AssertExpectations(t) } diff --git a/internal/autotag/scene_test.go b/internal/autotag/scene_test.go index 611c56c38a3..aaf015c8ff7 100644 --- a/internal/autotag/scene_test.go +++ b/internal/autotag/scene_test.go @@ -210,8 +210,7 @@ func TestScenePerformers(t *testing.T) { err := ScenePerformers(testCtx, &scene, db.Scene, db.Performer, nil) assert.Nil(err) - db.Performer.AssertExpectations(t) - db.Scene.AssertExpectations(t) + db.AssertExpectations(t) } } @@ -258,8 +257,7 @@ func TestSceneStudios(t *testing.T) { err := SceneStudios(testCtx, &scene, db.Scene, db.Studio, nil) assert.Nil(err) - db.Studio.AssertExpectations(t) - db.Scene.AssertExpectations(t) + db.AssertExpectations(t) } for _, test := range testTables { @@ -335,8 +333,7 @@ func TestSceneTags(t *testing.T) { err := SceneTags(testCtx, &scene, db.Scene, db.Tag, nil) assert.Nil(err) - db.Tag.AssertExpectations(t) - db.Scene.AssertExpectations(t) + db.AssertExpectations(t) } for _, test := range testTables { diff --git a/internal/autotag/studio_test.go b/internal/autotag/studio_test.go index d5db603d34f..3a806e0047b 100644 --- a/internal/autotag/studio_test.go +++ b/internal/autotag/studio_test.go @@ -171,7 +171,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) { assert := assert.New(t) assert.Nil(err) - db.Scene.AssertExpectations(t) + db.AssertExpectations(t) } func TestStudioImages(t *testing.T) { @@ -274,7 +274,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) { assert := assert.New(t) assert.Nil(err) - db.Image.AssertExpectations(t) + db.AssertExpectations(t) } func TestStudioGalleries(t *testing.T) { @@ -377,5 +377,5 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) { assert := assert.New(t) assert.Nil(err) - db.Gallery.AssertExpectations(t) + db.AssertExpectations(t) } diff --git a/internal/autotag/tag_test.go b/internal/autotag/tag_test.go index cb2ae907b25..f14c1b5409c 100644 --- a/internal/autotag/tag_test.go +++ b/internal/autotag/tag_test.go @@ -174,7 +174,7 @@ func testTagScenes(t *testing.T, tc testTagCase) { assert := assert.New(t) assert.Nil(err) - db.Scene.AssertExpectations(t) + db.AssertExpectations(t) } func TestTagImages(t *testing.T) { @@ -281,7 +281,7 @@ func testTagImages(t *testing.T, tc testTagCase) { assert := assert.New(t) assert.Nil(err) - db.Image.AssertExpectations(t) + db.AssertExpectations(t) } func TestTagGalleries(t *testing.T) { @@ -389,5 +389,5 @@ func testTagGalleries(t *testing.T, tc testTagCase) { assert := assert.New(t) assert.Nil(err) - db.Gallery.AssertExpectations(t) + db.AssertExpectations(t) } diff --git a/pkg/gallery/export_test.go b/pkg/gallery/export_test.go index db0c0691043..12563c642e7 100644 --- a/pkg/gallery/export_test.go +++ b/pkg/gallery/export_test.go @@ -181,7 +181,7 @@ func TestGetStudioName(t *testing.T) { } } - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } const ( @@ -279,4 +279,6 @@ func TestGetGalleryChaptersJSON(t *testing.T) { assert.Equal(t, s.expected, json, "[%d]", i) } } + + db.AssertExpectations(t) } diff --git a/pkg/gallery/import_test.go b/pkg/gallery/import_test.go index e165fe3afb1..58e97ae0641 100644 --- a/pkg/gallery/import_test.go +++ b/pkg/gallery/import_test.go @@ -100,7 +100,7 @@ func TestImporterPreImportWithStudio(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingStudio(t *testing.T) { @@ -132,7 +132,7 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { assert.Nil(t, err) assert.Equal(t, existingStudioID, *i.gallery.StudioID) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { @@ -151,6 +151,8 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { err := i.PreImport(testCtx) assert.NotNil(t, err) + + db.AssertExpectations(t) } func TestImporterPreImportWithPerformer(t *testing.T) { @@ -182,7 +184,7 @@ func TestImporterPreImportWithPerformer(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - db.Performer.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingPerformer(t *testing.T) { @@ -216,7 +218,7 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { assert.Nil(t, err) assert.Equal(t, []int{existingPerformerID}, i.gallery.PerformerIDs.List()) - db.Performer.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { @@ -237,6 +239,8 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { err := i.PreImport(testCtx) assert.NotNil(t, err) + + db.AssertExpectations(t) } func TestImporterPreImportWithTag(t *testing.T) { @@ -268,7 +272,7 @@ func TestImporterPreImportWithTag(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingTag(t *testing.T) { @@ -302,7 +306,7 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { assert.Nil(t, err) assert.Equal(t, []int{existingTagID}, i.gallery.TagIDs.List()) - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { @@ -323,4 +327,6 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { err := i.PreImport(testCtx) assert.NotNil(t, err) + + db.AssertExpectations(t) } diff --git a/pkg/image/export_test.go b/pkg/image/export_test.go index b228a371bc8..6adaf1d3321 100644 --- a/pkg/image/export_test.go +++ b/pkg/image/export_test.go @@ -154,5 +154,5 @@ func TestGetStudioName(t *testing.T) { } } - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } diff --git a/pkg/image/import_test.go b/pkg/image/import_test.go index ea1d899eeb4..9d63dd02e92 100644 --- a/pkg/image/import_test.go +++ b/pkg/image/import_test.go @@ -62,7 +62,7 @@ func TestImporterPreImportWithStudio(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingStudio(t *testing.T) { @@ -94,7 +94,7 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { assert.Nil(t, err) assert.Equal(t, existingStudioID, *i.image.StudioID) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { @@ -113,6 +113,8 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { err := i.PreImport(testCtx) assert.NotNil(t, err) + + db.AssertExpectations(t) } func TestImporterPreImportWithPerformer(t *testing.T) { @@ -144,7 +146,7 @@ func TestImporterPreImportWithPerformer(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - db.Performer.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingPerformer(t *testing.T) { @@ -178,7 +180,7 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { assert.Nil(t, err) assert.Equal(t, []int{existingPerformerID}, i.image.PerformerIDs.List()) - db.Performer.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { @@ -199,6 +201,8 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { err := i.PreImport(testCtx) assert.NotNil(t, err) + + db.AssertExpectations(t) } func TestImporterPreImportWithTag(t *testing.T) { @@ -230,7 +234,7 @@ func TestImporterPreImportWithTag(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingTag(t *testing.T) { @@ -264,7 +268,7 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { assert.Nil(t, err) assert.Equal(t, []int{existingTagID}, i.image.TagIDs.List()) - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { @@ -285,4 +289,6 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { err := i.PreImport(testCtx) assert.NotNil(t, err) + + db.AssertExpectations(t) } diff --git a/pkg/models/mocks/database.go b/pkg/models/mocks/database.go index c3b6cd62631..83d2cbfabad 100644 --- a/pkg/models/mocks/database.go +++ b/pkg/models/mocks/database.go @@ -5,6 +5,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/txn" + "github.com/stretchr/testify/mock" ) type Database struct { @@ -72,6 +73,21 @@ func NewDatabase() *Database { } } +func (db *Database) AssertExpectations(t mock.TestingT) { + db.File.AssertExpectations(t) + db.Folder.AssertExpectations(t) + db.Gallery.AssertExpectations(t) + db.GalleryChapter.AssertExpectations(t) + db.Image.AssertExpectations(t) + db.Movie.AssertExpectations(t) + db.Performer.AssertExpectations(t) + db.Scene.AssertExpectations(t) + db.SceneMarker.AssertExpectations(t) + db.Studio.AssertExpectations(t) + db.Tag.AssertExpectations(t) + db.SavedFilter.AssertExpectations(t) +} + func (db *Database) Repository() models.Repository { return models.Repository{ TxnManager: db, diff --git a/pkg/movie/export_test.go b/pkg/movie/export_test.go index d369caaa789..51d57e2b6e8 100644 --- a/pkg/movie/export_test.go +++ b/pkg/movie/export_test.go @@ -205,6 +205,5 @@ func TestToJSON(t *testing.T) { } } - db.Movie.AssertExpectations(t) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } diff --git a/pkg/movie/import_test.go b/pkg/movie/import_test.go index c4957545da3..d62f5a89004 100644 --- a/pkg/movie/import_test.go +++ b/pkg/movie/import_test.go @@ -96,7 +96,7 @@ func TestImporterPreImportWithStudio(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingStudio(t *testing.T) { @@ -131,7 +131,7 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { assert.Nil(t, err) assert.Equal(t, existingStudioID, *i.movie.StudioID) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { @@ -153,6 +153,8 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { err := i.PreImport(testCtx) assert.NotNil(t, err) + + db.AssertExpectations(t) } func TestImporterPostImport(t *testing.T) { @@ -177,7 +179,7 @@ func TestImporterPostImport(t *testing.T) { err = i.PostImport(testCtx, errImageID) assert.NotNil(t, err) - db.Movie.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterFindExistingID(t *testing.T) { @@ -212,7 +214,7 @@ func TestImporterFindExistingID(t *testing.T) { assert.Nil(t, id) assert.NotNil(t, err) - db.Movie.AssertExpectations(t) + db.AssertExpectations(t) } func TestCreate(t *testing.T) { @@ -248,7 +250,7 @@ func TestCreate(t *testing.T) { assert.Nil(t, id) assert.NotNil(t, err) - db.Movie.AssertExpectations(t) + db.AssertExpectations(t) } func TestUpdate(t *testing.T) { @@ -286,5 +288,5 @@ func TestUpdate(t *testing.T) { err = i.Update(testCtx, errImageID) assert.NotNil(t, err) - db.Movie.AssertExpectations(t) + db.AssertExpectations(t) } diff --git a/pkg/performer/export_test.go b/pkg/performer/export_test.go index 483cf80d5d8..572634aa6a7 100644 --- a/pkg/performer/export_test.go +++ b/pkg/performer/export_test.go @@ -225,5 +225,5 @@ func TestToJSON(t *testing.T) { } } - db.Performer.AssertExpectations(t) + db.AssertExpectations(t) } diff --git a/pkg/performer/import_test.go b/pkg/performer/import_test.go index 26960a1a539..1ee569892d4 100644 --- a/pkg/performer/import_test.go +++ b/pkg/performer/import_test.go @@ -92,7 +92,7 @@ func TestImporterPreImportWithTag(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingTag(t *testing.T) { @@ -127,7 +127,7 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { assert.Nil(t, err) assert.Equal(t, existingTagID, i.performer.TagIDs.List()[0]) - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { @@ -149,6 +149,8 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { err := i.PreImport(testCtx) assert.NotNil(t, err) + + db.AssertExpectations(t) } func TestImporterPostImport(t *testing.T) { @@ -171,7 +173,7 @@ func TestImporterPostImport(t *testing.T) { err = i.PostImport(testCtx, errImageID) assert.NotNil(t, err) - db.Performer.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterFindExistingID(t *testing.T) { @@ -222,7 +224,7 @@ func TestImporterFindExistingID(t *testing.T) { assert.Nil(t, id) assert.NotNil(t, err) - db.Performer.AssertExpectations(t) + db.AssertExpectations(t) } func TestCreate(t *testing.T) { @@ -258,7 +260,7 @@ func TestCreate(t *testing.T) { assert.Nil(t, id) assert.NotNil(t, err) - db.Performer.AssertExpectations(t) + db.AssertExpectations(t) } func TestUpdate(t *testing.T) { @@ -296,5 +298,5 @@ func TestUpdate(t *testing.T) { err = i.Update(testCtx, errImageID) assert.NotNil(t, err) - db.Performer.AssertExpectations(t) + db.AssertExpectations(t) } diff --git a/pkg/scene/export_test.go b/pkg/scene/export_test.go index 0c44734ce54..04149bfa630 100644 --- a/pkg/scene/export_test.go +++ b/pkg/scene/export_test.go @@ -208,7 +208,7 @@ func TestToJSON(t *testing.T) { } } - db.Scene.AssertExpectations(t) + db.AssertExpectations(t) } func createStudioScene(studioID int) models.Scene { @@ -266,7 +266,7 @@ func TestGetStudioName(t *testing.T) { } } - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } type stringSliceTestScenario struct { @@ -327,7 +327,7 @@ func TestGetTagNames(t *testing.T) { } } - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } type sceneMoviesTestScenario struct { @@ -417,7 +417,7 @@ func TestGetSceneMoviesJSON(t *testing.T) { } } - db.Movie.AssertExpectations(t) + db.AssertExpectations(t) } const ( @@ -591,5 +591,5 @@ func TestGetSceneMarkersJSON(t *testing.T) { } } - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } diff --git a/pkg/scene/import_test.go b/pkg/scene/import_test.go index bb13c96732d..26180785627 100644 --- a/pkg/scene/import_test.go +++ b/pkg/scene/import_test.go @@ -78,7 +78,7 @@ func TestImporterPreImportWithStudio(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingStudio(t *testing.T) { @@ -110,7 +110,7 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { assert.Nil(t, err) assert.Equal(t, existingStudioID, *i.scene.StudioID) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { @@ -129,6 +129,8 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { err := i.PreImport(testCtx) assert.NotNil(t, err) + + db.AssertExpectations(t) } func TestImporterPreImportWithPerformer(t *testing.T) { @@ -160,7 +162,7 @@ func TestImporterPreImportWithPerformer(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - db.Performer.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingPerformer(t *testing.T) { @@ -194,7 +196,7 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { assert.Nil(t, err) assert.Equal(t, []int{existingPerformerID}, i.scene.PerformerIDs.List()) - db.Performer.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { @@ -215,6 +217,8 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { err := i.PreImport(testCtx) assert.NotNil(t, err) + + db.AssertExpectations(t) } func TestImporterPreImportWithMovie(t *testing.T) { @@ -247,7 +251,7 @@ func TestImporterPreImportWithMovie(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - db.Movie.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingMovie(t *testing.T) { @@ -283,7 +287,7 @@ func TestImporterPreImportWithMissingMovie(t *testing.T) { assert.Nil(t, err) assert.Equal(t, existingMovieID, i.scene.Movies.List()[0].MovieID) - db.Movie.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) { @@ -306,6 +310,8 @@ func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) { err := i.PreImport(testCtx) assert.NotNil(t, err) + + db.AssertExpectations(t) } func TestImporterPreImportWithTag(t *testing.T) { @@ -337,7 +343,7 @@ func TestImporterPreImportWithTag(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingTag(t *testing.T) { @@ -371,7 +377,7 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { assert.Nil(t, err) assert.Equal(t, []int{existingTagID}, i.scene.TagIDs.List()) - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { @@ -392,4 +398,6 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { err := i.PreImport(testCtx) assert.NotNil(t, err) + + db.AssertExpectations(t) } diff --git a/pkg/scene/update_test.go b/pkg/scene/update_test.go index 2aee2bbcbc6..96ebb491f66 100644 --- a/pkg/scene/update_test.go +++ b/pkg/scene/update_test.go @@ -213,7 +213,7 @@ func TestUpdater_Update(t *testing.T) { }) } - db.Scene.AssertExpectations(t) + db.AssertExpectations(t) } func TestUpdateSet_UpdateInput(t *testing.T) { diff --git a/pkg/studio/export_test.go b/pkg/studio/export_test.go index 6682213b018..eb489f8a9a0 100644 --- a/pkg/studio/export_test.go +++ b/pkg/studio/export_test.go @@ -192,5 +192,5 @@ func TestToJSON(t *testing.T) { } } - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } diff --git a/pkg/studio/import_test.go b/pkg/studio/import_test.go index 78b40d3d9f5..e89256371cf 100644 --- a/pkg/studio/import_test.go +++ b/pkg/studio/import_test.go @@ -92,7 +92,7 @@ func TestImporterPreImportWithParent(t *testing.T) { err = i.PreImport(testCtx) assert.NotNil(t, err) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingParent(t *testing.T) { @@ -126,7 +126,7 @@ func TestImporterPreImportWithMissingParent(t *testing.T) { assert.Nil(t, err) assert.Equal(t, existingStudioID, *i.studio.ParentID) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) { @@ -147,6 +147,8 @@ func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) { err := i.PreImport(testCtx) assert.NotNil(t, err) + + db.AssertExpectations(t) } func TestImporterPostImport(t *testing.T) { @@ -171,7 +173,7 @@ func TestImporterPostImport(t *testing.T) { err = i.PostImport(testCtx, errImageID) assert.NotNil(t, err) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterFindExistingID(t *testing.T) { @@ -205,7 +207,7 @@ func TestImporterFindExistingID(t *testing.T) { assert.Nil(t, id) assert.NotNil(t, err) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } func TestCreate(t *testing.T) { @@ -240,7 +242,7 @@ func TestCreate(t *testing.T) { assert.Nil(t, id) assert.NotNil(t, err) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } func TestUpdate(t *testing.T) { @@ -277,5 +279,5 @@ func TestUpdate(t *testing.T) { err = i.Update(testCtx, errImageID) assert.NotNil(t, err) - db.Studio.AssertExpectations(t) + db.AssertExpectations(t) } diff --git a/pkg/tag/export_test.go b/pkg/tag/export_test.go index 75a85b7eb15..1018f8d2d9d 100644 --- a/pkg/tag/export_test.go +++ b/pkg/tag/export_test.go @@ -147,5 +147,5 @@ func TestToJSON(t *testing.T) { } } - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } diff --git a/pkg/tag/import_test.go b/pkg/tag/import_test.go index 9378adf75a3..6b856399200 100644 --- a/pkg/tag/import_test.go +++ b/pkg/tag/import_test.go @@ -106,7 +106,7 @@ func TestImporterPostImport(t *testing.T) { err = i.PostImport(testCtx, errParentsID) assert.NotNil(t, err) - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterPostImportParentMissing(t *testing.T) { @@ -206,7 +206,7 @@ func TestImporterPostImportParentMissing(t *testing.T) { err = i.PostImport(testCtx, ignoreFoundID) assert.Nil(t, err) - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } func TestImporterFindExistingID(t *testing.T) { @@ -240,7 +240,7 @@ func TestImporterFindExistingID(t *testing.T) { assert.Nil(t, id) assert.NotNil(t, err) - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } func TestCreate(t *testing.T) { @@ -275,7 +275,7 @@ func TestCreate(t *testing.T) { assert.Nil(t, id) assert.NotNil(t, err) - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } func TestUpdate(t *testing.T) { @@ -312,5 +312,5 @@ func TestUpdate(t *testing.T) { err = i.Update(testCtx, errImageID) assert.NotNil(t, err) - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } diff --git a/pkg/tag/update_test.go b/pkg/tag/update_test.go index bb137d818b4..c581d34ac43 100644 --- a/pkg/tag/update_test.go +++ b/pkg/tag/update_test.go @@ -284,5 +284,5 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, assert.Nil(res) } - db.Tag.AssertExpectations(t) + db.AssertExpectations(t) } From 40f37d9e60e36284f2b2f865c0207901324f376a Mon Sep 17 00:00:00 2001 From: DingDongSoLong4 <99329275+DingDongSoLong4@users.noreply.github.com> Date: Mon, 17 Jul 2023 00:24:02 +0200 Subject: [PATCH 5/9] Refactor routes --- internal/api/routes.go | 15 ++++++++++++ internal/api/routes_custom.go | 4 ++++ internal/api/routes_downloads.go | 4 ++++ internal/api/routes_image.go | 17 +++++++------ internal/api/routes_movie.go | 16 +++++++++---- internal/api/routes_performer.go | 14 +++++++---- internal/api/routes_scene.go | 36 ++++++++++++++++------------ internal/api/routes_studio.go | 14 +++++++---- internal/api/routes_tag.go | 16 +++++++++---- internal/api/server.go | 41 +++++++------------------------- 10 files changed, 104 insertions(+), 73 deletions(-) create mode 100644 internal/api/routes.go diff --git a/internal/api/routes.go b/internal/api/routes.go new file mode 100644 index 00000000000..e3a0f48c083 --- /dev/null +++ b/internal/api/routes.go @@ -0,0 +1,15 @@ +package api + +import ( + "net/http" + + "github.com/stashapp/stash/pkg/txn" +) + +type routes struct { + txnManager txn.Manager +} + +func (rs routes) withReadTxn(r *http.Request, fn txn.TxnFunc) error { + return txn.WithReadTxn(r.Context(), rs.txnManager, fn) +} diff --git a/internal/api/routes_custom.go b/internal/api/routes_custom.go index 731bbc58692..091aa061690 100644 --- a/internal/api/routes_custom.go +++ b/internal/api/routes_custom.go @@ -12,6 +12,10 @@ type customRoutes struct { servedFolders config.URLMap } +func getCustomRoutes(servedFolders config.URLMap) chi.Router { + return customRoutes{servedFolders: servedFolders}.Routes() +} + func (rs customRoutes) Routes() chi.Router { r := chi.NewRouter() diff --git a/internal/api/routes_downloads.go b/internal/api/routes_downloads.go index 16ceedaab56..24fd4831d15 100644 --- a/internal/api/routes_downloads.go +++ b/internal/api/routes_downloads.go @@ -10,6 +10,10 @@ import ( type downloadsRoutes struct{} +func getDownloadsRoutes() chi.Router { + return downloadsRoutes{}.Routes() +} + func (rs downloadsRoutes) Routes() chi.Router { r := chi.NewRouter() diff --git a/internal/api/routes_image.go b/internal/api/routes_image.go index 4cc2576718c..14e0cbf92e8 100644 --- a/internal/api/routes_image.go +++ b/internal/api/routes_image.go @@ -17,7 +17,6 @@ import ( "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) @@ -27,11 +26,19 @@ type ImageFinder interface { } type imageRoutes struct { - txnManager txn.Manager + routes imageFinder ImageFinder fileGetter models.FileGetter } +func getImageRoutes(repo models.Repository) chi.Router { + return imageRoutes{ + routes: routes{txnManager: repo.TxnManager}, + imageFinder: repo.Image, + fileGetter: repo.File, + }.Routes() +} + func (rs imageRoutes) Routes() chi.Router { r := chi.NewRouter() @@ -46,8 +53,6 @@ func (rs imageRoutes) Routes() chi.Router { return r } -// region Handlers - func (rs imageRoutes) Thumbnail(w http.ResponseWriter, r *http.Request) { img := r.Context().Value(imageKey).(*models.Image) filepath := manager.GetInstance().Paths.Generated.GetThumbnailPath(img.Checksum, models.DefaultGthumbWidth) @@ -148,15 +153,13 @@ func (rs imageRoutes) serveImage(w http.ResponseWriter, r *http.Request, i *mode utils.ServeImage(w, r, image) } -// endregion - func (rs imageRoutes) ImageCtx(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { imageIdentifierQueryParam := chi.URLParam(r, "imageId") imageID, _ := strconv.Atoi(imageIdentifierQueryParam) var image *models.Image - _ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + _ = rs.withReadTxn(r, func(ctx context.Context) error { qb := rs.imageFinder if imageID == 0 { images, _ := qb.FindByChecksum(ctx, imageIdentifierQueryParam) diff --git a/internal/api/routes_movie.go b/internal/api/routes_movie.go index 400587763b5..740d8886761 100644 --- a/internal/api/routes_movie.go +++ b/internal/api/routes_movie.go @@ -9,7 +9,6 @@ import ( "github.com/go-chi/chi" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) @@ -20,10 +19,17 @@ type MovieFinder interface { } type movieRoutes struct { - txnManager txn.Manager + routes movieFinder MovieFinder } +func getMovieRoutes(repo models.Repository) chi.Router { + return movieRoutes{ + routes: routes{txnManager: repo.TxnManager}, + movieFinder: repo.Movie, + }.Routes() +} + func (rs movieRoutes) Routes() chi.Router { r := chi.NewRouter() @@ -41,7 +47,7 @@ func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) { defaultParam := r.URL.Query().Get("default") var image []byte if defaultParam != "true" { - readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error { var err error image, err = rs.movieFinder.GetFrontImage(ctx, movie.ID) return err @@ -66,7 +72,7 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) { defaultParam := r.URL.Query().Get("default") var image []byte if defaultParam != "true" { - readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error { var err error image, err = rs.movieFinder.GetBackImage(ctx, movie.ID) return err @@ -95,7 +101,7 @@ func (rs movieRoutes) MovieCtx(next http.Handler) http.Handler { } var movie *models.Movie - _ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + _ = rs.withReadTxn(r, func(ctx context.Context) error { movie, _ = rs.movieFinder.Find(ctx, movieID) return nil }) diff --git a/internal/api/routes_performer.go b/internal/api/routes_performer.go index d05e5309570..0bc0cf2f4bd 100644 --- a/internal/api/routes_performer.go +++ b/internal/api/routes_performer.go @@ -10,7 +10,6 @@ import ( "github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) @@ -20,10 +19,17 @@ type PerformerFinder interface { } type performerRoutes struct { - txnManager txn.Manager + routes performerFinder PerformerFinder } +func getPerformerRoutes(repo models.Repository) chi.Router { + return performerRoutes{ + routes: routes{txnManager: repo.TxnManager}, + performerFinder: repo.Performer, + }.Routes() +} + func (rs performerRoutes) Routes() chi.Router { r := chi.NewRouter() @@ -41,7 +47,7 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) { var image []byte if defaultParam != "true" { - readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error { var err error image, err = rs.performerFinder.GetImage(ctx, performer.ID) return err @@ -70,7 +76,7 @@ func (rs performerRoutes) PerformerCtx(next http.Handler) http.Handler { } var performer *models.Performer - _ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + _ = rs.withReadTxn(r, func(ctx context.Context) error { var err error performer, err = rs.performerFinder.Find(ctx, performerID) return err diff --git a/internal/api/routes_scene.go b/internal/api/routes_scene.go index e0584d6888b..3bbc5b258ce 100644 --- a/internal/api/routes_scene.go +++ b/internal/api/routes_scene.go @@ -16,7 +16,6 @@ import ( "github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) @@ -43,7 +42,7 @@ type CaptionFinder interface { } type sceneRoutes struct { - txnManager txn.Manager + routes sceneFinder SceneFinder fileGetter models.FileGetter captionFinder CaptionFinder @@ -51,6 +50,17 @@ type sceneRoutes struct { tagFinder SceneMarkerTagFinder } +func getSceneRoutes(repo models.Repository) chi.Router { + return sceneRoutes{ + routes: routes{txnManager: repo.TxnManager}, + sceneFinder: repo.Scene, + fileGetter: repo.File, + captionFinder: repo.File, + sceneMarkerFinder: repo.SceneMarker, + tagFinder: repo.Tag, + }.Routes() +} + func (rs sceneRoutes) Routes() chi.Router { r := chi.NewRouter() @@ -89,8 +99,6 @@ func (rs sceneRoutes) Routes() chi.Router { return r } -// region Handlers - func (rs sceneRoutes) StreamDirect(w http.ResponseWriter, r *http.Request) { scene := r.Context().Value(sceneKey).(*models.Scene) ss := manager.SceneServer{ @@ -270,13 +278,13 @@ func (rs sceneRoutes) Webp(w http.ResponseWriter, r *http.Request) { utils.ServeStaticFile(w, r, filepath) } -func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.SceneMarker) (*string, error) { +func (rs sceneRoutes) getChapterVttTitle(r *http.Request, marker *models.SceneMarker) (*string, error) { if marker.Title != "" { return &marker.Title, nil } var title string - if err := txn.WithReadTxn(ctx, rs.txnManager, func(ctx context.Context) error { + if err := rs.withReadTxn(r, func(ctx context.Context) error { qb := rs.tagFinder primaryTag, err := qb.Find(ctx, marker.PrimaryTagID) if err != nil { @@ -305,7 +313,7 @@ func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.Sce func (rs sceneRoutes) VttChapter(w http.ResponseWriter, r *http.Request) { scene := r.Context().Value(sceneKey).(*models.Scene) var sceneMarkers []*models.SceneMarker - readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error { var err error sceneMarkers, err = rs.sceneMarkerFinder.FindBySceneID(ctx, scene.ID) return err @@ -325,7 +333,7 @@ func (rs sceneRoutes) VttChapter(w http.ResponseWriter, r *http.Request) { time := utils.GetVTTTime(marker.Seconds) vttLines = append(vttLines, time+" --> "+time) - vttTitle, err := rs.getChapterVttTitle(r.Context(), marker) + vttTitle, err := rs.getChapterVttTitle(r, marker) if errors.Is(err, context.Canceled) { return } @@ -404,7 +412,7 @@ func (rs sceneRoutes) Caption(w http.ResponseWriter, r *http.Request, lang strin s := r.Context().Value(sceneKey).(*models.Scene) var captions []*models.VideoCaption - readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error { var err error primaryFile := s.Files.Primary() if primaryFile == nil { @@ -466,7 +474,7 @@ func (rs sceneRoutes) SceneMarkerStream(w http.ResponseWriter, r *http.Request) sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm()) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) var sceneMarker *models.SceneMarker - readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error { var err error sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID) return err @@ -494,7 +502,7 @@ func (rs sceneRoutes) SceneMarkerPreview(w http.ResponseWriter, r *http.Request) sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm()) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) var sceneMarker *models.SceneMarker - readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error { var err error sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID) return err @@ -530,7 +538,7 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm()) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) var sceneMarker *models.SceneMarker - readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error { var err error sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID) return err @@ -561,8 +569,6 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque } } -// endregion - func (rs sceneRoutes) SceneCtx(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sceneID, err := strconv.Atoi(chi.URLParam(r, "sceneId")) @@ -572,7 +578,7 @@ func (rs sceneRoutes) SceneCtx(next http.Handler) http.Handler { } var scene *models.Scene - _ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + _ = rs.withReadTxn(r, func(ctx context.Context) error { qb := rs.sceneFinder scene, _ = qb.Find(ctx, sceneID) diff --git a/internal/api/routes_studio.go b/internal/api/routes_studio.go index 1cce3938532..6e4278d0b58 100644 --- a/internal/api/routes_studio.go +++ b/internal/api/routes_studio.go @@ -11,7 +11,6 @@ import ( "github.com/stashapp/stash/internal/static" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) @@ -21,10 +20,17 @@ type StudioFinder interface { } type studioRoutes struct { - txnManager txn.Manager + routes studioFinder StudioFinder } +func getStudioRoutes(repo models.Repository) chi.Router { + return studioRoutes{ + routes: routes{txnManager: repo.TxnManager}, + studioFinder: repo.Studio, + }.Routes() +} + func (rs studioRoutes) Routes() chi.Router { r := chi.NewRouter() @@ -42,7 +48,7 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) { var image []byte if defaultParam != "true" { - readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error { var err error image, err = rs.studioFinder.GetImage(ctx, studio.ID) return err @@ -78,7 +84,7 @@ func (rs studioRoutes) StudioCtx(next http.Handler) http.Handler { } var studio *models.Studio - _ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + _ = rs.withReadTxn(r, func(ctx context.Context) error { var err error studio, err = rs.studioFinder.Find(ctx, studioID) return err diff --git a/internal/api/routes_tag.go b/internal/api/routes_tag.go index 9ccf11a11c9..ceb359cc55f 100644 --- a/internal/api/routes_tag.go +++ b/internal/api/routes_tag.go @@ -11,7 +11,6 @@ import ( "github.com/stashapp/stash/internal/static" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) @@ -21,8 +20,15 @@ type TagFinder interface { } type tagRoutes struct { - txnManager txn.Manager - tagFinder TagFinder + routes + tagFinder TagFinder +} + +func getTagRoutes(repo models.Repository) chi.Router { + return tagRoutes{ + routes: routes{txnManager: repo.TxnManager}, + tagFinder: repo.Tag, + }.Routes() } func (rs tagRoutes) Routes() chi.Router { @@ -42,7 +48,7 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) { var image []byte if defaultParam != "true" { - readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error { var err error image, err = rs.tagFinder.GetImage(ctx, tag.ID) return err @@ -78,7 +84,7 @@ func (rs tagRoutes) TagCtx(next http.Handler) http.Handler { } var tag *models.Tag - _ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + _ = rs.withReadTxn(r, func(ctx context.Context) error { var err error tag, err = rs.tagFinder.Find(ctx, tagID) return err diff --git a/internal/api/server.go b/internal/api/server.go index 15f72e416e7..ba4607fb2f1 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -142,36 +142,13 @@ func Start() error { gqlPlayground.Handler("GraphQL playground", endpoint)(w, r) }) - r.Mount("/performer", performerRoutes{ - txnManager: repo.TxnManager, - performerFinder: repo.Performer, - }.Routes()) - r.Mount("/scene", sceneRoutes{ - txnManager: repo.TxnManager, - sceneFinder: repo.Scene, - fileGetter: repo.File, - captionFinder: repo.File, - sceneMarkerFinder: repo.SceneMarker, - tagFinder: repo.Tag, - }.Routes()) - r.Mount("/image", imageRoutes{ - txnManager: repo.TxnManager, - imageFinder: repo.Image, - fileGetter: repo.File, - }.Routes()) - r.Mount("/studio", studioRoutes{ - txnManager: repo.TxnManager, - studioFinder: repo.Studio, - }.Routes()) - r.Mount("/movie", movieRoutes{ - txnManager: repo.TxnManager, - movieFinder: repo.Movie, - }.Routes()) - r.Mount("/tag", tagRoutes{ - txnManager: repo.TxnManager, - tagFinder: repo.Tag, - }.Routes()) - r.Mount("/downloads", downloadsRoutes{}.Routes()) + r.Mount("/performer", getPerformerRoutes(repo)) + r.Mount("/scene", getSceneRoutes(repo)) + r.Mount("/image", getImageRoutes(repo)) + r.Mount("/studio", getStudioRoutes(repo)) + r.Mount("/movie", getMovieRoutes(repo)) + r.Mount("/tag", getTagRoutes(repo)) + r.Mount("/downloads", getDownloadsRoutes()) r.HandleFunc("/css", cssHandler(c, pluginCache)) r.HandleFunc("/javascript", javascriptHandler(c, pluginCache)) @@ -191,9 +168,7 @@ func Start() error { // Serve static folders customServedFolders := c.GetCustomServedFolders() if customServedFolders != nil { - r.Mount("/custom", customRoutes{ - servedFolders: customServedFolders, - }.Routes()) + r.Mount("/custom", getCustomRoutes(customServedFolders)) } customUILocation := c.GetCustomUILocation() From 29414487cc7eed0dd3bfa9462c29bdb98f90f9fb Mon Sep 17 00:00:00 2001 From: DingDongSoLong4 <99329275+DingDongSoLong4@users.noreply.github.com> Date: Mon, 17 Jul 2023 00:34:41 +0200 Subject: [PATCH 6/9] Move default movie image to internal/static and add convenience methods --- internal/api/images.go | 4 +- internal/api/resolver_mutation_movie.go | 13 ++--- internal/api/routes_image.go | 9 +--- internal/api/routes_movie.go | 8 ++- internal/api/routes_studio.go | 11 +--- internal/api/routes_tag.go | 11 +--- internal/manager/running_streams.go | 8 +-- internal/static/embed.go | 67 +++++++++++++++++++----- internal/static/movie/movie.png | Bin 0 -> 405 bytes pkg/models/model_movie.go | 2 - 10 files changed, 76 insertions(+), 57 deletions(-) create mode 100644 internal/static/movie/movie.png diff --git a/internal/api/images.go b/internal/api/images.go index 95ed4c8447f..bd554d77a38 100644 --- a/internal/api/images.go +++ b/internal/api/images.go @@ -61,11 +61,11 @@ var performerBoxCustom *imageBox func initialiseImages() { var err error - performerBox, err = newImageBox(&static.Performer) + performerBox, err = newImageBox(static.Sub(static.Performer)) if err != nil { logger.Warnf("error loading performer images: %v", err) } - performerBoxMale, err = newImageBox(&static.PerformerMale) + performerBoxMale, err = newImageBox(static.Sub(static.PerformerMale)) if err != nil { logger.Warnf("error loading male performer images: %v", err) } diff --git a/internal/api/resolver_mutation_movie.go b/internal/api/resolver_mutation_movie.go index ef2d2405afe..71363a327d9 100644 --- a/internal/api/resolver_mutation_movie.go +++ b/internal/api/resolver_mutation_movie.go @@ -5,6 +5,7 @@ import ( "fmt" "strconv" + "github.com/stashapp/stash/internal/static" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/sliceutil/stringslice" @@ -50,12 +51,6 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input MovieCreateInp return nil, fmt.Errorf("converting studio id: %w", err) } - // HACK: if back image is being set, set the front image to the default. - // This is because we can't have a null front image with a non-null back image. - if input.FrontImage == nil && input.BackImage != nil { - input.FrontImage = &models.DefaultMovieImage - } - // Process the base 64 encoded image string var frontimageData []byte if input.FrontImage != nil { @@ -74,6 +69,12 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input MovieCreateInp } } + // HACK: if back image is being set, set the front image to the default. + // This is because we can't have a null front image with a non-null back image. + if len(frontimageData) == 0 && len(backimageData) != 0 { + frontimageData = static.ReadAll(static.DefaultMovieImage) + } + // Start the transaction and save the movie if err := r.withTxn(ctx, func(ctx context.Context) error { qb := r.repository.Movie diff --git a/internal/api/routes_image.go b/internal/api/routes_image.go index 14e0cbf92e8..9c58bc72d8f 100644 --- a/internal/api/routes_image.go +++ b/internal/api/routes_image.go @@ -3,7 +3,6 @@ package api import ( "context" "errors" - "io" "io/fs" "net/http" "os/exec" @@ -124,8 +123,6 @@ func (rs imageRoutes) Image(w http.ResponseWriter, r *http.Request) { } func (rs imageRoutes) serveImage(w http.ResponseWriter, r *http.Request, i *models.Image, useDefault bool) { - const defaultImageImage = "image/image.svg" - if i.Files.Primary() != nil { err := i.Files.Primary().Base().Serve(&file.OsFS{}, w, r) if err == nil { @@ -146,10 +143,8 @@ func (rs imageRoutes) serveImage(w http.ResponseWriter, r *http.Request, i *mode return } - // fall back to static image - f, _ := static.Image.Open(defaultImageImage) - defer f.Close() - image, _ := io.ReadAll(f) + // fallback to default image + image := static.ReadAll(static.DefaultImageImage) utils.ServeImage(w, r, image) } diff --git a/internal/api/routes_movie.go b/internal/api/routes_movie.go index 740d8886761..7487755cfaa 100644 --- a/internal/api/routes_movie.go +++ b/internal/api/routes_movie.go @@ -7,6 +7,8 @@ import ( "strconv" "github.com/go-chi/chi" + + "github.com/stashapp/stash/internal/static" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" @@ -60,8 +62,9 @@ func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) { } } + // fallback to default image if len(image) == 0 { - image, _ = utils.ProcessBase64Image(models.DefaultMovieImage) + image = static.ReadAll(static.DefaultMovieImage) } utils.ServeImage(w, r, image) @@ -85,8 +88,9 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) { } } + // fallback to default image if len(image) == 0 { - image, _ = utils.ProcessBase64Image(models.DefaultMovieImage) + image = static.ReadAll(static.DefaultMovieImage) } utils.ServeImage(w, r, image) diff --git a/internal/api/routes_studio.go b/internal/api/routes_studio.go index 6e4278d0b58..d61a1e7548d 100644 --- a/internal/api/routes_studio.go +++ b/internal/api/routes_studio.go @@ -3,7 +3,6 @@ package api import ( "context" "errors" - "io" "net/http" "strconv" @@ -61,15 +60,9 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) { } } + // fallback to default image if len(image) == 0 { - const defaultStudioImage = "studio/studio.svg" - - // fall back to static image - f, _ := static.Studio.Open(defaultStudioImage) - defer f.Close() - stat, _ := f.Stat() - http.ServeContent(w, r, "studio.svg", stat.ModTime(), f.(io.ReadSeeker)) - return + image = static.ReadAll(static.DefaultStudioImage) } utils.ServeImage(w, r, image) diff --git a/internal/api/routes_tag.go b/internal/api/routes_tag.go index ceb359cc55f..5c05ac07fb1 100644 --- a/internal/api/routes_tag.go +++ b/internal/api/routes_tag.go @@ -3,7 +3,6 @@ package api import ( "context" "errors" - "io" "net/http" "strconv" @@ -61,15 +60,9 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) { } } + // fallback to default image if len(image) == 0 { - const defaultTagImage = "tag/tag.svg" - - // fall back to static image - f, _ := static.Tag.Open(defaultTagImage) - defer f.Close() - stat, _ := f.Stat() - http.ServeContent(w, r, "tag.svg", stat.ModTime(), f.(io.ReadSeeker)) - return + image = static.ReadAll(static.DefaultTagImage) } utils.ServeImage(w, r, image) diff --git a/internal/manager/running_streams.go b/internal/manager/running_streams.go index 788b30f1c28..2d255b7308c 100644 --- a/internal/manager/running_streams.go +++ b/internal/manager/running_streams.go @@ -3,7 +3,6 @@ package manager import ( "context" "errors" - "io" "net/http" "github.com/stashapp/stash/internal/manager/config" @@ -58,8 +57,6 @@ func (s *SceneServer) StreamSceneDirect(scene *models.Scene, w http.ResponseWrit } func (s *SceneServer) ServeScreenshot(scene *models.Scene, w http.ResponseWriter, r *http.Request) { - const defaultSceneImage = "scene/scene.svg" - var cover []byte readTxnErr := txn.WithReadTxn(r.Context(), s.TxnManager, func(ctx context.Context) error { cover, _ = s.SceneCoverGetter.GetCover(ctx, scene.ID) @@ -92,10 +89,7 @@ func (s *SceneServer) ServeScreenshot(scene *models.Scene, w http.ResponseWriter } // fallback to default cover if none found - // should always be there - f, _ := static.Scene.Open(defaultSceneImage) - defer f.Close() - cover, _ = io.ReadAll(f) + cover = static.ReadAll(static.DefaultSceneImage) } utils.ServeImage(w, r, cover) diff --git a/internal/static/embed.go b/internal/static/embed.go index 9be76afa4a1..d82c0b66bd7 100644 --- a/internal/static/embed.go +++ b/internal/static/embed.go @@ -1,21 +1,62 @@ package static -import "embed" +import ( + "embed" + "fmt" + "io" + "io/fs" +) -//go:embed performer -var Performer embed.FS +//go:embed performer performer_male scene image tag studio movie +var data embed.FS -//go:embed performer_male -var PerformerMale embed.FS +const ( + Performer = "performer" + PerformerMale = "performer_male" -//go:embed scene -var Scene embed.FS + Scene = "scene" + DefaultSceneImage = "scene/scene.svg" -//go:embed image -var Image embed.FS + Image = "image" + DefaultImageImage = "image/image.svg" -//go:embed tag -var Tag embed.FS + Tag = "tag" + DefaultTagImage = "tag/tag.svg" -//go:embed studio -var Studio embed.FS + Studio = "studio" + DefaultStudioImage = "studio/studio.svg" + + Movie = "movie" + DefaultMovieImage = "movie/movie.png" +) + +// Sub returns an FS rooted at path, using fs.Sub. +// It will panic if an error occurs. +func Sub(path string) fs.FS { + ret, err := fs.Sub(data, path) + if err != nil { + panic(fmt.Sprintf("creating static SubFS: %v", err)) + } + return ret +} + +// Open opens the file at path for reading. +// It will panic if an error occurs. +func Open(path string) fs.File { + f, err := data.Open(path) + if err != nil { + panic(fmt.Sprintf("opening static file: %v", err)) + } + return f +} + +// ReadAll returns the contents of the file at path. +// It will panic if an error occurs. +func ReadAll(path string) []byte { + f := Open(path) + ret, err := io.ReadAll(f) + if err != nil { + panic(fmt.Sprintf("reading static file: %v", err)) + } + return ret +} diff --git a/internal/static/movie/movie.png b/internal/static/movie/movie.png new file mode 100644 index 0000000000000000000000000000000000000000..0bb8b00a6cfa5999b5d1218a348a42d31754c3c0 GIT binary patch literal 405 zcmeAS@N?(olHy`uVBq!ia0vp^DImWxe4W7D@5exIwKcl6`5 pnt7hzAcnHAbige@WAA&#_T0*4(eYz@nScSv;OXk;vd$@?2>^{`g)#sD literal 0 HcmV?d00001 diff --git a/pkg/models/model_movie.go b/pkg/models/model_movie.go index 152f0d3bbb5..5880ff2d137 100644 --- a/pkg/models/model_movie.go +++ b/pkg/models/model_movie.go @@ -49,5 +49,3 @@ func NewMoviePartial() MoviePartial { UpdatedAt: NewOptionalTime(currentTime), } } - -var DefaultMovieImage = "" From e261ecb6298245adc153a759961b500098f217a7 Mon Sep 17 00:00:00 2001 From: DingDongSoLong4 <99329275+DingDongSoLong4@users.noreply.github.com> Date: Sun, 24 Sep 2023 18:10:02 +0200 Subject: [PATCH 7/9] Refactor default performer image boxes --- internal/api/images.go | 77 ++++++++++++--------- internal/api/resolver_mutation_configure.go | 2 +- internal/api/routes_performer.go | 3 +- internal/api/server.go | 5 +- 4 files changed, 49 insertions(+), 38 deletions(-) diff --git a/internal/api/images.go b/internal/api/images.go index bd554d77a38..89a8e87b0c9 100644 --- a/internal/api/images.go +++ b/internal/api/images.go @@ -1,12 +1,13 @@ package api import ( + "errors" + "fmt" "io" "io/fs" "os" "strings" - "github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/internal/static" "github.com/stashapp/stash/pkg/hash" "github.com/stashapp/stash/pkg/logger" @@ -18,7 +19,7 @@ type imageBox struct { files []string } -var imageExtensions = []string{ +var imageBoxExts = []string{ ".jpg", ".jpeg", ".png", @@ -42,7 +43,7 @@ func newImageBox(box fs.FS) (*imageBox, error) { } baseName := strings.ToLower(d.Name()) - for _, ext := range imageExtensions { + for _, ext := range imageBoxExts { if strings.HasSuffix(baseName, ext) { ret.files = append(ret.files, path) break @@ -55,44 +56,59 @@ func newImageBox(box fs.FS) (*imageBox, error) { return ret, err } +func (box *imageBox) GetRandomImageByName(name string) ([]byte, error) { + files := box.files + if len(files) == 0 { + return nil, errors.New("box is empty") + } + + index := hash.IntFromString(name) % uint64(len(files)) + img, err := box.box.Open(files[index]) + if err != nil { + return nil, err + } + defer img.Close() + + return io.ReadAll(img) +} + var performerBox *imageBox var performerBoxMale *imageBox var performerBoxCustom *imageBox -func initialiseImages() { +func init() { var err error performerBox, err = newImageBox(static.Sub(static.Performer)) if err != nil { - logger.Warnf("error loading performer images: %v", err) + panic(fmt.Sprintf("loading performer images: %v", err)) } performerBoxMale, err = newImageBox(static.Sub(static.PerformerMale)) if err != nil { - logger.Warnf("error loading male performer images: %v", err) + panic(fmt.Sprintf("loading male performer images: %v", err)) } - initialiseCustomImages() } -func initialiseCustomImages() { - customPath := config.GetInstance().GetCustomPerformerImageLocation() +func initCustomPerformerImages(customPath string) { if customPath != "" { logger.Debugf("Loading custom performer images from %s", customPath) - // We need to set performerBoxCustom at runtime, as this is a custom path, and store it in a pointer. var err error performerBoxCustom, err = newImageBox(os.DirFS(customPath)) if err != nil { - logger.Warnf("error loading custom performer from %s: %v", customPath, err) + logger.Warnf("error loading custom performer images from %s: %v", customPath, err) } } else { performerBoxCustom = nil } } -func getRandomPerformerImageUsingName(name string, gender *models.GenderEnum, customPath string) ([]byte, error) { - var box *imageBox - - // If we have a custom path, we should return a new box in the given path. - if performerBoxCustom != nil && len(performerBoxCustom.files) > 0 { - box = performerBoxCustom +func getDefaultPerformerImage(name string, gender *models.GenderEnum) []byte { + // try the custom box first if we have one + if performerBoxCustom != nil { + ret, err := performerBoxCustom.GetRandomImageByName(name) + if err == nil { + return ret + } + logger.Warnf("error loading custom default performer image: %v", err) } var g models.GenderEnum @@ -100,24 +116,19 @@ func getRandomPerformerImageUsingName(name string, gender *models.GenderEnum, cu g = *gender } - if box == nil { - switch g { - case models.GenderEnumFemale, models.GenderEnumTransgenderFemale: - box = performerBox - case models.GenderEnumMale, models.GenderEnumTransgenderMale: - box = performerBoxMale - default: - box = performerBox - } + var box *imageBox + switch g { + case models.GenderEnumFemale, models.GenderEnumTransgenderFemale: + box = performerBox + case models.GenderEnumMale, models.GenderEnumTransgenderMale: + box = performerBoxMale + default: + box = performerBox } - imageFiles := box.files - index := hash.IntFromString(name) % uint64(len(imageFiles)) - img, err := box.box.Open(imageFiles[index]) + ret, err := box.GetRandomImageByName(name) if err != nil { - return nil, err + logger.Warnf("error loading default performer image: %v", err) } - defer img.Close() - - return io.ReadAll(img) + return ret } diff --git a/internal/api/resolver_mutation_configure.go b/internal/api/resolver_mutation_configure.go index f12b3aa0cec..e3537731a7e 100644 --- a/internal/api/resolver_mutation_configure.go +++ b/internal/api/resolver_mutation_configure.go @@ -316,7 +316,7 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen if input.CustomPerformerImageLocation != nil { c.Set(config.CustomPerformerImageLocation, *input.CustomPerformerImageLocation) - initialiseCustomImages() + initCustomPerformerImages(*input.CustomPerformerImageLocation) } if input.ScraperUserAgent != nil { diff --git a/internal/api/routes_performer.go b/internal/api/routes_performer.go index 0bc0cf2f4bd..8ac4ee34901 100644 --- a/internal/api/routes_performer.go +++ b/internal/api/routes_performer.go @@ -7,7 +7,6 @@ import ( "strconv" "github.com/go-chi/chi" - "github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" @@ -61,7 +60,7 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) { } if len(image) == 0 { - image, _ = getRandomPerformerImageUsingName(performer.Name, performer.Gender, config.GetInstance().GetCustomPerformerImageLocation()) + image = getDefaultPerformerImage(performer.Name, performer.Gender) } utils.ServeImage(w, r, image) diff --git a/internal/api/server.go b/internal/api/server.go index ba4607fb2f1..6c8ea857114 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -50,7 +50,9 @@ var uiBox = ui.UIBox var loginUIBox = ui.LoginUIBox func Start() error { - initialiseImages() + c := config.GetInstance() + + initCustomPerformerImages(c.GetCustomPerformerImageLocation()) r := chi.NewRouter() @@ -62,7 +64,6 @@ func Start() error { r.Use(middleware.Recoverer) - c := config.GetInstance() if c.GetLogAccess() { httpLogger := httplog.NewLogger("Stash", httplog.Options{ Concise: true, From 5293ffe665698e0497dea0f86352884cae9d06a9 Mon Sep 17 00:00:00 2001 From: WithoutPants <53250216+WithoutPants@users.noreply.github.com> Date: Mon, 16 Oct 2023 13:53:46 +1100 Subject: [PATCH 8/9] Lint --- pkg/scraper/postprocessing.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/scraper/postprocessing.go b/pkg/scraper/postprocessing.go index 3077d915dbc..6c985262a32 100644 --- a/pkg/scraper/postprocessing.go +++ b/pkg/scraper/postprocessing.go @@ -177,7 +177,7 @@ func (c Cache) postScrapeGallery(ctx context.Context, g ScrapedGallery) (Scraped g.URLs = []string{*g.URL} } - r := c.repository + r := c.repository if err := r.WithReadTxn(ctx, func(ctx context.Context) error { pqb := r.PerformerFinder tqb := r.TagFinder From e0255fbcb6c34217611b9d5719c1e84d405ac571 Mon Sep 17 00:00:00 2001 From: WithoutPants <53250216+WithoutPants@users.noreply.github.com> Date: Mon, 16 Oct 2023 14:03:08 +1100 Subject: [PATCH 9/9] Lint --- pkg/scraper/postprocessing.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/scraper/postprocessing.go b/pkg/scraper/postprocessing.go index 6c985262a32..0cf9b5a17fb 100644 --- a/pkg/scraper/postprocessing.go +++ b/pkg/scraper/postprocessing.go @@ -176,7 +176,7 @@ func (c Cache) postScrapeGallery(ctx context.Context, g ScrapedGallery) (Scraped if g.URL != nil && len(g.URLs) == 0 { g.URLs = []string{*g.URL} } - + r := c.repository if err := r.WithReadTxn(ctx, func(ctx context.Context) error { pqb := r.PerformerFinder