diff options
| -rw-r--r-- | lib/include/server.h | 2 | ||||
| -rw-r--r-- | lib/libostp/auth.c | 42 | ||||
| -rw-r--r-- | lib/libostp/server.c | 87 | ||||
| -rw-r--r-- | lib/libostp/session.c | 5 | 
4 files changed, 92 insertions, 44 deletions
| diff --git a/lib/include/server.h b/lib/include/server.h index 52f92ce..1c23f61 100644 --- a/lib/include/server.h +++ b/lib/include/server.h @@ -34,6 +34,7 @@  #include <ostp/session.h>  #include <pthread.h>  #include <stddef.h> +#include <stdint.h>  #define MAX_CLIENTS 32 @@ -41,6 +42,7 @@ struct ostp_client {      struct ostp_session session;      int sockfd;      pthread_t td; +    volatile uint8_t authed : 1;  };  struct ostp_listener { diff --git a/lib/libostp/auth.c b/lib/libostp/auth.c index 559e3b3..b807485 100644 --- a/lib/libostp/auth.c +++ b/lib/libostp/auth.c @@ -125,36 +125,6 @@ send_motd(struct ostp_client *c, const unsigned char *session_key)      }  } -static int -session_run(struct ostp_listener *lp, const unsigned char *session_key) -{ -    struct ostp_client *c; -    char buf[4096]; -    size_t len; - -    while (1) { -        for (int i = 1; i < MAX_CLIENTS; ++i) { -            c = &lp->clients[i]; -            if (c->sockfd <= 0) -                continue; -            if (FD_ISSET(c->sockfd, &lp->client_fds) <= 0) -                continue; - -            len = recv_frame(c->sockfd, sizeof(buf) - 1, session_key, buf); -            if (len < 0) { -                printf("recv_frame() failure, packet lost\n"); -                continue; -            } -            if (len == 0) { -                return 0; -            } -            if (lp->on_recv != NULL) { -                lp->on_recv(c, buf, len); -            } -        } -    } -} -  static void *  session_td(void *args)  { @@ -174,7 +144,7 @@ session_td(void *args)      }      send_motd(tmp->c, tmp->session_key); -    session_run(tmp->lp, tmp->session_key); +    tmp->c->authed = 1;      free(args);      return NULL;  } @@ -184,7 +154,7 @@ handle_srq(struct ostp_client *c, struct ostp_listener *lp, struct session_reque  {      struct x25519_keypair keypair;      struct session_td_args *sargs; -    unsigned char *session_key; +    struct ostp_session *session;      int error;      if (REQUIRE_USER_AUTH && !ISSET(srq->options, SESSION_REQ_USER)) { @@ -207,8 +177,12 @@ handle_srq(struct ostp_client *c, struct ostp_listener *lp, struct session_reque          return error;      } +    /* Setup client session descriptor */ +    session = &c->session; +    session->sockfd = c->sockfd; +      printf("Deriving session key...\n"); -    error = gen_session_key(keypair.privkey, srq->pubkey, &session_key); +    error = gen_session_key(keypair.privkey, srq->pubkey, &session->session_key);      if (error < 0) {          return error;      } @@ -221,7 +195,7 @@ handle_srq(struct ostp_client *c, struct ostp_listener *lp, struct session_reque      sargs->c = c;      sargs->lp = lp; -    sargs->session_key = session_key; +    sargs->session_key = session->session_key;      error = pthread_create(&c->td, NULL, session_td, sargs);      if (error != 0) {          return error; diff --git a/lib/libostp/server.c b/lib/libostp/server.c index aad91e6..3bede26 100644 --- a/lib/libostp/server.c +++ b/lib/libostp/server.c @@ -37,10 +37,19 @@  #include <stdlib.h>  #include <unistd.h>  #include <stdio.h> +#include <pthread.h>  #define MAX_BACKLOG 4  #define LISTEN_PORT 5352 +struct recv_args { +    void *buf; +    struct ostp_client *c; +    struct ostp_listener *lp; +    size_t n; +    pthread_t td; +}; +  static int  handle_client(struct sockaddr_in *caddr, struct ostp_client *c, struct ostp_listener *lp)  { @@ -77,6 +86,24 @@ handle_client(struct sockaddr_in *caddr, struct ostp_client *c, struct ostp_list  }  /* + * When we get data from a client, a thread is sent + * here to handle it. + */ +static void * +recv_td(void *args) +{ +    struct recv_args *ap = args; +    struct ostp_listener *lp; + +    lp = ap->lp; +    if (lp->on_recv != NULL) { +        lp->on_recv(ap->c, ap->buf, ap->n); +    } +    free(args); +    return NULL; +} + +/*   * Put a listener in a known state during   * initialization.   */ @@ -122,14 +149,17 @@ int  listener_poll(struct ostp_listener *lp)  {      struct sockaddr_in caddr; +    struct recv_args *recv_ap;      struct ostp_client *c;      socklen_t caddr_len;      pthread_t client_td; -    int client_sock, error = 0; -    char *ip; +    int client_sock, n, error = 0; +    unsigned char *session_key; +    char *ip, buf[4096];      memset(lp->clients, -1, sizeof(lp->clients));      lp->clients[0].sockfd = lp->serv_sockfd; +    lp->client_count = 1;      while (1) {          FD_ZERO(&lp->client_fds); @@ -155,23 +185,64 @@ listener_poll(struct ostp_listener *lp)                  continue;              } +            if (lp->client_count >= MAX_CLIENTS) { +                printf("New connection rejected, max clients reached\n"); +                close(client_sock); +                continue; +            } +              for (int i = 0; i < MAX_CLIENTS; ++i) {                  c = &lp->clients[i]; -                if (lp->client_count >= MAX_CLIENTS) { -                    printf("New connection rejected, max clients reached\n"); -                    continue; -                }                  if (c->sockfd < 0) {                      c->sockfd = client_sock; -                    ip = inet_ntoa(caddr.sin_addr); +                    c->authed = 0; +                    ++lp->client_count; +                    ip = inet_ntoa(caddr.sin_addr);                      printf("Incoming connection from %s\n", ip); -                    ++lp->client_count;                      handle_client(&caddr, c, lp);                      break;                  }              }          } + +        /* Handle data from clients */ +        for (int i = 1; i < MAX_CLIENTS; ++i) { +            c = &lp->clients[i]; +            if (c->sockfd < 0 || !c->authed) +                continue; +            if (!FD_ISSET(c->sockfd, &lp->client_fds)) +                continue; + +            session_key = c->session.session_key; +            n = recv_frame(c->sockfd, sizeof(buf) - 1, session_key, buf); +            if (n < 0) { +                printf("recv_frame() failure, packet lost\n"); +                continue; +            } +            if (n == 0) { +                printf("Peer disconnected\n"); +                listener_close(lp, c); +                continue; +            } + +            if (lp->on_recv != NULL) { +                recv_ap = malloc(sizeof(*recv_ap)); +                if (recv_ap == NULL) +                    continue; + +                /* Prepare on_recv() args */ +                recv_ap->c = c; +                recv_ap->buf = buf; +                recv_ap->n = n; + +                error = pthread_create(&recv_ap->td, NULL, recv_td, recv_ap); +                if (error != 0) { +                    free(recv_ap); +                    continue; +                } +            } +        }      }      close(client_sock); diff --git a/lib/libostp/session.c b/lib/libostp/session.c index 151cad9..7087cc7 100644 --- a/lib/libostp/session.c +++ b/lib/libostp/session.c @@ -63,7 +63,7 @@ static const char *auth_codestr[] = {  };  static int -send_auth(int sockfd, const unsigned char *session_key) +send_auth(int sockfd, const unsigned char *session_key, struct ostp_session *s)  {      struct session_auth auth;      struct termios oldt, newt; @@ -129,6 +129,7 @@ send_auth(int sockfd, const unsigned char *session_key)          return -1;      } +    memcpy(s->username, auth.username, sizeof(auth.username));      return 0;  } @@ -260,7 +261,7 @@ session_new(const char *host, struct ostp_session *res)      gen_session_key(keypair.privkey, serv_pubkey, &session_key);      /* User authentication occurs before sending SPWs */ -    if ((error = send_auth(sockfd, session_key)) < 0) { +    if ((error = send_auth(sockfd, session_key, res)) < 0) {          return error;      } | 
