diff options
author | Ian Moffett <ian@osmora.org> | 2024-09-29 22:38:43 -0400 |
---|---|---|
committer | Ian Moffett <ian@osmora.org> | 2024-09-29 22:38:43 -0400 |
commit | 73ead92c2d37d5d091992ef617c4abdfe9907a18 (patch) | |
tree | 1b689727607d72e525cee5bd298367aadc293615 | |
parent | 788b1308e86320882245159540ef0a489209bcf1 (diff) |
project: Massive fixups
- Fix client handling
- Add multithreading
- Fixup bad values
Signed-off-by: Ian Moffett <ian@osmora.org>
-rw-r--r-- | lib/include/libostp/server.h | 17 | ||||
-rw-r--r-- | lib/include/net/auth.h | 4 | ||||
-rw-r--r-- | lib/libostp/auth.c | 86 | ||||
-rw-r--r-- | lib/libostp/param.c | 16 | ||||
-rw-r--r-- | lib/libostp/server.c | 76 | ||||
-rw-r--r-- | ostp.d/init/main.c | 9 |
6 files changed, 123 insertions, 85 deletions
diff --git a/lib/include/libostp/server.h b/lib/include/libostp/server.h index a7a737b..0e232f6 100644 --- a/lib/include/libostp/server.h +++ b/lib/include/libostp/server.h @@ -32,23 +32,32 @@ #include <sys/select.h> #include <libostp/session.h> +#include <pthread.h> #include <stddef.h> #define MAX_CLIENTS 32 +struct ostp_client { + struct ostp_session session; + int sockfd; + pthread_t td; +}; + struct ostp_listener { - int(*on_recv)(struct ostp_session *session, const char *buf, size_t len); + int(*on_recv)(struct ostp_client *c, const char *buf, size_t len); int port; /* -- Private -- */ - int clients[MAX_CLIENTS]; + struct ostp_client clients[MAX_CLIENTS]; + size_t client_count; int serv_sockfd; fd_set client_fds; }; void listener_init(struct ostp_listener *lp); -int listener_bind(struct ostp_session *sp, struct ostp_listener *lp); -int listener_poll(struct ostp_session *sp, struct ostp_listener *lp); +int listener_bind(struct ostp_listener *lp); +int listener_poll(struct ostp_listener *lp); void listener_cleanup(struct ostp_listener *lp); +void listener_close(struct ostp_listener *lp, struct ostp_client *c); #endif /* !LIBOSTP_SERVER_H_ */ diff --git a/lib/include/net/auth.h b/lib/include/net/auth.h index d672231..0f22d22 100644 --- a/lib/include/net/auth.h +++ b/lib/include/net/auth.h @@ -34,8 +34,8 @@ #include <libostp/session.h> #include <libostp/server.h> -int handle_srq(struct ostp_session *sp, struct ostp_listener *lp, +int handle_srq(struct ostp_client *c, struct ostp_listener *lp, struct session_request *srq); -int negotiate_spw(struct ostp_session *sp, unsigned char *session_key); +int negotiate_spw(struct ostp_client *c, unsigned char *session_key); #endif /* NET_AUTH_H_ */ diff --git a/lib/libostp/auth.c b/lib/libostp/auth.c index d32c06a..f2097bc 100644 --- a/lib/libostp/auth.c +++ b/lib/libostp/auth.c @@ -35,6 +35,13 @@ #include <stdio.h> #include <unistd.h> #include <stdlib.h> +#include <errno.h> + +struct session_td_args { + struct ostp_client *c; + struct ostp_listener *lp; + unsigned char *session_key; +}; /* * Check a password to see if it matches with @@ -67,7 +74,7 @@ pwcheck(char *username, char *pw) } static int -passwd_auth(struct ostp_session *sp, const unsigned char *session_key) +passwd_auth(struct ostp_client *c, const unsigned char *session_key) { int error; struct session_auth auth; @@ -77,7 +84,7 @@ passwd_auth(struct ostp_session *sp, const unsigned char *session_key) return 0; } - error = recv_frame(sp->sockfd, sizeof(auth), session_key, &auth); + error = recv_frame(c->sockfd, sizeof(auth), session_key, &auth); if (error < 0) { return error; } @@ -85,7 +92,7 @@ passwd_auth(struct ostp_session *sp, const unsigned char *session_key) if (pwcheck(auth.username, auth.password) != 0) { printf("Got bad password for %s\n", auth.username); auth.code = AUTH_BAD_PW; - error = send_frame(sp->sockfd, &auth, sizeof(auth), session_key); + error = send_frame(c->sockfd, &auth, sizeof(auth), session_key); if (error < 0) { printf("Failed to ACK user authentication with frame\n"); } @@ -93,7 +100,7 @@ passwd_auth(struct ostp_session *sp, const unsigned char *session_key) } auth.code = AUTH_SUCCESS; - error = send_frame(sp->sockfd, &auth, sizeof(auth), session_key); + error = send_frame(c->sockfd, &auth, sizeof(auth), session_key); if (error < 0) { printf("Failed to ACK user authentication with frame\n"); return error; @@ -102,31 +109,32 @@ passwd_auth(struct ostp_session *sp, const unsigned char *session_key) } static void -send_motd(struct ostp_session *sp, const unsigned char *session_key) +send_motd(struct ostp_client *c, const unsigned char *session_key) { char motd[] = MOTD; printf("Sending MOTD...\n"); - if (send_frame(sp->sockfd, motd, sizeof(motd), session_key) < 0) { + if (send_frame(c->sockfd, motd, sizeof(motd), session_key) < 0) { printf("Failed to session MOTD\n"); } } static int -session_run(struct ostp_session *sp, struct ostp_listener *lp, - const unsigned char *session_key) +session_run(struct ostp_listener *lp, const unsigned char *session_key) { + struct ostp_client *c; char buf[4096]; size_t len; while (1) { for (int i = 1; i < MAX_CLIENTS; ++i) { - if (lp->clients[i] <= 0) + c = &lp->clients[i]; + if (c->sockfd <= 0) continue; - if (FD_ISSET(lp->clients[i], &lp->client_fds) <= 0) + if (FD_ISSET(c->sockfd, &lp->client_fds) <= 0) continue; - len = recv_frame(lp->clients[i], sizeof(buf) - 1, session_key, buf); + len = recv_frame(c->sockfd, sizeof(buf) - 1, session_key, buf); if (len < 0) { printf("recv_frame() failure, packet lost\n"); continue; @@ -135,22 +143,45 @@ session_run(struct ostp_session *sp, struct ostp_listener *lp, return 0; } if (lp->on_recv != NULL) { - lp->on_recv(sp, buf, len); + lp->on_recv(c, buf, len); } } } } +static void * +session_td(void *args) +{ + struct session_td_args *tmp = args; + int error; + + /* Try user auth, not needed if REQUIRE_USER_AUTH is 0 */ + if (passwd_auth(tmp->c, tmp->session_key) != 0) { + free_session_key(tmp->session_key); + exit(-1); + } + + /* Handle any requested session parameters */ + if ((error = negotiate_spw(tmp->c, tmp->session_key)) < 0) { + free_session_key(tmp->session_key); + exit(error); + } + + send_motd(tmp->c, tmp->session_key); + session_run(tmp->lp, tmp->session_key); + free(args); + return NULL; +} + int -handle_srq(struct ostp_session *sp, struct ostp_listener *lp, struct session_request *srq) +handle_srq(struct ostp_client *c, struct ostp_listener *lp, struct session_request *srq) { struct x25519_keypair keypair; + struct session_td_args *sargs; unsigned char *session_key; - pid_t child; int error; if (REQUIRE_USER_AUTH && !ISSET(srq->options, SESSION_REQ_USER)) { - printf("%x\n", srq->options); printf("User authentication enforced but client 'U' bit not set\n"); printf("Closing connection...\n"); return -1; @@ -164,7 +195,7 @@ handle_srq(struct ostp_session *sp, struct ostp_listener *lp, struct session_req } /* Send back our our public key */ - error = send(sp->sockfd, keypair.pubkey, keypair.pubkey_len, 0); + error = send(c->sockfd, keypair.pubkey, keypair.pubkey_len, 0); if (error < 0) { perror("Failed to send public key"); return error; @@ -176,24 +207,19 @@ handle_srq(struct ostp_session *sp, struct ostp_listener *lp, struct session_req return error; } - /* Try user auth, not needed if REQUIRE_USER_AUTH is 0 */ - if (passwd_auth(sp, session_key) != 0) { - return -1; + sargs = malloc(sizeof(*sargs)); + if (sargs == NULL) { + printf("Failed to allocate session args\n"); + return errno; } - /* Handle any requested session parameters */ - if ((error = negotiate_spw(sp, session_key)) < 0) { - free_session_key(session_key); + sargs->c = c; + sargs->lp = lp; + sargs->session_key = session_key; + error = pthread_create(&c->td, NULL, session_td, sargs); + if (error != 0) { return error; } - send_motd(sp, session_key); - - /* Dispatch a thread and handle this session */ - child = fork(); - if (child == 0) { - session_run(sp, lp, session_key); - exit(0); - } return 0; } diff --git a/lib/libostp/param.c b/lib/libostp/param.c index 4c14733..8b83f46 100644 --- a/lib/libostp/param.c +++ b/lib/libostp/param.c @@ -36,9 +36,9 @@ #include <stdio.h> static int -handle_pap(struct ostp_session *sp, const struct pap *pap, const unsigned char *session_key) +handle_pap(struct ostp_client *c, const struct pap *pap, const unsigned char *session_key) { - int error = -1; + int error = 0; uint8_t attempts = 0; struct pap tmp_pap = *pap; const size_t LEN = sizeof(struct pap); @@ -48,7 +48,7 @@ handle_pap(struct ostp_session *sp, const struct pap *pap, const unsigned char * /* Quick session request, jump right in! */ if (ISSET(tmp_pap.spw, PAP_SPW_QSR)) { printf("Got QSR, starting session...\n"); - send_frame(sp->sockfd, &tmp_pap, LEN, session_key); + send_frame(c->sockfd, &tmp_pap, LEN, session_key); return 0; } @@ -65,11 +65,11 @@ handle_pap(struct ostp_session *sp, const struct pap *pap, const unsigned char * tmp_pap.code = PAP_BAD_SPW; /* Send in PAP and wait for response */ - if ((error = send_frame(sp->sockfd, &tmp_pap, LEN, session_key)) < -1) { + if ((error = send_frame(c->sockfd, &tmp_pap, LEN, session_key)) < 0) { printf("Failed to send PAP frame\n"); return -1; } - if ((error = recv_frame(sp->sockfd, LEN, session_key, &tmp_pap)) < -1) { + if ((error = recv_frame(c->sockfd, LEN, session_key, &tmp_pap)) < 0) { printf("Failed to recv PAP frame\n"); return error; } @@ -81,16 +81,16 @@ handle_pap(struct ostp_session *sp, const struct pap *pap, const unsigned char * } int -negotiate_spw(struct ostp_session *sp, unsigned char *session_key) +negotiate_spw(struct ostp_client *c, unsigned char *session_key) { const size_t LEN = sizeof(struct pap); struct pap pap; int error; /* Get PAP from the network */ - if ((error = recv_frame(sp->sockfd, LEN, session_key, &pap)) < -1) { + if ((error = recv_frame(c->sockfd, LEN, session_key, &pap)) < 0) { return error; } - return handle_pap(sp, &pap, session_key); + return handle_pap(c, &pap, session_key); } diff --git a/lib/libostp/server.c b/lib/libostp/server.c index 0013ce2..588a6d2 100644 --- a/lib/libostp/server.c +++ b/lib/libostp/server.c @@ -32,6 +32,7 @@ #include <net/stpsession.h> #include <arpa/inet.h> #include <string.h> +#include <stdlib.h> #include <unistd.h> #include <stdio.h> @@ -39,41 +40,34 @@ #define LISTEN_PORT 5352 static int -handle_client(struct sockaddr_in *caddr, struct ostp_session *sp, struct ostp_listener *lp, - int clientno) +handle_client(struct sockaddr_in *caddr, struct ostp_client *c, struct ostp_listener *lp) { struct session_request srq; ssize_t nread; - sp->sockfd = lp->clients[clientno]; - /* Try to read in the session request */ - if ((nread = read(sp->sockfd, &srq, sizeof(srq))) < 0) { + if ((nread = read(c->sockfd, &srq, sizeof(srq))) < 0) { printf("Read failure...\n"); - close(sp->sockfd); - lp->clients[clientno] = -1; + listener_close(lp, c); return -1; } if (nread == 0) { printf("Connection closed by peer\n"); - close(sp->sockfd); - lp->clients[clientno] = -1; + listener_close(lp, c); return -1; } /* Is this even a session request? */ if (nread != sizeof(srq)) { printf("Rejecting data - not a session request...\n"); - close(sp->sockfd); - lp->clients[clientno] = -1; + listener_close(lp, c); return -1; } /* Handle the session request */ - if (handle_srq(sp, lp, &srq) < 0) { - close(sp->sockfd); - lp->clients[clientno] = -1; + if (handle_srq(c, lp, &srq) < 0) { + listener_close(lp, c); return -1; } @@ -92,13 +86,14 @@ listener_init(struct ostp_listener *lp) } int -listener_bind(struct ostp_session *sp, struct ostp_listener *lp) +listener_bind(struct ostp_listener *lp) { + struct ostp_session *session; struct sockaddr_in saddr; int error; lp->serv_sockfd = socket(AF_INET, SOCK_STREAM, 0); - if (sp->sockfd < 0) { + if (lp->serv_sockfd < 0) { perror("Failed to create socket\n"); return -1; } @@ -122,22 +117,24 @@ listener_bind(struct ostp_session *sp, struct ostp_listener *lp) } int -listener_poll(struct ostp_session *sp, struct ostp_listener *lp) +listener_poll(struct ostp_listener *lp) { struct sockaddr_in caddr; + struct ostp_client *c; socklen_t caddr_len; + pthread_t client_td; int client_sock, error = 0; char *ip; memset(lp->clients, -1, sizeof(lp->clients)); - lp->clients[0] = lp->serv_sockfd; + lp->clients[0].sockfd = lp->serv_sockfd; while (1) { FD_ZERO(&lp->client_fds); for (int i = 0; i < MAX_CLIENTS; ++i) { - if (lp->clients[i] >= 0) - FD_SET(lp->clients[i], &lp->client_fds); + if (lp->clients[i].sockfd >= 0) + FD_SET(lp->clients[i].sockfd, &lp->client_fds); } if (select(1024, &lp->client_fds, NULL, NULL, NULL) < 0) { @@ -157,25 +154,22 @@ listener_poll(struct ostp_session *sp, struct ostp_listener *lp) } for (int i = 0; i < MAX_CLIENTS; ++i) { - if (lp->clients[i] < 0) { - lp->clients[i] = client_sock; + c = &lp->clients[i]; + if (lp->client_count >= MAX_CLIENTS) { + printf("New connection rejected, max clients reached\n"); + continue; + } + if (c->sockfd < 0) { + c->sockfd = client_sock; ip = inet_ntoa(caddr.sin_addr); + printf("Incoming connection from %s\n", ip); + ++lp->client_count; + handle_client(&caddr, c, lp); break; } } } - - /* Handle from data from lp->clients */ - for (int i = 1; i < MAX_CLIENTS; ++i) { - if (lp->clients[i] <= 0) - continue; - if (FD_ISSET(lp->clients[i], &lp->client_fds) <= 0) - continue; - - handle_client(&caddr, sp, lp, i); - break; - } } close(client_sock); @@ -185,11 +179,21 @@ listener_poll(struct ostp_session *sp, struct ostp_listener *lp) void listener_cleanup(struct ostp_listener *lp) { + struct ostp_client *c; + for (int i = 0; i < MAX_CLIENTS; ++i) { - if (lp->clients[i] > 0) { - close(lp->clients[i]); - } + c = &lp->clients[i]; + listener_close(lp, c); } close(lp->serv_sockfd); } + +void +listener_close(struct ostp_listener *lp, struct ostp_client *c) +{ + close(c->sockfd); + c->sockfd = -1; + memset(&c->session, 0, sizeof(c->session)); + --lp->client_count; +} diff --git a/ostp.d/init/main.c b/ostp.d/init/main.c index 08e020e..9b2a836 100644 --- a/ostp.d/init/main.c +++ b/ostp.d/init/main.c @@ -33,7 +33,7 @@ #include <stdio.h> static int -blah(struct ostp_session *s, const char *buf, size_t len) +handle_data(struct ostp_client *s, const char *buf, size_t len) { printf("Got data!\n"); return 0; @@ -43,16 +43,15 @@ int main(void) { struct ostp_listener l; - struct ostp_session s; int error; listener_init(&l); - l.on_recv = blah; + l.on_recv = handle_data; - if ((error = listener_bind(&s, &l)) < 0) { + if ((error = listener_bind(&l)) < 0) { return error; } - if ((error = listener_poll(&s, &l)) < 0) { + if ((error = listener_poll(&l)) < 0) { return error; } |