Skip to content

Commit

Permalink
fixes errors in encodeType, 'example_presetTypes' supported
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagodeev committed Nov 26, 2024
1 parent e261812 commit f081714
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 30 deletions.
23 changes: 21 additions & 2 deletions typed/revision.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ func getRevisionV1PresetTypes() map[string]TypeDefinition {
Parameters: []TypeParameter{
{
Name: "low",
Type: "U128",
Type: "u128",
},
{
Name: "high",
Type: "U128",
Type: "u128",
},
},
},
Expand All @@ -177,3 +177,22 @@ func isStandardType(typeName string) bool {

return false
}

// Check if the provided type name is a basic type defined at the SNIP 12, also validates arrays
func isBasicType(typeName string) bool {
typeName, _ = strings.CutSuffix(typeName, "*")

if slices.Contains(revision_0_basic_types, typeName) ||
slices.Contains(revision_1_basic_types, typeName) {
return true
}

return false
}

// Check if the provided type name is a preset type defined at the SNIP 12, also validates arrays
func isPresetType(typeName string) bool {
typeName, _ = strings.CutSuffix(typeName, "*")

return slices.Contains(revision_1_preset_types, typeName)
}
80 changes: 55 additions & 25 deletions typed/typed.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,21 @@ func NewTypedData(types []TypeDefinition, primaryType string, domain Domain, mes
}

for k, v := range td.Types {
enc, err := getTypeHash(k, td.Types, td.Revision.Version())
enc, err := getTypeHash(k, td.Types, td.Revision)
if err != nil {
return td, fmt.Errorf("error encoding type hash: %s %w", k, err)
}
v.Enconding = enc
td.Types[k] = v
}
for k, v := range td.Revision.types.Preset {
enc, err := getTypeHash(k, td.Revision.types.Preset, td.Revision)
if err != nil {
return td, fmt.Errorf("error encoding type hash: %s %w", k, err)
}
v.Enconding = enc
td.Revision.types.Preset[k] = v
}
return td, nil
}

Expand Down Expand Up @@ -181,11 +189,11 @@ func shortGetStructHash(
// - err: any error if any
func (td *TypedData) GetTypeHash(typeName string) (ret *felt.Felt, err error) {
//TODO: create/update methods descriptions
return getTypeHash(typeName, td.Types, td.Revision.Version())
return getTypeHash(typeName, td.Types, td.Revision)
}

func getTypeHash(typeName string, types map[string]TypeDefinition, revisionVersion uint8) (ret *felt.Felt, err error) {
enc, err := encodeType(typeName, types, revisionVersion)
func getTypeHash(typeName string, types map[string]TypeDefinition, revision *revision) (ret *felt.Felt, err error) {
enc, err := encodeType(typeName, types, revision)
if err != nil {
return ret, err
}
Expand All @@ -199,14 +207,14 @@ func getTypeHash(typeName string, types map[string]TypeDefinition, revisionVersi
// Returns:
// - enc: the encoded type
// - err: any error if any
func encodeType(typeName string, types map[string]TypeDefinition, revisionVersion uint8) (enc string, err error) {
func encodeType(typeName string, types map[string]TypeDefinition, revision *revision) (enc string, err error) {
customTypesEncodeResp := make(map[string]string)

var getEncodeType func(typeName string, typeDef TypeDefinition) (result string, err error)
getEncodeType = func(typeName string, typeDef TypeDefinition) (result string, err error) {
var buf bytes.Buffer
quotationMark := ""
if revisionVersion == 1 {
if revision.Version() == 1 {
quotationMark = `"`
}

Expand All @@ -220,12 +228,24 @@ func encodeType(typeName string, types map[string]TypeDefinition, revisionVersio
if i != (len(typeDef.Parameters) - 1) {
buf.WriteString(",")
}
// e.g.: "felt" or "felt*"
if isStandardType(param.Type) {

if isBasicType(param.Type) {
continue
} else if _, ok = customTypesEncodeResp[param.Type]; !ok {
var customTypeDef TypeDefinition
if customTypeDef, ok = types[param.Type]; !ok { //OBS: this is wrong on V1
if isPresetType(param.Type) {
typeDef, ok := revision.Types().Preset[param.Type]
if !ok {
return result, fmt.Errorf("error trying to get the type definition of '%s'", param.Type)
}
customTypesEncodeResp[param.Type], err = getEncodeType(param.Type, typeDef)
if err != nil {
return "", err
}

continue
}
customTypeDef, ok := types[param.Type]
if !ok {
return "", fmt.Errorf("can't parse type %s from types %v", param.Type, types)
}
customTypesEncodeResp[param.Type], err = getEncodeType(param.Type, customTypeDef)
Expand Down Expand Up @@ -305,13 +325,13 @@ func encodeData(
}

var handleStandardTypes func(param TypeParameter, data any, rev *revision) (resp *felt.Felt, err error)
var handleObjectTypes func(typeName string, data any) (resp *felt.Felt, err error)
var handleObjectTypes func(typeDef *TypeDefinition, data any) (resp *felt.Felt, err error)
var handleArrays func(param TypeParameter, data any, rev *revision) (resp *felt.Felt, err error)

getData := func(key string) (any, error) {
value, ok := data[key]
if !ok {
return value, fmt.Errorf("error trying to get the value of the %s param", key)
return value, fmt.Errorf("error trying to get the value of the '%s' param", key)
}
return value, nil
}
Expand All @@ -330,7 +350,11 @@ func encodeData(
return resp, nil
case "enum":
case "NftId", "TokenAmount", "u256":
resp, err := handleObjectTypes(param.Type, data)
typeDef, ok := rev.Types().Preset[param.Type]
if !ok {
return resp, fmt.Errorf("error trying to get the type definition of '%s'", param.Type)
}
resp, err := handleObjectTypes(&typeDef, data)
if err != nil {
return resp, err
}
Expand All @@ -345,21 +369,18 @@ func encodeData(
return resp, fmt.Errorf("error trying to encode the data of '%s'", param.Type)
}

handleObjectTypes = func(typeName string, data any) (resp *felt.Felt, err error) {
handleObjectTypes = func(typeDef *TypeDefinition, data any) (resp *felt.Felt, err error) {
mapData, ok := data.(map[string]any)
if !ok {
return resp, fmt.Errorf("error trying to convert the value of '%s' to an map", typeName)
return resp, fmt.Errorf("error trying to convert the value of '%s' to an map", typeDef)
}

if nextTypeDef, ok := typedData.Types[typeName]; ok {
resp, err := shortGetStructHash(&nextTypeDef, typedData, mapData, context...)
if err != nil {
return resp, err
}

return resp, nil
resp, err = shortGetStructHash(typeDef, typedData, mapData, context...)
if err != nil {
return resp, err
}
return resp, fmt.Errorf("error trying to get the type definition of '%s'", typeName)

return resp, nil
}

handleArrays = func(param TypeParameter, data any, rev *revision) (resp *felt.Felt, err error) {
Expand All @@ -381,8 +402,12 @@ func encodeData(
return rev.HashMethod(localEncode...), nil
}

typeDef, ok := rev.Types().Preset[singleParamType]
if !ok {
return resp, fmt.Errorf("error trying to get the type definition of '%s'", singleParamType)
}
for _, item := range dataArray {
resp, err := handleObjectTypes(singleParamType, item)
resp, err := handleObjectTypes(&typeDef, item)
if err != nil {
return resp, err
}
Expand Down Expand Up @@ -415,11 +440,16 @@ func encodeData(
continue
}

resp, err := handleObjectTypes(param.Type, localData)
nextTypeDef, ok := typedData.Types[param.Type]
if !ok {
return enc, fmt.Errorf("error trying to get the type definition of '%s'", param.Type)
}
resp, err := handleObjectTypes(&nextTypeDef, localData)
if err != nil {
return enc, err
}
enc = append(enc, resp)

}

return enc, nil
Expand Down
16 changes: 13 additions & 3 deletions typed/typed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestMain(m *testing.M) {
"example_array",
"example_baseTypes",
// "example_enum",
// "example_presetTypes",
"example_presetTypes",
// "mail_StructArray",
// "session_MerkleTree",
// "v1Nested",
Expand Down Expand Up @@ -284,6 +284,11 @@ func TestGetMessageHash(t *testing.T) {
Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826",
ExpectedMessageHash: "0xdb7829db8909c0c5496f5952bcfc4fc894341ce01842537fc4f448743480b6",
},
{
TypedData: typedDataExamples["example_presetTypes"],
Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826",
ExpectedMessageHash: "0x185b339d5c566a883561a88fb36da301051e2c0225deb325c91bb7aa2f3473a",
},
// {
// TypedData: typedDataExamples["session_MerkleTree"],
// Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826",
Expand Down Expand Up @@ -402,9 +407,14 @@ func TestEncodeType(t *testing.T) {
TypeName: "Example",
ExpectedEncode: `"Example"("n0":"felt","n1":"bool","n2":"string","n3":"selector","n4":"u128","n5":"i128","n6":"ContractAddress","n7":"ClassHash","n8":"timestamp","n9":"shortstring")`,
},
{
TypedData: typedDataExamples["example_presetTypes"],
TypeName: "Example",
ExpectedEncode: `"Example"("n0":"TokenAmount","n1":"NftId")"NftId"("collection_address":"ContractAddress","token_id":"u256")"TokenAmount"("token_address":"ContractAddress","amount":"u256")"u256"("low":"u128","high":"u128")`,
},
}
for _, test := range testSet {
encode, err := encodeType(test.TypeName, test.TypedData.Types, test.TypedData.Revision.Version())
encode, err := encodeType(test.TypeName, test.TypedData.Types, test.TypedData.Revision)
require.NoError(t, err)

require.Equal(t, test.ExpectedEncode, encode)
Expand Down Expand Up @@ -443,7 +453,7 @@ func TestGetStructHash(t *testing.T) {
{
TypedData: typedDataExamples["example_baseTypes"],
TypeName: "Example",
ExpectedHash: "0x2288b5f74a05d6e2f2efea4e2275a7fdfff532707e6ba77187c14ea84f1b778",
ExpectedHash: "0x75db031c1f5bf980cc48f46943b236cb85a95c8f3b3c8203572453075d3d39",
},
}
for _, test := range testSet {
Expand Down

0 comments on commit f081714

Please sign in to comment.