diff --git a/databricks/sdk/dbutils.py b/databricks/sdk/dbutils.py index 5ec013856..20a05f606 100644 --- a/databricks/sdk/dbutils.py +++ b/databricks/sdk/dbutils.py @@ -1,8 +1,9 @@ import base64 import json import logging +import os.path import threading -import typing +from typing import Callable, List, Dict from collections import namedtuple from .core import ApiClient, Config, DatabricksError @@ -34,7 +35,7 @@ class SecretMetadata(namedtuple('SecretMetadata', ['key'])): class _FsUtil: """ Manipulates the Databricks filesystem (DBFS) """ - def __init__(self, dbfs_ext: dbfs_ext.DbfsExt, proxy_factory: typing.Callable[[str], '_ProxyUtil']): + def __init__(self, dbfs_ext: dbfs_ext.DbfsExt, proxy_factory: Callable[[str], '_ProxyUtil']): self._dbfs = dbfs_ext self._proxy_factory = proxy_factory @@ -48,9 +49,9 @@ def head(self, file: str, maxBytes: int = 65536) -> str: with self._dbfs.download(file) as f: return f.read(maxBytes).decode('utf8') - def ls(self, dir: str) -> typing.List[FileInfo]: + def ls(self, dir: str) -> List[FileInfo]: """Lists the contents of a directory """ - return list(self._dbfs.list(dir)) + return [FileInfo(f.path, os.path.basename(f.path), f.file_size, f.modification_time) for f in self._dbfs.list(dir)] def mkdirs(self, dir: str) -> bool: """Creates the given directory if it does not exist, also creating any necessary parent directories """ @@ -78,7 +79,7 @@ def mount(self, mount_point: str, encryption_type: str = None, owner: str = None, - extra_configs: 'typing.Dict[str, str]' = None) -> bool: + extra_configs: Dict[str, str] = None) -> bool: """Mounts the given source directory into DBFS at the given mount point""" fs = self._proxy_factory('fs') kwargs = {} @@ -100,7 +101,7 @@ def updateMount(self, mount_point: str, encryption_type: str = None, owner: str = None, - extra_configs: 'typing.Dict[str, str]' = None) -> bool: + extra_configs: Dict[str, str] = None) -> bool: """ Similar to mount(), but updates an existing mount point (if present) instead of creating a new one """ fs = self._proxy_factory('fs') kwargs = {} @@ -112,7 +113,7 @@ def updateMount(self, kwargs['extra_configs'] = extra_configs return fs.updateMount(source, mount_point, **kwargs) - def mounts(self) -> typing.List[MountInfo]: + def mounts(self) -> List[MountInfo]: """ Displays information about what is mounted within DBFS """ result = [] fs = self._proxy_factory('fs') @@ -145,13 +146,13 @@ def get(self, scope: str, key: str) -> str: string_value = val.decode() return string_value - def list(self, scope) -> typing.List[SecretMetadata]: + def list(self, scope) -> List[SecretMetadata]: """Lists the metadata for secrets within the specified scope.""" # transform from SDK dataclass to dbutils-compatible namedtuple return [SecretMetadata(v.key) for v in self._api.list_secrets(scope)] - def listScopes(self) -> typing.List[SecretScope]: + def listScopes(self) -> List[SecretScope]: """Lists the available scopes.""" # transform from SDK dataclass to dbutils-compatible namedtuple @@ -240,7 +241,7 @@ class _ProxyUtil: """Enables temporary workaround to call remote in-REPL dbutils without having to re-implement them""" def __init__(self, *, command_execution: compute.CommandExecutionAPI, - context_factory: typing.Callable[[], + context_factory: Callable[[], compute.ContextStatusResponse], cluster_id: str, name: str): self._commands = command_execution self._cluster_id = cluster_id @@ -262,7 +263,7 @@ def __getattr__(self, method: str) -> '_ProxyCall': class _ProxyCall: def __init__(self, *, command_execution: compute.CommandExecutionAPI, - context_factory: typing.Callable[[], compute.ContextStatusResponse], cluster_id: str, + context_factory: Callable[[], compute.ContextStatusResponse], cluster_id: str, util: str, method: str): self._commands = command_execution self._cluster_id = cluster_id diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index 2300818f8..413a4b534 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -350,7 +350,7 @@ def open(self, *, read=False, write=False, overwrite=False): def list(self, recursive=False) -> Generator[files.FileInfo, None, None]: if not self.is_dir: st = self._path.stat() - yield files.FileInfo(path=str(self._path.absolute()), + yield files.FileInfo(path='file:' + str(self._path.absolute()), is_dir=False, file_size=st.st_size, modification_time=int(st.st_mtime_ns / 1e6), @@ -365,7 +365,7 @@ def list(self, recursive=False) -> Generator[files.FileInfo, None, None]: queue.append(leaf) continue info = leaf.stat() - yield files.FileInfo(path=str(leaf.absolute()), + yield files.FileInfo(path='file:' + str(leaf.absolute()), is_dir=False, file_size=info.st_size, modification_time=int(info.st_mtime_ns / 1e6), @@ -491,7 +491,7 @@ def open(self, *, read=False, write=False, overwrite=False) -> BinaryIO: def list(self, *, recursive=False) -> Generator[files.FileInfo, None, None]: if not self.is_dir: meta = self._api.get_status(self.as_string) - yield files.FileInfo(path=self.as_string, + yield files.FileInfo(path='dbfs:' + self.as_string, is_dir=False, file_size=meta.file_size, modification_time=meta.modification_time, diff --git a/tests/integration/test_dbutils.py b/tests/integration/test_dbutils.py index 85773abaf..3c4980aeb 100644 --- a/tests/integration/test_dbutils.py +++ b/tests/integration/test_dbutils.py @@ -1,5 +1,6 @@ import base64 import logging +import os import pytest @@ -152,6 +153,26 @@ def _test_mv_dir(fs, base_path, random): fs.ls(path) +def test_dbutils_dbfs_mv_local_to_remote(w, random, tmp_path): + fs = w.dbutils.fs + _test_mv_local_to_remote(fs, 'dbfs:/tmp', random, tmp_path) + + +def test_dbutils_volumes_mv_local_to_remote(ucws, dbfs_volume, random, tmp_path): + fs = ucws.dbutils.fs + _test_mv_local_to_remote(fs, dbfs_volume, random, tmp_path) + + +def _test_mv_local_to_remote(fs, base_path, random, tmp_path): + path = base_path + "/dbc_qa_file-" + random() + with open(tmp_path / "test", "w") as f: + f.write("test") + fs.mv('file:' + str(tmp_path / "test"), path) + output = fs.head(path) + assert output == "test" + assert os.listdir(tmp_path) == [] + + def test_dbutils_dbfs_rm_file(w, random): fs = w.dbutils.fs _test_rm_file(fs, 'dbfs:/tmp', random) diff --git a/tests/test_dbfs_mixins.py b/tests/test_dbfs_mixins.py index 6d63eaf10..1036ca99a 100644 --- a/tests/test_dbfs_mixins.py +++ b/tests/test_dbfs_mixins.py @@ -1,3 +1,6 @@ +from databricks.sdk.errors import NotFound + + def test_moving_dbfs_file_to_local_dir(config, tmp_path, mocker): from databricks.sdk import WorkspaceClient from databricks.sdk.service.files import FileInfo, ReadResponse @@ -35,17 +38,14 @@ def test_moving_local_dir_to_dbfs(config, tmp_path, mocker): mocker.patch('databricks.sdk.service.files.DbfsAPI.create', return_value=CreateResponse(123)) - def fake(path: str): - assert path == 'a' - raise DatabricksError('nope', error_code='RESOURCE_DOES_NOT_EXIST') - - mocker.patch('databricks.sdk.service.files.DbfsAPI.get_status', wraps=fake) + get_status = mocker.patch('databricks.sdk.service.files.DbfsAPI.get_status', side_effect=NotFound()) add_block = mocker.patch('databricks.sdk.service.files.DbfsAPI.add_block') close = mocker.patch('databricks.sdk.service.files.DbfsAPI.close') w = WorkspaceClient(config=config) w.dbfs.move_(f'file:{tmp_path}', 'a', recursive=True) + get_status.assert_called_with('a') close.assert_called_with(123) add_block.assert_called_with(123, 'aGVsbG8=') assert not (tmp_path / 'a').exists() diff --git a/tests/test_dbutils.py b/tests/test_dbutils.py index 5a8cb1edf..1d96d9e9f 100644 --- a/tests/test_dbutils.py +++ b/tests/test_dbutils.py @@ -2,6 +2,9 @@ from .conftest import raises +from databricks.sdk.dbutils import FileInfo as DBUtilsFileInfo +from databricks.sdk.service.files import ReadResponse, FileInfo + @pytest.fixture def dbutils(config): @@ -18,30 +21,37 @@ def test_fs_cp(dbutils, mocker): def test_fs_head(dbutils, mocker): - from databricks.sdk.service.files import ReadResponse inner = mocker.patch('databricks.sdk.service.files.DbfsAPI.read', - return_value=ReadResponse(data='aGVsbG8=')) + return_value=ReadResponse(data='aGVsbG8=', bytes_read=5)) + inner2 = mocker.patch('databricks.sdk.service.files.DbfsAPI.get_status', + return_value=FileInfo(path='a', is_dir=False, file_size=5)) result = dbutils.fs.head('a') inner.assert_called_with('a', length=65536, offset=0) + inner2.assert_called_with('a') assert result == 'hello' def test_fs_ls(dbutils, mocker): - from databricks.sdk.service.files import FileInfo - inner = mocker.patch('databricks.sdk.mixins.files.DbfsExt.list', + inner = mocker.patch('databricks.sdk.service.files.DbfsAPI.list', return_value=[ - FileInfo(path='b', file_size=10, modification_time=20), - FileInfo(path='c', file_size=30, modification_time=40), + FileInfo(path='a/b', file_size=10, modification_time=20), + FileInfo(path='a/c', file_size=30, modification_time=40), ]) + inner2 = mocker.patch('databricks.sdk.service.files.DbfsAPI.get_status', + side_effect=[ + FileInfo(path='a', is_dir=True, file_size=5), + FileInfo(path='a/b', is_dir=False, file_size=5), + FileInfo(path='a/c', is_dir=False, file_size=5), + ]) result = dbutils.fs.ls('a') - from databricks.sdk.dbutils import FileInfo inner.assert_called_with('a') assert len(result) == 2 - assert result[0] == FileInfo('dbfs:b', 'b', 10, 20) + assert result[0] == DBUtilsFileInfo('dbfs:a/b', 'b', 10, 20) + assert result[1] == DBUtilsFileInfo('dbfs:a/c', 'c', 30, 40) def test_fs_mkdirs(dbutils, mocker): @@ -85,6 +95,8 @@ def write(self, contents): def test_fs_rm(dbutils, mocker): inner = mocker.patch('databricks.sdk.service.files.DbfsAPI.delete') + inner2 = mocker.patch('databricks.sdk.service.files.DbfsAPI.get_status', + return_value=FileInfo(path='a', is_dir=False, file_size=5)) dbutils.fs.rm('a')