Skip to content

Commit

Permalink
Ensure DHT protocol ID matches latest Starknet spec
Browse files Browse the repository at this point in the history
Use dht.ProtocolExtension to set chainID in  protocol ID format for DHT.
Add test to verify protocol ID format for different networks:
- /starknet/SN_SEPOLIA/kad/1.0.0 for Sepolia
- /starknet/SN_MAIN/kad/1.0.0 for Mainnet

The change ensures that DHT protocol follow latest Starknet
specification.
  • Loading branch information
wojciechos committed Dec 12, 2024
1 parent ec24744 commit df1ed30
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 23 deletions.
16 changes: 8 additions & 8 deletions p2p/p2p.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai
}
}

p2pdht, err := makeDHT(p2phost, peersAddrInfoS)
p2pdht, err := MakeDHT(p2phost, peersAddrInfoS, snNetwork.L2ChainID)
if err != nil {
return nil, err
}
Expand All @@ -164,9 +164,9 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai
return s, nil
}

func makeDHT(p2phost host.Host, addrInfos []peer.AddrInfo) (*dht.IpfsDHT, error) {
func MakeDHT(p2phost host.Host, addrInfos []peer.AddrInfo, chainID string) (*dht.IpfsDHT, error) {
return dht.New(context.Background(), p2phost,
dht.ProtocolPrefix(starknet.Prefix),
dht.ProtocolPrefix(starknet.ChainPID(chainID)),
dht.BootstrapPeers(addrInfos...),
dht.RoutingTableRefreshPeriod(routingTableRefreshPeriod),
dht.Mode(dht.ModeServer),
Expand Down Expand Up @@ -282,11 +282,11 @@ func (s *Service) Run(ctx context.Context) error {
}

func (s *Service) setProtocolHandlers() {
s.SetProtocolHandler(starknet.HeadersPID(), s.handler.HeadersHandler)
s.SetProtocolHandler(starknet.EventsPID(), s.handler.EventsHandler)
s.SetProtocolHandler(starknet.TransactionsPID(), s.handler.TransactionsHandler)
s.SetProtocolHandler(starknet.ClassesPID(), s.handler.ClassesHandler)
s.SetProtocolHandler(starknet.StateDiffPID(), s.handler.StateDiffHandler)
s.SetProtocolHandler(starknet.HeadersPID(s.network.L2ChainID), s.handler.HeadersHandler)
s.SetProtocolHandler(starknet.EventsPID(s.network.L2ChainID), s.handler.EventsHandler)
s.SetProtocolHandler(starknet.TransactionsPID(s.network.L2ChainID), s.handler.TransactionsHandler)
s.SetProtocolHandler(starknet.ClassesPID(s.network.L2ChainID), s.handler.ClassesHandler)
s.SetProtocolHandler(starknet.StateDiffPID(s.network.L2ChainID), s.handler.StateDiffHandler)
}

func (s *Service) callAndLogErr(f func() error, msg string) {
Expand Down
34 changes: 34 additions & 0 deletions p2p/p2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/libp2p/go-libp2p/core/protocol"
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
"github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -206,3 +207,36 @@ func TestLoadAndPersistPeers(t *testing.T) {
)
require.NoError(t, err)
}

func TestMakeDHTProtocolName(t *testing.T) {
net, err := mocknet.FullMeshLinked(1)
require.NoError(t, err)
testHost := net.Hosts()[0]

testCases := []struct {
name string
network *utils.Network
expected string
}{
{
name: "sepolia network",
network: &utils.Sepolia,
expected: "/starknet/SN_SEPOLIA/sync/kad/1.0.0",
},
{
name: "mainnet network",
network: &utils.Mainnet,
expected: "/starknet/SN_MAIN/sync/kad/1.0.0",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
dht, err := p2p.MakeDHT(testHost, nil, tc.network.L2ChainID)
require.NoError(t, err)

protocols := dht.Host().Mux().Protocols()
assert.Contains(t, protocols, protocol.ID(tc.expected), "protocol list: %v", protocols)
})
}
}
10 changes: 5 additions & 5 deletions p2p/starknet/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,22 @@ func (c *Client) RequestBlockHeaders(
ctx context.Context, req *spec.BlockHeadersRequest,
) (iter.Seq[*spec.BlockHeadersResponse], error) {
return requestAndReceiveStream[*spec.BlockHeadersRequest, *spec.BlockHeadersResponse](
ctx, c.newStream, HeadersPID(), req, c.log)
ctx, c.newStream, HeadersPID(c.network.L2ChainID), req, c.log)
}

func (c *Client) RequestEvents(ctx context.Context, req *spec.EventsRequest) (iter.Seq[*spec.EventsResponse], error) {
return requestAndReceiveStream[*spec.EventsRequest, *spec.EventsResponse](ctx, c.newStream, EventsPID(), req, c.log)
return requestAndReceiveStream[*spec.EventsRequest, *spec.EventsResponse](ctx, c.newStream, EventsPID(c.network.L2ChainID), req, c.log)
}

func (c *Client) RequestClasses(ctx context.Context, req *spec.ClassesRequest) (iter.Seq[*spec.ClassesResponse], error) {
return requestAndReceiveStream[*spec.ClassesRequest, *spec.ClassesResponse](ctx, c.newStream, ClassesPID(), req, c.log)
return requestAndReceiveStream[*spec.ClassesRequest, *spec.ClassesResponse](ctx, c.newStream, ClassesPID(c.network.L2ChainID), req, c.log)
}

func (c *Client) RequestStateDiffs(ctx context.Context, req *spec.StateDiffsRequest) (iter.Seq[*spec.StateDiffsResponse], error) {
return requestAndReceiveStream[*spec.StateDiffsRequest, *spec.StateDiffsResponse](ctx, c.newStream, StateDiffPID(), req, c.log)
return requestAndReceiveStream[*spec.StateDiffsRequest, *spec.StateDiffsResponse](ctx, c.newStream, StateDiffPID(c.network.L2ChainID), req, c.log)
}

func (c *Client) RequestTransactions(ctx context.Context, req *spec.TransactionsRequest) (iter.Seq[*spec.TransactionsResponse], error) {
return requestAndReceiveStream[*spec.TransactionsRequest, *spec.TransactionsResponse](
ctx, c.newStream, TransactionsPID(), req, c.log)
ctx, c.newStream, TransactionsPID(c.network.L2ChainID), req, c.log)
}
24 changes: 14 additions & 10 deletions p2p/starknet/ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,26 @@ import (

const Prefix = "/starknet"

func HeadersPID() protocol.ID {
return Prefix + "/headers/0.1.0-rc.0"
func HeadersPID(chainID string) protocol.ID {
return protocol.ID(Prefix + "/" + chainID + "/sync/headers/0.1.0-rc.0")
}

func EventsPID() protocol.ID {
return Prefix + "/events/0.1.0-rc.0"
func EventsPID(chainID string) protocol.ID {
return protocol.ID(Prefix + "/" + chainID + "/sync/events/0.1.0-rc.0")
}

func TransactionsPID() protocol.ID {
return Prefix + "/transactions/0.1.0-rc.0"
func TransactionsPID(chainID string) protocol.ID {
return protocol.ID(Prefix + "/" + chainID + "/sync/transactions/0.1.0-rc.0")
}

func ClassesPID() protocol.ID {
return Prefix + "/classes/0.1.0-rc.0"
func ClassesPID(chainID string) protocol.ID {
return protocol.ID(Prefix + "/" + chainID + "/sync/classes/0.1.0-rc.0")
}

func StateDiffPID() protocol.ID {
return Prefix + "/state_diffs/0.1.0-rc.0"
func StateDiffPID(chainID string) protocol.ID {
return protocol.ID(Prefix + "/" + chainID + "/sync/state_diffs/0.1.0-rc.0")
}

func ChainPID(chainID string) protocol.ID {
return protocol.ID(Prefix + "/" + chainID + "/sync")
}
66 changes: 66 additions & 0 deletions p2p/starknet/ids_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package starknet

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestProtocolIDs(t *testing.T) {
testCases := []struct {
name string
chainID string
pidFunc func(string) string
expected string
}{
{
name: "HeadersPID with SN_MAIN",
chainID: "SN_MAIN",
pidFunc: func(c string) string { return string(HeadersPID(c)) },
expected: "/starknet/SN_MAIN/sync/headers/0.1.0-rc.0",
},
{
name: "EventsPID with SN_MAIN",
chainID: "SN_MAIN",
pidFunc: func(c string) string { return string(EventsPID(c)) },
expected: "/starknet/SN_MAIN/sync/events/0.1.0-rc.0",
},
{
name: "TransactionsPID with SN_MAIN",
chainID: "SN_MAIN",
pidFunc: func(c string) string { return string(TransactionsPID(c)) },
expected: "/starknet/SN_MAIN/sync/transactions/0.1.0-rc.0",
},
{
name: "ClassesPID with SN_MAIN",
chainID: "SN_MAIN",
pidFunc: func(c string) string { return string(ClassesPID(c)) },
expected: "/starknet/SN_MAIN/sync/classes/0.1.0-rc.0",
},
{
name: "StateDiffPID with SN_MAIN",
chainID: "SN_MAIN",
pidFunc: func(c string) string { return string(StateDiffPID(c)) },
expected: "/starknet/SN_MAIN/sync/state_diffs/0.1.0-rc.0",
},
{
name: "ChainPID with SN_MAIN",
chainID: "SN_MAIN",
pidFunc: func(c string) string { return string(ChainPID(c)) },
expected: "/starknet/SN_MAIN/sync",
},
{
name: "HeadersPID with SN_SEPOLIA",
chainID: "SN_SEPOLIA",
pidFunc: func(c string) string { return string(HeadersPID(c)) },
expected: "/starknet/SN_SEPOLIA/sync/headers/0.1.0-rc.0",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := tc.pidFunc(tc.chainID)
assert.Equal(t, tc.expected, result)
})
}
}

0 comments on commit df1ed30

Please sign in to comment.