From 586824e7ed2f4e29a47944cd4c7ba14031939c98 Mon Sep 17 00:00:00 2001 From: wojo Date: Wed, 27 Nov 2024 11:30:16 +0100 Subject: [PATCH] Ensure DHT protocol ID matches latest Starknet spec Use dht.ProtocolExtension to set chainID in protocol ID format for DHT. Add test to verify protocol ID format for different networks: - /starknet/SN_SEPOLIA/kad/1.0.0 for Sepolia - /starknet/SN_MAIN/kad/1.0.0 for Mainnet The change ensures that DHT protocol follow latest Starknet specification. --- p2p/p2p.go | 5 +++-- p2p/p2p_test.go | 34 ++++++++++++++++++++++++++++++++++ p2p/starknet/ids.go | 4 ++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/p2p/p2p.go b/p2p/p2p.go index 49633f49ee..fae78cd781 100644 --- a/p2p/p2p.go +++ b/p2p/p2p.go @@ -142,7 +142,7 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai } } - p2pdht, err := makeDHT(p2phost, peersAddrInfoS) + p2pdht, err := MakeDHT(p2phost, peersAddrInfoS, snNetwork.L2ChainID) if err != nil { return nil, err } @@ -164,9 +164,10 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai return s, nil } -func makeDHT(p2phost host.Host, addrInfos []peer.AddrInfo) (*dht.IpfsDHT, error) { +func MakeDHT(p2phost host.Host, addrInfos []peer.AddrInfo, chainID string) (*dht.IpfsDHT, error) { return dht.New(context.Background(), p2phost, dht.ProtocolPrefix(starknet.Prefix), + dht.ProtocolExtension(starknet.ChainPID(chainID)), dht.BootstrapPeers(addrInfos...), dht.RoutingTableRefreshPeriod(routingTableRefreshPeriod), dht.Mode(dht.ModeServer), diff --git a/p2p/p2p_test.go b/p2p/p2p_test.go index 070a9eedb8..1d5e1f3ae6 100644 --- a/p2p/p2p_test.go +++ b/p2p/p2p_test.go @@ -17,6 +17,7 @@ import ( "github.com/libp2p/go-libp2p/core/protocol" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -206,3 +207,36 @@ func TestLoadAndPersistPeers(t *testing.T) { ) require.NoError(t, err) } + +func TestMakeDHTProtocolName(t *testing.T) { + net, err := mocknet.FullMeshLinked(1) + require.NoError(t, err) + testHost := net.Hosts()[0] + + testCases := []struct { + name string + network *utils.Network + expected string + }{ + { + name: "sepolia network", + network: &utils.Sepolia, + expected: "/starknet/SN_SEPOLIA/kad/1.0.0", + }, + { + name: "mainnet network", + network: &utils.Mainnet, + expected: "/starknet/SN_MAIN/kad/1.0.0", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dht, err := p2p.MakeDHT(testHost, nil, tc.network.L2ChainID) + require.NoError(t, err) + + protocols := dht.Host().Mux().Protocols() + assert.Contains(t, protocols, protocol.ID(tc.expected), "protocol list: %v", protocols) + }) + } +} diff --git a/p2p/starknet/ids.go b/p2p/starknet/ids.go index d1b97b0ad2..9b54f7500a 100644 --- a/p2p/starknet/ids.go +++ b/p2p/starknet/ids.go @@ -25,3 +25,7 @@ func ClassesPID() protocol.ID { func StateDiffPID() protocol.ID { return Prefix + "/state_diffs/0.1.0-rc.0" } + +func ChainPID(chainID string) protocol.ID { + return protocol.ID("/" + chainID) +}