Skip to content

Commit

Permalink
feat: replace pyjwkest with pyjwt
Browse files Browse the repository at this point in the history
  • Loading branch information
mumarkhan999 committed Jun 16, 2023
1 parent 758f710 commit 58d2482
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,62 +89,78 @@ KeyPair Generation

Here is code for generating a keypair::

import json
from Cryptodome.PublicKey import RSA
from jwkest import jwk
from jwt.algorithms import get_default_algorithms

rsa_key = RSA.generate(2048)
rsa_jwk = jwk.RSAKey(kid="your_key_id", key=rsa_key)
algo = get_default_algorithms()['RS512']
pem = rsa_key.export_key('PEM').decode()
rsa_jwk = json.loads(algo.to_jwk(algo.prepare_key(pem)))
public_rsa_jwk = json.loads(algo.to_jwk(algo.prepare_key(pem).public_key()))

kid = 'your_key_id'
rsa_jwk['kid'] = kid
public_rsa_jwk['kid'] = kid


To serialize the **public key** in a `JSON Web Key Set (JWK Set)`_::

public_keys = jwk.KEYS()
public_keys.append(rsa_jwk)
serialized_public_keys_json = public_keys.dump_jwks()
public_keys = { 'keys': [public_rsa_jwk] }
serialized_public_keys_json = json.dumps(public_keys)

and its sample output::

{
"keys": [
{
"kid": "your_key_id",
"e": "strawberry",
"kty": "RSA",
"n": "something"
}
]
}
"""
{
"keys": [
{
"kty": "RSA",
"key_ops": ["verify"],
"n": "banana",
"e": "AQAB",
"kid": "your_key_id"
}
]
}
"""

To serialize the **keypair** as a JWK::

serialized_keypair = rsa_jwk.serialize(private=True)
serialized_keypair_json = json.dumps(serialized_keypair)
serialized_keypair_json = json.dumps(rsa_jwk)

and its sample output::

{
"e": "strawberry",
"d": "apple",
"n": "banana",
"q": "pear",
"p": "plum",
"kid": "your_key_id",
"kty": "RSA"
}
"""
{
"kty": "RSA",
"key_ops": ["sign"],
"n": "banana",
"e": "AQAB",
"d": "apple",
"p": "peach",
"q": "pear",
"dp": "palm",
"dq": "pineapple",
"qi": "watermelon",
"kid": "your_key_id"
}
"""

Signing
~~~~~~~

To deserialize the keypair from above::

private_keys = jwk.KEYS()
serialized_keypair = json.loads(serialized_keypair_json)
private_keys.add(serialized_keypair)
from jwt.api_jwk import PyJWK

private_key = PyJWK.from_json(serialized_keypair_json)

To create a signature::

from jwkest.jws import JWS
jws = JWS("JWT payload", alg="RS512")
signed_message = jws.sign_compact(keys=private_keys)
import jwt

signed_message = jwt.encode("JWT payload in dict format", key=private_key.key, algorithm="RS512")

Note: we specify **RS512** above to identify *RSASSA-PKCS1-v1_5 using SHA-512* as
the signature algorithm value as described in the `JSON Web Algorithms (JWA)`_ spec.
Expand All @@ -156,19 +172,44 @@ Verify Signature

To verify the signature from above::

public_keys = jwk.KEYS()
public_keys.load_jwks(serialized_public_keys_json)
jws.verify_compact(signed_message, public_keys)
def _verify_jws_using_keyset(signed_message, key_set)
for i in range(len(key_set)):
try:
algorithms = None
if key_set[i].key_type == 'RSA':
algorithms = ['RS256', 'RS512',]
elif key_set[i].key_type == 'oct':
algorithms = ['HS256',]

return jwt.decode(
signed_message,
key=key_set[i].key,
algorithms=algorithms,
)
except Exception:
if i == len(key_set) - 1:
raise

key_set = PyJWKSet.from_json(serialized_public_keys_json).keys
verified_message = _verify_jws_using_keyset(signed_message, key_set)

Key Rotation
~~~~~~~~~~~~

When a new public key is added in the future, it should have a unique "kid"
value and added to the public keys JWK set::

new_rsa_key = RSA.generate(2048)
new_rsa_jwk = jwk.RSAKey(kid="new_id", key=new_rsa_key)
public_keys.append(new_rsa_jwk)
rsa_key = RSA.generate(2048)
algo = get_default_algorithms()['RS512']
pem = rsa_key.export_key('PEM').decode()
rsa_jwk = json.loads(algo.to_jwk(algo.prepare_key(pem)))
public_rsa_jwk = json.loads(algo.to_jwk(algo.prepare_key(pem).public_key()))

kid = 'new_id'
rsa_jwk['kid'] = kid
public_rsa_jwk['kid'] = kid

public_keys['keys'].append(public_rsa_jwk)

When a JWS is created, it is signed with a certain "kid"-identified keypair. When it
is later verified, the public key with the matching "kid" in the JWK set is used.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from Cryptodome.PublicKey import RSA
from django.conf import settings
from django.core.management.base import BaseCommand
from jwkest import jwk
from jwt.algorithms import get_default_algorithms

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,15 +123,23 @@ def _generate_key_id(self, size, chars=string.ascii_uppercase + string.digits):
def _generate_key_pair(self, key_size, key_id):
log.info('Generating new JWT signing keypair for key id %s.', key_id)
rsa_key = RSA.generate(key_size)
rsa_jwk = jwk.RSAKey(kid=key_id, key=rsa_key)
return rsa_jwk
algo = get_default_algorithms()['RS512']
key_data = algo.prepare_key(rsa_key.export_key('PEM').decode())
rsa_jwk = json.loads(algo.to_jwk(key_data))
public_rsa_jwk = json.loads(algo.to_jwk(key_data.public_key()))

rsa_jwk['kid'] = key_id
public_rsa_jwk['kid'] = key_id
return {'private': rsa_jwk, 'public': public_rsa_jwk}

def _output_public_keys(self, jwk_key, add_previous, strip_prefix):
public_keys = jwk.KEYS()
public_keys = {'keys': []}

if add_previous:
self._add_previous_public_keys(public_keys)
public_keys.append(jwk_key)
serialized_public_keys = public_keys.dump_jwks()

public_keys['keys'].append(jwk_key['public'])
serialized_public_keys = json.dumps(public_keys)

prefix = '' if strip_prefix else 'COMMON_'
public_signing_key = f'{prefix}JWT_PUBLIC_SIGNING_JWK_SET'
Expand All @@ -155,11 +163,10 @@ def _add_previous_public_keys(self, public_keys):
previous_signing_keys = settings.JWT_AUTH.get('JWT_PUBLIC_SIGNING_JWK_SET')
if previous_signing_keys:
log.info('Old JWT_PUBLIC_SIGNING_JWK_SET: %s.', previous_signing_keys)
public_keys.load_jwks(previous_signing_keys)
public_keys['keys'].extend(json.loads(previous_signing_keys)['keys'])

def _output_private_keys(self, jwk_key, strip_prefix):
serialized_keypair = jwk_key.serialize(private=True)
serialized_keypair_json = json.dumps(serialized_keypair)
serialized_keypair_json = json.dumps(jwk_key['private'])

prefix = '' if strip_prefix else 'EDXAPP_'
private_signing_key = f'{prefix}JWT_PRIVATE_SIGNING_JWK'
Expand Down

0 comments on commit 58d2482

Please sign in to comment.