/**********************************************************************
 * conn.c                                                    August 2005
 *
 * L7VSD: Linux Virtual Server for Layer7 Load Balancing
 * Copyright (C) 2005  NTT COMWARE Corporation.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
 * 02110-1301 USA
 *
 **********************************************************************/

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/select.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <fcntl.h>
#include <stdlib.h>
#include <assert.h>
#include <glib.h>
#include <limits.h>
#include "vanessa_logger.h"
#include "l7vs.h"

static int l7vs_conn_cl_callback(struct l7vs_iomux *iom, int flags);
static int l7vs_conn_rs_callback(struct l7vs_iomux *iom, int flags);
static int l7vs_conn_rs_connected(struct l7vs_conn *conn);
static int l7vs_conn_relay(struct l7vs_conn *conn, int is_client);
static int l7vs_conn_relay_cl(struct l7vs_conn *conn);
static int l7vs_conn_relay_rs(struct l7vs_conn *conn);

static GList *l7vs_conn_pending = NULL;

struct l7vs_conn *
l7vs_conn_create(int lfd, struct l7vs_lsock *lsock)
{
        socklen_t len;
        struct l7vs_conn *conn;
        int mss;
        int ret;
        
        conn = (struct l7vs_conn *)malloc(sizeof(*conn));
        if (conn == NULL) {
                VANESSA_LOGGER_ERR("Could not allocate memory");
                return conn;
        }

        conn->cldata_len = 0;
        conn->cldata_bufsize = L7VS_CLDATA_CHUNKSIZE;
        conn->cldata = (char *)malloc(conn->cldata_bufsize);
        if (conn->cldata == NULL) {
                VANESSA_LOGGER_ERR("Could not allocate memory for the buffer");
                free(conn);
                return NULL;
        }
        conn->srv = NULL;

        len = sizeof(conn->caddr);
        conn->ciom.fd = accept(lfd, (struct sockaddr *)&conn->caddr, &len);
        if (conn->ciom.fd < 0) {
                VANESSA_LOGGER_ERR_UNSAFE("accept: %s", strerror(errno));
                free(conn->cldata);
                free(conn);
                return NULL;
        }

        len = sizeof(mss);
        ret = getsockopt(conn->ciom.fd, IPPROTO_TCP, TCP_MAXSEG, &mss, &len);
        if (ret < 0) {
                VANESSA_LOGGER_ERR_UNSAFE("getsockopt TCP_MAXSEG: %s",
                                          strerror(errno));
                close(conn->ciom.fd);
                free(conn->cldata);
                free(conn);
                return NULL;
        }
        conn->cmss = mss;
        VANESSA_LOGGER_DEBUG_UNSAFE("client %s mss %d",
                                    inet_ntoa(conn->caddr.sin_addr), mss);

        len = sizeof(conn->tcpopt);

        conn->lsock = lsock;
        conn->proto = lsock->proto;
        conn->state = L7VS_CONN_S_CL_CONNECTED; 
        conn->ciom.callback = l7vs_conn_cl_callback;
        conn->ciom.data = conn;
        l7vs_iomux_add(&conn->ciom, L7VS_IOMUX_READ);

        conn->riom.fd = -1;
        l7vs_conn_pending = g_list_append(l7vs_conn_pending, conn);

        return conn;
}

void
l7vs_conn_destroy(struct l7vs_conn *conn)
{
        if (conn->srv == NULL) {
                l7vs_conn_pending = g_list_remove(l7vs_conn_pending, conn);
        } else {
                l7vs_service_remove_conn(conn->srv, conn);
        }

        if (conn->ciom.fd >= 0) {
                l7vs_conn_close_csock(conn);
        }

        if (conn->riom.fd >= 0) {
                l7vs_conn_close_rsock(conn);
        }

        free(conn->cldata);
        free(conn);
}

static int
l7vs_conn_cl_callback(struct l7vs_iomux *iom, int flags)
{
        struct l7vs_conn *conn;
        struct l7vs_service *srv;
        struct l7vs_lsock *lsock;
        struct l7vs_dest *dest;
        int tcps;
        int ret;

        conn = (struct l7vs_conn *)iom->data;
        lsock = conn->lsock;
        if (flags & L7VS_IOMUX_EXCEPT) {
                VANESSA_LOGGER_DEBUG("out-of-band data received. disconnecting");
                l7vs_conn_destroy(conn);
                return L7VS_IOMUX_LIST_REMOVED_OTHER;
        }

        switch (conn->state) {
        case L7VS_CONN_S_CL_CONNECTED:
                /* Messages arrived from the client. */
                if ((flags & L7VS_IOMUX_READ) == 0) {
                        return L7VS_IOMUX_LIST_UNCHANGED;
                }

                ret = l7vs_conn_recv_client(conn);
                if (ret <= 0) {
                        goto cl_recv_err;
                }

                ret = l7vs_lsock_select_service(lsock, conn, conn->cldata,
                                                ret, &srv, &dest, &tcps);
                if (ret < 0) {
                        VANESSA_LOGGER_DEBUG("no matching service found");
                        goto cl_recv_err;
                } else if (ret == 0) {
                        VANESSA_LOGGER_DEBUG("continue receiving");
                        break;
                }

		conn->splice = 0;
                ret = l7vs_conn_connect_rs(conn, dest);
                if (ret < 0) {
                        goto cl_recv_err;
                }

                l7vs_conn_pending = g_list_remove(l7vs_conn_pending, conn);
                conn->srv = srv;
                l7vs_service_register_conn(srv, conn);

                if (conn->state == L7VS_CONN_S_RS_CONNECTED) {
                        /* connect() completed synchronously */
                        ret = l7vs_conn_rs_connected(conn);
                        if (ret < 0) {
                                goto cl_recv_err;
                        }
                }
                break;

cl_recv_err:
                l7vs_conn_destroy(conn);
                return L7VS_IOMUX_LIST_REMOVED_OTHER;

        case L7VS_CONN_S_RS_CONNECTING:

		ret = l7vs_conn_recv_client(conn);
		if (ret <= 0) {
			l7vs_conn_destroy(conn);
			return L7VS_IOMUX_LIST_REMOVED_OTHER;
		} else {
			return L7VS_IOMUX_LIST_UNCHANGED;
		}

        case L7VS_CONN_S_RS_CONNECTED:
                ret = l7vs_conn_relay_cl(conn);
                if (ret <= 0) {
			int splice = conn->splice;

                        /* Client connection closed. */
                        l7vs_conn_close_csock(conn);
                        if (!splice) {
                                l7vs_conn_close_rsock(conn);
                        }

                        if (l7vs_conn_closed(conn)) {
                                l7vs_conn_destroy(conn);
                        }

                        if (splice) {
                                return L7VS_IOMUX_LIST_REMOVED_MYSELF;
                        } else {
                                return L7VS_IOMUX_LIST_REMOVED_OTHER;
                        }
                }
                break;
        }

        return L7VS_IOMUX_LIST_UNCHANGED;
}

static int
l7vs_conn_rs_callback(struct l7vs_iomux *iom, int flags)
{
        struct l7vs_conn *conn;
        struct l7vs_lsock *lsock;
        socklen_t len;
        int opt;
        int ret;
	int fcntl_flags; 

        conn = (struct l7vs_conn *)iom->data;
        lsock = conn->lsock;

        if (flags & L7VS_IOMUX_EXCEPT) {
                l7vs_conn_destroy(conn);
                return L7VS_IOMUX_LIST_REMOVED_OTHER;
        }

        assert(conn->state != L7VS_CONN_S_CL_CONNECTED);
        switch (conn->state) {
        case L7VS_CONN_S_RS_CONNECTING:
                /* connected to the real server. */
                if (conn->srv == NULL) {
                        VANESSA_LOGGER_INFO("Service destroyed during"
                                            " connecting to a real server");
                        l7vs_conn_destroy(conn);
                        return L7VS_IOMUX_LIST_REMOVED_OTHER;
                }

                len = sizeof(opt);
                ret = getsockopt(conn->riom.fd, SOL_SOCKET, SO_ERROR,
                                 &opt, &len);
                if (ret < 0) {
                        VANESSA_LOGGER_ERR_UNSAFE("getsockopt(SO_ERROR): %s",
                                                  strerror(errno));
                        l7vs_conn_destroy(conn);
                        return L7VS_IOMUX_LIST_REMOVED_OTHER;
                }

                if (opt != 0) {
                        VANESSA_LOGGER_ERR_UNSAFE("Connect to RS: %s",
                                                  strerror(opt));
                        l7vs_conn_destroy(conn);
                        return L7VS_IOMUX_LIST_REMOVED_OTHER;
                }

                /*
                 * Connection established to the real server.
                 * Now we can splice the connection.
                 */
                l7vs_iomux_change_flags(&conn->riom, L7VS_IOMUX_READ);
                ret = l7vs_conn_rs_connected(conn);
                if (ret < 0) {
                        l7vs_conn_destroy(conn);
                        return L7VS_IOMUX_LIST_REMOVED_OTHER;
                }
		fcntl_flags = fcntl(conn->riom.fd, F_GETFL, 0);
		if (fcntl_flags < 0) {
			VANESSA_LOGGER_ERR_UNSAFE("fcntl(F_GETFL): %s", strerror(errno));
			l7vs_conn_destroy(conn);
			return L7VS_IOMUX_LIST_REMOVED_OTHER;
		}
		fcntl_flags = fcntl(conn->riom.fd, F_SETFL, fcntl_flags & ~O_NONBLOCK);
		if (fcntl_flags < 0) {
			VANESSA_LOGGER_ERR_UNSAFE("fcntl(F_SETFL): %s", strerror(errno));
			l7vs_conn_destroy(conn);
			return L7VS_IOMUX_LIST_REMOVED_OTHER;
		}
		break;

        case L7VS_CONN_S_RS_CONNECTED:
                ret = l7vs_conn_relay_rs(conn);
                if (ret <= 0) {
			int splice = conn->splice;

                        /* Connection closed. */
                        l7vs_conn_close_rsock(conn);
                        if (!splice) {
                                l7vs_conn_close_csock(conn);
                        }

                        if (l7vs_conn_closed(conn)) {
                                l7vs_conn_destroy(conn);
                        }

                        if (splice) {
                                return L7VS_IOMUX_LIST_REMOVED_MYSELF;
                        } else {
                                return L7VS_IOMUX_LIST_REMOVED_OTHER;
                        }
                }

                break;
        }

        return L7VS_IOMUX_LIST_UNCHANGED;
}

static int
l7vs_conn_rs_connected(struct l7vs_conn *conn)
{
        int ret;

        if (conn->splice) {
                ret = l7vs_conn_splice(conn);
                if (ret < 0) {
                        return ret;
                }
        }
        
        ret = l7vs_service_establish(conn->srv, conn);
        if (ret < 0) {
                return ret;
        }
        
        conn->state = L7VS_CONN_S_RS_CONNECTED; 
        conn->dest->nactive++;

        return send(conn->riom.fd, conn->cldata, conn->cldata_len, 0);
}

int
l7vs_conn_recv_client(struct l7vs_conn *conn)
{
        int ret;
        char *newbuf;
        my_fd_set rfds;
        struct timeval tv;

        /*
         * the file descriptor is select()ed and made sure to be ready 
         * for reading by the caller.  So it can be recv()ed at least once.
         */
        for (;;) {
                if (conn->cldata_bufsize - conn->cldata_len <= 1) {
                        conn->cldata_bufsize += L7VS_CLDATA_CHUNKSIZE;
                        newbuf = (char *)realloc(conn->cldata,
                                                 conn->cldata_bufsize);
                        if (newbuf == NULL) {
                                VANESSA_LOGGER_ERR("realloc failed");
                                return -1;
                        }
                        conn->cldata = newbuf;
                }

                ret = recv(conn->ciom.fd, conn->cldata + conn->cldata_len,
                           conn->cldata_bufsize - conn->cldata_len - 1, 0);
                if (ret == 0) {
                        VANESSA_LOGGER_DEBUG("client disconnected");
                        return ret;
                } else if (ret < 0) {
                        VANESSA_LOGGER_ERR_UNSAFE("recv from client: %s\n", 
                                                  strerror(errno));
                        return ret;
                }
                conn->cldata_len += ret;
                conn->cldata[conn->cldata_len] = '\0';
                VANESSA_LOGGER_DEBUG_UNSAFE("received %s", conn->cldata);

                FD_ZERO(&rfds);
                tv.tv_sec = 0;
                tv.tv_usec = 0;
                FD_SET(conn->ciom.fd, (fd_set *)&rfds);
                ret = select(conn->ciom.fd + 1, (fd_set *)&rfds, NULL, NULL, &tv);
                if (ret == 0) {
                        break;
                } else if (ret < 0) {
                        VANESSA_LOGGER_ERR_UNSAFE("select on cliend fd: %s\n",
                                                  strerror(errno));
                        return ret;
                }
                /* we can recv() more if ret > 0 */
        }
        return conn->cldata_len;
}

#define L7VS_CONN_RELAY_BUFSIZE         2048

static int
l7vs_conn_relay(struct l7vs_conn *conn, int is_client)
{
        char buf[L7VS_CONN_RELAY_BUFSIZE + L7VS_PROTOMOD_MAX_ADD_BUFSIZE];
        char *rcvr, *sndr;
        int rfd, sfd;
        int ret, len, len_ret;
        int (*relayf)(struct l7vs_service *, struct l7vs_conn *,
                       char *, size_t *);
        memset(buf, 0, L7VS_CONN_RELAY_BUFSIZE + L7VS_PROTOMOD_MAX_ADD_BUFSIZE);

        relayf = NULL;
        if (is_client) {
                rcvr = "client";
                sndr = "realserver";
                rfd = conn->ciom.fd;
                sfd = conn->riom.fd;
        } else {
                rcvr = "realserver";
                sndr = "client";
                rfd = conn->riom.fd;
                sfd = conn->ciom.fd;
                if (conn->srv != NULL) {
                        relayf = conn->srv->pm->analyze_rsdata;
                }
        }

        len = recv(rfd, buf,L7VS_CONN_RELAY_BUFSIZE , 0);
        if (len < 0) {
                return len;
        } else if (len == 0) {
                VANESSA_LOGGER_DEBUG_UNSAFE("%s connection closed\n", rcvr);
                return len;
        }

	if (relayf != NULL) {
		len_ret = len;
		ret = (*relayf)(conn->srv, conn, buf, &len_ret);
		if (ret != 0) {
			VANESSA_LOGGER_ERR("failed to analyze realserver data");
		}
		if (len > len_ret + L7VS_PROTOMOD_MAX_ADD_BUFSIZE){
			VANESSA_LOGGER_ERR("bufsize too long modified by protomod ");
			return -1;
		} else {
			len = len_ret;
                }
	}

        ret = send(sfd, buf, len, 0);
        if (ret < 0) {
                VANESSA_LOGGER_ERR_UNSAFE("send on %s fd failed: %s\n",
                                          sndr, strerror(errno));
                return ret;
        } else if (ret != len) {
                VANESSA_LOGGER_ERR_UNSAFE("send len(%d) does not match"
                                          " recv len(%d)\n", ret, len);
                return -1;
        }

        return ret;
}

static int
l7vs_conn_relay_cl(struct l7vs_conn *conn)
{
        return l7vs_conn_relay(conn, 1);
}

static int
l7vs_conn_relay_rs(struct l7vs_conn *conn)
{
        return l7vs_conn_relay(conn, 0);
}

int
l7vs_conn_splice(struct l7vs_conn *conn)
{
        return setsockopt(conn->ciom.fd, IPPROTO_IP, TCPS_SO_SPLICE,
                          &conn->riom.fd, sizeof(conn->riom.fd));
}

int
l7vs_conn_connect_rs(struct l7vs_conn *conn, struct l7vs_dest *dest)
{
        int s;
        int stype;
        int flags;
        int ret;
        int iomux_flags;

        switch (conn->proto) {
        case IPPROTO_TCP:
                stype = SOCK_STREAM;
                break;
        case IPPROTO_UDP:
                stype = SOCK_DGRAM;
                break;
        default:
                VANESSA_LOGGER_ERR("Unknwon socket type");
                return -1;
        }

        s = socket(PF_INET, stype, conn->proto);  
        if (s < 0) {
                VANESSA_LOGGER_ERR_UNSAFE("socket: %s", strerror(errno));
                return s;
        }
        
        flags = fcntl(s, F_GETFL, 0);
        if (flags < 0) {
                VANESSA_LOGGER_ERR_UNSAFE("fcntl(F_GETFL): %s",
                                          strerror(errno));
                close(s);
                return -1;
        }

        flags = fcntl(s, F_SETFL, flags | O_NONBLOCK);
        if (flags < 0) {
                VANESSA_LOGGER_ERR_UNSAFE("fcntl(F_SETFL): %s",
                                          strerror(errno));
                close(s);
                return flags;
        }

        if (conn->cmss > 0) {
                ret = setsockopt(s, IPPROTO_TCP, TCP_MAXSEG, &conn->cmss,
                                 sizeof(conn->cmss));
                if (ret < 0) {
                        VANESSA_LOGGER_ERR_UNSAFE(
                                "setsockopt(TCP_MAXSEG): %s, continuing",
                                strerror(errno));
                }
        }

        ret = connect(s, (struct sockaddr *)&dest->addr,
                      sizeof(dest->addr)); 
        if (ret >= 0) {
                conn->state = L7VS_CONN_S_RS_CONNECTED; 
                VANESSA_LOGGER_DEBUG_UNSAFE("dest %p nactive %d->%d (imm.)",
                                            conn->dest, conn->dest->nactive,
                                            conn->dest->nactive + 1);
                iomux_flags = L7VS_IOMUX_READ;
                conn->dest->nactive++;   
        } else {
                if (errno != EINPROGRESS) {
                        VANESSA_LOGGER_ERR_UNSAFE("connect: %s", 
                                                  strerror(errno));
                        close(s);
                        return ret;
                }
                conn->state = L7VS_CONN_S_RS_CONNECTING;
                iomux_flags = L7VS_IOMUX_WRITE;
        }
        conn->dest = dest;   

        conn->riom.fd = s;
        conn->riom.callback = l7vs_conn_rs_callback;
        conn->riom.data = conn;
        l7vs_iomux_add(&conn->riom, iomux_flags);

        return 0;
}

void
l7vs_conn_close_csock(struct l7vs_conn *conn)
{
        l7vs_iomux_remove(&conn->ciom);
        close(conn->ciom.fd);
        conn->ciom.fd = -1;
}

void
l7vs_conn_close_rsock(struct l7vs_conn *conn)
{
        if (conn->state == L7VS_CONN_S_RS_CONNECTED) {
                VANESSA_LOGGER_DEBUG_UNSAFE("dest %p nactive %d->%d",
                                            conn->dest, conn->dest->nactive,
                                            conn->dest->nactive - 1);
                conn->dest->nactive--;
                conn->dest->ninact++;
                if(conn->dest->ninact == INT_MAX) {
                        conn->dest->ninact = 0;
                }
        }
        l7vs_iomux_remove(&conn->riom);
        close(conn->riom.fd);
        conn->riom.fd = -1;
}

int
l7vs_conn_closed(struct l7vs_conn *conn)
{
        return (conn->ciom.fd < 0 && conn->riom.fd < 0);
}
