diff --git a/repositories/metadata/postgresql/blobs.go b/repositories/metadata/postgresql/blobs.go index 59d79bc..943f690 100644 --- a/repositories/metadata/postgresql/blobs.go +++ b/repositories/metadata/postgresql/blobs.go @@ -8,7 +8,7 @@ import ( ) func (r *repository) CreateBLOB(ctx context.Context, checksum string, size uint64, mimeType string) error { - _, err := psql. + _, err := insertQuery(ctx, r.db, psql. Insert("blobs"). Columns( "checksum", @@ -19,15 +19,12 @@ func (r *repository) CreateBLOB(ctx context.Context, checksum string, size uint6 checksum, size, mimeType, - ). - RunWith(r.db). - ExecContext(ctx) - - return errors.Wrap(err, "error executing SQL query") + )) + return mapSQLErrors(err) } func (r *repository) GetBlobKeyByObject(ctx context.Context, container, version, key string) (string, error) { - row := psql. + row, err := selectQueryRow(ctx, r.db, psql. Select("b.checksum AS checksum"). From("blobs b"). Join("objects o ON o.blob_id = b.id"). @@ -38,9 +35,10 @@ func (r *repository) GetBlobKeyByObject(ctx context.Context, container, version, "v.name": version, "o.key": key, "v.is_published": true, - }). - RunWith(r.db). - QueryRowContext(ctx) + })) + if err != nil { + return "", mapSQLErrors(err) + } var checksum string if err := row.Scan(&checksum); err != nil { @@ -51,15 +49,16 @@ func (r *repository) GetBlobKeyByObject(ctx context.Context, container, version, } func (r *repository) EnsureBlobKey(ctx context.Context, key string, size uint64) error { - row := psql. + row, err := selectQueryRow(ctx, r.db, psql. Select("id"). From("blobs"). Where(sq.Eq{ "checksum": key, "size": size, - }). - RunWith(r.db). - QueryRowContext(ctx) + })) + if err != nil { + return errors.Wrap(err, "error selecting BLOB") + } var blobID uint if err := row.Scan(&blobID); err != nil { diff --git a/repositories/metadata/postgresql/containers.go b/repositories/metadata/postgresql/containers.go index 2fd3e1f..ef44da3 100644 --- a/repositories/metadata/postgresql/containers.go +++ b/repositories/metadata/postgresql/containers.go @@ -4,33 +4,27 @@ import ( "context" sq "github.com/Masterminds/squirrel" - "github.com/pkg/errors" ) func (r *repository) CreateContainer(ctx context.Context, name string) error { - _, err := psql. + _, err := insertQuery(ctx, r.db, psql. Insert("containers"). Columns( "name", ). Values( name, - ). - RunWith(r.db). - ExecContext(ctx) - - return errors.Wrap(err, "error executing SQL query") + )) + return mapSQLErrors(err) } func (r *repository) ListContainers(ctx context.Context) ([]string, error) { - rows, err := psql. + rows, err := selectQuery(ctx, r.db, psql. Select("name"). From("containers"). - OrderBy("name"). - RunWith(r.db). - QueryContext(ctx) + OrderBy("name")) if err != nil { - return nil, errors.Wrap(mapSQLErrors(err), "error executing SQL query") + return nil, mapSQLErrors(err) } defer rows.Close() @@ -38,7 +32,7 @@ func (r *repository) ListContainers(ctx context.Context) ([]string, error) { for rows.Next() { var r string if err := rows.Scan(&r); err != nil { - return nil, errors.Wrap(err, "error decoding database result") + return nil, mapSQLErrors(err) } result = append(result, r) @@ -48,11 +42,8 @@ func (r *repository) ListContainers(ctx context.Context) ([]string, error) { } func (r *repository) DeleteContainer(ctx context.Context, name string) error { - _, err := psql. + _, err := deleteQuery(ctx, r.db, psql. Delete("containers"). - Where(sq.Eq{"name": name}). - RunWith(r.db). - ExecContext(ctx) - - return errors.Wrap(err, "error executing SQL query") + Where(sq.Eq{"name": name})) + return mapSQLErrors(err) } diff --git a/repositories/metadata/postgresql/containers_test.go b/repositories/metadata/postgresql/containers_test.go index 48ca312..ccd2dad 100644 --- a/repositories/metadata/postgresql/containers_test.go +++ b/repositories/metadata/postgresql/containers_test.go @@ -14,7 +14,7 @@ func (s *postgreSQLRepositoryTestSuite) TestContainerOperations() { err = s.repo.CreateContainer(s.ctx, "test-container9") s.Require().Error(err) s.Require().Equal( - `error executing SQL query: pq: duplicate key value violates unique constraint "containers_name_key"`, + `pq: duplicate key value violates unique constraint "containers_name_key"`, err.Error(), ) diff --git a/repositories/metadata/postgresql/objects.go b/repositories/metadata/postgresql/objects.go index 7ca9f18..6147401 100644 --- a/repositories/metadata/postgresql/objects.go +++ b/repositories/metadata/postgresql/objects.go @@ -14,7 +14,7 @@ func (r *repository) CreateObject(ctx context.Context, container, version, key, } defer tx.Rollback() - row := psql. + row, err := selectQueryRow(ctx, tx, psql. Select("v.id as id"). From("containers c"). Join("versions v ON v.container_id = c.id"). @@ -22,28 +22,30 @@ func (r *repository) CreateObject(ctx context.Context, container, version, key, "c.name": container, "v.name": version, "is_published": false, - }). - RunWith(tx). - QueryRowContext(ctx) + })) + if err != nil { + return mapSQLErrors(err) + } var versionID uint if err := row.Scan(&versionID); err != nil { - return errors.Wrap(mapSQLErrors(err), "error looking up version") + return mapSQLErrors(err) } - row = psql. + row, err = selectQueryRow(ctx, tx, psql. Select("id"). From("blobs"). - Where(sq.Eq{"checksum": casKey}). - RunWith(tx). - QueryRowContext(ctx) + Where(sq.Eq{"checksum": casKey})) + if err != nil { + return mapSQLErrors(err) + } var blobID uint if err := row.Scan(&blobID); err != nil { - return errors.Wrap(mapSQLErrors(err), "error looking up blob") + return mapSQLErrors(err) } - _, err = psql. + _, err = insertQuery(ctx, tx, psql. Insert("objects"). Columns( "version_id", @@ -54,11 +56,9 @@ func (r *repository) CreateObject(ctx context.Context, container, version, key, versionID, key, blobID, - ). - RunWith(tx). - ExecContext(ctx) + )) if err != nil { - return errors.Wrap(err, "error executing SQL query") + return mapSQLErrors(err) } if err := tx.Commit(); err != nil { @@ -68,36 +68,38 @@ func (r *repository) CreateObject(ctx context.Context, container, version, key, } func (r *repository) ListObjects(ctx context.Context, container, version string, offset, limit uint64) ([]string, error) { - row := psql. + row, err := selectQueryRow(ctx, r.db, psql. Select("id"). From("containers"). Where(sq.Eq{ "name": container, - }). - RunWith(r.db). - QueryRowContext(ctx) + })) + if err != nil { + return nil, mapSQLErrors(err) + } var containerID uint if err := row.Scan(&containerID); err != nil { return nil, mapSQLErrors(err) } - row = psql. + row, err = selectQueryRow(ctx, r.db, psql. Select("id"). From("versions"). Where(sq.Eq{ "container_id": containerID, "name": version, - }). - RunWith(r.db). - QueryRowContext(ctx) + })) + if err != nil { + return nil, mapSQLErrors(err) + } var versionID uint if err := row.Scan(&versionID); err != nil { return nil, mapSQLErrors(err) } - rows, err := psql. + rows, err := selectQuery(ctx, r.db, psql. Select("key"). From("objects"). Where(sq.Eq{ @@ -105,11 +107,9 @@ func (r *repository) ListObjects(ctx context.Context, container, version string, }). OrderBy("id"). Offset(offset). - Limit(limit). - RunWith(r.db). - QueryContext(ctx) + Limit(limit)) if err != nil { - return nil, errors.Wrap(err, "error executing SQL query") + return nil, mapSQLErrors(err) } defer rows.Close() @@ -133,32 +133,31 @@ func (r *repository) DeleteObject(ctx context.Context, container, version, key s } defer tx.Rollback() - row := psql. + row, err := selectQueryRow(ctx, tx, psql. Select("v.id"). From("versions v"). Join("containers c ON v.container_id = c.id"). Where(sq.Eq{ "c.name": container, "v.name": version, - }). - RunWith(tx). - QueryRowContext(ctx) + })) + if err != nil { + return mapSQLErrors(err) + } var versionID uint if err := row.Scan(&versionID); err != nil { return errors.Wrap(err, "error looking up version") } - _, err = psql. + _, err = deleteQuery(ctx, tx, psql. Delete("objects"). Where(sq.Eq{ "version_id": versionID, "key": key, - }). - RunWith(tx). - ExecContext(ctx) + })) if err != nil { - return errors.Wrap(err, "error executing SQL query") + return mapSQLErrors(err) } if err := tx.Commit(); err != nil { @@ -174,45 +173,45 @@ func (r *repository) RemapObject(ctx context.Context, container, version, key, n } defer tx.Rollback() - row := psql. + row, err := selectQueryRow(ctx, tx, psql. Select("v.id"). From("versions v"). Join("containers c ON v.container_id = c.id"). Where(sq.Eq{ "c.name": container, "v.name": version, - }). - RunWith(tx). - QueryRowContext(ctx) + })) + if err != nil { + return mapSQLErrors(err) + } var versionID uint if err := row.Scan(&versionID); err != nil { return errors.Wrap(err, "error looking up version") } - row = psql. + row, err = selectQueryRow(ctx, tx, psql. Select("id"). From("blobs"). - Where(sq.Eq{"checksum": newCASKey}). - RunWith(tx). - QueryRowContext(ctx) + Where(sq.Eq{"checksum": newCASKey})) + if err != nil { + return mapSQLErrors(err) + } var blobID uint if err := row.Scan(&blobID); err != nil { return errors.Wrap(err, "error looking up blob") } - _, err = psql. + _, err = updateQuery(ctx, tx, psql. Update("objects"). Set("blob_id", blobID). Where(sq.Eq{ "version_id": versionID, "key": key, - }). - RunWith(tx). - ExecContext(ctx) + })) if err != nil { - return errors.Wrap(err, "error executing SQL query") + return mapSQLErrors(err) } if err := tx.Commit(); err != nil { diff --git a/repositories/metadata/postgresql/postgresql.go b/repositories/metadata/postgresql/postgresql.go index 8ec41b2..f6291cf 100644 --- a/repositories/metadata/postgresql/postgresql.go +++ b/repositories/metadata/postgresql/postgresql.go @@ -1,10 +1,12 @@ package postgresql import ( + "context" "database/sql" "time" sq "github.com/Masterminds/squirrel" + log "github.com/sirupsen/logrus" "github.com/teran/archived/repositories/metadata" ) @@ -35,3 +37,75 @@ func mapSQLErrors(err error) error { return err } } + +type queryRunner interface { + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} + +type execRunner interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) +} + +type query interface { + ToSql() (string, []interface{}, error) +} + +func mkQuery(q query) (string, []any, error) { + sql, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + + log.WithFields(log.Fields{ + "query": sql, + "args": args, + }).Tracef("SQL query generated") + + return sql, args, nil +} + +func selectQueryRow(ctx context.Context, db queryRunner, q sq.SelectBuilder) (sq.RowScanner, error) { + sql, args, err := mkQuery(q) + if err != nil { + return nil, err + } + + return db.QueryRowContext(ctx, sql, args...), nil +} + +func selectQuery(ctx context.Context, db queryRunner, q sq.SelectBuilder) (*sql.Rows, error) { + sql, args, err := mkQuery(q) + if err != nil { + return nil, err + } + + return db.QueryContext(ctx, sql, args...) +} + +func insertQuery(ctx context.Context, db execRunner, q sq.InsertBuilder) (sql.Result, error) { + sql, args, err := mkQuery(q) + if err != nil { + return nil, err + } + + return db.ExecContext(ctx, sql, args...) +} + +func updateQuery(ctx context.Context, db execRunner, q sq.UpdateBuilder) (sql.Result, error) { + sql, args, err := mkQuery(q) + if err != nil { + return nil, err + } + + return db.ExecContext(ctx, sql, args...) +} + +func deleteQuery(ctx context.Context, db execRunner, q sq.DeleteBuilder) (sql.Result, error) { + sql, args, err := mkQuery(q) + if err != nil { + return nil, err + } + + return db.ExecContext(ctx, sql, args...) +} diff --git a/repositories/metadata/postgresql/versions.go b/repositories/metadata/postgresql/versions.go index 74b0c90..3692ab7 100644 --- a/repositories/metadata/postgresql/versions.go +++ b/repositories/metadata/postgresql/versions.go @@ -17,12 +17,13 @@ func (r *repository) CreateVersion(ctx context.Context, container string) (strin } defer tx.Rollback() - row := psql. + row, err := selectQueryRow(ctx, tx, psql. Select("id"). From("containers"). - Where(sq.Eq{"name": container}). - RunWith(tx). - QueryRowContext(ctx) + Where(sq.Eq{"name": container})) + if err != nil { + return "", mapSQLErrors(err) + } var containerID uint if err := row.Scan(&containerID); err != nil { @@ -31,7 +32,7 @@ func (r *repository) CreateVersion(ctx context.Context, container string) (strin versionID := r.tp().UTC().Format("20060102150405") - _, err = psql. + _, err = insertQuery(ctx, tx, psql. Insert("versions"). Columns( "container_id", @@ -42,11 +43,9 @@ func (r *repository) CreateVersion(ctx context.Context, container string) (strin containerID, versionID, false, - ). - RunWith(tx). - ExecContext(ctx) + )) if err != nil { - return "", errors.Wrap(err, "error executing SQL query") + return "", mapSQLErrors(err) } if err := tx.Commit(); err != nil { @@ -75,12 +74,13 @@ func (r *repository) listVersionsByContainer(ctx context.Context, container stri limit = defaultLimit } - row := psql. + row, err := selectQueryRow(ctx, r.db, psql. Select("id"). From("containers"). - Where(sq.Eq{"name": container}). - RunWith(r.db). - QueryRowContext(ctx) + Where(sq.Eq{"name": container})) + if err != nil { + return 0, nil, mapSQLErrors(err) + } var containerID uint64 if err := row.Scan(&containerID); err != nil { @@ -95,28 +95,28 @@ func (r *repository) listVersionsByContainer(ctx context.Context, container stri condition["is_published"] = *isPublished } - row = psql. + row, err = selectQueryRow(ctx, r.db, psql. Select("COUNT(*)"). From("versions"). - RunWith(r.db). - QueryRowContext(ctx) + Where(sq.Eq{"container_id": containerID})) + if err != nil { + return 0, nil, mapSQLErrors(err) + } var versionsTotal uint64 if err := row.Scan(&versionsTotal); err != nil { return 0, nil, mapSQLErrors(err) } - rows, err := psql. + rows, err := selectQuery(ctx, r.db, psql. Select("name"). From("versions"). Where(condition). OrderBy("id"). Offset(offset). - Limit(limit). - RunWith(r.db). - QueryContext(ctx) + Limit(limit)) if err != nil { - return 0, nil, errors.Wrap(err, "error executing SQL query") + return 0, nil, mapSQLErrors(err) } defer rows.Close() @@ -141,28 +141,28 @@ func (r *repository) MarkVersionPublished(ctx context.Context, container, versio defer tx.Rollback() var containerID uint - row := psql. + + row, err := selectQueryRow(ctx, tx, psql. Select("id"). From("containers"). - Where(sq.Eq{"name": container}). - RunWith(tx). - QueryRowContext(ctx) + Where(sq.Eq{"name": container})) + if err != nil { + return mapSQLErrors(err) + } if err := row.Scan(&containerID); err != nil { return errors.Wrap(err, "error looking up container") } - _, err = psql. + _, err = updateQuery(ctx, tx, psql. Update("versions"). Set("is_published", true). Where(sq.Eq{ "container_id": containerID, "name": version, - }). - RunWith(tx). - ExecContext(ctx) + })) if err != nil { - return errors.Wrap(err, "error executing SQL query") + return mapSQLErrors(err) } if err := tx.Commit(); err != nil {