diff --git a/go.sum b/go.sum index 6e89490..4903c30 100644 --- a/go.sum +++ b/go.sum @@ -5,7 +5,6 @@ github.com/Masterminds/squirrel v1.5.0/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA4 github.com/bool64/ctxd v0.1.3 h1:n+aQ6UdoZXlltETFXqIhZR2DoMxhMYE2CW9BQn8aNBY= github.com/bool64/ctxd v0.1.3/go.mod h1:rhUkoNE4mKFSJmo9l+78u2j+FVQifRCj0MHRhyZ2GDA= github.com/bool64/dev v0.1.0/go.mod h1:pn52JC52uSgpazChx9CeXyG+S3sW2V36HHoLNBbscdg= -github.com/bool64/dev v0.1.10 h1:4L6eLD+qo1QgWDy+Y7OhJxi/gLwOAuV1rd07noMc3dU= github.com/bool64/dev v0.1.10/go.mod h1:pn52JC52uSgpazChx9CeXyG+S3sW2V36HHoLNBbscdg= github.com/bool64/dev v0.1.12 h1:mbuWBtCtOGwqt2lN1/9oPmn70XaOHdv2YHzQ31Zf9ks= github.com/bool64/dev v0.1.12/go.mod h1:pn52JC52uSgpazChx9CeXyG+S3sW2V36HHoLNBbscdg= diff --git a/mapper.go b/mapper.go index f609edc..580cb64 100644 --- a/mapper.go +++ b/mapper.go @@ -3,6 +3,7 @@ package sqluct import ( "errors" "reflect" + "sync" "github.com/Masterminds/squirrel" "github.com/jmoiron/sqlx/reflectx" @@ -17,9 +18,15 @@ var ( // Mapper prepares select, insert and update statements. type Mapper struct { ReflectMapper *reflectx.Mapper + + mu sync.Mutex + types map[reflect.Type]*reflectx.StructMap } -var reflectMapper = reflectx.NewMapper("db") +var ( + reflectMapper = reflectx.NewMapper("db") + defaultMapper = &Mapper{} +) // SkipZeroValues instructs mapper to ignore fields with zero values. func SkipZeroValues(o *Options) { @@ -170,7 +177,7 @@ func (sm *Mapper) colType(v reflect.Value, options ...func(*Options)) (*reflectx panic("struct or slice/array of struct expected in sql query mapper") } - tm := sm.reflectMapper().TypeMap(t) + tm := sm.typeMap(t) return tm, o, skipValues } @@ -263,7 +270,7 @@ func (sm *Mapper) FindColumnName(structPtr, fieldPtr interface{}) (string, error return "", errNotAPointer } - tm := sm.reflectMapper().TypeMap(t) + tm := sm.typeMap(t) for _, fi := range tm.Index { fv := reflectx.FieldByIndexesReadOnly(v, fi.Index) if fv.Addr().Interface() == fieldPtr { @@ -274,6 +281,56 @@ func (sm *Mapper) FindColumnName(structPtr, fieldPtr interface{}) (string, error return "", errFieldNotFound } +func (sm *Mapper) typeMap(t reflect.Type) *reflectx.StructMap { + if sm == nil { + sm = defaultMapper + } + + sm.mu.Lock() + defer sm.mu.Unlock() + + tm, found := sm.types[t] + if found { + return tm + } + + tm = sm.reflectMapper().TypeMap(t) + index := make([]*reflectx.FieldInfo, 0, len(tm.Index)) + + for _, fi := range tm.Index { + skip := false + p := fi.Parent + + // Field is allowed to be a column if does not have a named parent (with non-empty path) + // or all parents are embedded. + for p != nil && p.Path != "" { + if !p.Embedded { + skip = true + + break + } + + p = p.Parent + } + + if skip { + continue + } + + index = append(index, fi) + } + + tm.Index = index + + if sm.types == nil { + sm.types = make(map[reflect.Type]*reflectx.StructMap, 1) + } + + sm.types[t] = tm + + return tm +} + // FindColumnNames returns column names mapped by a pointer to a field. func (sm *Mapper) FindColumnNames(structPtr interface{}) (map[interface{}]string, error) { if structPtr == nil { @@ -289,7 +346,7 @@ func (sm *Mapper) FindColumnNames(structPtr interface{}) (map[interface{}]string res := make(map[interface{}]string) - tm := sm.reflectMapper().TypeMap(t) + tm := sm.typeMap(t) for _, fi := range tm.Index { fv := reflectx.FieldByIndexesReadOnly(v, fi.Index) res[fv.Addr().Interface()] = fi.Name diff --git a/mapper_test.go b/mapper_test.go index fef1d0f..d7617f9 100644 --- a/mapper_test.go +++ b/mapper_test.go @@ -11,22 +11,36 @@ import ( type ( Sample struct { - A int `db:"a"` + A int `db:"a"` + DeeplyEmbedded // Recursively embedded fields are used as root fields. + Meta AnotherRow `db:"meta"` // Meta is a column, but its fields are not. + } + + DeeplyEmbedded struct { SampleEmbedded + E string `db:"e"` } SampleEmbedded struct { B float64 `db:"b"` C string `db:"c"` } + + AnotherRow struct { + SampleEmbedded // These embedded fields won't show up in Sample statements. + D string `db:"d"` // This field won't show up in Sample statements. + } ) func TestInsertValue(t *testing.T) { z := Sample{ A: 1, - SampleEmbedded: SampleEmbedded{ - B: 2.2, - C: "3", + DeeplyEmbedded: DeeplyEmbedded{ + SampleEmbedded: SampleEmbedded{ + B: 2.2, + C: "3", + }, + E: "e!", }, } @@ -35,8 +49,8 @@ func TestInsertValue(t *testing.T) { q := sm.Insert(ps.Insert("sample"), z) query, args, err := q.ToSql() assert.NoError(t, err) - assert.Equal(t, "INSERT INTO sample (a,b,c) VALUES ($1,$2,$3)", query) - assert.Equal(t, []interface{}{1, 2.2, "3"}, args) + assert.Equal(t, "INSERT INTO sample (a,meta,e,b,c) VALUES ($1,$2,$3,$4,$5)", query) + assert.Equal(t, []interface{}{1, AnotherRow{SampleEmbedded: SampleEmbedded{B: 0, C: ""}, D: ""}, "e!", 2.2, "3"}, args) } func TestMapper_Insert_nil(t *testing.T) { @@ -61,16 +75,22 @@ func TestInsertValueSlice(t *testing.T) { z := []Sample{ { A: 1, - SampleEmbedded: SampleEmbedded{ - B: 2.2, - C: "3", + DeeplyEmbedded: DeeplyEmbedded{ + SampleEmbedded: SampleEmbedded{ + B: 2.2, + C: "3", + }, + E: "e!", }, }, { A: 4, - SampleEmbedded: SampleEmbedded{ - B: 5.5, - C: "6", + DeeplyEmbedded: DeeplyEmbedded{ + SampleEmbedded: SampleEmbedded{ + B: 5.5, + C: "6", + }, + E: "ee!", }, }, } @@ -82,24 +102,37 @@ func TestInsertValueSlice(t *testing.T) { q = sm.Insert(q, z) query, args, err := q.ToSql() assert.NoError(t, err) - assert.Equal(t, "INSERT INTO sample (a,b,c) VALUES ($1,$2,$3),($4,$5,$6)", query) - assert.Equal(t, []interface{}{1, 2.2, "3", 4, 5.5, "6"}, args) + assert.Equal(t, "INSERT INTO sample (a,meta,e,b,c) VALUES ($1,$2,$3,$4,$5),($6,$7,$8,$9,$10)", query) + assert.Equal(t, []interface{}{ + 1, + AnotherRow{SampleEmbedded: SampleEmbedded{B: 0, C: ""}, D: ""}, + "e!", 2.2, "3", + 4, + AnotherRow{SampleEmbedded: SampleEmbedded{B: 0, C: ""}, D: ""}, + "ee!", 5.5, "6", + }, args) } func TestInsertValueSlicePtr(t *testing.T) { z := []Sample{ { A: 1, - SampleEmbedded: SampleEmbedded{ - B: 2.2, - C: "3", + DeeplyEmbedded: DeeplyEmbedded{ + SampleEmbedded: SampleEmbedded{ + B: 2.2, + C: "3", + }, + E: "e!", }, }, { A: 4, - SampleEmbedded: SampleEmbedded{ - B: 5.5, - C: "6", + DeeplyEmbedded: DeeplyEmbedded{ + SampleEmbedded: SampleEmbedded{ + B: 5.5, + C: "6", + }, + E: "ee!", }, }, } @@ -109,8 +142,15 @@ func TestInsertValueSlicePtr(t *testing.T) { q := sm.Insert(ps.Insert("sample"), z) query, args, err := q.ToSql() assert.NoError(t, err) - assert.Equal(t, "INSERT INTO sample (a,b,c) VALUES ($1,$2,$3),($4,$5,$6)", query) - assert.Equal(t, []interface{}{1, 2.2, "3", 4, 5.5, "6"}, args) + assert.Equal(t, "INSERT INTO sample (a,meta,e,b,c) VALUES ($1,$2,$3,$4,$5),($6,$7,$8,$9,$10)", query) + assert.Equal(t, []interface{}{ + 1, + AnotherRow{SampleEmbedded: SampleEmbedded{B: 0, C: ""}, D: ""}, + "e!", 2.2, "3", + 4, + AnotherRow{SampleEmbedded: SampleEmbedded{B: 0, C: ""}, D: ""}, + "ee!", 5.5, "6", + }, args) } func TestMapper_Update(t *testing.T) {