Skip to content

Commit

Permalink
Add support for nested message overwrite
Browse files Browse the repository at this point in the history
  • Loading branch information
Geoff Watson committed Mar 19, 2024
1 parent 3bb1eb3 commit 3defce4
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 19 deletions.
65 changes: 46 additions & 19 deletions fmutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,30 +170,57 @@ func (mask NestedMask) Overwrite(src, dest proto.Message) {
mask.overwrite(src.ProtoReflect(), dest.ProtoReflect())
}

func (mask NestedMask) overwrite(src, dest protoreflect.Message) {
for k, v := range mask {
srcFD := src.Descriptor().Fields().ByName(protoreflect.Name(k))
destFD := dest.Descriptor().Fields().ByName(protoreflect.Name(k))
if srcFD == nil || destFD == nil {
continue
}

// Leaf mask -> copy value from src to dest
if len(v) == 0 {
if srcFD.Kind() == destFD.Kind() { // TODO: Full type equality check
val := src.Get(srcFD)
if isValid(srcFD, val) {
dest.Set(destFD, val)
func (mask NestedMask) overwrite(srcRft, destRft protoreflect.Message) {
for srcFDName, submask := range mask {
srcFD := srcRft.Descriptor().Fields().ByName(protoreflect.Name(srcFDName))
srcVal := srcRft.Get(srcFD)
if len(submask) == 0 {
if isValid(srcFD, srcVal) {
destRft.Set(srcFD, srcVal)
} else {
destRft.Clear(srcFD)
}
} else if srcFD.IsMap() && srcFD.Kind() == protoreflect.MessageKind {
srcMap := srcRft.Get(srcFD).Map()
destMap := destRft.Get(srcFD).Map()
srcMap.Range(func(mk protoreflect.MapKey, mv protoreflect.Value) bool {
if mi, ok := submask[mk.String()]; ok {
if i, ok := mv.Interface().(protoreflect.Message); ok && len(mi) > 0 {
destVal := protoreflect.ValueOf(mv)
destMap.Set(mk, destVal)
mi.overwrite(i, destVal.Message())
} else {
destMap.Set(mk, mv)
}

Check warning on line 194 in fmutils.go

View check run for this annotation

Codecov / codecov/patch

fmutils.go#L184-L194

Added lines #L184 - L194 were not covered by tests
}
return true

Check warning on line 196 in fmutils.go

View check run for this annotation

Codecov / codecov/patch

fmutils.go#L196

Added line #L196 was not covered by tests
})
} else if srcFD.IsList() && srcFD.Kind() == protoreflect.MessageKind {
srcList := srcRft.Get(srcFD).List()
destList := destRft.Mutable(srcFD).List()
// Truncate anything in dest that exceeds the length of src
if srcList.Len() < destList.Len() {
destList.Truncate(srcList.Len())
}
for i := 0; i < srcList.Len(); i++ {
srcListItem := srcList.Get(i)
var destListItem protoreflect.Message
if destList.Len() > i {
// Overwrite existing items.
destListItem = destList.Get(i).Message()
} else {
dest.Clear(destFD)
// Append new items to overwrite.
destListItem = destList.AppendMutable().Message()

Check warning on line 213 in fmutils.go

View check run for this annotation

Codecov / codecov/patch

fmutils.go#L212-L213

Added lines #L212 - L213 were not covered by tests
}
submask.overwrite(srcListItem.Message(), destListItem)
}

} else if srcFD.Kind() == protoreflect.MessageKind {
// If dest field is nil
if !dest.Get(destFD).Message().IsValid() {
dest.Set(destFD, protoreflect.ValueOf(dest.Get(destFD).Message().New()))
// If the dest field is nil
if !destRft.Get(srcFD).Message().IsValid() {
destRft.Set(srcFD, protoreflect.ValueOf(destRft.Get(srcFD).Message().New()))
}
v.overwrite(src.Get(srcFD).Message(), dest.Get(destFD).Message())
submask.overwrite(srcRft.Get(srcFD).Message(), destRft.Get(srcFD).Message())
}
}
}
Expand Down
113 changes: 113 additions & 0 deletions fmutils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,119 @@ func TestOverwrite(t *testing.T) {
},
},
},
{
name: "overwrite repeated message fields",
paths: []string{"gallery.path"},
src: &testproto.Profile{
User: &testproto.User{
UserId: 567,
Name: "different-name",
},
Photo: &testproto.Photo{
Path: "photo-path",
},
LoginTimestamps: []int64{1, 2, 3},
Attributes: map[string]*testproto.Attribute{
"src": {},
},
Gallery: []*testproto.Photo{
{
PhotoId: 123,
Path: "test-path-1",
Dimensions: &testproto.Dimensions{
Width: 345,
Height: 456,
},
},
{
PhotoId: 234,
Path: "test-path-2",
Dimensions: &testproto.Dimensions{
Width: 3456,
Height: 4567,
},
},
{
PhotoId: 345,
Path: "test-path-3",
Dimensions: &testproto.Dimensions{
Width: 34567,
Height: 45678,
},
},
},
},
dest: &testproto.Profile{
User: &testproto.User{
Name: "name",
},
Gallery: []*testproto.Photo{
{
PhotoId: 123,
Path: "test-path-7",
Dimensions: &testproto.Dimensions{
Width: 345,
Height: 456,
},
},
{
PhotoId: 234,
Path: "test-path-6",
Dimensions: &testproto.Dimensions{
Width: 3456,
Height: 4567,
},
},
{
PhotoId: 345,
Path: "test-path-5",
Dimensions: &testproto.Dimensions{
Width: 34567,
Height: 45678,
},
},
{
PhotoId: 345,
Path: "test-path-4",
Dimensions: &testproto.Dimensions{
Width: 34567,
Height: 45678,
},
},
},
},
want: &testproto.Profile{
User: &testproto.User{
Name: "name",
},
Gallery: []*testproto.Photo{
{
PhotoId: 123,
Path: "test-path-1",
Dimensions: &testproto.Dimensions{
Width: 345,
Height: 456,
},
},
{
PhotoId: 234,
Path: "test-path-2",
Dimensions: &testproto.Dimensions{
Width: 3456,
Height: 4567,
},
},
{
PhotoId: 345,
Path: "test-path-3",
Dimensions: &testproto.Dimensions{
Width: 34567,
Height: 45678,
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down

0 comments on commit 3defce4

Please sign in to comment.