// ============================================================================
// ==                                                                        ==
// == AMDiS - Adaptive multidimensional simulations                          ==
// ==                                                                        ==
// ============================================================================
// ==                                                                        ==
// ==  TU Dresden                                                            ==
// ==                                                                        ==
// ==  Institut f�r Wissenschaftliches Rechnen                               ==
// ==  Zellescher Weg 12-14                                                  ==
// ==  01069 Dresden                                                         ==
// ==  germany                                                               ==
// ==                                                                        ==
// ============================================================================
// ==                                                                        ==
// ==  https://gforge.zih.tu-dresden.de/projects/amdis/                      ==
// ==                                                                        ==
// ============================================================================

/** \file AdaptInfo.h */

#ifndef AMDIS_ADAPTINFO_H
#define AMDIS_ADAPTINFO_H

#include "MatrixVector.h"
#include "Parameters.h"
#include "Serializable.h"

namespace AMDiS {

  /**
   * \ingroup Adaption
   * 
   * \brief
   * Holds adapt parameters and infos about the problem. Base class
   * for AdaptInfoScal and AdaptInfoVec.
   */
  class AdaptInfo : public Serializable
  {
  protected:
    /** \brief
     * Stores adapt infos for a scalar problem or for one component of a 
     * vector valued problem.
     */
    class ScalContent {
    public:
      /// Constructor.
      ScalContent(std::string prefix) 
	: est_sum(0.0),
	  est_t_sum(0.0),
	  est_max(0.0),
	  est_t_max(0.0),
	  fac_max(0.0),
	  fac_sum(1.0),
	  spaceTolerance(0.0),
	  timeTolerance(0.0),
	  timeErrLow(0.0),
	  coarsenAllowed(0),
	  refinementAllowed(1),
	  refineBisections(1),
	  coarseBisections(1)	  	
      {
	double timeTheta2 = 0.3;

	// TODO: obsolete parameters timeTheta2, relTimeErr, relSpaceErr

	GET_PARAMETER(0, prefix + "->tolerance", "%f", &spaceTolerance);
	GET_PARAMETER(0, prefix + "->time tolerance", "%f", &timeTolerance);
	GET_PARAMETER(0, prefix + "->time theta 2", "%f", &timeTheta2);
	GET_PARAMETER(0, prefix + "->coarsen allowed", "%d", &coarsenAllowed);
	GET_PARAMETER(0, prefix + "->refinement allowed", "%d", &refinementAllowed);
	GET_PARAMETER(0, prefix + "->refine bisections", "%d", &refineBisections);
	GET_PARAMETER(0, prefix + "->coarsen bisections", "%d", &coarseBisections);
	GET_PARAMETER(0, prefix + "->sum factor", "%f", &fac_sum);
	GET_PARAMETER(0, prefix + "->max factor", "%f", &fac_max);

	timeErrLow = timeTolerance * timeTheta2;
      }

      /// Sum of all error estimates
      double est_sum;

      /// Sum of all time error estimates
      double est_t_sum;

      /// maximal local error estimate.
      double est_max;

      /// Maximum of all time error estimates
      double est_t_max;
      
      /// factors to combine max and integral time estimate
      double fac_max, fac_sum;

      /// Tolerance for the (absolute or relative) error
      double spaceTolerance;

      /// Time tolerance.
      double timeTolerance;

      /// Lower bound for the time error.
      double timeErrLow;

      /// true if coarsening is allowed, false otherwise.
      int coarsenAllowed;

      /// true if refinement is allowed, false otherwise.
      int refinementAllowed;

      /** \brief
       * parameter to tell the marking strategy how many bisections should be 
       * performed when an element is marked for refinement; usually the value is
       * 1 or DIM
       */
      int refineBisections;

      /** \brief
       * parameter to tell the marking strategy how many bisections should
       * be undone when an element is marked for coarsening; usually the value is 
       * 1 or DIM
       */                          
      int coarseBisections;    
    };

  public:
    /// Constructor.
    AdaptInfo(std::string name_, int size = 1) 
      : name(name_), 
	spaceIteration(-1),
	maxSpaceIteration(-1),
	timestepIteration(0),
	maxTimestepIteration(30),
	timeIteration(0),
	maxTimeIteration(30),
	time(0.0),
	startTime(0.0),
	endTime(1.0),
	timestep(0.0),
	lastProcessedTimestep(0.0),
	minTimestep(0.0),
	maxTimestep(1.0),
	timestepNumber(0),
	nTimesteps(0),
	solverIterations(0),
	maxSolverIterations(0),
	solverTolerance(1e-8),
	solverResidual(0.0),
        scalContents(size),
	deserialized(false),
	rosenbrockMode(false)
    {
      GET_PARAMETER(0, name_ + "->start time", "%f", &startTime);
      time = startTime;
      GET_PARAMETER(0, name_ + "->timestep", "%f", &timestep);
      GET_PARAMETER(0, name_ + "->end time", "%f", &endTime);
      GET_PARAMETER(0, name_ + "->max iteration", "%d", &maxSpaceIteration);
      GET_PARAMETER(0, name_ + "->max timestep iteration", "%d", &maxTimestepIteration);
      GET_PARAMETER(0, name_ + "->max time iteration", "%d", &maxTimeIteration);

      GET_PARAMETER(0, name_ + "->min timestep", "%f", &minTimestep);
      GET_PARAMETER(0, name_ + "->max timestep", "%f", &maxTimestep);

      GET_PARAMETER(0, name_ + "->number of timesteps", "%d", &nTimesteps);

      if (size == 1) {
	scalContents[0] = new ScalContent(name); 
      } else {
	char number[5];
	for (int i = 0; i < size; i++) {
	  sprintf(number, "[%d]", i);
	  scalContents[i] = new ScalContent(name + std::string(number)); 
	}
      }
    }

    /// Destructor.
    virtual ~AdaptInfo() 
    {
      for (unsigned int i = 0;  i < scalContents.size(); i++)
	delete scalContents[i];
    }

    inline void reset() 
    {
      spaceIteration = -1;
      timestepIteration = 0;
      timeIteration = 0;
      time = 0.0;
      timestep = 0.0;
      timestepNumber = 0;
      solverIterations = 0;
      solverResidual = 0.0;

      GET_PARAMETER(0, name + "->timestep", "%f", &timestep);
      lastProcessedTimestep=timestep;
    }

    /// Returns whether space tolerance is reached.
    virtual bool spaceToleranceReached() 
    {
      for (unsigned int i = 0; i < scalContents.size(); i++) {
	std::cout << "est_sum:" <<scalContents[i]->est_sum 
		  << " spaceTol: " << scalContents[i]->spaceTolerance 
		  << std::endl;
	if (!(scalContents[i]->est_sum < scalContents[i]->spaceTolerance))
	  return false;
      }

      return true;
    }

    /// Returns whether space tolerance of component i is reached.
    virtual bool spaceToleranceReached(int i) 
    {
      if (!(scalContents[i]->est_sum < scalContents[i]->spaceTolerance))
	return false;
      else
	return true;
    }

    /// Returns whether time tolerance is reached.
    virtual bool timeToleranceReached() 
    {
      for (unsigned int i = 0; i < scalContents.size(); i++)
	if (!(getTimeEstCombined(i) < scalContents[i]->timeTolerance))
	  return false;

      return true;
    }

    /// Returns whether time tolerance of component i is reached.
    virtual bool timeToleranceReached(int i) 
    {
      if (!(getTimeEstCombined(i) < scalContents[i]->timeTolerance))
	return false;
      else
	return true;
    }

    /// Returns whether time error is under its lower bound.
    virtual bool timeErrorLow() 
    {
      for (unsigned int i = 0; i < scalContents.size(); i++)
	if (!(getTimeEstCombined(i) < scalContents[i]->timeErrLow))
	  return false;

      return true;
    }
    /// Returns the time estimation as a combination 
    /// of maximal and integral time error 
    double getTimeEstCombined(unsigned i) const 
    { 
      return scalContents[i]->est_t_max*scalContents[i]->fac_max
	+scalContents[i]->est_t_sum*scalContents[i]->fac_sum; 
    }


    /// Print debug information about time error and its bound.
    void printTimeErrorLowInfo() 
    {
      for (unsigned int i = 0; i < scalContents.size(); i++){
	std::cout << "    Time error estimate  = " << getTimeEstCombined(i)
		  << "    Time error estimate sum = " << scalContents[i]->est_t_sum 
		  << "    Time error estimate max = " << scalContents[i]->est_t_max 
		  << "    Time error low bound = " << scalContents[i]->timeErrLow  
		  << "    Time error high bound = " << scalContents[i]->timeTolerance << "\n";
      }
    }

    /// Returns \ref spaceIteration.
    inline int getSpaceIteration() 
    { 
      return spaceIteration; 
    }

    /// Sets \ref spaceIteration.
    inline void setSpaceIteration(int it) 
    { 
      spaceIteration = it; 
    }
  
    /// Returns \ref maxSpaceIteration.
    inline int getMaxSpaceIteration() 
    { 
      return maxSpaceIteration;
    }

    /// Sets \ref maxSpaceIteration.
    inline void setMaxSpaceIteration(int it) 
    { 
      maxSpaceIteration = it; 
    }
  
    /// Increments \ref spaceIteration by 1;
    inline void incSpaceIteration() 
    { 
      spaceIteration++; 
    }

    /// Sets \ref timestepIteration.
    inline void setTimestepIteration(int it) 
    { 
      timestepIteration = it; 
    }
  
    /// Returns \ref timestepIteration.
    inline int getTimestepIteration() 
    { 
      return timestepIteration; 
    }

    /// Increments \ref timestepIteration by 1;
    inline void incTimestepIteration() 
    { 
      timestepIteration++; 
    }

    /// Returns \ref maxTimestepIteration.
    inline int getMaxTimestepIteration() 
    { 
      return maxTimestepIteration; 
    }

    /// Sets \ref maxTimestepIteration.
    inline void setMaxTimestepIteration(int it) 
    { 
      maxTimestepIteration = it; 
    }
  
    /// Sets \ref timeIteration.
    inline void setTimeIteration(int it) 
    { 
      timeIteration = it; 
    }
  
    /// Returns \ref timeIteration.
    inline int getTimeIteration() 
    { 
      return timeIteration; 
    }

    /// Increments \ref timesIteration by 1;
    inline void incTimeIteration() 
    { 
      timeIteration++; 
    }

    /// Returns \ref maxTimeIteration.
    inline int getMaxTimeIteration() 
    { 
      return maxTimeIteration; 
    }

    /// Sets \ref maxTimeIteration.
    inline void setMaxTimeIteration(int it) 
    { 
      maxTimeIteration = it; 
    }
  
    /// Returns \ref timestepNumber.
    inline int getTimestepNumber() 
    { 
      return timestepNumber; 
    }

    /// Returns \ref nTimesteps.
    inline int getNumberOfTimesteps() 
    {
      return nTimesteps;
    }

    /// Increments \ref timestepNumber by 1;
    inline void incTimestepNumber() 
    { 
      timestepNumber++; 
    }

    /// Sets \ref est_sum.
    inline void setEstSum(double e, int index) 
    {
      scalContents[index]->est_sum = e;
    }

    /// Sets \ref est_max.
    inline void setEstMax(double e, int index) 
    {
      scalContents[index]->est_max = e;
    }

    /// Sets \ref est_max.
    inline void setTimeEstMax(double e, int index) 
    {
      scalContents[index]->est_t_max = e;
    }

    /// Sets \ref est_t_sum.
    inline void setTimeEstSum(double e, int index) 
    {
      scalContents[index]->est_t_sum = e;
    }

    /// Returns \ref est_sum.
    inline double getEstSum(int index) 
    { 
      FUNCNAME("AdaptInfo::getEstSum()");

      TEST_EXIT_DBG(static_cast<unsigned int>(index) < scalContents.size())
	("Wrong index for adaptInfo!\n");

      return scalContents[index]->est_sum; 
    }

    /// Returns \ref est_t_sum.
    inline double getEstTSum(int index) 
    { 
      return scalContents[index]->est_t_sum; 
    }

    /// Returns \ref est_max.
    inline double getEstMax(int index) 
    { 
      FUNCNAME("AdaptInfo::getEstSum()");

      TEST_EXIT_DBG(static_cast<unsigned int>(index) < scalContents.size())
	("Wrong index for adaptInfo!\n");

      return scalContents[index]->est_max; 
    }

    /// Returns \ref est_max.
    inline double getTimeEstMax(int index) 
    { 
      return scalContents[index]->est_t_max; 
    }

    /// Returns \ref est_t_sum.
    inline double getTimeEstSum(int index) 
    { 
      return scalContents[index]->est_t_sum; 
    }

    /// Returns \ref spaceTolerance.
    inline double getSpaceTolerance(int index) 
    { 
      return scalContents[index]->spaceTolerance; 
    }  

    /// Sets \ref spaceTolerance.
    inline void setSpaceTolerance(int index, double tol) 
    { 
      scalContents[index]->spaceTolerance = tol; 
    }  

    /// Returns \ref timeTolerance.
    inline double getTimeTolerance(int index) 
    { 
      return scalContents[index]->timeTolerance; 
    }  

    /// Sets \ref time
    inline double setTime(double t) 
    { 
      time = t; 
      if (time > endTime) 
	time = endTime;
      if (time < startTime) 
	time = startTime;

      return time;
    }

    /// Gets \ref time
    inline double getTime() 
    { 
      return time; 
    }  

    /// Gets \ref &time
    inline double *getTimePtr() 
    { 
      return &time; 
    }  

    /// Sets \ref timestep
    inline double setTimestep(double t) 
    { 
      timestep = t; 
      if (timestep > maxTimestep)
	timestep = maxTimestep;
      if (timestep < minTimestep)
	timestep = minTimestep;
      if (time + timestep > endTime)
	timestep = endTime - time;
      
      return timestep;
    }
    /// Gets \ref timestep
    inline double getTimestep() 
    { 
      return timestep; 
    }

    inline void setLastProcessedTimestep(double t){
	lastProcessedTimestep=t;
    } 

    inline double getLastProcessedTimestep(){
	return lastProcessedTimestep;
    } 

    /** \brief
     * Returns true, if the end time is reached and no more timestep
     * computations must be done.
     */
    inline bool reachedEndTime() 
    {
      if (nTimesteps > 0) 
	return !(timestepNumber < nTimesteps);

      return !(time < endTime - DBL_TOL);
    }


    /// Sets \ref minTimestep
    inline void setMinTimestep(double t) 
    { 
      minTimestep = t; 
    }

    /// Gets \ref minTimestep
    inline double getMinTimestep() 
    { 
      return minTimestep; 
    }  

    /// Sets \ref maxTimestep
    inline void setMaxTimestep(double t) 
    { 
      maxTimestep = t; 
    }

    /// Gets \ref maxTimestep
    inline double getMaxTimestep() 
    { 
      return maxTimestep; 
    }  

    /// Gets \ref &timestep
    inline double *getTimestepPtr() 
    { 
      return &timestep; 
    }  

    /// Sets \ref startTime = time
    inline void setStartTime(double time) 
    { 
      startTime = time; 
    }

    /// Sets \ref endTime = time
    inline void setEndTime(double time) 
    { 
      endTime = time; 
    }

    /// Returns \ref startTime
    inline double getStartTime() 
    { 
      return startTime; 
    }

    /// Returns \ref endTime
    inline double getEndTime() 
    { 
      return endTime; 
    }

    /// Returns \ref timeErrLow.
    inline double getTimeErrLow(int index) 
    { 
      return scalContents[index]->timeErrLow; 
    }  

    /// Returns whether coarsening is allowed or not.
    inline bool isCoarseningAllowed(int index) 
    {
      return (scalContents[index]->coarsenAllowed == 1);
    }

    /// Returns whether coarsening is allowed or not.
    inline bool isRefinementAllowed(int index) 
    {
      return (scalContents[index]->refinementAllowed == 1);
    }

    ///
    inline void allowRefinement(bool allow, int index) 
    {
      scalContents[index]->refinementAllowed = allow;
    }

    ///
    inline void allowCoarsening(bool allow, int index) 
    {
      scalContents[index]->coarsenAllowed = allow;
    }

    /// Returns \ref refineBisections
    inline const int getRefineBisections(int index) const 
    {
      return scalContents[index]->refineBisections;
    }

    /// Returns \ref coarseBisections
    inline const int getCoarseBisections(int index) const 
    {
      return scalContents[index]->coarseBisections;
    }    

    inline int getSize() 
    { 
      return scalContents.size(); 
    }

    inline bool getRosenbrockMode()
    {
      return rosenbrockMode;
    }

    inline void setSolverIterations(int it) 
    {
      solverIterations = it;
    }
  
    inline int getSolverIterations() 
    {
      return solverIterations;
    }
  
    inline void setMaxSolverIterations(int it) 
    {
      maxSolverIterations = it;
    }
  
    inline int getMaxSolverIterations() 
    {
      return maxSolverIterations;
    }
  
    inline void setSolverTolerance(double tol) 
    {
      solverTolerance = tol;
    }
  
    inline double getSolverTolerance() 
    {
      return solverTolerance;
    }
  
    inline void setSolverResidual(double res) 
    {
      solverResidual = res;
    }
  
    inline double getSolverResidual() 
    {
      return solverResidual;
    }

    /// Returns true, if the adaptive procedure was deserialized from a file.
    const bool isDeserialized() const 
    {
      return deserialized;
    }

    inline void setIsDeserialized(bool b) 
    {
      deserialized = b;
    }

    inline void setRosenbrockMode(bool b)
    {
      rosenbrockMode = b;
    }

    /// Creates new scalContents with the given size.
    void setScalContents(int newSize);

    /** \brief
     * Resets timestep, current time and time boundaries without
     * any check. Is used by the parareal algorithm.
     */
    void resetTimeValues(double newTimeStep,
			 double newStartTime,
			 double newEndTime)
    {
      time = newStartTime;
      startTime = newStartTime;
      endTime = newEndTime;
      timestep = newTimeStep;
      timestepNumber = 0;
    }

    void serialize(std::ostream& out);

    void deserialize(std::istream& in);

  protected:
    /// Name.
    std::string name;

    /// Current space iteration
    int spaceIteration;

    /** \brief
     * maximal allowed number of iterations of the adaptive procedure; if 
     * maxIteration <= 0, no iteration bound is used
     */
    int maxSpaceIteration;

    /// Current timestep iteration
    int timestepIteration;

    /// Maximal number of iterations for choosing a timestep
    int maxTimestepIteration;

    /// Current time iteration
    int timeIteration;

    /// Maximal number of time iterations
    int maxTimeIteration;

    /// Actual time, end of time interval for current time step
    double time;

    /// Initial time
    double startTime;

    /// Final time
    double endTime;

    ///Time step size to be used
    double timestep;

    /// Last processed time step size of finished iteration
    double lastProcessedTimestep;

    /// Minimal step size
    double minTimestep;

    /// Maximal step size
    double maxTimestep;

    /// Number of current time step
    int timestepNumber;

    /** \brief
     * Per default this value is 0 and not used. If it is set to a non-zero value,
     * the computation of the stationary problem is done nTimesteps times with a
     * fixed timestep.
     */
    int nTimesteps;
  
    /// number of iterations needed of linear or nonlinear solver
    int solverIterations;

    /// maximal number of iterations needed of linear or nonlinear solver
    int maxSolverIterations;

    ///
    double solverTolerance;

    ///
    double solverResidual;

    /// Scalar adapt infos.
    std::vector<ScalContent*> scalContents;

    /// Is true, if the adaptive procedure was deserialized from a file.
    bool deserialized;

    /// Is true, if the time adaption is controlled by a Rosenbrock Method.
    bool rosenbrockMode;
  };

}

#endif //  AMDIS_ADAPTINFO_H