-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathenum_test.go
219 lines (189 loc) · 4.49 KB
/
enum_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
package dbcommon_test
import (
"context"
"database/sql/driver"
"fmt"
"github.com/Flaque/filet"
"github.com/ipfs/go-log"
. "github.com/stretchr/testify/assert"
"github.com/synapsecns/sanguine/core/dbcommon"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"os"
"testing"
)
var testDBLogger = log.Logger("dbcommon")
// TestEnum tests the default providers for an enum type.
func (s DbSuite) TestEnum() {
res, err := RunEnumExample(s.GetTestContext(), fmt.Sprintf("%s/sql.db", filet.TmpDir(s.T(), "")))
Nil(s.T(), err)
for i, fruit := range AllFruits {
Equal(s.T(), fruit.Int(), res[i].Fruit.Int())
}
Equal(s.T(), len(res), len(AllFruits))
}
// ExampleEnum demonstrates example use of the enum interface.
// this implementation can be confusing, so there's an example below.
func ExampleEnum() {
res, err := RunEnumExample(context.Background(), fmt.Sprintf("%s/sql.db", os.TempDir()))
if err != nil {
panic(err)
}
for _, res := range res {
fmt.Printf("got result %s \n", res.Fruit.String())
}
}
// RunEnumExample is used to separate out tests from the example.
func RunEnumExample(ctx context.Context, dbDir string) (res []InventoryModel, err error) {
gdb, err := gorm.Open(sqlite.Open(dbDir), &gorm.Config{
Logger: dbcommon.GetGormLogger(testDBLogger),
})
if err != nil {
return res, fmt.Errorf("could not open db: %w", err)
}
// migrate the inventory model
err = gdb.WithContext(ctx).AutoMigrate(&InventoryModel{})
if err != nil {
return res, fmt.Errorf("could not migrate db: %w", err)
}
for _, fruit := range AllFruits {
tx := gdb.WithContext(ctx).Create(&InventoryModel{
Fruit: fruit,
})
if tx.Error != nil {
return res, fmt.Errorf("could not insert fruit: %w", err)
}
}
tx := gdb.WithContext(ctx).Find(&res)
if tx.Error != nil {
return res, fmt.Errorf("could not query db: %w", err)
}
return res, nil
}
// InventoryModel is an example model for of an inventory table for fruit.
type InventoryModel struct {
gorm.Model
// fruit is the fruit we're storing
Fruit Fruit
}
// you should use ints rather than iota's when interacting with the database.
const (
// Apple is an example implementing enum.
Apple Fruit = 0
// Pear is a n example implementing enum.
Pear Fruit = 1
)
var AllFruits = []Fruit{Apple, Pear}
type Fruit uint8
// String gets a string of the enum
// in a production setting, generater should be used.
// see: https://pkg.go.dev/golang.org/x/tools/cmd/stringer for details
func (f Fruit) String() string {
switch f {
case Apple:
return "Apple"
case Pear:
return "Pear"
}
return ""
}
// Int get the integer value of the fruit.
func (f Fruit) Int() uint8 {
return uint8(f)
}
// GormDataType is the gorm data type.
func (f Fruit) GormDataType() string {
return dbcommon.EnumDataType
}
// Scan will scan the fruit into the db.
func (f *Fruit) Scan(src interface{}) error {
res, err := dbcommon.EnumScan(src)
if err != nil {
return fmt.Errorf("could not scan: %w", err)
}
newFruit := Fruit(res)
*f = newFruit
return nil
}
// nolint: wrapcheck
func (f *Fruit) Value() (driver.Value, error) {
return dbcommon.EnumValue(f)
}
var _ dbcommon.EnumInter = (*Fruit)(nil)
type testEnum uint8
func (t testEnum) Int() uint8 {
return uint8(t)
}
const (
testEnumValue1 testEnum = 1
testEnumValue2 testEnum = 2
)
func TestEnumValue(t *testing.T) {
tests := []struct {
name string
enum dbcommon.EnumInter
want int64
wantErr error
}{
{
name: "Valid enum value",
enum: testEnumValue1,
want: 1,
},
{
name: "Valid enum value",
enum: testEnumValue2,
want: 2,
},
}
for i := range tests {
tt := tests[i]
t.Run(tt.name, func(t *testing.T) {
got, err := dbcommon.EnumValue(tt.enum)
if tt.wantErr != nil {
ErrorIs(t, err, tt.wantErr)
} else {
Nil(t, err)
Equal(t, tt.want, got)
}
})
}
}
func TestEnumScan(t *testing.T) {
tests := []struct {
name string
src interface{}
want uint8
wantErr string
}{
{
name: "Valid int64 value",
src: int64(1),
want: 1,
},
{
name: "Valid int32 value",
src: int32(2),
want: 2,
},
{
name: "Invalid type",
src: "invalid",
want: 0,
wantErr: "could not scan enum: converting driver.Value type string (\"invalid\") to a int32: invalid syntax",
},
}
for i := range tests {
tt := tests[i]
t.Run(tt.name, func(t *testing.T) {
got, err := dbcommon.EnumScan(tt.src)
if tt.wantErr != "" {
Error(t, err)
EqualError(t, err, tt.wantErr)
} else {
NoError(t, err)
Equal(t, tt.want, got)
}
})
}
}