diff --git a/libc/socket/unix-socket.c b/libc/socket/unix-socket.c index adb11a4f..a7477004 100644 --- a/libc/socket/unix-socket.c +++ b/libc/socket/unix-socket.c @@ -6,7 +6,7 @@ * unix socket tests * * Copyright 2021, 2024 Phoenix Systems - * Author: Ziemowit Leszczynski, Adam Debek + * Author: Ziemowit Leszczynski, Adam Debek, Adam Greloch * * This file is part of Phoenix-RTOS. * @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -33,12 +34,39 @@ #include "common.h" #include "unity_fixture.h" +#define BAD_FD 33333 /* should be bad descriptor */ + +#define ASSERT_ERRNO(rv, expected) \ + do { \ + if ((rv) < 0 && errno != (expected)) \ + FAIL("errno"); \ + } while (0) + +#define CHILD_ASSERT_ERRNO(rv, expected) \ + do { \ + if ((rv) < 0 && errno != (expected)) \ + exit(1); \ + } while (0) + +#define FAIL_OR_EXIT(pid, msg) \ + do { \ + if (pid) \ + FAIL(msg); \ + else \ + exit(1); \ + } while (0) + +#define CHILD_ASSERT(pred) \ + do { \ + if (!(pred)) \ + exit(1); \ + } while (0) char data[DATA_SIZE]; char buf[DATA_SIZE]; -ssize_t unix_named_socket(int type, const char *name) +static ssize_t unix_named_socket(int type, const char *name) { int fd; struct sockaddr_un addr = { 0 }; @@ -60,7 +88,7 @@ ssize_t unix_named_socket(int type, const char *name) } -int unix_connect(int fd, const char *name) +static int connect_to_named(int fd, const char *name) { struct sockaddr_un addr = { 0 }; @@ -71,7 +99,7 @@ int unix_connect(int fd, const char *name) } -int unlink_files(size_t cnt) +static int unlink_files(size_t cnt) { size_t i; char buf[64]; @@ -86,6 +114,21 @@ int unlink_files(size_t cnt) } +static pid_t safe_fork(void) +{ + pid_t pid; + if ((pid = fork()) < 0) { + if (errno == ENOSYS) { + TEST_IGNORE_MESSAGE("fork syscall not supported"); + } + else { + FAIL("fork"); + } + } + return pid; +} + + TEST_GROUP(test_unix_socket); @@ -208,7 +251,7 @@ TEST(test_unix_socket, zero_len_send) n = sendmsg(fd[0], &msg, 0); TEST_ASSERT(n == 0); - fd[2] = 33333; /* should be bad descriptor */ + fd[2] = BAD_FD; memset(&msg, 0, sizeof(msg)); msg.msg_iov = &iov; msg.msg_iovlen = 0; @@ -370,7 +413,7 @@ TEST(test_unix_socket, close) } -void unix_msg_data_only(int type) +static void unix_msg_data_only(int type) { unsigned int i; int fd[2]; @@ -413,7 +456,7 @@ TEST(test_unix_socket, msg_data_only) } -void unix_msg_data_and_fd(int type) +static void unix_msg_data_and_fd(int type) { int i; int fd[2]; @@ -488,7 +531,7 @@ TEST(test_unix_socket, dgram_sock_data_and_fd) } -void unix_msg_fork(int type) +static void unix_msg_fork(int type) { int fd[2]; pid_t pid; @@ -499,14 +542,7 @@ void unix_msg_fork(int type) if (socketpair(AF_UNIX, type, 0, fd) < 0) FAIL("socketpair"); - if ((pid = fork()) < 0) { - if (errno == ENOSYS) { - TEST_IGNORE_MESSAGE("fork syscall not supported"); - } - else { - FAIL("fork"); - } - } + pid = safe_fork(); if (pid) { int sfd[MAX_FD_CNT]; @@ -549,13 +585,13 @@ void unix_msg_fork(int type) exit(1); if (read_files(rfd, rfdcnt, data, buf) < 0) - FAIL("read_files"); + exit(1); if (close_files(rfd, rfdcnt) < 0) exit(2); if (stat_files(rfd, rfdcnt, 0) < 0) - FAIL("stat_files"); + exit(1); exit(0); } @@ -582,7 +618,7 @@ TEST(test_unix_socket, dgram_sock_msg_fork) } -int unix_data_cmp(char *buf, size_t pos, size_t len) +static int unix_data_cmp(char *buf, size_t pos, size_t len) { size_t i; @@ -595,7 +631,7 @@ int unix_data_cmp(char *buf, size_t pos, size_t len) } -void unix_transfer(int type) +static void unix_transfer(int type) { int fd[2]; pid_t pid; @@ -606,14 +642,7 @@ void unix_transfer(int type) if (socketpair(AF_UNIX, type | SOCK_NONBLOCK, 0, fd) < 0) FAIL("socketpair"); - if ((pid = fork()) < 0) { - if (errno == ENOSYS) { - TEST_IGNORE_MESSAGE("fork syscall not supported"); - } - else { - FAIL("fork"); - } - } + pid = safe_fork(); if (pid) { size_t max_len, len, pos = 0; @@ -648,9 +677,9 @@ void unix_transfer(int type) while (tot_len > 0) { n = recv(fd[1], buf, sizeof(buf), 0); - TEST_ASSERT(n > 0 || errno == EAGAIN); + CHILD_ASSERT(n > 0 || errno == EAGAIN); if (n > 0) { - TEST_ASSERT(unix_data_cmp(buf, pos, n) == 0); + CHILD_ASSERT(unix_data_cmp(buf, pos, n) == 0); tot_len -= n; pos = (pos + n) % sizeof(data); } @@ -672,7 +701,7 @@ TEST(test_unix_socket, transfer) } -void unix_close_connected(int type) +static void unix_close_connected(int type) { int fd[2]; @@ -698,13 +727,13 @@ TEST(test_unix_socket, close_connected) volatile int got_epipe; -void sighandler(int sig) +static void sighandler(int sig) { got_epipe = 1; } -void unix_send_after_close(int type, int epipe, int err) +static void unix_send_after_close(int type, int epipe, int err) { int fd[2]; ssize_t n; @@ -747,7 +776,7 @@ TEST(test_unix_socket, send_after_close) } -void unix_recv_after_close(int type) +static void unix_recv_after_close(int type) { int fd[2]; ssize_t n; @@ -791,7 +820,7 @@ TEST(test_unix_socket, recv_after_close) } -void unix_connect_after_close(int type) +static void unix_connect_after_close(int type) { int fd[3]; int rv; @@ -804,7 +833,7 @@ void unix_connect_after_close(int type) if ((fd[2] = unix_named_socket(SOCK_DGRAM, "/tmp/test_connect_after_close")) < 0) FAIL("unix_named_socket(SOCK_DGRAM, "); - rv = unix_connect(fd[0], "/tmp/test_connect_after_close"); + rv = connect_to_named(fd[0], "/tmp/test_connect_after_close"); TEST_ASSERT(rv == -1); /* EPROTOTYPE??? */ // TEST_ASSERT(errno == EISCONN); @@ -824,7 +853,7 @@ TEST(test_unix_socket, connect_after_close) } -void unix_poll(int type) +static void unix_poll(int type) { int fd[2]; struct pollfd fds[2]; @@ -904,6 +933,430 @@ TEST(test_unix_socket, poll) } +static void read_msg(int fd, pid_t pid, int send_flags) +{ + int msg_len = 128; + int rv; + + memset(buf, 0, msg_len); + if ((rv = read(fd, buf, msg_len)) < 0) + FAIL_OR_EXIT(pid, "read"); + + if (msg_len != rv) + FAIL_OR_EXIT(pid, "msg_len != rv"); + + if (0 != strncmp(buf, data, msg_len)) + FAIL_OR_EXIT(pid, "strncmp != 0"); +} + + +static void send_msg(int fd, pid_t pid, int send_flags) +{ + int msg_len = 128; + int rv; + + if ((rv = send(fd, data, msg_len, send_flags)) < 0) + FAIL_OR_EXIT(pid, "read"); + + if (msg_len != rv) + FAIL_OR_EXIT(pid, "msg_len != rv"); +} + + +/** Note: makes sense for child processes only */ +static int connect_to_named_or_timeout(int fd, const char *name, int timeout_ms) +{ + struct timespec ts[2]; + int ms, rv; + clock_gettime(CLOCK_REALTIME, &ts[0]); + while (true) { + rv = connect_to_named(fd, name); + if (rv == 0) + break; + else { + CHILD_ASSERT_ERRNO(rv, ECONNREFUSED); + clock_gettime(CLOCK_REALTIME, &ts[1]); + ms = (ts[1].tv_sec - ts[0].tv_sec) * 1000 + (ts[1].tv_nsec - ts[0].tv_nsec) / 1000000; + if (ms > timeout_ms) + exit(1); + + usleep(150); + } + } + return rv; +} + + +static void unix_accept_connect_errnos(int type) +{ + int fd, named, rv, conn; + + const char *socket_name = "/tmp/test_accept_connect_errnos"; + + rv = connect_to_named(BAD_FD, socket_name); + TEST_ASSERT(rv < 0); + TEST_ASSERT_EQUAL_INT(EBADF, errno); + + if ((fd = socket(AF_UNIX, type, 0)) < 0) + FAIL("socket"); + + rv = connect_to_named(fd, socket_name); + TEST_ASSERT(rv < 0); + TEST_ASSERT_EQUAL_INT(ECONNREFUSED, errno); + + if ((named = unix_named_socket(type, socket_name)) < 0) + FAIL("unix_named_socket"); + + if (set_nonblock(named, 1) < 0) + FAIL("set_nonblock"); + + if (listen(named, 0) < 0) + FAIL("listen"); + + conn = accept(named, NULL, NULL); + TEST_ASSERT(conn < 0); + TEST_ASSERT_EQUAL_INT(EWOULDBLOCK, errno); + + if (set_nonblock(fd, 1) < 0) + FAIL("set_nonblock"); + + rv = connect_to_named(fd, socket_name); + TEST_ASSERT(rv < 0); + TEST_ASSERT_EQUAL_INT(EINPROGRESS, errno); + + rv = connect_to_named(fd, socket_name); + TEST_ASSERT(rv < 0); + TEST_ASSERT_EQUAL_INT(EALREADY, errno); + + close(fd); + close(named); +} + + +TEST(test_unix_socket, accept_connect_errnos) +{ + unsigned int i; + + for (i = 0; i < CONNECTED_LOOP_CNT; ++i) { + unix_accept_connect_errnos(SOCK_STREAM); + unix_accept_connect_errnos(SOCK_SEQPACKET); + } +} + + +static void unix_accept_connect_async(int type) +{ + int client_fd, server_fd, rv, conn; + struct pollfd fds[3]; + + const char *socket_name = "/tmp/test_accept_connect_async"; + + if ((server_fd = unix_named_socket(type, socket_name)) < 0) + FAIL("unix_named_socket"); + + if (set_nonblock(server_fd, 1) < 0) + FAIL("set_nonblock"); + + if (listen(server_fd, 0) < 0) + FAIL("listen"); + + if ((client_fd = socket(AF_UNIX, type, 0)) < 0) + FAIL("socket"); + + if (set_nonblock(client_fd, 1) < 0) + FAIL("set_nonblock"); + + conn = accept(server_fd, NULL, NULL); + TEST_ASSERT(conn < 0); + TEST_ASSERT_EQUAL_INT(EWOULDBLOCK, errno); + + rv = connect_to_named(client_fd, socket_name); + TEST_ASSERT(rv < 0); + TEST_ASSERT_EQUAL_INT(EINPROGRESS, errno); + + fds[0].fd = server_fd; + fds[0].events = POLLIN; + fds[1].fd = client_fd; + fds[1].events = POLLOUT; + + /* poll for incoming connection on server_fd (POLLIN) */ + TEST_ASSERT_EQUAL_INT(1, poll(fds, 2, 1000)); + TEST_ASSERT_EQUAL_INT(POLLIN, fds[0].revents); + TEST_ASSERT_EQUAL_INT(0, fds[1].revents); + + fds[2].fd = accept(server_fd, NULL, NULL); + fds[2].events = POLLIN; + TEST_ASSERT(fds[2].fd > 0); + + rv = connect_to_named(client_fd, socket_name); + TEST_ASSERT(rv < 0); + TEST_ASSERT_EQUAL_INT(EISCONN, errno); + + /* poll for connection on client_fd to be established (POLLOUT) */ + TEST_ASSERT_EQUAL_INT(1, poll(fds, 2, 1000)); + TEST_ASSERT_EQUAL_INT(0, fds[0].revents); + TEST_ASSERT_EQUAL_INT(POLLOUT, fds[1].revents); + + rv = connect_to_named(client_fd, socket_name); + TEST_ASSERT(rv < 0); + TEST_ASSERT_EQUAL_INT(EISCONN, errno); + + send_msg(client_fd, 1, 0); + + TEST_ASSERT_EQUAL_INT(2, poll(fds, 3, 1000)); + TEST_ASSERT_EQUAL_INT(0, fds[0].revents); + TEST_ASSERT_EQUAL_INT(POLLOUT, fds[1].revents); + TEST_ASSERT_EQUAL_INT(POLLIN, fds[2].revents); + + read_msg(fds[2].fd, 1, 0); + + close(fds[0].fd); + close(fds[1].fd); + close(fds[2].fd); +} + + +TEST(test_unix_socket, accept_connect_async) +{ + unsigned int i; + + for (i = 0; i < CONNECTED_LOOP_CNT; ++i) { + unix_accept_connect_async(SOCK_STREAM); + unix_accept_connect_async(SOCK_SEQPACKET); + } +} + + +static void unix_accept_connect_liveness_helper(int type) +{ + pid_t pid; + int fd, named, rv, conn, status; + struct pollfd fds[2]; + + const char *socket_name = "/tmp/test_accept_connect"; + + /* blocking connect, blocking accept */ + + pid = safe_fork(); + + if (pid) { + if ((named = unix_named_socket(type, socket_name)) < 0) + FAIL("unix_named_socket"); + + if (listen(named, 0) < 0) + FAIL("listen"); + + if ((conn = accept(named, NULL, NULL)) < 0) + FAIL("accept"); + + read_msg(conn, pid, 0); + + if (waitpid(pid, &status, 0) < 0) + FAIL("waitpid"); + + TEST_ASSERT(WIFEXITED(status)); + TEST_ASSERT_EQUAL_INT(0, WEXITSTATUS(status)); + + close(conn); + close(named); + } + else { + if ((fd = socket(AF_UNIX, type, 0)) < 0) + exit(1); + + rv = connect_to_named_or_timeout(fd, socket_name, 3000); + + send_msg(fd, pid, 0); + + close(fd); + + exit(0); + } + + /* blocking connect, nonblocking accept */ + + pid = safe_fork(); + + if (pid) { + if ((named = unix_named_socket(type, socket_name)) < 0) + FAIL("unix_named_socket"); + + if (set_nonblock(named, 1) < 0) + FAIL("set_nonblock"); + + if (listen(named, 0) < 0) + FAIL("listen"); + + fds[0].fd = named; + fds[0].events = POLLIN; + + TEST_ASSERT_EQUAL_INT(1, poll(fds, 1, 500)); + TEST_ASSERT_EQUAL_INT(POLLIN, fds[0].revents); + + conn = accept(fds[0].fd, NULL, NULL); + TEST_ASSERT(conn > 0); + + send_msg(conn, pid, 0); + + if (waitpid(pid, &status, 0) < 0) + FAIL("waitpid"); + + TEST_ASSERT(WIFEXITED(status)); + TEST_ASSERT_EQUAL_INT(0, WEXITSTATUS(status)); + + close(conn); + close(named); + } + else { + if ((fd = socket(AF_UNIX, type, 0)) < 0) + exit(1); + + rv = connect_to_named_or_timeout(fd, socket_name, 3000); + + read_msg(fd, pid, 0); + + close(fd); + + exit(0); + } + + /* nonblocking connect, blocking accept */ + + pid = safe_fork(); + + if (pid) { + if ((fd = socket(AF_UNIX, type, 0)) < 0) + FAIL("socket"); + + if (set_nonblock(fd, 1) < 0) + FAIL("set_nonblock"); + + while (true) { + rv = connect_to_named(fd, socket_name); + if (rv >= 0) { + FAIL("should never happen - child proc should sleep for longer"); + } + else if (rv < 0 && errno == EINPROGRESS) { + break; + } + else { + ASSERT_ERRNO(rv, ECONNREFUSED); + usleep(50); + } + } + + fds[0].fd = fd; + fds[0].events = POLLOUT; + + TEST_ASSERT_EQUAL_INT(1, poll(fds, 1, 700)); + TEST_ASSERT_EQUAL_INT(POLLOUT, fds[0].revents); + TEST_ASSERT_EQUAL_INT(0, getsockopt(fds[0].fd, SOL_SOCKET, SO_ERROR, NULL, NULL)); + + fds[0].events = POLLIN; + TEST_ASSERT_EQUAL_INT(1, poll(fds, 1, 250)); + TEST_ASSERT_EQUAL_INT(POLLIN, fds[0].revents); + + read_msg(fds[0].fd, pid, 0); + + fds[0].events = POLLOUT; + TEST_ASSERT_EQUAL_INT(1, poll(fds, 1, 250)); + TEST_ASSERT_EQUAL_INT(POLLOUT, fds[0].revents); + send_msg(fds[0].fd, pid, 0); + + if (waitpid(pid, &status, 0) < 0) + FAIL("waitpid"); + + TEST_ASSERT(WIFEXITED(status)); + TEST_ASSERT_EQUAL_INT(0, WEXITSTATUS(status)); + + close(fds[0].fd); + } + else { + if ((named = unix_named_socket(type, socket_name)) < 0) + exit(1); + + if (listen(named, 0) < 0) + exit(1); + + usleep(500); /* sleep so that connect would block */ + + if ((conn = accept(named, NULL, NULL)) < 0) + exit(1); + + send_msg(conn, pid, 0); + + /* read something from parent so that parent does the first + * POLLOUT before conn gets closed */ + read_msg(conn, pid, 0); + close(conn); + + close(named); + + exit(0); + } +} + + +static void unix_accept_connect_liveness(int type) +{ + unsigned int i = 0; + + for (i = 0; i < 25; ++i) { + unix_accept_connect_liveness_helper(type); + } +} + + +TEST(test_unix_socket, accept_connect_liveness) +{ + unix_accept_connect_liveness(SOCK_STREAM); + unix_accept_connect_liveness(SOCK_SEQPACKET); +} + + +static void unix_socket_recv_msg_peek(int flags) +{ + int fd[2]; + int msg_len = 128; + ssize_t n; + + if (socketpair(AF_UNIX, SOCK_STREAM, 0, fd) < 0) + FAIL("socketpair"); + + n = write(fd[0], data, msg_len); + TEST_ASSERT(n == msg_len); + + n = write(fd[0], data, 1); + TEST_ASSERT(n == 1); + + /** Peek on first 2 iterations, on 3th do a normal read */ + for (int i = 0; i < 4; i++) { + n = recv(fd[1], buf, msg_len, flags | (i < 2 ? MSG_PEEK : 0)); + if (i < 3) { + /* Should read the same message 3 times */ + TEST_ASSERT(n == msg_len); + TEST_ASSERT(strncmp(buf, data, msg_len) == 0); + } + else + /* Should read one byte as previous message was normally read on 3rd + * iteration */ + TEST_ASSERT(n == 1); + } + + close(fd[0]); + close(fd[1]); +} + + +TEST(test_unix_socket, recv_msg_peek) +{ + unix_socket_recv_msg_peek(0); + unix_socket_recv_msg_peek(MSG_DONTWAIT); +} + + +// TODO: add listen() backlog test when implemented + TEST_GROUP_RUNNER(test_unix_socket) { RUN_TEST_CASE(test_unix_socket, zero_len_send); @@ -920,6 +1373,10 @@ TEST_GROUP_RUNNER(test_unix_socket) RUN_TEST_CASE(test_unix_socket, recv_after_close); RUN_TEST_CASE(test_unix_socket, connect_after_close); RUN_TEST_CASE(test_unix_socket, poll); + RUN_TEST_CASE(test_unix_socket, recv_msg_peek); + RUN_TEST_CASE(test_unix_socket, accept_connect_errnos); + RUN_TEST_CASE(test_unix_socket, accept_connect_async); + RUN_TEST_CASE(test_unix_socket, accept_connect_liveness); } void runner(void)