From 16b2f8b8d1012ee93a8d6cb0f28e57031e0a6ec3 Mon Sep 17 00:00:00 2001
From: Simon Praetorius <simon.praetorius@tu-dresden.de>
Date: Mon, 9 May 2016 00:56:42 +0200
Subject: [PATCH] added a boundary-element/facet-iterator implementation, to be
 used in the problemStat->assemble method

---
 dune/amdis/BoundaryElementIterator.hpp | 213 +++++++++++++++++++++++++
 dune/amdis/Mesh.hpp                    |   4 +-
 dune/amdis/ProblemStat.hpp             |  34 +++-
 dune/amdis/ProblemStat.inc.hpp         |  81 +++++-----
 4 files changed, 283 insertions(+), 49 deletions(-)
 create mode 100644 dune/amdis/BoundaryElementIterator.hpp

diff --git a/dune/amdis/BoundaryElementIterator.hpp b/dune/amdis/BoundaryElementIterator.hpp
new file mode 100644
index 00000000..b3689038
--- /dev/null
+++ b/dune/amdis/BoundaryElementIterator.hpp
@@ -0,0 +1,213 @@
+#pragma once
+
+namespace AMDiS
+{
+  template <MeshView>
+  class BoundaryFacetIterator;
+
+  // An Iterator over all elements and when element hasBoundaryIntersections
+  template <MeshView>
+  class BoundaryElementIterator
+  {
+    friend template <class> class BoundaryFacetIterator;
+    
+    using Element = typename MeshView::Codim<0>::Entity;
+    using ElementIterator = typename MeshView::Codim<1>::Iterator;
+    
+    class Iterator
+    {
+    public:
+      Iterator(ElementIterator elementIt, ElementIterator endIt)
+        : elementIt(elementIt)
+        , endIt(endIt)
+      {}
+      
+      Iterator(Iterator const&) = default;
+      Iterator& operator=(Iterator const&) = default;
+      
+      Iterator& operator++()
+      {
+        ++elementIt;
+        while (!elementIt->hasBoundaryIntersections() && elementIt != endIt)
+          ++elementIt;
+        return *this;
+      }
+      
+      Iterator operator++(int)
+      {
+        auto tmp = *this;
+        ++(*this);
+        return tmp;
+      }
+      
+      Element& operator*() const
+      {
+        return *elementIt;
+      }
+      
+      Element* operator->() const
+      {
+        return &(*elementIt);
+      }
+      
+      bool operator==(Iterator const& that)
+      {
+        return elementIt == that.elementIt;
+      }
+      
+      bool operator!=(Iterator const& that)
+      {
+        return !(*this == that);
+      }
+      
+    private:
+      ElementIterator elementIt;
+      ElementIterator endIt;
+    };
+
+  public:
+    /// Constructor.
+    BoundaryElementIterator(MeshView& meshView)
+      : meshView(meshView)
+    {}
+    
+    Iterator begin() {
+      auto elementIt = elements(meshView).begin();
+      auto endIt = elements(meshView).end();
+      while (!elementIt->hasBoundaryIntersections() && elementIt != endIt)
+        ++elementIt;
+      
+      return {elementIt, endIt};
+    }
+    
+    Iterator end() {
+      return {elements(meshView).end(), elements(meshView).end()};
+    }
+    
+  private:
+    MeshView& meshView;
+  };
+  
+  
+  /// Generator function for the boundary-element iterator
+  template <class MeshView>
+  BoundaryElementIterator<MeshView> boundary_elements(MeshView& meshView)
+  {
+    return {meshView};
+  }
+  
+
+  // An Iterator over all elements and when element hasBoundaryIntersections, then 
+  // iterate over all boundary-intersections with given Index, oder for thos with 
+  // predicate returns true
+  template <MeshView>
+  class BoundaryFacetIterator
+  {
+    using Element = typename MeshView::Codim<0>::Entity;
+    using Facet   = typename MeshView::Codim<1>::Entity;
+    
+    using ElementIterator = typename BoundaryElementIterator<MeshView>::Iterator;
+    using FacetIterator = typename MeshView::IntersectionIterator;
+    
+    class Iterator
+    {
+    public:
+      Iterator(MeshView& meshView, 
+               ElementIterator elementIt, 
+               FacetIterator facetIt, 
+               ElementIterator endIt)
+        : meshView(meshView)
+        , elementIt(elementIt)
+        , facetIt(facetIt)
+        , endIt(endIt)
+      {}
+      
+      Iterator(Iterator const&) = default;
+      Iterator& operator=(Iterator const&) = default;
+      
+      Iterator& operator++()
+      {
+        ++facetIt;
+        do {
+          auto facetEndIt = intersections(meshView, *elementIt).end();
+          while (!facetIt->boundary() && facetIt != facetEndIt)
+            ++facetIt;
+          if (facetIt == facetEndIt)
+            ++elementIt;
+        } while (elementIt != endIt);
+        
+        return *this;
+      }
+      
+      Iterator operator++(int)
+      {
+        auto tmp = *this;
+        ++(*this);
+        return tmp;
+      }
+      
+      Facet& operator*() const
+      {
+        return *facetIt;
+      }
+      
+      Facet* operator->() const
+      {
+        return &(*facetIt);
+      }
+      
+      bool operator==(Iterator const& that)
+      {
+        return elementIt == that.elementIt && (elementIt == endIt || facetIt == that.facetIt);
+      }
+      
+      bool operator!=(Iterator const& that)
+      {
+        return !(*this == that);
+      }
+      
+    private:
+      MeshView&       meshView;
+      ElementIterator elementIt;
+      ElementIterator endIt;
+      FacetIterator   facetIt;
+    };
+
+  public:
+    /// Constructor.
+    BoundaryFacetIterator(MeshView& meshView)
+      : meshView(meshView)
+    {}
+    
+    Iterator begin() {
+      auto elementIt = boundary_elements(meshView).begin();
+      auto endElementIt = boundary_elements(meshView).end();
+      auto facetIt = intersections(meshView, *elementIt).begin();
+      auto endFacetIt = intersections(meshView, *elementIt).end();
+      
+      while (!facetIt->boundary() && facetIt != facetEndIt)
+        ++facetIt;
+      
+      return {elementIt, facetIt, endElementIt};
+    }
+    
+    Iterator end() {
+      auto elementIt = boundary_elements(meshView).begin();
+      auto endElementIt = boundary_elements(meshView).end();
+      auto facetIt = intersections(meshView, *elementIt).begin(); // TODO: what is the correct end_facet_iterator
+      return {endElementIt, facetIt, endElementIt};
+    }
+    
+  private:
+    MeshView& meshView;
+  };
+  
+  
+  /// Generator function for the boundary-element iterator
+  template <class MeshView>
+  BoundaryFacetIterator<MeshView> boundary_facets(MeshView& meshView)
+  {
+    return {meshView};
+  }
+
+} // end namespace AMDiS
diff --git a/dune/amdis/Mesh.hpp b/dune/amdis/Mesh.hpp
index ded9d67b..958eeaa3 100644
--- a/dune/amdis/Mesh.hpp
+++ b/dune/amdis/Mesh.hpp
@@ -147,7 +147,7 @@ namespace AMDiS
       Dune::FieldVector<double, dim> L; L = 1.0;  // extension of the domain
       Parameters::get(meshName + "->dimension", L);
       
-      auto s = Dune::fill_array<dim>(int{2}); // number of cells on coarse mesh in each direction
+      auto s = Dune::fill_array<int,dim>(2); // number of cells on coarse mesh in each direction
       Parameters::get(meshName + "->num cells", s);
       
       // TODO: add more parameters for yasp-grid (see constructor)      
@@ -168,7 +168,7 @@ namespace AMDiS
       Parameters::get(meshName + "->min corner", lowerleft);
       Parameters::get(meshName + "->max corner", upperright);
       
-      auto s = Dune::fill_array<dim>(int{2}); // number of cells on coarse mesh in each direction
+      auto s = Dune::fill_array<int,dim>(2); // number of cells on coarse mesh in each direction
       Parameters::get(meshName + "->num cells", s);
       
       // TODO: add more parameters for yasp-grid (see constructor)      
diff --git a/dune/amdis/ProblemStat.hpp b/dune/amdis/ProblemStat.hpp
index a8147fcb..46d44560 100644
--- a/dune/amdis/ProblemStat.hpp
+++ b/dune/amdis/ProblemStat.hpp
@@ -278,10 +278,8 @@ namespace AMDiS
     
   protected: // sub-methods to assemble DOFMatrix
     
-    template <class Matrix, class Vector>
-    void assemble(std::pair<int, int> row_col,
-                  Matrix& matrix, bool asmMatrix,
-                  Vector& rhs,    bool asmVector);
+    template <class LhsData, class RhsData, class Elements>
+    void assemble(LhsData lhs, RhsData rhs, Elements const& elements);
     
     template <class RowView, class ColView>
     bool getElementMatrix(RowView const& rowView,
@@ -382,6 +380,34 @@ namespace AMDiS
     VectorEntries<double*>                  vectorFactors;
     std::map< int, bool >                   vectorAssembled; // if false, do reassemble
     std::map< int, bool >                   vectorChanging;  // if true, or vectorAssembled false, do reassemble
+    
+    template <int r, int c>
+    struct MatrixData 
+    {
+      using DOFMatrixType = 
+        std::tuple_element_t<c, std::tuple_element_t<r, typename SystemMatrixType::DOFMatrices>>;
+        
+      DOFMatrixType&                       matrix;
+      std::list<shared_ptr<OperatorType>>& operators;
+      std::list<double*> const&            factors;
+      bool                                 assemble;
+      
+      std::pair<int,int> row_col = {r, c};
+    };
+    
+    template <int r>
+    struct VectorData
+    {
+      using DOFVectorType = 
+        std::tuple_element_t<r, typename SystemVectorType::DOFVectors>;
+      
+      DOFVectorType&                       vector;
+      std::list<shared_ptr<OperatorType>>& operators;
+      std::list<double*> const&            factors;
+      bool                                 assemble;
+      
+      int row = r;
+    };
   };
   
   
diff --git a/dune/amdis/ProblemStat.inc.hpp b/dune/amdis/ProblemStat.inc.hpp
index 73274b8e..16b61d6c 100644
--- a/dune/amdis/ProblemStat.inc.hpp
+++ b/dune/amdis/ProblemStat.inc.hpp
@@ -277,16 +277,20 @@ namespace AMDiS
 
       For<0, nComponents>::loop([this, &nnz, asmMatrix_, asmVector, _r](auto const _c) 
       {
+        using MatrixData = typename ProblemStatSeq<Traits>::template MatrixData<_r, _c>;
+        using VectorData = typename ProblemStatSeq<Traits>::template VectorData<_r>;
+        
         // The DOFMatrix which should be assembled
         auto& dofmatrix    = (*systemMatrix)(_r, _c);
         auto& solution_vec = (*solution)[_c];
         auto& rhs_vec      = (*rhs)[_r];
 	  
         auto row_col = std::make_pair(int(_r), int(_c));
-        bool asmMatrix = asmMatrix_ && (!matrixAssembled[row_col] || matrixChanging[row_col]);
+        bool assembleMatrix = asmMatrix_ && (!matrixAssembled[row_col] || matrixChanging[row_col]);
+        bool assembleVector = asmVector  && _r == _c;
         
         int r = 0, c = 0;
-	if (asmMatrix) {
+	if (assembleMatrix) {
 	  // init boundary condition
 	  for (auto bc_list : dirichletBc) {
 	    std::tie(r, c) = bc_list.first;
@@ -296,11 +300,22 @@ namespace AMDiS
 	    }
 	  }
         }
+        
+        auto mat = MatrixData{dofmatrix, matrixOperators[row_col], matrixFactors[row_col], assembleMatrix};
+        auto vec = VectorData{rhs_vec,   vectorOperators[int(_r)], vectorFactors[int(_r)], assembleVector};
 	  
         // assemble the DOFMatrix block and the corresponding rhs vector, of r==c
-        this->assemble(row_col, dofmatrix, asmMatrix, rhs_vec, (_r == _c && asmVector));
+    
+        dofmatrix.init(assembleMatrix);
+        this->assemble(mat, vec, elements(*meshView));
+        dofmatrix.finish();
+      
+        if (assembleMatrix)
+          matrixAssembled[row_col] = true;
+        if (assembleVector)
+          vectorAssembled[int(_r)] = true;
 
-        if (asmMatrix) {
+        if (assembleMatrix) {
 	  // finish boundary condition
 	  for (auto bc_list : dirichletBc) {
 	    std::tie(r, c) = bc_list.first;
@@ -331,71 +346,51 @@ namespace AMDiS
 
 
   template <class Traits>
-    template <class Matrix, class Vector>
-  void ProblemStatSeq<Traits>::assemble(std::pair<int, int> row_col,
-                                        Matrix& dofmatrix, bool asmMatrix,
-                                        Vector& rhs,       bool asmVector)
+    template <class LhsData, class RhsData, class Elements>
+  void ProblemStatSeq<Traits>::assemble(LhsData lhs, RhsData rhs,
+                                        Elements const& elements)
   {    
-    auto const& rowFeSpace = dofmatrix.getRowFeSpace();
-    auto const& colFeSpace = dofmatrix.getColFeSpace();
-    
-    auto matrixOp = matrixOperators[row_col];
-    auto matrixOpFac = matrixFactors[row_col];
-    auto vectorOp = vectorOperators[row_col.first];
-    auto vectorOpFac = vectorFactors[row_col.first];
-    
-    dofmatrix.init(asmMatrix);
+    auto const& rowFeSpace = lhs.matrix.getRowFeSpace();
+    auto const& colFeSpace = lhs.matrix.getColFeSpace();
     
-    // return if nothing to to.
-    if ((matrixOp.empty() || !asmMatrix) && 
-        (vectorOp.empty()  || !asmVector))
-    {
-      dofmatrix.finish();
-      return;
-    }
+    if ((lhs.operators.empty() || !lhs.assemble) && 
+        (rhs.operators.empty() || !rhs.assemble))
+      return; // nothing to do
     
-    for (auto op : matrixOp)
+    for (auto op : lhs.operators)
       op->init(rowFeSpace, colFeSpace);
-    for (auto op : vectorOp)
+    for (auto op : rhs.operators)
       op->init(rowFeSpace, colFeSpace);
     
     auto rowLocalView = rowFeSpace.localView();
-    auto rowIndexSet = rowFeSpace.localIndexSet();
+    auto rowIndexSet  = rowFeSpace.localIndexSet();
     
     auto colLocalView = colFeSpace.localView();
-    auto colIndexSet = colFeSpace.localIndexSet();
+    auto colIndexSet  = colFeSpace.localIndexSet();
     
-    for (auto const& element : elements(*meshView)) {
+    for (auto const& element : elements) {
       rowLocalView.bind(element);
       rowIndexSet.bind(rowLocalView);
       
       colLocalView.bind(element);
       colIndexSet.bind(colLocalView);
       
-      if (asmMatrix) {
+      if (lhs.assemble) {
         ElementMatrix elementMatrix;
         bool add = getElementMatrix(rowLocalView, colLocalView, elementMatrix, 
-                                    matrixOp, matrixOpFac);
+                                    lhs.operators, lhs.factors);
         if (add)
-          addElementMatrix(dofmatrix, rowIndexSet, colIndexSet, elementMatrix);
+          addElementMatrix(lhs.matrix, rowIndexSet, colIndexSet, elementMatrix);
       }
       
-      if (asmVector) {
+      if (rhs.assemble) {
         ElementVector elementVector;
         bool add = getElementVector(rowLocalView, elementVector, 
-                                    vectorOp, vectorOpFac);
+                                    rhs.operators, rhs.factors);
         if (add)
-          addElementVector(rhs, rowIndexSet, elementVector);
+          addElementVector(rhs.vector, rowIndexSet, elementVector);
       }
     }
-    
-    if (asmMatrix) {
-      dofmatrix.finish();
-      matrixAssembled[row_col] = true;
-    }
-    if (asmVector) {      
-      vectorAssembled[row_col.first] = true;
-    }
   }
   
   
-- 
GitLab