//
// Software License for AMDiS
//
// Copyright (c) 2010 Dresden University of Technology 
// All rights reserved.
// Authors: Simon Vey, Thomas Witkowski et al.
//
// This file is part of AMDiS
//
// See also license.opensource.txt in the distribution.


#include "time/RosenbrockStationary.h"
#include "io/VtkWriter.h"
#include "ProblemStat.h"
#include "SystemVector.h"
#include "OEMSolver.h"
#include "Debug.h"

#ifdef HAVE_PARALLEL_DOMAIN_AMDIS
#include "parallel/MeshDistributor.h"
#endif

namespace AMDiS {

  void RosenbrockStationary::acceptTimestep()
  {
    *solution = *newUn;
    *unVec = *newUn;
  }


  void RosenbrockStationary::init()
  {
    stageSolution = new SystemVector(*solution);
    unVec = new SystemVector(*solution);
    timeRhsVec = new SystemVector(*solution);
    newUn = new SystemVector(*solution);
    tmp = new SystemVector(*solution);
    lowSol = new SystemVector(*solution);    

    stageSolution->set(0.0);
    unVec->set(0.0);
    
    stageSolutions.resize(rm->getStages());
    for (int i = 0; i < rm->getStages(); i++) {
      stageSolutions[i] = new SystemVector(*solution);
      stageSolutions[i]->set(0.0);
    }
  }


  void RosenbrockStationary::buildAfterCoarsen(AdaptInfo *adaptInfo, Flag flag,
					       bool asmMatrix, bool asmVector)
  {        
    FUNCNAME("RosenbrockStationary::buildAfterCoarsen()");

    TEST_EXIT(tauPtr)("No tau pointer defined in stationary problem!\n");

    if (first) {
      first = false;
      *unVec = *solution;      
    }
    
    *newUn = *unVec;    
    *lowSol = *unVec;

    for (int i = 0; i < rm->getStages(); i++) {      
      *stageSolution = *unVec;
      for (int j = 0; j < i; j++) {
	*tmp = *(stageSolutions[j]);
	*tmp *= rm->getA(i, j);
	*stageSolution += *tmp;
      }

      for (unsigned int j = 0; j < boundaries.size(); j++) {
	boundaries[j].vec->interpol(boundaries[j].fct);
	*(boundaries[j].vec) -= *(stageSolution->getDOFVector(boundaries[j].row));
      }

      timeRhsVec->set(0.0);
      for (int j = 0; j < i; j++) {
	*tmp = *(stageSolutions[j]);
	*tmp *= (rm->getC(i, j) / *tauPtr);
	*timeRhsVec += *tmp;
      }

      ProblemStat::buildAfterCoarsen(adaptInfo, flag, (i == 0), asmVector);
      ProblemStat::solve(adaptInfo, i == 0, i + 1 < rm->getStages());

      *(stageSolutions[i]) = *solution;
      
      *tmp = *solution;
      *tmp *= rm->getM1(i);

      *newUn += *tmp;

      *tmp = *solution;
      *tmp *= rm->getM2(i);
      *lowSol += *tmp;
    }
    
    for (int i = 0; i < nComponents; i++) {
      (*(lowSol->getDOFVector(i))) -= (*(newUn->getDOFVector(i)));
      adaptInfo->setTimeEstSum(lowSol->getDOFVector(i)->l2norm(), i);
    }   
  }


  void RosenbrockStationary::solve(AdaptInfo *adaptInfo, bool, bool)
  {}


  void RosenbrockStationary::addOperator(Operator &op, int row, int col, 
					 double *factor, double *estFactor)
  {
    FUNCNAME("RosenbrockStationary::addOperator()");

    TEST_EXIT(op.getUhOld() == NULL)("UhOld is not allowed to be set!\n");

    op.setUhOld(stageSolution->getDOFVector(col));
    ProblemStat::addVectorOperator(op, row, factor, estFactor);
  }
  

  void RosenbrockStationary::addJacobianOperator(Operator &op, int row, int col, 
						 double *factor, double *estFactor)
  {
    FUNCNAME("RosenbrockStationary::addJacobianOperator()");
    
    TEST_EXIT(factor == NULL)("Not yet implemented!\n");
    TEST_EXIT(estFactor == NULL)("Not yet implemented!\n");

    ProblemStat::addMatrixOperator(op, row, col, &minusOne, &minusOne);
  }


  void RosenbrockStationary::addTimeOperator(int row, int col)
  {
    FUNCNAME("RosenbrockStationary::addTimeOperator()");

    TEST_EXIT(invTauGamma)("This should not happen!\n");

    Operator *op = new Operator(componentSpaces[row], componentSpaces[col]);
    op->addZeroOrderTerm(new Simple_ZOT);
    ProblemStat::addMatrixOperator(op, row, col, invTauGamma, invTauGamma);

    Operator *opRhs = new Operator(componentSpaces[row]);
    opRhs->addZeroOrderTerm(new VecAtQP_ZOT(timeRhsVec->getDOFVector(col)));
    ProblemStat::addVectorOperator(opRhs, row, &minusOne, &minusOne);
  }

  
  void RosenbrockStationary::addDirichletBC(BoundaryType type, int row, int col,
					    AbstractFunction<double, WorldVector<double> > *fct)
  {
    FUNCNAME("RosenbrockStationary::addDirichletBC()");

    DOFVector<double>* vec = new DOFVector<double>(componentSpaces[row], "vec");
    RosenbrockBoundary bound = {fct, vec, row, col};
    boundaries.push_back(bound);

    ProblemStat::addDirichletBC(type, row, col, vec);
  }

}