BlockMTLMatrix.hpp 5.93 KB
Newer Older
1
/** \file BlockMTLMatrix.hpp */
2
3
4

#pragma once

5
6
#include <array>

7
8
#include <boost/numeric/mtl/matrices.hpp>

9
10
11
12
#include <amdis/common/Utility.hpp>
#include <amdis/common/Literals.hpp>
#include <amdis/common/Loops.hpp>
#include <amdis/linear_algebra/LinearAlgebraBase.hpp>
13
14
15
16

namespace AMDiS
{
  /// A wrapper for AMDiS::SolverMatrix to be used in MTL/ITL solvers
17
  template <class MTLMatrix, std::size_t _N, std::size_t _M>
18
  class BlockMTLMatrix
19
      : public std::array<std::array<MTLMatrix, _M>, _N>
20
21
  {
    using Self = BlockMTLMatrix;
22

23
  public:
24
    /// The index/size - type
25
    using size_type  = typename MTLMatrix::size_type;
26

27
    /// The type of the elements of the MTLMatrix
28
    using value_type = typename MTLMatrix::value_type;
29

30
31
    /// The underlying mtl matrix type
    using BaseMatrix = MTLMatrix;
32

33
34
  public:
    /// Return the (R,C)'th matrix block
35
36
    template <std::size_t R, std::size_t C>
    auto& operator()(const index_t<R>, const index_t<C>)
37
38
39
40
    {
      static_assert(R < N() && C < M(), "Indices out of range [0,N)x[0,M)");
      return std::get<C>(std::get<R>(*this));
    }
41

42
    /// Return the (R,C)'th matrix block
43
44
    template <std::size_t R, std::size_t C>
    auto const& operator()(const index_t<R>, const index_t<C>) const
45
46
47
48
    {
      static_assert(R < N() && C < M(), "Indices out of range [0,N)x[0,M)");
      return std::get<C>(std::get<R>(*this));
    }
49

50
    /// Return the number of row blocks
51
    static constexpr std::size_t N() { return _N; }
52

53
    /// Return the number of column blocks
54
    static constexpr std::size_t M() { return _M; }
55
56
57
58
59
60
61
62
63

    /// perform blockwise multiplication A*b -> x
    template <class VectorIn, class VectorOut, class Assign>
    void mult(VectorIn const& b, VectorOut& x, Assign) const
    {
      // create iranges to access array blocks
      std::array<mtl::irange, _N> r_rows;
      std::array<mtl::irange, _M> r_cols;
      getRanges(r_rows, r_cols);
64

65
      forEach(range_<0, _N>, [&](const auto _i)
66
      {
67
        bool first = true;
68

69
        // a reference to the i'th block of x
70
        VectorOut x_i(x[r_rows[_i]]);
71
        forEach(range_<0, _M>, [&](const auto _j)
72
        {
73
          auto const& A_ij = this->operator()(_i, _j);
74
          if (num_rows(A_ij) > 0 && A_ij.nnz() > 0) {
75
            // a reference to the j'th block of b
76
            const VectorIn b_j(b[r_cols[_j]]);
77

78
            if (first) {
79
              Assign::first_update(x_i, A_ij * b_j);
80
81
82
              first = false;
            }
            else {
83
              Assign::update(x_i, A_ij * b_j);
84
85
86
87
88
89
            }
          }
        });
      });
    }

90
91
    /// A Multiplication operator returns a multiplication-expresssion.
    /// Calls \ref mult internally.
92
    template <class VectorIn>
93
    mtl::vec::mat_cvec_multiplier<Self, VectorIn>
94
95
96
97
    operator*(VectorIn const& v) const
    {
      return {*this, v};
    }
98
99

    /// Fill an array of irange corresponding to the row-sizes, used
100
    /// to access sub-vectors
101
    void getRowRanges(std::array<mtl::irange, _N>& r_rows) const
102
    {
103
104
105
      std::size_t start = 0;
      forEach(range_<0, _N>, [&](const auto _r) {
        std::size_t finish = start + num_rows((*this)(_r, 0_c));
106
107
108
109
        r_rows[_r].set(start, finish);
        start = finish;
      });
    }
110
111

    /// Fill an array of irange corresponding to the column-sizes, used
112
    /// to access sub-vectors
113
    void getColRanges(std::array<mtl::irange, _M>& r_cols) const
114
    {
115
116
117
      std::size_t start = 0;
      forEach(range_<0, _M>, [&](const auto _c) {
        std::size_t finish = start + num_cols((*this)(0_c, _c));
118
119
120
121
        r_cols[_c].set(start, finish);
        start = finish;
      });
    }
122

123
124
    /// Fill two arrays of irange corresponding to row and column sizes.
    /// \see getRowRanges() and \see getColRanges()
125
    void getRanges(std::array<mtl::irange, _N>& r_rows,
126
                   std::array<mtl::irange, _M>& r_cols) const
127
    {
128
129
130
131
      getRowRanges(r_rows);
      getColRanges(r_cols);
    }
  };
132
133


134
135
136
  namespace Impl
  {
    /// Specialization of Impl::MTLMatrix from \file LinearAlgebraBase.hpp
137
    template <class MTLMatrix, std::size_t _N, std::size_t _M>
138
139
140
141
142
    struct BaseMatrix<BlockMTLMatrix<MTLMatrix, _N, _M>>
    {
      using type = MTLMatrix;
    };
  }
143

144
  /// Return the number of overall rows of a BlockMTLMatrix
145
146
  template <class MTLMatrix, std::size_t _N, std::size_t _M>
  inline std::size_t num_rows(BlockMTLMatrix<MTLMatrix, _N, _M> const& A)
147
  {
148
149
    std::size_t nRows = 0;
    forEach(range_<0, _N>, [&](const auto _r) {
150
      nRows += num_rows(A(_r, 0_c));
151
152
153
154
    });
    return nRows;
  }

155
  /// Return the number of overall columns of a BlockMTLMatrix
156
157
  template <class MTLMatrix, std::size_t _N, std::size_t _M>
  inline std::size_t num_cols(BlockMTLMatrix<MTLMatrix, _N, _M> const& A)
158
  {
159
160
    std::size_t nCols = 0;
    forEach(range_<0, _M>, [&](const auto _c) {
161
      nCols += num_cols(A(0_c, _c));
162
163
164
165
    });
    return nCols;
  }

166
  /// Return the size, i.e. rows*columns of a BlockMTLMatrix
167
168
  template <class MTLMatrix, std::size_t _N, std::size_t _M>
  inline std::size_t size(BlockMTLMatrix<MTLMatrix, _N, _M> const& A)
169
170
171
172
  {
    return num_rows(A) * num_cols(A);
  }

173
  /// Nullify a BlockMTLMatrix, i.e. nullify each block.
174
  template <class MTLMatrix, std::size_t _N, std::size_t _M>
175
  inline void set_to_zero(BlockMTLMatrix<MTLMatrix, _N, _M>& A)
176
  {
177
178
    forEach(range_<0, _N>, [&](const auto _r) {
      forEach(range_<0, _M>, [&](const auto _c) {
179
180
181
182
183
184
185
        set_to_zero(A(_r,_c));
      });
    });
  }

} // end namespace AMDiS

186
187

/// \cond HIDDEN_SYMBOLS
188
189
namespace mtl
{
190
  template <class MTLMatrix, std::size_t _N, std::size_t _M>
191
192
193
194
195
196
197
198
  struct Collection<AMDiS::BlockMTLMatrix<MTLMatrix, _N, _M>>
  {
    using value_type = typename MTLMatrix::value_type;
    using size_type  = typename MTLMatrix::size_type;
  };

  namespace ashape
  {
199
    template <class MTLMatrix, std::size_t _N, std::size_t _M>
200
201
202
203
204
205
206
207
    struct ashape_aux<AMDiS::BlockMTLMatrix<MTLMatrix, _N, _M>>
    {
      using type = nonscal;
    };

  } // end namespace ashape

} // end namespace mtl
208
/// \endcond