diff --git a/typed/revision.go b/typed/revision.go index ff4f9571..4825220e 100644 --- a/typed/revision.go +++ b/typed/revision.go @@ -146,11 +146,11 @@ func getRevisionV1PresetTypes() map[string]TypeDefinition { Parameters: []TypeParameter{ { Name: "low", - Type: "U128", + Type: "u128", }, { Name: "high", - Type: "U128", + Type: "u128", }, }, }, @@ -177,3 +177,22 @@ func isStandardType(typeName string) bool { return false } + +// Check if the provided type name is a basic type defined at the SNIP 12, also validates arrays +func isBasicType(typeName string) bool { + typeName, _ = strings.CutSuffix(typeName, "*") + + if slices.Contains(revision_0_basic_types, typeName) || + slices.Contains(revision_1_basic_types, typeName) { + return true + } + + return false +} + +// Check if the provided type name is a preset type defined at the SNIP 12, also validates arrays +func isPresetType(typeName string) bool { + typeName, _ = strings.CutSuffix(typeName, "*") + + return slices.Contains(revision_1_preset_types, typeName) +} diff --git a/typed/typed.go b/typed/typed.go index 0bf50c53..170d7a84 100644 --- a/typed/typed.go +++ b/typed/typed.go @@ -83,13 +83,21 @@ func NewTypedData(types []TypeDefinition, primaryType string, domain Domain, mes } for k, v := range td.Types { - enc, err := getTypeHash(k, td.Types, td.Revision.Version()) + enc, err := getTypeHash(k, td.Types, td.Revision) if err != nil { return td, fmt.Errorf("error encoding type hash: %s %w", k, err) } v.Enconding = enc td.Types[k] = v } + for k, v := range td.Revision.types.Preset { + enc, err := getTypeHash(k, td.Revision.types.Preset, td.Revision) + if err != nil { + return td, fmt.Errorf("error encoding type hash: %s %w", k, err) + } + v.Enconding = enc + td.Revision.types.Preset[k] = v + } return td, nil } @@ -181,11 +189,11 @@ func shortGetStructHash( // - err: any error if any func (td *TypedData) GetTypeHash(typeName string) (ret *felt.Felt, err error) { //TODO: create/update methods descriptions - return getTypeHash(typeName, td.Types, td.Revision.Version()) + return getTypeHash(typeName, td.Types, td.Revision) } -func getTypeHash(typeName string, types map[string]TypeDefinition, revisionVersion uint8) (ret *felt.Felt, err error) { - enc, err := encodeType(typeName, types, revisionVersion) +func getTypeHash(typeName string, types map[string]TypeDefinition, revision *revision) (ret *felt.Felt, err error) { + enc, err := encodeType(typeName, types, revision) if err != nil { return ret, err } @@ -199,14 +207,14 @@ func getTypeHash(typeName string, types map[string]TypeDefinition, revisionVersi // Returns: // - enc: the encoded type // - err: any error if any -func encodeType(typeName string, types map[string]TypeDefinition, revisionVersion uint8) (enc string, err error) { +func encodeType(typeName string, types map[string]TypeDefinition, revision *revision) (enc string, err error) { customTypesEncodeResp := make(map[string]string) var getEncodeType func(typeName string, typeDef TypeDefinition) (result string, err error) getEncodeType = func(typeName string, typeDef TypeDefinition) (result string, err error) { var buf bytes.Buffer quotationMark := "" - if revisionVersion == 1 { + if revision.Version() == 1 { quotationMark = `"` } @@ -220,12 +228,24 @@ func encodeType(typeName string, types map[string]TypeDefinition, revisionVersio if i != (len(typeDef.Parameters) - 1) { buf.WriteString(",") } - // e.g.: "felt" or "felt*" - if isStandardType(param.Type) { + + if isBasicType(param.Type) { continue } else if _, ok = customTypesEncodeResp[param.Type]; !ok { - var customTypeDef TypeDefinition - if customTypeDef, ok = types[param.Type]; !ok { //OBS: this is wrong on V1 + if isPresetType(param.Type) { + typeDef, ok := revision.Types().Preset[param.Type] + if !ok { + return result, fmt.Errorf("error trying to get the type definition of '%s'", param.Type) + } + customTypesEncodeResp[param.Type], err = getEncodeType(param.Type, typeDef) + if err != nil { + return "", err + } + + continue + } + customTypeDef, ok := types[param.Type] + if !ok { return "", fmt.Errorf("can't parse type %s from types %v", param.Type, types) } customTypesEncodeResp[param.Type], err = getEncodeType(param.Type, customTypeDef) @@ -305,13 +325,13 @@ func encodeData( } var handleStandardTypes func(param TypeParameter, data any, rev *revision) (resp *felt.Felt, err error) - var handleObjectTypes func(typeName string, data any) (resp *felt.Felt, err error) + var handleObjectTypes func(typeDef *TypeDefinition, data any) (resp *felt.Felt, err error) var handleArrays func(param TypeParameter, data any, rev *revision) (resp *felt.Felt, err error) getData := func(key string) (any, error) { value, ok := data[key] if !ok { - return value, fmt.Errorf("error trying to get the value of the %s param", key) + return value, fmt.Errorf("error trying to get the value of the '%s' param", key) } return value, nil } @@ -330,7 +350,11 @@ func encodeData( return resp, nil case "enum": case "NftId", "TokenAmount", "u256": - resp, err := handleObjectTypes(param.Type, data) + typeDef, ok := rev.Types().Preset[param.Type] + if !ok { + return resp, fmt.Errorf("error trying to get the type definition of '%s'", param.Type) + } + resp, err := handleObjectTypes(&typeDef, data) if err != nil { return resp, err } @@ -345,21 +369,18 @@ func encodeData( return resp, fmt.Errorf("error trying to encode the data of '%s'", param.Type) } - handleObjectTypes = func(typeName string, data any) (resp *felt.Felt, err error) { + handleObjectTypes = func(typeDef *TypeDefinition, data any) (resp *felt.Felt, err error) { mapData, ok := data.(map[string]any) if !ok { - return resp, fmt.Errorf("error trying to convert the value of '%s' to an map", typeName) + return resp, fmt.Errorf("error trying to convert the value of '%s' to an map", typeDef) } - if nextTypeDef, ok := typedData.Types[typeName]; ok { - resp, err := shortGetStructHash(&nextTypeDef, typedData, mapData, context...) - if err != nil { - return resp, err - } - - return resp, nil + resp, err = shortGetStructHash(typeDef, typedData, mapData, context...) + if err != nil { + return resp, err } - return resp, fmt.Errorf("error trying to get the type definition of '%s'", typeName) + + return resp, nil } handleArrays = func(param TypeParameter, data any, rev *revision) (resp *felt.Felt, err error) { @@ -381,8 +402,12 @@ func encodeData( return rev.HashMethod(localEncode...), nil } + typeDef, ok := rev.Types().Preset[singleParamType] + if !ok { + return resp, fmt.Errorf("error trying to get the type definition of '%s'", singleParamType) + } for _, item := range dataArray { - resp, err := handleObjectTypes(singleParamType, item) + resp, err := handleObjectTypes(&typeDef, item) if err != nil { return resp, err } @@ -415,11 +440,16 @@ func encodeData( continue } - resp, err := handleObjectTypes(param.Type, localData) + nextTypeDef, ok := typedData.Types[param.Type] + if !ok { + return enc, fmt.Errorf("error trying to get the type definition of '%s'", param.Type) + } + resp, err := handleObjectTypes(&nextTypeDef, localData) if err != nil { return enc, err } enc = append(enc, resp) + } return enc, nil diff --git a/typed/typed_test.go b/typed/typed_test.go index 24861a6c..fce7649e 100644 --- a/typed/typed_test.go +++ b/typed/typed_test.go @@ -62,7 +62,7 @@ func TestMain(m *testing.M) { "example_array", "example_baseTypes", // "example_enum", - // "example_presetTypes", + "example_presetTypes", // "mail_StructArray", // "session_MerkleTree", // "v1Nested", @@ -284,6 +284,11 @@ func TestGetMessageHash(t *testing.T) { Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", ExpectedMessageHash: "0xdb7829db8909c0c5496f5952bcfc4fc894341ce01842537fc4f448743480b6", }, + { + TypedData: typedDataExamples["example_presetTypes"], + Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", + ExpectedMessageHash: "0x185b339d5c566a883561a88fb36da301051e2c0225deb325c91bb7aa2f3473a", + }, // { // TypedData: typedDataExamples["session_MerkleTree"], // Address: "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", @@ -402,9 +407,14 @@ func TestEncodeType(t *testing.T) { TypeName: "Example", ExpectedEncode: `"Example"("n0":"felt","n1":"bool","n2":"string","n3":"selector","n4":"u128","n5":"i128","n6":"ContractAddress","n7":"ClassHash","n8":"timestamp","n9":"shortstring")`, }, + { + TypedData: typedDataExamples["example_presetTypes"], + TypeName: "Example", + ExpectedEncode: `"Example"("n0":"TokenAmount","n1":"NftId")"NftId"("collection_address":"ContractAddress","token_id":"u256")"TokenAmount"("token_address":"ContractAddress","amount":"u256")"u256"("low":"u128","high":"u128")`, + }, } for _, test := range testSet { - encode, err := encodeType(test.TypeName, test.TypedData.Types, test.TypedData.Revision.Version()) + encode, err := encodeType(test.TypeName, test.TypedData.Types, test.TypedData.Revision) require.NoError(t, err) require.Equal(t, test.ExpectedEncode, encode) @@ -443,7 +453,7 @@ func TestGetStructHash(t *testing.T) { { TypedData: typedDataExamples["example_baseTypes"], TypeName: "Example", - ExpectedHash: "0x2288b5f74a05d6e2f2efea4e2275a7fdfff532707e6ba77187c14ea84f1b778", + ExpectedHash: "0x75db031c1f5bf980cc48f46943b236cb85a95c8f3b3c8203572453075d3d39", }, } for _, test := range testSet {