Skip to content

Commit

Permalink
Adds revision 1 support to encodeType
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagodeev committed Nov 24, 2024
1 parent 4720826 commit 9ba0234
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 12 deletions.
1 change: 1 addition & 0 deletions typed/revision.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func init() {
}

type revision struct {
//TODO: create a enum
version uint8
domain string
hashMethod func(felts ...*felt.Felt) *felt.Felt
Expand Down
20 changes: 12 additions & 8 deletions typed/typed.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func NewTypedData(types []TypeDefinition, primaryType string, domain Domain, mes
}

for k, v := range td.Types {
enc, err := getTypeHash(k, td.Types)
enc, err := getTypeHash(k, td.Types, td.Revision.Version())
if err != nil {
return td, fmt.Errorf("error encoding type hash: %s %w", k, err)
}
Expand Down Expand Up @@ -207,11 +207,11 @@ func shortGetStructHash(
// - err: any error if any
func (td *TypedData) GetTypeHash(typeName string) (ret *felt.Felt, err error) {
//TODO: create/update methods descriptions
return getTypeHash(typeName, td.Types)
return getTypeHash(typeName, td.Types, td.Revision.Version())
}

func getTypeHash(typeName string, types map[string]TypeDefinition) (ret *felt.Felt, err error) {
enc, err := encodeType(typeName, types)
func getTypeHash(typeName string, types map[string]TypeDefinition, revisionVersion uint8) (ret *felt.Felt, err error) {
enc, err := encodeType(typeName, types, revisionVersion)
if err != nil {
return ret, err
}
Expand All @@ -225,25 +225,29 @@ func getTypeHash(typeName string, types map[string]TypeDefinition) (ret *felt.Fe
// Returns:
// - enc: the encoded type
// - err: any error if any
func encodeType(typeName string, types map[string]TypeDefinition) (enc string, err error) {
func encodeType(typeName string, types map[string]TypeDefinition, revisionVersion uint8) (enc string, err error) {
customTypesEncodeResp := make(map[string]string)

var getEncodeType func(typeName string, typeDef TypeDefinition) (result string, err error)
getEncodeType = func(typeName string, typeDef TypeDefinition) (result string, err error) {
var buf bytes.Buffer
quotationMark := ""
if revisionVersion == 1 {
quotationMark = `"`
}

buf.WriteString(typeName)
buf.WriteString(quotationMark + typeName + quotationMark)
buf.WriteString("(")

var ok bool

for i, param := range typeDef.Parameters {
buf.WriteString(fmt.Sprintf("%s:%s", param.Name, param.Type))
buf.WriteString(fmt.Sprintf(quotationMark+"%s"+quotationMark+":"+quotationMark+"%s"+quotationMark, param.Name, param.Type))
if i != (len(typeDef.Parameters) - 1) {
buf.WriteString(",")
}
// e.g.: "felt" or "felt*"
if slices.Contains(REVISION_0_TYPES, param.Type) || slices.Contains(REVISION_0_TYPES, fmt.Sprintf("%s*", param.Type)) {
if isStandardType(param.Type) {
continue
} else if _, ok = customTypesEncodeResp[param.Type]; !ok {
var customTypeDef TypeDefinition
Expand Down
18 changes: 14 additions & 4 deletions typed/typed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,24 +354,34 @@ func TestEncodeType(t *testing.T) {
type testSetType struct {
TypeName string
ExpectedEncode string
Revision revision
}
testSet := []testSetType{
// revision 0
{
TypeName: "StarkNetDomain",
ExpectedEncode: "StarkNetDomain(name:felt,version:felt,chainId:felt)",
Revision: RevisionV0,
},
{
TypeName: "Person",
ExpectedEncode: "Person(name:felt,wallet:felt)",
TypeName: "Mail",
ExpectedEncode: "Mail(from:Person,to:Person,contents:felt)Person(name:felt,wallet:felt)",
Revision: RevisionV0,
},
// revision 1
{
TypeName: "StarkNetDomain",
ExpectedEncode: `"StarkNetDomain"("name":"felt","version":"felt","chainId":"felt")`,
Revision: RevisionV1,
},
{
TypeName: "Mail",
ExpectedEncode: "Mail(from:Person,to:Person,contents:felt)Person(name:felt,wallet:felt)",
ExpectedEncode: `"Mail"("from":"Person","to":"Person","contents":"felt")"Person"("name":"felt","wallet":"felt")`,
Revision: RevisionV1,
},
}
for _, test := range testSet {
encode, err := encodeType(test.TypeName, ttd.Types)
encode, err := encodeType(test.TypeName, ttd.Types, test.Revision.Version())
require.NoError(err)

require.Equal(test.ExpectedEncode, encode)
Expand Down

0 comments on commit 9ba0234

Please sign in to comment.