// -*- tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*-
// vi: set et ts=4 sw=2 sts=2:
#include <config.h>

#include <array>
#include <vector>

#include <dune/common/indices.hh>

#include <dune/geometry/quadraturerules.hh>

#include <dune/grid/yaspgrid.hh>

#include <dune/istl/matrix.hh>
#include <dune/istl/bcrsmatrix.hh>
#include <dune/istl/matrixindexset.hh>
#include <dune/istl/solvers.hh>
#include <dune/istl/preconditioners.hh>

#include <dune/functions/functionspacebases/interpolate.hh>
#include <dune/functions/backends/istlvectorbackend.hh>
#include <dune/functions/functionspacebases/compositebasis.hh>
#include <dune/functions/functionspacebases/lagrangebasis.hh>
#include <dune/functions/functionspacebases/subspacebasis.hh>

#include <dune/functions/functionspacebases/scalarbasis.hh>

/**
 * This example implements the poisson equation with homogeneous neumann
 * boundary conditions and an average integral normalization:
 *
 * ```
 *    -\Delta u = f in domain, \\
 *       \int u = 0,           \\
 * \partial_n u = 0 on boundary
 * ```
 **/

using namespace Dune;

template <class LocalView, class ElementMatrix>
void getLocalMatrix(const LocalView& localView, ElementMatrix& elementMatrix)
{
  using Element = typename LocalView::Element;
  const Element element = localView.element();

  const int dim = Element::dimension;
  auto geometry = element.geometry();

  elementMatrix = 0;

  using namespace Indices;
  const auto& uLocalFE = localView.tree().child(_0).finiteElement();
  const auto& sLocalFE = localView.tree().child(_1).finiteElement();

  int order = 2*(dim*uLocalFE.localBasis().order()-1);
  const auto& quad = QuadratureRules<double, dim>::rule(element.type(), order);

  for (const auto& qp : quad)
  {
    const auto JinvT = geometry.jacobianInverseTransposed(qp.position());
    const auto dx = geometry.integrationElement(qp.position()) * qp.weight();

    thread_local std::vector<FieldMatrix<double,1,dim> > refGrads;
    uLocalFE.localBasis().evaluateJacobian(qp.position(), refGrads);

    thread_local std::vector<FieldVector<double,dim> > gradients;
    gradients.resize(refGrads.size());
    for (std::size_t i=0; i<gradients.size(); i++)
      JinvT.mv(refGrads[i][0], gradients[i]);

    // laplace(u)
    for (std::size_t i=0; i<uLocalFE.size(); i++)
      for (std::size_t j=0; j<uLocalFE.size(); j++)
      {
        std::size_t row = localView.tree().child(_0).localIndex(i);
        std::size_t col = localView.tree().child(_0).localIndex(j);
        elementMatrix[row][col] += (gradients[i] * gradients[j]) * dx;
      }

    thread_local std::vector<FieldVector<double,1> > uValues;
    uLocalFE.localBasis().evaluateFunction(qp.position(), uValues);

    // int(u) == 0 condition
    for (std::size_t i=0; i<uLocalFE.size(); i++)
      for (std::size_t j=0; j<sLocalFE.size(); j++)
      {
        std::size_t uIndex = localView.tree().child(_0).localIndex(i);
        std::size_t sIndex = localView.tree().child(_1).localIndex(j);

        elementMatrix[sIndex][uIndex] += uValues[i] * dx;
        elementMatrix[uIndex][sIndex] += uValues[i] * dx;
      }
  }
}


template <class Basis, class MatrixType>
void setOccupationPattern(const Basis& basis, MatrixType& matrix)
{
  MatrixIndexSet nb(basis.dimension(), basis.dimension());

  auto localView = basis.localView();
  for(const auto& element : elements(basis.gridView()))
  {
    localView.bind(element);

    for (std::size_t i=0; i<localView.size(); i++) {
      auto row = localView.index(i);
      for (std::size_t j=0; j<localView.size(); j++) {
        auto col = localView.index(j);
        nb.add(row[0],col[0]);
      }
    }
  }

  // Give the matrix the occupation pattern we want.
  nb.exportIdx(matrix);
}


template<class Matrix, class MultiIndex>
decltype(auto) matrixEntry(Matrix& matrix, const MultiIndex& row, const MultiIndex& col)
{
  return matrix[row[0]][col[0]];
}


template <class Basis, class MatrixType>
void assembleMatrix(const Basis& basis, MatrixType& matrix)
{
  setOccupationPattern(basis, matrix);
  matrix = 0;

  auto localView = basis.localView();

  Matrix<FieldMatrix<double,1,1> > elementMatrix;
  elementMatrix.setSize(localView.maxSize(), localView.maxSize());

  for (const auto& element : elements(basis.gridView()))
  {
    localView.bind(element);

    getLocalMatrix(localView, elementMatrix);

    for (std::size_t i=0; i<elementMatrix.N(); i++) {
      auto row = localView.index(i);
      for (std::size_t j=0; j<elementMatrix.M(); j++) {
        auto col = localView.index(j);
        matrixEntry(matrix, row, col) += elementMatrix[i][j];
      }
    }
  }
}


int main (int argc, char *argv[])
{
  MPIHelper::instance(argc, argv);

  const int dim = 2;
  using GridType = YaspGrid<dim>;
  GridType grid({1.0, 1.0}, {4, 4});

  using GridView = typename GridType::LeafGridView;
  GridView gridView = grid.leafGridView();


  using namespace Functions::BasisFactory;
  auto basis = makeBasis(gridView,
    composite(lagrange<1>(), scalar(), flatLexicographic() ));

  using VectorType = BlockVector<FieldVector<double,1> >;
  using MatrixType = BCRSMatrix<FieldMatrix<double,1,1> >;

  VectorType rhs;
  auto rhsBackend = Dune::Functions::istlVectorBackend(rhs);
  rhsBackend.resize(basis);
  rhs = 0;

  using namespace Indices;
  Functions::interpolate(
    Functions::subspaceBasis(basis, _0), rhs,
    [](auto const& x) { return std::sin(x[0]*6*M_PI) + std::cos(x[1]*4*M_PI); });

  MatrixType stiffnessMatrix;
  assembleMatrix(basis, stiffnessMatrix);

  MatrixAdapter<MatrixType,VectorType,VectorType> stiffnessOperator(stiffnessMatrix);
  Richardson<VectorType,VectorType> preconditioner(1.0);
  RestartedGMResSolver<VectorType> solver(
          stiffnessOperator,  // operator to invert
          preconditioner,     // preconditioner for interation
          1e-10,              // desired residual reduction factor
          500,                // number of iterations between restarts
          500,                // maximum number of iterations
          2);                 // verbosity of the solver

  InverseOperatorResult statistics;
  VectorType x = rhs;
  solver.apply(x, rhs, statistics);
 }