aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIan Moffett <ian@osmora.org>2024-09-29 22:38:43 -0400
committerIan Moffett <ian@osmora.org>2024-09-29 22:38:43 -0400
commit73ead92c2d37d5d091992ef617c4abdfe9907a18 (patch)
tree1b689727607d72e525cee5bd298367aadc293615
parent788b1308e86320882245159540ef0a489209bcf1 (diff)
project: Massive fixups
- Fix client handling - Add multithreading - Fixup bad values Signed-off-by: Ian Moffett <ian@osmora.org>
-rw-r--r--lib/include/libostp/server.h17
-rw-r--r--lib/include/net/auth.h4
-rw-r--r--lib/libostp/auth.c86
-rw-r--r--lib/libostp/param.c16
-rw-r--r--lib/libostp/server.c76
-rw-r--r--ostp.d/init/main.c9
6 files changed, 123 insertions, 85 deletions
diff --git a/lib/include/libostp/server.h b/lib/include/libostp/server.h
index a7a737b..0e232f6 100644
--- a/lib/include/libostp/server.h
+++ b/lib/include/libostp/server.h
@@ -32,23 +32,32 @@
#include <sys/select.h>
#include <libostp/session.h>
+#include <pthread.h>
#include <stddef.h>
#define MAX_CLIENTS 32
+struct ostp_client {
+ struct ostp_session session;
+ int sockfd;
+ pthread_t td;
+};
+
struct ostp_listener {
- int(*on_recv)(struct ostp_session *session, const char *buf, size_t len);
+ int(*on_recv)(struct ostp_client *c, const char *buf, size_t len);
int port;
/* -- Private -- */
- int clients[MAX_CLIENTS];
+ struct ostp_client clients[MAX_CLIENTS];
+ size_t client_count;
int serv_sockfd;
fd_set client_fds;
};
void listener_init(struct ostp_listener *lp);
-int listener_bind(struct ostp_session *sp, struct ostp_listener *lp);
-int listener_poll(struct ostp_session *sp, struct ostp_listener *lp);
+int listener_bind(struct ostp_listener *lp);
+int listener_poll(struct ostp_listener *lp);
void listener_cleanup(struct ostp_listener *lp);
+void listener_close(struct ostp_listener *lp, struct ostp_client *c);
#endif /* !LIBOSTP_SERVER_H_ */
diff --git a/lib/include/net/auth.h b/lib/include/net/auth.h
index d672231..0f22d22 100644
--- a/lib/include/net/auth.h
+++ b/lib/include/net/auth.h
@@ -34,8 +34,8 @@
#include <libostp/session.h>
#include <libostp/server.h>
-int handle_srq(struct ostp_session *sp, struct ostp_listener *lp,
+int handle_srq(struct ostp_client *c, struct ostp_listener *lp,
struct session_request *srq);
-int negotiate_spw(struct ostp_session *sp, unsigned char *session_key);
+int negotiate_spw(struct ostp_client *c, unsigned char *session_key);
#endif /* NET_AUTH_H_ */
diff --git a/lib/libostp/auth.c b/lib/libostp/auth.c
index d32c06a..f2097bc 100644
--- a/lib/libostp/auth.c
+++ b/lib/libostp/auth.c
@@ -35,6 +35,13 @@
#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
+#include <errno.h>
+
+struct session_td_args {
+ struct ostp_client *c;
+ struct ostp_listener *lp;
+ unsigned char *session_key;
+};
/*
* Check a password to see if it matches with
@@ -67,7 +74,7 @@ pwcheck(char *username, char *pw)
}
static int
-passwd_auth(struct ostp_session *sp, const unsigned char *session_key)
+passwd_auth(struct ostp_client *c, const unsigned char *session_key)
{
int error;
struct session_auth auth;
@@ -77,7 +84,7 @@ passwd_auth(struct ostp_session *sp, const unsigned char *session_key)
return 0;
}
- error = recv_frame(sp->sockfd, sizeof(auth), session_key, &auth);
+ error = recv_frame(c->sockfd, sizeof(auth), session_key, &auth);
if (error < 0) {
return error;
}
@@ -85,7 +92,7 @@ passwd_auth(struct ostp_session *sp, const unsigned char *session_key)
if (pwcheck(auth.username, auth.password) != 0) {
printf("Got bad password for %s\n", auth.username);
auth.code = AUTH_BAD_PW;
- error = send_frame(sp->sockfd, &auth, sizeof(auth), session_key);
+ error = send_frame(c->sockfd, &auth, sizeof(auth), session_key);
if (error < 0) {
printf("Failed to ACK user authentication with frame\n");
}
@@ -93,7 +100,7 @@ passwd_auth(struct ostp_session *sp, const unsigned char *session_key)
}
auth.code = AUTH_SUCCESS;
- error = send_frame(sp->sockfd, &auth, sizeof(auth), session_key);
+ error = send_frame(c->sockfd, &auth, sizeof(auth), session_key);
if (error < 0) {
printf("Failed to ACK user authentication with frame\n");
return error;
@@ -102,31 +109,32 @@ passwd_auth(struct ostp_session *sp, const unsigned char *session_key)
}
static void
-send_motd(struct ostp_session *sp, const unsigned char *session_key)
+send_motd(struct ostp_client *c, const unsigned char *session_key)
{
char motd[] = MOTD;
printf("Sending MOTD...\n");
- if (send_frame(sp->sockfd, motd, sizeof(motd), session_key) < 0) {
+ if (send_frame(c->sockfd, motd, sizeof(motd), session_key) < 0) {
printf("Failed to session MOTD\n");
}
}
static int
-session_run(struct ostp_session *sp, struct ostp_listener *lp,
- const unsigned char *session_key)
+session_run(struct ostp_listener *lp, const unsigned char *session_key)
{
+ struct ostp_client *c;
char buf[4096];
size_t len;
while (1) {
for (int i = 1; i < MAX_CLIENTS; ++i) {
- if (lp->clients[i] <= 0)
+ c = &lp->clients[i];
+ if (c->sockfd <= 0)
continue;
- if (FD_ISSET(lp->clients[i], &lp->client_fds) <= 0)
+ if (FD_ISSET(c->sockfd, &lp->client_fds) <= 0)
continue;
- len = recv_frame(lp->clients[i], sizeof(buf) - 1, session_key, buf);
+ len = recv_frame(c->sockfd, sizeof(buf) - 1, session_key, buf);
if (len < 0) {
printf("recv_frame() failure, packet lost\n");
continue;
@@ -135,22 +143,45 @@ session_run(struct ostp_session *sp, struct ostp_listener *lp,
return 0;
}
if (lp->on_recv != NULL) {
- lp->on_recv(sp, buf, len);
+ lp->on_recv(c, buf, len);
}
}
}
}
+static void *
+session_td(void *args)
+{
+ struct session_td_args *tmp = args;
+ int error;
+
+ /* Try user auth, not needed if REQUIRE_USER_AUTH is 0 */
+ if (passwd_auth(tmp->c, tmp->session_key) != 0) {
+ free_session_key(tmp->session_key);
+ exit(-1);
+ }
+
+ /* Handle any requested session parameters */
+ if ((error = negotiate_spw(tmp->c, tmp->session_key)) < 0) {
+ free_session_key(tmp->session_key);
+ exit(error);
+ }
+
+ send_motd(tmp->c, tmp->session_key);
+ session_run(tmp->lp, tmp->session_key);
+ free(args);
+ return NULL;
+}
+
int
-handle_srq(struct ostp_session *sp, struct ostp_listener *lp, struct session_request *srq)
+handle_srq(struct ostp_client *c, struct ostp_listener *lp, struct session_request *srq)
{
struct x25519_keypair keypair;
+ struct session_td_args *sargs;
unsigned char *session_key;
- pid_t child;
int error;
if (REQUIRE_USER_AUTH && !ISSET(srq->options, SESSION_REQ_USER)) {
- printf("%x\n", srq->options);
printf("User authentication enforced but client 'U' bit not set\n");
printf("Closing connection...\n");
return -1;
@@ -164,7 +195,7 @@ handle_srq(struct ostp_session *sp, struct ostp_listener *lp, struct session_req
}
/* Send back our our public key */
- error = send(sp->sockfd, keypair.pubkey, keypair.pubkey_len, 0);
+ error = send(c->sockfd, keypair.pubkey, keypair.pubkey_len, 0);
if (error < 0) {
perror("Failed to send public key");
return error;
@@ -176,24 +207,19 @@ handle_srq(struct ostp_session *sp, struct ostp_listener *lp, struct session_req
return error;
}
- /* Try user auth, not needed if REQUIRE_USER_AUTH is 0 */
- if (passwd_auth(sp, session_key) != 0) {
- return -1;
+ sargs = malloc(sizeof(*sargs));
+ if (sargs == NULL) {
+ printf("Failed to allocate session args\n");
+ return errno;
}
- /* Handle any requested session parameters */
- if ((error = negotiate_spw(sp, session_key)) < 0) {
- free_session_key(session_key);
+ sargs->c = c;
+ sargs->lp = lp;
+ sargs->session_key = session_key;
+ error = pthread_create(&c->td, NULL, session_td, sargs);
+ if (error != 0) {
return error;
}
- send_motd(sp, session_key);
-
- /* Dispatch a thread and handle this session */
- child = fork();
- if (child == 0) {
- session_run(sp, lp, session_key);
- exit(0);
- }
return 0;
}
diff --git a/lib/libostp/param.c b/lib/libostp/param.c
index 4c14733..8b83f46 100644
--- a/lib/libostp/param.c
+++ b/lib/libostp/param.c
@@ -36,9 +36,9 @@
#include <stdio.h>
static int
-handle_pap(struct ostp_session *sp, const struct pap *pap, const unsigned char *session_key)
+handle_pap(struct ostp_client *c, const struct pap *pap, const unsigned char *session_key)
{
- int error = -1;
+ int error = 0;
uint8_t attempts = 0;
struct pap tmp_pap = *pap;
const size_t LEN = sizeof(struct pap);
@@ -48,7 +48,7 @@ handle_pap(struct ostp_session *sp, const struct pap *pap, const unsigned char *
/* Quick session request, jump right in! */
if (ISSET(tmp_pap.spw, PAP_SPW_QSR)) {
printf("Got QSR, starting session...\n");
- send_frame(sp->sockfd, &tmp_pap, LEN, session_key);
+ send_frame(c->sockfd, &tmp_pap, LEN, session_key);
return 0;
}
@@ -65,11 +65,11 @@ handle_pap(struct ostp_session *sp, const struct pap *pap, const unsigned char *
tmp_pap.code = PAP_BAD_SPW;
/* Send in PAP and wait for response */
- if ((error = send_frame(sp->sockfd, &tmp_pap, LEN, session_key)) < -1) {
+ if ((error = send_frame(c->sockfd, &tmp_pap, LEN, session_key)) < 0) {
printf("Failed to send PAP frame\n");
return -1;
}
- if ((error = recv_frame(sp->sockfd, LEN, session_key, &tmp_pap)) < -1) {
+ if ((error = recv_frame(c->sockfd, LEN, session_key, &tmp_pap)) < 0) {
printf("Failed to recv PAP frame\n");
return error;
}
@@ -81,16 +81,16 @@ handle_pap(struct ostp_session *sp, const struct pap *pap, const unsigned char *
}
int
-negotiate_spw(struct ostp_session *sp, unsigned char *session_key)
+negotiate_spw(struct ostp_client *c, unsigned char *session_key)
{
const size_t LEN = sizeof(struct pap);
struct pap pap;
int error;
/* Get PAP from the network */
- if ((error = recv_frame(sp->sockfd, LEN, session_key, &pap)) < -1) {
+ if ((error = recv_frame(c->sockfd, LEN, session_key, &pap)) < 0) {
return error;
}
- return handle_pap(sp, &pap, session_key);
+ return handle_pap(c, &pap, session_key);
}
diff --git a/lib/libostp/server.c b/lib/libostp/server.c
index 0013ce2..588a6d2 100644
--- a/lib/libostp/server.c
+++ b/lib/libostp/server.c
@@ -32,6 +32,7 @@
#include <net/stpsession.h>
#include <arpa/inet.h>
#include <string.h>
+#include <stdlib.h>
#include <unistd.h>
#include <stdio.h>
@@ -39,41 +40,34 @@
#define LISTEN_PORT 5352
static int
-handle_client(struct sockaddr_in *caddr, struct ostp_session *sp, struct ostp_listener *lp,
- int clientno)
+handle_client(struct sockaddr_in *caddr, struct ostp_client *c, struct ostp_listener *lp)
{
struct session_request srq;
ssize_t nread;
- sp->sockfd = lp->clients[clientno];
-
/* Try to read in the session request */
- if ((nread = read(sp->sockfd, &srq, sizeof(srq))) < 0) {
+ if ((nread = read(c->sockfd, &srq, sizeof(srq))) < 0) {
printf("Read failure...\n");
- close(sp->sockfd);
- lp->clients[clientno] = -1;
+ listener_close(lp, c);
return -1;
}
if (nread == 0) {
printf("Connection closed by peer\n");
- close(sp->sockfd);
- lp->clients[clientno] = -1;
+ listener_close(lp, c);
return -1;
}
/* Is this even a session request? */
if (nread != sizeof(srq)) {
printf("Rejecting data - not a session request...\n");
- close(sp->sockfd);
- lp->clients[clientno] = -1;
+ listener_close(lp, c);
return -1;
}
/* Handle the session request */
- if (handle_srq(sp, lp, &srq) < 0) {
- close(sp->sockfd);
- lp->clients[clientno] = -1;
+ if (handle_srq(c, lp, &srq) < 0) {
+ listener_close(lp, c);
return -1;
}
@@ -92,13 +86,14 @@ listener_init(struct ostp_listener *lp)
}
int
-listener_bind(struct ostp_session *sp, struct ostp_listener *lp)
+listener_bind(struct ostp_listener *lp)
{
+ struct ostp_session *session;
struct sockaddr_in saddr;
int error;
lp->serv_sockfd = socket(AF_INET, SOCK_STREAM, 0);
- if (sp->sockfd < 0) {
+ if (lp->serv_sockfd < 0) {
perror("Failed to create socket\n");
return -1;
}
@@ -122,22 +117,24 @@ listener_bind(struct ostp_session *sp, struct ostp_listener *lp)
}
int
-listener_poll(struct ostp_session *sp, struct ostp_listener *lp)
+listener_poll(struct ostp_listener *lp)
{
struct sockaddr_in caddr;
+ struct ostp_client *c;
socklen_t caddr_len;
+ pthread_t client_td;
int client_sock, error = 0;
char *ip;
memset(lp->clients, -1, sizeof(lp->clients));
- lp->clients[0] = lp->serv_sockfd;
+ lp->clients[0].sockfd = lp->serv_sockfd;
while (1) {
FD_ZERO(&lp->client_fds);
for (int i = 0; i < MAX_CLIENTS; ++i) {
- if (lp->clients[i] >= 0)
- FD_SET(lp->clients[i], &lp->client_fds);
+ if (lp->clients[i].sockfd >= 0)
+ FD_SET(lp->clients[i].sockfd, &lp->client_fds);
}
if (select(1024, &lp->client_fds, NULL, NULL, NULL) < 0) {
@@ -157,25 +154,22 @@ listener_poll(struct ostp_session *sp, struct ostp_listener *lp)
}
for (int i = 0; i < MAX_CLIENTS; ++i) {
- if (lp->clients[i] < 0) {
- lp->clients[i] = client_sock;
+ c = &lp->clients[i];
+ if (lp->client_count >= MAX_CLIENTS) {
+ printf("New connection rejected, max clients reached\n");
+ continue;
+ }
+ if (c->sockfd < 0) {
+ c->sockfd = client_sock;
ip = inet_ntoa(caddr.sin_addr);
+
printf("Incoming connection from %s\n", ip);
+ ++lp->client_count;
+ handle_client(&caddr, c, lp);
break;
}
}
}
-
- /* Handle from data from lp->clients */
- for (int i = 1; i < MAX_CLIENTS; ++i) {
- if (lp->clients[i] <= 0)
- continue;
- if (FD_ISSET(lp->clients[i], &lp->client_fds) <= 0)
- continue;
-
- handle_client(&caddr, sp, lp, i);
- break;
- }
}
close(client_sock);
@@ -185,11 +179,21 @@ listener_poll(struct ostp_session *sp, struct ostp_listener *lp)
void
listener_cleanup(struct ostp_listener *lp)
{
+ struct ostp_client *c;
+
for (int i = 0; i < MAX_CLIENTS; ++i) {
- if (lp->clients[i] > 0) {
- close(lp->clients[i]);
- }
+ c = &lp->clients[i];
+ listener_close(lp, c);
}
close(lp->serv_sockfd);
}
+
+void
+listener_close(struct ostp_listener *lp, struct ostp_client *c)
+{
+ close(c->sockfd);
+ c->sockfd = -1;
+ memset(&c->session, 0, sizeof(c->session));
+ --lp->client_count;
+}
diff --git a/ostp.d/init/main.c b/ostp.d/init/main.c
index 08e020e..9b2a836 100644
--- a/ostp.d/init/main.c
+++ b/ostp.d/init/main.c
@@ -33,7 +33,7 @@
#include <stdio.h>
static int
-blah(struct ostp_session *s, const char *buf, size_t len)
+handle_data(struct ostp_client *s, const char *buf, size_t len)
{
printf("Got data!\n");
return 0;
@@ -43,16 +43,15 @@ int
main(void)
{
struct ostp_listener l;
- struct ostp_session s;
int error;
listener_init(&l);
- l.on_recv = blah;
+ l.on_recv = handle_data;
- if ((error = listener_bind(&s, &l)) < 0) {
+ if ((error = listener_bind(&l)) < 0) {
return error;
}
- if ((error = listener_poll(&s, &l)) < 0) {
+ if ((error = listener_poll(&l)) < 0) {
return error;
}