diff --git a/libs/notebook/detect.go b/libs/notebook/detect.go index 582a88479f..480fae212e 100644 --- a/libs/notebook/detect.go +++ b/libs/notebook/detect.go @@ -139,3 +139,46 @@ func Detect(name string) (notebook bool, language workspace.Language, err error) b := filepath.Base(name) return DetectWithFS(os.DirFS(d), b) } + +type inMemoryFile struct { + content []byte + readIndex int64 +} + +type inMemoryFS struct { + content []byte +} + +func (f *inMemoryFile) Close() error { + return nil +} + +func (f *inMemoryFile) Stat() (fs.FileInfo, error) { + return nil, nil +} + +func (f *inMemoryFile) Read(b []byte) (n int, err error) { + if f.readIndex >= int64(len(f.content)) { + err = io.EOF + return + } + + n = copy(b, f.content[f.readIndex:]) + f.readIndex += int64(n) + return +} + +func (fs inMemoryFS) Open(name string) (fs.File, error) { + return &inMemoryFile{ + content: fs.content, + readIndex: 0, + }, nil +} + +func DetectWithContent(name string, content []byte) (notebook bool, language workspace.Language, err error) { + fs := inMemoryFS{ + content: content, + } + + return DetectWithFS(fs, name) +} diff --git a/libs/notebook/detect_test.go b/libs/notebook/detect_test.go index ad89d6dd53..a5892fe745 100644 --- a/libs/notebook/detect_test.go +++ b/libs/notebook/detect_test.go @@ -117,3 +117,22 @@ func TestDetectWithObjectInfo(t *testing.T) { assert.True(t, nb) assert.Equal(t, workspace.LanguagePython, lang) } + +func TestInMemoryFiles(t *testing.T) { + isNotebook, language, err := DetectWithContent("hello.py", []byte("# Databricks notebook source\n print('hello')")) + assert.True(t, isNotebook) + assert.Equal(t, workspace.LanguagePython, language) + require.NoError(t, err) + + isNotebook, language, err = DetectWithContent("hello.py", []byte("print('hello')")) + assert.False(t, isNotebook) + assert.Equal(t, workspace.Language(""), language) + require.NoError(t, err) + + fileContent, err := os.ReadFile("./testdata/py_ipynb.ipynb") + require.NoError(t, err) + isNotebook, language, err = DetectWithContent("py_ipynb.ipynb", fileContent) + assert.True(t, isNotebook) + assert.Equal(t, workspace.LanguagePython, language) + require.NoError(t, err) +} diff --git a/libs/template/file.go b/libs/template/file.go index c27857c9ca..3f1a3d7238 100644 --- a/libs/template/file.go +++ b/libs/template/file.go @@ -11,6 +11,7 @@ import ( "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/filer" + "github.com/databricks/cli/libs/notebook" "github.com/databricks/cli/libs/runtime" "github.com/databricks/databricks-sdk-go/service/workspace" ) @@ -107,12 +108,17 @@ func (f *inMemoryFile) PersistToDisk() error { return writeFile(f.ctx, path, f.content, f.perm) } -func shouldUseImportNotebook(ctx context.Context, path string) bool { - return strings.HasPrefix(path, "/Workspace/") && runtime.RunsOnDatabricks(ctx) && strings.HasSuffix(path, ".ipynb") +func shouldUseImportNotebook(ctx context.Context, path string, content []byte) bool { + if strings.HasPrefix(path, "/Workspace/") && runtime.RunsOnDatabricks(ctx) { + isNotebook, _, _ := notebook.DetectWithContent(path, content) + return isNotebook + } else { + return false + } } func writeFile(ctx context.Context, path string, content []byte, perm fs.FileMode) error { - if shouldUseImportNotebook(ctx, path) { + if shouldUseImportNotebook(ctx, path, content) { return importNotebook(ctx, path, content) } else { return os.WriteFile(path, content, perm) diff --git a/libs/template/file_test.go b/libs/template/file_test.go index 9fe3c36fd6..e1b37e5599 100644 --- a/libs/template/file_test.go +++ b/libs/template/file_test.go @@ -118,16 +118,18 @@ func TestTemplateCopyFilePersistToDiskForWindows(t *testing.T) { func TestShouldUseImportNotebook(t *testing.T) { ctx := context.Background() - assert.False(t, shouldUseImportNotebook(ctx, "./foo/bar")) - assert.False(t, shouldUseImportNotebook(ctx, "./foo/bar.ipynb")) - assert.False(t, shouldUseImportNotebook(ctx, "/Workspace/foo/bar")) - assert.False(t, shouldUseImportNotebook(ctx, "/Workspace/foo/bar.ipynb")) + data := []byte("# Databricks notebook source\n print('hello')") + + assert.False(t, shouldUseImportNotebook(ctx, "./foo/bar", data)) + assert.False(t, shouldUseImportNotebook(ctx, "./foo/bar.ipynb", data)) + assert.False(t, shouldUseImportNotebook(ctx, "/Workspace/foo/bar", data)) + assert.False(t, shouldUseImportNotebook(ctx, "/Workspace/foo/bar.ipynb", data)) t.Setenv("DATABRICKS_RUNTIME_VERSION", "14.3") - assert.False(t, shouldUseImportNotebook(ctx, "./foo/bar")) - assert.False(t, shouldUseImportNotebook(ctx, "./foo/bar.ipynb")) - assert.False(t, shouldUseImportNotebook(ctx, "/Workspace/foo/bar")) - assert.True(t, shouldUseImportNotebook(ctx, "/Workspace/foo/bar.ipynb")) + assert.False(t, shouldUseImportNotebook(ctx, "./foo/bar", data)) + assert.False(t, shouldUseImportNotebook(ctx, "./foo/bar.ipynb", data)) + assert.False(t, shouldUseImportNotebook(ctx, "/Workspace/foo/bar", data)) + assert.True(t, shouldUseImportNotebook(ctx, "/Workspace/foo/bar.py", data)) } func TestImportNotebook(t *testing.T) {