diff --git a/dirneucoupling.cc b/dirneucoupling.cc
index 23ad5d1aacaa4f9acc8611aedc0cd931ea4b87cf..38319c9e622aa04c3a8a82e1816666ccedc84fd9 100644
--- a/dirneucoupling.cc
+++ b/dirneucoupling.cc
@@ -35,6 +35,7 @@
 #include "src/riemanniantrsolver.hh"
 #include "src/geodesicdifference.hh"
 #include "src/rodwriter.hh"
+#include "src/makestraightrod.hh"
 
 // Space dimension
 const int dim = 3;
@@ -50,41 +51,6 @@ typedef vector<RigidBodyMotion<dim> >              RodSolutionType;
 typedef BlockVector<FieldVector<double, 6> >       RodDifferenceType;
 
 
-// Make a straight rod from two given endpoints
-void makeStraightRod(RodSolutionType& rod, int n,
-                     const FieldVector<double,3>& beginning, const FieldVector<double,3>& end)
-{
-    // Compute the correct orientation
-    Rotation<3,double> orientation = Rotation<3,double>::identity();
-
-    FieldVector<double,3> zAxis(0);
-    zAxis[2] = 1;
-    FieldVector<double,3> axis = crossProduct(end-beginning, zAxis);
-
-    if (axis.two_norm() != 0)
-        axis /= -axis.two_norm();
-
-    FieldVector<double,3> d3 = end-beginning;
-    d3 /= d3.two_norm();
-
-    double angle = std::acos(zAxis * d3);
-
-    if (angle != 0)
-        orientation = Rotation<3,double>(axis, angle);
-
-    // Set the values
-    rod.resize(n);
-    for (int i=0; i<n; i++) {
-
-        rod[i].r = beginning;
-        rod[i].r.axpy(double(i) / (n-1), end-beginning);
-        rod[i].q = orientation;
-
-    }
-
-
-}
-
 int main (int argc, char *argv[]) try
 {
     // Some types that I need
@@ -171,14 +137,7 @@ int main (int argc, char *argv[]) try
     // //////////////////////////
     //   Initial solution
     // //////////////////////////
-#if 0
-    for (int i=0; i<rodX.size(); i++) {
-        rodX[i].r[0] = 0.5;
-        rodX[i].r[1] = 0.5;
-        rodX[i].r[2] = 5 + (i* 5.0 /(rodX.size()-1));
-        rodX[i].q = Quaternion<double>::identity();
-    }
-#endif
+
     makeStraightRod(rodX, rodGrid.size(1), rodRestEndPoint[0], rodRestEndPoint[1]);
 
     // /////////////////////////////////////////
diff --git a/src/makestraightrod.hh b/src/makestraightrod.hh
new file mode 100644
index 0000000000000000000000000000000000000000..ab7bc62321f6e59f4536deda654d105b28551ffb
--- /dev/null
+++ b/src/makestraightrod.hh
@@ -0,0 +1,47 @@
+#ifndef MAKE_STRAIGHT_ROD_HH
+#define MAKE_STRAIGHT_ROD_HH
+
+#include <vector>
+#include <dune/common/fvector.hh>
+#include "rotation.hh"
+
+/** \brief Make a straight rod from two given endpoints
+
+\param[out] rod The new rod
+\param[in] n The number of vertices
+*/
+template <int dim>
+void makeStraightRod(std::vector<RigidBodyMotion<dim> >& rod, int n,
+                     const Dune::FieldVector<double,3>& beginning, const Dune::FieldVector<double,3>& end)
+{
+    // Compute the correct orientation
+    Rotation<3,double> orientation = Rotation<3,double>::identity();
+
+    Dune::FieldVector<double,3> zAxis(0);
+    zAxis[2] = 1;
+    Dune::FieldVector<double,3> axis = crossProduct(end-beginning, zAxis);
+    if (axis.two_norm() != 0)
+        axis /= -axis.two_norm();
+
+    Dune::FieldVector<double,3> d3 = end-beginning;
+    d3 /= d3.two_norm();
+
+    double angle = std::acos(zAxis * d3);
+
+    if (angle != 0)
+        orientation = Rotation<3,double>(axis, angle);
+
+    // Set the values
+    rod.resize(n);
+    for (int i=0; i<n; i++) {
+
+        rod[i].r = beginning;
+        rod[i].r.axpy(double(i) / (n-1), end-beginning);
+        rod[i].q = orientation;
+
+    }
+
+
+}
+
+#endif