diff --git a/test/unitvectortest.cc b/test/unitvectortest.cc
index 123a667c314594678aa2fd7de0ab46316e5d0333..4c42b21bee18f8c2bef36a2c731c07f606545154 100644
--- a/test/unitvectortest.cc
+++ b/test/unitvectortest.cc
@@ -296,45 +296,26 @@ void testDerivativesOfSquaredDistance(const TargetSpace& a, const TargetSpace& b
 
 }
 
-void testUnitVector2d()
+
+template <int N>
+void testUnitVector()
 {
-    std::vector<UnitVector<2> > testPoints;
-    ValueFactory<UnitVector<2> >::get(testPoints);
+    std::vector<UnitVector<N> > testPoints;
+    ValueFactory<UnitVector<N> >::get(testPoints);
     
     int nTestPoints = testPoints.size();
     
-    // Set up elements of S^1
+    // Set up elements of S^{N-1}
     for (int i=0; i<nTestPoints; i++) {
         
-        testOrthonormalFrame<UnitVector<2>, 2>(testPoints[i]);
+        testOrthonormalFrame<UnitVector<N>, N>(testPoints[i]);
         
         for (int j=0; j<nTestPoints; j++) {
             
-            if (UnitVector<2>::distance(testPoints[i],testPoints[j]) > M_PI*0.98)
+            if (UnitVector<N>::distance(testPoints[i],testPoints[j]) > M_PI*0.98)
                 continue;
 
-            testDerivativesOfSquaredDistance<UnitVector<2>, 2>(testPoints[i], testPoints[j]);
-            
-        }
-        
-    }
-}
-
-void testUnitVector3d()
-{
-    std::vector<UnitVector<3> > testPoints;
-    ValueFactory<UnitVector<3> >::get(testPoints);
-    
-    int nTestPoints = testPoints.size();
-                                  
-    // Set up elements of S^2
-    for (int i=0; i<nTestPoints; i++) {
-        
-        testOrthonormalFrame<UnitVector<3>, 3>(testPoints[i]);
-        
-        for (int j=0; j<nTestPoints; j++) {
-            
-            testDerivativesOfSquaredDistance<UnitVector<3>, 3>(testPoints[i], testPoints[j]);
+            testDerivativesOfSquaredDistance<UnitVector<N>, N>(testPoints[i], testPoints[j]);
             
         }
         
@@ -342,6 +323,7 @@ void testUnitVector3d()
     
 }
 
+
 void testRotation3d()
 {
     int nTestPoints = 10;
@@ -373,8 +355,8 @@ void testRotation3d()
 
 int main() try
 {
-    testUnitVector2d();
-    testUnitVector3d();
+    testUnitVector<2>();
+    testUnitVector<3>();
 
     testRotation3d();
 } catch (Exception e) {