#include "DOFVector.h"

#include <dune/istl/bvector.hh>
#include <dune/istl/bcrsmatrix.hh>
#include <dune/istl/ilu.hh>
#include <dune/istl/operators.hh>
#include <dune/istl/solvers.hh>
#include <dune/istl/preconditioners.hh>
#include <dune/istl/matrix.hh>

namespace AMDiS {

  template<typename VectorType>
  DuneSolver<VectorType>::DuneSolver(std::string name)   
    : OEMSolver<VectorType>(name),
      verbose_(1),
      solverType_(""),
      preconType_(""),
      preconRelaxation_(1.0),
      preconIterations_(1)
  {
    GET_PARAMETER(0, name + "->dune verbose", "%d", &verbose_);
    GET_PARAMETER(0, name + "->dune solver", &solverType_);
    GET_PARAMETER(0, name + "->dune precon", &preconType_);

    GET_PARAMETER(0, name + "->dune precon->relaxation", "%f", &preconRelaxation_);
    GET_PARAMETER(0, name + "->dune precon->iterations", "%d", &preconIterations_);

    TEST_EXIT(solverType_ != "")("No DUNE solver chosen!\n");
    TEST_EXIT(preconType_ != "")("No DUNE preconditioner chosen!\n");
  }


  template<typename VectorType>
  DuneSolver<VectorType>::~DuneSolver()
  {}


  template<typename VectorType> 
  template<typename M, typename V>
  Dune::Preconditioner<V, V>* DuneSolver<VectorType>::getPreconditioner(M *matrix)
  {
    if (preconType_ == "ilu0") {
      return new Dune::SeqILU0<M, V, V>(*matrix, preconRelaxation_);
    } else if (preconType_ == "ilun") {
      return new Dune::SeqILUn<M, V, V>(*matrix, preconIterations_, preconRelaxation_);
    } else if (preconType_ == "jac") {
      return new Dune::SeqJac<M, V, V>(*matrix, preconIterations_, preconRelaxation_);
    } else if (preconType_ == "gs") {
      return new Dune::SeqGS<M, V, V>(*matrix, preconIterations_, preconRelaxation_);
    } else if (preconType_ == "sor") {
      return new Dune::SeqSOR<M, V, V>(*matrix, preconIterations_, preconRelaxation_);
    } else if (preconType_ == "ssor") {
      return new Dune::SeqSSOR<M, V, V>(*matrix, preconIterations_, preconRelaxation_);
    };

    ERROR_EXIT("Wrong DUNE preconditioner type!\n");

    return NULL;
  }


  template<typename VectorType> 
  template<typename M, typename V>
  Dune::Preconditioner<V, V>* DuneSolver<VectorType>::getSystemPreconditioner(M *matrix)
  {
    if (preconType_ == "ilu0") {
      ERROR_EXIT("ILU0 is not supported for systems!\n");
    } else if (preconType_ == "ilun") {
      ERROR_EXIT("ILUN is not supported for systems!\n");
    } else if (preconType_ == "jac") {
      return new Dune::SeqJac<M, V, V, 2>(*matrix, preconIterations_, preconRelaxation_);
    } else if (preconType_ == "gs") {
      return new Dune::SeqGS<M, V, V, 2>(*matrix, preconIterations_, preconRelaxation_);
    } else if (preconType_ == "sor") {
      return new Dune::SeqSOR<M, V, V, 2>(*matrix, preconIterations_, preconRelaxation_);
    } else if (preconType_ == "ssor") {
      return new Dune::SeqSSOR<M, V, V, 2>(*matrix, preconIterations_, preconRelaxation_);
    };

    ERROR_EXIT("Wrong DUNE preconditioner type!\n");

    return NULL;
  }


  template<typename VectorType>
  template<typename M, typename V>
  Dune::InverseOperator<V, V>* DuneSolver<VectorType>::getSolver(Dune::MatrixAdapter<M, V, V>* ma,
								 Dune::Preconditioner<V, V>* precon)
  {
    if (solverType_ == "bicgstab") {      
      return new Dune::BiCGSTABSolver<V>(*ma, *precon, this->tolerance, this->max_iter, verbose_);      
    } else if (solverType_ == "cg") {
      return new Dune::CGSolver<V>(*ma, *precon, this->tolerance, this->max_iter, verbose_);
    } else if (solverType_ == "gradient") {
      return new Dune::GradientSolver<V>(*ma, *precon, this->tolerance, this->max_iter, verbose_); 
    } else if (solverType_ == "loop") {
      return new Dune::LoopSolver<V>(*ma, *precon, this->tolerance, this->max_iter, verbose_);
    }

    ERROR_EXIT("Wrong DUNE solver type!\n");
    
    return NULL;
  }
 

  template<typename VectorType>
  void DuneSolver<VectorType>::mapDOFMatrix(DOFMatrix *dofMatrix, DuneMatrix *duneMatrix)
  {
    int rowIndex = 0;
    DOFMatrix::Iterator matrixRowIt(dofMatrix, USED_DOFS);
    for (matrixRowIt.reset(); !matrixRowIt.end(); ++matrixRowIt, rowIndex++) {
      duneMatrix->setrowsize(rowIndex, (*matrixRowIt).size());
    }
    
    duneMatrix->endrowsizes();
    
    rowIndex = 0;
    for (matrixRowIt.reset(); !matrixRowIt.end(); ++matrixRowIt, rowIndex++) {
      int nCols = static_cast<int>((*matrixRowIt).size());
      for (int i = 0; i < nCols; i++) {
	duneMatrix->addindex(rowIndex, (*matrixRowIt)[i].col);
      }   
    }
    
    duneMatrix->endindices();
    
    DuneMatrix::Iterator duneMatIt = duneMatrix->begin();
    for (matrixRowIt.reset(); !matrixRowIt.end(); ++matrixRowIt, ++duneMatIt) {   
      int nCols = static_cast<int>((*matrixRowIt).size());
      for (int i = 0; i < nCols; i++) {
	(*duneMatIt)[(*matrixRowIt)[i].col] = (*matrixRowIt)[i].entry;
      }   
    }

  }

  template<typename VectorType>
  void DuneSolver<VectorType>::mapDOFVector(DOFVector<double> *dofVector, DuneVector *duneVector)
  {
    DuneVector::Iterator duneVecIt = duneVector->begin();
    DOFVector<double>::Iterator dofVecIt(dofVector, USED_DOFS);
    for (dofVecIt.reset(); !dofVecIt.end(); ++dofVecIt, ++duneVecIt) {
      *duneVecIt = *dofVecIt;
    }
  }
}