diff --git a/typed/typed.go b/typed/typed.go index 478c0fca..d532639c 100644 --- a/typed/typed.go +++ b/typed/typed.go @@ -15,11 +15,39 @@ import ( ) type TypedData struct { - Types map[string]TypeDefinition `json:"types"` - PrimaryType string `json:"primaryType"` - Domain Domain `json:"domain"` - Message map[string]any `json:"message"` - Revision *revision `json:"-"` + types map[string]TypeDefinition + primaryType string + domain Domain + message map[string]any + revision *revision +} + +func (td *TypedData) Types() map[string]TypeDefinition { + copyMap := make(map[string]TypeDefinition, len(td.types)) + for k, v := range td.types { + copyMap[k] = v + } + return copyMap +} + +func (td *TypedData) PrimaryType() string { + return td.primaryType +} + +func (td *TypedData) Domain() Domain { + return td.domain +} + +func (td *TypedData) Message() map[string]any { + copyMap := make(map[string]any, len(td.message)) + for k, v := range td.message { + copyMap[k] = v + } + return copyMap +} + +func (td *TypedData) Revision() revision { + return *td.revision } type Domain struct { @@ -102,11 +130,11 @@ func NewTypedData(types []TypeDefinition, primaryType string, domain Domain, mes } td = &TypedData{ - Types: typesMap, - PrimaryType: primaryType, - Domain: domain, - Message: messageMap, - Revision: revision, + types: typesMap, + primaryType: primaryType, + domain: domain, + message: messageMap, + revision: revision, } return td, nil @@ -133,7 +161,7 @@ func (td *TypedData) GetMessageHash(account string) (hash *felt.Felt, err error) elements = append(elements, starknetMessage) //Enc[domain_separator] - domEnc, err := td.GetStructHash(td.Revision.Domain()) + domEnc, err := td.GetStructHash(td.revision.Domain()) if err != nil { return hash, err } @@ -147,13 +175,13 @@ func (td *TypedData) GetMessageHash(account string) (hash *felt.Felt, err error) elements = append(elements, accountFelt) //Enc[message] - msgEnc, err := td.GetStructHash(td.PrimaryType) + msgEnc, err := td.GetStructHash(td.primaryType) if err != nil { return hash, err } elements = append(elements, msgEnc) - return td.Revision.HashMethod(elements...), nil + return td.revision.HashMethod(elements...), nil } // GetStructHash calculates the hash of a type and its respective data. @@ -164,9 +192,9 @@ func (td *TypedData) GetMessageHash(account string) (hash *felt.Felt, err error) // - hash: A pointer to a felt.Felt representing the calculated hash. // - err: any error if any func (td *TypedData) GetStructHash(typeName string, context ...string) (hash *felt.Felt, err error) { - typeDef, ok := td.Types[typeName] + typeDef, ok := td.types[typeName] if !ok { - if typeDef, ok = td.Revision.Types().Preset[typeName]; !ok { + if typeDef, ok = td.revision.Types().Preset[typeName]; !ok { return hash, fmt.Errorf("error getting the type definition of %s", typeName) } } @@ -175,7 +203,7 @@ func (td *TypedData) GetStructHash(typeName string, context ...string) (hash *fe return hash, err } - return td.Revision.HashMethod(append([]*felt.Felt{typeDef.Enconding}, encTypeData...)...), nil + return td.revision.HashMethod(append([]*felt.Felt{typeDef.Enconding}, encTypeData...)...), nil } func shortGetStructHash( @@ -192,9 +220,9 @@ func shortGetStructHash( } if isEnum { - return typedData.Revision.HashMethod(encTypeData...), nil + return typedData.revision.HashMethod(encTypeData...), nil } - return typedData.Revision.HashMethod(append([]*felt.Felt{typeDef.Enconding}, encTypeData...)...), nil + return typedData.revision.HashMethod(append([]*felt.Felt{typeDef.Enconding}, encTypeData...)...), nil } // GetTypeHash returns the hash of the given type. @@ -206,9 +234,9 @@ func shortGetStructHash( // - err: any error if any func (td *TypedData) GetTypeHash(typeName string) (*felt.Felt, error) { //TODO: create/update methods descriptions - typeDef, ok := td.Types[typeName] + typeDef, ok := td.types[typeName] if !ok { - if typeDef, ok = td.Revision.Types().Preset[typeName]; !ok { + if typeDef, ok = td.revision.Types().Preset[typeName]; !ok { return typeDef.Enconding, fmt.Errorf("type '%s' not found", typeName) } } @@ -382,7 +410,7 @@ func encodeTypes(typeName string, types map[string]TypeDefinition, revision *rev func EncodeData(typeDef *TypeDefinition, td *TypedData, context ...string) (enc []*felt.Felt, err error) { if typeDef.Name == "StarkNetDomain" || typeDef.Name == "StarknetDomain" { domainMap := make(map[string]any) - domainBytes, err := json.Marshal(td.Domain) + domainBytes, err := json.Marshal(td.domain) if err != nil { return enc, err } @@ -397,7 +425,7 @@ func EncodeData(typeDef *TypeDefinition, td *TypedData, context ...string) (enc return encodeData(typeDef, td, domainMap, false, context...) } - return encodeData(typeDef, td, td.Message, false, context...) + return encodeData(typeDef, td, td.message, false, context...) } func encodeData( @@ -441,7 +469,7 @@ func encodeData( } return resp, nil case "enum": - typeDef, ok := typedData.Types[param.Contains] + typeDef, ok := typedData.types[param.Contains] if !ok { return resp, fmt.Errorf("error trying to get the type definition of '%s' in contains of '%s'", param.Contains, param.Name) } @@ -529,7 +557,7 @@ func encodeData( if isPresetType(singleParamType) { typeDef, ok = rev.Types().Preset[singleParamType] } else { - typeDef, ok = typedData.Types[singleParamType] + typeDef, ok = typedData.types[singleParamType] } if !ok { return resp, fmt.Errorf("error trying to get the type definition of '%s'", singleParamType) @@ -551,7 +579,7 @@ func encodeData( //function logic if strings.HasSuffix(param.Type, "*") { - resp, err := handleArrays(param, data, typedData.Revision) + resp, err := handleArrays(param, data, typedData.revision) if err != nil { return resp, err } @@ -559,14 +587,14 @@ func encodeData( } if isStandardType(param.Type) { - resp, err := handleStandardTypes(param, data, typedData.Revision) + resp, err := handleStandardTypes(param, data, typedData.revision) if err != nil { return resp, err } return resp, nil } - nextTypeDef, ok := typedData.Types[param.Type] + nextTypeDef, ok := typedData.types[param.Type] if !ok { return resp, fmt.Errorf("error trying to get the type definition of '%s'", param.Type) } @@ -733,7 +761,7 @@ func encodePieceOfData(typeName string, data any, rev *revision) (resp *felt.Fel } } -func (typedData *TypedData) UnmarshalJSON(data []byte) error { +func (td *TypedData) UnmarshalJSON(data []byte) error { var dec map[string]json.RawMessage if err := json.Unmarshal(data, &dec); err != nil { return err @@ -787,7 +815,7 @@ func (typedData *TypedData) UnmarshalJSON(data []byte) error { return err } - *typedData = *resultTypedData + *td = *resultTypedData return nil } diff --git a/typed/typed_test.go b/typed/typed_test.go index 04b7bb98..bf2129bf 100644 --- a/typed/typed_test.go +++ b/typed/typed_test.go @@ -3,56 +3,12 @@ package typed import ( "encoding/json" "fmt" - "math/big" "os" "testing" "github.com/stretchr/testify/require" ) -type Mail struct { - From Person `json:"from"` - To Person `json:"to"` - Contents string `json:"contents"` -} - -type Person struct { - Name string `json:"name"` - Wallet string `json:"wallet"` -} - -var types = []TypeDefinition{ - { - Name: "StarkNetDomain", - Parameters: []TypeParameter{ - {Name: "name", Type: "felt"}, - {Name: "version", Type: "felt"}, - {Name: "chainId", Type: "felt"}, - }, - }, - { - Name: "Mail", - Parameters: []TypeParameter{ - {Name: "from", Type: "Person"}, - {Name: "to", Type: "Person"}, - {Name: "contents", Type: "felt"}, - }, - }, - { - Name: "Person", - Parameters: []TypeParameter{ - {Name: "name", Type: "felt"}, - {Name: "wallet", Type: "felt"}, - }, - }, -} - -var dm = Domain{ - Name: "StarkNet Mail", - Version: "1", - ChainId: "1", -} - var typedDataExamples = make(map[string]TypedData) func TestMain(m *testing.M) { @@ -97,159 +53,6 @@ func BMockTypedData(b *testing.B) (ttd TypedData) { return } -// The TestUnmarshal function tests the ability to correctly unmarshal (deserialize) JSON content from -// a file into a Go TypedData struct. It starts by reading a json file. The JSON content is then unmarshaled -// into a TypedData struct using the json.Unmarshal function. After unmarshaling, the test checks if there were -// any errors during the unmarshaling process, and if an error is found, the test will fail. -// -// Parameters: -// - t: a testing.T object that provides methods for testing functions -// Returns: -// - None -func TestUnmarshal(t *testing.T) { - content, err := os.ReadFile("./tests/baseExample.json") - require.NoError(t, err) - - var typedData TypedData - err = json.Unmarshal(content, &typedData) - require.NoError(t, err) -} - -func TestGeneral_CreateMessageWithTypes(t *testing.T) { - t.Skip("TODO: need to implement encodeData method") - // for testSetType 2 - type Example1 struct { - N0 Felt `json:"n0"` - N1 Bool `json:"n1"` - N2 String `json:"n2"` - N3 Selector `json:"n3"` - N4 U128 `json:"n4"` - N5 I128 `json:"n5"` - N6 ContractAddress `json:"n6"` - N7 ClassHash `json:"n7"` - N8 Timestamp `json:"n8"` - N9 Shortstring `json:"n9"` - } - - // for testSetType 3 - type Example2 struct { - N0 TokenAmount `json:"n0"` - N1 NftId `json:"n1"` - } - - hex1, ok := new(big.Int).SetString("0x3e8", 0) - require.True(t, ok) - hex2, ok := new(big.Int).SetString("0x0", 0) - require.True(t, ok) - - type testSetType struct { - MessageWithString string - MessageWithTypes any - } - testSet := []testSetType{ - { - MessageWithString: ` - { - "from": { - "name": "Cow", - "wallet": "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826" - }, - "to": { - "name": "Bob", - "wallet": "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB" - }, - "contents": "Hello, Bob!" - }`, - MessageWithTypes: Mail{ - From: Person{ - Name: "Cow", - Wallet: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", - }, - To: Person{ - Name: "Bob", - Wallet: "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB", - }, - Contents: "Hello, Bob!", - }, - }, - { - MessageWithString: ` - { - "n0": "0x3e8", - "n1": true, - "n2": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", - "n3": "transfer", - "n4": 10, - "n5": -10, - "n6": "0x3e8", - "n7": "0x3e8", - "n8": 1000, - "n9": "transfer" - }`, - MessageWithTypes: Example1{ - N0: "0x3e8", - N1: true, - N2: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", - N3: "transfer", - N4: big.NewInt(10), - N5: big.NewInt(-10), - N6: "0x3e8", - N7: "0x3e8", - N8: big.NewInt(1000), - N9: "transfer", - }, - }, - { - MessageWithString: ` - { - "n0": { - "token_address": "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", - "amount": { - "low": "0x3e8", - "high": "0x0" - } - }, - "n1": { - "collection_address": "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", - "token_id": { - "low": "0x3e8", - "high": "0x0" - } - } - }`, - MessageWithTypes: Example2{ - N0: TokenAmount{ - TokenAddress: "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", - Amount: U256{ - Low: hex1, - High: hex2, - }, - }, - N1: NftId{ - CollectionAddress: "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", - TokenID: U256{ - Low: hex1, - High: hex2, - }, - }, - }, - }, - } - for _, test := range testSet { - ttd1, err := NewTypedData(types, "Mail", dm, []byte(test.MessageWithString)) - require.NoError(t, err) - - bytes, err := json.Marshal(test.MessageWithTypes) - require.NoError(t, err) - - ttd2, err := NewTypedData(types, "Mail", dm, bytes) - require.NoError(t, err) - - require.EqualValues(t, ttd1, ttd2) - - } -} - // TestMessageHash tests the GetMessageHash function. // // It creates a mock TypedData and sets up a test case for hashing a mail message. @@ -465,7 +268,7 @@ func TestEncodeType(t *testing.T) { }, } for _, test := range testSet { - require.Equal(t, test.ExpectedEncode, test.TypedData.Types[test.TypeName].EncoddingString) + require.Equal(t, test.ExpectedEncode, test.TypedData.types[test.TypeName].EncoddingString) } } diff --git a/typed/types.go b/typed/types.go deleted file mode 100644 index b83c2ca0..00000000 --- a/typed/types.go +++ /dev/null @@ -1,31 +0,0 @@ -package typed - -import "math/big" - -type ( - Felt string - Bool bool - String string - Selector string - U128 *big.Int - I128 *big.Int - ContractAddress string - ClassHash string - Timestamp U128 - Shortstring string -) - -type U256 struct { - Low U128 `json:"low"` - High U128 `json:"high"` -} - -type TokenAmount struct { - TokenAddress ContractAddress `json:"token_address"` - Amount U256 `json:"amount"` -} - -type NftId struct { - CollectionAddress ContractAddress `json:"collection_address"` - TokenID U256 `json:"token_id"` -}