
#include <vector>
#include <iostream>

#include "gmock/gmock.h"

#include "cm_thread.h"
#include "mt_auto_ptr.h"
#include "mt_shim_clear_memory.h"
#include "cm_socket_if.h"
#include "cm_socket_server.h"
#include "cm_socket_client.h"
#include "cm_event.h"
#include "mt_range.h"
#include "cm_mutex.h"
#include "cm_condition.h"

namespace {

struct ThreadArg
{
    ThreadArg()
        : mutex(), is_init_finished(false)
    {}

    cm::Mutex mutex;
    cm::Condition cond;
    bool is_init_finished;
};

struct FindSocket
{
    FindSocket(cm::SocketIf& socket_if)
        : socket_if_(socket_if)
    {}

    template <typename InputIterator>
    InputIterator operator()(InputIterator begin, const InputIterator& end)
    {
        for (; begin != end; ++begin) {
            if (*begin == &socket_if_) {
                return begin;
            }
        }
        return end;
    }

    const cm::SocketIf& socket_if_;
};

static const char* read_data_table[] = {
    "hello world!",
    "good bye world!",
    "hello world!",
    "good bye world!",
    "invalid"
};

class SocketIfMock : public cm::SocketIf
{
public:
    SocketIfMock(cm::SocketIf* socket_if)
        : socket_if_(socket_if)
    {
        EXPECT_CALL(*this, read(::testing::_, ::testing::_, ::testing::_))
            .WillOnce(::testing::Invoke(this, &SocketIfMock::doHandleRead))   // data read
            .WillOnce(::testing::Invoke(this, &SocketIfMock::doHandleDisconnection));
    }

    bool doHandleRead(size_t& bytes_read, void* buf, size_t size_to_read)
    {
        bool ret = socket_if_->read(bytes_read, buf, size_to_read);
        EXPECT_EQ(ret, true);
        EXPECT_STREQ(reinterpret_cast<char*>(buf), read_data_table[mock_creation_count_++]);
        return ret;
    }

    bool doHandleDisconnection(size_t& bytes_read, void* buf, size_t size_to_read)
    {
        bool ret = socket_if_->read(bytes_read, buf, size_to_read);
        EXPECT_EQ(ret, false);
        return ret;
    }

    MOCK_METHOD3(read, bool(size_t& bytes_read, void* buf, size_t size_to_read));
    MOCK_METHOD3(write, bool(size_t& bytes_read, const void* buf, size_t size_to_write));
    MOCK_METHOD0(release, int());

    virtual int getFD() const
    {
        return socket_if_->getFD();
    }

    ~SocketIfMock()
    {
        delete socket_if_;
    }

private:
    SocketIf* socket_if_;
    static unsigned int mock_creation_count_;

    MOCK_CONST_METHOD1(doClone, SocketIf*(int fd));
};

unsigned int SocketIfMock::mock_creation_count_ = 0u;

template <unsigned short PortNumber = 8888>
class CmSocketServerThread
{
public:
    CmSocketServerThread(ThreadArg* arg)
        : server_(cm::SOCKET_TYPE_INET_STREAM, "0.0.0.0", PortNumber),
          event_(cm::Event::tsdInstance()), mutex_(arg->mutex),
          accepted_sockets_(), write_socket_()
    {
        cm::Mutex::Lock lock(mutex_);
        event_.addHandlerRead(*this, &CmSocketServerThread<PortNumber>::accept, server_);
        arg->is_init_finished = true;
        arg->cond.broadcast();
    }

    bool accept(cm::SocketServer& server)
    {
        cm::Mutex::Lock lock(mutex_);

        mt::AutoPtr<cm::SocketIf> sock = server.accept();

        std::cout << "[" << __func__ << "] : entry FD[" << getFD(*sock.get()) << "]" << std::endl;
        cm::SocketIf* socket_ptr = sock.get();
        SocketIfMock* socket_if_mock_ptr = new SocketIfMock(socket_ptr);
        event_.addHandlerRead(*this, &CmSocketServerThread<PortNumber>::readSocket, *socket_if_mock_ptr);
        accepted_sockets_.push_back(socket_if_mock_ptr);
        sock.release();

        std::cout << "[" << __func__ << "] : exit" << std::endl;
        return false;
    }

    ~CmSocketServerThread()
    {
        EXPECT_EQ(accepted_sockets_.size(), 0u);
    }

    bool readSocket(cm::SocketIf& socket)
    {
        char buffer[20];
        size_t bytes_read = 0u;
        if (socket.read(bytes_read, buffer, sizeof(buffer))) {
            std::cout << "received " << buffer << std::endl;
        }
        else {
            event_.delHandlerRead(socket);
            FindSocket socket_finder(socket);
            std::vector<cm::SocketIf*>::iterator it = socket_finder(mt::begin(accepted_sockets_),
                                                                    mt::end(accepted_sockets_));
            std::cout << "closing accepted socket[" << getFD(socket) << "]" << std::endl;
            delete *it;
            accepted_sockets_.erase(it);
        }
        return false;
    }

    void run()
    {
        bool is_started = false;
        while (!is_started || accepted_sockets_.size() > 0) {
            event_.pend();
            is_started = true;
        }
    }

private:
    cm::SocketServer server_;
    cm::Event& event_;
    cm::Mutex& mutex_;
    std::vector<cm::SocketIf*> accepted_sockets_;
    mt::AutoPtr<cm::SocketIf> write_socket_;
};

class CmClientSocket
{
public:
    CmClientSocket(cm::Mutex& mutex)
        : client_(cm::SOCKET_TYPE_INET_STREAM, true),
          event_(cm::Event::tsdInstance()),
          write_socket_(), address_(), port_(0u)
    {
        cm::Mutex::Lock lock(mutex);
        write_socket_ = this->connect("127.0.0.1", 8889);
        event_.addHandlerWrite(*this, &CmClientSocket::handleSocketConnected, *write_socket_.get());
    }

    mt::AutoPtr<cm::SocketIf>& connect(const char* address, unsigned short port)
    {
        address_ = address;
        port_ = port;
        write_socket_ = client_.connect(address, port);
        return write_socket_;
    }

    bool handleSocketConnected(cm::SocketIf& connected_sock)
    {
        std::cout << "[" << __func__ << "] : entry FD[" << getFD(connected_sock) << "]" << std::endl;
        event_.delHandlerWrite(connected_sock);
        std::cout << "[" << __func__ << "] : finish deleting from Event FD[" << getFD(connected_sock) << "]" << std::endl;
        return false;
    }
    
    mt::AutoPtr<cm::SocketIf>& getWriteSocket()
    {
        return write_socket_;
    }

    const mt::AutoPtr<cm::SocketIf>& getWriteSocket() const
    {
        return write_socket_;
    }

private:
    cm::SocketClient client_;
    cm::Event& event_;
    mt::AutoPtr<cm::SocketIf> write_socket_;
    const char* address_;
    unsigned short port_;
};

TEST(CmSocketServerTest, inet_anyaddr)
{
    ThreadArg arg;
    cm::Thread<CmSocketServerThread<8888u>, ThreadArg> server_thread("server_thread");
    server_thread.create(&arg);

    {
        cm::Mutex::Lock lock(arg.mutex);
        if (arg.is_init_finished) {
            arg.cond.wait(arg.mutex);
        }
        cm::SocketClient client(cm::SOCKET_TYPE_INET_STREAM);

        mt::AutoPtr<cm::SocketIf> sock = client.connect("127.0.0.1", 8888);
        EXPECT_NE(sock.get(), static_cast<cm::SocketIf*>(0));

        char buffer[20];
        mt::clearMemory(buffer);
        strcpy(buffer, "hello world!");
        size_t bytes_written = 0u;
        sock->write(bytes_written, buffer, sizeof(buffer));
        sock.reset();

        mt::clearMemory(buffer);
        strcpy(buffer, "good bye world!");

        sock = client.connect("127.0.0.1", 8888);
        EXPECT_NE(sock.get(), static_cast<cm::SocketIf*>(0));

        sock->write(bytes_written, buffer, sizeof(buffer));
        sock.reset();
    }

    server_thread.join();
}

TEST(CmSocketServerTest, cm_event_addHandlerWrite)
{
    ThreadArg arg;
    cm::Thread<CmSocketServerThread<8889u>, ThreadArg> server_thread("server_thread");
    server_thread.create(&arg);

    {
        cm::Mutex::Lock lock(arg.mutex);
        if (!arg.is_init_finished) {
            arg.cond.wait(arg.mutex);
        }

        CmClientSocket client_socket(arg.mutex);
        cm::Event::tsdInstance().pend();

        char buffer[20];
        mt::clearMemory(buffer);
        strcpy(buffer, "hello world!");

        size_t bytes_written = 0u;
        mt::AutoPtr<cm::SocketIf>& sock = client_socket.getWriteSocket();
        sock->write(bytes_written, buffer, sizeof(buffer));

        CmClientSocket client_socket2(arg.mutex);
        cm::Event::tsdInstance().pend();

        mt::clearMemory(buffer);
        strcpy(buffer, "good bye world!");

        mt::AutoPtr<cm::SocketIf>& sock2 = client_socket2.getWriteSocket();
        sock2->write(bytes_written, buffer, sizeof(buffer));
        sock2.reset();
    }

    server_thread.join();
}

} // namespace
