Skip to content

Commit

Permalink
[Fix] Use correct domain for Azure Gov and China (#4274)
Browse files Browse the repository at this point in the history
## Changes
<!-- Summary of your changes that are easy to understand -->

Resolves #4272

## Tests
<!-- 
How is this tested? Please see the checklist below and also describe any
other relevant tests
-->

- [x] `make test` run locally
- [ ] relevant change in `docs/` folder
- [ ] covered with integration tests in `internal/acceptance`
- [x] relevant acceptance tests are passing
- [ ] using Go SDK
  • Loading branch information
alexott authored Dec 3, 2024
1 parent a5eb85e commit 0a055b3
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 58 deletions.
2 changes: 1 addition & 1 deletion storage/adls_gen1_mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type AzureADLSGen1Mount struct {
}

// Source ...
func (m AzureADLSGen1Mount) Source() string {
func (m AzureADLSGen1Mount) Source(_ *common.DatabricksClient) string {
return fmt.Sprintf("adl://%s.azuredatalakestore.net%s", m.StorageResource, m.Directory)
}

Expand Down
28 changes: 25 additions & 3 deletions storage/adls_gen2_mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package storage

import (
"fmt"
"strings"

"github.com/databricks/terraform-provider-databricks/common"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
)

// AzureADLSGen2Mount describes the object for a azure datalake gen 2 storage mount
Expand All @@ -19,10 +21,23 @@ type AzureADLSGen2Mount struct {
InitializeFileSystem bool `json:"initialize_file_system"`
}

func getAzureDomain(client *common.DatabricksClient) string {
domains := map[string]string{
"PUBLIC": "core.windows.net",
"USGOVERNMENT": "core.usgovcloudapi.net",
"CHINA": "core.chinacloudapi.cn",
}
azureEnvironment := client.Config.Environment().AzureEnvironment.Name
domain, ok := domains[strings.ToUpper(azureEnvironment)]
if !ok {
panic(fmt.Sprintf("Unknown Azure environment: '%s'", azureEnvironment))
}
return domain
}

// Source returns ABFSS URI backing the mount
func (m AzureADLSGen2Mount) Source() string {
return fmt.Sprintf("abfss://%s@%s.dfs.core.windows.net%s",
m.ContainerName, m.StorageAccountName, m.Directory)
func (m AzureADLSGen2Mount) Source(client *common.DatabricksClient) string {
return fmt.Sprintf("abfss://%s@%s.dfs.%s%s", m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory)
}

func (m AzureADLSGen2Mount) Name() string {
Expand Down Expand Up @@ -106,5 +121,12 @@ func ResourceAzureAdlsGen2Mount() common.Resource {
Required: true,
ForceNew: true,
},
"environment": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice([]string{"PUBLIC", "USGOVERNMENT", "CHINA"}, false),
Default: "PUBLIC",
},
}))
}
12 changes: 6 additions & 6 deletions storage/adls_gen2_mount_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ import (

"github.com/databricks/terraform-provider-databricks/qa"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestResourceAdlsGen2Mount_Create(t *testing.T) {
d, err := qa.ResourceFixture{
qa.ResourceFixture{
Fixtures: []qa.HTTPFixture{
{
Method: "GET",
Expand Down Expand Up @@ -51,8 +50,9 @@ func TestResourceAdlsGen2Mount_Create(t *testing.T) {
"initialize_file_system": true,
},
Create: true,
}.Apply(t)
require.NoError(t, err)
assert.Equal(t, "this_mount", d.Id())
assert.Equal(t, "abfss://[email protected]", d.Get("source"))
Azure: true,
}.ApplyAndExpectData(t, map[string]any{
"id": "this_mount",
"source": "abfss://[email protected]",
})
}
2 changes: 1 addition & 1 deletion storage/aws_s3_mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type AWSIamMount struct {
}

// Source ...
func (m AWSIamMount) Source() string {
func (m AWSIamMount) Source(_ *common.DatabricksClient) string {
return fmt.Sprintf("s3a://%s", m.S3BucketName)
}

Expand Down
6 changes: 3 additions & 3 deletions storage/azure_blob_mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ type AzureBlobMount struct {
}

// Source ...
func (m AzureBlobMount) Source() string {
return fmt.Sprintf("wasbs://%[1]s@%[2]s.blob.core.windows.net%[3]s",
m.ContainerName, m.StorageAccountName, m.Directory)
func (m AzureBlobMount) Source(client *common.DatabricksClient) string {
return fmt.Sprintf("wasbs://%[1]s@%[2]s.blob.%[3]s%[4]s",
m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory)
}

func (m AzureBlobMount) Name() string {
Expand Down
24 changes: 15 additions & 9 deletions storage/azure_blob_mount_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
)

func TestResourceAzureBlobMountCreate(t *testing.T) {
d, err := qa.ResourceFixture{
qa.ResourceFixture{
Fixtures: []qa.HTTPFixture{
{
Method: "GET",
Expand Down Expand Up @@ -50,11 +50,12 @@ func TestResourceAzureBlobMountCreate(t *testing.T) {
"token_secret_key": "g",
"token_secret_scope": "h",
},
Azure: true,
Create: true,
}.Apply(t)
require.NoError(t, err)
assert.Equal(t, "e", d.Id())
assert.Equal(t, "wasbs://[email protected]/d", d.Get("source"))
}.ApplyAndExpectData(t, map[string]any{
"id": "e",
"source": "wasbs://[email protected]/d",
})
}

func TestResourceAzureBlobMountCreate_Error(t *testing.T) {
Expand Down Expand Up @@ -86,6 +87,7 @@ func TestResourceAzureBlobMountCreate_Error(t *testing.T) {
"token_secret_scope": "h",
},
Create: true,
Azure: true,
}.Apply(t)
require.EqualError(t, err, "Some error")
assert.Equal(t, "e", d.Id())
Expand Down Expand Up @@ -124,8 +126,9 @@ func TestResourceAzureBlobMountRead(t *testing.T) {
"token_secret_key": "g",
"token_secret_scope": "h",
},
ID: "e",
Read: true,
ID: "e",
Read: true,
Azure: true,
}.Apply(t)
require.NoError(t, err)
assert.Equal(t, "e", d.Id())
Expand Down Expand Up @@ -165,6 +168,7 @@ func TestResourceAzureBlobMountRead_NotFound(t *testing.T) {
ID: "e",
Read: true,
Removed: true,
Azure: true,
}.ApplyNoError(t)
}

Expand Down Expand Up @@ -198,8 +202,9 @@ func TestResourceAzureBlobMountRead_Error(t *testing.T) {
"token_secret_key": "g",
"token_secret_scope": "h",
},
ID: "e",
Read: true,
ID: "e",
Azure: true,
Read: true,
}.Apply(t)
require.EqualError(t, err, "Some error")
assert.Equal(t, "e", d.Id())
Expand Down Expand Up @@ -239,6 +244,7 @@ func TestResourceAzureBlobMountDelete(t *testing.T) {
},
ID: "e",
Delete: true,
Azure: true,
}.Apply(t)
require.NoError(t, err)
assert.Equal(t, "e", d.Id())
Expand Down
30 changes: 14 additions & 16 deletions storage/generic_mounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ func (m GenericMount) getBlock() Mount {
}

// Source returns URI backing the mount
func (m GenericMount) Source() string {
func (m GenericMount) Source(client *common.DatabricksClient) string {
if block := m.getBlock(); block != nil {
return block.Source()
return block.Source(client)
}
return m.URI
}
Expand Down Expand Up @@ -96,7 +96,7 @@ func parseStorageContainerId(rid string) (string, string, error) {
return match[3], match[4], nil
}

func getContainerDefaults(d *schema.ResourceData, allowed_schemas []string, suffix string) (string, string, error) {
func getContainerDefaults(d *schema.ResourceData) (string, string, error) {
rid := d.Get("resource_id").(string)
if rid != "" {
acc, cont, err := parseStorageContainerId(rid)
Expand Down Expand Up @@ -134,9 +134,8 @@ type AzureADLSGen2MountGeneric struct {
}

// Source returns ABFSS URI backing the mount
func (m *AzureADLSGen2MountGeneric) Source() string {
return fmt.Sprintf("abfss://%s@%s.dfs.core.windows.net%s",
m.ContainerName, m.StorageAccountName, m.Directory)
func (m *AzureADLSGen2MountGeneric) Source(client *common.DatabricksClient) string {
return fmt.Sprintf("abfss://%s@%s.dfs.%s%s", m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory)
}

func (m *AzureADLSGen2MountGeneric) Name() string {
Expand All @@ -145,7 +144,7 @@ func (m *AzureADLSGen2MountGeneric) Name() string {

func (m *AzureADLSGen2MountGeneric) ValidateAndApplyDefaults(d *schema.ResourceData, client *common.DatabricksClient) error {
if m.ContainerName == "" || m.StorageAccountName == "" {
acc, cont, err := getContainerDefaults(d, []string{"abfs", "abfss"}, "dfs.core.windows.net")
acc, cont, err := getContainerDefaults(d)
if err != nil {
return err
}
Expand Down Expand Up @@ -194,7 +193,7 @@ type AzureADLSGen1MountGeneric struct {
}

// Source ...
func (m *AzureADLSGen1MountGeneric) Source() string {
func (m *AzureADLSGen1MountGeneric) Source(_ *common.DatabricksClient) string {
return fmt.Sprintf("adl://%s.azuredatalakestore.net%s", m.StorageResource, m.Directory)
}

Expand Down Expand Up @@ -237,10 +236,9 @@ func (m *AzureADLSGen1MountGeneric) Config(client *common.DatabricksClient) map[
aadEndpoint := client.Config.Environment().AzureActiveDirectoryEndpoint()
return map[string]string{
m.PrefixType + ".oauth2.access.token.provider.type": "ClientCredential",

m.PrefixType + ".oauth2.client.id": m.ClientID,
m.PrefixType + ".oauth2.credential": fmt.Sprintf("{{secrets/%s/%s}}", m.SecretScope, m.SecretKey),
m.PrefixType + ".oauth2.refresh.url": fmt.Sprintf("%s%s/oauth2/token", aadEndpoint, m.TenantID),
m.PrefixType + ".oauth2.client.id": m.ClientID,
m.PrefixType + ".oauth2.credential": fmt.Sprintf("{{secrets/%s/%s}}", m.SecretScope, m.SecretKey),
m.PrefixType + ".oauth2.refresh.url": fmt.Sprintf("%s%s/oauth2/token", aadEndpoint, m.TenantID),
}
}

Expand All @@ -257,9 +255,9 @@ type AzureBlobMountGeneric struct {
}

// Source ...
func (m *AzureBlobMountGeneric) Source() string {
return fmt.Sprintf("wasbs://%[1]s@%[2]s.blob.core.windows.net%[3]s",
m.ContainerName, m.StorageAccountName, m.Directory)
func (m *AzureBlobMountGeneric) Source(client *common.DatabricksClient) string {
return fmt.Sprintf("wasbs://%[1]s@%[2]s.blob.%[3]s%[4]s",
m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory)
}

func (m *AzureBlobMountGeneric) Name() string {
Expand All @@ -268,7 +266,7 @@ func (m *AzureBlobMountGeneric) Name() string {

func (m *AzureBlobMountGeneric) ValidateAndApplyDefaults(d *schema.ResourceData, client *common.DatabricksClient) error {
if m.ContainerName == "" || m.StorageAccountName == "" {
acc, cont, err := getContainerDefaults(d, []string{"wasb", "wasbs"}, "blob.core.windows.net")
acc, cont, err := getContainerDefaults(d)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion storage/gs.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type GSMount struct {
}

// Source ...
func (m GSMount) Source() string {
func (m GSMount) Source(_ *common.DatabricksClient) string {
return fmt.Sprintf("gs://%s", m.BucketName)
}

Expand Down
6 changes: 3 additions & 3 deletions storage/mounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (

// Mount exposes generic url & extra config map options
type Mount interface {
Source() string
Source(client *common.DatabricksClient) string
Config(client *common.DatabricksClient) map[string]string

Name() string
Expand Down Expand Up @@ -96,7 +96,7 @@ func (mp MountPoint) Mount(mo Mount, client *common.DatabricksClient) (source st
raise e
mount_source = safe_mount("/mnt/%s", "%v", %s, "%s")
dbutils.notebook.exit(mount_source)
`, mp.Name, mo.Source(), extraConfigs, mp.EncryptionType) // lgtm[go/unsafe-quoting]
`, mp.Name, mo.Source(client), extraConfigs, mp.EncryptionType) // lgtm[go/unsafe-quoting]
result := mp.Exec.Execute(mp.ClusterID, "python", command)
return result.Text(), result.Err()
}
Expand Down Expand Up @@ -235,7 +235,7 @@ func mountCreate(tpl any, r common.Resource) func(context.Context, *schema.Resou
if err != nil {
return err
}
log.Printf("[INFO] Mounting %s at /mnt/%s", mountConfig.Source(), d.Id())
log.Printf("[INFO] Mounting %s at /mnt/%s", mountConfig.Source(client), d.Id())
source, err := mountPoint.Mount(mountConfig, client)
if err != nil {
return err
Expand Down
20 changes: 10 additions & 10 deletions storage/mounts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ func testMountFuncHelper(t *testing.T, mountFunc func(mp MountPoint, mount Mount

type mockMount struct{}

func (t mockMount) Source() string { return "fake-mount" }
func (t mockMount) Name() string { return "fake-mount" }
func (t mockMount) Source(_ *common.DatabricksClient) string { return "fake-mount" }
func (t mockMount) Name() string { return "fake-mount" }
func (t mockMount) Config(client *common.DatabricksClient) map[string]string {
return map[string]string{"fake-key": "fake-value"}
}
Expand All @@ -84,6 +84,14 @@ func (m mockMount) ValidateAndApplyDefaults(d *schema.ResourceData, client *comm
}

func TestMountPoint_Mount(t *testing.T) {
client := common.DatabricksClient{
DatabricksClient: &client.DatabricksClient{
Config: &config.Config{
Host: ".",
Token: ".",
},
},
}
mount := mockMount{}
expectedMountSource := "fake-mount"
expectedMountConfig := `{"fake-key":"fake-value"}`
Expand All @@ -108,14 +116,6 @@ func TestMountPoint_Mount(t *testing.T) {
dbutils.notebook.exit(mount_source)
`, mountName, expectedMountSource, expectedMountConfig)
testMountFuncHelper(t, func(mp MountPoint, mount Mount) (s string, e error) {
client := common.DatabricksClient{
DatabricksClient: &client.DatabricksClient{
Config: &config.Config{
Host: ".",
Token: ".",
},
},
}
return mp.Mount(mount, &client)
}, mount, mountName, expectedCommand)
}
Expand Down
Loading

0 comments on commit 0a055b3

Please sign in to comment.