From aa7e063adaf67b65d6b6e628eb2f15d5858ce0b3 Mon Sep 17 00:00:00 2001 From: M Umar Khan Date: Tue, 30 May 2023 13:30:43 +0500 Subject: [PATCH] feat: replace pyjwkest with pyjwt --- .../decisions/0008-use-asymmetric-jwts.rst | 117 ++++++++++++------ 1 file changed, 79 insertions(+), 38 deletions(-) diff --git a/openedx/core/djangoapps/oauth_dispatch/docs/decisions/0008-use-asymmetric-jwts.rst b/openedx/core/djangoapps/oauth_dispatch/docs/decisions/0008-use-asymmetric-jwts.rst index d65aaaf4a90a..03cfbaa128ad 100644 --- a/openedx/core/djangoapps/oauth_dispatch/docs/decisions/0008-use-asymmetric-jwts.rst +++ b/openedx/core/djangoapps/oauth_dispatch/docs/decisions/0008-use-asymmetric-jwts.rst @@ -89,62 +89,78 @@ KeyPair Generation Here is code for generating a keypair:: + import json from Cryptodome.PublicKey import RSA - from jwkest import jwk + from jwt.algorithms import get_default_algorithms rsa_key = RSA.generate(2048) - rsa_jwk = jwk.RSAKey(kid="your_key_id", key=rsa_key) + algo = get_default_algorithms()['RS512'] + pem = rsa_key.export_key('PEM').decode() + rsa_jwk = json.loads(algo.to_jwk(algo.prepare_key(pem))) + public_rsa_jwk = json.loads(algo.to_jwk(algo.prepare_key(pem).public_key())) + + kid = 'your_key_id' + rsa_jwk['kid'] = kid + public_rsa_jwk['kid'] = kid + To serialize the **public key** in a `JSON Web Key Set (JWK Set)`_:: - public_keys = jwk.KEYS() - public_keys.append(rsa_jwk) - serialized_public_keys_json = public_keys.dump_jwks() + public_keys = { 'keys': [public_rsa_jwk] } + serialized_public_keys_json = json.dumps(public_keys) and its sample output:: - { - "keys": [ - { - "kid": "your_key_id", - "e": "strawberry", - "kty": "RSA", - "n": "something" - } - ] - } + """ + { + "keys": [ + { + "kty": "RSA", + "key_ops": ["verify"], + "n": "banana", + "e": "AQAB", + "kid": "your_key_id" + } + ] + } + """ To serialize the **keypair** as a JWK:: - serialized_keypair = rsa_jwk.serialize(private=True) - serialized_keypair_json = json.dumps(serialized_keypair) + serialized_keypair_json = json.dumps(rsa_jwk) and its sample output:: - { - "e": "strawberry", - "d": "apple", - "n": "banana", - "q": "pear", - "p": "plum", - "kid": "your_key_id", - "kty": "RSA" - } + """ + { + "kty": "RSA", + "key_ops": ["sign"], + "n": "banana", + "e": "AQAB", + "d": "apple", + "p": "peach", + "q": "pear", + "dp": "palm", + "dq": "pineapple", + "qi": "watermelon", + "kid": "your_key_id" + } + """ Signing ~~~~~~~ To deserialize the keypair from above:: - private_keys = jwk.KEYS() - serialized_keypair = json.loads(serialized_keypair_json) - private_keys.add(serialized_keypair) + from jwt.api_jwk import PyJWK + + private_key = PyJWK.from_json(serialized_keypair_json) To create a signature:: - from jwkest.jws import JWS - jws = JWS("JWT payload", alg="RS512") - signed_message = jws.sign_compact(keys=private_keys) + import jwt + + signed_message = jwt.encode("JWT payload in dict format", key=private_key.key, algorithm="RS512") Note: we specify **RS512** above to identify *RSASSA-PKCS1-v1_5 using SHA-512* as the signature algorithm value as described in the `JSON Web Algorithms (JWA)`_ spec. @@ -156,9 +172,26 @@ Verify Signature To verify the signature from above:: - public_keys = jwk.KEYS() - public_keys.load_jwks(serialized_public_keys_json) - jws.verify_compact(signed_message, public_keys) + def _verify_jws_using_keyset(signed_message, key_set) + for i in range(len(key_set)): + try: + algorithms = None + if key_set[i].key_type == 'RSA': + algorithms = ['RS256', 'RS512',] + elif key_set[i].key_type == 'oct': + algorithms = ['HS256',] + + return jwt.decode( + signed_message, + key=key_set[i].key, + algorithms=algorithms, + ) + except Exception: + if i == len(key_set) - 1: + raise + + key_set = PyJWKSet.from_json(serialized_public_keys_json).keys + verified_message = _verify_jws_using_keyset(signed_message, key_set) Key Rotation ~~~~~~~~~~~~ @@ -166,9 +199,17 @@ Key Rotation When a new public key is added in the future, it should have a unique "kid" value and added to the public keys JWK set:: - new_rsa_key = RSA.generate(2048) - new_rsa_jwk = jwk.RSAKey(kid="new_id", key=new_rsa_key) - public_keys.append(new_rsa_jwk) + rsa_key = RSA.generate(2048) + algo = get_default_algorithms()['RS512'] + pem = rsa_key.export_key('PEM').decode() + rsa_jwk = json.loads(algo.to_jwk(algo.prepare_key(pem))) + public_rsa_jwk = json.loads(algo.to_jwk(algo.prepare_key(pem).public_key())) + + kid = 'new_id' + rsa_jwk['kid'] = kid + public_rsa_jwk['kid'] = kid + + public_keys['keys'].append(public_rsa_jwk) When a JWS is created, it is signed with a certain "kid"-identified keypair. When it is later verified, the public key with the matching "kid" in the JWK set is used.