Skip to content

Commit

Permalink
Add "oneof" tag value
Browse files Browse the repository at this point in the history
  • Loading branch information
nao1215 committed May 13, 2024
1 parent 1069485 commit 2e1e56a
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 2 deletions.
74 changes: 74 additions & 0 deletions csv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,41 @@ func TestCSV_Decode(t *testing.T) {
t.Errorf("CSV.Decode() mismatch (-got +want):\n%s", diff)
}
})

t.Run("validate oneof: success case", func(t *testing.T) {
t.Parallel()

input := `id,gender
1,male
2,female
3,prefer_not_to
`
c, err := NewCSV(bytes.NewBufferString(input))
if err != nil {
t.Fatal(err)
}

type person struct {
ID int // no validate
Gender string `validate:"oneof=male female prefer_not_to"`
}

people := make([]person, 0)
errs := c.Decode(&people)
if len(errs) != 0 {
t.Errorf("CSV.Decode() got errors: %v", errs)
}

want := []person{
{ID: 1, Gender: "male"},
{ID: 2, Gender: "female"},
{ID: 3, Gender: "prefer_not_to"},
}

if diff := cmp.Diff(people, want); diff != "" {
t.Errorf("CSV.Decode() mismatch (-got +want):\n%s", diff)
}
})
}

func Test_ErrCheck(t *testing.T) {
Expand Down Expand Up @@ -346,4 +381,43 @@ func Test_ErrCheck(t *testing.T) {
}
}
})

t.Run("validate oneof: error case", func(t *testing.T) {
t.Parallel()

input := `id,gender
1,smale
2,child
3,prefer_not_tooa
`

c, err := NewCSV(bytes.NewBufferString(input))
if err != nil {
t.Fatal(err)
}

type person struct {
ID int // no validate
Gender string `validate:"oneof=male female prefer_not_to"`
}

people := make([]person, 0)
errs := c.Decode(&people)
for i, err := range errs {
switch i {
case 0:
if err.Error() != "line:2 column gender: target is not one of the values: oneof=male female prefer_not_to, value=smale" {
t.Errorf("CSV.Decode() got errors: %v", err)
}
case 1:
if err.Error() != "line:3 column gender: target is not one of the values: oneof=male female prefer_not_to, value=child" {
t.Errorf("CSV.Decode() got errors: %v", err)
}
case 2:
if err.Error() != "line:4 column gender: target is not one of the values: oneof=male female prefer_not_to, value=prefer_not_tooa" {
t.Errorf("CSV.Decode() got errors: %v", err)
}
}
}
})
}
9 changes: 7 additions & 2 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ import "errors"
var (
// ErrStructSlicePointer is returned when the value is not a pointer to a struct slice.
ErrStructSlicePointer = errors.New("value is not a pointer to a struct slice")
// ErrInvalidOneOfFormat is returned when the target is not one of the values.
ErrInvalidOneOfFormat = errors.New("target is not one of the values")
// ErrInvalidThresholdFormat is returned when the threshold value is not an integer.
ErrInvalidThresholdFormat = errors.New("threshold format is invalid")

// ErrInvalidBoolean is returned when the target is not a boolean.
ErrInvalidBoolean = errors.New("target is not a boolean")
// ErrInvalidAlphabet is returned when the target is not an alphabetic character.
Expand All @@ -19,8 +24,6 @@ var (
ErrEqual = errors.New("target is not equal to the threshold value")
// ErrInvalidThreshold is returned when the target is not greater than the value.
ErrInvalidThreshold = errors.New("threshold value is invalid")
// ErrInvalidThresholdFormat is returned when the threshold value is not an integer.
ErrInvalidThresholdFormat = errors.New("threshold format is invalid")
// ErrNotEqual is returned when the target is equal to the value.
ErrNotEqual = errors.New("target is equal to threshold the value")
// ErrGreaterThan is returned when the target is not greater than the value.
Expand All @@ -37,4 +40,6 @@ var (
ErrMax = errors.New("target is greater than the maximum value")
// ErrLength is returned when the target length is not equal to the value.
ErrLength = errors.New("target length is not equal to the threshold value")
// ErrOneOf is returned when the target is not one of the values.
ErrOneOf = errors.New("target is not one of the values")
)
18 changes: 18 additions & 0 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ func parseValidateTag(tags string) (validators, error) {
return nil, err
}
validatorList = append(validatorList, newLengthValidator(threshold))

case strings.HasPrefix(t, oneOfTagValue.String()):
oneOf, err := parseOneOf(t)
if err != nil {
return nil, err
}
validatorList = append(validatorList, newOneOfValidator(oneOf))
}
}
return validatorList, nil
Expand All @@ -140,3 +147,14 @@ func parseThreshold(tagValue string) (float64, error) {
}
return 0, fmt.Errorf("%w: %s", ErrInvalidThresholdFormat, tagValue)
}

// parseOneOf parses the oneOf value.
// tagValue is the value of the struct tag. e.g. oneof=male female prefer_not_to
func parseOneOf(tagValue string) ([]string, error) {
parts := strings.Split(tagValue, "=")

if len(parts) == 2 {
return strings.Split(parts[1], " "), nil
}
return nil, fmt.Errorf("%w: %s", ErrInvalidOneOfFormat, tagValue)
}
2 changes: 2 additions & 0 deletions tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ const (
maxTagValue tagValue = "max"
// lengthTagValue is the struct tag name for length fields.
lengthTagValue tagValue = "len"
// oneOfTagValue is the struct tag name for one of fields.
oneOfTagValue tagValue = "oneof"
)

// String returns the string representation of the tag.
Expand Down
26 changes: 26 additions & 0 deletions validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package csv
import (
"fmt"
"strconv"
"strings"

"github.com/rivo/uniseg"
)
Expand Down Expand Up @@ -382,3 +383,28 @@ func (l *lengthValidator) Do(target any) error {
}
return nil
}

// oneOfValidator is a struct that contains the validation rules for a one of column.
type oneOfValidator struct {
oneOf []string
}

// newOneOfValidator returns a new oneOfValidator.
func newOneOfValidator(oneOf []string) *oneOfValidator {
return &oneOfValidator{oneOf: oneOf}
}

// Do validates the target is one of the oneOf values.
func (o *oneOfValidator) Do(target any) error {
v, ok := target.(string)
if !ok {
return fmt.Errorf("%w: value=%v", ErrOneOf, target) //nolint
}

for _, s := range o.oneOf {
if v == s {
return nil
}
}
return fmt.Errorf("%w: oneof=%s, value=%v", ErrOneOf, strings.Join(o.oneOf, " "), target) //nolint
}

0 comments on commit 2e1e56a

Please sign in to comment.