diff --git a/README.md b/README.md
index b82e442..2431e07 100644
--- a/README.md
+++ b/README.md
@@ -96,6 +96,7 @@ You set the validation rules following the "validate:" tag according to the rule
| min | Check whether value is greater than or equal to the specified value
e.g. `validate:"min=1"` |
| max | Check whether value is less than or equal to the specified value
e.g. `validate:"max=100"` |
| len | Check whether the length of the value is equal to the specified value
e.g. `validate:"len=10"` |
+| oneof | Check whether value is included in the specified values
e.g. `validate:"oneof=male female prefer_not_to"` |
## License
[MIT License](./LICENSE)
diff --git a/csv_test.go b/csv_test.go
index 319e97d..bc473ca 100644
--- a/csv_test.go
+++ b/csv_test.go
@@ -193,6 +193,41 @@ func TestCSV_Decode(t *testing.T) {
t.Errorf("CSV.Decode() mismatch (-got +want):\n%s", diff)
}
})
+
+ t.Run("validate oneof: success case", func(t *testing.T) {
+ t.Parallel()
+
+ input := `id,gender
+1,male
+2,female
+3,prefer_not_to
+`
+ c, err := NewCSV(bytes.NewBufferString(input))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ type person struct {
+ ID int // no validate
+ Gender string `validate:"oneof=male female prefer_not_to"`
+ }
+
+ people := make([]person, 0)
+ errs := c.Decode(&people)
+ if len(errs) != 0 {
+ t.Errorf("CSV.Decode() got errors: %v", errs)
+ }
+
+ want := []person{
+ {ID: 1, Gender: "male"},
+ {ID: 2, Gender: "female"},
+ {ID: 3, Gender: "prefer_not_to"},
+ }
+
+ if diff := cmp.Diff(people, want); diff != "" {
+ t.Errorf("CSV.Decode() mismatch (-got +want):\n%s", diff)
+ }
+ })
}
func Test_ErrCheck(t *testing.T) {
@@ -346,4 +381,43 @@ func Test_ErrCheck(t *testing.T) {
}
}
})
+
+ t.Run("validate oneof: error case", func(t *testing.T) {
+ t.Parallel()
+
+ input := `id,gender
+1,smale
+2,child
+3,prefer_not_tooa
+`
+
+ c, err := NewCSV(bytes.NewBufferString(input))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ type person struct {
+ ID int // no validate
+ Gender string `validate:"oneof=male female prefer_not_to"`
+ }
+
+ people := make([]person, 0)
+ errs := c.Decode(&people)
+ for i, err := range errs {
+ switch i {
+ case 0:
+ if err.Error() != "line:2 column gender: target is not one of the values: oneof=male female prefer_not_to, value=smale" {
+ t.Errorf("CSV.Decode() got errors: %v", err)
+ }
+ case 1:
+ if err.Error() != "line:3 column gender: target is not one of the values: oneof=male female prefer_not_to, value=child" {
+ t.Errorf("CSV.Decode() got errors: %v", err)
+ }
+ case 2:
+ if err.Error() != "line:4 column gender: target is not one of the values: oneof=male female prefer_not_to, value=prefer_not_tooa" {
+ t.Errorf("CSV.Decode() got errors: %v", err)
+ }
+ }
+ }
+ })
}
diff --git a/errors.go b/errors.go
index 02cd5ac..591058f 100644
--- a/errors.go
+++ b/errors.go
@@ -5,6 +5,11 @@ import "errors"
var (
// ErrStructSlicePointer is returned when the value is not a pointer to a struct slice.
ErrStructSlicePointer = errors.New("value is not a pointer to a struct slice")
+ // ErrInvalidOneOfFormat is returned when the target is not one of the values.
+ ErrInvalidOneOfFormat = errors.New("target is not one of the values")
+ // ErrInvalidThresholdFormat is returned when the threshold value is not an integer.
+ ErrInvalidThresholdFormat = errors.New("threshold format is invalid")
+
// ErrInvalidBoolean is returned when the target is not a boolean.
ErrInvalidBoolean = errors.New("target is not a boolean")
// ErrInvalidAlphabet is returned when the target is not an alphabetic character.
@@ -19,8 +24,6 @@ var (
ErrEqual = errors.New("target is not equal to the threshold value")
// ErrInvalidThreshold is returned when the target is not greater than the value.
ErrInvalidThreshold = errors.New("threshold value is invalid")
- // ErrInvalidThresholdFormat is returned when the threshold value is not an integer.
- ErrInvalidThresholdFormat = errors.New("threshold format is invalid")
// ErrNotEqual is returned when the target is equal to the value.
ErrNotEqual = errors.New("target is equal to threshold the value")
// ErrGreaterThan is returned when the target is not greater than the value.
@@ -37,4 +40,6 @@ var (
ErrMax = errors.New("target is greater than the maximum value")
// ErrLength is returned when the target length is not equal to the value.
ErrLength = errors.New("target length is not equal to the threshold value")
+ // ErrOneOf is returned when the target is not one of the values.
+ ErrOneOf = errors.New("target is not one of the values")
)
diff --git a/parser.go b/parser.go
index eaf867a..b58e8d1 100644
--- a/parser.go
+++ b/parser.go
@@ -121,6 +121,13 @@ func parseValidateTag(tags string) (validators, error) {
return nil, err
}
validatorList = append(validatorList, newLengthValidator(threshold))
+
+ case strings.HasPrefix(t, oneOfTagValue.String()):
+ oneOf, err := parseOneOf(t)
+ if err != nil {
+ return nil, err
+ }
+ validatorList = append(validatorList, newOneOfValidator(oneOf))
}
}
return validatorList, nil
@@ -140,3 +147,14 @@ func parseThreshold(tagValue string) (float64, error) {
}
return 0, fmt.Errorf("%w: %s", ErrInvalidThresholdFormat, tagValue)
}
+
+// parseOneOf parses the oneOf value.
+// tagValue is the value of the struct tag. e.g. oneof=male female prefer_not_to
+func parseOneOf(tagValue string) ([]string, error) {
+ parts := strings.Split(tagValue, "=")
+
+ if len(parts) == 2 {
+ return strings.Split(parts[1], " "), nil
+ }
+ return nil, fmt.Errorf("%w: %s", ErrInvalidOneOfFormat, tagValue)
+}
diff --git a/tag.go b/tag.go
index 85e5e3a..ac32c3f 100644
--- a/tag.go
+++ b/tag.go
@@ -40,6 +40,8 @@ const (
maxTagValue tagValue = "max"
// lengthTagValue is the struct tag name for length fields.
lengthTagValue tagValue = "len"
+ // oneOfTagValue is the struct tag name for one of fields.
+ oneOfTagValue tagValue = "oneof"
)
// String returns the string representation of the tag.
diff --git a/validation.go b/validation.go
index a17fbf3..8ef68e7 100644
--- a/validation.go
+++ b/validation.go
@@ -3,6 +3,7 @@ package csv
import (
"fmt"
"strconv"
+ "strings"
"github.com/rivo/uniseg"
)
@@ -382,3 +383,28 @@ func (l *lengthValidator) Do(target any) error {
}
return nil
}
+
+// oneOfValidator is a struct that contains the validation rules for a one of column.
+type oneOfValidator struct {
+ oneOf []string
+}
+
+// newOneOfValidator returns a new oneOfValidator.
+func newOneOfValidator(oneOf []string) *oneOfValidator {
+ return &oneOfValidator{oneOf: oneOf}
+}
+
+// Do validates the target is one of the oneOf values.
+func (o *oneOfValidator) Do(target any) error {
+ v, ok := target.(string)
+ if !ok {
+ return fmt.Errorf("%w: value=%v", ErrOneOf, target) //nolint
+ }
+
+ for _, s := range o.oneOf {
+ if v == s {
+ return nil
+ }
+ }
+ return fmt.Errorf("%w: oneof=%s, value=%v", ErrOneOf, strings.Join(o.oneOf, " "), target) //nolint
+}