diff options
-rw-r--r-- | Makefile | 16 | ||||
-rw-r--r-- | client.c | 142 | ||||
-rw-r--r-- | main.c | 59 | ||||
-rw-r--r-- | server.c | 116 |
4 files changed, 333 insertions, 0 deletions
diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..ca84429 --- /dev/null +++ b/Makefile @@ -0,0 +1,16 @@ +CC=gcc +CFLAGS=-O3 -march=native -flto -pipe -g -fsanitize=address +LIBS=`pkg-config --cflags --libs openssl` -lpthread + +OBJ = main.o client.o server.o +HEADERS = client.h server.h + +%.o: %.c $(DEPS) $(HEADERS) + $(CC) -c -o $@ $< $(CFLAGS) $(LIBS) + +chat: $(OBJ) + $(CC) -o $@ $^ $(CFLAGS) $(LIBS) + +clean: + rm $(OBJ) + rm chat diff --git a/client.c b/client.c new file mode 100644 index 0000000..67bd6bd --- /dev/null +++ b/client.c @@ -0,0 +1,142 @@ +#include <sys/socket.h> +#include <netdb.h> +#include <arpa/inet.h> +#include <unistd.h> +#include <string.h> +#include <openssl/ssl.h> +#include "client.h" +#define MAX 512 + +int create_socket(char server[], char port[]) + { + struct addrinfo hints; + struct addrinfo *result, *rp; + int sockfd, host; + size_t len; + + /* Obtain address(es) matching host */ + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_INET;//IPv4 or IPv6 + hints.ai_socktype = SOCK_STREAM;//stream socket + hints.ai_flags = 0; + hints.ai_protocol = 0;//just don't set a protocol as openssl doesn't like that + + /* + getaddrinfo() returns a list of address structures. Try each address until we successfully connect(2). + If socket(2) (or connect(2)) fails, we (close the socket and) try the next address. + */ + + host = getaddrinfo(server, port, &hints, &result); + if (host != 0){fprintf(stderr, "Host not found\n"); return 0;}//host not found + + for (rp = result; rp != NULL; rp = rp->ai_next) + { + sockfd = socket(rp->ai_family, rp->ai_socktype,rp->ai_protocol); + if (sockfd == -1){continue;} + if (connect(sockfd, rp->ai_addr, rp->ai_addrlen) != -1){break;}//Success + close(sockfd); + } + + /* No address connect()'d successfully */ + if (!rp){fprintf(stderr, "Could not connect() to any host\n"); return 0;} + + freeaddrinfo(result);//No longer needed + + return sockfd; + } + +SSL *connectServer(char server[]) + { + const SSL_METHOD *method; + SSL_CTX *ctx; + SSL *tls; + int sockfd = 0; + + /* attempt to initalize TLS library */ + if(SSL_library_init() < 0){return 0;} + + /* set TLS client hello */ + method = TLS_client_method(); + + /* attempt to create new TLS context */ + if (!(ctx = SSL_CTX_new(method))){return 0;} + SSL_CTX_set_min_proto_version(ctx,TLS1_3_VERSION); + SSL_CTX_set_max_proto_version(ctx,TLS1_3_VERSION); + + /* create a new TLS connection state object */ + tls = SSL_new(ctx); + /* set min and max protocol version to TLSv1.3 */ + SSL_set_min_proto_version(tls,TLS1_3_VERSION); + SSL_set_max_proto_version(tls,TLS1_3_VERSION); + + char port[6] = "6697"; + + /* Copy string as you can't write to hardcoded strings */ + char buff[strlen(server)+1]; + memcpy(buff,server,strlen(server)+1); + + char *pos; + /* if the server contains a slash, we have a port number */ + if (pos = strchr(buff, '/')) + { + /* copy port value into port */ + strncpy(port,pos+1,strlen(port)); + /* terminate server string at slash */ + *pos = '\0'; + } + + /* create a TCP socket connection with the server */ + sockfd = create_socket(buff,port); + + /* return 0 on failure */ + if (!sockfd){return 0;} + + SSL_set_fd(tls, sockfd); + + /* try to tls connect here, success returns 1 */ + if (SSL_connect(tls) != 1){return 0;}//could not make a TLS session + + /* B L O A T */ +// cert = SSL_get_peer_certificate(tls); +// if (!cert){return 0;}//could not get a certificate + //free what we aren't using anymore +// X509_free(cert); + + return tls; + } + +void client() + { + /* make a tls session with a server */ + SSL *tlsSession = connectServer("127.0.0.1/6697"); + if (!tlsSession){fprintf(stderr,"Could not make a TLS session with host\n"); exit(1);} + + printf("Session made, you can start typing:\n"); + + char buffer[MAX]; + while (1) + { + int len = 0; + + char c; + while ((c = getchar()) != '\n' && len < MAX){buffer[len++] = c;} + buffer[len] = '\n'; + if (!(memcmp(buffer,"exit",4))){break;} + + SSL_write(tlsSession,buffer,len+1); + int num = SSL_read(tlsSession,buffer,MAX); + if (num <= 0){fprintf(stderr, "SSL_read() 0 or failed, giving up\n"); exit(1);} + buffer[num] = '\0'; + printf("%s", buffer); + if (!(strncmp(buffer,"Goodbye\n",num))){break;} + } + + + int fd = SSL_get_fd(tlsSession); + SSL_shutdown(tlsSession); + if (fd){close(fd);} + SSL_free(tlsSession); + } + + + @@ -0,0 +1,59 @@ +#include <stdio.h> +#include <stdlib.h> +#include <stdbool.h> +#include <getopt.h> +#include "client.h" +#include "server.h" + +void usage(char exname[]) + { + fprintf(stderr,"Usage: %s [OPTIONS]\n\ + -h, --help This help message.\n\ + -c, --client Enable client mode (default).\n\ + -s, --server Enable server mode.\n", exname); + + exit(0); + } + +int main(int argc, char *argv[]) + { + static struct option long_options[] = + { + {"help", no_argument, 0, 'h'}, + {"client", no_argument, 0, 'c'}, + {"server", no_argument, 0, 's'}, + {0,0,0,0} + }; + + bool isClient = true; + /* go through the options */ + while(1) + { + int option_index = 0; + char c = getopt_long(argc, argv, "hcs", long_options, &option_index); + /* End of args */ + if (c == -1){break;} + switch (c) + { + case 'h': + usage(argv[0]); + break; + case 'c': + isClient = true; + break; + case 's': + isClient = false; + break; + case '?': + break; + default: + exit(1); + } + } + + /* pass off to client and server */ + if (isClient){client();} + else {server();} + + return 0; + } diff --git a/server.c b/server.c new file mode 100644 index 0000000..4dfb04f --- /dev/null +++ b/server.c @@ -0,0 +1,116 @@ +#include <stdio.h> +#include <stdlib.h> +#include <unistd.h> +#include <string.h> +#include <sys/socket.h> +#include <arpa/inet.h> +#include <openssl/ssl.h> +#include <openssl/err.h> +#include "server.h" +#include <pthread.h> +#define MAX 512 +#define NICK_CHARS 32 + + +int bind_socket(int port) + { + int sock; + struct sockaddr_in addr; + + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + addr.sin_addr.s_addr = htonl(INADDR_ANY); + /* create the socket */ + if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0){fprintf(stderr,"Unable to create socket!\n"); exit(1);}//unable to create socket + /* bind to socket */ + if (bind(sock, (struct sockaddr*)&addr, sizeof(addr)) < 0){fprintf(stderr, "Unable to bind to socket!\n"); exit(1);}//unable to bind + /* listen on socket */ + if (listen(sock, 1) < 0){fprintf(stderr, "Unable to listen on socket!\n"); exit(1);}//unable to listen + + return sock; + } + +void init_openssl(){SSL_load_error_strings(); OpenSSL_add_ssl_algorithms();} +void cleanup_openssl(){EVP_cleanup();} + +void *session(SSL *tls) + { + char buffer[MAX+1]; + char nick[NICK_CHARS+1]; + + /* Get nick */ + while (1) + { + int num = SSL_read(tls,buffer,MAX); + if (num <= 0){fprintf(stdout, "SSL_read() 0 or error, giving up on client\n"); break;} + + char *newlinePos = strchr(buffer,'\n'); + if (!newlinePos){SSL_write(tls,"Missing newline\n",16); fprintf(stderr, "Missing newline\n"); continue;} + *newlinePos = '\0'; + + int length = newlinePos-buffer-5; + fprintf(stderr, "Length: %d\n", length); + if ((length > 0 && length <= NICK_CHARS) && !strncmp(buffer,"nick ", 5)) + { + memcpy(nick,buffer+5,length+1); + fprintf(stderr, "Nick: %s\n", nick); + SSL_write(tls, "Goodbye\n",8); + break; + } + SSL_write(tls,"That's not a nick or nick too long\n",35); + } + + int client = SSL_get_fd(tls); + + SSL_shutdown(tls); + SSL_free(tls); + close(client); + } + + +void server() + { + int sock; + + init_openssl(); + SSL_CTX *ctx; + if (!(ctx = SSL_CTX_new(TLS_server_method()))){fprintf(stderr,"Unable to create TLS context\n"); ERR_print_errors_fp(stderr); exit(1);} + SSL_CTX_set_ecdh_auto(ctx, 1); + SSL_CTX_set_min_proto_version(ctx,TLS1_3_VERSION); + SSL_CTX_set_max_proto_version(ctx,TLS1_3_VERSION); + if (SSL_CTX_use_certificate_file(ctx, "cert.pem", SSL_FILETYPE_PEM) <= 0){ERR_print_errors_fp(stderr); exit(1);} + if (SSL_CTX_use_PrivateKey_file(ctx, "key.pem", SSL_FILETYPE_PEM) <= 0){ERR_print_errors_fp(stderr); exit(1);} + + /* bind to socket */ + sock = bind_socket(6697); + + /* Handle connections */ + while(1) + { + struct sockaddr_in addr; + uint len = sizeof(addr); + + fprintf(stderr, "Listening for clients\n"); + + int client = accept(sock, (struct sockaddr*)&addr, &len); + if (client < 0){perror("Unable to accept"); exit(1);} + + SSL *tls = SSL_new(ctx); + if (!tls){fprintf(stderr, "Could not make TLS session"); exit(1);} + SSL_set_min_proto_version(tls,TLS1_3_VERSION); + SSL_set_max_proto_version(tls,TLS1_3_VERSION); + SSL_set_fd(tls, client); + + if (SSL_accept(tls) <= 0){ERR_print_errors_fp(stderr);} + else + { + /* We don't care about re-joining the threads */ + pthread_t thread; + pthread_create(&thread, NULL, (void *)session, tls); + } + } + + close(sock); + SSL_CTX_free(ctx); + cleanup_openssl(); +} |