diff --git a/README.md b/README.md index 3f919a0..474ccd8 100644 --- a/README.md +++ b/README.md @@ -70,8 +70,8 @@ fieldmask_utils.StructToStruct(mask, request.User, userDst) 2. Masks inside a protobuf `Map` are not supported. 3. When copying from a struct to struct the destination struct must have the same fields (or a subset) - as the source struct. Pointers must also be coherent: if a field is a pointer in the source struct, then - it also must be a pointer (not a value field) in the destination struct. + as the source struct. Either of source or destination fields can be a pointer as long as it is a pointer to + the type of the corresponding field. 4. `oneof` fields are represented differently in `fieldmaskpb.FieldMask` compared to `fieldmask_util.Mask`. In [FieldMask](https://pkg.go.dev/google.golang.org/protobuf/types/known/fieldmaskpb#:~:text=%23%20Field%20Masks%20and%20Oneof%20Fields) the fields are represented using their property name, in this library they are prefixed with the `oneof` name diff --git a/copy.go b/copy.go index 9ec304f..89e30b0 100644 --- a/copy.go +++ b/copy.go @@ -34,9 +34,24 @@ func StructToStruct(filter FieldFilter, src, dst interface{}, userOpts ...Option return structToStruct(filter, &srcVal, &dstVal, opts) } +func ensureCompatible(src, dst *reflect.Value) error { + srcKind := src.Kind() + if srcKind == reflect.Ptr { + srcKind = src.Type().Elem().Kind() + } + dstKind := dst.Kind() + if dstKind == reflect.Ptr { + dstKind = dst.Type().Elem().Kind() + } + if srcKind != dstKind { + return errors.Errorf("src kind %s differs from dst kind %s", srcKind, dstKind) + } + return nil +} + func structToStruct(filter FieldFilter, src, dst *reflect.Value, userOptions *options) error { - if src.Kind() != dst.Kind() { - return errors.Errorf("src kind %s differs from dst kind %s", src.Kind(), dst.Kind()) + if err := ensureCompatible(src, dst); err != nil { + return err } switch src.Kind() { @@ -46,6 +61,14 @@ func structToStruct(filter FieldFilter, src, dst *reflect.Value, userOptions *op return nil } + if dst.Kind() == reflect.Ptr { + if dst.IsNil() { + dst.Set(reflect.New(dst.Type().Elem())) + } + v := dst.Elem() + dst = &v + } + for i := 0; i < src.NumField(); i++ { srcType := src.Type() fieldName := srcType.Field(i).Name @@ -78,7 +101,7 @@ func structToStruct(filter FieldFilter, src, dst *reflect.Value, userOptions *op dst.Set(reflect.Zero(dst.Type())) break } - if dst.IsNil() { + if dst.Kind() == reflect.Ptr && dst.IsNil() { // If dst is nil create a new instance of the underlying type and set dst to the pointer of that instance. dst.Set(reflect.New(dst.Type().Elem())) } @@ -117,7 +140,11 @@ func structToStruct(filter FieldFilter, src, dst *reflect.Value, userOptions *op break } - srcElem, dstElem := src.Elem(), dst.Elem() + srcElem, dstElem := src.Elem(), *dst + if dst.Kind() == reflect.Ptr { + dstElem = dst.Elem() + } + if err := structToStruct(filter, &srcElem, &dstElem, userOptions); err != nil { return err } @@ -185,7 +212,14 @@ func structToStruct(filter FieldFilter, src, dst *reflect.Value, userOptions *op if !dst.CanSet() { return errors.Errorf("dst %s, %s is not settable", dst, dst.Type()) } - dst.Set(*src) + if dst.Kind() == reflect.Ptr { + if !src.CanAddr() { + return errors.Errorf("src %s, %s is not addressable", src, src.Type()) + } + dst.Set(src.Addr()) + } else { + dst.Set(*src) + } } return nil diff --git a/copy_test.go b/copy_test.go index 995e30c..0460ba7 100644 --- a/copy_test.go +++ b/copy_test.go @@ -47,6 +47,104 @@ func TestStructToStruct_PtrToInt(t *testing.T) { }, dst) } +func TestStructToStruct_StructToPointer(t *testing.T) { + v15 := 15 + v42 := 42 + + type N struct { + Field1 int + } + type S struct { + Field1 N + Field2 int + } + src := &S{ + Field1: N{ + Field1: v15, + }, + Field2: v42, + } + type SN struct { + Field1 *int + } + type D struct { + Field1 *SN + Field2 *int + } + dst := new(D) + + mask := fieldmask_utils.MaskFromString("Field1,Field2") + err := fieldmask_utils.StructToStruct(mask, src, dst) + require.NoError(t, err) + assert.Equal(t, &D{ + Field1: &SN{ + Field1: &v15, + }, + Field2: &v42, + }, dst) +} + +func TestStructToStruct_IntToPointer(t *testing.T) { + v := 42 + + type S struct { + Field2 int + } + src := &S{ + Field2: v, + } + type D struct { + Field2 *int + } + dst := new(D) + + mask := fieldmask_utils.MaskFromString("Field2") + err := fieldmask_utils.StructToStruct(mask, src, dst) + require.NoError(t, err) + assert.Equal(t, &D{ + Field2: &v, + }, dst) +} + +func TestStructToStruct_PointerToInt(t *testing.T) { + v := 42 + + type S struct { + Field2 *int + } + src := &S{ + Field2: &v, + } + type D struct { + Field2 int + } + dst := new(D) + + mask := fieldmask_utils.MaskFromString("Field2") + err := fieldmask_utils.StructToStruct(mask, src, dst) + require.NoError(t, err) + assert.Equal(t, &D{ + Field2: 42, + }, dst) +} + +func TestStructToStruct_Incompatible(t *testing.T) { + type S struct { + Field2 int + } + src := &S{ + Field2: 42, + } + type D struct { + Field2 string + } + dst := new(D) + + mask := fieldmask_utils.MaskFromString("Field2") + err := fieldmask_utils.StructToStruct(mask, src, dst) + require.EqualError(t, err, "src kind int differs from dst kind string") +} + func TestStructToStruct_PtrToStruct_EmptyDst(t *testing.T) { type A struct { Field1 string