diff --git a/src/comet_llm/experiment_api/comet_api_client.py b/src/comet_llm/experiment_api/comet_api_client.py index 62f28ce70..511abd8da 100644 --- a/src/comet_llm/experiment_api/comet_api_client.py +++ b/src/comet_llm/experiment_api/comet_api_client.py @@ -99,7 +99,6 @@ def log_experiment_tags( json={"experimentKey": experiment_key, "addedTags": tags}, ) - @request_exception_wrapper.wrap def _request(self, method: str, path: str, *args, **kwargs) -> ResponseContent: # type: ignore url = urllib.parse.urljoin(self._comet_url, path) response = requests.request(method, url, headers=self._headers, *args, **kwargs) diff --git a/src/comet_llm/experiment_api/experiment_api.py b/src/comet_llm/experiment_api/experiment_api.py index 5ed06e359..ca0115d58 100644 --- a/src/comet_llm/experiment_api/experiment_api.py +++ b/src/comet_llm/experiment_api/experiment_api.py @@ -14,10 +14,11 @@ from typing import IO, Any, List, Optional -from . import comet_api_client +from . import comet_api_client, request_exception_wrapper class ExperimentAPI: + @request_exception_wrapper.wrap(check_on_prem=True) def __init__( self, api_key: str, @@ -46,18 +47,22 @@ def _initialize_links(self, link: str) -> None: self._link = link self._project_link = link[: link.rfind("/")] + @request_exception_wrapper.wrap(check_on_prem=True) def log_asset_with_io(self, name: str, file: IO, asset_type: str) -> None: self._client.log_experiment_asset_with_io( self._experiment_key, name=name, file=file, asset_type=asset_type ) + @request_exception_wrapper.wrap() def log_parameter(self, name: str, value: Any) -> None: self._client.log_experiment_parameter( self._experiment_key, name=name, value=value ) + @request_exception_wrapper.wrap() def log_metric(self, name: str, value: Any) -> None: self._client.log_experiment_metric(self._experiment_key, name=name, value=value) + @request_exception_wrapper.wrap() def log_tags(self, tags: List[str]) -> None: self._client.log_experiment_tags(self._experiment_key, tags=tags) diff --git a/src/comet_llm/experiment_api/request_exception_wrapper.py b/src/comet_llm/experiment_api/request_exception_wrapper.py index 665e0b561..f41325e5e 100644 --- a/src/comet_llm/experiment_api/request_exception_wrapper.py +++ b/src/comet_llm/experiment_api/request_exception_wrapper.py @@ -14,19 +14,40 @@ import functools -from typing import Any, Callable +import urllib.parse +from typing import Any, Callable, List import requests # type: ignore -from .. import exceptions +from .. import config, exceptions -def wrap(func: Callable) -> Callable: - @functools.wraps(func) - def wrapper(*args, **kwargs) -> Any: # type: ignore - try: - return func(*args, **kwargs) - except requests.RequestException as exception: - raise exceptions.CometLLMException() from exception +def wrap(check_on_prem: bool = False) -> Callable: + def inner_wrap(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs) -> Any: # type: ignore + try: + return func(*args, **kwargs) + except requests.RequestException as exception: + exception_args: List[Any] = [] - return wrapper + if check_on_prem: + comet_url = config.comet_url() + if _is_on_prem(comet_url): + exception_args.append( + f"Failed to send prompt to your Comet installation at " + f"{comet_url}. Check that your Comet " + f"installation is up-to-date and check the traceback for more details." + ) + + raise exceptions.CometLLMException(*exception_args) from exception + + return wrapper + + return inner_wrap + + +def _is_on_prem(url: str) -> bool: + parsed = urllib.parse.urlparse(url) + root = f"{parsed.scheme}://{parsed.hostname}/" + return root != "https://www.comet.com/" diff --git a/tests/unit/experiment_api/test_request_exception_wrapper.py b/tests/unit/experiment_api/test_request_exception_wrapper.py index d5ff10b91..ac2849594 100644 --- a/tests/unit/experiment_api/test_request_exception_wrapper.py +++ b/tests/unit/experiment_api/test_request_exception_wrapper.py @@ -1,20 +1,26 @@ import pytest import requests +from testix import * from comet_llm import exceptions from comet_llm.experiment_api import request_exception_wrapper -def test_reraiser_no_exceptions(): - @request_exception_wrapper.wrap +@pytest.fixture(autouse=True) +def mock_imports(patch_module): + patch_module(request_exception_wrapper, "config") + + +def test_wrap_no_exceptions(): + @request_exception_wrapper.wrap() def f(): return "return-value" assert f() == "return-value" -def test_reraiser__request_exception_caught__comet_exception_raised(): - @request_exception_wrapper.wrap +def test_wrap__request_exception_caught__comet_exception_raised(): + @request_exception_wrapper.wrap() def f(): raise requests.RequestException @@ -22,10 +28,25 @@ def f(): f() -def test_reraiser__generic_exception_not_caught(): - @request_exception_wrapper.wrap +def test_wrap__on_prem_check_enabled__request_exception_caught__on_prem_detected__comet_exception_raised_with_additional_message(): + @request_exception_wrapper.wrap(check_on_prem=True) def f(): - raise Exception + raise requests.RequestException + + with Scenario() as s: + s.config.comet_url() >> "https://not.comet.cloud/ddf/" + with pytest.raises(exceptions.CometLLMException): + f() + + +def test_wrap__on_prem_check_enabled__request_exception_caught__on_prem_not_detected__comet_exception_raised_without_additional_message(): + @request_exception_wrapper.wrap(check_on_prem=True) + def f(): + raise requests.RequestException + + COMET_CLOUD_URL = "https://www.comet.com/clientlib/" - with pytest.raises(Exception): - f() \ No newline at end of file + with Scenario() as s: + s.config.comet_url() >> COMET_CLOUD_URL + with pytest.raises(exceptions.CometLLMException): + f() \ No newline at end of file