Skip to content

Commit

Permalink
Ignore nested fields that are not recursively embedded (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
vearutop authored Jan 25, 2021
1 parent 605b794 commit 753cad9
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 27 deletions.
1 change: 0 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ github.com/Masterminds/squirrel v1.5.0/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA4
github.com/bool64/ctxd v0.1.3 h1:n+aQ6UdoZXlltETFXqIhZR2DoMxhMYE2CW9BQn8aNBY=
github.com/bool64/ctxd v0.1.3/go.mod h1:rhUkoNE4mKFSJmo9l+78u2j+FVQifRCj0MHRhyZ2GDA=
github.com/bool64/dev v0.1.0/go.mod h1:pn52JC52uSgpazChx9CeXyG+S3sW2V36HHoLNBbscdg=
github.com/bool64/dev v0.1.10 h1:4L6eLD+qo1QgWDy+Y7OhJxi/gLwOAuV1rd07noMc3dU=
github.com/bool64/dev v0.1.10/go.mod h1:pn52JC52uSgpazChx9CeXyG+S3sW2V36HHoLNBbscdg=
github.com/bool64/dev v0.1.12 h1:mbuWBtCtOGwqt2lN1/9oPmn70XaOHdv2YHzQ31Zf9ks=
github.com/bool64/dev v0.1.12/go.mod h1:pn52JC52uSgpazChx9CeXyG+S3sW2V36HHoLNBbscdg=
Expand Down
65 changes: 61 additions & 4 deletions mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sqluct
import (
"errors"
"reflect"
"sync"

"github.com/Masterminds/squirrel"
"github.com/jmoiron/sqlx/reflectx"
Expand All @@ -17,9 +18,15 @@ var (
// Mapper prepares select, insert and update statements.
type Mapper struct {
ReflectMapper *reflectx.Mapper

mu sync.Mutex
types map[reflect.Type]*reflectx.StructMap
}

var reflectMapper = reflectx.NewMapper("db")
var (
reflectMapper = reflectx.NewMapper("db")
defaultMapper = &Mapper{}
)

// SkipZeroValues instructs mapper to ignore fields with zero values.
func SkipZeroValues(o *Options) {
Expand Down Expand Up @@ -170,7 +177,7 @@ func (sm *Mapper) colType(v reflect.Value, options ...func(*Options)) (*reflectx
panic("struct or slice/array of struct expected in sql query mapper")
}

tm := sm.reflectMapper().TypeMap(t)
tm := sm.typeMap(t)

return tm, o, skipValues
}
Expand Down Expand Up @@ -263,7 +270,7 @@ func (sm *Mapper) FindColumnName(structPtr, fieldPtr interface{}) (string, error
return "", errNotAPointer
}

tm := sm.reflectMapper().TypeMap(t)
tm := sm.typeMap(t)
for _, fi := range tm.Index {
fv := reflectx.FieldByIndexesReadOnly(v, fi.Index)
if fv.Addr().Interface() == fieldPtr {
Expand All @@ -274,6 +281,56 @@ func (sm *Mapper) FindColumnName(structPtr, fieldPtr interface{}) (string, error
return "", errFieldNotFound
}

func (sm *Mapper) typeMap(t reflect.Type) *reflectx.StructMap {
if sm == nil {
sm = defaultMapper
}

sm.mu.Lock()
defer sm.mu.Unlock()

tm, found := sm.types[t]
if found {
return tm
}

tm = sm.reflectMapper().TypeMap(t)
index := make([]*reflectx.FieldInfo, 0, len(tm.Index))

for _, fi := range tm.Index {
skip := false
p := fi.Parent

// Field is allowed to be a column if does not have a named parent (with non-empty path)
// or all parents are embedded.
for p != nil && p.Path != "" {
if !p.Embedded {
skip = true

break
}

p = p.Parent
}

if skip {
continue
}

index = append(index, fi)
}

tm.Index = index

if sm.types == nil {
sm.types = make(map[reflect.Type]*reflectx.StructMap, 1)
}

sm.types[t] = tm

return tm
}

// FindColumnNames returns column names mapped by a pointer to a field.
func (sm *Mapper) FindColumnNames(structPtr interface{}) (map[interface{}]string, error) {
if structPtr == nil {
Expand All @@ -289,7 +346,7 @@ func (sm *Mapper) FindColumnNames(structPtr interface{}) (map[interface{}]string

res := make(map[interface{}]string)

tm := sm.reflectMapper().TypeMap(t)
tm := sm.typeMap(t)
for _, fi := range tm.Index {
fv := reflectx.FieldByIndexesReadOnly(v, fi.Index)
res[fv.Addr().Interface()] = fi.Name
Expand Down
84 changes: 62 additions & 22 deletions mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,36 @@ import (

type (
Sample struct {
A int `db:"a"`
A int `db:"a"`
DeeplyEmbedded // Recursively embedded fields are used as root fields.
Meta AnotherRow `db:"meta"` // Meta is a column, but its fields are not.
}

DeeplyEmbedded struct {
SampleEmbedded
E string `db:"e"`
}

SampleEmbedded struct {
B float64 `db:"b"`
C string `db:"c"`
}

AnotherRow struct {
SampleEmbedded // These embedded fields won't show up in Sample statements.
D string `db:"d"` // This field won't show up in Sample statements.
}
)

func TestInsertValue(t *testing.T) {
z := Sample{
A: 1,
SampleEmbedded: SampleEmbedded{
B: 2.2,
C: "3",
DeeplyEmbedded: DeeplyEmbedded{
SampleEmbedded: SampleEmbedded{
B: 2.2,
C: "3",
},
E: "e!",
},
}

Expand All @@ -35,8 +49,8 @@ func TestInsertValue(t *testing.T) {
q := sm.Insert(ps.Insert("sample"), z)
query, args, err := q.ToSql()
assert.NoError(t, err)
assert.Equal(t, "INSERT INTO sample (a,b,c) VALUES ($1,$2,$3)", query)
assert.Equal(t, []interface{}{1, 2.2, "3"}, args)
assert.Equal(t, "INSERT INTO sample (a,meta,e,b,c) VALUES ($1,$2,$3,$4,$5)", query)
assert.Equal(t, []interface{}{1, AnotherRow{SampleEmbedded: SampleEmbedded{B: 0, C: ""}, D: ""}, "e!", 2.2, "3"}, args)
}

func TestMapper_Insert_nil(t *testing.T) {
Expand All @@ -61,16 +75,22 @@ func TestInsertValueSlice(t *testing.T) {
z := []Sample{
{
A: 1,
SampleEmbedded: SampleEmbedded{
B: 2.2,
C: "3",
DeeplyEmbedded: DeeplyEmbedded{
SampleEmbedded: SampleEmbedded{
B: 2.2,
C: "3",
},
E: "e!",
},
},
{
A: 4,
SampleEmbedded: SampleEmbedded{
B: 5.5,
C: "6",
DeeplyEmbedded: DeeplyEmbedded{
SampleEmbedded: SampleEmbedded{
B: 5.5,
C: "6",
},
E: "ee!",
},
},
}
Expand All @@ -82,24 +102,37 @@ func TestInsertValueSlice(t *testing.T) {
q = sm.Insert(q, z)
query, args, err := q.ToSql()
assert.NoError(t, err)
assert.Equal(t, "INSERT INTO sample (a,b,c) VALUES ($1,$2,$3),($4,$5,$6)", query)
assert.Equal(t, []interface{}{1, 2.2, "3", 4, 5.5, "6"}, args)
assert.Equal(t, "INSERT INTO sample (a,meta,e,b,c) VALUES ($1,$2,$3,$4,$5),($6,$7,$8,$9,$10)", query)
assert.Equal(t, []interface{}{
1,
AnotherRow{SampleEmbedded: SampleEmbedded{B: 0, C: ""}, D: ""},
"e!", 2.2, "3",
4,
AnotherRow{SampleEmbedded: SampleEmbedded{B: 0, C: ""}, D: ""},
"ee!", 5.5, "6",
}, args)
}

func TestInsertValueSlicePtr(t *testing.T) {
z := []Sample{
{
A: 1,
SampleEmbedded: SampleEmbedded{
B: 2.2,
C: "3",
DeeplyEmbedded: DeeplyEmbedded{
SampleEmbedded: SampleEmbedded{
B: 2.2,
C: "3",
},
E: "e!",
},
},
{
A: 4,
SampleEmbedded: SampleEmbedded{
B: 5.5,
C: "6",
DeeplyEmbedded: DeeplyEmbedded{
SampleEmbedded: SampleEmbedded{
B: 5.5,
C: "6",
},
E: "ee!",
},
},
}
Expand All @@ -109,8 +142,15 @@ func TestInsertValueSlicePtr(t *testing.T) {
q := sm.Insert(ps.Insert("sample"), z)
query, args, err := q.ToSql()
assert.NoError(t, err)
assert.Equal(t, "INSERT INTO sample (a,b,c) VALUES ($1,$2,$3),($4,$5,$6)", query)
assert.Equal(t, []interface{}{1, 2.2, "3", 4, 5.5, "6"}, args)
assert.Equal(t, "INSERT INTO sample (a,meta,e,b,c) VALUES ($1,$2,$3,$4,$5),($6,$7,$8,$9,$10)", query)
assert.Equal(t, []interface{}{
1,
AnotherRow{SampleEmbedded: SampleEmbedded{B: 0, C: ""}, D: ""},
"e!", 2.2, "3",
4,
AnotherRow{SampleEmbedded: SampleEmbedded{B: 0, C: ""}, D: ""},
"ee!", 5.5, "6",
}, args)
}

func TestMapper_Update(t *testing.T) {
Expand Down

0 comments on commit 753cad9

Please sign in to comment.