Skip to content

Commit

Permalink
auth: omit tenant argument when acquiring tokens when authenticated u…
Browse files Browse the repository at this point in the history
…sing a managed identity
  • Loading branch information
manicminer committed Feb 28, 2024
1 parent b41754d commit 3655f75
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 21 deletions.
24 changes: 22 additions & 2 deletions sdk/auth/azure_cli_authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,29 @@ func (a *AzureCliAuthorizer) Token(_ context.Context, _ *http.Request) (*oauth2.
}
azArgs = append(azArgs, "--scope", *scope)

accountType, err := azurecli.GetAccountType()
if err != nil {
return nil, fmt.Errorf("determining account type: %+v", err)
}

accountName, err := azurecli.GetAccountName()
if err != nil {
return nil, fmt.Errorf("determining account name: %+v", err)
}

tenantIdRequired := true

// Try to detect if we're running in Cloud Shell
if cloudShell := os.Getenv("AZUREPS_HOST_ENVIRONMENT"); !strings.HasPrefix(cloudShell, "cloud-shell/") {
// Seemingly not, so we'll append the tenant ID to the az args
if cloudShell := os.Getenv("AZUREPS_HOST_ENVIRONMENT"); strings.HasPrefix(cloudShell, "cloud-shell/") {
tenantIdRequired = false
}

// Try to detect whether authenticated principal is a managed identity
if accountType != nil && accountName != nil && *accountType == "servicePrincipal" && (*accountName == "systemAssignedIdentity" || *accountName == "userAssignedIdentity") {
tenantIdRequired = false
}

if tenantIdRequired {
azArgs = append(azArgs, "--tenant", a.conf.TenantID)
}

Expand Down
86 changes: 67 additions & 19 deletions sdk/internal/azurecli/azcli.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,33 @@ import (
"github.com/hashicorp/go-version"
)

type azAccount struct {
EnvironmentName *string `json:"environmentName"`
HomeTenantId *string `json:"homeTenantId"`
Id *string `json:"id"`
Default *bool `json:"isDefault"`
Name *string `json:"name"`
State *string `json:"state"`
TenantId *string `json:"tenantId"`

ManagedByTenants *[]struct {
TenantId *string `json:"tenantId"`
} `json:"managedByTenants"`

User *struct {
AssignedIdentityInfo *string `json:"assignedIdentityInfo"`
Name *string `json:"name"`
Type *string `json:"type"`
}
}

type azVersion struct {
AzureCli *string `json:"azure-cli,omitempty"`
AzureCliCore *string `json:"azure-cli-core,omitempty"`
AzureCliTelemetry *string `json:"azure-cli-telemetry,omitempty"`
Extensions *interface{} `json:"extensions,omitempty"`
}

// CheckAzVersion tries to determine the version of Azure CLI in the path and checks for a compatible version
func CheckAzVersion() error {
currentVersion, err := getAzVersion()
Expand Down Expand Up @@ -60,39 +87,60 @@ func ValidateTenantID(tenantId string) (bool, error) {

// GetDefaultTenantID tries to determine the default tenant
func GetDefaultTenantID() (*string, error) {
var account struct {
TenantID string `json:"tenantId"`
}
if err := JSONUnmarshalAzCmd(true, &account, "account", "show"); err != nil {
account, err := getAzAccount()
if err != nil {
return nil, fmt.Errorf("obtaining tenant ID: %s", err)
}

return &account.TenantID, nil
return account.TenantId, nil
}

// GetDefaultSubscriptionID tries to determine the default subscription
func GetDefaultSubscriptionID() (*string, error) {
var account struct {
SubscriptionID string `json:"id"`
}
err := JSONUnmarshalAzCmd(true, &account, "account", "show")
account, err := getAzAccount()
if err != nil {
return nil, fmt.Errorf("obtaining subscription ID: %s", err)
}
return account.Id, nil
}

// GetAccountName returns the name of the authenticated principal
func GetAccountName() (*string, error) {
account, err := getAzAccount()
if err != nil {
return nil, fmt.Errorf("obtaining account name: %s", err)
}
if account.User == nil {
return nil, fmt.Errorf("account details were nil: %s", err)
}
return account.User.Name, nil
}

return &account.SubscriptionID, nil
// GetAccountType returns the account type of the authenticated principal
func GetAccountType() (*string, error) {
account, err := getAzAccount()
if err != nil {
return nil, fmt.Errorf("obtaining account type: %s", err)
}
if account.User == nil {
return nil, fmt.Errorf("account details were nil: %s", err)
}
return account.User.Type, nil
}

// getAzAccount returns the output of `az account show`
func getAzAccount() (*azAccount, error) {
var account azAccount
if err := JSONUnmarshalAzCmd(true, &account, "account", "show"); err != nil {
return nil, fmt.Errorf("obtaining account details: %s", err)
}
return &account, nil
}

// getAzVersion tries to determine the version of Azure CLI in the path.
func getAzVersion() (*string, error) {
var cliVersion *struct {
AzureCli *string `json:"azure-cli,omitempty"`
AzureCliCore *string `json:"azure-cli-core,omitempty"`
AzureCliTelemetry *string `json:"azure-cli-telemetry,omitempty"`
Extensions *interface{} `json:"extensions,omitempty"`
}
err := JSONUnmarshalAzCmd(true, &cliVersion, "version")
if err != nil {
var cliVersion azVersion

if err := JSONUnmarshalAzCmd(true, &cliVersion, "version"); err != nil {
return nil, fmt.Errorf("could not parse Azure CLI version: %v", err)
}

Expand Down

0 comments on commit 3655f75

Please sign in to comment.