diff --git a/typed/revision.go b/typed/revision.go index 0c60d5f3..7cb0e009 100644 --- a/typed/revision.go +++ b/typed/revision.go @@ -70,6 +70,7 @@ func init() { } type revision struct { + //TODO: create a enum version uint8 domain string hashMethod func(felts ...*felt.Felt) *felt.Felt diff --git a/typed/typed.go b/typed/typed.go index a27fb925..b9584c42 100644 --- a/typed/typed.go +++ b/typed/typed.go @@ -109,7 +109,7 @@ func NewTypedData(types []TypeDefinition, primaryType string, domain Domain, mes } for k, v := range td.Types { - enc, err := getTypeHash(k, td.Types) + enc, err := getTypeHash(k, td.Types, td.Revision.Version()) if err != nil { return td, fmt.Errorf("error encoding type hash: %s %w", k, err) } @@ -207,11 +207,11 @@ func shortGetStructHash( // - err: any error if any func (td *TypedData) GetTypeHash(typeName string) (ret *felt.Felt, err error) { //TODO: create/update methods descriptions - return getTypeHash(typeName, td.Types) + return getTypeHash(typeName, td.Types, td.Revision.Version()) } -func getTypeHash(typeName string, types map[string]TypeDefinition) (ret *felt.Felt, err error) { - enc, err := encodeType(typeName, types) +func getTypeHash(typeName string, types map[string]TypeDefinition, revisionVersion uint8) (ret *felt.Felt, err error) { + enc, err := encodeType(typeName, types, revisionVersion) if err != nil { return ret, err } @@ -225,25 +225,29 @@ func getTypeHash(typeName string, types map[string]TypeDefinition) (ret *felt.Fe // Returns: // - enc: the encoded type // - err: any error if any -func encodeType(typeName string, types map[string]TypeDefinition) (enc string, err error) { +func encodeType(typeName string, types map[string]TypeDefinition, revisionVersion uint8) (enc string, err error) { customTypesEncodeResp := make(map[string]string) var getEncodeType func(typeName string, typeDef TypeDefinition) (result string, err error) getEncodeType = func(typeName string, typeDef TypeDefinition) (result string, err error) { var buf bytes.Buffer + quotationMark := "" + if revisionVersion == 1 { + quotationMark = `"` + } - buf.WriteString(typeName) + buf.WriteString(quotationMark + typeName + quotationMark) buf.WriteString("(") var ok bool for i, param := range typeDef.Parameters { - buf.WriteString(fmt.Sprintf("%s:%s", param.Name, param.Type)) + buf.WriteString(fmt.Sprintf(quotationMark+"%s"+quotationMark+":"+quotationMark+"%s"+quotationMark, param.Name, param.Type)) if i != (len(typeDef.Parameters) - 1) { buf.WriteString(",") } // e.g.: "felt" or "felt*" - if slices.Contains(REVISION_0_TYPES, param.Type) || slices.Contains(REVISION_0_TYPES, fmt.Sprintf("%s*", param.Type)) { + if isStandardType(param.Type) { continue } else if _, ok = customTypesEncodeResp[param.Type]; !ok { var customTypeDef TypeDefinition diff --git a/typed/typed_test.go b/typed/typed_test.go index c8a14929..15c4ec95 100644 --- a/typed/typed_test.go +++ b/typed/typed_test.go @@ -354,24 +354,34 @@ func TestEncodeType(t *testing.T) { type testSetType struct { TypeName string ExpectedEncode string + Revision revision } testSet := []testSetType{ // revision 0 { TypeName: "StarkNetDomain", ExpectedEncode: "StarkNetDomain(name:felt,version:felt,chainId:felt)", + Revision: RevisionV0, }, { - TypeName: "Person", - ExpectedEncode: "Person(name:felt,wallet:felt)", + TypeName: "Mail", + ExpectedEncode: "Mail(from:Person,to:Person,contents:felt)Person(name:felt,wallet:felt)", + Revision: RevisionV0, + }, + // revision 1 + { + TypeName: "StarkNetDomain", + ExpectedEncode: `"StarkNetDomain"("name":"felt","version":"felt","chainId":"felt")`, + Revision: RevisionV1, }, { TypeName: "Mail", - ExpectedEncode: "Mail(from:Person,to:Person,contents:felt)Person(name:felt,wallet:felt)", + ExpectedEncode: `"Mail"("from":"Person","to":"Person","contents":"felt")"Person"("name":"felt","wallet":"felt")`, + Revision: RevisionV1, }, } for _, test := range testSet { - encode, err := encodeType(test.TypeName, ttd.Types) + encode, err := encodeType(test.TypeName, ttd.Types, test.Revision.Version()) require.NoError(err) require.Equal(test.ExpectedEncode, encode)