#include "ResidualEstimator.h"
#include "Operator.h"
#include "DOFMatrix.h"
#include "DOFVector.h"
#include "Assembler.h"
#include "Traverse.h"
#include "Parameters.h"

namespace AMDiS {

  ResidualEstimator::ResidualEstimator(std::string name, int r) 
    : Estimator(name, r),
      C0(1.0), 
      C1(1.0), 
      C2(1.0), 
      C3(1.0)
  {
    GET_PARAMETER(0, name + "->C0", "%f", &C0);
    GET_PARAMETER(0, name + "->C1", "%f", &C1);
    GET_PARAMETER(0, name + "->C2", "%f", &C2);
    GET_PARAMETER(0, name + "->C3", "%f", &C3);

    C0 = C0 > 1.e-25 ? sqr(C0) : 0.0;
    C1 = C1 > 1.e-25 ? sqr(C1) : 0.0;
    C2 = C2 > 1.e-25 ? sqr(C2) : 0.0;
    C3 = C3 > 1.e-25 ? sqr(C3) : 0.0;
  }

  void ResidualEstimator::init(double ts)
  {
    FUNCNAME("ResidualEstimator::init()");
    
    timestep = ts;

    mesh = uh[row == -1 ? 0 : row]->getFESpace()->getMesh();

    nSystems = static_cast<int>(uh.size());
    TEST_EXIT_DBG(nSystems > 0)("no system set\n");

    dim = mesh->getDim();
    basFcts = new const BasisFunction*[nSystems];
    quadFast = new FastQuadrature*[nSystems];

    degree = 0;
    for (int system = 0; system < nSystems; system++) {
      basFcts[system] = uh[system]->getFESpace()->getBasisFcts();
      degree = std::max(degree, basFcts[system]->getDegree());
    }

    degree *= 2;

    quad = Quadrature::provideQuadrature(dim, degree);
    nPoints = quad->getNumPoints();

    Flag flag = INIT_PHI | INIT_GRD_PHI;
    if (degree > 2) {
      flag |= INIT_D2_PHI;
    }

    for (int system = 0; system < nSystems; system++) {
      quadFast[system] = FastQuadrature::provideFastQuadrature(basFcts[system], 
							       *quad, 
							       flag);
    }
  
    uhEl = new double*[nSystems];
    uhNeigh = new double*[nSystems];
    uhOldEl = timestep ? new double*[nSystems] : NULL;

    for (int system = 0; system < nSystems; system++) {
      uhEl[system] = new double[basFcts[system]->getNumber()]; 
      uhNeigh[system] = new double[basFcts[system]->getNumber()];
      if (timestep)
	uhOldEl[system] = new double[basFcts[system]->getNumber()];
    }

    uhQP = timestep ? new double[nPoints] : NULL;
    uhOldQP = timestep ? new double[nPoints] : NULL;

    riq = new double[nPoints];

    grdUh_qp = NULL;
    D2uhqp = NULL;

    TraverseStack stack;
    ElInfo *elInfo = NULL;

    // clear error indicators and mark elements for jumpRes
    elInfo = stack.traverseFirst(mesh, -1, Mesh::CALL_LEAF_EL);
    while (elInfo) {
      elInfo->getElement()->setEstimation(0.0, row);
      elInfo->getElement()->setMark(1);
      elInfo = stack.traverseNext(elInfo);
    }

    est_sum = 0.0;
    est_max = 0.0;
    est_t_sum = 0.0;
    est_t_max = 0.0;

    traverseFlag = 
      Mesh::FILL_NEIGH      |
      Mesh::FILL_COORDS     |
      Mesh::FILL_OPP_COORDS |
      Mesh::FILL_BOUND      |
      Mesh::FILL_GRD_LAMBDA |
      Mesh::FILL_DET        |
      Mesh::CALL_LEAF_EL;

    neighInfo = mesh->createNewElInfo();

    // prepare date for computing jump residual
    if (C1 && (dim > 1)) {
      surfaceQuad_ = Quadrature::provideQuadrature(dim - 1, degree);
      nPointsSurface_ = surfaceQuad_->getNumPoints();
      grdUhEl_.resize(nPointsSurface_);
      grdUhNeigh_.resize(nPointsSurface_);
      jump_.resize(nPointsSurface_);
      localJump_.resize(nPointsSurface_);
      neighbours_ = Global::getGeo(NEIGH, dim);
      lambdaNeigh_ = new DimVec<WorldVector<double> >(dim, NO_INIT);
      lambda_ = new DimVec<double>(dim, NO_INIT);
    }
  }

  void ResidualEstimator::exit(bool output)
  {
    FUNCNAME("ResidualEstimator::exit()");

    est_sum = sqrt(est_sum);
    est_t_sum = sqrt(est_t_sum);

    for (int system = 0; system < nSystems; system++) {
      delete [] uhEl[system];
      delete [] uhNeigh[system];
      if (timestep)
	delete [] uhOldEl[system];
    }

    delete [] uhEl;
    delete [] uhNeigh;

    if (timestep) {
      delete [] uhOldEl;
      delete [] uhQP;
      delete [] uhOldQP;
    } else {
      if (uhQP != NULL)
	delete [] uhQP;
    }

    if (output) {
      MSG("estimate   = %.8e\n", est_sum);
      if (C3)
	MSG("time estimate   = %.8e\n", est_t_sum);
    }

    delete [] riq;
    delete [] basFcts;
    delete [] quadFast;

    if (grdUh_qp != NULL)
      delete [] grdUh_qp;
    if (D2uhqp != NULL)
      delete [] D2uhqp;

    if (C1 && (dim > 1)) {
      delete lambdaNeigh_;
      delete lambda_;
    }

    delete neighInfo;
  }

  void ResidualEstimator::estimateElement(ElInfo *elInfo)
  {    
    FUNCNAME("ResidualEstimator::estimateElement()");

    TEST_EXIT_DBG(nSystems > 0)("no system set\n");

    double val = 0.0;
    std::vector<Operator*>::iterator it;
    std::vector<double*>::iterator itfac;
    Element *el = elInfo->getElement();
    double det = elInfo->getDet();
    const DimVec<WorldVector<double> > &grdLambda = elInfo->getGrdLambda();
    double est_el = el->getEstimation(row);
    double h2 = h2_from_det(det, dim);

    for (int iq = 0; iq < nPoints; iq++) {
      riq[iq] = 0.0;
    }

    for (int system = 0; system < nSystems; system++) {

      if (matrix[system] == NULL) 
	continue;

      // init assemblers
      for (it = const_cast<DOFMatrix*>(matrix[system])->getOperatorsBegin(),
	   itfac = const_cast<DOFMatrix*>(matrix[system])->getOperatorEstFactorBegin();
	   it != const_cast<DOFMatrix*>(matrix[system])->getOperatorsEnd(); 
	   ++it, ++itfac) {
	if (*itfac == NULL || **itfac != 0.0) {
	  (*it)->getAssembler(omp_get_thread_num())->initElement(elInfo, NULL, quad);
	}
      }

      if (C0) {
	for (it = const_cast<DOFVector<double>*>(fh[system])->getOperatorsBegin();
	     it != const_cast<DOFVector<double>*>(fh[system])->getOperatorsEnd(); 
	     ++it) {
	  (*it)->getAssembler(omp_get_thread_num())->initElement(elInfo, NULL, quad);	
	}
      }
	
      if (timestep && uhOld[system]) {
	TEST_EXIT_DBG(uhOld[system])("no uhOld\n");
	uhOld[system]->getLocalVector(el, uhOldEl[system]);
  
	// ===== time and element residuals       
	if (C0 || C3) {   
	  uh[system]->getVecAtQPs(elInfo, NULL, quadFast[system], uhQP);
	  uhOld[system]->getVecAtQPs(elInfo, NULL, quadFast[system], uhOldQP);
	  
	  if (C3 && uhOldQP && system == std::max(row, 0)) {
	    val = 0.0;
	    for (int iq = 0; iq < nPoints; iq++) {
	      double tiq = (uhQP[iq] - uhOldQP[iq]);
	      val += quad->getWeight(iq) * tiq * tiq;
	    }
	    double v = C3 * det * val;
	    est_t_sum += v;
	    est_t_max = max(est_t_max, v);
	  }
	}
      }
           
      if (C0) {  
	for (it = const_cast<DOFMatrix*>(matrix[system])->getOperatorsBegin(),
	     itfac = const_cast<DOFMatrix*>(matrix[system])->getOperatorEstFactorBegin(); 
	     it != const_cast<DOFMatrix*>(matrix[system])->getOperatorsEnd(); 
	     ++it, ++itfac) {
	  if (*itfac == NULL || **itfac != 0.0) {
	    if ((uhQP == NULL) && (*it)->zeroOrderTerms()) {
	      uhQP = new double[nPoints];
	      uh[system]->getVecAtQPs(elInfo, NULL, quadFast[system], uhQP);
	    }
	    if ((grdUh_qp == NULL) && ((*it)->firstOrderTermsGrdPsi() || (*it)->firstOrderTermsGrdPhi())) {
	      grdUh_qp = new WorldVector<double>[nPoints];
	      uh[system]->getGrdAtQPs(elInfo, NULL, quadFast[system], grdUh_qp);
	    }
	    if ((D2uhqp == NULL) && (degree > 2) && (*it)->secondOrderTerms()) { 
	      D2uhqp = new WorldMatrix<double>[nPoints];
	      uh[system]->getD2AtQPs(elInfo, NULL, quadFast[system], D2uhqp);	    
	    }
	  }
	}
	
	r(elInfo,
	  nPoints, 
	  uhQP,
	  grdUh_qp,
	  D2uhqp,
	  uhOldQP,
	  NULL,  // grdUhOldQP 
	  NULL,  // D2UhOldQP
	  matrix[system], 
	  fh[system],
	  quad,
	  riq);
      }     
    }

    // add integral over r square
    val = 0.0;
    for (int iq = 0; iq < nPoints; iq++)
      val += quad->getWeight(iq) * riq[iq] * riq[iq];
   
    if (timestep != 0.0 || norm == NO_NORM || norm == L2_NORM)
      val = C0 * h2 * h2 * det * val;
    else
      val = C0 * h2 * det * val;
    
    est_el += val;

    // ===== jump residuals 
    if (C1 && (dim > 1)) {
      int dow = Global::getGeo(WORLD);

      for (int face = 0; face < neighbours_; face++) {  
	Element *neigh = const_cast<Element*>(elInfo->getNeighbour(face));
	if (neigh && neigh->getMark()) {      
	  int oppV = elInfo->getOppVertex(face);
	      
	  el->sortFaceIndices(face, &faceIndEl_);
	  neigh->sortFaceIndices(oppV, &faceIndNeigh_);
	    
	  neighInfo->setElement(const_cast<Element*>(neigh));
	  neighInfo->setFillFlag(Mesh::FILL_COORDS);
	      	
	  for (int i = 0; i < dow; i++)
	    neighInfo->getCoord(oppV)[i] = elInfo->getOppCoord(face)[i];
		
	  // periodic leaf data ?
	  ElementData *ldp = el->getElementData()->getElementData(PERIODIC);

	  bool periodicCoords = false;

	  if (ldp) {
	    std::list<LeafDataPeriodic::PeriodicInfo>::iterator it;
	    std::list<LeafDataPeriodic::PeriodicInfo>& infoList = 
		dynamic_cast<LeafDataPeriodic*>(ldp)->getInfoList();

	    for (it = infoList.begin(); it != infoList.end(); ++it) {
	      if (it->elementSide == face) {
		for (int i = 0; i < dim; i++) {
		  int i1 = faceIndEl_[i];
		  int i2 = faceIndNeigh_[i];

		  int j = 0;
		  for (; j < dim; j++) {
		    if (i1 == el->getVertexOfPosition(INDEX_OF_DIM(dim - 1, 
								   dim),
						      face,
						      j)) {
		      break;
		    }
		  }

		  TEST_EXIT_DBG(j != dim)("vertex i1 not on face ???\n");
		      
		  neighInfo->getCoord(i2) = (*(it->periodicCoords))[j];
		}
		periodicCoords = true;
		break;
	      }
	    }
	  }
      
	  if (!periodicCoords) {
	    for (int i = 0; i < dim; i++) {
	      int i1 = faceIndEl_[i];
	      int i2 = faceIndNeigh_[i];
	      for (int j = 0; j < dow; j++)
		neighInfo->getCoord(i2)[j] = elInfo->getCoord(i1)[j];
	    }
	  }
	      
	  Parametric *parametric = mesh->getParametric();
	  if (parametric) {
	    neighInfo = parametric->addParametricInfo(neighInfo);
	  }
	      
	  double detNeigh = abs(neighInfo->calcGrdLambda(*lambdaNeigh_));
	      
	  for (int iq = 0; iq < nPointsSurface_; iq++) {
	    jump_[iq].set(0.0);
	  }
	     

	  for (int system = 0; system < nSystems; system++) {	
	    if (matrix[system] == NULL) 
	      continue;
	      
	    uh[system]->getLocalVector(el, uhEl[system]);	
	    uh[system]->getLocalVector(neigh, uhNeigh[system]);
			
	    for (int iq = 0; iq < nPointsSurface_; iq++) {
	      (*lambda_)[face] = 0.0;
	      for (int i = 0; i < dim; i++) {
		(*lambda_)[faceIndEl_[i]] = surfaceQuad_->getLambda(iq, i);
	      }
		  
	      basFcts[system]->evalGrdUh(*lambda_, 
					 grdLambda, 
					 uhEl[system], 
					 &grdUhEl_[iq]);
		  
	      (*lambda_)[oppV] = 0.0;
	      for (int i = 0; i < dim; i++) {
		(*lambda_)[faceIndNeigh_[i]] = surfaceQuad_->getLambda(iq, i);
	      }
		  
	      basFcts[system]->evalGrdUh(*lambda_, 
					 *lambdaNeigh_, 
					 uhNeigh[system], 
					 &grdUhNeigh_[iq]);
		  
	      grdUhEl_[iq] -= grdUhNeigh_[iq];
	    }				

	    std::vector<double*>::iterator fac;

	    for (it = const_cast<DOFMatrix*>(matrix[system])->getOperatorsBegin(),
		   fac = const_cast<DOFMatrix*>(matrix[system])->getOperatorEstFactorBegin(); 
		 it != const_cast<DOFMatrix*>(matrix[system])->getOperatorsEnd(); 
		 ++it, ++fac) {

	      if (*fac == NULL || **fac != 0.0) {
		for (int iq = 0; iq < nPointsSurface_; iq++) {
		  localJump_[iq].set(0.0);
		}
		
		(*it)->weakEvalSecondOrder(nPointsSurface_,
					   grdUhEl_.getValArray(),
					   localJump_.getValArray());
		double factor = *fac ? **fac : 1.0;
		if (factor != 1.0) {
		  for (int i = 0; i < nPointsSurface_; i++) {
		    localJump_[i] *= factor;
		  }
		}
		
		for (int i = 0; i < nPointsSurface_; i++) {
		  jump_[i] += localJump_[i];
		}
	      }		
	    }
	  }
	      
	  val = 0.0;
	  for (int iq = 0; iq < nPointsSurface_; iq++) {
	    val += surfaceQuad_->getWeight(iq) * (jump_[iq] * jump_[iq]);
	  }
	      
	  double d = 0.5 * (det + detNeigh);

	  if (norm == NO_NORM || norm == L2_NORM)
	    val *= C1 * h2_from_det(d, dim) * d;
	  else
	    val *= C1 * d;
	      
	  if (parametric) {
	    neighInfo = parametric->removeParametricInfo(neighInfo);
	  }

	  neigh->setEstimation(neigh->getEstimation(row) + val, row);
	  est_el += val;
	} 
      } 
       
      val = fh[std::max(row, 0)]->
	getBoundaryManager()->
	boundResidual(elInfo, matrix[std::max(row, 0)], uh[std::max(row, 0)]);
      if (norm == NO_NORM || norm == L2_NORM)
	val *= C1 * h2;
      else
	val *= C1;
	
      est_el += val;
    } 
  

    el->setEstimation(est_el, row);

    est_sum += est_el;
    est_max = max(est_max, est_el);

    elInfo->getElement()->setMark(0);  
  }

  void r(const ElInfo *elInfo,
	 int nPoints,
	 const double *uhIq,
	 const WorldVector<double> *grdUhIq,
	 const WorldMatrix<double> *D2UhIq,
	 const double *uhOldIq,
	 const WorldVector<double> *grdUhOldIq,
	 const WorldMatrix<double> *D2UhOldIq,
	 DOFMatrix *A, 
	 DOFVector<double> *fh,
	 Quadrature *quad,
	 double *result)
  {
    std::vector<Operator*>::iterator it;
    std::vector<double*>::iterator fac;

    // lhs
    for (it = const_cast<DOFMatrix*>(A)->getOperatorsBegin(),
	   fac = const_cast<DOFMatrix*>(A)->getOperatorEstFactorBegin(); 
	 it != const_cast<DOFMatrix*>(A)->getOperatorsEnd(); 
	 ++it, ++fac) {
     
      double factor = *fac ? **fac : 1.0;

      if (factor) {
	if (D2UhIq) {
	  (*it)->evalSecondOrder(nPoints, uhIq, grdUhIq, D2UhIq, result, -factor);
	}

	if (grdUhIq) {
	  (*it)->evalFirstOrderGrdPsi(nPoints, uhIq, grdUhIq, D2UhIq, result, factor);
	  (*it)->evalFirstOrderGrdPhi(nPoints, uhIq, grdUhIq, D2UhIq, result, factor);
	}
	
	if (uhIq) {
	  (*it)->evalZeroOrder(nPoints, uhIq, grdUhIq, D2UhIq, result, factor);
	}
      }
    }
    
    // rhs
    for (it = const_cast<DOFVector<double>*>(fh)->getOperatorsBegin(),
	 fac = const_cast<DOFVector<double>*>(fh)->getOperatorEstFactorBegin(); 
	 it != const_cast<DOFVector<double>*>(fh)->getOperatorsEnd(); 
	 ++it, ++fac) {

      double factor = *fac ? **fac : 1.0;

      if (factor) {
	if ((*it)->getUhOld()) {
	  if (D2UhOldIq) {
	    (*it)->evalSecondOrder(nPoints, 
				   uhOldIq, grdUhOldIq, D2UhOldIq, 
				   result, factor);
	  }
	  if (grdUhOldIq) {
	    (*it)->evalFirstOrderGrdPsi(nPoints, 
					uhOldIq, grdUhOldIq, D2UhOldIq, 
					result, -factor);
	    (*it)->evalFirstOrderGrdPhi(nPoints, 
					uhOldIq, grdUhOldIq, D2UhOldIq, 
					result, -factor);
	  }
	  if (uhOldIq) {
	    (*it)->evalZeroOrder(nPoints, 
				 uhOldIq, grdUhOldIq, D2UhOldIq, 
				 result, -factor);
	  }
	} else {
	  std::vector<double> fx(nPoints, 0.0);
	  (*it)->getC(elInfo, nPoints, fx);

	  for (int iq = 0; iq < nPoints; iq++)
	    result[iq] -= factor * fx[iq];
	}
      }
    }    
  }


}