Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite mjwt library to better support keystores #3

Merged
merged 4 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions auth/access-token.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,25 @@ package auth

import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"time"
)

// AccessTokenClaims contains the JWT claims for an access token
type AccessTokenClaims struct {
Perms *claims.PermStorage `json:"per"`
Perms *PermStorage `json:"per"`
}

func (a AccessTokenClaims) Valid() error { return nil }

func (a AccessTokenClaims) Type() string { return "access-token" }

// CreateAccessToken creates an access token with the default 15 minute duration
func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
func CreateAccessToken(p *mjwt.Issuer, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, aud, perms)
}

// CreateAccessTokenWithDuration creates an access token with a custom duration
func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
func CreateAccessTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms})
}

// CreateAccessTokenWithKID creates an access token with the default 15 minute duration and the specified kID
func CreateAccessTokenWithKID(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
return CreateAccessTokenWithDurationAndKID(p, time.Minute*15, sub, id, aud, perms, kID)
}

// CreateAccessTokenWithDurationAndKID creates an access token with a custom duration and the specified kID
func CreateAccessTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
return p.GenerateJwtWithKID(sub, id, aud, dur, &AccessTokenClaims{Perms: perms}, kID)
}
40 changes: 6 additions & 34 deletions auth/access-token_test.go
Original file line number Diff line number Diff line change
@@ -1,55 +1,27 @@
package auth

import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"testing"
)

func TestCreateAccessToken(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

ps := claims.NewPermStorage()
ps := NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")

s := mjwt.NewMJwtSigner("mjwt.test", key)

accessToken, err := CreateAccessToken(s, "1", "test", nil, ps)
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
}

func TestCreateAccessTokenInvalid(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)

kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)

ps := claims.NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")

s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)

accessToken, err := CreateAccessTokenWithKID(s, "1", "test", nil, ps, "test")
accessToken, err := CreateAccessToken(s, "1", "test", nil, ps)
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
Expand Down
25 changes: 2 additions & 23 deletions auth/pair.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@ package auth

import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"time"
)

// CreateTokenPair creates an access and refresh token pair using the default
// 15 minute and 7 day durations respectively
func CreateTokenPair(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
func CreateTokenPair(p *mjwt.Issuer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms)
}

// CreateTokenPairWithDuration creates an access and refresh token pair using
// custom durations for the access and refresh tokens
func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
func CreateTokenPairWithDuration(p *mjwt.Issuer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, aud, perms)
if err != nil {
return "", "", err
Expand All @@ -26,23 +25,3 @@ func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Durat
}
return accessToken, refreshToken, nil
}

// CreateTokenPairWithKID creates an access and refresh token pair using the default
// 15 minute and 7 day durations respectively using the specified kID
func CreateTokenPairWithKID(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
return CreateTokenPairWithDurationAndKID(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms, kID)
}

// CreateTokenPairWithDurationAndKID creates an access and refresh token pair using
// custom durations for the access and refresh tokens
func CreateTokenPairWithDurationAndKID(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
accessToken, err := CreateAccessTokenWithDurationAndKID(p, accessDur, sub, id, aud, perms, kID)
if err != nil {
return "", "", err
}
refreshToken, err := CreateRefreshTokenWithDurationAndKID(p, refreshDur, sub, rId, id, rAud, kID)
if err != nil {
return "", "", err
}
return accessToken, refreshToken, nil
}
47 changes: 7 additions & 40 deletions auth/pair_test.go
Original file line number Diff line number Diff line change
@@ -1,68 +1,35 @@
package auth

import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"testing"
)

func TestCreateTokenPair(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

ps := claims.NewPermStorage()
ps := NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")

s := mjwt.NewMJwtSigner("mjwt.test", key)

accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps)
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))

_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b2.Subject)
assert.Equal(t, "test2", b2.ID)
}

func TestCreateTokenPairWithKID(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key2", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)

kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)

ps := claims.NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")

s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)

accessToken, refreshToken, err := CreateTokenPairWithKID(s, "1", "test", "test2", nil, nil, ps, "test")
accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps)
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))

_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b2.Subject)
assert.Equal(t, "test2", b2.ID)
Expand Down
2 changes: 1 addition & 1 deletion claims/perms.go → auth/perms.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package claims
package auth

import (
"bufio"
Expand Down
2 changes: 1 addition & 1 deletion claims/perms_test.go → auth/perms_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package claims
package auth

import (
"bytes"
Expand Down
14 changes: 2 additions & 12 deletions auth/refresh-token.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,11 @@ func (r RefreshTokenClaims) Valid() error { return nil }
func (r RefreshTokenClaims) Type() string { return "refresh-token" }

// CreateRefreshToken creates a refresh token with the default 7 day duration
func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
func CreateRefreshToken(p *mjwt.Issuer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, ati, aud)
}

// CreateRefreshTokenWithDuration creates a refresh token with a custom duration
func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
func CreateRefreshTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati})
}

// CreateRefreshTokenWithKID creates a refresh token with the default 7 day duration and the specified kID
func CreateRefreshTokenWithKID(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
return CreateRefreshTokenWithDurationAndKID(p, time.Hour*24*7, sub, id, ati, aud, kID)
}

// CreateRefreshTokenWithDurationAndKID creates a refresh token with a custom duration and the specified kID
func CreateRefreshTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
return p.GenerateJwtWithKID(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati}, kID)
}
31 changes: 5 additions & 26 deletions auth/refresh-token_test.go
Original file line number Diff line number Diff line change
@@ -1,44 +1,23 @@
package auth

import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"testing"
)

func TestCreateRefreshToken(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

s := mjwt.NewMJwtSigner("mjwt.test", key)

refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", jwt.SigningMethodRS512, kStore)
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.Equal(t, "test2", b.Claims.AccessTokenId)
}

func TestCreateRefreshTokenWithKID(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)

s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)

refreshToken, err := CreateRefreshTokenWithKID(s, "1", "test", "test2", nil, "test")
refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil)
assert.NoError(t, err)

_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
Expand Down
8 changes: 4 additions & 4 deletions mjwt.go → claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (
var ErrClaimTypeMismatch = errors.New("claim type mismatch")

// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
func wrapClaims[T Claims](sub, id, issuer string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
now := time.Now()
return (&BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: p.Issuer(),
Issuer: issuer,
Subject: sub,
Audience: aud,
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
Expand All @@ -28,12 +28,12 @@ func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur ti

// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
// token and BaseTypeClaims
func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) {
func ExtractClaims[T Claims](ks *KeyStore, token string) (*jwt.Token, BaseTypeClaims[T], error) {
b := BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{},
Claims: *new(T),
}
tok, err := p.VerifyJwt(token, &b)
tok, err := ks.VerifyJwt(token, &b)
return tok, b, err
}

Expand Down
Loading
Loading