#include "PardisoSolver.h"
#include "SystemVector.h"
#include "MatVecMultiplier.h"

#ifdef HAVE_MKL

#include <mkl.h>
#include <mkl_pardiso.h>

namespace AMDiS {

  template<>
  int PardisoSolver<SystemVector>::solveSystem(MatVecMultiplier<SystemVector> *matVec,
					       SystemVector *x, SystemVector *b,
					       bool reuseMatrix)
  {
    FUNCNAME("PardisoSolver::solveSystem()");

    TEST_EXIT(x->getSize() == b->getSize())("Vectors x and b must have the same size!");
    
    // Extract the matrix of DOF-matrices.
    StandardMatVec<Matrix<DOFMatrix*>, SystemVector> *stdMatVec = 
      dynamic_cast<StandardMatVec<Matrix<DOFMatrix*>, SystemVector> *>(matVec);
    Matrix<DOFMatrix*> *m = stdMatVec->getMatrix();

    // Number of systems.
    int nComponents = m->getSize();
    // Size of the new composed matrix.
    int newMatrixSize = ((*m)[0][0])->getFeSpace()->getAdmin()->getUsedSize() * nComponents;

    // The new matrix has to be stored in compressed row format, therefore
    // the rows are collected.
    std::vector< std::vector< MatEntry > > rows(newMatrixSize, std::vector<MatEntry>(0));

    // Counter for the number of non-zero elements in the new matrix.
    int nElements = 0;

    for (int stencilRow = 0; stencilRow < nComponents; stencilRow++) {
      for (int stencilCol = 0; stencilCol < nComponents; stencilCol++) {

	if (!(*m)[stencilRow][stencilCol]) {
	  continue;
	}
	
	DOFMatrix::Iterator matrixRow((*m)[stencilRow][stencilCol], USED_DOFS);
 	int rowIndex = 0;
 	for (matrixRow.reset(); !matrixRow.end(); matrixRow++, rowIndex++) {
 	  for (int i = 0; i < static_cast<int>((*matrixRow).size()); i++) {	      
 	    if ((*matrixRow)[i].col >= 0) {
 	      MatEntry me;
	      me.entry = (*matrixRow)[i].entry;
	      // The col field is used to store the row number of the new element.
     	      me.col = ((*matrixRow)[i].col * nComponents) + stencilCol; 

	      rows[(rowIndex  * nComponents) + stencilRow].push_back(me);

	      nElements++;
 	    }
 	  }
 	}

      }
    }
    
    double *a = (double*)malloc(sizeof(double) * nElements);
    MKL_INT *ja = (int*)malloc(sizeof(MKL_INT) * nElements);
    MKL_INT *ia = (int*)malloc(sizeof(MKL_INT) * (newMatrixSize + 1));
    double *bvec = (double*)malloc(sizeof(double) * newMatrixSize);
    double *xvec = (double*)malloc(sizeof(double) * newMatrixSize);

    int elCounter = 0;
    int rowCounter = 0;
    ia[0] = 1;
   
    for (std::vector< std::vector< MatEntry > >::iterator rowsIt = rows.begin();
	 rowsIt != rows.end();
	 ++rowsIt) {

      sort((*rowsIt).begin(), (*rowsIt).end(), CmpMatEntryCol());

      ia[rowCounter + 1] = ia[rowCounter] + (*rowsIt).size();

      for (std::vector<MatEntry>::iterator rowIt = (*rowsIt).begin();
	   rowIt != (*rowsIt).end();
	   rowIt++) {
	a[elCounter] = (*rowIt).entry;
	ja[elCounter] = (*rowIt).col + 1;

	elCounter++;
      }

      rowCounter++;
    } 

    // Resort the right hand side of the linear system.
    for (int i = 0; i < b->getSize(); i++) {
      DOFVector<double>::Iterator it(b->getDOFVector(i), USED_DOFS);

      int counter = 0;
      for (it.reset(); !it.end(); ++it, counter++) {	
	bvec[counter * nComponents + i] = *it;
      }
    }

    // real unsymmetric matrix
    MKL_INT mtype = 11;

    // number of right hand sides
    MKL_INT nRhs = 1;

    // Pardiso internal memory
    void *pt[64];
    for (int i = 0; i < 64; i++) {
      pt[i] = 0;
    }

    // Pardiso control parameters
    MKL_INT iparm[64];
    for (int i = 0; i < 64; i++) {
      iparm[i] = 0;
    }

    iparm[0] = 1; // No solver default
    iparm[1] = 2; // Fill-in reordering from METIS
    iparm[2] = mkl_get_max_threads(); // Number of threads
    iparm[7] = 2; // Max numbers of iterative refinement steps
    iparm[9] = 13; // Perturb the pivot elements with 1e-13
    iparm[10] = 1; // Use nonsymmetric permutation and scaling MPS
    iparm[17] = 0; // Output: Number of nonzeros in the factor LU
    iparm[18] = 0; // Output: Mflops for LU factorization

    // Maximum number of numerical factorizations
    MKL_INT maxfct = 1; 
    
    // Which factorization to use
    MKL_INT mnum = 1;

    // Print statistical information in file
    MKL_INT msglvl = 0;

    // Error flag
    MKL_INT error = 0;

    MKL_INT n = newMatrixSize;

    // Reordering and symbolic factorization
    MKL_INT phase = 11;

    double ddum;

    MKL_INT idum;

    PARDISO(pt, &maxfct, &mnum, &mtype, &phase, &n, a, ia, ja, &idum, &nRhs, 
	    iparm, &msglvl, &ddum, &ddum, &error);

    TEST_EXIT(error == 0)("Intel MKL Pardiso error during symbolic factorization: %d\n", error);

    // Numerical factorization
    phase = 22;

    PARDISO(pt, &maxfct, &mnum, &mtype, &phase, &n, a, ia, ja, &idum, &nRhs, 
	    iparm, &msglvl, &ddum, &ddum, &error);
    
    TEST_EXIT(error == 0)("Intel MKL Pardiso error during numerical factorization: %d\n", error);

    // Back substitution and iterative refinement
    phase = 33;
    iparm[7] = 2; // Max numbers of iterative refinement steps

    PARDISO(pt, &maxfct, &mnum, &mtype, &phase, &n, a, ia, ja, &idum, &nRhs,
	    iparm, &msglvl, bvec, xvec, &error);

    TEST_EXIT(error == 0)("Intel MKL Pardiso error during solution: %d\n", error);

    // Copy and resort solution.
    for (int i = 0; i < x->getSize(); i++) {
      DOFVector<double>::Iterator it(x->getDOFVector(i), USED_DOFS);
      
      int counter = 0;
      for (it.reset(); !it.end(); it++, counter++) {
	*it = xvec[counter * nComponents + i];
      }
    }

    // Calculate and print the residual.
    *p = *x;
    *p *= -1.0;
    matVec->matVec(NoTranspose, *p, *r);
    *r += *b;
    
    this->residual = norm(r);

    MSG("Residual: %e\n", this->residual);

    // Termination and release of memory
    phase = -1;

    PARDISO(pt, &maxfct, &mnum, &mtype, &phase, &n, a, ia, ja, &idum, &nRhs, 
	    iparm, &msglvl, &ddum, &ddum, &error);    

    free(a);
    free(ja);
    free(ia);
    free(bvec);
    free(xvec);

    return(1);
  }
}

#endif // HAVE_MKL