Skip to content

Commit

Permalink
set default optional foreign key UUIDs to null on full update
Browse files Browse the repository at this point in the history
  • Loading branch information
kataras committed Nov 14, 2023
1 parent 11f04ff commit 5359735
Show file tree
Hide file tree
Showing 5 changed files with 374 additions and 11 deletions.
24 changes: 19 additions & 5 deletions desc/argument.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package desc
import (
"fmt"
"reflect"

"github.com/jackc/pgx/v5/pgtype/zeronull"
)

// Argument represents a single argument for a database query
Expand Down Expand Up @@ -69,6 +71,7 @@ func extractArguments(td *Table, structValue reflect.Value, filter func(columnNa

fieldValue := field.Interface() // get the field value as an interface

// If filter passed, respect just the filter.
if filter != nil {
if !filter(c.Name) {
continue
Expand All @@ -83,7 +86,7 @@ func extractArguments(td *Table, structValue reflect.Value, filter func(columnNa
}
}

if c.Default != "" && c.Type == UUID && c.PrimaryKey && !c.Nullable {
if c.Default != "" && c.Type == UUID && !c.Nullable && c.PrimaryKey {
if isZero(fieldValue) {
continue // skip this field if it is a UUID primary key and required and the field value is zero
}
Expand Down Expand Up @@ -117,10 +120,10 @@ func extractArguments(td *Table, structValue reflect.Value, filter func(columnNa
}

// filterArguments takes a slice of arguments and a filter function and returns a slice of arguments.
func filterArguments(args Arguments, filter func(arg Argument) bool) Arguments {
func filterArguments(args Arguments, filter func(arg *Argument) bool) Arguments {
var filtered Arguments
for _, arg := range args {
if filter(arg) {
if filter(&arg) {
filtered = append(filtered, arg)
}
}
Expand All @@ -129,8 +132,19 @@ func filterArguments(args Arguments, filter func(arg Argument) bool) Arguments {

// FilterArgumentsForInsert takes a slice of arguments and returns a slice of arguments for insert.
func filterArgumentsForFullUpdate(args Arguments) Arguments {
return filterArguments(args, func(arg Argument) bool {
return !arg.Column.IsGenerated() && !arg.Column.Presenter // && !arg.Column.Unscannable
return filterArguments(args, func(arg *Argument) bool {
c := arg.Column

if (c.PrimaryKey || c.ReferenceColumnName != "") && c.Default != "" && c.Type == UUID && c.Nullable {
if isZero(arg.Value) { // fixes full update of a record which contains an optional reference UUID, we allow setting it to null, but
// we have to replace empty string with zeronull.UUID{}. Note that on insert we omit it from the query, as it will default to the default sql line default value.
arg.Value = zeronull.UUID{}
}

return true
}

return !c.IsGenerated() && !c.Presenter // && !arg.Column.Unscannable
})
}

Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/crypto v0.15.0 // indirect
golang.org/x/sync v0.5.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/text v0.14.0 // indirect
)
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA=
golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g=
golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0=
golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
184 changes: 184 additions & 0 deletions http_controller.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package pg

import (
"encoding/json"
"fmt"
"net/http"
"reflect"
"strings"

"github.com/kataras/pg/desc"
)

// Exampler is an interface used by testing to generate example values for a specific struct field.
type Exampler interface {
ListExamples() any
}

type HTTPController[T any] struct {
repository *Repository[T]
primaryKeyType desc.DataType

// ErrorHandler defaults to the PG's error handler. It can be customized for this controller.
// Setting this to nil will panic the application on the first error.
ErrorHandler func(w http.ResponseWriter, r *http.Request, err error)

// AfterPayloadRead is called after the payload is read.
// It can be used to validate the payload or set default fields based on the request Context.
AfterPayloadRead func(w http.ResponseWriter, r *http.Request, payload T) (T, bool)
}

func NewHTTPController[T any](repository *Repository[T]) *HTTPController[T] {
return &HTTPController[T]{
repository: repository,
}
}

type (
jsonSchema[T any] struct {
Description string `json:"description,omitempty"`
Types []jsonSchemaFieldType `json:"types,omitempty"`
Fields []jsonSchemaField `json:"fields"`
}

jsonSchemaFieldType struct {
Name string `json:"name"`
Example any `json:"example,omitempty"`
}

jsonSchemaField struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Type string `json:"type"`
DataType string `json:"data_type"`
Required bool `json:"required"`
}
)

func newJSONSchema[T any](td *desc.Table) *jsonSchema[T] {
var fieldTypes []jsonSchemaFieldType
seenFieldTypes := make(map[reflect.Type]struct{})

fields := make([]jsonSchemaField, 0, len(td.Columns))
for _, col := range td.Columns {
fieldName, ok := getJSONTag(col.Table.StructType, col.FieldIndex)
if !ok {
fieldName = col.Name
}

// Get the field type examples.
if _, seen := seenFieldTypes[col.FieldType]; !seen {
seenFieldTypes[col.FieldType] = struct{}{}

colValue := reflect.New(col.FieldType).Interface()
if exampler, ok := colValue.(Exampler); ok {
exampleValues := exampler.ListExamples()
fieldTypes = append(fieldTypes, jsonSchemaFieldType{
Name: col.FieldType.String(),
Example: exampleValues,
})
}
}

field := jsonSchemaField{
// Here we want the json tag name, not the column name.
Name: fieldName,
Description: col.Description,
Type: col.FieldType.String(),
DataType: col.Type.String(),
Required: !col.Nullable,
}

fields = append(fields, field)
}

return &jsonSchema[T]{
Description: td.Description,
Types: fieldTypes,
Fields: fields,
}
}

func getJSONTag(t reflect.Type, fieldIndex []int) (string, bool) {
if t.Kind() != reflect.Struct {
return "", false
}

f := t.FieldByIndex(fieldIndex)
jsonTag := f.Tag.Get("json")
if jsonTag == "" {
return "", false
}

return strings.Split(jsonTag, ",")[0], true
}

func writeJSON(w http.ResponseWriter, code int, v any) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
return json.NewEncoder(w).Encode(v)
}

func readJSON(r *http.Request, v any) error {
return json.NewDecoder(r.Body).Decode(v)
}

func (c *HTTPController[T]) getSchema(s *jsonSchema[T]) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, s)
})
}

type idPayload struct {
ID any `json:"id"`
}

func toUUIDv4(v [16]uint8) string {
slice := v[:]
// Modify the 7th element to have the form 4xxx
slice[6] = (slice[6] & 0x0f) | 0x40
// Modify the 9th element to have the form yxxx
slice[8] = (slice[8] & 0x3f) | 0x80
// Convert to UUIDv4 string
s := fmt.Sprintf("%x-%x-%x-%x-%x", slice[0:4], slice[4:6], slice[6:8], slice[8:10], slice[10:])
return s
}

// readPayload reads the request body and returns the entity.
func (c *HTTPController[T]) readPayload(w http.ResponseWriter, r *http.Request) (T, bool) {
var payload T
err := readJSON(r, &payload)
if err != nil {
c.ErrorHandler(w, r, err)
return payload, false
}

if c.AfterPayloadRead != nil {
return c.AfterPayloadRead(w, r, payload)
}

return payload, true
}

// create creates a new entity.
func (c *HTTPController[T]) create(w http.ResponseWriter, r *http.Request) {
entry, ok := c.readPayload(w, r)
if !ok {
return
}

var id any
err := c.repository.InsertSingle(r.Context(), entry, &id)
if err != nil {
c.ErrorHandler(w, r, err)
return
}

switch c.primaryKeyType {
case desc.UUID:
// A special case to convert from [16]uint8 to string (uuidv4). We do this in order to not accept a 2nd generic parameter of V.
id = toUUIDv4(id.([16]uint8))
}

writeJSON(w, http.StatusCreated, idPayload{ID: id})
}
Loading

0 comments on commit 5359735

Please sign in to comment.