Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mgyucht committed Apr 23, 2024
1 parent 762f7b0 commit 88234ce
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 27 deletions.
23 changes: 12 additions & 11 deletions databricks/sdk/dbutils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import base64
import json
import logging
import os.path
import threading
import typing
from typing import Callable, List, Dict
from collections import namedtuple

from .core import ApiClient, Config, DatabricksError
Expand Down Expand Up @@ -34,7 +35,7 @@ class SecretMetadata(namedtuple('SecretMetadata', ['key'])):
class _FsUtil:
""" Manipulates the Databricks filesystem (DBFS) """

def __init__(self, dbfs_ext: dbfs_ext.DbfsExt, proxy_factory: typing.Callable[[str], '_ProxyUtil']):
def __init__(self, dbfs_ext: dbfs_ext.DbfsExt, proxy_factory: Callable[[str], '_ProxyUtil']):
self._dbfs = dbfs_ext
self._proxy_factory = proxy_factory

Expand All @@ -48,9 +49,9 @@ def head(self, file: str, maxBytes: int = 65536) -> str:
with self._dbfs.download(file) as f:
return f.read(maxBytes).decode('utf8')

def ls(self, dir: str) -> typing.List[FileInfo]:
def ls(self, dir: str) -> List[FileInfo]:
"""Lists the contents of a directory """
return list(self._dbfs.list(dir))
return [FileInfo(f.path, os.path.basename(f.path), f.file_size, f.modification_time) for f in self._dbfs.list(dir)]

def mkdirs(self, dir: str) -> bool:
"""Creates the given directory if it does not exist, also creating any necessary parent directories """
Expand Down Expand Up @@ -78,7 +79,7 @@ def mount(self,
mount_point: str,
encryption_type: str = None,
owner: str = None,
extra_configs: 'typing.Dict[str, str]' = None) -> bool:
extra_configs: Dict[str, str] = None) -> bool:
"""Mounts the given source directory into DBFS at the given mount point"""
fs = self._proxy_factory('fs')
kwargs = {}
Expand All @@ -100,7 +101,7 @@ def updateMount(self,
mount_point: str,
encryption_type: str = None,
owner: str = None,
extra_configs: 'typing.Dict[str, str]' = None) -> bool:
extra_configs: Dict[str, str] = None) -> bool:
""" Similar to mount(), but updates an existing mount point (if present) instead of creating a new one """
fs = self._proxy_factory('fs')
kwargs = {}
Expand All @@ -112,7 +113,7 @@ def updateMount(self,
kwargs['extra_configs'] = extra_configs
return fs.updateMount(source, mount_point, **kwargs)

def mounts(self) -> typing.List[MountInfo]:
def mounts(self) -> List[MountInfo]:
""" Displays information about what is mounted within DBFS """
result = []
fs = self._proxy_factory('fs')
Expand Down Expand Up @@ -145,13 +146,13 @@ def get(self, scope: str, key: str) -> str:
string_value = val.decode()
return string_value

def list(self, scope) -> typing.List[SecretMetadata]:
def list(self, scope) -> List[SecretMetadata]:
"""Lists the metadata for secrets within the specified scope."""

# transform from SDK dataclass to dbutils-compatible namedtuple
return [SecretMetadata(v.key) for v in self._api.list_secrets(scope)]

def listScopes(self) -> typing.List[SecretScope]:
def listScopes(self) -> List[SecretScope]:
"""Lists the available scopes."""

# transform from SDK dataclass to dbutils-compatible namedtuple
Expand Down Expand Up @@ -240,7 +241,7 @@ class _ProxyUtil:
"""Enables temporary workaround to call remote in-REPL dbutils without having to re-implement them"""

def __init__(self, *, command_execution: compute.CommandExecutionAPI,
context_factory: typing.Callable[[],
context_factory: Callable[[],
compute.ContextStatusResponse], cluster_id: str, name: str):
self._commands = command_execution
self._cluster_id = cluster_id
Expand All @@ -262,7 +263,7 @@ def __getattr__(self, method: str) -> '_ProxyCall':
class _ProxyCall:

def __init__(self, *, command_execution: compute.CommandExecutionAPI,
context_factory: typing.Callable[[], compute.ContextStatusResponse], cluster_id: str,
context_factory: Callable[[], compute.ContextStatusResponse], cluster_id: str,
util: str, method: str):
self._commands = command_execution
self._cluster_id = cluster_id
Expand Down
6 changes: 3 additions & 3 deletions databricks/sdk/mixins/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def open(self, *, read=False, write=False, overwrite=False):
def list(self, recursive=False) -> Generator[files.FileInfo, None, None]:
if not self.is_dir:
st = self._path.stat()
yield files.FileInfo(path=str(self._path.absolute()),
yield files.FileInfo(path='file:' + str(self._path.absolute()),
is_dir=False,
file_size=st.st_size,
modification_time=int(st.st_mtime_ns / 1e6),
Expand All @@ -365,7 +365,7 @@ def list(self, recursive=False) -> Generator[files.FileInfo, None, None]:
queue.append(leaf)
continue
info = leaf.stat()
yield files.FileInfo(path=str(leaf.absolute()),
yield files.FileInfo(path='file:' + str(leaf.absolute()),
is_dir=False,
file_size=info.st_size,
modification_time=int(info.st_mtime_ns / 1e6),
Expand Down Expand Up @@ -491,7 +491,7 @@ def open(self, *, read=False, write=False, overwrite=False) -> BinaryIO:
def list(self, *, recursive=False) -> Generator[files.FileInfo, None, None]:
if not self.is_dir:
meta = self._api.get_status(self.as_string)
yield files.FileInfo(path=self.as_string,
yield files.FileInfo(path='dbfs:' + self.as_string,
is_dir=False,
file_size=meta.file_size,
modification_time=meta.modification_time,
Expand Down
21 changes: 21 additions & 0 deletions tests/integration/test_dbutils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import logging
import os

import pytest

Expand Down Expand Up @@ -152,6 +153,26 @@ def _test_mv_dir(fs, base_path, random):
fs.ls(path)


def test_dbutils_dbfs_mv_local_to_remote(w, random, tmp_path):
fs = w.dbutils.fs
_test_mv_local_to_remote(fs, 'dbfs:/tmp', random, tmp_path)


def test_dbutils_volumes_mv_local_to_remote(ucws, dbfs_volume, random, tmp_path):
fs = ucws.dbutils.fs
_test_mv_local_to_remote(fs, dbfs_volume, random, tmp_path)


def _test_mv_local_to_remote(fs, base_path, random, tmp_path):
path = base_path + "/dbc_qa_file-" + random()
with open(tmp_path / "test", "w") as f:
f.write("test")
fs.mv('file:' + str(tmp_path / "test"), path)
output = fs.head(path)
assert output == "test"
assert os.listdir(tmp_path) == []


def test_dbutils_dbfs_rm_file(w, random):
fs = w.dbutils.fs
_test_rm_file(fs, 'dbfs:/tmp', random)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_dbfs_mixins.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from databricks.sdk.errors import NotFound


def test_moving_dbfs_file_to_local_dir(config, tmp_path, mocker):
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.files import FileInfo, ReadResponse
Expand Down Expand Up @@ -35,17 +38,14 @@ def test_moving_local_dir_to_dbfs(config, tmp_path, mocker):

mocker.patch('databricks.sdk.service.files.DbfsAPI.create', return_value=CreateResponse(123))

def fake(path: str):
assert path == 'a'
raise DatabricksError('nope', error_code='RESOURCE_DOES_NOT_EXIST')

mocker.patch('databricks.sdk.service.files.DbfsAPI.get_status', wraps=fake)
get_status = mocker.patch('databricks.sdk.service.files.DbfsAPI.get_status', side_effect=NotFound())
add_block = mocker.patch('databricks.sdk.service.files.DbfsAPI.add_block')
close = mocker.patch('databricks.sdk.service.files.DbfsAPI.close')

w = WorkspaceClient(config=config)
w.dbfs.move_(f'file:{tmp_path}', 'a', recursive=True)

get_status.assert_called_with('a')
close.assert_called_with(123)
add_block.assert_called_with(123, 'aGVsbG8=')
assert not (tmp_path / 'a').exists()
28 changes: 20 additions & 8 deletions tests/test_dbutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from .conftest import raises

from databricks.sdk.dbutils import FileInfo as DBUtilsFileInfo
from databricks.sdk.service.files import ReadResponse, FileInfo


@pytest.fixture
def dbutils(config):
Expand All @@ -18,30 +21,37 @@ def test_fs_cp(dbutils, mocker):


def test_fs_head(dbutils, mocker):
from databricks.sdk.service.files import ReadResponse
inner = mocker.patch('databricks.sdk.service.files.DbfsAPI.read',
return_value=ReadResponse(data='aGVsbG8='))
return_value=ReadResponse(data='aGVsbG8=', bytes_read=5))
inner2 = mocker.patch('databricks.sdk.service.files.DbfsAPI.get_status',
return_value=FileInfo(path='a', is_dir=False, file_size=5))

result = dbutils.fs.head('a')

inner.assert_called_with('a', length=65536, offset=0)
inner2.assert_called_with('a')
assert result == 'hello'


def test_fs_ls(dbutils, mocker):
from databricks.sdk.service.files import FileInfo
inner = mocker.patch('databricks.sdk.mixins.files.DbfsExt.list',
inner = mocker.patch('databricks.sdk.service.files.DbfsAPI.list',
return_value=[
FileInfo(path='b', file_size=10, modification_time=20),
FileInfo(path='c', file_size=30, modification_time=40),
FileInfo(path='a/b', file_size=10, modification_time=20),
FileInfo(path='a/c', file_size=30, modification_time=40),
])
inner2 = mocker.patch('databricks.sdk.service.files.DbfsAPI.get_status',
side_effect=[
FileInfo(path='a', is_dir=True, file_size=5),
FileInfo(path='a/b', is_dir=False, file_size=5),
FileInfo(path='a/c', is_dir=False, file_size=5),
])

result = dbutils.fs.ls('a')

from databricks.sdk.dbutils import FileInfo
inner.assert_called_with('a')
assert len(result) == 2
assert result[0] == FileInfo('dbfs:b', 'b', 10, 20)
assert result[0] == DBUtilsFileInfo('dbfs:a/b', 'b', 10, 20)
assert result[1] == DBUtilsFileInfo('dbfs:a/c', 'c', 30, 40)


def test_fs_mkdirs(dbutils, mocker):
Expand Down Expand Up @@ -85,6 +95,8 @@ def write(self, contents):

def test_fs_rm(dbutils, mocker):
inner = mocker.patch('databricks.sdk.service.files.DbfsAPI.delete')
inner2 = mocker.patch('databricks.sdk.service.files.DbfsAPI.get_status',
return_value=FileInfo(path='a', is_dir=False, file_size=5))

dbutils.fs.rm('a')

Expand Down

0 comments on commit 88234ce

Please sign in to comment.