From 428d373fc3ba845715db83e3d221a958dee5c386 Mon Sep 17 00:00:00 2001
From: Simon Praetorius <simon.praetorius@tu-dresden.de>
Date: Sun, 29 Apr 2018 23:10:07 +0200
Subject: [PATCH] assembler simplified, by adding an elementAssembler functor

---
 src/amdis/Assembler.hpp        |  18 +-----
 src/amdis/Assembler.inc.hpp    | 107 ++++++++++++---------------------
 src/amdis/utility/TreeData.hpp |  50 +++++++++++----
 3 files changed, 80 insertions(+), 95 deletions(-)

diff --git a/src/amdis/Assembler.hpp b/src/amdis/Assembler.hpp
index 11dec5bf..c5f6ab4e 100644
--- a/src/amdis/Assembler.hpp
+++ b/src/amdis/Assembler.hpp
@@ -54,23 +54,11 @@ namespace AMDiS
         SystemVectorType& rhs,
         bool asmMatrix, bool asmVector) const;
 
-
-    template <class ElementContainer, class Container, class Operators, class Geometry, class Basis>
-    void assembleElementOperators(
-        ElementContainer& elementContainer,
-        Container& container,
-        Operators& operators,
-        Geometry const& geometry,
-        Basis const& subBasis) const;
-
-    template <class ElementContainer, class Container, class Operators, class Geometry, class RowBasis, class ColBasis>
+    template <class Element, class Operators, class ElementAssembler>
     void assembleElementOperators(
-        ElementContainer& elementContainer,
-        Container& container,
+        Element const& element,
         Operators& operators,
-        Geometry const& geometry,
-        RowBasis const& rowBasis,
-        ColBasis const& colBasis) const;
+        ElementAssembler const& elementAssembler) const;
 
     /// Finish insertion into the matrix and assembles boundary conditions
     /// Return the number of nonzeros assembled into the matrix
diff --git a/src/amdis/Assembler.inc.hpp b/src/amdis/Assembler.inc.hpp
index 69d37bda..f71a9828 100644
--- a/src/amdis/Assembler.inc.hpp
+++ b/src/amdis/Assembler.inc.hpp
@@ -42,11 +42,19 @@ void Assembler<Traits>::assemble(
     {
       auto rowBasis = Dune::Functions::subspaceBasis(globalBasis_, rowTreePath);
       auto rowLocalView = rowBasis.localView();
-      rowLocalView.bind(element); // NOTE: Is this necessary?
+      rowLocalView.bind(element);
 
       auto& rhsOp = rhsOperators_[rowNode];
-      if (rhsOp.assemble(asmVector) && !rhsOp.empty())
-        this->assembleElementOperators(elementVector, rhs, rhsOp, geometry, rowLocalView);
+      if (rhsOp.assemble(asmVector) && !rhsOp.empty()) {
+        rhsOp.bind(element, geometry);
+
+        auto vecAssembler = [&](auto const& context, auto& operator_list) {
+          for (auto scaled : operator_list)
+            scaled.op->assemble(context, elementVector, rowLocalView.tree());
+        };
+
+        this->assembleElementOperators(element, rhsOp, vecAssembler);
+      }
 
       forEachNode(localView.tree(), [&,this](auto const& colNode, auto colTreePath)
       {
@@ -54,9 +62,16 @@ void Assembler<Traits>::assemble(
         if (matOp.assemble(asmMatrix) && !matOp.empty()) {
           auto colBasis = Dune::Functions::subspaceBasis(globalBasis_, colTreePath);
           auto colLocalView = colBasis.localView();
-          colLocalView.bind(element); // NOTE: Is this necessary?
+          colLocalView.bind(element);
+
+          matOp.bind(element, geometry);
+
+          auto matAssembler = [&](auto const& context, auto& operator_list) {
+            for (auto scaled : operator_list)
+              scaled.op->assemble(context, elementMatrix, rowLocalView.tree(), colLocalView.tree());
+          };
 
-          this->assembleElementOperators(elementMatrix, matrix, matOp, geometry, rowLocalView, colLocalView);
+          this->assembleElementOperators(element, matOp, matAssembler);
         }
       });
     });
@@ -82,6 +97,14 @@ void Assembler<Traits>::assemble(
       }
     }
 
+    // unbind all operators
+    forEachNode(localView.tree(), [&,this](auto const& rowNode, auto&&) {
+      rhsOperators_[rowNode].unbind();
+      forEachNode(localView.tree(), [&,this](auto const& colNode, auto&&) {
+        matrixOperators_[rowNode][colNode].unbind();
+      });
+    });
+
     localIndexSet.unbind();
     localView.unbind();
   }
@@ -89,85 +112,29 @@ void Assembler<Traits>::assemble(
   // 4. finish matrix insertion and apply dirichlet boundary conditions
   std::size_t nnz = finishMatrixVector(matrix, solution, rhs, asmMatrix, asmVector);
 
-  msg("fillin of assembled matrix: ", nnz);
+  msg("fill-in of assembled matrix: ", nnz);
 }
 
 
 template <class Traits>
-  template <class ElementContainer, class Container, class Operators, class Geometry, class LocalView>
+  template <class Element, class Operators, class ElementAssembler>
 void Assembler<Traits>::assembleElementOperators(
-    ElementContainer& elementContainer,
-    Container& container,
+    Element const& element,
     Operators& operators,
-    Geometry const& geometry,
-    LocalView const& localView) const
+    ElementAssembler const& elementAssembler) const
 {
-  auto const& element = getElement(localView);
-  auto const& gridView = getGridView(localView);
-
-  bool add = false;
-
-  auto assemble_operators = [&](auto const& context, auto& operator_list) {
-    for (auto scaled : operator_list) {
-      scaled.op->bind(element, geometry);
-      bool add_op = scaled.op->assemble(context, elementContainer, localView.tree());
-      scaled.op->unbind();
-      add = add || add_op;
-    }
-  };
-
-  // assemble element operators
-  assemble_operators(element, operators.element);
-
-  // assemble intersection operators
-  if (!operators.intersection.empty()
-      || (!operators.boundary.empty() && element.hasBoundaryIntersections()))
-  {
-    for (auto const& intersection : intersections(gridView, element)) {
-      if (intersection.boundary())
-        assemble_operators(intersection, operators.boundary);
-      else
-        assemble_operators(intersection, operators.intersection);
-    }
-  }
-}
-
-
-template <class Traits>
-  template <class ElementContainer, class Container, class Operators, class Geometry, class RowLocalView, class ColLocalView>
-void Assembler<Traits>::assembleElementOperators(
-    ElementContainer& elementContainer,
-    Container& container,
-    Operators& operators,
-    Geometry const& geometry,
-    RowLocalView const& rowLocalView, ColLocalView const& colLocalView) const
-{
-  auto const& element = getElement(rowLocalView, colLocalView);
-  auto const& gridView = getGridView(rowLocalView, colLocalView);
-
-  bool add = false;
-
-  auto assemble_operators = [&](auto const& context, auto& operator_list) {
-    for (auto scaled : operator_list) {
-      scaled.op->bind(element, geometry);
-      bool add_op = scaled.op->assemble(context, elementContainer, rowLocalView.tree(), colLocalView.tree());
-      scaled.op->unbind();
-      add = add || add_op;
-    }
-  };
-
   // assemble element operators
-  assemble_operators(element, operators.element);
+  elementAssembler(element, operators.element);
 
   // assemble intersection operators
   if (!operators.intersection.empty()
       || (!operators.boundary.empty() && element.hasBoundaryIntersections()))
   {
-    for (auto const& intersection : intersections(gridView, element)) {
+    for (auto const& intersection : intersections(globalBasis_.gridView(), element)) {
       if (intersection.boundary())
-        assemble_operators(intersection, operators.boundary);
+        elementAssembler(intersection, operators.boundary);
       else
-        assemble_operators(intersection, operators.intersection);
+        elementAssembler(intersection, operators.intersection);
     }
   }
 }
@@ -184,6 +151,8 @@ void Assembler<Traits>::initMatrixVector(
   matrix.init(asmMatrix);
   solution.compress();
   rhs.compress();
+  if (asmVector)
+    rhs = 0;
 
   auto localView = globalBasis_.localView();
   forEachNode(localView.tree(), [&,this](auto const& rowNode, auto rowTreePath)
diff --git a/src/amdis/utility/TreeData.hpp b/src/amdis/utility/TreeData.hpp
index 411969d3..b181a9e7 100644
--- a/src/amdis/utility/TreeData.hpp
+++ b/src/amdis/utility/TreeData.hpp
@@ -10,6 +10,11 @@
 
 namespace AMDiS
 {
+  namespace tag
+  {
+    struct store {};
+  }
+
   /**
   * \brief Container allowing to attach data to each node of a tree
   *
@@ -34,7 +39,7 @@ namespace AMDiS
   * \tparam ND The data stored for a node of type Node will be of type ND<Node>
   * \tparam LO Set this flag if data should only be attached to leaf nodes.
   */
-  template<class T, template<class> class ND, bool LO>
+  template <class T, template<class> class ND, bool LO>
   class TreeData
   {
   public:
@@ -48,7 +53,7 @@ namespace AMDiS
     static const bool leafOnly = LO;
 
     //! Template to determine the data type for given node type
-    template<class Node>
+    template <class Node>
     using NodeData = ND<Node>;
 
   public:
@@ -65,16 +70,21 @@ namespace AMDiS
       * of the tree data.
       * See also \ref init.
       */
-    explicit TreeData(const Tree& tree)
+    TreeData(Tree const& tree, tag::store)
       : tree_(&tree)
     {
       initData();
     }
 
+    explicit TreeData(Tree& tree)
+      : TreeData(tree, tag::store{})
+    {}
+
     //! Copy constructor
-    TreeData(const TreeData& other)
-      : TreeData(*other.tree_)
+    TreeData(TreeData const& other)
+      : tree_(other.tree_)
     {
+      initData();
       copyData(other);
     }
 
@@ -92,13 +102,18 @@ namespace AMDiS
       * A reference to the tree is stored because it's needed for destruction
       * of the tree data.
       */
-    void init(const Tree& tree)
+    void init(Tree const& tree, tag::store)
     {
       destroyData();
       tree_ = &tree;
       initData();
     }
 
+    void init(Tree& tree)
+    {
+      init(tree, tag::store{});
+    }
+
     //! Copy and Move assignment
     TreeData& operator=(TreeData other)
     {
@@ -122,6 +137,8 @@ namespace AMDiS
     template<class Node>
     NodeData<Node>& operator[](const Node& node)
     {
+      assert(initialized_);
+      assert(data_.size() > node.treeIndex());
       return *(NodeData<Node>*)(data_[node.treeIndex()]);
     }
 
@@ -129,6 +146,8 @@ namespace AMDiS
     template<class Node>
     const NodeData<Node>& operator[](const Node& node) const
     {
+      assert(initialized_);
+      assert(data_.size() > node.treeIndex());
       return *(NodeData<Node>*)(data_[node.treeIndex()]);
     }
 
@@ -138,6 +157,7 @@ namespace AMDiS
       using std::swap;
       swap(tree_, other.tree_);
       swap(data_, other.data_);
+      swap(initialized_, other.initialized_);
     }
 
   protected:
@@ -154,6 +174,7 @@ namespace AMDiS
         using Node = std::remove_reference_t<decltype(node)>;
         data_[node.treeIndex()] = new NodeData<Node>;
       });
+      initialized_ = true;
     }
 
     // Deep copy of node data
@@ -173,6 +194,7 @@ namespace AMDiS
         delete p;
       });
       tree_ = nullptr;
+      initialized_ = false;
     }
 
   protected:
@@ -193,6 +215,8 @@ namespace AMDiS
   protected:
     const Tree* tree_ = nullptr;
     std::vector<void*> data_;
+
+    bool initialized_ = false;
   };
 
 
@@ -221,18 +245,22 @@ namespace AMDiS
   class MatrixData
       : public TreeData<Tree, ND, false>
   {
-    using TD = TreeData<Tree, ND, false>;
+    using Super = TreeData<Tree, ND, false>;
 
   public:
-    void init(Tree const& tree)
+    void init(Tree const& tree, tag::store)
     {
-      TD::init(tree);
+      Super::init(tree, tag::store{});
       forEachNode(tree, [&,this](auto const& node, auto&&)
       {
-        (*this)[node].init(tree);
+        (*this)[node].init(tree, tag::store{});
       });
     }
-    using TD::init;
+
+    void init(Tree& tree)
+    {
+      init(tree, tag::store{});
+    }
   };
 
 
-- 
GitLab