diff --git a/README.md b/README.md index 19a19a3..0c4da41 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,8 @@ You set the validation rules following the "validate:" tag according to the rule | gte | Check whether value is greater than or equal to the specified value
e.g. `validate:"gte=1"` | | lt | Check whether value is less than the specified value
e.g. `validate:"lt=1"` | | lte | Check whether value is less than or equal to the specified value
e.g. `validate:"lte=1"` | +| min | Check whether value is greater than or equal to the specified value
e.g. `validate:"min=1"` | +| max | Check whether value is less than or equal to the specified value
e.g. `validate:"max=100"` | ## License [MIT License](./LICENSE) diff --git a/csv_test.go b/csv_test.go index f920c4e..3bddccd 100644 --- a/csv_test.go +++ b/csv_test.go @@ -1,6 +1,7 @@ package csv import ( + "bytes" "os" "path/filepath" "testing" @@ -10,81 +11,6 @@ import ( func TestCSV_Decode(t *testing.T) { t.Parallel() - - t.Run("all error: `id,name,age,password` header", func(t *testing.T) { - t.Parallel() - - f, err := os.Open(filepath.Join("testdata", "all_error.csv")) - if err != nil { - t.Fatal(err) - } - - c, err := NewCSV(f) - if err != nil { - t.Fatal(err) - } - - type person struct { - ID int `validate:"numeric,gte=1"` - Name string `validate:"alpha"` - Age int `validate:"numeric,gt=-1,lt=120,gte=0"` - Password string `validate:"required,alphanumeric"` - IsAdmin bool `validate:"boolean"` - Zero int `validate:"numeric,eq=0,lte=1,ne=1"` - } - people := make([]person, 0) - - got := c.Decode(&people) - for i, err := range got { - switch i { - case 0: - if err.Error() != "line:2 column id: target is not greater than or equal to the threshold value: threshold=1.000000, value=0.000000" { - t.Errorf("CSV.Decode() got errors: %v", err) - } - case 1: - if err.Error() != "line:3 column password: target is not an alphanumeric character: value=password-bad" { - t.Errorf("CSV.Decode() got errors: %v", err) - } - case 2: - if err.Error() != "line:4 column password: target is required but is empty: value=" { - t.Errorf("CSV.Decode() got errors: %v", err) - } - case 3: - if err.Error() != "line:5 column name: target is not an alphabetic character: value=1Joyless" { - t.Errorf("CSV.Decode() got errors: %v", err) - } - case 4: - if err.Error() != "line:5 column zero: target is not equal to the threshold value: threshold=0.000000, value=1.000000" { - t.Errorf("CSV.Decode() got errors: %v", err) - } - case 5: - if err.Error() != "line:5 column zero: target is equal to threshold the value: threshold=1.000000, value=1.000000" { - t.Errorf("CSV.Decode() got errors: %v", err) - } - case 6: - if err.Error() != "line:6 column age: target is not less than the threshold value: threshold=120.000000, value=120.000000" { - t.Errorf("CSV.Decode() got errors: %v", err) - } - case 7: - if err.Error() != "line:7 column is_admin: target is not a boolean: value=2" { - t.Errorf("CSV.Decode() got errors: %v", err) - } - case 8: - if err.Error() != "line:8 column age: target is not greater than the threshold value: threshold=-1.000000, value=-1.000000" { - t.Errorf("CSV.Decode() got errors: %v", err) - } - case 9: - if err.Error() != "line:8 column age: target is not greater than or equal to the threshold value: threshold=0.000000, value=-1.000000" { - t.Errorf("CSV.Decode() got errors: %v", err) - } - case 10: - if err.Error() != "line:9 column id: target is not a numeric character: value=a" { - t.Errorf("CSV.Decode() got errors: %v", err) - } - } - } - }) - t.Run("read `id,name,age` header with value", func(t *testing.T) { t.Parallel() @@ -194,4 +120,157 @@ func TestCSV_Decode(t *testing.T) { t.Errorf("CSV.Decode() mismatch (-got +want):\n%s", diff) } }) + + t.Run("validate min, max: success case", func(t *testing.T) { + t.Parallel() + + input := `id,age +1,0 +2,1 +3,120 +4,119 +` + + c, err := NewCSV(bytes.NewBufferString(input)) + if err != nil { + t.Fatal(err) + } + + type person struct { + ID int // no validate + Age int `validate:"min=0,max=120.0"` + } + + people := make([]person, 0) + errs := c.Decode(&people) + if len(errs) != 0 { + t.Errorf("CSV.Decode() got errors: %v", errs) + } + + want := []person{ + {ID: 1, Age: 0}, + {ID: 2, Age: 1}, + {ID: 3, Age: 120}, + {ID: 4, Age: 119}, + } + + if diff := cmp.Diff(people, want); diff != "" { + t.Errorf("CSV.Decode() mismatch (-got +want):\n%s", diff) + } + }) +} + +func Test_ErrCheck(t *testing.T) { + t.Parallel() + + t.Run("error: `id,name,age,password` header", func(t *testing.T) { + t.Parallel() + + f, err := os.Open(filepath.Join("testdata", "all_error.csv")) + if err != nil { + t.Fatal(err) + } + + c, err := NewCSV(f) + if err != nil { + t.Fatal(err) + } + + type person struct { + ID int `validate:"numeric,gte=1"` + Name string `validate:"alpha"` + Age int `validate:"numeric,gt=-1,lt=120,gte=0"` + Password string `validate:"required,alphanumeric"` + IsAdmin bool `validate:"boolean"` + Zero int `validate:"numeric,eq=0,lte=1,ne=1"` + } + people := make([]person, 0) + + got := c.Decode(&people) + for i, err := range got { + switch i { + case 0: + if err.Error() != "line:2 column id: target is not greater than or equal to the threshold value: threshold=1.000000, value=0.000000" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + case 1: + if err.Error() != "line:3 column password: target is not an alphanumeric character: value=password-bad" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + case 2: + if err.Error() != "line:4 column password: target is required but is empty: value=" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + case 3: + if err.Error() != "line:5 column name: target is not an alphabetic character: value=1Joyless" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + case 4: + if err.Error() != "line:5 column zero: target is not equal to the threshold value: threshold=0.000000, value=1.000000" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + case 5: + if err.Error() != "line:5 column zero: target is equal to threshold the value: threshold=1.000000, value=1.000000" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + case 6: + if err.Error() != "line:6 column age: target is not less than the threshold value: threshold=120.000000, value=120.000000" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + case 7: + if err.Error() != "line:7 column is_admin: target is not a boolean: value=2" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + case 8: + if err.Error() != "line:8 column age: target is not greater than the threshold value: threshold=-1.000000, value=-1.000000" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + case 9: + if err.Error() != "line:8 column age: target is not greater than or equal to the threshold value: threshold=0.000000, value=-1.000000" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + case 10: + if err.Error() != "line:9 column id: target is not a numeric character: value=a" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + } + } + }) + + t.Run("validate min, max: error case", func(t *testing.T) { + t.Parallel() + + input := `id,age +1,0 +2,-1 +3,120 +4,120.1 +` + + c, err := NewCSV(bytes.NewBufferString(input)) + if err != nil { + t.Fatal(err) + } + + type person struct { + ID int // no validate + Age int `validate:"min=0,max=120.0"` + } + + people := make([]person, 0) + errs := c.Decode(&people) + + for i, err := range errs { + switch i { + case 0: + if err.Error() != "line:3 column age: target is less than the minimum value: threshold=0.000000, value=-1.000000" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + case 1: + if err.Error() != "line:5 column age: target is greater than the maximum value: threshold=120.000000, value=120.100000" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + } + } + }) } diff --git a/errors.go b/errors.go index 4ae0051..b526ef4 100644 --- a/errors.go +++ b/errors.go @@ -31,4 +31,8 @@ var ( ErrLessThan = errors.New("target is not less than the threshold value") // ErrLessThanEqual is returned when the target is not less than or equal to the value. ErrLessThanEqual = errors.New("target is not less than or equal to the threshold value") + // ErrMin is returned when the target is less than the minimum value. + ErrMin = errors.New("target is less than the minimum value") + // ErrMax is returned when the target is greater than the maximum value. + ErrMax = errors.New("target is greater than the maximum value") ) diff --git a/parser.go b/parser.go index 8b9a634..4680124 100644 --- a/parser.go +++ b/parser.go @@ -103,6 +103,18 @@ func parseValidateTag(tags string) (validators, error) { return nil, err } validatorList = append(validatorList, newLessThanEqualValidator(threshold)) + case strings.HasPrefix(t, minTagValue.String()): + threshold, err := parseThreshold(t) + if err != nil { + return nil, err + } + validatorList = append(validatorList, newMinValidator(threshold)) + case strings.HasPrefix(t, maxTagValue.String()): + threshold, err := parseThreshold(t) + if err != nil { + return nil, err + } + validatorList = append(validatorList, newMaxValidator(threshold)) } } return validatorList, nil diff --git a/tag.go b/tag.go index 3ea0b25..0696a0c 100644 --- a/tag.go +++ b/tag.go @@ -34,6 +34,10 @@ const ( lessThanTagValue tagValue = "lt" // lessThanEqualTagValue is the struct tag name for less than or equal fields. lessThanEqualTagValue tagValue = "lte" + // minTagValue is the struct tag name for minimum fields. + minTagValue tagValue = "min" + // maxTagValue is the struct tag name for maximum fields. + maxTagValue tagValue = "max" ) // String returns the string representation of the tag. diff --git a/validation.go b/validation.go index bbfb9ef..ce5f1c7 100644 --- a/validation.go +++ b/validation.go @@ -300,3 +300,59 @@ func (l *lessThanEqualValidator) Do(target any) error { } return nil } + +// minValidator is a struct that contains the validation rules for a minimum column. +type minValidator struct { + threshold float64 +} + +// newMinValidator returns a new minValidator. +func newMinValidator(threshold float64) *minValidator { + return &minValidator{threshold: threshold} +} + +// Do validates the target is greater than or equal to the threshold. +func (m *minValidator) Do(target any) error { + v, ok := target.(string) + if !ok { + return fmt.Errorf("%w: value=%v", ErrMin, target) //nolint + } + + value, err := strconv.ParseFloat(v, 64) + if err != nil { + return fmt.Errorf("%w: value=%v", ErrMin, target) //nolint + } + + if value < m.threshold { + return fmt.Errorf("%w: threshold=%f, value=%f", ErrMin, m.threshold, value) //nolint + } + return nil +} + +// maxValidator is a struct that contains the validation rules for a maximum column. +type maxValidator struct { + threshold float64 +} + +// newMaxValidator returns a new maxValidator. +func newMaxValidator(threshold float64) *maxValidator { + return &maxValidator{threshold: threshold} +} + +// Do validates the target is less than or equal to the threshold. +func (m *maxValidator) Do(target any) error { + v, ok := target.(string) + if !ok { + return fmt.Errorf("%w: value=%v", ErrMax, target) //nolint + } + + value, err := strconv.ParseFloat(v, 64) + if err != nil { + return fmt.Errorf("%w: value=%v", ErrMax, target) //nolint + } + + if value > m.threshold { + return fmt.Errorf("%w: threshold=%f, value=%f", ErrMax, m.threshold, value) //nolint + } + return nil +}