summaryrefslogtreecommitdiff
path: root/sys/kern/kern_socket.c
diff options
context:
space:
mode:
Diffstat (limited to 'sys/kern/kern_socket.c')
-rw-r--r--sys/kern/kern_socket.c313
1 files changed, 313 insertions, 0 deletions
diff --git a/sys/kern/kern_socket.c b/sys/kern/kern_socket.c
index 7f58921..8be5031 100644
--- a/sys/kern/kern_socket.c
+++ b/sys/kern/kern_socket.c
@@ -323,6 +323,7 @@ int
bind(int sockfd, const struct sockaddr *addr, socklen_t len)
{
struct ksocket *ksock;
+ struct cmsg_list *clp;
int error;
if ((error = get_ksock(sockfd, &ksock)) < 0) {
@@ -335,9 +336,321 @@ bind(int sockfd, const struct sockaddr *addr, socklen_t len)
if (ksock->mtx == NULL) {
return -ENOMEM;
}
+
+ /* Initialize the cmsg list queue */
+ clp = &ksock->cmsg_list;
+ TAILQ_INIT(&clp->list);
+ clp->is_init = 1;
+ return 0;
+}
+
+/*
+ * Send socket control message - POSIX.1-2008
+ *
+ * @socket: Socket to transmit on
+ * @msg: Further arguments
+ * @flags: Optional flags
+ *
+ * Returns zero on success, otherwise a less
+ * than zero errno.
+ */
+ssize_t
+sendmsg(int socket, const struct msghdr *msg, int flags)
+{
+ struct ksocket *ksock;
+ struct cmsg *cmsg;
+ struct sockaddr_un *un;
+ struct cmsg_list *clp;
+ size_t control_len = 0;
+ int error;
+
+ if ((error = get_ksock(socket, &ksock)) < 0) {
+ return error;
+ }
+
+ /* We cannot do sendmsg() non domain sockets */
+ un = &ksock->un;
+ if (un->sun_family != AF_UNIX) {
+ return -EBADF;
+ }
+
+ control_len = MALIGN(msg->msg_controllen);
+
+ /* Allocate a new cmsg */
+ cmsg = dynalloc(control_len + sizeof(struct cmsg));
+ if (cmsg == NULL) {
+ return -EINVAL;
+ }
+
+ memcpy(cmsg->buf, msg->msg_control, control_len);
+ clp = &ksock->cmsg_list;
+ cmsg->control_len = control_len;
+ TAILQ_INSERT_TAIL(&clp->list, cmsg, link);
return 0;
}
+/*
+ * Receive socket control message - POSIX.1‐2017
+ *
+ * @socket: Socket to receive on
+ * @msg: Further arguments
+ * @flags: Optional flags
+ *
+ * Returns zero on success, otherwise a less
+ * than zero errno.
+ */
+ssize_t
+recvmsg(int socket, struct msghdr *msg, int flags)
+{
+ struct ksocket *ksock;
+ struct sockaddr_un *un;
+ struct cmsg *cmsg, *tmp;
+ struct cmsghdr *cmsghdr;
+ struct cmsg_list *clp;
+ uint8_t *fds;
+ int error;
+
+ if (socket < 0) {
+ return -EINVAL;
+ }
+
+ /* Grab the socket descriptor */
+ if ((error = get_ksock(socket, &ksock)) < 0) {
+ return error;
+ }
+
+ /* Must be a unix domain socket */
+ un = &ksock->un;
+ if (un->sun_family != AF_UNIX) {
+ return -EBADF;
+ }
+
+ /* Grab the control message list */
+ clp = &ksock->cmsg_list;
+ cmsg = TAILQ_FIRST(&clp->list);
+
+ while (cmsg != NULL) {
+ cmsghdr = &cmsg->hdr;
+
+ /* Check the control message type */
+ switch (cmsghdr->cmsg_type) {
+ case SCM_RIGHTS:
+ {
+ fds = (uint8_t *)CMSG_DATA(cmsghdr);
+ pr_trace("SCM_RIGHTS -> fd %d\n", fds[0]);
+ break;
+ }
+ }
+
+ tmp = cmsg;
+ cmsg = TAILQ_NEXT(cmsg, link);
+
+ TAILQ_REMOVE(&clp->list, tmp, link);
+ dynfree(tmp);
+ }
+
+ return 0;
+}
+
+/*
+ * socket(7) syscall
+ *
+ * arg0: domain
+ * arg1: type
+ * arg2: protocol
+ */
+scret_t
+sys_socket(struct syscall_args *scargs)
+{
+ int domain = scargs->arg0;
+ int type = scargs->arg1;
+ int protocol = scargs->arg2;
+
+ return socket(domain, type, protocol);
+}
+
+/*
+ * bind(2) syscall
+ *
+ * arg0: sockfd
+ * arg1: addr
+ * arg2: len
+ */
+scret_t
+sys_bind(struct syscall_args *scargs)
+{
+ const struct sockaddr *u_addr = (void *)scargs->arg1;
+ struct sockaddr addr_copy;
+ int sockfd = scargs->arg0;
+ int len = scargs->arg2;
+ int error;
+
+ error = copyin(u_addr, &addr_copy, sizeof(addr_copy));
+ if (error < 0) {
+ return error;
+ }
+
+ return bind(sockfd, &addr_copy, len);
+}
+
+/*
+ * recv(2) syscall
+ *
+ * arg0: sockfd
+ * arg1: buf
+ * arg2: size
+ * arg3: flags
+ */
+scret_t
+sys_recv(struct syscall_args *scargs)
+{
+ char buf[NETBUF_LEN];
+ void *u_buf = (void *)scargs->arg1;
+ int sockfd = scargs->arg0;
+ size_t len = scargs->arg2;
+ int error, flags = scargs->arg3;
+
+ if (len > sizeof(buf)) {
+ return -ENOBUFS;
+ }
+
+ error = recv(sockfd, buf, len, flags);
+ if (error < 0) {
+ pr_error("sys_recv: recv() fail (fd=%d)\n", sockfd);
+ return error;
+ }
+
+ error = copyout(buf, u_buf, len);
+ return (error == 0) ? len : error;
+}
+
+/*
+ * send(2) syscall
+ *
+ * arg0: sockfd
+ * arg1: buf
+ * arg2: size
+ * arg3: flags
+ */
+scret_t
+sys_send(struct syscall_args *scargs)
+{
+ char buf[NETBUF_LEN];
+ const void *u_buf = (void *)scargs->arg1;
+ int sockfd = scargs->arg0;
+ size_t len = scargs->arg2;
+ int error, flags = scargs->arg3;
+
+ if (len > sizeof(buf)) {
+ return -ENOBUFS;
+ }
+
+ error = copyin(u_buf, buf, len);
+ if (error < 0) {
+ pr_error("sys_send: copyin() failure (fd=%d)\n", sockfd);
+ return error;
+ }
+
+ return send(sockfd, buf, len, flags);
+}
+
+/*
+ * recvmsg(3) syscall
+ *
+ * arg0: socket
+ * arg1: msg
+ * arg2: flags
+ */
+scret_t
+sys_recvmsg(struct syscall_args *scargs)
+{
+ struct msghdr *u_msg = (void *)scargs->arg1;
+ void *u_control;
+ size_t controllen;
+ struct iovec msg_iov;
+ struct msghdr msg;
+ ssize_t retval;
+ int socket = scargs->arg0;
+ int flags = scargs->arg2;
+ int error;
+
+ /* Read the message header */
+ error = copyin(u_msg, &msg, sizeof(msg));
+ if (error < 0) {
+ pr_error("sys_recvmsg: bad msg\n");
+ return error;
+ }
+
+ /* Grab the message I/O vector */
+ error = uio_copyin(msg.msg_iov, &msg_iov, msg.msg_iovlen);
+ if (error < 0) {
+ return error;
+ }
+
+ /* Save control fields */
+ u_control = msg.msg_control;
+ controllen = msg.msg_controllen;
+
+ /* Allocate a new control field to copy in */
+ msg.msg_control = dynalloc(controllen);
+ if (msg.msg_control == NULL) {
+ uio_copyin_clean(&msg_iov, msg.msg_iovlen);
+ return -ENOMEM;
+ }
+
+ error = copyin(u_control, msg.msg_control, controllen);
+ if (error < 0) {
+ retval = error;
+ goto done;
+ }
+
+ msg.msg_iov = &msg_iov;
+ retval = recvmsg(socket, &msg, flags);
+done:
+ uio_copyin_clean(&msg_iov, msg.msg_iovlen);
+ dynfree(msg.msg_control);
+ return retval;
+}
+
+/*
+ * sendmsg(3) syscall
+ *
+ * arg0: socket
+ * arg1: msg
+ * arg2: flags
+ */
+scret_t
+sys_sendmsg(struct syscall_args *scargs)
+{
+ struct iovec msg_iov;
+ struct msghdr *u_msg = (void *)scargs->arg1;
+ struct msghdr msg;
+ ssize_t retval;
+ int socket = scargs->arg0;
+ int flags = scargs->arg2;
+ int error;
+
+ /* Read the message header */
+ error = copyin(u_msg, &msg, sizeof(msg));
+ if (error < 0) {
+ pr_error("sys_sendmsg: bad msg\n");
+ return error;
+ }
+
+ /* Grab the message I/O vector */
+ error = uio_copyin(msg.msg_iov, &msg_iov, msg.msg_iovlen);
+ if (error < 0) {
+ return error;
+ }
+
+ msg.msg_name = NULL;
+ msg.msg_namelen = 0;
+ msg.msg_iov = &msg_iov;
+ retval = sendmsg(socket, &msg, flags);
+ uio_copyin_clean(&msg_iov, msg.msg_iovlen);
+ return retval;
+}
+
static struct vops socket_vops = {
.read = NULL,
.write = NULL,