Skip to content

Commit

Permalink
Fix FmtCallData (#498)
Browse files Browse the repository at this point in the history
* Fix FmtCallData

* simplify fmtCallData usage
  • Loading branch information
rianhughes authored Dec 18, 2023
1 parent a2f9a0e commit 7d4e5b4
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 85 deletions.
90 changes: 47 additions & 43 deletions account/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type Account struct {
ChainId *felt.Felt
AccountAddress *felt.Felt
publicKey string
CairoVersion int
ks Keystore
}

Expand All @@ -61,12 +62,13 @@ type Account struct {
// It returns:
// - *Account: a pointer to newly created Account
// - error: an error if any
func NewAccount(provider rpc.RpcProvider, accountAddress *felt.Felt, publicKey string, keystore Keystore) (*Account, error) {
func NewAccount(provider rpc.RpcProvider, accountAddress *felt.Felt, publicKey string, keystore Keystore, cairoVersion int) (*Account, error) {
account := &Account{
provider: provider,
AccountAddress: accountAddress,
publicKey: publicKey,
ks: keystore,
CairoVersion: cairoVersion,
}

chainID, err := provider.ChainID(context.Background())
Expand Down Expand Up @@ -892,70 +894,72 @@ func (account *Account) GetTransactionStatus(ctx context.Context, Txnhash *felt.
// Returns:
// - a slice of *felt.Felt representing the formatted calldata.
// - an error if Cairo version is not supported.
func (account *Account) FmtCalldata(fnCalls []rpc.FunctionCall, cairoVersion int) ([]*felt.Felt, error) {
switch cairoVersion {
func (account *Account) FmtCalldata(fnCalls []rpc.FunctionCall) ([]*felt.Felt, error) {
switch account.CairoVersion {
case 0:
return FmtCalldataCairo0(fnCalls), nil
return FmtCallDataCairo0(fnCalls), nil
case 2:
return FmtCalldataCairo2(fnCalls), nil
return FmtCallDataCairo2(fnCalls), nil
default:
return nil, errors.New("Cairo version not supported")
}
}

// FmtCalldataCairo0 generates a slice of *felt.Felt that represents the calldata for the given function calls in Cairo 0 format.
// FmtCallDataCairo0 generates a slice of *felt.Felt that represents the calldata for the given function calls in Cairo 0 format.
//
// Parameters:
// - fnCalls: a slice of rpc.FunctionCall containing the function calls.
//
// Returns:
// - a slice of *felt.Felt representing the generated calldata.
func FmtCalldataCairo0(fnCalls []rpc.FunctionCall) []*felt.Felt {
execCallData := []*felt.Felt{}
execCallData = append(execCallData, new(felt.Felt).SetUint64(uint64(len(fnCalls))))

// Legacy : Cairo 0
concatCallData := []*felt.Felt{}
for _, fnCall := range fnCalls {
execCallData = append(
execCallData,
fnCall.ContractAddress,
fnCall.EntryPointSelector,
new(felt.Felt).SetUint64(uint64(len(concatCallData))),
new(felt.Felt).SetUint64(uint64(len(fnCall.Calldata))+1),
)
concatCallData = append(concatCallData, fnCall.Calldata...)
// https://github.com/project3fusion/StarkSharp/blob/main/StarkSharp/StarkSharp.Rpc/Modules/Transactions/Hash/TransactionHash.cs#L27
func FmtCallDataCairo0(callArray []rpc.FunctionCall) []*felt.Felt {
var calldata []*felt.Felt
var calls []*felt.Felt

calldata = append(calldata, new(felt.Felt).SetUint64(uint64(len(callArray))))

offset := uint64(0)
for _, call := range callArray {
calldata = append(calldata, call.ContractAddress)
calldata = append(calldata, call.EntryPointSelector)
calldata = append(calldata, new(felt.Felt).SetUint64(uint64(offset)))
callDataLen := uint64(len(call.Calldata))
calldata = append(calldata, new(felt.Felt).SetUint64(callDataLen))
offset += callDataLen

for _, data := range call.Calldata {
calls = append(calls, data)
}
}
execCallData = append(execCallData, new(felt.Felt).SetUint64(uint64(len(concatCallData))+1))
execCallData = append(execCallData, concatCallData...)
execCallData = append(execCallData, new(felt.Felt).SetUint64(0))

return execCallData
calldata = append(calldata, new(felt.Felt).SetUint64(offset))
calldata = append(calldata, calls...)

return calldata
}

// FmtCalldataCairo2 generates the calldata for the given function calls for Cairo 2 contracs.
// FmtCallDataCairo2 generates the calldata for the given function calls for Cairo 2 contracs.
//
// Parameters:
// - fnCalls: a slice of rpc.FunctionCall containing the function calls.
// Returns:
// - a slice of *felt.Felt representing the generated calldata.
func FmtCalldataCairo2(fnCalls []rpc.FunctionCall) []*felt.Felt {
execCallData := []*felt.Felt{}
execCallData = append(execCallData, new(felt.Felt).SetUint64(uint64(len(fnCalls))))

concatCallData := []*felt.Felt{}
for _, fnCall := range fnCalls {
execCallData = append(
execCallData,
fnCall.ContractAddress,
fnCall.EntryPointSelector,
new(felt.Felt).SetUint64(uint64(len(concatCallData))),
new(felt.Felt).SetUint64(uint64(len(fnCall.Calldata))),
)
concatCallData = append(concatCallData, fnCall.Calldata...)
// https://github.com/project3fusion/StarkSharp/blob/main/StarkSharp/StarkSharp.Rpc/Modules/Transactions/Hash/TransactionHash.cs#L22
func FmtCallDataCairo2(callArray []rpc.FunctionCall) []*felt.Felt {
var result []*felt.Felt

result = append(result, new(felt.Felt).SetUint64(uint64(len(callArray))))

for _, call := range callArray {
result = append(result, call.ContractAddress)
result = append(result, call.EntryPointSelector)

callDataLen := uint64(len(call.Calldata))
result = append(result, new(felt.Felt).SetUint64(callDataLen))

result = append(result, call.Calldata...)
}
execCallData = append(execCallData, new(felt.Felt).SetUint64(uint64(len(concatCallData))))
execCallData = append(execCallData, concatCallData...)

return execCallData
return result
}
94 changes: 69 additions & 25 deletions account/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func TestTransactionHashInvoke(t *testing.T) {
}

mockRpcProvider.EXPECT().ChainID(context.Background()).Return(test.ChainID, nil)
account, err := account.NewAccount(mockRpcProvider, test.AccountAddress, test.PubKey, ks)
account, err := account.NewAccount(mockRpcProvider, test.AccountAddress, test.PubKey, ks, 0)
require.NoError(t, err, "error returned from account.NewAccount()")
invokeTxn := rpc.InvokeTxnV1{
Calldata: test.FnCall.Calldata,
Expand Down Expand Up @@ -211,21 +211,65 @@ func TestFmtCallData(t *testing.T) {
CairoVersion: 0,
ChainID: "SN_GOERLI",
FnCall: rpc.FunctionCall{
ContractAddress: utils.TestHexToFelt(t, "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7"),
EntryPointSelector: utils.GetSelectorFromNameFelt("transfer"),
Calldata: utils.TestHexArrToFelt(t, []string{
"0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7",
"0x1"}),
ContractAddress: utils.TestHexToFelt(t, "0x05f7cd1fd465baff2ba9d2d1501ad0a2eb5337d9a885be319366b5205a414fdd"),
EntryPointSelector: utils.GetSelectorFromNameFelt("increase_balance"),
Calldata: []*felt.Felt{new(felt.Felt).SetUint64(2), new(felt.Felt).SetUint64(2)},
},
ExpectedCallData: utils.TestHexArrToFelt(t, []string{
"0x1",
"0x05f7cd1fd465baff2ba9d2d1501ad0a2eb5337d9a885be319366b5205a414fdd",
"0x0362398bec32bc0ebb411203221a35a0301193a96f317ebe5e40be9f60d15320",
"0x0",
"0x2",
"0x2",
"0x2",
"0x2",
}),
},
{
CairoVersion: 0,
ChainID: "SN_GOERLI",
FnCall: rpc.FunctionCall{
ContractAddress: utils.TestHexToFelt(t, "0x4c1337d55351eac9a0b74f3b8f0d3928e2bb781e5084686a892e66d49d510d"),
EntryPointSelector: utils.GetSelectorFromNameFelt("increase_value"),
Calldata: []*felt.Felt{},
},
ExpectedCallData: utils.TestHexArrToFelt(t, []string{
"0x1",
"0x4c1337d55351eac9a0b74f3b8f0d3928e2bb781e5084686a892e66d49d510d",
"0x034c4c150632e67baf44fc50e9a685184d72a822510a26a66f72058b5e7b2892",
"0x0",
"0x0",
"0x0",
}),
},
{
CairoVersion: 2,
ChainID: "SN_GOERLI",
FnCall: rpc.FunctionCall{
ContractAddress: utils.TestHexToFelt(t, "0x4c1337d55351eac9a0b74f3b8f0d3928e2bb781e5084686a892e66d49d510d"),
EntryPointSelector: utils.GetSelectorFromNameFelt("increase_value"),
Calldata: []*felt.Felt{},
},
ExpectedCallData: utils.TestHexArrToFelt(t, []string{
"0x1",
"0x49d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7",
"0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e",
"0x4c1337d55351eac9a0b74f3b8f0d3928e2bb781e5084686a892e66d49d510d",
"0x034c4c150632e67baf44fc50e9a685184d72a822510a26a66f72058b5e7b2892",
"0x0",
"0x3",
"0x3",
"0x49d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7",
}),
},
{
CairoVersion: 2,
ChainID: "SN_GOERLI",
FnCall: rpc.FunctionCall{
ContractAddress: utils.TestHexToFelt(t, "0x4c1337d55351eac9a0b74f3b8f0d3928e2bb781e5084686a892e66d49d510d"),
EntryPointSelector: utils.GetSelectorFromNameFelt("increase_value"),
Calldata: []*felt.Felt{},
},
ExpectedCallData: utils.TestHexArrToFelt(t, []string{
"0x1",
"0x4c1337d55351eac9a0b74f3b8f0d3928e2bb781e5084686a892e66d49d510d",
"0x034c4c150632e67baf44fc50e9a685184d72a822510a26a66f72058b5e7b2892",
"0x0",
}),
},
Expand All @@ -236,10 +280,10 @@ func TestFmtCallData(t *testing.T) {

for _, test := range testSet {
mockRpcProvider.EXPECT().ChainID(context.Background()).Return(test.ChainID, nil)
acnt, err := account.NewAccount(mockRpcProvider, &felt.Zero, "pubkey", account.NewMemKeystore())
acnt, err := account.NewAccount(mockRpcProvider, &felt.Zero, "pubkey", account.NewMemKeystore(), test.CairoVersion)
require.NoError(t, err)

fmtCallData, err := acnt.FmtCalldata([]rpc.FunctionCall{test.FnCall}, test.CairoVersion)
fmtCallData, err := acnt.FmtCalldata([]rpc.FunctionCall{test.FnCall})
require.NoError(t, err)
require.Equal(t, fmtCallData, test.ExpectedCallData)
}
Expand Down Expand Up @@ -286,7 +330,7 @@ func TestChainIdMOCK(t *testing.T) {

for _, test := range testSet {
mockRpcProvider.EXPECT().ChainID(context.Background()).Return(test.ChainID, nil)
account, err := account.NewAccount(mockRpcProvider, &felt.Zero, "pubkey", account.NewMemKeystore())
account, err := account.NewAccount(mockRpcProvider, &felt.Zero, "pubkey", account.NewMemKeystore(), 0)
require.NoError(t, err)
require.Equal(t, account.ChainId.String(), test.ExpectedID)
}
Expand Down Expand Up @@ -328,7 +372,7 @@ func TestChainId(t *testing.T) {
require.NoError(t, err, "Error in rpc.NewClient")
provider := rpc.NewProvider(client)

account, err := account.NewAccount(provider, &felt.Zero, "pubkey", account.NewMemKeystore())
account, err := account.NewAccount(provider, &felt.Zero, "pubkey", account.NewMemKeystore(), 0)
require.NoError(t, err)
require.Equal(t, account.ChainId.String(), test.ExpectedID)
}
Expand Down Expand Up @@ -390,7 +434,7 @@ func TestSignMOCK(t *testing.T) {
ks.Put(test.Address.String(), privKeyBI)

mockRpcProvider.EXPECT().ChainID(context.Background()).Return(test.ChainId, nil)
account, err := account.NewAccount(mockRpcProvider, test.Address, test.Address.String(), ks)
account, err := account.NewAccount(mockRpcProvider, test.Address, test.Address.String(), ks, 0)
require.NoError(t, err, "error returned from account.NewAccount()")

msg := utils.TestHexToFelt(t, "0x73cf79c4bfa0c7a41f473c07e1be5ac25faa7c2fdf9edcbd12c1438f40f13d8")
Expand Down Expand Up @@ -560,10 +604,10 @@ func TestAddInvoke(t *testing.T) {
ks.Put(test.PubKey.String(), fakePrivKeyBI)
}

acnt, err := account.NewAccount(provider, test.AccountAddress, test.PubKey.String(), ks)
acnt, err := account.NewAccount(provider, test.AccountAddress, test.PubKey.String(), ks, 0)
require.NoError(t, err)

test.InvokeTx.Calldata, err = acnt.FmtCalldata([]rpc.FunctionCall{test.FnCall}, test.CairoContractVersion)
test.InvokeTx.Calldata, err = acnt.FmtCalldata([]rpc.FunctionCall{test.FnCall})
require.NoError(t, err)

err = acnt.SignInvokeTransaction(context.Background(), &test.InvokeTx)
Expand Down Expand Up @@ -615,7 +659,7 @@ func TestAddDeployAccountDevnet(t *testing.T) {
require.True(t, ok)
ks.Put(fakeUser.PublicKey, fakePrivKeyBI)

acnt, err := account.NewAccount(provider, fakeUserAddr, fakeUser.PublicKey, ks)
acnt, err := account.NewAccount(provider, fakeUserAddr, fakeUser.PublicKey, ks, 0)
require.NoError(t, err)

classHash := utils.TestHexToFelt(t, "0x7b3e05f48f0c69e4a65ce5e076a66271a527aff2c34ce1083ec6e1526997a69") // preDeployed classhash
Expand Down Expand Up @@ -668,7 +712,7 @@ func TestTransactionHashDeclare(t *testing.T) {
mockRpcProvider := mocks.NewMockRpcProvider(mockCtrl)
mockRpcProvider.EXPECT().ChainID(context.Background()).Return("SN_GOERLI", nil)

acnt, err := account.NewAccount(mockRpcProvider, &felt.Zero, "", account.NewMemKeystore())
acnt, err := account.NewAccount(mockRpcProvider, &felt.Zero, "", account.NewMemKeystore(), 0)
require.NoError(t, err)

type testSetType struct {
Expand Down Expand Up @@ -738,7 +782,7 @@ func TestTransactionHashInvokeV3(t *testing.T) {
mockRpcProvider := mocks.NewMockRpcProvider(mockCtrl)
mockRpcProvider.EXPECT().ChainID(context.Background()).Return("SN_GOERLI", nil)

acnt, err := account.NewAccount(mockRpcProvider, &felt.Zero, "", account.NewMemKeystore())
acnt, err := account.NewAccount(mockRpcProvider, &felt.Zero, "", account.NewMemKeystore(), 0)
require.NoError(t, err)

type testSetType struct {
Expand Down Expand Up @@ -809,7 +853,7 @@ func TestTransactionHashdeployAccount(t *testing.T) {
mockRpcProvider := mocks.NewMockRpcProvider(mockCtrl)
mockRpcProvider.EXPECT().ChainID(context.Background()).Return("SN_GOERLI", nil)

acnt, err := account.NewAccount(mockRpcProvider, &felt.Zero, "", account.NewMemKeystore())
acnt, err := account.NewAccount(mockRpcProvider, &felt.Zero, "", account.NewMemKeystore(), 0)
require.NoError(t, err)

type testSetType struct {
Expand Down Expand Up @@ -900,7 +944,7 @@ func TestWaitForTransactionReceiptMOCK(t *testing.T) {
mockRpcProvider := mocks.NewMockRpcProvider(mockCtrl)

mockRpcProvider.EXPECT().ChainID(context.Background()).Return("SN_GOERLI", nil)
acnt, err := account.NewAccount(mockRpcProvider, &felt.Zero, "", account.NewMemKeystore())
acnt, err := account.NewAccount(mockRpcProvider, &felt.Zero, "", account.NewMemKeystore(), 0)
require.NoError(t, err, "error returned from account.NewAccount()")

type testSetType struct {
Expand Down Expand Up @@ -985,7 +1029,7 @@ func TestWaitForTransactionReceipt(t *testing.T) {
require.NoError(t, err, "Error in rpc.NewClient")
provider := rpc.NewProvider(client)

acnt, err := account.NewAccount(provider, &felt.Zero, "pubkey", account.NewMemKeystore())
acnt, err := account.NewAccount(provider, &felt.Zero, "pubkey", account.NewMemKeystore(), 0)
require.NoError(t, err, "error returned from account.NewAccount()")

type testSetType struct {
Expand Down Expand Up @@ -1051,7 +1095,7 @@ func TestAddDeclareTxn(t *testing.T) {
require.NoError(t, err, "Error in rpc.NewClient")
provider := rpc.NewProvider(client)

acnt, err := account.NewAccount(provider, AccountAddress, PubKey.String(), ks)
acnt, err := account.NewAccount(provider, AccountAddress, PubKey.String(), ks, 0)
require.NoError(t, err)

// Class Hash
Expand Down
16 changes: 10 additions & 6 deletions examples/deployAccount/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ import (
)

var (
network string = "testnet"
predeployedClassHash = "0x2794ce20e5f2ff0d40e632cb53845b9f4e526ebd8471983f7dbd355b721d5a"
accountAddress = "0xdeadbeef"
network string = "testnet"
predeployedClassHash = "0x2794ce20e5f2ff0d40e632cb53845b9f4e526ebd8471983f7dbd355b721d5a"
accountAddress = "0xdeadbeef"
accountContractVersion = 0 //Replace with the cairo version of your account contract
)

// main initializes the client, sets up the account, deploys a contract, and sends a transaction to the network.
Expand All @@ -27,9 +28,12 @@ var (
// and finally sends the transaction to the network.
//
// Parameters:
// none
//
// none
//
// Returns:
// none
//
// none
func main() {
// Initialise the client.
godotenv.Load(fmt.Sprintf(".env.%s", network))
Expand All @@ -49,7 +53,7 @@ func main() {
}

// Set up account
acnt, err := account.NewAccount(clientv02, accountAddressFelt, pub.String(), ks)
acnt, err := account.NewAccount(clientv02, accountAddressFelt, pub.String(), ks, accountContractVersion)
if err != nil {
panic(err)
}
Expand Down
Loading

0 comments on commit 7d4e5b4

Please sign in to comment.