diff --git a/typed/revision.go b/typed/revision.go index fd294a7b..8eb0966a 100644 --- a/typed/revision.go +++ b/typed/revision.go @@ -126,8 +126,8 @@ func GetRevision(version uint8) (rev *revision, err error) { } func getRevisionV1PresetTypes() map[string]TypeDefinition { - NftIdEnc, _ := new(felt.Felt).SetString("0x14648649d4413eb385eea9ac7e6f2b9769671f5d9d7ad40f7b4aadd67839d4") - TokenAmountEnc, _ := new(felt.Felt).SetString("0xaf7d0f5e34446178d80fadf5ddaaed52347121d2fac19ff184ff508d4776f2") + NftIdEnc, _ := new(felt.Felt).SetString("0xaf7d0f5e34446178d80fadf5ddaaed52347121d2fac19ff184ff508d4776f2") + TokenAmountEnc, _ := new(felt.Felt).SetString("0x14648649d4413eb385eea9ac7e6f2b9769671f5d9d7ad40f7b4aadd67839d4") u256dEnc, _ := new(felt.Felt).SetString("0x3b143be38b811560b45593fb2a071ec4ddd0a020e10782be62ffe6f39e0e82c") presetTypes := []TypeDefinition{ diff --git a/typed/typed.go b/typed/typed.go index a94156a9..eea1fe09 100644 --- a/typed/typed.go +++ b/typed/typed.go @@ -95,6 +95,12 @@ func NewTypedData(types []TypeDefinition, primaryType string, domain Domain, mes } typesMap[primaryType] = primaryTypeDef + for _, typeDef := range typesMap { + if typeDef.EncoddingString == "" { + return td, fmt.Errorf("'encodeTypes' failed: type '%s' doesn't have encode value", typeDef.Name) + } + } + td = &TypedData{ Types: typesMap, PrimaryType: primaryType, @@ -160,7 +166,9 @@ func (td *TypedData) GetMessageHash(account string) (hash *felt.Felt, err error) func (td *TypedData) GetStructHash(typeName string, context ...string) (hash *felt.Felt, err error) { typeDef, ok := td.Types[typeName] if !ok { - return hash, fmt.Errorf("error getting the type definition of %s", typeName) + if typeDef, ok = td.Revision.Types().Preset[typeName]; !ok { + return hash, fmt.Errorf("error getting the type definition of %s", typeName) + } } encTypeData, err := EncodeData(&typeDef, td, context...) if err != nil { @@ -192,9 +200,15 @@ func shortGetStructHash( // Returns: // - ret: the hash of the given type // - err: any error if any -func (td *TypedData) GetTypeHash(typeName string) *felt.Felt { +func (td *TypedData) GetTypeHash(typeName string) (*felt.Felt, error) { //TODO: create/update methods descriptions - return td.Types[typeName].Enconding + typeDef, ok := td.Types[typeName] + if !ok { + if typeDef, ok = td.Revision.Types().Preset[typeName]; !ok { + return typeDef.Enconding, fmt.Errorf("type '%s' not found", typeName) + } + } + return typeDef.Enconding, nil } // EncodeType encodes the given inType using the TypedData struct. @@ -206,10 +220,21 @@ func (td *TypedData) GetTypeHash(typeName string) *felt.Felt { // - err: any error if any func encodeTypes(typeName string, types map[string]TypeDefinition, revision *revision, isEnum ...bool) (newTypeDef TypeDefinition, err error) { getTypeEncodeString := func(typeName string, typeDef TypeDefinition, customTypesStringEnc *[]string, isEnum ...bool) (result string, err error) { - verifyTypeName := func(typeName string, isEnum ...bool) error { - singleTypeName, _ := strings.CutSuffix(typeName, "*") + verifyTypeName := func(param TypeParameter, isEnum ...bool) error { + singleTypeName, _ := strings.CutSuffix(param.Type, "*") if isBasicType(singleTypeName) { + if singleTypeName == "merkletree" { + if param.Contains == "" { + return fmt.Errorf("missing 'contains' value from '%s'", param.Name) + } + newTypeDef, err := encodeTypes(param.Contains, types, revision) + if err != nil { + return err + } + + types[param.Contains] = newTypeDef + } return nil } @@ -265,7 +290,7 @@ func encodeTypes(typeName string, types map[string]TypeDefinition, revision *rev buf.WriteString(fmt.Sprintf(quotationMark+"%s"+quotationMark+":"+`(`+"%s"+`)`, param.Name, fullTypeName)) for _, typeNam := range typesArr { - err = verifyTypeName(typeNam) + err = verifyTypeName(TypeParameter{Type: typeNam}) if err != nil { return "", err } @@ -278,7 +303,7 @@ func encodeTypes(typeName string, types map[string]TypeDefinition, revision *rev return "", fmt.Errorf("missing 'contains' value from '%s'", param.Name) } currentTypeName = param.Contains - err = verifyTypeName(currentTypeName, true) + err = verifyTypeName(TypeParameter{Type: currentTypeName}, true) if err != nil { return "", err } @@ -286,7 +311,7 @@ func encodeTypes(typeName string, types map[string]TypeDefinition, revision *rev buf.WriteString(fmt.Sprintf(quotationMark+"%s"+quotationMark+":"+quotationMark+"%s"+quotationMark, param.Name, currentTypeName)) - err = verifyTypeName(param.Type) + err = verifyTypeName(param) if err != nil { return "", err } @@ -442,7 +467,7 @@ func encodeData( return resp, fmt.Errorf("error trying to convert the value of '%s' to an map", typeDef) } - resp, err = shortGetStructHash(typeDef, typedData, mapData, context...) + resp, err = shortGetStructHash(typeDef, typedData, mapData) if err != nil { return resp, err } diff --git a/typed/typed_test.go b/typed/typed_test.go index 810f9f0d..a58cb5eb 100644 --- a/typed/typed_test.go +++ b/typed/typed_test.go @@ -444,10 +444,7 @@ func TestEncodeType(t *testing.T) { }, } for _, test := range testSet { - encode, err := encodeType(test.TypeName, test.TypedData.Types, test.TypedData.Revision) - require.NoError(t, err) - - require.Equal(t, test.ExpectedEncode, encode) + require.Equal(t, test.ExpectedEncode, test.TypedData.Types[test.TypeName].EncoddingString) } }