Skip to content

Commit

Permalink
Add NoTable reference mode (#21)
Browse files Browse the repository at this point in the history
* Add NoTable reference mode

* Add Refs for multiple references

* Fix lint
  • Loading branch information
vearutop authored Nov 26, 2024
1 parent 66a3996 commit 1b9c96c
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 27 deletions.
121 changes: 95 additions & 26 deletions referencer.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ type Referencer struct {
// Default QuoteNoop.
IdentifierQuoter func(tableAndColumn ...string) string

refs map[interface{}]Quoted
columnNames map[interface{}]string
structColumns map[interface{}][]string
refs map[interface{}]Quoted
quotedCols map[interface{}]Quoted
columnNames map[interface{}]string
structRefs map[interface{}][]string
}

// ColumnsOf makes a Mapper option to prefix columns with table alias.
Expand Down Expand Up @@ -93,6 +94,36 @@ func (r *Referencer) ColumnsOf(rowStructPtr interface{}) func(o *Options) {
}
}

// QuotedNoTable is a container of field pointer that should be referenced without table.
type QuotedNoTable struct {
ptr interface{}
}

// NoTable enables references without table prefix.
// So that `my_table`.`my_column` would be rendered as `my_column`.
//
// r.Ref(sqluct.NoTable(&row.MyColumn))
// r.Fmt("%s = 1", sqluct.NoTable(&row.MyColumn))
//
// Such references may be useful for INSERT/UPDATE column expressions.
func NoTable(ptr interface{}) QuotedNoTable {
return QuotedNoTable{ptr: ptr}
}

// NoTableAll enables references without table prefix for all field pointers.
// It can be useful to prepare multiple variadic arguments.
//
// r.Fmt("ON CONFLICT(%s) DO UPDATE SET %s = excluded.%s, %s = excluded.%s",
// sqluct.NoTableAll(&row.ID, &row.F1, &row.F1, &row.F2, &row.F3)...)
func NoTableAll(ptrs ...interface{}) []interface{} {
res := make([]interface{}, 0, len(ptrs))
for _, ptr := range ptrs {
res = append(res, NoTable(ptr))
}

return res
}

// AddTableAlias creates string references for row pointer and all suitable field pointers in it.
//
// Empty alias is not added to column reference.
Expand All @@ -106,37 +137,42 @@ func (r *Referencer) AddTableAlias(rowStructPtr interface{}, alias string) {
r.refs = make(map[interface{}]Quoted, len(f)+1)
}

if r.quotedCols == nil {
r.quotedCols = make(map[interface{}]Quoted, len(f)+1)
}

if r.columnNames == nil {
r.columnNames = make(map[interface{}]string, len(f))
}

if r.structColumns == nil {
r.structColumns = make(map[interface{}][]string)
if r.structRefs == nil {
r.structRefs = make(map[interface{}][]string)
}

if alias != "" {
r.refs[rowStructPtr] = r.Q(alias)
}

columns := make([]string, 0, len(f))
refs := make([]string, 0, len(f))

for ptr, fieldName := range f {
var col string
var ref Quoted

if alias == "" {
col = string(r.Q(fieldName))
ref = r.Q(fieldName)
} else {
col = string(r.Q(alias, fieldName))
ref = r.Q(alias, fieldName)
}

columns = append(columns, col)
r.refs[ptr] = Quoted(col)
refs = append(refs, string(ref))
r.refs[ptr] = ref
r.quotedCols[ptr] = r.Q(fieldName)
r.columnNames[ptr] = fieldName
}

sort.Strings(columns)
sort.Strings(refs)

r.structColumns[rowStructPtr] = columns
r.structRefs[rowStructPtr] = refs
}

// Quoted is a string that can be interpolated into an SQL statement as is.
Expand All @@ -155,11 +191,49 @@ func (r *Referencer) Q(tableAndColumn ...string) Quoted {
//
// It panics if pointer is unknown.
func (r *Referencer) Ref(ptr interface{}) string {
if ref, found := r.refs[ptr]; found {
return string(ref)
s, err := r.ref(ptr)
if err != nil {
panic(err)
}

panic(errUnknownFieldOrRow)
return s
}

func (r *Referencer) ref(ptr interface{}) (string, error) {
if q, ok := ptr.(Quoted); ok {
return string(q), nil
}

refs := r.refs

if nt, ok := ptr.(QuotedNoTable); ok {
ptr = nt.ptr
refs = r.quotedCols
}

if ref, found := refs[ptr]; found {
return string(ref), nil
}

return "", errUnknownFieldOrRow
}

// Refs returns reference strings for multiple field pointers.
//
// It panics if pointer is unknown.
func (r *Referencer) Refs(ptrs ...interface{}) []string {
args := make([]string, 0, len(ptrs))

for i, fieldPtr := range ptrs {
ref, err := r.ref(fieldPtr)
if err != nil {
panic(fmt.Errorf("%w at position %d", err, i))
}

args = append(args, ref)
}

return args
}

// Col returns unescaped column name for field pointer that was previously added with AddTableAlias.
Expand All @@ -181,25 +255,20 @@ func (r *Referencer) Fmt(format string, ptrs ...interface{}) string {
args := make([]interface{}, 0, len(ptrs))

for i, fieldPtr := range ptrs {
if q, ok := fieldPtr.(Quoted); ok {
args = append(args, string(q))

continue
ref, err := r.ref(fieldPtr)
if err != nil {
panic(fmt.Errorf("%w at position %d", err, i))
}

if ref, found := r.refs[fieldPtr]; found {
args = append(args, ref)
} else {
panic(fmt.Errorf("%w at position %d", errUnknownFieldOrRow, i))
}
args = append(args, ref)
}

return fmt.Sprintf(format, args...)
}

// Cols returns column references of a row structure.
func (r *Referencer) Cols(ptr interface{}) []string {
if cols, found := r.structColumns[ptr]; found {
if cols, found := r.structRefs[ptr]; found {
return cols
}

Expand Down
26 changes: 25 additions & 1 deletion referencer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func BenchmarkReferencer_Fmt_lite(b *testing.B) {

for i := 0; i < b.N; i++ {
// Find direct reports that share same last name and manager is not named John.
qb := squirrel.StatementBuilder.Select(rf.Fmt("%s, %s", &dr.ManagerID, &dr.EmployeeID)).
qb := squirrel.StatementBuilder.Select(rf.Refs(&dr.ManagerID, &dr.EmployeeID)...).
From(rf.Fmt("%s AS %s", rf.Q("users"), manager)).
InnerJoin(rf.Fmt("%s AS %s ON %s = %s AND %s = %s",
rf.Q("direct_reports"), dr,
Expand Down Expand Up @@ -258,3 +258,27 @@ func BenchmarkReferencer_Fmt_raw(b *testing.B) {
}
}
}

func TestNoTable(t *testing.T) {
ref := sqluct.Referencer{}
ref.Mapper = &sqluct.Mapper{Dialect: sqluct.DialectSQLite3}
ref.IdentifierQuoter = sqluct.QuoteBackticks

type User struct {
ID int `db:"id"`
FirstName string `db:"first_name"`
LastName string `db:"last_name"`
}

row := &User{}

ref.AddTableAlias(row, "users")

expr := ref.Fmt("ON CONFLICT(%s) DO UPDATE SET %s = excluded.%s, %s = excluded.%s",
sqluct.NoTableAll(&row.ID, &row.FirstName, &row.FirstName, &row.LastName, &row.LastName)...)

assert.Equal(t, "ON CONFLICT(`id`) DO UPDATE SET `first_name` = excluded.`first_name`, `last_name` = excluded.`last_name`", expr)

assert.Equal(t, "`first_name`", ref.Ref(sqluct.NoTable(&row.FirstName)))
assert.Equal(t, "`users`.`first_name`", ref.Ref(&row.FirstName))
}

0 comments on commit 1b9c96c

Please sign in to comment.