From 1b5d956038986f146eff0b77e5124e8ed889cf8f Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Fri, 13 Dec 2024 01:21:33 -0300 Subject: [PATCH] Improves 'chainId' validation --- typedData/typedData.go | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/typedData/typedData.go b/typedData/typedData.go index c52d2d88..57b7e986 100644 --- a/typedData/typedData.go +++ b/typedData/typedData.go @@ -852,7 +852,7 @@ func (domain *Domain) UnmarshalJSON(data []byte) error { getField := func(fieldName string) (string, error) { value, ok := dec[fieldName] if !ok { - return "", fmt.Errorf("error getting value of '%s' from 'domain' struct", fieldName) + return "", fmt.Errorf("error getting the value of '%s' from 'domain' struct", fieldName) } return fmt.Sprintf("%v", value), nil } @@ -867,16 +867,6 @@ func (domain *Domain) UnmarshalJSON(data []byte) error { return err } - chainId, err := getField("chainId") - if err != nil { - var err2 error - // ref: https://community.starknet.io/t/signing-transactions-and-off-chain-messages/66 - chainId, err2 = getField("chain_id") - if err2 != nil { - return err - } - } - revision, err := getField("revision") if err != nil { revision = "0" @@ -886,6 +876,19 @@ func (domain *Domain) UnmarshalJSON(data []byte) error { return err } + chainId, err := getField("chainId") + if err != nil { + if numRevision == 1 { + return err + } + var err2 error + // ref: https://community.starknet.io/t/signing-transactions-and-off-chain-messages/66 + chainId, err2 = getField("chain_id") + if err2 != nil { + return fmt.Errorf("%w: %w", err, err2) + } + } + *domain = Domain{ Name: name, Version: version,