From 3655f75f58d4c6f42a591a395e7a030baea49b72 Mon Sep 17 00:00:00 2001 From: Tom Bamford Date: Wed, 28 Feb 2024 17:27:44 +0000 Subject: [PATCH] auth: omit tenant argument when acquiring tokens when authenticated using a managed identity --- sdk/auth/azure_cli_authorizer.go | 24 ++++++++- sdk/internal/azurecli/azcli.go | 86 +++++++++++++++++++++++++------- 2 files changed, 89 insertions(+), 21 deletions(-) diff --git a/sdk/auth/azure_cli_authorizer.go b/sdk/auth/azure_cli_authorizer.go index ac697ea301d..ba1853eca60 100644 --- a/sdk/auth/azure_cli_authorizer.go +++ b/sdk/auth/azure_cli_authorizer.go @@ -65,9 +65,29 @@ func (a *AzureCliAuthorizer) Token(_ context.Context, _ *http.Request) (*oauth2. } azArgs = append(azArgs, "--scope", *scope) + accountType, err := azurecli.GetAccountType() + if err != nil { + return nil, fmt.Errorf("determining account type: %+v", err) + } + + accountName, err := azurecli.GetAccountName() + if err != nil { + return nil, fmt.Errorf("determining account name: %+v", err) + } + + tenantIdRequired := true + // Try to detect if we're running in Cloud Shell - if cloudShell := os.Getenv("AZUREPS_HOST_ENVIRONMENT"); !strings.HasPrefix(cloudShell, "cloud-shell/") { - // Seemingly not, so we'll append the tenant ID to the az args + if cloudShell := os.Getenv("AZUREPS_HOST_ENVIRONMENT"); strings.HasPrefix(cloudShell, "cloud-shell/") { + tenantIdRequired = false + } + + // Try to detect whether authenticated principal is a managed identity + if accountType != nil && accountName != nil && *accountType == "servicePrincipal" && (*accountName == "systemAssignedIdentity" || *accountName == "userAssignedIdentity") { + tenantIdRequired = false + } + + if tenantIdRequired { azArgs = append(azArgs, "--tenant", a.conf.TenantID) } diff --git a/sdk/internal/azurecli/azcli.go b/sdk/internal/azurecli/azcli.go index caaffd8f48c..bb225d18187 100644 --- a/sdk/internal/azurecli/azcli.go +++ b/sdk/internal/azurecli/azcli.go @@ -15,6 +15,33 @@ import ( "github.com/hashicorp/go-version" ) +type azAccount struct { + EnvironmentName *string `json:"environmentName"` + HomeTenantId *string `json:"homeTenantId"` + Id *string `json:"id"` + Default *bool `json:"isDefault"` + Name *string `json:"name"` + State *string `json:"state"` + TenantId *string `json:"tenantId"` + + ManagedByTenants *[]struct { + TenantId *string `json:"tenantId"` + } `json:"managedByTenants"` + + User *struct { + AssignedIdentityInfo *string `json:"assignedIdentityInfo"` + Name *string `json:"name"` + Type *string `json:"type"` + } +} + +type azVersion struct { + AzureCli *string `json:"azure-cli,omitempty"` + AzureCliCore *string `json:"azure-cli-core,omitempty"` + AzureCliTelemetry *string `json:"azure-cli-telemetry,omitempty"` + Extensions *interface{} `json:"extensions,omitempty"` +} + // CheckAzVersion tries to determine the version of Azure CLI in the path and checks for a compatible version func CheckAzVersion() error { currentVersion, err := getAzVersion() @@ -60,39 +87,60 @@ func ValidateTenantID(tenantId string) (bool, error) { // GetDefaultTenantID tries to determine the default tenant func GetDefaultTenantID() (*string, error) { - var account struct { - TenantID string `json:"tenantId"` - } - if err := JSONUnmarshalAzCmd(true, &account, "account", "show"); err != nil { + account, err := getAzAccount() + if err != nil { return nil, fmt.Errorf("obtaining tenant ID: %s", err) } - - return &account.TenantID, nil + return account.TenantId, nil } // GetDefaultSubscriptionID tries to determine the default subscription func GetDefaultSubscriptionID() (*string, error) { - var account struct { - SubscriptionID string `json:"id"` - } - err := JSONUnmarshalAzCmd(true, &account, "account", "show") + account, err := getAzAccount() if err != nil { return nil, fmt.Errorf("obtaining subscription ID: %s", err) } + return account.Id, nil +} + +// GetAccountName returns the name of the authenticated principal +func GetAccountName() (*string, error) { + account, err := getAzAccount() + if err != nil { + return nil, fmt.Errorf("obtaining account name: %s", err) + } + if account.User == nil { + return nil, fmt.Errorf("account details were nil: %s", err) + } + return account.User.Name, nil +} - return &account.SubscriptionID, nil +// GetAccountType returns the account type of the authenticated principal +func GetAccountType() (*string, error) { + account, err := getAzAccount() + if err != nil { + return nil, fmt.Errorf("obtaining account type: %s", err) + } + if account.User == nil { + return nil, fmt.Errorf("account details were nil: %s", err) + } + return account.User.Type, nil +} + +// getAzAccount returns the output of `az account show` +func getAzAccount() (*azAccount, error) { + var account azAccount + if err := JSONUnmarshalAzCmd(true, &account, "account", "show"); err != nil { + return nil, fmt.Errorf("obtaining account details: %s", err) + } + return &account, nil } // getAzVersion tries to determine the version of Azure CLI in the path. func getAzVersion() (*string, error) { - var cliVersion *struct { - AzureCli *string `json:"azure-cli,omitempty"` - AzureCliCore *string `json:"azure-cli-core,omitempty"` - AzureCliTelemetry *string `json:"azure-cli-telemetry,omitempty"` - Extensions *interface{} `json:"extensions,omitempty"` - } - err := JSONUnmarshalAzCmd(true, &cliVersion, "version") - if err != nil { + var cliVersion azVersion + + if err := JSONUnmarshalAzCmd(true, &cliVersion, "version"); err != nil { return nil, fmt.Errorf("could not parse Azure CLI version: %v", err) }