PfcPrecon.hpp 2.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#pragma once

#include <tuple>
#include <cmath>

#include <dune/istl/bcrsmatrix.hh>
#include <dune/istl/bvector.hh>
#include <dune/istl/multitypeblockmatrix.hh>
#include <dune/istl/multitypeblockvector.hh>
#include <dune/istl/preconditioner.hh>

namespace AMDiS {

template <class Matrix, class Vector>
class PfcPrecon : public Dune::Preconditioner<Vector, Vector>
{    
    static constexpr index_<0> _0 = index_<0>();
    static constexpr index_<1> _1 = index_<1>();
    static constexpr index_<2> _2 = index_<2>();
    
    using SubMatrix = std::decay_t<decltype( std::declval<Matrix>()[_0][_0] )>;
    using SubVector = std::decay_t<decltype( std::declval<Vector>()[_0] )>;
    
public:    
    PfcPrecon(Matrix const& matrix, double* tauPtr, double M0)
	: matrix(matrix)
	, tauPtr(tauPtr)
	, M0(M0)
	, matL0(matrix[_1][_0])
	, matL (matrix[_2][_1])
	, matM (matrix[_2][_2])
    {}
    
    virtual void pre(Vector& x, Vector& b) override
    {
	double delta = std::sqrt(M0 * (*tauPtr));
	
	matMpL = matM;
	matMpL.axpy(1.0/delta, matL0); // => MpL = M + 1/delta * L0
	
	matMpL2 = matM;
	matMpL2.axpy(std::sqrt(delta), matL);
	
	y0.resize(matM.N());
	y1.resize(matM.N());
	tmp.resize(matM.N());
    }

    virtual void apply(Vector& x, const Vector& b) override
    {
	double delta = std::sqrt(M0 * (*tauPtr));
	
	solve(matM, y0, b[_0]);		// M*y0 = b0
	matL0.mv(y0, y1);		// y1 = K*y0
	tmp = b[_1];
	tmp-= y1;			// tmp := b1 - tau*y1

	solve(matMpL, y1, tmp);		// (M + delta*K) * y1 = tmp
	x[_0] = y0;
	x[_0].axpy(1.0/delta, y1);		// x0 = y0 + (1/delta)*y1

	matM.mv(y1, tmp);		// tmp := M*y1
	solve(matMpL2, y1, tmp);	// (M+eps*sqrt(delta)K) * y1 = tmp
	matM.mv(y1, tmp);		// tmp := M*y1
	solve(matMpL2, x[_1], tmp);	// (M+eps*sqrt(delta)K) * x1 = tmp

	x[_0].axpy(-1.0/delta, x[_1]);	// x0 = x0 - (1/delta)*x1 = y0 + (1/delta)*(y1 - x1)

	matL.mv(x[_1], y1);
	tmp = b[_2];
	tmp-= y1;			// tmp := b2 - K*x1
	solve(matM, x[_2], tmp);
    }

    virtual void post(Vector& x) override
    {
	
    }
    
private:
    
    template <class Mat, class Vec>
    void solve(Mat const& A, Vec& x, Vec b)
    {
	Dune::MatrixAdapter<Mat, Vec, Vec> op(A);
	Dune::SeqJac<Mat, Vec, Vec> precon(A, 1, 1.0);
	Dune::CGSolver<Vec> solver(op, precon, 1.e-3, 20, 0);
	
	Dune::InverseOperatorResult statistics;
	solver.apply(x, b, statistics);
    }
    
private:
    Matrix const& matrix;    
    double* tauPtr;
    double M0;
    
    SubMatrix const& matL0;
    SubMatrix const& matL;
    SubMatrix const& matM;
    
    SubMatrix matMpL;
    SubMatrix matMpL2;
    
    SubVector y0, y1, tmp;
};

} // end namespace AMDiS