From 3defce44b45d69a2678f9d57a3e06667f83acb06 Mon Sep 17 00:00:00 2001 From: Geoff Watson Date: Mon, 18 Mar 2024 20:23:45 -0700 Subject: [PATCH] Add support for nested message overwrite --- fmutils.go | 65 ++++++++++++++++++++-------- fmutils_test.go | 113 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 159 insertions(+), 19 deletions(-) diff --git a/fmutils.go b/fmutils.go index d73cdcd..c942e23 100644 --- a/fmutils.go +++ b/fmutils.go @@ -170,30 +170,57 @@ func (mask NestedMask) Overwrite(src, dest proto.Message) { mask.overwrite(src.ProtoReflect(), dest.ProtoReflect()) } -func (mask NestedMask) overwrite(src, dest protoreflect.Message) { - for k, v := range mask { - srcFD := src.Descriptor().Fields().ByName(protoreflect.Name(k)) - destFD := dest.Descriptor().Fields().ByName(protoreflect.Name(k)) - if srcFD == nil || destFD == nil { - continue - } - - // Leaf mask -> copy value from src to dest - if len(v) == 0 { - if srcFD.Kind() == destFD.Kind() { // TODO: Full type equality check - val := src.Get(srcFD) - if isValid(srcFD, val) { - dest.Set(destFD, val) +func (mask NestedMask) overwrite(srcRft, destRft protoreflect.Message) { + for srcFDName, submask := range mask { + srcFD := srcRft.Descriptor().Fields().ByName(protoreflect.Name(srcFDName)) + srcVal := srcRft.Get(srcFD) + if len(submask) == 0 { + if isValid(srcFD, srcVal) { + destRft.Set(srcFD, srcVal) + } else { + destRft.Clear(srcFD) + } + } else if srcFD.IsMap() && srcFD.Kind() == protoreflect.MessageKind { + srcMap := srcRft.Get(srcFD).Map() + destMap := destRft.Get(srcFD).Map() + srcMap.Range(func(mk protoreflect.MapKey, mv protoreflect.Value) bool { + if mi, ok := submask[mk.String()]; ok { + if i, ok := mv.Interface().(protoreflect.Message); ok && len(mi) > 0 { + destVal := protoreflect.ValueOf(mv) + destMap.Set(mk, destVal) + mi.overwrite(i, destVal.Message()) + } else { + destMap.Set(mk, mv) + } + } + return true + }) + } else if srcFD.IsList() && srcFD.Kind() == protoreflect.MessageKind { + srcList := srcRft.Get(srcFD).List() + destList := destRft.Mutable(srcFD).List() + // Truncate anything in dest that exceeds the length of src + if srcList.Len() < destList.Len() { + destList.Truncate(srcList.Len()) + } + for i := 0; i < srcList.Len(); i++ { + srcListItem := srcList.Get(i) + var destListItem protoreflect.Message + if destList.Len() > i { + // Overwrite existing items. + destListItem = destList.Get(i).Message() } else { - dest.Clear(destFD) + // Append new items to overwrite. + destListItem = destList.AppendMutable().Message() } + submask.overwrite(srcListItem.Message(), destListItem) } + } else if srcFD.Kind() == protoreflect.MessageKind { - // If dest field is nil - if !dest.Get(destFD).Message().IsValid() { - dest.Set(destFD, protoreflect.ValueOf(dest.Get(destFD).Message().New())) + // If the dest field is nil + if !destRft.Get(srcFD).Message().IsValid() { + destRft.Set(srcFD, protoreflect.ValueOf(destRft.Get(srcFD).Message().New())) } - v.overwrite(src.Get(srcFD).Message(), dest.Get(destFD).Message()) + submask.overwrite(srcRft.Get(srcFD).Message(), destRft.Get(srcFD).Message()) } } } diff --git a/fmutils_test.go b/fmutils_test.go index d3b4add..0993be7 100644 --- a/fmutils_test.go +++ b/fmutils_test.go @@ -883,6 +883,119 @@ func TestOverwrite(t *testing.T) { }, }, }, + { + name: "overwrite repeated message fields", + paths: []string{"gallery.path"}, + src: &testproto.Profile{ + User: &testproto.User{ + UserId: 567, + Name: "different-name", + }, + Photo: &testproto.Photo{ + Path: "photo-path", + }, + LoginTimestamps: []int64{1, 2, 3}, + Attributes: map[string]*testproto.Attribute{ + "src": {}, + }, + Gallery: []*testproto.Photo{ + { + PhotoId: 123, + Path: "test-path-1", + Dimensions: &testproto.Dimensions{ + Width: 345, + Height: 456, + }, + }, + { + PhotoId: 234, + Path: "test-path-2", + Dimensions: &testproto.Dimensions{ + Width: 3456, + Height: 4567, + }, + }, + { + PhotoId: 345, + Path: "test-path-3", + Dimensions: &testproto.Dimensions{ + Width: 34567, + Height: 45678, + }, + }, + }, + }, + dest: &testproto.Profile{ + User: &testproto.User{ + Name: "name", + }, + Gallery: []*testproto.Photo{ + { + PhotoId: 123, + Path: "test-path-7", + Dimensions: &testproto.Dimensions{ + Width: 345, + Height: 456, + }, + }, + { + PhotoId: 234, + Path: "test-path-6", + Dimensions: &testproto.Dimensions{ + Width: 3456, + Height: 4567, + }, + }, + { + PhotoId: 345, + Path: "test-path-5", + Dimensions: &testproto.Dimensions{ + Width: 34567, + Height: 45678, + }, + }, + { + PhotoId: 345, + Path: "test-path-4", + Dimensions: &testproto.Dimensions{ + Width: 34567, + Height: 45678, + }, + }, + }, + }, + want: &testproto.Profile{ + User: &testproto.User{ + Name: "name", + }, + Gallery: []*testproto.Photo{ + { + PhotoId: 123, + Path: "test-path-1", + Dimensions: &testproto.Dimensions{ + Width: 345, + Height: 456, + }, + }, + { + PhotoId: 234, + Path: "test-path-2", + Dimensions: &testproto.Dimensions{ + Width: 3456, + Height: 4567, + }, + }, + { + PhotoId: 345, + Path: "test-path-3", + Dimensions: &testproto.Dimensions{ + Width: 34567, + Height: 45678, + }, + }, + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {