Skip to content

Commit

Permalink
Merge pull request #48 from jpbland1/separate-big-key
Browse files Browse the repository at this point in the history
refactor the HSM server and keystore to handle a
  • Loading branch information
billphipps authored Aug 9, 2024
2 parents fcd66ba + 88da7ef commit 79f49b5
Show file tree
Hide file tree
Showing 5 changed files with 367 additions and 131 deletions.
153 changes: 86 additions & 67 deletions src/wh_server_crypto.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,26 @@
#ifndef NO_RSA
static int hsmCacheKeyRsa(whServerContext* server, RsaKey* key, whKeyId* outId)
{
uint8_t* cacheBuf;
whNvmMetadata* cacheMeta;
int ret = 0;
int slotIdx = 0;
/* wc_RsaKeyToDer doesn't have a length check option so we need to just pass
* the big key size if compiled */
const uint16_t keySz = WOLFHSM_CFG_SERVER_KEYCACHE_BIG_BUFSIZE;
whKeyId keyId = WH_KEYTYPE_CRYPTO;
/* get a free slot */
ret = slotIdx = hsmCacheFindSlot(server);
if (ret >= 0) {
ret = hsmCacheFindSlotAndZero(server, keySz, &cacheBuf, &cacheMeta);
if (ret == 0)
ret = hsmGetUniqueId(server, &keyId);
}
if (ret == 0) {
/* export key */
/* TODO: Fix wolfCrypto to allow KeyToDer when KEY_GEN is NOT set */
XMEMSET((uint8_t*)&server->cache[slotIdx], 0, sizeof(whServerCacheSlot));
ret = wc_RsaKeyToDer(key, server->cache[slotIdx].buffer,
WOLFHSM_CFG_SERVER_KEYCACHE_BUFSIZE);
ret = wc_RsaKeyToDer(key, cacheBuf, keySz);
}
if (ret > 0) {
/* set meta */
server->cache[slotIdx].meta->id = keyId;
server->cache[slotIdx].meta->len = ret;
cacheMeta->id = keyId;
cacheMeta->len = ret;
/* export keyId */
*outId = keyId;
ret = 0;
Expand All @@ -77,29 +78,30 @@ static int hsmCacheKeyRsa(whServerContext* server, RsaKey* key, whKeyId* outId)

static int hsmLoadKeyRsa(whServerContext* server, RsaKey* key, whKeyId keyId)
{
uint8_t* cacheBuf;
whNvmMetadata* cacheMeta;
int ret = 0;
int slotIdx = 0;
uint32_t idx = 0;
uint32_t size;
keyId |= (WH_KEYTYPE_CRYPTO | (server->comm->client_id << 8));
/* freshen the key */
ret = slotIdx = hsmFreshenKey(server, keyId);
ret = hsmFreshenKey(server, keyId, &cacheBuf, &cacheMeta);
/* decode the key */
if (ret >= 0) {
size = WOLFHSM_CFG_SERVER_KEYCACHE_BUFSIZE;
ret = wc_RsaPrivateKeyDecode(server->cache[slotIdx].buffer, (word32*)&idx, key,
size);
if (ret == 0) {
ret = wc_RsaPrivateKeyDecode(cacheBuf, (word32*)&idx, key,
cacheMeta->len);
}
return ret;
}

#ifdef WOLFSSL_KEY_GEN
static int hsmCryptoRsaKeyGen(whServerContext* server, whPacket* packet,
uint16_t* size)
{
int ret;
whKeyId keyId = WH_KEYID_ERASED;
/* init the rsa key */
ret = wc_InitRsaKey_ex(server->crypto->algoCtx.rsa, NULL, server->crypto->devId);
ret = wc_InitRsaKey_ex(server->crypto->algoCtx.rsa, NULL,
server->crypto->devId);
/* make the rsa key with the given params */
if (ret == 0) {
#ifdef DEBUG_CRYPTOCB_VERBOSE
Expand Down Expand Up @@ -139,7 +141,8 @@ static int hsmCryptoRsaFunction(whServerContext* server, whPacket* packet,
byte* in = (uint8_t*)(&packet->pkRsaReq + 1);
byte* out = (uint8_t*)(&packet->pkRsaRes + 1);
/* init rsa key */
ret = wc_InitRsaKey_ex(server->crypto->algoCtx.rsa, NULL, server->crypto->devId);
ret = wc_InitRsaKey_ex(server->crypto->algoCtx.rsa, NULL,
server->crypto->devId);
/* load the key from the keystore */
if (ret == 0) {
ret = hsmLoadKeyRsa(server, server->crypto->algoCtx.rsa,
Expand Down Expand Up @@ -176,7 +179,8 @@ static int hsmCryptoRsaGetSize(whServerContext* server, whPacket* packet,
{
int ret;
/* init rsa key */
ret = wc_InitRsaKey_ex(server->crypto->algoCtx.rsa, NULL, server->crypto->devId);
ret = wc_InitRsaKey_ex(server->crypto->algoCtx.rsa, NULL,
server->crypto->devId);
/* load the key from the keystore */
if (ret == 0) {
ret = hsmLoadKeyRsa(server, server->crypto->algoCtx.rsa,
Expand All @@ -200,27 +204,27 @@ static int hsmCryptoRsaGetSize(whServerContext* server, whPacket* packet,
static int hsmCacheKeyCurve25519(whServerContext* server, curve25519_key* key,
whKeyId* outId)
{
uint8_t* cacheBuf;
whNvmMetadata* cacheMeta;
int ret;
int slotIdx = 0;
word32 privSz = CURVE25519_KEYSIZE;
word32 pubSz = CURVE25519_KEYSIZE;
whKeyId keyId = WH_KEYTYPE_CRYPTO;
const uint16_t keySz = CURVE25519_KEYSIZE * 2;
/* get a free slot */
ret = slotIdx = hsmCacheFindSlot(server);
if (ret >= 0) {
ret = hsmCacheFindSlotAndZero(server, keySz, &cacheBuf,
&cacheMeta);
if (ret == 0)
ret = hsmGetUniqueId(server, &keyId);
}
if (ret == 0) {
XMEMSET((uint8_t*)&server->cache[slotIdx], 0, sizeof(whServerCacheSlot));
/* export key */
ret = wc_curve25519_export_key_raw(key,
server->cache[slotIdx].buffer + CURVE25519_KEYSIZE, &privSz,
server->cache[slotIdx].buffer, &pubSz);
ret = wc_curve25519_export_key_raw(key, cacheBuf + CURVE25519_KEYSIZE,
&privSz, cacheBuf, &pubSz);
}
if (ret == 0) {
/* set meta */
server->cache[slotIdx].meta->id = keyId;
server->cache[slotIdx].meta->len = CURVE25519_KEYSIZE * 2;
cacheMeta->id = keyId;
cacheMeta->len = keySz;
/* export keyId */
*outId = keyId;
}
Expand All @@ -230,22 +234,21 @@ static int hsmCacheKeyCurve25519(whServerContext* server, curve25519_key* key,
static int hsmLoadKeyCurve25519(whServerContext* server, curve25519_key* key,
whKeyId keyId)
{
uint8_t* cacheBuf;
whNvmMetadata* cacheMeta;
int ret = 0;
int slotIdx = 0;
uint32_t privSz = CURVE25519_KEYSIZE;
uint32_t pubSz = CURVE25519_KEYSIZE;
keyId |= WH_KEYTYPE_CRYPTO;
/* freshen the key */
ret = slotIdx = hsmFreshenKey(server, keyId);
ret = hsmFreshenKey(server, keyId, &cacheBuf, &cacheMeta);
/* decode the key */
if (ret >= 0) {
ret = wc_curve25519_import_public(server->cache[slotIdx].buffer,
(word32)pubSz, key);
}
if (ret == 0)
ret = wc_curve25519_import_public(cacheBuf, (word32)pubSz, key);
/* only import private if what we got back holds 2 keys */
if (ret == 0 && server->cache[slotIdx].meta->len == CURVE25519_KEYSIZE * 2) {
ret = wc_curve25519_import_private(
server->cache[slotIdx].buffer + pubSz, (word32)privSz, key);
if (ret == 0 && cacheMeta->len == CURVE25519_KEYSIZE * 2) {
ret = wc_curve25519_import_private( cacheBuf + pubSz, (word32)privSz,
key);
}
return ret;
}
Expand All @@ -266,8 +269,8 @@ static int hsmCryptoCurve25519KeyGen(whServerContext* server, whPacket* packet,
}
/* cache the generated key */
if (ret == 0) {
ret = hsmCacheKeyCurve25519(server, server->crypto->algoCtx.curve25519Private,
&keyId);
ret = hsmCacheKeyCurve25519(server,
server->crypto->algoCtx.curve25519Private, &keyId);
}
/* set the assigned id */
wc_curve25519_free(server->crypto->algoCtx.curve25519Private);
Expand Down Expand Up @@ -331,8 +334,9 @@ static int hsmCryptoCurve25519(whServerContext* server, whPacket* packet,
#ifdef HAVE_ECC
static int hsmCacheKeyEcc(whServerContext* server, ecc_key* key, whKeyId* outId)
{
uint8_t* cacheBuf;
whNvmMetadata* cacheMeta;
int ret;
int slotIdx = 0;
word32 qxLen = 0;
word32 qyLen = 0;
word32 qdLen = 0;
Expand All @@ -341,22 +345,21 @@ static int hsmCacheKeyEcc(whServerContext* server, ecc_key* key, whKeyId* outId)
byte* qyBuf = NULL;
byte* qdBuf = NULL;
/* get a free slot */
ret = slotIdx = hsmCacheFindSlot(server);
if (ret >= 0) {
ret = hsmCacheFindSlotAndZero(server, qxLen + qyLen + qdLen, &cacheBuf,
&cacheMeta);
if (ret == 0)
ret = hsmGetUniqueId(server, &keyId);
}
/* export key */
if (ret == 0) {
XMEMSET((uint8_t*)&server->cache[slotIdx], 0, sizeof(whServerCacheSlot));
if (key->type != ECC_PRIVATEKEY_ONLY) {
qxLen = qyLen = key->dp->size;
qxBuf = server->cache[slotIdx].buffer;
qxBuf = cacheBuf;
qyBuf = qxBuf + qxLen;
}
if (key->type == ECC_PRIVATEKEY_ONLY || key->type == ECC_PRIVATEKEY) {
qdLen = key->dp->size;
if (key->type == ECC_PRIVATEKEY_ONLY) {
qdBuf = server->cache[slotIdx].buffer;
qdBuf = cacheBuf;
}
else {
qdBuf = qyBuf + qyLen;
Expand All @@ -367,8 +370,8 @@ static int hsmCacheKeyEcc(whServerContext* server, ecc_key* key, whKeyId* outId)
}
if (ret == 0) {
/* set meta */
server->cache[slotIdx].meta->id = keyId;
server->cache[slotIdx].meta->len = qxLen + qyLen + qdLen;
cacheMeta->id = keyId;
cacheMeta->len = qxLen + qyLen + qdLen;
/* export keyId */
*outId = keyId;
}
Expand All @@ -378,8 +381,9 @@ static int hsmCacheKeyEcc(whServerContext* server, ecc_key* key, whKeyId* outId)
static int hsmLoadKeyEcc(whServerContext* server, ecc_key* key, uint16_t keyId,
int curveId)
{
uint8_t* cacheBuf;
whNvmMetadata* cacheMeta;
int ret;
int slotIdx = 0;
int curveIdx;
word32 qxLen = 0;
word32 qyLen = 0;
Expand All @@ -390,7 +394,7 @@ static int hsmLoadKeyEcc(whServerContext* server, ecc_key* key, uint16_t keyId,
byte* qdBuf = NULL;
keyId |= WH_KEYTYPE_CRYPTO;
/* freshen the key */
ret = slotIdx = hsmFreshenKey(server, keyId);
ret = hsmFreshenKey(server, keyId, &cacheBuf, &cacheMeta);
/* get the size by curveId */
if (ret >= 0) {
ret = curveIdx = wc_ecc_get_curve_idx(curveId);
Expand All @@ -402,20 +406,20 @@ static int hsmLoadKeyEcc(whServerContext* server, ecc_key* key, uint16_t keyId,
if (ret >= 0) {
/* determine which buffers should be set by size, wc_ecc_import_unsigned
* will set the key type accordingly */
if (server->cache[slotIdx].meta->len == keySz * 3) {
if (cacheMeta->len == keySz * 3) {
qxLen = qyLen = qdLen = keySz;
qxBuf = server->cache[slotIdx].buffer;
qxBuf = cacheBuf;
qyBuf = qxBuf + qxLen;
qdBuf = qyBuf + qyLen;
}
else if (server->cache[slotIdx].meta->len == keySz * 2) {
else if (cacheMeta->len == keySz * 2) {
qxLen = qyLen = keySz;
qxBuf = server->cache[slotIdx].buffer;
qxBuf = cacheBuf;
qyBuf = qxBuf + qxLen;
}
else {
qxLen = qyLen = qdLen = keySz;
qdBuf = server->cache[slotIdx].buffer;
qdBuf = cacheBuf;
}
ret = wc_ecc_import_unsigned(key, qxBuf, qyBuf, qdBuf, curveId);
}
Expand All @@ -437,8 +441,10 @@ static int hsmCryptoEcKeyGen(whServerContext* server, whPacket* packet,
packet->pkEckgReq.curveId);
}
/* cache the generated key */
if (ret == 0)
ret = hsmCacheKeyEcc(server, server->crypto->algoCtx.eccPrivate, &keyId);
if (ret == 0) {
ret = hsmCacheKeyEcc(server, server->crypto->algoCtx.eccPrivate,
&keyId);
}
/* set the assigned id */
wc_ecc_free(server->crypto->algoCtx.eccPrivate);
if (ret == 0) {
Expand Down Expand Up @@ -468,8 +474,10 @@ static int hsmCryptoEcdh(whServerContext* server, whPacket* packet,
packet->pkEcdhReq.privateKeyId, packet->pkEcdhReq.curveId);
}
/* set rng */
if (ret == 0)
ret = wc_ecc_set_rng(server->crypto->algoCtx.eccPrivate, server->crypto->rng);
if (ret == 0) {
ret = wc_ecc_set_rng(server->crypto->algoCtx.eccPrivate,
server->crypto->rng);
}
/* load the public key */
if (ret == 0) {
ret = hsmLoadKeyEcc(server, server->crypto->pubKey.eccPublic,
Expand Down Expand Up @@ -540,7 +548,8 @@ static int hsmCryptoEcdsaVerify(whServerContext* server, whPacket* packet,
/* verify the signature */
if (ret == 0) {
ret = wc_ecc_verify_hash(sig, packet->pkEccVerifyReq.sigSz, hash,
packet->pkEccVerifyReq.hashSz, &res, server->crypto->pubKey.eccPublic);
packet->pkEccVerifyReq.hashSz, &res,
server->crypto->pubKey.eccPublic);
}
wc_ecc_free(server->crypto->pubKey.eccPublic);
if (ret == 0) {
Expand Down Expand Up @@ -603,8 +612,10 @@ static int hsmCryptoAesCbc(whServerContext* server, whPacket* packet,
}
}
/* init key with possible hardware */
if (ret == 0)
ret = wc_AesInit(server->crypto->algoCtx.aes, NULL, server->crypto->devId);
if (ret == 0) {
ret = wc_AesInit(server->crypto->algoCtx.aes, NULL,
server->crypto->devId);
}
/* load the key */
if (ret == 0) {
ret = wc_AesSetKey(server->crypto->algoCtx.aes, key,
Expand Down Expand Up @@ -662,8 +673,10 @@ static int hsmCryptoAesGcm(whServerContext* server, whPacket* packet,
}
}
/* init key with possible hardware */
if (ret == 0)
ret = wc_AesInit(server->crypto->algoCtx.aes, NULL, server->crypto->devId);
if (ret == 0) {
ret = wc_AesInit(server->crypto->algoCtx.aes, NULL,
server->crypto->devId);
}
/* load the key */
if (ret == 0) {
ret = wc_AesGcmSetKey(server->crypto->algoCtx.aes, key,
Expand All @@ -681,8 +694,8 @@ static int hsmCryptoAesGcm(whServerContext* server, whPacket* packet,
/* copy authTagSz since it will be overwritten */
packet->cipherAesGcmRes.authTagSz =
packet->cipherAesGcmReq.authTagSz;
ret = wc_AesGcmEncrypt(server->crypto->algoCtx.aes, out, in, len, iv,
packet->cipherAesGcmReq.ivSz, authTag,
ret = wc_AesGcmEncrypt(server->crypto->algoCtx.aes, out, in, len,
iv, packet->cipherAesGcmReq.ivSz, authTag,
packet->cipherAesGcmReq.authTagSz, authIn,
packet->cipherAesGcmReq.authInSz);
}
Expand Down Expand Up @@ -724,6 +737,7 @@ static int hsmCryptoCmac(whServerContext* server, whPacket* packet,
byte* key = in + packet->cmacReq.inSz;
byte* out = (uint8_t*)(&packet->cmacRes + 1);
whNvmMetadata meta[1] = {{0}};
uint8_t moveToBigCache = 0;
/* do oneshot if all fields are present */
if (packet->cmacReq.inSz != 0 && packet->cmacReq.keySz != 0 &&
packet->cmacReq.outSz != 0) {
Expand Down Expand Up @@ -753,6 +767,7 @@ static int hsmCryptoCmac(whServerContext* server, whPacket* packet,
* overwrite the existing key on exit */
if (len == AES_128_KEY_SIZE || len == AES_192_KEY_SIZE ||
len == AES_256_KEY_SIZE) {
moveToBigCache = 1;
XMEMCPY(tmpKey, (uint8_t*)server->crypto->algoCtx.cmac, len);
/* type is not a part of the update call, assume AES */
ret = wc_InitCmac_ex(server->crypto->algoCtx.cmac, tmpKey, len,
Expand Down Expand Up @@ -807,6 +822,10 @@ static int hsmCryptoCmac(whServerContext* server, whPacket* packet,
keyId = WH_MAKE_KEYID(WH_KEYTYPE_CRYPTO,
server->comm->client_id, packet->cmacReq.keyId);
}
/* evict the aes sized key in the normal cache */
if (moveToBigCache == 1) {
ret = hsmEvictKey(server, keyId);
}
meta->id = keyId;
meta->len = sizeof(server->crypto->algoCtx.cmac);
ret = hsmCacheKey(server, meta, (uint8_t*)server->crypto->algoCtx.cmac);
Expand Down
Loading

0 comments on commit 79f49b5

Please sign in to comment.