diff --git a/copy.go b/copy.go index 31b9687..e048af7 100644 --- a/copy.go +++ b/copy.go @@ -113,6 +113,12 @@ func structToStruct(filter FieldFilter, src, dst *reflect.Value, userOptions *op return errors.Errorf("dst type is %s, expected: %s ", dst.Type(), "*any.Any") } + // If subfilter is empty then copy the entire any without any unmarshalling. + if filter.IsEmpty() && !userOptions.UnmarshalAllAny { + dst.Set(*src) + break + } + srcProto, err := srcAny.UnmarshalNew() if err != nil { return errors.WithStack(err) @@ -250,6 +256,12 @@ type options struct { // It is called before copying the data from source to destination allowing custom processing. // If the visitor function returns true the visited field is skipped. MapVisitor mapVisitor + + // UnmarshalAllAny is used to indicate unmarshal all any fields. Default to true to keep backward compatibility. + // + // If an any field is encountered and this flag is not set, it will only Unmarshal it if there is a subfilter for that field. + // If set it will always Unmarshal all any fields + UnmarshalAllAny bool } // mapVisitor is called for every filtered field in structToMap. @@ -293,9 +305,18 @@ func WithMapVisitor(visitor mapVisitor) Option { } } +func WithUnmarshalAllAny(unmarshal bool) Option { + return func(o *options) { + o.UnmarshalAllAny = unmarshal + } +} + func newDefaultOptions() *options { // set default CopyListSize is func which return src.Len() - return &options{CopyListSize: func(src *reflect.Value) int { return src.Len() }} + return &options{ + CopyListSize: func(src *reflect.Value) int { return src.Len() }, + UnmarshalAllAny: true, + } } // fieldName gets the field name according to the field's tag, or gets StructField.Name default when the field's tag is empty. @@ -334,7 +355,7 @@ func structToMap(filter FieldFilter, src, dst reflect.Value, userOptions *option } srcType := src.Type() for i := 0; i < src.NumField(); i++ { - srcName := fieldName(userOptions.SrcTag, srcType.Field(i)) + srcName := fieldName(userOptions.SrcTag, srcType.Field(i)) if !isExported(srcType.Field(i)) { // Unexported fields can not be copied. continue diff --git a/copy_proto_test.go b/copy_proto_test.go index 720ba14..ed61ad5 100644 --- a/copy_proto_test.go +++ b/copy_proto_test.go @@ -277,6 +277,71 @@ func TestStructToStruct_NonProtoFail(t *testing.T) { assert.NotNil(t, err) } +func TestStructToStruct_UnknownAnyInSrcNoSubfieldMask(t *testing.T) { + userWithUnknown := &testproto.User{ + Details: []*anypb.Any{ + { + TypeUrl: "example.com/example/UnknownType", + Value: []byte("unknown"), + }, + }, + } + emptyUser := &testproto.User{} + + mask := fieldmask_utils.MaskFromString("Details") + err := fieldmask_utils.StructToStruct(mask, userWithUnknown, emptyUser, fieldmask_utils.WithUnmarshalAllAny(false)) + assert.NoError(t, err) + assert.Equal(t, userWithUnknown.Details, emptyUser.Details) +} + +func TestStructToStruct_UnknownAnyInDstNoSubfieldMask(t *testing.T) { + userWithUnknown := &testproto.User{ + Details: []*anypb.Any{ + { + TypeUrl: "example.com/example/UnknownType", + Value: []byte("unknown"), + }, + }, + } + emptyUser := &testproto.User{} + + mask := fieldmask_utils.MaskFromString("Details") + err := fieldmask_utils.StructToStruct(mask, emptyUser, userWithUnknown, fieldmask_utils.WithUnmarshalAllAny(false)) + assert.NoError(t, err) + assert.Equal(t, userWithUnknown.Details, emptyUser.Details) +} + +func TestStructToStruct_UnknownAnyDefault(t *testing.T) { + userWithUnknown := &testproto.User{ + Details: []*anypb.Any{ + { + TypeUrl: "example.com/example/UnknownType", + Value: []byte("unknown"), + }, + }, + } + emptyUser := &testproto.User{} + + mask := fieldmask_utils.MaskFromString("Details") + err := fieldmask_utils.StructToStruct(mask, userWithUnknown, emptyUser) + assert.Contains(t, err.Error(), "not found") +} + +func TestStructToStruct_UnknownAnySubfieldMask(t *testing.T) { + userWithUnknown := &testproto.User{ + Details: []*anypb.Any{ + { + TypeUrl: "example.com/example/UnknownType", + Value: []byte("unknown"), + }, + }, + } + emptyUser := &testproto.User{} + + mask := fieldmask_utils.MaskFromString("Details{Id}") + err := fieldmask_utils.StructToStruct(mask, userWithUnknown, emptyUser, fieldmask_utils.WithUnmarshalAllAny(false)) + assert.Contains(t, err.Error(), "not found") +} func TestStructToMap_Success(t *testing.T) { userDst := make(map[string]interface{}) mask := fieldmask_utils.MaskFromString(