diff --git a/dune/gfe/rodassembler.hh b/dune/gfe/rodassembler.hh
index cf93563fc6770751eee0957d85411165065d1c61..a800cc8ba1bc1626a4ca33ce16483b228fea3507 100644
--- a/dune/gfe/rodassembler.hh
+++ b/dune/gfe/rodassembler.hh
@@ -41,12 +41,9 @@ class RodAssembler<Basis,3> : public GeodesicFEAssembler<Basis, RigidBodyMotion<
 public:
         //! ???
     RodAssembler(const Basis& basis,
-                 RodLocalStiffness<GridView,double>* localStiffness)
-        : GeodesicFEAssembler<Basis, RigidBodyMotion<double,3> >(basis,nullptr)
-        , rodEnergy_(localStiffness)
+                 LocalGeodesicFEStiffness<Basis, RigidBodyMotion<double,3> >* localStiffness)
+    : GeodesicFEAssembler<Basis, RigidBodyMotion<double,3> >(basis,localStiffness)
         {
-            this->localStiffness_ = new LocalGeodesicFEFDStiffness<Basis,RigidBodyMotion<double,3>, double>(localStiffness);
-
             std::vector<RigidBodyMotion<double,3> > referenceConfiguration(basis.size());
 
     for (const auto vertex : Dune::vertices(basis.gridView()))
@@ -59,12 +56,19 @@ public:
                 referenceConfiguration[idx].q = Rotation<double,3>::identity();
             }
 
-    rodEnergy_->setReferenceConfiguration(referenceConfiguration);
+    rodEnergy()->setReferenceConfiguration(referenceConfiguration);
         }
 
+    auto rodEnergy()
+    {
+      // TODO: Does not work for other stiffness implementations
+      auto localFDStiffness = dynamic_cast<LocalGeodesicFEFDStiffness<Basis, RigidBodyMotion<double,3> >*>(this->localStiffness_);
+      return const_cast<RodLocalStiffness<GridView,double>*>(dynamic_cast<const RodLocalStiffness<GridView,double>*>(localFDStiffness->localEnergy_));
+    }
+
         std::vector<RigidBodyMotion<double,3> > getRefConfig()
     {
-      return rodEnergy_->referenceConfiguration_;
+      return rodEnergy()->referenceConfiguration_;
         }
 
   virtual void assembleGradient(const std::vector<RigidBodyMotion<double,3> >& sol,
@@ -82,8 +86,6 @@ public:
         template <class PatchGridView>
         Dune::FieldVector<double,6> getResultantForce(const BoundaryPatch<PatchGridView>& boundary,
                                                       const std::vector<RigidBodyMotion<double,3> >& sol) const;
-
-    RodLocalStiffness<GridView,double>* rodEnergy_;
     }; // end class
 
 
diff --git a/src/rod3d.cc b/src/rod3d.cc
index 6eda1ec68973059a28ca404d8f26def8842f32a2..90ac7842faf41c97a7c261a70d45e12982fe0cd6 100644
--- a/src/rod3d.cc
+++ b/src/rod3d.cc
@@ -128,7 +128,9 @@ int main (int argc, char *argv[]) try
     RodLocalStiffness<GridView,double> localStiffness(gridView,
                                                       A, J1, J2, E, nu);
 
-    RodAssembler<FEBasis,3> rodAssembler(gridView, &localStiffness);
+    LocalGeodesicFEFDStiffness<FEBasis,RigidBodyMotion<double,3> > localFDStiffness(&localStiffness);
+
+    RodAssembler<FEBasis,3> rodAssembler(gridView, &localFDStiffness);
 
     RiemannianTrustRegionSolver<FEBasis,RigidBodyMotion<double,3> > rodSolver;
 
diff --git a/test/frameinvariancetest.cc b/test/frameinvariancetest.cc
index 578627b6d96a1c806d464330b5be2c63141ecdf8..56c00781b97548aa00aecbdbc3ea8db1501b53c6 100644
--- a/test/frameinvariancetest.cc
+++ b/test/frameinvariancetest.cc
@@ -74,10 +74,12 @@ int main (int argc, char *argv[]) try
         rotatedX[i].q = rotation.mult(x[i].q);
     }
 
-    RodLocalStiffness<GridView,double> localStiffness(gridView,
-                                                      1,1,1,1e6,0.3);
+    RodLocalStiffness<GridView,double> localRodFirstOrderModel(gridView,
+                                                               1,1,1,1e6,0.3);
 
-    RodAssembler<FEBasis,3> assembler(feBasis, &localStiffness);
+    LocalGeodesicFEFDStiffness<FEBasis,RigidBodyMotion<double,3> > localFDStiffness(&localRodFirstOrderModel);
+
+    RodAssembler<FEBasis,3> assembler(feBasis, &localFDStiffness);
 
     if (std::abs(assembler.computeEnergy(x) - assembler.computeEnergy(rotatedX)) > 1e-6)
         DUNE_THROW(Dune::Exception, "Rod energy not invariant under rigid body motions!");
diff --git a/test/rodassemblertest.cc b/test/rodassemblertest.cc
index 28a9da324f5cb8510b392bbae2caf727f7fbab5d..20c186d3ee3c1ff42fee6676ad14c7a560ad5179 100644
--- a/test/rodassemblertest.cc
+++ b/test/rodassemblertest.cc
@@ -551,8 +551,9 @@ int main (int argc, char *argv[]) try
     RodLocalStiffness<GridView,double> localStiffness(gridView,
                                                                     0.01, 0.0001, 0.0001, 2.5e5, 0.3);
 
+    LocalGeodesicFEFDStiffness<Basis,RigidBodyMotion<double,3> > localFDStiffness(&localStiffness);
 
-    RodAssembler<Basis,3> rodAssembler(basis, &localStiffness);
+    RodAssembler<Basis,3> rodAssembler(basis, &localFDStiffness);
 
     std::cout << "Energy: " << rodAssembler.computeEnergy(x) << std::endl;