diff --git a/iter.go b/iter.go index 1d46899d..1b2853cf 100644 --- a/iter.go +++ b/iter.go @@ -18,7 +18,7 @@ import ( "go.aporeto.io/elemental" ) -const iterDefaultBlockSize = 10000 +const iterDefaultBlockSize = 1000 // IterFunc calls RetrieveMany on the given Manipulator, and will retrieve the data by block // of the given blockSize. @@ -29,7 +29,9 @@ const iterDefaultBlockSize = 10000 // current data block. If the function returns an error, the error is returned to the caller // of IterFunc and the iteration stops. // -// The given context will be used if the underlying manipulator honors it. +// The given context will be used if the underlying manipulator honors it. Be careful to NOT pass +// a filter matching objects then updating the objects to not match anynmore. This would shift +// pagination and will produce unexpected results. To do so, prefer using manipulate.IterUntilFunc // // The given manipulate.Context will be used to retry any failed batch recovery. // @@ -37,7 +39,7 @@ const iterDefaultBlockSize = 10000 // hold the data block. It is reset at every iteration. Do not rely on it to be filled // once IterFunc is complete. // -// Finally, if the given blockSize is <= 0, then it will use the default that is 10000. +// Finally, if the given blockSize is <= 0, then it will use the default that is 1000. func IterFunc( ctx context.Context, manipulator Manipulator, @@ -46,46 +48,23 @@ func IterFunc( iteratorFunc func(block elemental.Identifiables) error, blockSize int, ) error { + return doIterFunc(ctx, manipulator, identifiablesTemplate, mctx, iteratorFunc, blockSize, false) +} - if manipulator == nil { - panic("manipulator must not be nil") - } - - if iteratorFunc == nil { - panic("iteratorFunc must not be nil") - } - - if identifiablesTemplate == nil { - panic("identifiablesTemplate must not be nil") - } - - if mctx == nil { - mctx = NewContext(ctx) - } - - if blockSize <= 0 { - blockSize = iterDefaultBlockSize - } - - var page int - - for { - page++ - - objects := identifiablesTemplate.Copy() - - if err := manipulator.RetrieveMany(mctx.Derive(ContextOptionPage(page, blockSize)), objects); err != nil { - return fmt.Errorf("unable to retrieve objects for page %d: %s", page, err.Error()) - } - - if len(objects.List()) == 0 { - return nil - } - - if err := iteratorFunc(objects); err != nil { - return fmt.Errorf("iter function returned an error on page %d: %s", page, err) - } - } +// IterUntilFunc works as IterFunc but pagination will not increase. +// It will always retrieve the first page with a size of given blockSize. +// +// The goal of this function is to be used with a filter, then update (or delete) the +// objects that match until no more are matching. +func IterUntilFunc( + ctx context.Context, + manipulator Manipulator, + identifiablesTemplate elemental.Identifiables, + mctx Context, + iteratorFunc func(block elemental.Identifiables) error, + blockSize int, +) error { + return doIterFunc(ctx, manipulator, identifiablesTemplate, mctx, iteratorFunc, blockSize, true) } // Iter is a helper function for IterFunc. @@ -125,3 +104,61 @@ func Iter( return identifiablesTemplate, nil } + +func doIterFunc( + ctx context.Context, + manipulator Manipulator, + identifiablesTemplate elemental.Identifiables, + mctx Context, + iteratorFunc func(block elemental.Identifiables) error, + blockSize int, + disablePageIncrease bool, +) error { + + if manipulator == nil { + panic("manipulator must not be nil") + } + + if iteratorFunc == nil { + panic("iteratorFunc must not be nil") + } + + if identifiablesTemplate == nil { + panic("identifiablesTemplate must not be nil") + } + + if mctx == nil { + mctx = NewContext(ctx) + } + + if blockSize <= 0 { + blockSize = iterDefaultBlockSize + } + + var page int + + if disablePageIncrease { + page = 1 + } + + for { + + if !disablePageIncrease { + page++ + } + + objects := identifiablesTemplate.Copy() + + if err := manipulator.RetrieveMany(mctx.Derive(ContextOptionPage(page, blockSize)), objects); err != nil { + return fmt.Errorf("unable to retrieve objects for page %d: %s", page, err.Error()) + } + + if len(objects.List()) == 0 { + return nil + } + + if err := iteratorFunc(objects); err != nil { + return fmt.Errorf("iter function returned an error on page %d: %s", page, err) + } + } +} diff --git a/iter_test.go b/iter_test.go index 9a95baea..5a408e5e 100644 --- a/iter_test.go +++ b/iter_test.go @@ -35,8 +35,10 @@ func makeData(size int) testmodel.ListsList { // A testManipulator is an empty TransactionalManipulator that can be easily mocked. type testManipulator struct { - data testmodel.ListsList - err error + data testmodel.ListsList + err error + stopAtIteration int + iteration int } func (m *testManipulator) RetrieveMany(mctx Context, dest elemental.Identifiables) error { @@ -45,6 +47,13 @@ func (m *testManipulator) RetrieveMany(mctx Context, dest elemental.Identifiable return m.err } + if m.stopAtIteration != 0 { + m.iteration++ + if m.iteration == m.stopAtIteration+1 { + return nil + } + } + start := (mctx.Page() - 1) * mctx.PageSize() end := start + mctx.PageSize() @@ -55,6 +64,7 @@ func (m *testManipulator) RetrieveMany(mctx Context, dest elemental.Identifiable if end > len(m.data) { end = len(m.data) } + *dest.(*testmodel.ListsList) = append(*dest.(*testmodel.ListsList), m.data[start:end]...) return nil @@ -85,14 +95,14 @@ func (m *testManipulator) Count(mctx Context, identity elemental.Identity) (int, return 0, nil } -func TestIterFunc(t *testing.T) { +func TestDoIterFunc(t *testing.T) { - Convey("Given I call IterFunc with no manipulator", t, func() { + Convey("Given I call doIterFunc with no manipulator", t, func() { Convey("Then it should panic", func() { So( func() { - _ = IterFunc(nil, nil, nil, nil, nil, 0) // nolint + _ = doIterFunc(nil, nil, nil, nil, nil, 0, false) // nolint }, ShouldPanicWith, "manipulator must not be nil", @@ -100,12 +110,12 @@ func TestIterFunc(t *testing.T) { }) }) - Convey("Given I call IterFunc with no iterator", t, func() { + Convey("Given I call doIterFunc with no iterator", t, func() { Convey("Then it should panic", func() { So( func() { - _ = IterFunc(nil, &testManipulator{}, nil, nil, nil, 0) // nolint + _ = doIterFunc(nil, &testManipulator{}, nil, nil, nil, 0, false) // nolint }, ShouldPanicWith, "iteratorFunc must not be nil", @@ -113,12 +123,12 @@ func TestIterFunc(t *testing.T) { }) }) - Convey("Given I call IterFunc with no identifiablesTemplate", t, func() { + Convey("Given I call doIterFunc with no identifiablesTemplate", t, func() { Convey("Then it should panic", func() { So( func() { - _ = IterFunc(nil, &testManipulator{}, nil, nil, func(elemental.Identifiables) error { return nil }, 0) // nolint + _ = doIterFunc(nil, &testManipulator{}, nil, nil, func(elemental.Identifiables) error { return nil }, 0, false) // nolint }, ShouldPanicWith, "identifiablesTemplate must not be nil", @@ -136,19 +146,20 @@ func TestIterFunc(t *testing.T) { return nil } - Convey("When I call IterFunc on a round page", func() { + Convey("When I call doIterFunc on a round page", func() { m := &testManipulator{ data: makeData(40), } - err := IterFunc( + err := doIterFunc( context.Background(), m, testmodel.ListsList{}, nil, iter, 10, + false, ) Convey("Then err should be nil", func() { @@ -164,19 +175,20 @@ func TestIterFunc(t *testing.T) { }) }) - Convey("When I call IterFunc on a non round page", func() { + Convey("When I call doIterFunc on a non round page", func() { m := &testManipulator{ data: makeData(45), } - err := IterFunc( + err := doIterFunc( context.Background(), m, testmodel.ListsList{}, nil, iter, 11, + false, ) Convey("Then err should be nil", func() { @@ -192,19 +204,20 @@ func TestIterFunc(t *testing.T) { }) }) - Convey("When I call IterFunc with the default block size", func() { + Convey("When I call doIterFunc with the default block size", func() { m := &testManipulator{ data: makeData(45), } - err := IterFunc( + err := doIterFunc( context.Background(), m, testmodel.ListsList{}, nil, iter, 0, + false, ) Convey("Then err should be nil", func() { @@ -220,19 +233,20 @@ func TestIterFunc(t *testing.T) { }) }) - Convey("When I call IterFunc but there are no objects", func() { + Convey("When I call doIterFunc but there are no objects", func() { m := &testManipulator{ data: makeData(0), } - err := IterFunc( + err := doIterFunc( context.Background(), m, testmodel.ListsList{}, nil, iter, 0, + false, ) Convey("Then err should be nil", func() { @@ -248,20 +262,21 @@ func TestIterFunc(t *testing.T) { }) }) - Convey("When I call IterFunc but manipulate returns an error", func() { + Convey("When I call doIterFunc but manipulate returns an error", func() { m := &testManipulator{ data: makeData(45), err: fmt.Errorf("boom"), } - err := IterFunc( + err := doIterFunc( context.Background(), m, testmodel.ListsList{}, nil, iter, 0, + false, ) Convey("Then err should be nil", func() { @@ -278,7 +293,7 @@ func TestIterFunc(t *testing.T) { }) }) - Convey("When I call IterFunc but iter returns an error", func() { + Convey("When I call doIterFunc but iter returns an error", func() { m := &testManipulator{ data: makeData(45), @@ -288,13 +303,14 @@ func TestIterFunc(t *testing.T) { return fmt.Errorf("paf") } - err := IterFunc( + err := doIterFunc( context.Background(), m, testmodel.ListsList{}, nil, iter, 0, + false, ) Convey("Then err should be nil", func() { @@ -396,3 +412,102 @@ func TestIter(t *testing.T) { }) }) } + +func TestIterUntilFunc(t *testing.T) { + + Convey("Given I have a manipulator and some objects in the db", t, func() { + + m := &testManipulator{ + data: makeData(45), + stopAtIteration: 3, + } + + Convey("When I call Iter", func() { + + dest := testmodel.ListsList{} + err := IterUntilFunc( + context.Background(), + m, + testmodel.ListsList{}, + nil, + func(block elemental.Identifiables) error { + dest = append(dest, *block.(*testmodel.ListsList)...) + return nil + }, + 10, + ) + + Convey("Then err should be nil", func() { + So(err, ShouldBeNil) + }) + + Convey("Then dest should be correct", func() { + So(len(dest.List()), ShouldEqual, len(m.data[:30])) + }) + }) + }) + + Convey("Given I have a manipulator and no object in the db", t, func() { + + m := &testManipulator{ + data: makeData(0), + } + + Convey("When I call Iter", func() { + + dest := testmodel.ListsList{} + err := IterUntilFunc( + context.Background(), + m, + testmodel.ListsList{}, + nil, + func(block elemental.Identifiables) error { + dest = append(dest, *block.(*testmodel.ListsList)...) + return nil + }, + 10, + ) + + Convey("Then err should be nil", func() { + So(err, ShouldBeNil) + }) + + Convey("Then dest should be correct", func() { + So(len(dest.List()), ShouldEqual, 0) + }) + }) + }) + + Convey("Given I have a manipulator but it returns an error", t, func() { + + m := &testManipulator{ + data: makeData(43), + err: fmt.Errorf("pif"), + } + + Convey("When I call Iter", func() { + + dest := testmodel.ListsList{} + err := IterUntilFunc( + context.Background(), + m, + testmodel.ListsList{}, + nil, + func(block elemental.Identifiables) error { + dest = append(dest, *block.(*testmodel.ListsList)...) + return nil + }, + 10, + ) + + Convey("Then err should not be nil", func() { + So(err, ShouldNotBeNil) + So(err.Error(), ShouldEqual, "unable to retrieve objects for page 1: pif") + }) + + Convey("Then dest should be correct", func() { + So(len(dest.List()), ShouldEqual, 0) + }) + }) + }) +}