#include "mex.h"
#include <cassert>
#include <cmath>
#include <algorithm>
#include <iostream>
using namespace std;


// idea: we specify the n regions in a n*3 matrix
// and the (x,y) pairs in two more nx*ny matrices


void backup(double & xout, double & yout, double & hops, double x, double y, const double * D, int n) {

  hops = 0;
  bool changed;;
  do {
    changed = false;
    for (int i = 0; i < n; ++i) {
      const double rho = D[i+2*n];
      assert(0 < rho && rho < 1);
      const double xlo = D[i+1*n] - rho/2;
      const double ylo = D[i+0*n] - rho/2;
      
      const double sx = xlo/(1-rho);
      assert(0 <= sx && sx <= 1);
      const double sy = ylo/(1-rho);
      assert(0 <= sy && sy <= 1);
      
      double xr = sx > x ? (sx-x)/sx : (x-sx)/(1-sx);
      double yr = sy > y ? (sy-y)/sy : (y-sy)/(1-sy);
      assert(xr >= 0);
      assert(yr >= 0);
      double r = max(xr, yr);

      if (r > rho) continue; // not in jurisdiction of this Droste operator
      
      int n = floor(log(r)/log(rho));
      assert(n >= 1);
      const double sc = pow(rho, -n);
      x = (x - sx)*sc + sx;
      y = (y - sy)*sc + sy;
      hops += n;
      changed = true;
    }
  } while (changed);
  
  xout = x;
  yout = y;
}

// we want to pullback a bunch of (x,y) pairs


/* The gateway function */
void mexFunction( int nlhs, mxArray *plhs[],
                  int nrhs, const mxArray *prhs[])
{
    if(nrhs != 3) {
        mexErrMsgIdAndTxt("bkup:nrhs","Three inputs required (X,Y,D).");
    }

    if(nlhs != 3) {
        mexErrMsgIdAndTxt("bkup:nlhs","Three outputs required (X,Y,H).");
    }

    if(!mxIsDouble(prhs[0])) {
        mexErrMsgIdAndTxt("bkup:notDouble","First input (X) must be of type double.");
    }

    if(!mxIsDouble(prhs[1])) {
        mexErrMsgIdAndTxt("bkup:notDouble","Second input (Y) must be of type double.");
    }

    if(!mxIsDouble(prhs[2])) {
        mexErrMsgIdAndTxt("bkup:notDouble","Third input (D) must be of type double.");
    }
    

    const size_t   ndimX = mxGetNumberOfDimensions(prhs[0]);
    const mwSize * dimsX = mxGetDimensions(prhs[0]);
    const size_t   ndimY = mxGetNumberOfDimensions(prhs[1]);
    const mwSize * dimsY = mxGetDimensions(prhs[1]);
    const size_t   ndimD = mxGetNumberOfDimensions(prhs[2]);
    const mwSize * dimsD = mxGetDimensions(prhs[2]);
    

    if (ndimX != ndimY) {
      mexErrMsgIdAndTxt("bkup:incompatible","Inputs must have the same number of dimensions.");
    }
    
    for (size_t i = 0; i < ndimX; ++i) {
      if (dimsX[i] != dimsY[i]) {
	mexErrMsgIdAndTxt("bkup:incompatible","Inputs must have the same sizes.");
      }
    }

    if (ndimD != 2 || dimsD[1] != 3) {
      mexErrMsgIdAndTxt("bkup:misshapen","D should be N by 3");
    }
    

    const size_t nelts = mxGetNumberOfElements(prhs[0]);
    assert(nelts == mxGetNumberOfElements(prhs[1]));

    /* get R and V */
    const double * X = mxGetPr(prhs[0]);
    const double * Y = mxGetPr(prhs[1]);
    const double * D = mxGetPr(prhs[2]);

    /* create the output matrices */
    plhs[0] = mxCreateNumericArray(ndimX, dimsX, mxDOUBLE_CLASS, mxREAL);
    plhs[1] = mxCreateNumericArray(ndimX, dimsX, mxDOUBLE_CLASS, mxREAL);
    plhs[2] = mxCreateNumericArray(ndimX, dimsX, mxDOUBLE_CLASS, mxREAL);

    /* get a pointer to the real data in the output matrix */
    double * outX = mxGetPr(plhs[0]);
    double * outY = mxGetPr(plhs[1]);
    double * outH = mxGetPr(plhs[2]);

    for (size_t i = 0; i < nelts; ++i) {
      backup(outX[i], outY[i], outH[i], X[i], Y[i], D, dimsD[0]);
    }
}
