#include "AgglomClusterTree.hh"
#include "lcl_array.hh"
#include <stdexcept>
#include <cmath>

using namespace std;

//======================================  Distance inference methods.
static double
linkage_single(double a, double b) {
   return (a < b) ? a : b;
}

//======================================  Constructor
AgglomClusterTree::glom::glom(size_t l1, size_t l2, double d)
  : _link1(l1), _link2(l2), _distance(d)
{}

AgglomClusterTree::AgglomClusterTree(size_t N, double distances[], 
				     const std::string& method) {
  //------------------------------------  Define distance inference method
  double (*link_method)(double, double) = &linkage_single; 
  if (method == "single") {
    link_method = &linkage_single;
  }
  else {
    throw invalid_argument(string("Unrecognized linkage inference method: ")
			   + method);
  }

  _tree.reserve(N-1);
  lcl_array<size_t> left(N);
  size_t distSize = N*(N-1)/2;
  lcl_array<double> dist(distSize);
  for (size_t i=0; i<N; i++) left[i] = i;

  //------------------------------------  Iterate over cluster steps.
  for (size_t nLeft=N; nLeft>1; nLeft--) {

    //----------------------------------  Find the closest
    size_t mininx  = 0;
    size_t minjnx  = 1;
    double mindist = dist[0];
    size_t dinx = 0;
    for (size_t i=0; i<nLeft; i++) {
      for (size_t j=i+1; j<nLeft; j++) {
	if (dist[dinx] < mindist) {
	  mininx = i;
	  minjnx = j;
	  mindist = dist[dinx];
	}
	dinx++;
      }
    }
    
    //-----------------------------------  Add a cluster to the tree
    _tree.push_back(glom(left[mininx], left[minjnx], mindist));
    
    //-----------------------------------  Update d[k, i] for k<i
    size_t inx0 = 0;
    for (size_t k=0; k<mininx; k++) {
      dist[inx0+mininx-1] = (*link_method)(dist[inx0+mininx-1], 
					   dist[inx0+minjnx-1]);
      inx0 += nLeft - 2*k - 1;
    }

    //-----------------------------------  Update d[i, k] for  i<k<j
    size_t jnx0 = inx0 + nLeft - 2*mininx - 1;
    for (size_t k=mininx+1; k<minjnx; k++) {
      dist[inx0+k-1] = (*link_method)(dist[inx0+k-1], dist[jnx0+minjnx-1]);
      jnx0 += nLeft - 2*k - 1;
    }

    //-----------------------------------  Update d[i, k] for  j<k<nLeft
    for (size_t k=minjnx+1; k<nLeft; k++) {
      dist[inx0+k-1] = (*link_method)(dist[inx0+k-1], dist[jnx0+k-1]);
    }

    //-----------------------------------  Squeeze the links
    inx0 = 0;
    jnx0 = 0;
    for (size_t i=0; i<nLeft; i++) {
      if (i == minjnx) {
	inx0 += nLeft - i - 1;
      }
      else {
	for (size_t j=i+1; j<nLeft; j++) {
	  if (j != minjnx) {
	    if (inx0 != jnx0) dist[inx0] = dist[jnx0];
	    inx0++;
	  }
	  jnx0++;
	}
      }
    }
    left[mininx] = 2 * N - nLeft;
  }
}

size_t
AgglomClusterTree::cluster(const std::string& method, double cutoff, 
			   std::vector<size_t>& clust) const {
   size_t nClust = 0;
   size_t N = _tree.size();
   lcl_array<size_t> clustID(N);
   clust.clear();
   clust.resize(N+1, ACT_null_id);
   for (size_t i=0; i<N; i++) {
      if (_tree[i]._distance <= cutoff) {
	 //-----------------------------  Make a new cluster from two tiles.
	 size_t tclust = 0;
	 if (_tree[i]._link1 < N && _tree[i]._link2 < N) {
	    tclust = nClust++;
	    clust[_tree[i]._link1] = tclust;
	    clust[_tree[i]._link2] = tclust;
	 }

	 //-----------------------------  Add first tile to cluster
	 else if (_tree[i]._link1 < N) {
	    tclust = clustID[_tree[i]._link2 - N];
	    clust[_tree[i]._link1] = tclust;
	 }

	 //-----------------------------  Add second tile to cluster
	 else if (_tree[i]._link2 < N) {
	    tclust = clustID[_tree[i]._link1 - N];
	    clust[_tree[i]._link2] = tclust;
	 }

	 //-----------------------------  Merge clusters
	 else {
	    size_t clust1 = clustID[_tree[i]._link1 - N];
	    size_t clust2 = clustID[_tree[i]._link2 - N];
	    if (clust1 < clust2) {
	       tclust = clust1;
	    } else {
	       tclust = clust2;
	       clust2 = clust1;
	    }
	    for (size_t j=0; j<N; j++) {
	       if (clust[j] == clust2) {
		  clust[j] = tclust;
	       } else if (clust[j] > clust2 && clust[j] < nClust) {
		  clust[j]--;
	       }
	    }
	    for (size_t j=0; j<i; j++) {
	       if (clustID[j] > clust2 && clustID[j] < nClust) clustID[j]--;
	    }
	    nClust--;
	 }
	 clustID[i] = tclust;
      }
   }
   return nClust;
}
