Skip to content

Commit

Permalink
Extract functionality to detect if the CLI is running on DBR
Browse files Browse the repository at this point in the history
  • Loading branch information
pietern committed Nov 7, 2024
1 parent 26afab2 commit c8382b6
Show file tree
Hide file tree
Showing 8 changed files with 339 additions and 4 deletions.
6 changes: 2 additions & 4 deletions bundle/config/mutator/configure_wsfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ import (
"strings"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/libs/dbr"
"github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/filer"
"github.com/databricks/cli/libs/vfs"
)

const envDatabricksRuntimeVersion = "DATABRICKS_RUNTIME_VERSION"

type configureWSFS struct{}

func ConfigureWSFS() bundle.Mutator {
Expand All @@ -32,7 +30,7 @@ func (m *configureWSFS) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagno
}

// The executable must be running on DBR.
if _, ok := env.Lookup(ctx, envDatabricksRuntimeVersion); !ok {
if !dbr.RunsOnRuntime(ctx) {
return nil
}

Expand Down
59 changes: 59 additions & 0 deletions bundle/config/mutator/configure_wsfs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package mutator_test

import (
"context"
"testing"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config/mutator"
"github.com/databricks/cli/libs/dbr"
"github.com/databricks/cli/libs/vfs"
"github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/stretchr/testify/assert"
)

func mockBundleForConfigureWSFS(t *testing.T, syncRootPath string) *bundle.Bundle {
b := &bundle.Bundle{
SyncRootPath: syncRootPath,
SyncRoot: vfs.MustNew(syncRootPath),
}

w := mocks.NewMockWorkspaceClient(t)
w.WorkspaceClient.Config = &config.Config{}
b.SetWorkpaceClient(w.WorkspaceClient)

return b
}

func TestConfigureWSFS_SkipsIfNotWorkspacePrefix(t *testing.T) {
b := mockBundleForConfigureWSFS(t, "/foo")
originalSyncRoot := b.SyncRoot

ctx := context.Background()
diags := bundle.Apply(ctx, b, mutator.ConfigureWSFS())
assert.Empty(t, diags)
assert.Equal(t, originalSyncRoot, b.SyncRoot)
}

func TestConfigureWSFS_SkipsIfNotRunningOnRuntime(t *testing.T) {
b := mockBundleForConfigureWSFS(t, "/Workspace/foo")
originalSyncRoot := b.SyncRoot

ctx := context.Background()
ctx = dbr.MockRuntime(ctx, false)
diags := bundle.Apply(ctx, b, mutator.ConfigureWSFS())
assert.Empty(t, diags)
assert.Equal(t, originalSyncRoot, b.SyncRoot)
}

func TestConfigureWSFS_SwapSyncRoot(t *testing.T) {
b := mockBundleForConfigureWSFS(t, "/Workspace/foo")
originalSyncRoot := b.SyncRoot

ctx := context.Background()
ctx = dbr.MockRuntime(ctx, true)
diags := bundle.Apply(ctx, b, mutator.ConfigureWSFS())
assert.Empty(t, diags)
assert.NotEqual(t, originalSyncRoot, b.SyncRoot)
}
4 changes: 4 additions & 0 deletions cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/databricks/cli/internal/build"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/dbr"
"github.com/databricks/cli/libs/log"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -73,6 +74,9 @@ func New(ctx context.Context) *cobra.Command {
// get the context back
ctx = cmd.Context()

// Detect if the CLI is running on DBR and store this on the context.
ctx = dbr.DetectRuntime(ctx)

// Configure our user agent with the command that's about to be executed.
ctx = withCommandInUserAgent(ctx, cmd)
ctx = withCommandExecIdInUserAgent(ctx)
Expand Down
42 changes: 42 additions & 0 deletions libs/dbr/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package dbr

import "context"

// key is a private type to prevent collisions with other packages.
type key int

const (
// dbrKey is the context key for the detection result.
// The value of 1 is arbitrary and can be any number.
// Other keys in the same package must have different values.
dbrKey = key(1)
)

// DetectRuntime detects whether or not the current
// process is running inside a Databricks Runtime environment.
// It return a new context with the detection result cached.
func DetectRuntime(ctx context.Context) context.Context {
if v := ctx.Value(dbrKey); v != nil {
panic("dbr.DetectRuntime called twice on the same context")
}
return context.WithValue(ctx, dbrKey, detect(ctx))
}

// MockRuntime is a helper function to mock the detection result.
// It returns a new context with the detection result cached.
func MockRuntime(ctx context.Context, b bool) context.Context {
if v := ctx.Value(dbrKey); v != nil {
panic("dbr.MockRuntime called twice on the same context")
}
return context.WithValue(ctx, dbrKey, b)
}

// RunsOnRuntime returns the cached detection result from the context.
// It expects a context returned by [DetectRuntime] or [MockRuntime].
func RunsOnRuntime(ctx context.Context) bool {
v := ctx.Value(dbrKey)
if v == nil {
panic("dbr.RunsOnRuntime called without calling dbr.DetectRuntime first")
}
return v.(bool)
}
59 changes: 59 additions & 0 deletions libs/dbr/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package dbr

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
)

func TestContext_DetectRuntimePanics(t *testing.T) {
ctx := context.Background()

// Run detection.
ctx = DetectRuntime(ctx)

// Expect a panic if the detection is run twice.
assert.Panics(t, func() {
ctx = DetectRuntime(ctx)
})
}

func TestContext_MockRuntimePanics(t *testing.T) {
ctx := context.Background()

// Run detection.
ctx = MockRuntime(ctx, true)

// Expect a panic if the mock function is run twice.
assert.Panics(t, func() {
MockRuntime(ctx, true)
})
}

func TestContext_RunsOnRuntimePanics(t *testing.T) {
ctx := context.Background()

// Expect a panic if the detection is not run.
assert.Panics(t, func() {
RunsOnRuntime(ctx)
})
}

func TestContext_RunsOnRuntime(t *testing.T) {
ctx := context.Background()

// Run detection.
ctx = DetectRuntime(ctx)

// Expect no panic because detection has run.
assert.NotPanics(t, func() {
RunsOnRuntime(ctx)
})
}

func TestContext_RunsOnRuntimeWithMock(t *testing.T) {
ctx := context.Background()
assert.True(t, RunsOnRuntime(MockRuntime(ctx, true)))
assert.False(t, RunsOnRuntime(MockRuntime(ctx, false)))
}
35 changes: 35 additions & 0 deletions libs/dbr/detect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package dbr

import (
"context"
"os"
"runtime"

"github.com/databricks/cli/libs/env"
)

// Dereference [os.Stat] to allow mocking in tests.
var statFunc = os.Stat

// detect returns true if the current process is running on a Databricks Runtime.
// Its return value is meant to be cached in the context.
func detect(ctx context.Context) bool {
// Databricks Runtime implies Linux.
// Return early on other operating systems.
if runtime.GOOS != "linux" {
return false
}

// Databricks Runtime always has the DATABRICKS_RUNTIME_VERSION environment variable set.
if value, ok := env.Lookup(ctx, "DATABRICKS_RUNTIME_VERSION"); !ok || value == "" {
return false
}

// Expect to see a "/databricks" directory.
if fi, err := statFunc("/databricks"); err != nil || !fi.IsDir() {
return false
}

// All checks passed.
return true
}
83 changes: 83 additions & 0 deletions libs/dbr/detect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package dbr

import (
"context"
"io/fs"
"runtime"
"testing"

"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/fakefs"
"github.com/stretchr/testify/assert"
)

func requireLinux(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("skipping test on %s", runtime.GOOS)
}
}

func configureStatFunc(t *testing.T, fi fs.FileInfo, err error) {
originalFunc := statFunc
statFunc = func(name string) (fs.FileInfo, error) {
assert.Equal(t, "/databricks", name)
return fi, err
}

t.Cleanup(func() {
statFunc = originalFunc
})
}

func TestDetect_NotLinux(t *testing.T) {
if runtime.GOOS == "linux" {
t.Skip("skipping test on Linux OS")
}

ctx := context.Background()
assert.False(t, detect(ctx))
}

func TestDetect_Env(t *testing.T) {
requireLinux(t)

// Configure other checks to pass.
configureStatFunc(t, fakefs.FileInfo{FakeDir: true}, nil)

t.Run("empty", func(t *testing.T) {
ctx := env.Set(context.Background(), "DATABRICKS_RUNTIME_VERSION", "")
assert.False(t, detect(ctx))
})

t.Run("non-empty cluster", func(t *testing.T) {
ctx := env.Set(context.Background(), "DATABRICKS_RUNTIME_VERSION", "15.4")
assert.True(t, detect(ctx))
})

t.Run("non-empty serverless", func(t *testing.T) {
ctx := env.Set(context.Background(), "DATABRICKS_RUNTIME_VERSION", "client.1.13")
assert.True(t, detect(ctx))
})
}

func TestDetect_Stat(t *testing.T) {
requireLinux(t)

// Configure other checks to pass.
ctx := env.Set(context.Background(), "DATABRICKS_RUNTIME_VERSION", "non-empty")

t.Run("error", func(t *testing.T) {
configureStatFunc(t, nil, fs.ErrNotExist)
assert.False(t, detect(ctx))
})

t.Run("not a directory", func(t *testing.T) {
configureStatFunc(t, fakefs.FileInfo{}, nil)
assert.False(t, detect(ctx))
})

t.Run("directory", func(t *testing.T) {
configureStatFunc(t, fakefs.FileInfo{FakeDir: true}, nil)
assert.True(t, detect(ctx))
})
}
55 changes: 55 additions & 0 deletions libs/fakefs/fakefs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package fakefs

import (
"io/fs"
"time"
)

// DirEntry is a fake implementation of [fs.DirEntry].
type DirEntry struct {
FileInfo
}

func (entry DirEntry) Type() fs.FileMode {
typ := fs.ModePerm
if entry.FakeDir {
typ |= fs.ModeDir
}
return typ
}

func (entry DirEntry) Info() (fs.FileInfo, error) {
return entry.FileInfo, nil
}

// FileInfo is a fake implementation of [fs.FileInfo].
type FileInfo struct {
FakeName string
FakeSize int64
FakeDir bool
FakeMode fs.FileMode
}

func (info FileInfo) Name() string {
return info.FakeName
}

func (info FileInfo) Size() int64 {
return info.FakeSize
}

func (info FileInfo) Mode() fs.FileMode {
return info.FakeMode
}

func (info FileInfo) ModTime() time.Time {
return time.Now()
}

func (info FileInfo) IsDir() bool {
return info.FakeDir
}

func (info FileInfo) Sys() any {
return nil
}

0 comments on commit c8382b6

Please sign in to comment.