From f597d20306b4d3be320b864937e8730f16890182 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Fri, 13 Dec 2024 01:30:38 -0300 Subject: [PATCH] Changes private fields of TypedData to public --- typedData/typedData.go | 85 ++++++++++++------------------------- typedData/typedData_test.go | 2 +- 2 files changed, 27 insertions(+), 60 deletions(-) diff --git a/typedData/typedData.go b/typedData/typedData.go index 57b7e986..e56cf372 100644 --- a/typedData/typedData.go +++ b/typedData/typedData.go @@ -15,44 +15,11 @@ import ( ) type TypedData struct { - types map[string]TypeDefinition - primaryType string - domain Domain - message map[string]any - revision *revision -} - -// Types returns a copy of the TypedData's type definitions map -func (td *TypedData) Types() map[string]TypeDefinition { - copyMap := make(map[string]TypeDefinition, len(td.types)) - for k, v := range td.types { - copyMap[k] = v - } - return copyMap -} - -// PrimaryType returns the primary type name of the TypedData -func (td *TypedData) PrimaryType() string { - return td.primaryType -} - -// Domain returns the domain information of the TypedData -func (td *TypedData) Domain() Domain { - return td.domain -} - -// Message returns a copy of the TypedData's message map -func (td *TypedData) Message() map[string]any { - copyMap := make(map[string]any, len(td.message)) - for k, v := range td.message { - copyMap[k] = v - } - return copyMap -} - -// Revision returns the revision value of the TypedData -func (td *TypedData) Revision() revision { - return *td.revision + Types map[string]TypeDefinition + PrimaryType string + Domain Domain + Message map[string]any + Revision *revision } type Domain struct { @@ -134,11 +101,11 @@ func NewTypedData(types []TypeDefinition, primaryType string, domain Domain, mes } td = &TypedData{ - types: typesMap, - primaryType: primaryType, - domain: domain, - message: messageMap, - revision: revision, + Types: typesMap, + PrimaryType: primaryType, + Domain: domain, + Message: messageMap, + Revision: revision, } return td, nil @@ -162,7 +129,7 @@ func (td *TypedData) GetMessageHash(account string) (hash *felt.Felt, err error) } //Enc[domain_separator] - domEnc, err := td.GetStructHash(td.revision.Domain()) + domEnc, err := td.GetStructHash(td.Revision.Domain()) if err != nil { return hash, err } @@ -174,12 +141,12 @@ func (td *TypedData) GetMessageHash(account string) (hash *felt.Felt, err error) } //Enc[message] - msgEnc, err := td.GetStructHash(td.primaryType) + msgEnc, err := td.GetStructHash(td.PrimaryType) if err != nil { return hash, err } - return td.revision.HashMethod(prefixMessage, domEnc, accountFelt, msgEnc), nil + return td.Revision.HashMethod(prefixMessage, domEnc, accountFelt, msgEnc), nil } // GetStructHash calculates the hash of a struct type and its respective data. @@ -197,9 +164,9 @@ func (td *TypedData) GetMessageHash(account string) (hash *felt.Felt, err error) // - hash: A pointer to a felt.Felt representing the calculated hash. // - err: an error if any occurred during the hash calculation. func (td *TypedData) GetStructHash(typeName string, context ...string) (hash *felt.Felt, err error) { - typeDef, ok := td.types[typeName] + typeDef, ok := td.Types[typeName] if !ok { - if typeDef, ok = td.revision.Types().Preset[typeName]; !ok { + if typeDef, ok = td.Revision.Types().Preset[typeName]; !ok { return hash, fmt.Errorf("error getting the type definition of %s", typeName) } } @@ -208,7 +175,7 @@ func (td *TypedData) GetStructHash(typeName string, context ...string) (hash *fe return hash, err } - return td.revision.HashMethod(append([]*felt.Felt{typeDef.Enconding}, encTypeData...)...), nil + return td.Revision.HashMethod(append([]*felt.Felt{typeDef.Enconding}, encTypeData...)...), nil } // shortGetStructHash is a helper function that calculates the hash of a struct type and its respective data. @@ -225,7 +192,7 @@ func shortGetStructHash( return hash, err } - return typedData.revision.HashMethod(append([]*felt.Felt{typeDef.Enconding}, encTypeData...)...), nil + return typedData.Revision.HashMethod(append([]*felt.Felt{typeDef.Enconding}, encTypeData...)...), nil } // GetTypeHash returns the hash of the given type. @@ -237,9 +204,9 @@ func shortGetStructHash( // - err: an error if any occurred during the hash calculation. func (td *TypedData) GetTypeHash(typeName string) (*felt.Felt, error) { //TODO: create/update methods descriptions - typeDef, ok := td.types[typeName] + typeDef, ok := td.Types[typeName] if !ok { - if typeDef, ok = td.revision.Types().Preset[typeName]; !ok { + if typeDef, ok = td.Revision.Types().Preset[typeName]; !ok { return typeDef.Enconding, fmt.Errorf("type '%s' not found", typeName) } } @@ -416,7 +383,7 @@ func encodeTypes(typeName string, types map[string]TypeDefinition, revision *rev func EncodeData(typeDef *TypeDefinition, td *TypedData, context ...string) (enc []*felt.Felt, err error) { if typeDef.Name == "StarkNetDomain" || typeDef.Name == "StarknetDomain" { domainMap := make(map[string]any) - domainBytes, err := json.Marshal(td.domain) + domainBytes, err := json.Marshal(td.Domain) if err != nil { return enc, err } @@ -431,7 +398,7 @@ func EncodeData(typeDef *TypeDefinition, td *TypedData, context ...string) (enc return encodeData(typeDef, td, domainMap, false, context...) } - return encodeData(typeDef, td, td.message, false, context...) + return encodeData(typeDef, td, td.Message, false, context...) } // encodeData is a helper function that encodes the given type definition using the TypedData struct. @@ -491,7 +458,7 @@ func encodeData( } return resp, nil case "enum": - typeDef, ok := typedData.types[param.Contains] + typeDef, ok := typedData.Types[param.Contains] if !ok { return resp, fmt.Errorf("error trying to get the type definition of '%s' in contains of '%s'", param.Contains, param.Name) } @@ -579,7 +546,7 @@ func encodeData( if isPresetType(singleParamType) { typeDef, ok = rev.Types().Preset[singleParamType] } else { - typeDef, ok = typedData.types[singleParamType] + typeDef, ok = typedData.Types[singleParamType] } if !ok { return resp, fmt.Errorf("error trying to get the type definition of '%s'", singleParamType) @@ -601,7 +568,7 @@ func encodeData( //function logic if strings.HasSuffix(param.Type, "*") { - resp, err := handleArrays(param, data, typedData.revision) + resp, err := handleArrays(param, data, typedData.Revision) if err != nil { return resp, err } @@ -609,14 +576,14 @@ func encodeData( } if isStandardType(param.Type) { - resp, err := handleStandardTypes(param, data, typedData.revision) + resp, err := handleStandardTypes(param, data, typedData.Revision) if err != nil { return resp, err } return resp, nil } - nextTypeDef, ok := typedData.types[param.Type] + nextTypeDef, ok := typedData.Types[param.Type] if !ok { return resp, fmt.Errorf("error trying to get the type definition of '%s'", param.Type) } diff --git a/typedData/typedData_test.go b/typedData/typedData_test.go index 99bb5b6d..318e9c62 100644 --- a/typedData/typedData_test.go +++ b/typedData/typedData_test.go @@ -282,7 +282,7 @@ func TestEncodeType(t *testing.T) { }, } for _, test := range testSet { - require.Equal(t, test.ExpectedEncode, test.TypedData.types[test.TypeName].EncoddingString) + require.Equal(t, test.ExpectedEncode, test.TypedData.Types[test.TypeName].EncoddingString) } }