diff --git a/openapi.go b/openapi.go index a2815223..53ff2b6e 100644 --- a/openapi.go +++ b/openapi.go @@ -161,7 +161,7 @@ func RegisterOpenAPIOperation[T, B any](s *Server, method, path string) (*openap } // Request body - bodyTag := schemaTagFromType[B](s, *new(B)) + bodyTag := schemaTagFromType(s, *new(B)) if (method == http.MethodPost || method == http.MethodPut || method == http.MethodPatch) && bodyTag.name != "unknown-interface" && bodyTag.name != "string" { content := openapi3.NewContentWithSchemaRef(&bodyTag.SchemaRef, []string{"application/json"}) requestBody := openapi3.NewRequestBody(). @@ -180,7 +180,7 @@ func RegisterOpenAPIOperation[T, B any](s *Server, method, path string) (*openap } } - responseSchema := schemaTagFromType[T](s, *new(T)) + responseSchema := schemaTagFromType(s, *new(T)) content := openapi3.NewContentWithSchemaRef(&responseSchema.SchemaRef, []string{"application/json"}) response := openapi3.NewResponse(). WithDescription("OK"). @@ -205,7 +205,7 @@ type schemaTag struct { name string } -func schemaTagFromType[V any](s *Server, v any) schemaTag { +func schemaTagFromType(s *Server, v any) schemaTag { if v == nil { // ensure we add unknown-interface to our schemas s.getOrCreateSchema("unknown-interface", struct{}{}) @@ -217,7 +217,7 @@ func schemaTagFromType[V any](s *Server, v any) schemaTag { } } - return dive[V](s, reflect.TypeOf(v), schemaTag{}, 5) + return dive(s, reflect.TypeOf(v), schemaTag{}, 5) } // dive returns a schemaTag which includes the generated openapi3.SchemaRef and @@ -227,7 +227,7 @@ func schemaTagFromType[V any](s *Server, v any) schemaTag { // If the type is a slice or array type it will dive into the type as well as // build and openapi3.Schema where Type is array and Ref is set to the proper // components Schema -func dive[V any](s *Server, t reflect.Type, tag schemaTag, maxDepth int) schemaTag { +func dive(s *Server, t reflect.Type, tag schemaTag, maxDepth int) schemaTag { if maxDepth == 0 { return schemaTag{ name: "default", @@ -239,10 +239,10 @@ func dive[V any](s *Server, t reflect.Type, tag schemaTag, maxDepth int) schemaT switch t.Kind() { case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Func, reflect.UnsafePointer: - return dive[V](s, t.Elem(), tag, maxDepth-1) + return dive(s, t.Elem(), tag, maxDepth-1) case reflect.Slice, reflect.Array: - item := dive[V](s, t.Elem(), tag, maxDepth-1) + item := dive(s, t.Elem(), tag, maxDepth-1) tag.name = item.name tag.Value = &openapi3.Schema{ Type: "array", @@ -253,7 +253,7 @@ func dive[V any](s *Server, t reflect.Type, tag schemaTag, maxDepth int) schemaT default: tag.name = t.Name() tag.Ref = "#/components/schemas/" + tag.name - tag.Value = s.getOrCreateSchema(tag.name, new(V)) + tag.Value = s.getOrCreateSchema(tag.name, reflect.New(t).Interface()) return tag } } diff --git a/openapi_test.go b/openapi_test.go index bf4832c4..442dc6e5 100644 --- a/openapi_test.go +++ b/openapi_test.go @@ -32,13 +32,6 @@ type testCaseForTagType[V any] struct { expectedTagValue string } -func runTestCase[V any](tc testCaseForTagType[V]) func(t *testing.T) { - return func(t *testing.T) { - tag := schemaTagFromType[V](tc.s, tc.inputType) - assert.Equal(t, tc.expectedTagValue, tag.name, tc.description) - } -} - func Test_tagFromType(t *testing.T) { s := NewServer() type DeeplyNested *[]MyStruct @@ -139,7 +132,10 @@ func Test_tagFromType(t *testing.T) { } for _, tc := range tcs { - t.Run(tc.name, runTestCase(tc)) + t.Run(tc.name, func(t *testing.T) { + tag := schemaTagFromType(tc.s, tc.inputType) + assert.Equal(t, tc.expectedTagValue, tag.name, tc.description) + }) } }