diff options
Diffstat (limited to 'sys/kern/kern_socket.c')
-rw-r--r-- | sys/kern/kern_socket.c | 313 |
1 files changed, 313 insertions, 0 deletions
diff --git a/sys/kern/kern_socket.c b/sys/kern/kern_socket.c index 7f58921..8be5031 100644 --- a/sys/kern/kern_socket.c +++ b/sys/kern/kern_socket.c @@ -323,6 +323,7 @@ int bind(int sockfd, const struct sockaddr *addr, socklen_t len) { struct ksocket *ksock; + struct cmsg_list *clp; int error; if ((error = get_ksock(sockfd, &ksock)) < 0) { @@ -335,9 +336,321 @@ bind(int sockfd, const struct sockaddr *addr, socklen_t len) if (ksock->mtx == NULL) { return -ENOMEM; } + + /* Initialize the cmsg list queue */ + clp = &ksock->cmsg_list; + TAILQ_INIT(&clp->list); + clp->is_init = 1; + return 0; +} + +/* + * Send socket control message - POSIX.1-2008 + * + * @socket: Socket to transmit on + * @msg: Further arguments + * @flags: Optional flags + * + * Returns zero on success, otherwise a less + * than zero errno. + */ +ssize_t +sendmsg(int socket, const struct msghdr *msg, int flags) +{ + struct ksocket *ksock; + struct cmsg *cmsg; + struct sockaddr_un *un; + struct cmsg_list *clp; + size_t control_len = 0; + int error; + + if ((error = get_ksock(socket, &ksock)) < 0) { + return error; + } + + /* We cannot do sendmsg() non domain sockets */ + un = &ksock->un; + if (un->sun_family != AF_UNIX) { + return -EBADF; + } + + control_len = MALIGN(msg->msg_controllen); + + /* Allocate a new cmsg */ + cmsg = dynalloc(control_len + sizeof(struct cmsg)); + if (cmsg == NULL) { + return -EINVAL; + } + + memcpy(cmsg->buf, msg->msg_control, control_len); + clp = &ksock->cmsg_list; + cmsg->control_len = control_len; + TAILQ_INSERT_TAIL(&clp->list, cmsg, link); return 0; } +/* + * Receive socket control message - POSIX.1‐2017 + * + * @socket: Socket to receive on + * @msg: Further arguments + * @flags: Optional flags + * + * Returns zero on success, otherwise a less + * than zero errno. + */ +ssize_t +recvmsg(int socket, struct msghdr *msg, int flags) +{ + struct ksocket *ksock; + struct sockaddr_un *un; + struct cmsg *cmsg, *tmp; + struct cmsghdr *cmsghdr; + struct cmsg_list *clp; + uint8_t *fds; + int error; + + if (socket < 0) { + return -EINVAL; + } + + /* Grab the socket descriptor */ + if ((error = get_ksock(socket, &ksock)) < 0) { + return error; + } + + /* Must be a unix domain socket */ + un = &ksock->un; + if (un->sun_family != AF_UNIX) { + return -EBADF; + } + + /* Grab the control message list */ + clp = &ksock->cmsg_list; + cmsg = TAILQ_FIRST(&clp->list); + + while (cmsg != NULL) { + cmsghdr = &cmsg->hdr; + + /* Check the control message type */ + switch (cmsghdr->cmsg_type) { + case SCM_RIGHTS: + { + fds = (uint8_t *)CMSG_DATA(cmsghdr); + pr_trace("SCM_RIGHTS -> fd %d\n", fds[0]); + break; + } + } + + tmp = cmsg; + cmsg = TAILQ_NEXT(cmsg, link); + + TAILQ_REMOVE(&clp->list, tmp, link); + dynfree(tmp); + } + + return 0; +} + +/* + * socket(7) syscall + * + * arg0: domain + * arg1: type + * arg2: protocol + */ +scret_t +sys_socket(struct syscall_args *scargs) +{ + int domain = scargs->arg0; + int type = scargs->arg1; + int protocol = scargs->arg2; + + return socket(domain, type, protocol); +} + +/* + * bind(2) syscall + * + * arg0: sockfd + * arg1: addr + * arg2: len + */ +scret_t +sys_bind(struct syscall_args *scargs) +{ + const struct sockaddr *u_addr = (void *)scargs->arg1; + struct sockaddr addr_copy; + int sockfd = scargs->arg0; + int len = scargs->arg2; + int error; + + error = copyin(u_addr, &addr_copy, sizeof(addr_copy)); + if (error < 0) { + return error; + } + + return bind(sockfd, &addr_copy, len); +} + +/* + * recv(2) syscall + * + * arg0: sockfd + * arg1: buf + * arg2: size + * arg3: flags + */ +scret_t +sys_recv(struct syscall_args *scargs) +{ + char buf[NETBUF_LEN]; + void *u_buf = (void *)scargs->arg1; + int sockfd = scargs->arg0; + size_t len = scargs->arg2; + int error, flags = scargs->arg3; + + if (len > sizeof(buf)) { + return -ENOBUFS; + } + + error = recv(sockfd, buf, len, flags); + if (error < 0) { + pr_error("sys_recv: recv() fail (fd=%d)\n", sockfd); + return error; + } + + error = copyout(buf, u_buf, len); + return (error == 0) ? len : error; +} + +/* + * send(2) syscall + * + * arg0: sockfd + * arg1: buf + * arg2: size + * arg3: flags + */ +scret_t +sys_send(struct syscall_args *scargs) +{ + char buf[NETBUF_LEN]; + const void *u_buf = (void *)scargs->arg1; + int sockfd = scargs->arg0; + size_t len = scargs->arg2; + int error, flags = scargs->arg3; + + if (len > sizeof(buf)) { + return -ENOBUFS; + } + + error = copyin(u_buf, buf, len); + if (error < 0) { + pr_error("sys_send: copyin() failure (fd=%d)\n", sockfd); + return error; + } + + return send(sockfd, buf, len, flags); +} + +/* + * recvmsg(3) syscall + * + * arg0: socket + * arg1: msg + * arg2: flags + */ +scret_t +sys_recvmsg(struct syscall_args *scargs) +{ + struct msghdr *u_msg = (void *)scargs->arg1; + void *u_control; + size_t controllen; + struct iovec msg_iov; + struct msghdr msg; + ssize_t retval; + int socket = scargs->arg0; + int flags = scargs->arg2; + int error; + + /* Read the message header */ + error = copyin(u_msg, &msg, sizeof(msg)); + if (error < 0) { + pr_error("sys_recvmsg: bad msg\n"); + return error; + } + + /* Grab the message I/O vector */ + error = uio_copyin(msg.msg_iov, &msg_iov, msg.msg_iovlen); + if (error < 0) { + return error; + } + + /* Save control fields */ + u_control = msg.msg_control; + controllen = msg.msg_controllen; + + /* Allocate a new control field to copy in */ + msg.msg_control = dynalloc(controllen); + if (msg.msg_control == NULL) { + uio_copyin_clean(&msg_iov, msg.msg_iovlen); + return -ENOMEM; + } + + error = copyin(u_control, msg.msg_control, controllen); + if (error < 0) { + retval = error; + goto done; + } + + msg.msg_iov = &msg_iov; + retval = recvmsg(socket, &msg, flags); +done: + uio_copyin_clean(&msg_iov, msg.msg_iovlen); + dynfree(msg.msg_control); + return retval; +} + +/* + * sendmsg(3) syscall + * + * arg0: socket + * arg1: msg + * arg2: flags + */ +scret_t +sys_sendmsg(struct syscall_args *scargs) +{ + struct iovec msg_iov; + struct msghdr *u_msg = (void *)scargs->arg1; + struct msghdr msg; + ssize_t retval; + int socket = scargs->arg0; + int flags = scargs->arg2; + int error; + + /* Read the message header */ + error = copyin(u_msg, &msg, sizeof(msg)); + if (error < 0) { + pr_error("sys_sendmsg: bad msg\n"); + return error; + } + + /* Grab the message I/O vector */ + error = uio_copyin(msg.msg_iov, &msg_iov, msg.msg_iovlen); + if (error < 0) { + return error; + } + + msg.msg_name = NULL; + msg.msg_namelen = 0; + msg.msg_iov = &msg_iov; + retval = sendmsg(socket, &msg, flags); + uio_copyin_clean(&msg_iov, msg.msg_iovlen); + return retval; +} + static struct vops socket_vops = { .read = NULL, .write = NULL, |