diff --git a/dictionary.go b/dictionary.go index 4eaa6bf..46e0344 100644 --- a/dictionary.go +++ b/dictionary.go @@ -130,17 +130,17 @@ func (dict Dictionary) Size() int64 { } // Enumerate lists each word in the Dictionary alphabetically. -func (dict Dictionary) Enumerate(cancel <-chan struct{}) Enumerator { +func (dict Dictionary) Enumerate(cancel <-chan struct{}) Enumerator[string] { if dict.root == nil { - return Empty.Enumerate(cancel) + return Empty[string]().Enumerate(cancel) } return dict.root.Enumerate(cancel) } -func (node trieNode) Enumerate(cancel <-chan struct{}) Enumerator { +func (node trieNode) Enumerate(cancel <-chan struct{}) Enumerator[string] { var enumerateHelper func(trieNode, string) - results := make(chan interface{}) + results := make(chan string) enumerateHelper = func(subject trieNode, prefix string) { if subject.IsWord { diff --git a/dictionary_examples_test.go b/dictionary_examples_test.go index 08f2570..886b3de 100644 --- a/dictionary_examples_test.go +++ b/dictionary_examples_test.go @@ -31,12 +31,12 @@ func ExampleDictionary_Clear() { subject.Add("world") fmt.Println(subject.Size()) - fmt.Println(collection.CountAll(subject)) + fmt.Println(collection.CountAll[string](subject)) subject.Clear() fmt.Println(subject.Size()) - fmt.Println(collection.Any(subject)) + fmt.Println(collection.Any[string](subject)) // Output: // 2 @@ -50,8 +50,8 @@ func ExampleDictionary_Enumerate() { subject.Add("world") subject.Add("hello") - upperCase := collection.Select(subject, func(x interface{}) interface{} { - return strings.ToUpper(x.(string)) + upperCase := collection.Select[string](subject, func(x string) string { + return strings.ToUpper(x) }) for word := range subject.Enumerate(nil) { @@ -76,13 +76,13 @@ func ExampleDictionary_Remove() { subject.Add(world) fmt.Println(subject.Size()) - fmt.Println(collection.CountAll(subject)) + fmt.Println(collection.CountAll[string](subject)) subject.Remove(world) fmt.Println(subject.Size()) - fmt.Println(collection.CountAll(subject)) - fmt.Println(collection.Any(subject)) + fmt.Println(collection.CountAll[string](subject)) + fmt.Println(collection.Any[string](subject)) // Output: // 2 diff --git a/dictionary_test.go b/dictionary_test.go index 2072378..bcb36cc 100644 --- a/dictionary_test.go +++ b/dictionary_test.go @@ -32,29 +32,29 @@ func TestDictionary_Enumerate(t *testing.T) { t.Fail() } - if subjectSize := CountAll(subject); subjectSize != expectedSize { + if subjectSize := CountAll[string](subject); subjectSize != expectedSize { t.Logf("`CountAll` returned %d elements, expected %d", subjectSize, expectedSize) t.Fail() } prev := "" for result := range subject.Enumerate(nil) { - t.Logf(result.(string)) - if alreadySeen, ok := expected[result.(string)]; !ok { + t.Logf(result) + if alreadySeen, ok := expected[result]; !ok { t.Logf("An unadded value was returned") t.Fail() } else if alreadySeen { - t.Logf("\"%s\" was duplicated", result.(string)) + t.Logf("\"%s\" was duplicated", result) t.Fail() } - if stringle(result.(string), prev) { - t.Logf("Results \"%s\" and \"%s\" were not alphabetized.", prev, result.(string)) + if stringle(result, prev) { + t.Logf("Results \"%s\" and \"%s\" were not alphabetized.", prev, result) t.Fail() } - prev = result.(string) + prev = result - expected[result.(string)] = true + expected[result] = true } }) } diff --git a/fibonacci.go b/fibonacci.go index 85eaa99..3139317 100644 --- a/fibonacci.go +++ b/fibonacci.go @@ -3,14 +3,14 @@ package collection type fibonacciGenerator struct{} // Fibonacci is an Enumerable which will dynamically generate the fibonacci sequence. -var Fibonacci Enumerable = fibonacciGenerator{} +var Fibonacci Enumerable[uint] = fibonacciGenerator{} -func (gen fibonacciGenerator) Enumerate(cancel <-chan struct{}) Enumerator { - retval := make(chan interface{}) +func (gen fibonacciGenerator) Enumerate(cancel <-chan struct{}) Enumerator[uint] { + retval := make(chan uint) go func() { defer close(retval) - a, b := 0, 1 + var a, b uint = 0, 1 for { select { diff --git a/filesystem.go b/filesystem.go index 520bdca..5300817 100644 --- a/filesystem.go +++ b/filesystem.go @@ -40,8 +40,8 @@ func (d Directory) applyOptions(loc string, info os.FileInfo) bool { } // Enumerate lists the items in a `Directory` -func (d Directory) Enumerate(cancel <-chan struct{}) Enumerator { - results := make(chan interface{}) +func (d Directory) Enumerate(cancel <-chan struct{}) Enumerator[string] { + results := make(chan string) go func() { defer close(results) @@ -75,3 +75,4 @@ func (d Directory) Enumerate(cancel <-chan struct{}) Enumerator { return results } + diff --git a/filesystem_test.go b/filesystem_test.go index d30ec4b..b21e12e 100644 --- a/filesystem_test.go +++ b/filesystem_test.go @@ -56,24 +56,16 @@ func ExampleDirectory_Enumerate() { done := make(chan struct{}) - filesOfInterest := traverser.Enumerate(done).Select(func(subject interface{}) (result interface{}) { - cast, ok := subject.(string) - if ok { - result = path.Base(cast) - } else { - result = subject - } - return - }).Where(func(subject interface{}) bool { - cast, ok := subject.(string) - if !ok { - return false - } - return cast == "filesystem_test.go" + fileNames := Select[string](traverser, func(subject string) string { + return path.Base(subject) + }) + + filesOfInterest := Where(fileNames, func(subject string) bool { + return subject == "filesystem_test.go" }) - for entry := range filesOfInterest { - fmt.Println(entry.(string)) + for entry := range filesOfInterest.Enumerate(done) { + fmt.Println(entry) } close(done) @@ -92,45 +84,45 @@ func TestDirectory_Enumerate(t *testing.T) { { options: 0, expected: map[string]struct{}{ - filepath.Join("testdata", "foo", "a.txt"): struct{}{}, - filepath.Join("testdata", "foo", "c.txt"): struct{}{}, - filepath.Join("testdata", "foo", "bar"): struct{}{}, + filepath.Join("testdata", "foo", "a.txt"): {}, + filepath.Join("testdata", "foo", "c.txt"): {}, + filepath.Join("testdata", "foo", "bar"): {}, }, }, { options: DirectoryOptionsExcludeFiles, expected: map[string]struct{}{ - filepath.Join("testdata", "foo", "bar"): struct{}{}, + filepath.Join("testdata", "foo", "bar"): {}, }, }, { options: DirectoryOptionsExcludeDirectories, expected: map[string]struct{}{ - filepath.Join("testdata", "foo", "a.txt"): struct{}{}, - filepath.Join("testdata", "foo", "c.txt"): struct{}{}, + filepath.Join("testdata", "foo", "a.txt"): {}, + filepath.Join("testdata", "foo", "c.txt"): {}, }, }, { options: DirectoryOptionsRecursive, expected: map[string]struct{}{ - filepath.Join("testdata", "foo", "bar"): struct{}{}, - filepath.Join("testdata", "foo", "bar", "b.txt"): struct{}{}, - filepath.Join("testdata", "foo", "a.txt"): struct{}{}, - filepath.Join("testdata", "foo", "c.txt"): struct{}{}, + filepath.Join("testdata", "foo", "bar"): {}, + filepath.Join("testdata", "foo", "bar", "b.txt"): {}, + filepath.Join("testdata", "foo", "a.txt"): {}, + filepath.Join("testdata", "foo", "c.txt"): {}, }, }, { options: DirectoryOptionsExcludeFiles | DirectoryOptionsRecursive, expected: map[string]struct{}{ - filepath.Join("testdata", "foo", "bar"): struct{}{}, + filepath.Join("testdata", "foo", "bar"): {}, }, }, { options: DirectoryOptionsRecursive | DirectoryOptionsExcludeDirectories, expected: map[string]struct{}{ - filepath.Join("testdata", "foo", "a.txt"): struct{}{}, - filepath.Join("testdata", "foo", "bar", "b.txt"): struct{}{}, - filepath.Join("testdata", "foo", "c.txt"): struct{}{}, + filepath.Join("testdata", "foo", "a.txt"): {}, + filepath.Join("testdata", "foo", "bar", "b.txt"): {}, + filepath.Join("testdata", "foo", "c.txt"): {}, }, }, { @@ -147,12 +139,11 @@ func TestDirectory_Enumerate(t *testing.T) { subject.Options = tc.options t.Run(fmt.Sprintf("%d", uint(tc.options)), func(t *testing.T) { for entry := range subject.Enumerate(nil) { - cast := entry.(string) - if _, ok := tc.expected[cast]; !ok { - t.Logf("unexpected result: %q", cast) + if _, ok := tc.expected[entry]; !ok { + t.Logf("unexpected result: %q", entry) t.Fail() } - delete(tc.expected, cast) + delete(tc.expected, entry) } if len(tc.expected) != 0 { diff --git a/go.mod b/go.mod index 796cbad..eee511c 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/marstr/collection -go 1.11 +go 1.18 diff --git a/linkedlist.go b/linkedlist.go index 2004042..c2e6b67 100644 --- a/linkedlist.go +++ b/linkedlist.go @@ -4,29 +4,28 @@ import ( "bytes" "errors" "fmt" - "strings" "sync" ) // LinkedList encapsulates a list where each entry is aware of only the next entry in the list. -type LinkedList struct { - first *llNode - last *llNode +type LinkedList[T any] struct { + first *llNode[T] + last *llNode[T] length uint key sync.RWMutex } -type llNode struct { - payload interface{} - next *llNode - prev *llNode +type llNode[T any] struct { + payload T + next *llNode[T] + prev *llNode[T] } // Comparator is a function which evaluates two values to determine their relation to one another. // - Zero is returned when `a` and `b` are equal. // - Positive numbers are returned when `a` is greater than `b`. // - Negative numbers are returned when `a` is less than `b`. -type Comparator func(a, b interface{}) (int, error) +type Comparator[T any] func(a, b T) (int, error) // A collection of errors that may be thrown by functions in this file. var ( @@ -34,8 +33,8 @@ var ( ) // NewLinkedList instantiates a new LinkedList with the entries provided. -func NewLinkedList(entries ...interface{}) *LinkedList { - list := &LinkedList{} +func NewLinkedList[T any](entries ...T) *LinkedList[T] { + list := &LinkedList[T]{} for _, entry := range entries { list.AddBack(entry) @@ -45,18 +44,18 @@ func NewLinkedList(entries ...interface{}) *LinkedList { } // AddBack creates an entry in the LinkedList that is logically at the back of the list. -func (list *LinkedList) AddBack(entry interface{}) { +func (list *LinkedList[T]) AddBack(entry T) { list.key.Lock() defer list.key.Unlock() - toAppend := &llNode{ + toAppend := &llNode[T]{ payload: entry, } list.addNodeBack(toAppend) } -func (list *LinkedList) addNodeBack(node *llNode) { +func (list *LinkedList[T]) addNodeBack(node *llNode[T]) { list.length++ @@ -73,8 +72,8 @@ func (list *LinkedList) addNodeBack(node *llNode) { } // AddFront creates an entry in the LinkedList that is logically at the front of the list. -func (list *LinkedList) AddFront(entry interface{}) { - toAppend := &llNode{ +func (list *LinkedList[T]) AddFront(entry T) { + toAppend := &llNode[T]{ payload: entry, } @@ -84,7 +83,7 @@ func (list *LinkedList) AddFront(entry interface{}) { list.addNodeFront(toAppend) } -func (list *LinkedList) addNodeFront(node *llNode) { +func (list *LinkedList[T]) addNodeFront(node *llNode[T]) { list.length++ node.next = list.first @@ -98,8 +97,8 @@ func (list *LinkedList) addNodeFront(node *llNode) { } // Enumerate creates a new instance of Enumerable which can be executed on. -func (list *LinkedList) Enumerate(cancel <-chan struct{}) Enumerator { - retval := make(chan interface{}) +func (list *LinkedList[T]) Enumerate(cancel <-chan struct{}) Enumerator[T] { + retval := make(chan T) go func() { list.key.RLock() @@ -123,18 +122,18 @@ func (list *LinkedList) Enumerate(cancel <-chan struct{}) Enumerator { // Get finds the value from the LinkedList. // pos is expressed as a zero-based index begining from the 'front' of the list. -func (list *LinkedList) Get(pos uint) (interface{}, bool) { +func (list *LinkedList[T]) Get(pos uint) (T, bool) { list.key.RLock() defer list.key.RUnlock() node, ok := get(list.first, pos) if ok { return node.payload, true } - return nil, false + return *new(T), false } // IsEmpty tests the list to determine if it is populate or not. -func (list *LinkedList) IsEmpty() bool { +func (list *LinkedList[T]) IsEmpty() bool { list.key.RLock() defer list.key.RUnlock() @@ -142,7 +141,7 @@ func (list *LinkedList) IsEmpty() bool { } // Length returns the number of elements present in the LinkedList. -func (list *LinkedList) Length() uint { +func (list *LinkedList[T]) Length() uint { list.key.RLock() defer list.key.RUnlock() @@ -150,34 +149,34 @@ func (list *LinkedList) Length() uint { } // PeekBack returns the entry logicall stored at the back of the list without removing it. -func (list *LinkedList) PeekBack() (interface{}, bool) { +func (list *LinkedList[T]) PeekBack() (T, bool) { list.key.RLock() defer list.key.RUnlock() if list.last == nil { - return nil, false + return *new(T), false } return list.last.payload, true } // PeekFront returns the entry logically stored at the front of this list without removing it. -func (list *LinkedList) PeekFront() (interface{}, bool) { +func (list *LinkedList[T]) PeekFront() (T, bool) { list.key.RLock() defer list.key.RUnlock() if list.first == nil { - return nil, false + return *new(T), false } return list.first.payload, true } // RemoveFront returns the entry logically stored at the front of this list and removes it. -func (list *LinkedList) RemoveFront() (interface{}, bool) { +func (list *LinkedList[T]) RemoveFront() (T, bool) { list.key.Lock() defer list.key.Unlock() if list.first == nil { - return nil, false + return *new(T), false } retval := list.first.payload @@ -193,12 +192,12 @@ func (list *LinkedList) RemoveFront() (interface{}, bool) { } // RemoveBack returns the entry logically stored at the back of this list and removes it. -func (list *LinkedList) RemoveBack() (interface{}, bool) { +func (list *LinkedList[T]) RemoveBack() (T, bool) { list.key.Lock() defer list.key.Unlock() if list.last == nil { - return nil, false + return *new(T), false } retval := list.last.payload @@ -214,7 +213,7 @@ func (list *LinkedList) RemoveBack() (interface{}, bool) { } // removeNode -func (list *LinkedList) removeNode(target *llNode) { +func (list *LinkedList[T]) removeNode(target *llNode[T]) { if target == nil { return } @@ -246,7 +245,7 @@ func (list *LinkedList) removeNode(target *llNode) { // Sort rearranges the positions of the entries in this list so that they are // ascending. -func (list *LinkedList) Sort(comparator Comparator) error { +func (list *LinkedList[T]) Sort(comparator Comparator[T]) error { list.key.Lock() defer list.key.Unlock() var err error @@ -258,56 +257,8 @@ func (list *LinkedList) Sort(comparator Comparator) error { return err } -// Sorta rearranges the position of string entries in this list so that they -// are ascending. -func (list *LinkedList) Sorta() error { - list.key.Lock() - defer list.key.Unlock() - - var err error - list.first, err = mergeSort(list.first, func(a, b interface{}) (int, error) { - castA, ok := a.(string) - if !ok { - return 0, ErrUnexpectedType - } - castB, ok := b.(string) - if !ok { - return 0, ErrUnexpectedType - } - - return strings.Compare(castA, castB), nil - }) - list.last = findLast(list.first) - return err -} - -// Sorti rearranges the position of integer entries in this list so that they -// are ascending. -func (list *LinkedList) Sorti() (err error) { - list.key.Lock() - defer list.key.Unlock() - - list.first, err = mergeSort(list.first, func(a, b interface{}) (int, error) { - castA, ok := a.(int) - if !ok { - return 0, ErrUnexpectedType - } - castB, ok := b.(int) - if !ok { - return 0, ErrUnexpectedType - } - - return castA - castB, nil - }) - if err != nil { - return - } - list.last = findLast(list.first) - return -} - // String prints upto the first fifteen elements of the list in string format. -func (list *LinkedList) String() string { +func (list *LinkedList[T]) String() string { list.key.RLock() defer list.key.RUnlock() @@ -328,11 +279,11 @@ func (list *LinkedList) String() string { // Swap switches the positions in which two values are stored in this list. // x and y represent the indexes of the items that should be swapped. -func (list *LinkedList) Swap(x, y uint) error { +func (list *LinkedList[T]) Swap(x, y uint) error { list.key.Lock() defer list.key.Unlock() - var xNode, yNode *llNode + var xNode, yNode *llNode[T] if temp, ok := get(list.first, x); ok { xNode = temp } else { @@ -350,16 +301,12 @@ func (list *LinkedList) Swap(x, y uint) error { return nil } -func (list *LinkedList) moveToFront(node *llNode) { - -} - // ToSlice converts the contents of the LinkedList into a slice. -func (list *LinkedList) ToSlice() []interface{} { +func (list *LinkedList[T]) ToSlice() []T { return list.Enumerate(nil).ToSlice() } -func findLast(head *llNode) *llNode { +func findLast[T any](head *llNode[T]) *llNode[T] { if head == nil { return nil } @@ -370,7 +317,7 @@ func findLast(head *llNode) *llNode { return current } -func get(head *llNode, pos uint) (*llNode, bool) { +func get[T any](head *llNode[T], pos uint) (*llNode[T], bool) { for i := uint(0); i < pos; i++ { if head == nil { return nil, false @@ -382,13 +329,13 @@ func get(head *llNode, pos uint) (*llNode, bool) { // merge takes two sorted lists and merges them into one sorted list. // Behavior is undefined when you pass a non-sorted list as `left` or `right` -func merge(left, right *llNode, comparator Comparator) (first *llNode, err error) { +func merge[T any](left, right *llNode[T], comparator Comparator[T]) (first *llNode[T], err error) { curLeft := left curRight := right - var last *llNode + var last *llNode[T] - appendResults := func(updated *llNode) { + appendResults := func(updated *llNode[T]) { if last == nil { last = updated } else { @@ -422,14 +369,14 @@ func merge(left, right *llNode, comparator Comparator) (first *llNode, err error return } -func mergeSort(head *llNode, comparator Comparator) (*llNode, error) { +func mergeSort[T any](head *llNode[T], comparator Comparator[T]) (*llNode[T], error) { if head == nil { return nil, nil } left, right := split(head) - repair := func(left, right *llNode) *llNode { + repair := func(left, right *llNode[T]) *llNode[T] { lastLeft := findLast(left) lastLeft.next = right return left @@ -453,7 +400,7 @@ func mergeSort(head *llNode, comparator Comparator) (*llNode, error) { } // split breaks a list in half. -func split(head *llNode) (left, right *llNode) { +func split[T any](head *llNode[T]) (left, right *llNode[T]) { left = head if head == nil || head.next == nil { return diff --git a/linkedlist_examples_test.go b/linkedlist_examples_test.go index 18306d1..80464cd 100644 --- a/linkedlist_examples_test.go +++ b/linkedlist_examples_test.go @@ -27,10 +27,11 @@ func ExampleLinkedList_AddBack() { func ExampleLinkedList_Enumerate() { subject := collection.NewLinkedList(2, 3, 5, 8) - results := subject.Enumerate(nil).Select(func(a interface{}) interface{} { - return -1 * a.(int) + results := collection.Select[int](subject, func(a int) int { + return -1 * a }) - for entry := range results { + + for entry := range results.Enumerate(nil) { fmt.Println(entry) } // Output: @@ -63,50 +64,21 @@ func ExampleLinkedList_Sort() { // Sorti sorts into ascending order, this example demonstrates sorting // into descending order. subject := collection.NewLinkedList(2, 4, 3, 5, 7, 7) - subject.Sort(func(a, b interface{}) (int, error) { - castA, ok := a.(int) - if !ok { - return 0, collection.ErrUnexpectedType - } - castB, ok := b.(int) - if !ok { - return 0, collection.ErrUnexpectedType - } - - return castB - castA, nil + subject.Sort(func(a, b int) (int, error) { + return b - a, nil }) fmt.Println(subject) // Output: [7 7 5 4 3 2] } -func ExampleLinkedList_Sorta() { - subject := collection.NewLinkedList("charlie", "alfa", "bravo", "delta") - subject.Sorta() - for _, entry := range subject.ToSlice() { - fmt.Println(entry.(string)) - } - // Output: - // alfa - // bravo - // charlie - // delta -} - -func ExampleLinkedList_Sorti() { - subject := collection.NewLinkedList(7, 3, 2, 2, 3, 6) - subject.Sorti() - fmt.Println(subject) - // Output: [2 2 3 3 6 7] -} - func ExampleLinkedList_String() { - subject1 := collection.NewLinkedList() + subject1 := collection.NewLinkedList[int]() for i := 0; i < 20; i++ { subject1.AddBack(i) } fmt.Println(subject1) - subject2 := collection.NewLinkedList(1, 2, 3) + subject2 := collection.NewLinkedList[int](1, 2, 3) fmt.Println(subject2) // Output: // [0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 ...] diff --git a/linkedlist_test.go b/linkedlist_test.go index b9472dd..b804732 100644 --- a/linkedlist_test.go +++ b/linkedlist_test.go @@ -3,45 +3,45 @@ package collection import "testing" func TestLinkedList_findLast_empty(t *testing.T) { - if result := findLast(nil); result != nil { + if result := findLast[int](nil); result != nil { t.Logf("got: %v\nwant: %v", result, nil) } } func TestLinkedList_merge(t *testing.T) { testCases := []struct { - Left *LinkedList - Right *LinkedList + Left *LinkedList[int] + Right *LinkedList[int] Expected []int - Comp Comparator + Comp Comparator[int] }{ { - NewLinkedList(1, 3, 5), - NewLinkedList(2, 4), + NewLinkedList[int](1, 3, 5), + NewLinkedList[int](2, 4), []int{1, 2, 3, 4, 5}, UncheckedComparatori, }, { - NewLinkedList(1, 2, 3), - NewLinkedList(), + NewLinkedList[int](1, 2, 3), + NewLinkedList[int](), []int{1, 2, 3}, UncheckedComparatori, }, { - NewLinkedList(), - NewLinkedList(1, 2, 3), + NewLinkedList[int](), + NewLinkedList[int](1, 2, 3), []int{1, 2, 3}, UncheckedComparatori, }, { - NewLinkedList(), - NewLinkedList(), + NewLinkedList[int](), + NewLinkedList[int](), []int{}, UncheckedComparatori, }, { - NewLinkedList(1), - NewLinkedList(1), + NewLinkedList[int](1), + NewLinkedList[int](1), []int{1, 1}, UncheckedComparatori, }, @@ -53,12 +53,12 @@ func TestLinkedList_merge(t *testing.T) { }, { NewLinkedList(3), - NewLinkedList(), + NewLinkedList[int](), []int{3}, UncheckedComparatori, }, { - NewLinkedList(), + NewLinkedList[int](), NewLinkedList(10), []int{10}, UncheckedComparatori, @@ -75,7 +75,7 @@ func TestLinkedList_merge(t *testing.T) { i := 0 for cursor := result; cursor != nil; cursor, i = cursor.next, i+1 { if cursor.payload != tc.Expected[i] { - t.Logf("got: %d want: %d", cursor.payload.(int), tc.Expected[i]) + t.Logf("got: %d want: %d", cursor.payload, tc.Expected[i]) t.Fail() } } @@ -88,57 +88,8 @@ func TestLinkedList_merge(t *testing.T) { } } -func TestLinkedList_mergeSort_repair(t *testing.T) { - testCases := []*LinkedList{ - NewLinkedList(1, 2, "str1", 4, 5, 6), - NewLinkedList(1, 2, 3, "str1", 5, 6), - NewLinkedList(1, 'a', 3, 4, 5, 6), - NewLinkedList(1, 2, 3, 4, 5, uint(8)), - NewLinkedList("alpha", 0), - NewLinkedList(0, "kappa"), - } - - for _, tc := range testCases { - t.Run(tc.String(), func(t *testing.T) { - originalLength := tc.Length() - originalElements := tc.Enumerate(nil).ToSlice() - originalContents := tc.String() - - if err := tc.Sorti(); err != ErrUnexpectedType { - t.Log("`Sorti() should have thrown ErrUnexpectedType") - t.Fail() - } - - t.Logf("Contents:\n\tOriginal: \t%s\n\tPost Merge: \t%s", originalContents, tc.String()) - - if newLength := tc.Length(); newLength != originalLength { - t.Logf("Length changed. got: %d want: %d", newLength, originalLength) - t.Fail() - } - - remaining := tc.Enumerate(nil).ToSlice() - - for _, desired := range originalElements { - found := false - for i, got := range remaining { - if got == desired { - remaining = append(remaining[:i], remaining[i+1:]...) - found = true - break - } - } - - if !found { - t.Logf("couldn't find element: %v", desired) - t.Fail() - } - } - }) - } -} - -func UncheckedComparatori(a, b interface{}) (int, error) { - return a.(int) - b.(int), nil +func UncheckedComparatori(a, b int) (int, error) { + return a - b, nil } func TestLinkedList_RemoveBack_single(t *testing.T) { @@ -149,53 +100,6 @@ func TestLinkedList_RemoveBack_single(t *testing.T) { } } -func TestLinkedList_Sorti(t *testing.T) { - testCases := []struct { - *LinkedList - Expected []int - }{ - { - NewLinkedList(), - []int{}, - }, - { - NewLinkedList(1, 2, 3, 4), - []int{1, 2, 3, 4}, - }, - { - NewLinkedList(0, -1, 2, 8, 9), - []int{-1, 0, 2, 8, 9}, - }, - } - - for _, tc := range testCases { - t.Run(tc.String(), func(t *testing.T) { - if err := tc.Sorti(); err != nil { - t.Error(err) - } - - sorted := tc.ToSlice() - - if countSorted, countExpected := len(sorted), len(tc.Expected); countSorted != countExpected { - t.Logf("got: %d want: %d", countSorted, countExpected) - t.FailNow() - } - - for i, entry := range sorted { - cast, ok := entry.(int) - if !ok { - t.Errorf("Element was not an int: %v", entry) - } - - if cast != tc.Expected[i] { - t.Logf("got: %d want: %d at: %d", cast, tc.Expected[i], i) - t.Fail() - } - } - }) - } -} - func TestLinkedList_split_Even(t *testing.T) { subject := NewLinkedList(1, 2, 3, 4) @@ -244,7 +148,7 @@ func TestLinkedList_split_Odd(t *testing.T) { } func TestLinkedList_split_Empty(t *testing.T) { - subject := NewLinkedList() + subject := NewLinkedList[*int]() left, right := split(subject.first) @@ -330,7 +234,7 @@ func TestLinkedList_Swap_OutOfBounds(t *testing.T) { func TestLinkedList_Get_OutsideBounds(t *testing.T) { subject := NewLinkedList(2, 3, 5, 8, 13, 21) result, ok := subject.Get(10) - if !(result == nil && ok == false) { + if !(result == 0 && ok == false) { t.Logf("got: %v %v\nwant: %v %v", result, ok, nil, false) t.Fail() } @@ -348,8 +252,8 @@ func TestLinkedList_removeNode(t *testing.T) { } if first, ok := subject.Get(0); ok { - if first.(int) != 2 { - t.Logf("got %d, want %d", first.(int), 2) + if first != 2 { + t.Logf("got %d, want %d", first, 2) t.Fail() } } else { @@ -358,8 +262,8 @@ func TestLinkedList_removeNode(t *testing.T) { } if second, ok := subject.Get(1); ok { - if second.(int) != 3 { - t.Logf("got %d, want %d", second.(int), 3) + if second != 3 { + t.Logf("got %d, want %d", second, 3) t.Fail() } } else { @@ -379,8 +283,8 @@ func TestLinkedList_removeNode(t *testing.T) { } if first, ok := subject.Get(0); ok { - if first.(int) != 1 { - t.Logf("got %d, want %d", first.(int), 1) + if first != 1 { + t.Logf("got %d, want %d", first, 1) t.Fail() } } else { @@ -389,8 +293,8 @@ func TestLinkedList_removeNode(t *testing.T) { } if second, ok := subject.Get(1); ok { - if second.(int) != 2 { - t.Logf("got %d, want %d", second.(int), 2) + if second != 2 { + t.Logf("got %d, want %d", second, 2) t.Fail() } } else { @@ -410,8 +314,8 @@ func TestLinkedList_removeNode(t *testing.T) { } if first, ok := subject.Get(0); ok { - if first.(int) != 1 { - t.Logf("got %d, want %d", first.(int), 1) + if first != 1 { + t.Logf("got %d, want %d", first, 1) t.Fail() } } else { @@ -420,8 +324,8 @@ func TestLinkedList_removeNode(t *testing.T) { } if second, ok := subject.Get(1); ok { - if second.(int) != 3 { - t.Logf("got %d, want %d", second.(int), 3) + if second != 3 { + t.Logf("got %d, want %d", second, 3) t.Fail() } } else { diff --git a/list.go b/list.go index 29f549a..d1c4b12 100644 --- a/list.go +++ b/list.go @@ -8,20 +8,20 @@ import ( // List is a dynamically sized list akin to List in the .NET world, // ArrayList in the Java world, or vector in the C++ world. -type List struct { - underlyer []interface{} +type List[T any] struct { + underlyer []T key sync.RWMutex } // NewList creates a new list which contains the elements provided. -func NewList(entries ...interface{}) *List { - return &List{ +func NewList[T any](entries ...T) *List[T] { + return &List[T]{ underlyer: entries, } } // Add appends an entry to the logical end of the List. -func (l *List) Add(entries ...interface{}) { +func (l *List[T]) Add(entries ...T) { l.key.Lock() defer l.key.Unlock() l.underlyer = append(l.underlyer, entries...) @@ -30,7 +30,7 @@ func (l *List) Add(entries ...interface{}) { // AddAt injects values beginning at `pos`. If multiple values // are provided in `entries` they are placed in the same order // they are provided. -func (l *List) AddAt(pos uint, entries ...interface{}) { +func (l *List[T]) AddAt(pos uint, entries ...T) { l.key.Lock() defer l.key.Unlock() @@ -38,8 +38,8 @@ func (l *List) AddAt(pos uint, entries ...interface{}) { } // Enumerate lists each element present in the collection -func (l *List) Enumerate(cancel <-chan struct{}) Enumerator { - retval := make(chan interface{}) +func (l *List[T]) Enumerate(cancel <-chan struct{}) Enumerator[T] { + retval := make(chan T) go func() { l.key.RLock() @@ -62,37 +62,37 @@ func (l *List) Enumerate(cancel <-chan struct{}) Enumerator { // Get retreives the value stored in a particular position of the list. // If no item exists at the given position, the second parameter will be // returned as false. -func (l *List) Get(pos uint) (interface{}, bool) { +func (l *List[T]) Get(pos uint) (T, bool) { l.key.RLock() defer l.key.RUnlock() if pos > uint(len(l.underlyer)) { - return nil, false + return *new(T), false } return l.underlyer[pos], true } // IsEmpty tests to see if this List has any elements present. -func (l *List) IsEmpty() bool { +func (l *List[T]) IsEmpty() bool { l.key.RLock() defer l.key.RUnlock() return 0 == len(l.underlyer) } // Length returns the number of elements in the List. -func (l *List) Length() uint { +func (l *List[T]) Length() uint { l.key.RLock() defer l.key.RUnlock() return uint(len(l.underlyer)) } // Remove retreives a value from this List and shifts all other values. -func (l *List) Remove(pos uint) (interface{}, bool) { +func (l *List[T]) Remove(pos uint) (T, bool) { l.key.Lock() defer l.key.Unlock() if pos > uint(len(l.underlyer)) { - return nil, false + return *new(T), false } retval := l.underlyer[pos] l.underlyer = append(l.underlyer[:pos], l.underlyer[pos+1:]...) @@ -100,7 +100,7 @@ func (l *List) Remove(pos uint) (interface{}, bool) { } // Set updates the value stored at a given position in the List. -func (l *List) Set(pos uint, val interface{}) bool { +func (l *List[T]) Set(pos uint, val T) bool { l.key.Lock() defer l.key.Unlock() var retval bool @@ -115,7 +115,7 @@ func (l *List) Set(pos uint, val interface{}) bool { } // String generates a textual representation of the List for the sake of debugging. -func (l *List) String() string { +func (l *List[T]) String() string { l.key.RLock() defer l.key.RUnlock() @@ -134,13 +134,13 @@ func (l *List) String() string { } // Swap switches the values that are stored at positions `x` and `y` -func (l *List) Swap(x, y uint) bool { +func (l *List[T]) Swap(x, y uint) bool { l.key.Lock() defer l.key.Unlock() return l.swap(x, y) } -func (l *List) swap(x, y uint) bool { +func (l *List[T]) swap(x, y uint) bool { count := uint(len(l.underlyer)) if x < count && y < count { temp := l.underlyer[x] diff --git a/lru_cache.go b/lru_cache.go index 8ff5f28..730d5e0 100644 --- a/lru_cache.go +++ b/lru_cache.go @@ -4,30 +4,30 @@ import "sync" // LRUCache hosts up to a given number of items. When more are presented, the least recently used item // is evicted from the cache. -type LRUCache struct { +type LRUCache[K comparable, V any] struct { capacity uint - entries map[interface{}]*lruEntry - touched *LinkedList - key sync.RWMutex + entries map[K]*lruEntry[K, V] + touched *LinkedList[*lruEntry[K, V]] + key sync.RWMutex } -type lruEntry struct { - Node *llNode - Key interface{} - Value interface{} +type lruEntry[K any, V any] struct { + Node *llNode[*lruEntry[K, V]] + Key K + Value V } // NewLRUCache creates an empty cache, which will accommodate the given number of items. -func NewLRUCache(capacity uint) *LRUCache { - return &LRUCache{ +func NewLRUCache[K comparable, V any](capacity uint) *LRUCache[K, V] { + return &LRUCache[K, V]{ capacity: capacity, - entries: make(map[interface{}]*lruEntry, capacity + 1), - touched: NewLinkedList(), + entries: make(map[K]*lruEntry[K, V], capacity+1), + touched: NewLinkedList[*lruEntry[K, V]](), } } // Put adds a value to the cache. The added value may be expelled without warning. -func (lru *LRUCache) Put(key interface{}, value interface{}) { +func (lru *LRUCache[K, V]) Put(key K, value V) { lru.key.Lock() defer lru.key.Unlock() @@ -35,9 +35,9 @@ func (lru *LRUCache) Put(key interface{}, value interface{}) { if ok { lru.touched.removeNode(entry.Node) } else { - entry = &lruEntry{ - Node: &llNode{}, - Key: key, + entry = &lruEntry[K, V]{ + Node: &llNode[*lruEntry[K, V]]{}, + Key: key, } } @@ -49,28 +49,28 @@ func (lru *LRUCache) Put(key interface{}, value interface{}) { if lru.touched.Length() > lru.capacity { removed, ok := lru.touched.RemoveBack() if ok { - delete(lru.entries, removed.(*lruEntry).Key) + delete(lru.entries, removed.Key) } } } // Get retrieves a cached value, if it is still present. -func (lru *LRUCache) Get(key interface{}) (interface{}, bool) { +func (lru *LRUCache[K, V]) Get(key K) (V, bool) { lru.key.RLock() defer lru.key.RUnlock() entry, ok := lru.entries[key] if !ok { - return nil, false + return *new(V), false } lru.touched.removeNode(entry.Node) lru.touched.addNodeFront(entry.Node) - return entry.Node.payload.(*lruEntry).Value, true + return entry.Node.payload.Value, true } // Remove explicitly takes an item out of the cache. -func (lru *LRUCache) Remove(key interface{}) bool { +func (lru *LRUCache[K, V]) Remove(key K) bool { lru.key.RLock() defer lru.key.RUnlock() @@ -85,8 +85,8 @@ func (lru *LRUCache) Remove(key interface{}) bool { } // Enumerate lists each value in the cache. -func (lru *LRUCache) Enumerate(cancel <-chan struct{}) Enumerator { - retval := make(chan interface{}) +func (lru *LRUCache[K, V]) Enumerate(cancel <-chan struct{}) Enumerator[V] { + retval := make(chan V) nested := lru.touched.Enumerate(cancel) @@ -97,7 +97,7 @@ func (lru *LRUCache) Enumerate(cancel <-chan struct{}) Enumerator { for entry := range nested { select { - case retval <- entry.(*lruEntry).Value: + case retval <- entry.Value: break case <-cancel: return @@ -109,19 +109,19 @@ func (lru *LRUCache) Enumerate(cancel <-chan struct{}) Enumerator { } // EnumerateKeys lists each key in the cache. -func (lru *LRUCache) EnumerateKeys(cancel <-chan struct{}) Enumerator { - retval := make(chan interface{}) +func (lru *LRUCache[K, V]) EnumerateKeys(cancel <-chan struct{}) Enumerator[K] { + retval := make(chan K) nested := lru.touched.Enumerate(cancel) go func() { lru.key.RLock() - defer lru.key.RUnlock() + defer lru.key.RUnlock() defer close(retval) for entry := range nested { select { - case retval <- entry.(*lruEntry).Key: + case retval <- entry.Key: break case <-cancel: return diff --git a/lru_cache_test.go b/lru_cache_test.go index b13604f..e5fbd55 100644 --- a/lru_cache_test.go +++ b/lru_cache_test.go @@ -7,7 +7,7 @@ func TestLRUCache_Put_replace(t *testing.T) { const firstPut = "first" const secondPut = "second" - subject := NewLRUCache(10) + subject := NewLRUCache[int, string](10) subject.Put(key, firstPut) subject.Put(key, secondPut) @@ -25,7 +25,7 @@ func TestLRUCache_Put_replace(t *testing.T) { } func TestLRUCache_Remove_empty(t *testing.T) { - subject := NewLRUCache(10) + subject := NewLRUCache[int, int](10) got := subject.Remove(7) if got != false { t.Fail() @@ -34,7 +34,7 @@ func TestLRUCache_Remove_empty(t *testing.T) { func TestLRUCache_Remove_present(t *testing.T) { const key = 10 - subject := NewLRUCache(6) + subject := NewLRUCache[int, string](6) subject.Put(key, "ten") ok := subject.Remove(key) if !ok { @@ -50,7 +50,7 @@ func TestLRUCache_Remove_present(t *testing.T) { func TestLRUCache_Remove_notPresent(t *testing.T) { const key1 = 10 const key2 = key1 + 1 - subject := NewLRUCache(6) + subject := NewLRUCache[int, string](6) subject.Put(key2, "eleven") ok := subject.Remove(key1) if ok { @@ -61,4 +61,4 @@ func TestLRUCache_Remove_notPresent(t *testing.T) { if !ok { t.Fail() } -} \ No newline at end of file +} diff --git a/lru_example_test.go b/lru_example_test.go index 80b99b9..43091ed 100644 --- a/lru_example_test.go +++ b/lru_example_test.go @@ -7,7 +7,7 @@ import ( ) func ExampleLRUCache() { - subject := collection.NewLRUCache(3) + subject := collection.NewLRUCache[int, string](3) subject.Put(1, "one") subject.Put(2, "two") subject.Put(3, "three") @@ -15,13 +15,13 @@ func ExampleLRUCache() { fmt.Println(subject.Get(1)) fmt.Println(subject.Get(4)) // Output: - // false + // false // four true } func ExampleLRUCache_Enumerate() { ctx := context.Background() - subject := collection.NewLRUCache(3) + subject := collection.NewLRUCache[int, string](3) subject.Put(1, "one") subject.Put(2, "two") subject.Put(3, "three") @@ -39,7 +39,7 @@ func ExampleLRUCache_Enumerate() { func ExampleLRUCache_EnumerateKeys() { ctx := context.Background() - subject := collection.NewLRUCache(3) + subject := collection.NewLRUCache[int, string](3) subject.Put(1, "one") subject.Put(2, "two") subject.Put(3, "three") @@ -53,4 +53,4 @@ func ExampleLRUCache_EnumerateKeys() { // 4 // 3 // 2 -} \ No newline at end of file +} diff --git a/query.go b/query.go index ac3632e..44c9fb2 100644 --- a/query.go +++ b/query.go @@ -2,30 +2,29 @@ package collection import ( "errors" - "reflect" "runtime" "sync" ) // Enumerable offers a means of easily converting into a channel. It is most // useful for types where mutability is not in question. -type Enumerable interface { - Enumerate(cancel <-chan struct{}) Enumerator +type Enumerable[T any] interface { + Enumerate(cancel <-chan struct{}) Enumerator[T] } // Enumerator exposes a new syntax for querying familiar data structures. -type Enumerator <-chan interface{} +type Enumerator[T any] <-chan T // Predicate defines an interface for funcs that make some logical test. -type Predicate func(interface{}) bool +type Predicate[T any] func(T) bool // Transform defines a function which takes a value, and returns some value based on the original. -type Transform func(interface{}) interface{} +type Transform[T any, E any] func(T) E // Unfolder defines a function which takes a single value, and exposes many of them as an Enumerator -type Unfolder func(interface{}) Enumerator +type Unfolder[T any, E any] func(T) Enumerator[E] -type emptyEnumerable struct{} +type emptyEnumerable[T any] struct{} var ( errNoElements = errors.New("enumerator encountered no elements") @@ -44,22 +43,25 @@ func IsErrorMultipleElements(err error) bool { return err == errMultipleElements } -// Identity is a trivial Transform which applies no operation on the value. -var Identity Transform = func(value interface{}) interface{} { - return value +// Identity returns a trivial Transform which applies no operation on the value. +func Identity[T any]() Transform[T, T] { + return func(value T) T { + return value + } } -// Empty is an Enumerable that has no elements, and will never have any elements. -var Empty Enumerable = &emptyEnumerable{} +func Empty[T any]() Enumerable[T] { + return &emptyEnumerable[T]{} +} -func (e emptyEnumerable) Enumerate(cancel <-chan struct{}) Enumerator { - results := make(chan interface{}) +func (e emptyEnumerable[T]) Enumerate(cancel <-chan struct{}) Enumerator[T] { + results := make(chan T) close(results) return results } // All tests whether or not all items present in an Enumerable meet a criteria. -func All(subject Enumerable, p Predicate) bool { +func All[T any](subject Enumerable[T], p Predicate[T]) bool { done := make(chan struct{}) defer close(done) @@ -67,7 +69,7 @@ func All(subject Enumerable, p Predicate) bool { } // All tests whether or not all items present meet a criteria. -func (iter Enumerator) All(p Predicate) bool { +func (iter Enumerator[T]) All(p Predicate[T]) bool { for entry := range iter { if !p(entry) { return false @@ -77,7 +79,7 @@ func (iter Enumerator) All(p Predicate) bool { } // Any tests an Enumerable to see if there are any elements present. -func Any(iterator Enumerable) bool { +func Any[T any](iterator Enumerable[T]) bool { done := make(chan struct{}) defer close(done) @@ -88,7 +90,7 @@ func Any(iterator Enumerable) bool { } // Anyp tests an Enumerable to see if there are any elements present that meet a criteria. -func Anyp(iterator Enumerable, p Predicate) bool { +func Anyp[T any](iterator Enumerable[T], p Predicate[T]) bool { done := make(chan struct{}) defer close(done) @@ -100,10 +102,10 @@ func Anyp(iterator Enumerable, p Predicate) bool { return false } -type enumerableSlice []interface{} +type EnumerableSlice[T any] []T -func (f enumerableSlice) Enumerate(cancel <-chan struct{}) Enumerator { - results := make(chan interface{}) +func (f EnumerableSlice[T]) Enumerate(cancel <-chan struct{}) Enumerator[T] { + results := make(chan T) go func() { defer close(results) @@ -120,61 +122,25 @@ func (f enumerableSlice) Enumerate(cancel <-chan struct{}) Enumerator { return results } -type enumerableValue struct { - reflect.Value -} - -func (v enumerableValue) Enumerate(cancel <-chan struct{}) Enumerator { - results := make(chan interface{}) - - go func() { - defer close(results) - - elements := v.Len() - - for i := 0; i < elements; i++ { - select { - case results <- v.Index(i).Interface(): - break - case <-cancel: - return - } - } - }() - - return results -} - // AsEnumerable allows for easy conversion of a slice to a re-usable Enumerable object. -func AsEnumerable(entries ...interface{}) Enumerable { - if len(entries) != 1 { - return enumerableSlice(entries) - } - - val := reflect.ValueOf(entries[0]) - - if kind := val.Kind(); kind == reflect.Slice || kind == reflect.Array { - return enumerableValue{ - Value: val, - } - } - return enumerableSlice(entries) +func AsEnumerable[T any](entries ...T) Enumerable[T] { + return EnumerableSlice[T](entries) } // AsEnumerable stores the results of an Enumerator so the results can be enumerated over repeatedly. -func (iter Enumerator) AsEnumerable() Enumerable { - return enumerableSlice(iter.ToSlice()) +func (iter Enumerator[T]) AsEnumerable() Enumerable[T] { + return EnumerableSlice[T](iter.ToSlice()) } // Count iterates over a list and keeps a running tally of the number of elements // satisfy a predicate. -func Count(iter Enumerable, p Predicate) int { +func Count[T any](iter Enumerable[T], p Predicate[T]) int { return iter.Enumerate(nil).Count(p) } // Count iterates over a list and keeps a running tally of the number of elements // satisfy a predicate. -func (iter Enumerator) Count(p Predicate) int { +func (iter Enumerator[T]) Count(p Predicate[T]) int { tally := 0 for entry := range iter { if p(entry) { @@ -185,12 +151,12 @@ func (iter Enumerator) Count(p Predicate) int { } // CountAll iterates over a list and keeps a running tally of how many it's seen. -func CountAll(iter Enumerable) int { +func CountAll[T any](iter Enumerable[T]) int { return iter.Enumerate(nil).CountAll() } // CountAll iterates over a list and keeps a running tally of how many it's seen. -func (iter Enumerator) CountAll() int { +func (iter Enumerator[T]) CountAll() int { tally := 0 for range iter { tally++ @@ -200,21 +166,21 @@ func (iter Enumerator) CountAll() int { // Discard reads an enumerator to the end but does nothing with it. // This method should be used in circumstances when it doesn't make sense to explicitly cancel the Enumeration. -func (iter Enumerator) Discard() { +func (iter Enumerator[T]) Discard() { for range iter { // Intentionally Left Blank } } // ElementAt retreives an item at a particular position in an Enumerator. -func ElementAt(iter Enumerable, n uint) interface{} { +func ElementAt[T any](iter Enumerable[T], n uint) T { done := make(chan struct{}) defer close(done) return iter.Enumerate(done).ElementAt(n) } // ElementAt retreives an item at a particular position in an Enumerator. -func (iter Enumerator) ElementAt(n uint) interface{} { +func (iter Enumerator[T]) ElementAt(n uint) T { for i := uint(0); i < n; i++ { <-iter } @@ -222,7 +188,7 @@ func (iter Enumerator) ElementAt(n uint) interface{} { } // First retrieves just the first item in the list, or returns an error if there are no elements in the array. -func First(subject Enumerable) (retval interface{}, err error) { +func First[T any](subject Enumerable[T]) (retval T, err error) { done := make(chan struct{}) err = errNoElements @@ -238,29 +204,29 @@ func First(subject Enumerable) (retval interface{}, err error) { } // Last retreives the item logically behind all other elements in the list. -func Last(iter Enumerable) interface{} { +func Last[T any](iter Enumerable[T]) T { return iter.Enumerate(nil).Last() } // Last retreives the item logically behind all other elements in the list. -func (iter Enumerator) Last() (retval interface{}) { +func (iter Enumerator[T]) Last() (retval T) { for retval = range iter { // Intentionally Left Blank } return } -type merger struct { - originals []Enumerable +type merger[T any] struct { + originals []Enumerable[T] } -func (m merger) Enumerate(cancel <-chan struct{}) Enumerator { - retval := make(chan interface{}) +func (m merger[T]) Enumerate(cancel <-chan struct{}) Enumerator[T] { + retval := make(chan T) var wg sync.WaitGroup wg.Add(len(m.originals)) for _, item := range m.originals { - go func(input Enumerable) { + go func(input Enumerable[T]) { defer wg.Done() for value := range input.Enumerate(cancel) { retval <- value @@ -277,21 +243,21 @@ func (m merger) Enumerate(cancel <-chan struct{}) Enumerator { // Merge takes the results as it receives them from several channels and directs // them into a single channel. -func Merge(channels ...Enumerable) Enumerable { - return merger{ +func Merge[T any](channels ...Enumerable[T]) Enumerable[T] { + return merger[T]{ originals: channels, } } // Merge takes the results of this Enumerator and others, and funnels them into // a single Enumerator. The order of in which they will be combined is non-deterministic. -func (iter Enumerator) Merge(others ...Enumerator) Enumerator { - retval := make(chan interface{}) +func (iter Enumerator[T]) Merge(others ...Enumerator[T]) Enumerator[T] { + retval := make(chan T) var wg sync.WaitGroup wg.Add(len(others) + 1) - funnel := func(prevResult Enumerator) { + funnel := func(prevResult Enumerator[T]) { for entry := range prevResult { retval <- entry } @@ -310,56 +276,66 @@ func (iter Enumerator) Merge(others ...Enumerator) Enumerator { return retval } -type parallelSelecter struct { - original Enumerable - operation Transform +type parallelSelecter[T any, E any] struct { + original Enumerable[T] + operation Transform[T, E] } -func (ps parallelSelecter) Enumerate(cancel <-chan struct{}) Enumerator { - return ps.original.Enumerate(cancel).ParallelSelect(ps.operation) +func (ps parallelSelecter[T, E]) Enumerate(cancel <-chan struct{}) Enumerator[E] { + iter := ps.original.Enumerate(cancel) + if cpus := runtime.NumCPU(); cpus != 1 { + intermediate := splitN(iter, ps.operation, uint(cpus)) + return intermediate[0].Merge(intermediate[1:]...) + } + + return Select(ps.original, ps.operation).Enumerate(cancel) } // ParallelSelect creates an Enumerable which will use all logically available CPUs to // execute a Transform. -func ParallelSelect(original Enumerable, operation Transform) Enumerable { - return parallelSelecter{ +func ParallelSelect[T any, E any](original Enumerable[T], operation Transform[T, E]) Enumerable[E] { + return parallelSelecter[T, E]{ original: original, operation: operation, } } // ParallelSelect will execute a Transform across all logical CPUs available to the current process. -func (iter Enumerator) ParallelSelect(operation Transform) Enumerator { - if cpus := runtime.NumCPU(); cpus != 1 { - intermediate := iter.splitN(operation, uint(cpus)) - return intermediate[0].Merge(intermediate[1:]...) - } - return iter -} - -type reverser struct { - original Enumerable +// +// This is commented out, because Go 1.18 adds support for generics, but disallows methods from having type parameters +// not declared by their receivers. +// +//func (iter Enumerator[T]) ParallelSelect[E any](operation Transform[T, E]) Enumerator[E] { +// if cpus := runtime.NumCPU(); cpus != 1 { +// intermediate := iter.splitN(operation, uint(cpus)) +// return intermediate[0].Merge(intermediate[1:]...) +// } +// return iter +//} + +type reverser[T any] struct { + original Enumerable[T] } // Reverse will enumerate all values of an enumerable, store them in a Stack, then replay them all. -func Reverse(original Enumerable) Enumerable { - return reverser{ +func Reverse[T any](original Enumerable[T]) Enumerable[T] { + return reverser[T]{ original: original, } } -func (r reverser) Enumerate(cancel <-chan struct{}) Enumerator { +func (r reverser[T]) Enumerate(cancel <-chan struct{}) Enumerator[T] { return r.original.Enumerate(cancel).Reverse() } // Reverse returns items in the opposite order it encountered them in. -func (iter Enumerator) Reverse() Enumerator { - cache := NewStack() +func (iter Enumerator[T]) Reverse() Enumerator[T] { + cache := NewStack[T]() for entry := range iter { cache.Push(entry) } - retval := make(chan interface{}) + retval := make(chan T) go func() { for !cache.IsEmpty() { @@ -371,72 +347,100 @@ func (iter Enumerator) Reverse() Enumerator { return retval } -type selecter struct { - original Enumerable - transform Transform +type selecter[T any, E any] struct { + original Enumerable[T] + transform Transform[T, E] } -func (s selecter) Enumerate(cancel <-chan struct{}) Enumerator { - return s.original.Enumerate(cancel).Select(s.transform) +func (s selecter[T, E]) Enumerate(cancel <-chan struct{}) Enumerator[E] { + retval := make(chan E) + + go func() { + defer close(retval) + + for item := range s.original.Enumerate(cancel) { + select { + case retval <- s.transform(item): + // Intentionally Left Blank + case <-cancel: + return + } + } + }() + + return retval } // Select creates a reusable stream of transformed values. -func Select(subject Enumerable, transform Transform) Enumerable { - return selecter{ +func Select[T any, E any](subject Enumerable[T], transform Transform[T, E]) Enumerable[E] { + return selecter[T, E]{ original: subject, transform: transform, } } // Select iterates over a list and returns a transformed item. -func (iter Enumerator) Select(transform Transform) Enumerator { - retval := make(chan interface{}) +// +// This is commented out because Go 1.18 added support for +// +//func (iter Enumerator[T]) Select[E any](transform Transform[T, E]) Enumerator[E] { +// retval := make(chan interface{}) +// +// go func() { +// for item := range iter { +// retval <- transform(item) +// } +// close(retval) +// }() +// +// return retval +//} + +type selectManyer[T any, E any] struct { + original Enumerable[T] + toMany Unfolder[T, E] +} + +func (s selectManyer[T, E]) Enumerate(cancel <-chan struct{}) Enumerator[E] { + retval := make(chan E) go func() { - for item := range iter { - retval <- transform(item) + for parent := range s.original.Enumerate(cancel) { + for child := range s.toMany(parent) { + retval <- child + } } close(retval) }() - return retval } -type selectManyer struct { - original Enumerable - toMany Unfolder -} - -func (s selectManyer) Enumerate(cancel <-chan struct{}) Enumerator { - return s.original.Enumerate(cancel).SelectMany(s.toMany) -} - // SelectMany allows for unfolding of values. -func SelectMany(subject Enumerable, toMany Unfolder) Enumerable { - return selectManyer{ +func SelectMany[T any, E any](subject Enumerable[T], toMany Unfolder[T, E]) Enumerable[E] { + return selectManyer[T, E]{ original: subject, toMany: toMany, } } -// SelectMany allows for flattening of data structures. -func (iter Enumerator) SelectMany(lister Unfolder) Enumerator { - retval := make(chan interface{}) - - go func() { - for parent := range iter { - for child := range lister(parent) { - retval <- child - } - } - close(retval) - }() - - return retval -} +//// SelectMany allows for flattening of data structures. +//func (iter Enumerator[T]) SelectMany[E any](lister Unfolder[T, E]) Enumerator[E] { +// retval := make(chan E) +// +// go func() { +// for parent := range iter { +// for child := range lister(parent) { +// retval <- child +// } +// } +// close(retval) +// }() +// +// return retval +//} // Single retreives the only element from a list, or returns nil and an error. -func Single(iter Enumerable) (retval interface{}, err error) { +func Single[T any](iter Enumerable[T]) (retval T, err error) { done := make(chan struct{}) defer close(done) @@ -448,7 +452,7 @@ func Single(iter Enumerable) (retval interface{}, err error) { retval = entry err = nil } else { - retval = nil + retval = *new(T) err = errMultipleElements break } @@ -460,32 +464,32 @@ func Single(iter Enumerable) (retval interface{}, err error) { // Singlep retrieces the only element from a list that matches a criteria. If // no match is found, or two or more are found, `Singlep` returns nil and an // error. -func Singlep(iter Enumerable, pred Predicate) (retval interface{}, err error) { +func Singlep[T any](iter Enumerable[T], pred Predicate[T]) (retval T, err error) { iter = Where(iter, pred) return Single(iter) } -type skipper struct { - original Enumerable +type skipper[T any] struct { + original Enumerable[T] skipCount uint } -func (s skipper) Enumerate(cancel <-chan struct{}) Enumerator { +func (s skipper[T]) Enumerate(cancel <-chan struct{}) Enumerator[T] { return s.original.Enumerate(cancel).Skip(s.skipCount) } // Skip creates a reusable stream which will skip the first `n` elements before iterating // over the rest of the elements in an Enumerable. -func Skip(subject Enumerable, n uint) Enumerable { - return skipper{ +func Skip[T any](subject Enumerable[T], n uint) Enumerable[T] { + return skipper[T]{ original: subject, skipCount: n, } } // Skip retreives all elements after the first 'n' elements. -func (iter Enumerator) Skip(n uint) Enumerator { - results := make(chan interface{}) +func (iter Enumerator[T]) Skip(n uint) Enumerator[T] { + results := make(chan T) go func() { defer close(results) @@ -505,11 +509,11 @@ func (iter Enumerator) Skip(n uint) Enumerator { // splitN creates N Enumerators, each will be a subset of the original Enumerator and will have // distinct populations from one another. -func (iter Enumerator) splitN(operation Transform, n uint) []Enumerator { - results, cast := make([]chan interface{}, n, n), make([]Enumerator, n, n) +func splitN[T any, E any](iter Enumerator[T], operation Transform[T, E], n uint) []Enumerator[E] { + results, cast := make([]chan E, n, n), make([]Enumerator[E], n, n) for i := uint(0); i < n; i++ { - results[i] = make(chan interface{}) + results[i] = make(chan E) cast[i] = results[i] } @@ -531,26 +535,26 @@ func (iter Enumerator) splitN(operation Transform, n uint) []Enumerator { return cast } -type taker struct { - original Enumerable +type taker[T any] struct { + original Enumerable[T] n uint } -func (t taker) Enumerate(cancel <-chan struct{}) Enumerator { +func (t taker[T]) Enumerate(cancel <-chan struct{}) Enumerator[T] { return t.original.Enumerate(cancel).Take(t.n) } // Take retreives just the first `n` elements from an Enumerable. -func Take(subject Enumerable, n uint) Enumerable { - return taker{ +func Take[T any](subject Enumerable[T], n uint) Enumerable[T] { + return taker[T]{ original: subject, n: n, } } // Take retreives just the first 'n' elements from an Enumerator. -func (iter Enumerator) Take(n uint) Enumerator { - results := make(chan interface{}) +func (iter Enumerator[T]) Take(n uint) Enumerator[T] { + results := make(chan T) go func() { defer close(results) @@ -567,26 +571,26 @@ func (iter Enumerator) Take(n uint) Enumerator { return results } -type takeWhiler struct { - original Enumerable - criteria func(interface{}, uint) bool +type takeWhiler[T any] struct { + original Enumerable[T] + criteria func(T, uint) bool } -func (tw takeWhiler) Enumerate(cancel <-chan struct{}) Enumerator { +func (tw takeWhiler[T]) Enumerate(cancel <-chan struct{}) Enumerator[T] { return tw.original.Enumerate(cancel).TakeWhile(tw.criteria) } // TakeWhile creates a reusable stream which will halt once some criteria is no longer met. -func TakeWhile(subject Enumerable, criteria func(interface{}, uint) bool) Enumerable { - return takeWhiler{ +func TakeWhile[T any](subject Enumerable[T], criteria func(T, uint) bool) Enumerable[T] { + return takeWhiler[T]{ original: subject, criteria: criteria, } } // TakeWhile continues returning items as long as 'criteria' holds true. -func (iter Enumerator) TakeWhile(criteria func(interface{}, uint) bool) Enumerator { - results := make(chan interface{}) +func (iter Enumerator[T]) TakeWhile(criteria func(T, uint) bool) Enumerator[T] { + results := make(chan T) go func() { defer close(results) @@ -604,8 +608,8 @@ func (iter Enumerator) TakeWhile(criteria func(interface{}, uint) bool) Enumerat } // Tee creates two Enumerators which will have identical contents as one another. -func (iter Enumerator) Tee() (Enumerator, Enumerator) { - left, right := make(chan interface{}), make(chan interface{}) +func (iter Enumerator[T]) Tee() (Enumerator[T], Enumerator[T]) { + left, right := make(chan T), make(chan T) go func() { for entry := range iter { @@ -620,26 +624,26 @@ func (iter Enumerator) Tee() (Enumerator, Enumerator) { } // ToSlice places all iterated over values in a Slice for easy consumption. -func ToSlice(iter Enumerable) []interface{} { +func ToSlice[T any](iter Enumerable[T]) []T { return iter.Enumerate(nil).ToSlice() } // ToSlice places all iterated over values in a Slice for easy consumption. -func (iter Enumerator) ToSlice() []interface{} { - retval := make([]interface{}, 0) +func (iter Enumerator[T]) ToSlice() []T { + retval := make([]T, 0) for entry := range iter { retval = append(retval, entry) } return retval } -type wherer struct { - original Enumerable - filter Predicate +type wherer[T any] struct { + original Enumerable[T] + filter Predicate[T] } -func (w wherer) Enumerate(cancel <-chan struct{}) Enumerator { - retval := make(chan interface{}) +func (w wherer[T]) Enumerate(cancel <-chan struct{}) Enumerator[T] { + retval := make(chan T) go func() { defer close(retval) @@ -654,8 +658,8 @@ func (w wherer) Enumerate(cancel <-chan struct{}) Enumerator { } // Where creates a reusable means of filtering a stream. -func Where(original Enumerable, p Predicate) Enumerable { - return wherer{ +func Where[T any](original Enumerable[T], p Predicate[T]) Enumerable[T] { + return wherer[T]{ original: original, filter: p, } @@ -663,8 +667,8 @@ func Where(original Enumerable, p Predicate) Enumerable { // Where iterates over a list and returns only the elements that satisfy a // predicate. -func (iter Enumerator) Where(predicate Predicate) Enumerator { - retval := make(chan interface{}) +func (iter Enumerator[T]) Where(predicate Predicate[T]) Enumerator[T] { + retval := make(chan T) go func() { for item := range iter { if predicate(item) { @@ -679,13 +683,13 @@ func (iter Enumerator) Where(predicate Predicate) Enumerator { // UCount iterates over a list and keeps a running tally of the number of elements // satisfy a predicate. -func UCount(iter Enumerable, p Predicate) uint { +func UCount[T any](iter Enumerable[T], p Predicate[T]) uint { return iter.Enumerate(nil).UCount(p) } // UCount iterates over a list and keeps a running tally of the number of elements // satisfy a predicate. -func (iter Enumerator) UCount(p Predicate) uint { +func (iter Enumerator[T]) UCount(p Predicate[T]) uint { tally := uint(0) for entry := range iter { if p(entry) { @@ -696,12 +700,12 @@ func (iter Enumerator) UCount(p Predicate) uint { } // UCountAll iterates over a list and keeps a running tally of how many it's seen. -func UCountAll(iter Enumerable) uint { +func UCountAll[T any](iter Enumerable[T]) uint { return iter.Enumerate(nil).UCountAll() } // UCountAll iterates over a list and keeps a running tally of how many it's seen. -func (iter Enumerator) UCountAll() uint { +func (iter Enumerator[T]) UCountAll() uint { tally := uint(0) for range iter { tally++ diff --git a/query_examples_test.go b/query_examples_test.go index 80822e6..2e7b5d3 100644 --- a/query_examples_test.go +++ b/query_examples_test.go @@ -7,19 +7,20 @@ import ( "github.com/marstr/collection" ) -func ExampleAsEnumerable() { +func ExampleEnumerableSlice_Enumerate() { // When a single value is provided, and it is an array or slice, each value in the array or slice is treated as an enumerable value. - original := []int{1, 2, 3, 4, 5} - wrapped := collection.AsEnumerable(original) + originalInts := []int{1, 2, 3, 4, 5} + wrappedInts := collection.EnumerableSlice[int](originalInts) - for entry := range wrapped.Enumerate(nil) { + for entry := range wrappedInts.Enumerate(nil) { fmt.Print(entry) } fmt.Println() - // When multiple values are provided, regardless of their type, they are each treated as enumerable values. - wrapped = collection.AsEnumerable("red", "orange", "yellow", "green", "blue", "indigo", "violet") - for entry := range wrapped.Enumerate(nil) { + // It's easy to convert arrays to slices for these enumerations as well. + originalStrings := [7]string{"red", "orange", "yellow", "green", "blue", "indigo", "violet"} + wrappedStrings := collection.EnumerableSlice[string](originalStrings[:]) + for entry := range wrappedStrings.Enumerate(nil) { fmt.Println(entry) } // Output: @@ -35,7 +36,7 @@ func ExampleAsEnumerable() { func ExampleEnumerator_Count() { subject := collection.AsEnumerable("str1", "str1", "str2") - count1 := subject.Enumerate(nil).Count(func(a interface{}) bool { + count1 := subject.Enumerate(nil).Count(func(a string) bool { return a == "str1" }) fmt.Println(count1) @@ -56,20 +57,20 @@ func ExampleEnumerator_ElementAt() { } func ExampleFirst() { - empty := collection.NewQueue() + empty := collection.NewQueue[int]() notEmpty := collection.NewQueue(1, 2, 3, 4) - fmt.Println(collection.First(empty)) - fmt.Println(collection.First(notEmpty)) + fmt.Println(collection.First[int](empty)) + fmt.Println(collection.First[int](notEmpty)) // Output: - // enumerator encountered no elements + // 0 enumerator encountered no elements // 1 } func ExampleLast() { subject := collection.NewList(1, 2, 3, 4) - fmt.Println(collection.Last(subject)) + fmt.Println(collection.Last[int](subject)) // Output: 4 } @@ -85,13 +86,13 @@ func ExampleMerge() { c := collection.Merge(a, b) sum := 0 for x := range c.Enumerate(nil) { - sum += x.(int) + sum += x } fmt.Println(sum) product := 1 for y := range a.Enumerate(nil) { - product *= y.(int) + product *= y } fmt.Println(product) // Output: @@ -109,68 +110,57 @@ func ExampleEnumerator_Reverse() { func ExampleSelect() { const offset = 'a' - 1 - subject := collection.AsEnumerable('a', 'b', 'c') - subject = collection.Select(subject, func(a interface{}) interface{} { - return a.(rune) - offset + subject := collection.AsEnumerable[rune]('a', 'b', 'c') + subject = collection.Select(subject, func(a rune) rune { + return a - offset }) fmt.Println(collection.ToSlice(subject)) // Output: [1 2 3] } -func ExampleEnumerator_Select() { - subject := collection.AsEnumerable('a', 'b', 'c').Enumerate(nil) - const offset = 'a' - 1 - results := subject.Select(func(a interface{}) interface{} { - return a.(rune) - offset - }) - - fmt.Println(results.ToSlice()) - // Output: [1 2 3] -} - func ExampleEnumerator_SelectMany() { type BrewHouse struct { Name string - Beers collection.Enumerable + Beers collection.Enumerable[string] } breweries := collection.AsEnumerable( BrewHouse{ "Mac & Jacks", - collection.AsEnumerable( + collection.AsEnumerable[string]( "African Amber", "Ibis IPA", ), }, BrewHouse{ "Post Doc", - collection.AsEnumerable( + collection.AsEnumerable[string]( "Prereq Pale", ), }, BrewHouse{ "Resonate", - collection.AsEnumerable( + collection.AsEnumerable[string]( "Comfortably Numb IPA", "Lithium Altbier", ), }, BrewHouse{ "Triplehorn", - collection.AsEnumerable( + collection.AsEnumerable[string]( "Samson", "Pepper Belly", ), }, ) - beers := breweries.Enumerate(nil).SelectMany(func(brewer interface{}) collection.Enumerator { - return brewer.(BrewHouse).Beers.Enumerate(nil) + beers := collection.SelectMany(breweries, func(brewer BrewHouse) collection.Enumerator[string] { + return brewer.Beers.Enumerate(nil) }) - for beer := range beers { + for beer := range beers.Enumerate(nil) { fmt.Println(beer) } @@ -238,8 +228,8 @@ func ExampleEnumerator_Take() { } func ExampleTakeWhile() { - taken := collection.TakeWhile(collection.Fibonacci, func(x interface{}, n uint) bool { - return x.(int) < 10 + taken := collection.TakeWhile(collection.Fibonacci, func(x, n uint) bool { + return x < 10 }) for entry := range taken.Enumerate(nil) { fmt.Println(entry) @@ -254,22 +244,6 @@ func ExampleTakeWhile() { // 8 } -func ExampleEnumerator_TakeWhile() { - taken := collection.Fibonacci.Enumerate(nil).TakeWhile(func(x interface{}, n uint) bool { - return x.(int) < 6 - }) - for entry := range taken { - fmt.Println(entry) - } - // Output: - // 0 - // 1 - // 1 - // 2 - // 3 - // 5 -} - func ExampleEnumerator_Tee() { base := collection.AsEnumerable(1, 2, 4) left, right := base.Enumerate(nil).Tee() @@ -279,7 +253,7 @@ func ExampleEnumerator_Tee() { product := 1 go func() { for x := range left { - product *= x.(int) + product *= x } wg.Done() }() @@ -287,7 +261,7 @@ func ExampleEnumerator_Tee() { sum := 0 go func() { for x := range right { - sum += x.(int) + sum += x } wg.Done() }() @@ -302,8 +276,8 @@ func ExampleEnumerator_Tee() { } func ExampleUCount() { - subject := collection.NewStack(9, 'a', "str1") - result := collection.UCount(subject, func(a interface{}) bool { + subject := collection.NewStack[any](9, 'a', "str1") + result := collection.UCount[any](subject, func(a interface{}) bool { _, ok := a.(string) return ok }) @@ -312,8 +286,8 @@ func ExampleUCount() { } func ExampleEnumerator_UCount() { - subject := collection.AsEnumerable("str1", "str1", "str2") - count1 := subject.Enumerate(nil).UCount(func(a interface{}) bool { + subject := collection.EnumerableSlice[string]([]string{"str1", "str1", "str2"}) + count1 := subject.Enumerate(nil).UCount(func(a string) bool { return a == "str1" }) fmt.Println(count1) @@ -322,12 +296,12 @@ func ExampleEnumerator_UCount() { func ExampleUCountAll() { subject := collection.NewStack(8, 9, 10, 11) - fmt.Println(collection.UCountAll(subject)) + fmt.Println(collection.UCountAll[int](subject)) // Output: 4 } func ExampleEnumerator_UCountAll() { - subject := collection.AsEnumerable('a', 2, "str1") + subject := collection.EnumerableSlice[any]([]interface{}{'a', 2, "str1"}) fmt.Println(subject.Enumerate(nil).UCountAll()) // Output: 3 } @@ -335,16 +309,17 @@ func ExampleEnumerator_UCountAll() { func ExampleEnumerator_Where() { done := make(chan struct{}) defer close(done) - results := collection.Fibonacci.Enumerate(done).Where(func(a interface{}) bool { - return a.(int) > 8 + results := collection.Fibonacci.Enumerate(done).Where(func(a uint) bool { + return a > 8 }).Take(3) fmt.Println(results.ToSlice()) // Output: [13 21 34] } func ExampleWhere() { - results := collection.Where(collection.AsEnumerable(1, 2, 3, 4, 5), func(a interface{}) bool { - return a.(int) < 3 + nums := collection.EnumerableSlice[int]([]int{1, 2, 3, 4, 5}) + results := collection.Where[int](nums, func(a int) bool { + return a < 3 }) fmt.Println(collection.ToSlice(results)) // Output: [1 2] diff --git a/query_test.go b/query_test.go index 993353f..2a9634e 100644 --- a/query_test.go +++ b/query_test.go @@ -6,47 +6,47 @@ import ( ) func Test_Empty(t *testing.T) { - if Any(Empty) { + if Any(Empty[int]()) { t.Log("empty should not have any elements") t.Fail() } - if CountAll(Empty) != 0 { + if CountAll(Empty[int]()) != 0 { t.Log("empty should have counted to zero elements") t.Fail() } - alwaysTrue := func(x interface{}) bool { + alwaysTrue := func(x int) bool { return true } - if Count(Empty, alwaysTrue) != 0 { + if Count(Empty[int](), alwaysTrue) != 0 { t.Log("empty should have counted to zero even when discriminating") t.Fail() } } func BenchmarkEnumerator_Sum(b *testing.B) { - nums := AsEnumerable(getInitializedSequentialArray()...) + var nums EnumerableSlice[int] = getInitializedSequentialArray[int]() b.ResetTimer() for i := 0; i < b.N; i++ { - for range nums.Enumerate(nil).Select(sleepIdentity) { + slowNums := Select[int, int](nums, sleepIdentity[int]) + for range slowNums.Enumerate(nil) { // Intentionally Left Blank } } } -func sleepIdentity(num interface{}) interface{} { +func sleepIdentity[T any](val T) T { time.Sleep(2 * time.Millisecond) - return Identity(num) + return val } -func getInitializedSequentialArray() []interface{} { - - rawNums := make([]interface{}, 1000, 1000) +func getInitializedSequentialArray[T ~int]() []T { + rawNums := make([]T, 1000, 1000) for i := range rawNums { - rawNums[i] = i + 1 + rawNums[i] = T(i + 1) } return rawNums } diff --git a/queue.go b/queue.go index 0ca83a7..f2c4953 100644 --- a/queue.go +++ b/queue.go @@ -5,45 +5,45 @@ import ( ) // Queue implements a basic FIFO structure. -type Queue struct { - underlyer *LinkedList +type Queue[T any] struct { + underlyer *LinkedList[T] key sync.RWMutex } // NewQueue instantiates a new FIFO structure. -func NewQueue(entries ...interface{}) *Queue { - retval := &Queue{ - underlyer: NewLinkedList(entries...), +func NewQueue[T any](entries ...T) *Queue[T] { + retval := &Queue[T]{ + underlyer: NewLinkedList[T](entries...), } return retval } // Add places an item at the back of the Queue. -func (q *Queue) Add(entry interface{}) { +func (q *Queue[T]) Add(entry T) { q.key.Lock() defer q.key.Unlock() if nil == q.underlyer { - q.underlyer = NewLinkedList() + q.underlyer = NewLinkedList[T]() } q.underlyer.AddBack(entry) } // Enumerate peeks at each element of this queue without mutating it. -func (q *Queue) Enumerate(cancel <-chan struct{}) Enumerator { +func (q *Queue[T]) Enumerate(cancel <-chan struct{}) Enumerator[T] { q.key.RLock() defer q.key.RUnlock() return q.underlyer.Enumerate(cancel) } // IsEmpty tests the Queue to determine if it is populate or not. -func (q *Queue) IsEmpty() bool { +func (q *Queue[T]) IsEmpty() bool { q.key.RLock() defer q.key.RUnlock() return q.underlyer == nil || q.underlyer.IsEmpty() } // Length returns the number of items in the Queue. -func (q *Queue) Length() uint { +func (q *Queue[T]) Length() uint { q.key.RLock() defer q.key.RUnlock() if nil == q.underlyer { @@ -53,32 +53,32 @@ func (q *Queue) Length() uint { } // Next removes and returns the next item in the Queue. -func (q *Queue) Next() (interface{}, bool) { +func (q *Queue[T]) Next() (T, bool) { q.key.Lock() defer q.key.Unlock() if q.underlyer == nil { - return nil, false + return *new(T), false } return q.underlyer.RemoveFront() } // Peek returns the next item in the Queue without removing it. -func (q *Queue) Peek() (interface{}, bool) { +func (q *Queue[T]) Peek() (T, bool) { q.key.RLock() defer q.key.RUnlock() if q.underlyer == nil { - return nil, false + return *new(T), false } return q.underlyer.PeekFront() } // ToSlice converts a Queue into a slice. -func (q *Queue) ToSlice() []interface{} { +func (q *Queue[T]) ToSlice() []T { q.key.RLock() defer q.key.RUnlock() if q.underlyer == nil { - return []interface{}{} + return []T{} } return q.underlyer.ToSlice() } diff --git a/queue_test.go b/queue_test.go index fcece89..2d1f13a 100644 --- a/queue_test.go +++ b/queue_test.go @@ -6,7 +6,7 @@ import ( ) func ExampleQueue_Add() { - subject := &Queue{} + subject := &Queue[int]{} subject.Add(1) subject.Add(2) res, _ := subject.Peek() @@ -15,7 +15,7 @@ func ExampleQueue_Add() { } func ExampleNewQueue() { - empty := NewQueue() + empty := NewQueue[int]() fmt.Println(empty.Length()) populated := NewQueue(1, 2, 3, 5, 8, 13) @@ -26,7 +26,7 @@ func ExampleNewQueue() { } func ExampleQueue_IsEmpty() { - empty := NewQueue() + empty := NewQueue[int]() fmt.Println(empty.IsEmpty()) populated := NewQueue(1, 2, 3, 5, 8, 13) @@ -52,7 +52,7 @@ func ExampleQueue_Next() { } func TestQueue_Length(t *testing.T) { - empty := NewQueue() + empty := NewQueue[int]() if count := empty.Length(); count != 0 { t.Logf("got: %d\nwant: %d", count, 0) t.Fail() @@ -74,7 +74,7 @@ func TestQueue_Length(t *testing.T) { } func TestQueue_Length_NonConstructed(t *testing.T) { - subject := &Queue{} + subject := &Queue[int]{} if got := subject.Length(); got != 0 { t.Logf("got: %d\nwant: %d", got, 0) t.Fail() @@ -82,12 +82,13 @@ func TestQueue_Length_NonConstructed(t *testing.T) { } func TestQueue_Next_NonConstructed(t *testing.T) { - subject := &Queue{} + const expected = 0 + subject := &Queue[int]{} if got, ok := subject.Next(); ok { t.Logf("Next should not have been ok") t.Fail() - } else if got != nil { - t.Logf("got: %v\nwant: %v", got, nil) + } else if got != expected { + t.Logf("got: %v\nwant: %v", got, expected) t.Fail() } } @@ -107,12 +108,13 @@ func TestQueue_Peek_DoesntRemove(t *testing.T) { } func TestQueue_Peek_NonConstructed(t *testing.T) { - subject := &Queue{} + const expected = 0 + subject := &Queue[int]{} if got, ok := subject.Peek(); ok { t.Logf("Peek should not have been ok") t.Fail() - } else if got != nil { - t.Logf("got: %v\nwant: %v", got, nil) + } else if got != expected { + t.Logf("got: %v\nwant: %v", got, expected) t.Fail() } } @@ -130,7 +132,7 @@ func TestQueue_ToSlice(t *testing.T) { } func TestQueue_ToSlice_Empty(t *testing.T) { - subject := NewQueue() + subject := NewQueue[int]() result := subject.ToSlice() if len(result) != 0 { @@ -146,7 +148,7 @@ func TestQueue_ToSlice_Empty(t *testing.T) { } func TestQueue_ToSlice_NotConstructed(t *testing.T) { - subject := &Queue{} + subject := &Queue[int]{} result := subject.ToSlice() if len(result) != 0 { diff --git a/stack.go b/stack.go index ecb9cc6..7004d78 100644 --- a/stack.go +++ b/stack.go @@ -5,15 +5,15 @@ import ( ) // Stack implements a basic FILO structure. -type Stack struct { - underlyer *LinkedList +type Stack[T any] struct { + underlyer *LinkedList[T] key sync.RWMutex } // NewStack instantiates a new FILO structure. -func NewStack(entries ...interface{}) *Stack { - retval := &Stack{} - retval.underlyer = NewLinkedList() +func NewStack[T any](entries ...T) *Stack[T] { + retval := &Stack[T]{} + retval.underlyer = NewLinkedList[T]() for _, entry := range entries { retval.Push(entry) @@ -22,7 +22,7 @@ func NewStack(entries ...interface{}) *Stack { } // Enumerate peeks at each element in the stack without mutating it. -func (stack *Stack) Enumerate(cancel <-chan struct{}) Enumerator { +func (stack *Stack[T]) Enumerate(cancel <-chan struct{}) Enumerator[T] { stack.key.RLock() defer stack.key.RUnlock() @@ -30,43 +30,43 @@ func (stack *Stack) Enumerate(cancel <-chan struct{}) Enumerator { } // IsEmpty tests the Stack to determine if it is populate or not. -func (stack *Stack) IsEmpty() bool { +func (stack *Stack[T]) IsEmpty() bool { stack.key.RLock() defer stack.key.RUnlock() return stack.underlyer == nil || stack.underlyer.IsEmpty() } // Push adds an entry to the top of the Stack. -func (stack *Stack) Push(entry interface{}) { +func (stack *Stack[T]) Push(entry T) { stack.key.Lock() defer stack.key.Unlock() if nil == stack.underlyer { - stack.underlyer = NewLinkedList() + stack.underlyer = NewLinkedList[T]() } stack.underlyer.AddFront(entry) } // Pop returns the entry at the top of the Stack then removes it. -func (stack *Stack) Pop() (interface{}, bool) { +func (stack *Stack[T]) Pop() (T, bool) { stack.key.Lock() defer stack.key.Unlock() if nil == stack.underlyer { - return nil, false + return *new(T), false } return stack.underlyer.RemoveFront() } // Peek returns the entry at the top of the Stack without removing it. -func (stack *Stack) Peek() (interface{}, bool) { +func (stack *Stack[T]) Peek() (T, bool) { stack.key.RLock() defer stack.key.RUnlock() return stack.underlyer.PeekFront() } // Size returns the number of entries populating the Stack. -func (stack *Stack) Size() uint { +func (stack *Stack[T]) Size() uint { stack.key.RLock() defer stack.key.RUnlock() if stack.underlyer == nil { diff --git a/stack_test.go b/stack_test.go index f65769e..ccc7fe7 100644 --- a/stack_test.go +++ b/stack_test.go @@ -6,7 +6,7 @@ import ( ) func TestStack_NewStack_FromEmpty(t *testing.T) { - subject := NewStack() + subject := NewStack[string]() subject.Push("alfa") subject.Push("bravo") subject.Push("charlie") @@ -42,7 +42,7 @@ func ExampleNewStack() { } func TestStack_Push_NonConstructor(t *testing.T) { - subject := &Stack{} + subject := &Stack[int]{} sizeAssertion := func(want uint) { if got := subject.Size(); got != want { @@ -67,12 +67,12 @@ func TestStack_Push_NonConstructor(t *testing.T) { } func TestStack_Pop_NonConstructorEmpty(t *testing.T) { - subject := &Stack{} + subject := &Stack[string]{} if result, ok := subject.Pop(); ok { t.Logf("Pop should not have been okay") t.Fail() - } else if result != nil { + } else if result != "" { t.Logf("got: %v\nwant: %v", result, nil) } }