//
// Software License for AMDiS
//
// Copyright (c) 2010 Dresden University of Technology 
// All rights reserved.
// Authors: Simon Vey, Thomas Witkowski et al.
//
// This file is part of AMDiS
//
// See also license.opensource.txt in the distribution.


#include <algorithm>
#include <boost/numeric/mtl/mtl.hpp>
#include "DOFMatrix.h"
#include "QPsiPhi.h"
#include "BasisFunction.h"
#include "Boundary.h"
#include "DOFAdmin.h"
#include "ElInfo.h"
#include "FiniteElemSpace.h"
#include "Mesh.h"
#include "DOFVector.h"
#include "Operator.h"
#include "BoundaryCondition.h"
#include "BoundaryManager.h"
#include "Assembler.h"
#include "Serializer.h"

namespace AMDiS {

  using namespace mtl;

  DOFMatrix *DOFMatrix::traversePtr = NULL;

  DOFMatrix::DOFMatrix()
    : rowFeSpace(NULL),
      colFeSpace(NULL),
      elementMatrix(3, 3),
      nRow(0),
      nCol(0),
      nnzPerRow(0),
      inserter(NULL)
  {}


  DOFMatrix::DOFMatrix(const FiniteElemSpace* rowSpace,
		       const FiniteElemSpace* colSpace,
		       std::string n)
    : rowFeSpace(rowSpace),
      colFeSpace(colSpace),
      name(n), 
      coupleMatrix(false),
      nnzPerRow(0),
      inserter(NULL)
  {
    FUNCNAME("DOFMatrix::DOFMatrix()");

    TEST_EXIT(rowFeSpace)("No fe space for row!\n");
  
    if (!colFeSpace)
      colFeSpace = rowFeSpace;

    if (rowFeSpace && rowFeSpace->getAdmin())
      (const_cast<DOFAdmin*>(rowFeSpace->getAdmin()))->addDOFIndexed(this);

    boundaryManager = new BoundaryManager(rowFeSpace);

    nRow = rowFeSpace->getBasisFcts()->getNumber();
    nCol = colFeSpace->getBasisFcts()->getNumber();
    elementMatrix.change_dim(nRow, nCol);
    rowIndices.resize(nRow);
    colIndices.resize(nCol);

    applyDBCs.clear();
  }


  DOFMatrix::DOFMatrix(const DOFMatrix& rhs)
    : name(rhs.name + "copy")
  {
    FUNCNAME("DOFMatrix::DOFMatrix()");

    *this = rhs;
    if (rowFeSpace && rowFeSpace->getAdmin())
      (const_cast<DOFAdmin*>( rowFeSpace->getAdmin()))->addDOFIndexed(this);

    TEST_EXIT(rhs.inserter == 0)("Cannot copy during insertion!\n");
    inserter = 0;
  }


  DOFMatrix::~DOFMatrix()
  {
    FUNCNAME("DOFMatrix::~DOFMatrix()");

    if (rowFeSpace && rowFeSpace->getAdmin())
      (const_cast<DOFAdmin*>(rowFeSpace->getAdmin()))->removeDOFIndexed(this);
    if (boundaryManager) 
      delete boundaryManager;
    if (inserter) 
      delete inserter;
  }


  void DOFMatrix::print() const
  {
    FUNCNAME("DOFMatrix::print()");

    if (inserter) 
      inserter->print();
  }


  bool DOFMatrix::symmetric()
  {
    FUNCNAME("DOFMatrix::symmetric()");

    double tol = 1e-5;

    using mtl::tag::major; using mtl::tag::nz; using mtl::begin; using mtl::end;
    namespace traits= mtl::traits;
    typedef base_matrix_type   Matrix;

    traits::row<Matrix>::type                                 row(matrix);
    traits::col<Matrix>::type                                 col(matrix);
    traits::const_value<Matrix>::type                         value(matrix);

    typedef traits::range_generator<major, Matrix>::type      cursor_type;
    typedef traits::range_generator<nz, cursor_type>::type    icursor_type;
    
    for (cursor_type cursor = begin<major>(matrix), cend = end<major>(matrix); cursor != cend; ++cursor)
      for (icursor_type icursor = begin<nz>(cursor), icend = end<nz>(cursor); icursor != icend; ++icursor)
	// Compare each non-zero entry with its transposed
	if (abs(value(*icursor) - matrix[col(*icursor)][row(*icursor)]) > tol)
	  return false;
    return true;
  }


  void DOFMatrix::test()
  {
    FUNCNAME("DOFMatrix::test()");

    int non_symmetric = !symmetric();

    if (non_symmetric)
      MSG("Matrix `%s' not symmetric.\n", name.data());
    else
      MSG("Matrix `%s' is symmetric.\n", name.data());
  }


  DOFMatrix& DOFMatrix::operator=(const DOFMatrix& rhs)
  {
    rowFeSpace = rhs.rowFeSpace;
    colFeSpace = rhs.colFeSpace;
    operators = rhs.operators;
    operatorFactor = rhs.operatorFactor;
    coupleMatrix = rhs.coupleMatrix;

    /// The matrix values may only be copyed, if there is no active insertion.
    if (rhs.inserter == 0 && inserter == 0)
      matrix = rhs.matrix;

    if (rhs.boundaryManager)
      boundaryManager = new BoundaryManager(*rhs.boundaryManager);
    else
      boundaryManager = NULL;
    
    nRow = rhs.nRow;
    nCol = rhs.nCol;
    elementMatrix.change_dim(nRow, nCol);

    return *this;
  }


  void DOFMatrix::addElementMatrix(const ElementMatrix& elMat, 
				   const BoundaryType *bound,
				   ElInfo* rowElInfo,
				   ElInfo* colElInfo)
  {
    FUNCNAME("DOFMatrix::addElementMatrix()");

    TEST_EXIT_DBG(inserter)("DOFMatrix is not in insertion mode");
    inserter_type &ins= *inserter;
 
    // === Get indices mapping from local to global matrix indices. ===

    rowFeSpace->getBasisFcts()->getLocalIndices(rowElInfo->getElement(),
						rowFeSpace->getAdmin(),
						rowIndices);
    if (rowFeSpace == colFeSpace) {
      colIndices = rowIndices;
    } else {
      if (colElInfo) {
	colFeSpace->getBasisFcts()->getLocalIndices(colElInfo->getElement(),
						    colFeSpace->getAdmin(),
						    colIndices);
      } else {
	// If there is no colElInfo pointer, use rowElInfo the get the indices.
	colFeSpace->getBasisFcts()->getLocalIndices(rowElInfo->getElement(),
						    colFeSpace->getAdmin(),
						    colIndices);
      }
    }

    using namespace mtl;

#if 0
    std::cout << "----- PRINT MAT --------" << std::endl;
    std::cout << elMat << std::endl;
    std::cout << "rows: ";
    for (int i = 0; i < rowIndices.size(); i++)
      std::cout << rowIndices[i] << " ";
    std::cout << std::endl;
    std::cout << "cols: ";
    for (int i = 0; i < colIndices.size(); i++)
      std::cout << colIndices[i] << " ";
    std::cout << std::endl;
#endif
         
    for (int i = 0; i < nRow; i++)  {
      DegreeOfFreedom row = rowIndices[i];

      BoundaryCondition *condition = 
	bound ? boundaryManager->getBoundaryCondition(bound[i]) : NULL;

      if (condition && condition->isDirichlet()) {
	if (condition->applyBoundaryCondition()) {
#ifdef HAVE_PARALLEL_DOMAIN_AMDIS
	  if ((*rankDofs)[rowIndices[i]]) 
	    applyDBCs.insert(static_cast<int>(row));
#else
	  applyDBCs.insert(static_cast<int>(row));
#endif
	}
      } else {
	for (int j = 0; j < nCol; j++) {
	  DegreeOfFreedom col = colIndices[j];
	  ins[row][col] += elMat[i][j];
	}
      }
    }
  }


  double DOFMatrix::logAcc(DegreeOfFreedom a, DegreeOfFreedom b) const
  {
    return matrix[a][b];
  }


  void DOFMatrix::freeDOFContent(int index)
  {}


  void DOFMatrix::assemble(double factor, ElInfo *elInfo, const BoundaryType *bound)
  {
    FUNCNAME("DOFMatrix::assemble()");

    set_to_zero(elementMatrix);

    std::vector<Operator*>::iterator it = operators.begin();
    std::vector<double*>::iterator factorIt = operatorFactor.begin();
    for (; it != operators.end(); ++it, ++factorIt)
      if ((*it)->getNeedDualTraverse() == false)
	(*it)->getElementMatrix(elInfo,	elementMatrix, *factorIt ? **factorIt : 1.0);      

    if (factor != 1.0)
      elementMatrix *= factor;

    addElementMatrix(elementMatrix, bound, elInfo, NULL); 
  }


  void DOFMatrix::assemble(double factor, ElInfo *elInfo, const BoundaryType *bound,
			   Operator *op)
  {
      FUNCNAME("DOFMatrix::assemble()");

      TEST_EXIT_DBG(op)("No operator!\n");

      set_to_zero(elementMatrix);
      op->getElementMatrix(elInfo, elementMatrix, factor);

      if (factor != 1.0)
	elementMatrix *= factor;

      addElementMatrix(elementMatrix, bound, elInfo, NULL);
  }


  void DOFMatrix::assemble(double factor, 
			   ElInfo *rowElInfo, ElInfo *colElInfo,
			   ElInfo *smallElInfo, ElInfo *largeElInfo,
			   const BoundaryType *bound, Operator *op)
  {
    FUNCNAME("DOFMatrix::assemble()");

    if (!op && operators.size() == 0)
      return;

    set_to_zero(elementMatrix);

    if (op) {
      op->getElementMatrix(rowElInfo, colElInfo, smallElInfo, largeElInfo, 
			   false, elementMatrix);
    } else {
      std::vector<Operator*>::iterator it = operators.begin();
      std::vector<double*>::iterator factorIt = operatorFactor.begin();
      for (; it != operators.end(); ++it, ++factorIt)
	(*it)->getElementMatrix(rowElInfo, 
				colElInfo,
				smallElInfo, 
				largeElInfo,
				false,
				elementMatrix, 
				*factorIt ? **factorIt : 1.0);	     
    }

    if (factor != 1.0)
      elementMatrix *= factor;

    addElementMatrix(elementMatrix, bound, rowElInfo, colElInfo);       
  }


  void DOFMatrix::assemble2(double factor, 
			    ElInfo *mainElInfo, ElInfo *auxElInfo,
			    ElInfo *smallElInfo, ElInfo *largeElInfo,
			    const BoundaryType *bound, Operator *op)
  {
    FUNCNAME("DOFMatrix::assemble2()");

    if (!op && operators.size() == 0)
      return;

    set_to_zero(elementMatrix);
    
    if (op) {
      ERROR_EXIT("TODO");
//       op->getElementMatrix(rowElInfo, colElInfo, 
// 			   smallElInfo, largeElInfo,
// 			   elementMatrix);
    } else {
      std::vector<Operator*>::iterator it;
      std::vector<double*>::iterator factorIt;
      for(it = operators.begin(), factorIt = operatorFactor.begin();	
	   it != operators.end(); 
	   ++it, ++factorIt) {
	if ((*it)->getNeedDualTraverse()) {
	  (*it)->getElementMatrix(mainElInfo, 
				  auxElInfo,
				  smallElInfo, 
				  largeElInfo,
				  rowFeSpace == colFeSpace,
				  elementMatrix, 
				  *factorIt ? **factorIt : 1.0);
	}
      }      
    }

    if (factor != 1.0)
      elementMatrix *= factor;

    addElementMatrix(elementMatrix, bound, mainElInfo, NULL);       
  }


  void DOFMatrix::finishAssembling()
  {
    // call the operatos cleanup procedures
    for (std::vector<Operator*>::iterator it = operators.begin();
	 it != operators.end(); ++it)
      (*it)->finishAssembling();
  }


  // Should work as before
  Flag DOFMatrix::getAssembleFlag()
  {
    Flag fillFlag(0);
    for (std::vector<Operator*>::iterator op = operators.begin(); 
	 op != operators.end(); ++op)
      fillFlag |= (*op)->getFillFlag();

    return fillFlag;
  }


  void DOFMatrix::axpy(double a, const DOFMatrix& x, const DOFMatrix& y)
  {
    matrix+= a * x.matrix + y.matrix;
  }


  void DOFMatrix::scal(double b) 
  {
    matrix*= b;
  }


  void DOFMatrix::addOperator(Operator *op, double* factor, double* estFactor) 
  { 
    operators.push_back(op);
    operatorFactor.push_back(factor);
    operatorEstFactor.push_back(estFactor);
  }


  void DOFMatrix::serialize(std::ostream &out)
  {
    using namespace mtl; 
    
    typedef traits::range_generator<tag::major, base_matrix_type>::type c_type;
    typedef traits::range_generator<tag::nz, c_type>::type ic_type;
    
    typedef Collection<base_matrix_type>::size_type size_type;
    typedef Collection<base_matrix_type>::value_type value_type;
    
    traits::row<base_matrix_type>::type row(matrix); 
    traits::col<base_matrix_type>::type col(matrix);
    traits::const_value<base_matrix_type>::type value(matrix); 
    
    size_type rows= num_rows(matrix), cols= num_cols(matrix), total= matrix.nnz();
    SerUtil::serialize(out, rows);
    SerUtil::serialize(out, cols);
    SerUtil::serialize(out, total);
    
    for (c_type cursor(mtl::begin<tag::major>(matrix)), 
	   cend(mtl::end<tag::major>(matrix)); cursor != cend; ++cursor)
      for (ic_type icursor(mtl::begin<tag::nz>(cursor)), 
	     icend(mtl::end<tag::nz>(cursor)); icursor != icend; ++icursor) {
	size_type   my_row= row(*icursor), my_col= col(*icursor);
	value_type  my_value= value(*icursor);
	SerUtil::serialize(out, my_row);
	SerUtil::serialize(out, my_col);
	SerUtil::serialize(out, my_value);
      }
  }

  void DOFMatrix::deserialize(std::istream &in)
  {
    using namespace mtl;
    
    typedef Collection<base_matrix_type>::size_type size_type;
    typedef Collection<base_matrix_type>::value_type value_type;
    
    size_type rows, cols, total;
    SerUtil::deserialize(in, rows);
    SerUtil::deserialize(in, cols);
    SerUtil::deserialize(in, total);
    
    // Prepare matrix insertion
    clear();
    // matrix.change_dim(rows, cols) // do we want this?
    inserter_type ins(matrix);
    
    for (size_type i = 0; i < total; ++i) {
      size_type   my_row, my_col;
      value_type  my_value;
      SerUtil::deserialize(in, my_row);
      SerUtil::deserialize(in, my_col);
      SerUtil::deserialize(in, my_value);
      ins(my_row, my_col) << my_value;
    }    
  }


  void DOFMatrix::copy(const DOFMatrix& rhs) 
  {
    matrix = rhs.matrix;
  }


  void DOFMatrix::removeRowsWithDBC(std::set<int> *rows)
  {      
    FUNCNAME("DOFMatrix::removeRowsWithDBC()");

    inserter_type &ins = *inserter;  
    for (std::set<int>::iterator it = rows->begin(); it != rows->end(); ++it)
      ins[*it][*it] = 1.0;    

    rows->clear();
  }


  int DOFMatrix::memsize() 
  {   
    return (num_rows(matrix) + matrix.nnz()) * sizeof(base_matrix_type::size_type)
      + matrix.nnz() * sizeof(base_matrix_type::value_type);
  }


  void DOFMatrix::startInsertion(int nnz_per_row)
  {
    if (inserter) {
      delete inserter;
      inserter = NULL; 
    }

    inserter = new inserter_type(matrix, nnz_per_row);
  }

}