diff --git a/README.md b/README.md index b82e442..2431e07 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,7 @@ You set the validation rules following the "validate:" tag according to the rule | min | Check whether value is greater than or equal to the specified value
e.g. `validate:"min=1"` | | max | Check whether value is less than or equal to the specified value
e.g. `validate:"max=100"` | | len | Check whether the length of the value is equal to the specified value
e.g. `validate:"len=10"` | +| oneof | Check whether value is included in the specified values
e.g. `validate:"oneof=male female prefer_not_to"` | ## License [MIT License](./LICENSE) diff --git a/csv_test.go b/csv_test.go index 319e97d..bc473ca 100644 --- a/csv_test.go +++ b/csv_test.go @@ -193,6 +193,41 @@ func TestCSV_Decode(t *testing.T) { t.Errorf("CSV.Decode() mismatch (-got +want):\n%s", diff) } }) + + t.Run("validate oneof: success case", func(t *testing.T) { + t.Parallel() + + input := `id,gender +1,male +2,female +3,prefer_not_to +` + c, err := NewCSV(bytes.NewBufferString(input)) + if err != nil { + t.Fatal(err) + } + + type person struct { + ID int // no validate + Gender string `validate:"oneof=male female prefer_not_to"` + } + + people := make([]person, 0) + errs := c.Decode(&people) + if len(errs) != 0 { + t.Errorf("CSV.Decode() got errors: %v", errs) + } + + want := []person{ + {ID: 1, Gender: "male"}, + {ID: 2, Gender: "female"}, + {ID: 3, Gender: "prefer_not_to"}, + } + + if diff := cmp.Diff(people, want); diff != "" { + t.Errorf("CSV.Decode() mismatch (-got +want):\n%s", diff) + } + }) } func Test_ErrCheck(t *testing.T) { @@ -346,4 +381,43 @@ func Test_ErrCheck(t *testing.T) { } } }) + + t.Run("validate oneof: error case", func(t *testing.T) { + t.Parallel() + + input := `id,gender +1,smale +2,child +3,prefer_not_tooa +` + + c, err := NewCSV(bytes.NewBufferString(input)) + if err != nil { + t.Fatal(err) + } + + type person struct { + ID int // no validate + Gender string `validate:"oneof=male female prefer_not_to"` + } + + people := make([]person, 0) + errs := c.Decode(&people) + for i, err := range errs { + switch i { + case 0: + if err.Error() != "line:2 column gender: target is not one of the values: oneof=male female prefer_not_to, value=smale" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + case 1: + if err.Error() != "line:3 column gender: target is not one of the values: oneof=male female prefer_not_to, value=child" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + case 2: + if err.Error() != "line:4 column gender: target is not one of the values: oneof=male female prefer_not_to, value=prefer_not_tooa" { + t.Errorf("CSV.Decode() got errors: %v", err) + } + } + } + }) } diff --git a/errors.go b/errors.go index 02cd5ac..591058f 100644 --- a/errors.go +++ b/errors.go @@ -5,6 +5,11 @@ import "errors" var ( // ErrStructSlicePointer is returned when the value is not a pointer to a struct slice. ErrStructSlicePointer = errors.New("value is not a pointer to a struct slice") + // ErrInvalidOneOfFormat is returned when the target is not one of the values. + ErrInvalidOneOfFormat = errors.New("target is not one of the values") + // ErrInvalidThresholdFormat is returned when the threshold value is not an integer. + ErrInvalidThresholdFormat = errors.New("threshold format is invalid") + // ErrInvalidBoolean is returned when the target is not a boolean. ErrInvalidBoolean = errors.New("target is not a boolean") // ErrInvalidAlphabet is returned when the target is not an alphabetic character. @@ -19,8 +24,6 @@ var ( ErrEqual = errors.New("target is not equal to the threshold value") // ErrInvalidThreshold is returned when the target is not greater than the value. ErrInvalidThreshold = errors.New("threshold value is invalid") - // ErrInvalidThresholdFormat is returned when the threshold value is not an integer. - ErrInvalidThresholdFormat = errors.New("threshold format is invalid") // ErrNotEqual is returned when the target is equal to the value. ErrNotEqual = errors.New("target is equal to threshold the value") // ErrGreaterThan is returned when the target is not greater than the value. @@ -37,4 +40,6 @@ var ( ErrMax = errors.New("target is greater than the maximum value") // ErrLength is returned when the target length is not equal to the value. ErrLength = errors.New("target length is not equal to the threshold value") + // ErrOneOf is returned when the target is not one of the values. + ErrOneOf = errors.New("target is not one of the values") ) diff --git a/parser.go b/parser.go index eaf867a..b58e8d1 100644 --- a/parser.go +++ b/parser.go @@ -121,6 +121,13 @@ func parseValidateTag(tags string) (validators, error) { return nil, err } validatorList = append(validatorList, newLengthValidator(threshold)) + + case strings.HasPrefix(t, oneOfTagValue.String()): + oneOf, err := parseOneOf(t) + if err != nil { + return nil, err + } + validatorList = append(validatorList, newOneOfValidator(oneOf)) } } return validatorList, nil @@ -140,3 +147,14 @@ func parseThreshold(tagValue string) (float64, error) { } return 0, fmt.Errorf("%w: %s", ErrInvalidThresholdFormat, tagValue) } + +// parseOneOf parses the oneOf value. +// tagValue is the value of the struct tag. e.g. oneof=male female prefer_not_to +func parseOneOf(tagValue string) ([]string, error) { + parts := strings.Split(tagValue, "=") + + if len(parts) == 2 { + return strings.Split(parts[1], " "), nil + } + return nil, fmt.Errorf("%w: %s", ErrInvalidOneOfFormat, tagValue) +} diff --git a/tag.go b/tag.go index 85e5e3a..ac32c3f 100644 --- a/tag.go +++ b/tag.go @@ -40,6 +40,8 @@ const ( maxTagValue tagValue = "max" // lengthTagValue is the struct tag name for length fields. lengthTagValue tagValue = "len" + // oneOfTagValue is the struct tag name for one of fields. + oneOfTagValue tagValue = "oneof" ) // String returns the string representation of the tag. diff --git a/validation.go b/validation.go index a17fbf3..8ef68e7 100644 --- a/validation.go +++ b/validation.go @@ -3,6 +3,7 @@ package csv import ( "fmt" "strconv" + "strings" "github.com/rivo/uniseg" ) @@ -382,3 +383,28 @@ func (l *lengthValidator) Do(target any) error { } return nil } + +// oneOfValidator is a struct that contains the validation rules for a one of column. +type oneOfValidator struct { + oneOf []string +} + +// newOneOfValidator returns a new oneOfValidator. +func newOneOfValidator(oneOf []string) *oneOfValidator { + return &oneOfValidator{oneOf: oneOf} +} + +// Do validates the target is one of the oneOf values. +func (o *oneOfValidator) Do(target any) error { + v, ok := target.(string) + if !ok { + return fmt.Errorf("%w: value=%v", ErrOneOf, target) //nolint + } + + for _, s := range o.oneOf { + if v == s { + return nil + } + } + return fmt.Errorf("%w: oneof=%s, value=%v", ErrOneOf, strings.Join(o.oneOf, " "), target) //nolint +}