diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index 068069f0..d27110b8 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -5,7 +5,7 @@ from databricks.sdk import azure from databricks.sdk.credentials_provider import CredentialsStrategy from databricks.sdk.mixins.compute import ClustersExt -from databricks.sdk.mixins.files import DbfsExt +from databricks.sdk.mixins.files import DbfsExt, FilesExt from databricks.sdk.mixins.jobs import JobsExt from databricks.sdk.mixins.open_ai_client import ServingEndpointsExt from databricks.sdk.mixins.workspace import WorkspaceExt @@ -114,6 +114,13 @@ def _make_dbutils(config: client.Config): return runtime_dbutils +def _make_files_client(apiClient: client.ApiClient, config: client.Config): + if config.enable_experimental_files_api_client: + return FilesExt(apiClient, config) + else: + return FilesAPI(apiClient) + + class WorkspaceClient: """ The WorkspaceClient is a client for the workspace-level Databricks REST API. @@ -203,7 +210,7 @@ def __init__(self, self._dbsql_permissions = DbsqlPermissionsAPI(self._api_client) self._experiments = ExperimentsAPI(self._api_client) self._external_locations = ExternalLocationsAPI(self._api_client) - self._files = FilesAPI(self._api_client) + self._files = _make_files_client(self._api_client, self._config) self._functions = FunctionsAPI(self._api_client) self._genie = GenieAPI(self._api_client) self._git_credentials = GitCredentialsAPI(self._api_client) diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py index ed85dc47..e61dd39c 100644 --- a/databricks/sdk/_base_client.py +++ b/databricks/sdk/_base_client.py @@ -1,6 +1,7 @@ import io import logging import urllib.parse +from abc import ABC, abstractmethod from datetime import timedelta from types import TracebackType from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List, @@ -285,8 +286,20 @@ def _record_request_log(self, response: requests.Response, raw: bool = False) -> logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate()) +class _RawResponse(ABC): + + @abstractmethod + # follows Response signature: https://github.com/psf/requests/blob/main/src/requests/models.py#L799 + def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False): + pass + + @abstractmethod + def close(self): + pass + + class _StreamingResponse(BinaryIO): - _response: requests.Response + _response: _RawResponse _buffer: bytes _content: Union[Iterator[bytes], None] _chunk_size: Union[int, None] @@ -298,7 +311,7 @@ def fileno(self) -> int: def flush(self) -> int: pass - def __init__(self, response: requests.Response, chunk_size: Union[int, None] = None): + def __init__(self, response: _RawResponse, chunk_size: Union[int, None] = None): self._response = response self._buffer = b'' self._content = None @@ -308,7 +321,7 @@ def _open(self) -> None: if self._closed: raise ValueError("I/O operation on closed file") if not self._content: - self._content = self._response.iter_content(chunk_size=self._chunk_size) + self._content = self._response.iter_content(chunk_size=self._chunk_size, decode_unicode=False) def __enter__(self) -> BinaryIO: self._open() diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 387fa65c..490c6ba4 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -92,6 +92,11 @@ class Config: max_connections_per_pool: int = ConfigAttribute() databricks_environment: Optional[DatabricksEnvironment] = None + enable_experimental_files_api_client: bool = ConfigAttribute( + env='DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT') + files_api_client_download_max_total_recovers = None + files_api_client_download_max_total_recovers_without_progressing = 1 + def __init__( self, *, diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index 1e109a1a..678b4b63 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import logging import os import pathlib import platform @@ -8,19 +9,27 @@ import sys from abc import ABC, abstractmethod from collections import deque +from collections.abc import Iterator from io import BytesIO from types import TracebackType from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Generator, Iterable, - Iterator, Type, Union) + Optional, Type, Union) from urllib import parse +from requests import RequestException + +from .._base_client import _RawResponse, _StreamingResponse from .._property import _cached_property from ..errors import NotFound from ..service import files +from ..service._internal import _escape_multi_segment_path_parameter +from ..service.files import DownloadResponse if TYPE_CHECKING: from _typeshed import Self +_LOG = logging.getLogger(__name__) + class _DbfsIO(BinaryIO): MAX_CHUNK_SIZE = 1024 * 1024 @@ -636,3 +645,177 @@ def delete(self, path: str, *, recursive=False): if p.is_dir and not recursive: raise IOError('deleting directories requires recursive flag') p.delete(recursive=recursive) + + +class FilesExt(files.FilesAPI): + __doc__ = files.FilesAPI.__doc__ + + def __init__(self, api_client, config: Config): + super().__init__(api_client) + self._config = config.copy() + + def download(self, file_path: str) -> DownloadResponse: + """Download a file. + + Downloads a file of any size. The file contents are the response body. + This is a standard HTTP file download, not a JSON RPC. + + It is strongly recommended, for fault tolerance reasons, + to iteratively consume from the stream with a maximum read(size) + defined instead of using indefinite-size reads. + + :param file_path: str + The remote path of the file, e.g. /Volumes/path/to/your/file + + :returns: :class:`DownloadResponse` + """ + + initial_response: DownloadResponse = self._download_raw_stream(file_path=file_path, + start_byte_offset=0, + if_unmodified_since_timestamp=None) + + wrapped_response = self._wrap_stream(file_path, initial_response) + initial_response.contents._response = wrapped_response + return initial_response + + def _download_raw_stream(self, + file_path: str, + start_byte_offset: int, + if_unmodified_since_timestamp: Optional[str] = None) -> DownloadResponse: + headers = {'Accept': 'application/octet-stream', } + + if start_byte_offset and not if_unmodified_since_timestamp: + raise Exception("if_unmodified_since_timestamp is required if start_byte_offset is specified") + + if start_byte_offset: + headers['Range'] = f'bytes={start_byte_offset}-' + + if if_unmodified_since_timestamp: + headers['If-Unmodified-Since'] = if_unmodified_since_timestamp + + response_headers = ['content-length', 'content-type', 'last-modified', ] + res = self._api.do('GET', + f'/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}', + headers=headers, + response_headers=response_headers, + raw=True) + + result = DownloadResponse.from_dict(res) + if not isinstance(result.contents, _StreamingResponse): + raise Exception("Internal error: response contents is of unexpected type: " + + type(result.contents).__name__) + + return result + + def _wrap_stream(self, file_path: str, downloadResponse: DownloadResponse): + underlying_response = _ResilientIterator._extract_raw_response(downloadResponse) + return _ResilientResponse(self, + file_path, + downloadResponse.last_modified, + offset=0, + underlying_response=underlying_response) + + +class _ResilientResponse(_RawResponse): + + def __init__(self, api: FilesExt, file_path: str, file_last_modified: str, offset: int, + underlying_response: _RawResponse): + self.api = api + self.file_path = file_path + self.underlying_response = underlying_response + self.offset = offset + self.file_last_modified = file_last_modified + + def iter_content(self, chunk_size=1, decode_unicode=False): + if decode_unicode: + raise ValueError('Decode unicode is not supported') + + iterator = self.underlying_response.iter_content(chunk_size=chunk_size, decode_unicode=False) + self.iterator = _ResilientIterator(iterator, self.file_path, self.file_last_modified, self.offset, + self.api, chunk_size) + return self.iterator + + def close(self): + self.iterator.close() + + +class _ResilientIterator(Iterator): + # This class tracks current offset (returned to the client code) + # and recovers from failures by requesting download from the current offset. + + @staticmethod + def _extract_raw_response(download_response: DownloadResponse) -> _RawResponse: + streaming_response: _StreamingResponse = download_response.contents # this is an instance of _StreamingResponse + return streaming_response._response + + def __init__(self, underlying_iterator, file_path: str, file_last_modified: str, offset: int, + api: FilesExt, chunk_size: int): + self._underlying_iterator = underlying_iterator + self._api = api + self._file_path = file_path + + # Absolute current offset (0-based), i.e. number of bytes from the beginning of the file + # that were so far returned to the caller code. + self._offset = offset + self._file_last_modified = file_last_modified + self._chunk_size = chunk_size + + self._total_recovers_count: int = 0 + self._recovers_without_progressing_count: int = 0 + self._closed: bool = False + + def _should_recover(self) -> bool: + if self._total_recovers_count == self._api._config.files_api_client_download_max_total_recovers: + _LOG.debug("Total recovers limit exceeded") + return False + if self._api._config.files_api_client_download_max_total_recovers_without_progressing is not None and self._recovers_without_progressing_count >= self._api._config.files_api_client_download_max_total_recovers_without_progressing: + _LOG.debug("No progression recovers limit exceeded") + return False + return True + + def _recover(self) -> bool: + if not self._should_recover(): + return False # recover suppressed, rethrow original exception + + self._total_recovers_count += 1 + self._recovers_without_progressing_count += 1 + + try: + self._underlying_iterator.close() + + _LOG.debug("Trying to recover from offset " + str(self._offset)) + + # following call includes all the required network retries + downloadResponse = self._api._download_raw_stream(self._file_path, self._offset, + self._file_last_modified) + underlying_response = _ResilientIterator._extract_raw_response(downloadResponse) + self._underlying_iterator = underlying_response.iter_content(chunk_size=self._chunk_size, + decode_unicode=False) + _LOG.debug("Recover succeeded") + return True + except: + return False # recover failed, rethrow original exception + + def __next__(self): + if self._closed: + # following _BaseClient + raise ValueError("I/O operation on closed file") + + while True: + try: + returned_bytes = next(self._underlying_iterator) + self._offset += len(returned_bytes) + self._recovers_without_progressing_count = 0 + return returned_bytes + + except StopIteration: + raise + + # https://requests.readthedocs.io/en/latest/user/quickstart/#errors-and-exceptions + except RequestException: + if not self._recover(): + raise + + def close(self): + self._underlying_iterator.close() + self._closed = True diff --git a/tests/test_base_client.py b/tests/test_base_client.py index 4b6aaa71..a9a9d5cc 100644 --- a/tests/test_base_client.py +++ b/tests/test_base_client.py @@ -5,17 +5,17 @@ from unittest.mock import Mock import pytest -import requests from databricks.sdk import errors, useragent -from databricks.sdk._base_client import _BaseClient, _StreamingResponse +from databricks.sdk._base_client import (_BaseClient, _RawResponse, + _StreamingResponse) from databricks.sdk.core import DatabricksError from .clock import FakeClock from .fixture_server import http_fixture_server -class DummyResponse(requests.Response): +class DummyResponse(_RawResponse): _content: Iterator[bytes] _closed: bool = False @@ -293,9 +293,9 @@ def test_streaming_response_chunk_size(chunk_size, expected_chunks, data_size): test_data = bytes(rng.getrandbits(8) for _ in range(data_size)) content_chunks = [] - mock_response = Mock(spec=requests.Response) + mock_response = Mock(spec=_RawResponse) - def mock_iter_content(chunk_size): + def mock_iter_content(chunk_size: int, decode_unicode: bool): # Simulate how requests would chunk the data. for i in range(0, len(test_data), chunk_size): chunk = test_data[i:i + chunk_size] diff --git a/tests/test_files.py b/tests/test_files.py new file mode 100644 index 00000000..f4d916f6 --- /dev/null +++ b/tests/test_files.py @@ -0,0 +1,340 @@ +import logging +import os +import re +from dataclasses import dataclass +from typing import List, Union + +import pytest +from requests import RequestException + +from databricks.sdk import WorkspaceClient +from databricks.sdk.core import Config + +logger = logging.getLogger(__name__) + + +@dataclass +class RequestData: + + def __init__(self, offset: int): + self._offset: int = offset + + +class DownloadTestCase: + + def __init__(self, name: str, enable_new_client: bool, file_size: int, + failure_at_absolute_offset: List[int], max_recovers_total: Union[int, None], + max_recovers_without_progressing: Union[int, None], expected_success: bool, + expected_requested_offsets: List[int]): + self.name = name + self.enable_new_client = enable_new_client + self.file_size = file_size + self.failure_at_absolute_offset = failure_at_absolute_offset + self.max_recovers_total = max_recovers_total + self.max_recovers_without_progressing = max_recovers_without_progressing + self.expected_success = expected_success + self.expected_requested_offsets = expected_requested_offsets + + @staticmethod + def to_string(test_case): + return test_case.name + + def run(self, config: Config): + config = config.copy() + config.enable_experimental_files_api_client = self.enable_new_client + config.files_api_client_download_max_total_recovers = self.max_recovers_total + config.files_api_client_download_max_total_recovers_without_progressing = self.max_recovers_without_progressing + + w = WorkspaceClient(config=config) + + session = MockSession(self) + w.files._api._api_client._session = session + + response = w.files.download("/test").contents + if self.expected_success: + actual_content = response.read() + assert (len(actual_content) == len(session.content)) + assert (actual_content == session.content) + else: + with pytest.raises(RequestException): + response.read() + + received_requests = session.received_requests + + assert (len(self.expected_requested_offsets) == len(received_requests)) + for idx, requested_offset in enumerate(self.expected_requested_offsets): + assert (requested_offset == received_requests[idx]._offset) + + +class MockSession: + + def __init__(self, test_case: DownloadTestCase): + self.test_case: DownloadTestCase = test_case + self.received_requests: List[RequestData] = [] + self.content: bytes = os.urandom(self.test_case.file_size) + self.failure_pointer = 0 + self.last_modified = 'Thu, 28 Nov 2024 16:39:14 GMT' + + # following the signature of Session.request() + def request(self, + method, + url, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + timeout=None, + allow_redirects=True, + proxies=None, + hooks=None, + stream=None, + verify=None, + cert=None, + json=None): + assert method == 'GET' + assert stream == True + + offset = 0 + if "Range" in headers: + range = headers["Range"] + match = re.search("^bytes=(\\d+)-$", range) + if match: + offset = int(match.group(1)) + else: + raise Exception("Unexpected range header: " + range) + + if "If-Unmodified-Since" in headers: + assert (headers["If-Unmodified-Since"] == self.last_modified) + else: + raise Exception("If-Unmodified-Since header should be passed along with Range") + + logger.info("Client requested offset: %s", offset) + + if offset > len(self.content): + raise Exception("Offset %s exceeds file length %s", offset, len(self.content)) + + self.received_requests.append(RequestData(offset)) + return MockResponse(self, offset, MockRequest(url)) + + +# required only for correct logging +class MockRequest: + + def __init__(self, url: str): + self.url = url + self.method = 'GET' + self.headers = dict() + self.body = None + + +class MockResponse: + + def __init__(self, session: MockSession, offset: int, request: MockRequest): + self.session = session + self.offset = offset + self.request = request + self.status_code = 200 + self.reason = 'OK' + self.headers = dict() + self.headers['Content-Length'] = len(session.content) - offset + self.headers['Content-Type'] = 'application/octet-stream' + self.headers['Last-Modified'] = session.last_modified + self.ok = True + self.url = request.url + + def iter_content(self, chunk_size: int, decode_unicode: bool): + assert decode_unicode == False + return MockIterator(self, chunk_size) + + +class MockIterator: + + def __init__(self, response: MockResponse, chunk_size: int): + self.response = response + self.chunk_size = chunk_size + self.offset = 0 + + def __next__(self): + start_offset = self.response.offset + self.offset + if start_offset == len(self.response.session.content): + raise StopIteration + + end_offset = start_offset + self.chunk_size # exclusive, might be out of range + + if self.response.session.failure_pointer < len( + self.response.session.test_case.failure_at_absolute_offset): + failure_after_byte = self.response.session.test_case.failure_at_absolute_offset[ + self.response.session.failure_pointer] + if failure_after_byte < end_offset: + self.response.session.failure_pointer += 1 + raise RequestException("Fake error") + + result = self.response.session.content[start_offset:end_offset] + self.offset += len(result) + return result + + def close(self): + pass + + +class _Constants: + underlying_chunk_size = 1024 * 1024 # see ticket #832 + + +@pytest.mark.parametrize( + "test_case", + [ + DownloadTestCase(name="Old client: no failures, file of 5 bytes", + enable_new_client=False, + file_size=5, + failure_at_absolute_offset=[], + max_recovers_total=0, + max_recovers_without_progressing=0, + expected_success=True, + expected_requested_offsets=[0]), + DownloadTestCase(name="Old client: no failures, file of 1.5 chunks", + enable_new_client=False, + file_size=int(1.5 * _Constants.underlying_chunk_size), + failure_at_absolute_offset=[], + max_recovers_total=0, + max_recovers_without_progressing=0, + expected_success=True, + expected_requested_offsets=[0]), + DownloadTestCase( + name="Old client: failure", + enable_new_client=False, + file_size=1024, + failure_at_absolute_offset=[100], + max_recovers_total=None, # unlimited but ignored + max_recovers_without_progressing=None, # unlimited but ignored + expected_success=False, + expected_requested_offsets=[0]), + DownloadTestCase(name="New client: no failures, file of 5 bytes", + enable_new_client=True, + file_size=5, + failure_at_absolute_offset=[], + max_recovers_total=0, + max_recovers_without_progressing=0, + expected_success=True, + expected_requested_offsets=[0]), + DownloadTestCase(name="New client: no failures, file of 1 Kb", + enable_new_client=True, + file_size=1024, + max_recovers_total=None, + max_recovers_without_progressing=None, + failure_at_absolute_offset=[], + expected_success=True, + expected_requested_offsets=[0]), + DownloadTestCase(name="New client: no failures, file of 1.5 chunks", + enable_new_client=True, + file_size=int(1.5 * _Constants.underlying_chunk_size), + failure_at_absolute_offset=[], + max_recovers_total=0, + max_recovers_without_progressing=0, + expected_success=True, + expected_requested_offsets=[0]), + DownloadTestCase(name="New client: no failures, file of 10 chunks", + enable_new_client=True, + file_size=10 * _Constants.underlying_chunk_size, + failure_at_absolute_offset=[], + max_recovers_total=0, + max_recovers_without_progressing=0, + expected_success=True, + expected_requested_offsets=[0]), + DownloadTestCase(name="New client: recovers are disabled, first failure leads to download abort", + enable_new_client=True, + file_size=10000, + failure_at_absolute_offset=[5], + max_recovers_total=0, + max_recovers_without_progressing=0, + expected_success=False, + expected_requested_offsets=[0]), + DownloadTestCase( + name="New client: unlimited recovers allowed", + enable_new_client=True, + file_size=_Constants.underlying_chunk_size * 5, + # causes errors on requesting the third chunk + failure_at_absolute_offset=[ + _Constants.underlying_chunk_size - 1, _Constants.underlying_chunk_size - 1, + _Constants.underlying_chunk_size - 1, _Constants.underlying_chunk_size + 1, + _Constants.underlying_chunk_size * 3, + ], + max_recovers_total=None, + max_recovers_without_progressing=None, + expected_success=True, + expected_requested_offsets=[ + 0, 0, 0, 0, _Constants.underlying_chunk_size, _Constants.underlying_chunk_size * 3 + ]), + DownloadTestCase( + name="New client: we respect limit on total recovers when progressing", + enable_new_client=True, + file_size=_Constants.underlying_chunk_size * 10, + failure_at_absolute_offset=[ + 1, + _Constants.underlying_chunk_size + 1, # progressing + _Constants.underlying_chunk_size * 2 + 1, # progressing + _Constants.underlying_chunk_size * 3 + 1 # progressing + ], + max_recovers_total=3, + max_recovers_without_progressing=None, + expected_success=False, + expected_requested_offsets=[ + 0, 0, _Constants.underlying_chunk_size * 1, _Constants.underlying_chunk_size * 2 + ]), + DownloadTestCase(name="New client: we respect limit on total recovers when not progressing", + enable_new_client=True, + file_size=_Constants.underlying_chunk_size * 10, + failure_at_absolute_offset=[1, 1, 1, 1], + max_recovers_total=3, + max_recovers_without_progressing=None, + expected_success=False, + expected_requested_offsets=[0, 0, 0, 0]), + DownloadTestCase(name="New client: we respect limit on non-progressing recovers", + enable_new_client=True, + file_size=_Constants.underlying_chunk_size * 2, + failure_at_absolute_offset=[ + _Constants.underlying_chunk_size - 1, _Constants.underlying_chunk_size - 1, + _Constants.underlying_chunk_size - 1, _Constants.underlying_chunk_size - 1 + ], + max_recovers_total=None, + max_recovers_without_progressing=3, + expected_success=False, + expected_requested_offsets=[0, 0, 0, 0]), + DownloadTestCase( + name="New client: non-progressing recovers count is reset when progressing", + enable_new_client=True, + file_size=_Constants.underlying_chunk_size * 10, + failure_at_absolute_offset=[ + _Constants.underlying_chunk_size + 1, # this recover is after progressing + _Constants.underlying_chunk_size + 1, # this is not + _Constants.underlying_chunk_size * 2 + 1, # this recover is after progressing + _Constants.underlying_chunk_size * 2 + 1, # this is not + _Constants.underlying_chunk_size * 2 + 1, # this is not, we abort here + ], + max_recovers_total=None, + max_recovers_without_progressing=2, + expected_success=False, + expected_requested_offsets=[ + 0, _Constants.underlying_chunk_size, _Constants.underlying_chunk_size, + _Constants.underlying_chunk_size * 2, _Constants.underlying_chunk_size * 2 + ]), + DownloadTestCase(name="New client: non-progressing recovers count is reset when progressing - 2", + enable_new_client=True, + file_size=_Constants.underlying_chunk_size * 10, + failure_at_absolute_offset=[ + 1, _Constants.underlying_chunk_size + 1, _Constants.underlying_chunk_size * 2 + + 1, _Constants.underlying_chunk_size * 3 + 1 + ], + max_recovers_total=None, + max_recovers_without_progressing=1, + expected_success=True, + expected_requested_offsets=[ + 0, 0, _Constants.underlying_chunk_size, _Constants.underlying_chunk_size * 2, + _Constants.underlying_chunk_size * 3 + ]), + ], + ids=DownloadTestCase.to_string) +def test_download_recover(config: Config, test_case: DownloadTestCase): + test_case.run(config)