diff --git a/src/wh_server_dma.c b/src/wh_server_dma.c index 9cdc2263..e6f1b6e1 100644 --- a/src/wh_server_dma.c +++ b/src/wh_server_dma.c @@ -6,6 +6,15 @@ * efficient representation (requiring sorted array and binary search, or * building a binary tree, etc.) */ +static int _checkOperValid(whServerDmaOper oper) +{ + if (oper < WH_DMA_OPER_CLIENT_READ_PRE || + oper > WH_DMA_OPER_CLIENT_WRITE_POST) { + return WH_ERROR_BADARGS; + } + + return WH_ERROR_OK; +} static int _checkAddrAgainstAllowList(const whServerDmaAddrList allowList, void* addr, size_t size) @@ -34,12 +43,22 @@ static int _checkAddrAgainstAllowList(const whServerDmaAddrList allowList, void* return WH_ERROR_ACCESS; } -static int _checkMemOperAgainstAllowList(const whServerDmaAddrAllowList* allowList, +static int _checkMemOperAgainstAllowList(const whServerContext* server, whServerDmaOper oper, void* addr, size_t size) { int rc = WH_ERROR_OK; + rc = _checkOperValid(oper); + if (rc != WH_ERROR_OK) { + return rc; + } + + /* If no allowlist is registered, anything goes */ + if (server->dma.dmaAddrAllowList == NULL) { + return WH_ERROR_OK; + } + /* If a read/write operation is requested, check the transformed address * against the appropriate allowlist * @@ -47,15 +66,26 @@ static int _checkMemOperAgainstAllowList(const whServerDmaAddrAllowList* allowLi * memory operations for some reason? */ if (oper == WH_DMA_OPER_CLIENT_READ_PRE) { - rc = _checkAddrAgainstAllowList(allowList->readList, addr, size); + rc = _checkAddrAgainstAllowList(server->dma.dmaAddrAllowList->readList, addr, size); } else if (oper == WH_DMA_OPER_CLIENT_WRITE_PRE) { - rc = _checkAddrAgainstAllowList(allowList->writeList, addr, size); + rc = _checkAddrAgainstAllowList(server->dma.dmaAddrAllowList->writeList, addr, size); } return rc; } +int wh_Server_DmaCheckMemOperAllowed(const whServerContext* server, + whServerDmaOper oper, void* addr, size_t size) +{ + /* NULL addr is allowed here, since 0 is a valid address */ + if (NULL == server || 0 == size) { + return WH_ERROR_BADARGS; + } + + return _checkMemOperAgainstAllowList(server, oper, addr, size); +} + int wh_Server_DmaRegisterCb32(whServerContext* server, whServerDmaClientMem32Cb cb) { if (NULL == server || NULL == cb) { @@ -110,19 +140,16 @@ int wh_Server_DmaProcessClientAddress32(whServerContext* server, if (NULL != server->dma.cb32) { rc = server->dma.cb32(server, clientAddr, xformedCliAddr, len, oper, flags); + if (rc != WH_ERROR_OK) { + return rc; + } } /* if the server has a allowlist registered, check transformed address * against it */ - if (server->dma.dmaAddrAllowList != NULL) { - rc = _checkMemOperAgainstAllowList(server->dma.dmaAddrAllowList, oper, - *xformedCliAddr, len); - } - - return rc; + return _checkMemOperAgainstAllowList(server, oper, *xformedCliAddr, len); } - int wh_Server_DmaProcessClientAddress64(whServerContext* server, uint64_t clientAddr, void** xformedCliAddr, uint64_t len, @@ -140,16 +167,14 @@ int wh_Server_DmaProcessClientAddress64(whServerContext* server, /* Perform user-supplied address transformation, cache manipulation, etc */ if (NULL != server->dma.cb64) { rc = server->dma.cb64(server, clientAddr, xformedCliAddr, len, oper, - flags); + flags); + if (rc != WH_ERROR_OK) { + return rc; + } } /* if the server has a allowlist registered, check address against it */ - if (server->dma.dmaAddrAllowList != NULL) { - rc = _checkMemOperAgainstAllowList(server->dma.dmaAddrAllowList, oper, - *xformedCliAddr, len); - } - - return rc; + return _checkMemOperAgainstAllowList(server, oper, *xformedCliAddr, len); } @@ -167,8 +192,8 @@ int whServerDma_CopyFromClient32(struct whServerContext_t* server, } /* Check the server address against the allow list */ - rc = _checkMemOperAgainstAllowList( - server->dma.dmaAddrAllowList, WH_DMA_OPER_CLIENT_READ_PRE, serverPtr, len); + rc = _checkMemOperAgainstAllowList(server, WH_DMA_OPER_CLIENT_READ_PRE, + serverPtr, len); if (rc != WH_ERROR_OK) { return rc; } @@ -208,8 +233,8 @@ int whServerDma_CopyFromClient64(struct whServerContext_t* server, } /* Check the server address against the allow list */ - rc = _checkMemOperAgainstAllowList( - server->dma.dmaAddrAllowList, WH_DMA_OPER_CLIENT_READ_PRE, serverPtr, len); + rc = _checkMemOperAgainstAllowList(server, WH_DMA_OPER_CLIENT_READ_PRE, + serverPtr, len); if (rc != WH_ERROR_OK) { return rc; } @@ -248,8 +273,8 @@ int whServerDma_CopyToClient32(struct whServerContext_t* server, } /* Check the server address against the allow list */ - rc = _checkMemOperAgainstAllowList( - server->dma.dmaAddrAllowList, WH_DMA_OPER_CLIENT_WRITE_PRE, serverPtr, len); + rc = _checkMemOperAgainstAllowList(server, WH_DMA_OPER_CLIENT_WRITE_PRE, + serverPtr, len); if (rc != WH_ERROR_OK) { return rc; } @@ -289,8 +314,8 @@ int whServerDma_CopyToClient64(struct whServerContext_t* server, } /* Check the server address against the allow list */ - rc = _checkMemOperAgainstAllowList( - server->dma.dmaAddrAllowList, WH_DMA_OPER_CLIENT_WRITE_PRE, serverPtr, len); + rc = _checkMemOperAgainstAllowList(server, WH_DMA_OPER_CLIENT_WRITE_PRE, + serverPtr, len); if (rc != WH_ERROR_OK) { return rc; } diff --git a/test/wh_test_clientserver.c b/test/wh_test_clientserver.c index 66f975d5..0d98fb98 100644 --- a/test/wh_test_clientserver.c +++ b/test/wh_test_clientserver.c @@ -235,6 +235,46 @@ static int _testDma(whServerContext* server, whClientContext* client) /* Register our custom allow list */ WH_TEST_RETURN_ON_FAIL(wh_Server_DmaRegisterAllowList(server, &allowList)); + /* Check allowed operations for addresses in the allowlist */ + WH_TEST_ASSERT_RETURN( + WH_ERROR_OK == wh_Server_DmaCheckMemOperAllowed( + server, WH_DMA_OPER_CLIENT_READ_PRE, + testMem.srvBufAllow, sizeof(testMem.srvBufAllow))); + WH_TEST_ASSERT_RETURN( + WH_ERROR_OK == + wh_Server_DmaCheckMemOperAllowed(server, WH_DMA_OPER_CLIENT_READ_PRE, + testMem.srvRemapBufAllow, + sizeof(testMem.srvRemapBufAllow))); + WH_TEST_ASSERT_RETURN( + WH_ERROR_OK == wh_Server_DmaCheckMemOperAllowed( + server, WH_DMA_OPER_CLIENT_WRITE_PRE, + testMem.srvBufAllow, sizeof(testMem.srvBufAllow))); + WH_TEST_ASSERT_RETURN( + WH_ERROR_OK == + wh_Server_DmaCheckMemOperAllowed(server, WH_DMA_OPER_CLIENT_WRITE_PRE, + testMem.srvRemapBufAllow, + sizeof(testMem.srvRemapBufAllow))); + + /* Ensure an address not in the allowlist is denied for all operations */ + WH_TEST_ASSERT_RETURN(WH_ERROR_ACCESS == + wh_Server_DmaCheckMemOperAllowed( + server, WH_DMA_OPER_CLIENT_READ_PRE, + testMem.srvBufDeny, sizeof(testMem.srvBufDeny))); + WH_TEST_ASSERT_RETURN(WH_ERROR_ACCESS == + wh_Server_DmaCheckMemOperAllowed( + server, WH_DMA_OPER_CLIENT_WRITE_PRE, + testMem.srvBufDeny, sizeof(testMem.srvBufDeny))); + + /* Zero sized operations should be denied */ + WH_TEST_ASSERT_RETURN( + WH_ERROR_BADARGS == + wh_Server_DmaCheckMemOperAllowed(server, WH_DMA_OPER_CLIENT_READ_PRE, + testMem.srvBufAllow, 0)); + WH_TEST_ASSERT_RETURN( + WH_ERROR_BADARGS == + wh_Server_DmaCheckMemOperAllowed(server, WH_DMA_OPER_CLIENT_WRITE_PRE, + testMem.srvBufAllow, 0)); + /* Set known pattern in client Buffer */ memset(testMem.cliBuf, TEST_MEM_CLI_BYTE, sizeof(testMem.cliBuf)); @@ -314,6 +354,32 @@ static int _testDma(whServerContext* server, whClientContext* client) sizeof(testMem.srvBufDeny), (whServerDmaFlags){0})); } + /* Check that zero-sized copies fail, even from allowed addresses */ + if (sizeof(void*) == sizeof(uint64_t)) { + /* 64-bit host system */ + WH_TEST_ASSERT_RETURN(WH_ERROR_BADARGS == + whServerDma_CopyFromClient64( + server, testMem.srvBufAllow, + (uint64_t)((uintptr_t)testMem.cliBuf), 0, + (whServerDmaFlags){0})); + WH_TEST_ASSERT_RETURN(WH_ERROR_BADARGS == + whServerDma_CopyToClient64( + server, (uint64_t)((uintptr_t)testMem.cliBuf), + testMem.srvBufAllow, 0, (whServerDmaFlags){0})); + } + else if (sizeof(void*) == sizeof(uint32_t)) { + /* 32-bit host system */ + WH_TEST_ASSERT_RETURN(WH_ERROR_BADARGS == + whServerDma_CopyFromClient32( + server, testMem.srvBufAllow, + (uint32_t)((uintptr_t)testMem.cliBuf), 0, + (whServerDmaFlags){0})); + WH_TEST_ASSERT_RETURN(WH_ERROR_BADARGS == + whServerDma_CopyToClient32( + server, (uint32_t)((uintptr_t)testMem.cliBuf), + testMem.srvBufAllow, 0, (whServerDmaFlags){0})); + } + return rc; } diff --git a/wolfhsm/wh_server_dma.h b/wolfhsm/wh_server_dma.h index 81c70924..19d9271b 100644 --- a/wolfhsm/wh_server_dma.h +++ b/wolfhsm/wh_server_dma.h @@ -79,6 +79,11 @@ int wh_Server_DmaRegisterCb64(struct whServerContext_t* server, int wh_Server_DmaRegisterAllowList(struct whServerContext_t* server, const whServerDmaAddrAllowList* allowlist); +/* Checks a desired memory operation against the server allowlist */ +int wh_Server_DmaCheckMemOperAllowed(const struct whServerContext_t* server, + whServerDmaOper oper, void* addr, + size_t size); + /* Helper functions to invoke user supplied client address DMA callbacks */ int wh_Server_DmaProcessClientAddress32(struct whServerContext_t* server, uint32_t clientAddr, void** serverPtr, @@ -89,6 +94,8 @@ int wh_Server_DmaProcessClientAddress64(struct whServerContext_t* server, uint64_t len, whServerDmaOper oper, whServerDmaFlags flags); +/* Helper functions to copy data to/from client addresses that invoke the + * appropriate callbacks and allowlist checks */ int whServerDma_CopyFromClient32(struct whServerContext_t* server, void* serverPtr, uint32_t clientAddr, size_t len, whServerDmaFlags flags);