Skip to content

Commit

Permalink
fix: needed connect and adjust test
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Dec 24, 2024
1 parent 7490a52 commit 5565524
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
23 changes: 18 additions & 5 deletions src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,6 @@ def empty(self) -> bool:
"""
``True`` when there are no providers in the context.
"""

return not self.connected_providers or not self.provider_stack

def __enter__(self, *args, **kwargs):
Expand Down Expand Up @@ -895,6 +894,17 @@ def disconnect_all(self):
self.connected_providers = {}


def _connect_provider(provider: "ProviderAPI") -> "ProviderAPI":
connection_id = provider.connection_id
if connection_id in ProviderContextManager.connected_providers:
# Likely multi-chain testing or utilizing multiple on-going connections.
provider = ProviderContextManager.connected_providers[connection_id]
if not provider.is_connected:
provider.connect()

return provider


class NetworkAPI(BaseInterfaceModel):
"""
A wrapper around a provider for a specific ecosystem.
Expand Down Expand Up @@ -1183,6 +1193,7 @@ def get_provider(
self,
provider_name: Optional[str] = None,
provider_settings: Optional[dict] = None,
connect: bool = False,
):
"""
Get a provider for the given name. If given ``None``, returns the default provider.
Expand All @@ -1192,6 +1203,7 @@ def get_provider(
When ``None``, returns the default provider.
provider_settings (dict, optional): Settings to apply to the provider. Defaults to
``None``.
connect (bool): Set to ``True`` when you also want the provider to connect.
Returns:
:class:`~ape.api.providers.ProviderAPI`
Expand All @@ -1215,17 +1227,19 @@ def get_provider(
provider_name = "node"

if provider_name in self.providers:
return self.providers[provider_name](provider_settings=provider_settings)
provider = self.providers[provider_name](provider_settings=provider_settings)
return _connect_provider(provider) if connect else provider

elif self.is_fork:
# If it can fork Ethereum (and we are asking for it) assume it can fork this one.
# TODO: Refactor this approach to work for custom-forked non-EVM networks.
common_forking_providers = self.network_manager.ethereum.mainnet_fork.providers
if provider_name in common_forking_providers:
return common_forking_providers[provider_name](
provider = common_forking_providers[provider_name](
provider_settings=provider_settings,
network=self,
)
return _connect_provider(provider) if connect else provider

raise ProviderNotFoundError(
provider_name,
Expand Down Expand Up @@ -1274,7 +1288,7 @@ def use_provider(
# NOTE: The main reason we allow a provider instance here is to avoid unnecessarily
# re-initializing the class.
provider_obj = (
self.get_provider(provider_name=provider, provider_settings=settings)
self.get_provider(provider_name=provider, provider_settings=settings, connect=True)
if isinstance(provider, str)
else provider
)
Expand Down Expand Up @@ -1434,7 +1448,6 @@ def upstream_provider(self) -> "UpstreamProvider":
When not set, will attempt to use the default provider, if one
exists.
"""

config_choice: str = self.config.get("upstream_provider")
if provider_name := config_choice or self.upstream_network.default_provider_name:
return self.upstream_network.get_provider(provider_name)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ def test_connected_provider_command_with_network_option_and_cls_types_false(runn
@network_option()
def cmd(network):
assert isinstance(network, str)
assert network == "ethereum:local:node"
assert network.startswith("ethereum:local")

# NOTE: Must use a network that is not the default.
spec = ("--network", "ethereum:local:node")
Expand Down

0 comments on commit 5565524

Please sign in to comment.