From 55655244443a830f558c926953de3dd135044a8d Mon Sep 17 00:00:00 2001 From: Juliya Smith Date: Tue, 24 Dec 2024 12:45:17 -0600 Subject: [PATCH] fix: needed connect and adjust test --- src/ape/api/networks.py | 23 ++++++++++++++++++----- tests/functional/test_cli.py | 2 +- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index 1eed33e820..6855cbce88 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -807,7 +807,6 @@ def empty(self) -> bool: """ ``True`` when there are no providers in the context. """ - return not self.connected_providers or not self.provider_stack def __enter__(self, *args, **kwargs): @@ -895,6 +894,17 @@ def disconnect_all(self): self.connected_providers = {} +def _connect_provider(provider: "ProviderAPI") -> "ProviderAPI": + connection_id = provider.connection_id + if connection_id in ProviderContextManager.connected_providers: + # Likely multi-chain testing or utilizing multiple on-going connections. + provider = ProviderContextManager.connected_providers[connection_id] + if not provider.is_connected: + provider.connect() + + return provider + + class NetworkAPI(BaseInterfaceModel): """ A wrapper around a provider for a specific ecosystem. @@ -1183,6 +1193,7 @@ def get_provider( self, provider_name: Optional[str] = None, provider_settings: Optional[dict] = None, + connect: bool = False, ): """ Get a provider for the given name. If given ``None``, returns the default provider. @@ -1192,6 +1203,7 @@ def get_provider( When ``None``, returns the default provider. provider_settings (dict, optional): Settings to apply to the provider. Defaults to ``None``. + connect (bool): Set to ``True`` when you also want the provider to connect. Returns: :class:`~ape.api.providers.ProviderAPI` @@ -1215,17 +1227,19 @@ def get_provider( provider_name = "node" if provider_name in self.providers: - return self.providers[provider_name](provider_settings=provider_settings) + provider = self.providers[provider_name](provider_settings=provider_settings) + return _connect_provider(provider) if connect else provider elif self.is_fork: # If it can fork Ethereum (and we are asking for it) assume it can fork this one. # TODO: Refactor this approach to work for custom-forked non-EVM networks. common_forking_providers = self.network_manager.ethereum.mainnet_fork.providers if provider_name in common_forking_providers: - return common_forking_providers[provider_name]( + provider = common_forking_providers[provider_name]( provider_settings=provider_settings, network=self, ) + return _connect_provider(provider) if connect else provider raise ProviderNotFoundError( provider_name, @@ -1274,7 +1288,7 @@ def use_provider( # NOTE: The main reason we allow a provider instance here is to avoid unnecessarily # re-initializing the class. provider_obj = ( - self.get_provider(provider_name=provider, provider_settings=settings) + self.get_provider(provider_name=provider, provider_settings=settings, connect=True) if isinstance(provider, str) else provider ) @@ -1434,7 +1448,6 @@ def upstream_provider(self) -> "UpstreamProvider": When not set, will attempt to use the default provider, if one exists. """ - config_choice: str = self.config.get("upstream_provider") if provider_name := config_choice or self.upstream_network.default_provider_name: return self.upstream_network.get_provider(provider_name) diff --git a/tests/functional/test_cli.py b/tests/functional/test_cli.py index 7d1087c98b..f8d8adc6a0 100644 --- a/tests/functional/test_cli.py +++ b/tests/functional/test_cli.py @@ -803,7 +803,7 @@ def test_connected_provider_command_with_network_option_and_cls_types_false(runn @network_option() def cmd(network): assert isinstance(network, str) - assert network == "ethereum:local:node" + assert network.startswith("ethereum:local") # NOTE: Must use a network that is not the default. spec = ("--network", "ethereum:local:node")