From ebf34297c8db15a7f5aa399f68f3c62184d73e06 Mon Sep 17 00:00:00 2001
From: Oliver Sander <oliver.sander@tu-dresden.de>
Date: Sat, 26 Dec 2015 21:42:30 +0100
Subject: [PATCH] Use a composite dune-functions basis, instead of two separate
 basis objects

This is an important step towards making the 'mixed' Cosserat code more like
the regular one.  The mid-term goal is to merge the two implementations, because
there is way too much code duplication here.
---
 dune/gfe/mixedcosseratenergy.hh           |  40 ++--
 dune/gfe/mixedgfeassembler.hh             | 248 +++++++++-------------
 dune/gfe/mixedlocalgeodesicfestiffness.hh |  20 +-
 dune/gfe/mixedlocalgfeadolcstiffness.hh   |  59 ++---
 dune/gfe/mixedriemanniantrsolver.cc       |  28 ++-
 dune/gfe/mixedriemanniantrsolver.hh       |  11 +-
 src/mixed-cosserat-continuum.cc           |  25 ++-
 7 files changed, 195 insertions(+), 236 deletions(-)

diff --git a/dune/gfe/mixedcosseratenergy.hh b/dune/gfe/mixedcosseratenergy.hh
index 8c0bcf45..0605243e 100644
--- a/dune/gfe/mixedcosseratenergy.hh
+++ b/dune/gfe/mixedcosseratenergy.hh
@@ -20,15 +20,14 @@
 //#define QUADRATIC_MEMBRANE_ENERGY
 
 
-template<class DisplacementBasis, class OrientationBasis, int dim, class field_type=double>
+template<class Basis, int dim, class field_type=double>
 class MixedCosseratEnergy
-    : public MixedLocalGeodesicFEStiffness<DisplacementBasis,RealTuple<field_type,dim>,
-                                           OrientationBasis,Rotation<field_type,dim> >
+    : public MixedLocalGeodesicFEStiffness<Basis,
+                                           RealTuple<field_type,dim>,
+                                           Rotation<field_type,dim> >
 {
     // grid types
-    typedef typename DisplacementBasis::LocalView::Tree::FiniteElement DisplacementLocalFiniteElement;
-    typedef typename OrientationBasis::LocalView::Tree::FiniteElement OrientationLocalFiniteElement;
-    typedef typename DisplacementBasis::GridView GridView;
+    typedef typename Basis::GridView GridView;
     typedef typename GridView::ctype DT;
     typedef field_type RT;
     typedef typename GridView::template Codim<0>::Entity Entity;
@@ -146,10 +145,8 @@ public:
     }
 
     /** \brief Assemble the energy for a single element */
-    RT energy (const Entity& e,
-               const DisplacementLocalFiniteElement& displacementLocalFiniteElement,
+    RT energy (const typename Basis::LocalView& localView,
                const std::vector<RealTuple<field_type,dim> >& localDisplacementConfiguration,
-               const OrientationLocalFiniteElement& orientationLocalFiniteElement,
                const std::vector<Rotation<field_type,dim> >& localOrientationConfiguration) const;
 
     /** \brief The energy \f$ W_{mp}(\overline{U}) \f$, as written in
@@ -265,32 +262,33 @@ public:
     const Dune::VirtualFunction<Dune::FieldVector<double,gridDim>, Dune::FieldVector<double,3> >* neumannFunction_;
 };
 
-template <class DeformationBasis, class OrientationBasis, int dim, class field_type>
-typename MixedCosseratEnergy<DeformationBasis,OrientationBasis,dim,field_type>::RT
-MixedCosseratEnergy<DeformationBasis,OrientationBasis,dim,field_type>::
-energy(const Entity& element,
-       const DisplacementLocalFiniteElement& deformationLocalFiniteElement,
+template <class Basis, int dim, class field_type>
+typename MixedCosseratEnergy<Basis,dim,field_type>::RT
+MixedCosseratEnergy<Basis,dim,field_type>::
+energy(const typename Basis::LocalView& localView,
        const std::vector<RealTuple<field_type,dim> >& localDeformationConfiguration,
-       const OrientationLocalFiniteElement& orientationLocalFiniteElement,
        const std::vector<Rotation<field_type,dim> >& localOrientationConfiguration) const
 {
-    assert(element.type() == deformationLocalFiniteElement.type());
-    assert(element.type() == orientationLocalFiniteElement.type());
     typedef typename GridView::template Codim<0>::Entity::Geometry Geometry;
 
+    auto element = localView.element();
+
     RT energy = 0;
 
-    typedef LocalGeodesicFEFunction<gridDim, DT, DisplacementLocalFiniteElement, RealTuple<field_type,dim> > LocalDeformationGFEFunctionType;
+    using namespace Dune::TypeTree::Indices;
+    const auto& deformationLocalFiniteElement = localView.tree().child(_0).finiteElement();
+    const auto& orientationLocalFiniteElement = localView.tree().child(_1).finiteElement();
+
+    typedef LocalGeodesicFEFunction<gridDim, DT, decltype(deformationLocalFiniteElement), RealTuple<field_type,dim> > LocalDeformationGFEFunctionType;
     LocalDeformationGFEFunctionType localDeformationGFEFunction(deformationLocalFiniteElement,localDeformationConfiguration);
 
-    typedef LocalGeodesicFEFunction<gridDim, DT, OrientationLocalFiniteElement, Rotation<field_type,dim> > LocalOrientationGFEFunctionType;
+    typedef LocalGeodesicFEFunction<gridDim, DT, decltype(orientationLocalFiniteElement), Rotation<field_type,dim> > LocalOrientationGFEFunctionType;
     LocalOrientationGFEFunctionType localOrientationGFEFunction(orientationLocalFiniteElement,localOrientationConfiguration);
 
     // \todo Implement smarter quadrature rule selection for more efficiency, i.e., less evaluations of the Rotation GFE function
     int quadOrder = deformationLocalFiniteElement.localBasis().order() * ((element.type().isSimplex()) ? 1 : gridDim);
 
-    const Dune::QuadratureRule<DT, gridDim>& quad
-        = Dune::QuadratureRules<DT, gridDim>::rule(element.type(), quadOrder);
+    const auto& quad = Dune::QuadratureRules<DT, gridDim>::rule(element.type(), quadOrder);
 
     for (size_t pt=0; pt<quad.size(); pt++) {
 
diff --git a/dune/gfe/mixedgfeassembler.hh b/dune/gfe/mixedgfeassembler.hh
index 737cc99b..8949ff38 100644
--- a/dune/gfe/mixedgfeassembler.hh
+++ b/dune/gfe/mixedgfeassembler.hh
@@ -11,10 +11,10 @@
 
 /** \brief A global FE assembler for problems involving functions that map into non-Euclidean spaces
  */
-template <class Basis0, class TargetSpace0, class Basis1, class TargetSpace1>
+template <class Basis, class TargetSpace0, class TargetSpace1>
 class MixedGFEAssembler {
 
-    typedef typename Basis0::GridView GridView;
+    typedef typename Basis::GridView GridView;
     typedef typename GridView::template Codim<0>::template Partition<Dune::Interior_Partition>::Iterator ElementIterator;
 
     //! Dimension of the grid.
@@ -32,21 +32,18 @@ class MixedGFEAssembler {
 
 protected:
 public:
-    const Basis0 basis0_;
-    const Basis1 basis1_;
+    const Basis basis_;
 
-    MixedLocalGeodesicFEStiffness<Basis0, TargetSpace0,
-                                  Basis1, TargetSpace1>* localStiffness_;
+    MixedLocalGeodesicFEStiffness<Basis,
+                                  TargetSpace0,
+                                  TargetSpace1>* localStiffness_;
 
 public:
 
     /** \brief Constructor for a given grid */
-    MixedGFEAssembler(const Basis0& basis0,
-                      const Basis1& basis1,
-                      MixedLocalGeodesicFEStiffness<Basis0, TargetSpace0,
-                                                    Basis1, TargetSpace1>* localStiffness)
-        : basis0_(basis0),
-          basis1_(basis1),
+    MixedGFEAssembler(const Basis& basis,
+                      MixedLocalGeodesicFEStiffness<Basis, TargetSpace0, TargetSpace1>* localStiffness)
+        : basis_(basis),
           localStiffness_(localStiffness)
     {}
 
@@ -83,65 +80,50 @@ public:
 
 
 
-template <class Basis0, class TargetSpace0, class Basis1, class TargetSpace1>
-void MixedGFEAssembler<Basis0,TargetSpace0,Basis1,TargetSpace1>::
+template <class Basis, class TargetSpace0, class TargetSpace1>
+void MixedGFEAssembler<Basis,TargetSpace0,TargetSpace1>::
 getMatrixPattern(Dune::MatrixIndexSet& nb00,
                  Dune::MatrixIndexSet& nb01,
                  Dune::MatrixIndexSet& nb10,
                  Dune::MatrixIndexSet& nb11) const
 {
-    nb00.resize(basis0_.indexSet().size(), basis0_.indexSet().size());
-    nb01.resize(basis0_.indexSet().size(), basis1_.indexSet().size());
-    nb10.resize(basis1_.indexSet().size(), basis0_.indexSet().size());
-    nb11.resize(basis1_.indexSet().size(), basis1_.indexSet().size());
+    nb00.resize(basis_.size({0}), basis_.size({0}));
+    nb01.resize(basis_.size({0}), basis_.size({1}));
+    nb10.resize(basis_.size({1}), basis_.size({0}));
+    nb11.resize(basis_.size({1}), basis_.size({1}));
 
     // A view on the FE basis on a single element
-    auto localView0 = basis0_.localView();
-    auto localView1 = basis1_.localView();
-    auto localIndexSet0 = basis0_.indexSet().localIndexSet();
-    auto localIndexSet1 = basis1_.indexSet().localIndexSet();
-
-    // Grid view must be the same for both bases
-    ElementIterator it    = basis0_.gridView().template begin<0,Dune::Interior_Partition>();
-    ElementIterator endit = basis0_.gridView().template end<0,Dune::Interior_Partition>  ();
-
-    for (; it!=endit; ++it) {
+    auto localView = basis_.localView();
+    auto localIndexSet = basis_.localIndexSet();
 
+    // Loop over grid elements
+    for (const auto& element : elements(basis_.gridView(), Dune::Partitions::interior))
+    {
         // Bind the local FE basis view to the current element
-        localView0.bind(*it);
-        localView1.bind(*it);
-        localIndexSet0.bind(localView0);
-        localIndexSet1.bind(localView1);
-
-        for (size_t i=0; i<localView0.size(); i++) {
-
-            int iIdx = localIndexSet0.index(i)[0];
-
-            for (size_t j=0; j<localView0.size(); j++) {
-                int jIdx = localIndexSet0.index(j)[0];
-                nb00.add(iIdx, jIdx);
-            }
-
-            for (size_t j=0; j<localView1.size(); j++) {
-                int jIdx = localIndexSet1.index(j)[0];
-                nb01.add(iIdx, jIdx);
-            }
-
-        }
-
-        for (size_t i=0; i<localView1.size(); i++) {
-
-            int iIdx = localIndexSet1.index(i)[0];
-
-            for (size_t j=0; j<localView0.size(); j++) {
-                int jIdx = localIndexSet0.index(j)[0];
-                nb10.add(iIdx, jIdx);
-            }
-
-            for (size_t j=0; j<localView1.size(); j++) {
-                int jIdx = localIndexSet1.index(j)[0];
-                nb11.add(iIdx, jIdx);
-            }
+        localView.bind(element);
+        localIndexSet.bind(localView);
+
+        // Add element stiffness matrix onto the global stiffness matrix
+        for (size_t i=0; i<localIndexSet.size(); i++)
+        {
+          // The global index of the i-th local degree of freedom of the element 'e'
+          auto row = localIndexSet.index(i);
+
+          for (size_t j=0; j<localIndexSet.size(); j++ )
+          {
+            // The global index of the j-th local degree of freedom of the element 'e'
+            auto col = localIndexSet.index(j);
+
+            if (row[0]==0 and col[0]==0)
+              nb00.add(row[1],col[1]);
+            if (row[0]==0 and col[0]==1)
+              nb01.add(row[1],col[1]);
+            if (row[0]==1 and col[0]==0)
+              nb10.add(row[1],col[1]);
+            if (row[0]==1 and col[0]==1)
+              nb11.add(row[1],col[1]);
+
+          }
 
         }
 
@@ -149,8 +131,8 @@ getMatrixPattern(Dune::MatrixIndexSet& nb00,
 
 }
 
-template <class Basis0, class TargetSpace0, class Basis1, class TargetSpace1>
-void MixedGFEAssembler<Basis0,TargetSpace0,Basis1,TargetSpace1>::
+template <class Basis, class TargetSpace0, class TargetSpace1>
+void MixedGFEAssembler<Basis,TargetSpace0,TargetSpace1>::
 assembleGradientAndHessian(const std::vector<TargetSpace0>& configuration0,
                            const std::vector<TargetSpace1>& configuration1,
                            Dune::BlockVector<Dune::FieldVector<double, blocksize0> >& gradient0,
@@ -188,83 +170,69 @@ assembleGradientAndHessian(const std::vector<TargetSpace0>& configuration0,
     gradient1 = 0;
 
     // A view on the FE basis on a single element
-    auto localView0 = basis0_.localView();
-    auto localView1 = basis1_.localView();
-    auto localIndexSet0 = basis0_.indexSet().localIndexSet();
-    auto localIndexSet1 = basis1_.indexSet().localIndexSet();
-
-    ElementIterator it    = basis0_.gridView().template begin<0,Dune::Interior_Partition>();
-    ElementIterator endit = basis0_.gridView().template end<0,Dune::Interior_Partition>  ();
-
-    for( ; it != endit; ++it ) {
+    auto localView = basis_.localView();
+    auto localIndexSet = basis_.localIndexSet();
 
+    for (const auto& element : elements(basis_.gridView(), Dune::Partitions::interior))
+    {
         // Bind the local FE basis view to the current element
-        localView0.bind(*it);
-        localView1.bind(*it);
-        localIndexSet0.bind(localView0);
-        localIndexSet1.bind(localView1);
+        localView.bind(element);
+        localIndexSet.bind(localView);
 
-        const int nDofs0 = localView0.size();
-        const int nDofs1 = localView1.size();
+        using namespace Dune::TypeTree::Indices;
+        const int nDofs0 = localView.tree().child(_0).finiteElement().size();
+        const int nDofs1 = localView.tree().child(_1).finiteElement().size();
 
         // Extract local solution
         std::vector<TargetSpace0> localConfiguration0(nDofs0);
         std::vector<TargetSpace1> localConfiguration1(nDofs1);
 
-        for (int i=0; i<nDofs0; i++)
-            localConfiguration0[i] = configuration0[localIndexSet0.index(i)[0]];
-
-        for (int i=0; i<nDofs1; i++)
-            localConfiguration1[i] = configuration1[localIndexSet1.index(i)[0]];
+        for (int i=0; i<nDofs0+nDofs1; i++)
+        {
+          if (localIndexSet.index(i)[0] == 0)
+            localConfiguration0[i] = configuration0[localIndexSet.index(i)[1]];
+          else
+            localConfiguration1[i-nDofs0] = configuration1[localIndexSet.index(i)[1]];
+        }
 
         std::vector<Dune::FieldVector<double,blocksize0> > localGradient0(nDofs0);
         std::vector<Dune::FieldVector<double,blocksize1> > localGradient1(nDofs1);
 
         // setup local matrix and gradient
-        localStiffness_->assembleGradientAndHessian(*it,
-                                                    localView0.tree().finiteElement(), localConfiguration0,
-                                                    localView1.tree().finiteElement(), localConfiguration1,
+        localStiffness_->assembleGradientAndHessian(localView,
+                                                    localConfiguration0, localConfiguration1,
                                                     localGradient0, localGradient1);
 
         // Add element matrix to global stiffness matrix
-        for (int i=0; i<nDofs0; i++) {
+        for (int i=0; i<nDofs0+nDofs1; i++)
+        {
+            auto row = localIndexSet.index(i);
 
-            int row = localIndexSet0.index(i)[0];
+            for (int j=0; j<nDofs0+nDofs1; j++ )
+            {
+                auto col = localIndexSet.index(j);
 
-            for (int j=0; j<nDofs0; j++ ) {
-                int col = localIndexSet0.index(j)[0];
-                hessian00[row][col] += localStiffness_->A00_[i][j];
-            }
-
-            for (int j=0; j<nDofs1; j++ ) {
-                int col = localIndexSet1.index(j)[0];
-                hessian01[row][col] += localStiffness_->A01_[i][j];
-            }
-        }
+                if (row[0]==0 and col[0]==0)
+                  hessian00[row[1]][col[1]] += localStiffness_->A00_[i][j];
 
-        for (int i=0; i<nDofs1; i++) {
+                if (row[0]==0 and col[0]==1)
+                  hessian01[row[1]][col[1]] += localStiffness_->A01_[i][j-nDofs0];
 
-            int row = localIndexSet1.index(i)[0];
+                if (row[0]==1 and col[0]==0)
+                  hessian10[row[1]][col[1]] += localStiffness_->A10_[i-nDofs0][j];
 
-            for (int j=0; j<nDofs0; j++ ) {
-                int col = localIndexSet0.index(j)[0];
-                hessian10[row][col] += localStiffness_->A10_[i][j];
+                if (row[0]==1 and col[0]==1)
+                  hessian11[row[1]][col[1]] += localStiffness_->A11_[i-nDofs0][j-nDofs0];
             }
 
-            for (int j=0; j<nDofs1; j++ ) {
-                int col = localIndexSet1.index(j)[0];
-                hessian11[row][col] += localStiffness_->A11_[i][j];
-            }
+            // Add local gradient to global gradient
+            if (localIndexSet.index(i)[0] == 0)
+              gradient0[localIndexSet.index(i)[1]] += localGradient0[i];
+            else
+              gradient1[localIndexSet.index(i)[1]] += localGradient1[i-nDofs0];
         }
 
-        // Add local gradient to global gradient
-        for (int i=0; i<nDofs0; i++)
-            gradient0[localIndexSet0.index(i)[0]] += localGradient0[i];
-
-        for (int i=0; i<nDofs1; i++)
-            gradient1[localIndexSet1.index(i)[0]] += localGradient1[i];
     }
-
 }
 
 #if 0
@@ -308,53 +276,49 @@ assembleGradient(const std::vector<TargetSpace>& sol,
 }
 #endif
 
-template <class Basis0, class TargetSpace0, class Basis1, class TargetSpace1>
-double MixedGFEAssembler<Basis0, TargetSpace0, Basis1, TargetSpace1>::
+template <class Basis, class TargetSpace0, class TargetSpace1>
+double MixedGFEAssembler<Basis, TargetSpace0, TargetSpace1>::
 computeEnergy(const std::vector<TargetSpace0>& configuration0,
               const std::vector<TargetSpace1>& configuration1) const
 {
     double energy = 0;
 
-    if (configuration0.size()!=basis0_.indexSet().size())
+    if (configuration0.size()!=basis_.size({0}))
         DUNE_THROW(Dune::Exception, "Configuration vector 0 doesn't match the basis!");
 
-    if (configuration1.size()!=basis1_.indexSet().size())
+    if (configuration1.size()!=basis_.size({1}))
         DUNE_THROW(Dune::Exception, "Configuration vector 1 doesn't match the basis!");
 
     // A view on the FE basis on a single element
-    auto localView0 = basis0_.localView();
-    auto localView1 = basis1_.localView();
-    auto localIndexSet0 = basis0_.indexSet().localIndexSet();
-    auto localIndexSet1 = basis1_.indexSet().localIndexSet();
-
-    ElementIterator it    = basis0_.gridView().template begin<0,Dune::Interior_Partition>();
-    ElementIterator endIt = basis0_.gridView().template end<0,Dune::Interior_Partition>();
+    auto localView = basis_.localView();
+    auto localIndexSet = basis_.localIndexSet();
 
     // Loop over all elements
-    for (; it!=endIt; ++it) {
-
+    for (const auto& element : elements(basis_.gridView(), Dune::Partitions::interior))
+    {
         // Bind the local FE basis view to the current element
-        localView0.bind(*it);
-        localView1.bind(*it);
-        localIndexSet0.bind(localView0);
-        localIndexSet1.bind(localView1);
+        localView.bind(element);
+        localIndexSet.bind(localView);
 
         // Number of degrees of freedom on this element
-        size_t nDofs0 = localView0.size();
-        size_t nDofs1 = localView1.size();
+        using namespace Dune::TypeTree::Indices;
+        const int nDofs0 = localView.tree().child(_0).finiteElement().size();
+        const int nDofs1 = localView.tree().child(_1).finiteElement().size();
 
         std::vector<TargetSpace0> localConfiguration0(nDofs0);
         std::vector<TargetSpace1> localConfiguration1(nDofs1);
 
-        for (size_t i=0; i<nDofs0; i++)
-            localConfiguration0[i] = configuration0[localIndexSet0.index(i)[0]];
-
-        for (size_t i=0; i<nDofs1; i++)
-            localConfiguration1[i] = configuration1[localIndexSet1.index(i)[0]];
+        for (int i=0; i<nDofs0+nDofs1; i++)
+        {
+          if (localIndexSet.index(i)[0] == 0)
+            localConfiguration0[i] = configuration0[localIndexSet.index(i)[1]];
+          else
+            localConfiguration1[i-nDofs0] = configuration1[localIndexSet.index(i)[1]];
+        }
 
-        energy += localStiffness_->energy(*it,
-                                          localView0.tree().finiteElement(), localConfiguration0,
-                                          localView1.tree().finiteElement(), localConfiguration1);
+        energy += localStiffness_->energy(localView,
+                                          localConfiguration0,
+                                          localConfiguration1);
 
     }
 
diff --git a/dune/gfe/mixedlocalgeodesicfestiffness.hh b/dune/gfe/mixedlocalgeodesicfestiffness.hh
index bc0a827c..2afa40a2 100644
--- a/dune/gfe/mixedlocalgeodesicfestiffness.hh
+++ b/dune/gfe/mixedlocalgeodesicfestiffness.hh
@@ -7,18 +7,12 @@
 #include <dune/istl/matrix.hh>
 
 
-template<class DeformationBasis, class DeformationTargetSpace,
-         class OrientationBasis, class OrientationTargetSpace>
+template<class Basis, class DeformationTargetSpace, class OrientationTargetSpace>
 class MixedLocalGeodesicFEStiffness
 {
-    static_assert(std::is_same<typename DeformationBasis::GridView, typename OrientationBasis::GridView>::value,
-                  "DeformationBasis and OrientationBasis must be designed on the same GridView!");
-
     // grid types
-    typedef typename DeformationBasis::LocalView::Tree::FiniteElement DeformationLocalFiniteElement;
-    typedef typename OrientationBasis::LocalView::Tree::FiniteElement OrientationLocalFiniteElement;
-    typedef typename DeformationBasis::GridView GridView;
-    typedef typename GridView::Grid::ctype DT;
+    typedef typename Basis::GridView GridView;
+    typedef typename GridView::ctype DT;
     typedef typename DeformationTargetSpace::ctype RT;
     typedef typename GridView::template Codim<0>::Entity Entity;
 
@@ -44,10 +38,8 @@ public:
     We compute that using a finite difference approximation.
 
     */
-    virtual void assembleGradientAndHessian(const Entity& e,
-                                            const DeformationLocalFiniteElement& displacementLocalFiniteElement,
+    virtual void assembleGradientAndHessian(const typename Basis::LocalView& localView,
                                             const std::vector<DeformationTargetSpace>& localDisplacementConfiguration,
-                                            const OrientationLocalFiniteElement& orientationLocalFiniteElement,
                                             const std::vector<OrientationTargetSpace>& localOrientationConfiguration,
                                             std::vector<typename DeformationTargetSpace::TangentVector>& localDeformationGradient,
                                             std::vector<typename OrientationTargetSpace::TangentVector>& localOrientationGradient)
@@ -56,10 +48,8 @@ public:
     }
 
     /** \brief Compute the energy at the current configuration */
-    virtual RT energy (const Entity& e,
-                       const DeformationLocalFiniteElement& deformationLocalFiniteElement,
+    virtual RT energy (const typename Basis::LocalView& localView,
                        const std::vector<DeformationTargetSpace>& localDeformationConfiguration,
-                       const OrientationLocalFiniteElement& orientationLocalFiniteElement,
                        const std::vector<OrientationTargetSpace>& localOrientationConfiguration) const = 0;
 #if 0
     /** \brief Assemble the element gradient of the energy functional
diff --git a/dune/gfe/mixedlocalgfeadolcstiffness.hh b/dune/gfe/mixedlocalgfeadolcstiffness.hh
index 9ba08d40..ddf89f96 100644
--- a/dune/gfe/mixedlocalgfeadolcstiffness.hh
+++ b/dune/gfe/mixedlocalgfeadolcstiffness.hh
@@ -18,20 +18,13 @@
 
 /** \brief Assembles energy gradient and Hessian with ADOL-C (automatic differentiation)
  */
-template<class Basis0, class TargetSpace0,
-         class Basis1, class TargetSpace1>
+template<class Basis, class TargetSpace0, class TargetSpace1>
 class MixedLocalGFEADOLCStiffness
-    : public MixedLocalGeodesicFEStiffness<Basis0,TargetSpace0,
-                                           Basis1,TargetSpace1>
+    : public MixedLocalGeodesicFEStiffness<Basis,TargetSpace0,TargetSpace1>
 {
-    static_assert(std::is_same<typename Basis0::GridView, typename Basis1::GridView>::value,
-                  "Basis0 and Basis1 must be designed on the same GridView!");
-
     // grid types
-    typedef typename Basis0::GridView GridView;
-    typedef typename Basis0::LocalView::Tree::FiniteElement LocalFiniteElement0;
-    typedef typename Basis1::LocalView::Tree::FiniteElement LocalFiniteElement1;
-    typedef typename GridView::Grid::ctype DT;
+    typedef typename Basis::GridView GridView;
+    typedef typename GridView::ctype DT;
     typedef typename TargetSpace0::ctype RT;
     typedef typename GridView::template Codim<0>::Entity Entity;
 
@@ -52,16 +45,14 @@ public:
     enum { embeddedBlocksize0 = TargetSpace0::EmbeddedTangentVector::dimension };
     enum { embeddedBlocksize1 = TargetSpace1::EmbeddedTangentVector::dimension };
 
-    MixedLocalGFEADOLCStiffness(const MixedLocalGeodesicFEStiffness<Basis0, ATargetSpace0,
-                                                                    Basis1, ATargetSpace1>* energy)
+    MixedLocalGFEADOLCStiffness(const MixedLocalGeodesicFEStiffness<Basis, ATargetSpace0,
+                                                                    ATargetSpace1>* energy)
     : localEnergy_(energy)
     {}
 
     /** \brief Compute the energy at the current configuration */
-    virtual RT energy (const Entity& e,
-                       const LocalFiniteElement0& localFiniteElement0,
+    virtual RT energy (const typename Basis::LocalView& localView,
                        const std::vector<TargetSpace0>& localConfiguration0,
-                       const LocalFiniteElement1& localFiniteElement1,
                        const std::vector<TargetSpace1>& localConfiguration1) const;
 #if 0
     /** \brief Assemble the element gradient of the energy functional
@@ -77,26 +68,22 @@ public:
 
     This uses the automatic differentiation toolbox ADOL_C.
     */
-    virtual void assembleGradientAndHessian(const Entity& e,
-                                            const LocalFiniteElement0& localFiniteElement0,
+    virtual void assembleGradientAndHessian(const typename Basis::LocalView& localView,
                                             const std::vector<TargetSpace0>& localConfiguration0,
-                                            const LocalFiniteElement1& localFiniteElement1,
                                             const std::vector<TargetSpace1>& localConfiguration1,
                                             std::vector<typename TargetSpace0::TangentVector>& localGradient0,
                                             std::vector<typename TargetSpace1::TangentVector>& localGradient1);
 
-    const MixedLocalGeodesicFEStiffness<Basis0, ATargetSpace0, Basis1, ATargetSpace1>* localEnergy_;
+    const MixedLocalGeodesicFEStiffness<Basis, ATargetSpace0, ATargetSpace1>* localEnergy_;
 
 };
 
 
-template <class Basis0, class TargetSpace0, class Basis1, class TargetSpace1>
-typename MixedLocalGFEADOLCStiffness<Basis0, TargetSpace0, Basis1, TargetSpace1>::RT
-MixedLocalGFEADOLCStiffness<Basis0, TargetSpace0, Basis1, TargetSpace1>::
-energy(const Entity& element,
-       const LocalFiniteElement0& localFiniteElement0,
+template <class Basis, class TargetSpace0, class TargetSpace1>
+typename MixedLocalGFEADOLCStiffness<Basis, TargetSpace0, TargetSpace1>::RT
+MixedLocalGFEADOLCStiffness<Basis, TargetSpace0, TargetSpace1>::
+energy(const typename Basis::LocalView& localView,
        const std::vector<TargetSpace0>& localConfiguration0,
-       const LocalFiniteElement1& localFiniteElement1,
        const std::vector<TargetSpace1>& localConfiguration1) const
 {
     double pureEnergy;
@@ -134,9 +121,10 @@ energy(const Entity& element,
       localAConfiguration1[i] = aRaw1[i];  // may contain a projection onto M -- needs to be done in adouble
     }
 
-    energy = localEnergy_->energy(element,
-                                  localFiniteElement0,localAConfiguration0,
-                                  localFiniteElement1,localAConfiguration1);
+    using namespace Dune::TypeTree::Indices;
+    energy = localEnergy_->energy(localView,
+                                  localAConfiguration0,
+                                  localAConfiguration1);
 
     energy >>= pureEnergy;
 
@@ -201,18 +189,16 @@ assembleGradient(const Entity& element,
 //   To compute the Hessian we need to compute the gradient anyway, so we may
 //   as well return it.  This saves assembly time.
 // ///////////////////////////////////////////////////////////
-template <class Basis0, class TargetSpace0, class Basis1, class TargetSpace1>
-void MixedLocalGFEADOLCStiffness<Basis0, TargetSpace0, Basis1, TargetSpace1>::
-assembleGradientAndHessian(const Entity& element,
-                           const LocalFiniteElement0& localFiniteElement0,
+template <class Basis, class TargetSpace0, class TargetSpace1>
+void MixedLocalGFEADOLCStiffness<Basis, TargetSpace0, TargetSpace1>::
+assembleGradientAndHessian(const typename Basis::LocalView& localView,
                            const std::vector<TargetSpace0>& localConfiguration0,
-                           const LocalFiniteElement1& localFiniteElement1,
                            const std::vector<TargetSpace1>& localConfiguration1,
                            std::vector<typename TargetSpace0::TangentVector>& localGradient0,
                            std::vector<typename TargetSpace1::TangentVector>& localGradient1)
 {
     // Tape energy computation.  We may not have to do this every time, but it's comparatively cheap.
-    energy(element, localFiniteElement0, localConfiguration0, localFiniteElement1, localConfiguration1);
+    energy(localView, localConfiguration0, localConfiguration1);
 
     /////////////////////////////////////////////////////////////////
     // Compute the gradient.  It is needed to transform the Hessian
@@ -517,9 +503,6 @@ assembleGradientAndHessian(const Entity& element,
       }
 
     }
-
-//     std::cout << "ADOL-C stiffness:\n";
-//     printmatrix(std::cout, this->A_, "foo", "--");
 }
 
 #endif
diff --git a/dune/gfe/mixedriemanniantrsolver.cc b/dune/gfe/mixedriemanniantrsolver.cc
index a33ee1ef..1c83bae1 100644
--- a/dune/gfe/mixedriemanniantrsolver.cc
+++ b/dune/gfe/mixedriemanniantrsolver.cc
@@ -30,11 +30,14 @@
 #include <dune/gfe/cosseratvtkwriter.hh>
 
 template <class GridType,
+          class Basis,
           class Basis0, class TargetSpace0,
           class Basis1, class TargetSpace1>
-void MixedRiemannianTrustRegionSolver<GridType,Basis0,TargetSpace0,Basis1,TargetSpace1>::
+void MixedRiemannianTrustRegionSolver<GridType,Basis,Basis0,TargetSpace0,Basis1,TargetSpace1>::
 setup(const GridType& grid,
-      const MixedGFEAssembler<Basis0, TargetSpace0, Basis1, TargetSpace1>* assembler,
+      const MixedGFEAssembler<Basis, TargetSpace0, TargetSpace1>* assembler,
+      const Basis0& tmpBasis0,
+      const Basis1& tmpBasis1,
          const SolutionType& x,
          const Dune::BitSetVector<blocksize0>& dirichletNodes0,
          const Dune::BitSetVector<blocksize1>& dirichletNodes1,
@@ -54,6 +57,8 @@ setup(const GridType& grid,
 
     grid_                     = &grid;
     assembler_                = assembler;
+    basis0_                   = std::unique_ptr<Basis0>(new Basis0(tmpBasis0));
+    basis1_                   = std::unique_ptr<Basis1>(new Basis1(tmpBasis1));
     x_                        = x;
     tolerance_                = tolerance;
     maxTrustRegionSteps_      = maxTrustRegionSteps;
@@ -206,7 +211,7 @@ setup(const GridType& grid,
         TransferOperatorType pkToP1TransferMatrix;
         assembleBasisInterpolationMatrix<TransferOperatorType,
                                          DuneFunctionsBasis<Dune::Functions::PQkNodalBasis<typename GridType::LeafGridView,1> >,
-                                         FufemBasis0>(pkToP1TransferMatrix,p1Basis,assembler->basis0_);
+                                         FufemBasis0>(pkToP1TransferMatrix,p1Basis,*basis0_);
 
         mmgStep0->mgTransfer_.back() = new TruncatedCompressedMGTransfer<CorrectionType0>;
         Dune::shared_ptr<TransferOperatorType> topTransferOperator = Dune::make_shared<TransferOperatorType>(pkToP1TransferMatrix);
@@ -246,7 +251,7 @@ setup(const GridType& grid,
         TransferOperatorType pkToP1TransferMatrix;
         assembleBasisInterpolationMatrix<TransferOperatorType,
                                          DuneFunctionsBasis<Dune::Functions::PQkNodalBasis<typename GridType::LeafGridView,1> >,
-                                         FufemBasis1>(pkToP1TransferMatrix,p1Basis,assembler->basis1_);
+                                         FufemBasis1>(pkToP1TransferMatrix,p1Basis,*basis1_);
 
         mmgStep1->mgTransfer_.back() = new TruncatedCompressedMGTransfer<CorrectionType1>;
         Dune::shared_ptr<TransferOperatorType> topTransferOperator = Dune::make_shared<TransferOperatorType>(pkToP1TransferMatrix);
@@ -283,8 +288,8 @@ setup(const GridType& grid,
       #if 0
       hasObstacle0_.resize(guIndex_->nGlobalEntity(), true);
       #else
-      hasObstacle0_.resize(assembler->basis0_.indexSet().size(), true);
-      hasObstacle1_.resize(assembler->basis1_.indexSet().size(), true);
+      hasObstacle0_.resize(assembler->basis_.size({0}), true);
+      hasObstacle1_.resize(assembler->basis_.size({1}), true);
       #endif
       mmgStep0->hasObstacle_ = &hasObstacle0_;
       mmgStep1->hasObstacle_ = &hasObstacle1_;
@@ -294,9 +299,10 @@ setup(const GridType& grid,
 
 
 template <class GridType,
+          class Basis,
           class Basis0, class TargetSpace0,
           class Basis1, class TargetSpace1>
-void MixedRiemannianTrustRegionSolver<GridType,Basis0,TargetSpace0,Basis1,TargetSpace1>::solve()
+void MixedRiemannianTrustRegionSolver<GridType,Basis,Basis0,TargetSpace0,Basis1,TargetSpace1>::solve()
 {
     int argc = 0;
     char** argv;
@@ -304,8 +310,8 @@ void MixedRiemannianTrustRegionSolver<GridType,Basis0,TargetSpace0,Basis1,Target
     int rank = grid_->comm().rank();
 
     // \todo Use global index set instead of basis for parallel computations
-    MaxNormTrustRegion<blocksize0> trustRegion0(assembler_->basis0_.indexSet().size(), initialTrustRegionRadius_);
-    MaxNormTrustRegion<blocksize1> trustRegion1(assembler_->basis1_.indexSet().size(), initialTrustRegionRadius_);
+    MaxNormTrustRegion<blocksize0> trustRegion0(assembler_->basis_.size({0}), initialTrustRegionRadius_);
+    MaxNormTrustRegion<blocksize1> trustRegion1(assembler_->basis_.size({1}), initialTrustRegionRadius_);
     trustRegion0.set(initialTrustRegionRadius_, std::get<0>(scaling_));
     trustRegion1.set(initialTrustRegionRadius_, std::get<1>(scaling_));
 
@@ -549,7 +555,7 @@ void MixedRiemannianTrustRegionSolver<GridType,Basis0,TargetSpace0,Basis1,Target
             if (this->verbosity_ == NumProc::FULL and rank==0)
                 std::cout << "Unsuccessful iteration!" << std::endl;
         }
-
+#if 0
         // Output each iterate, to better understand what the algorithm does
         DuneFunctionsBasis<Basis0> fufemBasis0(assembler_->basis0_);
         DuneFunctionsBasis<Basis1> fufemBasis1(assembler_->basis1_);
@@ -558,7 +564,7 @@ void MixedRiemannianTrustRegionSolver<GridType,Basis0,TargetSpace0,Basis1,Target
         CosseratVTKWriter<GridType>::template writeMixed<DuneFunctionsBasis<Basis0>, DuneFunctionsBasis<Basis1> >(fufemBasis0,x_[_0],
                                                                         fufemBasis1,x_[_1],
                                                                         "mixed-cosserat_iterate_" + iAsAscii.str());
-
+#endif
         if (rank==0)
             std::cout << "iteration took " << totalTimer.elapsed() << " sec." << std::endl;
 
diff --git a/dune/gfe/mixedriemanniantrsolver.hh b/dune/gfe/mixedriemanniantrsolver.hh
index fd3da73f..c59aa8b9 100644
--- a/dune/gfe/mixedriemanniantrsolver.hh
+++ b/dune/gfe/mixedriemanniantrsolver.hh
@@ -23,6 +23,7 @@
 
 /** \brief Riemannian trust-region solver for geodesic finite-element problems */
 template <class GridType,
+          class Basis,
           class Basis0, class TargetSpace0,
           class Basis1, class TargetSpace1>
 class MixedRiemannianTrustRegionSolver
@@ -69,7 +70,9 @@ public:
 
     /** \brief Set up the solver using a monotone multigrid method as the inner solver */
     void setup(const GridType& grid,
-               const MixedGFEAssembler<Basis0, TargetSpace0, Basis1, TargetSpace1>* assembler,
+               const MixedGFEAssembler<Basis, TargetSpace0, TargetSpace1>* assembler,
+               const Basis0& basis0,
+               const Basis1& basis1,
                const SolutionType& x,
                const Dune::BitSetVector<blocksize0>& dirichletNodes0,
                const Dune::BitSetVector<blocksize1>& dirichletNodes1,
@@ -144,7 +147,11 @@ protected:
     std::unique_ptr<MatrixType> hessianMatrix_;
 
     /** \brief The assembler for the material law */
-    const MixedGFEAssembler<Basis0, TargetSpace0, Basis1, TargetSpace1>* assembler_;
+    const MixedGFEAssembler<Basis, TargetSpace0, TargetSpace1>* assembler_;
+
+    /** \brief TEMPORARY: The two separate matrices */
+    std::unique_ptr<Basis0> basis0_;
+    std::unique_ptr<Basis1> basis1_;
 
     /** \brief The solver for the quadratic inner problems */
     std::shared_ptr<Solver> innerSolver_;
diff --git a/src/mixed-cosserat-continuum.cc b/src/mixed-cosserat-continuum.cc
index 8d208ce4..57f6211b 100644
--- a/src/mixed-cosserat-continuum.cc
+++ b/src/mixed-cosserat-continuum.cc
@@ -24,6 +24,7 @@
 
 #include <dune/functions/common/tuplevector.hh>
 #include <dune/functions/functionspacebases/pqknodalbasis.hh>
+#include <dune/functions/functionspacebases/compositebasis.hh>
 
 #include <dune/fufem/boundarypatch.hh>
 #include <dune/fufem/functiontools/boundarydofs.hh>
@@ -144,6 +145,16 @@ int main (int argc, char *argv[]) try
     typedef GridType::LeafGridView GridView;
     GridView gridView = grid->leafGridView();
 
+    using namespace Dune::Functions::BasisBuilder;
+
+    auto compositeBasis = makeBasis(
+      gridView,
+      composite(
+          pq<2>(),
+          pq<1>()
+      )
+    );
+
     typedef Dune::Functions::PQkNodalBasis<GridView,2> DeformationFEBasis;
     typedef Dune::Functions::PQkNodalBasis<GridView,1> OrientationFEBasis;
 
@@ -285,31 +296,31 @@ int main (int argc, char *argv[]) try
         }
 
     // Assembler using ADOL-C
-    MixedCosseratEnergy<DeformationFEBasis,
-                        OrientationFEBasis,
+    MixedCosseratEnergy<decltype(compositeBasis),
                         3,adouble> cosseratEnergyADOLCLocalStiffness(materialParameters,
                                                                      &neumannBoundary,
                                                                      neumannFunction.get());
 
-    MixedLocalGFEADOLCStiffness<DeformationFEBasis,
+    MixedLocalGFEADOLCStiffness<decltype(compositeBasis),
                                 RealTuple<double,3>,
-                                OrientationFEBasis,
                                 Rotation<double,3> > localGFEADOLCStiffness(&cosseratEnergyADOLCLocalStiffness);
 
-    MixedGFEAssembler<DeformationFEBasis,
+    MixedGFEAssembler<decltype(compositeBasis),
                       RealTuple<double,3>,
-                      OrientationFEBasis,
-                      Rotation<double,3> > assembler(deformationFEBasis, orientationFEBasis, &localGFEADOLCStiffness);
+                      Rotation<double,3> > assembler(compositeBasis, &localGFEADOLCStiffness);
 
     // /////////////////////////////////////////////////
     //   Create a Riemannian trust-region solver
     // /////////////////////////////////////////////////
 
     MixedRiemannianTrustRegionSolver<GridType,
+                                     decltype(compositeBasis),
                                      DeformationFEBasis, RealTuple<double,3>,
                                      OrientationFEBasis, Rotation<double,3> > solver;
     solver.setup(*grid,
                  &assembler,
+                 deformationFEBasis,
+                 orientationFEBasis,
                  x,
                  deformationDirichletDofs,
                  orientationDirichletDofs,
-- 
GitLab