Skip to content

Commit

Permalink
make error message more specific
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed May 14, 2024
1 parent 4416d4b commit 0eb8dbc
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 32 deletions.
10 changes: 5 additions & 5 deletions src/agentscope/agents/rpc_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import grpc
from grpc import ServicerContext
from expiringdict import ExpiringDict
except ImportError:
except ImportError as import_error:
from agentscope.utils.tools import ImportErrorReporter

dill = ImportErrorReporter("dill", "distribute")
grpc = ImportErrorReporter("grpcio", "distribute")
ServicerContext = ImportErrorReporter("grpcio", "distribute")
ExpiringDict = ImportErrorReporter("expiringdict", "distribute")
dill = ImportErrorReporter(import_error, "distribute")
grpc = ImportErrorReporter(import_error, "distribute")
ServicerContext = ImportErrorReporter(import_error, "distribute")
ExpiringDict = ImportErrorReporter(import_error, "distribute")

from agentscope._init import init_process, _INIT_SETTINGS
from agentscope.agents.agent import AgentBase
Expand Down
10 changes: 5 additions & 5 deletions src/agentscope/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from .rpc_agent_pb2_grpc import RpcAgentServicer
from .rpc_agent_pb2_grpc import RpcAgentStub
from .rpc_agent_pb2_grpc import add_RpcAgentServicer_to_server
except ImportError:
except ImportError as import_error:
from agentscope.utils.tools import ImportErrorReporter

RpcMsg = ImportErrorReporter("protobuf", "distribute") # type: ignore[misc]
RpcAgentServicer = ImportErrorReporter("grpcio", "distribute")
RpcAgentStub = ImportErrorReporter("grpcio", "distribute")
RpcMsg = ImportErrorReporter(import_error, "distribute") # type: ignore[misc]
RpcAgentServicer = ImportErrorReporter(import_error, "distribute")
RpcAgentStub = ImportErrorReporter(import_error, "distribute")
add_RpcAgentServicer_to_server = ImportErrorReporter(
"grpcio",
import_error,
"distribute",
)

Expand Down
10 changes: 5 additions & 5 deletions src/agentscope/rpc/rpc_agent_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from grpc import RpcError
from agentscope.rpc.rpc_agent_pb2 import RpcMsg # pylint: disable=E0611
from agentscope.rpc.rpc_agent_pb2_grpc import RpcAgentStub
except ImportError:
except ImportError as import_error:
from agentscope.utils.tools import ImportErrorReporter

dill = ImportErrorReporter("dill", "distribute")
grpc = ImportErrorReporter("grpcio", "distribute")
RpcMsg = ImportErrorReporter("protobuf", "distribute")
RpcAgentStub = ImportErrorReporter("grpcio", "distribute")
dill = ImportErrorReporter(import_error, "distribute")
grpc = ImportErrorReporter(import_error, "distribute")
RpcMsg = ImportErrorReporter(import_error, "distribute")
RpcAgentStub = ImportErrorReporter(import_error, "distribute")
RpcError = ImportError


Expand Down
4 changes: 2 additions & 2 deletions src/agentscope/rpc/rpc_agent_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
"""Client and server classes corresponding to protobuf-defined services."""
try:
import grpc
except ImportError:
except ImportError as import_error:
from agentscope.utils.tools import ImportErrorReporter

grpc = ImportErrorReporter("grpcio", "distribute")
grpc = ImportErrorReporter(import_error, "distribute")

import agentscope.rpc.rpc_agent_pb2 as rpc__agent__pb2

Expand Down
33 changes: 18 additions & 15 deletions src/agentscope/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,28 +304,31 @@ class ImportErrorReporter:
the specified extras requirement.
"""

def __init__(self, package_name: str, extras_require: str = None) -> None:
def __init__(self, error: ImportError, extras_require: str = None) -> None:
"""Init the ImportErrorReporter.
Args:
package_name (`str`): the name of the package to be imported.
error (`ImportError`): the original ImportError.
extras_require (`str`): the extras requirement.
"""
self.package_name = package_name
self.error = error
self.extras_require = extras_require

def raise_error(self) -> Any:
"""Raise an ImportError."""
msg = f"Failed to import {self.package_name}."
if self.extras_require is not None:
msg += (
f" Please install [{self.extras_require}] version of"
" agentscope."
)
raise ImportError(msg)

def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.raise_error()
return self._raise_import_error()

def __getattr__(self, name: str) -> Any:
return self.raise_error()
return self._raise_import_error()

def __getitem__(self, __key: Any) -> Any:
return self._raise_import_error()

def _raise_import_error(self) -> Any:
"""Raise the ImportError"""
err_msg = f"ImportError occorred: [{self.error.msg}]."
if self.extras_require is not None:
err_msg += (
f" Please install [{self.extras_require}] version"
" of agentscope."
)
raise ImportError(err_msg)

0 comments on commit 0eb8dbc

Please sign in to comment.