Skip to content

Commit

Permalink
generator-go-sdk: fix unmarshalling fields that reference a model int…
Browse files Browse the repository at this point in the history
…erface (rather than a model struct)
  • Loading branch information
manicminer committed Sep 19, 2024
1 parent ce7e94e commit 66c8d0d
Showing 1 changed file with 161 additions and 85 deletions.
246 changes: 161 additions & 85 deletions tools/generator-go-sdk/internal/generator/templater_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,13 @@ func (c modelsTemplater) structCode(data GeneratorData) (*string, error) {
}
sort.Strings(fields)

structLines := make([]string, 0)
for _, fieldName := range fields {
fieldDetails := c.model.Fields[fieldName]
fieldTypeName := "FIXME"
fieldTypeVal, err := helpers.GolangTypeForSDKObjectDefinition(fieldDetails.ObjectDefinition, nil, data.commonTypesPackageName)
if err != nil {
return nil, fmt.Errorf("determining type information for %q: %+v", fieldName, err)
}
fieldTypeName = *fieldTypeVal

structLine, err := c.structLineForField(fieldName, fieldTypeName, fieldDetails, data)
if err != nil {
return nil, err
}

structLines = append(structLines, *structLine)
// Build struct field lines
structLines, err := c.structLinesForModel(data, fields, false, false)
if err != nil {
return nil, fmt.Errorf("building struct lines for %q model: %+v", c.name, err)
}

// then add any inherited fields
// If this is a child model, determine its ancestry
parentAssignmentInfo := ""
ancestorTypeNames := make([]string, 0)
if c.model.ParentTypeName != nil {
Expand All @@ -111,52 +99,6 @@ func (c modelsTemplater) structCode(data GeneratorData) (*string, error) {

// Since ancestor interfaces embed each other, we only need to satisfy the immediate parent interface
parentAssignmentInfo = fmt.Sprintf("var _ %[1]s = %[2]s{}", ancestorTypeNames[0], structName)

// We want to include fields from all ancestors, grouped by ancestor name
ancestorFields := make(map[string]map[string]models.SDKField)
for _, ancestorTypeName := range ancestorTypeNames {
parent, ok := data.models[ancestorTypeName]
if !ok {
return nil, fmt.Errorf("couldn't find Ancestor Model %q for Model %q", ancestorTypeName, c.name)
}
ancestorFields[ancestorTypeName] = make(map[string]models.SDKField)
for fieldName, fieldDetails := range parent.Fields {
ancestorFields[ancestorTypeName][fieldName] = fieldDetails
}
}

// Get sorted slices of ancestors' field names
ancestorFieldNames := make(map[string][]string)
for ancestorName := range ancestorFields {
ancestorFieldNames[ancestorName] = make([]string, 0, len(ancestorFields[ancestorName]))
for fieldName := range ancestorFields[ancestorName] {
ancestorFieldNames[ancestorName] = append(ancestorFieldNames[ancestorName], fieldName)
}
sort.Strings(ancestorFieldNames[ancestorName])
}

// Append fields from all ancestors to struct
for _, ancestorName := range ancestorTypeNames {
if len(ancestorFieldNames[ancestorName]) > 0 {
structLines = append(structLines, fmt.Sprintf("\n// Fields inherited from %s", ancestorName))
for _, fieldName := range ancestorFieldNames[ancestorName] {
fieldDetails := ancestorFields[ancestorName][fieldName]
fieldTypeName := "FIXME"
fieldTypeVal, err := helpers.GolangTypeForSDKObjectDefinition(fieldDetails.ObjectDefinition, nil, data.commonTypesPackageName)
if err != nil {
return nil, fmt.Errorf("determining type information for %q: %+v", fieldName, err)
}
fieldTypeName = *fieldTypeVal

structLine, err := c.structLineForField(fieldName, fieldTypeName, fieldDetails, data)
if err != nil {
return nil, err
}

structLines = append(structLines, *structLine)
}
}
}
}

// If this is a parent model, we output an Interface with a manual unmarshal func that gets called wherever it's used
Expand Down Expand Up @@ -184,12 +126,12 @@ type %[1]s interface {

// Format the model struct field lines
formattedStructLines := make([]string, 0)
for i, v := range structLines {
for i, v := range *structLines {
if strings.HasPrefix(strings.TrimSpace(v), "//") {
if i > 0 && !strings.HasSuffix(formattedStructLines[i-1], "\n") {
v = "\n" + v
}
if i < len(structLines)-1 {
if i < len(*structLines)-1 {
v += "\n"
}
}
Expand Down Expand Up @@ -292,7 +234,99 @@ func (c modelsTemplater) methods(data GeneratorData) (*string, error) {
return &output, nil
}

func (c modelsTemplater) structLineForField(fieldName, fieldType string, fieldDetails models.SDKField, data GeneratorData) (*string, error) {
func (c modelsTemplater) structLinesForModel(data GeneratorData, fieldNames []string, excludeComments, excludeDiscriminatedParentMembers bool) (*[]string, error) {
output := make([]string, 0)

for _, fieldName := range fieldNames {
fieldDetails := c.model.Fields[fieldName]

if excludeDiscriminatedParentMembers && fieldDetails.ObjectDefinition.ReferenceName != nil {
if referencedModel, ok := data.models[*fieldDetails.ObjectDefinition.ReferenceName]; ok && referencedModel.IsDiscriminatedParentType() {
// Skip fields that reference a parent model (i.e. an interface rather than a struct)
continue
}
}

fieldTypeName := "FIXME"
fieldTypeVal, err := helpers.GolangTypeForSDKObjectDefinition(fieldDetails.ObjectDefinition, nil, data.commonTypesPackageName)
if err != nil {
return nil, fmt.Errorf("determining type information for %q: %+v", fieldName, err)
}
fieldTypeName = *fieldTypeVal

structLine, err := c.structLineForField(fieldName, fieldTypeName, fieldDetails, data, excludeComments)
if err != nil {
return nil, err
}

output = append(output, *structLine)
}

// then add any inherited fields
ancestorTypeNames := make([]string, 0)
if c.model.ParentTypeName != nil {
ancestorTypeNames = append(ancestorTypeNames, *c.model.ParentTypeName)
if c.model.FieldNameContainingDiscriminatedValue != nil {
_, foundAncestorTypeNames, err := c.findModelAncestry(data, *c.model.ParentTypeName, *c.model.FieldNameContainingDiscriminatedValue)
if err != nil {
return nil, err
}
ancestorTypeNames = *foundAncestorTypeNames
}

// We want to include fields from all ancestors, grouped by ancestor name
ancestorFields := make(map[string]map[string]models.SDKField)
for _, ancestorTypeName := range ancestorTypeNames {
parent, ok := data.models[ancestorTypeName]
if !ok {
return nil, fmt.Errorf("couldn't find Ancestor Model %q for Model %q", ancestorTypeName, c.name)
}
ancestorFields[ancestorTypeName] = make(map[string]models.SDKField)
for fieldName, fieldDetails := range parent.Fields {
ancestorFields[ancestorTypeName][fieldName] = fieldDetails
}
}

// Get sorted slices of ancestors' field names
ancestorFieldNames := make(map[string][]string)
for ancestorName := range ancestorFields {
ancestorFieldNames[ancestorName] = make([]string, 0, len(ancestorFields[ancestorName]))
for fieldName := range ancestorFields[ancestorName] {
ancestorFieldNames[ancestorName] = append(ancestorFieldNames[ancestorName], fieldName)
}
sort.Strings(ancestorFieldNames[ancestorName])
}

// Append fields from all ancestors to struct
for _, ancestorName := range ancestorTypeNames {
if len(ancestorFieldNames[ancestorName]) > 0 {
if !excludeComments {
output = append(output, fmt.Sprintf("\n// Fields inherited from %s", ancestorName))
}
for _, fieldName := range ancestorFieldNames[ancestorName] {
fieldDetails := ancestorFields[ancestorName][fieldName]
fieldTypeName := "FIXME"
fieldTypeVal, err := helpers.GolangTypeForSDKObjectDefinition(fieldDetails.ObjectDefinition, nil, data.commonTypesPackageName)
if err != nil {
return nil, fmt.Errorf("determining type information for %q: %+v", fieldName, err)
}
fieldTypeName = *fieldTypeVal

structLine, err := c.structLineForField(fieldName, fieldTypeName, fieldDetails, data, excludeComments)
if err != nil {
return nil, err
}

output = append(output, *structLine)
}
}
}
}

return &output, nil
}

func (c modelsTemplater) structLineForField(fieldName, fieldType string, fieldDetails models.SDKField, data GeneratorData, excludeComments bool) (*string, error) {
jsonDetails := fieldDetails.JsonName

if strings.HasPrefix(fieldType, "nullable.") {
Expand All @@ -309,7 +343,7 @@ func (c modelsTemplater) structLineForField(fieldName, fieldType string, fieldDe

line := fmt.Sprintf("\t%s %s `json:\"%s\"`", fieldName, fieldType, jsonDetails)

if data.generateDescriptionsForModels && fieldDetails.Description != "" {
if data.generateDescriptionsForModels && !excludeComments && fieldDetails.Description != "" {
comment := wrapOnWordBoundary(fieldDetails.Description, 120, "//")
line = fmt.Sprintf("%s\n%s", comment, line)
}
Expand Down Expand Up @@ -663,8 +697,10 @@ func (c modelsTemplater) codeForUnmarshalParentFunction(data GeneratorData) (*st
}
}

// NOTE: unmarshaling null returns an empty map, which'll mean the `ok` fails
// the 'type' field being omitted will also mean that `ok` is false
// Fail forward when unmarshalling - if the `temp` map is nil or empty, or if the discriminated value field is
// missing, we will proceed to unmarshal into the base implementation struct, rather than returning nil. When
// assigning the discriminated value field, we'll also attempt to stringify it, in case the API returns a different
// type than we are expecting.
lines = append(lines, fmt.Sprintf(`
func Unmarshal%[1]sImplementation(input []byte) (%[1]s, error) {
if input == nil {
Expand All @@ -676,9 +712,9 @@ func Unmarshal%[1]sImplementation(input []byte) (%[1]s, error) {
return nil, fmt.Errorf("unmarshaling %[1]s into map[string]interface: %%+v", err)
}
value, ok := temp[%[2]q].(string)
if !ok {
return nil, nil
var value string
if v, ok := temp[%[2]q]; ok {
value = fmt.Sprintf("%%v", v)
}
`, c.name, discriminatedValueField.JsonName))

Expand All @@ -697,8 +733,10 @@ func Unmarshal%[1]sImplementation(input []byte) (%[1]s, error) {
`, *model.DiscriminatedValue, implementationName))
}

// if it doesn't match - we generate and deserialize into a 'Raw{Name}Impl' type - named intentionally
// so that we don't conflict with a generated 'Raw{Name}' type which exists in a handful of Swaggers
// If no child type was matched, we generate and deserialize into a `Raw{Name}Impl` type - named intentionally
// so that we don't conflict with a generated `Raw{Name}` type which exists in a handful of Swaggers. The
// `Raw{Name}Impl` type implements the parent model interface, so the parent fields can be retrieved that way, and
// arbitrary fields can be found in the `Values` map.
implementationStructName := fmt.Sprintf("Raw%sImpl", c.name)
lines = append(lines, fmt.Sprintf(`
var parent %[1]s
Expand Down Expand Up @@ -727,12 +765,21 @@ func (c modelsTemplater) codeForUnmarshalStructFunction(data GeneratorData) (*st
structName = fmt.Sprintf("Base%sImpl", c.name)
}

fieldNames := make([]string, 0)
for fieldName := range c.model.Fields {
fieldNames = append(fieldNames, fieldName)
}
sort.Strings(fieldNames)

lines := make([]string, 0)

// Determine which fields can be directly assigned and which must be explicitly unmarshalled
fieldsRequiringAssignment := make([]string, 0)
fieldsRequiringUnmarshalling := make(map[string]models.SDKField)
for fieldName, fieldDetails := range c.model.Fields {

for _, fieldName := range fieldNames {
fieldDetails := c.model.Fields[fieldName]

// Check if the model field references a model interface, which will require explicit unmarshalling
topLevelObject := helpers.InnerMostSDKObjectDefinition(fieldDetails.ObjectDefinition)
if topLevelObject.Type == models.ReferenceSDKObjectDefinitionType {
Expand Down Expand Up @@ -774,7 +821,14 @@ func (c modelsTemplater) codeForUnmarshalStructFunction(data GeneratorData) (*st
}
}

for fieldName, fieldDetails := range ancestorFields {
ancestorFieldNames := make([]string, 0)
for fieldName := range ancestorFields {
ancestorFieldNames = append(ancestorFieldNames, fieldName)
}
sort.Strings(ancestorFieldNames)

for _, fieldName := range ancestorFieldNames {
fieldDetails := ancestorFields[fieldName]
// Check if the ancestor field references a model interface, which requires explicit unmarshalling
topLevelObject := helpers.InnerMostSDKObjectDefinition(fieldDetails.ObjectDefinition)
if topLevelObject.Type == models.ReferenceSDKObjectDefinitionType {
Expand All @@ -792,24 +846,46 @@ func (c modelsTemplater) codeForUnmarshalStructFunction(data GeneratorData) (*st
}
}

// we only need a custom unmarshal function when there's something needing unmarshaling
// Determine struct fields for unmarshalling
aliasStructLines, err := c.structLinesForModel(data, fieldNames, true, true)
if err != nil {
return nil, fmt.Errorf("building struct lines for %q model: %+v", c.name, err)
}

// Format the model struct field lines
formattedStructLines := make([]string, 0)
for i, v := range *aliasStructLines {
if strings.HasPrefix(strings.TrimSpace(v), "//") {
if i > 0 && !strings.HasSuffix(formattedStructLines[i-1], "\n") {
v = "\n" + v
}
if i < len(*aliasStructLines)-1 {
v += "\n"
}
}
formattedStructLines = append(formattedStructLines, v)
}

aliasTypeDefinition := fmt.Sprintf(`struct {
%[1]s
}`, strings.Join(formattedStructLines, "\n"))

// we only need a custom unmarshal function when there's something needing unmarshalling
// else the default unmarshaler will be fine
if len(fieldsRequiringUnmarshalling) > 0 {
lines = append(lines, fmt.Sprintf(`
var _ json.Unmarshaler = &%[1]s{}
lines = append(lines, fmt.Sprintf(`var _ json.Unmarshaler = &%[1]s{}
func (s *%[1]s) UnmarshalJSON(bytes []byte) error {`, structName))

// first for each regular field, decode & assign that
if len(fieldsRequiringAssignment) > 0 {
lines = append(lines, fmt.Sprintf(`type alias %[1]s
var decoded alias
lines = append(lines, fmt.Sprintf(`
var decoded %[1]s
if err := json.Unmarshal(bytes, &decoded); err != nil {
return fmt.Errorf("unmarshaling into %[1]s: %%+v", err)
return fmt.Errorf("unmarshaling: %%+v", err)
}
`, structName))
`, aliasTypeDefinition))

sort.Strings(fieldsRequiringAssignment)
for _, fieldName := range fieldsRequiringAssignment {
lines = append(lines, fmt.Sprintf("s.%[1]s = decoded.%[1]s", fieldName))
}
Expand Down Expand Up @@ -935,7 +1011,7 @@ func (s *%[1]s) UnmarshalJSON(bytes []byte) error {`, structName))
}
}

lines = append(lines, "return nil")
lines = append(lines, "\nreturn nil")
lines = append(lines, "}")
}

Expand Down

0 comments on commit 66c8d0d

Please sign in to comment.