diff --git a/test/unitvectortest.cc b/test/unitvectortest.cc index 123a667c314594678aa2fd7de0ab46316e5d0333..4c42b21bee18f8c2bef36a2c731c07f606545154 100644 --- a/test/unitvectortest.cc +++ b/test/unitvectortest.cc @@ -296,45 +296,26 @@ void testDerivativesOfSquaredDistance(const TargetSpace& a, const TargetSpace& b } -void testUnitVector2d() + +template <int N> +void testUnitVector() { - std::vector<UnitVector<2> > testPoints; - ValueFactory<UnitVector<2> >::get(testPoints); + std::vector<UnitVector<N> > testPoints; + ValueFactory<UnitVector<N> >::get(testPoints); int nTestPoints = testPoints.size(); - // Set up elements of S^1 + // Set up elements of S^{N-1} for (int i=0; i<nTestPoints; i++) { - testOrthonormalFrame<UnitVector<2>, 2>(testPoints[i]); + testOrthonormalFrame<UnitVector<N>, N>(testPoints[i]); for (int j=0; j<nTestPoints; j++) { - if (UnitVector<2>::distance(testPoints[i],testPoints[j]) > M_PI*0.98) + if (UnitVector<N>::distance(testPoints[i],testPoints[j]) > M_PI*0.98) continue; - testDerivativesOfSquaredDistance<UnitVector<2>, 2>(testPoints[i], testPoints[j]); - - } - - } -} - -void testUnitVector3d() -{ - std::vector<UnitVector<3> > testPoints; - ValueFactory<UnitVector<3> >::get(testPoints); - - int nTestPoints = testPoints.size(); - - // Set up elements of S^2 - for (int i=0; i<nTestPoints; i++) { - - testOrthonormalFrame<UnitVector<3>, 3>(testPoints[i]); - - for (int j=0; j<nTestPoints; j++) { - - testDerivativesOfSquaredDistance<UnitVector<3>, 3>(testPoints[i], testPoints[j]); + testDerivativesOfSquaredDistance<UnitVector<N>, N>(testPoints[i], testPoints[j]); } @@ -342,6 +323,7 @@ void testUnitVector3d() } + void testRotation3d() { int nTestPoints = 10; @@ -373,8 +355,8 @@ void testRotation3d() int main() try { - testUnitVector2d(); - testUnitVector3d(); + testUnitVector<2>(); + testUnitVector<3>(); testRotation3d(); } catch (Exception e) {