#ifndef LOCAL_GEODESIC_FE_STIFFNESS_HH
#define LOCAL_GEODESIC_FE_STIFFNESS_HH

#include "omp.h"

#include <dune/istl/bcrsmatrix.hh>
#include <dune/common/fmatrix.hh>
#include <dune/istl/matrixindexset.hh>
#include <dune/istl/matrix.hh>


template<class GridView, class LocalFiniteElement, class TargetSpace>
class LocalGeodesicFEStiffness
{
    // grid types
    typedef typename GridView::Grid::ctype DT;
    typedef typename TargetSpace::ctype RT;
    typedef typename GridView::template Codim<0>::Entity Entity;

    // some other sizes
    enum {gridDim=GridView::dimension};

public:

    //! Dimension of a tangent space
    enum { blocksize = TargetSpace::TangentVector::dimension };

    //! Dimension of the embedding space
    enum { embeddedBlocksize = TargetSpace::EmbeddedTangentVector::dimension };

    /** \brief Assemble the local stiffness matrix at the current position

    This default implementation used finite-difference approximations to compute the second derivatives

    The formula for the Riemannian Hessian has been taken from Absil, Mahony, Sepulchre:
    'Optimization algorithms on matrix manifolds', page 107.  There it says that
    \f[
        \langle Hess f(x)[\xi], \eta \rangle
            = \frac 12 \frac{d^2}{dt^2} \Big(f(\exp_x(t(\xi + \eta))) - f(\exp_x(t\xi)) - f(\exp_x(t\eta))\Big)\Big|_{t=0}.
    \f]
    We compute that using a finite difference approximation.

    */
    virtual void assembleHessian(const Entity& e,
                                 const LocalFiniteElement& localFiniteElement,
                                 const std::vector<TargetSpace>& localSolution);

    /** \brief Compute the energy at the current configuration */
    virtual RT energy (const Entity& e,
                       const LocalFiniteElement& localFiniteElement,
                       const std::vector<TargetSpace>& localSolution) const = 0;

    /** \brief Assemble the element gradient of the energy functional

    The default implementation in this class uses a finite difference approximation */
    virtual void assembleGradient(const Entity& element,
                                  const LocalFiniteElement& localFiniteElement,
                                  const std::vector<TargetSpace>& solution,
                                  std::vector<typename TargetSpace::TangentVector>& gradient) const;

    // assembled data
    Dune::Matrix<Dune::FieldMatrix<double,blocksize,blocksize> > A_;

};


template <class GridView, class LocalFiniteElement, class TargetSpace>
void LocalGeodesicFEStiffness<GridView, LocalFiniteElement, TargetSpace>::
assembleGradient(const Entity& element,
                 const LocalFiniteElement& localFiniteElement,
                 const std::vector<TargetSpace>& localSolution,
                 std::vector<typename TargetSpace::TangentVector>& localGradient) const
{

    // ///////////////////////////////////////////////////////////
    //   Compute gradient by finite-difference approximation
    // ///////////////////////////////////////////////////////////

    double eps = 1e-6;

    localGradient.resize(localSolution.size());

    std::vector<TargetSpace> forwardSolution = localSolution;
    std::vector<TargetSpace> backwardSolution = localSolution;

    for (size_t i=0; i<localSolution.size(); i++) {

        // basis vectors of the tangent space of the i-th entry of localSolution
        const Dune::FieldMatrix<double,blocksize,embeddedBlocksize> B = localSolution[i].orthonormalFrame();

        for (int j=0; j<blocksize; j++) {

            typename TargetSpace::EmbeddedTangentVector forwardCorrection = B[j];
            forwardCorrection *= eps;

            typename TargetSpace::EmbeddedTangentVector backwardCorrection = B[j];
            backwardCorrection *= -eps;

            forwardSolution[i]  = TargetSpace::exp(localSolution[i], forwardCorrection);
            backwardSolution[i] = TargetSpace::exp(localSolution[i], backwardCorrection);

            localGradient[i][j] = (energy(element,localFiniteElement,forwardSolution) - energy(element,localFiniteElement, backwardSolution)) / (2*eps);

        }

        forwardSolution[i]  = localSolution[i];
        backwardSolution[i] = localSolution[i];

    }

}


// ///////////////////////////////////////////////////////////
//   Compute gradient by finite-difference approximation
// ///////////////////////////////////////////////////////////
template <class GridType, class LocalFiniteElement, class TargetSpace>
void LocalGeodesicFEStiffness<GridType, LocalFiniteElement, TargetSpace>::
assembleHessian(const Entity& element,
                const LocalFiniteElement& localFiniteElement,
                const std::vector<TargetSpace>& localSolution)
{
    // Number of degrees of freedom for this element
    size_t nDofs = localSolution.size();

    // Clear assemble data
    A_.setSize(nDofs, nDofs);

    A_ = 0;

    const double eps = 1e-4;

    std::vector<Dune::FieldMatrix<double,blocksize,embeddedBlocksize> > B(localSolution.size());
    for (size_t i=0; i<B.size(); i++)
        B[i] = localSolution[i].orthonormalFrame();

    // Precompute negative energy at the current configuration
    // (negative because that is how we need it as part of the 2nd-order fd formula)
    double centerValue   = -energy(element, localFiniteElement, localSolution);

    // Precompute energy infinitesimal corrections in the directions of the local basis vectors
    std::vector<Dune::array<double,blocksize> > forwardEnergy(nDofs);
    std::vector<Dune::array<double,blocksize> > backwardEnergy(nDofs);

    #pragma omp parallel for schedule (dynamic)
    for (size_t i=0; i<localSolution.size(); i++) {
        for (size_t i2=0; i2<blocksize; i2++) {
            Dune::FieldVector<double,embeddedBlocksize> epsXi = B[i][i2];
            epsXi *= eps;
            Dune::FieldVector<double,embeddedBlocksize> minusEpsXi = epsXi;
            minusEpsXi  *= -1;

            std::vector<TargetSpace> forwardSolution  = localSolution;
            std::vector<TargetSpace> backwardSolution = localSolution;

            forwardSolution[i]  = TargetSpace::exp(localSolution[i],epsXi);
            backwardSolution[i] = TargetSpace::exp(localSolution[i],minusEpsXi);

            forwardEnergy[i][i2]  = energy(element, localFiniteElement, forwardSolution);
            backwardEnergy[i][i2] = energy(element, localFiniteElement, backwardSolution);

        }

    }

    // finite-difference approximation
    // we loop over the lower left triangular half of the matrix.
    // The other half follows from symmetry
    #pragma omp parallel for schedule (dynamic)
    for (size_t i=0; i<localSolution.size(); i++) {
        for (size_t i2=0; i2<blocksize; i2++) {
            for (size_t j=0; j<=i; j++) {
                for (size_t j2=0; j2<((i==j) ? i2+1 : blocksize); j2++) {

                    std::vector<TargetSpace> forwardSolutionXiEta  = localSolution;
                    std::vector<TargetSpace> backwardSolutionXiEta  = localSolution;

                    Dune::FieldVector<double,embeddedBlocksize> epsXi  = B[i][i2];    epsXi *= eps;
                    Dune::FieldVector<double,embeddedBlocksize> epsEta = B[j][j2];   epsEta *= eps;

                    Dune::FieldVector<double,embeddedBlocksize> minusEpsXi  = epsXi;   minusEpsXi  *= -1;
                    Dune::FieldVector<double,embeddedBlocksize> minusEpsEta = epsEta;  minusEpsEta *= -1;

                    if (i==j)
                        forwardSolutionXiEta[i] = TargetSpace::exp(localSolution[i],epsXi+epsEta);
                    else {
                        forwardSolutionXiEta[i] = TargetSpace::exp(localSolution[i],epsXi);
                        forwardSolutionXiEta[j] = TargetSpace::exp(localSolution[j],epsEta);
                    }

                    if (i==j)
                        backwardSolutionXiEta[i] = TargetSpace::exp(localSolution[i],minusEpsXi+minusEpsEta);
                    else {
                        backwardSolutionXiEta[i] = TargetSpace::exp(localSolution[i],minusEpsXi);
                        backwardSolutionXiEta[j] = TargetSpace::exp(localSolution[j],minusEpsEta);
                    }

                    double forwardValue  = energy(element, localFiniteElement, forwardSolutionXiEta) - forwardEnergy[i][i2] - forwardEnergy[j][j2];
                    double backwardValue = energy(element, localFiniteElement, backwardSolutionXiEta) - backwardEnergy[i][i2] - backwardEnergy[j][j2];

                    A_[i][j][i2][j2] = A_[j][i][j2][i2] = 0.5 * (forwardValue - 2*centerValue + backwardValue) / (eps*eps);

                }
            }
        }
    }
}

#endif