#include "config.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>

#include <CGAL/Exact_predicates_inexact_constructions_kernel.h>
#include <CGAL/Delaunay_triangulation_2.h>
#include <CGAL/natural_neighbor_coordinates_2.h>

#include <pipi.h>

/*
 * User-definable settings.
 */

/* The maximum message length */
#define MAX_MSG_LEN 140

/* The number of characters at disposal */
//#define NUM_CHARACTERS 0x7fffffff // The sky's the limit
//#define NUM_CHARACTERS 1111998 // Full valid Unicode set
//#define NUM_CHARACTERS 100507 // Full graphic Unicode
#define NUM_CHARACTERS 32768 // Chinese characters
//#define NUM_CHARACTERS 127 // ASCII

/* The maximum image size we want to support */
#define MAX_W 4000
#define MAX_H 4000

/* How does the algorithm work: one point per cell, or two */
#define POINTS_PER_CELL 2

/* The range value for point parameters: X Y, red/green/blue, "strength"
 * Tested values (on Mona Lisa) are:
 *  16 16 5 5 5 2 -> 0.06511725914
 *  16 16 6 7 6 1 -> 0.05731491348 *
 *  16 16 7 6 6 1 -> 0.06450513783
 *  14 14 7 7 6 1 -> 0.0637207893
 *  19 19 6 6 5 1 -> 0.06801999094 */
static unsigned int RANGE_X = 16;
static unsigned int RANGE_Y = 16;
static unsigned int RANGE_R = 6;
static unsigned int RANGE_G = 6;
static unsigned int RANGE_B = 6;
static unsigned int RANGE_S = 1;

/*
 * These values are computed at runtime
 */

static float TOTAL_BITS;
static float HEADER_BITS;
static float DATA_BITS;
static float POINT_BITS;

static unsigned int TOTAL_CELLS;

#define RANGE_SY (RANGE_S*RANGE_Y)
#define RANGE_SYX (RANGE_S*RANGE_Y*RANGE_X)
#define RANGE_SYXR (RANGE_S*RANGE_Y*RANGE_X*RANGE_R)
#define RANGE_SYXRG (RANGE_S*RANGE_Y*RANGE_X*RANGE_R*RANGE_G)
#define RANGE_SYXRGB (RANGE_S*RANGE_Y*RANGE_X*RANGE_R*RANGE_G*RANGE_B)

struct K : CGAL::Exact_predicates_inexact_constructions_kernel {};
typedef CGAL::Delaunay_triangulation_2<K> Delaunay_triangulation;
typedef std::vector<std::pair<K::Point_2, K::FT> > Point_coordinate_vector;

/* Global aspect ratio */
static unsigned int dw, dh;

/* Global point encoding */
static uint32_t points[1024];
static int npoints = 0;

/* Global triangulation */
static Delaunay_triangulation dt;

static unsigned int det_rand(unsigned int mod)
{
    static unsigned long next = 1;
    next = next * 1103515245 + 12345;
    return ((unsigned)(next / 65536) % 32768) % mod;
}

static inline int range2int(float val, int range)
{
    int ret = (int)(val * ((float)range - 0.0001));
    return ret < 0 ? 0 : ret > range - 1 ? range - 1 : ret;
}

static inline float int2midrange(int val, int range)
{
    return (float)(1 + 2 * val) / (float)(2 * range);
}

static inline float int2fullrange(int val, int range)
{
    return range > 1 ? (float)val / (float)(range - 1) : 0.0;
}

static inline void set_point(int index, float x, float y, float r,
                             float g, float b, float s)
{
    int dx = (index / POINTS_PER_CELL) % dw;
    int dy = (index / POINTS_PER_CELL) / dw;

    float fx = (x - dx * RANGE_X) / RANGE_X;
    float fy = (y - dy * RANGE_Y) / RANGE_Y;

    int is = range2int(s, RANGE_S);

    int ix = range2int(fx, RANGE_X);
    int iy = range2int(fy, RANGE_Y);

    int ir = range2int(r, RANGE_R);
    int ig = range2int(g, RANGE_G);
    int ib = range2int(b, RANGE_B);

    points[index] = is + RANGE_S * (iy + RANGE_Y * (ix + RANGE_X *
                               (ib + RANGE_B * (ig + (RANGE_R * ir)))));
}

static inline void get_point(int index, float *x, float *y, float *r,
                             float *g, float *b, float *s)
{
    uint32_t pt = points[index];

    unsigned int dx = (index / POINTS_PER_CELL) % dw;
    unsigned int dy = (index / POINTS_PER_CELL) / dw;

    *s = int2fullrange(pt % RANGE_S, RANGE_S); pt /= RANGE_S;

    float fy = int2midrange(pt % RANGE_Y, RANGE_Y); pt /= RANGE_Y;
    float fx = int2midrange(pt % RANGE_X, RANGE_X); pt /= RANGE_X;

    *x = (fx + dx) * RANGE_X;
    *y = (fy + dy) * RANGE_Y;

    *b = int2midrange(pt % RANGE_R, RANGE_R); pt /= RANGE_R;
    *g = int2midrange(pt % RANGE_G, RANGE_G); pt /= RANGE_G;
    *r = int2midrange(pt % RANGE_B, RANGE_B); pt /= RANGE_B;
}

static inline float clip(float x, int modulo)
{
    float mul = (float)modulo + 0.9999;
    int round = (int)(x * mul);
    return (float)round / (float)modulo;
}

static void add_point(float x, float y, float r, float g, float b, float s)
{
    set_point(npoints, x, y, r, g, b, s);
    npoints++;
}

static void add_random_point()
{
    points[npoints] = det_rand(RANGE_SYXRGB);
    npoints++;
}

#define NB_OPS 20

static uint8_t rand_op(void)
{
    uint8_t x = det_rand(NB_OPS);

    /* Randomly ignore statistically less efficient ops */
    if(x == 0)
        return rand_op();
    if(x == 1 && (RANGE_S == 1 || det_rand(2)))
        return rand_op();
    if(x <= 5 && det_rand(2))
        return rand_op();
    //if((x < 10 || x > 15) && !det_rand(4)) /* Favour colour changes */
    //    return rand_op();

    return x;
}

static uint32_t apply_op(uint8_t op, uint32_t val)
{
    uint32_t rem, ext;

    switch(op)
    {
    case 0: /* Flip strength value */
    case 1:
        /* Statistics show that this helps often, but does not reduce
         * the error significantly. */
        return val ^ 1;
    case 2: /* Move up; if impossible, down */
        rem = val % RANGE_S;
        ext = (val / RANGE_S) % RANGE_Y;
        ext = ext > 0 ? ext - 1 : ext + 1;
        return (val / RANGE_SY * RANGE_Y + ext) * RANGE_S + rem;
    case 3: /* Move down; if impossible, up */
        rem = val % RANGE_S;
        ext = (val / RANGE_S) % RANGE_Y;
        ext = ext < RANGE_Y - 1 ? ext + 1 : ext - 1;
        return (val / RANGE_SY * RANGE_Y + ext) * RANGE_S + rem;
    case 4: /* Move left; if impossible, right */
        rem = val % RANGE_SY;
        ext = (val / RANGE_SY) % RANGE_X;
        ext = ext > 0 ? ext - 1 : ext + 1;
        return (val / RANGE_SYX * RANGE_X + ext) * RANGE_SY + rem;
    case 5: /* Move left; if impossible, right */
        rem = val % RANGE_SY;
        ext = (val / RANGE_SY) % RANGE_X;
        ext = ext < RANGE_X - 1 ? ext + 1 : ext - 1;
        return (val / RANGE_SYX * RANGE_X + ext) * RANGE_SY + rem;
    case 6: /* Corner 1 */
        return apply_op(2, apply_op(4, val));
    case 7: /* Corner 2 */
        return apply_op(2, apply_op(5, val));
    case 8: /* Corner 3 */
        return apply_op(3, apply_op(5, val));
    case 9: /* Corner 4 */
        return apply_op(3, apply_op(4, val));
    case 16: /* Double up */
        return apply_op(2, apply_op(2, val));
    case 17: /* Double down */
        return apply_op(3, apply_op(3, val));
    case 18: /* Double left */
        return apply_op(4, apply_op(4, val));
    case 19: /* Double right */
        return apply_op(5, apply_op(5, val));
    case 10: /* R-- (or R++) */
        rem = val % RANGE_SYX;
        ext = (val / RANGE_SYX) % RANGE_R;
        ext = ext > 0 ? ext - 1 : ext + 1;
        return (val / RANGE_SYXR * RANGE_R + ext) * RANGE_SYX + rem;
    case 11: /* R++ (or R--) */
        rem = val % RANGE_SYX;
        ext = (val / RANGE_SYX) % RANGE_R;
        ext = ext < RANGE_R - 1 ? ext + 1 : ext - 1;
        return (val / RANGE_SYXR * RANGE_R + ext) * RANGE_SYX + rem;
    case 12: /* G-- (or G++) */
        rem = val % RANGE_SYXR;
        ext = (val / RANGE_SYXR) % RANGE_G;
        ext = ext > 0 ? ext - 1 : ext + 1;
        return (val / RANGE_SYXRG * RANGE_G + ext) * RANGE_SYXR + rem;
    case 13: /* G++ (or G--) */
        rem = val % RANGE_SYXR;
        ext = (val / RANGE_SYXR) % RANGE_G;
        ext = ext < RANGE_G - 1 ? ext + 1 : ext - 1;
        return (val / RANGE_SYXRG * RANGE_G + ext) * RANGE_SYXR + rem;
    case 14: /* B-- (or B++) */
        rem = val % RANGE_SYXRG;
        ext = (val / RANGE_SYXRG) % RANGE_B;
        ext = ext > 0 ? ext - 1 : ext + 1;
        return ext * RANGE_SYXRG + rem;
    case 15: /* B++ (or B--) */
        rem = val % RANGE_SYXRG;
        ext = (val / RANGE_SYXRG) % RANGE_B;
        ext = ext < RANGE_B - 1 ? ext + 1 : ext - 1;
        return ext * RANGE_SYXRG + rem;
#if 0
    case 15: /* Brightness-- */
        return apply_op(9, apply_op(11, apply_op(13, val)));
    case 16: /* Brightness++ */
        return apply_op(10, apply_op(12, apply_op(14, val)));
    case 17: /* RG-- */
        return apply_op(9, apply_op(11, val));
    case 18: /* RG++ */
        return apply_op(10, apply_op(12, val));
    case 19: /* GB-- */
        return apply_op(11, apply_op(13, val));
    case 20: /* GB++ */
        return apply_op(12, apply_op(14, val));
    case 21: /* RB-- */
        return apply_op(9, apply_op(13, val));
    case 22: /* RB++ */
        return apply_op(10, apply_op(14, val));
#endif
    default:
        return val;
    }
}

static void render(pipi_image_t *dst, int rx, int ry, int rw, int rh)
{
    uint8_t lookup[TOTAL_CELLS * RANGE_X * RANGE_Y];
    pipi_pixels_t *p = pipi_get_pixels(dst, PIPI_PIXELS_RGBA_F32);
    float *data = (float *)p->pixels;
    int i, x, y;

    memset(lookup, 0, sizeof(lookup));
    dt.clear();
    for(i = 0; i < npoints; i++)
    {
        float fx, fy, fr, fg, fb, fs;
        get_point(i, &fx, &fy, &fr, &fg, &fb, &fs);
        lookup[(int)fx + dw * RANGE_X * (int)fy] = i; /* Keep link to point */
        dt.insert(K::Point_2(fx, fy));
    }

    /* Add fake points to close the triangulation */
    dt.insert(K::Point_2(-p->w, -p->h));
    dt.insert(K::Point_2(2 * p->w, -p->h));
    dt.insert(K::Point_2(-p->w, 2 * p->h));
    dt.insert(K::Point_2(2 * p->w, 2 * p->h));

    for(y = ry; y < ry + rh; y++)
    {
        for(x = rx; x < rx + rw; x++)
        {
            K::Point_2 m(x, y);
            Point_coordinate_vector coords;
            CGAL::Triple<
              std::back_insert_iterator<Point_coordinate_vector>,
              K::FT, bool> result =
              CGAL::natural_neighbor_coordinates_2(dt, m,
                                                   std::back_inserter(coords));

            float r = 0.0f, g = 0.0f, b = 0.0f, norm = 0.0f;

            Point_coordinate_vector::iterator it;
            for(it = coords.begin(); it != coords.end(); ++it)
            {
                float fx, fy, fr, fg, fb, fs;

                fx = (*it).first.x();
                fy = (*it).first.y();

                if(fx < 0 || fy < 0 || fx > p->w - 1 || fy > p->h - 1)
                    continue;

                int index = lookup[(int)fx + dw * RANGE_X * (int)fy];

                get_point(index, &fx, &fy, &fr, &fg, &fb, &fs);

                //float k = pow((*it).second * (1.0 + fs), 1.2);
                float k = (*it).second * (1.00f + fs);
                //float k = (*it).second * (0.60f + fs);
                //float k = pow((*it).second, (1.0f + fs));

                r += k * fr;
                g += k * fg;
                b += k * fb;
                norm += k;
            }

            data[4 * (x + y * p->w) + 0] = r / norm;
            data[4 * (x + y * p->w) + 1] = g / norm;
            data[4 * (x + y * p->w) + 2] = b / norm;
            data[4 * (x + y * p->w) + 3] = 0.0;
        }
    }

    pipi_release_pixels(dst, p);
}

static void analyse(pipi_image_t *src)
{
    pipi_pixels_t *p = pipi_get_pixels(src, PIPI_PIXELS_RGBA_F32);
    float *data = (float *)p->pixels;

    for(unsigned int dy = 0; dy < dh; dy++)
        for(unsigned int dx = 0; dx < dw; dx++)
        {
            float min = 1.1f, max = -0.1f;
            float total = 0.0;
            int xmin = 0, xmax = 0, ymin = 0, ymax = 0;
            int npixels = 0;

            for(unsigned int iy = RANGE_Y * dy; iy < RANGE_Y * (dy + 1); iy++)
                for(unsigned int ix = RANGE_X * dx; ix < RANGE_X * (dx + 1); ix++)
                {
                    float lum = 0.0f;

                    lum += data[4 * (ix + iy * p->w) + 0];
                    lum += data[4 * (ix + iy * p->w) + 1];
                    lum += data[4 * (ix + iy * p->w) + 2];

                    if(lum < min)
                    {
                        min = lum;
                        xmin = ix;
                        ymin = iy;
                    }

                    if(lum > max)
                    {
                        max = lum;
                        xmax = ix;
                        ymax = iy;
                    }

                    total += lum;
                    npixels++;
                }

            total /= npixels;

            float wmin, wmax;

            if(total < min + (max - min) / 4)
                wmin = 1.0, wmax = 0.0;
            else if(total < min + (max - min) / 4 * 3)
                wmin = 0.0, wmax = 0.0;
            else
                wmin = 0.0, wmax = 1.0;

#if 0
add_random_point();
add_random_point();
#else
#if POINTS_PER_CELL == 1
            if(total < min + (max - min) / 2)
            {
#endif
            add_point(xmin, ymin,
                      data[4 * (xmin + ymin * p->w) + 0],
                      data[4 * (xmin + ymin * p->w) + 1],
                      data[4 * (xmin + ymin * p->w) + 2],
                      wmin);
#if POINTS_PER_CELL == 1
            }
            else
            {
#endif
            add_point(xmax, ymax,
                      data[4 * (xmax + ymax * p->w) + 0],
                      data[4 * (xmax + ymax * p->w) + 1],
                      data[4 * (xmax + ymax * p->w) + 2],
                      wmax);
#if POINTS_PER_CELL == 1
            }
#endif
#endif
        }
}

int main(int argc, char *argv[])
{
    int opstats[2 * NB_OPS];
    pipi_image_t *src, *tmp, *dst;
    double error = 1.0;
    int width, height, ret = 0;

    /* Compute bit allocation */
    fprintf(stderr, "Available characters: %i\n", NUM_CHARACTERS);
    fprintf(stderr, "Maximum message size: %i\n", MAX_MSG_LEN);
    TOTAL_BITS = MAX_MSG_LEN * logf(NUM_CHARACTERS) / logf(2);
    fprintf(stderr, "Available bits: %f\n", TOTAL_BITS);
    fprintf(stderr, "Maximum image resolution: %ix%i\n", MAX_W, MAX_H);
    HEADER_BITS = logf(MAX_W * MAX_H) / logf(2);
    fprintf(stderr, "Header bits: %f\n", HEADER_BITS);
    DATA_BITS = TOTAL_BITS - HEADER_BITS;
    fprintf(stderr, "Bits available for data: %f\n", DATA_BITS);
#if POINTS_PER_CELL == 1
    POINT_BITS = logf(RANGE_SYXRGB) / logf(2);
#else
    float coord_bits = logf((RANGE_Y * RANGE_X) * (RANGE_Y * RANGE_X + 1) / 2);
    float other_bits = logf(RANGE_R * RANGE_G * RANGE_B * RANGE_S);
    POINT_BITS = (coord_bits + 2 * other_bits) / logf(2);
#endif
    fprintf(stderr, "Cell bits: %f\n", POINT_BITS);
    TOTAL_CELLS = (int)(DATA_BITS / POINT_BITS);
    fprintf(stderr, "Available cells: %i\n", TOTAL_CELLS);
    fprintf(stderr, "Wasted bits: %f\n", DATA_BITS - POINT_BITS * TOTAL_CELLS);

    /* Load image */
    pipi_set_gamma(1.0);
    src = pipi_load(argv[1]);
    width = pipi_get_image_width(src);
    height = pipi_get_image_height(src);

    /* Compute best w/h ratio */
    dw = 1; dh = TOTAL_CELLS;
    for(unsigned int i = 1; i <= TOTAL_CELLS; i++)
    {
        int j = TOTAL_CELLS / i;

        float r = (float)width / (float)height;
        float ir = (float)i / (float)j;
        float dwr = (float)dw / (float)dh;

        if(fabs(logf(r / ir)) < fabs(logf(r / dwr)))
        {
            dw = i;
            dh = TOTAL_CELLS / dw;
        }
    }
    while((dh + 1) * dw <= TOTAL_CELLS) dh++;
    while(dw * (dh + 1) <= TOTAL_CELLS) dw++;
    fprintf(stderr, "Chosen image ratio: %i:%i (wasting %i point cells)\n",
            dw, dh, TOTAL_CELLS - dw * dh);
    fprintf(stderr, "Total wasted bits: %f\n",
            DATA_BITS - POINT_BITS * dw * dh);

    /* Resize and filter image to better state */
    tmp = pipi_resize(src, dw * RANGE_X, dh * RANGE_Y);
    pipi_free(src);
    src = pipi_median_ext(tmp, 1, 1);
    pipi_free(tmp);

    /* Analyse image */
    analyse(src);

    /* Render what we just computed */
    tmp = pipi_new(dw * RANGE_X, dh * RANGE_Y);
    render(tmp, 0, 0, dw * RANGE_X, dh * RANGE_Y);
    error = pipi_measure_rmsd(src, tmp);

    fprintf(stderr, "Distance: %2.10g\n", error);

    memset(opstats, 0, sizeof(opstats));
    for(int iter = 0, stuck = 0, failures = 0, success = 0;
        /*stuck < 5 && */iter < 10000;
        iter++)
    {
        if(failures > 500)
        {
            stuck++;
            failures = 0;
        }

        pipi_image_t *scrap = pipi_copy(tmp);

        /* Choose a point at random */
        int pt = det_rand(npoints);
        uint32_t oldval = points[pt];

        /* Compute the affected image zone */
        float fx, fy, fr, fg, fb, fs;
        get_point(pt, &fx, &fy, &fr, &fg, &fb, &fs);
        int zonex = (int)fx / RANGE_X - 1;
        int zoney = (int)fy / RANGE_Y - 1;
        int zonew = 3;
        int zoneh = 3;
        if(zonex < 0) { zonex = 0; zonew--; }
        if(zoney < 0) { zoney = 0; zoneh--; }
        if(zonex + zonew >= (int)dw) { zonew--; }
        if(zoney + zoneh >= (int)dh) { zoneh--; }

        /* Choose random operations and measure their effect */
        uint8_t op1 = rand_op();
        //uint8_t op2 = rand_op();

        uint32_t candidates[3];
        double besterr = error + 1.0;
        int bestop = -1;
        candidates[0] = apply_op(op1, oldval);
        //candidates[1] = apply_op(op2, oldval);
        //candidates[2] = apply_op(op1, apply_op(op2, oldval));

        for(int i = 0; i < 1; i++)
        //for(int i = 0; i < 3; i++)
        {
            if(oldval == candidates[i])
                continue;

            points[pt] = candidates[i];

            render(scrap, zonex * RANGE_X, zoney * RANGE_Y,
                   zonew * RANGE_X, zoneh * RANGE_Y);

            double newerr = pipi_measure_rmsd(src, scrap);
            if(newerr < besterr)
            {
                besterr = newerr;
                bestop = i;
            }
        }

        opstats[op1 * 2]++;
        //opstats[op2 * 2]++;

        if(besterr < error)
        {
            points[pt] = candidates[bestop];
            /* Redraw image if the last check wasn't the best one */
            if(bestop != 2)
                render(scrap, zonex * RANGE_X, zoney * RANGE_Y,
                       zonew * RANGE_X, zoneh * RANGE_Y);

            pipi_free(tmp);
            tmp = scrap;
            //fprintf(stderr, "%08i %2.010g %2.010g after op%i(%i)\n",
            //        iter, besterr - error, error, op1, pt);
            fprintf(stderr, "%08i -.%08i %2.010g after op%i(%i)\n", iter,
                    (int)((error - besterr) * 100000000), error, op1, pt);
            error = besterr;
            opstats[op1 * 2 + 1]++;
            //opstats[op2 * 2 + 1]++;
            failures = 0;
            success++;

            /* Save image! */
            //char buf[128];
            //sprintf(buf, "twit%08i.bmp", success);
            //if((success % 10) == 0)
            //    pipi_save(tmp, buf);
        }
        else
        {
            pipi_free(scrap);
            points[pt] = oldval;
            failures++;
        }
    }

    for(int j = 0; j < 2; j++)
    {
        fprintf(stderr,   "operation: ");
        for(int i = NB_OPS / 2 * j; i < NB_OPS / 2 * (j + 1); i++)
            fprintf(stderr, "%4i ", i);
        fprintf(stderr, "\nattempts:  ");
        for(int i = NB_OPS / 2 * j; i < NB_OPS / 2 * (j + 1); i++)
            fprintf(stderr, "%4i ", opstats[i * 2]);
        fprintf(stderr, "\nsuccesses: ");
        for(int i = NB_OPS / 2 * j; i < NB_OPS / 2 * (j + 1); i++)
            fprintf(stderr, "%4i ", opstats[i * 2 + 1]);
        fprintf(stderr, "\n");
    }

    fprintf(stderr, "Distance: %2.10g\n", error);

    dst = pipi_resize(tmp, width, height);
    pipi_free(tmp);

    /* Save image and bail out */
    pipi_save(dst, "lol.bmp");
    pipi_free(dst);

    return ret;
}