diff --git a/x/sqlbuilder/upserter/execer.go b/x/sqlbuilder/upserter/execer.go index 2d7ca30..1710d41 100644 --- a/x/sqlbuilder/upserter/execer.go +++ b/x/sqlbuilder/upserter/execer.go @@ -1,12 +1,9 @@ package upserter import ( - "bytes" "context" "database/sql/driver" "fmt" - "reflect" - "time" "github.com/upfluence/errors" "github.com/upfluence/sql" @@ -127,6 +124,12 @@ func newExecer(te txExecutor, stmt Statement) sqlbuilder.Execer { } } + var selectClauses = []sqlbuilder.Marker{oneMarker} + + for _, m := range stmt.SetValues { + selectClauses = append(selectClauses, &assertMarker{Marker: m}) + } + var ( clauses = make([]sqlbuilder.PredicateClause, len(stmt.QueryValues)) @@ -136,7 +139,7 @@ func newExecer(te txExecutor, stmt Statement) sqlbuilder.Execer { sfs: make([]string, len(stmt.SetValues)), ss: sqlbuilder.SelectStatement{ Table: stmt.Table, - SelectClauses: append([]sqlbuilder.Marker{oneMarker}, stmt.SetValues...), + SelectClauses: selectClauses, }, us: sqlbuilder.UpdateStatement{ Table: stmt.Table, @@ -199,49 +202,19 @@ func newExecer(te txExecutor, stmt Statement) sqlbuilder.Execer { return &e } -func cloneValue(v any) (any, error) { - if dv, ok := v.(driver.Valuer); ok { - vv, err := dv.Value() - - if err != nil { - return nil, err - } - - if vv != nil { - v = vv - } - } - - return reflect.New(reflect.TypeOf(v)).Interface(), nil +type assertMarker struct { + sqlbuilder.Marker } -func equalValues(x, y any) (bool, error) { - if dy, ok := y.(driver.Valuer); ok { - yy, err := dy.Value() - - if err != nil { - return false, err - } - - if yy != nil { - y = yy - } - } - - switch yy := y.(type) { - case time.Time: - if xx, ok := x.(time.Time); ok { - return xx.Equal(yy), nil - } - case []byte: - if xx, ok := x.([]byte); ok { - return bytes.Equal(yy, xx), nil - } - default: - return reflect.DeepEqual(x, y), nil - } +func (am *assertMarker) WriteTo(qw sqlbuilder.QueryWriter, vs map[string]interface{}) error { + _, err := fmt.Fprintf( + qw, + "%s = %s", + am.ToSQL(), + qw.RedeemVariable(vs["assert_"+am.Binding()]), + ) - return false, nil + return err } func (e *execer) Exec(ctx context.Context, vs map[string]interface{}) (sql.Result, error) { @@ -271,13 +244,11 @@ func (e *execer) Exec(ctx context.Context, vs map[string]interface{}) (sql.Resul return nil, sqlbuilder.ErrMissingKey{Key: f} } - var err error + var val sql.NullBool - existing[f], err = cloneValue(v) + qvs["assert_"+f] = v - if err != nil { - return nil, err - } + existing[f] = &val } if m := e.returningMarker; m != nil { @@ -296,14 +267,9 @@ func (e *execer) Exec(ctx context.Context, vs map[string]interface{}) (sql.Resul pristine := true for _, sf := range e.sfs { - ok, err := equalValues(reflect.ValueOf(existing[sf]).Elem().Interface(), vs[sf]) - - if err != nil { - return err - } - - if !ok { + if val := existing[sf].(*sql.NullBool); !val.Bool { pristine = false + break } }