Skip to content

Commit

Permalink
Merge pull request #1451 from alesstimec/fix-internal-db-group
Browse files Browse the repository at this point in the history
fix(internal/db/group.go): fixes CRUD methods for groups
  • Loading branch information
alesstimec authored Nov 20, 2024
2 parents 9e1c84d + a549044 commit 97910cd
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 25 deletions.
28 changes: 14 additions & 14 deletions internal/db/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ func (d *Database) GetGroup(ctx context.Context, group *dbmodel.GroupEntry) (err
return errors.E(op, err)
}

if group.UUID == "" && group.Name == "" {
return errors.E(op, "must specify uuid or name")
}

durationObserver := servermon.DurationObserver(servermon.DBQueryDurationHistogram, string(op))
defer durationObserver()
defer servermon.ErrorCounter(servermon.DBQueryErrorCount, &err, string(op))
Expand Down Expand Up @@ -108,15 +112,12 @@ func (d *Database) ListGroups(ctx context.Context, limit, offset int, match stri
return groups, nil
}

// UpdateGroupName updates the group name identified by its ID or UUID.
func (d *Database) UpdateGroupName(ctx context.Context, group *dbmodel.GroupEntry) (err error) {
// UpdateGroupName updates the name of the group identified by its UUID.
func (d *Database) UpdateGroupName(ctx context.Context, uuid, name string) (err error) {
const op = errors.Op("db.UpdateGroup")

if group.ID == 0 {
return errors.E(errors.CodeNotFound)
}
if group.UUID == "" {
return errors.E("group uuid not specified", errors.CodeNotFound)
if uuid == "" {
return errors.E(op, "uuid must be specified")
}

if err := d.ready(); err != nil {
Expand All @@ -127,8 +128,10 @@ func (d *Database) UpdateGroupName(ctx context.Context, group *dbmodel.GroupEntr
defer durationObserver()
defer servermon.ErrorCounter(servermon.DBQueryErrorCount, &err, string(op))

if err := d.DB.WithContext(ctx).Save(group).Error; err != nil {
return errors.E(op, dbError(err))
model := d.DB.WithContext(ctx).Model(&dbmodel.GroupEntry{})
model.Where("uuid = ?", uuid)
if model.Update("name", name).RowsAffected == 0 {
return errors.E(op, errors.CodeNotFound, "group not found")
}
return nil
}
Expand All @@ -137,11 +140,8 @@ func (d *Database) UpdateGroupName(ctx context.Context, group *dbmodel.GroupEntr
func (d *Database) RemoveGroup(ctx context.Context, group *dbmodel.GroupEntry) (err error) {
const op = errors.Op("db.RemoveGroup")

if group.ID == 0 {
return errors.E(errors.CodeNotFound)
}
if group.UUID == "" {
return errors.E(errors.CodeNotFound)
if group.ID == 0 && group.UUID == "" {
return errors.E("neither UUID or ID specified", errors.CodeNotFound)
}

if err := d.ready(); err != nil {
Expand Down
16 changes: 7 additions & 9 deletions internal/db/group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"context"
"fmt"
"testing"
"time"

qt "github.com/frankban/quicktest"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"

"github.com/canonical/jimm/v3/internal/db"
Expand Down Expand Up @@ -120,17 +122,13 @@ func (s *dbSuite) TestGetGroup(c *qt.C) {
}

func (s *dbSuite) TestUpdateGroupName(c *qt.C) {
err := s.Database.UpdateGroupName(context.Background(), &dbmodel.GroupEntry{Name: "test-group"})
c.Check(errors.ErrorCode(err), qt.Equals, errors.CodeNotFound)
err := s.Database.UpdateGroupName(context.Background(), "test-group", "new-name")
c.Check(errors.ErrorCode(err), qt.Equals, errors.CodeUpgradeInProgress)

err = s.Database.Migrate(context.Background(), false)
c.Assert(err, qt.IsNil)

ge := &dbmodel.GroupEntry{
Name: "test-group",
}

err = s.Database.UpdateGroupName(context.Background(), ge)
err = s.Database.UpdateGroupName(context.Background(), "test-group", "new-name")
c.Check(errors.ErrorCode(err), qt.Equals, errors.CodeNotFound)

_, err = s.Database.AddGroup(context.Background(), "test-group")
Expand All @@ -143,15 +141,15 @@ func (s *dbSuite) TestUpdateGroupName(c *qt.C) {
c.Assert(err, qt.IsNil)

ge1.Name = "renamed-group"
err = s.Database.UpdateGroupName(context.Background(), ge1)
err = s.Database.UpdateGroupName(context.Background(), ge1.UUID, ge1.Name)
c.Check(err, qt.IsNil)

ge2 := &dbmodel.GroupEntry{
Name: "renamed-group",
}
err = s.Database.GetGroup(context.Background(), ge2)
c.Check(err, qt.IsNil)
c.Assert(ge2, qt.DeepEquals, ge1)
c.Assert(ge2, qt.CmpEquals(cmpopts.IgnoreTypes(time.Time{})), ge1)
}

func (s *dbSuite) TestRemoveGroup(c *qt.C) {
Expand Down
3 changes: 1 addition & 2 deletions internal/jimm/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -755,9 +755,8 @@ func (j *JIMM) RenameGroup(ctx context.Context, user *openfga.User, oldName, new
if err != nil {
return errors.E(op, err)
}
group.Name = newName

if err := j.Database.UpdateGroupName(ctx, group); err != nil {
if err := j.Database.UpdateGroupName(ctx, group.UUID, newName); err != nil {
return errors.E(op, err)
}
return nil
Expand Down

0 comments on commit 97910cd

Please sign in to comment.