Skip to content

Commit

Permalink
Some code adjustments
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagodeev committed Nov 20, 2024
1 parent 1fdcd72 commit 2ba0749
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 56 deletions.
67 changes: 26 additions & 41 deletions typed/typed.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type TypedData struct {
PrimaryType string `json:"primaryType"`
Domain Domain `json:"domain"`
Message map[string]any `json:"message"`
Revision Revision
Revision Revision `json:"-"`
}

type Domain struct {
Expand Down Expand Up @@ -85,6 +85,10 @@ func NewTypedData(types []TypeDefinition, primaryType string, domain Domain, mes
typesMap[typeDef.Name] = typeDef
}

if _, ok := typesMap[primaryType]; !ok {
return td, fmt.Errorf("invalid primary type: %s", primaryType)
}

messageMap := make(map[string]any)
err = json.Unmarshal(message, &messageMap)
if err != nil {
Expand All @@ -103,9 +107,6 @@ func NewTypedData(types []TypeDefinition, primaryType string, domain Domain, mes
Message: messageMap,
Revision: revision,
}
if _, ok := td.Types[primaryType]; !ok {
return td, fmt.Errorf("invalid primary type: %s", primaryType)
}

for k, v := range td.Types {
enc, err := getTypeHash(k, td.Types)
Expand All @@ -127,26 +128,32 @@ func NewTypedData(types []TypeDefinition, primaryType string, domain Domain, mes
// Returns:
// - hash: A pointer to a felt.Felt representing the calculated hash.
func (td *TypedData) GetMessageHash(account string) (hash *felt.Felt, err error) {
//signed_data = encode(PREFIX_MESSAGE, Enc[domain_separator], account, Enc[message])

elements := []*felt.Felt{}

//PREFIX_MESSAGE
starknetMessage, err := utils.HexToFelt(utils.StrToHex("StarkNet Message"))
if err != nil {
return hash, err
}
elements = append(elements, starknetMessage)

//Enc[domain_separator]
domEnc, err := td.GetStructHash(td.Revision.Domain())
if err != nil {
return hash, err
}
elements = append(elements, domEnc)

//account
accountFelt, err := utils.HexToFelt(account)
if err != nil {
return hash, err
}
elements = append(elements, accountFelt)

//Enc[message]
msgEnc, err := td.GetStructHash(td.PrimaryType)
if err != nil {
return hash, err
Expand Down Expand Up @@ -179,7 +186,7 @@ func (td *TypedData) GetStructHash(typeName string, context ...string) (hash *fe
func shortGetStructHash(
typeDef *TypeDefinition,
typedData *TypedData,
data *map[string]any,
data map[string]any,
context ...string,
) (hash *felt.Felt, err error) {

Expand Down Expand Up @@ -282,8 +289,7 @@ func encodeType(typeName string, types map[string]TypeDefinition) (enc string, e
}

func EncodeData(typeDef *TypeDefinition, td *TypedData, context ...string) (enc []*felt.Felt, err error) {
localTypeDef := *typeDef
if localTypeDef.Name == "StarkNetDomain" || localTypeDef.Name == "StarknetDomain" {
if typeDef.Name == "StarkNetDomain" || typeDef.Name == "StarknetDomain" {
domainMap := make(map[string]any)
domainBytes, err := json.Marshal(td.Domain)
if err != nil {
Expand All @@ -294,36 +300,34 @@ func EncodeData(typeDef *TypeDefinition, td *TypedData, context ...string) (enc
return enc, err
}

return encodeData(typeDef, td, &domainMap, context...)
return encodeData(typeDef, td, domainMap, context...)
}

return encodeData(typeDef, td, &td.Message, context...)
return encodeData(typeDef, td, td.Message, context...)
}

func encodeData(
typeDef *TypeDefinition,
typedData *TypedData,
data *map[string]any,
data map[string]any,
context ...string,
) (enc []*felt.Felt, err error) {
localData := *data

if len(context) != 0 {
for _, paramName := range context {
value, ok := localData[paramName]
value, ok := data[paramName]
if !ok {
return enc, fmt.Errorf("context error: parameter '%s' not found in the data map", paramName)
}
newData, ok := value.(map[string]any)
if !ok {
return enc, fmt.Errorf("context error: error generating the new data map")
}
localData = newData
data = newData
}
}

getStringFromData := func(key string) (resp string, err error) {
value, ok := localData[key]
value, ok := data[key]
if !ok {
return resp, fmt.Errorf("error trying to get the value of the %s type", key)
}
Expand Down Expand Up @@ -394,49 +398,32 @@ func encodeData(
}

func (typedData *TypedData) UnmarshalJSON(data []byte) error {
var dec map[string]interface{}
var dec map[string]json.RawMessage
if err := json.Unmarshal(data, &dec); err != nil {
return err
}

// primaryType
rawPrimaryType, ok := dec["primaryType"]
if !ok {
return fmt.Errorf("invalid typedData json: missing field 'primaryType'")
}
primaryType, ok := rawPrimaryType.(string)
if !ok {
return fmt.Errorf("failed to unmarshal 'primaryType', it's not a string")
primaryType, err := utils.GetAndUnmarshalJSONFromMap[string](dec, "primaryType")
if err != nil {
return err
}

// domain
rawDomain, ok := dec["domain"]
if !ok {
return fmt.Errorf("invalid typedData json: missing field 'domain'")
}
bytesDomain, err := json.Marshal(rawDomain)
domain, err := utils.GetAndUnmarshalJSONFromMap[Domain](dec, "domain")
if err != nil {
return err
}
var domain Domain
if err := json.Unmarshal(bytesDomain, &domain); err != nil {
return err
}

// types
rawTypes, err := utils.UnwrapJSON(dec, "types")
rawTypes, err := utils.GetAndUnmarshalJSONFromMap[map[string]json.RawMessage](dec, "types")
if err != nil {
return err
}
var types []TypeDefinition
for key, value := range rawTypes {
bytesValue, err := json.Marshal(value)
if err != nil {
return err
}

var params []TypeParameter
if err := json.Unmarshal(bytesValue, &params); err != nil {
if err := json.Unmarshal(value, &params); err != nil {
return err
}

Expand Down Expand Up @@ -466,6 +453,4 @@ func (typedData *TypedData) UnmarshalJSON(data []byte) error {

*typedData = *resultTypedData
return nil

// TODO: implement typedMessage unmarshal
}
14 changes: 0 additions & 14 deletions typed/typed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,6 @@ var dm = Domain{
ChainId: "1",
}

var message = `
{
"from": {
"name": "Cow",
"wallet": "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"
},
"to": {
"name": "Bob",
"wallet": "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB"
},
"contents": "Hello, Bob!"
}
`

// MockTypedData generates a TypedData object for testing purposes.
// It creates example types and initializes a Domain object. Then it uses the example types and the domain to create a new TypedData object.
// The function returns the generated TypedData object.
Expand Down
19 changes: 18 additions & 1 deletion utils/data.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package utils

import "encoding/json"
import (
"encoding/json"
"fmt"
)

func UnwrapJSON(data map[string]interface{}, tag string) (map[string]interface{}, error) {
if data[tag] != nil {
Expand All @@ -16,3 +19,17 @@ func UnwrapJSON(data map[string]interface{}, tag string) (map[string]interface{}
}
return data, nil
}

func GetAndUnmarshalJSONFromMap[T any](aMap map[string]json.RawMessage, key string) (result T, err error) {
value, ok := aMap[key]
if !ok {
return result, fmt.Errorf("invalid json: missing field %s", key)
}

err = json.Unmarshal(value, &result)
if err != nil {
return result, err
}

return result, nil
}

0 comments on commit 2ba0749

Please sign in to comment.