diff --git a/p2p/p2p.go b/p2p/p2p.go index f0b54c3381..f274401071 100644 --- a/p2p/p2p.go +++ b/p2p/p2p.go @@ -138,7 +138,7 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai } } - p2pdht, err := makeDHT(p2phost, peersAddrInfoS) + p2pdht, err := MakeDHT(p2phost, peersAddrInfoS, snNetwork) if err != nil { return nil, err } @@ -159,9 +159,9 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai return s, nil } -func makeDHT(p2phost host.Host, addrInfos []peer.AddrInfo) (*dht.IpfsDHT, error) { +func MakeDHT(p2phost host.Host, addrInfos []peer.AddrInfo, network *utils.Network) (*dht.IpfsDHT, error) { return dht.New(context.Background(), p2phost, - dht.ProtocolPrefix(starknet.Prefix), + dht.ProtocolPrefix(starknet.DHTPrefixPID(network)), dht.BootstrapPeers(addrInfos...), dht.RoutingTableRefreshPeriod(routingTableRefreshPeriod), dht.Mode(dht.ModeServer), @@ -249,11 +249,11 @@ func (s *Service) Run(ctx context.Context) error { } func (s *Service) setProtocolHandlers() { - s.SetProtocolHandler(starknet.HeadersPID(), s.handler.HeadersHandler) - s.SetProtocolHandler(starknet.EventsPID(), s.handler.EventsHandler) - s.SetProtocolHandler(starknet.TransactionsPID(), s.handler.TransactionsHandler) - s.SetProtocolHandler(starknet.ClassesPID(), s.handler.ClassesHandler) - s.SetProtocolHandler(starknet.StateDiffPID(), s.handler.StateDiffHandler) + s.SetProtocolHandler(starknet.HeadersPID(s.network), s.handler.HeadersHandler) + s.SetProtocolHandler(starknet.EventsPID(s.network), s.handler.EventsHandler) + s.SetProtocolHandler(starknet.TransactionsPID(s.network), s.handler.TransactionsHandler) + s.SetProtocolHandler(starknet.ClassesPID(s.network), s.handler.ClassesHandler) + s.SetProtocolHandler(starknet.StateDiffPID(s.network), s.handler.StateDiffHandler) } func (s *Service) callAndLogErr(f func() error, msg string) { diff --git a/p2p/p2p_test.go b/p2p/p2p_test.go index 54b19d5900..0e10a4c04c 100644 --- a/p2p/p2p_test.go +++ b/p2p/p2p_test.go @@ -8,7 +8,10 @@ import ( "github.com/NethermindEth/juno/p2p" "github.com/NethermindEth/juno/utils" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -64,3 +67,36 @@ func TestLoadAndPersistPeers(t *testing.T) { ) require.NoError(t, err) } + +func TestMakeDHTProtocolName(t *testing.T) { + net, err := mocknet.FullMeshLinked(1) + require.NoError(t, err) + testHost := net.Hosts()[0] + + testCases := []struct { + name string + network *utils.Network + expected string + }{ + { + name: "sepolia network", + network: &utils.Sepolia, + expected: "/starknet/SN_SEPOLIA/sync/kad/1.0.0", + }, + { + name: "mainnet network", + network: &utils.Mainnet, + expected: "/starknet/SN_MAIN/sync/kad/1.0.0", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dht, err := p2p.MakeDHT(testHost, nil, tc.network) + require.NoError(t, err) + + protocols := dht.Host().Mux().Protocols() + assert.Contains(t, protocols, protocol.ID(tc.expected), "protocol list: %v", protocols) + }) + } +} diff --git a/p2p/starknet/client.go b/p2p/starknet/client.go index bfeed7ab7a..3e720e9597 100644 --- a/p2p/starknet/client.go +++ b/p2p/starknet/client.go @@ -104,22 +104,24 @@ func (c *Client) RequestBlockHeaders( ctx context.Context, req *spec.BlockHeadersRequest, ) (iter.Seq[*spec.BlockHeadersResponse], error) { return requestAndReceiveStream[*spec.BlockHeadersRequest, *spec.BlockHeadersResponse]( - ctx, c.newStream, HeadersPID(), req, c.log) + ctx, c.newStream, HeadersPID(c.network), req, c.log) } func (c *Client) RequestEvents(ctx context.Context, req *spec.EventsRequest) (iter.Seq[*spec.EventsResponse], error) { - return requestAndReceiveStream[*spec.EventsRequest, *spec.EventsResponse](ctx, c.newStream, EventsPID(), req, c.log) + return requestAndReceiveStream[*spec.EventsRequest, *spec.EventsResponse](ctx, c.newStream, EventsPID(c.network), req, c.log) } func (c *Client) RequestClasses(ctx context.Context, req *spec.ClassesRequest) (iter.Seq[*spec.ClassesResponse], error) { - return requestAndReceiveStream[*spec.ClassesRequest, *spec.ClassesResponse](ctx, c.newStream, ClassesPID(), req, c.log) + return requestAndReceiveStream[*spec.ClassesRequest, *spec.ClassesResponse](ctx, c.newStream, ClassesPID(c.network), req, c.log) } func (c *Client) RequestStateDiffs(ctx context.Context, req *spec.StateDiffsRequest) (iter.Seq[*spec.StateDiffsResponse], error) { - return requestAndReceiveStream[*spec.StateDiffsRequest, *spec.StateDiffsResponse](ctx, c.newStream, StateDiffPID(), req, c.log) + return requestAndReceiveStream[*spec.StateDiffsRequest, *spec.StateDiffsResponse]( + ctx, c.newStream, StateDiffPID(c.network), req, c.log, + ) } func (c *Client) RequestTransactions(ctx context.Context, req *spec.TransactionsRequest) (iter.Seq[*spec.TransactionsResponse], error) { return requestAndReceiveStream[*spec.TransactionsRequest, *spec.TransactionsResponse]( - ctx, c.newStream, TransactionsPID(), req, c.log) + ctx, c.newStream, TransactionsPID(c.network), req, c.log) } diff --git a/p2p/starknet/ids.go b/p2p/starknet/ids.go index d1b97b0ad2..1924c1705a 100644 --- a/p2p/starknet/ids.go +++ b/p2p/starknet/ids.go @@ -1,27 +1,32 @@ package starknet import ( + "github.com/NethermindEth/juno/utils" "github.com/libp2p/go-libp2p/core/protocol" ) const Prefix = "/starknet" -func HeadersPID() protocol.ID { - return Prefix + "/headers/0.1.0-rc.0" +func HeadersPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/headers/0.1.0-rc.0") } -func EventsPID() protocol.ID { - return Prefix + "/events/0.1.0-rc.0" +func EventsPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/events/0.1.0-rc.0") } -func TransactionsPID() protocol.ID { - return Prefix + "/transactions/0.1.0-rc.0" +func TransactionsPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/transactions/0.1.0-rc.0") } -func ClassesPID() protocol.ID { - return Prefix + "/classes/0.1.0-rc.0" +func ClassesPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/classes/0.1.0-rc.0") } -func StateDiffPID() protocol.ID { - return Prefix + "/state_diffs/0.1.0-rc.0" +func StateDiffPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/state_diffs/0.1.0-rc.0") +} + +func DHTPrefixPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync") } diff --git a/p2p/starknet/ids_test.go b/p2p/starknet/ids_test.go new file mode 100644 index 0000000000..c44e2cb276 --- /dev/null +++ b/p2p/starknet/ids_test.go @@ -0,0 +1,67 @@ +package starknet + +import ( + "testing" + + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/assert" +) + +func TestProtocolIDs(t *testing.T) { + testCases := []struct { + name string + network *utils.Network + pidFunc func(*utils.Network) string + expected string + }{ + { + name: "HeadersPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(HeadersPID(n)) }, + expected: "/starknet/SN_MAIN/sync/headers/0.1.0-rc.0", + }, + { + name: "EventsPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(EventsPID(n)) }, + expected: "/starknet/SN_MAIN/sync/events/0.1.0-rc.0", + }, + { + name: "TransactionsPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(TransactionsPID(n)) }, + expected: "/starknet/SN_MAIN/sync/transactions/0.1.0-rc.0", + }, + { + name: "ClassesPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(ClassesPID(n)) }, + expected: "/starknet/SN_MAIN/sync/classes/0.1.0-rc.0", + }, + { + name: "StateDiffPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(StateDiffPID(n)) }, + expected: "/starknet/SN_MAIN/sync/state_diffs/0.1.0-rc.0", + }, + { + name: "DHTPrefixPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(DHTPrefixPID(n)) }, + expected: "/starknet/SN_MAIN/sync", + }, + { + name: "HeadersPID with SN_SEPOLIA", + network: &utils.Sepolia, + pidFunc: func(n *utils.Network) string { return string(HeadersPID(n)) }, + expected: "/starknet/SN_SEPOLIA/sync/headers/0.1.0-rc.0", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.pidFunc(tc.network) + assert.Equal(t, tc.expected, result) + }) + } +}