From 2ee3f54fadf251226ff77b3a1f4cee7304f5eaad Mon Sep 17 00:00:00 2001 From: Martin Strobel Date: Sun, 19 Dec 2021 17:44:54 -0800 Subject: [PATCH] Use context instead of channels for cancellation. (#17) In v1, raw channels are used. Not really a problem, but a new idiom has become popular since this library was originally authored, and this makes the muscle memory easier. Fixes #7 --- README.md | 6 +-- dictionary.go | 11 ++--- dictionary_examples_test.go | 5 ++- dictionary_test.go | 3 +- doc.go | 2 +- fibonacci.go | 6 ++- filesystem.go | 8 ++-- filesystem_test.go | 8 ++-- linkedlist.go | 7 +-- linkedlist_examples_test.go | 3 +- linkedlist_test.go | 9 ++-- list.go | 5 ++- lru_cache.go | 35 ++++++++------- lru_example_test.go | 6 +-- query.go | 90 +++++++++++++++---------------------- query_examples_test.go | 58 ++++++++++-------------- query_test.go | 3 +- queue.go | 5 ++- stack.go | 5 ++- 19 files changed, 130 insertions(+), 145 deletions(-) diff --git a/README.md b/README.md index e6815be..2922996 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Converting between slices and a queryable structure is as trivial as it should b original := []interface{}{"a", "b", "c"} subject := collection.AsEnumerable(original...) -for entry := range subject.Enumerate(nil) { +for entry := range subject.Enumerate(context.Background()) { fmt.Println(entry) } // Output: @@ -28,7 +28,7 @@ subject := collection.AsEnumerable(1, 2, 3, 4, 5, 6) filtered := collection.Where(subject, func(num interface{}) bool{ return num.(int) > 3 }) -for entry := range filtered.Enumerate(nil) { +for entry := range filtered.Enumerate(context.Background()) { fmt.Println(entry) } // Output: @@ -42,7 +42,7 @@ subject := collection.AsEnumerable(1, 2, 3, 4, 5, 6) updated := collection.Select(subject, func(num interface{}) interface{}{ return num.(int) + 10 }) -for entry := range updated.Enumerate(nil) { +for entry := range updated.Enumerate(context.Background()) { fmt.Println(entry) } // Output: diff --git a/dictionary.go b/dictionary.go index 4eaa6bf..5732ed3 100644 --- a/dictionary.go +++ b/dictionary.go @@ -1,6 +1,7 @@ package collection import ( + "context" "sort" ) @@ -130,14 +131,14 @@ func (dict Dictionary) Size() int64 { } // Enumerate lists each word in the Dictionary alphabetically. -func (dict Dictionary) Enumerate(cancel <-chan struct{}) Enumerator { +func (dict Dictionary) Enumerate(ctx context.Context) Enumerator { if dict.root == nil { - return Empty.Enumerate(cancel) + return Empty.Enumerate(ctx) } - return dict.root.Enumerate(cancel) + return dict.root.Enumerate(ctx) } -func (node trieNode) Enumerate(cancel <-chan struct{}) Enumerator { +func (node trieNode) Enumerate(ctx context.Context) Enumerator { var enumerateHelper func(trieNode, string) results := make(chan interface{}) @@ -146,7 +147,7 @@ func (node trieNode) Enumerate(cancel <-chan struct{}) Enumerator { if subject.IsWord { select { case results <- prefix: - case <-cancel: + case <-ctx.Done(): return } } diff --git a/dictionary_examples_test.go b/dictionary_examples_test.go index e04dfbb..99eb5c5 100644 --- a/dictionary_examples_test.go +++ b/dictionary_examples_test.go @@ -1,6 +1,7 @@ package collection_test import ( + "context" "fmt" "strings" @@ -54,11 +55,11 @@ func ExampleDictionary_Enumerate() { return strings.ToUpper(x.(string)) }) - for word := range subject.Enumerate(nil) { + for word := range subject.Enumerate(context.Background()) { fmt.Println(word) } - for word := range upperCase.Enumerate(nil) { + for word := range upperCase.Enumerate(context.Background()) { fmt.Println(word) } diff --git a/dictionary_test.go b/dictionary_test.go index 2072378..a9ff63e 100644 --- a/dictionary_test.go +++ b/dictionary_test.go @@ -1,6 +1,7 @@ package collection import ( + "context" "strings" "testing" ) @@ -38,7 +39,7 @@ func TestDictionary_Enumerate(t *testing.T) { } prev := "" - for result := range subject.Enumerate(nil) { + for result := range subject.Enumerate(context.Background()) { t.Logf(result.(string)) if alreadySeen, ok := expected[result.(string)]; !ok { t.Logf("An unadded value was returned") diff --git a/doc.go b/doc.go index 05b3bb4..9191d1a 100644 --- a/doc.go +++ b/doc.go @@ -11,7 +11,7 @@ // Location: "./", // } // -// results := myDir.Enumerate(nil).Where(func(x interface{}) bool { +// results := myDir.Enumerate(context.Background()).Where(func(x interface{}) bool { // return strings.HasSuffix(x.(string), ".go") // }) // diff --git a/fibonacci.go b/fibonacci.go index 85eaa99..ca9e6ca 100644 --- a/fibonacci.go +++ b/fibonacci.go @@ -1,11 +1,13 @@ package collection +import "context" + type fibonacciGenerator struct{} // Fibonacci is an Enumerable which will dynamically generate the fibonacci sequence. var Fibonacci Enumerable = fibonacciGenerator{} -func (gen fibonacciGenerator) Enumerate(cancel <-chan struct{}) Enumerator { +func (gen fibonacciGenerator) Enumerate(ctx context.Context) Enumerator { retval := make(chan interface{}) go func() { @@ -16,7 +18,7 @@ func (gen fibonacciGenerator) Enumerate(cancel <-chan struct{}) Enumerator { select { case retval <- a: a, b = b, a+b - case <-cancel: + case <-ctx.Done(): return } } diff --git a/filesystem.go b/filesystem.go index 520bdca..58b40dc 100644 --- a/filesystem.go +++ b/filesystem.go @@ -1,7 +1,7 @@ package collection import ( - "errors" + "context" "os" "path/filepath" ) @@ -40,7 +40,7 @@ func (d Directory) applyOptions(loc string, info os.FileInfo) bool { } // Enumerate lists the items in a `Directory` -func (d Directory) Enumerate(cancel <-chan struct{}) Enumerator { +func (d Directory) Enumerate(ctx context.Context) Enumerator { results := make(chan interface{}) go func() { @@ -64,8 +64,8 @@ func (d Directory) Enumerate(cancel <-chan struct{}) Enumerator { select { case results <- currentLocation: // Intentionally Left Blank - case <-cancel: - err = errors.New("directory enumeration cancelled") + case <-ctx.Done(): + err = ctx.Err() } } diff --git a/filesystem_test.go b/filesystem_test.go index d30ec4b..89032cb 100644 --- a/filesystem_test.go +++ b/filesystem_test.go @@ -1,6 +1,7 @@ package collection import ( + "context" "fmt" "math" "path" @@ -54,9 +55,7 @@ func ExampleDirectory_Enumerate() { Options: DirectoryOptionsExcludeDirectories, } - done := make(chan struct{}) - - filesOfInterest := traverser.Enumerate(done).Select(func(subject interface{}) (result interface{}) { + filesOfInterest := traverser.Enumerate(context.Background()).Select(func(subject interface{}) (result interface{}) { cast, ok := subject.(string) if ok { result = path.Base(cast) @@ -75,7 +74,6 @@ func ExampleDirectory_Enumerate() { for entry := range filesOfInterest { fmt.Println(entry.(string)) } - close(done) // Output: filesystem_test.go } @@ -146,7 +144,7 @@ func TestDirectory_Enumerate(t *testing.T) { for _, tc := range testCases { subject.Options = tc.options t.Run(fmt.Sprintf("%d", uint(tc.options)), func(t *testing.T) { - for entry := range subject.Enumerate(nil) { + for entry := range subject.Enumerate(context.Background()) { cast := entry.(string) if _, ok := tc.expected[cast]; !ok { t.Logf("unexpected result: %q", cast) diff --git a/linkedlist.go b/linkedlist.go index 2004042..c46006d 100644 --- a/linkedlist.go +++ b/linkedlist.go @@ -2,6 +2,7 @@ package collection import ( "bytes" + "context" "errors" "fmt" "strings" @@ -98,7 +99,7 @@ func (list *LinkedList) addNodeFront(node *llNode) { } // Enumerate creates a new instance of Enumerable which can be executed on. -func (list *LinkedList) Enumerate(cancel <-chan struct{}) Enumerator { +func (list *LinkedList) Enumerate(ctx context.Context) Enumerator { retval := make(chan interface{}) go func() { @@ -111,7 +112,7 @@ func (list *LinkedList) Enumerate(cancel <-chan struct{}) Enumerator { select { case retval <- current.payload: break - case <-cancel: + case <-ctx.Done(): return } current = current.next @@ -356,7 +357,7 @@ func (list *LinkedList) moveToFront(node *llNode) { // ToSlice converts the contents of the LinkedList into a slice. func (list *LinkedList) ToSlice() []interface{} { - return list.Enumerate(nil).ToSlice() + return list.Enumerate(context.Background()).ToSlice() } func findLast(head *llNode) *llNode { diff --git a/linkedlist_examples_test.go b/linkedlist_examples_test.go index 552bf4e..fc66881 100644 --- a/linkedlist_examples_test.go +++ b/linkedlist_examples_test.go @@ -1,6 +1,7 @@ package collection_test import ( + "context" "fmt" "github.com/marstr/collection/v2" @@ -27,7 +28,7 @@ func ExampleLinkedList_AddBack() { func ExampleLinkedList_Enumerate() { subject := collection.NewLinkedList(2, 3, 5, 8) - results := subject.Enumerate(nil).Select(func(a interface{}) interface{} { + results := subject.Enumerate(context.Background()).Select(func(a interface{}) interface{} { return -1 * a.(int) }) for entry := range results { diff --git a/linkedlist_test.go b/linkedlist_test.go index b9472dd..8b2dcd6 100644 --- a/linkedlist_test.go +++ b/linkedlist_test.go @@ -1,6 +1,9 @@ package collection -import "testing" +import ( + "context" + "testing" +) func TestLinkedList_findLast_empty(t *testing.T) { if result := findLast(nil); result != nil { @@ -101,7 +104,7 @@ func TestLinkedList_mergeSort_repair(t *testing.T) { for _, tc := range testCases { t.Run(tc.String(), func(t *testing.T) { originalLength := tc.Length() - originalElements := tc.Enumerate(nil).ToSlice() + originalElements := tc.Enumerate(context.Background()).ToSlice() originalContents := tc.String() if err := tc.Sorti(); err != ErrUnexpectedType { @@ -116,7 +119,7 @@ func TestLinkedList_mergeSort_repair(t *testing.T) { t.Fail() } - remaining := tc.Enumerate(nil).ToSlice() + remaining := tc.Enumerate(context.Background()).ToSlice() for _, desired := range originalElements { found := false diff --git a/list.go b/list.go index 29f549a..4335036 100644 --- a/list.go +++ b/list.go @@ -2,6 +2,7 @@ package collection import ( "bytes" + "context" "fmt" "sync" ) @@ -38,7 +39,7 @@ func (l *List) AddAt(pos uint, entries ...interface{}) { } // Enumerate lists each element present in the collection -func (l *List) Enumerate(cancel <-chan struct{}) Enumerator { +func (l *List) Enumerate(ctx context.Context) Enumerator { retval := make(chan interface{}) go func() { @@ -50,7 +51,7 @@ func (l *List) Enumerate(cancel <-chan struct{}) Enumerator { select { case retval <- entry: break - case <-cancel: + case <-ctx.Done(): return } } diff --git a/lru_cache.go b/lru_cache.go index 8ff5f28..aff2d7a 100644 --- a/lru_cache.go +++ b/lru_cache.go @@ -1,19 +1,22 @@ package collection -import "sync" +import ( + "context" + "sync" +) // LRUCache hosts up to a given number of items. When more are presented, the least recently used item // is evicted from the cache. type LRUCache struct { capacity uint - entries map[interface{}]*lruEntry - touched *LinkedList - key sync.RWMutex + entries map[interface{}]*lruEntry + touched *LinkedList + key sync.RWMutex } type lruEntry struct { - Node *llNode - Key interface{} + Node *llNode + Key interface{} Value interface{} } @@ -21,7 +24,7 @@ type lruEntry struct { func NewLRUCache(capacity uint) *LRUCache { return &LRUCache{ capacity: capacity, - entries: make(map[interface{}]*lruEntry, capacity + 1), + entries: make(map[interface{}]*lruEntry, capacity+1), touched: NewLinkedList(), } } @@ -36,8 +39,8 @@ func (lru *LRUCache) Put(key interface{}, value interface{}) { lru.touched.removeNode(entry.Node) } else { entry = &lruEntry{ - Node: &llNode{}, - Key: key, + Node: &llNode{}, + Key: key, } } @@ -85,10 +88,10 @@ func (lru *LRUCache) Remove(key interface{}) bool { } // Enumerate lists each value in the cache. -func (lru *LRUCache) Enumerate(cancel <-chan struct{}) Enumerator { +func (lru *LRUCache) Enumerate(ctx context.Context) Enumerator { retval := make(chan interface{}) - nested := lru.touched.Enumerate(cancel) + nested := lru.touched.Enumerate(ctx) go func() { lru.key.RLock() @@ -99,7 +102,7 @@ func (lru *LRUCache) Enumerate(cancel <-chan struct{}) Enumerator { select { case retval <- entry.(*lruEntry).Value: break - case <-cancel: + case <-ctx.Done(): return } } @@ -109,21 +112,21 @@ func (lru *LRUCache) Enumerate(cancel <-chan struct{}) Enumerator { } // EnumerateKeys lists each key in the cache. -func (lru *LRUCache) EnumerateKeys(cancel <-chan struct{}) Enumerator { +func (lru *LRUCache) EnumerateKeys(ctx context.Context) Enumerator { retval := make(chan interface{}) - nested := lru.touched.Enumerate(cancel) + nested := lru.touched.Enumerate(ctx) go func() { lru.key.RLock() - defer lru.key.RUnlock() + defer lru.key.RUnlock() defer close(retval) for entry := range nested { select { case retval <- entry.(*lruEntry).Key: break - case <-cancel: + case <-ctx.Done(): return } } diff --git a/lru_example_test.go b/lru_example_test.go index 20a7984..9d1b1b6 100644 --- a/lru_example_test.go +++ b/lru_example_test.go @@ -20,14 +20,13 @@ func ExampleLRUCache() { } func ExampleLRUCache_Enumerate() { - ctx := context.Background() subject := collection.NewLRUCache(3) subject.Put(1, "one") subject.Put(2, "two") subject.Put(3, "three") subject.Put(4, "four") - for key := range subject.Enumerate(ctx.Done()) { + for key := range subject.Enumerate(context.Background()) { fmt.Println(key) } @@ -38,14 +37,13 @@ func ExampleLRUCache_Enumerate() { } func ExampleLRUCache_EnumerateKeys() { - ctx := context.Background() subject := collection.NewLRUCache(3) subject.Put(1, "one") subject.Put(2, "two") subject.Put(3, "three") subject.Put(4, "four") - for key := range subject.EnumerateKeys(ctx.Done()) { + for key := range subject.EnumerateKeys(context.Background()) { fmt.Println(key) } diff --git a/query.go b/query.go index ac3632e..4e37c67 100644 --- a/query.go +++ b/query.go @@ -1,6 +1,7 @@ package collection import ( + "context" "errors" "reflect" "runtime" @@ -10,7 +11,7 @@ import ( // Enumerable offers a means of easily converting into a channel. It is most // useful for types where mutability is not in question. type Enumerable interface { - Enumerate(cancel <-chan struct{}) Enumerator + Enumerate(ctx context.Context) Enumerator } // Enumerator exposes a new syntax for querying familiar data structures. @@ -52,7 +53,7 @@ var Identity Transform = func(value interface{}) interface{} { // Empty is an Enumerable that has no elements, and will never have any elements. var Empty Enumerable = &emptyEnumerable{} -func (e emptyEnumerable) Enumerate(cancel <-chan struct{}) Enumerator { +func (e emptyEnumerable) Enumerate(ctx context.Context) Enumerator { results := make(chan interface{}) close(results) return results @@ -60,10 +61,7 @@ func (e emptyEnumerable) Enumerate(cancel <-chan struct{}) Enumerator { // All tests whether or not all items present in an Enumerable meet a criteria. func All(subject Enumerable, p Predicate) bool { - done := make(chan struct{}) - defer close(done) - - return subject.Enumerate(done).All(p) + return subject.Enumerate(context.Background()).All(p) } // All tests whether or not all items present meet a criteria. @@ -78,10 +76,7 @@ func (iter Enumerator) All(p Predicate) bool { // Any tests an Enumerable to see if there are any elements present. func Any(iterator Enumerable) bool { - done := make(chan struct{}) - defer close(done) - - for range iterator.Enumerate(done) { + for range iterator.Enumerate(context.Background()) { return true } return false @@ -89,10 +84,7 @@ func Any(iterator Enumerable) bool { // Anyp tests an Enumerable to see if there are any elements present that meet a criteria. func Anyp(iterator Enumerable, p Predicate) bool { - done := make(chan struct{}) - defer close(done) - - for element := range iterator.Enumerate(done) { + for element := range iterator.Enumerate(context.Background()) { if p(element) { return true } @@ -102,7 +94,7 @@ func Anyp(iterator Enumerable, p Predicate) bool { type enumerableSlice []interface{} -func (f enumerableSlice) Enumerate(cancel <-chan struct{}) Enumerator { +func (f enumerableSlice) Enumerate(ctx context.Context) Enumerator { results := make(chan interface{}) go func() { @@ -111,7 +103,7 @@ func (f enumerableSlice) Enumerate(cancel <-chan struct{}) Enumerator { select { case results <- entry: break - case <-cancel: + case <-ctx.Done(): return } } @@ -124,7 +116,7 @@ type enumerableValue struct { reflect.Value } -func (v enumerableValue) Enumerate(cancel <-chan struct{}) Enumerator { +func (v enumerableValue) Enumerate(ctx context.Context) Enumerator { results := make(chan interface{}) go func() { @@ -136,7 +128,7 @@ func (v enumerableValue) Enumerate(cancel <-chan struct{}) Enumerator { select { case results <- v.Index(i).Interface(): break - case <-cancel: + case <-ctx.Done(): return } } @@ -169,7 +161,7 @@ func (iter Enumerator) AsEnumerable() Enumerable { // Count iterates over a list and keeps a running tally of the number of elements // satisfy a predicate. func Count(iter Enumerable, p Predicate) int { - return iter.Enumerate(nil).Count(p) + return iter.Enumerate(context.Background()).Count(p) } // Count iterates over a list and keeps a running tally of the number of elements @@ -186,7 +178,7 @@ func (iter Enumerator) Count(p Predicate) int { // CountAll iterates over a list and keeps a running tally of how many it's seen. func CountAll(iter Enumerable) int { - return iter.Enumerate(nil).CountAll() + return iter.Enumerate(context.Background()).CountAll() } // CountAll iterates over a list and keeps a running tally of how many it's seen. @@ -208,9 +200,7 @@ func (iter Enumerator) Discard() { // ElementAt retreives an item at a particular position in an Enumerator. func ElementAt(iter Enumerable, n uint) interface{} { - done := make(chan struct{}) - defer close(done) - return iter.Enumerate(done).ElementAt(n) + return iter.Enumerate(context.Background()).ElementAt(n) } // ElementAt retreives an item at a particular position in an Enumerator. @@ -223,23 +213,20 @@ func (iter Enumerator) ElementAt(n uint) interface{} { // First retrieves just the first item in the list, or returns an error if there are no elements in the array. func First(subject Enumerable) (retval interface{}, err error) { - done := make(chan struct{}) - err = errNoElements var isOpen bool - if retval, isOpen = <-subject.Enumerate(done); isOpen { + if retval, isOpen = <-subject.Enumerate(context.Background()); isOpen { err = nil } - close(done) return } // Last retreives the item logically behind all other elements in the list. func Last(iter Enumerable) interface{} { - return iter.Enumerate(nil).Last() + return iter.Enumerate(context.Background()).Last() } // Last retreives the item logically behind all other elements in the list. @@ -254,7 +241,7 @@ type merger struct { originals []Enumerable } -func (m merger) Enumerate(cancel <-chan struct{}) Enumerator { +func (m merger) Enumerate(ctx context.Context) Enumerator { retval := make(chan interface{}) var wg sync.WaitGroup @@ -262,7 +249,7 @@ func (m merger) Enumerate(cancel <-chan struct{}) Enumerator { for _, item := range m.originals { go func(input Enumerable) { defer wg.Done() - for value := range input.Enumerate(cancel) { + for value := range input.Enumerate(ctx) { retval <- value } }(item) @@ -315,8 +302,8 @@ type parallelSelecter struct { operation Transform } -func (ps parallelSelecter) Enumerate(cancel <-chan struct{}) Enumerator { - return ps.original.Enumerate(cancel).ParallelSelect(ps.operation) +func (ps parallelSelecter) Enumerate(ctx context.Context) Enumerator { + return ps.original.Enumerate(ctx).ParallelSelect(ps.operation) } // ParallelSelect creates an Enumerable which will use all logically available CPUs to @@ -348,8 +335,8 @@ func Reverse(original Enumerable) Enumerable { } } -func (r reverser) Enumerate(cancel <-chan struct{}) Enumerator { - return r.original.Enumerate(cancel).Reverse() +func (r reverser) Enumerate(ctx context.Context) Enumerator { + return r.original.Enumerate(ctx).Reverse() } // Reverse returns items in the opposite order it encountered them in. @@ -376,8 +363,8 @@ type selecter struct { transform Transform } -func (s selecter) Enumerate(cancel <-chan struct{}) Enumerator { - return s.original.Enumerate(cancel).Select(s.transform) +func (s selecter) Enumerate(ctx context.Context) Enumerator { + return s.original.Enumerate(ctx).Select(s.transform) } // Select creates a reusable stream of transformed values. @@ -407,8 +394,8 @@ type selectManyer struct { toMany Unfolder } -func (s selectManyer) Enumerate(cancel <-chan struct{}) Enumerator { - return s.original.Enumerate(cancel).SelectMany(s.toMany) +func (s selectManyer) Enumerate(ctx context.Context) Enumerator { + return s.original.Enumerate(ctx).SelectMany(s.toMany) } // SelectMany allows for unfolding of values. @@ -437,13 +424,10 @@ func (iter Enumerator) SelectMany(lister Unfolder) Enumerator { // Single retreives the only element from a list, or returns nil and an error. func Single(iter Enumerable) (retval interface{}, err error) { - done := make(chan struct{}) - defer close(done) - err = errNoElements firstPass := true - for entry := range iter.Enumerate(done) { + for entry := range iter.Enumerate(context.Background()) { if firstPass { retval = entry err = nil @@ -470,8 +454,8 @@ type skipper struct { skipCount uint } -func (s skipper) Enumerate(cancel <-chan struct{}) Enumerator { - return s.original.Enumerate(cancel).Skip(s.skipCount) +func (s skipper) Enumerate(ctx context.Context) Enumerator { + return s.original.Enumerate(ctx).Skip(s.skipCount) } // Skip creates a reusable stream which will skip the first `n` elements before iterating @@ -536,8 +520,8 @@ type taker struct { n uint } -func (t taker) Enumerate(cancel <-chan struct{}) Enumerator { - return t.original.Enumerate(cancel).Take(t.n) +func (t taker) Enumerate(ctx context.Context) Enumerator { + return t.original.Enumerate(ctx).Take(t.n) } // Take retreives just the first `n` elements from an Enumerable. @@ -572,8 +556,8 @@ type takeWhiler struct { criteria func(interface{}, uint) bool } -func (tw takeWhiler) Enumerate(cancel <-chan struct{}) Enumerator { - return tw.original.Enumerate(cancel).TakeWhile(tw.criteria) +func (tw takeWhiler) Enumerate(ctx context.Context) Enumerator { + return tw.original.Enumerate(ctx).TakeWhile(tw.criteria) } // TakeWhile creates a reusable stream which will halt once some criteria is no longer met. @@ -621,7 +605,7 @@ func (iter Enumerator) Tee() (Enumerator, Enumerator) { // ToSlice places all iterated over values in a Slice for easy consumption. func ToSlice(iter Enumerable) []interface{} { - return iter.Enumerate(nil).ToSlice() + return iter.Enumerate(context.Background()).ToSlice() } // ToSlice places all iterated over values in a Slice for easy consumption. @@ -638,12 +622,12 @@ type wherer struct { filter Predicate } -func (w wherer) Enumerate(cancel <-chan struct{}) Enumerator { +func (w wherer) Enumerate(ctx context.Context) Enumerator { retval := make(chan interface{}) go func() { defer close(retval) - for entry := range w.original.Enumerate(cancel) { + for entry := range w.original.Enumerate(ctx) { if w.filter(entry) { retval <- entry } @@ -680,7 +664,7 @@ func (iter Enumerator) Where(predicate Predicate) Enumerator { // UCount iterates over a list and keeps a running tally of the number of elements // satisfy a predicate. func UCount(iter Enumerable, p Predicate) uint { - return iter.Enumerate(nil).UCount(p) + return iter.Enumerate(context.Background()).UCount(p) } // UCount iterates over a list and keeps a running tally of the number of elements @@ -697,7 +681,7 @@ func (iter Enumerator) UCount(p Predicate) uint { // UCountAll iterates over a list and keeps a running tally of how many it's seen. func UCountAll(iter Enumerable) uint { - return iter.Enumerate(nil).UCountAll() + return iter.Enumerate(context.Background()).UCountAll() } // UCountAll iterates over a list and keeps a running tally of how many it's seen. diff --git a/query_examples_test.go b/query_examples_test.go index 1124091..44e9677 100644 --- a/query_examples_test.go +++ b/query_examples_test.go @@ -1,6 +1,7 @@ package collection_test import ( + "context" "fmt" "sync" @@ -12,14 +13,14 @@ func ExampleAsEnumerable() { original := []int{1, 2, 3, 4, 5} wrapped := collection.AsEnumerable(original) - for entry := range wrapped.Enumerate(nil) { + for entry := range wrapped.Enumerate(context.Background()) { fmt.Print(entry) } fmt.Println() // When multiple values are provided, regardless of their type, they are each treated as enumerable values. wrapped = collection.AsEnumerable("red", "orange", "yellow", "green", "blue", "indigo", "violet") - for entry := range wrapped.Enumerate(nil) { + for entry := range wrapped.Enumerate(context.Background()) { fmt.Println(entry) } // Output: @@ -35,7 +36,7 @@ func ExampleAsEnumerable() { func ExampleEnumerator_Count() { subject := collection.AsEnumerable("str1", "str1", "str2") - count1 := subject.Enumerate(nil).Count(func(a interface{}) bool { + count1 := subject.Enumerate(context.Background()).Count(func(a interface{}) bool { return a == "str1" }) fmt.Println(count1) @@ -44,14 +45,12 @@ func ExampleEnumerator_Count() { func ExampleEnumerator_CountAll() { subject := collection.AsEnumerable('a', 'b', 'c', 'd', 'e') - fmt.Println(subject.Enumerate(nil).CountAll()) + fmt.Println(subject.Enumerate(context.Background()).CountAll()) // Output: 5 } func ExampleEnumerator_ElementAt() { - done := make(chan struct{}) - defer close(done) - fmt.Print(collection.Fibonacci.Enumerate(done).ElementAt(4)) + fmt.Print(collection.Fibonacci.Enumerate(context.Background()).ElementAt(4)) // Output: 3 } @@ -75,7 +74,7 @@ func ExampleLast() { func ExampleEnumerator_Last() { subject := collection.AsEnumerable(1, 2, 3) - fmt.Print(subject.Enumerate(nil).Last()) + fmt.Print(subject.Enumerate(context.Background()).Last()) //Output: 3 } @@ -84,13 +83,13 @@ func ExampleMerge() { b := collection.AsEnumerable(8, 16, 32) c := collection.Merge(a, b) sum := 0 - for x := range c.Enumerate(nil) { + for x := range c.Enumerate(context.Background()) { sum += x.(int) } fmt.Println(sum) product := 1 - for y := range a.Enumerate(nil) { + for y := range a.Enumerate(context.Background()) { product *= y.(int) } fmt.Println(product) @@ -100,7 +99,7 @@ func ExampleMerge() { } func ExampleEnumerator_Reverse() { - a := collection.AsEnumerable(1, 2, 3).Enumerate(nil) + a := collection.AsEnumerable(1, 2, 3).Enumerate(context.Background()) a = a.Reverse() fmt.Println(a.ToSlice()) // Output: [3 2 1] @@ -119,7 +118,7 @@ func ExampleSelect() { } func ExampleEnumerator_Select() { - subject := collection.AsEnumerable('a', 'b', 'c').Enumerate(nil) + subject := collection.AsEnumerable('a', 'b', 'c').Enumerate(context.Background()) const offset = 'a' - 1 results := subject.Select(func(a interface{}) interface{} { return a.(rune) - offset @@ -166,8 +165,8 @@ func ExampleEnumerator_SelectMany() { }, ) - beers := breweries.Enumerate(nil).SelectMany(func(brewer interface{}) collection.Enumerator { - return brewer.(BrewHouse).Beers.Enumerate(nil) + beers := breweries.Enumerate(context.Background()).SelectMany(func(brewer interface{}) collection.Enumerator { + return brewer.(BrewHouse).Beers.Enumerate(context.Background()) }) for beer := range beers { @@ -185,11 +184,8 @@ func ExampleEnumerator_SelectMany() { } func ExampleSkip() { - done := make(chan struct{}) - defer close(done) - trimmed := collection.Take(collection.Skip(collection.Fibonacci, 1), 3) - for entry := range trimmed.Enumerate(done) { + for entry := range trimmed.Enumerate(context.Background()) { fmt.Println(entry) } // Output: @@ -200,7 +196,7 @@ func ExampleSkip() { func ExampleEnumerator_Skip() { subject := collection.AsEnumerable(1, 2, 3, 4, 5, 6, 7) - skipped := subject.Enumerate(nil).Skip(5) + skipped := subject.Enumerate(context.Background()).Skip(5) for entry := range skipped { fmt.Println(entry) } @@ -210,11 +206,8 @@ func ExampleEnumerator_Skip() { } func ExampleTake() { - done := make(chan struct{}) - defer close(done) - taken := collection.Take(collection.Fibonacci, 4) - for entry := range taken.Enumerate(done) { + for entry := range taken.Enumerate(context.Background()) { fmt.Println(entry) } // Output: @@ -225,10 +218,7 @@ func ExampleTake() { } func ExampleEnumerator_Take() { - done := make(chan struct{}) - defer close(done) - - taken := collection.Fibonacci.Enumerate(done).Skip(4).Take(2) + taken := collection.Fibonacci.Enumerate(context.Background()).Skip(4).Take(2) for entry := range taken { fmt.Println(entry) } @@ -241,7 +231,7 @@ func ExampleTakeWhile() { taken := collection.TakeWhile(collection.Fibonacci, func(x interface{}, n uint) bool { return x.(int) < 10 }) - for entry := range taken.Enumerate(nil) { + for entry := range taken.Enumerate(context.Background()) { fmt.Println(entry) } // Output: @@ -255,7 +245,7 @@ func ExampleTakeWhile() { } func ExampleEnumerator_TakeWhile() { - taken := collection.Fibonacci.Enumerate(nil).TakeWhile(func(x interface{}, n uint) bool { + taken := collection.Fibonacci.Enumerate(context.Background()).TakeWhile(func(x interface{}, n uint) bool { return x.(int) < 6 }) for entry := range taken { @@ -272,7 +262,7 @@ func ExampleEnumerator_TakeWhile() { func ExampleEnumerator_Tee() { base := collection.AsEnumerable(1, 2, 4) - left, right := base.Enumerate(nil).Tee() + left, right := base.Enumerate(context.Background()).Tee() var wg sync.WaitGroup wg.Add(2) @@ -313,7 +303,7 @@ func ExampleUCount() { func ExampleEnumerator_UCount() { subject := collection.AsEnumerable("str1", "str1", "str2") - count1 := subject.Enumerate(nil).UCount(func(a interface{}) bool { + count1 := subject.Enumerate(context.Background()).UCount(func(a interface{}) bool { return a == "str1" }) fmt.Println(count1) @@ -328,14 +318,12 @@ func ExampleUCountAll() { func ExampleEnumerator_UCountAll() { subject := collection.AsEnumerable('a', 2, "str1") - fmt.Println(subject.Enumerate(nil).UCountAll()) + fmt.Println(subject.Enumerate(context.Background()).UCountAll()) // Output: 3 } func ExampleEnumerator_Where() { - done := make(chan struct{}) - defer close(done) - results := collection.Fibonacci.Enumerate(done).Where(func(a interface{}) bool { + results := collection.Fibonacci.Enumerate(context.Background()).Where(func(a interface{}) bool { return a.(int) > 8 }).Take(3) fmt.Println(results.ToSlice()) diff --git a/query_test.go b/query_test.go index 993353f..8c919c5 100644 --- a/query_test.go +++ b/query_test.go @@ -1,6 +1,7 @@ package collection import ( + "context" "testing" "time" ) @@ -31,7 +32,7 @@ func BenchmarkEnumerator_Sum(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - for range nums.Enumerate(nil).Select(sleepIdentity) { + for range nums.Enumerate(context.Background()).Select(sleepIdentity) { // Intentionally Left Blank } } diff --git a/queue.go b/queue.go index 0ca83a7..bf4f4d8 100644 --- a/queue.go +++ b/queue.go @@ -1,6 +1,7 @@ package collection import ( + "context" "sync" ) @@ -29,10 +30,10 @@ func (q *Queue) Add(entry interface{}) { } // Enumerate peeks at each element of this queue without mutating it. -func (q *Queue) Enumerate(cancel <-chan struct{}) Enumerator { +func (q *Queue) Enumerate(ctx context.Context) Enumerator { q.key.RLock() defer q.key.RUnlock() - return q.underlyer.Enumerate(cancel) + return q.underlyer.Enumerate(ctx) } // IsEmpty tests the Queue to determine if it is populate or not. diff --git a/stack.go b/stack.go index ecb9cc6..8a5f140 100644 --- a/stack.go +++ b/stack.go @@ -1,6 +1,7 @@ package collection import ( + "context" "sync" ) @@ -22,11 +23,11 @@ func NewStack(entries ...interface{}) *Stack { } // Enumerate peeks at each element in the stack without mutating it. -func (stack *Stack) Enumerate(cancel <-chan struct{}) Enumerator { +func (stack *Stack) Enumerate(ctx context.Context) Enumerator { stack.key.RLock() defer stack.key.RUnlock() - return stack.underlyer.Enumerate(cancel) + return stack.underlyer.Enumerate(ctx) } // IsEmpty tests the Stack to determine if it is populate or not.