From 36eb33e4727fe0225ad0d6410f8ecd4604226547 Mon Sep 17 00:00:00 2001
From: Simon Praetorius <simon.praetorius@tu-dresden.de>
Date: Mon, 20 Nov 2017 09:34:34 +0100
Subject: [PATCH] first attempt to exchange element-traverse and
 matrix-block-traverse

---
 dune/amdis/ProblemInstatBase.cpp |  30 ++--
 dune/amdis/ProblemStat.inc.hpp   | 236 +++++++++++++++++++++----------
 2 files changed, 173 insertions(+), 93 deletions(-)

diff --git a/dune/amdis/ProblemInstatBase.cpp b/dune/amdis/ProblemInstatBase.cpp
index 48687792..a8c7df3c 100644
--- a/dune/amdis/ProblemInstatBase.cpp
+++ b/dune/amdis/ProblemInstatBase.cpp
@@ -4,22 +4,22 @@
 #include "AdaptStationary.hpp"
 #include "StandardProblemIteration.hpp"
 
-namespace AMDiS
+namespace AMDiS {
+
+void ProblemInstatBase::setTime(AdaptInfo& adaptInfo)
+{
+  cTime = adaptInfo.getTime();
+  tau = adaptInfo.getTimestep();
+  invTau = 1.0 / tau;
+}
+
+
+void ProblemInstatBase::solveInitialProblem(AdaptInfo& adaptInfo)
 {
-  void ProblemInstatBase::setTime(AdaptInfo& adaptInfo)
-  {
-    cTime = adaptInfo.getTime();
-    tau = adaptInfo.getTimestep();
-    invTau = 1.0 / tau;
-  }
-  
+  StandardProblemIteration iteration(*initialProblem);
+  AdaptStationary initialAdapt(name + "->initial->adapt", iteration, adaptInfo);
 
-  void ProblemInstatBase::solveInitialProblem(AdaptInfo& adaptInfo)
-  {
-    StandardProblemIteration iteration(*initialProblem);
-    AdaptStationary initialAdapt(name + "->initial->adapt", iteration, adaptInfo);
+  initialAdapt.adapt();
+}
 
-    initialAdapt.adapt();
-  }
-  
 } // end namespace AMDiS
diff --git a/dune/amdis/ProblemStat.inc.hpp b/dune/amdis/ProblemStat.inc.hpp
index e0045018..9c7ca3af 100644
--- a/dune/amdis/ProblemStat.inc.hpp
+++ b/dune/amdis/ProblemStat.inc.hpp
@@ -245,82 +245,174 @@ solve(AdaptInfo& adaptInfo, bool createMatrixData, bool storeMatrixData)
 
 
 template <class Traits>
-void ProblemStat<Traits>::
-buildAfterCoarsen(AdaptInfo& /*adaptInfo*/, Flag flag, bool asmMatrix_, bool asmVector_)
+  template <int R, int C>
+bool ProblemStat<Traits>::
+assembleMatrix(bool asmMatrix_, const index_t<R> _r = {}, const index_t<C> _c = {}) const
 {
-  Timer t;
-
-  // update global feSpace, i.e. necessary after mesh change
-  forEach(range_<0, nComponents>, [this](auto const _r) {
-    this->getFeSpace(_r).update(this->leafGridView());
-  });
+  return asmMatrix_ && (!matrixAssembled[R][C] || matrixChanging[R][C]);
+}
 
+template <class Traits>
+  template <int R>
+bool ProblemStat<Traits>::
+assembleVector(bool asmVector_, const index_t<R> _r = {}) const
+{
+  return asmVector_ && (!vectorAssembled[R] || vectorChanging[R]);
+}
 
-  std::size_t nnz = 0;
+template <class Traits>
+void ProblemStat<Traits>::
+initMatrixVector(bool asmMatrix_, bool asmVector_)
+{
   forEach(range_<0, nComponents>, [&,this](auto const _r)
   {
     static const int R = decltype(_r)::value;
     msg(this->getFeSpace(_r).size(), " DOFs for FeSpace[", R, "]");
 
-    bool asmVector = asmVector_ && (!vectorAssembled[R] || vectorChanging[R]);
-
-    if (asmVector) {
+    if (assembleVector(asmVector_, _r)) {
       rhs->compress(_r);
       rhs->getVector(_r) = 0.0;
+
+      // init vector operators
+      for (auto& op : vectorOperators[R])
+        op.init(this->getFeSpace(_r));
+      for (auto& op : vectorBoundaryOperators[R])
+        op.init(this->getFeSpace(_r));
+      for (auto& op : vectorIntersectionOperators[R])
+        op.init(this->getFeSpace(_r));
     }
 
     forEach(range_<0, nComponents>, [&,this](auto const _c)
     {
       static const int C = decltype(_c)::value;
 
-      using MatrixData = typename ProblemStat<Traits>::template MatrixData<R, C>;
-      using VectorData = typename ProblemStat<Traits>::template VectorData<R>;
-
-      // The DOFMatrix which should be assembled
-      auto& dofmatrix    = (*systemMatrix)(_r, _c);
-      auto& solution_vec = (*solution)[_c];
-      auto& rhs_vec      = (*rhs)[_r];
+      bool asmMatrix = assembleMatrix(asmMatrix_, _r, _c);
+      (*systemMatrix)(_r, _c).init(asmMatrix);
 
-      bool assembleMatrix = asmMatrix_ && (!matrixAssembled[R][C] || matrixChanging[R][C]);
-      bool assembleVector = asmVector  && R == C;
+      if (asmMatrix) {
+        // init matrix operators
+        for (auto& op : matrixOperators[R][C])
+          op.init(this->getFeSpace(_r), this->getFeSpace(_c));
+        for (auto& op : matrixBoundaryOperators[R][C])
+          op.init(this->getFeSpace(_r), this->getFeSpace(_c));
+        for (auto& op : matrixIntersectionOperators[R][C])
+          op.init(this->getFeSpace(_r), this->getFeSpace(_c));
 
-      if (assembleMatrix) {
         // init boundary condition
         for (int c = 0; c < nComponents; ++c)
           for (auto bc : dirichletBc[R][c])
-            bc->init(c == C, dofmatrix, solution_vec, rhs_vec);
+            bc->init(c == C, (*systemMatrix)(_r, _c), (*solution)[_c], (*rhs)[_r]);
       }
+    });
+  });
+}
 
-      auto mat = MatrixData{dofmatrix, matrixOperators[R][C], matrixBoundaryOperators[R][C], matrixIntersectionOperators[R][C], assembleMatrix};
-      auto vec = VectorData{rhs_vec, vectorOperators[R], vectorBoundaryOperators[R], vectorIntersectionOperators[R], assembleVector};
 
-      // assemble the DOFMatrix block and the corresponding rhs vector if r==c
+template <class Traits>
+std::size_t finishMatrixVector(bool asmMatrix_, bool asmVector_)
+{
+  std::size_t nnz = 0;
+  forEach(range_<0, nComponents>, [&,this](auto const _r)
+  {
+    static const int R = decltype(_r)::value;
+    if (assembleVector(asmVector_, _r))
+      vectorAssembled[R] = true;
+    forEach(range_<0, nComponents>, [&,this](auto const _c)
+    {
+      static const int C = decltype(_c)::value;
+      bool asmMatrix = assembleMatrix(asmMatrix_, _r, _c);
 
-      dofmatrix.init(assembleMatrix);
-      this->assemble(mat, vec, this->leafGridView());
-      dofmatrix.finish();
+      (*systemMatrix)(_r, _c).finish();
 
-      if (assembleMatrix)
+      if (asmMatrix)
         matrixAssembled[R][C] = true;
-      if (assembleVector)
-        vectorAssembled[R] = true;
 
-      if (assembleMatrix) {
+      if (asmMatrix) {
         // finish boundary condition
         for (int c = 0; c < nComponents; ++c) {
           for (int r = 0; r < nComponents; ++r) {
             if (r != R && c != C)
               continue;
             for (auto bc : dirichletBc[r][c])
-              bc->finish(r == R, c == C, dofmatrix, solution_vec, rhs_vec);
+              bc->finish(r == R, c == C, (*systemMatrix)(_r, _c), (*solution)[_c], (*rhs)[_r]);
           }
         }
 
         nnz += dofmatrix.getNnz();
-      }
     });
   });
 
+  return nnz;
+}
+
+
+template <class Traits>
+  template <class Element>
+void ProblemStat<Traits>::
+LocalFiniteElement<typename Traits::FeSpaces> initElement(Element const& element)
+{
+  LocalFiniteElement<FeSpaces> localFiniteElem(*feSpaces);
+
+  forEach(range_<0, nComponents>, [&,this](auto const _i)
+  {
+    auto& localView = localFiniteElem.localView(_i);
+    auto& localIndexSet  = localFiniteElem.localIndexSet(_i);
+
+    localView.bind(element);
+    localIndexSet.bind(element);
+  });
+
+  return localFiniteElem;
+}
+
+
+template <class Traits>
+void ProblemStat<Traits>::
+buildAfterCoarsen(AdaptInfo& /*adaptInfo*/, Flag flag, bool asmMatrix_, bool asmVector_)
+{
+  Timer t;
+
+  auto gridView = this->leafGridView();
+
+  // 1. update global feSpace. This is necessary after mesh change
+  forEach(range_<0, nComponents>, [&,this](auto const _r) {
+    this->getFeSpace(_r).update(gridView);
+  });
+
+  // 2. init matrix and rhs vector and initialize dirichlet boundary conditions
+  initMatrixVector(asmMatrix_, asmVector_);
+
+  // 3. assemble operators
+  for (auto const& element : elements(gridView))
+  {
+    auto localFiniteElem = initElement(element);
+
+    forEach(range_<0, nComponents>, [&,this](auto const _r)
+    {
+      forEach(range_<0, nComponents>, [&,this](auto const _c)
+      {
+        static const int R = decltype(_r)::value;
+        static const int C = decltype(_c)::value;
+
+        using MatrixData = typename ProblemStat<Traits>::template MatrixData<R, C>;
+        using VectorData = typename ProblemStat<Traits>::template VectorData<R>;
+
+        auto mat = MatrixData{(*systemMatrix)(_r, _c),
+          matrixOperators[R][C], matrixBoundaryOperators[R][C], matrixIntersectionOperators[R][C],
+          assembleMatrix(asmMatrix_, _r, _c)};
+        auto vec = VectorData{(*rhs)[_r],
+          vectorOperators[R], vectorBoundaryOperators[R], vectorIntersectionOperators[R],
+          assembleVector(asmVector_, _r) && R==C};
+
+        // assemble the DOFMatrix block and the corresponding rhs vector if r==c
+        this->assemble(mat, vec, element);
+      });
+    });
+  }
+
+  // 4. finish matrix insertion and apply dirichlet boundary conditions
+  std::size_t nnz = finishMatrixVector(asmMatrix_, asmVector_);
+
   msg("fillin of assembled matrix: ", nnz);
   msg("buildAfterCoarsen needed ", t.elapsed(), " seconds");
 }
@@ -337,56 +429,44 @@ writeFiles(AdaptInfo& adaptInfo, bool force)
 
 
 template <class Traits>
-  template <class LhsData, class RhsData, class GV>
+  template <int R, int C, class LocalFE>
 void ProblemStat<Traits>::
-assemble(LhsData lhs, RhsData rhs, GV const& gridView)
+assemble(LocalFe& localFiniteElem, MatrixData<R,C> lhs, VectorData<R> rhs)
 {
-  auto const& rowFeSpace = lhs.matrix.getRowFeSpace();
-  auto const& colFeSpace = lhs.matrix.getColFeSpace();
-
-  if ((lhs.operators.empty() || !lhs.assemble) &&
-      (rhs.operators.empty() || !rhs.assemble))
+  if (((lhs.operators.empty() &&
+        lhs.boundary_operators.empty() &&
+        lhs.intersection_operators.empty()) || !lhs.assemble) &&
+      ((rhs.operators.empty() &&
+        rhs.boundary_operators.empty() &&
+        rhs.intersection_operators.empty()) || !rhs.assemble))
     return; // nothing to do
 
-  for (auto scaledOp : lhs.operators)
-    scaledOp.op->init(rowFeSpace, colFeSpace);
-  for (auto scaledOp : rhs.operators)
-    scaledOp.op->init(rowFeSpace, colFeSpace);
-
-  auto rowLocalView = rowFeSpace.localView();
-  auto rowIndexSet  = rowFeSpace.localIndexSet();
-
-  auto colLocalView = colFeSpace.localView();
-  auto colIndexSet  = colFeSpace.localIndexSet();
-
-  for (auto const& element : elements(gridView)) {
-    // TODO: use only one localView and localIndexSet if feSpaces are equal
-    rowLocalView.bind(element);
-    colLocalView.bind(element);
-    rowIndexSet.bind(rowLocalView); // NOTE: expensive operation!
-    colIndexSet.bind(colLocalView);
-
-    if (lhs.assemble) {
-      ElementMatrix elementMatrix;
-      bool add = getElementMatrix(rowLocalView, colLocalView, elementMatrix,
-                    lhs.operators, lhs.boundary_operators, lhs.intersection_operators);
-      if (add)
-        addElementMatrix(lhs.matrix, rowIndexSet, colIndexSet, elementMatrix);
-    }
+  const index_t<R> _r{};
+  const index_t<C> _c{};
 
-    if (rhs.assemble) {
-      ElementVector elementVector;
-      bool add = getElementVector(rowLocalView, elementVector,
-                    rhs.operators, rhs.boundary_operators, rhs.intersection_operators);
-      if (add)
-        addElementVector(rhs.vector, rowIndexSet, elementVector);
-    }
+  auto& rowLocalView = localFiniteElem.localView(_r);
+  auto& rowIndexSet  = localFiniteElem.localIndexSet(_r);
+
+  auto& colLocalView = localFiniteElem.localView(_c);
+  auto& colIndexSet  = localFiniteElem.localIndexSet(_c);
 
-    rowIndexSet.unbind();
-    rowLocalView.unbind();
-    colLocalView.unbind();
-    colIndexSet.unbind();
+
+  if (lhs.assemble) {
+    ElementMatrix elementMatrix;
+    bool add = getElementMatrix(rowLocalView, colLocalView, elementMatrix,
+                  lhs.operators, lhs.boundary_operators, lhs.intersection_operators);
+    if (add)
+      addElementMatrix(lhs.matrix, rowIndexSet, colIndexSet, elementMatrix);
+  }
+
+  if (rhs.assemble) {
+    ElementVector elementVector;
+    bool add = getElementVector(rowLocalView, elementVector,
+                  rhs.operators, rhs.boundary_operators, rhs.intersection_operators);
+    if (add)
+      addElementVector(rhs.vector, rowIndexSet, elementVector);
   }
+
 }
 
 
-- 
GitLab