Skip to content

Commit

Permalink
Rewrite mjwt library to better support keystores
Browse files Browse the repository at this point in the history
  • Loading branch information
mrmelon54 committed Jul 27, 2024
1 parent 5d1bd6f commit cd2d80c
Show file tree
Hide file tree
Showing 26 changed files with 684 additions and 1,262 deletions.
17 changes: 3 additions & 14 deletions auth/access-token.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,25 @@ package auth

import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"time"
)

// AccessTokenClaims contains the JWT claims for an access token
type AccessTokenClaims struct {
Perms *claims.PermStorage `json:"per"`
Perms *PermStorage `json:"per"`
}

func (a AccessTokenClaims) Valid() error { return nil }

func (a AccessTokenClaims) Type() string { return "access-token" }

// CreateAccessToken creates an access token with the default 15 minute duration
func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
func CreateAccessToken(p *mjwt.Issuer, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, aud, perms)
}

// CreateAccessTokenWithDuration creates an access token with a custom duration
func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
func CreateAccessTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms})
}

// CreateAccessTokenWithKID creates an access token with the default 15 minute duration and the specified kID
func CreateAccessTokenWithKID(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
return CreateAccessTokenWithDurationAndKID(p, time.Minute*15, sub, id, aud, perms, kID)
}

// CreateAccessTokenWithDurationAndKID creates an access token with a custom duration and the specified kID
func CreateAccessTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
return p.GenerateJwtWithKID(sub, id, aud, dur, &AccessTokenClaims{Perms: perms}, kID)
}
39 changes: 5 additions & 34 deletions auth/access-token_test.go
Original file line number Diff line number Diff line change
@@ -1,55 +1,26 @@
package auth

import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/stretchr/testify/assert"
"testing"
)

func TestCreateAccessToken(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

ps := claims.NewPermStorage()
ps := NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")

s := mjwt.NewMJwtSigner("mjwt.test", key)

accessToken, err := CreateAccessToken(s, "1", "test", nil, ps)
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
}

func TestCreateAccessTokenInvalid(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
assert.NoError(t, err)

kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)

ps := claims.NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")

s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)

accessToken, err := CreateAccessTokenWithKID(s, "1", "test", nil, ps, "test")
accessToken, err := CreateAccessToken(s, "1", "test", nil, ps)
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
Expand Down
25 changes: 2 additions & 23 deletions auth/pair.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@ package auth

import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"time"
)

// CreateTokenPair creates an access and refresh token pair using the default
// 15 minute and 7 day durations respectively
func CreateTokenPair(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
func CreateTokenPair(p *mjwt.Issuer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms)
}

// CreateTokenPairWithDuration creates an access and refresh token pair using
// custom durations for the access and refresh tokens
func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
func CreateTokenPairWithDuration(p *mjwt.Issuer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, aud, perms)
if err != nil {
return "", "", err
Expand All @@ -26,23 +25,3 @@ func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Durat
}
return accessToken, refreshToken, nil
}

// CreateTokenPairWithKID creates an access and refresh token pair using the default
// 15 minute and 7 day durations respectively using the specified kID
func CreateTokenPairWithKID(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
return CreateTokenPairWithDurationAndKID(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms, kID)
}

// CreateTokenPairWithDurationAndKID creates an access and refresh token pair using
// custom durations for the access and refresh tokens
func CreateTokenPairWithDurationAndKID(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
accessToken, err := CreateAccessTokenWithDurationAndKID(p, accessDur, sub, id, aud, perms, kID)
if err != nil {
return "", "", err
}
refreshToken, err := CreateRefreshTokenWithDurationAndKID(p, refreshDur, sub, rId, id, rAud, kID)
if err != nil {
return "", "", err
}
return accessToken, refreshToken, nil
}
46 changes: 6 additions & 40 deletions auth/pair_test.go
Original file line number Diff line number Diff line change
@@ -1,68 +1,34 @@
package auth

import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/stretchr/testify/assert"
"testing"
)

func TestCreateTokenPair(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

ps := claims.NewPermStorage()
ps := NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")

s := mjwt.NewMJwtSigner("mjwt.test", key)

accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps)
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))

_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b2.Subject)
assert.Equal(t, "test2", b2.ID)
}

func TestCreateTokenPairWithKID(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
assert.NoError(t, err)

kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)

ps := claims.NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")

s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)

accessToken, refreshToken, err := CreateTokenPairWithKID(s, "1", "test", "test2", nil, nil, ps, "test")
accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps)
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))

_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b2.Subject)
assert.Equal(t, "test2", b2.ID)
Expand Down
2 changes: 1 addition & 1 deletion claims/perms.go → auth/perms.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package claims
package auth

import (
"bufio"
Expand Down
2 changes: 1 addition & 1 deletion claims/perms_test.go → auth/perms_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package claims
package auth

import (
"bytes"
Expand Down
14 changes: 2 additions & 12 deletions auth/refresh-token.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,11 @@ func (r RefreshTokenClaims) Valid() error { return nil }
func (r RefreshTokenClaims) Type() string { return "refresh-token" }

// CreateRefreshToken creates a refresh token with the default 7 day duration
func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
func CreateRefreshToken(p *mjwt.Issuer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, ati, aud)
}

// CreateRefreshTokenWithDuration creates a refresh token with a custom duration
func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
func CreateRefreshTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati})
}

// CreateRefreshTokenWithKID creates a refresh token with the default 7 day duration and the specified kID
func CreateRefreshTokenWithKID(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
return CreateRefreshTokenWithDurationAndKID(p, time.Hour*24*7, sub, id, ati, aud, kID)
}

// CreateRefreshTokenWithDurationAndKID creates a refresh token with a custom duration and the specified kID
func CreateRefreshTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
return p.GenerateJwtWithKID(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati}, kID)
}
30 changes: 4 additions & 26 deletions auth/refresh-token_test.go
Original file line number Diff line number Diff line change
@@ -1,44 +1,22 @@
package auth

import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/stretchr/testify/assert"
"testing"
)

func TestCreateRefreshToken(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

s := mjwt.NewMJwtSigner("mjwt.test", key)

refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.Equal(t, "test2", b.Claims.AccessTokenId)
}

func TestCreateRefreshTokenWithKID(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)

s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)

refreshToken, err := CreateRefreshTokenWithKID(s, "1", "test", "test2", nil, "test")
refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil)
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
Expand Down
8 changes: 4 additions & 4 deletions mjwt.go → claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (
var ErrClaimTypeMismatch = errors.New("claim type mismatch")

// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
func wrapClaims[T Claims](sub, id, issuer string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
now := time.Now()
return (&BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: p.Issuer(),
Issuer: issuer,
Subject: sub,
Audience: aud,
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
Expand All @@ -28,12 +28,12 @@ func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur ti

// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
// token and BaseTypeClaims
func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) {
func ExtractClaims[T Claims](ks *KeyStore, token string) (*jwt.Token, BaseTypeClaims[T], error) {
b := BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{},
Claims: *new(T),
}
tok, err := p.VerifyJwt(token, &b)
tok, err := ks.VerifyJwt(token, &b)
return tok, b, err
}

Expand Down
Loading

0 comments on commit cd2d80c

Please sign in to comment.