Skip to content

Commit

Permalink
Ensure DHT protocol ID matches latest Starknet spec
Browse files Browse the repository at this point in the history
Use dht.ProtocolExtension to set chainID in  protocol ID format for DHT.
Add test to verify protocol ID format for different networks:
- /starknet/SN_SEPOLIA/kad/1.0.0 for Sepolia
- /starknet/SN_MAIN/kad/1.0.0 for Mainnet

The change ensures that DHT protocol follow latest Starknet
specification.
  • Loading branch information
wojciechos authored and kirugan committed Dec 3, 2024
1 parent ec24744 commit 586824e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
5 changes: 3 additions & 2 deletions p2p/p2p.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai
}
}

p2pdht, err := makeDHT(p2phost, peersAddrInfoS)
p2pdht, err := MakeDHT(p2phost, peersAddrInfoS, snNetwork.L2ChainID)
if err != nil {
return nil, err
}
Expand All @@ -164,9 +164,10 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai
return s, nil
}

func makeDHT(p2phost host.Host, addrInfos []peer.AddrInfo) (*dht.IpfsDHT, error) {
func MakeDHT(p2phost host.Host, addrInfos []peer.AddrInfo, chainID string) (*dht.IpfsDHT, error) {
return dht.New(context.Background(), p2phost,
dht.ProtocolPrefix(starknet.Prefix),
dht.ProtocolExtension(starknet.ChainPID(chainID)),
dht.BootstrapPeers(addrInfos...),
dht.RoutingTableRefreshPeriod(routingTableRefreshPeriod),
dht.Mode(dht.ModeServer),
Expand Down
34 changes: 34 additions & 0 deletions p2p/p2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/libp2p/go-libp2p/core/protocol"
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
"github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -206,3 +207,36 @@ func TestLoadAndPersistPeers(t *testing.T) {
)
require.NoError(t, err)
}

func TestMakeDHTProtocolName(t *testing.T) {
net, err := mocknet.FullMeshLinked(1)
require.NoError(t, err)
testHost := net.Hosts()[0]

testCases := []struct {
name string
network *utils.Network
expected string
}{
{
name: "sepolia network",
network: &utils.Sepolia,
expected: "/starknet/SN_SEPOLIA/kad/1.0.0",
},
{
name: "mainnet network",
network: &utils.Mainnet,
expected: "/starknet/SN_MAIN/kad/1.0.0",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
dht, err := p2p.MakeDHT(testHost, nil, tc.network.L2ChainID)
require.NoError(t, err)

protocols := dht.Host().Mux().Protocols()
assert.Contains(t, protocols, protocol.ID(tc.expected), "protocol list: %v", protocols)
})
}
}
4 changes: 4 additions & 0 deletions p2p/starknet/ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ func ClassesPID() protocol.ID {
func StateDiffPID() protocol.ID {
return Prefix + "/state_diffs/0.1.0-rc.0"
}

func ChainPID(chainID string) protocol.ID {
return protocol.ID("/" + chainID)
}

0 comments on commit 586824e

Please sign in to comment.