Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generator-go-sdk: fix unmarshalling fields that reference a model interface #4427

Merged
merged 3 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 170 additions & 85 deletions tools/generator-go-sdk/internal/generator/templater_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,13 @@ func (c modelsTemplater) structCode(data GeneratorData) (*string, error) {
}
sort.Strings(fields)

structLines := make([]string, 0)
for _, fieldName := range fields {
fieldDetails := c.model.Fields[fieldName]
fieldTypeName := "FIXME"
fieldTypeVal, err := helpers.GolangTypeForSDKObjectDefinition(fieldDetails.ObjectDefinition, nil, data.commonTypesPackageName)
if err != nil {
return nil, fmt.Errorf("determining type information for %q: %+v", fieldName, err)
}
fieldTypeName = *fieldTypeVal

structLine, err := c.structLineForField(fieldName, fieldTypeName, fieldDetails, data)
if err != nil {
return nil, err
}

structLines = append(structLines, *structLine)
// Build struct field lines
structLines, err := c.structLinesForModel(data, fields, false, false)
if err != nil {
return nil, fmt.Errorf("building struct lines for %q model: %+v", c.name, err)
}

// then add any inherited fields
// If this is a child model, determine its ancestry
parentAssignmentInfo := ""
ancestorTypeNames := make([]string, 0)
if c.model.ParentTypeName != nil {
Expand All @@ -111,52 +99,6 @@ func (c modelsTemplater) structCode(data GeneratorData) (*string, error) {

// Since ancestor interfaces embed each other, we only need to satisfy the immediate parent interface
parentAssignmentInfo = fmt.Sprintf("var _ %[1]s = %[2]s{}", ancestorTypeNames[0], structName)

// We want to include fields from all ancestors, grouped by ancestor name
ancestorFields := make(map[string]map[string]models.SDKField)
for _, ancestorTypeName := range ancestorTypeNames {
parent, ok := data.models[ancestorTypeName]
if !ok {
return nil, fmt.Errorf("couldn't find Ancestor Model %q for Model %q", ancestorTypeName, c.name)
}
ancestorFields[ancestorTypeName] = make(map[string]models.SDKField)
for fieldName, fieldDetails := range parent.Fields {
ancestorFields[ancestorTypeName][fieldName] = fieldDetails
}
}

// Get sorted slices of ancestors' field names
ancestorFieldNames := make(map[string][]string)
for ancestorName := range ancestorFields {
ancestorFieldNames[ancestorName] = make([]string, 0, len(ancestorFields[ancestorName]))
for fieldName := range ancestorFields[ancestorName] {
ancestorFieldNames[ancestorName] = append(ancestorFieldNames[ancestorName], fieldName)
}
sort.Strings(ancestorFieldNames[ancestorName])
}

// Append fields from all ancestors to struct
for _, ancestorName := range ancestorTypeNames {
if len(ancestorFieldNames[ancestorName]) > 0 {
structLines = append(structLines, fmt.Sprintf("\n// Fields inherited from %s", ancestorName))
for _, fieldName := range ancestorFieldNames[ancestorName] {
fieldDetails := ancestorFields[ancestorName][fieldName]
fieldTypeName := "FIXME"
fieldTypeVal, err := helpers.GolangTypeForSDKObjectDefinition(fieldDetails.ObjectDefinition, nil, data.commonTypesPackageName)
if err != nil {
return nil, fmt.Errorf("determining type information for %q: %+v", fieldName, err)
}
fieldTypeName = *fieldTypeVal

structLine, err := c.structLineForField(fieldName, fieldTypeName, fieldDetails, data)
if err != nil {
return nil, err
}

structLines = append(structLines, *structLine)
}
}
}
}

// If this is a parent model, we output an Interface with a manual unmarshal func that gets called wherever it's used
Expand Down Expand Up @@ -184,12 +126,12 @@ type %[1]s interface {

// Format the model struct field lines
formattedStructLines := make([]string, 0)
for i, v := range structLines {
for i, v := range *structLines {
if strings.HasPrefix(strings.TrimSpace(v), "//") {
if i > 0 && !strings.HasSuffix(formattedStructLines[i-1], "\n") {
v = "\n" + v
}
if i < len(structLines)-1 {
if i < len(*structLines)-1 {
v += "\n"
}
}
Expand Down Expand Up @@ -292,7 +234,109 @@ func (c modelsTemplater) methods(data GeneratorData) (*string, error) {
return &output, nil
}

func (c modelsTemplater) structLineForField(fieldName, fieldType string, fieldDetails models.SDKField, data GeneratorData) (*string, error) {
func (c modelsTemplater) structLinesForModel(data GeneratorData, fieldNames []string, excludeComments, excludeDiscriminatedParentMembers bool) (*[]string, error) {
output := make([]string, 0)

for _, fieldName := range fieldNames {
fieldDetails := c.model.Fields[fieldName]

topLevelObject := helpers.InnerMostSDKObjectDefinition(fieldDetails.ObjectDefinition)
if excludeDiscriminatedParentMembers && topLevelObject.Type == models.ReferenceSDKObjectDefinitionType {
if referencedModel, ok := data.models[*topLevelObject.ReferenceName]; ok && referencedModel.IsDiscriminatedParentType() {
// Skip fields that reference a parent model (i.e. an interface rather than a struct)
continue
}
}

fieldTypeName := "FIXME"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any context on this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a holdover, we should probably be erroring here rather than outputting FIXME strings, I'll look to fix that up 👍

fieldTypeVal, err := helpers.GolangTypeForSDKObjectDefinition(fieldDetails.ObjectDefinition, nil, data.commonTypesPackageName)
if err != nil {
return nil, fmt.Errorf("determining type information for %q: %+v", fieldName, err)
}
fieldTypeName = *fieldTypeVal

structLine, err := c.structLineForField(fieldName, fieldTypeName, fieldDetails, data, excludeComments)
if err != nil {
return nil, err
}

output = append(output, *structLine)
}

// then add any inherited fields
ancestorTypeNames := make([]string, 0)
if c.model.ParentTypeName != nil {
ancestorTypeNames = append(ancestorTypeNames, *c.model.ParentTypeName)
if c.model.FieldNameContainingDiscriminatedValue != nil {
_, foundAncestorTypeNames, err := c.findModelAncestry(data, *c.model.ParentTypeName, *c.model.FieldNameContainingDiscriminatedValue)
if err != nil {
return nil, err
}
ancestorTypeNames = *foundAncestorTypeNames
}

// We want to include fields from all ancestors, grouped by ancestor name
ancestorFields := make(map[string]map[string]models.SDKField)
for _, ancestorTypeName := range ancestorTypeNames {
parent, ok := data.models[ancestorTypeName]
if !ok {
return nil, fmt.Errorf("couldn't find Ancestor Model %q for Model %q", ancestorTypeName, c.name)
}
ancestorFields[ancestorTypeName] = make(map[string]models.SDKField)
for fieldName, fieldDetails := range parent.Fields {
ancestorFields[ancestorTypeName][fieldName] = fieldDetails
}
}

// Get sorted slices of ancestors' field names
ancestorFieldNames := make(map[string][]string)
for ancestorName := range ancestorFields {
ancestorFieldNames[ancestorName] = make([]string, 0, len(ancestorFields[ancestorName]))
for fieldName := range ancestorFields[ancestorName] {
ancestorFieldNames[ancestorName] = append(ancestorFieldNames[ancestorName], fieldName)
}
sort.Strings(ancestorFieldNames[ancestorName])
}

// Append fields from all ancestors to struct
for _, ancestorName := range ancestorTypeNames {
if len(ancestorFieldNames[ancestorName]) > 0 {
if !excludeComments {
output = append(output, fmt.Sprintf("\n// Fields inherited from %s", ancestorName))
}
for _, fieldName := range ancestorFieldNames[ancestorName] {
fieldDetails := ancestorFields[ancestorName][fieldName]

topLevelObject := helpers.InnerMostSDKObjectDefinition(fieldDetails.ObjectDefinition)
if excludeDiscriminatedParentMembers && topLevelObject.Type == models.ReferenceSDKObjectDefinitionType {
if referencedModel, ok := data.models[*topLevelObject.ReferenceName]; ok && referencedModel.IsDiscriminatedParentType() {
// Skip fields that reference a parent model (i.e. an interface rather than a struct)
continue
}
}

fieldTypeName := "FIXME"
fieldTypeVal, err := helpers.GolangTypeForSDKObjectDefinition(fieldDetails.ObjectDefinition, nil, data.commonTypesPackageName)
if err != nil {
return nil, fmt.Errorf("determining type information for %q: %+v", fieldName, err)
}
fieldTypeName = *fieldTypeVal

structLine, err := c.structLineForField(fieldName, fieldTypeName, fieldDetails, data, excludeComments)
if err != nil {
return nil, err
}

output = append(output, *structLine)
}
}
}
}

return &output, nil
}

func (c modelsTemplater) structLineForField(fieldName, fieldType string, fieldDetails models.SDKField, data GeneratorData, excludeComments bool) (*string, error) {
jsonDetails := fieldDetails.JsonName

if strings.HasPrefix(fieldType, "nullable.") {
Expand All @@ -309,7 +353,7 @@ func (c modelsTemplater) structLineForField(fieldName, fieldType string, fieldDe

line := fmt.Sprintf("\t%s %s `json:\"%s\"`", fieldName, fieldType, jsonDetails)

if data.generateDescriptionsForModels && fieldDetails.Description != "" {
if data.generateDescriptionsForModels && !excludeComments && fieldDetails.Description != "" {
comment := wrapOnWordBoundary(fieldDetails.Description, 120, "//")
line = fmt.Sprintf("%s\n%s", comment, line)
}
Expand Down Expand Up @@ -663,8 +707,10 @@ func (c modelsTemplater) codeForUnmarshalParentFunction(data GeneratorData) (*st
}
}

// NOTE: unmarshaling null returns an empty map, which'll mean the `ok` fails
// the 'type' field being omitted will also mean that `ok` is false
// Fail forward when unmarshalling - if the `temp` map is nil or empty, or if the discriminated value field is
// missing, we will proceed to unmarshal into the base implementation struct, rather than returning nil. When
// assigning the discriminated value field, we'll also attempt to stringify it, in case the API returns a different
// type than we are expecting.
lines = append(lines, fmt.Sprintf(`
func Unmarshal%[1]sImplementation(input []byte) (%[1]s, error) {
if input == nil {
Expand All @@ -676,9 +722,9 @@ func Unmarshal%[1]sImplementation(input []byte) (%[1]s, error) {
return nil, fmt.Errorf("unmarshaling %[1]s into map[string]interface: %%+v", err)
}

value, ok := temp[%[2]q].(string)
if !ok {
return nil, nil
var value string
if v, ok := temp[%[2]q]; ok {
value = fmt.Sprintf("%%v", v)
}
`, c.name, discriminatedValueField.JsonName))

Expand All @@ -697,8 +743,10 @@ func Unmarshal%[1]sImplementation(input []byte) (%[1]s, error) {
`, *model.DiscriminatedValue, implementationName))
}

// if it doesn't match - we generate and deserialize into a 'Raw{Name}Impl' type - named intentionally
// so that we don't conflict with a generated 'Raw{Name}' type which exists in a handful of Swaggers
// If no child type was matched, we generate and deserialize into a `Raw{Name}Impl` type - named intentionally
// so that we don't conflict with a generated `Raw{Name}` type which exists in a handful of Swaggers. The
// `Raw{Name}Impl` type implements the parent model interface, so the parent fields can be retrieved that way, and
// arbitrary fields can be found in the `Values` map.
implementationStructName := fmt.Sprintf("Raw%sImpl", c.name)
lines = append(lines, fmt.Sprintf(`
var parent %[1]s
Expand Down Expand Up @@ -727,12 +775,21 @@ func (c modelsTemplater) codeForUnmarshalStructFunction(data GeneratorData) (*st
structName = fmt.Sprintf("Base%sImpl", c.name)
}

fieldNames := make([]string, 0)
for fieldName := range c.model.Fields {
fieldNames = append(fieldNames, fieldName)
}
sort.Strings(fieldNames)

lines := make([]string, 0)

// Determine which fields can be directly assigned and which must be explicitly unmarshalled
fieldsRequiringAssignment := make([]string, 0)
fieldsRequiringUnmarshalling := make(map[string]models.SDKField)
for fieldName, fieldDetails := range c.model.Fields {

for _, fieldName := range fieldNames {
fieldDetails := c.model.Fields[fieldName]

// Check if the model field references a model interface, which will require explicit unmarshalling
topLevelObject := helpers.InnerMostSDKObjectDefinition(fieldDetails.ObjectDefinition)
if topLevelObject.Type == models.ReferenceSDKObjectDefinitionType {
Expand Down Expand Up @@ -774,7 +831,14 @@ func (c modelsTemplater) codeForUnmarshalStructFunction(data GeneratorData) (*st
}
}

for fieldName, fieldDetails := range ancestorFields {
ancestorFieldNames := make([]string, 0)
for fieldName := range ancestorFields {
ancestorFieldNames = append(ancestorFieldNames, fieldName)
}
sort.Strings(ancestorFieldNames)

for _, fieldName := range ancestorFieldNames {
fieldDetails := ancestorFields[fieldName]
// Check if the ancestor field references a model interface, which requires explicit unmarshalling
topLevelObject := helpers.InnerMostSDKObjectDefinition(fieldDetails.ObjectDefinition)
if topLevelObject.Type == models.ReferenceSDKObjectDefinitionType {
Expand All @@ -792,24 +856,45 @@ func (c modelsTemplater) codeForUnmarshalStructFunction(data GeneratorData) (*st
}
}

// we only need a custom unmarshal function when there's something needing unmarshaling
// Determine struct fields for unmarshalling
aliasStructLines, err := c.structLinesForModel(data, fieldNames, true, true)
if err != nil {
return nil, fmt.Errorf("building struct lines for %q model: %+v", c.name, err)
}

// Format the model struct field lines
formattedStructLines := make([]string, 0)
for i, v := range *aliasStructLines {
if strings.HasPrefix(strings.TrimSpace(v), "//") {
if i > 0 && !strings.HasSuffix(formattedStructLines[i-1], "\n") {
v = "\n" + v
}
if i < len(*aliasStructLines)-1 {
v += "\n"
}
}
formattedStructLines = append(formattedStructLines, v)
}

aliasTypeDefinition := fmt.Sprintf(`struct {
%[1]s
}`, strings.Join(formattedStructLines, "\n"))

// we only need a custom unmarshal function when there's something needing unmarshalling
// else the default unmarshaler will be fine
if len(fieldsRequiringUnmarshalling) > 0 {
lines = append(lines, fmt.Sprintf(`
var _ json.Unmarshaler = &%[1]s{}
lines = append(lines, fmt.Sprintf(`var _ json.Unmarshaler = &%[1]s{}

func (s *%[1]s) UnmarshalJSON(bytes []byte) error {`, structName))

// first for each regular field, decode & assign that
if len(fieldsRequiringAssignment) > 0 {
lines = append(lines, fmt.Sprintf(`type alias %[1]s
var decoded alias
lines = append(lines, fmt.Sprintf(` var decoded %[1]s
if err := json.Unmarshal(bytes, &decoded); err != nil {
return fmt.Errorf("unmarshaling into %[1]s: %%+v", err)
return fmt.Errorf("unmarshaling: %%+v", err)
}
`, structName))
`, aliasTypeDefinition))

sort.Strings(fieldsRequiringAssignment)
for _, fieldName := range fieldsRequiringAssignment {
lines = append(lines, fmt.Sprintf("s.%[1]s = decoded.%[1]s", fieldName))
}
Expand Down Expand Up @@ -935,7 +1020,7 @@ func (s *%[1]s) UnmarshalJSON(bytes []byte) error {`, structName))
}
}

lines = append(lines, "return nil")
lines = append(lines, "\nreturn nil")
lines = append(lines, "}")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ func UnmarshalModeOfTransitImplementation(input []byte) (ModeOfTransit, error) {
return nil, fmt.Errorf("unmarshaling ModeOfTransit into map[string]interface: %+v", err)
}

value, ok := temp["type"].(string)
if !ok {
return nil, nil
var value string
if v, ok := temp["type"]; ok {
value = fmt.Sprintf("%v", v)
}

if strings.EqualFold(value, "car") {
Expand Down Expand Up @@ -524,10 +524,11 @@ func (s FirstImplementation) MarshalJSON() ([]byte, error) {
var _ json.Unmarshaler = &FirstImplementation{}

func (s *FirstImplementation) UnmarshalJSON(bytes []byte) error {
type alias FirstImplementation
var decoded alias
var decoded struct {
Type string ''json:"type"''
}
if err := json.Unmarshal(bytes, &decoded); err != nil {
return fmt.Errorf("unmarshaling into FirstImplementation: %+v", err)
return fmt.Errorf("unmarshaling: %+v", err)
}

s.Type = decoded.Type
Expand Down
Loading