Skip to content
Snippets Groups Projects
BiCGStab2.hh 5.77 KiB
#include "Preconditioner.h"

namespace AMDiS
{

  template<typename VectorType>
  BiCGStab2<VectorType>::BiCGStab2(std::string name)
    : OEMSolver<VectorType>(name),
      r(NULL), rstar(NULL), u(NULL), v(NULL), s(NULL), w(NULL), t(NULL),
      xmin(NULL)
  {}

  template<typename VectorType>
  BiCGStab2<VectorType>::~BiCGStab2()
  {}

  template<typename VectorType>
  void BiCGStab2<VectorType>::init()
  {
    r     = this->vectorCreator->create();
    rstar = this->vectorCreator->create();
    u     = this->vectorCreator->create();
    v     = this->vectorCreator->create();
    s     = this->vectorCreator->create();
    w     = this->vectorCreator->create();
    t     = this->vectorCreator->create();
    xmin  = this->vectorCreator->create();
  }

  template<typename VectorType>
  void BiCGStab2<VectorType>::exit()
  {
    this->vectorCreator->free(r);
    this->vectorCreator->free(rstar);
    this->vectorCreator->free(u);
    this->vectorCreator->free(v);
    this->vectorCreator->free(s);
    this->vectorCreator->free(w);
    this->vectorCreator->free(t);
    this->vectorCreator->free(xmin);
  }

  template<typename VectorType>
  int BiCGStab2<VectorType>::solveSystem(MatVecMultiplier<VectorType> *mv,
					 VectorType *x, VectorType *b, bool reuseMatrix)
  {
    FUNCNAME("BiCGStab2::solveSystem()");

    double res, old_res = -1.0;
    double rho0, alpha, omega1, omega2, rho1, beta, gamma, mu, nu, tau;
    int    iter;

    const double TOL = 1e-30;

    /*------------------------------------------------------------------------*/
    /*---  Initalization  ----------------------------------------------------*/
    /*------------------------------------------------------------------------*/

    *u = *b;
    if (this->leftPrecon)
      this->leftPrecon->precon(u);
    double normB = norm(u);

    if (normB < TOL) {
      INFO(this->info, 2)("b == 0; x = 0 is the solution of the linear system\n");
      setValue(*x, 0.0);
      this->residual = 0.0;
      return(0);
    }
    
    double save_tolerance = this->tolerance;
    if (this->relative)
      this->tolerance *= normB;

    *xmin = *x;
    int imin = 0;

    // r = b - Ax
    mv->matVec(NoTranspose, *x, *r);
    *r *= -1.0;
    *r += *b;

    if (this->leftPrecon) 
      this->leftPrecon->precon(r);

    /*---  check initial residual  -------------------------------------------*/

    res = norm(r);

    START_INFO();
    if (SOLVE_INFO(0, res, &old_res)) {
      if (this->relative)
	this->tolerance = save_tolerance;
      return(0);
    }

    double normrmin = res;

    // setting for the method
    *rstar  = *r;
    *rstar *= 1.0 / res;

    rho0 = 1.0; 
    alpha = 0.0; 
    omega2 = 1.0;
    setValue(*u, 0.0);

    /*------------------------------------------------------------------------*/
    /*---  Iteration  --------------------------------------------------------*/
    /*------------------------------------------------------------------------*/

    for (iter = 1; iter <= this->max_iter; iter++) {
      rho0 *= -omega2;

      /*---  even BiCG step  -------------------------------------------------*/
      
      // updating u
      rho1 = *r * *rstar;
      beta = alpha * rho1 / rho0;
      rho0 = rho1;
      *u *= -beta;
      *u += *r;
      
      // computing v
      mv->matVec(NoTranspose, *u, *v);
      if (this->leftPrecon) 
	this->leftPrecon->precon(v);
      
      // Updating x and r
      gamma = *v * *rstar;
      alpha = rho0 / gamma;
      axpy(alpha, *u, *x);
      axpy(-alpha, *v, *r);
      
      // computing s
      mv->matVec(NoTranspose, *r, *s);
      if (this->leftPrecon) 
	this->leftPrecon->precon(s);
      
      /*---  odd BiCG step  --------------------------------------------------*/
      
      // updating v
      rho1 = *s * *rstar;
      beta = alpha * rho1 / rho0;
      rho0 = rho1;
      *v *= -beta;
      *v += *s;
      
      // computing w
      mv->matVec(NoTranspose, *v, *w);
      if (this->leftPrecon) 
	this->leftPrecon->precon(w);
      
      // updating u, r and s
      gamma = *w * *rstar;
      alpha = rho0 / gamma;
      *u *= -beta;
      *u += *r;
      axpy(-alpha, *v, *r);
      axpy(-alpha, *w, *s);

      // computing t
      mv->matVec(NoTranspose, *s, *t);
      if (this->leftPrecon) 
	this->leftPrecon->precon(t);
      
      /*---  CGR(2) part  ----------------------------------------------------*/
      
      // computing constants
      omega1  = *r * *s;
      mu      = *s * *s;
      nu      = *s * *t;
      tau     = *t * *t;
      omega2  = *r * *t;
      tau    -= nu * nu / mu;
      omega2 -= nu * omega1 / mu;
      omega2 /= tau;
      omega1 -= nu * omega2;
      omega1 /= mu;
      
      // updating x
      axpy(omega1, *r, *x);
      axpy(omega2, *s, *x);
      axpy(alpha, *u, *x);
      
      // updating r
      axpy(-omega1, *s, *r);
      axpy(-omega2, *t, *r);
      /*---  checking accuracy  ----------------------------------------------*/
      
      res = norm(r);
      if (SOLVE_INFO(iter, res, &old_res) == 1) {
	if (this->relative)
	  this->tolerance = save_tolerance;
	return(iter);
      }

      // update minimal norm quantities
      if (res < normrmin) {	
	normrmin = res;
	*xmin    = *x;
	imin     = iter;
      } else if (res > normrmin * 1e+6) {
	INFO(this->info,2)("Linear solver diverges.\n");
	INFO(this->info,2)("Current iteration: %d.\n", iter);
	INFO(this->info,2)("Current residual: %e.\n", res);
	break;
      }
      
      // updating u
      axpy(-omega1, *v, *u);
      axpy(-omega2, *w, *u);
    }
    
    // returned solution is first with minimal residual
    *x = *xmin;
    iter = imin;
    this->residual = normrmin;
    
    if (this->relative) 
      this->tolerance = save_tolerance;

    INFO(this->info,2)("The minimal norm was %e; it was achieved in iteration %d.\n",
		       this->residual, iter);

    return(iter);
  }
}