
#include <rcsc/formation/formation_factory.h>
#include <rcsc/formation/formation_dt.h>
#include <rcsc/geom/delaunay_triangulation.h>
#include <rcsc/geom/triangle_2d.h>
#include <rcsc/geom/vector_2d.h>

#include <list>
#include <algorithm>
#include <string>
#include <fstream>
#include <iostream>
#include <ctime>

using namespace rcsc;

class LearnerDT {
private:

    typedef DelaunayTriangulation::Triangle Triangle;
    typedef Triangle* TrianglePtr;

    static const double PITCH_LENGTH;
    static const double PITCH_WIDTH;

    static const double BOUNDING_RECT_LENGTH;
    static const double BOUNDING_RECT_WIDTH;

    static const double MIN_SAMPLE_DIST;

    static const double MIN_ERROR;
    static const double ERROR_RATE;

    std::list< rcsc::Formation::Snapshot > M_train_data;

    rcsc::FormationDT M_formation;

public:

    LearnerDT();
    ~LearnerDT();


    void train( const rcsc::FormationPtr sample );

private:

    void save();

    void createBoundingData( const rcsc::FormationPtr sample );

    bool addSamplePointInLargestTriangle( const rcsc::FormationPtr sample );

    bool existTooNearSample( const rcsc::Vector2D & ball_pos );

    bool isBigError( const rcsc::Formation::Snapshot & snapshot );

};


/*-------------------------------------------------------------------*/
struct TriangleAreaCmp {

    bool operator()( const Triangle2D & lhs,
                     const Triangle2D & rhs ) const
      {
          return lhs.area() < rhs.area();
      }
};


/*-------------------------------------------------------------------*/
const double LearnerDT::PITCH_LENGTH = 105.0;
const double LearnerDT::PITCH_WIDTH = 68.0;

const double LearnerDT::BOUNDING_RECT_LENGTH = LearnerDT::PITCH_LENGTH + (2.01 * 2.0);
const double LearnerDT::BOUNDING_RECT_WIDTH = LearnerDT::PITCH_WIDTH + (2.01 * 2.0);

const double LearnerDT::MIN_SAMPLE_DIST = 2.0;

const double LearnerDT::MIN_ERROR = 0.5;
const double LearnerDT::ERROR_RATE = 0.09;

/*-------------------------------------------------------------------*/
LearnerDT::LearnerDT()
{

}

/*-------------------------------------------------------------------*/
LearnerDT::~LearnerDT()
{
    save();
}

/*-------------------------------------------------------------------*/
void
LearnerDT::save()
{
    char time_str[64];
    std::time_t current = std::time( NULL );
    tm * local_time = std::localtime( &current );
    if ( local_time )
    {
        std::strftime( time_str, 64, "-%Y%m%d-%H%M%S", local_time );
    }
    else
    {
        std::sprintf( time_str, "-time-%ld", current );
    }

    std::string filename = "output-";
    filename += M_formation.methodName();
    filename += time_str;

    // save formation conf file
    {
        std::string conf_file = filename + ".conf";
        std::ofstream fout( conf_file.c_str() );
        M_formation.print( fout );
        fout.close();
    }


    // save formation data file
    {
        std::string data_file = filename + ".dat";
        std::ofstream fout( data_file.c_str() );
        for ( std::list< Formation::Snapshot >::iterator it = M_train_data.begin();
              it != M_train_data.end();
              ++it )
        {
            fout << it->ball_.x << ' ' << it->ball_.y << ' ';
            for ( std::vector< rcsc::Vector2D >::iterator p = it->players_.begin();
                  p != it->players_.end();
                  ++p )
            {
                fout << p->x << ' ' << p->y << ' ';
            }
            fout << '\n';
        }
        fout << std::flush;
        fout.close();
    }

}

/*-------------------------------------------------------------------*/
void
LearnerDT::train( const rcsc::FormationPtr sample )
{
    M_train_data.clear();

    for ( int unum = 1; unum <= 11; ++unum )
    {
        M_formation.updateRole( unum,
                                sample->getSynmetryNumber( unum ),
                                sample->getRoleName( unum ) );
    }

    createBoundingData( sample );

    while ( M_train_data.size() < 500 )
    {
        if ( ! addSamplePointInLargestTriangle( sample ) )
        {
            break;
        }
    }

    std::cerr << "finished. data size = " << M_train_data.size()
              << std::endl;
}

/*-------------------------------------------------------------------*/
void
LearnerDT::createBoundingData( const rcsc::FormationPtr sample )
{
    std::vector< Vector2D > points;
#if 1
    points.push_back( Vector2D( 0.0, 0.0 ) );

    points.push_back( Vector2D( - BOUNDING_RECT_LENGTH * 0.5,
                                - BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( - BOUNDING_RECT_LENGTH * 0.25,
                                - BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( 0.0,
                                - BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( + BOUNDING_RECT_LENGTH * 0.25,
                                - BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( + BOUNDING_RECT_LENGTH * 0.5,
                                - BOUNDING_RECT_WIDTH * 0.5 ) );

    points.push_back( Vector2D( - BOUNDING_RECT_LENGTH * 0.5,
                                - BOUNDING_RECT_WIDTH * 0.25 ) );
    points.push_back( Vector2D( - BOUNDING_RECT_LENGTH * 0.25,
                                - BOUNDING_RECT_WIDTH * 0.25 ) );
    points.push_back( Vector2D( 0.0,
                                - BOUNDING_RECT_WIDTH * 0.25 ) );
    points.push_back( Vector2D( + BOUNDING_RECT_LENGTH * 0.25,
                                - BOUNDING_RECT_WIDTH * 0.25 ) );
    points.push_back( Vector2D( + BOUNDING_RECT_LENGTH * 0.5,
                                - BOUNDING_RECT_WIDTH * 0.25 ) );

    points.push_back( Vector2D( - BOUNDING_RECT_LENGTH * 0.5,
                                0.0 ) );
    points.push_back( Vector2D( + BOUNDING_RECT_LENGTH * 0.5,
                                0.0 ) );

    points.push_back( Vector2D( - BOUNDING_RECT_LENGTH * 0.5,
                                + BOUNDING_RECT_WIDTH * 0.25 ) );
    points.push_back( Vector2D( - BOUNDING_RECT_LENGTH * 0.25,
                                + BOUNDING_RECT_WIDTH * 0.25 ) );
    points.push_back( Vector2D( 0.0,
                                + BOUNDING_RECT_WIDTH * 0.25 ) );
    points.push_back( Vector2D( + BOUNDING_RECT_LENGTH * 0.25,
                                + BOUNDING_RECT_WIDTH * 0.25 ) );
    points.push_back( Vector2D( + BOUNDING_RECT_LENGTH * 0.5,
                                + BOUNDING_RECT_WIDTH * 0.25 ) );

    points.push_back( Vector2D( - BOUNDING_RECT_LENGTH * 0.5,
                                + BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( - BOUNDING_RECT_LENGTH * 0.25,
                                + BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( 0.0,
                                + BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( + BOUNDING_RECT_LENGTH * 0.5,
                                + BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( + BOUNDING_RECT_LENGTH * 0.25,
                                + BOUNDING_RECT_WIDTH * 0.5 ) );
#elif 1
    points.push_back( Vector2D( 0.0,
                                0.0 ) );

    points.push_back( Vector2D( - BOUNDING_RECT_LENGTH * 0.5,
                                - BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( - BOUNDING_RECT_LENGTH * 0.5,
                                + BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( + BOUNDING_RECT_LENGTH * 0.5,
                                - BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( + BOUNDING_RECT_LENGTH * 0.5,
                                + BOUNDING_RECT_WIDTH * 0.5 ) );

    points.push_back( Vector2D( - BOUNDING_RECT_LENGTH * 0.5,
                                0.0 ) );
    points.push_back( Vector2D( + BOUNDING_RECT_LENGTH * 0.5,
                                0.0 ) );
    points.push_back( Vector2D( 0.0,
                                - BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( 0.0,
                                + BOUNDING_RECT_WIDTH * 0.5 ) );
#else
    points.push_back( Vector2D( 0.0, 0.0 ) );
    points.push_back( Vector2D( - BOUNDING_RECT_LENGTH * 0.5,
                                - BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( - BOUNDING_RECT_LENGTH * 0.5,
                                + BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( + BOUNDING_RECT_LENGTH * 0.5,
                                - BOUNDING_RECT_WIDTH * 0.5 ) );
    points.push_back( Vector2D( + BOUNDING_RECT_LENGTH * 0.5,
                                + BOUNDING_RECT_WIDTH * 0.5 ) );
#endif

    for ( std::vector< Vector2D >::iterator it = points.begin();
          it != points.end();
          ++it )
    {
        Formation::Snapshot snapshot;

        snapshot.ball_ = *it;
        for ( int unum = 1; unum <= 11; ++unum )
        {
            Vector2D player_pos = sample->getPosition( unum, *it );
            snapshot.players_.push_back( player_pos );
        }

        M_train_data.push_back( snapshot );
    }

    M_formation.train( M_train_data );
}

/*-------------------------------------------------------------------*/
// 1. find largest triangle
// 2. create a new snapshot
// 3. check error
// 4. if error is big
//      add snapshot to the formation
//      return
//    else
//      remove largest triangle
//      goto 1
bool
LearnerDT::addSamplePointInLargestTriangle( const rcsc::FormationPtr sample )
{
    // create current triangles

    std::vector< Triangle2D > triangles;

    const std::map< int, TrianglePtr >::const_iterator end
        = M_formation.triangulation().triangleMap().end();
    for ( std::map< int, TrianglePtr >::const_iterator it
              = M_formation.triangulation().triangleMap().begin();
          it != end;
          ++it )
    {
        triangles.push_back( Triangle2D( it->second->vertex( 0 )->pos(),
                                         it->second->vertex( 1 )->pos(),
                                         it->second->vertex( 2 )->pos() ) );
    }

    if ( triangles.empty() )
    {
        std::cerr << "LearnerDT::addSamplePointInLargestTriangle() "
                  << "empty triangles"
                  << std::endl;
        return false;
    }

    std::sort( triangles.begin(), triangles.end(), TriangleAreaCmp() );

#if 0
    for ( std::vector< Triangle2D >::iterator it = triangles.begin();
          it != triangles.end();
          ++it )
    {
        std::cerr << " area = " << it->area() << std::endl;
    }
    return false;
#endif

    while ( ! triangles.empty() )
    {
        Triangle2D triangle = triangles.back();

        Formation::Snapshot snapshot;
#if 1
        snapshot.ball_ = triangle.getCentroid();
        //snapshot.ball_ = triangle.getIncenter();
#else
        snapshot.ball_ = triangle.getCircumcenter();

        if ( ! triangle.contains( snapshot.ball_ ) )
        {
            AngleDeg ball_to_a = ( triangle.a() - snapshot.ball_ ).th();
            AngleDeg ball_to_b = ( triangle.b() - snapshot.ball_ ).th();
            AngleDeg ball_to_c = ( triangle.c() - snapshot.ball_ ).th();

            if ( ball_to_a.isLeftOf( ball_to_b )
                 && ball_to_b.isLeftOf( ball_to_c ) )
            {
                snapshot.ball_ = ( triangle.a() + triangle.c() ) * 0.5;
            }
            else if ( ball_to_a.isLeftOf( ball_to_c )
                      && ball_to_c.isLeftOf( ball_to_b ) )
            {
                snapshot.ball_ = ( triangle.a() + triangle.b() ) * 0.5;
            }
            else
            {
                snapshot.ball_ = ( triangle.b() + triangle.c() ) * 0.5;
            }

            std::cerr << "circumcenter is out of triangle. adjust."
                      << triangle.a() << triangle.b() << triangle.c()
                      << " circumcenter="
                      << triangle.getCircumcenter() << " -> "
                      << snapshot.ball_
                      << std::endl;
        }

#endif
        if ( snapshot.ball_.absX() > BOUNDING_RECT_LENGTH * 0.5
             || snapshot.ball_.absY() > BOUNDING_RECT_WIDTH * 0.5 )
        {
            std::cerr << "ball pos is over the bounding area "
                      << snapshot.ball_
                      << std::endl;
            triangles.pop_back();
            continue;
        }

        if ( existTooNearSample( snapshot.ball_ ) )
        {
            triangles.pop_back();
            continue;
        }


        //std::cerr << "check triangle "
        //          << triangle.a() << triangle.b() << triangle.c()
        //          << " centroid = " << snapshot.ball_
        //          << std::endl;
        for ( int unum = 1; unum <= 11; ++unum )
        {
            Vector2D player_pos = sample->getPosition( unum, snapshot.ball_ );
            snapshot.players_.push_back( player_pos );
        }

        if ( isBigError( snapshot ) )
        {
            M_train_data.push_back( snapshot );
            M_formation.train( M_train_data );
            return true;
        }

        // remove largest triangle
        triangles.pop_back();
    }

    return false;
}

/*-------------------------------------------------------------------*/
bool
LearnerDT::existTooNearSample( const rcsc::Vector2D & ball_pos )
{
    for ( std::list< rcsc::Formation::Snapshot >::const_iterator it = M_train_data.begin();
          it != M_train_data.end();
          ++it )
    {
        if ( it->ball_.dist( ball_pos ) < MIN_SAMPLE_DIST )
        {
            std::cerr << "found too near sample "
                      << it->ball_
                      << "  to "
                      << ball_pos
                      << std::endl;
            return true;
        }
    }

    return false;
}

/*-------------------------------------------------------------------*/
bool
LearnerDT::isBigError( const rcsc::Formation::Snapshot & snapshot )
{
    try
    {
        for ( int unum = 1; unum <= 11; ++unum )
        {
            Vector2D pos = M_formation.getPosition( unum, snapshot.ball_ );
            double ball_dist = snapshot.ball_.dist( pos );
            if ( snapshot.players_.at( unum - 1 ).dist( pos )
                 > MIN_ERROR + ball_dist * ERROR_RATE
                 )
            {
                return true;
        }
        }
    }
    catch ( std::exception & e )
    {
        std::cerr << e.what() << std::endl;
    }

    return false;
}



int
main( int argc, char ** argv )
{
    std::cout << "start triangle reconstructor." << std::endl;

    std::string conf_file;

    for ( int i = 1; i < argc; ++i )
    {
        std::string arg = argv[i];
        if ( arg.length() > 5
             && arg.compare( arg.length() - 5, 5, ".conf" ) == 0 )
        {
            conf_file = arg;
        }
    }

    if ( conf_file.empty() )
    {
        std::cerr << "empty conf file" << std::endl;
        return 1;
    }

    rcsc::FormationPtr formation = rcsc::make_formation( conf_file );
    if ( ! formation )
    {
        std::cerr << "failed to make the formation from " << conf_file << std::endl;
        return 1;
    }

    std::ifstream fin( conf_file.c_str() );
    if ( ! formation->read( fin ) )
    {
        std::cout << "failed to read " << conf_file << std::endl;
        return 1;
    }

    LearnerDT learner;

    learner.train( formation );

    std::cout << "end triangle reconstructor." << std::endl;
    return 0;
}
