/****************************************************************************************
* mexschur.c : C mex file to compute  
*          
*    mexschur(blk,Avec,nzlistA1,nzlistA2,permA,U,V,colend,type,schur);  
*
*    schur(I,J) = schur(I,J) + Trace(Ai U Aj V),
*    where I = permA[i], J = permA[j],   1<=i,j<=colend. 
* 
*   input:  blk  = 1x2 a cell array describing the block structure of A.
*           Avec =  
*           nzlistA = 
*           permA = a permutation vector.  
*           U,V  = real symmetric matrices.
*           type = 0, compute Trace(Ai*(U Aj V + V Aj U)/2)) = Trace(Ai*(U Aj V))
*                = 1, compute Trace(Ai*(U Aj U)).
*
* SDPT3: version 3.0
* Copyright (c) 1997 by
* K.C. Toh, M.J. Todd, R.H. Tutuncu
* Last Modified: 2 Feb 01   
****************************************************************************************/

#include <math.h>
#include <mex.h>

/**********************************************************
*  compute  Trace(B U*A*U)
*
*  A,B are assumed to be real,sparse,symmetric.
*  U  is assumed to be real,dense,symmetric. 
**********************************************************/
double schurij1( int n, double r2, double ir2, 
                 double *Avec, double *idxstart, double *nzlistA, int len,
                 double *U, int i, int j)

{ int    ra, ca, rb, cb, rbn, cbn, l, k, kstart, kend, lstart, lend; 
  double tmp1, tmp2, tmp3, tmp4; 

  lstart = (int)idxstart[j]; lend = (int)idxstart[j+1]; 
  kstart = (int)idxstart[i]; kend = (int)idxstart[i+1]; 
  tmp1 = 0; tmp2 = 0;  
  for (l=lstart; l<lend; ++l) { 
      rb = (int)nzlistA[l] -1;    
      cb = (int)nzlistA[l+len] -1;
      if (rb > cb) { mexErrMsgTxt("mexschur: nzlistA2 is incorrect"); }
      rbn = rb*n; cbn = cb*n;   
      tmp3 = 0; tmp4 = 0;
      for (k=kstart; k<kend; ++k) { 
          ra = (int)nzlistA[k] -1 ;
          ca = (int)nzlistA[k+len] -1;             
          if (ra<ca) {  
             tmp3 += Avec[k] * (U[ra+rbn]*U[ca+cbn]+U[ra+cbn]*U[ca+rbn]); }
          else { 
             tmp4 += Avec[k] * U[ra+rbn]*U[ca+cbn]; }
      }
      if (rb<cb) { tmp1 += Avec[l]*(ir2*tmp3 + tmp4); }
      else       { tmp2 += Avec[l]*(ir2*tmp3 + tmp4); } 
  }
  return r2*tmp1+tmp2; 
}
/**********************************************************
*  compute  Trace(B (U*A*V + V*A*U)/2) = Trace(B U*A*V)
*
*  A,B are assumed to be real,sparse,symmetric.
*  U,V are assumed to be real,dense,symmetric. 
**********************************************************/
double schurij3( int n, double r2, double ir2, 
                double *Avec, double *idxstart, double *nzlistA, int len,
                double *U, double *V, int i, int j)

{ int    ra, ca, rb, cb, rbn, cbn, l, k, idx1, idx2, idx3, idx4;
  int    kstart, kend, lstart, lend; 
  double tmp1, tmp2, tmp3, tmp4; 

  lstart = (int)idxstart[j]; lend = (int)idxstart[j+1]; 
  kstart = (int)idxstart[i]; kend = (int)idxstart[i+1]; 
  tmp1 = 0; tmp2 = 0;  
  for (l=lstart; l<lend; ++l) { 
      rb = (int)nzlistA[l] -1;    
      cb = (int)nzlistA[l+len] -1;
      if (rb > cb) { mexErrMsgTxt("mexschur: nzlistA2 is incorrect"); }
      rbn = rb*n; cbn = cb*n;   
      tmp3 = 0; tmp4 = 0; 
      for (k=kstart; k<kend; ++k) { 
          ra = (int)nzlistA[k] -1 ;
          ca = (int)nzlistA[k+len] -1;
          idx1 = ra+rbn; idx2 = ca+cbn;
          if (ra<ca) { 
             idx3 = ra+cbn; idx4 = ca+rbn; 
	     tmp3 += Avec[k] *(U[idx1]*V[idx2]+U[idx2]*V[idx1] \
                               +U[idx3]*V[idx4]+U[idx4]*V[idx3]);  }
          else {
	     tmp4 += Avec[k] * (U[idx1]*V[idx2]+U[idx2]*V[idx1]);  }
      }
      if (rb<cb) { tmp1 += Avec[l]*(ir2*tmp3+tmp4); }
      else       { tmp2 += Avec[l]*(ir2*tmp3+tmp4); } 
  }
  return ir2*tmp1+tmp2/2; 
}
/**********************************************************
*
**********************************************************/
void vec(int numblk, int *cumblksize, int *blknnz, 
         double r2, double *A, int *irA, int *jcA, double *B) 

{  int idx0, idx, i, j, l, jstart, jend, istart, blksize;
   int k, kstart, kend; 
   
      for (l=0; l<numblk; l++) { 
  	  jstart = cumblksize[l]; 
  	  jend   = cumblksize[l+1];  
          blksize = jend-jstart; 
          istart = jstart;
          idx0 = blknnz[l]; 
          for (j=jstart; j<jend; j++) { 
              idx = idx0 + (j-jstart)*blksize; 
              kstart = jcA[j]; kend = jcA[j+1]; 
              for (k=kstart; k<kend; k++) { 
                  i = irA[k];
                  B[idx+i-istart] = A[k]; }
          }
      }  
return;
}
/**********************************************************
*  compute  Trace(B U*A*U)
*
*  A,B are assumed to be real,sparse,symmetric.
*  U  is assumed to be real,sparse,symmetric. 
**********************************************************/
double schurij2( double r2, double ir2, double *Avec, 
                 double *idxstart, double *nzlistA, int len, double *Utmp, 
                 int numblk, int *cumblksize, int *blknnz, int *blkidx, int i, int j)

{ int    r, ra, ca, rb, cb, l, k, kstart, kend, kstartnew, lstart, lend;
  int    colcb1, colcb2, blksize, idxrb, idxcb, idx1, idx2, idx3, idx4;
  int    cblk, calk, firstime; 
  double tmp1, tmp2, tmp3, tmp4; 

  lstart = (int)idxstart[j]; lend = (int)idxstart[j+1]; 
  kstart = (int)idxstart[i]; kend = (int)idxstart[i+1]; 
  kstartnew = kstart;
  tmp1 = 0; tmp2 = 0; 
  for (l=lstart; l<lend; ++l) { 
      rb = (int)nzlistA[l] -1;    
      cb = (int)nzlistA[l+len] -1;
      if (rb > cb) { mexErrMsgTxt("mexschur: nzlistA2 is incorrect"); }
      cblk = blkidx[cb];  colcb1 = cumblksize[cblk];  colcb2= cumblksize[cblk+1];
      blksize = (colcb2-colcb1);     
      idxcb = blknnz[cblk]+(cb-colcb1)*blksize;
      idxrb = blknnz[cblk]+(rb-colcb1)*blksize; 
      tmp3 = 0; tmp4 = 0; firstime = 1; 
      for (k=kstart; k<kend; ++k) { 
          ra = (int)nzlistA[k] -1 ;
          ca = (int)nzlistA[k+len] -1;
          calk = blkidx[ca]; 
          if (calk==cblk) {
             idx1 = (ra-colcb1)+idxrb; idx2 = (ca-colcb1)+idxcb; 
             if (ra<ca) {  
                idx3 = (ra-colcb1)+idxcb; idx4 = (ca-colcb1)+idxrb; 
                tmp3 += Avec[k] * (Utmp[idx1]*Utmp[idx2]+Utmp[idx3]*Utmp[idx4]); }
             else {
	        tmp4 += Avec[k] * Utmp[idx1]*Utmp[idx2]; }
             if (firstime) { kstartnew = k; firstime = 0; } 
	  }
          else if (calk > cblk) {
	     break;
          }
      }
      kstart = kstartnew; 
      if (rb<cb) { tmp1 += Avec[l]*(ir2*tmp3 + tmp4); }
      else       { tmp2 += Avec[l]*(ir2*tmp3 + tmp4); } 
  }
  return r2*tmp1+tmp2; 
}
/**********************************************************
*  compute  Trace(B (U*A*V + V*A*U)/2) = Trace(B U*A*V)
*
*  A,B are assumed to be real,sparse,symmetric.
*  U,V are assumed to be real,sparse,symmetric. 
**********************************************************/
double schurij4( double r2, double ir2, double *Avec, 
                 double *idxstart, double *nzlistA, int len,
                 double *Utmp, double *Vtmp, 
                 int numblk, int *cumblksize, int *blknnz, int *blkidx, int i, int j)

{ int    r, ra, ca, rb, cb, l, k, kstart, kend, kstartnew, lstart, lend;
  int    colcb1, colcb2, blksize, idxrb,idxcb, idx1, idx2, idx3, idx4; 
  int    cblk, calk, firstime;
  double tmp1, tmp2, tmp3, tmp4; 

  lstart = (int)idxstart[j]; lend = (int)idxstart[j+1]; 
  kstart = (int)idxstart[i]; kend = (int)idxstart[i+1]; 
  kstartnew = kstart;
  tmp1 = 0; tmp2 = 0;  
  for (l=lstart; l<lend; ++l) { 
      rb = (int)nzlistA[l] -1;    
      cb = (int)nzlistA[l+len] -1;
      if (rb > cb) { mexErrMsgTxt("mexschur: nzlistA2 is incorrect"); }
      cblk = blkidx[cb];  colcb1 = cumblksize[cblk];  colcb2= cumblksize[cblk+1];
      blksize = (colcb2-colcb1);     
      idxcb = blknnz[cblk]+(cb-colcb1)*blksize;
      idxrb = blknnz[cblk]+(rb-colcb1)*blksize; 
      tmp3 = 0; tmp4 = 0; firstime = 1; 
      for (k=kstart; k<kend; ++k) { 
          ra = (int)nzlistA[k] -1 ;
          ca = (int)nzlistA[k+len] -1;
          calk = blkidx[ca]; 
          if (calk == cblk) { 
             idx1 = (ra-colcb1)+idxrb; idx2 = (ca-colcb1)+idxcb; 
             if (ra<ca) {
                idx3 = (ra-colcb1)+idxcb; idx4 = (ca-colcb1)+idxrb; 
	        tmp3 += Avec[k] * (Utmp[idx1]*Vtmp[idx2] +Utmp[idx2]*Vtmp[idx1] \
                                  +Utmp[idx3]*Vtmp[idx4] +Utmp[idx4]*Vtmp[idx3]); }
             else {
	        tmp4 += Avec[k] * (Utmp[idx1]*Vtmp[idx2] +Utmp[idx2]*Vtmp[idx1]); }
             if (firstime) { kstartnew = k; firstime = 0; }  
	  }
          else if (calk > cblk) {
	     break;
          }
      }
      kstart = kstartnew; 
      if (rb<cb) { tmp1 += Avec[l]*(ir2*tmp3 + tmp4); }
      else       { tmp2 += Avec[l]*(ir2*tmp3 + tmp4); } 
  }
  return ir2*tmp1+tmp2/2; 
}
/**********************************************************/
void mexFunction(
     int nlhs,   mxArray  *plhs[], 
     int nrhs,   const mxArray  *prhs[] )
{    
     mxArray  *blk_cell_pr;  
     double   *Avec, *idxstart, *nzlistA, *permA, *U, *V, *schur;
     double   *blksize, *Utmp, *Vtmp;  
     int      *irU, *jcU, *irV, *jcV, *colm, *cumblksize, *blknnz, *blkidx; 

     int      subs[2];
     int      nsubs=2; 
     int      index, colend, type, isspU, isspV, numblk; 
     int      len, row, col, nU, nV, n, m, m1, idx1, idx2, k, nsub, n1, n2, opt;
     double   r2, ir2, tmp; 

/* CHECK THE DIMENSIONS */

   if (nrhs != 10) {
      mexErrMsgTxt(" mexschur: must have 10 inputs"); }
   if (!mxIsCell(prhs[0])) {
      mexErrMsgTxt("mexschur: 1ST input must be the cell array blk"); }  
    subs[0] = 0; 
    subs[1] = 1;
    index = mxCalcSingleSubscript(prhs[0],nsubs,subs); 
    blk_cell_pr = mxGetCell(prhs[0],index); 
    numblk  = mxGetN(blk_cell_pr);
    blksize = mxGetPr(blk_cell_pr); 

/**** get pointers ****/    

    Avec = mxGetPr(prhs[1]); 
    if (!mxIsSparse(prhs[1])) { 
       mexErrMsgTxt("mexschur: Avec must be sparse"); }
    idxstart = mxGetPr(prhs[2]);  
    nzlistA = mxGetPr(prhs[3]); 
    len = mxGetM(prhs[3]);

    permA = mxGetPr(prhs[4]); 
    m1 = mxGetN(prhs[4]); 

    U = mxGetPr(prhs[5]);  nU = mxGetM(prhs[5]); 
    isspU = mxIsSparse(prhs[5]); 
    if (isspU) { irU = mxGetIr(prhs[5]); jcU = mxGetJc(prhs[5]); }
    V = mxGetPr(prhs[6]);  nV = mxGetM(prhs[6]); 
    isspV = mxIsSparse(prhs[6]);
    if (isspV) { irV = mxGetIr(prhs[6]); jcV = mxGetJc(prhs[6]); }
    if ((isspU & !isspV) || (!isspU & isspV)) { 
       mexErrMsgTxt("mexschur: U,V must be both dense or both sparse"); 
    }
    colend = (int)*mxGetPr(prhs[7]); 
    type   = (int)*mxGetPr(prhs[8]); 

    schur = mxGetPr(prhs[9]); 
    m = mxGetM(prhs[9]);    
    if (m!= m1) {
       mexErrMsgTxt("mexschur: schur and permA are not compatible"); }
    
/************************************
* initialization 
************************************/
    if (isspU & isspV) { 
       cumblksize = mxCalloc(numblk+1,sizeof(int)); 
       blknnz = mxCalloc(numblk+1,sizeof(int)); 
       cumblksize[0] = 0; blknnz[0] = 0; 
       n1 = 0; n2 = 0; 
       for (k=0; k<numblk; ++k) {
           nsub = (int) blksize[k];
           n1  += nsub;  
           n2 += nsub*nsub;  
           cumblksize[k+1] = n1; 
           blknnz[k+1] = n2;  }
       if (nU != n1 || nV != n1) { 
          mexErrMsgTxt("mexschur: blk and dimension of U not compatible"); }
    }
    if (isspU) { 
       Utmp = mxCalloc(n2,sizeof(double)); 
       vec(numblk,cumblksize,blknnz,r2,U,irU,jcU,Utmp); }
    if (isspV & type == 0) { 
       Vtmp = mxCalloc(n2,sizeof(double)); 
       vec(numblk,cumblksize,blknnz,r2,V,irV,jcV,Vtmp); }
    r2 = sqrt(2);  ir2 = 1/r2; 
    colm = mxCalloc(colend,sizeof(int));     
    for (k=0; k<colend; k++) { colm[k] = (permA[k]-1)*m; } 
    if (isspU) {
       blkidx = mxCalloc(nU,sizeof(int));
       n = 0; 
       for (k=0; k<nU; k++) {
	 if (k>=cumblksize[n+1]) { n++; }  
         blkidx[k] = n; }
    }
/************************************
* compute schur(i,j)
************************************/
    
    n = nU; 
    if      (type==1 & !isspU)  { opt=1; } 
    else if (type==1 &  isspU)  { opt=2; }
    else if (type==0 & !isspU)  { opt=3; }
    else if (type==0 &  isspU)  { opt=4; }
    /*************************************/
    for (col=0; col<colend; col++) { 
        for (row=0; row<=col; row++) { 
   	    if (opt==1) { 
  	       tmp = schurij1(n,r2,ir2,Avec,idxstart,nzlistA,len,U,row,col); }
            else if (opt==2) {
               tmp = schurij2(r2,ir2,Avec,idxstart,nzlistA,len,Utmp, \
			      numblk,cumblksize,blknnz,blkidx,row,col); }
            else if (opt==3) {
   	       tmp = schurij3(n,r2,ir2,Avec,idxstart,nzlistA,len,U,V,row,col); }
            else { 
               tmp = schurij4(r2,ir2,Avec,idxstart,nzlistA,len,Utmp,Vtmp, \
                              numblk,cumblksize,blknnz,blkidx,row,col); 
            }
	    idx1 = (permA[row]-1)+colm[col]; /*subtract 1 to adjust for matlab index */
            idx2 = (permA[col]-1)+colm[row]; 
            schur[idx1] += tmp;
            schur[idx2] = schur[idx1];  
        }
    }
    if (isspU) { mxFree(Utmp); } 
    if (isspV & type==0) { mxFree(Vtmp); } 
return;
}
/**********************************************************/







