Skip to content

Commit

Permalink
refactored DMA context struct, improved DMA registration/initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
bigbrett committed Apr 11, 2024
1 parent 2ffd116 commit 1fd84e3
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 32 deletions.
10 changes: 5 additions & 5 deletions src/wh_server.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ int wh_Server_Init(whServerContext* server, whServerConfig* config)
return WH_ERROR_ABORTED;
}

/* Initialize DMA configuration and callbacks */
if (NULL != config->dmaAddrAllowList) {
server->dmaAddrAllowList = config->dmaAddrAllowList;
/* Initialize DMA configuration and callbacks, if provided */
if (NULL != config->dmaConfig) {
server->dma.dmaAddrAllowList = config->dmaConfig->dmaAddrAllowList;
server->dma.cb32 = config->dmaConfig->cb32;
server->dma.cb64 = config->dmaConfig->cb64;
}
server->dmaCb.cb32 = NULL;
server->dmaCb.cb64 = NULL;

return rc;
}
Expand Down
42 changes: 26 additions & 16 deletions src/wh_server_dma.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,27 @@ static int _checkMemOperAgainstAllowList(const whDmaAddrAllowList* allowList,
return rc;
}

int wh_Server_DmaRegisterCb(whServerContext* server, whDmaCb cb)
int wh_Server_DmaRegisterCb32(whServerContext* server, whDmaClientMem32Cb cb)
{
if (NULL == server) {
if (NULL == server || NULL == cb) {
return WH_ERROR_BADARGS;
}

server->dmaCb = cb;
server->dma.cb32 = cb;

return WH_ERROR_OK;
}

int wh_Server_DmaRegisterCb64(whServerContext* server, whDmaClientMem64Cb cb)
{
if (NULL == server || NULL == cb) {
return WH_ERROR_BADARGS;
}

server->dma.cb64 = cb;

return WH_ERROR_OK;
}

int wh_Server_DmaRegisterAllowList(whServerContext* server,
const whDmaAddrAllowList* allowlist)
Expand All @@ -67,7 +77,7 @@ int wh_Server_DmaRegisterAllowList(whServerContext* server,
return WH_ERROR_BADARGS;
}

server->dmaAddrAllowList = allowlist;
server->dma.dmaAddrAllowList = allowlist;

return WH_ERROR_OK;
}
Expand All @@ -88,15 +98,15 @@ int wh_Server_DmaProcessClientAddress32(whServerContext* server,
*xformedCliAddr = (void*)((uintptr_t)clientAddr);

/* Perform user-supplied address transformation, cache manipulation, etc */
if (NULL != server->dmaCb.cb32) {
rc = server->dmaCb.cb32(server, clientAddr, xformedCliAddr, len, oper,
if (NULL != server->dma.cb32) {
rc = server->dma.cb32(server, clientAddr, xformedCliAddr, len, oper,
flags);
}

/* if the server has a allowlist registered, check transformed address
* against it */
if (server->dmaAddrAllowList != NULL) {
rc = _checkMemOperAgainstAllowList(server->dmaAddrAllowList, oper,
if (server->dma.dmaAddrAllowList != NULL) {
rc = _checkMemOperAgainstAllowList(server->dma.dmaAddrAllowList, oper,
*xformedCliAddr, len);
}

Expand All @@ -119,14 +129,14 @@ int wh_Server_DmaProcessClientAddress64(whServerContext* server,
*xformedCliAddr = (void*)((uintptr_t)clientAddr);

/* Perform user-supplied address transformation, cache manipulation, etc */
if (NULL != server->dmaCb.cb64) {
rc = server->dmaCb.cb64(server, clientAddr, xformedCliAddr, len, oper,
if (NULL != server->dma.cb64) {
rc = server->dma.cb64(server, clientAddr, xformedCliAddr, len, oper,
flags);
}

/* if the server has a allowlist registered, check address against it */
if (server->dmaAddrAllowList != NULL) {
rc = _checkMemOperAgainstAllowList(server->dmaAddrAllowList, oper,
if (server->dma.dmaAddrAllowList != NULL) {
rc = _checkMemOperAgainstAllowList(server->dma.dmaAddrAllowList, oper,
*xformedCliAddr, len);
}

Expand All @@ -149,7 +159,7 @@ int whServerDma_CopyFromClient32(struct whServerContext_t* server,

/* Check the server address against the allow list */
rc = _checkMemOperAgainstAllowList(
server->dmaAddrAllowList, WH_DMA_OPER_CLIENT_READ_PRE, serverPtr, len);
server->dma.dmaAddrAllowList, WH_DMA_OPER_CLIENT_READ_PRE, serverPtr, len);
if (rc != WH_ERROR_OK) {
return rc;
}
Expand Down Expand Up @@ -190,7 +200,7 @@ int whServerDma_CopyFromClient64(struct whServerContext_t* server,

/* Check the server address against the allow list */
rc = _checkMemOperAgainstAllowList(
server->dmaAddrAllowList, WH_DMA_OPER_CLIENT_READ_PRE, serverPtr, len);
server->dma.dmaAddrAllowList, WH_DMA_OPER_CLIENT_READ_PRE, serverPtr, len);
if (rc != WH_ERROR_OK) {
return rc;
}
Expand Down Expand Up @@ -230,7 +240,7 @@ int whServerDma_CopyToClient32(struct whServerContext_t* server,

/* Check the server address against the allow list */
rc = _checkMemOperAgainstAllowList(
server->dmaAddrAllowList, WH_DMA_OPER_CLIENT_WRITE_PRE, serverPtr, len);
server->dma.dmaAddrAllowList, WH_DMA_OPER_CLIENT_WRITE_PRE, serverPtr, len);
if (rc != WH_ERROR_OK) {
return rc;
}
Expand Down Expand Up @@ -271,7 +281,7 @@ int whServerDma_CopyToClient64(struct whServerContext_t* server,

/* Check the server address against the allow list */
rc = _checkMemOperAgainstAllowList(
server->dmaAddrAllowList, WH_DMA_OPER_CLIENT_WRITE_PRE, serverPtr, len);
server->dma.dmaAddrAllowList, WH_DMA_OPER_CLIENT_WRITE_PRE, serverPtr, len);
if (rc != WH_ERROR_OK) {
return rc;
}
Expand Down
4 changes: 2 additions & 2 deletions test/wh_test_clientserver.c
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ static int _testDma(whServerContext* server, whClientContext* client)
};

/* Register a custom DMA callback */
WH_TEST_RETURN_ON_FAIL(wh_Server_DmaRegisterCb(
server, (whDmaCb){_customServerDma32Cb, _customServerDma64Cb}));
WH_TEST_RETURN_ON_FAIL(wh_Server_DmaRegisterCb32(server, _customServerDma32Cb));
WH_TEST_RETURN_ON_FAIL(wh_Server_DmaRegisterCb64(server, _customServerDma64Cb));

/* Register our custom allow list */
WH_TEST_RETURN_ON_FAIL(wh_Server_DmaRegisterAllowList(server, &allowList));
Expand Down
8 changes: 4 additions & 4 deletions wolfhsm/wh_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ typedef struct whServerContext_t {
crypto_context* crypto;
CacheSlot cache[WOLFHSM_NUM_RAMKEYS];
whServerCustomCb customHandlerTable[WH_CUSTOM_CB_NUM_CALLBACKS];
whDmaCb dmaCb;
const whDmaAddrAllowList* dmaAddrAllowList;
whDmaContext dma;
} whServerContext;

typedef struct whServerConfig_t {
Expand All @@ -64,7 +63,7 @@ typedef struct whServerConfig_t {
#if defined WOLF_CRYPTO_CB /* TODO: should we be relying on wolfSSL defines? */
int devId;
#endif
whDmaAddrAllowList* dmaAddrAllowList;
whDmaConfig* dmaConfig;
} whServerConfig;

/* Initialize the comms and crypto cache components.
Expand Down Expand Up @@ -93,7 +92,8 @@ int wh_Server_HandleCustomCbRequest(whServerContext* server, uint16_t magic,

/* Registers custom client DMA callbacs to handle platform specific restrictions
* on accessing the client address space such as caching and address translation */
int wh_Server_DmaRegisterCb(whServerContext* server, whDmaCb cb);
int wh_Server_DmaRegisterCb32(whServerContext* server, whDmaClientMem32Cb cb);
int wh_Server_DmaRegisterCb64(whServerContext* server, whDmaClientMem64Cb cb);
int wh_Server_DmaRegisterAllowList(whServerContext* server, const whDmaAddrAllowList* allowlist);

/* Helper functions to invoke user supplied client address DMA callbacks */
Expand Down
17 changes: 12 additions & 5 deletions wolfhsm/wh_server_dma.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ typedef int (*whDmaClientMem64Cb)(struct whServerContext_t* server,
uint64_t len, whDmaOper oper,
whDmaFlags flags);

typedef struct {
whDmaClientMem32Cb cb32;
whDmaClientMem64Cb cb64;
} whDmaCb;

typedef struct {
void* addr;
size_t size;
Expand All @@ -49,6 +44,18 @@ typedef struct {
whDmaAddrList writeList;
} whDmaAddrAllowList;

typedef struct {
whDmaClientMem32Cb cb32; /* DMA callback for 32-bit system */
whDmaClientMem64Cb cb64; /* DMA callback for 64-bit system */
const whDmaAddrAllowList* dmaAddrAllowList; /* list of allowed addresses */
} whDmaConfig;

typedef struct {
whDmaClientMem32Cb cb32; /* DMA callback for 32-bit system */
whDmaClientMem64Cb cb64; /* DMA callback for 64-bit system */
const whDmaAddrAllowList* dmaAddrAllowList; /* list of allowed addresses */
} whDmaContext;

int whServerDma_CopyFromClient32(struct whServerContext_t* server,
void* serverPtr, uint32_t clientAddr,
size_t len, whDmaFlags flags);
Expand Down

0 comments on commit 1fd84e3

Please sign in to comment.