Skip to content

Commit

Permalink
Review comments:
Browse files Browse the repository at this point in the history
- Exposed allowlist check as external API
- Added DMA operation check for invalid operations
- Added size==0 check as invalid
- Added missing return codes
- Expanded testing for new API and edge cases
  • Loading branch information
bigbrett committed Apr 16, 2024
1 parent 144d1ca commit cd0c4db
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 25 deletions.
75 changes: 50 additions & 25 deletions src/wh_server_dma.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
* efficient representation (requiring sorted array and binary search, or
* building a binary tree, etc.) */

static int _checkOperValid(whServerDmaOper oper)
{
if (oper < WH_DMA_OPER_CLIENT_READ_PRE ||
oper > WH_DMA_OPER_CLIENT_WRITE_POST) {
return WH_ERROR_BADARGS;
}

return WH_ERROR_OK;
}

static int _checkAddrAgainstAllowList(const whServerDmaAddrList allowList, void* addr,
size_t size)
Expand Down Expand Up @@ -34,28 +43,49 @@ static int _checkAddrAgainstAllowList(const whServerDmaAddrList allowList, void*
return WH_ERROR_ACCESS;
}

static int _checkMemOperAgainstAllowList(const whServerDmaAddrAllowList* allowList,
static int _checkMemOperAgainstAllowList(const whServerContext* server,
whServerDmaOper oper, void* addr,
size_t size)
{
int rc = WH_ERROR_OK;

rc = _checkOperValid(oper);
if (rc != WH_ERROR_OK) {
return rc;
}

/* If no allowlist is registered, anything goes */
if (server->dma.dmaAddrAllowList == NULL) {
return WH_ERROR_OK;
}

/* If a read/write operation is requested, check the transformed address
* against the appropriate allowlist
*
* TODO: do we need to allowlist check on POST in case there are subsequent
* memory operations for some reason?
*/
if (oper == WH_DMA_OPER_CLIENT_READ_PRE) {
rc = _checkAddrAgainstAllowList(allowList->readList, addr, size);
rc = _checkAddrAgainstAllowList(server->dma.dmaAddrAllowList->readList, addr, size);
}
else if (oper == WH_DMA_OPER_CLIENT_WRITE_PRE) {
rc = _checkAddrAgainstAllowList(allowList->writeList, addr, size);
rc = _checkAddrAgainstAllowList(server->dma.dmaAddrAllowList->writeList, addr, size);
}

return rc;
}

int wh_Server_DmaCheckMemOperAllowed(const whServerContext* server,
whServerDmaOper oper, void* addr, size_t size)
{
/* NULL addr is allowed here, since 0 is a valid address */
if (NULL == server || 0 == size) {
return WH_ERROR_BADARGS;
}

return _checkMemOperAgainstAllowList(server, oper, addr, size);
}

int wh_Server_DmaRegisterCb32(whServerContext* server, whServerDmaClientMem32Cb cb)
{
if (NULL == server || NULL == cb) {
Expand Down Expand Up @@ -110,19 +140,16 @@ int wh_Server_DmaProcessClientAddress32(whServerContext* server,
if (NULL != server->dma.cb32) {
rc = server->dma.cb32(server, clientAddr, xformedCliAddr, len, oper,
flags);
if (rc != WH_ERROR_OK) {
return rc;
}
}

/* if the server has a allowlist registered, check transformed address
* against it */
if (server->dma.dmaAddrAllowList != NULL) {
rc = _checkMemOperAgainstAllowList(server->dma.dmaAddrAllowList, oper,
*xformedCliAddr, len);
}

return rc;
return _checkMemOperAgainstAllowList(server, oper, *xformedCliAddr, len);
}


int wh_Server_DmaProcessClientAddress64(whServerContext* server,
uint64_t clientAddr,
void** xformedCliAddr, uint64_t len,
Expand All @@ -140,16 +167,14 @@ int wh_Server_DmaProcessClientAddress64(whServerContext* server,
/* Perform user-supplied address transformation, cache manipulation, etc */
if (NULL != server->dma.cb64) {
rc = server->dma.cb64(server, clientAddr, xformedCliAddr, len, oper,
flags);
flags);
if (rc != WH_ERROR_OK) {
return rc;
}
}

/* if the server has a allowlist registered, check address against it */
if (server->dma.dmaAddrAllowList != NULL) {
rc = _checkMemOperAgainstAllowList(server->dma.dmaAddrAllowList, oper,
*xformedCliAddr, len);
}

return rc;
return _checkMemOperAgainstAllowList(server, oper, *xformedCliAddr, len);
}


Expand All @@ -167,8 +192,8 @@ int whServerDma_CopyFromClient32(struct whServerContext_t* server,
}

/* Check the server address against the allow list */
rc = _checkMemOperAgainstAllowList(
server->dma.dmaAddrAllowList, WH_DMA_OPER_CLIENT_READ_PRE, serverPtr, len);
rc = _checkMemOperAgainstAllowList(server, WH_DMA_OPER_CLIENT_READ_PRE,
serverPtr, len);
if (rc != WH_ERROR_OK) {
return rc;
}
Expand Down Expand Up @@ -208,8 +233,8 @@ int whServerDma_CopyFromClient64(struct whServerContext_t* server,
}

/* Check the server address against the allow list */
rc = _checkMemOperAgainstAllowList(
server->dma.dmaAddrAllowList, WH_DMA_OPER_CLIENT_READ_PRE, serverPtr, len);
rc = _checkMemOperAgainstAllowList(server, WH_DMA_OPER_CLIENT_READ_PRE,
serverPtr, len);
if (rc != WH_ERROR_OK) {
return rc;
}
Expand Down Expand Up @@ -248,8 +273,8 @@ int whServerDma_CopyToClient32(struct whServerContext_t* server,
}

/* Check the server address against the allow list */
rc = _checkMemOperAgainstAllowList(
server->dma.dmaAddrAllowList, WH_DMA_OPER_CLIENT_WRITE_PRE, serverPtr, len);
rc = _checkMemOperAgainstAllowList(server, WH_DMA_OPER_CLIENT_WRITE_PRE,
serverPtr, len);
if (rc != WH_ERROR_OK) {
return rc;
}
Expand Down Expand Up @@ -289,8 +314,8 @@ int whServerDma_CopyToClient64(struct whServerContext_t* server,
}

/* Check the server address against the allow list */
rc = _checkMemOperAgainstAllowList(
server->dma.dmaAddrAllowList, WH_DMA_OPER_CLIENT_WRITE_PRE, serverPtr, len);
rc = _checkMemOperAgainstAllowList(server, WH_DMA_OPER_CLIENT_WRITE_PRE,
serverPtr, len);
if (rc != WH_ERROR_OK) {
return rc;
}
Expand Down
66 changes: 66 additions & 0 deletions test/wh_test_clientserver.c
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,46 @@ static int _testDma(whServerContext* server, whClientContext* client)
/* Register our custom allow list */
WH_TEST_RETURN_ON_FAIL(wh_Server_DmaRegisterAllowList(server, &allowList));

/* Check allowed operations for addresses in the allowlist */
WH_TEST_ASSERT_RETURN(
WH_ERROR_OK == wh_Server_DmaCheckMemOperAllowed(
server, WH_DMA_OPER_CLIENT_READ_PRE,
testMem.srvBufAllow, sizeof(testMem.srvBufAllow)));
WH_TEST_ASSERT_RETURN(
WH_ERROR_OK ==
wh_Server_DmaCheckMemOperAllowed(server, WH_DMA_OPER_CLIENT_READ_PRE,
testMem.srvRemapBufAllow,
sizeof(testMem.srvRemapBufAllow)));
WH_TEST_ASSERT_RETURN(
WH_ERROR_OK == wh_Server_DmaCheckMemOperAllowed(
server, WH_DMA_OPER_CLIENT_WRITE_PRE,
testMem.srvBufAllow, sizeof(testMem.srvBufAllow)));
WH_TEST_ASSERT_RETURN(
WH_ERROR_OK ==
wh_Server_DmaCheckMemOperAllowed(server, WH_DMA_OPER_CLIENT_WRITE_PRE,
testMem.srvRemapBufAllow,
sizeof(testMem.srvRemapBufAllow)));

/* Ensure an address not in the allowlist is denied for all operations */
WH_TEST_ASSERT_RETURN(WH_ERROR_ACCESS ==
wh_Server_DmaCheckMemOperAllowed(
server, WH_DMA_OPER_CLIENT_READ_PRE,
testMem.srvBufDeny, sizeof(testMem.srvBufDeny)));
WH_TEST_ASSERT_RETURN(WH_ERROR_ACCESS ==
wh_Server_DmaCheckMemOperAllowed(
server, WH_DMA_OPER_CLIENT_WRITE_PRE,
testMem.srvBufDeny, sizeof(testMem.srvBufDeny)));

/* Zero sized operations should be denied */
WH_TEST_ASSERT_RETURN(
WH_ERROR_BADARGS ==
wh_Server_DmaCheckMemOperAllowed(server, WH_DMA_OPER_CLIENT_READ_PRE,
testMem.srvBufAllow, 0));
WH_TEST_ASSERT_RETURN(
WH_ERROR_BADARGS ==
wh_Server_DmaCheckMemOperAllowed(server, WH_DMA_OPER_CLIENT_WRITE_PRE,
testMem.srvBufAllow, 0));

/* Set known pattern in client Buffer */
memset(testMem.cliBuf, TEST_MEM_CLI_BYTE, sizeof(testMem.cliBuf));

Expand Down Expand Up @@ -314,6 +354,32 @@ static int _testDma(whServerContext* server, whClientContext* client)
sizeof(testMem.srvBufDeny), (whServerDmaFlags){0}));
}

/* Check that zero-sized copies fail, even from allowed addresses */
if (sizeof(void*) == sizeof(uint64_t)) {
/* 64-bit host system */
WH_TEST_ASSERT_RETURN(WH_ERROR_BADARGS ==
whServerDma_CopyFromClient64(
server, testMem.srvBufAllow,
(uint64_t)((uintptr_t)testMem.cliBuf), 0,
(whServerDmaFlags){0}));
WH_TEST_ASSERT_RETURN(WH_ERROR_BADARGS ==
whServerDma_CopyToClient64(
server, (uint64_t)((uintptr_t)testMem.cliBuf),
testMem.srvBufAllow, 0, (whServerDmaFlags){0}));
}
else if (sizeof(void*) == sizeof(uint32_t)) {
/* 32-bit host system */
WH_TEST_ASSERT_RETURN(WH_ERROR_BADARGS ==
whServerDma_CopyFromClient32(
server, testMem.srvBufAllow,
(uint32_t)((uintptr_t)testMem.cliBuf), 0,
(whServerDmaFlags){0}));
WH_TEST_ASSERT_RETURN(WH_ERROR_BADARGS ==
whServerDma_CopyToClient32(
server, (uint32_t)((uintptr_t)testMem.cliBuf),
testMem.srvBufAllow, 0, (whServerDmaFlags){0}));
}

return rc;
}

Expand Down
7 changes: 7 additions & 0 deletions wolfhsm/wh_server_dma.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ int wh_Server_DmaRegisterCb64(struct whServerContext_t* server,
int wh_Server_DmaRegisterAllowList(struct whServerContext_t* server,
const whServerDmaAddrAllowList* allowlist);

/* Checks a desired memory operation against the server allowlist */
int wh_Server_DmaCheckMemOperAllowed(const struct whServerContext_t* server,
whServerDmaOper oper, void* addr,
size_t size);

/* Helper functions to invoke user supplied client address DMA callbacks */
int wh_Server_DmaProcessClientAddress32(struct whServerContext_t* server,
uint32_t clientAddr, void** serverPtr,
Expand All @@ -89,6 +94,8 @@ int wh_Server_DmaProcessClientAddress64(struct whServerContext_t* server,
uint64_t len, whServerDmaOper oper,
whServerDmaFlags flags);

/* Helper functions to copy data to/from client addresses that invoke the
* appropriate callbacks and allowlist checks */
int whServerDma_CopyFromClient32(struct whServerContext_t* server,
void* serverPtr, uint32_t clientAddr,
size_t len, whServerDmaFlags flags);
Expand Down

0 comments on commit cd0c4db

Please sign in to comment.