/*
 * =============================================================
JointBoosting with Temporal consistency. The weak learner outputs
the averages over a optimized window T.
 * =============================================================
 */

/* $Revision: 1.9 $ */

#include "mex.h"
#include "math.h"
#include <stdlib.h>
#include <gsl/gsl_vector.h>
#include <gsl/gsl_combination.h>
long long int factorial(long long int victim)
    {
		if(victim==0)
			return 1;
		else {
		    victim = victim * factorial(victim-1);
		}
    return victim;
}
void cicloPhiChi(float *x,float *w,float *z,
				 float *betaVals,float *alphaVals,
				 int Nsamples,int MaxThre,int Tmax, int Nclass,
				 float *aMin,
				 float *bMin,float *JMin,float *y,float *ksum, unsigned int *minN, unsigned int *tOpt, float *thOpt)  
{
	//DEFINE THE VARIABLES
	gsl_vector *xx;
	gsl_combination *c;
	size_t ik, *pt, indx; 
	int countN, countTth, count, Tth, ii, i, j, T, th, NxB, NxL, statusSet, nn, n, classe, classeOUT, comp;
	float ws, kSum1, kSum2, JminTemp, ConstPart, *wz, *phi, *chi, *sw, *Zclass, *kDen, *kNum, *wzphiSum, *wzchiSum, *wphiSum, *wchiSum, *wphichiSum;
	float *aN, *bN, *JN, *TN, *thN, *wzphiS, *wchiS, *wzchiS, *wphiS, *wphichiS;
	float **wzphiClass, **wzchiClass, **wphiClass, **wchiClass, **wphichiClass, **norm, **AcTth, **BcTth, **JcTth;
	long long int countNTemp,NN;
	

	//MALLOC THE VECTORS AND MATRICES AND INITIALIZE (CALLOC)
	wz			=(float *)malloc(sizeof(float) *Nsamples*Nclass);
	sw			=(float *)malloc(sizeof(float) *Nclass);
	Zclass		=(float *)malloc(sizeof(float) *Nclass);

	phi			=(float *)malloc(sizeof(float) *Nsamples);
	chi			=(float *)malloc(sizeof(float) *Nsamples);

	kDen		=(float *)malloc(sizeof(float) *Nclass);
	kNum		=(float *)malloc(sizeof(float) *Nclass);
	wzphiSum	=(float *)malloc(sizeof(float) *Nclass);
	wzchiSum	=(float *)malloc(sizeof(float) *Nclass);
	wphiSum		=(float *)malloc(sizeof(float) *Nclass);
	wchiSum		=(float *)malloc(sizeof(float) *Nclass);
	wphichiSum	=(float *)malloc(sizeof(float) *Nclass);

	wzphiClass	=(float **)malloc(sizeof(float ) *Nclass);
	wzchiClass	=(float **)malloc(sizeof(float ) *Nclass);
	wphiClass	=(float **)malloc(sizeof(float ) *Nclass);
	wchiClass	=(float **)malloc(sizeof(float ) *Nclass);
	wphichiClass=(float **)malloc(sizeof(float ) *Nclass);
	norm		=(float **)malloc(sizeof(float ) *Nclass);
	AcTth		=(float **)malloc(sizeof(float ) *Nclass);
	BcTth		=(float **)malloc(sizeof(float ) *Nclass);
	JcTth		=(float **)malloc(sizeof(float ) *Nclass);

	for (i = 0; i < Nclass; i++) {
		wzphiClass[i]	= (float *)malloc(sizeof(float) * MaxThre*Tmax);
		wzchiClass[i]	= (float *)malloc(sizeof(float) * MaxThre*Tmax);
		wphiClass[i]	= (float *)malloc(sizeof(float) * MaxThre*Tmax);
		wchiClass[i]	= (float *)malloc(sizeof(float) * MaxThre*Tmax);
		wphichiClass[i] = (float *)malloc(sizeof(float) * MaxThre*Tmax);
		norm[i]			= (float *)malloc(sizeof(float) * MaxThre*Tmax);
		AcTth[i]		= (float *)malloc(sizeof(float) * MaxThre*Tmax);
		BcTth[i]		= (float *)malloc(sizeof(float) * MaxThre*Tmax);
		JcTth[i]		= (float *)malloc(sizeof(float) * MaxThre*Tmax);
	}
//printf("antes do malloc");
	countNTemp=0;
	for (ik = 1 ; ik < Nclass; ik++){
		countNTemp+=factorial(Nclass)/(factorial(ik)*factorial(Nclass-ik));
	}
	//printf("countNTemp= %d",countNTemp);
	aN =(float *)malloc(sizeof(float)*countNTemp);
	bN =(float *)malloc(sizeof(float)*countNTemp);
	JN =(float *)malloc(sizeof(float)*countNTemp);
	TN =(float *)malloc(sizeof(float)*countNTemp);
	thN=(float *)malloc(sizeof(float)*countNTemp);
//printf("depois do malloc");
	wzphiS	=(float *)malloc(sizeof(float)*MaxThre*Tmax);
	wchiS	=(float *)malloc(sizeof(float)*MaxThre*Tmax);
	wzchiS	=(float *)malloc(sizeof(float)*MaxThre*Tmax);
	wphiS	=(float *)malloc(sizeof(float)*MaxThre*Tmax);
	wphichiS=(float *)malloc(sizeof(float)*MaxThre*Tmax);



	//count=0;
	for (classe=1; classe <= Nclass; classe++) {
		sw[classe-1]=0;Zclass[classe-1]=0;
		for (i = 0; i < Nsamples; i++) {		
			wz[i*Nclass+classe-1] = w[i*Nclass+classe-1] * z[i*Nclass+classe-1];
			sw[classe-1] += w[i*Nclass+classe-1];
			Zclass[classe-1] += wz[i*Nclass+classe-1];
			//count++;
		}
	}
	ws=0;
	for (classe=1; classe <= Nclass; classe++){
		ws += sw[classe-1];
	}

	countTth=0;
	for (T = 0; T < Tmax; T++){
		for (th = 0; th < MaxThre; th++){
			//printf("Max Threshold = %d\n\n",MaxThre);

			//INITIALIZE THE TERMS OF A B AND K
			for (classe=1; classe <= Nclass; classe++){
				kNum[classe-1]		= 0; kDen[classe-1]		= 0;
				wzphiSum[classe-1]	= 0; wzchiSum[classe-1]	= 0;
				wphiSum[classe-1]	= 0; wchiSum[classe-1]	= 0;
				wphichiSum[classe-1]= 0;
			}
			for (i = 0; i < Nsamples; i++){
				NxB=0;NxL=0;
				if (i >= T){
					for (j = i-T; j <= i; j++){
					/*	if (y[j] == y[i])
							if (x[j] > alphaVals[th])
								NxB++;
							else
								NxL++;
					}//for (j = i-T; j <= i; j++)
				}//(if i>=T)
				phi[i] = NxB/(T+1);chi[i] = NxL/(T+1);*/
						statusSet=1;
						if (y[j] == y[i]){
							if (x[j] > alphaVals[th])
								NxB++;
							else
								NxL++;
						}
						else if (y[j] != y[i])
							statusSet=0;
					}//for (j = i-T; j <= i; j++)
				}//(if i>=T)
				if(statusSet==1){
					phi[i] = (float)NxB/(T+1);chi[i] = (float)NxL/(T+1);}
				else{
					phi[i] = 0;chi[i] = 0;}
				for (classe=1; classe <= Nclass; classe++){
					kNum[classe-1]		+= wz[i*Nclass+classe-1];
					kDen[classe-1]		+= w[i*Nclass+classe-1];
					if (i >= T){
						wzphiSum[classe-1]	+= wz[i*Nclass+classe-1] * phi[i];
						wzchiSum[classe-1]	+= wz[i*Nclass+classe-1] * chi[i];
						wphiSum[classe-1]	+= w[i*Nclass+classe-1]	 * phi[i] * phi[i];
						wchiSum[classe-1]	+= w[i*Nclass+classe-1]  * chi[i] * chi[i];
						wphichiSum[classe-1]+= w[i*Nclass+classe-1]  * phi[i] * chi[i];
					}//(if i>=T)
				}//for (classe=1; classe <= Nclass; classe++)
			}//for (i = 0; i < Nsamples; i++)
			for (classe=1; classe <= Nclass; classe++){
				//JcTth[classe-1][countTth] =0;norm[classe-1][countTth]=0;AcTth[classe-1][countTth] =0;BcTth[classe-1][countTth] =0;
				
				wzphiClass[classe-1][countTth]	=wzphiSum[classe-1];
				wzchiClass[classe-1][countTth]	=wzchiSum[classe-1];
				wphiClass[classe-1][countTth]	=wphiSum[classe-1];
				wchiClass[classe-1][countTth]	=wchiSum[classe-1];
				wphichiClass[classe-1][countTth]=wphichiSum[classe-1];

				norm[classe-1][countTth] = (wphiClass[classe-1][countTth] * wchiClass[classe-1][countTth]) 
					- (wphichiClass[classe-1][countTth] * wphichiClass[classe-1][countTth]);

				AcTth[classe-1][countTth] = (wzphiClass[classe-1][countTth] * wchiClass[classe-1][countTth] 
				    - wzchiClass[classe-1][countTth] * wphichiClass[classe-1][countTth]) / norm[classe-1][countTth];

				BcTth[classe-1][countTth] = (wphiClass[classe-1][countTth] * wzchiClass[classe-1][countTth] 
				    - wzphiClass[classe-1][countTth] * wphichiClass[classe-1][countTth]) / norm[classe-1][countTth];
				
				JcTth[classe-1][countTth] = (AcTth[classe-1][countTth]*AcTth[classe-1][countTth])*wphiClass[classe-1][countTth]
								 +(BcTth[classe-1][countTth]*BcTth[classe-1][countTth])*wchiClass[classe-1][countTth]
								 +2*AcTth[classe-1][countTth]*BcTth[classe-1][countTth]*wphichiClass[classe-1][countTth]
								 -2*AcTth[classe-1][countTth]*wzphiClass[classe-1][countTth]
								 -2*BcTth[classe-1][countTth]*wzchiClass[classe-1][countTth];

				
				kSum1=0;kSum2=0;
				for (classeOUT=1; classeOUT<= Nclass; classeOUT++){
					ksum[classeOUT-1]=kNum[classeOUT-1]/kDen[classeOUT-1];
					if (classeOUT!=classe){
						kSum1 += (ksum[classeOUT-1]*ksum[classeOUT-1])*sw[classeOUT-1];
						kSum2 += 2*ksum[classeOUT-1]*Zclass[classeOUT-1];
						
					}
				}
				JcTth[classe-1][countTth] +=  ws + kSum1 - kSum2;
				//printf("Classe = %d T = %d Threshold = %d Custo = %f\n",classe,T,th,JcTth[classe-1][countTth]);
			}//for (classe=1; classe <= Nclass; classe++)
			countTth++;
		}//for (th = 0; th < MaxThre; th++)
	}//for (T = 0; T < Tmax; T++)
//printf("MaxThre=%d \n",MaxThre);
	//COST FUNCTION OPTIMIZATION FOR THE CLASSES IN THE LEAVES OF THE SHARING GRAPH (in T&th and put on n - classes set)...		
	for (classe=1; classe <= Nclass; classe++){
		xx = gsl_vector_calloc((size_t) countTth);
		for (i=0;i<countTth;i++)
			gsl_vector_set(xx,i,JcTth[classe-1][i]);
		indx=gsl_vector_min_index(xx);
		//printf("indx=%u \n",indx);
		JminTemp = gsl_vector_get (xx, indx);
		gsl_vector_free(xx);
		aN[classe-1]=AcTth[classe-1][indx];
		bN[classe-1]=BcTth[classe-1][indx];
		JN[classe-1]=(float)JminTemp;
		TN[classe-1]=(float)floor((double) indx/(double)(MaxThre));//T=floor(val/Nthre);
		//printf("TNinicial=%f \n",TN[classe-1]);
		thN[classe-1]=alphaVals[(int) ((int)indx-(int)floor((double) indx/(double)(MaxThre))*MaxThre)];//th=val-T*Nthre;
	}

	//printf("ANTES DO CICLO");
	countN=classe-1;
	//printf("countN= %d",countN);
	for (ik = 2 ; ik < Nclass; ik++){
		
		c  = gsl_combination_calloc(Nclass, ik);		//compute the combinations of ik elements
		pt = (size_t *)malloc(sizeof(size_t)*ik);
							//get a pointer to the data
		NN=factorial(Nclass)/(factorial(ik)*factorial(Nclass-ik));
		//NN = gsl_combination_n(c);						//get the number of elements
		//printf("N elements to combine = %d\nN of combinations = %d\n",ik,NN);
		//printf("CICLO 1");
		for (nn=1;nn<=NN;nn++){						//for all the number of elements 
			//if(countN==1214){
			//	printf("hahah aqui estou eu...");
			//}
			pt = gsl_combination_data(c);
			//for (comp=0; comp<ik;comp++){
			//	printf("{ %d }",pt[comp]+1);
			//}printf("\n");
			//INITIALIZE THE CONSTANT PART
			//printf("Count N = %d\n",countN);
			ConstPart = 0;kSum1=0;kSum2=0;
			for (classe=1;classe<=Nclass;classe++){
					statusSet=0;
					for (comp=0; comp<ik;comp++){
						if ((int)(pt[comp]+1)==classe){
							statusSet=1;comp=ik;
							}
						}
					if(statusSet==0){
						kSum1 += (ksum[classe-1]*ksum[classe-1])*sw[classe-1];
						kSum2 += 2*ksum[classe-1]*Zclass[classe-1];
						
					}
			}
			ConstPart = kSum1 - kSum2;
			xx = gsl_vector_calloc((size_t) countTth);
			//printf("CICLO 2");
			for (Tth=0;Tth<countTth;Tth++){
				//INITIALIZE THE TERMS OF AS & BS
				wzphiS[Tth] = 0; wzchiS[Tth] = 0;
				wphiS[Tth] = 0; wchiS[Tth] = 0;
				wphichiS[Tth] = 0;
				for (classe=1;classe<=Nclass;classe++){
					statusSet=0;
					for (comp=0; comp<ik;comp++){
						if ((int)(pt[comp]+1)==classe){
							statusSet=1;comp=ik;
							}
						}
					if(statusSet==1){						
							wzphiS[Tth] += wzphiClass[classe-1][Tth];
							wchiS[Tth] += wchiClass[classe-1][Tth];
							wzchiS[Tth] += wzchiClass[classe-1][Tth];
							wphiS[Tth] += wphiClass[classe-1][Tth];					
							wphichiS[Tth] += wphichiClass[classe-1][Tth];
						}
				}
				//printf("CICLO 3");
				//T=(float)floor((double) Tth/(double)(MaxThre));
				norm[0][Tth] = (wphiS[Tth] * wchiS[Tth]) - (wphichiS[Tth] * wphichiS[Tth]);
				AcTth[0][Tth] = (wzphiS[Tth] * wchiS[Tth] - wzchiS[Tth] * wphichiS[Tth]) / norm[0][Tth];
				BcTth[0][Tth] = (wphiS[Tth] * wzchiS[Tth] - wzphiS[Tth] * wphichiS[Tth]) / norm[0][Tth];
				JcTth[0][Tth] = ws + (AcTth[0][Tth]*AcTth[0][Tth])*wphiS[Tth]
						     + (BcTth[0][Tth]*BcTth[0][Tth])*wchiS[Tth]
							 + 2*AcTth[0][Tth]*BcTth[0][Tth]*wphichiS[Tth]
							 - 2*AcTth[0][Tth]*wzphiS[Tth]
							 - 2*BcTth[0][Tth]*wzchiS[Tth]
							 + ConstPart;
				gsl_vector_set(xx,Tth,JcTth[0][Tth]);
			}//for (Tth=0;Tth<countTth;Tth++)
			//COST FUNCTION OPTIMIZATION FOR THE CLASSES IN THE MIDLE (SET'S OF CLASSES) OF THE SHARING GRAPH (in T&th and put on n - classes set)
			//printf("CICLO 4");
			indx=gsl_vector_min_index(xx);
			JminTemp = gsl_vector_get(xx, indx);
			gsl_vector_free(xx);
			aN[countN]=AcTth[0][indx];
			bN[countN]=BcTth[0][indx];
			JN[countN]=(float)JminTemp;
			TN[countN]=(float)floor((double) indx/(double)(MaxThre));//T=floor(val/Nthre);
			thN[countN]=alphaVals[(int) ((int)indx-(int)floor((double) indx/(double)(MaxThre))*MaxThre)];//th=val-T*Nthre;
			//printf("%d\n",countN);
			countN++;
			gsl_combination_next(c);
		}//for (nn=1;nn<=NN;nn++)
		gsl_combination_free(c);
		//free(pt);
	}//for (ik = 2 ; ik < Nclass; ik++)
	//COST FUNCTION OPTIMIZATION FOR THE CLASSES SET (FINAL OPTIMIZATION)
	//printf("CICLO 5");
	xx = gsl_vector_calloc((size_t) countN);
	for (i=0;i<countN;i++){
		gsl_vector_set(xx,i,JN[i]);
		//printf("i=%d TN=%f\n",i,TN[i]);
	}
	
	indx=gsl_vector_min_index(xx);
	
	JminTemp = gsl_vector_get (xx, indx);
	gsl_vector_free(xx);
	*aMin=aN[indx];
	*bMin=bN[indx];
	*JMin=(float)JminTemp;
	*tOpt=(unsigned int)TN[indx]+1;
	//printf("cT = %f\n",TN[indx]);
	//printf("cT = %f\n",TN[indx]+1);
	//printf("cT = %u\n",*tOpt);
	*thOpt=thN[indx];
	*minN=indx+1;
//printf("cMin = %u\n",*minN);
//printf("CICLO 6");
	free(wz);
	free(sw);
	free(Zclass);

	free(phi);
	free(chi);

	free(kDen);
	free(kNum);	
	free(wzphiSum);	
	free(wzchiSum);	
	free(wphiSum);		
	free(wchiSum);		
	free(wphichiSum);	


	for (i = 0; i < Nclass; i++) {
		free(wzphiClass[i]);
		free(wzchiClass[i]);
		free(wphiClass[i]);
		free(wchiClass[i]);
		free(wphichiClass[i]);
		free(norm[i]);
		free(AcTth[i]);
		free(BcTth[i]);
		free(JcTth[i]);
	}

	free(wzphiClass);
	free(wzchiClass);
	free(wphiClass);
	free(wchiClass);
	free(wphichiClass);
	free(norm);
	free(AcTth);
	free(BcTth);
	free(JcTth);

	free(aN);
	free(bN);
	free(JN);
	free(TN);
	free(thN);

	free(wzphiS);
	free(wchiS);
	free(wzchiS);
	free(wphiS);
	free(wphichiS);
	
}

/* The gateway routine */
void mexFunction(int nlhs, mxArray *plhs[],
                 int nrhs, const mxArray *prhs[])
{
  float *x,*y, *w,*z,*betaVals,*alphaVals,*a,*b,*J,*wz;
  int Tmax;
  float *k, *TH;
  int Nsamples, MaxThre, Nclass, size1[2],size2[2];
	unsigned int *N,*T;	
////printf("tou aqui1");
  //GET IMPUT VARIABLES
  /* Get the scalar input Tmax & Create pointers to 
  the input data - x,w,z, beta, alpha...*/
  Tmax = (int)mxGetScalar(prhs[0]);  
  x = mxGetPr(prhs[1]);
  w = mxGetPr(prhs[2]);
  z = mxGetPr(prhs[3]);
  betaVals = mxGetPr(prhs[4]);
  alphaVals = mxGetPr(prhs[5]);
  y = mxGetPr(prhs[6]);
////printf("tou aqui1");
  /* Get the dimensions of the data input. */
  Nclass = mxGetM(prhs[3]);
  Nsamples = mxGetN(prhs[1]);
  MaxThre = mxGetN(prhs[4]);
  size1[0]=1;
  size1[1]=1;
  size2[0]=Nclass;
  size2[1]=1;
  ////printf("tou aqui1");
  /* Set the output pointer to the output matrixs. */
  plhs[0] = mxCreateNumericArray(2,size1, mxSINGLE_CLASS, 0);//a
  plhs[1] = mxCreateNumericArray(2,size1, mxSINGLE_CLASS, 0);//b
  plhs[2] = mxCreateNumericArray(2,size1, mxSINGLE_CLASS, 0);//J
  plhs[3] = mxCreateNumericArray(2,size2, mxSINGLE_CLASS, 0);//k
  plhs[4] = mxCreateNumericArray(2,size1, mxUINT64_CLASS, 0);//N
  plhs[5] = mxCreateNumericArray(2,size1, mxUINT64_CLASS, 0);//T
  plhs[6] = mxCreateNumericArray(2,size1, mxSINGLE_CLASS, 0);//T
  ////printf("tou aqui1");
  /* Create a C pointer to a copy of the output matrix. */

  a			= mxGetPr(plhs[0]);
  b			= mxGetPr(plhs[1]);
  J			= mxGetPr(plhs[2]);
  k			= mxGetPr(plhs[3]);
  N			= mxGetPr(plhs[4]);
  T			= mxGetPr(plhs[5]);
  TH		= mxGetPr(plhs[6]);
  ////printf("tou aqui1");
  /* Call the C subroutine. */
  cicloPhiChi(x,w,z,betaVals,alphaVals,Nsamples,MaxThre,Tmax,Nclass,a,b,J,y,k,N,T,TH);
  //printf("cN = %u\n",*N);
}