MultiTypeMatrix.hpp 3.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#pragma once

#include <tuple>

#include <dune/amdis/common/Concepts.hpp>
#include <dune/amdis/common/FieldTraits.hpp>
#include <dune/amdis/common/Loops.hpp>
#include <dune/amdis/common/Mpl.hpp>
#include <dune/amdis/common/MultiTypeVector.hpp>
#include <dune/amdis/common/Size.hpp>
#include <dune/amdis/utility/MultiIndex.hpp>

namespace AMDiS
{
  // forward declaration
  template <class... Rows>
  class MultiTypeMatrix;
}

namespace Dune
{
  template <class... Rows>
  struct FieldTraits<AMDiS::MultiTypeMatrix<Rows...>>
  {
    using field_type = typename AMDiS::CommonFieldTraits<Rows...>::field_type;
    using real_type = typename AMDiS::CommonFieldTraits<Rows...>::real_type;
  };
}


namespace AMDiS
{
  // Rows should be of type MultiTypeVector
  template <class... Rows>
  class MultiTypeMatrix
      : public std::tuple<Rows...>
  {
    using Self = MultiTypeMatrix;
    using Super = std::tuple<Rows...>;

    static_assert(is_equal<int, Rows::dimension...>::value,
      "All columns must have the same length.");

  public:
    using field_type = typename Dune::FieldTraits<Self>::field_type;
    using real_type = typename Dune::FieldTraits<Self>::real_type;
    using size_type = std::size_t;

    enum {
      rows = std::tuple_size<Super>::value,
      cols = Math::max(Rows::dimension...)
    };

    template <class... Rows_,
      REQUIRES( Concepts::Similar<Types<Rows...>, Types<Rows_...>> )>
    MultiTypeMatrix(Rows_&&... rows)
      : Super(std::forward<Rows_>(rows)...)
    {}

    /// Default construction of tuple of FieldVectors
    MultiTypeMatrix() = default;

    /// Construct tuple by initializing all tuple elements with a constant value
    MultiTypeMatrix(real_type value)
    {
      *this = value;
    }

    /// Assignment of real number to all tuple elements
    MultiTypeMatrix& operator=(real_type value)
    {
      forEach(*this, [value](auto& fv) { fv = value; });
      return *this;
    }

    // Compound assignment operator +=
    MultiTypeMatrix& operator+=(MultiTypeMatrix const& that)
    {
      forEach(range_<0,rows>, [&that,this](auto const _i) { (*this)[_i] += that[_i]; });
      return *this;
    }

    // Compound assignment operator -=
    MultiTypeMatrix& operator-=(MultiTypeMatrix const& that)
    {
      forEach(range_<0,rows>, [&that,this](auto const _i) { (*this)[_i] -= that[_i]; });
      return *this;
    }

    // Scaling of all tuple elements by a constant value
    MultiTypeMatrix& operator*=(real_type value)
    {
      forEach(*this, [value](auto& fv) { fv *= value; });
      return *this;
    }

    // Scaling of all tuple elements by the inverse of a constant value
    MultiTypeMatrix& operator/=(real_type value)
    {
      forEach(*this, [value](auto& fv) { fv /= value; });
      return *this;
    }

    /// Const access to the tuple elements
    template <std::size_t I, std::size_t J>
    decltype(auto) operator()(const index_t<I>&, const index_t<J>&) const
    {
      return std::get<J>(std::get<I>(*this));
    }

    /// Mutable access to the tuple elements
    template <std::size_t I, std::size_t J>
    decltype(auto) operator()(const index_t<I>&, const index_t<J>&)
    {
      return std::get<J>(std::get<I>(*this));
    }

    /// Return number of elements of the tuple
    static constexpr std::size_t num_rows()
    {
      return rows;
    }

    /// Return number of elements of the tuple
    static constexpr std::size_t num_cols()
    {
      return cols;
    }
  };

} // end namespace AMDiS