Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: utilize chain ID cache on re-connect in Ethereum node provider #2464

Merged
merged 6 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,6 @@ def chain_id(self) -> int:
**NOTE**: Unless overridden, returns same as
:py:attr:`ape.api.providers.ProviderAPI.chain_id`.
"""

return self.provider.chain_id

@property
Expand Down
1 change: 0 additions & 1 deletion src/ape/managers/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,6 @@ def chain_id(self) -> int:
The blockchain ID.
See `ChainList <https://chainlist.org/>`__ for a comprehensive list of IDs.
"""

network_name = self.provider.network.name
if network_name not in self._chain_id_map:
self._chain_id_map[network_name] = self.provider.chain_id
Expand Down
31 changes: 19 additions & 12 deletions src/ape_ethereum/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,20 +572,28 @@ def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional["BlockID"] =
@cached_property
def chain_id(self) -> int:
default_chain_id = None
if self.network.name != "custom" and not self.network.is_dev:
# If using a live network, the chain ID is hardcoded.
if self.network.name not in ("adhoc", "custom") and not self.network.is_dev:
# If using a live plugin-based network, the chain ID is hardcoded.
default_chain_id = self.network.chain_id

try:
if hasattr(self.web3, "eth"):
return self.web3.eth.chain_id
return self._get_chain_id()

except ProviderNotConnectedError:
if default_chain_id is not None:
return default_chain_id

raise # Original error

except ValueError as err:
# Possible syncing error.
raise ProviderError(
err.args[0].get("message")
if all((hasattr(err, "args"), err.args, isinstance(err.args[0], dict)))
else "Error getting chain ID."
)

if default_chain_id is not None:
return default_chain_id

Expand All @@ -606,6 +614,13 @@ def priority_fee(self) -> int:
"eth_maxPriorityFeePerGas not supported in this RPC. Please specify manually."
) from err

def _get_chain_id(self) -> int:
result = self.make_request("eth_chainId", [])
if isinstance(result, int):
return result

return int(result, 16)

def get_block(self, block_id: "BlockID") -> BlockAPI:
if isinstance(block_id, str) and block_id.isnumeric():
block_id = int(block_id)
Expand Down Expand Up @@ -1603,15 +1618,7 @@ def _complete_connect(self):
if not self.network.is_dev:
self.web3.eth.set_gas_price_strategy(rpc_gas_price_strategy)

# Check for chain errors, including syncing
try:
chain_id = self.web3.eth.chain_id
except ValueError as err:
raise ProviderError(
err.args[0].get("message")
if all((hasattr(err, "args"), err.args, isinstance(err.args[0], dict)))
else "Error getting chain id."
)
chain_id = self.chain_id

# NOTE: We have to check both earliest and latest
# because if the chain was _ever_ PoA, we need
Expand Down
58 changes: 58 additions & 0 deletions tests/functional/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,35 @@ def test_chain_id_is_cached(eth_tester_provider):
eth_tester_provider._web3 = web3 # Undo


def test_chain_id_from_ethereum_base_provider_is_cached(mock_web3, ethereum, eth_tester_provider):
"""
Simulated chain ID from a plugin (using base-ethereum class) to ensure is
also cached.
"""

def make_request(rpc, arguments):
if rpc == "eth_chainId":
return {"result": 11155111} # Sepolia

return eth_tester_provider.make_request(rpc, arguments)

mock_web3.provider.make_request.side_effect = make_request

class PluginProvider(Web3Provider):
def connect(self):
return

def disconnect(self):
return

provider = PluginProvider(name="sim", network=ethereum.sepolia)
provider._web3 = mock_web3
assert provider.chain_id == 11155111
# Unset to web3 to prove it does not check it again (else it would fail).
provider._web3 = None
assert provider.chain_id == 11155111


def test_chain_id_when_disconnected(eth_tester_provider):
eth_tester_provider.disconnect()
try:
Expand Down Expand Up @@ -658,3 +687,32 @@ def test_update_settings_invalidates_snapshots(eth_tester_provider, chain):
assert snapshot in chain._snapshots[eth_tester_provider.chain_id]
eth_tester_provider.update_settings({})
assert snapshot not in chain._snapshots[eth_tester_provider.chain_id]


def test_connect_uses_cached_chain_id(mocker, mock_web3, ethereum, eth_tester_provider):
class PluginProvider(EthereumNodeProvider):
pass

web3_factory_patch = mocker.patch("ape_ethereum.provider._create_web3")
web3_factory_patch.return_value = mock_web3

class ChainIDTracker:
call_count = 0

def make_request(self, rpc, args):
if rpc == "eth_chainId":
self.call_count += 1
return {"result": "0xaa36a7"} # Sepolia

return eth_tester_provider.make_request(rpc, args)

chain_id_tracker = ChainIDTracker()
mock_web3.provider.make_request.side_effect = chain_id_tracker.make_request

provider = PluginProvider(name="node", network=ethereum.sepolia)
provider.connect()
assert chain_id_tracker.call_count == 1
provider.disconnect()
provider.connect()
# It is still cached from the previous connection.
assert chain_id_tracker.call_count == 1
Loading