diff --git a/Makefile b/Makefile index d8b6e7d..0568ac4 100644 --- a/Makefile +++ b/Makefile @@ -27,6 +27,9 @@ test: -bench=$(TEST_PATTERN) \ -timeout=2m +lint: ## Run lint + golangci-lint run --timeout 5m + bench: @[ -d ${REPORTS_DIR} ] || mkdir -p ${REPORTS_DIR} @rm -rf ${REPORTS_DIR}/* diff --git a/basic.go b/basic.go index 8c33e43..f1c2105 100644 --- a/basic.go +++ b/basic.go @@ -61,12 +61,12 @@ func toString(v any) (string, bool) { return "", false } -func hasRequiredRule(rules []Rule) (Required, bool) { +func hasRequiredRule(rules []Rule) (*Required, bool) { for _, r := range rules { - if v, ok := r.(Required); ok { + if v, ok := r.(*Required); ok { return v, ok } } - return Required{}, false + return nil, false } diff --git a/callback.go b/callback.go index f439ffb..bbb1734 100644 --- a/callback.go +++ b/callback.go @@ -8,16 +8,48 @@ import ( type CallbackFunc[T any] func(ctx context.Context, value T) error type Callback[T any] struct { - f CallbackFunc[T] + f CallbackFunc[T] + whenFunc WhenFunc + skipEmpty bool } -func NewCallback[T any](f CallbackFunc[T]) Callback[T] { - return Callback[T]{ +func NewCallback[T any](f CallbackFunc[T]) *Callback[T] { + return &Callback[T]{ f: f, } } -func (c Callback[T]) ValidateValue(ctx context.Context, value any) error { +func (r *Callback[T]) When(v WhenFunc) *Callback[T] { + rc := *r + rc.whenFunc = v + + return &rc +} + +func (r *Callback[T]) when() WhenFunc { + return r.whenFunc +} + +func (r *Callback[T]) setWhen(v WhenFunc) { + r.whenFunc = v +} + +func (r *Callback[T]) SkipOnEmpty(v bool) *Callback[T] { + rc := *r + rc.skipEmpty = v + + return &rc +} + +func (r *Callback[T]) skipOnEmpty() bool { + return r.skipEmpty +} + +func (r *Callback[T]) setSkipOnEmpty(v bool) { + r.skipEmpty = v +} + +func (c *Callback[T]) ValidateValue(ctx context.Context, value any) error { v, ok := value.(T) if !ok { var v T diff --git a/callback_test.go b/callback_test.go index 3c536b8..14ee0b0 100644 --- a/callback_test.go +++ b/callback_test.go @@ -33,7 +33,7 @@ func TestCallback_ValidateValue_Error(t *testing.T) { rules := RuleSet{ "A": { NewCallback(func(ctx context.Context, value int) error { - if ds, ok := extractDataSet(ctx); ok { + if ds, ok := ExtractDataSet[DataSet](ctx); ok { if obj, ok := ds.Data().(*TestCallback); ok { if obj.B > value { return errAMustGreatB diff --git a/compare.go b/compare.go index f23015d..b626a00 100644 --- a/compare.go +++ b/compare.go @@ -15,8 +15,8 @@ type Compare struct { skipEmpty bool } -func NewCompare(targetValue any, targetAttribute, operator string) Compare { - c := Compare{ +func NewCompare(targetValue any, targetAttribute, operator string) *Compare { + c := &Compare{ targetValue: targetValue, targetAttribute: targetAttribute, operator: operator, @@ -43,28 +43,45 @@ func NewCompare(targetValue any, targetAttribute, operator string) Compare { return c } -func (c Compare) When(v WhenFunc) Compare { - c.whenFunc = v +func (r *Compare) WithMessage(v string) *Compare { + rc := *r + rc.message = v - return c + return &rc } -func (c Compare) when() WhenFunc { - return c.whenFunc +func (r *Compare) When(v WhenFunc) *Compare { + rc := *r + rc.whenFunc = v + + return &rc } -func (c Compare) SkipOnEmpty(v bool) Compare { - c.skipEmpty = v +func (r *Compare) when() WhenFunc { + return r.whenFunc +} - return c +func (r *Compare) setWhen(v WhenFunc) { + r.whenFunc = v +} + +func (r *Compare) SkipOnEmpty(v bool) *Compare { + rc := *r + rc.skipEmpty = v + + return &rc +} + +func (r *Compare) skipOnEmpty() bool { + return r.skipEmpty } -func (c Compare) skipOnEmpty() bool { - return c.skipEmpty +func (r *Compare) setSkipOnEmpty(v bool) { + r.skipEmpty = v } -func (c Compare) ValidateValue(ctx context.Context, value any) error { - if !c.operatorIsValid { +func (r *Compare) ValidateValue(ctx context.Context, value any) error { + if !r.operatorIsValid { return UnknownOperatorError } @@ -73,60 +90,60 @@ func (c Compare) ValidateValue(ctx context.Context, value any) error { targetValueOrAttr any err error ) - targetValue = c.targetValue - targetValueOrAttr = c.targetAttribute + targetValue = r.targetValue + targetValueOrAttr = r.targetAttribute - if c.targetValue == nil { - dataSet, ok := extractDataSet(ctx) + if r.targetValue == nil { + dataSet, ok := ExtractDataSet[DataSet](ctx) if !ok { return NotExistsDataSetIntoContextError } - targetValue, err = dataSet.FieldValue(c.targetAttribute) + targetValue, err = dataSet.FieldValue(r.targetAttribute) if err != nil { return err } targetValueOrAttr = targetValue } - switch c.operator { + switch r.operator { case "==": - if c.eq(value, targetValue) { + if r.eq(value, targetValue) { return nil } case "!=": - if !c.eq(value, targetValue) { + if !r.eq(value, targetValue) { return nil } case ">": - if c.gt(value, targetValue) { + if r.gt(value, targetValue) { return nil } case ">=": - if c.eq(value, targetValue) || c.gt(value, targetValue) { + if r.eq(value, targetValue) || r.gt(value, targetValue) { return nil } case "<": - if !c.eq(value, targetValue) && !c.gt(value, targetValue) { + if !r.eq(value, targetValue) && !r.gt(value, targetValue) { return nil } case "<=": - if c.eq(value, targetValue) || !c.gt(value, targetValue) { + if r.eq(value, targetValue) || !r.gt(value, targetValue) { return nil } } return NewResult(). WithError( - NewValidationError(c.message). + NewValidationError(r.message). WithParams(map[string]any{ - "targetValue": c.targetValue, - "targetAttribute": c.targetAttribute, + "targetValue": r.targetValue, + "targetAttribute": r.targetAttribute, "targetValueOrAttribute": targetValueOrAttr, }), ) } -func (c Compare) eq(a, b any) bool { +func (r *Compare) eq(a, b any) bool { if ia, ok := a.(int); ok { if ib, ok := b.(int); ok { return ia == ib @@ -166,7 +183,7 @@ func (c Compare) eq(a, b any) bool { return a == b } -func (c Compare) gt(a, b any) bool { +func (r *Compare) gt(a, b any) bool { if ia, ok := a.(int); ok { if ib, ok := b.(int); ok { return ia > ib diff --git a/context.go b/context.go index a2472da..5fd17af 100644 --- a/context.go +++ b/context.go @@ -10,6 +10,7 @@ type Key uint8 const ( KeyDataSet Key = iota + 1 + PreviousRulesErrored ) type DataSet interface { @@ -19,17 +20,73 @@ type DataSet interface { Data() any } +type Context struct { + context.Context + ds DataSet +} + +func NewContext(ctx context.Context) *Context { + return &Context{Context: ctx} +} + +func (c *Context) Value(key any) any { + if key == KeyDataSet { + return c.ds + } + + return c.Context.Value(key) +} + +func (c *Context) withDataSet(ds DataSet) *Context { + cc := *c + cc.ds = ds + + return &cc +} + +func (c *Context) dataSet() (DataSet, bool) { + return c.ds, c.ds != nil +} + +func DataSetFromContext[T DataSet](ctx *Context) (T, bool) { + if ds, ok := ctx.dataSet(); ok { + if dsT, ok2 := ds.(T); ok2 { + return dsT, true + } + } + var v T + + return v, false +} + +// todo: write funcs if context.Context interface + func withDataSet(ctx context.Context, ds DataSet) context.Context { - return context.WithValue(ctx, KeyDataSet, ds) + return NewContext(ctx).withDataSet(ds) + //return context.WithValue(ctx, KeyDataSet, ds) } -func extractDataSet(ctx context.Context) (DataSet, bool) { +func ExtractDataSet[T DataSet](ctx context.Context) (T, bool) { + var v T if ctx == nil { - return nil, false + return v, false } - if ds, ok := ctx.Value(KeyDataSet).(DataSet); ok { - return ds, true + + ds, ok := ctx.Value(KeyDataSet).(T) + if !ok { + return v, false } - return nil, false + return ds, true +} + +//func withPreviousRulesErrored(ctx context.Context) context.Context { +// return context.WithValue(ctx, PreviousRulesErrored, true) +//} + +func previousRulesErrored(ctx context.Context) bool { + if y, ok := ctx.Value(PreviousRulesErrored).(bool); ok { + return y + } + return false } diff --git a/each.go b/each.go index edb87f1..cfff94e 100644 --- a/each.go +++ b/each.go @@ -16,60 +16,66 @@ type Each struct { skipEmpty bool } -func NewEach(rules ...Rule) Each { - return Each{ +func NewEach(rules ...Rule) *Each { + return &Each{ message: "Value is invalid", incorrectInputMessage: "Value must be array", rules: rules, + normalizeRulesEnabled: true, } } -func (e Each) WithMessage(message string) Each { - e.message = message +func (r *Each) WithMessage(message string) *Each { + rc := *r + rc.message = message - return e + return &rc } -func (e Each) When(v WhenFunc) Each { - e.whenFunc = v +func (r *Each) When(v WhenFunc) *Each { + rc := *r + rc.whenFunc = v - return e + return &rc } -func (e Each) when() WhenFunc { - return e.whenFunc +func (r *Each) when() WhenFunc { + return r.whenFunc } -func (e Each) setWhen(v WhenFunc) { - e.whenFunc = v +func (r *Each) setWhen(v WhenFunc) { + r.whenFunc = v } -func (e Each) SkipOnEmpty(v bool) Each { - e.skipEmpty = v - return e +func (r *Each) SkipOnEmpty(v bool) *Each { + rc := *r + rc.skipEmpty = v + + return &rc } -func (e Each) skipOnEmpty() bool { - return e.skipEmpty +func (r *Each) skipOnEmpty() bool { + return r.skipEmpty } -func (e Each) setSkipOnEmpty(v bool) { - e.skipEmpty = v // because copy +func (r *Each) setSkipOnEmpty(v bool) { + r.skipEmpty = v } -func (e Each) WithIncorrectInputMessage(incorrectInputMessage string) Each { - e.incorrectInputMessage = incorrectInputMessage +func (r *Each) WithIncorrectInputMessage(incorrectInputMessage string) *Each { + rc := *r + rc.incorrectInputMessage = incorrectInputMessage - return e + return &rc } -func (e Each) ValidateValue(ctx context.Context, value any) error { - e.normalizeRules() +func (r *Each) ValidateValue(ctx context.Context, value any) error { + r.normalizeRules() result := NewResult() if reflect.TypeOf(value).Kind() != reflect.Slice { return result.WithError( - NewValidationError(e.incorrectInputMessage). + NewValidationError(r.incorrectInputMessage). WithParams(map[string]any{ //"attribute": "",//todo "value": value, @@ -81,7 +87,7 @@ func (e Each) ValidateValue(ctx context.Context, value any) error { for i := 0; i < vs.Len(); i++ { v := vs.Index(i).Interface() - if err := ValidateValue(ctx, v, e.rules...); err != nil { + if err := ValidateValue(ctx, v, r.rules...); err != nil { var r Result if errors.As(err, &r) { for _, err := range r.Errors() { @@ -107,21 +113,21 @@ func (e Each) ValidateValue(ctx context.Context, value any) error { return result } -func (e Each) normalizeRules() { - if !e.normalizeRulesEnabled { +func (r *Each) normalizeRules() { + if !r.normalizeRulesEnabled { return } - e.normalizeRulesEnabled = false // once + r.normalizeRulesEnabled = false - for i, r := range e.rules { - if rse, ok := r.(RuleSkipEmpty); ok { - rse.setSkipOnEmpty(e.skipEmpty) + for i, rule := range r.rules { + if rse, ok := rule.(RuleSkipEmpty); ok { + rse.setSkipOnEmpty(r.skipEmpty) } - if rw, ok := r.(RuleWhen); ok { - rw.setWhen(e.whenFunc) + if rw, ok := rule.(RuleWhen); ok { + rw.setWhen(r.whenFunc) } - e.rules[i] = r + r.rules[i] = rule } } diff --git a/email.go b/email.go index f9f0eae..64080d5 100644 --- a/email.go +++ b/email.go @@ -1,50 +1,14 @@ package validator -import "context" - -const ( - emailRegexp = `^[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,4}$` -) +const emailRegexp = `^[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,4}$` type Email struct { - basicRule MatchRegularExpression - whenFunc WhenFunc - skipEmpty bool + *MatchRegularExpression } -func NewEmail() Email { - return Email{ - basicRule: NewMatchRegularExpression(emailRegexp). +func NewEmail() *Email { + return &Email{ + MatchRegularExpression: NewMatchRegularExpression(emailRegexp). WithMessage("Email is not a valid email."), } } - -func (e Email) When(v WhenFunc) Email { - e.whenFunc = v - - return e -} - -func (e Email) when() WhenFunc { - return e.whenFunc -} - -func (e Email) SkipOnEmpty(v bool) Email { - e.skipEmpty = v - - return e -} - -func (e Email) skipOnEmpty() bool { - return e.skipEmpty -} - -func (e Email) WithMessage(message string) Email { - e.basicRule = e.basicRule.WithMessage(message) - - return e -} - -func (e Email) ValidateValue(ctx context.Context, value any) error { - return e.basicRule.ValidateValue(ctx, value) -} diff --git a/human_text.go b/human_text.go new file mode 100644 index 0000000..6b15c2e --- /dev/null +++ b/human_text.go @@ -0,0 +1,14 @@ +package validator + +const humanRegexp = `^[А-Яа-яЁёa-zA-Z0-9 ,.-]+$` + +type HumanText struct { + *MatchRegularExpression +} + +func NewHumanText() *HumanText { + return &HumanText{ + MatchRegularExpression: NewMatchRegularExpression(humanRegexp). + WithMessage("This value must be a normal text."), + } +} diff --git a/in_range.go b/in_range.go index e2753a8..0c302ac 100644 --- a/in_range.go +++ b/in_range.go @@ -10,47 +10,59 @@ type InRange struct { skipEmpty bool } -func NewInRange(rangeValues []any) InRange { - return InRange{ +func NewInRange(rangeValues []any) *InRange { + return &InRange{ message: "This value is invalid", rangeValues: rangeValues, not: false, } } -func (r InRange) When(v WhenFunc) InRange { - r.whenFunc = v +func (r *InRange) WithMessage(message string) *InRange { + rc := *r + rc.message = message - return r + return &rc } -func (r InRange) when() WhenFunc { - return r.whenFunc +func (r *InRange) Not() *InRange { + rc := *r + rc.not = true + + return &rc } -func (r InRange) SkipOnEmpty(v bool) InRange { - r.skipEmpty = v +func (r *InRange) When(v WhenFunc) *InRange { + rc := *r + rc.whenFunc = v - return r + return &rc } -func (r InRange) skipOnEmpty() bool { - return r.skipEmpty +func (r *InRange) when() WhenFunc { + return r.whenFunc } -func (r InRange) WithMessage(message string) InRange { - r.message = message +func (r *InRange) setWhen(v WhenFunc) { + r.whenFunc = v +} + +func (r *InRange) SkipOnEmpty(v bool) *InRange { + rc := *r + rc.skipEmpty = v - return r + return &rc } -func (r InRange) Not() InRange { - r.not = true +func (r *InRange) skipOnEmpty() bool { + return r.skipEmpty +} - return r +func (r *InRange) setSkipOnEmpty(v bool) { + r.skipEmpty = v } -func (r InRange) ValidateValue(_ context.Context, value any) error { +func (r *InRange) ValidateValue(_ context.Context, value any) error { v, valid := indirectValue(value) if !valid { return NewResult().WithError(NewValidationError(r.message)) diff --git a/interface.go b/interface.go index 42e4d14..b4f284a 100644 --- a/interface.go +++ b/interface.go @@ -19,3 +19,8 @@ type RuleSkipEmpty interface { skipOnEmpty() bool setSkipOnEmpty(v bool) } + +type RuleSkipError interface { + shouldSkipOnError(ctx context.Context) bool + setSkipOnError(v bool) +} diff --git a/ip.go b/ip.go index d8e3e00..efe0fde 100644 --- a/ip.go +++ b/ip.go @@ -15,8 +15,8 @@ type IP struct { skipEmpty bool } -func NewIP() IP { - return IP{ +func NewIP() *IP { + return &IP{ message: "Must be a valid IP address.", ipv4NotAllowedMessage: "Must not be an IPv4 address.", ipv6NotAllowedMessage: "Must not be an IPv6 address.", @@ -25,41 +25,52 @@ func NewIP() IP { } } -func (s IP) When(v WhenFunc) IP { - s.whenFunc = v +func (r *IP) WithMessage(v string) *IP { + rc := *r + rc.message = v - return s + return &rc } -func (s IP) when() WhenFunc { - return s.whenFunc +func (r *IP) When(v WhenFunc) *IP { + rc := *r + rc.whenFunc = v + + return &rc } -func (s IP) SkipOnEmpty(v bool) IP { - s.skipEmpty = v +func (r *IP) when() WhenFunc { + return r.whenFunc +} - return s +func (r *IP) setWhen(v WhenFunc) { + r.whenFunc = v } -func (s IP) skipOnEmpty() bool { - return s.skipEmpty +func (r *IP) SkipOnEmpty(v bool) *IP { + rc := *r + rc.skipEmpty = v + + return &rc } -func (s IP) WithMessage(v string) IP { - s.message = v +func (r *IP) skipOnEmpty() bool { + return r.skipEmpty +} - return s +func (r *IP) setSkipOnEmpty(v bool) { + r.skipEmpty = v } -func (s IP) ValidateValue(_ context.Context, value any) error { +func (r *IP) ValidateValue(_ context.Context, value any) error { v, ok := toString(value) if !ok { - return NewResult().WithError(NewValidationError(s.message)) + return NewResult().WithError(NewValidationError(r.message)) } ip := net.ParseIP(v) if ip == nil { - return NewResult().WithError(NewValidationError(s.message)) + return NewResult().WithError(NewValidationError(r.message)) } // TODO: implement ipv4 and ipv4 validations diff --git a/match_regular_expression.go b/match_regular_expression.go index 9b38416..0ad7d6a 100644 --- a/match_regular_expression.go +++ b/match_regular_expression.go @@ -13,52 +13,63 @@ type MatchRegularExpression struct { skipEmpty bool } -func NewMatchRegularExpression(pattern string) MatchRegularExpression { - return MatchRegularExpression{ +func NewMatchRegularExpression(pattern string) *MatchRegularExpression { + return &MatchRegularExpression{ message: "Value is invalid.", pattern: pattern, } } -func (s MatchRegularExpression) WithMessage(message string) MatchRegularExpression { - s.message = message +func (r *MatchRegularExpression) WithMessage(message string) *MatchRegularExpression { + rc := *r + rc.message = message - return s + return &rc } -func (s MatchRegularExpression) When(v WhenFunc) MatchRegularExpression { - s.whenFunc = v +func (r *MatchRegularExpression) When(v WhenFunc) *MatchRegularExpression { + rc := *r + rc.whenFunc = v - return s + return &rc } -func (s MatchRegularExpression) when() WhenFunc { - return s.whenFunc +func (r *MatchRegularExpression) when() WhenFunc { + return r.whenFunc } -func (s MatchRegularExpression) SkipOnEmpty(v bool) MatchRegularExpression { - s.skipEmpty = v +func (r *MatchRegularExpression) setWhen(v WhenFunc) { + r.whenFunc = v +} + +func (r *MatchRegularExpression) SkipOnEmpty(v bool) *MatchRegularExpression { + rc := *r + rc.skipEmpty = v + + return &rc +} - return s +func (r *MatchRegularExpression) skipOnEmpty() bool { + return r.skipEmpty } -func (s MatchRegularExpression) skipOnEmpty() bool { - return s.skipEmpty +func (r *MatchRegularExpression) setSkipOnEmpty(v bool) { + r.skipEmpty = v } -func (s MatchRegularExpression) ValidateValue(_ context.Context, value any) error { +func (r *MatchRegularExpression) ValidateValue(_ context.Context, value any) error { v, ok := toString(value) if !ok { - return NewResult().WithError(NewValidationError(s.message)) + return NewResult().WithError(NewValidationError(r.message)) } - r, err := regexpc.Compile(s.pattern) + rg, err := regexpc.Compile(r.pattern) if err != nil { return err } - if !r.MatchString(v) { - return NewResult().WithError(NewValidationError(s.message)) + if !rg.MatchString(v) { + return NewResult().WithError(NewValidationError(r.message)) } return nil diff --git a/msisdn.go b/msisdn.go index 4d7c23b..ff2e991 100644 --- a/msisdn.go +++ b/msisdn.go @@ -1,11 +1,9 @@ package validator -const ( - msisdnRegexp = `^\d+$` -) +const msisdnRegexp = `^\d+$` type MSISDN struct { - MatchRegularExpression + *MatchRegularExpression } func NewMSISDN() MSISDN { diff --git a/nested.go b/nested.go index 5f70d82..69b0ba0 100644 --- a/nested.go +++ b/nested.go @@ -24,53 +24,65 @@ type Nested struct { skipEmpty bool } -func NewNested(rules RuleSet) Nested { - return Nested{ +func NewNested(rules RuleSet) *Nested { + return &Nested{ normalizeRulesEnabled: true, rules: rules, message: "", } } -func (n Nested) WithMessage(message string) Nested { - n.message = message +func (r *Nested) WithMessage(message string) *Nested { + rc := *r + rc.message = message - return n + return &rc } -func (n Nested) When(v WhenFunc) Nested { - n.whenFunc = v +func (r *Nested) When(v WhenFunc) *Nested { + rc := *r + rc.whenFunc = v - return n + return &rc } -func (n Nested) when() WhenFunc { - return n.whenFunc +func (r *Nested) when() WhenFunc { + return r.whenFunc } -func (n Nested) SkipOnEmpty(v bool) Nested { - n.skipEmpty = v +func (r *Nested) setWhen(v WhenFunc) { + r.whenFunc = v +} + +func (r *Nested) SkipOnEmpty(v bool) *Nested { + rc := *r + rc.skipEmpty = v + + return &rc +} - return n +func (r *Nested) skipOnEmpty() bool { + return r.skipEmpty } -func (n Nested) skipOnEmpty() bool { - return n.skipEmpty +func (r *Nested) setSkipOnEmpty(v bool) { + r.skipEmpty = v } -func (n Nested) notNormalizeRules() Nested { - n.normalizeRulesEnabled = false +func (r *Nested) notNormalizeRules() *Nested { + rc := *r + rc.normalizeRulesEnabled = false - return n + return &rc } -func (n Nested) ValidateValue(ctx context.Context, value any) error { - if n.normalizeRulesEnabled { - n.normalizeRulesEnabled = false // once - if rules, err := n.normalizeRules(); err != nil { +func (r *Nested) ValidateValue(ctx context.Context, value any) error { + if r.normalizeRulesEnabled { + r.normalizeRulesEnabled = false // once + if rules, err := r.normalizeRules(); err != nil { return err } else { - n.rules = rules + r.rules = rules } } @@ -79,7 +91,7 @@ func (n Nested) ValidateValue(ctx context.Context, value any) error { vt = vt.Elem() } - if len(n.rules) == 0 { + if len(r.rules) == 0 { if vt.Kind() != reflect.Struct { return fmt.Errorf("nested rule without rules could be used for structs only. %s given", vt.Kind().String(), @@ -95,7 +107,7 @@ func (n Nested) ValidateValue(ctx context.Context, value any) error { } } - return Validate(ctx, data, n.rules) + return Validate(ctx, data, r.rules) } if vt.Kind() != reflect.Struct { @@ -118,9 +130,9 @@ func (n Nested) ValidateValue(ctx context.Context, value any) error { } compoundResult := NewResult() - results := make([]Result, 0, len(n.rules)) + results := make([]Result, 0, len(r.rules)) - for fieldName, rules := range n.rules { + for fieldName, rules := range r.rules { // todo: parse valuePath validatedValue, err := data.FieldValue(fieldName) @@ -167,8 +179,8 @@ func (n Nested) ValidateValue(ctx context.Context, value any) error { return nil } -func (n Nested) normalizeRules() (RuleSet, error) { - nRules := n.rules +func (r *Nested) normalizeRules() (RuleSet, error) { + nRules := r.rules for { rulesMap := make(map[string]RuleSet, len(nRules)) diff --git a/number.go b/number.go index a05a985..3a9fa8b 100644 --- a/number.go +++ b/number.go @@ -14,8 +14,8 @@ type Number struct { skipEmpty bool } -func NewNumber(min, max int64) Number { - return Number{ +func NewNumber(min, max int64) *Number { + return &Number{ min: min, max: max, notNumberMessage: "Value must be a number.", @@ -24,42 +24,58 @@ func NewNumber(min, max int64) Number { } } -func (n Number) WithTooBigMessage(message string) Number { - n.tooBigMessage = message - return n +func (r *Number) WithTooBigMessage(message string) *Number { + rc := *r + rc.tooBigMessage = message + + return &rc } -func (n Number) WithTooSmallMessage(message string) Number { - n.tooSmallMessage = message - return n +func (r *Number) WithTooSmallMessage(message string) *Number { + rc := *r + rc.tooSmallMessage = message + + return &rc } -func (n Number) WithNotNumberMessage(message string) Number { - n.notNumberMessage = message - return n +func (r *Number) WithNotNumberMessage(message string) *Number { + rc := *r + rc.notNumberMessage = message + + return &rc } -func (n Number) When(v WhenFunc) Number { - n.whenFunc = v +func (r *Number) When(v WhenFunc) *Number { + rc := *r + rc.whenFunc = v - return n + return &rc } -func (n Number) when() WhenFunc { - return n.whenFunc +func (r *Number) when() WhenFunc { + return r.whenFunc } -func (n Number) SkipOnEmpty(v bool) Number { - n.skipEmpty = v +func (r *Number) setWhen(v WhenFunc) { + r.whenFunc = v +} + +func (r *Number) SkipOnEmpty(v bool) *Number { + rc := *r + rc.skipEmpty = v + + return &rc +} - return n +func (r *Number) skipOnEmpty() bool { + return r.skipEmpty } -func (n Number) skipOnEmpty() bool { - return n.skipEmpty +func (r *Number) setSkipOnEmpty(v bool) { + r.skipEmpty = v } -func (n Number) ValidateValue(_ context.Context, value any) error { +func (r *Number) ValidateValue(_ context.Context, value any) error { var i int64 switch v := value.(type) { @@ -104,27 +120,27 @@ func (n Number) ValidateValue(_ context.Context, value any) error { case uint64: i = int64(v) default: - return NewResult().WithError(NewValidationError(n.notNumberMessage)) + return NewResult().WithError(NewValidationError(r.notNumberMessage)) } result := NewResult() - if i < n.min { + if i < r.min { result = result.WithError( - NewValidationError(n.tooSmallMessage). + NewValidationError(r.tooSmallMessage). WithParams(map[string]any{ - "min": n.min, - "max": n.max, + "min": r.min, + "max": r.max, }), ) } - if i > n.max { + if i > r.max { result = result.WithError( - NewValidationError(n.tooBigMessage). + NewValidationError(r.tooBigMessage). WithParams(map[string]any{ - "min": n.min, - "max": n.max, + "min": r.min, + "max": r.max, }), ) } diff --git a/required.go b/required.go index 1c94a2a..570438a 100644 --- a/required.go +++ b/required.go @@ -11,35 +11,42 @@ type Required struct { whenFunc WhenFunc } -func NewRequired() Required { - return Required{ +func NewRequired() *Required { + return &Required{ message: "Value cannot be blank.", allowZeroValue: false, } } -func (r Required) When(v WhenFunc) Required { - r.whenFunc = v +func (r *Required) When(v WhenFunc) *Required { + rc := *r + rc.whenFunc = v - return r + return &rc } -func (r Required) when() WhenFunc { +func (r *Required) when() WhenFunc { return r.whenFunc } -func (s Required) WithMessage(message string) Required { - s.message = message - return s +func (r *Required) setWhen(v WhenFunc) { + r.whenFunc = v +} + +func (r *Required) WithMessage(message string) *Required { + rc := *r + rc.message = message + + return &rc } -func (r Required) WithAllowZeroValue() Required { +func (r *Required) WithAllowZeroValue() *Required { r.allowZeroValue = true return r } -func (r Required) ValidateValue(_ context.Context, value any) error { +func (r *Required) ValidateValue(_ context.Context, value any) error { v := reflect.ValueOf(value) if valueIsEmpty(v, r.allowZeroValue) { return NewResult().WithError(NewValidationError(r.message)) diff --git a/result_test.go b/result_test.go index e6bed7d..61236df 100644 --- a/result_test.go +++ b/result_test.go @@ -8,10 +8,10 @@ import ( func TestResult_Errors_Successfully(t *testing.T) { res := NewResult().WithError(NewValidationError("test err")) - errs := res.Errors() - assert.Equal(t, []*ValidationError{{Message: "test err"}}, errs) + errors := res.Errors() + assert.Equal(t, []*ValidationError{{Message: "test err"}}, errors) - errs = append(errs, &ValidationError{Message: "invisible error"}) + _ = append(errors, &ValidationError{Message: "invisible error"}) assert.Equal(t, []*ValidationError{{Message: "test err"}}, res.Errors()) res = res.WithError(NewValidationError("test2 err")) diff --git a/string.go b/string.go index 4c4df54..795ba84 100644 --- a/string.go +++ b/string.go @@ -18,8 +18,8 @@ type StringLength struct { skipEmpty bool } -func NewStringLength(min, max int) StringLength { - return StringLength{ +func NewStringLength(min, max int) *StringLength { + return &StringLength{ message: "This value must be a string.", tooShortMessage: "This value should contain at least {min}.", tooLongMessage: "This value should contain at most {max}.", @@ -28,72 +28,85 @@ func NewStringLength(min, max int) StringLength { } } -func (s StringLength) WithMessage(message string) StringLength { - s.message = message +func (r *StringLength) WithMessage(message string) *StringLength { + rc := *r + rc.message = message - return s + return &rc } -func (s StringLength) WithTooShortMessage(message string) StringLength { - s.tooShortMessage = message +func (r *StringLength) WithTooShortMessage(message string) *StringLength { + rc := *r + rc.tooShortMessage = message - return s + return &rc } -func (s StringLength) WithTooLongMessage(message string) StringLength { - s.tooLongMessage = message +func (r *StringLength) WithTooLongMessage(message string) *StringLength { + rc := *r + rc.tooLongMessage = message - return s + return &rc } -func (s StringLength) When(v WhenFunc) StringLength { - s.whenFunc = v +func (r *StringLength) When(v WhenFunc) *StringLength { + rc := *r + rc.whenFunc = v - return s + return &rc } -func (s StringLength) when() WhenFunc { - return s.whenFunc +func (r *StringLength) when() WhenFunc { + return r.whenFunc } -func (s StringLength) SkipOnEmpty(v bool) StringLength { - s.skipEmpty = v +func (r *StringLength) setWhen(v WhenFunc) { + r.whenFunc = v +} + +func (r *StringLength) SkipOnEmpty(v bool) *StringLength { + rc := *r + rc.skipEmpty = v + + return &rc +} - return s +func (r *StringLength) skipOnEmpty() bool { + return r.skipEmpty } -func (s StringLength) skipOnEmpty() bool { - return s.skipEmpty +func (r *StringLength) setSkipOnEmpty(v bool) { + r.skipEmpty = v } -func (s StringLength) ValidateValue(_ context.Context, value any) error { +func (r *StringLength) ValidateValue(_ context.Context, value any) error { v, ok := toString(value) if !ok { - return NewResult().WithError(NewValidationError(s.message)) + return NewResult().WithError(NewValidationError(r.message)) } result := NewResult() v = strings.Trim(v, " ") l := utf8.RuneCountInString(v) - if l < s.min { + if l < r.min { result = NewResult(). WithError( - NewValidationError(s.tooShortMessage). + NewValidationError(r.tooShortMessage). WithParams(map[string]any{ - "min": s.min, - "max": s.max, + "min": r.min, + "max": r.max, }), ) } - if l > s.max { + if l > r.max { result = NewResult(). WithError( - NewValidationError(s.tooLongMessage). + NewValidationError(r.tooLongMessage). WithParams(map[string]any{ - "min": s.min, - "max": s.max, + "min": r.min, + "max": r.max, }), ) } diff --git a/time.go b/time.go index e8efea3..e5f4de2 100644 --- a/time.go +++ b/time.go @@ -21,8 +21,8 @@ type Time struct { skipEmpty bool } -func NewTime() Time { - return Time{ +func NewTime() *Time { + return &Time{ message: "Value is invalid", formatMessage: "Format of the time value must be equal {format}", tooBigMessage: "Time must be no greater than {max}.", @@ -33,91 +33,108 @@ func NewTime() Time { } } -func (t Time) WithMessage(message string) Time { - t.message = message +func (r *Time) WithMessage(message string) *Time { + rc := *r + rc.message = message - return t + return &rc } -func (t Time) WithFormatMessage(message string) Time { - t.formatMessage = message +func (r *Time) WithFormatMessage(message string) *Time { + rc := *r + rc.formatMessage = message - return t + return &rc } -func (t Time) WithTooSmallMessage(message string) Time { - t.tooSmallMessage = message +func (r *Time) WithTooSmallMessage(message string) *Time { + rc := *r + rc.tooSmallMessage = message - return t + return &rc } -func (t Time) WithTooBigMessage(message string) Time { - t.tooBigMessage = message +func (r *Time) WithTooBigMessage(message string) *Time { + rc := *r + rc.tooBigMessage = message - return t + return &rc } -func (t Time) WithFormat(format string) Time { - t.format = format +func (r *Time) WithFormat(format string) *Time { + rc := *r + rc.format = format - return t + return &rc } -func (t Time) WithMin(min TimeFunc) Time { - t.min = min +func (r *Time) WithMin(min TimeFunc) *Time { + rc := *r + rc.min = min - return t + return &rc } -func (t Time) WithMax(max TimeFunc) Time { - t.max = max +func (r *Time) WithMax(max TimeFunc) *Time { + rc := *r + rc.max = max - return t + return &rc } -func (t Time) When(v WhenFunc) Time { - t.whenFunc = v +func (r *Time) When(v WhenFunc) *Time { + rc := *r + rc.whenFunc = v - return t + return &rc } -func (t Time) when() WhenFunc { - return t.whenFunc +func (r *Time) when() WhenFunc { + return r.whenFunc } -func (t Time) SkipOnEmpty(v bool) Time { - t.skipEmpty = v +func (r *Time) setWhen(v WhenFunc) { + r.whenFunc = v +} + +func (r *Time) SkipOnEmpty(v bool) *Time { + rc := *r + rc.skipEmpty = v + + return &rc +} - return t +func (r *Time) skipOnEmpty() bool { + return r.skipEmpty } -func (t Time) skipOnEmpty() bool { - return t.skipEmpty +func (r *Time) setSkipOnEmpty(v bool) { + r.skipEmpty = v } -func (t Time) ValidateValue(_ context.Context, value any) error { +func (r *Time) ValidateValue(_ context.Context, value any) error { v, valid := indirectValue(value) if !valid { - return NewResult().WithError(NewValidationError(t.message)) + return NewResult().WithError(NewValidationError(r.message)) } vStr, okStr := toString(value) vObj, okObj := v.(vtype.Time) if !okStr && !okObj { - return NewResult().WithError(NewValidationError(t.message)) + return NewResult().WithError(NewValidationError(r.message)) } if okObj { vStr = vObj.String() } - vt, err := time.Parse(t.format, vStr) + vt, err := time.Parse(r.format, vStr) if err != nil { return NewResult().WithError( - NewValidationError(t.formatMessage). + NewValidationError(r.formatMessage). WithParams( map[string]any{ - "format": t.format, + "format": r.format, }, ), ) @@ -125,11 +142,11 @@ func (t Time) ValidateValue(_ context.Context, value any) error { result := NewResult() - if t.min != nil { - minTime := t.min() + if r.min != nil { + minTime := r.min() if vt.Before(minTime) { result = result.WithError( - NewValidationError(t.tooSmallMessage). + NewValidationError(r.tooSmallMessage). WithParams( map[string]any{ "min": minTime, @@ -139,11 +156,11 @@ func (t Time) ValidateValue(_ context.Context, value any) error { } } - if t.max != nil { - maxTime := t.max() + if r.max != nil { + maxTime := r.max() if vt.After(maxTime) { result = result.WithError( - NewValidationError(t.tooBigMessage). + NewValidationError(r.tooBigMessage). WithParams( map[string]any{ "max": maxTime, diff --git a/unique_values.go b/unique_values.go index 94b222e..3089e8d 100644 --- a/unique_values.go +++ b/unique_values.go @@ -11,39 +11,50 @@ type UniqueValues struct { skipEmpty bool } -func NewUniqueValues() UniqueValues { - return UniqueValues{ +func NewUniqueValues() *UniqueValues { + return &UniqueValues{ message: "The list of values must be unique.", } } -func (r UniqueValues) WithMessage(message string) UniqueValues { - r.message = message +func (r *UniqueValues) WithMessage(message string) *UniqueValues { + rc := *r + rc.message = message - return r + return &rc } -func (r UniqueValues) When(v WhenFunc) UniqueValues { - r.whenFunc = v +func (r *UniqueValues) When(v WhenFunc) *UniqueValues { + rc := *r + rc.whenFunc = v - return r + return &rc } -func (r UniqueValues) when() WhenFunc { +func (r *UniqueValues) when() WhenFunc { return r.whenFunc } -func (r UniqueValues) SkipOnEmpty(v bool) UniqueValues { - r.skipEmpty = v +func (r *UniqueValues) setWhen(v WhenFunc) { + r.whenFunc = v +} - return r +func (r *UniqueValues) SkipOnEmpty(v bool) *UniqueValues { + rc := *r + rc.skipEmpty = v + + return &rc } -func (r UniqueValues) skipOnEmpty() bool { +func (r *UniqueValues) skipOnEmpty() bool { return r.skipEmpty } -func (r UniqueValues) ValidateValue(_ context.Context, value any) error { +func (r *UniqueValues) setSkipOnEmpty(v bool) { + r.skipEmpty = v +} + +func (r *UniqueValues) ValidateValue(_ context.Context, value any) error { if reflect.TypeOf(value).Kind() != reflect.Slice { return NewResult().WithError(NewValidationError(r.message)) } diff --git a/url.go b/url.go index 49b0813..ea51bb7 100644 --- a/url.go +++ b/url.go @@ -22,71 +22,88 @@ type URL struct { skipEmpty bool } -func NewURL() URL { - return URL{ +func NewURL() *URL { + return &URL{ validSchemes: []string{"http", "https"}, enableIDN: false, message: "This value is not a valid URL.", } } -func (u URL) WithValidScheme(scheme ...string) URL { - u.validSchemes = scheme +func (r *URL) WithValidScheme(scheme ...string) *URL { + rc := *r + rc.validSchemes = scheme - return u + return &rc } -func (u URL) WithMessage(message string) URL { - u.message = message +func (r *URL) WithMessage(message string) *URL { + rc := *r + rc.message = message - return u + return &rc } -func (u URL) WithEnableIDN() URL { - u.enableIDN = true +func (r *URL) WithEnableIDN() *URL { + rc := *r + rc.enableIDN = true - return u + return &rc } -func (u URL) When(v WhenFunc) URL { - u.whenFunc = v +func (r *URL) When(v WhenFunc) *URL { + rc := *r + rc.whenFunc = v - return u + return &rc } -func (u URL) when() WhenFunc { - return u.whenFunc +func (r *URL) when() WhenFunc { + return r.whenFunc } -func (u URL) SkipOnEmpty(v bool) URL { - u.skipEmpty = v +func (r *URL) setWhen(v WhenFunc) { + r.whenFunc = v +} + +func (r *URL) SkipOnEmpty(v bool) *URL { + rc := *r + rc.skipEmpty = v + + return &rc +} + +func (r *URL) skipOnEmpty() bool { + return r.skipEmpty +} - return u +func (r *URL) setSkipOnEmpty(v bool) { + r.skipEmpty = v } -func (u URL) ValidateValue(_ context.Context, value any) error { +func (r *URL) ValidateValue(_ context.Context, value any) error { v, ok := toString(value) // make sure the length is limited to avoid DOS attacks if !ok || len(v) >= 2000 { - return NewResult().WithError(NewValidationError(u.message)) + return NewResult().WithError(NewValidationError(r.message)) } - if u.enableIDN { - v = u.convertIDN(v) + if r.enableIDN { + v = r.convertIDN(v) } uri, err := url.Parse(v) if err != nil { - return NewResult().WithError(NewValidationError(u.message)) + return NewResult().WithError(NewValidationError(r.message)) } if len(uri.Scheme) == 0 || (len(uri.Host) == 0 && len(uri.Opaque) == 0) { - return NewResult().WithError(NewValidationError(u.message)) + return NewResult().WithError(NewValidationError(r.message)) } - if len(u.validSchemes) > 0 && u.validSchemes[0] != AllowAnyURLSchema { + if len(r.validSchemes) > 0 && r.validSchemes[0] != AllowAnyURLSchema { isValidScheme := false - for _, s := range u.validSchemes { + for _, s := range r.validSchemes { if s == uri.Scheme { isValidScheme = true break @@ -94,25 +111,25 @@ func (u URL) ValidateValue(_ context.Context, value any) error { } if !isValidScheme { - return NewResult().WithError(NewValidationError(u.message)) + return NewResult().WithError(NewValidationError(r.message)) } } return nil } -func (u URL) convertIDN(value string) string { +func (r *URL) convertIDN(value string) string { if !strings.Contains(value, "://") { - return u.idnToASCII(value) + return r.idnToASCII(value) } return regexpDomain.ReplaceAllStringFunc(value, func(m string) string { p := regexpDomain.FindStringSubmatch(m) - return "://" + u.idnToASCII(p[1]) + return "://" + r.idnToASCII(p[1]) }) } -func (u URL) idnToASCII(idn string) string { +func (r *URL) idnToASCII(idn string) string { if d, err := idna.ToASCII(idn); err == nil { return d } else { diff --git a/uuid.go b/uuid.go index 2475a5e..1df986b 100644 --- a/uuid.go +++ b/uuid.go @@ -25,48 +25,65 @@ type UUID struct { skipEmpty bool } -func NewUUID() UUID { - return UUID{ +func NewUUID() *UUID { + return &UUID{ message: "Invalid UUID format.", invalidVersionMessage: "UUID version must be equal to {version}.", } } -func (r UUID) WithMessage(message string) UUID { - r.message = message +func (r *UUID) WithMessage(message string) *UUID { + rc := *r + rc.message = message - return r + return &rc } -func (r UUID) WithInvalidVersionMessage(message string) UUID { - r.invalidVersionMessage = message +func (r *UUID) WithInvalidVersionMessage(message string) *UUID { + rc := *r + rc.invalidVersionMessage = message - return r + return &rc } -func (r UUID) WithVersion(version UUIDVersion) UUID { - r.version = version +func (r *UUID) WithVersion(version UUIDVersion) *UUID { + rc := *r + rc.version = version - return r + return &rc } -func (r UUID) When(v WhenFunc) UUID { - r.whenFunc = v +func (r *UUID) When(v WhenFunc) *UUID { + rc := *r + rc.whenFunc = v - return r + return &rc } -func (r UUID) when() WhenFunc { +func (r *UUID) when() WhenFunc { return r.whenFunc } -func (r UUID) SkipOnEmpty(v bool) UUID { - r.skipEmpty = v +func (r *UUID) setWhen(v WhenFunc) { + r.whenFunc = v +} + +func (r *UUID) SkipOnEmpty(v bool) *UUID { + rc := *r + rc.skipEmpty = v - return r + return &rc +} + +func (r *UUID) skipOnEmpty() bool { + return r.skipEmpty +} + +func (r *UUID) setSkipOnEmpty(v bool) { + r.skipEmpty = v } -func (r UUID) ValidateValue(_ context.Context, value any) error { +func (r *UUID) ValidateValue(_ context.Context, value any) error { v, ok := toString(value) if !ok { return NewResult().WithError(NewValidationError(r.message)) diff --git a/validator.go b/validator.go index fd2331e..fcb5fa7 100644 --- a/validator.go +++ b/validator.go @@ -27,7 +27,7 @@ func ValidateValue(ctx context.Context, value any, rules ...Rule) error { return err } - if extDS, ok := extractDataSet(ctx); !ok || value != extDS { + if extDS, ok := ExtractDataSet[DataSet](ctx); !ok || value != extDS { ctx = withDataSet(ctx, dataSet) } @@ -81,7 +81,7 @@ func Validate(ctx context.Context, dataSet any, rules RuleSet) error { if isSkipValidate(ctx, fieldValue, validatorRule) { continue } - if _, ok := validatorRule.(Required); !ok { + if _, ok := validatorRule.(*Required); !ok { if fieldValue == nil { // if value is not required and is nil continue @@ -162,7 +162,7 @@ func normalizeRules(rules []Rule) []Rule { } for i := range rules { - if r, ok := rules[i].(Required); ok { + if r, ok := rules[i].(*Required); ok { if i == 0 { break } @@ -185,6 +185,12 @@ func isSkipValidate(ctx context.Context, value any, r Rule) bool { } } + if rser, ok := r.(RuleSkipError); ok { + if rser.shouldSkipOnError(ctx) && previousRulesErrored(ctx) { + return true + } + } + if rw, ok := r.(RuleWhen); ok { return rw.when() != nil && !rw.when()(ctx, value) }