Skip to content

Commit

Permalink
Changes private fields of TypedData to public
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagodeev committed Dec 13, 2024
1 parent 1b5d956 commit f597d20
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 60 deletions.
85 changes: 26 additions & 59 deletions typedData/typedData.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,11 @@ import (
)

type TypedData struct {
types map[string]TypeDefinition
primaryType string
domain Domain
message map[string]any
revision *revision
}

// Types returns a copy of the TypedData's type definitions map
func (td *TypedData) Types() map[string]TypeDefinition {
copyMap := make(map[string]TypeDefinition, len(td.types))
for k, v := range td.types {
copyMap[k] = v
}
return copyMap
}

// PrimaryType returns the primary type name of the TypedData
func (td *TypedData) PrimaryType() string {
return td.primaryType
}

// Domain returns the domain information of the TypedData
func (td *TypedData) Domain() Domain {
return td.domain
}

// Message returns a copy of the TypedData's message map
func (td *TypedData) Message() map[string]any {
copyMap := make(map[string]any, len(td.message))
for k, v := range td.message {
copyMap[k] = v
}
return copyMap
}

// Revision returns the revision value of the TypedData
func (td *TypedData) Revision() revision {
return *td.revision
Types map[string]TypeDefinition
PrimaryType string
Domain Domain
Message map[string]any
Revision *revision
}

type Domain struct {
Expand Down Expand Up @@ -134,11 +101,11 @@ func NewTypedData(types []TypeDefinition, primaryType string, domain Domain, mes
}

td = &TypedData{
types: typesMap,
primaryType: primaryType,
domain: domain,
message: messageMap,
revision: revision,
Types: typesMap,
PrimaryType: primaryType,
Domain: domain,
Message: messageMap,
Revision: revision,
}

return td, nil
Expand All @@ -162,7 +129,7 @@ func (td *TypedData) GetMessageHash(account string) (hash *felt.Felt, err error)
}

//Enc[domain_separator]
domEnc, err := td.GetStructHash(td.revision.Domain())
domEnc, err := td.GetStructHash(td.Revision.Domain())
if err != nil {
return hash, err
}
Expand All @@ -174,12 +141,12 @@ func (td *TypedData) GetMessageHash(account string) (hash *felt.Felt, err error)
}

//Enc[message]
msgEnc, err := td.GetStructHash(td.primaryType)
msgEnc, err := td.GetStructHash(td.PrimaryType)
if err != nil {
return hash, err
}

return td.revision.HashMethod(prefixMessage, domEnc, accountFelt, msgEnc), nil
return td.Revision.HashMethod(prefixMessage, domEnc, accountFelt, msgEnc), nil
}

// GetStructHash calculates the hash of a struct type and its respective data.
Expand All @@ -197,9 +164,9 @@ func (td *TypedData) GetMessageHash(account string) (hash *felt.Felt, err error)
// - hash: A pointer to a felt.Felt representing the calculated hash.
// - err: an error if any occurred during the hash calculation.
func (td *TypedData) GetStructHash(typeName string, context ...string) (hash *felt.Felt, err error) {
typeDef, ok := td.types[typeName]
typeDef, ok := td.Types[typeName]
if !ok {
if typeDef, ok = td.revision.Types().Preset[typeName]; !ok {
if typeDef, ok = td.Revision.Types().Preset[typeName]; !ok {
return hash, fmt.Errorf("error getting the type definition of %s", typeName)
}
}
Expand All @@ -208,7 +175,7 @@ func (td *TypedData) GetStructHash(typeName string, context ...string) (hash *fe
return hash, err
}

return td.revision.HashMethod(append([]*felt.Felt{typeDef.Enconding}, encTypeData...)...), nil
return td.Revision.HashMethod(append([]*felt.Felt{typeDef.Enconding}, encTypeData...)...), nil
}

// shortGetStructHash is a helper function that calculates the hash of a struct type and its respective data.
Expand All @@ -225,7 +192,7 @@ func shortGetStructHash(
return hash, err
}

return typedData.revision.HashMethod(append([]*felt.Felt{typeDef.Enconding}, encTypeData...)...), nil
return typedData.Revision.HashMethod(append([]*felt.Felt{typeDef.Enconding}, encTypeData...)...), nil
}

// GetTypeHash returns the hash of the given type.
Expand All @@ -237,9 +204,9 @@ func shortGetStructHash(
// - err: an error if any occurred during the hash calculation.
func (td *TypedData) GetTypeHash(typeName string) (*felt.Felt, error) {
//TODO: create/update methods descriptions
typeDef, ok := td.types[typeName]
typeDef, ok := td.Types[typeName]
if !ok {
if typeDef, ok = td.revision.Types().Preset[typeName]; !ok {
if typeDef, ok = td.Revision.Types().Preset[typeName]; !ok {
return typeDef.Enconding, fmt.Errorf("type '%s' not found", typeName)
}
}
Expand Down Expand Up @@ -416,7 +383,7 @@ func encodeTypes(typeName string, types map[string]TypeDefinition, revision *rev
func EncodeData(typeDef *TypeDefinition, td *TypedData, context ...string) (enc []*felt.Felt, err error) {
if typeDef.Name == "StarkNetDomain" || typeDef.Name == "StarknetDomain" {
domainMap := make(map[string]any)
domainBytes, err := json.Marshal(td.domain)
domainBytes, err := json.Marshal(td.Domain)
if err != nil {
return enc, err
}
Expand All @@ -431,7 +398,7 @@ func EncodeData(typeDef *TypeDefinition, td *TypedData, context ...string) (enc
return encodeData(typeDef, td, domainMap, false, context...)
}

return encodeData(typeDef, td, td.message, false, context...)
return encodeData(typeDef, td, td.Message, false, context...)
}

// encodeData is a helper function that encodes the given type definition using the TypedData struct.
Expand Down Expand Up @@ -491,7 +458,7 @@ func encodeData(
}
return resp, nil
case "enum":
typeDef, ok := typedData.types[param.Contains]
typeDef, ok := typedData.Types[param.Contains]
if !ok {
return resp, fmt.Errorf("error trying to get the type definition of '%s' in contains of '%s'", param.Contains, param.Name)
}
Expand Down Expand Up @@ -579,7 +546,7 @@ func encodeData(
if isPresetType(singleParamType) {
typeDef, ok = rev.Types().Preset[singleParamType]
} else {
typeDef, ok = typedData.types[singleParamType]
typeDef, ok = typedData.Types[singleParamType]
}
if !ok {
return resp, fmt.Errorf("error trying to get the type definition of '%s'", singleParamType)
Expand All @@ -601,22 +568,22 @@ func encodeData(

//function logic
if strings.HasSuffix(param.Type, "*") {
resp, err := handleArrays(param, data, typedData.revision)
resp, err := handleArrays(param, data, typedData.Revision)
if err != nil {
return resp, err
}
return resp, nil
}

if isStandardType(param.Type) {
resp, err := handleStandardTypes(param, data, typedData.revision)
resp, err := handleStandardTypes(param, data, typedData.Revision)
if err != nil {
return resp, err
}
return resp, nil
}

nextTypeDef, ok := typedData.types[param.Type]
nextTypeDef, ok := typedData.Types[param.Type]
if !ok {
return resp, fmt.Errorf("error trying to get the type definition of '%s'", param.Type)
}
Expand Down
2 changes: 1 addition & 1 deletion typedData/typedData_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func TestEncodeType(t *testing.T) {
},
}
for _, test := range testSet {
require.Equal(t, test.ExpectedEncode, test.TypedData.types[test.TypeName].EncoddingString)
require.Equal(t, test.ExpectedEncode, test.TypedData.Types[test.TypeName].EncoddingString)
}
}

Expand Down

0 comments on commit f597d20

Please sign in to comment.