//
// Software License for AMDiS
//
// Copyright (c) 2010 Dresden University of Technology 
// All rights reserved.
// Authors: Simon Vey, Thomas Witkowski et al.
//
// This file is part of AMDiS
//
// See also license.opensource.txt in the distribution.


#include <vector>
#include <algorithm>
#include <boost/numeric/mtl/mtl.hpp>
#include "Assembler.h"
#include "Operator.h"
#include "Element.h"
#include "QPsiPhi.h"
#include "DOFVector.h"

namespace AMDiS {

  Assembler::Assembler(Operator *op,
		       const FiniteElemSpace *row,
		       const FiniteElemSpace *col) 
    : operat(op),
      rowFeSpace(row),
      colFeSpace(col ? col : row),
      nRow(rowFeSpace->getBasisFcts()->getNumber()),
      nCol(colFeSpace->getBasisFcts()->getNumber()),
      remember(true),
      rememberElMat(false),
      rememberElVec(false),
      elementMatrix(nRow, nCol),
      elementVector(nRow),
      tmpMat(nRow, nCol),
      lastMatEl(NULL),
      lastVecEl(NULL),
      lastTraverseId(-1)
  {}


  Assembler::~Assembler()
  {}


  void Assembler::calculateElementMatrix(const ElInfo *elInfo, 
					 ElementMatrix& userMat,
					 double factor)
  {
    FUNCNAME("Assembler::calculateElementMatrix()");

    if (remember && (factor != 1.0 || operat->uhOld))
      rememberElMat = true;

    Element *el = elInfo->getElement();

    if (el != lastMatEl || !operat->isOptimized()) {
      initElement(elInfo);

      if (rememberElMat)
	set_to_zero(elementMatrix);

      lastMatEl = el;
    } else {
      if (rememberElMat) {
	userMat += factor * elementMatrix;
	return;
      }
    }

    ElementMatrix& mat = rememberElMat ? elementMatrix : userMat;

    if (secondOrderAssembler)
      secondOrderAssembler->calculateElementMatrix(elInfo, mat);
    if (firstOrderAssemblerGrdPsi)
      firstOrderAssemblerGrdPsi->calculateElementMatrix(elInfo, mat);
    if (firstOrderAssemblerGrdPhi)
      firstOrderAssemblerGrdPhi->calculateElementMatrix(elInfo, mat);
    if (zeroOrderAssembler)
      zeroOrderAssembler->calculateElementMatrix(elInfo, mat);

    if (rememberElMat && &userMat != &elementMatrix)
      userMat += factor * elementMatrix;
  }


  void Assembler::calculateElementMatrix(const ElInfo *rowElInfo,
					 const ElInfo *colElInfo,
					 const ElInfo *smallElInfo,
					 const ElInfo *largeElInfo,
					 bool rowColFeSpaceEqual,
					 ElementMatrix& userMat,
					 double factor)
  {
    FUNCNAME("Assembler::calculateElementMatrix()");

    if (remember && (factor != 1.0 || operat->uhOld))
      rememberElMat = true;
  
    Element *el = smallElInfo->getElement();   
    lastVecEl = lastMatEl = NULL;
   
    if ((el != lastMatEl && el != lastVecEl) || !operat->isOptimized()) {
      if (smallElInfo == colElInfo)
	initElement(smallElInfo);	
      else
	initElement(smallElInfo, largeElInfo);      
    }      

    if (el != lastMatEl || !operat->isOptimized()) {
      if (rememberElMat)
	set_to_zero(elementMatrix);

      lastMatEl = el;
    } else {
      if (rememberElMat) {
	userMat += factor * elementMatrix;
	return;
      }
    }
 
    ElementMatrix& mat = rememberElMat ? elementMatrix : userMat;

    if (secondOrderAssembler) {
      secondOrderAssembler->calculateElementMatrix(smallElInfo, mat);

      ElementMatrix &m = 
	smallElInfo->getSubElemGradCoordsMat(rowFeSpace->getBasisFcts()->getDegree());
      
      if (!rowColFeSpaceEqual) {
	if (smallElInfo == colElInfo)
	  tmpMat = m * mat;	
	else
	  tmpMat = mat * trans(m);
	
	mat = tmpMat;
      }
    }

    if (firstOrderAssemblerGrdPsi) {
      firstOrderAssemblerGrdPsi->calculateElementMatrix(smallElInfo, mat);

      if (!rowColFeSpaceEqual) {
	if (largeElInfo == rowElInfo) {
	  ElementMatrix &m = 
	    smallElInfo->getSubElemGradCoordsMat(rowFeSpace->getBasisFcts()->getDegree());
	  
	  tmpMat = m * mat;
	} else {
	  ElementMatrix &m = 
	    smallElInfo->getSubElemCoordsMat(rowFeSpace->getBasisFcts()->getDegree());
	  
	  tmpMat = mat * trans(m);
	}
	
	mat = tmpMat;
      }
    }

    if (firstOrderAssemblerGrdPhi) {
      firstOrderAssemblerGrdPhi->calculateElementMatrix(smallElInfo, mat);

      if (!rowColFeSpaceEqual) {
	if (largeElInfo == colElInfo) {
	  ElementMatrix &m = 
	    smallElInfo->getSubElemGradCoordsMat(rowFeSpace->getBasisFcts()->getDegree());
	  
	  tmpMat = mat * trans(m);
	} else {
	  ElementMatrix &m = 
	    smallElInfo->getSubElemCoordsMat(rowFeSpace->getBasisFcts()->getDegree());
	  
	  tmpMat = m * mat;	
	}
	
	mat = tmpMat;
      }
    }

    if (zeroOrderAssembler) {
      zeroOrderAssembler->calculateElementMatrix(smallElInfo, mat);

      if (!rowColFeSpaceEqual) {
	ElementMatrix &m = 
	  smallElInfo->getSubElemCoordsMat(rowFeSpace->getBasisFcts()->getDegree());
	
	if (smallElInfo == colElInfo)
	  tmpMat = m * mat;
	else 
	  tmpMat = mat * trans(m);
	
	mat = tmpMat;
      }
    }

    if (rememberElMat && &userMat != &elementMatrix)
      userMat += factor * elementMatrix;       
  }


  void Assembler::calculateElementVector(const ElInfo *elInfo, 
					 ElementVector& userVec,
					 double factor)
  {
    FUNCNAME("Assembler::calculateElementVector()");

    if (remember && factor != 1.0)
      rememberElVec = true;

    Element *el = elInfo->getElement();

    if ((el != lastMatEl && el != lastVecEl) || !operat->isOptimized())
      initElement(elInfo);
    
    if (el != lastVecEl || !operat->isOptimized()) {
      if (rememberElVec)
	set_to_zero(elementVector);
	
      lastVecEl = el;
    } else {
      if (rememberElVec) {
	userVec += factor * elementVector;
	return;
      }
    }

    ElementVector& vec = rememberElVec ? elementVector : userVec;

    if (operat->uhOld && remember) {
      matVecAssemble(elInfo, vec);
      if (rememberElVec)
	userVec += factor * elementVector;      

      return;
    } 

    if (firstOrderAssemblerGrdPsi)
      firstOrderAssemblerGrdPsi->calculateElementVector(elInfo, vec);
    if (zeroOrderAssembler)
      zeroOrderAssembler->calculateElementVector(elInfo, vec);
      
    if (rememberElVec)
      userVec += factor * elementVector;    
  }


  void Assembler::calculateElementVector(const ElInfo *mainElInfo, 
					 const ElInfo *auxElInfo,
					 const ElInfo *smallElInfo,
					 const ElInfo *largeElInfo,
					 ElementVector& userVec, 
					 double factor)
  {
    FUNCNAME("Assembler::calculateElementVector()");

    if (remember && factor != 1.0)
      rememberElVec = true;

    Element *el = mainElInfo->getElement();

    if ((el != lastMatEl && el != lastVecEl) || !operat->isOptimized())
      initElement(smallElInfo, largeElInfo);
   
    if (el != lastVecEl || !operat->isOptimized()) {
      if (rememberElVec)
	set_to_zero(elementVector);

      lastVecEl = el;
    } else {
      if (rememberElVec) {
	userVec += factor * elementVector;
	return;
      }
    }
    ElementVector& vec = rememberElVec ? elementVector : userVec;

    if (operat->uhOld && remember) {
      if (smallElInfo->getLevel() == largeElInfo->getLevel())
	matVecAssemble(auxElInfo, vec);
      else
	matVecAssemble(mainElInfo, auxElInfo, smallElInfo, largeElInfo, vec);      

      if (rememberElVec)
	userVec += factor * elementVector;      

      return;
    } 

    if (firstOrderAssemblerGrdPsi) {
      ERROR_EXIT("Not yet implemented!\n");
    }

    if (zeroOrderAssembler) {
      zeroOrderAssembler->calculateElementVector(smallElInfo, vec);
      
      if (smallElInfo != mainElInfo) {
	ElementVector tmpVec(vec);	
	ElementMatrix &m = 
	  smallElInfo->getSubElemCoordsMat(rowFeSpace->getBasisFcts()->getDegree());

	tmpVec = m * vec;	
	vec = tmpVec;
      }      
    }

    if (rememberElVec)
      userVec += factor * elementVector;    
  }


  void Assembler::matVecAssemble(const ElInfo *elInfo, ElementVector& vec)
  {
    FUNCNAME("Assembler::matVecAssemble()");

    Element *el = elInfo->getElement(); 
    ElementVector uhOldLoc(operat->uhOld->getFeSpace() == rowFeSpace ? 
			   nRow : nCol);
    operat->uhOld->getLocalVector(el, uhOldLoc);
    
    if (el != lastMatEl) {
      set_to_zero(elementMatrix);
      calculateElementMatrix(elInfo, elementMatrix);
    }

    for (int i = 0; i < nRow; i++) {
      double val = 0.0;
      for (int j = 0; j < nCol; j++)
	val += elementMatrix[i][j] * uhOldLoc[j];
      
      vec[i] += val;
    }   
  }


  void Assembler::matVecAssemble(const ElInfo *mainElInfo, const ElInfo *auxElInfo,
				 const ElInfo *smallElInfo, const ElInfo *largeElInfo,
				 ElementVector& vec)
  {
    FUNCNAME("Assembler::matVecAssemble()");

    TEST_EXIT(rowFeSpace->getBasisFcts() == colFeSpace->getBasisFcts())
      ("Works only for equal basis functions for different components!\n");

    TEST_EXIT(operat->uhOld->getFeSpace()->getMesh() == auxElInfo->getMesh())
      ("Da stimmt was nicht!\n");

    Element *mainEl = mainElInfo->getElement(); 
    Element *auxEl = auxElInfo->getElement();

    const BasisFunction *basFcts = rowFeSpace->getBasisFcts();
    int nBasFcts = basFcts->getNumber();
    ElementVector uhOldLoc(nBasFcts);

    operat->uhOld->getLocalVector(auxEl, uhOldLoc);

    if (mainEl != lastMatEl) {
      set_to_zero(elementMatrix);
      calculateElementMatrix(mainElInfo, auxElInfo, smallElInfo, largeElInfo, 
			     false, elementMatrix);    
    }

    for (int i = 0; i < nBasFcts; i++) {
      double val = 0.0;
      for (int j = 0; j < nBasFcts; j++)
 	val += elementMatrix[i][j] * uhOldLoc[j];
      vec[i] += val;
    }   
  }


  void Assembler::initElement(const ElInfo *smallElInfo, 
			      const ElInfo *largeElInfo,
			      Quadrature *quad)
  {
    if (secondOrderAssembler) 
      secondOrderAssembler->initElement(smallElInfo, largeElInfo, quad);
    if (firstOrderAssemblerGrdPsi)
      firstOrderAssemblerGrdPsi->initElement(smallElInfo, largeElInfo, quad);    
    if (firstOrderAssemblerGrdPhi)
      firstOrderAssemblerGrdPhi->initElement(smallElInfo, largeElInfo, quad);
    if (zeroOrderAssembler)
      zeroOrderAssembler->initElement(smallElInfo, largeElInfo, quad);
  }


  void Assembler::checkQuadratures()
  { 
    if (secondOrderAssembler) {
      // create quadrature
      if (!secondOrderAssembler->getQuadrature()) {
	int dim = rowFeSpace->getMesh()->getDim();
	int degree = operat->getQuadratureDegree(2);
	Quadrature *quadrature = Quadrature::provideQuadrature(dim, degree);
	secondOrderAssembler->setQuadrature(quadrature);
      }
    }
    if (firstOrderAssemblerGrdPsi) {
      // create quadrature
      if (!firstOrderAssemblerGrdPsi->getQuadrature()) {
	int dim = rowFeSpace->getMesh()->getDim();
	int degree = operat->getQuadratureDegree(1, GRD_PSI);
	Quadrature *quadrature = Quadrature::provideQuadrature(dim, degree);
	firstOrderAssemblerGrdPsi->setQuadrature(quadrature);
      }
    }
    if (firstOrderAssemblerGrdPhi) {
      // create quadrature
      if (!firstOrderAssemblerGrdPhi->getQuadrature()) {
	int dim = rowFeSpace->getMesh()->getDim();
	int degree = operat->getQuadratureDegree(1, GRD_PHI);
	Quadrature *quadrature = Quadrature::provideQuadrature(dim, degree);
	firstOrderAssemblerGrdPhi->setQuadrature(quadrature);
      }
    }
    if (zeroOrderAssembler) {
      // create quadrature
      if (!zeroOrderAssembler->getQuadrature()) {
	int dim = rowFeSpace->getMesh()->getDim();
	int degree = operat->getQuadratureDegree(0);
	Quadrature *quadrature = Quadrature::provideQuadrature(dim, degree);
	zeroOrderAssembler->setQuadrature(quadrature);
      }
    }
  }


  void Assembler::finishAssembling()
  {
    lastVecEl = NULL;
    lastMatEl = NULL;
  }


  OptimizedAssembler::OptimizedAssembler(Operator  *op,
					 Quadrature *quad2,
					 Quadrature *quad1GrdPsi,
					 Quadrature *quad1GrdPhi,
					 Quadrature *quad0,
					 const FiniteElemSpace *rowFeSpace,
					 const FiniteElemSpace *colFeSpace) 
    : Assembler(op, rowFeSpace, colFeSpace)
  {
    bool opt = (rowFeSpace->getBasisFcts() == colFeSpace->getBasisFcts());

    // create sub assemblers
    secondOrderAssembler = 
      SecondOrderAssembler::getSubAssembler(op, this, quad2, opt);
    firstOrderAssemblerGrdPsi = 
      FirstOrderAssembler::getSubAssembler(op, this, quad1GrdPsi, GRD_PSI, opt);
    firstOrderAssemblerGrdPhi = 
      FirstOrderAssembler::getSubAssembler(op, this, quad1GrdPhi, GRD_PHI, opt);
    zeroOrderAssembler = 
      ZeroOrderAssembler::getSubAssembler(op, this, quad0, opt);

    checkQuadratures();
  }


  StandardAssembler::StandardAssembler(Operator *op,
				       Quadrature *quad2,
				       Quadrature *quad1GrdPsi,
				       Quadrature *quad1GrdPhi,
				       Quadrature *quad0,
				       const FiniteElemSpace *rowFeSpace,
				       const FiniteElemSpace *colFeSpace) 
    : Assembler(op, rowFeSpace, colFeSpace)
  {
    remember = false;

    // create sub assemblers
    secondOrderAssembler = 
      SecondOrderAssembler::getSubAssembler(op, this, quad2, false);
    firstOrderAssemblerGrdPsi = 
      FirstOrderAssembler::getSubAssembler(op, this, quad1GrdPsi, GRD_PSI, false);
    firstOrderAssemblerGrdPhi = 
      FirstOrderAssembler::getSubAssembler(op, this, quad1GrdPhi, GRD_PHI, false);
    zeroOrderAssembler = 
      ZeroOrderAssembler::getSubAssembler(op, this, quad0, false);

    checkQuadratures();
  }

}