diff --git a/dune/gfe/surfacecosseratenergy.hh b/dune/gfe/surfacecosseratenergy.hh
index 366e233747f0a98c2e986b40fb8c606be0963e4f..e34f674016f88a23f353913b8981e01145634e3b 100644
--- a/dune/gfe/surfacecosseratenergy.hh
+++ b/dune/gfe/surfacecosseratenergy.hh
@@ -15,11 +15,18 @@
 #include <dune/gfe/orthogonalmatrix.hh>
 #include <dune/gfe/rigidbodymotion.hh>
 #include <dune/gfe/tensor3.hh>
-#include <dune/gfe/vertexnormals.hh>
 
-namespace Dune::GFE {
+#include <dune/curvedgeometry/curvedgeometry.hh>
+#include <dune/localfunctions/lagrange/lfecache.hh>
 
-template<class Basis, class... TargetSpaces>
+namespace Dune::GFE {
+/** \brief Assembles the cosserat energy on the given boundary for a single element.
+ *
+ * \tparam CurvedGeometryGridFunction Type of the grid function that gives the geometry of the deformed surface
+ * \tparam Basis Type of the Basis used for assembling
+ * \tparam TargetSpaces The List of TargetSpaces - SurfaceCosseratEnergy needs exactly two TargetSpaces!
+ */
+template<class CurvedGeometryGridFunction, class Basis, class... TargetSpaces>
 class SurfaceCosseratEnergy
 : public Dune::GFE::LocalEnergy<Basis, TargetSpaces...>
 {
@@ -70,16 +77,19 @@ public:
 
   /** \brief Constructor with a set of material parameters
    * \param parameters The material parameters
+   * \param shellBoundary The shellBoundary contains the faces where the cosserat energy is assembled
+   * \param curvedGeometryGridFunction The curvedGeometryGridFunction gives the geometry of the shell in stress-free state.
+            When assembling, we deform the intersections using the curvedGeometryGridFunction and then use the deformed geometries.
+   * \param thicknessF The shell thickness parameter, given as a function and evaluated at each quadrature point
+   * \param lameF The Lame parameters, given as a function and evaluated at each quadrature point
    */
   SurfaceCosseratEnergy(const Dune::ParameterTree& parameters,
-    const std::vector<UnitVector<double,dimWorld> >& vertexNormals,
     const BoundaryPatch<GridView>* shellBoundary,
-    const std::unordered_map<typename GridView::Grid::GlobalIdSet::IdType,Dune::MultiLinearGeometry<double, dimWorld-1, dimWorld>>& geometriesOnShellBoundary,
+    const CurvedGeometryGridFunction& curvedGeometryGridFunction,
     const std::function<double(Dune::FieldVector<double,dimWorld>)> thicknessF,
     const std::function<Dune::FieldVector<double,2>(Dune::FieldVector<double,dimWorld>)> lameF)
   : shellBoundary_(shellBoundary),
-    vertexNormals_(vertexNormals),
-    geometriesOnShellBoundary_(geometriesOnShellBoundary),
+    curvedGeometryGridFunction_(curvedGeometryGridFunction),
     thicknessF_(thicknessF),
     lameF_(lameF)
   {
@@ -98,7 +108,7 @@ public:
 RT energy(const typename Basis::LocalView& localView,
           const std::vector<TargetSpaces>&... localSolutions) const
 { 
-  static_assert(sizeof...(TargetSpaces) == 2, "SurfaceCosseratEnergy needs exactly two TargetSpace!");
+  static_assert(sizeof...(TargetSpaces) == 2, "SurfaceCosseratEnergy needs exactly two TargetSpaces!");
 
   using namespace Dune::Indices;
   using TargetSpace0 = typename std::tuple_element<0, std::tuple<TargetSpaces...> >::type;
@@ -110,21 +120,6 @@ RT energy(const typename Basis::LocalView& localView,
   auto element = localView.element();
   auto gridView = localView.globalBasis().gridView();
 
-  ////////////////////////////////////////////////////////////////////////////////////
-  //  Construct a linear (i.e., non-constant!) normal field on each element
-  ////////////////////////////////////////////////////////////////////////////////////
-  typedef typename Dune::PQkLocalFiniteElementCache<DT, double, gridDim, 1> P1FiniteElementCache;
-  typedef typename P1FiniteElementCache::FiniteElementType P1LocalFiniteElement;
-  //it.type()?
-  P1FiniteElementCache p1FiniteElementCache;
-  const auto& p1LocalFiniteElement = p1FiniteElementCache.get(element.type());
-
-  assert(vertexNormals_.size() == gridView.indexSet().size(gridDim));
-  std::vector<UnitVector<double,3> > cornerNormals(element.subEntities(gridDim));
-  for (size_t i=0; i<cornerNormals.size(); i++)
-    cornerNormals[i] = vertexNormals_[gridView.indexSet().subIndex(element,i,gridDim)];
-  Dune::GFE::LocalProjectedFEFunction<gridDim, DT, P1LocalFiniteElement, UnitVector<double,3> > unitNormals(p1LocalFiniteElement, cornerNormals);
-
   ////////////////////////////////////////////////////////////////////////////////////
   //  Set up the local nonlinear finite element function
   ////////////////////////////////////////////////////////////////////////////////////
@@ -159,14 +154,25 @@ RT energy(const typename Basis::LocalView& localView,
 
   RT energy = 0;
 
-  auto& idSet = gridView.grid().globalIdSet();
-
   for (auto&& it : intersections(shellBoundary_->gridView(), element)) {
     if (not shellBoundary_->contains(it))
       continue;
-    
-    auto id = idSet.subId(it.inside(), it.indexInInside(), 1);
-    auto boundaryGeometry = geometriesOnShellBoundary_.at(id);
+
+    auto localGridFunction = localFunction(curvedGeometryGridFunction_);
+    auto curvedGeometryGridFunctionOrder = deformationLocalFiniteElement.localBasis().order();//curvedGeometryGridFunction_.basis().localView().tree().child(0).finiteElement().localBasis().order();
+    localGridFunction.bind(element);
+    auto referenceElement = Dune::referenceElement<DT,boundaryDim>(it.type());
+
+    // Construct the geometry on the boundary using the map lGF(localGeometry.global(local)):
+    // The variable local holds the local coordinates in the 2D reference element, localGeometry.global maps them to the 3D reference element.
+    // The function lGF is the gridfunction bound to the current element, so lGF(localGeometry.global(local)) is the value of curvedGeometryGridFunction_ at
+    // the point on the intersection face.
+    using BoundaryGeometry = Dune::CurvedGeometry<DT, boundaryDim, dimWorld, Dune::CurvedGeometryTraits<DT, Dune::LagrangeLFECache<DT,DT,boundaryDim>>>;
+    BoundaryGeometry boundaryGeometry(referenceElement,
+      [localGridFunction, localGeometry=it.geometryInInside()](const auto& local) {
+        return localGridFunction(localGeometry.global(local));
+      }, curvedGeometryGridFunctionOrder);
+
     auto quadOrder = (it.type().isSimplex()) ? deformationLocalFiniteElement.localBasis().order()
                                                   : deformationLocalFiniteElement.localBasis().order() * boundaryDim;
 
@@ -243,7 +249,6 @@ RT energy(const typename Basis::LocalView& localView,
       // If dimWorld==3, then the first two lines of aCovariant are simply the jacobianTransposed
       // of the element.  If dimWorld<3 (i.e., ==2), we have to explicitly enters 0.0 in the last column.
       const auto jacobianTransposed = boundaryGeometry.jacobianTransposed(quad[pt].position());
-      // auto jacobianTransposed = geometry.jacobianTransposed(quadPos);
 
       for (int i=0; i<2; i++)
       {
@@ -253,7 +258,7 @@ RT energy(const typename Basis::LocalView& localView,
           aCovariant[i][j] = 0.0;
       }
 
-      aCovariant[2] = Dune::MatrixVector::crossProduct(aCovariant[0], aCovariant[1]);
+      aCovariant[2] = Dune::FMatrixHelp::Impl::crossProduct(aCovariant[0], aCovariant[1]);
       aCovariant[2] /= aCovariant[2].two_norm();
 
       auto aContravariant = aCovariant;
@@ -277,9 +282,8 @@ RT energy(const typename Basis::LocalView& localView,
         for (int beta=0; beta<2; beta++)
           c += aScalar * eps[alpha][beta] * Dune::GFE::dyadicProduct(aContravariant[alpha], aContravariant[beta]);
 
-      // Second fundamental form
-      // The derivative of the normal field
-      auto normalDerivative = unitNormals.evaluateDerivative(quadPos);
+      // Second fundamental form: The derivative of the normal field
+      auto normalDerivative = boundaryGeometry.normalGradient(quad[pt].position());
 
       Dune::FieldMatrix<double,3,3> b(0);
       for (int alpha=0; alpha<boundaryDim; alpha++)
@@ -375,11 +379,8 @@ private:
   /** \brief The shell boundary */
   const BoundaryPatch<GridView>* shellBoundary_;
 
-  /** \brief Stress-free geometries of the shell elements*/
-  const std::unordered_map<typename GridView::Grid::GlobalIdSet::IdType, Dune::MultiLinearGeometry<double, dimWorld-1, dimWorld>> geometriesOnShellBoundary_;
-
-  /** \brief The normal vectors at the grid vertices. They are used to compute the reference surface curvature. */
-  std::vector<UnitVector<double,3> > vertexNormals_;
+  /** \brief The function used to create the Geometries used for assembling */
+  const CurvedGeometryGridFunction curvedGeometryGridFunction_;
 
   /** \brief The shell thickness as a function*/
   std::function<double(Dune::FieldVector<double,dimWorld>)> thicknessF_;
diff --git a/problems/film-on-substrate.parset b/problems/film-on-substrate.parset
index dd3d24ace40844d69c626da2c59aa01d1abab4bf..61e3fde5626f44f40bbf3d0198a212b7559a7f9f 100644
--- a/problems/film-on-substrate.parset
+++ b/problems/film-on-substrate.parset
@@ -18,6 +18,7 @@ numLevels = 1
 startFromFile = false
 pathToGridDeformationFile = ./
 gridDeformationFile = deformation
+writeOutStressFreeData = true
 
 # When not starting from a file, deformation of the surface shell part can be given here using the gridDeformation function
 gridDeformation="[1.3*x[0], x[1], x[2]]"
diff --git a/src/film-on-substrate.cc b/src/film-on-substrate.cc
index 3921c142c19f218566896afbbd85f3f82f206df6..883b37eb7ae94bf12c458c317e9789c14adb9a83 100644
--- a/src/film-on-substrate.cc
+++ b/src/film-on-substrate.cc
@@ -47,7 +47,6 @@
 #include <dune/gfe/neumannenergy.hh>
 #include <dune/gfe/surfacecosseratenergy.hh>
 #include <dune/gfe/sumenergy.hh>
-#include <dune/gfe/vertexnormals.hh>
 
 #if MIXED_SPACE
 #include <dune/gfe/mixedriemanniantrsolver.hh>
@@ -63,6 +62,7 @@
 #include <dune/solvers/solvers/iterativesolver.hh>
 #include <dune/solvers/norms/energynorm.hh>
 
+
 // grid dimension
 #ifndef WORLD_DIM
 #  define WORLD_DIM 3
@@ -74,6 +74,8 @@ const int targetDim = WORLD_DIM;
 const int displacementOrder = 2;
 const int rotationOrder = 2;
 
+const int stressFreeDataOrder = 2;
+
 #if !MIXED_SPACE
 static_assert(displacementOrder==rotationOrder, "displacement and rotation order do not match!");
 #endif
@@ -343,32 +345,28 @@ int main (int argc, char *argv[]) try
   vtkWriter.write(resultPath + "finite-strain_homotopy_" + parameterSet.get<std::string>("energy") + "_0");
   
   /////////////////////////////////////////////////////////////
-  //               INITIAL SURFACE SHELL DATA
+  //               STRESS-FREE SURFACE SHELL DATA
   /////////////////////////////////////////////////////////////
+  auto stressFreeFEBasis = makeBasis(
+    gridView,
+    power<dim>(
+      lagrange<stressFreeDataOrder>(),
+      blockedInterleaved()
+  ));
 
-  typedef MultiLinearGeometry<double, dim-1, dim> ML;
-  std::unordered_map<GridType::GlobalIdSet::IdType, ML> geometriesOnShellBoundary;
-  
   auto& idSet = grid->globalIdSet();
+  GlobalIndexSet<GridView> globalVertexIndexSet(gridView,dim);
+  BlockVector<FieldVector<double,dim> > stressFreeShellVector(stressFreeFEBasis.size());
 
-  // Read in the grid deformation
   if (startFromFile) {
-    // Create a basis of order 1 in order to deform the geometries on the surface shell boundary
     const std::string pathToGridDeformationFile = parameterSet.get("pathToGridDeformationFile", "");
-    // for this, we create a basis of order 1 in order to deform the geometries on the surface shell boundary
-    auto feBasisOrder1 = makeBasis(
-      gridView,
-      power<dim>(
-        lagrange<1>(),
-        blockedInterleaved()
-    ));
-    GlobalIndexSet<GridView> globalVertexIndexSet(gridView,dim);
 
     std::unordered_map<std::string, FieldVector<double,3>> deformationMap;
     std::string line, displacement, entry;
     if (mpiHelper.rank() == 0) 
-      std::cout << "Reading in deformation file: " << pathToGridDeformationFile + parameterSet.get<std::string>("gridDeformationFile") << std::endl;
+      std::cout << "Reading in deformation file ("  << "order is "  << stressFreeDataOrder  << "): " << pathToGridDeformationFile + parameterSet.get<std::string>("gridDeformationFile") << std::endl;
     // Read grid deformation information from the file specified in the parameter set via gridDeformationFile
+
     std::ifstream file(pathToGridDeformationFile + parameterSet.get<std::string>("gridDeformationFile"), std::ios::in);
     if (file.is_open()) {
       while (std::getline(file, line)) {
@@ -384,62 +382,40 @@ int main (int argc, char *argv[]) try
       }
       if (mpiHelper.rank() == 0)
         std::cout << "... done: The grid has " << globalVertexIndexSet.size(dim) << " vertices and the defomation file has " << deformationMap.size() << " entries." << std::endl;
-      if (deformationMap.size() != globalVertexIndexSet.size(dim))
+      if (stressFreeDataOrder == 1 && deformationMap.size() != globalVertexIndexSet.size(dim))
         DUNE_THROW(Exception, "Error: Grid and deformation vector do not match!");
       file.close();
     } else {
       DUNE_THROW(Exception, "Error: Could not open the file containing the deformation vector!");
     }
-  
-    BlockVector<FieldVector<double,dim>> gridDeformationFromFile;
-    Dune::Functions::interpolate(feBasisOrder1, gridDeformationFromFile, [](FieldVector<double,dim> x){ return x; });
+    Dune::Functions::interpolate(stressFreeFEBasis, stressFreeShellVector, [](FieldVector<double,dim> x){ return x; });
 
-    for (auto& entry : gridDeformationFromFile) {
+    for (auto& entry : stressFreeShellVector) {
       std::stringstream stream;
       stream << entry;
-      entry = deformationMap.at(stream.str()); //Look up the deformation for this vertex in the deformationMap
-    }
-
-    auto gridDeformationFromFileFunction = Dune::Functions::makeDiscreteGlobalBasisFunction<FieldVector<double,dim>>(feBasisOrder1, gridDeformationFromFile);
-    auto localGridDeformationFromFileFunction = localFunction(gridDeformationFromFileFunction);
-
-    //Write out the stress-free geometries that were read in
-    SubsamplingVTKWriter<GridView> vtkWriter(gridView, Dune::refinementLevels(0));
-    vtkWriter.addVertexData(localGridDeformationFromFileFunction, VTK::FieldInfo("displacement", VTK::FieldInfo::Type::scalar, dim));
-    vtkWriter.write("stress-free-geometries");
-
-    //Iterate over boundary, each facet on the boundary has an element (boundaryElement.inside()) with a unique global id (idSet.subId);
-    //we store the new geometry in the map with this id as reference
-    for (auto boundaryElement : surfaceShellBoundary) {
-      localGridDeformationFromFileFunction.bind(boundaryElement.inside());
-      std::vector<Dune::FieldVector<double,dim>> corners;
-      for (int i = 0; i < boundaryElement.geometry().corners(); i++) {
-        auto corner = boundaryElement.geometry().corner(i);
-        corner += localGridDeformationFromFileFunction(boundaryElement.inside().geometry().local(boundaryElement.geometry().corner(i)));
-        corners.push_back(corner);
-      }
-      localGridDeformationFromFileFunction.unbind();
-      GridType::GlobalIdSet::IdType id = idSet.subId(boundaryElement.inside(), boundaryElement.indexInInside(), 1);
-      ML boundaryGeometry(boundaryElement.geometry().type(), corners);
-      geometriesOnShellBoundary.insert({id, boundaryGeometry});
+      entry += deformationMap.at(stream.str()); //Look up the displacement for this vertex in the deformationMap
     }
   } else {
     // Read grid deformation from deformation function
     auto gridDeformationLambda = std::string("lambda x: (") + parameterSet.get<std::string>("gridDeformation") + std::string(")");
     auto gridDeformation = Python::make_function<FieldVector<double,dim> >(Python::evaluate(gridDeformationLambda));
+    Dune::Functions::interpolate(stressFreeFEBasis, stressFreeShellVector, gridDeformation);
+  }
 
-    //Iterate over boundary, each facet on the boundary has an element (boundaryElement.inside()) with a unique global id (idSet.subId);
-    //we store the new geometry in the map with this id as reference
-    for (auto boundaryElement : surfaceShellBoundary) {
-      std::vector<Dune::FieldVector<double,dim>> corners;
-      for (int i = 0; i < boundaryElement.geometry().corners(); i++) {
-        auto corner = gridDeformation(boundaryElement.geometry().corner(i));
-        corners.push_back(corner);
-      }
-      GridType::GlobalIdSet::IdType id = idSet.subId(boundaryElement.inside(), boundaryElement.indexInInside(), 1);
-      ML boundaryGeometry(boundaryElement.geometry().type(), corners);
-      geometriesOnShellBoundary.insert({id, boundaryGeometry});
+  auto stressFreeShellFunction = Dune::Functions::makeDiscreteGlobalBasisFunction<FieldVector<double,dim>>(stressFreeFEBasis, stressFreeShellVector);
+  
+  if (parameterSet.hasKey("writeOutStressFreeData") && parameterSet.get<bool>("writeOutStressFreeData")) {
+    BlockVector<FieldVector<double,dim> > stressFreeDisplacement(stressFreeFEBasis.size());
+    Dune::Functions::interpolate(stressFreeFEBasis, stressFreeDisplacement, [](FieldVector<double,dim> x){ return (-1.0)*x; });
+
+    for (int i = 0; i < stressFreeFEBasis.size(); i++) {
+      stressFreeDisplacement[i] += stressFreeShellVector[i];
     }
+    auto stressFreeDisplacementFunction = Dune::Functions::makeDiscreteGlobalBasisFunction<FieldVector<double,dim>>(stressFreeFEBasis, stressFreeDisplacement);
+    //Write out the stress-free shell function that was read in
+    SubsamplingVTKWriter<GridView> vtkWriterStressFree(gridView, Dune::refinementLevels(1));
+    vtkWriterStressFree.addVertexData(stressFreeDisplacementFunction, VTK::FieldInfo("displacement", VTK::FieldInfo::Type::scalar, dim));
+    vtkWriterStressFree.write("stress-free-shell-function");
   }
 
   /////////////////////////////////////////////////////////////
@@ -450,13 +426,6 @@ int main (int argc, char *argv[]) try
     neumannValues = parameterSet.get<FieldVector<double,dim> >("neumannValues");
   std::cout << "Neumann values: " << neumannValues << std::endl;
 
-  // Vertex Normals for the 3D-Part
-  std::vector<UnitVector<double,dim> > vertexNormals(gridView.size(dim));
-  Dune::FieldVector<double,dim> vertexNormalRaw = {0,0,1};
-  for (int i = 0; i< vertexNormals.size(); i++) {
-    UnitVector vertexNormal(vertexNormalRaw);
-    vertexNormals[i] = vertexNormal;
-  }
   //Function for the Cosserat material parameters
   const ParameterTree& materialParameters = parameterSet.sub("materialParameters");
   Python::Reference surfaceShellClass = Python::import(materialParameters.get<std::string>("surfaceShellParameters"));
@@ -512,7 +481,13 @@ int main (int argc, char *argv[]) try
 
     auto elasticEnergy = std::make_shared<GFE::LocalIntegralEnergy<CompositeBasis, RealTuple<ValueType,targetDim>, Rotation<ValueType,dim>>>(elasticDensity);
     auto neumannEnergy = std::make_shared<GFE::NeumannEnergy<CompositeBasis, RealTuple<ValueType,targetDim>, Rotation<ValueType,dim>>>(neumannBoundary,*neumannFunctionPtr);
-    auto surfaceCosseratEnergy = std::make_shared<GFE::SurfaceCosseratEnergy<CompositeBasis, RealTuple<ValueType,dim>, Rotation<ValueType,dim> >>(materialParameters, std::move(vertexNormals), &surfaceShellBoundary, std::move(geometriesOnShellBoundary), fThickness, fLame);
+    auto surfaceCosseratEnergy = std::make_shared<GFE::SurfaceCosseratEnergy<
+        decltype(stressFreeShellFunction), CompositeBasis, RealTuple<ValueType,dim>, Rotation<ValueType,dim> >>(
+          materialParameters,
+          &surfaceShellBoundary,
+          stressFreeShellFunction,
+          fThickness,
+          fLame);
 
     GFE::SumEnergy<CompositeBasis, RealTuple<ValueType,targetDim>, Rotation<ValueType,targetDim>> sumEnergy;
     sumEnergy.addLocalEnergy(neumannEnergy);