diff --git a/README.md b/README.md index 8dc2e03e..b44273ae 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,7 @@ go run main.go | `starknet_addDeployAccountTransaction` | :heavy_check_mark: | | `starknet_traceTransaction` | :heavy_check_mark: | | `starknet_simulateTransaction` | :heavy_check_mark: | +| `starknet_specVersion` | :heavy_check_mark: | | `starknet_traceBlockTransactions` | :heavy_check_mark: | ### Run Tests diff --git a/account/account.go b/account/account.go index 2b87297e..592caec6 100644 --- a/account/account.go +++ b/account/account.go @@ -389,6 +389,9 @@ func (account *Account) StorageAt(ctx context.Context, contractAddress *felt.Fel func (account *Account) StateUpdate(ctx context.Context, blockID rpc.BlockID) (*rpc.StateUpdateOutput, error) { return account.provider.StateUpdate(ctx, blockID) } +func (account *Account) SpecVersion(ctx context.Context) (string, error) { + return account.provider.SpecVersion(ctx) +} func (account *Account) Syncing(ctx context.Context) (*rpc.SyncStatus, error) { return account.provider.Syncing(ctx) } diff --git a/mocks/mock_rpc_provider.go b/mocks/mock_rpc_provider.go index 7f793039..7da07caf 100644 --- a/mocks/mock_rpc_provider.go +++ b/mocks/mock_rpc_provider.go @@ -321,6 +321,21 @@ func (mr *MockRpcProviderMockRecorder) SimulateTransactions(ctx, blockID, txns, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SimulateTransactions", reflect.TypeOf((*MockRpcProvider)(nil).SimulateTransactions), ctx, blockID, txns, simulationFlags) } +// SpecVersion mocks base method. +func (m *MockRpcProvider) SpecVersion(ctx context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SpecVersion", ctx) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SpecVersion indicates an expected call of SpecVersion. +func (mr *MockRpcProviderMockRecorder) SpecVersion(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SpecVersion", reflect.TypeOf((*MockRpcProvider)(nil).SpecVersion), ctx) +} + // StateUpdate mocks base method. func (m *MockRpcProvider) StateUpdate(ctx context.Context, blockID rpc.BlockID) (*rpc.StateUpdateOutput, error) { m.ctrl.T.Helper() diff --git a/rpc/block.go b/rpc/block.go index 7625d357..921b8d45 100644 --- a/rpc/block.go +++ b/rpc/block.go @@ -22,11 +22,8 @@ func (provider *Provider) BlockNumber(ctx context.Context) (uint64, error) { // BlockHashAndNumber gets block information given the block number or its hash. func (provider *Provider) BlockHashAndNumber(ctx context.Context) (*BlockHashAndNumberOutput, error) { var block BlockHashAndNumberOutput - if err := do(ctx, provider.c, "starknet_blockHashAndNumber", &block); err != nil { - if errors.Is(err, errNotFound) { - return nil, ErrNoBlocks - } - return nil, err + if err := do(ctx, provider.c, "starknet_blockHashAndNumber", &block); err != nil { + return nil, tryUnwrapToRPCErr(err, ErrNoBlocks ) } return &block, nil } @@ -52,11 +49,8 @@ func WithBlockTag(tag string) BlockID { // BlockWithTxHashes gets block information given the block id. func (provider *Provider) BlockWithTxHashes(ctx context.Context, blockID BlockID) (interface{}, error) { var result BlockTxHashes - if err := do(ctx, provider.c, "starknet_getBlockWithTxHashes", &result, blockID); err != nil { - if errors.Is(err, errNotFound) { - return nil, ErrBlockNotFound - } - return nil, err + if err := do(ctx, provider.c, "starknet_getBlockWithTxHashes", &result, blockID); err != nil { + return nil, tryUnwrapToRPCErr(err,ErrBlockNotFound ) } // if header.Hash == nil it's a pending block @@ -75,11 +69,8 @@ func (provider *Provider) BlockWithTxHashes(ctx context.Context, blockID BlockID // StateUpdate gets the information about the result of executing the requested block. func (provider *Provider) StateUpdate(ctx context.Context, blockID BlockID) (*StateUpdateOutput, error) { var state StateUpdateOutput - if err := do(ctx, provider.c, "starknet_getStateUpdate", &state, blockID); err != nil { - if errors.Is(err, errNotFound) { - return nil, ErrBlockNotFound - } - return nil, err + if err := do(ctx, provider.c, "starknet_getStateUpdate", &state, blockID); err != nil { + return nil,tryUnwrapToRPCErr(err,ErrBlockNotFound ) } return &state, nil } @@ -100,10 +91,7 @@ func (provider *Provider) BlockTransactionCount(ctx context.Context, blockID Blo func (provider *Provider) BlockWithTxs(ctx context.Context, blockID BlockID) (interface{}, error) { var result Block if err := do(ctx, provider.c, "starknet_getBlockWithTxs", &result, blockID); err != nil { - if errors.Is(err, errNotFound) { - return nil, ErrBlockNotFound - } - return nil, err + return nil, tryUnwrapToRPCErr(err,ErrBlockNotFound ) } // if header.Hash == nil it's a pending block if result.BlockHeader.BlockHash == nil { diff --git a/rpc/call.go b/rpc/call.go index d499df3a..44b52eae 100644 --- a/rpc/call.go +++ b/rpc/call.go @@ -2,7 +2,7 @@ package rpc import ( "context" - "errors" + "github.com/NethermindEth/juno/core/felt" ) @@ -15,15 +15,8 @@ func (provider *Provider) Call(ctx context.Context, request FunctionCall, blockI } var result []*felt.Felt if err := do(ctx, provider.c, "starknet_call", &result, request, blockID); err != nil { - switch { - case errors.Is(err, ErrContractNotFound): - return nil, ErrContractNotFound - case errors.Is(err, ErrContractError): - return nil, ErrContractError - case errors.Is(err, ErrBlockNotFound): - return nil, ErrBlockNotFound - } - return nil, err + + return nil, tryUnwrapToRPCErr(err, ErrContractNotFound, ErrContractError, ErrBlockNotFound) } return result, nil } diff --git a/rpc/contract.go b/rpc/contract.go index 940c550a..c56868f1 100644 --- a/rpc/contract.go +++ b/rpc/contract.go @@ -3,7 +3,7 @@ package rpc import ( "context" "encoding/json" - "errors" + "fmt" "github.com/NethermindEth/juno/core/felt" @@ -14,13 +14,8 @@ import ( func (provider *Provider) Class(ctx context.Context, blockID BlockID, classHash *felt.Felt) (ClassOutput, error) { var rawClass map[string]any if err := do(ctx, provider.c, "starknet_getClass", &rawClass, blockID, classHash); err != nil { - switch { - case errors.Is(err, ErrClassHashNotFound): - return nil, ErrClassHashNotFound - case errors.Is(err, ErrBlockNotFound): - return nil, ErrBlockNotFound - } - return nil, err + + return nil, tryUnwrapToRPCErr(err, ErrClassHashNotFound, ErrBlockNotFound) } return typecastClassOutput(&rawClass) @@ -31,13 +26,8 @@ func (provider *Provider) Class(ctx context.Context, blockID BlockID, classHash func (provider *Provider) ClassAt(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (ClassOutput, error) { var rawClass map[string]any if err := do(ctx, provider.c, "starknet_getClassAt", &rawClass, blockID, contractAddress); err != nil { - switch { - case errors.Is(err, ErrContractNotFound): - return nil, ErrContractNotFound - case errors.Is(err, ErrBlockNotFound): - return nil, ErrBlockNotFound - } - return nil, err + + return nil, tryUnwrapToRPCErr(err, ErrContractNotFound, ErrBlockNotFound) } return typecastClassOutput(&rawClass) } @@ -69,13 +59,8 @@ func typecastClassOutput(rawClass *map[string]any) (ClassOutput, error) { func (provider *Provider) ClassHashAt(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (*felt.Felt, error) { var result *felt.Felt if err := do(ctx, provider.c, "starknet_getClassHashAt", &result, blockID, contractAddress); err != nil { - switch { - case errors.Is(err, ErrContractNotFound): - return nil, ErrContractNotFound - case errors.Is(err, ErrBlockNotFound): - return nil, ErrBlockNotFound - } - return nil, err + + return nil, tryUnwrapToRPCErr(err, ErrContractNotFound, ErrBlockNotFound) } return result, nil } @@ -85,13 +70,8 @@ func (provider *Provider) StorageAt(ctx context.Context, contractAddress *felt.F var value string hashKey := fmt.Sprintf("0x%x", utils.GetSelectorFromName(key)) if err := do(ctx, provider.c, "starknet_getStorageAt", &value, contractAddress, hashKey, blockID); err != nil { - switch { - case errors.Is(err, ErrContractNotFound): - return "", ErrContractNotFound - case errors.Is(err, ErrBlockNotFound): - return "", ErrBlockNotFound - } - return "", err + + return "", tryUnwrapToRPCErr(err, ErrContractNotFound, ErrBlockNotFound) } return value, nil } @@ -100,13 +80,8 @@ func (provider *Provider) StorageAt(ctx context.Context, contractAddress *felt.F func (provider *Provider) Nonce(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (*string, error) { nonce := "" if err := do(ctx, provider.c, "starknet_getNonce", &nonce, blockID, contractAddress); err != nil { - switch { - case errors.Is(err, ErrContractNotFound): - return nil, ErrContractNotFound - case errors.Is(err, ErrBlockNotFound): - return nil, ErrBlockNotFound - } - return nil, err + + return nil, tryUnwrapToRPCErr(err, ErrContractNotFound, ErrBlockNotFound) } return &nonce, nil } @@ -115,15 +90,8 @@ func (provider *Provider) Nonce(ctx context.Context, blockID BlockID, contractAd func (provider *Provider) EstimateFee(ctx context.Context, requests []EstimateFeeInput, blockID BlockID) ([]FeeEstimate, error) { var raw []FeeEstimate if err := do(ctx, provider.c, "starknet_estimateFee", &raw, requests, blockID); err != nil { - switch { - case errors.Is(err, ErrContractNotFound): - return nil, ErrContractNotFound - case errors.Is(err, ErrContractError): - return nil, ErrContractError - case errors.Is(err, ErrBlockNotFound): - return nil, ErrBlockNotFound - } - return nil, err + + return nil, tryUnwrapToRPCErr(err, ErrContractNotFound,ErrContractError, ErrBlockNotFound) } return raw, nil } @@ -132,15 +100,8 @@ func (provider *Provider) EstimateFee(ctx context.Context, requests []EstimateFe func (provider *Provider) EstimateMessageFee(ctx context.Context, msg MsgFromL1, blockID BlockID) (*FeeEstimate, error) { var raw FeeEstimate if err := do(ctx, provider.c, "starknet_estimateMessageFee", &raw, msg, blockID); err != nil { - switch { - case errors.Is(err, ErrContractNotFound): - return nil, ErrContractNotFound - case errors.Is(err, ErrContractError): - return nil, ErrContractError - case errors.Is(err, ErrBlockNotFound): - return nil, ErrBlockNotFound - } - return nil, err + + return nil, tryUnwrapToRPCErr(err, ErrContractNotFound,ErrContractError, ErrBlockNotFound) } return &raw, nil } diff --git a/rpc/events.go b/rpc/events.go index a50860b8..d8d63651 100644 --- a/rpc/events.go +++ b/rpc/events.go @@ -2,24 +2,15 @@ package rpc import ( "context" - "errors" + ) // Events returns all events matching the given filter func (provider *Provider) Events(ctx context.Context, input EventsInput) (*EventChunk, error) { var result EventChunk if err := do(ctx, provider.c, "starknet_getEvents", &result, input); err != nil { - switch { - case errors.Is(err, ErrPageSizeTooBig): - return nil, ErrPageSizeTooBig - case errors.Is(err, ErrInvalidContinuationToken): - return nil, ErrInvalidContinuationToken - case errors.Is(err, ErrBlockNotFound): - return nil, ErrBlockNotFound - case errors.Is(err, ErrTooManyKeysInFilter): - return nil, ErrTooManyKeysInFilter - } - return nil, err + + return nil, tryUnwrapToRPCErr(err, ErrPageSizeTooBig , ErrInvalidContinuationToken , ErrBlockNotFound ,ErrTooManyKeysInFilter) } return &result, nil } diff --git a/rpc/provider.go b/rpc/provider.go index 7ed0516c..119a8e4d 100644 --- a/rpc/provider.go +++ b/rpc/provider.go @@ -47,6 +47,7 @@ type RpcProvider interface { SimulateTransactions(ctx context.Context, blockID BlockID, txns []Transaction, simulationFlags []SimulationFlag) ([]SimulatedTransaction, error) StateUpdate(ctx context.Context, blockID BlockID) (*StateUpdateOutput, error) StorageAt(ctx context.Context, contractAddress *felt.Felt, key string, blockID BlockID) (string, error) + SpecVersion(ctx context.Context) (string, error) Syncing(ctx context.Context) (*SyncStatus, error) TraceBlockTransactions(ctx context.Context, blockHash *felt.Felt) ([]Trace, error) TransactionByBlockIdAndIndex(ctx context.Context, blockID BlockID, index uint64) (Transaction, error) diff --git a/rpc/transaction.go b/rpc/transaction.go index 6eb93406..0e09f9a4 100644 --- a/rpc/transaction.go +++ b/rpc/transaction.go @@ -62,11 +62,8 @@ func (provider *Provider) TransactionByHash(ctx context.Context, hash *felt.Felt // todo: update to return a custom Transaction type, then use adapt function var tx TXN if err := do(ctx, provider.c, "starknet_getTransactionByHash", &tx, hash); err != nil { - if errors.Is(err, ErrHashNotFound) { - return nil, ErrHashNotFound - } - return nil, err - } + return nil, tryUnwrapToRPCErr(err,ErrHashNotFound) +} return adaptTransaction(tx) } @@ -74,13 +71,9 @@ func (provider *Provider) TransactionByHash(ctx context.Context, hash *felt.Felt func (provider *Provider) TransactionByBlockIdAndIndex(ctx context.Context, blockID BlockID, index uint64) (Transaction, error) { var tx TXN if err := do(ctx, provider.c, "starknet_getTransactionByBlockIdAndIndex", &tx, blockID, index); err != nil { - switch { - case errors.Is(err, ErrInvalidTxnIndex): - return nil, ErrInvalidTxnIndex - case errors.Is(err, ErrBlockNotFound): - return nil, ErrBlockNotFound - } - return nil, err + + return nil,tryUnwrapToRPCErr(err, ErrInvalidTxnIndex ,ErrBlockNotFound) + } return adaptTransaction(tx) } @@ -90,10 +83,7 @@ func (provider *Provider) TransactionReceipt(ctx context.Context, transactionHas var receipt UnknownTransactionReceipt err := do(ctx, provider.c, "starknet_getTransactionReceipt", &receipt, transactionHash) if err != nil { - if errors.Is(err, ErrHashNotFound) { - return nil, ErrHashNotFound - } - return nil, err + return nil, tryUnwrapToRPCErr(err,ErrHashNotFound) } return receipt.TransactionReceipt, nil } diff --git a/rpc/version.go b/rpc/version.go new file mode 100644 index 00000000..e46f4c23 --- /dev/null +++ b/rpc/version.go @@ -0,0 +1,12 @@ +package rpc + +import "context" + +// SpecVersion returns the version of the Starknet JSON-RPC specification being used +// Parameters: None +// Returns: String of the Starknet JSON-RPC specification +func (provider *Provider) SpecVersion(ctx context.Context) (string, error) { + var result string + err := do(ctx, provider.c, "starknet_specVersion", &result) + return result, err +} diff --git a/rpc/version_test.go b/rpc/version_test.go new file mode 100644 index 00000000..57a53fcf --- /dev/null +++ b/rpc/version_test.go @@ -0,0 +1,32 @@ +package rpc + +import ( + "context" + "testing" + + "github.com/test-go/testify/require" +) + +// TestSpecVersion tests starknet_specVersion +func TestSpecVersion(t *testing.T) { + + testConfig := beforeEach(t) + + type testSetType struct { + ExpectedResp string + } + testSet := map[string][]testSetType{ + "devnet": {}, + "mainnet": {}, + "mock": {}, + "testnet": {{ + ExpectedResp: "0.5.0", + }}, + }[testEnv] + + for _, test := range testSet { + resp, err := testConfig.provider.SpecVersion(context.Background()) + require.NoError(t, err) + require.Equal(t, test.ExpectedResp, resp) + } +} diff --git a/typed/typed.go b/typed/typed.go index 1bddc958..23f09647 100644 --- a/typed/typed.go +++ b/typed/typed.go @@ -105,11 +105,11 @@ func NewTypedData(types map[string]TypeDef, pType string, dom Domain) (td TypedD return td, nil } -// (ref: https://github.com/0xs34n/starknet.js/blob/767021a203ac0b9cdb282eb6d63b33bfd7614858/src/utils/typedData/index.ts#L166) +// (ref: https://github.com/starknet-io/starknet.js/blob/d7bfc37ede85448e0a55ee4efe65200ff2c45f91/src/utils/typedData.ts#L249) func (td TypedData) GetMessageHash(account *big.Int, msg TypedMessage, sc curve.StarkCurve) (hash *big.Int, err error) { - elements := []*big.Int{utils.UTF8StrToBig("Starknet Message")} + elements := []*big.Int{utils.UTF8StrToBig("StarkNet Message")} - domEnc, err := td.GetTypedMessageHash("StarknetDomain", td.Domain, sc) + domEnc, err := td.GetTypedMessageHash("StarkNetDomain", td.Domain, sc) if err != nil { return hash, fmt.Errorf("could not hash domain: %w", err) } diff --git a/typed/typed_test.go b/typed/typed_test.go index 9065fbbb..c75887ae 100644 --- a/typed/typed_test.go +++ b/typed/typed_test.go @@ -36,14 +36,14 @@ func (mail Mail) FmtDefinitionEncoding(field string) (fmtEnc []*big.Int) { func MockTypedData() (ttd TypedData) { exampleTypes := make(map[string]TypeDef) domDefs := []Definition{{"name", "felt"}, {"version", "felt"}, {"chainId", "felt"}} - exampleTypes["StarknetDomain"] = TypeDef{Definitions: domDefs} + exampleTypes["StarkNetDomain"] = TypeDef{Definitions: domDefs} mailDefs := []Definition{{"from", "Person"}, {"to", "Person"}, {"contents", "felt"}} exampleTypes["Mail"] = TypeDef{Definitions: mailDefs} persDefs := []Definition{{"name", "felt"}, {"wallet", "felt"}} exampleTypes["Person"] = TypeDef{Definitions: persDefs} dm := Domain{ - Name: "Starknet Mail", + Name: "StarkNet Mail", Version: "1", ChainId: "1", } @@ -101,7 +101,7 @@ func BenchmarkGetMessageHash(b *testing.B) { func TestGeneral_GetDomainHash(t *testing.T) { ttd := MockTypedData() - hash, err := ttd.GetTypedMessageHash("StarknetDomain", ttd.Domain, curve.Curve) + hash, err := ttd.GetTypedMessageHash("StarkNetDomain", ttd.Domain, curve.Curve) if err != nil { t.Errorf("Could not hash message: %v\n", err) } @@ -142,7 +142,7 @@ func TestGeneral_GetTypedMessageHash(t *testing.T) { func TestGeneral_GetTypeHash(t *testing.T) { tdd := MockTypedData() - hash, err := tdd.GetTypeHash("StarknetDomain") + hash, err := tdd.GetTypeHash("StarkNetDomain") if err != nil { t.Errorf("error enccoding type %v\n", err) } @@ -152,7 +152,7 @@ func TestGeneral_GetTypeHash(t *testing.T) { t.Errorf("type hash: %v does not match expected %v\n", utils.BigToHex(hash), exp) } - enc := tdd.Types["StarknetDomain"] + enc := tdd.Types["StarkNetDomain"] if utils.BigToHex(enc.Encoding) != exp { t.Errorf("type hash: %v does not match expected %v\n", utils.BigToHex(hash), exp) }