diff --git a/fixtures.go b/fixtures.go index 6c8cefb..0cf7f3a 100644 --- a/fixtures.go +++ b/fixtures.go @@ -36,8 +36,10 @@ type BeforeSave interface { // Repsitory of fixtures that can be loaded and imported. type Repository struct { - log Logger - registry map[string]any + log Logger + registry map[string]any + order []string + skipResolve bool } // New creates a new fixtures repository. @@ -45,6 +47,7 @@ func New() *Repository { return &Repository{ log: defaultLogger{}, registry: make(map[string]any, 10), + order: make([]string, 0, 10), } } @@ -60,9 +63,15 @@ func (r *Repository) Register(v any) { } r.registry[name] = v + r.order = append(r.order, name) } // SetLogger sets the logger to be used by the repository to notify about warnings. func (r *Repository) SetLogger(l Logger) { r.log = l } + +// SetSkipResolve sets whether the repository should skip resolving relations and use registration order. +func (r *Repository) SetSkipResolve(skip bool) { + r.skipResolve = skip +} diff --git a/source.go b/source.go index a231909..67fd044 100644 --- a/source.go +++ b/source.go @@ -145,14 +145,17 @@ func (r *Repository) ImportDir(ctx context.Context, db rel.Repository, path stri } func (r *Repository) importData(ctx context.Context, db rel.Repository, data map[string][]any) error { - tables, err := r.importOrder() - if err != nil { - r.log.Warn(fmt.Sprintf("failed to get table import order: %v", err)) - - // Fallback to list of map keys - for table := range r.registry { - tables = append(tables, table) + var tables []string + if !r.skipResolve { + var err error + if tables, err = r.importOrder(); err != nil { + r.log.Warn(fmt.Sprintf("failed to get table import order: %v", err)) + + // Fallback to registration order + tables = r.order } + } else { + tables = r.order } return db.Transaction(ctx, func(ctx context.Context) error { diff --git a/source_test.go b/source_test.go index e6d16aa..2213ab8 100644 --- a/source_test.go +++ b/source_test.go @@ -119,3 +119,16 @@ func TestFixtures_ImportDir(t *testing.T) { qt.Check(t, qt.Equals(user.Address.ID, 1)) qt.Check(t, qt.Equals(user.Address.City, "New York")) } + +func TestFixtures_SkipResolve(t *testing.T) { + repo := New() + repo.SetSkipResolve(true) + repo.Register(&Address{}) + repo.Register(&User{}) + repo.Register(&Transaction{}) + + db := createTestDB(t) + + err := repo.ImportDir(context.TODO(), db, "testdata/sample/") + qt.Assert(t, qt.IsNil(err)) +}