Skip to content

Commit

Permalink
x/sqlbuilder/upserter: Handle nullable DB value uniquely
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisMontagne committed Sep 17, 2024
1 parent c446a3e commit b1c0f8d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
16 changes: 10 additions & 6 deletions x/sqlbuilder/upserter/execer.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,27 +201,31 @@ func newExecer(te txExecutor, stmt Statement) sqlbuilder.Execer {

func cloneValue(v any) (any, error) {
if dv, ok := v.(driver.Valuer); ok {
var err error

v, err = dv.Value()
vv, err := dv.Value()

if err != nil {
return nil, err
}

if vv != nil {
v = vv
}
}

return reflect.New(reflect.TypeOf(v)).Interface(), nil
}

func equalValues(x, y any) (bool, error) {
if dy, ok := y.(driver.Valuer); ok {
var err error

y, err = dy.Value()
yy, err := dy.Value()

if err != nil {
return false, err
}

if yy != nil {
y = yy
}
}

switch yy := y.(type) {
Expand Down
23 changes: 19 additions & 4 deletions x/sqlbuilder/upserter/upserter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestUpserterRegular(t *testing.T) {

res, err := e.Exec(
ctx,
map[string]interface{}{"x": "foo", "y": stringOverloaded("bar"), "z": "buz"},
map[string]interface{}{"x": "foo", "y": nilDBValue{}, "z": "buz"},
)

if err != nil {
Expand All @@ -115,7 +115,18 @@ func TestUpserterRegular(t *testing.T) {

res, err = e.Exec(
ctx,
map[string]interface{}{"x": "foo", "y": stringOverloaded("bar"), "z": "buz"},
map[string]interface{}{"x": "foo", "y": stringDBValue("bar"), "z": "buz"},
)

if err != nil {
t.Fatalf("Exec() = %v [ want nil ]", err)
}

assertResultAffected(t, res, 1)

res, err = e.Exec(
ctx,
map[string]interface{}{"x": "foo", "y": stringDBValue("bar"), "z": "buz"},
)

if err != nil {
Expand Down Expand Up @@ -449,9 +460,13 @@ func TestUpserterOnlyQueryValues(t *testing.T) {
})
}

type stringOverloaded string
type nilDBValue struct{}

func (nilDBValue) Value() (driver.Value, error) { return nil, nil }

type stringDBValue string

func (s stringOverloaded) Value() (driver.Value, error) { return []byte(s), nil }
func (s stringDBValue) Value() (driver.Value, error) { return []byte(s), nil }

func TestInTxUpserterPristine(t *testing.T) {
sqltest.NewTestCase(
Expand Down

0 comments on commit b1c0f8d

Please sign in to comment.