Skip to content

Commit

Permalink
x/sqlbuilder/upserter: Directly assert in the SQL query if the value …
Browse files Browse the repository at this point in the history
…are the correct ones
  • Loading branch information
AlexisMontagne committed Sep 17, 2024
1 parent b1c0f8d commit 60a3d13
Showing 1 changed file with 22 additions and 56 deletions.
78 changes: 22 additions & 56 deletions x/sqlbuilder/upserter/execer.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
package upserter

import (
"bytes"
"context"
"database/sql/driver"
"fmt"
"reflect"
"time"

"github.com/upfluence/errors"
"github.com/upfluence/sql"
Expand Down Expand Up @@ -127,6 +124,12 @@ func newExecer(te txExecutor, stmt Statement) sqlbuilder.Execer {
}
}

var selectClauses = []sqlbuilder.Marker{oneMarker}

for _, m := range stmt.SetValues {
selectClauses = append(selectClauses, &assertMarker{Marker: m})
}

var (
clauses = make([]sqlbuilder.PredicateClause, len(stmt.QueryValues))

Expand All @@ -136,7 +139,7 @@ func newExecer(te txExecutor, stmt Statement) sqlbuilder.Execer {
sfs: make([]string, len(stmt.SetValues)),
ss: sqlbuilder.SelectStatement{
Table: stmt.Table,
SelectClauses: append([]sqlbuilder.Marker{oneMarker}, stmt.SetValues...),
SelectClauses: selectClauses,
},
us: sqlbuilder.UpdateStatement{
Table: stmt.Table,
Expand Down Expand Up @@ -199,49 +202,19 @@ func newExecer(te txExecutor, stmt Statement) sqlbuilder.Execer {
return &e
}

func cloneValue(v any) (any, error) {
if dv, ok := v.(driver.Valuer); ok {
vv, err := dv.Value()

if err != nil {
return nil, err
}

if vv != nil {
v = vv
}
}

return reflect.New(reflect.TypeOf(v)).Interface(), nil
type assertMarker struct {
sqlbuilder.Marker
}

func equalValues(x, y any) (bool, error) {
if dy, ok := y.(driver.Valuer); ok {
yy, err := dy.Value()

if err != nil {
return false, err
}

if yy != nil {
y = yy
}
}

switch yy := y.(type) {
case time.Time:
if xx, ok := x.(time.Time); ok {
return xx.Equal(yy), nil
}
case []byte:
if xx, ok := x.([]byte); ok {
return bytes.Equal(yy, xx), nil
}
default:
return reflect.DeepEqual(x, y), nil
}
func (am *assertMarker) WriteTo(qw sqlbuilder.QueryWriter, vs map[string]interface{}) error {
_, err := fmt.Fprintf(
qw,
"%s = %s",
am.ToSQL(),
qw.RedeemVariable(vs["assert_"+am.Binding()]),
)

return false, nil
return err
}

func (e *execer) Exec(ctx context.Context, vs map[string]interface{}) (sql.Result, error) {
Expand Down Expand Up @@ -271,13 +244,11 @@ func (e *execer) Exec(ctx context.Context, vs map[string]interface{}) (sql.Resul
return nil, sqlbuilder.ErrMissingKey{Key: f}
}

var err error
var val sql.NullBool

existing[f], err = cloneValue(v)
qvs["assert_"+f] = v

if err != nil {
return nil, err
}
existing[f] = &val
}

if m := e.returningMarker; m != nil {
Expand All @@ -296,14 +267,9 @@ func (e *execer) Exec(ctx context.Context, vs map[string]interface{}) (sql.Resul
pristine := true

for _, sf := range e.sfs {
ok, err := equalValues(reflect.ValueOf(existing[sf]).Elem().Interface(), vs[sf])

if err != nil {
return err
}

if !ok {
if val := existing[sf].(*sql.NullBool); !val.Bool {
pristine = false

break
}
}
Expand Down

0 comments on commit 60a3d13

Please sign in to comment.