diff --git a/pkg/konnect/login_service.go b/pkg/konnect/login_service.go index ce40fb5..d211599 100644 --- a/pkg/konnect/login_service.go +++ b/pkg/konnect/login_service.go @@ -45,11 +45,16 @@ func (s *AuthService) Login(ctx context.Context, email, return authResponse, nil } +// getGlobalEndpoint returns the global endpoint for a given base Konnect URL. +func getGlobalEndpoint(baseURL string) string { + parts := strings.Split(baseURL, "api.konghq") + return baseEndpointUS + parts[len(parts)-1] +} + // getGlobalAuthEndpoint returns the global auth endpoint // given a base Konnect URL. func getGlobalAuthEndpoint(baseURL string) string { - parts := strings.Split(baseURL, "api.konghq") - return baseEndpointUS + parts[len(parts)-1] + authEndpointV2 + return getGlobalEndpoint(baseURL) + authEndpointV2 } func createAuthRequest(baseURL, email, password string) (*http.Request, error) { @@ -129,8 +134,7 @@ func (s *AuthService) LoginV2(ctx context.Context, email, func (s *AuthService) OrgUserInfo(ctx context.Context) (*OrgUserInfo, error) { // replace geo-specific endpoint with global one for retrieving org info client := *s.client - client.baseURL = strings.Replace(s.client.baseURL, "eu.", "global.", 1) - client.baseURL = strings.Replace(client.baseURL, "au.", "global.", 1) + client.baseURL = getGlobalEndpoint(client.baseURL) req, err := client.NewRequest(http.MethodGet, "/v2/organizations/me", nil, nil) if err != nil { diff --git a/pkg/konnect/login_service_test.go b/pkg/konnect/login_service_test.go index 89f384f..99765fc 100644 --- a/pkg/konnect/login_service_test.go +++ b/pkg/konnect/login_service_test.go @@ -1,11 +1,26 @@ package konnect import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +type mockRoundTripper struct{ mockHost string } + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req.Host = req.URL.Host + req.URL.Scheme = "http" + req.URL.Host = m.mockHost + + return (&http.Client{}).Do(req) +} + func TestGetGlobalAuthEndpoint(t *testing.T) { tests := []struct { baseURL string @@ -40,3 +55,43 @@ func TestGetGlobalAuthEndpoint(t *testing.T) { assert.Equal(t, tt.expected, getGlobalAuthEndpoint(tt.baseURL)) } } + +func TestAuthService_OrgUserInfo(t *testing.T) { + expectedResp := OrgUserInfo{Name: "test-org", OrgID: "1234"} + + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.Host, "global.api.konghq.com") + + if r.URL.Path == "/v2/organizations/me" && r.Method == http.MethodGet { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + resp, err := json.Marshal(expectedResp) + require.NoError(t, err) + + _, err = w.Write(resp) + require.NoError(t, err) + + return + } + + http.NotFound(w, r) + })) + defer mockServer.Close() + + authService := &AuthService{ + client: &Client{ + baseURL: "https://some-geo.api.konghq.com", + client: &http.Client{ + Transport: &mockRoundTripper{ + mockHost: mockServer.Listener.Addr().String(), + }, + }, + }, + } + + info, err := authService.OrgUserInfo(context.Background()) + require.NoError(t, err) + assert.Equal(t, expectedResp.Name, info.Name) + assert.Equal(t, expectedResp.OrgID, info.OrgID) +}