Skip to content

Commit

Permalink
fix: allow updating zero values and fix pointer resolution on isZero
Browse files Browse the repository at this point in the history
  • Loading branch information
kataras committed Oct 28, 2023
1 parent 979ffb5 commit a5c06f9
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 27 deletions.
2 changes: 2 additions & 0 deletions db_information_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func ExampleDB_ListColumns() {
`[customers.cognito_user_id] pg:"name=cognito_user_id,type=uuid,unique_index=customer_unique_idx"`,
`[customers.email] pg:"name=email,type=varchar(255),unique_index=customer_unique_idx"`,
`[customers.name] pg:"name=name,type=varchar(255),index=btree"`,
`[customers.username] pg:"name=username,type=varchar(255),default=''::character varying"`,
}

for i, column := range columns {
Expand Down Expand Up @@ -167,6 +168,7 @@ func ExampleDB_ListColumnsInformationSchema() {
// &desc.ColumnBasicInfo{TableName:"customers", TableDescription:"", TableType:0x0, Name:"cognito_user_id", OrdinalPosition:4, Description:"", Default:"", DataType:0x2f, DataTypeArgument:"", IsNullable:false, IsIdentity:false, IsGenerated:false}
// &desc.ColumnBasicInfo{TableName:"customers", TableDescription:"", TableType:0x0, Name:"email", OrdinalPosition:5, Description:"", Default:"", DataType:0xb, DataTypeArgument:"255", IsNullable:false, IsIdentity:false, IsGenerated:false}
// &desc.ColumnBasicInfo{TableName:"customers", TableDescription:"", TableType:0x0, Name:"name", OrdinalPosition:6, Description:"", Default:"", DataType:0xb, DataTypeArgument:"255", IsNullable:false, IsIdentity:false, IsGenerated:false}
// &desc.ColumnBasicInfo{TableName:"customers", TableDescription:"", TableType:0x0, Name:"username", OrdinalPosition:7, Description:"", Default:"''::character varying", DataType:0xb, DataTypeArgument:"255", IsNullable:false, IsIdentity:false, IsGenerated:false}
}

func ExampleDB_ListConstraints() {
Expand Down
20 changes: 14 additions & 6 deletions desc/argument.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func shiftToEndEnd[T any](s []T, x int) []T {

// extractArguments takes a reflect value of a struct and a table definition
// and returns a slice of arguments for each column in the table that is not auto-generated or has a default value.
func extractArguments(td *Table, structValue reflect.Value) (Arguments, error) {
func extractArguments(td *Table, structValue reflect.Value, filter func(columnName string) bool) (Arguments, error) {
args := make(Arguments, 0, len(td.Columns)) // create a slice to hold the arguments

for _, c := range td.Columns { // loop over each column in the table definition
Expand All @@ -69,13 +69,21 @@ func extractArguments(td *Table, structValue reflect.Value) (Arguments, error) {

fieldValue := field.Interface() // get the field value as an interface

if c.Default != "" {
if isZero(fieldValue) {
// skip this field if it has a default value and the field value is zero,
// the createTable function has configured the database's default value option
if filter != nil {
if !filter(c.Name) {
continue
}
} else if c.Type == UUID && c.PrimaryKey && !c.Nullable {
} else { // if no custom filter passed, then check by its zero value if no default value on database.
if c.Default != "" {
if isZero(field) {
// skip this field if it has a default value and the field value is zero,
// the createTable function has configured the database's default value option
continue
}
}
}

if c.Default != "" && c.Type == UUID && c.PrimaryKey && !c.Nullable {
if isZero(fieldValue) {
continue // skip this field if it is a UUID primary key and required and the field value is zero
}
Expand Down
2 changes: 1 addition & 1 deletion desc/exists_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
// BuildExistsQuery builds and returns an SQL query for checking of existing in a row in the table,
// based on the given struct value.
func BuildExistsQuery(td *Table, structValue reflect.Value) (string, []any, error) {
args, err := extractArguments(td, structValue)
args, err := extractArguments(td, structValue, nil)
if err != nil {
return "", nil, err // return the error if finding arguments fails
}
Expand Down
2 changes: 1 addition & 1 deletion desc/insert_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func BuildInsertQuery(td *Table, structValue reflect.Value, idPtr any, forceOnCo
}

// find the arguments for the SQL query based on the struct value and the table definition
args, err := extractArguments(td, structValue)
args, err := extractArguments(td, structValue, nil)
if err != nil {
return "", nil, err // return the error if finding arguments fails
}
Expand Down
30 changes: 18 additions & 12 deletions desc/update_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,28 @@ func extractUpdateArguments(value any, columnsToUpdate []string, primaryKey *Col
return nil, err
}

args, err := extractArguments(primaryKey.Table, structValue)
columnsToUpdateLength := len(columnsToUpdate)

args, err := extractArguments(primaryKey.Table, structValue, func(fieldName string) bool {
if columnsToUpdateLength == 0 {
// full update.
return true
}

for _, onlyColumnName := range columnsToUpdate {
if onlyColumnName == fieldName {
return true
}
}

return false
})
if err != nil {
return nil, err // return the error if finding arguments fails
}

if len(columnsToUpdate) > 0 { // if specific columns to update, then override the default behavior.
args = filterArguments(args, func(arg Argument) bool {
for _, onlyColumnName := range columnsToUpdate {
if arg.Column.Name == onlyColumnName {
return true
}
}

return false
})
} else { // otherwise full update, even zero values (e.g. integer 0) all except ID and any created_at, updated_at.
if columnsToUpdateLength == 0 {
// full update, even zero values (e.g. integer 0) all except ID and any created_at, updated_at.
args = filterArgumentsForFullUpdate(args)
}

Expand Down
113 changes: 109 additions & 4 deletions desc/zeroer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/json"
"math/big"
"net"
"reflect"
"time"
)

// Zeroer is an interface that defines a method to check if a value is zero.
Expand All @@ -23,11 +25,114 @@ func isZero(v any) bool {
}

switch t := v.(type) { // switch on the type of the value
case Zeroer: // if the value implements the Zeroer interface (this includes time.Time as well)
if t == nil { // if the value is nil, return true
return true
case *time.Time:
return t == nil || t.IsZero()
case *string:
return t == nil || *t == ""
case *int:
return t == nil || *t == 0
case *int8:
return t == nil || *t == 0
case *int16:
return t == nil || *t == 0
case *int32:
return t == nil || *t == 0
case *int64:
return t == nil || *t == 0
case *uint:
return t == nil || *t == 0
case *uint8:
return t == nil || *t == 0
case *uint16:
return t == nil || *t == 0
case *uint32:
return t == nil || *t == 0
case *uint64:
return t == nil || *t == 0
case *float32:
return t == nil || *t == 0
case *float64:
return t == nil || *t == 0
case *bool:
return t == nil || !*t
case *[]string:
return t == nil || len(*t) == 0
case *[]int:
return t == nil || len(*t) == 0
case *[]int8:
return t == nil || len(*t) == 0
case *[]int16:
return t == nil || len(*t) == 0
case *[]int32:
return t == nil || len(*t) == 0
case *[]int64:
return t == nil || len(*t) == 0
case *[]uint:
return t == nil || len(*t) == 0
case *[]uint8:
return t == nil || len(*t) == 0
case *[]uint16:
return t == nil || len(*t) == 0
case *[]uint32:
return t == nil || len(*t) == 0
case *[]uint64:
return t == nil || len(*t) == 0
case *[]float32:
return t == nil || len(*t) == 0
case *[]float64:
return t == nil || len(*t) == 0
case *[]bool:
return t == nil || len(*t) == 0
case *[]any:
return t == nil || len(*t) == 0
case *map[string]string:
return t == nil || len(*t) == 0
case *map[string]int:
return t == nil || len(*t) == 0
case *map[string]any:
return t == nil || len(*t) == 0
case *map[int]int:
return t == nil || len(*t) == 0
case *map[int]any:
return t == nil || len(*t) == 0
case *map[any]any:
return t == nil || len(*t) == 0
case *map[any]int:
return t == nil || len(*t) == 0
case *map[any]string:
return t == nil || len(*t) == 0
case *map[any]float64:
return t == nil || len(*t) == 0
case *map[any]bool:
return t == nil || len(*t) == 0
case *map[any][]any:
return t == nil || len(*t) == 0
case *map[any][]int:
return t == nil || len(*t) == 0
case *map[any][]string:
return t == nil || len(*t) == 0
case *map[any]map[any]any:
return t == nil || len(*t) == 0
case *map[any]map[any]int:
return t == nil || len(*t) == 0
case *map[any]map[any]string:
return t == nil || len(*t) == 0
case *map[any]map[any]float64:
return t == nil || len(*t) == 0
case *map[any]map[any]bool:
return t == nil || len(*t) == 0
case *map[any]map[any][]any:
return t == nil || len(*t) == 0
case *map[any]map[any][]int:
return t == nil || len(*t) == 0
case reflect.Value:
if t.Kind() == reflect.Ptr {
return t.IsNil()
}
return t.IsZero() // otherwise, call the IsZero method and return its result

return t.IsZero()
case Zeroer: // if the value implements the Zeroer interface
return t == nil || t.IsZero() // call the IsZero method on the value
case string: // if the value is a string
return t == "" // return true if the string is empty
case int: // if the value is an int
Expand Down
15 changes: 14 additions & 1 deletion desc/zeroer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@ import (
"encoding/json"
"math/big"
"net"
"reflect"
"testing"
"time"
)

// TestIsZero tests the isZero function with various inputs and outputs
func TestIsZero(t *testing.T) {
now := time.Now()
timePtr := &now
var nilTimePtr *time.Time

// Define a table of test cases
testCases := []struct {
input any // input value
Expand Down Expand Up @@ -44,10 +49,18 @@ func TestIsZero(t *testing.T) {
{net.IPv4(127, 0, 0, 1), false}, // non-empty net.IP should not be zero
{time.Time{}, true}, // empty time.Time (zero time) should be zero
{time.Now(), false}, // non-empty time.Time (current time) should not be zero
{timePtr, false}, // non-nil time.Time (current time) should not be zero
{nilTimePtr, true}, // nil time.Time should be zero
}

for i, tc := range testCases {
if tc.input == nil {
isNil := false

if val := reflect.ValueOf(tc.input); val.Kind() == reflect.Pointer {
isNil = val.IsNil()
}

if tc.input == nil || isNil {
t.Run("nil", func(t *testing.T) {
result := isZero(tc.input) // call the isZero function with the input
if result != tc.output { // compare the result with the expected output
Expand Down
21 changes: 19 additions & 2 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func Example() {
CognitoUserID: "373f90eb-00ac-410f-9fe0-1a7058d090ba",
Email: "[email protected]",
Name: "kataras",
Username: "kataras",
}

// Insert the customer into the database and get its ID.
Expand All @@ -43,8 +44,6 @@ func Example() {
// Update specific columns by id:
updated, err := customers.UpdateOnlyColumns(
context.Background(),
// TODO: make a generator for: models.Customers.Columns.CognitoUserID.Name so
// end-developers have static safety for columns.
[]string{"cognito_user_id"},
Customer{
BaseEntity: BaseEntity{
Expand All @@ -68,6 +67,21 @@ func Example() {
return fmt.Errorf("update: no record was updated")
}

// Update a default column to its zero value.
updated, err = customers.UpdateOnlyColumns(
context.Background(),
[]string{"username"},
Customer{
BaseEntity: BaseEntity{
ID: customerToInsert.ID,
},
Username: "",
})
if err != nil {
return fmt.Errorf("update username: %w", err)
} else if updated == 0 {
return fmt.Errorf("update username: no record was updated")
}
// Select the customer from the database by its ID.
customer, err := customers.SelectSingle(context.Background(), `SELECT * FROM customers WHERE id = $1;`, customerToInsert.ID)
if err != nil {
Expand Down Expand Up @@ -218,4 +232,7 @@ func Example() {
fmt.Printf("expected other_features to be equal but got %#+v and %#+v", otherFeatures, existingBlogPost.OtherFeatures)
return
}

// Output:
//
}
2 changes: 2 additions & 0 deletions schema_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type Customer struct {
Email string `pg:"type=varchar(255),unique_index=customer_unique_idx"`
// ^ optional: unique to allow upsert by "email"-only column confliction instead of the unique_index.
Name string `pg:"type=varchar(255),index=btree"`

Username string `pg:"type=varchar(255),default=''::character varying"`
}

// Blog is a struct that represents a blog entity in the database.
Expand Down

0 comments on commit a5c06f9

Please sign in to comment.