Skip to content

Commit

Permalink
Fixes bug with merkletree
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagodeev committed Nov 29, 2024
1 parent f0416c9 commit 641019a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 15 deletions.
4 changes: 2 additions & 2 deletions typed/revision.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ func GetRevision(version uint8) (rev *revision, err error) {
}

func getRevisionV1PresetTypes() map[string]TypeDefinition {
NftIdEnc, _ := new(felt.Felt).SetString("0x14648649d4413eb385eea9ac7e6f2b9769671f5d9d7ad40f7b4aadd67839d4")
TokenAmountEnc, _ := new(felt.Felt).SetString("0xaf7d0f5e34446178d80fadf5ddaaed52347121d2fac19ff184ff508d4776f2")
NftIdEnc, _ := new(felt.Felt).SetString("0xaf7d0f5e34446178d80fadf5ddaaed52347121d2fac19ff184ff508d4776f2")
TokenAmountEnc, _ := new(felt.Felt).SetString("0x14648649d4413eb385eea9ac7e6f2b9769671f5d9d7ad40f7b4aadd67839d4")
u256dEnc, _ := new(felt.Felt).SetString("0x3b143be38b811560b45593fb2a071ec4ddd0a020e10782be62ffe6f39e0e82c")

presetTypes := []TypeDefinition{
Expand Down
43 changes: 34 additions & 9 deletions typed/typed.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ func NewTypedData(types []TypeDefinition, primaryType string, domain Domain, mes
}
typesMap[primaryType] = primaryTypeDef

for _, typeDef := range typesMap {
if typeDef.EncoddingString == "" {
return td, fmt.Errorf("'encodeTypes' failed: type '%s' doesn't have encode value", typeDef.Name)
}
}

td = &TypedData{
Types: typesMap,
PrimaryType: primaryType,
Expand Down Expand Up @@ -160,7 +166,9 @@ func (td *TypedData) GetMessageHash(account string) (hash *felt.Felt, err error)
func (td *TypedData) GetStructHash(typeName string, context ...string) (hash *felt.Felt, err error) {
typeDef, ok := td.Types[typeName]
if !ok {
return hash, fmt.Errorf("error getting the type definition of %s", typeName)
if typeDef, ok = td.Revision.Types().Preset[typeName]; !ok {
return hash, fmt.Errorf("error getting the type definition of %s", typeName)
}
}
encTypeData, err := EncodeData(&typeDef, td, context...)
if err != nil {
Expand Down Expand Up @@ -192,9 +200,15 @@ func shortGetStructHash(
// Returns:
// - ret: the hash of the given type
// - err: any error if any
func (td *TypedData) GetTypeHash(typeName string) *felt.Felt {
func (td *TypedData) GetTypeHash(typeName string) (*felt.Felt, error) {
//TODO: create/update methods descriptions
return td.Types[typeName].Enconding
typeDef, ok := td.Types[typeName]
if !ok {
if typeDef, ok = td.Revision.Types().Preset[typeName]; !ok {
return typeDef.Enconding, fmt.Errorf("type '%s' not found", typeName)
}
}
return typeDef.Enconding, nil
}

// EncodeType encodes the given inType using the TypedData struct.
Expand All @@ -206,10 +220,21 @@ func (td *TypedData) GetTypeHash(typeName string) *felt.Felt {
// - err: any error if any
func encodeTypes(typeName string, types map[string]TypeDefinition, revision *revision, isEnum ...bool) (newTypeDef TypeDefinition, err error) {
getTypeEncodeString := func(typeName string, typeDef TypeDefinition, customTypesStringEnc *[]string, isEnum ...bool) (result string, err error) {
verifyTypeName := func(typeName string, isEnum ...bool) error {
singleTypeName, _ := strings.CutSuffix(typeName, "*")
verifyTypeName := func(param TypeParameter, isEnum ...bool) error {
singleTypeName, _ := strings.CutSuffix(param.Type, "*")

if isBasicType(singleTypeName) {
if singleTypeName == "merkletree" {
if param.Contains == "" {
return fmt.Errorf("missing 'contains' value from '%s'", param.Name)
}
newTypeDef, err := encodeTypes(param.Contains, types, revision)
if err != nil {
return err
}

types[param.Contains] = newTypeDef
}
return nil
}

Expand Down Expand Up @@ -265,7 +290,7 @@ func encodeTypes(typeName string, types map[string]TypeDefinition, revision *rev
buf.WriteString(fmt.Sprintf(quotationMark+"%s"+quotationMark+":"+`(`+"%s"+`)`, param.Name, fullTypeName))

for _, typeNam := range typesArr {
err = verifyTypeName(typeNam)
err = verifyTypeName(TypeParameter{Type: typeNam})
if err != nil {
return "", err
}
Expand All @@ -278,15 +303,15 @@ func encodeTypes(typeName string, types map[string]TypeDefinition, revision *rev
return "", fmt.Errorf("missing 'contains' value from '%s'", param.Name)
}
currentTypeName = param.Contains
err = verifyTypeName(currentTypeName, true)
err = verifyTypeName(TypeParameter{Type: currentTypeName}, true)
if err != nil {
return "", err
}
}

buf.WriteString(fmt.Sprintf(quotationMark+"%s"+quotationMark+":"+quotationMark+"%s"+quotationMark, param.Name, currentTypeName))

err = verifyTypeName(param.Type)
err = verifyTypeName(param)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -442,7 +467,7 @@ func encodeData(
return resp, fmt.Errorf("error trying to convert the value of '%s' to an map", typeDef)
}

resp, err = shortGetStructHash(typeDef, typedData, mapData, context...)
resp, err = shortGetStructHash(typeDef, typedData, mapData)
if err != nil {
return resp, err
}
Expand Down
5 changes: 1 addition & 4 deletions typed/typed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,7 @@ func TestEncodeType(t *testing.T) {
},
}
for _, test := range testSet {
encode, err := encodeType(test.TypeName, test.TypedData.Types, test.TypedData.Revision)
require.NoError(t, err)

require.Equal(t, test.ExpectedEncode, encode)
require.Equal(t, test.ExpectedEncode, test.TypedData.Types[test.TypeName].EncoddingString)
}
}

Expand Down

0 comments on commit 641019a

Please sign in to comment.