Skip to content

Commit

Permalink
add tests & fix InTx
Browse files Browse the repository at this point in the history
  • Loading branch information
wroge committed Dec 8, 2024
1 parent 87d9dd5 commit a9b6079
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 15 deletions.
7 changes: 4 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ go 1.23

require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/jba/templatecheck v0.7.0
github.com/jba/templatecheck v0.7.1
github.com/spf13/afero v1.11.0
)

require (
github.com/google/safehtml v0.0.2 // indirect
golang.org/x/text v0.17.0 // indirect
github.com/google/safehtml v0.1.0 // indirect
golang.org/x/text v0.21.0 // indirect
)
8 changes: 8 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@ github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7Oputl
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/google/safehtml v0.0.2 h1:ZOt2VXg4x24bW0m2jtzAOkhoXV0iM8vNKc0paByCZqM=
github.com/google/safehtml v0.0.2/go.mod h1:L4KWwDsUJdECRAEpZoBn3O64bQaywRscowZjJAzjHnU=
github.com/google/safehtml v0.1.0 h1:EwLKo8qawTKfsi0orxcQAZzu07cICaBeFMegAU9eaT8=
github.com/google/safehtml v0.1.0/go.mod h1:L4KWwDsUJdECRAEpZoBn3O64bQaywRscowZjJAzjHnU=
github.com/jba/templatecheck v0.7.0 h1:wjTb/VhGgSFeim5zjWVePBdaMo28X74bGLSABZV+zIA=
github.com/jba/templatecheck v0.7.0/go.mod h1:n1Etw+Rrw1mDDD8dDRsEKTwMZsJ98EkktgNJC6wLUGo=
github.com/jba/templatecheck v0.7.1 h1:yOEIFazBEwzdTPYHZF3Pm81NF1ksxx1+vJncSEwvjKc=
github.com/jba/templatecheck v0.7.1/go.mod h1:n1Etw+Rrw1mDDD8dDRsEKTwMZsJ98EkktgNJC6wLUGo=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
14 changes: 8 additions & 6 deletions sqlt.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func InTx(ctx context.Context, opts *sql.TxOptions, db *sql.DB, do func(db DB) e
panic(p)
}
} else if err != nil {
panic(errors.Join(err, tx.Rollback()))
err = errors.Join(err, tx.Rollback())
} else {
err = tx.Commit()
}
Expand Down Expand Up @@ -377,6 +377,8 @@ func (r *Runner) QueryRow(db DB, param any) (*sql.Row, error) {
func Stmt[Param any](opts ...Option) *Statement[Param] {
_, file, line, _ := runtime.Caller(1)

location := fmt.Sprintf("%s:%d", file, line)

config := &Config{
Placeholder: "?",
}
Expand All @@ -397,12 +399,12 @@ func Stmt[Param any](opts ...Option) *Statement[Param] {
for _, to := range config.TemplateOptions {
tpl, err = to(tpl)
if err != nil {
panic(fmt.Errorf("location: [%s:%d]: %w", file, line, err))
panic(fmt.Errorf("location: [%s]: %w", location, err))
}
}

if err = templatecheck.CheckText(tpl, *new(Param)); err != nil {
panic(fmt.Errorf("location: [%s:%d]: %w", file, line, err))
panic(fmt.Errorf("location: [%s]: %w", location, err))
}

escape(tpl)
Expand All @@ -417,13 +419,13 @@ func Stmt[Param any](opts ...Option) *Statement[Param] {
New: func() any {
t, err := tpl.Clone()
if err != nil {
panic(fmt.Errorf("location: [%s:%d]: %w", file, line, err))
panic(fmt.Errorf("location: [%s]: %w", location, err))
}

runner := &Runner{
Template: t,
SQL: &SQL{},
Location: fmt.Sprintf("%s:%d", file, line),
Location: location,
}

t.Funcs(template.FuncMap{
Expand Down Expand Up @@ -831,7 +833,7 @@ func (w *SQL) String() string {
}

if w.data[len(w.data)-1] == ' ' {
return string(w.data[:len(w.data)-1])
w.data = w.data[:len(w.data)-1]
}

return string(w.data)
Expand Down
Loading

0 comments on commit a9b6079

Please sign in to comment.